amitabh3 commited on
Commit
96d978b
·
verified ·
1 Parent(s): c206e8b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -101
app.py CHANGED
@@ -1,123 +1,73 @@
1
- # app.py ─── HuggingFace Space backend (Python 3.10 + FastAPI + Torch 2.x)
2
- # ──────────────────────────────────────────────────────────────────────────
3
- import io, cv2, numpy as np, torch
4
  from PIL import Image
5
  from fastapi import FastAPI, UploadFile, File
6
  from fastapi.middleware.cors import CORSMiddleware
7
  from starlette.responses import StreamingResponse, PlainTextResponse
 
8
 
9
- # ------------------------------------------------------------------------
10
- # 1️⃣ MODEL + TRAINING PARAMS
11
- # ------------------------------------------------------------------------
12
- CKPT_PATH = "models/best_model.pth" # tracked via Git‑LFS
13
- ARCH_NAME = "NestedUNet"
14
- INPUT_H = INPUT_W = 512 # resize used during training
15
- DEEP_SUP = False
16
- THRESH_PRIMARY = 0.50 # first try
17
- THRESH_FALLBACK= 0.25 # if mask is empty
18
- ALPHA = 0.60 # overlay transparency (0‑1)
19
-
20
- # ------------------------------------------------------------------------
21
- # 2️⃣ ALLOWED FRONT‑ENDS (GitHub Pages, localhost for dev, etc.)
22
- # ------------------------------------------------------------------------
23
- FRONTENDS = [
24
- "https://amitabhm1.github.io", # ← your GitHub‑Pages root
25
- "http://localhost:4173", # vite / React dev server
26
- "http://127.0.0.1:5500" # plain HTML preview
27
- ]
28
-
29
- # ------------------------------------------------------------------------
30
- # 3️⃣ FASTAPI APP + CORS
31
- # ------------------------------------------------------------------------
32
- app = FastAPI(title="Polyp Segmentation API", version="1.0")
33
-
34
- app.add_middleware(
35
- CORSMiddleware,
36
- allow_origins=FRONTENDS,
37
- allow_methods=["POST", "GET", "OPTIONS"],
38
- allow_headers=["*"],
39
- )
40
 
41
- # ------------------------------------------------------------------------
42
- # 4️⃣ LOAD NETWORK ONCE AT STARTUP
43
- # ------------------------------------------------------------------------
44
- from backend_src import archs
45
  device = "cuda" if torch.cuda.is_available() else "cpu"
46
-
47
- model = archs.__dict__[ARCH_NAME](1, 3, DEEP_SUP).to(device)
48
- state = torch.load(CKPT_PATH, map_location=device)
49
- state = {k.replace("module.", ""): v for k, v in state.items()} # DDP → single‑GPU
50
  model.load_state_dict(state, strict=False)
51
  model.eval()
 
52
 
53
- print(f"✓ model loaded ({device}) – params: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")
54
-
55
- # ------------------------------------------------------------------------
56
- # 5️⃣ PRE‑ / POST‑PROCESS HELPERS
57
- # ------------------------------------------------------------------------
58
- import albumentations as A
59
- _resize = A.Compose([A.Resize(INPUT_H, INPUT_W, interpolation=cv2.INTER_LINEAR)])
60
 
61
- def preprocess(pil: Image.Image) -> np.ndarray:
62
- """PIL → float32 tensor 1×3×H×W on device"""
63
- img = np.array(pil.convert("RGB"))
64
- img = _resize(image=img)["image"].astype("float32") / 255.0
65
- ten = torch.from_numpy(img.transpose(2,0,1)).unsqueeze(0).to(device)
66
- return ten
67
 
68
- def make_mask(prob: np.ndarray) -> np.ndarray:
69
- """adaptive mask – fallback if THRESH_PRIMARY gives an empty mask"""
70
- mask = (prob >= THRESH_PRIMARY).astype("uint8")
71
- if mask.sum() == 0:
72
- mask = (prob >= THRESH_FALLBACK).astype("uint8")
73
- return mask
74
 
75
- def overlay_contour(img: np.ndarray, mask: np.ndarray,
76
- color=(0,0,255), thickness=2) -> np.ndarray:
77
- """Draw red contours over the original image"""
78
- contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
79
  out = img.copy()
80
- cv2.drawContours(out, contours, -1, color, thickness)
81
  return out
82
 
83
- # ------------------------------------------------------------------------
84
- # 6️⃣ HEALTH‑CHECK
85
- # ------------------------------------------------------------------------
86
- @app.get("/ping", summary="Liveness/ready probe")
87
- def ping():
88
- return PlainTextResponse("pong")
89
-
90
- # ------------------------------------------------------------------------
91
- # 7️⃣ MAIN INFERENCE ENDPOINT
92
- # ------------------------------------------------------------------------
93
- @app.post("/segment", summary="Upload an image, get PNG with mask overlay")
94
- async def segment(file: UploadFile = File(...)):
95
- try:
96
- raw_bytes = await file.read()
97
- pil_img = Image.open(io.BytesIO(raw_bytes)).convert("RGB")
98
- inp = preprocess(pil_img)
99
-
100
- with torch.inference_mode():
101
- out = model(inp)
102
- if DEEP_SUP and isinstance(out, (list, tuple)):
103
- out = out[-1]
104
 
105
- prob = torch.sigmoid(out)[0,0].cpu().numpy()
106
- mask = make_mask(prob)
 
107
 
108
- # ---- overlay ----
109
- resized = pil_img.resize((INPUT_W, INPUT_H), Image.BILINEAR)
110
- overlay = overlay_contour(np.array(resized), mask)
111
 
112
- buf = io.BytesIO()
113
- Image.fromarray(overlay).save(buf, format="PNG")
114
- buf.seek(0)
 
 
115
 
116
- # ● optional debugging in Space logs
117
- print(f"foreground px: {mask.sum():,} – file: {file.filename}")
 
 
 
 
 
118
 
119
- return StreamingResponse(buf, media_type="image/png")
 
120
 
121
- except Exception as e:
122
- # keep same structure so JS can always parse JSON on error
123
- return {"error": str(e)}
 
1
+ # app.py (debugfriendly overlay)
2
+ import io, cv2, numpy as np, torch, albumentations as A
 
3
  from PIL import Image
4
  from fastapi import FastAPI, UploadFile, File
5
  from fastapi.middleware.cors import CORSMiddleware
6
  from starlette.responses import StreamingResponse, PlainTextResponse
7
+ from backend_src import archs
8
 
9
+ # ─────────── config ────────────
10
+ CKPT = "models/best_model.pth"
11
+ ARCH = "NestedUNet"
12
+ INP_H = INP_W = 512
13
+ DEEP_SUP = False
14
+ TH_MAIN, TH_FALL = .50, .25
15
+ ALPHA = .60
16
+ FRONTENDS = ["https://amitabhm1.github.io","http://localhost:4173"]
17
+ # ───────────────────────────────
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
 
 
 
 
19
  device = "cuda" if torch.cuda.is_available() else "cpu"
20
+ model = archs.__dict__[ARCH](1,3,DEEP_SUP).to(device)
21
+ state = torch.load(CKPT, map_location=device)
22
+ state = {k.removeprefix("module."):v for k,v in state.items()}
 
23
  model.load_state_dict(state, strict=False)
24
  model.eval()
25
+ print("✓ model loaded on", device)
26
 
27
+ resize = A.Compose([A.Resize(INP_H, INP_W)])
 
 
 
 
 
 
28
 
29
+ def to_tensor(pil:Image.Image)->torch.Tensor:
30
+ arr = np.array(pil.convert("RGB")).astype("float32")/255.0 # match training!
31
+ arr = resize(image=arr)["image"]
32
+ return torch.from_numpy(arr.transpose(2,0,1)).unsqueeze(0).to(device)
 
 
33
 
34
+ def mask_from_prob(prob:np.ndarray)->np.ndarray:
35
+ m = (prob>=TH_MAIN).astype("uint8")
36
+ return m if m.sum() else (prob>=TH_FALL).astype("uint8")
 
 
 
37
 
38
+ def overlay(img,np_mask): # red contour
39
+ cnts,_ = cv2.findContours(np_mask,cv2.RETR_EXTERNAL,cv2.CHAIN_APPROX_SIMPLE)
 
 
40
  out = img.copy()
41
+ cv2.drawContours(out,cnts,-1,(0,0,255),2)
42
  return out
43
 
44
+ def heatmap(img,prob):
45
+ hm = cv2.applyColorMap((prob*255).astype("uint8"),cv2.COLORMAP_JET)
46
+ return cv2.addWeighted(img,1-ALPHA,hm,ALPHA,0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
+ app = FastAPI()
49
+ app.add_middleware(CORSMiddleware,
50
+ allow_origins=FRONTENDS, allow_methods=["POST","GET","OPTIONS"], allow_headers=["*"])
51
 
52
+ @app.get("/ping")
53
+ def ping(): return PlainTextResponse("pong")
 
54
 
55
+ @app.post("/segment")
56
+ async def segment(file:UploadFile=File(...)):
57
+ raw = await file.read()
58
+ pil = Image.open(io.BytesIO(raw))
59
+ inp = to_tensor(pil)
60
 
61
+ with torch.inference_mode():
62
+ out = model(inp)
63
+ if DEEP_SUP and isinstance(out,(list,tuple)): out = out[-1]
64
+ prob = torch.sigmoid(out)[0,0].cpu().numpy()
65
+ mmax, mmin = prob.max(), prob.min()
66
+ mask = mask_from_prob(prob)
67
+ print(f"{file.filename}: min={mmin:.3f} max={mmax:.3f} fg={mask.sum()}")
68
 
69
+ frame = np.array(pil.resize((INP_W,INP_H),Image.BILINEAR))
70
+ vis = overlay(frame,mask) if mask.sum() else heatmap(frame,prob)
71
 
72
+ buf = io.BytesIO(); Image.fromarray(vis).save(buf,format="PNG"); buf.seek(0)
73
+ return StreamingResponse(buf, media_type="image/png")