from PIL import Image import torch from torchvision import transforms from model import load_model # 🔥 Load ALL models model1 = load_model("m1.safetensors") model2 = load_model("m2.safetensors") model3 = load_model("m3.safetensors") models = [model1, model2, model3] transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor() ]) def predict(image): img = Image.open(image).convert("RGB") x = transform(img).unsqueeze(0) probs = [] with torch.no_grad(): for model in models: out = model(x) prob = torch.sigmoid(out).item() probs.append(prob) # 🔥 Ensemble (average) final_prob = sum(probs) / len(probs) return final_prob