File size: 1,856 Bytes
e555dc3
 
24f6114
e555dc3
 
24f6114
c2a036b
e555dc3
 
 
c2a036b
e555dc3
 
c2a036b
 
e555dc3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e1d311a
 
a53c022
 
 
e1d311a
 
e555dc3
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
"""Model factory for creating LLM instances."""

import os
from langchain_openai import ChatOpenAI, OpenAI
from langchain_core.language_models.base import BaseLanguageModel
from langchain_google_genai import ChatGoogleGenerativeAI

from .utils.error_handler import ErrorMessages, check_model_api_key


class Models:
    """Factory class for creating language model instances."""

    @staticmethod
    def get_model(model_name: str, **kwargs) -> BaseLanguageModel:
        """
        Get a language model instance based on the model name.

        Args:
            model_name: The name of the model to instantiate
            **kwargs: Additional arguments to pass to the model constructor

        Returns:
            A configured language model instance

        Raises:
            ValueError: If the model is not supported or API key is missing
        """
        # Check for required API keys before initializing
        api_key_error = check_model_api_key(model_name)
        if api_key_error:
            raise ValueError(api_key_error)

        match model_name:
            case "gpt-4.1-mini" | "gpt-4o-mini" | "gpt-4" | "gpt-3.5-turbo":
                return ChatOpenAI(model_name=model_name, **kwargs)
            case name if name.startswith("text-"):
                return OpenAI(model_name=model_name, **kwargs)
            case name if name.startswith("gemini-"):
                return ChatGoogleGenerativeAI(model=model_name, **kwargs)
            case "alias-large" | "alias-fast":
                return ChatOpenAI(
                    model=model_name,
                    api_key=os.getenv("BLABLADOR_API_KEY"),
                    base_url="https://api.helmholtz-blablador.fz-juelich.de/v1",
                    **kwargs
                )
            case _:
                raise ValueError(f"Unsupported model: {model_name}")