| | """ |
| | SAM3 Static Image Segmentation - Correct Implementation |
| | |
| | Uses Sam3Model (not Sam3VideoModel) for text-prompted static image segmentation. |
| | """ |
| | import base64 |
| | import io |
| | import asyncio |
| | import torch |
| | import numpy as np |
| | from PIL import Image |
| | from fastapi import FastAPI, HTTPException |
| | from pydantic import BaseModel |
| | from transformers import AutoProcessor, AutoModel |
| | import logging |
| |
|
| | logging.basicConfig(level=logging.INFO) |
| | logger = logging.getLogger(__name__) |
| |
|
| | |
| | processor = AutoProcessor.from_pretrained("./model", trust_remote_code=True) |
| | model = AutoModel.from_pretrained( |
| | "./model", |
| | torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, |
| | trust_remote_code=True |
| | ) |
| |
|
| | model.eval() |
| | if torch.cuda.is_available(): |
| | model.cuda() |
| | logger.info(f"GPU: {torch.cuda.get_device_name()}") |
| | logger.info(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB") |
| |
|
| | logger.info(f"✓ Loaded {model.__class__.__name__} for static image segmentation") |
| |
|
| | |
| | class VRAMManager: |
| | def __init__(self): |
| | self.semaphore = asyncio.Semaphore(2) |
| | self.processing_count = 0 |
| |
|
| | def get_vram_status(self): |
| | if not torch.cuda.is_available(): |
| | return {} |
| | return { |
| | "total_gb": torch.cuda.get_device_properties(0).total_memory / 1e9, |
| | "allocated_gb": torch.cuda.memory_allocated() / 1e9, |
| | "free_gb": (torch.cuda.get_device_properties(0).total_memory - torch.cuda.memory_reserved()) / 1e9, |
| | "processing_now": self.processing_count |
| | } |
| |
|
| | async def acquire(self, rid): |
| | await self.semaphore.acquire() |
| | self.processing_count += 1 |
| |
|
| | def release(self, rid): |
| | self.processing_count -= 1 |
| | self.semaphore.release() |
| | if torch.cuda.is_available(): |
| | torch.cuda.empty_cache() |
| |
|
| | vram_manager = VRAMManager() |
| | app = FastAPI(title="SAM3 Static Image API") |
| |
|
| | class Request(BaseModel): |
| | inputs: str |
| | parameters: dict |
| |
|
| |
|
| | def run_inference(image_b64: str, classes: list, request_id: str): |
| | """ |
| | Sam3Model inference for static images with text prompts. |
| | |
| | Uses official SAM3 processor post-processing for correct mask generation. |
| | """ |
| | try: |
| | |
| | image_bytes = base64.b64decode(image_b64) |
| | pil_image = Image.open(io.BytesIO(image_bytes)).convert("RGB") |
| | logger.info(f"[{request_id}] Image: {pil_image.size}, Classes: {classes}") |
| |
|
| | |
| | |
| | |
| | images_batch = [pil_image] * len(classes) |
| | inputs = processor( |
| | images=images_batch, |
| | text=classes, |
| | return_tensors="pt" |
| | ) |
| |
|
| | |
| | |
| | |
| | original_size = [pil_image.size[1], pil_image.size[0]] |
| | original_sizes = torch.tensor([original_size] * len(classes)) |
| | inputs["original_sizes"] = original_sizes |
| |
|
| | logger.info(f"[{request_id}] Processing {len(classes)} classes with batched images") |
| | logger.info(f"[{request_id}] Original size: {pil_image.size} (W x H)") |
| |
|
| | |
| | if torch.cuda.is_available(): |
| | model_dtype = next(model.parameters()).dtype |
| | inputs = { |
| | k: v.cuda().to(model_dtype) if isinstance(v, torch.Tensor) and v.dtype.is_floating_point else v.cuda() if isinstance(v, torch.Tensor) else v |
| | for k, v in inputs.items() |
| | } |
| | logger.info(f"[{request_id}] Moved inputs to GPU (float tensors to {model_dtype})") |
| |
|
| | |
| | with torch.no_grad(): |
| | outputs = model(**inputs) |
| | logger.info(f"[{request_id}] Forward pass successful!") |
| |
|
| | logger.info(f"[{request_id}] Output type: {type(outputs)}") |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | logger.info(f"[{request_id}] Using processor.post_process_instance_segmentation()") |
| |
|
| | try: |
| | processed = processor.post_process_instance_segmentation( |
| | outputs, |
| | threshold=0.3, |
| | mask_threshold=0.5, |
| | target_sizes=original_sizes.tolist() |
| | ) |
| | |
| |
|
| | logger.info(f"[{request_id}] Post-processing successful!") |
| | logger.info(f"[{request_id}] Number of batched results: {len(processed)}") |
| |
|
| | except Exception as proc_error: |
| | logger.error(f"[{request_id}] Post-processing failed: {proc_error}") |
| | logger.info(f"[{request_id}] Falling back to manual processing") |
| |
|
| | |
| | results = [] |
| |
|
| | |
| | if hasattr(outputs, 'pred_masks'): |
| | pred_masks = outputs.pred_masks |
| | elif hasattr(outputs, 'masks'): |
| | pred_masks = outputs.masks |
| | elif isinstance(outputs, dict) and 'pred_masks' in outputs: |
| | pred_masks = outputs['pred_masks'] |
| | else: |
| | raise ValueError("Cannot find masks in model output") |
| |
|
| | logger.info(f"[{request_id}] pred_masks shape: {pred_masks.shape}") |
| |
|
| | for i, cls in enumerate(classes): |
| | if i < pred_masks.shape[1]: |
| | mask_tensor = pred_masks[0, i] |
| |
|
| | |
| | if mask_tensor.shape[-2:] != pil_image.size[::-1]: |
| | mask_tensor = torch.nn.functional.interpolate( |
| | mask_tensor.unsqueeze(0).unsqueeze(0), |
| | size=pil_image.size[::-1], |
| | mode='bilinear', |
| | align_corners=False |
| | ).squeeze() |
| |
|
| | |
| | probs = torch.sigmoid(mask_tensor) |
| | binary_mask = (probs > 0.5).float().cpu().numpy().astype("uint8") * 255 |
| | else: |
| | binary_mask = np.zeros(pil_image.size[::-1], dtype="uint8") |
| |
|
| | |
| | pil_mask = Image.fromarray(binary_mask, mode="L") |
| | buf = io.BytesIO() |
| | pil_mask.save(buf, format="PNG") |
| | mask_b64 = base64.b64encode(buf.getvalue()).decode("utf-8") |
| |
|
| | |
| | score = 1.0 |
| | if hasattr(outputs, 'pred_logits') and i < outputs.pred_logits.shape[1]: |
| | |
| | score = float(torch.sigmoid(outputs.pred_logits[0, i]).cpu()) |
| |
|
| | results.append({ |
| | "label": cls, |
| | "mask": mask_b64, |
| | "score": score |
| | }) |
| |
|
| | logger.info(f"[{request_id}] Completed (fallback): {len(results)} masks generated") |
| | return results |
| |
|
| | |
| | |
| | |
| | results = [] |
| |
|
| | total_instances = 0 |
| | for i, cls in enumerate(classes): |
| | class_result = processed[i] |
| |
|
| | num_instances = len(class_result['masks']) if 'masks' in class_result else 0 |
| | total_instances += num_instances |
| |
|
| | if num_instances > 0: |
| | logger.info(f"[{request_id}] {cls}: {num_instances} instance(s) detected") |
| |
|
| | |
| | for j in range(num_instances): |
| | |
| | mask_np = class_result['masks'][j].cpu().numpy().astype("uint8") * 255 |
| |
|
| | |
| | pil_mask = Image.fromarray(mask_np, mode="L") |
| | buf = io.BytesIO() |
| | pil_mask.save(buf, format="PNG") |
| | mask_b64 = base64.b64encode(buf.getvalue()).decode("utf-8") |
| |
|
| | |
| | score = float(class_result['scores'][j]) if 'scores' in class_result else 1.0 |
| |
|
| | |
| | coverage = (mask_np > 0).sum() / mask_np.size * 100 |
| |
|
| | results.append({ |
| | "label": cls, |
| | "mask": mask_b64, |
| | "score": score, |
| | "instance_id": j |
| | }) |
| |
|
| | logger.info(f"[{request_id}] └─ Instance {j}: score={score:.3f}, coverage={coverage:.2f}%") |
| | else: |
| | logger.info(f"[{request_id}] {cls}: No instances detected") |
| |
|
| | logger.info(f"[{request_id}] Completed: {total_instances} instance(s) across {len(classes)} class(es)") |
| | return results |
| |
|
| | except Exception as e: |
| | logger.error(f"[{request_id}] Failed: {str(e)}") |
| | import traceback |
| | traceback.print_exc() |
| | raise |
| |
|
| |
|
| | @app.post("/") |
| | async def predict(req: Request): |
| | request_id = str(id(req))[:8] |
| | try: |
| | await vram_manager.acquire(request_id) |
| | try: |
| | results = await asyncio.to_thread( |
| | run_inference, |
| | req.inputs, |
| | req.parameters.get("classes", []), |
| | request_id |
| | ) |
| | return results |
| | finally: |
| | vram_manager.release(request_id) |
| | except Exception as e: |
| | logger.error(f"[{request_id}] Error: {str(e)}") |
| | raise HTTPException(status_code=500, detail=str(e)) |
| |
|
| |
|
| | @app.get("/health") |
| | async def health(): |
| | return { |
| | "status": "healthy", |
| | "model": model.__class__.__name__, |
| | "gpu_available": torch.cuda.is_available(), |
| | "vram": vram_manager.get_vram_status() |
| | } |
| |
|
| |
|
| | @app.get("/metrics") |
| | async def metrics(): |
| | return vram_manager.get_vram_status() |
| |
|
| |
|
| | if __name__ == "__main__": |
| | import uvicorn |
| | uvicorn.run(app, host="0.0.0.0", port=7860, workers=1) |
| |
|