File size: 3,878 Bytes
8816dfd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5dd4236
8816dfd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
"""
Model Manager Module

This module provides centralized management of AI models for the HiveGPT Agent
system. It handles loading, caching, and lifecycle management of both LLM and
reranking models with thread-safe operations.

The ModelManager class offers:
- Lazy loading and caching of language models
- Thread-safe model access with async locks
- Integration with ModelRouter for model discovery
- Memory-efficient model reuse across requests

Key Features:
- Singleton pattern for consistent model access
- Async/await support for non-blocking operations
- Automatic model caching to improve performance
- Error handling for model loading failures

Author: HiveNetCode
License: Private
"""

import asyncio
from typing import Dict, Any, Optional

from langchain_openai import ChatOpenAI

from ComputeAgent.models.model_router import ModelRouter, LLMModel
from constant import Constants


class ModelManager:
    """
    Centralized manager for AI model loading, caching, and lifecycle management.
    
    This class implements a thread-safe caching system for language models and
    reranking models, providing efficient model reuse across the application.
    It integrates with ModelRouter to discover available models and handles
    the initialization and configuration of ChatOpenAI instances.
    
    The ModelManager follows a singleton-like pattern where models are cached
    at the class level to ensure memory efficiency and consistent model access
    throughout the application lifecycle.
    
    Attributes:
        _llm_models: Cache of loaded language models
        _reranker_models: Cache of loaded reranking models
        _llm_lock: Async lock for thread-safe LLM loading
        _reranker_lock: Async lock for thread-safe reranker loading
        model_router: Interface to model discovery service
        reranking_model_name: Name of the default reranking model
        reranker: Cached reranking model instance
    """

    def __init__(self):
        """
        Initialize the ModelManager with empty caches and async locks.

        Sets up the internal data structures for model caching and thread-safe
        access. Initializes the ModelRouter for model discovery and sets the
        default reranking model configuration.
        """
        # Model caches for efficient reuse
        self._llm_models: Dict[str, ChatOpenAI] = {}
        
        # Thread safety locks for concurrent access
        self._llm_lock = asyncio.Lock()
        
        # Model discovery and configuration
        self.model_router = ModelRouter()

    async def load_llm_model(self, model_name: str) -> ChatOpenAI:
        """
        Asynchronously loads and returns a language model for the specified model name.

        This method checks if the model is already loaded and cached in the class-level
        dictionary `_llm_models`. If not, it acquires a lock to ensure thread-safe
        model loading, retrieves the model information from the Model Router, initializes
        a `ChatOpenAI` instance with the given parameters, and caches it for future use.

        Args:
            model_name (str): The name of the language model to load.

        Returns:
            ChatOpenAI: An instance of the loaded language model.
        """
        if model_name in self._llm_models:
            return self._llm_models[model_name]
        async with self._llm_lock:
            if model_name not in self._llm_models:
                loaded_model: LLMModel = self.model_router.get_llm_model(model_name)
                llm = ChatOpenAI(
                    model_name=model_name,
                    api_key=Constants.MODEL_ROUTER_TOKEN,
                    base_url=loaded_model.openai_endpoint,
                    temperature=0.1,
                )
                self._llm_models[model_name] = llm
            return self._llm_models[model_name]