Spaces:
Running
Running
| import os | |
| from typing import Any | |
| from dotenv import load_dotenv | |
| from langchain_groq import ChatGroq | |
| from langchain_community.embeddings import FastEmbedEmbeddings | |
| from project.utils.config_loader import load_config | |
| from project.logger.logging import get_logger | |
| logger = get_logger(__name__) | |
| class ModelLoader: | |
| def __init__(self, config_path: str = None): | |
| load_dotenv() | |
| self.config = load_config(config_path) | |
| self._load_api_keys() | |
| logger.info("ModelLoader initialized") | |
| def _load_api_keys(self): | |
| groq_key = os.getenv('GROQ_API_KEY') | |
| if groq_key: | |
| os.environ['GROQ_API_KEY'] = groq_key | |
| logger.info("GROQ API key loaded") | |
| def load_llm(self) -> Any: | |
| llm_config = self.config.get('llm', {}) | |
| provider = llm_config.get('provider', 'langchain_groq') | |
| try: | |
| if provider == 'langchain_groq': | |
| model = ChatGroq( | |
| model=llm_config.get('model', 'openai/gpt-oss-20b'), | |
| temperature=llm_config.get('temperature', 0.1), | |
| max_tokens=llm_config.get('max_tokens', 2048) | |
| ) | |
| logger.info(f"Loaded Groq LLM: {llm_config.get('model')}") | |
| return model | |
| else: | |
| raise ValueError(f"Unsupported LLM provider: {provider}") | |
| except Exception as e: | |
| logger.error(f"Failed to load LLM: {str(e)}") | |
| raise | |
| def load_embeddings(self) -> Any: | |
| embed_config = self.config.get('embedding_model', {}) | |
| provider = embed_config.get('provider', 'fastembedding') | |
| try: | |
| if provider == 'fastembedding': | |
| embeddings = FastEmbedEmbeddings( | |
| model_name=embed_config.get('model_name', 'BAAI/bge-small-en-v1.5') | |
| ) | |
| logger.info(f"Loaded FastEmbed: {embed_config.get('model_name')}") | |
| return embeddings | |
| else: | |
| raise ValueError(f"Unsupported embedding provider: {provider}") | |
| except Exception as e: | |
| logger.error(f"Failed to load embeddings: {str(e)}") | |
| raise | |