| | """ |
| | SAM3 Static Image Segmentation - Correct Implementation |
| | |
| | Uses Sam3Model (not Sam3VideoModel) for text-prompted static image segmentation. |
| | """ |
| | import base64 |
| | import io |
| | import asyncio |
| | import torch |
| | import numpy as np |
| | from PIL import Image |
| | from fastapi import FastAPI, HTTPException |
| | from pydantic import BaseModel |
| | from transformers import AutoProcessor, AutoModel |
| | import logging |
| |
|
| | logging.basicConfig(level=logging.INFO) |
| | logger = logging.getLogger(__name__) |
| |
|
| | |
| | processor = AutoProcessor.from_pretrained("./model", trust_remote_code=True) |
| | model = AutoModel.from_pretrained( |
| | "./model", |
| | torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, |
| | trust_remote_code=True |
| | ) |
| |
|
| | model.eval() |
| | if torch.cuda.is_available(): |
| | model.cuda() |
| | logger.info(f"GPU: {torch.cuda.get_device_name()}") |
| | logger.info(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB") |
| |
|
| | logger.info(f"✓ Loaded {model.__class__.__name__} for static image segmentation") |
| |
|
| | |
| | class VRAMManager: |
| | def __init__(self): |
| | self.semaphore = asyncio.Semaphore(2) |
| | self.processing_count = 0 |
| |
|
| | def get_vram_status(self): |
| | if not torch.cuda.is_available(): |
| | return {} |
| | return { |
| | "total_gb": torch.cuda.get_device_properties(0).total_memory / 1e9, |
| | "allocated_gb": torch.cuda.memory_allocated() / 1e9, |
| | "free_gb": (torch.cuda.get_device_properties(0).total_memory - torch.cuda.memory_reserved()) / 1e9, |
| | "processing_now": self.processing_count |
| | } |
| |
|
| | async def acquire(self, rid): |
| | await self.semaphore.acquire() |
| | self.processing_count += 1 |
| |
|
| | def release(self, rid): |
| | self.processing_count -= 1 |
| | self.semaphore.release() |
| | if torch.cuda.is_available(): |
| | torch.cuda.empty_cache() |
| |
|
| | vram_manager = VRAMManager() |
| | app = FastAPI(title="SAM3 Static Image API") |
| |
|
| | class Request(BaseModel): |
| | inputs: str |
| | parameters: dict |
| |
|
| |
|
| | def run_inference(image_b64: str, classes: list, request_id: str): |
| | """ |
| | Sam3Model inference for static images with text prompts |
| | |
| | According to HuggingFace docs, Sam3Model uses: |
| | - processor(images=image, text=text_prompts) |
| | - model.forward(pixel_values, input_ids, ...) |
| | """ |
| | try: |
| | |
| | image_bytes = base64.b64decode(image_b64) |
| | pil_image = Image.open(io.BytesIO(image_bytes)).convert("RGB") |
| | logger.info(f"[{request_id}] Image: {pil_image.size}, Classes: {classes}") |
| |
|
| | |
| | |
| | |
| | images_batch = [pil_image] * len(classes) |
| | inputs = processor( |
| | images=images_batch, |
| | text=classes, |
| | return_tensors="pt" |
| | ) |
| | logger.info(f"[{request_id}] Processing {len(classes)} classes with batched images") |
| |
|
| | |
| | if torch.cuda.is_available(): |
| | model_dtype = next(model.parameters()).dtype |
| | inputs = { |
| | k: v.cuda().to(model_dtype) if isinstance(v, torch.Tensor) and v.dtype.is_floating_point else v.cuda() if isinstance(v, torch.Tensor) else v |
| | for k, v in inputs.items() |
| | } |
| | logger.info(f"[{request_id}] Moved inputs to GPU (float tensors to {model_dtype})") |
| |
|
| | logger.info(f"[{request_id}] Input keys: {list(inputs.keys())}") |
| |
|
| | |
| | with torch.no_grad(): |
| | |
| | outputs = model(**inputs) |
| | logger.info(f"[{request_id}] Forward pass successful!") |
| |
|
| | logger.info(f"[{request_id}] Output type: {type(outputs)}") |
| | logger.info(f"[{request_id}] Output attributes: {dir(outputs)}") |
| |
|
| | |
| | |
| | if hasattr(outputs, 'pred_masks'): |
| | pred_masks = outputs.pred_masks |
| | logger.info(f"[{request_id}] pred_masks shape: {pred_masks.shape}") |
| | elif hasattr(outputs, 'masks'): |
| | pred_masks = outputs.masks |
| | logger.info(f"[{request_id}] masks shape: {pred_masks.shape}") |
| | elif isinstance(outputs, dict) and 'pred_masks' in outputs: |
| | pred_masks = outputs['pred_masks'] |
| | logger.info(f"[{request_id}] pred_masks shape: {pred_masks.shape}") |
| | else: |
| | logger.error(f"[{request_id}] Unexpected output format") |
| | logger.error(f"Output attributes: {dir(outputs) if not isinstance(outputs, dict) else outputs.keys()}") |
| | raise ValueError("Cannot find masks in model output") |
| |
|
| | |
| | results = [] |
| |
|
| | |
| | batch_size = pred_masks.shape[0] |
| | num_masks = pred_masks.shape[1] if len(pred_masks.shape) > 1 else 1 |
| |
|
| | logger.info(f"[{request_id}] Batch size: {batch_size}, Num masks: {num_masks}") |
| |
|
| | for i, cls in enumerate(classes): |
| | if i < num_masks: |
| | |
| | if len(pred_masks.shape) == 4: |
| | mask_tensor = pred_masks[0, i] |
| | elif len(pred_masks.shape) == 3: |
| | mask_tensor = pred_masks[i] |
| | else: |
| | mask_tensor = pred_masks |
| |
|
| | |
| | if mask_tensor.shape[-2:] != pil_image.size[::-1]: |
| | mask_tensor = torch.nn.functional.interpolate( |
| | mask_tensor.unsqueeze(0).unsqueeze(0), |
| | size=pil_image.size[::-1], |
| | mode='bilinear', |
| | align_corners=False |
| | ).squeeze() |
| |
|
| | |
| | binary_mask = (mask_tensor > 0.0).float().cpu().numpy().astype("uint8") * 255 |
| | else: |
| | |
| | binary_mask = np.zeros(pil_image.size[::-1], dtype="uint8") |
| |
|
| | |
| | pil_mask = Image.fromarray(binary_mask, mode="L") |
| | buf = io.BytesIO() |
| | pil_mask.save(buf, format="PNG") |
| | mask_b64 = base64.b64encode(buf.getvalue()).decode("utf-8") |
| |
|
| | |
| | score = 1.0 |
| | if hasattr(outputs, 'pred_scores') and i < outputs.pred_scores.shape[1]: |
| | score = float(outputs.pred_scores[0, i].cpu()) |
| | elif hasattr(outputs, 'scores') and i < len(outputs.scores): |
| | score = float(outputs.scores[i].cpu() if hasattr(outputs.scores[i], 'cpu') else outputs.scores[i]) |
| |
|
| | results.append({ |
| | "label": cls, |
| | "mask": mask_b64, |
| | "score": score |
| | }) |
| |
|
| | logger.info(f"[{request_id}] Completed: {len(results)} masks generated") |
| | return results |
| |
|
| | except Exception as e: |
| | logger.error(f"[{request_id}] Failed: {str(e)}") |
| | import traceback |
| | traceback.print_exc() |
| | raise |
| |
|
| |
|
| | @app.post("/") |
| | async def predict(req: Request): |
| | request_id = str(id(req))[:8] |
| | try: |
| | await vram_manager.acquire(request_id) |
| | try: |
| | results = await asyncio.to_thread( |
| | run_inference, |
| | req.inputs, |
| | req.parameters.get("classes", []), |
| | request_id |
| | ) |
| | return results |
| | finally: |
| | vram_manager.release(request_id) |
| | except Exception as e: |
| | logger.error(f"[{request_id}] Error: {str(e)}") |
| | raise HTTPException(status_code=500, detail=str(e)) |
| |
|
| |
|
| | @app.get("/health") |
| | async def health(): |
| | return { |
| | "status": "healthy", |
| | "model": model.__class__.__name__, |
| | "gpu_available": torch.cuda.is_available(), |
| | "vram": vram_manager.get_vram_status() |
| | } |
| |
|
| |
|
| | @app.get("/metrics") |
| | async def metrics(): |
| | return vram_manager.get_vram_status() |
| |
|
| |
|
| | if __name__ == "__main__": |
| | import uvicorn |
| | uvicorn.run(app, host="0.0.0.0", port=7860, workers=1) |
| |
|