Spaces:
Build error
Build error
| from fastapi import FastAPI, Request | |
| from fastapi.responses import Response, JSONResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from PIL import Image | |
| import torch | |
| import numpy as np | |
| import base64 | |
| from io import BytesIO | |
| from sam2.sam2_image_predictor import SAM2ImagePredictor | |
| predictor = SAM2ImagePredictor.from_pretrained("checkpoint/sam2.1_hiera_large.pt") | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| predictor.model.to(device).eval() | |
| app = FastAPI() | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], allow_credentials=True, | |
| allow_methods=["*"], allow_headers=["*"] | |
| ) | |
| async def sam2_segment(request: Request): | |
| try: | |
| data = await request.json() | |
| image_base64 = data.get("image") | |
| point_coords = data.get("point_coords", []) | |
| point_labels = data.get("point_labels", []) | |
| if ( | |
| not image_base64 | |
| or not isinstance(point_coords, list) | |
| or not isinstance(point_labels, list) | |
| or len(point_coords) == 0 | |
| or len(point_coords) != len(point_labels) | |
| ): | |
| return JSONResponse(status_code=400, content={"error": "point_coords and point_labels must be supplied and have equal length."}) | |
| img_bytes = base64.b64decode(image_base64) | |
| pil_img = Image.open(BytesIO(img_bytes)).convert("RGB") | |
| np_img = np.array(pil_img) | |
| h, w = pil_img.height, pil_img.width | |
| union_mask = np.zeros((h, w), dtype=np.uint8) | |
| # Run SAM2 separately for each point, accumulate masks | |
| with torch.inference_mode(): | |
| predictor.set_image(np_img) | |
| for pt, label in zip(point_coords, point_labels): | |
| pt_np = np.array([pt], dtype=np.float32) | |
| label_np = np.array([label], dtype=np.int32) | |
| masks, _, _ = predictor.predict( | |
| point_coords=pt_np, | |
| point_labels=label_np, | |
| ) | |
| union_mask = np.logical_or(union_mask, masks[0]).astype(np.uint8) | |
| rgba = np.zeros((h, w, 4), dtype=np.uint8) | |
| rgba[..., 3] = union_mask * 128 # 128 = semi-transparent | |
| out_img = Image.fromarray(rgba, mode="RGBA") | |
| buf = BytesIO() | |
| out_img.save(buf, format="PNG") | |
| return Response(content=buf.getvalue(), media_type="image/png") | |
| except Exception as e: | |
| print("ERROR:", str(e)) | |
| return JSONResponse(status_code=500, content={"error": str(e)}) | |