""" Model loading and inference for NeuroSAM 3 application. Handles SAM 3 model initialization and inference operations. """ from typing import Optional, Dict, Any import torch import spaces from PIL import Image from logger_config import logger from config import ( SAM_MODEL_ID, HF_TOKEN, DEFAULT_THRESHOLD, DEFAULT_MASK_THRESHOLD, GPU_DURATION_SECONDS, ) # Try to import SAM 3 classes try: from transformers import Sam3Processor, Sam3Model SAM3_AVAILABLE = True except ImportError: logger.warning("Sam3Processor/Sam3Model not found in transformers.") logger.warning("SAM3 requires transformers from GitHub main branch.") logger.warning("Install with: pip install git+https://github.com/huggingface/transformers.git") SAM3_AVAILABLE = False Sam3Processor = None Sam3Model = None # Global model and processor instances model: Optional[Any] = None processor: Optional[Any] = None def initialize_model() -> bool: """ Initialize SAM 3 model and processor. Returns: True if model loaded successfully, False otherwise """ global model, processor if not SAM3_AVAILABLE: logger.error("SAM 3 classes not available in transformers library.") logger.error("Install with: pip install git+https://github.com/huggingface/transformers.git") return False if HF_TOKEN is None: logger.warning("Cannot load model: HF_TOKEN not set") model = None processor = None return False try: logger.info(f"Loading SAM 3 model: {SAM_MODEL_ID}") # Load model on CPU to avoid CUDA initialization in main process # (for HF Spaces Stateless GPU) model = Sam3Model.from_pretrained( SAM_MODEL_ID, torch_dtype=torch.float32, # Load as float32 on CPU token=HF_TOKEN ) processor = Sam3Processor.from_pretrained(SAM_MODEL_ID, token=HF_TOKEN) model.eval() logger.info(f"SAM 3 Model loaded successfully on CPU! ({SAM_MODEL_ID})") logger.info("Model will be moved to GPU when inference is called") return True except Exception as e: logger.error(f"Failed to load SAM 3 model: {e}", exc_info=True) logger.error("Ensure you have:") logger.error(" 1. transformers from GitHub main branch for SAM 3 support") logger.error(" Install with: pip install git+https://github.com/huggingface/transformers.git") logger.error(" 2. Valid Hugging Face token with access to SAM 3") logger.error(" 3. Sufficient memory for the model") model = None processor = None return False def is_model_loaded() -> bool: """Check if model is loaded.""" return model is not None and processor is not None def get_model() -> Optional[Any]: """Get the model instance.""" return model def get_processor() -> Optional[Any]: """Get the processor instance.""" return processor def to_serializable(obj: Any) -> Any: """ Convert all tensors to numpy arrays or Python primitives for safe serialization. This ensures NO PyTorch tensors (CPU or CUDA) are in the return value. Args: obj: Object to convert Returns: Serializable object """ if isinstance(obj, torch.Tensor): # Convert to numpy array (works for both CPU and CUDA tensors) result = obj.cpu().numpy() logger.debug(f"Converted tensor to numpy: shape={result.shape}, dtype={result.dtype}") return result elif isinstance(obj, dict): return {k: to_serializable(v) for k, v in obj.items()} elif isinstance(obj, list): return [to_serializable(item) for item in obj] elif isinstance(obj, tuple): return tuple(to_serializable(item) for item in obj) elif isinstance(obj, (int, float, str, bool, type(None))): return obj elif hasattr(obj, 'item'): # numpy scalar return obj.item() else: # For unknown types, try to convert to string representation logger.warning(f"Unknown type encountered: {type(obj)}, converting to string") return str(obj) @spaces.GPU(duration=GPU_DURATION_SECONDS) def run_sam3_inference( pil_image: Image.Image, prompt_text: str, threshold: float = DEFAULT_THRESHOLD, mask_threshold: float = DEFAULT_MASK_THRESHOLD ) -> Optional[Dict[str, Any]]: """ Run SAM 3 inference - optimized for medical imaging. Args: pil_image: PIL Image to segment prompt_text: Text prompt for segmentation (e.g., "brain", "tumor", "skull") threshold: Detection confidence threshold, range [0.0, 1.0] (default 0.1 for medical images). Lower values (0.0-0.3) are more permissive and better for subtle features. Higher values (0.5-1.0) require high confidence, may miss detections. mask_threshold: Mask binarization threshold, range [0.0, 1.0] (default 0.0 for medical images). Lower values preserve more detail. Higher values create sharper masks. Medical images often benefit from 0.0 to capture subtle boundaries. Returns: results dict with 'masks' and 'scores' as numpy arrays or lists, or None if failed Note: Default thresholds (0.1, 0.0) are optimized for medical imaging where features may be subtle or low-contrast. For natural images, higher thresholds (0.5, 0.5) may be more appropriate. """ if not is_model_loaded(): logger.error("Model not loaded - please check HF_TOKEN and model availability") raise ValueError( "SAM 3 model not loaded. Please check that HF_TOKEN is set correctly " "and the model is accessible." ) try: # Determine device and move model to GPU if available # (CUDA initialization happens here, inside @spaces.GPU) device = "cuda" if torch.cuda.is_available() else "cpu" logger.debug(f"Using device: {device}") # Move model to device and set appropriate dtype # Note: For nn.Module, .to() modifies in-place and returns self # IMPORTANT: @spaces.GPU ensures sequential execution - requests are queued # and processed one at a time, so there's NO concurrent access to the model. # This makes in-place modification safe despite model being a global variable. dtype = torch.float16 if device == "cuda" else torch.float32 model.to(device=device, dtype=dtype) logger.debug(f"Model moved to {device} with dtype {dtype}") # Prepare inputs - matching official implementation inputs = processor(images=pil_image, text=prompt_text.strip(), return_tensors="pt").to(device) # Convert float32 inputs to model dtype (float16 for GPU) # - matching official implementation for key in inputs: if isinstance(inputs[key], torch.Tensor) and inputs[key].dtype == torch.float32: inputs[key] = inputs[key].to(model.dtype) with torch.no_grad(): outputs = model(**inputs) logger.debug("Inference complete, processing results...") # Post-process using processor method - matching official implementation results = processor.post_process_instance_segmentation( outputs, threshold=threshold, mask_threshold=mask_threshold, target_sizes=inputs.get("original_sizes").tolist() if "original_sizes" in inputs else [pil_image.size[::-1]] )[0] # Get first batch result logger.debug(f"Results type: {type(results)}") if isinstance(results, dict): logger.debug(f"Results keys: {results.keys()}") for key, value in results.items(): logger.debug(f" - {key}: type={type(value)}") if isinstance(value, torch.Tensor): logger.debug( f" tensor device={value.device}, " f"shape={value.shape}, dtype={value.dtype}" ) elif isinstance(value, list) and len(value) > 0: logger.debug(f" list length={len(value)}, first item type={type(value[0])}") if isinstance(value[0], torch.Tensor): logger.debug(f" first tensor device={value[0].device}") # CRITICAL: Convert ALL tensors to numpy arrays before returning # This ensures NO PyTorch tensors (CPU or CUDA) cross the process boundary # Numpy arrays are safely serializable without triggering CUDA init logger.debug("Converting all tensors to numpy arrays...") results = to_serializable(results) logger.debug("All tensors converted to serializable format") # Move model back to CPU to free GPU memory (important for Spaces) model.to("cpu") logger.debug("Model moved back to CPU") return results except Exception as e: logger.error(f"Error during SAM 3 inference: {e}", exc_info=True) # Make sure to move model back to CPU even on error if model is not None: try: model.to("cpu") except RuntimeError as cleanup_error: logger.warning(f"Could not move model back to CPU: {cleanup_error}") return None