Spaces:
Sleeping
Sleeping
| from langchain.chat_models import BaseChatModel | |
| from src.config import config | |
| from src.utils.logging import get_logger | |
| logger = get_logger("model_config") | |
| class ModelConfigurator: | |
| _main_model_instance: BaseChatModel = None | |
| _subagent_model_instance: BaseChatModel = None | |
| _fallback_models_instances: list[BaseChatModel] = None | |
| _summarization_model_instance: BaseChatModel = None | |
| _confidence_scoring_model_instance: BaseChatModel = None | |
| _language_detector_model_instance: BaseChatModel = None | |
| def get_language_detector_model(cls) -> BaseChatModel: | |
| if cls._confidence_scoring_model_instance: | |
| return cls._confidence_scoring_model_instance | |
| try: | |
| from langchain_openai import ChatOpenAI | |
| cls._language_detector_model_instance = ChatOpenAI( | |
| model='gpt-4o-mini', | |
| openai_api_key=config.llm.get_api_key(), | |
| max_tokens=3072, | |
| temperature=0.00, | |
| timeout=60, | |
| request_timeout=60, | |
| ) | |
| logger.info(f"Initialized language detection model") | |
| return cls._language_detector_model_instance | |
| except Exception as e: | |
| logger.error(f"Failed to initialize language detection model: {e}") | |
| raise e | |
| def get_confidence_scoring_model(cls) -> BaseChatModel: | |
| if cls._confidence_scoring_model_instance: | |
| return cls._confidence_scoring_model_instance | |
| try: | |
| from langchain_openai import ChatOpenAI | |
| cls._confidence_scoring_model_instance = ChatOpenAI( | |
| model='gpt-4o-mini', | |
| openai_api_key=config.llm.get_api_key(), | |
| max_tokens=3072, | |
| temperature=0.00, | |
| timeout=60, | |
| request_timeout=60, | |
| ) | |
| logger.info(f"Initialized confidence scoring model") | |
| return cls._confidence_scoring_model_instance | |
| except Exception as e: | |
| logger.error(f"Failed to initialize confidence scoring model: {e}") | |
| raise e | |
| def get_summarization_model(cls) -> BaseChatModel: | |
| if cls._summarization_model_instance: | |
| return cls._summarization_model_instance | |
| try: | |
| # Add custom summarization model initialization here if needed | |
| cls._summarization_model_instance = cls.get_main_agent_model() | |
| logger.info(f"Initialized summarization model '{config.llm.LLM_PROVIDER.name}:{config.llm.get_default_model()}'") | |
| return cls._summarization_model_instance | |
| except Exception as e: | |
| logger.error(f"Failed to initialize the summarization model: {e}") | |
| raise e | |
| def get_subagent_model(cls) -> BaseChatModel: | |
| if cls._subagent_model_instance: | |
| return cls._subagent_model_instance | |
| subagent_provider = config.llm.LLM_PROVIDER | |
| subagent_model = ( | |
| 'gpt-5-mini' | |
| if subagent_provider.base == 'openai' | |
| else config.llm.get_default_model(subagent_provider) | |
| ) | |
| cls._subagent_model_instance = cls._initialize_model( | |
| provider=subagent_provider, | |
| model=subagent_model, | |
| ) | |
| logger.info( | |
| f"Initialized subagent model '{subagent_provider.name}:{subagent_model}'" | |
| ) | |
| return cls._subagent_model_instance | |
| def get_main_agent_model(cls) -> BaseChatModel: | |
| """Initialize the language model based on config.""" | |
| if cls._main_model_instance: | |
| return cls._main_model_instance | |
| try: | |
| cls._main_model_instance = cls._initialize_model( | |
| provider=config.llm.LLM_PROVIDER, | |
| model=config.llm.get_default_model() | |
| ) | |
| logger.info(f"Initialized main agent model '{config.llm.LLM_PROVIDER.name}:{config.llm.get_default_model()}'") | |
| return cls._main_model_instance | |
| except Exception as e: | |
| logger.error(f"Failed to initialize the main agent model for provider '{config.llm.LLM_PROVIDER.name}': {e}") | |
| raise e | |
| def get_fallback_models(cls) -> list[BaseChatModel]: | |
| if cls._fallback_models_instances != None: | |
| return cls._fallback_models_instances | |
| cls._fallback_models_instances = cls._initialize_fallback_models() | |
| if len(cls._fallback_models_instances) == 0: | |
| logger.warning("No fallback models were initialized! Response generation may result in unexpected errors!") | |
| return cls._fallback_models_instances | |
| def _initialize_fallback_models(cls) -> list[BaseChatModel]: | |
| fallback_models_instances = [] | |
| for fallback_provider, fallback_model in config.llm.get_fallback_models(): | |
| try: | |
| fallback_model_instance = cls._initialize_model( | |
| provider=fallback_provider, | |
| model=fallback_model, | |
| ) | |
| logger.info(f"Initialized fallback model '{fallback_provider.name}:{fallback_model}'") | |
| fallback_models_instances.append(fallback_model_instance) | |
| except Exception as e: | |
| logger.error(f"Failed to initialize the fallback model {fallback_provider.name}:{fallback_model}: {e}; skipping...") | |
| return fallback_models_instances | |
| def _initialize_model(cls, provider, model: str) -> BaseChatModel: | |
| try: | |
| match provider.name: | |
| case 'groq': | |
| from langchain_groq import ChatGroq | |
| return ChatGroq( | |
| model=model, | |
| groq_api_key=config.llm.get_api_key(), | |
| temperature=0.01, | |
| ) | |
| case ( 'open_router:openai' | |
| | 'open_router:alibaba' | |
| | 'open_router:nvidia' | |
| | 'open_router:meituan'): | |
| from langchain_openai import ChatOpenAI | |
| return ChatOpenAI( | |
| model=model, | |
| base_url=config.llm.OPEN_ROUTER_BASE_URL, | |
| api_key=config.llm.get_api_key(), | |
| temperature=0.01, | |
| ) | |
| case 'open_router:deepseek': | |
| from langchain_deepseek import ChatDeepSeek | |
| return ChatDeepSeek( | |
| model=model, | |
| api_key=config.llm.OPEN_ROUTER_API_KEY, | |
| api_base=config.llm.OPEN_ROUTER_BASE_URL, | |
| ) | |
| case 'openai': | |
| from langchain_openai import ChatOpenAI | |
| return ChatOpenAI( | |
| model=model, | |
| openai_api_key=config.llm.get_api_key(), | |
| max_tokens=3072, | |
| temperature=0.01, | |
| timeout=60, | |
| request_timeout=60, | |
| ) | |
| case 'ollama': | |
| from langchain_ollama import ChatOllama | |
| return ChatOllama( | |
| model=model, | |
| base_url=config.llm.OLLAMA_BASE_URL, | |
| temperature=0.01, | |
| reasoning=config.llm.get_reasoning_support(), | |
| num_predict=2048, | |
| ) | |
| case _: | |
| raise ValueError(f"Unsupported LLM provider: {provider.name}") | |
| except Exception as e: | |
| raise e | |