Spaces:
Running
Running
| import yaml | |
| from typing import List, Dict, Any | |
| from loguru import logger | |
| from .base import RerankerModel | |
| from .cross_encoder import SentenceTransformersReranker, QwenReranker | |
| class ModelManager: | |
| """ | |
| Manager for reranking models with preloading and configuration. | |
| This class loads model configurations from a YAML file (default: config.yaml), | |
| instantiates and manages multiple reranker models, and provides methods to preload, | |
| retrieve, and list the available models. Supports a default model if model_id is not provided. | |
| Attributes: | |
| models (Dict[str, RerankerModel]): Dictionary of loaded model instances keyed by model ID. | |
| model_configs (Dict[str, Dict[str, Any]]): Model configuration loaded from YAML file. | |
| default_model_id (str): The default model ID to use if none is provided. | |
| """ | |
| def __init__(self, config_path: str = 'config.yaml'): | |
| """ | |
| Initialize the ModelManager and load model configurations from a YAML file. | |
| Args: | |
| config_path (str): Path to the YAML configuration file. Defaults to 'config.yaml'. | |
| Side Effects: | |
| Loads model configuration into self.model_configs. | |
| Initializes an empty dictionary for loaded models. | |
| Sets the default model ID from config. | |
| """ | |
| self.models: Dict[str, RerankerModel] = {} | |
| try: | |
| with open(config_path, 'r') as f: | |
| config_data = yaml.safe_load(f) | |
| self.model_configs = config_data.get('models', {}) | |
| self.default_model_id = config_data.get('default_model') | |
| logger.info(f"Loaded model configs from {config_path}") | |
| except Exception as e: | |
| logger.error(f"Failed to load config.yaml: {e}") | |
| self.model_configs = {} | |
| self.default_model_id = None | |
| async def preload_all_models(self): | |
| """ | |
| Preload all models defined in the configuration file. | |
| Iterates through all model configurations, instantiates the appropriate reranker class | |
| (SentenceTransformersReranker or QwenReranker), loads the model, and stores it in self.models. | |
| Logs the status of each model load and a summary at the end. | |
| Raises: | |
| Exception: If a model fails to load, logs the error and continues with the next model. | |
| """ | |
| logger.info(f"Starting preload of {len(self.model_configs)} reranking models...") | |
| for model_id, config in self.model_configs.items(): | |
| try: | |
| logger.info(f"Loading {model_id}...") | |
| if config["model_type"] == "sentence_transformers": | |
| model = SentenceTransformersReranker( | |
| model_id=model_id, | |
| model_name=config["model_name"], | |
| model_type=config["model_type"] | |
| ) | |
| elif config["model_type"] == "qwen": | |
| model = QwenReranker( | |
| model_id=model_id, | |
| model_name=config["model_name"], | |
| model_type=config["model_type"] | |
| ) | |
| else: | |
| logger.error(f"Unknown model type: {config['model_type']}") | |
| continue | |
| model.load() | |
| self.models[model_id] = model | |
| logger.success(f"Successfully preloaded {model_id}") | |
| except Exception as e: | |
| logger.error(f"Failed to preload {model_id}: {e}") | |
| loaded_count = len([m for m in self.models.values() if m.loaded]) | |
| logger.success(f"Preloaded {loaded_count}/{len(self.model_configs)} models successfully") | |
| def get_model(self, model_id: str = None) -> RerankerModel: | |
| """ | |
| Retrieve a loaded model instance by its ID, or use the default model if not specified. | |
| Args: | |
| model_id (str, optional): The unique identifier of the model to retrieve. If None, uses the default model. | |
| Returns: | |
| RerankerModel: The loaded reranker model instance. | |
| Raises: | |
| ValueError: If the model is not found or not loaded. | |
| """ | |
| if model_id is None: | |
| if not self.default_model_id: | |
| raise ValueError("No model_id provided and no default_model set in config.yaml") | |
| model_id = self.default_model_id | |
| if model_id not in self.models: | |
| raise ValueError(f"Model {model_id} not found") | |
| model = self.models[model_id] | |
| if not model.loaded: | |
| raise ValueError(f"Model {model_id} not loaded") | |
| return model | |
| def list_models(self) -> List[Dict[str, Any]]: | |
| """ | |
| List all available models with their configuration and load status. | |
| Returns: | |
| List[Dict[str, Any]]: A list of dictionaries, each containing model ID, name, type, description, and loaded status. | |
| """ | |
| models_info = [] | |
| for model_id, config in self.model_configs.items(): | |
| model = self.models.get(model_id) | |
| info = { | |
| "id": model_id, | |
| "name": config.get("model_name"), | |
| "type": config.get("model_type"), | |
| "language": config.get("languages"), | |
| "description": config.get("description"), | |
| "repository": config.get("repository"), | |
| "loaded": model.loaded if model else False | |
| } | |
| models_info.append(info) | |
| return models_info | |