diff --git a/.env.example b/.env.example
new file mode 100644
index 0000000000000000000000000000000000000000..87b12e9227ca873988067e17d9bf5b9a026070bc
--- /dev/null
+++ b/.env.example
@@ -0,0 +1,72 @@
+# 代理服务配置文件示例
+# 复制此文件为 .env 并根据需要修改配置值
+
+# ========== API 基础配置 ==========
+# 客户端访问本服务使用的 Bearer 密钥,不是上游 Z.AI 用户 Token
+# 上游用户 Token 请在管理后台导入,由数据库 Token 池统一管理
+AUTH_TOKEN=sk-your-api-key
+
+# 跳过客户端认证(仅开发环境使用)
+SKIP_AUTH_TOKEN=false
+
+# ========== 用户 Token 池配置 ==========
+# 仅作用于管理后台导入的 Z.AI 用户 Token
+# 失败多少次后标记为不可用
+TOKEN_FAILURE_THRESHOLD=3
+
+# 失败 Token 多久后重新参与调度(秒)
+TOKEN_RECOVERY_TIMEOUT=1800
+
+# 定时扫描服务端目录导入 Token
+TOKEN_AUTO_IMPORT_ENABLED=false
+
+# 自动导入的服务端本地目录
+TOKEN_AUTO_IMPORT_SOURCE_DIR=
+
+# 自动导入扫描间隔(秒)
+TOKEN_AUTO_IMPORT_INTERVAL=300
+
+# 定时维护 Token 池
+TOKEN_AUTO_MAINTENANCE_ENABLED=false
+
+# 自动维护执行间隔(秒)
+TOKEN_AUTO_MAINTENANCE_INTERVAL=1800
+
+# 自动维护动作开关
+TOKEN_AUTO_REMOVE_DUPLICATES=true
+TOKEN_AUTO_HEALTH_CHECK=true
+TOKEN_AUTO_DELETE_INVALID=false
+
+# ========== 匿名 Guest 会话池 ==========
+# false: 禁用 guest 匿名池,仅使用后台导入的用户 Token 池
+# true: 启用 guest 匿名池;当没有可用用户 Token 时允许匿名会话
+ANONYMOUS_MODE=true
+
+# 预热和维持的 guest 会话数量
+GUEST_POOL_SIZE=10
+
+# ========== 服务器配置 ==========
+LISTEN_PORT=8080
+SERVICE_NAME=api-proxy-server
+DEBUG_LOGGING=false
+
+# Nginx 反向代理路径前缀(可选)
+ROOT_PATH=
+
+# Function Call 功能开关
+TOOL_SUPPORT=true
+
+# 工具调用扫描限制(字符数)
+SCAN_LIMIT=200000
+
+# SQLite 数据库路径
+DB_PATH=tokens.db
+
+# ========== 代理配置 ==========
+# HTTP_PROXY=http://127.0.0.1:7890
+# HTTPS_PROXY=http://127.0.0.1:7890
+# SOCKS5_PROXY=socks5://127.0.0.1:1080
+
+# ========== 管理后台认证 ==========
+ADMIN_PASSWORD=admin123
+SESSION_SECRET_KEY=your-secret-key-change-in-production
diff --git a/.gitattributes b/.gitattributes
new file mode 100644
index 0000000000000000000000000000000000000000..dfe0770424b2a19faf507a501ebfc23be8f54e7b
--- /dev/null
+++ b/.gitattributes
@@ -0,0 +1,2 @@
+# Auto detect text files and perform LF normalization
+* text=auto
diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml
new file mode 100644
index 0000000000000000000000000000000000000000..94e5dc92561ccaca2774fcbbeea5b77b193296ae
--- /dev/null
+++ b/.github/workflows/docker.yml
@@ -0,0 +1,64 @@
+name: Build and Push Docker Image
+
+on:
+ push:
+ branches:
+ - main
+ tags:
+ - 'v*'
+
+env:
+ IMAGE_NAME: z-ai2api-python
+
+jobs:
+ docker:
+ runs-on: ubuntu-latest
+ permissions:
+ contents: read
+ packages: write
+
+ steps:
+ - name: Checkout
+ uses: actions/checkout@v4
+
+ - name: Set up Docker Buildx
+ uses: docker/setup-buildx-action@v3
+
+ - name: Login to GitHub Container Registry
+ uses: docker/login-action@v3
+ with:
+ registry: ghcr.io
+ username: ${{ github.actor }}
+ password: ${{ secrets.GITHUB_TOKEN }}
+
+ - name: Login to Docker Hub
+ if: github.event_name != 'pull_request'
+ uses: docker/login-action@v3
+ with:
+ username: ${{ secrets.DOCKERHUB_USERNAME }}
+ password: ${{ secrets.DOCKERHUB_TOKEN }}
+
+ - name: Extract metadata
+ id: meta
+ uses: docker/metadata-action@v5
+ with:
+ images: |
+ ghcr.io/${{ github.repository }}
+ ${{ secrets.DOCKERHUB_USERNAME }}/${{ env.IMAGE_NAME }}
+ tags: |
+ type=ref,event=branch
+ type=semver,pattern={{version}}
+ type=semver,pattern={{major}}.{{minor}}
+ type=raw,value=latest,enable={{is_default_branch}}
+
+ - name: Build and push
+ uses: docker/build-push-action@v5
+ with:
+ context: .
+ file: ./deploy/Dockerfile
+ platforms: linux/amd64,linux/arm64
+ push: true
+ tags: ${{ steps.meta.outputs.tags }}
+ labels: ${{ steps.meta.outputs.labels }}
+ cache-from: type=gha
+ cache-to: type=gha,mode=max
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..69f8a8698cf4db10d62e7b703b535076ad1b1796
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,181 @@
+# Custom
+.vs/
+.vscode/
+.idea/
+.conda/
+*.zip
+*.txt
+*.pid
+docs/
+output/
+main.build/
+main.dist/
+main.onefile-build/
+*report.xml
+*.yaml
+logs/
+backup/
+uv.lock
+AGENTS.md
+*.db
+
+# AI Toolset
+.augment/
+.cursor/
+.claude/
+CLAUDE.md
+
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+share/python-wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# PyInstaller
+# Usually these files are written by a python script from a template
+# before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.nox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+*.py,cover
+.hypothesis/
+.pytest_cache/
+cover/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+db.sqlite3
+db.sqlite3-journal
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+.pybuilder/
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# IPython
+profile_default/
+ipython_config.py
+
+# pyenv
+# For a library or package, you might want to ignore these files since the code is
+# intended to run in multiple environments; otherwise, check them in:
+# .python-version
+
+# pipenv
+# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
+# However, in case of collaboration, if having platform-specific dependencies or dependencies
+# having no cross-platform support, pipenv may install dependencies that don't work, or not
+# install all needed dependencies.
+#Pipfile.lock
+
+# poetry
+# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
+# This is especially recommended for binary packages to ensure reproducibility, and is more
+# commonly ignored for libraries.
+# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
+#poetry.lock
+
+# pdm
+# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
+#pdm.lock
+# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
+# in version control.
+# https://pdm.fming.dev/#use-with-ide
+.pdm.toml
+
+# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
+__pypackages__/
+
+# Celery stuff
+celerybeat-schedule
+celerybeat.pid
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.env
+.venv
+env/
+venv/
+ENV/
+env.bak/
+venv.bak/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
+.dmypy.json
+dmypy.json
+
+# Pyre type checker
+.pyre/
+
+# pytype static type analyzer
+.pytype/
+
+# Cython debug symbols
+cython_debug/
+.ace-tool/
diff --git a/Dockerfile b/Dockerfile
new file mode 100644
index 0000000000000000000000000000000000000000..e8f5578f2e26366d0f292a785aa952035f3d40b7
--- /dev/null
+++ b/Dockerfile
@@ -0,0 +1,30 @@
+FROM python:3.12-slim
+
+# Set environment variables
+ENV LISTEN_PORT=7860
+ENV DB_PATH=/app/data/tokens.db
+ENV PYTHONUNBUFFERED=1
+
+# Set working directory
+WORKDIR /app
+
+# Create data and logs directories and set permissions
+# HF Spaces runs as user 1000, so we make sure it can write to these directories
+RUN mkdir -p /app/data /app/logs && \
+ chmod -R 777 /app/data /app/logs
+
+# Install dependencies
+COPY requirements.txt .
+RUN pip install --no-cache-dir -r requirements.txt
+
+# Copy application code
+COPY . .
+
+# Ensure all files are accessible
+RUN chmod -R 777 /app
+
+# Expose port
+EXPOSE 7860
+
+# Run the application
+CMD ["python", "main.py"]
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..d86c4ab26f3c0bdf7eabcbf887e5bfc0b9653212
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2025 ZyphrZero
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
diff --git a/README.md b/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..5f06ada691a9bef5f03ccde5e079b008f1e087f5
--- /dev/null
+++ b/README.md
@@ -0,0 +1,132 @@
+---
+title: Z.ai API
+emoji: 🚀
+colorFrom: blue
+colorTo: indigo
+sdk: docker
+app_port: 7860
+---
+
+# z-ai2api_python
+
+基于 FastAPI + Granian 的 GLM 代理服务
+适合本地开发、自托管代理、Token 池管理和兼容客户端接入
+
+中文简体 / [English](README_EN.md)
+
+## 特性
+
+- 兼容 `OpenAI`、`Claude Code`、`Anthropic` 风格请求
+- 支持流式响应、工具调用、Thinking 模型
+- 内置 Token 池,支持轮询、失败熔断、恢复和健康检查
+- 提供后台页面:仪表盘、Token 管理、配置管理、实时日志
+- 使用 SQLite 存储 Token 和请求日志,部署简单
+- 支持本地运行和 Docker / Docker Compose 部署
+
+## 快速开始
+
+### 环境要求
+
+- Python `3.9` 到 `3.12`
+- 推荐使用 `uv`
+
+### 本地启动
+
+```bash
+git clone https://github.com/ZyphrZero/z.ai2api_python.git
+cd z.ai2api_python
+
+uv sync
+cp .env.example .env
+uv run python main.py
+```
+
+首次启动会自动初始化数据库。
+
+默认地址:
+
+- API 根路径:`http://127.0.0.1:8080`
+- OpenAI 文档:`http://127.0.0.1:8080/docs`
+- 管理后台:`http://127.0.0.1:8080/admin`
+
+### Docker Compose
+
+```bash
+docker compose -f deploy/docker-compose.yml up -d --build
+```
+
+更多部署说明见 [deploy/README_DOCKER.md](deploy/README_DOCKER.md)。
+
+## 最小配置
+
+至少建议确认这些环境变量:
+
+| 变量 | 说明 |
+| --- | --- |
+| `AUTH_TOKEN` | 客户端访问本服务使用的 Bearer Token |
+| `ADMIN_PASSWORD` | 管理后台登录密码,默认值必须修改 |
+| `LISTEN_PORT` | 服务监听端口,默认 `8080` |
+| `ANONYMOUS_MODE` | 是否启用匿名模式 |
+| `GUEST_POOL_SIZE` | 匿名池容量 |
+| `DB_PATH` | SQLite 数据库路径 |
+| `TOKEN_FAILURE_THRESHOLD` | Token 连续失败阈值 |
+| `TOKEN_RECOVERY_TIMEOUT` | Token 恢复等待时间 |
+
+完整配置请看 [.env.example](.env.example)。
+
+## 管理后台
+
+管理后台统一入口:
+
+- `/admin`:仪表盘
+- `/admin/tokens`:Token 管理
+- `/admin/config`:配置管理
+- `/admin/logs`:实时日志
+
+## 常用命令
+
+```bash
+# 启动服务
+uv run python main.py
+
+# 运行测试
+uv run pytest
+
+# 运行一个现有 smoke test
+uv run python tests/test_simple_signature.py
+
+# Lint
+uv run ruff check app tests main.py
+```
+
+## 兼容接口
+
+常见接口入口:
+
+- OpenAI 兼容:`/v1/chat/completions`
+- Anthropic 兼容:`/v1/messages`
+- Claude Code 兼容:`/anthropic/v1/messages`
+
+模型映射和默认模型可在 `.env` 或后台配置页中调整。
+
+## ⭐ Star History
+
+[](https://star-history.com/#ZyphrZero/z.ai2api_python&Date)
+
+## 许可证
+
+本项目采用 MIT 许可证 - 详见 [LICENSE](LICENSE) 文件。
+
+## 免责声明
+
+- **本项目仅供学习和研究使用,切勿用于其他用途**
+- 本项目与 Z.AI 官方无关
+- 使用前请确保遵守 Z.AI 的服务条款
+- 请勿用于商业用途或违反使用条款的场景
+- 用户需自行承担使用风险
+
+---
+
+
+Made with ❤️ by the community
+
diff --git a/README_EN.md b/README_EN.md
new file mode 100644
index 0000000000000000000000000000000000000000..cbfde08e41dc433937c861a06248f0911b6f8c32
--- /dev/null
+++ b/README_EN.md
@@ -0,0 +1,123 @@
+# z-ai2api_python
+
+GLM proxy service based on FastAPI + Granian
+Suitable for local development, self-hosted proxy, Token pool management, and compatible client access
+
+English / [中文简体](README.md)
+
+## Features
+
+- Compatible with `OpenAI`, `Claude Code`, `Anthropic` style requests
+- Supports streaming responses, tool calls, Thinking models
+- Built-in Token pool, supports polling, failure circuit breaker, recovery, and health checks
+- Provides admin panel: Dashboard, Token management, Configuration management, Real-time logs
+- Uses SQLite to store Tokens and request logs, simple deployment
+- Supports local running and Docker / Docker Compose deployment
+
+## Quick Start
+
+### Environment Requirements
+
+- Python `3.9` to `3.12`
+- Recommend using `uv`
+
+### Local Startup
+
+```bash
+git clone https://github.com/ZyphrZero/z.ai2api_python.git
+cd z.ai2api_python
+
+uv sync
+cp .env.example .env
+uv run python main.py
+```
+
+First startup will automatically initialize the database.
+
+Default addresses:
+
+- API root path: `http://127.0.0.1:8080`
+- OpenAI docs: `http://127.0.0.1:8080/docs`
+- Admin panel: `http://127.0.0.1:8080/admin`
+
+### Docker Compose
+
+```bash
+docker compose -f deploy/docker-compose.yml up -d --build
+```
+
+More deployment instructions see [deploy/README_DOCKER.md](deploy/README_DOCKER.md).
+
+## Minimum Configuration
+
+At least suggest confirming these environment variables:
+
+| Variable | Description |
+| --- | --- |
+| `AUTH_TOKEN` | Bearer Token used by clients to access this service |
+| `ADMIN_PASSWORD` | Admin panel login password, default value must be changed |
+| `LISTEN_PORT` | Service listening port, default `8080` |
+| `ANONYMOUS_MODE` | Whether to enable anonymous mode |
+| `GUEST_POOL_SIZE` | Anonymous pool capacity |
+| `DB_PATH` | SQLite database path |
+| `TOKEN_FAILURE_THRESHOLD` | Token consecutive failure threshold |
+| `TOKEN_RECOVERY_TIMEOUT` | Token recovery wait time |
+
+Complete configuration please see [.env.example](.env.example).
+
+## Admin Panel
+
+Admin panel unified entry:
+
+- `/admin`: Dashboard
+- `/admin/tokens`: Token management
+- `/admin/config`: Configuration management
+- `/admin/logs`: Real-time logs
+
+## Common Commands
+
+```bash
+# Start service
+uv run python main.py
+
+# Run tests
+uv run pytest
+
+# Run an existing smoke test
+uv run python tests/test_simple_signature.py
+
+# Lint
+uv run ruff check app tests main.py
+```
+
+## Compatible Interfaces
+
+Common interface entries:
+
+- OpenAI compatible: `/v1/chat/completions`
+- Anthropic compatible: `/v1/messages`
+- Claude Code compatible: `/anthropic/v1/messages`
+
+Model mapping and default model can be adjusted in `.env` or admin configuration page.
+
+## ⭐ Star History
+
+[](https://star-history.com/#ZyphrZero/z.ai2api_python&Date)
+
+## License
+
+This project uses MIT license - see [LICENSE](LICENSE) file for details.
+
+## Disclaimer
+
+- **This project is for learning and research use only, do not use for other purposes**
+- This project is not affiliated with Z.AI official
+- Please ensure compliance with Z.AI's terms of service before use
+- Do not use for commercial purposes or scenarios that violate terms of service
+- Users must bear their own usage risks
+
+---
+
+
+Made with ❤️ by the community
+
\ No newline at end of file
diff --git a/app/__init__.py b/app/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..101e6a38bb631969e65ff40b512eb45668947047
--- /dev/null
+++ b/app/__init__.py
@@ -0,0 +1,6 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+
+from app import core, models, utils
+
+__all__ = ["core", "models", "utils"]
diff --git a/app/admin/__init__.py b/app/admin/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..36cacc8ac4ca1ac050d6d7f340c82fc4de024519
--- /dev/null
+++ b/app/admin/__init__.py
@@ -0,0 +1,3 @@
+"""
+管理后台模块初始化
+"""
diff --git a/app/admin/api.py b/app/admin/api.py
new file mode 100644
index 0000000000000000000000000000000000000000..af6b040f3fe786984d9ad4d8b05cbba841049f1e
--- /dev/null
+++ b/app/admin/api.py
@@ -0,0 +1,1111 @@
+"""
+管理后台 API 接口
+用于 htmx 调用的 HTML 片段返回
+"""
+from datetime import datetime
+from html import escape
+from pathlib import Path
+import re
+from typing import Optional
+
+from fastapi import APIRouter, Depends, Request
+from fastapi.responses import HTMLResponse, JSONResponse
+from fastapi.templating import Jinja2Templates
+
+from app.admin.auth import require_auth
+from app.admin.config_manager import (
+ read_env_content,
+ reset_env_to_example,
+ save_form_config,
+ save_source_config,
+)
+from app.admin.stats import collect_admin_stats, normalize_trend_window
+from app.services.request_log_dao import get_request_log_dao
+from app.utils.logger import logger
+
+router = APIRouter(prefix="/admin/api", tags=["admin-api"])
+templates = Jinja2Templates(directory="app/templates")
+DEFAULT_TOKEN_NAMESPACE = "zai"
+
+
+# ==================== 认证 API ====================
+
+@router.post("/login")
+async def login(request: Request):
+ """管理后台登录"""
+ from app.admin.auth import create_session
+
+ try:
+ data = await request.json()
+ password = data.get("password", "")
+
+ # 创建 session
+ session_token = create_session(password)
+
+ if session_token:
+ # 登录成功,设置 cookie
+ response = JSONResponse({
+ "success": True,
+ "message": "登录成功"
+ })
+ response.set_cookie(
+ key="admin_session",
+ value=session_token,
+ httponly=True,
+ max_age=86400, # 24小时
+ samesite="lax"
+ )
+ logger.info("✅ 管理后台登录成功")
+ return response
+ else:
+ # 密码错误
+ logger.warning("❌ 管理后台登录失败:密码错误")
+ return JSONResponse({
+ "success": False,
+ "message": "密码错误"
+ }, status_code=401)
+
+ except Exception as e:
+ logger.error(f"❌ 登录异常: {e}")
+ return JSONResponse({
+ "success": False,
+ "message": "登录失败"
+ }, status_code=500)
+
+
+@router.post("/logout")
+async def logout(request: Request):
+ """管理后台登出"""
+ from app.admin.auth import delete_session, get_session_token_from_request
+
+ session_token = get_session_token_from_request(request)
+ delete_session(session_token)
+
+ # 清除 cookie
+ response = JSONResponse({
+ "success": True,
+ "message": "已登出"
+ })
+ response.delete_cookie("admin_session")
+ logger.info("✅ 管理后台已登出")
+ return response
+
+
+async def reload_settings():
+ """热重载配置(重新加载环境变量并更新 settings 对象)"""
+ from dotenv import load_dotenv
+
+ from app.core.config import settings
+ from app.utils.logger import setup_logger
+
+ # 重新加载 .env 文件
+ load_dotenv(override=True)
+
+ # 重新创建 Settings 对象并更新全局配置
+ new_settings = type(settings)()
+
+ # 更新全局 settings 的所有属性
+ for field_name in new_settings.model_fields.keys():
+ setattr(settings, field_name, getattr(new_settings, field_name))
+
+ # 重新初始化 logger(使用新的 DEBUG_LOGGING 配置)
+ setup_logger(log_dir="logs", debug_mode=settings.DEBUG_LOGGING)
+
+ logger.info(f"🔄 配置已热重载 (DEBUG_LOGGING={settings.DEBUG_LOGGING})")
+
+
+def _build_alert(
+ message: str,
+ *,
+ title: str,
+ level: str,
+ status_code: int = 200,
+) -> HTMLResponse:
+ level_classes = {
+ "success": "bg-green-100 border-green-400 text-green-700",
+ "warning": "bg-yellow-100 border-yellow-400 text-yellow-700",
+ "error": "bg-red-100 border-red-400 text-red-700",
+ "info": "bg-blue-100 border-blue-400 text-blue-700",
+ }
+ classes = level_classes.get(level, level_classes["info"])
+ safe_title = escape(title)
+ safe_message = escape(message)
+ return HTMLResponse(
+ f"""
+
+ {safe_title}
+ {safe_message}
+
+ """,
+ status_code=status_code,
+ )
+
+
+def _with_hx_trigger(response: HTMLResponse, event_name: str) -> HTMLResponse:
+ response.headers["HX-Trigger"] = event_name
+ return response
+
+
+def _get_int_query_param(
+ request: Request,
+ name: str,
+ default: int,
+ *,
+ minimum: int = 1,
+ maximum: Optional[int] = None,
+) -> int:
+ """解析查询参数中的正整数,非法值回退到默认值。"""
+ raw_value = request.query_params.get(name)
+ if raw_value is None:
+ return default
+
+ try:
+ value = int(str(raw_value).strip())
+ except (TypeError, ValueError):
+ return default
+
+ value = max(minimum, value)
+ if maximum is not None:
+ value = min(value, maximum)
+ return value
+
+
+def _build_pagination(
+ *,
+ total_items: int,
+ page: int,
+ page_size: int,
+) -> dict:
+ """构建分页上下文。"""
+ total_items = max(0, int(total_items))
+ page_size = max(1, int(page_size))
+ total_pages = max(1, (total_items + page_size - 1) // page_size)
+ current_page = min(max(1, int(page)), total_pages)
+
+ if total_items == 0:
+ start_item = 0
+ end_item = 0
+ else:
+ start_item = (current_page - 1) * page_size + 1
+ end_item = min(total_items, current_page * page_size)
+
+ return {
+ "current_page": current_page,
+ "page_size": page_size,
+ "total_items": total_items,
+ "total_pages": total_pages,
+ "has_previous": current_page > 1,
+ "has_next": current_page < total_pages,
+ "previous_page": max(1, current_page - 1),
+ "next_page": min(total_pages, current_page + 1),
+ "start_item": start_item,
+ "end_item": end_item,
+ }
+
+
+def _normalize_display_value(value: str) -> str:
+ normalized = re.sub(r"[^a-z0-9]+", "", str(value or "").casefold())
+ return normalized
+
+
+def _is_redundant_source(source: str, client_name: str) -> bool:
+ normalized_source = _normalize_display_value(source)
+ normalized_client = _normalize_display_value(client_name)
+ if not normalized_source:
+ return True
+ if not normalized_client:
+ return False
+ return normalized_source == normalized_client
+
+
+def _humanize_protocol(protocol: str) -> str:
+ normalized = str(protocol or "").strip().lower()
+ if normalized == "openai":
+ return "OpenAI"
+ if normalized == "anthropic":
+ return "Anthropic"
+ if normalized == "unknown":
+ return "Unknown"
+ return normalized or "Unknown"
+
+
+@router.get(
+ "/dashboard/usage-trend",
+ response_class=JSONResponse,
+ dependencies=[Depends(require_auth)],
+)
+async def get_dashboard_usage_trend(request: Request):
+ """返回仪表盘趋势图数据。"""
+ trend_window = normalize_trend_window(
+ request.query_params.get("window")
+ )
+ dao = get_request_log_dao()
+ trend_points = await dao.get_provider_usage_trend(
+ DEFAULT_TOKEN_NAMESPACE,
+ window=trend_window,
+ )
+ return JSONResponse(
+ {
+ "window": trend_window,
+ "points": trend_points,
+ }
+ )
+
+
+def _validate_directory_path(source_dir: str) -> str:
+ if not source_dir:
+ raise ValueError("请先填写服务端可访问的本地目录路径。")
+
+ source_path = Path(source_dir).expanduser()
+ if not source_path.exists():
+ raise ValueError(f"导入目录不存在: {source_path}")
+ if not source_path.is_dir():
+ raise ValueError(f"导入路径不是目录: {source_path}")
+
+ return str(source_path)
+
+
+@router.get("/token-pool", response_class=HTMLResponse)
+async def get_token_pool_status(request: Request):
+ """获取 Token 池状态(HTML 片段)"""
+ from app.utils.token_pool import get_token_pool
+
+ token_pool = get_token_pool()
+
+ if not token_pool:
+ # Token 池未初始化
+ context = {
+ "request": request,
+ "tokens": [],
+ }
+ return templates.TemplateResponse("components/token_pool.html", context)
+
+ # 获取 token 状态统计
+ pool_status = token_pool.get_pool_status()
+ tokens_info = []
+
+ for idx, token_info in enumerate(pool_status.get("tokens", []), 1):
+ is_available = token_info.get("is_available", False)
+ is_healthy = token_info.get("is_healthy", False)
+
+ # 确定状态和颜色
+ if is_healthy:
+ status = "健康"
+ status_color = "bg-green-100 text-green-800"
+ elif is_available:
+ status = "可用"
+ status_color = "bg-yellow-100 text-yellow-800"
+ else:
+ status = "失败"
+ status_color = "bg-red-100 text-red-800"
+
+ # 格式化最后使用时间
+ last_success = token_info.get("last_success_time", 0)
+ if last_success > 0:
+ from datetime import datetime
+ last_used = datetime.fromtimestamp(last_success).strftime("%Y-%m-%d %H:%M:%S")
+ else:
+ last_used = "从未使用"
+
+ tokens_info.append({
+ "index": idx,
+ "key": token_info.get("token", "")[:20] + "...",
+ "status": status,
+ "status_color": status_color,
+ "last_used": last_used,
+ "failure_count": token_info.get("failure_count", 0),
+ "success_rate": token_info.get("success_rate", "0%"),
+ "token_type": token_info.get("token_type", "unknown"),
+ })
+
+ context = {
+ "request": request,
+ "tokens": tokens_info,
+ }
+
+ return templates.TemplateResponse("components/token_pool.html", context)
+
+
+@router.get("/recent-logs", response_class=HTMLResponse)
+async def get_recent_logs(request: Request):
+ """获取最近的请求日志(HTML 片段)"""
+ dao = get_request_log_dao()
+ page_size = _get_int_query_param(
+ request,
+ "page_size",
+ 12,
+ maximum=50,
+ )
+ requested_page = _get_int_query_param(request, "page", 1, maximum=100000)
+ total_count = await dao.count_logs()
+ pagination = _build_pagination(
+ total_items=total_count,
+ page=requested_page,
+ page_size=page_size,
+ )
+
+ rows = await dao.get_recent_logs(
+ limit=page_size,
+ offset=(pagination["current_page"] - 1) * page_size,
+ )
+ logs = []
+ for row in rows:
+ timestamp = (
+ row.get("timestamp")
+ or row.get("created_at")
+ or datetime.now().strftime("%Y-%m-%d %H:%M:%S")
+ )
+ success = bool(row.get("success"))
+ status_code = int(
+ row.get("status_code") or (200 if success else 500)
+ )
+ duration_value = float(row.get("duration") or 0.0)
+ first_token_value = float(row.get("first_token_time") or 0.0)
+ source = row.get("source") or "unknown"
+ client_name = row.get("client_name") or "Unknown"
+ provider = row.get("provider") or "-"
+ source_display = (
+ ""
+ if _is_redundant_source(source, client_name)
+ else source
+ )
+ provider_display = "" if provider == "zai" else provider
+ logs.append(
+ {
+ "timestamp": timestamp,
+ "endpoint": row.get("endpoint") or "-",
+ "model": row.get("model") or "-",
+ "provider": provider,
+ "provider_display": provider_display,
+ "source": source,
+ "source_display": source_display,
+ "protocol": row.get("protocol") or "unknown",
+ "protocol_display": _humanize_protocol(
+ row.get("protocol") or "unknown"
+ ),
+ "client_name": client_name,
+ "success": success,
+ "status_code": status_code,
+ "duration_display": f"{duration_value:.2f}s",
+ "first_token_display": (
+ f"{first_token_value:.2f}s"
+ if first_token_value > 0
+ else "--"
+ ),
+ "input_tokens": int(row.get("input_tokens") or 0),
+ "output_tokens": int(row.get("output_tokens") or 0),
+ "cache_creation_tokens": int(
+ row.get("cache_creation_tokens") or 0
+ ),
+ "cache_read_tokens": int(
+ row.get("cache_read_tokens") or 0
+ ),
+ "error_message": row.get("error_message") or "",
+ }
+ )
+
+ context = {
+ "request": request,
+ "logs": logs,
+ "page": pagination,
+ }
+
+ return templates.TemplateResponse("components/recent_logs.html", context)
+
+
+@router.post("/config/save", dependencies=[Depends(require_auth)])
+async def save_config(request: Request):
+ """保存结构化配置并热重载。"""
+ try:
+ form_data = await request.form()
+ await save_form_config(
+ form_data,
+ reload_callback=reload_settings,
+ )
+ logger.info("✅ 结构化配置已保存")
+ return _with_hx_trigger(
+ _build_alert(
+ "配置已保存并热重载,页面即将刷新。",
+ title="保存成功!",
+ level="success",
+ ),
+ "admin-config-refresh",
+ )
+ except ValueError as exc:
+ return _build_alert(
+ str(exc),
+ title="校验失败!",
+ level="error",
+ status_code=400,
+ )
+ except Exception as exc:
+ logger.error(f"❌ 配置保存失败: {exc}")
+ return _build_alert(
+ f"保存失败: {exc}",
+ title="错误!",
+ level="error",
+ status_code=500,
+ )
+
+
+@router.post("/config/source", dependencies=[Depends(require_auth)])
+async def save_config_source(request: Request):
+ """保存 .env 源文件并热重载。"""
+ try:
+ form_data = await request.form()
+ await save_source_config(
+ str(form_data.get("env_content", "")),
+ reload_callback=reload_settings,
+ )
+ logger.info("✅ 配置源文件已保存")
+ return _with_hx_trigger(
+ _build_alert(
+ ".env 源文件已保存并热重载,页面即将刷新。",
+ title="保存成功!",
+ level="success",
+ ),
+ "admin-config-refresh",
+ )
+ except ValueError as exc:
+ return _build_alert(
+ str(exc),
+ title="源文件校验失败!",
+ level="error",
+ status_code=400,
+ )
+ except Exception as exc:
+ logger.error(f"❌ 源文件保存失败: {exc}")
+ return _build_alert(
+ f"源文件保存失败: {exc}",
+ title="错误!",
+ level="error",
+ status_code=500,
+ )
+
+
+@router.post("/config/reset", dependencies=[Depends(require_auth)])
+async def reset_config():
+ """将配置重置为 .env.example 并热重载。"""
+ try:
+ await reset_env_to_example(reload_callback=reload_settings)
+ logger.info("✅ 配置已重置为 .env.example 默认值")
+ return _with_hx_trigger(
+ _build_alert(
+ "配置已恢复为 .env.example 默认值,页面即将刷新。",
+ title="已重置!",
+ level="success",
+ ),
+ "admin-config-refresh",
+ )
+ except FileNotFoundError:
+ logger.error("❌ 未找到 .env.example,无法重置配置")
+ return _build_alert(
+ "未找到 .env.example,无法重置配置。",
+ title="错误!",
+ level="error",
+ status_code=404,
+ )
+ except Exception as exc:
+ logger.error(f"❌ 配置重置失败: {exc}")
+ return _build_alert(
+ f"重置失败: {exc}",
+ title="错误!",
+ level="error",
+ status_code=500,
+ )
+
+
+@router.get("/env-preview", dependencies=[Depends(require_auth)])
+async def get_env_preview():
+ """获取 .env 文件预览"""
+ try:
+ content = read_env_content()
+ if not content:
+ content = "# .env 文件不存在"
+ return HTMLResponse(f"{escape(content)}")
+ except Exception as exc:
+ return HTMLResponse(f"# 读取失败: {escape(str(exc))}")
+
+
+@router.get("/live-logs", response_class=HTMLResponse)
+async def get_live_logs():
+ """获取实时日志(最新 50 行)"""
+ import os
+ from datetime import datetime
+
+ logs = []
+
+ # 尝试读取日志文件
+ log_dir = "logs"
+ if os.path.exists(log_dir):
+ log_files = sorted([f for f in os.listdir(log_dir) if f.endswith('.log')], reverse=True)
+ if log_files:
+ log_file = os.path.join(log_dir, log_files[0])
+ try:
+ with open(log_file, 'r', encoding='utf-8') as f:
+ # 读取最后 50 行
+ lines = f.readlines()[-50:]
+ logs = lines
+ except Exception as e:
+ logs = [f"# [{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] 读取日志失败: {str(e)}"]
+
+ if not logs:
+ logs = [f"# [{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] 暂无日志数据"]
+
+ html = ""
+ for log in logs:
+ log_line = log.strip()
+ if not log_line:
+ continue
+
+ # 根据日志级别设置颜色和样式
+ if "ERROR" in log_line or "CRITICAL" in log_line:
+ color_class = "text-red-400 font-semibold"
+ icon = "❌"
+ elif "WARNING" in log_line or "WARN" in log_line:
+ color_class = "text-yellow-400"
+ icon = "⚠️"
+ elif "SUCCESS" in log_line or "✅" in log_line:
+ color_class = "text-green-400"
+ icon = "✅"
+ elif "INFO" in log_line:
+ color_class = "text-blue-400"
+ icon = "ℹ️"
+ elif "DEBUG" in log_line:
+ color_class = "text-gray-400 text-xs"
+ icon = "🔍"
+ else:
+ color_class = "text-gray-300"
+ icon = "•"
+
+ # 转义 HTML 特殊字符
+ log_escaped = log_line.replace('<', '<').replace('>', '>')
+
+ html += f'{icon} {log_escaped}
'
+
+ return HTMLResponse(html)
+
+
+# ==================== Token 管理 API ====================
+
+@router.get("/tokens/list", response_class=HTMLResponse)
+async def get_tokens_list(request: Request):
+ """获取 Token 列表(HTML 片段)"""
+ from app.services.token_dao import get_token_dao
+
+ dao = get_token_dao()
+ page_size = _get_int_query_param(
+ request,
+ "page_size",
+ 20,
+ maximum=100,
+ )
+ requested_page = _get_int_query_param(request, "page", 1, maximum=100000)
+ total_count = await dao.count_tokens_by_provider(
+ DEFAULT_TOKEN_NAMESPACE,
+ enabled_only=False,
+ )
+ pagination = _build_pagination(
+ total_items=total_count,
+ page=requested_page,
+ page_size=page_size,
+ )
+ tokens = await dao.get_tokens_by_provider(
+ DEFAULT_TOKEN_NAMESPACE,
+ enabled_only=False,
+ limit=page_size,
+ offset=(pagination["current_page"] - 1) * page_size,
+ )
+
+ context = {
+ "request": request,
+ "tokens": tokens,
+ "page": pagination,
+ }
+
+ return templates.TemplateResponse("components/token_list.html", context)
+
+
+@router.post("/tokens/add")
+async def add_tokens(request: Request):
+ """添加 Token"""
+ from app.services.token_dao import get_token_dao
+ from app.utils.token_pool import get_token_pool
+
+ form_data = await request.form()
+ single_token = form_data.get("single_token", "").strip()
+ bulk_tokens = form_data.get("bulk_tokens", "").strip()
+
+ dao = get_token_dao()
+ added_count = 0
+ failed_count = 0
+
+ # 添加单个 Token(带验证)
+ if single_token:
+ token_id = await dao.add_token(
+ DEFAULT_TOKEN_NAMESPACE,
+ single_token,
+ validate=True,
+ )
+ if token_id:
+ added_count += 1
+ else:
+ failed_count += 1
+
+ # 批量添加 Token(带验证)
+ if bulk_tokens:
+ # 支持换行和逗号分隔
+ tokens = []
+ for line in bulk_tokens.split('\n'):
+ line = line.strip()
+ if ',' in line:
+ tokens.extend([t.strip() for t in line.split(',') if t.strip()])
+ elif line:
+ tokens.append(line)
+
+ success, failed = await dao.bulk_add_tokens(
+ DEFAULT_TOKEN_NAMESPACE,
+ tokens,
+ validate=True,
+ )
+ added_count += success
+ failed_count += failed
+
+ # 同步 Token 池状态(如果有新增成功的 Token)
+ if added_count > 0:
+ pool = get_token_pool()
+ if pool:
+ await pool.sync_from_database(DEFAULT_TOKEN_NAMESPACE)
+ logger.info(f"✅ Token 池已同步,新增 {added_count} 个 Token")
+
+ # 生成响应
+ if added_count > 0 and failed_count == 0:
+ return HTMLResponse(f"""
+
+ 成功!
+ 已添加 {added_count} 个有效 Token
+
+ """)
+ elif added_count > 0 and failed_count > 0:
+ return HTMLResponse(f"""
+
+ 部分成功!
+ 已添加 {added_count} 个 Token,{failed_count} 个失败(可能是重复、无效或匿名 Token)
+
+ """)
+ else:
+ return HTMLResponse("""
+
+ 失败!
+ 所有 Token 添加失败(可能是重复、无效或匿名 Token)
+
+ """)
+
+
+@router.post("/tokens/import-directory", dependencies=[Depends(require_auth)])
+async def import_tokens_from_directory_api(request: Request):
+ """从本地目录导入 token 文件。"""
+ from app.core.config import settings
+ from app.services.token_automation import run_directory_import
+
+ form_data = await request.form()
+ source_dir = str(
+ form_data.get("source_dir")
+ or settings.TOKEN_AUTO_IMPORT_SOURCE_DIR
+ or ""
+ ).strip()
+ try:
+ source_dir = _validate_directory_path(source_dir)
+ except ValueError as exc:
+ return _build_alert(
+ str(exc),
+ title="导入失败!",
+ level="error",
+ status_code=400,
+ )
+
+ try:
+ summary = await run_directory_import(
+ source_dir,
+ provider=DEFAULT_TOKEN_NAMESPACE,
+ validate=True,
+ )
+ except (FileNotFoundError, NotADirectoryError) as exc:
+ return _build_alert(
+ str(exc),
+ title="导入失败!",
+ level="error",
+ status_code=400,
+ )
+ except RuntimeError as exc:
+ return _build_alert(
+ str(exc),
+ title="导入稍后重试",
+ level="warning",
+ status_code=409,
+ )
+ except Exception as exc:
+ logger.exception(f"❌ 本地目录导入 Token 失败: {exc}")
+ return _build_alert(
+ f"目录扫描或入库异常: {exc}",
+ title="导入失败!",
+ level="error",
+ status_code=500,
+ )
+
+ if summary.imported_count > 0:
+ title = "导入成功!" if summary.failed_count == 0 else "导入完成!"
+ detail = (
+ f"目录 {summary.source_dir} 共扫描 {summary.scanned_files} 个文件,"
+ f"成功导入 {summary.imported_count} 个 Token,"
+ f"重复 {summary.duplicate_count} 个,"
+ f"无效 JSON {summary.invalid_json_count} 个,"
+ f"缺少 token {summary.missing_token_count} 个,"
+ f"验证失败 {summary.invalid_token_count} 个。"
+ )
+ return _build_alert(
+ detail,
+ title=title,
+ level="success" if summary.failed_count == 0 else "warning",
+ )
+
+ return _build_alert(
+ (
+ f"目录 {summary.source_dir} 共扫描 {summary.scanned_files} 个文件,"
+ f"其中重复 {summary.duplicate_count} 个,无效 JSON {summary.invalid_json_count} 个,"
+ f"缺少 token {summary.missing_token_count} 个,验证失败 {summary.invalid_token_count} 个。"
+ ),
+ title="未导入任何 Token!",
+ level="warning",
+ )
+
+
+@router.post("/tokens/auto-import/save", dependencies=[Depends(require_auth)])
+async def save_auto_import_settings(request: Request):
+ """兼容旧入口,提示用户改到配置管理页。"""
+ return _build_alert(
+ "自动导入配置入口已迁移到 /admin/config#tokens,当前页面仅保留手动执行入口。",
+ title="入口已迁移",
+ level="info",
+ )
+
+
+@router.post("/tokens/maintenance/save", dependencies=[Depends(require_auth)])
+async def save_auto_maintenance_settings(request: Request):
+ """兼容旧入口,提示用户改到配置管理页。"""
+ return _build_alert(
+ "自动维护配置入口已迁移到 /admin/config#tokens,当前页面仅保留手动执行入口。",
+ title="入口已迁移",
+ level="info",
+ )
+
+
+@router.post("/tokens/maintenance/run", dependencies=[Depends(require_auth)])
+async def run_token_maintenance_api(request: Request):
+ """立即执行一次 Token 维护。"""
+ from app.core.config import settings
+ from app.services.token_automation import run_token_maintenance
+
+ form_data = await request.form()
+ action_fields = (
+ "auto_remove_duplicates",
+ "auto_health_check",
+ "auto_delete_invalid",
+ )
+ has_explicit_actions = any(field in form_data for field in action_fields)
+
+ if has_explicit_actions:
+ remove_duplicates = "auto_remove_duplicates" in form_data
+ run_health_check = "auto_health_check" in form_data
+ delete_invalid = "auto_delete_invalid" in form_data
+ else:
+ remove_duplicates = settings.TOKEN_AUTO_REMOVE_DUPLICATES
+ run_health_check = settings.TOKEN_AUTO_HEALTH_CHECK
+ delete_invalid = settings.TOKEN_AUTO_DELETE_INVALID
+
+ if not any((remove_duplicates, run_health_check, delete_invalid)):
+ return _build_alert(
+ "当前没有可执行的维护动作,请先到 /admin/config#tokens 配置至少一个维护动作。",
+ title="未执行维护!",
+ level="warning",
+ status_code=400,
+ )
+
+ try:
+ summary = await run_token_maintenance(
+ provider=DEFAULT_TOKEN_NAMESPACE,
+ remove_duplicates=remove_duplicates,
+ run_health_check=run_health_check,
+ delete_invalid_tokens=delete_invalid,
+ )
+ except RuntimeError as exc:
+ return _build_alert(
+ str(exc),
+ title="维护稍后重试",
+ level="warning",
+ status_code=409,
+ )
+ except Exception as exc:
+ logger.exception(f"❌ 手动执行 Token 维护失败: {exc}")
+ return _build_alert(
+ f"Token 维护失败: {exc}",
+ title="维护失败!",
+ level="error",
+ status_code=500,
+ )
+
+ return _build_alert(
+ (
+ f"本次维护共去重 {summary.duplicate_removed_count} 个,"
+ f"测活 {summary.checked_count} 个(有效 {summary.valid_count} / "
+ f"匿名 {summary.guest_count} / 无效 {summary.invalid_count}),"
+ f"删除失效 Token {summary.deleted_invalid_count} 个。"
+ ),
+ title="维护完成!",
+ level="success",
+ )
+
+
+@router.post("/tokens/toggle/{token_id}")
+async def toggle_token(token_id: int, enabled: bool):
+ """切换 Token 启用状态"""
+ from app.services.token_dao import get_token_dao
+ from app.utils.token_pool import get_token_pool
+
+ dao = get_token_dao()
+ await dao.update_token_status(token_id, enabled)
+
+ # 同步 Token 池状态
+ pool = get_token_pool()
+ if pool:
+ # 获取 Token 的提供商信息
+ async with dao.get_connection() as conn:
+ cursor = await conn.execute("SELECT provider FROM tokens WHERE id = ?", (token_id,))
+ row = await cursor.fetchone()
+ if row:
+ provider = row[0]
+ await pool.sync_from_database(provider)
+ logger.info("✅ Token 池已同步")
+
+ # 根据状态返回不同样式的按钮
+ if enabled:
+ button_class = "bg-green-100 text-green-800 hover:bg-green-200"
+ indicator_class = "bg-green-500"
+ label = "已启用"
+ next_state = "false"
+ else:
+ button_class = "bg-red-100 text-red-800 hover:bg-red-200"
+ indicator_class = "bg-red-500"
+ label = "已禁用"
+ next_state = "true"
+
+ return HTMLResponse(f"""
+
+ """)
+
+
+@router.delete("/tokens/delete/{token_id}")
+async def delete_token(token_id: int):
+ """删除 Token"""
+ from app.services.token_dao import get_token_dao
+ from app.utils.token_pool import get_token_pool
+
+ dao = get_token_dao()
+
+ # 获取 Token 信息以确定提供商
+ async with dao.get_connection() as conn:
+ cursor = await conn.execute("SELECT provider FROM tokens WHERE id = ?", (token_id,))
+ row = await cursor.fetchone()
+ provider = row[0] if row else "zai"
+
+ await dao.delete_token(token_id)
+
+ # 同步 Token 池状态
+ pool = get_token_pool()
+ if pool:
+ await pool.sync_from_database(provider)
+ logger.info("✅ Token 池已同步")
+
+ return HTMLResponse("") # 返回空内容,让 htmx 移除元素
+
+
+@router.get("/tokens/stats", response_class=HTMLResponse)
+async def get_tokens_stats(request: Request):
+ """获取 Token 统计信息(HTML 片段)"""
+ stats_data = await collect_admin_stats(DEFAULT_TOKEN_NAMESPACE)
+
+ context = {
+ "request": request,
+ "stats": stats_data,
+ }
+
+ return templates.TemplateResponse("components/token_stats.html", context)
+
+
+@router.post("/tokens/validate")
+async def validate_tokens():
+ """批量验证 Token"""
+ from app.services.token_dao import get_token_dao
+ from app.utils.token_pool import get_token_pool
+
+ dao = get_token_dao()
+
+ # 执行批量验证
+ stats = await dao.validate_all_tokens(DEFAULT_TOKEN_NAMESPACE)
+
+ pool = get_token_pool()
+ if pool:
+ await pool.sync_from_database(DEFAULT_TOKEN_NAMESPACE)
+
+ valid_count = stats.get("valid", 0)
+ guest_count = stats.get("guest", 0)
+ invalid_count = stats.get("invalid", 0)
+
+ # 生成通知消息
+ if guest_count > 0:
+ message_class = "bg-yellow-100 border-yellow-400 text-yellow-700"
+ message = f"验证完成:有效 {valid_count} 个,匿名 {guest_count} 个,无效 {invalid_count} 个。匿名 Token 已标记。"
+ elif invalid_count > 0:
+ message_class = "bg-blue-100 border-blue-400 text-blue-700"
+ message = f"验证完成:有效 {valid_count} 个,无效 {invalid_count} 个。"
+ else:
+ message_class = "bg-green-100 border-green-400 text-green-700"
+ message = f"验证完成:所有 {valid_count} 个 Token 均有效!"
+
+ return HTMLResponse(f"""
+
+ 批量验证完成!
+ {message}
+
+ """)
+
+
+@router.post("/tokens/validate-single/{token_id}")
+async def validate_single_token(request: Request, token_id: int):
+ """验证单个 Token 并返回更新后的行"""
+ from app.services.token_dao import get_token_dao
+ from app.utils.token_pool import get_token_pool
+
+ dao = get_token_dao()
+
+ # 验证 Token
+ await dao.validate_and_update_token(token_id)
+
+ pool = get_token_pool()
+ if pool:
+ await pool.sync_from_database(DEFAULT_TOKEN_NAMESPACE)
+
+ # 获取更新后的 Token 信息
+ async with dao.get_connection() as conn:
+ cursor = await conn.execute("""
+ SELECT t.*, ts.total_requests, ts.successful_requests, ts.failed_requests,
+ ts.last_success_time, ts.last_failure_time
+ FROM tokens t
+ LEFT JOIN token_stats ts ON t.id = ts.token_id
+ WHERE t.id = ?
+ """, (token_id,))
+ row = await cursor.fetchone()
+
+ if row:
+ # 返回更新后的单行 HTML
+ token = dict(row)
+ context = {
+ "request": request,
+ "token": token,
+ }
+ # 使用单行模板渲染
+ return templates.TemplateResponse("components/token_row.html", context)
+ else:
+ return HTMLResponse("")
+
+
+@router.post("/tokens/health-check")
+async def health_check_tokens():
+ """执行 Token 池健康检查"""
+ from app.utils.token_pool import get_token_pool
+
+ pool = get_token_pool()
+
+ if not pool:
+ return HTMLResponse("""
+
+ 提示!
+ Token 池未初始化,请重启服务。
+
+ """)
+
+ # 执行健康检查
+ await pool.health_check_all()
+
+ # 获取健康状态
+ status = pool.get_pool_status()
+ healthy_count = status.get("healthy_tokens", 0)
+ total_count = status.get("total_tokens", 0)
+
+ if healthy_count == total_count:
+ message_class = "bg-green-100 border-green-400 text-green-700"
+ message = f"所有 {total_count} 个 Token 均健康!"
+ elif healthy_count > 0:
+ message_class = "bg-blue-100 border-blue-400 text-blue-700"
+ message = f"健康检查完成:{healthy_count}/{total_count} 个 Token 健康。"
+ else:
+ message_class = "bg-red-100 border-red-400 text-red-700"
+ message = f"警告:0/{total_count} 个 Token 健康,请检查配置。"
+
+ return HTMLResponse(f"""
+
+ 健康检查完成!
+ {message}
+
+ """)
+
+
+@router.post("/tokens/sync-pool")
+async def sync_token_pool():
+ """手动同步 Token 池(从数据库重新加载)"""
+ from app.utils.token_pool import get_token_pool
+
+ pool = get_token_pool()
+
+ if not pool:
+ return HTMLResponse("""
+
+ 提示!
+ Token 池未初始化,请重启服务。
+
+ """)
+
+ # 从数据库同步
+ await pool.sync_from_database(DEFAULT_TOKEN_NAMESPACE)
+
+ # 获取同步后的状态
+ status = pool.get_pool_status()
+ total_count = status.get("total_tokens", 0)
+ available_count = status.get("available_tokens", 0)
+ user_count = status.get("user_tokens", 0)
+
+ logger.info(
+ f"✅ Token 池手动同步完成,总计 {total_count} 个 Token, 可用 {available_count} 个, 认证用户 {user_count} 个"
+ )
+
+ if total_count == 0:
+ message_class = "bg-yellow-100 border-yellow-400 text-yellow-700"
+ message = "同步完成:当前没有可用 Token,请在数据库中启用 Token。"
+ elif available_count == 0:
+ message_class = "bg-orange-100 border-orange-400 text-orange-700"
+ message = f"同步完成:共 {total_count} 个 Token,但无可用 Token(可能都已禁用)。"
+ else:
+ message_class = "bg-green-100 border-green-400 text-green-700"
+ message = f"同步完成:共 {total_count} 个 Token,{available_count} 个可用,{user_count} 个认证用户。"
+
+ return HTMLResponse(f"""
+
+ Token 池同步完成!
+ {message}
+
+ """)
diff --git a/app/admin/auth.py b/app/admin/auth.py
new file mode 100644
index 0000000000000000000000000000000000000000..654cae372fc6425ce820b0296e4b923574d1dae6
--- /dev/null
+++ b/app/admin/auth.py
@@ -0,0 +1,129 @@
+"""
+管理后台认证中间件
+"""
+from fastapi import Request, HTTPException, status
+from fastapi.responses import RedirectResponse
+from typing import Optional
+import hashlib
+import secrets
+from datetime import datetime, timedelta
+
+from app.core.config import settings
+
+# 简单的内存 Session 存储(生产环境建议使用 Redis)
+_sessions = {}
+
+# Session 有效期(小时)
+SESSION_EXPIRE_HOURS = 24
+
+
+def generate_session_token() -> str:
+ """生成随机 session token"""
+ return secrets.token_urlsafe(32)
+
+
+def create_session(password: str) -> Optional[str]:
+ """
+ 创建 session
+
+ Args:
+ password: 用户输入的密码
+
+ Returns:
+ session_token 或 None(密码错误)
+ """
+ # 验证密码
+ if password != settings.ADMIN_PASSWORD:
+ return None
+
+ # 生成 session token
+ session_token = generate_session_token()
+
+ # 存储 session(包含过期时间)
+ _sessions[session_token] = {
+ "created_at": datetime.now(),
+ "expires_at": datetime.now() + timedelta(hours=SESSION_EXPIRE_HOURS),
+ "authenticated": True
+ }
+
+ return session_token
+
+
+def verify_session(session_token: Optional[str]) -> bool:
+ """
+ 验证 session 是否有效
+
+ Args:
+ session_token: Session token
+
+ Returns:
+ 是否已认证
+ """
+ if not session_token:
+ return False
+
+ session = _sessions.get(session_token)
+ if not session:
+ return False
+
+ # 检查是否过期
+ if datetime.now() > session["expires_at"]:
+ # 删除过期 session
+ del _sessions[session_token]
+ return False
+
+ return session.get("authenticated", False)
+
+
+def delete_session(session_token: Optional[str]):
+ """删除 session(登出)"""
+ if session_token and session_token in _sessions:
+ del _sessions[session_token]
+
+
+def get_session_token_from_request(request: Request) -> Optional[str]:
+ """从请求中获取 session token"""
+ return request.cookies.get("admin_session")
+
+
+async def require_auth(request: Request):
+ """
+ 认证依赖项:要求用户已登录
+
+ 在路由中使用:
+ @router.get("/admin", dependencies=[Depends(require_auth)])
+ """
+ session_token = get_session_token_from_request(request)
+
+ if not verify_session(session_token):
+ # 未认证,重定向到登录页
+ raise HTTPException(
+ status_code=status.HTTP_303_SEE_OTHER,
+ detail="未登录",
+ headers={"Location": "/admin/login"}
+ )
+
+
+def get_authenticated_user(request: Request) -> bool:
+ """
+ 获取当前认证状态(用于模板)
+
+ Returns:
+ 是否已认证
+ """
+ session_token = get_session_token_from_request(request)
+ return verify_session(session_token)
+
+
+def cleanup_expired_sessions():
+ """清理过期的 session(定时任务调用)"""
+ now = datetime.now()
+ expired_tokens = [
+ token for token, session in _sessions.items()
+ if now > session["expires_at"]
+ ]
+
+ for token in expired_tokens:
+ del _sessions[token]
+
+ return len(expired_tokens)
diff --git a/app/admin/config_manager.py b/app/admin/config_manager.py
new file mode 100644
index 0000000000000000000000000000000000000000..0b14123c3c5bd4c8b68dad1f247fb204b2fd051c
--- /dev/null
+++ b/app/admin/config_manager.py
@@ -0,0 +1,682 @@
+"""Admin config metadata and helpers for the configuration console."""
+
+from __future__ import annotations
+
+import re
+from dataclasses import dataclass
+from pathlib import Path
+from typing import Any, Awaitable, Callable, Mapping
+
+from dotenv import dotenv_values
+
+from app.core.config import settings
+from app.utils.env_file import update_env_file
+from app.utils.logger import logger
+
+ENV_PATH = Path(".env")
+ENV_EXAMPLE_PATH = Path(".env.example")
+_ENV_SOURCE_LINE_PATTERN = re.compile(
+ r"^\s*(?:export\s+)?[A-Za-z_][A-Za-z0-9_]*\s*=.*$"
+)
+
+
+@dataclass(frozen=True)
+class ConfigFieldSpec:
+ key: str
+ label: str
+ description: str
+ value_type: str
+ default_value: object
+ input_type: str = "text"
+ placeholder: str = ""
+ required: bool = False
+ wide: bool = False
+ sensitive: bool = False
+ restart_required: bool = False
+ min_value: int | None = None
+ max_value: int | None = None
+
+
+@dataclass(frozen=True)
+class ConfigSectionSpec:
+ id: str
+ title: str
+ description: str
+ fields: tuple[ConfigFieldSpec, ...]
+
+
+CONFIG_SECTIONS: tuple[ConfigSectionSpec, ...] = (
+ ConfigSectionSpec(
+ id="access",
+ title="接入与认证",
+ description="控制上游接口地址、客户端鉴权和 Function Call 行为。",
+ fields=(
+ ConfigFieldSpec(
+ key="API_ENDPOINT",
+ label="上游 API 地址",
+ description="代理请求实际转发到的上游聊天完成接口。",
+ value_type="str",
+ default_value="https://chat.z.ai/api/v2/chat/completions",
+ input_type="url",
+ placeholder="https://chat.z.ai/api/v2/chat/completions",
+ required=True,
+ wide=True,
+ ),
+ ConfigFieldSpec(
+ key="AUTH_TOKEN",
+ label="客户端认证密钥",
+ description="客户端访问本服务时使用的 Bearer Token。",
+ value_type="str",
+ default_value="sk-your-api-key",
+ input_type="password",
+ placeholder="sk-your-api-key",
+ wide=True,
+ sensitive=True,
+ ),
+ ConfigFieldSpec(
+ key="SKIP_AUTH_TOKEN",
+ label="跳过客户端认证",
+ description="仅建议开发环境使用,开启后不校验 AUTH_TOKEN。",
+ value_type="bool",
+ default_value=False,
+ ),
+ ConfigFieldSpec(
+ key="TOOL_SUPPORT",
+ label="启用 Function Call",
+ description="允许 OpenAI 兼容接口使用工具调用能力。",
+ value_type="bool",
+ default_value=True,
+ ),
+ ConfigFieldSpec(
+ key="SCAN_LIMIT",
+ label="工具调用扫描限制",
+ description="Function Call 扫描的最大字符数。",
+ value_type="int",
+ default_value=200000,
+ input_type="number",
+ min_value=1,
+ placeholder="200000",
+ ),
+ ),
+ ),
+ ConfigSectionSpec(
+ id="server",
+ title="服务运行",
+ description="服务监听、日志、数据库路径和反向代理前缀。",
+ fields=(
+ ConfigFieldSpec(
+ key="SERVICE_NAME",
+ label="服务名称",
+ description="显示在进程列表中的服务名称。",
+ value_type="str",
+ default_value="api-proxy-server",
+ placeholder="api-proxy-server",
+ required=True,
+ restart_required=True,
+ ),
+ ConfigFieldSpec(
+ key="LISTEN_PORT",
+ label="监听端口",
+ description="HTTP 服务监听端口。",
+ value_type="int",
+ default_value=8080,
+ input_type="number",
+ min_value=1,
+ max_value=65535,
+ required=True,
+ restart_required=True,
+ placeholder="8080",
+ ),
+ ConfigFieldSpec(
+ key="ROOT_PATH",
+ label="反向代理路径前缀",
+ description="例如 /api,部署在子路径时使用。",
+ value_type="str",
+ default_value="",
+ placeholder="/api",
+ restart_required=True,
+ ),
+ ConfigFieldSpec(
+ key="DEBUG_LOGGING",
+ label="启用调试日志",
+ description="开启后会输出更详细的调试信息。",
+ value_type="bool",
+ default_value=False,
+ ),
+ ConfigFieldSpec(
+ key="DB_PATH",
+ label="数据库路径",
+ description="SQLite 数据库文件位置。",
+ value_type="str",
+ default_value="tokens.db",
+ placeholder="tokens.db",
+ required=True,
+ wide=True,
+ restart_required=True,
+ ),
+ ),
+ ),
+ ConfigSectionSpec(
+ id="tokens",
+ title="Token 池策略",
+ description="失败判定、恢复时间和自动导入、自动维护计划任务。",
+ fields=(
+ ConfigFieldSpec(
+ key="TOKEN_FAILURE_THRESHOLD",
+ label="失败阈值",
+ description="连续失败多少次后将 Token 标记为不可用。",
+ value_type="int",
+ default_value=3,
+ input_type="number",
+ min_value=1,
+ required=True,
+ restart_required=True,
+ ),
+ ConfigFieldSpec(
+ key="TOKEN_RECOVERY_TIMEOUT",
+ label="恢复超时(秒)",
+ description="失败 Token 重新参与调度前的等待时间。",
+ value_type="int",
+ default_value=1800,
+ input_type="number",
+ min_value=1,
+ required=True,
+ restart_required=True,
+ ),
+ ConfigFieldSpec(
+ key="TOKEN_AUTO_IMPORT_ENABLED",
+ label="启用自动导入",
+ description="按固定周期扫描服务端目录并导入 Token。",
+ value_type="bool",
+ default_value=False,
+ ),
+ ConfigFieldSpec(
+ key="TOKEN_AUTO_IMPORT_SOURCE_DIR",
+ label="自动导入目录",
+ description="服务端本地目录,开启自动导入时需要可访问。",
+ value_type="str",
+ default_value="",
+ placeholder="E:\\tokens\\input",
+ wide=True,
+ ),
+ ConfigFieldSpec(
+ key="TOKEN_AUTO_IMPORT_INTERVAL",
+ label="自动导入间隔(秒)",
+ description="自动导入的扫描周期。",
+ value_type="int",
+ default_value=300,
+ input_type="number",
+ min_value=1,
+ required=True,
+ ),
+ ConfigFieldSpec(
+ key="TOKEN_AUTO_MAINTENANCE_ENABLED",
+ label="启用自动维护",
+ description="定时执行去重、健康检查和删除失效 Token。",
+ value_type="bool",
+ default_value=False,
+ ),
+ ConfigFieldSpec(
+ key="TOKEN_AUTO_MAINTENANCE_INTERVAL",
+ label="自动维护间隔(秒)",
+ description="自动维护的执行周期。",
+ value_type="int",
+ default_value=1800,
+ input_type="number",
+ min_value=1,
+ required=True,
+ ),
+ ConfigFieldSpec(
+ key="TOKEN_AUTO_REMOVE_DUPLICATES",
+ label="自动去重",
+ description="自动维护时清理重复 Token。",
+ value_type="bool",
+ default_value=True,
+ ),
+ ConfigFieldSpec(
+ key="TOKEN_AUTO_HEALTH_CHECK",
+ label="自动健康检查",
+ description="自动维护时验证 Token 可用性。",
+ value_type="bool",
+ default_value=True,
+ ),
+ ConfigFieldSpec(
+ key="TOKEN_AUTO_DELETE_INVALID",
+ label="自动删除失效 Token",
+ description="自动维护时移除已验证为无效的 Token。",
+ value_type="bool",
+ default_value=False,
+ ),
+ ),
+ ),
+ ConfigSectionSpec(
+ id="guest",
+ title="匿名 Guest 会话池",
+ description="没有用户 Token 时,仅控制是否启用匿名池和池容量。",
+ fields=(
+ ConfigFieldSpec(
+ key="ANONYMOUS_MODE",
+ label="启用匿名模式",
+ description="无可用用户 Token 时允许使用匿名会话。",
+ value_type="bool",
+ default_value=True,
+ restart_required=True,
+ ),
+ ConfigFieldSpec(
+ key="GUEST_POOL_SIZE",
+ label="Guest 池容量",
+ description="启动和维持的 guest 会话数量。",
+ value_type="int",
+ default_value=3,
+ input_type="number",
+ min_value=1,
+ required=True,
+ restart_required=True,
+ ),
+ ),
+ ),
+ ConfigSectionSpec(
+ id="models",
+ title="模型映射",
+ description="映射 OpenAI 兼容模型名到上游 Z.AI 实际模型名。",
+ fields=(
+ ConfigFieldSpec(
+ key="GLM45_MODEL",
+ label="GLM 4.5",
+ description="标准 GLM 4.5 模型标识。",
+ value_type="str",
+ default_value="GLM-4.5",
+ placeholder="GLM-4.5",
+ required=True,
+ ),
+ ConfigFieldSpec(
+ key="GLM45_THINKING_MODEL",
+ label="GLM 4.5 Thinking",
+ description="推理增强版 GLM 4.5 模型标识。",
+ value_type="str",
+ default_value="GLM-4.5-Thinking",
+ placeholder="GLM-4.5-Thinking",
+ required=True,
+ ),
+ ConfigFieldSpec(
+ key="GLM45_SEARCH_MODEL",
+ label="GLM 4.5 Search",
+ description="搜索增强版 GLM 4.5 模型标识。",
+ value_type="str",
+ default_value="GLM-4.5-Search",
+ placeholder="GLM-4.5-Search",
+ required=True,
+ ),
+ ConfigFieldSpec(
+ key="GLM45_AIR_MODEL",
+ label="GLM 4.5 Air",
+ description="轻量版 GLM 4.5 模型标识。",
+ value_type="str",
+ default_value="GLM-4.5-Air",
+ placeholder="GLM-4.5-Air",
+ required=True,
+ ),
+ ConfigFieldSpec(
+ key="GLM46V_MODEL",
+ label="GLM 4.6V",
+ description="视觉模型标识。",
+ value_type="str",
+ default_value="GLM-4.6V",
+ placeholder="GLM-4.6V",
+ required=True,
+ ),
+ ConfigFieldSpec(
+ key="GLM5_MODEL",
+ label="GLM 5",
+ description="GLM 5 模型标识。",
+ value_type="str",
+ default_value="GLM-5",
+ placeholder="GLM-5",
+ required=True,
+ ),
+ ConfigFieldSpec(
+ key="GLM47_MODEL",
+ label="GLM 4.7",
+ description="GLM 4.7 主模型标识。",
+ value_type="str",
+ default_value="GLM-4.7",
+ placeholder="GLM-4.7",
+ required=True,
+ ),
+ ConfigFieldSpec(
+ key="GLM47_THINKING_MODEL",
+ label="GLM 4.7 Thinking",
+ description="GLM 4.7 推理版模型标识。",
+ value_type="str",
+ default_value="GLM-4.7-Thinking",
+ placeholder="GLM-4.7-Thinking",
+ required=True,
+ ),
+ ConfigFieldSpec(
+ key="GLM47_SEARCH_MODEL",
+ label="GLM 4.7 Search",
+ description="GLM 4.7 搜索版模型标识。",
+ value_type="str",
+ default_value="GLM-4.7-Search",
+ placeholder="GLM-4.7-Search",
+ required=True,
+ ),
+ ConfigFieldSpec(
+ key="GLM47_ADVANCED_SEARCH_MODEL",
+ label="GLM 4.7 Advanced Search",
+ description="GLM 4.7 高级搜索模型标识。",
+ value_type="str",
+ default_value="GLM-4.7-advanced-search",
+ placeholder="GLM-4.7-advanced-search",
+ required=True,
+ wide=True,
+ ),
+ ),
+ ),
+ ConfigSectionSpec(
+ id="proxy",
+ title="代理网络",
+ description="上游访问使用的 HTTP、HTTPS 和 SOCKS5 代理。",
+ fields=(
+ ConfigFieldSpec(
+ key="HTTP_PROXY",
+ label="HTTP 代理",
+ description="例如 http://127.0.0.1:7890。",
+ value_type="str",
+ default_value="",
+ placeholder="http://127.0.0.1:7890",
+ wide=True,
+ ),
+ ConfigFieldSpec(
+ key="HTTPS_PROXY",
+ label="HTTPS 代理",
+ description="例如 http://127.0.0.1:7890。",
+ value_type="str",
+ default_value="",
+ placeholder="http://127.0.0.1:7890",
+ wide=True,
+ ),
+ ConfigFieldSpec(
+ key="SOCKS5_PROXY",
+ label="SOCKS5 代理",
+ description="例如 socks5://127.0.0.1:1080。",
+ value_type="str",
+ default_value="",
+ placeholder="socks5://127.0.0.1:1080",
+ wide=True,
+ ),
+ ),
+ ),
+ ConfigSectionSpec(
+ id="admin",
+ title="后台安全",
+ description="管理后台密码和会话密钥。修改后建议重新登录。",
+ fields=(
+ ConfigFieldSpec(
+ key="ADMIN_PASSWORD",
+ label="后台密码",
+ description="管理后台登录密码。",
+ value_type="str",
+ default_value="admin123",
+ input_type="password",
+ placeholder="admin123",
+ required=True,
+ sensitive=True,
+ ),
+ ConfigFieldSpec(
+ key="SESSION_SECRET_KEY",
+ label="会话密钥",
+ description="用于后台会话签名的密钥。",
+ value_type="str",
+ default_value="your-secret-key-change-in-production",
+ input_type="password",
+ placeholder="your-secret-key-change-in-production",
+ required=True,
+ sensitive=True,
+ wide=True,
+ ),
+ ),
+ ),
+)
+
+CONFIG_FIELD_SPECS = {
+ field.key: field
+ for section in CONFIG_SECTIONS
+ for field in section.fields
+}
+MANAGED_ENV_KEYS = tuple(CONFIG_FIELD_SPECS.keys())
+ReloadCallback = Callable[[], Awaitable[None]]
+
+
+def read_env_content(env_path: str | Path = ENV_PATH) -> str:
+ path = Path(env_path)
+ if not path.exists():
+ return ""
+ return path.read_text(encoding="utf-8")
+
+
+def validate_env_source(content: str) -> str:
+ normalized = content.replace("\r\n", "\n").replace("\r", "\n")
+
+ for line_number, line in enumerate(normalized.splitlines(), start=1):
+ stripped = line.strip()
+ if not stripped or stripped.startswith("#"):
+ continue
+ if not _ENV_SOURCE_LINE_PATTERN.match(line):
+ raise ValueError(
+ f"第 {line_number} 行不是合法的 KEY=VALUE 格式。"
+ )
+
+ return normalized
+
+
+def build_config_page_data(
+ *,
+ settings_obj: Any = settings,
+ env_path: str | Path = ENV_PATH,
+ env_example_path: str | Path = ENV_EXAMPLE_PATH,
+) -> dict[str, Any]:
+ env_file = Path(env_path)
+ env_content = read_env_content(env_file)
+ env_values = dotenv_values(env_file) if env_file.exists() else {}
+ sections: list[dict[str, Any]] = []
+ total_fields = 0
+ overridden_fields = 0
+ sensitive_fields = 0
+ restart_required_fields = 0
+
+ for section in CONFIG_SECTIONS:
+ rendered_fields: list[dict[str, Any]] = []
+ for field in section.fields:
+ total_fields += 1
+ if field.sensitive:
+ sensitive_fields += 1
+ if field.restart_required:
+ restart_required_fields += 1
+
+ is_overridden = field.key in env_values
+ if is_overridden:
+ overridden_fields += 1
+
+ value = getattr(settings_obj, field.key, field.default_value)
+ if value is None:
+ value = ""
+
+ rendered_fields.append(
+ {
+ "key": field.key,
+ "label": field.label,
+ "description": field.description,
+ "value_type": field.value_type,
+ "value": value,
+ "input_type": field.input_type,
+ "placeholder": field.placeholder,
+ "required": field.required,
+ "wide": field.wide,
+ "sensitive": field.sensitive,
+ "restart_required": field.restart_required,
+ "min_value": field.min_value,
+ "max_value": field.max_value,
+ "source_label": ".env" if is_overridden else "默认值",
+ "source_badge_class": (
+ "bg-emerald-50 text-emerald-700 ring-emerald-200"
+ if is_overridden
+ else "bg-slate-100 text-slate-600 ring-slate-200"
+ ),
+ }
+ )
+
+ sections.append(
+ {
+ "id": section.id,
+ "title": section.title,
+ "description": section.description,
+ "fields": rendered_fields,
+ "field_count": len(rendered_fields),
+ }
+ )
+
+ return {
+ "sections": sections,
+ "env_content": env_content,
+ "overview": {
+ "total_sections": len(CONFIG_SECTIONS),
+ "total_fields": total_fields,
+ "overridden_fields": overridden_fields,
+ "default_fields": total_fields - overridden_fields,
+ "sensitive_fields": sensitive_fields,
+ "restart_required_fields": restart_required_fields,
+ "env_exists": env_file.exists(),
+ "env_path": str(env_file.resolve()),
+ "env_line_count": len(env_content.splitlines()) if env_content else 0,
+ "example_exists": Path(env_example_path).exists(),
+ },
+ }
+
+
+def build_form_updates(form_data: Mapping[str, Any]) -> dict[str, object]:
+ updates: dict[str, object] = {}
+
+ for key in MANAGED_ENV_KEYS:
+ field = CONFIG_FIELD_SPECS[key]
+
+ if field.value_type == "bool":
+ updates[key] = key in form_data
+ continue
+
+ raw_value = str(form_data.get(key, "") or "").strip()
+ if field.required and raw_value == "":
+ raise ValueError(f"{field.label} 不能为空。")
+
+ if field.value_type == "int":
+ try:
+ parsed = int(raw_value)
+ except ValueError as exc:
+ raise ValueError(f"{field.label} 必须是整数。") from exc
+
+ if field.min_value is not None and parsed < field.min_value:
+ raise ValueError(
+ f"{field.label} 不能小于 {field.min_value}。"
+ )
+ if field.max_value is not None and parsed > field.max_value:
+ raise ValueError(
+ f"{field.label} 不能大于 {field.max_value}。"
+ )
+ updates[key] = parsed
+ continue
+
+ updates[key] = raw_value
+
+ return updates
+
+
+async def _apply_env_change(
+ writer: Callable[[Path], None],
+ *,
+ reload_callback: ReloadCallback,
+ env_path: str | Path = ENV_PATH,
+) -> None:
+ path = Path(env_path)
+ had_existing_file = path.exists()
+ previous_content = read_env_content(path) if had_existing_file else ""
+
+ try:
+ writer(path)
+ await reload_callback()
+ except Exception:
+ if had_existing_file:
+ path.write_text(previous_content, encoding="utf-8")
+ elif path.exists():
+ path.unlink()
+
+ try:
+ await reload_callback()
+ except Exception as restore_exc:
+ logger.warning(f"⚠️ 回滚配置后重新加载失败: {restore_exc}")
+ raise
+
+
+async def save_form_config(
+ form_data: Mapping[str, Any],
+ *,
+ reload_callback: ReloadCallback,
+ env_path: str | Path = ENV_PATH,
+) -> dict[str, object]:
+ updates = build_form_updates(form_data)
+
+ async def _reload() -> None:
+ await reload_callback()
+
+ def _writer(target_path: Path) -> None:
+ update_env_file(updates, env_path=target_path)
+
+ await _apply_env_change(_writer, reload_callback=_reload, env_path=env_path)
+ return updates
+
+
+async def save_source_config(
+ env_content: str,
+ *,
+ reload_callback: ReloadCallback,
+ env_path: str | Path = ENV_PATH,
+) -> None:
+ normalized = validate_env_source(env_content)
+
+ def _writer(target_path: Path) -> None:
+ content = normalized.rstrip("\n")
+ target_path.write_text(
+ f"{content}\n" if content else "",
+ encoding="utf-8",
+ )
+
+ await _apply_env_change(
+ _writer,
+ reload_callback=reload_callback,
+ env_path=env_path,
+ )
+
+
+async def reset_env_to_example(
+ *,
+ reload_callback: ReloadCallback,
+ env_path: str | Path = ENV_PATH,
+ env_example_path: str | Path = ENV_EXAMPLE_PATH,
+) -> None:
+ example_path = Path(env_example_path)
+ if not example_path.exists():
+ raise FileNotFoundError(".env.example 不存在")
+
+ example_content = example_path.read_text(encoding="utf-8")
+
+ def _writer(target_path: Path) -> None:
+ content = example_content.rstrip("\n")
+ target_path.write_text(
+ f"{content}\n" if content else "",
+ encoding="utf-8",
+ )
+
+ await _apply_env_change(
+ _writer,
+ reload_callback=reload_callback,
+ env_path=env_path,
+ )
diff --git a/app/admin/routes.py b/app/admin/routes.py
new file mode 100644
index 0000000000000000000000000000000000000000..98a0b2e8702f75fb951012cb360f0402f4a38bd5
--- /dev/null
+++ b/app/admin/routes.py
@@ -0,0 +1,109 @@
+"""
+管理后台路由模块
+"""
+from datetime import datetime
+
+from fastapi import APIRouter, Depends, Request
+from fastapi.responses import HTMLResponse
+from fastapi.templating import Jinja2Templates
+
+from app.admin.auth import require_auth
+from app.admin.config_manager import build_config_page_data
+from app.admin.stats import (
+ DEFAULT_TREND_WINDOW,
+ TREND_WINDOW_OPTIONS,
+ collect_admin_stats,
+ get_process_uptime,
+)
+
+router = APIRouter(prefix="/admin", tags=["admin"])
+templates = Jinja2Templates(directory="app/templates")
+DEFAULT_TOKEN_NAMESPACE = "zai"
+
+
+@router.get("/login", response_class=HTMLResponse)
+async def login_page(request: Request):
+ """登录页面"""
+ return templates.TemplateResponse("login.html", {"request": request})
+
+
+@router.get("/", response_class=HTMLResponse, dependencies=[Depends(require_auth)])
+async def dashboard(request: Request):
+ """仪表盘首页"""
+ stats = await collect_admin_stats(
+ DEFAULT_TOKEN_NAMESPACE,
+ trend_window=DEFAULT_TREND_WINDOW,
+ )
+ stats["uptime"] = get_process_uptime()
+
+ context = {
+ "request": request,
+ "stats": stats,
+ "current_time": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
+ "trend_windows": TREND_WINDOW_OPTIONS,
+ }
+
+ return templates.TemplateResponse("index.html", context)
+
+
+@router.get(
+ "/config",
+ response_class=HTMLResponse,
+ dependencies=[Depends(require_auth)],
+)
+async def config_page(request: Request):
+ """配置管理页面"""
+ page_data = build_config_page_data()
+
+ context = {
+ "request": request,
+ "sections": page_data["sections"],
+ "env_content": page_data["env_content"],
+ "overview": page_data["overview"],
+ }
+ return templates.TemplateResponse("config.html", context)
+
+
+@router.get("/logs", response_class=HTMLResponse, dependencies=[Depends(require_auth)])
+async def logs_page(request: Request):
+ """实时日志页面"""
+ context = {
+ "request": request,
+ }
+ return templates.TemplateResponse("logs.html", context)
+
+
+@router.get(
+ "/tokens",
+ response_class=HTMLResponse,
+ dependencies=[Depends(require_auth)],
+)
+async def tokens_page(request: Request):
+ """Token 管理页面"""
+ from app.core.config import settings
+
+ maintenance_actions: list[str] = []
+ if settings.TOKEN_AUTO_REMOVE_DUPLICATES:
+ maintenance_actions.append("删除重复 Token")
+ if settings.TOKEN_AUTO_HEALTH_CHECK:
+ maintenance_actions.append("批量测活")
+ if settings.TOKEN_AUTO_DELETE_INVALID:
+ maintenance_actions.append("删除失效 Token")
+
+ context = {
+ "request": request,
+ "automation": {
+ "config_url": "/admin/config#tokens",
+ "import_enabled": settings.TOKEN_AUTO_IMPORT_ENABLED,
+ "import_source_dir": settings.TOKEN_AUTO_IMPORT_SOURCE_DIR,
+ "import_interval": settings.TOKEN_AUTO_IMPORT_INTERVAL,
+ "has_import_source_dir": bool(
+ settings.TOKEN_AUTO_IMPORT_SOURCE_DIR.strip()
+ ),
+ "maintenance_enabled": settings.TOKEN_AUTO_MAINTENANCE_ENABLED,
+ "maintenance_interval": settings.TOKEN_AUTO_MAINTENANCE_INTERVAL,
+ "maintenance_actions": maintenance_actions,
+ "has_maintenance_actions": bool(maintenance_actions),
+ },
+ }
+ return templates.TemplateResponse("tokens.html", context)
diff --git a/app/admin/stats.py b/app/admin/stats.py
new file mode 100644
index 0000000000000000000000000000000000000000..2299621f406564cd1fca2e1e70237e669bc8ebce
--- /dev/null
+++ b/app/admin/stats.py
@@ -0,0 +1,184 @@
+"""管理后台统计聚合辅助函数。"""
+
+from __future__ import annotations
+
+import os
+import time
+from typing import Any, Dict, Optional
+
+import psutil
+
+from app.services.request_log_dao import RequestLogDAO, get_request_log_dao
+from app.services.token_dao import TokenDAO, get_token_dao
+from app.utils.token_pool import TokenPool, get_token_pool
+
+_TOKEN_POOL_SENTINEL = object()
+DEFAULT_TREND_WINDOW = "7d"
+TREND_WINDOW_OPTIONS = (
+ {"key": "24h", "label": "24 小时"},
+ {"key": "7d", "label": "7 天"},
+ {"key": "30d", "label": "30 天"},
+)
+
+
+def _coerce_int(value: Any) -> int:
+ """将数据库聚合结果安全转换为整数。"""
+ return int(value or 0)
+
+
+def calculate_success_rate(
+ successful_requests: int,
+ total_requests: int,
+) -> float:
+ """计算成功率百分比。"""
+ if total_requests <= 0:
+ return 0.0
+ return round(successful_requests / total_requests * 100, 1)
+
+
+def format_compact_number(value: Any) -> str:
+ """格式化大数字,便于仪表盘展示。"""
+ number = int(value or 0)
+ if number >= 1_000_000:
+ return f"{number / 1_000_000:.1f}M"
+ if number >= 10_000:
+ return f"{number / 10_000:.1f}万"
+ if number >= 1_000:
+ return f"{number / 1_000:.1f}k"
+ return str(number)
+
+
+def normalize_trend_window(value: Any) -> str:
+ """规范化趋势窗口参数,非法值回退到默认值。"""
+ normalized = str(value or "").strip().lower()
+ if normalized in {"24h", "7d", "30d"}:
+ return normalized
+ if normalized == "1d":
+ return "24h"
+ return DEFAULT_TREND_WINDOW
+
+
+def format_uptime(total_seconds: int) -> str:
+ """格式化运行时长。"""
+ total_seconds = max(0, int(total_seconds))
+ days, remainder = divmod(total_seconds, 86400)
+ hours, remainder = divmod(remainder, 3600)
+ minutes, seconds = divmod(remainder, 60)
+
+ parts = []
+ if days:
+ parts.append(f"{days}天")
+ if days or hours:
+ parts.append(f"{hours}小时")
+ if days or hours or minutes:
+ parts.append(f"{minutes}分钟")
+ parts.append(f"{seconds}秒")
+
+ return " ".join(parts)
+
+
+def get_process_uptime() -> str:
+ """获取当前进程运行时长。"""
+ created_at = psutil.Process(os.getpid()).create_time()
+ return format_uptime(int(time.time() - created_at))
+
+
+async def collect_admin_stats(
+ provider: str,
+ *,
+ token_dao: Optional[TokenDAO] = None,
+ request_log_dao: Optional[RequestLogDAO] = None,
+ token_pool: Any = _TOKEN_POOL_SENTINEL,
+ trend_window: str = DEFAULT_TREND_WINDOW,
+) -> Dict[str, Any]:
+ """聚合管理后台所需的 Token 与请求统计。"""
+ token_dao = token_dao or get_token_dao()
+ request_log_dao = request_log_dao or get_request_log_dao()
+ if token_pool is _TOKEN_POOL_SENTINEL:
+ token_pool = get_token_pool()
+ trend_window = normalize_trend_window(trend_window)
+
+ token_counts = await token_dao.get_provider_token_counts(provider)
+ request_stats = await request_log_dao.get_provider_request_stats(provider)
+ usage_trend = await request_log_dao.get_provider_usage_trend(
+ provider,
+ window=trend_window,
+ )
+
+ pool_status: Dict[str, Any] = {}
+ if isinstance(token_pool, TokenPool) or hasattr(token_pool, "get_pool_status"):
+ pool_status = token_pool.get_pool_status() if token_pool else {}
+
+ total_tokens = _coerce_int(token_counts.get("total_tokens"))
+ enabled_tokens = _coerce_int(token_counts.get("enabled_tokens"))
+ user_tokens = _coerce_int(token_counts.get("user_tokens"))
+ guest_tokens = _coerce_int(token_counts.get("guest_tokens"))
+ unknown_tokens = _coerce_int(token_counts.get("unknown_tokens"))
+
+ pool_total_tokens = _coerce_int(pool_status.get("total_tokens"))
+ if pool_total_tokens == 0 and token_pool is None:
+ pool_total_tokens = max(0, enabled_tokens - guest_tokens)
+
+ available_tokens = _coerce_int(pool_status.get("available_tokens"))
+ healthy_tokens = _coerce_int(pool_status.get("healthy_tokens"))
+ unhealthy_tokens = _coerce_int(pool_status.get("unhealthy_tokens"))
+
+ total_requests = _coerce_int(request_stats.get("total_requests"))
+ successful_requests = _coerce_int(request_stats.get("successful_requests"))
+ failed_requests = _coerce_int(request_stats.get("failed_requests"))
+ input_tokens = _coerce_int(request_stats.get("input_tokens"))
+ output_tokens = _coerce_int(request_stats.get("output_tokens"))
+ total_consumed_tokens = _coerce_int(request_stats.get("total_tokens"))
+ cache_creation_tokens = _coerce_int(
+ request_stats.get("cache_creation_tokens")
+ )
+ cache_read_tokens = _coerce_int(request_stats.get("cache_read_tokens"))
+ cache_creation_requests = _coerce_int(
+ request_stats.get("cache_creation_requests")
+ )
+ cache_hit_requests = _coerce_int(request_stats.get("cache_hit_requests"))
+ average_latency = round(float(request_stats.get("avg_duration") or 0.0), 2)
+ average_first_token_latency = round(
+ float(request_stats.get("avg_first_token_time") or 0.0),
+ 2,
+ )
+ total_cache_tokens = cache_creation_tokens + cache_read_tokens
+
+ return {
+ "total_tokens": total_tokens,
+ "enabled_tokens": enabled_tokens,
+ "user_tokens": user_tokens,
+ "guest_tokens": guest_tokens,
+ "unknown_tokens": unknown_tokens,
+ "pool_total_tokens": pool_total_tokens,
+ "available_tokens": available_tokens,
+ "healthy_tokens": healthy_tokens,
+ "unhealthy_tokens": unhealthy_tokens,
+ "total_requests": total_requests,
+ "successful_requests": successful_requests,
+ "failed_requests": failed_requests,
+ "input_tokens": input_tokens,
+ "output_tokens": output_tokens,
+ "total_consumed_tokens": total_consumed_tokens,
+ "cache_creation_tokens": cache_creation_tokens,
+ "cache_read_tokens": cache_read_tokens,
+ "total_cache_tokens": total_cache_tokens,
+ "cache_creation_requests": cache_creation_requests,
+ "cache_hit_requests": cache_hit_requests,
+ "average_latency": average_latency,
+ "average_first_token_latency": average_first_token_latency,
+ "trend_window": trend_window,
+ "usage_trend": usage_trend,
+ "total_consumed_tokens_display": format_compact_number(
+ total_consumed_tokens
+ ),
+ "total_cache_tokens_display": format_compact_number(
+ total_cache_tokens
+ ),
+ "input_tokens_display": format_compact_number(input_tokens),
+ "output_tokens_display": format_compact_number(output_tokens),
+ "success_rate": calculate_success_rate(
+ successful_requests,
+ total_requests,
+ ),
+ }
diff --git a/app/core/__init__.py b/app/core/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..6323e5f05522c2198c151ae459d5de7790089fc8
--- /dev/null
+++ b/app/core/__init__.py
@@ -0,0 +1,6 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+
+from app.core import claude, config, openai
+
+__all__ = ["claude", "config", "openai"]
diff --git a/app/core/claude.py b/app/core/claude.py
new file mode 100644
index 0000000000000000000000000000000000000000..d817ccd7427f5cef05f778f4ed7c59199c61d23d
--- /dev/null
+++ b/app/core/claude.py
@@ -0,0 +1,582 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+
+import json
+import math
+import time
+import uuid
+from typing import Any, AsyncGenerator, Dict, List, Optional
+
+from fastapi import APIRouter, Header, Request
+from fastapi.responses import JSONResponse, StreamingResponse
+
+from app.core.claude_compat import (
+ build_non_stream_response,
+ claude_messages_to_openai,
+ claude_tool_choice_to_openai,
+ claude_tools_to_openai,
+ extract_text,
+ make_claude_id,
+ sse_content_block_delta,
+ sse_content_block_start,
+ sse_content_block_stop,
+ sse_error,
+ sse_message_delta,
+ sse_message_start,
+ sse_message_stop,
+ sse_ping,
+)
+from app.core.config import settings
+from app.core.openai import get_upstream_client
+from app.models.schemas import Message, OpenAIRequest
+from app.utils.logger import get_logger
+from app.utils.request_logging import (
+ extract_openai_usage,
+ extract_claude_usage,
+ wrap_claude_stream_with_logging,
+ write_request_log,
+)
+from app.utils.request_source import detect_request_source, format_request_source
+
+logger = get_logger()
+router = APIRouter()
+
+
+def _resolve_claude_model(model: Any) -> str:
+ """Map Claude/Claude Code model aliases to local upstream-supported models."""
+ if not isinstance(model, str) or not model.strip():
+ return settings.GLM5_MODEL
+
+ raw_model = model.strip()
+ normalized = raw_model.casefold()
+ if normalized.endswith("[1m]"):
+ normalized = normalized[:-4].rstrip()
+
+ direct_models = {
+ settings.GLM45_MODEL.casefold(): settings.GLM45_MODEL,
+ settings.GLM45_THINKING_MODEL.casefold(): settings.GLM45_THINKING_MODEL,
+ settings.GLM45_SEARCH_MODEL.casefold(): settings.GLM45_SEARCH_MODEL,
+ settings.GLM45_AIR_MODEL.casefold(): settings.GLM45_AIR_MODEL,
+ settings.GLM46V_MODEL.casefold(): settings.GLM46V_MODEL,
+ settings.GLM5_MODEL.casefold(): settings.GLM5_MODEL,
+ settings.GLM47_MODEL.casefold(): settings.GLM47_MODEL,
+ settings.GLM47_THINKING_MODEL.casefold(): settings.GLM47_THINKING_MODEL,
+ settings.GLM47_SEARCH_MODEL.casefold(): settings.GLM47_SEARCH_MODEL,
+ settings.GLM47_ADVANCED_SEARCH_MODEL.casefold(): settings.GLM47_ADVANCED_SEARCH_MODEL,
+ }
+ if normalized in direct_models:
+ return direct_models[normalized]
+
+ alias_map = {
+ "default": settings.GLM5_MODEL,
+ "sonnet": settings.GLM5_MODEL,
+ "haiku": settings.GLM45_AIR_MODEL,
+ "opus": settings.GLM5_MODEL,
+ "opusplan": settings.GLM47_THINKING_MODEL,
+ }
+ if normalized in alias_map:
+ return alias_map[normalized]
+
+ if normalized.startswith("claude-sonnet") or normalized.startswith("claude-3-7-sonnet") or normalized.startswith("claude-3-5-sonnet"):
+ return settings.GLM5_MODEL
+ if normalized.startswith("claude-opus") or normalized.startswith("claude-4-opus"):
+ return settings.GLM5_MODEL
+ if normalized.startswith("claude-haiku") or normalized.startswith("claude-3-5-haiku"):
+ return settings.GLM45_AIR_MODEL
+
+ return raw_model
+
+
+def _estimate_tokens(text: str) -> int:
+ if not text:
+ return 0
+ return max(1, math.ceil(len(text) / 2))
+
+
+def _extract_api_key(
+ authorization: Optional[str],
+ x_api_key: Optional[str],
+) -> Optional[str]:
+ if x_api_key:
+ return x_api_key
+ if authorization and authorization.startswith("Bearer "):
+ return authorization[7:]
+ return None
+
+
+def _claude_error_response(
+ message: str,
+ status_code: int,
+ error_type: str,
+) -> JSONResponse:
+ return JSONResponse(
+ status_code=status_code,
+ content={
+ "type": "error",
+ "error": {"type": error_type, "message": message},
+ },
+ )
+
+
+def _build_openai_request(body: Dict[str, Any]) -> OpenAIRequest:
+ system = body.get("system")
+ claude_messages = body.get("messages", [])
+ openai_messages = claude_messages_to_openai(system, claude_messages)
+ openai_tools = claude_tools_to_openai(body.get("tools"))
+ tool_choice = claude_tool_choice_to_openai(body.get("tool_choice"))
+
+ thinking = body.get("thinking")
+ enable_thinking = None
+ if isinstance(thinking, dict):
+ thinking_type = thinking.get("type")
+ if thinking_type == "enabled":
+ enable_thinking = True
+ elif thinking_type == "disabled":
+ enable_thinking = False
+
+ messages = [Message.model_validate(message) for message in openai_messages]
+ resolved_model = _resolve_claude_model(body.get("model", settings.GLM5_MODEL))
+ if resolved_model != body.get("model", settings.GLM5_MODEL):
+ logger.info(
+ f"🔀 Claude 模型映射: "
+ f"{body.get('model', settings.GLM5_MODEL)} -> {resolved_model}"
+ )
+
+ return OpenAIRequest(
+ model=resolved_model,
+ messages=messages,
+ stream=bool(body.get("stream", False)),
+ temperature=body.get("temperature"),
+ max_tokens=body.get("max_tokens"),
+ tools=openai_tools,
+ tool_choice=tool_choice,
+ enable_thinking=enable_thinking,
+ )
+
+
+def _build_prompt_text(body: Dict[str, Any]) -> str:
+ prompt_parts: List[str] = []
+ system = body.get("system")
+ if system:
+ prompt_parts.append(extract_text(system))
+
+ for message in body.get("messages", []):
+ content = message.get("content") if isinstance(message, dict) else None
+ text = extract_text(content)
+ if text:
+ prompt_parts.append(text)
+
+ return "\n".join(part for part in prompt_parts if part)
+
+
+def _normalize_tool_calls(tool_calls: Any) -> List[Dict[str, Any]]:
+ if not isinstance(tool_calls, list):
+ return []
+
+ normalized: List[Dict[str, Any]] = []
+ seen_ids = set()
+ for tool_call in tool_calls:
+ if not isinstance(tool_call, dict):
+ continue
+
+ tool_call_id = tool_call.get("id") or f"call_{uuid.uuid4().hex[:24]}"
+ if tool_call_id in seen_ids:
+ continue
+ seen_ids.add(tool_call_id)
+
+ function_data = (
+ tool_call.get("function")
+ if isinstance(tool_call.get("function"), dict)
+ else {}
+ )
+ arguments = function_data.get("arguments", "{}")
+ if not isinstance(arguments, str):
+ try:
+ arguments = json.dumps(arguments, ensure_ascii=False)
+ except Exception:
+ arguments = "{}"
+
+ normalized.append(
+ {
+ "id": tool_call_id,
+ "type": "function",
+ "function": {
+ "name": function_data.get("name", ""),
+ "arguments": arguments,
+ },
+ }
+ )
+
+ return normalized
+
+
+def _convert_openai_response_to_claude(response: Dict[str, Any], msg_id: str) -> Dict[str, Any]:
+ choice = ((response.get("choices") or [{}])[0]) if isinstance(response, dict) else {}
+ message = choice.get("message") or {}
+ reasoning = message.get("reasoning_content")
+ usage = extract_openai_usage(response)
+ return build_non_stream_response(
+ msg_id=msg_id,
+ model=response.get("model", settings.GLM5_MODEL),
+ reasoning_parts=[reasoning] if isinstance(reasoning, str) and reasoning else [],
+ answer_text=message.get("content") or "",
+ tool_calls=_normalize_tool_calls(message.get("tool_calls")),
+ input_tokens=usage["input_tokens"],
+ output_tokens=usage["output_tokens"],
+ cache_creation_tokens=usage["cache_creation_tokens"],
+ cache_read_tokens=usage["cache_read_tokens"],
+ )
+
+
+async def _stream_openai_to_claude(
+ openai_stream: AsyncGenerator[str, None],
+ msg_id: str,
+ model: str,
+ input_tokens: int,
+) -> AsyncGenerator[str, None]:
+ reasoning_parts: List[str] = []
+ answer_parts: List[str] = []
+ tool_calls: List[Dict[str, Any]] = []
+ block_index = 0
+ thinking_started = False
+ final_input_tokens = input_tokens
+ final_output_tokens = 0
+ cache_creation_tokens = 0
+ cache_read_tokens = 0
+
+ yield sse_message_start(msg_id, model, input_tokens)
+ yield sse_ping()
+
+ try:
+ async for chunk in openai_stream:
+ if not chunk.startswith("data: "):
+ continue
+
+ payload_text = chunk[6:].strip()
+ if not payload_text or payload_text == "[DONE]":
+ continue
+
+ payload = json.loads(payload_text)
+ if isinstance(payload, dict) and "error" in payload:
+ error = payload.get("error") or {}
+ yield sse_error(
+ error.get("type", "api_error"),
+ error.get("message", "Upstream error"),
+ )
+ return
+
+ choice = ((payload.get("choices") or [{}])[0]) if isinstance(payload, dict) else {}
+ delta = choice.get("delta") or {}
+
+ reasoning_delta = delta.get("reasoning_content")
+ if reasoning_delta:
+ if not thinking_started:
+ yield sse_content_block_start(
+ block_index,
+ {"type": "thinking", "thinking": ""},
+ )
+ thinking_started = True
+
+ reasoning_parts.append(reasoning_delta)
+ yield sse_content_block_delta(
+ block_index,
+ {"type": "thinking_delta", "thinking": reasoning_delta},
+ )
+
+ content_delta = delta.get("content")
+ if content_delta:
+ answer_parts.append(content_delta)
+
+ if payload.get("usage"):
+ usage = extract_openai_usage(payload)
+ if usage["input_tokens"] > 0:
+ final_input_tokens = usage["input_tokens"]
+ if usage["output_tokens"] > 0:
+ final_output_tokens = usage["output_tokens"]
+ if usage["cache_creation_tokens"] > 0:
+ cache_creation_tokens = usage["cache_creation_tokens"]
+ if usage["cache_read_tokens"] > 0:
+ cache_read_tokens = usage["cache_read_tokens"]
+
+ tool_calls.extend(_normalize_tool_calls(delta.get("tool_calls")))
+
+ if thinking_started:
+ yield sse_content_block_stop(block_index)
+ block_index += 1
+
+ answer_text = "".join(answer_parts)
+ if answer_text:
+ yield sse_content_block_start(block_index, {"type": "text", "text": ""})
+ yield sse_content_block_delta(
+ block_index,
+ {"type": "text_delta", "text": answer_text},
+ )
+ yield sse_content_block_stop(block_index)
+ block_index += 1
+
+ if tool_calls:
+ for tool_call in tool_calls:
+ function_data = tool_call.get("function") or {}
+ tool_id = tool_call.get(
+ "id",
+ f"toolu_{uuid.uuid4().hex[:20]}",
+ ).replace("call_", "toolu_")
+ yield sse_content_block_start(
+ block_index,
+ {
+ "type": "tool_use",
+ "id": tool_id,
+ "name": function_data.get("name", ""),
+ "input": {},
+ },
+ )
+ yield sse_content_block_delta(
+ block_index,
+ {
+ "type": "input_json_delta",
+ "partial_json": function_data.get("arguments", "{}"),
+ },
+ )
+ yield sse_content_block_stop(block_index)
+ block_index += 1
+
+ if not final_output_tokens:
+ final_output_tokens = _estimate_tokens(
+ "".join(reasoning_parts) + answer_text
+ )
+
+ yield sse_message_delta(
+ "tool_use" if tool_calls else "end_turn",
+ final_output_tokens,
+ input_tokens=final_input_tokens,
+ cache_creation_tokens=cache_creation_tokens,
+ cache_read_tokens=cache_read_tokens,
+ )
+ yield sse_message_stop()
+ except Exception as exc:
+ logger.error(f"❌ Claude 流式响应转换失败: {exc}")
+ yield sse_error("api_error", str(exc))
+
+
+@router.post("/v1/messages")
+@router.post("/anthropic/v1/messages")
+async def claude_messages(
+ request: Request,
+ authorization: Optional[str] = Header(None),
+ x_api_key: Optional[str] = Header(None, alias="x-api-key"),
+):
+ source_info = detect_request_source(
+ request,
+ protocol_hint="anthropic",
+ )
+ source_prefix = format_request_source(source_info)
+ started_at = time.perf_counter()
+ requested_model = "unknown"
+
+ try:
+ body = await request.json()
+ except Exception:
+ await write_request_log(
+ provider="zai",
+ model=requested_model,
+ source_info=source_info,
+ success=False,
+ started_at=started_at,
+ status_code=400,
+ error_message="Invalid JSON body",
+ )
+ return _claude_error_response(
+ "Invalid JSON body",
+ 400,
+ "invalid_request_error",
+ )
+
+ requested_model = str(body.get("model") or "unknown")
+ source_info = detect_request_source(
+ request,
+ protocol_hint="anthropic",
+ model_hint=body.get("model"),
+ )
+ source_prefix = format_request_source(source_info)
+
+ if not settings.SKIP_AUTH_TOKEN:
+ api_key = _extract_api_key(authorization, x_api_key)
+ if not api_key:
+ await write_request_log(
+ provider="zai",
+ model=requested_model,
+ source_info=source_info,
+ success=False,
+ started_at=started_at,
+ status_code=401,
+ error_message="Missing API key",
+ )
+ return _claude_error_response(
+ "Missing API key",
+ 401,
+ "authentication_error",
+ )
+ if api_key != settings.AUTH_TOKEN:
+ await write_request_log(
+ provider="zai",
+ model=requested_model,
+ source_info=source_info,
+ success=False,
+ started_at=started_at,
+ status_code=401,
+ error_message="Invalid API key",
+ )
+ return _claude_error_response(
+ "Invalid API key",
+ 401,
+ "authentication_error",
+ )
+
+ try:
+ openai_request = _build_openai_request(body)
+ except Exception as exc:
+ await write_request_log(
+ provider="zai",
+ model=requested_model,
+ source_info=source_info,
+ success=False,
+ started_at=started_at,
+ status_code=400,
+ error_message=f"Invalid request: {exc}",
+ )
+ return _claude_error_response(
+ f"Invalid request: {exc}",
+ 400,
+ "invalid_request_error",
+ )
+
+ if not openai_request.messages:
+ await write_request_log(
+ provider="zai",
+ model=openai_request.model,
+ source_info=source_info,
+ success=False,
+ started_at=started_at,
+ status_code=400,
+ error_message="messages is required",
+ )
+ return _claude_error_response(
+ "messages is required",
+ 400,
+ "invalid_request_error",
+ )
+ logger.info(
+ f"{source_prefix} 🤖 收到 Claude 请求 - 模型: {body.get('model')}, 映射模型: {openai_request.model}, 流式: {openai_request.stream}, 消息数: {len(openai_request.messages)}, 工具数: {len(openai_request.tools) if openai_request.tools else 0}"
+ )
+
+ msg_id = make_claude_id()
+ input_tokens = _estimate_tokens(_build_prompt_text(body))
+
+ try:
+ client = get_upstream_client()
+ result = await client.chat_completion(openai_request)
+ except Exception as exc:
+ logger.error(f"{source_prefix} ❌ Claude 请求处理失败: {exc}")
+ await write_request_log(
+ provider="zai",
+ model=openai_request.model,
+ source_info=source_info,
+ success=False,
+ started_at=started_at,
+ status_code=500,
+ error_message=str(exc),
+ )
+ return _claude_error_response(str(exc), 500, "api_error")
+
+ if isinstance(result, dict) and "error" in result:
+ error = result.get("error") or {}
+ error_code = error.get("code")
+ status_code = error_code if isinstance(error_code, int) else 500
+ await write_request_log(
+ provider="zai",
+ model=openai_request.model,
+ source_info=source_info,
+ success=False,
+ started_at=started_at,
+ status_code=status_code,
+ error_message=error.get("message", "Unknown upstream error"),
+ )
+ return _claude_error_response(
+ error.get("message", "Unknown upstream error"),
+ status_code,
+ error.get("type", "api_error"),
+ )
+
+ if openai_request.stream:
+ if not hasattr(result, "__aiter__"):
+ await write_request_log(
+ provider="zai",
+ model=openai_request.model,
+ source_info=source_info,
+ success=False,
+ started_at=started_at,
+ status_code=500,
+ error_message="Expected streaming response",
+ )
+ return _claude_error_response(
+ "Expected streaming response",
+ 500,
+ "api_error",
+ )
+
+ return StreamingResponse(
+ wrap_claude_stream_with_logging(
+ _stream_openai_to_claude(
+ result,
+ msg_id,
+ openai_request.model,
+ input_tokens,
+ ),
+ provider="zai",
+ model=openai_request.model,
+ source_info=source_info,
+ started_at=started_at,
+ input_tokens=input_tokens,
+ ),
+ media_type="text/event-stream",
+ headers={
+ "Cache-Control": "no-cache",
+ "Connection": "keep-alive",
+ "Access-Control-Allow-Origin": "*",
+ },
+ )
+
+ if not isinstance(result, dict):
+ await write_request_log(
+ provider="zai",
+ model=openai_request.model,
+ source_info=source_info,
+ success=False,
+ started_at=started_at,
+ status_code=500,
+ error_message="Expected non-streaming response payload",
+ )
+ return _claude_error_response(
+ "Expected non-streaming response payload",
+ 500,
+ "api_error",
+ )
+
+ response_data = _convert_openai_response_to_claude(result, msg_id)
+ if not response_data.get("usage", {}).get("input_tokens"):
+ response_data["usage"]["input_tokens"] = input_tokens
+ usage = extract_claude_usage(response_data)
+ await write_request_log(
+ provider="zai",
+ model=openai_request.model,
+ source_info=source_info,
+ success=True,
+ started_at=started_at,
+ status_code=200,
+ input_tokens=usage["input_tokens"],
+ output_tokens=usage["output_tokens"],
+ cache_creation_tokens=usage["cache_creation_tokens"],
+ cache_read_tokens=usage["cache_read_tokens"],
+ total_tokens=usage["total_tokens"],
+ )
+ return JSONResponse(content=response_data)
diff --git a/app/core/claude_compat.py b/app/core/claude_compat.py
new file mode 100644
index 0000000000000000000000000000000000000000..ae26ec1ab2de23547a39f926d5de6197b4e2427d
--- /dev/null
+++ b/app/core/claude_compat.py
@@ -0,0 +1,352 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+
+"""Claude Messages API 兼容辅助函数。"""
+
+from __future__ import annotations
+
+import json
+import uuid
+from typing import Any, Optional
+
+
+def extract_text(content: Any) -> str:
+ """Extract plain text from Claude/OpenAI mixed content blocks."""
+ if isinstance(content, str):
+ return content
+
+ if isinstance(content, list):
+ return " ".join(
+ str(block.get("text", ""))
+ for block in content
+ if isinstance(block, dict) and block.get("type") == "text"
+ ).strip()
+
+ return str(content) if content else ""
+
+
+def claude_messages_to_openai(system: Any, messages: list[dict]) -> list[dict]:
+ """Convert Claude messages payload into OpenAI-style messages."""
+ converted: list[dict] = []
+
+ if system:
+ if isinstance(system, str):
+ converted.append({"role": "system", "content": system})
+ elif isinstance(system, list):
+ system_text = [
+ block.get("text", "")
+ for block in system
+ if isinstance(block, dict) and block.get("type") == "text"
+ ]
+ if system_text:
+ converted.append({
+ "role": "system",
+ "content": "\n".join(system_text),
+ })
+
+ for message in messages:
+ role = message.get("role", "user")
+ content = message.get("content", "")
+
+ if role == "assistant" and isinstance(content, list):
+ text_parts: list[str] = []
+ tool_calls: list[dict] = []
+
+ for block in content:
+ if not isinstance(block, dict):
+ continue
+
+ block_type = block.get("type")
+ if block_type == "text":
+ text_parts.append(block.get("text", ""))
+ elif block_type == "tool_use":
+ tool_calls.append(
+ {
+ "id": block.get(
+ "id",
+ f"call_{uuid.uuid4().hex[:24]}",
+ ),
+ "type": "function",
+ "function": {
+ "name": block.get("name", ""),
+ "arguments": json.dumps(
+ block.get("input", {}),
+ ensure_ascii=False,
+ ),
+ },
+ }
+ )
+
+ openai_message: dict = {
+ "role": "assistant",
+ "content": " ".join(text_parts).strip() or None,
+ }
+ if tool_calls:
+ openai_message["tool_calls"] = tool_calls
+ converted.append(openai_message)
+ continue
+
+ if role == "user" and isinstance(content, list):
+ has_tool_result = any(
+ isinstance(block, dict) and block.get("type") == "tool_result"
+ for block in content
+ )
+ if has_tool_result:
+ for block in content:
+ if not isinstance(block, dict):
+ continue
+
+ block_type = block.get("type")
+ if block_type == "tool_result":
+ result_content = block.get("content", "")
+ if isinstance(result_content, str):
+ rendered = result_content
+ elif isinstance(result_content, list):
+ rendered = " ".join(
+ item.get("text", "")
+ for item in result_content
+ if isinstance(item, dict)
+ and item.get("type") == "text"
+ )
+ else:
+ rendered = str(result_content)
+
+ converted.append(
+ {
+ "role": "tool",
+ "tool_call_id": block.get("tool_use_id", ""),
+ "content": rendered,
+ }
+ )
+ elif block_type == "text":
+ converted.append(
+ {"role": "user", "content": block.get("text", "")}
+ )
+ continue
+
+ converted.append({"role": role, "content": extract_text(content)})
+
+ return converted
+
+
+def claude_tools_to_openai(tools: Optional[list[dict]]) -> Optional[list[dict]]:
+ """Convert Claude tool schemas into OpenAI function tools."""
+ if not tools:
+ return None
+
+ converted = [
+ {
+ "type": "function",
+ "function": {
+ "name": tool.get("name", ""),
+ "description": tool.get("description", ""),
+ "parameters": tool.get("input_schema", {}),
+ },
+ }
+ for tool in tools
+ if isinstance(tool, dict)
+ ]
+ return converted or None
+
+
+def claude_tool_choice_to_openai(tool_choice: Any) -> Any:
+ """Convert Claude tool_choice payload into OpenAI-compatible form."""
+ if not isinstance(tool_choice, dict):
+ return tool_choice
+
+ tool_choice_type = tool_choice.get("type", "auto")
+ if tool_choice_type == "auto":
+ return "auto"
+ if tool_choice_type == "any":
+ return "required"
+ if tool_choice_type == "none":
+ return "none"
+ if tool_choice_type == "tool":
+ name = tool_choice.get("name", "")
+ if name:
+ return {"type": "function", "function": {"name": name}}
+ return tool_choice
+
+
+def make_claude_id() -> str:
+ """Generate a Claude-style message id."""
+ return f"msg_{uuid.uuid4().hex[:24]}"
+
+
+def build_tool_call_blocks(tool_calls: list[dict]) -> list[dict]:
+ """Convert OpenAI tool calls to Claude tool_use blocks."""
+ blocks = []
+ for tool_call in tool_calls:
+ function_data = (
+ tool_call.get("function")
+ if isinstance(tool_call.get("function"), dict)
+ else {}
+ )
+ arguments = function_data.get("arguments", "{}")
+ try:
+ input_data = json.loads(arguments) if isinstance(arguments, str) else arguments
+ except Exception:
+ input_data = {}
+
+ blocks.append(
+ {
+ "type": "tool_use",
+ "id": tool_call.get(
+ "id",
+ f"toolu_{uuid.uuid4().hex[:20]}",
+ ).replace("call_", "toolu_"),
+ "name": function_data.get("name", ""),
+ "input": input_data,
+ }
+ )
+ return blocks
+
+
+def build_non_stream_response(
+ msg_id: str,
+ model: str,
+ reasoning_parts: list[str],
+ answer_text: str,
+ tool_calls: Optional[list[dict]],
+ input_tokens: int,
+ output_tokens: int,
+ cache_creation_tokens: int = 0,
+ cache_read_tokens: int = 0,
+) -> dict:
+ """Build a Claude non-streaming message response."""
+ content: list[dict] = []
+ if reasoning_parts:
+ content.append(
+ {"type": "thinking", "thinking": "".join(reasoning_parts)}
+ )
+ if answer_text:
+ content.append({"type": "text", "text": answer_text})
+ elif not tool_calls:
+ content.append({"type": "text", "text": ""})
+ if tool_calls:
+ content.extend(build_tool_call_blocks(tool_calls))
+
+ return {
+ "id": msg_id,
+ "type": "message",
+ "role": "assistant",
+ "content": content,
+ "model": model,
+ "stop_reason": "tool_use" if tool_calls else "end_turn",
+ "stop_sequence": None,
+ "usage": {
+ "input_tokens": input_tokens,
+ "output_tokens": output_tokens,
+ "cache_creation_input_tokens": cache_creation_tokens,
+ "cache_read_input_tokens": cache_read_tokens,
+ },
+ }
+
+
+def sse(event: str, data: dict) -> str:
+ """Format a Claude SSE event."""
+ return f"event: {event}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n"
+
+
+def sse_message_start(
+ msg_id: str,
+ model: str,
+ input_tokens: int,
+ cache_creation_tokens: int = 0,
+ cache_read_tokens: int = 0,
+) -> str:
+ """Create Claude message_start SSE event."""
+ return sse(
+ "message_start",
+ {
+ "type": "message_start",
+ "message": {
+ "id": msg_id,
+ "type": "message",
+ "role": "assistant",
+ "content": [],
+ "model": model,
+ "stop_reason": None,
+ "stop_sequence": None,
+ "usage": {
+ "input_tokens": input_tokens,
+ "cache_creation_input_tokens": cache_creation_tokens,
+ "cache_read_input_tokens": cache_read_tokens,
+ "output_tokens": 0,
+ },
+ },
+ },
+ )
+
+
+def sse_ping() -> str:
+ """Create Claude ping SSE event."""
+ return sse("ping", {"type": "ping"})
+
+
+def sse_content_block_start(index: int, block: dict) -> str:
+ """Create Claude content_block_start SSE event."""
+ return sse(
+ "content_block_start",
+ {
+ "type": "content_block_start",
+ "index": index,
+ "content_block": block,
+ },
+ )
+
+
+def sse_content_block_delta(index: int, delta: dict) -> str:
+ """Create Claude content_block_delta SSE event."""
+ return sse(
+ "content_block_delta",
+ {"type": "content_block_delta", "index": index, "delta": delta},
+ )
+
+
+def sse_content_block_stop(index: int) -> str:
+ """Create Claude content_block_stop SSE event."""
+ return sse(
+ "content_block_stop",
+ {"type": "content_block_stop", "index": index},
+ )
+
+
+def sse_message_delta(
+ stop_reason: str,
+ output_tokens: int,
+ *,
+ input_tokens: int = 0,
+ cache_creation_tokens: int = 0,
+ cache_read_tokens: int = 0,
+) -> str:
+ """Create Claude message_delta SSE event."""
+ return sse(
+ "message_delta",
+ {
+ "type": "message_delta",
+ "delta": {"stop_reason": stop_reason, "stop_sequence": None},
+ "usage": {
+ "input_tokens": input_tokens,
+ "output_tokens": output_tokens,
+ "cache_creation_input_tokens": cache_creation_tokens,
+ "cache_read_input_tokens": cache_read_tokens,
+ },
+ },
+ )
+
+
+def sse_message_stop() -> str:
+ """Create Claude message_stop SSE event."""
+ return sse("message_stop", {"type": "message_stop"})
+
+
+def sse_error(error_type: str, message: str) -> str:
+ """Create Claude error SSE event."""
+ return sse(
+ "error",
+ {
+ "type": "error",
+ "error": {"type": error_type, "message": message},
+ },
+ )
diff --git a/app/core/config.py b/app/core/config.py
new file mode 100644
index 0000000000000000000000000000000000000000..d830de82a00dc1e1bc21b17195215140d7def612
--- /dev/null
+++ b/app/core/config.py
@@ -0,0 +1,95 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+
+import os
+from typing import Optional
+
+from pydantic_settings import BaseSettings, SettingsConfigDict
+
+
+class Settings(BaseSettings):
+ """Application settings"""
+
+ # API Configuration
+ API_ENDPOINT: str = "https://chat.z.ai/api/v2/chat/completions"
+
+ # Authentication
+ AUTH_TOKEN: Optional[str] = os.getenv("AUTH_TOKEN")
+
+ # Token池配置
+ TOKEN_FAILURE_THRESHOLD: int = int(
+ os.getenv("TOKEN_FAILURE_THRESHOLD", "3")
+ )
+ TOKEN_RECOVERY_TIMEOUT: int = int(
+ os.getenv("TOKEN_RECOVERY_TIMEOUT", "1800")
+ )
+ TOKEN_AUTO_IMPORT_ENABLED: bool = (
+ os.getenv("TOKEN_AUTO_IMPORT_ENABLED", "false").lower() == "true"
+ )
+ TOKEN_AUTO_IMPORT_SOURCE_DIR: str = os.getenv("TOKEN_AUTO_IMPORT_SOURCE_DIR", "")
+ TOKEN_AUTO_IMPORT_INTERVAL: int = int(
+ os.getenv("TOKEN_AUTO_IMPORT_INTERVAL", "300")
+ )
+ TOKEN_AUTO_MAINTENANCE_ENABLED: bool = (
+ os.getenv("TOKEN_AUTO_MAINTENANCE_ENABLED", "false").lower() == "true"
+ )
+ TOKEN_AUTO_MAINTENANCE_INTERVAL: int = int(
+ os.getenv("TOKEN_AUTO_MAINTENANCE_INTERVAL", "1800")
+ )
+ TOKEN_AUTO_REMOVE_DUPLICATES: bool = (
+ os.getenv("TOKEN_AUTO_REMOVE_DUPLICATES", "true").lower() == "true"
+ )
+ TOKEN_AUTO_HEALTH_CHECK: bool = (
+ os.getenv("TOKEN_AUTO_HEALTH_CHECK", "true").lower() == "true"
+ )
+ TOKEN_AUTO_DELETE_INVALID: bool = (
+ os.getenv("TOKEN_AUTO_DELETE_INVALID", "false").lower() == "true"
+ )
+
+ # Model Configuration
+ GLM45_MODEL: str = os.getenv("GLM45_MODEL", "GLM-4.5")
+ GLM45_THINKING_MODEL: str = os.getenv("GLM45_THINKING_MODEL", "GLM-4.5-Thinking")
+ GLM45_SEARCH_MODEL: str = os.getenv("GLM45_SEARCH_MODEL", "GLM-4.5-Search")
+ GLM45_AIR_MODEL: str = os.getenv("GLM45_AIR_MODEL", "GLM-4.5-Air")
+ GLM46V_MODEL: str = os.getenv("GLM46V_MODEL", "GLM-4.6V")
+ GLM5_MODEL: str = os.getenv("GLM5_MODEL", "GLM-5")
+ GLM47_MODEL: str = os.getenv("GLM47_MODEL", "GLM-4.7")
+ GLM47_THINKING_MODEL: str = os.getenv("GLM47_THINKING_MODEL", "GLM-4.7-Thinking")
+ GLM47_SEARCH_MODEL: str = os.getenv("GLM47_SEARCH_MODEL", "GLM-4.7-Search")
+ GLM47_ADVANCED_SEARCH_MODEL: str = os.getenv(
+ "GLM47_ADVANCED_SEARCH_MODEL",
+ "GLM-4.7-advanced-search",
+ )
+
+ # Server Configuration
+ LISTEN_PORT: int = int(os.getenv("LISTEN_PORT", "8080"))
+ DEBUG_LOGGING: bool = os.getenv("DEBUG_LOGGING", "true").lower() == "true"
+ SERVICE_NAME: str = os.getenv("SERVICE_NAME", "api-proxy-server")
+ ROOT_PATH: str = os.getenv("ROOT_PATH", "")
+
+ ANONYMOUS_MODE: bool = os.getenv("ANONYMOUS_MODE", "true").lower() == "true"
+ GUEST_POOL_SIZE: int = int(os.getenv("GUEST_POOL_SIZE", "3"))
+ TOOL_SUPPORT: bool = os.getenv("TOOL_SUPPORT", "true").lower() == "true"
+ SCAN_LIMIT: int = int(os.getenv("SCAN_LIMIT", "200000"))
+ SKIP_AUTH_TOKEN: bool = os.getenv("SKIP_AUTH_TOKEN", "false").lower() == "true"
+
+ # Proxy Configuration
+ HTTP_PROXY: Optional[str] = os.getenv("HTTP_PROXY")
+ HTTPS_PROXY: Optional[str] = os.getenv("HTTPS_PROXY")
+ SOCKS5_PROXY: Optional[str] = os.getenv("SOCKS5_PROXY")
+
+ # Admin Panel Authentication
+ ADMIN_PASSWORD: str = os.getenv("ADMIN_PASSWORD", "admin123")
+ SESSION_SECRET_KEY: str = os.getenv(
+ "SESSION_SECRET_KEY",
+ "your-secret-key-change-in-production",
+ )
+ DB_PATH: str = os.getenv("DB_PATH", "tokens.db")
+
+ model_config = SettingsConfigDict(
+ env_file=".env",
+ extra="ignore", # 忽略额外字段,防止环境变量中的未知字段导致验证错误
+ )
+
+
+settings = Settings()
diff --git a/app/core/openai.py b/app/core/openai.py
new file mode 100644
index 0000000000000000000000000000000000000000..b17b3e97f331af11d017511f916d3292ba23a765
--- /dev/null
+++ b/app/core/openai.py
@@ -0,0 +1,224 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+
+import json
+import time
+from typing import Optional
+
+from fastapi import APIRouter, Header, HTTPException, Request
+from fastapi.responses import JSONResponse, StreamingResponse
+
+from app.core.config import settings
+from app.models.schemas import (
+ Choice,
+ Message,
+ Model,
+ ModelsResponse,
+ OpenAIRequest,
+ OpenAIResponse,
+ Usage,
+)
+from app.core.upstream import UpstreamClient
+from app.utils.logger import get_logger
+from app.utils.request_logging import (
+ extract_openai_usage,
+ wrap_openai_stream_with_logging,
+ write_request_log,
+)
+from app.utils.request_source import detect_request_source, format_request_source
+
+logger = get_logger()
+router = APIRouter()
+
+_upstream_client: Optional[UpstreamClient] = None
+
+
+def get_upstream_client() -> UpstreamClient:
+ """获取懒加载的上游适配器单例。"""
+ global _upstream_client
+ if _upstream_client is None:
+ _upstream_client = UpstreamClient()
+ return _upstream_client
+
+
+async def handle_non_stream_response(stream_response, request: OpenAIRequest) -> JSONResponse:
+ """处理非流式响应。"""
+ logger.info("📄 开始处理非流式响应")
+
+ full_content = []
+ async for chunk_data in stream_response():
+ if chunk_data.startswith("data: "):
+ chunk_str = chunk_data[6:].strip()
+ if chunk_str and chunk_str != "[DONE]":
+ try:
+ chunk = json.loads(chunk_str)
+ if "choices" in chunk and chunk["choices"]:
+ choice = chunk["choices"][0]
+ if "delta" in choice and "content" in choice["delta"]:
+ content = choice["delta"]["content"]
+ if content:
+ full_content.append(content)
+ except json.JSONDecodeError:
+ continue
+
+ response_data = OpenAIResponse(
+ id=f"chatcmpl-{int(time.time())}",
+ object="chat.completion",
+ created=int(time.time()),
+ model=request.model,
+ choices=[
+ Choice(
+ index=0,
+ message=Message(
+ role="assistant",
+ content="".join(full_content),
+ tool_calls=None,
+ ),
+ finish_reason="stop",
+ )
+ ],
+ usage=Usage(prompt_tokens=0, completion_tokens=0, total_tokens=0),
+ )
+
+ logger.info("✅ 非流式响应处理完成")
+ return JSONResponse(content=response_data.model_dump(exclude_none=True))
+
+
+@router.get("/v1/models")
+async def list_models():
+ """返回当前服务支持的模型列表。"""
+ try:
+ client = get_upstream_client()
+ current_time = int(time.time())
+ response = ModelsResponse(
+ data=[
+ Model(id=model_id, created=current_time, owned_by=settings.SERVICE_NAME)
+ for model_id in client.get_supported_models()
+ ]
+ )
+ return JSONResponse(content=response.model_dump(exclude_none=True))
+ except Exception as exc:
+ logger.error(f"❌ 获取模型列表失败: {exc}")
+ raise HTTPException(status_code=500, detail=f"Failed to list models: {exc}")
+
+
+@router.post("/v1/chat/completions")
+async def chat_completions(
+ body: OpenAIRequest,
+ http_request: Request,
+ authorization: Optional[str] = Header(None),
+):
+ """直接调用上游适配器处理请求。"""
+ source_info = detect_request_source(
+ http_request,
+ protocol_hint="openai",
+ model_hint=body.model,
+ )
+ source_prefix = format_request_source(source_info)
+ started_at = time.perf_counter()
+
+ role = body.messages[0].role if body.messages else "unknown"
+ logger.info(
+ f"{source_prefix} 😶🌫️ 收到客户端请求 - 模型: {body.model}, 流式: {body.stream}, 消息数: {len(body.messages)}, 角色: {role}, 工具数: {len(body.tools) if body.tools else 0}"
+ )
+
+ try:
+ if not settings.SKIP_AUTH_TOKEN:
+ if not authorization or not authorization.startswith("Bearer "):
+ raise HTTPException(status_code=401, detail="Missing or invalid Authorization header")
+
+ api_key = authorization[7:]
+ if api_key != settings.AUTH_TOKEN:
+ raise HTTPException(status_code=401, detail="Invalid API key")
+
+ client = get_upstream_client()
+ result = await client.chat_completion(body)
+
+ if isinstance(result, dict) and "error" in result:
+ error_info = result["error"]
+ error_message = error_info.get("message", "Unknown upstream error")
+ error_code = error_info.get("code")
+ status_code = 404 if error_code == "model_not_found" else 500
+ raise HTTPException(status_code=status_code, detail=error_message)
+
+ if body.stream:
+ if hasattr(result, "__aiter__"):
+ return StreamingResponse(
+ wrap_openai_stream_with_logging(
+ result,
+ provider="zai",
+ model=body.model,
+ source_info=source_info,
+ started_at=started_at,
+ ),
+ media_type="text/event-stream",
+ headers={
+ "Cache-Control": "no-cache",
+ "Connection": "keep-alive",
+ "Access-Control-Allow-Origin": "*",
+ },
+ )
+ raise HTTPException(
+ status_code=500,
+ detail="Expected streaming response but got non-streaming result",
+ )
+
+ if isinstance(result, dict):
+ usage = extract_openai_usage(result)
+ await write_request_log(
+ provider="zai",
+ model=body.model,
+ source_info=source_info,
+ success="error" not in result,
+ started_at=started_at,
+ status_code=200 if "error" not in result else 500,
+ input_tokens=usage["input_tokens"],
+ output_tokens=usage["output_tokens"],
+ cache_creation_tokens=usage["cache_creation_tokens"],
+ cache_read_tokens=usage["cache_read_tokens"],
+ total_tokens=usage["total_tokens"],
+ error_message=(result.get("error") or {}).get("message") if isinstance(result, dict) else None,
+ )
+ return JSONResponse(content=result)
+
+ response = await handle_non_stream_response(result, body)
+ response_body = json.loads(response.body)
+ usage = extract_openai_usage(response_body)
+ await write_request_log(
+ provider="zai",
+ model=body.model,
+ source_info=source_info,
+ success=True,
+ started_at=started_at,
+ status_code=200,
+ input_tokens=usage["input_tokens"],
+ output_tokens=usage["output_tokens"],
+ cache_creation_tokens=usage["cache_creation_tokens"],
+ cache_read_tokens=usage["cache_read_tokens"],
+ total_tokens=usage["total_tokens"],
+ )
+ return response
+
+ except HTTPException as exc:
+ await write_request_log(
+ provider="zai",
+ model=body.model,
+ source_info=source_info,
+ success=False,
+ started_at=started_at,
+ status_code=exc.status_code,
+ error_message=str(exc.detail),
+ )
+ raise
+ except Exception as exc:
+ logger.error(f"{source_prefix} ❌ 请求处理失败: {exc}")
+ await write_request_log(
+ provider="zai",
+ model=body.model,
+ source_info=source_info,
+ success=False,
+ started_at=started_at,
+ status_code=500,
+ error_message=str(exc),
+ )
+ raise HTTPException(status_code=500, detail=f"Internal server error: {str(exc)}")
diff --git a/app/core/openai_compat.py b/app/core/openai_compat.py
new file mode 100644
index 0000000000000000000000000000000000000000..cccdc094e617e447e1bdb11cd68afec2c982a72a
--- /dev/null
+++ b/app/core/openai_compat.py
@@ -0,0 +1,139 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+
+"""OpenAI 兼容响应辅助函数。"""
+
+import json
+import time
+import uuid
+from typing import Any, Dict, List, Optional
+
+from app.utils.logger import get_logger
+
+logger = get_logger()
+SYSTEM_FINGERPRINT = "fp_api_proxy_001"
+
+
+def create_chat_id() -> str:
+ """生成聊天 ID。"""
+ return f"chatcmpl-{uuid.uuid4().hex}"
+
+
+def create_openai_chunk(
+ chat_id: str,
+ model: str,
+ delta: Dict[str, Any],
+ finish_reason: Optional[str] = None,
+) -> Dict[str, Any]:
+ """创建 OpenAI 格式的流式响应块。"""
+ return {
+ "id": chat_id,
+ "object": "chat.completion.chunk",
+ "created": int(time.time()),
+ "model": model,
+ "choices": [
+ {
+ "index": 0,
+ "delta": delta,
+ "finish_reason": finish_reason,
+ "logprobs": None,
+ }
+ ],
+ "system_fingerprint": SYSTEM_FINGERPRINT,
+ }
+
+
+def create_openai_response(
+ chat_id: str,
+ model: str,
+ content: str,
+ usage: Optional[Dict[str, int]] = None,
+) -> Dict[str, Any]:
+ """创建 OpenAI 格式的非流式响应。"""
+ return {
+ "id": chat_id,
+ "object": "chat.completion",
+ "created": int(time.time()),
+ "model": model,
+ "choices": [
+ {
+ "index": 0,
+ "message": {"role": "assistant", "content": content},
+ "finish_reason": "stop",
+ "logprobs": None,
+ }
+ ],
+ "usage": usage
+ or {
+ "prompt_tokens": 0,
+ "completion_tokens": 0,
+ "total_tokens": 0,
+ },
+ "system_fingerprint": SYSTEM_FINGERPRINT,
+ }
+
+
+def create_openai_response_with_reasoning(
+ chat_id: str,
+ model: str,
+ content: str,
+ reasoning_content: Optional[str] = None,
+ usage: Optional[Dict[str, int]] = None,
+ tool_calls: Optional[List[Dict[str, Any]]] = None,
+) -> Dict[str, Any]:
+ """创建包含 reasoning/tool_calls 的 OpenAI 响应。"""
+ message: Dict[str, Any] = {
+ "role": "assistant",
+ "content": content,
+ }
+
+ if reasoning_content and reasoning_content.strip():
+ message["reasoning_content"] = reasoning_content
+
+ if tool_calls:
+ message["tool_calls"] = tool_calls
+
+ return {
+ "id": chat_id,
+ "object": "chat.completion",
+ "created": int(time.time()),
+ "model": model,
+ "choices": [
+ {
+ "index": 0,
+ "message": message,
+ "finish_reason": "tool_calls" if tool_calls else "stop",
+ "logprobs": None,
+ }
+ ],
+ "usage": usage
+ or {
+ "prompt_tokens": 0,
+ "completion_tokens": 0,
+ "total_tokens": 0,
+ },
+ "system_fingerprint": SYSTEM_FINGERPRINT,
+ }
+
+
+async def format_sse_chunk(chunk: Dict[str, Any]) -> str:
+ """格式化 SSE 响应块。"""
+ return f"data: {json.dumps(chunk, ensure_ascii=False)}\n\n"
+
+
+async def format_sse_done() -> str:
+ """格式化 SSE 结束标记。"""
+ return "data: [DONE]\n\n"
+
+
+def handle_error(error: Exception, context: str = "") -> Dict[str, Any]:
+ """统一错误处理。"""
+ error_msg = f"上游{context}错误: {str(error)}" if context else f"上游错误: {str(error)}"
+ logger.error(error_msg)
+ return {
+ "error": {
+ "message": error_msg,
+ "type": "upstream_error",
+ "code": "internal_error",
+ }
+ }
diff --git a/app/core/upstream.py b/app/core/upstream.py
new file mode 100644
index 0000000000000000000000000000000000000000..ab183a5fd15b336829b7133c113d6630001461e3
--- /dev/null
+++ b/app/core/upstream.py
@@ -0,0 +1,2245 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+
+"""上游适配器。"""
+
+import asyncio
+import base64
+import json
+import random
+import time
+import uuid
+from datetime import datetime, timezone
+from typing import Any, AsyncGenerator, Dict, List, Optional, Set, Tuple, Union
+from urllib.parse import urlencode
+
+import httpx
+
+from app.core.config import settings
+from app.core.openai_compat import (
+ create_openai_chunk,
+ create_openai_response_with_reasoning,
+ format_sse_chunk,
+ handle_error,
+)
+from app.models.schemas import OpenAIRequest
+from app.utils.fe_version import get_latest_fe_version
+from app.utils.guest_session_pool import get_guest_session_pool
+from app.utils.logger import get_logger
+from app.utils.signature import generate_signature
+from app.utils.token_pool import get_token_pool
+from app.utils.tool_call_handler import (
+ parse_and_extract_tool_calls,
+)
+from app.utils.user_agent import get_random_user_agent
+
+logger = get_logger()
+
+DEFAULT_ZAI_BASE_URL = "https://chat.z.ai"
+CHAT_BOOTSTRAP_MAX_CONTENT_LEN = 500
+DEFAULT_PLATFORM = "web"
+DEFAULT_CLIENT_VERSION = "0.0.1"
+DEFAULT_TIMEZONE = "Asia/Shanghai"
+DEFAULT_LANGUAGE = "zh-CN"
+DEFAULT_SCREEN_WIDTH = "1920"
+DEFAULT_SCREEN_HEIGHT = "1080"
+DEFAULT_VIEWPORT_WIDTH = "944"
+DEFAULT_VIEWPORT_HEIGHT = "919"
+DEFAULT_VIEWPORT_SIZE = f"{DEFAULT_VIEWPORT_WIDTH}x{DEFAULT_VIEWPORT_HEIGHT}"
+DEFAULT_SCREEN_RESOLUTION = f"{DEFAULT_SCREEN_WIDTH}x{DEFAULT_SCREEN_HEIGHT}"
+DEFAULT_COLOR_DEPTH = "24"
+DEFAULT_PIXEL_RATIO = "1.25"
+DEFAULT_MAX_TOUCH_POINTS = "10"
+DEFAULT_TIMEZONE_OFFSET = "-480"
+DEFAULT_PAGE_TITLE = "Z.ai Chat Proxy"
+DEFAULT_COMPLETION_FEATURES = [
+ {"type": "mcp", "server": "vibe-coding", "status": "hidden"},
+ {"type": "mcp", "server": "ppt-maker", "status": "hidden"},
+ {"type": "mcp", "server": "image-search", "status": "hidden"},
+ {"type": "mcp", "server": "deep-research", "status": "hidden"},
+ {"type": "tool_selector", "server": "tool_selector", "status": "hidden"},
+ {"type": "mcp", "server": "advanced-search", "status": "hidden"},
+]
+GLM46V_MCP_SERVERS = [
+ "vlm-image-search",
+ "vlm-image-recognition",
+ "vlm-image-processing",
+]
+GLM46V_SELECTED_FEATURES = [
+ {"type": "mcp", "server": "vlm-image-search", "status": "selected"},
+ {"type": "mcp", "server": "vlm-image-recognition", "status": "selected"},
+ {"type": "mcp", "server": "vlm-image-processing", "status": "selected"},
+]
+
+def generate_uuid() -> str:
+ """生成UUID v4"""
+ return str(uuid.uuid4())
+
+def get_dynamic_headers(
+ chat_id: str = "",
+ browser_type: Optional[str] = None,
+) -> Dict[str, str]:
+ """生成上游请求所需的动态浏览器 headers。"""
+ browser_choices = [
+ "chrome",
+ "chrome",
+ "chrome",
+ "edge",
+ "edge",
+ "firefox",
+ "safari",
+ ]
+ selected_browser = browser_type or random.choice(browser_choices)
+ user_agent = get_random_user_agent(selected_browser)
+ fe_version = get_latest_fe_version()
+
+ chrome_version = "139"
+ edge_version = "139"
+
+ if "Chrome/" in user_agent:
+ try:
+ chrome_version = user_agent.split("Chrome/")[1].split(".")[0]
+ except Exception:
+ pass
+
+ if "Edg/" in user_agent:
+ try:
+ edge_version = user_agent.split("Edg/")[1].split(".")[0]
+ sec_ch_ua = (
+ f'"Microsoft Edge";v="{edge_version}", '
+ f'"Chromium";v="{chrome_version}", "Not_A Brand";v="24"'
+ )
+ except Exception:
+ sec_ch_ua = (
+ f'"Not_A Brand";v="8", "Chromium";v="{chrome_version}", '
+ f'"Google Chrome";v="{chrome_version}"'
+ )
+ elif "Firefox/" in user_agent:
+ sec_ch_ua = None
+ else:
+ sec_ch_ua = (
+ f'"Not_A Brand";v="8", "Chromium";v="{chrome_version}", '
+ f'"Google Chrome";v="{chrome_version}"'
+ )
+
+ headers = {
+ "Content-Type": "application/json",
+ "Accept": "application/json, text/event-stream",
+ "Connection": "keep-alive",
+ "Cache-Control": "no-cache",
+ "User-Agent": user_agent,
+ "Accept-Language": "zh-CN,zh;q=0.9,en;q=0.8",
+ "X-FE-Version": fe_version,
+ "Origin": "https://chat.z.ai",
+ }
+
+ if sec_ch_ua:
+ headers["sec-ch-ua"] = sec_ch_ua
+ headers["sec-ch-ua-mobile"] = "?0"
+ headers["sec-ch-ua-platform"] = '"Windows"'
+
+ if chat_id:
+ headers["Referer"] = f"https://chat.z.ai/c/{chat_id}"
+ else:
+ headers["Referer"] = "https://chat.z.ai/"
+
+ return headers
+
+def _urlsafe_b64decode(data: str) -> bytes:
+ """Decode a URL-safe base64 string with proper padding."""
+ if isinstance(data, str):
+ data_bytes = data.encode("utf-8")
+ else:
+ data_bytes = data
+ padding = b"=" * (-len(data_bytes) % 4)
+ return base64.urlsafe_b64decode(data_bytes + padding)
+
+
+def _decode_jwt_payload(token: str) -> Dict[str, Any]:
+ """Decode JWT payload without verification to extract metadata."""
+ try:
+ parts = token.split(".")
+ if len(parts) < 2:
+ return {}
+ payload_raw = _urlsafe_b64decode(parts[1])
+ return json.loads(payload_raw.decode("utf-8", errors="ignore"))
+ except Exception:
+ return {}
+
+
+def _extract_user_id_from_token(token: str) -> str:
+ """Extract user_id from a JWT's payload. Fallback to 'guest'."""
+ payload = _decode_jwt_payload(token) if token else {}
+ for key in ("id", "user_id", "uid", "sub"):
+ val = payload.get(key)
+ if isinstance(val, (str, int)) and str(val):
+ return str(val)
+ return "guest"
+
+
+def _extract_text_from_content(content: Any) -> str:
+ """Extract text parts from OpenAI-compatible content payloads."""
+ if isinstance(content, str):
+ return content
+
+ if isinstance(content, list):
+ parts: List[str] = []
+ for item in content:
+ if isinstance(item, dict) and item.get("type") == "text":
+ parts.append(str(item.get("text", "")))
+ return " ".join(part for part in parts if part).strip()
+
+ if content is None:
+ return ""
+
+ try:
+ return json.dumps(content, ensure_ascii=False)
+ except Exception:
+ return str(content)
+
+
+def _stringify_tool_arguments(arguments: Any) -> str:
+ """Normalize tool-call arguments into a JSON string."""
+ if isinstance(arguments, str):
+ return arguments
+
+ try:
+ return json.dumps(arguments or {}, ensure_ascii=False)
+ except Exception:
+ return "{}"
+
+
+def _build_tool_call_index(
+ messages: List[Dict[str, Any]],
+) -> Dict[str, Dict[str, str]]:
+ """Index assistant tool calls by id for later tool-result messages."""
+ index: Dict[str, Dict[str, str]] = {}
+
+ for message in messages:
+ if message.get("role") != "assistant":
+ continue
+
+ tool_calls = message.get("tool_calls")
+ if not isinstance(tool_calls, list):
+ continue
+
+ for tool_call in tool_calls:
+ if not isinstance(tool_call, dict):
+ continue
+
+ tool_call_id = tool_call.get("id")
+ function_data = (
+ tool_call.get("function")
+ if isinstance(tool_call.get("function"), dict)
+ else {}
+ )
+ name = str(function_data.get("name", "")).strip()
+ if not isinstance(tool_call_id, str) or not name:
+ continue
+
+ index[tool_call_id] = {
+ "name": name,
+ "arguments": _stringify_tool_arguments(
+ function_data.get("arguments")
+ ),
+ }
+
+ return index
+
+
+def _format_tool_result_message(
+ tool_name: str,
+ tool_arguments: str,
+ result_content: str,
+) -> str:
+ """Serialize a tool result into a text block the upstream can consume."""
+ return (
+ "\n"
+ f"{tool_name}\n"
+ f"{tool_arguments}\n"
+ f"{result_content}\n"
+ ""
+ )
+
+
+def _format_assistant_tool_calls(tool_calls: List[Dict[str, Any]]) -> str:
+ """Serialize historical assistant tool calls into a text block."""
+ blocks: List[str] = []
+
+ for tool_call in tool_calls:
+ if not isinstance(tool_call, dict):
+ continue
+
+ function_data = (
+ tool_call.get("function")
+ if isinstance(tool_call.get("function"), dict)
+ else {}
+ )
+ name = str(function_data.get("name", "")).strip()
+ if not name:
+ continue
+
+ arguments = _stringify_tool_arguments(function_data.get("arguments"))
+ blocks.append(
+ "\n"
+ f"{name}\n"
+ f"{arguments}\n"
+ ""
+ )
+
+ if not blocks:
+ return ""
+
+ return "\n" + "\n".join(blocks) + "\n"
+
+
+def _preprocess_openai_messages(
+ messages: List[Dict[str, Any]],
+) -> List[Dict[str, Any]]:
+ """Normalize OpenAI history into shapes accepted by the upstream service."""
+ tool_call_index = _build_tool_call_index(messages)
+ normalized: List[Dict[str, Any]] = []
+
+ for message in messages:
+ if not isinstance(message, dict):
+ continue
+
+ role = message.get("role")
+
+ if role == "developer":
+ converted = dict(message)
+ converted["role"] = "system"
+ normalized.append(converted)
+ continue
+
+ if role == "tool":
+ tool_call_id = message.get("tool_call_id")
+ content = _extract_text_from_content(message.get("content"))
+ tool_info = tool_call_index.get(
+ tool_call_id,
+ {
+ "name": str(message.get("name") or "unknown_tool"),
+ "arguments": "{}",
+ },
+ )
+ normalized.append(
+ {
+ "role": "user",
+ "content": _format_tool_result_message(
+ tool_info["name"],
+ tool_info["arguments"],
+ content,
+ ),
+ }
+ )
+ continue
+
+ if role == "assistant" and isinstance(message.get("tool_calls"), list):
+ content = _extract_text_from_content(message.get("content"))
+ tool_calls_text = _format_assistant_tool_calls(message["tool_calls"])
+ merged_content = "\n".join(
+ part for part in (content, tool_calls_text) if part
+ ).strip()
+ normalized.append({"role": "assistant", "content": merged_content})
+ continue
+
+ normalized.append(dict(message))
+
+ return normalized
+
+
+def _extract_last_user_text(messages: List[Dict[str, Any]]) -> str:
+ """Extract the last user text from the original OpenAI message history."""
+ for message in reversed(messages):
+ if message.get("role") != "user":
+ continue
+ content = _extract_text_from_content(message.get("content"))
+ if content:
+ return content
+ return ""
+
+
+
+class UpstreamClient:
+ """当前服务使用的上游适配器。"""
+
+ def __init__(self):
+ self.name = "upstream"
+ self.logger = logger
+ self.api_endpoint = settings.API_ENDPOINT
+
+ # 当前上游特定配置
+ self.base_url = DEFAULT_ZAI_BASE_URL
+ self.auth_url = f"{self.base_url}/api/v1/auths/"
+
+ # 模型映射
+ self.model_mapping = {
+ settings.GLM45_MODEL: "0727-360B-API", # GLM-4.5
+ settings.GLM45_THINKING_MODEL: "0727-360B-API", # GLM-4.5-Thinking
+ settings.GLM45_SEARCH_MODEL: "0727-360B-API", # GLM-4.5-Search
+ settings.GLM45_AIR_MODEL: "0727-106B-API", # GLM-4.5-Air
+ settings.GLM46V_MODEL: "glm-4.6v", # GLM-4.6V多模态
+ settings.GLM5_MODEL: "glm-5", # GLM-5
+ settings.GLM47_MODEL: "glm-4.7", # GLM-4.7
+ settings.GLM47_THINKING_MODEL: "glm-4.7", # GLM-4.7-Thinking
+ settings.GLM47_SEARCH_MODEL: "glm-4.7", # GLM-4.7-Search
+ settings.GLM47_ADVANCED_SEARCH_MODEL: "glm-4.7", # GLM-4.7-advanced-search
+ }
+
+ def _get_guest_retry_limit(self) -> int:
+ """匿名号池可提供的最大重试预算。"""
+ if not settings.ANONYMOUS_MODE:
+ return 0
+
+ guest_pool = get_guest_session_pool()
+ if not guest_pool:
+ return max(2, settings.GUEST_POOL_SIZE + 1)
+
+ pool_status = guest_pool.get_pool_status()
+ available_sessions = int(
+ pool_status.get("valid_sessions")
+ or pool_status.get("available_sessions")
+ or 0
+ )
+ return max(2, available_sessions + 1)
+
+ def _get_authenticated_retry_limit(self) -> int:
+ """认证号池与静态 Token 可提供的最大重试预算。"""
+ available_tokens = 0
+ token_pool = get_token_pool()
+ if token_pool:
+ available_tokens = int(
+ token_pool.get_pool_status().get("available_tokens", 0) or 0
+ )
+
+ return max(0, available_tokens)
+
+ def _get_total_retry_limit(self) -> int:
+ """综合认证号池与匿名号池的最大尝试次数。"""
+ return max(
+ 1,
+ self._get_authenticated_retry_limit() + self._get_guest_retry_limit(),
+ )
+
+ def _is_guest_auth(self, transformed: Dict[str, Any]) -> bool:
+ """判断当前请求是否使用匿名会话。"""
+ return str(transformed.get("auth_mode") or "") == "guest"
+
+ def _should_retry_guest_session(
+ self,
+ status_code: int,
+ is_concurrency_limited: bool,
+ attempt: int,
+ max_attempts: int,
+ transformed: Dict[str, Any],
+ ) -> bool:
+ """判断匿名号池是否需要刷新会话后重试。"""
+ return (
+ self._is_guest_auth(transformed)
+ and (status_code == 401 or is_concurrency_limited)
+ and attempt + 1 < max_attempts
+ )
+
+ def _should_retry_authenticated_session(
+ self,
+ status_code: int,
+ is_concurrency_limited: bool,
+ attempt: int,
+ max_attempts: int,
+ transformed: Dict[str, Any],
+ ) -> bool:
+ """判断认证号池是否需要切号重试。"""
+ current_token = str(transformed.get("token") or "")
+ return (
+ not self._is_guest_auth(transformed)
+ and bool(current_token)
+ and (status_code == 401 or is_concurrency_limited)
+ and attempt + 1 < max_attempts
+ )
+
+ async def _release_guest_session(self, transformed: Dict[str, Any]):
+ """释放当前匿名会话占用。"""
+ if not self._is_guest_auth(transformed):
+ return
+
+ guest_pool = get_guest_session_pool()
+ guest_user_id = str(
+ transformed.get("guest_user_id") or transformed.get("user_id") or ""
+ )
+ if guest_pool and guest_user_id:
+ guest_pool.release(guest_user_id)
+
+ async def _report_guest_session_failure(
+ self,
+ transformed: Dict[str, Any],
+ *,
+ is_concurrency_limited: bool = False,
+ ):
+ """上报匿名会话失败并补齐新会话。"""
+ if not self._is_guest_auth(transformed):
+ return
+
+ guest_pool = get_guest_session_pool()
+ guest_user_id = str(
+ transformed.get("guest_user_id") or transformed.get("user_id") or ""
+ )
+ if not guest_pool or not guest_user_id:
+ return
+
+ if is_concurrency_limited:
+ await guest_pool.cleanup_idle_chats()
+
+ await guest_pool.report_failure(guest_user_id)
+
+ async def _refresh_guest_request(
+ self,
+ request: OpenAIRequest,
+ attempt: int,
+ excluded_tokens: Set[str],
+ excluded_guest_user_ids: Set[str],
+ failed_transformed: Dict[str, Any],
+ is_concurrency_limited: bool = False,
+ ) -> Dict[str, Any]:
+ """匿名会话失效或并发受限后切换会话并重签请求。"""
+ retry_number = attempt + 2
+ self.logger.warning(
+ "🔄 匿名会话不可用,正在切换匿名会话并进行第 "
+ f"{retry_number} 次请求"
+ )
+ await self._report_guest_session_failure(
+ failed_transformed,
+ is_concurrency_limited=is_concurrency_limited,
+ )
+ return await self.transform_request(
+ request,
+ excluded_tokens=excluded_tokens,
+ excluded_guest_user_ids=excluded_guest_user_ids,
+ )
+
+ async def _refresh_authenticated_request(
+ self,
+ request: OpenAIRequest,
+ attempt: int,
+ excluded_tokens: Set[str],
+ excluded_guest_user_ids: Set[str],
+ ) -> Dict[str, Any]:
+ """认证模式下切换到下一枚 Token,并允许回退匿名池。"""
+ retry_number = attempt + 2
+ self.logger.warning(
+ "🔄 检测到认证会话不可用,正在切换认证 Token/回退匿名池并进行第 "
+ f"{retry_number} 次请求"
+ )
+ return await self.transform_request(
+ request,
+ excluded_tokens=excluded_tokens,
+ excluded_guest_user_ids=excluded_guest_user_ids,
+ )
+
+ def _extract_upstream_error_details(
+ self,
+ status_code: int,
+ error_text: str,
+ ) -> Tuple[Optional[int], str]:
+ """解析上游错误响应中的 code/message。"""
+ parsed_code: Optional[int] = None
+ parsed_message = (error_text or "").strip()
+
+ try:
+ payload = json.loads(error_text)
+ except Exception:
+ return parsed_code, parsed_message
+
+ if not isinstance(payload, dict):
+ return parsed_code, parsed_message
+
+ candidates = [
+ payload,
+ payload.get("error") if isinstance(payload.get("error"), dict) else None,
+ payload.get("detail") if isinstance(payload.get("detail"), dict) else None,
+ payload.get("data") if isinstance(payload.get("data"), dict) else None,
+ ]
+
+ for candidate in candidates:
+ if not isinstance(candidate, dict):
+ continue
+
+ code = candidate.get("code")
+ if isinstance(code, int):
+ parsed_code = code
+ elif isinstance(code, str) and code.isdigit():
+ parsed_code = int(code)
+
+ for key in ("message", "msg", "detail", "error"):
+ value = candidate.get(key)
+ if isinstance(value, str) and value.strip():
+ parsed_message = value.strip()
+ break
+
+ if parsed_code is not None or parsed_message:
+ break
+
+ return parsed_code, parsed_message
+
+ def _is_concurrency_limited(
+ self,
+ status_code: int,
+ error_code: Optional[int],
+ error_message: str,
+ ) -> bool:
+ """判断是否为上游并发限制/429 场景。"""
+ message = (error_message or "").casefold()
+ return (
+ status_code == 429
+ or error_code == 429
+ or "concurrency" in message
+ or "too many requests" in message
+ or "并发" in error_message
+ )
+
+ def get_supported_models(self) -> List[str]:
+ """获取支持的模型列表"""
+ return [
+ settings.GLM45_MODEL,
+ settings.GLM45_THINKING_MODEL,
+ settings.GLM45_SEARCH_MODEL,
+ settings.GLM45_AIR_MODEL,
+ settings.GLM46V_MODEL,
+ settings.GLM5_MODEL,
+ settings.GLM47_MODEL,
+ settings.GLM47_THINKING_MODEL,
+ settings.GLM47_SEARCH_MODEL,
+ settings.GLM47_ADVANCED_SEARCH_MODEL,
+ ]
+
+ def _requires_persisted_chat(self, upstream_model_id: str) -> bool:
+ """需要挂载真实 chat 会话的上游模型。"""
+ return bool(
+ self._get_model_request_profile(upstream_model_id)["use_persisted_chat"]
+ )
+
+ def _get_model_request_profile(self, upstream_model_id: str) -> Dict[str, Any]:
+ """返回模型专属的请求配置。"""
+ if upstream_model_id == "glm-4.6v":
+ return {
+ "use_persisted_chat": True,
+ "preview_mode": False,
+ "mcp_servers": list(GLM46V_MCP_SERVERS),
+ "feature_entries": [dict(item) for item in GLM46V_SELECTED_FEATURES],
+ "default_enable_thinking": True,
+ }
+
+ if upstream_model_id == "glm-5":
+ return {
+ "use_persisted_chat": False,
+ "preview_mode": True,
+ "mcp_servers": [],
+ "feature_entries": [],
+ "default_enable_thinking": True,
+ }
+
+ return {
+ "use_persisted_chat": upstream_model_id == "glm-4.7",
+ "preview_mode": True,
+ "mcp_servers": [],
+ "feature_entries": [],
+ "default_enable_thinking": None,
+ }
+
+ def _build_request_variables(self) -> Dict[str, str]:
+ """构建上游请求需要的运行时变量。"""
+ now = datetime.now()
+ return {
+ "{{USER_NAME}}": "Guest",
+ "{{USER_LOCATION}}": "Unknown",
+ "{{CURRENT_DATETIME}}": now.strftime("%Y-%m-%d %H:%M:%S"),
+ "{{CURRENT_DATE}}": now.strftime("%Y-%m-%d"),
+ "{{CURRENT_TIME}}": now.strftime("%H:%M:%S"),
+ "{{CURRENT_WEEKDAY}}": now.strftime("%A"),
+ "{{CURRENT_TIMEZONE}}": DEFAULT_TIMEZONE,
+ "{{USER_LANGUAGE}}": DEFAULT_LANGUAGE,
+ }
+
+ def _build_browser_query_params(
+ self,
+ *,
+ chat_id: str,
+ token: str,
+ user_id: str,
+ user_agent: str,
+ timestamp_ms: int,
+ ) -> Dict[str, str]:
+ """构建 GLM-4.7 所需的浏览器指纹查询参数。"""
+ now = datetime.now(timezone.utc)
+ browser_name = "Chrome"
+ if "Edg/" in user_agent:
+ browser_name = "Microsoft Edge"
+ elif "Firefox/" in user_agent:
+ browser_name = "Firefox"
+ elif "Safari/" in user_agent and "Chrome/" not in user_agent:
+ browser_name = "Safari"
+
+ return {
+ "version": DEFAULT_CLIENT_VERSION,
+ "platform": DEFAULT_PLATFORM,
+ "token": token,
+ "user_agent": user_agent,
+ "language": DEFAULT_LANGUAGE,
+ "languages": DEFAULT_LANGUAGE,
+ "timezone": DEFAULT_TIMEZONE,
+ "cookie_enabled": "true",
+ "screen_width": DEFAULT_SCREEN_WIDTH,
+ "screen_height": DEFAULT_SCREEN_HEIGHT,
+ "screen_resolution": DEFAULT_SCREEN_RESOLUTION,
+ "viewport_height": DEFAULT_VIEWPORT_HEIGHT,
+ "viewport_width": DEFAULT_VIEWPORT_WIDTH,
+ "viewport_size": DEFAULT_VIEWPORT_SIZE,
+ "color_depth": DEFAULT_COLOR_DEPTH,
+ "pixel_ratio": DEFAULT_PIXEL_RATIO,
+ "current_url": f"{self.base_url}/c/{chat_id}",
+ "pathname": f"/c/{chat_id}",
+ "search": "",
+ "hash": "",
+ "host": "chat.z.ai",
+ "hostname": "chat.z.ai",
+ "protocol": "https:",
+ "referrer": "",
+ "title": DEFAULT_PAGE_TITLE,
+ "timezone_offset": DEFAULT_TIMEZONE_OFFSET,
+ "local_time": (
+ now.strftime("%Y-%m-%dT%H:%M:%S.")
+ + f"{now.microsecond // 1000:03d}Z"
+ ),
+ "utc_time": now.strftime("%a, %d %b %Y %H:%M:%S GMT"),
+ "is_mobile": "false",
+ "is_touch": "false",
+ "max_touch_points": DEFAULT_MAX_TOUCH_POINTS,
+ "browser_name": browser_name,
+ "os_name": "Windows",
+ "signature_timestamp": str(timestamp_ms),
+ }
+
+ def _build_signed_completion_request(
+ self,
+ *,
+ prompt: str,
+ chat_id: str,
+ token: str,
+ user_id: str,
+ user_agent: str,
+ use_browser_fingerprint: bool,
+ ) -> Tuple[str, str, str]:
+ """构建上游 completions 的签名 URL 与请求头元数据。"""
+ timestamp_ms = int(time.time() * 1000)
+ request_id = generate_uuid()
+ core_params = {
+ "requestId": request_id,
+ "timestamp": str(timestamp_ms),
+ "user_id": user_id,
+ }
+ canonical_payload = ",".join(
+ f"{key},{value}" for key, value in sorted(core_params.items())
+ )
+ signature = generate_signature(
+ e=canonical_payload,
+ t=prompt or "",
+ s=timestamp_ms,
+ )["signature"]
+ query_params = dict(core_params)
+ if use_browser_fingerprint:
+ query_params.update(
+ self._build_browser_query_params(
+ chat_id=chat_id,
+ token=token,
+ user_id=user_id,
+ user_agent=user_agent,
+ timestamp_ms=timestamp_ms,
+ )
+ )
+ else:
+ query_params.update(
+ {
+ "token": token,
+ "version": DEFAULT_CLIENT_VERSION,
+ "platform": DEFAULT_PLATFORM,
+ "current_url": f"{self.base_url}/c/{chat_id}",
+ "pathname": f"/c/{chat_id}",
+ "signature_timestamp": str(timestamp_ms),
+ }
+ )
+
+ return (
+ f"{self.api_endpoint}?{urlencode(query_params)}",
+ signature,
+ str(timestamp_ms),
+ )
+
+ async def _create_upstream_chat(
+ self,
+ *,
+ prompt: str,
+ model: str,
+ token: str,
+ headers: Dict[str, str],
+ enable_thinking: bool,
+ web_search: bool,
+ user_message_id: Optional[str] = None,
+ files: Optional[List[Dict[str, Any]]] = None,
+ feature_entries: Optional[List[Dict[str, Any]]] = None,
+ mcp_servers: Optional[List[str]] = None,
+ ) -> str:
+ """为 GLM-4.7 系列创建上游真实 chat 会话。"""
+ init_content = prompt[:CHAT_BOOTSTRAP_MAX_CONTENT_LEN]
+ if len(prompt) > CHAT_BOOTSTRAP_MAX_CONTENT_LEN:
+ init_content = init_content + "..."
+
+ message_id = user_message_id or generate_uuid()
+ timestamp_seconds = int(time.time())
+ chat_features = (
+ [dict(item) for item in feature_entries]
+ if feature_entries
+ else [
+ {
+ "type": "tool_selector",
+ "server": "tool_selector_h",
+ "status": "hidden",
+ }
+ ]
+ )
+ body = {
+ "chat": {
+ "id": "",
+ "title": "新聊天",
+ "models": [model],
+ "params": {},
+ "history": {
+ "messages": {
+ message_id: {
+ "id": message_id,
+ "parentId": None,
+ "childrenIds": [],
+ "role": "user",
+ "content": init_content,
+ **({"files": [dict(item) for item in files]} if files else {}),
+ "timestamp": timestamp_seconds,
+ "models": [model],
+ }
+ },
+ "currentId": message_id,
+ },
+ "tags": [],
+ "flags": [],
+ "features": chat_features,
+ "mcp_servers": list(mcp_servers or []),
+ "enable_thinking": enable_thinking,
+ "auto_web_search": web_search,
+ "message_version": 1,
+ "extra": {},
+ "timestamp": int(time.time() * 1000),
+ }
+ }
+ request_headers = {
+ "Content-Type": "application/json",
+ "Accept": "application/json",
+ "Authorization": f"Bearer {token}",
+ "User-Agent": headers["User-Agent"],
+ "Accept-Language": headers.get("Accept-Language", DEFAULT_LANGUAGE),
+ "Origin": self.base_url,
+ "Referer": f"{self.base_url}/",
+ }
+ async with httpx.AsyncClient(
+ base_url=self.base_url,
+ timeout=self._build_timeout(),
+ limits=self._build_limits(),
+ proxy=self._get_proxy_config(),
+ follow_redirects=True,
+ ) as client:
+ response = await client.post(
+ "/api/v1/chats/new",
+ headers=request_headers,
+ json=body,
+ )
+
+ if response.status_code != 200:
+ raise RuntimeError(
+ f"上游创建 chat 失败: {response.status_code} {response.text}"
+ )
+
+ payload = response.json()
+ chat_id = str(payload.get("id") or payload.get("chat", {}).get("id") or "")
+ if not chat_id:
+ raise RuntimeError("上游创建 chat 成功但未返回 chat_id")
+ return chat_id
+
+ def _build_glm47_completion_body(
+ self,
+ *,
+ model: str,
+ messages: List[Dict[str, Any]],
+ prompt: str,
+ chat_id: str,
+ enable_thinking: bool,
+ web_search: bool,
+ files: List[Dict[str, Any]],
+ tools: Optional[List[Dict[str, Any]]],
+ tool_choice: Any,
+ temperature: Optional[float],
+ max_tokens: Optional[int],
+ mcp_servers: List[str],
+ preview_mode: bool,
+ feature_entries: Optional[List[Dict[str, Any]]],
+ message_id: str,
+ current_user_message_id: str,
+ current_user_message_parent_id: Optional[str],
+ ) -> Dict[str, Any]:
+ """构建兼容持久化 chat 模型的精简 completions 请求体。"""
+ params: Dict[str, Any] = {}
+ if temperature is not None:
+ params["temperature"] = temperature
+ if max_tokens is not None:
+ params["max_tokens"] = max_tokens
+
+ body: Dict[str, Any] = {
+ "stream": True,
+ "model": model,
+ "messages": messages,
+ "signature_prompt": prompt,
+ "params": params,
+ "extra": {},
+ "features": {
+ "image_generation": False,
+ "web_search": web_search,
+ "auto_web_search": web_search,
+ "preview_mode": preview_mode,
+ "flags": [],
+ "enable_thinking": enable_thinking,
+ },
+ "variables": self._build_request_variables(),
+ "chat_id": chat_id,
+ "id": message_id,
+ "current_user_message_id": current_user_message_id,
+ "current_user_message_parent_id": current_user_message_parent_id,
+ "background_tasks": {
+ "title_generation": True,
+ "tags_generation": True,
+ },
+ }
+ if files:
+ body["files"] = files
+ if mcp_servers:
+ body["mcp_servers"] = mcp_servers
+ if tools:
+ body["tools"] = tools
+ if tool_choice is not None:
+ body["tool_choice"] = tool_choice
+ return body
+
+ def _clean_reasoning_delta(self, delta_content: str) -> str:
+ """清理思考阶段的 details 包裹内容。"""
+ if not delta_content:
+ return ""
+
+ if delta_content.startswith("\n>" in delta_content:
+ return delta_content.split("\n>")[-1].strip()
+ if "\n" in delta_content:
+ return delta_content.split("\n")[-1].lstrip("> ").strip()
+
+ return delta_content
+
+ def _extract_answer_content(self, text: str) -> str:
+ """提取思考结束后的答案正文。"""
+ if not text:
+ return ""
+
+ if " \n" in text:
+ return text.split("\n")[-1]
+
+ if "" in text:
+ return text.split("")[-1].lstrip()
+
+ return text
+
+ def _normalize_tool_calls(
+ self,
+ raw_tool_calls: Any,
+ start_index: int = 0,
+ ) -> List[Dict[str, Any]]:
+ """标准化上游工具调用为 OpenAI 兼容格式。"""
+ if not raw_tool_calls:
+ return []
+
+ tool_calls = raw_tool_calls if isinstance(raw_tool_calls, list) else [raw_tool_calls]
+ normalized: List[Dict[str, Any]] = []
+
+ for offset, tool_call in enumerate(tool_calls):
+ if not isinstance(tool_call, dict):
+ continue
+
+ function_data = tool_call.get("function") or {}
+ normalized.append(
+ {
+ "index": tool_call.get("index", start_index + offset),
+ "id": tool_call.get("id") or f"call_{uuid.uuid4().hex[:24]}",
+ "type": "function",
+ "function": {
+ "name": function_data.get("name", ""),
+ "arguments": function_data.get("arguments", ""),
+ },
+ }
+ )
+
+ return normalized
+
+ def _format_search_results(self, data: Dict[str, Any]) -> str:
+ """将上游搜索结果格式化为可追加的 Markdown 引用。"""
+ search_info = data.get("results") or data.get("sources") or data.get("citations")
+ if not isinstance(search_info, list) or not search_info:
+ return ""
+
+ citations = []
+ for index, item in enumerate(search_info, 1):
+ if not isinstance(item, dict):
+ continue
+
+ title = item.get("title") or item.get("name") or f"Result {index}"
+ url = item.get("url") or item.get("link")
+ if url:
+ citations.append(f"[{index}] [{title}]({url})")
+
+ if not citations:
+ return ""
+
+ return "\n\n---\n" + "\n".join(citations)
+
+ def _get_proxy_config(self) -> Optional[str]:
+ """Get proxy configuration from settings"""
+ # In httpx 0.28.1, proxy parameter expects a single URL string
+ # Support HTTP_PROXY, HTTPS_PROXY and SOCKS5_PROXY
+
+ if settings.HTTPS_PROXY:
+ self.logger.info(f"🔄 使用HTTPS代理: {settings.HTTPS_PROXY}")
+ return settings.HTTPS_PROXY
+
+ if settings.HTTP_PROXY:
+ self.logger.info(f"🔄 使用HTTP代理: {settings.HTTP_PROXY}")
+ return settings.HTTP_PROXY
+
+ if settings.SOCKS5_PROXY:
+ self.logger.info(f"🔄 使用SOCKS5代理: {settings.SOCKS5_PROXY}")
+ return settings.SOCKS5_PROXY
+
+ return None
+
+ def _build_timeout(self, read_timeout: float = 30.0) -> httpx.Timeout:
+ """Create httpx timeout settings tuned for upstream chat traffic."""
+ return httpx.Timeout(
+ connect=5.0,
+ read=read_timeout,
+ write=10.0,
+ pool=5.0,
+ )
+
+ def _build_limits(self) -> httpx.Limits:
+ """Create conservative connection-pool limits for upstream requests."""
+ return httpx.Limits(
+ max_keepalive_connections=5,
+ max_connections=10,
+ )
+
+ async def _fetch_direct_guest_auth(self) -> Dict[str, Any]:
+ """匿名号池缺席时,兜底直连拉取一个访客令牌。"""
+ max_retries = 3
+
+ for retry_count in range(max_retries):
+ try:
+ headers = get_dynamic_headers()
+ self.logger.debug(
+ f"尝试获取访客令牌 (第{retry_count + 1}次): {self.auth_url}"
+ )
+
+ proxies = self._get_proxy_config()
+ async with httpx.AsyncClient(
+ timeout=self._build_timeout(),
+ follow_redirects=True,
+ limits=self._build_limits(),
+ proxy=proxies,
+ ) as client:
+ response = await client.get(self.auth_url, headers=headers)
+
+ if response.status_code == 200:
+ data = response.json()
+ token = str(data.get("token") or "").strip()
+ if token:
+ user_id = str(
+ data.get("id")
+ or data.get("user_id")
+ or _extract_user_id_from_token(token)
+ )
+ username = str(
+ data.get("name")
+ or str(data.get("email") or "").split("@")[0]
+ or "Guest"
+ )
+ self.logger.info(
+ f"✅ 直连获取匿名令牌成功: {token[:20]}..."
+ )
+ return {
+ "token": token,
+ "user_id": user_id,
+ "username": username or "Guest",
+ "auth_mode": "guest",
+ "token_source": "guest_direct",
+ "guest_user_id": user_id,
+ }
+
+ self.logger.warning(f"响应中未找到 token 字段: {data}")
+ elif response.status_code == 405:
+ self.logger.error(
+ "🚫 请求被 WAF 拦截 (405),无法直连获取匿名令牌"
+ )
+ break
+ else:
+ self.logger.warning(
+ f"直连获取匿名令牌失败,状态码: {response.status_code}"
+ )
+ except httpx.TimeoutException as exc:
+ self.logger.warning(
+ f"直连获取匿名令牌超时 (第{retry_count + 1}次): {exc}"
+ )
+ except httpx.ConnectError as exc:
+ self.logger.warning(
+ f"直连获取匿名令牌连接错误 (第{retry_count + 1}次): {exc}"
+ )
+ except json.JSONDecodeError as exc:
+ self.logger.warning(
+ f"直连获取匿名令牌 JSON 解析错误 (第{retry_count + 1}次): {exc}"
+ )
+ except Exception as exc:
+ self.logger.warning(
+ f"直连获取匿名令牌失败 (第{retry_count + 1}次): {exc}"
+ )
+
+ if retry_count + 1 < max_retries:
+ await asyncio.sleep(2)
+
+ return {
+ "token": "",
+ "user_id": "guest",
+ "username": "Guest",
+ "auth_mode": "guest",
+ "token_source": "guest_direct",
+ "guest_user_id": None,
+ }
+
+ async def get_auth_info(
+ self,
+ excluded_tokens: Optional[Set[str]] = None,
+ excluded_guest_user_ids: Optional[Set[str]] = None,
+ ) -> Dict[str, Any]:
+ """优先获取认证 Token,必要时回退匿名会话池。"""
+ token_pool = get_token_pool()
+ if token_pool:
+ token = token_pool.get_next_token(exclude_tokens=excluded_tokens)
+ if token:
+ user_id = _extract_user_id_from_token(token)
+ self.logger.debug(f"从认证号池获取令牌: {token[:20]}...")
+ return {
+ "token": token,
+ "user_id": user_id,
+ "username": "User",
+ "auth_mode": "authenticated",
+ "token_source": "auth_pool",
+ "guest_user_id": None,
+ }
+
+ if settings.ANONYMOUS_MODE:
+ guest_pool = get_guest_session_pool()
+ if guest_pool:
+ try:
+ session = await guest_pool.acquire(
+ exclude_user_ids=excluded_guest_user_ids
+ )
+ self.logger.info(
+ "🫥 认证池不可用,回退匿名会话池: "
+ f"user_id={session.user_id}"
+ )
+ return {
+ "token": session.token,
+ "user_id": session.user_id,
+ "username": session.username,
+ "auth_mode": "guest",
+ "token_source": "guest_pool",
+ "guest_user_id": session.user_id,
+ }
+ except Exception as exc:
+ self.logger.warning(f"匿名会话池获取失败,转为直连访客鉴权: {exc}")
+
+ return await self._fetch_direct_guest_auth()
+
+ self.logger.error("❌ 无法获取有效的上游令牌")
+ return {
+ "token": "",
+ "user_id": "",
+ "username": "",
+ "auth_mode": "authenticated",
+ "token_source": "none",
+ "guest_user_id": None,
+ }
+
+ async def mark_token_failure(self, token: str, error: Exception = None):
+ """标记token使用失败"""
+ token_pool = get_token_pool()
+ if token_pool:
+ await token_pool.record_token_failure(token, error)
+
+ async def upload_image(
+ self,
+ data_url: str,
+ chat_id: str,
+ token: str,
+ user_id: str,
+ auth_mode: str = "authenticated",
+ ) -> Optional[Dict]:
+ """上传 base64 编码的图片到上游服务器。
+
+ Args:
+ data_url: data:image/xxx;base64,... 格式的图片数据
+ chat_id: 当前对话ID
+ token: 认证令牌
+ user_id: 用户ID
+ auth_mode: 当前鉴权模式,guest 模式下禁止上传
+
+ Returns:
+ 上传成功返回完整的文件信息字典,失败返回 None
+ """
+ if auth_mode == "guest" or not data_url.startswith("data:"):
+ return None
+
+ try:
+ # 解析 data URL
+ header, encoded = data_url.split(",", 1)
+ mime_type = header.split(";")[0].split(":")[1] if ":" in header else "image/jpeg"
+
+ # 解码 base64 数据
+ image_data = base64.b64decode(encoded)
+ filename = str(uuid.uuid4())
+
+ self.logger.debug(f"📤 上传图片: {filename}, 大小: {len(image_data)} bytes")
+
+ # 构建上传请求
+ upload_url = f"{self.base_url}/api/v1/files/"
+ headers = {
+ "Accept": "*/*",
+ "Accept-Language": "zh-CN,zh;q=0.9",
+ "Cache-Control": "no-cache",
+ "Connection": "keep-alive",
+ "Origin": f"{self.base_url}",
+ "Pragma": "no-cache",
+ "Referer": (
+ f"{self.base_url}/c/{chat_id}" if chat_id else f"{self.base_url}/"
+ ),
+ "Sec-Ch-Ua": '"Microsoft Edge";v="141", "Not?A_Brand";v="8", "Chromium";v="141"',
+ "Sec-Ch-Ua-Mobile": "?0",
+ "Sec-Ch-Ua-Platform": '"Windows"',
+ "Sec-Fetch-Dest": "empty",
+ "Sec-Fetch-Mode": "cors",
+ "Sec-Fetch-Site": "same-origin",
+ "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/141.0.0.0 Safari/537.36 Edg/141.0.0.0",
+ "Authorization": f"Bearer {token}",
+ }
+
+ # Get proxy configuration
+ proxies = self._get_proxy_config()
+
+ # 使用 httpx 上传文件
+ async with httpx.AsyncClient(
+ timeout=self._build_timeout(),
+ limits=self._build_limits(),
+ proxy=proxies,
+ ) as client:
+ files = {
+ "file": (filename, image_data, mime_type)
+ }
+ response = await client.post(upload_url, files=files, headers=headers)
+
+ if response.status_code == 200:
+ result = response.json()
+ file_id = result.get("id")
+ file_name = result.get("filename")
+ file_size = len(image_data)
+
+ self.logger.info(f"✅ 图片上传成功: {file_id}_{file_name}")
+
+ # 返回符合上游格式的文件信息
+ current_timestamp = int(time.time())
+ return {
+ "type": "image",
+ "file": {
+ "id": file_id,
+ "user_id": user_id,
+ "hash": None,
+ "filename": file_name,
+ "data": {},
+ "meta": {
+ "name": file_name,
+ "content_type": mime_type,
+ "size": file_size,
+ "data": {},
+ },
+ "created_at": current_timestamp,
+ "updated_at": current_timestamp
+ },
+ "id": file_id,
+ "url": f"/api/v1/files/{file_id}/content",
+ "name": file_name,
+ "status": "uploaded",
+ "size": file_size,
+ "error": "",
+ "itemId": str(uuid.uuid4()),
+ "media": "image"
+ }
+ else:
+ self.logger.error(f"❌ 图片上传失败: {response.status_code} - {response.text}")
+ return None
+
+ except Exception as e:
+ self.logger.error(f"❌ 图片上传异常: {e}")
+ return None
+
+ async def transform_request(
+ self,
+ request: OpenAIRequest,
+ excluded_tokens: Optional[Set[str]] = None,
+ excluded_guest_user_ids: Optional[Set[str]] = None,
+ ) -> Dict[str, Any]:
+ """转换 OpenAI 请求为上游格式。"""
+ self.logger.info(f"🔄 转换 OpenAI 请求到上游格式: {request.model}")
+
+ raw_messages = [
+ message.model_dump(exclude_none=True)
+ for message in request.messages
+ ]
+ normalized_messages = _preprocess_openai_messages(raw_messages)
+
+ auth_info = await self.get_auth_info(
+ excluded_tokens=excluded_tokens,
+ excluded_guest_user_ids=excluded_guest_user_ids,
+ )
+ token = str(auth_info.get("token") or "")
+ if not token:
+ raise RuntimeError("无法获取上游认证令牌")
+
+ user_id = str(auth_info.get("user_id") or _extract_user_id_from_token(token))
+ auth_mode = str(auth_info.get("auth_mode") or "authenticated")
+ token_source = str(auth_info.get("token_source") or "unknown")
+ guest_user_id = auth_info.get("guest_user_id")
+ # 确定请求的模型特性
+ last_user_text = _extract_last_user_text(raw_messages)
+ requested_model = request.model
+ is_thinking_model = "-thinking" in requested_model.casefold()
+ is_search_model = "-search" in requested_model.casefold()
+ is_advanced_search = requested_model == settings.GLM47_ADVANCED_SEARCH_MODEL
+ upstream_model_id = self.model_mapping.get(requested_model, "0727-360B-API")
+ tools = request.tools if settings.TOOL_SUPPORT and request.tools else None
+ tool_choice = getattr(request, "tool_choice", None)
+ model_profile = self._get_model_request_profile(upstream_model_id)
+ enable_thinking = request.enable_thinking
+ if enable_thinking is None:
+ default_enable_thinking = model_profile["default_enable_thinking"]
+ enable_thinking = (
+ default_enable_thinking
+ if default_enable_thinking is not None
+ else is_thinking_model
+ )
+
+ web_search = request.web_search
+ if web_search is None:
+ web_search = is_search_model or is_advanced_search
+
+ use_persisted_chat = bool(model_profile["use_persisted_chat"])
+ preview_mode = bool(model_profile["preview_mode"])
+ feature_entries = list(model_profile["feature_entries"])
+ persisted_user_message_id = generate_uuid() if use_persisted_chat else None
+ persisted_assistant_message_id = generate_uuid() if use_persisted_chat else None
+
+ mcp_servers = list(model_profile["mcp_servers"])
+ if is_advanced_search and "advanced-search" not in mcp_servers:
+ mcp_servers.append("advanced-search")
+ self.logger.info("🔍 检测到高级搜索模型,添加 advanced-search MCP 服务器")
+
+ headers = get_dynamic_headers(
+ browser_type="chrome" if use_persisted_chat else None,
+ )
+ chat_id = generate_uuid()
+
+ # 处理消息格式 - 上游使用单独的 files 字段传递图片
+ messages = []
+ files = []
+ upload_chat_id = "" if use_persisted_chat else chat_id
+
+ for msg in normalized_messages:
+ role = str(msg.get("role", "user"))
+ content = msg.get("content")
+
+ if isinstance(content, str):
+ messages.append({"role": role, "content": content})
+ continue
+
+ if not isinstance(content, list):
+ continue
+
+ text_parts = []
+ image_parts = []
+ for part in content:
+ image_url = None
+ if hasattr(part, "type"):
+ if part.type == "text" and hasattr(part, "text"):
+ text_parts.append(part.text or "")
+ elif part.type == "image_url" and hasattr(part, "image_url"):
+ if hasattr(part.image_url, "url"):
+ image_url = part.image_url.url
+ elif (
+ isinstance(part.image_url, dict)
+ and "url" in part.image_url
+ ):
+ image_url = part.image_url["url"]
+ elif isinstance(part, dict):
+ if part.get("type") == "text":
+ text_parts.append(part.get("text", ""))
+ elif part.get("type") == "image_url":
+ image_url = part.get("image_url", {}).get("url", "")
+ elif isinstance(part, str):
+ text_parts.append(part)
+
+ if not image_url:
+ continue
+
+ self.logger.debug(f"✅ 检测到图片: {image_url[:50]}...")
+ if image_url.startswith("data:") and auth_mode != "guest":
+ self.logger.info("🔄 上传 base64 图片到上游服务")
+ file_info = await self.upload_image(
+ image_url,
+ upload_chat_id,
+ token,
+ user_id,
+ auth_mode=auth_mode,
+ )
+ if not file_info:
+ self.logger.warning("⚠️ 图片上传失败")
+ text_parts.append("[系统提示: 图片上传失败]")
+ continue
+
+ files.append(file_info)
+ self.logger.info("✅ 图片已添加到 files 数组")
+ if persisted_user_message_id:
+ file_info["ref_user_msg_id"] = persisted_user_message_id
+ image_ref = str(file_info["id"])
+ image_parts.append(
+ {
+ "type": "image_url",
+ "image_url": {"url": image_ref},
+ }
+ )
+ self.logger.debug(f"📎 图片引用: {image_ref}")
+ continue
+
+ if auth_mode != "guest":
+ self.logger.warning("⚠️ 非 base64 图片或匿名模式,保留原始URL")
+ image_parts.append(
+ {
+ "type": "image_url",
+ "image_url": {"url": image_url},
+ }
+ )
+
+ message_content = []
+ combined_text = " ".join(text_parts).strip()
+ if combined_text:
+ message_content.append({"type": "text", "text": combined_text})
+ message_content.extend(image_parts)
+ if message_content:
+ messages.append({"role": role, "content": message_content})
+
+ if use_persisted_chat:
+ chat_id = await self._create_upstream_chat(
+ prompt=last_user_text,
+ model=upstream_model_id,
+ token=token,
+ headers=headers,
+ enable_thinking=enable_thinking,
+ web_search=web_search,
+ user_message_id=persisted_user_message_id,
+ files=files or None,
+ feature_entries=feature_entries or None,
+ mcp_servers=mcp_servers or None,
+ )
+ self.logger.info(f"🧩 已为 {requested_model} 创建上游 chat: {chat_id}")
+ headers["Referer"] = f"{self.base_url}/c/{chat_id}"
+
+ if use_persisted_chat:
+ body = self._build_glm47_completion_body(
+ model=upstream_model_id,
+ messages=messages,
+ prompt=last_user_text,
+ chat_id=chat_id,
+ enable_thinking=enable_thinking,
+ web_search=web_search,
+ files=files,
+ tools=tools,
+ tool_choice=tool_choice,
+ temperature=request.temperature,
+ max_tokens=request.max_tokens,
+ mcp_servers=mcp_servers,
+ preview_mode=preview_mode,
+ feature_entries=feature_entries or None,
+ message_id=persisted_assistant_message_id or generate_uuid(),
+ current_user_message_id=persisted_user_message_id or generate_uuid(),
+ current_user_message_parent_id=None,
+ )
+ else:
+ message_id = generate_uuid()
+ session_id = generate_uuid()
+ body = {
+ "stream": True,
+ "model": upstream_model_id,
+ "messages": messages,
+ "signature_prompt": last_user_text,
+ "files": files,
+ "params": {},
+ "extra": {},
+ "features": {
+ "image_generation": False,
+ "web_search": web_search,
+ "auto_web_search": web_search,
+ "preview_mode": preview_mode,
+ "flags": [],
+ "features": [
+ dict(item)
+ for item in (feature_entries or DEFAULT_COMPLETION_FEATURES)
+ ],
+ "enable_thinking": enable_thinking,
+ },
+ "background_tasks": {
+ "title_generation": False,
+ "tags_generation": False,
+ },
+ "mcp_servers": mcp_servers,
+ "variables": self._build_request_variables(),
+ "model_item": {
+ "id": upstream_model_id,
+ "name": requested_model,
+ "owned_by": settings.SERVICE_NAME,
+ },
+ "chat_id": chat_id,
+ "id": message_id,
+ "session_id": session_id,
+ "current_user_message_id": message_id,
+ "current_user_message_parent_id": None,
+ }
+ if tools:
+ body["tools"] = tools
+ if tool_choice is not None:
+ body["tool_choice"] = tool_choice
+ self.logger.info(f"🔧 工具调用将直接透传到上游: {len(tools)} 个工具")
+ else:
+ body["tools"] = None
+ if request.temperature is not None:
+ body["params"]["temperature"] = request.temperature
+ if request.max_tokens is not None:
+ body["params"]["max_tokens"] = request.max_tokens
+
+ try:
+ signed_url, signature, timestamp_ms = self._build_signed_completion_request(
+ prompt=last_user_text,
+ chat_id=chat_id,
+ token=token,
+ user_id=user_id,
+ user_agent=headers["User-Agent"],
+ use_browser_fingerprint=use_persisted_chat,
+ )
+ logger.debug(
+ "[上游] 生成签名成功: %s... (user_id=%s, timestamp=%s)",
+ signature[:16],
+ user_id,
+ timestamp_ms,
+ )
+ except Exception as e:
+ logger.error(f"[上游] 签名生成失败: {e}")
+ signature = ""
+ timestamp_ms = "0"
+ signed_url = self.api_endpoint
+
+ fe_version = headers.get("X-FE-Version") or get_latest_fe_version()
+ headers.update(
+ {
+ "Authorization": f"Bearer {token}",
+ "Content-Type": "application/json",
+ "Accept": "*/*" if use_persisted_chat else "application/json",
+ "X-FE-Version": fe_version,
+ "X-Signature": signature,
+ }
+ )
+
+ logger.debug(
+ "[上游] 请求头: Authorization=Bearer *****, X-Signature=%s...",
+ signature[:16] if signature else "(空)",
+ )
+ logger.debug(
+ "[上游] URL 参数: timestamp=%s, user_id=%s, persisted_chat=%s",
+ timestamp_ms,
+ user_id,
+ use_persisted_chat,
+ )
+
+ # 存储当前token用于错误处理
+ self._current_token = token
+
+ return {
+ "url": signed_url,
+ "headers": headers,
+ "body": body,
+ "token": token,
+ "chat_id": chat_id,
+ "model": requested_model,
+ "user_id": user_id,
+ "auth_mode": auth_mode,
+ "token_source": token_source,
+ "guest_user_id": guest_user_id,
+ }
+
+ async def chat_completion(
+ self,
+ request: OpenAIRequest,
+ **kwargs
+ ) -> Union[Dict[str, Any], AsyncGenerator[str, None]]:
+ """聊天完成接口。"""
+ self.logger.info(f"🔄 {self.name} 处理请求: {request.model}")
+ self.logger.debug(f" 消息数量: {len(request.messages)}")
+ self.logger.debug(f" 流式模式: {request.stream}")
+
+ try:
+ transformed = await self.transform_request(request)
+
+ if request.stream:
+ return self._create_stream_response(request, transformed)
+
+ proxies = self._get_proxy_config()
+ max_attempts = self._get_total_retry_limit()
+ excluded_tokens: Set[str] = set()
+ excluded_guest_user_ids: Set[str] = set()
+
+ for attempt in range(max_attempts):
+ async with httpx.AsyncClient(
+ timeout=self._build_timeout(read_timeout=60.0),
+ limits=self._build_limits(),
+ proxy=proxies,
+ ) as client:
+ response = await client.post(
+ transformed["url"],
+ headers=transformed["headers"],
+ json=transformed["body"],
+ )
+
+ error_code, error_message = self._extract_upstream_error_details(
+ response.status_code,
+ response.text,
+ )
+ is_concurrency_limited = self._is_concurrency_limited(
+ response.status_code,
+ error_code,
+ error_message,
+ )
+
+ if self._should_retry_guest_session(
+ response.status_code,
+ is_concurrency_limited,
+ attempt,
+ max_attempts,
+ transformed,
+ ):
+ guest_user_id = str(
+ transformed.get("guest_user_id")
+ or transformed.get("user_id")
+ or ""
+ )
+ if guest_user_id:
+ excluded_guest_user_ids.add(guest_user_id)
+ transformed = await self._refresh_guest_request(
+ request,
+ attempt,
+ excluded_tokens,
+ excluded_guest_user_ids,
+ transformed,
+ is_concurrency_limited=is_concurrency_limited,
+ )
+ continue
+
+ if self._should_retry_authenticated_session(
+ response.status_code,
+ is_concurrency_limited,
+ attempt,
+ max_attempts,
+ transformed,
+ ):
+ current_token = str(transformed.get("token") or "")
+ if current_token:
+ excluded_tokens.add(current_token)
+ await self.mark_token_failure(
+ current_token,
+ Exception(error_message or "上游认证会话不可用"),
+ )
+ self.logger.warning(
+ "⚠️ 认证会话不可用,准备切换认证 Token/回退匿名池: "
+ f"{current_token[:20]}..."
+ )
+ transformed = await self._refresh_authenticated_request(
+ request,
+ attempt,
+ excluded_tokens,
+ excluded_guest_user_ids,
+ )
+ continue
+
+ if not response.is_success:
+ error_msg = f"上游 API 错误: {response.status_code}"
+ if not self._is_guest_auth(transformed):
+ current_token = str(transformed.get("token") or "")
+ if current_token:
+ await self.mark_token_failure(
+ current_token,
+ Exception(error_message or error_msg),
+ )
+ await self._release_guest_session(transformed)
+ self.logger.error(f"❌ {self.name} 响应失败: {error_msg}")
+ return handle_error(Exception(error_message or error_msg))
+
+ try:
+ result = await self.transform_response(response, request, transformed)
+ finally:
+ await self._release_guest_session(transformed)
+
+ if not self._is_guest_auth(transformed):
+ current_token = str(transformed.get("token") or "")
+ if current_token:
+ token_pool = get_token_pool()
+ if token_pool:
+ await token_pool.record_token_success(current_token)
+
+ return result
+
+ except Exception as e:
+ self.logger.error(f"❌ {self.name} 响应失败: {str(e)}")
+ return handle_error(e, "请求处理")
+
+ async def _create_stream_response(
+ self,
+ request: OpenAIRequest,
+ transformed: Dict[str, Any]
+ ) -> AsyncGenerator[str, None]:
+ """创建流式响应,并在首包前支持双池重试。"""
+ max_attempts = self._get_total_retry_limit()
+ excluded_tokens: Set[str] = set()
+ excluded_guest_user_ids: Set[str] = set()
+ current_token = str(transformed.get("token") or "")
+
+ try:
+ proxies = self._get_proxy_config()
+
+ async with httpx.AsyncClient(
+ timeout=self._build_timeout(read_timeout=180.0),
+ http2=True,
+ limits=self._build_limits(),
+ proxy=proxies,
+ ) as client:
+ for attempt in range(max_attempts):
+ self.logger.info(f"🎯 发送请求到上游: {transformed['url']}")
+ async with client.stream(
+ "POST",
+ transformed["url"],
+ json=transformed["body"],
+ headers=transformed["headers"],
+ ) as response:
+ error_text = await response.aread() if response.status_code != 200 else b""
+ error_msg = error_text.decode("utf-8", errors="ignore")
+ error_code, parsed_error_message = (
+ self._extract_upstream_error_details(
+ response.status_code,
+ error_msg,
+ )
+ if response.status_code != 200
+ else (None, "")
+ )
+ is_concurrency_limited = self._is_concurrency_limited(
+ response.status_code,
+ error_code,
+ parsed_error_message,
+ )
+
+ if self._should_retry_guest_session(
+ response.status_code,
+ is_concurrency_limited,
+ attempt,
+ max_attempts,
+ transformed,
+ ):
+ guest_user_id = str(
+ transformed.get("guest_user_id")
+ or transformed.get("user_id")
+ or ""
+ )
+ if guest_user_id:
+ excluded_guest_user_ids.add(guest_user_id)
+ transformed = await self._refresh_guest_request(
+ request,
+ attempt,
+ excluded_tokens,
+ excluded_guest_user_ids,
+ transformed,
+ is_concurrency_limited=is_concurrency_limited,
+ )
+ current_token = str(transformed.get("token") or "")
+ continue
+
+ if self._should_retry_authenticated_session(
+ response.status_code,
+ is_concurrency_limited,
+ attempt,
+ max_attempts,
+ transformed,
+ ):
+ if current_token:
+ excluded_tokens.add(current_token)
+ await self.mark_token_failure(
+ current_token,
+ Exception(
+ parsed_error_message or "上游认证会话不可用"
+ ),
+ )
+ self.logger.warning(
+ "⚠️ 流式请求命中认证会话限制,准备切号/回退匿名池: "
+ f"{current_token[:20]}..."
+ )
+ transformed = await self._refresh_authenticated_request(
+ request,
+ attempt,
+ excluded_tokens,
+ excluded_guest_user_ids,
+ )
+ current_token = str(transformed.get("token") or "")
+ continue
+
+ if response.status_code != 200:
+ self.logger.error(f"❌ 上游返回错误: {response.status_code}")
+ if error_msg:
+ self.logger.error(f"❌ 错误详情: {error_msg}")
+
+ if not self._is_guest_auth(transformed) and current_token:
+ await self.mark_token_failure(
+ current_token,
+ Exception(
+ parsed_error_message
+ or f"Upstream error: {response.status_code}"
+ ),
+ )
+ await self._release_guest_session(transformed)
+
+ if response.status_code == 405:
+ self.logger.error(
+ "🚫 请求被上游 WAF 拦截,可能是请求头或签名异常"
+ )
+ error_response = {
+ "error": {
+ "message": (
+ "请求被上游WAF拦截(405 Method Not Allowed),"
+ "可能是请求头或签名异常,请稍后重试..."
+ ),
+ "type": "waf_blocked",
+ "code": 405,
+ }
+ }
+ else:
+ error_response = {
+ "error": {
+ "message": parsed_error_message
+ or f"Upstream error: {response.status_code}",
+ "type": "upstream_error",
+ "code": error_code or response.status_code,
+ }
+ }
+ yield f"data: {json.dumps(error_response)}\n\n"
+ yield "data: [DONE]\n\n"
+ return
+
+ chat_id = transformed["chat_id"]
+ model = transformed["model"]
+ try:
+ async for chunk in self._handle_stream_response(
+ response,
+ chat_id,
+ model,
+ request,
+ transformed,
+ ):
+ yield chunk
+ finally:
+ await self._release_guest_session(transformed)
+
+ if not self._is_guest_auth(transformed) and current_token:
+ token_pool = get_token_pool()
+ if token_pool:
+ await token_pool.record_token_success(current_token)
+ return
+ except Exception as e:
+ self.logger.error(f"❌ 流处理错误: {e}")
+ import traceback
+ self.logger.error(traceback.format_exc())
+ if self._is_guest_auth(transformed):
+ await self._release_guest_session(transformed)
+ elif current_token:
+ await self.mark_token_failure(current_token, e)
+
+ error_response = {
+ "error": {
+ "message": str(e),
+ "type": "stream_error"
+ }
+ }
+ yield f"data: {json.dumps(error_response)}\n\n"
+ yield "data: [DONE]\n\n"
+ return
+
+ async def transform_response(
+ self,
+ response: httpx.Response,
+ request: OpenAIRequest,
+ transformed: Dict[str, Any]
+ ) -> Union[Dict[str, Any], AsyncGenerator[str, None]]:
+ """转换上游响应为 OpenAI 格式。"""
+ chat_id = transformed["chat_id"]
+ model = transformed["model"]
+
+ if request.stream:
+ return self._handle_stream_response(response, chat_id, model, request, transformed)
+ else:
+ return await self._handle_non_stream_response(response, chat_id, model)
+
+ async def _handle_stream_response(
+ self,
+ response: httpx.Response,
+ chat_id: str,
+ model: str,
+ request: OpenAIRequest,
+ transformed: Dict[str, Any]
+ ) -> AsyncGenerator[str, None]:
+ """处理上游流式响应"""
+ self.logger.info("✅ 上游响应成功,开始处理 SSE 流")
+
+ has_tools = settings.TOOL_SUPPORT and bool(request.tools)
+ buffered_content = ""
+ usage_info: Dict[str, int] = {
+ "prompt_tokens": 0,
+ "completion_tokens": 0,
+ "total_tokens": 0,
+ }
+ tool_calls_accum: List[Dict[str, Any]] = []
+ has_sent_role = False
+ finished = False
+ line_count = 0
+
+ async def ensure_role_sent() -> Optional[str]:
+ nonlocal has_sent_role
+ if has_sent_role:
+ return None
+
+ has_sent_role = True
+ return await format_sse_chunk(
+ create_openai_chunk(chat_id, model, {"role": "assistant"})
+ )
+
+ async def finalize_stream() -> AsyncGenerator[str, None]:
+ nonlocal finished, tool_calls_accum
+ if finished:
+ return
+
+ if has_tools and not tool_calls_accum:
+ parsed_tool_calls, _ = parse_and_extract_tool_calls(buffered_content)
+ normalized = self._normalize_tool_calls(parsed_tool_calls)
+ if normalized:
+ tool_calls_accum = normalized
+ role_output = await ensure_role_sent()
+ if role_output:
+ yield role_output
+ for tool_call in normalized:
+ yield await format_sse_chunk(
+ create_openai_chunk(
+ chat_id,
+ model,
+ {"tool_calls": [tool_call]},
+ )
+ )
+
+ if not has_sent_role:
+ role_output = await ensure_role_sent()
+ if role_output:
+ yield role_output
+
+ finish_reason = "tool_calls" if tool_calls_accum else "stop"
+ finish_chunk = create_openai_chunk(
+ chat_id,
+ model,
+ {},
+ finish_reason,
+ )
+ finish_chunk["usage"] = usage_info
+ yield await format_sse_chunk(finish_chunk)
+ yield "data: [DONE]\n\n"
+ finished = True
+
+ try:
+ async for line in response.aiter_lines():
+ line_count += 1
+ if not line:
+ continue
+
+ current_line = line.strip()
+ if not current_line.startswith("data:"):
+ continue
+
+ chunk_str = current_line[5:].strip()
+ if not chunk_str:
+ continue
+
+ if chunk_str == "[DONE]":
+ async for final_chunk in finalize_stream():
+ yield final_chunk
+ continue
+
+ try:
+ chunk = json.loads(chunk_str)
+ except json.JSONDecodeError as error:
+ self.logger.debug(f"❌ JSON解析错误: {error}, 内容: {chunk_str[:1000]}")
+ continue
+
+ chunk_type = chunk.get("type")
+ data = chunk.get("data", {}) if chunk_type == "chat:completion" else chunk
+ if not isinstance(data, dict):
+ continue
+
+ phase = data.get("phase")
+ delta_content = data.get("delta_content", "")
+ edit_content = data.get("edit_content", "")
+
+ if phase and phase != getattr(self, "_last_phase", None):
+ self.logger.info(f"📈 SSE 阶段: {phase}")
+ self._last_phase = phase
+
+ if data.get("usage"):
+ usage_info = data["usage"]
+
+ if delta_content:
+ buffered_content += delta_content
+ elif edit_content:
+ buffered_content += edit_content
+
+ direct_tool_calls = self._normalize_tool_calls(
+ data.get("tool_calls"),
+ len(tool_calls_accum),
+ )
+ if direct_tool_calls:
+ role_output = await ensure_role_sent()
+ if role_output:
+ yield role_output
+ tool_calls_accum.extend(direct_tool_calls)
+ for tool_call in direct_tool_calls:
+ yield await format_sse_chunk(
+ create_openai_chunk(
+ chat_id,
+ model,
+ {"tool_calls": [tool_call]},
+ )
+ )
+
+ if phase == "thinking" and delta_content:
+ cleaned = self._clean_reasoning_delta(delta_content)
+ if cleaned:
+ role_output = await ensure_role_sent()
+ if role_output:
+ yield role_output
+ yield await format_sse_chunk(
+ create_openai_chunk(
+ chat_id,
+ model,
+ {"reasoning_content": cleaned},
+ )
+ )
+
+ elif phase == "answer":
+ text = delta_content or self._extract_answer_content(edit_content)
+ if text:
+ role_output = await ensure_role_sent()
+ if role_output:
+ yield role_output
+ yield await format_sse_chunk(
+ create_openai_chunk(
+ chat_id,
+ model,
+ {"content": text},
+ )
+ )
+
+ elif phase == "other":
+ other_text = self._extract_answer_content(edit_content)
+ if other_text:
+ role_output = await ensure_role_sent()
+ if role_output:
+ yield role_output
+ yield await format_sse_chunk(
+ create_openai_chunk(
+ chat_id,
+ model,
+ {"content": other_text},
+ )
+ )
+
+ elif phase == "search" or chunk_type == "web_search":
+ citation_text = self._format_search_results(data)
+ if citation_text:
+ role_output = await ensure_role_sent()
+ if role_output:
+ yield role_output
+ yield await format_sse_chunk(
+ create_openai_chunk(
+ chat_id,
+ model,
+ {"content": citation_text},
+ )
+ )
+
+ if data.get("done"):
+ async for final_chunk in finalize_stream():
+ yield final_chunk
+ return
+
+ self.logger.info(f"✅ SSE 流处理完成,共处理 {line_count} 行数据")
+
+ if not finished:
+ async for final_chunk in finalize_stream():
+ yield final_chunk
+
+ except Exception as e:
+ self.logger.error(f"❌ 流式响应处理错误: {e}")
+ import traceback
+ self.logger.error(traceback.format_exc())
+ yield await format_sse_chunk(
+ create_openai_chunk(chat_id, model, {}, "stop")
+ )
+ yield "data: [DONE]\n\n"
+
+ async def _handle_non_stream_response(
+ self,
+ response: httpx.Response,
+ chat_id: str,
+ model: str
+ ) -> Dict[str, Any]:
+ """处理非流式响应,聚合上游 SSE 为一次性 OpenAI 响应。"""
+ final_content = ""
+ reasoning_content = ""
+ tool_calls_accum: List[Dict[str, Any]] = []
+ usage_info: Dict[str, int] = {
+ "prompt_tokens": 0,
+ "completion_tokens": 0,
+ "total_tokens": 0,
+ }
+
+ try:
+ async for line in response.aiter_lines():
+ if not line:
+ continue
+
+ line = line.strip()
+ if not line.startswith("data:"):
+ try:
+ maybe_err = json.loads(line)
+ if isinstance(maybe_err, dict) and (
+ "error" in maybe_err or "code" in maybe_err or "message" in maybe_err
+ ):
+ msg = (
+ (maybe_err.get("error") or {}).get("message")
+ if isinstance(maybe_err.get("error"), dict)
+ else maybe_err.get("message")
+ ) or "上游返回错误"
+ return handle_error(Exception(msg), "API响应")
+ except Exception:
+ pass
+ continue
+
+ data_str = line[5:].strip()
+ if not data_str or data_str in ("[DONE]", "DONE", "done"):
+ continue
+
+ try:
+ chunk = json.loads(data_str)
+ except json.JSONDecodeError:
+ continue
+
+ chunk_type = chunk.get("type")
+ data = chunk.get("data", {}) if chunk_type == "chat:completion" else chunk
+ if not isinstance(data, dict):
+ continue
+
+ phase = data.get("phase")
+ delta_content = data.get("delta_content", "")
+ edit_content = data.get("edit_content", "")
+
+ if data.get("usage"):
+ usage_info = data["usage"]
+
+ if phase == "thinking" and delta_content:
+ reasoning_content += self._clean_reasoning_delta(delta_content)
+
+ elif phase == "answer":
+ if delta_content:
+ final_content += delta_content
+ elif edit_content:
+ final_content += self._extract_answer_content(edit_content)
+
+ elif phase == "other" and edit_content:
+ final_content += self._extract_answer_content(edit_content)
+
+ elif phase == "search" or chunk_type == "web_search":
+ final_content += self._format_search_results(data)
+
+ tool_calls_accum.extend(
+ self._normalize_tool_calls(
+ data.get("tool_calls"),
+ len(tool_calls_accum),
+ )
+ )
+
+ except Exception as e:
+ self.logger.error(f"❌ 非流式响应处理错误: {e}")
+ import traceback
+ self.logger.error(traceback.format_exc())
+ return handle_error(e, "非流式聚合")
+
+ if not tool_calls_accum:
+ parsed_tool_calls, cleaned_content = parse_and_extract_tool_calls(final_content)
+ normalized = self._normalize_tool_calls(parsed_tool_calls)
+ if normalized:
+ tool_calls_accum = normalized
+ final_content = cleaned_content
+
+ final_content = (final_content or "").strip()
+ reasoning_content = (reasoning_content or "").strip()
+
+ if not final_content and reasoning_content:
+ final_content = reasoning_content
+
+ return create_openai_response_with_reasoning(
+ chat_id,
+ model,
+ final_content,
+ reasoning_content,
+ usage_info,
+ tool_calls_accum or None,
+ )
diff --git a/app/models/__init__.py b/app/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..daad5f4488268bebe2c495c63a408caf0ef1c881
--- /dev/null
+++ b/app/models/__init__.py
@@ -0,0 +1,6 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+
+from app.models import schemas
+
+__all__ = ["schemas"]
diff --git a/app/models/request_log.py b/app/models/request_log.py
new file mode 100644
index 0000000000000000000000000000000000000000..7ec9f980e0e40ad228ddbe99ce53730f15d6a2eb
--- /dev/null
+++ b/app/models/request_log.py
@@ -0,0 +1,35 @@
+"""请求日志数据库模型。"""
+
+from app.core.config import settings
+
+DB_PATH = settings.DB_PATH
+
+# 创建请求日志表的SQL
+SQL_CREATE_REQUEST_LOGS_TABLE = """
+CREATE TABLE IF NOT EXISTS request_logs (
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ timestamp DATETIME DEFAULT CURRENT_TIMESTAMP,
+ provider TEXT NOT NULL,
+ endpoint TEXT DEFAULT '',
+ source TEXT DEFAULT 'unknown',
+ protocol TEXT DEFAULT 'unknown',
+ client_name TEXT DEFAULT 'Unknown',
+ model TEXT NOT NULL,
+ status_code INTEGER DEFAULT 200,
+ success BOOLEAN NOT NULL,
+ duration REAL,
+ first_token_time REAL,
+ input_tokens INTEGER DEFAULT 0,
+ output_tokens INTEGER DEFAULT 0,
+ cache_creation_tokens INTEGER DEFAULT 0,
+ cache_read_tokens INTEGER DEFAULT 0,
+ total_tokens INTEGER DEFAULT 0,
+ error_message TEXT,
+ created_at DATETIME DEFAULT CURRENT_TIMESTAMP
+);
+
+CREATE INDEX IF NOT EXISTS idx_request_logs_timestamp ON request_logs(timestamp);
+CREATE INDEX IF NOT EXISTS idx_request_logs_model ON request_logs(model);
+CREATE INDEX IF NOT EXISTS idx_request_logs_provider ON request_logs(provider);
+CREATE INDEX IF NOT EXISTS idx_request_logs_source ON request_logs(source);
+"""
diff --git a/app/models/schemas.py b/app/models/schemas.py
new file mode 100644
index 0000000000000000000000000000000000000000..6e4fa63e49155843f1cb4c5d717a6bc709369e4d
--- /dev/null
+++ b/app/models/schemas.py
@@ -0,0 +1,166 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+
+from typing import Dict, List, Optional, Any, Union, Literal
+from pydantic import BaseModel
+
+
+class ImageUrl(BaseModel):
+ """Image URL model for vision content"""
+ url: str
+
+
+class ContentPart(BaseModel):
+ """Content part model for OpenAI's new content format"""
+
+ type: str
+ text: Optional[str] = None
+ image_url: Optional[ImageUrl] = None # 添加 image_url 字段
+
+
+class Message(BaseModel):
+ """Chat message model"""
+
+ role: str
+ content: Optional[Union[str, List[ContentPart]]] = None
+ reasoning_content: Optional[str] = None
+ tool_calls: Optional[List[Dict[str, Any]]] = None
+ tool_call_id: Optional[str] = None
+ name: Optional[str] = None
+
+
+class OpenAIRequest(BaseModel):
+ """OpenAI-compatible request model"""
+
+ model: str
+ messages: List[Message]
+ stream: Optional[bool] = False
+ temperature: Optional[float] = None
+ max_tokens: Optional[int] = None
+ tools: Optional[List[Dict[str, Any]]] = None
+ tool_choice: Optional[Any] = None
+ enable_thinking: Optional[bool] = None
+ web_search: Optional[bool] = None
+
+
+class ModelItem(BaseModel):
+ """Model information item"""
+
+ id: str
+ name: str
+ owned_by: str
+
+
+class UpstreamRequest(BaseModel):
+ """Upstream service request model"""
+
+ stream: bool
+ model: str
+ messages: List[Message]
+ params: Dict[str, Any] = {}
+ features: Dict[str, Any] = {}
+ signature_prompt: Optional[str] = None
+ files: Optional[List[Dict[str, Any]]] = None
+ extra: Optional[Dict[str, Any]] = None
+ background_tasks: Optional[Dict[str, bool]] = None
+ chat_id: Optional[str] = None
+ id: Optional[str] = None
+ session_id: Optional[str] = None
+ current_user_message_id: Optional[str] = None
+ current_user_message_parent_id: Optional[str] = None
+ mcp_servers: Optional[List[str]] = None
+ model_item: Optional[Dict[str, Any]] = {} # Model item dictionary
+ tools: Optional[List[Dict[str, Any]]] = None # Add tools field for OpenAI compatibility
+ tool_choice: Optional[Any] = None
+ variables: Optional[Dict[str, str]] = None
+ model_config = {"protected_namespaces": ()}
+
+
+class Delta(BaseModel):
+ """Stream delta model"""
+
+ role: Optional[str] = None
+ content: Optional[str] = "" or None
+ reasoning_content: Optional[str] = None
+ tool_calls: Optional[List[Dict[str, Any]]] = None
+
+
+class Choice(BaseModel):
+ """Response choice model"""
+
+ index: int
+ message: Optional[Message] = None
+ delta: Optional[Delta] = None
+ finish_reason: Optional[str] = None
+
+
+class Usage(BaseModel):
+ """Token usage statistics"""
+
+ prompt_tokens: int = 0
+ completion_tokens: int = 0
+ total_tokens: int = 0
+
+
+class OpenAIResponse(BaseModel):
+ """OpenAI-compatible response model"""
+
+ id: str
+ object: str
+ created: int
+ model: str
+ choices: List[Choice]
+ usage: Optional[Usage] = None
+
+
+class UpstreamError(BaseModel):
+ """Upstream error model"""
+
+ detail: str
+ code: int
+
+
+class UpstreamDataInner(BaseModel):
+ """Inner upstream data model"""
+
+ error: Optional[UpstreamError] = None
+
+
+class UpstreamDataData(BaseModel):
+ """Upstream data content model"""
+
+ delta_content: str = ""
+ edit_content: str = ""
+ phase: str = ""
+ done: bool = False
+ results: Optional[List[Dict[str, Any]]] = None
+ sources: Optional[List[Dict[str, Any]]] = None
+ citations: Optional[List[Dict[str, Any]]] = None
+ tool_calls: Optional[List[Dict[str, Any]]] = None
+ usage: Optional[Usage] = None
+ error: Optional[UpstreamError] = None
+ inner: Optional[UpstreamDataInner] = None
+
+
+class UpstreamData(BaseModel):
+ """Upstream data model"""
+
+ type: str
+ data: UpstreamDataData
+ error: Optional[UpstreamError] = None
+
+
+class Model(BaseModel):
+ """Model information for listing"""
+
+ id: str
+ object: str = "model"
+ created: int
+ owned_by: str
+
+
+class ModelsResponse(BaseModel):
+ """Models list response model"""
+
+ object: str = "list"
+ data: List[Model]
diff --git a/app/models/token_db.py b/app/models/token_db.py
new file mode 100644
index 0000000000000000000000000000000000000000..91795c5d1d51e1fa6c92c52ab3315fe99f8149a7
--- /dev/null
+++ b/app/models/token_db.py
@@ -0,0 +1,44 @@
+"""Token 数据库模型定义。"""
+
+from app.core.config import settings
+
+SQL_CREATE_TABLES = """
+-- Token 配置表
+CREATE TABLE IF NOT EXISTS tokens (
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ provider TEXT NOT NULL, -- 提供商: zai
+ token TEXT NOT NULL UNIQUE, -- Token 值(唯一)
+ token_type TEXT DEFAULT 'user', -- Token 类型: user, guest, unknown
+ is_enabled BOOLEAN DEFAULT 1, -- 是否启用
+ priority INTEGER DEFAULT 0, -- 优先级(用于排序)
+ created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
+ updated_at DATETIME DEFAULT CURRENT_TIMESTAMP,
+ UNIQUE(provider, token) -- 同一提供商内 Token 唯一
+);
+
+-- Token 使用统计表
+CREATE TABLE IF NOT EXISTS token_stats (
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ token_id INTEGER NOT NULL,
+ total_requests INTEGER DEFAULT 0,
+ successful_requests INTEGER DEFAULT 0,
+ failed_requests INTEGER DEFAULT 0,
+ last_success_time DATETIME,
+ last_failure_time DATETIME,
+ FOREIGN KEY (token_id) REFERENCES tokens(id) ON DELETE CASCADE
+);
+
+-- 创建索引
+CREATE INDEX IF NOT EXISTS idx_tokens_provider ON tokens(provider);
+CREATE INDEX IF NOT EXISTS idx_tokens_enabled ON tokens(is_enabled);
+CREATE INDEX IF NOT EXISTS idx_token_stats_token_id ON token_stats(token_id);
+
+-- 触发器:自动更新 updated_at
+CREATE TRIGGER IF NOT EXISTS update_tokens_timestamp
+AFTER UPDATE ON tokens
+BEGIN
+ UPDATE tokens SET updated_at = CURRENT_TIMESTAMP WHERE id = NEW.id;
+END;
+"""
+
+DB_PATH = settings.DB_PATH
diff --git a/app/services/request_log_dao.py b/app/services/request_log_dao.py
new file mode 100644
index 0000000000000000000000000000000000000000..53e14f735c224091959c7e88f4f7ad5e71bd13f1
--- /dev/null
+++ b/app/services/request_log_dao.py
@@ -0,0 +1,630 @@
+"""
+请求日志数据访问层 (DAO)
+提供请求日志的 CRUD 操作和查询功能
+"""
+import os
+import sqlite3
+from contextlib import asynccontextmanager
+from datetime import datetime, timedelta
+from typing import Dict, List, Optional
+
+import aiosqlite
+
+from app.models.request_log import DB_PATH, SQL_CREATE_REQUEST_LOGS_TABLE
+from app.utils.logger import logger
+
+
+def _format_sqlite_datetime(value: datetime) -> str:
+ """格式化为 SQLite `CURRENT_TIMESTAMP` 兼容的时间字符串。"""
+ return value.strftime("%Y-%m-%d %H:%M:%S")
+
+
+def _normalize_trend_window(window: Optional[str], days: Optional[int]) -> str:
+ """统一趋势窗口参数,兼容旧版 `days` 调用。"""
+ if window:
+ normalized = str(window).strip().lower()
+ elif days == 30:
+ normalized = "30d"
+ elif days == 1:
+ normalized = "24h"
+ else:
+ normalized = "7d"
+
+ if normalized in {"24h", "7d", "30d"}:
+ return normalized
+ if normalized == "1d":
+ return "24h"
+ return "7d"
+
+
+class RequestLogDAO:
+ """请求日志数据访问对象"""
+
+ def __init__(self, db_path: str = DB_PATH):
+ """初始化 DAO"""
+ self.db_path = db_path
+ self._ensure_db_directory()
+ self._init_db()
+
+ def _ensure_db_directory(self):
+ """确保数据库目录存在"""
+ db_dir = os.path.dirname(self.db_path)
+ if db_dir and not os.path.exists(db_dir):
+ os.makedirs(db_dir, exist_ok=True)
+
+ def _init_db(self):
+ """初始化数据库表"""
+ try:
+ conn = sqlite3.connect(self.db_path)
+ conn.executescript(SQL_CREATE_REQUEST_LOGS_TABLE)
+ self._ensure_columns(conn)
+ conn.commit()
+ conn.close()
+ logger.debug("请求日志表初始化成功")
+ except Exception as e:
+ logger.error(f"初始化请求日志表失败: {e}")
+
+ def _ensure_columns(self, conn: sqlite3.Connection):
+ """为旧数据库补齐新增列。"""
+ cursor = conn.execute("PRAGMA table_info(request_logs)")
+ existing_columns = {row[1] for row in cursor.fetchall()}
+ required_columns = {
+ "endpoint": "TEXT DEFAULT ''",
+ "source": "TEXT DEFAULT 'unknown'",
+ "protocol": "TEXT DEFAULT 'unknown'",
+ "client_name": "TEXT DEFAULT 'Unknown'",
+ "status_code": "INTEGER DEFAULT 200",
+ "cache_creation_tokens": "INTEGER DEFAULT 0",
+ "cache_read_tokens": "INTEGER DEFAULT 0",
+ }
+
+ for column, definition in required_columns.items():
+ if column in existing_columns:
+ continue
+ conn.execute(
+ f"ALTER TABLE request_logs ADD COLUMN {column} {definition}"
+ )
+
+ @asynccontextmanager
+ async def get_connection(self):
+ """获取异步数据库连接"""
+ conn = await aiosqlite.connect(self.db_path)
+ conn.row_factory = aiosqlite.Row
+ try:
+ yield conn
+ finally:
+ await conn.close()
+
+ async def add_log(
+ self,
+ provider: str,
+ endpoint: str,
+ source: str,
+ protocol: str,
+ client_name: str,
+ model: str,
+ status_code: int,
+ success: bool,
+ duration: float = 0.0,
+ first_token_time: float = 0.0,
+ input_tokens: int = 0,
+ output_tokens: int = 0,
+ cache_creation_tokens: int = 0,
+ cache_read_tokens: int = 0,
+ total_tokens: Optional[int] = None,
+ error_message: str = None
+ ) -> int:
+ """
+ 添加请求日志
+
+ Args:
+ provider: 提供商名称
+ endpoint: 请求端点
+ source: 请求来源标识
+ protocol: 协议类型
+ client_name: 客户端名称
+ model: 模型名称
+ status_code: 请求状态码
+ success: 是否成功
+ duration: 总耗时(秒)
+ first_token_time: 首字延迟(秒)
+ input_tokens: 输入 token 数
+ output_tokens: 输出 token 数
+ cache_creation_tokens: 缓存创建 token 数
+ cache_read_tokens: 缓存命中 token 数
+ total_tokens: 总 token 数
+ error_message: 错误信息
+
+ Returns:
+ 日志 ID
+ """
+ if total_tokens is None:
+ total_tokens = input_tokens + output_tokens
+
+ async with self.get_connection() as conn:
+ cursor = await conn.execute(
+ """
+ INSERT INTO request_logs
+ (provider, endpoint, source, protocol, client_name, model,
+ status_code, success, duration, first_token_time,
+ input_tokens, output_tokens, cache_creation_tokens,
+ cache_read_tokens, total_tokens, error_message)
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
+ """,
+ (
+ provider,
+ endpoint,
+ source,
+ protocol,
+ client_name,
+ model,
+ status_code,
+ success,
+ duration,
+ first_token_time,
+ input_tokens,
+ output_tokens,
+ cache_creation_tokens,
+ cache_read_tokens,
+ total_tokens,
+ error_message,
+ )
+ )
+ await conn.commit()
+ return cursor.lastrowid
+
+ async def get_recent_logs(
+ self,
+ limit: int = 100,
+ offset: int = 0,
+ provider: str = None,
+ model: str = None,
+ success: bool = None,
+ source: str = None,
+ ) -> List[Dict]:
+ """
+ 获取最近的请求日志
+
+ Args:
+ limit: 返回数量限制
+ provider: 过滤提供商
+ model: 过滤模型
+ success: 过滤成功/失败状态
+
+ Returns:
+ 日志列表
+ """
+ query = "SELECT * FROM request_logs WHERE 1=1"
+ params = []
+
+ if provider:
+ query += " AND provider = ?"
+ params.append(provider)
+
+ if model:
+ query += " AND model = ?"
+ params.append(model)
+
+ if success is not None:
+ query += " AND success = ?"
+ params.append(success)
+
+ if source:
+ query += " AND source = ?"
+ params.append(source)
+
+ query += " ORDER BY timestamp DESC, id DESC LIMIT ? OFFSET ?"
+ params.extend([limit, max(0, offset)])
+
+ async with self.get_connection() as conn:
+ cursor = await conn.execute(query, params)
+ rows = await cursor.fetchall()
+ return [dict(row) for row in rows]
+
+ async def count_logs(
+ self,
+ provider: str = None,
+ model: str = None,
+ success: bool = None,
+ source: str = None,
+ ) -> int:
+ """统计日志总数。"""
+ query = "SELECT COUNT(*) AS total_count FROM request_logs WHERE 1=1"
+ params = []
+
+ if provider:
+ query += " AND provider = ?"
+ params.append(provider)
+
+ if model:
+ query += " AND model = ?"
+ params.append(model)
+
+ if success is not None:
+ query += " AND success = ?"
+ params.append(success)
+
+ if source:
+ query += " AND source = ?"
+ params.append(source)
+
+ async with self.get_connection() as conn:
+ cursor = await conn.execute(query, params)
+ row = await cursor.fetchone()
+ return int(row["total_count"] or 0) if row else 0
+
+ async def get_logs_by_time_range(
+ self,
+ start_time: datetime,
+ end_time: datetime,
+ provider: str = None,
+ model: str = None
+ ) -> List[Dict]:
+ """
+ 按时间范围获取日志
+
+ Args:
+ start_time: 开始时间
+ end_time: 结束时间
+ provider: 过滤提供商
+ model: 过滤模型
+
+ Returns:
+ 日志列表
+ """
+ query = "SELECT * FROM request_logs WHERE timestamp BETWEEN ? AND ?"
+ params = [
+ _format_sqlite_datetime(start_time),
+ _format_sqlite_datetime(end_time),
+ ]
+
+ if provider:
+ query += " AND provider = ?"
+ params.append(provider)
+
+ if model:
+ query += " AND model = ?"
+ params.append(model)
+
+ query += " ORDER BY timestamp DESC, id DESC"
+
+ async with self.get_connection() as conn:
+ cursor = await conn.execute(query, params)
+ rows = await cursor.fetchall()
+ return [dict(row) for row in rows]
+
+ async def get_provider_request_stats(self, provider: Optional[str] = None) -> Dict:
+ """聚合请求日志统计,可按提供商过滤。"""
+ query = """
+ SELECT
+ COUNT(*) as total_requests,
+ SUM(CASE WHEN success = 1 THEN 1 ELSE 0 END) as successful_requests,
+ SUM(CASE WHEN success = 0 THEN 1 ELSE 0 END) as failed_requests,
+ SUM(input_tokens) as input_tokens,
+ SUM(output_tokens) as output_tokens,
+ SUM(total_tokens) as total_tokens,
+ SUM(cache_creation_tokens) as cache_creation_tokens,
+ SUM(cache_read_tokens) as cache_read_tokens,
+ SUM(
+ CASE WHEN cache_creation_tokens > 0 THEN 1 ELSE 0 END
+ ) as cache_creation_requests,
+ SUM(
+ CASE WHEN cache_read_tokens > 0 THEN 1 ELSE 0 END
+ ) as cache_hit_requests,
+ AVG(duration) as avg_duration,
+ AVG(
+ CASE
+ WHEN first_token_time > 0 THEN first_token_time
+ ELSE NULL
+ END
+ ) as avg_first_token_time
+ FROM request_logs
+ """
+ params: List[object] = []
+
+ if provider:
+ query += " WHERE provider = ?"
+ params.append(provider)
+
+ try:
+ async with self.get_connection() as conn:
+ cursor = await conn.execute(query, params)
+ row = await cursor.fetchone()
+
+ if not row:
+ return {
+ "total_requests": 0,
+ "successful_requests": 0,
+ "failed_requests": 0,
+ "input_tokens": 0,
+ "output_tokens": 0,
+ "total_tokens": 0,
+ "cache_creation_tokens": 0,
+ "cache_read_tokens": 0,
+ "cache_creation_requests": 0,
+ "cache_hit_requests": 0,
+ "avg_duration": 0.0,
+ "avg_first_token_time": 0.0,
+ }
+
+ return {
+ "total_requests": int(row["total_requests"] or 0),
+ "successful_requests": int(row["successful_requests"] or 0),
+ "failed_requests": int(row["failed_requests"] or 0),
+ "input_tokens": int(row["input_tokens"] or 0),
+ "output_tokens": int(row["output_tokens"] or 0),
+ "total_tokens": int(row["total_tokens"] or 0),
+ "cache_creation_tokens": int(
+ row["cache_creation_tokens"] or 0
+ ),
+ "cache_read_tokens": int(row["cache_read_tokens"] or 0),
+ "cache_creation_requests": int(
+ row["cache_creation_requests"] or 0
+ ),
+ "cache_hit_requests": int(row["cache_hit_requests"] or 0),
+ "avg_duration": float(row["avg_duration"] or 0.0),
+ "avg_first_token_time": float(
+ row["avg_first_token_time"] or 0.0
+ ),
+ }
+ except Exception as e:
+ logger.error(f"❌ 获取请求统计失败: {e}")
+ return {
+ "total_requests": 0,
+ "successful_requests": 0,
+ "failed_requests": 0,
+ "input_tokens": 0,
+ "output_tokens": 0,
+ "total_tokens": 0,
+ "cache_creation_tokens": 0,
+ "cache_read_tokens": 0,
+ "cache_creation_requests": 0,
+ "cache_hit_requests": 0,
+ "avg_duration": 0.0,
+ "avg_first_token_time": 0.0,
+ }
+
+ async def get_provider_usage_trend(
+ self,
+ provider: Optional[str] = None,
+ days: Optional[int] = None,
+ *,
+ window: Optional[str] = None,
+ now: Optional[datetime] = None,
+ ) -> List[Dict]:
+ """按窗口聚合最近一段时间的请求与 token 趋势。"""
+ trend_window = _normalize_trend_window(window, days)
+ current_time = now or datetime.utcnow()
+
+ if trend_window == "24h":
+ bucket_count = 24
+ current_hour = current_time.replace(
+ minute=0,
+ second=0,
+ microsecond=0,
+ )
+ start_time = current_hour - timedelta(hours=bucket_count - 1)
+ bucket_expression = "strftime('%Y-%m-%d %H:00:00', timestamp)"
+ row_key = "trend_bucket"
+ label_format = "%H:%M"
+ tooltip_format = "%Y-%m-%d %H:00"
+ rows = await self._query_usage_trend_rows(
+ provider,
+ start_time,
+ bucket_expression,
+ row_key,
+ )
+ rows_by_bucket = {str(row[row_key]): dict(row) for row in rows}
+ trend: List[Dict] = []
+
+ for offset in range(bucket_count):
+ bucket_time = start_time + timedelta(hours=offset)
+ bucket_key = bucket_time.strftime("%Y-%m-%d %H:00:00")
+ trend.append(
+ self._build_usage_trend_point(
+ row=rows_by_bucket.get(bucket_key, {}),
+ bucket=bucket_key,
+ label=bucket_time.strftime(label_format),
+ tooltip_label=bucket_time.strftime(tooltip_format),
+ )
+ )
+
+ return trend
+
+ bucket_count = 30 if trend_window == "30d" else 7
+ current_date = current_time.date()
+ start_date = current_date - timedelta(days=bucket_count - 1)
+ start_time = datetime.combine(start_date, datetime.min.time())
+ rows = await self._query_usage_trend_rows(
+ provider,
+ start_time,
+ "DATE(timestamp)",
+ "trend_bucket",
+ )
+ rows_by_bucket = {
+ str(row["trend_bucket"]): dict(row)
+ for row in rows
+ }
+ trend = []
+
+ for offset in range(bucket_count):
+ bucket_date = start_date + timedelta(days=offset)
+ bucket_key = bucket_date.isoformat()
+ trend.append(
+ self._build_usage_trend_point(
+ row=rows_by_bucket.get(bucket_key, {}),
+ bucket=bucket_key,
+ label=bucket_date.strftime("%m-%d"),
+ tooltip_label=bucket_date.strftime("%Y-%m-%d"),
+ )
+ )
+
+ return trend
+
+ async def _query_usage_trend_rows(
+ self,
+ provider: Optional[str],
+ start_time: datetime,
+ bucket_expression: str,
+ bucket_alias: str,
+ ) -> list[aiosqlite.Row]:
+ query = f"""
+ SELECT
+ {bucket_expression} as {bucket_alias},
+ COUNT(*) as total_requests,
+ SUM(CASE WHEN success = 1 THEN 1 ELSE 0 END) as successful_requests,
+ SUM(input_tokens) as input_tokens,
+ SUM(output_tokens) as output_tokens,
+ SUM(total_tokens) as total_tokens,
+ SUM(cache_creation_tokens) as cache_creation_tokens,
+ SUM(cache_read_tokens) as cache_read_tokens
+ FROM request_logs
+ WHERE timestamp >= ?
+ """
+ params: List[object] = [_format_sqlite_datetime(start_time)]
+
+ if provider:
+ query += " AND provider = ?"
+ params.append(provider)
+
+ query += f" GROUP BY {bucket_expression} ORDER BY {bucket_alias} ASC"
+
+ async with self.get_connection() as conn:
+ cursor = await conn.execute(query, params)
+ return await cursor.fetchall()
+
+ def _build_usage_trend_point(
+ self,
+ *,
+ row: Dict,
+ bucket: str,
+ label: str,
+ tooltip_label: str,
+ ) -> Dict:
+ total_requests = int(row.get("total_requests") or 0)
+ successful_requests = int(row.get("successful_requests") or 0)
+ cache_creation_tokens = int(row.get("cache_creation_tokens") or 0)
+ cache_read_tokens = int(row.get("cache_read_tokens") or 0)
+
+ return {
+ "bucket": bucket,
+ "label": label,
+ "tooltip_label": tooltip_label,
+ "total_requests": total_requests,
+ "successful_requests": successful_requests,
+ "failed_requests": max(0, total_requests - successful_requests),
+ "input_tokens": int(row.get("input_tokens") or 0),
+ "output_tokens": int(row.get("output_tokens") or 0),
+ "total_tokens": int(row.get("total_tokens") or 0),
+ "cache_creation_tokens": cache_creation_tokens,
+ "cache_read_tokens": cache_read_tokens,
+ "cache_total_tokens": (
+ cache_creation_tokens + cache_read_tokens
+ ),
+ "success_rate": round(
+ (
+ successful_requests / total_requests * 100
+ ) if total_requests > 0 else 0,
+ 1,
+ ),
+ }
+
+ async def get_model_stats_from_db(self, hours: int = 24) -> Dict:
+ """
+ 从数据库获取模型统计(最近N小时)
+
+ Args:
+ hours: 小时数
+
+ Returns:
+ 模型统计数据
+ """
+ start_time = datetime.utcnow() - timedelta(hours=hours)
+
+ async with self.get_connection() as conn:
+ cursor = await conn.execute(
+ """
+ SELECT
+ model,
+ COUNT(*) as total,
+ SUM(CASE WHEN success = 1 THEN 1 ELSE 0 END) as success,
+ SUM(CASE WHEN success = 0 THEN 1 ELSE 0 END) as failed,
+ SUM(input_tokens) as input_tokens,
+ SUM(output_tokens) as output_tokens,
+ SUM(total_tokens) as total_tokens,
+ AVG(duration) as avg_duration,
+ AVG(first_token_time) as avg_first_token_time
+ FROM request_logs
+ WHERE timestamp >= ?
+ GROUP BY model
+ ORDER BY total DESC
+ """,
+ (_format_sqlite_datetime(start_time),)
+ )
+ rows = await cursor.fetchall()
+
+ result = {}
+ for row in rows:
+ model = row['model']
+ result[model] = {
+ 'total': row['total'],
+ 'success': row['success'],
+ 'failed': row['failed'],
+ 'input_tokens': row['input_tokens'] or 0,
+ 'output_tokens': row['output_tokens'] or 0,
+ 'total_tokens': row['total_tokens'] or 0,
+ 'avg_duration': round(row['avg_duration'] or 0, 2),
+ 'avg_first_token_time': round(row['avg_first_token_time'] or 0, 2),
+ 'success_rate': round(
+ (row['success'] / row['total'] * 100)
+ if row['total'] > 0
+ else 0,
+ 1,
+ ),
+ }
+
+ return result
+
+ async def delete_old_logs(self, days: int = 30) -> int:
+ """
+ 删除旧日志
+
+ Args:
+ days: 保留天数
+
+ Returns:
+ 删除的记录数
+ """
+ cutoff_time = datetime.utcnow() - timedelta(days=days)
+
+ async with self.get_connection() as conn:
+ cursor = await conn.execute(
+ "DELETE FROM request_logs WHERE timestamp < ?",
+ (_format_sqlite_datetime(cutoff_time),)
+ )
+ await conn.commit()
+ return cursor.rowcount
+
+
+# 全局单例实例
+_request_log_dao: Optional[RequestLogDAO] = None
+
+
+def get_request_log_dao() -> RequestLogDAO:
+ """
+ 获取请求日志 DAO 单例
+
+ Returns:
+ RequestLogDAO 实例
+ """
+ global _request_log_dao
+ if _request_log_dao is None:
+ _request_log_dao = RequestLogDAO()
+ return _request_log_dao
+
+
+def init_request_log_dao():
+ """初始化请求日志 DAO"""
+ global _request_log_dao
+ _request_log_dao = RequestLogDAO()
+ return _request_log_dao
diff --git a/app/services/token_automation.py b/app/services/token_automation.py
new file mode 100644
index 0000000000000000000000000000000000000000..abf07e8fa4f933465b506e167ab64fd1932bbe87
--- /dev/null
+++ b/app/services/token_automation.py
@@ -0,0 +1,278 @@
+"""Background automation for token import and maintenance."""
+
+from __future__ import annotations
+
+import asyncio
+from dataclasses import dataclass
+from typing import Optional
+
+from app.core.config import settings
+from app.services.token_dao import TokenDAO, get_token_dao
+from app.services.token_importer import TokenImportSummary, import_tokens_from_directory
+from app.utils.logger import logger
+from app.utils.token_pool import TokenPool, get_token_pool
+
+DEFAULT_TOKEN_PROVIDER = "zai"
+_AUTO_IMPORT_LOCK = asyncio.Lock()
+_AUTO_MAINTENANCE_LOCK = asyncio.Lock()
+
+
+@dataclass(frozen=True)
+class TokenMaintenanceSummary:
+ provider: str
+ checked_count: int = 0
+ duplicate_removed_count: int = 0
+ valid_count: int = 0
+ guest_count: int = 0
+ invalid_count: int = 0
+ deleted_invalid_count: int = 0
+
+
+async def run_directory_import(
+ source_dir: str,
+ *,
+ provider: str = DEFAULT_TOKEN_PROVIDER,
+ validate: bool = True,
+ dao: Optional[TokenDAO] = None,
+ pool: Optional[TokenPool] = None,
+) -> TokenImportSummary:
+ """Import tokens from a configured directory and refresh the pool if needed."""
+ if _AUTO_IMPORT_LOCK.locked():
+ raise RuntimeError("目录导入任务正在执行,请稍后再试")
+
+ async with _AUTO_IMPORT_LOCK:
+ summary = await import_tokens_from_directory(
+ source_dir,
+ provider=provider,
+ validate=validate,
+ dao=dao,
+ )
+
+ active_pool = pool if pool is not None else get_token_pool()
+ if active_pool and summary.imported_count > 0:
+ await active_pool.sync_from_database(provider)
+ logger.info("✅ 目录导入后已同步 Token 池")
+
+ return summary
+
+
+async def run_token_maintenance(
+ *,
+ provider: str = DEFAULT_TOKEN_PROVIDER,
+ remove_duplicates: bool = True,
+ run_health_check: bool = True,
+ delete_invalid_tokens: bool = False,
+ dao: Optional[TokenDAO] = None,
+ pool: Optional[TokenPool] = None,
+) -> TokenMaintenanceSummary:
+ """Run dedupe, validation, and invalid-token cleanup as one maintenance cycle."""
+ if _AUTO_MAINTENANCE_LOCK.locked():
+ raise RuntimeError("Token 自动维护任务正在执行,请稍后再试")
+
+ token_dao = dao or get_token_dao()
+ duplicate_removed_count = 0
+ checked_count = 0
+ valid_count = 0
+ guest_count = 0
+ invalid_count = 0
+ deleted_invalid_count = 0
+
+ async with _AUTO_MAINTENANCE_LOCK:
+ if remove_duplicates:
+ duplicate_removed_count = await token_dao.remove_duplicate_tokens(provider)
+
+ should_validate = run_health_check or delete_invalid_tokens
+ invalid_token_ids: list[int] = []
+
+ if should_validate:
+ validation_result = await token_dao.validate_tokens_detailed(provider)
+ checked_count = int(validation_result.get("checked", 0) or 0)
+ valid_count = int(validation_result.get("valid", 0) or 0)
+ guest_count = int(validation_result.get("guest", 0) or 0)
+ invalid_count = int(validation_result.get("invalid", 0) or 0)
+ invalid_token_ids = list(
+ validation_result.get("invalid_token_ids", []) or []
+ )
+
+ if delete_invalid_tokens and invalid_token_ids:
+ deleted_invalid_count = await token_dao.delete_tokens_by_ids(
+ invalid_token_ids
+ )
+
+ active_pool = pool if pool is not None else get_token_pool()
+ if active_pool:
+ await active_pool.sync_from_database(provider)
+ logger.info("✅ Token 维护后已同步 Token 池")
+
+ return TokenMaintenanceSummary(
+ provider=provider,
+ checked_count=checked_count,
+ duplicate_removed_count=duplicate_removed_count,
+ valid_count=valid_count,
+ guest_count=guest_count,
+ invalid_count=invalid_count,
+ deleted_invalid_count=deleted_invalid_count,
+ )
+
+
+class TokenAutomationScheduler:
+ """Run token import and maintenance loops in the application background."""
+
+ def __init__(self) -> None:
+ self._stop_event = asyncio.Event()
+ self._tasks: list[asyncio.Task] = []
+ self._import_warning: Optional[str] = None
+ self._maintenance_warning: Optional[str] = None
+
+ async def start(self) -> None:
+ if self._tasks:
+ return
+
+ self._stop_event.clear()
+ self._tasks = [
+ asyncio.create_task(
+ self._auto_import_loop(),
+ name="token-auto-import",
+ ),
+ asyncio.create_task(
+ self._auto_maintenance_loop(),
+ name="token-auto-maintenance",
+ ),
+ ]
+ logger.info("✅ Token 自动任务调度器已启动")
+
+ async def stop(self) -> None:
+ if not self._tasks:
+ return
+
+ self._stop_event.set()
+ for task in self._tasks:
+ task.cancel()
+
+ await asyncio.gather(*self._tasks, return_exceptions=True)
+ self._tasks.clear()
+ self._import_warning = None
+ self._maintenance_warning = None
+ logger.info("🛑 Token 自动任务调度器已停止")
+
+ async def _auto_import_loop(self) -> None:
+ while not self._stop_event.is_set():
+ wait_seconds = 15
+ try:
+ if settings.TOKEN_AUTO_IMPORT_ENABLED:
+ wait_seconds = max(int(settings.TOKEN_AUTO_IMPORT_INTERVAL), 30)
+ source_dir = settings.TOKEN_AUTO_IMPORT_SOURCE_DIR.strip()
+ if not source_dir:
+ self._log_import_warning_once(
+ "已启用自动导入,但未配置导入目录"
+ )
+ else:
+ self._import_warning = None
+ summary = await run_directory_import(
+ source_dir,
+ provider=DEFAULT_TOKEN_PROVIDER,
+ )
+ logger.info(
+ "🔄 自动导入完成: scanned={} imported={} duplicate={} invalid={}",
+ summary.scanned_files,
+ summary.imported_count,
+ summary.duplicate_count,
+ summary.invalid_json_count + summary.invalid_token_count,
+ )
+ except asyncio.CancelledError:
+ raise
+ except RuntimeError as exc:
+ logger.info(f"⏭️ 跳过本轮自动导入: {exc}")
+ except (FileNotFoundError, NotADirectoryError) as exc:
+ self._log_import_warning_once(str(exc))
+ except Exception as exc:
+ logger.exception(f"❌ 自动导入 Token 失败: {exc}")
+
+ await self._wait_or_stop(wait_seconds)
+
+ async def _auto_maintenance_loop(self) -> None:
+ while not self._stop_event.is_set():
+ wait_seconds = 15
+ try:
+ if settings.TOKEN_AUTO_MAINTENANCE_ENABLED:
+ wait_seconds = max(
+ int(settings.TOKEN_AUTO_MAINTENANCE_INTERVAL),
+ 30,
+ )
+ if not self._has_enabled_maintenance_action():
+ self._log_maintenance_warning_once(
+ "已启用自动维护,但未选择任何维护动作"
+ )
+ else:
+ self._maintenance_warning = None
+ summary = await run_token_maintenance(
+ provider=DEFAULT_TOKEN_PROVIDER,
+ remove_duplicates=settings.TOKEN_AUTO_REMOVE_DUPLICATES,
+ run_health_check=settings.TOKEN_AUTO_HEALTH_CHECK,
+ delete_invalid_tokens=settings.TOKEN_AUTO_DELETE_INVALID,
+ )
+ logger.info(
+ "🧹 自动维护完成: dedupe={} checked={} valid={} guest={} invalid={} deleted={}",
+ summary.duplicate_removed_count,
+ summary.checked_count,
+ summary.valid_count,
+ summary.guest_count,
+ summary.invalid_count,
+ summary.deleted_invalid_count,
+ )
+ except asyncio.CancelledError:
+ raise
+ except RuntimeError as exc:
+ logger.info(f"⏭️ 跳过本轮自动维护: {exc}")
+ except Exception as exc:
+ logger.exception(f"❌ Token 自动维护失败: {exc}")
+
+ await self._wait_or_stop(wait_seconds)
+
+ async def _wait_or_stop(self, timeout: int) -> None:
+ try:
+ await asyncio.wait_for(self._stop_event.wait(), timeout=timeout)
+ except asyncio.TimeoutError:
+ return
+
+ def _has_enabled_maintenance_action(self) -> bool:
+ return any(
+ (
+ settings.TOKEN_AUTO_REMOVE_DUPLICATES,
+ settings.TOKEN_AUTO_HEALTH_CHECK,
+ settings.TOKEN_AUTO_DELETE_INVALID,
+ )
+ )
+
+ def _log_import_warning_once(self, message: str) -> None:
+ if self._import_warning == message:
+ return
+ self._import_warning = message
+ logger.warning(f"⚠️ {message}")
+
+ def _log_maintenance_warning_once(self, message: str) -> None:
+ if self._maintenance_warning == message:
+ return
+ self._maintenance_warning = message
+ logger.warning(f"⚠️ {message}")
+
+
+_scheduler: Optional[TokenAutomationScheduler] = None
+
+
+def get_token_automation_scheduler() -> TokenAutomationScheduler:
+ global _scheduler
+ if _scheduler is None:
+ _scheduler = TokenAutomationScheduler()
+ return _scheduler
+
+
+async def start_token_automation_scheduler() -> None:
+ await get_token_automation_scheduler().start()
+
+
+async def stop_token_automation_scheduler() -> None:
+ global _scheduler
+ if _scheduler is None:
+ return
+ await _scheduler.stop()
diff --git a/app/services/token_dao.py b/app/services/token_dao.py
new file mode 100644
index 0000000000000000000000000000000000000000..f4f1058e9f11bad47df8ccb96ff14e004906e786
--- /dev/null
+++ b/app/services/token_dao.py
@@ -0,0 +1,664 @@
+"""
+Token 数据访问层 (DAO)
+提供 Token 的 CRUD 操作和查询功能
+"""
+import os
+import sqlite3
+from contextlib import asynccontextmanager
+from typing import Any, Dict, List, Optional, Tuple
+
+import aiosqlite
+
+from app.models.token_db import DB_PATH, SQL_CREATE_TABLES
+from app.utils.logger import logger
+
+
+class TokenDAO:
+ """Token 数据访问对象"""
+
+ def __init__(self, db_path: str = DB_PATH):
+ """初始化 DAO"""
+ self.db_path = db_path
+ self._ensure_db_directory()
+
+ def _ensure_db_directory(self):
+ """确保数据库目录存在"""
+ db_dir = os.path.dirname(self.db_path)
+ if db_dir and not os.path.exists(db_dir):
+ os.makedirs(db_dir, exist_ok=True)
+
+ @asynccontextmanager
+ async def get_connection(self):
+ """获取异步数据库连接"""
+ conn = await aiosqlite.connect(self.db_path)
+ conn.row_factory = aiosqlite.Row # 返回字典式结果
+
+ # 启用外键约束(SQLite 默认关闭)
+ await conn.execute("PRAGMA foreign_keys = ON")
+
+ try:
+ yield conn
+ finally:
+ await conn.close()
+
+ def get_sync_connection(self):
+ """获取同步数据库连接(用于初始化)"""
+ conn = sqlite3.connect(self.db_path)
+ # 启用外键约束
+ conn.execute("PRAGMA foreign_keys = ON")
+ return conn
+
+ async def init_database(self):
+ """初始化数据库表结构"""
+ try:
+ # 使用同步连接创建表(避免异步初始化问题)
+ conn = self.get_sync_connection()
+ conn.executescript(SQL_CREATE_TABLES)
+ conn.commit()
+ conn.close()
+ except Exception as e:
+ logger.error(f"❌ Token 数据库初始化失败: {e}")
+ raise
+
+ # ==================== Token CRUD 操作 ====================
+
+ async def add_token(
+ self,
+ provider: str,
+ token: str,
+ token_type: str = "user",
+ priority: int = 0,
+ validate: bool = True
+ ) -> Optional[int]:
+ """
+ 添加新 Token(可选验证)
+
+ Args:
+ provider: 提供商名称
+ token: Token 值
+ token_type: Token 类型(如果 validate=True 将被验证结果覆盖)
+ priority: 优先级
+ validate: 是否验证 Token(仅针对 zai 提供商)
+
+ Returns:
+ token_id 或 None(验证失败或已存在)
+ """
+ try:
+ # 对于 zai 提供商,强制验证 Token
+ if provider == "zai" and validate:
+ from app.utils.token_pool import ZAITokenValidator
+
+ validated_type, is_valid, error_msg = await ZAITokenValidator.validate_token(token)
+
+ # 拒绝 guest token
+ if validated_type == "guest":
+ logger.warning(f"🚫 拒绝添加匿名用户 Token: {token[:20]}... - {error_msg}")
+ return None
+
+ # 拒绝无效 token
+ if not is_valid:
+ logger.warning(f"🚫 Token 验证失败: {token[:20]}... - {error_msg}")
+ return None
+
+ # 使用验证后的类型
+ token_type = validated_type
+
+ async with self.get_connection() as conn:
+ cursor = await conn.execute("""
+ INSERT OR IGNORE INTO tokens (provider, token, token_type, priority)
+ VALUES (?, ?, ?, ?)
+ """, (provider, token, token_type, priority))
+
+ await conn.commit()
+
+ if cursor.lastrowid > 0:
+ # 同时创建统计记录
+ await conn.execute("""
+ INSERT INTO token_stats (token_id)
+ VALUES (?)
+ """, (cursor.lastrowid,))
+ await conn.commit()
+ logger.info(f"✅ 添加 Token: {provider} ({token_type}) - {token[:20]}...")
+ return cursor.lastrowid
+ else:
+ logger.warning(f"⚠️ Token 已存在: {provider} - {token[:20]}...")
+ return None
+ except Exception as e:
+ logger.error(f"❌ 添加 Token 失败: {e}")
+ return None
+
+ async def get_tokens_by_provider(
+ self,
+ provider: str,
+ enabled_only: bool = True,
+ limit: Optional[int] = None,
+ offset: int = 0,
+ ) -> List[Dict]:
+ """
+ 获取指定提供商的所有 Token
+
+ Args:
+ provider: 提供商名称
+ enabled_only: 是否只返回启用的 Token
+ """
+ try:
+ async with self.get_connection() as conn:
+ query = """
+ SELECT t.*, ts.total_requests, ts.successful_requests, ts.failed_requests,
+ ts.last_success_time, ts.last_failure_time
+ FROM tokens t
+ LEFT JOIN token_stats ts ON t.id = ts.token_id
+ WHERE t.provider = ?
+ """
+ params = [provider]
+
+ if enabled_only:
+ query += " AND t.is_enabled = 1"
+
+ query += " ORDER BY t.priority DESC, t.id ASC"
+
+ if limit is not None:
+ query += " LIMIT ? OFFSET ?"
+ params.extend([limit, max(0, offset)])
+
+ cursor = await conn.execute(query, params)
+ rows = await cursor.fetchall()
+
+ return [dict(row) for row in rows]
+ except Exception as e:
+ logger.error(f"❌ 查询 Token 失败: {e}")
+ return []
+
+ async def get_all_tokens(self, enabled_only: bool = False) -> List[Dict]:
+ """获取所有 Token"""
+ try:
+ async with self.get_connection() as conn:
+ query = """
+ SELECT t.*, ts.total_requests, ts.successful_requests, ts.failed_requests,
+ ts.last_success_time, ts.last_failure_time
+ FROM tokens t
+ LEFT JOIN token_stats ts ON t.id = ts.token_id
+ """
+
+ if enabled_only:
+ query += " WHERE t.is_enabled = 1"
+
+ query += " ORDER BY t.provider, t.priority DESC, t.id ASC"
+
+ cursor = await conn.execute(query)
+ rows = await cursor.fetchall()
+
+ return [dict(row) for row in rows]
+ except Exception as e:
+ logger.error(f"❌ 查询所有 Token 失败: {e}")
+ return []
+
+ async def update_token_status(self, token_id: int, is_enabled: bool):
+ """更新 Token 启用状态"""
+ try:
+ async with self.get_connection() as conn:
+ await conn.execute("""
+ UPDATE tokens SET is_enabled = ? WHERE id = ?
+ """, (is_enabled, token_id))
+ await conn.commit()
+ logger.info(f"✅ 更新 Token 状态: id={token_id}, enabled={is_enabled}")
+ except Exception as e:
+ logger.error(f"❌ 更新 Token 状态失败: {e}")
+
+ async def update_token_type(self, token_id: int, token_type: str):
+ """更新 Token 类型"""
+ try:
+ async with self.get_connection() as conn:
+ await conn.execute("""
+ UPDATE tokens SET token_type = ? WHERE id = ?
+ """, (token_type, token_id))
+ await conn.commit()
+ logger.info(f"✅ 更新 Token 类型: id={token_id}, type={token_type}")
+ except Exception as e:
+ logger.error(f"❌ 更新 Token 类型失败: {e}")
+
+ async def delete_token(self, token_id: int):
+ """删除 Token(级联删除统计数据)"""
+ try:
+ async with self.get_connection() as conn:
+ await conn.execute("DELETE FROM tokens WHERE id = ?", (token_id,))
+ await conn.commit()
+ logger.info(f"✅ 删除 Token: id={token_id}")
+ except Exception as e:
+ logger.error(f"❌ 删除 Token 失败: {e}")
+
+ async def delete_tokens_by_ids(self, token_ids: List[int]) -> int:
+ """批量删除 Token(级联删除统计数据)"""
+ if not token_ids:
+ return 0
+
+ try:
+ placeholders = ",".join("?" for _ in token_ids)
+ async with self.get_connection() as conn:
+ await conn.execute(
+ f"DELETE FROM tokens WHERE id IN ({placeholders})",
+ token_ids,
+ )
+ cursor = await conn.execute("SELECT changes()")
+ row = await cursor.fetchone()
+ await conn.commit()
+
+ deleted_count = int(row[0] if row else 0)
+ logger.info(f"✅ 批量删除 Token: {deleted_count} 个")
+ return deleted_count
+ except Exception as e:
+ logger.error(f"❌ 批量删除 Token 失败: {e}")
+ return 0
+
+ async def delete_tokens_by_provider(self, provider: str):
+ """删除指定提供商的所有 Token"""
+ try:
+ async with self.get_connection() as conn:
+ await conn.execute("DELETE FROM tokens WHERE provider = ?", (provider,))
+ await conn.commit()
+ logger.info(f"✅ 删除提供商所有 Token: {provider}")
+ except Exception as e:
+ logger.error(f"❌ 删除提供商 Token 失败: {e}")
+
+ # ==================== Token 统计操作 ====================
+
+ async def record_success(self, token_id: int):
+ """记录 Token 使用成功"""
+ try:
+ async with self.get_connection() as conn:
+ await conn.execute("""
+ UPDATE token_stats
+ SET total_requests = total_requests + 1,
+ successful_requests = successful_requests + 1,
+ last_success_time = CURRENT_TIMESTAMP
+ WHERE token_id = ?
+ """, (token_id,))
+ await conn.commit()
+ except Exception as e:
+ logger.error(f"❌ 记录成功失败: {e}")
+
+ async def record_failure(self, token_id: int):
+ """记录 Token 使用失败"""
+ try:
+ async with self.get_connection() as conn:
+ await conn.execute("""
+ UPDATE token_stats
+ SET total_requests = total_requests + 1,
+ failed_requests = failed_requests + 1,
+ last_failure_time = CURRENT_TIMESTAMP
+ WHERE token_id = ?
+ """, (token_id,))
+ await conn.commit()
+ except Exception as e:
+ logger.error(f"❌ 记录失败失败: {e}")
+
+ async def get_token_stats(self, token_id: int) -> Optional[Dict]:
+ """获取 Token 统计信息"""
+ try:
+ async with self.get_connection() as conn:
+ cursor = await conn.execute("""
+ SELECT * FROM token_stats WHERE token_id = ?
+ """, (token_id,))
+ row = await cursor.fetchone()
+ return dict(row) if row else None
+ except Exception as e:
+ logger.error(f"❌ 获取统计信息失败: {e}")
+ return None
+
+ # ==================== 批量操作 ====================
+
+ async def bulk_add_tokens(
+ self,
+ provider: str,
+ tokens: List[str],
+ token_type: str = "user",
+ validate: bool = True
+ ) -> Tuple[int, int]:
+ """
+ 批量添加 Token(可选验证)
+
+ Args:
+ provider: 提供商名称
+ tokens: Token 列表
+ token_type: Token 类型(如果 validate=True 将被覆盖)
+ validate: 是否验证 Token(仅针对 zai)
+
+ Returns:
+ (成功添加数量, 失败数量)
+ """
+ added_count = 0
+ failed_count = 0
+
+ for token in tokens:
+ if token.strip(): # 过滤空 token
+ token_id = await self.add_token(
+ provider,
+ token.strip(),
+ token_type,
+ validate=validate
+ )
+ if token_id:
+ added_count += 1
+ else:
+ failed_count += 1
+
+ logger.info(f"✅ 批量添加完成: {provider} - 成功 {added_count}/{len(tokens)},失败 {failed_count}")
+ return added_count, failed_count
+
+ async def replace_tokens(self, provider: str, tokens: List[str],
+ token_type: str = "user"):
+ """
+ 替换指定提供商的所有 Token(先删除后添加)
+ """
+ # 删除旧 Token
+ await self.delete_tokens_by_provider(provider)
+
+ # 添加新 Token
+ added_count = await self.bulk_add_tokens(provider, tokens, token_type)
+
+ logger.info(f"✅ 替换 Token 完成: {provider} - {added_count} 个")
+ return added_count
+
+ async def remove_duplicate_tokens(self, provider: Optional[str] = None) -> int:
+ """
+ 删除重复 Token,保留每个 provider/token 组合中排序靠前的一条记录。
+
+ 正常情况下唯一约束会阻止重复数据,这里主要处理历史数据或手工导入异常。
+ """
+ try:
+ tokens = (
+ await self.get_tokens_by_provider(provider, enabled_only=False)
+ if provider
+ else await self.get_all_tokens(enabled_only=False)
+ )
+
+ seen_keys: set[tuple[str, str]] = set()
+ duplicate_ids: list[int] = []
+
+ for token_record in tokens:
+ token_value = str(token_record.get("token") or "").strip()
+ token_provider = str(token_record.get("provider") or "")
+ key = (token_provider, token_value)
+
+ if key in seen_keys:
+ duplicate_ids.append(int(token_record["id"]))
+ continue
+
+ seen_keys.add(key)
+
+ deleted_count = await self.delete_tokens_by_ids(duplicate_ids)
+ if deleted_count > 0:
+ logger.info(f"✅ 已清理重复 Token: {deleted_count} 个")
+ return deleted_count
+ except Exception as e:
+ logger.error(f"❌ 清理重复 Token 失败: {e}")
+ return 0
+
+ # ==================== 实用方法 ====================
+
+ async def get_token_by_value(self, provider: str, token: str) -> Optional[Dict]:
+ """根据 Token 值查询"""
+ try:
+ async with self.get_connection() as conn:
+ cursor = await conn.execute("""
+ SELECT t.*, ts.total_requests, ts.successful_requests, ts.failed_requests
+ FROM tokens t
+ LEFT JOIN token_stats ts ON t.id = ts.token_id
+ WHERE t.provider = ? AND t.token = ?
+ """, (provider, token))
+ row = await cursor.fetchone()
+ return dict(row) if row else None
+ except Exception as e:
+ logger.error(f"❌ 查询 Token 失败: {e}")
+ return None
+
+ async def get_provider_stats(self, provider: str) -> Dict:
+ """获取提供商统计信息"""
+ try:
+ async with self.get_connection() as conn:
+ cursor = await conn.execute("""
+ SELECT
+ COUNT(*) as total_tokens,
+ SUM(CASE WHEN is_enabled = 1 THEN 1 ELSE 0 END) as enabled_tokens,
+ SUM(ts.total_requests) as total_requests,
+ SUM(ts.successful_requests) as successful_requests,
+ SUM(ts.failed_requests) as failed_requests
+ FROM tokens t
+ LEFT JOIN token_stats ts ON t.id = ts.token_id
+ WHERE t.provider = ?
+ """, (provider,))
+ row = await cursor.fetchone()
+ return dict(row) if row else {}
+ except Exception as e:
+ logger.error(f"❌ 获取提供商统计失败: {e}")
+ return {}
+
+ async def get_provider_token_counts(self, provider: str) -> Dict[str, int]:
+ """聚合提供商的 Token 数量与类型分布。"""
+ try:
+ async with self.get_connection() as conn:
+ cursor = await conn.execute(
+ """
+ SELECT
+ COUNT(*) as total_tokens,
+ SUM(CASE WHEN is_enabled = 1 THEN 1 ELSE 0 END) as enabled_tokens,
+ SUM(CASE WHEN token_type = 'user' THEN 1 ELSE 0 END) as user_tokens,
+ SUM(CASE WHEN token_type = 'guest' THEN 1 ELSE 0 END) as guest_tokens,
+ SUM(CASE WHEN token_type = 'unknown' THEN 1 ELSE 0 END) as unknown_tokens
+ FROM tokens
+ WHERE provider = ?
+ """,
+ (provider,),
+ )
+ row = await cursor.fetchone()
+
+ if not row:
+ return {
+ "total_tokens": 0,
+ "enabled_tokens": 0,
+ "user_tokens": 0,
+ "guest_tokens": 0,
+ "unknown_tokens": 0,
+ }
+
+ return {
+ "total_tokens": int(row["total_tokens"] or 0),
+ "enabled_tokens": int(row["enabled_tokens"] or 0),
+ "user_tokens": int(row["user_tokens"] or 0),
+ "guest_tokens": int(row["guest_tokens"] or 0),
+ "unknown_tokens": int(row["unknown_tokens"] or 0),
+ }
+ except Exception as e:
+ logger.error(f"❌ 获取 Token 数量统计失败: {e}")
+ return {
+ "total_tokens": 0,
+ "enabled_tokens": 0,
+ "user_tokens": 0,
+ "guest_tokens": 0,
+ "unknown_tokens": 0,
+ }
+
+ async def count_tokens_by_provider(
+ self,
+ provider: str,
+ enabled_only: bool = False,
+ ) -> int:
+ """统计提供商下的 Token 总数。"""
+ try:
+ async with self.get_connection() as conn:
+ query = "SELECT COUNT(*) AS total_count FROM tokens WHERE provider = ?"
+ params: List[object] = [provider]
+ if enabled_only:
+ query += " AND is_enabled = 1"
+
+ cursor = await conn.execute(query, params)
+ row = await cursor.fetchone()
+
+ return int(row["total_count"] or 0) if row else 0
+ except Exception as e:
+ logger.error(f"❌ 统计 Token 总数失败: {e}")
+ return 0
+
+ # ==================== Token 验证操作 ====================
+
+ async def validate_and_update_token(self, token_id: int) -> bool:
+ """
+ 验证单个 Token 并更新其类型
+
+ Args:
+ token_id: Token 数据库 ID
+
+ Returns:
+ 是否为有效的认证用户 Token
+ """
+ try:
+ # 获取 Token 信息
+ async with self.get_connection() as conn:
+ cursor = await conn.execute("""
+ SELECT provider, token FROM tokens WHERE id = ?
+ """, (token_id,))
+ row = await cursor.fetchone()
+
+ if not row:
+ logger.error(f"❌ Token ID {token_id} 不存在")
+ return False
+
+ provider = row["provider"]
+ token = row["token"]
+
+ if provider != "zai":
+ logger.info(f"⏭️ 跳过非 zai 提供商的 Token 验证: {provider}")
+ return True
+
+ # 验证 Token
+ from app.utils.token_pool import ZAITokenValidator
+
+ token_type, is_valid, error_msg = await ZAITokenValidator.validate_token(token)
+
+ # 更新 Token 类型
+ await self.update_token_type(token_id, token_type)
+
+ if not is_valid:
+ logger.warning(f"⚠️ Token 验证失败: id={token_id}, type={token_type}, error={error_msg}")
+
+ return is_valid
+
+ except Exception as e:
+ logger.error(f"❌ 验证 Token 失败: {e}")
+ return False
+
+ async def validate_tokens_detailed(self, provider: str = "zai") -> Dict[str, Any]:
+ """
+ 批量验证所有 Token,并返回详细结果。
+
+ Returns:
+ {
+ "checked": 数量,
+ "valid": 数量,
+ "guest": 数量,
+ "invalid": 数量,
+ "invalid_token_ids": [id, ...],
+ }
+ """
+ try:
+ tokens = await self.get_tokens_by_provider(provider, enabled_only=False)
+
+ if not tokens:
+ logger.warning(f"⚠️ 没有需要验证的 {provider} Token")
+ return {
+ "checked": 0,
+ "valid": 0,
+ "guest": 0,
+ "invalid": 0,
+ "invalid_token_ids": [],
+ }
+
+ logger.info(f"🔍 开始批量验证 {len(tokens)} 个 {provider} Token...")
+
+ from app.utils.token_pool import ZAITokenValidator
+
+ stats: Dict[str, Any] = {
+ "checked": len(tokens),
+ "valid": 0,
+ "guest": 0,
+ "invalid": 0,
+ "invalid_token_ids": [],
+ }
+
+ for token_record in tokens:
+ token_id = int(token_record["id"])
+ token = str(token_record["token"])
+
+ token_type, is_valid, error_msg = await ZAITokenValidator.validate_token(
+ token
+ )
+ await self.update_token_type(token_id, token_type)
+
+ if token_type == "user" and is_valid:
+ stats["valid"] += 1
+ elif token_type == "guest":
+ stats["guest"] += 1
+ stats["invalid_token_ids"].append(token_id)
+ else:
+ stats["invalid"] += 1
+ stats["invalid_token_ids"].append(token_id)
+ if error_msg:
+ logger.warning(
+ "⚠️ Token 验证失败: id={}, type={}, error={}",
+ token_id,
+ token_type,
+ error_msg,
+ )
+
+ logger.info(
+ "✅ 批量验证完成: 有效 {}, 匿名 {}, 无效 {}",
+ stats["valid"],
+ stats["guest"],
+ stats["invalid"],
+ )
+ return stats
+
+ except Exception as e:
+ logger.error(f"❌ 批量验证失败: {e}")
+ return {
+ "checked": 0,
+ "valid": 0,
+ "guest": 0,
+ "invalid": 0,
+ "invalid_token_ids": [],
+ }
+
+ async def validate_all_tokens(self, provider: str = "zai") -> Dict[str, int]:
+ """
+ 批量验证所有 Token
+
+ Args:
+ provider: 提供商名称(默认 zai)
+
+ Returns:
+ 统计结果 {"valid": 数量, "guest": 数量, "invalid": 数量}
+ """
+ stats = await self.validate_tokens_detailed(provider)
+ return {
+ "valid": int(stats.get("valid", 0) or 0),
+ "guest": int(stats.get("guest", 0) or 0),
+ "invalid": int(stats.get("invalid", 0) or 0),
+ }
+
+
+# 全局单例
+_token_dao: Optional[TokenDAO] = None
+
+
+def get_token_dao() -> TokenDAO:
+ """获取全局 TokenDAO 实例"""
+ global _token_dao
+ if _token_dao is None:
+ _token_dao = TokenDAO()
+ return _token_dao
+
+
+async def init_token_database():
+ """初始化 Token 数据库"""
+ dao = get_token_dao()
+ await dao.init_database()
diff --git a/app/services/token_importer.py b/app/services/token_importer.py
new file mode 100644
index 0000000000000000000000000000000000000000..4ff83ecaed37196fc03bdbe33d41c46b770fab0c
--- /dev/null
+++ b/app/services/token_importer.py
@@ -0,0 +1,138 @@
+"""本地目录 token 导入服务。"""
+
+from __future__ import annotations
+
+import json
+from dataclasses import dataclass
+from pathlib import Path
+from typing import Optional
+
+from app.services.token_dao import TokenDAO, get_token_dao
+from app.utils.logger import logger
+
+
+@dataclass(frozen=True)
+class TokenImportSummary:
+ source_dir: str
+ scanned_files: int
+ imported_count: int
+ duplicate_count: int
+ invalid_json_count: int
+ missing_token_count: int
+ invalid_token_count: int
+
+ @property
+ def failed_count(self) -> int:
+ return (
+ self.duplicate_count
+ + self.invalid_json_count
+ + self.missing_token_count
+ + self.invalid_token_count
+ )
+
+
+def _load_token_payload(file_path: Path) -> dict:
+ try:
+ return json.loads(file_path.read_text(encoding="utf-8"))
+ except json.JSONDecodeError as exc:
+ raise ValueError(f"JSON 解析失败: {exc}") from exc
+
+
+async def import_tokens_from_directory(
+ source_dir: str | Path,
+ *,
+ provider: str = "zai",
+ validate: bool = True,
+ dao: Optional[TokenDAO] = None,
+) -> TokenImportSummary:
+ """
+ 从本地目录导入 token。
+
+ 目录中的每个 JSON 文件应至少包含 `token` 字段。
+ """
+ source_path = Path(source_dir).expanduser().resolve()
+ if not source_path.exists():
+ raise FileNotFoundError(f"导入目录不存在: {source_path}")
+ if not source_path.is_dir():
+ raise NotADirectoryError(f"导入路径不是目录: {source_path}")
+
+ token_dao = dao or get_token_dao()
+ token_files = sorted(source_path.rglob("*.json"))
+ seen_tokens: set[str] = set()
+ imported_count = 0
+ duplicate_count = 0
+ invalid_json_count = 0
+ missing_token_count = 0
+ invalid_token_count = 0
+
+ for file_path in token_files:
+ try:
+ payload = _load_token_payload(file_path)
+ except ValueError as exc:
+ invalid_json_count += 1
+ logger.warning(f"⚠️ 跳过无效 JSON 文件: {file_path} - {exc}")
+ continue
+
+ if not isinstance(payload, dict):
+ invalid_json_count += 1
+ logger.warning(f"⚠️ 跳过非对象 JSON 文件: {file_path}")
+ continue
+
+ token = str(payload.get("token") or "").strip()
+ email = str(payload.get("email") or "").strip()
+ if not token:
+ missing_token_count += 1
+ logger.warning(f"⚠️ 文件缺少 token 字段: {file_path}")
+ continue
+
+ if token in seen_tokens:
+ duplicate_count += 1
+ logger.info(f"↩️ 跳过本批次重复 Token: {file_path.name}")
+ continue
+ seen_tokens.add(token)
+
+ existing = await token_dao.get_token_by_value(provider, token)
+ if existing is not None:
+ duplicate_count += 1
+ logger.info(
+ "↩️ Token 已存在,跳过导入: {} ({})",
+ file_path.name,
+ email or "unknown",
+ )
+ continue
+
+ token_id = await token_dao.add_token(
+ provider=provider,
+ token=token,
+ token_type="user",
+ validate=validate,
+ )
+ if token_id is None:
+ invalid_token_count += 1
+ logger.warning(f"⚠️ Token 导入失败: {file_path.name} ({email or 'unknown'})")
+ continue
+
+ imported_count += 1
+ logger.info(f"✅ 已导入 Token: {file_path.name} ({email or 'unknown'})")
+
+ summary = TokenImportSummary(
+ source_dir=str(source_path),
+ scanned_files=len(token_files),
+ imported_count=imported_count,
+ duplicate_count=duplicate_count,
+ invalid_json_count=invalid_json_count,
+ missing_token_count=missing_token_count,
+ invalid_token_count=invalid_token_count,
+ )
+ logger.info(
+ "✅ Token 目录导入完成: "
+ "scanned={}, imported={}, duplicate={}, invalid_json={}, "
+ "missing_token={}, invalid_token={}",
+ summary.scanned_files,
+ summary.imported_count,
+ summary.duplicate_count,
+ summary.invalid_json_count,
+ summary.missing_token_count,
+ summary.invalid_token_count,
+ )
+ return summary
diff --git a/app/templates/base.html b/app/templates/base.html
new file mode 100644
index 0000000000000000000000000000000000000000..e7871f1a7f47b467650589c989c806617274fe42
--- /dev/null
+++ b/app/templates/base.html
@@ -0,0 +1,201 @@
+
+
+
+
+
+ {% block title %}管理后台{% endblock %} - API 控制台
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ {% block extra_head %}{% endblock %}
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ {% block content %}{% endblock %}
+
+
+
+
+
+ {% block extra_scripts %}{% endblock %}
+
+
diff --git a/app/templates/components/recent_logs.html b/app/templates/components/recent_logs.html
new file mode 100644
index 0000000000000000000000000000000000000000..60929a0c1c187dbad29da0fadce0c13c2987dede
--- /dev/null
+++ b/app/templates/components/recent_logs.html
@@ -0,0 +1,128 @@
+
+
+
+
+
+ {% if logs %}
+
+
+
+ | 时间 |
+ 请求 |
+ 标记 |
+ 输入 / 输出 |
+ 缓存创建 / 命中 |
+ 用时 / 首字 |
+ 状态 |
+
+
+
+ {% for log in logs %}
+
+ |
+ {{ log.timestamp }}
+ |
+
+
+
+ {{ log.model }}
+
+
+ {{ log.endpoint }}
+
+
+ {% if log.error_message %}
+
+ {{ log.error_message }}
+
+ {% endif %}
+ |
+
+
+
+ {{ log.client_name }}
+
+
+ {{ log.protocol_display }}
+
+ {% if log.source_display %}
+
+ {{ log.source_display }}
+
+ {% endif %}
+ {% if log.provider_display %}
+
+ {{ log.provider_display }}
+
+ {% endif %}
+
+ |
+
+ 输入 {{ log.input_tokens }}
+ /
+ 输出 {{ log.output_tokens }}
+ |
+
+ 创建 {{ log.cache_creation_tokens }}
+ /
+ 命中 {{ log.cache_read_tokens }}
+ |
+
+ 用时 {{ log.duration_display }}
+ /
+ 首字 {{ log.first_token_display }}
+ |
+
+
+ {{ "成功" if log.success else "失败" }}
+
+
+ HTTP {{ log.status_code }}
+
+ |
+
+ {% endfor %}
+
+
+ {% else %}
+
+ {% endif %}
+
+
+
+
+ {% if page.total_items > 0 %}
+ 显示第 {{ page.start_item }} - {{ page.end_item }} 条,共 {{ page.total_items }} 条
+ {% else %}
+ 暂无日志数据
+ {% endif %}
+
+
+
+
+ 第 {{ page.current_page }} / {{ page.total_pages }} 页
+
+
+
+
+
diff --git a/app/templates/components/token_list.html b/app/templates/components/token_list.html
new file mode 100644
index 0000000000000000000000000000000000000000..50cd3e1377472a54cb8ce2ff4c12e88619223cda
--- /dev/null
+++ b/app/templates/components/token_list.html
@@ -0,0 +1,114 @@
+
+
+
+
+
+ {% if tokens %}
+
+
+
+ | ID |
+ Token |
+ 类型 |
+ 健康度 |
+ 状态 |
+ 使用统计 |
+ 创建时间 |
+ 操作 |
+
+
+
+ {% for token in tokens %}
+ {% include "components/token_row.html" %}
+ {% endfor %}
+
+
+ {% else %}
+
+
+
暂无 Token
+
点击右上角"添加 Token"按钮开始添加
+
+ {% endif %}
+
+
+
+
+ {% if page.total_items > 0 %}
+ 显示第 {{ page.start_item }} - {{ page.end_item }} 条,共 {{ page.total_items }} 个 Token
+ {% else %}
+ 暂无 Token 数据
+ {% endif %}
+
+
+
+
+ 第 {{ page.current_page }} / {{ page.total_pages }} 页
+
+
+
+
+
+
+
+
+
+
+
diff --git a/app/templates/components/token_pool.html b/app/templates/components/token_pool.html
new file mode 100644
index 0000000000000000000000000000000000000000..8f37941675962f88565d318af215271ef857b712
--- /dev/null
+++ b/app/templates/components/token_pool.html
@@ -0,0 +1,40 @@
+
+
+ {% for token in tokens %}
+
+
+ Token #{{ token.index }}
+
+ {{ token.status }}
+
+
+
+
+ {{ token.key }}
+
+
类型:
+ {% if token.token_type == 'user' %}
+ 认证用户
+ {% elif token.token_type == 'guest' %}
+ 匿名用户
+ {% else %}
+ 未知
+ {% endif %}
+
+
成功率: {{ token.success_rate }}
+
失败次数: {{ token.failure_count }}
+
最后使用: {{ token.last_used }}
+
+
+ {% endfor %}
+
+ {% if not tokens %}
+
+
+
暂无 Token 配置
+
请在配置管理页面添加 Token
+
+ {% endif %}
+
diff --git a/app/templates/components/token_row.html b/app/templates/components/token_row.html
new file mode 100644
index 0000000000000000000000000000000000000000..32df001d19c3bb1606ad27529729d0d5dc9bc2b5
--- /dev/null
+++ b/app/templates/components/token_row.html
@@ -0,0 +1,154 @@
+
+{% set success_rate = (token.successful_requests / token.total_requests * 100) if token.total_requests else 0 %}
+{% set is_healthy = (token.token_type == 'user' and token.is_enabled and (success_rate >= 50 or token.total_requests <= 3)) %}
+
+ |
+ {{ token.id }}
+ |
+
+
+
+ {{ token.token[:30] }}...
+
+
+
+ |
+
+ {% if token.token_type == 'user' %}
+
+
+ 认证用户
+
+ {% elif token.token_type == 'guest' %}
+
+
+ 匿名用户
+
+ {% else %}
+
+
+ 未知
+
+ {% endif %}
+ |
+
+
+
+ {% if is_healthy %}
+
+ {% elif token.token_type == 'guest' %}
+
+ {% elif not token.is_enabled %}
+
+ {% else %}
+
+ {% endif %}
+
+ |
+
+
+ |
+
+ {% if token.total_requests %}
+
+
+ 成功:
+ {{ token.successful_requests }}
+
+
+ 失败:
+ {{ token.failed_requests }}
+
+
+ 成功率:
+
+ {{ "%.1f"|format(success_rate) }}%
+
+
+
+
+
+ {% else %}
+ 未使用
+ {% endif %}
+ |
+
+
+ {{ token.created_at[:10] if token.created_at else 'N/A' }}
+ {{ token.created_at[11:19] if token.created_at else '' }}
+
+ |
+
+
+
+
+
+
+
+ |
+
diff --git a/app/templates/components/token_stats.html b/app/templates/components/token_stats.html
new file mode 100644
index 0000000000000000000000000000000000000000..1d1bc4e05ad86079d554b8667dcc890ebcec2ed0
--- /dev/null
+++ b/app/templates/components/token_stats.html
@@ -0,0 +1,125 @@
+
+
+
+
+
+
+
+
+
+ - Token 总数
+ - {{ stats.total_tokens }}
+
+
+
+
+
+
+
+
+
+
+
+
+
+ - 已启用
+ -
+
{{ stats.enabled_tokens }}
+ {% if stats.total_tokens > 0 %}
+
+ {{ "%.0f"|format(stats.enabled_tokens / stats.total_tokens * 100) }}%
+
+ {% endif %}
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ - 认证用户
+ -
+
{{ stats.user_tokens }}
+ {% if stats.guest_tokens > 0 %}
+
+
+ {{ stats.guest_tokens }} 个匿名
+
+ {% endif %}
+
+
+
+
+
+
+
+
+
+
+
+
+ {% if stats.total_requests > 0 %}
+ {% set success_rate = (stats.successful_requests / stats.total_requests * 100) %}
+ {% if success_rate >= 80 %}
+
+ {% elif success_rate >= 50 %}
+
+ {% else %}
+
+ {% endif %}
+ {% else %}
+
+ {% endif %}
+
+
+
+ - 总成功率
+ -
+ {% if stats.total_requests > 0 %}
+ {% set success_rate = (stats.successful_requests / stats.total_requests * 100) %}
+
+ {{ "%.1f"|format(success_rate) }}%
+
+
+ {{ stats.successful_requests }} / {{ stats.total_requests }} 请求
+
+ {% else %}
+ N/A
+ 暂无请求
+ {% endif %}
+
+
+
+
+
+
+
diff --git a/app/templates/config.html b/app/templates/config.html
new file mode 100644
index 0000000000000000000000000000000000000000..6c39091e3b388c0066c68be6a24a65844ff49d03
--- /dev/null
+++ b/app/templates/config.html
@@ -0,0 +1,344 @@
+{% extends "base.html" %}
+
+{% block title %}配置管理{% endblock %}
+
+{% block extra_head %}
+
+{% endblock %}
+
+{% macro section_link(section) -%}
+
+{%- endmacro %}
+
+{% macro render_field(field) -%}
+
+ {% if field.value_type == 'bool' %}
+
+ {% else %}
+
+
+
+ {% if field.sensitive %}
+
+ {% endif %}
+
+
{{ field.description }}
+
+
+ {% endif %}
+
+
+
+ {{ field.source_label }}
+
+ {% if field.restart_required %}
+
+ 需重启
+
+ {% endif %}
+ {% if field.sensitive %}
+
+ 敏感字段
+
+ {% endif %}
+
+
+{%- endmacro %}
+
+{% block content %}
+
+
+
+
+
+
Admin Config Center
+
集中管理运行参数,并支持直接编辑 `.env` 源文件
+
+ 结构化表单适合日常操作,源文件模式适合批量调整、复制完整配置或保留注释。两种模式都会在保存后立即热重载。
+
+
+
+
+
+
+
+
+
+
+
+
+
受管字段
+
{{ overview.total_fields }}
+
{{ overview.total_sections }} 个分组
+
+
+
.env 覆写
+
{{ overview.overridden_fields }}
+
{{ overview.default_fields }} 个字段仍在使用默认值
+
+
+
敏感字段
+
{{ overview.sensitive_fields }}
+
{{ overview.restart_required_fields }} 个字段修改后建议重启
+
+
+
源文件状态
+
{{ '.env 已存在' if overview.env_exists else '.env 尚未创建' }}
+
{{ overview.env_line_count }} 行,{{ '.env.example 可用' if overview.example_exists else '缺少 .env.example' }}
+
+
+
+
+
+
+
+ {% if not overview.env_exists %}
+
+ 当前工作目录中尚未找到 `.env` 文件。你可以直接保存表单或源文件,系统会自动创建它。
+
+ {% endif %}
+
+
+
+{% endblock %}
+
+{% block extra_scripts %}
+
+{% endblock %}
diff --git a/app/templates/index.html b/app/templates/index.html
new file mode 100644
index 0000000000000000000000000000000000000000..0292cadc62063d28d24e34246e550fdfb54f0f37
--- /dev/null
+++ b/app/templates/index.html
@@ -0,0 +1,588 @@
+{% extends "base.html" %}
+
+{% block title %}仪表盘{% endblock %}
+
+{% block content %}
+
+
+
+
+
+
Usage Dashboard
+
查看请求消耗、缓存效果、延迟表现和使用趋势
+
+ 统计来源于请求日志数据库,覆盖输入输出 Token、缓存创建与命中、成功率、平均延迟,以及最近 24 小时、7 天、30 天的使用趋势。
+
+
+
+
Last Update
+
{{ current_time }}
+
运行时间 {{ stats.uptime }}
+
+
+
+
+
+
+
+
+
+
总请求数
+
{{ stats.total_requests }}
+
+
+
+
+
+
成功 {{ stats.successful_requests }} / 失败 {{ stats.failed_requests }}
+
+
+
+
+
+
总消耗 Token 数
+
{{ stats.total_consumed_tokens_display }}
+
+
+
+
+
+
累计 {{ stats.total_consumed_tokens }} Tokens
+
+
+
+
+
+
缓存 Token
+
{{ stats.total_cache_tokens_display }}
+
+
+
+
+
+
创建 {{ stats.cache_creation_tokens }} / 命中 {{ stats.cache_read_tokens }}
+
+
+
+
+
+
成功率
+
{{ stats.success_rate }}%
+
+
+
+
+
+
图表支持切换 24 小时 / 7 天 / 30 天
+
+
+
+
+
+
输入 Token
+
{{ stats.input_tokens_display }}
+
+
+
+
+
+
累计 {{ stats.input_tokens }} Tokens
+
+
+
+
+
+
输出 Token
+
{{ stats.output_tokens_display }}
+
+
+
+
+
+
累计 {{ stats.output_tokens }} Tokens
+
+
+
+
+
+
平均延迟
+
{{ "%.2f"|format(stats.average_latency) }}s
+
+
+
+
+
+
平均首字延迟 {{ "%.2f"|format(stats.average_first_token_latency) }}s
+
+
+
+
+
+
Token 池健康度
+
{{ stats.healthy_tokens }}/{{ stats.pool_total_tokens }}
+
+
+
+
+
+
可用 {{ stats.available_tokens }} / 已启用 {{ stats.enabled_tokens }} / 认证 {{ stats.user_tokens }}
+
+
+
+
+
+
+
+
使用趋势图
+
+ 最近 7 天按天聚合的请求量、输入输出与缓存变化。
+
+
+
+ {% set trend_window_options = trend_windows if trend_windows is defined and trend_windows else [
+ {'key': '24h', 'label': '24 小时'},
+ {'key': '7d', 'label': '7 天'},
+ {'key': '30d', 'label': '30 天'}
+ ] %}
+ {% for option in trend_window_options %}
+
+ {% endfor %}
+
+
+
+ 蓝柱: 请求量
+ 紫线: 输入
+ 红线: 输出
+ 绿线: 缓存创建
+ 黄线: 缓存命中
+
+
+
+
+
+
+
+
+
+ 缓存创建 / 命中
+ 按请求次数和 Token 数量查看缓存是否真的生效。
+
+
+
缓存创建
+
{{ stats.cache_creation_requests }}
+
共创建 {{ stats.cache_creation_tokens }} Tokens
+
+
+
缓存命中
+
{{ stats.cache_hit_requests }}
+
共命中 {{ stats.cache_read_tokens }} Tokens
+
+
+
+
+
+ 输入 / 输出画像
+ 对比 Prompt 与 Completion 的消耗分布。
+ {% set usage_total = stats.input_tokens + stats.output_tokens %}
+ {% set input_ratio = (stats.input_tokens / usage_total * 100) if usage_total > 0 else 0 %}
+ {% set output_ratio = (stats.output_tokens / usage_total * 100) if usage_total > 0 else 0 %}
+
+
+
+ 输入 Token
+ {{ "%.1f"|format(input_ratio) }}%
+
+
+
{{ stats.input_tokens }} Tokens
+
+
+
+ 输出 Token
+ {{ "%.1f"|format(output_ratio) }}%
+
+
+
{{ stats.output_tokens }} Tokens
+
+
+
+
+
+
+
+
+
最近请求日志
+
+
+
+
+
+
+
+{% endblock %}
+
+{% block extra_scripts %}
+
+{% endblock %}
diff --git a/app/templates/login.html b/app/templates/login.html
new file mode 100644
index 0000000000000000000000000000000000000000..f0defb7746297faf1d4f6dddb0a694da9b494260
--- /dev/null
+++ b/app/templates/login.html
@@ -0,0 +1,143 @@
+
+
+
+
+
+ 登录 - API 控制台
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ 默认密码:admin123(请在 .env 中修改 ADMIN_PASSWORD)
+
+
+
+
+
+
+
diff --git a/app/templates/logs.html b/app/templates/logs.html
new file mode 100644
index 0000000000000000000000000000000000000000..2316039dea1b6127143fc23a335a194a3c1509ef
--- /dev/null
+++ b/app/templates/logs.html
@@ -0,0 +1,59 @@
+{% extends "base.html" %}
+
+{% block title %}实时日志{% endblock %}
+
+{% block content %}
+
+
+
+
实时日志
+
滚动查看服务当前输出的最新日志
+
+
+
+
+
+
+
+
+{% endblock %}
+
+{% block extra_scripts %}
+
+{% endblock %}
diff --git a/app/templates/tokens.html b/app/templates/tokens.html
new file mode 100644
index 0000000000000000000000000000000000000000..b570e990a92d12d1b9c09ec09c709e26625ae2c4
--- /dev/null
+++ b/app/templates/tokens.html
@@ -0,0 +1,487 @@
+{% extends "base.html" %}
+
+{% block title %}Token 管理{% endblock %}
+
+{% block content %}
+
+
+
+
+
Token 管理
+
管理和维护当前服务使用的 Token
+
+
+
+
+
+
+
+
+
+
+
+
+
目录导入策略
+
+ 配置入口已迁移到配置管理页。这里仅展示当前策略,并允许立即执行一次导入。
+
+
+
+ {{ '定时已开启' if automation.import_enabled else '定时已关闭' }}
+
+
+
+ {% if automation.has_import_source_dir %}
+
+ 手动导入会复用当前配置的目录和验证逻辑,重复 Token 会自动跳过。
+
+ {% else %}
+
+ 还没有配置导入目录,无法执行手动导入。请先到配置管理页设置 `TOKEN_AUTO_IMPORT_SOURCE_DIR`。
+
+ {% endif %}
+
+
+
+
- Token 目录
+ -
+ {{ automation.import_source_dir or '未配置' }}
+
+
+
+
+
- 扫描间隔
+ - {{ automation.import_interval }} 秒
+
+
+
- 配置位置
+ - 配置管理 / Token 池策略
+
+
+
+
+
+
+
+
+
+
+
自动维护策略
+
+ 维护动作和定时间隔统一在配置管理页设置。这里仅执行当前已配置的维护策略。
+
+
+
+ {{ '定时已开启' if automation.maintenance_enabled else '定时已关闭' }}
+
+
+
+ {% if automation.has_maintenance_actions %}
+
+ 手动维护会按当前配置顺序执行去重、测活和失效清理,不再在本页单独维护另一套选项。
+
+ {% else %}
+
+ 当前没有配置任何维护动作。请先到配置管理页勾选至少一个维护动作。
+
+ {% endif %}
+
+
+
+
+
- 维护间隔
+ - {{ automation.maintenance_interval }} 秒
+
+
+
- 配置位置
+ - 配置管理 / Token 池策略
+
+
+
+
- 当前维护动作
+ -
+ {% if automation.maintenance_actions %}
+ {% for action in automation.maintenance_actions %}
+ {{ action }}
+ {% endfor %}
+ {% else %}
+ 未配置
+ {% endif %}
+
+
+
+
+
+
+
+
+
+
+
+
+ Token 列表
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ Token 验证:添加时将自动验证 Token 有效性,
+ 匿名用户 Token (guest) 将被拒绝。
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
支持格式:每行一个 Token,或使用逗号分隔
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ 将验证当前通道的所有 Token 的有效性。
+
此操作可能需要较长时间,请耐心等待。
+
+
+
+
+
+
+
+
验证内容:
+
+ - 检查 Token 是否有效
+ - 识别 Token 类型(认证用户 / 匿名用户)
+ - 更新数据库中的 Token 类型
+ - 匿名用户 Token 将被标记为不健康
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+{% endblock %}
+
+{% block extra_scripts %}
+
+{% endblock %}
diff --git a/app/utils/__init__.py b/app/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..4e4a7eb51af06f2d96bdd7bcce3c24fdd9b2d469
--- /dev/null
+++ b/app/utils/__init__.py
@@ -0,0 +1,6 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+
+from app.utils import reload_config, logger
+
+__all__ = ["reload_config", "logger"]
diff --git a/app/utils/env_file.py b/app/utils/env_file.py
new file mode 100644
index 0000000000000000000000000000000000000000..c5fcc0cd9b9be69be9610afa6a4232922f0dceeb
--- /dev/null
+++ b/app/utils/env_file.py
@@ -0,0 +1,59 @@
+"""Helpers for updating .env files without dropping unrelated settings."""
+
+from __future__ import annotations
+
+import re
+from pathlib import Path
+from typing import Mapping
+
+_ENV_KEY_PATTERN = re.compile(r"^\s*([A-Za-z_][A-Za-z0-9_]*)\s*=")
+
+
+def _serialize_env_value(value: object) -> str:
+ if isinstance(value, bool):
+ return "true" if value else "false"
+
+ text = "" if value is None else str(value)
+ if not text:
+ return ""
+
+ if any(char.isspace() for char in text) or any(
+ char in text for char in ["#", '"', "\\", "'"]
+ ):
+ if "'" not in text:
+ return f"'{text}'"
+
+ escaped = text.replace("\\", "\\\\").replace('"', '\\"')
+ return f'"{escaped}"'
+
+ return text
+
+
+def update_env_file(
+ updates: Mapping[str, object],
+ env_path: str | Path = ".env",
+) -> None:
+ """Update selected keys inside a .env file while preserving other lines."""
+ path = Path(env_path)
+ lines = path.read_text(encoding="utf-8").splitlines() if path.exists() else []
+ remaining_updates = {key: _serialize_env_value(value) for key, value in updates.items()}
+
+ for index, line in enumerate(lines):
+ match = _ENV_KEY_PATTERN.match(line)
+ if not match:
+ continue
+
+ key = match.group(1)
+ if key not in remaining_updates:
+ continue
+
+ lines[index] = f"{key}={remaining_updates.pop(key)}"
+
+ if remaining_updates:
+ if lines and lines[-1].strip():
+ lines.append("")
+ for key, value in remaining_updates.items():
+ lines.append(f"{key}={value}")
+
+ content = "\n".join(lines).rstrip()
+ path.write_text(f"{content}\n" if content else "", encoding="utf-8")
diff --git a/app/utils/fe_version.py b/app/utils/fe_version.py
new file mode 100644
index 0000000000000000000000000000000000000000..de28498b353b533457efd94259525dde47a76cc3
--- /dev/null
+++ b/app/utils/fe_version.py
@@ -0,0 +1,112 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+
+"""
+Utility helpers for resolving the latest X-FE-Version value from chat.z.ai.
+
+The upstream service embeds the current front-end release identifier inside
+its landing page static asset URLs (e.g. `prod-fe-1.0.107`). The helpers in
+this module fetch the landing page, extract the version string, and cache it
+with a configurable TTL so the expensive network fetch only happens when
+necessary.
+"""
+
+from __future__ import annotations
+
+import re
+import time
+from typing import Optional
+
+import httpx
+
+from app.utils.logger import get_logger
+from app.utils.user_agent import get_random_user_agent
+
+# Base URL to probe for the version string.
+FE_VERSION_SOURCE_URL = "https://chat.z.ai"
+
+# Cache TTL in seconds (default: 30 minutes).
+CACHE_TTL_SECONDS = 1800
+
+_logger = get_logger()
+_version_pattern = re.compile(r"prod-fe-\d+\.\d+\.\d+")
+
+_cached_version: str = ""
+_cached_at: float = 0.0
+
+
+def _extract_version(page_content: str) -> Optional[str]:
+ """Extract the version string from the page content."""
+ if not page_content:
+ return None
+
+ matches = _version_pattern.findall(page_content)
+ if not matches:
+ return None
+
+ # Choose the highest lexical value to guard against mixed versions.
+ return max(matches)
+
+
+
+
+def _should_use_cache(force_refresh: bool) -> bool:
+ """Determine whether the cached value can be reused."""
+ if force_refresh:
+ return False
+ if not _cached_version:
+ return False
+ if _cached_at <= 0:
+ return False
+ return (time.time() - _cached_at) < CACHE_TTL_SECONDS
+
+
+def get_latest_fe_version(force_refresh: bool = False) -> str:
+ """
+ Resolve the latest X-FE-Version value from chat.z.ai.
+
+ The lookup order is:
+ 1. Cached value within TTL.
+ 2. Remote fetch from chat.z.ai.
+
+ Raises:
+ Exception: If unable to fetch the version from the remote source.
+ """
+ global _cached_version, _cached_at
+
+ if _should_use_cache(force_refresh):
+ return _cached_version
+
+ try:
+ headers = {"User-Agent": get_random_user_agent("chrome")}
+ except Exception:
+ headers = {
+ "User-Agent": (
+ "Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
+ "AppleWebKit/537.36 (KHTML, like Gecko) "
+ "Chrome/120.0.0.0 Safari/537.36"
+ )
+ }
+
+ try:
+ with httpx.Client(timeout=10.0, follow_redirects=True) as client:
+ response = client.get(FE_VERSION_SOURCE_URL, headers=headers)
+ response.raise_for_status()
+ version = _extract_version(response.text)
+ if version:
+ if version != _cached_version:
+ _logger.info(f"[Z.AI] Detected X-FE-Version update: {version}")
+ _cached_version = version
+ _cached_at = time.time()
+ return version
+
+ _logger.error("[Z.AI] Unable to locate X-FE-Version in landing page")
+ raise Exception("Unable to locate X-FE-Version in landing page")
+ except Exception as exc:
+ _logger.error(f"[Z.AI] Failed to fetch X-FE-Version from {FE_VERSION_SOURCE_URL}: {exc}")
+ raise Exception(f"Failed to fetch X-FE-Version: {exc}")
+
+
+def refresh_fe_version() -> str:
+ """Force refresh the cached version by bypassing the TTL."""
+ return get_latest_fe_version(force_refresh=True)
diff --git a/app/utils/guest_session_pool.py b/app/utils/guest_session_pool.py
new file mode 100644
index 0000000000000000000000000000000000000000..9cc82d5d852149bbbf4e8a2b345b56c1e7c9f149
--- /dev/null
+++ b/app/utils/guest_session_pool.py
@@ -0,0 +1,646 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+
+"""匿名访客会话池。"""
+
+import asyncio
+import random
+import time
+from dataclasses import dataclass, field
+from threading import Lock
+from typing import Dict, List, Optional, Set
+
+import httpx
+
+from app.core.config import settings
+from app.utils.fe_version import get_latest_fe_version
+from app.utils.logger import logger
+from app.utils.user_agent import get_random_user_agent
+
+AUTH_URL = "https://chat.z.ai/api/v1/auths/"
+CHATS_URL = "https://chat.z.ai/api/v1/chats/"
+AUTH_HTTP_MAX_KEEPALIVE_CONNECTIONS = 20
+AUTH_HTTP_MAX_CONNECTIONS = 50
+GUEST_SESSION_TTL_SECONDS = 480
+GUEST_SESSION_TTL_JITTER_SECONDS = 60
+GUEST_SESSION_MIN_TTL_SECONDS = 180
+GUEST_POOL_MAINTENANCE_INTERVAL_SECONDS = 30
+GUEST_CLEANUP_PARALLELISM = 4
+CAPACITY_FILL_ATTEMPT_MULTIPLIER = 3
+CAPACITY_FILL_MIN_ATTEMPTS = 3
+MAX_DUPLICATE_LOG_USER_IDS = 3
+
+
+def _get_proxy_config() -> Optional[str]:
+ """获取代理配置。"""
+ if settings.HTTPS_PROXY:
+ return settings.HTTPS_PROXY
+ if settings.HTTP_PROXY:
+ return settings.HTTP_PROXY
+ if settings.SOCKS5_PROXY:
+ return settings.SOCKS5_PROXY
+ return None
+
+
+def _build_timeout(read_timeout: float = 30.0) -> httpx.Timeout:
+ """构建访客会话相关请求超时。"""
+ return httpx.Timeout(
+ connect=5.0,
+ read=read_timeout,
+ write=10.0,
+ pool=5.0,
+ )
+
+
+def _build_limits() -> httpx.Limits:
+ """构建访客会话相关连接池限制。"""
+ return httpx.Limits(
+ max_keepalive_connections=AUTH_HTTP_MAX_KEEPALIVE_CONNECTIONS,
+ max_connections=AUTH_HTTP_MAX_CONNECTIONS,
+ )
+
+
+def _build_async_client(read_timeout: float = 30.0) -> httpx.AsyncClient:
+ """构建访客会话相关 HTTP 客户端。"""
+ return httpx.AsyncClient(
+ timeout=_build_timeout(read_timeout),
+ follow_redirects=True,
+ limits=_build_limits(),
+ proxy=_get_proxy_config(),
+ )
+
+
+def _build_dynamic_headers(chat_id: str = "") -> Dict[str, str]:
+ """生成匿名访客鉴权所需浏览器请求头。"""
+ browser_choices = [
+ "chrome",
+ "chrome",
+ "chrome",
+ "edge",
+ "edge",
+ "firefox",
+ "safari",
+ ]
+ browser_type = random.choice(browser_choices)
+ user_agent = get_random_user_agent(browser_type)
+ fe_version = get_latest_fe_version()
+
+ chrome_version = "139"
+ edge_version = "139"
+
+ if "Chrome/" in user_agent:
+ try:
+ chrome_version = user_agent.split("Chrome/")[1].split(".")[0]
+ except Exception:
+ pass
+
+ if "Edg/" in user_agent:
+ try:
+ edge_version = user_agent.split("Edg/")[1].split(".")[0]
+ sec_ch_ua = (
+ f'"Microsoft Edge";v="{edge_version}", '
+ f'"Chromium";v="{chrome_version}", "Not_A Brand";v="24"'
+ )
+ except Exception:
+ sec_ch_ua = (
+ f'"Not_A Brand";v="8", "Chromium";v="{chrome_version}", '
+ f'"Google Chrome";v="{chrome_version}"'
+ )
+ elif "Firefox/" in user_agent:
+ sec_ch_ua = None
+ else:
+ sec_ch_ua = (
+ f'"Not_A Brand";v="8", "Chromium";v="{chrome_version}", '
+ f'"Google Chrome";v="{chrome_version}"'
+ )
+
+ headers = {
+ "Content-Type": "application/json",
+ "Accept": "application/json, text/event-stream",
+ "Connection": "keep-alive",
+ "Cache-Control": "no-cache",
+ "User-Agent": user_agent,
+ "Accept-Language": "zh-CN,zh;q=0.9,en;q=0.8",
+ "X-FE-Version": fe_version,
+ "Origin": "https://chat.z.ai",
+ }
+
+ if sec_ch_ua:
+ headers["sec-ch-ua"] = sec_ch_ua
+ headers["sec-ch-ua-mobile"] = "?0"
+ headers["sec-ch-ua-platform"] = '"Windows"'
+
+ if chat_id:
+ headers["Referer"] = f"https://chat.z.ai/c/{chat_id}"
+ else:
+ headers["Referer"] = "https://chat.z.ai/"
+
+ return headers
+
+
+def _build_session_expiry() -> float:
+ """为新会话分配带抖动的过期时间,避免整池同时失效。"""
+ jitter = random.uniform(
+ -GUEST_SESSION_TTL_JITTER_SECONDS,
+ GUEST_SESSION_TTL_JITTER_SECONDS,
+ )
+ ttl_seconds = max(
+ GUEST_SESSION_MIN_TTL_SECONDS,
+ GUEST_SESSION_TTL_SECONDS + jitter,
+ )
+ return time.time() + ttl_seconds
+
+
+@dataclass
+class GuestSession:
+ """单个匿名访客会话。"""
+
+ token: str
+ user_id: str
+ username: str
+ created_at: float = field(default_factory=time.time)
+ expires_at: float = field(default_factory=_build_session_expiry)
+ active_requests: int = 0
+ valid: bool = True
+ failure_count: int = 0
+ last_failure_time: float = 0.0
+
+ @property
+ def age(self) -> float:
+ """会话存活时间。"""
+ return time.time() - self.created_at
+
+ @property
+ def is_expired(self) -> bool:
+ """判断会话是否已过期。"""
+ return time.time() >= self.expires_at
+
+ def snapshot(self) -> Dict[str, str]:
+ """获取当前会话快照。"""
+ return {
+ "token": self.token,
+ "user_id": self.user_id,
+ "username": self.username,
+ }
+
+
+class GuestSessionPool:
+ """匿名访客会话池,支持最小负载获取与失败替换。"""
+
+ def __init__(self, pool_size: int = 3):
+ self.pool_size = max(1, pool_size)
+ self._lock = Lock()
+ self._sessions: Dict[str, GuestSession] = {}
+ self._maintenance_task: Optional[asyncio.Task] = None
+ self._http_client: Optional[httpx.AsyncClient] = None
+ self._client_lock = asyncio.Lock()
+ self._capacity_lock = asyncio.Lock()
+ self._background_tasks: Set[asyncio.Task] = set()
+ self._cleanup_parallelism = GUEST_CLEANUP_PARALLELISM
+ self._maintenance_interval = GUEST_POOL_MAINTENANCE_INTERVAL_SECONDS
+
+ async def _get_http_client(self) -> httpx.AsyncClient:
+ """获取可复用的 HTTP 客户端,减少频繁建连开销。"""
+ if self._http_client is not None:
+ return self._http_client
+
+ async with self._client_lock:
+ if self._http_client is None:
+ self._http_client = _build_async_client()
+ return self._http_client
+
+ async def _close_http_client(self):
+ """关闭可复用的 HTTP 客户端。"""
+ async with self._client_lock:
+ client = self._http_client
+ self._http_client = None
+
+ if client is not None:
+ await client.aclose()
+
+ def _track_background_task(self, coro) -> asyncio.Task:
+ """跟踪后台任务,避免清理阻塞前台重试路径。"""
+ task = asyncio.create_task(coro)
+ self._background_tasks.add(task)
+
+ def _on_done(done_task: asyncio.Task):
+ self._background_tasks.discard(done_task)
+ try:
+ done_task.result()
+ except asyncio.CancelledError:
+ pass
+ except Exception as exc:
+ logger.warning(f"⚠️ 匿名会话后台任务异常: {exc}")
+
+ task.add_done_callback(_on_done)
+ return task
+
+ async def _wait_background_tasks(self):
+ """等待当前已注册的后台任务结束。"""
+ pending = list(self._background_tasks)
+ if pending:
+ await asyncio.gather(*pending, return_exceptions=True)
+
+ async def _delete_sessions_concurrently(self, sessions: List[GuestSession]):
+ """并发清理多枚匿名会话,加快池维护速度。"""
+ if not sessions:
+ return
+
+ semaphore = asyncio.Semaphore(self._cleanup_parallelism)
+
+ async def _cleanup(session: GuestSession):
+ async with semaphore:
+ await self._delete_all_chats(session)
+
+ await asyncio.gather(*(_cleanup(session) for session in sessions))
+
+ async def _create_session(self) -> GuestSession:
+ """创建一个新的匿名访客会话。"""
+ headers = _build_dynamic_headers()
+
+ # 访客鉴权会写入 cookie,复用同一个 client 会把“新建会话”粘回旧访客身份。
+ async with _build_async_client() as auth_client:
+ response = await auth_client.get(AUTH_URL, headers=headers)
+
+ if response.status_code != 200:
+ raise RuntimeError(
+ f"匿名会话创建失败: HTTP {response.status_code} {response.text[:200]}"
+ )
+
+ data = response.json()
+ token = str(data.get("token") or "").strip()
+ user_id = str(
+ data.get("id") or data.get("user_id") or data.get("uid") or ""
+ ).strip()
+ username = str(
+ data.get("name")
+ or str(data.get("email") or "").split("@")[0]
+ or f"guest-{user_id[:8] or 'session'}"
+ ).strip()
+
+ if not token:
+ raise RuntimeError(f"匿名会话创建失败: 未返回 token {data}")
+ if not user_id:
+ user_id = f"guest-{token[:12]}"
+
+ logger.info(
+ f"🫥 创建匿名会话成功: user_id={user_id}, username={username or 'Guest'}"
+ )
+ return GuestSession(
+ token=token,
+ user_id=user_id,
+ username=username or "Guest",
+ )
+
+ async def _delete_all_chats(self, session: GuestSession) -> bool:
+ """删除匿名会话的全部对话,尽量释放并发占用。"""
+ headers = _build_dynamic_headers()
+ headers.update(
+ {
+ "Authorization": f"Bearer {session.token}",
+ "Accept": "application/json",
+ "Content-Type": "application/json",
+ }
+ )
+
+ try:
+ client = await self._get_http_client()
+ response = await client.delete(CHATS_URL, headers=headers)
+
+ if response.status_code == 200:
+ logger.info(f"🧹 已清理匿名会话聊天记录: {session.user_id}")
+ return True
+
+ logger.warning(
+ f"⚠️ 清理匿名会话聊天记录失败: {session.user_id}, "
+ f"HTTP {response.status_code}, body={response.text[:200]}"
+ )
+ except Exception as exc:
+ logger.warning(f"⚠️ 清理匿名会话聊天记录异常: {session.user_id}, {exc}")
+
+ return False
+
+ def _list_valid_sessions(
+ self,
+ exclude_user_ids: Optional[Set[str]] = None,
+ ) -> List[GuestSession]:
+ """获取有效匿名会话列表。"""
+ excluded = exclude_user_ids or set()
+ with self._lock:
+ return [
+ session
+ for session in self._sessions.values()
+ if self._is_session_usable(session)
+ and session.user_id not in excluded
+ ]
+
+ def _is_session_usable(self, session: GuestSession) -> bool:
+ """判断会话当前是否还能继续分配。"""
+ return session.valid and not session.is_expired
+
+ def _should_retire_session(self, session: GuestSession) -> bool:
+ """判断会话是否应当从池中回收。"""
+ return session.active_requests == 0 and not self._is_session_usable(session)
+
+ def _can_replace_session(self, session: GuestSession) -> bool:
+ """判断当前池内会话是否允许被新的同 user_id 会话替换。"""
+ return self._should_retire_session(session)
+
+ def _store_session(self, session: GuestSession) -> bool:
+ """仅在会话唯一或旧会话已过期时写入会话池。"""
+ with self._lock:
+ existing = self._sessions.get(session.user_id)
+ if existing and not self._can_replace_session(existing):
+ return False
+ self._sessions[session.user_id] = session
+ return True
+
+ def _log_duplicate_sessions(self, action: str, user_ids: List[str]):
+ """记录重复会话,避免补池时静默覆盖。"""
+ if not user_ids:
+ return
+
+ sample = ", ".join(user_ids[:MAX_DUPLICATE_LOG_USER_IDS])
+ logger.warning(
+ f"⚠️ 匿名会话池{action}收到重复会话,已忽略: "
+ f"count={len(user_ids)}, user_ids={sample}"
+ )
+
+ def _register_create_results(self, action: str, results: List[object]) -> int:
+ """写入新创建的会话,并显式忽略重复 user_id。"""
+ created = 0
+ duplicate_user_ids: List[str] = []
+
+ for result in results:
+ if isinstance(result, GuestSession):
+ if self._store_session(result):
+ created += 1
+ else:
+ duplicate_user_ids.append(result.user_id)
+ continue
+
+ if isinstance(result, Exception):
+ logger.warning(f"⚠️ 匿名会话池{action}失败: {result}")
+
+ self._log_duplicate_sessions(action, duplicate_user_ids)
+ return created
+
+ def _get_fill_attempt_budget(self, missing_count: int) -> int:
+ """为补池/获取会话计算显式尝试上限,避免重复会话导致死循环。"""
+ scaled_budget = max(1, missing_count) * CAPACITY_FILL_ATTEMPT_MULTIPLIER
+ minimum_budget = max(1, missing_count) + CAPACITY_FILL_MIN_ATTEMPTS
+ return max(scaled_budget, minimum_budget)
+
+ def _pop_retired_sessions(self) -> List[GuestSession]:
+ """移除当前所有可回收的失效会话。"""
+ retired_sessions: List[GuestSession] = []
+
+ with self._lock:
+ for user_id, session in list(self._sessions.items()):
+ if self._should_retire_session(session):
+ retired_sessions.append(self._sessions.pop(user_id))
+
+ return retired_sessions
+
+ async def _ensure_capacity(self):
+ """补齐匿名会话池容量。"""
+ async with self._capacity_lock:
+ attempts_left = self._get_fill_attempt_budget(
+ self.pool_size - len(self._list_valid_sessions())
+ )
+
+ while attempts_left > 0:
+ need = self.pool_size - len(self._list_valid_sessions())
+ if need <= 0:
+ return
+
+ batch_size = min(need, attempts_left)
+ results = await asyncio.gather(
+ *[self._create_session() for _ in range(batch_size)],
+ return_exceptions=True,
+ )
+ attempts_left -= batch_size
+
+ created = self._register_create_results("补齐", results)
+ if created == 0 and attempts_left == 0:
+ break
+
+ remaining = self.pool_size - len(self._list_valid_sessions())
+ if remaining > 0:
+ logger.warning(
+ "⚠️ 匿名会话池补齐未达到目标容量: "
+ f"missing={remaining}, current={len(self._list_valid_sessions())}"
+ )
+
+ async def _maintenance_loop(self):
+ """后台维护:回收过期/失效会话,并补齐池容量。"""
+ while True:
+ try:
+ await asyncio.sleep(self._maintenance_interval)
+ retired_sessions = self._pop_retired_sessions()
+ await self._delete_sessions_concurrently(retired_sessions)
+
+ await self._ensure_capacity()
+ except asyncio.CancelledError:
+ return
+ except Exception as exc:
+ logger.warning(f"⚠️ 匿名会话池后台维护异常: {exc}")
+
+ async def initialize(self):
+ """初始化匿名会话池。"""
+ if self._maintenance_task:
+ return
+
+ await self._ensure_capacity()
+ created = len(self._list_valid_sessions())
+
+ if created == 0:
+ fallback = await self._create_session()
+ if not self._store_session(fallback):
+ raise RuntimeError(
+ "匿名会话池初始化失败: 无法写入唯一匿名会话"
+ )
+ created = len(self._list_valid_sessions())
+
+ logger.info(f"✅ 匿名会话池初始化完成: {created} 个会话")
+ self._maintenance_task = asyncio.create_task(self._maintenance_loop())
+
+ async def close(self):
+ """关闭匿名会话池。"""
+ if self._maintenance_task:
+ self._maintenance_task.cancel()
+ try:
+ await self._maintenance_task
+ except asyncio.CancelledError:
+ pass
+ self._maintenance_task = None
+
+ with self._lock:
+ sessions = list(self._sessions.values())
+ self._sessions.clear()
+
+ await self._wait_background_tasks()
+ idle_sessions = [
+ session for session in sessions if session.active_requests == 0
+ ]
+ await self._delete_sessions_concurrently(idle_sessions)
+ await self._close_http_client()
+
+ async def acquire(
+ self,
+ exclude_user_ids: Optional[Set[str]] = None,
+ ) -> GuestSession:
+ """按最小忙碌度获取一个可用匿名会话。"""
+ excluded = exclude_user_ids or set()
+ attempts_left = self._get_fill_attempt_budget(len(excluded) + 1)
+
+ while attempts_left > 0:
+ candidates = self._list_valid_sessions(exclude_user_ids=excluded)
+ if candidates:
+ session = min(
+ candidates,
+ key=lambda item: (item.active_requests, item.created_at),
+ )
+ with self._lock:
+ current = self._sessions.get(session.user_id)
+ if (
+ current
+ and self._is_session_usable(current)
+ and current.user_id not in excluded
+ ):
+ current.active_requests += 1
+ return current
+
+ new_session = await self._create_session()
+ attempts_left -= 1
+ if new_session.user_id in excluded:
+ logger.warning(
+ "⚠️ 获取匿名会话时命中排除 user_id,已忽略: "
+ f"{new_session.user_id}"
+ )
+ continue
+
+ if not self._store_session(new_session):
+ logger.warning(
+ "⚠️ 获取匿名会话时命中重复 user_id,已重试: "
+ f"{new_session.user_id}"
+ )
+ continue
+
+ with self._lock:
+ current = self._sessions.get(new_session.user_id)
+ if current and self._is_session_usable(current):
+ current.active_requests += 1
+ return current
+
+ raise RuntimeError("匿名会话池获取失败: 未能创建唯一匿名会话")
+
+ def release(self, user_id: str):
+ """释放一个匿名会话占用。"""
+ retired_session: Optional[GuestSession] = None
+
+ with self._lock:
+ session = self._sessions.get(user_id)
+ if session:
+ session.active_requests = max(0, session.active_requests - 1)
+ if self._should_retire_session(session):
+ retired_session = self._sessions.pop(user_id)
+
+ if retired_session:
+ logger.info(f"🧹 已回收过期匿名会话: {retired_session.user_id}")
+ self._track_background_task(self._delete_all_chats(retired_session))
+ self._track_background_task(self._ensure_capacity())
+
+ async def report_failure(self, user_id: Optional[str] = None):
+ """标记匿名会话失效,并尝试补一个新会话。"""
+ session: Optional[GuestSession] = None
+
+ if user_id:
+ with self._lock:
+ session = self._sessions.pop(user_id, None)
+ if session:
+ session.valid = False
+ session.failure_count += 1
+ session.last_failure_time = time.time()
+ session.active_requests = 0
+
+ if session:
+ self._track_background_task(self._delete_all_chats(session))
+ logger.warning(f"⚠️ 已淘汰匿名会话: {session.user_id}")
+
+ await self._ensure_capacity()
+
+ async def refresh_auth(self, failed_user_id: Optional[str] = None):
+ """兼容 glm-demo 命名:刷新匿名会话。"""
+ await self.report_failure(failed_user_id)
+
+ async def cleanup_idle_chats(self):
+ """清理当前空闲匿名会话的聊天记录。"""
+ with self._lock:
+ idle_sessions = [
+ session
+ for session in self._sessions.values()
+ if self._is_session_usable(session) and session.active_requests == 0
+ ]
+
+ await self._delete_sessions_concurrently(idle_sessions)
+
+ def get_pool_status(self) -> Dict[str, int]:
+ """获取匿名会话池状态。"""
+ with self._lock:
+ sessions = list(self._sessions.values())
+
+ valid_sessions = [
+ session for session in sessions if self._is_session_usable(session)
+ ]
+ busy_sessions = [
+ session for session in valid_sessions if session.active_requests > 0
+ ]
+
+ return {
+ "total_sessions": len(sessions),
+ "valid_sessions": len(valid_sessions),
+ "available_sessions": len(
+ [session for session in valid_sessions if session.active_requests == 0]
+ ),
+ "busy_sessions": len(busy_sessions),
+ "expired_sessions": len(
+ [session for session in sessions if session.is_expired]
+ ),
+ }
+
+
+_guest_session_pool: Optional[GuestSessionPool] = None
+_guest_pool_lock = Lock()
+
+
+def get_guest_session_pool() -> Optional[GuestSessionPool]:
+ """获取全局匿名会话池。"""
+ return _guest_session_pool
+
+
+async def initialize_guest_session_pool(
+ pool_size: int = 3,
+) -> GuestSessionPool:
+ """初始化全局匿名会话池。"""
+ global _guest_session_pool
+
+ with _guest_pool_lock:
+ if _guest_session_pool is None:
+ _guest_session_pool = GuestSessionPool(pool_size=pool_size)
+ pool = _guest_session_pool
+
+ await pool.initialize()
+ return pool
+
+
+async def close_guest_session_pool():
+ """关闭全局匿名会话池。"""
+ global _guest_session_pool
+
+ with _guest_pool_lock:
+ pool = _guest_session_pool
+ _guest_session_pool = None
+
+ if pool:
+ await pool.close()
diff --git a/app/utils/logger.py b/app/utils/logger.py
new file mode 100644
index 0000000000000000000000000000000000000000..202c5a439431ab146777408bc586e6d45feca810
--- /dev/null
+++ b/app/utils/logger.py
@@ -0,0 +1,105 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+
+import sys
+from pathlib import Path
+from loguru import logger
+
+# Global logger instance
+app_logger = None
+
+
+def setup_logger(log_dir, log_retention_days=7, log_rotation="1 day", debug_mode=False):
+ """
+ Create a logger instance
+
+ Parameters:
+ log_dir (str): 日志目录
+ log_retention_days (int): 日志保留天数
+ log_rotation (str): 日志轮转间隔
+ debug_mode (bool): 是否开启调试模式
+ """
+ global app_logger
+
+ # 移除所有现有的日志处理器(支持热重载)
+ logger.remove()
+
+ log_level = "DEBUG" if debug_mode else "INFO"
+
+ console_format = (
+ "{time:HH:mm:ss} | {level: <8} | {message}"
+ if not debug_mode
+ else "{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | "
+ "{name}:{function}:{line} | {message}"
+ )
+
+ # 添加控制台输出(根据 debug_mode 设置级别)
+ logger.add(sys.stderr, level=log_level, format=console_format, colorize=True)
+
+ # 只有在 debug_mode 时才添加文件输出
+ if debug_mode:
+ try:
+ log_path = Path(log_dir)
+ log_path.mkdir(parents=True, exist_ok=True)
+
+ log_file = log_path / "{time:YYYY-MM-DD}.log"
+ file_format = "{time:YYYY-MM-DD HH:mm:ss.SSS} | {level: <8} | {name}:{function}:{line} | {message}"
+
+ logger.add(
+ str(log_file),
+ level=log_level,
+ format=file_format,
+ rotation=log_rotation,
+ retention=f"{log_retention_days} days",
+ encoding="utf-8",
+ compression="zip",
+ enqueue=True,
+ catch=True,
+ )
+ except (PermissionError, OSError) as e:
+ # 如果无法创建日志目录或文件,降级为仅控制台输出
+ logger.warning(f"⚠️ 无法创建日志文件 ({e}),将仅使用控制台输出")
+
+ app_logger = logger
+
+ return logger
+
+
+def get_logger():
+ """Get the logger instance"""
+ global app_logger
+ if app_logger is None:
+ # 如果没有设置过logger,使用默认配置
+ logger.remove() # 移除所有现有处理器
+ logger.add(sys.stderr, level="INFO", format="{time:YYYY-MM-DD HH:mm:ss.SSS} | {level: <8} | {name}:{function}:{line} | {message}")
+ app_logger = logger
+ return app_logger
+
+
+if __name__ == "__main__":
+ """Test the logger"""
+ import tempfile
+
+ with tempfile.TemporaryDirectory() as temp_dir:
+ try:
+ setup_logger(temp_dir, debug_mode=True)
+
+ logger.debug("这是一条调试日志")
+ logger.info("这是一条信息日志")
+ logger.warning("这是一条警告日志")
+ logger.error("这是一条错误日志")
+ logger.critical("这是一条严重日志")
+
+ try:
+ 1 / 0
+ except ZeroDivisionError:
+ logger.exception("发生了除零异常")
+
+ print("✅ 日志测试完成")
+
+ logger.remove()
+
+ except Exception as e:
+ print(f"❌ 日志测试失败: {e}")
+ logger.remove()
+ raise
diff --git a/app/utils/reload_config.py b/app/utils/reload_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..36398a1b3a27317410a3b674d30cb54eeee303df
--- /dev/null
+++ b/app/utils/reload_config.py
@@ -0,0 +1,94 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+
+"""
+热重载配置模块
+定义 Granian 服务器热重载时需要忽略的目录和文件模式
+"""
+
+# 忽略的目录列表
+RELOAD_IGNORE_DIRS = [
+ "logs", # 忽略日志目录
+ "storage", # 忽略存储目录
+ "__pycache__", # 忽略 Python 缓存
+ ".git", # 忽略 git 目录
+ ".github", # 忽略 GitHub 相关目录
+ ".vscode", # 忽略 VSCode 配置目录
+ "deploy", # 忽略部署相关目录
+ ".idea", # 忽略 IntelliJ IDEA 配置目录
+ "node_modules", # 忽略 node_modules
+ "migrations", # 忽略数据库迁移目录
+ ".pytest_cache", # 忽略 pytest 缓存
+ ".venv", # 忽略虚拟环境
+ "venv", # 忽略虚拟环境
+ "env", # 忽略环境目录
+ ".mypy_cache", # 忽略 mypy 缓存
+ ".ruff_cache", # 忽略 ruff 缓存
+ "dist", # 忽略构建分发目录
+ "build", # 忽略构建目录
+ ".coverage", # 忽略测试覆盖率文件
+ "htmlcov", # 忽略覆盖率报告目录
+ "tests", # 忽略测试目录
+ "z-ai2api-server.pid", # 忽略 PID 文件
+ "app\\templates" # 忽略模板目录
+]
+
+# 忽略的文件模式(正则表达式)
+RELOAD_IGNORE_PATTERNS = [
+ # 日志文件
+ r".*\.log$",
+ r".*\.log\.\d+$",
+ # 数据库文件
+ r".*\.sqlite3.*",
+ r".*\.db$",
+ r".*\.db-.*$",
+ # Python 相关
+ r".*\.pyc$",
+ r".*\.pyo$",
+ r".*\.pyd$",
+ # 临时文件
+ r".*\.tmp$",
+ r".*\.temp$",
+ r".*\.swp$",
+ r".*\.swo$",
+ r".*~$",
+ # 系统文件
+ r".*\.DS_Store$",
+ r".*Thumbs\.db$",
+ r".*\.directory$",
+ # 编辑器文件
+ r".*\.vscode.*",
+ r".*\.idea.*",
+ # 测试和覆盖率
+ r".*\.coverage$",
+ r".*\.pytest_cache.*",
+ # 构建文件
+ r".*\.egg-info.*",
+ r".*\.wheel$",
+ r".*\.whl$",
+ # 版本控制
+ r".*\.git.*",
+ r".*\.gitignore$",
+ r".*\.gitkeep$",
+ # 配置文件备份
+ r".*\.bak$",
+ r".*\.backup$",
+ r".*\.orig$",
+ # 锁文件
+ r".*\.lock$",
+ r".*\.pid$",
+]
+
+# 监视的路径(只监视应用相关代码)
+RELOAD_WATCH_PATHS = [
+ "app", # 应用主目录
+ "main.py", # 主入口文件
+]
+
+# 热重载配置
+RELOAD_CONFIG = {
+ "reload_ignore_dirs": RELOAD_IGNORE_DIRS,
+ "reload_ignore_patterns": RELOAD_IGNORE_PATTERNS,
+ "reload_paths": RELOAD_WATCH_PATHS,
+ "reload_tick": 500, # 监视频率(毫秒)
+}
diff --git a/app/utils/request_logging.py b/app/utils/request_logging.py
new file mode 100644
index 0000000000000000000000000000000000000000..e13dd423bf5f87313656fea6e1dda40d0e1cea58
--- /dev/null
+++ b/app/utils/request_logging.py
@@ -0,0 +1,337 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+
+"""请求日志写库与流式日志包装。"""
+
+from __future__ import annotations
+
+import json
+import time
+from typing import Any, AsyncGenerator, Dict, Optional
+
+from app.services.request_log_dao import get_request_log_dao
+from app.utils.logger import get_logger
+from app.utils.request_source import RequestSourceInfo
+
+logger = get_logger()
+
+
+def _coerce_int(value: Any) -> int:
+ try:
+ return int(value or 0)
+ except (TypeError, ValueError):
+ return 0
+
+
+def _merge_usage(
+ current: Dict[str, int],
+ update: Dict[str, int],
+ *,
+ include_cache_in_total: bool,
+) -> Dict[str, int]:
+ merged = dict(current)
+
+ for key in (
+ "input_tokens",
+ "output_tokens",
+ "cache_creation_tokens",
+ "cache_read_tokens",
+ ):
+ value = _coerce_int(update.get(key))
+ if value > 0:
+ merged[key] = value
+
+ total_tokens = _coerce_int(update.get("total_tokens"))
+ if total_tokens > 0:
+ merged["total_tokens"] = total_tokens
+ return merged
+
+ merged["total_tokens"] = (
+ merged["input_tokens"] + merged["output_tokens"]
+ )
+ if include_cache_in_total:
+ merged["total_tokens"] += (
+ merged["cache_creation_tokens"] + merged["cache_read_tokens"]
+ )
+
+ return merged
+
+
+def extract_openai_usage(response: Dict[str, Any]) -> Dict[str, int]:
+ """Extract usage from an OpenAI-compatible response payload."""
+ usage = response.get("usage") or {}
+ prompt_details = usage.get("prompt_tokens_details") or {}
+ input_details = usage.get("input_token_details") or {}
+
+ input_tokens = _coerce_int(
+ usage.get("prompt_tokens") or usage.get("input_tokens")
+ )
+ output_tokens = _coerce_int(
+ usage.get("completion_tokens") or usage.get("output_tokens")
+ )
+ cache_creation_tokens = _coerce_int(
+ usage.get("cache_creation_input_tokens")
+ or prompt_details.get("cache_creation_tokens")
+ or input_details.get("cache_creation_input_tokens")
+ or input_details.get("cache_creation_tokens")
+ )
+ cache_read_tokens = _coerce_int(
+ usage.get("cache_read_input_tokens")
+ or prompt_details.get("cached_tokens")
+ or prompt_details.get("cache_read_tokens")
+ or input_details.get("cached_tokens")
+ or input_details.get("cache_read_input_tokens")
+ or input_details.get("cache_read_tokens")
+ )
+ total_tokens = _coerce_int(usage.get("total_tokens"))
+ if total_tokens <= 0:
+ total_tokens = input_tokens + output_tokens
+
+ return {
+ "input_tokens": input_tokens,
+ "output_tokens": output_tokens,
+ "cache_creation_tokens": cache_creation_tokens,
+ "cache_read_tokens": cache_read_tokens,
+ "total_tokens": total_tokens,
+ }
+
+
+def extract_claude_usage(response: Dict[str, Any]) -> Dict[str, int]:
+ """Extract usage from a Claude-compatible response payload."""
+ usage = response.get("usage") or {}
+ input_tokens = _coerce_int(
+ usage.get("input_tokens") or usage.get("prompt_tokens")
+ )
+ output_tokens = _coerce_int(
+ usage.get("output_tokens") or usage.get("completion_tokens")
+ )
+ cache_creation_tokens = _coerce_int(
+ usage.get("cache_creation_input_tokens")
+ or usage.get("cache_creation_tokens")
+ )
+ cache_read_tokens = _coerce_int(
+ usage.get("cache_read_input_tokens")
+ or usage.get("cached_tokens")
+ or usage.get("cache_read_tokens")
+ )
+ total_tokens = _coerce_int(usage.get("total_tokens"))
+ if total_tokens <= 0:
+ total_tokens = (
+ input_tokens
+ + output_tokens
+ + cache_creation_tokens
+ + cache_read_tokens
+ )
+
+ return {
+ "input_tokens": input_tokens,
+ "output_tokens": output_tokens,
+ "cache_creation_tokens": cache_creation_tokens,
+ "cache_read_tokens": cache_read_tokens,
+ "total_tokens": total_tokens,
+ }
+
+
+async def write_request_log(
+ *,
+ provider: str,
+ model: str,
+ source_info: RequestSourceInfo,
+ success: bool,
+ started_at: float,
+ status_code: int = 200,
+ first_token_time: float = 0.0,
+ input_tokens: int = 0,
+ output_tokens: int = 0,
+ cache_creation_tokens: int = 0,
+ cache_read_tokens: int = 0,
+ total_tokens: Optional[int] = None,
+ error_message: Optional[str] = None,
+) -> None:
+ """Persist a request log entry without breaking request handling."""
+ duration = max(0.0, time.perf_counter() - started_at)
+ try:
+ dao = get_request_log_dao()
+ await dao.add_log(
+ provider=provider,
+ endpoint=source_info.endpoint,
+ source=source_info.source,
+ protocol=source_info.protocol,
+ client_name=source_info.client_name,
+ model=model,
+ status_code=status_code,
+ success=success,
+ duration=duration,
+ first_token_time=first_token_time,
+ input_tokens=input_tokens,
+ output_tokens=output_tokens,
+ cache_creation_tokens=cache_creation_tokens,
+ cache_read_tokens=cache_read_tokens,
+ total_tokens=total_tokens,
+ error_message=error_message,
+ )
+ except Exception as exc:
+ logger.error(f"写入请求日志失败: {exc}")
+
+
+def _openai_payload_has_output(payload: Dict[str, Any]) -> bool:
+ choice = ((payload.get("choices") or [{}])[0]) if isinstance(payload, dict) else {}
+ delta = choice.get("delta") or {}
+ return bool(
+ delta.get("content")
+ or delta.get("reasoning_content")
+ or delta.get("tool_calls")
+ )
+
+
+async def wrap_openai_stream_with_logging(
+ stream: AsyncGenerator[str, None],
+ *,
+ provider: str,
+ model: str,
+ source_info: RequestSourceInfo,
+ started_at: float,
+) -> AsyncGenerator[str, None]:
+ """Wrap OpenAI SSE stream and persist completion metadata."""
+ success = True
+ status_code = 200
+ error_message: Optional[str] = None
+ first_token_time = 0.0
+ usage = {
+ "input_tokens": 0,
+ "output_tokens": 0,
+ "cache_creation_tokens": 0,
+ "cache_read_tokens": 0,
+ "total_tokens": 0,
+ }
+
+ try:
+ async for chunk in stream:
+ if chunk.startswith("data: "):
+ payload_text = chunk[6:].strip()
+ if payload_text and payload_text != "[DONE]":
+ try:
+ payload = json.loads(payload_text)
+ except json.JSONDecodeError:
+ payload = None
+
+ if isinstance(payload, dict):
+ if "error" in payload:
+ success = False
+ error = payload.get("error") or {}
+ error_message = (
+ error.get("message")
+ or "Unknown stream error"
+ )
+ status_code = int(error.get("code") or 500)
+ else:
+ if (
+ not first_token_time
+ and _openai_payload_has_output(payload)
+ ):
+ first_token_time = max(
+ 0.0,
+ time.perf_counter() - started_at,
+ )
+ if payload.get("usage"):
+ usage = _merge_usage(
+ usage,
+ extract_openai_usage(payload),
+ include_cache_in_total=False,
+ )
+
+ yield chunk
+ except Exception as exc:
+ success = False
+ status_code = 500
+ error_message = str(exc)
+ raise
+ finally:
+ await write_request_log(
+ provider=provider,
+ model=model,
+ source_info=source_info,
+ success=success,
+ started_at=started_at,
+ status_code=status_code,
+ first_token_time=first_token_time,
+ input_tokens=usage["input_tokens"],
+ output_tokens=usage["output_tokens"],
+ cache_creation_tokens=usage["cache_creation_tokens"],
+ cache_read_tokens=usage["cache_read_tokens"],
+ total_tokens=usage["total_tokens"],
+ error_message=error_message,
+ )
+
+
+async def wrap_claude_stream_with_logging(
+ stream: AsyncGenerator[str, None],
+ *,
+ provider: str,
+ model: str,
+ source_info: RequestSourceInfo,
+ started_at: float,
+ input_tokens: int,
+) -> AsyncGenerator[str, None]:
+ """Wrap Claude SSE stream and persist completion metadata."""
+ success = True
+ status_code = 200
+ error_message: Optional[str] = None
+ first_token_time = 0.0
+ usage = {
+ "input_tokens": input_tokens,
+ "output_tokens": 0,
+ "cache_creation_tokens": 0,
+ "cache_read_tokens": 0,
+ "total_tokens": input_tokens,
+ }
+ current_event: Optional[str] = None
+
+ try:
+ async for chunk in stream:
+ if chunk.startswith("event: "):
+ current_event = chunk[7:].strip()
+ elif chunk.startswith("data: "):
+ payload_text = chunk[6:].strip()
+ try:
+ payload = json.loads(payload_text)
+ except json.JSONDecodeError:
+ payload = None
+
+ if isinstance(payload, dict):
+ if current_event == "content_block_delta" and not first_token_time:
+ first_token_time = max(0.0, time.perf_counter() - started_at)
+ if payload.get("usage"):
+ usage = _merge_usage(
+ usage,
+ extract_claude_usage(payload),
+ include_cache_in_total=True,
+ )
+ elif current_event == "error":
+ success = False
+ status_code = 500
+ error = payload.get("error") or {}
+ error_message = error.get("message") or "Claude stream error"
+
+ yield chunk
+ except Exception as exc:
+ success = False
+ status_code = 500
+ error_message = str(exc)
+ raise
+ finally:
+ await write_request_log(
+ provider=provider,
+ model=model,
+ source_info=source_info,
+ success=success,
+ started_at=started_at,
+ status_code=status_code,
+ first_token_time=first_token_time,
+ input_tokens=usage["input_tokens"],
+ output_tokens=usage["output_tokens"],
+ cache_creation_tokens=usage["cache_creation_tokens"],
+ cache_read_tokens=usage["cache_read_tokens"],
+ total_tokens=usage["total_tokens"],
+ error_message=error_message,
+ )
diff --git a/app/utils/request_source.py b/app/utils/request_source.py
new file mode 100644
index 0000000000000000000000000000000000000000..02c68c915801d53a1851994c71379350239fb052
--- /dev/null
+++ b/app/utils/request_source.py
@@ -0,0 +1,129 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+
+"""请求来源识别辅助函数。"""
+
+from __future__ import annotations
+
+import re
+from dataclasses import dataclass
+from typing import Any, Optional
+
+from fastapi import Request
+
+
+ANTHROPIC_MODEL_PREFIXES = (
+ "claude-",
+ "claude.",
+)
+ANTHROPIC_MODEL_ALIASES = {
+ "sonnet",
+ "opus",
+ "haiku",
+ "opusplan",
+}
+
+
+@dataclass(frozen=True)
+class RequestSourceInfo:
+ """Normalized request-source metadata for logging."""
+
+ source: str
+ protocol: str
+ client_name: str
+ endpoint: str
+ user_agent: str
+
+
+def _normalize_source_name(value: str) -> str:
+ normalized = re.sub(r"[^a-zA-Z0-9._-]+", "_", value.strip().lower())
+ return normalized.strip("_") or "unknown"
+
+
+def _looks_like_anthropic_model(model_hint: Optional[str]) -> bool:
+ if not isinstance(model_hint, str):
+ return False
+
+ normalized = model_hint.strip().casefold()
+ if normalized in ANTHROPIC_MODEL_ALIASES:
+ return True
+
+ return normalized.startswith(ANTHROPIC_MODEL_PREFIXES)
+
+
+def detect_request_source(
+ request: Request,
+ protocol_hint: Optional[str] = None,
+ model_hint: Optional[str] = None,
+) -> RequestSourceInfo:
+ """Detect the request source from headers, path, and model hints."""
+ headers = request.headers
+ endpoint = request.url.path
+ user_agent = (headers.get("user-agent") or "").strip()
+ user_agent_normalized = user_agent.casefold()
+
+ protocol = (protocol_hint or "").strip().lower()
+ if not protocol:
+ if headers.get("anthropic-version") or "/messages" in endpoint:
+ protocol = "anthropic"
+ elif "/chat/completions" in endpoint:
+ protocol = "openai"
+ else:
+ protocol = "unknown"
+
+ explicit_source = headers.get("x-request-source") or headers.get("x-client-source")
+ if explicit_source:
+ source = _normalize_source_name(explicit_source)
+ return RequestSourceInfo(
+ source=source,
+ protocol=protocol,
+ client_name=explicit_source.strip(),
+ endpoint=endpoint,
+ user_agent=user_agent,
+ )
+
+ if any(token in user_agent_normalized for token in ("claude-code", "claude code", "claude-cli", "claude/")):
+ source = "claude_code"
+ client_name = "Claude Code"
+ elif "anthropic" in user_agent_normalized:
+ source = "anthropic_sdk"
+ client_name = "Anthropic SDK"
+ elif "openai" in user_agent_normalized:
+ source = "openai_sdk"
+ client_name = "OpenAI SDK"
+ elif "curl/" in user_agent_normalized:
+ source = "curl"
+ client_name = "curl"
+ elif any(token in user_agent_normalized for token in ("python-httpx", "httpx/", "python-requests", "requests/")):
+ source = "custom_http_client"
+ client_name = "HTTP Client"
+ elif "mozilla/" in user_agent_normalized:
+ source = "browser"
+ client_name = "Browser"
+ elif protocol == "anthropic":
+ source = "claude_family" if _looks_like_anthropic_model(model_hint) else "anthropic_compatible"
+ client_name = "Claude/Anthropic Compatible"
+ elif protocol == "openai":
+ source = "openai_compatible"
+ client_name = "OpenAI Compatible"
+ else:
+ source = "unknown"
+ client_name = "Unknown"
+
+ return RequestSourceInfo(
+ source=source,
+ protocol=protocol,
+ client_name=client_name,
+ endpoint=endpoint,
+ user_agent=user_agent,
+ )
+
+
+def format_request_source(info: RequestSourceInfo) -> str:
+ """Render request-source metadata into a compact log prefix."""
+ return (
+ f"[source={info.source}]"
+ f"[protocol={info.protocol}]"
+ f"[client={info.client_name}]"
+ f"[endpoint={info.endpoint}]"
+ )
diff --git a/app/utils/signature.py b/app/utils/signature.py
new file mode 100644
index 0000000000000000000000000000000000000000..785967a448ff2669f4e93f8a2aa66c53af85673f
--- /dev/null
+++ b/app/utils/signature.py
@@ -0,0 +1,56 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+
+"""
+Z.AI 签名工具模块
+"""
+
+import hmac
+import hashlib
+import base64
+from typing import Dict
+
+
+def generate_signature(e: str, t: str, s: int) -> dict:
+ """Generate signature matching JavaScript zs function.
+
+ Args:
+ e: canonical metadata string, e.g. "requestId,,timestamp,,user_id,"
+ t: latest user message text that feeds into the signature prompt (may be empty)
+ s: timestamp in milliseconds
+
+ Returns:
+ Dictionary with signature and timestamp
+ """
+ # r = Number(s) - convert to number (already a number in Python)
+ r = s
+ # i = s - timestamp as string
+ i = str(s)
+
+ # n = new TextEncoder
+ # a = n.encode(t)
+ a = t.encode('utf-8')
+
+ # w = btoa(String.fromCharCode(...a))
+ # This is equivalent to base64 encoding the UTF-8 bytes
+ w = base64.b64encode(a).decode('ascii')
+
+ # c = `${e}|${w}|${i}`
+ c = f"{e}|{w}|{i}"
+
+ # E = Math.floor(r / (5 * 60 * 1e3))
+ E = r // (5 * 60 * 1000)
+
+ # A = CryptoJS.HmacSHA256(`${E}`, "key-@@@@)))()((9))-xxxx&&&%%%%%")
+ secret = "key-@@@@)))()((9))-xxxx&&&%%%%%"
+ A = hmac.new(secret.encode('utf-8'), str(E).encode('utf-8'), hashlib.sha256).hexdigest()
+
+ # k = CryptoJS.HmacSHA256(c, A).toString()
+ k = hmac.new(A.encode('utf-8'), c.encode('utf-8'), hashlib.sha256).hexdigest()
+
+ # return n.encode(c), { signature: k, timestamp: i }
+ # Note: n.encode(c) is not used in the return value, so we ignore it
+ return {
+ "signature": k,
+ "timestamp": i
+ }
diff --git a/app/utils/token_pool.py b/app/utils/token_pool.py
new file mode 100644
index 0000000000000000000000000000000000000000..7681313bbae55b51030a0e723c876a329a99b363
--- /dev/null
+++ b/app/utils/token_pool.py
@@ -0,0 +1,685 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+
+"""
+Token 池管理器 - 基于数据库的 Token 轮询和健康检查
+
+核心功能:
+1. Token 轮询机制 - 负载均衡和容错
+2. Z.AI 官方认证接口验证 - 基于 role 字段区分用户类型
+3. Token 健康度监控 - 自动禁用失败 Token
+4. 数据库集成 - 与 TokenDAO 协同工作
+"""
+
+import asyncio
+import time
+from dataclasses import dataclass
+from threading import Lock
+from typing import Dict, List, Optional, Set, Tuple
+
+import httpx
+
+from app.utils.logger import logger
+
+
+# ==================== Token 状态管理 ====================
+
+
+@dataclass
+class TokenStatus:
+ """Token 运行时状态(内存中)"""
+ token: str
+ token_id: int # 数据库 ID,用于同步统计
+ token_type: str = "unknown" # "user", "guest", "unknown"
+ is_available: bool = True
+ failure_count: int = 0
+ last_failure_time: float = 0.0
+ last_success_time: float = 0.0
+ total_requests: int = 0
+ successful_requests: int = 0
+ db_synced_successful_requests: int = 0
+ db_synced_failed_requests: int = 0
+
+ @property
+ def success_rate(self) -> float:
+ """成功率"""
+ if self.total_requests == 0:
+ return 1.0
+ return self.successful_requests / self.total_requests
+
+ @property
+ def failed_requests(self) -> int:
+ """失败次数。"""
+ return max(0, self.total_requests - self.successful_requests)
+
+ @property
+ def is_healthy(self) -> bool:
+ """
+ Token 健康状态判断
+
+ 健康标准:
+ 1. 必须是认证用户 Token (token_type = "user")
+ 2. 当前可用 (is_available = True)
+ 3. 成功率 >= 50% 或总请求数 <= 3(新 Token 容错)
+
+ 注意:
+ - guest Token 永远不健康
+ - unknown Token 永远不健康
+ """
+ # guest 和 unknown token 永远不健康
+ if self.token_type != "user":
+ return False
+
+ # 不可用的 token 不健康
+ if not self.is_available:
+ return False
+
+ # 新 token 容错:请求数很少时,只要没失败就健康
+ if self.total_requests <= 3:
+ return self.failure_count == 0
+
+ # 基于成功率判断
+ return self.success_rate >= 0.5
+
+
+# ==================== Token 验证服务 ====================
+
+
+class ZAITokenValidator:
+ """Z.AI Token 验证器(使用官方认证接口)"""
+
+ AUTH_URL = "https://chat.z.ai/api/v1/auths/"
+
+ @staticmethod
+ def get_headers(token: str) -> Dict[str, str]:
+ """构建认证请求头"""
+ return {
+ "Accept": "*/*",
+ "Accept-Language": "zh-CN,zh;q=0.9",
+ "Authorization": f"Bearer {token}",
+ "Connection": "keep-alive",
+ "Content-Type": "application/json",
+ "DNT": "1",
+ "Referer": "https://chat.z.ai/",
+ "Sec-Fetch-Dest": "empty",
+ "Sec-Fetch-Mode": "cors",
+ "Sec-Fetch-Site": "same-origin",
+ "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/140.0.0.0 Safari/537.36",
+ "sec-ch-ua": '"Chromium";v="140", "Not=A?Brand";v="24", "Google Chrome";v="140"',
+ "sec-ch-ua-mobile": "?0",
+ "sec-ch-ua-platform": '"Windows"'
+ }
+
+ @classmethod
+ async def validate_token(cls, token: str) -> Tuple[str, bool, Optional[str]]:
+ """
+ 验证 Token 有效性并返回类型
+
+ Args:
+ token: 待验证的 Token
+
+ Returns:
+ (token_type, is_valid, error_message)
+ - token_type: "user" | "guest" | "unknown"
+ - is_valid: True 表示是有效的认证用户 Token
+ - error_message: 失败原因(仅在 is_valid=False 时有值)
+ """
+ try:
+ async with httpx.AsyncClient(timeout=15.0) as client:
+ response = await client.get(
+ cls.AUTH_URL,
+ headers=cls.get_headers(token)
+ )
+
+ # 解析响应
+ return cls._parse_auth_response(response)
+
+ except httpx.TimeoutException:
+ return ("unknown", False, "请求超时")
+ except httpx.ConnectError:
+ return ("unknown", False, "连接失败")
+ except Exception as e:
+ return ("unknown", False, f"验证异常: {str(e)}")
+
+ @staticmethod
+ def _parse_auth_response(response: httpx.Response) -> Tuple[str, bool, Optional[str]]:
+ """
+ 解析 Z.AI 认证接口响应
+
+ 响应格式示例:
+ {
+ "id": "...",
+ "email": "user@example.com",
+ "role": "user" # 或 "guest"
+ }
+
+ 验证规则:
+ - role: "user" → 认证用户 Token(有效,可添加)
+ - role: "guest" → 匿名用户 Token(无效,拒绝添加)
+ - 其他情况 → 无效 Token
+ """
+ # 检查 HTTP 状态码
+ if response.status_code != 200:
+ return ("unknown", False, f"HTTP {response.status_code}")
+
+ try:
+ data = response.json()
+
+ # 验证响应格式
+ if not isinstance(data, dict):
+ return ("unknown", False, "无效的响应格式")
+
+ # 检查是否包含错误信息
+ if "error" in data or "message" in data:
+ error_msg = data.get("error") or data.get("message", "未知错误")
+ return ("unknown", False, str(error_msg))
+
+ # 核心验证:检查 role 字段
+ role = data.get("role")
+
+ if role == "user":
+ return ("user", True, None)
+ elif role == "guest":
+ return ("guest", False, "匿名用户 Token 不允许添加")
+ else:
+ return ("unknown", False, f"未知 role: {role}")
+
+ except (ValueError, Exception) as e:
+ return ("unknown", False, f"解析响应失败: {str(e)}")
+
+
+# ==================== Token 池管理器 ====================
+
+
+class TokenPool:
+ """Token 池管理器(数据库驱动)"""
+
+ def __init__(
+ self,
+ tokens: List[Tuple[int, str, str]], # [(token_id, token_value, token_type), ...]
+ failure_threshold: int = 3,
+ recovery_timeout: int = 1800
+ ):
+ """
+ 初始化 Token 池
+
+ Args:
+ tokens: Token 列表 [(token_id, token_value, token_type), ...]
+ failure_threshold: 失败阈值,超过此次数将标记为不可用
+ recovery_timeout: 恢复超时时间(秒),失败 Token 在此时间后重新尝试
+ """
+ self.failure_threshold = failure_threshold
+ self.recovery_timeout = recovery_timeout
+ self._lock = Lock()
+ self._current_index = 0
+
+ # 初始化 Token 状态(内存中)
+ self.token_statuses: Dict[str, TokenStatus] = {}
+ self.token_id_map: Dict[str, int] = {} # token -> token_id 映射
+
+ for token_id, token_value, token_type in tokens:
+ if token_value and token_value not in self.token_statuses:
+ self.token_statuses[token_value] = TokenStatus(
+ token=token_value,
+ token_id=token_id,
+ token_type=token_type
+ )
+ self.token_id_map[token_value] = token_id
+
+ if not self.token_statuses:
+ logger.warning("⚠️ Token 池为空,将依赖匿名模式")
+
+ def get_next_token(self, exclude_tokens: Optional[Set[str]] = None) -> Optional[str]:
+ """
+ 获取下一个可用的认证用户 Token(轮询算法)
+
+ Returns:
+ 可用的 Token 字符串,如果没有可用 Token 则返回 None
+ """
+ with self._lock:
+ if not self.token_statuses:
+ return None
+
+ excluded = exclude_tokens or set()
+
+ available_tokens = self._get_available_user_tokens()
+ if excluded:
+ available_tokens = [
+ token for token in available_tokens if token not in excluded
+ ]
+ if not available_tokens:
+ # 尝试恢复过期的失败 Token
+ self._try_recover_failed_tokens()
+ available_tokens = self._get_available_user_tokens()
+ if excluded:
+ available_tokens = [
+ token for token in available_tokens if token not in excluded
+ ]
+
+ if not available_tokens:
+ logger.warning("⚠️ 没有可用的认证用户 Token")
+ return None
+
+ # 轮询选择
+ token = available_tokens[self._current_index % len(available_tokens)]
+ self._current_index = (self._current_index + 1) % len(available_tokens)
+
+ return token
+
+ def _get_available_user_tokens(self) -> List[str]:
+ """
+ 获取当前可用的认证用户 Token 列表
+
+ 过滤条件:
+ 1. is_available = True
+ 2. token_type == "user"
+ """
+ available_user_tokens = [
+ status.token for status in self.token_statuses.values()
+ if status.is_available and status.token_type == "user"
+ ]
+
+ # 警告:如果有 guest token 但没有 user token
+ if not available_user_tokens and self.token_statuses:
+ guest_count = sum(
+ 1 for status in self.token_statuses.values()
+ if status.token_type == "guest"
+ )
+ if guest_count > 0:
+ logger.warning(f"⚠️ 检测到 {guest_count} 个匿名用户 Token,轮询机制将跳过这些 Token")
+
+ return available_user_tokens
+
+ def _try_recover_failed_tokens(self):
+ """尝试恢复失败的 Token(仅针对认证用户 Token)"""
+ current_time = time.time()
+ recovered_count = 0
+
+ for status in self.token_statuses.values():
+ # 只恢复认证用户 Token
+ if (
+ status.token_type == "user"
+ and not status.is_available
+ and current_time - status.last_failure_time > self.recovery_timeout
+ ):
+ status.is_available = True
+ status.failure_count = 0
+ recovered_count += 1
+ logger.info(f"🔄 恢复失败 Token: {status.token[:20]}...")
+
+ if recovered_count > 0:
+ logger.info(f"✅ 恢复了 {recovered_count} 个失败的 Token")
+
+ def mark_token_success(self, token: str):
+ """标记 Token 使用成功"""
+ with self._lock:
+ if token in self.token_statuses:
+ status = self.token_statuses[token]
+ status.total_requests += 1
+ status.successful_requests += 1
+ status.last_success_time = time.time()
+ status.failure_count = 0 # 重置失败计数
+
+ if not status.is_available:
+ status.is_available = True
+ logger.info(f"✅ Token 恢复可用: {token[:20]}...")
+
+ def mark_token_failure(self, token: str, error: Exception = None):
+ """标记 Token 使用失败"""
+ with self._lock:
+ if token in self.token_statuses:
+ status = self.token_statuses[token]
+ status.total_requests += 1
+ status.failure_count += 1
+ status.last_failure_time = time.time()
+
+ if status.failure_count >= self.failure_threshold:
+ status.is_available = False
+ logger.warning(f"🚫 Token 已禁用: {token[:20]}... (失败 {status.failure_count} 次)")
+
+ async def record_token_success(self, token: str, dao=None):
+ """标记成功并实时同步数据库统计。"""
+ self.mark_token_success(token)
+
+ token_id = self.get_token_id(token)
+ if token_id is None:
+ return
+
+ if dao is None:
+ from app.services.token_dao import get_token_dao
+
+ dao = get_token_dao()
+
+ try:
+ await dao.record_success(token_id)
+ except Exception as e:
+ logger.error(f"❌ 同步 Token 成功统计失败: {e}")
+ return
+
+ with self._lock:
+ if token in self.token_statuses:
+ self.token_statuses[token].db_synced_successful_requests += 1
+
+ async def record_token_failure(self, token: str, error: Exception = None, dao=None):
+ """标记失败并实时同步数据库统计。"""
+ self.mark_token_failure(token, error)
+
+ token_id = self.get_token_id(token)
+ if token_id is None:
+ return
+
+ if dao is None:
+ from app.services.token_dao import get_token_dao
+
+ dao = get_token_dao()
+
+ try:
+ await dao.record_failure(token_id)
+ except Exception as e:
+ logger.error(f"❌ 同步 Token 失败统计失败: {e}")
+ return
+
+ with self._lock:
+ if token in self.token_statuses:
+ self.token_statuses[token].db_synced_failed_requests += 1
+
+ def get_token_id(self, token: str) -> Optional[int]:
+ """获取 Token 的数据库 ID"""
+ return self.token_id_map.get(token)
+
+ def get_pool_status(self) -> Dict:
+ """获取 Token 池状态信息"""
+ with self._lock:
+ available_count = len(self._get_available_user_tokens())
+ total_count = len(self.token_statuses)
+ healthy_count = sum(1 for status in self.token_statuses.values() if status.is_healthy)
+
+ # 统计各类型 Token
+ user_count = sum(1 for s in self.token_statuses.values() if s.token_type == "user")
+ guest_count = sum(1 for s in self.token_statuses.values() if s.token_type == "guest")
+ unknown_count = sum(1 for s in self.token_statuses.values() if s.token_type == "unknown")
+
+ status_info = {
+ "total_tokens": total_count,
+ "available_tokens": available_count,
+ "unavailable_tokens": total_count - available_count,
+ "healthy_tokens": healthy_count,
+ "unhealthy_tokens": total_count - healthy_count,
+ "user_tokens": user_count,
+ "guest_tokens": guest_count,
+ "unknown_tokens": unknown_count,
+ "current_index": self._current_index,
+ "tokens": []
+ }
+
+ for token, status in self.token_statuses.items():
+ status_info["tokens"].append({
+ "token": f"{token[:10]}...{token[-10:]}",
+ "token_id": status.token_id,
+ "token_type": status.token_type,
+ "is_available": status.is_available,
+ "failure_count": status.failure_count,
+ "success_count": status.successful_requests,
+ "success_rate": f"{status.success_rate:.2%}",
+ "total_requests": status.total_requests,
+ "is_healthy": status.is_healthy,
+ "last_failure_time": status.last_failure_time,
+ "last_success_time": status.last_success_time
+ })
+
+ return status_info
+
+ def update_token_type(self, token: str, token_type: str):
+ """更新 Token 类型(用于健康检查后更新)"""
+ with self._lock:
+ if token in self.token_statuses:
+ old_type = self.token_statuses[token].token_type
+ self.token_statuses[token].token_type = token_type
+
+ if old_type != token_type:
+ logger.info(f"🔄 更新 Token 类型: {token[:20]}... {old_type} → {token_type}")
+
+ async def health_check_token(self, token: str) -> bool:
+ """
+ 异步健康检查单个 Token(使用 Z.AI 官方认证接口)
+
+ Args:
+ token: 要检查的 Token
+
+ Returns:
+ Token 是否健康(True = 有效的认证用户 Token)
+ """
+ token_type, is_valid, error_message = await ZAITokenValidator.validate_token(token)
+
+ # 更新 Token 类型
+ self.update_token_type(token, token_type)
+
+ # 更新状态
+ if is_valid:
+ await self.record_token_success(token)
+ else:
+ await self.record_token_failure(
+ token,
+ Exception(error_message or "验证失败"),
+ )
+
+ return is_valid
+
+ async def health_check_all(self):
+ """异步健康检查所有 Token"""
+ if not self.token_statuses:
+ logger.warning("⚠️ Token 池为空,跳过健康检查")
+ return
+
+ total_tokens = len(self.token_statuses)
+ logger.info(f"🔍 开始 Token 池健康检查... (共 {total_tokens} 个 Token)")
+
+ # 并发执行所有 Token 的健康检查
+ tasks = [
+ self.health_check_token(token)
+ for token in self.token_statuses.keys()
+ ]
+
+ results = await asyncio.gather(*tasks, return_exceptions=True)
+
+ # 统计结果
+ healthy_count = sum(1 for r in results if r is True)
+ failed_count = sum(1 for r in results if r is False)
+ exception_count = sum(1 for r in results if isinstance(r, Exception))
+
+ health_rate = (healthy_count / total_tokens) * 100 if total_tokens > 0 else 0
+
+ if healthy_count == 0 and total_tokens > 0:
+ logger.warning(f"⚠️ 健康检查完成: 0/{total_tokens} 个 Token 健康 - 请检查 Token 配置")
+ elif failed_count > 0:
+ logger.warning(f"⚠️ 健康检查完成: {healthy_count}/{total_tokens} 个 Token 健康 ({health_rate:.1f}%)")
+ else:
+ logger.info(f"✅ 健康检查完成: {healthy_count}/{total_tokens} 个 Token 健康")
+
+ if exception_count > 0:
+ logger.error(f"💥 {exception_count} 个 Token 检查异常")
+
+ async def sync_from_database(self, provider: str = "zai"):
+ """
+ 从数据库同步 Token 状态(禁用/启用状态)
+
+ Args:
+ provider: 提供商名称
+
+ 说明:
+ - 从数据库读取最新的 Token 启用状态
+ - 如果数据库中 Token 被禁用,则从池中移除
+ - 如果数据库中有新增的启用 Token,则添加到池中
+ - 保留现有 Token 的运行时统计(请求数、成功率等)
+ """
+ from app.services.token_dao import get_token_dao
+
+ dao = get_token_dao()
+
+ # 从数据库加载所有启用的认证用户 Token
+ token_records = await dao.get_tokens_by_provider(provider, enabled_only=True)
+
+ # 构建数据库中的 Token 映射
+ db_tokens = {
+ record["token"]: (record["id"], record.get("token_type", "unknown"))
+ for record in token_records
+ if record.get("token_type") != "guest" # 过滤 guest token
+ }
+
+ with self._lock:
+ # 1. 移除已在数据库中禁用的 Token
+ tokens_to_remove = []
+ for token_value in list(self.token_statuses.keys()):
+ if token_value not in db_tokens:
+ tokens_to_remove.append(token_value)
+
+ for token_value in tokens_to_remove:
+ del self.token_statuses[token_value]
+ del self.token_id_map[token_value]
+ logger.info(f"🗑️ 从池中移除已禁用 Token: {token_value[:20]}...")
+
+ # 2. 添加新启用的 Token
+ new_tokens_count = 0
+ for token_value, (token_id, token_type) in db_tokens.items():
+ if token_value not in self.token_statuses:
+ self.token_statuses[token_value] = TokenStatus(
+ token=token_value,
+ token_id=token_id,
+ token_type=token_type
+ )
+ self.token_id_map[token_value] = token_id
+ new_tokens_count += 1
+ logger.info(f"➕ 添加新启用 Token: {token_value[:20]}...")
+
+ # 3. 更新现有 Token 的类型(如果数据库中有更新)
+ for token_value, (token_id, token_type) in db_tokens.items():
+ if token_value in self.token_statuses:
+ old_type = self.token_statuses[token_value].token_type
+ if old_type != token_type:
+ self.token_statuses[token_value].token_type = token_type
+ logger.info(f"🔄 更新 Token 类型: {token_value[:20]}... {old_type} → {token_type}")
+
+ logger.info(
+ f"✅ Token 池同步完成: "
+ f"当前 {len(self.token_statuses)} 个 Token "
+ f"(移除 {len(tokens_to_remove)}, 新增 {new_tokens_count})"
+ )
+
+
+# ==================== 全局实例管理 ====================
+
+
+_token_pool: Optional[TokenPool] = None
+_pool_lock = Lock()
+
+
+def get_token_pool() -> Optional[TokenPool]:
+ """获取全局 Token 池实例"""
+ return _token_pool
+
+
+async def initialize_token_pool_from_db(
+ provider: str = "zai",
+ failure_threshold: int = 3,
+ recovery_timeout: int = 1800
+) -> Optional[TokenPool]:
+ """
+ 从数据库初始化全局 Token 池
+
+ Args:
+ provider: 提供商名称(当前仅使用 zai)
+ failure_threshold: 失败阈值
+ recovery_timeout: 恢复超时时间(秒)
+
+ Returns:
+ TokenPool 实例(即使没有 Token 也会创建空池)
+ """
+ global _token_pool
+
+ from app.services.token_dao import get_token_dao
+
+ dao = get_token_dao()
+
+ # 从数据库加载 Token(只加载启用的认证用户 Token)
+ token_records = await dao.get_tokens_by_provider(provider, enabled_only=True)
+
+ # 转换为 TokenPool 所需格式
+ tokens = []
+ if token_records:
+ tokens = [
+ (record["id"], record["token"], record.get("token_type", "unknown"))
+ for record in token_records
+ ]
+
+ # 过滤掉 guest token(不应该在数据库中,但防御性检查)
+ user_tokens = [
+ (tid, tval, ttype) for tid, tval, ttype in tokens
+ if ttype != "guest"
+ ]
+
+ if len(user_tokens) < len(tokens):
+ guest_count = len(tokens) - len(user_tokens)
+ logger.warning(f"⚠️ 过滤了 {guest_count} 个匿名用户 Token")
+
+ tokens = user_tokens
+
+ # 始终创建 Token 池实例(即使为空)
+ with _pool_lock:
+ _token_pool = TokenPool(tokens, failure_threshold, recovery_timeout)
+
+ if not tokens:
+ logger.warning(f"⚠️ {provider} 没有有效的认证用户 Token,已创建空 Token 池")
+ else:
+ logger.info(f"🔧 从数据库初始化 Token 池({provider}),共 {len(tokens)} 个 Token")
+
+ return _token_pool
+
+
+async def sync_token_stats_to_db():
+ """
+ 将内存中的 Token 统计同步到数据库
+
+ 应在服务关闭或定期调用,确保统计数据不丢失
+ """
+ pool = get_token_pool()
+ if not pool:
+ return
+
+ from app.services.token_dao import get_token_dao
+
+ dao = get_token_dao()
+
+ pending_updates = []
+ with pool._lock:
+ for token, status in pool.token_statuses.items():
+ pending_success = max(
+ 0,
+ status.successful_requests - status.db_synced_successful_requests,
+ )
+ pending_failure = max(
+ 0,
+ status.failed_requests - status.db_synced_failed_requests,
+ )
+ if pending_success > 0 or pending_failure > 0:
+ pending_updates.append(
+ (
+ token,
+ status.token_id,
+ pending_success,
+ pending_failure,
+ )
+ )
+
+ for token, token_id, pending_success, pending_failure in pending_updates:
+ for _ in range(pending_success):
+ await dao.record_success(token_id)
+ for _ in range(pending_failure):
+ await dao.record_failure(token_id)
+
+ with pool._lock:
+ if token in pool.token_statuses:
+ status = pool.token_statuses[token]
+ status.db_synced_successful_requests += pending_success
+ status.db_synced_failed_requests += pending_failure
+
+ logger.info("✅ Token 统计已同步到数据库")
diff --git a/app/utils/tool_call_handler.py b/app/utils/tool_call_handler.py
new file mode 100644
index 0000000000000000000000000000000000000000..a82620f3de903f11198c8d6f5947237acb1bbe84
--- /dev/null
+++ b/app/utils/tool_call_handler.py
@@ -0,0 +1,347 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+"""
+工具调用处理模块
+"""
+
+import json
+import re
+from typing import Dict, List, Any, Optional, Tuple
+from app.utils.logger import get_logger
+
+logger = get_logger()
+
+
+def generate_tool_prompt(tools: Optional[List[Dict[str, Any]]]) -> str:
+ """
+ 生成工具调用提示词
+ 将 OpenAI tools 定义转换为 Markdown 格式的说明文档
+
+ Args:
+ tools: OpenAI 格式的工具定义列表
+
+ Returns:
+ str: Markdown 格式的工具使用说明
+ """
+ if not tools or len(tools) == 0:
+ return ""
+
+ tool_definitions = []
+
+ for tool in tools:
+ if tool.get("type") != "function":
+ continue
+
+ function_spec = tool.get("function", {})
+ function_name = function_spec.get("name", "unknown")
+ function_description = function_spec.get("description", "")
+ parameters = function_spec.get("parameters", {})
+
+ # 创建结构化的工具定义
+ tool_info = [
+ f"## {function_name}",
+ f"**Purpose**: {function_description}"
+ ]
+
+ # 添加参数详情
+ parameter_properties = parameters.get("properties", {})
+ required_parameters = set(parameters.get("required", []))
+
+ if parameter_properties:
+ tool_info.append("**Parameters**:")
+ for param_name, param_info in parameter_properties.items():
+ param_type = param_info.get("type", "string")
+ param_desc = param_info.get("description", "")
+ is_required = param_name in required_parameters
+ required_str = " (required)" if is_required else " (optional)"
+ tool_info.append(f"- `{param_name}` ({param_type}){required_str}: {param_desc}")
+
+ tool_definitions.append("\n".join(tool_info))
+
+ # 组合完整的提示词
+ prompt = (
+ "\n\n---\n"
+ "# Available Tools\n\n"
+ + "\n\n".join(tool_definitions) +
+ "\n\n"
+ "**Tool Invocation Format**:\n"
+ "To use a tool, include a JSON block with this structure:\n"
+ '{"tool_calls": [{"id": "call_ID", "type": "function", "function": {"name": "TOOL_NAME", "arguments": "JSON_STRING"}}]}\n\n'
+ "**Rules**:\n"
+ "- Use tool ONLY when user explicitly requests an action that matches a tool's purpose\n"
+ "- For normal conversation, respond naturally WITHOUT any tool calls\n"
+ "- The `arguments` must be a JSON string, not an object\n"
+ "- Multiple tools can be called by adding more items to the array\n"
+ "---\n\n"
+ )
+
+ logger.debug(f"生成工具提示词,包含 {len(tool_definitions)} 个工具定义")
+ return prompt
+
+
+def process_messages_with_tools(
+ messages: List[Dict[str, Any]],
+ tools: Optional[List[Dict[str, Any]]],
+ tool_choice: str = "auto"
+) -> List[Dict[str, Any]]:
+ """
+ 将工具定义注入到消息列表中
+
+ Args:
+ messages: 原始消息列表
+ tools: 工具定义列表
+ tool_choice: 工具选择策略 ("auto", "none", 等)
+
+ Returns:
+ List[Dict]: 处理后的消息列表
+ """
+ if not tools or tool_choice == "none":
+ return messages
+
+ tools_prompt = generate_tool_prompt(tools)
+ if not tools_prompt:
+ return messages
+
+ processed = []
+ has_system = any(m.get("role") == "system" for m in messages)
+
+ if has_system:
+ # 如果有 system 消息,将工具提示追加到第一个 system 消息
+ for msg in messages:
+ if msg.get("role") == "system":
+ new_msg = msg.copy()
+ content = new_msg.get("content", "")
+ if isinstance(content, list):
+ # 多模态内容
+ content_str = " ".join([
+ item.get("text", "") if item.get("type") == "text" else ""
+ for item in content
+ ])
+ else:
+ content_str = str(content)
+ new_msg["content"] = content_str + tools_prompt
+ processed.append(new_msg)
+ else:
+ processed.append(msg)
+ else:
+ # 没有 system 消息,创建一个新的 system 消息
+ processed.append({
+ "role": "system",
+ "content": f"You are a helpful assistant with access to tools.{tools_prompt}"
+ })
+ processed.extend(messages)
+
+ logger.debug(f"工具提示已注入到消息列表,共 {len(processed)} 条消息")
+ return processed
+
+
+def parse_and_extract_tool_calls(content: str) -> Tuple[Optional[List[Dict[str, Any]]], str]:
+ """
+ 从响应内容中提取 tool_calls JSON
+
+ Args:
+ content: 模型返回的文本内容
+
+ Returns:
+ Tuple[Optional[List], str]: (提取的 tool_calls 列表, 清理后的内容)
+ """
+ if not content or not content.strip():
+ return None, content
+
+ tool_calls = None
+ cleaned_content = content
+
+ # 方法1: 尝试解析 JSON 代码块中的 tool_calls
+ # 匹配 ```json ... ``` 或 ```...```
+ json_block_pattern = r'```(?:json)?\s*\n?(\{[\s\S]*?\})\s*\n?```'
+ json_blocks = re.findall(json_block_pattern, content)
+
+ for json_str in json_blocks:
+ try:
+ parsed_data = json.loads(json_str)
+ if "tool_calls" in parsed_data:
+ tool_calls = parsed_data["tool_calls"]
+ if tool_calls and isinstance(tool_calls, list):
+ # 确保 arguments 字段是字符串
+ for tc in tool_calls:
+ if tc.get("function"):
+ func = tc["function"]
+ if func.get("arguments"):
+ if isinstance(func["arguments"], dict):
+ # 转换对象为 JSON 字符串
+ func["arguments"] = json.dumps(func["arguments"], ensure_ascii=False)
+ elif not isinstance(func["arguments"], str):
+ func["arguments"] = str(func["arguments"])
+ logger.debug(f"从 JSON 代码块中提取到 {len(tool_calls)} 个工具调用")
+ break
+ except json.JSONDecodeError:
+ continue
+
+ # 方法2: 尝试从文本中直接查找 JSON 对象
+ if not tool_calls:
+ # 查找包含 "tool_calls" 的 JSON 对象
+ i = 0
+ scannable_text = content
+ while i < len(scannable_text):
+ if scannable_text[i] == '{':
+ # 尝试找到匹配的闭合括号
+ brace_count = 1
+ j = i + 1
+ in_string = False
+ escape_next = False
+
+ while j < len(scannable_text) and brace_count > 0:
+ if escape_next:
+ escape_next = False
+ elif scannable_text[j] == '\\':
+ escape_next = True
+ elif scannable_text[j] == '"':
+ in_string = not in_string
+ elif not in_string:
+ if scannable_text[j] == '{':
+ brace_count += 1
+ elif scannable_text[j] == '}':
+ brace_count -= 1
+ j += 1
+
+ if brace_count == 0:
+ # 找到完整的 JSON 对象
+ json_candidate = scannable_text[i:j]
+ try:
+ parsed_data = json.loads(json_candidate)
+ if "tool_calls" in parsed_data:
+ tool_calls = parsed_data["tool_calls"]
+ if tool_calls and isinstance(tool_calls, list):
+ # 确保 arguments 字段是字符串
+ for tc in tool_calls:
+ if tc.get("function"):
+ func = tc["function"]
+ if func.get("arguments"):
+ if isinstance(func["arguments"], dict):
+ func["arguments"] = json.dumps(func["arguments"], ensure_ascii=False)
+ elif not isinstance(func["arguments"], str):
+ func["arguments"] = str(func["arguments"])
+ logger.debug(f"从内联 JSON 中提取到 {len(tool_calls)} 个工具调用")
+ break
+ except json.JSONDecodeError:
+ pass
+
+ i = j
+ else:
+ i += 1
+
+ # 清理内容 - 移除包含 tool_calls 的 JSON
+ if tool_calls:
+ cleaned_content = remove_tool_json_content(content)
+
+ return tool_calls, cleaned_content
+
+
+def remove_tool_json_content(content: str) -> str:
+ """
+ 从响应内容中移除工具调用 JSON
+
+ Args:
+ content: 原始响应内容
+
+ Returns:
+ str: 清理后的内容
+ """
+ if not content:
+ return content
+
+ # 步骤1: 移除 JSON 代码块中包含 tool_calls 的部分
+ cleaned_text = content
+
+ # 匹配 ```json ... ``` 或 ```...```
+ def replace_json_block(match):
+ json_content = match.group(1)
+ try:
+ parsed_data = json.loads(json_content)
+ if "tool_calls" in parsed_data:
+ return "" # 移除整个代码块
+ except json.JSONDecodeError:
+ pass
+ return match.group(0) # 保留原文
+
+ json_block_pattern = r'```(?:json)?\s*\n?(\{[\s\S]*?\})\s*\n?```'
+ cleaned_text = re.sub(json_block_pattern, replace_json_block, cleaned_text)
+
+ # 步骤2: 移除内联的 tool JSON - 使用括号平衡方法
+ result = []
+ i = 0
+
+ while i < len(cleaned_text):
+ if cleaned_text[i] == '{':
+ # 尝试找到匹配的闭合括号
+ brace_count = 1
+ j = i + 1
+ in_string = False
+ escape_next = False
+
+ while j < len(cleaned_text) and brace_count > 0:
+ if escape_next:
+ escape_next = False
+ elif cleaned_text[j] == '\\':
+ escape_next = True
+ elif cleaned_text[j] == '"':
+ in_string = not in_string
+ elif not in_string:
+ if cleaned_text[j] == '{':
+ brace_count += 1
+ elif cleaned_text[j] == '}':
+ brace_count -= 1
+ j += 1
+
+ if brace_count == 0:
+ # 找到完整的 JSON 对象,检查是否包含 tool_calls
+ json_candidate = cleaned_text[i:j]
+ try:
+ parsed = json.loads(json_candidate)
+ if "tool_calls" in parsed:
+ # 这是一个工具调用,跳过它
+ i = j
+ continue
+ except json.JSONDecodeError:
+ pass
+
+ # 不是工具调用或无法解析,保留这个字符
+ result.append(cleaned_text[i])
+ i += 1
+ else:
+ result.append(cleaned_text[i])
+ i += 1
+
+ cleaned_result = "".join(result).strip()
+
+ # 移除多余的空白行
+ cleaned_result = re.sub(r'\n{3,}', '\n\n', cleaned_result)
+
+ logger.debug(f"内容清理完成,原始长度: {len(content)}, 清理后长度: {len(cleaned_result)}")
+ return cleaned_result
+
+
+def content_to_string(content: Any) -> str:
+ """
+ 将消息内容转换为字符串
+
+ Args:
+ content: 消息内容,可能是字符串或列表(多模态)
+
+ Returns:
+ str: 字符串格式的内容
+ """
+ if isinstance(content, str):
+ return content
+ elif isinstance(content, list):
+ # 多模态内容,提取文本部分
+ text_parts = []
+ for item in content:
+ if isinstance(item, dict):
+ if item.get("type") == "text":
+ text_parts.append(item.get("text", ""))
+ elif isinstance(item, str):
+ text_parts.append(item)
+ return " ".join(text_parts)
+ else:
+ return str(content)
diff --git a/app/utils/user_agent.py b/app/utils/user_agent.py
new file mode 100644
index 0000000000000000000000000000000000000000..cc6bbbe0df929ca66cbabb881e13d403f57d1866
--- /dev/null
+++ b/app/utils/user_agent.py
@@ -0,0 +1,133 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+
+"""
+用户代理工具模块
+提供动态随机用户代理生成功能
+"""
+
+import random
+from typing import Dict, Optional
+from fake_useragent import UserAgent
+
+# 全局 UserAgent 实例(单例模式)
+_user_agent_instance: Optional[UserAgent] = None
+
+
+def get_user_agent_instance() -> UserAgent:
+ """获取或创建 UserAgent 实例(单例模式)"""
+ global _user_agent_instance
+ if _user_agent_instance is None:
+ _user_agent_instance = UserAgent()
+ return _user_agent_instance
+
+
+def get_random_user_agent(browser_type: Optional[str] = None) -> str:
+ """
+ 获取随机用户代理字符串
+
+ Args:
+ browser_type: 指定浏览器类型 ('chrome', 'firefox', 'safari', 'edge')
+ 如果为 None,则随机选择
+
+ Returns:
+ str: 用户代理字符串
+ """
+ ua = get_user_agent_instance()
+
+ # 如果没有指定浏览器类型,随机选择一个(偏向 Chrome 和 Edge)
+ if browser_type is None:
+ browser_choices = ["chrome", "chrome", "chrome", "edge", "edge", "firefox", "safari"]
+ browser_type = random.choice(browser_choices)
+
+ # 根据浏览器类型获取用户代理
+ if browser_type == "chrome":
+ user_agent = ua.chrome
+ elif browser_type == "edge":
+ user_agent = ua.edge
+ elif browser_type == "firefox":
+ user_agent = ua.firefox
+ elif browser_type == "safari":
+ user_agent = ua.safari
+ else:
+ user_agent = ua.random
+
+ return user_agent
+
+
+# 通用 UserAgent headers 生成函数
+def get_dynamic_headers(
+ referer: Optional[str] = None,
+ origin: Optional[str] = None,
+ browser_type: Optional[str] = None,
+ additional_headers: Optional[Dict[str, str]] = None
+) -> Dict[str, str]:
+ """
+ 生成动态浏览器 headers,包含随机 User-Agent
+
+ Args:
+ referer: 引用页面 URL
+ origin: 源站 URL
+ browser_type: 指定浏览器类型
+ additional_headers: 额外的 headers
+
+ Returns:
+ Dict[str, str]: 包含动态 User-Agent 的 headers
+ """
+ user_agent = get_random_user_agent(browser_type)
+
+ # 基础 headers
+ headers = {
+ "User-Agent": user_agent,
+ "Accept": "application/json, text/event-stream",
+ "Accept-Language": "zh-CN,zh;q=0.9,en;q=0.8",
+ "Accept-Encoding": "gzip, deflate, br",
+ "Cache-Control": "no-cache",
+ "Connection": "keep-alive",
+ "Pragma": "no-cache",
+ }
+
+ # 添加可选的 headers
+ if referer:
+ headers["Referer"] = referer
+
+ if origin:
+ headers["Origin"] = origin
+
+ # 根据用户代理添加浏览器特定的 headers
+ if "Chrome/" in user_agent or "Edg/" in user_agent:
+ # Chrome/Edge 特定的 headers
+ chrome_version = "139"
+ edge_version = "139"
+
+ try:
+ if "Chrome/" in user_agent:
+ chrome_version = user_agent.split("Chrome/")[1].split(".")[0]
+ except:
+ pass
+
+ try:
+ if "Edg/" in user_agent:
+ edge_version = user_agent.split("Edg/")[1].split(".")[0]
+ sec_ch_ua = f'"Microsoft Edge";v="{edge_version}", "Chromium";v="{chrome_version}", "Not_A Brand";v="24"'
+ else:
+ sec_ch_ua = f'"Not_A Brand";v="8", "Chromium";v="{chrome_version}", "Google Chrome";v="{chrome_version}"'
+ except:
+ sec_ch_ua = f'"Not_A Brand";v="8", "Chromium";v="{chrome_version}", "Google Chrome";v="{chrome_version}"'
+
+ headers.update({
+ "sec-ch-ua": sec_ch_ua,
+ "sec-ch-ua-mobile": "?0",
+ "sec-ch-ua-platform": '"Windows"',
+ "Sec-Fetch-Dest": "empty",
+ "Sec-Fetch-Mode": "cors",
+ "Sec-Fetch-Site": "same-origin",
+ })
+
+ # 添加额外的 headers
+ if additional_headers:
+ headers.update(additional_headers)
+
+ return headers
+
+
diff --git a/deploy/.dockerignore b/deploy/.dockerignore
new file mode 100644
index 0000000000000000000000000000000000000000..5418af614bcf80c659dbbd9c383dd8b2961ba61f
--- /dev/null
+++ b/deploy/.dockerignore
@@ -0,0 +1,54 @@
+# Git files
+.git
+.gitignore
+.gitattributes
+
+# Python cache
+__pycache__
+*.py[cod]
+*$py.class
+*.so
+.Python
+
+# Virtual environments
+venv/
+env/
+ENV/
+
+# IDE
+.vscode/
+.idea/
+*.swp
+*.swo
+*~
+
+# Documentation
+*.md
+!README.md
+docs/
+
+# Test files
+tests/
+pytest.ini
+.pytest_cache/
+
+# Local data (will be mounted as volumes)
+*.db
+*.sqlite
+*.sqlite3
+logs/
+data/
+
+# Build artifacts
+build/
+dist/
+*.egg-info/
+
+# Docker files in parent directory
+Dockerfile
+docker-compose.yml
+.dockerignore
+
+# Other
+.env.local
+.DS_Store
diff --git a/deploy/.env.example b/deploy/.env.example
new file mode 100644
index 0000000000000000000000000000000000000000..50cec592a3ccbcd64416a37501afccfe36a7248a
--- /dev/null
+++ b/deploy/.env.example
@@ -0,0 +1,35 @@
+# ==============================================
+# Z.AI API Server - Docker 环境变量配置示例
+# ==============================================
+
+# 管理后台密码
+ADMIN_PASSWORD=admin123
+
+# API 认证密钥 (用于验证客户端请求)
+AUTH_TOKEN=sk-your-api-key-here
+
+# 是否跳过 API Key 验证 (开发环境可设为 true)
+SKIP_AUTH_TOKEN=false
+
+# 调试日志 (生产环境建议设为 false)
+DEBUG_LOGGING=true
+
+# 匿名模式 (允许无 token 访问,需要配合 SKIP_AUTH_TOKEN=true)
+ANONYMOUS_MODE=false
+
+# Function Call 功能开关 (是否支持工具调用)
+TOOL_SUPPORT=true
+
+# 工具调用扫描限制 (字符数)
+SCAN_LIMIT=200000
+
+# 数据库路径 (Docker 环境使用持久化卷)
+DB_PATH=/app/data/tokens.db
+
+# Token 池配置
+TOKEN_FAILURE_THRESHOLD=3
+TOKEN_RECOVERY_TIMEOUT=300
+
+# 服务配置
+SERVICE_NAME=Z.AI_API_Server
+LISTEN_PORT=8080
diff --git a/deploy/Dockerfile b/deploy/Dockerfile
new file mode 100644
index 0000000000000000000000000000000000000000..4dd5c3796b188159627088f97545b6fdf358c811
--- /dev/null
+++ b/deploy/Dockerfile
@@ -0,0 +1,24 @@
+FROM python:3.12-slim
+
+# Set working directory
+WORKDIR /app
+
+# Create data and logs directories with proper permissions
+RUN mkdir -p /app/data /app/logs && \
+ chmod 755 /app/data /app/logs
+
+# Install dependencies
+COPY requirements.txt .
+RUN pip install --no-cache-dir -r requirements.txt
+
+# Copy application code
+COPY . .
+
+# Set environment variable for database path
+ENV DB_PATH=/app/data/tokens.db
+
+# Expose port
+EXPOSE 8080
+
+# Run the application
+CMD ["python", "main.py"]
diff --git a/deploy/NGINX_SETUP.md b/deploy/NGINX_SETUP.md
new file mode 100644
index 0000000000000000000000000000000000000000..7f45afdd9b13c085d35e4c0fe7daecdd814b77bc
--- /dev/null
+++ b/deploy/NGINX_SETUP.md
@@ -0,0 +1,278 @@
+# Nginx 反向代理部署指南
+
+本文档说明如何在 Nginx 反向代理后部署 Z.AI2API,支持自定义路径前缀。
+
+## 问题说明
+
+在使用 Nginx 反向代理时,如果需要将服务部署在自定义路径前缀下(例如 `http://domain.com/ai2api`),
+需要正确配置 `ROOT_PATH` 环境变量,否则会出现以下问题:
+
+- 后台管理页面跳转错误(缺少路径前缀)
+- API 接口请求 404(路径不完整)
+- 静态资源加载失败
+
+## 解决方案
+
+### 1. 配置环境变量
+
+在 `.env` 文件中设置 `ROOT_PATH` 变量,值为 Nginx 配置的 location 路径:
+
+```bash
+# 示例:部署在 /ai2api 路径下
+ROOT_PATH=/ai2api
+```
+
+**重要**: `ROOT_PATH` 必须与 Nginx 配置中的 `location` 路径完全一致。
+
+### 2. 配置 Nginx
+
+参考 `deploy/nginx.conf.example` 文件,选择合适的配置模板。
+
+#### 基础配置示例
+
+```nginx
+server {
+ listen 80;
+ server_name your-domain.com;
+
+ location /ai2api {
+ # 代理到后端服务
+ proxy_pass http://127.0.0.1:8080;
+
+ # 传递原始请求信息
+ proxy_set_header Host $host;
+ proxy_set_header X-Real-IP $remote_addr;
+ proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
+ proxy_set_header X-Forwarded-Proto $scheme;
+
+ # SSE 流式响应支持
+ proxy_http_version 1.1;
+ proxy_set_header Upgrade $http_upgrade;
+ proxy_set_header Connection "upgrade";
+ proxy_buffering off;
+ proxy_cache off;
+
+ # 超时设置
+ proxy_read_timeout 300s;
+ }
+}
+```
+
+### 3. Docker Compose 配置
+
+如果使用 Docker 部署,需要在 `docker-compose.yml` 中添加 `ROOT_PATH` 环境变量:
+
+```yaml
+version: '3.8'
+services:
+ ai2api:
+ image: z-ai2api:latest
+ environment:
+ - ROOT_PATH=/ai2api
+ - LISTEN_PORT=8080
+ # ... 其他环境变量
+ ports:
+ - "8080:8080"
+```
+
+### 4. 重启服务
+
+```bash
+# 重载 Nginx 配置
+sudo nginx -t
+sudo systemctl reload nginx
+
+# 重启应用(Docker)
+docker-compose restart
+
+# 或重启应用(直接运行)
+# 停止服务后重新启动
+```
+
+## 访问地址
+
+配置完成后,服务访问地址如下:
+
+- **API 端点**: `http://your-domain.com/ai2api/v1/chat/completions`
+- **模型列表**: `http://your-domain.com/ai2api/v1/models`
+- **管理后台**: `http://your-domain.com/ai2api/admin/login`
+- **根路径**: `http://your-domain.com/ai2api/`
+
+## 配置示例
+
+### 示例 1: 部署在 /api 路径下
+
+**.env 配置**:
+```bash
+ROOT_PATH=/api
+```
+
+**Nginx 配置**:
+```nginx
+location /api {
+ proxy_pass http://127.0.0.1:8080;
+ # ... 其他配置
+}
+```
+
+**访问地址**: `http://domain.com/api/admin/login`
+
+### 示例 2: 部署在根路径(无前缀)
+
+**.env 配置**:
+```bash
+ROOT_PATH=
+```
+
+**Nginx 配置**:
+```nginx
+location / {
+ proxy_pass http://127.0.0.1:8080;
+ # ... 其他配置
+}
+```
+
+**访问地址**: `http://domain.com/admin/login`
+
+### 示例 3: 多级路径前缀
+
+**.env 配置**:
+```bash
+ROOT_PATH=/services/ai/chat
+```
+
+**Nginx 配置**:
+```nginx
+location /services/ai/chat {
+ proxy_pass http://127.0.0.1:8080;
+ # ... 其他配置
+}
+```
+
+**访问地址**: `http://domain.com/services/ai/chat/admin/login`
+
+## 常见问题排查
+
+### 1. 404 错误
+
+**现象**: 访问页面或 API 时返回 404
+
+**可能原因**:
+- `ROOT_PATH` 配置与 Nginx location 路径不匹配
+- Nginx 配置错误或未重载
+
+**解决方法**:
+- 检查 `.env` 中的 `ROOT_PATH` 是否与 Nginx `location` 完全一致
+- 确认 Nginx 配置无误: `sudo nginx -t`
+- 重载 Nginx: `sudo systemctl reload nginx`
+- 重启应用服务
+
+### 2. 静态资源加载失败
+
+**现象**: 管理后台页面样式错乱,控制台显示 CSS/JS 404
+
+**可能原因**:
+- `ROOT_PATH` 未配置或配置错误
+- 静态文件路径未包含前缀
+
+**解决方法**:
+- 确保 `ROOT_PATH` 正确配置并重启服务
+- 检查浏览器开发者工具中的资源请求路径
+
+### 3. 流式响应中断
+
+**现象**: SSE 流式响应提前终止或无法正常工作
+
+**可能原因**:
+- Nginx 启用了缓冲
+- 超时时间设置过短
+
+**解决方法**:
+在 Nginx 配置中添加:
+```nginx
+proxy_buffering off;
+proxy_cache off;
+proxy_read_timeout 300s;
+```
+
+### 4. CORS 错误
+
+**现象**: 浏览器控制台显示跨域请求被阻止
+
+**可能原因**:
+- Nginx 未正确传递请求头
+
+**解决方法**:
+确保 Nginx 配置中包含:
+```nginx
+proxy_set_header Host $host;
+proxy_set_header X-Forwarded-Proto $scheme;
+```
+
+## 验证配置
+
+配置完成后,可以通过以下方式验证:
+
+1. **访问健康检查端点**:
+ ```bash
+ curl http://your-domain.com/ai2api/v1/models
+ ```
+
+2. **访问管理后台**:
+ 在浏览器打开 `http://your-domain.com/ai2api/admin/login`
+
+3. **测试 API 请求**:
+ ```bash
+ curl -X POST http://your-domain.com/ai2api/v1/chat/completions \
+ -H "Content-Type: application/json" \
+ -H "Authorization: Bearer your-api-key" \
+ -d '{
+ "model": "GLM-4.6",
+ "messages": [{"role": "user", "content": "Hello"}],
+ "stream": false
+ }'
+ ```
+
+## 进阶配置
+
+### HTTPS 配置
+
+```nginx
+server {
+ listen 443 ssl http2;
+ server_name your-domain.com;
+
+ ssl_certificate /path/to/cert.pem;
+ ssl_certificate_key /path/to/key.pem;
+
+ location /ai2api {
+ proxy_pass http://127.0.0.1:8080;
+ proxy_set_header X-Forwarded-Proto https;
+ # ... 其他配置
+ }
+}
+```
+
+### 负载均衡
+
+```nginx
+upstream ai2api_backend {
+ server 127.0.0.1:8080;
+ server 127.0.0.1:8081;
+ server 127.0.0.1:8082;
+}
+
+server {
+ listen 80;
+ location /ai2api {
+ proxy_pass http://ai2api_backend;
+ # ... 其他配置
+ }
+}
+```
+
+## 参考资料
+
+- [FastAPI Behind a Proxy](https://fastapi.tiangolo.com/advanced/behind-a-proxy/)
+- [Nginx Proxy Module](http://nginx.org/en/docs/http/ngx_http_proxy_module.html)
+- 完整配置示例: `deploy/nginx.conf.example`
diff --git a/deploy/README_DOCKER.md b/deploy/README_DOCKER.md
new file mode 100644
index 0000000000000000000000000000000000000000..a739b4a468b23d79dcebe9c7eb9a33a3eb78c4a4
--- /dev/null
+++ b/deploy/README_DOCKER.md
@@ -0,0 +1,357 @@
+# Docker 部署文档
+
+## 快速部署
+
+### 方式一: 使用预构建镜像 (推荐)
+
+从 Docker Hub 拉取镜像:
+
+```bash
+# 拉取最新镜像
+docker pull zyphrzero/z-ai2api-python:latest
+
+# 创建数据目录
+mkdir -p data logs
+
+# 快速启动
+docker run -d \
+ --name z-ai-api-server \
+ -p 8080:8080 \
+ -e ADMIN_PASSWORD=admin123 \
+ -e AUTH_TOKEN=sk-your-api-key \
+ -e ANONYMOUS_MODE=true \
+ -e DB_PATH=/app/data/tokens.db \
+ -v $(pwd)/data:/app/data \
+ -v $(pwd)/logs:/app/logs \
+ --restart unless-stopped \
+ zyphrzero/z-ai2api-python:latest
+```
+
+**优势**:
+- ✅ 无需本地构建,节省时间
+- ✅ GitHub Actions 自动化构建,保证质量
+- ✅ 多架构支持 (amd64/arm64)
+- ✅ 镜像已优化,体积更小
+
+### 方式二: 使用本地构建
+
+适用于需要自定义修改代码的场景:
+
+```bash
+# 进入部署目录
+cd deploy
+
+# 启动服务 (会自动构建镜像)
+docker compose up -d
+
+# 查看日志
+docker compose logs -f api-server
+```
+
+服务将在 `http://localhost:8080` 启动。
+
+## 架构说明
+
+### 持久化存储
+
+容器使用卷映射实现数据持久化:
+
+```yaml
+volumes:
+ - ./data:/app/data # 数据库存储 (tokens.db)
+ - ./logs:/app/logs # 应用日志
+```
+
+**目录结构**:
+```
+deploy/
+├── data/
+│ └── tokens.db # SQLite 数据库 (自动创建)
+├── logs/ # 应用日志 (自动创建)
+├── docker-compose.yml
+├── Dockerfile
+└── README_DOCKER.md
+```
+
+### 环境变量
+
+核心配置参数 (在 `docker-compose.yml` 中设置):
+
+| 变量 | 默认值 | 说明 |
+|------|--------|------|
+| `DB_PATH` | `/app/data/tokens.db` | 数据库文件路径 |
+| `ADMIN_PASSWORD` | `admin123` | 管理后台密码 |
+| `AUTH_TOKEN` | `sk-your-api-key` | API 认证密钥 |
+| `SKIP_AUTH_TOKEN` | `false` | 跳过 API 验证 |
+| `ANONYMOUS_MODE` | `true` | 匿名访问模式 |
+| `DEBUG_LOGGING` | `true` | 调试日志开关 |
+| `TOOL_SUPPORT` | `true` | Function Call 支持 |
+
+**生产环境建议**:
+- 修改 `ADMIN_PASSWORD` 和 `AUTH_TOKEN`
+- 设置 `DEBUG_LOGGING=false`
+- 设置 `ANONYMOUS_MODE=false`
+
+## 运维操作
+
+### 服务管理
+
+```bash
+# 启动服务
+docker compose up -d
+
+# 停止服务
+docker compose down
+
+# 重启服务
+docker compose restart
+
+# 查看状态
+docker compose ps
+
+# 实时日志
+docker compose logs -f
+```
+
+### 更新应用
+
+**使用预构建镜像**:
+
+```bash
+# 停止当前容器
+docker compose down
+
+# 拉取最新镜像
+docker pull zyphrzero/z-ai2api-python:latest
+
+# 启动新版本 (数据会自动保留)
+docker compose up -d
+
+# 清理旧镜像
+docker image prune -f
+```
+
+**使用本地构建**:
+
+```bash
+# 拉取最新代码
+git pull
+
+# 重新构建并启动 (数据会保留)
+docker compose up -d --build
+
+# 清理旧镜像
+docker image prune -f
+```
+
+### 数据备份与恢复
+
+**备份**:
+```bash
+# 备份数据库
+cp ./data/tokens.db ./data/tokens.db.backup.$(date +%Y%m%d_%H%M%S)
+
+# 完整备份
+tar -czf backup_$(date +%Y%m%d_%H%M%S).tar.gz ./data ./logs
+```
+
+**恢复**:
+```bash
+# 停止服务
+docker compose down
+
+# 恢复数据库
+cp ./data/tokens.db.backup.20250116_120000 ./data/tokens.db
+
+# 启动服务
+docker compose up -d
+```
+
+### 数据库迁移
+
+如需从其他位置迁移现有数据库:
+
+```bash
+# 使用迁移脚本
+./migrate_db.sh /path/to/existing/tokens.db
+
+# 或手动复制
+cp /opt/1panel/docker/compose/k2think/tokens.db ./data/
+chmod 644 ./data/tokens.db
+
+# 启动服务
+docker compose up -d
+```
+
+## 故障排查
+
+### 数据库初始化失败
+
+**错误**: `unable to open database file`
+
+**原因**: 目录权限或卷映射问题
+
+**解决**:
+```bash
+# 停止容器
+docker compose down
+
+# 确保目录存在
+mkdir -p ./data ./logs
+
+# 设置权限
+chmod 755 ./data ./logs
+
+# 重新构建并启动
+docker compose up -d --build
+```
+
+### 容器无法启动
+
+**检查步骤**:
+```bash
+# 查看详细日志
+docker compose logs api-server
+
+# 检查容器状态
+docker compose ps
+
+# 验证配置文件
+docker compose config
+```
+
+### 端口冲突
+
+如端口 8080 被占用,修改 `docker-compose.yml`:
+```yaml
+ports:
+ - "8081:8080" # 映射到宿主机 8081 端口
+```
+
+### 健康检查失败
+
+```bash
+# 检查健康状态
+docker compose ps
+
+# 手动测试接口
+curl http://localhost:8080/v1/models
+
+# 进入容器排查
+docker exec -it z-ai-api-server bash
+```
+
+## API 访问
+
+| 端点 | 地址 | 说明 |
+|------|------|------|
+| API 根路径 | `http://localhost:8080` | OpenAI 兼容 API |
+| 模型列表 | `http://localhost:8080/v1/models` | 获取可用模型 |
+| 管理后台 | `http://localhost:8080/admin` | Web 管理界面 |
+| API 文档 | `http://localhost:8080/docs` | OpenAPI/Swagger 文档 |
+| 健康检查 | `http://localhost:8080/v1/models` | 服务健康状态 |
+
+## 高级配置
+
+### 自定义数据库路径
+
+修改 `docker-compose.yml` 使用外部路径:
+
+```yaml
+volumes:
+ - /opt/mydata:/app/data # 使用绝对路径
+
+environment:
+ - DB_PATH=/app/data/tokens.db
+```
+
+### 使用 .env 文件
+
+创建 `.env` 文件 (基于 `.env.example`):
+
+```bash
+cp .env.example .env
+# 编辑配置
+vim .env
+```
+
+修改 `docker-compose.yml`:
+```yaml
+services:
+ api-server:
+ env_file: .env
+```
+
+### 启用日志轮转
+
+在生产环境配置 Docker 日志驱动:
+
+```yaml
+services:
+ api-server:
+ logging:
+ driver: "json-file"
+ options:
+ max-size: "10m"
+ max-file: "3"
+```
+
+### 资源限制
+
+限制容器资源使用:
+
+```yaml
+services:
+ api-server:
+ deploy:
+ resources:
+ limits:
+ cpus: '2'
+ memory: 2G
+ reservations:
+ cpus: '0.5'
+ memory: 512M
+```
+
+## 监控与日志
+
+### 查看日志
+
+```bash
+# 实时日志
+docker compose logs -f
+
+# 最近100行
+docker compose logs --tail=100
+
+# 特定时间段
+docker compose logs --since 30m
+
+# 导出日志
+docker compose logs > app.log
+```
+
+### 容器指标
+
+```bash
+# 资源使用情况
+docker stats z-ai-api-server
+
+# 容器详情
+docker inspect z-ai-api-server
+```
+
+## 安全建议
+
+1. **修改默认密码**: 更改 `ADMIN_PASSWORD` 和 `AUTH_TOKEN`
+2. **限制网络访问**: 生产环境使用反向代理 (Nginx/Caddy)
+3. **启用 HTTPS**: 配置 SSL 证书
+4. **定期备份**: 自动化数据库备份任务
+5. **日志审计**: 定期检查 `request_logs` 表
+6. **最小权限**: 避免以 root 运行容器
+
+## 参考资料
+
+- [Docker Compose 文档](https://docs.docker.com/compose/)
+- [项目主 README](../README.md)
+- [配置示例](.env.example)
diff --git a/deploy/docker-compose.yml b/deploy/docker-compose.yml
new file mode 100644
index 0000000000000000000000000000000000000000..7125f4c9b750ee8c81f3cb36551e366327fb8aa0
--- /dev/null
+++ b/deploy/docker-compose.yml
@@ -0,0 +1,35 @@
+services:
+ api-server:
+ build:
+ context: ..
+ dockerfile: deploy/Dockerfile
+ container_name: z-ai-api-server
+ ports:
+ - "8080:8080"
+ volumes:
+ # 数据库持久化存储
+ - ./data:/app/data
+ # 日志持久化存储(可选)
+ - ./logs:/app/logs
+ environment:
+ - ADMIN_PASSWORD=admin123
+ # Auth Configuration
+ - AUTH_TOKEN=sk-your-api-key
+ # 是否跳过api key验证
+ - SKIP_AUTH_TOKEN=false
+ # 调试日志
+ - DEBUG_LOGGING=true
+ # 匿名模式
+ - ANONYMOUS_MODE=true
+ # Function Call 功能开关
+ - TOOL_SUPPORT=true
+ # 工具调用扫描限制(字符数)
+ - SCAN_LIMIT=200000
+ # 数据库路径 - 使用持久化卷
+ - DB_PATH=/app/data/tokens.db
+ restart: unless-stopped
+ healthcheck:
+ test: ["CMD", "curl", "-f", "http://localhost:8080/v1/models"]
+ interval: 30s
+ timeout: 10s
+ retries: 3
diff --git a/deploy/nginx.conf.example b/deploy/nginx.conf.example
new file mode 100644
index 0000000000000000000000000000000000000000..a163e32a2c17a5f7b29c7d41da1d17bcf2332e6a
--- /dev/null
+++ b/deploy/nginx.conf.example
@@ -0,0 +1,157 @@
+# Nginx reverse proxy configuration example for Z.AI2API
+# This example shows how to deploy the service behind Nginx with a custom path prefix
+
+# Example 1: Deploy at http://your-domain.com/ai2api
+server {
+ listen 80;
+ server_name your-domain.com;
+
+ # Forward requests with /ai2api prefix to the backend service
+ location /ai2api {
+ # Remove trailing slash redirect (optional, but recommended)
+ rewrite ^(/ai2api)$ $1/ permanent;
+
+ # Proxy to the backend service
+ proxy_pass http://127.0.0.1:8080;
+
+ # Pass original host and IP information
+ proxy_set_header Host $host;
+ proxy_set_header X-Real-IP $remote_addr;
+ proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
+ proxy_set_header X-Forwarded-Proto $scheme;
+
+ # IMPORTANT: Tell the backend about the path prefix
+ # This ensures all generated URLs include the prefix
+ proxy_set_header X-Forwarded-Prefix /ai2api;
+
+ # WebSocket and SSE support (for streaming responses)
+ proxy_http_version 1.1;
+ proxy_set_header Upgrade $http_upgrade;
+ proxy_set_header Connection "upgrade";
+
+ # Disable buffering for streaming responses
+ proxy_buffering off;
+ proxy_cache off;
+
+ # Timeout settings (adjust as needed)
+ proxy_connect_timeout 60s;
+ proxy_send_timeout 300s;
+ proxy_read_timeout 300s;
+ }
+}
+
+# Example 2: Deploy at http://your-domain.com/api/chat
+server {
+ listen 80;
+ server_name example.com;
+
+ location /api/chat {
+ # Proxy configuration
+ proxy_pass http://127.0.0.1:8080;
+
+ # Headers
+ proxy_set_header Host $host;
+ proxy_set_header X-Real-IP $remote_addr;
+ proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
+ proxy_set_header X-Forwarded-Proto $scheme;
+ proxy_set_header X-Forwarded-Prefix /api/chat;
+
+ # SSE/WebSocket support
+ proxy_http_version 1.1;
+ proxy_set_header Upgrade $http_upgrade;
+ proxy_set_header Connection "upgrade";
+ proxy_buffering off;
+ proxy_cache off;
+ }
+}
+
+# Example 3: Deploy with SSL (HTTPS)
+server {
+ listen 443 ssl http2;
+ server_name secure.example.com;
+
+ # SSL configuration
+ ssl_certificate /path/to/cert.pem;
+ ssl_certificate_key /path/to/key.pem;
+
+ location /ai2api {
+ proxy_pass http://127.0.0.1:8080;
+
+ proxy_set_header Host $host;
+ proxy_set_header X-Real-IP $remote_addr;
+ proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
+ proxy_set_header X-Forwarded-Proto https;
+ proxy_set_header X-Forwarded-Prefix /ai2api;
+
+ # SSE/WebSocket support
+ proxy_http_version 1.1;
+ proxy_set_header Upgrade $http_upgrade;
+ proxy_set_header Connection "upgrade";
+ proxy_buffering off;
+ proxy_cache off;
+
+ # Security headers (optional)
+ add_header X-Content-Type-Options nosniff;
+ add_header X-Frame-Options DENY;
+ add_header X-XSS-Protection "1; mode=block";
+ }
+}
+
+# Example 4: Load balancing with multiple backend instances
+upstream ai2api_backend {
+ # Round-robin by default
+ server 127.0.0.1:8080;
+ server 127.0.0.1:8081;
+ server 127.0.0.1:8082;
+
+ # Or use least connections
+ # least_conn;
+
+ # Or use IP hash for session persistence
+ # ip_hash;
+}
+
+server {
+ listen 80;
+ server_name loadbalanced.example.com;
+
+ location /ai2api {
+ proxy_pass http://ai2api_backend;
+
+ proxy_set_header Host $host;
+ proxy_set_header X-Real-IP $remote_addr;
+ proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
+ proxy_set_header X-Forwarded-Proto $scheme;
+ proxy_set_header X-Forwarded-Prefix /ai2api;
+
+ proxy_http_version 1.1;
+ proxy_set_header Upgrade $http_upgrade;
+ proxy_set_header Connection "upgrade";
+ proxy_buffering off;
+ proxy_cache off;
+ }
+}
+
+# Important Notes:
+#
+# 1. Set ROOT_PATH in your .env file to match the Nginx location path:
+# ROOT_PATH=/ai2api
+#
+# 2. Restart both Nginx and the application after configuration changes:
+# sudo systemctl reload nginx
+# docker-compose restart (or restart your application)
+#
+# 3. Access URLs will include the prefix:
+# - Admin panel: http://your-domain.com/ai2api/admin/login
+# - API endpoint: http://your-domain.com/ai2api/v1/chat/completions
+# - Health check: http://your-domain.com/ai2api/v1/models
+#
+# 4. For Docker deployments, make sure to:
+# - Add ROOT_PATH to docker-compose.yml environment variables
+# - Expose the container port (8080 by default)
+#
+# 5. Common issues:
+# - 404 errors: Check that ROOT_PATH matches the Nginx location path exactly
+# - CORS errors: Verify proxy headers are set correctly
+# - Streaming not working: Ensure proxy_buffering is off
+# - Admin panel CSS/JS not loading: Confirm static files are served with the prefix
diff --git a/main.py b/main.py
new file mode 100644
index 0000000000000000000000000000000000000000..49392464d68af1752d4731fbfb8c7123bce61ee2
--- /dev/null
+++ b/main.py
@@ -0,0 +1,173 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+
+import os
+import sys
+from contextlib import asynccontextmanager
+
+from fastapi import FastAPI, Response
+from fastapi.middleware.cors import CORSMiddleware
+from fastapi.staticfiles import StaticFiles
+from granian import Granian
+
+from app.admin import api as admin_api
+from app.admin import routes as admin_routes
+from app.core import claude, openai
+from app.core.config import settings
+from app.core.upstream import UpstreamClient
+from app.utils.logger import setup_logger
+from app.utils.reload_config import RELOAD_CONFIG
+
+# Setup logger
+logger = setup_logger(log_dir="logs", debug_mode=settings.DEBUG_LOGGING)
+
+
+async def warmup_upstream_client():
+ """可选预热上游适配器,提前初始化动态依赖。"""
+ try:
+ client = UpstreamClient()
+ logger.info(
+ f"✅ 上游适配器已就绪,支持 {len(client.get_supported_models())} 个模型"
+ )
+ except Exception as exc:
+ logger.warning(f"⚠️ 上游适配器预热失败: {exc}")
+
+
+@asynccontextmanager
+async def lifespan(app: FastAPI):
+ # 初始化 Token 数据库
+ from app.services.request_log_dao import init_request_log_dao
+ from app.services.token_automation import (
+ run_directory_import,
+ start_token_automation_scheduler,
+ stop_token_automation_scheduler,
+ )
+ from app.services.token_dao import init_token_database
+
+ await init_token_database()
+ init_request_log_dao()
+
+ if (
+ settings.TOKEN_AUTO_IMPORT_ENABLED
+ and settings.TOKEN_AUTO_IMPORT_SOURCE_DIR.strip()
+ ):
+ try:
+ await run_directory_import(
+ settings.TOKEN_AUTO_IMPORT_SOURCE_DIR,
+ provider="zai",
+ )
+ logger.info("✅ 启动阶段已完成一次目录自动导入")
+ except Exception as exc:
+ logger.warning(f"⚠️ 启动阶段目录自动导入失败: {exc}")
+
+ # 从数据库初始化认证 token 池
+ from app.utils.token_pool import initialize_token_pool_from_db
+
+ token_pool = await initialize_token_pool_from_db(
+ provider="zai",
+ failure_threshold=settings.TOKEN_FAILURE_THRESHOLD,
+ recovery_timeout=settings.TOKEN_RECOVERY_TIMEOUT,
+ )
+
+ if not token_pool and not settings.ANONYMOUS_MODE:
+ logger.warning(
+ "⚠️ 未找到可用 Token 且未启用匿名模式,服务可能无法正常工作"
+ )
+
+ if settings.ANONYMOUS_MODE:
+ from app.utils.guest_session_pool import initialize_guest_session_pool
+
+ guest_pool = await initialize_guest_session_pool(
+ pool_size=settings.GUEST_POOL_SIZE,
+ )
+ guest_status = guest_pool.get_pool_status()
+ logger.info(
+ "🫥 匿名会话池已就绪: "
+ f"{guest_status.get('valid_sessions', 0)} 个可用会话"
+ )
+
+ await warmup_upstream_client()
+ await start_token_automation_scheduler()
+
+ yield
+
+ logger.info("🔄 应用正在关闭...")
+
+ await stop_token_automation_scheduler()
+
+ if settings.ANONYMOUS_MODE:
+ from app.utils.guest_session_pool import close_guest_session_pool
+
+ await close_guest_session_pool()
+
+
+# Create FastAPI app with lifespan
+# root_path is used for reverse proxy path prefix (e.g., /api or /path-prefix)
+app = FastAPI(lifespan=lifespan, root_path=settings.ROOT_PATH)
+
+# Add CORS middleware
+app.add_middleware(
+ CORSMiddleware,
+ allow_origins=["*"],
+ allow_credentials=True,
+ allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"],
+ allow_headers=["Content-Type", "Authorization"],
+)
+
+# 挂载web端静态文件目录
+try:
+ app.mount("/static", StaticFiles(directory="app/static"), name="static")
+except RuntimeError:
+ # 如果 static 目录不存在,创建它
+ os.makedirs("app/static/css", exist_ok=True)
+ os.makedirs("app/static/js", exist_ok=True)
+ app.mount("/static", StaticFiles(directory="app/static"), name="static")
+
+# Include API routers
+app.include_router(openai.router)
+app.include_router(claude.router)
+
+# Include admin routers
+app.include_router(admin_routes.router)
+app.include_router(admin_api.router)
+
+
+@app.options("/")
+async def handle_options():
+ """Handle OPTIONS requests"""
+ return Response(status_code=200)
+
+
+@app.get("/")
+async def root():
+ """Root endpoint"""
+ return {"message": "OpenAI Compatible API Server"}
+
+
+def run_server():
+ service_name = settings.SERVICE_NAME
+
+ logger.info(f"🚀 启动 {service_name} 服务...")
+ logger.info(f"📡 监听地址: 0.0.0.0:{settings.LISTEN_PORT}")
+ logger.info(f"🔧 调试模式: {'开启' if settings.DEBUG_LOGGING else '关闭'}")
+ logger.info(f"🔐 匿名模式: {'开启' if settings.ANONYMOUS_MODE else '关闭'}")
+
+ try:
+ Granian(
+ "main:app",
+ interface="asgi",
+ address="0.0.0.0",
+ port=settings.LISTEN_PORT,
+ reload=False, # 生产环境请关闭热重载
+ process_name=service_name, # 设置进程名称
+ **RELOAD_CONFIG, # 热重载配置
+ ).serve()
+ except KeyboardInterrupt:
+ logger.info("🛑 收到中断信号,正在关闭服务...")
+ except Exception as e:
+ logger.error(f"❌ 服务启动失败: {e}")
+ sys.exit(1)
+
+
+if __name__ == "__main__":
+ run_server()
diff --git a/pyproject.toml b/pyproject.toml
new file mode 100644
index 0000000000000000000000000000000000000000..a68ee5538c15fecf8723604d6bc9cb34c317dec0
--- /dev/null
+++ b/pyproject.toml
@@ -0,0 +1,71 @@
+[build-system]
+requires = ["hatchling"]
+build-backend = "hatchling.build"
+
+[project]
+name = "z-ai2api-python"
+version = "0.1.0"
+description = "一个为 Z.ai 提供 OpenAI 兼容接口的 Python 代理服务"
+readme = "README.md"
+requires-python = ">=3.9,<=3.12"
+license = { text = "MIT" }
+authors = [{ name = "Contributors" }]
+classifiers = [
+ "Development Status :: 4 - Beta",
+ "Intended Audience :: Developers",
+ "License :: OSI Approved :: MIT License",
+ "Operating System :: OS Independent",
+ "Programming Language :: Python :: 3",
+ "Programming Language :: Python :: 3.9",
+ "Programming Language :: Python :: 3.10",
+ "Programming Language :: Python :: 3.11",
+ "Programming Language :: Python :: 3.12",
+ "Topic :: Internet :: WWW/HTTP :: HTTP Servers",
+ "Topic :: Software Development :: Libraries :: Python Modules",
+]
+dependencies = [
+ "fastapi==0.116.1",
+ "granian[reload,pname]==2.5.2",
+ "httpx[http2,socks]==0.28.1",
+ "pydantic==2.11.7",
+ "pydantic-settings==2.10.1",
+ "pydantic-core==2.33.2",
+ "typing-inspection==0.4.1",
+ "fake-useragent==2.2.0",
+ "loguru==0.7.3",
+ "psutil>=7.0.0",
+ "json-repair==0.44.1",
+ "jinja2==3.1.4",
+ "aiosqlite==0.20.0",
+ "python-multipart==0.0.12",
+ "python-dotenv==1.0.1"
+]
+
+[project.scripts]
+z-ai2api = "main:app"
+
+[tool.hatch.build.targets.wheel]
+packages = ["."]
+
+[tool.uv]
+dev-dependencies = [
+ "pytest>=7.0.0",
+ "pytest-asyncio>=0.21.0",
+ "requests>=2.30.0",
+ "ruff>=0.1.0",
+]
+
+[tool.ruff]
+line-length = 88
+target-version = "py38"
+select = ["E", "F", "I", "B"]
+ignore = []
+
+[tool.ruff.isort]
+known-first-party = []
+
+[tool.pytest.ini_options]
+asyncio_mode = "auto"
+testpaths = ["tests"]
+python_files = ["test_*.py"]
+python_functions = ["test_*"]
diff --git a/tests/real_upstream_test_utils.py b/tests/real_upstream_test_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..67a64d74a340d0c423ec04b79a0853c10aaa2a04
--- /dev/null
+++ b/tests/real_upstream_test_utils.py
@@ -0,0 +1,65 @@
+import os
+from typing import Any
+
+import pytest
+
+from app.core import upstream as upstream_module
+from app.core.upstream import UpstreamClient, _extract_user_id_from_token
+
+REAL_AUTH_TOKEN_ENV = "REAL_AUTH_TOKEN_ENV"
+RED_2X2_PNG_DATA_URL = (
+ "data:image/png;base64,"
+ "iVBORw0KGgoAAAANSUhEUgAAAAIAAAACCAIAAAD91JpzAAAAEElEQVR42mP4z8AARAwQCgAf7gP9Y167WwAAAABJRU5ErkJggg=="
+)
+
+def install_real_auth(monkeypatch) -> str:
+ token = os.getenv(REAL_AUTH_TOKEN_ENV, "").strip()
+ if not token:
+ pytest.skip(f"需要设置环境变量 {REAL_AUTH_TOKEN_ENV}")
+
+ user_id = _extract_user_id_from_token(token)
+ if not user_id or user_id == "guest":
+ raise AssertionError(f"{REAL_AUTH_TOKEN_ENV} 不是可解析的认证 token")
+
+ async def fake_get_auth_info(
+ self,
+ excluded_tokens=None,
+ excluded_guest_user_ids=None,
+ ):
+ return {
+ "token": token,
+ "user_id": user_id,
+ "username": "RealUser",
+ "auth_mode": "authenticated",
+ "token_source": "env",
+ "guest_user_id": None,
+ }
+
+ monkeypatch.setattr(UpstreamClient, "get_auth_info", fake_get_auth_info)
+ monkeypatch.setattr(upstream_module, "get_token_pool", lambda: None)
+ monkeypatch.setattr(upstream_module, "get_guest_session_pool", lambda: None)
+ return token
+
+
+def install_real_anonymous(monkeypatch) -> None:
+ monkeypatch.setattr(upstream_module, "get_token_pool", lambda: None)
+ monkeypatch.setattr(upstream_module, "get_guest_session_pool", lambda: None)
+ monkeypatch.setattr(upstream_module.settings, "ANONYMOUS_MODE", True)
+
+
+def extract_content(payload: dict[str, Any]) -> str:
+ assert isinstance(payload, dict), payload
+ assert "error" not in payload, payload
+
+ choices = payload.get("choices") or []
+ assert choices, payload
+
+ message = choices[0].get("message") or {}
+ content = str(message.get("content") or "").strip()
+ assert content, payload
+ return content
+
+
+def assert_usage_present(payload: dict[str, Any]) -> None:
+ usage = payload.get("usage") or {}
+ assert int(usage.get("total_tokens") or 0) > 0, payload
diff --git a/tests/test_admin_config.py b/tests/test_admin_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..b49ddc85ac4c3f32ce9bd07d5cbf842030ebcd6e
--- /dev/null
+++ b/tests/test_admin_config.py
@@ -0,0 +1,231 @@
+from types import SimpleNamespace
+from urllib.parse import urlencode
+
+import pytest
+from jinja2 import Environment, FileSystemLoader
+from starlette.requests import Request
+
+from app.admin import api as admin_api
+from app.admin.config_manager import (
+ CONFIG_FIELD_SPECS,
+ build_config_page_data,
+ save_form_config,
+ save_source_config,
+ validate_env_source,
+)
+
+
+def _build_form_payload(**overrides):
+ payload = {}
+
+ for key, field in CONFIG_FIELD_SPECS.items():
+ value = overrides[key] if key in overrides else field.default_value
+ if field.value_type == "bool":
+ if value:
+ payload[key] = "on"
+ continue
+ payload[key] = "" if value is None else str(value)
+
+ return payload
+
+
+def _make_form_request(path: str, data: dict[str, str]) -> Request:
+ body = urlencode(data, doseq=True).encode()
+ sent = False
+
+ async def receive():
+ nonlocal sent
+ if sent:
+ return {"type": "http.request", "body": b"", "more_body": False}
+ sent = True
+ return {"type": "http.request", "body": body, "more_body": False}
+
+ scope = {
+ "type": "http",
+ "http_version": "1.1",
+ "method": "POST",
+ "scheme": "http",
+ "path": path,
+ "raw_path": path.encode(),
+ "query_string": b"",
+ "headers": [
+ (
+ b"content-type",
+ b"application/x-www-form-urlencoded",
+ )
+ ],
+ "client": ("testclient", 50000),
+ "server": ("testserver", 80),
+ }
+ return Request(scope, receive)
+
+
+@pytest.mark.asyncio
+async def test_build_config_page_data_includes_sections_and_override_status(
+ tmp_path,
+):
+ env_path = tmp_path / ".env"
+ example_path = tmp_path / ".env.example"
+ env_path.write_text(
+ "API_ENDPOINT=https://example.com/v1/chat\nDEBUG_LOGGING=true\n",
+ encoding="utf-8",
+ )
+ example_path.write_text("SERVICE_NAME=example\n", encoding="utf-8")
+
+ settings_stub = SimpleNamespace(
+ API_ENDPOINT="https://example.com/v1/chat",
+ DEBUG_LOGGING=True,
+ GLM5_MODEL="GLM-5",
+ ADMIN_PASSWORD="secret",
+ )
+
+ page_data = build_config_page_data(
+ settings_obj=settings_stub,
+ env_path=env_path,
+ env_example_path=example_path,
+ )
+
+ assert page_data["overview"]["total_sections"] >= 7
+ assert page_data["overview"]["total_fields"] >= 35
+ assert page_data["overview"]["overridden_fields"] == 2
+ assert page_data["overview"]["example_exists"] is True
+
+ field_map = {
+ field["key"]: field
+ for section in page_data["sections"]
+ for field in section["fields"]
+ }
+
+ assert field_map["API_ENDPOINT"]["source_label"] == ".env"
+ assert field_map["DEBUG_LOGGING"]["source_label"] == ".env"
+ assert field_map["GLM5_MODEL"]["source_label"] == "默认值"
+ assert field_map["ADMIN_PASSWORD"]["sensitive"] is True
+
+
+@pytest.mark.asyncio
+async def test_save_form_config_preserves_unmanaged_lines_and_updates_fields(
+ tmp_path,
+):
+ env_path = tmp_path / ".env"
+ env_path.write_text(
+ "CUSTOM_FLAG=keep\nSERVICE_NAME=old-service\n",
+ encoding="utf-8",
+ )
+
+ reloaded = False
+
+ async def fake_reload():
+ nonlocal reloaded
+ reloaded = True
+
+ payload = _build_form_payload(
+ SERVICE_NAME="new-service",
+ LISTEN_PORT=9090,
+ ROOT_PATH="/edge",
+ DEBUG_LOGGING=False,
+ TOKEN_AUTO_IMPORT_ENABLED=True,
+ TOKEN_AUTO_IMPORT_SOURCE_DIR="/srv/tokens",
+ HTTP_PROXY="http://127.0.0.1:7890",
+ ADMIN_PASSWORD="new-admin-password",
+ )
+
+ updates = await save_form_config(
+ payload,
+ reload_callback=fake_reload,
+ env_path=env_path,
+ )
+ content = env_path.read_text(encoding="utf-8")
+
+ assert reloaded is True
+ assert updates["SERVICE_NAME"] == "new-service"
+ assert updates["LISTEN_PORT"] == 9090
+ assert updates["TOKEN_AUTO_IMPORT_ENABLED"] is True
+ assert "CUSTOM_FLAG=keep" in content
+ assert "SERVICE_NAME=new-service" in content
+ assert "LISTEN_PORT=9090" in content
+ assert "ROOT_PATH=/edge" in content
+ assert "TOKEN_AUTO_IMPORT_ENABLED=true" in content
+ assert "TOKEN_AUTO_IMPORT_SOURCE_DIR=/srv/tokens" in content
+ assert "HTTP_PROXY=http://127.0.0.1:7890" in content
+
+
+@pytest.mark.asyncio
+async def test_save_source_config_rolls_back_file_when_reload_fails(tmp_path):
+ env_path = tmp_path / ".env"
+ env_path.write_text("SERVICE_NAME=old-service\n", encoding="utf-8")
+
+ async def failing_reload():
+ raise RuntimeError("reload failed")
+
+ with pytest.raises(RuntimeError, match="reload failed"):
+ await save_source_config(
+ "SERVICE_NAME=new-service\nLISTEN_PORT=8081\n",
+ reload_callback=failing_reload,
+ env_path=env_path,
+ )
+
+ assert env_path.read_text(encoding="utf-8") == "SERVICE_NAME=old-service\n"
+
+
+@pytest.mark.asyncio
+async def test_save_config_endpoint_returns_refresh_trigger(tmp_path, monkeypatch):
+ monkeypatch.chdir(tmp_path)
+ (tmp_path / ".env").write_text("SERVICE_NAME=before\n", encoding="utf-8")
+
+ async def fake_reload():
+ return None
+
+ monkeypatch.setattr(admin_api, "reload_settings", fake_reload)
+
+ request = _make_form_request(
+ "/admin/api/config/save",
+ _build_form_payload(
+ SERVICE_NAME="after",
+ LISTEN_PORT=8081,
+ DEBUG_LOGGING=True,
+ ),
+ )
+ response = await admin_api.save_config(request)
+ body = response.body.decode("utf-8")
+
+ assert response.status_code == 200
+ assert response.headers["HX-Trigger"] == "admin-config-refresh"
+ assert "保存成功" in body
+ assert "SERVICE_NAME=after" in (tmp_path / ".env").read_text(encoding="utf-8")
+
+
+@pytest.mark.asyncio
+async def test_save_config_source_endpoint_rejects_invalid_source(
+ tmp_path,
+ monkeypatch,
+):
+ monkeypatch.chdir(tmp_path)
+ (tmp_path / ".env").write_text("SERVICE_NAME=before\n", encoding="utf-8")
+
+ async def fake_reload():
+ return None
+
+ monkeypatch.setattr(admin_api, "reload_settings", fake_reload)
+
+ request = _make_form_request(
+ "/admin/api/config/source",
+ {"env_content": "SERVICE_NAME=after\nnot-valid-line\n"},
+ )
+ response = await admin_api.save_config_source(request)
+ body = response.body.decode("utf-8")
+
+ assert response.status_code == 400
+ assert "KEY=VALUE" in body
+ assert (tmp_path / ".env").read_text(encoding="utf-8") == "SERVICE_NAME=before\n"
+
+
+def test_validate_env_source_rejects_invalid_lines():
+ with pytest.raises(ValueError, match="KEY=VALUE"):
+ validate_env_source("SERVICE_NAME=ok\nbad line\n")
+
+
+def test_config_template_compiles():
+ env = Environment(loader=FileSystemLoader("app/templates"))
+ template = env.get_template("config.html")
+
+ assert template is not None
diff --git a/tests/test_admin_stats.py b/tests/test_admin_stats.py
new file mode 100644
index 0000000000000000000000000000000000000000..c0608beebe1db15392f71ebac954517ec42655d0
--- /dev/null
+++ b/tests/test_admin_stats.py
@@ -0,0 +1,491 @@
+import json
+from datetime import datetime
+from urllib.parse import urlencode
+
+import pytest
+from starlette.requests import Request
+
+from app.admin import api as admin_api
+from app.admin.stats import collect_admin_stats, format_uptime
+from app.services import token_dao as token_dao_module
+from app.services.request_log_dao import RequestLogDAO
+from app.services.token_dao import TokenDAO
+from app.utils import token_pool as token_pool_module
+from app.utils.token_pool import TokenPool, sync_token_stats_to_db
+
+
+class DummyPool:
+ def __init__(self, status):
+ self._status = status
+
+ def get_pool_status(self):
+ return self._status
+
+
+def _make_get_request(path: str, query: dict[str, str] | None = None) -> Request:
+ query_string = urlencode(query or {}).encode()
+
+ async def receive():
+ return {"type": "http.request", "body": b"", "more_body": False}
+
+ scope = {
+ "type": "http",
+ "http_version": "1.1",
+ "method": "GET",
+ "scheme": "http",
+ "path": path,
+ "raw_path": path.encode(),
+ "query_string": query_string,
+ "headers": [],
+ "client": ("testclient", 50000),
+ "server": ("testserver", 80),
+ }
+ return Request(scope, receive)
+
+
+@pytest.mark.asyncio
+async def test_collect_admin_stats_uses_request_logs_and_token_inventory(tmp_path):
+ db_path = tmp_path / "admin_stats.db"
+ token_dao = TokenDAO(str(db_path))
+ await token_dao.init_database()
+ request_log_dao = RequestLogDAO(str(db_path))
+
+ await token_dao.add_token("zai", "token-user-1", validate=False)
+ await token_dao.add_token("zai", "token-user-2", validate=False)
+ await token_dao.add_token(
+ "zai",
+ "token-guest-1",
+ token_type="guest",
+ validate=False,
+ )
+ unknown_token_id = await token_dao.add_token(
+ "zai",
+ "token-unknown-1",
+ token_type="unknown",
+ validate=False,
+ )
+ await token_dao.update_token_status(int(unknown_token_id), False)
+
+ await request_log_dao.add_log(
+ provider="zai",
+ endpoint="/v1/chat/completions",
+ source="pytest",
+ protocol="openai",
+ client_name="pytest",
+ model="glm-5",
+ status_code=200,
+ success=True,
+ duration=0.5,
+ input_tokens=100,
+ output_tokens=40,
+ cache_read_tokens=20,
+ total_tokens=140,
+ )
+ await request_log_dao.add_log(
+ provider="zai",
+ endpoint="/v1/chat/completions",
+ source="pytest",
+ protocol="openai",
+ client_name="pytest",
+ model="glm-5",
+ status_code=500,
+ success=False,
+ duration=1.2,
+ input_tokens=60,
+ output_tokens=10,
+ cache_creation_tokens=15,
+ total_tokens=70,
+ error_message="upstream failed",
+ )
+ await request_log_dao.add_log(
+ provider="zai",
+ endpoint="/v1/messages",
+ source="pytest",
+ protocol="anthropic",
+ client_name="pytest",
+ model="glm-4.5",
+ status_code=200,
+ success=True,
+ duration=0.9,
+ input_tokens=30,
+ output_tokens=20,
+ total_tokens=50,
+ )
+ await request_log_dao.add_log(
+ provider="other",
+ endpoint="/ignored",
+ source="pytest",
+ protocol="openai",
+ client_name="pytest",
+ model="glm-ignored",
+ status_code=200,
+ success=True,
+ duration=0.1,
+ )
+
+ stats = await collect_admin_stats(
+ "zai",
+ token_dao=token_dao,
+ request_log_dao=request_log_dao,
+ token_pool=DummyPool(
+ {
+ "total_tokens": 2,
+ "available_tokens": 1,
+ "healthy_tokens": 1,
+ "unhealthy_tokens": 1,
+ }
+ ),
+ )
+
+ assert stats["total_tokens"] == 4
+ assert stats["enabled_tokens"] == 3
+ assert stats["user_tokens"] == 2
+ assert stats["guest_tokens"] == 1
+ assert stats["unknown_tokens"] == 1
+ assert stats["pool_total_tokens"] == 2
+ assert stats["available_tokens"] == 1
+ assert stats["healthy_tokens"] == 1
+ assert stats["unhealthy_tokens"] == 1
+ assert stats["total_requests"] == 3
+ assert stats["successful_requests"] == 2
+ assert stats["failed_requests"] == 1
+ assert stats["success_rate"] == pytest.approx(66.7)
+ assert stats["input_tokens"] == 190
+ assert stats["output_tokens"] == 70
+ assert stats["total_consumed_tokens"] == 260
+ assert stats["cache_creation_tokens"] == 15
+ assert stats["cache_read_tokens"] == 20
+ assert stats["total_cache_tokens"] == 35
+ assert stats["cache_creation_requests"] == 1
+ assert stats["cache_hit_requests"] == 1
+ assert stats["average_latency"] == pytest.approx(0.87, rel=1e-2)
+ assert stats["trend_window"] == "7d"
+ assert len(stats["usage_trend"]) == 7
+ assert stats["usage_trend"][-1]["total_tokens"] == 260
+ assert stats["usage_trend"][-1]["cache_total_tokens"] == 35
+
+
+@pytest.mark.asyncio
+async def test_get_model_stats_from_db_includes_recent_same_day_logs(tmp_path):
+ dao = RequestLogDAO(str(tmp_path / "request_logs.db"))
+
+ await dao.add_log(
+ provider="zai",
+ endpoint="/v1/chat/completions",
+ source="pytest",
+ protocol="openai",
+ client_name="pytest",
+ model="glm-5",
+ status_code=200,
+ success=True,
+ duration=0.25,
+ input_tokens=10,
+ output_tokens=20,
+ )
+
+ stats = await dao.get_model_stats_from_db(hours=1)
+
+ assert "glm-5" in stats
+ assert stats["glm-5"]["total"] == 1
+ assert stats["glm-5"]["success"] == 1
+ assert stats["glm-5"]["failed"] == 0
+
+
+@pytest.mark.asyncio
+async def test_request_log_dao_supports_count_and_offset_pagination(tmp_path):
+ dao = RequestLogDAO(str(tmp_path / "request_logs_paging.db"))
+
+ for index in range(5):
+ await dao.add_log(
+ provider="zai",
+ endpoint=f"/v1/chat/completions/{index}",
+ source="pytest",
+ protocol="openai",
+ client_name="pytest",
+ model="glm-5",
+ status_code=200,
+ success=True,
+ duration=0.1,
+ )
+
+ total_count = await dao.count_logs(provider="zai")
+ paged_logs = await dao.get_recent_logs(
+ limit=2,
+ offset=2,
+ provider="zai",
+ )
+
+ assert total_count == 5
+ assert len(paged_logs) == 2
+ assert paged_logs[0]["endpoint"] == "/v1/chat/completions/2"
+ assert paged_logs[1]["endpoint"] == "/v1/chat/completions/1"
+
+
+@pytest.mark.asyncio
+async def test_request_log_dao_returns_usage_trend_with_missing_days_filled(
+ tmp_path,
+):
+ dao = RequestLogDAO(str(tmp_path / "request_logs_trend.db"))
+
+ await dao.add_log(
+ provider="zai",
+ endpoint="/v1/chat/completions",
+ source="pytest",
+ protocol="openai",
+ client_name="pytest",
+ model="glm-5",
+ status_code=200,
+ success=True,
+ duration=0.2,
+ input_tokens=12,
+ output_tokens=8,
+ cache_read_tokens=3,
+ total_tokens=20,
+ )
+
+ trend = await dao.get_provider_usage_trend("zai", days=7)
+
+ assert len(trend) == 7
+ assert sum(day["total_requests"] for day in trend) == 1
+ assert trend[-1]["total_tokens"] == 20
+ assert trend[-1]["cache_total_tokens"] == 3
+
+
+@pytest.mark.asyncio
+async def test_request_log_dao_returns_hourly_usage_trend_with_missing_hours(
+ tmp_path,
+):
+ dao = RequestLogDAO(str(tmp_path / "request_logs_hourly_trend.db"))
+ log_id = await dao.add_log(
+ provider="zai",
+ endpoint="/v1/chat/completions",
+ source="pytest",
+ protocol="openai",
+ client_name="pytest",
+ model="glm-5",
+ status_code=200,
+ success=True,
+ duration=0.2,
+ input_tokens=18,
+ output_tokens=7,
+ cache_creation_tokens=5,
+ cache_read_tokens=3,
+ total_tokens=25,
+ )
+
+ async with dao.get_connection() as conn:
+ await conn.execute(
+ "UPDATE request_logs SET timestamp = ? WHERE id = ?",
+ ("2026-03-10 12:00:00", log_id),
+ )
+ await conn.commit()
+
+ trend = await dao.get_provider_usage_trend(
+ "zai",
+ window="24h",
+ now=datetime(2026, 3, 10, 12, 0, 0),
+ )
+
+ assert len(trend) == 24
+ assert trend[-1]["label"] == "12:00"
+ assert trend[-1]["tooltip_label"] == "2026-03-10 12:00"
+ assert trend[-1]["input_tokens"] == 18
+ assert trend[-1]["output_tokens"] == 7
+ assert trend[-1]["cache_creation_tokens"] == 5
+ assert trend[-1]["cache_read_tokens"] == 3
+ assert sum(point["total_requests"] for point in trend) == 1
+ assert all(point["total_requests"] == 0 for point in trend[:-1])
+
+
+@pytest.mark.asyncio
+async def test_dashboard_usage_trend_api_returns_requested_window(
+ tmp_path,
+ monkeypatch,
+):
+ dao = RequestLogDAO(str(tmp_path / "request_logs_api_trend.db"))
+ log_id = await dao.add_log(
+ provider="zai",
+ endpoint="/v1/chat/completions",
+ source="pytest",
+ protocol="openai",
+ client_name="pytest",
+ model="glm-5",
+ status_code=200,
+ success=True,
+ duration=0.2,
+ input_tokens=30,
+ output_tokens=12,
+ cache_read_tokens=4,
+ total_tokens=42,
+ )
+
+ async with dao.get_connection() as conn:
+ await conn.execute(
+ "UPDATE request_logs SET timestamp = ? WHERE id = ?",
+ ("2026-03-10 09:00:00", log_id),
+ )
+ await conn.commit()
+
+ async def fixed_usage_trend(provider, days=None, *, window=None, now=None):
+ return await RequestLogDAO.get_provider_usage_trend(
+ dao,
+ provider,
+ days=days,
+ window=window,
+ now=datetime(2026, 3, 10, 12, 0, 0),
+ )
+
+ monkeypatch.setattr(dao, "get_provider_usage_trend", fixed_usage_trend)
+ monkeypatch.setattr(admin_api, "get_request_log_dao", lambda: dao)
+ request = _make_get_request(
+ "/admin/api/dashboard/usage-trend",
+ {"window": "24h"},
+ )
+
+ response = await admin_api.get_dashboard_usage_trend(request)
+ payload = json.loads(response.body.decode("utf-8"))
+
+ assert response.status_code == 200
+ assert payload["window"] == "24h"
+ assert len(payload["points"]) == 24
+ assert payload["points"][-4]["input_tokens"] == 30
+ assert payload["points"][-4]["cache_read_tokens"] == 4
+
+
+@pytest.mark.asyncio
+async def test_recent_logs_component_includes_usage_cache_and_latency_fields(
+ tmp_path,
+ monkeypatch,
+):
+ dao = RequestLogDAO(str(tmp_path / "request_logs_recent_component.db"))
+ await dao.add_log(
+ provider="zai",
+ endpoint="/v1/chat/completions",
+ source="pytest",
+ protocol="openai",
+ client_name="pytest-client",
+ model="glm-5",
+ status_code=200,
+ success=True,
+ duration=1.25,
+ first_token_time=0.42,
+ input_tokens=111,
+ output_tokens=22,
+ cache_creation_tokens=9,
+ cache_read_tokens=7,
+ total_tokens=133,
+ )
+
+ monkeypatch.setattr(admin_api, "get_request_log_dao", lambda: dao)
+ request = _make_get_request(
+ "/admin/api/recent-logs",
+ {"page": "1", "page_size": "12"},
+ )
+
+ response = await admin_api.get_recent_logs(request)
+ body = response.body.decode("utf-8")
+
+ assert response.status_code == 200
+ assert "请求" in body
+ assert "标记" in body
+ assert "输入 / 输出" in body
+ assert "缓存创建 / 命中" in body
+ assert "用时 / 首字" in body
+ assert "111" in body
+ assert "22" in body
+ assert "9" in body
+ assert "7" in body
+ assert "1.25s" in body
+ assert "0.42s" in body
+
+
+@pytest.mark.asyncio
+async def test_recent_logs_component_deduplicates_client_and_source_labels(
+ tmp_path,
+ monkeypatch,
+):
+ dao = RequestLogDAO(str(tmp_path / "request_logs_recent_dedupe.db"))
+ await dao.add_log(
+ provider="zai",
+ endpoint="/v1/chat/completions",
+ source="browser",
+ protocol="openai",
+ client_name="Browser",
+ model="glm-5",
+ status_code=200,
+ success=True,
+ duration=1.0,
+ )
+
+ monkeypatch.setattr(admin_api, "get_request_log_dao", lambda: dao)
+ request = _make_get_request(
+ "/admin/api/recent-logs",
+ {"page": "1", "page_size": "12"},
+ )
+
+ response = await admin_api.get_recent_logs(request)
+ body = response.body.decode("utf-8")
+
+ assert response.status_code == 200
+ assert "Browser" in body
+ assert "OpenAI" in body
+ assert "glm-5" in body
+ assert ">browser<" not in body
+ assert ">zai<" not in body
+
+
+@pytest.mark.asyncio
+async def test_token_dao_supports_count_and_offset_pagination(tmp_path):
+ dao = TokenDAO(str(tmp_path / "tokens_paging.db"))
+ await dao.init_database()
+
+ for index in range(5):
+ await dao.add_token("zai", f"token-{index}", validate=False)
+
+ total_count = await dao.count_tokens_by_provider("zai", enabled_only=False)
+ paged_tokens = await dao.get_tokens_by_provider(
+ "zai",
+ enabled_only=False,
+ limit=2,
+ offset=2,
+ )
+
+ assert total_count == 5
+ assert len(paged_tokens) == 2
+ assert paged_tokens[0]["token"] == "token-2"
+ assert paged_tokens[1]["token"] == "token-3"
+
+
+@pytest.mark.asyncio
+async def test_token_pool_realtime_usage_stats_sync_to_db(tmp_path, monkeypatch):
+ dao = TokenDAO(str(tmp_path / "token_usage.db"))
+ await dao.init_database()
+ token_id = await dao.add_token("zai", "token-usage", validate=False)
+ assert token_id is not None
+
+ pool = TokenPool([(token_id, "token-usage", "user")])
+
+ await pool.record_token_success("token-usage", dao=dao)
+ await pool.record_token_failure("token-usage", Exception("boom"), dao=dao)
+
+ stats = await dao.get_token_stats(token_id)
+ assert stats is not None
+ assert stats["total_requests"] == 2
+ assert stats["successful_requests"] == 1
+ assert stats["failed_requests"] == 1
+
+ monkeypatch.setattr(token_pool_module, "_token_pool", pool)
+ monkeypatch.setattr(token_dao_module, "_token_dao", dao)
+
+ await sync_token_stats_to_db()
+
+ stats_after_sync = await dao.get_token_stats(token_id)
+ assert stats_after_sync is not None
+ assert stats_after_sync["total_requests"] == 2
+ assert stats_after_sync["successful_requests"] == 1
+ assert stats_after_sync["failed_requests"] == 1
+
+
+def test_format_uptime_formats_seconds_minutes_and_hours():
+ assert format_uptime(59) == "59秒"
+ assert format_uptime(3661) == "1小时 1分钟 1秒"
diff --git a/tests/test_admin_tokens.py b/tests/test_admin_tokens.py
new file mode 100644
index 0000000000000000000000000000000000000000..23752b24f42390c3861e4047ace9497a5e45b720
--- /dev/null
+++ b/tests/test_admin_tokens.py
@@ -0,0 +1,154 @@
+from urllib.parse import urlencode
+
+import pytest
+from jinja2 import Environment, FileSystemLoader
+from starlette.requests import Request
+
+from app.admin import api as admin_api
+from app.core.config import settings
+from app.services.token_automation import TokenMaintenanceSummary
+from app.services.token_importer import TokenImportSummary
+
+
+def _make_form_request(path: str, data: dict[str, str] | None = None) -> Request:
+ encoded = urlencode(data or {}, doseq=True).encode()
+ sent = False
+
+ async def receive():
+ nonlocal sent
+ if sent:
+ return {"type": "http.request", "body": b"", "more_body": False}
+ sent = True
+ return {"type": "http.request", "body": encoded, "more_body": False}
+
+ scope = {
+ "type": "http",
+ "http_version": "1.1",
+ "method": "POST",
+ "scheme": "http",
+ "path": path,
+ "raw_path": path.encode(),
+ "query_string": b"",
+ "headers": [
+ (
+ b"content-type",
+ b"application/x-www-form-urlencoded",
+ )
+ ],
+ "client": ("testclient", 50000),
+ "server": ("testserver", 80),
+ }
+ return Request(scope, receive)
+
+
+@pytest.mark.asyncio
+async def test_import_directory_uses_configured_source_dir_when_form_empty(
+ tmp_path,
+ monkeypatch,
+):
+ source_dir = tmp_path / "tokens"
+ source_dir.mkdir()
+ monkeypatch.setattr(
+ settings,
+ "TOKEN_AUTO_IMPORT_SOURCE_DIR",
+ str(source_dir),
+ )
+
+ called: dict[str, object] = {}
+
+ async def fake_run_directory_import(
+ source_dir_arg,
+ *,
+ provider,
+ validate,
+ ):
+ called["source_dir"] = source_dir_arg
+ called["provider"] = provider
+ called["validate"] = validate
+ return TokenImportSummary(
+ source_dir=str(source_dir),
+ scanned_files=1,
+ imported_count=1,
+ duplicate_count=0,
+ invalid_json_count=0,
+ missing_token_count=0,
+ invalid_token_count=0,
+ )
+
+ import app.services.token_automation as token_automation
+
+ monkeypatch.setattr(
+ token_automation,
+ "run_directory_import",
+ fake_run_directory_import,
+ )
+
+ response = await admin_api.import_tokens_from_directory_api(
+ _make_form_request("/admin/api/tokens/import-directory"),
+ )
+ body = response.body.decode("utf-8")
+
+ assert response.status_code == 200
+ assert called["source_dir"] == str(source_dir)
+ assert called["provider"] == "zai"
+ assert called["validate"] is True
+ assert "导入成功" in body
+
+
+@pytest.mark.asyncio
+async def test_run_maintenance_uses_configured_actions_when_form_empty(
+ monkeypatch,
+):
+ monkeypatch.setattr(settings, "TOKEN_AUTO_REMOVE_DUPLICATES", True)
+ monkeypatch.setattr(settings, "TOKEN_AUTO_HEALTH_CHECK", False)
+ monkeypatch.setattr(settings, "TOKEN_AUTO_DELETE_INVALID", True)
+
+ called: dict[str, object] = {}
+
+ async def fake_run_token_maintenance(
+ *,
+ provider,
+ remove_duplicates,
+ run_health_check,
+ delete_invalid_tokens,
+ ):
+ called["provider"] = provider
+ called["remove_duplicates"] = remove_duplicates
+ called["run_health_check"] = run_health_check
+ called["delete_invalid_tokens"] = delete_invalid_tokens
+ return TokenMaintenanceSummary(
+ provider=provider,
+ checked_count=2,
+ duplicate_removed_count=1,
+ valid_count=1,
+ guest_count=0,
+ invalid_count=1,
+ deleted_invalid_count=1,
+ )
+
+ import app.services.token_automation as token_automation
+
+ monkeypatch.setattr(
+ token_automation,
+ "run_token_maintenance",
+ fake_run_token_maintenance,
+ )
+
+ response = await admin_api.run_token_maintenance_api(
+ _make_form_request("/admin/api/tokens/maintenance/run"),
+ )
+ body = response.body.decode("utf-8")
+
+ assert response.status_code == 200
+ assert called["provider"] == "zai"
+ assert called["remove_duplicates"] is True
+ assert called["run_health_check"] is False
+ assert called["delete_invalid_tokens"] is True
+ assert "维护完成" in body
+
+
+def test_tokens_template_compiles():
+ env = Environment(loader=FileSystemLoader("app/templates"))
+ template = env.get_template("tokens.html")
+
+ assert template is not None
diff --git a/tests/test_dependency_metadata.py b/tests/test_dependency_metadata.py
new file mode 100644
index 0000000000000000000000000000000000000000..47f6130b24dcd09a099f53109060f1ea95be561c
--- /dev/null
+++ b/tests/test_dependency_metadata.py
@@ -0,0 +1,14 @@
+from pathlib import Path
+
+
+ROOT = Path(__file__).resolve().parents[1]
+
+
+def test_requirements_enable_httpx_socks_support():
+ requirements = (ROOT / "requirements.txt").read_text(encoding="utf-8")
+ assert "httpx[http2,socks]==0.28.1" in requirements
+
+
+def test_pyproject_enable_httpx_socks_support():
+ pyproject = (ROOT / "pyproject.toml").read_text(encoding="utf-8")
+ assert '"httpx[http2,socks]==0.28.1"' in pyproject
diff --git a/tests/test_glm45_real_token.py b/tests/test_glm45_real_token.py
new file mode 100644
index 0000000000000000000000000000000000000000..9fef459176192a4ae3af5a56c9f7e2a80c983210
--- /dev/null
+++ b/tests/test_glm45_real_token.py
@@ -0,0 +1,33 @@
+import pytest
+
+from app.core.config import settings
+from app.core.upstream import UpstreamClient
+from app.models.schemas import Message, OpenAIRequest
+from tests.real_upstream_test_utils import (
+ assert_usage_present,
+ extract_content,
+ install_real_anonymous,
+)
+
+
+@pytest.mark.asyncio
+async def test_glm45_with_real_anonymous_request(monkeypatch):
+ install_real_anonymous(monkeypatch)
+
+ client = UpstreamClient()
+ request = OpenAIRequest(
+ model=settings.GLM45_MODEL,
+ messages=[
+ Message(
+ role="user",
+ content="请只输出字符串 GLM45_OK,不要输出任何其他内容。",
+ )
+ ],
+ stream=False,
+ )
+
+ payload = await client.chat_completion(request)
+ content = extract_content(payload)
+
+ assert "GLM45_OK" in content
+ assert_usage_present(payload)
diff --git a/tests/test_glm46v_real_token.py b/tests/test_glm46v_real_token.py
new file mode 100644
index 0000000000000000000000000000000000000000..95734ca1b70e8e8077a4f65948a16e14a3de85ec
--- /dev/null
+++ b/tests/test_glm46v_real_token.py
@@ -0,0 +1,43 @@
+import pytest
+
+from app.core.config import settings
+from app.core.upstream import UpstreamClient
+from app.models.schemas import ContentPart, ImageUrl, Message, OpenAIRequest
+from tests.real_upstream_test_utils import (
+ RED_2X2_PNG_DATA_URL,
+ assert_usage_present,
+ extract_content,
+ install_real_auth,
+)
+
+
+@pytest.mark.asyncio
+async def test_glm46v_with_real_auth_token_and_image(monkeypatch):
+ install_real_auth(monkeypatch)
+
+ client = UpstreamClient()
+ request = OpenAIRequest(
+ model=settings.GLM46V_MODEL,
+ messages=[
+ Message(
+ role="user",
+ content=[
+ ContentPart(
+ type="text",
+ text="请判断这张图片的主色调。如果它是红色,只输出 RED_OK。",
+ ),
+ ContentPart(
+ type="image_url",
+ image_url=ImageUrl(url=RED_2X2_PNG_DATA_URL),
+ ),
+ ],
+ )
+ ],
+ stream=False,
+ )
+
+ payload = await client.chat_completion(request)
+ content = extract_content(payload)
+
+ assert "RED_OK" in content
+ assert_usage_present(payload)
diff --git a/tests/test_glm47_request_bootstrap.py b/tests/test_glm47_request_bootstrap.py
new file mode 100644
index 0000000000000000000000000000000000000000..3940370be4b783dec6f626cf0c955cb1ba5e518e
--- /dev/null
+++ b/tests/test_glm47_request_bootstrap.py
@@ -0,0 +1,444 @@
+from urllib.parse import parse_qs, urlparse
+
+import pytest
+
+from app.core import upstream as upstream_module
+from app.core.upstream import UpstreamClient
+from app.models.schemas import ContentPart, ImageUrl, Message, OpenAIRequest
+
+FAKE_HEADERS = {
+ "Content-Type": "application/json",
+ "Accept": "application/json, text/event-stream",
+ "Connection": "keep-alive",
+ "Cache-Control": "no-cache",
+ "User-Agent": (
+ "Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
+ "AppleWebKit/537.36 (KHTML, like Gecko) "
+ "Chrome/144.0.0.0 Safari/537.36"
+ ),
+ "Accept-Language": "zh-CN",
+ "X-FE-Version": "prod-fe-test",
+ "Origin": "https://chat.z.ai",
+ "Referer": "https://chat.z.ai/",
+}
+
+
+def _make_request(model: str) -> OpenAIRequest:
+ return OpenAIRequest(
+ model=model,
+ messages=[Message(role="user", content="请用一句话回答:你好")],
+ stream=True,
+ )
+
+
+async def _fake_get_auth_info(self, excluded_tokens=None, excluded_guest_user_ids=None):
+ return {
+ "token": "auth-token",
+ "user_id": "user-123",
+ "username": "User",
+ "auth_mode": "authenticated",
+ "token_source": "auth_pool",
+ "guest_user_id": None,
+ }
+
+
+@pytest.mark.asyncio
+async def test_glm47_request_bootstraps_chat_and_uses_browser_signature(monkeypatch):
+ create_chat_calls: list[dict] = []
+ browser_type_calls: list[str | None] = []
+
+ def fake_headers(chat_id: str = "", browser_type=None):
+ browser_type_calls.append(browser_type)
+ headers = dict(FAKE_HEADERS)
+ headers["Referer"] = (
+ f"https://chat.z.ai/c/{chat_id}"
+ if chat_id
+ else FAKE_HEADERS["Referer"]
+ )
+ return headers
+
+ async def fake_create_chat(
+ self,
+ *,
+ prompt,
+ model,
+ token,
+ headers,
+ enable_thinking,
+ web_search,
+ user_message_id,
+ files,
+ feature_entries,
+ mcp_servers,
+ ):
+ create_chat_calls.append(
+ {
+ "prompt": prompt,
+ "model": model,
+ "token": token,
+ "user_agent": headers["User-Agent"],
+ "enable_thinking": enable_thinking,
+ "web_search": web_search,
+ "user_message_id": user_message_id,
+ "files": files,
+ "feature_entries": feature_entries,
+ "mcp_servers": mcp_servers,
+ }
+ )
+ return "persisted-chat-id"
+
+ monkeypatch.setattr(UpstreamClient, "get_auth_info", _fake_get_auth_info)
+ monkeypatch.setattr(UpstreamClient, "_create_upstream_chat", fake_create_chat)
+ monkeypatch.setattr(upstream_module, "get_dynamic_headers", fake_headers)
+
+ client = UpstreamClient()
+ transformed = await client.transform_request(_make_request("GLM-4.7"))
+ parsed_url = urlparse(transformed["url"])
+ query = parse_qs(parsed_url.query)
+
+ assert len(create_chat_calls) == 1
+ assert create_chat_calls[0]["prompt"] == "请用一句话回答:你好"
+ assert create_chat_calls[0]["model"] == "glm-4.7"
+ assert create_chat_calls[0]["token"] == "auth-token"
+ assert create_chat_calls[0]["user_agent"] == FAKE_HEADERS["User-Agent"]
+ assert create_chat_calls[0]["enable_thinking"] is False
+ assert create_chat_calls[0]["web_search"] is False
+ assert create_chat_calls[0]["files"] is None
+ assert create_chat_calls[0]["feature_entries"] is None
+ assert create_chat_calls[0]["mcp_servers"] is None
+ assert create_chat_calls[0]["user_message_id"]
+ assert browser_type_calls == ["chrome"]
+ assert transformed["chat_id"] == "persisted-chat-id"
+ assert transformed["headers"]["Accept"] == "*/*"
+ assert transformed["headers"]["Referer"] == "https://chat.z.ai/c/persisted-chat-id"
+ assert query["current_url"] == ["https://chat.z.ai/c/persisted-chat-id"]
+ assert query["pathname"] == ["/c/persisted-chat-id"]
+ assert query["user_agent"] == [FAKE_HEADERS["User-Agent"]]
+ assert query["timezone"] == ["Asia/Shanghai"]
+ assert transformed["body"]["chat_id"] == "persisted-chat-id"
+ assert transformed["body"]["current_user_message_id"] == (
+ create_chat_calls[0]["user_message_id"]
+ )
+ assert transformed["body"]["features"]["enable_thinking"] is False
+ assert transformed["body"]["background_tasks"] == {
+ "title_generation": True,
+ "tags_generation": True,
+ }
+ assert "session_id" not in transformed["body"]
+ assert "model_item" not in transformed["body"]
+
+
+@pytest.mark.asyncio
+async def test_glm47_thinking_defaults_to_enable_thinking(monkeypatch):
+ create_chat_calls: list[dict] = []
+
+ def fake_headers(chat_id: str = "", browser_type=None):
+ headers = dict(FAKE_HEADERS)
+ headers["Referer"] = (
+ f"https://chat.z.ai/c/{chat_id}"
+ if chat_id
+ else FAKE_HEADERS["Referer"]
+ )
+ return headers
+
+ async def fake_create_chat(
+ self,
+ *,
+ prompt,
+ model,
+ token,
+ headers,
+ enable_thinking,
+ web_search,
+ user_message_id,
+ files,
+ feature_entries,
+ mcp_servers,
+ ):
+ create_chat_calls.append(
+ {
+ "model": model,
+ "enable_thinking": enable_thinking,
+ "web_search": web_search,
+ "user_message_id": user_message_id,
+ "files": files,
+ "feature_entries": feature_entries,
+ "mcp_servers": mcp_servers,
+ }
+ )
+ return "thinking-chat-id"
+
+ monkeypatch.setattr(UpstreamClient, "get_auth_info", _fake_get_auth_info)
+ monkeypatch.setattr(UpstreamClient, "_create_upstream_chat", fake_create_chat)
+ monkeypatch.setattr(upstream_module, "get_dynamic_headers", fake_headers)
+
+ client = UpstreamClient()
+ transformed = await client.transform_request(_make_request("GLM-4.7-Thinking"))
+
+ assert len(create_chat_calls) == 1
+ assert create_chat_calls[0]["model"] == "glm-4.7"
+ assert create_chat_calls[0]["enable_thinking"] is True
+ assert create_chat_calls[0]["web_search"] is False
+ assert create_chat_calls[0]["files"] is None
+ assert create_chat_calls[0]["feature_entries"] is None
+ assert create_chat_calls[0]["mcp_servers"] is None
+ assert create_chat_calls[0]["user_message_id"]
+ assert transformed["body"]["features"]["enable_thinking"] is True
+ assert transformed["body"]["current_user_message_id"] == (
+ create_chat_calls[0]["user_message_id"]
+ )
+
+
+@pytest.mark.asyncio
+async def test_non_glm47_request_keeps_legacy_request_shape(monkeypatch):
+ def fake_headers(chat_id: str = "", browser_type=None):
+ headers = dict(FAKE_HEADERS)
+ headers["Referer"] = (
+ f"https://chat.z.ai/c/{chat_id}"
+ if chat_id
+ else FAKE_HEADERS["Referer"]
+ )
+ return headers
+
+ async def fail_create_chat(self, **kwargs):
+ raise AssertionError("GLM-4.5 不应触发 create_chat")
+
+ monkeypatch.setattr(UpstreamClient, "get_auth_info", _fake_get_auth_info)
+ monkeypatch.setattr(UpstreamClient, "_create_upstream_chat", fail_create_chat)
+ monkeypatch.setattr(upstream_module, "get_dynamic_headers", fake_headers)
+
+ client = UpstreamClient()
+ transformed = await client.transform_request(_make_request("GLM-4.5"))
+ query = parse_qs(urlparse(transformed["url"]).query)
+
+ assert transformed["headers"]["Accept"] == "application/json"
+ assert transformed["chat_id"] != "persisted-chat-id"
+ assert "user_agent" not in query
+ assert "session_id" in transformed["body"]
+ assert transformed["body"]["model_item"]["name"] == "GLM-4.5"
+
+
+@pytest.mark.asyncio
+async def test_glm5_defaults_to_enable_thinking(monkeypatch):
+ def fake_headers(chat_id: str = "", browser_type=None):
+ headers = dict(FAKE_HEADERS)
+ headers["Referer"] = (
+ f"https://chat.z.ai/c/{chat_id}"
+ if chat_id
+ else FAKE_HEADERS["Referer"]
+ )
+ return headers
+
+ async def fail_create_chat(self, **kwargs):
+ raise AssertionError("GLM-5 不应触发 create_chat")
+
+ monkeypatch.setattr(UpstreamClient, "get_auth_info", _fake_get_auth_info)
+ monkeypatch.setattr(UpstreamClient, "_create_upstream_chat", fail_create_chat)
+ monkeypatch.setattr(upstream_module, "get_dynamic_headers", fake_headers)
+
+ client = UpstreamClient()
+ transformed = await client.transform_request(_make_request("GLM-5"))
+ query = parse_qs(urlparse(transformed["url"]).query)
+
+ assert transformed["headers"]["Accept"] == "application/json"
+ assert transformed["body"]["model"] == "glm-5"
+ assert transformed["body"]["features"]["enable_thinking"] is True
+ assert transformed["body"]["features"]["preview_mode"] is True
+ assert "session_id" in transformed["body"]
+ assert "user_agent" not in query
+
+
+@pytest.mark.asyncio
+async def test_glm5_allows_explicitly_disabling_thinking(monkeypatch):
+ def fake_headers(chat_id: str = "", browser_type=None):
+ headers = dict(FAKE_HEADERS)
+ headers["Referer"] = (
+ f"https://chat.z.ai/c/{chat_id}"
+ if chat_id
+ else FAKE_HEADERS["Referer"]
+ )
+ return headers
+
+ async def fail_create_chat(self, **kwargs):
+ raise AssertionError("GLM-5 不应触发 create_chat")
+
+ monkeypatch.setattr(UpstreamClient, "get_auth_info", _fake_get_auth_info)
+ monkeypatch.setattr(UpstreamClient, "_create_upstream_chat", fail_create_chat)
+ monkeypatch.setattr(upstream_module, "get_dynamic_headers", fake_headers)
+
+ client = UpstreamClient()
+ request = OpenAIRequest(
+ model="GLM-5",
+ messages=[Message(role="user", content="请用一句话回答:你好")],
+ stream=False,
+ enable_thinking=False,
+ )
+
+ transformed = await client.transform_request(request)
+
+ assert transformed["body"]["features"]["enable_thinking"] is False
+
+
+@pytest.mark.asyncio
+async def test_glm46v_uses_persisted_chat_and_visual_features(monkeypatch):
+ create_chat_calls: list[dict] = []
+ upload_calls: list[dict] = []
+ browser_type_calls: list[str | None] = []
+
+ def fake_headers(chat_id: str = "", browser_type=None):
+ browser_type_calls.append(browser_type)
+ headers = dict(FAKE_HEADERS)
+ headers["Referer"] = (
+ f"https://chat.z.ai/c/{chat_id}"
+ if chat_id
+ else FAKE_HEADERS["Referer"]
+ )
+ return headers
+
+ async def fake_create_chat(
+ self,
+ *,
+ prompt,
+ model,
+ token,
+ headers,
+ enable_thinking,
+ web_search,
+ user_message_id,
+ files,
+ feature_entries,
+ mcp_servers,
+ ):
+ create_chat_calls.append(
+ {
+ "prompt": prompt,
+ "model": model,
+ "token": token,
+ "user_agent": headers["User-Agent"],
+ "enable_thinking": enable_thinking,
+ "web_search": web_search,
+ "user_message_id": user_message_id,
+ "files": files,
+ "feature_entries": feature_entries,
+ "mcp_servers": mcp_servers,
+ }
+ )
+ return "vision-chat-id"
+
+ async def fake_upload_image(
+ self,
+ data_url,
+ chat_id,
+ token,
+ user_id,
+ auth_mode="authenticated",
+ ):
+ upload_calls.append(
+ {
+ "data_url": data_url,
+ "chat_id": chat_id,
+ "token": token,
+ "user_id": user_id,
+ "auth_mode": auth_mode,
+ }
+ )
+ return {
+ "type": "image",
+ "file": {
+ "id": "file-id",
+ "user_id": user_id,
+ "filename": "file.png",
+ "data": {},
+ "meta": {
+ "name": "file.png",
+ "content_type": "image/png",
+ "size": 4,
+ "data": {},
+ },
+ "created_at": 1,
+ "updated_at": 1,
+ },
+ "id": "file-id",
+ "url": "/api/v1/files/file-id/content",
+ "name": "file.png",
+ "status": "uploaded",
+ "size": 4,
+ "error": "",
+ "itemId": "item-id",
+ "media": "image",
+ }
+
+ monkeypatch.setattr(UpstreamClient, "get_auth_info", _fake_get_auth_info)
+ monkeypatch.setattr(UpstreamClient, "_create_upstream_chat", fake_create_chat)
+ monkeypatch.setattr(UpstreamClient, "upload_image", fake_upload_image)
+ monkeypatch.setattr(upstream_module, "get_dynamic_headers", fake_headers)
+
+ client = UpstreamClient()
+ request = OpenAIRequest(
+ model="GLM-4.6V",
+ messages=[
+ Message(
+ role="user",
+ content=[
+ ContentPart(type="text", text="请判断图片主色调"),
+ ContentPart(
+ type="image_url",
+ image_url=ImageUrl(url="data:image/png;base64,AAAA"),
+ ),
+ ],
+ )
+ ],
+ stream=False,
+ )
+
+ transformed = await client.transform_request(request)
+ query = parse_qs(urlparse(transformed["url"]).query)
+
+ assert len(create_chat_calls) == 1
+ assert create_chat_calls[0]["prompt"] == "请判断图片主色调"
+ assert create_chat_calls[0]["model"] == "glm-4.6v"
+ assert create_chat_calls[0]["token"] == "auth-token"
+ assert create_chat_calls[0]["user_agent"] == FAKE_HEADERS["User-Agent"]
+ assert create_chat_calls[0]["enable_thinking"] is True
+ assert create_chat_calls[0]["web_search"] is False
+ assert (
+ create_chat_calls[0]["feature_entries"]
+ == upstream_module.GLM46V_SELECTED_FEATURES
+ )
+ assert create_chat_calls[0]["mcp_servers"] == upstream_module.GLM46V_MCP_SERVERS
+ assert create_chat_calls[0]["user_message_id"]
+ assert create_chat_calls[0]["files"][0]["id"] == "file-id"
+ assert create_chat_calls[0]["files"][0]["ref_user_msg_id"] == (
+ create_chat_calls[0]["user_message_id"]
+ )
+ assert upload_calls == [
+ {
+ "data_url": "data:image/png;base64,AAAA",
+ "chat_id": "",
+ "token": "auth-token",
+ "user_id": "user-123",
+ "auth_mode": "authenticated",
+ }
+ ]
+ assert browser_type_calls == ["chrome"]
+ assert transformed["chat_id"] == "vision-chat-id"
+ assert transformed["headers"]["Accept"] == "*/*"
+ assert transformed["headers"]["Referer"] == "https://chat.z.ai/c/vision-chat-id"
+ assert query["current_url"] == ["https://chat.z.ai/c/vision-chat-id"]
+ assert query["pathname"] == ["/c/vision-chat-id"]
+ assert query["user_agent"] == [FAKE_HEADERS["User-Agent"]]
+ assert transformed["body"]["current_user_message_id"] == (
+ create_chat_calls[0]["user_message_id"]
+ )
+ assert transformed["body"]["features"]["enable_thinking"] is True
+ assert transformed["body"]["features"]["preview_mode"] is False
+ assert "features" not in transformed["body"]["features"]
+ assert transformed["body"]["mcp_servers"] == upstream_module.GLM46V_MCP_SERVERS
+ assert transformed["body"]["files"][0]["id"] == "file-id"
+ assert transformed["body"]["files"][0]["ref_user_msg_id"] == (
+ create_chat_calls[0]["user_message_id"]
+ )
+ assert transformed["body"]["messages"][0]["content"][1]["image_url"]["url"] == (
+ "file-id"
+ )
+ assert "session_id" not in transformed["body"]
diff --git a/tests/test_glm5_real_token.py b/tests/test_glm5_real_token.py
new file mode 100644
index 0000000000000000000000000000000000000000..aa6de719cc0ab2ca3b7ab6c9508c463978159a4c
--- /dev/null
+++ b/tests/test_glm5_real_token.py
@@ -0,0 +1,33 @@
+import pytest
+
+from app.core.config import settings
+from app.core.upstream import UpstreamClient
+from app.models.schemas import Message, OpenAIRequest
+from tests.real_upstream_test_utils import (
+ assert_usage_present,
+ extract_content,
+ install_real_anonymous,
+)
+
+
+@pytest.mark.asyncio
+async def test_glm5_with_real_anonymous_request(monkeypatch):
+ install_real_anonymous(monkeypatch)
+
+ client = UpstreamClient()
+ request = OpenAIRequest(
+ model=settings.GLM5_MODEL,
+ messages=[
+ Message(
+ role="user",
+ content="请只输出字符串 GLM5_OK,不要输出任何其他内容。",
+ )
+ ],
+ stream=False,
+ )
+
+ payload = await client.chat_completion(request)
+ content = extract_content(payload)
+
+ assert "GLM5_OK" in content
+ assert_usage_present(payload)
diff --git a/tests/test_guest_pool_concurrency.py b/tests/test_guest_pool_concurrency.py
new file mode 100644
index 0000000000000000000000000000000000000000..b85d40603b94a048a8bdf549adcfed653d4be544
--- /dev/null
+++ b/tests/test_guest_pool_concurrency.py
@@ -0,0 +1,222 @@
+import asyncio
+import types
+from dataclasses import dataclass, field
+from unittest.mock import AsyncMock
+
+import pytest
+
+from app.core import upstream as upstream_module
+from app.core.upstream import UpstreamClient
+from app.models.schemas import Message, OpenAIRequest
+from app.utils.guest_session_pool import GuestSession, GuestSessionPool
+
+POOL_SIZE = 8
+REQUEST_COUNT = 64
+REQUEST_DELAY_SECONDS = 0.03
+FAILURE_POOL_SIZE = 4
+FAILURE_REQUEST_COUNT = 24
+FAILURE_DELAY_SECONDS = 0.02
+
+
+def _make_session(user_id: str, token_suffix: str) -> GuestSession:
+ return GuestSession(
+ token=f"token-{token_suffix}",
+ user_id=user_id,
+ username=f"Guest-{user_id}",
+ )
+
+
+def _make_request() -> OpenAIRequest:
+ return OpenAIRequest(
+ model="GLM-4.5",
+ messages=[Message(role="user", content="ping")],
+ stream=False,
+ )
+
+
+@dataclass
+class LoadState:
+ active_posts: int = 0
+ peak_posts: int = 0
+ failed_once: set[str] = field(default_factory=set)
+
+
+class FakeResponse:
+ def __init__(self, status_code: int, text: str):
+ self.status_code = status_code
+ self.text = text
+
+ @property
+ def is_success(self) -> bool:
+ return 200 <= self.status_code < 300
+
+
+def _build_fake_async_client(handler):
+ class FakeAsyncClient:
+ def __init__(self, *args, **kwargs):
+ pass
+
+ async def __aenter__(self):
+ return self
+
+ async def __aexit__(self, exc_type, exc, tb):
+ return False
+
+ async def post(self, url, headers=None, json=None):
+ return await handler(url, headers or {}, json or {})
+
+ return FakeAsyncClient
+
+
+async def _build_pool(monkeypatch, pool_size: int) -> GuestSessionPool:
+ pool = GuestSessionPool(pool_size=pool_size)
+ counter = 0
+
+ async def fake_create_session() -> GuestSession:
+ nonlocal counter
+ counter += 1
+ return _make_session(f"guest-{counter}", str(counter))
+
+ monkeypatch.setattr(pool, "_create_session", fake_create_session)
+ monkeypatch.setattr(pool, "_maintenance_loop", AsyncMock(return_value=None))
+ monkeypatch.setattr(pool, "_delete_all_chats", AsyncMock(return_value=True))
+ await pool.initialize()
+ await asyncio.sleep(0)
+ return pool
+
+
+def _bind_guest_request_flow(
+ client,
+ pool: GuestSessionPool,
+ assigned_user_ids: list[str],
+):
+ async def fake_transform_request(
+ self,
+ request,
+ excluded_tokens=None,
+ excluded_guest_user_ids=None,
+ ):
+ session = await pool.acquire(exclude_user_ids=excluded_guest_user_ids)
+ assigned_user_ids.append(session.user_id)
+ return {
+ "url": f"https://upstream.test/{session.user_id}",
+ "headers": {"x-guest-user-id": session.user_id},
+ "body": {"model": request.model},
+ "token": session.token,
+ "chat_id": f"chat-{session.user_id}",
+ "model": request.model,
+ "user_id": session.user_id,
+ "auth_mode": "guest",
+ "token_source": "guest_pool",
+ "guest_user_id": session.user_id,
+ }
+
+ async def fake_transform_response(self, response, request, transformed):
+ return {
+ "ok": response.is_success,
+ "guest_user_id": transformed["guest_user_id"],
+ "status_code": response.status_code,
+ }
+
+ client.transform_request = types.MethodType(fake_transform_request, client)
+ client.transform_response = types.MethodType(fake_transform_response, client)
+
+
+def _patch_upstream_globals(monkeypatch, pool: GuestSessionPool, async_client_cls):
+ monkeypatch.setattr(upstream_module, "get_guest_session_pool", lambda: pool)
+ monkeypatch.setattr(upstream_module, "get_token_pool", lambda: None)
+ monkeypatch.setattr(upstream_module.settings, "ANONYMOUS_MODE", True)
+ monkeypatch.setattr(upstream_module.httpx, "AsyncClient", async_client_cls)
+
+
+def _build_handler(
+ delay: float,
+ state: LoadState,
+ failure_users: set[str] | None = None,
+):
+ lock = asyncio.Lock()
+ failures = failure_users or set()
+
+ async def handler(url, headers, body):
+ user_id = headers["x-guest-user-id"]
+ async with lock:
+ state.active_posts += 1
+ state.peak_posts = max(state.peak_posts, state.active_posts)
+
+ try:
+ await asyncio.sleep(delay)
+ if user_id in failures and user_id not in state.failed_once:
+ state.failed_once.add(user_id)
+ return FakeResponse(401, '{"message":"expired"}')
+ return FakeResponse(200, "{}")
+ finally:
+ async with lock:
+ state.active_posts -= 1
+
+ return handler
+
+
+@pytest.mark.asyncio
+async def test_guest_pool_handles_many_concurrent_requests(monkeypatch):
+ pool = await _build_pool(monkeypatch, POOL_SIZE)
+ assigned_user_ids: list[str] = []
+ state = LoadState()
+ client = UpstreamClient()
+ handler = _build_handler(REQUEST_DELAY_SECONDS, state)
+
+ _bind_guest_request_flow(client, pool, assigned_user_ids)
+ _patch_upstream_globals(
+ monkeypatch,
+ pool,
+ _build_fake_async_client(handler),
+ )
+
+ results = await asyncio.gather(
+ *(client.chat_completion(_make_request()) for _ in range(REQUEST_COUNT))
+ )
+ pool_status = pool.get_pool_status()
+
+ assert all(result.get("ok") is True for result in results)
+ assert len(set(assigned_user_ids)) == POOL_SIZE
+ assert state.peak_posts >= POOL_SIZE
+ assert pool_status == {
+ "total_sessions": POOL_SIZE,
+ "valid_sessions": POOL_SIZE,
+ "available_sessions": POOL_SIZE,
+ "busy_sessions": 0,
+ "expired_sessions": 0,
+ }
+
+ await pool.close()
+
+
+@pytest.mark.asyncio
+async def test_guest_pool_recovers_from_failures_under_concurrency(monkeypatch):
+ pool = await _build_pool(monkeypatch, FAILURE_POOL_SIZE)
+ assigned_user_ids: list[str] = []
+ state = LoadState()
+ client = UpstreamClient()
+ failure_users = {"guest-1", "guest-2"}
+ handler = _build_handler(FAILURE_DELAY_SECONDS, state, failure_users)
+
+ _bind_guest_request_flow(client, pool, assigned_user_ids)
+ _patch_upstream_globals(
+ monkeypatch,
+ pool,
+ _build_fake_async_client(handler),
+ )
+
+ results = await asyncio.gather(
+ *(client.chat_completion(_make_request()) for _ in range(FAILURE_REQUEST_COUNT))
+ )
+ pool_status = pool.get_pool_status()
+ current_user_ids = set(pool._sessions)
+
+ assert all(result.get("ok") is True for result in results)
+ assert state.failed_once == failure_users
+ assert "guest-1" not in current_user_ids
+ assert "guest-2" not in current_user_ids
+ assert pool_status["busy_sessions"] == 0
+ assert pool_status["valid_sessions"] == FAILURE_POOL_SIZE
+
+ await pool.close()
diff --git a/tests/test_guest_session_pool.py b/tests/test_guest_session_pool.py
new file mode 100644
index 0000000000000000000000000000000000000000..27800d9ed503699613c8f27b1d393404c97a84ea
--- /dev/null
+++ b/tests/test_guest_session_pool.py
@@ -0,0 +1,97 @@
+import asyncio
+from unittest.mock import AsyncMock, Mock
+
+import pytest
+
+from app.utils import guest_session_pool as guest_pool_module
+from app.utils.guest_session_pool import GuestSession, GuestSessionPool
+
+
+def _make_session(user_id: str, token_suffix: str) -> GuestSession:
+ return GuestSession(
+ token=f"token-{token_suffix}",
+ user_id=user_id,
+ username=f"Guest-{user_id}",
+ )
+
+
+@pytest.mark.asyncio
+async def test_ensure_capacity_returns_when_only_duplicate_user_ids_are_created(
+ monkeypatch,
+):
+ pool = GuestSessionPool(pool_size=2)
+ create_calls = 0
+
+ async def fake_create_session() -> GuestSession:
+ nonlocal create_calls
+ create_calls += 1
+ return _make_session("duplicate-user", str(create_calls))
+
+ monkeypatch.setattr(pool, "_create_session", fake_create_session)
+
+ await asyncio.wait_for(pool._ensure_capacity(), timeout=0.2)
+
+ assert create_calls >= 1
+ assert set(pool._sessions) == {"duplicate-user"}
+ assert len(pool._sessions) == 1
+
+
+@pytest.mark.asyncio
+async def test_initialize_logs_unique_session_count_when_results_contain_duplicates(
+ monkeypatch,
+):
+ pool = GuestSessionPool(pool_size=3)
+ sessions = [
+ _make_session("user-1", "1"),
+ _make_session("user-1", "2"),
+ _make_session("user-2", "3"),
+ _make_session("user-1", "4"),
+ _make_session("user-2", "5"),
+ _make_session("user-1", "6"),
+ _make_session("user-2", "7"),
+ _make_session("user-1", "8"),
+ _make_session("user-2", "9"),
+ ]
+ info_mock = Mock()
+
+ async def fake_create_session() -> GuestSession:
+ return sessions.pop(0)
+
+ monkeypatch.setattr(pool, "_create_session", fake_create_session)
+ monkeypatch.setattr(pool, "_maintenance_loop", AsyncMock(return_value=None))
+ monkeypatch.setattr(guest_pool_module.logger, "info", info_mock)
+ monkeypatch.setattr(guest_pool_module.logger, "warning", Mock())
+
+ await pool.initialize()
+ await asyncio.sleep(0)
+
+ assert set(pool._sessions) == {"user-1", "user-2"}
+ assert any(
+ call.args == ("✅ 匿名会话池初始化完成: 2 个会话",)
+ for call in info_mock.call_args_list
+ )
+
+
+@pytest.mark.asyncio
+async def test_acquire_skips_duplicate_excluded_session_without_overwriting_pool(
+ monkeypatch,
+):
+ pool = GuestSessionPool(pool_size=2)
+ existing = _make_session("user-1", "seed")
+ pool._sessions[existing.user_id] = existing
+ created_sessions = [
+ _make_session("user-1", "duplicate"),
+ _make_session("user-2", "fresh"),
+ ]
+
+ async def fake_create_session() -> GuestSession:
+ return created_sessions.pop(0)
+
+ monkeypatch.setattr(pool, "_create_session", fake_create_session)
+
+ acquired = await pool.acquire(exclude_user_ids={"user-1"})
+
+ assert acquired.user_id == "user-2"
+ assert acquired.active_requests == 1
+ assert set(pool._sessions) == {"user-1", "user-2"}
+ assert pool._sessions["user-1"].token == "token-seed"
diff --git a/tests/test_request_logging.py b/tests/test_request_logging.py
new file mode 100644
index 0000000000000000000000000000000000000000..c5d46e64a936975f4f72f77b11bb1e43a9fc100b
--- /dev/null
+++ b/tests/test_request_logging.py
@@ -0,0 +1,48 @@
+from app.utils.request_logging import (
+ extract_claude_usage,
+ extract_openai_usage,
+)
+
+
+def test_extract_openai_usage_supports_cached_prompt_details():
+ usage = extract_openai_usage(
+ {
+ "usage": {
+ "prompt_tokens": 120,
+ "completion_tokens": 45,
+ "total_tokens": 165,
+ "prompt_tokens_details": {
+ "cached_tokens": 32,
+ },
+ }
+ }
+ )
+
+ assert usage == {
+ "input_tokens": 120,
+ "output_tokens": 45,
+ "cache_creation_tokens": 0,
+ "cache_read_tokens": 32,
+ "total_tokens": 165,
+ }
+
+
+def test_extract_claude_usage_supports_cache_token_fields():
+ usage = extract_claude_usage(
+ {
+ "usage": {
+ "input_tokens": 200,
+ "output_tokens": 80,
+ "cache_creation_input_tokens": 64,
+ "cache_read_input_tokens": 48,
+ }
+ }
+ )
+
+ assert usage == {
+ "input_tokens": 200,
+ "output_tokens": 80,
+ "cache_creation_tokens": 64,
+ "cache_read_tokens": 48,
+ "total_tokens": 392,
+ }
diff --git a/tests/test_token_automation.py b/tests/test_token_automation.py
new file mode 100644
index 0000000000000000000000000000000000000000..37445758bf571a54eb610ea77259f926ff13de0b
--- /dev/null
+++ b/tests/test_token_automation.py
@@ -0,0 +1,51 @@
+import pytest
+
+from app.services.token_automation import run_token_maintenance
+from app.services.token_dao import TokenDAO
+from app.utils.token_pool import ZAITokenValidator
+
+
+@pytest.mark.asyncio
+async def test_run_token_maintenance_deletes_invalid_tokens_after_validation(
+ tmp_path,
+ monkeypatch,
+):
+ dao = TokenDAO(str(tmp_path / "tokens.db"))
+ await dao.init_database()
+
+ await dao.add_token("zai", "token-valid", validate=False)
+ await dao.add_token("zai", "token-guest", validate=False)
+ await dao.add_token("zai", "token-invalid", validate=False)
+
+ async def fake_validate_token(cls, token):
+ mapping = {
+ "token-valid": ("user", True, None),
+ "token-guest": ("guest", False, "guest token"),
+ "token-invalid": ("unknown", False, "token expired"),
+ }
+ return mapping[token]
+
+ monkeypatch.setattr(
+ ZAITokenValidator,
+ "validate_token",
+ classmethod(fake_validate_token),
+ )
+
+ summary = await run_token_maintenance(
+ provider="zai",
+ remove_duplicates=False,
+ run_health_check=False,
+ delete_invalid_tokens=True,
+ dao=dao,
+ pool=None,
+ )
+
+ assert summary.checked_count == 3
+ assert summary.valid_count == 1
+ assert summary.guest_count == 1
+ assert summary.invalid_count == 1
+ assert summary.deleted_invalid_count == 2
+
+ remaining_tokens = await dao.get_tokens_by_provider("zai", enabled_only=False)
+ assert [token["token"] for token in remaining_tokens] == ["token-valid"]
+ assert remaining_tokens[0]["token_type"] == "user"
diff --git a/tests/test_token_importer.py b/tests/test_token_importer.py
new file mode 100644
index 0000000000000000000000000000000000000000..ad5d34c50dd84f9478cc194857a10b91fcb3218d
--- /dev/null
+++ b/tests/test_token_importer.py
@@ -0,0 +1,70 @@
+import json
+
+import pytest
+
+from app.services.token_dao import TokenDAO
+from app.services.token_importer import import_tokens_from_directory
+
+
+@pytest.mark.asyncio
+async def test_import_tokens_from_directory_handles_duplicates_and_invalid_files(
+ tmp_path,
+):
+ source_dir = tmp_path / "source_tokens"
+ source_dir.mkdir()
+
+ (source_dir / "token_valid_1.json").write_text(
+ json.dumps(
+ {
+ "email": "alpha@example.com",
+ "token": "token-alpha",
+ "token_source": "context.cookie:token",
+ }
+ ),
+ encoding="utf-8",
+ )
+ (source_dir / "token_valid_2.json").write_text(
+ json.dumps(
+ {
+ "email": "beta@example.com",
+ "token": "token-beta",
+ "token_source": "context.cookie:token",
+ }
+ ),
+ encoding="utf-8",
+ )
+ (source_dir / "token_duplicate.json").write_text(
+ json.dumps(
+ {
+ "email": "alpha-dup@example.com",
+ "token": "token-alpha",
+ }
+ ),
+ encoding="utf-8",
+ )
+ (source_dir / "token_missing.json").write_text(
+ json.dumps({"email": "missing@example.com"}),
+ encoding="utf-8",
+ )
+ (source_dir / "token_invalid.json").write_text("{invalid json", encoding="utf-8")
+
+ dao = TokenDAO(str(tmp_path / "tokens.db"))
+ await dao.init_database()
+
+ summary = await import_tokens_from_directory(
+ source_dir,
+ provider="zai",
+ validate=False,
+ dao=dao,
+ )
+
+ assert summary.scanned_files == 5
+ assert summary.imported_count == 2
+ assert summary.duplicate_count == 1
+ assert summary.missing_token_count == 1
+ assert summary.invalid_json_count == 1
+ assert summary.invalid_token_count == 0
+
+ tokens = await dao.get_tokens_by_provider("zai", enabled_only=False)
+ imported_values = {item["token"] for item in tokens}
+ assert imported_values == {"token-alpha", "token-beta"}
diff --git a/tests/test_upstream_dual_pool.py b/tests/test_upstream_dual_pool.py
new file mode 100644
index 0000000000000000000000000000000000000000..2989a736bb67f9df7873d371b963d743f33fae3a
--- /dev/null
+++ b/tests/test_upstream_dual_pool.py
@@ -0,0 +1,407 @@
+import asyncio
+import types
+from dataclasses import dataclass
+from unittest.mock import AsyncMock
+
+import pytest
+
+from app.core import upstream as upstream_module
+from app.core.upstream import UpstreamClient
+from app.models.schemas import Message, OpenAIRequest
+from app.utils.guest_session_pool import GuestSession, GuestSessionPool
+
+AUTH_POOL_SIZE = 2
+GUEST_POOL_SIZE = 2
+AUTH_REQUEST_COUNT = 6
+MIXED_REQUEST_DELAY = 0.01
+
+
+def _make_request() -> OpenAIRequest:
+ return OpenAIRequest(
+ model="GLM-4.5",
+ messages=[Message(role="user", content="ping")],
+ stream=False,
+ )
+
+
+def _make_guest_session(user_id: str) -> GuestSession:
+ return GuestSession(
+ token=f"guest-token-{user_id}",
+ user_id=user_id,
+ username=f"Guest-{user_id}",
+ )
+
+
+@dataclass
+class StubTokenPool:
+ tokens: list[str]
+
+ def __post_init__(self):
+ self.failure_tokens: list[str] = []
+ self.success_tokens: list[str] = []
+
+ def get_next_token(self, exclude_tokens=None):
+ excluded = exclude_tokens or set()
+ for token in self.tokens:
+ if token not in excluded:
+ return token
+ return None
+
+ async def record_token_failure(self, token: str, error=None, dao=None):
+ self.failure_tokens.append(token)
+
+ async def record_token_success(self, token: str, dao=None):
+ self.success_tokens.append(token)
+
+ def get_pool_status(self):
+ return {"available_tokens": len(self.tokens)}
+
+
+class FakeResponse:
+ def __init__(self, status_code: int, text: str = "{}"):
+ self.status_code = status_code
+ self.text = text
+
+ @property
+ def is_success(self) -> bool:
+ return 200 <= self.status_code < 300
+
+
+def _build_fake_async_client(handler):
+ class FakeAsyncClient:
+ def __init__(self, *args, **kwargs):
+ pass
+
+ async def __aenter__(self):
+ return self
+
+ async def __aexit__(self, exc_type, exc, tb):
+ return False
+
+ async def post(self, url, headers=None, json=None):
+ return await handler(headers or {})
+
+ return FakeAsyncClient
+
+
+async def _build_guest_pool(
+ monkeypatch,
+ *,
+ pool_size: int,
+ user_ids: list[str],
+) -> GuestSessionPool:
+ pool = GuestSessionPool(pool_size=pool_size)
+ queue = iter(user_ids)
+
+ async def fake_create_session() -> GuestSession:
+ return _make_guest_session(next(queue))
+
+ monkeypatch.setattr(pool, "_create_session", fake_create_session)
+ monkeypatch.setattr(pool, "_maintenance_loop", AsyncMock(return_value=None))
+ monkeypatch.setattr(pool, "_delete_all_chats", AsyncMock(return_value=True))
+ await pool.initialize()
+ await asyncio.sleep(0)
+ return pool
+
+
+def _patch_upstream_dependencies(
+ monkeypatch,
+ *,
+ token_pool,
+ guest_pool,
+ async_client_cls,
+):
+ monkeypatch.setattr(upstream_module, "get_token_pool", lambda: token_pool)
+ monkeypatch.setattr(upstream_module, "get_guest_session_pool", lambda: guest_pool)
+ monkeypatch.setattr(upstream_module.settings, "ANONYMOUS_MODE", True)
+ monkeypatch.setattr(
+ upstream_module.settings,
+ "GUEST_POOL_SIZE",
+ guest_pool.pool_size if guest_pool else 1,
+ )
+ monkeypatch.setattr(upstream_module.httpx, "AsyncClient", async_client_cls)
+
+
+def _bind_minimal_request_flow(client: UpstreamClient, captures: list[dict]):
+ async def fake_transform_request(
+ self,
+ request,
+ excluded_tokens=None,
+ excluded_guest_user_ids=None,
+ ):
+ auth_info = await self.get_auth_info(
+ excluded_tokens=excluded_tokens,
+ excluded_guest_user_ids=excluded_guest_user_ids,
+ )
+ captures.append(dict(auth_info))
+ return {
+ "url": "https://upstream.test/chat",
+ "headers": {
+ "x-token": str(auth_info["token"]),
+ "x-token-source": str(auth_info["token_source"]),
+ "x-guest-user-id": str(auth_info.get("guest_user_id") or ""),
+ },
+ "body": {"model": request.model},
+ "token": auth_info["token"],
+ "chat_id": "chat-id",
+ "model": request.model,
+ "user_id": auth_info["user_id"],
+ "auth_mode": auth_info["auth_mode"],
+ "token_source": auth_info["token_source"],
+ "guest_user_id": auth_info["guest_user_id"],
+ }
+
+ async def fake_transform_response(self, response, request, transformed):
+ return {
+ "ok": response.is_success,
+ "token_source": transformed["token_source"],
+ "token": transformed["token"],
+ "guest_user_id": transformed["guest_user_id"],
+ }
+
+ client.transform_request = types.MethodType(fake_transform_request, client)
+ client.transform_response = types.MethodType(fake_transform_response, client)
+
+
+async def _run_chat_requests(client: UpstreamClient, count: int) -> list[dict]:
+ tasks = [client.chat_completion(_make_request()) for _ in range(count)]
+ return await asyncio.gather(*tasks)
+
+
+@pytest.mark.asyncio
+async def test_authenticated_tokens_are_used_before_guest_pool(monkeypatch):
+ token_pool = StubTokenPool(["auth-1"])
+ guest_pool = await _build_guest_pool(
+ monkeypatch,
+ pool_size=GUEST_POOL_SIZE,
+ user_ids=["guest-1", "guest-2"],
+ )
+ captures: list[dict] = []
+ acquire_calls = 0
+
+ async def counted_acquire(*args, **kwargs):
+ nonlocal acquire_calls
+ acquire_calls += 1
+ return await original_acquire(*args, **kwargs)
+
+ async def handler(headers):
+ await asyncio.sleep(MIXED_REQUEST_DELAY)
+ return FakeResponse(200)
+
+ client = UpstreamClient()
+ original_acquire = guest_pool.acquire
+ monkeypatch.setattr(guest_pool, "acquire", counted_acquire)
+ _bind_minimal_request_flow(client, captures)
+ _patch_upstream_dependencies(
+ monkeypatch,
+ token_pool=token_pool,
+ guest_pool=guest_pool,
+ async_client_cls=_build_fake_async_client(handler),
+ )
+
+ try:
+ results = await _run_chat_requests(client, AUTH_REQUEST_COUNT)
+ pool_status = guest_pool.get_pool_status()
+ finally:
+ await guest_pool.close()
+
+ assert all(result["ok"] is True for result in results)
+ assert all(item["token_source"] == "auth_pool" for item in captures)
+ assert acquire_calls == 0
+ assert token_pool.success_tokens == ["auth-1"] * AUTH_REQUEST_COUNT
+ assert token_pool.failure_tokens == []
+ assert pool_status["busy_sessions"] == 0
+ assert pool_status["available_sessions"] == GUEST_POOL_SIZE
+
+
+@pytest.mark.asyncio
+async def test_authenticated_401_retries_next_token_before_guest_fallback(monkeypatch):
+ token_pool = StubTokenPool(["auth-1", "auth-2"])
+ guest_pool = await _build_guest_pool(
+ monkeypatch,
+ pool_size=GUEST_POOL_SIZE,
+ user_ids=["guest-1", "guest-2"],
+ )
+ captures: list[dict] = []
+ acquire_calls = 0
+
+ async def counted_acquire(*args, **kwargs):
+ nonlocal acquire_calls
+ acquire_calls += 1
+ return await original_acquire(*args, **kwargs)
+
+ async def handler(headers):
+ token = headers["x-token"]
+ if token == "auth-1":
+ return FakeResponse(401, '{"message":"expired"}')
+ return FakeResponse(200)
+
+ client = UpstreamClient()
+ original_acquire = guest_pool.acquire
+ monkeypatch.setattr(guest_pool, "acquire", counted_acquire)
+ _bind_minimal_request_flow(client, captures)
+ _patch_upstream_dependencies(
+ monkeypatch,
+ token_pool=token_pool,
+ guest_pool=guest_pool,
+ async_client_cls=_build_fake_async_client(handler),
+ )
+
+ try:
+ result = await client.chat_completion(_make_request())
+ finally:
+ await guest_pool.close()
+
+ assert result["ok"] is True
+ assert [item["token"] for item in captures] == ["auth-1", "auth-2"]
+ assert [item["token_source"] for item in captures] == ["auth_pool", "auth_pool"]
+ assert token_pool.failure_tokens == ["auth-1"]
+ assert token_pool.success_tokens == ["auth-2"]
+ assert acquire_calls == 0
+
+
+@pytest.mark.asyncio
+async def test_authenticated_pool_exhaustion_falls_back_to_guest(monkeypatch):
+ token_pool = StubTokenPool(["auth-1", "auth-2"])
+ guest_pool = await _build_guest_pool(
+ monkeypatch,
+ pool_size=GUEST_POOL_SIZE,
+ user_ids=["guest-1", "guest-2", "guest-3"],
+ )
+ captures: list[dict] = []
+
+ async def handler(headers):
+ if headers["x-token-source"] == "auth_pool":
+ return FakeResponse(401, '{"message":"expired"}')
+ return FakeResponse(200)
+
+ client = UpstreamClient()
+ _bind_minimal_request_flow(client, captures)
+ _patch_upstream_dependencies(
+ monkeypatch,
+ token_pool=token_pool,
+ guest_pool=guest_pool,
+ async_client_cls=_build_fake_async_client(handler),
+ )
+
+ try:
+ result = await client.chat_completion(_make_request())
+ pool_status = guest_pool.get_pool_status()
+ finally:
+ await guest_pool.close()
+
+ assert result["ok"] is True
+ assert [item["token_source"] for item in captures] == [
+ "auth_pool",
+ "auth_pool",
+ "guest_pool",
+ ]
+ assert token_pool.failure_tokens == ["auth-1", "auth-2"]
+ assert token_pool.success_tokens == []
+ assert result["guest_user_id"]
+ assert pool_status["busy_sessions"] == 0
+
+
+@pytest.mark.asyncio
+async def test_guest_retry_is_isolated_and_does_not_pollute_auth_stats(monkeypatch):
+ token_pool = StubTokenPool(["auth-1", "auth-2"])
+ guest_pool = await _build_guest_pool(
+ monkeypatch,
+ pool_size=GUEST_POOL_SIZE,
+ user_ids=["guest-1", "guest-2", "guest-3", "guest-4"],
+ )
+ captures: list[dict] = []
+
+ async def handler(headers):
+ source = headers["x-token-source"]
+ guest_user_id = headers["x-guest-user-id"]
+ if source == "auth_pool":
+ return FakeResponse(401, '{"message":"expired"}')
+ if guest_user_id == "guest-1":
+ return FakeResponse(401, '{"message":"expired"}')
+ return FakeResponse(200)
+
+ client = UpstreamClient()
+ _bind_minimal_request_flow(client, captures)
+ _patch_upstream_dependencies(
+ monkeypatch,
+ token_pool=token_pool,
+ guest_pool=guest_pool,
+ async_client_cls=_build_fake_async_client(handler),
+ )
+
+ try:
+ result = await client.chat_completion(_make_request())
+ pool_status = guest_pool.get_pool_status()
+ finally:
+ await guest_pool.close()
+
+ guest_ids = [
+ item["guest_user_id"]
+ for item in captures
+ if item["token_source"] == "guest_pool"
+ ]
+
+ assert result["ok"] is True
+ assert [item["token"] for item in captures[:2]] == ["auth-1", "auth-2"]
+ assert token_pool.failure_tokens == ["auth-1", "auth-2"]
+ assert token_pool.success_tokens == []
+ assert guest_ids[0] == "guest-1"
+ assert guest_ids[1] != "guest-1"
+ assert pool_status["busy_sessions"] == 0
+ assert pool_status["valid_sessions"] == GUEST_POOL_SIZE
+
+
+@pytest.mark.asyncio
+async def test_cleanup_idle_chats_only_touches_idle_valid_sessions(monkeypatch):
+ guest_pool = await _build_guest_pool(
+ monkeypatch,
+ pool_size=3,
+ user_ids=["guest-1", "guest-2", "guest-3"],
+ )
+ deleted_user_ids: list[str] = []
+
+ async def fake_delete_all_chats(session: GuestSession):
+ deleted_user_ids.append(session.user_id)
+ return True
+
+ monkeypatch.setattr(guest_pool, "_delete_all_chats", fake_delete_all_chats)
+ guest_pool._sessions["guest-2"].active_requests = 1
+
+ try:
+ await guest_pool.cleanup_idle_chats()
+ deleted_before_close = list(deleted_user_ids)
+ finally:
+ await guest_pool.close()
+
+ assert deleted_before_close == ["guest-1", "guest-3"]
+
+
+@pytest.mark.asyncio
+async def test_report_failure_only_retires_target_guest_session(monkeypatch):
+ guest_pool = await _build_guest_pool(
+ monkeypatch,
+ pool_size=3,
+ user_ids=["guest-1", "guest-2", "guest-3", "guest-4"],
+ )
+ deleted_user_ids: list[str] = []
+
+ async def fake_delete_all_chats(session: GuestSession):
+ deleted_user_ids.append(session.user_id)
+ return True
+
+ monkeypatch.setattr(guest_pool, "_delete_all_chats", fake_delete_all_chats)
+
+ try:
+ await guest_pool.report_failure("guest-1")
+ await asyncio.sleep(0)
+ current_user_ids = set(guest_pool._sessions)
+ deleted_before_close = list(deleted_user_ids)
+ finally:
+ await guest_pool.close()
+
+ assert "guest-1" not in current_user_ids
+ assert "guest-2" in current_user_ids
+ assert "guest-3" in current_user_ids
+ assert "guest-4" in current_user_ids
+ assert deleted_before_close == ["guest-1"]