| from fastapi import FastAPI, Request
|
| from fastapi.responses import Response, JSONResponse
|
| from fastapi.middleware.cors import CORSMiddleware
|
| from PIL import Image
|
| import torch
|
| import numpy as np
|
| import base64
|
| from io import BytesIO
|
|
|
| from sam2.sam2_image_predictor import SAM2ImagePredictor
|
|
|
|
|
| predictor = SAM2ImagePredictor.from_pretrained(
|
| "facebook/sam2-hiera-large",
|
| checkpoint="checkpoints/sam2.1_hiera_large.pt"
|
| )
|
|
|
| device = "cuda" if torch.cuda.is_available() else "cpu"
|
| predictor.model.to(device).eval()
|
|
|
| app = FastAPI()
|
| app.add_middleware(
|
| CORSMiddleware,
|
| allow_origins=["*"], allow_credentials=True,
|
| allow_methods=["*"], allow_headers=["*"]
|
| )
|
|
|
| @app.post("/sam2_segment/")
|
| async def sam2_segment(request: Request):
|
| try:
|
| data = await request.json()
|
| image_base64 = data.get("image")
|
| point_coords = data.get("point_coords", [])
|
| point_labels = data.get("point_labels", [])
|
|
|
| if (
|
| not image_base64
|
| or not isinstance(point_coords, list)
|
| or not isinstance(point_labels, list)
|
| or len(point_coords) == 0
|
| or len(point_coords) != len(point_labels)
|
| ):
|
| return JSONResponse(
|
| status_code=400,
|
| content={"error": "point_coords and point_labels must be supplied and have equal length."}
|
| )
|
|
|
| img_bytes = base64.b64decode(image_base64)
|
| pil_img = Image.open(BytesIO(img_bytes)).convert("RGB")
|
| np_img = np.array(pil_img)
|
|
|
| h, w = pil_img.height, pil_img.width
|
| union_mask = np.zeros((h, w), dtype=np.uint8)
|
|
|
| with torch.inference_mode():
|
| predictor.set_image(np_img)
|
| for pt, label in zip(point_coords, point_labels):
|
| pt_np = np.array([pt], dtype=np.float32)
|
| label_np = np.array([label], dtype=np.int32)
|
| masks, _, _ = predictor.predict(
|
| point_coords=pt_np,
|
| point_labels=label_np,
|
| )
|
| union_mask = np.logical_or(union_mask, masks[0]).astype(np.uint8)
|
|
|
| rgba = np.zeros((h, w, 4), dtype=np.uint8)
|
| rgba[..., 3] = union_mask * 128
|
|
|
| out_img = Image.fromarray(rgba, mode="RGBA")
|
| buf = BytesIO()
|
| out_img.save(buf, format="PNG")
|
| return Response(content=buf.getvalue(), media_type="image/png")
|
| except Exception as e:
|
| print("ERROR:", str(e))
|
| return JSONResponse(status_code=500, content={"error": str(e)})
|
|
|