|
|
""" |
|
|
Model loading and inference for NeuroSAM 3 application. |
|
|
Handles SAM 3 model initialization and inference operations. |
|
|
""" |
|
|
|
|
|
from typing import Optional, Dict, Any |
|
|
import torch |
|
|
import spaces |
|
|
from PIL import Image |
|
|
from logger_config import logger |
|
|
from config import ( |
|
|
SAM_MODEL_ID, |
|
|
HF_TOKEN, |
|
|
DEFAULT_THRESHOLD, |
|
|
DEFAULT_MASK_THRESHOLD, |
|
|
GPU_DURATION_SECONDS, |
|
|
) |
|
|
|
|
|
|
|
|
try: |
|
|
from transformers import Sam3Processor, Sam3Model |
|
|
SAM3_AVAILABLE = True |
|
|
except ImportError: |
|
|
logger.warning("Sam3Processor/Sam3Model not found in transformers.") |
|
|
logger.warning("SAM3 requires transformers from GitHub main branch.") |
|
|
logger.warning("Install with: pip install git+https://github.com/huggingface/transformers.git") |
|
|
SAM3_AVAILABLE = False |
|
|
Sam3Processor = None |
|
|
Sam3Model = None |
|
|
|
|
|
|
|
|
model: Optional[Any] = None |
|
|
processor: Optional[Any] = None |
|
|
|
|
|
|
|
|
def initialize_model() -> bool: |
|
|
""" |
|
|
Initialize SAM 3 model and processor. |
|
|
|
|
|
Returns: |
|
|
True if model loaded successfully, False otherwise |
|
|
""" |
|
|
global model, processor |
|
|
|
|
|
if not SAM3_AVAILABLE: |
|
|
logger.error("SAM 3 classes not available in transformers library.") |
|
|
logger.error("Install with: pip install git+https://github.com/huggingface/transformers.git") |
|
|
return False |
|
|
|
|
|
if HF_TOKEN is None: |
|
|
logger.warning("Cannot load model: HF_TOKEN not set") |
|
|
model = None |
|
|
processor = None |
|
|
return False |
|
|
|
|
|
try: |
|
|
logger.info(f"Loading SAM 3 model: {SAM_MODEL_ID}") |
|
|
|
|
|
|
|
|
|
|
|
model = Sam3Model.from_pretrained( |
|
|
SAM_MODEL_ID, |
|
|
torch_dtype=torch.float32, |
|
|
token=HF_TOKEN |
|
|
) |
|
|
processor = Sam3Processor.from_pretrained(SAM_MODEL_ID, token=HF_TOKEN) |
|
|
model.eval() |
|
|
|
|
|
logger.info(f"SAM 3 Model loaded successfully on CPU! ({SAM_MODEL_ID})") |
|
|
logger.info("Model will be moved to GPU when inference is called") |
|
|
return True |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Failed to load SAM 3 model: {e}", exc_info=True) |
|
|
logger.error("Ensure you have:") |
|
|
logger.error(" 1. transformers from GitHub main branch for SAM 3 support") |
|
|
logger.error(" Install with: pip install git+https://github.com/huggingface/transformers.git") |
|
|
logger.error(" 2. Valid Hugging Face token with access to SAM 3") |
|
|
logger.error(" 3. Sufficient memory for the model") |
|
|
model = None |
|
|
processor = None |
|
|
return False |
|
|
|
|
|
|
|
|
def is_model_loaded() -> bool: |
|
|
"""Check if model is loaded.""" |
|
|
return model is not None and processor is not None |
|
|
|
|
|
|
|
|
def get_model() -> Optional[Any]: |
|
|
"""Get the model instance.""" |
|
|
return model |
|
|
|
|
|
|
|
|
def get_processor() -> Optional[Any]: |
|
|
"""Get the processor instance.""" |
|
|
return processor |
|
|
|
|
|
|
|
|
def to_serializable(obj: Any) -> Any: |
|
|
""" |
|
|
Convert all tensors to numpy arrays or Python primitives for safe serialization. |
|
|
This ensures NO PyTorch tensors (CPU or CUDA) are in the return value. |
|
|
|
|
|
Args: |
|
|
obj: Object to convert |
|
|
|
|
|
Returns: |
|
|
Serializable object |
|
|
""" |
|
|
if isinstance(obj, torch.Tensor): |
|
|
|
|
|
result = obj.cpu().numpy() |
|
|
logger.debug(f"Converted tensor to numpy: shape={result.shape}, dtype={result.dtype}") |
|
|
return result |
|
|
elif isinstance(obj, dict): |
|
|
return {k: to_serializable(v) for k, v in obj.items()} |
|
|
elif isinstance(obj, list): |
|
|
return [to_serializable(item) for item in obj] |
|
|
elif isinstance(obj, tuple): |
|
|
return tuple(to_serializable(item) for item in obj) |
|
|
elif isinstance(obj, (int, float, str, bool, type(None))): |
|
|
return obj |
|
|
elif hasattr(obj, 'item'): |
|
|
return obj.item() |
|
|
else: |
|
|
|
|
|
logger.warning(f"Unknown type encountered: {type(obj)}, converting to string") |
|
|
return str(obj) |
|
|
|
|
|
|
|
|
@spaces.GPU(duration=GPU_DURATION_SECONDS) |
|
|
def run_sam3_inference( |
|
|
pil_image: Image.Image, |
|
|
prompt_text: str, |
|
|
threshold: float = DEFAULT_THRESHOLD, |
|
|
mask_threshold: float = DEFAULT_MASK_THRESHOLD |
|
|
) -> Optional[Dict[str, Any]]: |
|
|
""" |
|
|
Run SAM 3 inference - optimized for medical imaging. |
|
|
|
|
|
Args: |
|
|
pil_image: PIL Image to segment |
|
|
prompt_text: Text prompt for segmentation (e.g., "brain", "tumor", "skull") |
|
|
threshold: Detection confidence threshold, range [0.0, 1.0] (default 0.1 for medical images). |
|
|
Lower values (0.0-0.3) are more permissive and better for subtle features. |
|
|
Higher values (0.5-1.0) require high confidence, may miss detections. |
|
|
mask_threshold: Mask binarization threshold, range [0.0, 1.0] (default 0.0 for medical images). |
|
|
Lower values preserve more detail. Higher values create sharper masks. |
|
|
Medical images often benefit from 0.0 to capture subtle boundaries. |
|
|
|
|
|
Returns: |
|
|
results dict with 'masks' and 'scores' as numpy arrays or lists, or None if failed |
|
|
|
|
|
Note: |
|
|
Default thresholds (0.1, 0.0) are optimized for medical imaging where features |
|
|
may be subtle or low-contrast. For natural images, higher thresholds (0.5, 0.5) |
|
|
may be more appropriate. |
|
|
""" |
|
|
if not is_model_loaded(): |
|
|
logger.error("Model not loaded - please check HF_TOKEN and model availability") |
|
|
raise ValueError( |
|
|
"SAM 3 model not loaded. Please check that HF_TOKEN is set correctly " |
|
|
"and the model is accessible." |
|
|
) |
|
|
|
|
|
try: |
|
|
|
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
logger.debug(f"Using device: {device}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dtype = torch.float16 if device == "cuda" else torch.float32 |
|
|
model.to(device=device, dtype=dtype) |
|
|
logger.debug(f"Model moved to {device} with dtype {dtype}") |
|
|
|
|
|
|
|
|
inputs = processor(images=pil_image, text=prompt_text.strip(), return_tensors="pt").to(device) |
|
|
|
|
|
|
|
|
|
|
|
for key in inputs: |
|
|
if isinstance(inputs[key], torch.Tensor) and inputs[key].dtype == torch.float32: |
|
|
inputs[key] = inputs[key].to(model.dtype) |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model(**inputs) |
|
|
|
|
|
logger.debug("Inference complete, processing results...") |
|
|
|
|
|
|
|
|
results = processor.post_process_instance_segmentation( |
|
|
outputs, |
|
|
threshold=threshold, |
|
|
mask_threshold=mask_threshold, |
|
|
target_sizes=inputs.get("original_sizes").tolist() |
|
|
if "original_sizes" in inputs |
|
|
else [pil_image.size[::-1]] |
|
|
)[0] |
|
|
|
|
|
logger.debug(f"Results type: {type(results)}") |
|
|
if isinstance(results, dict): |
|
|
logger.debug(f"Results keys: {results.keys()}") |
|
|
for key, value in results.items(): |
|
|
logger.debug(f" - {key}: type={type(value)}") |
|
|
if isinstance(value, torch.Tensor): |
|
|
logger.debug( |
|
|
f" tensor device={value.device}, " |
|
|
f"shape={value.shape}, dtype={value.dtype}" |
|
|
) |
|
|
elif isinstance(value, list) and len(value) > 0: |
|
|
logger.debug(f" list length={len(value)}, first item type={type(value[0])}") |
|
|
if isinstance(value[0], torch.Tensor): |
|
|
logger.debug(f" first tensor device={value[0].device}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logger.debug("Converting all tensors to numpy arrays...") |
|
|
results = to_serializable(results) |
|
|
|
|
|
logger.debug("All tensors converted to serializable format") |
|
|
|
|
|
|
|
|
model.to("cpu") |
|
|
logger.debug("Model moved back to CPU") |
|
|
|
|
|
return results |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error during SAM 3 inference: {e}", exc_info=True) |
|
|
|
|
|
if model is not None: |
|
|
try: |
|
|
model.to("cpu") |
|
|
except RuntimeError as cleanup_error: |
|
|
logger.warning(f"Could not move model back to CPU: {cleanup_error}") |
|
|
return None |
|
|
|
|
|
|