File size: 7,143 Bytes
f741a7b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 | import os
import io
import base64
import torch
import numpy as np
import cv2
from fastapi import FastAPI, UploadFile, File, Form, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing import List, Optional, Union
from PIL import Image
from transformers import Sam3Processor, Sam3Model
app = FastAPI(title="SAM 3 API", description="Segment Anything Model 3 API for HF Spaces")
# CORS Setup - Allow all for simplicity in this demo, restrict in production
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# --- Global Model Variables ---
device = "cuda" if torch.cuda.is_available() else "cpu"
model = None
processor = None
# --- Startup Event ---
@app.on_event("startup")
async def startup_event():
global model, processor
print(f"Loading SAM 3 Model on {device}...")
try:
processor = Sam3Processor.from_pretrained("facebook/sam3")
model = Sam3Model.from_pretrained("facebook/sam3").to(device)
print("Model loaded successfully!")
except Exception as e:
print(f"Error loading model: {e}")
# In a real deployed environment, we might want to crash or retry.
# For now, we print error.
# --- Data Models ---
class Point(BaseModel):
x: int
y: int
label: int # 1 for positive, 0 for negative
class Box(BaseModel):
x1: int
y1: int
x2: int
y2: int
label: int = 1 # 1 for positive, 0 for negative
class InferenceRequest(BaseModel):
image: str # Base64 encoded image
prompt_type: str # 'point', 'box', 'text', 'everything'
points: Optional[List[Point]] = None
boxes: Optional[List[Box]] = None
text_prompt: Optional[str] = None
# --- Helper Functions ---
def decode_image(base64_string):
if "," in base64_string:
base64_string = base64_string.split(",")[1]
image_data = base64.b64decode(base64_string)
image = Image.open(io.BytesIO(image_data)).convert("RGB")
return image
def encode_image(image: Image.Image):
buffered = io.BytesIO()
image.save(buffered, format="PNG")
return base64.b64encode(buffered.getvalue()).decode("utf-8")
def numpy_to_base64_mask(mask_np):
# mask_np is bool or uint8 (0/1)
mask_img = Image.fromarray((mask_np * 255).astype(np.uint8))
return encode_image(mask_img)
# --- Endpoints ---
@app.get("/")
def home():
return {"status": "running", "device": device}
@app.post("/predict")
async def predict(request: InferenceRequest):
global model, processor
if not model or not processor:
raise HTTPException(status_code=503, detail="Model not loaded yet")
try:
image = decode_image(request.image)
inputs = None
# Prepare inputs based on prompt type
if request.prompt_type == "text":
if not request.text_prompt:
raise HTTPException(status_code=400, detail="Text prompt required")
inputs = processor(images=image, text=request.text_prompt, return_tensors="pt").to(device)
elif request.prompt_type == "box":
if not request.boxes:
raise HTTPException(status_code=400, detail="Box prompt required")
# Format: [[ [x1, y1, x2, y2], ... ]] - Batch size 1
input_boxes = [[[b.x1, b.y1, b.x2, b.y2] for b in request.boxes]]
input_labels = [[[b.label] for b in request.boxes]]
inputs = processor(
images=image,
input_boxes=input_boxes,
input_boxes_labels=input_labels,
return_tensors="pt"
).to(device)
elif request.prompt_type == "point":
if not request.points:
raise HTTPException(status_code=400, detail="Point prompt required")
# Format: [[ [x, y], ... ]] - Batch size 1
input_points = [[[p.x, p.y] for p in request.points]]
input_labels = [[[p.label] for p in request.points]]
inputs = processor(
images=image,
input_points=input_points,
input_labels=input_labels,
return_tensors="pt"
).to(device)
elif request.prompt_type == "everything":
# For "everything", we might need a different strategy or just use grid points
# SAM 3 doesn't have a built-in "everything" function in the same way SAM 1 did (AutomaticMaskGenerator)
# but we can simulate it or check if transformers supports it.
# For this MVP, let's just return an error or implement a simple grid if possible.
# Transformers Sam3 integration is new. Let's stick to prompts for now or try a grid of points.
# We'll use a simple grid of points for now.
width, height = image.size
grid_size = 32
x = np.linspace(0, width, grid_size)
y = np.linspace(0, height, grid_size)
xv, yv = np.meshgrid(x, y)
grid_points = list(zip(xv.flatten(), yv.flatten()))
input_points = [[list(p) for p in grid_points]]
input_labels = [[1] * len(grid_points)] # All positive
# This might just get one big mask or many. Let's try it.
# Actually, simpler to just say feature not fully supported in this snippet without more complex logic.
# But let's try sending a generic text prompt "object" or "everything" :D
# Let's fallback to text "objects".
inputs = processor(images=image, text="objects", return_tensors="pt").to(device)
else:
raise HTTPException(status_code=400, detail="Invalid prompt type")
# Inference
with torch.no_grad():
outputs = model(**inputs)
# Post-process
results = processor.post_process_instance_segmentation(
outputs,
threshold=0.5,
mask_threshold=0.5,
target_sizes=[image.size[::-1]] # (height, width)
)[0]
# Convert results to JSON-serializable format
# results['masks'] is a boolean tensor of shape (num_masks, H, W)
masks = results['masks'].cpu().numpy()
scores = results['scores'].cpu().numpy().tolist()
boxes_out = results['boxes'].cpu().numpy().tolist() # [x1, y1, x2, y2]
encoded_masks = []
for mask in masks:
encoded_masks.append(numpy_to_base64_mask(mask))
return {
"masks": encoded_masks,
"scores": scores,
"boxes": boxes_out,
"count": len(scores)
}
except Exception as e:
import traceback
traceback.print_exc()
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)
|