image_model / interface.py
Simma7's picture
Create interface.py
cac5c9d verified
from PIL import Image
import torch
from torchvision import transforms
from model import load_model
# 🔥 Load ALL models
model1 = load_model("m1.safetensors")
model2 = load_model("m2.safetensors")
model3 = load_model("m3.safetensors")
models = [model1, model2, model3]
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor()
])
def predict(image):
img = Image.open(image).convert("RGB")
x = transform(img).unsqueeze(0)
probs = []
with torch.no_grad():
for model in models:
out = model(x)
prob = torch.sigmoid(out).item()
probs.append(prob)
# 🔥 Ensemble (average)
final_prob = sum(probs) / len(probs)
return final_prob