AdarshRajDS commited on
Commit
511fc83
·
1 Parent(s): fdcce44

Add model file with Git and updated app.py

Browse files
Files changed (1) hide show
  1. app.py +64 -42
app.py CHANGED
@@ -1,68 +1,90 @@
1
  from fastapi import FastAPI, UploadFile
2
  from fastapi.middleware.cors import CORSMiddleware
3
  from PIL import Image
4
- import torch
5
- import io
6
  from pathlib import Path
7
  from torchvision import transforms
8
 
9
  from model import MultiTaskResNet50
10
- from decision import final_decision
 
 
 
11
 
12
- app = FastAPI(
13
- title="Mold Detection API",
14
- description="FastAPI backend for mold detection using multi-task ResNet50",
15
- version="1.0.0"
16
- )
17
 
18
- # Add CORS middleware for frontend
19
  app.add_middleware(
20
  CORSMiddleware,
21
- allow_origins=["*"], # In production, replace with specific frontend URL
22
- allow_credentials=True,
23
  allow_methods=["*"],
24
  allow_headers=["*"],
25
  )
26
 
27
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
28
 
29
- # Model path for HuggingFace Spaces (flat structure)
30
- model_path = Path("resnet50_multitask_mold.pth")
31
-
32
- print(f"Loading model from: {model_path.absolute()}")
33
- print(f"Model exists: {model_path.exists()}")
34
-
35
- model = MultiTaskResNet50()
36
- model.load_state_dict(torch.load(str(model_path), map_location=device))
37
- model.eval().to(device)
38
- print("✅ Model loaded successfully")
39
 
 
40
  transform = transforms.Compose([
41
  transforms.Resize((224,224)),
42
  transforms.ToTensor(),
43
- transforms.Normalize(
44
- mean=[0.485,0.456,0.406],
45
- std=[0.229,0.224,0.225]
46
- )
47
  ])
48
 
49
- @app.get("/")
50
- async def root():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  return {
52
- "status": "healthy",
53
- "message": "Mold Detection API is running",
54
- "endpoint": "/predict",
55
- "method": "POST",
56
- "docs": "/docs"
 
 
 
 
 
57
  }
58
 
59
- @app.get("/health")
60
- async def health():
61
- return {"status": "healthy"}
62
-
63
- @app.post("/predict")
64
- async def predict(file: UploadFile):
65
- image_bytes = await file.read()
66
- img = Image.open(io.BytesIO(image_bytes)).convert("RGB")
67
- img_tensor = transform(img).to(device)
68
- return final_decision(model, img_tensor)
 
1
  from fastapi import FastAPI, UploadFile
2
  from fastapi.middleware.cors import CORSMiddleware
3
  from PIL import Image
4
+ import torch, io
 
5
  from pathlib import Path
6
  from torchvision import transforms
7
 
8
  from model import MultiTaskResNet50
9
+ from decision import final_decision # v1
10
+ from advanced_decision import *
11
+ from gradcam import GradCAM
12
+ from dino import *
13
 
14
+ app = FastAPI(title="Mold Detection API v2")
 
 
 
 
15
 
 
16
  app.add_middleware(
17
  CORSMiddleware,
18
+ allow_origins=["*"],
 
19
  allow_methods=["*"],
20
  allow_headers=["*"],
21
  )
22
 
23
  device = "cuda" if torch.cuda.is_available() else "cpu"
24
+ mold_idx = 4
25
 
26
+ # Load model
27
+ model = MultiTaskResNet50().to(device)
28
+ model.load_state_dict(torch.load("resnet50_multitask_bio.pth", map_location=device))
29
+ model.eval()
 
 
 
 
 
 
30
 
31
+ # Transforms
32
  transform = transforms.Compose([
33
  transforms.Resize((224,224)),
34
  transforms.ToTensor(),
35
+ transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
 
 
 
36
  ])
37
 
38
+ # Grad-CAM
39
+ gradcam = GradCAM(model, model.backbone.layer4[-1].conv3)
40
+
41
+ # DINO
42
+ dino = load_dino(device)
43
+ mold_embs = build_embeddings(dino, transform, "mold_reference_images", device)
44
+
45
+ @app.post("/predict/v1")
46
+ async def predict_v1(file: UploadFile):
47
+ img = Image.open(io.BytesIO(await file.read())).convert("RGB")
48
+ img_t = transform(img).to(device)
49
+ return final_decision(model, img_t)
50
+
51
+ @app.post("/predict/v2")
52
+ async def predict_v2(file: UploadFile):
53
+ img = Image.open(io.BytesIO(await file.read())).convert("RGB")
54
+ img_t = transform(img).to(device)
55
+
56
+ with torch.no_grad():
57
+ out = model(img_t.unsqueeze(0))
58
+ cp = torch.softmax(out["class"],1)[0]
59
+ bp = torch.softmax(out["bio"],1)[0]
60
+
61
+ mold_p = cp[mold_idx].item()
62
+ bio_p = bp[1].item()
63
+
64
+ mean_p, std_p = mc_uncertainty(model, img_t, mold_idx)
65
+ patch_ratio = patch_consistency(model, img, transform, mold_idx, device)
66
+ dino_sim = similarity(dino, mold_embs, img, transform, device)
67
+
68
+ decision = final_decision_v2(
69
+ mold_p, bio_p, std_p, patch_ratio, dino_sim
70
+ )
71
+
72
  return {
73
+ "decision": decision,
74
+ "model_outputs": {
75
+ "mold_probability": round(mold_p,3),
76
+ "biological_probability": round(bio_p,3)
77
+ },
78
+ "confidence_checks": {
79
+ "uncertainty": round(std_p,3),
80
+ "patch_ratio": round(patch_ratio,3),
81
+ "dino_similarity": round(dino_sim,3)
82
+ }
83
  }
84
 
85
+ @app.post("/explain/gradcam")
86
+ async def explain_gradcam(file: UploadFile):
87
+ img = Image.open(io.BytesIO(await file.read())).convert("RGB")
88
+ img_t = transform(img).to(device)
89
+ cam = gradcam.generate(img_t, mold_idx)
90
+ return {"gradcam": cam.tolist()}