ABAO77's picture
Upload 60 files
172064c verified
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_google_genai.embeddings import GoogleGenerativeAIEmbeddings
from src.utils.logger import logger
from langchain_core.language_models.chat_models import BaseChatModel
# Default model instances
llm_2_0 = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=1)
llm_2_5_flash_preview = ChatGoogleGenerativeAI(
model="gemini-2.5-flash-preview-05-20", temperature=1, thinking_budget=None
)
llm_2_0_flash_lite = ChatGoogleGenerativeAI(
model="gemini-2.0-flash-lite", temperature=1
)
# Default embeddings model
embeddings = GoogleGenerativeAIEmbeddings(model="models/text-embedding-004")
def get_llm(
model_name: str = "gemini-2.0-flash",
api_key: str = None,
include_thoughts: bool = False,
reasoning: bool = False,
) -> BaseChatModel:
"""
Get LLM instance based on model name and optional API key.
Args:
model_name: Name of the model to use
api_key: Optional API key for authentication
Returns:
Configured ChatGoogleGenerativeAI instance
Raises:
ValueError: If model name is not supported
"""
if api_key:
logger.warning("Using custom API key")
if model_name == "gemini-2.0-flash":
return ChatGoogleGenerativeAI(
model=model_name,
temperature=1,
google_api_key=api_key,
)
elif model_name == "gemini-2.5-flash-preview-05-20":
return ChatGoogleGenerativeAI(
model=model_name,
temperature=1,
google_api_key=api_key,
include_thoughts=include_thoughts,
thinking_budget=None if reasoning else 0,
)
elif model_name == "gemini-2.0-flash-lite":
return ChatGoogleGenerativeAI(
model=model_name,
temperature=1,
google_api_key=api_key,
)
if model_name == "gemini-2.0-flash":
return llm_2_0
elif model_name == "gemini-2.5-flash-preview-05-20":
return llm_2_5_flash_preview
elif model_name == "gemini-2.0-flash-lite":
return llm_2_0_flash_lite
raise ValueError(f"Unknown model: {model_name}")