|
|
""" |
|
|
SAM3 Static Image Segmentation - Correct Implementation |
|
|
|
|
|
Uses Sam3Model (not Sam3VideoModel) for text-prompted static image segmentation. |
|
|
""" |
|
|
import base64 |
|
|
import io |
|
|
import asyncio |
|
|
import torch |
|
|
import numpy as np |
|
|
from PIL import Image |
|
|
from fastapi import FastAPI, HTTPException |
|
|
from pydantic import BaseModel |
|
|
from transformers import AutoProcessor, AutoModel |
|
|
import logging |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
processor = AutoProcessor.from_pretrained("./model", trust_remote_code=True) |
|
|
model = AutoModel.from_pretrained( |
|
|
"./model", |
|
|
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, |
|
|
trust_remote_code=True |
|
|
) |
|
|
|
|
|
model.eval() |
|
|
if torch.cuda.is_available(): |
|
|
model.cuda() |
|
|
logger.info(f"GPU: {torch.cuda.get_device_name()}") |
|
|
logger.info(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB") |
|
|
|
|
|
logger.info(f"✓ Loaded {model.__class__.__name__} for static image segmentation") |
|
|
|
|
|
|
|
|
class VRAMManager: |
|
|
def __init__(self): |
|
|
self.semaphore = asyncio.Semaphore(2) |
|
|
self.processing_count = 0 |
|
|
|
|
|
def get_vram_status(self): |
|
|
if not torch.cuda.is_available(): |
|
|
return {} |
|
|
return { |
|
|
"total_gb": torch.cuda.get_device_properties(0).total_memory / 1e9, |
|
|
"allocated_gb": torch.cuda.memory_allocated() / 1e9, |
|
|
"free_gb": (torch.cuda.get_device_properties(0).total_memory - torch.cuda.memory_reserved()) / 1e9, |
|
|
"processing_now": self.processing_count |
|
|
} |
|
|
|
|
|
async def acquire(self, rid): |
|
|
await self.semaphore.acquire() |
|
|
self.processing_count += 1 |
|
|
|
|
|
def release(self, rid): |
|
|
self.processing_count -= 1 |
|
|
self.semaphore.release() |
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
vram_manager = VRAMManager() |
|
|
app = FastAPI(title="SAM3 Static Image API") |
|
|
|
|
|
class Request(BaseModel): |
|
|
inputs: str |
|
|
parameters: dict |
|
|
|
|
|
|
|
|
def run_inference(image_b64: str, classes: list, request_id: str): |
|
|
""" |
|
|
Sam3Model inference for static images with text prompts. |
|
|
|
|
|
Uses official SAM3 processor post-processing for correct mask generation. |
|
|
""" |
|
|
try: |
|
|
|
|
|
image_bytes = base64.b64decode(image_b64) |
|
|
pil_image = Image.open(io.BytesIO(image_bytes)).convert("RGB") |
|
|
logger.info(f"[{request_id}] Image: {pil_image.size}, Classes: {classes}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
images_batch = [pil_image] * len(classes) |
|
|
inputs = processor( |
|
|
images=images_batch, |
|
|
text=classes, |
|
|
return_tensors="pt" |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
original_size = [pil_image.size[1], pil_image.size[0]] |
|
|
original_sizes = torch.tensor([original_size] * len(classes)) |
|
|
inputs["original_sizes"] = original_sizes |
|
|
|
|
|
logger.info(f"[{request_id}] Processing {len(classes)} classes with batched images") |
|
|
logger.info(f"[{request_id}] Original size: {pil_image.size} (W x H)") |
|
|
|
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
model_dtype = next(model.parameters()).dtype |
|
|
inputs = { |
|
|
k: v.cuda().to(model_dtype) if isinstance(v, torch.Tensor) and v.dtype.is_floating_point else v.cuda() if isinstance(v, torch.Tensor) else v |
|
|
for k, v in inputs.items() |
|
|
} |
|
|
logger.info(f"[{request_id}] Moved inputs to GPU (float tensors to {model_dtype})") |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model(**inputs) |
|
|
logger.info(f"[{request_id}] Forward pass successful!") |
|
|
|
|
|
logger.info(f"[{request_id}] Output type: {type(outputs)}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logger.info(f"[{request_id}] Using processor.post_process_instance_segmentation()") |
|
|
|
|
|
try: |
|
|
processed = processor.post_process_instance_segmentation( |
|
|
outputs, |
|
|
threshold=0.3, |
|
|
mask_threshold=0.5, |
|
|
target_sizes=original_sizes.tolist() |
|
|
) |
|
|
|
|
|
|
|
|
logger.info(f"[{request_id}] Post-processing successful!") |
|
|
logger.info(f"[{request_id}] Number of batched results: {len(processed)}") |
|
|
|
|
|
except Exception as proc_error: |
|
|
logger.error(f"[{request_id}] Post-processing failed: {proc_error}") |
|
|
logger.info(f"[{request_id}] Falling back to manual processing") |
|
|
|
|
|
|
|
|
results = [] |
|
|
|
|
|
|
|
|
if hasattr(outputs, 'pred_masks'): |
|
|
pred_masks = outputs.pred_masks |
|
|
elif hasattr(outputs, 'masks'): |
|
|
pred_masks = outputs.masks |
|
|
elif isinstance(outputs, dict) and 'pred_masks' in outputs: |
|
|
pred_masks = outputs['pred_masks'] |
|
|
else: |
|
|
raise ValueError("Cannot find masks in model output") |
|
|
|
|
|
logger.info(f"[{request_id}] pred_masks shape: {pred_masks.shape}") |
|
|
|
|
|
for i, cls in enumerate(classes): |
|
|
if i < pred_masks.shape[1]: |
|
|
mask_tensor = pred_masks[0, i] |
|
|
|
|
|
|
|
|
if mask_tensor.shape[-2:] != pil_image.size[::-1]: |
|
|
mask_tensor = torch.nn.functional.interpolate( |
|
|
mask_tensor.unsqueeze(0).unsqueeze(0), |
|
|
size=pil_image.size[::-1], |
|
|
mode='bilinear', |
|
|
align_corners=False |
|
|
).squeeze() |
|
|
|
|
|
|
|
|
probs = torch.sigmoid(mask_tensor) |
|
|
binary_mask = (probs > 0.5).float().cpu().numpy().astype("uint8") * 255 |
|
|
else: |
|
|
binary_mask = np.zeros(pil_image.size[::-1], dtype="uint8") |
|
|
|
|
|
|
|
|
pil_mask = Image.fromarray(binary_mask, mode="L") |
|
|
buf = io.BytesIO() |
|
|
pil_mask.save(buf, format="PNG") |
|
|
mask_b64 = base64.b64encode(buf.getvalue()).decode("utf-8") |
|
|
|
|
|
|
|
|
score = 1.0 |
|
|
if hasattr(outputs, 'pred_logits') and i < outputs.pred_logits.shape[1]: |
|
|
|
|
|
score = float(torch.sigmoid(outputs.pred_logits[0, i]).cpu()) |
|
|
|
|
|
results.append({ |
|
|
"label": cls, |
|
|
"mask": mask_b64, |
|
|
"score": score |
|
|
}) |
|
|
|
|
|
logger.info(f"[{request_id}] Completed (fallback): {len(results)} masks generated") |
|
|
return results |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
results = [] |
|
|
|
|
|
total_instances = 0 |
|
|
for i, cls in enumerate(classes): |
|
|
class_result = processed[i] |
|
|
|
|
|
num_instances = len(class_result['masks']) if 'masks' in class_result else 0 |
|
|
total_instances += num_instances |
|
|
|
|
|
if num_instances > 0: |
|
|
logger.info(f"[{request_id}] {cls}: {num_instances} instance(s) detected") |
|
|
|
|
|
|
|
|
for j in range(num_instances): |
|
|
|
|
|
mask_np = class_result['masks'][j].cpu().numpy().astype("uint8") * 255 |
|
|
|
|
|
|
|
|
pil_mask = Image.fromarray(mask_np, mode="L") |
|
|
buf = io.BytesIO() |
|
|
pil_mask.save(buf, format="PNG") |
|
|
mask_b64 = base64.b64encode(buf.getvalue()).decode("utf-8") |
|
|
|
|
|
|
|
|
score = float(class_result['scores'][j]) if 'scores' in class_result else 1.0 |
|
|
|
|
|
|
|
|
coverage = (mask_np > 0).sum() / mask_np.size * 100 |
|
|
|
|
|
results.append({ |
|
|
"label": cls, |
|
|
"mask": mask_b64, |
|
|
"score": score, |
|
|
"instance_id": j |
|
|
}) |
|
|
|
|
|
logger.info(f"[{request_id}] └─ Instance {j}: score={score:.3f}, coverage={coverage:.2f}%") |
|
|
else: |
|
|
logger.info(f"[{request_id}] {cls}: No instances detected") |
|
|
|
|
|
logger.info(f"[{request_id}] Completed: {total_instances} instance(s) across {len(classes)} class(es)") |
|
|
return results |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"[{request_id}] Failed: {str(e)}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
raise |
|
|
|
|
|
|
|
|
@app.post("/") |
|
|
async def predict(req: Request): |
|
|
request_id = str(id(req))[:8] |
|
|
try: |
|
|
await vram_manager.acquire(request_id) |
|
|
try: |
|
|
results = await asyncio.to_thread( |
|
|
run_inference, |
|
|
req.inputs, |
|
|
req.parameters.get("classes", []), |
|
|
request_id |
|
|
) |
|
|
return results |
|
|
finally: |
|
|
vram_manager.release(request_id) |
|
|
except Exception as e: |
|
|
logger.error(f"[{request_id}] Error: {str(e)}") |
|
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
|
|
|
|
@app.get("/health") |
|
|
async def health(): |
|
|
return { |
|
|
"status": "healthy", |
|
|
"model": model.__class__.__name__, |
|
|
"gpu_available": torch.cuda.is_available(), |
|
|
"vram": vram_manager.get_vram_status() |
|
|
} |
|
|
|
|
|
|
|
|
@app.get("/metrics") |
|
|
async def metrics(): |
|
|
return vram_manager.get_vram_status() |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
import uvicorn |
|
|
uvicorn.run(app, host="0.0.0.0", port=7860, workers=1) |
|
|
|