File size: 739 Bytes
cac5c9d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 | from PIL import Image
import torch
from torchvision import transforms
from model import load_model
# 🔥 Load ALL models
model1 = load_model("m1.safetensors")
model2 = load_model("m2.safetensors")
model3 = load_model("m3.safetensors")
models = [model1, model2, model3]
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor()
])
def predict(image):
img = Image.open(image).convert("RGB")
x = transform(img).unsqueeze(0)
probs = []
with torch.no_grad():
for model in models:
out = model(x)
prob = torch.sigmoid(out).item()
probs.append(prob)
# 🔥 Ensemble (average)
final_prob = sum(probs) / len(probs)
return final_prob |