File size: 7,143 Bytes
f741a7b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
import os
import io
import base64
import torch
import numpy as np
import cv2
from fastapi import FastAPI, UploadFile, File, Form, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing import List, Optional, Union
from PIL import Image
from transformers import Sam3Processor, Sam3Model

app = FastAPI(title="SAM 3 API", description="Segment Anything Model 3 API for HF Spaces")

# CORS Setup - Allow all for simplicity in this demo, restrict in production
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# --- Global Model Variables ---
device = "cuda" if torch.cuda.is_available() else "cpu"
model = None
processor = None

# --- Startup Event ---
@app.on_event("startup")
async def startup_event():
    global model, processor
    print(f"Loading SAM 3 Model on {device}...")
    try:
        processor = Sam3Processor.from_pretrained("facebook/sam3")
        model = Sam3Model.from_pretrained("facebook/sam3").to(device)
        print("Model loaded successfully!")
    except Exception as e:
        print(f"Error loading model: {e}")
        # In a real deployed environment, we might want to crash or retry. 
        # For now, we print error.

# --- Data Models ---
class Point(BaseModel):
    x: int
    y: int
    label: int  # 1 for positive, 0 for negative

class Box(BaseModel):
    x1: int
    y1: int
    x2: int
    y2: int
    label: int = 1 # 1 for positive, 0 for negative

class InferenceRequest(BaseModel):
    image: str  # Base64 encoded image
    prompt_type: str  # 'point', 'box', 'text', 'everything'
    points: Optional[List[Point]] = None
    boxes: Optional[List[Box]] = None
    text_prompt: Optional[str] = None

# --- Helper Functions ---
def decode_image(base64_string):
    if "," in base64_string:
        base64_string = base64_string.split(",")[1]
    image_data = base64.b64decode(base64_string)
    image = Image.open(io.BytesIO(image_data)).convert("RGB")
    return image

def encode_image(image: Image.Image):
    buffered = io.BytesIO()
    image.save(buffered, format="PNG")
    return base64.b64encode(buffered.getvalue()).decode("utf-8")

def numpy_to_base64_mask(mask_np):
    # mask_np is bool or uint8 (0/1)
    mask_img = Image.fromarray((mask_np * 255).astype(np.uint8))
    return encode_image(mask_img)

# --- Endpoints ---

@app.get("/")
def home():
    return {"status": "running", "device": device}

@app.post("/predict")
async def predict(request: InferenceRequest):
    global model, processor
    if not model or not processor:
        raise HTTPException(status_code=503, detail="Model not loaded yet")

    try:
        image = decode_image(request.image)
        inputs = None

        # Prepare inputs based on prompt type
        if request.prompt_type == "text":
            if not request.text_prompt:
                raise HTTPException(status_code=400, detail="Text prompt required")
            inputs = processor(images=image, text=request.text_prompt, return_tensors="pt").to(device)

        elif request.prompt_type == "box":
            if not request.boxes:
                 raise HTTPException(status_code=400, detail="Box prompt required")
            # Format: [[ [x1, y1, x2, y2], ... ]] - Batch size 1
            input_boxes = [[[b.x1, b.y1, b.x2, b.y2] for b in request.boxes]]
            input_labels = [[[b.label] for b in request.boxes]]
            inputs = processor(
                images=image, 
                input_boxes=input_boxes, 
                input_boxes_labels=input_labels, 
                return_tensors="pt"
            ).to(device)

        elif request.prompt_type == "point":
             if not request.points:
                 raise HTTPException(status_code=400, detail="Point prompt required")
             # Format: [[ [x, y], ... ]] - Batch size 1
             input_points = [[[p.x, p.y] for p in request.points]]
             input_labels = [[[p.label] for p in request.points]]
             inputs = processor(
                 images=image,
                 input_points=input_points,
                 input_labels=input_labels,
                 return_tensors="pt"
             ).to(device)
        
        elif request.prompt_type == "everything":
            # For "everything", we might need a different strategy or just use grid points
            # SAM 3 doesn't have a built-in "everything" function in the same way SAM 1 did (AutomaticMaskGenerator)
            # but we can simulate it or check if transformers supports it. 
            # For this MVP, let's just return an error or implement a simple grid if possible.
            # Transformers Sam3 integration is new. Let's stick to prompts for now or try a grid of points.
            # We'll use a simple grid of points for now.
             width, height = image.size
             grid_size = 32
             x = np.linspace(0, width, grid_size)
             y = np.linspace(0, height, grid_size)
             xv, yv = np.meshgrid(x, y)
             grid_points = list(zip(xv.flatten(), yv.flatten()))
             input_points = [[list(p) for p in grid_points]]
             input_labels = [[1] * len(grid_points)] # All positive
             # This might just get one big mask or many. Let's try it.
             # Actually, simpler to just say feature not fully supported in this snippet without more complex logic.
             # But let's try sending a generic text prompt "object" or "everything" :D 
             # Let's fallback to text "objects".
             inputs = processor(images=image, text="objects", return_tensors="pt").to(device)


        else:
             raise HTTPException(status_code=400, detail="Invalid prompt type")

        # Inference
        with torch.no_grad():
            outputs = model(**inputs)

        # Post-process
        results = processor.post_process_instance_segmentation(
            outputs, 
            threshold=0.5, 
            mask_threshold=0.5, 
            target_sizes=[image.size[::-1]] # (height, width)
        )[0]

        # Convert results to JSON-serializable format
        # results['masks'] is a boolean tensor of shape (num_masks, H, W)
        masks = results['masks'].cpu().numpy()
        scores = results['scores'].cpu().numpy().tolist()
        boxes_out = results['boxes'].cpu().numpy().tolist() # [x1, y1, x2, y2]
        
        encoded_masks = []
        for mask in masks:
            encoded_masks.append(numpy_to_base64_mask(mask))

        return {
            "masks": encoded_masks,
            "scores": scores,
            "boxes": boxes_out,
            "count": len(scores)
        }

    except Exception as e:
        import traceback
        traceback.print_exc()
        raise HTTPException(status_code=500, detail=str(e))

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=7860)