File size: 9,583 Bytes
69066c5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 |
"""
Model loading and inference for NeuroSAM 3 application.
Handles SAM 3 model initialization and inference operations.
"""
from typing import Optional, Dict, Any
import torch
import spaces
from PIL import Image
from logger_config import logger
from config import (
SAM_MODEL_ID,
HF_TOKEN,
DEFAULT_THRESHOLD,
DEFAULT_MASK_THRESHOLD,
GPU_DURATION_SECONDS,
)
# Try to import SAM 3 classes
try:
from transformers import Sam3Processor, Sam3Model
SAM3_AVAILABLE = True
except ImportError:
logger.warning("Sam3Processor/Sam3Model not found in transformers.")
logger.warning("SAM3 requires transformers from GitHub main branch.")
logger.warning("Install with: pip install git+https://github.com/huggingface/transformers.git")
SAM3_AVAILABLE = False
Sam3Processor = None
Sam3Model = None
# Global model and processor instances
model: Optional[Any] = None
processor: Optional[Any] = None
def initialize_model() -> bool:
"""
Initialize SAM 3 model and processor.
Returns:
True if model loaded successfully, False otherwise
"""
global model, processor
if not SAM3_AVAILABLE:
logger.error("SAM 3 classes not available in transformers library.")
logger.error("Install with: pip install git+https://github.com/huggingface/transformers.git")
return False
if HF_TOKEN is None:
logger.warning("Cannot load model: HF_TOKEN not set")
model = None
processor = None
return False
try:
logger.info(f"Loading SAM 3 model: {SAM_MODEL_ID}")
# Load model on CPU to avoid CUDA initialization in main process
# (for HF Spaces Stateless GPU)
model = Sam3Model.from_pretrained(
SAM_MODEL_ID,
torch_dtype=torch.float32, # Load as float32 on CPU
token=HF_TOKEN
)
processor = Sam3Processor.from_pretrained(SAM_MODEL_ID, token=HF_TOKEN)
model.eval()
logger.info(f"SAM 3 Model loaded successfully on CPU! ({SAM_MODEL_ID})")
logger.info("Model will be moved to GPU when inference is called")
return True
except Exception as e:
logger.error(f"Failed to load SAM 3 model: {e}", exc_info=True)
logger.error("Ensure you have:")
logger.error(" 1. transformers from GitHub main branch for SAM 3 support")
logger.error(" Install with: pip install git+https://github.com/huggingface/transformers.git")
logger.error(" 2. Valid Hugging Face token with access to SAM 3")
logger.error(" 3. Sufficient memory for the model")
model = None
processor = None
return False
def is_model_loaded() -> bool:
"""Check if model is loaded."""
return model is not None and processor is not None
def get_model() -> Optional[Any]:
"""Get the model instance."""
return model
def get_processor() -> Optional[Any]:
"""Get the processor instance."""
return processor
def to_serializable(obj: Any) -> Any:
"""
Convert all tensors to numpy arrays or Python primitives for safe serialization.
This ensures NO PyTorch tensors (CPU or CUDA) are in the return value.
Args:
obj: Object to convert
Returns:
Serializable object
"""
if isinstance(obj, torch.Tensor):
# Convert to numpy array (works for both CPU and CUDA tensors)
result = obj.cpu().numpy()
logger.debug(f"Converted tensor to numpy: shape={result.shape}, dtype={result.dtype}")
return result
elif isinstance(obj, dict):
return {k: to_serializable(v) for k, v in obj.items()}
elif isinstance(obj, list):
return [to_serializable(item) for item in obj]
elif isinstance(obj, tuple):
return tuple(to_serializable(item) for item in obj)
elif isinstance(obj, (int, float, str, bool, type(None))):
return obj
elif hasattr(obj, 'item'): # numpy scalar
return obj.item()
else:
# For unknown types, try to convert to string representation
logger.warning(f"Unknown type encountered: {type(obj)}, converting to string")
return str(obj)
@spaces.GPU(duration=GPU_DURATION_SECONDS)
def run_sam3_inference(
pil_image: Image.Image,
prompt_text: str,
threshold: float = DEFAULT_THRESHOLD,
mask_threshold: float = DEFAULT_MASK_THRESHOLD
) -> Optional[Dict[str, Any]]:
"""
Run SAM 3 inference - optimized for medical imaging.
Args:
pil_image: PIL Image to segment
prompt_text: Text prompt for segmentation (e.g., "brain", "tumor", "skull")
threshold: Detection confidence threshold, range [0.0, 1.0] (default 0.1 for medical images).
Lower values (0.0-0.3) are more permissive and better for subtle features.
Higher values (0.5-1.0) require high confidence, may miss detections.
mask_threshold: Mask binarization threshold, range [0.0, 1.0] (default 0.0 for medical images).
Lower values preserve more detail. Higher values create sharper masks.
Medical images often benefit from 0.0 to capture subtle boundaries.
Returns:
results dict with 'masks' and 'scores' as numpy arrays or lists, or None if failed
Note:
Default thresholds (0.1, 0.0) are optimized for medical imaging where features
may be subtle or low-contrast. For natural images, higher thresholds (0.5, 0.5)
may be more appropriate.
"""
if not is_model_loaded():
logger.error("Model not loaded - please check HF_TOKEN and model availability")
raise ValueError(
"SAM 3 model not loaded. Please check that HF_TOKEN is set correctly "
"and the model is accessible."
)
try:
# Determine device and move model to GPU if available
# (CUDA initialization happens here, inside @spaces.GPU)
device = "cuda" if torch.cuda.is_available() else "cpu"
logger.debug(f"Using device: {device}")
# Move model to device and set appropriate dtype
# Note: For nn.Module, .to() modifies in-place and returns self
# IMPORTANT: @spaces.GPU ensures sequential execution - requests are queued
# and processed one at a time, so there's NO concurrent access to the model.
# This makes in-place modification safe despite model being a global variable.
dtype = torch.float16 if device == "cuda" else torch.float32
model.to(device=device, dtype=dtype)
logger.debug(f"Model moved to {device} with dtype {dtype}")
# Prepare inputs - matching official implementation
inputs = processor(images=pil_image, text=prompt_text.strip(), return_tensors="pt").to(device)
# Convert float32 inputs to model dtype (float16 for GPU)
# - matching official implementation
for key in inputs:
if isinstance(inputs[key], torch.Tensor) and inputs[key].dtype == torch.float32:
inputs[key] = inputs[key].to(model.dtype)
with torch.no_grad():
outputs = model(**inputs)
logger.debug("Inference complete, processing results...")
# Post-process using processor method - matching official implementation
results = processor.post_process_instance_segmentation(
outputs,
threshold=threshold,
mask_threshold=mask_threshold,
target_sizes=inputs.get("original_sizes").tolist()
if "original_sizes" in inputs
else [pil_image.size[::-1]]
)[0] # Get first batch result
logger.debug(f"Results type: {type(results)}")
if isinstance(results, dict):
logger.debug(f"Results keys: {results.keys()}")
for key, value in results.items():
logger.debug(f" - {key}: type={type(value)}")
if isinstance(value, torch.Tensor):
logger.debug(
f" tensor device={value.device}, "
f"shape={value.shape}, dtype={value.dtype}"
)
elif isinstance(value, list) and len(value) > 0:
logger.debug(f" list length={len(value)}, first item type={type(value[0])}")
if isinstance(value[0], torch.Tensor):
logger.debug(f" first tensor device={value[0].device}")
# CRITICAL: Convert ALL tensors to numpy arrays before returning
# This ensures NO PyTorch tensors (CPU or CUDA) cross the process boundary
# Numpy arrays are safely serializable without triggering CUDA init
logger.debug("Converting all tensors to numpy arrays...")
results = to_serializable(results)
logger.debug("All tensors converted to serializable format")
# Move model back to CPU to free GPU memory (important for Spaces)
model.to("cpu")
logger.debug("Model moved back to CPU")
return results
except Exception as e:
logger.error(f"Error during SAM 3 inference: {e}", exc_info=True)
# Make sure to move model back to CPU even on error
if model is not None:
try:
model.to("cpu")
except RuntimeError as cleanup_error:
logger.warning(f"Could not move model back to CPU: {cleanup_error}")
return None
|