mila2030 commited on
Commit
1ed2808
·
verified ·
1 Parent(s): e46693e

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +68 -0
app.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, Request
2
+ from fastapi.responses import Response, JSONResponse
3
+ from fastapi.middleware.cors import CORSMiddleware
4
+ from PIL import Image
5
+ import torch
6
+ import numpy as np
7
+ import base64
8
+ from io import BytesIO
9
+
10
+ from sam2.sam2_image_predictor import SAM2ImagePredictor
11
+
12
+ predictor = SAM2ImagePredictor.from_pretrained("facebook/sam2-hiera-large")
13
+ device = "cuda" if torch.cuda.is_available() else "cpu"
14
+ predictor.model.to(device).eval()
15
+
16
+ app = FastAPI()
17
+ app.add_middleware(
18
+ CORSMiddleware,
19
+ allow_origins=["*"], allow_credentials=True,
20
+ allow_methods=["*"], allow_headers=["*"]
21
+ )
22
+
23
+ @app.post("/sam2_segment/")
24
+ async def sam2_segment(request: Request):
25
+ try:
26
+ data = await request.json()
27
+ image_base64 = data.get("image")
28
+ point_coords = data.get("point_coords", [])
29
+ point_labels = data.get("point_labels", [])
30
+
31
+ if (
32
+ not image_base64
33
+ or not isinstance(point_coords, list)
34
+ or not isinstance(point_labels, list)
35
+ or len(point_coords) == 0
36
+ or len(point_coords) != len(point_labels)
37
+ ):
38
+ return JSONResponse(status_code=400, content={"error": "point_coords and point_labels must be supplied and have equal length."})
39
+
40
+ img_bytes = base64.b64decode(image_base64)
41
+ pil_img = Image.open(BytesIO(img_bytes)).convert("RGB")
42
+ np_img = np.array(pil_img)
43
+
44
+ h, w = pil_img.height, pil_img.width
45
+ union_mask = np.zeros((h, w), dtype=np.uint8)
46
+
47
+ # Run SAM2 separately for each point, accumulate masks
48
+ with torch.inference_mode():
49
+ predictor.set_image(np_img)
50
+ for pt, label in zip(point_coords, point_labels):
51
+ pt_np = np.array([pt], dtype=np.float32)
52
+ label_np = np.array([label], dtype=np.int32)
53
+ masks, _, _ = predictor.predict(
54
+ point_coords=pt_np,
55
+ point_labels=label_np,
56
+ )
57
+ union_mask = np.logical_or(union_mask, masks[0]).astype(np.uint8)
58
+
59
+ rgba = np.zeros((h, w, 4), dtype=np.uint8)
60
+ rgba[..., 3] = union_mask * 128 # 128 = semi-transparent
61
+
62
+ out_img = Image.fromarray(rgba, mode="RGBA")
63
+ buf = BytesIO()
64
+ out_img.save(buf, format="PNG")
65
+ return Response(content=buf.getvalue(), media_type="image/png")
66
+ except Exception as e:
67
+ print("ERROR:", str(e))
68
+ return JSONResponse(status_code=500, content={"error": str(e)})