NeuroSAM3 / models.py
mmrech's picture
Refactor codebase: Add modular structure, logging, validation, and comprehensive improvements
69066c5
"""
Model loading and inference for NeuroSAM 3 application.
Handles SAM 3 model initialization and inference operations.
"""
from typing import Optional, Dict, Any
import torch
import spaces
from PIL import Image
from logger_config import logger
from config import (
SAM_MODEL_ID,
HF_TOKEN,
DEFAULT_THRESHOLD,
DEFAULT_MASK_THRESHOLD,
GPU_DURATION_SECONDS,
)
# Try to import SAM 3 classes
try:
from transformers import Sam3Processor, Sam3Model
SAM3_AVAILABLE = True
except ImportError:
logger.warning("Sam3Processor/Sam3Model not found in transformers.")
logger.warning("SAM3 requires transformers from GitHub main branch.")
logger.warning("Install with: pip install git+https://github.com/huggingface/transformers.git")
SAM3_AVAILABLE = False
Sam3Processor = None
Sam3Model = None
# Global model and processor instances
model: Optional[Any] = None
processor: Optional[Any] = None
def initialize_model() -> bool:
"""
Initialize SAM 3 model and processor.
Returns:
True if model loaded successfully, False otherwise
"""
global model, processor
if not SAM3_AVAILABLE:
logger.error("SAM 3 classes not available in transformers library.")
logger.error("Install with: pip install git+https://github.com/huggingface/transformers.git")
return False
if HF_TOKEN is None:
logger.warning("Cannot load model: HF_TOKEN not set")
model = None
processor = None
return False
try:
logger.info(f"Loading SAM 3 model: {SAM_MODEL_ID}")
# Load model on CPU to avoid CUDA initialization in main process
# (for HF Spaces Stateless GPU)
model = Sam3Model.from_pretrained(
SAM_MODEL_ID,
torch_dtype=torch.float32, # Load as float32 on CPU
token=HF_TOKEN
)
processor = Sam3Processor.from_pretrained(SAM_MODEL_ID, token=HF_TOKEN)
model.eval()
logger.info(f"SAM 3 Model loaded successfully on CPU! ({SAM_MODEL_ID})")
logger.info("Model will be moved to GPU when inference is called")
return True
except Exception as e:
logger.error(f"Failed to load SAM 3 model: {e}", exc_info=True)
logger.error("Ensure you have:")
logger.error(" 1. transformers from GitHub main branch for SAM 3 support")
logger.error(" Install with: pip install git+https://github.com/huggingface/transformers.git")
logger.error(" 2. Valid Hugging Face token with access to SAM 3")
logger.error(" 3. Sufficient memory for the model")
model = None
processor = None
return False
def is_model_loaded() -> bool:
"""Check if model is loaded."""
return model is not None and processor is not None
def get_model() -> Optional[Any]:
"""Get the model instance."""
return model
def get_processor() -> Optional[Any]:
"""Get the processor instance."""
return processor
def to_serializable(obj: Any) -> Any:
"""
Convert all tensors to numpy arrays or Python primitives for safe serialization.
This ensures NO PyTorch tensors (CPU or CUDA) are in the return value.
Args:
obj: Object to convert
Returns:
Serializable object
"""
if isinstance(obj, torch.Tensor):
# Convert to numpy array (works for both CPU and CUDA tensors)
result = obj.cpu().numpy()
logger.debug(f"Converted tensor to numpy: shape={result.shape}, dtype={result.dtype}")
return result
elif isinstance(obj, dict):
return {k: to_serializable(v) for k, v in obj.items()}
elif isinstance(obj, list):
return [to_serializable(item) for item in obj]
elif isinstance(obj, tuple):
return tuple(to_serializable(item) for item in obj)
elif isinstance(obj, (int, float, str, bool, type(None))):
return obj
elif hasattr(obj, 'item'): # numpy scalar
return obj.item()
else:
# For unknown types, try to convert to string representation
logger.warning(f"Unknown type encountered: {type(obj)}, converting to string")
return str(obj)
@spaces.GPU(duration=GPU_DURATION_SECONDS)
def run_sam3_inference(
pil_image: Image.Image,
prompt_text: str,
threshold: float = DEFAULT_THRESHOLD,
mask_threshold: float = DEFAULT_MASK_THRESHOLD
) -> Optional[Dict[str, Any]]:
"""
Run SAM 3 inference - optimized for medical imaging.
Args:
pil_image: PIL Image to segment
prompt_text: Text prompt for segmentation (e.g., "brain", "tumor", "skull")
threshold: Detection confidence threshold, range [0.0, 1.0] (default 0.1 for medical images).
Lower values (0.0-0.3) are more permissive and better for subtle features.
Higher values (0.5-1.0) require high confidence, may miss detections.
mask_threshold: Mask binarization threshold, range [0.0, 1.0] (default 0.0 for medical images).
Lower values preserve more detail. Higher values create sharper masks.
Medical images often benefit from 0.0 to capture subtle boundaries.
Returns:
results dict with 'masks' and 'scores' as numpy arrays or lists, or None if failed
Note:
Default thresholds (0.1, 0.0) are optimized for medical imaging where features
may be subtle or low-contrast. For natural images, higher thresholds (0.5, 0.5)
may be more appropriate.
"""
if not is_model_loaded():
logger.error("Model not loaded - please check HF_TOKEN and model availability")
raise ValueError(
"SAM 3 model not loaded. Please check that HF_TOKEN is set correctly "
"and the model is accessible."
)
try:
# Determine device and move model to GPU if available
# (CUDA initialization happens here, inside @spaces.GPU)
device = "cuda" if torch.cuda.is_available() else "cpu"
logger.debug(f"Using device: {device}")
# Move model to device and set appropriate dtype
# Note: For nn.Module, .to() modifies in-place and returns self
# IMPORTANT: @spaces.GPU ensures sequential execution - requests are queued
# and processed one at a time, so there's NO concurrent access to the model.
# This makes in-place modification safe despite model being a global variable.
dtype = torch.float16 if device == "cuda" else torch.float32
model.to(device=device, dtype=dtype)
logger.debug(f"Model moved to {device} with dtype {dtype}")
# Prepare inputs - matching official implementation
inputs = processor(images=pil_image, text=prompt_text.strip(), return_tensors="pt").to(device)
# Convert float32 inputs to model dtype (float16 for GPU)
# - matching official implementation
for key in inputs:
if isinstance(inputs[key], torch.Tensor) and inputs[key].dtype == torch.float32:
inputs[key] = inputs[key].to(model.dtype)
with torch.no_grad():
outputs = model(**inputs)
logger.debug("Inference complete, processing results...")
# Post-process using processor method - matching official implementation
results = processor.post_process_instance_segmentation(
outputs,
threshold=threshold,
mask_threshold=mask_threshold,
target_sizes=inputs.get("original_sizes").tolist()
if "original_sizes" in inputs
else [pil_image.size[::-1]]
)[0] # Get first batch result
logger.debug(f"Results type: {type(results)}")
if isinstance(results, dict):
logger.debug(f"Results keys: {results.keys()}")
for key, value in results.items():
logger.debug(f" - {key}: type={type(value)}")
if isinstance(value, torch.Tensor):
logger.debug(
f" tensor device={value.device}, "
f"shape={value.shape}, dtype={value.dtype}"
)
elif isinstance(value, list) and len(value) > 0:
logger.debug(f" list length={len(value)}, first item type={type(value[0])}")
if isinstance(value[0], torch.Tensor):
logger.debug(f" first tensor device={value[0].device}")
# CRITICAL: Convert ALL tensors to numpy arrays before returning
# This ensures NO PyTorch tensors (CPU or CUDA) cross the process boundary
# Numpy arrays are safely serializable without triggering CUDA init
logger.debug("Converting all tensors to numpy arrays...")
results = to_serializable(results)
logger.debug("All tensors converted to serializable format")
# Move model back to CPU to free GPU memory (important for Spaces)
model.to("cpu")
logger.debug("Model moved back to CPU")
return results
except Exception as e:
logger.error(f"Error during SAM 3 inference: {e}", exc_info=True)
# Make sure to move model back to CPU even on error
if model is not None:
try:
model.to("cpu")
except RuntimeError as cleanup_error:
logger.warning(f"Could not move model back to CPU: {cleanup_error}")
return None