Spaces:
Sleeping
Sleeping
| import os | |
| import tiktoken | |
| from typing import Union | |
| from azure.identity import DefaultAzureCredential, get_bearer_token_provider | |
| from agents import OpenAIChatCompletionsModel | |
| from openai import AsyncOpenAI, AsyncAzureOpenAI | |
| from langchain_openai import AzureOpenAIEmbeddings, OpenAIEmbeddings | |
| from langchain_huggingface import HuggingFaceEmbeddings | |
| from langchain_ollama import OllamaEmbeddings | |
| from huggingface_hub import login | |
| class OpenAIModelFactory: | |
| """ | |
| Factory for creating OpenAI-SDK compatible model instances (using the 'agents' library). | |
| Supports multiple providers via the OpenAI-compatible API format. | |
| """ | |
| def get_model(provider: str = "openai", # openai, azure, google, groq, ollama | |
| model_name: str = "gpt-4o", | |
| temperature: float = 0 | |
| ) -> OpenAIChatCompletionsModel: | |
| """ | |
| Returns an OpenAIChatCompletionsModel instance. | |
| """ | |
| # ---------------------------------------------------------------------- | |
| # AZURE OPENAI | |
| # ---------------------------------------------------------------------- | |
| if provider.lower() == "azure": | |
| token_provider = get_bearer_token_provider( | |
| DefaultAzureCredential(), "https://cognitiveservices.azure.com/.default" | |
| ) | |
| client = AsyncAzureOpenAI( | |
| azure_endpoint=os.environ["AZURE_OPENAI_API_URI"], | |
| api_version=os.environ["AZURE_OPENAI_API_VERSION"], | |
| azure_ad_token_provider=token_provider, | |
| ) | |
| return OpenAIChatCompletionsModel(model=model_name, openai_client=client) | |
| # ---------------------------------------------------------------------- | |
| # STANDARD OPENAI | |
| # ---------------------------------------------------------------------- | |
| elif provider.lower() == "openai": | |
| client = AsyncOpenAI(api_key=os.environ["OPENAI_API_KEY"]) | |
| return OpenAIChatCompletionsModel(model=model_name, openai_client=client) | |
| # ---------------------------------------------------------------------- | |
| # GOOGLE (GEMINI) via OpenAI Compat | |
| # ---------------------------------------------------------------------- | |
| elif provider.lower() == "google" or provider.lower() == "gemini": | |
| GEMINI_BASE_URL = "https://generativelanguage.googleapis.com/v1beta/openai/" | |
| client = AsyncOpenAI( | |
| base_url=GEMINI_BASE_URL, | |
| api_key=os.environ["GOOGLE_API_KEY"] | |
| ) | |
| return OpenAIChatCompletionsModel(model=model_name, openai_client=client) | |
| # ---------------------------------------------------------------------- | |
| # GROQ via OpenAI Compat | |
| # ---------------------------------------------------------------------- | |
| elif provider.lower() == "groq": | |
| GROQ_BASE_URL = "https://api.groq.com/openai/v1" | |
| client = AsyncOpenAI( | |
| base_url=GROQ_BASE_URL, | |
| api_key=os.environ["GROQ_API_KEY"] | |
| ) | |
| return OpenAIChatCompletionsModel(model=model_name, openai_client=client) | |
| # ---------------------------------------------------------------------- | |
| # OLLAMA via OpenAI Compat | |
| # ---------------------------------------------------------------------- | |
| elif provider.lower() == "ollama": | |
| client = AsyncOpenAI( | |
| base_url="http://localhost:11434/v1", | |
| api_key="ollama" | |
| ) | |
| return OpenAIChatCompletionsModel(model=model_name, openai_client=client) | |
| # ---------------------------------------------------------------------- | |
| # UNSUPPORTED | |
| # ---------------------------------------------------------------------- | |
| else: | |
| raise ValueError(f"Unsupported provider for OpenAIModelFactory: {provider}") | |
| def num_tokens_from_messages(messages, model: str = "gpt-4o"): | |
| """ | |
| Return the number of tokens used by a list of messages. | |
| """ | |
| try: | |
| encoding = tiktoken.encoding_for_model(model) | |
| except KeyError: | |
| encoding = tiktoken.get_encoding("cl100k_base") | |
| tokens_per_message = 3 | |
| num_tokens = 0 | |
| for message in messages: | |
| num_tokens += tokens_per_message | |
| for key, value in message.items(): | |
| if key == "name": | |
| num_tokens += 1 | |
| # Encode values if they are strings | |
| if isinstance(value, str): | |
| num_tokens += len(encoding.encode(value)) | |
| elif isinstance(value, list) and key == "content": | |
| for part in value: | |
| if isinstance(part, dict) and part.get("type") == "text": | |
| num_tokens += len(encoding.encode(part.get("text", ""))) | |
| elif isinstance(part, dict) and part.get("type") == "image_url": | |
| num_tokens += 85 | |
| num_tokens += 3 | |
| return num_tokens | |
| class EmbeddingFactory: | |
| """ | |
| A static utility class to create and return Embedding Model instances. | |
| """ | |
| def get_embedding_model(provider: str = "openai", | |
| model_name: str = "text-embedding-3-small" | |
| ) -> Union[AzureOpenAIEmbeddings, OpenAIEmbeddings, OllamaEmbeddings, HuggingFaceEmbeddings]: | |
| if provider.lower() == "azure": | |
| token_provider = get_bearer_token_provider( | |
| DefaultAzureCredential(), "https://cognitiveservices.azure.com/.default" | |
| ) | |
| return AzureOpenAIEmbeddings( | |
| azure_endpoint=os.environ["AZURE_OPENAI_API_URI"], | |
| azure_deployment=os.environ.get("AZURE_OPENAI_EMBEDDING_DEPLOYMENT", model_name), | |
| api_version=os.environ["AZURE_OPENAI_API_VERSION"], | |
| azure_ad_token_provider=token_provider, | |
| ) | |
| elif provider.lower() == "openai": | |
| return OpenAIEmbeddings( | |
| api_key=os.environ["OPENAI_API_KEY"], | |
| model=model_name | |
| ) | |
| elif provider.lower() == "ollama": | |
| return OllamaEmbeddings(model=model_name) | |
| elif provider.lower() == "huggingface": | |
| if os.environ.get("HF_TOKEN"): | |
| login(token=os.environ.get("HF_TOKEN")) | |
| return HuggingFaceEmbeddings(model_name=model_name) | |
| else: | |
| raise ValueError(f"Unsupported embedding provider: {provider}") | |
| # ================================================================================================= | |
| # GLOBAL HELPER FUNCTIONS | |
| # ================================================================================================= | |
| def get_model(provider:str = "openai", model_name:str = "gpt-4o"): | |
| """ | |
| Global helper to get an OpenAI-SDK compatible model. | |
| Defaults to OpenAI provider and gpt-4o. | |
| """ | |
| return OpenAIModelFactory.get_model( | |
| provider=provider, | |
| model_name=model_name, | |
| temperature=0 | |
| ) | |
| def get_model_json(model_name: str = "gpt-4o-2024-08-06", provider: str = "openai"): | |
| """ | |
| Global helper to get a JSON-capable model (Structured Outputs). | |
| Defaults to gpt-4o-2024-08-06 on OpenAI. | |
| """ | |
| return OpenAIModelFactory.get_model( | |
| provider=provider, | |
| model_name=model_name, | |
| temperature=0 | |
| ) | |