sam3 / src /app.py
Thibaut's picture
Reorganize repository with clean separation of concerns
647f69c
raw
history blame
8.38 kB
"""
SAM3 Static Image Segmentation - Correct Implementation
Uses Sam3Model (not Sam3VideoModel) for text-prompted static image segmentation.
"""
import base64
import io
import asyncio
import torch
import numpy as np
from PIL import Image
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import AutoProcessor, AutoModel
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Load SAM3 model for STATIC IMAGES
processor = AutoProcessor.from_pretrained("./model", trust_remote_code=True)
model = AutoModel.from_pretrained(
"./model",
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
trust_remote_code=True
)
model.eval()
if torch.cuda.is_available():
model.cuda()
logger.info(f"GPU: {torch.cuda.get_device_name()}")
logger.info(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
logger.info(f"✓ Loaded {model.__class__.__name__} for static image segmentation")
# Simple concurrency control
class VRAMManager:
def __init__(self):
self.semaphore = asyncio.Semaphore(2)
self.processing_count = 0
def get_vram_status(self):
if not torch.cuda.is_available():
return {}
return {
"total_gb": torch.cuda.get_device_properties(0).total_memory / 1e9,
"allocated_gb": torch.cuda.memory_allocated() / 1e9,
"free_gb": (torch.cuda.get_device_properties(0).total_memory - torch.cuda.memory_reserved()) / 1e9,
"processing_now": self.processing_count
}
async def acquire(self, rid):
await self.semaphore.acquire()
self.processing_count += 1
def release(self, rid):
self.processing_count -= 1
self.semaphore.release()
if torch.cuda.is_available():
torch.cuda.empty_cache()
vram_manager = VRAMManager()
app = FastAPI(title="SAM3 Static Image API")
class Request(BaseModel):
inputs: str
parameters: dict
def run_inference(image_b64: str, classes: list, request_id: str):
"""
Sam3Model inference for static images with text prompts
According to HuggingFace docs, Sam3Model uses:
- processor(images=image, text=text_prompts)
- model.forward(pixel_values, input_ids, ...)
"""
try:
# Decode image
image_bytes = base64.b64decode(image_b64)
pil_image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
logger.info(f"[{request_id}] Image: {pil_image.size}, Classes: {classes}")
# Process with Sam3Processor
# Sam3Model expects: batch of images matching text prompts
# For multiple objects in ONE image, repeat the image for each class
images_batch = [pil_image] * len(classes)
inputs = processor(
images=images_batch, # Repeat image for each text prompt
text=classes, # List of text prompts
return_tensors="pt"
)
logger.info(f"[{request_id}] Processing {len(classes)} classes with batched images")
# Move to GPU and match model dtype
if torch.cuda.is_available():
model_dtype = next(model.parameters()).dtype
inputs = {
k: v.cuda().to(model_dtype) if isinstance(v, torch.Tensor) and v.dtype.is_floating_point else v.cuda() if isinstance(v, torch.Tensor) else v
for k, v in inputs.items()
}
logger.info(f"[{request_id}] Moved inputs to GPU (float tensors to {model_dtype})")
logger.info(f"[{request_id}] Input keys: {list(inputs.keys())}")
# Sam3Model Inference
with torch.no_grad():
# Sam3Model.forward() accepts pixel_values, input_ids, etc.
outputs = model(**inputs)
logger.info(f"[{request_id}] Forward pass successful!")
logger.info(f"[{request_id}] Output type: {type(outputs)}")
logger.info(f"[{request_id}] Output attributes: {dir(outputs)}")
# Extract masks from outputs
# Sam3Model returns masks in outputs.pred_masks
if hasattr(outputs, 'pred_masks'):
pred_masks = outputs.pred_masks
logger.info(f"[{request_id}] pred_masks shape: {pred_masks.shape}")
elif hasattr(outputs, 'masks'):
pred_masks = outputs.masks
logger.info(f"[{request_id}] masks shape: {pred_masks.shape}")
elif isinstance(outputs, dict) and 'pred_masks' in outputs:
pred_masks = outputs['pred_masks']
logger.info(f"[{request_id}] pred_masks shape: {pred_masks.shape}")
else:
logger.error(f"[{request_id}] Unexpected output format")
logger.error(f"Output attributes: {dir(outputs) if not isinstance(outputs, dict) else outputs.keys()}")
raise ValueError("Cannot find masks in model output")
# Process masks
results = []
# pred_masks typically: [batch, num_objects, height, width]
batch_size = pred_masks.shape[0]
num_masks = pred_masks.shape[1] if len(pred_masks.shape) > 1 else 1
logger.info(f"[{request_id}] Batch size: {batch_size}, Num masks: {num_masks}")
for i, cls in enumerate(classes):
if i < num_masks:
# Get mask for this class/object
if len(pred_masks.shape) == 4: # [batch, num, h, w]
mask_tensor = pred_masks[0, i] # [h, w]
elif len(pred_masks.shape) == 3: # [num, h, w]
mask_tensor = pred_masks[i]
else:
mask_tensor = pred_masks
# Resize to original size if needed
if mask_tensor.shape[-2:] != pil_image.size[::-1]:
mask_tensor = torch.nn.functional.interpolate(
mask_tensor.unsqueeze(0).unsqueeze(0),
size=pil_image.size[::-1],
mode='bilinear',
align_corners=False
).squeeze()
# Convert to binary mask
binary_mask = (mask_tensor > 0.0).float().cpu().numpy().astype("uint8") * 255
else:
# No mask available for this class
binary_mask = np.zeros(pil_image.size[::-1], dtype="uint8")
# Convert to PNG
pil_mask = Image.fromarray(binary_mask, mode="L")
buf = io.BytesIO()
pil_mask.save(buf, format="PNG")
mask_b64 = base64.b64encode(buf.getvalue()).decode("utf-8")
# Get confidence score if available
score = 1.0
if hasattr(outputs, 'pred_scores') and i < outputs.pred_scores.shape[1]:
score = float(outputs.pred_scores[0, i].cpu())
elif hasattr(outputs, 'scores') and i < len(outputs.scores):
score = float(outputs.scores[i].cpu() if hasattr(outputs.scores[i], 'cpu') else outputs.scores[i])
results.append({
"label": cls,
"mask": mask_b64,
"score": score
})
logger.info(f"[{request_id}] Completed: {len(results)} masks generated")
return results
except Exception as e:
logger.error(f"[{request_id}] Failed: {str(e)}")
import traceback
traceback.print_exc()
raise
@app.post("/")
async def predict(req: Request):
request_id = str(id(req))[:8]
try:
await vram_manager.acquire(request_id)
try:
results = await asyncio.to_thread(
run_inference,
req.inputs,
req.parameters.get("classes", []),
request_id
)
return results
finally:
vram_manager.release(request_id)
except Exception as e:
logger.error(f"[{request_id}] Error: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/health")
async def health():
return {
"status": "healthy",
"model": model.__class__.__name__,
"gpu_available": torch.cuda.is_available(),
"vram": vram_manager.get_vram_status()
}
@app.get("/metrics")
async def metrics():
return vram_manager.get_vram_status()
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860, workers=1)