Spaces:
Sleeping
Sleeping
Patryk Studzinski commited on
Commit ·
371aac9
1
Parent(s): 9ecca89
refactor: enhance model unloading and memory management for improved GPU efficiency
Browse files- app/models/registry.py +21 -3
- app/models/transformers_model.py +24 -1
app/models/registry.py
CHANGED
|
@@ -74,11 +74,17 @@ class ModelRegistry:
|
|
| 74 |
async def get_model(self, name: str) -> BaseLLM:
|
| 75 |
config = self._config[name]
|
| 76 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
if name not in self._models:
|
| 78 |
model = self._create_model(name)
|
| 79 |
await model.initialize()
|
| 80 |
self._models[name] = model
|
| 81 |
|
|
|
|
| 82 |
return self._models[name]
|
| 83 |
|
| 84 |
async def _unload_model(self, name: str) -> None:
|
|
@@ -123,8 +129,20 @@ class ModelRegistry:
|
|
| 123 |
return self.get_model_info(name)
|
| 124 |
|
| 125 |
async def unload_model(self, name: str) -> Dict[str, str]:
|
| 126 |
-
"""Explicitly unload a model."""
|
| 127 |
-
|
| 128 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 129 |
|
| 130 |
registry = ModelRegistry()
|
|
|
|
| 74 |
async def get_model(self, name: str) -> BaseLLM:
|
| 75 |
config = self._config[name]
|
| 76 |
|
| 77 |
+
# Unload previously active model to free GPU memory when switching models
|
| 78 |
+
if self._active_local_model and self._active_local_model != name:
|
| 79 |
+
print(f"Switching models: unloading '{self._active_local_model}' to load '{name}'")
|
| 80 |
+
await self._unload_model(self._active_local_model)
|
| 81 |
+
|
| 82 |
if name not in self._models:
|
| 83 |
model = self._create_model(name)
|
| 84 |
await model.initialize()
|
| 85 |
self._models[name] = model
|
| 86 |
|
| 87 |
+
self._active_local_model = name
|
| 88 |
return self._models[name]
|
| 89 |
|
| 90 |
async def _unload_model(self, name: str) -> None:
|
|
|
|
| 129 |
return self.get_model_info(name)
|
| 130 |
|
| 131 |
async def unload_model(self, name: str) -> Dict[str, str]:
|
| 132 |
+
"""Explicitly unload a model and free its memory."""
|
| 133 |
+
if name in self._models:
|
| 134 |
+
await self._unload_model(name)
|
| 135 |
+
if self._active_local_model == name:
|
| 136 |
+
self._active_local_model = None
|
| 137 |
+
return {"status": "success", "message": f"Model '{name}' unloaded"}
|
| 138 |
+
return {"status": "error", "message": f"Model '{name}' not loaded"}
|
| 139 |
+
|
| 140 |
+
async def unload_all_models(self) -> Dict[str, str]:
|
| 141 |
+
"""Unload all loaded models and free GPU memory."""
|
| 142 |
+
loaded_models = list(self._models.keys())
|
| 143 |
+
for model_name in loaded_models:
|
| 144 |
+
await self._unload_model(model_name)
|
| 145 |
+
self._active_local_model = None
|
| 146 |
+
return {"status": "success", "message": f"Unloaded {len(loaded_models)} models"}
|
| 147 |
|
| 148 |
registry = ModelRegistry()
|
app/models/transformers_model.py
CHANGED
|
@@ -70,6 +70,18 @@ class TransformersModel(BaseLLM):
|
|
| 70 |
|
| 71 |
def _load_model(self) -> None:
|
| 72 |
"""Load model with optimal device configuration and quantization support."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
# Check GPU availability with detailed diagnostics
|
| 74 |
cuda_available = torch.cuda.is_available()
|
| 75 |
cuda_device_count = torch.cuda.device_count() if cuda_available else 0
|
|
@@ -322,6 +334,8 @@ class TransformersModel(BaseLLM):
|
|
| 322 |
|
| 323 |
async def cleanup(self) -> None:
|
| 324 |
"""Free memory."""
|
|
|
|
|
|
|
| 325 |
if self.model:
|
| 326 |
del self.model
|
| 327 |
self.model = None
|
|
@@ -330,8 +344,17 @@ class TransformersModel(BaseLLM):
|
|
| 330 |
self.tokenizer = None
|
| 331 |
self._initialized = False
|
| 332 |
|
|
|
|
|
|
|
|
|
|
| 333 |
# Clear CUDA cache if available
|
| 334 |
if torch.cuda.is_available():
|
| 335 |
torch.cuda.empty_cache()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 336 |
|
| 337 |
-
print(f"[{self.name}] Transformers Model unloaded")
|
|
|
|
| 70 |
|
| 71 |
def _load_model(self) -> None:
|
| 72 |
"""Load model with optimal device configuration and quantization support."""
|
| 73 |
+
import gc
|
| 74 |
+
|
| 75 |
+
# Set PyTorch environment variables for optimal memory management
|
| 76 |
+
if not os.getenv("PYTORCH_CUDA_ALLOC_CONF"):
|
| 77 |
+
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
| 78 |
+
print(f"[{self.name}] Set PYTORCH_CUDA_ALLOC_CONF to prevent GPU memory fragmentation")
|
| 79 |
+
|
| 80 |
+
# Force garbage collection before loading new model
|
| 81 |
+
gc.collect()
|
| 82 |
+
if torch.cuda.is_available():
|
| 83 |
+
torch.cuda.empty_cache()
|
| 84 |
+
|
| 85 |
# Check GPU availability with detailed diagnostics
|
| 86 |
cuda_available = torch.cuda.is_available()
|
| 87 |
cuda_device_count = torch.cuda.device_count() if cuda_available else 0
|
|
|
|
| 334 |
|
| 335 |
async def cleanup(self) -> None:
|
| 336 |
"""Free memory."""
|
| 337 |
+
import gc
|
| 338 |
+
|
| 339 |
if self.model:
|
| 340 |
del self.model
|
| 341 |
self.model = None
|
|
|
|
| 344 |
self.tokenizer = None
|
| 345 |
self._initialized = False
|
| 346 |
|
| 347 |
+
# Aggressive cleanup
|
| 348 |
+
gc.collect() # Force garbage collection
|
| 349 |
+
|
| 350 |
# Clear CUDA cache if available
|
| 351 |
if torch.cuda.is_available():
|
| 352 |
torch.cuda.empty_cache()
|
| 353 |
+
try:
|
| 354 |
+
# Empty reserved memory too (PyTorch 2.0+)
|
| 355 |
+
device_id = torch.cuda.current_device()
|
| 356 |
+
torch.cuda.reset_peak_memory_stats(device_id)
|
| 357 |
+
except:
|
| 358 |
+
pass
|
| 359 |
|
| 360 |
+
print(f"[{self.name}] Transformers Model unloaded and memory freed")
|