AdarshRajDS commited on
Commit
bef3d34
·
1 Parent(s): 7a5f7fb

Fix Resnet

Browse files
Files changed (1) hide show
  1. app.py +43 -0
app.py CHANGED
@@ -25,6 +25,26 @@ app.add_middleware(
25
 
26
  device = "cuda" if torch.cuda.is_available() else "cpu"
27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  # ------------------
29
  # Load main model (ConvNeXt)
30
  # ------------------
@@ -150,3 +170,26 @@ async def explain_gradcam(file: UploadFile):
150
  img_t = transform(img).to(device)
151
  cam = gradcam.generate(img_t, mold_idx)
152
  return {"gradcam": cam.tolist()}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
  device = "cuda" if torch.cuda.is_available() else "cpu"
27
 
28
+ # ------------------
29
+ # Load baseline model (ResNet)
30
+ # ------------------
31
+ resnet_ckpt = torch.load(
32
+ "best_resnet_multitask.pth",
33
+ map_location=device
34
+ )
35
+
36
+ resnet_classes = resnet_ckpt.get("classes") or []
37
+ resnet_num_classes = len(resnet_classes) if resnet_classes else 9
38
+ resnet_mold_idx = (
39
+ resnet_classes.index("mold")
40
+ if resnet_classes else 4
41
+ )
42
+
43
+ resnet_model = MultiTaskResNet50(resnet_num_classes).to(device)
44
+ resnet_model.load_state_dict(resnet_ckpt["model"])
45
+ resnet_model.eval()
46
+
47
+
48
  # ------------------
49
  # Load main model (ConvNeXt)
50
  # ------------------
 
170
  img_t = transform(img).to(device)
171
  cam = gradcam.generate(img_t, mold_idx)
172
  return {"gradcam": cam.tolist()}
173
+
174
+
175
+ @app.post("/predict/resnet")
176
+ async def predict_resnet(file: UploadFile):
177
+ img = Image.open(io.BytesIO(await file.read())).convert("RGB")
178
+ img_t = transform(img).to(device)
179
+
180
+ with torch.no_grad():
181
+ out = resnet_model(img_t.unsqueeze(0))
182
+ cp = torch.softmax(out["class"], 1)[0]
183
+ bp = torch.softmax(out["bio"], 1)[0]
184
+
185
+ mold_p = cp[resnet_mold_idx].item()
186
+ bio_p = bp[1].item()
187
+
188
+ decision = final_decision(mold_p, bio_p)
189
+
190
+ return {
191
+ "decision": decision,
192
+ "mold_probability": round(mold_p, 3),
193
+ "biological_probability": round(bio_p, 3),
194
+ }
195
+