File size: 7,573 Bytes
0a372e8
 
 
 
 
 
 
 
 
95cdb75
0a372e8
 
593a090
 
0a372e8
593a090
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0a372e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95cdb75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0a372e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95cdb75
0a372e8
95cdb75
 
0a372e8
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
from langchain.chat_models import BaseChatModel
from config import LLMProvider, LLMProviderConfiguration as llmconf

from src.utils.logging import get_logger

logger = get_logger("model_config")

class ModelConfigurator:
    _main_model_instance: BaseChatModel = None
    _subagent_model_instance: BaseChatModel = None
    _fallback_models_instances: list[BaseChatModel] = None
    _summarization_model_instance: BaseChatModel = None
    _confidence_scoring_model_instance: BaseChatModel = None 
    _language_detector_model_instance: BaseChatModel = None
    
    @classmethod 
    def get_language_detector_model(cls) -> BaseChatModel:
        if cls._confidence_scoring_model_instance:
            return cls._confidence_scoring_model_instance
        try:
            from langchain_openai import ChatOpenAI
            cls._language_detector_model_instance = ChatOpenAI(
                model='gpt-4o-mini',
                openai_api_key=llmconf.get_api_key(),
                max_tokens=3072,
                temperature=0.00,
                timeout=60,
                request_timeout=60,
            )
            logger.info(f"Initialized language detection model")
            return cls._language_detector_model_instance
        except Exception as e:
            logger.error(f"Failed to initialize language detection model: {e}")
            raise e

    @classmethod
    def get_confidence_scoring_model(cls) -> BaseChatModel:
        if cls._confidence_scoring_model_instance:
            return cls._confidence_scoring_model_instance
        
        try:
            from langchain_openai import ChatOpenAI
            cls._confidence_scoring_model_instance = ChatOpenAI(
                model='gpt-4o-mini',
                openai_api_key=llmconf.get_api_key(),
                max_tokens=3072,
                temperature=0.00,
                timeout=60,
                request_timeout=60,
            )
            logger.info(f"Initialized confidence scoring model")
            return cls._confidence_scoring_model_instance
        except Exception as e:
            logger.error(f"Failed to initialize confidence scoring model: {e}")
            raise e


    @classmethod
    def get_summarization_model(cls) -> BaseChatModel:
        if cls._summarization_model_instance:
            return cls._summarization_model_instance
        
        try:
            # Add custom summarization model initialization here if needed
            cls._summarization_model_instance = cls.get_main_agent_model()
            logger.info(f"Initialized summarization model '{llmconf.LLM_PROVIDER.name}:{llmconf.get_default_model()}'")
            return cls._summarization_model_instance
        except Exception as e:
            logger.error(f"Failed to initialize the summarization model: {e}")
            raise e

    @classmethod
    def get_subagent_model(cls) -> BaseChatModel:
        if cls._subagent_model_instance:
            return cls._subagent_model_instance
        
        from langchain_openai import ChatOpenAI
        cls._subagent_model_instance = ChatOpenAI(
            model='gpt-5.1',
            openai_api_key=llmconf.get_api_key(),
            max_tokens=3072,
            temperature=0.01,
            timeout=60,
            request_timeout=60,
        )
        return cls._subagent_model_instance


    @classmethod
    def get_main_agent_model(cls) -> BaseChatModel:
        """Initialize the language model based on config."""
        if cls._main_model_instance:
            return cls._main_model_instance

        try:
            cls._main_model_instance = cls._initialize_model(
                provider=llmconf.LLM_PROVIDER,
                model=llmconf.get_default_model()
            )
            logger.info(f"Initialized main agent model '{llmconf.LLM_PROVIDER.name}:{llmconf.get_default_model()}'")
            return cls._main_model_instance
        except Exception as e: 
            logger.error(f"Failed to initialize the main agent model for provider '{llmconf.LLM_PROVIDER.name}': {e}")
            raise e


    @classmethod
    def get_fallback_models(cls) -> list[BaseChatModel]:
        if cls._fallback_models_instances != None:
            return cls._fallback_models_instances 

        cls._fallback_models_instances = cls._initialize_fallback_models()
        if len(cls._fallback_models_instances) == 0:
            logger.warning("No fallback models were initialized! Response generation may result in unexpected errors!")
        return cls._fallback_models_instances


    @classmethod
    def _initialize_fallback_models(cls) -> list[BaseChatModel]:
        fallback_models_instances = []
        for fallback_provider, fallback_model in llmconf.get_fallback_models().items():
            try:
                fallback_model_instance = cls._initialize_model(
                    provider=fallback_provider,
                    model=fallback_model,
                )
                logger.info(f"Initialized fallback model '{fallback_provider.name}:{fallback_model}'")
                fallback_models_instances.append(fallback_model_instance)
            except Exception as e:
                logger.error(f"Failed to initialize the fallback model {fallback_provider.name}:{fallback_model}: {e}; skipping...")
        return fallback_models_instances


    @classmethod
    def _initialize_model(cls, provider: LLMProvider, model: str) -> BaseChatModel:
        try:
            match provider.name:
                case 'groq':
                    from langchain_groq import ChatGroq
                    return ChatGroq(
                        model=model,
                        groq_api_key=llmconf.get_api_key(),
                        temperature=0.01,
                    )
                case (  'open_router:openai' 
                      | 'open_router:alibaba' 
                      | 'open_router:nvidia'
                      | 'open_router:meituan'):
                    from langchain_openai import ChatOpenAI
                    return ChatOpenAI(
                        model=model,
                        base_url=llmconf.OPEN_ROUTER_BASE_URL,
                        api_key=llmconf.get_api_key(),
                        temperature=0.01,
                    )
                case 'open_router:deepseek':
                    from langchain_deepseek import ChatDeepSeek
                    return ChatDeepSeek(
                        model=model,
                        api_key=llmconf.OPEN_ROUTER_API_KEY,
                        api_base=llmconf.OPEN_ROUTER_BASE_URL,
                    )
                case 'openai':
                    from langchain_openai import ChatOpenAI
                    return ChatOpenAI(
                        model=model,
                        openai_api_key=llmconf.get_api_key(),
                        max_tokens=3072,
                        temperature=0.01,
                        timeout=60,
                        request_timeout=60,
                    )
                case 'ollama':
                    from langchain_ollama import ChatOllama
                    return ChatOllama(
                        model=model,
                        base_url=llmconf.OLLAMA_BASE_URL,
                        temperature=0.01,
                        reasoning=llmconf.get_reasoning_support(),
                        num_predict=2048,
                    )
                case _:
                    raise ValueError(f"Unsupported LLM provider: {provider.name}")
        except Exception as e:
            raise e