amitabh3 commited on
Commit
4891b1a
·
verified ·
1 Parent(s): 72c1765

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -49
app.py CHANGED
@@ -1,57 +1,31 @@
1
- import io, cv2, numpy as np, torch
2
- from fastapi import FastAPI, File, UploadFile
3
- from fastapi.middleware.cors import CORSMiddleware
4
- from starlette.responses import StreamingResponse
5
- from PIL import Image
6
- import archs # copy your model definition file into this repo
 
 
 
 
 
7
 
8
- # ----- CORS: allow your GitHub Pages origin -----
9
  app = FastAPI()
 
 
 
10
  app.add_middleware(
11
  CORSMiddleware,
12
- allow_origins=["https://<your-github-username>.github.io"],
 
13
  allow_methods=["POST"],
14
  allow_headers=["*"],
15
  )
16
 
17
- # ----- load model once -----
18
- MODEL_PATH = "model.pth"
19
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
20
- model = archs.NestedUNet(num_classes=1, input_channels=3, deep_supervision=False)
21
- model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
22
- model.to(DEVICE).eval()
23
-
24
- # ----- utils -----
25
- def preprocess(pil):
26
- im = pil.resize((512, 512)) # same as training
27
- arr = np.asarray(im).astype("float32")/255
28
- ten = torch.from_numpy(arr.transpose(2,0,1)).unsqueeze(0)
29
- return ten.to(DEVICE)
30
-
31
- def postprocess(pred, alpha=0.4):
32
- mask = (torch.sigmoid(pred)[0,0].cpu().numpy() > .5).astype("uint8")
33
- mask_rgb = np.zeros((*mask.shape,3), np.uint8)
34
- mask_rgb[mask==1] = (255,0,0)
35
- return mask_rgb, mask
36
-
37
- def overlay(img, mask_rgb, alpha=0.4):
38
- blend = (img*(1-alpha) + mask_rgb*alpha).astype("uint8")
39
- out = img.copy()
40
- out[mask_rgb[:,:,0]>0] = blend[mask_rgb[:,:,0]>0]
41
- return out
42
-
43
- # ----- endpoint -----
44
- @app.post("/segment")
45
- async def segment(file: UploadFile = File(...)):
46
- raw = await file.read()
47
- pil = Image.open(io.BytesIO(raw)).convert("RGB")
48
- input_t = preprocess(pil)
49
- with torch.no_grad():
50
- pred = model(input_t)
51
- if isinstance(pred,(list,tuple)): pred = pred[-1]
52
- mask_rgb,_ = postprocess(pred)
53
- result = overlay(np.array(pil.resize((512,512))), mask_rgb)
54
- buf = io.BytesIO()
55
- Image.fromarray(result).save(buf, format="PNG")
56
- buf.seek(0)
57
- return StreamingResponse(buf, media_type="image/png")
 
1
+ # app.py (key lines only)
2
+
3
+ from fastapi import FastAPI, UploadFile, File
4
+ from backend_src import archs # ← path fixed
5
+ import torch, albumentations as A, cv2, numpy as np
6
+
7
+ # ------------------------------- CFG ---------------------------------
8
+ CKPT_PATH = "models/best_model.pth"
9
+ ARCH_NAME = "NestedUNet"
10
+ NUM_CLASSES = 1
11
+ # ---------------------------------------------------------------------
12
 
 
13
  app = FastAPI()
14
+
15
+ # ----- CORS (GitHub Pages) -----
16
+ from fastapi.middleware.cors import CORSMiddleware
17
  app.add_middleware(
18
  CORSMiddleware,
19
+ allow_origins=["https://amitabhm.github.io", # or "*" while testing
20
+ "http://localhost:4173"], # vite / local dev
21
  allow_methods=["POST"],
22
  allow_headers=["*"],
23
  )
24
 
25
+ # ------------------------------ LOAD MODEL ---------------------------
26
+ model = archs.__dict__[ARCH_NAME](NUM_CLASSES, 3, False)
27
+ state = torch.load(CKPT_PATH, map_location="cpu")
28
+ # strip "module." keys if you trained with DDP
29
+ state = {k.replace("module.", ""): v for k, v in state.items()}
30
+ model.load_state_dict(state, strict=False)
31
+ model.eval()