| import os
|
| import io
|
| import base64
|
| import torch
|
| import numpy as np
|
| import cv2
|
| from fastapi import FastAPI, UploadFile, File, Form, HTTPException
|
| from fastapi.middleware.cors import CORSMiddleware
|
| from pydantic import BaseModel
|
| from typing import List, Optional, Union
|
| from PIL import Image
|
| from transformers import Sam3Processor, Sam3Model
|
|
|
| app = FastAPI(title="SAM 3 API", description="Segment Anything Model 3 API for HF Spaces")
|
|
|
|
|
| app.add_middleware(
|
| CORSMiddleware,
|
| allow_origins=["*"],
|
| allow_credentials=True,
|
| allow_methods=["*"],
|
| allow_headers=["*"],
|
| )
|
|
|
|
|
| device = "cuda" if torch.cuda.is_available() else "cpu"
|
| model = None
|
| processor = None
|
|
|
|
|
| @app.on_event("startup")
|
| async def startup_event():
|
| global model, processor
|
| print(f"Loading SAM 3 Model on {device}...")
|
| try:
|
| processor = Sam3Processor.from_pretrained("facebook/sam3")
|
| model = Sam3Model.from_pretrained("facebook/sam3").to(device)
|
| print("Model loaded successfully!")
|
| except Exception as e:
|
| print(f"Error loading model: {e}")
|
|
|
|
|
|
|
|
|
| class Point(BaseModel):
|
| x: int
|
| y: int
|
| label: int
|
|
|
| class Box(BaseModel):
|
| x1: int
|
| y1: int
|
| x2: int
|
| y2: int
|
| label: int = 1
|
|
|
| class InferenceRequest(BaseModel):
|
| image: str
|
| prompt_type: str
|
| points: Optional[List[Point]] = None
|
| boxes: Optional[List[Box]] = None
|
| text_prompt: Optional[str] = None
|
|
|
|
|
| def decode_image(base64_string):
|
| if "," in base64_string:
|
| base64_string = base64_string.split(",")[1]
|
| image_data = base64.b64decode(base64_string)
|
| image = Image.open(io.BytesIO(image_data)).convert("RGB")
|
| return image
|
|
|
| def encode_image(image: Image.Image):
|
| buffered = io.BytesIO()
|
| image.save(buffered, format="PNG")
|
| return base64.b64encode(buffered.getvalue()).decode("utf-8")
|
|
|
| def numpy_to_base64_mask(mask_np):
|
|
|
| mask_img = Image.fromarray((mask_np * 255).astype(np.uint8))
|
| return encode_image(mask_img)
|
|
|
|
|
|
|
| @app.get("/")
|
| def home():
|
| return {"status": "running", "device": device}
|
|
|
| @app.post("/predict")
|
| async def predict(request: InferenceRequest):
|
| global model, processor
|
| if not model or not processor:
|
| raise HTTPException(status_code=503, detail="Model not loaded yet")
|
|
|
| try:
|
| image = decode_image(request.image)
|
| inputs = None
|
|
|
|
|
| if request.prompt_type == "text":
|
| if not request.text_prompt:
|
| raise HTTPException(status_code=400, detail="Text prompt required")
|
| inputs = processor(images=image, text=request.text_prompt, return_tensors="pt").to(device)
|
|
|
| elif request.prompt_type == "box":
|
| if not request.boxes:
|
| raise HTTPException(status_code=400, detail="Box prompt required")
|
|
|
| input_boxes = [[[b.x1, b.y1, b.x2, b.y2] for b in request.boxes]]
|
| input_labels = [[[b.label] for b in request.boxes]]
|
| inputs = processor(
|
| images=image,
|
| input_boxes=input_boxes,
|
| input_boxes_labels=input_labels,
|
| return_tensors="pt"
|
| ).to(device)
|
|
|
| elif request.prompt_type == "point":
|
| if not request.points:
|
| raise HTTPException(status_code=400, detail="Point prompt required")
|
|
|
| input_points = [[[p.x, p.y] for p in request.points]]
|
| input_labels = [[[p.label] for p in request.points]]
|
| inputs = processor(
|
| images=image,
|
| input_points=input_points,
|
| input_labels=input_labels,
|
| return_tensors="pt"
|
| ).to(device)
|
|
|
| elif request.prompt_type == "everything":
|
|
|
|
|
|
|
|
|
|
|
|
|
| width, height = image.size
|
| grid_size = 32
|
| x = np.linspace(0, width, grid_size)
|
| y = np.linspace(0, height, grid_size)
|
| xv, yv = np.meshgrid(x, y)
|
| grid_points = list(zip(xv.flatten(), yv.flatten()))
|
| input_points = [[list(p) for p in grid_points]]
|
| input_labels = [[1] * len(grid_points)]
|
|
|
|
|
|
|
|
|
| inputs = processor(images=image, text="objects", return_tensors="pt").to(device)
|
|
|
|
|
| else:
|
| raise HTTPException(status_code=400, detail="Invalid prompt type")
|
|
|
|
|
| with torch.no_grad():
|
| outputs = model(**inputs)
|
|
|
|
|
| results = processor.post_process_instance_segmentation(
|
| outputs,
|
| threshold=0.5,
|
| mask_threshold=0.5,
|
| target_sizes=[image.size[::-1]]
|
| )[0]
|
|
|
|
|
|
|
| masks = results['masks'].cpu().numpy()
|
| scores = results['scores'].cpu().numpy().tolist()
|
| boxes_out = results['boxes'].cpu().numpy().tolist()
|
|
|
| encoded_masks = []
|
| for mask in masks:
|
| encoded_masks.append(numpy_to_base64_mask(mask))
|
|
|
| return {
|
| "masks": encoded_masks,
|
| "scores": scores,
|
| "boxes": boxes_out,
|
| "count": len(scores)
|
| }
|
|
|
| except Exception as e:
|
| import traceback
|
| traceback.print_exc()
|
| raise HTTPException(status_code=500, detail=str(e))
|
|
|
| if __name__ == "__main__":
|
| import uvicorn
|
| uvicorn.run(app, host="0.0.0.0", port=7860)
|
|
|