|
|
|
|
|
import io, cv2, numpy as np, torch, albumentations as A |
|
|
from PIL import Image |
|
|
from fastapi import FastAPI, UploadFile, File |
|
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
from starlette.responses import StreamingResponse, PlainTextResponse |
|
|
from backend_src import archs |
|
|
|
|
|
|
|
|
CKPT = "models/best_model.pth" |
|
|
ARCH = "NestedUNet" |
|
|
INP_H = INP_W = 512 |
|
|
DEEP_SUP = False |
|
|
TH_MAIN, TH_FALL = .50, .25 |
|
|
ALPHA = .60 |
|
|
FRONTENDS = ["https://amitabhm1.github.io","http://localhost:4173"] |
|
|
|
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
model = archs.__dict__[ARCH](1,3,DEEP_SUP).to(device) |
|
|
state = torch.load(CKPT, map_location=device) |
|
|
state = {k.removeprefix("module."):v for k,v in state.items()} |
|
|
model.load_state_dict(state, strict=False) |
|
|
model.eval() |
|
|
print("β model loaded on", device) |
|
|
|
|
|
resize = A.Compose([A.Resize(INP_H, INP_W)]) |
|
|
|
|
|
def to_tensor(pil:Image.Image)->torch.Tensor: |
|
|
arr = np.array(pil.convert("RGB")).astype("float32")/255.0 |
|
|
arr = resize(image=arr)["image"] |
|
|
return torch.from_numpy(arr.transpose(2,0,1)).unsqueeze(0).to(device) |
|
|
|
|
|
def mask_from_prob(prob:np.ndarray)->np.ndarray: |
|
|
m = (prob>=TH_MAIN).astype("uint8") |
|
|
return m if m.sum() else (prob>=TH_FALL).astype("uint8") |
|
|
|
|
|
def overlay(img,np_mask): |
|
|
cnts,_ = cv2.findContours(np_mask,cv2.RETR_EXTERNAL,cv2.CHAIN_APPROX_SIMPLE) |
|
|
out = img.copy() |
|
|
cv2.drawContours(out,cnts,-1,(0,0,255),2) |
|
|
return out |
|
|
|
|
|
def heatmap(img,prob): |
|
|
hm = cv2.applyColorMap((prob*255).astype("uint8"),cv2.COLORMAP_JET) |
|
|
return cv2.addWeighted(img,1-ALPHA,hm,ALPHA,0) |
|
|
|
|
|
app = FastAPI() |
|
|
app.add_middleware(CORSMiddleware, |
|
|
allow_origins=FRONTENDS, allow_methods=["POST","GET","OPTIONS"], allow_headers=["*"]) |
|
|
|
|
|
@app.get("/ping") |
|
|
def ping(): return PlainTextResponse("pong") |
|
|
|
|
|
@app.post("/segment") |
|
|
async def segment(file:UploadFile=File(...)): |
|
|
raw = await file.read() |
|
|
pil = Image.open(io.BytesIO(raw)) |
|
|
inp = to_tensor(pil) |
|
|
|
|
|
with torch.inference_mode(): |
|
|
out = model(inp) |
|
|
if DEEP_SUP and isinstance(out,(list,tuple)): out = out[-1] |
|
|
prob = torch.sigmoid(out)[0,0].cpu().numpy() |
|
|
mmax, mmin = prob.max(), prob.min() |
|
|
mask = mask_from_prob(prob) |
|
|
print(f"{file.filename}: min={mmin:.3f} max={mmax:.3f} fg={mask.sum()}") |
|
|
|
|
|
frame = np.array(pil.resize((INP_W,INP_H),Image.BILINEAR)) |
|
|
vis = overlay(frame,mask) if mask.sum() else heatmap(frame,prob) |
|
|
|
|
|
buf = io.BytesIO(); Image.fromarray(vis).save(buf,format="PNG"); buf.seek(0) |
|
|
return StreamingResponse(buf, media_type="image/png") |