Spaces:
Running
Running
| import os | |
| from typing import Union | |
| from predictions.all_llms import private_llm | |
| from src.language_model.anthropic_wrapper import AnthropicWrapper | |
| from src.language_model.cohere_wrapper import CohereWrapper | |
| from src.language_model.deepseek_wrapper import DeepSeekWrapper | |
| from src.language_model.mistral_wrapper import MistralWrapper | |
| from src.language_model.open_ai_wrapper import OpenAIWrapper | |
| from src.language_model.xai_wrapper import XAIWrapper | |
| def get_api_key(model_name: str) -> Union[str, None]: | |
| if model_name in private_llm["openai"]: | |
| key_name = "openai_api_key" | |
| elif model_name in private_llm["anthropic"]: | |
| key_name = "anthropic_token" | |
| elif model_name in private_llm["deepseek"]: | |
| key_name = "deepseek_token" | |
| elif model_name in private_llm["mistral"]: | |
| key_name = "mistral_token" | |
| elif model_name in private_llm["xai"]: | |
| key_name = "XAI_API_KEY" | |
| elif model_name in private_llm["openrouter"]: | |
| key_name = "open_route_api_key" | |
| elif model_name in private_llm["cohere"]: | |
| key_name = "cohere_api_key" | |
| else: | |
| raise ValueError(f"Model name {model_name} not found.") | |
| api_key = os.getenv(key_name, None) | |
| if api_key is None: | |
| raise ValueError(f"API key {key_name} not found.") | |
| return api_key | |
| def private_language_model_factory(model_name): | |
| if model_name in private_llm["all"]: | |
| api_key = get_api_key(model_name) | |
| if model_name in private_llm["openai"]: | |
| if "o1" in model_name and not "o1-mini" in model_name: | |
| extra_params = { | |
| "reasoning_effort": "low" | |
| } # Otherwise take too many tokens and stop the process. | |
| else: | |
| extra_params = {} | |
| if "mini" in model_name: | |
| use_function_calling = False | |
| else: | |
| use_function_calling = True | |
| model = OpenAIWrapper( | |
| model_name=model_name, | |
| api_key=api_key, | |
| extra_params=extra_params, | |
| use_function_calling=use_function_calling, | |
| ) | |
| elif model_name in private_llm["anthropic"]: | |
| extra_params = {"max_tokens": 5012} | |
| model = AnthropicWrapper( | |
| model_name=model_name, api_key=api_key, extra_params=extra_params | |
| ) | |
| elif model_name in private_llm["deepseek"]: | |
| extra_params = {"timeout": 120} | |
| # DeepSeek reasoner does not support function calling | |
| use_function_calling = model_name == "deepseek-reasoner" | |
| model = DeepSeekWrapper( | |
| model_name=model_name, | |
| api_key=api_key, | |
| extra_params=extra_params, | |
| use_function_calling=use_function_calling, | |
| ) | |
| elif model_name in private_llm["xai"]: | |
| extra_params = {} | |
| model = XAIWrapper( | |
| model_name=model_name, api_key=api_key, extra_params=extra_params | |
| ) | |
| elif model_name in private_llm["mistral"]: | |
| extra_params = {} | |
| model = MistralWrapper( | |
| model_name=model_name, api_key=api_key, extra_params=extra_params | |
| ) | |
| elif model_name in private_llm["openrouter"]: | |
| extra_params = {} | |
| use_function_calling = True | |
| model = OpenAIWrapper( | |
| model_name=model_name, | |
| api_key=api_key, | |
| extra_params=extra_params, | |
| use_function_calling=use_function_calling, | |
| base_url="https://openrouter.ai/api/v1", | |
| timeout=480, | |
| ) | |
| elif model_name in private_llm["cohere"]: | |
| extra_params = {} | |
| use_function_calling = True | |
| model = CohereWrapper( | |
| model_name=model_name, | |
| api_key=api_key, | |
| extra_params=extra_params, | |
| use_function_calling=use_function_calling, | |
| base_url="https://api.cohere.ai/compatibility/v1", | |
| timeout=480, | |
| ) | |
| else: | |
| raise NotImplementedError("Not implemented yet.") | |
| else: | |
| raise ValueError(f"Model name {model_name} not found.") | |
| return model | |