cole / src /language_model /private_language_model_factory.py
davebulaval's picture
v1
8fa3acc
import os
from typing import Union
from predictions.all_llms import private_llm
from src.language_model.anthropic_wrapper import AnthropicWrapper
from src.language_model.cohere_wrapper import CohereWrapper
from src.language_model.deepseek_wrapper import DeepSeekWrapper
from src.language_model.mistral_wrapper import MistralWrapper
from src.language_model.open_ai_wrapper import OpenAIWrapper
from src.language_model.xai_wrapper import XAIWrapper
def get_api_key(model_name: str) -> Union[str, None]:
if model_name in private_llm["openai"]:
key_name = "openai_api_key"
elif model_name in private_llm["anthropic"]:
key_name = "anthropic_token"
elif model_name in private_llm["deepseek"]:
key_name = "deepseek_token"
elif model_name in private_llm["mistral"]:
key_name = "mistral_token"
elif model_name in private_llm["xai"]:
key_name = "XAI_API_KEY"
elif model_name in private_llm["openrouter"]:
key_name = "open_route_api_key"
elif model_name in private_llm["cohere"]:
key_name = "cohere_api_key"
else:
raise ValueError(f"Model name {model_name} not found.")
api_key = os.getenv(key_name, None)
if api_key is None:
raise ValueError(f"API key {key_name} not found.")
return api_key
def private_language_model_factory(model_name):
if model_name in private_llm["all"]:
api_key = get_api_key(model_name)
if model_name in private_llm["openai"]:
if "o1" in model_name and not "o1-mini" in model_name:
extra_params = {
"reasoning_effort": "low"
} # Otherwise take too many tokens and stop the process.
else:
extra_params = {}
if "mini" in model_name:
use_function_calling = False
else:
use_function_calling = True
model = OpenAIWrapper(
model_name=model_name,
api_key=api_key,
extra_params=extra_params,
use_function_calling=use_function_calling,
)
elif model_name in private_llm["anthropic"]:
extra_params = {"max_tokens": 5012}
model = AnthropicWrapper(
model_name=model_name, api_key=api_key, extra_params=extra_params
)
elif model_name in private_llm["deepseek"]:
extra_params = {"timeout": 120}
# DeepSeek reasoner does not support function calling
use_function_calling = model_name == "deepseek-reasoner"
model = DeepSeekWrapper(
model_name=model_name,
api_key=api_key,
extra_params=extra_params,
use_function_calling=use_function_calling,
)
elif model_name in private_llm["xai"]:
extra_params = {}
model = XAIWrapper(
model_name=model_name, api_key=api_key, extra_params=extra_params
)
elif model_name in private_llm["mistral"]:
extra_params = {}
model = MistralWrapper(
model_name=model_name, api_key=api_key, extra_params=extra_params
)
elif model_name in private_llm["openrouter"]:
extra_params = {}
use_function_calling = True
model = OpenAIWrapper(
model_name=model_name,
api_key=api_key,
extra_params=extra_params,
use_function_calling=use_function_calling,
base_url="https://openrouter.ai/api/v1",
timeout=480,
)
elif model_name in private_llm["cohere"]:
extra_params = {}
use_function_calling = True
model = CohereWrapper(
model_name=model_name,
api_key=api_key,
extra_params=extra_params,
use_function_calling=use_function_calling,
base_url="https://api.cohere.ai/compatibility/v1",
timeout=480,
)
else:
raise NotImplementedError("Not implemented yet.")
else:
raise ValueError(f"Model name {model_name} not found.")
return model