Spaces:
Sleeping
Sleeping
| from langchain_google_genai import ChatGoogleGenerativeAI | |
| from langchain_google_genai.embeddings import GoogleGenerativeAIEmbeddings | |
| from src.utils.logger import logger | |
| from langchain_core.language_models.chat_models import BaseChatModel | |
| # Default model instances | |
| llm_2_0 = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=1) | |
| llm_2_5_flash_preview = ChatGoogleGenerativeAI( | |
| model="gemini-2.5-flash-preview-05-20", temperature=1, thinking_budget=None | |
| ) | |
| llm_2_0_flash_lite = ChatGoogleGenerativeAI( | |
| model="gemini-2.0-flash-lite", temperature=1 | |
| ) | |
| # Default embeddings model | |
| embeddings = GoogleGenerativeAIEmbeddings(model="models/text-embedding-004") | |
| def get_llm( | |
| model_name: str = "gemini-2.0-flash", | |
| api_key: str = None, | |
| include_thoughts: bool = False, | |
| reasoning: bool = False, | |
| ) -> BaseChatModel: | |
| """ | |
| Get LLM instance based on model name and optional API key. | |
| Args: | |
| model_name: Name of the model to use | |
| api_key: Optional API key for authentication | |
| Returns: | |
| Configured ChatGoogleGenerativeAI instance | |
| Raises: | |
| ValueError: If model name is not supported | |
| """ | |
| if api_key: | |
| logger.warning("Using custom API key") | |
| if model_name == "gemini-2.0-flash": | |
| return ChatGoogleGenerativeAI( | |
| model=model_name, | |
| temperature=1, | |
| google_api_key=api_key, | |
| ) | |
| elif model_name == "gemini-2.5-flash-preview-05-20": | |
| return ChatGoogleGenerativeAI( | |
| model=model_name, | |
| temperature=1, | |
| google_api_key=api_key, | |
| include_thoughts=include_thoughts, | |
| thinking_budget=None if reasoning else 0, | |
| ) | |
| elif model_name == "gemini-2.0-flash-lite": | |
| return ChatGoogleGenerativeAI( | |
| model=model_name, | |
| temperature=1, | |
| google_api_key=api_key, | |
| ) | |
| if model_name == "gemini-2.0-flash": | |
| return llm_2_0 | |
| elif model_name == "gemini-2.5-flash-preview-05-20": | |
| return llm_2_5_flash_preview | |
| elif model_name == "gemini-2.0-flash-lite": | |
| return llm_2_0_flash_lite | |
| raise ValueError(f"Unknown model: {model_name}") | |