amitabh3 commited on
Commit
5964bb9
·
verified ·
1 Parent(s): 268763c

Add app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -0
app.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io, cv2, numpy as np, torch
2
+ from fastapi import FastAPI, File, UploadFile
3
+ from fastapi.middleware.cors import CORSMiddleware
4
+ from starlette.responses import StreamingResponse
5
+ from PIL import Image
6
+ import archs # copy your model definition file into this repo
7
+
8
+ # ----- CORS: allow your GitHub Pages origin -----
9
+ app = FastAPI()
10
+ app.add_middleware(
11
+ CORSMiddleware,
12
+ allow_origins=["https://<your-github-username>.github.io"],
13
+ allow_methods=["POST"],
14
+ allow_headers=["*"],
15
+ )
16
+
17
+ # ----- load model once -----
18
+ MODEL_PATH = "model.pth"
19
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
20
+ model = archs.NestedUNet(num_classes=1, input_channels=3, deep_supervision=False)
21
+ model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
22
+ model.to(DEVICE).eval()
23
+
24
+ # ----- utils -----
25
+ def preprocess(pil):
26
+ im = pil.resize((512, 512)) # same as training
27
+ arr = np.asarray(im).astype("float32")/255
28
+ ten = torch.from_numpy(arr.transpose(2,0,1)).unsqueeze(0)
29
+ return ten.to(DEVICE)
30
+
31
+ def postprocess(pred, alpha=0.4):
32
+ mask = (torch.sigmoid(pred)[0,0].cpu().numpy() > .5).astype("uint8")
33
+ mask_rgb = np.zeros((*mask.shape,3), np.uint8)
34
+ mask_rgb[mask==1] = (255,0,0)
35
+ return mask_rgb, mask
36
+
37
+ def overlay(img, mask_rgb, alpha=0.4):
38
+ blend = (img*(1-alpha) + mask_rgb*alpha).astype("uint8")
39
+ out = img.copy()
40
+ out[mask_rgb[:,:,0]>0] = blend[mask_rgb[:,:,0]>0]
41
+ return out
42
+
43
+ # ----- endpoint -----
44
+ @app.post("/segment")
45
+ async def segment(file: UploadFile = File(...)):
46
+ raw = await file.read()
47
+ pil = Image.open(io.BytesIO(raw)).convert("RGB")
48
+ input_t = preprocess(pil)
49
+ with torch.no_grad():
50
+ pred = model(input_t)
51
+ if isinstance(pred,(list,tuple)): pred = pred[-1]
52
+ mask_rgb,_ = postprocess(pred)
53
+ result = overlay(np.array(pil.resize((512,512))), mask_rgb)
54
+ buf = io.BytesIO()
55
+ Image.fromarray(result).save(buf, format="PNG")
56
+ buf.seek(0)
57
+ return StreamingResponse(buf, media_type="image/png")