Spaces:
Sleeping
Sleeping
AdarshRajDS commited on
Commit ·
bef3d34
1
Parent(s): 7a5f7fb
Fix Resnet
Browse files
app.py
CHANGED
|
@@ -25,6 +25,26 @@ app.add_middleware(
|
|
| 25 |
|
| 26 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 27 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
# ------------------
|
| 29 |
# Load main model (ConvNeXt)
|
| 30 |
# ------------------
|
|
@@ -150,3 +170,26 @@ async def explain_gradcam(file: UploadFile):
|
|
| 150 |
img_t = transform(img).to(device)
|
| 151 |
cam = gradcam.generate(img_t, mold_idx)
|
| 152 |
return {"gradcam": cam.tolist()}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
|
| 26 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 27 |
|
| 28 |
+
# ------------------
|
| 29 |
+
# Load baseline model (ResNet)
|
| 30 |
+
# ------------------
|
| 31 |
+
resnet_ckpt = torch.load(
|
| 32 |
+
"best_resnet_multitask.pth",
|
| 33 |
+
map_location=device
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
resnet_classes = resnet_ckpt.get("classes") or []
|
| 37 |
+
resnet_num_classes = len(resnet_classes) if resnet_classes else 9
|
| 38 |
+
resnet_mold_idx = (
|
| 39 |
+
resnet_classes.index("mold")
|
| 40 |
+
if resnet_classes else 4
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
resnet_model = MultiTaskResNet50(resnet_num_classes).to(device)
|
| 44 |
+
resnet_model.load_state_dict(resnet_ckpt["model"])
|
| 45 |
+
resnet_model.eval()
|
| 46 |
+
|
| 47 |
+
|
| 48 |
# ------------------
|
| 49 |
# Load main model (ConvNeXt)
|
| 50 |
# ------------------
|
|
|
|
| 170 |
img_t = transform(img).to(device)
|
| 171 |
cam = gradcam.generate(img_t, mold_idx)
|
| 172 |
return {"gradcam": cam.tolist()}
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
@app.post("/predict/resnet")
|
| 176 |
+
async def predict_resnet(file: UploadFile):
|
| 177 |
+
img = Image.open(io.BytesIO(await file.read())).convert("RGB")
|
| 178 |
+
img_t = transform(img).to(device)
|
| 179 |
+
|
| 180 |
+
with torch.no_grad():
|
| 181 |
+
out = resnet_model(img_t.unsqueeze(0))
|
| 182 |
+
cp = torch.softmax(out["class"], 1)[0]
|
| 183 |
+
bp = torch.softmax(out["bio"], 1)[0]
|
| 184 |
+
|
| 185 |
+
mold_p = cp[resnet_mold_idx].item()
|
| 186 |
+
bio_p = bp[1].item()
|
| 187 |
+
|
| 188 |
+
decision = final_decision(mold_p, bio_p)
|
| 189 |
+
|
| 190 |
+
return {
|
| 191 |
+
"decision": decision,
|
| 192 |
+
"mold_probability": round(mold_p, 3),
|
| 193 |
+
"biological_probability": round(bio_p, 3),
|
| 194 |
+
}
|
| 195 |
+
|