Spaces:
Runtime error
Runtime error
| import os | |
| import logging | |
| from dotenv import load_dotenv | |
| from langchain_openai import AzureChatOpenAI, ChatOpenAI | |
| from langchain_google_genai import ChatGoogleGenerativeAI | |
| # Configure logging | |
| logger = logging.getLogger(__name__) | |
| load_dotenv() | |
| os.environ["OPENAI_API_VERSION"] = "2024-12-01-preview" | |
| def get_llm(model_name, provider="openai", api_key=None, **kwargs): | |
| """ | |
| Get a language model instance based on the specified provider and model name. | |
| Args: | |
| model_name (str): The name of the model to use. | |
| provider (str): The provider of the model (openai, gemini, azure). | |
| api_key (str, optional): API key for the provider. If None, will use environment variable. | |
| Returns: | |
| llm: An instance of the language model. | |
| """ | |
| if provider == "openai": | |
| logging.info("Using OpenAI provider with model: %s", model_name) | |
| # Use provided API key or fall back to environment variable | |
| openai_api_key = api_key or os.getenv("OPENAI_API_KEY") | |
| llm = ChatOpenAI( | |
| model_name = "gpt-4o-mini", | |
| api_key = openai_api_key, | |
| ) | |
| elif provider == "gemini": | |
| logging.info("Using Gemini provider with model: %s", model_name) | |
| # Use provided API key or fall back to environment variable | |
| google_api_key = api_key or os.getenv("GOOGLE_API_KEY") | |
| llm = ChatGoogleGenerativeAI( | |
| model='gemini-2.0-flash', | |
| api_key=google_api_key, | |
| **kwargs | |
| ) | |
| else: | |
| return None | |
| return llm | |