File size: 2,222 Bytes
6c58cf4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import os
from typing import Any
from dotenv import load_dotenv
from langchain_groq import ChatGroq
from langchain_community.embeddings import FastEmbedEmbeddings
from project.utils.config_loader import load_config
from project.logger.logging import get_logger

logger = get_logger(__name__)


class ModelLoader:
    def __init__(self, config_path: str = None):
        load_dotenv()
        self.config = load_config(config_path)
        self._load_api_keys()
        logger.info("ModelLoader initialized")
    
    def _load_api_keys(self):
        groq_key = os.getenv('GROQ_API_KEY')
        
        if groq_key:
            os.environ['GROQ_API_KEY'] = groq_key
            logger.info("GROQ API key loaded")
        
    
    def load_llm(self) -> Any:
        llm_config = self.config.get('llm', {})
        provider = llm_config.get('provider', 'langchain_groq')
        
        try:
            if provider == 'langchain_groq':
                model = ChatGroq(
                    model=llm_config.get('model', 'openai/gpt-oss-20b'),
                    temperature=llm_config.get('temperature', 0.1),
                    max_tokens=llm_config.get('max_tokens', 2048)
                )
                logger.info(f"Loaded Groq LLM: {llm_config.get('model')}")
                return model
            else:
                raise ValueError(f"Unsupported LLM provider: {provider}")
        except Exception as e:
            logger.error(f"Failed to load LLM: {str(e)}")
            raise
    
    def load_embeddings(self) -> Any:
        embed_config = self.config.get('embedding_model', {})
        provider = embed_config.get('provider', 'fastembedding')
        
        try:
            if provider == 'fastembedding':
                embeddings = FastEmbedEmbeddings(
                    model_name=embed_config.get('model_name', 'BAAI/bge-small-en-v1.5')
                )
                logger.info(f"Loaded FastEmbed: {embed_config.get('model_name')}")
                return embeddings
            else:
                raise ValueError(f"Unsupported embedding provider: {provider}")
        except Exception as e:
            logger.error(f"Failed to load embeddings: {str(e)}")
            raise