File size: 2,824 Bytes
96d978b 4100272 c206e8b 4100272 c206e8b 96d978b c206e8b 96d978b 9d4f263 4100272 96d978b c206e8b 96d978b 9d4f263 96d978b 9d4f263 96d978b c206e8b 96d978b c206e8b 96d978b c206e8b 96d978b c206e8b 96d978b c206e8b 96d978b c206e8b 96d978b c206e8b 96d978b c206e8b 96d978b c206e8b 96d978b 9d4f263 96d978b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 |
# app.py (debugβfriendly overlay)
import io, cv2, numpy as np, torch, albumentations as A
from PIL import Image
from fastapi import FastAPI, UploadFile, File
from fastapi.middleware.cors import CORSMiddleware
from starlette.responses import StreamingResponse, PlainTextResponse
from backend_src import archs
# βββββββββββ config ββββββββββββ
CKPT = "models/best_model.pth"
ARCH = "NestedUNet"
INP_H = INP_W = 512
DEEP_SUP = False
TH_MAIN, TH_FALL = .50, .25
ALPHA = .60
FRONTENDS = ["https://amitabhm1.github.io","http://localhost:4173"]
# βββββββββββββββββββββββββββββββ
device = "cuda" if torch.cuda.is_available() else "cpu"
model = archs.__dict__[ARCH](1,3,DEEP_SUP).to(device)
state = torch.load(CKPT, map_location=device)
state = {k.removeprefix("module."):v for k,v in state.items()}
model.load_state_dict(state, strict=False)
model.eval()
print("β model loaded on", device)
resize = A.Compose([A.Resize(INP_H, INP_W)])
def to_tensor(pil:Image.Image)->torch.Tensor:
arr = np.array(pil.convert("RGB")).astype("float32")/255.0 # β
match training!
arr = resize(image=arr)["image"]
return torch.from_numpy(arr.transpose(2,0,1)).unsqueeze(0).to(device)
def mask_from_prob(prob:np.ndarray)->np.ndarray:
m = (prob>=TH_MAIN).astype("uint8")
return m if m.sum() else (prob>=TH_FALL).astype("uint8")
def overlay(img,np_mask): # red contour
cnts,_ = cv2.findContours(np_mask,cv2.RETR_EXTERNAL,cv2.CHAIN_APPROX_SIMPLE)
out = img.copy()
cv2.drawContours(out,cnts,-1,(0,0,255),2)
return out
def heatmap(img,prob):
hm = cv2.applyColorMap((prob*255).astype("uint8"),cv2.COLORMAP_JET)
return cv2.addWeighted(img,1-ALPHA,hm,ALPHA,0)
app = FastAPI()
app.add_middleware(CORSMiddleware,
allow_origins=FRONTENDS, allow_methods=["POST","GET","OPTIONS"], allow_headers=["*"])
@app.get("/ping")
def ping(): return PlainTextResponse("pong")
@app.post("/segment")
async def segment(file:UploadFile=File(...)):
raw = await file.read()
pil = Image.open(io.BytesIO(raw))
inp = to_tensor(pil)
with torch.inference_mode():
out = model(inp)
if DEEP_SUP and isinstance(out,(list,tuple)): out = out[-1]
prob = torch.sigmoid(out)[0,0].cpu().numpy()
mmax, mmin = prob.max(), prob.min()
mask = mask_from_prob(prob)
print(f"{file.filename}: min={mmin:.3f} max={mmax:.3f} fg={mask.sum()}")
frame = np.array(pil.resize((INP_W,INP_H),Image.BILINEAR))
vis = overlay(frame,mask) if mask.sum() else heatmap(frame,prob)
buf = io.BytesIO(); Image.fromarray(vis).save(buf,format="PNG"); buf.seek(0)
return StreamingResponse(buf, media_type="image/png") |