File size: 9,583 Bytes
69066c5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
"""
Model loading and inference for NeuroSAM 3 application.
Handles SAM 3 model initialization and inference operations.
"""

from typing import Optional, Dict, Any
import torch
import spaces
from PIL import Image
from logger_config import logger
from config import (
    SAM_MODEL_ID,
    HF_TOKEN,
    DEFAULT_THRESHOLD,
    DEFAULT_MASK_THRESHOLD,
    GPU_DURATION_SECONDS,
)

# Try to import SAM 3 classes
try:
    from transformers import Sam3Processor, Sam3Model
    SAM3_AVAILABLE = True
except ImportError:
    logger.warning("Sam3Processor/Sam3Model not found in transformers.")
    logger.warning("SAM3 requires transformers from GitHub main branch.")
    logger.warning("Install with: pip install git+https://github.com/huggingface/transformers.git")
    SAM3_AVAILABLE = False
    Sam3Processor = None
    Sam3Model = None

# Global model and processor instances
model: Optional[Any] = None
processor: Optional[Any] = None


def initialize_model() -> bool:
    """
    Initialize SAM 3 model and processor.
    
    Returns:
        True if model loaded successfully, False otherwise
    """
    global model, processor
    
    if not SAM3_AVAILABLE:
        logger.error("SAM 3 classes not available in transformers library.")
        logger.error("Install with: pip install git+https://github.com/huggingface/transformers.git")
        return False
    
    if HF_TOKEN is None:
        logger.warning("Cannot load model: HF_TOKEN not set")
        model = None
        processor = None
        return False
    
    try:
        logger.info(f"Loading SAM 3 model: {SAM_MODEL_ID}")
        
        # Load model on CPU to avoid CUDA initialization in main process
        # (for HF Spaces Stateless GPU)
        model = Sam3Model.from_pretrained(
            SAM_MODEL_ID,
            torch_dtype=torch.float32,  # Load as float32 on CPU
            token=HF_TOKEN
        )
        processor = Sam3Processor.from_pretrained(SAM_MODEL_ID, token=HF_TOKEN)
        model.eval()
        
        logger.info(f"SAM 3 Model loaded successfully on CPU! ({SAM_MODEL_ID})")
        logger.info("Model will be moved to GPU when inference is called")
        return True
        
    except Exception as e:
        logger.error(f"Failed to load SAM 3 model: {e}", exc_info=True)
        logger.error("Ensure you have:")
        logger.error("  1. transformers from GitHub main branch for SAM 3 support")
        logger.error("     Install with: pip install git+https://github.com/huggingface/transformers.git")
        logger.error("  2. Valid Hugging Face token with access to SAM 3")
        logger.error("  3. Sufficient memory for the model")
        model = None
        processor = None
        return False


def is_model_loaded() -> bool:
    """Check if model is loaded."""
    return model is not None and processor is not None


def get_model() -> Optional[Any]:
    """Get the model instance."""
    return model


def get_processor() -> Optional[Any]:
    """Get the processor instance."""
    return processor


def to_serializable(obj: Any) -> Any:
    """
    Convert all tensors to numpy arrays or Python primitives for safe serialization.
    This ensures NO PyTorch tensors (CPU or CUDA) are in the return value.
    
    Args:
        obj: Object to convert
    
    Returns:
        Serializable object
    """
    if isinstance(obj, torch.Tensor):
        # Convert to numpy array (works for both CPU and CUDA tensors)
        result = obj.cpu().numpy()
        logger.debug(f"Converted tensor to numpy: shape={result.shape}, dtype={result.dtype}")
        return result
    elif isinstance(obj, dict):
        return {k: to_serializable(v) for k, v in obj.items()}
    elif isinstance(obj, list):
        return [to_serializable(item) for item in obj]
    elif isinstance(obj, tuple):
        return tuple(to_serializable(item) for item in obj)
    elif isinstance(obj, (int, float, str, bool, type(None))):
        return obj
    elif hasattr(obj, 'item'):  # numpy scalar
        return obj.item()
    else:
        # For unknown types, try to convert to string representation
        logger.warning(f"Unknown type encountered: {type(obj)}, converting to string")
        return str(obj)


@spaces.GPU(duration=GPU_DURATION_SECONDS)
def run_sam3_inference(
    pil_image: Image.Image,
    prompt_text: str,
    threshold: float = DEFAULT_THRESHOLD,
    mask_threshold: float = DEFAULT_MASK_THRESHOLD
) -> Optional[Dict[str, Any]]:
    """
    Run SAM 3 inference - optimized for medical imaging.
    
    Args:
        pil_image: PIL Image to segment
        prompt_text: Text prompt for segmentation (e.g., "brain", "tumor", "skull")
        threshold: Detection confidence threshold, range [0.0, 1.0] (default 0.1 for medical images).
                   Lower values (0.0-0.3) are more permissive and better for subtle features.
                   Higher values (0.5-1.0) require high confidence, may miss detections.
        mask_threshold: Mask binarization threshold, range [0.0, 1.0] (default 0.0 for medical images).
                       Lower values preserve more detail. Higher values create sharper masks.
                       Medical images often benefit from 0.0 to capture subtle boundaries.
    
    Returns:
        results dict with 'masks' and 'scores' as numpy arrays or lists, or None if failed
    
    Note:
        Default thresholds (0.1, 0.0) are optimized for medical imaging where features
        may be subtle or low-contrast. For natural images, higher thresholds (0.5, 0.5)
        may be more appropriate.
    """
    if not is_model_loaded():
        logger.error("Model not loaded - please check HF_TOKEN and model availability")
        raise ValueError(
            "SAM 3 model not loaded. Please check that HF_TOKEN is set correctly "
            "and the model is accessible."
        )
    
    try:
        # Determine device and move model to GPU if available
        # (CUDA initialization happens here, inside @spaces.GPU)
        device = "cuda" if torch.cuda.is_available() else "cpu"
        logger.debug(f"Using device: {device}")
        
        # Move model to device and set appropriate dtype
        # Note: For nn.Module, .to() modifies in-place and returns self
        # IMPORTANT: @spaces.GPU ensures sequential execution - requests are queued
        # and processed one at a time, so there's NO concurrent access to the model.
        # This makes in-place modification safe despite model being a global variable.
        dtype = torch.float16 if device == "cuda" else torch.float32
        model.to(device=device, dtype=dtype)
        logger.debug(f"Model moved to {device} with dtype {dtype}")
        
        # Prepare inputs - matching official implementation
        inputs = processor(images=pil_image, text=prompt_text.strip(), return_tensors="pt").to(device)
        
        # Convert float32 inputs to model dtype (float16 for GPU)
        # - matching official implementation
        for key in inputs:
            if isinstance(inputs[key], torch.Tensor) and inputs[key].dtype == torch.float32:
                inputs[key] = inputs[key].to(model.dtype)
        
        with torch.no_grad():
            outputs = model(**inputs)
        
        logger.debug("Inference complete, processing results...")
        
        # Post-process using processor method - matching official implementation
        results = processor.post_process_instance_segmentation(
            outputs,
            threshold=threshold,
            mask_threshold=mask_threshold,
            target_sizes=inputs.get("original_sizes").tolist()
            if "original_sizes" in inputs
            else [pil_image.size[::-1]]
        )[0]  # Get first batch result
        
        logger.debug(f"Results type: {type(results)}")
        if isinstance(results, dict):
            logger.debug(f"Results keys: {results.keys()}")
            for key, value in results.items():
                logger.debug(f"  - {key}: type={type(value)}")
                if isinstance(value, torch.Tensor):
                    logger.debug(
                        f"    tensor device={value.device}, "
                        f"shape={value.shape}, dtype={value.dtype}"
                    )
                elif isinstance(value, list) and len(value) > 0:
                    logger.debug(f"    list length={len(value)}, first item type={type(value[0])}")
                    if isinstance(value[0], torch.Tensor):
                        logger.debug(f"    first tensor device={value[0].device}")
        
        # CRITICAL: Convert ALL tensors to numpy arrays before returning
        # This ensures NO PyTorch tensors (CPU or CUDA) cross the process boundary
        # Numpy arrays are safely serializable without triggering CUDA init
        logger.debug("Converting all tensors to numpy arrays...")
        results = to_serializable(results)
        
        logger.debug("All tensors converted to serializable format")
        
        # Move model back to CPU to free GPU memory (important for Spaces)
        model.to("cpu")
        logger.debug("Model moved back to CPU")
        
        return results
        
    except Exception as e:
        logger.error(f"Error during SAM 3 inference: {e}", exc_info=True)
        # Make sure to move model back to CPU even on error
        if model is not None:
            try:
                model.to("cpu")
            except RuntimeError as cleanup_error:
                logger.warning(f"Could not move model back to CPU: {cleanup_error}")
        return None