from fastapi import FastAPI, Request from fastapi.responses import Response, JSONResponse from fastapi.middleware.cors import CORSMiddleware from PIL import Image import torch import numpy as np import base64 from io import BytesIO from sam2.sam2_image_predictor import SAM2ImagePredictor predictor = SAM2ImagePredictor.from_pretrained("checkpoint/sam2.1_hiera_large.pt") device = "cuda" if torch.cuda.is_available() else "cpu" predictor.model.to(device).eval() app = FastAPI() app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"] ) @app.post("/sam2_segment/") async def sam2_segment(request: Request): try: data = await request.json() image_base64 = data.get("image") point_coords = data.get("point_coords", []) point_labels = data.get("point_labels", []) if ( not image_base64 or not isinstance(point_coords, list) or not isinstance(point_labels, list) or len(point_coords) == 0 or len(point_coords) != len(point_labels) ): return JSONResponse(status_code=400, content={"error": "point_coords and point_labels must be supplied and have equal length."}) img_bytes = base64.b64decode(image_base64) pil_img = Image.open(BytesIO(img_bytes)).convert("RGB") np_img = np.array(pil_img) h, w = pil_img.height, pil_img.width union_mask = np.zeros((h, w), dtype=np.uint8) # Run SAM2 separately for each point, accumulate masks with torch.inference_mode(): predictor.set_image(np_img) for pt, label in zip(point_coords, point_labels): pt_np = np.array([pt], dtype=np.float32) label_np = np.array([label], dtype=np.int32) masks, _, _ = predictor.predict( point_coords=pt_np, point_labels=label_np, ) union_mask = np.logical_or(union_mask, masks[0]).astype(np.uint8) rgba = np.zeros((h, w, 4), dtype=np.uint8) rgba[..., 3] = union_mask * 128 # 128 = semi-transparent out_img = Image.fromarray(rgba, mode="RGBA") buf = BytesIO() out_img.save(buf, format="PNG") return Response(content=buf.getvalue(), media_type="image/png") except Exception as e: print("ERROR:", str(e)) return JSONResponse(status_code=500, content={"error": str(e)})