Simma7 commited on
Commit
cac5c9d
·
verified ·
1 Parent(s): 0a6c6ed

Create interface.py

Browse files
Files changed (1) hide show
  1. interface.py +33 -0
interface.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import torch
3
+ from torchvision import transforms
4
+ from model import load_model
5
+
6
+ # 🔥 Load ALL models
7
+ model1 = load_model("m1.safetensors")
8
+ model2 = load_model("m2.safetensors")
9
+ model3 = load_model("m3.safetensors")
10
+
11
+ models = [model1, model2, model3]
12
+
13
+ transform = transforms.Compose([
14
+ transforms.Resize((224, 224)),
15
+ transforms.ToTensor()
16
+ ])
17
+
18
+ def predict(image):
19
+ img = Image.open(image).convert("RGB")
20
+ x = transform(img).unsqueeze(0)
21
+
22
+ probs = []
23
+
24
+ with torch.no_grad():
25
+ for model in models:
26
+ out = model(x)
27
+ prob = torch.sigmoid(out).item()
28
+ probs.append(prob)
29
+
30
+ # 🔥 Ensemble (average)
31
+ final_prob = sum(probs) / len(probs)
32
+
33
+ return final_prob