from fastapi import FastAPI, File, UploadFile, Form, HTTPException from fastapi.middleware.cors import CORSMiddleware from segment_anything import sam_model_registry, SamPredictor from PIL import Image import numpy as np import torch import io import base64 import json app = FastAPI() # Add CORS middleware for CVAT app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Load SAM Model sam_checkpoint = "sam_vit_b.pth" model_type = "vit_b" device = "cuda" if torch.cuda.is_available() else "cpu" sam = sam_model_registry[model_type](checkpoint=sam_checkpoint).to(device) predictor = SamPredictor(sam) @app.get("/") def read_root(): return {"status": "SAM API is running"} @app.post("/segment") async def segment_image(file: UploadFile = File(...)): try: image_bytes = await file.read() image = Image.open(io.BytesIO(image_bytes)).convert("RGB") image_np = np.array(image) # Get image dimensions height, width = image_np.shape[:2] # Use center point instead of fixed point center_point = np.array([[width // 2, height // 2]]) input_label = np.array([1]) predictor.set_image(image_np) masks, scores, _ = predictor.predict( point_coords=center_point, point_labels=input_label, multimask_output=True # Return multiple masks ) # Return the best mask best_mask_idx = np.argmax(scores) mask = masks[best_mask_idx].astype(bool) return { "score": float(scores[best_mask_idx]), "mask": mask.tolist() } except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.get("/models") def list_models(): return { "models": [ { "name": "sam-cvat", "type": "segmentation", "labels": ["object"] } ] } # CVAT-specific endpoint @app.post("/predict") async def predict_for_cvat(body: str = Form(...)): try: data = json.loads(body) image_data = data.get('image', '') # Decode base64 image image_bytes = base64.b64decode(image_data) image = Image.open(io.BytesIO(image_bytes)).convert("RGB") image_np = np.array(image) # Get points from CVAT request points = data.get('points', []) if not points: # If no points, use center of image height, width = image_np.shape[:2] points = [[width // 2, height // 2]] input_points = np.array(points) input_labels = np.ones(len(points)) predictor.set_image(image_np) masks, scores, _ = predictor.predict( point_coords=input_points, point_labels=input_labels, multimask_output=True ) # Get best mask best_mask_idx = np.argmax(scores) mask = masks[best_mask_idx].astype(bool) # Convert mask to CVAT format height, width = mask.shape rle = mask_to_rle(mask) return { "model": "sam-cvat", "annotations": [{ "name": "object", "score": float(scores[best_mask_idx]), "mask": { "rle": rle, "width": width, "height": height }, "type": "mask" }] } except Exception as e: raise HTTPException(status_code=500, detail=str(e)) # Helper function to convert mask to RLE (Run-Length Encoding) def mask_to_rle(mask): """Convert mask to RLE format expected by CVAT""" flattened_mask = mask.flatten() rle = [] current_pixel = 0 count = 0 for pixel in flattened_mask: if pixel == current_pixel: count += 1 else: rle.append(count) current_pixel = pixel count = 1 rle.append(count) return rle