| from __future__ import annotations |
|
|
| import asyncio |
| import base64 |
| import contextlib |
| import hashlib |
| import json |
| import os |
| import re |
| import sqlite3 |
| import time |
| import uuid |
| import xml.etree.ElementTree as ET |
| from contextlib import asynccontextmanager |
| from datetime import UTC, datetime, timedelta |
| from pathlib import Path |
| from typing import Any |
| from urllib.parse import urlparse |
| from zoneinfo import ZoneInfo |
|
|
| import httpx |
| from fastapi import Depends, FastAPI, Header, HTTPException, Request, status |
| from fastapi.middleware.gzip import GZipMiddleware |
| from fastapi.responses import HTMLResponse, JSONResponse, StreamingResponse |
| from fastapi.staticfiles import StaticFiles |
|
|
|
|
| BASE_DIR = Path(__file__).resolve().parent.parent |
| STATIC_DIR = BASE_DIR / "static" |
| DB_PATH = Path(os.getenv("DATABASE_PATH", BASE_DIR / "data.sqlite3")) |
| RAW_NVIDIA_API_BASE = os.getenv("NVIDIA_API_BASE", os.getenv("NIM_BASE_URL", "https://integrate.api.nvidia.com/v1")).rstrip("/") |
| NVIDIA_API_BASE = RAW_NVIDIA_API_BASE if RAW_NVIDIA_API_BASE.endswith("/v1") else f"{RAW_NVIDIA_API_BASE}/v1" |
| CHAT_COMPLETIONS_URL = f"{NVIDIA_API_BASE}/chat/completions" |
| MODELS_URL = f"{NVIDIA_API_BASE}/models" |
| REQUEST_TIMEOUT_SECONDS = float(os.getenv("REQUEST_TIMEOUT_SECONDS", "180")) |
| MAX_UPSTREAM_CONNECTIONS = int(os.getenv("MAX_UPSTREAM_CONNECTIONS", "512")) |
| MAX_KEEPALIVE_CONNECTIONS = int(os.getenv("MAX_KEEPALIVE_CONNECTIONS", "128")) |
| MODEL_SYNC_INTERVAL_MINUTES = int(os.getenv("MODEL_SYNC_INTERVAL_MINUTES", "30")) |
| PUBLIC_HISTORY_BUCKETS = int(os.getenv("PUBLIC_HISTORY_BUCKETS", "22")) |
| HEALTH_SUMMARY_WINDOW_MINUTES = 120 |
| UPSTREAM_TIMEOUT_RETRIES = 1 |
| BUCKET_MINUTES = 10 |
| DEFAULT_MONITORED_MODELS = "z-ai/glm5,z-ai/glm4.7,minimaxai/minimax-m2.5,minimaxai/minimax-m2.7,moonshotai/kimi-k2.5,deepseek-ai/deepseek-v3.2,google/gemma-4-31b-it,qwen/qwen3.5-397b-a17b" |
| MODEL_LIST = [item.strip() for item in os.getenv("MODEL_LIST", DEFAULT_MONITORED_MODELS).split(",") if item.strip()] |
| APP_TIMEZONE = ZoneInfo(os.getenv("APP_TIMEZONE", "Asia/Shanghai")) |
| ANTHROPIC_API_VERSION = "2023-06-01" |
| ANTHROPIC_INTERLEAVED_THINKING_BETA = "interleaved-thinking-2025-05-14" |
| ANTHROPIC_MIN_THINKING_BUDGET_TOKENS = 1024 |
| ANTHROPIC_SERVER_TOOL_MAX_ITERATIONS = 8 |
| ANTHROPIC_SERVER_TOOL_PREFIXES = ( |
| "web_search_", |
| "web_fetch_", |
| "code_execution_", |
| "advisor_", |
| "tool_search_tool_", |
| "mcp_toolset", |
| ) |
| WEB_SEARCH_RSS_URL = os.getenv("WEB_SEARCH_RSS_URL", "https://www.bing.com/search") |
| WEB_SEARCH_DEFAULT_MAX_RESULTS = int(os.getenv("WEB_SEARCH_DEFAULT_MAX_RESULTS", "5")) |
| WEB_SEARCH_MAX_QUERY_LENGTH = int(os.getenv("WEB_SEARCH_MAX_QUERY_LENGTH", "512")) |
|
|
| http_client: httpx.AsyncClient | None = None |
| model_cache: list[dict[str, Any]] = [] |
| model_cache_synced_at: str | None = None |
| model_cache_lock: asyncio.Lock | None = None |
| model_sync_task: asyncio.Task[None] | None = None |
| THINK_TAG_PATTERN = re.compile(r"<think>(.*?)</think>", re.DOTALL | re.IGNORECASE) |
|
|
|
|
| def utcnow() -> datetime: |
| return datetime.now(APP_TIMEZONE) |
|
|
|
|
| def utcnow_iso() -> str: |
| return utcnow().isoformat() |
|
|
|
|
| def json_dumps(value: Any) -> str: |
| return json.dumps(value, ensure_ascii=False) |
|
|
|
|
| def hash_api_key(api_key: str) -> str: |
| return hashlib.sha256(api_key.encode("utf-8")).hexdigest() |
|
|
|
|
| def normalize_provider(model_id: str, owned_by: str | None = None) -> str: |
| if owned_by: |
| return owned_by |
| if "/" in model_id: |
| return model_id.split("/", 1)[0] |
| return "unknown" |
|
|
|
|
| def bucket_start(dt: datetime | None = None) -> datetime: |
| dt = dt or utcnow() |
| minute = dt.minute - (dt.minute % BUCKET_MINUTES) |
| return dt.replace(minute=minute, second=0, microsecond=0) |
|
|
|
|
| def bucket_label(value: str) -> str: |
| try: |
| dt = datetime.fromisoformat(value) |
| except ValueError: |
| return value |
| return dt.strftime("%H:%M") |
|
|
|
|
| def get_db_connection() -> sqlite3.Connection: |
| DB_PATH.parent.mkdir(parents=True, exist_ok=True) |
| conn = sqlite3.connect(DB_PATH, check_same_thread=False, timeout=30.0) |
| conn.row_factory = sqlite3.Row |
| conn.execute("PRAGMA journal_mode=WAL") |
| conn.execute("PRAGMA synchronous=NORMAL") |
| conn.execute("PRAGMA foreign_keys=ON") |
| conn.execute("PRAGMA busy_timeout=30000") |
| return conn |
|
|
|
|
| def init_db() -> None: |
| conn = get_db_connection() |
| try: |
| conn.executescript( |
| """ |
| CREATE TABLE IF NOT EXISTS response_records ( |
| response_id TEXT PRIMARY KEY, |
| api_key_hash TEXT NOT NULL, |
| parent_response_id TEXT, |
| model_id TEXT NOT NULL, |
| request_json TEXT NOT NULL, |
| input_items_json TEXT NOT NULL, |
| output_json TEXT NOT NULL, |
| output_items_json TEXT NOT NULL, |
| status TEXT NOT NULL, |
| success INTEGER NOT NULL, |
| latency_ms REAL, |
| error_message TEXT, |
| created_at TEXT NOT NULL |
| ); |
| |
| CREATE INDEX IF NOT EXISTS idx_response_api_hash ON response_records(api_key_hash); |
| CREATE INDEX IF NOT EXISTS idx_response_parent ON response_records(parent_response_id); |
| CREATE INDEX IF NOT EXISTS idx_response_model_created ON response_records(model_id, created_at); |
| |
| CREATE TABLE IF NOT EXISTS metric_buckets ( |
| bucket_start TEXT NOT NULL, |
| model_id TEXT NOT NULL, |
| total_count INTEGER NOT NULL DEFAULT 0, |
| success_count INTEGER NOT NULL DEFAULT 0, |
| total_latency_ms REAL NOT NULL DEFAULT 0, |
| PRIMARY KEY(bucket_start, model_id) |
| ); |
| |
| CREATE TABLE IF NOT EXISTS gateway_totals ( |
| id INTEGER PRIMARY KEY CHECK(id = 1), |
| total_requests INTEGER NOT NULL DEFAULT 0, |
| total_success INTEGER NOT NULL DEFAULT 0, |
| total_latency_ms REAL NOT NULL DEFAULT 0, |
| updated_at TEXT NOT NULL |
| ); |
| |
| CREATE TABLE IF NOT EXISTS official_models_cache ( |
| id TEXT PRIMARY KEY, |
| object TEXT NOT NULL, |
| created INTEGER, |
| owned_by TEXT, |
| synced_at TEXT NOT NULL |
| ); |
| """ |
| ) |
| conn.execute( |
| """ |
| INSERT OR IGNORE INTO gateway_totals (id, total_requests, total_success, total_latency_ms, updated_at) |
| VALUES (1, 0, 0, 0, ?) |
| """, |
| (utcnow_iso(),), |
| ) |
| conn.commit() |
| finally: |
| conn.close() |
|
|
|
|
| async def run_db(fn, *args, **kwargs): |
| return await asyncio.to_thread(fn, *args, **kwargs) |
|
|
|
|
| async def get_http_client() -> httpx.AsyncClient: |
| global http_client |
| if http_client is None or http_client.is_closed: |
| limits = httpx.Limits( |
| max_connections=MAX_UPSTREAM_CONNECTIONS, |
| max_keepalive_connections=MAX_KEEPALIVE_CONNECTIONS, |
| ) |
| http_client = httpx.AsyncClient(timeout=REQUEST_TIMEOUT_SECONDS, limits=limits) |
| return http_client |
|
|
|
|
| async def get_model_cache_lock() -> asyncio.Lock: |
| global model_cache_lock |
| if model_cache_lock is None: |
| model_cache_lock = asyncio.Lock() |
| return model_cache_lock |
|
|
|
|
| def load_cached_models_from_db() -> tuple[list[dict[str, Any]], str | None]: |
| conn = get_db_connection() |
| try: |
| rows = conn.execute( |
| "SELECT id, object, created, owned_by, synced_at FROM official_models_cache ORDER BY id ASC" |
| ).fetchall() |
| if not rows: |
| return [], None |
| synced_at = rows[0]["synced_at"] |
| models = [ |
| { |
| "id": row["id"], |
| "object": row["object"], |
| "created": row["created"], |
| "owned_by": row["owned_by"], |
| } |
| for row in rows |
| ] |
| return models, synced_at |
| finally: |
| conn.close() |
|
|
|
|
| def save_models_to_db(models: list[dict[str, Any]], synced_at: str) -> None: |
| unique_models: dict[str, dict[str, Any]] = {} |
| for model in models: |
| model_id = model.get("id") |
| if model_id: |
| unique_models[model_id] = model |
|
|
| conn = get_db_connection() |
| try: |
| conn.execute("DELETE FROM official_models_cache") |
| conn.executemany( |
| """ |
| INSERT INTO official_models_cache (id, object, created, owned_by, synced_at) |
| VALUES (?, ?, ?, ?, ?) |
| """, |
| [ |
| ( |
| model_id, |
| model.get("object", "model"), |
| model.get("created"), |
| model.get("owned_by") or normalize_provider(model_id), |
| synced_at, |
| ) |
| for model_id, model in sorted(unique_models.items(), key=lambda item: item[0]) |
| ], |
| ) |
| conn.commit() |
| finally: |
| conn.close() |
|
|
|
|
| async def refresh_official_models(force: bool = False) -> list[dict[str, Any]]: |
| global model_cache, model_cache_synced_at |
| if model_cache and not force: |
| return model_cache |
| lock = await get_model_cache_lock() |
| async with lock: |
| if model_cache and not force: |
| return model_cache |
| client = await get_http_client() |
| response = await client.get(MODELS_URL, headers={"Accept": "application/json"}) |
| response.raise_for_status() |
| payload = response.json() |
| models = payload.get("data") or payload.get("models") or [] |
| normalized = [ |
| { |
| "id": item.get("id"), |
| "object": item.get("object", "model"), |
| "created": item.get("created"), |
| "owned_by": item.get("owned_by") or normalize_provider(item.get("id", "")), |
| } |
| for item in models |
| if isinstance(item, dict) and item.get("id") |
| ] |
| synced_at = utcnow_iso() |
| await run_db(save_models_to_db, normalized, synced_at) |
| model_cache = normalized |
| model_cache_synced_at = synced_at |
| return normalized |
|
|
|
|
| async def model_sync_loop() -> None: |
| while True: |
| try: |
| await refresh_official_models(force=True) |
| except Exception: |
| pass |
| await asyncio.sleep(max(300, MODEL_SYNC_INTERVAL_MINUTES * 60)) |
|
|
|
|
| def extract_user_api_key( |
| authorization: str | None = Header(default=None), |
| x_api_key: str | None = Header(default=None), |
| x_nvidia_api_key: str | None = Header(default=None), |
| ) -> str: |
| token: str | None = None |
| if authorization and authorization.startswith("Bearer "): |
| token = authorization.removeprefix("Bearer ").strip() |
| elif x_api_key: |
| token = x_api_key.strip() |
| elif x_nvidia_api_key: |
| token = x_nvidia_api_key.strip() |
| if not token: |
| raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="请通过 Authorization Bearer 或 X-API-Key 提供你的 NIM Key。") |
| return token |
|
|
| def normalize_content(content: Any, role: str) -> list[dict[str, Any]]: |
| if content is None: |
| return [] |
| if isinstance(content, str): |
| return [{"type": "output_text" if role == "assistant" else "input_text", "text": content}] |
| if isinstance(content, list): |
| normalized: list[dict[str, Any]] = [] |
| for part in content: |
| if isinstance(part, str): |
| normalized.append({"type": "output_text" if role == "assistant" else "input_text", "text": part}) |
| continue |
| if not isinstance(part, dict): |
| normalized.append({"type": "input_text", "text": str(part)}) |
| continue |
| if part.get("type") in {"input_text", "output_text", "text", "tool_call", "function_call"}: |
| normalized.append(part) |
| continue |
| if "text" in part: |
| normalized.append({"type": part.get("type", "input_text"), "text": part.get("text", "")}) |
| return normalized |
| if isinstance(content, dict): |
| if "text" in content: |
| return [{"type": content.get("type", "input_text"), "text": content.get("text", "")}] |
| return [{"type": "input_text", "text": json_dumps(content)}] |
| return [{"type": "input_text", "text": str(content)}] |
|
|
|
|
| def normalize_input_items(value: Any) -> list[dict[str, Any]]: |
| if value is None: |
| return [] |
| if isinstance(value, str): |
| return [{"type": "message", "role": "user", "content": [{"type": "input_text", "text": value}]}] |
| if isinstance(value, dict): |
| value = [value] |
|
|
| items: list[dict[str, Any]] = [] |
| for item in value: |
| if isinstance(item, str): |
| items.append({"type": "message", "role": "user", "content": [{"type": "input_text", "text": item}]}) |
| continue |
| if not isinstance(item, dict): |
| items.append({"type": "message", "role": "user", "content": [{"type": "input_text", "text": str(item)}]}) |
| continue |
| item_type = item.get("type") |
| if item_type == "message" or item.get("role"): |
| role = item.get("role", "user") |
| items.append({"type": "message", "role": role, "content": normalize_content(item.get("content"), role)}) |
| continue |
| if item_type == "function_call_output": |
| output = item.get("output") |
| if not isinstance(output, str): |
| output = json_dumps(output) if output is not None else "" |
| items.append({"type": "function_call_output", "call_id": item.get("call_id"), "output": output}) |
| continue |
| if item_type == "function_call": |
| arguments = item.get("arguments", "{}") |
| if not isinstance(arguments, str): |
| arguments = json_dumps(arguments) |
| items.append({ |
| "type": "function_call", |
| "call_id": item.get("call_id") or f"call_{uuid.uuid4().hex[:12]}", |
| "name": item.get("name"), |
| "arguments": arguments, |
| }) |
| continue |
| if item_type in {"input_text", "output_text", "text"}: |
| items.append({"type": "message", "role": "user", "content": [{"type": "input_text", "text": item.get("text", "")}]}) |
| continue |
| items.append({"type": "message", "role": "user", "content": [{"type": "input_text", "text": json_dumps(item)}]}) |
| return items |
|
|
|
|
| def extract_text_from_content(content: Any) -> str: |
| if content is None: |
| return "" |
| if isinstance(content, str): |
| return content |
| if isinstance(content, dict): |
| if "text" in content: |
| return str(content.get("text", "")) |
| return json_dumps(content) |
| if isinstance(content, list): |
| chunks: list[str] = [] |
| for part in content: |
| if isinstance(part, str): |
| chunks.append(part) |
| elif isinstance(part, dict) and part.get("type") in {"input_text", "output_text", "text"}: |
| chunks.append(str(part.get("text", ""))) |
| return "\n".join(filter(None, chunks)) |
| return str(content) |
|
|
|
|
| def items_to_chat_messages(items: list[dict[str, Any]]) -> list[dict[str, Any]]: |
| messages: list[dict[str, Any]] = [] |
| pending_tool_calls: list[dict[str, Any]] = [] |
|
|
| def flush_pending_tool_calls() -> None: |
| nonlocal pending_tool_calls |
| if pending_tool_calls: |
| messages.append({"role": "assistant", "content": "", "tool_calls": pending_tool_calls}) |
| pending_tool_calls = [] |
|
|
| for item in items: |
| item_type = item.get("type") |
| if item_type == "function_call": |
| pending_tool_calls.append( |
| { |
| "id": item.get("call_id") or f"call_{uuid.uuid4().hex[:12]}", |
| "type": "function", |
| "function": {"name": item.get("name"), "arguments": item.get("arguments", "{}")}, |
| } |
| ) |
| continue |
| if item_type == "function_call_output": |
| flush_pending_tool_calls() |
| messages.append({"role": "tool", "tool_call_id": item.get("call_id"), "content": item.get("output", "")}) |
| continue |
| if item_type != "message": |
| continue |
| flush_pending_tool_calls() |
| role = item.get("role", "user") |
| text_value = extract_text_from_content(item.get("content")) |
| if role in {"system", "developer"}: |
| messages.append({"role": "system", "content": text_value}) |
| elif role == "assistant": |
| messages.append({"role": "assistant", "content": text_value}) |
| else: |
| messages.append({"role": role, "content": text_value}) |
|
|
| flush_pending_tool_calls() |
| return [message for message in messages if message.get("content") is not None or message.get("tool_calls")] |
|
|
|
|
| def response_tools_to_chat_tools(tools: Any) -> list[dict[str, Any]]: |
| normalized: list[dict[str, Any]] = [] |
| for tool in tools or []: |
| if not isinstance(tool, dict) or tool.get("type") != "function": |
| continue |
| function_payload = tool.get("function") if isinstance(tool.get("function"), dict) else tool |
| name = function_payload.get("name") |
| if not name: |
| continue |
| normalized.append( |
| { |
| "type": "function", |
| "function": { |
| "name": name, |
| "description": function_payload.get("description"), |
| "parameters": function_payload.get("parameters") or {"type": "object", "properties": {}}, |
| }, |
| } |
| ) |
| return normalized |
|
|
|
|
| def normalize_tool_choice(tool_choice: Any, tools: list[dict[str, Any]]) -> tuple[Any, list[dict[str, Any]]]: |
| if tool_choice is None: |
| return None, tools |
| if isinstance(tool_choice, str): |
| return tool_choice, tools |
| if not isinstance(tool_choice, dict): |
| return None, tools |
| if tool_choice.get("type") == "function": |
| function_name = tool_choice.get("name") or (tool_choice.get("function") or {}).get("name") |
| if function_name: |
| return {"type": "function", "function": {"name": function_name}}, tools |
| if tool_choice.get("type") == "allowed_tools": |
| allowed = tool_choice.get("tools") or [] |
| allowed_names = { |
| entry if isinstance(entry, str) else entry.get("name") |
| for entry in allowed |
| if entry is not None |
| } |
| filtered_tools = [tool for tool in tools if tool["function"]["name"] in allowed_names] |
| mode = tool_choice.get("mode", "auto") |
| return mode if isinstance(mode, str) else "auto", filtered_tools |
| return None, tools |
|
|
|
|
| def build_chat_payload(body: dict[str, Any], items: list[dict[str, Any]]) -> dict[str, Any]: |
| tools = response_tools_to_chat_tools(body.get("tools")) |
| tool_choice, tools = normalize_tool_choice(body.get("tool_choice"), tools) |
| payload: dict[str, Any] = {"model": body.get("model"), "messages": items_to_chat_messages(items)} |
| if tools: |
| payload["tools"] = tools |
| if tool_choice is not None: |
| payload["tool_choice"] = tool_choice |
| if body.get("temperature") is not None: |
| payload["temperature"] = body.get("temperature") |
| if body.get("top_p") is not None: |
| payload["top_p"] = body.get("top_p") |
| if body.get("parallel_tool_calls") is not None: |
| payload["parallel_tool_calls"] = body.get("parallel_tool_calls") |
| if body.get("max_output_tokens") is not None: |
| payload["max_tokens"] = body.get("max_output_tokens") |
| if body.get("instructions"): |
| payload["messages"] = [{"role": "system", "content": body["instructions"]}] + payload["messages"] |
| text_config = body.get("text") or {} |
| text_format = text_config.get("format") if isinstance(text_config, dict) else None |
| if isinstance(text_format, dict): |
| if text_format.get("type") == "json_object": |
| payload["response_format"] = {"type": "json_object"} |
| elif text_format.get("type") == "json_schema": |
| payload["response_format"] = {"type": "json_schema", "json_schema": text_format.get("json_schema") or {}} |
| return payload |
|
|
|
|
| def anthropic_content_to_blocks(content: Any) -> list[dict[str, Any]]: |
| if content is None: |
| return [] |
| if isinstance(content, str): |
| return [{"type": "text", "text": content}] |
| if isinstance(content, dict): |
| return [content] |
| if not isinstance(content, list): |
| return [{"type": "text", "text": str(content)}] |
|
|
| blocks: list[dict[str, Any]] = [] |
| for part in content: |
| if isinstance(part, str): |
| blocks.append({"type": "text", "text": part}) |
| elif isinstance(part, dict): |
| blocks.append(part) |
| else: |
| blocks.append({"type": "text", "text": str(part)}) |
| return blocks |
|
|
|
|
| def extract_anthropic_text(value: Any) -> str: |
| if value is None: |
| return "" |
| if isinstance(value, str): |
| return value |
| if isinstance(value, dict): |
| value_type = value.get("type") |
| if value_type in {"text", "input_text", "output_text"}: |
| return str(value.get("text", "")) |
| if value_type == "thinking": |
| return str(value.get("thinking", "")) |
| if value_type == "redacted_thinking": |
| return "" |
| if value_type in {"image", "document"}: |
| return f"[{value_type} content omitted]" |
| if "content" in value: |
| return extract_anthropic_text(value.get("content")) |
| return json_dumps(value) |
| if isinstance(value, list): |
| chunks: list[str] = [] |
| for part in value: |
| text_value = extract_anthropic_text(part) |
| if text_value: |
| chunks.append(text_value) |
| return "\n".join(chunks) |
| return str(value) |
|
|
|
|
| def anthropic_result_block_to_text(block: dict[str, Any]) -> str: |
| content = block.get("content") |
| if block.get("type") == "tool_result" and not block.get("is_error"): |
| if isinstance(content, str): |
| return content |
| plain_text = extract_anthropic_text(content) |
| if plain_text and plain_text != json_dumps(content): |
| return plain_text |
|
|
| payload: dict[str, Any] = {} |
| if block.get("is_error") is not None: |
| payload["is_error"] = bool(block.get("is_error")) |
| payload["content"] = content |
| return json_dumps(payload) |
|
|
|
|
| def is_anthropic_tool_result_block(block: dict[str, Any]) -> bool: |
| block_type = block.get("type") |
| return isinstance(block_type, str) and (block_type == "tool_result" or block_type.endswith("_tool_result")) |
|
|
|
|
| def parse_anthropic_beta_header(value: str | None) -> set[str]: |
| if not value: |
| return set() |
| return {item.strip() for item in value.split(",") if item.strip()} |
|
|
|
|
| def has_anthropic_message_prefill(messages: Any) -> bool: |
| if not isinstance(messages, list) or not messages: |
| return False |
| last_message = messages[-1] |
| if not isinstance(last_message, dict): |
| return False |
| if last_message.get("role") != "assistant": |
| return False |
| blocks = anthropic_content_to_blocks(last_message.get("content")) |
| return any(block.get("type") != "redacted_thinking" for block in blocks) |
|
|
|
|
| def is_forced_anthropic_tool_choice(tool_choice: Any) -> bool: |
| if isinstance(tool_choice, str): |
| return tool_choice in {"any", "required"} |
| if not isinstance(tool_choice, dict): |
| return False |
| return tool_choice.get("type") in {"any", "tool"} |
|
|
|
|
| def build_anthropic_thinking_config(body: dict[str, Any], anthropic_beta: str | None) -> dict[str, Any]: |
| thinking = body.get("thinking") |
| enabled = False |
| budget_tokens = None |
|
|
| if thinking is None: |
| return {"enabled": False, "budget_tokens": None, "synthetic_signature": True} |
| if isinstance(thinking, bool): |
| thinking = {"type": "enabled" if thinking else "disabled"} |
| if not isinstance(thinking, dict): |
| raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="thinking 字段必须是对象或布尔值。") |
|
|
| raw_thinking_type = thinking.get("type") |
| if isinstance(raw_thinking_type, bool): |
| thinking_type = "enabled" if raw_thinking_type else "disabled" |
| elif isinstance(raw_thinking_type, str): |
| lowered_type = raw_thinking_type.strip().lower() |
| if lowered_type in {"enabled", "enable", "on", "true"}: |
| thinking_type = "enabled" |
| elif lowered_type in {"disabled", "disable", "off", "false"}: |
| thinking_type = "disabled" |
| else: |
| thinking_type = lowered_type |
| elif "enabled" in thinking: |
| thinking_type = "enabled" if bool(thinking.get("enabled")) else "disabled" |
| elif any(key in thinking for key in ("budget_tokens", "budgetTokens")): |
| thinking_type = "enabled" |
| else: |
| raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="thinking.type 仅支持 enabled 或 disabled。") |
|
|
| if thinking_type == "disabled": |
| return {"enabled": False, "budget_tokens": None, "synthetic_signature": True} |
| if thinking_type != "enabled": |
| raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="thinking.type 仅支持 enabled 或 disabled。") |
|
|
| budget_tokens = thinking.get("budget_tokens", thinking.get("budgetTokens")) |
| if isinstance(budget_tokens, str): |
| try: |
| budget_tokens = int(budget_tokens.strip()) |
| except ValueError as exc: |
| raise HTTPException( |
| status_code=status.HTTP_400_BAD_REQUEST, |
| detail="thinking.enabled 时必须提供整数类型的 budget_tokens。", |
| ) from exc |
| if not isinstance(budget_tokens, int): |
| raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="thinking.enabled 时必须提供整数类型的 budget_tokens。") |
| if budget_tokens < ANTHROPIC_MIN_THINKING_BUDGET_TOKENS: |
| raise HTTPException( |
| status_code=status.HTTP_400_BAD_REQUEST, |
| detail=f"thinking.budget_tokens 不能小于 {ANTHROPIC_MIN_THINKING_BUDGET_TOKENS}。", |
| ) |
|
|
| max_tokens = body.get("max_tokens") |
| beta_flags = parse_anthropic_beta_header(anthropic_beta) |
| interleaved_thinking = ( |
| ANTHROPIC_INTERLEAVED_THINKING_BETA in beta_flags |
| and bool(body.get("tools")) |
| ) |
| if isinstance(max_tokens, int) and budget_tokens >= max_tokens and not interleaved_thinking: |
| raise HTTPException( |
| status_code=status.HTTP_400_BAD_REQUEST, |
| detail="thinking.budget_tokens 必须小于 max_tokens;只有启用 interleaved thinking beta 且使用工具时可以例外。", |
| ) |
|
|
| if body.get("temperature") is not None: |
| raise HTTPException( |
| status_code=status.HTTP_400_BAD_REQUEST, |
| detail="Anthropic thinking 模式不支持自定义 temperature。", |
| ) |
|
|
| top_p = body.get("top_p") |
| if top_p is not None: |
| try: |
| top_p_value = float(top_p) |
| except (TypeError, ValueError) as exc: |
| raise HTTPException( |
| status_code=status.HTTP_400_BAD_REQUEST, |
| detail="top_p 必须是数值。", |
| ) from exc |
| if not (0.95 <= top_p_value <= 1.0): |
| raise HTTPException( |
| status_code=status.HTTP_400_BAD_REQUEST, |
| detail="Anthropic thinking 模式下 top_p 只能在 0.95 到 1.0 之间。", |
| ) |
|
|
| if is_forced_anthropic_tool_choice(body.get("tool_choice")): |
| raise HTTPException( |
| status_code=status.HTTP_400_BAD_REQUEST, |
| detail="Anthropic thinking 模式不支持 forced tool choice。", |
| ) |
|
|
| if has_anthropic_message_prefill(body.get("messages")): |
| raise HTTPException( |
| status_code=status.HTTP_400_BAD_REQUEST, |
| detail="Anthropic thinking 模式不支持 assistant prefill。", |
| ) |
|
|
| enabled = True |
| return { |
| "enabled": enabled, |
| "budget_tokens": budget_tokens, |
| "interleaved": interleaved_thinking, |
| "synthetic_signature": True, |
| } |
|
|
|
|
| def build_synthetic_thinking_signature(model_id: str | None, thinking_text: str) -> str: |
| digest = hashlib.sha256(f"{model_id or 'unknown'}\n{thinking_text}".encode("utf-8")).digest() |
| encoded = base64.urlsafe_b64encode(digest).decode("ascii").rstrip("=") |
| return f"nimthinking_{encoded}" |
|
|
|
|
| def split_anthropic_thinking_blocks(text: str, model_id: str | None) -> list[dict[str, Any]]: |
| if not text: |
| return [] |
|
|
| blocks: list[dict[str, Any]] = [] |
| cursor = 0 |
| for match in THINK_TAG_PATTERN.finditer(text): |
| before = text[cursor:match.start()] |
| if before.strip(): |
| blocks.append({"type": "text", "text": before.strip()}) |
|
|
| thinking_text = match.group(1).strip() |
| if thinking_text: |
| blocks.append( |
| { |
| "type": "thinking", |
| "thinking": thinking_text, |
| "signature": build_synthetic_thinking_signature(model_id, thinking_text), |
| } |
| ) |
| cursor = match.end() |
|
|
| if cursor == 0: |
| return [{"type": "text", "text": text.strip()}] if text.strip() else [] |
|
|
| after = text[cursor:] |
| if after.strip(): |
| blocks.append({"type": "text", "text": after.strip()}) |
| return blocks |
|
|
|
|
| def build_bash_tool_schema() -> dict[str, Any]: |
| return { |
| "type": "object", |
| "properties": { |
| "command": { |
| "type": "string", |
| "description": "The shell command to execute in the persistent bash session.", |
| }, |
| "restart": { |
| "type": "boolean", |
| "description": "Restart the persistent bash session before running the next command.", |
| }, |
| }, |
| } |
|
|
|
|
| def build_text_editor_tool_schema(tool_type: str | None) -> dict[str, Any]: |
| commands = ["view", "create", "str_replace", "insert"] |
| if tool_type and (tool_type.endswith("20241022") or tool_type.endswith("20250124")): |
| commands.append("undo_edit") |
| return { |
| "type": "object", |
| "properties": { |
| "command": { |
| "type": "string", |
| "enum": commands, |
| "description": "The editor operation to perform.", |
| }, |
| "path": {"type": "string", "description": "Absolute or relative path to the target file."}, |
| "view_range": { |
| "type": "array", |
| "items": {"type": "integer"}, |
| "minItems": 2, |
| "maxItems": 2, |
| "description": "Inclusive start/end line numbers for view operations.", |
| }, |
| "file_text": {"type": "string", "description": "Full file contents when creating a file."}, |
| "old_str": {"type": "string", "description": "Existing text to replace."}, |
| "new_str": {"type": "string", "description": "Replacement text for str_replace."}, |
| "insert_line": {"type": "integer", "description": "Line number to insert text before."}, |
| "insert_text": {"type": "string", "description": "Text to insert."}, |
| }, |
| } |
|
|
|
|
| def build_memory_tool_schema() -> dict[str, Any]: |
| return { |
| "type": "object", |
| "properties": { |
| "command": { |
| "type": "string", |
| "enum": ["view", "create", "str_replace", "insert", "delete", "rename"], |
| "description": "The memory operation to perform under the memory directory.", |
| }, |
| "path": {"type": "string", "description": "Path to the memory file."}, |
| "new_path": {"type": "string", "description": "New path when renaming a memory file."}, |
| "view_range": { |
| "type": "array", |
| "items": {"type": "integer"}, |
| "minItems": 2, |
| "maxItems": 2, |
| "description": "Inclusive start/end line numbers for view operations.", |
| }, |
| "file_text": {"type": "string", "description": "Full file contents when creating a memory file."}, |
| "old_str": {"type": "string", "description": "Existing text to replace."}, |
| "new_str": {"type": "string", "description": "Replacement text for str_replace."}, |
| "insert_line": {"type": "integer", "description": "Line number to insert text before."}, |
| "insert_text": {"type": "string", "description": "Text to insert."}, |
| }, |
| } |
|
|
|
|
| def build_computer_tool_schema(tool_type: str | None) -> dict[str, Any]: |
| actions = [ |
| "screenshot", |
| "left_click", |
| "right_click", |
| "middle_click", |
| "double_click", |
| "triple_click", |
| "mouse_move", |
| "left_click_drag", |
| "left_mouse_down", |
| "left_mouse_up", |
| "scroll", |
| "type", |
| "key", |
| "hold_key", |
| "wait", |
| ] |
| if tool_type and tool_type.endswith("20251124"): |
| actions.append("zoom") |
| return { |
| "type": "object", |
| "properties": { |
| "action": { |
| "type": "string", |
| "enum": actions, |
| "description": "The computer action to perform.", |
| }, |
| "coordinate": { |
| "type": "array", |
| "items": {"type": "integer"}, |
| "minItems": 2, |
| "maxItems": 2, |
| "description": "X/Y coordinate for click and move actions.", |
| }, |
| "start_coordinate": { |
| "type": "array", |
| "items": {"type": "integer"}, |
| "minItems": 2, |
| "maxItems": 2, |
| "description": "Start coordinate for drag actions.", |
| }, |
| "end_coordinate": { |
| "type": "array", |
| "items": {"type": "integer"}, |
| "minItems": 2, |
| "maxItems": 2, |
| "description": "End coordinate for drag actions.", |
| }, |
| "text": {"type": "string", "description": "Text to type or zoom target text."}, |
| "key": {"type": "string", "description": "Keyboard key or key chord to press."}, |
| "duration": {"type": "number", "description": "Optional wait duration in seconds."}, |
| "scroll_direction": { |
| "type": "string", |
| "enum": ["up", "down", "left", "right"], |
| "description": "Scroll direction.", |
| }, |
| "scroll_amount": {"type": "integer", "description": "Scroll distance in pixels or wheel units."}, |
| "region": { |
| "type": "array", |
| "items": {"type": "integer"}, |
| "minItems": 4, |
| "maxItems": 4, |
| "description": "Optional region [left, top, width, height] for screenshots.", |
| }, |
| "modifiers": { |
| "type": "array", |
| "items": {"type": "string"}, |
| "description": "Modifier keys to hold during the action.", |
| }, |
| }, |
| } |
|
|
|
|
| def build_web_search_tool_schema() -> dict[str, Any]: |
| return { |
| "type": "object", |
| "properties": { |
| "query": { |
| "type": "string", |
| "description": "The web search query to execute.", |
| }, |
| }, |
| "required": ["query"], |
| } |
|
|
|
|
| def normalize_domain_name(value: str) -> str: |
| parsed = urlparse(value if "://" in value else f"https://{value}") |
| host = (parsed.netloc or parsed.path or "").strip().lower() |
| if host.startswith("www."): |
| host = host[4:] |
| return host.split("/", 1)[0] |
|
|
|
|
| def normalize_domain_list(values: Any) -> list[str]: |
| if not isinstance(values, list): |
| return [] |
| normalized = [normalize_domain_name(str(item)) for item in values if str(item).strip()] |
| return [item for item in normalized if item] |
|
|
|
|
| def domain_matches(host: str, domain: str) -> bool: |
| return host == domain or host.endswith(f".{domain}") |
|
|
|
|
| def filter_web_search_url(url: str, allowed_domains: list[str], blocked_domains: list[str]) -> bool: |
| try: |
| host = normalize_domain_name(url) |
| except Exception: |
| return False |
| if not host: |
| return False |
| if allowed_domains and not any(domain_matches(host, domain) for domain in allowed_domains): |
| return False |
| if blocked_domains and any(domain_matches(host, domain) for domain in blocked_domains): |
| return False |
| return True |
|
|
|
|
| def build_web_search_tool_description(raw_tool: dict[str, Any]) -> str: |
| description = raw_tool.get("description") or "Search the public web and return concise result metadata." |
| allowed_domains = normalize_domain_list(raw_tool.get("allowed_domains")) |
| blocked_domains = normalize_domain_list(raw_tool.get("blocked_domains")) |
| if allowed_domains: |
| description = f"{description}\n\nOnly search and return results from these domains: {', '.join(allowed_domains)}." |
| if blocked_domains: |
| description = f"{description}\n\nDo not return results from these domains: {', '.join(blocked_domains)}." |
| return description |
|
|
|
|
| def build_server_tool_result_error_block( |
| result_type: str, |
| tool_use_id: str, |
| error_code: str, |
| message: str | None = None, |
| ) -> dict[str, Any]: |
| error_item: dict[str, Any] = {"type": f"{result_type}_error", "error_code": error_code} |
| if message: |
| error_item["message"] = message |
| return { |
| "type": result_type, |
| "tool_use_id": tool_use_id, |
| "content": error_item, |
| "is_error": True, |
| } |
|
|
|
|
| def build_web_search_encrypted_content(payload: dict[str, Any]) -> str: |
| raw = json_dumps(payload).encode("utf-8") |
| encoded = base64.urlsafe_b64encode(raw).decode("ascii").rstrip("=") |
| return f"nimsearch_{encoded}" |
|
|
|
|
| def parse_rss_results(xml_text: str) -> list[dict[str, Any]]: |
| try: |
| root = ET.fromstring(xml_text) |
| except ET.ParseError: |
| return [] |
| results: list[dict[str, Any]] = [] |
| for item in root.findall(".//item"): |
| title = (item.findtext("title") or "").strip() |
| link = (item.findtext("link") or "").strip() |
| description = (item.findtext("description") or "").strip() |
| pub_date = (item.findtext("pubDate") or "").strip() |
| if not title or not link: |
| continue |
| results.append( |
| { |
| "title": title, |
| "url": link, |
| "snippet": description, |
| "page_age": pub_date or None, |
| } |
| ) |
| return results |
|
|
|
|
| def append_anthropic_tool_examples(description: str | None, examples: Any) -> str | None: |
| if not isinstance(examples, list) or not examples: |
| return description |
| snippet = json_dumps(examples[:2]) |
| if description: |
| return f"{description}\n\nInput examples: {snippet}" |
| return f"Input examples: {snippet}" |
|
|
|
|
| def normalize_anthropic_tool_name(tool_type: str | None, fallback_name: str | None) -> str | None: |
| if fallback_name: |
| return fallback_name |
| if not tool_type: |
| return None |
| if tool_type.startswith("bash_"): |
| return "bash" |
| if tool_type.startswith("text_editor_"): |
| return "str_replace_based_edit_tool" |
| if tool_type.startswith("computer_"): |
| return "computer" |
| if tool_type.startswith("memory_"): |
| return "memory" |
| return None |
|
|
|
|
| def anthropic_tools_to_chat_tools(tools: Any) -> tuple[list[dict[str, Any]], dict[str, dict[str, Any]]]: |
| normalized: list[dict[str, Any]] = [] |
| metadata_by_name: dict[str, dict[str, Any]] = {} |
|
|
| for raw_tool in tools or []: |
| if not isinstance(raw_tool, dict): |
| continue |
| tool_type = raw_tool.get("type") |
| tool_name = normalize_anthropic_tool_name(tool_type, raw_tool.get("name")) |
| allowed_callers = raw_tool.get("allowed_callers") or ["direct"] |
| if isinstance(allowed_callers, (list, tuple, set)): |
| allowed_callers_set = {str(item) for item in allowed_callers if item} |
| else: |
| allowed_callers_set = {str(allowed_callers)} |
|
|
| if "direct" not in allowed_callers_set: |
| raise HTTPException( |
| status_code=status.HTTP_400_BAD_REQUEST, |
| detail=f"当前网关暂不支持仅允许 programmatic caller 的工具:{tool_name or tool_type or 'unknown'}。", |
| ) |
|
|
| if tool_type == "mcp_toolset": |
| raise HTTPException( |
| status_code=status.HTTP_400_BAD_REQUEST, |
| detail=f"当前网关暂不支持 Anthropic 服务端工具 '{tool_type}';请改用客户端工具或自定义 tools。", |
| ) |
|
|
| if isinstance(tool_type, str) and tool_type.startswith("web_search_"): |
| tool_name = tool_name or "web_search" |
| if normalize_domain_list(raw_tool.get("allowed_domains")) and normalize_domain_list(raw_tool.get("blocked_domains")): |
| raise HTTPException( |
| status_code=status.HTTP_400_BAD_REQUEST, |
| detail="web_search 工具不能同时设置 allowed_domains 和 blocked_domains。", |
| ) |
| description = build_web_search_tool_description(raw_tool) |
| parameters = build_web_search_tool_schema() |
| normalized.append( |
| { |
| "type": "function", |
| "function": { |
| "name": tool_name, |
| "description": description, |
| "parameters": parameters, |
| }, |
| } |
| ) |
| metadata_by_name[tool_name] = { |
| "anthropic_type": tool_type, |
| "allowed_callers": sorted(allowed_callers_set) or ["direct"], |
| "server_execution": "web_search", |
| "allowed_domains": normalize_domain_list(raw_tool.get("allowed_domains")), |
| "blocked_domains": normalize_domain_list(raw_tool.get("blocked_domains")), |
| "user_location": raw_tool.get("user_location") if isinstance(raw_tool.get("user_location"), dict) else None, |
| "max_uses": raw_tool.get("max_uses") if isinstance(raw_tool.get("max_uses"), int) and raw_tool.get("max_uses") > 0 else None, |
| "uses": 0, |
| } |
| continue |
|
|
| if isinstance(tool_type, str) and tool_type.startswith(ANTHROPIC_SERVER_TOOL_PREFIXES): |
| raise HTTPException( |
| status_code=status.HTTP_400_BAD_REQUEST, |
| detail=f"当前网关暂不支持 Anthropic 服务端工具 '{tool_type}';目前已兼容 web_search 工具,其他服务端工具仍需单独适配。", |
| ) |
|
|
| if isinstance(tool_type, str) and tool_type.startswith("bash_"): |
| description = raw_tool.get("description") or "Execute shell commands in a persistent bash session." |
| parameters = build_bash_tool_schema() |
| elif isinstance(tool_type, str) and tool_type.startswith("text_editor_"): |
| description = raw_tool.get("description") or "View and edit text files with command-based operations." |
| parameters = build_text_editor_tool_schema(tool_type) |
| elif isinstance(tool_type, str) and tool_type.startswith("computer_"): |
| description = raw_tool.get("description") or "Interact with a computer UI using screenshots, clicks, typing, keys, scrolling, and drag actions." |
| parameters = build_computer_tool_schema(tool_type) |
| elif isinstance(tool_type, str) and tool_type.startswith("memory_"): |
| description = raw_tool.get("description") or "Read and edit persistent memory files with command-based operations." |
| parameters = build_memory_tool_schema() |
| else: |
| if not tool_name: |
| continue |
| if tool_type and raw_tool.get("input_schema") is None: |
| raise HTTPException( |
| status_code=status.HTTP_400_BAD_REQUEST, |
| detail=f"当前网关暂不支持 Anthropic 工具类型 '{tool_type}'。", |
| ) |
| description = raw_tool.get("description") |
| parameters = raw_tool.get("input_schema") or {"type": "object", "properties": {}} |
|
|
| description = append_anthropic_tool_examples(description, raw_tool.get("input_examples")) |
| normalized.append( |
| { |
| "type": "function", |
| "function": { |
| "name": tool_name, |
| "description": description, |
| "parameters": parameters, |
| }, |
| } |
| ) |
| metadata_by_name[tool_name] = { |
| "anthropic_type": tool_type or "custom", |
| "allowed_callers": sorted(allowed_callers_set) or ["direct"], |
| } |
|
|
| return normalized, metadata_by_name |
|
|
|
|
| def normalize_anthropic_tool_choice(tool_choice: Any) -> tuple[Any, bool | None]: |
| if tool_choice is None: |
| return None, None |
| if isinstance(tool_choice, str): |
| if tool_choice == "any": |
| return "required", None |
| return tool_choice, None |
| if not isinstance(tool_choice, dict): |
| return None, None |
|
|
| parallel_tool_calls = None |
| if tool_choice.get("disable_parallel_tool_use") is not None: |
| parallel_tool_calls = not bool(tool_choice.get("disable_parallel_tool_use")) |
|
|
| choice_type = tool_choice.get("type") |
| if choice_type in {"auto", "none"}: |
| return choice_type, parallel_tool_calls |
| if choice_type == "any": |
| return "required", parallel_tool_calls |
| if choice_type == "tool": |
| tool_name = tool_choice.get("name") |
| if tool_name: |
| return {"type": "function", "function": {"name": tool_name}}, parallel_tool_calls |
| return None, parallel_tool_calls |
|
|
|
|
| def anthropic_messages_to_chat_messages(body: dict[str, Any]) -> list[dict[str, Any]]: |
| chat_messages: list[dict[str, Any]] = [] |
| system_text = extract_anthropic_text(body.get("system")) |
| if system_text: |
| chat_messages.append({"role": "system", "content": system_text}) |
|
|
| for raw_message in body.get("messages") or []: |
| if isinstance(raw_message, str): |
| chat_messages.append({"role": "user", "content": raw_message}) |
| continue |
| if not isinstance(raw_message, dict): |
| chat_messages.append({"role": "user", "content": str(raw_message)}) |
| continue |
|
|
| role = raw_message.get("role", "user") |
| blocks = anthropic_content_to_blocks(raw_message.get("content")) |
|
|
| if role == "assistant": |
| text_chunks: list[str] = [] |
| tool_calls: list[dict[str, Any]] = [] |
| for block in blocks: |
| block_type = block.get("type") |
| if block_type == "text": |
| text_chunks.append(str(block.get("text", ""))) |
| continue |
| if block_type == "thinking": |
| thinking_text = str(block.get("thinking", "")).strip() |
| if thinking_text: |
| text_chunks.append(f"<think>\n{thinking_text}\n</think>") |
| continue |
| if block_type == "redacted_thinking": |
| continue |
| if block_type in {"tool_use", "server_tool_use"}: |
| arguments = block.get("input") |
| if not isinstance(arguments, str): |
| arguments = json_dumps(arguments or {}) |
| tool_calls.append( |
| { |
| "id": block.get("id") or f"toolu_{uuid.uuid4().hex[:24]}", |
| "type": "function", |
| "function": { |
| "name": block.get("name"), |
| "arguments": arguments, |
| }, |
| } |
| ) |
| continue |
| block_text = extract_anthropic_text(block) |
| if block_text: |
| text_chunks.append(block_text) |
|
|
| if text_chunks or tool_calls: |
| assistant_message: dict[str, Any] = { |
| "role": "assistant", |
| "content": "\n".join(filter(None, text_chunks)), |
| } |
| if tool_calls: |
| assistant_message["tool_calls"] = tool_calls |
| chat_messages.append(assistant_message) |
| continue |
|
|
| pending_text: list[str] = [] |
|
|
| def flush_pending_text() -> None: |
| nonlocal pending_text |
| text_value = "\n".join(filter(None, pending_text)) |
| if text_value: |
| target_role = "system" if role in {"system", "developer"} else "user" |
| chat_messages.append({"role": target_role, "content": text_value}) |
| pending_text = [] |
|
|
| for block in blocks: |
| if is_anthropic_tool_result_block(block): |
| flush_pending_text() |
| tool_use_id = block.get("tool_use_id") or block.get("id") |
| result_text = anthropic_result_block_to_text(block) |
| if tool_use_id: |
| chat_messages.append( |
| { |
| "role": "tool", |
| "tool_call_id": tool_use_id, |
| "content": result_text, |
| } |
| ) |
| elif result_text: |
| pending_text.append(result_text) |
| continue |
|
|
| block_text = extract_anthropic_text(block) |
| if block_text: |
| pending_text.append(block_text) |
|
|
| flush_pending_text() |
|
|
| return [message for message in chat_messages if message.get("content") is not None or message.get("tool_calls")] |
|
|
|
|
| def build_anthropic_chat_payload( |
| body: dict[str, Any], |
| anthropic_beta: str | None = None, |
| ) -> tuple[dict[str, Any], list[dict[str, Any]], dict[str, dict[str, Any]], dict[str, Any]]: |
| thinking_config = build_anthropic_thinking_config(body, anthropic_beta) |
| if body.get("mcp_servers"): |
| raise HTTPException( |
| status_code=status.HTTP_400_BAD_REQUEST, |
| detail="当前网关暂不支持 Anthropic mcp_servers 直连能力。", |
| ) |
|
|
| messages = anthropic_messages_to_chat_messages(body) |
| tools, tool_metadata = anthropic_tools_to_chat_tools(body.get("tools")) |
| tool_choice, parallel_tool_calls = normalize_anthropic_tool_choice(body.get("tool_choice")) |
| payload: dict[str, Any] = { |
| "model": body.get("model"), |
| "messages": messages, |
| "max_tokens": body.get("max_tokens"), |
| } |
| if thinking_config["enabled"]: |
| payload["chat_template_kwargs"] = {"enable_thinking": True} |
| payload["nvext"] = {"max_thinking_tokens": thinking_config["budget_tokens"]} |
| elif body.get("thinking") is not None: |
| payload["chat_template_kwargs"] = {"enable_thinking": False} |
| if tools: |
| payload["tools"] = tools |
| if tool_choice is not None: |
| payload["tool_choice"] = tool_choice |
| if parallel_tool_calls is not None and tools: |
| payload["parallel_tool_calls"] = parallel_tool_calls |
| if body.get("temperature") is not None: |
| payload["temperature"] = body.get("temperature") |
| if body.get("top_p") is not None: |
| payload["top_p"] = body.get("top_p") |
| if body.get("stop_sequences"): |
| payload["stop"] = body.get("stop_sequences") |
| return payload, messages, tool_metadata, thinking_config |
|
|
|
|
| def parse_anthropic_tool_input(arguments: Any) -> dict[str, Any]: |
| if isinstance(arguments, dict): |
| return arguments |
| if arguments is None: |
| return {} |
| if not isinstance(arguments, str): |
| return {"value": arguments} |
| try: |
| parsed = json.loads(arguments) |
| except Exception: |
| return {"raw_input": arguments} |
| if isinstance(parsed, dict): |
| return parsed |
| return {"value": parsed} |
|
|
|
|
| def normalize_anthropic_message_id(message_id: Any) -> str: |
| if isinstance(message_id, str) and message_id.startswith("msg_"): |
| return message_id |
| return f"msg_{uuid.uuid4().hex[:24]}" |
|
|
|
|
| def normalize_anthropic_tool_use_id(tool_use_id: Any) -> str: |
| if isinstance(tool_use_id, str) and tool_use_id.startswith(("toolu_", "srvtoolu_")): |
| return tool_use_id |
| return f"toolu_{uuid.uuid4().hex[:24]}" |
|
|
|
|
| def anthropic_block_to_chat_tool_call(block: dict[str, Any]) -> dict[str, Any]: |
| arguments = block.get("input") or {} |
| if not isinstance(arguments, str): |
| arguments = json_dumps(arguments) |
| return { |
| "id": block.get("id") or normalize_anthropic_tool_use_id(None), |
| "type": "function", |
| "function": { |
| "name": block.get("name"), |
| "arguments": arguments, |
| }, |
| } |
|
|
|
|
| def anthropic_blocks_to_chat_assistant_message(blocks: list[dict[str, Any]]) -> dict[str, Any]: |
| text_chunks: list[str] = [] |
| tool_calls: list[dict[str, Any]] = [] |
|
|
| for block in blocks: |
| block_type = block.get("type") |
| if block_type == "text": |
| text_value = str(block.get("text", "")).strip() |
| if text_value: |
| text_chunks.append(text_value) |
| continue |
| if block_type == "thinking": |
| thinking_text = str(block.get("thinking", "")).strip() |
| if thinking_text: |
| text_chunks.append(f"<think>\n{thinking_text}\n</think>") |
| continue |
| if block_type in {"tool_use", "server_tool_use"}: |
| tool_calls.append(anthropic_block_to_chat_tool_call(block)) |
|
|
| message: dict[str, Any] = {"role": "assistant", "content": "\n".join(filter(None, text_chunks))} |
| if tool_calls: |
| message["tool_calls"] = tool_calls |
| return message |
|
|
|
|
| def is_server_tool_block(block: dict[str, Any], tool_metadata: dict[str, dict[str, Any]]) -> bool: |
| if block.get("type") != "server_tool_use": |
| return False |
| tool_name = block.get("name") |
| return bool(tool_name and tool_metadata.get(tool_name, {}).get("server_execution")) |
|
|
|
|
| def count_server_tool_blocks(content_blocks: list[dict[str, Any]]) -> dict[str, int]: |
| counts: dict[str, int] = {} |
| for block in content_blocks: |
| if block.get("type") != "server_tool_use": |
| continue |
| name = str(block.get("name") or "unknown") |
| counts[name] = counts.get(name, 0) + 1 |
| return counts |
|
|
|
|
| def merge_anthropic_usage(base_usage: dict[str, Any], extra_usage: dict[str, Any]) -> dict[str, Any]: |
| merged = { |
| "input_tokens": (base_usage or {}).get("input_tokens") or 0, |
| "output_tokens": (base_usage or {}).get("output_tokens") or 0, |
| } |
| merged["input_tokens"] += (extra_usage or {}).get("input_tokens") or 0 |
| merged["output_tokens"] += (extra_usage or {}).get("output_tokens") or 0 |
|
|
| base_server = dict((base_usage or {}).get("server_tool_use") or {}) |
| extra_server = (extra_usage or {}).get("server_tool_use") or {} |
| for key, value in extra_server.items(): |
| base_server[key] = (base_server.get(key) or 0) + (value or 0) |
| if base_server: |
| merged["server_tool_use"] = base_server |
| return merged |
|
|
|
|
| async def execute_web_search_tool(block: dict[str, Any], metadata: dict[str, Any]) -> tuple[dict[str, Any], str, dict[str, int]]: |
| tool_use_id = block.get("id") or normalize_anthropic_tool_use_id(None) |
| tool_input = block.get("input") if isinstance(block.get("input"), dict) else {} |
| query = str( |
| tool_input.get("query") |
| or tool_input.get("q") |
| or tool_input.get("search_query") |
| or "" |
| ).strip() |
|
|
| if not query: |
| result_block = build_server_tool_result_error_block( |
| "web_search_tool_result", |
| tool_use_id, |
| "invalid_input", |
| "web_search 需要 query 字段。", |
| ) |
| return result_block, json_dumps({"error": "missing_query"}), {"web_search_requests": 1} |
|
|
| if len(query) > WEB_SEARCH_MAX_QUERY_LENGTH: |
| result_block = build_server_tool_result_error_block( |
| "web_search_tool_result", |
| tool_use_id, |
| "query_too_long", |
| f"query 长度不能超过 {WEB_SEARCH_MAX_QUERY_LENGTH} 个字符。", |
| ) |
| return result_block, json_dumps({"error": "query_too_long", "query": query}), {"web_search_requests": 1} |
|
|
| max_uses = metadata.get("max_uses") |
| if isinstance(max_uses, int) and metadata.get("uses", 0) >= max_uses: |
| result_block = build_server_tool_result_error_block( |
| "web_search_tool_result", |
| tool_use_id, |
| "max_uses_exceeded", |
| "web_search 已达到当前请求允许的最大调用次数。", |
| ) |
| return result_block, json_dumps({"error": "max_uses_exceeded", "query": query}), {"web_search_requests": 0} |
|
|
| metadata["uses"] = metadata.get("uses", 0) + 1 |
|
|
| search_query = query |
| user_location = metadata.get("user_location") if isinstance(metadata.get("user_location"), dict) else None |
| if user_location: |
| location_parts = [ |
| str(user_location.get(key)).strip() |
| for key in ("city", "region", "country") |
| if user_location.get(key) |
| ] |
| if location_parts: |
| search_query = f"{query} {' '.join(location_parts)}" |
|
|
| client = await get_http_client() |
| try: |
| response = await client.get( |
| WEB_SEARCH_RSS_URL, |
| params={"format": "rss", "q": search_query}, |
| headers={ |
| "Accept": "application/rss+xml, application/xml;q=0.9, text/xml;q=0.8", |
| "User-Agent": "nim4cc/1.0 (+https://github.com/Geek66666/nim4cc)", |
| }, |
| ) |
| response.raise_for_status() |
| except httpx.HTTPStatusError as exc: |
| result_block = build_server_tool_result_error_block( |
| "web_search_tool_result", |
| tool_use_id, |
| "search_unavailable", |
| f"web_search 上游请求失败:HTTP {exc.response.status_code}", |
| ) |
| return result_block, json_dumps({"error": "search_unavailable", "query": query}), {"web_search_requests": 1} |
| except httpx.RequestError as exc: |
| result_block = build_server_tool_result_error_block( |
| "web_search_tool_result", |
| tool_use_id, |
| "search_unavailable", |
| f"web_search 请求异常:{exc}", |
| ) |
| return result_block, json_dumps({"error": "search_unavailable", "query": query}), {"web_search_requests": 1} |
|
|
| allowed_domains = metadata.get("allowed_domains") or [] |
| blocked_domains = metadata.get("blocked_domains") or [] |
| parsed_results = [ |
| item |
| for item in parse_rss_results(response.text) |
| if filter_web_search_url(item.get("url", ""), allowed_domains, blocked_domains) |
| ][:max(1, WEB_SEARCH_DEFAULT_MAX_RESULTS)] |
|
|
| outward_results: list[dict[str, Any]] = [] |
| model_results: list[dict[str, Any]] = [] |
| for result in parsed_results: |
| encrypted_content = build_web_search_encrypted_content( |
| { |
| "query": query, |
| "url": result.get("url"), |
| "title": result.get("title"), |
| "snippet": result.get("snippet"), |
| "page_age": result.get("page_age"), |
| "retrieved_at": utcnow_iso(), |
| } |
| ) |
| outward_item = { |
| "type": "web_search_result", |
| "url": result.get("url"), |
| "title": result.get("title"), |
| "encrypted_content": encrypted_content, |
| } |
| if result.get("page_age"): |
| outward_item["page_age"] = result.get("page_age") |
| outward_results.append(outward_item) |
| model_results.append( |
| { |
| "url": result.get("url"), |
| "title": result.get("title"), |
| "snippet": result.get("snippet"), |
| "page_age": result.get("page_age"), |
| "encrypted_content": encrypted_content, |
| } |
| ) |
|
|
| result_block = { |
| "type": "web_search_tool_result", |
| "tool_use_id": tool_use_id, |
| "content": outward_results, |
| } |
| model_payload = json_dumps({"query": query, "results": model_results}) |
| return result_block, model_payload, {"web_search_requests": 1} |
|
|
|
|
| async def execute_anthropic_server_tool_block( |
| block: dict[str, Any], |
| tool_metadata: dict[str, dict[str, Any]], |
| ) -> tuple[dict[str, Any], str, dict[str, int]]: |
| tool_name = block.get("name") |
| metadata = tool_metadata.get(tool_name or "") |
| if not metadata: |
| result_block = build_server_tool_result_error_block( |
| "tool_result", |
| block.get("id") or normalize_anthropic_tool_use_id(None), |
| "unknown_tool", |
| f"未找到服务端工具 {tool_name} 的元数据。", |
| ) |
| return result_block, json_dumps({"error": "unknown_tool"}), {} |
|
|
| server_execution = metadata.get("server_execution") |
| if server_execution == "web_search": |
| return await execute_web_search_tool(block, metadata) |
|
|
| result_block = build_server_tool_result_error_block( |
| "tool_result", |
| block.get("id") or normalize_anthropic_tool_use_id(None), |
| "unsupported_tool", |
| f"当前网关尚未实现服务端工具 {tool_name}。", |
| ) |
| return result_block, json_dumps({"error": "unsupported_tool"}), {} |
|
|
|
|
| def anthropic_stop_reason(finish_reason: str | None, content_blocks: list[dict[str, Any]]) -> str: |
| if any(block.get("type") == "tool_use" for block in content_blocks): |
| return "tool_use" |
| if finish_reason == "length": |
| return "max_tokens" |
| if finish_reason == "tool_calls": |
| return "tool_use" |
| return "end_turn" |
|
|
|
|
| def chat_completion_to_anthropic_message( |
| body: dict[str, Any], |
| upstream_json: dict[str, Any], |
| tool_metadata: dict[str, dict[str, Any]], |
| thinking_config: dict[str, Any], |
| ) -> dict[str, Any]: |
| upstream_message, finish_reason = extract_upstream_message(upstream_json) |
| assistant_text, tool_calls = extract_text_and_tool_calls(upstream_message) |
| content_blocks: list[dict[str, Any]] = [] |
| if assistant_text: |
| if thinking_config.get("enabled"): |
| content_blocks.extend(split_anthropic_thinking_blocks(assistant_text, body.get("model"))) |
| else: |
| content_blocks.append({"type": "text", "text": assistant_text}) |
| for tool_call in tool_calls: |
| tool_name = tool_call.get("name") |
| tool_info = tool_metadata.get(tool_name or "", {}) |
| is_server_tool = bool(tool_info.get("server_execution")) |
| content_blocks.append( |
| { |
| "type": "server_tool_use" if is_server_tool else "tool_use", |
| "id": ( |
| tool_call.get("id") |
| if isinstance(tool_call.get("id"), str) and tool_call.get("id").startswith("srvtoolu_") |
| else (f"srvtoolu_{uuid.uuid4().hex[:24]}" if is_server_tool else normalize_anthropic_tool_use_id(tool_call.get("id"))) |
| ), |
| "name": tool_name, |
| "input": parse_anthropic_tool_input(tool_call.get("arguments")), |
| **({} if is_server_tool else {"caller": {"type": "direct"}}), |
| } |
| ) |
|
|
| usage = upstream_json.get("usage") or {} |
| return { |
| "id": normalize_anthropic_message_id(upstream_json.get("id")), |
| "type": "message", |
| "role": "assistant", |
| "content": content_blocks, |
| "model": body.get("model"), |
| "stop_reason": anthropic_stop_reason(finish_reason, content_blocks), |
| "stop_sequence": None, |
| "usage": { |
| "input_tokens": usage.get("prompt_tokens"), |
| "output_tokens": usage.get("completion_tokens"), |
| }, |
| } |
|
|
|
|
| def build_anthropic_storage_items(body: dict[str, Any]) -> list[dict[str, Any]]: |
| items: list[dict[str, Any]] = [] |
| if body.get("system") is not None: |
| items.append({"role": "system", "content": body.get("system")}) |
| for message in body.get("messages") or []: |
| if isinstance(message, dict): |
| items.append(message) |
| else: |
| items.append({"role": "user", "content": str(message)}) |
| return items |
|
|
|
|
| async def create_anthropic_message_with_server_tools( |
| api_key: str, |
| body: dict[str, Any], |
| chat_payload: dict[str, Any], |
| tool_metadata: dict[str, dict[str, Any]], |
| thinking_config: dict[str, Any], |
| ) -> tuple[dict[str, Any], float]: |
| request_messages = list(chat_payload.get("messages") or []) |
| base_payload = {key: value for key, value in chat_payload.items() if key != "messages"} |
| accumulated_content: list[dict[str, Any]] = [] |
| accumulated_usage: dict[str, Any] = {} |
| last_message_payload: dict[str, Any] | None = None |
| started = time.perf_counter() |
|
|
| for _ in range(ANTHROPIC_SERVER_TOOL_MAX_ITERATIONS): |
| loop_payload = {**base_payload, "messages": request_messages} |
| upstream_json, _latency_ms = await post_nvidia_chat_completion(api_key, loop_payload) |
| loop_message = chat_completion_to_anthropic_message(body, upstream_json, tool_metadata, thinking_config) |
| last_message_payload = loop_message |
| accumulated_usage = merge_anthropic_usage(accumulated_usage, loop_message.get("usage") or {}) |
|
|
| loop_content = list(loop_message.get("content") or []) |
| accumulated_content.extend(loop_content) |
|
|
| client_tool_blocks = [block for block in loop_content if block.get("type") == "tool_use"] |
| server_tool_blocks = [block for block in loop_content if is_server_tool_block(block, tool_metadata)] |
|
|
| if client_tool_blocks: |
| break |
| if not server_tool_blocks: |
| break |
|
|
| request_messages.append(anthropic_blocks_to_chat_assistant_message(loop_content)) |
| executed_usage: dict[str, Any] = {} |
| for block in server_tool_blocks: |
| result_block, model_tool_content, usage_increment = await execute_anthropic_server_tool_block(block, tool_metadata) |
| accumulated_content.append(result_block) |
| executed_usage = merge_anthropic_usage(executed_usage, {"server_tool_use": usage_increment}) |
| request_messages.append( |
| { |
| "role": "tool", |
| "tool_call_id": block.get("id"), |
| "content": model_tool_content, |
| } |
| ) |
| accumulated_usage = merge_anthropic_usage(accumulated_usage, executed_usage) |
| else: |
| accumulated_content.append( |
| build_server_tool_result_error_block( |
| "tool_result", |
| f"srvtoolu_{uuid.uuid4().hex[:24]}", |
| "max_iterations_exceeded", |
| "服务端工具循环次数过多,已停止继续执行。", |
| ) |
| ) |
|
|
| if last_message_payload is None: |
| raise HTTPException( |
| status_code=status.HTTP_502_BAD_GATEWAY, |
| detail="上游未返回有效的 Anthropic 兼容消息。", |
| ) |
|
|
| message_payload = { |
| **last_message_payload, |
| "content": accumulated_content, |
| "stop_reason": anthropic_stop_reason(None, accumulated_content), |
| "usage": accumulated_usage, |
| } |
| latency_ms = round((time.perf_counter() - started) * 1000, 2) |
| return message_payload, latency_ms |
|
|
|
|
| def build_anthropic_streaming_response(message_payload: dict[str, Any], anthropic_version: str | None) -> StreamingResponse: |
| async def event_stream() -> Any: |
| opening_message = { |
| **message_payload, |
| "content": [], |
| "stop_reason": None, |
| "stop_sequence": None, |
| } |
| yield f"event: message_start\ndata: {json_dumps({'type': 'message_start', 'message': opening_message})}\n\n" |
|
|
| for index, block in enumerate(message_payload.get("content") or []): |
| block_type = block.get("type") |
| if block_type == "text": |
| yield f"event: content_block_start\ndata: {json_dumps({'type': 'content_block_start', 'index': index, 'content_block': {'type': 'text', 'text': ''}})}\n\n" |
| text_value = str(block.get("text", "")) |
| if text_value: |
| yield f"event: content_block_delta\ndata: {json_dumps({'type': 'content_block_delta', 'index': index, 'delta': {'type': 'text_delta', 'text': text_value}})}\n\n" |
| yield f"event: content_block_stop\ndata: {json_dumps({'type': 'content_block_stop', 'index': index})}\n\n" |
| continue |
|
|
| if block_type == "thinking": |
| content_block = {"type": "thinking", "thinking": ""} |
| yield f"event: content_block_start\ndata: {json_dumps({'type': 'content_block_start', 'index': index, 'content_block': content_block})}\n\n" |
| thinking_text = str(block.get("thinking", "")) |
| if thinking_text: |
| yield f"event: content_block_delta\ndata: {json_dumps({'type': 'content_block_delta', 'index': index, 'delta': {'type': 'thinking_delta', 'thinking': thinking_text}})}\n\n" |
| signature = block.get("signature") |
| if signature: |
| yield f"event: content_block_delta\ndata: {json_dumps({'type': 'content_block_delta', 'index': index, 'delta': {'type': 'signature_delta', 'signature': signature}})}\n\n" |
| yield f"event: content_block_stop\ndata: {json_dumps({'type': 'content_block_stop', 'index': index})}\n\n" |
| continue |
|
|
| if block_type == "tool_use": |
| content_block = {**block, "input": {}} |
| yield f"event: content_block_start\ndata: {json_dumps({'type': 'content_block_start', 'index': index, 'content_block': content_block})}\n\n" |
| input_json = json_dumps(block.get("input") or {}) |
| if input_json: |
| yield f"event: content_block_delta\ndata: {json_dumps({'type': 'content_block_delta', 'index': index, 'delta': {'type': 'input_json_delta', 'partial_json': input_json}})}\n\n" |
| yield f"event: content_block_stop\ndata: {json_dumps({'type': 'content_block_stop', 'index': index})}\n\n" |
| continue |
|
|
| if block_type == "server_tool_use": |
| content_block = {**block, "input": {}} |
| yield f"event: content_block_start\ndata: {json_dumps({'type': 'content_block_start', 'index': index, 'content_block': content_block})}\n\n" |
| input_json = json_dumps(block.get("input") or {}) |
| if input_json: |
| yield f"event: content_block_delta\ndata: {json_dumps({'type': 'content_block_delta', 'index': index, 'delta': {'type': 'input_json_delta', 'partial_json': input_json}})}\n\n" |
| yield f"event: content_block_stop\ndata: {json_dumps({'type': 'content_block_stop', 'index': index})}\n\n" |
| continue |
|
|
| if isinstance(block_type, str) and block_type.endswith("_tool_result"): |
| yield f"event: content_block_start\ndata: {json_dumps({'type': 'content_block_start', 'index': index, 'content_block': block})}\n\n" |
| yield f"event: content_block_stop\ndata: {json_dumps({'type': 'content_block_stop', 'index': index})}\n\n" |
| continue |
|
|
| yield f"event: message_delta\ndata: {json_dumps({'type': 'message_delta', 'delta': {'stop_reason': message_payload.get('stop_reason'), 'stop_sequence': message_payload.get('stop_sequence')}, 'usage': {'output_tokens': (message_payload.get('usage') or {}).get('output_tokens')}})}\n\n" |
| yield "event: message_stop\ndata: {\"type\": \"message_stop\"}\n\n" |
|
|
| headers = { |
| "anthropic-version": anthropic_version or ANTHROPIC_API_VERSION, |
| "cache-control": "no-cache", |
| } |
| return StreamingResponse(event_stream(), media_type="text/event-stream", headers=headers) |
|
|
|
|
| def extract_upstream_message(upstream_json: dict[str, Any]) -> tuple[dict[str, Any], str | None]: |
| choices = upstream_json.get("choices") or [] |
| if not choices: |
| return {}, None |
| choice = choices[0] or {} |
| return choice.get("message") or {}, choice.get("finish_reason") |
|
|
|
|
| def extract_text_and_tool_calls(message: dict[str, Any]) -> tuple[str, list[dict[str, Any]]]: |
| content = message.get("content") |
| text_chunks: list[str] = [] |
| tool_calls: list[dict[str, Any]] = [] |
|
|
| if isinstance(content, str): |
| text_chunks.append(content) |
| elif isinstance(content, list): |
| for part in content: |
| if isinstance(part, str): |
| text_chunks.append(part) |
| continue |
| if not isinstance(part, dict): |
| text_chunks.append(str(part)) |
| continue |
| if part.get("type") in {"input_text", "output_text", "text"}: |
| text_chunks.append(str(part.get("text", ""))) |
| continue |
| if part.get("type") in {"tool_call", "function_call"}: |
| arguments = part.get("arguments") or "{}" |
| if not isinstance(arguments, str): |
| arguments = json_dumps(arguments) |
| tool_calls.append({ |
| "id": part.get("id") or part.get("call_id") or f"call_{uuid.uuid4().hex[:12]}", |
| "name": part.get("name"), |
| "arguments": arguments, |
| }) |
|
|
| for tool_call in message.get("tool_calls") or []: |
| if not isinstance(tool_call, dict): |
| continue |
| function_data = tool_call.get("function") or {} |
| arguments = function_data.get("arguments") or tool_call.get("arguments") or "{}" |
| if not isinstance(arguments, str): |
| arguments = json_dumps(arguments) |
| tool_calls.append( |
| { |
| "id": tool_call.get("id") or f"call_{uuid.uuid4().hex[:12]}", |
| "name": function_data.get("name") or tool_call.get("name"), |
| "arguments": arguments, |
| } |
| ) |
|
|
| deduped: list[dict[str, Any]] = [] |
| seen_ids: set[str] = set() |
| for tool_call in tool_calls: |
| if tool_call["id"] in seen_ids: |
| continue |
| seen_ids.add(tool_call["id"]) |
| deduped.append(tool_call) |
| return "\n".join(filter(None, text_chunks)).strip(), deduped |
|
|
| def build_choice_alias(output_items: list[dict[str, Any]], finish_reason: str | None) -> list[dict[str, Any]]: |
| content_parts: list[dict[str, Any]] = [] |
| for item in output_items: |
| if item.get("type") == "message": |
| for part in item.get("content", []): |
| content_parts.append({"type": part.get("type", "output_text"), "text": part.get("text", "")}) |
| elif item.get("type") == "function_call": |
| arguments = item.get("arguments") or "{}" |
| try: |
| parsed_arguments = json.loads(arguments) |
| except Exception: |
| parsed_arguments = arguments |
| content_parts.append({"type": "tool_call", "id": item.get("call_id"), "name": item.get("name"), "arguments": parsed_arguments}) |
| return [{"index": 0, "message": {"role": "assistant", "content": content_parts}, "finish_reason": finish_reason or "stop"}] |
|
|
|
|
| def chat_completion_to_response(body: dict[str, Any], upstream_json: dict[str, Any], previous_response_id: str | None) -> dict[str, Any]: |
| upstream_message, finish_reason = extract_upstream_message(upstream_json) |
| assistant_text, tool_calls = extract_text_and_tool_calls(upstream_message) |
| response_id = upstream_json.get("id") or f"resp_{uuid.uuid4().hex}" |
| output_items: list[dict[str, Any]] = [] |
| if assistant_text: |
| output_items.append({ |
| "id": f"msg_{uuid.uuid4().hex[:24]}", |
| "type": "message", |
| "status": "completed", |
| "role": "assistant", |
| "content": [{"type": "output_text", "text": assistant_text, "annotations": []}], |
| }) |
| for tool_call in tool_calls: |
| output_items.append({ |
| "id": f"fc_{uuid.uuid4().hex[:24]}", |
| "type": "function_call", |
| "status": "completed", |
| "call_id": tool_call["id"], |
| "name": tool_call.get("name"), |
| "arguments": tool_call.get("arguments", "{}"), |
| }) |
| usage = upstream_json.get("usage") or {} |
| return { |
| "id": response_id, |
| "object": "response", |
| "created_at": int(time.time()), |
| "status": "completed", |
| "model": body.get("model"), |
| "output": output_items, |
| "output_text": assistant_text, |
| "parallel_tool_calls": bool(body.get("parallel_tool_calls", True)), |
| "previous_response_id": previous_response_id, |
| "store": True, |
| "text": body.get("text") or {"format": {"type": "text"}}, |
| "usage": { |
| "input_tokens": usage.get("prompt_tokens"), |
| "output_tokens": usage.get("completion_tokens"), |
| "total_tokens": usage.get("total_tokens"), |
| }, |
| "choices": build_choice_alias(output_items, finish_reason), |
| "upstream": { |
| "id": upstream_json.get("id"), |
| "object": upstream_json.get("object", "chat.completion"), |
| "finish_reason": finish_reason or "stop", |
| }, |
| } |
|
|
|
|
| def store_success_record(api_key_hash: str, model_id: str, request_body: dict[str, Any], input_items: list[dict[str, Any]], response_payload: dict[str, Any], latency_ms: float) -> None: |
| conn = get_db_connection() |
| try: |
| now = utcnow_iso() |
| bucket = bucket_start().isoformat() |
| output_items = response_payload.get("output") |
| if not isinstance(output_items, list): |
| output_items = response_payload.get("content") |
| if not isinstance(output_items, list): |
| output_items = [] |
| conn.execute( |
| """ |
| INSERT OR REPLACE INTO response_records ( |
| response_id, api_key_hash, parent_response_id, model_id, request_json, |
| input_items_json, output_json, output_items_json, status, success, |
| latency_ms, error_message, created_at |
| ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) |
| """, |
| ( |
| response_payload["id"], |
| api_key_hash, |
| request_body.get("previous_response_id"), |
| model_id, |
| json_dumps(request_body), |
| json_dumps(input_items), |
| json_dumps(response_payload), |
| json_dumps(output_items), |
| response_payload.get("status", "completed"), |
| 1, |
| latency_ms, |
| None, |
| now, |
| ), |
| ) |
| conn.execute( |
| """ |
| INSERT INTO metric_buckets (bucket_start, model_id, total_count, success_count, total_latency_ms) |
| VALUES (?, ?, 1, 1, ?) |
| ON CONFLICT(bucket_start, model_id) DO UPDATE SET |
| total_count = total_count + 1, |
| success_count = success_count + 1, |
| total_latency_ms = total_latency_ms + excluded.total_latency_ms |
| """, |
| (bucket, model_id, latency_ms), |
| ) |
| conn.execute( |
| """ |
| UPDATE gateway_totals |
| SET total_requests = total_requests + 1, |
| total_success = total_success + 1, |
| total_latency_ms = total_latency_ms + ?, |
| updated_at = ? |
| WHERE id = 1 |
| """, |
| (latency_ms, now), |
| ) |
| conn.commit() |
| finally: |
| conn.close() |
|
|
|
|
| def store_failure_metric(model_id: str, error_message: str) -> None: |
| conn = get_db_connection() |
| try: |
| now = utcnow_iso() |
| bucket = bucket_start().isoformat() |
| conn.execute( |
| """ |
| INSERT INTO metric_buckets (bucket_start, model_id, total_count, success_count, total_latency_ms) |
| VALUES (?, ?, 1, 0, 0) |
| ON CONFLICT(bucket_start, model_id) DO UPDATE SET |
| total_count = total_count + 1 |
| """, |
| (bucket, model_id), |
| ) |
| conn.execute( |
| """ |
| UPDATE gateway_totals |
| SET total_requests = total_requests + 1, |
| updated_at = ? |
| WHERE id = 1 |
| """, |
| (now,), |
| ) |
| conn.commit() |
| finally: |
| conn.close() |
|
|
|
|
| def load_previous_conversation_items(api_key_hash: str, previous_response_id: str | None) -> list[dict[str, Any]]: |
| if not previous_response_id: |
| return [] |
| conn = get_db_connection() |
| try: |
| items: list[dict[str, Any]] = [] |
| current = previous_response_id |
| chain: list[sqlite3.Row] = [] |
| while current: |
| row = conn.execute( |
| "SELECT * FROM response_records WHERE response_id = ? AND api_key_hash = ?", |
| (current, api_key_hash), |
| ).fetchone() |
| if not row: |
| raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"previous_response_id '{current}' 不存在,或不属于当前 Key。") |
| chain.append(row) |
| current = row["parent_response_id"] |
| for row in reversed(chain): |
| items.extend(json.loads(row["input_items_json"])) |
| items.extend(json.loads(row["output_items_json"])) |
| return items |
| finally: |
| conn.close() |
|
|
|
|
| def load_response_record(api_key_hash: str, response_id: str) -> dict[str, Any]: |
| conn = get_db_connection() |
| try: |
| row = conn.execute( |
| "SELECT output_json FROM response_records WHERE response_id = ? AND api_key_hash = ?", |
| (response_id, api_key_hash), |
| ).fetchone() |
| if not row: |
| raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="未找到对应响应,或当前 Key 无权访问。") |
| return json.loads(row["output_json"]) |
| finally: |
| conn.close() |
|
|
|
|
| def load_dashboard_data() -> dict[str, Any]: |
| conn = get_db_connection() |
| try: |
| totals_row = conn.execute("SELECT * FROM gateway_totals WHERE id = 1").fetchone() |
| total_requests = totals_row["total_requests"] if totals_row else 0 |
| now_bucket = bucket_start() |
| bucket_points = [(now_bucket - timedelta(minutes=BUCKET_MINUTES * offset)).isoformat() for offset in reversed(range(PUBLIC_HISTORY_BUCKETS))] |
| health_window_buckets = max(1, HEALTH_SUMMARY_WINDOW_MINUTES // BUCKET_MINUTES) |
| health_window_points = [ |
| (now_bucket - timedelta(minutes=BUCKET_MINUTES * offset)).isoformat() |
| for offset in reversed(range(health_window_buckets)) |
| ] |
| placeholders = ",".join("?" for _ in MODEL_LIST) if MODEL_LIST else "''" |
| totals_by_model = { |
| row["model_id"]: row["total_count"] |
| for row in conn.execute( |
| f"SELECT model_id, COALESCE(SUM(total_count), 0) AS total_count FROM metric_buckets WHERE model_id IN ({placeholders}) GROUP BY model_id", |
| MODEL_LIST, |
| ).fetchall() |
| } if MODEL_LIST else {} |
| since_candidates = [value for value in [bucket_points[0] if bucket_points else None, health_window_points[0] if health_window_points else None] if value] |
| since = min(since_candidates) if since_candidates else utcnow_iso() |
| recent_rows = conn.execute( |
| f"SELECT bucket_start, model_id, total_count, success_count FROM metric_buckets WHERE model_id IN ({placeholders}) AND bucket_start >= ? ORDER BY bucket_start ASC", |
| [*MODEL_LIST, since], |
| ).fetchall() if MODEL_LIST else [] |
| row_map: dict[str, dict[str, sqlite3.Row]] = {} |
| for row in recent_rows: |
| row_map.setdefault(row["model_id"], {})[row["bucket_start"]] = row |
| models: list[dict[str, Any]] = [] |
| health_window_rates: list[float] = [] |
| for model_id in MODEL_LIST: |
| points: list[dict[str, Any]] = [] |
| latest_bucket_rate: float | None = None |
| for bucket_value in bucket_points: |
| row = row_map.get(model_id, {}).get(bucket_value) |
| total_count = row["total_count"] if row else 0 |
| success_count = row["success_count"] if row else 0 |
| success_rate = round((success_count / total_count) * 100, 1) if total_count else None |
| points.append( |
| { |
| "bucket_start": bucket_value, |
| "label": bucket_label(bucket_value), |
| "total_count": total_count, |
| "success_count": success_count, |
| "success_rate": success_rate, |
| } |
| ) |
| if bucket_value == bucket_points[-1] and total_count: |
| latest_bucket_rate = success_rate |
|
|
| health_window_total = 0 |
| health_window_success = 0 |
| for bucket_value in health_window_points: |
| row = row_map.get(model_id, {}).get(bucket_value) |
| if not row: |
| continue |
| health_window_total += row["total_count"] |
| health_window_success += row["success_count"] |
|
|
| health_window_rate = round((health_window_success / health_window_total) * 100, 1) if health_window_total else None |
| if health_window_rate is not None: |
| health_window_rates.append(health_window_rate) |
| models.append( |
| { |
| "model_id": model_id, |
| "provider": normalize_provider(model_id), |
| "total_calls": totals_by_model.get(model_id, 0), |
| "latest_success_rate": health_window_rate, |
| "average_success_rate": health_window_rate, |
| "health_window_success_rate": health_window_rate, |
| "health_window_total_count": health_window_total, |
| "health_window_success_count": health_window_success, |
| "health_window_minutes": HEALTH_SUMMARY_WINDOW_MINUTES, |
| "latest_bucket_success_rate": latest_bucket_rate, |
| "points": points, |
| } |
| ) |
| average_health = round(sum(health_window_rates) / len(health_window_rates), 1) if health_window_rates else None |
| return { |
| "generated_at": utcnow_iso(), |
| "bucket_minutes": BUCKET_MINUTES, |
| "health_window_minutes": HEALTH_SUMMARY_WINDOW_MINUTES, |
| "total_requests": total_requests, |
| "average_health": average_health, |
| "models": models, |
| } |
| finally: |
| conn.close() |
|
|
|
|
| def build_catalog_payload() -> dict[str, Any]: |
| grouped: dict[str, list[dict[str, Any]]] = {} |
| for model in sorted(model_cache, key=lambda item: item.get("id", "")): |
| provider = normalize_provider(model.get("id", ""), model.get("owned_by")) |
| grouped.setdefault(provider, []).append(model) |
| providers = [ |
| { |
| "provider": provider, |
| "count": len(items), |
| "models": items, |
| } |
| for provider, items in sorted(grouped.items(), key=lambda entry: entry[0].lower()) |
| ] |
| return { |
| "generated_at": utcnow_iso(), |
| "synced_at": model_cache_synced_at, |
| "total_models": len(model_cache), |
| "providers": providers, |
| } |
|
|
|
|
| async def post_nvidia_chat_completion(api_key: str, payload: dict[str, Any]) -> tuple[dict[str, Any], float]: |
| client = await get_http_client() |
| started = time.perf_counter() |
| total_attempts = UPSTREAM_TIMEOUT_RETRIES + 1 |
| last_timeout: httpx.TimeoutException | None = None |
|
|
| for attempt in range(1, total_attempts + 1): |
| try: |
| response = await client.post( |
| CHAT_COMPLETIONS_URL, |
| headers={"Authorization": f"Bearer {api_key}", "Content-Type": "application/json", "Accept": "application/json"}, |
| json=payload, |
| ) |
| latency_ms = round((time.perf_counter() - started) * 1000, 2) |
| if response.status_code >= 400: |
| try: |
| error_payload = response.json() |
| detail = error_payload.get("error", {}).get("message") or json_dumps(error_payload) |
| except Exception: |
| detail = response.text |
| raise HTTPException(status_code=response.status_code, detail=f"NVIDIA NIM 请求失败:{detail}") |
| return response.json(), latency_ms |
| except httpx.TimeoutException as exc: |
| last_timeout = exc |
| if attempt >= total_attempts: |
| break |
| except httpx.RequestError as exc: |
| raise HTTPException( |
| status_code=status.HTTP_502_BAD_GATEWAY, |
| detail=f"NVIDIA NIM 请求异常:{exc}", |
| ) from exc |
|
|
| detail = f"NVIDIA NIM 请求超时,已自动重试 {UPSTREAM_TIMEOUT_RETRIES} 次后仍未成功。" |
| if last_timeout and str(last_timeout): |
| detail = f"{detail} 最后错误:{last_timeout}" |
| raise HTTPException(status_code=status.HTTP_504_GATEWAY_TIMEOUT, detail=detail) |
|
|
|
|
| def render_html(filename: str) -> HTMLResponse: |
| content = (STATIC_DIR / filename).read_text(encoding="utf-8") |
| return HTMLResponse(content=content, media_type="text/html; charset=utf-8") |
|
|
|
|
| @asynccontextmanager |
| async def lifespan(_app: FastAPI): |
| global model_cache, model_cache_synced_at, model_sync_task, http_client, model_cache_lock |
| init_db() |
| cached_models, cached_synced_at = await run_db(load_cached_models_from_db) |
| model_cache = cached_models |
| model_cache_synced_at = cached_synced_at |
| model_cache_lock = asyncio.Lock() |
| http_client = await get_http_client() |
| try: |
| await refresh_official_models(force=not bool(model_cache)) |
| except Exception: |
| pass |
| model_sync_task = asyncio.create_task(model_sync_loop()) |
| try: |
| yield |
| finally: |
| if model_sync_task is not None: |
| model_sync_task.cancel() |
| with contextlib.suppress(asyncio.CancelledError): |
| await model_sync_task |
| if http_client is not None and not http_client.is_closed: |
| await http_client.aclose() |
| http_client = None |
| model_sync_task = None |
| model_cache_lock = None |
|
|
|
|
| app = FastAPI(title="NIM Responses Gateway", lifespan=lifespan) |
| app.add_middleware(GZipMiddleware, minimum_size=1000) |
| app.mount("/static", StaticFiles(directory=str(STATIC_DIR)), name="static") |
|
|
|
|
| @app.get("/", response_class=HTMLResponse) |
| async def homepage() -> HTMLResponse: |
| return render_html("index.html") |
|
|
|
|
| @app.get("/model_list", response_class=HTMLResponse) |
| async def models_page() -> HTMLResponse: |
| return render_html("models.html") |
|
|
|
|
| @app.get("/api/dashboard") |
| async def dashboard_api() -> dict[str, Any]: |
| return await run_db(load_dashboard_data) |
|
|
|
|
| @app.get("/api/catalog") |
| async def catalog_api() -> dict[str, Any]: |
| if not model_cache: |
| try: |
| await refresh_official_models(force=True) |
| except Exception: |
| pass |
| return build_catalog_payload() |
|
|
|
|
| async def build_models_response() -> dict[str, Any]: |
| if not model_cache: |
| await refresh_official_models(force=True) |
| return {"object": "list", "data": model_cache} |
|
|
|
|
| @app.get("/v1/models") |
| async def list_models_v1() -> dict[str, Any]: |
| return await build_models_response() |
|
|
|
|
| @app.get("/models") |
| async def list_models() -> dict[str, Any]: |
| return await build_models_response() |
|
|
|
|
| async def fetch_response_record(response_id: str, api_key: str) -> dict[str, Any]: |
| return await run_db(load_response_record, hash_api_key(api_key), response_id) |
|
|
|
|
| @app.post("/v1/messages") |
| async def create_anthropic_message( |
| request: Request, |
| api_key: str = Depends(extract_user_api_key), |
| anthropic_version: str | None = Header(default=None), |
| anthropic_beta: str | None = Header(default=None), |
| ): |
| return await create_anthropic_message_impl(request, api_key, anthropic_version, anthropic_beta) |
|
|
|
|
| @app.get("/v1/responses/{response_id}") |
| async def get_response_v1(response_id: str, api_key: str = Depends(extract_user_api_key)) -> dict[str, Any]: |
| return await fetch_response_record(response_id, api_key) |
|
|
|
|
| @app.get("/responses/{response_id}") |
| async def get_response(response_id: str, api_key: str = Depends(extract_user_api_key)) -> dict[str, Any]: |
| return await fetch_response_record(response_id, api_key) |
|
|
|
|
| @app.post("/v1/responses") |
| async def create_response_v1(request: Request, api_key: str = Depends(extract_user_api_key)): |
| return await create_response_impl(request, api_key) |
|
|
|
|
| @app.post("/responses") |
| async def create_response(request: Request, api_key: str = Depends(extract_user_api_key)): |
| return await create_response_impl(request, api_key) |
|
|
|
|
| async def create_anthropic_message_impl( |
| request: Request, |
| api_key: str, |
| anthropic_version: str | None, |
| anthropic_beta: str | None = None, |
| ): |
| body = await request.json() |
| if not isinstance(body, dict): |
| raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="请求体必须是 JSON 对象。") |
| if not body.get("model"): |
| raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="缺少 model 字段。") |
| if body.get("messages") is None: |
| raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="缺少 messages 字段。") |
| if not isinstance(body.get("messages"), list): |
| raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="messages 字段必须是数组。") |
| if body.get("max_tokens") is None: |
| raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="缺少 max_tokens 字段。") |
|
|
| api_key_hash = hash_api_key(api_key) |
| storage_items = build_anthropic_storage_items(body) |
| chat_payload, _chat_messages, tool_metadata, thinking_config = build_anthropic_chat_payload(body, anthropic_beta) |
| has_server_tools = any(meta.get("server_execution") for meta in tool_metadata.values()) |
|
|
| try: |
| if has_server_tools: |
| message_payload, latency_ms = await create_anthropic_message_with_server_tools( |
| api_key, |
| body, |
| chat_payload, |
| tool_metadata, |
| thinking_config, |
| ) |
| else: |
| upstream_json, latency_ms = await post_nvidia_chat_completion(api_key, chat_payload) |
| message_payload = chat_completion_to_anthropic_message(body, upstream_json, tool_metadata, thinking_config) |
| await run_db(store_success_record, api_key_hash, body.get("model"), body, storage_items, message_payload, latency_ms) |
| except HTTPException as exc: |
| with contextlib.suppress(Exception): |
| await run_db(store_failure_metric, body.get("model"), str(exc.detail)) |
| raise |
| except Exception as exc: |
| with contextlib.suppress(Exception): |
| await run_db(store_failure_metric, body.get("model"), str(exc)) |
| raise HTTPException( |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, |
| detail="网关处理 Anthropic Messages 请求时发生内部错误。", |
| ) from exc |
|
|
| resolved_version = anthropic_version or ANTHROPIC_API_VERSION |
| if body.get("stream"): |
| return build_anthropic_streaming_response(message_payload, resolved_version) |
|
|
| return JSONResponse(content=message_payload, headers={"anthropic-version": resolved_version}) |
|
|
|
|
| async def create_response_impl(request: Request, api_key: str): |
| body = await request.json() |
| if not isinstance(body, dict): |
| raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="请求体必须是 JSON 对象。") |
| if not body.get("model"): |
| raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="缺少 model 字段。") |
| if body.get("input") is None: |
| raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="缺少 input 字段。") |
|
|
| api_key_hash = hash_api_key(api_key) |
| input_items = normalize_input_items(body.get("input")) |
| previous_items = await run_db(load_previous_conversation_items, api_key_hash, body.get("previous_response_id")) |
| merged_items = previous_items + input_items |
| chat_payload = build_chat_payload(body, merged_items) |
|
|
| try: |
| upstream_json, latency_ms = await post_nvidia_chat_completion(api_key, chat_payload) |
| response_payload = chat_completion_to_response(body, upstream_json, body.get("previous_response_id")) |
| await run_db(store_success_record, api_key_hash, body.get("model"), body, input_items, response_payload, latency_ms) |
| except HTTPException as exc: |
| with contextlib.suppress(Exception): |
| await run_db(store_failure_metric, body.get("model"), str(exc.detail)) |
| raise |
| except Exception as exc: |
| with contextlib.suppress(Exception): |
| await run_db(store_failure_metric, body.get("model"), str(exc)) |
| raise HTTPException( |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, |
| detail="网关处理请求时发生内部错误。", |
| ) from exc |
|
|
| if body.get("stream"): |
| async def event_stream() -> Any: |
| yield f"event: response.created\ndata: {json_dumps({'type': 'response.created', 'response': {'id': response_payload['id'], 'model': response_payload['model'], 'status': 'in_progress'}})}\n\n" |
| for index, item in enumerate(response_payload.get("output") or []): |
| yield f"event: response.output_item.added\ndata: {json_dumps({'type': 'response.output_item.added', 'output_index': index, 'item': item})}\n\n" |
| if item.get("type") == "message": |
| text_value = extract_text_from_content(item.get("content")) |
| if text_value: |
| yield f"event: response.output_text.delta\ndata: {json_dumps({'type': 'response.output_text.delta', 'output_index': index, 'delta': text_value})}\n\n" |
| yield f"event: response.output_text.done\ndata: {json_dumps({'type': 'response.output_text.done', 'output_index': index, 'text': text_value})}\n\n" |
| if item.get("type") == "function_call": |
| yield f"event: response.function_call_arguments.done\ndata: {json_dumps({'type': 'response.function_call_arguments.done', 'output_index': index, 'arguments': item.get('arguments', '{}'), 'call_id': item.get('call_id')})}\n\n" |
| yield f"event: response.output_item.done\ndata: {json_dumps({'type': 'response.output_item.done', 'output_index': index, 'item': item})}\n\n" |
| yield f"event: response.completed\ndata: {json_dumps({'type': 'response.completed', 'response': response_payload})}\n\n" |
| return StreamingResponse(event_stream(), media_type="text/event-stream") |
|
|
| return response_payload |
|
|
|
|