Spaces:
Sleeping
Sleeping
| import os | |
| import requests | |
| from openai import OpenAI | |
| import anthropic | |
| from google import genai | |
| # --------------------------------------------------------------------------- | |
| # Model Registry | |
| # Each entry: display_name -> {provider, model_id, base_url (None = default), env_var} | |
| # --------------------------------------------------------------------------- | |
| MODEL_REGISTRY: dict[str, dict] = { | |
| "GPT-4o (OpenAI)": { | |
| "provider": "openai", | |
| "model_id": "gpt-4o", | |
| "base_url": None, | |
| "env_var": "OPENAI_API_KEY", | |
| "env_base_url": "OPENAI_BASE_URL", | |
| "env_model_id": "OPENAI_MODEL_ID", | |
| }, | |
| "GPT-4o-mini (OpenAI)": { | |
| "provider": "openai", | |
| "model_id": "gpt-4o-mini", | |
| "base_url": None, | |
| "env_var": "OPENAI_API_KEY", | |
| "env_base_url": "OPENAI_BASE_URL", | |
| "env_model_id": "OPENAI_MINI_MODEL_ID", | |
| }, | |
| "Claude Sonnet 4 (Anthropic)": { | |
| "provider": "anthropic", | |
| "model_id": "claude-sonnet-4-6", | |
| "base_url": None, | |
| "env_var": "ANTHROPIC_API_KEY", | |
| "env_base_url": "ANTHROPIC_BASE_URL", | |
| "env_model_id": "ANTHROPIC_MODEL_ID", | |
| }, | |
| "Gemini 2.0 Flash (Google)": { | |
| "provider": "gemini", | |
| "model_id": "gemini-2.0-flash", | |
| "base_url": None, | |
| "env_var": "GOOGLE_API_KEY", | |
| "env_base_url": "GOOGLE_BASE_URL", | |
| "env_model_id": "GOOGLE_MODEL_ID", | |
| }, | |
| "Qwen-Plus (Alibaba)": { | |
| "provider": "openai_compat", | |
| "model_id": "qwen-plus", | |
| "base_url": "https://dashscope.aliyuncs.com/compatible-mode/v1", | |
| "env_var": "DASHSCOPE_API_KEY", | |
| "env_base_url": "DASHSCOPE_BASE_URL", | |
| "env_model_id": "DASHSCOPE_MODEL_ID", | |
| }, | |
| "Yi-Large (01.AI)": { | |
| "provider": "openai_compat", | |
| "model_id": "yi-large", | |
| "base_url": "https://api.01.ai/v1", | |
| "env_var": "YI_API_KEY", | |
| "env_base_url": "YI_BASE_URL", | |
| "env_model_id": "YI_MODEL_ID", | |
| }, | |
| } | |
| MODEL_NAMES = list(MODEL_REGISTRY.keys()) | |
| def get_model_defaults(display_name: str) -> tuple[str, str]: | |
| """Return (base_url, model_id) for a registry model, considering env overrides. | |
| Priority: env var > registry hardcoded value. | |
| """ | |
| entry = MODEL_REGISTRY.get(display_name, {}) | |
| base_url = os.environ.get(entry.get("env_base_url", ""), "") or entry.get("base_url") or "" | |
| model_id = os.environ.get(entry.get("env_model_id", ""), "") or entry.get("model_id", "") | |
| return base_url, model_id | |
| def _resolve_key(env_var: str, user_key: str | None) -> str: | |
| """Return user-provided key if non-empty, else fall back to env var.""" | |
| if user_key and user_key.strip(): | |
| return user_key.strip() | |
| key = os.environ.get(env_var, "") | |
| if not key: | |
| raise ValueError( | |
| f"No API key provided and environment variable {env_var} is not set." | |
| ) | |
| return key | |
| # --------------------------------------------------------------------------- | |
| # Provider dispatch | |
| # --------------------------------------------------------------------------- | |
| def _call_openai(model_id: str, prompt: str, api_key: str, base_url: str | None) -> str: | |
| client = OpenAI(api_key=api_key, base_url=base_url) | |
| resp = client.chat.completions.create( | |
| model=model_id, | |
| messages=[{"role": "user", "content": prompt}], | |
| ) | |
| return resp.choices[0].message.content | |
| def _call_anthropic(model_id: str, prompt: str, api_key: str) -> str: | |
| client = anthropic.Anthropic(api_key=api_key) | |
| resp = client.messages.create( | |
| model=model_id, | |
| max_tokens=4096, | |
| messages=[{"role": "user", "content": prompt}], | |
| ) | |
| return resp.content[0].text | |
| def _call_gemini(model_id: str, prompt: str, api_key: str) -> str: | |
| client = genai.Client(api_key=api_key) | |
| resp = client.models.generate_content(model=model_id, contents=prompt) | |
| return resp.text | |
| def call_model( | |
| display_name: str, | |
| prompt: str, | |
| user_key: str | None = None, | |
| user_base_url: str | None = None, | |
| user_model_id: str | None = None, | |
| ) -> str: | |
| """Call a reference model from the registry. | |
| User-supplied base_url / model_id override env-var defaults, which in turn | |
| override the hardcoded registry values. | |
| """ | |
| entry = MODEL_REGISTRY.get(display_name) | |
| if entry is None: | |
| raise ValueError(f"Unknown model: {display_name}") | |
| api_key = _resolve_key(entry["env_var"], user_key) | |
| provider = entry["provider"] | |
| # Resolve: user input > env var > registry default | |
| default_base_url, default_model_id = get_model_defaults(display_name) | |
| model_id = (user_model_id.strip() if user_model_id and user_model_id.strip() else "") or default_model_id | |
| base_url = (user_base_url.strip() if user_base_url and user_base_url.strip() else "") or default_base_url or None | |
| if provider in ("openai", "openai_compat"): | |
| return _call_openai(model_id, prompt, api_key, base_url) | |
| elif provider == "anthropic": | |
| return _call_anthropic(model_id, prompt, api_key) | |
| elif provider == "gemini": | |
| return _call_gemini(model_id, prompt, api_key) | |
| else: | |
| raise ValueError(f"Unknown provider: {provider}") | |
| def call_custom_endpoint( | |
| base_url: str, model_name: str, prompt: str, api_key: str | |
| ) -> str: | |
| """Call a user-supplied Dify application endpoint (left column). | |
| Dify API docs: https://docs.dify.ai/en/guides/application-publishing/developing-with-apis | |
| base_url should be the Dify API base, e.g. https://api.dify.ai/v1 | |
| The endpoint called is {base_url}/chat-messages (for Chat apps). | |
| """ | |
| if not base_url or not base_url.strip(): | |
| raise ValueError("API endpoint URL is required for your Dify model.") | |
| if not api_key or not api_key.strip(): | |
| raise ValueError("API Key (Secret Key) is required for Dify.") | |
| url = base_url.strip().rstrip("/") + "/chat-messages" | |
| headers = { | |
| "Authorization": f"Bearer {api_key.strip()}", | |
| "Content-Type": "application/json", | |
| } | |
| payload = { | |
| "inputs": {}, | |
| "query": prompt, | |
| "response_mode": "blocking", | |
| "user": "llm-compare-user", | |
| } | |
| resp = requests.post(url, json=payload, headers=headers, timeout=120) | |
| resp.raise_for_status() | |
| data = resp.json() | |
| answer = data.get("answer", "") | |
| if not answer: | |
| raise ValueError(f"Dify returned no answer. Full response: {data}") | |
| return answer | |