diff --git a/.env.example b/.env.example new file mode 100644 index 0000000000000000000000000000000000000000..87b12e9227ca873988067e17d9bf5b9a026070bc --- /dev/null +++ b/.env.example @@ -0,0 +1,72 @@ +# 代理服务配置文件示例 +# 复制此文件为 .env 并根据需要修改配置值 + +# ========== API 基础配置 ========== +# 客户端访问本服务使用的 Bearer 密钥,不是上游 Z.AI 用户 Token +# 上游用户 Token 请在管理后台导入,由数据库 Token 池统一管理 +AUTH_TOKEN=sk-your-api-key + +# 跳过客户端认证(仅开发环境使用) +SKIP_AUTH_TOKEN=false + +# ========== 用户 Token 池配置 ========== +# 仅作用于管理后台导入的 Z.AI 用户 Token +# 失败多少次后标记为不可用 +TOKEN_FAILURE_THRESHOLD=3 + +# 失败 Token 多久后重新参与调度(秒) +TOKEN_RECOVERY_TIMEOUT=1800 + +# 定时扫描服务端目录导入 Token +TOKEN_AUTO_IMPORT_ENABLED=false + +# 自动导入的服务端本地目录 +TOKEN_AUTO_IMPORT_SOURCE_DIR= + +# 自动导入扫描间隔(秒) +TOKEN_AUTO_IMPORT_INTERVAL=300 + +# 定时维护 Token 池 +TOKEN_AUTO_MAINTENANCE_ENABLED=false + +# 自动维护执行间隔(秒) +TOKEN_AUTO_MAINTENANCE_INTERVAL=1800 + +# 自动维护动作开关 +TOKEN_AUTO_REMOVE_DUPLICATES=true +TOKEN_AUTO_HEALTH_CHECK=true +TOKEN_AUTO_DELETE_INVALID=false + +# ========== 匿名 Guest 会话池 ========== +# false: 禁用 guest 匿名池,仅使用后台导入的用户 Token 池 +# true: 启用 guest 匿名池;当没有可用用户 Token 时允许匿名会话 +ANONYMOUS_MODE=true + +# 预热和维持的 guest 会话数量 +GUEST_POOL_SIZE=10 + +# ========== 服务器配置 ========== +LISTEN_PORT=8080 +SERVICE_NAME=api-proxy-server +DEBUG_LOGGING=false + +# Nginx 反向代理路径前缀(可选) +ROOT_PATH= + +# Function Call 功能开关 +TOOL_SUPPORT=true + +# 工具调用扫描限制(字符数) +SCAN_LIMIT=200000 + +# SQLite 数据库路径 +DB_PATH=tokens.db + +# ========== 代理配置 ========== +# HTTP_PROXY=http://127.0.0.1:7890 +# HTTPS_PROXY=http://127.0.0.1:7890 +# SOCKS5_PROXY=socks5://127.0.0.1:1080 + +# ========== 管理后台认证 ========== +ADMIN_PASSWORD=admin123 +SESSION_SECRET_KEY=your-secret-key-change-in-production diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000000000000000000000000000000000000..dfe0770424b2a19faf507a501ebfc23be8f54e7b --- /dev/null +++ b/.gitattributes @@ -0,0 +1,2 @@ +# Auto detect text files and perform LF normalization +* text=auto diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml new file mode 100644 index 0000000000000000000000000000000000000000..94e5dc92561ccaca2774fcbbeea5b77b193296ae --- /dev/null +++ b/.github/workflows/docker.yml @@ -0,0 +1,64 @@ +name: Build and Push Docker Image + +on: + push: + branches: + - main + tags: + - 'v*' + +env: + IMAGE_NAME: z-ai2api-python + +jobs: + docker: + runs-on: ubuntu-latest + permissions: + contents: read + packages: write + + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 + + - name: Login to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + + - name: Login to Docker Hub + if: github.event_name != 'pull_request' + uses: docker/login-action@v3 + with: + username: ${{ secrets.DOCKERHUB_USERNAME }} + password: ${{ secrets.DOCKERHUB_TOKEN }} + + - name: Extract metadata + id: meta + uses: docker/metadata-action@v5 + with: + images: | + ghcr.io/${{ github.repository }} + ${{ secrets.DOCKERHUB_USERNAME }}/${{ env.IMAGE_NAME }} + tags: | + type=ref,event=branch + type=semver,pattern={{version}} + type=semver,pattern={{major}}.{{minor}} + type=raw,value=latest,enable={{is_default_branch}} + + - name: Build and push + uses: docker/build-push-action@v5 + with: + context: . + file: ./deploy/Dockerfile + platforms: linux/amd64,linux/arm64 + push: true + tags: ${{ steps.meta.outputs.tags }} + labels: ${{ steps.meta.outputs.labels }} + cache-from: type=gha + cache-to: type=gha,mode=max diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..69f8a8698cf4db10d62e7b703b535076ad1b1796 --- /dev/null +++ b/.gitignore @@ -0,0 +1,181 @@ +# Custom +.vs/ +.vscode/ +.idea/ +.conda/ +*.zip +*.txt +*.pid +docs/ +output/ +main.build/ +main.dist/ +main.onefile-build/ +*report.xml +*.yaml +logs/ +backup/ +uv.lock +AGENTS.md +*.db + +# AI Toolset +.augment/ +.cursor/ +.claude/ +CLAUDE.md + +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ +.ace-tool/ diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..e8f5578f2e26366d0f292a785aa952035f3d40b7 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,30 @@ +FROM python:3.12-slim + +# Set environment variables +ENV LISTEN_PORT=7860 +ENV DB_PATH=/app/data/tokens.db +ENV PYTHONUNBUFFERED=1 + +# Set working directory +WORKDIR /app + +# Create data and logs directories and set permissions +# HF Spaces runs as user 1000, so we make sure it can write to these directories +RUN mkdir -p /app/data /app/logs && \ + chmod -R 777 /app/data /app/logs + +# Install dependencies +COPY requirements.txt . +RUN pip install --no-cache-dir -r requirements.txt + +# Copy application code +COPY . . + +# Ensure all files are accessible +RUN chmod -R 777 /app + +# Expose port +EXPOSE 7860 + +# Run the application +CMD ["python", "main.py"] diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..d86c4ab26f3c0bdf7eabcbf887e5bfc0b9653212 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2025 ZyphrZero + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md new file mode 100644 index 0000000000000000000000000000000000000000..5f06ada691a9bef5f03ccde5e079b008f1e087f5 --- /dev/null +++ b/README.md @@ -0,0 +1,132 @@ +--- +title: Z.ai API +emoji: 🚀 +colorFrom: blue +colorTo: indigo +sdk: docker +app_port: 7860 +--- + +# z-ai2api_python + +基于 FastAPI + Granian 的 GLM 代理服务 +适合本地开发、自托管代理、Token 池管理和兼容客户端接入 + +中文简体 / [English](README_EN.md) + +## 特性 + +- 兼容 `OpenAI`、`Claude Code`、`Anthropic` 风格请求 +- 支持流式响应、工具调用、Thinking 模型 +- 内置 Token 池,支持轮询、失败熔断、恢复和健康检查 +- 提供后台页面:仪表盘、Token 管理、配置管理、实时日志 +- 使用 SQLite 存储 Token 和请求日志,部署简单 +- 支持本地运行和 Docker / Docker Compose 部署 + +## 快速开始 + +### 环境要求 + +- Python `3.9` 到 `3.12` +- 推荐使用 `uv` + +### 本地启动 + +```bash +git clone https://github.com/ZyphrZero/z.ai2api_python.git +cd z.ai2api_python + +uv sync +cp .env.example .env +uv run python main.py +``` + +首次启动会自动初始化数据库。 + +默认地址: + +- API 根路径:`http://127.0.0.1:8080` +- OpenAI 文档:`http://127.0.0.1:8080/docs` +- 管理后台:`http://127.0.0.1:8080/admin` + +### Docker Compose + +```bash +docker compose -f deploy/docker-compose.yml up -d --build +``` + +更多部署说明见 [deploy/README_DOCKER.md](deploy/README_DOCKER.md)。 + +## 最小配置 + +至少建议确认这些环境变量: + +| 变量 | 说明 | +| --- | --- | +| `AUTH_TOKEN` | 客户端访问本服务使用的 Bearer Token | +| `ADMIN_PASSWORD` | 管理后台登录密码,默认值必须修改 | +| `LISTEN_PORT` | 服务监听端口,默认 `8080` | +| `ANONYMOUS_MODE` | 是否启用匿名模式 | +| `GUEST_POOL_SIZE` | 匿名池容量 | +| `DB_PATH` | SQLite 数据库路径 | +| `TOKEN_FAILURE_THRESHOLD` | Token 连续失败阈值 | +| `TOKEN_RECOVERY_TIMEOUT` | Token 恢复等待时间 | + +完整配置请看 [.env.example](.env.example)。 + +## 管理后台 + +管理后台统一入口: + +- `/admin`:仪表盘 +- `/admin/tokens`:Token 管理 +- `/admin/config`:配置管理 +- `/admin/logs`:实时日志 + +## 常用命令 + +```bash +# 启动服务 +uv run python main.py + +# 运行测试 +uv run pytest + +# 运行一个现有 smoke test +uv run python tests/test_simple_signature.py + +# Lint +uv run ruff check app tests main.py +``` + +## 兼容接口 + +常见接口入口: + +- OpenAI 兼容:`/v1/chat/completions` +- Anthropic 兼容:`/v1/messages` +- Claude Code 兼容:`/anthropic/v1/messages` + +模型映射和默认模型可在 `.env` 或后台配置页中调整。 + +## ⭐ Star History + +[![Star History Chart](https://api.star-history.com/svg?repos=ZyphrZero/z.ai2api_python&type=Date)](https://star-history.com/#ZyphrZero/z.ai2api_python&Date) + +## 许可证 + +本项目采用 MIT 许可证 - 详见 [LICENSE](LICENSE) 文件。 + +## 免责声明 + +- **本项目仅供学习和研究使用,切勿用于其他用途** +- 本项目与 Z.AI 官方无关 +- 使用前请确保遵守 Z.AI 的服务条款 +- 请勿用于商业用途或违反使用条款的场景 +- 用户需自行承担使用风险 + +--- + +
+Made with ❤️ by the community +
diff --git a/README_EN.md b/README_EN.md new file mode 100644 index 0000000000000000000000000000000000000000..cbfde08e41dc433937c861a06248f0911b6f8c32 --- /dev/null +++ b/README_EN.md @@ -0,0 +1,123 @@ +# z-ai2api_python + +GLM proxy service based on FastAPI + Granian +Suitable for local development, self-hosted proxy, Token pool management, and compatible client access + +English / [中文简体](README.md) + +## Features + +- Compatible with `OpenAI`, `Claude Code`, `Anthropic` style requests +- Supports streaming responses, tool calls, Thinking models +- Built-in Token pool, supports polling, failure circuit breaker, recovery, and health checks +- Provides admin panel: Dashboard, Token management, Configuration management, Real-time logs +- Uses SQLite to store Tokens and request logs, simple deployment +- Supports local running and Docker / Docker Compose deployment + +## Quick Start + +### Environment Requirements + +- Python `3.9` to `3.12` +- Recommend using `uv` + +### Local Startup + +```bash +git clone https://github.com/ZyphrZero/z.ai2api_python.git +cd z.ai2api_python + +uv sync +cp .env.example .env +uv run python main.py +``` + +First startup will automatically initialize the database. + +Default addresses: + +- API root path: `http://127.0.0.1:8080` +- OpenAI docs: `http://127.0.0.1:8080/docs` +- Admin panel: `http://127.0.0.1:8080/admin` + +### Docker Compose + +```bash +docker compose -f deploy/docker-compose.yml up -d --build +``` + +More deployment instructions see [deploy/README_DOCKER.md](deploy/README_DOCKER.md). + +## Minimum Configuration + +At least suggest confirming these environment variables: + +| Variable | Description | +| --- | --- | +| `AUTH_TOKEN` | Bearer Token used by clients to access this service | +| `ADMIN_PASSWORD` | Admin panel login password, default value must be changed | +| `LISTEN_PORT` | Service listening port, default `8080` | +| `ANONYMOUS_MODE` | Whether to enable anonymous mode | +| `GUEST_POOL_SIZE` | Anonymous pool capacity | +| `DB_PATH` | SQLite database path | +| `TOKEN_FAILURE_THRESHOLD` | Token consecutive failure threshold | +| `TOKEN_RECOVERY_TIMEOUT` | Token recovery wait time | + +Complete configuration please see [.env.example](.env.example). + +## Admin Panel + +Admin panel unified entry: + +- `/admin`: Dashboard +- `/admin/tokens`: Token management +- `/admin/config`: Configuration management +- `/admin/logs`: Real-time logs + +## Common Commands + +```bash +# Start service +uv run python main.py + +# Run tests +uv run pytest + +# Run an existing smoke test +uv run python tests/test_simple_signature.py + +# Lint +uv run ruff check app tests main.py +``` + +## Compatible Interfaces + +Common interface entries: + +- OpenAI compatible: `/v1/chat/completions` +- Anthropic compatible: `/v1/messages` +- Claude Code compatible: `/anthropic/v1/messages` + +Model mapping and default model can be adjusted in `.env` or admin configuration page. + +## ⭐ Star History + +[![Star History Chart](https://api.star-history.com/svg?repos=ZyphrZero/z.ai2api_python&type=Date)](https://star-history.com/#ZyphrZero/z.ai2api_python&Date) + +## License + +This project uses MIT license - see [LICENSE](LICENSE) file for details. + +## Disclaimer + +- **This project is for learning and research use only, do not use for other purposes** +- This project is not affiliated with Z.AI official +- Please ensure compliance with Z.AI's terms of service before use +- Do not use for commercial purposes or scenarios that violate terms of service +- Users must bear their own usage risks + +--- + +
+Made with ❤️ by the community +
\ No newline at end of file diff --git a/app/__init__.py b/app/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..101e6a38bb631969e65ff40b512eb45668947047 --- /dev/null +++ b/app/__init__.py @@ -0,0 +1,6 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +from app import core, models, utils + +__all__ = ["core", "models", "utils"] diff --git a/app/admin/__init__.py b/app/admin/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..36cacc8ac4ca1ac050d6d7f340c82fc4de024519 --- /dev/null +++ b/app/admin/__init__.py @@ -0,0 +1,3 @@ +""" +管理后台模块初始化 +""" diff --git a/app/admin/api.py b/app/admin/api.py new file mode 100644 index 0000000000000000000000000000000000000000..af6b040f3fe786984d9ad4d8b05cbba841049f1e --- /dev/null +++ b/app/admin/api.py @@ -0,0 +1,1111 @@ +""" +管理后台 API 接口 +用于 htmx 调用的 HTML 片段返回 +""" +from datetime import datetime +from html import escape +from pathlib import Path +import re +from typing import Optional + +from fastapi import APIRouter, Depends, Request +from fastapi.responses import HTMLResponse, JSONResponse +from fastapi.templating import Jinja2Templates + +from app.admin.auth import require_auth +from app.admin.config_manager import ( + read_env_content, + reset_env_to_example, + save_form_config, + save_source_config, +) +from app.admin.stats import collect_admin_stats, normalize_trend_window +from app.services.request_log_dao import get_request_log_dao +from app.utils.logger import logger + +router = APIRouter(prefix="/admin/api", tags=["admin-api"]) +templates = Jinja2Templates(directory="app/templates") +DEFAULT_TOKEN_NAMESPACE = "zai" + + +# ==================== 认证 API ==================== + +@router.post("/login") +async def login(request: Request): + """管理后台登录""" + from app.admin.auth import create_session + + try: + data = await request.json() + password = data.get("password", "") + + # 创建 session + session_token = create_session(password) + + if session_token: + # 登录成功,设置 cookie + response = JSONResponse({ + "success": True, + "message": "登录成功" + }) + response.set_cookie( + key="admin_session", + value=session_token, + httponly=True, + max_age=86400, # 24小时 + samesite="lax" + ) + logger.info("✅ 管理后台登录成功") + return response + else: + # 密码错误 + logger.warning("❌ 管理后台登录失败:密码错误") + return JSONResponse({ + "success": False, + "message": "密码错误" + }, status_code=401) + + except Exception as e: + logger.error(f"❌ 登录异常: {e}") + return JSONResponse({ + "success": False, + "message": "登录失败" + }, status_code=500) + + +@router.post("/logout") +async def logout(request: Request): + """管理后台登出""" + from app.admin.auth import delete_session, get_session_token_from_request + + session_token = get_session_token_from_request(request) + delete_session(session_token) + + # 清除 cookie + response = JSONResponse({ + "success": True, + "message": "已登出" + }) + response.delete_cookie("admin_session") + logger.info("✅ 管理后台已登出") + return response + + +async def reload_settings(): + """热重载配置(重新加载环境变量并更新 settings 对象)""" + from dotenv import load_dotenv + + from app.core.config import settings + from app.utils.logger import setup_logger + + # 重新加载 .env 文件 + load_dotenv(override=True) + + # 重新创建 Settings 对象并更新全局配置 + new_settings = type(settings)() + + # 更新全局 settings 的所有属性 + for field_name in new_settings.model_fields.keys(): + setattr(settings, field_name, getattr(new_settings, field_name)) + + # 重新初始化 logger(使用新的 DEBUG_LOGGING 配置) + setup_logger(log_dir="logs", debug_mode=settings.DEBUG_LOGGING) + + logger.info(f"🔄 配置已热重载 (DEBUG_LOGGING={settings.DEBUG_LOGGING})") + + +def _build_alert( + message: str, + *, + title: str, + level: str, + status_code: int = 200, +) -> HTMLResponse: + level_classes = { + "success": "bg-green-100 border-green-400 text-green-700", + "warning": "bg-yellow-100 border-yellow-400 text-yellow-700", + "error": "bg-red-100 border-red-400 text-red-700", + "info": "bg-blue-100 border-blue-400 text-blue-700", + } + classes = level_classes.get(level, level_classes["info"]) + safe_title = escape(title) + safe_message = escape(message) + return HTMLResponse( + f""" + + """, + status_code=status_code, + ) + + +def _with_hx_trigger(response: HTMLResponse, event_name: str) -> HTMLResponse: + response.headers["HX-Trigger"] = event_name + return response + + +def _get_int_query_param( + request: Request, + name: str, + default: int, + *, + minimum: int = 1, + maximum: Optional[int] = None, +) -> int: + """解析查询参数中的正整数,非法值回退到默认值。""" + raw_value = request.query_params.get(name) + if raw_value is None: + return default + + try: + value = int(str(raw_value).strip()) + except (TypeError, ValueError): + return default + + value = max(minimum, value) + if maximum is not None: + value = min(value, maximum) + return value + + +def _build_pagination( + *, + total_items: int, + page: int, + page_size: int, +) -> dict: + """构建分页上下文。""" + total_items = max(0, int(total_items)) + page_size = max(1, int(page_size)) + total_pages = max(1, (total_items + page_size - 1) // page_size) + current_page = min(max(1, int(page)), total_pages) + + if total_items == 0: + start_item = 0 + end_item = 0 + else: + start_item = (current_page - 1) * page_size + 1 + end_item = min(total_items, current_page * page_size) + + return { + "current_page": current_page, + "page_size": page_size, + "total_items": total_items, + "total_pages": total_pages, + "has_previous": current_page > 1, + "has_next": current_page < total_pages, + "previous_page": max(1, current_page - 1), + "next_page": min(total_pages, current_page + 1), + "start_item": start_item, + "end_item": end_item, + } + + +def _normalize_display_value(value: str) -> str: + normalized = re.sub(r"[^a-z0-9]+", "", str(value or "").casefold()) + return normalized + + +def _is_redundant_source(source: str, client_name: str) -> bool: + normalized_source = _normalize_display_value(source) + normalized_client = _normalize_display_value(client_name) + if not normalized_source: + return True + if not normalized_client: + return False + return normalized_source == normalized_client + + +def _humanize_protocol(protocol: str) -> str: + normalized = str(protocol or "").strip().lower() + if normalized == "openai": + return "OpenAI" + if normalized == "anthropic": + return "Anthropic" + if normalized == "unknown": + return "Unknown" + return normalized or "Unknown" + + +@router.get( + "/dashboard/usage-trend", + response_class=JSONResponse, + dependencies=[Depends(require_auth)], +) +async def get_dashboard_usage_trend(request: Request): + """返回仪表盘趋势图数据。""" + trend_window = normalize_trend_window( + request.query_params.get("window") + ) + dao = get_request_log_dao() + trend_points = await dao.get_provider_usage_trend( + DEFAULT_TOKEN_NAMESPACE, + window=trend_window, + ) + return JSONResponse( + { + "window": trend_window, + "points": trend_points, + } + ) + + +def _validate_directory_path(source_dir: str) -> str: + if not source_dir: + raise ValueError("请先填写服务端可访问的本地目录路径。") + + source_path = Path(source_dir).expanduser() + if not source_path.exists(): + raise ValueError(f"导入目录不存在: {source_path}") + if not source_path.is_dir(): + raise ValueError(f"导入路径不是目录: {source_path}") + + return str(source_path) + + +@router.get("/token-pool", response_class=HTMLResponse) +async def get_token_pool_status(request: Request): + """获取 Token 池状态(HTML 片段)""" + from app.utils.token_pool import get_token_pool + + token_pool = get_token_pool() + + if not token_pool: + # Token 池未初始化 + context = { + "request": request, + "tokens": [], + } + return templates.TemplateResponse("components/token_pool.html", context) + + # 获取 token 状态统计 + pool_status = token_pool.get_pool_status() + tokens_info = [] + + for idx, token_info in enumerate(pool_status.get("tokens", []), 1): + is_available = token_info.get("is_available", False) + is_healthy = token_info.get("is_healthy", False) + + # 确定状态和颜色 + if is_healthy: + status = "健康" + status_color = "bg-green-100 text-green-800" + elif is_available: + status = "可用" + status_color = "bg-yellow-100 text-yellow-800" + else: + status = "失败" + status_color = "bg-red-100 text-red-800" + + # 格式化最后使用时间 + last_success = token_info.get("last_success_time", 0) + if last_success > 0: + from datetime import datetime + last_used = datetime.fromtimestamp(last_success).strftime("%Y-%m-%d %H:%M:%S") + else: + last_used = "从未使用" + + tokens_info.append({ + "index": idx, + "key": token_info.get("token", "")[:20] + "...", + "status": status, + "status_color": status_color, + "last_used": last_used, + "failure_count": token_info.get("failure_count", 0), + "success_rate": token_info.get("success_rate", "0%"), + "token_type": token_info.get("token_type", "unknown"), + }) + + context = { + "request": request, + "tokens": tokens_info, + } + + return templates.TemplateResponse("components/token_pool.html", context) + + +@router.get("/recent-logs", response_class=HTMLResponse) +async def get_recent_logs(request: Request): + """获取最近的请求日志(HTML 片段)""" + dao = get_request_log_dao() + page_size = _get_int_query_param( + request, + "page_size", + 12, + maximum=50, + ) + requested_page = _get_int_query_param(request, "page", 1, maximum=100000) + total_count = await dao.count_logs() + pagination = _build_pagination( + total_items=total_count, + page=requested_page, + page_size=page_size, + ) + + rows = await dao.get_recent_logs( + limit=page_size, + offset=(pagination["current_page"] - 1) * page_size, + ) + logs = [] + for row in rows: + timestamp = ( + row.get("timestamp") + or row.get("created_at") + or datetime.now().strftime("%Y-%m-%d %H:%M:%S") + ) + success = bool(row.get("success")) + status_code = int( + row.get("status_code") or (200 if success else 500) + ) + duration_value = float(row.get("duration") or 0.0) + first_token_value = float(row.get("first_token_time") or 0.0) + source = row.get("source") or "unknown" + client_name = row.get("client_name") or "Unknown" + provider = row.get("provider") or "-" + source_display = ( + "" + if _is_redundant_source(source, client_name) + else source + ) + provider_display = "" if provider == "zai" else provider + logs.append( + { + "timestamp": timestamp, + "endpoint": row.get("endpoint") or "-", + "model": row.get("model") or "-", + "provider": provider, + "provider_display": provider_display, + "source": source, + "source_display": source_display, + "protocol": row.get("protocol") or "unknown", + "protocol_display": _humanize_protocol( + row.get("protocol") or "unknown" + ), + "client_name": client_name, + "success": success, + "status_code": status_code, + "duration_display": f"{duration_value:.2f}s", + "first_token_display": ( + f"{first_token_value:.2f}s" + if first_token_value > 0 + else "--" + ), + "input_tokens": int(row.get("input_tokens") or 0), + "output_tokens": int(row.get("output_tokens") or 0), + "cache_creation_tokens": int( + row.get("cache_creation_tokens") or 0 + ), + "cache_read_tokens": int( + row.get("cache_read_tokens") or 0 + ), + "error_message": row.get("error_message") or "", + } + ) + + context = { + "request": request, + "logs": logs, + "page": pagination, + } + + return templates.TemplateResponse("components/recent_logs.html", context) + + +@router.post("/config/save", dependencies=[Depends(require_auth)]) +async def save_config(request: Request): + """保存结构化配置并热重载。""" + try: + form_data = await request.form() + await save_form_config( + form_data, + reload_callback=reload_settings, + ) + logger.info("✅ 结构化配置已保存") + return _with_hx_trigger( + _build_alert( + "配置已保存并热重载,页面即将刷新。", + title="保存成功!", + level="success", + ), + "admin-config-refresh", + ) + except ValueError as exc: + return _build_alert( + str(exc), + title="校验失败!", + level="error", + status_code=400, + ) + except Exception as exc: + logger.error(f"❌ 配置保存失败: {exc}") + return _build_alert( + f"保存失败: {exc}", + title="错误!", + level="error", + status_code=500, + ) + + +@router.post("/config/source", dependencies=[Depends(require_auth)]) +async def save_config_source(request: Request): + """保存 .env 源文件并热重载。""" + try: + form_data = await request.form() + await save_source_config( + str(form_data.get("env_content", "")), + reload_callback=reload_settings, + ) + logger.info("✅ 配置源文件已保存") + return _with_hx_trigger( + _build_alert( + ".env 源文件已保存并热重载,页面即将刷新。", + title="保存成功!", + level="success", + ), + "admin-config-refresh", + ) + except ValueError as exc: + return _build_alert( + str(exc), + title="源文件校验失败!", + level="error", + status_code=400, + ) + except Exception as exc: + logger.error(f"❌ 源文件保存失败: {exc}") + return _build_alert( + f"源文件保存失败: {exc}", + title="错误!", + level="error", + status_code=500, + ) + + +@router.post("/config/reset", dependencies=[Depends(require_auth)]) +async def reset_config(): + """将配置重置为 .env.example 并热重载。""" + try: + await reset_env_to_example(reload_callback=reload_settings) + logger.info("✅ 配置已重置为 .env.example 默认值") + return _with_hx_trigger( + _build_alert( + "配置已恢复为 .env.example 默认值,页面即将刷新。", + title="已重置!", + level="success", + ), + "admin-config-refresh", + ) + except FileNotFoundError: + logger.error("❌ 未找到 .env.example,无法重置配置") + return _build_alert( + "未找到 .env.example,无法重置配置。", + title="错误!", + level="error", + status_code=404, + ) + except Exception as exc: + logger.error(f"❌ 配置重置失败: {exc}") + return _build_alert( + f"重置失败: {exc}", + title="错误!", + level="error", + status_code=500, + ) + + +@router.get("/env-preview", dependencies=[Depends(require_auth)]) +async def get_env_preview(): + """获取 .env 文件预览""" + try: + content = read_env_content() + if not content: + content = "# .env 文件不存在" + return HTMLResponse(f"
{escape(content)}
") + except Exception as exc: + return HTMLResponse(f"
# 读取失败: {escape(str(exc))}
") + + +@router.get("/live-logs", response_class=HTMLResponse) +async def get_live_logs(): + """获取实时日志(最新 50 行)""" + import os + from datetime import datetime + + logs = [] + + # 尝试读取日志文件 + log_dir = "logs" + if os.path.exists(log_dir): + log_files = sorted([f for f in os.listdir(log_dir) if f.endswith('.log')], reverse=True) + if log_files: + log_file = os.path.join(log_dir, log_files[0]) + try: + with open(log_file, 'r', encoding='utf-8') as f: + # 读取最后 50 行 + lines = f.readlines()[-50:] + logs = lines + except Exception as e: + logs = [f"# [{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] 读取日志失败: {str(e)}"] + + if not logs: + logs = [f"# [{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] 暂无日志数据"] + + html = "" + for log in logs: + log_line = log.strip() + if not log_line: + continue + + # 根据日志级别设置颜色和样式 + if "ERROR" in log_line or "CRITICAL" in log_line: + color_class = "text-red-400 font-semibold" + icon = "❌" + elif "WARNING" in log_line or "WARN" in log_line: + color_class = "text-yellow-400" + icon = "⚠️" + elif "SUCCESS" in log_line or "✅" in log_line: + color_class = "text-green-400" + icon = "✅" + elif "INFO" in log_line: + color_class = "text-blue-400" + icon = "ℹ️" + elif "DEBUG" in log_line: + color_class = "text-gray-400 text-xs" + icon = "🔍" + else: + color_class = "text-gray-300" + icon = "•" + + # 转义 HTML 特殊字符 + log_escaped = log_line.replace('<', '<').replace('>', '>') + + html += f'
{icon} {log_escaped}
' + + return HTMLResponse(html) + + +# ==================== Token 管理 API ==================== + +@router.get("/tokens/list", response_class=HTMLResponse) +async def get_tokens_list(request: Request): + """获取 Token 列表(HTML 片段)""" + from app.services.token_dao import get_token_dao + + dao = get_token_dao() + page_size = _get_int_query_param( + request, + "page_size", + 20, + maximum=100, + ) + requested_page = _get_int_query_param(request, "page", 1, maximum=100000) + total_count = await dao.count_tokens_by_provider( + DEFAULT_TOKEN_NAMESPACE, + enabled_only=False, + ) + pagination = _build_pagination( + total_items=total_count, + page=requested_page, + page_size=page_size, + ) + tokens = await dao.get_tokens_by_provider( + DEFAULT_TOKEN_NAMESPACE, + enabled_only=False, + limit=page_size, + offset=(pagination["current_page"] - 1) * page_size, + ) + + context = { + "request": request, + "tokens": tokens, + "page": pagination, + } + + return templates.TemplateResponse("components/token_list.html", context) + + +@router.post("/tokens/add") +async def add_tokens(request: Request): + """添加 Token""" + from app.services.token_dao import get_token_dao + from app.utils.token_pool import get_token_pool + + form_data = await request.form() + single_token = form_data.get("single_token", "").strip() + bulk_tokens = form_data.get("bulk_tokens", "").strip() + + dao = get_token_dao() + added_count = 0 + failed_count = 0 + + # 添加单个 Token(带验证) + if single_token: + token_id = await dao.add_token( + DEFAULT_TOKEN_NAMESPACE, + single_token, + validate=True, + ) + if token_id: + added_count += 1 + else: + failed_count += 1 + + # 批量添加 Token(带验证) + if bulk_tokens: + # 支持换行和逗号分隔 + tokens = [] + for line in bulk_tokens.split('\n'): + line = line.strip() + if ',' in line: + tokens.extend([t.strip() for t in line.split(',') if t.strip()]) + elif line: + tokens.append(line) + + success, failed = await dao.bulk_add_tokens( + DEFAULT_TOKEN_NAMESPACE, + tokens, + validate=True, + ) + added_count += success + failed_count += failed + + # 同步 Token 池状态(如果有新增成功的 Token) + if added_count > 0: + pool = get_token_pool() + if pool: + await pool.sync_from_database(DEFAULT_TOKEN_NAMESPACE) + logger.info(f"✅ Token 池已同步,新增 {added_count} 个 Token") + + # 生成响应 + if added_count > 0 and failed_count == 0: + return HTMLResponse(f""" + + """) + elif added_count > 0 and failed_count > 0: + return HTMLResponse(f""" + + """) + else: + return HTMLResponse(""" + + """) + + +@router.post("/tokens/import-directory", dependencies=[Depends(require_auth)]) +async def import_tokens_from_directory_api(request: Request): + """从本地目录导入 token 文件。""" + from app.core.config import settings + from app.services.token_automation import run_directory_import + + form_data = await request.form() + source_dir = str( + form_data.get("source_dir") + or settings.TOKEN_AUTO_IMPORT_SOURCE_DIR + or "" + ).strip() + try: + source_dir = _validate_directory_path(source_dir) + except ValueError as exc: + return _build_alert( + str(exc), + title="导入失败!", + level="error", + status_code=400, + ) + + try: + summary = await run_directory_import( + source_dir, + provider=DEFAULT_TOKEN_NAMESPACE, + validate=True, + ) + except (FileNotFoundError, NotADirectoryError) as exc: + return _build_alert( + str(exc), + title="导入失败!", + level="error", + status_code=400, + ) + except RuntimeError as exc: + return _build_alert( + str(exc), + title="导入稍后重试", + level="warning", + status_code=409, + ) + except Exception as exc: + logger.exception(f"❌ 本地目录导入 Token 失败: {exc}") + return _build_alert( + f"目录扫描或入库异常: {exc}", + title="导入失败!", + level="error", + status_code=500, + ) + + if summary.imported_count > 0: + title = "导入成功!" if summary.failed_count == 0 else "导入完成!" + detail = ( + f"目录 {summary.source_dir} 共扫描 {summary.scanned_files} 个文件," + f"成功导入 {summary.imported_count} 个 Token," + f"重复 {summary.duplicate_count} 个," + f"无效 JSON {summary.invalid_json_count} 个," + f"缺少 token {summary.missing_token_count} 个," + f"验证失败 {summary.invalid_token_count} 个。" + ) + return _build_alert( + detail, + title=title, + level="success" if summary.failed_count == 0 else "warning", + ) + + return _build_alert( + ( + f"目录 {summary.source_dir} 共扫描 {summary.scanned_files} 个文件," + f"其中重复 {summary.duplicate_count} 个,无效 JSON {summary.invalid_json_count} 个," + f"缺少 token {summary.missing_token_count} 个,验证失败 {summary.invalid_token_count} 个。" + ), + title="未导入任何 Token!", + level="warning", + ) + + +@router.post("/tokens/auto-import/save", dependencies=[Depends(require_auth)]) +async def save_auto_import_settings(request: Request): + """兼容旧入口,提示用户改到配置管理页。""" + return _build_alert( + "自动导入配置入口已迁移到 /admin/config#tokens,当前页面仅保留手动执行入口。", + title="入口已迁移", + level="info", + ) + + +@router.post("/tokens/maintenance/save", dependencies=[Depends(require_auth)]) +async def save_auto_maintenance_settings(request: Request): + """兼容旧入口,提示用户改到配置管理页。""" + return _build_alert( + "自动维护配置入口已迁移到 /admin/config#tokens,当前页面仅保留手动执行入口。", + title="入口已迁移", + level="info", + ) + + +@router.post("/tokens/maintenance/run", dependencies=[Depends(require_auth)]) +async def run_token_maintenance_api(request: Request): + """立即执行一次 Token 维护。""" + from app.core.config import settings + from app.services.token_automation import run_token_maintenance + + form_data = await request.form() + action_fields = ( + "auto_remove_duplicates", + "auto_health_check", + "auto_delete_invalid", + ) + has_explicit_actions = any(field in form_data for field in action_fields) + + if has_explicit_actions: + remove_duplicates = "auto_remove_duplicates" in form_data + run_health_check = "auto_health_check" in form_data + delete_invalid = "auto_delete_invalid" in form_data + else: + remove_duplicates = settings.TOKEN_AUTO_REMOVE_DUPLICATES + run_health_check = settings.TOKEN_AUTO_HEALTH_CHECK + delete_invalid = settings.TOKEN_AUTO_DELETE_INVALID + + if not any((remove_duplicates, run_health_check, delete_invalid)): + return _build_alert( + "当前没有可执行的维护动作,请先到 /admin/config#tokens 配置至少一个维护动作。", + title="未执行维护!", + level="warning", + status_code=400, + ) + + try: + summary = await run_token_maintenance( + provider=DEFAULT_TOKEN_NAMESPACE, + remove_duplicates=remove_duplicates, + run_health_check=run_health_check, + delete_invalid_tokens=delete_invalid, + ) + except RuntimeError as exc: + return _build_alert( + str(exc), + title="维护稍后重试", + level="warning", + status_code=409, + ) + except Exception as exc: + logger.exception(f"❌ 手动执行 Token 维护失败: {exc}") + return _build_alert( + f"Token 维护失败: {exc}", + title="维护失败!", + level="error", + status_code=500, + ) + + return _build_alert( + ( + f"本次维护共去重 {summary.duplicate_removed_count} 个," + f"测活 {summary.checked_count} 个(有效 {summary.valid_count} / " + f"匿名 {summary.guest_count} / 无效 {summary.invalid_count})," + f"删除失效 Token {summary.deleted_invalid_count} 个。" + ), + title="维护完成!", + level="success", + ) + + +@router.post("/tokens/toggle/{token_id}") +async def toggle_token(token_id: int, enabled: bool): + """切换 Token 启用状态""" + from app.services.token_dao import get_token_dao + from app.utils.token_pool import get_token_pool + + dao = get_token_dao() + await dao.update_token_status(token_id, enabled) + + # 同步 Token 池状态 + pool = get_token_pool() + if pool: + # 获取 Token 的提供商信息 + async with dao.get_connection() as conn: + cursor = await conn.execute("SELECT provider FROM tokens WHERE id = ?", (token_id,)) + row = await cursor.fetchone() + if row: + provider = row[0] + await pool.sync_from_database(provider) + logger.info("✅ Token 池已同步") + + # 根据状态返回不同样式的按钮 + if enabled: + button_class = "bg-green-100 text-green-800 hover:bg-green-200" + indicator_class = "bg-green-500" + label = "已启用" + next_state = "false" + else: + button_class = "bg-red-100 text-red-800 hover:bg-red-200" + indicator_class = "bg-red-500" + label = "已禁用" + next_state = "true" + + return HTMLResponse(f""" + + """) + + +@router.delete("/tokens/delete/{token_id}") +async def delete_token(token_id: int): + """删除 Token""" + from app.services.token_dao import get_token_dao + from app.utils.token_pool import get_token_pool + + dao = get_token_dao() + + # 获取 Token 信息以确定提供商 + async with dao.get_connection() as conn: + cursor = await conn.execute("SELECT provider FROM tokens WHERE id = ?", (token_id,)) + row = await cursor.fetchone() + provider = row[0] if row else "zai" + + await dao.delete_token(token_id) + + # 同步 Token 池状态 + pool = get_token_pool() + if pool: + await pool.sync_from_database(provider) + logger.info("✅ Token 池已同步") + + return HTMLResponse("") # 返回空内容,让 htmx 移除元素 + + +@router.get("/tokens/stats", response_class=HTMLResponse) +async def get_tokens_stats(request: Request): + """获取 Token 统计信息(HTML 片段)""" + stats_data = await collect_admin_stats(DEFAULT_TOKEN_NAMESPACE) + + context = { + "request": request, + "stats": stats_data, + } + + return templates.TemplateResponse("components/token_stats.html", context) + + +@router.post("/tokens/validate") +async def validate_tokens(): + """批量验证 Token""" + from app.services.token_dao import get_token_dao + from app.utils.token_pool import get_token_pool + + dao = get_token_dao() + + # 执行批量验证 + stats = await dao.validate_all_tokens(DEFAULT_TOKEN_NAMESPACE) + + pool = get_token_pool() + if pool: + await pool.sync_from_database(DEFAULT_TOKEN_NAMESPACE) + + valid_count = stats.get("valid", 0) + guest_count = stats.get("guest", 0) + invalid_count = stats.get("invalid", 0) + + # 生成通知消息 + if guest_count > 0: + message_class = "bg-yellow-100 border-yellow-400 text-yellow-700" + message = f"验证完成:有效 {valid_count} 个,匿名 {guest_count} 个,无效 {invalid_count} 个。匿名 Token 已标记。" + elif invalid_count > 0: + message_class = "bg-blue-100 border-blue-400 text-blue-700" + message = f"验证完成:有效 {valid_count} 个,无效 {invalid_count} 个。" + else: + message_class = "bg-green-100 border-green-400 text-green-700" + message = f"验证完成:所有 {valid_count} 个 Token 均有效!" + + return HTMLResponse(f""" + + """) + + +@router.post("/tokens/validate-single/{token_id}") +async def validate_single_token(request: Request, token_id: int): + """验证单个 Token 并返回更新后的行""" + from app.services.token_dao import get_token_dao + from app.utils.token_pool import get_token_pool + + dao = get_token_dao() + + # 验证 Token + await dao.validate_and_update_token(token_id) + + pool = get_token_pool() + if pool: + await pool.sync_from_database(DEFAULT_TOKEN_NAMESPACE) + + # 获取更新后的 Token 信息 + async with dao.get_connection() as conn: + cursor = await conn.execute(""" + SELECT t.*, ts.total_requests, ts.successful_requests, ts.failed_requests, + ts.last_success_time, ts.last_failure_time + FROM tokens t + LEFT JOIN token_stats ts ON t.id = ts.token_id + WHERE t.id = ? + """, (token_id,)) + row = await cursor.fetchone() + + if row: + # 返回更新后的单行 HTML + token = dict(row) + context = { + "request": request, + "token": token, + } + # 使用单行模板渲染 + return templates.TemplateResponse("components/token_row.html", context) + else: + return HTMLResponse("") + + +@router.post("/tokens/health-check") +async def health_check_tokens(): + """执行 Token 池健康检查""" + from app.utils.token_pool import get_token_pool + + pool = get_token_pool() + + if not pool: + return HTMLResponse(""" + + """) + + # 执行健康检查 + await pool.health_check_all() + + # 获取健康状态 + status = pool.get_pool_status() + healthy_count = status.get("healthy_tokens", 0) + total_count = status.get("total_tokens", 0) + + if healthy_count == total_count: + message_class = "bg-green-100 border-green-400 text-green-700" + message = f"所有 {total_count} 个 Token 均健康!" + elif healthy_count > 0: + message_class = "bg-blue-100 border-blue-400 text-blue-700" + message = f"健康检查完成:{healthy_count}/{total_count} 个 Token 健康。" + else: + message_class = "bg-red-100 border-red-400 text-red-700" + message = f"警告:0/{total_count} 个 Token 健康,请检查配置。" + + return HTMLResponse(f""" + + """) + + +@router.post("/tokens/sync-pool") +async def sync_token_pool(): + """手动同步 Token 池(从数据库重新加载)""" + from app.utils.token_pool import get_token_pool + + pool = get_token_pool() + + if not pool: + return HTMLResponse(""" + + """) + + # 从数据库同步 + await pool.sync_from_database(DEFAULT_TOKEN_NAMESPACE) + + # 获取同步后的状态 + status = pool.get_pool_status() + total_count = status.get("total_tokens", 0) + available_count = status.get("available_tokens", 0) + user_count = status.get("user_tokens", 0) + + logger.info( + f"✅ Token 池手动同步完成,总计 {total_count} 个 Token, 可用 {available_count} 个, 认证用户 {user_count} 个" + ) + + if total_count == 0: + message_class = "bg-yellow-100 border-yellow-400 text-yellow-700" + message = "同步完成:当前没有可用 Token,请在数据库中启用 Token。" + elif available_count == 0: + message_class = "bg-orange-100 border-orange-400 text-orange-700" + message = f"同步完成:共 {total_count} 个 Token,但无可用 Token(可能都已禁用)。" + else: + message_class = "bg-green-100 border-green-400 text-green-700" + message = f"同步完成:共 {total_count} 个 Token,{available_count} 个可用,{user_count} 个认证用户。" + + return HTMLResponse(f""" + + """) diff --git a/app/admin/auth.py b/app/admin/auth.py new file mode 100644 index 0000000000000000000000000000000000000000..654cae372fc6425ce820b0296e4b923574d1dae6 --- /dev/null +++ b/app/admin/auth.py @@ -0,0 +1,129 @@ +""" +管理后台认证中间件 +""" +from fastapi import Request, HTTPException, status +from fastapi.responses import RedirectResponse +from typing import Optional +import hashlib +import secrets +from datetime import datetime, timedelta + +from app.core.config import settings + +# 简单的内存 Session 存储(生产环境建议使用 Redis) +_sessions = {} + +# Session 有效期(小时) +SESSION_EXPIRE_HOURS = 24 + + +def generate_session_token() -> str: + """生成随机 session token""" + return secrets.token_urlsafe(32) + + +def create_session(password: str) -> Optional[str]: + """ + 创建 session + + Args: + password: 用户输入的密码 + + Returns: + session_token 或 None(密码错误) + """ + # 验证密码 + if password != settings.ADMIN_PASSWORD: + return None + + # 生成 session token + session_token = generate_session_token() + + # 存储 session(包含过期时间) + _sessions[session_token] = { + "created_at": datetime.now(), + "expires_at": datetime.now() + timedelta(hours=SESSION_EXPIRE_HOURS), + "authenticated": True + } + + return session_token + + +def verify_session(session_token: Optional[str]) -> bool: + """ + 验证 session 是否有效 + + Args: + session_token: Session token + + Returns: + 是否已认证 + """ + if not session_token: + return False + + session = _sessions.get(session_token) + if not session: + return False + + # 检查是否过期 + if datetime.now() > session["expires_at"]: + # 删除过期 session + del _sessions[session_token] + return False + + return session.get("authenticated", False) + + +def delete_session(session_token: Optional[str]): + """删除 session(登出)""" + if session_token and session_token in _sessions: + del _sessions[session_token] + + +def get_session_token_from_request(request: Request) -> Optional[str]: + """从请求中获取 session token""" + return request.cookies.get("admin_session") + + +async def require_auth(request: Request): + """ + 认证依赖项:要求用户已登录 + + 在路由中使用: + @router.get("/admin", dependencies=[Depends(require_auth)]) + """ + session_token = get_session_token_from_request(request) + + if not verify_session(session_token): + # 未认证,重定向到登录页 + raise HTTPException( + status_code=status.HTTP_303_SEE_OTHER, + detail="未登录", + headers={"Location": "/admin/login"} + ) + + +def get_authenticated_user(request: Request) -> bool: + """ + 获取当前认证状态(用于模板) + + Returns: + 是否已认证 + """ + session_token = get_session_token_from_request(request) + return verify_session(session_token) + + +def cleanup_expired_sessions(): + """清理过期的 session(定时任务调用)""" + now = datetime.now() + expired_tokens = [ + token for token, session in _sessions.items() + if now > session["expires_at"] + ] + + for token in expired_tokens: + del _sessions[token] + + return len(expired_tokens) diff --git a/app/admin/config_manager.py b/app/admin/config_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..0b14123c3c5bd4c8b68dad1f247fb204b2fd051c --- /dev/null +++ b/app/admin/config_manager.py @@ -0,0 +1,682 @@ +"""Admin config metadata and helpers for the configuration console.""" + +from __future__ import annotations + +import re +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Awaitable, Callable, Mapping + +from dotenv import dotenv_values + +from app.core.config import settings +from app.utils.env_file import update_env_file +from app.utils.logger import logger + +ENV_PATH = Path(".env") +ENV_EXAMPLE_PATH = Path(".env.example") +_ENV_SOURCE_LINE_PATTERN = re.compile( + r"^\s*(?:export\s+)?[A-Za-z_][A-Za-z0-9_]*\s*=.*$" +) + + +@dataclass(frozen=True) +class ConfigFieldSpec: + key: str + label: str + description: str + value_type: str + default_value: object + input_type: str = "text" + placeholder: str = "" + required: bool = False + wide: bool = False + sensitive: bool = False + restart_required: bool = False + min_value: int | None = None + max_value: int | None = None + + +@dataclass(frozen=True) +class ConfigSectionSpec: + id: str + title: str + description: str + fields: tuple[ConfigFieldSpec, ...] + + +CONFIG_SECTIONS: tuple[ConfigSectionSpec, ...] = ( + ConfigSectionSpec( + id="access", + title="接入与认证", + description="控制上游接口地址、客户端鉴权和 Function Call 行为。", + fields=( + ConfigFieldSpec( + key="API_ENDPOINT", + label="上游 API 地址", + description="代理请求实际转发到的上游聊天完成接口。", + value_type="str", + default_value="https://chat.z.ai/api/v2/chat/completions", + input_type="url", + placeholder="https://chat.z.ai/api/v2/chat/completions", + required=True, + wide=True, + ), + ConfigFieldSpec( + key="AUTH_TOKEN", + label="客户端认证密钥", + description="客户端访问本服务时使用的 Bearer Token。", + value_type="str", + default_value="sk-your-api-key", + input_type="password", + placeholder="sk-your-api-key", + wide=True, + sensitive=True, + ), + ConfigFieldSpec( + key="SKIP_AUTH_TOKEN", + label="跳过客户端认证", + description="仅建议开发环境使用,开启后不校验 AUTH_TOKEN。", + value_type="bool", + default_value=False, + ), + ConfigFieldSpec( + key="TOOL_SUPPORT", + label="启用 Function Call", + description="允许 OpenAI 兼容接口使用工具调用能力。", + value_type="bool", + default_value=True, + ), + ConfigFieldSpec( + key="SCAN_LIMIT", + label="工具调用扫描限制", + description="Function Call 扫描的最大字符数。", + value_type="int", + default_value=200000, + input_type="number", + min_value=1, + placeholder="200000", + ), + ), + ), + ConfigSectionSpec( + id="server", + title="服务运行", + description="服务监听、日志、数据库路径和反向代理前缀。", + fields=( + ConfigFieldSpec( + key="SERVICE_NAME", + label="服务名称", + description="显示在进程列表中的服务名称。", + value_type="str", + default_value="api-proxy-server", + placeholder="api-proxy-server", + required=True, + restart_required=True, + ), + ConfigFieldSpec( + key="LISTEN_PORT", + label="监听端口", + description="HTTP 服务监听端口。", + value_type="int", + default_value=8080, + input_type="number", + min_value=1, + max_value=65535, + required=True, + restart_required=True, + placeholder="8080", + ), + ConfigFieldSpec( + key="ROOT_PATH", + label="反向代理路径前缀", + description="例如 /api,部署在子路径时使用。", + value_type="str", + default_value="", + placeholder="/api", + restart_required=True, + ), + ConfigFieldSpec( + key="DEBUG_LOGGING", + label="启用调试日志", + description="开启后会输出更详细的调试信息。", + value_type="bool", + default_value=False, + ), + ConfigFieldSpec( + key="DB_PATH", + label="数据库路径", + description="SQLite 数据库文件位置。", + value_type="str", + default_value="tokens.db", + placeholder="tokens.db", + required=True, + wide=True, + restart_required=True, + ), + ), + ), + ConfigSectionSpec( + id="tokens", + title="Token 池策略", + description="失败判定、恢复时间和自动导入、自动维护计划任务。", + fields=( + ConfigFieldSpec( + key="TOKEN_FAILURE_THRESHOLD", + label="失败阈值", + description="连续失败多少次后将 Token 标记为不可用。", + value_type="int", + default_value=3, + input_type="number", + min_value=1, + required=True, + restart_required=True, + ), + ConfigFieldSpec( + key="TOKEN_RECOVERY_TIMEOUT", + label="恢复超时(秒)", + description="失败 Token 重新参与调度前的等待时间。", + value_type="int", + default_value=1800, + input_type="number", + min_value=1, + required=True, + restart_required=True, + ), + ConfigFieldSpec( + key="TOKEN_AUTO_IMPORT_ENABLED", + label="启用自动导入", + description="按固定周期扫描服务端目录并导入 Token。", + value_type="bool", + default_value=False, + ), + ConfigFieldSpec( + key="TOKEN_AUTO_IMPORT_SOURCE_DIR", + label="自动导入目录", + description="服务端本地目录,开启自动导入时需要可访问。", + value_type="str", + default_value="", + placeholder="E:\\tokens\\input", + wide=True, + ), + ConfigFieldSpec( + key="TOKEN_AUTO_IMPORT_INTERVAL", + label="自动导入间隔(秒)", + description="自动导入的扫描周期。", + value_type="int", + default_value=300, + input_type="number", + min_value=1, + required=True, + ), + ConfigFieldSpec( + key="TOKEN_AUTO_MAINTENANCE_ENABLED", + label="启用自动维护", + description="定时执行去重、健康检查和删除失效 Token。", + value_type="bool", + default_value=False, + ), + ConfigFieldSpec( + key="TOKEN_AUTO_MAINTENANCE_INTERVAL", + label="自动维护间隔(秒)", + description="自动维护的执行周期。", + value_type="int", + default_value=1800, + input_type="number", + min_value=1, + required=True, + ), + ConfigFieldSpec( + key="TOKEN_AUTO_REMOVE_DUPLICATES", + label="自动去重", + description="自动维护时清理重复 Token。", + value_type="bool", + default_value=True, + ), + ConfigFieldSpec( + key="TOKEN_AUTO_HEALTH_CHECK", + label="自动健康检查", + description="自动维护时验证 Token 可用性。", + value_type="bool", + default_value=True, + ), + ConfigFieldSpec( + key="TOKEN_AUTO_DELETE_INVALID", + label="自动删除失效 Token", + description="自动维护时移除已验证为无效的 Token。", + value_type="bool", + default_value=False, + ), + ), + ), + ConfigSectionSpec( + id="guest", + title="匿名 Guest 会话池", + description="没有用户 Token 时,仅控制是否启用匿名池和池容量。", + fields=( + ConfigFieldSpec( + key="ANONYMOUS_MODE", + label="启用匿名模式", + description="无可用用户 Token 时允许使用匿名会话。", + value_type="bool", + default_value=True, + restart_required=True, + ), + ConfigFieldSpec( + key="GUEST_POOL_SIZE", + label="Guest 池容量", + description="启动和维持的 guest 会话数量。", + value_type="int", + default_value=3, + input_type="number", + min_value=1, + required=True, + restart_required=True, + ), + ), + ), + ConfigSectionSpec( + id="models", + title="模型映射", + description="映射 OpenAI 兼容模型名到上游 Z.AI 实际模型名。", + fields=( + ConfigFieldSpec( + key="GLM45_MODEL", + label="GLM 4.5", + description="标准 GLM 4.5 模型标识。", + value_type="str", + default_value="GLM-4.5", + placeholder="GLM-4.5", + required=True, + ), + ConfigFieldSpec( + key="GLM45_THINKING_MODEL", + label="GLM 4.5 Thinking", + description="推理增强版 GLM 4.5 模型标识。", + value_type="str", + default_value="GLM-4.5-Thinking", + placeholder="GLM-4.5-Thinking", + required=True, + ), + ConfigFieldSpec( + key="GLM45_SEARCH_MODEL", + label="GLM 4.5 Search", + description="搜索增强版 GLM 4.5 模型标识。", + value_type="str", + default_value="GLM-4.5-Search", + placeholder="GLM-4.5-Search", + required=True, + ), + ConfigFieldSpec( + key="GLM45_AIR_MODEL", + label="GLM 4.5 Air", + description="轻量版 GLM 4.5 模型标识。", + value_type="str", + default_value="GLM-4.5-Air", + placeholder="GLM-4.5-Air", + required=True, + ), + ConfigFieldSpec( + key="GLM46V_MODEL", + label="GLM 4.6V", + description="视觉模型标识。", + value_type="str", + default_value="GLM-4.6V", + placeholder="GLM-4.6V", + required=True, + ), + ConfigFieldSpec( + key="GLM5_MODEL", + label="GLM 5", + description="GLM 5 模型标识。", + value_type="str", + default_value="GLM-5", + placeholder="GLM-5", + required=True, + ), + ConfigFieldSpec( + key="GLM47_MODEL", + label="GLM 4.7", + description="GLM 4.7 主模型标识。", + value_type="str", + default_value="GLM-4.7", + placeholder="GLM-4.7", + required=True, + ), + ConfigFieldSpec( + key="GLM47_THINKING_MODEL", + label="GLM 4.7 Thinking", + description="GLM 4.7 推理版模型标识。", + value_type="str", + default_value="GLM-4.7-Thinking", + placeholder="GLM-4.7-Thinking", + required=True, + ), + ConfigFieldSpec( + key="GLM47_SEARCH_MODEL", + label="GLM 4.7 Search", + description="GLM 4.7 搜索版模型标识。", + value_type="str", + default_value="GLM-4.7-Search", + placeholder="GLM-4.7-Search", + required=True, + ), + ConfigFieldSpec( + key="GLM47_ADVANCED_SEARCH_MODEL", + label="GLM 4.7 Advanced Search", + description="GLM 4.7 高级搜索模型标识。", + value_type="str", + default_value="GLM-4.7-advanced-search", + placeholder="GLM-4.7-advanced-search", + required=True, + wide=True, + ), + ), + ), + ConfigSectionSpec( + id="proxy", + title="代理网络", + description="上游访问使用的 HTTP、HTTPS 和 SOCKS5 代理。", + fields=( + ConfigFieldSpec( + key="HTTP_PROXY", + label="HTTP 代理", + description="例如 http://127.0.0.1:7890。", + value_type="str", + default_value="", + placeholder="http://127.0.0.1:7890", + wide=True, + ), + ConfigFieldSpec( + key="HTTPS_PROXY", + label="HTTPS 代理", + description="例如 http://127.0.0.1:7890。", + value_type="str", + default_value="", + placeholder="http://127.0.0.1:7890", + wide=True, + ), + ConfigFieldSpec( + key="SOCKS5_PROXY", + label="SOCKS5 代理", + description="例如 socks5://127.0.0.1:1080。", + value_type="str", + default_value="", + placeholder="socks5://127.0.0.1:1080", + wide=True, + ), + ), + ), + ConfigSectionSpec( + id="admin", + title="后台安全", + description="管理后台密码和会话密钥。修改后建议重新登录。", + fields=( + ConfigFieldSpec( + key="ADMIN_PASSWORD", + label="后台密码", + description="管理后台登录密码。", + value_type="str", + default_value="admin123", + input_type="password", + placeholder="admin123", + required=True, + sensitive=True, + ), + ConfigFieldSpec( + key="SESSION_SECRET_KEY", + label="会话密钥", + description="用于后台会话签名的密钥。", + value_type="str", + default_value="your-secret-key-change-in-production", + input_type="password", + placeholder="your-secret-key-change-in-production", + required=True, + sensitive=True, + wide=True, + ), + ), + ), +) + +CONFIG_FIELD_SPECS = { + field.key: field + for section in CONFIG_SECTIONS + for field in section.fields +} +MANAGED_ENV_KEYS = tuple(CONFIG_FIELD_SPECS.keys()) +ReloadCallback = Callable[[], Awaitable[None]] + + +def read_env_content(env_path: str | Path = ENV_PATH) -> str: + path = Path(env_path) + if not path.exists(): + return "" + return path.read_text(encoding="utf-8") + + +def validate_env_source(content: str) -> str: + normalized = content.replace("\r\n", "\n").replace("\r", "\n") + + for line_number, line in enumerate(normalized.splitlines(), start=1): + stripped = line.strip() + if not stripped or stripped.startswith("#"): + continue + if not _ENV_SOURCE_LINE_PATTERN.match(line): + raise ValueError( + f"第 {line_number} 行不是合法的 KEY=VALUE 格式。" + ) + + return normalized + + +def build_config_page_data( + *, + settings_obj: Any = settings, + env_path: str | Path = ENV_PATH, + env_example_path: str | Path = ENV_EXAMPLE_PATH, +) -> dict[str, Any]: + env_file = Path(env_path) + env_content = read_env_content(env_file) + env_values = dotenv_values(env_file) if env_file.exists() else {} + sections: list[dict[str, Any]] = [] + total_fields = 0 + overridden_fields = 0 + sensitive_fields = 0 + restart_required_fields = 0 + + for section in CONFIG_SECTIONS: + rendered_fields: list[dict[str, Any]] = [] + for field in section.fields: + total_fields += 1 + if field.sensitive: + sensitive_fields += 1 + if field.restart_required: + restart_required_fields += 1 + + is_overridden = field.key in env_values + if is_overridden: + overridden_fields += 1 + + value = getattr(settings_obj, field.key, field.default_value) + if value is None: + value = "" + + rendered_fields.append( + { + "key": field.key, + "label": field.label, + "description": field.description, + "value_type": field.value_type, + "value": value, + "input_type": field.input_type, + "placeholder": field.placeholder, + "required": field.required, + "wide": field.wide, + "sensitive": field.sensitive, + "restart_required": field.restart_required, + "min_value": field.min_value, + "max_value": field.max_value, + "source_label": ".env" if is_overridden else "默认值", + "source_badge_class": ( + "bg-emerald-50 text-emerald-700 ring-emerald-200" + if is_overridden + else "bg-slate-100 text-slate-600 ring-slate-200" + ), + } + ) + + sections.append( + { + "id": section.id, + "title": section.title, + "description": section.description, + "fields": rendered_fields, + "field_count": len(rendered_fields), + } + ) + + return { + "sections": sections, + "env_content": env_content, + "overview": { + "total_sections": len(CONFIG_SECTIONS), + "total_fields": total_fields, + "overridden_fields": overridden_fields, + "default_fields": total_fields - overridden_fields, + "sensitive_fields": sensitive_fields, + "restart_required_fields": restart_required_fields, + "env_exists": env_file.exists(), + "env_path": str(env_file.resolve()), + "env_line_count": len(env_content.splitlines()) if env_content else 0, + "example_exists": Path(env_example_path).exists(), + }, + } + + +def build_form_updates(form_data: Mapping[str, Any]) -> dict[str, object]: + updates: dict[str, object] = {} + + for key in MANAGED_ENV_KEYS: + field = CONFIG_FIELD_SPECS[key] + + if field.value_type == "bool": + updates[key] = key in form_data + continue + + raw_value = str(form_data.get(key, "") or "").strip() + if field.required and raw_value == "": + raise ValueError(f"{field.label} 不能为空。") + + if field.value_type == "int": + try: + parsed = int(raw_value) + except ValueError as exc: + raise ValueError(f"{field.label} 必须是整数。") from exc + + if field.min_value is not None and parsed < field.min_value: + raise ValueError( + f"{field.label} 不能小于 {field.min_value}。" + ) + if field.max_value is not None and parsed > field.max_value: + raise ValueError( + f"{field.label} 不能大于 {field.max_value}。" + ) + updates[key] = parsed + continue + + updates[key] = raw_value + + return updates + + +async def _apply_env_change( + writer: Callable[[Path], None], + *, + reload_callback: ReloadCallback, + env_path: str | Path = ENV_PATH, +) -> None: + path = Path(env_path) + had_existing_file = path.exists() + previous_content = read_env_content(path) if had_existing_file else "" + + try: + writer(path) + await reload_callback() + except Exception: + if had_existing_file: + path.write_text(previous_content, encoding="utf-8") + elif path.exists(): + path.unlink() + + try: + await reload_callback() + except Exception as restore_exc: + logger.warning(f"⚠️ 回滚配置后重新加载失败: {restore_exc}") + raise + + +async def save_form_config( + form_data: Mapping[str, Any], + *, + reload_callback: ReloadCallback, + env_path: str | Path = ENV_PATH, +) -> dict[str, object]: + updates = build_form_updates(form_data) + + async def _reload() -> None: + await reload_callback() + + def _writer(target_path: Path) -> None: + update_env_file(updates, env_path=target_path) + + await _apply_env_change(_writer, reload_callback=_reload, env_path=env_path) + return updates + + +async def save_source_config( + env_content: str, + *, + reload_callback: ReloadCallback, + env_path: str | Path = ENV_PATH, +) -> None: + normalized = validate_env_source(env_content) + + def _writer(target_path: Path) -> None: + content = normalized.rstrip("\n") + target_path.write_text( + f"{content}\n" if content else "", + encoding="utf-8", + ) + + await _apply_env_change( + _writer, + reload_callback=reload_callback, + env_path=env_path, + ) + + +async def reset_env_to_example( + *, + reload_callback: ReloadCallback, + env_path: str | Path = ENV_PATH, + env_example_path: str | Path = ENV_EXAMPLE_PATH, +) -> None: + example_path = Path(env_example_path) + if not example_path.exists(): + raise FileNotFoundError(".env.example 不存在") + + example_content = example_path.read_text(encoding="utf-8") + + def _writer(target_path: Path) -> None: + content = example_content.rstrip("\n") + target_path.write_text( + f"{content}\n" if content else "", + encoding="utf-8", + ) + + await _apply_env_change( + _writer, + reload_callback=reload_callback, + env_path=env_path, + ) diff --git a/app/admin/routes.py b/app/admin/routes.py new file mode 100644 index 0000000000000000000000000000000000000000..98a0b2e8702f75fb951012cb360f0402f4a38bd5 --- /dev/null +++ b/app/admin/routes.py @@ -0,0 +1,109 @@ +""" +管理后台路由模块 +""" +from datetime import datetime + +from fastapi import APIRouter, Depends, Request +from fastapi.responses import HTMLResponse +from fastapi.templating import Jinja2Templates + +from app.admin.auth import require_auth +from app.admin.config_manager import build_config_page_data +from app.admin.stats import ( + DEFAULT_TREND_WINDOW, + TREND_WINDOW_OPTIONS, + collect_admin_stats, + get_process_uptime, +) + +router = APIRouter(prefix="/admin", tags=["admin"]) +templates = Jinja2Templates(directory="app/templates") +DEFAULT_TOKEN_NAMESPACE = "zai" + + +@router.get("/login", response_class=HTMLResponse) +async def login_page(request: Request): + """登录页面""" + return templates.TemplateResponse("login.html", {"request": request}) + + +@router.get("/", response_class=HTMLResponse, dependencies=[Depends(require_auth)]) +async def dashboard(request: Request): + """仪表盘首页""" + stats = await collect_admin_stats( + DEFAULT_TOKEN_NAMESPACE, + trend_window=DEFAULT_TREND_WINDOW, + ) + stats["uptime"] = get_process_uptime() + + context = { + "request": request, + "stats": stats, + "current_time": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), + "trend_windows": TREND_WINDOW_OPTIONS, + } + + return templates.TemplateResponse("index.html", context) + + +@router.get( + "/config", + response_class=HTMLResponse, + dependencies=[Depends(require_auth)], +) +async def config_page(request: Request): + """配置管理页面""" + page_data = build_config_page_data() + + context = { + "request": request, + "sections": page_data["sections"], + "env_content": page_data["env_content"], + "overview": page_data["overview"], + } + return templates.TemplateResponse("config.html", context) + + +@router.get("/logs", response_class=HTMLResponse, dependencies=[Depends(require_auth)]) +async def logs_page(request: Request): + """实时日志页面""" + context = { + "request": request, + } + return templates.TemplateResponse("logs.html", context) + + +@router.get( + "/tokens", + response_class=HTMLResponse, + dependencies=[Depends(require_auth)], +) +async def tokens_page(request: Request): + """Token 管理页面""" + from app.core.config import settings + + maintenance_actions: list[str] = [] + if settings.TOKEN_AUTO_REMOVE_DUPLICATES: + maintenance_actions.append("删除重复 Token") + if settings.TOKEN_AUTO_HEALTH_CHECK: + maintenance_actions.append("批量测活") + if settings.TOKEN_AUTO_DELETE_INVALID: + maintenance_actions.append("删除失效 Token") + + context = { + "request": request, + "automation": { + "config_url": "/admin/config#tokens", + "import_enabled": settings.TOKEN_AUTO_IMPORT_ENABLED, + "import_source_dir": settings.TOKEN_AUTO_IMPORT_SOURCE_DIR, + "import_interval": settings.TOKEN_AUTO_IMPORT_INTERVAL, + "has_import_source_dir": bool( + settings.TOKEN_AUTO_IMPORT_SOURCE_DIR.strip() + ), + "maintenance_enabled": settings.TOKEN_AUTO_MAINTENANCE_ENABLED, + "maintenance_interval": settings.TOKEN_AUTO_MAINTENANCE_INTERVAL, + "maintenance_actions": maintenance_actions, + "has_maintenance_actions": bool(maintenance_actions), + }, + } + return templates.TemplateResponse("tokens.html", context) diff --git a/app/admin/stats.py b/app/admin/stats.py new file mode 100644 index 0000000000000000000000000000000000000000..2299621f406564cd1fca2e1e70237e669bc8ebce --- /dev/null +++ b/app/admin/stats.py @@ -0,0 +1,184 @@ +"""管理后台统计聚合辅助函数。""" + +from __future__ import annotations + +import os +import time +from typing import Any, Dict, Optional + +import psutil + +from app.services.request_log_dao import RequestLogDAO, get_request_log_dao +from app.services.token_dao import TokenDAO, get_token_dao +from app.utils.token_pool import TokenPool, get_token_pool + +_TOKEN_POOL_SENTINEL = object() +DEFAULT_TREND_WINDOW = "7d" +TREND_WINDOW_OPTIONS = ( + {"key": "24h", "label": "24 小时"}, + {"key": "7d", "label": "7 天"}, + {"key": "30d", "label": "30 天"}, +) + + +def _coerce_int(value: Any) -> int: + """将数据库聚合结果安全转换为整数。""" + return int(value or 0) + + +def calculate_success_rate( + successful_requests: int, + total_requests: int, +) -> float: + """计算成功率百分比。""" + if total_requests <= 0: + return 0.0 + return round(successful_requests / total_requests * 100, 1) + + +def format_compact_number(value: Any) -> str: + """格式化大数字,便于仪表盘展示。""" + number = int(value or 0) + if number >= 1_000_000: + return f"{number / 1_000_000:.1f}M" + if number >= 10_000: + return f"{number / 10_000:.1f}万" + if number >= 1_000: + return f"{number / 1_000:.1f}k" + return str(number) + + +def normalize_trend_window(value: Any) -> str: + """规范化趋势窗口参数,非法值回退到默认值。""" + normalized = str(value or "").strip().lower() + if normalized in {"24h", "7d", "30d"}: + return normalized + if normalized == "1d": + return "24h" + return DEFAULT_TREND_WINDOW + + +def format_uptime(total_seconds: int) -> str: + """格式化运行时长。""" + total_seconds = max(0, int(total_seconds)) + days, remainder = divmod(total_seconds, 86400) + hours, remainder = divmod(remainder, 3600) + minutes, seconds = divmod(remainder, 60) + + parts = [] + if days: + parts.append(f"{days}天") + if days or hours: + parts.append(f"{hours}小时") + if days or hours or minutes: + parts.append(f"{minutes}分钟") + parts.append(f"{seconds}秒") + + return " ".join(parts) + + +def get_process_uptime() -> str: + """获取当前进程运行时长。""" + created_at = psutil.Process(os.getpid()).create_time() + return format_uptime(int(time.time() - created_at)) + + +async def collect_admin_stats( + provider: str, + *, + token_dao: Optional[TokenDAO] = None, + request_log_dao: Optional[RequestLogDAO] = None, + token_pool: Any = _TOKEN_POOL_SENTINEL, + trend_window: str = DEFAULT_TREND_WINDOW, +) -> Dict[str, Any]: + """聚合管理后台所需的 Token 与请求统计。""" + token_dao = token_dao or get_token_dao() + request_log_dao = request_log_dao or get_request_log_dao() + if token_pool is _TOKEN_POOL_SENTINEL: + token_pool = get_token_pool() + trend_window = normalize_trend_window(trend_window) + + token_counts = await token_dao.get_provider_token_counts(provider) + request_stats = await request_log_dao.get_provider_request_stats(provider) + usage_trend = await request_log_dao.get_provider_usage_trend( + provider, + window=trend_window, + ) + + pool_status: Dict[str, Any] = {} + if isinstance(token_pool, TokenPool) or hasattr(token_pool, "get_pool_status"): + pool_status = token_pool.get_pool_status() if token_pool else {} + + total_tokens = _coerce_int(token_counts.get("total_tokens")) + enabled_tokens = _coerce_int(token_counts.get("enabled_tokens")) + user_tokens = _coerce_int(token_counts.get("user_tokens")) + guest_tokens = _coerce_int(token_counts.get("guest_tokens")) + unknown_tokens = _coerce_int(token_counts.get("unknown_tokens")) + + pool_total_tokens = _coerce_int(pool_status.get("total_tokens")) + if pool_total_tokens == 0 and token_pool is None: + pool_total_tokens = max(0, enabled_tokens - guest_tokens) + + available_tokens = _coerce_int(pool_status.get("available_tokens")) + healthy_tokens = _coerce_int(pool_status.get("healthy_tokens")) + unhealthy_tokens = _coerce_int(pool_status.get("unhealthy_tokens")) + + total_requests = _coerce_int(request_stats.get("total_requests")) + successful_requests = _coerce_int(request_stats.get("successful_requests")) + failed_requests = _coerce_int(request_stats.get("failed_requests")) + input_tokens = _coerce_int(request_stats.get("input_tokens")) + output_tokens = _coerce_int(request_stats.get("output_tokens")) + total_consumed_tokens = _coerce_int(request_stats.get("total_tokens")) + cache_creation_tokens = _coerce_int( + request_stats.get("cache_creation_tokens") + ) + cache_read_tokens = _coerce_int(request_stats.get("cache_read_tokens")) + cache_creation_requests = _coerce_int( + request_stats.get("cache_creation_requests") + ) + cache_hit_requests = _coerce_int(request_stats.get("cache_hit_requests")) + average_latency = round(float(request_stats.get("avg_duration") or 0.0), 2) + average_first_token_latency = round( + float(request_stats.get("avg_first_token_time") or 0.0), + 2, + ) + total_cache_tokens = cache_creation_tokens + cache_read_tokens + + return { + "total_tokens": total_tokens, + "enabled_tokens": enabled_tokens, + "user_tokens": user_tokens, + "guest_tokens": guest_tokens, + "unknown_tokens": unknown_tokens, + "pool_total_tokens": pool_total_tokens, + "available_tokens": available_tokens, + "healthy_tokens": healthy_tokens, + "unhealthy_tokens": unhealthy_tokens, + "total_requests": total_requests, + "successful_requests": successful_requests, + "failed_requests": failed_requests, + "input_tokens": input_tokens, + "output_tokens": output_tokens, + "total_consumed_tokens": total_consumed_tokens, + "cache_creation_tokens": cache_creation_tokens, + "cache_read_tokens": cache_read_tokens, + "total_cache_tokens": total_cache_tokens, + "cache_creation_requests": cache_creation_requests, + "cache_hit_requests": cache_hit_requests, + "average_latency": average_latency, + "average_first_token_latency": average_first_token_latency, + "trend_window": trend_window, + "usage_trend": usage_trend, + "total_consumed_tokens_display": format_compact_number( + total_consumed_tokens + ), + "total_cache_tokens_display": format_compact_number( + total_cache_tokens + ), + "input_tokens_display": format_compact_number(input_tokens), + "output_tokens_display": format_compact_number(output_tokens), + "success_rate": calculate_success_rate( + successful_requests, + total_requests, + ), + } diff --git a/app/core/__init__.py b/app/core/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6323e5f05522c2198c151ae459d5de7790089fc8 --- /dev/null +++ b/app/core/__init__.py @@ -0,0 +1,6 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +from app.core import claude, config, openai + +__all__ = ["claude", "config", "openai"] diff --git a/app/core/claude.py b/app/core/claude.py new file mode 100644 index 0000000000000000000000000000000000000000..d817ccd7427f5cef05f778f4ed7c59199c61d23d --- /dev/null +++ b/app/core/claude.py @@ -0,0 +1,582 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import json +import math +import time +import uuid +from typing import Any, AsyncGenerator, Dict, List, Optional + +from fastapi import APIRouter, Header, Request +from fastapi.responses import JSONResponse, StreamingResponse + +from app.core.claude_compat import ( + build_non_stream_response, + claude_messages_to_openai, + claude_tool_choice_to_openai, + claude_tools_to_openai, + extract_text, + make_claude_id, + sse_content_block_delta, + sse_content_block_start, + sse_content_block_stop, + sse_error, + sse_message_delta, + sse_message_start, + sse_message_stop, + sse_ping, +) +from app.core.config import settings +from app.core.openai import get_upstream_client +from app.models.schemas import Message, OpenAIRequest +from app.utils.logger import get_logger +from app.utils.request_logging import ( + extract_openai_usage, + extract_claude_usage, + wrap_claude_stream_with_logging, + write_request_log, +) +from app.utils.request_source import detect_request_source, format_request_source + +logger = get_logger() +router = APIRouter() + + +def _resolve_claude_model(model: Any) -> str: + """Map Claude/Claude Code model aliases to local upstream-supported models.""" + if not isinstance(model, str) or not model.strip(): + return settings.GLM5_MODEL + + raw_model = model.strip() + normalized = raw_model.casefold() + if normalized.endswith("[1m]"): + normalized = normalized[:-4].rstrip() + + direct_models = { + settings.GLM45_MODEL.casefold(): settings.GLM45_MODEL, + settings.GLM45_THINKING_MODEL.casefold(): settings.GLM45_THINKING_MODEL, + settings.GLM45_SEARCH_MODEL.casefold(): settings.GLM45_SEARCH_MODEL, + settings.GLM45_AIR_MODEL.casefold(): settings.GLM45_AIR_MODEL, + settings.GLM46V_MODEL.casefold(): settings.GLM46V_MODEL, + settings.GLM5_MODEL.casefold(): settings.GLM5_MODEL, + settings.GLM47_MODEL.casefold(): settings.GLM47_MODEL, + settings.GLM47_THINKING_MODEL.casefold(): settings.GLM47_THINKING_MODEL, + settings.GLM47_SEARCH_MODEL.casefold(): settings.GLM47_SEARCH_MODEL, + settings.GLM47_ADVANCED_SEARCH_MODEL.casefold(): settings.GLM47_ADVANCED_SEARCH_MODEL, + } + if normalized in direct_models: + return direct_models[normalized] + + alias_map = { + "default": settings.GLM5_MODEL, + "sonnet": settings.GLM5_MODEL, + "haiku": settings.GLM45_AIR_MODEL, + "opus": settings.GLM5_MODEL, + "opusplan": settings.GLM47_THINKING_MODEL, + } + if normalized in alias_map: + return alias_map[normalized] + + if normalized.startswith("claude-sonnet") or normalized.startswith("claude-3-7-sonnet") or normalized.startswith("claude-3-5-sonnet"): + return settings.GLM5_MODEL + if normalized.startswith("claude-opus") or normalized.startswith("claude-4-opus"): + return settings.GLM5_MODEL + if normalized.startswith("claude-haiku") or normalized.startswith("claude-3-5-haiku"): + return settings.GLM45_AIR_MODEL + + return raw_model + + +def _estimate_tokens(text: str) -> int: + if not text: + return 0 + return max(1, math.ceil(len(text) / 2)) + + +def _extract_api_key( + authorization: Optional[str], + x_api_key: Optional[str], +) -> Optional[str]: + if x_api_key: + return x_api_key + if authorization and authorization.startswith("Bearer "): + return authorization[7:] + return None + + +def _claude_error_response( + message: str, + status_code: int, + error_type: str, +) -> JSONResponse: + return JSONResponse( + status_code=status_code, + content={ + "type": "error", + "error": {"type": error_type, "message": message}, + }, + ) + + +def _build_openai_request(body: Dict[str, Any]) -> OpenAIRequest: + system = body.get("system") + claude_messages = body.get("messages", []) + openai_messages = claude_messages_to_openai(system, claude_messages) + openai_tools = claude_tools_to_openai(body.get("tools")) + tool_choice = claude_tool_choice_to_openai(body.get("tool_choice")) + + thinking = body.get("thinking") + enable_thinking = None + if isinstance(thinking, dict): + thinking_type = thinking.get("type") + if thinking_type == "enabled": + enable_thinking = True + elif thinking_type == "disabled": + enable_thinking = False + + messages = [Message.model_validate(message) for message in openai_messages] + resolved_model = _resolve_claude_model(body.get("model", settings.GLM5_MODEL)) + if resolved_model != body.get("model", settings.GLM5_MODEL): + logger.info( + f"🔀 Claude 模型映射: " + f"{body.get('model', settings.GLM5_MODEL)} -> {resolved_model}" + ) + + return OpenAIRequest( + model=resolved_model, + messages=messages, + stream=bool(body.get("stream", False)), + temperature=body.get("temperature"), + max_tokens=body.get("max_tokens"), + tools=openai_tools, + tool_choice=tool_choice, + enable_thinking=enable_thinking, + ) + + +def _build_prompt_text(body: Dict[str, Any]) -> str: + prompt_parts: List[str] = [] + system = body.get("system") + if system: + prompt_parts.append(extract_text(system)) + + for message in body.get("messages", []): + content = message.get("content") if isinstance(message, dict) else None + text = extract_text(content) + if text: + prompt_parts.append(text) + + return "\n".join(part for part in prompt_parts if part) + + +def _normalize_tool_calls(tool_calls: Any) -> List[Dict[str, Any]]: + if not isinstance(tool_calls, list): + return [] + + normalized: List[Dict[str, Any]] = [] + seen_ids = set() + for tool_call in tool_calls: + if not isinstance(tool_call, dict): + continue + + tool_call_id = tool_call.get("id") or f"call_{uuid.uuid4().hex[:24]}" + if tool_call_id in seen_ids: + continue + seen_ids.add(tool_call_id) + + function_data = ( + tool_call.get("function") + if isinstance(tool_call.get("function"), dict) + else {} + ) + arguments = function_data.get("arguments", "{}") + if not isinstance(arguments, str): + try: + arguments = json.dumps(arguments, ensure_ascii=False) + except Exception: + arguments = "{}" + + normalized.append( + { + "id": tool_call_id, + "type": "function", + "function": { + "name": function_data.get("name", ""), + "arguments": arguments, + }, + } + ) + + return normalized + + +def _convert_openai_response_to_claude(response: Dict[str, Any], msg_id: str) -> Dict[str, Any]: + choice = ((response.get("choices") or [{}])[0]) if isinstance(response, dict) else {} + message = choice.get("message") or {} + reasoning = message.get("reasoning_content") + usage = extract_openai_usage(response) + return build_non_stream_response( + msg_id=msg_id, + model=response.get("model", settings.GLM5_MODEL), + reasoning_parts=[reasoning] if isinstance(reasoning, str) and reasoning else [], + answer_text=message.get("content") or "", + tool_calls=_normalize_tool_calls(message.get("tool_calls")), + input_tokens=usage["input_tokens"], + output_tokens=usage["output_tokens"], + cache_creation_tokens=usage["cache_creation_tokens"], + cache_read_tokens=usage["cache_read_tokens"], + ) + + +async def _stream_openai_to_claude( + openai_stream: AsyncGenerator[str, None], + msg_id: str, + model: str, + input_tokens: int, +) -> AsyncGenerator[str, None]: + reasoning_parts: List[str] = [] + answer_parts: List[str] = [] + tool_calls: List[Dict[str, Any]] = [] + block_index = 0 + thinking_started = False + final_input_tokens = input_tokens + final_output_tokens = 0 + cache_creation_tokens = 0 + cache_read_tokens = 0 + + yield sse_message_start(msg_id, model, input_tokens) + yield sse_ping() + + try: + async for chunk in openai_stream: + if not chunk.startswith("data: "): + continue + + payload_text = chunk[6:].strip() + if not payload_text or payload_text == "[DONE]": + continue + + payload = json.loads(payload_text) + if isinstance(payload, dict) and "error" in payload: + error = payload.get("error") or {} + yield sse_error( + error.get("type", "api_error"), + error.get("message", "Upstream error"), + ) + return + + choice = ((payload.get("choices") or [{}])[0]) if isinstance(payload, dict) else {} + delta = choice.get("delta") or {} + + reasoning_delta = delta.get("reasoning_content") + if reasoning_delta: + if not thinking_started: + yield sse_content_block_start( + block_index, + {"type": "thinking", "thinking": ""}, + ) + thinking_started = True + + reasoning_parts.append(reasoning_delta) + yield sse_content_block_delta( + block_index, + {"type": "thinking_delta", "thinking": reasoning_delta}, + ) + + content_delta = delta.get("content") + if content_delta: + answer_parts.append(content_delta) + + if payload.get("usage"): + usage = extract_openai_usage(payload) + if usage["input_tokens"] > 0: + final_input_tokens = usage["input_tokens"] + if usage["output_tokens"] > 0: + final_output_tokens = usage["output_tokens"] + if usage["cache_creation_tokens"] > 0: + cache_creation_tokens = usage["cache_creation_tokens"] + if usage["cache_read_tokens"] > 0: + cache_read_tokens = usage["cache_read_tokens"] + + tool_calls.extend(_normalize_tool_calls(delta.get("tool_calls"))) + + if thinking_started: + yield sse_content_block_stop(block_index) + block_index += 1 + + answer_text = "".join(answer_parts) + if answer_text: + yield sse_content_block_start(block_index, {"type": "text", "text": ""}) + yield sse_content_block_delta( + block_index, + {"type": "text_delta", "text": answer_text}, + ) + yield sse_content_block_stop(block_index) + block_index += 1 + + if tool_calls: + for tool_call in tool_calls: + function_data = tool_call.get("function") or {} + tool_id = tool_call.get( + "id", + f"toolu_{uuid.uuid4().hex[:20]}", + ).replace("call_", "toolu_") + yield sse_content_block_start( + block_index, + { + "type": "tool_use", + "id": tool_id, + "name": function_data.get("name", ""), + "input": {}, + }, + ) + yield sse_content_block_delta( + block_index, + { + "type": "input_json_delta", + "partial_json": function_data.get("arguments", "{}"), + }, + ) + yield sse_content_block_stop(block_index) + block_index += 1 + + if not final_output_tokens: + final_output_tokens = _estimate_tokens( + "".join(reasoning_parts) + answer_text + ) + + yield sse_message_delta( + "tool_use" if tool_calls else "end_turn", + final_output_tokens, + input_tokens=final_input_tokens, + cache_creation_tokens=cache_creation_tokens, + cache_read_tokens=cache_read_tokens, + ) + yield sse_message_stop() + except Exception as exc: + logger.error(f"❌ Claude 流式响应转换失败: {exc}") + yield sse_error("api_error", str(exc)) + + +@router.post("/v1/messages") +@router.post("/anthropic/v1/messages") +async def claude_messages( + request: Request, + authorization: Optional[str] = Header(None), + x_api_key: Optional[str] = Header(None, alias="x-api-key"), +): + source_info = detect_request_source( + request, + protocol_hint="anthropic", + ) + source_prefix = format_request_source(source_info) + started_at = time.perf_counter() + requested_model = "unknown" + + try: + body = await request.json() + except Exception: + await write_request_log( + provider="zai", + model=requested_model, + source_info=source_info, + success=False, + started_at=started_at, + status_code=400, + error_message="Invalid JSON body", + ) + return _claude_error_response( + "Invalid JSON body", + 400, + "invalid_request_error", + ) + + requested_model = str(body.get("model") or "unknown") + source_info = detect_request_source( + request, + protocol_hint="anthropic", + model_hint=body.get("model"), + ) + source_prefix = format_request_source(source_info) + + if not settings.SKIP_AUTH_TOKEN: + api_key = _extract_api_key(authorization, x_api_key) + if not api_key: + await write_request_log( + provider="zai", + model=requested_model, + source_info=source_info, + success=False, + started_at=started_at, + status_code=401, + error_message="Missing API key", + ) + return _claude_error_response( + "Missing API key", + 401, + "authentication_error", + ) + if api_key != settings.AUTH_TOKEN: + await write_request_log( + provider="zai", + model=requested_model, + source_info=source_info, + success=False, + started_at=started_at, + status_code=401, + error_message="Invalid API key", + ) + return _claude_error_response( + "Invalid API key", + 401, + "authentication_error", + ) + + try: + openai_request = _build_openai_request(body) + except Exception as exc: + await write_request_log( + provider="zai", + model=requested_model, + source_info=source_info, + success=False, + started_at=started_at, + status_code=400, + error_message=f"Invalid request: {exc}", + ) + return _claude_error_response( + f"Invalid request: {exc}", + 400, + "invalid_request_error", + ) + + if not openai_request.messages: + await write_request_log( + provider="zai", + model=openai_request.model, + source_info=source_info, + success=False, + started_at=started_at, + status_code=400, + error_message="messages is required", + ) + return _claude_error_response( + "messages is required", + 400, + "invalid_request_error", + ) + logger.info( + f"{source_prefix} 🤖 收到 Claude 请求 - 模型: {body.get('model')}, 映射模型: {openai_request.model}, 流式: {openai_request.stream}, 消息数: {len(openai_request.messages)}, 工具数: {len(openai_request.tools) if openai_request.tools else 0}" + ) + + msg_id = make_claude_id() + input_tokens = _estimate_tokens(_build_prompt_text(body)) + + try: + client = get_upstream_client() + result = await client.chat_completion(openai_request) + except Exception as exc: + logger.error(f"{source_prefix} ❌ Claude 请求处理失败: {exc}") + await write_request_log( + provider="zai", + model=openai_request.model, + source_info=source_info, + success=False, + started_at=started_at, + status_code=500, + error_message=str(exc), + ) + return _claude_error_response(str(exc), 500, "api_error") + + if isinstance(result, dict) and "error" in result: + error = result.get("error") or {} + error_code = error.get("code") + status_code = error_code if isinstance(error_code, int) else 500 + await write_request_log( + provider="zai", + model=openai_request.model, + source_info=source_info, + success=False, + started_at=started_at, + status_code=status_code, + error_message=error.get("message", "Unknown upstream error"), + ) + return _claude_error_response( + error.get("message", "Unknown upstream error"), + status_code, + error.get("type", "api_error"), + ) + + if openai_request.stream: + if not hasattr(result, "__aiter__"): + await write_request_log( + provider="zai", + model=openai_request.model, + source_info=source_info, + success=False, + started_at=started_at, + status_code=500, + error_message="Expected streaming response", + ) + return _claude_error_response( + "Expected streaming response", + 500, + "api_error", + ) + + return StreamingResponse( + wrap_claude_stream_with_logging( + _stream_openai_to_claude( + result, + msg_id, + openai_request.model, + input_tokens, + ), + provider="zai", + model=openai_request.model, + source_info=source_info, + started_at=started_at, + input_tokens=input_tokens, + ), + media_type="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "Access-Control-Allow-Origin": "*", + }, + ) + + if not isinstance(result, dict): + await write_request_log( + provider="zai", + model=openai_request.model, + source_info=source_info, + success=False, + started_at=started_at, + status_code=500, + error_message="Expected non-streaming response payload", + ) + return _claude_error_response( + "Expected non-streaming response payload", + 500, + "api_error", + ) + + response_data = _convert_openai_response_to_claude(result, msg_id) + if not response_data.get("usage", {}).get("input_tokens"): + response_data["usage"]["input_tokens"] = input_tokens + usage = extract_claude_usage(response_data) + await write_request_log( + provider="zai", + model=openai_request.model, + source_info=source_info, + success=True, + started_at=started_at, + status_code=200, + input_tokens=usage["input_tokens"], + output_tokens=usage["output_tokens"], + cache_creation_tokens=usage["cache_creation_tokens"], + cache_read_tokens=usage["cache_read_tokens"], + total_tokens=usage["total_tokens"], + ) + return JSONResponse(content=response_data) diff --git a/app/core/claude_compat.py b/app/core/claude_compat.py new file mode 100644 index 0000000000000000000000000000000000000000..ae26ec1ab2de23547a39f926d5de6197b4e2427d --- /dev/null +++ b/app/core/claude_compat.py @@ -0,0 +1,352 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +"""Claude Messages API 兼容辅助函数。""" + +from __future__ import annotations + +import json +import uuid +from typing import Any, Optional + + +def extract_text(content: Any) -> str: + """Extract plain text from Claude/OpenAI mixed content blocks.""" + if isinstance(content, str): + return content + + if isinstance(content, list): + return " ".join( + str(block.get("text", "")) + for block in content + if isinstance(block, dict) and block.get("type") == "text" + ).strip() + + return str(content) if content else "" + + +def claude_messages_to_openai(system: Any, messages: list[dict]) -> list[dict]: + """Convert Claude messages payload into OpenAI-style messages.""" + converted: list[dict] = [] + + if system: + if isinstance(system, str): + converted.append({"role": "system", "content": system}) + elif isinstance(system, list): + system_text = [ + block.get("text", "") + for block in system + if isinstance(block, dict) and block.get("type") == "text" + ] + if system_text: + converted.append({ + "role": "system", + "content": "\n".join(system_text), + }) + + for message in messages: + role = message.get("role", "user") + content = message.get("content", "") + + if role == "assistant" and isinstance(content, list): + text_parts: list[str] = [] + tool_calls: list[dict] = [] + + for block in content: + if not isinstance(block, dict): + continue + + block_type = block.get("type") + if block_type == "text": + text_parts.append(block.get("text", "")) + elif block_type == "tool_use": + tool_calls.append( + { + "id": block.get( + "id", + f"call_{uuid.uuid4().hex[:24]}", + ), + "type": "function", + "function": { + "name": block.get("name", ""), + "arguments": json.dumps( + block.get("input", {}), + ensure_ascii=False, + ), + }, + } + ) + + openai_message: dict = { + "role": "assistant", + "content": " ".join(text_parts).strip() or None, + } + if tool_calls: + openai_message["tool_calls"] = tool_calls + converted.append(openai_message) + continue + + if role == "user" and isinstance(content, list): + has_tool_result = any( + isinstance(block, dict) and block.get("type") == "tool_result" + for block in content + ) + if has_tool_result: + for block in content: + if not isinstance(block, dict): + continue + + block_type = block.get("type") + if block_type == "tool_result": + result_content = block.get("content", "") + if isinstance(result_content, str): + rendered = result_content + elif isinstance(result_content, list): + rendered = " ".join( + item.get("text", "") + for item in result_content + if isinstance(item, dict) + and item.get("type") == "text" + ) + else: + rendered = str(result_content) + + converted.append( + { + "role": "tool", + "tool_call_id": block.get("tool_use_id", ""), + "content": rendered, + } + ) + elif block_type == "text": + converted.append( + {"role": "user", "content": block.get("text", "")} + ) + continue + + converted.append({"role": role, "content": extract_text(content)}) + + return converted + + +def claude_tools_to_openai(tools: Optional[list[dict]]) -> Optional[list[dict]]: + """Convert Claude tool schemas into OpenAI function tools.""" + if not tools: + return None + + converted = [ + { + "type": "function", + "function": { + "name": tool.get("name", ""), + "description": tool.get("description", ""), + "parameters": tool.get("input_schema", {}), + }, + } + for tool in tools + if isinstance(tool, dict) + ] + return converted or None + + +def claude_tool_choice_to_openai(tool_choice: Any) -> Any: + """Convert Claude tool_choice payload into OpenAI-compatible form.""" + if not isinstance(tool_choice, dict): + return tool_choice + + tool_choice_type = tool_choice.get("type", "auto") + if tool_choice_type == "auto": + return "auto" + if tool_choice_type == "any": + return "required" + if tool_choice_type == "none": + return "none" + if tool_choice_type == "tool": + name = tool_choice.get("name", "") + if name: + return {"type": "function", "function": {"name": name}} + return tool_choice + + +def make_claude_id() -> str: + """Generate a Claude-style message id.""" + return f"msg_{uuid.uuid4().hex[:24]}" + + +def build_tool_call_blocks(tool_calls: list[dict]) -> list[dict]: + """Convert OpenAI tool calls to Claude tool_use blocks.""" + blocks = [] + for tool_call in tool_calls: + function_data = ( + tool_call.get("function") + if isinstance(tool_call.get("function"), dict) + else {} + ) + arguments = function_data.get("arguments", "{}") + try: + input_data = json.loads(arguments) if isinstance(arguments, str) else arguments + except Exception: + input_data = {} + + blocks.append( + { + "type": "tool_use", + "id": tool_call.get( + "id", + f"toolu_{uuid.uuid4().hex[:20]}", + ).replace("call_", "toolu_"), + "name": function_data.get("name", ""), + "input": input_data, + } + ) + return blocks + + +def build_non_stream_response( + msg_id: str, + model: str, + reasoning_parts: list[str], + answer_text: str, + tool_calls: Optional[list[dict]], + input_tokens: int, + output_tokens: int, + cache_creation_tokens: int = 0, + cache_read_tokens: int = 0, +) -> dict: + """Build a Claude non-streaming message response.""" + content: list[dict] = [] + if reasoning_parts: + content.append( + {"type": "thinking", "thinking": "".join(reasoning_parts)} + ) + if answer_text: + content.append({"type": "text", "text": answer_text}) + elif not tool_calls: + content.append({"type": "text", "text": ""}) + if tool_calls: + content.extend(build_tool_call_blocks(tool_calls)) + + return { + "id": msg_id, + "type": "message", + "role": "assistant", + "content": content, + "model": model, + "stop_reason": "tool_use" if tool_calls else "end_turn", + "stop_sequence": None, + "usage": { + "input_tokens": input_tokens, + "output_tokens": output_tokens, + "cache_creation_input_tokens": cache_creation_tokens, + "cache_read_input_tokens": cache_read_tokens, + }, + } + + +def sse(event: str, data: dict) -> str: + """Format a Claude SSE event.""" + return f"event: {event}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n" + + +def sse_message_start( + msg_id: str, + model: str, + input_tokens: int, + cache_creation_tokens: int = 0, + cache_read_tokens: int = 0, +) -> str: + """Create Claude message_start SSE event.""" + return sse( + "message_start", + { + "type": "message_start", + "message": { + "id": msg_id, + "type": "message", + "role": "assistant", + "content": [], + "model": model, + "stop_reason": None, + "stop_sequence": None, + "usage": { + "input_tokens": input_tokens, + "cache_creation_input_tokens": cache_creation_tokens, + "cache_read_input_tokens": cache_read_tokens, + "output_tokens": 0, + }, + }, + }, + ) + + +def sse_ping() -> str: + """Create Claude ping SSE event.""" + return sse("ping", {"type": "ping"}) + + +def sse_content_block_start(index: int, block: dict) -> str: + """Create Claude content_block_start SSE event.""" + return sse( + "content_block_start", + { + "type": "content_block_start", + "index": index, + "content_block": block, + }, + ) + + +def sse_content_block_delta(index: int, delta: dict) -> str: + """Create Claude content_block_delta SSE event.""" + return sse( + "content_block_delta", + {"type": "content_block_delta", "index": index, "delta": delta}, + ) + + +def sse_content_block_stop(index: int) -> str: + """Create Claude content_block_stop SSE event.""" + return sse( + "content_block_stop", + {"type": "content_block_stop", "index": index}, + ) + + +def sse_message_delta( + stop_reason: str, + output_tokens: int, + *, + input_tokens: int = 0, + cache_creation_tokens: int = 0, + cache_read_tokens: int = 0, +) -> str: + """Create Claude message_delta SSE event.""" + return sse( + "message_delta", + { + "type": "message_delta", + "delta": {"stop_reason": stop_reason, "stop_sequence": None}, + "usage": { + "input_tokens": input_tokens, + "output_tokens": output_tokens, + "cache_creation_input_tokens": cache_creation_tokens, + "cache_read_input_tokens": cache_read_tokens, + }, + }, + ) + + +def sse_message_stop() -> str: + """Create Claude message_stop SSE event.""" + return sse("message_stop", {"type": "message_stop"}) + + +def sse_error(error_type: str, message: str) -> str: + """Create Claude error SSE event.""" + return sse( + "error", + { + "type": "error", + "error": {"type": error_type, "message": message}, + }, + ) diff --git a/app/core/config.py b/app/core/config.py new file mode 100644 index 0000000000000000000000000000000000000000..d830de82a00dc1e1bc21b17195215140d7def612 --- /dev/null +++ b/app/core/config.py @@ -0,0 +1,95 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import os +from typing import Optional + +from pydantic_settings import BaseSettings, SettingsConfigDict + + +class Settings(BaseSettings): + """Application settings""" + + # API Configuration + API_ENDPOINT: str = "https://chat.z.ai/api/v2/chat/completions" + + # Authentication + AUTH_TOKEN: Optional[str] = os.getenv("AUTH_TOKEN") + + # Token池配置 + TOKEN_FAILURE_THRESHOLD: int = int( + os.getenv("TOKEN_FAILURE_THRESHOLD", "3") + ) + TOKEN_RECOVERY_TIMEOUT: int = int( + os.getenv("TOKEN_RECOVERY_TIMEOUT", "1800") + ) + TOKEN_AUTO_IMPORT_ENABLED: bool = ( + os.getenv("TOKEN_AUTO_IMPORT_ENABLED", "false").lower() == "true" + ) + TOKEN_AUTO_IMPORT_SOURCE_DIR: str = os.getenv("TOKEN_AUTO_IMPORT_SOURCE_DIR", "") + TOKEN_AUTO_IMPORT_INTERVAL: int = int( + os.getenv("TOKEN_AUTO_IMPORT_INTERVAL", "300") + ) + TOKEN_AUTO_MAINTENANCE_ENABLED: bool = ( + os.getenv("TOKEN_AUTO_MAINTENANCE_ENABLED", "false").lower() == "true" + ) + TOKEN_AUTO_MAINTENANCE_INTERVAL: int = int( + os.getenv("TOKEN_AUTO_MAINTENANCE_INTERVAL", "1800") + ) + TOKEN_AUTO_REMOVE_DUPLICATES: bool = ( + os.getenv("TOKEN_AUTO_REMOVE_DUPLICATES", "true").lower() == "true" + ) + TOKEN_AUTO_HEALTH_CHECK: bool = ( + os.getenv("TOKEN_AUTO_HEALTH_CHECK", "true").lower() == "true" + ) + TOKEN_AUTO_DELETE_INVALID: bool = ( + os.getenv("TOKEN_AUTO_DELETE_INVALID", "false").lower() == "true" + ) + + # Model Configuration + GLM45_MODEL: str = os.getenv("GLM45_MODEL", "GLM-4.5") + GLM45_THINKING_MODEL: str = os.getenv("GLM45_THINKING_MODEL", "GLM-4.5-Thinking") + GLM45_SEARCH_MODEL: str = os.getenv("GLM45_SEARCH_MODEL", "GLM-4.5-Search") + GLM45_AIR_MODEL: str = os.getenv("GLM45_AIR_MODEL", "GLM-4.5-Air") + GLM46V_MODEL: str = os.getenv("GLM46V_MODEL", "GLM-4.6V") + GLM5_MODEL: str = os.getenv("GLM5_MODEL", "GLM-5") + GLM47_MODEL: str = os.getenv("GLM47_MODEL", "GLM-4.7") + GLM47_THINKING_MODEL: str = os.getenv("GLM47_THINKING_MODEL", "GLM-4.7-Thinking") + GLM47_SEARCH_MODEL: str = os.getenv("GLM47_SEARCH_MODEL", "GLM-4.7-Search") + GLM47_ADVANCED_SEARCH_MODEL: str = os.getenv( + "GLM47_ADVANCED_SEARCH_MODEL", + "GLM-4.7-advanced-search", + ) + + # Server Configuration + LISTEN_PORT: int = int(os.getenv("LISTEN_PORT", "8080")) + DEBUG_LOGGING: bool = os.getenv("DEBUG_LOGGING", "true").lower() == "true" + SERVICE_NAME: str = os.getenv("SERVICE_NAME", "api-proxy-server") + ROOT_PATH: str = os.getenv("ROOT_PATH", "") + + ANONYMOUS_MODE: bool = os.getenv("ANONYMOUS_MODE", "true").lower() == "true" + GUEST_POOL_SIZE: int = int(os.getenv("GUEST_POOL_SIZE", "3")) + TOOL_SUPPORT: bool = os.getenv("TOOL_SUPPORT", "true").lower() == "true" + SCAN_LIMIT: int = int(os.getenv("SCAN_LIMIT", "200000")) + SKIP_AUTH_TOKEN: bool = os.getenv("SKIP_AUTH_TOKEN", "false").lower() == "true" + + # Proxy Configuration + HTTP_PROXY: Optional[str] = os.getenv("HTTP_PROXY") + HTTPS_PROXY: Optional[str] = os.getenv("HTTPS_PROXY") + SOCKS5_PROXY: Optional[str] = os.getenv("SOCKS5_PROXY") + + # Admin Panel Authentication + ADMIN_PASSWORD: str = os.getenv("ADMIN_PASSWORD", "admin123") + SESSION_SECRET_KEY: str = os.getenv( + "SESSION_SECRET_KEY", + "your-secret-key-change-in-production", + ) + DB_PATH: str = os.getenv("DB_PATH", "tokens.db") + + model_config = SettingsConfigDict( + env_file=".env", + extra="ignore", # 忽略额外字段,防止环境变量中的未知字段导致验证错误 + ) + + +settings = Settings() diff --git a/app/core/openai.py b/app/core/openai.py new file mode 100644 index 0000000000000000000000000000000000000000..b17b3e97f331af11d017511f916d3292ba23a765 --- /dev/null +++ b/app/core/openai.py @@ -0,0 +1,224 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import json +import time +from typing import Optional + +from fastapi import APIRouter, Header, HTTPException, Request +from fastapi.responses import JSONResponse, StreamingResponse + +from app.core.config import settings +from app.models.schemas import ( + Choice, + Message, + Model, + ModelsResponse, + OpenAIRequest, + OpenAIResponse, + Usage, +) +from app.core.upstream import UpstreamClient +from app.utils.logger import get_logger +from app.utils.request_logging import ( + extract_openai_usage, + wrap_openai_stream_with_logging, + write_request_log, +) +from app.utils.request_source import detect_request_source, format_request_source + +logger = get_logger() +router = APIRouter() + +_upstream_client: Optional[UpstreamClient] = None + + +def get_upstream_client() -> UpstreamClient: + """获取懒加载的上游适配器单例。""" + global _upstream_client + if _upstream_client is None: + _upstream_client = UpstreamClient() + return _upstream_client + + +async def handle_non_stream_response(stream_response, request: OpenAIRequest) -> JSONResponse: + """处理非流式响应。""" + logger.info("📄 开始处理非流式响应") + + full_content = [] + async for chunk_data in stream_response(): + if chunk_data.startswith("data: "): + chunk_str = chunk_data[6:].strip() + if chunk_str and chunk_str != "[DONE]": + try: + chunk = json.loads(chunk_str) + if "choices" in chunk and chunk["choices"]: + choice = chunk["choices"][0] + if "delta" in choice and "content" in choice["delta"]: + content = choice["delta"]["content"] + if content: + full_content.append(content) + except json.JSONDecodeError: + continue + + response_data = OpenAIResponse( + id=f"chatcmpl-{int(time.time())}", + object="chat.completion", + created=int(time.time()), + model=request.model, + choices=[ + Choice( + index=0, + message=Message( + role="assistant", + content="".join(full_content), + tool_calls=None, + ), + finish_reason="stop", + ) + ], + usage=Usage(prompt_tokens=0, completion_tokens=0, total_tokens=0), + ) + + logger.info("✅ 非流式响应处理完成") + return JSONResponse(content=response_data.model_dump(exclude_none=True)) + + +@router.get("/v1/models") +async def list_models(): + """返回当前服务支持的模型列表。""" + try: + client = get_upstream_client() + current_time = int(time.time()) + response = ModelsResponse( + data=[ + Model(id=model_id, created=current_time, owned_by=settings.SERVICE_NAME) + for model_id in client.get_supported_models() + ] + ) + return JSONResponse(content=response.model_dump(exclude_none=True)) + except Exception as exc: + logger.error(f"❌ 获取模型列表失败: {exc}") + raise HTTPException(status_code=500, detail=f"Failed to list models: {exc}") + + +@router.post("/v1/chat/completions") +async def chat_completions( + body: OpenAIRequest, + http_request: Request, + authorization: Optional[str] = Header(None), +): + """直接调用上游适配器处理请求。""" + source_info = detect_request_source( + http_request, + protocol_hint="openai", + model_hint=body.model, + ) + source_prefix = format_request_source(source_info) + started_at = time.perf_counter() + + role = body.messages[0].role if body.messages else "unknown" + logger.info( + f"{source_prefix} 😶‍🌫️ 收到客户端请求 - 模型: {body.model}, 流式: {body.stream}, 消息数: {len(body.messages)}, 角色: {role}, 工具数: {len(body.tools) if body.tools else 0}" + ) + + try: + if not settings.SKIP_AUTH_TOKEN: + if not authorization or not authorization.startswith("Bearer "): + raise HTTPException(status_code=401, detail="Missing or invalid Authorization header") + + api_key = authorization[7:] + if api_key != settings.AUTH_TOKEN: + raise HTTPException(status_code=401, detail="Invalid API key") + + client = get_upstream_client() + result = await client.chat_completion(body) + + if isinstance(result, dict) and "error" in result: + error_info = result["error"] + error_message = error_info.get("message", "Unknown upstream error") + error_code = error_info.get("code") + status_code = 404 if error_code == "model_not_found" else 500 + raise HTTPException(status_code=status_code, detail=error_message) + + if body.stream: + if hasattr(result, "__aiter__"): + return StreamingResponse( + wrap_openai_stream_with_logging( + result, + provider="zai", + model=body.model, + source_info=source_info, + started_at=started_at, + ), + media_type="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "Access-Control-Allow-Origin": "*", + }, + ) + raise HTTPException( + status_code=500, + detail="Expected streaming response but got non-streaming result", + ) + + if isinstance(result, dict): + usage = extract_openai_usage(result) + await write_request_log( + provider="zai", + model=body.model, + source_info=source_info, + success="error" not in result, + started_at=started_at, + status_code=200 if "error" not in result else 500, + input_tokens=usage["input_tokens"], + output_tokens=usage["output_tokens"], + cache_creation_tokens=usage["cache_creation_tokens"], + cache_read_tokens=usage["cache_read_tokens"], + total_tokens=usage["total_tokens"], + error_message=(result.get("error") or {}).get("message") if isinstance(result, dict) else None, + ) + return JSONResponse(content=result) + + response = await handle_non_stream_response(result, body) + response_body = json.loads(response.body) + usage = extract_openai_usage(response_body) + await write_request_log( + provider="zai", + model=body.model, + source_info=source_info, + success=True, + started_at=started_at, + status_code=200, + input_tokens=usage["input_tokens"], + output_tokens=usage["output_tokens"], + cache_creation_tokens=usage["cache_creation_tokens"], + cache_read_tokens=usage["cache_read_tokens"], + total_tokens=usage["total_tokens"], + ) + return response + + except HTTPException as exc: + await write_request_log( + provider="zai", + model=body.model, + source_info=source_info, + success=False, + started_at=started_at, + status_code=exc.status_code, + error_message=str(exc.detail), + ) + raise + except Exception as exc: + logger.error(f"{source_prefix} ❌ 请求处理失败: {exc}") + await write_request_log( + provider="zai", + model=body.model, + source_info=source_info, + success=False, + started_at=started_at, + status_code=500, + error_message=str(exc), + ) + raise HTTPException(status_code=500, detail=f"Internal server error: {str(exc)}") diff --git a/app/core/openai_compat.py b/app/core/openai_compat.py new file mode 100644 index 0000000000000000000000000000000000000000..cccdc094e617e447e1bdb11cd68afec2c982a72a --- /dev/null +++ b/app/core/openai_compat.py @@ -0,0 +1,139 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +"""OpenAI 兼容响应辅助函数。""" + +import json +import time +import uuid +from typing import Any, Dict, List, Optional + +from app.utils.logger import get_logger + +logger = get_logger() +SYSTEM_FINGERPRINT = "fp_api_proxy_001" + + +def create_chat_id() -> str: + """生成聊天 ID。""" + return f"chatcmpl-{uuid.uuid4().hex}" + + +def create_openai_chunk( + chat_id: str, + model: str, + delta: Dict[str, Any], + finish_reason: Optional[str] = None, +) -> Dict[str, Any]: + """创建 OpenAI 格式的流式响应块。""" + return { + "id": chat_id, + "object": "chat.completion.chunk", + "created": int(time.time()), + "model": model, + "choices": [ + { + "index": 0, + "delta": delta, + "finish_reason": finish_reason, + "logprobs": None, + } + ], + "system_fingerprint": SYSTEM_FINGERPRINT, + } + + +def create_openai_response( + chat_id: str, + model: str, + content: str, + usage: Optional[Dict[str, int]] = None, +) -> Dict[str, Any]: + """创建 OpenAI 格式的非流式响应。""" + return { + "id": chat_id, + "object": "chat.completion", + "created": int(time.time()), + "model": model, + "choices": [ + { + "index": 0, + "message": {"role": "assistant", "content": content}, + "finish_reason": "stop", + "logprobs": None, + } + ], + "usage": usage + or { + "prompt_tokens": 0, + "completion_tokens": 0, + "total_tokens": 0, + }, + "system_fingerprint": SYSTEM_FINGERPRINT, + } + + +def create_openai_response_with_reasoning( + chat_id: str, + model: str, + content: str, + reasoning_content: Optional[str] = None, + usage: Optional[Dict[str, int]] = None, + tool_calls: Optional[List[Dict[str, Any]]] = None, +) -> Dict[str, Any]: + """创建包含 reasoning/tool_calls 的 OpenAI 响应。""" + message: Dict[str, Any] = { + "role": "assistant", + "content": content, + } + + if reasoning_content and reasoning_content.strip(): + message["reasoning_content"] = reasoning_content + + if tool_calls: + message["tool_calls"] = tool_calls + + return { + "id": chat_id, + "object": "chat.completion", + "created": int(time.time()), + "model": model, + "choices": [ + { + "index": 0, + "message": message, + "finish_reason": "tool_calls" if tool_calls else "stop", + "logprobs": None, + } + ], + "usage": usage + or { + "prompt_tokens": 0, + "completion_tokens": 0, + "total_tokens": 0, + }, + "system_fingerprint": SYSTEM_FINGERPRINT, + } + + +async def format_sse_chunk(chunk: Dict[str, Any]) -> str: + """格式化 SSE 响应块。""" + return f"data: {json.dumps(chunk, ensure_ascii=False)}\n\n" + + +async def format_sse_done() -> str: + """格式化 SSE 结束标记。""" + return "data: [DONE]\n\n" + + +def handle_error(error: Exception, context: str = "") -> Dict[str, Any]: + """统一错误处理。""" + error_msg = f"上游{context}错误: {str(error)}" if context else f"上游错误: {str(error)}" + logger.error(error_msg) + return { + "error": { + "message": error_msg, + "type": "upstream_error", + "code": "internal_error", + } + } diff --git a/app/core/upstream.py b/app/core/upstream.py new file mode 100644 index 0000000000000000000000000000000000000000..ab183a5fd15b336829b7133c113d6630001461e3 --- /dev/null +++ b/app/core/upstream.py @@ -0,0 +1,2245 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +"""上游适配器。""" + +import asyncio +import base64 +import json +import random +import time +import uuid +from datetime import datetime, timezone +from typing import Any, AsyncGenerator, Dict, List, Optional, Set, Tuple, Union +from urllib.parse import urlencode + +import httpx + +from app.core.config import settings +from app.core.openai_compat import ( + create_openai_chunk, + create_openai_response_with_reasoning, + format_sse_chunk, + handle_error, +) +from app.models.schemas import OpenAIRequest +from app.utils.fe_version import get_latest_fe_version +from app.utils.guest_session_pool import get_guest_session_pool +from app.utils.logger import get_logger +from app.utils.signature import generate_signature +from app.utils.token_pool import get_token_pool +from app.utils.tool_call_handler import ( + parse_and_extract_tool_calls, +) +from app.utils.user_agent import get_random_user_agent + +logger = get_logger() + +DEFAULT_ZAI_BASE_URL = "https://chat.z.ai" +CHAT_BOOTSTRAP_MAX_CONTENT_LEN = 500 +DEFAULT_PLATFORM = "web" +DEFAULT_CLIENT_VERSION = "0.0.1" +DEFAULT_TIMEZONE = "Asia/Shanghai" +DEFAULT_LANGUAGE = "zh-CN" +DEFAULT_SCREEN_WIDTH = "1920" +DEFAULT_SCREEN_HEIGHT = "1080" +DEFAULT_VIEWPORT_WIDTH = "944" +DEFAULT_VIEWPORT_HEIGHT = "919" +DEFAULT_VIEWPORT_SIZE = f"{DEFAULT_VIEWPORT_WIDTH}x{DEFAULT_VIEWPORT_HEIGHT}" +DEFAULT_SCREEN_RESOLUTION = f"{DEFAULT_SCREEN_WIDTH}x{DEFAULT_SCREEN_HEIGHT}" +DEFAULT_COLOR_DEPTH = "24" +DEFAULT_PIXEL_RATIO = "1.25" +DEFAULT_MAX_TOUCH_POINTS = "10" +DEFAULT_TIMEZONE_OFFSET = "-480" +DEFAULT_PAGE_TITLE = "Z.ai Chat Proxy" +DEFAULT_COMPLETION_FEATURES = [ + {"type": "mcp", "server": "vibe-coding", "status": "hidden"}, + {"type": "mcp", "server": "ppt-maker", "status": "hidden"}, + {"type": "mcp", "server": "image-search", "status": "hidden"}, + {"type": "mcp", "server": "deep-research", "status": "hidden"}, + {"type": "tool_selector", "server": "tool_selector", "status": "hidden"}, + {"type": "mcp", "server": "advanced-search", "status": "hidden"}, +] +GLM46V_MCP_SERVERS = [ + "vlm-image-search", + "vlm-image-recognition", + "vlm-image-processing", +] +GLM46V_SELECTED_FEATURES = [ + {"type": "mcp", "server": "vlm-image-search", "status": "selected"}, + {"type": "mcp", "server": "vlm-image-recognition", "status": "selected"}, + {"type": "mcp", "server": "vlm-image-processing", "status": "selected"}, +] + +def generate_uuid() -> str: + """生成UUID v4""" + return str(uuid.uuid4()) + +def get_dynamic_headers( + chat_id: str = "", + browser_type: Optional[str] = None, +) -> Dict[str, str]: + """生成上游请求所需的动态浏览器 headers。""" + browser_choices = [ + "chrome", + "chrome", + "chrome", + "edge", + "edge", + "firefox", + "safari", + ] + selected_browser = browser_type or random.choice(browser_choices) + user_agent = get_random_user_agent(selected_browser) + fe_version = get_latest_fe_version() + + chrome_version = "139" + edge_version = "139" + + if "Chrome/" in user_agent: + try: + chrome_version = user_agent.split("Chrome/")[1].split(".")[0] + except Exception: + pass + + if "Edg/" in user_agent: + try: + edge_version = user_agent.split("Edg/")[1].split(".")[0] + sec_ch_ua = ( + f'"Microsoft Edge";v="{edge_version}", ' + f'"Chromium";v="{chrome_version}", "Not_A Brand";v="24"' + ) + except Exception: + sec_ch_ua = ( + f'"Not_A Brand";v="8", "Chromium";v="{chrome_version}", ' + f'"Google Chrome";v="{chrome_version}"' + ) + elif "Firefox/" in user_agent: + sec_ch_ua = None + else: + sec_ch_ua = ( + f'"Not_A Brand";v="8", "Chromium";v="{chrome_version}", ' + f'"Google Chrome";v="{chrome_version}"' + ) + + headers = { + "Content-Type": "application/json", + "Accept": "application/json, text/event-stream", + "Connection": "keep-alive", + "Cache-Control": "no-cache", + "User-Agent": user_agent, + "Accept-Language": "zh-CN,zh;q=0.9,en;q=0.8", + "X-FE-Version": fe_version, + "Origin": "https://chat.z.ai", + } + + if sec_ch_ua: + headers["sec-ch-ua"] = sec_ch_ua + headers["sec-ch-ua-mobile"] = "?0" + headers["sec-ch-ua-platform"] = '"Windows"' + + if chat_id: + headers["Referer"] = f"https://chat.z.ai/c/{chat_id}" + else: + headers["Referer"] = "https://chat.z.ai/" + + return headers + +def _urlsafe_b64decode(data: str) -> bytes: + """Decode a URL-safe base64 string with proper padding.""" + if isinstance(data, str): + data_bytes = data.encode("utf-8") + else: + data_bytes = data + padding = b"=" * (-len(data_bytes) % 4) + return base64.urlsafe_b64decode(data_bytes + padding) + + +def _decode_jwt_payload(token: str) -> Dict[str, Any]: + """Decode JWT payload without verification to extract metadata.""" + try: + parts = token.split(".") + if len(parts) < 2: + return {} + payload_raw = _urlsafe_b64decode(parts[1]) + return json.loads(payload_raw.decode("utf-8", errors="ignore")) + except Exception: + return {} + + +def _extract_user_id_from_token(token: str) -> str: + """Extract user_id from a JWT's payload. Fallback to 'guest'.""" + payload = _decode_jwt_payload(token) if token else {} + for key in ("id", "user_id", "uid", "sub"): + val = payload.get(key) + if isinstance(val, (str, int)) and str(val): + return str(val) + return "guest" + + +def _extract_text_from_content(content: Any) -> str: + """Extract text parts from OpenAI-compatible content payloads.""" + if isinstance(content, str): + return content + + if isinstance(content, list): + parts: List[str] = [] + for item in content: + if isinstance(item, dict) and item.get("type") == "text": + parts.append(str(item.get("text", ""))) + return " ".join(part for part in parts if part).strip() + + if content is None: + return "" + + try: + return json.dumps(content, ensure_ascii=False) + except Exception: + return str(content) + + +def _stringify_tool_arguments(arguments: Any) -> str: + """Normalize tool-call arguments into a JSON string.""" + if isinstance(arguments, str): + return arguments + + try: + return json.dumps(arguments or {}, ensure_ascii=False) + except Exception: + return "{}" + + +def _build_tool_call_index( + messages: List[Dict[str, Any]], +) -> Dict[str, Dict[str, str]]: + """Index assistant tool calls by id for later tool-result messages.""" + index: Dict[str, Dict[str, str]] = {} + + for message in messages: + if message.get("role") != "assistant": + continue + + tool_calls = message.get("tool_calls") + if not isinstance(tool_calls, list): + continue + + for tool_call in tool_calls: + if not isinstance(tool_call, dict): + continue + + tool_call_id = tool_call.get("id") + function_data = ( + tool_call.get("function") + if isinstance(tool_call.get("function"), dict) + else {} + ) + name = str(function_data.get("name", "")).strip() + if not isinstance(tool_call_id, str) or not name: + continue + + index[tool_call_id] = { + "name": name, + "arguments": _stringify_tool_arguments( + function_data.get("arguments") + ), + } + + return index + + +def _format_tool_result_message( + tool_name: str, + tool_arguments: str, + result_content: str, +) -> str: + """Serialize a tool result into a text block the upstream can consume.""" + return ( + "\n" + f"{tool_name}\n" + f"{tool_arguments}\n" + f"{result_content}\n" + "" + ) + + +def _format_assistant_tool_calls(tool_calls: List[Dict[str, Any]]) -> str: + """Serialize historical assistant tool calls into a text block.""" + blocks: List[str] = [] + + for tool_call in tool_calls: + if not isinstance(tool_call, dict): + continue + + function_data = ( + tool_call.get("function") + if isinstance(tool_call.get("function"), dict) + else {} + ) + name = str(function_data.get("name", "")).strip() + if not name: + continue + + arguments = _stringify_tool_arguments(function_data.get("arguments")) + blocks.append( + "\n" + f"{name}\n" + f"{arguments}\n" + "" + ) + + if not blocks: + return "" + + return "\n" + "\n".join(blocks) + "\n" + + +def _preprocess_openai_messages( + messages: List[Dict[str, Any]], +) -> List[Dict[str, Any]]: + """Normalize OpenAI history into shapes accepted by the upstream service.""" + tool_call_index = _build_tool_call_index(messages) + normalized: List[Dict[str, Any]] = [] + + for message in messages: + if not isinstance(message, dict): + continue + + role = message.get("role") + + if role == "developer": + converted = dict(message) + converted["role"] = "system" + normalized.append(converted) + continue + + if role == "tool": + tool_call_id = message.get("tool_call_id") + content = _extract_text_from_content(message.get("content")) + tool_info = tool_call_index.get( + tool_call_id, + { + "name": str(message.get("name") or "unknown_tool"), + "arguments": "{}", + }, + ) + normalized.append( + { + "role": "user", + "content": _format_tool_result_message( + tool_info["name"], + tool_info["arguments"], + content, + ), + } + ) + continue + + if role == "assistant" and isinstance(message.get("tool_calls"), list): + content = _extract_text_from_content(message.get("content")) + tool_calls_text = _format_assistant_tool_calls(message["tool_calls"]) + merged_content = "\n".join( + part for part in (content, tool_calls_text) if part + ).strip() + normalized.append({"role": "assistant", "content": merged_content}) + continue + + normalized.append(dict(message)) + + return normalized + + +def _extract_last_user_text(messages: List[Dict[str, Any]]) -> str: + """Extract the last user text from the original OpenAI message history.""" + for message in reversed(messages): + if message.get("role") != "user": + continue + content = _extract_text_from_content(message.get("content")) + if content: + return content + return "" + + + +class UpstreamClient: + """当前服务使用的上游适配器。""" + + def __init__(self): + self.name = "upstream" + self.logger = logger + self.api_endpoint = settings.API_ENDPOINT + + # 当前上游特定配置 + self.base_url = DEFAULT_ZAI_BASE_URL + self.auth_url = f"{self.base_url}/api/v1/auths/" + + # 模型映射 + self.model_mapping = { + settings.GLM45_MODEL: "0727-360B-API", # GLM-4.5 + settings.GLM45_THINKING_MODEL: "0727-360B-API", # GLM-4.5-Thinking + settings.GLM45_SEARCH_MODEL: "0727-360B-API", # GLM-4.5-Search + settings.GLM45_AIR_MODEL: "0727-106B-API", # GLM-4.5-Air + settings.GLM46V_MODEL: "glm-4.6v", # GLM-4.6V多模态 + settings.GLM5_MODEL: "glm-5", # GLM-5 + settings.GLM47_MODEL: "glm-4.7", # GLM-4.7 + settings.GLM47_THINKING_MODEL: "glm-4.7", # GLM-4.7-Thinking + settings.GLM47_SEARCH_MODEL: "glm-4.7", # GLM-4.7-Search + settings.GLM47_ADVANCED_SEARCH_MODEL: "glm-4.7", # GLM-4.7-advanced-search + } + + def _get_guest_retry_limit(self) -> int: + """匿名号池可提供的最大重试预算。""" + if not settings.ANONYMOUS_MODE: + return 0 + + guest_pool = get_guest_session_pool() + if not guest_pool: + return max(2, settings.GUEST_POOL_SIZE + 1) + + pool_status = guest_pool.get_pool_status() + available_sessions = int( + pool_status.get("valid_sessions") + or pool_status.get("available_sessions") + or 0 + ) + return max(2, available_sessions + 1) + + def _get_authenticated_retry_limit(self) -> int: + """认证号池与静态 Token 可提供的最大重试预算。""" + available_tokens = 0 + token_pool = get_token_pool() + if token_pool: + available_tokens = int( + token_pool.get_pool_status().get("available_tokens", 0) or 0 + ) + + return max(0, available_tokens) + + def _get_total_retry_limit(self) -> int: + """综合认证号池与匿名号池的最大尝试次数。""" + return max( + 1, + self._get_authenticated_retry_limit() + self._get_guest_retry_limit(), + ) + + def _is_guest_auth(self, transformed: Dict[str, Any]) -> bool: + """判断当前请求是否使用匿名会话。""" + return str(transformed.get("auth_mode") or "") == "guest" + + def _should_retry_guest_session( + self, + status_code: int, + is_concurrency_limited: bool, + attempt: int, + max_attempts: int, + transformed: Dict[str, Any], + ) -> bool: + """判断匿名号池是否需要刷新会话后重试。""" + return ( + self._is_guest_auth(transformed) + and (status_code == 401 or is_concurrency_limited) + and attempt + 1 < max_attempts + ) + + def _should_retry_authenticated_session( + self, + status_code: int, + is_concurrency_limited: bool, + attempt: int, + max_attempts: int, + transformed: Dict[str, Any], + ) -> bool: + """判断认证号池是否需要切号重试。""" + current_token = str(transformed.get("token") or "") + return ( + not self._is_guest_auth(transformed) + and bool(current_token) + and (status_code == 401 or is_concurrency_limited) + and attempt + 1 < max_attempts + ) + + async def _release_guest_session(self, transformed: Dict[str, Any]): + """释放当前匿名会话占用。""" + if not self._is_guest_auth(transformed): + return + + guest_pool = get_guest_session_pool() + guest_user_id = str( + transformed.get("guest_user_id") or transformed.get("user_id") or "" + ) + if guest_pool and guest_user_id: + guest_pool.release(guest_user_id) + + async def _report_guest_session_failure( + self, + transformed: Dict[str, Any], + *, + is_concurrency_limited: bool = False, + ): + """上报匿名会话失败并补齐新会话。""" + if not self._is_guest_auth(transformed): + return + + guest_pool = get_guest_session_pool() + guest_user_id = str( + transformed.get("guest_user_id") or transformed.get("user_id") or "" + ) + if not guest_pool or not guest_user_id: + return + + if is_concurrency_limited: + await guest_pool.cleanup_idle_chats() + + await guest_pool.report_failure(guest_user_id) + + async def _refresh_guest_request( + self, + request: OpenAIRequest, + attempt: int, + excluded_tokens: Set[str], + excluded_guest_user_ids: Set[str], + failed_transformed: Dict[str, Any], + is_concurrency_limited: bool = False, + ) -> Dict[str, Any]: + """匿名会话失效或并发受限后切换会话并重签请求。""" + retry_number = attempt + 2 + self.logger.warning( + "🔄 匿名会话不可用,正在切换匿名会话并进行第 " + f"{retry_number} 次请求" + ) + await self._report_guest_session_failure( + failed_transformed, + is_concurrency_limited=is_concurrency_limited, + ) + return await self.transform_request( + request, + excluded_tokens=excluded_tokens, + excluded_guest_user_ids=excluded_guest_user_ids, + ) + + async def _refresh_authenticated_request( + self, + request: OpenAIRequest, + attempt: int, + excluded_tokens: Set[str], + excluded_guest_user_ids: Set[str], + ) -> Dict[str, Any]: + """认证模式下切换到下一枚 Token,并允许回退匿名池。""" + retry_number = attempt + 2 + self.logger.warning( + "🔄 检测到认证会话不可用,正在切换认证 Token/回退匿名池并进行第 " + f"{retry_number} 次请求" + ) + return await self.transform_request( + request, + excluded_tokens=excluded_tokens, + excluded_guest_user_ids=excluded_guest_user_ids, + ) + + def _extract_upstream_error_details( + self, + status_code: int, + error_text: str, + ) -> Tuple[Optional[int], str]: + """解析上游错误响应中的 code/message。""" + parsed_code: Optional[int] = None + parsed_message = (error_text or "").strip() + + try: + payload = json.loads(error_text) + except Exception: + return parsed_code, parsed_message + + if not isinstance(payload, dict): + return parsed_code, parsed_message + + candidates = [ + payload, + payload.get("error") if isinstance(payload.get("error"), dict) else None, + payload.get("detail") if isinstance(payload.get("detail"), dict) else None, + payload.get("data") if isinstance(payload.get("data"), dict) else None, + ] + + for candidate in candidates: + if not isinstance(candidate, dict): + continue + + code = candidate.get("code") + if isinstance(code, int): + parsed_code = code + elif isinstance(code, str) and code.isdigit(): + parsed_code = int(code) + + for key in ("message", "msg", "detail", "error"): + value = candidate.get(key) + if isinstance(value, str) and value.strip(): + parsed_message = value.strip() + break + + if parsed_code is not None or parsed_message: + break + + return parsed_code, parsed_message + + def _is_concurrency_limited( + self, + status_code: int, + error_code: Optional[int], + error_message: str, + ) -> bool: + """判断是否为上游并发限制/429 场景。""" + message = (error_message or "").casefold() + return ( + status_code == 429 + or error_code == 429 + or "concurrency" in message + or "too many requests" in message + or "并发" in error_message + ) + + def get_supported_models(self) -> List[str]: + """获取支持的模型列表""" + return [ + settings.GLM45_MODEL, + settings.GLM45_THINKING_MODEL, + settings.GLM45_SEARCH_MODEL, + settings.GLM45_AIR_MODEL, + settings.GLM46V_MODEL, + settings.GLM5_MODEL, + settings.GLM47_MODEL, + settings.GLM47_THINKING_MODEL, + settings.GLM47_SEARCH_MODEL, + settings.GLM47_ADVANCED_SEARCH_MODEL, + ] + + def _requires_persisted_chat(self, upstream_model_id: str) -> bool: + """需要挂载真实 chat 会话的上游模型。""" + return bool( + self._get_model_request_profile(upstream_model_id)["use_persisted_chat"] + ) + + def _get_model_request_profile(self, upstream_model_id: str) -> Dict[str, Any]: + """返回模型专属的请求配置。""" + if upstream_model_id == "glm-4.6v": + return { + "use_persisted_chat": True, + "preview_mode": False, + "mcp_servers": list(GLM46V_MCP_SERVERS), + "feature_entries": [dict(item) for item in GLM46V_SELECTED_FEATURES], + "default_enable_thinking": True, + } + + if upstream_model_id == "glm-5": + return { + "use_persisted_chat": False, + "preview_mode": True, + "mcp_servers": [], + "feature_entries": [], + "default_enable_thinking": True, + } + + return { + "use_persisted_chat": upstream_model_id == "glm-4.7", + "preview_mode": True, + "mcp_servers": [], + "feature_entries": [], + "default_enable_thinking": None, + } + + def _build_request_variables(self) -> Dict[str, str]: + """构建上游请求需要的运行时变量。""" + now = datetime.now() + return { + "{{USER_NAME}}": "Guest", + "{{USER_LOCATION}}": "Unknown", + "{{CURRENT_DATETIME}}": now.strftime("%Y-%m-%d %H:%M:%S"), + "{{CURRENT_DATE}}": now.strftime("%Y-%m-%d"), + "{{CURRENT_TIME}}": now.strftime("%H:%M:%S"), + "{{CURRENT_WEEKDAY}}": now.strftime("%A"), + "{{CURRENT_TIMEZONE}}": DEFAULT_TIMEZONE, + "{{USER_LANGUAGE}}": DEFAULT_LANGUAGE, + } + + def _build_browser_query_params( + self, + *, + chat_id: str, + token: str, + user_id: str, + user_agent: str, + timestamp_ms: int, + ) -> Dict[str, str]: + """构建 GLM-4.7 所需的浏览器指纹查询参数。""" + now = datetime.now(timezone.utc) + browser_name = "Chrome" + if "Edg/" in user_agent: + browser_name = "Microsoft Edge" + elif "Firefox/" in user_agent: + browser_name = "Firefox" + elif "Safari/" in user_agent and "Chrome/" not in user_agent: + browser_name = "Safari" + + return { + "version": DEFAULT_CLIENT_VERSION, + "platform": DEFAULT_PLATFORM, + "token": token, + "user_agent": user_agent, + "language": DEFAULT_LANGUAGE, + "languages": DEFAULT_LANGUAGE, + "timezone": DEFAULT_TIMEZONE, + "cookie_enabled": "true", + "screen_width": DEFAULT_SCREEN_WIDTH, + "screen_height": DEFAULT_SCREEN_HEIGHT, + "screen_resolution": DEFAULT_SCREEN_RESOLUTION, + "viewport_height": DEFAULT_VIEWPORT_HEIGHT, + "viewport_width": DEFAULT_VIEWPORT_WIDTH, + "viewport_size": DEFAULT_VIEWPORT_SIZE, + "color_depth": DEFAULT_COLOR_DEPTH, + "pixel_ratio": DEFAULT_PIXEL_RATIO, + "current_url": f"{self.base_url}/c/{chat_id}", + "pathname": f"/c/{chat_id}", + "search": "", + "hash": "", + "host": "chat.z.ai", + "hostname": "chat.z.ai", + "protocol": "https:", + "referrer": "", + "title": DEFAULT_PAGE_TITLE, + "timezone_offset": DEFAULT_TIMEZONE_OFFSET, + "local_time": ( + now.strftime("%Y-%m-%dT%H:%M:%S.") + + f"{now.microsecond // 1000:03d}Z" + ), + "utc_time": now.strftime("%a, %d %b %Y %H:%M:%S GMT"), + "is_mobile": "false", + "is_touch": "false", + "max_touch_points": DEFAULT_MAX_TOUCH_POINTS, + "browser_name": browser_name, + "os_name": "Windows", + "signature_timestamp": str(timestamp_ms), + } + + def _build_signed_completion_request( + self, + *, + prompt: str, + chat_id: str, + token: str, + user_id: str, + user_agent: str, + use_browser_fingerprint: bool, + ) -> Tuple[str, str, str]: + """构建上游 completions 的签名 URL 与请求头元数据。""" + timestamp_ms = int(time.time() * 1000) + request_id = generate_uuid() + core_params = { + "requestId": request_id, + "timestamp": str(timestamp_ms), + "user_id": user_id, + } + canonical_payload = ",".join( + f"{key},{value}" for key, value in sorted(core_params.items()) + ) + signature = generate_signature( + e=canonical_payload, + t=prompt or "", + s=timestamp_ms, + )["signature"] + query_params = dict(core_params) + if use_browser_fingerprint: + query_params.update( + self._build_browser_query_params( + chat_id=chat_id, + token=token, + user_id=user_id, + user_agent=user_agent, + timestamp_ms=timestamp_ms, + ) + ) + else: + query_params.update( + { + "token": token, + "version": DEFAULT_CLIENT_VERSION, + "platform": DEFAULT_PLATFORM, + "current_url": f"{self.base_url}/c/{chat_id}", + "pathname": f"/c/{chat_id}", + "signature_timestamp": str(timestamp_ms), + } + ) + + return ( + f"{self.api_endpoint}?{urlencode(query_params)}", + signature, + str(timestamp_ms), + ) + + async def _create_upstream_chat( + self, + *, + prompt: str, + model: str, + token: str, + headers: Dict[str, str], + enable_thinking: bool, + web_search: bool, + user_message_id: Optional[str] = None, + files: Optional[List[Dict[str, Any]]] = None, + feature_entries: Optional[List[Dict[str, Any]]] = None, + mcp_servers: Optional[List[str]] = None, + ) -> str: + """为 GLM-4.7 系列创建上游真实 chat 会话。""" + init_content = prompt[:CHAT_BOOTSTRAP_MAX_CONTENT_LEN] + if len(prompt) > CHAT_BOOTSTRAP_MAX_CONTENT_LEN: + init_content = init_content + "..." + + message_id = user_message_id or generate_uuid() + timestamp_seconds = int(time.time()) + chat_features = ( + [dict(item) for item in feature_entries] + if feature_entries + else [ + { + "type": "tool_selector", + "server": "tool_selector_h", + "status": "hidden", + } + ] + ) + body = { + "chat": { + "id": "", + "title": "新聊天", + "models": [model], + "params": {}, + "history": { + "messages": { + message_id: { + "id": message_id, + "parentId": None, + "childrenIds": [], + "role": "user", + "content": init_content, + **({"files": [dict(item) for item in files]} if files else {}), + "timestamp": timestamp_seconds, + "models": [model], + } + }, + "currentId": message_id, + }, + "tags": [], + "flags": [], + "features": chat_features, + "mcp_servers": list(mcp_servers or []), + "enable_thinking": enable_thinking, + "auto_web_search": web_search, + "message_version": 1, + "extra": {}, + "timestamp": int(time.time() * 1000), + } + } + request_headers = { + "Content-Type": "application/json", + "Accept": "application/json", + "Authorization": f"Bearer {token}", + "User-Agent": headers["User-Agent"], + "Accept-Language": headers.get("Accept-Language", DEFAULT_LANGUAGE), + "Origin": self.base_url, + "Referer": f"{self.base_url}/", + } + async with httpx.AsyncClient( + base_url=self.base_url, + timeout=self._build_timeout(), + limits=self._build_limits(), + proxy=self._get_proxy_config(), + follow_redirects=True, + ) as client: + response = await client.post( + "/api/v1/chats/new", + headers=request_headers, + json=body, + ) + + if response.status_code != 200: + raise RuntimeError( + f"上游创建 chat 失败: {response.status_code} {response.text}" + ) + + payload = response.json() + chat_id = str(payload.get("id") or payload.get("chat", {}).get("id") or "") + if not chat_id: + raise RuntimeError("上游创建 chat 成功但未返回 chat_id") + return chat_id + + def _build_glm47_completion_body( + self, + *, + model: str, + messages: List[Dict[str, Any]], + prompt: str, + chat_id: str, + enable_thinking: bool, + web_search: bool, + files: List[Dict[str, Any]], + tools: Optional[List[Dict[str, Any]]], + tool_choice: Any, + temperature: Optional[float], + max_tokens: Optional[int], + mcp_servers: List[str], + preview_mode: bool, + feature_entries: Optional[List[Dict[str, Any]]], + message_id: str, + current_user_message_id: str, + current_user_message_parent_id: Optional[str], + ) -> Dict[str, Any]: + """构建兼容持久化 chat 模型的精简 completions 请求体。""" + params: Dict[str, Any] = {} + if temperature is not None: + params["temperature"] = temperature + if max_tokens is not None: + params["max_tokens"] = max_tokens + + body: Dict[str, Any] = { + "stream": True, + "model": model, + "messages": messages, + "signature_prompt": prompt, + "params": params, + "extra": {}, + "features": { + "image_generation": False, + "web_search": web_search, + "auto_web_search": web_search, + "preview_mode": preview_mode, + "flags": [], + "enable_thinking": enable_thinking, + }, + "variables": self._build_request_variables(), + "chat_id": chat_id, + "id": message_id, + "current_user_message_id": current_user_message_id, + "current_user_message_parent_id": current_user_message_parent_id, + "background_tasks": { + "title_generation": True, + "tags_generation": True, + }, + } + if files: + body["files"] = files + if mcp_servers: + body["mcp_servers"] = mcp_servers + if tools: + body["tools"] = tools + if tool_choice is not None: + body["tool_choice"] = tool_choice + return body + + def _clean_reasoning_delta(self, delta_content: str) -> str: + """清理思考阶段的 details 包裹内容。""" + if not delta_content: + return "" + + if delta_content.startswith("\n>" in delta_content: + return delta_content.split("\n>")[-1].strip() + if "\n" in delta_content: + return delta_content.split("\n")[-1].lstrip("> ").strip() + + return delta_content + + def _extract_answer_content(self, text: str) -> str: + """提取思考结束后的答案正文。""" + if not text: + return "" + + if "\n" in text: + return text.split("\n")[-1] + + if "" in text: + return text.split("")[-1].lstrip() + + return text + + def _normalize_tool_calls( + self, + raw_tool_calls: Any, + start_index: int = 0, + ) -> List[Dict[str, Any]]: + """标准化上游工具调用为 OpenAI 兼容格式。""" + if not raw_tool_calls: + return [] + + tool_calls = raw_tool_calls if isinstance(raw_tool_calls, list) else [raw_tool_calls] + normalized: List[Dict[str, Any]] = [] + + for offset, tool_call in enumerate(tool_calls): + if not isinstance(tool_call, dict): + continue + + function_data = tool_call.get("function") or {} + normalized.append( + { + "index": tool_call.get("index", start_index + offset), + "id": tool_call.get("id") or f"call_{uuid.uuid4().hex[:24]}", + "type": "function", + "function": { + "name": function_data.get("name", ""), + "arguments": function_data.get("arguments", ""), + }, + } + ) + + return normalized + + def _format_search_results(self, data: Dict[str, Any]) -> str: + """将上游搜索结果格式化为可追加的 Markdown 引用。""" + search_info = data.get("results") or data.get("sources") or data.get("citations") + if not isinstance(search_info, list) or not search_info: + return "" + + citations = [] + for index, item in enumerate(search_info, 1): + if not isinstance(item, dict): + continue + + title = item.get("title") or item.get("name") or f"Result {index}" + url = item.get("url") or item.get("link") + if url: + citations.append(f"[{index}] [{title}]({url})") + + if not citations: + return "" + + return "\n\n---\n" + "\n".join(citations) + + def _get_proxy_config(self) -> Optional[str]: + """Get proxy configuration from settings""" + # In httpx 0.28.1, proxy parameter expects a single URL string + # Support HTTP_PROXY, HTTPS_PROXY and SOCKS5_PROXY + + if settings.HTTPS_PROXY: + self.logger.info(f"🔄 使用HTTPS代理: {settings.HTTPS_PROXY}") + return settings.HTTPS_PROXY + + if settings.HTTP_PROXY: + self.logger.info(f"🔄 使用HTTP代理: {settings.HTTP_PROXY}") + return settings.HTTP_PROXY + + if settings.SOCKS5_PROXY: + self.logger.info(f"🔄 使用SOCKS5代理: {settings.SOCKS5_PROXY}") + return settings.SOCKS5_PROXY + + return None + + def _build_timeout(self, read_timeout: float = 30.0) -> httpx.Timeout: + """Create httpx timeout settings tuned for upstream chat traffic.""" + return httpx.Timeout( + connect=5.0, + read=read_timeout, + write=10.0, + pool=5.0, + ) + + def _build_limits(self) -> httpx.Limits: + """Create conservative connection-pool limits for upstream requests.""" + return httpx.Limits( + max_keepalive_connections=5, + max_connections=10, + ) + + async def _fetch_direct_guest_auth(self) -> Dict[str, Any]: + """匿名号池缺席时,兜底直连拉取一个访客令牌。""" + max_retries = 3 + + for retry_count in range(max_retries): + try: + headers = get_dynamic_headers() + self.logger.debug( + f"尝试获取访客令牌 (第{retry_count + 1}次): {self.auth_url}" + ) + + proxies = self._get_proxy_config() + async with httpx.AsyncClient( + timeout=self._build_timeout(), + follow_redirects=True, + limits=self._build_limits(), + proxy=proxies, + ) as client: + response = await client.get(self.auth_url, headers=headers) + + if response.status_code == 200: + data = response.json() + token = str(data.get("token") or "").strip() + if token: + user_id = str( + data.get("id") + or data.get("user_id") + or _extract_user_id_from_token(token) + ) + username = str( + data.get("name") + or str(data.get("email") or "").split("@")[0] + or "Guest" + ) + self.logger.info( + f"✅ 直连获取匿名令牌成功: {token[:20]}..." + ) + return { + "token": token, + "user_id": user_id, + "username": username or "Guest", + "auth_mode": "guest", + "token_source": "guest_direct", + "guest_user_id": user_id, + } + + self.logger.warning(f"响应中未找到 token 字段: {data}") + elif response.status_code == 405: + self.logger.error( + "🚫 请求被 WAF 拦截 (405),无法直连获取匿名令牌" + ) + break + else: + self.logger.warning( + f"直连获取匿名令牌失败,状态码: {response.status_code}" + ) + except httpx.TimeoutException as exc: + self.logger.warning( + f"直连获取匿名令牌超时 (第{retry_count + 1}次): {exc}" + ) + except httpx.ConnectError as exc: + self.logger.warning( + f"直连获取匿名令牌连接错误 (第{retry_count + 1}次): {exc}" + ) + except json.JSONDecodeError as exc: + self.logger.warning( + f"直连获取匿名令牌 JSON 解析错误 (第{retry_count + 1}次): {exc}" + ) + except Exception as exc: + self.logger.warning( + f"直连获取匿名令牌失败 (第{retry_count + 1}次): {exc}" + ) + + if retry_count + 1 < max_retries: + await asyncio.sleep(2) + + return { + "token": "", + "user_id": "guest", + "username": "Guest", + "auth_mode": "guest", + "token_source": "guest_direct", + "guest_user_id": None, + } + + async def get_auth_info( + self, + excluded_tokens: Optional[Set[str]] = None, + excluded_guest_user_ids: Optional[Set[str]] = None, + ) -> Dict[str, Any]: + """优先获取认证 Token,必要时回退匿名会话池。""" + token_pool = get_token_pool() + if token_pool: + token = token_pool.get_next_token(exclude_tokens=excluded_tokens) + if token: + user_id = _extract_user_id_from_token(token) + self.logger.debug(f"从认证号池获取令牌: {token[:20]}...") + return { + "token": token, + "user_id": user_id, + "username": "User", + "auth_mode": "authenticated", + "token_source": "auth_pool", + "guest_user_id": None, + } + + if settings.ANONYMOUS_MODE: + guest_pool = get_guest_session_pool() + if guest_pool: + try: + session = await guest_pool.acquire( + exclude_user_ids=excluded_guest_user_ids + ) + self.logger.info( + "🫥 认证池不可用,回退匿名会话池: " + f"user_id={session.user_id}" + ) + return { + "token": session.token, + "user_id": session.user_id, + "username": session.username, + "auth_mode": "guest", + "token_source": "guest_pool", + "guest_user_id": session.user_id, + } + except Exception as exc: + self.logger.warning(f"匿名会话池获取失败,转为直连访客鉴权: {exc}") + + return await self._fetch_direct_guest_auth() + + self.logger.error("❌ 无法获取有效的上游令牌") + return { + "token": "", + "user_id": "", + "username": "", + "auth_mode": "authenticated", + "token_source": "none", + "guest_user_id": None, + } + + async def mark_token_failure(self, token: str, error: Exception = None): + """标记token使用失败""" + token_pool = get_token_pool() + if token_pool: + await token_pool.record_token_failure(token, error) + + async def upload_image( + self, + data_url: str, + chat_id: str, + token: str, + user_id: str, + auth_mode: str = "authenticated", + ) -> Optional[Dict]: + """上传 base64 编码的图片到上游服务器。 + + Args: + data_url: data:image/xxx;base64,... 格式的图片数据 + chat_id: 当前对话ID + token: 认证令牌 + user_id: 用户ID + auth_mode: 当前鉴权模式,guest 模式下禁止上传 + + Returns: + 上传成功返回完整的文件信息字典,失败返回 None + """ + if auth_mode == "guest" or not data_url.startswith("data:"): + return None + + try: + # 解析 data URL + header, encoded = data_url.split(",", 1) + mime_type = header.split(";")[0].split(":")[1] if ":" in header else "image/jpeg" + + # 解码 base64 数据 + image_data = base64.b64decode(encoded) + filename = str(uuid.uuid4()) + + self.logger.debug(f"📤 上传图片: {filename}, 大小: {len(image_data)} bytes") + + # 构建上传请求 + upload_url = f"{self.base_url}/api/v1/files/" + headers = { + "Accept": "*/*", + "Accept-Language": "zh-CN,zh;q=0.9", + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "Origin": f"{self.base_url}", + "Pragma": "no-cache", + "Referer": ( + f"{self.base_url}/c/{chat_id}" if chat_id else f"{self.base_url}/" + ), + "Sec-Ch-Ua": '"Microsoft Edge";v="141", "Not?A_Brand";v="8", "Chromium";v="141"', + "Sec-Ch-Ua-Mobile": "?0", + "Sec-Ch-Ua-Platform": '"Windows"', + "Sec-Fetch-Dest": "empty", + "Sec-Fetch-Mode": "cors", + "Sec-Fetch-Site": "same-origin", + "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/141.0.0.0 Safari/537.36 Edg/141.0.0.0", + "Authorization": f"Bearer {token}", + } + + # Get proxy configuration + proxies = self._get_proxy_config() + + # 使用 httpx 上传文件 + async with httpx.AsyncClient( + timeout=self._build_timeout(), + limits=self._build_limits(), + proxy=proxies, + ) as client: + files = { + "file": (filename, image_data, mime_type) + } + response = await client.post(upload_url, files=files, headers=headers) + + if response.status_code == 200: + result = response.json() + file_id = result.get("id") + file_name = result.get("filename") + file_size = len(image_data) + + self.logger.info(f"✅ 图片上传成功: {file_id}_{file_name}") + + # 返回符合上游格式的文件信息 + current_timestamp = int(time.time()) + return { + "type": "image", + "file": { + "id": file_id, + "user_id": user_id, + "hash": None, + "filename": file_name, + "data": {}, + "meta": { + "name": file_name, + "content_type": mime_type, + "size": file_size, + "data": {}, + }, + "created_at": current_timestamp, + "updated_at": current_timestamp + }, + "id": file_id, + "url": f"/api/v1/files/{file_id}/content", + "name": file_name, + "status": "uploaded", + "size": file_size, + "error": "", + "itemId": str(uuid.uuid4()), + "media": "image" + } + else: + self.logger.error(f"❌ 图片上传失败: {response.status_code} - {response.text}") + return None + + except Exception as e: + self.logger.error(f"❌ 图片上传异常: {e}") + return None + + async def transform_request( + self, + request: OpenAIRequest, + excluded_tokens: Optional[Set[str]] = None, + excluded_guest_user_ids: Optional[Set[str]] = None, + ) -> Dict[str, Any]: + """转换 OpenAI 请求为上游格式。""" + self.logger.info(f"🔄 转换 OpenAI 请求到上游格式: {request.model}") + + raw_messages = [ + message.model_dump(exclude_none=True) + for message in request.messages + ] + normalized_messages = _preprocess_openai_messages(raw_messages) + + auth_info = await self.get_auth_info( + excluded_tokens=excluded_tokens, + excluded_guest_user_ids=excluded_guest_user_ids, + ) + token = str(auth_info.get("token") or "") + if not token: + raise RuntimeError("无法获取上游认证令牌") + + user_id = str(auth_info.get("user_id") or _extract_user_id_from_token(token)) + auth_mode = str(auth_info.get("auth_mode") or "authenticated") + token_source = str(auth_info.get("token_source") or "unknown") + guest_user_id = auth_info.get("guest_user_id") + # 确定请求的模型特性 + last_user_text = _extract_last_user_text(raw_messages) + requested_model = request.model + is_thinking_model = "-thinking" in requested_model.casefold() + is_search_model = "-search" in requested_model.casefold() + is_advanced_search = requested_model == settings.GLM47_ADVANCED_SEARCH_MODEL + upstream_model_id = self.model_mapping.get(requested_model, "0727-360B-API") + tools = request.tools if settings.TOOL_SUPPORT and request.tools else None + tool_choice = getattr(request, "tool_choice", None) + model_profile = self._get_model_request_profile(upstream_model_id) + enable_thinking = request.enable_thinking + if enable_thinking is None: + default_enable_thinking = model_profile["default_enable_thinking"] + enable_thinking = ( + default_enable_thinking + if default_enable_thinking is not None + else is_thinking_model + ) + + web_search = request.web_search + if web_search is None: + web_search = is_search_model or is_advanced_search + + use_persisted_chat = bool(model_profile["use_persisted_chat"]) + preview_mode = bool(model_profile["preview_mode"]) + feature_entries = list(model_profile["feature_entries"]) + persisted_user_message_id = generate_uuid() if use_persisted_chat else None + persisted_assistant_message_id = generate_uuid() if use_persisted_chat else None + + mcp_servers = list(model_profile["mcp_servers"]) + if is_advanced_search and "advanced-search" not in mcp_servers: + mcp_servers.append("advanced-search") + self.logger.info("🔍 检测到高级搜索模型,添加 advanced-search MCP 服务器") + + headers = get_dynamic_headers( + browser_type="chrome" if use_persisted_chat else None, + ) + chat_id = generate_uuid() + + # 处理消息格式 - 上游使用单独的 files 字段传递图片 + messages = [] + files = [] + upload_chat_id = "" if use_persisted_chat else chat_id + + for msg in normalized_messages: + role = str(msg.get("role", "user")) + content = msg.get("content") + + if isinstance(content, str): + messages.append({"role": role, "content": content}) + continue + + if not isinstance(content, list): + continue + + text_parts = [] + image_parts = [] + for part in content: + image_url = None + if hasattr(part, "type"): + if part.type == "text" and hasattr(part, "text"): + text_parts.append(part.text or "") + elif part.type == "image_url" and hasattr(part, "image_url"): + if hasattr(part.image_url, "url"): + image_url = part.image_url.url + elif ( + isinstance(part.image_url, dict) + and "url" in part.image_url + ): + image_url = part.image_url["url"] + elif isinstance(part, dict): + if part.get("type") == "text": + text_parts.append(part.get("text", "")) + elif part.get("type") == "image_url": + image_url = part.get("image_url", {}).get("url", "") + elif isinstance(part, str): + text_parts.append(part) + + if not image_url: + continue + + self.logger.debug(f"✅ 检测到图片: {image_url[:50]}...") + if image_url.startswith("data:") and auth_mode != "guest": + self.logger.info("🔄 上传 base64 图片到上游服务") + file_info = await self.upload_image( + image_url, + upload_chat_id, + token, + user_id, + auth_mode=auth_mode, + ) + if not file_info: + self.logger.warning("⚠️ 图片上传失败") + text_parts.append("[系统提示: 图片上传失败]") + continue + + files.append(file_info) + self.logger.info("✅ 图片已添加到 files 数组") + if persisted_user_message_id: + file_info["ref_user_msg_id"] = persisted_user_message_id + image_ref = str(file_info["id"]) + image_parts.append( + { + "type": "image_url", + "image_url": {"url": image_ref}, + } + ) + self.logger.debug(f"📎 图片引用: {image_ref}") + continue + + if auth_mode != "guest": + self.logger.warning("⚠️ 非 base64 图片或匿名模式,保留原始URL") + image_parts.append( + { + "type": "image_url", + "image_url": {"url": image_url}, + } + ) + + message_content = [] + combined_text = " ".join(text_parts).strip() + if combined_text: + message_content.append({"type": "text", "text": combined_text}) + message_content.extend(image_parts) + if message_content: + messages.append({"role": role, "content": message_content}) + + if use_persisted_chat: + chat_id = await self._create_upstream_chat( + prompt=last_user_text, + model=upstream_model_id, + token=token, + headers=headers, + enable_thinking=enable_thinking, + web_search=web_search, + user_message_id=persisted_user_message_id, + files=files or None, + feature_entries=feature_entries or None, + mcp_servers=mcp_servers or None, + ) + self.logger.info(f"🧩 已为 {requested_model} 创建上游 chat: {chat_id}") + headers["Referer"] = f"{self.base_url}/c/{chat_id}" + + if use_persisted_chat: + body = self._build_glm47_completion_body( + model=upstream_model_id, + messages=messages, + prompt=last_user_text, + chat_id=chat_id, + enable_thinking=enable_thinking, + web_search=web_search, + files=files, + tools=tools, + tool_choice=tool_choice, + temperature=request.temperature, + max_tokens=request.max_tokens, + mcp_servers=mcp_servers, + preview_mode=preview_mode, + feature_entries=feature_entries or None, + message_id=persisted_assistant_message_id or generate_uuid(), + current_user_message_id=persisted_user_message_id or generate_uuid(), + current_user_message_parent_id=None, + ) + else: + message_id = generate_uuid() + session_id = generate_uuid() + body = { + "stream": True, + "model": upstream_model_id, + "messages": messages, + "signature_prompt": last_user_text, + "files": files, + "params": {}, + "extra": {}, + "features": { + "image_generation": False, + "web_search": web_search, + "auto_web_search": web_search, + "preview_mode": preview_mode, + "flags": [], + "features": [ + dict(item) + for item in (feature_entries or DEFAULT_COMPLETION_FEATURES) + ], + "enable_thinking": enable_thinking, + }, + "background_tasks": { + "title_generation": False, + "tags_generation": False, + }, + "mcp_servers": mcp_servers, + "variables": self._build_request_variables(), + "model_item": { + "id": upstream_model_id, + "name": requested_model, + "owned_by": settings.SERVICE_NAME, + }, + "chat_id": chat_id, + "id": message_id, + "session_id": session_id, + "current_user_message_id": message_id, + "current_user_message_parent_id": None, + } + if tools: + body["tools"] = tools + if tool_choice is not None: + body["tool_choice"] = tool_choice + self.logger.info(f"🔧 工具调用将直接透传到上游: {len(tools)} 个工具") + else: + body["tools"] = None + if request.temperature is not None: + body["params"]["temperature"] = request.temperature + if request.max_tokens is not None: + body["params"]["max_tokens"] = request.max_tokens + + try: + signed_url, signature, timestamp_ms = self._build_signed_completion_request( + prompt=last_user_text, + chat_id=chat_id, + token=token, + user_id=user_id, + user_agent=headers["User-Agent"], + use_browser_fingerprint=use_persisted_chat, + ) + logger.debug( + "[上游] 生成签名成功: %s... (user_id=%s, timestamp=%s)", + signature[:16], + user_id, + timestamp_ms, + ) + except Exception as e: + logger.error(f"[上游] 签名生成失败: {e}") + signature = "" + timestamp_ms = "0" + signed_url = self.api_endpoint + + fe_version = headers.get("X-FE-Version") or get_latest_fe_version() + headers.update( + { + "Authorization": f"Bearer {token}", + "Content-Type": "application/json", + "Accept": "*/*" if use_persisted_chat else "application/json", + "X-FE-Version": fe_version, + "X-Signature": signature, + } + ) + + logger.debug( + "[上游] 请求头: Authorization=Bearer *****, X-Signature=%s...", + signature[:16] if signature else "(空)", + ) + logger.debug( + "[上游] URL 参数: timestamp=%s, user_id=%s, persisted_chat=%s", + timestamp_ms, + user_id, + use_persisted_chat, + ) + + # 存储当前token用于错误处理 + self._current_token = token + + return { + "url": signed_url, + "headers": headers, + "body": body, + "token": token, + "chat_id": chat_id, + "model": requested_model, + "user_id": user_id, + "auth_mode": auth_mode, + "token_source": token_source, + "guest_user_id": guest_user_id, + } + + async def chat_completion( + self, + request: OpenAIRequest, + **kwargs + ) -> Union[Dict[str, Any], AsyncGenerator[str, None]]: + """聊天完成接口。""" + self.logger.info(f"🔄 {self.name} 处理请求: {request.model}") + self.logger.debug(f" 消息数量: {len(request.messages)}") + self.logger.debug(f" 流式模式: {request.stream}") + + try: + transformed = await self.transform_request(request) + + if request.stream: + return self._create_stream_response(request, transformed) + + proxies = self._get_proxy_config() + max_attempts = self._get_total_retry_limit() + excluded_tokens: Set[str] = set() + excluded_guest_user_ids: Set[str] = set() + + for attempt in range(max_attempts): + async with httpx.AsyncClient( + timeout=self._build_timeout(read_timeout=60.0), + limits=self._build_limits(), + proxy=proxies, + ) as client: + response = await client.post( + transformed["url"], + headers=transformed["headers"], + json=transformed["body"], + ) + + error_code, error_message = self._extract_upstream_error_details( + response.status_code, + response.text, + ) + is_concurrency_limited = self._is_concurrency_limited( + response.status_code, + error_code, + error_message, + ) + + if self._should_retry_guest_session( + response.status_code, + is_concurrency_limited, + attempt, + max_attempts, + transformed, + ): + guest_user_id = str( + transformed.get("guest_user_id") + or transformed.get("user_id") + or "" + ) + if guest_user_id: + excluded_guest_user_ids.add(guest_user_id) + transformed = await self._refresh_guest_request( + request, + attempt, + excluded_tokens, + excluded_guest_user_ids, + transformed, + is_concurrency_limited=is_concurrency_limited, + ) + continue + + if self._should_retry_authenticated_session( + response.status_code, + is_concurrency_limited, + attempt, + max_attempts, + transformed, + ): + current_token = str(transformed.get("token") or "") + if current_token: + excluded_tokens.add(current_token) + await self.mark_token_failure( + current_token, + Exception(error_message or "上游认证会话不可用"), + ) + self.logger.warning( + "⚠️ 认证会话不可用,准备切换认证 Token/回退匿名池: " + f"{current_token[:20]}..." + ) + transformed = await self._refresh_authenticated_request( + request, + attempt, + excluded_tokens, + excluded_guest_user_ids, + ) + continue + + if not response.is_success: + error_msg = f"上游 API 错误: {response.status_code}" + if not self._is_guest_auth(transformed): + current_token = str(transformed.get("token") or "") + if current_token: + await self.mark_token_failure( + current_token, + Exception(error_message or error_msg), + ) + await self._release_guest_session(transformed) + self.logger.error(f"❌ {self.name} 响应失败: {error_msg}") + return handle_error(Exception(error_message or error_msg)) + + try: + result = await self.transform_response(response, request, transformed) + finally: + await self._release_guest_session(transformed) + + if not self._is_guest_auth(transformed): + current_token = str(transformed.get("token") or "") + if current_token: + token_pool = get_token_pool() + if token_pool: + await token_pool.record_token_success(current_token) + + return result + + except Exception as e: + self.logger.error(f"❌ {self.name} 响应失败: {str(e)}") + return handle_error(e, "请求处理") + + async def _create_stream_response( + self, + request: OpenAIRequest, + transformed: Dict[str, Any] + ) -> AsyncGenerator[str, None]: + """创建流式响应,并在首包前支持双池重试。""" + max_attempts = self._get_total_retry_limit() + excluded_tokens: Set[str] = set() + excluded_guest_user_ids: Set[str] = set() + current_token = str(transformed.get("token") or "") + + try: + proxies = self._get_proxy_config() + + async with httpx.AsyncClient( + timeout=self._build_timeout(read_timeout=180.0), + http2=True, + limits=self._build_limits(), + proxy=proxies, + ) as client: + for attempt in range(max_attempts): + self.logger.info(f"🎯 发送请求到上游: {transformed['url']}") + async with client.stream( + "POST", + transformed["url"], + json=transformed["body"], + headers=transformed["headers"], + ) as response: + error_text = await response.aread() if response.status_code != 200 else b"" + error_msg = error_text.decode("utf-8", errors="ignore") + error_code, parsed_error_message = ( + self._extract_upstream_error_details( + response.status_code, + error_msg, + ) + if response.status_code != 200 + else (None, "") + ) + is_concurrency_limited = self._is_concurrency_limited( + response.status_code, + error_code, + parsed_error_message, + ) + + if self._should_retry_guest_session( + response.status_code, + is_concurrency_limited, + attempt, + max_attempts, + transformed, + ): + guest_user_id = str( + transformed.get("guest_user_id") + or transformed.get("user_id") + or "" + ) + if guest_user_id: + excluded_guest_user_ids.add(guest_user_id) + transformed = await self._refresh_guest_request( + request, + attempt, + excluded_tokens, + excluded_guest_user_ids, + transformed, + is_concurrency_limited=is_concurrency_limited, + ) + current_token = str(transformed.get("token") or "") + continue + + if self._should_retry_authenticated_session( + response.status_code, + is_concurrency_limited, + attempt, + max_attempts, + transformed, + ): + if current_token: + excluded_tokens.add(current_token) + await self.mark_token_failure( + current_token, + Exception( + parsed_error_message or "上游认证会话不可用" + ), + ) + self.logger.warning( + "⚠️ 流式请求命中认证会话限制,准备切号/回退匿名池: " + f"{current_token[:20]}..." + ) + transformed = await self._refresh_authenticated_request( + request, + attempt, + excluded_tokens, + excluded_guest_user_ids, + ) + current_token = str(transformed.get("token") or "") + continue + + if response.status_code != 200: + self.logger.error(f"❌ 上游返回错误: {response.status_code}") + if error_msg: + self.logger.error(f"❌ 错误详情: {error_msg}") + + if not self._is_guest_auth(transformed) and current_token: + await self.mark_token_failure( + current_token, + Exception( + parsed_error_message + or f"Upstream error: {response.status_code}" + ), + ) + await self._release_guest_session(transformed) + + if response.status_code == 405: + self.logger.error( + "🚫 请求被上游 WAF 拦截,可能是请求头或签名异常" + ) + error_response = { + "error": { + "message": ( + "请求被上游WAF拦截(405 Method Not Allowed)," + "可能是请求头或签名异常,请稍后重试..." + ), + "type": "waf_blocked", + "code": 405, + } + } + else: + error_response = { + "error": { + "message": parsed_error_message + or f"Upstream error: {response.status_code}", + "type": "upstream_error", + "code": error_code or response.status_code, + } + } + yield f"data: {json.dumps(error_response)}\n\n" + yield "data: [DONE]\n\n" + return + + chat_id = transformed["chat_id"] + model = transformed["model"] + try: + async for chunk in self._handle_stream_response( + response, + chat_id, + model, + request, + transformed, + ): + yield chunk + finally: + await self._release_guest_session(transformed) + + if not self._is_guest_auth(transformed) and current_token: + token_pool = get_token_pool() + if token_pool: + await token_pool.record_token_success(current_token) + return + except Exception as e: + self.logger.error(f"❌ 流处理错误: {e}") + import traceback + self.logger.error(traceback.format_exc()) + if self._is_guest_auth(transformed): + await self._release_guest_session(transformed) + elif current_token: + await self.mark_token_failure(current_token, e) + + error_response = { + "error": { + "message": str(e), + "type": "stream_error" + } + } + yield f"data: {json.dumps(error_response)}\n\n" + yield "data: [DONE]\n\n" + return + + async def transform_response( + self, + response: httpx.Response, + request: OpenAIRequest, + transformed: Dict[str, Any] + ) -> Union[Dict[str, Any], AsyncGenerator[str, None]]: + """转换上游响应为 OpenAI 格式。""" + chat_id = transformed["chat_id"] + model = transformed["model"] + + if request.stream: + return self._handle_stream_response(response, chat_id, model, request, transformed) + else: + return await self._handle_non_stream_response(response, chat_id, model) + + async def _handle_stream_response( + self, + response: httpx.Response, + chat_id: str, + model: str, + request: OpenAIRequest, + transformed: Dict[str, Any] + ) -> AsyncGenerator[str, None]: + """处理上游流式响应""" + self.logger.info("✅ 上游响应成功,开始处理 SSE 流") + + has_tools = settings.TOOL_SUPPORT and bool(request.tools) + buffered_content = "" + usage_info: Dict[str, int] = { + "prompt_tokens": 0, + "completion_tokens": 0, + "total_tokens": 0, + } + tool_calls_accum: List[Dict[str, Any]] = [] + has_sent_role = False + finished = False + line_count = 0 + + async def ensure_role_sent() -> Optional[str]: + nonlocal has_sent_role + if has_sent_role: + return None + + has_sent_role = True + return await format_sse_chunk( + create_openai_chunk(chat_id, model, {"role": "assistant"}) + ) + + async def finalize_stream() -> AsyncGenerator[str, None]: + nonlocal finished, tool_calls_accum + if finished: + return + + if has_tools and not tool_calls_accum: + parsed_tool_calls, _ = parse_and_extract_tool_calls(buffered_content) + normalized = self._normalize_tool_calls(parsed_tool_calls) + if normalized: + tool_calls_accum = normalized + role_output = await ensure_role_sent() + if role_output: + yield role_output + for tool_call in normalized: + yield await format_sse_chunk( + create_openai_chunk( + chat_id, + model, + {"tool_calls": [tool_call]}, + ) + ) + + if not has_sent_role: + role_output = await ensure_role_sent() + if role_output: + yield role_output + + finish_reason = "tool_calls" if tool_calls_accum else "stop" + finish_chunk = create_openai_chunk( + chat_id, + model, + {}, + finish_reason, + ) + finish_chunk["usage"] = usage_info + yield await format_sse_chunk(finish_chunk) + yield "data: [DONE]\n\n" + finished = True + + try: + async for line in response.aiter_lines(): + line_count += 1 + if not line: + continue + + current_line = line.strip() + if not current_line.startswith("data:"): + continue + + chunk_str = current_line[5:].strip() + if not chunk_str: + continue + + if chunk_str == "[DONE]": + async for final_chunk in finalize_stream(): + yield final_chunk + continue + + try: + chunk = json.loads(chunk_str) + except json.JSONDecodeError as error: + self.logger.debug(f"❌ JSON解析错误: {error}, 内容: {chunk_str[:1000]}") + continue + + chunk_type = chunk.get("type") + data = chunk.get("data", {}) if chunk_type == "chat:completion" else chunk + if not isinstance(data, dict): + continue + + phase = data.get("phase") + delta_content = data.get("delta_content", "") + edit_content = data.get("edit_content", "") + + if phase and phase != getattr(self, "_last_phase", None): + self.logger.info(f"📈 SSE 阶段: {phase}") + self._last_phase = phase + + if data.get("usage"): + usage_info = data["usage"] + + if delta_content: + buffered_content += delta_content + elif edit_content: + buffered_content += edit_content + + direct_tool_calls = self._normalize_tool_calls( + data.get("tool_calls"), + len(tool_calls_accum), + ) + if direct_tool_calls: + role_output = await ensure_role_sent() + if role_output: + yield role_output + tool_calls_accum.extend(direct_tool_calls) + for tool_call in direct_tool_calls: + yield await format_sse_chunk( + create_openai_chunk( + chat_id, + model, + {"tool_calls": [tool_call]}, + ) + ) + + if phase == "thinking" and delta_content: + cleaned = self._clean_reasoning_delta(delta_content) + if cleaned: + role_output = await ensure_role_sent() + if role_output: + yield role_output + yield await format_sse_chunk( + create_openai_chunk( + chat_id, + model, + {"reasoning_content": cleaned}, + ) + ) + + elif phase == "answer": + text = delta_content or self._extract_answer_content(edit_content) + if text: + role_output = await ensure_role_sent() + if role_output: + yield role_output + yield await format_sse_chunk( + create_openai_chunk( + chat_id, + model, + {"content": text}, + ) + ) + + elif phase == "other": + other_text = self._extract_answer_content(edit_content) + if other_text: + role_output = await ensure_role_sent() + if role_output: + yield role_output + yield await format_sse_chunk( + create_openai_chunk( + chat_id, + model, + {"content": other_text}, + ) + ) + + elif phase == "search" or chunk_type == "web_search": + citation_text = self._format_search_results(data) + if citation_text: + role_output = await ensure_role_sent() + if role_output: + yield role_output + yield await format_sse_chunk( + create_openai_chunk( + chat_id, + model, + {"content": citation_text}, + ) + ) + + if data.get("done"): + async for final_chunk in finalize_stream(): + yield final_chunk + return + + self.logger.info(f"✅ SSE 流处理完成,共处理 {line_count} 行数据") + + if not finished: + async for final_chunk in finalize_stream(): + yield final_chunk + + except Exception as e: + self.logger.error(f"❌ 流式响应处理错误: {e}") + import traceback + self.logger.error(traceback.format_exc()) + yield await format_sse_chunk( + create_openai_chunk(chat_id, model, {}, "stop") + ) + yield "data: [DONE]\n\n" + + async def _handle_non_stream_response( + self, + response: httpx.Response, + chat_id: str, + model: str + ) -> Dict[str, Any]: + """处理非流式响应,聚合上游 SSE 为一次性 OpenAI 响应。""" + final_content = "" + reasoning_content = "" + tool_calls_accum: List[Dict[str, Any]] = [] + usage_info: Dict[str, int] = { + "prompt_tokens": 0, + "completion_tokens": 0, + "total_tokens": 0, + } + + try: + async for line in response.aiter_lines(): + if not line: + continue + + line = line.strip() + if not line.startswith("data:"): + try: + maybe_err = json.loads(line) + if isinstance(maybe_err, dict) and ( + "error" in maybe_err or "code" in maybe_err or "message" in maybe_err + ): + msg = ( + (maybe_err.get("error") or {}).get("message") + if isinstance(maybe_err.get("error"), dict) + else maybe_err.get("message") + ) or "上游返回错误" + return handle_error(Exception(msg), "API响应") + except Exception: + pass + continue + + data_str = line[5:].strip() + if not data_str or data_str in ("[DONE]", "DONE", "done"): + continue + + try: + chunk = json.loads(data_str) + except json.JSONDecodeError: + continue + + chunk_type = chunk.get("type") + data = chunk.get("data", {}) if chunk_type == "chat:completion" else chunk + if not isinstance(data, dict): + continue + + phase = data.get("phase") + delta_content = data.get("delta_content", "") + edit_content = data.get("edit_content", "") + + if data.get("usage"): + usage_info = data["usage"] + + if phase == "thinking" and delta_content: + reasoning_content += self._clean_reasoning_delta(delta_content) + + elif phase == "answer": + if delta_content: + final_content += delta_content + elif edit_content: + final_content += self._extract_answer_content(edit_content) + + elif phase == "other" and edit_content: + final_content += self._extract_answer_content(edit_content) + + elif phase == "search" or chunk_type == "web_search": + final_content += self._format_search_results(data) + + tool_calls_accum.extend( + self._normalize_tool_calls( + data.get("tool_calls"), + len(tool_calls_accum), + ) + ) + + except Exception as e: + self.logger.error(f"❌ 非流式响应处理错误: {e}") + import traceback + self.logger.error(traceback.format_exc()) + return handle_error(e, "非流式聚合") + + if not tool_calls_accum: + parsed_tool_calls, cleaned_content = parse_and_extract_tool_calls(final_content) + normalized = self._normalize_tool_calls(parsed_tool_calls) + if normalized: + tool_calls_accum = normalized + final_content = cleaned_content + + final_content = (final_content or "").strip() + reasoning_content = (reasoning_content or "").strip() + + if not final_content and reasoning_content: + final_content = reasoning_content + + return create_openai_response_with_reasoning( + chat_id, + model, + final_content, + reasoning_content, + usage_info, + tool_calls_accum or None, + ) diff --git a/app/models/__init__.py b/app/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..daad5f4488268bebe2c495c63a408caf0ef1c881 --- /dev/null +++ b/app/models/__init__.py @@ -0,0 +1,6 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +from app.models import schemas + +__all__ = ["schemas"] diff --git a/app/models/request_log.py b/app/models/request_log.py new file mode 100644 index 0000000000000000000000000000000000000000..7ec9f980e0e40ad228ddbe99ce53730f15d6a2eb --- /dev/null +++ b/app/models/request_log.py @@ -0,0 +1,35 @@ +"""请求日志数据库模型。""" + +from app.core.config import settings + +DB_PATH = settings.DB_PATH + +# 创建请求日志表的SQL +SQL_CREATE_REQUEST_LOGS_TABLE = """ +CREATE TABLE IF NOT EXISTS request_logs ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + timestamp DATETIME DEFAULT CURRENT_TIMESTAMP, + provider TEXT NOT NULL, + endpoint TEXT DEFAULT '', + source TEXT DEFAULT 'unknown', + protocol TEXT DEFAULT 'unknown', + client_name TEXT DEFAULT 'Unknown', + model TEXT NOT NULL, + status_code INTEGER DEFAULT 200, + success BOOLEAN NOT NULL, + duration REAL, + first_token_time REAL, + input_tokens INTEGER DEFAULT 0, + output_tokens INTEGER DEFAULT 0, + cache_creation_tokens INTEGER DEFAULT 0, + cache_read_tokens INTEGER DEFAULT 0, + total_tokens INTEGER DEFAULT 0, + error_message TEXT, + created_at DATETIME DEFAULT CURRENT_TIMESTAMP +); + +CREATE INDEX IF NOT EXISTS idx_request_logs_timestamp ON request_logs(timestamp); +CREATE INDEX IF NOT EXISTS idx_request_logs_model ON request_logs(model); +CREATE INDEX IF NOT EXISTS idx_request_logs_provider ON request_logs(provider); +CREATE INDEX IF NOT EXISTS idx_request_logs_source ON request_logs(source); +""" diff --git a/app/models/schemas.py b/app/models/schemas.py new file mode 100644 index 0000000000000000000000000000000000000000..6e4fa63e49155843f1cb4c5d717a6bc709369e4d --- /dev/null +++ b/app/models/schemas.py @@ -0,0 +1,166 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +from typing import Dict, List, Optional, Any, Union, Literal +from pydantic import BaseModel + + +class ImageUrl(BaseModel): + """Image URL model for vision content""" + url: str + + +class ContentPart(BaseModel): + """Content part model for OpenAI's new content format""" + + type: str + text: Optional[str] = None + image_url: Optional[ImageUrl] = None # 添加 image_url 字段 + + +class Message(BaseModel): + """Chat message model""" + + role: str + content: Optional[Union[str, List[ContentPart]]] = None + reasoning_content: Optional[str] = None + tool_calls: Optional[List[Dict[str, Any]]] = None + tool_call_id: Optional[str] = None + name: Optional[str] = None + + +class OpenAIRequest(BaseModel): + """OpenAI-compatible request model""" + + model: str + messages: List[Message] + stream: Optional[bool] = False + temperature: Optional[float] = None + max_tokens: Optional[int] = None + tools: Optional[List[Dict[str, Any]]] = None + tool_choice: Optional[Any] = None + enable_thinking: Optional[bool] = None + web_search: Optional[bool] = None + + +class ModelItem(BaseModel): + """Model information item""" + + id: str + name: str + owned_by: str + + +class UpstreamRequest(BaseModel): + """Upstream service request model""" + + stream: bool + model: str + messages: List[Message] + params: Dict[str, Any] = {} + features: Dict[str, Any] = {} + signature_prompt: Optional[str] = None + files: Optional[List[Dict[str, Any]]] = None + extra: Optional[Dict[str, Any]] = None + background_tasks: Optional[Dict[str, bool]] = None + chat_id: Optional[str] = None + id: Optional[str] = None + session_id: Optional[str] = None + current_user_message_id: Optional[str] = None + current_user_message_parent_id: Optional[str] = None + mcp_servers: Optional[List[str]] = None + model_item: Optional[Dict[str, Any]] = {} # Model item dictionary + tools: Optional[List[Dict[str, Any]]] = None # Add tools field for OpenAI compatibility + tool_choice: Optional[Any] = None + variables: Optional[Dict[str, str]] = None + model_config = {"protected_namespaces": ()} + + +class Delta(BaseModel): + """Stream delta model""" + + role: Optional[str] = None + content: Optional[str] = "" or None + reasoning_content: Optional[str] = None + tool_calls: Optional[List[Dict[str, Any]]] = None + + +class Choice(BaseModel): + """Response choice model""" + + index: int + message: Optional[Message] = None + delta: Optional[Delta] = None + finish_reason: Optional[str] = None + + +class Usage(BaseModel): + """Token usage statistics""" + + prompt_tokens: int = 0 + completion_tokens: int = 0 + total_tokens: int = 0 + + +class OpenAIResponse(BaseModel): + """OpenAI-compatible response model""" + + id: str + object: str + created: int + model: str + choices: List[Choice] + usage: Optional[Usage] = None + + +class UpstreamError(BaseModel): + """Upstream error model""" + + detail: str + code: int + + +class UpstreamDataInner(BaseModel): + """Inner upstream data model""" + + error: Optional[UpstreamError] = None + + +class UpstreamDataData(BaseModel): + """Upstream data content model""" + + delta_content: str = "" + edit_content: str = "" + phase: str = "" + done: bool = False + results: Optional[List[Dict[str, Any]]] = None + sources: Optional[List[Dict[str, Any]]] = None + citations: Optional[List[Dict[str, Any]]] = None + tool_calls: Optional[List[Dict[str, Any]]] = None + usage: Optional[Usage] = None + error: Optional[UpstreamError] = None + inner: Optional[UpstreamDataInner] = None + + +class UpstreamData(BaseModel): + """Upstream data model""" + + type: str + data: UpstreamDataData + error: Optional[UpstreamError] = None + + +class Model(BaseModel): + """Model information for listing""" + + id: str + object: str = "model" + created: int + owned_by: str + + +class ModelsResponse(BaseModel): + """Models list response model""" + + object: str = "list" + data: List[Model] diff --git a/app/models/token_db.py b/app/models/token_db.py new file mode 100644 index 0000000000000000000000000000000000000000..91795c5d1d51e1fa6c92c52ab3315fe99f8149a7 --- /dev/null +++ b/app/models/token_db.py @@ -0,0 +1,44 @@ +"""Token 数据库模型定义。""" + +from app.core.config import settings + +SQL_CREATE_TABLES = """ +-- Token 配置表 +CREATE TABLE IF NOT EXISTS tokens ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + provider TEXT NOT NULL, -- 提供商: zai + token TEXT NOT NULL UNIQUE, -- Token 值(唯一) + token_type TEXT DEFAULT 'user', -- Token 类型: user, guest, unknown + is_enabled BOOLEAN DEFAULT 1, -- 是否启用 + priority INTEGER DEFAULT 0, -- 优先级(用于排序) + created_at DATETIME DEFAULT CURRENT_TIMESTAMP, + updated_at DATETIME DEFAULT CURRENT_TIMESTAMP, + UNIQUE(provider, token) -- 同一提供商内 Token 唯一 +); + +-- Token 使用统计表 +CREATE TABLE IF NOT EXISTS token_stats ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + token_id INTEGER NOT NULL, + total_requests INTEGER DEFAULT 0, + successful_requests INTEGER DEFAULT 0, + failed_requests INTEGER DEFAULT 0, + last_success_time DATETIME, + last_failure_time DATETIME, + FOREIGN KEY (token_id) REFERENCES tokens(id) ON DELETE CASCADE +); + +-- 创建索引 +CREATE INDEX IF NOT EXISTS idx_tokens_provider ON tokens(provider); +CREATE INDEX IF NOT EXISTS idx_tokens_enabled ON tokens(is_enabled); +CREATE INDEX IF NOT EXISTS idx_token_stats_token_id ON token_stats(token_id); + +-- 触发器:自动更新 updated_at +CREATE TRIGGER IF NOT EXISTS update_tokens_timestamp +AFTER UPDATE ON tokens +BEGIN + UPDATE tokens SET updated_at = CURRENT_TIMESTAMP WHERE id = NEW.id; +END; +""" + +DB_PATH = settings.DB_PATH diff --git a/app/services/request_log_dao.py b/app/services/request_log_dao.py new file mode 100644 index 0000000000000000000000000000000000000000..53e14f735c224091959c7e88f4f7ad5e71bd13f1 --- /dev/null +++ b/app/services/request_log_dao.py @@ -0,0 +1,630 @@ +""" +请求日志数据访问层 (DAO) +提供请求日志的 CRUD 操作和查询功能 +""" +import os +import sqlite3 +from contextlib import asynccontextmanager +from datetime import datetime, timedelta +from typing import Dict, List, Optional + +import aiosqlite + +from app.models.request_log import DB_PATH, SQL_CREATE_REQUEST_LOGS_TABLE +from app.utils.logger import logger + + +def _format_sqlite_datetime(value: datetime) -> str: + """格式化为 SQLite `CURRENT_TIMESTAMP` 兼容的时间字符串。""" + return value.strftime("%Y-%m-%d %H:%M:%S") + + +def _normalize_trend_window(window: Optional[str], days: Optional[int]) -> str: + """统一趋势窗口参数,兼容旧版 `days` 调用。""" + if window: + normalized = str(window).strip().lower() + elif days == 30: + normalized = "30d" + elif days == 1: + normalized = "24h" + else: + normalized = "7d" + + if normalized in {"24h", "7d", "30d"}: + return normalized + if normalized == "1d": + return "24h" + return "7d" + + +class RequestLogDAO: + """请求日志数据访问对象""" + + def __init__(self, db_path: str = DB_PATH): + """初始化 DAO""" + self.db_path = db_path + self._ensure_db_directory() + self._init_db() + + def _ensure_db_directory(self): + """确保数据库目录存在""" + db_dir = os.path.dirname(self.db_path) + if db_dir and not os.path.exists(db_dir): + os.makedirs(db_dir, exist_ok=True) + + def _init_db(self): + """初始化数据库表""" + try: + conn = sqlite3.connect(self.db_path) + conn.executescript(SQL_CREATE_REQUEST_LOGS_TABLE) + self._ensure_columns(conn) + conn.commit() + conn.close() + logger.debug("请求日志表初始化成功") + except Exception as e: + logger.error(f"初始化请求日志表失败: {e}") + + def _ensure_columns(self, conn: sqlite3.Connection): + """为旧数据库补齐新增列。""" + cursor = conn.execute("PRAGMA table_info(request_logs)") + existing_columns = {row[1] for row in cursor.fetchall()} + required_columns = { + "endpoint": "TEXT DEFAULT ''", + "source": "TEXT DEFAULT 'unknown'", + "protocol": "TEXT DEFAULT 'unknown'", + "client_name": "TEXT DEFAULT 'Unknown'", + "status_code": "INTEGER DEFAULT 200", + "cache_creation_tokens": "INTEGER DEFAULT 0", + "cache_read_tokens": "INTEGER DEFAULT 0", + } + + for column, definition in required_columns.items(): + if column in existing_columns: + continue + conn.execute( + f"ALTER TABLE request_logs ADD COLUMN {column} {definition}" + ) + + @asynccontextmanager + async def get_connection(self): + """获取异步数据库连接""" + conn = await aiosqlite.connect(self.db_path) + conn.row_factory = aiosqlite.Row + try: + yield conn + finally: + await conn.close() + + async def add_log( + self, + provider: str, + endpoint: str, + source: str, + protocol: str, + client_name: str, + model: str, + status_code: int, + success: bool, + duration: float = 0.0, + first_token_time: float = 0.0, + input_tokens: int = 0, + output_tokens: int = 0, + cache_creation_tokens: int = 0, + cache_read_tokens: int = 0, + total_tokens: Optional[int] = None, + error_message: str = None + ) -> int: + """ + 添加请求日志 + + Args: + provider: 提供商名称 + endpoint: 请求端点 + source: 请求来源标识 + protocol: 协议类型 + client_name: 客户端名称 + model: 模型名称 + status_code: 请求状态码 + success: 是否成功 + duration: 总耗时(秒) + first_token_time: 首字延迟(秒) + input_tokens: 输入 token 数 + output_tokens: 输出 token 数 + cache_creation_tokens: 缓存创建 token 数 + cache_read_tokens: 缓存命中 token 数 + total_tokens: 总 token 数 + error_message: 错误信息 + + Returns: + 日志 ID + """ + if total_tokens is None: + total_tokens = input_tokens + output_tokens + + async with self.get_connection() as conn: + cursor = await conn.execute( + """ + INSERT INTO request_logs + (provider, endpoint, source, protocol, client_name, model, + status_code, success, duration, first_token_time, + input_tokens, output_tokens, cache_creation_tokens, + cache_read_tokens, total_tokens, error_message) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + provider, + endpoint, + source, + protocol, + client_name, + model, + status_code, + success, + duration, + first_token_time, + input_tokens, + output_tokens, + cache_creation_tokens, + cache_read_tokens, + total_tokens, + error_message, + ) + ) + await conn.commit() + return cursor.lastrowid + + async def get_recent_logs( + self, + limit: int = 100, + offset: int = 0, + provider: str = None, + model: str = None, + success: bool = None, + source: str = None, + ) -> List[Dict]: + """ + 获取最近的请求日志 + + Args: + limit: 返回数量限制 + provider: 过滤提供商 + model: 过滤模型 + success: 过滤成功/失败状态 + + Returns: + 日志列表 + """ + query = "SELECT * FROM request_logs WHERE 1=1" + params = [] + + if provider: + query += " AND provider = ?" + params.append(provider) + + if model: + query += " AND model = ?" + params.append(model) + + if success is not None: + query += " AND success = ?" + params.append(success) + + if source: + query += " AND source = ?" + params.append(source) + + query += " ORDER BY timestamp DESC, id DESC LIMIT ? OFFSET ?" + params.extend([limit, max(0, offset)]) + + async with self.get_connection() as conn: + cursor = await conn.execute(query, params) + rows = await cursor.fetchall() + return [dict(row) for row in rows] + + async def count_logs( + self, + provider: str = None, + model: str = None, + success: bool = None, + source: str = None, + ) -> int: + """统计日志总数。""" + query = "SELECT COUNT(*) AS total_count FROM request_logs WHERE 1=1" + params = [] + + if provider: + query += " AND provider = ?" + params.append(provider) + + if model: + query += " AND model = ?" + params.append(model) + + if success is not None: + query += " AND success = ?" + params.append(success) + + if source: + query += " AND source = ?" + params.append(source) + + async with self.get_connection() as conn: + cursor = await conn.execute(query, params) + row = await cursor.fetchone() + return int(row["total_count"] or 0) if row else 0 + + async def get_logs_by_time_range( + self, + start_time: datetime, + end_time: datetime, + provider: str = None, + model: str = None + ) -> List[Dict]: + """ + 按时间范围获取日志 + + Args: + start_time: 开始时间 + end_time: 结束时间 + provider: 过滤提供商 + model: 过滤模型 + + Returns: + 日志列表 + """ + query = "SELECT * FROM request_logs WHERE timestamp BETWEEN ? AND ?" + params = [ + _format_sqlite_datetime(start_time), + _format_sqlite_datetime(end_time), + ] + + if provider: + query += " AND provider = ?" + params.append(provider) + + if model: + query += " AND model = ?" + params.append(model) + + query += " ORDER BY timestamp DESC, id DESC" + + async with self.get_connection() as conn: + cursor = await conn.execute(query, params) + rows = await cursor.fetchall() + return [dict(row) for row in rows] + + async def get_provider_request_stats(self, provider: Optional[str] = None) -> Dict: + """聚合请求日志统计,可按提供商过滤。""" + query = """ + SELECT + COUNT(*) as total_requests, + SUM(CASE WHEN success = 1 THEN 1 ELSE 0 END) as successful_requests, + SUM(CASE WHEN success = 0 THEN 1 ELSE 0 END) as failed_requests, + SUM(input_tokens) as input_tokens, + SUM(output_tokens) as output_tokens, + SUM(total_tokens) as total_tokens, + SUM(cache_creation_tokens) as cache_creation_tokens, + SUM(cache_read_tokens) as cache_read_tokens, + SUM( + CASE WHEN cache_creation_tokens > 0 THEN 1 ELSE 0 END + ) as cache_creation_requests, + SUM( + CASE WHEN cache_read_tokens > 0 THEN 1 ELSE 0 END + ) as cache_hit_requests, + AVG(duration) as avg_duration, + AVG( + CASE + WHEN first_token_time > 0 THEN first_token_time + ELSE NULL + END + ) as avg_first_token_time + FROM request_logs + """ + params: List[object] = [] + + if provider: + query += " WHERE provider = ?" + params.append(provider) + + try: + async with self.get_connection() as conn: + cursor = await conn.execute(query, params) + row = await cursor.fetchone() + + if not row: + return { + "total_requests": 0, + "successful_requests": 0, + "failed_requests": 0, + "input_tokens": 0, + "output_tokens": 0, + "total_tokens": 0, + "cache_creation_tokens": 0, + "cache_read_tokens": 0, + "cache_creation_requests": 0, + "cache_hit_requests": 0, + "avg_duration": 0.0, + "avg_first_token_time": 0.0, + } + + return { + "total_requests": int(row["total_requests"] or 0), + "successful_requests": int(row["successful_requests"] or 0), + "failed_requests": int(row["failed_requests"] or 0), + "input_tokens": int(row["input_tokens"] or 0), + "output_tokens": int(row["output_tokens"] or 0), + "total_tokens": int(row["total_tokens"] or 0), + "cache_creation_tokens": int( + row["cache_creation_tokens"] or 0 + ), + "cache_read_tokens": int(row["cache_read_tokens"] or 0), + "cache_creation_requests": int( + row["cache_creation_requests"] or 0 + ), + "cache_hit_requests": int(row["cache_hit_requests"] or 0), + "avg_duration": float(row["avg_duration"] or 0.0), + "avg_first_token_time": float( + row["avg_first_token_time"] or 0.0 + ), + } + except Exception as e: + logger.error(f"❌ 获取请求统计失败: {e}") + return { + "total_requests": 0, + "successful_requests": 0, + "failed_requests": 0, + "input_tokens": 0, + "output_tokens": 0, + "total_tokens": 0, + "cache_creation_tokens": 0, + "cache_read_tokens": 0, + "cache_creation_requests": 0, + "cache_hit_requests": 0, + "avg_duration": 0.0, + "avg_first_token_time": 0.0, + } + + async def get_provider_usage_trend( + self, + provider: Optional[str] = None, + days: Optional[int] = None, + *, + window: Optional[str] = None, + now: Optional[datetime] = None, + ) -> List[Dict]: + """按窗口聚合最近一段时间的请求与 token 趋势。""" + trend_window = _normalize_trend_window(window, days) + current_time = now or datetime.utcnow() + + if trend_window == "24h": + bucket_count = 24 + current_hour = current_time.replace( + minute=0, + second=0, + microsecond=0, + ) + start_time = current_hour - timedelta(hours=bucket_count - 1) + bucket_expression = "strftime('%Y-%m-%d %H:00:00', timestamp)" + row_key = "trend_bucket" + label_format = "%H:%M" + tooltip_format = "%Y-%m-%d %H:00" + rows = await self._query_usage_trend_rows( + provider, + start_time, + bucket_expression, + row_key, + ) + rows_by_bucket = {str(row[row_key]): dict(row) for row in rows} + trend: List[Dict] = [] + + for offset in range(bucket_count): + bucket_time = start_time + timedelta(hours=offset) + bucket_key = bucket_time.strftime("%Y-%m-%d %H:00:00") + trend.append( + self._build_usage_trend_point( + row=rows_by_bucket.get(bucket_key, {}), + bucket=bucket_key, + label=bucket_time.strftime(label_format), + tooltip_label=bucket_time.strftime(tooltip_format), + ) + ) + + return trend + + bucket_count = 30 if trend_window == "30d" else 7 + current_date = current_time.date() + start_date = current_date - timedelta(days=bucket_count - 1) + start_time = datetime.combine(start_date, datetime.min.time()) + rows = await self._query_usage_trend_rows( + provider, + start_time, + "DATE(timestamp)", + "trend_bucket", + ) + rows_by_bucket = { + str(row["trend_bucket"]): dict(row) + for row in rows + } + trend = [] + + for offset in range(bucket_count): + bucket_date = start_date + timedelta(days=offset) + bucket_key = bucket_date.isoformat() + trend.append( + self._build_usage_trend_point( + row=rows_by_bucket.get(bucket_key, {}), + bucket=bucket_key, + label=bucket_date.strftime("%m-%d"), + tooltip_label=bucket_date.strftime("%Y-%m-%d"), + ) + ) + + return trend + + async def _query_usage_trend_rows( + self, + provider: Optional[str], + start_time: datetime, + bucket_expression: str, + bucket_alias: str, + ) -> list[aiosqlite.Row]: + query = f""" + SELECT + {bucket_expression} as {bucket_alias}, + COUNT(*) as total_requests, + SUM(CASE WHEN success = 1 THEN 1 ELSE 0 END) as successful_requests, + SUM(input_tokens) as input_tokens, + SUM(output_tokens) as output_tokens, + SUM(total_tokens) as total_tokens, + SUM(cache_creation_tokens) as cache_creation_tokens, + SUM(cache_read_tokens) as cache_read_tokens + FROM request_logs + WHERE timestamp >= ? + """ + params: List[object] = [_format_sqlite_datetime(start_time)] + + if provider: + query += " AND provider = ?" + params.append(provider) + + query += f" GROUP BY {bucket_expression} ORDER BY {bucket_alias} ASC" + + async with self.get_connection() as conn: + cursor = await conn.execute(query, params) + return await cursor.fetchall() + + def _build_usage_trend_point( + self, + *, + row: Dict, + bucket: str, + label: str, + tooltip_label: str, + ) -> Dict: + total_requests = int(row.get("total_requests") or 0) + successful_requests = int(row.get("successful_requests") or 0) + cache_creation_tokens = int(row.get("cache_creation_tokens") or 0) + cache_read_tokens = int(row.get("cache_read_tokens") or 0) + + return { + "bucket": bucket, + "label": label, + "tooltip_label": tooltip_label, + "total_requests": total_requests, + "successful_requests": successful_requests, + "failed_requests": max(0, total_requests - successful_requests), + "input_tokens": int(row.get("input_tokens") or 0), + "output_tokens": int(row.get("output_tokens") or 0), + "total_tokens": int(row.get("total_tokens") or 0), + "cache_creation_tokens": cache_creation_tokens, + "cache_read_tokens": cache_read_tokens, + "cache_total_tokens": ( + cache_creation_tokens + cache_read_tokens + ), + "success_rate": round( + ( + successful_requests / total_requests * 100 + ) if total_requests > 0 else 0, + 1, + ), + } + + async def get_model_stats_from_db(self, hours: int = 24) -> Dict: + """ + 从数据库获取模型统计(最近N小时) + + Args: + hours: 小时数 + + Returns: + 模型统计数据 + """ + start_time = datetime.utcnow() - timedelta(hours=hours) + + async with self.get_connection() as conn: + cursor = await conn.execute( + """ + SELECT + model, + COUNT(*) as total, + SUM(CASE WHEN success = 1 THEN 1 ELSE 0 END) as success, + SUM(CASE WHEN success = 0 THEN 1 ELSE 0 END) as failed, + SUM(input_tokens) as input_tokens, + SUM(output_tokens) as output_tokens, + SUM(total_tokens) as total_tokens, + AVG(duration) as avg_duration, + AVG(first_token_time) as avg_first_token_time + FROM request_logs + WHERE timestamp >= ? + GROUP BY model + ORDER BY total DESC + """, + (_format_sqlite_datetime(start_time),) + ) + rows = await cursor.fetchall() + + result = {} + for row in rows: + model = row['model'] + result[model] = { + 'total': row['total'], + 'success': row['success'], + 'failed': row['failed'], + 'input_tokens': row['input_tokens'] or 0, + 'output_tokens': row['output_tokens'] or 0, + 'total_tokens': row['total_tokens'] or 0, + 'avg_duration': round(row['avg_duration'] or 0, 2), + 'avg_first_token_time': round(row['avg_first_token_time'] or 0, 2), + 'success_rate': round( + (row['success'] / row['total'] * 100) + if row['total'] > 0 + else 0, + 1, + ), + } + + return result + + async def delete_old_logs(self, days: int = 30) -> int: + """ + 删除旧日志 + + Args: + days: 保留天数 + + Returns: + 删除的记录数 + """ + cutoff_time = datetime.utcnow() - timedelta(days=days) + + async with self.get_connection() as conn: + cursor = await conn.execute( + "DELETE FROM request_logs WHERE timestamp < ?", + (_format_sqlite_datetime(cutoff_time),) + ) + await conn.commit() + return cursor.rowcount + + +# 全局单例实例 +_request_log_dao: Optional[RequestLogDAO] = None + + +def get_request_log_dao() -> RequestLogDAO: + """ + 获取请求日志 DAO 单例 + + Returns: + RequestLogDAO 实例 + """ + global _request_log_dao + if _request_log_dao is None: + _request_log_dao = RequestLogDAO() + return _request_log_dao + + +def init_request_log_dao(): + """初始化请求日志 DAO""" + global _request_log_dao + _request_log_dao = RequestLogDAO() + return _request_log_dao diff --git a/app/services/token_automation.py b/app/services/token_automation.py new file mode 100644 index 0000000000000000000000000000000000000000..abf07e8fa4f933465b506e167ab64fd1932bbe87 --- /dev/null +++ b/app/services/token_automation.py @@ -0,0 +1,278 @@ +"""Background automation for token import and maintenance.""" + +from __future__ import annotations + +import asyncio +from dataclasses import dataclass +from typing import Optional + +from app.core.config import settings +from app.services.token_dao import TokenDAO, get_token_dao +from app.services.token_importer import TokenImportSummary, import_tokens_from_directory +from app.utils.logger import logger +from app.utils.token_pool import TokenPool, get_token_pool + +DEFAULT_TOKEN_PROVIDER = "zai" +_AUTO_IMPORT_LOCK = asyncio.Lock() +_AUTO_MAINTENANCE_LOCK = asyncio.Lock() + + +@dataclass(frozen=True) +class TokenMaintenanceSummary: + provider: str + checked_count: int = 0 + duplicate_removed_count: int = 0 + valid_count: int = 0 + guest_count: int = 0 + invalid_count: int = 0 + deleted_invalid_count: int = 0 + + +async def run_directory_import( + source_dir: str, + *, + provider: str = DEFAULT_TOKEN_PROVIDER, + validate: bool = True, + dao: Optional[TokenDAO] = None, + pool: Optional[TokenPool] = None, +) -> TokenImportSummary: + """Import tokens from a configured directory and refresh the pool if needed.""" + if _AUTO_IMPORT_LOCK.locked(): + raise RuntimeError("目录导入任务正在执行,请稍后再试") + + async with _AUTO_IMPORT_LOCK: + summary = await import_tokens_from_directory( + source_dir, + provider=provider, + validate=validate, + dao=dao, + ) + + active_pool = pool if pool is not None else get_token_pool() + if active_pool and summary.imported_count > 0: + await active_pool.sync_from_database(provider) + logger.info("✅ 目录导入后已同步 Token 池") + + return summary + + +async def run_token_maintenance( + *, + provider: str = DEFAULT_TOKEN_PROVIDER, + remove_duplicates: bool = True, + run_health_check: bool = True, + delete_invalid_tokens: bool = False, + dao: Optional[TokenDAO] = None, + pool: Optional[TokenPool] = None, +) -> TokenMaintenanceSummary: + """Run dedupe, validation, and invalid-token cleanup as one maintenance cycle.""" + if _AUTO_MAINTENANCE_LOCK.locked(): + raise RuntimeError("Token 自动维护任务正在执行,请稍后再试") + + token_dao = dao or get_token_dao() + duplicate_removed_count = 0 + checked_count = 0 + valid_count = 0 + guest_count = 0 + invalid_count = 0 + deleted_invalid_count = 0 + + async with _AUTO_MAINTENANCE_LOCK: + if remove_duplicates: + duplicate_removed_count = await token_dao.remove_duplicate_tokens(provider) + + should_validate = run_health_check or delete_invalid_tokens + invalid_token_ids: list[int] = [] + + if should_validate: + validation_result = await token_dao.validate_tokens_detailed(provider) + checked_count = int(validation_result.get("checked", 0) or 0) + valid_count = int(validation_result.get("valid", 0) or 0) + guest_count = int(validation_result.get("guest", 0) or 0) + invalid_count = int(validation_result.get("invalid", 0) or 0) + invalid_token_ids = list( + validation_result.get("invalid_token_ids", []) or [] + ) + + if delete_invalid_tokens and invalid_token_ids: + deleted_invalid_count = await token_dao.delete_tokens_by_ids( + invalid_token_ids + ) + + active_pool = pool if pool is not None else get_token_pool() + if active_pool: + await active_pool.sync_from_database(provider) + logger.info("✅ Token 维护后已同步 Token 池") + + return TokenMaintenanceSummary( + provider=provider, + checked_count=checked_count, + duplicate_removed_count=duplicate_removed_count, + valid_count=valid_count, + guest_count=guest_count, + invalid_count=invalid_count, + deleted_invalid_count=deleted_invalid_count, + ) + + +class TokenAutomationScheduler: + """Run token import and maintenance loops in the application background.""" + + def __init__(self) -> None: + self._stop_event = asyncio.Event() + self._tasks: list[asyncio.Task] = [] + self._import_warning: Optional[str] = None + self._maintenance_warning: Optional[str] = None + + async def start(self) -> None: + if self._tasks: + return + + self._stop_event.clear() + self._tasks = [ + asyncio.create_task( + self._auto_import_loop(), + name="token-auto-import", + ), + asyncio.create_task( + self._auto_maintenance_loop(), + name="token-auto-maintenance", + ), + ] + logger.info("✅ Token 自动任务调度器已启动") + + async def stop(self) -> None: + if not self._tasks: + return + + self._stop_event.set() + for task in self._tasks: + task.cancel() + + await asyncio.gather(*self._tasks, return_exceptions=True) + self._tasks.clear() + self._import_warning = None + self._maintenance_warning = None + logger.info("🛑 Token 自动任务调度器已停止") + + async def _auto_import_loop(self) -> None: + while not self._stop_event.is_set(): + wait_seconds = 15 + try: + if settings.TOKEN_AUTO_IMPORT_ENABLED: + wait_seconds = max(int(settings.TOKEN_AUTO_IMPORT_INTERVAL), 30) + source_dir = settings.TOKEN_AUTO_IMPORT_SOURCE_DIR.strip() + if not source_dir: + self._log_import_warning_once( + "已启用自动导入,但未配置导入目录" + ) + else: + self._import_warning = None + summary = await run_directory_import( + source_dir, + provider=DEFAULT_TOKEN_PROVIDER, + ) + logger.info( + "🔄 自动导入完成: scanned={} imported={} duplicate={} invalid={}", + summary.scanned_files, + summary.imported_count, + summary.duplicate_count, + summary.invalid_json_count + summary.invalid_token_count, + ) + except asyncio.CancelledError: + raise + except RuntimeError as exc: + logger.info(f"⏭️ 跳过本轮自动导入: {exc}") + except (FileNotFoundError, NotADirectoryError) as exc: + self._log_import_warning_once(str(exc)) + except Exception as exc: + logger.exception(f"❌ 自动导入 Token 失败: {exc}") + + await self._wait_or_stop(wait_seconds) + + async def _auto_maintenance_loop(self) -> None: + while not self._stop_event.is_set(): + wait_seconds = 15 + try: + if settings.TOKEN_AUTO_MAINTENANCE_ENABLED: + wait_seconds = max( + int(settings.TOKEN_AUTO_MAINTENANCE_INTERVAL), + 30, + ) + if not self._has_enabled_maintenance_action(): + self._log_maintenance_warning_once( + "已启用自动维护,但未选择任何维护动作" + ) + else: + self._maintenance_warning = None + summary = await run_token_maintenance( + provider=DEFAULT_TOKEN_PROVIDER, + remove_duplicates=settings.TOKEN_AUTO_REMOVE_DUPLICATES, + run_health_check=settings.TOKEN_AUTO_HEALTH_CHECK, + delete_invalid_tokens=settings.TOKEN_AUTO_DELETE_INVALID, + ) + logger.info( + "🧹 自动维护完成: dedupe={} checked={} valid={} guest={} invalid={} deleted={}", + summary.duplicate_removed_count, + summary.checked_count, + summary.valid_count, + summary.guest_count, + summary.invalid_count, + summary.deleted_invalid_count, + ) + except asyncio.CancelledError: + raise + except RuntimeError as exc: + logger.info(f"⏭️ 跳过本轮自动维护: {exc}") + except Exception as exc: + logger.exception(f"❌ Token 自动维护失败: {exc}") + + await self._wait_or_stop(wait_seconds) + + async def _wait_or_stop(self, timeout: int) -> None: + try: + await asyncio.wait_for(self._stop_event.wait(), timeout=timeout) + except asyncio.TimeoutError: + return + + def _has_enabled_maintenance_action(self) -> bool: + return any( + ( + settings.TOKEN_AUTO_REMOVE_DUPLICATES, + settings.TOKEN_AUTO_HEALTH_CHECK, + settings.TOKEN_AUTO_DELETE_INVALID, + ) + ) + + def _log_import_warning_once(self, message: str) -> None: + if self._import_warning == message: + return + self._import_warning = message + logger.warning(f"⚠️ {message}") + + def _log_maintenance_warning_once(self, message: str) -> None: + if self._maintenance_warning == message: + return + self._maintenance_warning = message + logger.warning(f"⚠️ {message}") + + +_scheduler: Optional[TokenAutomationScheduler] = None + + +def get_token_automation_scheduler() -> TokenAutomationScheduler: + global _scheduler + if _scheduler is None: + _scheduler = TokenAutomationScheduler() + return _scheduler + + +async def start_token_automation_scheduler() -> None: + await get_token_automation_scheduler().start() + + +async def stop_token_automation_scheduler() -> None: + global _scheduler + if _scheduler is None: + return + await _scheduler.stop() diff --git a/app/services/token_dao.py b/app/services/token_dao.py new file mode 100644 index 0000000000000000000000000000000000000000..f4f1058e9f11bad47df8ccb96ff14e004906e786 --- /dev/null +++ b/app/services/token_dao.py @@ -0,0 +1,664 @@ +""" +Token 数据访问层 (DAO) +提供 Token 的 CRUD 操作和查询功能 +""" +import os +import sqlite3 +from contextlib import asynccontextmanager +from typing import Any, Dict, List, Optional, Tuple + +import aiosqlite + +from app.models.token_db import DB_PATH, SQL_CREATE_TABLES +from app.utils.logger import logger + + +class TokenDAO: + """Token 数据访问对象""" + + def __init__(self, db_path: str = DB_PATH): + """初始化 DAO""" + self.db_path = db_path + self._ensure_db_directory() + + def _ensure_db_directory(self): + """确保数据库目录存在""" + db_dir = os.path.dirname(self.db_path) + if db_dir and not os.path.exists(db_dir): + os.makedirs(db_dir, exist_ok=True) + + @asynccontextmanager + async def get_connection(self): + """获取异步数据库连接""" + conn = await aiosqlite.connect(self.db_path) + conn.row_factory = aiosqlite.Row # 返回字典式结果 + + # 启用外键约束(SQLite 默认关闭) + await conn.execute("PRAGMA foreign_keys = ON") + + try: + yield conn + finally: + await conn.close() + + def get_sync_connection(self): + """获取同步数据库连接(用于初始化)""" + conn = sqlite3.connect(self.db_path) + # 启用外键约束 + conn.execute("PRAGMA foreign_keys = ON") + return conn + + async def init_database(self): + """初始化数据库表结构""" + try: + # 使用同步连接创建表(避免异步初始化问题) + conn = self.get_sync_connection() + conn.executescript(SQL_CREATE_TABLES) + conn.commit() + conn.close() + except Exception as e: + logger.error(f"❌ Token 数据库初始化失败: {e}") + raise + + # ==================== Token CRUD 操作 ==================== + + async def add_token( + self, + provider: str, + token: str, + token_type: str = "user", + priority: int = 0, + validate: bool = True + ) -> Optional[int]: + """ + 添加新 Token(可选验证) + + Args: + provider: 提供商名称 + token: Token 值 + token_type: Token 类型(如果 validate=True 将被验证结果覆盖) + priority: 优先级 + validate: 是否验证 Token(仅针对 zai 提供商) + + Returns: + token_id 或 None(验证失败或已存在) + """ + try: + # 对于 zai 提供商,强制验证 Token + if provider == "zai" and validate: + from app.utils.token_pool import ZAITokenValidator + + validated_type, is_valid, error_msg = await ZAITokenValidator.validate_token(token) + + # 拒绝 guest token + if validated_type == "guest": + logger.warning(f"🚫 拒绝添加匿名用户 Token: {token[:20]}... - {error_msg}") + return None + + # 拒绝无效 token + if not is_valid: + logger.warning(f"🚫 Token 验证失败: {token[:20]}... - {error_msg}") + return None + + # 使用验证后的类型 + token_type = validated_type + + async with self.get_connection() as conn: + cursor = await conn.execute(""" + INSERT OR IGNORE INTO tokens (provider, token, token_type, priority) + VALUES (?, ?, ?, ?) + """, (provider, token, token_type, priority)) + + await conn.commit() + + if cursor.lastrowid > 0: + # 同时创建统计记录 + await conn.execute(""" + INSERT INTO token_stats (token_id) + VALUES (?) + """, (cursor.lastrowid,)) + await conn.commit() + logger.info(f"✅ 添加 Token: {provider} ({token_type}) - {token[:20]}...") + return cursor.lastrowid + else: + logger.warning(f"⚠️ Token 已存在: {provider} - {token[:20]}...") + return None + except Exception as e: + logger.error(f"❌ 添加 Token 失败: {e}") + return None + + async def get_tokens_by_provider( + self, + provider: str, + enabled_only: bool = True, + limit: Optional[int] = None, + offset: int = 0, + ) -> List[Dict]: + """ + 获取指定提供商的所有 Token + + Args: + provider: 提供商名称 + enabled_only: 是否只返回启用的 Token + """ + try: + async with self.get_connection() as conn: + query = """ + SELECT t.*, ts.total_requests, ts.successful_requests, ts.failed_requests, + ts.last_success_time, ts.last_failure_time + FROM tokens t + LEFT JOIN token_stats ts ON t.id = ts.token_id + WHERE t.provider = ? + """ + params = [provider] + + if enabled_only: + query += " AND t.is_enabled = 1" + + query += " ORDER BY t.priority DESC, t.id ASC" + + if limit is not None: + query += " LIMIT ? OFFSET ?" + params.extend([limit, max(0, offset)]) + + cursor = await conn.execute(query, params) + rows = await cursor.fetchall() + + return [dict(row) for row in rows] + except Exception as e: + logger.error(f"❌ 查询 Token 失败: {e}") + return [] + + async def get_all_tokens(self, enabled_only: bool = False) -> List[Dict]: + """获取所有 Token""" + try: + async with self.get_connection() as conn: + query = """ + SELECT t.*, ts.total_requests, ts.successful_requests, ts.failed_requests, + ts.last_success_time, ts.last_failure_time + FROM tokens t + LEFT JOIN token_stats ts ON t.id = ts.token_id + """ + + if enabled_only: + query += " WHERE t.is_enabled = 1" + + query += " ORDER BY t.provider, t.priority DESC, t.id ASC" + + cursor = await conn.execute(query) + rows = await cursor.fetchall() + + return [dict(row) for row in rows] + except Exception as e: + logger.error(f"❌ 查询所有 Token 失败: {e}") + return [] + + async def update_token_status(self, token_id: int, is_enabled: bool): + """更新 Token 启用状态""" + try: + async with self.get_connection() as conn: + await conn.execute(""" + UPDATE tokens SET is_enabled = ? WHERE id = ? + """, (is_enabled, token_id)) + await conn.commit() + logger.info(f"✅ 更新 Token 状态: id={token_id}, enabled={is_enabled}") + except Exception as e: + logger.error(f"❌ 更新 Token 状态失败: {e}") + + async def update_token_type(self, token_id: int, token_type: str): + """更新 Token 类型""" + try: + async with self.get_connection() as conn: + await conn.execute(""" + UPDATE tokens SET token_type = ? WHERE id = ? + """, (token_type, token_id)) + await conn.commit() + logger.info(f"✅ 更新 Token 类型: id={token_id}, type={token_type}") + except Exception as e: + logger.error(f"❌ 更新 Token 类型失败: {e}") + + async def delete_token(self, token_id: int): + """删除 Token(级联删除统计数据)""" + try: + async with self.get_connection() as conn: + await conn.execute("DELETE FROM tokens WHERE id = ?", (token_id,)) + await conn.commit() + logger.info(f"✅ 删除 Token: id={token_id}") + except Exception as e: + logger.error(f"❌ 删除 Token 失败: {e}") + + async def delete_tokens_by_ids(self, token_ids: List[int]) -> int: + """批量删除 Token(级联删除统计数据)""" + if not token_ids: + return 0 + + try: + placeholders = ",".join("?" for _ in token_ids) + async with self.get_connection() as conn: + await conn.execute( + f"DELETE FROM tokens WHERE id IN ({placeholders})", + token_ids, + ) + cursor = await conn.execute("SELECT changes()") + row = await cursor.fetchone() + await conn.commit() + + deleted_count = int(row[0] if row else 0) + logger.info(f"✅ 批量删除 Token: {deleted_count} 个") + return deleted_count + except Exception as e: + logger.error(f"❌ 批量删除 Token 失败: {e}") + return 0 + + async def delete_tokens_by_provider(self, provider: str): + """删除指定提供商的所有 Token""" + try: + async with self.get_connection() as conn: + await conn.execute("DELETE FROM tokens WHERE provider = ?", (provider,)) + await conn.commit() + logger.info(f"✅ 删除提供商所有 Token: {provider}") + except Exception as e: + logger.error(f"❌ 删除提供商 Token 失败: {e}") + + # ==================== Token 统计操作 ==================== + + async def record_success(self, token_id: int): + """记录 Token 使用成功""" + try: + async with self.get_connection() as conn: + await conn.execute(""" + UPDATE token_stats + SET total_requests = total_requests + 1, + successful_requests = successful_requests + 1, + last_success_time = CURRENT_TIMESTAMP + WHERE token_id = ? + """, (token_id,)) + await conn.commit() + except Exception as e: + logger.error(f"❌ 记录成功失败: {e}") + + async def record_failure(self, token_id: int): + """记录 Token 使用失败""" + try: + async with self.get_connection() as conn: + await conn.execute(""" + UPDATE token_stats + SET total_requests = total_requests + 1, + failed_requests = failed_requests + 1, + last_failure_time = CURRENT_TIMESTAMP + WHERE token_id = ? + """, (token_id,)) + await conn.commit() + except Exception as e: + logger.error(f"❌ 记录失败失败: {e}") + + async def get_token_stats(self, token_id: int) -> Optional[Dict]: + """获取 Token 统计信息""" + try: + async with self.get_connection() as conn: + cursor = await conn.execute(""" + SELECT * FROM token_stats WHERE token_id = ? + """, (token_id,)) + row = await cursor.fetchone() + return dict(row) if row else None + except Exception as e: + logger.error(f"❌ 获取统计信息失败: {e}") + return None + + # ==================== 批量操作 ==================== + + async def bulk_add_tokens( + self, + provider: str, + tokens: List[str], + token_type: str = "user", + validate: bool = True + ) -> Tuple[int, int]: + """ + 批量添加 Token(可选验证) + + Args: + provider: 提供商名称 + tokens: Token 列表 + token_type: Token 类型(如果 validate=True 将被覆盖) + validate: 是否验证 Token(仅针对 zai) + + Returns: + (成功添加数量, 失败数量) + """ + added_count = 0 + failed_count = 0 + + for token in tokens: + if token.strip(): # 过滤空 token + token_id = await self.add_token( + provider, + token.strip(), + token_type, + validate=validate + ) + if token_id: + added_count += 1 + else: + failed_count += 1 + + logger.info(f"✅ 批量添加完成: {provider} - 成功 {added_count}/{len(tokens)},失败 {failed_count}") + return added_count, failed_count + + async def replace_tokens(self, provider: str, tokens: List[str], + token_type: str = "user"): + """ + 替换指定提供商的所有 Token(先删除后添加) + """ + # 删除旧 Token + await self.delete_tokens_by_provider(provider) + + # 添加新 Token + added_count = await self.bulk_add_tokens(provider, tokens, token_type) + + logger.info(f"✅ 替换 Token 完成: {provider} - {added_count} 个") + return added_count + + async def remove_duplicate_tokens(self, provider: Optional[str] = None) -> int: + """ + 删除重复 Token,保留每个 provider/token 组合中排序靠前的一条记录。 + + 正常情况下唯一约束会阻止重复数据,这里主要处理历史数据或手工导入异常。 + """ + try: + tokens = ( + await self.get_tokens_by_provider(provider, enabled_only=False) + if provider + else await self.get_all_tokens(enabled_only=False) + ) + + seen_keys: set[tuple[str, str]] = set() + duplicate_ids: list[int] = [] + + for token_record in tokens: + token_value = str(token_record.get("token") or "").strip() + token_provider = str(token_record.get("provider") or "") + key = (token_provider, token_value) + + if key in seen_keys: + duplicate_ids.append(int(token_record["id"])) + continue + + seen_keys.add(key) + + deleted_count = await self.delete_tokens_by_ids(duplicate_ids) + if deleted_count > 0: + logger.info(f"✅ 已清理重复 Token: {deleted_count} 个") + return deleted_count + except Exception as e: + logger.error(f"❌ 清理重复 Token 失败: {e}") + return 0 + + # ==================== 实用方法 ==================== + + async def get_token_by_value(self, provider: str, token: str) -> Optional[Dict]: + """根据 Token 值查询""" + try: + async with self.get_connection() as conn: + cursor = await conn.execute(""" + SELECT t.*, ts.total_requests, ts.successful_requests, ts.failed_requests + FROM tokens t + LEFT JOIN token_stats ts ON t.id = ts.token_id + WHERE t.provider = ? AND t.token = ? + """, (provider, token)) + row = await cursor.fetchone() + return dict(row) if row else None + except Exception as e: + logger.error(f"❌ 查询 Token 失败: {e}") + return None + + async def get_provider_stats(self, provider: str) -> Dict: + """获取提供商统计信息""" + try: + async with self.get_connection() as conn: + cursor = await conn.execute(""" + SELECT + COUNT(*) as total_tokens, + SUM(CASE WHEN is_enabled = 1 THEN 1 ELSE 0 END) as enabled_tokens, + SUM(ts.total_requests) as total_requests, + SUM(ts.successful_requests) as successful_requests, + SUM(ts.failed_requests) as failed_requests + FROM tokens t + LEFT JOIN token_stats ts ON t.id = ts.token_id + WHERE t.provider = ? + """, (provider,)) + row = await cursor.fetchone() + return dict(row) if row else {} + except Exception as e: + logger.error(f"❌ 获取提供商统计失败: {e}") + return {} + + async def get_provider_token_counts(self, provider: str) -> Dict[str, int]: + """聚合提供商的 Token 数量与类型分布。""" + try: + async with self.get_connection() as conn: + cursor = await conn.execute( + """ + SELECT + COUNT(*) as total_tokens, + SUM(CASE WHEN is_enabled = 1 THEN 1 ELSE 0 END) as enabled_tokens, + SUM(CASE WHEN token_type = 'user' THEN 1 ELSE 0 END) as user_tokens, + SUM(CASE WHEN token_type = 'guest' THEN 1 ELSE 0 END) as guest_tokens, + SUM(CASE WHEN token_type = 'unknown' THEN 1 ELSE 0 END) as unknown_tokens + FROM tokens + WHERE provider = ? + """, + (provider,), + ) + row = await cursor.fetchone() + + if not row: + return { + "total_tokens": 0, + "enabled_tokens": 0, + "user_tokens": 0, + "guest_tokens": 0, + "unknown_tokens": 0, + } + + return { + "total_tokens": int(row["total_tokens"] or 0), + "enabled_tokens": int(row["enabled_tokens"] or 0), + "user_tokens": int(row["user_tokens"] or 0), + "guest_tokens": int(row["guest_tokens"] or 0), + "unknown_tokens": int(row["unknown_tokens"] or 0), + } + except Exception as e: + logger.error(f"❌ 获取 Token 数量统计失败: {e}") + return { + "total_tokens": 0, + "enabled_tokens": 0, + "user_tokens": 0, + "guest_tokens": 0, + "unknown_tokens": 0, + } + + async def count_tokens_by_provider( + self, + provider: str, + enabled_only: bool = False, + ) -> int: + """统计提供商下的 Token 总数。""" + try: + async with self.get_connection() as conn: + query = "SELECT COUNT(*) AS total_count FROM tokens WHERE provider = ?" + params: List[object] = [provider] + if enabled_only: + query += " AND is_enabled = 1" + + cursor = await conn.execute(query, params) + row = await cursor.fetchone() + + return int(row["total_count"] or 0) if row else 0 + except Exception as e: + logger.error(f"❌ 统计 Token 总数失败: {e}") + return 0 + + # ==================== Token 验证操作 ==================== + + async def validate_and_update_token(self, token_id: int) -> bool: + """ + 验证单个 Token 并更新其类型 + + Args: + token_id: Token 数据库 ID + + Returns: + 是否为有效的认证用户 Token + """ + try: + # 获取 Token 信息 + async with self.get_connection() as conn: + cursor = await conn.execute(""" + SELECT provider, token FROM tokens WHERE id = ? + """, (token_id,)) + row = await cursor.fetchone() + + if not row: + logger.error(f"❌ Token ID {token_id} 不存在") + return False + + provider = row["provider"] + token = row["token"] + + if provider != "zai": + logger.info(f"⏭️ 跳过非 zai 提供商的 Token 验证: {provider}") + return True + + # 验证 Token + from app.utils.token_pool import ZAITokenValidator + + token_type, is_valid, error_msg = await ZAITokenValidator.validate_token(token) + + # 更新 Token 类型 + await self.update_token_type(token_id, token_type) + + if not is_valid: + logger.warning(f"⚠️ Token 验证失败: id={token_id}, type={token_type}, error={error_msg}") + + return is_valid + + except Exception as e: + logger.error(f"❌ 验证 Token 失败: {e}") + return False + + async def validate_tokens_detailed(self, provider: str = "zai") -> Dict[str, Any]: + """ + 批量验证所有 Token,并返回详细结果。 + + Returns: + { + "checked": 数量, + "valid": 数量, + "guest": 数量, + "invalid": 数量, + "invalid_token_ids": [id, ...], + } + """ + try: + tokens = await self.get_tokens_by_provider(provider, enabled_only=False) + + if not tokens: + logger.warning(f"⚠️ 没有需要验证的 {provider} Token") + return { + "checked": 0, + "valid": 0, + "guest": 0, + "invalid": 0, + "invalid_token_ids": [], + } + + logger.info(f"🔍 开始批量验证 {len(tokens)} 个 {provider} Token...") + + from app.utils.token_pool import ZAITokenValidator + + stats: Dict[str, Any] = { + "checked": len(tokens), + "valid": 0, + "guest": 0, + "invalid": 0, + "invalid_token_ids": [], + } + + for token_record in tokens: + token_id = int(token_record["id"]) + token = str(token_record["token"]) + + token_type, is_valid, error_msg = await ZAITokenValidator.validate_token( + token + ) + await self.update_token_type(token_id, token_type) + + if token_type == "user" and is_valid: + stats["valid"] += 1 + elif token_type == "guest": + stats["guest"] += 1 + stats["invalid_token_ids"].append(token_id) + else: + stats["invalid"] += 1 + stats["invalid_token_ids"].append(token_id) + if error_msg: + logger.warning( + "⚠️ Token 验证失败: id={}, type={}, error={}", + token_id, + token_type, + error_msg, + ) + + logger.info( + "✅ 批量验证完成: 有效 {}, 匿名 {}, 无效 {}", + stats["valid"], + stats["guest"], + stats["invalid"], + ) + return stats + + except Exception as e: + logger.error(f"❌ 批量验证失败: {e}") + return { + "checked": 0, + "valid": 0, + "guest": 0, + "invalid": 0, + "invalid_token_ids": [], + } + + async def validate_all_tokens(self, provider: str = "zai") -> Dict[str, int]: + """ + 批量验证所有 Token + + Args: + provider: 提供商名称(默认 zai) + + Returns: + 统计结果 {"valid": 数量, "guest": 数量, "invalid": 数量} + """ + stats = await self.validate_tokens_detailed(provider) + return { + "valid": int(stats.get("valid", 0) or 0), + "guest": int(stats.get("guest", 0) or 0), + "invalid": int(stats.get("invalid", 0) or 0), + } + + +# 全局单例 +_token_dao: Optional[TokenDAO] = None + + +def get_token_dao() -> TokenDAO: + """获取全局 TokenDAO 实例""" + global _token_dao + if _token_dao is None: + _token_dao = TokenDAO() + return _token_dao + + +async def init_token_database(): + """初始化 Token 数据库""" + dao = get_token_dao() + await dao.init_database() diff --git a/app/services/token_importer.py b/app/services/token_importer.py new file mode 100644 index 0000000000000000000000000000000000000000..4ff83ecaed37196fc03bdbe33d41c46b770fab0c --- /dev/null +++ b/app/services/token_importer.py @@ -0,0 +1,138 @@ +"""本地目录 token 导入服务。""" + +from __future__ import annotations + +import json +from dataclasses import dataclass +from pathlib import Path +from typing import Optional + +from app.services.token_dao import TokenDAO, get_token_dao +from app.utils.logger import logger + + +@dataclass(frozen=True) +class TokenImportSummary: + source_dir: str + scanned_files: int + imported_count: int + duplicate_count: int + invalid_json_count: int + missing_token_count: int + invalid_token_count: int + + @property + def failed_count(self) -> int: + return ( + self.duplicate_count + + self.invalid_json_count + + self.missing_token_count + + self.invalid_token_count + ) + + +def _load_token_payload(file_path: Path) -> dict: + try: + return json.loads(file_path.read_text(encoding="utf-8")) + except json.JSONDecodeError as exc: + raise ValueError(f"JSON 解析失败: {exc}") from exc + + +async def import_tokens_from_directory( + source_dir: str | Path, + *, + provider: str = "zai", + validate: bool = True, + dao: Optional[TokenDAO] = None, +) -> TokenImportSummary: + """ + 从本地目录导入 token。 + + 目录中的每个 JSON 文件应至少包含 `token` 字段。 + """ + source_path = Path(source_dir).expanduser().resolve() + if not source_path.exists(): + raise FileNotFoundError(f"导入目录不存在: {source_path}") + if not source_path.is_dir(): + raise NotADirectoryError(f"导入路径不是目录: {source_path}") + + token_dao = dao or get_token_dao() + token_files = sorted(source_path.rglob("*.json")) + seen_tokens: set[str] = set() + imported_count = 0 + duplicate_count = 0 + invalid_json_count = 0 + missing_token_count = 0 + invalid_token_count = 0 + + for file_path in token_files: + try: + payload = _load_token_payload(file_path) + except ValueError as exc: + invalid_json_count += 1 + logger.warning(f"⚠️ 跳过无效 JSON 文件: {file_path} - {exc}") + continue + + if not isinstance(payload, dict): + invalid_json_count += 1 + logger.warning(f"⚠️ 跳过非对象 JSON 文件: {file_path}") + continue + + token = str(payload.get("token") or "").strip() + email = str(payload.get("email") or "").strip() + if not token: + missing_token_count += 1 + logger.warning(f"⚠️ 文件缺少 token 字段: {file_path}") + continue + + if token in seen_tokens: + duplicate_count += 1 + logger.info(f"↩️ 跳过本批次重复 Token: {file_path.name}") + continue + seen_tokens.add(token) + + existing = await token_dao.get_token_by_value(provider, token) + if existing is not None: + duplicate_count += 1 + logger.info( + "↩️ Token 已存在,跳过导入: {} ({})", + file_path.name, + email or "unknown", + ) + continue + + token_id = await token_dao.add_token( + provider=provider, + token=token, + token_type="user", + validate=validate, + ) + if token_id is None: + invalid_token_count += 1 + logger.warning(f"⚠️ Token 导入失败: {file_path.name} ({email or 'unknown'})") + continue + + imported_count += 1 + logger.info(f"✅ 已导入 Token: {file_path.name} ({email or 'unknown'})") + + summary = TokenImportSummary( + source_dir=str(source_path), + scanned_files=len(token_files), + imported_count=imported_count, + duplicate_count=duplicate_count, + invalid_json_count=invalid_json_count, + missing_token_count=missing_token_count, + invalid_token_count=invalid_token_count, + ) + logger.info( + "✅ Token 目录导入完成: " + "scanned={}, imported={}, duplicate={}, invalid_json={}, " + "missing_token={}, invalid_token={}", + summary.scanned_files, + summary.imported_count, + summary.duplicate_count, + summary.invalid_json_count, + summary.missing_token_count, + summary.invalid_token_count, + ) + return summary diff --git a/app/templates/base.html b/app/templates/base.html new file mode 100644 index 0000000000000000000000000000000000000000..e7871f1a7f47b467650589c989c806617274fe42 --- /dev/null +++ b/app/templates/base.html @@ -0,0 +1,201 @@ + + + + + + {% block title %}管理后台{% endblock %} - API 控制台 + + + + + + + + + + + + + + + + + {% block extra_head %}{% endblock %} + + +
+ + + +
+ + + + +
+ +
+ + +
+ {% block content %}{% endblock %} +
+
+
+
+ + {% block extra_scripts %}{% endblock %} + + diff --git a/app/templates/components/recent_logs.html b/app/templates/components/recent_logs.html new file mode 100644 index 0000000000000000000000000000000000000000..60929a0c1c187dbad29da0fadce0c13c2987dede --- /dev/null +++ b/app/templates/components/recent_logs.html @@ -0,0 +1,128 @@ + + + +
+
+ {% if logs %} + + + + + + + + + + + + + + {% for log in logs %} + + + + + + + + + + {% endfor %} + +
时间请求标记输入 / 输出缓存创建 / 命中用时 / 首字状态
+ {{ log.timestamp }} + +
+ + {{ log.model }} + + + {{ log.endpoint }} + +
+ {% if log.error_message %} +

+ {{ log.error_message }} +

+ {% endif %} +
+
+ + {{ log.client_name }} + + + {{ log.protocol_display }} + + {% if log.source_display %} + + {{ log.source_display }} + + {% endif %} + {% if log.provider_display %} + + {{ log.provider_display }} + + {% endif %} +
+
+ 输入 {{ log.input_tokens }} + / + 输出 {{ log.output_tokens }} + + 创建 {{ log.cache_creation_tokens }} + / + 命中 {{ log.cache_read_tokens }} + + 用时 {{ log.duration_display }} + / + 首字 {{ log.first_token_display }} + + + {{ "成功" if log.success else "失败" }} + + + HTTP {{ log.status_code }} + +
+ {% else %} +
+ + + +

暂无请求日志

+
+ {% endif %} +
+ +
+
+ {% if page.total_items > 0 %} + 显示第 {{ page.start_item }} - {{ page.end_item }} 条,共 {{ page.total_items }} 条 + {% else %} + 暂无日志数据 + {% endif %} +
+
+ + + 第 {{ page.current_page }} / {{ page.total_pages }} 页 + + +
+
+
diff --git a/app/templates/components/token_list.html b/app/templates/components/token_list.html new file mode 100644 index 0000000000000000000000000000000000000000..50cd3e1377472a54cb8ce2ff4c12e88619223cda --- /dev/null +++ b/app/templates/components/token_list.html @@ -0,0 +1,114 @@ + + + +
+
+ {% if tokens %} + + + + + + + + + + + + + + + {% for token in tokens %} + {% include "components/token_row.html" %} + {% endfor %} + +
IDToken类型健康度状态使用统计创建时间操作
+ {% else %} +
+ + + +

暂无 Token

+

点击右上角"添加 Token"按钮开始添加

+
+ {% endif %} +
+ +
+
+ {% if page.total_items > 0 %} + 显示第 {{ page.start_item }} - {{ page.end_item }} 条,共 {{ page.total_items }} 个 Token + {% else %} + 暂无 Token 数据 + {% endif %} +
+
+ + + 第 {{ page.current_page }} / {{ page.total_pages }} 页 + + +
+
+
+ + + + + + diff --git a/app/templates/components/token_pool.html b/app/templates/components/token_pool.html new file mode 100644 index 0000000000000000000000000000000000000000..8f37941675962f88565d318af215271ef857b712 --- /dev/null +++ b/app/templates/components/token_pool.html @@ -0,0 +1,40 @@ + +
+ {% for token in tokens %} +
+
+ Token #{{ token.index }} + + {{ token.status }} + +
+
+
+ {{ token.key }} +
+
类型: + {% if token.token_type == 'user' %} + 认证用户 + {% elif token.token_type == 'guest' %} + 匿名用户 + {% else %} + 未知 + {% endif %} +
+
成功率: {{ token.success_rate }}
+
失败次数: {{ token.failure_count }}
+
最后使用: {{ token.last_used }}
+
+
+ {% endfor %} + + {% if not tokens %} +
+ + + +

暂无 Token 配置

+

请在配置管理页面添加 Token

+
+ {% endif %} +
diff --git a/app/templates/components/token_row.html b/app/templates/components/token_row.html new file mode 100644 index 0000000000000000000000000000000000000000..32df001d19c3bb1606ad27529729d0d5dc9bc2b5 --- /dev/null +++ b/app/templates/components/token_row.html @@ -0,0 +1,154 @@ + +{% set success_rate = (token.successful_requests / token.total_requests * 100) if token.total_requests else 0 %} +{% set is_healthy = (token.token_type == 'user' and token.is_enabled and (success_rate >= 50 or token.total_requests <= 3)) %} + + + {{ token.id }} + + +
+ + {{ token.token[:30] }}... + + +
+ + + {% if token.token_type == 'user' %} + + + + + 认证用户 + + {% elif token.token_type == 'guest' %} + + + + + 匿名用户 + + {% else %} + + + + + 未知 + + {% endif %} + + + +
+ {% if is_healthy %} +
+ + + + 健康 +
+ {% elif token.token_type == 'guest' %} +
+ + + + 匿名 +
+ {% elif not token.is_enabled %} +
+ + + + 已禁用 +
+ {% else %} +
+ + + + 不健康 +
+ {% endif %} +
+ + + + + + {% if token.total_requests %} +
+
+ 成功: + {{ token.successful_requests }} +
+
+ 失败: + {{ token.failed_requests }} +
+
+ 成功率: + + {{ "%.1f"|format(success_rate) }}% + +
+ +
+
+
+
+ {% else %} + 未使用 + {% endif %} + + +
+ {{ token.created_at[:10] if token.created_at else 'N/A' }} + {{ token.created_at[11:19] if token.created_at else '' }} +
+ + +
+ + + + +
+ + diff --git a/app/templates/components/token_stats.html b/app/templates/components/token_stats.html new file mode 100644 index 0000000000000000000000000000000000000000..1d1bc4e05ad86079d554b8667dcc890ebcec2ed0 --- /dev/null +++ b/app/templates/components/token_stats.html @@ -0,0 +1,125 @@ + +
+ +
+
+
+
+ + + +
+
+
+
Token 总数
+
{{ stats.total_tokens }}
+
+
+
+
+
+ + +
+
+
+
+ + + +
+
+
+
已启用
+
+
{{ stats.enabled_tokens }}
+ {% if stats.total_tokens > 0 %} +
+ {{ "%.0f"|format(stats.enabled_tokens / stats.total_tokens * 100) }}% +
+ {% endif %} +
+
+
+
+
+
+ + +
+
+
+
+ + + +
+
+
+
认证用户
+
+
{{ stats.user_tokens }}
+ {% if stats.guest_tokens > 0 %} +
+ + + + {{ stats.guest_tokens }} 个匿名 +
+ {% endif %} +
+
+
+
+
+
+ + +
+
+
+
+ {% if stats.total_requests > 0 %} + {% set success_rate = (stats.successful_requests / stats.total_requests * 100) %} + {% if success_rate >= 80 %} + + + + {% elif success_rate >= 50 %} + + + + {% else %} + + + + {% endif %} + {% else %} + + + + {% endif %} +
+
+
+
总成功率
+
+ {% if stats.total_requests > 0 %} + {% set success_rate = (stats.successful_requests / stats.total_requests * 100) %} +
+ {{ "%.1f"|format(success_rate) }}% +
+
+ {{ stats.successful_requests }} / {{ stats.total_requests }} 请求 +
+ {% else %} +
N/A
+
暂无请求
+ {% endif %} +
+
+
+
+
+
+
diff --git a/app/templates/config.html b/app/templates/config.html new file mode 100644 index 0000000000000000000000000000000000000000..6c39091e3b388c0066c68be6a24a65844ff49d03 --- /dev/null +++ b/app/templates/config.html @@ -0,0 +1,344 @@ +{% extends "base.html" %} + +{% block title %}配置管理{% endblock %} + +{% block extra_head %} + +{% endblock %} + +{% macro section_link(section) -%} + +{%- endmacro %} + +{% macro render_field(field) -%} +
+ {% if field.value_type == 'bool' %} + + {% else %} +
+
+ + {% if field.sensitive %} + + {% endif %} +
+

{{ field.description }}

+ +
+ {% endif %} + +
+ + {{ field.source_label }} + + {% if field.restart_required %} + + 需重启 + + {% endif %} + {% if field.sensitive %} + + 敏感字段 + + {% endif %} +
+
+{%- endmacro %} + +{% block content %} +
+
+
+
+
+

Admin Config Center

+

集中管理运行参数,并支持直接编辑 `.env` 源文件

+

+ 结构化表单适合日常操作,源文件模式适合批量调整、复制完整配置或保留注释。两种模式都会在保存后立即热重载。 +

+
+ +
+ + + +
+
+ +
+
+

受管字段

+

{{ overview.total_fields }}

+

{{ overview.total_sections }} 个分组

+
+
+

.env 覆写

+

{{ overview.overridden_fields }}

+

{{ overview.default_fields }} 个字段仍在使用默认值

+
+
+

敏感字段

+

{{ overview.sensitive_fields }}

+

{{ overview.restart_required_fields }} 个字段修改后建议重启

+
+
+

源文件状态

+

{{ '.env 已存在' if overview.env_exists else '.env 尚未创建' }}

+

{{ overview.env_line_count }} 行,{{ '.env.example 可用' if overview.example_exists else '缺少 .env.example' }}

+
+
+
+
+ +
+ + {% if not overview.env_exists %} +
+ 当前工作目录中尚未找到 `.env` 文件。你可以直接保存表单或源文件,系统会自动创建它。 +
+ {% endif %} + +
+ + +
+
+ {% for section in sections %} +
+
+
+
+

{{ section.title }}

+

{{ section.description }}

+
+ + {{ section.field_count }} 个字段 + +
+
+
+ {% for field in section.fields %} + {{ render_field(field) }} + {% endfor %} +
+
+ {% endfor %} + +
+
+

+ 表单模式会保留未知配置,只更新当前页面管理的字段。 +

+
+ + +
+
+
+
+ +
+
+
+
+
+

`.env` 源文件编辑器

+

+ 直接编辑源文件内容。该模式会原样覆盖整个 `.env`,适合保留注释、批量调整和整体替换。 +

+
+ + {{ overview.env_path }} + +
+
+
+
+ 保存前会进行基础语法检查,确保每一行是合法的 `KEY=VALUE` 结构;如果热重载失败,系统会自动回滚到原来的 `.env`。 +
+ +
+
+ +
+
+

+ 源文件模式会直接覆盖整个 `.env`。保存后会自动刷新页面并保留当前视图。 +

+
+ + +
+
+
+
+
+
+
+{% endblock %} + +{% block extra_scripts %} + +{% endblock %} diff --git a/app/templates/index.html b/app/templates/index.html new file mode 100644 index 0000000000000000000000000000000000000000..0292cadc62063d28d24e34246e550fdfb54f0f37 --- /dev/null +++ b/app/templates/index.html @@ -0,0 +1,588 @@ +{% extends "base.html" %} + +{% block title %}仪表盘{% endblock %} + +{% block content %} +
+
+
+
+
+

Usage Dashboard

+

查看请求消耗、缓存效果、延迟表现和使用趋势

+

+ 统计来源于请求日志数据库,覆盖输入输出 Token、缓存创建与命中、成功率、平均延迟,以及最近 24 小时、7 天、30 天的使用趋势。 +

+
+
+

Last Update

+

{{ current_time }}

+

运行时间 {{ stats.uptime }}

+
+
+
+
+ +
+
+
+
+

总请求数

+

{{ stats.total_requests }}

+
+ + + + + +
+

成功 {{ stats.successful_requests }} / 失败 {{ stats.failed_requests }}

+
+ +
+
+
+

总消耗 Token 数

+

{{ stats.total_consumed_tokens_display }}

+
+ + + + + +
+

累计 {{ stats.total_consumed_tokens }} Tokens

+
+ +
+
+
+

缓存 Token

+

{{ stats.total_cache_tokens_display }}

+
+ + + + + +
+

创建 {{ stats.cache_creation_tokens }} / 命中 {{ stats.cache_read_tokens }}

+
+ +
+
+
+

成功率

+

{{ stats.success_rate }}%

+
+ + + + + +
+

图表支持切换 24 小时 / 7 天 / 30 天

+
+ +
+
+
+

输入 Token

+

{{ stats.input_tokens_display }}

+
+ + + + + +
+

累计 {{ stats.input_tokens }} Tokens

+
+ +
+
+
+

输出 Token

+

{{ stats.output_tokens_display }}

+
+ + + + + +
+

累计 {{ stats.output_tokens }} Tokens

+
+ +
+
+
+

平均延迟

+

{{ "%.2f"|format(stats.average_latency) }}s

+
+ + + + + +
+

平均首字延迟 {{ "%.2f"|format(stats.average_first_token_latency) }}s

+
+ +
+
+
+

Token 池健康度

+

{{ stats.healthy_tokens }}/{{ stats.pool_total_tokens }}

+
+ + + + + +
+

可用 {{ stats.available_tokens }} / 已启用 {{ stats.enabled_tokens }} / 认证 {{ stats.user_tokens }}

+
+
+ +
+
+
+
+

使用趋势图

+

+ 最近 7 天按天聚合的请求量、输入输出与缓存变化。 +

+
+
+ {% set trend_window_options = trend_windows if trend_windows is defined and trend_windows else [ + {'key': '24h', 'label': '24 小时'}, + {'key': '7d', 'label': '7 天'}, + {'key': '30d', 'label': '30 天'} + ] %} + {% for option in trend_window_options %} + + {% endfor %} +
+
+
+ 蓝柱: 请求量 + 紫线: 输入 + 红线: 输出 + 绿线: 缓存创建 + 黄线: 缓存命中 +
+
+ +
+ +
+ +
+
+

缓存创建 / 命中

+

按请求次数和 Token 数量查看缓存是否真的生效。

+
+
+

缓存创建

+

{{ stats.cache_creation_requests }}

+

共创建 {{ stats.cache_creation_tokens }} Tokens

+
+
+

缓存命中

+

{{ stats.cache_hit_requests }}

+

共命中 {{ stats.cache_read_tokens }} Tokens

+
+
+
+ +
+

输入 / 输出画像

+

对比 Prompt 与 Completion 的消耗分布。

+ {% set usage_total = stats.input_tokens + stats.output_tokens %} + {% set input_ratio = (stats.input_tokens / usage_total * 100) if usage_total > 0 else 0 %} + {% set output_ratio = (stats.output_tokens / usage_total * 100) if usage_total > 0 else 0 %} +
+
+
+ 输入 Token + {{ "%.1f"|format(input_ratio) }}% +
+
+
+
+

{{ stats.input_tokens }} Tokens

+
+
+
+ 输出 Token + {{ "%.1f"|format(output_ratio) }}% +
+
+
+
+

{{ stats.output_tokens }} Tokens

+
+
+
+
+
+ +
+
+

最近请求日志

+
+ +
+
+
+
+
+ + + + + 加载中... +
+
+
+
+
+{% endblock %} + +{% block extra_scripts %} + +{% endblock %} diff --git a/app/templates/login.html b/app/templates/login.html new file mode 100644 index 0000000000000000000000000000000000000000..f0defb7746297faf1d4f6dddb0a694da9b494260 --- /dev/null +++ b/app/templates/login.html @@ -0,0 +1,143 @@ + + + + + + 登录 - API 控制台 + + + + + + + + + + +
+
+ +
+
+ + + +
+

+ API 管理后台 +

+

+ 请输入管理密码以继续 +

+
+ + +
+ + +
+
+
+ + + +
+
+

+
+
+
+ + +
+
+
+ + +
+
+ +
+ +
+
+ + +
+

+ 默认密码:admin123(请在 .env 中修改 ADMIN_PASSWORD) +

+
+
+
+
+ + diff --git a/app/templates/logs.html b/app/templates/logs.html new file mode 100644 index 0000000000000000000000000000000000000000..2316039dea1b6127143fc23a335a194a3c1509ef --- /dev/null +++ b/app/templates/logs.html @@ -0,0 +1,59 @@ +{% extends "base.html" %} + +{% block title %}实时日志{% endblock %} + +{% block content %} +
+
+
+

实时日志

+

滚动查看服务当前输出的最新日志

+
+
+ +
+
+ +
+
+

日志流

+
+ +
+
+
+
+
+ 加载中... +
+
+
+
+
+{% endblock %} + +{% block extra_scripts %} + +{% endblock %} diff --git a/app/templates/tokens.html b/app/templates/tokens.html new file mode 100644 index 0000000000000000000000000000000000000000..b570e990a92d12d1b9c09ec09c709e26625ae2c4 --- /dev/null +++ b/app/templates/tokens.html @@ -0,0 +1,487 @@ +{% extends "base.html" %} + +{% block title %}Token 管理{% endblock %} + +{% block content %} +
+ +
+
+

Token 管理

+

管理和维护当前服务使用的 Token

+
+
+ + + + + + 打开配置中心 + + + +
+
+ + +
+ +
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ + +
+
+
+
+

目录导入策略

+

+ 配置入口已迁移到配置管理页。这里仅展示当前策略,并允许立即执行一次导入。 +

+
+ + {{ '定时已开启' if automation.import_enabled else '定时已关闭' }} + +
+ + {% if automation.has_import_source_dir %} +
+ 手动导入会复用当前配置的目录和验证逻辑,重复 Token 会自动跳过。 +
+ {% else %} +
+ 还没有配置导入目录,无法执行手动导入。请先到配置管理页设置 `TOKEN_AUTO_IMPORT_SOURCE_DIR`。 +
+ {% endif %} + +
+
+
Token 目录
+
+ {{ automation.import_source_dir or '未配置' }} +
+
+
+
+
扫描间隔
+
{{ automation.import_interval }} 秒
+
+
+
配置位置
+
配置管理 / Token 池策略
+
+
+
+ +
+ + 去配置中心修改 + + +
+
+ +
+
+
+

自动维护策略

+

+ 维护动作和定时间隔统一在配置管理页设置。这里仅执行当前已配置的维护策略。 +

+
+ + {{ '定时已开启' if automation.maintenance_enabled else '定时已关闭' }} + +
+ + {% if automation.has_maintenance_actions %} +
+ 手动维护会按当前配置顺序执行去重、测活和失效清理,不再在本页单独维护另一套选项。 +
+ {% else %} +
+ 当前没有配置任何维护动作。请先到配置管理页勾选至少一个维护动作。 +
+ {% endif %} + +
+
+
+
维护间隔
+
{{ automation.maintenance_interval }} 秒
+
+
+
配置位置
+
配置管理 / Token 池策略
+
+
+
+
当前维护动作
+
+ {% if automation.maintenance_actions %} + {% for action in automation.maintenance_actions %} + {{ action }} + {% endfor %} + {% else %} + 未配置 + {% endif %} +
+
+
+ +
+ + 去配置中心修改 + + +
+
+
+ + +
+
+

+ Token 列表 + +

+
+ + +
+
+
+ +
+ + + + +
+
+
+ + + + + + +
+{% endblock %} + +{% block extra_scripts %} + +{% endblock %} diff --git a/app/utils/__init__.py b/app/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4e4a7eb51af06f2d96bdd7bcce3c24fdd9b2d469 --- /dev/null +++ b/app/utils/__init__.py @@ -0,0 +1,6 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +from app.utils import reload_config, logger + +__all__ = ["reload_config", "logger"] diff --git a/app/utils/env_file.py b/app/utils/env_file.py new file mode 100644 index 0000000000000000000000000000000000000000..c5fcc0cd9b9be69be9610afa6a4232922f0dceeb --- /dev/null +++ b/app/utils/env_file.py @@ -0,0 +1,59 @@ +"""Helpers for updating .env files without dropping unrelated settings.""" + +from __future__ import annotations + +import re +from pathlib import Path +from typing import Mapping + +_ENV_KEY_PATTERN = re.compile(r"^\s*([A-Za-z_][A-Za-z0-9_]*)\s*=") + + +def _serialize_env_value(value: object) -> str: + if isinstance(value, bool): + return "true" if value else "false" + + text = "" if value is None else str(value) + if not text: + return "" + + if any(char.isspace() for char in text) or any( + char in text for char in ["#", '"', "\\", "'"] + ): + if "'" not in text: + return f"'{text}'" + + escaped = text.replace("\\", "\\\\").replace('"', '\\"') + return f'"{escaped}"' + + return text + + +def update_env_file( + updates: Mapping[str, object], + env_path: str | Path = ".env", +) -> None: + """Update selected keys inside a .env file while preserving other lines.""" + path = Path(env_path) + lines = path.read_text(encoding="utf-8").splitlines() if path.exists() else [] + remaining_updates = {key: _serialize_env_value(value) for key, value in updates.items()} + + for index, line in enumerate(lines): + match = _ENV_KEY_PATTERN.match(line) + if not match: + continue + + key = match.group(1) + if key not in remaining_updates: + continue + + lines[index] = f"{key}={remaining_updates.pop(key)}" + + if remaining_updates: + if lines and lines[-1].strip(): + lines.append("") + for key, value in remaining_updates.items(): + lines.append(f"{key}={value}") + + content = "\n".join(lines).rstrip() + path.write_text(f"{content}\n" if content else "", encoding="utf-8") diff --git a/app/utils/fe_version.py b/app/utils/fe_version.py new file mode 100644 index 0000000000000000000000000000000000000000..de28498b353b533457efd94259525dde47a76cc3 --- /dev/null +++ b/app/utils/fe_version.py @@ -0,0 +1,112 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +""" +Utility helpers for resolving the latest X-FE-Version value from chat.z.ai. + +The upstream service embeds the current front-end release identifier inside +its landing page static asset URLs (e.g. `prod-fe-1.0.107`). The helpers in +this module fetch the landing page, extract the version string, and cache it +with a configurable TTL so the expensive network fetch only happens when +necessary. +""" + +from __future__ import annotations + +import re +import time +from typing import Optional + +import httpx + +from app.utils.logger import get_logger +from app.utils.user_agent import get_random_user_agent + +# Base URL to probe for the version string. +FE_VERSION_SOURCE_URL = "https://chat.z.ai" + +# Cache TTL in seconds (default: 30 minutes). +CACHE_TTL_SECONDS = 1800 + +_logger = get_logger() +_version_pattern = re.compile(r"prod-fe-\d+\.\d+\.\d+") + +_cached_version: str = "" +_cached_at: float = 0.0 + + +def _extract_version(page_content: str) -> Optional[str]: + """Extract the version string from the page content.""" + if not page_content: + return None + + matches = _version_pattern.findall(page_content) + if not matches: + return None + + # Choose the highest lexical value to guard against mixed versions. + return max(matches) + + + + +def _should_use_cache(force_refresh: bool) -> bool: + """Determine whether the cached value can be reused.""" + if force_refresh: + return False + if not _cached_version: + return False + if _cached_at <= 0: + return False + return (time.time() - _cached_at) < CACHE_TTL_SECONDS + + +def get_latest_fe_version(force_refresh: bool = False) -> str: + """ + Resolve the latest X-FE-Version value from chat.z.ai. + + The lookup order is: + 1. Cached value within TTL. + 2. Remote fetch from chat.z.ai. + + Raises: + Exception: If unable to fetch the version from the remote source. + """ + global _cached_version, _cached_at + + if _should_use_cache(force_refresh): + return _cached_version + + try: + headers = {"User-Agent": get_random_user_agent("chrome")} + except Exception: + headers = { + "User-Agent": ( + "Mozilla/5.0 (Windows NT 10.0; Win64; x64) " + "AppleWebKit/537.36 (KHTML, like Gecko) " + "Chrome/120.0.0.0 Safari/537.36" + ) + } + + try: + with httpx.Client(timeout=10.0, follow_redirects=True) as client: + response = client.get(FE_VERSION_SOURCE_URL, headers=headers) + response.raise_for_status() + version = _extract_version(response.text) + if version: + if version != _cached_version: + _logger.info(f"[Z.AI] Detected X-FE-Version update: {version}") + _cached_version = version + _cached_at = time.time() + return version + + _logger.error("[Z.AI] Unable to locate X-FE-Version in landing page") + raise Exception("Unable to locate X-FE-Version in landing page") + except Exception as exc: + _logger.error(f"[Z.AI] Failed to fetch X-FE-Version from {FE_VERSION_SOURCE_URL}: {exc}") + raise Exception(f"Failed to fetch X-FE-Version: {exc}") + + +def refresh_fe_version() -> str: + """Force refresh the cached version by bypassing the TTL.""" + return get_latest_fe_version(force_refresh=True) diff --git a/app/utils/guest_session_pool.py b/app/utils/guest_session_pool.py new file mode 100644 index 0000000000000000000000000000000000000000..9cc82d5d852149bbbf4e8a2b345b56c1e7c9f149 --- /dev/null +++ b/app/utils/guest_session_pool.py @@ -0,0 +1,646 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +"""匿名访客会话池。""" + +import asyncio +import random +import time +from dataclasses import dataclass, field +from threading import Lock +from typing import Dict, List, Optional, Set + +import httpx + +from app.core.config import settings +from app.utils.fe_version import get_latest_fe_version +from app.utils.logger import logger +from app.utils.user_agent import get_random_user_agent + +AUTH_URL = "https://chat.z.ai/api/v1/auths/" +CHATS_URL = "https://chat.z.ai/api/v1/chats/" +AUTH_HTTP_MAX_KEEPALIVE_CONNECTIONS = 20 +AUTH_HTTP_MAX_CONNECTIONS = 50 +GUEST_SESSION_TTL_SECONDS = 480 +GUEST_SESSION_TTL_JITTER_SECONDS = 60 +GUEST_SESSION_MIN_TTL_SECONDS = 180 +GUEST_POOL_MAINTENANCE_INTERVAL_SECONDS = 30 +GUEST_CLEANUP_PARALLELISM = 4 +CAPACITY_FILL_ATTEMPT_MULTIPLIER = 3 +CAPACITY_FILL_MIN_ATTEMPTS = 3 +MAX_DUPLICATE_LOG_USER_IDS = 3 + + +def _get_proxy_config() -> Optional[str]: + """获取代理配置。""" + if settings.HTTPS_PROXY: + return settings.HTTPS_PROXY + if settings.HTTP_PROXY: + return settings.HTTP_PROXY + if settings.SOCKS5_PROXY: + return settings.SOCKS5_PROXY + return None + + +def _build_timeout(read_timeout: float = 30.0) -> httpx.Timeout: + """构建访客会话相关请求超时。""" + return httpx.Timeout( + connect=5.0, + read=read_timeout, + write=10.0, + pool=5.0, + ) + + +def _build_limits() -> httpx.Limits: + """构建访客会话相关连接池限制。""" + return httpx.Limits( + max_keepalive_connections=AUTH_HTTP_MAX_KEEPALIVE_CONNECTIONS, + max_connections=AUTH_HTTP_MAX_CONNECTIONS, + ) + + +def _build_async_client(read_timeout: float = 30.0) -> httpx.AsyncClient: + """构建访客会话相关 HTTP 客户端。""" + return httpx.AsyncClient( + timeout=_build_timeout(read_timeout), + follow_redirects=True, + limits=_build_limits(), + proxy=_get_proxy_config(), + ) + + +def _build_dynamic_headers(chat_id: str = "") -> Dict[str, str]: + """生成匿名访客鉴权所需浏览器请求头。""" + browser_choices = [ + "chrome", + "chrome", + "chrome", + "edge", + "edge", + "firefox", + "safari", + ] + browser_type = random.choice(browser_choices) + user_agent = get_random_user_agent(browser_type) + fe_version = get_latest_fe_version() + + chrome_version = "139" + edge_version = "139" + + if "Chrome/" in user_agent: + try: + chrome_version = user_agent.split("Chrome/")[1].split(".")[0] + except Exception: + pass + + if "Edg/" in user_agent: + try: + edge_version = user_agent.split("Edg/")[1].split(".")[0] + sec_ch_ua = ( + f'"Microsoft Edge";v="{edge_version}", ' + f'"Chromium";v="{chrome_version}", "Not_A Brand";v="24"' + ) + except Exception: + sec_ch_ua = ( + f'"Not_A Brand";v="8", "Chromium";v="{chrome_version}", ' + f'"Google Chrome";v="{chrome_version}"' + ) + elif "Firefox/" in user_agent: + sec_ch_ua = None + else: + sec_ch_ua = ( + f'"Not_A Brand";v="8", "Chromium";v="{chrome_version}", ' + f'"Google Chrome";v="{chrome_version}"' + ) + + headers = { + "Content-Type": "application/json", + "Accept": "application/json, text/event-stream", + "Connection": "keep-alive", + "Cache-Control": "no-cache", + "User-Agent": user_agent, + "Accept-Language": "zh-CN,zh;q=0.9,en;q=0.8", + "X-FE-Version": fe_version, + "Origin": "https://chat.z.ai", + } + + if sec_ch_ua: + headers["sec-ch-ua"] = sec_ch_ua + headers["sec-ch-ua-mobile"] = "?0" + headers["sec-ch-ua-platform"] = '"Windows"' + + if chat_id: + headers["Referer"] = f"https://chat.z.ai/c/{chat_id}" + else: + headers["Referer"] = "https://chat.z.ai/" + + return headers + + +def _build_session_expiry() -> float: + """为新会话分配带抖动的过期时间,避免整池同时失效。""" + jitter = random.uniform( + -GUEST_SESSION_TTL_JITTER_SECONDS, + GUEST_SESSION_TTL_JITTER_SECONDS, + ) + ttl_seconds = max( + GUEST_SESSION_MIN_TTL_SECONDS, + GUEST_SESSION_TTL_SECONDS + jitter, + ) + return time.time() + ttl_seconds + + +@dataclass +class GuestSession: + """单个匿名访客会话。""" + + token: str + user_id: str + username: str + created_at: float = field(default_factory=time.time) + expires_at: float = field(default_factory=_build_session_expiry) + active_requests: int = 0 + valid: bool = True + failure_count: int = 0 + last_failure_time: float = 0.0 + + @property + def age(self) -> float: + """会话存活时间。""" + return time.time() - self.created_at + + @property + def is_expired(self) -> bool: + """判断会话是否已过期。""" + return time.time() >= self.expires_at + + def snapshot(self) -> Dict[str, str]: + """获取当前会话快照。""" + return { + "token": self.token, + "user_id": self.user_id, + "username": self.username, + } + + +class GuestSessionPool: + """匿名访客会话池,支持最小负载获取与失败替换。""" + + def __init__(self, pool_size: int = 3): + self.pool_size = max(1, pool_size) + self._lock = Lock() + self._sessions: Dict[str, GuestSession] = {} + self._maintenance_task: Optional[asyncio.Task] = None + self._http_client: Optional[httpx.AsyncClient] = None + self._client_lock = asyncio.Lock() + self._capacity_lock = asyncio.Lock() + self._background_tasks: Set[asyncio.Task] = set() + self._cleanup_parallelism = GUEST_CLEANUP_PARALLELISM + self._maintenance_interval = GUEST_POOL_MAINTENANCE_INTERVAL_SECONDS + + async def _get_http_client(self) -> httpx.AsyncClient: + """获取可复用的 HTTP 客户端,减少频繁建连开销。""" + if self._http_client is not None: + return self._http_client + + async with self._client_lock: + if self._http_client is None: + self._http_client = _build_async_client() + return self._http_client + + async def _close_http_client(self): + """关闭可复用的 HTTP 客户端。""" + async with self._client_lock: + client = self._http_client + self._http_client = None + + if client is not None: + await client.aclose() + + def _track_background_task(self, coro) -> asyncio.Task: + """跟踪后台任务,避免清理阻塞前台重试路径。""" + task = asyncio.create_task(coro) + self._background_tasks.add(task) + + def _on_done(done_task: asyncio.Task): + self._background_tasks.discard(done_task) + try: + done_task.result() + except asyncio.CancelledError: + pass + except Exception as exc: + logger.warning(f"⚠️ 匿名会话后台任务异常: {exc}") + + task.add_done_callback(_on_done) + return task + + async def _wait_background_tasks(self): + """等待当前已注册的后台任务结束。""" + pending = list(self._background_tasks) + if pending: + await asyncio.gather(*pending, return_exceptions=True) + + async def _delete_sessions_concurrently(self, sessions: List[GuestSession]): + """并发清理多枚匿名会话,加快池维护速度。""" + if not sessions: + return + + semaphore = asyncio.Semaphore(self._cleanup_parallelism) + + async def _cleanup(session: GuestSession): + async with semaphore: + await self._delete_all_chats(session) + + await asyncio.gather(*(_cleanup(session) for session in sessions)) + + async def _create_session(self) -> GuestSession: + """创建一个新的匿名访客会话。""" + headers = _build_dynamic_headers() + + # 访客鉴权会写入 cookie,复用同一个 client 会把“新建会话”粘回旧访客身份。 + async with _build_async_client() as auth_client: + response = await auth_client.get(AUTH_URL, headers=headers) + + if response.status_code != 200: + raise RuntimeError( + f"匿名会话创建失败: HTTP {response.status_code} {response.text[:200]}" + ) + + data = response.json() + token = str(data.get("token") or "").strip() + user_id = str( + data.get("id") or data.get("user_id") or data.get("uid") or "" + ).strip() + username = str( + data.get("name") + or str(data.get("email") or "").split("@")[0] + or f"guest-{user_id[:8] or 'session'}" + ).strip() + + if not token: + raise RuntimeError(f"匿名会话创建失败: 未返回 token {data}") + if not user_id: + user_id = f"guest-{token[:12]}" + + logger.info( + f"🫥 创建匿名会话成功: user_id={user_id}, username={username or 'Guest'}" + ) + return GuestSession( + token=token, + user_id=user_id, + username=username or "Guest", + ) + + async def _delete_all_chats(self, session: GuestSession) -> bool: + """删除匿名会话的全部对话,尽量释放并发占用。""" + headers = _build_dynamic_headers() + headers.update( + { + "Authorization": f"Bearer {session.token}", + "Accept": "application/json", + "Content-Type": "application/json", + } + ) + + try: + client = await self._get_http_client() + response = await client.delete(CHATS_URL, headers=headers) + + if response.status_code == 200: + logger.info(f"🧹 已清理匿名会话聊天记录: {session.user_id}") + return True + + logger.warning( + f"⚠️ 清理匿名会话聊天记录失败: {session.user_id}, " + f"HTTP {response.status_code}, body={response.text[:200]}" + ) + except Exception as exc: + logger.warning(f"⚠️ 清理匿名会话聊天记录异常: {session.user_id}, {exc}") + + return False + + def _list_valid_sessions( + self, + exclude_user_ids: Optional[Set[str]] = None, + ) -> List[GuestSession]: + """获取有效匿名会话列表。""" + excluded = exclude_user_ids or set() + with self._lock: + return [ + session + for session in self._sessions.values() + if self._is_session_usable(session) + and session.user_id not in excluded + ] + + def _is_session_usable(self, session: GuestSession) -> bool: + """判断会话当前是否还能继续分配。""" + return session.valid and not session.is_expired + + def _should_retire_session(self, session: GuestSession) -> bool: + """判断会话是否应当从池中回收。""" + return session.active_requests == 0 and not self._is_session_usable(session) + + def _can_replace_session(self, session: GuestSession) -> bool: + """判断当前池内会话是否允许被新的同 user_id 会话替换。""" + return self._should_retire_session(session) + + def _store_session(self, session: GuestSession) -> bool: + """仅在会话唯一或旧会话已过期时写入会话池。""" + with self._lock: + existing = self._sessions.get(session.user_id) + if existing and not self._can_replace_session(existing): + return False + self._sessions[session.user_id] = session + return True + + def _log_duplicate_sessions(self, action: str, user_ids: List[str]): + """记录重复会话,避免补池时静默覆盖。""" + if not user_ids: + return + + sample = ", ".join(user_ids[:MAX_DUPLICATE_LOG_USER_IDS]) + logger.warning( + f"⚠️ 匿名会话池{action}收到重复会话,已忽略: " + f"count={len(user_ids)}, user_ids={sample}" + ) + + def _register_create_results(self, action: str, results: List[object]) -> int: + """写入新创建的会话,并显式忽略重复 user_id。""" + created = 0 + duplicate_user_ids: List[str] = [] + + for result in results: + if isinstance(result, GuestSession): + if self._store_session(result): + created += 1 + else: + duplicate_user_ids.append(result.user_id) + continue + + if isinstance(result, Exception): + logger.warning(f"⚠️ 匿名会话池{action}失败: {result}") + + self._log_duplicate_sessions(action, duplicate_user_ids) + return created + + def _get_fill_attempt_budget(self, missing_count: int) -> int: + """为补池/获取会话计算显式尝试上限,避免重复会话导致死循环。""" + scaled_budget = max(1, missing_count) * CAPACITY_FILL_ATTEMPT_MULTIPLIER + minimum_budget = max(1, missing_count) + CAPACITY_FILL_MIN_ATTEMPTS + return max(scaled_budget, minimum_budget) + + def _pop_retired_sessions(self) -> List[GuestSession]: + """移除当前所有可回收的失效会话。""" + retired_sessions: List[GuestSession] = [] + + with self._lock: + for user_id, session in list(self._sessions.items()): + if self._should_retire_session(session): + retired_sessions.append(self._sessions.pop(user_id)) + + return retired_sessions + + async def _ensure_capacity(self): + """补齐匿名会话池容量。""" + async with self._capacity_lock: + attempts_left = self._get_fill_attempt_budget( + self.pool_size - len(self._list_valid_sessions()) + ) + + while attempts_left > 0: + need = self.pool_size - len(self._list_valid_sessions()) + if need <= 0: + return + + batch_size = min(need, attempts_left) + results = await asyncio.gather( + *[self._create_session() for _ in range(batch_size)], + return_exceptions=True, + ) + attempts_left -= batch_size + + created = self._register_create_results("补齐", results) + if created == 0 and attempts_left == 0: + break + + remaining = self.pool_size - len(self._list_valid_sessions()) + if remaining > 0: + logger.warning( + "⚠️ 匿名会话池补齐未达到目标容量: " + f"missing={remaining}, current={len(self._list_valid_sessions())}" + ) + + async def _maintenance_loop(self): + """后台维护:回收过期/失效会话,并补齐池容量。""" + while True: + try: + await asyncio.sleep(self._maintenance_interval) + retired_sessions = self._pop_retired_sessions() + await self._delete_sessions_concurrently(retired_sessions) + + await self._ensure_capacity() + except asyncio.CancelledError: + return + except Exception as exc: + logger.warning(f"⚠️ 匿名会话池后台维护异常: {exc}") + + async def initialize(self): + """初始化匿名会话池。""" + if self._maintenance_task: + return + + await self._ensure_capacity() + created = len(self._list_valid_sessions()) + + if created == 0: + fallback = await self._create_session() + if not self._store_session(fallback): + raise RuntimeError( + "匿名会话池初始化失败: 无法写入唯一匿名会话" + ) + created = len(self._list_valid_sessions()) + + logger.info(f"✅ 匿名会话池初始化完成: {created} 个会话") + self._maintenance_task = asyncio.create_task(self._maintenance_loop()) + + async def close(self): + """关闭匿名会话池。""" + if self._maintenance_task: + self._maintenance_task.cancel() + try: + await self._maintenance_task + except asyncio.CancelledError: + pass + self._maintenance_task = None + + with self._lock: + sessions = list(self._sessions.values()) + self._sessions.clear() + + await self._wait_background_tasks() + idle_sessions = [ + session for session in sessions if session.active_requests == 0 + ] + await self._delete_sessions_concurrently(idle_sessions) + await self._close_http_client() + + async def acquire( + self, + exclude_user_ids: Optional[Set[str]] = None, + ) -> GuestSession: + """按最小忙碌度获取一个可用匿名会话。""" + excluded = exclude_user_ids or set() + attempts_left = self._get_fill_attempt_budget(len(excluded) + 1) + + while attempts_left > 0: + candidates = self._list_valid_sessions(exclude_user_ids=excluded) + if candidates: + session = min( + candidates, + key=lambda item: (item.active_requests, item.created_at), + ) + with self._lock: + current = self._sessions.get(session.user_id) + if ( + current + and self._is_session_usable(current) + and current.user_id not in excluded + ): + current.active_requests += 1 + return current + + new_session = await self._create_session() + attempts_left -= 1 + if new_session.user_id in excluded: + logger.warning( + "⚠️ 获取匿名会话时命中排除 user_id,已忽略: " + f"{new_session.user_id}" + ) + continue + + if not self._store_session(new_session): + logger.warning( + "⚠️ 获取匿名会话时命中重复 user_id,已重试: " + f"{new_session.user_id}" + ) + continue + + with self._lock: + current = self._sessions.get(new_session.user_id) + if current and self._is_session_usable(current): + current.active_requests += 1 + return current + + raise RuntimeError("匿名会话池获取失败: 未能创建唯一匿名会话") + + def release(self, user_id: str): + """释放一个匿名会话占用。""" + retired_session: Optional[GuestSession] = None + + with self._lock: + session = self._sessions.get(user_id) + if session: + session.active_requests = max(0, session.active_requests - 1) + if self._should_retire_session(session): + retired_session = self._sessions.pop(user_id) + + if retired_session: + logger.info(f"🧹 已回收过期匿名会话: {retired_session.user_id}") + self._track_background_task(self._delete_all_chats(retired_session)) + self._track_background_task(self._ensure_capacity()) + + async def report_failure(self, user_id: Optional[str] = None): + """标记匿名会话失效,并尝试补一个新会话。""" + session: Optional[GuestSession] = None + + if user_id: + with self._lock: + session = self._sessions.pop(user_id, None) + if session: + session.valid = False + session.failure_count += 1 + session.last_failure_time = time.time() + session.active_requests = 0 + + if session: + self._track_background_task(self._delete_all_chats(session)) + logger.warning(f"⚠️ 已淘汰匿名会话: {session.user_id}") + + await self._ensure_capacity() + + async def refresh_auth(self, failed_user_id: Optional[str] = None): + """兼容 glm-demo 命名:刷新匿名会话。""" + await self.report_failure(failed_user_id) + + async def cleanup_idle_chats(self): + """清理当前空闲匿名会话的聊天记录。""" + with self._lock: + idle_sessions = [ + session + for session in self._sessions.values() + if self._is_session_usable(session) and session.active_requests == 0 + ] + + await self._delete_sessions_concurrently(idle_sessions) + + def get_pool_status(self) -> Dict[str, int]: + """获取匿名会话池状态。""" + with self._lock: + sessions = list(self._sessions.values()) + + valid_sessions = [ + session for session in sessions if self._is_session_usable(session) + ] + busy_sessions = [ + session for session in valid_sessions if session.active_requests > 0 + ] + + return { + "total_sessions": len(sessions), + "valid_sessions": len(valid_sessions), + "available_sessions": len( + [session for session in valid_sessions if session.active_requests == 0] + ), + "busy_sessions": len(busy_sessions), + "expired_sessions": len( + [session for session in sessions if session.is_expired] + ), + } + + +_guest_session_pool: Optional[GuestSessionPool] = None +_guest_pool_lock = Lock() + + +def get_guest_session_pool() -> Optional[GuestSessionPool]: + """获取全局匿名会话池。""" + return _guest_session_pool + + +async def initialize_guest_session_pool( + pool_size: int = 3, +) -> GuestSessionPool: + """初始化全局匿名会话池。""" + global _guest_session_pool + + with _guest_pool_lock: + if _guest_session_pool is None: + _guest_session_pool = GuestSessionPool(pool_size=pool_size) + pool = _guest_session_pool + + await pool.initialize() + return pool + + +async def close_guest_session_pool(): + """关闭全局匿名会话池。""" + global _guest_session_pool + + with _guest_pool_lock: + pool = _guest_session_pool + _guest_session_pool = None + + if pool: + await pool.close() diff --git a/app/utils/logger.py b/app/utils/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..202c5a439431ab146777408bc586e6d45feca810 --- /dev/null +++ b/app/utils/logger.py @@ -0,0 +1,105 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import sys +from pathlib import Path +from loguru import logger + +# Global logger instance +app_logger = None + + +def setup_logger(log_dir, log_retention_days=7, log_rotation="1 day", debug_mode=False): + """ + Create a logger instance + + Parameters: + log_dir (str): 日志目录 + log_retention_days (int): 日志保留天数 + log_rotation (str): 日志轮转间隔 + debug_mode (bool): 是否开启调试模式 + """ + global app_logger + + # 移除所有现有的日志处理器(支持热重载) + logger.remove() + + log_level = "DEBUG" if debug_mode else "INFO" + + console_format = ( + "{time:HH:mm:ss} | {level: <8} | {message}" + if not debug_mode + else "{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | " + "{name}:{function}:{line} | {message}" + ) + + # 添加控制台输出(根据 debug_mode 设置级别) + logger.add(sys.stderr, level=log_level, format=console_format, colorize=True) + + # 只有在 debug_mode 时才添加文件输出 + if debug_mode: + try: + log_path = Path(log_dir) + log_path.mkdir(parents=True, exist_ok=True) + + log_file = log_path / "{time:YYYY-MM-DD}.log" + file_format = "{time:YYYY-MM-DD HH:mm:ss.SSS} | {level: <8} | {name}:{function}:{line} | {message}" + + logger.add( + str(log_file), + level=log_level, + format=file_format, + rotation=log_rotation, + retention=f"{log_retention_days} days", + encoding="utf-8", + compression="zip", + enqueue=True, + catch=True, + ) + except (PermissionError, OSError) as e: + # 如果无法创建日志目录或文件,降级为仅控制台输出 + logger.warning(f"⚠️ 无法创建日志文件 ({e}),将仅使用控制台输出") + + app_logger = logger + + return logger + + +def get_logger(): + """Get the logger instance""" + global app_logger + if app_logger is None: + # 如果没有设置过logger,使用默认配置 + logger.remove() # 移除所有现有处理器 + logger.add(sys.stderr, level="INFO", format="{time:YYYY-MM-DD HH:mm:ss.SSS} | {level: <8} | {name}:{function}:{line} | {message}") + app_logger = logger + return app_logger + + +if __name__ == "__main__": + """Test the logger""" + import tempfile + + with tempfile.TemporaryDirectory() as temp_dir: + try: + setup_logger(temp_dir, debug_mode=True) + + logger.debug("这是一条调试日志") + logger.info("这是一条信息日志") + logger.warning("这是一条警告日志") + logger.error("这是一条错误日志") + logger.critical("这是一条严重日志") + + try: + 1 / 0 + except ZeroDivisionError: + logger.exception("发生了除零异常") + + print("✅ 日志测试完成") + + logger.remove() + + except Exception as e: + print(f"❌ 日志测试失败: {e}") + logger.remove() + raise diff --git a/app/utils/reload_config.py b/app/utils/reload_config.py new file mode 100644 index 0000000000000000000000000000000000000000..36398a1b3a27317410a3b674d30cb54eeee303df --- /dev/null +++ b/app/utils/reload_config.py @@ -0,0 +1,94 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +""" +热重载配置模块 +定义 Granian 服务器热重载时需要忽略的目录和文件模式 +""" + +# 忽略的目录列表 +RELOAD_IGNORE_DIRS = [ + "logs", # 忽略日志目录 + "storage", # 忽略存储目录 + "__pycache__", # 忽略 Python 缓存 + ".git", # 忽略 git 目录 + ".github", # 忽略 GitHub 相关目录 + ".vscode", # 忽略 VSCode 配置目录 + "deploy", # 忽略部署相关目录 + ".idea", # 忽略 IntelliJ IDEA 配置目录 + "node_modules", # 忽略 node_modules + "migrations", # 忽略数据库迁移目录 + ".pytest_cache", # 忽略 pytest 缓存 + ".venv", # 忽略虚拟环境 + "venv", # 忽略虚拟环境 + "env", # 忽略环境目录 + ".mypy_cache", # 忽略 mypy 缓存 + ".ruff_cache", # 忽略 ruff 缓存 + "dist", # 忽略构建分发目录 + "build", # 忽略构建目录 + ".coverage", # 忽略测试覆盖率文件 + "htmlcov", # 忽略覆盖率报告目录 + "tests", # 忽略测试目录 + "z-ai2api-server.pid", # 忽略 PID 文件 + "app\\templates" # 忽略模板目录 +] + +# 忽略的文件模式(正则表达式) +RELOAD_IGNORE_PATTERNS = [ + # 日志文件 + r".*\.log$", + r".*\.log\.\d+$", + # 数据库文件 + r".*\.sqlite3.*", + r".*\.db$", + r".*\.db-.*$", + # Python 相关 + r".*\.pyc$", + r".*\.pyo$", + r".*\.pyd$", + # 临时文件 + r".*\.tmp$", + r".*\.temp$", + r".*\.swp$", + r".*\.swo$", + r".*~$", + # 系统文件 + r".*\.DS_Store$", + r".*Thumbs\.db$", + r".*\.directory$", + # 编辑器文件 + r".*\.vscode.*", + r".*\.idea.*", + # 测试和覆盖率 + r".*\.coverage$", + r".*\.pytest_cache.*", + # 构建文件 + r".*\.egg-info.*", + r".*\.wheel$", + r".*\.whl$", + # 版本控制 + r".*\.git.*", + r".*\.gitignore$", + r".*\.gitkeep$", + # 配置文件备份 + r".*\.bak$", + r".*\.backup$", + r".*\.orig$", + # 锁文件 + r".*\.lock$", + r".*\.pid$", +] + +# 监视的路径(只监视应用相关代码) +RELOAD_WATCH_PATHS = [ + "app", # 应用主目录 + "main.py", # 主入口文件 +] + +# 热重载配置 +RELOAD_CONFIG = { + "reload_ignore_dirs": RELOAD_IGNORE_DIRS, + "reload_ignore_patterns": RELOAD_IGNORE_PATTERNS, + "reload_paths": RELOAD_WATCH_PATHS, + "reload_tick": 500, # 监视频率(毫秒) +} diff --git a/app/utils/request_logging.py b/app/utils/request_logging.py new file mode 100644 index 0000000000000000000000000000000000000000..e13dd423bf5f87313656fea6e1dda40d0e1cea58 --- /dev/null +++ b/app/utils/request_logging.py @@ -0,0 +1,337 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +"""请求日志写库与流式日志包装。""" + +from __future__ import annotations + +import json +import time +from typing import Any, AsyncGenerator, Dict, Optional + +from app.services.request_log_dao import get_request_log_dao +from app.utils.logger import get_logger +from app.utils.request_source import RequestSourceInfo + +logger = get_logger() + + +def _coerce_int(value: Any) -> int: + try: + return int(value or 0) + except (TypeError, ValueError): + return 0 + + +def _merge_usage( + current: Dict[str, int], + update: Dict[str, int], + *, + include_cache_in_total: bool, +) -> Dict[str, int]: + merged = dict(current) + + for key in ( + "input_tokens", + "output_tokens", + "cache_creation_tokens", + "cache_read_tokens", + ): + value = _coerce_int(update.get(key)) + if value > 0: + merged[key] = value + + total_tokens = _coerce_int(update.get("total_tokens")) + if total_tokens > 0: + merged["total_tokens"] = total_tokens + return merged + + merged["total_tokens"] = ( + merged["input_tokens"] + merged["output_tokens"] + ) + if include_cache_in_total: + merged["total_tokens"] += ( + merged["cache_creation_tokens"] + merged["cache_read_tokens"] + ) + + return merged + + +def extract_openai_usage(response: Dict[str, Any]) -> Dict[str, int]: + """Extract usage from an OpenAI-compatible response payload.""" + usage = response.get("usage") or {} + prompt_details = usage.get("prompt_tokens_details") or {} + input_details = usage.get("input_token_details") or {} + + input_tokens = _coerce_int( + usage.get("prompt_tokens") or usage.get("input_tokens") + ) + output_tokens = _coerce_int( + usage.get("completion_tokens") or usage.get("output_tokens") + ) + cache_creation_tokens = _coerce_int( + usage.get("cache_creation_input_tokens") + or prompt_details.get("cache_creation_tokens") + or input_details.get("cache_creation_input_tokens") + or input_details.get("cache_creation_tokens") + ) + cache_read_tokens = _coerce_int( + usage.get("cache_read_input_tokens") + or prompt_details.get("cached_tokens") + or prompt_details.get("cache_read_tokens") + or input_details.get("cached_tokens") + or input_details.get("cache_read_input_tokens") + or input_details.get("cache_read_tokens") + ) + total_tokens = _coerce_int(usage.get("total_tokens")) + if total_tokens <= 0: + total_tokens = input_tokens + output_tokens + + return { + "input_tokens": input_tokens, + "output_tokens": output_tokens, + "cache_creation_tokens": cache_creation_tokens, + "cache_read_tokens": cache_read_tokens, + "total_tokens": total_tokens, + } + + +def extract_claude_usage(response: Dict[str, Any]) -> Dict[str, int]: + """Extract usage from a Claude-compatible response payload.""" + usage = response.get("usage") or {} + input_tokens = _coerce_int( + usage.get("input_tokens") or usage.get("prompt_tokens") + ) + output_tokens = _coerce_int( + usage.get("output_tokens") or usage.get("completion_tokens") + ) + cache_creation_tokens = _coerce_int( + usage.get("cache_creation_input_tokens") + or usage.get("cache_creation_tokens") + ) + cache_read_tokens = _coerce_int( + usage.get("cache_read_input_tokens") + or usage.get("cached_tokens") + or usage.get("cache_read_tokens") + ) + total_tokens = _coerce_int(usage.get("total_tokens")) + if total_tokens <= 0: + total_tokens = ( + input_tokens + + output_tokens + + cache_creation_tokens + + cache_read_tokens + ) + + return { + "input_tokens": input_tokens, + "output_tokens": output_tokens, + "cache_creation_tokens": cache_creation_tokens, + "cache_read_tokens": cache_read_tokens, + "total_tokens": total_tokens, + } + + +async def write_request_log( + *, + provider: str, + model: str, + source_info: RequestSourceInfo, + success: bool, + started_at: float, + status_code: int = 200, + first_token_time: float = 0.0, + input_tokens: int = 0, + output_tokens: int = 0, + cache_creation_tokens: int = 0, + cache_read_tokens: int = 0, + total_tokens: Optional[int] = None, + error_message: Optional[str] = None, +) -> None: + """Persist a request log entry without breaking request handling.""" + duration = max(0.0, time.perf_counter() - started_at) + try: + dao = get_request_log_dao() + await dao.add_log( + provider=provider, + endpoint=source_info.endpoint, + source=source_info.source, + protocol=source_info.protocol, + client_name=source_info.client_name, + model=model, + status_code=status_code, + success=success, + duration=duration, + first_token_time=first_token_time, + input_tokens=input_tokens, + output_tokens=output_tokens, + cache_creation_tokens=cache_creation_tokens, + cache_read_tokens=cache_read_tokens, + total_tokens=total_tokens, + error_message=error_message, + ) + except Exception as exc: + logger.error(f"写入请求日志失败: {exc}") + + +def _openai_payload_has_output(payload: Dict[str, Any]) -> bool: + choice = ((payload.get("choices") or [{}])[0]) if isinstance(payload, dict) else {} + delta = choice.get("delta") or {} + return bool( + delta.get("content") + or delta.get("reasoning_content") + or delta.get("tool_calls") + ) + + +async def wrap_openai_stream_with_logging( + stream: AsyncGenerator[str, None], + *, + provider: str, + model: str, + source_info: RequestSourceInfo, + started_at: float, +) -> AsyncGenerator[str, None]: + """Wrap OpenAI SSE stream and persist completion metadata.""" + success = True + status_code = 200 + error_message: Optional[str] = None + first_token_time = 0.0 + usage = { + "input_tokens": 0, + "output_tokens": 0, + "cache_creation_tokens": 0, + "cache_read_tokens": 0, + "total_tokens": 0, + } + + try: + async for chunk in stream: + if chunk.startswith("data: "): + payload_text = chunk[6:].strip() + if payload_text and payload_text != "[DONE]": + try: + payload = json.loads(payload_text) + except json.JSONDecodeError: + payload = None + + if isinstance(payload, dict): + if "error" in payload: + success = False + error = payload.get("error") or {} + error_message = ( + error.get("message") + or "Unknown stream error" + ) + status_code = int(error.get("code") or 500) + else: + if ( + not first_token_time + and _openai_payload_has_output(payload) + ): + first_token_time = max( + 0.0, + time.perf_counter() - started_at, + ) + if payload.get("usage"): + usage = _merge_usage( + usage, + extract_openai_usage(payload), + include_cache_in_total=False, + ) + + yield chunk + except Exception as exc: + success = False + status_code = 500 + error_message = str(exc) + raise + finally: + await write_request_log( + provider=provider, + model=model, + source_info=source_info, + success=success, + started_at=started_at, + status_code=status_code, + first_token_time=first_token_time, + input_tokens=usage["input_tokens"], + output_tokens=usage["output_tokens"], + cache_creation_tokens=usage["cache_creation_tokens"], + cache_read_tokens=usage["cache_read_tokens"], + total_tokens=usage["total_tokens"], + error_message=error_message, + ) + + +async def wrap_claude_stream_with_logging( + stream: AsyncGenerator[str, None], + *, + provider: str, + model: str, + source_info: RequestSourceInfo, + started_at: float, + input_tokens: int, +) -> AsyncGenerator[str, None]: + """Wrap Claude SSE stream and persist completion metadata.""" + success = True + status_code = 200 + error_message: Optional[str] = None + first_token_time = 0.0 + usage = { + "input_tokens": input_tokens, + "output_tokens": 0, + "cache_creation_tokens": 0, + "cache_read_tokens": 0, + "total_tokens": input_tokens, + } + current_event: Optional[str] = None + + try: + async for chunk in stream: + if chunk.startswith("event: "): + current_event = chunk[7:].strip() + elif chunk.startswith("data: "): + payload_text = chunk[6:].strip() + try: + payload = json.loads(payload_text) + except json.JSONDecodeError: + payload = None + + if isinstance(payload, dict): + if current_event == "content_block_delta" and not first_token_time: + first_token_time = max(0.0, time.perf_counter() - started_at) + if payload.get("usage"): + usage = _merge_usage( + usage, + extract_claude_usage(payload), + include_cache_in_total=True, + ) + elif current_event == "error": + success = False + status_code = 500 + error = payload.get("error") or {} + error_message = error.get("message") or "Claude stream error" + + yield chunk + except Exception as exc: + success = False + status_code = 500 + error_message = str(exc) + raise + finally: + await write_request_log( + provider=provider, + model=model, + source_info=source_info, + success=success, + started_at=started_at, + status_code=status_code, + first_token_time=first_token_time, + input_tokens=usage["input_tokens"], + output_tokens=usage["output_tokens"], + cache_creation_tokens=usage["cache_creation_tokens"], + cache_read_tokens=usage["cache_read_tokens"], + total_tokens=usage["total_tokens"], + error_message=error_message, + ) diff --git a/app/utils/request_source.py b/app/utils/request_source.py new file mode 100644 index 0000000000000000000000000000000000000000..02c68c915801d53a1851994c71379350239fb052 --- /dev/null +++ b/app/utils/request_source.py @@ -0,0 +1,129 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +"""请求来源识别辅助函数。""" + +from __future__ import annotations + +import re +from dataclasses import dataclass +from typing import Any, Optional + +from fastapi import Request + + +ANTHROPIC_MODEL_PREFIXES = ( + "claude-", + "claude.", +) +ANTHROPIC_MODEL_ALIASES = { + "sonnet", + "opus", + "haiku", + "opusplan", +} + + +@dataclass(frozen=True) +class RequestSourceInfo: + """Normalized request-source metadata for logging.""" + + source: str + protocol: str + client_name: str + endpoint: str + user_agent: str + + +def _normalize_source_name(value: str) -> str: + normalized = re.sub(r"[^a-zA-Z0-9._-]+", "_", value.strip().lower()) + return normalized.strip("_") or "unknown" + + +def _looks_like_anthropic_model(model_hint: Optional[str]) -> bool: + if not isinstance(model_hint, str): + return False + + normalized = model_hint.strip().casefold() + if normalized in ANTHROPIC_MODEL_ALIASES: + return True + + return normalized.startswith(ANTHROPIC_MODEL_PREFIXES) + + +def detect_request_source( + request: Request, + protocol_hint: Optional[str] = None, + model_hint: Optional[str] = None, +) -> RequestSourceInfo: + """Detect the request source from headers, path, and model hints.""" + headers = request.headers + endpoint = request.url.path + user_agent = (headers.get("user-agent") or "").strip() + user_agent_normalized = user_agent.casefold() + + protocol = (protocol_hint or "").strip().lower() + if not protocol: + if headers.get("anthropic-version") or "/messages" in endpoint: + protocol = "anthropic" + elif "/chat/completions" in endpoint: + protocol = "openai" + else: + protocol = "unknown" + + explicit_source = headers.get("x-request-source") or headers.get("x-client-source") + if explicit_source: + source = _normalize_source_name(explicit_source) + return RequestSourceInfo( + source=source, + protocol=protocol, + client_name=explicit_source.strip(), + endpoint=endpoint, + user_agent=user_agent, + ) + + if any(token in user_agent_normalized for token in ("claude-code", "claude code", "claude-cli", "claude/")): + source = "claude_code" + client_name = "Claude Code" + elif "anthropic" in user_agent_normalized: + source = "anthropic_sdk" + client_name = "Anthropic SDK" + elif "openai" in user_agent_normalized: + source = "openai_sdk" + client_name = "OpenAI SDK" + elif "curl/" in user_agent_normalized: + source = "curl" + client_name = "curl" + elif any(token in user_agent_normalized for token in ("python-httpx", "httpx/", "python-requests", "requests/")): + source = "custom_http_client" + client_name = "HTTP Client" + elif "mozilla/" in user_agent_normalized: + source = "browser" + client_name = "Browser" + elif protocol == "anthropic": + source = "claude_family" if _looks_like_anthropic_model(model_hint) else "anthropic_compatible" + client_name = "Claude/Anthropic Compatible" + elif protocol == "openai": + source = "openai_compatible" + client_name = "OpenAI Compatible" + else: + source = "unknown" + client_name = "Unknown" + + return RequestSourceInfo( + source=source, + protocol=protocol, + client_name=client_name, + endpoint=endpoint, + user_agent=user_agent, + ) + + +def format_request_source(info: RequestSourceInfo) -> str: + """Render request-source metadata into a compact log prefix.""" + return ( + f"[source={info.source}]" + f"[protocol={info.protocol}]" + f"[client={info.client_name}]" + f"[endpoint={info.endpoint}]" + ) diff --git a/app/utils/signature.py b/app/utils/signature.py new file mode 100644 index 0000000000000000000000000000000000000000..785967a448ff2669f4e93f8a2aa66c53af85673f --- /dev/null +++ b/app/utils/signature.py @@ -0,0 +1,56 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +""" +Z.AI 签名工具模块 +""" + +import hmac +import hashlib +import base64 +from typing import Dict + + +def generate_signature(e: str, t: str, s: int) -> dict: + """Generate signature matching JavaScript zs function. + + Args: + e: canonical metadata string, e.g. "requestId,,timestamp,,user_id," + t: latest user message text that feeds into the signature prompt (may be empty) + s: timestamp in milliseconds + + Returns: + Dictionary with signature and timestamp + """ + # r = Number(s) - convert to number (already a number in Python) + r = s + # i = s - timestamp as string + i = str(s) + + # n = new TextEncoder + # a = n.encode(t) + a = t.encode('utf-8') + + # w = btoa(String.fromCharCode(...a)) + # This is equivalent to base64 encoding the UTF-8 bytes + w = base64.b64encode(a).decode('ascii') + + # c = `${e}|${w}|${i}` + c = f"{e}|{w}|{i}" + + # E = Math.floor(r / (5 * 60 * 1e3)) + E = r // (5 * 60 * 1000) + + # A = CryptoJS.HmacSHA256(`${E}`, "key-@@@@)))()((9))-xxxx&&&%%%%%") + secret = "key-@@@@)))()((9))-xxxx&&&%%%%%" + A = hmac.new(secret.encode('utf-8'), str(E).encode('utf-8'), hashlib.sha256).hexdigest() + + # k = CryptoJS.HmacSHA256(c, A).toString() + k = hmac.new(A.encode('utf-8'), c.encode('utf-8'), hashlib.sha256).hexdigest() + + # return n.encode(c), { signature: k, timestamp: i } + # Note: n.encode(c) is not used in the return value, so we ignore it + return { + "signature": k, + "timestamp": i + } diff --git a/app/utils/token_pool.py b/app/utils/token_pool.py new file mode 100644 index 0000000000000000000000000000000000000000..7681313bbae55b51030a0e723c876a329a99b363 --- /dev/null +++ b/app/utils/token_pool.py @@ -0,0 +1,685 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +""" +Token 池管理器 - 基于数据库的 Token 轮询和健康检查 + +核心功能: +1. Token 轮询机制 - 负载均衡和容错 +2. Z.AI 官方认证接口验证 - 基于 role 字段区分用户类型 +3. Token 健康度监控 - 自动禁用失败 Token +4. 数据库集成 - 与 TokenDAO 协同工作 +""" + +import asyncio +import time +from dataclasses import dataclass +from threading import Lock +from typing import Dict, List, Optional, Set, Tuple + +import httpx + +from app.utils.logger import logger + + +# ==================== Token 状态管理 ==================== + + +@dataclass +class TokenStatus: + """Token 运行时状态(内存中)""" + token: str + token_id: int # 数据库 ID,用于同步统计 + token_type: str = "unknown" # "user", "guest", "unknown" + is_available: bool = True + failure_count: int = 0 + last_failure_time: float = 0.0 + last_success_time: float = 0.0 + total_requests: int = 0 + successful_requests: int = 0 + db_synced_successful_requests: int = 0 + db_synced_failed_requests: int = 0 + + @property + def success_rate(self) -> float: + """成功率""" + if self.total_requests == 0: + return 1.0 + return self.successful_requests / self.total_requests + + @property + def failed_requests(self) -> int: + """失败次数。""" + return max(0, self.total_requests - self.successful_requests) + + @property + def is_healthy(self) -> bool: + """ + Token 健康状态判断 + + 健康标准: + 1. 必须是认证用户 Token (token_type = "user") + 2. 当前可用 (is_available = True) + 3. 成功率 >= 50% 或总请求数 <= 3(新 Token 容错) + + 注意: + - guest Token 永远不健康 + - unknown Token 永远不健康 + """ + # guest 和 unknown token 永远不健康 + if self.token_type != "user": + return False + + # 不可用的 token 不健康 + if not self.is_available: + return False + + # 新 token 容错:请求数很少时,只要没失败就健康 + if self.total_requests <= 3: + return self.failure_count == 0 + + # 基于成功率判断 + return self.success_rate >= 0.5 + + +# ==================== Token 验证服务 ==================== + + +class ZAITokenValidator: + """Z.AI Token 验证器(使用官方认证接口)""" + + AUTH_URL = "https://chat.z.ai/api/v1/auths/" + + @staticmethod + def get_headers(token: str) -> Dict[str, str]: + """构建认证请求头""" + return { + "Accept": "*/*", + "Accept-Language": "zh-CN,zh;q=0.9", + "Authorization": f"Bearer {token}", + "Connection": "keep-alive", + "Content-Type": "application/json", + "DNT": "1", + "Referer": "https://chat.z.ai/", + "Sec-Fetch-Dest": "empty", + "Sec-Fetch-Mode": "cors", + "Sec-Fetch-Site": "same-origin", + "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/140.0.0.0 Safari/537.36", + "sec-ch-ua": '"Chromium";v="140", "Not=A?Brand";v="24", "Google Chrome";v="140"', + "sec-ch-ua-mobile": "?0", + "sec-ch-ua-platform": '"Windows"' + } + + @classmethod + async def validate_token(cls, token: str) -> Tuple[str, bool, Optional[str]]: + """ + 验证 Token 有效性并返回类型 + + Args: + token: 待验证的 Token + + Returns: + (token_type, is_valid, error_message) + - token_type: "user" | "guest" | "unknown" + - is_valid: True 表示是有效的认证用户 Token + - error_message: 失败原因(仅在 is_valid=False 时有值) + """ + try: + async with httpx.AsyncClient(timeout=15.0) as client: + response = await client.get( + cls.AUTH_URL, + headers=cls.get_headers(token) + ) + + # 解析响应 + return cls._parse_auth_response(response) + + except httpx.TimeoutException: + return ("unknown", False, "请求超时") + except httpx.ConnectError: + return ("unknown", False, "连接失败") + except Exception as e: + return ("unknown", False, f"验证异常: {str(e)}") + + @staticmethod + def _parse_auth_response(response: httpx.Response) -> Tuple[str, bool, Optional[str]]: + """ + 解析 Z.AI 认证接口响应 + + 响应格式示例: + { + "id": "...", + "email": "user@example.com", + "role": "user" # 或 "guest" + } + + 验证规则: + - role: "user" → 认证用户 Token(有效,可添加) + - role: "guest" → 匿名用户 Token(无效,拒绝添加) + - 其他情况 → 无效 Token + """ + # 检查 HTTP 状态码 + if response.status_code != 200: + return ("unknown", False, f"HTTP {response.status_code}") + + try: + data = response.json() + + # 验证响应格式 + if not isinstance(data, dict): + return ("unknown", False, "无效的响应格式") + + # 检查是否包含错误信息 + if "error" in data or "message" in data: + error_msg = data.get("error") or data.get("message", "未知错误") + return ("unknown", False, str(error_msg)) + + # 核心验证:检查 role 字段 + role = data.get("role") + + if role == "user": + return ("user", True, None) + elif role == "guest": + return ("guest", False, "匿名用户 Token 不允许添加") + else: + return ("unknown", False, f"未知 role: {role}") + + except (ValueError, Exception) as e: + return ("unknown", False, f"解析响应失败: {str(e)}") + + +# ==================== Token 池管理器 ==================== + + +class TokenPool: + """Token 池管理器(数据库驱动)""" + + def __init__( + self, + tokens: List[Tuple[int, str, str]], # [(token_id, token_value, token_type), ...] + failure_threshold: int = 3, + recovery_timeout: int = 1800 + ): + """ + 初始化 Token 池 + + Args: + tokens: Token 列表 [(token_id, token_value, token_type), ...] + failure_threshold: 失败阈值,超过此次数将标记为不可用 + recovery_timeout: 恢复超时时间(秒),失败 Token 在此时间后重新尝试 + """ + self.failure_threshold = failure_threshold + self.recovery_timeout = recovery_timeout + self._lock = Lock() + self._current_index = 0 + + # 初始化 Token 状态(内存中) + self.token_statuses: Dict[str, TokenStatus] = {} + self.token_id_map: Dict[str, int] = {} # token -> token_id 映射 + + for token_id, token_value, token_type in tokens: + if token_value and token_value not in self.token_statuses: + self.token_statuses[token_value] = TokenStatus( + token=token_value, + token_id=token_id, + token_type=token_type + ) + self.token_id_map[token_value] = token_id + + if not self.token_statuses: + logger.warning("⚠️ Token 池为空,将依赖匿名模式") + + def get_next_token(self, exclude_tokens: Optional[Set[str]] = None) -> Optional[str]: + """ + 获取下一个可用的认证用户 Token(轮询算法) + + Returns: + 可用的 Token 字符串,如果没有可用 Token 则返回 None + """ + with self._lock: + if not self.token_statuses: + return None + + excluded = exclude_tokens or set() + + available_tokens = self._get_available_user_tokens() + if excluded: + available_tokens = [ + token for token in available_tokens if token not in excluded + ] + if not available_tokens: + # 尝试恢复过期的失败 Token + self._try_recover_failed_tokens() + available_tokens = self._get_available_user_tokens() + if excluded: + available_tokens = [ + token for token in available_tokens if token not in excluded + ] + + if not available_tokens: + logger.warning("⚠️ 没有可用的认证用户 Token") + return None + + # 轮询选择 + token = available_tokens[self._current_index % len(available_tokens)] + self._current_index = (self._current_index + 1) % len(available_tokens) + + return token + + def _get_available_user_tokens(self) -> List[str]: + """ + 获取当前可用的认证用户 Token 列表 + + 过滤条件: + 1. is_available = True + 2. token_type == "user" + """ + available_user_tokens = [ + status.token for status in self.token_statuses.values() + if status.is_available and status.token_type == "user" + ] + + # 警告:如果有 guest token 但没有 user token + if not available_user_tokens and self.token_statuses: + guest_count = sum( + 1 for status in self.token_statuses.values() + if status.token_type == "guest" + ) + if guest_count > 0: + logger.warning(f"⚠️ 检测到 {guest_count} 个匿名用户 Token,轮询机制将跳过这些 Token") + + return available_user_tokens + + def _try_recover_failed_tokens(self): + """尝试恢复失败的 Token(仅针对认证用户 Token)""" + current_time = time.time() + recovered_count = 0 + + for status in self.token_statuses.values(): + # 只恢复认证用户 Token + if ( + status.token_type == "user" + and not status.is_available + and current_time - status.last_failure_time > self.recovery_timeout + ): + status.is_available = True + status.failure_count = 0 + recovered_count += 1 + logger.info(f"🔄 恢复失败 Token: {status.token[:20]}...") + + if recovered_count > 0: + logger.info(f"✅ 恢复了 {recovered_count} 个失败的 Token") + + def mark_token_success(self, token: str): + """标记 Token 使用成功""" + with self._lock: + if token in self.token_statuses: + status = self.token_statuses[token] + status.total_requests += 1 + status.successful_requests += 1 + status.last_success_time = time.time() + status.failure_count = 0 # 重置失败计数 + + if not status.is_available: + status.is_available = True + logger.info(f"✅ Token 恢复可用: {token[:20]}...") + + def mark_token_failure(self, token: str, error: Exception = None): + """标记 Token 使用失败""" + with self._lock: + if token in self.token_statuses: + status = self.token_statuses[token] + status.total_requests += 1 + status.failure_count += 1 + status.last_failure_time = time.time() + + if status.failure_count >= self.failure_threshold: + status.is_available = False + logger.warning(f"🚫 Token 已禁用: {token[:20]}... (失败 {status.failure_count} 次)") + + async def record_token_success(self, token: str, dao=None): + """标记成功并实时同步数据库统计。""" + self.mark_token_success(token) + + token_id = self.get_token_id(token) + if token_id is None: + return + + if dao is None: + from app.services.token_dao import get_token_dao + + dao = get_token_dao() + + try: + await dao.record_success(token_id) + except Exception as e: + logger.error(f"❌ 同步 Token 成功统计失败: {e}") + return + + with self._lock: + if token in self.token_statuses: + self.token_statuses[token].db_synced_successful_requests += 1 + + async def record_token_failure(self, token: str, error: Exception = None, dao=None): + """标记失败并实时同步数据库统计。""" + self.mark_token_failure(token, error) + + token_id = self.get_token_id(token) + if token_id is None: + return + + if dao is None: + from app.services.token_dao import get_token_dao + + dao = get_token_dao() + + try: + await dao.record_failure(token_id) + except Exception as e: + logger.error(f"❌ 同步 Token 失败统计失败: {e}") + return + + with self._lock: + if token in self.token_statuses: + self.token_statuses[token].db_synced_failed_requests += 1 + + def get_token_id(self, token: str) -> Optional[int]: + """获取 Token 的数据库 ID""" + return self.token_id_map.get(token) + + def get_pool_status(self) -> Dict: + """获取 Token 池状态信息""" + with self._lock: + available_count = len(self._get_available_user_tokens()) + total_count = len(self.token_statuses) + healthy_count = sum(1 for status in self.token_statuses.values() if status.is_healthy) + + # 统计各类型 Token + user_count = sum(1 for s in self.token_statuses.values() if s.token_type == "user") + guest_count = sum(1 for s in self.token_statuses.values() if s.token_type == "guest") + unknown_count = sum(1 for s in self.token_statuses.values() if s.token_type == "unknown") + + status_info = { + "total_tokens": total_count, + "available_tokens": available_count, + "unavailable_tokens": total_count - available_count, + "healthy_tokens": healthy_count, + "unhealthy_tokens": total_count - healthy_count, + "user_tokens": user_count, + "guest_tokens": guest_count, + "unknown_tokens": unknown_count, + "current_index": self._current_index, + "tokens": [] + } + + for token, status in self.token_statuses.items(): + status_info["tokens"].append({ + "token": f"{token[:10]}...{token[-10:]}", + "token_id": status.token_id, + "token_type": status.token_type, + "is_available": status.is_available, + "failure_count": status.failure_count, + "success_count": status.successful_requests, + "success_rate": f"{status.success_rate:.2%}", + "total_requests": status.total_requests, + "is_healthy": status.is_healthy, + "last_failure_time": status.last_failure_time, + "last_success_time": status.last_success_time + }) + + return status_info + + def update_token_type(self, token: str, token_type: str): + """更新 Token 类型(用于健康检查后更新)""" + with self._lock: + if token in self.token_statuses: + old_type = self.token_statuses[token].token_type + self.token_statuses[token].token_type = token_type + + if old_type != token_type: + logger.info(f"🔄 更新 Token 类型: {token[:20]}... {old_type} → {token_type}") + + async def health_check_token(self, token: str) -> bool: + """ + 异步健康检查单个 Token(使用 Z.AI 官方认证接口) + + Args: + token: 要检查的 Token + + Returns: + Token 是否健康(True = 有效的认证用户 Token) + """ + token_type, is_valid, error_message = await ZAITokenValidator.validate_token(token) + + # 更新 Token 类型 + self.update_token_type(token, token_type) + + # 更新状态 + if is_valid: + await self.record_token_success(token) + else: + await self.record_token_failure( + token, + Exception(error_message or "验证失败"), + ) + + return is_valid + + async def health_check_all(self): + """异步健康检查所有 Token""" + if not self.token_statuses: + logger.warning("⚠️ Token 池为空,跳过健康检查") + return + + total_tokens = len(self.token_statuses) + logger.info(f"🔍 开始 Token 池健康检查... (共 {total_tokens} 个 Token)") + + # 并发执行所有 Token 的健康检查 + tasks = [ + self.health_check_token(token) + for token in self.token_statuses.keys() + ] + + results = await asyncio.gather(*tasks, return_exceptions=True) + + # 统计结果 + healthy_count = sum(1 for r in results if r is True) + failed_count = sum(1 for r in results if r is False) + exception_count = sum(1 for r in results if isinstance(r, Exception)) + + health_rate = (healthy_count / total_tokens) * 100 if total_tokens > 0 else 0 + + if healthy_count == 0 and total_tokens > 0: + logger.warning(f"⚠️ 健康检查完成: 0/{total_tokens} 个 Token 健康 - 请检查 Token 配置") + elif failed_count > 0: + logger.warning(f"⚠️ 健康检查完成: {healthy_count}/{total_tokens} 个 Token 健康 ({health_rate:.1f}%)") + else: + logger.info(f"✅ 健康检查完成: {healthy_count}/{total_tokens} 个 Token 健康") + + if exception_count > 0: + logger.error(f"💥 {exception_count} 个 Token 检查异常") + + async def sync_from_database(self, provider: str = "zai"): + """ + 从数据库同步 Token 状态(禁用/启用状态) + + Args: + provider: 提供商名称 + + 说明: + - 从数据库读取最新的 Token 启用状态 + - 如果数据库中 Token 被禁用,则从池中移除 + - 如果数据库中有新增的启用 Token,则添加到池中 + - 保留现有 Token 的运行时统计(请求数、成功率等) + """ + from app.services.token_dao import get_token_dao + + dao = get_token_dao() + + # 从数据库加载所有启用的认证用户 Token + token_records = await dao.get_tokens_by_provider(provider, enabled_only=True) + + # 构建数据库中的 Token 映射 + db_tokens = { + record["token"]: (record["id"], record.get("token_type", "unknown")) + for record in token_records + if record.get("token_type") != "guest" # 过滤 guest token + } + + with self._lock: + # 1. 移除已在数据库中禁用的 Token + tokens_to_remove = [] + for token_value in list(self.token_statuses.keys()): + if token_value not in db_tokens: + tokens_to_remove.append(token_value) + + for token_value in tokens_to_remove: + del self.token_statuses[token_value] + del self.token_id_map[token_value] + logger.info(f"🗑️ 从池中移除已禁用 Token: {token_value[:20]}...") + + # 2. 添加新启用的 Token + new_tokens_count = 0 + for token_value, (token_id, token_type) in db_tokens.items(): + if token_value not in self.token_statuses: + self.token_statuses[token_value] = TokenStatus( + token=token_value, + token_id=token_id, + token_type=token_type + ) + self.token_id_map[token_value] = token_id + new_tokens_count += 1 + logger.info(f"➕ 添加新启用 Token: {token_value[:20]}...") + + # 3. 更新现有 Token 的类型(如果数据库中有更新) + for token_value, (token_id, token_type) in db_tokens.items(): + if token_value in self.token_statuses: + old_type = self.token_statuses[token_value].token_type + if old_type != token_type: + self.token_statuses[token_value].token_type = token_type + logger.info(f"🔄 更新 Token 类型: {token_value[:20]}... {old_type} → {token_type}") + + logger.info( + f"✅ Token 池同步完成: " + f"当前 {len(self.token_statuses)} 个 Token " + f"(移除 {len(tokens_to_remove)}, 新增 {new_tokens_count})" + ) + + +# ==================== 全局实例管理 ==================== + + +_token_pool: Optional[TokenPool] = None +_pool_lock = Lock() + + +def get_token_pool() -> Optional[TokenPool]: + """获取全局 Token 池实例""" + return _token_pool + + +async def initialize_token_pool_from_db( + provider: str = "zai", + failure_threshold: int = 3, + recovery_timeout: int = 1800 +) -> Optional[TokenPool]: + """ + 从数据库初始化全局 Token 池 + + Args: + provider: 提供商名称(当前仅使用 zai) + failure_threshold: 失败阈值 + recovery_timeout: 恢复超时时间(秒) + + Returns: + TokenPool 实例(即使没有 Token 也会创建空池) + """ + global _token_pool + + from app.services.token_dao import get_token_dao + + dao = get_token_dao() + + # 从数据库加载 Token(只加载启用的认证用户 Token) + token_records = await dao.get_tokens_by_provider(provider, enabled_only=True) + + # 转换为 TokenPool 所需格式 + tokens = [] + if token_records: + tokens = [ + (record["id"], record["token"], record.get("token_type", "unknown")) + for record in token_records + ] + + # 过滤掉 guest token(不应该在数据库中,但防御性检查) + user_tokens = [ + (tid, tval, ttype) for tid, tval, ttype in tokens + if ttype != "guest" + ] + + if len(user_tokens) < len(tokens): + guest_count = len(tokens) - len(user_tokens) + logger.warning(f"⚠️ 过滤了 {guest_count} 个匿名用户 Token") + + tokens = user_tokens + + # 始终创建 Token 池实例(即使为空) + with _pool_lock: + _token_pool = TokenPool(tokens, failure_threshold, recovery_timeout) + + if not tokens: + logger.warning(f"⚠️ {provider} 没有有效的认证用户 Token,已创建空 Token 池") + else: + logger.info(f"🔧 从数据库初始化 Token 池({provider}),共 {len(tokens)} 个 Token") + + return _token_pool + + +async def sync_token_stats_to_db(): + """ + 将内存中的 Token 统计同步到数据库 + + 应在服务关闭或定期调用,确保统计数据不丢失 + """ + pool = get_token_pool() + if not pool: + return + + from app.services.token_dao import get_token_dao + + dao = get_token_dao() + + pending_updates = [] + with pool._lock: + for token, status in pool.token_statuses.items(): + pending_success = max( + 0, + status.successful_requests - status.db_synced_successful_requests, + ) + pending_failure = max( + 0, + status.failed_requests - status.db_synced_failed_requests, + ) + if pending_success > 0 or pending_failure > 0: + pending_updates.append( + ( + token, + status.token_id, + pending_success, + pending_failure, + ) + ) + + for token, token_id, pending_success, pending_failure in pending_updates: + for _ in range(pending_success): + await dao.record_success(token_id) + for _ in range(pending_failure): + await dao.record_failure(token_id) + + with pool._lock: + if token in pool.token_statuses: + status = pool.token_statuses[token] + status.db_synced_successful_requests += pending_success + status.db_synced_failed_requests += pending_failure + + logger.info("✅ Token 统计已同步到数据库") diff --git a/app/utils/tool_call_handler.py b/app/utils/tool_call_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..a82620f3de903f11198c8d6f5947237acb1bbe84 --- /dev/null +++ b/app/utils/tool_call_handler.py @@ -0,0 +1,347 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +工具调用处理模块 +""" + +import json +import re +from typing import Dict, List, Any, Optional, Tuple +from app.utils.logger import get_logger + +logger = get_logger() + + +def generate_tool_prompt(tools: Optional[List[Dict[str, Any]]]) -> str: + """ + 生成工具调用提示词 + 将 OpenAI tools 定义转换为 Markdown 格式的说明文档 + + Args: + tools: OpenAI 格式的工具定义列表 + + Returns: + str: Markdown 格式的工具使用说明 + """ + if not tools or len(tools) == 0: + return "" + + tool_definitions = [] + + for tool in tools: + if tool.get("type") != "function": + continue + + function_spec = tool.get("function", {}) + function_name = function_spec.get("name", "unknown") + function_description = function_spec.get("description", "") + parameters = function_spec.get("parameters", {}) + + # 创建结构化的工具定义 + tool_info = [ + f"## {function_name}", + f"**Purpose**: {function_description}" + ] + + # 添加参数详情 + parameter_properties = parameters.get("properties", {}) + required_parameters = set(parameters.get("required", [])) + + if parameter_properties: + tool_info.append("**Parameters**:") + for param_name, param_info in parameter_properties.items(): + param_type = param_info.get("type", "string") + param_desc = param_info.get("description", "") + is_required = param_name in required_parameters + required_str = " (required)" if is_required else " (optional)" + tool_info.append(f"- `{param_name}` ({param_type}){required_str}: {param_desc}") + + tool_definitions.append("\n".join(tool_info)) + + # 组合完整的提示词 + prompt = ( + "\n\n---\n" + "# Available Tools\n\n" + + "\n\n".join(tool_definitions) + + "\n\n" + "**Tool Invocation Format**:\n" + "To use a tool, include a JSON block with this structure:\n" + '{"tool_calls": [{"id": "call_ID", "type": "function", "function": {"name": "TOOL_NAME", "arguments": "JSON_STRING"}}]}\n\n' + "**Rules**:\n" + "- Use tool ONLY when user explicitly requests an action that matches a tool's purpose\n" + "- For normal conversation, respond naturally WITHOUT any tool calls\n" + "- The `arguments` must be a JSON string, not an object\n" + "- Multiple tools can be called by adding more items to the array\n" + "---\n\n" + ) + + logger.debug(f"生成工具提示词,包含 {len(tool_definitions)} 个工具定义") + return prompt + + +def process_messages_with_tools( + messages: List[Dict[str, Any]], + tools: Optional[List[Dict[str, Any]]], + tool_choice: str = "auto" +) -> List[Dict[str, Any]]: + """ + 将工具定义注入到消息列表中 + + Args: + messages: 原始消息列表 + tools: 工具定义列表 + tool_choice: 工具选择策略 ("auto", "none", 等) + + Returns: + List[Dict]: 处理后的消息列表 + """ + if not tools or tool_choice == "none": + return messages + + tools_prompt = generate_tool_prompt(tools) + if not tools_prompt: + return messages + + processed = [] + has_system = any(m.get("role") == "system" for m in messages) + + if has_system: + # 如果有 system 消息,将工具提示追加到第一个 system 消息 + for msg in messages: + if msg.get("role") == "system": + new_msg = msg.copy() + content = new_msg.get("content", "") + if isinstance(content, list): + # 多模态内容 + content_str = " ".join([ + item.get("text", "") if item.get("type") == "text" else "" + for item in content + ]) + else: + content_str = str(content) + new_msg["content"] = content_str + tools_prompt + processed.append(new_msg) + else: + processed.append(msg) + else: + # 没有 system 消息,创建一个新的 system 消息 + processed.append({ + "role": "system", + "content": f"You are a helpful assistant with access to tools.{tools_prompt}" + }) + processed.extend(messages) + + logger.debug(f"工具提示已注入到消息列表,共 {len(processed)} 条消息") + return processed + + +def parse_and_extract_tool_calls(content: str) -> Tuple[Optional[List[Dict[str, Any]]], str]: + """ + 从响应内容中提取 tool_calls JSON + + Args: + content: 模型返回的文本内容 + + Returns: + Tuple[Optional[List], str]: (提取的 tool_calls 列表, 清理后的内容) + """ + if not content or not content.strip(): + return None, content + + tool_calls = None + cleaned_content = content + + # 方法1: 尝试解析 JSON 代码块中的 tool_calls + # 匹配 ```json ... ``` 或 ```...``` + json_block_pattern = r'```(?:json)?\s*\n?(\{[\s\S]*?\})\s*\n?```' + json_blocks = re.findall(json_block_pattern, content) + + for json_str in json_blocks: + try: + parsed_data = json.loads(json_str) + if "tool_calls" in parsed_data: + tool_calls = parsed_data["tool_calls"] + if tool_calls and isinstance(tool_calls, list): + # 确保 arguments 字段是字符串 + for tc in tool_calls: + if tc.get("function"): + func = tc["function"] + if func.get("arguments"): + if isinstance(func["arguments"], dict): + # 转换对象为 JSON 字符串 + func["arguments"] = json.dumps(func["arguments"], ensure_ascii=False) + elif not isinstance(func["arguments"], str): + func["arguments"] = str(func["arguments"]) + logger.debug(f"从 JSON 代码块中提取到 {len(tool_calls)} 个工具调用") + break + except json.JSONDecodeError: + continue + + # 方法2: 尝试从文本中直接查找 JSON 对象 + if not tool_calls: + # 查找包含 "tool_calls" 的 JSON 对象 + i = 0 + scannable_text = content + while i < len(scannable_text): + if scannable_text[i] == '{': + # 尝试找到匹配的闭合括号 + brace_count = 1 + j = i + 1 + in_string = False + escape_next = False + + while j < len(scannable_text) and brace_count > 0: + if escape_next: + escape_next = False + elif scannable_text[j] == '\\': + escape_next = True + elif scannable_text[j] == '"': + in_string = not in_string + elif not in_string: + if scannable_text[j] == '{': + brace_count += 1 + elif scannable_text[j] == '}': + brace_count -= 1 + j += 1 + + if brace_count == 0: + # 找到完整的 JSON 对象 + json_candidate = scannable_text[i:j] + try: + parsed_data = json.loads(json_candidate) + if "tool_calls" in parsed_data: + tool_calls = parsed_data["tool_calls"] + if tool_calls and isinstance(tool_calls, list): + # 确保 arguments 字段是字符串 + for tc in tool_calls: + if tc.get("function"): + func = tc["function"] + if func.get("arguments"): + if isinstance(func["arguments"], dict): + func["arguments"] = json.dumps(func["arguments"], ensure_ascii=False) + elif not isinstance(func["arguments"], str): + func["arguments"] = str(func["arguments"]) + logger.debug(f"从内联 JSON 中提取到 {len(tool_calls)} 个工具调用") + break + except json.JSONDecodeError: + pass + + i = j + else: + i += 1 + + # 清理内容 - 移除包含 tool_calls 的 JSON + if tool_calls: + cleaned_content = remove_tool_json_content(content) + + return tool_calls, cleaned_content + + +def remove_tool_json_content(content: str) -> str: + """ + 从响应内容中移除工具调用 JSON + + Args: + content: 原始响应内容 + + Returns: + str: 清理后的内容 + """ + if not content: + return content + + # 步骤1: 移除 JSON 代码块中包含 tool_calls 的部分 + cleaned_text = content + + # 匹配 ```json ... ``` 或 ```...``` + def replace_json_block(match): + json_content = match.group(1) + try: + parsed_data = json.loads(json_content) + if "tool_calls" in parsed_data: + return "" # 移除整个代码块 + except json.JSONDecodeError: + pass + return match.group(0) # 保留原文 + + json_block_pattern = r'```(?:json)?\s*\n?(\{[\s\S]*?\})\s*\n?```' + cleaned_text = re.sub(json_block_pattern, replace_json_block, cleaned_text) + + # 步骤2: 移除内联的 tool JSON - 使用括号平衡方法 + result = [] + i = 0 + + while i < len(cleaned_text): + if cleaned_text[i] == '{': + # 尝试找到匹配的闭合括号 + brace_count = 1 + j = i + 1 + in_string = False + escape_next = False + + while j < len(cleaned_text) and brace_count > 0: + if escape_next: + escape_next = False + elif cleaned_text[j] == '\\': + escape_next = True + elif cleaned_text[j] == '"': + in_string = not in_string + elif not in_string: + if cleaned_text[j] == '{': + brace_count += 1 + elif cleaned_text[j] == '}': + brace_count -= 1 + j += 1 + + if brace_count == 0: + # 找到完整的 JSON 对象,检查是否包含 tool_calls + json_candidate = cleaned_text[i:j] + try: + parsed = json.loads(json_candidate) + if "tool_calls" in parsed: + # 这是一个工具调用,跳过它 + i = j + continue + except json.JSONDecodeError: + pass + + # 不是工具调用或无法解析,保留这个字符 + result.append(cleaned_text[i]) + i += 1 + else: + result.append(cleaned_text[i]) + i += 1 + + cleaned_result = "".join(result).strip() + + # 移除多余的空白行 + cleaned_result = re.sub(r'\n{3,}', '\n\n', cleaned_result) + + logger.debug(f"内容清理完成,原始长度: {len(content)}, 清理后长度: {len(cleaned_result)}") + return cleaned_result + + +def content_to_string(content: Any) -> str: + """ + 将消息内容转换为字符串 + + Args: + content: 消息内容,可能是字符串或列表(多模态) + + Returns: + str: 字符串格式的内容 + """ + if isinstance(content, str): + return content + elif isinstance(content, list): + # 多模态内容,提取文本部分 + text_parts = [] + for item in content: + if isinstance(item, dict): + if item.get("type") == "text": + text_parts.append(item.get("text", "")) + elif isinstance(item, str): + text_parts.append(item) + return " ".join(text_parts) + else: + return str(content) diff --git a/app/utils/user_agent.py b/app/utils/user_agent.py new file mode 100644 index 0000000000000000000000000000000000000000..cc6bbbe0df929ca66cbabb881e13d403f57d1866 --- /dev/null +++ b/app/utils/user_agent.py @@ -0,0 +1,133 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +""" +用户代理工具模块 +提供动态随机用户代理生成功能 +""" + +import random +from typing import Dict, Optional +from fake_useragent import UserAgent + +# 全局 UserAgent 实例(单例模式) +_user_agent_instance: Optional[UserAgent] = None + + +def get_user_agent_instance() -> UserAgent: + """获取或创建 UserAgent 实例(单例模式)""" + global _user_agent_instance + if _user_agent_instance is None: + _user_agent_instance = UserAgent() + return _user_agent_instance + + +def get_random_user_agent(browser_type: Optional[str] = None) -> str: + """ + 获取随机用户代理字符串 + + Args: + browser_type: 指定浏览器类型 ('chrome', 'firefox', 'safari', 'edge') + 如果为 None,则随机选择 + + Returns: + str: 用户代理字符串 + """ + ua = get_user_agent_instance() + + # 如果没有指定浏览器类型,随机选择一个(偏向 Chrome 和 Edge) + if browser_type is None: + browser_choices = ["chrome", "chrome", "chrome", "edge", "edge", "firefox", "safari"] + browser_type = random.choice(browser_choices) + + # 根据浏览器类型获取用户代理 + if browser_type == "chrome": + user_agent = ua.chrome + elif browser_type == "edge": + user_agent = ua.edge + elif browser_type == "firefox": + user_agent = ua.firefox + elif browser_type == "safari": + user_agent = ua.safari + else: + user_agent = ua.random + + return user_agent + + +# 通用 UserAgent headers 生成函数 +def get_dynamic_headers( + referer: Optional[str] = None, + origin: Optional[str] = None, + browser_type: Optional[str] = None, + additional_headers: Optional[Dict[str, str]] = None +) -> Dict[str, str]: + """ + 生成动态浏览器 headers,包含随机 User-Agent + + Args: + referer: 引用页面 URL + origin: 源站 URL + browser_type: 指定浏览器类型 + additional_headers: 额外的 headers + + Returns: + Dict[str, str]: 包含动态 User-Agent 的 headers + """ + user_agent = get_random_user_agent(browser_type) + + # 基础 headers + headers = { + "User-Agent": user_agent, + "Accept": "application/json, text/event-stream", + "Accept-Language": "zh-CN,zh;q=0.9,en;q=0.8", + "Accept-Encoding": "gzip, deflate, br", + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "Pragma": "no-cache", + } + + # 添加可选的 headers + if referer: + headers["Referer"] = referer + + if origin: + headers["Origin"] = origin + + # 根据用户代理添加浏览器特定的 headers + if "Chrome/" in user_agent or "Edg/" in user_agent: + # Chrome/Edge 特定的 headers + chrome_version = "139" + edge_version = "139" + + try: + if "Chrome/" in user_agent: + chrome_version = user_agent.split("Chrome/")[1].split(".")[0] + except: + pass + + try: + if "Edg/" in user_agent: + edge_version = user_agent.split("Edg/")[1].split(".")[0] + sec_ch_ua = f'"Microsoft Edge";v="{edge_version}", "Chromium";v="{chrome_version}", "Not_A Brand";v="24"' + else: + sec_ch_ua = f'"Not_A Brand";v="8", "Chromium";v="{chrome_version}", "Google Chrome";v="{chrome_version}"' + except: + sec_ch_ua = f'"Not_A Brand";v="8", "Chromium";v="{chrome_version}", "Google Chrome";v="{chrome_version}"' + + headers.update({ + "sec-ch-ua": sec_ch_ua, + "sec-ch-ua-mobile": "?0", + "sec-ch-ua-platform": '"Windows"', + "Sec-Fetch-Dest": "empty", + "Sec-Fetch-Mode": "cors", + "Sec-Fetch-Site": "same-origin", + }) + + # 添加额外的 headers + if additional_headers: + headers.update(additional_headers) + + return headers + + diff --git a/deploy/.dockerignore b/deploy/.dockerignore new file mode 100644 index 0000000000000000000000000000000000000000..5418af614bcf80c659dbbd9c383dd8b2961ba61f --- /dev/null +++ b/deploy/.dockerignore @@ -0,0 +1,54 @@ +# Git files +.git +.gitignore +.gitattributes + +# Python cache +__pycache__ +*.py[cod] +*$py.class +*.so +.Python + +# Virtual environments +venv/ +env/ +ENV/ + +# IDE +.vscode/ +.idea/ +*.swp +*.swo +*~ + +# Documentation +*.md +!README.md +docs/ + +# Test files +tests/ +pytest.ini +.pytest_cache/ + +# Local data (will be mounted as volumes) +*.db +*.sqlite +*.sqlite3 +logs/ +data/ + +# Build artifacts +build/ +dist/ +*.egg-info/ + +# Docker files in parent directory +Dockerfile +docker-compose.yml +.dockerignore + +# Other +.env.local +.DS_Store diff --git a/deploy/.env.example b/deploy/.env.example new file mode 100644 index 0000000000000000000000000000000000000000..50cec592a3ccbcd64416a37501afccfe36a7248a --- /dev/null +++ b/deploy/.env.example @@ -0,0 +1,35 @@ +# ============================================== +# Z.AI API Server - Docker 环境变量配置示例 +# ============================================== + +# 管理后台密码 +ADMIN_PASSWORD=admin123 + +# API 认证密钥 (用于验证客户端请求) +AUTH_TOKEN=sk-your-api-key-here + +# 是否跳过 API Key 验证 (开发环境可设为 true) +SKIP_AUTH_TOKEN=false + +# 调试日志 (生产环境建议设为 false) +DEBUG_LOGGING=true + +# 匿名模式 (允许无 token 访问,需要配合 SKIP_AUTH_TOKEN=true) +ANONYMOUS_MODE=false + +# Function Call 功能开关 (是否支持工具调用) +TOOL_SUPPORT=true + +# 工具调用扫描限制 (字符数) +SCAN_LIMIT=200000 + +# 数据库路径 (Docker 环境使用持久化卷) +DB_PATH=/app/data/tokens.db + +# Token 池配置 +TOKEN_FAILURE_THRESHOLD=3 +TOKEN_RECOVERY_TIMEOUT=300 + +# 服务配置 +SERVICE_NAME=Z.AI_API_Server +LISTEN_PORT=8080 diff --git a/deploy/Dockerfile b/deploy/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..4dd5c3796b188159627088f97545b6fdf358c811 --- /dev/null +++ b/deploy/Dockerfile @@ -0,0 +1,24 @@ +FROM python:3.12-slim + +# Set working directory +WORKDIR /app + +# Create data and logs directories with proper permissions +RUN mkdir -p /app/data /app/logs && \ + chmod 755 /app/data /app/logs + +# Install dependencies +COPY requirements.txt . +RUN pip install --no-cache-dir -r requirements.txt + +# Copy application code +COPY . . + +# Set environment variable for database path +ENV DB_PATH=/app/data/tokens.db + +# Expose port +EXPOSE 8080 + +# Run the application +CMD ["python", "main.py"] diff --git a/deploy/NGINX_SETUP.md b/deploy/NGINX_SETUP.md new file mode 100644 index 0000000000000000000000000000000000000000..7f45afdd9b13c085d35e4c0fe7daecdd814b77bc --- /dev/null +++ b/deploy/NGINX_SETUP.md @@ -0,0 +1,278 @@ +# Nginx 反向代理部署指南 + +本文档说明如何在 Nginx 反向代理后部署 Z.AI2API,支持自定义路径前缀。 + +## 问题说明 + +在使用 Nginx 反向代理时,如果需要将服务部署在自定义路径前缀下(例如 `http://domain.com/ai2api`), +需要正确配置 `ROOT_PATH` 环境变量,否则会出现以下问题: + +- 后台管理页面跳转错误(缺少路径前缀) +- API 接口请求 404(路径不完整) +- 静态资源加载失败 + +## 解决方案 + +### 1. 配置环境变量 + +在 `.env` 文件中设置 `ROOT_PATH` 变量,值为 Nginx 配置的 location 路径: + +```bash +# 示例:部署在 /ai2api 路径下 +ROOT_PATH=/ai2api +``` + +**重要**: `ROOT_PATH` 必须与 Nginx 配置中的 `location` 路径完全一致。 + +### 2. 配置 Nginx + +参考 `deploy/nginx.conf.example` 文件,选择合适的配置模板。 + +#### 基础配置示例 + +```nginx +server { + listen 80; + server_name your-domain.com; + + location /ai2api { + # 代理到后端服务 + proxy_pass http://127.0.0.1:8080; + + # 传递原始请求信息 + proxy_set_header Host $host; + proxy_set_header X-Real-IP $remote_addr; + proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; + proxy_set_header X-Forwarded-Proto $scheme; + + # SSE 流式响应支持 + proxy_http_version 1.1; + proxy_set_header Upgrade $http_upgrade; + proxy_set_header Connection "upgrade"; + proxy_buffering off; + proxy_cache off; + + # 超时设置 + proxy_read_timeout 300s; + } +} +``` + +### 3. Docker Compose 配置 + +如果使用 Docker 部署,需要在 `docker-compose.yml` 中添加 `ROOT_PATH` 环境变量: + +```yaml +version: '3.8' +services: + ai2api: + image: z-ai2api:latest + environment: + - ROOT_PATH=/ai2api + - LISTEN_PORT=8080 + # ... 其他环境变量 + ports: + - "8080:8080" +``` + +### 4. 重启服务 + +```bash +# 重载 Nginx 配置 +sudo nginx -t +sudo systemctl reload nginx + +# 重启应用(Docker) +docker-compose restart + +# 或重启应用(直接运行) +# 停止服务后重新启动 +``` + +## 访问地址 + +配置完成后,服务访问地址如下: + +- **API 端点**: `http://your-domain.com/ai2api/v1/chat/completions` +- **模型列表**: `http://your-domain.com/ai2api/v1/models` +- **管理后台**: `http://your-domain.com/ai2api/admin/login` +- **根路径**: `http://your-domain.com/ai2api/` + +## 配置示例 + +### 示例 1: 部署在 /api 路径下 + +**.env 配置**: +```bash +ROOT_PATH=/api +``` + +**Nginx 配置**: +```nginx +location /api { + proxy_pass http://127.0.0.1:8080; + # ... 其他配置 +} +``` + +**访问地址**: `http://domain.com/api/admin/login` + +### 示例 2: 部署在根路径(无前缀) + +**.env 配置**: +```bash +ROOT_PATH= +``` + +**Nginx 配置**: +```nginx +location / { + proxy_pass http://127.0.0.1:8080; + # ... 其他配置 +} +``` + +**访问地址**: `http://domain.com/admin/login` + +### 示例 3: 多级路径前缀 + +**.env 配置**: +```bash +ROOT_PATH=/services/ai/chat +``` + +**Nginx 配置**: +```nginx +location /services/ai/chat { + proxy_pass http://127.0.0.1:8080; + # ... 其他配置 +} +``` + +**访问地址**: `http://domain.com/services/ai/chat/admin/login` + +## 常见问题排查 + +### 1. 404 错误 + +**现象**: 访问页面或 API 时返回 404 + +**可能原因**: +- `ROOT_PATH` 配置与 Nginx location 路径不匹配 +- Nginx 配置错误或未重载 + +**解决方法**: +- 检查 `.env` 中的 `ROOT_PATH` 是否与 Nginx `location` 完全一致 +- 确认 Nginx 配置无误: `sudo nginx -t` +- 重载 Nginx: `sudo systemctl reload nginx` +- 重启应用服务 + +### 2. 静态资源加载失败 + +**现象**: 管理后台页面样式错乱,控制台显示 CSS/JS 404 + +**可能原因**: +- `ROOT_PATH` 未配置或配置错误 +- 静态文件路径未包含前缀 + +**解决方法**: +- 确保 `ROOT_PATH` 正确配置并重启服务 +- 检查浏览器开发者工具中的资源请求路径 + +### 3. 流式响应中断 + +**现象**: SSE 流式响应提前终止或无法正常工作 + +**可能原因**: +- Nginx 启用了缓冲 +- 超时时间设置过短 + +**解决方法**: +在 Nginx 配置中添加: +```nginx +proxy_buffering off; +proxy_cache off; +proxy_read_timeout 300s; +``` + +### 4. CORS 错误 + +**现象**: 浏览器控制台显示跨域请求被阻止 + +**可能原因**: +- Nginx 未正确传递请求头 + +**解决方法**: +确保 Nginx 配置中包含: +```nginx +proxy_set_header Host $host; +proxy_set_header X-Forwarded-Proto $scheme; +``` + +## 验证配置 + +配置完成后,可以通过以下方式验证: + +1. **访问健康检查端点**: + ```bash + curl http://your-domain.com/ai2api/v1/models + ``` + +2. **访问管理后台**: + 在浏览器打开 `http://your-domain.com/ai2api/admin/login` + +3. **测试 API 请求**: + ```bash + curl -X POST http://your-domain.com/ai2api/v1/chat/completions \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer your-api-key" \ + -d '{ + "model": "GLM-4.6", + "messages": [{"role": "user", "content": "Hello"}], + "stream": false + }' + ``` + +## 进阶配置 + +### HTTPS 配置 + +```nginx +server { + listen 443 ssl http2; + server_name your-domain.com; + + ssl_certificate /path/to/cert.pem; + ssl_certificate_key /path/to/key.pem; + + location /ai2api { + proxy_pass http://127.0.0.1:8080; + proxy_set_header X-Forwarded-Proto https; + # ... 其他配置 + } +} +``` + +### 负载均衡 + +```nginx +upstream ai2api_backend { + server 127.0.0.1:8080; + server 127.0.0.1:8081; + server 127.0.0.1:8082; +} + +server { + listen 80; + location /ai2api { + proxy_pass http://ai2api_backend; + # ... 其他配置 + } +} +``` + +## 参考资料 + +- [FastAPI Behind a Proxy](https://fastapi.tiangolo.com/advanced/behind-a-proxy/) +- [Nginx Proxy Module](http://nginx.org/en/docs/http/ngx_http_proxy_module.html) +- 完整配置示例: `deploy/nginx.conf.example` diff --git a/deploy/README_DOCKER.md b/deploy/README_DOCKER.md new file mode 100644 index 0000000000000000000000000000000000000000..a739b4a468b23d79dcebe9c7eb9a33a3eb78c4a4 --- /dev/null +++ b/deploy/README_DOCKER.md @@ -0,0 +1,357 @@ +# Docker 部署文档 + +## 快速部署 + +### 方式一: 使用预构建镜像 (推荐) + +从 Docker Hub 拉取镜像: + +```bash +# 拉取最新镜像 +docker pull zyphrzero/z-ai2api-python:latest + +# 创建数据目录 +mkdir -p data logs + +# 快速启动 +docker run -d \ + --name z-ai-api-server \ + -p 8080:8080 \ + -e ADMIN_PASSWORD=admin123 \ + -e AUTH_TOKEN=sk-your-api-key \ + -e ANONYMOUS_MODE=true \ + -e DB_PATH=/app/data/tokens.db \ + -v $(pwd)/data:/app/data \ + -v $(pwd)/logs:/app/logs \ + --restart unless-stopped \ + zyphrzero/z-ai2api-python:latest +``` + +**优势**: +- ✅ 无需本地构建,节省时间 +- ✅ GitHub Actions 自动化构建,保证质量 +- ✅ 多架构支持 (amd64/arm64) +- ✅ 镜像已优化,体积更小 + +### 方式二: 使用本地构建 + +适用于需要自定义修改代码的场景: + +```bash +# 进入部署目录 +cd deploy + +# 启动服务 (会自动构建镜像) +docker compose up -d + +# 查看日志 +docker compose logs -f api-server +``` + +服务将在 `http://localhost:8080` 启动。 + +## 架构说明 + +### 持久化存储 + +容器使用卷映射实现数据持久化: + +```yaml +volumes: + - ./data:/app/data # 数据库存储 (tokens.db) + - ./logs:/app/logs # 应用日志 +``` + +**目录结构**: +``` +deploy/ +├── data/ +│ └── tokens.db # SQLite 数据库 (自动创建) +├── logs/ # 应用日志 (自动创建) +├── docker-compose.yml +├── Dockerfile +└── README_DOCKER.md +``` + +### 环境变量 + +核心配置参数 (在 `docker-compose.yml` 中设置): + +| 变量 | 默认值 | 说明 | +|------|--------|------| +| `DB_PATH` | `/app/data/tokens.db` | 数据库文件路径 | +| `ADMIN_PASSWORD` | `admin123` | 管理后台密码 | +| `AUTH_TOKEN` | `sk-your-api-key` | API 认证密钥 | +| `SKIP_AUTH_TOKEN` | `false` | 跳过 API 验证 | +| `ANONYMOUS_MODE` | `true` | 匿名访问模式 | +| `DEBUG_LOGGING` | `true` | 调试日志开关 | +| `TOOL_SUPPORT` | `true` | Function Call 支持 | + +**生产环境建议**: +- 修改 `ADMIN_PASSWORD` 和 `AUTH_TOKEN` +- 设置 `DEBUG_LOGGING=false` +- 设置 `ANONYMOUS_MODE=false` + +## 运维操作 + +### 服务管理 + +```bash +# 启动服务 +docker compose up -d + +# 停止服务 +docker compose down + +# 重启服务 +docker compose restart + +# 查看状态 +docker compose ps + +# 实时日志 +docker compose logs -f +``` + +### 更新应用 + +**使用预构建镜像**: + +```bash +# 停止当前容器 +docker compose down + +# 拉取最新镜像 +docker pull zyphrzero/z-ai2api-python:latest + +# 启动新版本 (数据会自动保留) +docker compose up -d + +# 清理旧镜像 +docker image prune -f +``` + +**使用本地构建**: + +```bash +# 拉取最新代码 +git pull + +# 重新构建并启动 (数据会保留) +docker compose up -d --build + +# 清理旧镜像 +docker image prune -f +``` + +### 数据备份与恢复 + +**备份**: +```bash +# 备份数据库 +cp ./data/tokens.db ./data/tokens.db.backup.$(date +%Y%m%d_%H%M%S) + +# 完整备份 +tar -czf backup_$(date +%Y%m%d_%H%M%S).tar.gz ./data ./logs +``` + +**恢复**: +```bash +# 停止服务 +docker compose down + +# 恢复数据库 +cp ./data/tokens.db.backup.20250116_120000 ./data/tokens.db + +# 启动服务 +docker compose up -d +``` + +### 数据库迁移 + +如需从其他位置迁移现有数据库: + +```bash +# 使用迁移脚本 +./migrate_db.sh /path/to/existing/tokens.db + +# 或手动复制 +cp /opt/1panel/docker/compose/k2think/tokens.db ./data/ +chmod 644 ./data/tokens.db + +# 启动服务 +docker compose up -d +``` + +## 故障排查 + +### 数据库初始化失败 + +**错误**: `unable to open database file` + +**原因**: 目录权限或卷映射问题 + +**解决**: +```bash +# 停止容器 +docker compose down + +# 确保目录存在 +mkdir -p ./data ./logs + +# 设置权限 +chmod 755 ./data ./logs + +# 重新构建并启动 +docker compose up -d --build +``` + +### 容器无法启动 + +**检查步骤**: +```bash +# 查看详细日志 +docker compose logs api-server + +# 检查容器状态 +docker compose ps + +# 验证配置文件 +docker compose config +``` + +### 端口冲突 + +如端口 8080 被占用,修改 `docker-compose.yml`: +```yaml +ports: + - "8081:8080" # 映射到宿主机 8081 端口 +``` + +### 健康检查失败 + +```bash +# 检查健康状态 +docker compose ps + +# 手动测试接口 +curl http://localhost:8080/v1/models + +# 进入容器排查 +docker exec -it z-ai-api-server bash +``` + +## API 访问 + +| 端点 | 地址 | 说明 | +|------|------|------| +| API 根路径 | `http://localhost:8080` | OpenAI 兼容 API | +| 模型列表 | `http://localhost:8080/v1/models` | 获取可用模型 | +| 管理后台 | `http://localhost:8080/admin` | Web 管理界面 | +| API 文档 | `http://localhost:8080/docs` | OpenAPI/Swagger 文档 | +| 健康检查 | `http://localhost:8080/v1/models` | 服务健康状态 | + +## 高级配置 + +### 自定义数据库路径 + +修改 `docker-compose.yml` 使用外部路径: + +```yaml +volumes: + - /opt/mydata:/app/data # 使用绝对路径 + +environment: + - DB_PATH=/app/data/tokens.db +``` + +### 使用 .env 文件 + +创建 `.env` 文件 (基于 `.env.example`): + +```bash +cp .env.example .env +# 编辑配置 +vim .env +``` + +修改 `docker-compose.yml`: +```yaml +services: + api-server: + env_file: .env +``` + +### 启用日志轮转 + +在生产环境配置 Docker 日志驱动: + +```yaml +services: + api-server: + logging: + driver: "json-file" + options: + max-size: "10m" + max-file: "3" +``` + +### 资源限制 + +限制容器资源使用: + +```yaml +services: + api-server: + deploy: + resources: + limits: + cpus: '2' + memory: 2G + reservations: + cpus: '0.5' + memory: 512M +``` + +## 监控与日志 + +### 查看日志 + +```bash +# 实时日志 +docker compose logs -f + +# 最近100行 +docker compose logs --tail=100 + +# 特定时间段 +docker compose logs --since 30m + +# 导出日志 +docker compose logs > app.log +``` + +### 容器指标 + +```bash +# 资源使用情况 +docker stats z-ai-api-server + +# 容器详情 +docker inspect z-ai-api-server +``` + +## 安全建议 + +1. **修改默认密码**: 更改 `ADMIN_PASSWORD` 和 `AUTH_TOKEN` +2. **限制网络访问**: 生产环境使用反向代理 (Nginx/Caddy) +3. **启用 HTTPS**: 配置 SSL 证书 +4. **定期备份**: 自动化数据库备份任务 +5. **日志审计**: 定期检查 `request_logs` 表 +6. **最小权限**: 避免以 root 运行容器 + +## 参考资料 + +- [Docker Compose 文档](https://docs.docker.com/compose/) +- [项目主 README](../README.md) +- [配置示例](.env.example) diff --git a/deploy/docker-compose.yml b/deploy/docker-compose.yml new file mode 100644 index 0000000000000000000000000000000000000000..7125f4c9b750ee8c81f3cb36551e366327fb8aa0 --- /dev/null +++ b/deploy/docker-compose.yml @@ -0,0 +1,35 @@ +services: + api-server: + build: + context: .. + dockerfile: deploy/Dockerfile + container_name: z-ai-api-server + ports: + - "8080:8080" + volumes: + # 数据库持久化存储 + - ./data:/app/data + # 日志持久化存储(可选) + - ./logs:/app/logs + environment: + - ADMIN_PASSWORD=admin123 + # Auth Configuration + - AUTH_TOKEN=sk-your-api-key + # 是否跳过api key验证 + - SKIP_AUTH_TOKEN=false + # 调试日志 + - DEBUG_LOGGING=true + # 匿名模式 + - ANONYMOUS_MODE=true + # Function Call 功能开关 + - TOOL_SUPPORT=true + # 工具调用扫描限制(字符数) + - SCAN_LIMIT=200000 + # 数据库路径 - 使用持久化卷 + - DB_PATH=/app/data/tokens.db + restart: unless-stopped + healthcheck: + test: ["CMD", "curl", "-f", "http://localhost:8080/v1/models"] + interval: 30s + timeout: 10s + retries: 3 diff --git a/deploy/nginx.conf.example b/deploy/nginx.conf.example new file mode 100644 index 0000000000000000000000000000000000000000..a163e32a2c17a5f7b29c7d41da1d17bcf2332e6a --- /dev/null +++ b/deploy/nginx.conf.example @@ -0,0 +1,157 @@ +# Nginx reverse proxy configuration example for Z.AI2API +# This example shows how to deploy the service behind Nginx with a custom path prefix + +# Example 1: Deploy at http://your-domain.com/ai2api +server { + listen 80; + server_name your-domain.com; + + # Forward requests with /ai2api prefix to the backend service + location /ai2api { + # Remove trailing slash redirect (optional, but recommended) + rewrite ^(/ai2api)$ $1/ permanent; + + # Proxy to the backend service + proxy_pass http://127.0.0.1:8080; + + # Pass original host and IP information + proxy_set_header Host $host; + proxy_set_header X-Real-IP $remote_addr; + proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; + proxy_set_header X-Forwarded-Proto $scheme; + + # IMPORTANT: Tell the backend about the path prefix + # This ensures all generated URLs include the prefix + proxy_set_header X-Forwarded-Prefix /ai2api; + + # WebSocket and SSE support (for streaming responses) + proxy_http_version 1.1; + proxy_set_header Upgrade $http_upgrade; + proxy_set_header Connection "upgrade"; + + # Disable buffering for streaming responses + proxy_buffering off; + proxy_cache off; + + # Timeout settings (adjust as needed) + proxy_connect_timeout 60s; + proxy_send_timeout 300s; + proxy_read_timeout 300s; + } +} + +# Example 2: Deploy at http://your-domain.com/api/chat +server { + listen 80; + server_name example.com; + + location /api/chat { + # Proxy configuration + proxy_pass http://127.0.0.1:8080; + + # Headers + proxy_set_header Host $host; + proxy_set_header X-Real-IP $remote_addr; + proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; + proxy_set_header X-Forwarded-Proto $scheme; + proxy_set_header X-Forwarded-Prefix /api/chat; + + # SSE/WebSocket support + proxy_http_version 1.1; + proxy_set_header Upgrade $http_upgrade; + proxy_set_header Connection "upgrade"; + proxy_buffering off; + proxy_cache off; + } +} + +# Example 3: Deploy with SSL (HTTPS) +server { + listen 443 ssl http2; + server_name secure.example.com; + + # SSL configuration + ssl_certificate /path/to/cert.pem; + ssl_certificate_key /path/to/key.pem; + + location /ai2api { + proxy_pass http://127.0.0.1:8080; + + proxy_set_header Host $host; + proxy_set_header X-Real-IP $remote_addr; + proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; + proxy_set_header X-Forwarded-Proto https; + proxy_set_header X-Forwarded-Prefix /ai2api; + + # SSE/WebSocket support + proxy_http_version 1.1; + proxy_set_header Upgrade $http_upgrade; + proxy_set_header Connection "upgrade"; + proxy_buffering off; + proxy_cache off; + + # Security headers (optional) + add_header X-Content-Type-Options nosniff; + add_header X-Frame-Options DENY; + add_header X-XSS-Protection "1; mode=block"; + } +} + +# Example 4: Load balancing with multiple backend instances +upstream ai2api_backend { + # Round-robin by default + server 127.0.0.1:8080; + server 127.0.0.1:8081; + server 127.0.0.1:8082; + + # Or use least connections + # least_conn; + + # Or use IP hash for session persistence + # ip_hash; +} + +server { + listen 80; + server_name loadbalanced.example.com; + + location /ai2api { + proxy_pass http://ai2api_backend; + + proxy_set_header Host $host; + proxy_set_header X-Real-IP $remote_addr; + proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; + proxy_set_header X-Forwarded-Proto $scheme; + proxy_set_header X-Forwarded-Prefix /ai2api; + + proxy_http_version 1.1; + proxy_set_header Upgrade $http_upgrade; + proxy_set_header Connection "upgrade"; + proxy_buffering off; + proxy_cache off; + } +} + +# Important Notes: +# +# 1. Set ROOT_PATH in your .env file to match the Nginx location path: +# ROOT_PATH=/ai2api +# +# 2. Restart both Nginx and the application after configuration changes: +# sudo systemctl reload nginx +# docker-compose restart (or restart your application) +# +# 3. Access URLs will include the prefix: +# - Admin panel: http://your-domain.com/ai2api/admin/login +# - API endpoint: http://your-domain.com/ai2api/v1/chat/completions +# - Health check: http://your-domain.com/ai2api/v1/models +# +# 4. For Docker deployments, make sure to: +# - Add ROOT_PATH to docker-compose.yml environment variables +# - Expose the container port (8080 by default) +# +# 5. Common issues: +# - 404 errors: Check that ROOT_PATH matches the Nginx location path exactly +# - CORS errors: Verify proxy headers are set correctly +# - Streaming not working: Ensure proxy_buffering is off +# - Admin panel CSS/JS not loading: Confirm static files are served with the prefix diff --git a/main.py b/main.py new file mode 100644 index 0000000000000000000000000000000000000000..49392464d68af1752d4731fbfb8c7123bce61ee2 --- /dev/null +++ b/main.py @@ -0,0 +1,173 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import os +import sys +from contextlib import asynccontextmanager + +from fastapi import FastAPI, Response +from fastapi.middleware.cors import CORSMiddleware +from fastapi.staticfiles import StaticFiles +from granian import Granian + +from app.admin import api as admin_api +from app.admin import routes as admin_routes +from app.core import claude, openai +from app.core.config import settings +from app.core.upstream import UpstreamClient +from app.utils.logger import setup_logger +from app.utils.reload_config import RELOAD_CONFIG + +# Setup logger +logger = setup_logger(log_dir="logs", debug_mode=settings.DEBUG_LOGGING) + + +async def warmup_upstream_client(): + """可选预热上游适配器,提前初始化动态依赖。""" + try: + client = UpstreamClient() + logger.info( + f"✅ 上游适配器已就绪,支持 {len(client.get_supported_models())} 个模型" + ) + except Exception as exc: + logger.warning(f"⚠️ 上游适配器预热失败: {exc}") + + +@asynccontextmanager +async def lifespan(app: FastAPI): + # 初始化 Token 数据库 + from app.services.request_log_dao import init_request_log_dao + from app.services.token_automation import ( + run_directory_import, + start_token_automation_scheduler, + stop_token_automation_scheduler, + ) + from app.services.token_dao import init_token_database + + await init_token_database() + init_request_log_dao() + + if ( + settings.TOKEN_AUTO_IMPORT_ENABLED + and settings.TOKEN_AUTO_IMPORT_SOURCE_DIR.strip() + ): + try: + await run_directory_import( + settings.TOKEN_AUTO_IMPORT_SOURCE_DIR, + provider="zai", + ) + logger.info("✅ 启动阶段已完成一次目录自动导入") + except Exception as exc: + logger.warning(f"⚠️ 启动阶段目录自动导入失败: {exc}") + + # 从数据库初始化认证 token 池 + from app.utils.token_pool import initialize_token_pool_from_db + + token_pool = await initialize_token_pool_from_db( + provider="zai", + failure_threshold=settings.TOKEN_FAILURE_THRESHOLD, + recovery_timeout=settings.TOKEN_RECOVERY_TIMEOUT, + ) + + if not token_pool and not settings.ANONYMOUS_MODE: + logger.warning( + "⚠️ 未找到可用 Token 且未启用匿名模式,服务可能无法正常工作" + ) + + if settings.ANONYMOUS_MODE: + from app.utils.guest_session_pool import initialize_guest_session_pool + + guest_pool = await initialize_guest_session_pool( + pool_size=settings.GUEST_POOL_SIZE, + ) + guest_status = guest_pool.get_pool_status() + logger.info( + "🫥 匿名会话池已就绪: " + f"{guest_status.get('valid_sessions', 0)} 个可用会话" + ) + + await warmup_upstream_client() + await start_token_automation_scheduler() + + yield + + logger.info("🔄 应用正在关闭...") + + await stop_token_automation_scheduler() + + if settings.ANONYMOUS_MODE: + from app.utils.guest_session_pool import close_guest_session_pool + + await close_guest_session_pool() + + +# Create FastAPI app with lifespan +# root_path is used for reverse proxy path prefix (e.g., /api or /path-prefix) +app = FastAPI(lifespan=lifespan, root_path=settings.ROOT_PATH) + +# Add CORS middleware +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"], + allow_headers=["Content-Type", "Authorization"], +) + +# 挂载web端静态文件目录 +try: + app.mount("/static", StaticFiles(directory="app/static"), name="static") +except RuntimeError: + # 如果 static 目录不存在,创建它 + os.makedirs("app/static/css", exist_ok=True) + os.makedirs("app/static/js", exist_ok=True) + app.mount("/static", StaticFiles(directory="app/static"), name="static") + +# Include API routers +app.include_router(openai.router) +app.include_router(claude.router) + +# Include admin routers +app.include_router(admin_routes.router) +app.include_router(admin_api.router) + + +@app.options("/") +async def handle_options(): + """Handle OPTIONS requests""" + return Response(status_code=200) + + +@app.get("/") +async def root(): + """Root endpoint""" + return {"message": "OpenAI Compatible API Server"} + + +def run_server(): + service_name = settings.SERVICE_NAME + + logger.info(f"🚀 启动 {service_name} 服务...") + logger.info(f"📡 监听地址: 0.0.0.0:{settings.LISTEN_PORT}") + logger.info(f"🔧 调试模式: {'开启' if settings.DEBUG_LOGGING else '关闭'}") + logger.info(f"🔐 匿名模式: {'开启' if settings.ANONYMOUS_MODE else '关闭'}") + + try: + Granian( + "main:app", + interface="asgi", + address="0.0.0.0", + port=settings.LISTEN_PORT, + reload=False, # 生产环境请关闭热重载 + process_name=service_name, # 设置进程名称 + **RELOAD_CONFIG, # 热重载配置 + ).serve() + except KeyboardInterrupt: + logger.info("🛑 收到中断信号,正在关闭服务...") + except Exception as e: + logger.error(f"❌ 服务启动失败: {e}") + sys.exit(1) + + +if __name__ == "__main__": + run_server() diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000000000000000000000000000000000000..a68ee5538c15fecf8723604d6bc9cb34c317dec0 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,71 @@ +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[project] +name = "z-ai2api-python" +version = "0.1.0" +description = "一个为 Z.ai 提供 OpenAI 兼容接口的 Python 代理服务" +readme = "README.md" +requires-python = ">=3.9,<=3.12" +license = { text = "MIT" } +authors = [{ name = "Contributors" }] +classifiers = [ + "Development Status :: 4 - Beta", + "Intended Audience :: Developers", + "License :: OSI Approved :: MIT License", + "Operating System :: OS Independent", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Topic :: Internet :: WWW/HTTP :: HTTP Servers", + "Topic :: Software Development :: Libraries :: Python Modules", +] +dependencies = [ + "fastapi==0.116.1", + "granian[reload,pname]==2.5.2", + "httpx[http2,socks]==0.28.1", + "pydantic==2.11.7", + "pydantic-settings==2.10.1", + "pydantic-core==2.33.2", + "typing-inspection==0.4.1", + "fake-useragent==2.2.0", + "loguru==0.7.3", + "psutil>=7.0.0", + "json-repair==0.44.1", + "jinja2==3.1.4", + "aiosqlite==0.20.0", + "python-multipart==0.0.12", + "python-dotenv==1.0.1" +] + +[project.scripts] +z-ai2api = "main:app" + +[tool.hatch.build.targets.wheel] +packages = ["."] + +[tool.uv] +dev-dependencies = [ + "pytest>=7.0.0", + "pytest-asyncio>=0.21.0", + "requests>=2.30.0", + "ruff>=0.1.0", +] + +[tool.ruff] +line-length = 88 +target-version = "py38" +select = ["E", "F", "I", "B"] +ignore = [] + +[tool.ruff.isort] +known-first-party = [] + +[tool.pytest.ini_options] +asyncio_mode = "auto" +testpaths = ["tests"] +python_files = ["test_*.py"] +python_functions = ["test_*"] diff --git a/tests/real_upstream_test_utils.py b/tests/real_upstream_test_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..67a64d74a340d0c423ec04b79a0853c10aaa2a04 --- /dev/null +++ b/tests/real_upstream_test_utils.py @@ -0,0 +1,65 @@ +import os +from typing import Any + +import pytest + +from app.core import upstream as upstream_module +from app.core.upstream import UpstreamClient, _extract_user_id_from_token + +REAL_AUTH_TOKEN_ENV = "REAL_AUTH_TOKEN_ENV" +RED_2X2_PNG_DATA_URL = ( + "data:image/png;base64," + "iVBORw0KGgoAAAANSUhEUgAAAAIAAAACCAIAAAD91JpzAAAAEElEQVR42mP4z8AARAwQCgAf7gP9Y167WwAAAABJRU5ErkJggg==" +) + +def install_real_auth(monkeypatch) -> str: + token = os.getenv(REAL_AUTH_TOKEN_ENV, "").strip() + if not token: + pytest.skip(f"需要设置环境变量 {REAL_AUTH_TOKEN_ENV}") + + user_id = _extract_user_id_from_token(token) + if not user_id or user_id == "guest": + raise AssertionError(f"{REAL_AUTH_TOKEN_ENV} 不是可解析的认证 token") + + async def fake_get_auth_info( + self, + excluded_tokens=None, + excluded_guest_user_ids=None, + ): + return { + "token": token, + "user_id": user_id, + "username": "RealUser", + "auth_mode": "authenticated", + "token_source": "env", + "guest_user_id": None, + } + + monkeypatch.setattr(UpstreamClient, "get_auth_info", fake_get_auth_info) + monkeypatch.setattr(upstream_module, "get_token_pool", lambda: None) + monkeypatch.setattr(upstream_module, "get_guest_session_pool", lambda: None) + return token + + +def install_real_anonymous(monkeypatch) -> None: + monkeypatch.setattr(upstream_module, "get_token_pool", lambda: None) + monkeypatch.setattr(upstream_module, "get_guest_session_pool", lambda: None) + monkeypatch.setattr(upstream_module.settings, "ANONYMOUS_MODE", True) + + +def extract_content(payload: dict[str, Any]) -> str: + assert isinstance(payload, dict), payload + assert "error" not in payload, payload + + choices = payload.get("choices") or [] + assert choices, payload + + message = choices[0].get("message") or {} + content = str(message.get("content") or "").strip() + assert content, payload + return content + + +def assert_usage_present(payload: dict[str, Any]) -> None: + usage = payload.get("usage") or {} + assert int(usage.get("total_tokens") or 0) > 0, payload diff --git a/tests/test_admin_config.py b/tests/test_admin_config.py new file mode 100644 index 0000000000000000000000000000000000000000..b49ddc85ac4c3f32ce9bd07d5cbf842030ebcd6e --- /dev/null +++ b/tests/test_admin_config.py @@ -0,0 +1,231 @@ +from types import SimpleNamespace +from urllib.parse import urlencode + +import pytest +from jinja2 import Environment, FileSystemLoader +from starlette.requests import Request + +from app.admin import api as admin_api +from app.admin.config_manager import ( + CONFIG_FIELD_SPECS, + build_config_page_data, + save_form_config, + save_source_config, + validate_env_source, +) + + +def _build_form_payload(**overrides): + payload = {} + + for key, field in CONFIG_FIELD_SPECS.items(): + value = overrides[key] if key in overrides else field.default_value + if field.value_type == "bool": + if value: + payload[key] = "on" + continue + payload[key] = "" if value is None else str(value) + + return payload + + +def _make_form_request(path: str, data: dict[str, str]) -> Request: + body = urlencode(data, doseq=True).encode() + sent = False + + async def receive(): + nonlocal sent + if sent: + return {"type": "http.request", "body": b"", "more_body": False} + sent = True + return {"type": "http.request", "body": body, "more_body": False} + + scope = { + "type": "http", + "http_version": "1.1", + "method": "POST", + "scheme": "http", + "path": path, + "raw_path": path.encode(), + "query_string": b"", + "headers": [ + ( + b"content-type", + b"application/x-www-form-urlencoded", + ) + ], + "client": ("testclient", 50000), + "server": ("testserver", 80), + } + return Request(scope, receive) + + +@pytest.mark.asyncio +async def test_build_config_page_data_includes_sections_and_override_status( + tmp_path, +): + env_path = tmp_path / ".env" + example_path = tmp_path / ".env.example" + env_path.write_text( + "API_ENDPOINT=https://example.com/v1/chat\nDEBUG_LOGGING=true\n", + encoding="utf-8", + ) + example_path.write_text("SERVICE_NAME=example\n", encoding="utf-8") + + settings_stub = SimpleNamespace( + API_ENDPOINT="https://example.com/v1/chat", + DEBUG_LOGGING=True, + GLM5_MODEL="GLM-5", + ADMIN_PASSWORD="secret", + ) + + page_data = build_config_page_data( + settings_obj=settings_stub, + env_path=env_path, + env_example_path=example_path, + ) + + assert page_data["overview"]["total_sections"] >= 7 + assert page_data["overview"]["total_fields"] >= 35 + assert page_data["overview"]["overridden_fields"] == 2 + assert page_data["overview"]["example_exists"] is True + + field_map = { + field["key"]: field + for section in page_data["sections"] + for field in section["fields"] + } + + assert field_map["API_ENDPOINT"]["source_label"] == ".env" + assert field_map["DEBUG_LOGGING"]["source_label"] == ".env" + assert field_map["GLM5_MODEL"]["source_label"] == "默认值" + assert field_map["ADMIN_PASSWORD"]["sensitive"] is True + + +@pytest.mark.asyncio +async def test_save_form_config_preserves_unmanaged_lines_and_updates_fields( + tmp_path, +): + env_path = tmp_path / ".env" + env_path.write_text( + "CUSTOM_FLAG=keep\nSERVICE_NAME=old-service\n", + encoding="utf-8", + ) + + reloaded = False + + async def fake_reload(): + nonlocal reloaded + reloaded = True + + payload = _build_form_payload( + SERVICE_NAME="new-service", + LISTEN_PORT=9090, + ROOT_PATH="/edge", + DEBUG_LOGGING=False, + TOKEN_AUTO_IMPORT_ENABLED=True, + TOKEN_AUTO_IMPORT_SOURCE_DIR="/srv/tokens", + HTTP_PROXY="http://127.0.0.1:7890", + ADMIN_PASSWORD="new-admin-password", + ) + + updates = await save_form_config( + payload, + reload_callback=fake_reload, + env_path=env_path, + ) + content = env_path.read_text(encoding="utf-8") + + assert reloaded is True + assert updates["SERVICE_NAME"] == "new-service" + assert updates["LISTEN_PORT"] == 9090 + assert updates["TOKEN_AUTO_IMPORT_ENABLED"] is True + assert "CUSTOM_FLAG=keep" in content + assert "SERVICE_NAME=new-service" in content + assert "LISTEN_PORT=9090" in content + assert "ROOT_PATH=/edge" in content + assert "TOKEN_AUTO_IMPORT_ENABLED=true" in content + assert "TOKEN_AUTO_IMPORT_SOURCE_DIR=/srv/tokens" in content + assert "HTTP_PROXY=http://127.0.0.1:7890" in content + + +@pytest.mark.asyncio +async def test_save_source_config_rolls_back_file_when_reload_fails(tmp_path): + env_path = tmp_path / ".env" + env_path.write_text("SERVICE_NAME=old-service\n", encoding="utf-8") + + async def failing_reload(): + raise RuntimeError("reload failed") + + with pytest.raises(RuntimeError, match="reload failed"): + await save_source_config( + "SERVICE_NAME=new-service\nLISTEN_PORT=8081\n", + reload_callback=failing_reload, + env_path=env_path, + ) + + assert env_path.read_text(encoding="utf-8") == "SERVICE_NAME=old-service\n" + + +@pytest.mark.asyncio +async def test_save_config_endpoint_returns_refresh_trigger(tmp_path, monkeypatch): + monkeypatch.chdir(tmp_path) + (tmp_path / ".env").write_text("SERVICE_NAME=before\n", encoding="utf-8") + + async def fake_reload(): + return None + + monkeypatch.setattr(admin_api, "reload_settings", fake_reload) + + request = _make_form_request( + "/admin/api/config/save", + _build_form_payload( + SERVICE_NAME="after", + LISTEN_PORT=8081, + DEBUG_LOGGING=True, + ), + ) + response = await admin_api.save_config(request) + body = response.body.decode("utf-8") + + assert response.status_code == 200 + assert response.headers["HX-Trigger"] == "admin-config-refresh" + assert "保存成功" in body + assert "SERVICE_NAME=after" in (tmp_path / ".env").read_text(encoding="utf-8") + + +@pytest.mark.asyncio +async def test_save_config_source_endpoint_rejects_invalid_source( + tmp_path, + monkeypatch, +): + monkeypatch.chdir(tmp_path) + (tmp_path / ".env").write_text("SERVICE_NAME=before\n", encoding="utf-8") + + async def fake_reload(): + return None + + monkeypatch.setattr(admin_api, "reload_settings", fake_reload) + + request = _make_form_request( + "/admin/api/config/source", + {"env_content": "SERVICE_NAME=after\nnot-valid-line\n"}, + ) + response = await admin_api.save_config_source(request) + body = response.body.decode("utf-8") + + assert response.status_code == 400 + assert "KEY=VALUE" in body + assert (tmp_path / ".env").read_text(encoding="utf-8") == "SERVICE_NAME=before\n" + + +def test_validate_env_source_rejects_invalid_lines(): + with pytest.raises(ValueError, match="KEY=VALUE"): + validate_env_source("SERVICE_NAME=ok\nbad line\n") + + +def test_config_template_compiles(): + env = Environment(loader=FileSystemLoader("app/templates")) + template = env.get_template("config.html") + + assert template is not None diff --git a/tests/test_admin_stats.py b/tests/test_admin_stats.py new file mode 100644 index 0000000000000000000000000000000000000000..c0608beebe1db15392f71ebac954517ec42655d0 --- /dev/null +++ b/tests/test_admin_stats.py @@ -0,0 +1,491 @@ +import json +from datetime import datetime +from urllib.parse import urlencode + +import pytest +from starlette.requests import Request + +from app.admin import api as admin_api +from app.admin.stats import collect_admin_stats, format_uptime +from app.services import token_dao as token_dao_module +from app.services.request_log_dao import RequestLogDAO +from app.services.token_dao import TokenDAO +from app.utils import token_pool as token_pool_module +from app.utils.token_pool import TokenPool, sync_token_stats_to_db + + +class DummyPool: + def __init__(self, status): + self._status = status + + def get_pool_status(self): + return self._status + + +def _make_get_request(path: str, query: dict[str, str] | None = None) -> Request: + query_string = urlencode(query or {}).encode() + + async def receive(): + return {"type": "http.request", "body": b"", "more_body": False} + + scope = { + "type": "http", + "http_version": "1.1", + "method": "GET", + "scheme": "http", + "path": path, + "raw_path": path.encode(), + "query_string": query_string, + "headers": [], + "client": ("testclient", 50000), + "server": ("testserver", 80), + } + return Request(scope, receive) + + +@pytest.mark.asyncio +async def test_collect_admin_stats_uses_request_logs_and_token_inventory(tmp_path): + db_path = tmp_path / "admin_stats.db" + token_dao = TokenDAO(str(db_path)) + await token_dao.init_database() + request_log_dao = RequestLogDAO(str(db_path)) + + await token_dao.add_token("zai", "token-user-1", validate=False) + await token_dao.add_token("zai", "token-user-2", validate=False) + await token_dao.add_token( + "zai", + "token-guest-1", + token_type="guest", + validate=False, + ) + unknown_token_id = await token_dao.add_token( + "zai", + "token-unknown-1", + token_type="unknown", + validate=False, + ) + await token_dao.update_token_status(int(unknown_token_id), False) + + await request_log_dao.add_log( + provider="zai", + endpoint="/v1/chat/completions", + source="pytest", + protocol="openai", + client_name="pytest", + model="glm-5", + status_code=200, + success=True, + duration=0.5, + input_tokens=100, + output_tokens=40, + cache_read_tokens=20, + total_tokens=140, + ) + await request_log_dao.add_log( + provider="zai", + endpoint="/v1/chat/completions", + source="pytest", + protocol="openai", + client_name="pytest", + model="glm-5", + status_code=500, + success=False, + duration=1.2, + input_tokens=60, + output_tokens=10, + cache_creation_tokens=15, + total_tokens=70, + error_message="upstream failed", + ) + await request_log_dao.add_log( + provider="zai", + endpoint="/v1/messages", + source="pytest", + protocol="anthropic", + client_name="pytest", + model="glm-4.5", + status_code=200, + success=True, + duration=0.9, + input_tokens=30, + output_tokens=20, + total_tokens=50, + ) + await request_log_dao.add_log( + provider="other", + endpoint="/ignored", + source="pytest", + protocol="openai", + client_name="pytest", + model="glm-ignored", + status_code=200, + success=True, + duration=0.1, + ) + + stats = await collect_admin_stats( + "zai", + token_dao=token_dao, + request_log_dao=request_log_dao, + token_pool=DummyPool( + { + "total_tokens": 2, + "available_tokens": 1, + "healthy_tokens": 1, + "unhealthy_tokens": 1, + } + ), + ) + + assert stats["total_tokens"] == 4 + assert stats["enabled_tokens"] == 3 + assert stats["user_tokens"] == 2 + assert stats["guest_tokens"] == 1 + assert stats["unknown_tokens"] == 1 + assert stats["pool_total_tokens"] == 2 + assert stats["available_tokens"] == 1 + assert stats["healthy_tokens"] == 1 + assert stats["unhealthy_tokens"] == 1 + assert stats["total_requests"] == 3 + assert stats["successful_requests"] == 2 + assert stats["failed_requests"] == 1 + assert stats["success_rate"] == pytest.approx(66.7) + assert stats["input_tokens"] == 190 + assert stats["output_tokens"] == 70 + assert stats["total_consumed_tokens"] == 260 + assert stats["cache_creation_tokens"] == 15 + assert stats["cache_read_tokens"] == 20 + assert stats["total_cache_tokens"] == 35 + assert stats["cache_creation_requests"] == 1 + assert stats["cache_hit_requests"] == 1 + assert stats["average_latency"] == pytest.approx(0.87, rel=1e-2) + assert stats["trend_window"] == "7d" + assert len(stats["usage_trend"]) == 7 + assert stats["usage_trend"][-1]["total_tokens"] == 260 + assert stats["usage_trend"][-1]["cache_total_tokens"] == 35 + + +@pytest.mark.asyncio +async def test_get_model_stats_from_db_includes_recent_same_day_logs(tmp_path): + dao = RequestLogDAO(str(tmp_path / "request_logs.db")) + + await dao.add_log( + provider="zai", + endpoint="/v1/chat/completions", + source="pytest", + protocol="openai", + client_name="pytest", + model="glm-5", + status_code=200, + success=True, + duration=0.25, + input_tokens=10, + output_tokens=20, + ) + + stats = await dao.get_model_stats_from_db(hours=1) + + assert "glm-5" in stats + assert stats["glm-5"]["total"] == 1 + assert stats["glm-5"]["success"] == 1 + assert stats["glm-5"]["failed"] == 0 + + +@pytest.mark.asyncio +async def test_request_log_dao_supports_count_and_offset_pagination(tmp_path): + dao = RequestLogDAO(str(tmp_path / "request_logs_paging.db")) + + for index in range(5): + await dao.add_log( + provider="zai", + endpoint=f"/v1/chat/completions/{index}", + source="pytest", + protocol="openai", + client_name="pytest", + model="glm-5", + status_code=200, + success=True, + duration=0.1, + ) + + total_count = await dao.count_logs(provider="zai") + paged_logs = await dao.get_recent_logs( + limit=2, + offset=2, + provider="zai", + ) + + assert total_count == 5 + assert len(paged_logs) == 2 + assert paged_logs[0]["endpoint"] == "/v1/chat/completions/2" + assert paged_logs[1]["endpoint"] == "/v1/chat/completions/1" + + +@pytest.mark.asyncio +async def test_request_log_dao_returns_usage_trend_with_missing_days_filled( + tmp_path, +): + dao = RequestLogDAO(str(tmp_path / "request_logs_trend.db")) + + await dao.add_log( + provider="zai", + endpoint="/v1/chat/completions", + source="pytest", + protocol="openai", + client_name="pytest", + model="glm-5", + status_code=200, + success=True, + duration=0.2, + input_tokens=12, + output_tokens=8, + cache_read_tokens=3, + total_tokens=20, + ) + + trend = await dao.get_provider_usage_trend("zai", days=7) + + assert len(trend) == 7 + assert sum(day["total_requests"] for day in trend) == 1 + assert trend[-1]["total_tokens"] == 20 + assert trend[-1]["cache_total_tokens"] == 3 + + +@pytest.mark.asyncio +async def test_request_log_dao_returns_hourly_usage_trend_with_missing_hours( + tmp_path, +): + dao = RequestLogDAO(str(tmp_path / "request_logs_hourly_trend.db")) + log_id = await dao.add_log( + provider="zai", + endpoint="/v1/chat/completions", + source="pytest", + protocol="openai", + client_name="pytest", + model="glm-5", + status_code=200, + success=True, + duration=0.2, + input_tokens=18, + output_tokens=7, + cache_creation_tokens=5, + cache_read_tokens=3, + total_tokens=25, + ) + + async with dao.get_connection() as conn: + await conn.execute( + "UPDATE request_logs SET timestamp = ? WHERE id = ?", + ("2026-03-10 12:00:00", log_id), + ) + await conn.commit() + + trend = await dao.get_provider_usage_trend( + "zai", + window="24h", + now=datetime(2026, 3, 10, 12, 0, 0), + ) + + assert len(trend) == 24 + assert trend[-1]["label"] == "12:00" + assert trend[-1]["tooltip_label"] == "2026-03-10 12:00" + assert trend[-1]["input_tokens"] == 18 + assert trend[-1]["output_tokens"] == 7 + assert trend[-1]["cache_creation_tokens"] == 5 + assert trend[-1]["cache_read_tokens"] == 3 + assert sum(point["total_requests"] for point in trend) == 1 + assert all(point["total_requests"] == 0 for point in trend[:-1]) + + +@pytest.mark.asyncio +async def test_dashboard_usage_trend_api_returns_requested_window( + tmp_path, + monkeypatch, +): + dao = RequestLogDAO(str(tmp_path / "request_logs_api_trend.db")) + log_id = await dao.add_log( + provider="zai", + endpoint="/v1/chat/completions", + source="pytest", + protocol="openai", + client_name="pytest", + model="glm-5", + status_code=200, + success=True, + duration=0.2, + input_tokens=30, + output_tokens=12, + cache_read_tokens=4, + total_tokens=42, + ) + + async with dao.get_connection() as conn: + await conn.execute( + "UPDATE request_logs SET timestamp = ? WHERE id = ?", + ("2026-03-10 09:00:00", log_id), + ) + await conn.commit() + + async def fixed_usage_trend(provider, days=None, *, window=None, now=None): + return await RequestLogDAO.get_provider_usage_trend( + dao, + provider, + days=days, + window=window, + now=datetime(2026, 3, 10, 12, 0, 0), + ) + + monkeypatch.setattr(dao, "get_provider_usage_trend", fixed_usage_trend) + monkeypatch.setattr(admin_api, "get_request_log_dao", lambda: dao) + request = _make_get_request( + "/admin/api/dashboard/usage-trend", + {"window": "24h"}, + ) + + response = await admin_api.get_dashboard_usage_trend(request) + payload = json.loads(response.body.decode("utf-8")) + + assert response.status_code == 200 + assert payload["window"] == "24h" + assert len(payload["points"]) == 24 + assert payload["points"][-4]["input_tokens"] == 30 + assert payload["points"][-4]["cache_read_tokens"] == 4 + + +@pytest.mark.asyncio +async def test_recent_logs_component_includes_usage_cache_and_latency_fields( + tmp_path, + monkeypatch, +): + dao = RequestLogDAO(str(tmp_path / "request_logs_recent_component.db")) + await dao.add_log( + provider="zai", + endpoint="/v1/chat/completions", + source="pytest", + protocol="openai", + client_name="pytest-client", + model="glm-5", + status_code=200, + success=True, + duration=1.25, + first_token_time=0.42, + input_tokens=111, + output_tokens=22, + cache_creation_tokens=9, + cache_read_tokens=7, + total_tokens=133, + ) + + monkeypatch.setattr(admin_api, "get_request_log_dao", lambda: dao) + request = _make_get_request( + "/admin/api/recent-logs", + {"page": "1", "page_size": "12"}, + ) + + response = await admin_api.get_recent_logs(request) + body = response.body.decode("utf-8") + + assert response.status_code == 200 + assert "请求" in body + assert "标记" in body + assert "输入 / 输出" in body + assert "缓存创建 / 命中" in body + assert "用时 / 首字" in body + assert "111" in body + assert "22" in body + assert "9" in body + assert "7" in body + assert "1.25s" in body + assert "0.42s" in body + + +@pytest.mark.asyncio +async def test_recent_logs_component_deduplicates_client_and_source_labels( + tmp_path, + monkeypatch, +): + dao = RequestLogDAO(str(tmp_path / "request_logs_recent_dedupe.db")) + await dao.add_log( + provider="zai", + endpoint="/v1/chat/completions", + source="browser", + protocol="openai", + client_name="Browser", + model="glm-5", + status_code=200, + success=True, + duration=1.0, + ) + + monkeypatch.setattr(admin_api, "get_request_log_dao", lambda: dao) + request = _make_get_request( + "/admin/api/recent-logs", + {"page": "1", "page_size": "12"}, + ) + + response = await admin_api.get_recent_logs(request) + body = response.body.decode("utf-8") + + assert response.status_code == 200 + assert "Browser" in body + assert "OpenAI" in body + assert "glm-5" in body + assert ">browser<" not in body + assert ">zai<" not in body + + +@pytest.mark.asyncio +async def test_token_dao_supports_count_and_offset_pagination(tmp_path): + dao = TokenDAO(str(tmp_path / "tokens_paging.db")) + await dao.init_database() + + for index in range(5): + await dao.add_token("zai", f"token-{index}", validate=False) + + total_count = await dao.count_tokens_by_provider("zai", enabled_only=False) + paged_tokens = await dao.get_tokens_by_provider( + "zai", + enabled_only=False, + limit=2, + offset=2, + ) + + assert total_count == 5 + assert len(paged_tokens) == 2 + assert paged_tokens[0]["token"] == "token-2" + assert paged_tokens[1]["token"] == "token-3" + + +@pytest.mark.asyncio +async def test_token_pool_realtime_usage_stats_sync_to_db(tmp_path, monkeypatch): + dao = TokenDAO(str(tmp_path / "token_usage.db")) + await dao.init_database() + token_id = await dao.add_token("zai", "token-usage", validate=False) + assert token_id is not None + + pool = TokenPool([(token_id, "token-usage", "user")]) + + await pool.record_token_success("token-usage", dao=dao) + await pool.record_token_failure("token-usage", Exception("boom"), dao=dao) + + stats = await dao.get_token_stats(token_id) + assert stats is not None + assert stats["total_requests"] == 2 + assert stats["successful_requests"] == 1 + assert stats["failed_requests"] == 1 + + monkeypatch.setattr(token_pool_module, "_token_pool", pool) + monkeypatch.setattr(token_dao_module, "_token_dao", dao) + + await sync_token_stats_to_db() + + stats_after_sync = await dao.get_token_stats(token_id) + assert stats_after_sync is not None + assert stats_after_sync["total_requests"] == 2 + assert stats_after_sync["successful_requests"] == 1 + assert stats_after_sync["failed_requests"] == 1 + + +def test_format_uptime_formats_seconds_minutes_and_hours(): + assert format_uptime(59) == "59秒" + assert format_uptime(3661) == "1小时 1分钟 1秒" diff --git a/tests/test_admin_tokens.py b/tests/test_admin_tokens.py new file mode 100644 index 0000000000000000000000000000000000000000..23752b24f42390c3861e4047ace9497a5e45b720 --- /dev/null +++ b/tests/test_admin_tokens.py @@ -0,0 +1,154 @@ +from urllib.parse import urlencode + +import pytest +from jinja2 import Environment, FileSystemLoader +from starlette.requests import Request + +from app.admin import api as admin_api +from app.core.config import settings +from app.services.token_automation import TokenMaintenanceSummary +from app.services.token_importer import TokenImportSummary + + +def _make_form_request(path: str, data: dict[str, str] | None = None) -> Request: + encoded = urlencode(data or {}, doseq=True).encode() + sent = False + + async def receive(): + nonlocal sent + if sent: + return {"type": "http.request", "body": b"", "more_body": False} + sent = True + return {"type": "http.request", "body": encoded, "more_body": False} + + scope = { + "type": "http", + "http_version": "1.1", + "method": "POST", + "scheme": "http", + "path": path, + "raw_path": path.encode(), + "query_string": b"", + "headers": [ + ( + b"content-type", + b"application/x-www-form-urlencoded", + ) + ], + "client": ("testclient", 50000), + "server": ("testserver", 80), + } + return Request(scope, receive) + + +@pytest.mark.asyncio +async def test_import_directory_uses_configured_source_dir_when_form_empty( + tmp_path, + monkeypatch, +): + source_dir = tmp_path / "tokens" + source_dir.mkdir() + monkeypatch.setattr( + settings, + "TOKEN_AUTO_IMPORT_SOURCE_DIR", + str(source_dir), + ) + + called: dict[str, object] = {} + + async def fake_run_directory_import( + source_dir_arg, + *, + provider, + validate, + ): + called["source_dir"] = source_dir_arg + called["provider"] = provider + called["validate"] = validate + return TokenImportSummary( + source_dir=str(source_dir), + scanned_files=1, + imported_count=1, + duplicate_count=0, + invalid_json_count=0, + missing_token_count=0, + invalid_token_count=0, + ) + + import app.services.token_automation as token_automation + + monkeypatch.setattr( + token_automation, + "run_directory_import", + fake_run_directory_import, + ) + + response = await admin_api.import_tokens_from_directory_api( + _make_form_request("/admin/api/tokens/import-directory"), + ) + body = response.body.decode("utf-8") + + assert response.status_code == 200 + assert called["source_dir"] == str(source_dir) + assert called["provider"] == "zai" + assert called["validate"] is True + assert "导入成功" in body + + +@pytest.mark.asyncio +async def test_run_maintenance_uses_configured_actions_when_form_empty( + monkeypatch, +): + monkeypatch.setattr(settings, "TOKEN_AUTO_REMOVE_DUPLICATES", True) + monkeypatch.setattr(settings, "TOKEN_AUTO_HEALTH_CHECK", False) + monkeypatch.setattr(settings, "TOKEN_AUTO_DELETE_INVALID", True) + + called: dict[str, object] = {} + + async def fake_run_token_maintenance( + *, + provider, + remove_duplicates, + run_health_check, + delete_invalid_tokens, + ): + called["provider"] = provider + called["remove_duplicates"] = remove_duplicates + called["run_health_check"] = run_health_check + called["delete_invalid_tokens"] = delete_invalid_tokens + return TokenMaintenanceSummary( + provider=provider, + checked_count=2, + duplicate_removed_count=1, + valid_count=1, + guest_count=0, + invalid_count=1, + deleted_invalid_count=1, + ) + + import app.services.token_automation as token_automation + + monkeypatch.setattr( + token_automation, + "run_token_maintenance", + fake_run_token_maintenance, + ) + + response = await admin_api.run_token_maintenance_api( + _make_form_request("/admin/api/tokens/maintenance/run"), + ) + body = response.body.decode("utf-8") + + assert response.status_code == 200 + assert called["provider"] == "zai" + assert called["remove_duplicates"] is True + assert called["run_health_check"] is False + assert called["delete_invalid_tokens"] is True + assert "维护完成" in body + + +def test_tokens_template_compiles(): + env = Environment(loader=FileSystemLoader("app/templates")) + template = env.get_template("tokens.html") + + assert template is not None diff --git a/tests/test_dependency_metadata.py b/tests/test_dependency_metadata.py new file mode 100644 index 0000000000000000000000000000000000000000..47f6130b24dcd09a099f53109060f1ea95be561c --- /dev/null +++ b/tests/test_dependency_metadata.py @@ -0,0 +1,14 @@ +from pathlib import Path + + +ROOT = Path(__file__).resolve().parents[1] + + +def test_requirements_enable_httpx_socks_support(): + requirements = (ROOT / "requirements.txt").read_text(encoding="utf-8") + assert "httpx[http2,socks]==0.28.1" in requirements + + +def test_pyproject_enable_httpx_socks_support(): + pyproject = (ROOT / "pyproject.toml").read_text(encoding="utf-8") + assert '"httpx[http2,socks]==0.28.1"' in pyproject diff --git a/tests/test_glm45_real_token.py b/tests/test_glm45_real_token.py new file mode 100644 index 0000000000000000000000000000000000000000..9fef459176192a4ae3af5a56c9f7e2a80c983210 --- /dev/null +++ b/tests/test_glm45_real_token.py @@ -0,0 +1,33 @@ +import pytest + +from app.core.config import settings +from app.core.upstream import UpstreamClient +from app.models.schemas import Message, OpenAIRequest +from tests.real_upstream_test_utils import ( + assert_usage_present, + extract_content, + install_real_anonymous, +) + + +@pytest.mark.asyncio +async def test_glm45_with_real_anonymous_request(monkeypatch): + install_real_anonymous(monkeypatch) + + client = UpstreamClient() + request = OpenAIRequest( + model=settings.GLM45_MODEL, + messages=[ + Message( + role="user", + content="请只输出字符串 GLM45_OK,不要输出任何其他内容。", + ) + ], + stream=False, + ) + + payload = await client.chat_completion(request) + content = extract_content(payload) + + assert "GLM45_OK" in content + assert_usage_present(payload) diff --git a/tests/test_glm46v_real_token.py b/tests/test_glm46v_real_token.py new file mode 100644 index 0000000000000000000000000000000000000000..95734ca1b70e8e8077a4f65948a16e14a3de85ec --- /dev/null +++ b/tests/test_glm46v_real_token.py @@ -0,0 +1,43 @@ +import pytest + +from app.core.config import settings +from app.core.upstream import UpstreamClient +from app.models.schemas import ContentPart, ImageUrl, Message, OpenAIRequest +from tests.real_upstream_test_utils import ( + RED_2X2_PNG_DATA_URL, + assert_usage_present, + extract_content, + install_real_auth, +) + + +@pytest.mark.asyncio +async def test_glm46v_with_real_auth_token_and_image(monkeypatch): + install_real_auth(monkeypatch) + + client = UpstreamClient() + request = OpenAIRequest( + model=settings.GLM46V_MODEL, + messages=[ + Message( + role="user", + content=[ + ContentPart( + type="text", + text="请判断这张图片的主色调。如果它是红色,只输出 RED_OK。", + ), + ContentPart( + type="image_url", + image_url=ImageUrl(url=RED_2X2_PNG_DATA_URL), + ), + ], + ) + ], + stream=False, + ) + + payload = await client.chat_completion(request) + content = extract_content(payload) + + assert "RED_OK" in content + assert_usage_present(payload) diff --git a/tests/test_glm47_request_bootstrap.py b/tests/test_glm47_request_bootstrap.py new file mode 100644 index 0000000000000000000000000000000000000000..3940370be4b783dec6f626cf0c955cb1ba5e518e --- /dev/null +++ b/tests/test_glm47_request_bootstrap.py @@ -0,0 +1,444 @@ +from urllib.parse import parse_qs, urlparse + +import pytest + +from app.core import upstream as upstream_module +from app.core.upstream import UpstreamClient +from app.models.schemas import ContentPart, ImageUrl, Message, OpenAIRequest + +FAKE_HEADERS = { + "Content-Type": "application/json", + "Accept": "application/json, text/event-stream", + "Connection": "keep-alive", + "Cache-Control": "no-cache", + "User-Agent": ( + "Mozilla/5.0 (Windows NT 10.0; Win64; x64) " + "AppleWebKit/537.36 (KHTML, like Gecko) " + "Chrome/144.0.0.0 Safari/537.36" + ), + "Accept-Language": "zh-CN", + "X-FE-Version": "prod-fe-test", + "Origin": "https://chat.z.ai", + "Referer": "https://chat.z.ai/", +} + + +def _make_request(model: str) -> OpenAIRequest: + return OpenAIRequest( + model=model, + messages=[Message(role="user", content="请用一句话回答:你好")], + stream=True, + ) + + +async def _fake_get_auth_info(self, excluded_tokens=None, excluded_guest_user_ids=None): + return { + "token": "auth-token", + "user_id": "user-123", + "username": "User", + "auth_mode": "authenticated", + "token_source": "auth_pool", + "guest_user_id": None, + } + + +@pytest.mark.asyncio +async def test_glm47_request_bootstraps_chat_and_uses_browser_signature(monkeypatch): + create_chat_calls: list[dict] = [] + browser_type_calls: list[str | None] = [] + + def fake_headers(chat_id: str = "", browser_type=None): + browser_type_calls.append(browser_type) + headers = dict(FAKE_HEADERS) + headers["Referer"] = ( + f"https://chat.z.ai/c/{chat_id}" + if chat_id + else FAKE_HEADERS["Referer"] + ) + return headers + + async def fake_create_chat( + self, + *, + prompt, + model, + token, + headers, + enable_thinking, + web_search, + user_message_id, + files, + feature_entries, + mcp_servers, + ): + create_chat_calls.append( + { + "prompt": prompt, + "model": model, + "token": token, + "user_agent": headers["User-Agent"], + "enable_thinking": enable_thinking, + "web_search": web_search, + "user_message_id": user_message_id, + "files": files, + "feature_entries": feature_entries, + "mcp_servers": mcp_servers, + } + ) + return "persisted-chat-id" + + monkeypatch.setattr(UpstreamClient, "get_auth_info", _fake_get_auth_info) + monkeypatch.setattr(UpstreamClient, "_create_upstream_chat", fake_create_chat) + monkeypatch.setattr(upstream_module, "get_dynamic_headers", fake_headers) + + client = UpstreamClient() + transformed = await client.transform_request(_make_request("GLM-4.7")) + parsed_url = urlparse(transformed["url"]) + query = parse_qs(parsed_url.query) + + assert len(create_chat_calls) == 1 + assert create_chat_calls[0]["prompt"] == "请用一句话回答:你好" + assert create_chat_calls[0]["model"] == "glm-4.7" + assert create_chat_calls[0]["token"] == "auth-token" + assert create_chat_calls[0]["user_agent"] == FAKE_HEADERS["User-Agent"] + assert create_chat_calls[0]["enable_thinking"] is False + assert create_chat_calls[0]["web_search"] is False + assert create_chat_calls[0]["files"] is None + assert create_chat_calls[0]["feature_entries"] is None + assert create_chat_calls[0]["mcp_servers"] is None + assert create_chat_calls[0]["user_message_id"] + assert browser_type_calls == ["chrome"] + assert transformed["chat_id"] == "persisted-chat-id" + assert transformed["headers"]["Accept"] == "*/*" + assert transformed["headers"]["Referer"] == "https://chat.z.ai/c/persisted-chat-id" + assert query["current_url"] == ["https://chat.z.ai/c/persisted-chat-id"] + assert query["pathname"] == ["/c/persisted-chat-id"] + assert query["user_agent"] == [FAKE_HEADERS["User-Agent"]] + assert query["timezone"] == ["Asia/Shanghai"] + assert transformed["body"]["chat_id"] == "persisted-chat-id" + assert transformed["body"]["current_user_message_id"] == ( + create_chat_calls[0]["user_message_id"] + ) + assert transformed["body"]["features"]["enable_thinking"] is False + assert transformed["body"]["background_tasks"] == { + "title_generation": True, + "tags_generation": True, + } + assert "session_id" not in transformed["body"] + assert "model_item" not in transformed["body"] + + +@pytest.mark.asyncio +async def test_glm47_thinking_defaults_to_enable_thinking(monkeypatch): + create_chat_calls: list[dict] = [] + + def fake_headers(chat_id: str = "", browser_type=None): + headers = dict(FAKE_HEADERS) + headers["Referer"] = ( + f"https://chat.z.ai/c/{chat_id}" + if chat_id + else FAKE_HEADERS["Referer"] + ) + return headers + + async def fake_create_chat( + self, + *, + prompt, + model, + token, + headers, + enable_thinking, + web_search, + user_message_id, + files, + feature_entries, + mcp_servers, + ): + create_chat_calls.append( + { + "model": model, + "enable_thinking": enable_thinking, + "web_search": web_search, + "user_message_id": user_message_id, + "files": files, + "feature_entries": feature_entries, + "mcp_servers": mcp_servers, + } + ) + return "thinking-chat-id" + + monkeypatch.setattr(UpstreamClient, "get_auth_info", _fake_get_auth_info) + monkeypatch.setattr(UpstreamClient, "_create_upstream_chat", fake_create_chat) + monkeypatch.setattr(upstream_module, "get_dynamic_headers", fake_headers) + + client = UpstreamClient() + transformed = await client.transform_request(_make_request("GLM-4.7-Thinking")) + + assert len(create_chat_calls) == 1 + assert create_chat_calls[0]["model"] == "glm-4.7" + assert create_chat_calls[0]["enable_thinking"] is True + assert create_chat_calls[0]["web_search"] is False + assert create_chat_calls[0]["files"] is None + assert create_chat_calls[0]["feature_entries"] is None + assert create_chat_calls[0]["mcp_servers"] is None + assert create_chat_calls[0]["user_message_id"] + assert transformed["body"]["features"]["enable_thinking"] is True + assert transformed["body"]["current_user_message_id"] == ( + create_chat_calls[0]["user_message_id"] + ) + + +@pytest.mark.asyncio +async def test_non_glm47_request_keeps_legacy_request_shape(monkeypatch): + def fake_headers(chat_id: str = "", browser_type=None): + headers = dict(FAKE_HEADERS) + headers["Referer"] = ( + f"https://chat.z.ai/c/{chat_id}" + if chat_id + else FAKE_HEADERS["Referer"] + ) + return headers + + async def fail_create_chat(self, **kwargs): + raise AssertionError("GLM-4.5 不应触发 create_chat") + + monkeypatch.setattr(UpstreamClient, "get_auth_info", _fake_get_auth_info) + monkeypatch.setattr(UpstreamClient, "_create_upstream_chat", fail_create_chat) + monkeypatch.setattr(upstream_module, "get_dynamic_headers", fake_headers) + + client = UpstreamClient() + transformed = await client.transform_request(_make_request("GLM-4.5")) + query = parse_qs(urlparse(transformed["url"]).query) + + assert transformed["headers"]["Accept"] == "application/json" + assert transformed["chat_id"] != "persisted-chat-id" + assert "user_agent" not in query + assert "session_id" in transformed["body"] + assert transformed["body"]["model_item"]["name"] == "GLM-4.5" + + +@pytest.mark.asyncio +async def test_glm5_defaults_to_enable_thinking(monkeypatch): + def fake_headers(chat_id: str = "", browser_type=None): + headers = dict(FAKE_HEADERS) + headers["Referer"] = ( + f"https://chat.z.ai/c/{chat_id}" + if chat_id + else FAKE_HEADERS["Referer"] + ) + return headers + + async def fail_create_chat(self, **kwargs): + raise AssertionError("GLM-5 不应触发 create_chat") + + monkeypatch.setattr(UpstreamClient, "get_auth_info", _fake_get_auth_info) + monkeypatch.setattr(UpstreamClient, "_create_upstream_chat", fail_create_chat) + monkeypatch.setattr(upstream_module, "get_dynamic_headers", fake_headers) + + client = UpstreamClient() + transformed = await client.transform_request(_make_request("GLM-5")) + query = parse_qs(urlparse(transformed["url"]).query) + + assert transformed["headers"]["Accept"] == "application/json" + assert transformed["body"]["model"] == "glm-5" + assert transformed["body"]["features"]["enable_thinking"] is True + assert transformed["body"]["features"]["preview_mode"] is True + assert "session_id" in transformed["body"] + assert "user_agent" not in query + + +@pytest.mark.asyncio +async def test_glm5_allows_explicitly_disabling_thinking(monkeypatch): + def fake_headers(chat_id: str = "", browser_type=None): + headers = dict(FAKE_HEADERS) + headers["Referer"] = ( + f"https://chat.z.ai/c/{chat_id}" + if chat_id + else FAKE_HEADERS["Referer"] + ) + return headers + + async def fail_create_chat(self, **kwargs): + raise AssertionError("GLM-5 不应触发 create_chat") + + monkeypatch.setattr(UpstreamClient, "get_auth_info", _fake_get_auth_info) + monkeypatch.setattr(UpstreamClient, "_create_upstream_chat", fail_create_chat) + monkeypatch.setattr(upstream_module, "get_dynamic_headers", fake_headers) + + client = UpstreamClient() + request = OpenAIRequest( + model="GLM-5", + messages=[Message(role="user", content="请用一句话回答:你好")], + stream=False, + enable_thinking=False, + ) + + transformed = await client.transform_request(request) + + assert transformed["body"]["features"]["enable_thinking"] is False + + +@pytest.mark.asyncio +async def test_glm46v_uses_persisted_chat_and_visual_features(monkeypatch): + create_chat_calls: list[dict] = [] + upload_calls: list[dict] = [] + browser_type_calls: list[str | None] = [] + + def fake_headers(chat_id: str = "", browser_type=None): + browser_type_calls.append(browser_type) + headers = dict(FAKE_HEADERS) + headers["Referer"] = ( + f"https://chat.z.ai/c/{chat_id}" + if chat_id + else FAKE_HEADERS["Referer"] + ) + return headers + + async def fake_create_chat( + self, + *, + prompt, + model, + token, + headers, + enable_thinking, + web_search, + user_message_id, + files, + feature_entries, + mcp_servers, + ): + create_chat_calls.append( + { + "prompt": prompt, + "model": model, + "token": token, + "user_agent": headers["User-Agent"], + "enable_thinking": enable_thinking, + "web_search": web_search, + "user_message_id": user_message_id, + "files": files, + "feature_entries": feature_entries, + "mcp_servers": mcp_servers, + } + ) + return "vision-chat-id" + + async def fake_upload_image( + self, + data_url, + chat_id, + token, + user_id, + auth_mode="authenticated", + ): + upload_calls.append( + { + "data_url": data_url, + "chat_id": chat_id, + "token": token, + "user_id": user_id, + "auth_mode": auth_mode, + } + ) + return { + "type": "image", + "file": { + "id": "file-id", + "user_id": user_id, + "filename": "file.png", + "data": {}, + "meta": { + "name": "file.png", + "content_type": "image/png", + "size": 4, + "data": {}, + }, + "created_at": 1, + "updated_at": 1, + }, + "id": "file-id", + "url": "/api/v1/files/file-id/content", + "name": "file.png", + "status": "uploaded", + "size": 4, + "error": "", + "itemId": "item-id", + "media": "image", + } + + monkeypatch.setattr(UpstreamClient, "get_auth_info", _fake_get_auth_info) + monkeypatch.setattr(UpstreamClient, "_create_upstream_chat", fake_create_chat) + monkeypatch.setattr(UpstreamClient, "upload_image", fake_upload_image) + monkeypatch.setattr(upstream_module, "get_dynamic_headers", fake_headers) + + client = UpstreamClient() + request = OpenAIRequest( + model="GLM-4.6V", + messages=[ + Message( + role="user", + content=[ + ContentPart(type="text", text="请判断图片主色调"), + ContentPart( + type="image_url", + image_url=ImageUrl(url="data:image/png;base64,AAAA"), + ), + ], + ) + ], + stream=False, + ) + + transformed = await client.transform_request(request) + query = parse_qs(urlparse(transformed["url"]).query) + + assert len(create_chat_calls) == 1 + assert create_chat_calls[0]["prompt"] == "请判断图片主色调" + assert create_chat_calls[0]["model"] == "glm-4.6v" + assert create_chat_calls[0]["token"] == "auth-token" + assert create_chat_calls[0]["user_agent"] == FAKE_HEADERS["User-Agent"] + assert create_chat_calls[0]["enable_thinking"] is True + assert create_chat_calls[0]["web_search"] is False + assert ( + create_chat_calls[0]["feature_entries"] + == upstream_module.GLM46V_SELECTED_FEATURES + ) + assert create_chat_calls[0]["mcp_servers"] == upstream_module.GLM46V_MCP_SERVERS + assert create_chat_calls[0]["user_message_id"] + assert create_chat_calls[0]["files"][0]["id"] == "file-id" + assert create_chat_calls[0]["files"][0]["ref_user_msg_id"] == ( + create_chat_calls[0]["user_message_id"] + ) + assert upload_calls == [ + { + "data_url": "data:image/png;base64,AAAA", + "chat_id": "", + "token": "auth-token", + "user_id": "user-123", + "auth_mode": "authenticated", + } + ] + assert browser_type_calls == ["chrome"] + assert transformed["chat_id"] == "vision-chat-id" + assert transformed["headers"]["Accept"] == "*/*" + assert transformed["headers"]["Referer"] == "https://chat.z.ai/c/vision-chat-id" + assert query["current_url"] == ["https://chat.z.ai/c/vision-chat-id"] + assert query["pathname"] == ["/c/vision-chat-id"] + assert query["user_agent"] == [FAKE_HEADERS["User-Agent"]] + assert transformed["body"]["current_user_message_id"] == ( + create_chat_calls[0]["user_message_id"] + ) + assert transformed["body"]["features"]["enable_thinking"] is True + assert transformed["body"]["features"]["preview_mode"] is False + assert "features" not in transformed["body"]["features"] + assert transformed["body"]["mcp_servers"] == upstream_module.GLM46V_MCP_SERVERS + assert transformed["body"]["files"][0]["id"] == "file-id" + assert transformed["body"]["files"][0]["ref_user_msg_id"] == ( + create_chat_calls[0]["user_message_id"] + ) + assert transformed["body"]["messages"][0]["content"][1]["image_url"]["url"] == ( + "file-id" + ) + assert "session_id" not in transformed["body"] diff --git a/tests/test_glm5_real_token.py b/tests/test_glm5_real_token.py new file mode 100644 index 0000000000000000000000000000000000000000..aa6de719cc0ab2ca3b7ab6c9508c463978159a4c --- /dev/null +++ b/tests/test_glm5_real_token.py @@ -0,0 +1,33 @@ +import pytest + +from app.core.config import settings +from app.core.upstream import UpstreamClient +from app.models.schemas import Message, OpenAIRequest +from tests.real_upstream_test_utils import ( + assert_usage_present, + extract_content, + install_real_anonymous, +) + + +@pytest.mark.asyncio +async def test_glm5_with_real_anonymous_request(monkeypatch): + install_real_anonymous(monkeypatch) + + client = UpstreamClient() + request = OpenAIRequest( + model=settings.GLM5_MODEL, + messages=[ + Message( + role="user", + content="请只输出字符串 GLM5_OK,不要输出任何其他内容。", + ) + ], + stream=False, + ) + + payload = await client.chat_completion(request) + content = extract_content(payload) + + assert "GLM5_OK" in content + assert_usage_present(payload) diff --git a/tests/test_guest_pool_concurrency.py b/tests/test_guest_pool_concurrency.py new file mode 100644 index 0000000000000000000000000000000000000000..b85d40603b94a048a8bdf549adcfed653d4be544 --- /dev/null +++ b/tests/test_guest_pool_concurrency.py @@ -0,0 +1,222 @@ +import asyncio +import types +from dataclasses import dataclass, field +from unittest.mock import AsyncMock + +import pytest + +from app.core import upstream as upstream_module +from app.core.upstream import UpstreamClient +from app.models.schemas import Message, OpenAIRequest +from app.utils.guest_session_pool import GuestSession, GuestSessionPool + +POOL_SIZE = 8 +REQUEST_COUNT = 64 +REQUEST_DELAY_SECONDS = 0.03 +FAILURE_POOL_SIZE = 4 +FAILURE_REQUEST_COUNT = 24 +FAILURE_DELAY_SECONDS = 0.02 + + +def _make_session(user_id: str, token_suffix: str) -> GuestSession: + return GuestSession( + token=f"token-{token_suffix}", + user_id=user_id, + username=f"Guest-{user_id}", + ) + + +def _make_request() -> OpenAIRequest: + return OpenAIRequest( + model="GLM-4.5", + messages=[Message(role="user", content="ping")], + stream=False, + ) + + +@dataclass +class LoadState: + active_posts: int = 0 + peak_posts: int = 0 + failed_once: set[str] = field(default_factory=set) + + +class FakeResponse: + def __init__(self, status_code: int, text: str): + self.status_code = status_code + self.text = text + + @property + def is_success(self) -> bool: + return 200 <= self.status_code < 300 + + +def _build_fake_async_client(handler): + class FakeAsyncClient: + def __init__(self, *args, **kwargs): + pass + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + return False + + async def post(self, url, headers=None, json=None): + return await handler(url, headers or {}, json or {}) + + return FakeAsyncClient + + +async def _build_pool(monkeypatch, pool_size: int) -> GuestSessionPool: + pool = GuestSessionPool(pool_size=pool_size) + counter = 0 + + async def fake_create_session() -> GuestSession: + nonlocal counter + counter += 1 + return _make_session(f"guest-{counter}", str(counter)) + + monkeypatch.setattr(pool, "_create_session", fake_create_session) + monkeypatch.setattr(pool, "_maintenance_loop", AsyncMock(return_value=None)) + monkeypatch.setattr(pool, "_delete_all_chats", AsyncMock(return_value=True)) + await pool.initialize() + await asyncio.sleep(0) + return pool + + +def _bind_guest_request_flow( + client, + pool: GuestSessionPool, + assigned_user_ids: list[str], +): + async def fake_transform_request( + self, + request, + excluded_tokens=None, + excluded_guest_user_ids=None, + ): + session = await pool.acquire(exclude_user_ids=excluded_guest_user_ids) + assigned_user_ids.append(session.user_id) + return { + "url": f"https://upstream.test/{session.user_id}", + "headers": {"x-guest-user-id": session.user_id}, + "body": {"model": request.model}, + "token": session.token, + "chat_id": f"chat-{session.user_id}", + "model": request.model, + "user_id": session.user_id, + "auth_mode": "guest", + "token_source": "guest_pool", + "guest_user_id": session.user_id, + } + + async def fake_transform_response(self, response, request, transformed): + return { + "ok": response.is_success, + "guest_user_id": transformed["guest_user_id"], + "status_code": response.status_code, + } + + client.transform_request = types.MethodType(fake_transform_request, client) + client.transform_response = types.MethodType(fake_transform_response, client) + + +def _patch_upstream_globals(monkeypatch, pool: GuestSessionPool, async_client_cls): + monkeypatch.setattr(upstream_module, "get_guest_session_pool", lambda: pool) + monkeypatch.setattr(upstream_module, "get_token_pool", lambda: None) + monkeypatch.setattr(upstream_module.settings, "ANONYMOUS_MODE", True) + monkeypatch.setattr(upstream_module.httpx, "AsyncClient", async_client_cls) + + +def _build_handler( + delay: float, + state: LoadState, + failure_users: set[str] | None = None, +): + lock = asyncio.Lock() + failures = failure_users or set() + + async def handler(url, headers, body): + user_id = headers["x-guest-user-id"] + async with lock: + state.active_posts += 1 + state.peak_posts = max(state.peak_posts, state.active_posts) + + try: + await asyncio.sleep(delay) + if user_id in failures and user_id not in state.failed_once: + state.failed_once.add(user_id) + return FakeResponse(401, '{"message":"expired"}') + return FakeResponse(200, "{}") + finally: + async with lock: + state.active_posts -= 1 + + return handler + + +@pytest.mark.asyncio +async def test_guest_pool_handles_many_concurrent_requests(monkeypatch): + pool = await _build_pool(monkeypatch, POOL_SIZE) + assigned_user_ids: list[str] = [] + state = LoadState() + client = UpstreamClient() + handler = _build_handler(REQUEST_DELAY_SECONDS, state) + + _bind_guest_request_flow(client, pool, assigned_user_ids) + _patch_upstream_globals( + monkeypatch, + pool, + _build_fake_async_client(handler), + ) + + results = await asyncio.gather( + *(client.chat_completion(_make_request()) for _ in range(REQUEST_COUNT)) + ) + pool_status = pool.get_pool_status() + + assert all(result.get("ok") is True for result in results) + assert len(set(assigned_user_ids)) == POOL_SIZE + assert state.peak_posts >= POOL_SIZE + assert pool_status == { + "total_sessions": POOL_SIZE, + "valid_sessions": POOL_SIZE, + "available_sessions": POOL_SIZE, + "busy_sessions": 0, + "expired_sessions": 0, + } + + await pool.close() + + +@pytest.mark.asyncio +async def test_guest_pool_recovers_from_failures_under_concurrency(monkeypatch): + pool = await _build_pool(monkeypatch, FAILURE_POOL_SIZE) + assigned_user_ids: list[str] = [] + state = LoadState() + client = UpstreamClient() + failure_users = {"guest-1", "guest-2"} + handler = _build_handler(FAILURE_DELAY_SECONDS, state, failure_users) + + _bind_guest_request_flow(client, pool, assigned_user_ids) + _patch_upstream_globals( + monkeypatch, + pool, + _build_fake_async_client(handler), + ) + + results = await asyncio.gather( + *(client.chat_completion(_make_request()) for _ in range(FAILURE_REQUEST_COUNT)) + ) + pool_status = pool.get_pool_status() + current_user_ids = set(pool._sessions) + + assert all(result.get("ok") is True for result in results) + assert state.failed_once == failure_users + assert "guest-1" not in current_user_ids + assert "guest-2" not in current_user_ids + assert pool_status["busy_sessions"] == 0 + assert pool_status["valid_sessions"] == FAILURE_POOL_SIZE + + await pool.close() diff --git a/tests/test_guest_session_pool.py b/tests/test_guest_session_pool.py new file mode 100644 index 0000000000000000000000000000000000000000..27800d9ed503699613c8f27b1d393404c97a84ea --- /dev/null +++ b/tests/test_guest_session_pool.py @@ -0,0 +1,97 @@ +import asyncio +from unittest.mock import AsyncMock, Mock + +import pytest + +from app.utils import guest_session_pool as guest_pool_module +from app.utils.guest_session_pool import GuestSession, GuestSessionPool + + +def _make_session(user_id: str, token_suffix: str) -> GuestSession: + return GuestSession( + token=f"token-{token_suffix}", + user_id=user_id, + username=f"Guest-{user_id}", + ) + + +@pytest.mark.asyncio +async def test_ensure_capacity_returns_when_only_duplicate_user_ids_are_created( + monkeypatch, +): + pool = GuestSessionPool(pool_size=2) + create_calls = 0 + + async def fake_create_session() -> GuestSession: + nonlocal create_calls + create_calls += 1 + return _make_session("duplicate-user", str(create_calls)) + + monkeypatch.setattr(pool, "_create_session", fake_create_session) + + await asyncio.wait_for(pool._ensure_capacity(), timeout=0.2) + + assert create_calls >= 1 + assert set(pool._sessions) == {"duplicate-user"} + assert len(pool._sessions) == 1 + + +@pytest.mark.asyncio +async def test_initialize_logs_unique_session_count_when_results_contain_duplicates( + monkeypatch, +): + pool = GuestSessionPool(pool_size=3) + sessions = [ + _make_session("user-1", "1"), + _make_session("user-1", "2"), + _make_session("user-2", "3"), + _make_session("user-1", "4"), + _make_session("user-2", "5"), + _make_session("user-1", "6"), + _make_session("user-2", "7"), + _make_session("user-1", "8"), + _make_session("user-2", "9"), + ] + info_mock = Mock() + + async def fake_create_session() -> GuestSession: + return sessions.pop(0) + + monkeypatch.setattr(pool, "_create_session", fake_create_session) + monkeypatch.setattr(pool, "_maintenance_loop", AsyncMock(return_value=None)) + monkeypatch.setattr(guest_pool_module.logger, "info", info_mock) + monkeypatch.setattr(guest_pool_module.logger, "warning", Mock()) + + await pool.initialize() + await asyncio.sleep(0) + + assert set(pool._sessions) == {"user-1", "user-2"} + assert any( + call.args == ("✅ 匿名会话池初始化完成: 2 个会话",) + for call in info_mock.call_args_list + ) + + +@pytest.mark.asyncio +async def test_acquire_skips_duplicate_excluded_session_without_overwriting_pool( + monkeypatch, +): + pool = GuestSessionPool(pool_size=2) + existing = _make_session("user-1", "seed") + pool._sessions[existing.user_id] = existing + created_sessions = [ + _make_session("user-1", "duplicate"), + _make_session("user-2", "fresh"), + ] + + async def fake_create_session() -> GuestSession: + return created_sessions.pop(0) + + monkeypatch.setattr(pool, "_create_session", fake_create_session) + + acquired = await pool.acquire(exclude_user_ids={"user-1"}) + + assert acquired.user_id == "user-2" + assert acquired.active_requests == 1 + assert set(pool._sessions) == {"user-1", "user-2"} + assert pool._sessions["user-1"].token == "token-seed" diff --git a/tests/test_request_logging.py b/tests/test_request_logging.py new file mode 100644 index 0000000000000000000000000000000000000000..c5d46e64a936975f4f72f77b11bb1e43a9fc100b --- /dev/null +++ b/tests/test_request_logging.py @@ -0,0 +1,48 @@ +from app.utils.request_logging import ( + extract_claude_usage, + extract_openai_usage, +) + + +def test_extract_openai_usage_supports_cached_prompt_details(): + usage = extract_openai_usage( + { + "usage": { + "prompt_tokens": 120, + "completion_tokens": 45, + "total_tokens": 165, + "prompt_tokens_details": { + "cached_tokens": 32, + }, + } + } + ) + + assert usage == { + "input_tokens": 120, + "output_tokens": 45, + "cache_creation_tokens": 0, + "cache_read_tokens": 32, + "total_tokens": 165, + } + + +def test_extract_claude_usage_supports_cache_token_fields(): + usage = extract_claude_usage( + { + "usage": { + "input_tokens": 200, + "output_tokens": 80, + "cache_creation_input_tokens": 64, + "cache_read_input_tokens": 48, + } + } + ) + + assert usage == { + "input_tokens": 200, + "output_tokens": 80, + "cache_creation_tokens": 64, + "cache_read_tokens": 48, + "total_tokens": 392, + } diff --git a/tests/test_token_automation.py b/tests/test_token_automation.py new file mode 100644 index 0000000000000000000000000000000000000000..37445758bf571a54eb610ea77259f926ff13de0b --- /dev/null +++ b/tests/test_token_automation.py @@ -0,0 +1,51 @@ +import pytest + +from app.services.token_automation import run_token_maintenance +from app.services.token_dao import TokenDAO +from app.utils.token_pool import ZAITokenValidator + + +@pytest.mark.asyncio +async def test_run_token_maintenance_deletes_invalid_tokens_after_validation( + tmp_path, + monkeypatch, +): + dao = TokenDAO(str(tmp_path / "tokens.db")) + await dao.init_database() + + await dao.add_token("zai", "token-valid", validate=False) + await dao.add_token("zai", "token-guest", validate=False) + await dao.add_token("zai", "token-invalid", validate=False) + + async def fake_validate_token(cls, token): + mapping = { + "token-valid": ("user", True, None), + "token-guest": ("guest", False, "guest token"), + "token-invalid": ("unknown", False, "token expired"), + } + return mapping[token] + + monkeypatch.setattr( + ZAITokenValidator, + "validate_token", + classmethod(fake_validate_token), + ) + + summary = await run_token_maintenance( + provider="zai", + remove_duplicates=False, + run_health_check=False, + delete_invalid_tokens=True, + dao=dao, + pool=None, + ) + + assert summary.checked_count == 3 + assert summary.valid_count == 1 + assert summary.guest_count == 1 + assert summary.invalid_count == 1 + assert summary.deleted_invalid_count == 2 + + remaining_tokens = await dao.get_tokens_by_provider("zai", enabled_only=False) + assert [token["token"] for token in remaining_tokens] == ["token-valid"] + assert remaining_tokens[0]["token_type"] == "user" diff --git a/tests/test_token_importer.py b/tests/test_token_importer.py new file mode 100644 index 0000000000000000000000000000000000000000..ad5d34c50dd84f9478cc194857a10b91fcb3218d --- /dev/null +++ b/tests/test_token_importer.py @@ -0,0 +1,70 @@ +import json + +import pytest + +from app.services.token_dao import TokenDAO +from app.services.token_importer import import_tokens_from_directory + + +@pytest.mark.asyncio +async def test_import_tokens_from_directory_handles_duplicates_and_invalid_files( + tmp_path, +): + source_dir = tmp_path / "source_tokens" + source_dir.mkdir() + + (source_dir / "token_valid_1.json").write_text( + json.dumps( + { + "email": "alpha@example.com", + "token": "token-alpha", + "token_source": "context.cookie:token", + } + ), + encoding="utf-8", + ) + (source_dir / "token_valid_2.json").write_text( + json.dumps( + { + "email": "beta@example.com", + "token": "token-beta", + "token_source": "context.cookie:token", + } + ), + encoding="utf-8", + ) + (source_dir / "token_duplicate.json").write_text( + json.dumps( + { + "email": "alpha-dup@example.com", + "token": "token-alpha", + } + ), + encoding="utf-8", + ) + (source_dir / "token_missing.json").write_text( + json.dumps({"email": "missing@example.com"}), + encoding="utf-8", + ) + (source_dir / "token_invalid.json").write_text("{invalid json", encoding="utf-8") + + dao = TokenDAO(str(tmp_path / "tokens.db")) + await dao.init_database() + + summary = await import_tokens_from_directory( + source_dir, + provider="zai", + validate=False, + dao=dao, + ) + + assert summary.scanned_files == 5 + assert summary.imported_count == 2 + assert summary.duplicate_count == 1 + assert summary.missing_token_count == 1 + assert summary.invalid_json_count == 1 + assert summary.invalid_token_count == 0 + + tokens = await dao.get_tokens_by_provider("zai", enabled_only=False) + imported_values = {item["token"] for item in tokens} + assert imported_values == {"token-alpha", "token-beta"} diff --git a/tests/test_upstream_dual_pool.py b/tests/test_upstream_dual_pool.py new file mode 100644 index 0000000000000000000000000000000000000000..2989a736bb67f9df7873d371b963d743f33fae3a --- /dev/null +++ b/tests/test_upstream_dual_pool.py @@ -0,0 +1,407 @@ +import asyncio +import types +from dataclasses import dataclass +from unittest.mock import AsyncMock + +import pytest + +from app.core import upstream as upstream_module +from app.core.upstream import UpstreamClient +from app.models.schemas import Message, OpenAIRequest +from app.utils.guest_session_pool import GuestSession, GuestSessionPool + +AUTH_POOL_SIZE = 2 +GUEST_POOL_SIZE = 2 +AUTH_REQUEST_COUNT = 6 +MIXED_REQUEST_DELAY = 0.01 + + +def _make_request() -> OpenAIRequest: + return OpenAIRequest( + model="GLM-4.5", + messages=[Message(role="user", content="ping")], + stream=False, + ) + + +def _make_guest_session(user_id: str) -> GuestSession: + return GuestSession( + token=f"guest-token-{user_id}", + user_id=user_id, + username=f"Guest-{user_id}", + ) + + +@dataclass +class StubTokenPool: + tokens: list[str] + + def __post_init__(self): + self.failure_tokens: list[str] = [] + self.success_tokens: list[str] = [] + + def get_next_token(self, exclude_tokens=None): + excluded = exclude_tokens or set() + for token in self.tokens: + if token not in excluded: + return token + return None + + async def record_token_failure(self, token: str, error=None, dao=None): + self.failure_tokens.append(token) + + async def record_token_success(self, token: str, dao=None): + self.success_tokens.append(token) + + def get_pool_status(self): + return {"available_tokens": len(self.tokens)} + + +class FakeResponse: + def __init__(self, status_code: int, text: str = "{}"): + self.status_code = status_code + self.text = text + + @property + def is_success(self) -> bool: + return 200 <= self.status_code < 300 + + +def _build_fake_async_client(handler): + class FakeAsyncClient: + def __init__(self, *args, **kwargs): + pass + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + return False + + async def post(self, url, headers=None, json=None): + return await handler(headers or {}) + + return FakeAsyncClient + + +async def _build_guest_pool( + monkeypatch, + *, + pool_size: int, + user_ids: list[str], +) -> GuestSessionPool: + pool = GuestSessionPool(pool_size=pool_size) + queue = iter(user_ids) + + async def fake_create_session() -> GuestSession: + return _make_guest_session(next(queue)) + + monkeypatch.setattr(pool, "_create_session", fake_create_session) + monkeypatch.setattr(pool, "_maintenance_loop", AsyncMock(return_value=None)) + monkeypatch.setattr(pool, "_delete_all_chats", AsyncMock(return_value=True)) + await pool.initialize() + await asyncio.sleep(0) + return pool + + +def _patch_upstream_dependencies( + monkeypatch, + *, + token_pool, + guest_pool, + async_client_cls, +): + monkeypatch.setattr(upstream_module, "get_token_pool", lambda: token_pool) + monkeypatch.setattr(upstream_module, "get_guest_session_pool", lambda: guest_pool) + monkeypatch.setattr(upstream_module.settings, "ANONYMOUS_MODE", True) + monkeypatch.setattr( + upstream_module.settings, + "GUEST_POOL_SIZE", + guest_pool.pool_size if guest_pool else 1, + ) + monkeypatch.setattr(upstream_module.httpx, "AsyncClient", async_client_cls) + + +def _bind_minimal_request_flow(client: UpstreamClient, captures: list[dict]): + async def fake_transform_request( + self, + request, + excluded_tokens=None, + excluded_guest_user_ids=None, + ): + auth_info = await self.get_auth_info( + excluded_tokens=excluded_tokens, + excluded_guest_user_ids=excluded_guest_user_ids, + ) + captures.append(dict(auth_info)) + return { + "url": "https://upstream.test/chat", + "headers": { + "x-token": str(auth_info["token"]), + "x-token-source": str(auth_info["token_source"]), + "x-guest-user-id": str(auth_info.get("guest_user_id") or ""), + }, + "body": {"model": request.model}, + "token": auth_info["token"], + "chat_id": "chat-id", + "model": request.model, + "user_id": auth_info["user_id"], + "auth_mode": auth_info["auth_mode"], + "token_source": auth_info["token_source"], + "guest_user_id": auth_info["guest_user_id"], + } + + async def fake_transform_response(self, response, request, transformed): + return { + "ok": response.is_success, + "token_source": transformed["token_source"], + "token": transformed["token"], + "guest_user_id": transformed["guest_user_id"], + } + + client.transform_request = types.MethodType(fake_transform_request, client) + client.transform_response = types.MethodType(fake_transform_response, client) + + +async def _run_chat_requests(client: UpstreamClient, count: int) -> list[dict]: + tasks = [client.chat_completion(_make_request()) for _ in range(count)] + return await asyncio.gather(*tasks) + + +@pytest.mark.asyncio +async def test_authenticated_tokens_are_used_before_guest_pool(monkeypatch): + token_pool = StubTokenPool(["auth-1"]) + guest_pool = await _build_guest_pool( + monkeypatch, + pool_size=GUEST_POOL_SIZE, + user_ids=["guest-1", "guest-2"], + ) + captures: list[dict] = [] + acquire_calls = 0 + + async def counted_acquire(*args, **kwargs): + nonlocal acquire_calls + acquire_calls += 1 + return await original_acquire(*args, **kwargs) + + async def handler(headers): + await asyncio.sleep(MIXED_REQUEST_DELAY) + return FakeResponse(200) + + client = UpstreamClient() + original_acquire = guest_pool.acquire + monkeypatch.setattr(guest_pool, "acquire", counted_acquire) + _bind_minimal_request_flow(client, captures) + _patch_upstream_dependencies( + monkeypatch, + token_pool=token_pool, + guest_pool=guest_pool, + async_client_cls=_build_fake_async_client(handler), + ) + + try: + results = await _run_chat_requests(client, AUTH_REQUEST_COUNT) + pool_status = guest_pool.get_pool_status() + finally: + await guest_pool.close() + + assert all(result["ok"] is True for result in results) + assert all(item["token_source"] == "auth_pool" for item in captures) + assert acquire_calls == 0 + assert token_pool.success_tokens == ["auth-1"] * AUTH_REQUEST_COUNT + assert token_pool.failure_tokens == [] + assert pool_status["busy_sessions"] == 0 + assert pool_status["available_sessions"] == GUEST_POOL_SIZE + + +@pytest.mark.asyncio +async def test_authenticated_401_retries_next_token_before_guest_fallback(monkeypatch): + token_pool = StubTokenPool(["auth-1", "auth-2"]) + guest_pool = await _build_guest_pool( + monkeypatch, + pool_size=GUEST_POOL_SIZE, + user_ids=["guest-1", "guest-2"], + ) + captures: list[dict] = [] + acquire_calls = 0 + + async def counted_acquire(*args, **kwargs): + nonlocal acquire_calls + acquire_calls += 1 + return await original_acquire(*args, **kwargs) + + async def handler(headers): + token = headers["x-token"] + if token == "auth-1": + return FakeResponse(401, '{"message":"expired"}') + return FakeResponse(200) + + client = UpstreamClient() + original_acquire = guest_pool.acquire + monkeypatch.setattr(guest_pool, "acquire", counted_acquire) + _bind_minimal_request_flow(client, captures) + _patch_upstream_dependencies( + monkeypatch, + token_pool=token_pool, + guest_pool=guest_pool, + async_client_cls=_build_fake_async_client(handler), + ) + + try: + result = await client.chat_completion(_make_request()) + finally: + await guest_pool.close() + + assert result["ok"] is True + assert [item["token"] for item in captures] == ["auth-1", "auth-2"] + assert [item["token_source"] for item in captures] == ["auth_pool", "auth_pool"] + assert token_pool.failure_tokens == ["auth-1"] + assert token_pool.success_tokens == ["auth-2"] + assert acquire_calls == 0 + + +@pytest.mark.asyncio +async def test_authenticated_pool_exhaustion_falls_back_to_guest(monkeypatch): + token_pool = StubTokenPool(["auth-1", "auth-2"]) + guest_pool = await _build_guest_pool( + monkeypatch, + pool_size=GUEST_POOL_SIZE, + user_ids=["guest-1", "guest-2", "guest-3"], + ) + captures: list[dict] = [] + + async def handler(headers): + if headers["x-token-source"] == "auth_pool": + return FakeResponse(401, '{"message":"expired"}') + return FakeResponse(200) + + client = UpstreamClient() + _bind_minimal_request_flow(client, captures) + _patch_upstream_dependencies( + monkeypatch, + token_pool=token_pool, + guest_pool=guest_pool, + async_client_cls=_build_fake_async_client(handler), + ) + + try: + result = await client.chat_completion(_make_request()) + pool_status = guest_pool.get_pool_status() + finally: + await guest_pool.close() + + assert result["ok"] is True + assert [item["token_source"] for item in captures] == [ + "auth_pool", + "auth_pool", + "guest_pool", + ] + assert token_pool.failure_tokens == ["auth-1", "auth-2"] + assert token_pool.success_tokens == [] + assert result["guest_user_id"] + assert pool_status["busy_sessions"] == 0 + + +@pytest.mark.asyncio +async def test_guest_retry_is_isolated_and_does_not_pollute_auth_stats(monkeypatch): + token_pool = StubTokenPool(["auth-1", "auth-2"]) + guest_pool = await _build_guest_pool( + monkeypatch, + pool_size=GUEST_POOL_SIZE, + user_ids=["guest-1", "guest-2", "guest-3", "guest-4"], + ) + captures: list[dict] = [] + + async def handler(headers): + source = headers["x-token-source"] + guest_user_id = headers["x-guest-user-id"] + if source == "auth_pool": + return FakeResponse(401, '{"message":"expired"}') + if guest_user_id == "guest-1": + return FakeResponse(401, '{"message":"expired"}') + return FakeResponse(200) + + client = UpstreamClient() + _bind_minimal_request_flow(client, captures) + _patch_upstream_dependencies( + monkeypatch, + token_pool=token_pool, + guest_pool=guest_pool, + async_client_cls=_build_fake_async_client(handler), + ) + + try: + result = await client.chat_completion(_make_request()) + pool_status = guest_pool.get_pool_status() + finally: + await guest_pool.close() + + guest_ids = [ + item["guest_user_id"] + for item in captures + if item["token_source"] == "guest_pool" + ] + + assert result["ok"] is True + assert [item["token"] for item in captures[:2]] == ["auth-1", "auth-2"] + assert token_pool.failure_tokens == ["auth-1", "auth-2"] + assert token_pool.success_tokens == [] + assert guest_ids[0] == "guest-1" + assert guest_ids[1] != "guest-1" + assert pool_status["busy_sessions"] == 0 + assert pool_status["valid_sessions"] == GUEST_POOL_SIZE + + +@pytest.mark.asyncio +async def test_cleanup_idle_chats_only_touches_idle_valid_sessions(monkeypatch): + guest_pool = await _build_guest_pool( + monkeypatch, + pool_size=3, + user_ids=["guest-1", "guest-2", "guest-3"], + ) + deleted_user_ids: list[str] = [] + + async def fake_delete_all_chats(session: GuestSession): + deleted_user_ids.append(session.user_id) + return True + + monkeypatch.setattr(guest_pool, "_delete_all_chats", fake_delete_all_chats) + guest_pool._sessions["guest-2"].active_requests = 1 + + try: + await guest_pool.cleanup_idle_chats() + deleted_before_close = list(deleted_user_ids) + finally: + await guest_pool.close() + + assert deleted_before_close == ["guest-1", "guest-3"] + + +@pytest.mark.asyncio +async def test_report_failure_only_retires_target_guest_session(monkeypatch): + guest_pool = await _build_guest_pool( + monkeypatch, + pool_size=3, + user_ids=["guest-1", "guest-2", "guest-3", "guest-4"], + ) + deleted_user_ids: list[str] = [] + + async def fake_delete_all_chats(session: GuestSession): + deleted_user_ids.append(session.user_id) + return True + + monkeypatch.setattr(guest_pool, "_delete_all_chats", fake_delete_all_chats) + + try: + await guest_pool.report_failure("guest-1") + await asyncio.sleep(0) + current_user_ids = set(guest_pool._sessions) + deleted_before_close = list(deleted_user_ids) + finally: + await guest_pool.close() + + assert "guest-1" not in current_user_ids + assert "guest-2" in current_user_ids + assert "guest-3" in current_user_ids + assert "guest-4" in current_user_ids + assert deleted_before_close == ["guest-1"]