File size: 739 Bytes
cac5c9d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
from PIL import Image
import torch
from torchvision import transforms
from model import load_model

# 🔥 Load ALL models
model1 = load_model("m1.safetensors")
model2 = load_model("m2.safetensors")
model3 = load_model("m3.safetensors")

models = [model1, model2, model3]

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

def predict(image):
    img = Image.open(image).convert("RGB")
    x = transform(img).unsqueeze(0)

    probs = []

    with torch.no_grad():
        for model in models:
            out = model(x)
            prob = torch.sigmoid(out).item()
            probs.append(prob)

    # 🔥 Ensemble (average)
    final_prob = sum(probs) / len(probs)

    return final_prob