Spaces:
Running
Running
File size: 1,856 Bytes
e555dc3 24f6114 e555dc3 24f6114 c2a036b e555dc3 c2a036b e555dc3 c2a036b e555dc3 e1d311a a53c022 e1d311a e555dc3 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 | """Model factory for creating LLM instances."""
import os
from langchain_openai import ChatOpenAI, OpenAI
from langchain_core.language_models.base import BaseLanguageModel
from langchain_google_genai import ChatGoogleGenerativeAI
from .utils.error_handler import ErrorMessages, check_model_api_key
class Models:
"""Factory class for creating language model instances."""
@staticmethod
def get_model(model_name: str, **kwargs) -> BaseLanguageModel:
"""
Get a language model instance based on the model name.
Args:
model_name: The name of the model to instantiate
**kwargs: Additional arguments to pass to the model constructor
Returns:
A configured language model instance
Raises:
ValueError: If the model is not supported or API key is missing
"""
# Check for required API keys before initializing
api_key_error = check_model_api_key(model_name)
if api_key_error:
raise ValueError(api_key_error)
match model_name:
case "gpt-4.1-mini" | "gpt-4o-mini" | "gpt-4" | "gpt-3.5-turbo":
return ChatOpenAI(model_name=model_name, **kwargs)
case name if name.startswith("text-"):
return OpenAI(model_name=model_name, **kwargs)
case name if name.startswith("gemini-"):
return ChatGoogleGenerativeAI(model=model_name, **kwargs)
case "alias-large" | "alias-fast":
return ChatOpenAI(
model=model_name,
api_key=os.getenv("BLABLADOR_API_KEY"),
base_url="https://api.helmholtz-blablador.fz-juelich.de/v1",
**kwargs
)
case _:
raise ValueError(f"Unsupported model: {model_name}") |