import os import io import base64 import torch import numpy as np import cv2 from fastapi import FastAPI, UploadFile, File, Form, HTTPException from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from typing import List, Optional, Union from PIL import Image from transformers import Sam3Processor, Sam3Model app = FastAPI(title="SAM 3 API", description="Segment Anything Model 3 API for HF Spaces") # CORS Setup - Allow all for simplicity in this demo, restrict in production app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # --- Global Model Variables --- device = "cuda" if torch.cuda.is_available() else "cpu" model = None processor = None # --- Startup Event --- @app.on_event("startup") async def startup_event(): global model, processor print(f"Loading SAM 3 Model on {device}...") try: processor = Sam3Processor.from_pretrained("facebook/sam3") model = Sam3Model.from_pretrained("facebook/sam3").to(device) print("Model loaded successfully!") except Exception as e: print(f"Error loading model: {e}") # In a real deployed environment, we might want to crash or retry. # For now, we print error. # --- Data Models --- class Point(BaseModel): x: int y: int label: int # 1 for positive, 0 for negative class Box(BaseModel): x1: int y1: int x2: int y2: int label: int = 1 # 1 for positive, 0 for negative class InferenceRequest(BaseModel): image: str # Base64 encoded image prompt_type: str # 'point', 'box', 'text', 'everything' points: Optional[List[Point]] = None boxes: Optional[List[Box]] = None text_prompt: Optional[str] = None # --- Helper Functions --- def decode_image(base64_string): if "," in base64_string: base64_string = base64_string.split(",")[1] image_data = base64.b64decode(base64_string) image = Image.open(io.BytesIO(image_data)).convert("RGB") return image def encode_image(image: Image.Image): buffered = io.BytesIO() image.save(buffered, format="PNG") return base64.b64encode(buffered.getvalue()).decode("utf-8") def numpy_to_base64_mask(mask_np): # mask_np is bool or uint8 (0/1) mask_img = Image.fromarray((mask_np * 255).astype(np.uint8)) return encode_image(mask_img) # --- Endpoints --- @app.get("/") def home(): return {"status": "running", "device": device} @app.post("/predict") async def predict(request: InferenceRequest): global model, processor if not model or not processor: raise HTTPException(status_code=503, detail="Model not loaded yet") try: image = decode_image(request.image) inputs = None # Prepare inputs based on prompt type if request.prompt_type == "text": if not request.text_prompt: raise HTTPException(status_code=400, detail="Text prompt required") inputs = processor(images=image, text=request.text_prompt, return_tensors="pt").to(device) elif request.prompt_type == "box": if not request.boxes: raise HTTPException(status_code=400, detail="Box prompt required") # Format: [[ [x1, y1, x2, y2], ... ]] - Batch size 1 input_boxes = [[[b.x1, b.y1, b.x2, b.y2] for b in request.boxes]] input_labels = [[[b.label] for b in request.boxes]] inputs = processor( images=image, input_boxes=input_boxes, input_boxes_labels=input_labels, return_tensors="pt" ).to(device) elif request.prompt_type == "point": if not request.points: raise HTTPException(status_code=400, detail="Point prompt required") # Format: [[ [x, y], ... ]] - Batch size 1 input_points = [[[p.x, p.y] for p in request.points]] input_labels = [[[p.label] for p in request.points]] inputs = processor( images=image, input_points=input_points, input_labels=input_labels, return_tensors="pt" ).to(device) elif request.prompt_type == "everything": # For "everything", we might need a different strategy or just use grid points # SAM 3 doesn't have a built-in "everything" function in the same way SAM 1 did (AutomaticMaskGenerator) # but we can simulate it or check if transformers supports it. # For this MVP, let's just return an error or implement a simple grid if possible. # Transformers Sam3 integration is new. Let's stick to prompts for now or try a grid of points. # We'll use a simple grid of points for now. width, height = image.size grid_size = 32 x = np.linspace(0, width, grid_size) y = np.linspace(0, height, grid_size) xv, yv = np.meshgrid(x, y) grid_points = list(zip(xv.flatten(), yv.flatten())) input_points = [[list(p) for p in grid_points]] input_labels = [[1] * len(grid_points)] # All positive # This might just get one big mask or many. Let's try it. # Actually, simpler to just say feature not fully supported in this snippet without more complex logic. # But let's try sending a generic text prompt "object" or "everything" :D # Let's fallback to text "objects". inputs = processor(images=image, text="objects", return_tensors="pt").to(device) else: raise HTTPException(status_code=400, detail="Invalid prompt type") # Inference with torch.no_grad(): outputs = model(**inputs) # Post-process results = processor.post_process_instance_segmentation( outputs, threshold=0.5, mask_threshold=0.5, target_sizes=[image.size[::-1]] # (height, width) )[0] # Convert results to JSON-serializable format # results['masks'] is a boolean tensor of shape (num_masks, H, W) masks = results['masks'].cpu().numpy() scores = results['scores'].cpu().numpy().tolist() boxes_out = results['boxes'].cpu().numpy().tolist() # [x1, y1, x2, y2] encoded_masks = [] for mask in masks: encoded_masks.append(numpy_to_base64_mask(mask)) return { "masks": encoded_masks, "scores": scores, "boxes": boxes_out, "count": len(scores) } except Exception as e: import traceback traceback.print_exc() raise HTTPException(status_code=500, detail=str(e)) if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)