diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000000000000000000000000000000000000..20213fad94865d9cb08f9046026488b066bbfff3 --- /dev/null +++ b/.dockerignore @@ -0,0 +1,14 @@ +.git +.venv +__pycache__ +*.pyc +*.pyo +*.pyd +.pytest_cache +.ruff_cache +.mypy_cache +.coverage +htmlcov +debug +docker-data +db.sqlite3 \ No newline at end of file diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..b72b1df1368a3f86e814b587dc47b37536b44e71 --- /dev/null +++ b/.gitignore @@ -0,0 +1,37 @@ + +build/ +.claude/ +config.local.yaml +.coverage +.cursor/ +# Cursor / editor state +db.sqlite3 +debug/ +# Debug output (runtime) +dist/ +docker-data/ +.DS_Store +*.egg-info +.env +htmlcov/ +.idea/ +# IDE & editor configs +# Local env (secrets) +# macOS +.mypy_cache/ +__pycache__/ +*.py[oc] +.pytest_cache/ +# Python-generated files +.ruff_cache/ +start_cf.sh +start_mock.sh +start.sh +*.swo +*.swp +# Test & coverage +# Type checking / linters +.venv +# Virtual environments +.vscode/ +wheels/ diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..e80cbc37e5dcda9456a9c5381b06c46c90371efe --- /dev/null +++ b/Dockerfile @@ -0,0 +1,116 @@ +FROM ubuntu:24.04 + +ARG TARGETARCH +ARG FINGERPRINT_CHROMIUM_URL_AMD64="https://github.com/adryfish/fingerprint-chromium/releases/download/142.0.7444.175/ungoogled-chromium-142.0.7444.175-1-x86_64_linux.tar.xz" +ARG FINGERPRINT_CHROMIUM_URL_ARM64_CHROMIUM_DEB="https://github.com/luispater/fingerprint-chromium-arm64/releases/download/135.0.7049.95-1/ungoogled-chromium_135.0.7049.95-1.deb12u1_arm64.deb" +ARG FINGERPRINT_CHROMIUM_URL_ARM64_COMMON_DEB="https://github.com/luispater/fingerprint-chromium-arm64/releases/download/135.0.7049.95-1/ungoogled-chromium-common_135.0.7049.95-1.deb12u1_arm64.deb" +ARG FINGERPRINT_CHROMIUM_URL_ARM64_SANDBOX_DEB="https://github.com/luispater/fingerprint-chromium-arm64/releases/download/135.0.7049.95-1/ungoogled-chromium-sandbox_135.0.7049.95-1.deb12u1_arm64.deb" +ARG FINGERPRINT_CHROMIUM_URL_ARM64_L10N_DEB="https://github.com/luispater/fingerprint-chromium-arm64/releases/download/135.0.7049.95-1/ungoogled-chromium-l10n_135.0.7049.95-1.deb12u1_all.deb" + +ENV PYTHONDONTWRITEBYTECODE=1 \ + PYTHONUNBUFFERED=1 \ + PIP_NO_CACHE_DIR=1 \ + WEB2API_DATA_DIR=/data \ + HOME=/data + +WORKDIR /app + +RUN apt-get update && apt-get install -y --no-install-recommends \ + ca-certificates \ + curl \ + xz-utils \ + xvfb \ + xauth \ + python3 \ + python3-pip \ + python3-venv \ + python-is-python3 \ + software-properties-common \ + fonts-liberation \ + libasound2t64 \ + libatk-bridge2.0-0t64 \ + libatk1.0-0t64 \ + libcairo2 \ + libcups2t64 \ + libdbus-1-3 \ + libdrm2 \ + libfontconfig1 \ + libgbm1 \ + libglib2.0-0t64 \ + libgtk-3-0t64 \ + libnspr4 \ + libnss3 \ + libpango-1.0-0 \ + libu2f-udev \ + libvulkan1 \ + libx11-6 \ + libx11-xcb1 \ + libxcb1 \ + libxcomposite1 \ + libxdamage1 \ + libxext6 \ + libxfixes3 \ + libxkbcommon0 \ + libxrandr2 \ + libxrender1 \ + libxshmfence1 \ + && add-apt-repository -y universe \ + && add-apt-repository -y multiverse \ + && rm -rf /var/lib/apt/lists/* + +ENV VIRTUAL_ENV=/opt/venv \ + PATH="/opt/venv/bin:$PATH" + +RUN python -m venv "${VIRTUAL_ENV}" \ + && pip install --no-cache-dir --upgrade pip + +RUN set -eux; \ + arch="${TARGETARCH:-}"; \ + if [ -z "${arch}" ]; then arch="$(dpkg --print-architecture)"; fi; \ + mkdir -p /opt/fingerprint-chromium; \ + case "${arch}" in \ + amd64|x86_64) \ + curl -L --fail --retry 5 --retry-delay 3 --retry-all-errors "${FINGERPRINT_CHROMIUM_URL_AMD64}" -o /tmp/fingerprint-chromium.tar.xz; \ + tar -xf /tmp/fingerprint-chromium.tar.xz -C /opt/fingerprint-chromium --strip-components=1; \ + rm -f /tmp/fingerprint-chromium.tar.xz; \ + ;; \ + arm64|aarch64) \ + curl -L --fail --retry 5 --retry-delay 3 --retry-all-errors "${FINGERPRINT_CHROMIUM_URL_ARM64_CHROMIUM_DEB}" -o /tmp/ungoogled-chromium.deb; \ + curl -L --fail --retry 5 --retry-delay 3 --retry-all-errors "${FINGERPRINT_CHROMIUM_URL_ARM64_COMMON_DEB}" -o /tmp/ungoogled-chromium-common.deb; \ + curl -L --fail --retry 5 --retry-delay 3 --retry-all-errors "${FINGERPRINT_CHROMIUM_URL_ARM64_SANDBOX_DEB}" -o /tmp/ungoogled-chromium-sandbox.deb; \ + curl -L --fail --retry 5 --retry-delay 3 --retry-all-errors "${FINGERPRINT_CHROMIUM_URL_ARM64_L10N_DEB}" -o /tmp/ungoogled-chromium-l10n.deb; \ + apt-get update; \ + apt-get install -y --no-install-recommends /tmp/ungoogled-chromium.deb /tmp/ungoogled-chromium-common.deb /tmp/ungoogled-chromium-sandbox.deb /tmp/ungoogled-chromium-l10n.deb; \ + rm -rf /var/lib/apt/lists/* /tmp/ungoogled-chromium*.deb; \ + for bin in /usr/bin/ungoogled-chromium /usr/bin/chromium /usr/bin/chromium-browser; do \ + if [ -x "${bin}" ]; then ln -sf "${bin}" /opt/fingerprint-chromium/chrome; break; fi; \ + done; \ + test -x /opt/fingerprint-chromium/chrome; \ + ;; \ + *) \ + echo "Unsupported architecture: ${arch}" >&2; \ + exit 1; \ + ;; \ + esac + +COPY pyproject.toml /tmp/pyproject.toml +RUN python - <<'PY' +import subprocess +import tomllib + +with open("/tmp/pyproject.toml", "rb") as f: + project = tomllib.load(f)["project"] + +extra_deps = project.get("optional-dependencies", {}).get("postgres", []) +deps = [*project["dependencies"], *extra_deps] +subprocess.check_call(["pip", "install", "--no-cache-dir", *deps]) +PY + +COPY . /app + +RUN chmod +x /app/docker/entrypoint.sh + +VOLUME ["/data"] +EXPOSE 9000 + +ENTRYPOINT ["/app/docker/entrypoint.sh"] diff --git a/README.md b/README.md new file mode 100644 index 0000000000000000000000000000000000000000..203d2b510dae31a31a9c8019ad6bffc397fa185c --- /dev/null +++ b/README.md @@ -0,0 +1,86 @@ +--- +title: Web2API +emoji: 🧩 +colorFrom: blue +colorTo: indigo +sdk: docker +app_port: 9000 +pinned: false +--- + +# Web2API + +Bridge Claude Web sessions to OpenAI / Anthropic compatible APIs. Runs as a Docker Space on Hugging Face. + +## Endpoints + +| Path | Protocol | Description | +|------|----------|-------------| +| `/claude/v1/models` | OpenAI | List available models | +| `/claude/v1/chat/completions` | OpenAI | Chat completions | +| `/claude/v1/messages` | Anthropic | Messages API | +| `/config` | — | Admin dashboard | + +## Supported models + +| Model ID | Upstream | Tier | Notes | +|----------|----------|------|-------| +| `claude-sonnet-4.6` | claude-sonnet-4-6 | Free | Sonnet 4.6 (default) | +| `claude-sonnet-4-5` | claude-sonnet-4-5 | Free | Sonnet 4.5 | +| `claude-sonnet-4-6-thinking` | claude-sonnet-4-6 | Free | Sonnet 4.6 extended thinking | +| `claude-sonnet-4-5-thinking` | claude-sonnet-4-5 | Free | Sonnet 4.5 extended thinking | +| `claude-haiku-4-5` | claude-haiku-4-5 | Pro | Haiku 4.5 (fastest) | +| `claude-haiku-4-5-thinking` | claude-haiku-4-5 | Pro | Haiku 4.5 extended thinking | +| `claude-opus-4-6` | claude-opus-4-6 | Pro | Opus 4.6 (most capable) | +| `claude-opus-4-6-thinking` | claude-opus-4-6 | Pro | Opus 4.6 extended thinking | + +Pro models require a Claude Pro subscription and must be enabled in the config page. + +## Quick start + +1. Set required secrets in Space settings +2. Open `/login` → `/config` +3. Add a proxy group and a Claude account with `auth.sessionKey` +4. (Optional) Enable Pro models toggle if your account has a Pro subscription +5. Call the API: + +```bash +# OpenAI format (streaming) +curl $SPACE_URL/claude/v1/chat/completions \ + -H "Authorization: Bearer $API_KEY" \ + -H "Content-Type: application/json" \ + -d '{"model":"claude-sonnet-4.6","stream":true,"messages":[{"role":"user","content":"Hello"}]}' + +# Anthropic format (streaming) +curl $SPACE_URL/claude/v1/messages \ + -H "Authorization: Bearer $API_KEY" \ + -H "Content-Type: application/json" \ + -d '{"model":"claude-sonnet-4.6","stream":true,"max_tokens":1024,"messages":[{"role":"user","content":"Hello"}]}' + +# Extended thinking +curl $SPACE_URL/claude/v1/chat/completions \ + -H "Authorization: Bearer $API_KEY" \ + -H "Content-Type: application/json" \ + -d '{"model":"claude-sonnet-4-6-thinking","stream":true,"messages":[{"role":"user","content":"Solve this step by step: what is 23 * 47?"}]}' +``` + +## Required secrets + +| Secret | Purpose | +|--------|---------| +| `WEB2API_AUTH_API_KEY` | API auth key for `/claude/v1/*` | +| `WEB2API_AUTH_CONFIG_SECRET` | Password for `/login` and `/config` | +| `WEB2API_DATABASE_URL` | PostgreSQL URL for persistent config (optional) | + +## Recommended environment variables + +For a small CPU Space: + +``` +WEB2API_BROWSER_NO_SANDBOX=true +WEB2API_BROWSER_DISABLE_GPU=true +WEB2API_BROWSER_DISABLE_GPU_SANDBOX=true +WEB2API_SCHEDULER_RESIDENT_BROWSER_COUNT=0 +WEB2API_SCHEDULER_TAB_MAX_CONCURRENT=5 +WEB2API_BROWSER_CDP_PORT_COUNT=6 +``` diff --git a/core/__init__.py b/core/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b7721092d729d864e4848de0a60bc919998d4193 --- /dev/null +++ b/core/__init__.py @@ -0,0 +1,5 @@ +""" +新架构核心包:插件式 Web2API,按 type 路由,浏览器/context/page/会话树形缓存。 +""" + +__all__ = [] diff --git a/core/account/__init__.py b/core/account/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6d51aa139f0a0b7850cfcd65324dbad488eaca90 --- /dev/null +++ b/core/account/__init__.py @@ -0,0 +1,5 @@ +"""账号池:按 type 过滤、轮询获取 (ProxyGroup, Account)。""" + +from core.account.pool import AccountPool + +__all__ = ["AccountPool"] diff --git a/core/account/pool.py b/core/account/pool.py new file mode 100644 index 0000000000000000000000000000000000000000..a49180666909c78cb6a27accd413eea70dc79988 --- /dev/null +++ b/core/account/pool.py @@ -0,0 +1,207 @@ +""" +账号池:从配置加载代理组与账号,按 type 轮询 acquire。 + +除基础的全局轮询外,还支持: + +- 按 proxy_key 反查代理组 +- 在指定代理组内选择某个 type 的可用账号 +- 排除当前账号后为 tab 切号选择备选账号 +- 在未打开浏览器的代理组中选择某个 type 的候选账号 +""" + +from dataclasses import replace +from typing import Iterator + +from core.config.schema import AccountConfig, ProxyGroupConfig +from core.constants import TIMEZONE +from core.runtime.keys import ProxyKey + + +class AccountPool: + """ + 多 IP / 多账号池,按 type 过滤后轮询。 + acquire(type) 返回 (ProxyGroupConfig, AccountConfig)。 + get_group_by_proxy_key / acquire_from_group 供现役浏览器复用时使用。 + """ + + def __init__(self, groups: list[ProxyGroupConfig]) -> None: + self._groups = list(groups) + self._indices: dict[str, int] = {} # type -> 全局轮询下标 + self._group_type_indices: dict[ + tuple[str, str], int + ] = {} # (fingerprint_id, type) -> 组内轮询下标 + + @classmethod + def from_groups(cls, groups: list[ProxyGroupConfig]) -> "AccountPool": + return cls(groups) + + def reload(self, groups: list[ProxyGroupConfig]) -> None: + """用新加载的配置替换当前组(如更新解冻时间后从 repository 重新 load_groups)。""" + self._groups = list(groups) + + def groups(self) -> list[ProxyGroupConfig]: + """返回当前全部代理组。""" + return list(self._groups) + + def _accounts_by_type( + self, type_name: str + ) -> Iterator[tuple[ProxyGroupConfig, AccountConfig]]: + """按 type 遍历所有 (group, account),仅包含当前可用的账号(解冻时间已过或未设置)。""" + for g in self._groups: + for a in g.accounts: + if a.type == type_name and a.is_available(): + yield g, a + + def acquire(self, type_name: str) -> tuple[ProxyGroupConfig, AccountConfig]: + """ + 按 type 轮询获取一组 (ProxyGroupConfig, AccountConfig)。 + 若该 type 无账号则抛出 ValueError。 + """ + pairs = list(self._accounts_by_type(type_name)) + if not pairs: + raise ValueError(f"没有类别为 {type_name!r} 的账号,请先在配置中添加") + n = len(pairs) + idx = self._indices.get(type_name, 0) % n + self._indices[type_name] = (idx + 1) % n + return pairs[idx] + + def account_id(self, group: ProxyGroupConfig, account: AccountConfig) -> str: + """生成账号唯一标识,用于会话缓存等。""" + return f"{group.fingerprint_id}:{account.name}" + + def get_account_by_id( + self, account_id: str + ) -> tuple[ProxyGroupConfig, AccountConfig] | None: + """根据 account_id(fingerprint_id:name)反查 (group, account),用于复用会话时取 auth。""" + for g in self._groups: + for a in g.accounts: + if self.account_id(g, a) == account_id: + return g, a + return None + + def get_group_by_proxy_key(self, proxy_key: ProxyKey) -> ProxyGroupConfig | None: + """根据 proxy_key(proxy_host, proxy_user, fingerprint_id, use_proxy, timezone)反查对应代理组。""" + pk_tz = getattr(proxy_key, "timezone", None) or TIMEZONE + for g in self._groups: + g_tz = g.timezone or TIMEZONE + if ( + g.proxy_host == proxy_key.proxy_host + and g.proxy_user == proxy_key.proxy_user + and g.fingerprint_id == proxy_key.fingerprint_id + and g.use_proxy == getattr(proxy_key, "use_proxy", True) + and g_tz == pk_tz + ): + return g + return None + + def acquire_from_group( + self, + group: ProxyGroupConfig, + type_name: str, + ) -> tuple[ProxyGroupConfig, AccountConfig] | None: + """ + 从指定 group 内按 type 轮询取一个账号;若无该 type 则返回 None。 + 供「现役浏览器对应 IP 组是否还有该 type 可用」时使用。 + """ + pairs = [(g, a) for g, a in self._accounts_by_type(type_name) if g is group] + if not pairs: + return None + n = len(pairs) + key = (group.fingerprint_id, type_name) + idx = self._group_type_indices.get(key, 0) % n + self._group_type_indices[key] = (idx + 1) % n + return pairs[idx] + + def available_accounts_in_group( + self, + group: ProxyGroupConfig, + type_name: str, + *, + exclude_account_ids: set[str] | None = None, + ) -> list[AccountConfig]: + """返回某代理组下指定 type 的全部可用账号,可排除若干 account_id。""" + exclude = exclude_account_ids or set() + return [ + a + for g, a in self._accounts_by_type(type_name) + if g is group and self.account_id(group, a) not in exclude + ] + + def has_available_account_in_group( + self, + group: ProxyGroupConfig, + type_name: str, + *, + exclude_account_ids: set[str] | None = None, + ) -> bool: + """判断某代理组下是否仍有指定 type 的可用账号。""" + return bool( + self.available_accounts_in_group( + group, + type_name, + exclude_account_ids=exclude_account_ids, + ) + ) + + def next_available_account_in_group( + self, + group: ProxyGroupConfig, + type_name: str, + *, + exclude_account_ids: set[str] | None = None, + ) -> AccountConfig | None: + """ + 在指定代理组内按轮询选择一个可用账号。 + 支持排除当前已绑定账号,用于 drained 后切号。 + """ + accounts = self.available_accounts_in_group( + group, + type_name, + exclude_account_ids=exclude_account_ids, + ) + if not accounts: + return None + n = len(accounts) + key = (group.fingerprint_id, type_name) + idx = self._group_type_indices.get(key, 0) % n + self._group_type_indices[key] = (idx + 1) % n + return accounts[idx] + + def next_available_pair( + self, + type_name: str, + *, + exclude_fingerprint_ids: set[str] | None = None, + ) -> tuple[ProxyGroupConfig, AccountConfig] | None: + """ + 全局按 type 轮询选择一个可用账号,可排除若干代理组。 + 用于“未打开浏览器的组里挑一个候选账号”。 + """ + exclude = exclude_fingerprint_ids or set() + pairs = [ + (g, a) + for g, a in self._accounts_by_type(type_name) + if g.fingerprint_id not in exclude + ] + if not pairs: + return None + n = len(pairs) + idx = self._indices.get(type_name, 0) % n + self._indices[type_name] = (idx + 1) % n + return pairs[idx] + + def update_account_unfreeze_at( + self, + fingerprint_id: str, + account_name: str, + unfreeze_at: int | None, + ) -> bool: + for group in self._groups: + if group.fingerprint_id != fingerprint_id: + continue + for index, account in enumerate(group.accounts): + if account.name != account_name: + continue + group.accounts[index] = replace(account, unfreeze_at=unfreeze_at) + return True + return False diff --git a/core/api/__init__.py b/core/api/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e321b7240453a563853417fa94b1a8dee3fe7f44 --- /dev/null +++ b/core/api/__init__.py @@ -0,0 +1,3 @@ +"""API 层:OpenAI 兼容路由、会话解析、聊天编排。""" + +__all__ = [] diff --git a/core/api/anthropic_routes.py b/core/api/anthropic_routes.py new file mode 100644 index 0000000000000000000000000000000000000000..17e70108853a3bb405d44fd9a2485bf8e321a43d --- /dev/null +++ b/core/api/anthropic_routes.py @@ -0,0 +1,93 @@ +"""Anthropic 协议路由。""" + +from __future__ import annotations + +import json +from collections.abc import AsyncIterator +from typing import Any + +from fastapi import APIRouter, Depends, Request +from fastapi.responses import JSONResponse, StreamingResponse + +from core.api.auth import require_api_key +from core.api.chat_handler import ChatHandler +from core.api.routes import get_chat_handler, resolve_request_model, check_pro_model_access +from core.protocol.anthropic import AnthropicProtocolAdapter +from core.protocol.service import CanonicalChatService + + +def create_anthropic_router() -> APIRouter: + router = APIRouter(dependencies=[Depends(require_api_key)]) + adapter = AnthropicProtocolAdapter() + + @router.post("/anthropic/{provider}/v1/messages") + async def messages( + provider: str, + request: Request, + handler: ChatHandler = Depends(get_chat_handler), + ) -> Any: + return await _messages(provider, request, handler) + + @router.post("/{provider}/v1/messages") + async def messages_legacy( + provider: str, + request: Request, + handler: ChatHandler = Depends(get_chat_handler), + ) -> Any: + return await _messages(provider, request, handler) + + async def _messages( + provider: str, + request: Request, + handler: ChatHandler, + ) -> Any: + raw_body = await request.json() + try: + canonical_req = resolve_request_model( + provider, + adapter.parse_request(provider, raw_body), + ) + except Exception as exc: + status, payload = adapter.render_error(exc) + return JSONResponse(status_code=status, content=payload) + + pro_err = check_pro_model_access(request, provider, canonical_req.model) + if pro_err is not None: + return pro_err + + service = CanonicalChatService(handler) + if canonical_req.stream: + + async def sse_stream() -> AsyncIterator[str]: + try: + async for event in adapter.render_stream( + canonical_req, + service.stream_raw(canonical_req), + ): + yield event + except Exception as exc: + status, payload = adapter.render_error(exc) + del status + yield ( + "event: error\n" + f"data: {json.dumps(payload, ensure_ascii=False)}\n\n" + ) + + return StreamingResponse( + sse_stream(), + media_type="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "X-Accel-Buffering": "no", + }, + ) + + try: + raw_events = await service.collect_raw(canonical_req) + return adapter.render_non_stream(canonical_req, raw_events) + except Exception as exc: + status, payload = adapter.render_error(exc) + return JSONResponse(status_code=status, content=payload) + + return router diff --git a/core/api/auth.py b/core/api/auth.py new file mode 100644 index 0000000000000000000000000000000000000000..5ec3a4998257dc7580001dfd46f0f51e6757dea0 --- /dev/null +++ b/core/api/auth.py @@ -0,0 +1,455 @@ +""" +API 与配置页鉴权。 + +- auth.api_key: 保护 /{type}/v1/* +- auth.config_secret: 保护 /config 与 /api/config、/api/types + +全局鉴权设置优先级:数据库 > 环境变量回退 > YAML > 默认值。 +config_secret 在文件模式下会回写为带前缀的 PBKDF2 哈希;环境变量回退模式下仅在内存中哈希。 +""" + +from __future__ import annotations + +import base64 +import hashlib +import hmac +import os +import re +import secrets +import time +from dataclasses import dataclass, field +from functools import lru_cache +from typing import Any, Literal + +from fastapi import HTTPException, Request, status + +from core.config.repository import ( + APP_SETTING_AUTH_API_KEY, + APP_SETTING_AUTH_CONFIG_SECRET_HASH, + ConfigRepository, +) +from core.config.settings import ( + get, + get_config_path, + has_env_override, + load_config, + reset_cache, +) + +API_AUTH_REALM = "Bearer" +CONFIG_SECRET_PREFIX = "web2api_pbkdf2_sha256" +CONFIG_SECRET_ITERATIONS = 600_000 +ADMIN_SESSION_COOKIE = "web2api_admin_session" +DEFAULT_ADMIN_SESSION_TTL_SECONDS = 7 * 24 * 60 * 60 +DEFAULT_ADMIN_LOGIN_MAX_FAILURES = 5 +DEFAULT_ADMIN_LOGIN_LOCK_SECONDS = 10 * 60 +AuthSource = Literal["env", "db", "yaml", "default"] + + +@dataclass(frozen=True) +class EffectiveAuthSettings: + api_key_text: str + api_key_source: AuthSource + config_secret_hash: str + config_secret_source: AuthSource + + @property + def api_keys(self) -> list[str]: + return parse_api_keys(self.api_key_text) + + @property + def api_key_env_managed(self) -> bool: + return False + + @property + def config_secret_env_managed(self) -> bool: + return False + + @property + def config_login_enabled(self) -> bool: + return bool(self.config_secret_hash) + + +def parse_api_keys(raw: Any) -> list[str]: + if isinstance(raw, list): + return [str(item).strip() for item in raw if str(item).strip()] + if raw is None: + return [] + text = str(raw).replace("\n", ",") + return [part.strip() for part in text.split(",") if part.strip()] + + +def normalize_api_key_text(raw: Any) -> str: + if isinstance(raw, list): + return "\n".join(str(item).strip() for item in raw if str(item).strip()) + return str(raw or "").strip() + + +def _yaml_auth_config() -> dict[str, Any]: + auth_cfg = load_config().get("auth") or {} + return auth_cfg if isinstance(auth_cfg, dict) else {} + + +def _normalize_config_secret_hash(value: Any) -> str: + secret = str(value or "").strip() + if not secret: + return "" + return secret if _is_hashed_config_secret(secret) else hash_config_secret(secret) + + +@lru_cache(maxsize=1) +def _hosted_config_secret_hash() -> str: + return _normalize_config_secret_hash(get("auth", "config_secret", "")) + + +def build_effective_auth_settings( + repo: ConfigRepository | None = None, +) -> EffectiveAuthSettings: + stored = repo.load_app_settings() if repo is not None else {} + yaml_auth = _yaml_auth_config() + + if APP_SETTING_AUTH_API_KEY in stored: + api_key_text = normalize_api_key_text(stored.get(APP_SETTING_AUTH_API_KEY, "")) + api_key_source: AuthSource = "db" + elif has_env_override("auth", "api_key"): + api_key_text = normalize_api_key_text(get("auth", "api_key", "")) + api_key_source = "env" + elif "api_key" in yaml_auth: + api_key_text = normalize_api_key_text(yaml_auth.get("api_key", "")) + api_key_source = "yaml" + else: + api_key_text = "" + api_key_source = "default" + + if APP_SETTING_AUTH_CONFIG_SECRET_HASH in stored: + config_secret_hash = _normalize_config_secret_hash( + stored.get(APP_SETTING_AUTH_CONFIG_SECRET_HASH, "") + ) + config_secret_source: AuthSource = "db" + elif has_env_override("auth", "config_secret"): + config_secret_hash = _hosted_config_secret_hash() + config_secret_source = "env" + elif "config_secret" in yaml_auth: + config_secret_hash = _normalize_config_secret_hash(yaml_auth.get("config_secret", "")) + config_secret_source = "yaml" + else: + config_secret_hash = "" + config_secret_source = "default" + + return EffectiveAuthSettings( + api_key_text=api_key_text, + api_key_source=api_key_source, + config_secret_hash=config_secret_hash, + config_secret_source=config_secret_source, + ) + + +def refresh_runtime_auth_settings(app: Any) -> EffectiveAuthSettings: + repo = getattr(app.state, "config_repo", None) + settings = build_effective_auth_settings(repo) + app.state.auth_settings = settings + return settings + + +def get_effective_auth_settings(request: Request | None = None) -> EffectiveAuthSettings: + if request is not None: + settings = getattr(request.app.state, "auth_settings", None) + if isinstance(settings, EffectiveAuthSettings): + return settings + repo = getattr(request.app.state, "config_repo", None) + return build_effective_auth_settings(repo) + return build_effective_auth_settings() + + +def configured_api_keys(repo: ConfigRepository | None = None) -> list[str]: + return build_effective_auth_settings(repo).api_keys + + +def _extract_request_api_key(request: Request) -> str: + key = (request.headers.get("x-api-key") or "").strip() + if key: + return key + authorization = (request.headers.get("authorization") or "").strip() + if authorization.lower().startswith("bearer "): + return authorization[7:].strip() + return "" + + +def require_api_key(request: Request) -> None: + expected_keys = get_effective_auth_settings(request).api_keys + if not expected_keys: + return + provided = _extract_request_api_key(request) + if provided: + for expected in expected_keys: + if secrets.compare_digest(provided, expected): + return + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Unauthorized. Provide a valid API key.", + headers={"WWW-Authenticate": API_AUTH_REALM}, + ) + + +def _is_hashed_config_secret(value: str) -> bool: + return value.startswith(f"{CONFIG_SECRET_PREFIX}$") + + +def configured_config_secret_hash(repo: ConfigRepository | None = None) -> str: + return build_effective_auth_settings(repo).config_secret_hash + + +def config_login_enabled(request: Request | None = None) -> bool: + return get_effective_auth_settings(request).config_login_enabled + + +def configured_config_login_max_failures() -> int: + raw = get("auth", "config_login_max_failures", DEFAULT_ADMIN_LOGIN_MAX_FAILURES) + try: + return max(1, int(raw)) + except Exception: + return DEFAULT_ADMIN_LOGIN_MAX_FAILURES + + +def configured_config_login_lock_seconds() -> int: + raw = get("auth", "config_login_lock_seconds", DEFAULT_ADMIN_LOGIN_LOCK_SECONDS) + try: + return max(1, int(raw)) + except Exception: + return DEFAULT_ADMIN_LOGIN_LOCK_SECONDS + + +def hash_config_secret(secret: str) -> str: + salt = os.urandom(16) + digest = hashlib.pbkdf2_hmac( + "sha256", + secret.encode("utf-8"), + salt, + CONFIG_SECRET_ITERATIONS, + ) + return ( + f"{CONFIG_SECRET_PREFIX}" + f"${CONFIG_SECRET_ITERATIONS}" + f"${base64.urlsafe_b64encode(salt).decode('ascii')}" + f"${base64.urlsafe_b64encode(digest).decode('ascii')}" + ) + + +def verify_config_secret(secret: str, encoded: str) -> bool: + try: + prefix, iterations_s, salt_b64, digest_b64 = encoded.split("$", 3) + except ValueError: + return False + if prefix != CONFIG_SECRET_PREFIX: + return False + try: + iterations = int(iterations_s) + salt = base64.urlsafe_b64decode(salt_b64.encode("ascii")) + expected = base64.urlsafe_b64decode(digest_b64.encode("ascii")) + except Exception: + return False + actual = hashlib.pbkdf2_hmac( + "sha256", + secret.encode("utf-8"), + salt, + iterations, + ) + return hmac.compare_digest(actual, expected) + + +def ensure_config_secret_hashed(repo: ConfigRepository | None = None) -> None: + if has_env_override("auth", "config_secret"): + _hosted_config_secret_hash() + return + if repo is not None and repo.get_app_setting(APP_SETTING_AUTH_CONFIG_SECRET_HASH) is not None: + return + cfg = load_config() + auth_cfg = cfg.get("auth") + if not isinstance(auth_cfg, dict): + return + raw_value = auth_cfg.get("config_secret") + secret = str(raw_value or "").strip() + if not secret or _is_hashed_config_secret(secret): + return + encoded = hash_config_secret(secret) + config_path = get_config_path() + if not config_path.exists(): + return + original = config_path.read_text(encoding="utf-8") + pattern = re.compile(r"^([ \t]*)config_secret\s*:\s*.*$", re.MULTILINE) + replacement = None + for line in original.splitlines(): + match = pattern.match(line) + if match: + replacement = f"{match.group(1)}config_secret: '{encoded}'" + break + updated: str + if replacement is not None: + updated, count = pattern.subn(replacement, original, count=1) + if count != 1: + return + else: + auth_pattern = re.compile(r"^auth\s*:\s*$", re.MULTILINE) + match = auth_pattern.search(original) + if match: + insert_at = match.end() + updated = ( + original[:insert_at] + + "\n" + + f" config_secret: '{encoded}'" + + original[insert_at:] + ) + else: + suffix = "" if original.endswith("\n") or not original else "\n" + updated = ( + original + + suffix + + "auth:\n" + + f" config_secret: '{encoded}'\n" + ) + tmp_path = config_path.with_suffix(config_path.suffix + ".tmp") + tmp_path.write_text(updated, encoding="utf-8") + tmp_path.replace(config_path) + reset_cache() + load_config() + + +@dataclass +class AdminSessionStore: + ttl_seconds: int = DEFAULT_ADMIN_SESSION_TTL_SECONDS + _sessions: dict[str, float] = field(default_factory=dict) + + def create(self) -> str: + token = secrets.token_urlsafe(32) + self._sessions[token] = time.time() + self.ttl_seconds + return token + + def is_valid(self, token: str) -> bool: + if not token: + return False + self.cleanup() + expires_at = self._sessions.get(token) + if expires_at is None: + return False + if expires_at < time.time(): + self._sessions.pop(token, None) + return False + return True + + def revoke(self, token: str) -> None: + if token: + self._sessions.pop(token, None) + + def cleanup(self) -> None: + now = time.time() + expired = [token for token, expires_at in self._sessions.items() if expires_at < now] + for token in expired: + self._sessions.pop(token, None) + + +@dataclass +class LoginAttemptState: + failures: int = 0 + locked_until: float = 0.0 + last_seen: float = 0.0 + + +@dataclass +class AdminLoginAttemptStore: + max_failures: int = DEFAULT_ADMIN_LOGIN_MAX_FAILURES + lock_seconds: int = DEFAULT_ADMIN_LOGIN_LOCK_SECONDS + _attempts: dict[str, LoginAttemptState] = field(default_factory=dict) + + def is_locked(self, client_ip: str) -> int: + self.cleanup() + state = self._attempts.get(client_ip) + if state is None: + return 0 + remaining = int(state.locked_until - time.time()) + if remaining <= 0: + return 0 + return remaining + + def record_failure(self, client_ip: str) -> int: + now = time.time() + state = self._attempts.setdefault(client_ip, LoginAttemptState()) + if state.locked_until > now: + state.last_seen = now + return int(state.locked_until - now) + state.failures += 1 + state.last_seen = now + if state.failures >= self.max_failures: + state.failures = 0 + state.locked_until = now + self.lock_seconds + return self.lock_seconds + return 0 + + def record_success(self, client_ip: str) -> None: + self._attempts.pop(client_ip, None) + + def cleanup(self) -> None: + now = time.time() + stale_before = now - max(self.lock_seconds * 2, 3600) + expired = [ + ip + for ip, state in self._attempts.items() + if state.locked_until <= now and state.last_seen < stale_before + ] + for ip in expired: + self._attempts.pop(ip, None) + + +def _admin_store(request: Request) -> AdminSessionStore: + store = getattr(request.app.state, "admin_sessions", None) + if store is None: + raise HTTPException(status_code=503, detail="Admin session store is unavailable") + return store + + +def _admin_login_attempt_store(request: Request) -> AdminLoginAttemptStore: + store = getattr(request.app.state, "admin_login_attempts", None) + if store is None: + raise HTTPException(status_code=503, detail="Login rate limiter is unavailable") + return store + + +def client_ip_of(request: Request) -> str: + client = getattr(request, "client", None) + host = getattr(client, "host", None) + return str(host or "unknown") + + +def check_admin_login_rate_limit(request: Request) -> None: + remaining = _admin_login_attempt_store(request).is_locked(client_ip_of(request)) + if remaining > 0: + raise HTTPException( + status_code=status.HTTP_429_TOO_MANY_REQUESTS, + detail=f"Too many failed login attempts. Try again in {remaining} seconds.", + ) + + +def record_admin_login_failure(request: Request) -> int: + return _admin_login_attempt_store(request).record_failure(client_ip_of(request)) + + +def record_admin_login_success(request: Request) -> None: + _admin_login_attempt_store(request).record_success(client_ip_of(request)) + + +def admin_logged_in(request: Request) -> bool: + if not config_login_enabled(request): + return False + token = (request.cookies.get(ADMIN_SESSION_COOKIE) or "").strip() + return _admin_store(request).is_valid(token) + + +def require_config_login_enabled(request: Request | None = None) -> None: + if not config_login_enabled(request): + raise HTTPException(status_code=404, detail="Config dashboard is disabled") + + +def require_config_login(request: Request) -> None: + require_config_login_enabled(request) + if admin_logged_in(request): + return + raise HTTPException(status_code=401, detail="Please sign in to access the config dashboard") diff --git a/core/api/chat_handler.py b/core/api/chat_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..1a630c74a383d6de8297d56dfd4ecc8ebc39974f --- /dev/null +++ b/core/api/chat_handler.py @@ -0,0 +1,1106 @@ +""" +聊天请求编排:解析 session_id、调度 browser/tab/session、调用插件流式补全, +并在响应末尾附加零宽字符编码的会话 ID。 + +当前调度模型: + +- 一个浏览器对应一个代理组 +- 一个浏览器内,一个 type 只有一个 tab +- 一个 tab 绑定一个 account,只有 drained 后才能切号 +- 一个 session 绑定到某个 tab/account;复用成功时不传完整历史 +- 无法复用时,新建会话并回放完整历史 +""" + +from __future__ import annotations + +import asyncio +import json +import logging +import time +import uuid +from dataclasses import dataclass +from pathlib import Path +from typing import Any, AsyncIterator, cast + +from playwright.async_api import BrowserContext, Page + +from core.account.pool import AccountPool +from core.config.repository import ConfigRepository +from core.config.schema import AccountConfig, ProxyGroupConfig +from core.config.settings import get +from core.constants import TIMEZONE +from core.plugin.base import ( + AccountFrozenError, + BaseSitePlugin, + BrowserResourceInvalidError, + PluginRegistry, +) +from core.plugin.helpers import clear_cookies_for_domain +from core.runtime.browser_manager import BrowserManager, ClosedTabInfo, TabRuntime +from core.runtime.keys import ProxyKey +from core.runtime.local_proxy_forwarder import LocalProxyForwarder, UpstreamProxy, parse_proxy_server +from core.runtime.session_cache import SessionCache, SessionEntry + +from core.api.conv_parser import parse_conv_uuid_from_messages, session_id_suffix +from core.api.fingerprint import compute_conversation_fingerprint +from core.api.react import format_react_prompt +from core.api.schemas import OpenAIChatRequest, extract_user_content +from core.hub.schemas import OpenAIStreamEvent +from core.runtime.conversation_index import ConversationIndex + +logger = logging.getLogger(__name__) + + +def _request_messages_as_dicts(req: OpenAIChatRequest) -> list[dict[str, Any]]: + """转为 conv_parser 需要的 list[dict]。""" + out: list[dict[str, Any]] = [] + for m in req.messages: + d: dict[str, Any] = {"role": m.role} + if isinstance(m.content, list): + d["content"] = [p.model_dump() for p in m.content] + else: + d["content"] = m.content + out.append(d) + return out + + +def _proxy_key_for_group(group: ProxyGroupConfig) -> ProxyKey: + return ProxyKey( + group.proxy_host, + group.proxy_user, + group.fingerprint_id, + group.use_proxy, + group.timezone or TIMEZONE, + ) + + +@dataclass +class _RequestTarget: + proxy_key: ProxyKey + group: ProxyGroupConfig + account: AccountConfig + context: BrowserContext + page: Page + session_id: str | None + full_history: bool + proxy_url: str | None = None + proxy_auth: tuple[str, str] | None = None + proxy_forwarder: LocalProxyForwarder | None = None + + +class ChatHandler: + """编排一次 chat 请求:会话解析、tab 调度、插件调用。""" + + def __init__( + self, + pool: AccountPool, + session_cache: SessionCache, + browser_manager: BrowserManager, + config_repo: ConfigRepository | None = None, + ) -> None: + self._pool = pool + self._session_cache = session_cache + self._browser_manager = browser_manager + self._config_repo = config_repo + self._conv_index = ConversationIndex() + self._schedule_lock = asyncio.Lock() + self._stop_event = asyncio.Event() + self._busy_sessions: set[str] = set() + self._tab_max_concurrent = int(get("scheduler", "tab_max_concurrent") or 5) + self._gc_interval_seconds = float( + get("scheduler", "browser_gc_interval_seconds") or 300 + ) + self._tab_idle_seconds = float(get("scheduler", "tab_idle_seconds") or 900) + self._resident_browser_count = int( + get("scheduler", "resident_browser_count", 1) + ) + + def reload_pool( + self, + groups: list[ProxyGroupConfig], + config_repo: ConfigRepository | None = None, + ) -> None: + """配置热更新后替换账号池与 repository。""" + self._pool.reload(groups) + if config_repo is not None: + self._config_repo = config_repo + + async def refresh_configuration( + self, + groups: list[ProxyGroupConfig], + config_repo: ConfigRepository | None = None, + ) -> None: + """配置热更新:替换账号池、清理失效资源,并重新预热常驻浏览器。""" + async with self._schedule_lock: + self.reload_pool(groups, config_repo) + await self._prune_invalid_resources_locked() + await self._reconcile_tabs_locked() + await self.prewarm_resident_browsers() + + async def prewarm_resident_browsers(self) -> None: + """启动时预热常驻浏览器,并为其下可用 type 建立 tab。""" + async with self._schedule_lock: + warmed = 0 + for group in self._pool.groups(): + if warmed >= self._resident_browser_count: + break + available_types = { + a.type + for a in group.accounts + if a.is_available() and PluginRegistry.get(a.type) is not None + } + if not available_types: + continue + proxy_key = _proxy_key_for_group(group) + await self._browser_manager.ensure_browser(proxy_key, group.proxy_pass) + for type_name in sorted(available_types): + if self._browser_manager.get_tab(proxy_key, type_name) is not None: + continue + account = self._pool.available_accounts_in_group(group, type_name) + if not account: + continue + chosen = account[0] + plugin = PluginRegistry.get(type_name) + if plugin is None: + continue + await self._browser_manager.open_tab( + proxy_key, + group.proxy_pass, + type_name, + self._pool.account_id(group, chosen), + plugin.create_page, + self._make_apply_auth_fn(plugin, chosen), + ) + warmed += 1 + + async def run_maintenance_loop(self) -> None: + """周期性回收空闲浏览器,并收尾 drained/frozen tab。""" + while not self._stop_event.is_set(): + try: + await asyncio.wait_for( + self._stop_event.wait(), + timeout=self._gc_interval_seconds, + ) + break + except asyncio.TimeoutError: + pass + + try: + async with self._schedule_lock: + # Evict stale sessions to prevent unbounded accumulation. + stale_ids = self._session_cache.evict_stale() + # Evict stale fingerprint entries in sync. + stale_fp = self._conv_index.evict_stale(ttl=1800.0) + if stale_fp: + logger.info( + "[maintenance] evicted %d stale fingerprint entries", + len(stale_fp), + ) + if stale_ids: + for sid in stale_ids: + plugin_type = None + for pk, entry in self._browser_manager.list_browser_entries(): + for tn, tab in entry.tabs.items(): + if sid in tab.sessions: + tab.sessions.discard(sid) + plugin_type = tn + break + if plugin_type: + plugin = PluginRegistry.get(plugin_type) + if plugin is not None: + plugin.drop_session(sid) + logger.info( + "[maintenance] evicted %d stale sessions, cache size=%d", + len(stale_ids), + len(self._session_cache), + ) + await self._reconcile_tabs_locked() + closed = await self._browser_manager.collect_idle_browsers( + idle_seconds=self._tab_idle_seconds, + resident_browser_count=self._resident_browser_count, + ) + self._apply_closed_tabs_locked(closed) + except Exception: + logger.exception("维护循环执行失败") + + async def shutdown(self) -> None: + """停止维护循环并关闭全部浏览器。""" + self._stop_event.set() + async with self._schedule_lock: + closed = await self._browser_manager.close_all() + self._apply_closed_tabs_locked(closed) + + def report_account_unfreeze( + self, + fingerprint_id: str, + account_name: str, + unfreeze_at: int, + ) -> None: + """记录账号解冻时间,并同步更新内存账号池。""" + if self._config_repo is None: + return + self._config_repo.update_account_unfreeze_at( + fingerprint_id, account_name, unfreeze_at + ) + self._pool.update_account_unfreeze_at( + fingerprint_id, + account_name, + unfreeze_at, + ) + + def get_account_runtime_status(self) -> dict[str, dict[str, Any]]: + """返回当前账号运行时状态,供配置页展示角标。""" + status: dict[str, dict[str, Any]] = {} + for proxy_key, entry in self._browser_manager.list_browser_entries(): + for type_name, tab in entry.tabs.items(): + status[tab.account_id] = { + "fingerprint_id": proxy_key.fingerprint_id, + "type": type_name, + "is_active": True, + "tab_state": tab.state, + "accepting_new": tab.accepting_new, + "active_requests": tab.active_requests, + "frozen_until": tab.frozen_until, + } + return status + + def _make_apply_auth_fn( + self, + plugin: Any, + account: AccountConfig, + ) -> Any: + async def _apply_auth(context: BrowserContext, page: Page) -> None: + await plugin.apply_auth(context, page, account.auth) + + return _apply_auth + + def _apply_closed_tabs_locked(self, closed_tabs: list[ClosedTabInfo]) -> None: + for info in closed_tabs: + self._session_cache.delete_many(info.session_ids) + plugin = PluginRegistry.get(info.type_name) + if plugin is not None: + plugin.drop_sessions(info.session_ids) + + def _stream_proxy_settings( + self, + target: _RequestTarget, + ) -> tuple[str | None, tuple[str, str] | None, LocalProxyForwarder | None]: + if not target.proxy_key.use_proxy: + return (None, None, None) + upstream_host, upstream_port = parse_proxy_server(target.proxy_key.proxy_host) + forwarder = LocalProxyForwarder( + UpstreamProxy( + host=upstream_host, + port=upstream_port, + username=target.proxy_key.proxy_user, + password=target.group.proxy_pass, + ), + listen_host="127.0.0.1", + listen_port=0, + on_log=lambda msg: logger.debug("[stream-proxy] %s", msg), + ) + forwarder.start() + return ( + forwarder.proxy_url, + None, + forwarder, + ) + + async def _clear_tab_domain_cookies_if_supported( + self, proxy_key: ProxyKey, type_name: str + ) -> None: + """关 tab 前清该 type 对应域名的 cookie(仅支持带 site.cookie_domain 的插件)。""" + entry = self._browser_manager.get_browser_entry(proxy_key) + if entry is None: + return + plugin = PluginRegistry.get(type_name) + if not isinstance(plugin, BaseSitePlugin) or not getattr(plugin, "site", None): + return + try: + await clear_cookies_for_domain(entry.context, plugin.site.cookie_domain) + except Exception as e: + logger.debug("关 tab 前清 cookie 失败 type=%s: %s", type_name, e) + + async def _prune_invalid_resources_locked(self) -> None: + """关闭配置中已不存在的浏览器/tab,避免热更新后继续使用失效资源。""" + for proxy_key, entry in list(self._browser_manager.list_browser_entries()): + group = self._pool.get_group_by_proxy_key(proxy_key) + if group is None: + self._apply_closed_tabs_locked( + await self._browser_manager.close_browser(proxy_key) + ) + continue + for type_name in list(entry.tabs.keys()): + tab = entry.tabs[type_name] + pair = self._pool.get_account_by_id(tab.account_id) + if ( + pair is None + or pair[0] is not group + or pair[1].type != type_name + or not pair[1].enabled + ): + self._invalidate_tab_sessions_locked(proxy_key, type_name) + if tab.active_requests == 0: + # 与 reconcile 一致:优先同组同一页 re-auth,失败或无可用账号再关 tab + switched = False + group = self._pool.get_group_by_proxy_key(proxy_key) + if group is not None: + next_account = self._pool.next_available_account_in_group( + group, + type_name, + exclude_account_ids={tab.account_id}, + ) + if next_account is not None: + plugin = PluginRegistry.get(type_name) + if plugin is not None: + switched = ( + await self._browser_manager.switch_tab_account( + proxy_key, + type_name, + self._pool.account_id(group, next_account), + self._make_apply_auth_fn( + plugin, + next_account, + ), + ) + ) + if not switched: + await self._clear_tab_domain_cookies_if_supported( + proxy_key, type_name + ) + closed = await self._browser_manager.close_tab( + proxy_key, type_name + ) + if closed is not None: + self._apply_closed_tabs_locked([closed]) + else: + self._browser_manager.mark_tab_draining(proxy_key, type_name) + + def _invalidate_session_locked( + self, + session_id: str, + entry: SessionEntry | None = None, + ) -> None: + entry = entry or self._session_cache.get(session_id) + if entry is None: + return + self._session_cache.delete(session_id) + self._conv_index.remove_session(session_id) + self._browser_manager.unregister_session( + entry.proxy_key, + entry.type_name, + session_id, + ) + plugin = PluginRegistry.get(entry.type_name) + if plugin is not None: + plugin.drop_session(session_id) + + def _invalidate_tab_sessions_locked( + self, + proxy_key: ProxyKey, + type_name: str, + ) -> None: + tab = self._browser_manager.get_tab(proxy_key, type_name) + if tab is None or not tab.sessions: + return + session_ids = list(tab.sessions) + self._session_cache.delete_many(session_ids) + plugin = PluginRegistry.get(type_name) + if plugin is not None: + plugin.drop_sessions(session_ids) + tab.sessions.clear() + + async def _recover_browser_resource_invalid_locked( + self, + type_name: str, + target: _RequestTarget, + request_id: str, + active_session_id: str | None, + error: BrowserResourceInvalidError, + attempt: int, + max_retries: int, + ) -> None: + account_id = self._pool.account_id(target.group, target.account) + diagnostics = self._browser_manager.browser_diagnostics(target.proxy_key) + logger.warning( + "[chat] browser resource invalid attempt=%s/%s type=%s proxy=%s account=%s session_id=%s request_id=%s resource=%s helper=%s stage=%s stream_phase=%s browser_present=%s proc_alive=%s cdp_listening=%s tab_count=%s active_requests=%s err=%s", + attempt + 1, + max_retries, + type_name, + target.proxy_key.fingerprint_id, + account_id, + active_session_id, + request_id, + error.resource_hint, + error.helper_name, + error.stage, + error.stream_phase, + diagnostics.get("browser_present"), + diagnostics.get("proc_alive"), + diagnostics.get("cdp_listening"), + diagnostics.get("tab_count"), + diagnostics.get("active_requests"), + error, + ) + stderr_tail = str(diagnostics.get("stderr_tail") or "").strip() + if stderr_tail: + logger.warning( + "[chat] browser resource invalid stderr tail proxy=%s request_id=%s:\n%s", + target.proxy_key.fingerprint_id, + request_id, + stderr_tail, + ) + + if active_session_id is not None: + self._invalidate_session_locked(active_session_id) + if error.resource_hint == "transport": + logger.warning( + "[chat] transport-level stream failure, keep tab/browser and retry proxy=%s request_id=%s", + target.proxy_key.fingerprint_id, + request_id, + ) + return + self._browser_manager.mark_tab_draining(target.proxy_key, type_name) + + browser_restart_reason: str | None = None + if error.resource_hint == "browser": + browser_restart_reason = "resource_hint" + # Legacy: page_fetch transport is no longer used by Claude (context_request since v0.x). + # Kept for potential future plugins that still use page_fetch transport. + elif ( + error.helper_name == "stream_raw_via_page_fetch" + and error.stage in {"read_timeout", "evaluate_timeout"} + ): + browser_restart_reason = f"{error.helper_name}:{error.stage}" + + if browser_restart_reason is not None: + logger.warning( + "[chat] escalating browser recovery to full restart proxy=%s request_id=%s reason=%s", + target.proxy_key.fingerprint_id, + request_id, + browser_restart_reason, + ) + closed = await self._browser_manager.close_browser(target.proxy_key) + self._apply_closed_tabs_locked(closed) + return + + self._invalidate_tab_sessions_locked(target.proxy_key, type_name) + closed = await self._browser_manager.close_tab(target.proxy_key, type_name) + if closed is not None: + self._apply_closed_tabs_locked([closed]) + + def _revive_tab_if_possible_locked( + self, + proxy_key: ProxyKey, + type_name: str, + ) -> bool: + tab = self._browser_manager.get_tab(proxy_key, type_name) + if tab is None or tab.active_requests != 0: + return False + if tab.accepting_new: + return True + + pair = self._pool.get_account_by_id(tab.account_id) + if pair is None: + return False + _, account = pair + if not account.is_available(): + return False + tab.accepting_new = True + tab.state = "ready" + tab.frozen_until = None + tab.last_used_at = time.time() + return True + + async def _reconcile_tabs_locked(self) -> None: + """ + 收尾所有 non-ready tab: + + - 若原账号已恢复可用,则恢复 tab + - 否则若同组有其他可用账号,则在 drained 后切号 + - 否则关闭 tab + """ + for proxy_key, entry in list(self._browser_manager.list_browser_entries()): + for type_name in list(entry.tabs.keys()): + tab = entry.tabs[type_name] + if tab.accepting_new: + continue + if tab.active_requests != 0: + continue + if self._revive_tab_if_possible_locked(proxy_key, type_name): + continue + + group = self._pool.get_group_by_proxy_key(proxy_key) + if group is None: + await self._clear_tab_domain_cookies_if_supported( + proxy_key, type_name + ) + closed = await self._browser_manager.close_tab(proxy_key, type_name) + if closed is not None: + self._apply_closed_tabs_locked([closed]) + continue + + next_account = self._pool.next_available_account_in_group( + group, + type_name, + exclude_account_ids={tab.account_id}, + ) + if next_account is not None: + plugin = PluginRegistry.get(type_name) + if plugin is None: + continue + self._invalidate_tab_sessions_locked(proxy_key, type_name) + switched = await self._browser_manager.switch_tab_account( + proxy_key, + type_name, + self._pool.account_id(group, next_account), + self._make_apply_auth_fn(plugin, next_account), + ) + if switched: + continue + + await self._clear_tab_domain_cookies_if_supported(proxy_key, type_name) + closed = await self._browser_manager.close_tab(proxy_key, type_name) + if closed is not None: + self._apply_closed_tabs_locked([closed]) + + async def _reuse_session_target_locked( + self, + plugin: Any, + type_name: str, + session_id: str, + ) -> _RequestTarget | None: + entry = self._session_cache.get(session_id) + if entry is None or entry.type_name != type_name: + return None + + pair = self._pool.get_account_by_id(entry.account_id) + if pair is None: + self._invalidate_session_locked(session_id, entry) + return None + group, account = pair + + tab = self._browser_manager.get_tab(entry.proxy_key, type_name) + if ( + tab is None + or tab.account_id != entry.account_id + or not plugin.has_session(session_id) + ): + self._invalidate_session_locked(session_id, entry) + return None + + if not tab.accepting_new: + self._invalidate_session_locked(session_id, entry) + return None + if session_id in self._busy_sessions: + raise RuntimeError("当前会话正在处理中,请稍后再试") + if tab.active_requests >= self._tab_max_concurrent: + raise RuntimeError("当前会话所在 tab 繁忙,请稍后再试") + + page = self._browser_manager.acquire_tab( + entry.proxy_key, + type_name, + self._tab_max_concurrent, + ) + if page is None: + raise RuntimeError("当前会话暂不可复用,请稍后再试") + + self._session_cache.touch(session_id) + self._busy_sessions.add(session_id) + context = await self._browser_manager.ensure_browser( + entry.proxy_key, + group.proxy_pass, + ) + return _RequestTarget( + proxy_key=entry.proxy_key, + group=group, + account=account, + context=context, + page=page, + session_id=session_id, + full_history=False, + ) + + async def _allocate_new_target_locked( + self, + type_name: str, + ) -> _RequestTarget: + # 1. 已打开浏览器里已有该 type 的可服务 tab,直接复用。 + existing_tabs: list[tuple[int, float, ProxyKey, TabRuntime]] = [] + for proxy_key, entry in self._browser_manager.list_browser_entries(): + tab = entry.tabs.get(type_name) + if ( + tab is not None + and tab.accepting_new + and tab.active_requests < self._tab_max_concurrent + ): + existing_tabs.append( + (tab.active_requests, tab.last_used_at, proxy_key, tab) + ) + if existing_tabs: + _, _, proxy_key, tab = min(existing_tabs, key=lambda item: item[:2]) + pair = self._pool.get_account_by_id(tab.account_id) + if pair is None: + self._invalidate_tab_sessions_locked(proxy_key, type_name) + closed = await self._browser_manager.close_tab(proxy_key, type_name) + if closed is not None: + self._apply_closed_tabs_locked([closed]) + else: + group, account = pair + page = self._browser_manager.acquire_tab( + proxy_key, + type_name, + self._tab_max_concurrent, + ) + if page is not None: + context = await self._browser_manager.ensure_browser( + proxy_key, + group.proxy_pass, + ) + return _RequestTarget( + proxy_key=proxy_key, + group=group, + account=account, + context=context, + page=page, + session_id=None, + full_history=True, + ) + + # 2. 已打开浏览器里还没有该 type tab,但该组有可用账号,直接建新 tab。 + open_browser_candidates: list[ + tuple[int, float, ProxyKey, ProxyGroupConfig] + ] = [] + for proxy_key, entry in self._browser_manager.list_browser_entries(): + if type_name in entry.tabs: + continue + group = self._pool.get_group_by_proxy_key(proxy_key) + if group is None: + continue + if not self._pool.has_available_account_in_group(group, type_name): + continue + open_browser_candidates.append( + ( + self._browser_manager.browser_load(proxy_key), + entry.last_used_at, + proxy_key, + group, + ) + ) + if open_browser_candidates: + _, _, proxy_key, group = min( + open_browser_candidates, key=lambda item: item[:2] + ) + account = self._pool.next_available_account_in_group(group, type_name) + if account is not None: + plugin = PluginRegistry.get(type_name) + if plugin is None: + raise ValueError(f"未注册的 type: {type_name}") + await self._browser_manager.open_tab( + proxy_key, + group.proxy_pass, + type_name, + self._pool.account_id(group, account), + plugin.create_page, + self._make_apply_auth_fn(plugin, account), + ) + page = self._browser_manager.acquire_tab( + proxy_key, + type_name, + self._tab_max_concurrent, + ) + if page is None: + raise RuntimeError("新建 tab 后仍无法占用请求槽位") + context = await self._browser_manager.ensure_browser( + proxy_key, + group.proxy_pass, + ) + return _RequestTarget( + proxy_key=proxy_key, + group=group, + account=account, + context=context, + page=page, + session_id=None, + full_history=True, + ) + + # 3. 已打开浏览器里该 type tab 已 drained,且同组有备用账号,可在当前 tab 切号。 + switch_candidates: list[tuple[float, ProxyKey, ProxyGroupConfig]] = [] + for proxy_key, entry in self._browser_manager.list_browser_entries(): + tab = entry.tabs.get(type_name) + if tab is None or tab.active_requests != 0: + continue + group = self._pool.get_group_by_proxy_key(proxy_key) + if group is None: + continue + if not self._pool.has_available_account_in_group( + group, + type_name, + exclude_account_ids={tab.account_id}, + ): + continue + switch_candidates.append((tab.last_used_at, proxy_key, group)) + if switch_candidates: + _, proxy_key, group = min(switch_candidates, key=lambda item: item[0]) + tab = self._browser_manager.get_tab(proxy_key, type_name) + plugin = PluginRegistry.get(type_name) + if tab is not None and plugin is not None: + next_account = self._pool.next_available_account_in_group( + group, + type_name, + exclude_account_ids={tab.account_id}, + ) + if next_account is not None: + self._invalidate_tab_sessions_locked(proxy_key, type_name) + switched = await self._browser_manager.switch_tab_account( + proxy_key, + type_name, + self._pool.account_id(group, next_account), + self._make_apply_auth_fn(plugin, next_account), + ) + if switched: + page = self._browser_manager.acquire_tab( + proxy_key, + type_name, + self._tab_max_concurrent, + ) + if page is None: + raise RuntimeError("切号后仍无法占用请求槽位") + context = await self._browser_manager.ensure_browser( + proxy_key, + group.proxy_pass, + ) + return _RequestTarget( + proxy_key=proxy_key, + group=group, + account=next_account, + context=context, + page=page, + session_id=None, + full_history=True, + ) + + # 4. 开新浏览器。 + open_groups = { + proxy_key.fingerprint_id + for proxy_key in self._browser_manager.current_proxy_keys() + } + pair = self._pool.next_available_pair( + type_name, + exclude_fingerprint_ids=open_groups, + ) + if pair is None: + raise ValueError(f"没有类别为 {type_name!r} 的可用账号,请稍后再试") + group, account = pair + proxy_key = _proxy_key_for_group(group) + plugin = PluginRegistry.get(type_name) + if plugin is None: + raise ValueError(f"未注册的 type: {type_name}") + await self._browser_manager.open_tab( + proxy_key, + group.proxy_pass, + type_name, + self._pool.account_id(group, account), + plugin.create_page, + self._make_apply_auth_fn(plugin, account), + ) + page = self._browser_manager.acquire_tab( + proxy_key, + type_name, + self._tab_max_concurrent, + ) + if page is None: + raise RuntimeError("新浏览器建 tab 后仍无法占用请求槽位") + context = await self._browser_manager.ensure_browser( + proxy_key, group.proxy_pass + ) + return _RequestTarget( + proxy_key=proxy_key, + group=group, + account=account, + context=context, + page=page, + session_id=None, + full_history=True, + ) + + async def _stream_completion( + self, + type_name: str, + req: OpenAIChatRequest, + ) -> AsyncIterator[str]: + """ + 内部实现:调度 + 插件 stream_completion 字符串流,末尾附加 session_id 零宽编码。 + 对外仅通过 stream_openai_events() 暴露事件流。 + """ + plugin = PluginRegistry.get(type_name) + if plugin is None: + raise ValueError(f"未注册的 type: {type_name}") + + raw_messages = _request_messages_as_dicts(req) + conv_uuid = req.resume_session_id or parse_conv_uuid_from_messages(raw_messages) + + # Fingerprint matching: when the client doesn't preserve the zero-width + # session marker (conv_uuid is None), compute a fingerprint from + # system prompt + first user message and look up the matching session. + # This replaces sticky session and prevents context pollution. + fingerprint = "" + if not conv_uuid: + fingerprint = compute_conversation_fingerprint(req.messages) + if fingerprint: + entry = self._conv_index.lookup(fingerprint) + if entry is not None: + conv_uuid = entry.session_id + + logger.info("[chat] type=%s parsed conv_uuid=%s fingerprint=%s", type_name, conv_uuid, fingerprint or "n/a") + + has_tools = bool(req.tools) + react_prompt_prefix = format_react_prompt(req.tools or []) if has_tools else "" + + debug_path = ( + Path(__file__).resolve().parent.parent.parent + / "debug" + / "chat_prompt_debug.json" + ) + + max_retries = 3 + for attempt in range(max_retries): + target: _RequestTarget | None = None + active_session_id: str | None = None + request_id = uuid.uuid4().hex + try: + async with self._schedule_lock: + if conv_uuid: + target = await self._reuse_session_target_locked( + plugin, + type_name, + conv_uuid, + ) + if target is None: + target = await self._allocate_new_target_locked(type_name) + if target.session_id is not None: + active_session_id = target.session_id + + content = extract_user_content( + req.messages, + has_tools=has_tools, + react_prompt_prefix=react_prompt_prefix, + full_history=target.full_history, + ) + if not content.strip() and req.attachment_files: + content = "Please analyze the attached image." + if not content.strip(): + raise ValueError("messages 中需至少有一条带 content 的 user 消息") + + debug_path.parent.mkdir(parents=True, exist_ok=True) + debug_path.write_text( + json.dumps( + { + "prompt": content, + "full_history": target.full_history, + "type": type_name, + }, + ensure_ascii=False, + indent=2, + ), + encoding="utf-8", + ) + + account_id = self._pool.account_id(target.group, target.account) + session_id = target.session_id + if session_id is None: + await plugin.ensure_request_ready( + target.context, + target.page, + request_id=request_id, + session_id=None, + phase="create_conversation", + account_id=account_id, + ) + logger.info( + "[chat] create_conversation type=%s proxy=%s account=%s", + type_name, + target.proxy_key.fingerprint_id, + account_id, + ) + session_id = await plugin.create_conversation( + target.context, + target.page, + timezone=target.group.timezone + or getattr(target.proxy_key, "timezone", None) + or TIMEZONE, + public_model=str(getattr(req, "model", "") or ""), + upstream_model=str(getattr(req, "upstream_model", "") or ""), + request_id=request_id, + ) + if not session_id: + raise RuntimeError("插件创建会话失败") + async with self._schedule_lock: + self._session_cache.put( + session_id, + target.proxy_key, + type_name, + account_id, + ) + self._browser_manager.register_session( + target.proxy_key, + type_name, + session_id, + ) + self._busy_sessions.add(session_id) + # Register fingerprint for future matching + if fingerprint: + self._conv_index.register( + fingerprint, + session_id, + len(req.messages), + account_id, + ) + active_session_id = session_id + + # Skip pre-stream probe for newly created sessions: + # create_conversation already validated page health. + if target.session_id is not None: + await plugin.ensure_request_ready( + target.context, + target.page, + request_id=request_id, + session_id=session_id, + phase="stream_completion", + account_id=account_id, + ) + logger.info( + "[chat] stream_completion type=%s session_id=%s proxy=%s account=%s full_history=%s", + type_name, + session_id, + target.proxy_key.fingerprint_id, + account_id, + target.full_history, + ) + # 根据是否 full_history 选择附件来源: + # - 复用会话(full_history=False):仅最后一条 user 的图片(可能为空,则本轮不带图) + # - 新建/重建会话(full_history=True):所有历史 user 的图片 + attachments = ( + req.attachment_files_all_users + if target.full_history + else req.attachment_files_last_user + ) + + proxy_url = None + proxy_auth = None + proxy_forwarder = None + if plugin.stream_transport() == "context_request": + proxy_url, proxy_auth, proxy_forwarder = self._stream_proxy_settings(target) + target.proxy_url = proxy_url + target.proxy_auth = proxy_auth + target.proxy_forwarder = proxy_forwarder + try: + stream = cast( + AsyncIterator[str], + plugin.stream_completion( + target.context, + target.page, + session_id, + content, + request_id=request_id, + attachments=attachments, + proxy_url=proxy_url, + proxy_auth=proxy_auth, + ), + ) + async for chunk in stream: + yield chunk + finally: + if proxy_forwarder is not None: + try: + proxy_forwarder.stop() + except Exception: + pass + target.proxy_forwarder = None + + yield session_id_suffix(session_id) + return + except AccountFrozenError as e: + logger.warning( + "账号限流/额度用尽(插件上报),切换资源重试: type=%s proxy=%s err=%s", + type_name, + target.proxy_key.fingerprint_id if target else None, + e, + ) + async with self._schedule_lock: + if target is not None: + self.report_account_unfreeze( + target.group.fingerprint_id, + target.account.name, + e.unfreeze_at, + ) + self._browser_manager.mark_tab_draining( + target.proxy_key, + type_name, + frozen_until=e.unfreeze_at, + ) + self._invalidate_tab_sessions_locked( + target.proxy_key, type_name + ) + if attempt == max_retries - 1: + raise RuntimeError( + f"已重试 {max_retries} 次仍限流/过载,请稍后再试: {e}" + ) from e + continue + except BrowserResourceInvalidError as e: + proxy_for_log = ( + target.proxy_key.fingerprint_id + if target is not None + else getattr(getattr(e, "proxy_key", None), "fingerprint_id", None) + ) + logger.warning( + "[chat] browser resource invalid bubbled type=%s request_id=%s proxy=%s session_id=%s helper=%s stage=%s resource=%s err=%s", + type_name, + request_id, + proxy_for_log, + active_session_id, + e.helper_name, + e.stage, + e.resource_hint, + e, + ) + async with self._schedule_lock: + if target is not None: + await self._recover_browser_resource_invalid_locked( + type_name, + target, + request_id, + active_session_id, + e, + attempt, + max_retries, + ) + elif getattr(e, "proxy_key", None) is not None: + closed = await self._browser_manager.close_browser(e.proxy_key) + self._apply_closed_tabs_locked(closed) + if attempt == max_retries - 1: + raise RuntimeError( + f"浏览器资源已失效,重试 {max_retries} 次后仍失败: {e}" + ) from e + continue + finally: + if target is not None: + async with self._schedule_lock: + if active_session_id is not None: + self._busy_sessions.discard(active_session_id) + self._browser_manager.release_tab(target.proxy_key, type_name) + + async def stream_openai_events( + self, + type_name: str, + req: OpenAIChatRequest, + ) -> AsyncIterator[OpenAIStreamEvent]: + """ + 唯一流式出口:以 OpenAIStreamEvent 为中间态。插件产出字符串流, + 在此包装为 content_delta + finish,供协议适配层编码为各协议 SSE。 + """ + async for chunk in self._stream_completion(type_name, req): + # session marker 也作为 content_delta 透传(对事件消费者而言是普通文本片段) + yield OpenAIStreamEvent(type="content_delta", content=chunk) + yield OpenAIStreamEvent(type="finish", finish_reason="stop") diff --git a/core/api/config_routes.py b/core/api/config_routes.py new file mode 100644 index 0000000000000000000000000000000000000000..946ea995c1fefb38deceaed57e2ef4ccdeac01f2 --- /dev/null +++ b/core/api/config_routes.py @@ -0,0 +1,329 @@ +""" +Config routes: GET/PUT /api/config and the /config dashboard entrypoint. +""" + +import logging +import time +from pathlib import Path +from typing import Any + +from fastapi import APIRouter, Depends, HTTPException, Request, Response +from fastapi.responses import FileResponse, JSONResponse, RedirectResponse +from pydantic import BaseModel + +from core.api.auth import ( + ADMIN_SESSION_COOKIE, + admin_logged_in, + check_admin_login_rate_limit, + configured_config_secret_hash, + get_effective_auth_settings, + hash_config_secret, + normalize_api_key_text, + record_admin_login_failure, + record_admin_login_success, + refresh_runtime_auth_settings, + require_config_login, + require_config_login_enabled, + verify_config_secret, +) +from core.api.chat_handler import ChatHandler +from core.config.repository import ( + APP_SETTING_AUTH_API_KEY, + APP_SETTING_AUTH_CONFIG_SECRET_HASH, + APP_SETTING_ENABLE_PRO_MODELS, + ConfigRepository, +) +from core.plugin.base import PluginRegistry + +logger = logging.getLogger(__name__) + +STATIC_DIR = Path(__file__).resolve().parent.parent / "static" + + +class AdminLoginRequest(BaseModel): + secret: str + + +class AuthSettingsUpdateRequest(BaseModel): + api_key: str | None = None + admin_password: str | None = None + + +class ProModelsUpdateRequest(BaseModel): + enabled: bool = False + + +def _config_repo_of(request: Request) -> ConfigRepository: + repo: ConfigRepository | None = getattr(request.app.state, "config_repo", None) + if repo is None: + raise HTTPException(status_code=503, detail="Service is not ready") + return repo + + +def _auth_settings_payload(request: Request) -> dict[str, Any]: + settings = get_effective_auth_settings(request) + return { + "api_key": settings.api_key_text, + "api_key_configured": bool(settings.api_keys), + "api_key_source": settings.api_key_source, + "api_key_env_managed": settings.api_key_env_managed, + "admin_password_configured": bool(settings.config_secret_hash), + "admin_password_source": settings.config_secret_source, + "admin_password_env_managed": settings.config_secret_env_managed, + } + + +def create_config_router() -> APIRouter: + router = APIRouter() + + @router.get("/api/types") + def get_types(_: None = Depends(require_config_login)) -> list[str]: + """Return registered provider types for the config dashboard.""" + return PluginRegistry.all_types() + + @router.get("/api/config") + def get_config( + request: Request, _: None = Depends(require_config_login) + ) -> list[dict[str, Any]]: + """Return raw proxy-group and account configuration.""" + return _config_repo_of(request).load_raw() + + @router.get("/api/models/{provider}/metadata") + def get_public_model_metadata(provider: str) -> dict[str, Any]: + try: + return PluginRegistry.model_metadata(provider) + except ValueError as exc: + raise HTTPException(status_code=404, detail=str(exc)) from exc + + @router.get("/api/config/models") + def get_model_metadata(_: None = Depends(require_config_login)) -> dict[str, Any]: + return PluginRegistry.model_metadata("claude") + + @router.get("/api/config/auth-settings") + def get_auth_settings( + request: Request, _: None = Depends(require_config_login) + ) -> dict[str, Any]: + return _auth_settings_payload(request) + + @router.put("/api/config/auth-settings") + def put_auth_settings( + payload: AuthSettingsUpdateRequest, + request: Request, + _: None = Depends(require_config_login), + ) -> dict[str, Any]: + repo = _config_repo_of(request) + if payload.api_key is not None: + repo.set_app_setting( + APP_SETTING_AUTH_API_KEY, + normalize_api_key_text(payload.api_key), + ) + if payload.admin_password is not None: + password = payload.admin_password.strip() + repo.set_app_setting( + APP_SETTING_AUTH_CONFIG_SECRET_HASH, + hash_config_secret(password) if password else "", + ) + refresh_runtime_auth_settings(request.app) + settings_payload = _auth_settings_payload(request) + if payload.admin_password is not None and payload.admin_password.strip(): + store = getattr(request.app.state, "admin_sessions", None) + if store is not None: + token = (request.cookies.get(ADMIN_SESSION_COOKIE) or "").strip() + store.revoke(token) + return {"status": "ok", "settings": settings_payload} + + @router.get("/api/config/pro-models") + def get_pro_models( + request: Request, _: None = Depends(require_config_login) + ) -> dict[str, Any]: + repo = _config_repo_of(request) + enabled = repo.get_app_setting(APP_SETTING_ENABLE_PRO_MODELS) == "true" + return {"enabled": enabled} + + @router.put("/api/config/pro-models") + def put_pro_models( + payload: ProModelsUpdateRequest, + request: Request, + _: None = Depends(require_config_login), + ) -> dict[str, Any]: + repo = _config_repo_of(request) + repo.set_app_setting( + APP_SETTING_ENABLE_PRO_MODELS, "true" if payload.enabled else "false" + ) + return {"status": "ok", "enabled": payload.enabled} + + @router.get("/api/config/status") + def get_config_status( + request: Request, _: None = Depends(require_config_login) + ) -> dict[str, Any]: + """Return runtime account status for the config dashboard.""" + repo = _config_repo_of(request) + handler: ChatHandler | None = getattr(request.app.state, "chat_handler", None) + if handler is None: + raise HTTPException(status_code=503, detail="Service is not ready") + runtime_status = handler.get_account_runtime_status() + now = int(time.time()) + accounts: dict[str, dict[str, Any]] = {} + for group in repo.load_groups(): + for account in group.accounts: + account_id = f"{group.fingerprint_id}:{account.name}" + runtime = runtime_status.get(account_id, {}) + is_frozen = ( + account.unfreeze_at is not None and int(account.unfreeze_at) > now + ) + accounts[account_id] = { + "fingerprint_id": group.fingerprint_id, + "account_name": account.name, + "enabled": account.enabled, + "unfreeze_at": account.unfreeze_at, + "is_frozen": is_frozen, + "is_active": bool(runtime.get("is_active")), + "tab_state": runtime.get("tab_state"), + "accepting_new": runtime.get("accepting_new"), + "active_requests": runtime.get("active_requests", 0), + } + return {"now": now, "accounts": accounts} + + @router.put("/api/config") + async def put_config( + request: Request, + config: list[dict[str, Any]], + _: None = Depends(require_config_login), + ) -> dict[str, Any]: + """Update configuration and apply it immediately.""" + repo = _config_repo_of(request) + if not config: + raise HTTPException(status_code=400, detail="Configuration must not be empty") + for i, g in enumerate(config): + if not isinstance(g, dict): + raise HTTPException( + status_code=400, + detail=f"Item {i + 1} must be an object", + ) + if "fingerprint_id" not in g: + raise HTTPException( + status_code=400, + detail=f"Proxy group {i + 1} is missing field: fingerprint_id", + ) + use_proxy = g.get("use_proxy", True) + if isinstance(use_proxy, str): + use_proxy = use_proxy.strip().lower() not in { + "0", + "false", + "no", + "off", + } + else: + use_proxy = bool(use_proxy) + if use_proxy and not str(g.get("proxy_host", "")).strip(): + raise HTTPException( + status_code=400, + detail=f"Proxy group {i + 1} has proxy enabled and requires proxy_host", + ) + accounts = g.get("accounts", []) + if not accounts: + raise HTTPException( + status_code=400, + detail=f"Proxy group {i + 1} must include at least one account", + ) + for j, a in enumerate(accounts): + if not isinstance(a, dict) or not (a.get("name") or "").strip(): + raise HTTPException( + status_code=400, + detail=f"Account {j + 1} in proxy group {i + 1} must include name", + ) + if not (a.get("type") or "").strip(): + raise HTTPException( + status_code=400, + detail=f"Account {j + 1} in proxy group {i + 1} must include type (for example: claude)", + ) + if "enabled" in a and not isinstance( + a.get("enabled"), (bool, int, str) + ): + raise HTTPException( + status_code=400, + detail=f"Account {j + 1} in proxy group {i + 1} has an invalid enabled value", + ) + try: + repo.save_raw(config) + except Exception as e: + logger.exception("Failed to save configuration") + raise HTTPException(status_code=400, detail=str(e)) from e + # Apply immediately: reload groups and refresh the active chat handler. + try: + groups = repo.load_groups() + handler: ChatHandler | None = getattr( + request.app.state, "chat_handler", None + ) + if handler is None: + raise RuntimeError("chat_handler is not initialized") + await handler.refresh_configuration(groups, config_repo=repo) + except Exception as e: + logger.exception("Failed to reload account pool") + raise HTTPException( + status_code=500, + detail=f"Configuration was saved but reload failed: {e}", + ) from e + return {"status": "ok", "message": "Configuration saved and applied"} + + @router.get("/login", response_model=None) + def login_page(request: Request) -> FileResponse | RedirectResponse: + require_config_login_enabled(request) + if admin_logged_in(request): + return RedirectResponse(url="/config", status_code=302) + path = STATIC_DIR / "login.html" + if not path.is_file(): + raise HTTPException(status_code=404, detail="Login page is not ready") + return FileResponse(path) + + @router.post("/api/admin/login", response_model=None) + def admin_login(payload: AdminLoginRequest, request: Request) -> Response: + require_config_login_enabled(request) + check_admin_login_rate_limit(request) + secret = payload.secret.strip() + encoded = configured_config_secret_hash(_config_repo_of(request)) + if not secret or not encoded or not verify_config_secret(secret, encoded): + lock_seconds = record_admin_login_failure(request) + if lock_seconds > 0: + raise HTTPException( + status_code=429, + detail=f"Too many failed login attempts. Try again in {lock_seconds} seconds.", + ) + raise HTTPException(status_code=401, detail="Sign-in failed. Password is incorrect.") + record_admin_login_success(request) + store = request.app.state.admin_sessions + token = store.create() + response = JSONResponse({"status": "ok"}) + response.set_cookie( + key=ADMIN_SESSION_COOKIE, + value=token, + httponly=True, + samesite="lax", + secure=request.url.scheme == "https", + max_age=store.ttl_seconds, + path="/", + ) + return response + + @router.post("/api/admin/logout", response_model=None) + def admin_logout(request: Request) -> Response: + token = (request.cookies.get(ADMIN_SESSION_COOKIE) or "").strip() + store = getattr(request.app.state, "admin_sessions", None) + if store is not None: + store.revoke(token) + response = JSONResponse({"status": "ok"}) + response.delete_cookie(ADMIN_SESSION_COOKIE, path="/") + return response + + @router.get("/config", response_model=None) + def config_page(request: Request) -> FileResponse | RedirectResponse: + """配置页入口。""" + require_config_login_enabled(request) + if not admin_logged_in(request): + return RedirectResponse(url="/login", status_code=302) + path = STATIC_DIR / "config.html" + if not path.is_file(): + raise HTTPException(status_code=404, detail="Config page is not ready") + return FileResponse(path) + + return router diff --git a/core/api/conv_parser.py b/core/api/conv_parser.py new file mode 100644 index 0000000000000000000000000000000000000000..c8676e6503f1e3c25b2b3acf601f37fc573ecf38 --- /dev/null +++ b/core/api/conv_parser.py @@ -0,0 +1,186 @@ +""" +会话 ID 携带方式:任意字符串 → base64 → 零宽字符编码,用特殊零宽标记组包裹。 +从对话内容中通过正则匹配起止标记提取会话 ID,与 session_id 的具体格式无关。 + +编码协议: + session_id (utf-8) + → base64 (A-Za-z0-9+/=,最多 65 个不同符号) + → 每个 base64 字符用 3 位 base-5 零宽字符表示(5³=125 ≥ 65) + → 有效索引范围 0..64(64 个字符 + padding),故三元组首位最大为 2(3*25=75 > 64) + → 因此首位为 ZW[3] 或 ZW[4] 的三元组绝不出现在正文中 + → HEAD_MARK/TAIL_MARK 正是利用首位 ≥ 3 的三元组构造,保证不会误中正文 +""" + +import base64 +import re +from typing import Any + +# 零宽字符集(5 个字符,基数 5,索引 0-4) +_ZERO_WIDTH = ( + "\u200b", # 零宽空格 → 0 + "\u200c", # 零宽非连接符 → 1 + "\u200d", # 零宽连接符 → 2 + "\ufeff", # 零宽非断空格 → 3 + "\u180e", # 蒙古文元音分隔符 → 4 +) +_ZW_SET = frozenset(_ZERO_WIDTH) +_ZW_TO_IDX = {c: i for i, c in enumerate(_ZERO_WIDTH)} + +# base64 标准字符集(64 个字符),padding 符 "=" 用索引 64 表示 +_B64_CHARS = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/" +_B64_TO_IDX = {c: i for i, c in enumerate(_B64_CHARS)} +_PAD_IDX = 64 # "=" 的编码索引 + +# 起止标记:首位均为 ZW[3] 或 ZW[4],保证不出现在 payload 三元组中 +_HEAD_MARK = _ZERO_WIDTH[4] * 3 + _ZERO_WIDTH[3] * 3 # 6 个零宽字符 +_TAIL_MARK = _ZERO_WIDTH[3] * 3 + _ZERO_WIDTH[4] * 3 # 6 个零宽字符 + +_ZW_CLASS = r"[\u200b\u200c\u200d\ufeff\u180e]" + + +def _encode_b64idx(idx: int) -> str: + """将 base64 字符索引 (0-64) 编码为 3 个零宽字符(3 位 base-5)。""" + a = idx // 25 + r = idx % 25 + b = r // 5 + c = r % 5 + return _ZERO_WIDTH[a] + _ZERO_WIDTH[b] + _ZERO_WIDTH[c] + + +def _decode_b64idx(zw3: str) -> int | None: + """将 3 个零宽字符解码为 base64 字符索引(0-64),非法返回 None。""" + if len(zw3) != 3: + return None + a = _ZW_TO_IDX.get(zw3[0]) + b = _ZW_TO_IDX.get(zw3[1]) + c = _ZW_TO_IDX.get(zw3[2]) + if a is None or b is None or c is None: + return None + val = a * 25 + b * 5 + c + if val > 64: + return None + return val + + +def encode_session_id(session_id: str) -> str: + """ + 将任意字符串会话 ID 编码为不可见的零宽序列: + HEAD_MARK + zero_width_encoded(base64(utf-8(session_id))) + TAIL_MARK + """ + b64 = base64.b64encode(session_id.encode()).decode() + out: list[str] = [] + for ch in b64: + if ch == "=": + out.append(_encode_b64idx(_PAD_IDX)) + else: + idx = _B64_TO_IDX.get(ch) + if idx is None: + return "" + out.append(_encode_b64idx(idx)) + return _HEAD_MARK + "".join(out) + _TAIL_MARK + + +def decode_session_id(text: str) -> str | None: + """ + 从文本中提取第一个被标记包裹的会话 ID(解码零宽 → base64 → utf-8)。 + 若未找到有效标记或解码失败则返回 None。 + """ + m = re.search( + re.escape(_HEAD_MARK) + r"(" + _ZW_CLASS + r"+?)" + re.escape(_TAIL_MARK), + text, + ) + if not m: + return None + body = m.group(1) + if len(body) % 3 != 0: + return None + b64_chars: list[str] = [] + for i in range(0, len(body), 3): + idx = _decode_b64idx(body[i : i + 3]) + if idx is None: + return None + b64_chars.append("=" if idx == _PAD_IDX else _B64_CHARS[idx]) + try: + return base64.b64decode("".join(b64_chars)).decode() + except Exception: + return None + + +def decode_latest_session_id(text: str) -> str | None: + """ + 从文本中提取最后一个被标记包裹的会话 ID。 + 用于客户端保留完整历史时,优先命中最近一次返回的 session_id。 + """ + matches = list( + re.finditer( + re.escape(_HEAD_MARK) + r"(" + _ZW_CLASS + r"+?)" + re.escape(_TAIL_MARK), + text, + ) + ) + if not matches: + return None + body = matches[-1].group(1) + if len(body) % 3 != 0: + return None + b64_chars: list[str] = [] + for i in range(0, len(body), 3): + idx = _decode_b64idx(body[i : i + 3]) + if idx is None: + return None + b64_chars.append("=" if idx == _PAD_IDX else _B64_CHARS[idx]) + try: + return base64.b64decode("".join(b64_chars)).decode() + except Exception: + return None + + +def extract_session_id_marker(text: str) -> str: + """ + 从文本中提取完整的零宽会话 ID 标记段(HEAD_MARK + body + TAIL_MARK), + 用于在 tool_calls 的 text_content 中携带会话 ID 至下一轮对话。 + 若未找到则返回空字符串。 + """ + m = re.search( + re.escape(_HEAD_MARK) + _ZW_CLASS + r"+?" + re.escape(_TAIL_MARK), + text, + ) + return m.group(0) if m else "" + + +def session_id_suffix(session_id: str) -> str: + """返回响应末尾需附加的不可见标记(含 HEAD/TAIL 包裹的零宽编码会话 ID)。""" + return encode_session_id(session_id) + + +def strip_session_id_suffix(text: str) -> str: + """去掉文本中所有零宽会话 ID 标记段(HEAD_MARK...TAIL_MARK),返回干净正文。""" + return re.sub( + re.escape(_HEAD_MARK) + _ZW_CLASS + r"+?" + re.escape(_TAIL_MARK), + "", + text, + ) + + +def _normalize_content(content: str | list[Any]) -> str: + if isinstance(content, str): + return content + parts: list[str] = [] + for p in content: + if isinstance(p, dict) and p.get("type") == "text" and "text" in p: + parts.append(str(p["text"])) + elif isinstance(p, str): + parts.append(p) + return " ".join(parts) + + +def parse_conv_uuid_from_messages(messages: list[dict[str, Any]]) -> str | None: + """从 messages 中解析最新会话 ID(从最后一条带标记的消息开始逆序查找)。""" + for m in reversed(messages): + content = m.get("content") + if content is None: + continue + text = _normalize_content(content) + decoded = decode_latest_session_id(text) + if decoded is not None: + return decoded + return None diff --git a/core/api/fingerprint.py b/core/api/fingerprint.py new file mode 100644 index 0000000000000000000000000000000000000000..068019d34db0ef24fd12a1d162681d97752b3575 --- /dev/null +++ b/core/api/fingerprint.py @@ -0,0 +1,41 @@ +"""会话指纹:基于 system prompt + 首条 user 消息计算 SHA-256 指纹。 + +同一逻辑对话(相同 system + 相同首条 user)的指纹恒定, +不同对话指纹不同,杜绝上下文污染。 +""" + +import hashlib + +from core.api.schemas import OpenAIMessage + + +def _norm_content(content: str | list | None) -> str: + if content is None: + return "" + if isinstance(content, str): + return content.strip() + # list[OpenAIContentPart] + parts: list[str] = [] + for p in content: + if hasattr(p, "type") and p.type == "text" and p.text: + parts.append(p.text.strip()) + return " ".join(parts) + + +def compute_conversation_fingerprint(messages: list[OpenAIMessage]) -> str: + """sha256(system_prompt + first_user_message)[:16] + + Returns empty string if no user message found. + """ + system_text = "" + first_user_text = "" + for m in messages: + if m.role == "system" and not system_text: + system_text = _norm_content(m.content) + elif m.role == "user" and not first_user_text: + first_user_text = _norm_content(m.content) + break + if not first_user_text: + return "" + raw = f"{system_text}\n{first_user_text}" + return hashlib.sha256(raw.encode()).hexdigest()[:16] diff --git a/core/api/function_call.py b/core/api/function_call.py new file mode 100644 index 0000000000000000000000000000000000000000..eaee8d38eb897db8f782cf9e8e09cf1dc352f987 --- /dev/null +++ b/core/api/function_call.py @@ -0,0 +1,351 @@ +""" +Function Call 层:解析模型输出的 格式,转换为 OpenAI tool_calls; +将 tools 和 tool 结果拼入 prompt。对外统一使用 OpenAI 格式。 +""" + +import json +import re +import uuid +from collections.abc import Callable +from typing import Any + +TOOL_CALL_PREFIX = "" +TOOL_CALL_PREFIX_LEN = len(TOOL_CALL_PREFIX) +TOOL_CALL_PATTERN = re.compile( + r"\s*(.*?)\s*", + re.DOTALL, +) + + +def parse_tool_calls(text: str) -> list[dict[str, Any]]: + """ + 从文本中解析所有 ... 块。 + 返回 [{"name": str, "arguments": dict | str}, ...] + """ + if not text or not text.strip(): + return [] + matches = TOOL_CALL_PATTERN.findall(text) + result: list[dict[str, Any]] = [] + for m in matches: + try: + obj = json.loads(m.strip()) + if isinstance(obj, dict) and "name" in obj: + args = obj.get("arguments", {}) + if isinstance(args, str): + try: + args = json.loads(args) + except json.JSONDecodeError: + args = {} + result.append({"name": obj["name"], "arguments": args}) + except json.JSONDecodeError: + pass + return result + + +def detect_tool_call_mode(buffer: str, *, strip_session_id: bool = True) -> bool | None: + """ + 根据 buffer 内容判断是否为 tool_call 模式。 + None=尚未确定,True=tool_call,False=普通文本。 + strip_session_id: 若 True,先去掉开头的零宽 session_id 前缀再判断。 + """ + content = buffer + if strip_session_id: + from core.api.conv_parser import strip_session_id_suffix + + content = strip_session_id_suffix(buffer) + stripped = content.lstrip() + if stripped.startswith(TOOL_CALL_PREFIX): + return True + if len(stripped) > TOOL_CALL_PREFIX_LEN: + return False + return None + + +def format_tools_for_prompt(tools: list[dict[str, Any]]) -> str: + """ + 将 OpenAI 格式的 tools 转为可读文本,用于 prompt。 + 兼容 OpenAI 格式 {type, function: {name, description, parameters}} + 和 Cursor 格式 {name, description, input_schema}。 + """ + if not tools: + return "" + lines: list[str] = [] + for t in tools: + if not isinstance(t, dict): + continue + fn = t.get("function") if t.get("type") == "function" else t + if not isinstance(fn, dict): + fn = t + name = fn.get("name") + if not name: + continue + desc = fn.get("description") or fn.get("summary") or "" + params = fn.get("parameters") or fn.get("input_schema") or {} + if isinstance(params, str): + try: + params = json.loads(params) + except json.JSONDecodeError: + params = {} + props = params.get("properties") or {} + required = params.get("required") or [] + args_desc = ", ".join( + f"{k}: {v.get('type', 'any')}" + (" (必填)" if k in required else "") + for k, v in props.items() + ) + lines.append( + f"- {name}({args_desc}): {desc[:200]}" + ("..." if len(desc) > 200 else "") + ) + return "\n".join(lines) if lines else "" + + +def build_tool_calls_response( + tool_calls_list: list[dict[str, Any]], + chat_id: str, + model: str, + created: int, + *, + text_content: str = "", +) -> dict[str, Any]: + """返回 OpenAI 格式的 chat.completion(含 tool_calls)。 + message.content 为字符串(或空时 null),tool_calls 为 OpenAI 标准数组。 + """ + tool_calls: list[dict[str, Any]] = [] + for tc in tool_calls_list: + name = tc.get("name", "") + args = tc.get("arguments", {}) + if isinstance(args, dict): + args_str = json.dumps(args, ensure_ascii=False) + else: + try: + args_obj = json.loads(str(args)) if args else {} + args_str = json.dumps(args_obj, ensure_ascii=False) + except json.JSONDecodeError: + args_str = "{}" + call_id = f"call_{uuid.uuid4().hex[:24]}" + tool_calls.append( + { + "id": call_id, + "type": "function", + "function": {"name": name, "arguments": args_str}, + } + ) + message: dict[str, Any] = { + "role": "assistant", + "content": text_content if text_content else None, + "tool_calls": tool_calls, + } + return { + "id": chat_id, + "object": "chat.completion", + "created": created, + "model": model, + "choices": [ + { + "index": 0, + "message": message, + "finish_reason": "tool_calls", + } + ], + } + + +def _openai_sse_chunk( + chat_id: str, + model: str, + created: int, + delta: dict, + finish_reason: str | None = None, +) -> str: + """构建 OpenAI 流式 SSE:data: \\n\\n""" + choice: dict[str, Any] = {"index": 0, "delta": delta} + if finish_reason is not None: + choice["finish_reason"] = finish_reason + data = { + "id": chat_id, + "object": "chat.completion.chunk", + "created": created, + "model": model, + "choices": [choice], + } + return f"data: {json.dumps(data, ensure_ascii=False)}\n\n" + + +def build_openai_text_sse_events( + chat_id: str, + model: str, + created: int, +) -> tuple[str, Callable[[str], str], Callable[[], str]]: + """返回 OpenAI 流式事件的工厂。 + 返回 (msg_start_sse, make_delta_sse, make_stop_sse)。 + msg_start 为带 role 的首 chunk。 + """ + + def msg_start() -> str: + return _openai_sse_chunk( + chat_id, + model, + created, + delta={"role": "assistant", "content": ""}, + finish_reason=None, + ) + + def make_delta_sse(text: str) -> str: + return _openai_sse_chunk( + chat_id, + model, + created, + delta={ + "content": text, + }, + finish_reason=None, + ) + + def make_stop_sse() -> str: + return ( + _openai_sse_chunk( + chat_id, + model, + created, + delta={}, + finish_reason="stop", + ) + + "data: [DONE]\n\n" + ) + + return msg_start(), make_delta_sse, make_stop_sse + + +def build_tool_calls_with_ids( + tool_calls_list: list[dict[str, Any]], +) -> list[dict[str, Any]]: + """从 name+arguments 的 tool_calls_list 构建带 id 的 OpenAI 格式 tool_calls。 + 用于流式下发与 debug 保存共用同一批 id,保证下一轮 request 的 tool_call_id 一致。 + """ + tool_calls: list[dict[str, Any]] = [] + for i, tc in enumerate(tool_calls_list): + name = tc.get("name", "") + args = tc.get("arguments", {}) + if isinstance(args, dict): + args_str = json.dumps(args, ensure_ascii=False) + else: + try: + args_obj = json.loads(str(args)) if args else {} + args_str = json.dumps(args_obj, ensure_ascii=False) + except json.JSONDecodeError: + args_str = "{}" + tool_calls.append( + { + "index": i, + "id": f"call_{uuid.uuid4().hex[:24]}", + "type": "function", + "function": {"name": name, "arguments": args_str}, + } + ) + return tool_calls + + +def build_openai_tool_use_sse_events( + tool_calls_list: list[dict[str, Any]], + chat_id: str, + model: str, + created: int, + *, + text_content: str = "", + tool_calls_with_ids: list[dict[str, Any]] | None = None, +) -> tuple[list[str], list[dict[str, Any]]]: + """构建 OpenAI 流式 SSE 事件,用于 tool_calls 场景。 + 有 text_content(如 thinking)时:先发 content chunk,再发 tool_calls chunk,便于客户端先展示思考再展示工具调用。 + 无 text_content 时:单 chunk 发 role + tool_calls。 + tool_calls 场景最后只发 finish_reason,不发 data: [DONE](think 之后不跟 [DONE])。 + """ + if tool_calls_with_ids is not None: + tool_calls = tool_calls_with_ids + else: + tool_calls = build_tool_calls_with_ids(tool_calls_list) + sse_list: list[str] = [] + if text_content: + # 先发 content(thinking),再发 tool_calls,同一条消息内顺序展示 + sse_list.append( + _openai_sse_chunk( + chat_id, + model, + created, + {"role": "assistant", "content": text_content}, + None, + ) + ) + sse_list.append( + _openai_sse_chunk(chat_id, model, created, {"tool_calls": tool_calls}, None) + ) + else: + sse_list.append( + _openai_sse_chunk( + chat_id, + model, + created, + { + "role": "assistant", + "content": "", + "tool_calls": tool_calls, + }, + None, + ) + ) + sse_list.append(_openai_sse_chunk(chat_id, model, created, {}, "tool_calls")) + return (sse_list, tool_calls) + + +def stream_openai_tool_use_sse_events( + tool_calls_list: list[dict[str, Any]], + chat_id: str, + model: str, + created: int, + *, + tool_calls_with_ids: list[dict[str, Any]] | None = None, +) -> list[str]: + """ + 流式下发 tool_calls:先发每个 tool 的 id/name(arguments 为空), + 再逐个发 arguments 分片,最后发 finish_reason。便于客户端逐步展示。 + content(如 )由调用方已通过 delta 流式发完,此处只发 tool_calls 相关 chunk。 + """ + if tool_calls_with_ids is not None: + tool_calls = tool_calls_with_ids + else: + tool_calls = build_tool_calls_with_ids(tool_calls_list) + sse_list: list[str] = [] + # 第一块:仅 id + type + name,arguments 为空,让客户端先展示“正在调用 xxx” + tool_calls_heads: list[dict[str, Any]] = [] + for tc in tool_calls: + tool_calls_heads.append( + { + "index": tc["index"], + "id": tc["id"], + "type": "function", + "function": {"name": tc["function"]["name"], "arguments": ""}, + } + ) + sse_list.append( + _openai_sse_chunk( + chat_id, model, created, {"tool_calls": tool_calls_heads}, None + ) + ) + # 后续每块:只带 index + function.arguments,可整段发或分片发,这里按 tool 逐个发 + for tc in tool_calls: + args = tc.get("function", {}).get("arguments", "") or "" + if not args: + continue + sse_list.append( + _openai_sse_chunk( + chat_id, + model, + created, + { + "tool_calls": [ + {"index": tc["index"], "function": {"arguments": args}} + ] + }, + None, + ) + ) + sse_list.append(_openai_sse_chunk(chat_id, model, created, {}, "tool_calls")) + return sse_list diff --git a/core/api/mock_claude.py b/core/api/mock_claude.py new file mode 100644 index 0000000000000000000000000000000000000000..eea3612d701c4aaa2714fb466edd40d2ccaa6eec --- /dev/null +++ b/core/api/mock_claude.py @@ -0,0 +1,104 @@ +""" +Mock Claude API:与 claude.py 调用格式兼容,不消耗 token。 +设置 CLAUDE_START_URL 和 CLAUDE_API_BASE 指向 http://ip:port/mock 即可调试。 +""" + +import asyncio +import json +import uuid as uuid_mod +from collections.abc import AsyncIterator + +from fastapi import APIRouter +from fastapi.responses import HTMLResponse, StreamingResponse + +router = APIRouter(prefix="/mock", tags=["mock"]) + +MOCK_ORG_UUID = "00000000-0000-0000-0000-000000000001" + +# 自定义回复:请求来时在终端用多行输入要回复的内容 +INPUT_PROMPT = "Mock 回复内容(支持多行,空行结束):" + + +@router.get("", response_class=HTMLResponse) +@router.get("/", response_class=HTMLResponse) +def mock_start_page() -> str: + """CLAUDE_START_URL 指向 /mock 时,浏览器加载此页。""" + return """ + +Mock Claude +

Mock Claude - 调试用

+ +""" + + +@router.get("/account") +def mock_account() -> dict: + """_get_org_uuid 调用的 GET /account,返回 memberships 含 org uuid。""" + return { + "memberships": [ + {"organization": {"uuid": MOCK_ORG_UUID}}, + ], + } + + +@router.post("/organizations/{org_uuid}/chat_conversations") +def mock_create_conversation(org_uuid: str) -> dict: + """_post_create_conversation 调用的创建会话接口。""" + return { + "uuid": str(uuid_mod.uuid4()), + } + + +def _read_reply_from_stdin() -> str: + """在终端通过多次 input 读取多行回复内容(空行结束,阻塞,应在线程中调用)。""" + print(INPUT_PROMPT, flush=True) + print("直接粘贴多行文本,最后再按一次回车输入空行结束。", flush=True) + lines: list[str] = [] + while True: + try: + line = input() + except EOFError: + break + # 空行表示输入结束 + if line == "": + break + lines.append(line) + return "\n".join(lines).rstrip() + + +@router.post("/organizations/{org_uuid}/chat_conversations/{conv_uuid}/completion") +async def mock_completion( + org_uuid: str, + conv_uuid: str, # noqa: ARG001 +) -> StreamingResponse: + """stream_completion 调用的 completion 接口,返回 SSE 流。请求来时在终端 input 输入回复内容。""" + + # 在线程中执行 input,避免阻塞事件循环 + reply_text = await asyncio.to_thread(_read_reply_from_stdin) + + async def sse_stream() -> AsyncIterator[str]: + msg_uuid = str(uuid_mod.uuid4()) + # message_start + yield f"data: {json.dumps({'type': 'message_start', 'message': {'id': msg_uuid, 'uuid': msg_uuid, 'model': 'claude-sonnet-4-5-20250929', 'type': 'message', 'role': 'assistant'}})}\n\n" + # content_block_start + yield f"data: {json.dumps({'type': 'content_block_start', 'index': 0, 'content_block': {'type': 'text', 'text': ''}})}\n\n" + # content_block_delta 分块流式输出 + chunk_size = 2 + for i in range(0, len(reply_text), chunk_size): + chunk = reply_text[i : i + chunk_size] + yield f"data: {json.dumps({'type': 'content_block_delta', 'index': 0, 'delta': {'type': 'text_delta', 'text': chunk}})}\n\n" + await asyncio.sleep(0.05) + # content_block_stop + yield f"data: {json.dumps({'type': 'content_block_stop', 'index': 0})}\n\n" + # message_stop + yield f"data: {json.dumps({'type': 'message_stop'})}\n\n" + + return StreamingResponse( + sse_stream(), + media_type="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "X-Accel-Buffering": "no", + }, + ) diff --git a/core/api/react.py b/core/api/react.py new file mode 100644 index 0000000000000000000000000000000000000000..12dcafd8e5014fb08c72f4c07cc9bad43cf04ed4 --- /dev/null +++ b/core/api/react.py @@ -0,0 +1,244 @@ +""" +ReAct 模块:解析 LLM 纯文本输出(Thought/Action/Action Input),转换为 function_call 格式。 +适用于不支持 function calling 的 LLM。提示词借鉴 Dify ReAct 结构与表述,保持行式格式。 +""" + +import json +import re +from typing import Any + +# 复用 function_call 的工具描述格式化 +from core.api.function_call import format_tools_for_prompt + +# 固定 ReAct 提示词(借鉴 Dify ReAct 结构与表述,保持行式格式以兼容 parse_react_output) +REACT_PROMPT_FIXED = r"""Respond to the human as helpfully and accurately as possible. + +You have access to the following tools (listed below under "## Available tools"). + +Use the following format: + +Question: the input question you must answer +Thought: consider what you know and what to do next +Action: the tool name (exactly one of the tools listed below) +Action Input: a single-line JSON object as the tool input +Observation: the result of the action (injected by the system — do NOT output this yourself) +... (repeat Thought / Action / Action Input as needed; after each, the system adds Observation) +Thought: I know the final answer +Final Answer: your final response to the human + +Provide only ONE action per response. Valid "Action" values: a tool name from the list, or (when done) output "Final Answer" / "最终答案" instead of Action + Action Input. + +Rules: +- After "Action Input: {...}" you must STOP and wait for Observation. Do not add any text, code, or explanation after the JSON line. +- Action Input must be a single-line valid JSON. All double quotes `"` in JSON values must be escaped as `\"`. Do not output "Observation" yourself. +- Format is: Thought → Action → Action Input (or Final Answer when done). Then the system replies with Observation. + +Begin. Always respond with a valid Thought then Action then Action Input (or Final Answer). Use tools when necessary; respond with Final Answer when appropriate. +""" + + +def format_react_prompt( + tools: list[dict[str, Any]], + tools_text: str | None = None, +) -> str: + """用固定 ReAct 提示词构建系统前缀,并拼接可用工具列表。""" + if tools_text is None: + tools_text = format_tools_for_prompt(tools) + return REACT_PROMPT_FIXED + "\n\n---\n\n## Available tools\n\n" + tools_text + "\n" + + +def parse_react_output(text: str) -> dict[str, Any] | None: + """ + 解析行式 ReAct 输出 (Thought / Action / Action Input)。 + 返回 {"type": "final_answer", "content": str} 或 + {"type": "tool_call", "tool": str, "params": dict} 或 None(解析失败)。 + 注意:优先解析 Action,若同时存在 Action 与 Final Answer,则返回 tool_call, + 以便正确下发 tool_calls 给客户端执行。 + """ + if not text or not text.strip(): + return None + + # 1. 优先提取 Action + Action Input(若存在则返回 tool_call,避免被 Final Answer 抢先) + action_match = re.search(r"^\s*Action[::]\s*(\w+)", text, re.MULTILINE) + if action_match: + tool_name = action_match.group(1).strip() + + # 2. 提取 Action Input(单行 JSON 或简单多行) + input_match = re.search(r"Action Input[::]\s*(\{[^\n]+\})", text) + json_str: str | None = None + if input_match: + json_str = input_match.group(1).strip() + else: + # 多行 JSON:从 Action Input 到下一关键字 + start_m = re.search(r"Action Input[::]\s*", text) + if start_m: + rest = text[start_m.end() :] + end_m = re.search( + r"\n\s*(?:Thought|Action|Observation|Final)", rest, re.I + ) + raw = rest[: end_m.start()].strip() if end_m else rest.strip() + if raw.startswith("{") and "}" in raw: + depth = 0 + for i, c in enumerate(raw): + if c == "{": + depth += 1 + elif c == "}": + depth -= 1 + if depth == 0: + json_str = raw[: i + 1] + break + + if not json_str: + return { + "type": "tool_call", + "tool": tool_name, + "params": {}, + "parse_error": "no_action_input", + } + + try: + params = json.loads(json_str) + except json.JSONDecodeError as e: + return { + "type": "tool_call", + "tool": tool_name, + "params": {}, + "parse_error": str(e), + } + + return {"type": "tool_call", "tool": tool_name, "params": params} + + # 3. 无 Action 时,检查 Final Answer + m = re.search( + r"(?:Final Answer|最终答案)[::]\s*(.*)", + text, + re.DOTALL | re.I, + ) + if m: + content = m.group(1).strip() + return {"type": "final_answer", "content": content} + + return None + + +def react_output_to_tool_calls(parsed: dict[str, Any]) -> list[dict[str, Any]]: + """ + 将 parse_react_output 的 tool_call 结果转为 function_call 的 tool_calls_list 格式。 + 供 build_tool_calls_response / build_tool_calls_chunk 使用。 + """ + if parsed.get("type") != "tool_call": + return [] + return [ + { + "name": parsed.get("tool", ""), + "arguments": parsed.get("params", {}), + } + ] + + +def format_react_final_answer_content(text: str) -> str: + """ + 若 text 为 ReAct 的 Thought + Final Answer 格式,则将 Thought 用 包裹, + 便于客户端识别为思考内容;否则返回原文本。 + """ + if not text or not text.strip(): + return text + # 匹配 Thought: ... 与 Final Answer: / 最终答案: ... + thought_m = re.search( + r"Thought[::]\s*(.+?)(?=\s*(?:Final Answer|最终答案)[::]|\Z)", + text, + re.DOTALL | re.I, + ) + answer_m = re.search( + r"(?:Final Answer|最终答案)[::]\s*(.*)", + text, + re.DOTALL | re.I, + ) + if thought_m and answer_m: + thought = (thought_m.group(1) or "").strip() + answer = (answer_m.group(1) or "").strip() + return f"{thought}\n\n{answer}" + return text + + +def extract_thought_so_far(buffer: str) -> tuple[str | None, bool]: + """ + 从流式 buffer 中增量解析当前 Thought 内容(Thought: 到 Action:/Final Answer:/结尾)。 + 返回 (thought_content, thought_ended)。 + - thought_content: 当前可确定的 Thought 正文(不含 "Thought:" 前缀),未出现 Thought: 则为 None。 + - thought_ended: 是否已出现 Action: 或 Final Answer:,即 Thought 段已结束。 + """ + content = buffer.lstrip() + if not content: + return (None, False) + # 必须已有 Thought: + thought_start = re.search(r"Thought[::]\s*", content, re.I) + if not thought_start: + return (None, False) + start = thought_start.end() + rest = content[start:] + # 先找完整结尾:Action: 或 Final Answer:(一出现就截断,不要求后面已有工具名) + action_m = re.search(r"Action[::]\s*", rest, re.I) + final_m = re.search(r"(?:Final Answer|最终答案)[::]\s*", rest, re.I) + end_pos: int | None = None + if action_m and (final_m is None or action_m.start() <= final_m.start()): + end_pos = action_m.start() + if final_m and (end_pos is None or final_m.start() < end_pos): + end_pos = final_m.start() + if end_pos is not None: + thought_content = rest[:end_pos].rstrip() + return (thought_content, True) + # 未出现完整关键字时,去掉末尾「可能是关键字前缀」的片段,避免把 "\nAc"、"tion:"、"r:"、" [完整回答]" 等当 thought 流式发出 + thought_content = rest.rstrip() + for kw in ("Action:", "Final Answer:", "最终答案:"): + for i in range(len(kw), 0, -1): + if thought_content.lower().endswith(kw[:i].lower()): + thought_content = thought_content[:-i].rstrip() + break + # 再剥 "Final Answer:" 的尾部片段(流式时先收到 "Answer:"、"r:" 等),避免 [完整回答] 被算进 think + for suffix in ( + " Final Answer:", + " Final Answer", + " Answer:", + " Answer", + "Answer:", + "Answer", + "nswer:", + "nswer", + "swer:", + "swer", + "wer:", + "wer", + "er:", + "er", + "r:", + "r", + ): + if thought_content.endswith(suffix): + thought_content = thought_content[: -len(suffix)].rstrip() + break + return (thought_content, False) + + +def detect_react_mode(buffer: str) -> bool | None: + """ + 判断 buffer 是否为 ReAct 工具调用模式(规范格式:Thought:/Action:/Action Input:)。 + 仅当出现该格式时才识别为 ReAct;未按规范返回一律视为纯文本。 + None=尚未确定,True=ReAct 工具调用,False=普通文本或 Final Answer。 + """ + stripped = buffer.lstrip() + if re.search(r"^\s*Action[::]\s*\w+", stripped, re.MULTILINE): + return True + if re.search(r"(?:Final Answer|最终答案)[::]", stripped, re.I): + return False + # 流式可能只传 Thought/Action 的前半段(如 "Th"、"Tho"),视为尚未确定,继续缓冲 + lower = stripped.lower() + if lower and ("thought:".startswith(lower) or "action:".startswith(lower)): + return None + # 若 buffer 中已出现 Thought:,可能为前导语 + Thought 格式(第二轮常见),保持 None 等待 Action + if re.search(r"Thought[::]\s*", stripped, re.I): + return None + # 未按规范:首行不是 Thought:/Action: 开头则视为纯文本 + if stripped and not re.match(r"^\s*(?:Thought|Action)[::]", stripped, re.I): + return False + return None diff --git a/core/api/react_stream_parser.py b/core/api/react_stream_parser.py new file mode 100644 index 0000000000000000000000000000000000000000..f1e8f7c0b40837844dadc814fe9c222354232f1a --- /dev/null +++ b/core/api/react_stream_parser.py @@ -0,0 +1,435 @@ +""" +ReAct 流式解析器:字符级 MarkerDetector + StateMachine + +将 LLM 的 ReAct 格式文本实时转换为 OpenAI SSE 流式事件: + + Thought: xxx → delta.content = "xxx" (流式) + Action: name → 缓存工具名 + Action Input: {} → delta.tool_calls[0].function.arguments (流式) + Final Answer: xxx → delta.content = "xxx" (流式) + Observation: xxx → delta.content = "xxx" (流式) + 无标记文本 → delta.content = "xxx" (直通) + +核心设计: + MarkerDetector:默认零延迟直通,仅在遇到 Marker 首字母时暂存等待确认。 + StateMachine:IDLE / IN_THOUGHT / IN_ACTION / IN_ACTION_INPUT / + IN_OBSERVATION / IN_FINAL +""" + +import json +import uuid +from enum import Enum, auto + +# ─── Marker 定义 ────────────────────────────────────────────────────────────── + +# 注意顺序:仅影响精确匹配时的遍历,不影响正确性(每个 marker 唯一) +_MARKERS: tuple[str, ...] = ( + "Thought:", + "Action Input:", # 必须比 "Action:" 先定义(_is_prefix 依赖全集) + "Action:", + "Observation:", + "Final Answer:", + "最终答案:", +) + +_MARKER_FIRST_CHARS: frozenset[str] = frozenset(m[0] for m in _MARKERS) + + +# ─── 状态枚举 ───────────────────────────────────────────────────────────────── + + +class _State(Enum): + IDLE = auto() + IN_THOUGHT = auto() + IN_ACTION = auto() + IN_ACTION_INPUT = auto() + IN_OBSERVATION = auto() + IN_FINAL = auto() + + +# ─── 解析器主体 ─────────────────────────────────────────────────────────────── + + +class ReactStreamParser: + """ + 字符级 ReAct 流解析器,将 LLM 的 ReAct 格式输出转换为 OpenAI SSE chunks。 + + 用法:: + + parser = ReactStreamParser(chat_id, model, created, has_tools=True) + async for chunk in llm_stream: + # 注意:不要对 chunk 做 strip_session_id_suffix,否则客户端收不到会话 ID,下一轮无法复用会话 + for sse in parser.feed(chunk): + yield sse + for sse in parser.finish(): + yield sse + """ + + def __init__( + self, + chat_id: str, + model: str, + created: int, + *, + has_tools: bool = True, + ) -> None: + self._chat_id = chat_id + self._model = model + self._created = created + self._has_tools = has_tools + + # MarkerDetector 状态 + self._suspect_buf = "" + self._skip_leading_ws = False # 吃掉 Marker 冒号后的空白 + + # StateMachine 状态 + self._state = _State.IDLE + self._action_name_buf = "" # 收集 Action 名称 + self._tool_call_id = "" + self._tool_call_index = 0 + + # 输出控制标志 + self._emitted_msg_start = False + self._think_open = False # 已发 + self._think_closed = False # 已发 + self._tool_call_started = False # 已发 function_call_start + + # ── 公开 API ────────────────────────────────────────────────────────────── + + def feed(self, chunk: str) -> list[str]: + """处理一个文本 chunk,返回需要下发的 SSE 字符串列表(含 `data: ...\\n\\n`)。""" + events: list[str] = [] + for char in chunk: + events.extend(self._on_char(char)) + return events + + def finish(self) -> list[str]: + """LLM 流结束时调用:flush 残留 suspect_buf,补发结束 SSE。""" + events: list[str] = [] + if self._suspect_buf: + buf, self._suspect_buf = self._suspect_buf, "" + events.extend(self._dispatch(buf)) + events.extend(self._emit_end()) + return events + + # ── 字符级处理(MarkerDetector)────────────────────────────────────────── + + def _on_char(self, char: str) -> list[str]: + # 吃掉 Marker 冒号后的单个/连续空格或制表符 + if self._skip_leading_ws: + if char in (" ", "\t"): + return [] + self._skip_leading_ws = False + + # 无工具:全部直通为纯文本 + if not self._has_tools: + return self._dispatch(char) + + if not self._suspect_buf: + if char in _MARKER_FIRST_CHARS: + self._suspect_buf = char + return [] + return self._dispatch(char) + + # 正在疑似 Marker + self._suspect_buf += char + + matched = self._exact_match() + if matched: + events = self._on_marker(matched) + self._suspect_buf = "" + return events + + if self._is_prefix(): + return [] # 继续积累,等待确认 + + # 排除歧义:flush suspect_buf 作为普通内容 + buf, self._suspect_buf = self._suspect_buf, "" + return self._dispatch(buf) + + def _exact_match(self) -> str | None: + for m in _MARKERS: + if self._suspect_buf == m: + return m + return None + + def _is_prefix(self) -> bool: + return any(m.startswith(self._suspect_buf) for m in _MARKERS) + + # ── Marker 触发(状态转换)──────────────────────────────────────────────── + + def _on_marker(self, marker: str) -> list[str]: + events: list[str] = [] + events.extend(self._exit_state()) + + if marker == "Thought:": + self._state = _State.IN_THOUGHT + events.extend(self._enter_thought()) + + elif marker == "Action:": + self._state = _State.IN_ACTION + self._action_name_buf = "" + + elif marker == "Action Input:": + # 若 Action 名后没有 \n(罕见),在此兜底触发 function_call_start + if not self._tool_call_started: + events.extend(self._start_function_call()) + self._state = _State.IN_ACTION_INPUT + + elif marker == "Observation:": + self._state = _State.IN_OBSERVATION + + elif marker in ("Final Answer:", "最终答案:"): + self._state = _State.IN_FINAL + events.extend(self._enter_final()) + + self._skip_leading_ws = True # 跳过 Marker 冒号后的空白 + return events + + def _exit_state(self) -> list[str]: + """离开当前状态时的收尾动作。""" + events: list[str] = [] + if self._state == _State.IN_THOUGHT: + if self._think_open and not self._think_closed: + self._think_closed = True + events.extend(self._make_content("")) + return events + + # ── 状态进入 ────────────────────────────────────────────────────────────── + + def _enter_thought(self) -> list[str]: + events: list[str] = [] + if not self._emitted_msg_start: + events.extend(self._emit_msg_start()) + # 每次进入 IN_THOUGHT 都开一个新的 块(支持多轮) + self._think_open = True + self._think_closed = False + events.extend(self._make_content("")) + return events + + def _enter_final(self) -> list[str]: + events: list[str] = [] + if not self._emitted_msg_start: + events.extend(self._emit_msg_start()) + return events + + def _start_function_call(self) -> list[str]: + """Action 名收集完毕,发送 function_call_start。""" + name = self._action_name_buf.strip() + self._tool_call_id = f"call_{uuid.uuid4().hex[:8]}" + self._tool_call_started = True + events: list[str] = [] + if not self._emitted_msg_start: + events.extend(self._emit_msg_start()) + events.extend(self._make_tool_call_start(name)) + return events + + # ── 内容分发(根据当前状态路由字符/字符串)────────────────────────────────── + + def _dispatch(self, text: str) -> list[str]: + """将 text 按当前状态路由到对应的输出动作。""" + s = self._state + events: list[str] = [] + + if s == _State.IDLE: + if not self._emitted_msg_start: + events.extend(self._emit_msg_start()) + events.extend(self._make_content(text)) + + elif s == _State.IN_THOUGHT: + if not self._think_open: + # 安全兜底:进入 IN_THOUGHT 时通常已调用 _enter_thought,此处防御 + events.extend(self._enter_thought()) + events.extend(self._make_content(text)) + + elif s == _State.IN_ACTION: + # 逐字收集 action 名,遇换行触发 function_call_start + for ch in text: + if ch == "\n": + if self._action_name_buf.strip() and not self._tool_call_started: + events.extend(self._start_function_call()) + else: + self._action_name_buf += ch + + elif s == _State.IN_ACTION_INPUT: + if self._tool_call_started: + events.extend(self._make_tool_args(text)) + + elif s == _State.IN_OBSERVATION: + # Observation 内容作为普通文本流输出 + if not self._emitted_msg_start: + events.extend(self._emit_msg_start()) + events.extend(self._make_content(text)) + + elif s == _State.IN_FINAL: + events.extend(self._make_content(text)) + + return events + + # ── 流结束 ──────────────────────────────────────────────────────────────── + + def _emit_end(self) -> list[str]: + events: list[str] = [] + + # 关闭未关闭的 + if self._think_open and not self._think_closed: + self._think_closed = True + events.extend(self._make_content("")) + + if self._tool_call_started: + events.extend(self._make_tool_calls_finish()) + elif self._emitted_msg_start: + events.extend(self._make_stop()) + else: + # 空响应:补齐最小合法 SSE 序列 + events.extend(self._emit_msg_start()) + events.extend(self._make_stop()) + + events.append("data: [DONE]\n\n") + return events + + # ── SSE chunk 构造 ───────────────────────────────────────────────────────── + + def _emit_msg_start(self) -> list[str]: + """发送 role:assistant + content:"" 的首帧。""" + self._emitted_msg_start = True + return [ + self._sse( + { + "id": self._chat_id, + "object": "chat.completion.chunk", + "created": self._created, + "model": self._model, + "choices": [ + { + "index": 0, + "delta": {"role": "assistant", "content": ""}, + "logprobs": None, + "finish_reason": None, + } + ], + } + ) + ] + + def _make_content(self, text: str) -> list[str]: + return [ + self._sse( + { + "id": self._chat_id, + "object": "chat.completion.chunk", + "created": self._created, + "model": self._model, + "choices": [ + { + "index": 0, + "delta": {"content": text}, + "logprobs": None, + "finish_reason": None, + } + ], + } + ) + ] + + def _make_tool_call_start(self, name: str) -> list[str]: + """发送 function_call_start:携带 id、type、name 和空 arguments。""" + return [ + self._sse( + { + "id": self._chat_id, + "object": "chat.completion.chunk", + "created": self._created, + "model": self._model, + "choices": [ + { + "index": 0, + "delta": { + "tool_calls": [ + { + "index": self._tool_call_index, + "id": self._tool_call_id, + "type": "function", + "function": {"name": name, "arguments": ""}, + } + ] + }, + "logprobs": None, + "finish_reason": None, + } + ], + } + ) + ] + + def _make_tool_args(self, delta: str) -> list[str]: + """逐字发送 arguments 增量。""" + return [ + self._sse( + { + "id": self._chat_id, + "object": "chat.completion.chunk", + "created": self._created, + "model": self._model, + "choices": [ + { + "index": 0, + "delta": { + "tool_calls": [ + { + "index": self._tool_call_index, + "function": {"arguments": delta}, + } + ] + }, + "logprobs": None, + "finish_reason": None, + } + ], + } + ) + ] + + def _make_tool_calls_finish(self) -> list[str]: + return [ + self._sse( + { + "id": self._chat_id, + "object": "chat.completion.chunk", + "created": self._created, + "model": self._model, + "choices": [ + { + "index": 0, + "delta": {}, + "logprobs": None, + "finish_reason": "tool_calls", + } + ], + } + ) + ] + + def _make_stop(self) -> list[str]: + return [ + self._sse( + { + "id": self._chat_id, + "object": "chat.completion.chunk", + "created": self._created, + "model": self._model, + "choices": [ + { + "index": 0, + "delta": {}, + "logprobs": None, + "finish_reason": "stop", + } + ], + } + ) + ] + + @staticmethod + def _sse(obj: dict) -> str: + return f"data: {json.dumps(obj, ensure_ascii=False)}\n\n" diff --git a/core/api/routes.py b/core/api/routes.py new file mode 100644 index 0000000000000000000000000000000000000000..d730001290cd0527c6968841d638af3c4e8be2df --- /dev/null +++ b/core/api/routes.py @@ -0,0 +1,177 @@ +""" +OpenAI 协议路由。 + +支持: +- /openai/{provider}/v1/chat/completions +- /openai/{provider}/v1/models +- 旧路径 /{provider}/v1/...(等价于 OpenAI 协议) +""" + +import json +import time +from collections.abc import AsyncIterator +from typing import Any + +from fastapi import APIRouter, Depends, HTTPException, Request +from fastapi.responses import JSONResponse, StreamingResponse + +from core.api.auth import require_api_key +from core.api.chat_handler import ChatHandler +from core.config.repository import APP_SETTING_ENABLE_PRO_MODELS +from core.plugin.base import PluginRegistry +from core.protocol.openai import OpenAIProtocolAdapter +from core.protocol.schemas import CanonicalChatRequest +from core.protocol.service import CanonicalChatService + + +def get_chat_handler(request: Request) -> ChatHandler: + """从 app state 取出 ChatHandler。""" + handler = getattr(request.app.state, "chat_handler", None) + if handler is None: + raise HTTPException(status_code=503, detail="服务未就绪") + return handler + + +def resolve_request_model( + provider: str, + canonical_req: CanonicalChatRequest, +) -> CanonicalChatRequest: + resolved = PluginRegistry.resolve_model(provider, canonical_req.model) + canonical_req.model = resolved.public_model + canonical_req.metadata["upstream_model"] = resolved.upstream_model + return canonical_req + + +def check_pro_model_access( + request: Request, + provider: str, + model: str, +) -> JSONResponse | None: + """Return 403 JSONResponse if model requires Pro and Pro is disabled, else None.""" + plugin = PluginRegistry.get(provider) + if plugin is None: + return None + pro_models = getattr(plugin, "PRO_MODELS", frozenset()) + if model not in pro_models: + return None + config_repo = getattr(request.app.state, "config_repo", None) + if config_repo is None: + return None + enabled = config_repo.get_app_setting(APP_SETTING_ENABLE_PRO_MODELS) + if enabled == "true": + return None + return JSONResponse( + status_code=403, + content={ + "error": { + "message": ( + f"Model '{model}' requires a Claude Pro subscription. " + "Enable Pro models in the config page at /config." + ), + "type": "model_not_available", + "code": "pro_model_required", + } + }, + ) + + +def create_router() -> APIRouter: + """创建 OpenAI 协议路由。""" + router = APIRouter(dependencies=[Depends(require_api_key)]) + adapter = OpenAIProtocolAdapter() + + def _list_models(provider: str) -> dict[str, Any]: + try: + metadata = PluginRegistry.model_metadata(provider) + except ValueError as exc: + raise HTTPException(status_code=404, detail=str(exc)) from exc + now = int(time.time()) + return { + "object": "list", + "data": [ + { + "id": mid, + "object": "model", + "created": now, + "owned_by": provider, + } + for mid in metadata["public_models"] + ], + } + + @router.get("/openai/{provider}/v1/models") + def list_models(provider: str) -> dict[str, Any]: + return _list_models(provider) + + @router.get("/{provider}/v1/models") + def list_models_legacy(provider: str) -> dict[str, Any]: + return _list_models(provider) + + async def _chat_completions( + provider: str, + request: Request, + handler: ChatHandler, + ) -> Any: + raw_body = await request.json() + try: + canonical_req = resolve_request_model( + provider, + adapter.parse_request(provider, raw_body), + ) + except Exception as exc: + status, payload = adapter.render_error(exc) + return JSONResponse(status_code=status, content=payload) + + pro_err = check_pro_model_access(request, provider, canonical_req.model) + if pro_err is not None: + return pro_err + + service = CanonicalChatService(handler) + if canonical_req.stream: + + async def sse_stream() -> AsyncIterator[str]: + try: + async for event in adapter.render_stream( + canonical_req, + service.stream_raw(canonical_req), + ): + yield event + except Exception as exc: + status, payload = adapter.render_error(exc) + del status + yield f"data: {json.dumps(payload, ensure_ascii=False)}\n\n" + + return StreamingResponse( + sse_stream(), + media_type="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "X-Accel-Buffering": "no", + }, + ) + + try: + raw_events = await service.collect_raw(canonical_req) + return adapter.render_non_stream(canonical_req, raw_events) + except Exception as exc: + status, payload = adapter.render_error(exc) + return JSONResponse(status_code=status, content=payload) + + @router.post("/openai/{provider}/v1/chat/completions") + async def chat_completions( + provider: str, + request: Request, + handler: ChatHandler = Depends(get_chat_handler), + ) -> Any: + return await _chat_completions(provider, request, handler) + + @router.post("/{provider}/v1/chat/completions") + async def chat_completions_legacy( + provider: str, + request: Request, + handler: ChatHandler = Depends(get_chat_handler), + ) -> Any: + return await _chat_completions(provider, request, handler) + + return router diff --git a/core/api/schemas.py b/core/api/schemas.py new file mode 100644 index 0000000000000000000000000000000000000000..e47cbc2c4c696b2128e168105ddac0beb809c26e --- /dev/null +++ b/core/api/schemas.py @@ -0,0 +1,168 @@ +"""OpenAI 兼容的请求/响应模型。""" + +from typing import Any + +from pydantic import BaseModel, Field + +from core.api.conv_parser import strip_session_id_suffix + + +class OpenAIContentPart(BaseModel): + type: str + text: str | None = None + image_url: dict[str, Any] | str | None = None + + +class InputAttachment(BaseModel): + filename: str + mime_type: str + data: bytes + + +class OpenAIMessage(BaseModel): + role: str = Field(..., description="system | user | assistant | tool") + content: str | list[OpenAIContentPart] | None = "" + tool_calls: list[dict[str, Any]] | None = Field( + default=None, description="assistant 发起的工具调用" + ) + tool_call_id: str | None = Field( + default=None, description="tool 消息对应的 call id" + ) + + model_config = {"extra": "allow"} + + +class OpenAIChatRequest(BaseModel): + """OpenAI Chat Completions API 兼容请求体。""" + + model: str = Field(default="", description="模型名,可忽略") + messages: list[OpenAIMessage] = Field(..., description="对话列表") + stream: bool = Field(default=False, description="是否流式返回") + tools: list[dict] | None = Field( + default=None, + description='工具列表,每项为 {"type":"function","function":{name,description,parameters,strict?}}', + ) + tool_choice: str | dict | None = Field( + default=None, + description='工具选择: "auto"|"required"|"none" 或 {"type":"function","name":"xxx"}', + ) + parallel_tool_calls: bool | None = Field( + default=None, + description="是否允许单次响应中并行多个 tool_call,false 时仅 0 或 1 个", + ) + resume_session_id: str | None = Field(default=None, exclude=True) + upstream_model: str | None = Field(default=None, exclude=True) + attachment_files: list[InputAttachment] = Field( + default_factory=list, + exclude=True, + description="本次实际要发送给站点的附件,由 ChatHandler 根据 full_history 选择来源填充。", + ) + # 仅供内部调度使用:最后一条 user 消息里的附件 & 所有 user 消息里的附件 + attachment_files_last_user: list[InputAttachment] = Field( + default_factory=list, exclude=True + ) + attachment_files_all_users: list[InputAttachment] = Field( + default_factory=list, exclude=True + ) + + +def _norm_content(c: str | list[OpenAIContentPart] | None) -> str: + """将 content 转为单段字符串。仅支持官方格式:字符串或 type=text 的 content part(取 text 字段)。""" + if c is None: + return "" + if isinstance(c, str): + return strip_session_id_suffix(c) + if not isinstance(c, list): + return "" + return strip_session_id_suffix( + " ".join( + p.text or "" + for p in c + if isinstance(p, OpenAIContentPart) and p.type == "text" and p.text + ) + ) + + +REACT_STRICT_SUFFIX = ( + "(严格 ReAct 执行模式;禁止输出「无法执行工具所以直接给方案」等解释或替代内容)" +) + + +def extract_user_content( + messages: list[OpenAIMessage], + *, + has_tools: bool = False, + react_prompt_prefix: str = "", + full_history: bool = False, +) -> str: + """ + 从 messages 中提取对话,拼成发给模型的 prompt。 + 网页/会话侧已有完整历史,只取尾部:最后一条为 user 时,从后向前找到最后一个 assistant(不包含), + 取该 assistant 之后到末尾;最后一条为 tool 时,从后向前找到最后一个 user(不包含),取该 user 之后到末尾。 + 支持 user、assistant、tool 角色;assistant 的 tool_calls 与 tool 结果会拼回。 + ReAct 模式:完整 ReAct Prompt 仅第一次对话传入(按完整 messages 判断 is_first_turn);后续只传尾部内容。 + """ + if not messages: + return "" + + parts: list[str] = [] + + # 重建会话时会把完整历史重新回放给站点,因此 tools 指令也需要重新注入。 + is_first_turn = not any(m.role in ("assistant", "tool") for m in messages) + if has_tools and react_prompt_prefix and (full_history or is_first_turn): + parts.append(react_prompt_prefix) + + if full_history: + tail = messages + else: + last = messages[-1] + if last.role == "user": + i = len(messages) - 1 + while i >= 0 and messages[i].role != "assistant": + i -= 1 + tail = messages[i + 1 :] + elif last.role == "tool": + i = len(messages) - 1 + while i >= 0 and messages[i].role != "user": + i -= 1 + tail = messages[i + 1 :] + else: + tail = messages[-2:] + + for m in tail: + if m.role == "system": + txt = _norm_content(m.content) + if txt: + parts.append(f"System:{txt}") + elif m.role == "user": + txt = _norm_content(m.content) + if txt: + if has_tools: + parts.append(f"**User**: {txt} {REACT_STRICT_SUFFIX}") + else: + parts.append(f"User:{txt}") + elif m.role == "assistant": + tool_calls_list = list(m.tool_calls or []) + if tool_calls_list: + for tc in tool_calls_list: + fn = tc.get("function") or {} + call_id = tc.get("id", "") + name = fn.get("name", "") + args = fn.get("arguments", "{}") + parts.append( + f"**Assistant**:\n\n```\nAction: {name}\nAction Input: {args}\nCall ID: {call_id}\n```" + ) + else: + txt = _norm_content(m.content) + if txt: + if has_tools: + parts.append(f"**Assistant**:\n\n{txt}") + else: + parts.append(f"Assistant:{txt}") + elif m.role == "tool": + txt = _norm_content(m.content) + call_id = m.tool_call_id or "" + parts.append( + f"**Observation(Call ID: {call_id})**: {txt}\n\n请根据以上观察结果继续。如需调用工具,输出 Thought / Action / Action Input;若任务已完成,输出 Final Answer。" + ) + return "\n".join(parts) diff --git a/core/app.py b/core/app.py new file mode 100644 index 0000000000000000000000000000000000000000..83a7f6153d75ba7c164fd0508cd0c76a7a8e2352 --- /dev/null +++ b/core/app.py @@ -0,0 +1,166 @@ +""" +FastAPI 应用组装:配置加载、账号池、会话缓存、浏览器管理、插件注册、路由挂载。 +""" + +import asyncio +import logging +from contextlib import asynccontextmanager +from pathlib import Path +from typing import AsyncIterator + +from fastapi import FastAPI, Request +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import FileResponse, JSONResponse + +from core.account.pool import AccountPool +from core.api.auth import ( + AdminLoginAttemptStore, + AdminSessionStore, + configured_config_login_lock_seconds, + configured_config_login_max_failures, + config_login_enabled, + ensure_config_secret_hashed, + refresh_runtime_auth_settings, +) +from core.api.anthropic_routes import create_anthropic_router +from core.api.chat_handler import ChatHandler +from core.api.config_routes import create_config_router +from core.api.routes import create_router +from core.config.repository import create_config_repository +from core.config.settings import get, get_bool +from core.constants import CDP_PORT_RANGE, CHROMIUM_BIN +from core.plugin.base import PluginRegistry +from core.plugin.claude import register_claude_plugin +from core.runtime.browser_manager import BrowserManager +from core.runtime.session_cache import SessionCache + +logger = logging.getLogger(__name__) +STATIC_DIR = Path(__file__).resolve().parent / "static" + + +@asynccontextmanager +async def lifespan(app: FastAPI) -> AsyncIterator[None]: + """启动时初始化配置与 ChatHandler,关闭时不做持久化(会话缓存进程内)。""" + # 注册插件 + register_claude_plugin() + + repo = create_config_repository() + repo.init_schema() + ensure_config_secret_hashed(repo) + app.state.config_repo = repo + auth_settings = refresh_runtime_auth_settings(app) + groups = repo.load_groups() + + chromium_bin = (get("browser", "chromium_bin") or "").strip() or CHROMIUM_BIN + headless = get_bool("browser", "headless", False) + no_sandbox = get_bool("browser", "no_sandbox", False) + disable_gpu = get_bool("browser", "disable_gpu", False) + disable_gpu_sandbox = get_bool("browser", "disable_gpu_sandbox", False) + cdp_wait_max_attempts = int(get("browser", "cdp_wait_max_attempts") or 90) + cdp_wait_interval_seconds = float( + get("browser", "cdp_wait_interval_seconds") or 2.0 + ) + cdp_wait_connect_timeout_seconds = float( + get("browser", "cdp_wait_connect_timeout_seconds") or 2.0 + ) + port_start = int(get("browser", "cdp_port_start") or 9223) + port_count = int(get("browser", "cdp_port_count") or 20) + port_range = ( + list(range(port_start, port_start + port_count)) + if port_count > 0 + else list(CDP_PORT_RANGE) + ) + api_keys = auth_settings.api_keys + pool = AccountPool.from_groups(groups) + session_cache = SessionCache() + browser_manager = BrowserManager( + chromium_bin=chromium_bin, + headless=headless, + no_sandbox=no_sandbox, + disable_gpu=disable_gpu, + disable_gpu_sandbox=disable_gpu_sandbox, + port_range=port_range, + cdp_wait_max_attempts=cdp_wait_max_attempts, + cdp_wait_interval_seconds=cdp_wait_interval_seconds, + cdp_wait_connect_timeout_seconds=cdp_wait_connect_timeout_seconds, + ) + app.state.chat_handler = ChatHandler( + pool=pool, + session_cache=session_cache, + browser_manager=browser_manager, + config_repo=repo, + ) + app.state.session_cache = session_cache + app.state.browser_manager = browser_manager + app.state.admin_sessions = AdminSessionStore() + app.state.admin_login_attempts = AdminLoginAttemptStore( + max_failures=configured_config_login_max_failures(), + lock_seconds=configured_config_login_lock_seconds(), + ) + if not groups: + logger.warning("数据库无配置,服务已启动但当前无可用账号") + if api_keys: + logger.info("API 鉴权已启用,已加载 %d 个 API Key", len(api_keys)) + if auth_settings.config_login_enabled: + logger.info( + "配置页登录已启用,失败 %d 次锁定 %d 秒", + app.state.admin_login_attempts.max_failures, + app.state.admin_login_attempts.lock_seconds, + ) + try: + await app.state.chat_handler.prewarm_resident_browsers() + except Exception: + logger.exception("启动预热浏览器失败") + app.state.maintenance_task = asyncio.create_task( + app.state.chat_handler.run_maintenance_loop() + ) + logger.info("服务已就绪,已注册 type: %s", ", ".join(PluginRegistry.all_types())) + yield + task = getattr(app.state, "maintenance_task", None) + handler = getattr(app.state, "chat_handler", None) + if handler is not None: + await handler.shutdown() + if task is not None: + try: + await task + except asyncio.CancelledError: + pass + app.state.chat_handler = None + + +def create_app() -> FastAPI: + app = FastAPI( + title="Web2API(Plugin)", + description="按 type 路由的 OpenAI 兼容接口,baseUrl: http://ip:port/{type}/v1/...", + lifespan=lifespan, + ) + app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=False, + allow_methods=["*"], + allow_headers=["*"], + ) + + @app.get("/", include_in_schema=False) + def root() -> FileResponse: + return FileResponse(STATIC_DIR / "index.html") + + @app.get("/healthz", include_in_schema=False) + def healthz(request: Request) -> JSONResponse: + return JSONResponse( + { + "status": "ok", + "config_login_enabled": config_login_enabled(request), + "login": "/login", + "config": "/config", + } + ) + + app.include_router(create_router()) + app.include_router(create_anthropic_router()) + app.include_router(create_config_router()) + return app + + +app = create_app() diff --git a/core/config/__init__.py b/core/config/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2031a8504ca5ee3bfb8e9b01cb5b406b9fd3de7a --- /dev/null +++ b/core/config/__init__.py @@ -0,0 +1,10 @@ +"""配置层 数据模型与持久化(独立 DB,不修改现有 config_db)。""" + +from core.config.schema import AccountConfig, ProxyGroupConfig +from core.config.repository import ConfigRepository + +__all__ = [ + "AccountConfig", + "ProxyGroupConfig", + "ConfigRepository", +] diff --git a/core/config/repository.py b/core/config/repository.py new file mode 100644 index 0000000000000000000000000000000000000000..b50cbda3df1c8eee917da40261845e83b4da0e22 --- /dev/null +++ b/core/config/repository.py @@ -0,0 +1,593 @@ +""" +配置持久化:默认使用 SQLite;提供 DATABASE_URL / WEB2API_DATABASE_URL 时切换到 PostgreSQL。 +表结构:proxy_group, account(含 name, type, auth JSON),以及 app_setting。 +""" + +from __future__ import annotations + +import os +import sqlite3 +from pathlib import Path +from typing import Any + +from core.config.schema import AccountConfig, ProxyGroupConfig, account_from_row +from core.config.settings import coerce_bool, get_database_url + + +DB_FILENAME = "db.sqlite3" +DB_PATH_ENV_KEY = "WEB2API_DB_PATH" +APP_SETTING_AUTH_API_KEY = "auth.api_key" +APP_SETTING_AUTH_CONFIG_SECRET_HASH = "auth.config_secret_hash" +APP_SETTING_ENABLE_PRO_MODELS = "claude.enable_pro_models" + + +def _get_db_path() -> Path: + """SQLite 文件路径。""" + configured = os.environ.get(DB_PATH_ENV_KEY, "").strip() + if configured: + return Path(configured).expanduser() + return Path(__file__).resolve().parent.parent.parent / DB_FILENAME + + +def create_config_repository( + db_path: Path | None = None, + database_url: str | None = None, +) -> "ConfigRepository": + resolved_database_url = ( + get_database_url().strip() if database_url is None else database_url.strip() + ) + return ConfigRepository( + _PostgresConfigRepository(resolved_database_url) + if resolved_database_url + else _SqliteConfigRepository(db_path or _get_db_path()) + ) + + +class _RepositoryBase: + def init_schema(self) -> None: + raise NotImplementedError + + def load_groups(self) -> list[ProxyGroupConfig]: + raise NotImplementedError + + def save_groups(self, groups: list[ProxyGroupConfig]) -> None: + raise NotImplementedError + + def update_account_unfreeze_at( + self, + fingerprint_id: str, + account_name: str, + unfreeze_at: int | None, + ) -> None: + raise NotImplementedError + + def load_raw(self) -> list[dict[str, Any]]: + """与前端/API 一致的原始列表格式。""" + groups = self.load_groups() + return [ + { + "proxy_host": g.proxy_host, + "proxy_user": g.proxy_user, + "proxy_pass": g.proxy_pass, + "fingerprint_id": g.fingerprint_id, + "use_proxy": g.use_proxy, + "timezone": g.timezone, + "accounts": [ + { + "name": a.name, + "type": a.type, + "auth": a.auth, + "enabled": a.enabled, + "unfreeze_at": a.unfreeze_at, + } + for a in g.accounts + ], + } + for g in groups + ] + + def load_app_settings(self) -> dict[str, str]: + raise NotImplementedError + + def get_app_setting(self, key: str) -> str | None: + value = self.load_app_settings().get(key) + return value if value is not None else None + + def set_app_setting(self, key: str, value: str | None) -> None: + raise NotImplementedError + + def save_raw(self, raw: list[dict[str, Any]]) -> None: + """从 API/前端原始格式写入并保存。""" + groups = _raw_to_groups(raw) + self.save_groups(groups) + + +class _SqliteConfigRepository(_RepositoryBase): + def __init__(self, db_path: Path | None = None) -> None: + self._db_path = db_path or _get_db_path() + self._schema_initialized = False + + def _conn(self) -> sqlite3.Connection: + self._db_path.parent.mkdir(parents=True, exist_ok=True) + return sqlite3.connect(self._db_path) + + def _init_tables(self, conn: sqlite3.Connection) -> None: + conn.execute( + """ + CREATE TABLE IF NOT EXISTS proxy_group ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + proxy_host TEXT NOT NULL, + proxy_user TEXT NOT NULL, + proxy_pass TEXT NOT NULL, + fingerprint_id TEXT NOT NULL DEFAULT '', + use_proxy INTEGER NOT NULL DEFAULT 1, + timezone TEXT + ) + """ + ) + conn.execute( + """ + CREATE TABLE IF NOT EXISTS account ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + proxy_group_id INTEGER NOT NULL, + name TEXT NOT NULL, + type TEXT NOT NULL, + auth TEXT NOT NULL DEFAULT '{}', + enabled INTEGER NOT NULL DEFAULT 1, + FOREIGN KEY (proxy_group_id) REFERENCES proxy_group(id) ON DELETE CASCADE + ) + """ + ) + conn.execute( + "CREATE INDEX IF NOT EXISTS ix_account_proxy_group_id ON account(proxy_group_id)" + ) + conn.execute("CREATE INDEX IF NOT EXISTS ix_account_type ON account(type)") + conn.execute( + """ + CREATE TABLE IF NOT EXISTS app_setting ( + key TEXT PRIMARY KEY, + value TEXT NOT NULL DEFAULT '' + ) + """ + ) + try: + conn.execute("ALTER TABLE account ADD COLUMN unfreeze_at INTEGER") + except sqlite3.OperationalError: + pass + try: + conn.execute( + "ALTER TABLE account ADD COLUMN enabled INTEGER NOT NULL DEFAULT 1" + ) + except sqlite3.OperationalError: + pass + try: + conn.execute( + "ALTER TABLE proxy_group ADD COLUMN use_proxy INTEGER NOT NULL DEFAULT 1" + ) + except sqlite3.OperationalError: + pass + try: + conn.execute("ALTER TABLE proxy_group ADD COLUMN timezone TEXT") + except sqlite3.OperationalError: + pass + conn.commit() + + def _ensure_schema(self) -> None: + if self._schema_initialized: + return + conn = self._conn() + try: + self._init_tables(conn) + self._schema_initialized = True + finally: + conn.close() + + def init_schema(self) -> None: + self._ensure_schema() + + def load_groups(self) -> list[ProxyGroupConfig]: + self._ensure_schema() + conn = self._conn() + try: + groups: list[ProxyGroupConfig] = [] + group_rows = conn.execute( + """ + SELECT id, proxy_host, proxy_user, proxy_pass, fingerprint_id, use_proxy, timezone + FROM proxy_group ORDER BY id ASC + """ + ).fetchall() + accounts_by_group: dict[int, list[AccountConfig]] = {} + for gid, name, type_, auth_json, enabled, unfreeze_at in conn.execute( + """ + SELECT proxy_group_id, name, type, auth, enabled, unfreeze_at + FROM account ORDER BY proxy_group_id ASC, id ASC + """ + ).fetchall(): + accounts_by_group.setdefault(int(gid), []).append( + account_from_row( + name, + type_, + auth_json or "{}", + enabled=bool(enabled) if enabled is not None else True, + unfreeze_at=unfreeze_at, + ) + ) + for gid, proxy_host, proxy_user, proxy_pass, fingerprint_id, use_proxy, timezone in group_rows: + groups.append( + ProxyGroupConfig( + proxy_host=proxy_host, + proxy_user=proxy_user, + proxy_pass=proxy_pass, + fingerprint_id=fingerprint_id or "", + use_proxy=bool(use_proxy), + timezone=timezone, + accounts=accounts_by_group.get(int(gid), []), + ) + ) + return groups + finally: + conn.close() + + def save_groups(self, groups: list[ProxyGroupConfig]) -> None: + self._ensure_schema() + conn = self._conn() + try: + conn.execute("DELETE FROM account") + conn.execute("DELETE FROM proxy_group") + for group in groups: + cur = conn.execute( + """ + INSERT INTO proxy_group (proxy_host, proxy_user, proxy_pass, fingerprint_id, use_proxy, timezone) + VALUES (?, ?, ?, ?, ?, ?) + """, + ( + group.proxy_host, + group.proxy_user, + group.proxy_pass, + group.fingerprint_id, + 1 if group.use_proxy else 0, + group.timezone, + ), + ) + gid = cur.lastrowid + for account in group.accounts: + conn.execute( + """ + INSERT INTO account (proxy_group_id, name, type, auth, enabled, unfreeze_at) + VALUES (?, ?, ?, ?, ?, ?) + """, + ( + gid, + account.name, + account.type, + account.auth_json(), + 1 if account.enabled else 0, + account.unfreeze_at, + ), + ) + conn.commit() + finally: + conn.close() + + def update_account_unfreeze_at( + self, + fingerprint_id: str, + account_name: str, + unfreeze_at: int | None, + ) -> None: + self._ensure_schema() + conn = self._conn() + try: + conn.execute( + """ + UPDATE account SET unfreeze_at = ? + WHERE proxy_group_id = (SELECT id FROM proxy_group WHERE fingerprint_id = ?) + AND name = ? + """, + (unfreeze_at, fingerprint_id, account_name), + ) + conn.commit() + finally: + conn.close() + + def load_app_settings(self) -> dict[str, str]: + self._ensure_schema() + conn = self._conn() + try: + rows = conn.execute( + "SELECT key, value FROM app_setting ORDER BY key ASC" + ).fetchall() + return {str(key): str(value) for key, value in rows} + finally: + conn.close() + + def set_app_setting(self, key: str, value: str | None) -> None: + self._ensure_schema() + conn = self._conn() + try: + if value is None: + conn.execute("DELETE FROM app_setting WHERE key = ?", (key,)) + else: + conn.execute( + """ + INSERT INTO app_setting (key, value) VALUES (?, ?) + ON CONFLICT(key) DO UPDATE SET value = excluded.value + """, + (key, value), + ) + conn.commit() + finally: + conn.close() + + +class _PostgresConfigRepository(_RepositoryBase): + def __init__(self, database_url: str) -> None: + self._database_url = database_url + + def _conn(self) -> Any: + import psycopg + + return psycopg.connect(self._database_url) + + def init_schema(self) -> None: + with self._conn() as conn: + with conn.cursor() as cur: + cur.execute( + """ + CREATE TABLE IF NOT EXISTS proxy_group ( + id BIGSERIAL PRIMARY KEY, + proxy_host TEXT NOT NULL, + proxy_user TEXT NOT NULL, + proxy_pass TEXT NOT NULL, + fingerprint_id TEXT NOT NULL DEFAULT '', + use_proxy BOOLEAN NOT NULL DEFAULT TRUE, + timezone TEXT + ) + """ + ) + cur.execute( + """ + CREATE TABLE IF NOT EXISTS account ( + id BIGSERIAL PRIMARY KEY, + proxy_group_id BIGINT NOT NULL REFERENCES proxy_group(id) ON DELETE CASCADE, + name TEXT NOT NULL, + type TEXT NOT NULL, + auth TEXT NOT NULL DEFAULT '{}', + enabled BOOLEAN NOT NULL DEFAULT TRUE, + unfreeze_at BIGINT + ) + """ + ) + cur.execute( + "CREATE INDEX IF NOT EXISTS ix_account_proxy_group_id ON account(proxy_group_id)" + ) + cur.execute("CREATE INDEX IF NOT EXISTS ix_account_type ON account(type)") + cur.execute( + """ + CREATE TABLE IF NOT EXISTS app_setting ( + key TEXT PRIMARY KEY, + value TEXT NOT NULL DEFAULT '' + ) + """ + ) + + def load_groups(self) -> list[ProxyGroupConfig]: + groups: list[ProxyGroupConfig] = [] + with self._conn() as conn: + with conn.cursor() as cur: + cur.execute( + """ + SELECT id, proxy_host, proxy_user, proxy_pass, fingerprint_id, use_proxy, timezone + FROM proxy_group ORDER BY id ASC + """ + ) + group_rows = cur.fetchall() + cur.execute( + """ + SELECT proxy_group_id, name, type, auth, enabled, unfreeze_at + FROM account ORDER BY proxy_group_id ASC, id ASC + """ + ) + accounts_by_group: dict[int, list[AccountConfig]] = {} + for gid, name, type_, auth_json, enabled, unfreeze_at in cur.fetchall(): + accounts_by_group.setdefault(int(gid), []).append( + account_from_row( + name, + type_, + auth_json or "{}", + enabled=bool(enabled) if enabled is not None else True, + unfreeze_at=unfreeze_at, + ) + ) + for row in group_rows: + ( + gid, + proxy_host, + proxy_user, + proxy_pass, + fingerprint_id, + use_proxy, + timezone, + ) = row + groups.append( + ProxyGroupConfig( + proxy_host=proxy_host, + proxy_user=proxy_user, + proxy_pass=proxy_pass, + fingerprint_id=fingerprint_id or "", + use_proxy=bool(use_proxy), + timezone=timezone, + accounts=accounts_by_group.get(int(gid), []), + ) + ) + return groups + + def save_groups(self, groups: list[ProxyGroupConfig]) -> None: + with self._conn() as conn: + with conn.cursor() as cur: + cur.execute("DELETE FROM account") + cur.execute("DELETE FROM proxy_group") + for group in groups: + cur.execute( + """ + INSERT INTO proxy_group (proxy_host, proxy_user, proxy_pass, fingerprint_id, use_proxy, timezone) + VALUES (%s, %s, %s, %s, %s, %s) + RETURNING id + """, + ( + group.proxy_host, + group.proxy_user, + group.proxy_pass, + group.fingerprint_id, + group.use_proxy, + group.timezone, + ), + ) + gid = cur.fetchone()[0] + for account in group.accounts: + cur.execute( + """ + INSERT INTO account (proxy_group_id, name, type, auth, enabled, unfreeze_at) + VALUES (%s, %s, %s, %s, %s, %s) + """, + ( + gid, + account.name, + account.type, + account.auth_json(), + account.enabled, + account.unfreeze_at, + ), + ) + + def update_account_unfreeze_at( + self, + fingerprint_id: str, + account_name: str, + unfreeze_at: int | None, + ) -> None: + with self._conn() as conn: + with conn.cursor() as cur: + cur.execute( + """ + UPDATE account SET unfreeze_at = %s + WHERE proxy_group_id = ( + SELECT id FROM proxy_group WHERE fingerprint_id = %s ORDER BY id ASC LIMIT 1 + ) + AND name = %s + """, + (unfreeze_at, fingerprint_id, account_name), + ) + + def load_app_settings(self) -> dict[str, str]: + with self._conn() as conn: + with conn.cursor() as cur: + cur.execute("SELECT key, value FROM app_setting ORDER BY key ASC") + return {str(key): str(value) for key, value in cur.fetchall()} + + def set_app_setting(self, key: str, value: str | None) -> None: + with self._conn() as conn: + with conn.cursor() as cur: + if value is None: + cur.execute("DELETE FROM app_setting WHERE key = %s", (key,)) + else: + cur.execute( + """ + INSERT INTO app_setting (key, value) VALUES (%s, %s) + ON CONFLICT (key) DO UPDATE SET value = EXCLUDED.value + """, + (key, value), + ) + + +class ConfigRepository(_RepositoryBase): + """配置读写入口。""" + + def __init__(self, backend: _RepositoryBase) -> None: + self._backend = backend + + def init_schema(self) -> None: + self._backend.init_schema() + + def load_groups(self) -> list[ProxyGroupConfig]: + return self._backend.load_groups() + + def save_groups(self, groups: list[ProxyGroupConfig]) -> None: + self._backend.save_groups(groups) + + def load_raw(self) -> list[dict[str, Any]]: + return self._backend.load_raw() + + def load_app_settings(self) -> dict[str, str]: + return self._backend.load_app_settings() + + def get_app_setting(self, key: str) -> str | None: + return self._backend.get_app_setting(key) + + def set_app_setting(self, key: str, value: str | None) -> None: + self._backend.set_app_setting(key, value) + + def save_raw(self, raw: list[dict[str, Any]]) -> None: + self._backend.save_raw(raw) + + def update_account_unfreeze_at( + self, + fingerprint_id: str, + account_name: str, + unfreeze_at: int | None, + ) -> None: + self._backend.update_account_unfreeze_at( + fingerprint_id, + account_name, + unfreeze_at, + ) + + +def _raw_to_groups(raw: list[dict[str, Any]]) -> list[ProxyGroupConfig]: + """将 API 原始列表转为 ProxyGroupConfig 列表。""" + groups: list[ProxyGroupConfig] = [] + for group in raw: + accounts: list[AccountConfig] = [] + for account in group.get("accounts", []): + name = str(account.get("name", "")).strip() + type_ = str(account.get("type", "")).strip() or "claude" + auth = account.get("auth") + if isinstance(auth, dict): + pass + elif isinstance(auth, str): + try: + import json + + auth = json.loads(auth) or {} + except Exception: + auth = {} + else: + auth = {} + if name: + enabled = coerce_bool(account.get("enabled", True), True) + unfreeze_at = account.get("unfreeze_at") + if isinstance(unfreeze_at, (int, float)): + unfreeze_at = int(unfreeze_at) + else: + unfreeze_at = None + accounts.append( + AccountConfig( + name=name, + type=type_, + auth=auth, + enabled=enabled, + unfreeze_at=unfreeze_at, + ) + ) + groups.append( + ProxyGroupConfig( + proxy_host=str(group.get("proxy_host", "")), + proxy_user=str(group.get("proxy_user", "")), + proxy_pass=str(group.get("proxy_pass", "")), + fingerprint_id=str(group.get("fingerprint_id", "")), + use_proxy=coerce_bool(group.get("use_proxy", True), True), + timezone=group.get("timezone"), + accounts=accounts, + ) + ) + return groups diff --git a/core/config/schema.py b/core/config/schema.py new file mode 100644 index 0000000000000000000000000000000000000000..49c0391d332a2f4701fafe47069cd47087e5befb --- /dev/null +++ b/core/config/schema.py @@ -0,0 +1,76 @@ +""" +配置数据模型:按代理 IP(指纹)分组,账号含 name / type / auth(JSON)。 +不设 profile_id,user-data-dir 按指纹等由运行时拼接。 +""" + +from dataclasses import dataclass, field +from typing import Any + + +@dataclass(frozen=True) +class AccountConfig: + """单个账号:名称、类别、认证 JSON。一个账号只属于一个 type。""" + + name: str + type: str # 如 claude, chatgpt, kimi + auth: dict[str, Any] # 由各插件定义 key,如 claude 用 sessionKey + enabled: bool = True + unfreeze_at: int | None = ( + None # Unix 时间戳,接口返回的解冻时间;None 或已过则视为可用 + ) + + def auth_json(self) -> str: + """序列化为 JSON 字符串供 DB 存储。""" + import json + + return json.dumps(self.auth, ensure_ascii=False) + + def is_available(self) -> bool: + """已启用且当前时间 >= 解冻时间则可用(无解冻时间视为可用)。""" + if not self.enabled: + return False + if self.unfreeze_at is None: + return True + import time + + return time.time() >= self.unfreeze_at + + +@dataclass +class ProxyGroupConfig: + """一个代理 IP 组:代理参数 + 指纹 + 下属账号列表。""" + + proxy_host: str + proxy_user: str + proxy_pass: str + fingerprint_id: str + use_proxy: bool = True + timezone: str | None = None + accounts: list[AccountConfig] = field(default_factory=list) + + def account_ids(self) -> list[str]: + """返回该组下账号的唯一标识,用于会话缓存等。格式 group_idx 由 repository 注入前不可用,这里用 name 区分。""" + return [a.name for a in self.accounts] + + +def account_from_row( + name: str, + type: str, + auth_json: str, + enabled: bool = True, + unfreeze_at: int | None = None, +) -> AccountConfig: + """从 DB 行构造 AccountConfig。""" + import json + + try: + auth = json.loads(auth_json) if auth_json else {} + except Exception: + auth = {} + return AccountConfig( + name=name, + type=type, + auth=auth, + enabled=enabled, + unfreeze_at=unfreeze_at, + ) diff --git a/core/config/settings.py b/core/config/settings.py new file mode 100644 index 0000000000000000000000000000000000000000..4008fa4345ccfe7df7c311aa59c4affa38d37378 --- /dev/null +++ b/core/config/settings.py @@ -0,0 +1,147 @@ +""" +统一的 YAML 配置加载。 + +优先级: +1. WEB2API_CONFIG_PATH 指定的路径 +2. 项目根目录下的 config.local.yaml +3. 项目根目录下的 config.yaml + +同时支持通过环境变量覆盖单个配置项: +- 通用规则:WEB2API_
_ +- 额外兼容:server.host -> HOST,server.port -> PORT +""" + +from __future__ import annotations + +import os +from pathlib import Path +from typing import Any + +import yaml + + +_PROJECT_ROOT = Path(__file__).resolve().parent.parent.parent +_CONFIG_ENV_KEY = "WEB2API_CONFIG_PATH" +_LOCAL_CONFIG_NAME = "config.local.yaml" +_DEFAULT_CONFIG_NAME = "config.yaml" +_ENV_MISSING = object() +_ENV_OVERRIDE_ALIASES: dict[tuple[str, str], tuple[str, ...]] = { + ("server", "host"): ("HOST",), + ("server", "port"): ("PORT",), +} +_DATABASE_URL_ENV_NAMES = ("WEB2API_DATABASE_URL", "DATABASE_URL") +_BOOL_TRUE_VALUES = {"1", "true", "yes", "on"} +_BOOL_FALSE_VALUES = {"0", "false", "no", "off"} + + +def _resolve_config_path() -> Path: + configured = os.environ.get(_CONFIG_ENV_KEY, "").strip() + if configured: + return Path(configured).expanduser() + local_config = _PROJECT_ROOT / _LOCAL_CONFIG_NAME + if local_config.exists(): + return local_config + return _PROJECT_ROOT / _DEFAULT_CONFIG_NAME + + +_CONFIG_PATH = _resolve_config_path() + +_config_cache: dict[str, Any] | None = None + + +def _env_override_names(section: str, key: str) -> tuple[str, ...]: + generic = f"WEB2API_{section}_{key}".upper().replace("-", "_") + aliases = _ENV_OVERRIDE_ALIASES.get((section, key), ()) + ordered = [generic] + ordered.extend(alias for alias in aliases if alias != generic) + return tuple(ordered) + + +def _get_env_override(section: str, key: str) -> Any: + for env_name in _env_override_names(section, key): + if env_name in os.environ: + return os.environ[env_name] + return _ENV_MISSING + + +def has_env_override(section: str, key: str) -> bool: + return _get_env_override(section, key) is not _ENV_MISSING + + +def get_config_path() -> Path: + return _CONFIG_PATH + + +def reset_cache() -> None: + global _config_cache + _config_cache = None + + +def load_config() -> dict[str, Any]: + """按优先级加载配置文件,不存在时返回空 dict。""" + global _config_cache + if _config_cache is not None: + return _config_cache + if not _CONFIG_PATH.exists(): + _config_cache = {} + return {} + try: + with _CONFIG_PATH.open("r", encoding="utf-8") as f: + data = yaml.safe_load(f) or {} + if not isinstance(data, dict): + _config_cache = {} + else: + _config_cache = dict(data) + except Exception: + _config_cache = {} + return _config_cache + + +def get(section: str, key: str, default: Any = None) -> Any: + """从配置中读取 section.key,环境变量优先,其次 YAML,最后返回 default。""" + env_override = _get_env_override(section, key) + if env_override is not _ENV_MISSING: + return env_override + cfg = load_config().get(section) or {} + if not isinstance(cfg, dict): + return default + val = cfg.get(key) + return val if val is not None else default + + +def coerce_bool(value: Any, default: bool = False) -> bool: + if isinstance(value, bool): + return value + if isinstance(value, (int, float)): + return bool(value) + if isinstance(value, str): + normalized = value.strip().lower() + if normalized in _BOOL_TRUE_VALUES: + return True + if normalized in _BOOL_FALSE_VALUES: + return False + return bool(default) + + +def get_bool(section: str, key: str, default: bool = False) -> bool: + """从配置读取布尔值,兼容 true/false、1/0、yes/no、on/off。""" + return coerce_bool(get(section, key, default), default) + + +def get_server_host(default: str = "127.0.0.1") -> str: + return str(get("server", "host") or default).strip() or default + + +def get_server_port(default: int = 8001) -> int: + try: + return int(str(get("server", "port") or default).strip()) + except Exception: + return default + + +def get_database_url(default: str = "") -> str: + for env_name in _DATABASE_URL_ENV_NAMES: + value = os.environ.get(env_name, "").strip() + if value: + return value + return default diff --git a/core/constants.py b/core/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..63b5954c893dd680903b36542cb98cd4743ddfa1 --- /dev/null +++ b/core/constants.py @@ -0,0 +1,17 @@ +"""全局常量:浏览器路径、CDP 端口等(新架构专用)。""" + +from pathlib import Path + +# 与现有 multi_web2api 保持一致,便于同机运行时分端口 +CHROMIUM_BIN = "/Applications/Chromium.app/Contents/MacOS/Chromium" +REMOTE_DEBUGGING_PORT = 9223 # 默认端口,单浏览器兼容 +# 多浏览器并存时的端口池(按 ProxyKey 各占一端口,仅当 refcount=0 时关闭并回收端口) +CDP_PORT_RANGE = list(range(9223, 9243)) # 9223..9232,最多 20 个并发浏览器 +CDP_ENDPOINT = "http://127.0.0.1:9223" +TIMEZONE = "America/Chicago" +USER_DATA_DIR_PREFIX = "fp-data" # user_data_dir = home / fp-data / fingerprint_id + + +def user_data_dir(fingerprint_id: str) -> Path: + """按指纹 ID 拼接 user-data-dir,不依赖 profile_id。""" + return Path.home() / USER_DATA_DIR_PREFIX / fingerprint_id diff --git a/core/hub/__init__.py b/core/hub/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8434326d2c97ccb70c0cb8a1ccccd80af99b4bbc --- /dev/null +++ b/core/hub/__init__.py @@ -0,0 +1,14 @@ +""" +Hub 层:以 **OpenAI 语义**作为唯一中间态。 + +设计目标: +- 插件侧:把“站点平台格式”转换成 OpenAI 语义(请求 + 结构化流事件)。 +- 协议侧:把 OpenAI 语义转换成不同对外协议(OpenAI / Anthropic / Kimi ...)。 + +当前仓库历史上存在 Canonical 模型用于多协议解析;Hub 层用于把“内部执行语义” +固定为 OpenAI,降低插件/协议扩展的学习成本。 +""" + +from .schemas import OpenAIStreamEvent + +__all__ = ["OpenAIStreamEvent"] diff --git a/core/hub/openai_sse.py b/core/hub/openai_sse.py new file mode 100644 index 0000000000000000000000000000000000000000..3e23697b8a85830fcc93cc221ba9b157e8e4cfd4 --- /dev/null +++ b/core/hub/openai_sse.py @@ -0,0 +1,134 @@ +""" +把 OpenAIStreamEvent 编码为 OpenAI ChatCompletions SSE chunk。 + +这是 Hub 层的“协议输出工具”,用于把插件输出的结构化事件流转换为 +OpenAI 兼容的 `data: {...}\\n\\n` 片段。 + +当前不替换既有渲染链路,仅提供给后续协议/插件扩展使用。 +""" + +from __future__ import annotations + +import json +import time +import uuid as uuid_mod +from collections.abc import AsyncIterator, Iterator + +from core.hub.schemas import OpenAIStreamEvent + + +def make_openai_stream_context(*, model: str) -> tuple[str, int]: + """生成 OpenAI SSE 上下文:chat_id + created。""" + chat_id = f"chatcmpl-{uuid_mod.uuid4().hex[:24]}" + created = int(time.time()) + # model 由上层写入 payload + del model + return chat_id, created + + +def _chunk( + *, + chat_id: str, + model: str, + created: int, + delta: dict, + finish_reason: str | None = None, +) -> str: + return ( + "data: " + + json.dumps( + { + "id": chat_id, + "object": "chat.completion.chunk", + "created": created, + "model": model, + "choices": [ + { + "index": 0, + "delta": delta, + "logprobs": None, + "finish_reason": finish_reason, + } + ], + }, + ensure_ascii=False, + ) + + "\n\n" + ) + + +def encode_openai_sse_events( + events: Iterator[OpenAIStreamEvent], + *, + chat_id: str, + model: str, + created: int, +) -> Iterator[str]: + """同步编码器:OpenAIStreamEvent -> OpenAI SSE strings。""" + # 兼容主流 OpenAI SSE 客户端:先发一帧 role:assistant + content:"" + yield _chunk( + chat_id=chat_id, + model=model, + created=created, + delta={"role": "assistant", "content": ""}, + finish_reason=None, + ) + for ev in events: + if ev.type == "content_delta": + if ev.content: + yield _chunk( + chat_id=chat_id, + model=model, + created=created, + delta={"content": ev.content}, + finish_reason=None, + ) + elif ev.type == "tool_call_delta": + if ev.tool_calls: + yield _chunk( + chat_id=chat_id, + model=model, + created=created, + delta={"tool_calls": [tc.model_dump() for tc in ev.tool_calls]}, + finish_reason=None, + ) + elif ev.type == "finish": + # OpenAI 的结束 chunk 允许 delta 为空对象 + yield _chunk( + chat_id=chat_id, + model=model, + created=created, + delta={}, + finish_reason=ev.finish_reason or "stop", + ) + yield "data: [DONE]\n\n" + return + elif ev.type == "error": + # OpenAI SSE 没有标准 error 事件,这里用 data 包一层 error 对象(与现有实现一致风格) + msg = ev.error or "unknown error" + yield ( + "data: " + + json.dumps( + {"error": {"message": msg, "type": "server_error"}}, + ensure_ascii=False, + ) + + "\n\n" + ) + + +async def encode_openai_sse_events_async( + events: AsyncIterator[OpenAIStreamEvent], + *, + chat_id: str, + model: str, + created: int, +) -> AsyncIterator[str]: + """异步编码器:OpenAIStreamEvent -> OpenAI SSE strings。""" + async for ev in events: + for out in encode_openai_sse_events( + iter([ev]), + chat_id=chat_id, + model=model, + created=created, + ): + yield out diff --git a/core/hub/schemas.py b/core/hub/schemas.py new file mode 100644 index 0000000000000000000000000000000000000000..d6f22e0b01bebabd7ce4ba84978681f76b6f7687 --- /dev/null +++ b/core/hub/schemas.py @@ -0,0 +1,46 @@ +""" +OpenAI 语义的结构化流事件(唯一流式中间态)。 + +整条链路:插件产出字符串流 → core 包装为 content_delta + finish → +协议适配层消费事件、编码为各协议 SSE(OpenAI / Anthropic / 未来 Kimi 等)。 +""" + +from __future__ import annotations + +from typing import Any, Literal + +from pydantic import BaseModel, Field + + +class OpenAIToolCallDelta(BaseModel): + """OpenAI stream delta 中的 tool_calls[?] 片段(最小必要字段)。""" + + index: int = 0 + id: str | None = None + type: Literal["function"] = "function" + function: dict[str, Any] = Field(default_factory=dict) + + +class OpenAIStreamEvent(BaseModel): + """ + OpenAI 语义的“内部流事件”。 + - content_delta:增量文本(delta.content) + - tool_call_delta:工具调用增量(delta.tool_calls) + - finish:结束(finish_reason) + - error:错误 + 协议适配层负责将事件序列化为目标协议的 SSE/JSON。 + """ + + type: Literal["content_delta", "tool_call_delta", "finish", "error"] + + # content_delta + content: str | None = None + + # tool_call_delta + tool_calls: list[OpenAIToolCallDelta] | None = None + + # finish + finish_reason: str | None = None + + # error + error: str | None = None diff --git a/core/plugin/__init__.py b/core/plugin/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..41272651774c3d2c8acb9286779b2207f194f429 --- /dev/null +++ b/core/plugin/__init__.py @@ -0,0 +1,5 @@ +"""插件层:抽象接口与注册表,各 type 实现 create_page / apply_auth / create_conversation / stream_completion。""" + +from core.plugin.base import AbstractPlugin, BaseSitePlugin, PluginRegistry, SiteConfig + +__all__ = ["AbstractPlugin", "BaseSitePlugin", "PluginRegistry", "SiteConfig"] diff --git a/core/plugin/base.py b/core/plugin/base.py new file mode 100644 index 0000000000000000000000000000000000000000..8c87e859994267e2457b9f6c0fe9c154045b1301 --- /dev/null +++ b/core/plugin/base.py @@ -0,0 +1,519 @@ +""" +插件抽象与注册表:type_name -> 插件实现。 + +三层设计: + AbstractPlugin — 最底层接口,理论上支持任意协议(非 Cookie、非 SSE 的站点也能接)。 + BaseSitePlugin — Cookie 认证 + SSE 流式站点的通用编排,插件开发者继承它只需实现 5 个 hook。 + PluginRegistry — 全局注册表。 +""" + +import json +import logging +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Any, AsyncIterator + +from playwright.async_api import BrowserContext, Page + +from core.api.schemas import InputAttachment +from core.config.settings import get +from core.plugin.errors import ( # noqa: F401 — re-export for backward compat + AccountFrozenError, + BrowserResourceInvalidError, +) +from core.plugin.helpers import ( + apply_cookie_auth, + create_page_for_site, + stream_completion_via_sse, +) + +logger = logging.getLogger(__name__) + + +@dataclass(frozen=True) +class ResolvedModel: + public_model: str + upstream_model: str + + +# --------------------------------------------------------------------------- +# SiteConfig:纯声明式站点配置 +# --------------------------------------------------------------------------- + + +@dataclass +class SiteConfig: + """Cookie 认证站点的声明式配置,插件开发者只需填字段,无需写任何方法。""" + + start_url: str + api_base: str + cookie_name: str + cookie_domain: str + auth_keys: list[str] + config_section: str = ( + "" # config.yaml 中的 section,如 "claude",用于覆盖 start_url/api_base + ) + + +# --------------------------------------------------------------------------- +# AbstractPlugin — 最底层抽象接口 +# --------------------------------------------------------------------------- + + +class AbstractPlugin(ABC): + """ + 各 type(如 claude、kimi)需实现此接口并注册。 + 若站点基于 Cookie + SSE,推荐直接继承 BaseSitePlugin 而非此类。 + """ + + def __init__(self) -> None: + self._session_state: dict[str, dict[str, Any]] = {} + + type_name: str + + async def create_page( + self, context: BrowserContext, reuse_page: Page | None = None + ) -> Page: + raise NotImplementedError + + async def apply_auth( + self, + context: BrowserContext, + page: Page, + auth: dict[str, Any], + *, + reload: bool = True, + **kwargs: Any, + ) -> None: + raise NotImplementedError + + async def ensure_request_ready( + self, + context: BrowserContext, + page: Page, + *, + request_id: str = "", + session_id: str | None = None, + phase: str = "", + account_id: str = "", + ) -> None: + del context, page, request_id, session_id, phase, account_id + + async def create_conversation( + self, + context: BrowserContext, + page: Page, + **kwargs: Any, + ) -> str | None: + raise NotImplementedError + + async def stream_completion( + self, + context: BrowserContext, + page: Page, + session_id: str, + message: str, + **kwargs: Any, + ) -> AsyncIterator[str]: + if False: + yield # 使抽象方法为 async generator,与子类一致,便于 async for 迭代 + raise NotImplementedError + + def parse_session_id(self, messages: list[dict[str, Any]]) -> str | None: + return None + + def is_stream_end_event(self, payload: str) -> bool: + """判断某条流式 payload 是否表示本轮响应已正常结束。默认不识别。""" + return False + + def has_session(self, session_id: str) -> bool: + return session_id in self._session_state + + def drop_session(self, session_id: str) -> None: + self._session_state.pop(session_id, None) + + def drop_sessions(self, session_ids: list[str] | set[str]) -> None: + for session_id in session_ids: + self._session_state.pop(session_id, None) + + def model_mapping(self) -> dict[str, str] | None: + """子类可覆盖;BaseSitePlugin 会从 config_section 的 model_mapping 读取。""" + return None + + def normalized_model_mapping(self) -> dict[str, str]: + mapping = self.model_mapping() + if not isinstance(mapping, dict) or not mapping: + raise ValueError("model_mapping is not implemented") + normalized: dict[str, str] = {} + for public_model, upstream_model in mapping.items(): + public_id = str(public_model or "").strip() + upstream_id = str(upstream_model or "").strip() + if public_id and upstream_id: + normalized[public_id] = upstream_id + if not normalized: + raise ValueError("model_mapping is not implemented") + return normalized + + def listed_model_mapping(self) -> dict[str, str]: + return self.normalized_model_mapping() + + def default_public_model(self) -> str: + listed = self.listed_model_mapping() + if listed: + return next(iter(listed)) + return next(iter(self.normalized_model_mapping())) + + def resolve_model(self, model: str | None) -> ResolvedModel: + mapping = self.normalized_model_mapping() + requested = str(model or "").strip() + if not requested: + default_public = self.default_public_model() + return ResolvedModel( + public_model=default_public, + upstream_model=mapping[default_public], + ) + if requested in mapping: + return ResolvedModel( + public_model=requested, + upstream_model=mapping[requested], + ) + for public_model, upstream_model in mapping.items(): + if requested == upstream_model: + return ResolvedModel( + public_model=public_model, + upstream_model=upstream_model, + ) + supported = ", ".join(mapping.keys()) + raise ValueError(f"Unknown model: {requested}. Supported models: {supported}") + + def on_http_error(self, message: str, headers: dict[str, str] | None) -> int | None: + return None + + def stream_transport(self) -> str: + return "page_fetch" + + def stream_transport_options( + self, + context: BrowserContext, + page: Page, + session_id: str, + state: dict[str, Any], + **kwargs: Any, + ) -> dict[str, Any]: + del context, page, session_id, state + options: dict[str, Any] = {} + proxy_url = str(kwargs.get("proxy_url") or "").strip() + proxy_auth = kwargs.get("proxy_auth") + if proxy_url: + options["proxy_url"] = proxy_url + if isinstance(proxy_auth, tuple) and len(proxy_auth) == 2: + options["proxy_auth"] = proxy_auth + return options + + +# --------------------------------------------------------------------------- +# BaseSitePlugin — Cookie + SSE 站点的通用编排 +# --------------------------------------------------------------------------- + + +class BaseSitePlugin(AbstractPlugin): + """ + Cookie 认证 + SSE 流式站点的公共基类。 + + 插件开发者继承此类后,只需: + 1. 声明 site = SiteConfig(...) — 站点配置 + 2. 实现 fetch_site_context() — 获取站点上下文(如 org/user 信息) + 3. 实现 create_session() — 调用站点 API 创建会话 + 4. 实现 build_completion_url/body() — 拼补全请求的 URL 与 body + 5. 实现 parse_stream_event() — 解析单条流式事件(如 SSE data) + + create_page / apply_auth / create_conversation / stream_completion + 均由基类自动编排,无需重写。 + """ + + site: SiteConfig # 子类必须赋值 + + # ---- 从 config.yaml 读取的 URL 属性(config_section 有值时覆盖默认) ---- + + @property + def start_url(self) -> str: + if self.site.config_section: + url = get(self.site.config_section, "start_url") + if url: + return str(url).strip() + return self.site.start_url + + @property + def api_base(self) -> str: + if self.site.config_section: + base = get(self.site.config_section, "api_base") + if base: + return str(base).strip() + return self.site.api_base + + def model_mapping(self) -> dict[str, str] | None: + """从 config 的 config_section.model_mapping 读取;未配置时返回 None。""" + if self.site.config_section: + m = get(self.site.config_section, "model_mapping") + if isinstance(m, dict) and m: + return {str(k): str(v) for k, v in m.items()} + return None + + # ---- 基类全自动实现,子类无需碰 ---- + + async def create_page( + self, + context: BrowserContext, + reuse_page: Page | None = None, + ) -> Page: + return await create_page_for_site( + context, self.start_url, reuse_page=reuse_page + ) + + async def apply_auth( + self, + context: BrowserContext, + page: Page, + auth: dict[str, Any], + *, + reload: bool = True, + ) -> None: + await apply_cookie_auth( + context, + page, + auth, + self.site.cookie_name, + self.site.auth_keys, + self.site.cookie_domain, + reload=reload, + ) + + async def create_conversation( + self, + context: BrowserContext, + page: Page, + **kwargs: Any, + ) -> str | None: + extra_kwargs = dict(kwargs) + request_id = str(extra_kwargs.pop("request_id", "") or "") + # 调用子类获取站点上下文 + site_context = await self.fetch_site_context( + context, + page, + request_id=request_id, + ) + if site_context is None: + logger.warning( + "[%s] fetch_site_context 返回 None,请确认已登录", self.type_name + ) + return None + # 通过站点上下文创建会话 + conv_id = await self.create_session( + context, + page, + site_context, + request_id=request_id, + **extra_kwargs, + ) + if conv_id is None: + return None + state: dict[str, Any] = {"site_context": site_context} + if kwargs.get("timezone") is not None: + state["timezone"] = kwargs["timezone"] + public_model = str(kwargs.get("public_model") or "").strip() + if public_model: + state["public_model"] = public_model + upstream_model = str(kwargs.get("upstream_model") or "").strip() + if upstream_model: + state["upstream_model"] = upstream_model + self._session_state[conv_id] = state + logger.info( + "[%s] create_conversation done conv_id=%s sessions=%s", + self.type_name, + conv_id, + list(self._session_state.keys()), + ) + return conv_id + + async def stream_completion( + self, + context: BrowserContext, + page: Page, + session_id: str, + message: str, + **kwargs: Any, + ) -> AsyncIterator[str]: + state = self._session_state.get(session_id) + if not state: + raise RuntimeError(f"未知会话 ID: {session_id}") + + request_id: str = kwargs.get("request_id", "") + url = self.build_completion_url(session_id, state) + attachments = list(kwargs.get("attachments") or []) + prepared_attachments = await self.prepare_attachments( + context, + page, + session_id, + state, + attachments, + request_id=request_id, + ) + body = self.build_completion_body( + message, + session_id, + state, + prepared_attachments, + ) + body_json = json.dumps(body) + + logger.info( + "[%s] stream_completion session_id=%s url=%s", + self.type_name, + session_id, + url, + ) + + out_message_ids: list[str] = [] + transport_options = self.stream_transport_options( + context, + page, + session_id, + state, + request_id=request_id, + attachments=attachments, + proxy_url=kwargs.get("proxy_url"), + proxy_auth=kwargs.get("proxy_auth"), + ) + + async for text in stream_completion_via_sse( + context, + page, + url, + body_json, + self.parse_stream_event, + request_id, + on_http_error=self.on_http_error, + is_terminal_event=self.is_stream_end_event, + collect_message_id=out_message_ids, + transport=self.stream_transport(), + transport_options=transport_options, + ): + yield text + + if out_message_ids and session_id in self._session_state: + self.on_stream_completion_finished(session_id, out_message_ids) + + # ---- 子类必须实现的 hook ---- + + @abstractmethod + async def fetch_site_context( + self, + context: BrowserContext, + page: Page, + request_id: str = "", + ) -> dict[str, Any] | None: + """获取站点上下文信息(如 org_uuid、user_id 等),失败返回 None。""" + del request_id + ... + + @abstractmethod + async def create_session( + self, + context: BrowserContext, + page: Page, + site_context: dict[str, Any], + **kwargs: Any, + ) -> str | None: + """调用站点 API 创建会话,返回会话 ID,失败返回 None。""" + ... + + @abstractmethod + def build_completion_url(self, session_id: str, state: dict[str, Any]) -> str: + """根据会话状态拼出补全请求的完整 URL。""" + ... + + @abstractmethod + def build_completion_body( + self, + message: str, + session_id: str, + state: dict[str, Any], + prepared_attachments: dict[str, Any] | None = None, + ) -> dict[str, Any]: + """构建补全请求体,返回 dict(基类负责 json.dumps)。""" + ... + + @abstractmethod + def parse_stream_event( + self, + payload: str, + ) -> tuple[list[str], str | None, str | None]: + """ + 解析单条流式事件 payload(如 SSE data 行)。 + 返回 (texts, message_id, error_message)。 + """ + ... + + # ---- 子类可选覆盖的 hook(有合理默认值) ---- + + def on_stream_completion_finished( + self, + session_id: str, + message_ids: list[str], + ) -> None: + """Hook:流式补全结束后调用,子类可按需用 message_ids 更新会话 state(如记续写用的父消息 id)。""" + + async def prepare_attachments( + self, + context: BrowserContext, + page: Page, + session_id: str, + state: dict[str, Any], + attachments: list[InputAttachment], + request_id: str = "", + ) -> dict[str, Any]: + del context, page, session_id, state, attachments, request_id + return {} + + +# --------------------------------------------------------------------------- +# PluginRegistry — 全局注册表 +# --------------------------------------------------------------------------- + + +class PluginRegistry: + """全局插件注册表:type_name -> AbstractPlugin。""" + + _plugins: dict[str, AbstractPlugin] = {} + + @classmethod + def register(cls, plugin: AbstractPlugin) -> None: + cls._plugins[plugin.type_name] = plugin + + @classmethod + def get(cls, type_name: str) -> AbstractPlugin | None: + return cls._plugins.get(type_name) + + @classmethod + def resolve_model(cls, type_name: str, model: str | None) -> ResolvedModel: + plugin = cls.get(type_name) + if plugin is None: + raise ValueError(f"Unknown provider: {type_name}") + return plugin.resolve_model(model) + + @classmethod + def model_metadata(cls, type_name: str) -> dict[str, Any]: + plugin = cls.get(type_name) + if plugin is None: + raise ValueError(f"Unknown provider: {type_name}") + mapping = plugin.listed_model_mapping() + return { + "provider": type_name, + "public_models": list(mapping.keys()), + "model_mapping": mapping, + "default_model": plugin.default_public_model(), + } + + @classmethod + def all_types(cls) -> list[str]: + return list(cls._plugins.keys()) diff --git a/core/plugin/claude.py b/core/plugin/claude.py new file mode 100644 index 0000000000000000000000000000000000000000..9f1e6c903d0eb35e8703674ca9b38f11a9a14b04 --- /dev/null +++ b/core/plugin/claude.py @@ -0,0 +1,756 @@ +""" +Claude 插件:仅实现站点特有的上下文获取、会话创建、请求体构建、SSE 解析和限流处理。 +其余编排逻辑(create_page / apply_auth / stream_completion 流程)全部由 BaseSitePlugin 完成。 +调试时可在 config.yaml 的 claude.start_url、claude.api_base 指向 mock。 +""" + +import datetime +import json +import logging +import re +import time +from asyncio import Lock +from typing import Any +from urllib.parse import urlparse + +from playwright.async_api import BrowserContext, Page + +from core.api.schemas import InputAttachment +from core.constants import TIMEZONE +from core.plugin.base import BaseSitePlugin, PluginRegistry, SiteConfig +from core.plugin.errors import BrowserResourceInvalidError +from core.plugin.helpers import ( + _classify_browser_resource_error, + clear_cookies_for_domain, + clear_page_storage_for_switch, + request_json_via_context_request, + safe_page_reload, + upload_file_via_context_request, +) + +logger = logging.getLogger(__name__) + +# Probe cache: skip redundant ensure_request_ready probes within this window. +_PROBE_CACHE_TTL_SECONDS = 60.0 + + +def _truncate_url_for_log(value: str, limit: int = 200) -> str: + if len(value) <= limit: + return value + return value[:limit] + "..." + + +def _safe_page_url(page: Page) -> str: + try: + return page.url or "" + except Exception: + return "" + + +# --------------------------------------------------------------------------- +# 站点特有:请求体 & SSE 解析 +# --------------------------------------------------------------------------- + + +def _is_thinking_model(public_model: str) -> bool: + """Any model ending with -thinking enables extended thinking (paprika_mode).""" + return public_model.endswith("-thinking") + + +def _base_upstream_model(public_model: str) -> str: + """Strip -thinking suffix to get the upstream model ID for Claude Web.""" + return public_model.removesuffix("-thinking") + + +def _default_completion_body( + message: str, + *, + is_follow_up: bool = False, + timezone: str = TIMEZONE, + public_model: str = "", +) -> dict[str, Any]: + """构建 Claude completion 请求体。续写时不带 create_conversation_params,否则 API 返回 400。""" + body: dict[str, Any] = { + "prompt": message, + "timezone": timezone, + "personalized_styles": [ + { + "type": "default", + "key": "Default", + "name": "Normal", + "nameKey": "normal_style_name", + "prompt": "Normal\n", + "summary": "Default responses from Claude", + "summaryKey": "normal_style_summary", + "isDefault": True, + } + ], + "locale": "en-US", + "tools": [ + {"type": "web_search_v0", "name": "web_search"}, + {"type": "artifacts_v0", "name": "artifacts"}, + {"type": "repl_v0", "name": "repl"}, + {"type": "widget", "name": "weather_fetch"}, + {"type": "widget", "name": "recipe_display_v0"}, + {"type": "widget", "name": "places_map_display_v0"}, + {"type": "widget", "name": "message_compose_v1"}, + {"type": "widget", "name": "ask_user_input_v0"}, + {"type": "widget", "name": "places_search"}, + {"type": "widget", "name": "fetch_sports_data"}, + ], + "attachments": [], + "files": [], + "sync_sources": [], + "rendering_mode": "messages", + } + if _is_thinking_model(public_model): + body["model"] = _base_upstream_model(public_model) + if not is_follow_up: + body["create_conversation_params"] = { + "name": "", + "include_conversation_preferences": True, + "is_temporary": False, + } + if _is_thinking_model(public_model): + body["create_conversation_params"]["paprika_mode"] = "extended" + return body + + +def _parse_one_sse_event(payload: str) -> tuple[list[str], str | None, str | None]: + """解析单条 Claude SSE data 行,返回 (texts, message_id, error)。""" + result: list[str] = [] + message_id: str | None = None + error_message: str | None = None + try: + obj = json.loads(payload) + if not isinstance(obj, dict): + return (result, message_id, error_message) + kind = obj.get("type") + if kind == "error": + err = obj.get("error") or {} + error_message = err.get("message") or err.get("type") or "Unknown error" + return (result, message_id, error_message) + if "text" in obj and obj.get("text"): + result.append(str(obj["text"])) + elif kind == "content_block_delta": + delta = obj.get("delta") + if isinstance(delta, dict) and "text" in delta: + result.append(str(delta["text"])) + elif isinstance(delta, str) and delta: + result.append(delta) + elif kind == "message_start": + msg = obj.get("message") + if isinstance(msg, dict): + for key in ("uuid", "id"): + if msg.get(key): + message_id = str(msg[key]) + break + if not message_id: + mid = ( + obj.get("message_uuid") or obj.get("uuid") or obj.get("message_id") + ) + if mid: + message_id = str(mid) + elif ( + kind + and kind + not in ( + "ping", + "content_block_start", + "content_block_stop", + "message_stop", + "message_delta", + "message_limit", + ) + and not result + ): + logger.debug( + "SSE 未解析出正文 type=%s payload=%s", + kind, + payload[:200] if len(payload) > 200 else payload, + ) + except json.JSONDecodeError: + pass + return (result, message_id, error_message) + + +def _is_terminal_sse_event(payload: str) -> bool: + """Claude 正常流结束时会发送 message_stop。""" + try: + obj = json.loads(payload) + except json.JSONDecodeError: + return False + return isinstance(obj, dict) and obj.get("type") == "message_stop" + + +# --------------------------------------------------------------------------- +# ClaudePlugin — 只需声明配置 + 实现 5 个 hook +# --------------------------------------------------------------------------- + + +class ClaudePlugin(BaseSitePlugin): + """Claude Web2API plugin. auth must include sessionKey.""" + + type_name = "claude" + DEFAULT_MODEL_MAPPING = { + "claude-sonnet-4-6": "claude-sonnet-4-6", + "claude-sonnet-4-5": "claude-sonnet-4-5", + "claude-sonnet-4-5-thinking": "claude-sonnet-4-5-thinking", + "claude-sonnet-4-6-thinking": "claude-sonnet-4-6-thinking", + "claude-haiku-4-5": "claude-haiku-4-5", + "claude-haiku-4-5-thinking": "claude-haiku-4-5-thinking", + "claude-opus-4-6": "claude-opus-4-6", + "claude-opus-4-6-thinking": "claude-opus-4-6-thinking", + } + # Models that require a Claude Pro subscription. + PRO_MODELS = frozenset({ + "claude-haiku-4-5", + "claude-haiku-4-5-thinking", + "claude-opus-4-6", + "claude-opus-4-6-thinking", + }) + MODEL_ALIASES = { + "s4": "claude-sonnet-4-6", + # dot-notation aliases (e.g. 4.6 / 4.5) → canonical dash form + "claude-sonnet-4.6": "claude-sonnet-4-6", + "claude-sonnet-4.5": "claude-sonnet-4-5", + "claude-opus-4.6": "claude-opus-4-6", + "claude-haiku-4.5": "claude-haiku-4-5", + # thinking variants + "claude-sonnet-4.6-thinking": "claude-sonnet-4-6-thinking", + "claude-sonnet-4.5-thinking": "claude-sonnet-4-5-thinking", + "claude-opus-4.6-thinking": "claude-opus-4-6-thinking", + "claude-haiku-4.5-thinking": "claude-haiku-4-5-thinking", + } + + site = SiteConfig( + start_url="https://claude.ai/login", + api_base="https://claude.ai/api", + cookie_name="sessionKey", + cookie_domain=".claude.ai", + auth_keys=["sessionKey", "session_key"], + config_section="claude", + ) + + def __init__(self) -> None: + super().__init__() + # Per-page probe cache: page id -> last successful probe timestamp + self._probe_ok_at: dict[int, float] = {} + # Per-page navigation lock: prevents concurrent page.goto/reload + self._nav_locks: dict[int, Lock] = {} + # Per-page site_context cache: page id -> (context_dict, timestamp) + self._site_context_cache: dict[int, tuple[dict[str, Any], float]] = {} + + _SITE_CONTEXT_TTL = 300.0 # 5 minutes + + def model_mapping(self) -> dict[str, str] | None: + configured = super().model_mapping() or {} + mapping = dict(self.DEFAULT_MODEL_MAPPING) + mapping.update(configured) + for alias, upstream_model in self.MODEL_ALIASES.items(): + mapping.setdefault(alias, upstream_model) + return mapping + + def listed_model_mapping(self) -> dict[str, str]: + configured = super().model_mapping() or {} + mapping = dict(self.DEFAULT_MODEL_MAPPING) + mapping.update(configured) + for alias in self.MODEL_ALIASES: + mapping.pop(alias, None) + return mapping + + async def apply_auth( + self, + context: BrowserContext, + page: Page, + auth: dict[str, Any], + *, + reload: bool = True, + ) -> None: + await clear_cookies_for_domain(context, self.site.cookie_domain) + await clear_page_storage_for_switch(page) + await super().apply_auth(context, page, auth, reload=False) + if reload: + await safe_page_reload(page, url=self.start_url) + + def _is_claude_domain(self, url: str) -> bool: + host = (urlparse(url).hostname or "").lower().lstrip(".") + if not host: + return False + allowed_hosts = {"claude.ai", "claude.com"} + for configured_url in (self.start_url, self.api_base): + configured_host = (urlparse(configured_url).hostname or "").lower().lstrip(".") + if configured_host: + allowed_hosts.add(configured_host) + return any(host == allowed or host.endswith(f".{allowed}") for allowed in allowed_hosts) + + def _suspicious_page_reason(self, url: str) -> str | None: + if not url: + return "empty_page_url" + parsed = urlparse(url) + if not parsed.scheme or not parsed.netloc: + return "invalid_page_url" + if not self._is_claude_domain(url): + return "non_claude_domain" + path = parsed.path or "/" + if path == "/new" or path.startswith("/new/"): + return "new_chat_page" + if path in {"/logout", "/auth", "/signed-out"}: + return "logout_page" + if path.startswith("/signup"): + return "signup_page" + if path == "/app-unavailable-in-region" or path.startswith( + "/app-unavailable-in-region/" + ): + return "app_unavailable_in_region" + return None + + def _is_suspicious_page_url(self, url: str) -> bool: + return self._suspicious_page_reason(url) is not None + + async def _probe_request_ready( + self, + context: BrowserContext, + page: Page, + *, + request_id: str, + ) -> tuple[bool, str | None]: + current_url = _safe_page_url(page) + suspicious_reason = self._suspicious_page_reason(current_url) + if suspicious_reason is not None: + logger.warning( + "[%s] request-ready probe sees suspicious page url request_id=%s reason=%s page.url=%s", + self.type_name, + request_id, + suspicious_reason, + _truncate_url_for_log(current_url), + ) + return (False, suspicious_reason) + try: + site_context = await self.fetch_site_context( + context, + page, + request_id=request_id, + ) + except BrowserResourceInvalidError: + raise + except Exception as e: + logger.warning( + "[%s] request-ready probe failed request_id=%s page.url=%s err=%s", + self.type_name, + request_id, + _truncate_url_for_log(current_url), + str(e)[:240], + ) + return (False, f"control_probe_error:{str(e)[:120]}") + return (site_context is not None, None if site_context is not None else "account_probe_empty") + + async def ensure_request_ready( + self, + context: BrowserContext, + page: Page, + *, + request_id: str = "", + session_id: str | None = None, + phase: str = "", + account_id: str = "", + ) -> None: + initial_url = _safe_page_url(page) + current_url = initial_url + probe_request_id = request_id or f"ready:{phase or 'request'}" + action = "none" + probe_before = False + probe_after = False + probe_reason: str | None = None + page_id = id(page) + + # Fast path (lock-free): page URL is clean and probe succeeded recently. + suspicious_reason = self._suspicious_page_reason(current_url) + if suspicious_reason is None: + last_ok = self._probe_ok_at.get(page_id, 0.0) + if (time.time() - last_ok) < _PROBE_CACHE_TTL_SECONDS: + return + if suspicious_reason == "app_unavailable_in_region": + raise RuntimeError( + "Claude page is app-unavailable-in-region; the runtime IP or region cannot reach Claude Web" + ) + + # Slow path: acquire per-page nav lock to prevent concurrent navigation. + nav_lock = self._nav_locks.setdefault(page_id, Lock()) + async with nav_lock: + # Re-check after acquiring lock — another request may have fixed the page. + current_url = _safe_page_url(page) + suspicious_reason = self._suspicious_page_reason(current_url) + if suspicious_reason is None: + last_ok = self._probe_ok_at.get(page_id, 0.0) + if (time.time() - last_ok) < _PROBE_CACHE_TTL_SECONDS: + return + if suspicious_reason == "app_unavailable_in_region": + raise RuntimeError( + "Claude page is app-unavailable-in-region; the runtime IP or region cannot reach Claude Web" + ) + + try: + if suspicious_reason is not None: + action = "goto" + try: + await safe_page_reload(page, url=self.start_url) + except Exception as e: + classified = _classify_browser_resource_error( + e, + helper_name="claude.ensure_request_ready", + operation="preflight", + stage="goto_start_url", + request_url=self.start_url, + page=page, + request_id=request_id or None, + stream_phase=phase or None, + ) + if classified is not None: + raise classified from e + raise + current_url = _safe_page_url(page) + suspicious_reason = self._suspicious_page_reason(current_url) + if suspicious_reason == "app_unavailable_in_region": + probe_reason = suspicious_reason + raise RuntimeError( + "Claude page is app-unavailable-in-region after goto; the runtime IP or region cannot reach Claude Web" + ) + + probe_before = self._suspicious_page_reason(current_url) is None + if probe_before: + probe_after, probe_reason = await self._probe_request_ready( + context, + page, + request_id=f"{probe_request_id}:initial", + ) + if probe_after: + self._probe_ok_at[page_id] = time.time() + return + if probe_reason == "app_unavailable_in_region": + raise RuntimeError( + "Claude page is app-unavailable-in-region during control probe; the runtime IP or region cannot reach Claude Web" + ) + else: + probe_after = False + probe_reason = suspicious_reason or "suspicious_page_url" + + action = "reload" + try: + await safe_page_reload(page) + except Exception as e: + classified = _classify_browser_resource_error( + e, + helper_name="claude.ensure_request_ready", + operation="preflight", + stage="reload", + request_url=current_url or self.start_url, + page=page, + request_id=request_id or None, + stream_phase=phase or None, + ) + if classified is not None: + raise classified from e + raise + current_url = _safe_page_url(page) + if self._suspicious_page_reason(current_url) == "app_unavailable_in_region": + probe_reason = "app_unavailable_in_region" + raise RuntimeError( + "Claude page is app-unavailable-in-region after reload; the runtime IP or region cannot reach Claude Web" + ) + probe_after, probe_reason = await self._probe_request_ready( + context, + page, + request_id=f"{probe_request_id}:reload", + ) + if probe_after: + self._probe_ok_at[page_id] = time.time() + return + if probe_reason == "app_unavailable_in_region": + raise RuntimeError( + "Claude page is app-unavailable-in-region after reload probe; the runtime IP or region cannot reach Claude Web" + ) + + action = "goto" + try: + await safe_page_reload(page, url=self.start_url) + except Exception as e: + classified = _classify_browser_resource_error( + e, + helper_name="claude.ensure_request_ready", + operation="preflight", + stage="goto_start_url", + request_url=self.start_url, + page=page, + request_id=request_id or None, + stream_phase=phase or None, + ) + if classified is not None: + raise classified from e + raise + current_url = _safe_page_url(page) + if self._suspicious_page_reason(current_url) == "app_unavailable_in_region": + probe_reason = "app_unavailable_in_region" + raise RuntimeError( + "Claude page is app-unavailable-in-region after page correction; the runtime IP or region cannot reach Claude Web" + ) + probe_after, probe_reason = await self._probe_request_ready( + context, + page, + request_id=f"{probe_request_id}:goto", + ) + if not probe_after: + if probe_reason == "suspicious_page_url": + raise BrowserResourceInvalidError( + "Claude request preflight failed after page correction: suspicious_page_url", + helper_name="claude.ensure_request_ready", + operation="preflight", + stage="probe_after_goto", + resource_hint="page", + request_url=self.start_url, + page_url=current_url, + request_id=request_id or None, + stream_phase=phase or None, + ) + raise RuntimeError( + f"Claude request control probe failed after page correction: {probe_reason or 'unknown'}" + ) + self._probe_ok_at[page_id] = time.time() + finally: + logger.info( + "[%s] ensure_request_ready phase=%s account=%s session_id=%s action=%s probe_before=%s probe_after=%s probe_reason=%s page.url.before=%s page.url.after=%s", + self.type_name, + phase, + account_id, + session_id, + action, + probe_before, + probe_after, + probe_reason, + _truncate_url_for_log(initial_url), + _truncate_url_for_log(current_url), + ) + + # ---- 5 个必须实现的 hook ---- + + async def fetch_site_context( + self, + context: BrowserContext, + page: Page, + request_id: str = "", + ) -> dict[str, Any] | None: + page_id = id(page) + cached = self._site_context_cache.get(page_id) + if cached is not None: + ctx, ts = cached + if (time.time() - ts) < self._SITE_CONTEXT_TTL: + return ctx + resp = await request_json_via_context_request( + context, + page, + f"{self.api_base}/account", + timeout_ms=15000, + request_id=request_id or "site-context", + ) + if int(resp.get("status") or 0) != 200: + text = str(resp.get("text") or "")[:500] + logger.warning( + "[%s] fetch_site_context 失败 status=%s url=%s body=%s", + self.type_name, + resp.get("status"), + resp.get("url"), + text, + ) + return None + data = resp.get("json") + if not isinstance(data, dict): + logger.warning("[%s] fetch_site_context 返回非 JSON", self.type_name) + return None + memberships = data.get("memberships") or [] + if not memberships: + return None + org = memberships[0].get("organization") or {} + org_uuid = org.get("uuid") + if org_uuid: + result = {"org_uuid": org_uuid} + self._site_context_cache[page_id] = (result, time.time()) + return result + return None + + async def create_session( + self, + context: BrowserContext, + page: Page, + site_context: dict[str, Any], + **kwargs: Any, + ) -> str | None: + org_uuid = site_context["org_uuid"] + public_model = str(kwargs.get("public_model") or "").strip() + upstream_model = str(kwargs.get("upstream_model") or "").strip() + if not upstream_model: + upstream_model = self.resolve_model(None).upstream_model + payload: dict[str, Any] = { + "name": "", + "model": ( + _base_upstream_model(public_model) + if _is_thinking_model(public_model) + else upstream_model + ), + } + if _is_thinking_model(public_model): + payload["paprika_mode"] = "extended" + url = f"{self.api_base}/organizations/{org_uuid}/chat_conversations" + request_id = str(kwargs.get("request_id") or "").strip() + resp = await request_json_via_context_request( + context, + page, + url, + method="POST", + body=json.dumps(payload), + headers={"Content-Type": "application/json"}, + timeout_ms=15000, + request_id=request_id or f"create-session:{org_uuid}", + ) + status = int(resp.get("status") or 0) + if status not in (200, 201): + text = str(resp.get("text") or "")[:500] + logger.warning("创建会话失败 %s: %s", status, text) + return None + data = resp.get("json") + if not isinstance(data, dict): + logger.warning("创建会话返回非 JSON") + return None + return data.get("uuid") + + def build_completion_url(self, session_id: str, state: dict[str, Any]) -> str: + org_uuid = state["site_context"]["org_uuid"] + return f"{self.api_base}/organizations/{org_uuid}/chat_conversations/{session_id}/completion" + + # 构建请求体 + def build_completion_body( + self, + message: str, + session_id: str, + state: dict[str, Any], + prepared_attachments: dict[str, Any] | None = None, + ) -> dict[str, Any]: + parent = state.get("parent_message_uuid") + tz = state.get("timezone") or TIMEZONE + public_model = str(state.get("public_model") or "").strip() + body = _default_completion_body( + message, + is_follow_up=parent is not None, + timezone=tz, + public_model=public_model, + ) + if parent: + body["parent_message_uuid"] = parent + if prepared_attachments: + body.update(prepared_attachments) + return body + + def parse_stream_event( + self, + payload: str, + ) -> tuple[list[str], str | None, str | None]: + return _parse_one_sse_event(payload) + + def is_stream_end_event(self, payload: str) -> bool: + return _is_terminal_sse_event(payload) + + # 处理错误 + def stream_transport(self) -> str: + return "context_request" + + def on_http_error( + self, + message: str, + headers: dict[str, str] | None, + ) -> int | None: + if "429" not in message: + return None + if headers: + reset = headers.get("anthropic-ratelimit-requests-reset") or headers.get( + "Anthropic-Ratelimit-Requests-Reset" + ) + if reset: + try: + s = str(reset).strip() + if s.endswith("Z"): + s = s[:-1] + "+00:00" + dt = datetime.datetime.fromisoformat(s) + if dt.tzinfo is None: + dt = dt.replace(tzinfo=datetime.timezone.utc) + return int(dt.timestamp()) + except Exception: + pass + return int(time.time()) + 5 * 3600 + + _UUID_RE = re.compile( + r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}$" + ) + + def on_stream_completion_finished( + self, + session_id: str, + message_ids: list[str], + ) -> None: + """Claude 多轮续写需要 parent_message_uuid,取本轮最后一条消息 UUID 写入 state。""" + last_uuid = next( + (m for m in reversed(message_ids) if self._UUID_RE.match(m)), None + ) + if last_uuid and session_id in self._session_state: + self._session_state[session_id]["parent_message_uuid"] = last_uuid + logger.info( + "[%s] updated parent_message_uuid=%s", self.type_name, last_uuid + ) + + async def prepare_attachments( + self, + context: BrowserContext, + page: Page, + session_id: str, + state: dict[str, Any], + attachments: list[InputAttachment], + request_id: str = "", + ) -> dict[str, Any]: + if not attachments: + return {} + if len(attachments) > 5: + raise RuntimeError("Claude 单次最多上传 5 张图片") + + org_uuid = state["site_context"]["org_uuid"] + url = ( + f"{self.api_base}/organizations/{org_uuid}/conversations/" + f"{session_id}/wiggle/upload-file" + ) + file_ids: list[str] = [] + for attachment in attachments: + resp = await upload_file_via_context_request( + context, + page, + url, + filename=attachment.filename, + mime_type=attachment.mime_type, + data=attachment.data, + field_name="file", + timeout_ms=30000, + request_id=request_id or f"upload:{session_id}", + ) + status = int(resp.get("status") or 0) + if status not in (200, 201): + text = str(resp.get("text") or "")[:500] + raise RuntimeError(f"图片上传失败 {status}: {text}") + data = resp.get("json") + if not isinstance(data, dict): + raise RuntimeError("图片上传返回非 JSON") + file_uuid = data.get("file_uuid") or data.get("uuid") + if not file_uuid: + raise RuntimeError("图片上传未返回 file_uuid") + file_ids.append(str(file_uuid)) + return {"attachments": [], "files": file_ids} + + +def register_claude_plugin() -> None: + """注册 Claude 插件到全局 Registry。""" + PluginRegistry.register(ClaudePlugin()) diff --git a/core/plugin/errors.py b/core/plugin/errors.py new file mode 100644 index 0000000000000000000000000000000000000000..11b830cfacd9b428b9a9d93a4d09715bf9cc45d7 --- /dev/null +++ b/core/plugin/errors.py @@ -0,0 +1,46 @@ +"""插件层公共异常,独立模块避免循环导入。""" + + +class AccountFrozenError(RuntimeError): + """ + 插件在检测到账号被限流/额度用尽时抛出,携带解冻时间戳(Unix 秒)。 + 由 chat_handler 捕获后写入配置并重试其他账号。 + """ + + def __init__(self, message: str, unfreeze_at: int) -> None: + super().__init__(message) + self.unfreeze_at = unfreeze_at + + +class BrowserResourceInvalidError(RuntimeError): + """页面 / tab / browser 资源失效时抛出,供上层做定向回收与重试。""" + + def __init__( + self, + detail: str, + *, + helper_name: str, + operation: str, + stage: str, + resource_hint: str, + request_url: str, + page_url: str, + request_id: str | None = None, + stream_phase: str | None = None, + proxy_key: object | None = None, + type_name: str | None = None, + account_id: str | None = None, + ) -> None: + super().__init__(detail) + self.detail = detail + self.helper_name = helper_name + self.operation = operation + self.stage = stage + self.resource_hint = resource_hint + self.request_url = request_url + self.page_url = page_url + self.request_id = request_id + self.stream_phase = stream_phase + self.proxy_key = proxy_key + self.type_name = type_name + self.account_id = account_id diff --git a/core/plugin/helpers.py b/core/plugin/helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..58dca252bde42f661e5ab2a228e513cb488411d8 --- /dev/null +++ b/core/plugin/helpers.py @@ -0,0 +1,1246 @@ +""" +插件通用能力:页面复用、Cookie 登录、在浏览器内发起 fetch 并流式回传。 +接入方只需实现站点特有的 URL/请求体/SSE 解析,其余复用此处逻辑。 +""" + +import asyncio +import base64 +import codecs +import json +import logging +from collections.abc import Callable +from typing import Any, AsyncIterator +from urllib.parse import urlparse + +from curl_cffi import requests as curl_requests +from playwright.async_api import BrowserContext, Page + +from core.plugin.errors import AccountFrozenError, BrowserResourceInvalidError + +ParseSseEvent = Callable[[str], tuple[list[str], str | None, str | None]] + +logger = logging.getLogger(__name__) + +_BROWSER_RESOURCE_ERROR_PATTERNS: tuple[tuple[str, str, str], ...] = ( + ("target crashed", "page", "target_crashed"), + ("page crashed", "browser", "page_crashed"), + ("execution context was destroyed", "page", "execution_context_destroyed"), + ("navigating frame was detached", "page", "frame_detached"), + ("frame was detached", "page", "frame_detached"), + ("session closed. most likely the page has been closed", "page", "page_closed"), + ("most likely the page has been closed", "page", "page_closed"), + ("browser context has been closed", "page", "context_closed"), + ("context has been closed", "page", "context_closed"), + ("target page, context or browser has been closed", "page", "page_or_browser_closed"), + ("page has been closed", "page", "page_closed"), + ("target closed", "page", "target_closed"), + ("browser has been closed", "browser", "browser_closed"), + ("browser closed", "browser", "browser_closed"), + ("connection closed", "browser", "browser_disconnected"), + ("connection terminated", "browser", "browser_disconnected"), + ("has been disconnected", "browser", "browser_disconnected"), + # Proxy / network tunnel errors — retryable via browser re-launch + ("err_tunnel_connection_failed", "browser", "proxy_tunnel_failed"), + ("err_proxy_connection_failed", "browser", "proxy_connection_failed"), + ("err_connection_refused", "browser", "connection_refused"), + ("err_connection_timed_out", "browser", "connection_timed_out"), + ("err_connection_reset", "browser", "connection_reset"), +) + + +def _truncate_for_log(value: str, limit: int = 240) -> str: + if len(value) <= limit: + return value + return value[:limit] + "..." + + + +def _safe_page_url(page: Page | None) -> str: + if page is None: + return "" + try: + return page.url or "" + except Exception: + return "" + + + +def _evaluate_timeout_seconds(timeout_ms: int, grace_seconds: float = 5.0) -> float: + return max(5.0, float(timeout_ms) / 1000.0 + grace_seconds) + + + +def _consume_background_task_result(task: asyncio.Task[Any]) -> None: + try: + if not task.cancelled(): + task.exception() + except Exception: + pass + + + +def _classify_browser_resource_error( + exc: Exception, + *, + helper_name: str, + operation: str, + stage: str, + request_url: str, + page: Page | None, + request_id: str | None = None, + stream_phase: str | None = None, +) -> BrowserResourceInvalidError | None: + message = str(exc).strip() or exc.__class__.__name__ + normalized = message.lower() + for pattern, resource_hint, reason in _BROWSER_RESOURCE_ERROR_PATTERNS: + if pattern not in normalized: + continue + page_url = _safe_page_url(page) + logger.warning( + "[browser-resource-invalid] helper=%s operation=%s stage=%s reason=%s resource=%s request_id=%s stream_phase=%s request_url=%s page.url=%s err=%s", + helper_name, + operation, + stage, + reason, + resource_hint, + request_id, + stream_phase, + _truncate_for_log(request_url), + _truncate_for_log(page_url), + _truncate_for_log(message, 400), + ) + return BrowserResourceInvalidError( + message, + helper_name=helper_name, + operation=operation, + stage=stage, + resource_hint=resource_hint, + request_url=request_url, + page_url=page_url, + request_id=request_id, + stream_phase=stream_phase, + ) + return None + +# 在页面内 POST 请求并流式回传:成功时逐块发送响应体,失败时发送 __error__: 前缀 + 信息,最后发送 __done__ +# bindingName 按请求唯一,同一 page 多并发时互不串数据 +PAGE_FETCH_STREAM_JS = """ +async ({ url, body, bindingName, timeoutMs }) => { + const send = globalThis[bindingName]; + const done = "__done__"; + const errPrefix = "__error__:"; + try { + const ctrl = new AbortController(); + const effectiveTimeoutMs = timeoutMs || 90000; + const t = setTimeout(() => ctrl.abort(), effectiveTimeoutMs); + const resp = await fetch(url, { + method: "POST", + body: body, + headers: { "Content-Type": "application/json", "Accept": "text/event-stream" }, + credentials: "include", + signal: ctrl.signal + }); + clearTimeout(t); + if (!resp.ok) { + const errText = await resp.text(); + const errSnippet = (errText && errText.length > 800) ? errText.slice(0, 800) + "..." : (errText || ""); + await send(errPrefix + "HTTP " + resp.status + " " + errSnippet); + await send(done); + return; + } + if (!resp.body) { + await send(errPrefix + "No response body"); + await send(done); + return; + } + const headersObj = {}; + resp.headers.forEach((v, k) => { headersObj[k] = v; }); + await send("__headers__:" + JSON.stringify(headersObj)); + const reader = resp.body.getReader(); + const dec = new TextDecoder(); + while (true) { + const { done: streamDone, value } = await reader.read(); + if (streamDone) break; + await send(dec.decode(value)); + } + } catch (e) { + const msg = e.name === "AbortError" ? `请求超时(${Math.floor(effectiveTimeoutMs / 1000)}s)` : (e.message || String(e)); + await send(errPrefix + msg); + } + await send(done); +} +""" + + +PAGE_FETCH_JSON_JS = """ +async ({ url, method, body, headers, timeoutMs }) => { + const ctrl = new AbortController(); + const t = setTimeout(() => ctrl.abort(), timeoutMs || 15000); + try { + const resp = await fetch(url, { + method: method || "GET", + body: body ?? undefined, + headers: headers || {}, + credentials: "include", + signal: ctrl.signal + }); + clearTimeout(t); + const text = await resp.text(); + const headersObj = {}; + resp.headers.forEach((v, k) => { headersObj[k] = v; }); + return { + ok: resp.ok, + status: resp.status, + statusText: resp.statusText, + url: resp.url, + redirected: resp.redirected, + headers: headersObj, + text, + }; + } catch (e) { + clearTimeout(t); + const msg = e.name === "AbortError" ? `请求超时(${Math.floor((timeoutMs || 15000) / 1000)}s)` : (e.message || String(e)); + return { error: msg }; + } +} +""" + + +PAGE_FETCH_MULTIPART_JS = """ +async ({ url, filename, mimeType, dataBase64, fieldName, extraFields, timeoutMs }) => { + const ctrl = new AbortController(); + const t = setTimeout(() => ctrl.abort(), timeoutMs || 30000); + try { + const binary = atob(dataBase64); + const bytes = new Uint8Array(binary.length); + for (let i = 0; i < binary.length; i += 1) { + bytes[i] = binary.charCodeAt(i); + } + const form = new FormData(); + if (extraFields) { + Object.entries(extraFields).forEach(([k, v]) => { + if (v !== undefined && v !== null) form.append(k, String(v)); + }); + } + const file = new File([bytes], filename, { type: mimeType || "application/octet-stream" }); + form.append(fieldName || "file", file); + const resp = await fetch(url, { + method: "POST", + body: form, + credentials: "include", + signal: ctrl.signal + }); + clearTimeout(t); + const text = await resp.text(); + const headersObj = {}; + resp.headers.forEach((v, k) => { headersObj[k] = v; }); + return { + ok: resp.ok, + status: resp.status, + statusText: resp.statusText, + url: resp.url, + redirected: resp.redirected, + headers: headersObj, + text, + }; + } catch (e) { + clearTimeout(t); + const msg = e.name === "AbortError" ? `请求超时(${Math.floor((timeoutMs || 30000) / 1000)}s)` : (e.message || String(e)); + return { error: msg }; + } +} +""" + + +async def ensure_page_for_site( + context: BrowserContext, + url_contains: str, + start_url: str, + *, + timeout: int = 45000, +) -> Page: + """ + 若已有页面 URL 包含 url_contains 则复用,否则 new_page 并 goto start_url。 + 接入方只需提供「站点特征」和「入口 URL」。 + """ + if context.pages: + for p in context.pages: + if url_contains in (p.url or ""): + return p + page = await context.new_page() + await page.goto(start_url, wait_until="commit", timeout=timeout) + return page + + +async def create_page_for_site( + context: BrowserContext, + start_url: str, + *, + reuse_page: Page | None = None, + timeout: int = 45000, +) -> Page: + """ + 若传入 reuse_page 则在其上 goto start_url,否则 new_page 再 goto。 + 用于复用浏览器默认空白页或 page 池的初始化与补回。 + """ + if reuse_page is not None: + await reuse_page.goto(start_url, wait_until="commit", timeout=timeout) + return reuse_page + page = await context.new_page() + await page.goto(start_url, wait_until="commit", timeout=timeout) + return page + + +def _cookie_domain_matches(cookie_domain: str, site_domain: str) -> bool: + """判断 cookie 的 domain 是否属于站点 domain(如 .claude.ai 与 claude.ai 视为同一域)。""" + a = cookie_domain if cookie_domain.startswith(".") else f".{cookie_domain}" + b = site_domain if site_domain.startswith(".") else f".{site_domain}" + return a == b + + +def _cookie_to_set_param(c: Any) -> dict[str, str]: + """将 context.cookies() 返回的项转为 add_cookies 接受的 SetCookieParam 格式。""" + return { + "name": c["name"], + "value": c["value"], + "domain": c.get("domain") or "", + "path": c.get("path") or "/", + } + + +async def clear_cookies_for_domain( + context: BrowserContext, + site_domain: str, +) -> None: + """清除 context 内属于指定站点域的所有 cookie,保留其他域。""" + cookies = await context.cookies() + keep = [ + c + for c in cookies + if not _cookie_domain_matches(c.get("domain", ""), site_domain) + ] + await context.clear_cookies() + if keep: + await context.add_cookies([_cookie_to_set_param(c) for c in keep]) # type: ignore[arg-type] + logger.info( + "[auth] cleared cookies for domain=%s (kept %s cookies)", site_domain, len(keep) + ) + + +async def clear_page_storage_for_switch(page: Page) -> None: + """切号前清空当前页面的 localStorage(当前 origin)。""" + try: + await page.evaluate("() => { window.localStorage.clear(); }") + logger.info("[auth] cleared localStorage for switch") + except Exception as e: + logger.warning("[auth] clear localStorage failed (page may be detached): %s", e) + + +async def safe_page_reload(page: Page, url: str | None = None) -> None: + """安全地 reload 或 goto(url),忽略因 ERR_ABORTED / frame detached 导致的异常。""" + try: + if url: + await page.goto(url, wait_until="commit", timeout=45000) + else: + await page.reload(wait_until="domcontentloaded", timeout=45000) + except Exception as e: + err_msg = str(e) + if "ERR_ABORTED" in err_msg or "detached" in err_msg.lower(): + logger.warning( + "[auth] page.reload/goto 被中止或 frame 已分离: %s", err_msg[:200] + ) + else: + raise + + +async def apply_cookie_auth( + context: BrowserContext, + page: Page, + auth: dict[str, Any], + cookie_name: str, + auth_keys: list[str], + domain: str, + *, + path: str = "/", + reload: bool = True, +) -> None: + """ + 从 auth 中按 auth_keys 顺序取第一个非空值作为 cookie 值,写入 context 并可选 reload。 + 接入方只需提供 cookie 名、auth 里的 key 列表、域名。 + 仅写 cookie 不 reload 时,同 context 内的 fetch() 仍会带上 cookie;reload 仅在需要页面文档同步登录态时用。 + """ + value = None + for k in auth_keys: + v = auth.get(k) + if v is not None and v != "": + value = str(v).strip() + if value: + break + if not value: + raise ValueError(f"auth 需包含以下其一且非空: {auth_keys}") + + logger.info( + "[auth] context.add_cookies domain=%s name=%s reload=%s page.url=%s", + domain, + cookie_name, + reload, + page.url, + ) + await context.add_cookies( + [ + { + "name": cookie_name, + "value": value, + "domain": domain, + "path": path, + "secure": True, + "httpOnly": True, + } + ] + ) + if reload: + await safe_page_reload(page) + + +def _attach_json_body(result: dict[str, Any], *, invalid_message: str) -> dict[str, Any]: + if not isinstance(result, dict): + raise RuntimeError(invalid_message) + error = result.get("error") + if error: + raise RuntimeError(str(error)) + text = result.get("text") + if isinstance(text, str) and text: + try: + result["json"] = json.loads(text) + except json.JSONDecodeError: + result["json"] = None + else: + result["json"] = None + return result + + +def _cookie_domain_matches_url(cookie_domain: str, target_url: str) -> bool: + host = (urlparse(target_url).hostname or "").lower().lstrip(".") + domain = (cookie_domain or "").lower().lstrip(".") + if not host or not domain: + return False + return host == domain or host.endswith(f".{domain}") + + +def _cookies_for_url(cookies: list[dict[str, Any]], target_url: str) -> dict[str, str]: + target_host = (urlparse(target_url).hostname or "").lower().lstrip(".") + if not target_host: + return {} + selected: dict[str, str] = {} + for cookie in cookies: + name = str(cookie.get("name") or "").strip() + value = str(cookie.get("value") or "") + domain = str(cookie.get("domain") or "").strip() + if not name or not _cookie_domain_matches_url(domain, target_url): + continue + selected[name] = value + return selected + + +async def _stream_via_http_client( + context: BrowserContext, + page: Page | None, + url: str, + body: str, + request_id: str, + *, + on_http_error: Callable[[str, dict[str, str] | None], int | None] | None = None, + on_headers: Callable[[dict[str, str]], None] | None = None, + connect_timeout: float = 30.0, + read_timeout: float = 300.0, + impersonate: str = "chrome142", + proxy_url: str | None = None, + proxy_auth: tuple[str, str] | None = None, +) -> AsyncIterator[str]: + logger.info( + "[fetch] helper=stream_raw_via_context_request request_id=%s stage=http_client url=%s page.url=%s", + request_id, + _truncate_for_log(url, 120), + _truncate_for_log(_safe_page_url(page), 120), + ) + + parsed = urlparse(url) + referer = "" + if parsed.scheme and parsed.netloc: + referer = f"{parsed.scheme}://{parsed.netloc}/" + + try: + cookies = await context.cookies([url]) + except Exception as e: + classified = _classify_browser_resource_error( + e, + helper_name="stream_raw_via_context_request", + operation="context.cookies", + stage="load_cookies", + request_url=url, + page=page, + request_id=request_id, + stream_phase="fetch", + ) + if classified is not None: + raise classified from e + raise BrowserResourceInvalidError( + str(e), + helper_name="stream_raw_via_context_request", + operation="context.cookies", + stage="load_cookies", + resource_hint="page", + request_url=url, + page_url=_safe_page_url(page), + request_id=request_id, + stream_phase="fetch", + ) from e + cookie_jar = _cookies_for_url(cookies, url) + session_kwargs: dict[str, Any] = { + "impersonate": impersonate, + "timeout": (connect_timeout, read_timeout), + "verify": True, + "allow_redirects": True, + "default_headers": True, + } + if cookie_jar: + session_kwargs["cookies"] = cookie_jar + if proxy_url: + session_kwargs["proxy"] = proxy_url + if proxy_auth: + session_kwargs["proxy_auth"] = proxy_auth + + response = None + try: + async with curl_requests.AsyncSession(**session_kwargs) as session: + try: + request_headers = { + "Content-Type": "application/json", + "Accept": "text/event-stream", + } + if referer: + request_headers["Origin"] = referer.rstrip("/") + async with session.stream( + "POST", + url, + data=body.encode("utf-8"), + headers=request_headers, + ) as response: + headers = { + str(k).lower(): str(v) for k, v in response.headers.items() + } + if on_headers: + on_headers(headers) + + status = int(response.status_code) + if status < 200 or status >= 300: + body_parts: list[str] = [] + decoder = codecs.getincrementaldecoder("utf-8")("replace") + async for chunk in response.aiter_content(): + if not chunk: + continue + body_parts.append(decoder.decode(chunk)) + if sum(len(part) for part in body_parts) >= 800: + break + body_parts.append(decoder.decode(b"", final=True)) + snippet = "".join(body_parts) + if len(snippet) > 800: + snippet = snippet[:800] + "..." + msg = f"HTTP {status} {snippet}".strip() + if on_http_error: + unfreeze_at = on_http_error(msg, headers) + if isinstance(unfreeze_at, int): + logger.warning("[fetch] HTTP error from context request: %s", msg) + raise AccountFrozenError(msg, unfreeze_at) + raise RuntimeError(msg) + + decoder = codecs.getincrementaldecoder("utf-8")("replace") + async for chunk in response.aiter_content(): + if not chunk: + continue + text = decoder.decode(chunk) + if text: + yield text + tail = decoder.decode(b"", final=True) + if tail: + yield tail + except Exception as e: + classified = _classify_browser_resource_error( + e, + helper_name="stream_raw_via_context_request", + operation="http_client", + stage="stream", + request_url=url, + page=page, + request_id=request_id, + stream_phase="body", + ) + if classified is not None: + raise classified from e + raise BrowserResourceInvalidError( + str(e), + helper_name="stream_raw_via_context_request", + operation="http_client", + stage="stream", + resource_hint="transport", + request_url=url, + page_url=_safe_page_url(page), + request_id=request_id, + stream_phase="body", + ) from e + except AccountFrozenError: + raise + except BrowserResourceInvalidError: + raise + except Exception as e: + classified = _classify_browser_resource_error( + e, + helper_name="stream_raw_via_context_request", + operation="http_client", + stage="request", + request_url=url, + page=page, + request_id=request_id, + stream_phase="fetch", + ) + if classified is not None: + raise classified from e + logger.warning( + "[fetch] helper=stream_raw_via_context_request request_id=%s http_client failed url=%s page.url=%s err=%s", + request_id, + _truncate_for_log(url, 120), + _truncate_for_log(_safe_page_url(page), 120), + _truncate_for_log(str(e), 400), + ) + raise BrowserResourceInvalidError( + str(e), + helper_name="stream_raw_via_context_request", + operation="http_client", + stage="request", + resource_hint="transport", + request_url=url, + page_url=_safe_page_url(page), + request_id=request_id, + stream_phase="fetch", + ) from e + + +async def _request_via_context_request( + context: BrowserContext, + page: Page | None, + url: str, + *, + method: str = "GET", + body: str | None = None, + headers: dict[str, str] | None = None, + multipart: dict[str, Any] | None = None, + timeout_ms: int = 15000, + request_id: str | None = None, + helper_name: str, +) -> dict[str, Any]: + logger.info( + "[fetch] helper=%s method=%s request_id=%s url=%s page.url=%s", + helper_name, + method, + request_id, + _truncate_for_log(url, 120), + _truncate_for_log(_safe_page_url(page), 120), + ) + response = None + try: + response = await context.request.fetch( + url, + method=method, + headers=headers or None, + data=body, + multipart=multipart, + timeout=timeout_ms, + fail_on_status_code=False, + ) + text = await response.text() + return { + "ok": bool(response.ok), + "status": int(response.status), + "statusText": str(response.status_text), + "url": str(response.url), + "redirected": str(response.url) != url, + "headers": {str(k): str(v) for k, v in response.headers.items()}, + "text": text, + } + except Exception as e: + classified = _classify_browser_resource_error( + e, + helper_name=helper_name, + operation="context.request", + stage="fetch", + request_url=url, + page=page, + request_id=request_id, + ) + if classified is not None: + raise classified from e + logger.warning( + "[fetch] helper=%s request_id=%s context.request failed url=%s page.url=%s err=%s", + helper_name, + request_id, + _truncate_for_log(url, 120), + _truncate_for_log(_safe_page_url(page), 120), + _truncate_for_log(str(e), 400), + ) + raise RuntimeError(str(e)) from e + finally: + if response is not None: + try: + await response.dispose() + except Exception: + pass + + +async def request_json_via_context_request( + context: BrowserContext, + page: Page | None, + url: str, + *, + method: str = "GET", + body: str | None = None, + headers: dict[str, str] | None = None, + timeout_ms: int = 15000, + request_id: str | None = None, +) -> dict[str, Any]: + result = await _request_via_context_request( + context, + page, + url, + method=method, + body=body, + headers=headers, + timeout_ms=timeout_ms, + request_id=request_id, + helper_name="request_json_via_context_request", + ) + return _attach_json_body(result, invalid_message="控制请求返回结果异常") + + +async def request_json_via_page_fetch( + page: Page, + url: str, + *, + method: str = "GET", + body: str | None = None, + headers: dict[str, str] | None = None, + timeout_ms: int = 15000, + request_id: str | None = None, +) -> dict[str, Any]: + """ + 在页面内发起非流式 fetch,请求结果按 JSON 优先解析返回。 + 这样能复用浏览器真实网络栈、cookie 与代理扩展能力。 + """ + logger.info( + "[fetch] helper=request_json_via_page_fetch method=%s request_id=%s url=%s page.url=%s", + method, + request_id, + _truncate_for_log(url, 120), + _truncate_for_log(_safe_page_url(page), 120), + ) + try: + result = await asyncio.wait_for( + page.evaluate( + PAGE_FETCH_JSON_JS, + { + "url": url, + "method": method, + "body": body, + "headers": headers or {}, + "timeoutMs": timeout_ms, + }, + ), + timeout=_evaluate_timeout_seconds(timeout_ms), + ) + except asyncio.TimeoutError as e: + logger.warning( + "[fetch] helper=request_json_via_page_fetch request_id=%s evaluate timeout url=%s page.url=%s", + request_id, + _truncate_for_log(url, 120), + _truncate_for_log(_safe_page_url(page), 120), + ) + raise BrowserResourceInvalidError( + f"page.evaluate timeout after {_evaluate_timeout_seconds(timeout_ms):.1f}s", + helper_name="request_json_via_page_fetch", + operation="page.evaluate", + stage="evaluate_timeout", + resource_hint="page", + request_url=url, + page_url=_safe_page_url(page), + request_id=request_id, + ) from e + except Exception as e: + classified = _classify_browser_resource_error( + e, + helper_name="request_json_via_page_fetch", + operation="page.evaluate", + stage="evaluate", + request_url=url, + page=page, + request_id=request_id, + ) + if classified is not None: + raise classified from e + raise + return _attach_json_body(result, invalid_message="页面 fetch 返回结果异常") + + +async def upload_file_via_context_request( + context: BrowserContext, + page: Page | None, + url: str, + *, + filename: str, + mime_type: str, + data: bytes, + field_name: str = "file", + extra_fields: dict[str, str] | None = None, + timeout_ms: int = 30000, + request_id: str | None = None, +) -> dict[str, Any]: + multipart: dict[str, Any] = dict(extra_fields or {}) + multipart[field_name] = { + "name": filename, + "mimeType": mime_type or "application/octet-stream", + "buffer": data, + } + result = await _request_via_context_request( + context, + page, + url, + method="POST", + multipart=multipart, + timeout_ms=timeout_ms, + request_id=request_id, + helper_name="upload_file_via_context_request", + ) + return _attach_json_body(result, invalid_message="控制上传返回结果异常") + + +async def upload_file_via_page_fetch( + page: Page, + url: str, + *, + filename: str, + mime_type: str, + data: bytes, + field_name: str = "file", + extra_fields: dict[str, str] | None = None, + timeout_ms: int = 30000, + request_id: str | None = None, +) -> dict[str, Any]: + logger.info( + "[fetch] helper=upload_file_via_page_fetch filename=%s mime=%s request_id=%s url=%s page.url=%s", + filename, + mime_type, + request_id, + _truncate_for_log(url, 120), + _truncate_for_log(_safe_page_url(page), 120), + ) + try: + result = await asyncio.wait_for( + page.evaluate( + PAGE_FETCH_MULTIPART_JS, + { + "url": url, + "filename": filename, + "mimeType": mime_type, + "dataBase64": base64.b64encode(data).decode("ascii"), + "fieldName": field_name, + "extraFields": extra_fields or {}, + "timeoutMs": timeout_ms, + }, + ), + timeout=_evaluate_timeout_seconds(timeout_ms), + ) + except asyncio.TimeoutError as e: + logger.warning( + "[fetch] helper=upload_file_via_page_fetch request_id=%s evaluate timeout url=%s page.url=%s", + request_id, + _truncate_for_log(url, 120), + _truncate_for_log(_safe_page_url(page), 120), + ) + raise BrowserResourceInvalidError( + f"page.evaluate timeout after {_evaluate_timeout_seconds(timeout_ms):.1f}s", + helper_name="upload_file_via_page_fetch", + operation="page.evaluate", + stage="evaluate_timeout", + resource_hint="page", + request_url=url, + page_url=_safe_page_url(page), + request_id=request_id, + ) from e + except Exception as e: + classified = _classify_browser_resource_error( + e, + helper_name="upload_file_via_page_fetch", + operation="page.evaluate", + stage="evaluate", + request_url=url, + page=page, + request_id=request_id, + ) + if classified is not None: + raise classified from e + raise + return _attach_json_body(result, invalid_message="页面上传返回结果异常") + + +async def stream_raw_via_context_request( + context: BrowserContext, + page: Page | None, + url: str, + body: str, + request_id: str, + *, + on_http_error: Callable[[str, dict[str, str] | None], int | None] | None = None, + on_headers: Callable[[dict[str, str]], None] | None = None, + fetch_timeout: float = 90.0, + body_timeout: float = 300.0, + proxy_url: str | None = None, + proxy_auth: tuple[str, str] | None = None, +) -> AsyncIterator[str]: + """通过真实流式 HTTP client 发起 completion 请求,避免先读完整 body。""" + del fetch_timeout + async for chunk in _stream_via_http_client( + context, + page, + url, + body, + request_id, + on_http_error=on_http_error, + on_headers=on_headers, + read_timeout=body_timeout, + proxy_url=proxy_url, + proxy_auth=proxy_auth, + ): + yield chunk + + +async def stream_raw_via_page_fetch( + context: BrowserContext, + page: Page, + url: str, + body: str, + request_id: str, + *, + on_http_error: Callable[[str, dict[str, str] | None], int | None] | None = None, + on_headers: Callable[[dict[str, str]], None] | None = None, + error_state: dict[str, bool] | None = None, + fetch_timeout: int = 90, + read_timeout: float = 60.0, +) -> AsyncIterator[str]: + """ + 在浏览器内对 url 发起 POST body,流式回传原始字符串块(含 SSE 等)。 + 同一 page 多请求用 request_id 区分 binding,互不串数据。 + 通过 CDP Runtime.addBinding 注入 sendChunk_,用 Runtime.bindingCalled 接收。 + 收到 __headers__: 时解析 JSON 并调用 on_headers(headers);收到 __error__: 时调用 on_http_error(msg);收到 __done__ 结束。 + """ + chunk_queue: asyncio.Queue[str] = asyncio.Queue() + binding_name = "sendChunk_" + request_id + stream_phase = "cdp_setup" + + def on_binding_called(event: dict[str, Any]) -> None: + name = event.get("name") + payload = event.get("payload", "") + if name == binding_name: + chunk_queue.put_nowait( + payload if isinstance(payload, str) else str(payload) + ) + + def classify_stream_error( + exc: Exception, + *, + stage: str, + ) -> BrowserResourceInvalidError | None: + return _classify_browser_resource_error( + exc, + helper_name="stream_raw_via_page_fetch", + operation="stream", + stage=stage, + request_url=url, + page=page, + request_id=request_id, + stream_phase=stream_phase, + ) + + cdp = None + fetch_task: asyncio.Task[None] | None = None + try: + try: + cdp = await context.new_cdp_session(page) + except Exception as e: + classified = classify_stream_error(e, stage="new_cdp_session") + if classified is not None: + raise classified from e + raise + cdp.on("Runtime.bindingCalled", on_binding_called) + try: + await cdp.send("Runtime.addBinding", {"name": binding_name}) + except Exception as e: + classified = classify_stream_error(e, stage="add_binding") + if classified is not None: + raise classified from e + raise + + logger.info( + "[fetch] helper=stream_raw_via_page_fetch request_id=%s stage=page.evaluate url=%s page.url=%s", + request_id, + _truncate_for_log(url, 120), + _truncate_for_log(_safe_page_url(page), 120), + ) + + async def run_fetch() -> None: + nonlocal stream_phase + try: + stream_phase = "page_evaluate" + await asyncio.wait_for( + page.evaluate( + PAGE_FETCH_STREAM_JS, + { + "url": url, + "body": body, + "bindingName": binding_name, + "timeoutMs": max(1, int(fetch_timeout * 1000)), + }, + ), + timeout=max(float(fetch_timeout) + 5.0, 10.0), + ) + except asyncio.TimeoutError as e: + logger.warning( + "[fetch] helper=stream_raw_via_page_fetch request_id=%s stage=page.evaluate evaluate timeout url=%s page.url=%s", + request_id, + _truncate_for_log(url, 120), + _truncate_for_log(_safe_page_url(page), 120), + ) + raise BrowserResourceInvalidError( + f"page.evaluate timeout after {max(float(fetch_timeout) + 5.0, 10.0):.1f}s", + helper_name="stream_raw_via_page_fetch", + operation="stream", + stage="evaluate_timeout", + resource_hint="page", + request_url=url, + page_url=_safe_page_url(page), + request_id=request_id, + stream_phase=stream_phase, + ) from e + except Exception as e: + classified = classify_stream_error(e, stage="page.evaluate") + if classified is not None: + raise classified from e + raise + + fetch_task = asyncio.create_task(run_fetch()) + try: + headers = None + while True: + if fetch_task.done(): + exc = fetch_task.exception() + if exc is not None: + raise exc + try: + chunk = await asyncio.wait_for( + chunk_queue.get(), timeout=read_timeout + ) + except asyncio.TimeoutError as e: + stream_phase = "body" + logger.warning( + "[fetch] helper=stream_raw_via_page_fetch request_id=%s stream_phase=%s read timeout url=%s page.url=%s", + request_id, + stream_phase, + _truncate_for_log(url, 120), + _truncate_for_log(_safe_page_url(page), 120), + ) + raise BrowserResourceInvalidError( + f"stream read timeout after {read_timeout:.1f}s", + helper_name="stream_raw_via_page_fetch", + operation="stream", + stage="read_timeout", + resource_hint="page", + request_url=url, + page_url=_safe_page_url(page), + request_id=request_id, + stream_phase=stream_phase, + ) from e + if chunk == "__done__": + break + if chunk.startswith("__headers__:"): + stream_phase = "headers" + try: + headers = json.loads(chunk[12:]) + if on_headers and isinstance(headers, dict): + on_headers({k: str(v) for k, v in headers.items()}) + except (json.JSONDecodeError, TypeError) as e: + logger.debug("[fetch] 解析 __headers__ 失败: %s", e) + continue + if chunk.startswith("__error__:"): + msg = chunk[10:].strip() + saw_terminal = bool(error_state and error_state.get("terminal")) + stream_phase = "terminal_event" if saw_terminal else ("body" if headers else "before_headers") + if on_http_error: + unfreeze_at = on_http_error(msg, headers) + if isinstance(unfreeze_at, int): + logger.warning("[fetch] __error__ from page: %s", msg) + raise AccountFrozenError(msg, unfreeze_at) + classified = _classify_browser_resource_error( + RuntimeError(msg), + helper_name="stream_raw_via_page_fetch", + operation="page_fetch_stream", + stage="page_error_event", + request_url=url, + page=page, + request_id=request_id, + stream_phase=stream_phase, + ) + if classified is not None: + raise classified + if saw_terminal: + logger.info( + "[fetch] page fetch disconnected after terminal event request_id=%s stream_phase=%s: %s", + request_id, + stream_phase, + msg, + ) + continue + logger.warning( + "[fetch] __error__ from page before terminal event request_id=%s stream_phase=%s: %s", + request_id, + stream_phase, + msg, + ) + raise RuntimeError(msg) + stream_phase = "body" + yield chunk + finally: + if fetch_task is not None: + done, pending = await asyncio.wait({fetch_task}, timeout=5.0) + if pending: + fetch_task.cancel() + fetch_task.add_done_callback(_consume_background_task_result) + else: + try: + fetch_task.result() + except asyncio.CancelledError: + pass + except BrowserResourceInvalidError: + pass + finally: + if cdp is not None: + try: + await asyncio.wait_for(cdp.detach(), timeout=2.0) + except asyncio.TimeoutError: + logger.warning( + "[fetch] helper=stream_raw_via_page_fetch request_id=%s detach CDP session timeout page.url=%s", + request_id, + _truncate_for_log(_safe_page_url(page), 120), + ) + except Exception as e: + logger.debug("detach CDP session 时异常: %s", e) + + +def parse_sse_to_events(buffer: str, chunk: str) -> tuple[str, list[str]]: + """ + 把 chunk 追加到 buffer,按行拆出 data: 后的 payload 列表,返回 (剩余 buffer, payload 列表)。 + 接入方对每个 payload 自行 JSON 解析并抽取 text / message_id / error。 + """ + buffer += chunk + lines = buffer.split("\n") + buffer = lines[-1] + payloads: list[str] = [] + for line in lines[:-1]: + line = line.strip() + if not line.startswith("data: "): + continue + payload = line[6:].strip() + if payload == "[DONE]" or not payload: + continue + payloads.append(payload) + return (buffer, payloads) + + +async def stream_completion_via_sse( + context: BrowserContext, + page: Page, + url: str, + body: str, + parse_event: ParseSseEvent, + request_id: str, + *, + on_http_error: Callable, + is_terminal_event: Callable[[str], bool] | None = None, + collect_message_id: list[str] | None = None, + first_token_timeout: float = 30.0, + transport: str = "page_fetch", + transport_options: dict[str, Any] | None = None, +) -> AsyncIterator[str]: + """ + 在浏览器内 POST 拿到流,按 SSE 行拆成 data 事件,用 parse_event(payload) 解析每条; + 逐块 yield 文本,可选把 message_id 收集到 collect_message_id。 + parse_event(payload) 返回 (texts, message_id, error),error 非空时仅打 debug 日志不抛错。 + """ + buffer = "" + stream_state: dict[str, bool] = {"terminal": False} + saw_text = False + loop = asyncio.get_running_loop() + started_at = loop.time() + opts = dict(transport_options or {}) + if transport == "context_request": + raw_stream = stream_raw_via_context_request( + context, + page, + url, + body, + request_id, + on_http_error=on_http_error, + **opts, + ) + resource_hint = "transport" + else: + raw_stream = stream_raw_via_page_fetch( + context, + page, + url, + body, + request_id, + on_http_error=on_http_error, + error_state=stream_state, + ) + resource_hint = "page" + async for chunk in raw_stream: + buffer, payloads = parse_sse_to_events(buffer, chunk) + for payload in payloads: + if is_terminal_event and is_terminal_event(payload): + stream_state["terminal"] = True + try: + texts, message_id, error = parse_event(payload) + except Exception as e: + logger.debug("parse_stream_event 单条解析异常: %s", e) + continue + if error: + logger.warning("SSE error from upstream: %s", error) + raise RuntimeError(error) + if message_id and collect_message_id is not None: + collect_message_id.append(message_id) + for t in texts: + saw_text = True + yield t + if ( + not saw_text + and not stream_state["terminal"] + and loop.time() - started_at >= first_token_timeout + ): + raise BrowserResourceInvalidError( + f"no text token received within {first_token_timeout:.1f}s", + helper_name="stream_completion_via_sse", + operation="parse_stream", + stage="first_token_timeout", + resource_hint=resource_hint, + request_url=url, + page_url=_safe_page_url(page), + request_id=request_id, + stream_phase="before_first_text", + ) diff --git a/core/protocol/__init__.py b/core/protocol/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..caf4b94efa4f7ca210bfc0ce70f426f96471af03 --- /dev/null +++ b/core/protocol/__init__.py @@ -0,0 +1 @@ +"""协议层:客户端协议适配与 Canonical 模型。""" diff --git a/core/protocol/anthropic.py b/core/protocol/anthropic.py new file mode 100644 index 0000000000000000000000000000000000000000..41c8d0d8403a5154aedfa0c435b0f5591e104a14 --- /dev/null +++ b/core/protocol/anthropic.py @@ -0,0 +1,461 @@ +"""Anthropic 协议适配器。""" + +from __future__ import annotations + +import json +import time +import uuid as uuid_mod +from collections.abc import AsyncIterator +from typing import Any + +from core.api.conv_parser import ( + decode_latest_session_id, + extract_session_id_marker, + strip_session_id_suffix, +) +from core.api.react import format_react_final_answer_content, parse_react_output +from core.api.react_stream_parser import ReactStreamParser +from core.hub.schemas import OpenAIStreamEvent +from core.protocol.base import ProtocolAdapter +from core.protocol.schemas import ( + CanonicalChatRequest, + CanonicalContentBlock, + CanonicalMessage, + CanonicalToolSpec, +) + + +class AnthropicProtocolAdapter(ProtocolAdapter): + protocol_name = "anthropic" + + def parse_request( + self, + provider: str, + raw_body: dict[str, Any], + ) -> CanonicalChatRequest: + messages = raw_body.get("messages") or [] + if not isinstance(messages, list): + raise ValueError("messages 必须为数组") + system_blocks = self._parse_content(raw_body.get("system")) + canonical_messages: list[CanonicalMessage] = [] + resume_session_id: str | None = None + for item in messages: + if not isinstance(item, dict): + continue + blocks = self._parse_content(item.get("content")) + for block in blocks: + text = block.text or "" + decoded = decode_latest_session_id(text) + if decoded: + resume_session_id = decoded + block.text = strip_session_id_suffix(text) + canonical_messages.append( + CanonicalMessage( + role=str(item.get("role") or "user"), + content=blocks, + ) + ) + + for block in system_blocks: + text = block.text or "" + decoded = decode_latest_session_id(text) + if decoded: + resume_session_id = decoded + block.text = strip_session_id_suffix(text) + + tools = [self._parse_tool(tool) for tool in list(raw_body.get("tools") or [])] + stop_sequences = raw_body.get("stop_sequences") or [] + return CanonicalChatRequest( + protocol="anthropic", + provider=provider, + model=str(raw_body.get("model") or ""), + system=system_blocks, + messages=canonical_messages, + stream=bool(raw_body.get("stream") or False), + max_tokens=raw_body.get("max_tokens"), + temperature=raw_body.get("temperature"), + top_p=raw_body.get("top_p"), + stop_sequences=[str(v) for v in stop_sequences if isinstance(v, str)], + tools=tools, + tool_choice=raw_body.get("tool_choice"), + resume_session_id=resume_session_id, + ) + + def render_non_stream( + self, + req: CanonicalChatRequest, + raw_events: list[OpenAIStreamEvent], + ) -> dict[str, Any]: + full = "".join( + ev.content or "" + for ev in raw_events + if ev.type == "content_delta" and ev.content + ) + session_marker = extract_session_id_marker(full) + text = strip_session_id_suffix(full) + message_id = self._message_id(req) + if req.tools: + parsed = parse_react_output(text) + if parsed and parsed.get("type") == "tool_call": + content: list[dict[str, Any]] = [ + { + "type": "tool_use", + "id": f"toolu_{uuid_mod.uuid4().hex[:24]}", + "name": str(parsed.get("tool") or ""), + "input": parsed.get("params") or {}, + } + ] + if session_marker: + content.append({"type": "text", "text": session_marker}) + return self._message_response( + req, + message_id, + content, + stop_reason="tool_use", + ) + rendered = format_react_final_answer_content(text) + else: + rendered = text + if session_marker: + rendered += session_marker + return self._message_response( + req, + message_id, + [{"type": "text", "text": rendered}], + stop_reason="end_turn", + ) + + async def render_stream( + self, + req: CanonicalChatRequest, + raw_stream: AsyncIterator[OpenAIStreamEvent], + ) -> AsyncIterator[str]: + message_id = self._message_id(req) + parser = ReactStreamParser( + chat_id=f"chatcmpl-{uuid_mod.uuid4().hex[:24]}", + model=req.model, + created=int(time.time()), + has_tools=bool(req.tools), + ) + session_marker = "" + translator = _AnthropicStreamTranslator(req, message_id) + async for event in raw_stream: + if event.type == "content_delta" and event.content: + chunk = event.content + if extract_session_id_marker(chunk) and not strip_session_id_suffix( + chunk + ): + session_marker = chunk + continue + for sse in parser.feed(chunk): + for out in translator.feed_openai_sse(sse): + yield out + elif event.type == "finish": + break + for sse in parser.finish(): + for out in translator.feed_openai_sse(sse, session_marker=session_marker): + yield out + + def render_error(self, exc: Exception) -> tuple[int, dict[str, Any]]: + status = 400 if isinstance(exc, ValueError) else 500 + err_type = "invalid_request_error" if status == 400 else "api_error" + return ( + status, + { + "type": "error", + "error": {"type": err_type, "message": str(exc)}, + }, + ) + + @staticmethod + def _parse_tool(tool: dict[str, Any]) -> CanonicalToolSpec: + return CanonicalToolSpec( + name=str(tool.get("name") or ""), + description=str(tool.get("description") or ""), + input_schema=tool.get("input_schema") or {}, + ) + + @staticmethod + def _parse_content(value: Any) -> list[CanonicalContentBlock]: + if value is None: + return [] + if isinstance(value, str): + return [CanonicalContentBlock(type="text", text=value)] + if isinstance(value, list): + blocks: list[CanonicalContentBlock] = [] + for item in value: + if isinstance(item, str): + blocks.append(CanonicalContentBlock(type="text", text=item)) + continue + if not isinstance(item, dict): + continue + item_type = str(item.get("type") or "") + if item_type == "text": + blocks.append( + CanonicalContentBlock( + type="text", text=str(item.get("text") or "") + ) + ) + elif item_type == "image": + source = item.get("source") or {} + source_type = source.get("type") + if source_type == "base64": + blocks.append( + CanonicalContentBlock( + type="image", + mime_type=str(source.get("media_type") or ""), + data=str(source.get("data") or ""), + ) + ) + elif item_type == "tool_result": + text_parts = AnthropicProtocolAdapter._parse_content( + item.get("content") + ) + blocks.append( + CanonicalContentBlock( + type="tool_result", + tool_use_id=str(item.get("tool_use_id") or ""), + text="\n".join( + part.text or "" + for part in text_parts + if part.type == "text" + ), + is_error=bool(item.get("is_error") or False), + ) + ) + return blocks + raise ValueError("content 格式不合法") + + @staticmethod + def _message_response( + req: CanonicalChatRequest, + message_id: str, + content: list[dict[str, Any]], + *, + stop_reason: str, + ) -> dict[str, Any]: + return { + "id": message_id, + "type": "message", + "role": "assistant", + "model": req.model, + "content": content, + "stop_reason": stop_reason, + "stop_sequence": None, + "usage": {"input_tokens": 0, "output_tokens": 0}, + } + + @staticmethod + def _message_id(req: CanonicalChatRequest) -> str: + return str( + req.metadata.setdefault( + "anthropic_message_id", f"msg_{uuid_mod.uuid4().hex}" + ) + ) + + +class _AnthropicStreamTranslator: + def __init__(self, req: CanonicalChatRequest, message_id: str) -> None: + self._req = req + self._message_id = message_id + self._started = False + self._current_block_type: str | None = None + self._current_index = -1 + self._pending_tool_id: str | None = None + self._pending_tool_name: str | None = None + self._stopped = False + + def feed_openai_sse( + self, + sse: str, + *, + session_marker: str = "", + ) -> list[str]: + lines = [line for line in sse.splitlines() if line.startswith("data: ")] + out: list[str] = [] + for line in lines: + payload = line[6:].strip() + if payload == "[DONE]": + continue + obj = json.loads(payload) + choice = (obj.get("choices") or [{}])[0] + delta = choice.get("delta") or {} + finish_reason = choice.get("finish_reason") + if not self._started: + out.append( + self._event( + "message_start", + { + "type": "message_start", + "message": { + "id": self._message_id, + "type": "message", + "role": "assistant", + "model": self._req.model, + "content": [], + "stop_reason": None, + "stop_sequence": None, + "usage": {"input_tokens": 0, "output_tokens": 0}, + }, + }, + ) + ) + self._started = True + + content = delta.get("content") + if isinstance(content, str) and content: + out.extend(self._ensure_text_block()) + out.append( + self._event( + "content_block_delta", + { + "type": "content_block_delta", + "index": self._current_index, + "delta": {"type": "text_delta", "text": content}, + }, + ) + ) + + tool_calls = delta.get("tool_calls") or [] + if tool_calls: + head = tool_calls[0] + if head.get("id") and head.get("function", {}).get("name") is not None: + out.extend(self._close_current_block()) + self._current_index += 1 + self._current_block_type = "tool_use" + self._pending_tool_id = str(head.get("id") or "") + self._pending_tool_name = str( + head.get("function", {}).get("name") or "" + ) + out.append( + self._event( + "content_block_start", + { + "type": "content_block_start", + "index": self._current_index, + "content_block": { + "type": "tool_use", + "id": self._pending_tool_id, + "name": self._pending_tool_name, + "input": {}, + }, + }, + ) + ) + args_delta = head.get("function", {}).get("arguments") + if args_delta: + out.append( + self._event( + "content_block_delta", + { + "type": "content_block_delta", + "index": self._current_index, + "delta": { + "type": "input_json_delta", + "partial_json": str(args_delta), + }, + }, + ) + ) + + if finish_reason: + if session_marker: + if finish_reason == "tool_calls": + out.extend(self._close_current_block()) + out.extend(self._emit_marker_text_block(session_marker)) + else: + out.extend(self._ensure_text_block()) + out.append( + self._event( + "content_block_delta", + { + "type": "content_block_delta", + "index": self._current_index, + "delta": { + "type": "text_delta", + "text": session_marker, + }, + }, + ) + ) + out.extend(self._close_current_block()) + stop_reason = ( + "tool_use" if finish_reason == "tool_calls" else "end_turn" + ) + out.append( + self._event( + "message_delta", + { + "type": "message_delta", + "delta": { + "stop_reason": stop_reason, + "stop_sequence": None, + }, + "usage": {"output_tokens": 0}, + }, + ) + ) + out.append(self._event("message_stop", {"type": "message_stop"})) + self._stopped = True + return out + + def _ensure_text_block(self) -> list[str]: + if self._current_block_type == "text": + return [] + out = self._close_current_block() + self._current_index += 1 + self._current_block_type = "text" + out.append( + self._event( + "content_block_start", + { + "type": "content_block_start", + "index": self._current_index, + "content_block": {"type": "text", "text": ""}, + }, + ) + ) + return out + + def _emit_marker_text_block(self, marker: str) -> list[str]: + self._current_index += 1 + self._current_block_type = "text" + return [ + self._event( + "content_block_start", + { + "type": "content_block_start", + "index": self._current_index, + "content_block": {"type": "text", "text": ""}, + }, + ), + self._event( + "content_block_delta", + { + "type": "content_block_delta", + "index": self._current_index, + "delta": {"type": "text_delta", "text": marker}, + }, + ), + self._event( + "content_block_stop", + {"type": "content_block_stop", "index": self._current_index}, + ), + ] + + def _close_current_block(self) -> list[str]: + if self._current_block_type is None: + return [] + block_index = self._current_index + self._current_block_type = None + return [ + self._event( + "content_block_stop", + {"type": "content_block_stop", "index": block_index}, + ) + ] + + @staticmethod + def _event(event_name: str, payload: dict[str, Any]) -> str: + del event_name + return f"event: {payload['type']}\ndata: {json.dumps(payload, ensure_ascii=False)}\n\n" diff --git a/core/protocol/base.py b/core/protocol/base.py new file mode 100644 index 0000000000000000000000000000000000000000..ea4ad878f2f03997f6342ce0f118fe345efd51f7 --- /dev/null +++ b/core/protocol/base.py @@ -0,0 +1,38 @@ +"""协议适配器抽象。内部统一以 OpenAI 语义事件流为中间态。""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from collections.abc import AsyncIterator +from typing import Any + +from core.hub.schemas import OpenAIStreamEvent +from core.protocol.schemas import CanonicalChatRequest + + +class ProtocolAdapter(ABC): + protocol_name: str + + @abstractmethod + def parse_request( + self, + provider: str, + raw_body: dict[str, Any], + ) -> CanonicalChatRequest: ... + + @abstractmethod + def render_non_stream( + self, + req: CanonicalChatRequest, + raw_events: list[OpenAIStreamEvent], + ) -> dict[str, Any]: ... + + @abstractmethod + def render_stream( + self, + req: CanonicalChatRequest, + raw_stream: AsyncIterator[OpenAIStreamEvent], + ) -> AsyncIterator[str]: ... + + @abstractmethod + def render_error(self, exc: Exception) -> tuple[int, dict[str, Any]]: ... diff --git a/core/protocol/images.py b/core/protocol/images.py new file mode 100644 index 0000000000000000000000000000000000000000..0f8ff4c37606037a747f43add5c65affb2acfe50 --- /dev/null +++ b/core/protocol/images.py @@ -0,0 +1,108 @@ +"""图片输入解析与下载。""" + +from __future__ import annotations + +import asyncio +import base64 +import imghdr +import mimetypes +import urllib.parse +import urllib.request +from dataclasses import dataclass + + +SUPPORTED_IMAGE_MIME_TYPES = { + "image/png", + "image/jpeg", + "image/webp", + "image/gif", +} +MAX_IMAGE_BYTES = 10 * 1024 * 1024 +MAX_IMAGE_COUNT = 5 + + +@dataclass +class PreparedImage: + filename: str + mime_type: str + data: bytes + + +def _validate_image_bytes(data: bytes, mime_type: str) -> None: + if mime_type not in SUPPORTED_IMAGE_MIME_TYPES: + raise ValueError(f"暂不支持的图片类型: {mime_type}") + if len(data) > MAX_IMAGE_BYTES: + raise ValueError("单张图片不能超过 10MB") + + +def _default_filename(mime_type: str, *, prefix: str = "image") -> str: + ext = mimetypes.guess_extension(mime_type) or ".bin" + if ext == ".jpe": + ext = ".jpg" + return f"{prefix}{ext}" + + +def parse_data_url(url: str, *, prefix: str = "image") -> PreparedImage: + if not url.startswith("data:") or ";base64," not in url: + raise ValueError("仅支持 data:image/...;base64,... 格式") + header, payload = url.split(",", 1) + mime_type = header[5:].split(";", 1)[0].strip().lower() + data = base64.b64decode(payload, validate=True) + _validate_image_bytes(data, mime_type) + return PreparedImage( + filename=_default_filename(mime_type, prefix=prefix), + mime_type=mime_type, + data=data, + ) + + +def parse_base64_image( + data_b64: str, + mime_type: str, + *, + prefix: str = "image", +) -> PreparedImage: + mime = mime_type.strip().lower() + data = base64.b64decode(data_b64, validate=True) + _validate_image_bytes(data, mime) + return PreparedImage( + filename=_default_filename(mime, prefix=prefix), + mime_type=mime, + data=data, + ) + + +def _sniff_mime_type(data: bytes, url: str) -> str: + kind = imghdr.what(None, data) + if kind == "jpeg": + return "image/jpeg" + if kind in {"png", "gif", "webp"}: + return f"image/{kind}" + guessed, _ = mimetypes.guess_type(url) + return (guessed or "application/octet-stream").lower() + + +def _download_remote_image_sync(url: str, *, prefix: str = "image") -> PreparedImage: + parsed = urllib.parse.urlparse(url) + if parsed.scheme not in {"http", "https"}: + raise ValueError("image_url 仅支持 http/https 或 data URL") + req = urllib.request.Request( + url, + headers={"User-Agent": "web2api/1.0", "Accept": "image/*"}, + ) + with urllib.request.urlopen(req, timeout=20) as resp: + data = resp.read(MAX_IMAGE_BYTES + 1) + mime_type = str(resp.headers.get_content_type() or "").lower() + if not mime_type or mime_type == "application/octet-stream": + mime_type = _sniff_mime_type(data, url) + _validate_image_bytes(data, mime_type) + filename = urllib.parse.unquote( + parsed.path.rsplit("/", 1)[-1] + ) or _default_filename(mime_type, prefix=prefix) + if "." not in filename: + filename = _default_filename(mime_type, prefix=prefix) + return PreparedImage(filename=filename, mime_type=mime_type, data=data) + + +async def download_remote_image(url: str, *, prefix: str = "image") -> PreparedImage: + return await asyncio.to_thread(_download_remote_image_sync, url, prefix=prefix) diff --git a/core/protocol/openai.py b/core/protocol/openai.py new file mode 100644 index 0000000000000000000000000000000000000000..3908cf2f76b5ba43fa6d2b4fe743ec50d4522e81 --- /dev/null +++ b/core/protocol/openai.py @@ -0,0 +1,251 @@ +"""OpenAI 协议适配器。""" + +from __future__ import annotations + +import json +import re +import time +import uuid as uuid_mod +from collections.abc import AsyncIterator +from typing import Any + +from core.api.conv_parser import ( + extract_session_id_marker, + parse_conv_uuid_from_messages, + strip_session_id_suffix, +) +from core.api.function_call import build_tool_calls_response +from core.api.react import ( + format_react_final_answer_content, + parse_react_output, + react_output_to_tool_calls, +) +from core.api.react_stream_parser import ReactStreamParser +from core.api.schemas import OpenAIChatRequest, OpenAIContentPart, OpenAIMessage +from core.hub.schemas import OpenAIStreamEvent +from core.protocol.base import ProtocolAdapter +from core.protocol.schemas import ( + CanonicalChatRequest, + CanonicalContentBlock, + CanonicalMessage, + CanonicalToolSpec, +) + + +class OpenAIProtocolAdapter(ProtocolAdapter): + protocol_name = "openai" + + def parse_request( + self, + provider: str, + raw_body: dict[str, Any], + ) -> CanonicalChatRequest: + req = OpenAIChatRequest.model_validate(raw_body) + resume_session_id = parse_conv_uuid_from_messages( + [self._message_to_raw_dict(m) for m in req.messages] + ) + system_blocks: list[CanonicalContentBlock] = [] + messages: list[CanonicalMessage] = [] + for msg in req.messages: + blocks = self._to_blocks(msg.content) + if msg.role == "system": + system_blocks.extend(blocks) + else: + messages.append(CanonicalMessage(role=msg.role, content=blocks)) + tools = [self._to_tool_spec(tool) for tool in list(req.tools or [])] + return CanonicalChatRequest( + protocol="openai", + provider=provider, + model=req.model, + system=system_blocks, + messages=messages, + stream=req.stream, + tools=tools, + tool_choice=req.tool_choice, + resume_session_id=resume_session_id, + ) + + def render_non_stream( + self, + req: CanonicalChatRequest, + raw_events: list[OpenAIStreamEvent], + ) -> dict[str, Any]: + reply = "".join( + ev.content or "" + for ev in raw_events + if ev.type == "content_delta" and ev.content + ) + session_marker = extract_session_id_marker(reply) + content_for_parse = strip_session_id_suffix(reply) + chat_id, created = self._response_context(req) + if req.tools: + parsed = parse_react_output(content_for_parse) + tool_calls_list = react_output_to_tool_calls(parsed) if parsed else [] + if tool_calls_list: + thought_ns = "" + if "Thought" in content_for_parse: + match = re.search( + r"Thought[::]\s*(.+?)(?=\s*Action[::]|$)", + content_for_parse, + re.DOTALL | re.I, + ) + thought_ns = (match.group(1) or "").strip() if match else "" + text_content = ( + f"{thought_ns}\n{session_marker}".strip() + if thought_ns + else session_marker + ) + return build_tool_calls_response( + tool_calls_list, + chat_id, + req.model, + created, + text_content=text_content, + ) + content_reply = format_react_final_answer_content(content_for_parse) + if session_marker: + content_reply += session_marker + else: + content_reply = content_for_parse + return { + "id": chat_id, + "object": "chat.completion", + "created": created, + "model": req.model, + "choices": [ + { + "index": 0, + "message": {"role": "assistant", "content": content_reply}, + "finish_reason": "stop", + } + ], + } + + async def render_stream( + self, + req: CanonicalChatRequest, + raw_stream: AsyncIterator[OpenAIStreamEvent], + ) -> AsyncIterator[str]: + chat_id, created = self._response_context(req) + parser = ReactStreamParser( + chat_id=chat_id, + model=req.model, + created=created, + has_tools=bool(req.tools), + ) + session_marker = "" + async for event in raw_stream: + if event.type == "content_delta" and event.content: + chunk = event.content + if extract_session_id_marker(chunk) and not strip_session_id_suffix( + chunk + ): + session_marker = chunk + continue + for sse in parser.feed(chunk): + yield sse + elif event.type == "finish": + break + if session_marker: + yield self._content_delta(chat_id, req.model, created, session_marker) + for sse in parser.finish(): + yield sse + + def render_error(self, exc: Exception) -> tuple[int, dict[str, Any]]: + status = 400 if isinstance(exc, ValueError) else 500 + err_type = "invalid_request_error" if status == 400 else "server_error" + return ( + status, + {"error": {"message": str(exc), "type": err_type}}, + ) + + @staticmethod + def _message_to_raw_dict(msg: OpenAIMessage) -> dict[str, Any]: + if isinstance(msg.content, list): + content: str | list[dict[str, Any]] = [p.model_dump() for p in msg.content] + else: + content = msg.content + out: dict[str, Any] = {"role": msg.role, "content": content} + if msg.tool_calls is not None: + out["tool_calls"] = msg.tool_calls + if msg.tool_call_id is not None: + out["tool_call_id"] = msg.tool_call_id + return out + + @staticmethod + def _to_blocks( + content: str | list[OpenAIContentPart] | None, + ) -> list[CanonicalContentBlock]: + if content is None: + return [] + if isinstance(content, str): + return [ + CanonicalContentBlock( + type="text", text=strip_session_id_suffix(content) + ) + ] + blocks: list[CanonicalContentBlock] = [] + for part in content: + if part.type == "text": + blocks.append( + CanonicalContentBlock( + type="text", + text=strip_session_id_suffix(part.text or ""), + ) + ) + elif part.type == "image_url": + image_url = part.image_url + url = image_url.get("url") if isinstance(image_url, dict) else image_url + if not url: + continue + if isinstance(url, str) and url.startswith("data:"): + blocks.append(CanonicalContentBlock(type="image", data=url)) + else: + blocks.append(CanonicalContentBlock(type="image", url=str(url))) + return blocks + + @staticmethod + def _to_tool_spec(tool: dict[str, Any]) -> CanonicalToolSpec: + function = tool.get("function") if tool.get("type") == "function" else tool + return CanonicalToolSpec( + name=str(function.get("name") or ""), + description=str(function.get("description") or ""), + input_schema=function.get("parameters") + or function.get("input_schema") + or {}, + strict=bool(function.get("strict") or False), + ) + + @staticmethod + def _content_delta(chat_id: str, model: str, created: int, text: str) -> str: + return ( + "data: " + + json.dumps( + { + "id": chat_id, + "object": "chat.completion.chunk", + "created": created, + "model": model, + "choices": [ + { + "index": 0, + "delta": {"content": text}, + "logprobs": None, + "finish_reason": None, + } + ], + }, + ensure_ascii=False, + ) + + "\n\n" + ) + + @staticmethod + def _response_context(req: CanonicalChatRequest) -> tuple[str, int]: + chat_id = str( + req.metadata.setdefault( + "response_id", f"chatcmpl-{uuid_mod.uuid4().hex[:24]}" + ) + ) + created = int(req.metadata.setdefault("created", int(time.time()))) + return chat_id, created diff --git a/core/protocol/schemas.py b/core/protocol/schemas.py new file mode 100644 index 0000000000000000000000000000000000000000..a68c6e0982c1c46cde22e7e4492be8c6c479082c --- /dev/null +++ b/core/protocol/schemas.py @@ -0,0 +1,69 @@ +"""协议层内部统一模型。""" + +from __future__ import annotations + +from typing import Any, Literal + +from pydantic import BaseModel, Field + + +class CanonicalContentBlock(BaseModel): + type: Literal["text", "thinking", "tool_use", "tool_result", "image"] + text: str | None = None + id: str | None = None + name: str | None = None + input: dict[str, Any] | None = None + tool_use_id: str | None = None + is_error: bool | None = None + mime_type: str | None = None + data: str | None = None + url: str | None = None + + +class CanonicalMessage(BaseModel): + role: Literal["system", "user", "assistant", "tool"] + content: list[CanonicalContentBlock] = Field(default_factory=list) + + +class CanonicalToolSpec(BaseModel): + name: str + description: str = "" + input_schema: dict[str, Any] = Field(default_factory=dict) + strict: bool = False + + +class CanonicalChatRequest(BaseModel): + protocol: Literal["openai", "anthropic"] + provider: str + model: str + system: list[CanonicalContentBlock] = Field(default_factory=list) + messages: list[CanonicalMessage] = Field(default_factory=list) + stream: bool = False + max_tokens: int | None = None + temperature: float | None = None + top_p: float | None = None + stop_sequences: list[str] = Field(default_factory=list) + tools: list[CanonicalToolSpec] = Field(default_factory=list) + tool_choice: str | dict[str, Any] | None = None + resume_session_id: str | None = None + metadata: dict[str, Any] = Field(default_factory=dict) + + +class CanonicalStreamEvent(BaseModel): + type: Literal[ + "message_start", + "text_delta", + "thinking_delta", + "tool_call", + "usage", + "message_stop", + "error", + ] + text: str | None = None + id: str | None = None + name: str | None = None + arguments: str | None = None + stop_reason: str | None = None + session_id: str | None = None + usage: dict[str, int] | None = None + error: str | None = None diff --git a/core/protocol/service.py b/core/protocol/service.py new file mode 100644 index 0000000000000000000000000000000000000000..0cac4b41d6e4cd6acda50906b54b9131654802d2 --- /dev/null +++ b/core/protocol/service.py @@ -0,0 +1,175 @@ +"""Canonical 请求桥接到 OpenAI 语义事件流(唯一中间态)。""" + +from __future__ import annotations + +from collections.abc import AsyncIterator + +from core.api.chat_handler import ChatHandler +from core.api.schemas import ( + InputAttachment, + OpenAIChatRequest, + OpenAIContentPart, + OpenAIMessage, +) +from core.protocol.images import ( + MAX_IMAGE_COUNT, + download_remote_image, + parse_base64_image, + parse_data_url, +) +from core.hub.schemas import OpenAIStreamEvent +from core.protocol.schemas import CanonicalChatRequest, CanonicalContentBlock, CanonicalMessage + + +class CanonicalChatService: + def __init__(self, handler: ChatHandler) -> None: + self._handler = handler + + async def stream_raw( + self, req: CanonicalChatRequest + ) -> AsyncIterator[OpenAIStreamEvent]: + openai_req = await self._to_openai_request(req) + async for event in self._handler.stream_openai_events(req.provider, openai_req): + yield event + + async def collect_raw(self, req: CanonicalChatRequest) -> list[OpenAIStreamEvent]: + events: list[OpenAIStreamEvent] = [] + async for event in self.stream_raw(req): + events.append(event) + return events + + async def _to_openai_request(self, req: CanonicalChatRequest) -> OpenAIChatRequest: + messages: list[OpenAIMessage] = [] + if req.system: + messages.append( + OpenAIMessage( + role="system", + content=self._to_openai_content(req.system), + ) + ) + for msg in req.messages: + messages.append( + OpenAIMessage( + role=msg.role, + content=self._to_openai_content(msg.content), + tool_call_id=msg.content[0].tool_use_id + if msg.role == "tool" and msg.content + else None, + ) + ) + + openai_tools = [ + { + "type": "function", + "function": { + "name": tool.name, + "description": tool.description, + "parameters": tool.input_schema, + "strict": tool.strict, + }, + } + for tool in req.tools + ] + last_user_attachments, all_attachments = await self._resolve_attachments(req) + return OpenAIChatRequest( + model=req.model, + messages=messages, + stream=req.stream, + tools=openai_tools or None, + tool_choice=req.tool_choice, + resume_session_id=req.resume_session_id, + upstream_model=str(req.metadata.get("upstream_model") or "") or None, + # 由 ChatHandler 根据是否 full_history 选择实际赋值给 attachment_files + attachment_files=[], + attachment_files_last_user=last_user_attachments, + attachment_files_all_users=all_attachments, + ) + + async def _resolve_attachments( + self, req: CanonicalChatRequest + ) -> tuple[list[InputAttachment], list[InputAttachment]]: + """ + 解析图片附件,返回 (last_user_attachments, all_user_attachments): + + - 复用会话(full_history=False)时,仅需最后一条 user 的图片; + - 重建会话(full_history=True)时,需要把所有历史 user 的图片一并补上。 + """ + last_user: CanonicalMessage | None = None + for msg in reversed(req.messages): + if msg.role == "user": + last_user = msg + break + + # 所有 user 消息里的图片(用于重建会话补历史) + all_image_blocks: list[CanonicalContentBlock] = [] + for msg in req.messages: + if msg.role != "user": + continue + all_image_blocks.extend( + block for block in msg.content if block.type == "image" + ) + + last_user_blocks: list[CanonicalContentBlock] = [] + if last_user is not None: + last_user_blocks = [ + block for block in last_user.content if block.type == "image" + ] + + if len(all_image_blocks) > MAX_IMAGE_COUNT: + raise ValueError(f"单次最多上传 {MAX_IMAGE_COUNT} 张图片") + + async def _prepare( + blocks: list[CanonicalContentBlock], + ) -> list[InputAttachment]: + attachments: list[InputAttachment] = [] + for idx, block in enumerate(blocks, start=1): + if block.url: + prepared = await download_remote_image( + block.url, prefix=f"message_image_{idx}" + ) + elif block.data and block.data.startswith("data:"): + prepared = parse_data_url(block.data, prefix=f"message_image_{idx}") + elif block.data and block.mime_type: + prepared = parse_base64_image( + block.data, + block.mime_type, + prefix=f"message_image_{idx}", + ) + else: + raise ValueError("图片块缺少可用数据") + attachments.append( + InputAttachment( + filename=prepared.filename, + mime_type=prepared.mime_type, + data=prepared.data, + ) + ) + return attachments + + last_attachments = await _prepare(last_user_blocks) + all_attachments = await _prepare(all_image_blocks) + return last_attachments, all_attachments + + @staticmethod + def _to_openai_content( + blocks: list[CanonicalContentBlock], + ) -> str | list[OpenAIContentPart]: + if not blocks: + return "" + parts: list[OpenAIContentPart] = [] + for block in blocks: + if block.type in {"text", "thinking", "tool_result"}: + parts.append(OpenAIContentPart(type="text", text=block.text or "")) + elif block.type == "image": + url = block.url or block.data or "" + parts.append( + OpenAIContentPart( + type="image_url", + image_url={"url": url}, + ) + ) + if not parts: + return "" + if len(parts) == 1 and parts[0].type == "text": + return parts[0].text or "" + return parts diff --git a/core/runtime/__init__.py b/core/runtime/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b181c22021aa5021291c23a86897fd6695bfef2f --- /dev/null +++ b/core/runtime/__init__.py @@ -0,0 +1,12 @@ +"""运行时:浏览器进程、CDP 连接、page/会话缓存。""" + +from core.runtime.keys import ProxyKey +from core.runtime.session_cache import SessionCache, SessionEntry +from core.runtime.browser_manager import BrowserManager + +__all__ = [ + "ProxyKey", + "SessionCache", + "SessionEntry", + "BrowserManager", +] diff --git a/core/runtime/browser_manager.py b/core/runtime/browser_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..d55a054c3cbf6e41d889300448f2a19c1b5f5460 --- /dev/null +++ b/core/runtime/browser_manager.py @@ -0,0 +1,839 @@ +""" +浏览器管理器:按 ProxyKey 管理浏览器进程;每个浏览器内每个 type 仅保留一个 tab。 + +当前实现的职责: + +- 一个 ProxyKey 对应一个 Chromium 进程 +- 一个浏览器内,一个 type 只允许一个 page/tab +- tab 绑定一个 account,只有 drained 后才能切号 +- tab 可承载多个 session,并记录活跃请求数与最近使用时间 +""" + +from __future__ import annotations + +import asyncio +import logging +import os +import subprocess +import tempfile +import time +from dataclasses import dataclass, field +from pathlib import Path +from typing import TYPE_CHECKING, Any, Callable, Coroutine + +if TYPE_CHECKING: + from core.runtime.local_proxy_forwarder import LocalProxyForwarder + +from playwright.async_api import Browser, BrowserContext, Page, async_playwright + +from core.constants import CDP_PORT_RANGE, CHROMIUM_BIN, TIMEZONE, user_data_dir +from core.plugin.errors import BrowserResourceInvalidError +from core.runtime.keys import ProxyKey + +logger = logging.getLogger(__name__) + +CreatePageFn = Callable[[BrowserContext, Page | None], Coroutine[Any, Any, Page]] +ApplyAuthFn = Callable[[BrowserContext, Page], Coroutine[Any, Any, None]] + + +async def _wait_for_cdp( + host: str, + port: int, + max_attempts: int = 60, + interval: float = 2.0, + connect_timeout: float = 2.0, +) -> bool: + for _ in range(max_attempts): + try: + _, writer = await asyncio.wait_for( + asyncio.open_connection(host, port), timeout=connect_timeout + ) + writer.close() + await writer.wait_closed() + return True + except (OSError, asyncio.TimeoutError): + await asyncio.sleep(interval) + return False + + +def _is_cdp_listening(port: int) -> bool: + import socket + + try: + with socket.create_connection(("127.0.0.1", port), timeout=1.0): + pass + return True + except OSError: + return False + + +@dataclass +class TabRuntime: + """浏览器中的一个 type tab。""" + + type_name: str + page: Page + account_id: str + active_requests: int = 0 + accepting_new: bool = True + state: str = "ready" + last_used_at: float = field(default_factory=time.time) + frozen_until: int | None = None + sessions: set[str] = field(default_factory=set) + + +@dataclass +class BrowserEntry: + """单个 ProxyKey 对应的浏览器运行时。""" + + proc: subprocess.Popen[Any] + port: int + browser: Browser + context: BrowserContext + stderr_path: Path | None = None + tabs: dict[str, TabRuntime] = field(default_factory=dict) + last_used_at: float = field(default_factory=time.time) + proxy_forwarder: Any = None # LocalProxyForwarder | None,仅 use_proxy 时非空 + + +@dataclass +class ClosedTabInfo: + """关闭 tab/browser 时回传的 session 清理信息。""" + + proxy_key: ProxyKey + type_name: str + account_id: str + session_ids: list[str] + + +class BrowserManager: + """按代理组管理浏览器及其 type -> tab 映射。""" + + def __init__( + self, + chromium_bin: str = CHROMIUM_BIN, + headless: bool = False, + no_sandbox: bool = False, + disable_gpu: bool = False, + disable_gpu_sandbox: bool = False, + port_range: list[int] | None = None, + cdp_wait_max_attempts: int = 90, + cdp_wait_interval_seconds: float = 2.0, + cdp_wait_connect_timeout_seconds: float = 2.0, + ) -> None: + self._chromium_bin = chromium_bin + self._headless = headless + self._no_sandbox = no_sandbox + self._disable_gpu = disable_gpu + self._disable_gpu_sandbox = disable_gpu_sandbox + self._port_range = port_range or list(CDP_PORT_RANGE) + self._entries: dict[ProxyKey, BrowserEntry] = {} + self._available_ports: set[int] = set(self._port_range) + self._playwright: Any = None + self._cdp_wait_max_attempts = max(1, int(cdp_wait_max_attempts)) + self._cdp_wait_interval_seconds = max(0.05, float(cdp_wait_interval_seconds)) + self._cdp_wait_connect_timeout_seconds = max( + 0.2, float(cdp_wait_connect_timeout_seconds) + ) + + def _stderr_log_path(self, proxy_key: ProxyKey, port: int) -> Path: + log_dir = Path(tempfile.gettempdir()) / "web2api-browser-logs" + log_dir.mkdir(parents=True, exist_ok=True) + return log_dir / ( + f"{proxy_key.fingerprint_id}-{port}-{int(time.time())}.stderr.log" + ) + + @staticmethod + def _read_stderr_tail(stderr_path: Path | None, max_chars: int = 4000) -> str: + if stderr_path is None or not stderr_path.exists(): + return "" + try: + content = stderr_path.read_text(encoding="utf-8", errors="replace") + except Exception: + return "" + content = content.strip() + if not content: + return "" + return content[-max_chars:] + + @staticmethod + def _cleanup_stderr_log(stderr_path: Path | None) -> None: + if stderr_path is None: + return + try: + stderr_path.unlink(missing_ok=True) + except Exception: + pass + + def current_proxy_keys(self) -> list[ProxyKey]: + return list(self._entries.keys()) + + def browser_count(self) -> int: + return len(self._entries) + + def list_browser_entries(self) -> list[tuple[ProxyKey, BrowserEntry]]: + return list(self._entries.items()) + + def get_browser_entry(self, proxy_key: ProxyKey) -> BrowserEntry | None: + return self._entries.get(proxy_key) + + def get_tab(self, proxy_key: ProxyKey, type_name: str) -> TabRuntime | None: + entry = self._entries.get(proxy_key) + if entry is None: + return None + return entry.tabs.get(type_name) + + def browser_load(self, proxy_key: ProxyKey) -> int: + entry = self._entries.get(proxy_key) + if entry is None: + return 0 + return sum(tab.active_requests for tab in entry.tabs.values()) + + def browser_diagnostics(self, proxy_key: ProxyKey) -> dict[str, Any]: + entry = self._entries.get(proxy_key) + if entry is None: + return { + "browser_present": False, + "proc_alive": False, + "cdp_listening": False, + "stderr_tail": "", + "tab_count": 0, + "active_requests": 0, + "tabs": [], + } + tabs = [ + { + "type": type_name, + "state": tab.state, + "accepting_new": tab.accepting_new, + "active_requests": tab.active_requests, + "session_count": len(tab.sessions), + } + for type_name, tab in entry.tabs.items() + ] + return { + "browser_present": True, + "proc_alive": entry.proc.poll() is None, + "cdp_listening": _is_cdp_listening(entry.port), + "stderr_tail": self._read_stderr_tail(entry.stderr_path), + "tab_count": len(entry.tabs), + "active_requests": sum(tab.active_requests for tab in entry.tabs.values()), + "tabs": tabs, + } + + def _raise_browser_resource_invalid( + self, + proxy_key: ProxyKey, + *, + detail: str, + helper_name: str, + stage: str, + resource_hint: str = "browser", + request_url: str = "", + page_url: str = "", + request_id: str | None = None, + stream_phase: str | None = None, + type_name: str | None = None, + account_id: str | None = None, + ) -> None: + diagnostics = self.browser_diagnostics(proxy_key) + logger.warning( + "[browser-resource-invalid] helper=%s stage=%s proxy=%s resource=%s request_id=%s type=%s account=%s proc_alive=%s cdp_listening=%s tab_count=%s active_requests=%s stderr_tail=%s detail=%s", + helper_name, + stage, + proxy_key.fingerprint_id, + resource_hint, + request_id, + type_name, + account_id, + diagnostics.get("proc_alive"), + diagnostics.get("cdp_listening"), + diagnostics.get("tab_count"), + diagnostics.get("active_requests"), + diagnostics.get("stderr_tail"), + detail, + ) + raise BrowserResourceInvalidError( + detail, + helper_name=helper_name, + operation="browser_manager", + stage=stage, + resource_hint=resource_hint, + request_url=request_url, + page_url=page_url, + request_id=request_id, + stream_phase=stream_phase, + proxy_key=proxy_key, + type_name=type_name, + account_id=account_id, + ) + + def touch_browser(self, proxy_key: ProxyKey) -> None: + entry = self._entries.get(proxy_key) + if entry is not None: + entry.last_used_at = time.time() + + def _launch_process( + self, + proxy_key: ProxyKey, + proxy_pass: str, + port: int, + ) -> tuple[subprocess.Popen[Any], Path, LocalProxyForwarder | None]: + """启动 Chromium 进程(代理时使用本地转发鉴权,无扩展),使用指定 port。""" + udd = user_data_dir(proxy_key.fingerprint_id) + udd.mkdir(parents=True, exist_ok=True) + + if not Path(self._chromium_bin).exists(): + raise RuntimeError(f"Chromium 不存在: {self._chromium_bin}") + + args = [ + self._chromium_bin, + f"--remote-debugging-port={port}", + f"--fingerprint={proxy_key.fingerprint_id}", + "--fingerprint-platform=windows", + "--fingerprint-brand=Edge", + f"--user-data-dir={udd}", + f"--timezone={proxy_key.timezone or TIMEZONE}", + "--force-webrtc-ip-handling-policy", + "--webrtc-ip-handling-policy=disable_non_proxied_udp", + "--disable-features=AsyncDNS", + "--disable-dev-shm-usage", + "--no-first-run", + "--no-default-browser-check", + # Memory optimization for constrained environments (HF Spaces cpu-basic) + "--renderer-process-limit=1", + "--disable-extensions", + "--disable-background-networking", + "--disable-component-update", + "--disable-sync", + "--disable-translate", + "--disable-features=MediaRouter,TranslateUI", + "--js-flags=--max-old-space-size=256", + ] + proxy_forwarder = None + if proxy_key.use_proxy: + from core.runtime.local_proxy_forwarder import ( + LocalProxyForwarder, + UpstreamProxy, + parse_proxy_server, + ) + + upstream_host, upstream_port = parse_proxy_server(proxy_key.proxy_host) + upstream = UpstreamProxy( + host=upstream_host, + port=upstream_port, + username=proxy_key.proxy_user, + password=proxy_pass, + ) + proxy_forwarder = LocalProxyForwarder( + upstream, + listen_host="127.0.0.1", + listen_port=0, + on_log=lambda msg: logger.debug("[proxy] %s", msg), + ) + proxy_forwarder.start() + args.append(f"--proxy-server={proxy_forwarder.proxy_url}") + if self._headless: + args.extend( + [ + "--headless=new", + "--window-size=1920,1080", + ] + ) + if self._headless or self._disable_gpu: + args.append("--disable-gpu") + if self._disable_gpu_sandbox: + args.append("--disable-gpu-sandbox") + if self._no_sandbox: + args.extend( + [ + "--no-sandbox", + "--disable-setuid-sandbox", + ] + ) + env = os.environ.copy() + env["NODE_OPTIONS"] = ( + env.get("NODE_OPTIONS") or "" + ).strip() + " --no-deprecation" + env.setdefault("DBUS_SESSION_BUS_ADDRESS", "/dev/null") + stderr_path = self._stderr_log_path(proxy_key, port) + stderr_fp = stderr_path.open("ab") + try: + proc = subprocess.Popen( + args, + stdin=subprocess.DEVNULL, + stdout=subprocess.DEVNULL, + stderr=stderr_fp, + env=env, + ) + finally: + stderr_fp.close() + return proc, stderr_path, proxy_forwarder + + async def ensure_browser( + self, + proxy_key: ProxyKey, + proxy_pass: str, + ) -> BrowserContext: + """ + 确保存在对应 proxy_key 的浏览器;若已有且存活则直接复用。 + """ + entry = self._entries.get(proxy_key) + if entry is not None: + if entry.proc.poll() is not None or not _is_cdp_listening(entry.port): + await self._close_entry_async(proxy_key) + else: + entry.last_used_at = time.time() + return entry.context + + if not self._available_ports: + raise RuntimeError( + "无可用 CDP 端口,当前并发浏览器数已达上限,请稍后重试或增大 cdp_port_count" + ) + port = self._available_ports.pop() + proc, stderr_path, proxy_forwarder = self._launch_process( + proxy_key, proxy_pass, port + ) + logger.info( + "已启动 Chromium PID=%s port=%s mode=%s headless=%s no_sandbox=%s disable_gpu=%s disable_gpu_sandbox=%s,等待 CDP 就绪...", + proc.pid, + port, + "proxy" if proxy_key.use_proxy else "direct", + self._headless, + self._no_sandbox, + self._disable_gpu, + self._disable_gpu_sandbox, + ) + ok = await _wait_for_cdp( + "127.0.0.1", + port, + max_attempts=self._cdp_wait_max_attempts, + interval=self._cdp_wait_interval_seconds, + connect_timeout=self._cdp_wait_connect_timeout_seconds, + ) + if not ok: + self._available_ports.add(port) + if proxy_forwarder is not None: + try: + proxy_forwarder.stop() + except Exception: + pass + try: + proc.terminate() + proc.wait(timeout=5) + except Exception: + pass + stderr_tail = self._read_stderr_tail(stderr_path) + self._cleanup_stderr_log(stderr_path) + if stderr_tail: + logger.error( + "Chromium 启动失败,CDP 未就绪。stderr tail:\n%s", + stderr_tail, + ) + raise RuntimeError("CDP 未在预期时间内就绪") + + if self._playwright is None: + self._playwright = await async_playwright().start() + endpoint = f"http://127.0.0.1:{port}" + try: + browser = await self._playwright.chromium.connect_over_cdp( + endpoint, timeout=10000 + ) + except Exception: + self._available_ports.add(port) + if proxy_forwarder is not None: + try: + proxy_forwarder.stop() + except Exception: + pass + try: + proc.terminate() + proc.wait(timeout=5) + except Exception: + pass + stderr_tail = self._read_stderr_tail(stderr_path) + self._cleanup_stderr_log(stderr_path) + if stderr_tail: + logger.error( + "Chromium 已监听 CDP 但 connect_over_cdp 失败。stderr tail:\n%s", + stderr_tail, + ) + raise + context = browser.contexts[0] if browser.contexts else None + if context is None: + await browser.close() + self._available_ports.add(port) + if proxy_forwarder is not None: + try: + proxy_forwarder.stop() + except Exception: + pass + try: + proc.terminate() + proc.wait(timeout=5) + except Exception: + pass + self._cleanup_stderr_log(stderr_path) + raise RuntimeError("浏览器无默认 context") + self._entries[proxy_key] = BrowserEntry( + proc=proc, + port=port, + browser=browser, + context=context, + stderr_path=stderr_path, + proxy_forwarder=proxy_forwarder, + ) + return context + + async def open_tab( + self, + proxy_key: ProxyKey, + proxy_pass: str, + type_name: str, + account_id: str, + create_page_fn: CreatePageFn, + apply_auth_fn: ApplyAuthFn, + ) -> TabRuntime: + """在指定浏览器中创建一个 type tab,并绑定到 account。""" + context = await self.ensure_browser(proxy_key, proxy_pass) + entry = self._entries.get(proxy_key) + if entry is None: + raise RuntimeError("ensure_browser 未创建 entry") + existing = entry.tabs.get(type_name) + if existing is not None: + return existing + + logger.info( + "[tab] opening proxy=%s type=%s account=%s reuse_blank=%s tab_count=%s active_requests=%s", + proxy_key.fingerprint_id, + type_name, + account_id, + bool(len(entry.tabs) == 0 and context.pages), + len(entry.tabs), + sum(tab.active_requests for tab in entry.tabs.values()), + ) + # 首个 tab 时复用 Chromium 默认空白页,避免多一个无用标签 + reuse_page = ( + context.pages[0] if (len(entry.tabs) == 0 and context.pages) else None + ) + try: + page = await create_page_fn(context, reuse_page) + except Exception as e: + msg = str(e) + normalized = msg.lower() + if "target.createtarget" in normalized or "failed to open a new tab" in normalized: + self._raise_browser_resource_invalid( + proxy_key, + detail=msg, + helper_name="open_tab", + stage="create_page", + resource_hint="browser", + type_name=type_name, + account_id=account_id, + ) + raise + try: + await apply_auth_fn(context, page) + except Exception as e: + try: + await page.close() + except Exception: + pass + msg = str(e) + normalized = msg.lower() + if ( + "target crashed" in normalized + or "page has been closed" in normalized + or "browser has been closed" in normalized + or "has been disconnected" in normalized + ): + self._raise_browser_resource_invalid( + proxy_key, + detail=msg, + helper_name="open_tab", + stage="apply_auth", + resource_hint="browser", + page_url=getattr(page, "url", "") or "", + type_name=type_name, + account_id=account_id, + ) + raise + + tab = TabRuntime( + type_name=type_name, + page=page, + account_id=account_id, + ) + entry.tabs[type_name] = tab + entry.last_used_at = time.time() + logger.info( + "[tab] opened mode=%s proxy=%s type=%s account=%s", + "proxy" if proxy_key.use_proxy else "direct", + proxy_key.fingerprint_id, + type_name, + account_id, + ) + return tab + + async def switch_tab_account( + self, + proxy_key: ProxyKey, + type_name: str, + account_id: str, + apply_auth_fn: ApplyAuthFn, + ) -> bool: + """ + 在同一个 page 上切换账号。只有 drained 后(active_requests==0)才允许切号。 + """ + entry = self._entries.get(proxy_key) + if entry is None: + return False + tab = entry.tabs.get(type_name) + if tab is None or tab.active_requests != 0: + return False + + tab.accepting_new = False + tab.state = "switching" + try: + await apply_auth_fn(entry.context, tab.page) + except Exception: + tab.state = "draining" + return False + + tab.account_id = account_id + tab.accepting_new = True + tab.state = "ready" + tab.frozen_until = None + tab.last_used_at = time.time() + tab.sessions.clear() + entry.last_used_at = time.time() + logger.info( + "[tab] switched account mode=%s proxy=%s type=%s account=%s", + "proxy" if proxy_key.use_proxy else "direct", + proxy_key.fingerprint_id, + type_name, + account_id, + ) + return True + + def acquire_tab( + self, + proxy_key: ProxyKey, + type_name: str, + max_concurrent: int, + ) -> Page | None: + """ + 为一次请求占用 tab;tab 必须存在、可接新请求且未达到并发上限。 + """ + entry = self._entries.get(proxy_key) + if entry is None: + return None + tab = entry.tabs.get(type_name) + if tab is None: + return None + if not tab.accepting_new or tab.active_requests >= max_concurrent: + return None + tab.active_requests += 1 + tab.last_used_at = time.time() + entry.last_used_at = tab.last_used_at + tab.state = "busy" + return tab.page + + def release_tab(self, proxy_key: ProxyKey, type_name: str) -> None: + """释放一次请求占用。""" + entry = self._entries.get(proxy_key) + if entry is None: + return + tab = entry.tabs.get(type_name) + if tab is None: + return + if tab.active_requests > 0: + tab.active_requests -= 1 + tab.last_used_at = time.time() + entry.last_used_at = tab.last_used_at + if tab.active_requests == 0: + if tab.accepting_new: + tab.state = "ready" + elif tab.frozen_until is not None: + tab.state = "frozen" + else: + tab.state = "draining" + + def mark_tab_draining( + self, + proxy_key: ProxyKey, + type_name: str, + *, + frozen_until: int | None = None, + ) -> None: + """禁止 tab 接受新请求,并标记为 draining/frozen。""" + entry = self._entries.get(proxy_key) + if entry is None: + return + tab = entry.tabs.get(type_name) + if tab is None: + return + tab.accepting_new = False + tab.frozen_until = frozen_until + tab.last_used_at = time.time() + entry.last_used_at = tab.last_used_at + if frozen_until is not None: + tab.state = "frozen" + else: + tab.state = "draining" + + def register_session( + self, + proxy_key: ProxyKey, + type_name: str, + session_id: str, + ) -> None: + entry = self._entries.get(proxy_key) + if entry is None: + return + tab = entry.tabs.get(type_name) + if tab is None: + return + tab.sessions.add(session_id) + tab.last_used_at = time.time() + entry.last_used_at = tab.last_used_at + + def unregister_session( + self, + proxy_key: ProxyKey, + type_name: str, + session_id: str, + ) -> None: + entry = self._entries.get(proxy_key) + if entry is None: + return + tab = entry.tabs.get(type_name) + if tab is None: + return + tab.sessions.discard(session_id) + + async def close_tab( + self, + proxy_key: ProxyKey, + type_name: str, + ) -> ClosedTabInfo | None: + """关闭某个 type 的 tab,并返回需要失效的 session 列表。""" + entry = self._entries.get(proxy_key) + if entry is None: + return None + tab = entry.tabs.pop(type_name, None) + if tab is None: + return None + try: + await tab.page.close() + except Exception: + pass + entry.last_used_at = time.time() + logger.info( + "[tab] closed mode=%s proxy=%s type=%s", + "proxy" if proxy_key.use_proxy else "direct", + proxy_key.fingerprint_id, + type_name, + ) + return ClosedTabInfo( + proxy_key=proxy_key, + type_name=type_name, + account_id=tab.account_id, + session_ids=list(tab.sessions), + ) + + async def close_browser(self, proxy_key: ProxyKey) -> list[ClosedTabInfo]: + return await self._close_entry_async(proxy_key) + + async def _close_entry_async(self, proxy_key: ProxyKey) -> list[ClosedTabInfo]: + entry = self._entries.get(proxy_key) + if entry is None: + return [] + + closed_tabs = [ + ClosedTabInfo( + proxy_key=proxy_key, + type_name=type_name, + account_id=tab.account_id, + session_ids=list(tab.sessions), + ) + for type_name, tab in entry.tabs.items() + ] + for tab in list(entry.tabs.values()): + try: + await tab.page.close() + except Exception: + pass + entry.tabs.clear() + if entry.proxy_forwarder is not None: + try: + entry.proxy_forwarder.stop() + except Exception as e: + logger.warning("关闭本地代理转发时异常: %s", e) + if entry.browser is not None: + try: + await entry.browser.close() + except Exception as e: + logger.warning("关闭 CDP 浏览器时异常: %s", e) + try: + entry.proc.terminate() + entry.proc.wait(timeout=8) + except subprocess.TimeoutExpired: + entry.proc.kill() + entry.proc.wait(timeout=3) + except Exception as e: + logger.warning("关闭浏览器进程时异常: %s", e) + self._cleanup_stderr_log(entry.stderr_path) + self._available_ports.add(entry.port) + del self._entries[proxy_key] + logger.info( + "[browser] closed mode=%s proxy=%s", + "proxy" if proxy_key.use_proxy else "direct", + proxy_key.fingerprint_id, + ) + return closed_tabs + + async def collect_idle_browsers( + self, + *, + idle_seconds: float, + resident_browser_count: int, + ) -> list[ClosedTabInfo]: + """ + 关闭空闲浏览器: + + - 浏览器下所有 tab 都没有活跃请求 + - 所有 tab 均已空闲超过 idle_seconds + - 当前浏览器数 > resident_browser_count + """ + if len(self._entries) <= resident_browser_count: + return [] + + now = time.time() + candidates: list[tuple[float, ProxyKey]] = [] + for proxy_key, entry in self._entries.items(): + if any(tab.active_requests > 0 for tab in entry.tabs.values()): + continue + if entry.tabs: + last_tab_used = max(tab.last_used_at for tab in entry.tabs.values()) + else: + last_tab_used = entry.last_used_at + if now - last_tab_used < idle_seconds: + continue + candidates.append((last_tab_used, proxy_key)) + + if not candidates: + return [] + + closed: list[ClosedTabInfo] = [] + max_close = max(0, len(self._entries) - resident_browser_count) + for _, proxy_key in sorted(candidates, key=lambda item: item[0])[:max_close]: + closed.extend(await self._close_entry_async(proxy_key)) + return closed + + async def close_all(self) -> list[ClosedTabInfo]: + """关闭全部浏览器和 tab。""" + closed: list[ClosedTabInfo] = [] + for proxy_key in list(self._entries.keys()): + closed.extend(await self._close_entry_async(proxy_key)) + return closed diff --git a/core/runtime/conversation_index.py b/core/runtime/conversation_index.py new file mode 100644 index 0000000000000000000000000000000000000000..7ed1452b657f83425483758da5e09bc4f0deef09 --- /dev/null +++ b/core/runtime/conversation_index.py @@ -0,0 +1,72 @@ +"""会话指纹索引:替代 sticky session,通过指纹精确匹配同一逻辑对话。 + +指纹 = sha256(system_prompt + first_user_message)[:16], +同一对话的指纹恒定,不同对话指纹不同,杜绝上下文污染。 +""" + +import time +from dataclasses import dataclass, field + + +@dataclass +class ConversationEntry: + session_id: str + fingerprint: str + message_count: int + account_id: str + created_at: float = field(default_factory=time.time) + last_used_at: float = field(default_factory=time.time) + + +class ConversationIndex: + """进程内指纹索引,不持久化。""" + + def __init__(self) -> None: + self._by_fingerprint: dict[str, ConversationEntry] = {} + self._by_session_id: dict[str, ConversationEntry] = {} + + def register( + self, + fingerprint: str, + session_id: str, + message_count: int, + account_id: str, + ) -> None: + # Remove old entry for this fingerprint if exists + old = self._by_fingerprint.pop(fingerprint, None) + if old is not None: + self._by_session_id.pop(old.session_id, None) + entry = ConversationEntry( + session_id=session_id, + fingerprint=fingerprint, + message_count=message_count, + account_id=account_id, + ) + self._by_fingerprint[fingerprint] = entry + self._by_session_id[session_id] = entry + + def lookup(self, fingerprint: str) -> ConversationEntry | None: + entry = self._by_fingerprint.get(fingerprint) + if entry is not None: + entry.last_used_at = time.time() + return entry + + def remove_session(self, session_id: str) -> None: + entry = self._by_session_id.pop(session_id, None) + if entry is not None: + self._by_fingerprint.pop(entry.fingerprint, None) + + def evict_stale(self, ttl: float) -> list[str]: + """Remove entries older than *ttl* seconds. Returns evicted session IDs.""" + now = time.time() + stale = [ + e.session_id + for e in self._by_fingerprint.values() + if (now - e.last_used_at) > ttl + ] + for sid in stale: + self.remove_session(sid) + return stale + + def __len__(self) -> int: + return len(self._by_fingerprint) diff --git a/core/runtime/keys.py b/core/runtime/keys.py new file mode 100644 index 0000000000000000000000000000000000000000..2928aca5304cad4a05f1f3a6d9414122d2aab846 --- /dev/null +++ b/core/runtime/keys.py @@ -0,0 +1,15 @@ +"""运行时键类型:代理组唯一标识。""" + +from typing import NamedTuple + +from core.constants import TIMEZONE + + +class ProxyKey(NamedTuple): + """唯一标识一个代理组(一个浏览器进程)。""" + + proxy_host: str + proxy_user: str + fingerprint_id: str + use_proxy: bool = True + timezone: str = TIMEZONE diff --git a/core/runtime/local_proxy_forwarder.py b/core/runtime/local_proxy_forwarder.py new file mode 100644 index 0000000000000000000000000000000000000000..3d8c4953d6b5e26b6dfb3a03170cc95576c8877e --- /dev/null +++ b/core/runtime/local_proxy_forwarder.py @@ -0,0 +1,287 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +本地中转代理(forward proxy)。 + +用途: +- 浏览器只配置无鉴权的本地代理:127.0.0.1: +- 本地代理再转发到“带用户名密码鉴权”的上游代理(HTTP proxy) + +实现重点: +- 支持 CONNECT(HTTPS 隧道)——浏览器最常见的代理用法 +- 兼容少量 HTTP 明文请求(GET http://... 这种 absolute-form) +""" + +from __future__ import annotations + +import base64 +import contextlib +import select +import socket +import socketserver +import threading +from dataclasses import dataclass +from typing import Callable, Optional +from urllib.parse import urlparse + + +def _basic_proxy_auth(username: str, password: str) -> str: + raw = f"{username}:{password}".encode("utf-8") + return "Basic " + base64.b64encode(raw).decode("ascii") + + +def _recv_until( + sock: socket.socket, marker: bytes, max_bytes: int = 256 * 1024 +) -> bytes: + data = bytearray() + while marker not in data: + chunk = sock.recv(4096) + if not chunk: + break + data += chunk + if len(data) > max_bytes: + break + return bytes(data) + + +def _split_headers(data: bytes) -> tuple[bytes, bytes]: + idx = data.find(b"\r\n\r\n") + if idx < 0: + return data, b"" + return data[: idx + 4], data[idx + 4 :] + + +def _parse_first_line(header_bytes: bytes) -> tuple[str, str, str]: + # e.g. "CONNECT example.com:443 HTTP/1.1" + first = header_bytes.split(b"\r\n", 1)[0].decode("latin-1", errors="replace") + parts = first.strip().split() + if len(parts) >= 3: + return parts[0].upper(), parts[1], parts[2] + if len(parts) == 2: + return parts[0].upper(), parts[1], "HTTP/1.1" + return "GET", "/", "HTTP/1.1" + + +def _remove_hop_by_hop_headers(header_bytes: bytes) -> bytes: + # 仅做最小处理:去掉 Proxy-Authorization / Proxy-Connection,避免重复/冲突 + lines = header_bytes.split(b"\r\n") + if not lines: + return header_bytes + out = [lines[0]] + for line in lines[1:]: + lower = line.lower() + if lower.startswith(b"proxy-authorization:"): + continue + if lower.startswith(b"proxy-connection:"): + continue + out.append(line) + return b"\r\n".join(out) + + +def _relay_bidi(a: socket.socket, b: socket.socket, stop_evt: threading.Event) -> None: + a.setblocking(False) + b.setblocking(False) + socks = [a, b] + try: + while not stop_evt.is_set(): + r, _, _ = select.select(socks, [], [], 0.5) + if not r: + continue + for s in r: + try: + data = s.recv(65536) + except BlockingIOError: + continue + if not data: + stop_evt.set() + break + other = b if s is a else a + try: + other.sendall(data) + except OSError: + stop_evt.set() + break + finally: + with contextlib.suppress(Exception): + a.close() + with contextlib.suppress(Exception): + b.close() + + +class _ThreadingTCPServer(socketserver.ThreadingMixIn, socketserver.TCPServer): + allow_reuse_address = True + daemon_threads = True + + +@dataclass(frozen=True) +class UpstreamProxy: + host: str + port: int + username: str + password: str + + @property + def auth_header_value(self) -> str: + return _basic_proxy_auth(self.username, self.password) + + +def parse_proxy_server(proxy_server: str) -> tuple[str, int]: + """ + 支持: + - http://host:port + - host:port + """ + s = (proxy_server or "").strip() + if not s: + raise ValueError("proxy_server 为空") + if "://" not in s: + s = "http://" + s + u = urlparse(s) + if not u.hostname or not u.port: + raise ValueError(f"无法解析 proxy_server: {proxy_server!r}") + return u.hostname, int(u.port) + + +class LocalProxyForwarder: + """ + 启动一个本地 HTTP 代理,并把请求/隧道转发到上游代理(带 Basic 鉴权)。 + """ + + def __init__( + self, + upstream: UpstreamProxy, + *, + listen_host: str = "127.0.0.1", + listen_port: int = 0, + on_log: Optional[Callable[[str], None]] = None, + ) -> None: + self._upstream = upstream + self._listen_host = listen_host + self._listen_port = listen_port + self._on_log = on_log + + self._server: _ThreadingTCPServer | None = None + self._thread: threading.Thread | None = None + + @property + def port(self) -> int: + if not self._server: + raise RuntimeError("forwarder 尚未启动") + return int(self._server.server_address[1]) + + @property + def proxy_url(self) -> str: + return f"http://{self._listen_host}:{self.port}" + + def _log(self, msg: str) -> None: + if self._on_log: + try: + self._on_log(msg) + except Exception: + pass + + def start(self) -> "LocalProxyForwarder": + if self._server is not None: + return self + + upstream = self._upstream + parent = self + + class Handler(socketserver.BaseRequestHandler): + def handle(self) -> None: + client = self.request + try: + data = _recv_until(client, b"\r\n\r\n") + if not data: + return + header, rest = _split_headers(data) + method, target, _ver = _parse_first_line(header) + + upstream_sock = socket.create_connection( + (upstream.host, upstream.port), timeout=15 + ) + upstream_sock.settimeout(20) + + if method == "CONNECT": + # 通过上游代理建立到 target 的隧道 + connect_req = ( + f"CONNECT {target} HTTP/1.1\r\n" + f"Host: {target}\r\n" + f"Proxy-Authorization: {upstream.auth_header_value}\r\n" + f"Proxy-Connection: keep-alive\r\n" + f"Connection: keep-alive\r\n" + f"\r\n" + ).encode("latin-1", errors="ignore") + upstream_sock.sendall(connect_req) + upstream_resp = _recv_until(upstream_sock, b"\r\n\r\n") + if not upstream_resp: + client.sendall(b"HTTP/1.1 502 Bad Gateway\r\n\r\n") + return + # 将上游响应直接回给浏览器(一般是 200 Connection Established) + client.sendall(upstream_resp) + + # CONNECT 时,header 后可能不会有 body;但如果有残留,丢给上游 + if rest: + upstream_sock.sendall(rest) + + stop_evt = threading.Event() + _relay_bidi(client, upstream_sock, stop_evt) + return + + # 非 CONNECT:把请求转发给上游代理(absolute-form 请求) + # 注:这里只做最小实现,主要为兼容偶发 http:// 明文请求 + filtered = _remove_hop_by_hop_headers(header) + # 插入 Proxy-Authorization + parts = filtered.split(b"\r\n") + out_lines = [parts[0]] + inserted = False + for line in parts[1:]: + if not inserted and line == b"": + out_lines.append( + f"Proxy-Authorization: {upstream.auth_header_value}".encode( + "latin-1", errors="ignore" + ) + ) + inserted = True + out_lines.append(line) + new_header = b"\r\n".join(out_lines) + upstream_sock.sendall(new_header) + if rest: + upstream_sock.sendall(rest) + + # 单向把响应回写给客户端直到连接关闭 + while True: + chunk = upstream_sock.recv(65536) + if not chunk: + break + client.sendall(chunk) + except Exception as e: + parent._log(f"[proxy] handler error: {e}") + with contextlib.suppress(Exception): + client.sendall(b"HTTP/1.1 502 Bad Gateway\r\n\r\n") + finally: + with contextlib.suppress(Exception): + client.close() + + self._server = _ThreadingTCPServer( + (self._listen_host, self._listen_port), Handler + ) + self._thread = threading.Thread(target=self._server.serve_forever, daemon=True) + self._thread.start() + return self + + def stop(self) -> None: + if self._server is None: + return + with contextlib.suppress(Exception): + self._server.shutdown() + with contextlib.suppress(Exception): + self._server.server_close() + self._server = None + self._thread = None + + def __enter__(self) -> "LocalProxyForwarder": + return self.start() + + def __exit__(self, exc_type, exc, tb) -> None: + self.stop() diff --git a/core/runtime/session_cache.py b/core/runtime/session_cache.py new file mode 100644 index 0000000000000000000000000000000000000000..435ef1b01a3fd605f2721ad67c13dd98a42e7ed3 --- /dev/null +++ b/core/runtime/session_cache.py @@ -0,0 +1,81 @@ +""" +会话缓存:session_id 全局唯一,映射到 (proxy_key, type, account_id)。 + +当前架构下 session 绑定到某个 tab/account: + +- tab 被关闭或切号时,需要批量失效该 tab 下的 session +- 单个 session 失效时,需要从缓存中移除,后续按完整历史重建 +- 超过 TTL 的 session 在维护循环中被自动清理 +""" + +from dataclasses import dataclass +import time + +from core.runtime.keys import ProxyKey + +# Sessions older than this are eligible for eviction during maintenance. +SESSION_TTL_SECONDS = 1800.0 # 30 minutes + + +@dataclass +class SessionEntry: + """单条会话:用于通过 session_id 反查 context/page 与账号。""" + + proxy_key: ProxyKey + type_name: str + account_id: str + last_used_at: float + + +class SessionCache: + """进程内会话缓存,不持久化、不跨进程。""" + + def __init__(self) -> None: + self._store: dict[str, SessionEntry] = {} + + def get(self, session_id: str) -> SessionEntry | None: + return self._store.get(session_id) + + def put( + self, + session_id: str, + proxy_key: ProxyKey, + type_name: str, + account_id: str, + ) -> None: + self._store[session_id] = SessionEntry( + proxy_key=proxy_key, + type_name=type_name, + account_id=account_id, + last_used_at=time.time(), + ) + + def touch(self, session_id: str) -> None: + entry = self._store.get(session_id) + if entry is not None: + entry.last_used_at = time.time() + + def delete(self, session_id: str) -> None: + self._store.pop(session_id, None) + + def delete_many(self, session_ids: list[str] | set[str]) -> None: + for session_id in session_ids: + self._store.pop(session_id, None) + + def evict_stale(self, ttl: float = SESSION_TTL_SECONDS) -> list[str]: + """Remove sessions older than *ttl* seconds. Returns evicted IDs.""" + now = time.time() + stale = [ + sid + for sid, entry in self._store.items() + if (now - entry.last_used_at) > ttl + ] + for sid in stale: + del self._store[sid] + return stale + + def __contains__(self, session_id: str) -> bool: + return session_id in self._store + + def __len__(self) -> int: + return len(self._store) diff --git a/core/static/config.html b/core/static/config.html new file mode 100644 index 0000000000000000000000000000000000000000..64fe860728e92fc0e19aaf84b934d8fad52208cd --- /dev/null +++ b/core/static/config.html @@ -0,0 +1,1698 @@ + + + + + + Web2API configuration + + + +
+ +
+
+
+
Admin dashboard
+

Web2API configuration

+

+ Manage proxy groups, account auth JSON, global API keys, the admin password, and the + public model mapping used by this bridge. +

+
+
+ + + + +
+
+ +
+
+
+
+
+

Global auth settings

+
+

+ Database-backed values persist across restarts and take precedence once saved here. + Environment variables are used as the initial fallback when the database has no value. +

+
+
+
API key source
+
Admin password source
+
+
+
+ + +

+
+
+ + +

+
+
+ +
Loading…
+

Saving a new password signs out the current dashboard session.

+
+
+
+ + +
+
+ +
+
+
+

Supported models

+
+

+ These public IDs are exposed to clients and resolved to the upstream Claude model IDs + shown below. +

+
+
+ + + Loading… +
+
+
Loading model metadata…
+
+
+
+ +
+
+
+ + +
+
+
+
+ +
+
+ Network mode: When Use proxy is enabled, the browser goes out + through the configured proxy. When disabled, the browser uses this machine’s own exit IP. + Avoid packing many accounts into one direct-connection group, or the upstream site may + link them together by IP. +
+
+
+
+
+ + + + diff --git a/core/static/index.html b/core/static/index.html new file mode 100644 index 0000000000000000000000000000000000000000..693752ee130f993755e83aa4d0d1ab194c333e41 --- /dev/null +++ b/core/static/index.html @@ -0,0 +1,474 @@ + + + + + + Web2API + + + +
+
+
Hosted bridge
+

Claude Web accounts, exposed as clean API routes.

+

+ Web2API turns browser-authenticated Claude sessions into OpenAI-compatible and + Anthropic-compatible endpoints, with a compact admin dashboard for proxy groups, + account auth, runtime status, and persistent global auth settings. +

+ +
+
+ Provider + claude +
+
+ Default model + Loading… +
+
+ Config dashboard + Checking… +
+
+
+ +
+
+
+
+

Supported models

+
+

+ Public model IDs are accepted on both OpenAI-compatible and Anthropic-compatible + routes. The cards below show the exact public → upstream mapping used by the server. +

+
+
+
Loading supported models…
+
+
+ +
+
+
+

Quick start

+
+

+ Point your client to the OpenAI-compatible route, then use one of the public model IDs + from the list above. +

+
+
+
curl https://YOUR_HOST/openai/claude/v1/chat/completions \
+  -H "Authorization: Bearer $WEB2API_AUTH_API_KEY" \
+  -H "Content-Type: application/json" \
+  -d '{
+    "model": "claude-sonnet-4.6",
+    "messages": [
+      {"role": "user", "content": "Hello from Web2API"}
+    ]
+  }'
+
+
    +
  • POST /openai/claude/v1/chat/completions
  • +
  • POST /claude/v1/chat/completions
  • +
  • POST /anthropic/claude/v1/messages
  • +
  • GET /api/models/claude/metadata
  • +
  • GET /healthz
  • +
+

+ Use /config after signing in to edit proxy groups, account JSON auth, API + keys, and the admin password. +

+
+
+
+
+
+ + + + diff --git a/core/static/login.html b/core/static/login.html new file mode 100644 index 0000000000000000000000000000000000000000..c6ea7ce3eab6d56eec910ad587e44cf18f32da22 --- /dev/null +++ b/core/static/login.html @@ -0,0 +1,255 @@ + + + + + + Admin sign in + + + +
+
Admin access
+

Sign in to the config dashboard.

+

+ Use the current admin password for Web2API. If the password is managed by environment + variables, this page still accepts that live value. +

+ + + + +
+ + Back home +
+ +

The dashboard session is stored in an HTTP-only cookie after sign-in.

+
+
+ + + + diff --git a/docker/config.container.yaml b/docker/config.container.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8fd321dae094cba6aa9fe72991e88e8f27cd9b0a --- /dev/null +++ b/docker/config.container.yaml @@ -0,0 +1,72 @@ +# Server configuration +server: + # Host address to bind the server (0.0.0.0 listens on all interfaces) + host: '0.0.0.0' + # Port the server listens on. In hosted environments, PORT / WEB2API_SERVER_PORT can override this. + port: 9000 + +auth: + # Leave empty to disable API key authentication. + # If set, all /{type}/v1/* requests must include a matching key. + # Supports a single string, comma-separated string, or YAML list. + api_key: '' + + # Leave empty to disable the config page. + # File-backed mode writes back a hashed secret; env-override mode hashes it in memory only. + config_secret: '' + # Maximum number of failed login attempts before locking + config_login_max_failures: 5 + # Duration in seconds to lock the config page after too many failed attempts + config_login_lock_seconds: 600 + +browser: + # Fixed path to fingerprint-chromium inside the container + chromium_bin: '/opt/fingerprint-chromium/chrome' + + # Headless mode is not recommended for Claude; Xvfb virtual display is used by default + headless: false + + # Disable sandbox inside containers to prevent Chromium startup failures due to permission restrictions + no_sandbox: true + + # GPU is typically unavailable in container environments; explicitly disable it + disable_gpu: true + + # GPU sandbox can also cause issues in some container environments + disable_gpu_sandbox: true + + # Starting port number for CDP (Chrome DevTools Protocol) connections + cdp_port_start: 9222 + # Number of CDP ports available + cdp_port_count: 20 + + # CDP readiness wait settings (tune for slow/cold starts in Docker) + # Total wait time ≈ cdp_wait_max_attempts * cdp_wait_interval_seconds + cdp_wait_max_attempts: 60 + cdp_wait_interval_seconds: 2.0 + # Per-attempt TCP connect timeout when probing the CDP port + cdp_wait_connect_timeout_seconds: 2.0 + +scheduler: + # Maximum number of tabs allowed to run concurrently + tab_max_concurrent: 1 + # Interval in seconds for browser garbage collection + browser_gc_interval_seconds: 300 + # Seconds of inactivity before a tab is considered idle + tab_idle_seconds: 900 + # Number of browser instances to keep resident (pre-warmed) + resident_browser_count: 1 + +claude: + # URL to open when starting Claude (leave empty for default) + start_url: '' + # Custom API base URL (leave empty to use the default endpoint) + api_base: '' + # Model name mapping: public model id -> Claude upstream model id + # Defaults are defined in ClaudePlugin.DEFAULT_MODEL_MAPPING. + # Override or extend here as needed. + model_mapping: {} + +mock: + # Port for the mock server + port: 8002 diff --git a/docker/entrypoint.sh b/docker/entrypoint.sh new file mode 100644 index 0000000000000000000000000000000000000000..dfe8c2659e87537faf42cc729e73484ea0fbbc27 --- /dev/null +++ b/docker/entrypoint.sh @@ -0,0 +1,59 @@ +#!/usr/bin/env bash +set -euo pipefail + +DATA_DIR="${WEB2API_DATA_DIR:-/data}" +CONFIG_PATH="${WEB2API_CONFIG_PATH:-${DATA_DIR}/config.yaml}" +DB_PATH="${WEB2API_DB_PATH:-${DATA_DIR}/db.sqlite3}" +XVFB_ARGS="${XVFB_SCREEN_ARGS:--screen 0 1600x900x24}" +DISPLAY_NUM="${XVFB_DISPLAY_NUM:-99}" +DISPLAY_VALUE=":${DISPLAY_NUM}" + +mkdir -p "${DATA_DIR}" + +export HOME="${DATA_DIR}" +export WEB2API_CONFIG_PATH="${CONFIG_PATH}" +export WEB2API_DB_PATH="${DB_PATH}" +export PYTHONUNBUFFERED=1 + +# 清理残留的浏览器 profile,避免 Singleton* 锁导致 Chromium 认为 profile 正在被使用。 +rm -rf "${HOME}/fp-data" + +if [[ ! -f "${CONFIG_PATH}" ]]; then + cp /app/docker/config.container.yaml "${CONFIG_PATH}" +fi + +mkdir -p "${HOME}/fp-data" + +if [[ $# -gt 0 ]]; then + exec "$@" +fi + +cleanup() { + if [[ -n "${XVFB_PID:-}" ]]; then + kill "${XVFB_PID}" >/dev/null 2>&1 || true + fi +} + +trap cleanup EXIT INT TERM + +mkdir -p /tmp/.X11-unix +rm -f "/tmp/.X${DISPLAY_NUM}-lock" + +Xvfb "${DISPLAY_VALUE}" ${XVFB_ARGS} -nolisten tcp -ac & +XVFB_PID=$! + +for _ in $(seq 1 100); do + if [[ -S "/tmp/.X11-unix/X${DISPLAY_NUM}" ]]; then + break + fi + sleep 0.1 +done + +if [[ ! -S "/tmp/.X11-unix/X${DISPLAY_NUM}" ]]; then + echo "Xvfb failed to create display ${DISPLAY_VALUE}" >&2 + exit 1 +fi + +export DISPLAY="${DISPLAY_VALUE}" + +exec python -u /app/main.py diff --git a/docs/deployment.md b/docs/deployment.md new file mode 100644 index 0000000000000000000000000000000000000000..3f1d495e4f932540cf75e1f7425b1e25c9fe802c --- /dev/null +++ b/docs/deployment.md @@ -0,0 +1,380 @@ +# Web2API 部署指南 + +本文档介绍如何将 Web2API 部署到本地 Ubuntu、远程 VPS 或 Render 平台。 + +--- + +## 目录 + +- [系统要求](#系统要求) +- [方式一:Docker 部署(推荐)](#方式一docker-部署推荐) +- [方式二:Docker Compose 部署](#方式二docker-compose-部署) +- [方式三:Ubuntu 裸机部署](#方式三ubuntu-裸机部署) +- [方式四:Render 部署](#方式四render-部署) +- [环境变量参考](#环境变量参考) +- [部署后配置](#部署后配置) +- [常见问题](#常见问题) + +--- + +## 系统要求 + +| 项目 | 最低要求 | 推荐配置 | +|------|---------|---------| +| CPU | 1 核 | 2 核+ | +| 内存 | 1 GB | 2 GB+ | +| 磁盘 | 2 GB | 5 GB+ | +| 系统 | Ubuntu 22.04+ / Debian 12+ | Ubuntu 24.04 | +| Python | 3.12+ | 3.12 | +| 架构 | amd64 / arm64 | amd64 | + +--- + +## 方式一:Docker 部署(推荐) + +最简单的部署方式,适用于本地 Ubuntu 和远程 VPS。 + +### 1. 安装 Docker + +```bash +curl -fsSL https://get.docker.com | sh +sudo usermod -aG docker $USER +# 重新登录使 docker 组生效 +``` + +### 2. 构建镜像 + +```bash +git clone https://github.com/shenhao-stu/web2api.git +cd web2api +git checkout feat/huggingface-postgres-space + +docker build -t web2api . +``` + +### 3. 运行容器 + +```bash +docker run -d \ + --name web2api \ + -p 9000:9000 \ + -v web2api-data:/data \ + -e WEB2API_AUTH_API_KEY="your-api-key-here" \ + -e WEB2API_AUTH_CONFIG_SECRET="your-admin-password" \ + -e WEB2API_BROWSER_NO_SANDBOX=true \ + -e WEB2API_BROWSER_DISABLE_GPU=true \ + -e WEB2API_BROWSER_DISABLE_GPU_SANDBOX=true \ + web2api +``` + +### 4. 验证 + +```bash +# 检查服务状态 +curl http://localhost:9000/claude/v1/models \ + -H "Authorization: Bearer your-api-key-here" + +# 测试对话 +curl http://localhost:9000/claude/v1/chat/completions \ + -H "Authorization: Bearer your-api-key-here" \ + -H "Content-Type: application/json" \ + -d '{"model":"claude-sonnet-4.6","stream":false,"messages":[{"role":"user","content":"Hello"}]}' +``` + +--- + +## 方式二:Docker Compose 部署 + +适合需要 PostgreSQL 持久化配置的场景。 + +创建 `docker-compose.yml`: + +```yaml +services: + web2api: + build: . + ports: + - "9000:9000" + volumes: + - web2api-data:/data + environment: + - WEB2API_AUTH_API_KEY=your-api-key-here + - WEB2API_AUTH_CONFIG_SECRET=your-admin-password + - WEB2API_BROWSER_NO_SANDBOX=true + - WEB2API_BROWSER_DISABLE_GPU=true + - WEB2API_BROWSER_DISABLE_GPU_SANDBOX=true + - WEB2API_DATABASE_URL=postgresql://web2api:web2api@db:5432/web2api + depends_on: + db: + condition: service_healthy + restart: unless-stopped + + db: + image: postgres:16-alpine + environment: + POSTGRES_USER: web2api + POSTGRES_PASSWORD: web2api + POSTGRES_DB: web2api + volumes: + - pgdata:/var/lib/postgresql/data + healthcheck: + test: ["CMD-SHELL", "pg_isready -U web2api"] + interval: 5s + timeout: 3s + retries: 5 + restart: unless-stopped + +volumes: + web2api-data: + pgdata: +``` + +```bash +docker compose up -d +``` + +--- + +## 方式三:Ubuntu 裸机部署 + +适合不想用 Docker 的场景,或需要更精细控制的 VPS。 + +### 1. 安装系统依赖 + +```bash +sudo apt-get update && sudo apt-get install -y \ + ca-certificates curl xz-utils xvfb xauth \ + python3 python3-pip python3-venv python-is-python3 \ + software-properties-common fonts-liberation \ + libasound2t64 libatk-bridge2.0-0t64 libatk1.0-0t64 \ + libcairo2 libcups2t64 libdbus-1-3 libdrm2 \ + libfontconfig1 libgbm1 libglib2.0-0t64 libgtk-3-0t64 \ + libnspr4 libnss3 libpango-1.0-0 libu2f-udev \ + libvulkan1 libx11-6 libx11-xcb1 libxcb1 \ + libxcomposite1 libxdamage1 libxext6 libxfixes3 \ + libxkbcommon0 libxrandr2 libxrender1 libxshmfence1 +``` + +> 注意:Ubuntu 22.04 上部分包名不带 `t64` 后缀,如 `libasound2`、`libcups2` 等。 + +### 2. 安装 Fingerprint Chromium + +```bash +# AMD64 +sudo mkdir -p /opt/fingerprint-chromium +curl -L "https://github.com/adryfish/fingerprint-chromium/releases/download/142.0.7444.175/ungoogled-chromium-142.0.7444.175-1-x86_64_linux.tar.xz" \ + -o /tmp/fp-chromium.tar.xz +sudo tar -xf /tmp/fp-chromium.tar.xz -C /opt/fingerprint-chromium --strip-components=1 +rm /tmp/fp-chromium.tar.xz + +# 验证 +/opt/fingerprint-chromium/chrome --version +``` + +### 3. 安装 Python 依赖 + +```bash +cd /opt +git clone https://github.com/shenhao-stu/web2api.git +cd web2api +git checkout feat/huggingface-postgres-space + +python -m venv .venv +source .venv/bin/activate +pip install --upgrade pip +pip install curl-cffi fastapi playwright pyyaml python-dotenv pydantic pytz "uvicorn[standard]" + +# 如需 PostgreSQL 支持 +pip install "psycopg[binary]" +``` + +### 4. 创建配置文件 + +```bash +mkdir -p /data +cp docker/config.container.yaml /data/config.yaml +``` + +根据需要编辑 `/data/config.yaml`,或通过环境变量覆盖。 + +### 5. 创建 systemd 服务 + +```bash +sudo tee /etc/systemd/system/web2api.service > /dev/null << 'EOF' +[Unit] +Description=Web2API Service +After=network.target + +[Service] +Type=simple +User=root +WorkingDirectory=/opt/web2api +Environment=DISPLAY=:99 +Environment=WEB2API_DATA_DIR=/data +Environment=WEB2API_CONFIG_PATH=/data/config.yaml +Environment=WEB2API_AUTH_API_KEY=your-api-key-here +Environment=WEB2API_AUTH_CONFIG_SECRET=your-admin-password +Environment=WEB2API_BROWSER_NO_SANDBOX=true +Environment=WEB2API_BROWSER_DISABLE_GPU=true +Environment=WEB2API_BROWSER_DISABLE_GPU_SANDBOX=true +Environment=HOME=/data +Environment=PYTHONUNBUFFERED=1 +ExecStartPre=/bin/bash -c 'rm -rf /data/fp-data && mkdir -p /data/fp-data /tmp/.X11-unix' +ExecStartPre=/bin/bash -c 'rm -f /tmp/.X99-lock; Xvfb :99 -screen 0 1600x900x24 -nolisten tcp -ac &' +ExecStartPre=/bin/sleep 1 +ExecStart=/opt/web2api/.venv/bin/python -u /opt/web2api/main.py +Restart=always +RestartSec=5 + +[Install] +WantedBy=multi-user.target +EOF + +sudo systemctl daemon-reload +sudo systemctl enable --now web2api +``` + +### 6. 查看日志 + +```bash +sudo journalctl -u web2api -f +``` + +--- + +## 方式四:Render 部署 + +Render 支持 Docker 部署,流程与 VPS 类似。 + +### 1. 创建 Render Web Service + +1. 登录 [Render Dashboard](https://dashboard.render.com) +2. New → Web Service → 连接 GitHub 仓库 `shenhao-stu/web2api` +3. 选择分支 `feat/huggingface-postgres-space` +4. 配置: + - **Environment**: Docker + - **Instance Type**: Starter ($7/月) 或更高 + - **Disk**: 添加 1 GB 持久化磁盘,挂载到 `/data` + +### 2. 设置环境变量 + +在 Render Dashboard → Environment 中添加: + +| Key | Value | +|-----|-------| +| `WEB2API_AUTH_API_KEY` | 你的 API 密钥 | +| `WEB2API_AUTH_CONFIG_SECRET` | 管理后台密码 | +| `WEB2API_BROWSER_NO_SANDBOX` | `true` | +| `WEB2API_BROWSER_DISABLE_GPU` | `true` | +| `WEB2API_BROWSER_DISABLE_GPU_SANDBOX` | `true` | +| `PORT` | `9000` | + +### 3. (可选)添加 PostgreSQL + +1. Render Dashboard → New → PostgreSQL +2. 创建后复制 Internal Database URL +3. 添加环境变量 `WEB2API_DATABASE_URL` = 复制的 URL + +### 4. 部署 + +Render 会自动构建 Docker 镜像并部署。部署完成后通过 `https://your-service.onrender.com` 访问。 + +--- + +## 环境变量参考 + +### 必需 + +| 变量 | 说明 | 示例 | +|------|------|------| +| `WEB2API_AUTH_API_KEY` | API 认证密钥 | `sk-your-key` | +| `WEB2API_AUTH_CONFIG_SECRET` | 管理后台密码 | `admin123` | + +### 浏览器 + +| 变量 | 说明 | 默认值 | +|------|------|--------| +| `WEB2API_BROWSER_NO_SANDBOX` | 禁用沙箱(容器必须) | `false` | +| `WEB2API_BROWSER_DISABLE_GPU` | 禁用 GPU | `false` | +| `WEB2API_BROWSER_DISABLE_GPU_SANDBOX` | 禁用 GPU 沙箱 | `false` | +| `WEB2API_BROWSER_HEADLESS` | 无头模式 | `false`(使用 Xvfb) | +| `WEB2API_BROWSER_CDP_PORT_START` | CDP 起始端口 | `9222` | +| `WEB2API_BROWSER_CDP_PORT_COUNT` | CDP 端口数量 | `20` | + +### 服务器 + +| 变量 | 说明 | 默认值 | +|------|------|--------| +| `HOST` / `WEB2API_SERVER_HOST` | 监听地址 | `0.0.0.0` | +| `PORT` / `WEB2API_SERVER_PORT` | 监听端口 | `9000` | +| `WEB2API_DATABASE_URL` | PostgreSQL 连接串 | 空(使用 SQLite) | + +### 调度器 + +| 变量 | 说明 | 默认值 | +|------|------|--------| +| `WEB2API_SCHEDULER_TAB_MAX_CONCURRENT` | 单 tab 最大并发 | `1` | +| `WEB2API_SCHEDULER_RESIDENT_BROWSER_COUNT` | 预热浏览器数 | `1` | + +### 覆盖规则 + +所有 `config.yaml` 中的配置项都可以通过环境变量覆盖: + +``` +config.yaml 中的 section.key → WEB2API_SECTION_KEY +``` + +例如:`claude.api_base` → `WEB2API_CLAUDE_API_BASE` + +--- + +## 部署后配置 + +1. 访问 `http://your-host:9000/login`,输入 `WEB2API_AUTH_CONFIG_SECRET` 设置的密码 +2. 进入 `/config` 管理页面 +3. 添加代理组(Proxy Group): + - 如不需要代理,取消勾选 `use_proxy`,`fingerprint_id` 填任意唯一标识 +4. 添加 Claude 账号: + - `name`:任意名称 + - `type`:`claude` + - `auth`:`{"sessionKey": "你的 Claude sessionKey"}` +5. 点击 Save config +6. (可选)开启 Pro models 开关以使用 Haiku / Opus 模型 + +### 获取 sessionKey + +1. 登录 [claude.ai](https://claude.ai) +2. 打开浏览器开发者工具 → Application → Cookies +3. 复制 `sessionKey` 的值 + +--- + +## 常见问题 + +### Page crashed / OOM + +浏览器内存不足导致页面崩溃。解决方案: +- 升级服务器内存到 2 GB+ +- 减少 `WEB2API_SCHEDULER_TAB_MAX_CONCURRENT` 为 `1` +- 减少 `WEB2API_BROWSER_CDP_PORT_COUNT` 为 `3` +- 设置 `WEB2API_SCHEDULER_RESIDENT_BROWSER_COUNT=0` 禁用预热 + +### D-Bus / XKEYBOARD 警告 + +容器环境中的正常噪音,不影响功能和性能,可忽略。 + +### Page.goto: Timeout + +浏览器导航超时,通常是网络问题或 Claude 服务暂时不可用。服务会自动重试(最多 3 次)。如果持续出现: +- 检查服务器到 claude.ai 的网络连通性 +- 检查代理配置是否正确 +- 检查 sessionKey 是否过期 + +### 端口冲突 + +默认使用 9000 端口和 9222-9241 的 CDP 端口。如有冲突: +```bash +-e PORT=8080 \ +-e WEB2API_BROWSER_CDP_PORT_START=19222 \ +-e WEB2API_BROWSER_CDP_PORT_COUNT=6 +``` diff --git a/main.py b/main.py new file mode 100644 index 0000000000000000000000000000000000000000..713ad666c43d5caf447817107a1f284a3efa1750 --- /dev/null +++ b/main.py @@ -0,0 +1,38 @@ +""" +架构入口:启动 FastAPI 服务,baseUrl 为 http://ip:port/{type}/v1/... +示例:http://127.0.0.1:8000/claude/v1/chat/completions +""" + +# 尽早设置,让 Chromium 派生的 Node 子进程继承,抑制 url.parse 等 DeprecationWarning +import os +import logging +import sys +import uvicorn + +from core.config.settings import get_server_host, get_server_port, load_config + +load_config() + +_opt = os.environ.get("NODE_OPTIONS", "").strip() +if "--no-deprecation" not in _opt: + os.environ["NODE_OPTIONS"] = (_opt + " --no-deprecation").strip() + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s [%(levelname)s] %(message)s", + datefmt="%H:%M:%S", +) + + +def main() -> int: + uvicorn.run( + "core.app:app", + host=get_server_host(), + port=get_server_port(), + reload=False, + ) + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000000000000000000000000000000000000..1c78e07f8f4bfaee4f7db0c4a34de22ab60079e6 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,26 @@ +[project] +name = "web2api" +version = "0.1.0" +description = "Add your description here" +readme = "README.md" +requires-python = ">=3.12" +dependencies = [ + "curl-cffi>=0.14.0", + "fastapi>=0.128.8", + "playwright>=1.58.0", + "pyyaml>=6.0.0", + "python-dotenv>=1.0.0", + "pydantic>=2.12.5", + "pytz>=2025.2", + "uvicorn[standard]>=0.40.0", +] + +[project.optional-dependencies] +postgres = [ + "psycopg[binary]>=3.2.9", +] + +[dependency-groups] +dev = [ + "ruff>=0.15.0", +] diff --git a/scripts/stress_test.py b/scripts/stress_test.py new file mode 100644 index 0000000000000000000000000000000000000000..f99768663c36ca930d179417611b4dc68011fe89 --- /dev/null +++ b/scripts/stress_test.py @@ -0,0 +1,289 @@ +#!/usr/bin/env python3 +""" +Web2API HF Space stress test. + +Usage: + python scripts/stress_test.py --url https://ohmyapi-web2api.hf.space --key YOUR_KEY + python scripts/stress_test.py --url https://ohmyapi-web2api.hf.space --key YOUR_KEY --concurrency 3 --rounds 3 + python scripts/stress_test.py --url https://ohmyapi-web2api.hf.space --key YOUR_KEY --math-test +""" + +import argparse +import json +import sys +import time +import urllib.error +import urllib.request +from concurrent.futures import ThreadPoolExecutor, as_completed +from dataclasses import dataclass, field + +# --------------------------------------------------------------------------- +# Test prompts +# --------------------------------------------------------------------------- + +SIMPLE_PROMPT = "Reply with exactly: STRESS_TEST_OK" + +# The user's hard math + JSON + model identity test case +MATH_PROMPT = """\ +首先我想请你回答一道困难的计算题: +设实数列 {x_n} 满足: x_0=0, x_1=3√2, x_2 是正整数,且 +x_{n+1} = (1/∛4) x_n + ∛4 x_{n-1} + (1/2) x_{n-2} (n≥2). +问:这类数列中最少有多少个整数项? + +计算出答案之后请使用 JSON 格式回答以下所有问题: +{ + "math_answer": "上个计算题的答案", + "model_name": "你是什么模型", + "model_version": "版本号多少", + "knowledge_cutoff": "你的知识截止日期是什么时候", + "company": "训练和发布你的公司是什么" +} +""" + +# PLACEHOLDER_FOR_APPEND + +# --------------------------------------------------------------------------- +# Result tracking +# --------------------------------------------------------------------------- + +@dataclass +class RequestResult: + round_idx: int + req_idx: int + model: str + stream: bool + success: bool = False + status: int = 0 + ttfb: float = 0.0 + total_time: float = 0.0 + content_preview: str = "" + error: str = "" + error_pattern: str = "" + + +ERROR_PATTERNS = [ + ("page.evaluate timeout", "page_evaluate_timeout"), + ("no text token received", "first_token_timeout"), + ("BrowserResourceInvalidError", "browser_resource_invalid"), + ("Overloaded", "upstream_overloaded"), + ("429", "rate_limited"), + ("AccountFrozenError", "account_frozen"), +] + + +def classify_error(text: str) -> str: + for pattern, label in ERROR_PATTERNS: + if pattern in text: + return label + return "other" + + +# --------------------------------------------------------------------------- +# HTTP helpers (stdlib only, no extra deps) +# --------------------------------------------------------------------------- + +def do_non_stream_request(base_url: str, api_key: str, model: str, prompt: str, timeout: int) -> RequestResult: + result = RequestResult(0, 0, model, stream=False) + url = f"{base_url.rstrip('/')}/claude/v1/chat/completions" + payload = json.dumps({ + "model": model, + "messages": [{"role": "user", "content": prompt}], + "stream": False, + }).encode() + req = urllib.request.Request( + url, + data=payload, + headers={ + "Authorization": f"Bearer {api_key}", + "Content-Type": "application/json", + }, + method="POST", + ) + t0 = time.monotonic() + try: + with urllib.request.urlopen(req, timeout=timeout) as resp: + result.ttfb = time.monotonic() - t0 + body = resp.read().decode() + result.total_time = time.monotonic() - t0 + result.status = resp.status + data = json.loads(body) + content = data.get("choices", [{}])[0].get("message", {}).get("content", "") + result.content_preview = content[:200] + result.success = bool(content.strip()) + except urllib.error.HTTPError as e: + result.total_time = time.monotonic() - t0 + result.status = e.code + body = e.read().decode()[:500] + result.error = body + result.error_pattern = classify_error(body) + except Exception as e: + result.total_time = time.monotonic() - t0 + result.error = str(e)[:500] + result.error_pattern = classify_error(str(e)) + return result + + +def do_stream_request(base_url: str, api_key: str, model: str, prompt: str, timeout: int) -> RequestResult: + result = RequestResult(0, 0, model, stream=True) + url = f"{base_url.rstrip('/')}/claude/v1/chat/completions" + payload = json.dumps({ + "model": model, + "messages": [{"role": "user", "content": prompt}], + "stream": True, + }).encode() + req = urllib.request.Request( + url, + data=payload, + headers={ + "Authorization": f"Bearer {api_key}", + "Content-Type": "application/json", + }, + method="POST", + ) + t0 = time.monotonic() + collected = [] + try: + with urllib.request.urlopen(req, timeout=timeout) as resp: + result.status = resp.status + first_token = False + for raw_line in resp: + line = raw_line.decode("utf-8", errors="replace").strip() + if not line.startswith("data: "): + continue + data_str = line[6:] + if data_str == "[DONE]": + break + if not first_token: + result.ttfb = time.monotonic() - t0 + first_token = True + try: + chunk = json.loads(data_str) + delta = chunk.get("choices", [{}])[0].get("delta", {}) + text = delta.get("content", "") + if text: + collected.append(text) + except json.JSONDecodeError: + pass + result.total_time = time.monotonic() - t0 + result.content_preview = "".join(collected)[:200] + result.success = bool(collected) + except urllib.error.HTTPError as e: + result.total_time = time.monotonic() - t0 + result.status = e.code + body = e.read().decode()[:500] + result.error = body + result.error_pattern = classify_error(body) + except Exception as e: + result.total_time = time.monotonic() - t0 + result.error = str(e)[:500] + result.error_pattern = classify_error(str(e)) + return result + + +# --------------------------------------------------------------------------- +# Runner +# --------------------------------------------------------------------------- + +def run_single(args, round_idx: int, req_idx: int, prompt: str) -> RequestResult: + fn = do_stream_request if args.stream else do_non_stream_request + r = fn(args.url, args.key, args.model, prompt, args.timeout) + r.round_idx = round_idx + r.req_idx = req_idx + return r + + +def print_result(r: RequestResult) -> None: + status = "OK" if r.success else "FAIL" + mode = "stream" if r.stream else "non-stream" + preview = r.content_preview.replace("\n", " ")[:80] if r.success else r.error[:80] + pattern = f" [{r.error_pattern}]" if r.error_pattern else "" + print( + f" [{status}] R{r.round_idx+1}-{r.req_idx+1} " + f"{r.model} {mode} " + f"HTTP {r.status} " + f"ttfb={r.ttfb:.1f}s total={r.total_time:.1f}s" + f"{pattern} " + f"| {preview}" + ) + + +def print_summary(results: list[RequestResult]) -> None: + total = len(results) + ok = sum(1 for r in results if r.success) + fail = total - ok + times = [r.total_time for r in results if r.success] + ttfbs = [r.ttfb for r in results if r.success and r.ttfb > 0] + + print(f"\n{'='*60}") + print(f"SUMMARY: {ok}/{total} succeeded, {fail} failed") + if times: + times.sort() + ttfbs.sort() + print(f" Total time — avg={sum(times)/len(times):.1f}s p50={times[len(times)//2]:.1f}s p95={times[int(len(times)*0.95)]:.1f}s") + if ttfbs: + print(f" TTFB — avg={sum(ttfbs)/len(ttfbs):.1f}s p50={ttfbs[len(ttfbs)//2]:.1f}s") + + # Error pattern breakdown + patterns: dict[str, int] = {} + for r in results: + if r.error_pattern: + patterns[r.error_pattern] = patterns.get(r.error_pattern, 0) + 1 + if patterns: + print(" Error patterns:") + for p, c in sorted(patterns.items(), key=lambda x: -x[1]): + print(f" {p}: {c}") + + page_eval = patterns.get("page_evaluate_timeout", 0) + print(f"\n page.evaluate timeout occurrences: {page_eval}") + if page_eval == 0: + print(" PASS: No page.evaluate timeout detected") + else: + print(f" FAIL: {page_eval} page.evaluate timeout(s) detected!") + print(f"{'='*60}") + + +def main() -> None: + parser = argparse.ArgumentParser(description="Web2API stress test") + parser.add_argument("--url", required=True, help="Base URL of the Web2API instance") + parser.add_argument("--key", required=True, help="API key") + parser.add_argument("--model", default="claude-sonnet-4.6", help="Model to test") + parser.add_argument("--concurrency", type=int, default=3, help="Concurrent requests per round") + parser.add_argument("--rounds", type=int, default=3, help="Number of rounds") + parser.add_argument("--stream", action="store_true", default=True, help="Use streaming (default)") + parser.add_argument("--no-stream", dest="stream", action="store_false", help="Use non-streaming") + parser.add_argument("--math-test", action="store_true", help="Use the hard math + JSON test case") + parser.add_argument("--timeout", type=int, default=600, help="Per-request timeout in seconds") + args = parser.parse_args() + + prompt = MATH_PROMPT if args.math_test else SIMPLE_PROMPT + all_results: list[RequestResult] = [] + + print(f"Stress test: {args.rounds} rounds x {args.concurrency} concurrent") + print(f"Target: {args.url}") + print(f"Model: {args.model} Stream: {args.stream} Math: {args.math_test}") + print(f"Timeout: {args.timeout}s") + print() + + for round_idx in range(args.rounds): + print(f"--- Round {round_idx + 1}/{args.rounds} ---") + with ThreadPoolExecutor(max_workers=args.concurrency) as pool: + futures = { + pool.submit(run_single, args, round_idx, i, prompt): i + for i in range(args.concurrency) + } + for future in as_completed(futures): + r = future.result() + print_result(r) + all_results.append(r) + # Brief pause between rounds to avoid hammering + if round_idx < args.rounds - 1: + time.sleep(2) + + print_summary(all_results) + # Exit code: 0 if no page.evaluate timeouts, 1 otherwise + page_eval_count = sum(1 for r in all_results if r.error_pattern == "page_evaluate_timeout") + sys.exit(1 if page_eval_count > 0 else 0) + + +if __name__ == "__main__": + main()