bielik_app_service / app /models /registry.py
Patryk Studzinski
refactor: enhance model unloading and memory management for improved GPU efficiency
371aac9
"""
Model Registry - Central configuration and factory for all LLM models.
"""
import os
import gc
from typing import Dict, List, Any, Optional
from app.models.base_llm import BaseLLM
from app.models.huggingface_inference_api import HuggingFaceInferenceAPI
from app.models.transformers_model import TransformersModel
# Model configuration
MODEL_CONFIG = {
"bielik-1.5b-transformer": {
"id": "speakleash/Bielik-1.5B-v3.0-Instruct",
"type": "transformers",
"size": "1.5B",
"polish_support": "excellent",
"use_8bit": False,
"device_map": "auto"
},
"bielik-11b-transformer": {
"id": "speakleash/Bielik-11B-v2.3-Instruct",
"type": "transformers",
"size": "11B",
"polish_support": "excellent",
"use_8bit": True,
"device_map": "auto",
"enable_cpu_offload": True
},
"llama-3.1-8b": {
"id": "meta-llama/Llama-3.1-8B-Instruct",
"type": "inference_api",
"polish_support": "good",
"size": "8B",
}
}
LOCAL_MODEL_BASE = os.getenv("MODEL_DIR", "/app/pretrain_model")
class ModelRegistry:
def __init__(self):
self._models: Dict[str, BaseLLM] = {}
self._config = MODEL_CONFIG.copy()
self._active_local_model: Optional[str] = None
def _create_model(self, name: str) -> BaseLLM:
if name not in self._config:
raise ValueError(f"Unknown model: {name}")
config = self._config[name]
model_type = config["type"]
model_id = config["id"]
if model_type == "transformers":
use_8bit = config.get("use_8bit", True)
device_map = config.get("device_map", "auto")
enable_cpu_offload = config.get("enable_cpu_offload", False)
return TransformersModel(
name=name,
model_id=model_id,
use_8bit=use_8bit,
device_map=device_map,
enable_cpu_offload=enable_cpu_offload
)
elif model_type == "inference_api":
return HuggingFaceInferenceAPI(name=name, model_id=model_id)
else:
raise ValueError(f"Unsupported model type: {model_type}")
async def get_model(self, name: str) -> BaseLLM:
config = self._config[name]
# Unload previously active model to free GPU memory when switching models
if self._active_local_model and self._active_local_model != name:
print(f"Switching models: unloading '{self._active_local_model}' to load '{name}'")
await self._unload_model(self._active_local_model)
if name not in self._models:
model = self._create_model(name)
await model.initialize()
self._models[name] = model
self._active_local_model = name
return self._models[name]
async def _unload_model(self, name: str) -> None:
if name in self._models:
model = self._models[name]
if hasattr(model, 'cleanup'): await model.cleanup()
del self._models[name]
gc.collect()
print(f"Model '{name}' unloaded.")
def get_model_info(self, name: str) -> Dict[str, Any]:
config = self._config[name]
return {
"name": name,
"model_id": config["id"],
"type": config["type"],
"size": config.get("size", "unknown"),
"polish_support": config.get("polish_support", "unknown"),
"loaded": name in self._models,
"active": name == self._active_local_model
}
def get_available_model_names(self) -> List[str]:
"""Return list of all available model names."""
return list(self._config.keys())
def list_models(self) -> List[Dict[str, Any]]:
"""Return list of all models with their info."""
return [self.get_model_info(name) for name in self._config.keys()]
def get_loaded_models(self) -> List[str]:
"""Return list of currently loaded model names."""
return list(self._models.keys())
def get_active_model(self) -> Optional[str]:
"""Return name of currently active local model."""
return self._active_local_model
async def load_model(self, name: str) -> Dict[str, Any]:
"""Explicitly load a model and return its info."""
await self.get_model(name)
return self.get_model_info(name)
async def unload_model(self, name: str) -> Dict[str, str]:
"""Explicitly unload a model and free its memory."""
if name in self._models:
await self._unload_model(name)
if self._active_local_model == name:
self._active_local_model = None
return {"status": "success", "message": f"Model '{name}' unloaded"}
return {"status": "error", "message": f"Model '{name}' not loaded"}
async def unload_all_models(self) -> Dict[str, str]:
"""Unload all loaded models and free GPU memory."""
loaded_models = list(self._models.keys())
for model_name in loaded_models:
await self._unload_model(model_name)
self._active_local_model = None
return {"status": "success", "message": f"Unloaded {len(loaded_models)} models"}
registry = ModelRegistry()