File size: 2,824 Bytes
96d978b
 
4100272
c206e8b
4100272
c206e8b
96d978b
c206e8b
96d978b
 
 
 
 
 
 
 
 
9d4f263
4100272
96d978b
 
 
c206e8b
 
96d978b
9d4f263
96d978b
9d4f263
96d978b
 
 
 
c206e8b
96d978b
 
 
c206e8b
96d978b
 
c206e8b
96d978b
c206e8b
 
96d978b
 
 
c206e8b
96d978b
 
 
c206e8b
96d978b
 
c206e8b
96d978b
 
 
 
 
c206e8b
96d978b
 
 
 
 
 
 
c206e8b
96d978b
 
9d4f263
96d978b
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
# app.py  (debug‑friendly overlay)
import io, cv2, numpy as np, torch, albumentations as A
from PIL import Image
from fastapi import FastAPI, UploadFile, File
from fastapi.middleware.cors import CORSMiddleware
from starlette.responses import StreamingResponse, PlainTextResponse
from backend_src import archs

# ─────────── config ────────────
CKPT          = "models/best_model.pth"
ARCH          = "NestedUNet"
INP_H = INP_W = 512
DEEP_SUP      = False
TH_MAIN, TH_FALL = .50, .25
ALPHA         = .60
FRONTENDS     = ["https://amitabhm1.github.io","http://localhost:4173"]
# ───────────────────────────────

device = "cuda" if torch.cuda.is_available() else "cpu"
model  = archs.__dict__[ARCH](1,3,DEEP_SUP).to(device)
state  = torch.load(CKPT, map_location=device)
state  = {k.removeprefix("module."):v for k,v in state.items()}
model.load_state_dict(state, strict=False)
model.eval()
print("βœ“ model loaded on", device)

resize = A.Compose([A.Resize(INP_H, INP_W)])

def to_tensor(pil:Image.Image)->torch.Tensor:
    arr = np.array(pil.convert("RGB")).astype("float32")/255.0  # β˜… match training!
    arr = resize(image=arr)["image"]
    return torch.from_numpy(arr.transpose(2,0,1)).unsqueeze(0).to(device)

def mask_from_prob(prob:np.ndarray)->np.ndarray:
    m = (prob>=TH_MAIN).astype("uint8")
    return m if m.sum() else (prob>=TH_FALL).astype("uint8")

def overlay(img,np_mask):                               # red contour
    cnts,_ = cv2.findContours(np_mask,cv2.RETR_EXTERNAL,cv2.CHAIN_APPROX_SIMPLE)
    out = img.copy()
    cv2.drawContours(out,cnts,-1,(0,0,255),2)
    return out

def heatmap(img,prob):
    hm = cv2.applyColorMap((prob*255).astype("uint8"),cv2.COLORMAP_JET)
    return cv2.addWeighted(img,1-ALPHA,hm,ALPHA,0)

app = FastAPI()
app.add_middleware(CORSMiddleware,
    allow_origins=FRONTENDS, allow_methods=["POST","GET","OPTIONS"], allow_headers=["*"])

@app.get("/ping")
def ping(): return PlainTextResponse("pong")

@app.post("/segment")
async def segment(file:UploadFile=File(...)):
    raw = await file.read()
    pil = Image.open(io.BytesIO(raw))
    inp = to_tensor(pil)

    with torch.inference_mode():
        out = model(inp)
        if DEEP_SUP and isinstance(out,(list,tuple)): out = out[-1]
    prob = torch.sigmoid(out)[0,0].cpu().numpy()
    mmax, mmin = prob.max(), prob.min()
    mask = mask_from_prob(prob)
    print(f"{file.filename}: min={mmin:.3f}  max={mmax:.3f}  fg={mask.sum()}")

    frame = np.array(pil.resize((INP_W,INP_H),Image.BILINEAR))
    vis   = overlay(frame,mask) if mask.sum() else heatmap(frame,prob)

    buf = io.BytesIO(); Image.fromarray(vis).save(buf,format="PNG"); buf.seek(0)
    return StreamingResponse(buf, media_type="image/png")