polyp-api / app.py
amitabh3's picture
Update app.py
96d978b verified
# app.py (debug‑friendly overlay)
import io, cv2, numpy as np, torch, albumentations as A
from PIL import Image
from fastapi import FastAPI, UploadFile, File
from fastapi.middleware.cors import CORSMiddleware
from starlette.responses import StreamingResponse, PlainTextResponse
from backend_src import archs
# ─────────── config ────────────
CKPT = "models/best_model.pth"
ARCH = "NestedUNet"
INP_H = INP_W = 512
DEEP_SUP = False
TH_MAIN, TH_FALL = .50, .25
ALPHA = .60
FRONTENDS = ["https://amitabhm1.github.io","http://localhost:4173"]
# ───────────────────────────────
device = "cuda" if torch.cuda.is_available() else "cpu"
model = archs.__dict__[ARCH](1,3,DEEP_SUP).to(device)
state = torch.load(CKPT, map_location=device)
state = {k.removeprefix("module."):v for k,v in state.items()}
model.load_state_dict(state, strict=False)
model.eval()
print("βœ“ model loaded on", device)
resize = A.Compose([A.Resize(INP_H, INP_W)])
def to_tensor(pil:Image.Image)->torch.Tensor:
arr = np.array(pil.convert("RGB")).astype("float32")/255.0 # β˜… match training!
arr = resize(image=arr)["image"]
return torch.from_numpy(arr.transpose(2,0,1)).unsqueeze(0).to(device)
def mask_from_prob(prob:np.ndarray)->np.ndarray:
m = (prob>=TH_MAIN).astype("uint8")
return m if m.sum() else (prob>=TH_FALL).astype("uint8")
def overlay(img,np_mask): # red contour
cnts,_ = cv2.findContours(np_mask,cv2.RETR_EXTERNAL,cv2.CHAIN_APPROX_SIMPLE)
out = img.copy()
cv2.drawContours(out,cnts,-1,(0,0,255),2)
return out
def heatmap(img,prob):
hm = cv2.applyColorMap((prob*255).astype("uint8"),cv2.COLORMAP_JET)
return cv2.addWeighted(img,1-ALPHA,hm,ALPHA,0)
app = FastAPI()
app.add_middleware(CORSMiddleware,
allow_origins=FRONTENDS, allow_methods=["POST","GET","OPTIONS"], allow_headers=["*"])
@app.get("/ping")
def ping(): return PlainTextResponse("pong")
@app.post("/segment")
async def segment(file:UploadFile=File(...)):
raw = await file.read()
pil = Image.open(io.BytesIO(raw))
inp = to_tensor(pil)
with torch.inference_mode():
out = model(inp)
if DEEP_SUP and isinstance(out,(list,tuple)): out = out[-1]
prob = torch.sigmoid(out)[0,0].cpu().numpy()
mmax, mmin = prob.max(), prob.min()
mask = mask_from_prob(prob)
print(f"{file.filename}: min={mmin:.3f} max={mmax:.3f} fg={mask.sum()}")
frame = np.array(pil.resize((INP_W,INP_H),Image.BILINEAR))
vis = overlay(frame,mask) if mask.sum() else heatmap(frame,prob)
buf = io.BytesIO(); Image.fromarray(vis).save(buf,format="PNG"); buf.seek(0)
return StreamingResponse(buf, media_type="image/png")