amitabh3 commited on
Commit
4100272
·
verified ·
1 Parent(s): 4891b1a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +70 -20
app.py CHANGED
@@ -1,31 +1,81 @@
1
- # app.py (key lines only)
 
 
 
 
 
2
 
3
- from fastapi import FastAPI, UploadFile, File
4
- from backend_src import archs # path fixed
5
- import torch, albumentations as A, cv2, numpy as np
6
 
7
- # ------------------------------- CFG ---------------------------------
8
- CKPT_PATH = "models/best_model.pth"
9
- ARCH_NAME = "NestedUNet"
10
- NUM_CLASSES = 1
11
- # ---------------------------------------------------------------------
 
 
 
 
12
 
13
- app = FastAPI()
 
14
 
15
- # ----- CORS (GitHub Pages) -----
16
- from fastapi.middleware.cors import CORSMiddleware
17
  app.add_middleware(
18
  CORSMiddleware,
19
- allow_origins=["https://amitabhm.github.io", # or "*" while testing
20
- "http://localhost:4173"], # vite / local dev
21
  allow_methods=["POST"],
22
  allow_headers=["*"],
23
  )
24
-
25
- # ------------------------------ LOAD MODEL ---------------------------
26
- model = archs.__dict__[ARCH_NAME](NUM_CLASSES, 3, False)
27
- state = torch.load(CKPT_PATH, map_location="cpu")
28
- # strip "module." keys if you trained with DDP
29
- state = {k.replace("module.", ""): v for k, v in state.items()}
 
 
30
  model.load_state_dict(state, strict=False)
31
  model.eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py – Hugging‑Face Space backend
2
+ import io, cv2, torch, albumentations as A, numpy as np
3
+ from PIL import Image
4
+ from fastapi import FastAPI, File, UploadFile
5
+ from fastapi.middleware.cors import CORSMiddleware
6
+ from starlette.responses import StreamingResponse
7
 
8
+ # ---------- model code lives inside backend_src -------------
9
+ from backend_src import archs # <- you copied this
10
+ # ------------------------------------------------------------
11
 
12
+ ### CONFIG #### -------------------------------------------------
13
+ CKPT = "models/best_model.pth" # uploaded via Git‑LFS
14
+ ARCH_NAME = "NestedUNet"
15
+ INPUT_H = INPUT_W = 512 # what you trained with
16
+ THRESH = 0.5 # binarisation
17
+ DEEP_SUP = False
18
+ ALPHA = 0.40 # overlay transparency
19
+ ALLOWED_FRONT = "https://amitabhm1.github.io" # GitHub‑Pages origin
20
+ ################ -------------------------------------------------
21
 
22
+ # ---------------- FastAPI (plus simple CORS) ------------------
23
+ app = FastAPI(title="Polyp Segmentation API")
24
 
 
 
25
  app.add_middleware(
26
  CORSMiddleware,
27
+ allow_origins=[ALLOWED_FRONT, "http://localhost:4173", "http://127.0.0.1:5500"],
 
28
  allow_methods=["POST"],
29
  allow_headers=["*"],
30
  )
31
+ # --------------- load model once at startup -------------------
32
+ device = "cuda" if torch.cuda.is_available() else "cpu"
33
+ model = archs.__dict__[ARCH_NAME](num_classes=1,
34
+ input_channels=3,
35
+ deep_supervision=DEEP_SUP).to(device)
36
+ state = torch.load(CKPT, map_location=device)
37
+ # strip “module.” if checkpoint came from DDP
38
+ state = {k.replace("module.", ""): v for k, v in state.items()}
39
  model.load_state_dict(state, strict=False)
40
  model.eval()
41
+ print("✓ model loaded on", device)
42
+
43
+ # ---------- transforms = pad/resize exactly as during train ----
44
+ transform = A.Compose([
45
+ A.Resize(height=INPUT_H, width=INPUT_W, interpolation=cv2.INTER_LINEAR),
46
+ ])
47
+
48
+ # ------------------------- helpers -----------------------------
49
+ def infer_overlay(pil_img: Image.Image) -> bytes:
50
+ orig = np.array(pil_img.convert("RGB"))
51
+ trans = transform(image=orig)["image"]
52
+ ten = torch.from_numpy(trans.astype("float32")/255.0).permute(2,0,1).unsqueeze(0).to(device)
53
+
54
+ with torch.no_grad():
55
+ out = model(ten)
56
+ if DEEP_SUP and isinstance(out, (list,tuple)):
57
+ out = out[-1]
58
+
59
+ mask = (torch.sigmoid(out)[0,0].cpu().numpy() > THRESH).astype("uint8")
60
+ # colour mask
61
+ mask_rgb = np.zeros((*mask.shape,3), dtype="uint8")
62
+ mask_rgb[mask==1] = (255, 0, 0) # red overlay
63
+ # blend
64
+ blended = (trans*(1-ALPHA) + mask_rgb*ALPHA).astype("uint8")
65
+
66
+ # encode PNG
67
+ buf = io.BytesIO()
68
+ Image.fromarray(blended).save(buf, format="PNG")
69
+ buf.seek(0)
70
+ return buf.getvalue()
71
+ # ---------------------------------------------------------------
72
+
73
+ @app.post("/segment", summary="Upload an image, get overlay PNG")
74
+ async def segment(file: UploadFile = File(...)):
75
+ try:
76
+ img_bytes = await file.read()
77
+ pil = Image.open(io.BytesIO(img_bytes))
78
+ png_bytes = infer_overlay(pil)
79
+ return StreamingResponse(io.BytesIO(png_bytes), media_type="image/png")
80
+ except Exception as e:
81
+ return {"error": str(e)}