Spaces:
Sleeping
Sleeping
Patryk Studzinski
refactor: enhance model unloading and memory management for improved GPU efficiency
371aac9 | """ | |
| Model Registry - Central configuration and factory for all LLM models. | |
| """ | |
| import os | |
| import gc | |
| from typing import Dict, List, Any, Optional | |
| from app.models.base_llm import BaseLLM | |
| from app.models.huggingface_inference_api import HuggingFaceInferenceAPI | |
| from app.models.transformers_model import TransformersModel | |
| # Model configuration | |
| MODEL_CONFIG = { | |
| "bielik-1.5b-transformer": { | |
| "id": "speakleash/Bielik-1.5B-v3.0-Instruct", | |
| "type": "transformers", | |
| "size": "1.5B", | |
| "polish_support": "excellent", | |
| "use_8bit": False, | |
| "device_map": "auto" | |
| }, | |
| "bielik-11b-transformer": { | |
| "id": "speakleash/Bielik-11B-v2.3-Instruct", | |
| "type": "transformers", | |
| "size": "11B", | |
| "polish_support": "excellent", | |
| "use_8bit": True, | |
| "device_map": "auto", | |
| "enable_cpu_offload": True | |
| }, | |
| "llama-3.1-8b": { | |
| "id": "meta-llama/Llama-3.1-8B-Instruct", | |
| "type": "inference_api", | |
| "polish_support": "good", | |
| "size": "8B", | |
| } | |
| } | |
| LOCAL_MODEL_BASE = os.getenv("MODEL_DIR", "/app/pretrain_model") | |
| class ModelRegistry: | |
| def __init__(self): | |
| self._models: Dict[str, BaseLLM] = {} | |
| self._config = MODEL_CONFIG.copy() | |
| self._active_local_model: Optional[str] = None | |
| def _create_model(self, name: str) -> BaseLLM: | |
| if name not in self._config: | |
| raise ValueError(f"Unknown model: {name}") | |
| config = self._config[name] | |
| model_type = config["type"] | |
| model_id = config["id"] | |
| if model_type == "transformers": | |
| use_8bit = config.get("use_8bit", True) | |
| device_map = config.get("device_map", "auto") | |
| enable_cpu_offload = config.get("enable_cpu_offload", False) | |
| return TransformersModel( | |
| name=name, | |
| model_id=model_id, | |
| use_8bit=use_8bit, | |
| device_map=device_map, | |
| enable_cpu_offload=enable_cpu_offload | |
| ) | |
| elif model_type == "inference_api": | |
| return HuggingFaceInferenceAPI(name=name, model_id=model_id) | |
| else: | |
| raise ValueError(f"Unsupported model type: {model_type}") | |
| async def get_model(self, name: str) -> BaseLLM: | |
| config = self._config[name] | |
| # Unload previously active model to free GPU memory when switching models | |
| if self._active_local_model and self._active_local_model != name: | |
| print(f"Switching models: unloading '{self._active_local_model}' to load '{name}'") | |
| await self._unload_model(self._active_local_model) | |
| if name not in self._models: | |
| model = self._create_model(name) | |
| await model.initialize() | |
| self._models[name] = model | |
| self._active_local_model = name | |
| return self._models[name] | |
| async def _unload_model(self, name: str) -> None: | |
| if name in self._models: | |
| model = self._models[name] | |
| if hasattr(model, 'cleanup'): await model.cleanup() | |
| del self._models[name] | |
| gc.collect() | |
| print(f"Model '{name}' unloaded.") | |
| def get_model_info(self, name: str) -> Dict[str, Any]: | |
| config = self._config[name] | |
| return { | |
| "name": name, | |
| "model_id": config["id"], | |
| "type": config["type"], | |
| "size": config.get("size", "unknown"), | |
| "polish_support": config.get("polish_support", "unknown"), | |
| "loaded": name in self._models, | |
| "active": name == self._active_local_model | |
| } | |
| def get_available_model_names(self) -> List[str]: | |
| """Return list of all available model names.""" | |
| return list(self._config.keys()) | |
| def list_models(self) -> List[Dict[str, Any]]: | |
| """Return list of all models with their info.""" | |
| return [self.get_model_info(name) for name in self._config.keys()] | |
| def get_loaded_models(self) -> List[str]: | |
| """Return list of currently loaded model names.""" | |
| return list(self._models.keys()) | |
| def get_active_model(self) -> Optional[str]: | |
| """Return name of currently active local model.""" | |
| return self._active_local_model | |
| async def load_model(self, name: str) -> Dict[str, Any]: | |
| """Explicitly load a model and return its info.""" | |
| await self.get_model(name) | |
| return self.get_model_info(name) | |
| async def unload_model(self, name: str) -> Dict[str, str]: | |
| """Explicitly unload a model and free its memory.""" | |
| if name in self._models: | |
| await self._unload_model(name) | |
| if self._active_local_model == name: | |
| self._active_local_model = None | |
| return {"status": "success", "message": f"Model '{name}' unloaded"} | |
| return {"status": "error", "message": f"Model '{name}' not loaded"} | |
| async def unload_all_models(self) -> Dict[str, str]: | |
| """Unload all loaded models and free GPU memory.""" | |
| loaded_models = list(self._models.keys()) | |
| for model_name in loaded_models: | |
| await self._unload_model(model_name) | |
| self._active_local_model = None | |
| return {"status": "success", "message": f"Unloaded {len(loaded_models)} models"} | |
| registry = ModelRegistry() |