Spaces:
Sleeping
Sleeping
File size: 2,264 Bytes
172064c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 |
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_google_genai.embeddings import GoogleGenerativeAIEmbeddings
from src.utils.logger import logger
from langchain_core.language_models.chat_models import BaseChatModel
# Default model instances
llm_2_0 = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=1)
llm_2_5_flash_preview = ChatGoogleGenerativeAI(
model="gemini-2.5-flash-preview-05-20", temperature=1, thinking_budget=None
)
llm_2_0_flash_lite = ChatGoogleGenerativeAI(
model="gemini-2.0-flash-lite", temperature=1
)
# Default embeddings model
embeddings = GoogleGenerativeAIEmbeddings(model="models/text-embedding-004")
def get_llm(
model_name: str = "gemini-2.0-flash",
api_key: str = None,
include_thoughts: bool = False,
reasoning: bool = False,
) -> BaseChatModel:
"""
Get LLM instance based on model name and optional API key.
Args:
model_name: Name of the model to use
api_key: Optional API key for authentication
Returns:
Configured ChatGoogleGenerativeAI instance
Raises:
ValueError: If model name is not supported
"""
if api_key:
logger.warning("Using custom API key")
if model_name == "gemini-2.0-flash":
return ChatGoogleGenerativeAI(
model=model_name,
temperature=1,
google_api_key=api_key,
)
elif model_name == "gemini-2.5-flash-preview-05-20":
return ChatGoogleGenerativeAI(
model=model_name,
temperature=1,
google_api_key=api_key,
include_thoughts=include_thoughts,
thinking_budget=None if reasoning else 0,
)
elif model_name == "gemini-2.0-flash-lite":
return ChatGoogleGenerativeAI(
model=model_name,
temperature=1,
google_api_key=api_key,
)
if model_name == "gemini-2.0-flash":
return llm_2_0
elif model_name == "gemini-2.5-flash-preview-05-20":
return llm_2_5_flash_preview
elif model_name == "gemini-2.0-flash-lite":
return llm_2_0_flash_lite
raise ValueError(f"Unknown model: {model_name}")
|