File size: 5,555 Bytes
7f8bfb2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87ef7db
7f8bfb2
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import yaml
from typing import List, Dict, Any
from loguru import logger

from .base import RerankerModel
from .cross_encoder import SentenceTransformersReranker, QwenReranker


class ModelManager:
    """
    Manager for reranking models with preloading and configuration.

    This class loads model configurations from a YAML file (default: config.yaml),
    instantiates and manages multiple reranker models, and provides methods to preload,
    retrieve, and list the available models. Supports a default model if model_id is not provided.

    Attributes:
        models (Dict[str, RerankerModel]): Dictionary of loaded model instances keyed by model ID.
        model_configs (Dict[str, Dict[str, Any]]): Model configuration loaded from YAML file.
        default_model_id (str): The default model ID to use if none is provided.
    """

    def __init__(self, config_path: str = 'config.yaml'):
        """
        Initialize the ModelManager and load model configurations from a YAML file.

        Args:
            config_path (str): Path to the YAML configuration file. Defaults to 'config.yaml'.

        Side Effects:
            Loads model configuration into self.model_configs.
            Initializes an empty dictionary for loaded models.
            Sets the default model ID from config.
        """
        self.models: Dict[str, RerankerModel] = {}
        try:
            with open(config_path, 'r') as f:
                config_data = yaml.safe_load(f)
            self.model_configs = config_data.get('models', {})
            self.default_model_id = config_data.get('default_model')
            logger.info(f"Loaded model configs from {config_path}")
        except Exception as e:
            logger.error(f"Failed to load config.yaml: {e}")
            self.model_configs = {}
            self.default_model_id = None
    
    async def preload_all_models(self):
        """
        Preload all models defined in the configuration file.

        Iterates through all model configurations, instantiates the appropriate reranker class
        (SentenceTransformersReranker or QwenReranker), loads the model, and stores it in self.models.
        Logs the status of each model load and a summary at the end.

        Raises:
            Exception: If a model fails to load, logs the error and continues with the next model.
        """
        logger.info(f"Starting preload of {len(self.model_configs)} reranking models...")

        for model_id, config in self.model_configs.items():
            try:
                logger.info(f"Loading {model_id}...")

                if config["model_type"] == "sentence_transformers":
                    model = SentenceTransformersReranker(
                        model_id=model_id,
                        model_name=config["model_name"],
                        model_type=config["model_type"]
                    )
                elif config["model_type"] == "qwen":
                    model = QwenReranker(
                        model_id=model_id,
                        model_name=config["model_name"],
                        model_type=config["model_type"]
                    )
                else:
                    logger.error(f"Unknown model type: {config['model_type']}")
                    continue

                model.load()
                self.models[model_id] = model
                logger.success(f"Successfully preloaded {model_id}")

            except Exception as e:
                logger.error(f"Failed to preload {model_id}: {e}")

        loaded_count = len([m for m in self.models.values() if m.loaded])
        logger.success(f"Preloaded {loaded_count}/{len(self.model_configs)} models successfully")
    
    def get_model(self, model_id: str = None) -> RerankerModel:
        """
        Retrieve a loaded model instance by its ID, or use the default model if not specified.

        Args:
            model_id (str, optional): The unique identifier of the model to retrieve. If None, uses the default model.

        Returns:
            RerankerModel: The loaded reranker model instance.

        Raises:
            ValueError: If the model is not found or not loaded.
        """
        if model_id is None:
            if not self.default_model_id:
                raise ValueError("No model_id provided and no default_model set in config.yaml")
            model_id = self.default_model_id

        if model_id not in self.models:
            raise ValueError(f"Model {model_id} not found")

        model = self.models[model_id]
        if not model.loaded:
            raise ValueError(f"Model {model_id} not loaded")

        return model
    
    def list_models(self) -> List[Dict[str, Any]]:
        """
        List all available models with their configuration and load status.

        Returns:
            List[Dict[str, Any]]: A list of dictionaries, each containing model ID, name, type, description, and loaded status.
        """
        models_info = []
        for model_id, config in self.model_configs.items():
            model = self.models.get(model_id)
            info = {
                "id": model_id,
                "name": config.get("model_name"),
                "type": config.get("model_type"),
                "language": config.get("languages"),
                "description": config.get("description"),
                "repository": config.get("repository"),
                "loaded": model.loaded if model else False
            }
            models_info.append(info)
        return models_info