Patryk Studzinski commited on
Commit
371aac9
·
1 Parent(s): 9ecca89

refactor: enhance model unloading and memory management for improved GPU efficiency

Browse files
app/models/registry.py CHANGED
@@ -74,11 +74,17 @@ class ModelRegistry:
74
  async def get_model(self, name: str) -> BaseLLM:
75
  config = self._config[name]
76
 
 
 
 
 
 
77
  if name not in self._models:
78
  model = self._create_model(name)
79
  await model.initialize()
80
  self._models[name] = model
81
 
 
82
  return self._models[name]
83
 
84
  async def _unload_model(self, name: str) -> None:
@@ -123,8 +129,20 @@ class ModelRegistry:
123
  return self.get_model_info(name)
124
 
125
  async def unload_model(self, name: str) -> Dict[str, str]:
126
- """Explicitly unload a model."""
127
- await self._unload_model(name)
128
- return {"status": "unloaded", "model": name}
 
 
 
 
 
 
 
 
 
 
 
 
129
 
130
  registry = ModelRegistry()
 
74
  async def get_model(self, name: str) -> BaseLLM:
75
  config = self._config[name]
76
 
77
+ # Unload previously active model to free GPU memory when switching models
78
+ if self._active_local_model and self._active_local_model != name:
79
+ print(f"Switching models: unloading '{self._active_local_model}' to load '{name}'")
80
+ await self._unload_model(self._active_local_model)
81
+
82
  if name not in self._models:
83
  model = self._create_model(name)
84
  await model.initialize()
85
  self._models[name] = model
86
 
87
+ self._active_local_model = name
88
  return self._models[name]
89
 
90
  async def _unload_model(self, name: str) -> None:
 
129
  return self.get_model_info(name)
130
 
131
  async def unload_model(self, name: str) -> Dict[str, str]:
132
+ """Explicitly unload a model and free its memory."""
133
+ if name in self._models:
134
+ await self._unload_model(name)
135
+ if self._active_local_model == name:
136
+ self._active_local_model = None
137
+ return {"status": "success", "message": f"Model '{name}' unloaded"}
138
+ return {"status": "error", "message": f"Model '{name}' not loaded"}
139
+
140
+ async def unload_all_models(self) -> Dict[str, str]:
141
+ """Unload all loaded models and free GPU memory."""
142
+ loaded_models = list(self._models.keys())
143
+ for model_name in loaded_models:
144
+ await self._unload_model(model_name)
145
+ self._active_local_model = None
146
+ return {"status": "success", "message": f"Unloaded {len(loaded_models)} models"}
147
 
148
  registry = ModelRegistry()
app/models/transformers_model.py CHANGED
@@ -70,6 +70,18 @@ class TransformersModel(BaseLLM):
70
 
71
  def _load_model(self) -> None:
72
  """Load model with optimal device configuration and quantization support."""
 
 
 
 
 
 
 
 
 
 
 
 
73
  # Check GPU availability with detailed diagnostics
74
  cuda_available = torch.cuda.is_available()
75
  cuda_device_count = torch.cuda.device_count() if cuda_available else 0
@@ -322,6 +334,8 @@ class TransformersModel(BaseLLM):
322
 
323
  async def cleanup(self) -> None:
324
  """Free memory."""
 
 
325
  if self.model:
326
  del self.model
327
  self.model = None
@@ -330,8 +344,17 @@ class TransformersModel(BaseLLM):
330
  self.tokenizer = None
331
  self._initialized = False
332
 
 
 
 
333
  # Clear CUDA cache if available
334
  if torch.cuda.is_available():
335
  torch.cuda.empty_cache()
 
 
 
 
 
 
336
 
337
- print(f"[{self.name}] Transformers Model unloaded")
 
70
 
71
  def _load_model(self) -> None:
72
  """Load model with optimal device configuration and quantization support."""
73
+ import gc
74
+
75
+ # Set PyTorch environment variables for optimal memory management
76
+ if not os.getenv("PYTORCH_CUDA_ALLOC_CONF"):
77
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
78
+ print(f"[{self.name}] Set PYTORCH_CUDA_ALLOC_CONF to prevent GPU memory fragmentation")
79
+
80
+ # Force garbage collection before loading new model
81
+ gc.collect()
82
+ if torch.cuda.is_available():
83
+ torch.cuda.empty_cache()
84
+
85
  # Check GPU availability with detailed diagnostics
86
  cuda_available = torch.cuda.is_available()
87
  cuda_device_count = torch.cuda.device_count() if cuda_available else 0
 
334
 
335
  async def cleanup(self) -> None:
336
  """Free memory."""
337
+ import gc
338
+
339
  if self.model:
340
  del self.model
341
  self.model = None
 
344
  self.tokenizer = None
345
  self._initialized = False
346
 
347
+ # Aggressive cleanup
348
+ gc.collect() # Force garbage collection
349
+
350
  # Clear CUDA cache if available
351
  if torch.cuda.is_available():
352
  torch.cuda.empty_cache()
353
+ try:
354
+ # Empty reserved memory too (PyTorch 2.0+)
355
+ device_id = torch.cuda.current_device()
356
+ torch.cuda.reset_peak_memory_stats(device_id)
357
+ except:
358
+ pass
359
 
360
+ print(f"[{self.name}] Transformers Model unloaded and memory freed")