amitabh3 commited on
Commit
c206e8b
·
verified ·
1 Parent(s): 9d4f263

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +106 -67
app.py CHANGED
@@ -1,84 +1,123 @@
1
- # ---------- app.py (backend root of your HF Space) -----------------
2
- import io, cv2, torch, albumentations as A, numpy as np
 
3
  from PIL import Image
4
- from fastapi import FastAPI, File, UploadFile
5
  from fastapi.middleware.cors import CORSMiddleware
6
- from starlette.responses import StreamingResponse, JSONResponse
7
-
8
- # -------------------------------------------------------------------
9
- CKPT = "models/best_model.pth"
10
- ARCH_NAME = "NestedUNet"
11
- INPUT_H = INPUT_W = 512
12
- THRESH = 0.50
13
- ALPHA = 0.40
14
- DEEP_SUP = False
15
- FRONT_ORIGIN = "https://amitabhm1.github.io"
16
- # -------------------------------------------------------------------
17
-
18
- # ---------- import model definition from backend_src ----------------
19
- from backend_src import archs
20
- # --------------------------------------------------------------------
21
-
22
- app = FastAPI(title="Polyp‑Seg‑API")
 
 
 
 
 
 
 
 
 
23
 
24
  app.add_middleware(
25
  CORSMiddleware,
26
- allow_origins=[FRONT_ORIGIN, "http://localhost:4173"],
27
- allow_methods=["POST"],
28
  allow_headers=["*"],
29
  )
30
 
31
- # ---------------- load checkpoint (fail fast) -----------------------
32
- if not torch.cuda.is_available():
33
- print("⚠ CUDA not available – running on CPU")
34
-
35
  device = "cuda" if torch.cuda.is_available() else "cpu"
36
- net = archs.__dict__[ARCH_NAME](1, 3, DEEP_SUP).to(device)
37
-
38
- try:
39
- state = torch.load(CKPT, map_location=device)
40
- except FileNotFoundError as e:
41
- raise SystemExit(f"❌ Checkpoint {CKPT} not found in the Space") from e
42
 
43
- state = {k.replace("module.", ""): v for k, v in state.items()}
44
- missing, unexpected = net.load_state_dict(state, strict=False)
45
- if missing:
46
- raise SystemExit(f"❌ Checkpoint is missing keys: {missing[:5]} …")
 
47
 
48
- net.eval()
49
- print("✓ model loaded on", device)
50
 
51
- # ---------------- preprocessing -------------------------------------
52
- xform = A.Compose([A.Resize(INPUT_H, INPUT_W)])
 
 
 
53
 
54
- def predict_overlay(pil: Image.Image) -> bytes:
 
55
  img = np.array(pil.convert("RGB"))
56
- img_rz = xform(image=img)["image"]
57
- ten = torch.from_numpy(img_rz.astype("float32")/255).permute(2,0,1)
58
- ten = ten.unsqueeze(0).to(device)
59
-
60
- with torch.no_grad():
61
- logits = net(ten)
62
- if DEEP_SUP and isinstance(logits,(list,tuple)): logits = logits[-1]
63
- mask = (torch.sigmoid(logits)[0,0].cpu().numpy() > THRESH).astype("uint8")
64
-
65
- # colourise mask in red
66
- red = np.zeros_like(img_rz); red[:] = (0,0,255)
67
- blend = (img_rz*(1-ALPHA) + red*ALPHA).astype("uint8")
68
- out = img_rz.copy()
69
- out[mask==1] = blend[mask==1]
70
-
71
- buf = io.BytesIO()
72
- Image.fromarray(out).save(buf, format="PNG")
73
- buf.seek(0); return buf
74
-
75
- # ---------------- endpoint -----------------------------------------
76
- @app.post("/segment")
 
 
 
 
 
 
 
 
 
77
  async def segment(file: UploadFile = File(...)):
78
  try:
79
- pil = Image.open(io.BytesIO(await file.read()))
80
- except Exception:
81
- return JSONResponse({"error":"invalid image"}, status_code=400)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
- png = predict_overlay(pil)
84
- return StreamingResponse(png, media_type="image/png")
 
 
1
+ # app.py ─── Hugging‑Face Space backend (Python 3.10 + FastAPI + Torch 2.x)
2
+ # ──────────────────────────────────────────────────────────────────────────
3
+ import io, cv2, numpy as np, torch
4
  from PIL import Image
5
+ from fastapi import FastAPI, UploadFile, File
6
  from fastapi.middleware.cors import CORSMiddleware
7
+ from starlette.responses import StreamingResponse, PlainTextResponse
8
+
9
+ # ------------------------------------------------------------------------
10
+ # 1️⃣ MODEL + TRAINING PARAMS
11
+ # ------------------------------------------------------------------------
12
+ CKPT_PATH = "models/best_model.pth" # tracked via Git‑LFS
13
+ ARCH_NAME = "NestedUNet"
14
+ INPUT_H = INPUT_W = 512 # resize used during training
15
+ DEEP_SUP = False
16
+ THRESH_PRIMARY = 0.50 # first try
17
+ THRESH_FALLBACK= 0.25 # if mask is empty
18
+ ALPHA = 0.60 # overlay transparency (0‑1)
19
+
20
+ # ------------------------------------------------------------------------
21
+ # 2️⃣ ALLOWED FRONT‑ENDS (GitHub Pages, localhost for dev, etc.)
22
+ # ------------------------------------------------------------------------
23
+ FRONTENDS = [
24
+ "https://amitabhm1.github.io", # ← your GitHub‑Pages root
25
+ "http://localhost:4173", # vite / React dev server
26
+ "http://127.0.0.1:5500" # plain HTML preview
27
+ ]
28
+
29
+ # ------------------------------------------------------------------------
30
+ # 3️⃣ FASTAPI APP + CORS
31
+ # ------------------------------------------------------------------------
32
+ app = FastAPI(title="Polyp Segmentation API", version="1.0")
33
 
34
  app.add_middleware(
35
  CORSMiddleware,
36
+ allow_origins=FRONTENDS,
37
+ allow_methods=["POST", "GET", "OPTIONS"],
38
  allow_headers=["*"],
39
  )
40
 
41
+ # ------------------------------------------------------------------------
42
+ # 4️⃣ LOAD NETWORK ONCE AT STARTUP
43
+ # ------------------------------------------------------------------------
44
+ from backend_src import archs
45
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
 
 
 
 
46
 
47
+ model = archs.__dict__[ARCH_NAME](1, 3, DEEP_SUP).to(device)
48
+ state = torch.load(CKPT_PATH, map_location=device)
49
+ state = {k.replace("module.", ""): v for k, v in state.items()} # DDP → single‑GPU
50
+ model.load_state_dict(state, strict=False)
51
+ model.eval()
52
 
53
+ print(f"✓ model loaded ({device}) – params: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")
 
54
 
55
+ # ------------------------------------------------------------------------
56
+ # 5️⃣ PRE‑ / POST‑PROCESS HELPERS
57
+ # ------------------------------------------------------------------------
58
+ import albumentations as A
59
+ _resize = A.Compose([A.Resize(INPUT_H, INPUT_W, interpolation=cv2.INTER_LINEAR)])
60
 
61
+ def preprocess(pil: Image.Image) -> np.ndarray:
62
+ """PIL → float32 tensor 1×3×H×W on device"""
63
  img = np.array(pil.convert("RGB"))
64
+ img = _resize(image=img)["image"].astype("float32") / 255.0
65
+ ten = torch.from_numpy(img.transpose(2,0,1)).unsqueeze(0).to(device)
66
+ return ten
67
+
68
+ def make_mask(prob: np.ndarray) -> np.ndarray:
69
+ """adaptive mask – fallback if THRESH_PRIMARY gives an empty mask"""
70
+ mask = (prob >= THRESH_PRIMARY).astype("uint8")
71
+ if mask.sum() == 0:
72
+ mask = (prob >= THRESH_FALLBACK).astype("uint8")
73
+ return mask
74
+
75
+ def overlay_contour(img: np.ndarray, mask: np.ndarray,
76
+ color=(0,0,255), thickness=2) -> np.ndarray:
77
+ """Draw red contours over the original image"""
78
+ contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
79
+ out = img.copy()
80
+ cv2.drawContours(out, contours, -1, color, thickness)
81
+ return out
82
+
83
+ # ------------------------------------------------------------------------
84
+ # 6️⃣ HEALTH‑CHECK
85
+ # ------------------------------------------------------------------------
86
+ @app.get("/ping", summary="Liveness/ready probe")
87
+ def ping():
88
+ return PlainTextResponse("pong")
89
+
90
+ # ------------------------------------------------------------------------
91
+ # 7️⃣ MAIN INFERENCE ENDPOINT
92
+ # ------------------------------------------------------------------------
93
+ @app.post("/segment", summary="Upload an image, get PNG with mask overlay")
94
  async def segment(file: UploadFile = File(...)):
95
  try:
96
+ raw_bytes = await file.read()
97
+ pil_img = Image.open(io.BytesIO(raw_bytes)).convert("RGB")
98
+ inp = preprocess(pil_img)
99
+
100
+ with torch.inference_mode():
101
+ out = model(inp)
102
+ if DEEP_SUP and isinstance(out, (list, tuple)):
103
+ out = out[-1]
104
+
105
+ prob = torch.sigmoid(out)[0,0].cpu().numpy()
106
+ mask = make_mask(prob)
107
+
108
+ # ---- overlay ----
109
+ resized = pil_img.resize((INPUT_W, INPUT_H), Image.BILINEAR)
110
+ overlay = overlay_contour(np.array(resized), mask)
111
+
112
+ buf = io.BytesIO()
113
+ Image.fromarray(overlay).save(buf, format="PNG")
114
+ buf.seek(0)
115
+
116
+ # ● optional debugging in Space logs
117
+ print(f"foreground px: {mask.sum():,} – file: {file.filename}")
118
+
119
+ return StreamingResponse(buf, media_type="image/png")
120
 
121
+ except Exception as e:
122
+ # keep same structure so JS can always parse JSON on error
123
+ return {"error": str(e)}