amitabh3 commited on
Commit
9d4f263
·
verified ·
1 Parent(s): 26dc955

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +64 -61
app.py CHANGED
@@ -1,81 +1,84 @@
1
- # app.py  Hugging‑Face Space backend
2
  import io, cv2, torch, albumentations as A, numpy as np
3
  from PIL import Image
4
  from fastapi import FastAPI, File, UploadFile
5
  from fastapi.middleware.cors import CORSMiddleware
6
- from starlette.responses import StreamingResponse
7
 
8
- # ---------- model code lives inside backend_src -------------
9
- from backend_src import archs # <- you copied this
10
- # ------------------------------------------------------------
11
-
12
- ### CONFIG #### -------------------------------------------------
13
- CKPT = "models/best_model.pth" # uploaded via Git‑LFS
14
  ARCH_NAME = "NestedUNet"
15
- INPUT_H = INPUT_W = 512 # what you trained with
16
- THRESH = 0.5 # binarisation
 
17
  DEEP_SUP = False
18
- ALPHA = 0.40 # overlay transparency
19
- ALLOWED_FRONT = "https://amitabhm1.github.io" # GitHub‑Pages origin
20
- ################ -------------------------------------------------
 
 
 
21
 
22
- # ---------------- FastAPI (plus simple CORS) ------------------
23
- app = FastAPI(title="Polyp Segmentation API")
24
 
25
  app.add_middleware(
26
  CORSMiddleware,
27
- allow_origins=[ALLOWED_FRONT, "http://localhost:4173", "http://127.0.0.1:5500"],
28
- allow_methods=["POST", "OPTIONS"],
29
  allow_headers=["*"],
30
  )
31
- # --------------- load model once at startup -------------------
 
 
 
 
32
  device = "cuda" if torch.cuda.is_available() else "cpu"
33
- model = archs.__dict__[ARCH_NAME](num_classes=1,
34
- input_channels=3,
35
- deep_supervision=DEEP_SUP).to(device)
36
- state = torch.load(CKPT, map_location=device)
37
- # strip “module.” if checkpoint came from DDP
38
- state = {k.replace("module.", ""): v for k, v in state.items()}
39
- model.load_state_dict(state, strict=False)
40
- model.eval()
41
- print("✓ model loaded on", device)
42
-
43
- # ---------- transforms = pad/resize exactly as during train ----
44
- transform = A.Compose([
45
- A.Resize(height=INPUT_H, width=INPUT_W, interpolation=cv2.INTER_LINEAR),
46
- ])
47
-
48
- # ------------------------- helpers -----------------------------
49
- def infer_overlay(pil_img: Image.Image) -> bytes:
50
- orig = np.array(pil_img.convert("RGB"))
51
- trans = transform(image=orig)["image"]
52
- ten = torch.from_numpy(trans.astype("float32")/255.0).permute(2,0,1).unsqueeze(0).to(device)
 
 
 
53
 
54
  with torch.no_grad():
55
- out = model(ten)
56
- if DEEP_SUP and isinstance(out, (list,tuple)):
57
- out = out[-1]
58
-
59
- mask = (torch.sigmoid(out)[0,0].cpu().numpy() > THRESH).astype("uint8")
60
- # colour mask
61
- mask_rgb = np.zeros((*mask.shape,3), dtype="uint8")
62
- mask_rgb[mask==1] = (255, 0, 0) # red overlay
63
- # blend
64
- blended = (trans*(1-ALPHA) + mask_rgb*ALPHA).astype("uint8")
65
-
66
- # encode PNG
67
  buf = io.BytesIO()
68
- Image.fromarray(blended).save(buf, format="PNG")
69
- buf.seek(0)
70
- return buf.getvalue()
71
- # ---------------------------------------------------------------
72
 
73
- @app.post("/segment", summary="Upload an image, get overlay PNG")
 
74
  async def segment(file: UploadFile = File(...)):
75
  try:
76
- img_bytes = await file.read()
77
- pil = Image.open(io.BytesIO(img_bytes))
78
- png_bytes = infer_overlay(pil)
79
- return StreamingResponse(io.BytesIO(png_bytes), media_type="image/png")
80
- except Exception as e:
81
- return {"error": str(e)}
 
1
+ # ---------- app.py (backend root of your HF Space) -----------------
2
  import io, cv2, torch, albumentations as A, numpy as np
3
  from PIL import Image
4
  from fastapi import FastAPI, File, UploadFile
5
  from fastapi.middleware.cors import CORSMiddleware
6
+ from starlette.responses import StreamingResponse, JSONResponse
7
 
8
+ # -------------------------------------------------------------------
9
+ CKPT = "models/best_model.pth"
 
 
 
 
10
  ARCH_NAME = "NestedUNet"
11
+ INPUT_H = INPUT_W = 512
12
+ THRESH = 0.50
13
+ ALPHA = 0.40
14
  DEEP_SUP = False
15
+ FRONT_ORIGIN = "https://amitabhm1.github.io"
16
+ # -------------------------------------------------------------------
17
+
18
+ # ---------- import model definition from backend_src ----------------
19
+ from backend_src import archs
20
+ # --------------------------------------------------------------------
21
 
22
+ app = FastAPI(title="Polyp‑Seg‑API")
 
23
 
24
  app.add_middleware(
25
  CORSMiddleware,
26
+ allow_origins=[FRONT_ORIGIN, "http://localhost:4173"],
27
+ allow_methods=["POST"],
28
  allow_headers=["*"],
29
  )
30
+
31
+ # ---------------- load checkpoint (fail fast) -----------------------
32
+ if not torch.cuda.is_available():
33
+ print("⚠ CUDA not available – running on CPU")
34
+
35
  device = "cuda" if torch.cuda.is_available() else "cpu"
36
+ net = archs.__dict__[ARCH_NAME](1, 3, DEEP_SUP).to(device)
37
+
38
+ try:
39
+ state = torch.load(CKPT, map_location=device)
40
+ except FileNotFoundError as e:
41
+ raise SystemExit(f"❌ Checkpoint {CKPT} not found in the Space") from e
42
+
43
+ state = {k.replace("module.", ""): v for k, v in state.items()}
44
+ missing, unexpected = net.load_state_dict(state, strict=False)
45
+ if missing:
46
+ raise SystemExit(f"❌ Checkpoint is missing keys: {missing[:5]} …")
47
+
48
+ net.eval()
49
+ print("✓ model loaded on", device)
50
+
51
+ # ---------------- preprocessing -------------------------------------
52
+ xform = A.Compose([A.Resize(INPUT_H, INPUT_W)])
53
+
54
+ def predict_overlay(pil: Image.Image) -> bytes:
55
+ img = np.array(pil.convert("RGB"))
56
+ img_rz = xform(image=img)["image"]
57
+ ten = torch.from_numpy(img_rz.astype("float32")/255).permute(2,0,1)
58
+ ten = ten.unsqueeze(0).to(device)
59
 
60
  with torch.no_grad():
61
+ logits = net(ten)
62
+ if DEEP_SUP and isinstance(logits,(list,tuple)): logits = logits[-1]
63
+ mask = (torch.sigmoid(logits)[0,0].cpu().numpy() > THRESH).astype("uint8")
64
+
65
+ # colourise mask in red
66
+ red = np.zeros_like(img_rz); red[:] = (0,0,255)
67
+ blend = (img_rz*(1-ALPHA) + red*ALPHA).astype("uint8")
68
+ out = img_rz.copy()
69
+ out[mask==1] = blend[mask==1]
70
+
 
 
71
  buf = io.BytesIO()
72
+ Image.fromarray(out).save(buf, format="PNG")
73
+ buf.seek(0); return buf
 
 
74
 
75
+ # ---------------- endpoint -----------------------------------------
76
+ @app.post("/segment")
77
  async def segment(file: UploadFile = File(...)):
78
  try:
79
+ pil = Image.open(io.BytesIO(await file.read()))
80
+ except Exception:
81
+ return JSONResponse({"error":"invalid image"}, status_code=400)
82
+
83
+ png = predict_overlay(pil)
84
+ return StreamingResponse(png, media_type="image/png")