Spaces:
Sleeping
Sleeping
File size: 12,025 Bytes
0234c58 c2c8715 42a4892 4780d8d c2c8715 0234c58 4780d8d 0234c58 4780d8d 0234c58 6e06a36 0234c58 42a4892 0234c58 6e06a36 42a4892 6e06a36 0234c58 42a4892 0234c58 42a4892 6e06a36 42a4892 6e06a36 4780d8d 0234c58 c2c8715 0234c58 4780d8d 0234c58 4780d8d 0234c58 4780d8d 0234c58 4780d8d 0234c58 6e06a36 0234c58 6e06a36 0234c58 6e06a36 0234c58 6e06a36 0234c58 42a4892 0234c58 4780d8d 0234c58 4780d8d 0234c58 4780d8d 0234c58 6e06a36 0234c58 4780d8d 0234c58 6e06a36 4780d8d 6e06a36 4780d8d 6e06a36 0234c58 6e06a36 0234c58 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 |
"""Model management module for batch processing optimization.
This module provides model loading and caching infrastructure to support
efficient batch processing of multiple slides by loading models once instead
of reloading for each slide.
"""
import gc
import pickle
from pathlib import Path
from typing import Dict, Optional
import torch
from loguru import logger
from mosaic.data_directory import get_data_directory
from mosaic.hardware import IS_T4_GPU, GPU_NAME
from mussel.models import ModelType, get_model_factory
class ModelCache:
"""Container for pre-loaded models with T4-aware memory management.
This class manages loading and caching of all models used in the slide
analysis pipeline. It implements adaptive memory management that adjusts
behavior based on GPU type (T4 vs A100) to avoid out-of-memory errors.
Attributes:
ctranspath_model: Pre-loaded CTransPath feature extraction model
optimus_model: Pre-loaded Optimus feature extraction model
marker_classifier: Pre-loaded marker classifier model
aeon_model: Pre-loaded Aeon cancer subtype prediction model
paladin_models: Dict mapping (cancer_subtype, target) -> model
is_t4_gpu: Whether running on a T4 GPU (16GB memory)
aggressive_memory_mgmt: If True, aggressively free Paladin models after use
device: torch.device for GPU/CPU placement
"""
def __init__(
self,
ctranspath_model=None,
optimus_model=None,
marker_classifier=None,
aeon_model=None,
is_t4_gpu=False,
aggressive_memory_mgmt=False,
device=None,
):
self.ctranspath_model = ctranspath_model
self.optimus_model = optimus_model
self.marker_classifier = marker_classifier
self.aeon_model = aeon_model
self.paladin_models: Dict[tuple, torch.nn.Module] = {}
self.is_t4_gpu = is_t4_gpu
self.aggressive_memory_mgmt = aggressive_memory_mgmt
self.device = device or torch.device(
"cuda" if torch.cuda.is_available() else "cpu"
)
def cleanup_paladin(self):
"""Aggressively free all Paladin models from memory.
Used on T4 GPUs to free memory between inferences.
"""
if self.paladin_models:
logger.debug(f"Cleaning up {len(self.paladin_models)} Paladin models")
for key in list(self.paladin_models.keys()):
del self.paladin_models[key]
self.paladin_models.clear()
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
def cleanup(self):
"""Release all models and free GPU memory.
Called at the end of batch processing to ensure clean shutdown.
"""
logger.info("Cleaning up all models from memory")
# Clean up Paladin models
self.cleanup_paladin()
# Clean up core models
if self.ctranspath_model is not None:
del self.ctranspath_model
self.ctranspath_model = None
if self.optimus_model is not None:
del self.optimus_model
self.optimus_model = None
if self.marker_classifier is not None:
del self.marker_classifier
self.marker_classifier = None
if self.aeon_model is not None:
del self.aeon_model
self.aeon_model = None
# Force garbage collection and GPU cache clearing
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
mem_allocated = torch.cuda.memory_allocated() / (1024**3)
logger.info(f"GPU memory after cleanup: {mem_allocated:.2f} GB")
def load_all_models(
use_gpu=True,
aggressive_memory_mgmt: Optional[bool] = None,
) -> ModelCache:
"""Load core models once for batch processing.
Loads CTransPath, Optimus, Marker Classifier, and Aeon models into memory.
Paladin models are loaded on-demand via load_paladin_model_for_inference().
Args:
use_gpu: If True, load models to GPU. If False, use CPU.
aggressive_memory_mgmt: Memory management strategy:
- None: Auto-detect based on GPU type (T4 = True, A100 = False)
- True: T4-style aggressive cleanup (load/delete Paladin models)
- False: A100-style caching (keep Paladin models loaded)
Returns:
ModelCache instance with all core models loaded
Raises:
FileNotFoundError: If model files are not found in data/ directory
RuntimeError: If CUDA is requested but not available
"""
logger.info("=" * 80)
logger.info("BATCH PROCESSING: Loading models (this happens ONCE per batch)")
logger.info("=" * 80)
# Use centralized GPU detection
device = torch.device("cpu")
if use_gpu and torch.cuda.is_available():
device = torch.device("cuda")
gpu_memory_total = torch.cuda.get_device_properties(0).total_memory / (1024**3)
logger.info(f"GPU detected: {GPU_NAME}")
logger.info(f"GPU total memory: {gpu_memory_total:.2f} GB")
# Log initial GPU memory
mem_before = torch.cuda.memory_allocated() / (1024**3)
logger.info(f"GPU memory before loading models: {mem_before:.2f} GB")
# Auto-detect memory management strategy based on centralized hardware detection
if aggressive_memory_mgmt is None:
aggressive_memory_mgmt = IS_T4_GPU
strategy = "AGGRESSIVE (T4)" if IS_T4_GPU else "CACHING (High-Memory GPU)"
logger.info(f"Memory management strategy: {strategy}")
if IS_T4_GPU:
logger.info(" β Paladin models will be loaded and freed per slide")
else:
logger.info(
" β Paladin models will be cached and reused across slides"
)
elif use_gpu and not torch.cuda.is_available():
logger.warning("GPU requested but CUDA not available, falling back to CPU")
use_gpu = False
if aggressive_memory_mgmt is None:
aggressive_memory_mgmt = False
# Get model data directory (HF cache or local data/)
data_dir = get_data_directory()
logger.info(f"Using model data directory: {data_dir}")
# Load CTransPath model
logger.info("Loading CTransPath model...")
ctranspath_path = data_dir / "ctranspath.pth"
if not ctranspath_path.exists():
raise FileNotFoundError(f"CTransPath model not found at {ctranspath_path}")
ctranspath_factory = get_model_factory(ModelType.CTRANSPATH)
ctranspath_model = ctranspath_factory.get_model(
str(ctranspath_path), use_gpu=use_gpu, gpu_device_id=0 if use_gpu else None
)
logger.info("β CTransPath model loaded")
if use_gpu and torch.cuda.is_available():
mem = torch.cuda.memory_allocated() / (1024**3)
logger.info(f" GPU memory: {mem:.2f} GB")
# Load Optimus model from Hugging Face Hub
logger.info("Loading Optimus model from bioptimus/H-optimus-0...")
optimus_factory = get_model_factory(ModelType.OPTIMUS)
optimus_model = optimus_factory.get_model(
model_path="hf-hub:bioptimus/H-optimus-0",
use_gpu=use_gpu,
gpu_device_id=0 if use_gpu else None,
)
logger.info("β Optimus model loaded")
if use_gpu and torch.cuda.is_available():
mem = torch.cuda.memory_allocated() / (1024**3)
logger.info(f" GPU memory: {mem:.2f} GB")
# Load Marker Classifier
logger.info("Loading Marker Classifier...")
marker_classifier_path = data_dir / "marker_classifier.pkl"
if not marker_classifier_path.exists():
raise FileNotFoundError(
f"Marker classifier not found at {marker_classifier_path}"
)
with open(marker_classifier_path, "rb") as f:
marker_classifier = pickle.load(f) # nosec
logger.info("β Marker Classifier loaded")
if use_gpu and torch.cuda.is_available():
mem = torch.cuda.memory_allocated() / (1024**3)
logger.info(f" GPU memory: {mem:.2f} GB")
# Load Aeon model
logger.info("Loading Aeon model...")
aeon_path = data_dir / "aeon_model.pkl"
if not aeon_path.exists():
raise FileNotFoundError(f"Aeon model not found at {aeon_path}")
with open(aeon_path, "rb") as f:
aeon_model = pickle.load(f) # nosec
aeon_model.to(device)
aeon_model.eval()
logger.info("β Aeon model loaded")
if use_gpu and torch.cuda.is_available():
mem = torch.cuda.memory_allocated() / (1024**3)
logger.info(f" GPU memory: {mem:.2f} GB")
# Log final memory usage
logger.info("-" * 80)
if use_gpu and torch.cuda.is_available():
mem_allocated = torch.cuda.memory_allocated() / (1024**3)
logger.info(f"β All core models loaded to GPU")
logger.info(f" Total GPU memory used: {mem_allocated:.2f} GB")
logger.info(f" These models will be REUSED for all slides in this batch")
else:
logger.info("β All core models loaded to CPU")
logger.info(" These models will be REUSED for all slides in this batch")
logger.info("-" * 80)
# Create ModelCache
cache = ModelCache(
ctranspath_model=ctranspath_model,
optimus_model=optimus_model,
marker_classifier=marker_classifier,
aeon_model=aeon_model,
is_t4_gpu=IS_T4_GPU,
aggressive_memory_mgmt=aggressive_memory_mgmt,
device=device,
)
return cache
def load_paladin_model_for_inference(
cache: ModelCache,
model_path: Path,
) -> torch.nn.Module:
"""Load a single Paladin model for inference, downloading on-demand if needed.
Implements adaptive loading strategy:
- T4 GPU (aggressive mode): Load model fresh, caller must delete after use
- A100 GPU (caching mode): Check cache, load if needed, return cached model
If the model file doesn't exist locally, downloads it from HuggingFace Hub.
Args:
cache: ModelCache instance managing loaded models
model_path: Path to the Paladin model file
Returns:
Loaded Paladin model ready for inference
Note:
On T4 GPUs, caller MUST delete the model and call torch.cuda.empty_cache()
after inference to avoid OOM errors.
"""
from huggingface_hub import hf_hub_download
model_key = str(model_path)
# Check cache first (only used in non-aggressive mode)
if not cache.aggressive_memory_mgmt and model_key in cache.paladin_models:
logger.info(f" β Using CACHED Paladin model: {model_path.name} (no disk I/O!)")
return cache.paladin_models[model_key]
# Download model from HF Hub if it doesn't exist locally
if not model_path.exists():
logger.info(
f" β¬ Downloading Paladin model from HuggingFace Hub: {model_path.name}"
)
# Extract the relative path from the data directory
data_dir = get_data_directory()
relative_path = model_path.relative_to(data_dir)
downloaded_path = hf_hub_download(
repo_id="PDM-Group/paladin-aeon-models",
filename=str(relative_path),
cache_dir=data_dir.parent.parent, # Use HF cache directory
)
model_path = Path(downloaded_path)
logger.info(f" β Downloaded to: {model_path}")
# Load model from disk
if cache.aggressive_memory_mgmt:
logger.info(
f" β Loading Paladin model: {model_path.name} (will free after use)"
)
else:
logger.info(
f" β Loading Paladin model: {model_path.name} (will cache for reuse)"
)
with open(model_path, "rb") as f:
model = pickle.load(f) # nosec
model.to(cache.device)
model.eval()
# Cache if not in aggressive mode
if not cache.aggressive_memory_mgmt:
cache.paladin_models[model_key] = model
logger.info(f" β Cached Paladin model for future reuse")
return model
|