File size: 2,264 Bytes
172064c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_google_genai.embeddings import GoogleGenerativeAIEmbeddings
from src.utils.logger import logger
from langchain_core.language_models.chat_models import BaseChatModel

# Default model instances
llm_2_0 = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=1)
llm_2_5_flash_preview = ChatGoogleGenerativeAI(
    model="gemini-2.5-flash-preview-05-20", temperature=1, thinking_budget=None
)
llm_2_0_flash_lite = ChatGoogleGenerativeAI(
    model="gemini-2.0-flash-lite", temperature=1
)
# Default embeddings model
embeddings = GoogleGenerativeAIEmbeddings(model="models/text-embedding-004")


def get_llm(
    model_name: str = "gemini-2.0-flash",
    api_key: str = None,
    include_thoughts: bool = False,
    reasoning: bool = False,
) -> BaseChatModel:
    """
    Get LLM instance based on model name and optional API key.

    Args:
        model_name: Name of the model to use
        api_key: Optional API key for authentication

    Returns:
        Configured ChatGoogleGenerativeAI instance

    Raises:
        ValueError: If model name is not supported
    """
    if api_key:
        logger.warning("Using custom API key")
        if model_name == "gemini-2.0-flash":
            return ChatGoogleGenerativeAI(
                model=model_name,
                temperature=1,
                google_api_key=api_key,
            )
        elif model_name == "gemini-2.5-flash-preview-05-20":
            return ChatGoogleGenerativeAI(
                model=model_name,
                temperature=1,
                google_api_key=api_key,
                include_thoughts=include_thoughts,
                thinking_budget=None if reasoning else 0,
            )
        elif model_name == "gemini-2.0-flash-lite":
            return ChatGoogleGenerativeAI(
                model=model_name,
                temperature=1,
                google_api_key=api_key,
            )

    if model_name == "gemini-2.0-flash":
        return llm_2_0
    elif model_name == "gemini-2.5-flash-preview-05-20":
        return llm_2_5_flash_preview
    elif model_name == "gemini-2.0-flash-lite":
        return llm_2_0_flash_lite

    raise ValueError(f"Unknown model: {model_name}")