# app.py (debug‑friendly overlay) import io, cv2, numpy as np, torch, albumentations as A from PIL import Image from fastapi import FastAPI, UploadFile, File from fastapi.middleware.cors import CORSMiddleware from starlette.responses import StreamingResponse, PlainTextResponse from backend_src import archs # ─────────── config ──────────── CKPT = "models/best_model.pth" ARCH = "NestedUNet" INP_H = INP_W = 512 DEEP_SUP = False TH_MAIN, TH_FALL = .50, .25 ALPHA = .60 FRONTENDS = ["https://amitabhm1.github.io","http://localhost:4173"] # ─────────────────────────────── device = "cuda" if torch.cuda.is_available() else "cpu" model = archs.__dict__[ARCH](1,3,DEEP_SUP).to(device) state = torch.load(CKPT, map_location=device) state = {k.removeprefix("module."):v for k,v in state.items()} model.load_state_dict(state, strict=False) model.eval() print("✓ model loaded on", device) resize = A.Compose([A.Resize(INP_H, INP_W)]) def to_tensor(pil:Image.Image)->torch.Tensor: arr = np.array(pil.convert("RGB")).astype("float32")/255.0 # ★ match training! arr = resize(image=arr)["image"] return torch.from_numpy(arr.transpose(2,0,1)).unsqueeze(0).to(device) def mask_from_prob(prob:np.ndarray)->np.ndarray: m = (prob>=TH_MAIN).astype("uint8") return m if m.sum() else (prob>=TH_FALL).astype("uint8") def overlay(img,np_mask): # red contour cnts,_ = cv2.findContours(np_mask,cv2.RETR_EXTERNAL,cv2.CHAIN_APPROX_SIMPLE) out = img.copy() cv2.drawContours(out,cnts,-1,(0,0,255),2) return out def heatmap(img,prob): hm = cv2.applyColorMap((prob*255).astype("uint8"),cv2.COLORMAP_JET) return cv2.addWeighted(img,1-ALPHA,hm,ALPHA,0) app = FastAPI() app.add_middleware(CORSMiddleware, allow_origins=FRONTENDS, allow_methods=["POST","GET","OPTIONS"], allow_headers=["*"]) @app.get("/ping") def ping(): return PlainTextResponse("pong") @app.post("/segment") async def segment(file:UploadFile=File(...)): raw = await file.read() pil = Image.open(io.BytesIO(raw)) inp = to_tensor(pil) with torch.inference_mode(): out = model(inp) if DEEP_SUP and isinstance(out,(list,tuple)): out = out[-1] prob = torch.sigmoid(out)[0,0].cpu().numpy() mmax, mmin = prob.max(), prob.min() mask = mask_from_prob(prob) print(f"{file.filename}: min={mmin:.3f} max={mmax:.3f} fg={mask.sum()}") frame = np.array(pil.resize((INP_W,INP_H),Image.BILINEAR)) vis = overlay(frame,mask) if mask.sum() else heatmap(frame,prob) buf = io.BytesIO(); Image.fromarray(vis).save(buf,format="PNG"); buf.seek(0) return StreamingResponse(buf, media_type="image/png")