""" SAM3 Static Image Segmentation - Correct Implementation Uses Sam3Model (not Sam3VideoModel) for text-prompted static image segmentation. """ import base64 import io import asyncio import torch import numpy as np from PIL import Image from fastapi import FastAPI, HTTPException from pydantic import BaseModel from transformers import AutoProcessor, AutoModel import logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Load SAM3 model for STATIC IMAGES processor = AutoProcessor.from_pretrained("./model", trust_remote_code=True) model = AutoModel.from_pretrained( "./model", torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, trust_remote_code=True ) model.eval() if torch.cuda.is_available(): model.cuda() logger.info(f"GPU: {torch.cuda.get_device_name()}") logger.info(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB") logger.info(f"✓ Loaded {model.__class__.__name__} for static image segmentation") # Simple concurrency control class VRAMManager: def __init__(self): self.semaphore = asyncio.Semaphore(2) self.processing_count = 0 def get_vram_status(self): if not torch.cuda.is_available(): return {} return { "total_gb": torch.cuda.get_device_properties(0).total_memory / 1e9, "allocated_gb": torch.cuda.memory_allocated() / 1e9, "free_gb": (torch.cuda.get_device_properties(0).total_memory - torch.cuda.memory_reserved()) / 1e9, "processing_now": self.processing_count } async def acquire(self, rid): await self.semaphore.acquire() self.processing_count += 1 def release(self, rid): self.processing_count -= 1 self.semaphore.release() if torch.cuda.is_available(): torch.cuda.empty_cache() vram_manager = VRAMManager() app = FastAPI(title="SAM3 Static Image API") class Request(BaseModel): inputs: str parameters: dict def run_inference(image_b64: str, classes: list, request_id: str): """ Sam3Model inference for static images with text prompts According to HuggingFace docs, Sam3Model uses: - processor(images=image, text=text_prompts) - model.forward(pixel_values, input_ids, ...) """ try: # Decode image image_bytes = base64.b64decode(image_b64) pil_image = Image.open(io.BytesIO(image_bytes)).convert("RGB") logger.info(f"[{request_id}] Image: {pil_image.size}, Classes: {classes}") # Process with Sam3Processor # Sam3Model expects: batch of images matching text prompts # For multiple objects in ONE image, repeat the image for each class images_batch = [pil_image] * len(classes) inputs = processor( images=images_batch, # Repeat image for each text prompt text=classes, # List of text prompts return_tensors="pt" ) logger.info(f"[{request_id}] Processing {len(classes)} classes with batched images") # Move to GPU and match model dtype if torch.cuda.is_available(): model_dtype = next(model.parameters()).dtype inputs = { k: v.cuda().to(model_dtype) if isinstance(v, torch.Tensor) and v.dtype.is_floating_point else v.cuda() if isinstance(v, torch.Tensor) else v for k, v in inputs.items() } logger.info(f"[{request_id}] Moved inputs to GPU (float tensors to {model_dtype})") logger.info(f"[{request_id}] Input keys: {list(inputs.keys())}") # Sam3Model Inference with torch.no_grad(): # Sam3Model.forward() accepts pixel_values, input_ids, etc. outputs = model(**inputs) logger.info(f"[{request_id}] Forward pass successful!") logger.info(f"[{request_id}] Output type: {type(outputs)}") logger.info(f"[{request_id}] Output attributes: {dir(outputs)}") # Extract masks from outputs # Sam3Model returns masks in outputs.pred_masks if hasattr(outputs, 'pred_masks'): pred_masks = outputs.pred_masks logger.info(f"[{request_id}] pred_masks shape: {pred_masks.shape}") elif hasattr(outputs, 'masks'): pred_masks = outputs.masks logger.info(f"[{request_id}] masks shape: {pred_masks.shape}") elif isinstance(outputs, dict) and 'pred_masks' in outputs: pred_masks = outputs['pred_masks'] logger.info(f"[{request_id}] pred_masks shape: {pred_masks.shape}") else: logger.error(f"[{request_id}] Unexpected output format") logger.error(f"Output attributes: {dir(outputs) if not isinstance(outputs, dict) else outputs.keys()}") raise ValueError("Cannot find masks in model output") # Process masks results = [] # pred_masks typically: [batch, num_objects, height, width] batch_size = pred_masks.shape[0] num_masks = pred_masks.shape[1] if len(pred_masks.shape) > 1 else 1 logger.info(f"[{request_id}] Batch size: {batch_size}, Num masks: {num_masks}") for i, cls in enumerate(classes): if i < num_masks: # Get mask for this class/object if len(pred_masks.shape) == 4: # [batch, num, h, w] mask_tensor = pred_masks[0, i] # [h, w] elif len(pred_masks.shape) == 3: # [num, h, w] mask_tensor = pred_masks[i] else: mask_tensor = pred_masks # Resize to original size if needed if mask_tensor.shape[-2:] != pil_image.size[::-1]: mask_tensor = torch.nn.functional.interpolate( mask_tensor.unsqueeze(0).unsqueeze(0), size=pil_image.size[::-1], mode='bilinear', align_corners=False ).squeeze() # Convert to binary mask binary_mask = (mask_tensor > 0.0).float().cpu().numpy().astype("uint8") * 255 else: # No mask available for this class binary_mask = np.zeros(pil_image.size[::-1], dtype="uint8") # Convert to PNG pil_mask = Image.fromarray(binary_mask, mode="L") buf = io.BytesIO() pil_mask.save(buf, format="PNG") mask_b64 = base64.b64encode(buf.getvalue()).decode("utf-8") # Get confidence score if available score = 1.0 if hasattr(outputs, 'pred_scores') and i < outputs.pred_scores.shape[1]: score = float(outputs.pred_scores[0, i].cpu()) elif hasattr(outputs, 'scores') and i < len(outputs.scores): score = float(outputs.scores[i].cpu() if hasattr(outputs.scores[i], 'cpu') else outputs.scores[i]) results.append({ "label": cls, "mask": mask_b64, "score": score }) logger.info(f"[{request_id}] Completed: {len(results)} masks generated") return results except Exception as e: logger.error(f"[{request_id}] Failed: {str(e)}") import traceback traceback.print_exc() raise @app.post("/") async def predict(req: Request): request_id = str(id(req))[:8] try: await vram_manager.acquire(request_id) try: results = await asyncio.to_thread( run_inference, req.inputs, req.parameters.get("classes", []), request_id ) return results finally: vram_manager.release(request_id) except Exception as e: logger.error(f"[{request_id}] Error: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) @app.get("/health") async def health(): return { "status": "healthy", "model": model.__class__.__name__, "gpu_available": torch.cuda.is_available(), "vram": vram_manager.get_vram_status() } @app.get("/metrics") async def metrics(): return vram_manager.get_vram_status() if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860, workers=1)