sam2-fastapi / app.py
mila2030's picture
Update app.py
68874b0 verified
from fastapi import FastAPI, Request
from fastapi.responses import Response, JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from PIL import Image
import torch
import numpy as np
import base64
from io import BytesIO
from sam2.sam2_image_predictor import SAM2ImagePredictor
predictor = SAM2ImagePredictor.from_pretrained("checkpoint/sam2.1_hiera_large.pt")
device = "cuda" if torch.cuda.is_available() else "cpu"
predictor.model.to(device).eval()
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], allow_credentials=True,
allow_methods=["*"], allow_headers=["*"]
)
@app.post("/sam2_segment/")
async def sam2_segment(request: Request):
try:
data = await request.json()
image_base64 = data.get("image")
point_coords = data.get("point_coords", [])
point_labels = data.get("point_labels", [])
if (
not image_base64
or not isinstance(point_coords, list)
or not isinstance(point_labels, list)
or len(point_coords) == 0
or len(point_coords) != len(point_labels)
):
return JSONResponse(status_code=400, content={"error": "point_coords and point_labels must be supplied and have equal length."})
img_bytes = base64.b64decode(image_base64)
pil_img = Image.open(BytesIO(img_bytes)).convert("RGB")
np_img = np.array(pil_img)
h, w = pil_img.height, pil_img.width
union_mask = np.zeros((h, w), dtype=np.uint8)
# Run SAM2 separately for each point, accumulate masks
with torch.inference_mode():
predictor.set_image(np_img)
for pt, label in zip(point_coords, point_labels):
pt_np = np.array([pt], dtype=np.float32)
label_np = np.array([label], dtype=np.int32)
masks, _, _ = predictor.predict(
point_coords=pt_np,
point_labels=label_np,
)
union_mask = np.logical_or(union_mask, masks[0]).astype(np.uint8)
rgba = np.zeros((h, w, 4), dtype=np.uint8)
rgba[..., 3] = union_mask * 128 # 128 = semi-transparent
out_img = Image.fromarray(rgba, mode="RGBA")
buf = BytesIO()
out_img.save(buf, format="PNG")
return Response(content=buf.getvalue(), media_type="image/png")
except Exception as e:
print("ERROR:", str(e))
return JSONResponse(status_code=500, content={"error": str(e)})