llm_compare / providers.py
crossingk's picture
Update providers.py
d10f45c verified
import os
import requests
from openai import OpenAI
import anthropic
from google import genai
# ---------------------------------------------------------------------------
# Model Registry
# Each entry: display_name -> {provider, model_id, base_url (None = default), env_var}
# ---------------------------------------------------------------------------
MODEL_REGISTRY: dict[str, dict] = {
"GPT-4o (OpenAI)": {
"provider": "openai",
"model_id": "gpt-4o",
"base_url": None,
"env_var": "OPENAI_API_KEY",
"env_base_url": "OPENAI_BASE_URL",
"env_model_id": "OPENAI_MODEL_ID",
},
"GPT-4o-mini (OpenAI)": {
"provider": "openai",
"model_id": "gpt-4o-mini",
"base_url": None,
"env_var": "OPENAI_API_KEY",
"env_base_url": "OPENAI_BASE_URL",
"env_model_id": "OPENAI_MINI_MODEL_ID",
},
"Claude Sonnet 4 (Anthropic)": {
"provider": "anthropic",
"model_id": "claude-sonnet-4-6",
"base_url": None,
"env_var": "ANTHROPIC_API_KEY",
"env_base_url": "ANTHROPIC_BASE_URL",
"env_model_id": "ANTHROPIC_MODEL_ID",
},
"Gemini 2.0 Flash (Google)": {
"provider": "gemini",
"model_id": "gemini-2.0-flash",
"base_url": None,
"env_var": "GOOGLE_API_KEY",
"env_base_url": "GOOGLE_BASE_URL",
"env_model_id": "GOOGLE_MODEL_ID",
},
"Qwen-Plus (Alibaba)": {
"provider": "openai_compat",
"model_id": "qwen-plus",
"base_url": "https://dashscope.aliyuncs.com/compatible-mode/v1",
"env_var": "DASHSCOPE_API_KEY",
"env_base_url": "DASHSCOPE_BASE_URL",
"env_model_id": "DASHSCOPE_MODEL_ID",
},
"Yi-Large (01.AI)": {
"provider": "openai_compat",
"model_id": "yi-large",
"base_url": "https://api.01.ai/v1",
"env_var": "YI_API_KEY",
"env_base_url": "YI_BASE_URL",
"env_model_id": "YI_MODEL_ID",
},
}
MODEL_NAMES = list(MODEL_REGISTRY.keys())
def get_model_defaults(display_name: str) -> tuple[str, str]:
"""Return (base_url, model_id) for a registry model, considering env overrides.
Priority: env var > registry hardcoded value.
"""
entry = MODEL_REGISTRY.get(display_name, {})
base_url = os.environ.get(entry.get("env_base_url", ""), "") or entry.get("base_url") or ""
model_id = os.environ.get(entry.get("env_model_id", ""), "") or entry.get("model_id", "")
return base_url, model_id
def _resolve_key(env_var: str, user_key: str | None) -> str:
"""Return user-provided key if non-empty, else fall back to env var."""
if user_key and user_key.strip():
return user_key.strip()
key = os.environ.get(env_var, "")
if not key:
raise ValueError(
f"No API key provided and environment variable {env_var} is not set."
)
return key
# ---------------------------------------------------------------------------
# Provider dispatch
# ---------------------------------------------------------------------------
def _call_openai(model_id: str, prompt: str, api_key: str, base_url: str | None) -> str:
client = OpenAI(api_key=api_key, base_url=base_url)
resp = client.chat.completions.create(
model=model_id,
messages=[{"role": "user", "content": prompt}],
)
return resp.choices[0].message.content
def _call_anthropic(model_id: str, prompt: str, api_key: str) -> str:
client = anthropic.Anthropic(api_key=api_key)
resp = client.messages.create(
model=model_id,
max_tokens=4096,
messages=[{"role": "user", "content": prompt}],
)
return resp.content[0].text
def _call_gemini(model_id: str, prompt: str, api_key: str) -> str:
client = genai.Client(api_key=api_key)
resp = client.models.generate_content(model=model_id, contents=prompt)
return resp.text
def call_model(
display_name: str,
prompt: str,
user_key: str | None = None,
user_base_url: str | None = None,
user_model_id: str | None = None,
) -> str:
"""Call a reference model from the registry.
User-supplied base_url / model_id override env-var defaults, which in turn
override the hardcoded registry values.
"""
entry = MODEL_REGISTRY.get(display_name)
if entry is None:
raise ValueError(f"Unknown model: {display_name}")
api_key = _resolve_key(entry["env_var"], user_key)
provider = entry["provider"]
# Resolve: user input > env var > registry default
default_base_url, default_model_id = get_model_defaults(display_name)
model_id = (user_model_id.strip() if user_model_id and user_model_id.strip() else "") or default_model_id
base_url = (user_base_url.strip() if user_base_url and user_base_url.strip() else "") or default_base_url or None
if provider in ("openai", "openai_compat"):
return _call_openai(model_id, prompt, api_key, base_url)
elif provider == "anthropic":
return _call_anthropic(model_id, prompt, api_key)
elif provider == "gemini":
return _call_gemini(model_id, prompt, api_key)
else:
raise ValueError(f"Unknown provider: {provider}")
def call_custom_endpoint(
base_url: str, model_name: str, prompt: str, api_key: str
) -> str:
"""Call a user-supplied Dify application endpoint (left column).
Dify API docs: https://docs.dify.ai/en/guides/application-publishing/developing-with-apis
base_url should be the Dify API base, e.g. https://api.dify.ai/v1
The endpoint called is {base_url}/chat-messages (for Chat apps).
"""
if not base_url or not base_url.strip():
raise ValueError("API endpoint URL is required for your Dify model.")
if not api_key or not api_key.strip():
raise ValueError("API Key (Secret Key) is required for Dify.")
url = base_url.strip().rstrip("/") + "/chat-messages"
headers = {
"Authorization": f"Bearer {api_key.strip()}",
"Content-Type": "application/json",
}
payload = {
"inputs": {},
"query": prompt,
"response_mode": "blocking",
"user": "llm-compare-user",
}
resp = requests.post(url, json=payload, headers=headers, timeout=120)
resp.raise_for_status()
data = resp.json()
answer = data.get("answer", "")
if not answer:
raise ValueError(f"Dify returned no answer. Full response: {data}")
return answer