File size: 7,751 Bytes
2cb39a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
import os
import tiktoken
from typing import Union
from azure.identity import DefaultAzureCredential, get_bearer_token_provider
from agents import OpenAIChatCompletionsModel
from openai import AsyncOpenAI, AsyncAzureOpenAI
from langchain_openai import AzureOpenAIEmbeddings, OpenAIEmbeddings
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_ollama import OllamaEmbeddings
from huggingface_hub import login

class OpenAIModelFactory:
    """

    Factory for creating OpenAI-SDK compatible model instances (using the 'agents' library).

    Supports multiple providers via the OpenAI-compatible API format.

    """

    @staticmethod
    def get_model(provider: str = "openai", # openai, azure, google, groq, ollama

                  model_name: str = "gpt-4o", 

                  temperature: float = 0

                  ) -> OpenAIChatCompletionsModel:
        """

        Returns an OpenAIChatCompletionsModel instance.

        """
        
        # ----------------------------------------------------------------------
        # AZURE OPENAI
        # ----------------------------------------------------------------------
        if provider.lower() == "azure":
            token_provider = get_bearer_token_provider(
                DefaultAzureCredential(), "https://cognitiveservices.azure.com/.default"
            )
            client = AsyncAzureOpenAI(
                azure_endpoint=os.environ["AZURE_OPENAI_API_URI"],
                api_version=os.environ["AZURE_OPENAI_API_VERSION"],
                azure_ad_token_provider=token_provider,
            )
            return OpenAIChatCompletionsModel(model=model_name, openai_client=client)

        # ----------------------------------------------------------------------
        # STANDARD OPENAI
        # ----------------------------------------------------------------------
        elif provider.lower() == "openai":
            client = AsyncOpenAI(api_key=os.environ["OPENAI_API_KEY"])
            return OpenAIChatCompletionsModel(model=model_name, openai_client=client)
            
        # ----------------------------------------------------------------------
        # GOOGLE (GEMINI) via OpenAI Compat
        # ----------------------------------------------------------------------
        elif provider.lower() == "google" or provider.lower() == "gemini":
            GEMINI_BASE_URL = "https://generativelanguage.googleapis.com/v1beta/openai/"
            client = AsyncOpenAI(
                base_url=GEMINI_BASE_URL,
                api_key=os.environ["GOOGLE_API_KEY"]
            )
            return OpenAIChatCompletionsModel(model=model_name, openai_client=client)

        # ----------------------------------------------------------------------
        # GROQ via OpenAI Compat
        # ----------------------------------------------------------------------
        elif provider.lower() == "groq":
            GROQ_BASE_URL = "https://api.groq.com/openai/v1"
            client = AsyncOpenAI(
                base_url=GROQ_BASE_URL,
                api_key=os.environ["GROQ_API_KEY"]
            )
            return OpenAIChatCompletionsModel(model=model_name, openai_client=client)

        # ----------------------------------------------------------------------
        # OLLAMA via OpenAI Compat
        # ----------------------------------------------------------------------
        elif provider.lower() == "ollama":
            client = AsyncOpenAI(
                base_url="http://localhost:11434/v1",
                api_key="ollama"
            )
            return OpenAIChatCompletionsModel(model=model_name, openai_client=client)

        # ----------------------------------------------------------------------
        # UNSUPPORTED
        # ----------------------------------------------------------------------
        else:
            raise ValueError(f"Unsupported provider for OpenAIModelFactory: {provider}")


    @staticmethod
    def num_tokens_from_messages(messages, model: str = "gpt-4o"):
        """

        Return the number of tokens used by a list of messages.

        """
        try:
            encoding = tiktoken.encoding_for_model(model)
        except KeyError:
            encoding = tiktoken.get_encoding("cl100k_base")

        tokens_per_message = 3
        num_tokens = 0

        for message in messages:
            num_tokens += tokens_per_message
            for key, value in message.items():
                if key == "name":
                    num_tokens += 1
                
                # Encode values if they are strings
                if isinstance(value, str):
                    num_tokens += len(encoding.encode(value))
                elif isinstance(value, list) and key == "content":
                    for part in value:
                        if isinstance(part, dict) and part.get("type") == "text":
                                num_tokens += len(encoding.encode(part.get("text", "")))
                        elif isinstance(part, dict) and part.get("type") == "image_url":
                                num_tokens += 85 

        num_tokens += 3 
        return num_tokens


class EmbeddingFactory:
    """

    A static utility class to create and return Embedding Model instances.

    """

    @staticmethod
    def get_embedding_model(provider: str = "openai",

                            model_name: str = "text-embedding-3-small"

                            ) -> Union[AzureOpenAIEmbeddings, OpenAIEmbeddings, OllamaEmbeddings, HuggingFaceEmbeddings]:
        
        if provider.lower() == "azure":
            token_provider = get_bearer_token_provider(
                DefaultAzureCredential(), "https://cognitiveservices.azure.com/.default"
            )
            return AzureOpenAIEmbeddings(
                azure_endpoint=os.environ["AZURE_OPENAI_API_URI"],
                azure_deployment=os.environ.get("AZURE_OPENAI_EMBEDDING_DEPLOYMENT", model_name),
                api_version=os.environ["AZURE_OPENAI_API_VERSION"],
                azure_ad_token_provider=token_provider,
            )
        elif provider.lower() == "openai":
            return OpenAIEmbeddings(
                api_key=os.environ["OPENAI_API_KEY"],
                model=model_name
            )
        elif provider.lower() == "ollama":
            return OllamaEmbeddings(model=model_name)
        elif provider.lower() == "huggingface":
            if os.environ.get("HF_TOKEN"):
                login(token=os.environ.get("HF_TOKEN"))
            return HuggingFaceEmbeddings(model_name=model_name)
        else:
            raise ValueError(f"Unsupported embedding provider: {provider}")


# =================================================================================================
# GLOBAL HELPER FUNCTIONS
# =================================================================================================

def get_model(provider:str = "openai", model_name:str = "gpt-4o"):
    """

    Global helper to get an OpenAI-SDK compatible model.

    Defaults to OpenAI provider and gpt-4o.

    """
    return OpenAIModelFactory.get_model(
        provider=provider,
        model_name=model_name,
        temperature=0
    )

def get_model_json(model_name: str = "gpt-4o-2024-08-06", provider: str = "openai"):
    """

    Global helper to get a JSON-capable model (Structured Outputs).

    Defaults to gpt-4o-2024-08-06 on OpenAI.

    """
    return OpenAIModelFactory.get_model(
        provider=provider,
        model_name=model_name,
        temperature=0
    )