diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000000000000000000000000000000000000..a934a149e4ab4a268b04208c5b6f1d10b22c6be4 --- /dev/null +++ b/.dockerignore @@ -0,0 +1,70 @@ +# Git files +.git +.gitignore +.github + +# Python cache +__pycache__ +*.py[cod] +*$py.class +*.so +.Python +*.egg-info/ +dist/ +build/ + +# Virtual environments +venv/ +env/ +ENV/ +.venv + +# IDE +.vscode/ +.idea/ +*.swp +*.swo + +# OS +.DS_Store +Thumbs.db + +# Documentation +*.md +!README.md +CONTRIBUTING.md +CHANGELOG.md +SECURITY.md + +# Test files +test_*.py +tests/ +.pytest_cache/ +.coverage +htmlcov/ +*.coverage + +# Development files +.editorconfig +.pre-commit-config.yaml +Makefile +setup-dev.sh +requirements-dev.txt + +# Logs +*.log +log.txt + +# Credentials (never include) +creds/ +*.json +!package.json +*.toml +!pyproject.toml +.env +.env.* + +# Temporary files +*.tmp +*.bak +tmp/ diff --git a/.editorconfig b/.editorconfig new file mode 100644 index 0000000000000000000000000000000000000000..462c14ecc29fcc851a0a2cd60f94370645e5808e --- /dev/null +++ b/.editorconfig @@ -0,0 +1,42 @@ +# EditorConfig helps maintain consistent coding styles across editors +# https://editorconfig.org + +root = true + +[*] +charset = utf-8 +end_of_line = lf +insert_final_newline = true +trim_trailing_whitespace = true + +[*.{py,pyi}] +indent_style = space +indent_size = 4 +max_line_length = 100 + +[*.{yml,yaml}] +indent_style = space +indent_size = 2 + +[*.{json,toml}] +indent_style = space +indent_size = 2 + +[*.{md,markdown}] +trim_trailing_whitespace = false +max_line_length = off + +[*.{sh,bat,ps1}] +indent_style = space +indent_size = 2 + +[Makefile] +indent_style = tab + +[*.js] +indent_style = space +indent_size = 2 + +[*.html] +indent_style = space +indent_size = 2 diff --git a/.env.example b/.env.example new file mode 100644 index 0000000000000000000000000000000000000000..5620277d883fea18658c28152bd596b4bf5bb55b --- /dev/null +++ b/.env.example @@ -0,0 +1,230 @@ +# ================================================================ +# GCLI2API 环境变量配置示例文件 +# 复制此文件为 .env 并根据需要修改配置值 +# ================================================================ + +# ================================================================ +# 服务器配置 +# ================================================================ + +# 服务器监听地址 +# 默认: 0.0.0.0 (监听所有网络接口) +HOST=0.0.0.0 + +# 服务器端口 +# 默认: 7861 +PORT=7861 + +# ================================================================ +# 密码配置 (支持分离密码) +# ================================================================ + +# 聊天API访问密码 (用于OpenAI和Gemini API端点认证) +# 默认: 继承通用密码或 pwd +API_PASSWORD=your_api_password + +# 控制面板访问密码 (用于Web界面登录认证) +# 默认: 继承通用密码或 pwd +PANEL_PASSWORD=your_panel_password + +# 通用访问密码 (兼容性保留) +# 设置后会覆盖上述两个专用密码,优先级最高 +# 如果只想使用一个密码,设置此项即可 +# 默认: pwd +PASSWORD=pwd + +# ================================================================ +# 存储配置 +# ================================================================ + +# 存储后端优先级: MongoDB > 本地sqlite文件存储 +# 系统会自动选择可用的最高优先级存储后端 + +# MongoDB 分布式存储模式配置 (第二优先级) +# 设置 MONGODB_URI 后自动启用 MongoDB 模式,不再使用本地文件存储 + +# MongoDB 连接字符串 (设置后启用 MongoDB 分布式存储模式) +# 本地 MongoDB: mongodb://localhost:27017 +# 带认证: mongodb://admin:password@localhost:27017/admin +# MongoDB Atlas: mongodb+srv://username:password@cluster.mongodb.net +# 副本集: mongodb://host1:27017,host2:27017,host3:27017/gcli2api?replicaSet=rs0 +# 默认: 无 (使用本地文件存储) +MONGODB_URI=mongodb://localhost:27017 + +# MongoDB 数据库名称 (仅在启用 MongoDB 模式时有效) +# 默认: gcli2api +MONGODB_DATABASE=gcli2api + +# ================================================================ +# Google API 配置 +# ================================================================ + +# 凭证文件目录 (仅在文件存储模式下使用) +# 默认: ./creds +CREDENTIALS_DIR=./creds + +# 是否自动从环境变量加载凭证 +# 默认: false +AUTO_LOAD_ENV_CREDS=false + +# Google 凭证环境变量配置 (可选,通过 GCLI_CREDS_* 环境变量提供凭证) +# 支持编号格式和项目名格式 +# GCLI_CREDS_1={"client_id":"your-client-id","client_secret":"your-secret","refresh_token":"your-token","token_uri":"https://oauth2.googleapis.com/token","project_id":"your-project"} +# GCLI_CREDS_2={"client_id":"...","project_id":"..."} +# GCLI_CREDS_myproject={"client_id":"...","project_id":"myproject",...} + +# ================================================================ +# 凭证轮换配置 +# ================================================================ + +# 每个凭证使用多少次调用后轮换到下一个 +# 默认: 100 +CALLS_PER_ROTATION=100 + +# 代理配置 (可选) +# 支持 http, https, socks5 代理 +# 格式: http://proxy:port, https://proxy:port, socks5://proxy:port +PROXY=http://localhost:7890 + +# Google API 代理 URL 配置 (可选) + +# Google Code Assist API 端点 +# 默认: https://cloudcode-pa.googleapis.com +CODE_ASSIST_ENDPOINT=https://cloudcode-pa.googleapis.com +# 用于Google OAuth2认证的代理URL +# 默认: https://oauth2.googleapis.com +OAUTH_PROXY_URL=https://oauth2.googleapis.com + +# 用于Google APIs调用的代理URL +# 默认: https://www.googleapis.com +GOOGLEAPIS_PROXY_URL=https://www.googleapis.com + +# 用于Google Cloud Resource Manager API的URL +# 默认: https://cloudresourcemanager.googleapis.com +RESOURCE_MANAGER_API_URL=https://cloudresourcemanager.googleapis.com + +# 用于Google Cloud Service Usage API的URL +# 默认: https://serviceusage.googleapis.com +SERVICE_USAGE_API_URL=https://serviceusage.googleapis.com + +# 用于Google Antigravity API的URL (反重力模式) +# 默认: https://daily-cloudcode-pa.sandbox.googleapis.com +ANTIGRAVITY_API_URL=https://daily-cloudcode-pa.sandbox.googleapis.com + +# ================================================================ +# 错误处理和重试配置 +# ================================================================ + +# 是否启用自动封禁功能 +# 当凭证返回特定错误码时自动禁用该凭证 +# 默认: false +AUTO_BAN=false + +# 自动封禁的错误码列表 (逗号分隔) +# 默认: 400,403 +AUTO_BAN_ERROR_CODES=400,403 + +# 是否启用 429 错误重试 +# 默认: true +RETRY_429_ENABLED=true + +# 429 错误最大重试次数 +# 默认: 5 +RETRY_429_MAX_RETRIES=5 + +# 429 错误重试间隔 (秒) +# 默认: 1 +RETRY_429_INTERVAL=1 + +# ================================================================ +# 日志配置 +# ================================================================ + +# 日志级别 +# 可选值: debug, info, warning, error, critical +# 默认: info +LOG_LEVEL=info + +# 日志文件路径 +# 默认: log.txt +LOG_FILE=log.txt + +# ================================================================ +# 高级功能配置 +# ================================================================ + +# 流式抗截断最大尝试次数 +# 用于 "流式抗截断/" 前缀的模型 +# 默认: 3 +ANTI_TRUNCATION_MAX_ATTEMPTS=3 + +# ================================================================ +# 环境变量使用说明 +# ================================================================ + +# 1. 存储模式配置 (按优先级自动选择): +# - Redis 分布式模式 (最高优先级): 设置 REDIS_URI,数据存储在 Redis 数据库,性能最佳 +# - MongoDB 分布式模式 (第二优先级): 设置 MONGODB_URI,数据存储在 MongoDB 数据库 +# - 文件存储模式 (默认): 不设置上述 URI,数据存储在本地 creds/ 目录 +# - 自动切换: 系统根据可用的存储配置自动选择最高优先级的存储后端 + +# 2. 凭证配置方式 (三选一): +# a) 将 Google 凭证 JSON 文件放在 CREDENTIALS_DIR 目录中 (仅文件模式) +# b) 设置 AUTO_LOAD_ENV_CREDS=true,通过 GOOGLE_CREDENTIALS 等环境变量直接提供 +# c) 通过面板导入 + +# 3. 密码配置优先级: +# a) PASSWORD 环境变量 (最高优先级,设置后覆盖其他密码) +# b) API_PASSWORD / PANEL_PASSWORD 环境变量 (专用密码) +# c) config.toml 文件中的密码配置 +# d) 默认值 "pwd" +# +# 4. 通用配置优先级: +# 环境变量 > config.toml 文件 > 默认值 + +# 5. 布尔值环境变量: +# true/1/yes/on 表示启用 +# false/0/no/off 表示禁用 + +# 6. 模型功能说明: +# - 基础模型: gemini-2.5-pro, gemini-2.5-flash 等 +# - 功能前缀: +# * "假流式/" - 使用假流式传输 +# * "流式抗截断/" - 启用流式抗截断功能 +# - 功能后缀: +# * "-maxthinking" - 最大思考预算 +# * "-nothinking" - 禁用思考模式 +# * "-search" - 启用 Google 搜索 + +# 7. 示例模型名称: +# - gemini-2.5-pro +# - 假流式/gemini-2.5-pro-maxthinking +# - 流式抗截断/gemini-2.5-flash-search + +# ================================================================ +# 配置文件说明 +# ================================================================ + +# 除了环境变量,你还可以使用 TOML 配置文件进行配置 +# 配置文件位置: {CREDENTIALS_DIR}/config.toml +# +# # 密码配置 +# api_password = "your_api_password" # 聊天API密码 +# panel_password = "your_panel_password" # 控制面板密码 +# password = "your_common_password" # 通用密码 (覆盖上述两个) +# +# # 基础配置 +# calls_per_rotation = 100 +# +# [retry] +# retry_429_enabled = true +# retry_429_max_retries = 5 +# retry_429_interval = 1 +# +# [logging] +# log_level = "info" +# log_file = "log.txt" +# +# [auto_ban] +# auto_ban_enabled = false +# auto_ban_error_codes = [400, 403] diff --git a/.flake8 b/.flake8 new file mode 100644 index 0000000000000000000000000000000000000000..5bceba01e8b21ee62a65f00e9f00177aa3713d91 --- /dev/null +++ b/.flake8 @@ -0,0 +1,13 @@ +[flake8] +max-line-length = 100 +extend-ignore = E203, W503, E501 +exclude = + .git, + __pycache__, + .venv, + venv, + gcli, + build, + dist, + .eggs, + *.egg-info diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..c05a05aa8a9b8a8f1796fa00640685760fcb0244 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +docs/qq群.jpg filter=lfs diff=lfs merge=lfs -text diff --git a/.github/workflows/docker-publish.yml b/.github/workflows/docker-publish.yml new file mode 100644 index 0000000000000000000000000000000000000000..fc40808e83ba8cf46964756cc6bf6cf0ee516bcb --- /dev/null +++ b/.github/workflows/docker-publish.yml @@ -0,0 +1,81 @@ +name: Docker Build and Publish + +on: + workflow_run: + workflows: ["Update Version File"] + types: + - completed + branches: + - master + - main + push: + tags: + - 'v*' + pull_request: + branches: + - master + - main + workflow_dispatch: + +env: + REGISTRY: ghcr.io + IMAGE_NAME: ${{ github.repository }} + +jobs: + build-and-push: + runs-on: ubuntu-latest + # 只在 workflow_run 成功时运行,或者非 workflow_run 触发时运行 + if: ${{ github.event_name != 'workflow_run' || github.event.workflow_run.conclusion == 'success' }} + permissions: + contents: read + packages: write + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + with: + # workflow_run 触发时需要获取最新的代码(包括 version.txt 的更新) + ref: ${{ github.event_name == 'workflow_run' && github.event.workflow_run.head_branch || github.ref }} + + - name: Set up QEMU + uses: docker/setup-qemu-action@v3 + + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 + + - name: Log in to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ${{ env.REGISTRY }} + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + + - name: Extract metadata + id: meta + uses: docker/metadata-action@v5 + with: + images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }} + tags: | + type=ref,event=branch + type=ref,event=tag + type=ref,event=pr + type=raw,value=latest,enable={{is_default_branch}} + type=sha,prefix={{branch}}- + type=semver,pattern={{version}} + type=semver,pattern={{major}}.{{minor}} + type=semver,pattern={{major}} + + - name: Build and push Docker image + uses: docker/build-push-action@v5 + with: + context: . + platforms: linux/amd64,linux/arm64 + push: ${{ github.event_name != 'pull_request' }} + tags: ${{ steps.meta.outputs.tags }} + labels: ${{ steps.meta.outputs.labels }} + cache-from: type=gha + cache-to: type=gha,mode=max + build-args: | + BUILD_DATE=${{ fromJSON(steps.meta.outputs.json).labels['org.opencontainers.image.created'] }} + VERSION=${{ fromJSON(steps.meta.outputs.json).labels['org.opencontainers.image.version'] }} + REVISION=${{ github.sha }} \ No newline at end of file diff --git a/.github/workflows/update-version.yml b/.github/workflows/update-version.yml new file mode 100644 index 0000000000000000000000000000000000000000..93b018ca576fb727cd3d12f74d893db1ce29129a --- /dev/null +++ b/.github/workflows/update-version.yml @@ -0,0 +1,51 @@ +name: Update Version File + +on: + push: + branches: + - master + - main + +jobs: + update-version: + runs-on: ubuntu-latest + permissions: + contents: write + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + with: + fetch-depth: 0 + token: ${{ secrets.GITHUB_TOKEN }} + + - name: Update version.txt + run: | + # 获取最新commit信息 + FULL_HASH=$(git log -1 --format=%H) + SHORT_HASH=$(git log -1 --format=%h) + MESSAGE=$(git log -1 --format=%s) + DATE=$(git log -1 --format=%ci) + + # 写入version.txt + echo "full_hash=$FULL_HASH" > version.txt + echo "short_hash=$SHORT_HASH" >> version.txt + echo "message=$MESSAGE" >> version.txt + echo "date=$DATE" >> version.txt + + echo "Version file updated:" + cat version.txt + + - name: Commit version.txt if changed + run: | + git config --local user.email "github-actions[bot]@users.noreply.github.com" + git config --local user.name "github-actions[bot]" + + # 检查是否有变化 + if git diff --quiet version.txt; then + echo "No changes to version.txt" + else + git add version.txt + git commit -m "chore: update version.txt [skip ci]" + git push + fi diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..d1a9d3b9262585fb2d9daf2a7d754300b5be3c6d --- /dev/null +++ b/.gitignore @@ -0,0 +1,98 @@ +# Credential files - should never be committed +*.json +!package.json +!package-lock.json +!tsconfig.json +*.toml +!pyproject.toml +creds/ +CLAUDE.md +GEMINI.md +.kiro +# Environment configuration +.env + +# Python +uv.lock +__pycache__/ +*.py[cod] +*$py.class +*.so +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ + +# Virtual environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# IDE +.vscode/ +.idea/ +.claude/ +*.swp +*.swo +*~ + +# OS +.DS_Store +.DS_Store? +._* +.Spotlight-V100 +.Trashes +ehthumbs.db +Thumbs.db + +# Logs +*.log +log.txt + +# Temporary files +*.tmp +*.temp +*.bak + +tools/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d546a90fe4c8db744da8d33f127c82b3926f3fa8 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,33 @@ +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.5.0 + hooks: + - id: trailing-whitespace + - id: end-of-file-fixer + - id: check-yaml + - id: check-added-large-files + args: ['--maxkb=1000'] + - id: check-json + - id: check-toml + - id: check-merge-conflict + - id: detect-private-key + + - repo: https://github.com/psf/black + rev: 24.1.1 + hooks: + - id: black + args: [--line-length=100] + language_version: python3.12 + + - repo: https://github.com/pycqa/flake8 + rev: 7.0.0 + hooks: + - id: flake8 + args: [--max-line-length=100, --extend-ignore=E203,W503] + additional_dependencies: [flake8-docstrings] + + - repo: https://github.com/pycqa/isort + rev: 5.13.2 + hooks: + - id: isort + args: [--profile=black, --line-length=100] diff --git a/.python-version b/.python-version new file mode 100644 index 0000000000000000000000000000000000000000..e4fba2183587225f216eeada4c78dfab6b2e65f5 --- /dev/null +++ b/.python-version @@ -0,0 +1 @@ +3.12 diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000000000000000000000000000000000000..2e237c1dbe43b760b1e2aa2cac6546b69eb7c735 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,169 @@ +# Contributing to gcli2api + +First off, thank you for considering contributing to gcli2api! It's people like you that make gcli2api such a great tool. + +## Code of Conduct + +This project is intended for personal learning and research purposes only. By participating, you are expected to uphold this code and respect the CNC-1.0 license restrictions on commercial use. + +## How Can I Contribute? + +### Reporting Bugs + +Before creating bug reports, please check the existing issues to avoid duplicates. When you create a bug report, include as many details as possible: + +* **Use a clear and descriptive title** +* **Describe the exact steps to reproduce the problem** +* **Provide specific examples** - Include code snippets, configuration files, or log outputs +* **Describe the behavior you observed** and what you expected to see +* **Include environment details**: OS, Python version, Docker version (if applicable) + +### Suggesting Enhancements + +Enhancement suggestions are tracked as GitHub issues. When creating an enhancement suggestion, include: + +* **Use a clear and descriptive title** +* **Provide a detailed description** of the suggested enhancement +* **Explain why this enhancement would be useful** +* **List any alternative solutions** you've considered + +### Pull Requests + +1. Fork the repo and create your branch from `master` +2. If you've added code that should be tested, add tests +3. If you've changed APIs, update the documentation +4. Ensure the test suite passes +5. Make sure your code follows the existing style +6. Write a clear commit message + +## Development Setup + +### Prerequisites + +* Python 3.12 or higher +* pip or uv package manager + +### Setting Up Your Development Environment + +```bash +# Clone your fork +git clone https://github.com/YOUR_USERNAME/gcli2api.git +cd gcli2api + +# Install development dependencies +make install-dev +# or +pip install -e ".[dev]" + +# Copy environment example +cp .env.example .env +# Edit .env with your configuration +``` + +### Development Workflow + +```bash +# Run tests +make test + +# Format code +make format + +# Run linters +make lint + +# Run the application locally +make run +``` + +### Testing + +We use pytest for testing. All new features should include appropriate tests. + +```bash +# Run all tests +make test + +# Run with coverage +make test-cov + +# Run specific test file +python -m pytest test_tool_calling.py -v +``` + +### Code Style + +* We use [Black](https://black.readthedocs.io/) for code formatting (line length: 100) +* We use [flake8](https://flake8.pycqa.org/) for linting +* We use [mypy](http://mypy-lang.org/) for type checking (optional, but encouraged) + +```bash +# Format your code before committing +make format + +# Check if code is properly formatted +make format-check + +# Run linters +make lint +``` + +## Project Structure + +``` +gcli2api/ +├── src/ # Main source code +│ ├── auth.py # Authentication and OAuth +│ ├── credential_manager.py # Credential rotation +│ ├── openai_router.py # OpenAI-compatible endpoints +│ ├── gemini_router.py # Gemini native endpoints +│ ├── openai_transfer.py # Format conversion +│ ├── storage/ # Storage backends (Redis, MongoDB, Postgres, File) +│ └── ... +├── front/ # Frontend static files +├── tests/ # Test directory (to be created) +├── test_*.py # Test files (root level) +├── web.py # Main application entry point +├── config.py # Configuration management +└── requirements.txt # Production dependencies +``` + +## Coding Guidelines + +### Python Style + +* Follow PEP 8 guidelines +* Use type hints where appropriate +* Write docstrings for classes and functions +* Keep functions focused and concise + +### Commit Messages + +* Use the present tense ("Add feature" not "Added feature") +* Use the imperative mood ("Move cursor to..." not "Moves cursor to...") +* Limit the first line to 72 characters or less +* Reference issues and pull requests liberally after the first line + +### Documentation + +* Update the README.md if you change functionality +* Comment your code where necessary +* Update the .env.example if you add new configuration options + +## License + +By contributing to gcli2api, you agree that your contributions will be licensed under the CNC-1.0 license. This is a strict anti-commercial license - see [LICENSE](LICENSE) for details. + +### Important License Restrictions + +* ❌ No commercial use +* ❌ No use by companies with revenue > $1M USD +* ❌ No use by VC-backed or publicly traded companies +* ✅ Personal learning, research, and educational use only +* ✅ Open source integration (must follow same license) + +## Questions? + +Feel free to open an issue with your question or reach out to the maintainers. + +Thank you for contributing! 🎉 diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..1c86c3aaebe124d147285185dc588c9f46c9f42f --- /dev/null +++ b/Dockerfile @@ -0,0 +1,34 @@ +# Multi-stage build for gcli2api +FROM python:3.13-slim as base + +# Set environment variables +ENV PYTHONUNBUFFERED=1 \ + PYTHONDONTWRITEBYTECODE=1 \ + PIP_NO_CACHE_DIR=1 \ + PIP_DISABLE_PIP_VERSION_CHECK=1 \ + TZ=Asia/Shanghai + +# Install tzdata and set timezone +RUN apt-get update && \ + apt-get install -y --no-install-recommends tzdata && \ + ln -sf /usr/share/zoneinfo/Asia/Shanghai /etc/localtime && \ + echo "Asia/Shanghai" > /etc/timezone && \ + apt-get clean && \ + rm -rf /var/lib/apt/lists/* + +WORKDIR /app + +# Copy only requirements first for better caching +COPY requirements.txt . + +# Install Python dependencies +RUN pip install --no-cache-dir -r requirements.txt + +# Copy application code +COPY . . + +# Expose port +EXPOSE 7861 + +# Default command +CMD ["python", "web.py"] \ No newline at end of file diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..2317eb517b7aa0e90bc5b81db74b82ff885336e5 --- /dev/null +++ b/LICENSE @@ -0,0 +1,83 @@ +Cooperative Non-Commercial License (CNC-1.0) + +Copyright (c) 2024 gcli2api contributors + +Permission is hereby granted, free of charge, to any person or organization +obtaining a copy of this software and associated documentation files (the +"Software"), to use, copy, modify, merge, publish, distribute, and/or +sublicense the Software, subject to the following conditions: + +TERMS AND CONDITIONS: + +1. NON-COMMERCIAL USE ONLY + The Software may only be used for non-commercial purposes. Commercial use + is strictly prohibited without explicit written permission from the + copyright holders. + +2. DEFINITION OF COMMERCIAL USE + "Commercial use" includes but is not limited to: + a) Using the Software to provide paid services or products + b) Integrating the Software into commercial products or services + c) Using the Software in any business operation that generates revenue + d) Offering the Software as part of a paid subscription or service + e) Using the Software to compete with the original project commercially + +3. COPYLEFT REQUIREMENT + Any derivative works, modifications, or substantial portions of the Software + must be licensed under the same or substantially similar terms. This ensures + that all derivatives remain non-commercial and freely available. + +4. SOURCE CODE AVAILABILITY + If you distribute the Software or any derivative works, you must make the + complete source code available under the same license terms at no charge. + +5. ATTRIBUTION REQUIREMENT + You must retain all copyright notices, license notices, and attribution + statements in all copies or substantial portions of the Software. + +6. ANTI-CORPORATE CLAUSE + This Software may not be used by corporations with annual revenue exceeding + $1 million USD, venture capital backed companies, or publicly traded + companies without explicit written permission from the copyright holders. + +7. EDUCATIONAL AND RESEARCH EXEMPTION + Use by educational institutions, non-profit research organizations, and + individual researchers for educational or research purposes is explicitly + permitted and encouraged. + +8. MODIFICATION AND CONTRIBUTION + Modifications and contributions to the Software are welcomed and encouraged, + provided they comply with these license terms. Contributors grant the same + license to their contributions. + +9. PATENT GRANT + Each contributor grants you a non-exclusive, worldwide, royalty-free patent + license to make, have made, use, offer to sell, sell, import, and otherwise + transfer the Work for non-commercial purposes only. + +10. TERMINATION + This license automatically terminates if you violate any of its terms. + Upon termination, you must cease all use and distribution of the Software + and destroy all copies in your possession. + +11. LIABILITY DISCLAIMER + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + SOFTWARE. + +12. JURISDICTION + This license shall be governed by and construed in accordance with the laws + of the jurisdiction where the copyright holder resides. + +SUMMARY: +This license allows free use, modification, and distribution of the Software +for non-commercial purposes only. It explicitly prohibits commercial use and +ensures that all derivatives remain freely available under the same terms. +The license promotes cooperative development while preventing commercial +exploitation of the community's work. + +For commercial licensing inquiries, please contact the copyright holders. \ No newline at end of file diff --git a/Makefile b/Makefile new file mode 100644 index 0000000000000000000000000000000000000000..8f37465d7010bfd7a32b9cab03ebddc06e8481d5 --- /dev/null +++ b/Makefile @@ -0,0 +1,64 @@ +.PHONY: help install install-dev test lint format clean run docker-build docker-run docker-compose-up docker-compose-down + +help: + @echo "gcli2api - Development Commands" + @echo "" + @echo "Available commands:" + @echo " make install - Install production dependencies" + @echo " make install-dev - Install development dependencies" + @echo " make test - Run tests" + @echo " make test-cov - Run tests with coverage report" + @echo " make lint - Run linters (flake8, mypy)" + @echo " make format - Format code with black" + @echo " make format-check - Check code formatting without making changes" + @echo " make clean - Clean build artifacts and cache" + @echo " make run - Run the application" + @echo " make docker-build - Build Docker image" + @echo " make docker-run - Run Docker container" + @echo " make docker-compose-up - Start services with docker-compose" + @echo " make docker-compose-down - Stop services with docker-compose" + +install: + pip install -r requirements.txt + +install-dev: + pip install -e ".[dev]" + pip install -r requirements-dev.txt + +test: + python -m pytest -v + +test-cov: + python -m pytest --cov=src --cov-report=term-missing --cov-report=html + +lint: + python -m flake8 src/ web.py config.py log.py --max-line-length=100 --extend-ignore=E203,W503 + python -m mypy src/ --ignore-missing-imports + +format: + python -m black src/ web.py config.py log.py test_*.py + +format-check: + python -m black --check src/ web.py config.py log.py test_*.py + +clean: + find . -type d -name "__pycache__" -exec rm -rf {} + 2>/dev/null || true + find . -type f -name "*.pyc" -delete + find . -type f -name "*.pyo" -delete + find . -type f -name "*.log" -delete + rm -rf .pytest_cache .mypy_cache .coverage htmlcov/ build/ dist/ *.egg-info + +run: + python web.py + +docker-build: + docker build -t gcli2api:latest . + +docker-run: + docker run -d --name gcli2api --network host -e PASSWORD=pwd -e PORT=7861 -v $$(pwd)/data/creds:/app/creds gcli2api:latest + +docker-compose-up: + docker-compose up -d + +docker-compose-down: + docker-compose down diff --git a/README.md b/README.md index 95431ee7e22f86f2a43c42e634124893ba1662b2..ad86873cd594fc2542932956cbb57df879ad860f 100644 --- a/README.md +++ b/README.md @@ -1,10 +1,15 @@ --- -title: 2api -emoji: 🌍 -colorFrom: red -colorTo: indigo +title: "2api" +emoji: "🚀" +colorFrom: blue +colorTo: green sdk: docker -pinned: false +app_port: 7861 --- -Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference +### 🚀 一键部署 +[![Deploy with HFSpaceDeploy](https://img.shields.io/badge/Deploy_with-HFSpaceDeploy-green?style=social&logo=rocket)](https://github.com/kfcx/HFSpaceDeploy) + +本项目由[HFSpaceDeploy](https://github.com/kfcx/HFSpaceDeploy)一键部署 + + diff --git a/config.py b/config.py new file mode 100644 index 0000000000000000000000000000000000000000..11114c4df446287037b1a4fa0ae3239a8c8a70a9 --- /dev/null +++ b/config.py @@ -0,0 +1,441 @@ +""" +Configuration constants for the Geminicli2api proxy server. +Centralizes all configuration to avoid duplication across modules. + +- 启动时加载一次配置到内存 +- 修改配置时调用 reload_config() 重新从数据库加载 +""" + +import os +from typing import Any, Optional + +# 全局配置缓存 +_config_cache: dict[str, Any] = {} +_config_initialized = False + +# Client Configuration + +# 需要自动封禁的错误码 (默认值,可通过环境变量或配置覆盖) +AUTO_BAN_ERROR_CODES = [403] + +# ====================== 环境变量映射表 ====================== +# 统一维护环境变量名和配置键名的映射关系 +# 格式: "环境变量名": "配置键名" +ENV_MAPPINGS = { + "CODE_ASSIST_ENDPOINT": "code_assist_endpoint", + "CREDENTIALS_DIR": "credentials_dir", + "PROXY": "proxy", + "OAUTH_PROXY_URL": "oauth_proxy_url", + "GOOGLEAPIS_PROXY_URL": "googleapis_proxy_url", + "RESOURCE_MANAGER_API_URL": "resource_manager_api_url", + "SERVICE_USAGE_API_URL": "service_usage_api_url", + "ANTIGRAVITY_API_URL": "antigravity_api_url", + "AUTO_BAN": "auto_ban_enabled", + "AUTO_BAN_ERROR_CODES": "auto_ban_error_codes", + "RETRY_429_MAX_RETRIES": "retry_429_max_retries", + "RETRY_429_ENABLED": "retry_429_enabled", + "RETRY_429_INTERVAL": "retry_429_interval", + "ANTI_TRUNCATION_MAX_ATTEMPTS": "anti_truncation_max_attempts", + "COMPATIBILITY_MODE": "compatibility_mode_enabled", + "RETURN_THOUGHTS_TO_FRONTEND": "return_thoughts_to_frontend", + "ANTIGRAVITY_STREAM2NOSTREAM": "antigravity_stream2nostream", + "HOST": "host", + "PORT": "port", + "API_PASSWORD": "api_password", + "PANEL_PASSWORD": "panel_password", + "PASSWORD": "password", +} + + +# ====================== 配置系统 ====================== + +async def init_config(): + """初始化配置缓存(启动时调用一次)""" + global _config_cache, _config_initialized + + if _config_initialized: + return + + try: + from src.storage_adapter import get_storage_adapter + storage_adapter = await get_storage_adapter() + _config_cache = await storage_adapter.get_all_config() + _config_initialized = True + except Exception: + # 初始化失败时使用空缓存 + _config_cache = {} + _config_initialized = True + + +async def reload_config(): + """重新加载配置(修改配置后调用)""" + global _config_cache, _config_initialized + + try: + from src.storage_adapter import get_storage_adapter + storage_adapter = await get_storage_adapter() + + # 如果后端支持 reload_config_cache,调用它 + if hasattr(storage_adapter._backend, 'reload_config_cache'): + await storage_adapter._backend.reload_config_cache() + + # 重新加载配置缓存 + _config_cache = await storage_adapter.get_all_config() + _config_initialized = True + except Exception: + pass + + +def _get_cached_config(key: str, default: Any = None) -> Any: + """从内存缓存获取配置(同步)""" + return _config_cache.get(key, default) + + +async def get_config_value(key: str, default: Any = None, env_var: Optional[str] = None) -> Any: + """Get configuration value with priority: ENV > Storage > default.""" + # 确保配置已初始化 + if not _config_initialized: + await init_config() + + # Priority 1: Environment variable + if env_var and os.getenv(env_var): + return os.getenv(env_var) + + # Priority 2: Memory cache + value = _get_cached_config(key) + if value is not None: + return value + + return default + + +# Configuration getters - all async +async def get_proxy_config(): + """Get proxy configuration.""" + proxy_url = await get_config_value("proxy", env_var="PROXY") + return proxy_url if proxy_url else None + + +async def get_auto_ban_enabled() -> bool: + """Get auto ban enabled setting.""" + env_value = os.getenv("AUTO_BAN") + if env_value: + return env_value.lower() in ("true", "1", "yes", "on") + + return bool(await get_config_value("auto_ban_enabled", False)) + + +async def get_auto_ban_error_codes() -> list: + """ + Get auto ban error codes. + + Environment variable: AUTO_BAN_ERROR_CODES (comma-separated, e.g., "400,403") + Database config key: auto_ban_error_codes + Default: [400, 403] + """ + env_value = os.getenv("AUTO_BAN_ERROR_CODES") + if env_value: + try: + return [int(code.strip()) for code in env_value.split(",") if code.strip()] + except ValueError: + pass + + codes = await get_config_value("auto_ban_error_codes") + if codes and isinstance(codes, list): + return codes + return AUTO_BAN_ERROR_CODES + + +async def get_retry_429_max_retries() -> int: + """Get max retries for 429 errors.""" + env_value = os.getenv("RETRY_429_MAX_RETRIES") + if env_value: + try: + return int(env_value) + except ValueError: + pass + + return int(await get_config_value("retry_429_max_retries", 5)) + + +async def get_retry_429_enabled() -> bool: + """Get 429 retry enabled setting.""" + env_value = os.getenv("RETRY_429_ENABLED") + if env_value: + return env_value.lower() in ("true", "1", "yes", "on") + + return bool(await get_config_value("retry_429_enabled", True)) + + +async def get_retry_429_interval() -> float: + """Get 429 retry interval in seconds.""" + env_value = os.getenv("RETRY_429_INTERVAL") + if env_value: + try: + return float(env_value) + except ValueError: + pass + + return float(await get_config_value("retry_429_interval", 0.1)) + + +async def get_anti_truncation_max_attempts() -> int: + """ + Get maximum attempts for anti-truncation continuation. + + Environment variable: ANTI_TRUNCATION_MAX_ATTEMPTS + Database config key: anti_truncation_max_attempts + Default: 3 + """ + env_value = os.getenv("ANTI_TRUNCATION_MAX_ATTEMPTS") + if env_value: + try: + return int(env_value) + except ValueError: + pass + + return int(await get_config_value("anti_truncation_max_attempts", 3)) + + +# Server Configuration +async def get_server_host() -> str: + """ + Get server host setting. + + Environment variable: HOST + Database config key: host + Default: 0.0.0.0 + """ + return str(await get_config_value("host", "0.0.0.0", "HOST")) + + +async def get_server_port() -> int: + """ + Get server port setting. + + Environment variable: PORT + Database config key: port + Default: 7861 + """ + env_value = os.getenv("PORT") + if env_value: + try: + return int(env_value) + except ValueError: + pass + + return int(await get_config_value("port", 7861)) + + +async def get_api_password() -> str: + """ + Get API password setting for chat endpoints. + + Environment variable: API_PASSWORD + Database config key: api_password + Default: Uses PASSWORD env var for compatibility, otherwise 'pwd' + """ + # 优先使用 API_PASSWORD,如果没有则使用通用 PASSWORD 保证兼容性 + api_password = await get_config_value("api_password", None, "API_PASSWORD") + if api_password is not None: + return str(api_password) + + # 兼容性:使用通用密码 + return str(await get_config_value("password", "pwd", "PASSWORD")) + + +async def get_panel_password() -> str: + """ + Get panel password setting for web interface. + + Environment variable: PANEL_PASSWORD + Database config key: panel_password + Default: Uses PASSWORD env var for compatibility, otherwise 'pwd' + """ + # 优先使用 PANEL_PASSWORD,如果没有则使用通用 PASSWORD 保证兼容性 + panel_password = await get_config_value("panel_password", None, "PANEL_PASSWORD") + if panel_password is not None: + return str(panel_password) + + # 兼容性:使用通用密码 + return str(await get_config_value("password", "pwd", "PASSWORD")) + + +async def get_server_password() -> str: + """ + Get server password setting (deprecated, use get_api_password or get_panel_password). + + Environment variable: PASSWORD + Database config key: password + Default: pwd + """ + return str(await get_config_value("password", "pwd", "PASSWORD")) + + +async def get_credentials_dir() -> str: + """ + Get credentials directory setting. + + Environment variable: CREDENTIALS_DIR + Database config key: credentials_dir + Default: ./creds + """ + return str(await get_config_value("credentials_dir", "./creds", "CREDENTIALS_DIR")) + + +async def get_code_assist_endpoint() -> str: + """ + Get Code Assist endpoint setting. + + Environment variable: CODE_ASSIST_ENDPOINT + Database config key: code_assist_endpoint + Default: https://cloudcode-pa.googleapis.com + """ + return str( + await get_config_value( + "code_assist_endpoint", "https://cloudcode-pa.googleapis.com", "CODE_ASSIST_ENDPOINT" + ) + ) + + +async def get_compatibility_mode_enabled() -> bool: + """ + Get compatibility mode setting. + + 兼容性模式:启用后所有system消息全部转换成user,停用system_instructions。 + 该选项可能会降低模型理解能力,但是能避免流式空回的情况。 + + Environment variable: COMPATIBILITY_MODE + Database config key: compatibility_mode_enabled + Default: False + """ + env_value = os.getenv("COMPATIBILITY_MODE") + if env_value: + return env_value.lower() in ("true", "1", "yes", "on") + + return bool(await get_config_value("compatibility_mode_enabled", False)) + + +async def get_return_thoughts_to_frontend() -> bool: + """ + Get return thoughts to frontend setting. + + 控制是否将思维链返回到前端。 + 启用后,思维链会在响应中返回;禁用后,思维链会在响应中被过滤掉。 + + Environment variable: RETURN_THOUGHTS_TO_FRONTEND + Database config key: return_thoughts_to_frontend + Default: True + """ + env_value = os.getenv("RETURN_THOUGHTS_TO_FRONTEND") + if env_value: + return env_value.lower() in ("true", "1", "yes", "on") + + return bool(await get_config_value("return_thoughts_to_frontend", True)) + + +async def get_antigravity_stream2nostream() -> bool: + """ + Get use stream for non-stream setting. + + 控制antigravity非流式请求是否使用流式API并收集为完整响应。 + 启用后,非流式请求将在后端使用流式API,然后收集所有块后再返回完整响应。 + + Environment variable: ANTIGRAVITY_STREAM2NOSTREAM + Database config key: antigravity_stream2nostream + Default: True + """ + env_value = os.getenv("ANTIGRAVITY_STREAM2NOSTREAM") + if env_value: + return env_value.lower() in ("true", "1", "yes", "on") + + return bool(await get_config_value("antigravity_stream2nostream", True)) + + +async def get_oauth_proxy_url() -> str: + """ + Get OAuth proxy URL setting. + + 用于Google OAuth2认证的代理URL。 + + Environment variable: OAUTH_PROXY_URL + Database config key: oauth_proxy_url + Default: https://oauth2.googleapis.com + """ + return str( + await get_config_value( + "oauth_proxy_url", "https://oauth2.googleapis.com", "OAUTH_PROXY_URL" + ) + ) + + +async def get_googleapis_proxy_url() -> str: + """ + Get Google APIs proxy URL setting. + + 用于Google APIs调用的代理URL。 + + Environment variable: GOOGLEAPIS_PROXY_URL + Database config key: googleapis_proxy_url + Default: https://www.googleapis.com + """ + return str( + await get_config_value( + "googleapis_proxy_url", "https://www.googleapis.com", "GOOGLEAPIS_PROXY_URL" + ) + ) + + +async def get_resource_manager_api_url() -> str: + """ + Get Google Cloud Resource Manager API URL setting. + + 用于Google Cloud Resource Manager API的URL。 + + Environment variable: RESOURCE_MANAGER_API_URL + Database config key: resource_manager_api_url + Default: https://cloudresourcemanager.googleapis.com + """ + return str( + await get_config_value( + "resource_manager_api_url", + "https://cloudresourcemanager.googleapis.com", + "RESOURCE_MANAGER_API_URL", + ) + ) + + +async def get_service_usage_api_url() -> str: + """ + Get Google Cloud Service Usage API URL setting. + + 用于Google Cloud Service Usage API的URL。 + + Environment variable: SERVICE_USAGE_API_URL + Database config key: service_usage_api_url + Default: https://serviceusage.googleapis.com + """ + return str( + await get_config_value( + "service_usage_api_url", "https://serviceusage.googleapis.com", "SERVICE_USAGE_API_URL" + ) + ) + + +async def get_antigravity_api_url() -> str: + """ + Get Antigravity API URL setting. + + 用于Google Antigravity API的URL。 + + Environment variable: ANTIGRAVITY_API_URL + Database config key: antigravity_api_url + Default: https://daily-cloudcode-pa.sandbox.googleapis.com + """ + return str( + await get_config_value( + "antigravity_api_url", + "https://daily-cloudcode-pa.sandbox.googleapis.com", + "ANTIGRAVITY_API_URL", + ) + ) diff --git a/darwin-install.sh b/darwin-install.sh new file mode 100644 index 0000000000000000000000000000000000000000..91681f26e2f240ad6495802aa0364b8a6768289f --- /dev/null +++ b/darwin-install.sh @@ -0,0 +1,55 @@ +#!/bin/bash +# macOS 安装脚本 (支持 Intel 和 Apple Silicon) + +# 确保 Homebrew 已安装 +if ! command -v brew &> /dev/null; then + echo "未检测到 Homebrew,开始安装..." + /bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Homebrew/install/HEAD/install.sh)" + + # 检测 Homebrew 安装路径并设置环境变量 + if [[ -f "/opt/homebrew/bin/brew" ]]; then + # Apple Silicon Mac + eval "$(/opt/homebrew/bin/brew shellenv)" + elif [[ -f "/usr/local/bin/brew" ]]; then + # Intel Mac + eval "$(/usr/local/bin/brew shellenv)" + fi +fi + +# 更新 brew 并安装 git +brew update +brew install git + +# 安装 uv (Python 环境管理工具) +curl -Ls https://astral.sh/uv/install.sh | sh + +# 确保 uv 在 PATH 中 +export PATH="$HOME/.local/bin:$PATH" + +# 克隆或进入项目目录 +if [ -f "./web.py" ]; then + # 已经在目标目录 + : +elif [ -f "./gcli2api/web.py" ]; then + cd ./gcli2api +else + git clone https://github.com/su-kaka/gcli2api.git + cd ./gcli2api +fi + +# 拉取最新代码 +git pull + +# 创建并同步虚拟环境 +uv sync + +# 激活虚拟环境 +if [ -f ".venv/bin/activate" ]; then + source .venv/bin/activate +else + echo "❌ 未找到虚拟环境,请检查 uv 是否安装成功" + exit 1 +fi + +# 启动项目 +python3 web.py diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 0000000000000000000000000000000000000000..a68f73708d39ec89e1d18fe0ca0c035611d0c7fa --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,85 @@ +version: '3.8' + +services: + gcli2api: + image: ghcr.io/su-kaka/gcli2api:latest + container_name: gcli2api + restart: unless-stopped + network_mode: host + environment: + # Password configuration (choose one) + # Option 1: Use common password + - PASSWORD=${PASSWORD:-pwd} + # Option 2: Use separate passwords (uncomment if needed) + # - API_PASSWORD=${API_PASSWORD:-your_api_password} + # - PANEL_PASSWORD=${PANEL_PASSWORD:-your_panel_password} + + # Server configuration + - PORT=${PORT:-7861} + - HOST=${HOST:-0.0.0.0} + + # Optional: Google credentials from environment + # - GOOGLE_CREDENTIALS=${GOOGLE_CREDENTIALS} + + # Optional: Logging configuration + # - LOG_LEVEL=${LOG_LEVEL:-info} + + # Optional: Redis configuration (for distributed storage) + # - REDIS_URI=${REDIS_URI} + # - REDIS_DATABASE=${REDIS_DATABASE:-0} + + # Optional: MongoDB configuration (for distributed storage) + # - MONGODB_URI=${MONGODB_URI} + # - MONGODB_DATABASE=${MONGODB_DATABASE:-gcli2api} + + # Optional: PostgreSQL configuration (for distributed storage) + # - POSTGRES_DSN=${POSTGRES_DSN} + + # Optional: Proxy configuration + # - PROXY=${PROXY} + volumes: + - ./data/creds:/app/creds + healthcheck: + test: ["CMD-SHELL", "python -c \"import sys, urllib.request, os; port = os.environ.get('PORT', '7861'); req = urllib.request.Request(f'http://localhost:{port}/v1/models', headers={'Authorization': 'Bearer ' + os.environ.get('PASSWORD', 'pwd')}); sys.exit(0 if urllib.request.urlopen(req, timeout=5).getcode() == 200 else 1)\""] + interval: 30s + timeout: 10s + retries: 3 + start_period: 40s + +# Example with Redis for distributed storage +# redis: +# image: redis:7-alpine +# container_name: gcli2api-redis +# restart: unless-stopped +# ports: +# - "6379:6379" +# volumes: +# - redis_data:/data +# command: redis-server --appendonly yes +# healthcheck: +# test: ["CMD", "redis-cli", "ping"] +# interval: 10s +# timeout: 3s +# retries: 3 + +# Example with MongoDB for distributed storage +# mongodb: +# image: mongo:7 +# container_name: gcli2api-mongodb +# restart: unless-stopped +# environment: +# MONGO_INITDB_ROOT_USERNAME: ${MONGO_USER:-admin} +# MONGO_INITDB_ROOT_PASSWORD: ${MONGO_PASSWORD:-password} +# ports: +# - "27017:27017" +# volumes: +# - mongodb_data:/data/db +# healthcheck: +# test: ["CMD", "mongosh", "--eval", "db.adminCommand('ping')"] +# interval: 10s +# timeout: 5s +# retries: 3 + +#volumes: +# redis_data: +# mongodb_data: diff --git a/docs/README_EN.md b/docs/README_EN.md new file mode 100644 index 0000000000000000000000000000000000000000..413fda2b98059a30018ca20a65a861adeca815fa --- /dev/null +++ b/docs/README_EN.md @@ -0,0 +1,958 @@ +# GeminiCLI to API + +**Convert GeminiCLI and Antigravity to OpenAI, GEMINI, and Claude API Compatible Interfaces** + +[![Python 3.12+](https://img.shields.io/badge/python-3.12+-blue.svg)](https://www.python.org/downloads/) +[![License: CNC-1.0](https://img.shields.io/badge/License-CNC--1.0-red.svg)](../LICENSE) +[![Docker](https://img.shields.io/badge/docker-available-blue.svg)](https://github.com/su-kaka/gcli2api/pkgs/container/gcli2api) + +[中文](../README.md) | English + +## 🚀 Quick Deploy + +[![Deploy on Zeabur](https://zeabur.com/button.svg)](https://zeabur.com/templates/97VMEF?referralCode=su-kaka) +--- + +## ⚠️ License Declaration + +**This project is licensed under the Cooperative Non-Commercial License (CNC-1.0)** + +This is a strict anti-commercial open source license. Please refer to the [LICENSE](../LICENSE) file for details. + +### ✅ Permitted Uses: +- Personal learning, research, and educational purposes +- Non-profit organization use +- Open source project integration (must comply with the same license) +- Academic research and publication + +### ❌ Prohibited Uses: +- Any form of commercial use +- Enterprise use with annual revenue exceeding $1 million +- Venture capital-backed or publicly traded companies +- Providing paid services or products +- Commercial competitive use + +--- + +## Core Features + +### 🔄 API Endpoints and Format Support + +**Multi-endpoint Multi-format Support** +- **OpenAI Compatible Endpoints**: `/v1/chat/completions` and `/v1/models` + - Supports standard OpenAI format (messages structure) + - Supports Gemini native format (contents structure) + - Automatic format detection and conversion, no manual switching required + - Supports multimodal input (text + images) +- **Gemini Native Endpoints**: `/v1/models/{model}:generateContent` and `streamGenerateContent` + - Supports complete Gemini native API specifications + - Multiple authentication methods: Bearer Token, x-goog-api-key header, URL parameter key +- **Claude Format Compatibility**: Full support for Claude API format + - Endpoint: `/v1/messages` (follows Claude API specification) + - Supports Claude standard messages format + - Supports system parameter and Claude-specific features + - Automatically converts to backend-supported format +- **Antigravity API Support**: Supports OpenAI, Gemini, and Claude formats + - OpenAI format endpoint: `/antigravity/v1/chat/completions` + - Gemini format endpoint: `/antigravity/v1/models/{model}:generateContent` and `streamGenerateContent` + - Claude format endpoint: `/antigravity/v1/messages` + - Supports all Antigravity models (Claude, Gemini, etc.) + - Automatic model name mapping and thinking mode detection + +### 🔐 Authentication and Security Management + +**Flexible Password Management** +- **Separate Password Support**: API password (chat endpoints) and control panel password can be set independently +- **Multiple Authentication Methods**: Supports Authorization Bearer, x-goog-api-key header, URL parameters, etc. +- **JWT Token Authentication**: Control panel supports JWT token authentication +- **User Email Retrieval**: Automatically retrieves and displays Google account email addresses + +### 📊 Intelligent Credential Management System + +**Advanced Credential Management** +- Multiple Google OAuth credential automatic rotation +- Enhanced stability through redundant authentication +- Load balancing and concurrent request support +- Automatic failure detection and credential disabling +- Credential usage statistics and quota management +- Support for manual enable/disable credential files +- Batch credential file operations (enable, disable, delete) + +**Credential Status Monitoring** +- Real-time credential health checks +- Error code tracking (429, 403, 500, etc.) +- Automatic banning mechanism (configurable) +- Credential rotation strategy (based on call count) +- Usage statistics and quota monitoring + +### 🌊 Streaming and Response Processing + +**Multiple Streaming Support** +- True real-time streaming responses +- Fake streaming mode (for compatibility) +- Streaming anti-truncation feature (prevents answer truncation) +- Asynchronous task management and timeout handling + +**Response Optimization** +- Thinking chain content separation +- Reasoning process (reasoning_content) handling +- Multi-turn conversation context management +- Compatibility mode (converts system messages to user messages) + +### 🎛️ Web Management Console + +**Full-featured Web Interface** +- OAuth authentication flow management (supports GCLI and Antigravity dual modes) +- Credential file upload, download, and management +- Real-time log viewing (WebSocket) +- System configuration management +- Usage statistics and monitoring dashboard +- Mobile-friendly interface + +**Batch Operation Support** +- ZIP file batch credential upload (GCLI and Antigravity) +- Batch enable/disable/delete credentials +- Batch user email retrieval +- Batch configuration management +- Unified batch upload interface for all credential types + +### 📈 Usage Statistics and Monitoring + +**Detailed Usage Statistics** +- Call count statistics by credential file +- Gemini 2.5 Pro model specific statistics +- Daily quota management (UTC+7 reset) +- Aggregated statistics and analysis +- Custom daily limit configuration + +**Real-time Monitoring** +- WebSocket real-time log streams +- System status monitoring +- Credential health status +- API call success rate statistics + +### 🔧 Advanced Configuration and Customization + +**Network and Proxy Configuration** +- HTTP/HTTPS proxy support +- Proxy endpoint configuration (OAuth, Google APIs, metadata service) +- Timeout and retry configuration +- Network error handling and recovery + +**Performance and Stability Configuration** +- 429 error automatic retry (configurable interval and attempts) +- Anti-truncation maximum retry attempts +- Credential rotation strategy +- Concurrent request management + +**Logging and Debugging** +- Multi-level logging system (DEBUG, INFO, WARNING, ERROR) +- Log file management +- Real-time log streams +- Log download and clearing + +### 🔄 Environment Variables and Configuration Management + +**Flexible Configuration Methods** +- TOML configuration file support +- Environment variable configuration +- Hot configuration updates (partial configuration items) +- Configuration locking (environment variable priority) + +**Environment Variable Credential Support** +- `GCLI_CREDS_*` format environment variable import +- Automatic loading of environment variable credentials +- Base64 encoded credential support +- Docker container friendly + +## Supported Models + +All models have 1M context window capacity. Each credential file provides 1000 request quota. + +### 🤖 Base Models +- `gemini-2.5-pro` +- `gemini-3-pro-preview` + +### 🧠 Thinking Models +- `gemini-2.5-pro-maxthinking`: Maximum thinking budget mode +- `gemini-2.5-pro-nothinking`: No thinking mode +- Supports custom thinking budget configuration +- Automatic separation of thinking content and final answers + +### 🔍 Search-Enhanced Models +- `gemini-2.5-pro-search`: Model with integrated search functionality + +### 🌊 Special Feature Variants +- **Fake Streaming Mode**: Add `-假流式` suffix to any model name + - Example: `gemini-2.5-pro-假流式` + - For scenarios requiring streaming responses but server doesn't support true streaming +- **Streaming Anti-truncation Mode**: Add `流式抗截断/` prefix to model name + - Example: `流式抗截断/gemini-2.5-pro` + - Automatically detects response truncation and retries to ensure complete answers + +### 🔧 Automatic Model Feature Detection +- System automatically recognizes feature identifiers in model names +- Transparently handles feature mode transitions +- Supports feature combination usage + +--- + +## Installation Guide + +### Termux Environment + +**Initial Installation** +```bash +curl -o termux-install.sh "https://raw.githubusercontent.com/su-kaka/gcli2api/refs/heads/master/termux-install.sh" && chmod +x termux-install.sh && ./termux-install.sh +``` + +**Restart Service** +```bash +cd gcli2api +bash termux-start.sh +``` + +### Windows Environment + +**Initial Installation** +```powershell +iex (iwr "https://raw.githubusercontent.com/su-kaka/gcli2api/refs/heads/master/install.ps1" -UseBasicParsing).Content +``` + +**Restart Service** +Double-click to execute `start.bat` + +### Linux Environment + +**Initial Installation** +```bash +curl -o install.sh "https://raw.githubusercontent.com/su-kaka/gcli2api/refs/heads/master/install.sh" && chmod +x install.sh && ./install.sh +``` + +**Restart Service** +```bash +cd gcli2api +bash start.sh +``` + +### Docker Environment + +**Docker Run Command** +```bash +# Using universal password +docker run -d --name gcli2api --network host -e PASSWORD=pwd -e PORT=7861 -v $(pwd)/data/creds:/app/creds ghcr.io/su-kaka/gcli2api:latest + +# Using separate passwords +docker run -d --name gcli2api --network host -e API_PASSWORD=api_pwd -e PANEL_PASSWORD=panel_pwd -e PORT=7861 -v $(pwd)/data/creds:/app/creds ghcr.io/su-kaka/gcli2api:latest +``` + +**Docker Compose Run Command** +1. Save the following content as `docker-compose.yml` file: + ```yaml + version: '3.8' + + services: + gcli2api: + image: ghcr.io/su-kaka/gcli2api:latest + container_name: gcli2api + restart: unless-stopped + network_mode: host + environment: + # Using universal password (recommended for simple deployment) + - PASSWORD=pwd + - PORT=7861 + # Or use separate passwords (recommended for production) + # - API_PASSWORD=your_api_password + # - PANEL_PASSWORD=your_panel_password + volumes: + - ./data/creds:/app/creds + healthcheck: + test: ["CMD-SHELL", "python -c \"import sys, urllib.request, os; port = os.environ.get('PORT', '7861'); req = urllib.request.Request(f'http://localhost:{port}/v1/models', headers={'Authorization': 'Bearer ' + os.environ.get('PASSWORD', 'pwd')}); sys.exit(0 if urllib.request.urlopen(req, timeout=5).getcode() == 200 else 1)\""] + interval: 30s + timeout: 10s + retries: 3 + start_period: 40s + ``` +2. Start the service: + ```bash + docker-compose up -d + ``` + +--- + +## ⚠️ Important Notes + +- The current OAuth authentication process **only supports localhost access**, meaning authentication must be completed through `http://127.0.0.1:7861/auth` (default port 7861, modifiable via PORT environment variable). +- **For deployment on cloud servers or other remote environments, please first run the service locally and complete OAuth authentication to obtain the generated json credential files (located in the `./geminicli/creds` directory), then upload these files via the auth panel.** +- **Please strictly comply with usage restrictions, only for personal learning and non-commercial purposes** + +--- + +## Configuration Instructions + +1. Visit `http://127.0.0.1:7861/auth` (default port, modifiable via PORT environment variable) +2. Complete OAuth authentication flow (default password: `pwd`, modifiable via environment variables) + - **GCLI Mode**: For obtaining Google Cloud Gemini API credentials + - **Antigravity Mode**: For obtaining Google Antigravity API credentials +3. Configure client: + +**OpenAI Compatible Client:** + - **Endpoint Address**: `http://127.0.0.1:7861/v1` + - **API Key**: `pwd` (default value, modifiable via API_PASSWORD or PASSWORD environment variables) + +**Gemini Native Client:** + - **Endpoint Address**: `http://127.0.0.1:7861` + - **Authentication Methods**: + - `Authorization: Bearer your_api_password` + - `x-goog-api-key: your_api_password` + - URL parameter: `?key=your_api_password` + +### 🌟 Dual Authentication Mode Support + +**GCLI Authentication Mode** +- Standard Google Cloud Gemini API authentication +- Supports OAuth2.0 authentication flow +- Automatically enables required Google Cloud APIs + +**Antigravity Authentication Mode** +- Dedicated authentication for Google Antigravity API +- Independent credential management system +- Supports batch upload and management +- Completely isolated from GCLI credentials + +**Unified Management Interface** +- Manage both credential types in the "Batch Upload" tab +- Upper section: GCLI credential batch upload (blue theme) +- Lower section: Antigravity credential batch upload (green theme) +- Separate credential management tabs for each type + +## 💾 Data Storage Mode + +### 🌟 Storage Backend Support + +gcli2api supports two storage backends: **Local SQLite (Default)** and **MongoDB (Cloud Distributed Storage)** + +### 📁 Local SQLite Storage (Default) + +**Default Storage Method** +- No configuration required, works out of the box +- Data is stored in a local SQLite database +- Suitable for single-machine deployment and personal use +- Automatically creates and manages database files + +### 🍃 MongoDB Cloud Storage Mode + +**Cloud Distributed Storage Solution** + +When multi-instance deployment or cloud storage is needed, MongoDB storage mode can be enabled. + +### ⚙️ Enable MongoDB Mode + +**Step 1: Configure MongoDB Connection** +```bash +# Local MongoDB +export MONGODB_URI="mongodb://localhost:27017" + +# MongoDB Atlas cloud service +export MONGODB_URI="mongodb+srv://username:password@cluster.mongodb.net" + +# MongoDB with authentication +export MONGODB_URI="mongodb://admin:password@localhost:27017/admin" + +# Optional: Custom database name (default: gcli2api) +export MONGODB_DATABASE="my_gcli_db" +``` + +**Step 2: Start Application** +```bash +# Application will automatically detect MongoDB configuration and use MongoDB storage +python web.py +``` + +**Docker Environment using MongoDB** +```bash +# Single MongoDB deployment +docker run -d --name gcli2api \ + -e MONGODB_URI="mongodb://mongodb:27017" \ + -e API_PASSWORD=your_password \ + --network your_network \ + ghcr.io/su-kaka/gcli2api:latest + +# Using MongoDB Atlas +docker run -d --name gcli2api \ + -e MONGODB_URI="mongodb+srv://user:pass@cluster.mongodb.net/gcli2api" \ + -e API_PASSWORD=your_password \ + -p 7861:7861 \ + ghcr.io/su-kaka/gcli2api:latest +``` + +**Docker Compose Example** +```yaml +version: '3.8' + +services: + mongodb: + image: mongo:7 + container_name: gcli2api-mongodb + restart: unless-stopped + environment: + MONGO_INITDB_ROOT_USERNAME: admin + MONGO_INITDB_ROOT_PASSWORD: password123 + volumes: + - mongodb_data:/data/db + ports: + - "27017:27017" + + gcli2api: + image: ghcr.io/su-kaka/gcli2api:latest + container_name: gcli2api + restart: unless-stopped + depends_on: + - mongodb + environment: + - MONGODB_URI=mongodb://admin:password123@mongodb:27017/admin + - MONGODB_DATABASE=gcli2api + - API_PASSWORD=your_api_password + - PORT=7861 + ports: + - "7861:7861" + +volumes: + mongodb_data: +``` + +### 🛠️ Troubleshooting + +**Common Issue Solutions** + +```bash +# Check MongoDB connection +python mongodb_setup.py check + +# View detailed status information +python mongodb_setup.py status + +# Verify data migration results +python -c " +import asyncio +from src.storage_adapter import get_storage_adapter + +async def test(): + storage = await get_storage_adapter() + info = await storage.get_backend_info() + print(f'Current mode: {info[\"backend_type\"]}') + if info['backend_type'] == 'mongodb': + print(f'Database: {info.get(\"database_name\", \"Unknown\")}') + +asyncio.run(test()) +" +``` + +**Migration Failure Handling** +```bash +# If migration is interrupted, re-run +python mongodb_setup.py migrate + +# To rollback to local SQLite mode, remove MONGODB_URI environment variable +unset MONGODB_URI +# Then export data from MongoDB +python mongodb_setup.py export +``` + +### 🔧 Advanced Configuration + +**MongoDB Connection Optimization** +```bash +# Connection pool and timeout configuration +export MONGODB_URI="mongodb://localhost:27017?maxPoolSize=10&serverSelectionTimeoutMS=5000" + +# Replica set configuration +export MONGODB_URI="mongodb://host1:27017,host2:27017,host3:27017/gcli2api?replicaSet=myReplicaSet" + +# Read-write separation configuration +export MONGODB_URI="mongodb://localhost:27017/gcli2api?readPreference=secondaryPreferred" +``` + +## 🏗️ Technical Architecture + +### Core Module Description + +**Authentication and Credential Management** (`src/auth.py`, `src/credential_manager.py`) +- OAuth 2.0 authentication flow management +- Multi-credential file status management and rotation +- Automatic failure detection and recovery +- JWT token generation and validation + +**API Routing and Conversion** (`src/openai_router.py`, `src/gemini_router.py`, `src/openai_transfer.py`) +- OpenAI and Gemini format bidirectional conversion +- Multimodal input processing (text+images) +- Thinking chain content separation and processing +- Streaming response management + +**Network and Proxy** (`src/httpx_client.py`, `src/google_chat_api.py`) +- Unified HTTP client management +- Proxy configuration and hot update support +- Timeout and retry strategies +- Asynchronous request pool management + +**State Management** (`src/state_manager.py`, `src/usage_stats.py`) +- Atomic state operations +- Usage statistics and quota management +- File locking and concurrency safety +- Data persistence (TOML format) + +**Task Management** (`src/task_manager.py`) +- Global asynchronous task lifecycle management +- Resource cleanup and memory management +- Graceful shutdown and exception handling + +**Web Console** (`src/web_routes.py`) +- RESTful API endpoints +- WebSocket real-time communication +- Mobile device adaptation detection +- Batch operation support + +### Advanced Feature Implementation + +**Streaming Anti-truncation Mechanism** (`src/anti_truncation.py`) +- Response truncation pattern detection +- Automatic retry and state recovery +- Context connection management + +**Format Detection and Conversion** (`src/format_detector.py`) +- Automatic request format detection (OpenAI vs Gemini) +- Seamless format conversion +- Parameter mapping and validation + +**User Agent Simulation** (`src/utils.py`) +- GeminiCLI format user agent generation +- Platform detection and client metadata +- API compatibility guarantee + +### Environment Variable Configuration + +**Basic Configuration** +- `PORT`: Service port (default: 7861) +- `HOST`: Server listen address (default: 0.0.0.0) + +**Password Configuration** +- `API_PASSWORD`: Chat API access password (default: inherits PASSWORD or pwd) +- `PANEL_PASSWORD`: Control panel access password (default: inherits PASSWORD or pwd) +- `PASSWORD`: Universal password, overrides the above two when set (default: pwd) + +**Performance and Stability Configuration** +- `CALLS_PER_ROTATION`: Number of calls before each credential rotation (default: 10) +- `RETRY_429_ENABLED`: Enable 429 error automatic retry (default: true) +- `RETRY_429_MAX_RETRIES`: Maximum retry attempts for 429 errors (default: 3) +- `RETRY_429_INTERVAL`: Retry interval for 429 errors, in seconds (default: 1.0) +- `ANTI_TRUNCATION_MAX_ATTEMPTS`: Maximum retry attempts for anti-truncation (default: 3) + +**Network and Proxy Configuration** +- `PROXY`: HTTP/HTTPS proxy address (format: `http://host:port`) +- `OAUTH_PROXY_URL`: OAuth authentication proxy endpoint +- `GOOGLEAPIS_PROXY_URL`: Google APIs proxy endpoint +- `METADATA_SERVICE_URL`: Metadata service proxy endpoint + +**Automation Configuration** +- `AUTO_BAN`: Enable automatic credential banning (default: true) +- `AUTO_LOAD_ENV_CREDS`: Automatically load environment variable credentials at startup (default: false) + +**Compatibility Configuration** +- `COMPATIBILITY_MODE`: Enable compatibility mode, converts system messages to user messages (default: false) + +**Logging Configuration** +- `LOG_LEVEL`: Log level (DEBUG/INFO/WARNING/ERROR, default: INFO) +- `LOG_FILE`: Log file path (default: gcli2api.log) + +**Storage Configuration** + +**SQLite Configuration (Default)** +- No configuration required, automatically uses local SQLite database +- Database files are automatically created in the project directory + +**MongoDB Configuration (Optional Cloud Storage)** +- `MONGODB_URI`: MongoDB connection string (enables MongoDB mode when set) +- `MONGODB_DATABASE`: MongoDB database name (default: gcli2api) + +**Docker Usage Example** +```bash +# Using universal password +docker run -d --name gcli2api \ + -e PASSWORD=mypassword \ + -e PORT=11451 \ + -e GOOGLE_CREDENTIALS="$(cat credential.json | base64 -w 0)" \ + ghcr.io/su-kaka/gcli2api:latest + +# Using separate passwords +docker run -d --name gcli2api \ + -e API_PASSWORD=my_api_password \ + -e PANEL_PASSWORD=my_panel_password \ + -e PORT=11451 \ + -e GOOGLE_CREDENTIALS="$(cat credential.json | base64 -w 0)" \ + ghcr.io/su-kaka/gcli2api:latest +``` + +Note: When credential environment variables are set, the system will prioritize using credentials from environment variables and ignore files in the `creds` directory. + +### API Usage Methods + +This service supports multiple complete sets of API endpoints: + +#### 1. OpenAI Compatible Endpoints (GCLI) + +**Endpoint:** `/v1/chat/completions` +**Authentication:** `Authorization: Bearer your_api_password` + +Supports two request formats with automatic detection and processing: + +**OpenAI Format:** +```json +{ + "model": "gemini-2.5-pro", + "messages": [ + {"role": "system", "content": "You are a helpful assistant"}, + {"role": "user", "content": "Hello"} + ], + "temperature": 0.7, + "stream": true +} +``` + +**Gemini Native Format:** +```json +{ + "model": "gemini-2.5-pro", + "contents": [ + {"role": "user", "parts": [{"text": "Hello"}]} + ], + "systemInstruction": {"parts": [{"text": "You are a helpful assistant"}]}, + "generationConfig": { + "temperature": 0.7 + } +} +``` + +#### 2. Gemini Native Endpoints (GCLI) + +**Non-streaming Endpoint:** `/v1/models/{model}:generateContent` +**Streaming Endpoint:** `/v1/models/{model}:streamGenerateContent` +**Model List:** `/v1/models` + +**Authentication Methods (choose one):** +- `Authorization: Bearer your_api_password` +- `x-goog-api-key: your_api_password` +- URL parameter: `?key=your_api_password` + +**Request Examples:** +```bash +# Using x-goog-api-key header +curl -X POST "http://127.0.0.1:7861/v1/models/gemini-2.5-pro:generateContent" \ + -H "x-goog-api-key: your_api_password" \ + -H "Content-Type: application/json" \ + -d '{ + "contents": [ + {"role": "user", "parts": [{"text": "Hello"}]} + ] + }' + +# Using URL parameter +curl -X POST "http://127.0.0.1:7861/v1/models/gemini-2.5-pro:streamGenerateContent?key=your_api_password" \ + -H "Content-Type: application/json" \ + -d '{ + "contents": [ + {"role": "user", "parts": [{"text": "Hello"}]} + ] + }' +``` + +#### 3. Claude API Format Endpoints + +**Endpoint:** `/v1/messages` +**Authentication:** `x-api-key: your_api_password` or `Authorization: Bearer your_api_password` + +**Request Example:** +```bash +curl -X POST "http://127.0.0.1:7861/v1/messages" \ + -H "x-api-key: your_api_password" \ + -H "anthropic-version: 2023-06-01" \ + -H "Content-Type: application/json" \ + -d '{ + "model": "gemini-2.5-pro", + "max_tokens": 1024, + "messages": [ + {"role": "user", "content": "Hello, Claude!"} + ] + }' +``` + +**Support for system parameter:** +```json +{ + "model": "gemini-2.5-pro", + "max_tokens": 1024, + "system": "You are a helpful assistant", + "messages": [ + {"role": "user", "content": "Hello"} + ] +} +``` + +**Notes:** +- Fully compatible with Claude API format specification +- Automatically converts to Gemini format for backend calls +- Supports all Claude standard parameters +- Response format follows Claude API specification + +#### 4. Antigravity API Endpoints + +**Supports three formats: OpenAI, Gemini, and Claude** + +##### Antigravity OpenAI Format Endpoints + +**Endpoint:** `/antigravity/v1/chat/completions` +**Authentication:** `Authorization: Bearer your_api_password` + +**Request Example:** +```bash +curl -X POST "http://127.0.0.1:7861/antigravity/v1/chat/completions" \ + -H "Authorization: Bearer your_api_password" \ + -H "Content-Type: application/json" \ + -d '{ + "model": "claude-sonnet-4-5", + "messages": [ + {"role": "user", "content": "Hello"} + ], + "stream": true + }' +``` + +##### Antigravity Gemini Format Endpoints + +**Non-streaming Endpoint:** `/antigravity/v1/models/{model}:generateContent` +**Streaming Endpoint:** `/antigravity/v1/models/{model}:streamGenerateContent` + +**Authentication Methods (choose one):** +- `Authorization: Bearer your_api_password` +- `x-goog-api-key: your_api_password` +- URL parameter: `?key=your_api_password` + +**Request Examples:** +```bash +# Gemini format non-streaming request +curl -X POST "http://127.0.0.1:7861/antigravity/v1/models/claude-sonnet-4-5:generateContent" \ + -H "x-goog-api-key: your_api_password" \ + -H "Content-Type: application/json" \ + -d '{ + "contents": [ + {"role": "user", "parts": [{"text": "Hello"}]} + ], + "generationConfig": { + "temperature": 0.7 + } + }' + +# Gemini format streaming request +curl -X POST "http://127.0.0.1:7861/antigravity/v1/models/gemini-2.5-flash:streamGenerateContent?key=your_api_password" \ + -H "Content-Type: application/json" \ + -d '{ + "contents": [ + {"role": "user", "parts": [{"text": "Hello"}]} + ] + }' +``` + +##### Antigravity Claude Format Endpoints + +**Endpoint:** `/antigravity/v1/messages` +**Authentication:** `x-api-key: your_api_password` + +**Request Example:** +```bash +curl -X POST "http://127.0.0.1:7861/antigravity/v1/messages" \ + -H "x-api-key: your_api_password" \ + -H "anthropic-version: 2023-06-01" \ + -H "Content-Type: application/json" \ + -d '{ + "model": "claude-sonnet-4-5", + "max_tokens": 1024, + "messages": [ + {"role": "user", "content": "Hello"} + ] + }' +``` + +**Supported Antigravity Models:** +- Claude series: `claude-sonnet-4-5`, `claude-opus-4-5`, etc. +- Gemini series: `gemini-2.5-flash`, `gemini-2.5-pro`, etc. +- Automatically supports thinking models + +**Gemini Native Example:** +```python +from io import BytesIO +from PIL import Image +from google.genai import Client +from google.genai.types import HttpOptions +from google.genai import types +# The client gets the API key from the environment variable `GEMINI_API_KEY`. + +client = Client( + api_key="pwd", + http_options=HttpOptions(base_url="http://127.0.0.1:7861"), + ) + +prompt = ( + """ + Draw a cat + """ +) + +response = client.models.generate_content( + model="gemini-2.5-flash-image", + contents=[prompt], + config=types.GenerateContentConfig( + image_config=types.ImageConfig( + aspect_ratio="16:9", + ) + ) +) +for part in response.candidates[0].content.parts: + if part.text is not None: + print(part.text) + elif part.inline_data is not None: + image = Image.open(BytesIO(part.inline_data.data)) + image.save("generated_image.png") + +``` + +**Notes:** +- OpenAI endpoints return OpenAI-compatible format +- Gemini endpoints return Gemini native format +- Claude endpoints return Claude-compatible format +- All endpoints use the same API password + +## 📋 Complete API Reference + +### Web Console API + +**Authentication Endpoints** +- `POST /auth/login` - User login +- `POST /auth/start` - Start GCLI OAuth authentication +- `POST /auth/antigravity/start` - Start Antigravity OAuth authentication +- `POST /auth/callback` - Handle OAuth callback +- `GET /auth/status/{project_id}` - Check authentication status +- `GET /auth/antigravity/credentials` - Get Antigravity credentials + +**GCLI Credential Management Endpoints** +- `GET /creds/status` - Get all GCLI credential statuses +- `POST /creds/action` - Single GCLI credential operation (enable/disable/delete) +- `POST /creds/batch-action` - Batch GCLI credential operations +- `POST /auth/upload` - Batch upload GCLI credential files (supports ZIP) +- `GET /creds/download/{filename}` - Download GCLI credential file +- `GET /creds/download-all` - Package download all GCLI credentials +- `POST /creds/fetch-email/{filename}` - Get GCLI user email +- `POST /creds/refresh-all-emails` - Batch refresh GCLI user emails + +**Antigravity Credential Management Endpoints** +- `GET /antigravity/creds/status` - Get all Antigravity credential statuses +- `POST /antigravity/creds/action` - Single Antigravity credential operation (enable/disable/delete) +- `POST /antigravity/creds/batch-action` - Batch Antigravity credential operations +- `POST /antigravity/auth/upload` - Batch upload Antigravity credential files (supports ZIP) +- `GET /antigravity/creds/download/{filename}` - Download Antigravity credential file +- `GET /antigravity/creds/download-all` - Package download all Antigravity credentials +- `POST /antigravity/creds/fetch-email/{filename}` - Get Antigravity user email +- `POST /antigravity/creds/refresh-all-emails` - Batch refresh Antigravity user emails + +**Configuration Management Endpoints** +- `GET /config/get` - Get current configuration +- `POST /config/save` - Save configuration + +**Environment Variable Credential Endpoints** +- `POST /auth/load-env-creds` - Load environment variable credentials +- `DELETE /auth/env-creds` - Clear environment variable credentials +- `GET /auth/env-creds-status` - Get environment variable credential status + +**Log Management Endpoints** +- `POST /auth/logs/clear` - Clear logs +- `GET /auth/logs/download` - Download log file +- `WebSocket /auth/logs/stream` - Real-time log stream + +**Usage Statistics Endpoints** +- `GET /usage/stats` - Get usage statistics +- `GET /usage/aggregated` - Get aggregated statistics +- `POST /usage/update-limits` - Update usage limits +- `POST /usage/reset` - Reset usage statistics + +### Chat API Features + +**Multimodal Support** +```json +{ + "model": "gemini-2.5-pro", + "messages": [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Describe this image"}, + { + "type": "image_url", + "image_url": { + "url": "..." + } + } + ] + } + ] +} +``` + +**Thinking Mode Support** +```json +{ + "model": "gemini-2.5-pro-maxthinking", + "messages": [ + {"role": "user", "content": "Complex math problem"} + ] +} +``` + +Response will include separated thinking content: +```json +{ + "choices": [{ + "message": { + "role": "assistant", + "content": "Final answer", + "reasoning_content": "Detailed thought process..." + } + }] +} +``` + +**Streaming Anti-truncation Usage** +```json +{ + "model": "流式抗截断/gemini-2.5-pro", + "messages": [ + {"role": "user", "content": "Write a long article"} + ], + "stream": true +} +``` + +**Compatibility Mode** +```bash +# Enable compatibility mode +export COMPATIBILITY_MODE=true +``` +In this mode, all `system` messages are converted to `user` messages, improving compatibility with certain clients. + +--- + +## License and Disclaimer + +This project is for learning and research purposes only. Using this project indicates that you agree to: +- Not use this project for any commercial purposes +- Bear all risks and responsibilities of using this project +- Comply with relevant terms of service and legal regulations + +The project authors are not responsible for any direct or indirect losses arising from the use of this project. diff --git "a/docs/qq\347\276\244.jpg" "b/docs/qq\347\276\244.jpg" new file mode 100644 index 0000000000000000000000000000000000000000..de803d000bc808c3abe1241ddbe75a1a0509917a --- /dev/null +++ "b/docs/qq\347\276\244.jpg" @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:58a7061b633748a639326e26636f93ef34d92b73b2c513743218f0f0d2c15247 +size 193283 diff --git a/front/common.js b/front/common.js new file mode 100644 index 0000000000000000000000000000000000000000..0988f7d722092073be133c311709d0fa203f2070 --- /dev/null +++ b/front/common.js @@ -0,0 +1,2669 @@ +// ===================================================================== +// GCLI2API 控制面板公共JavaScript模块 +// ===================================================================== + +// ===================================================================== +// 全局状态管理 +// ===================================================================== +const AppState = { + // 认证相关 + authToken: '', + authInProgress: false, + currentProjectId: '', + + // Antigravity认证 + antigravityAuthState: null, + antigravityAuthInProgress: false, + + // 凭证管理 + creds: createCredsManager('normal'), + antigravityCreds: createCredsManager('antigravity'), + + // 文件上传 + uploadFiles: createUploadManager('normal'), + antigravityUploadFiles: createUploadManager('antigravity'), + + // 配置管理 + currentConfig: {}, + envLockedFields: new Set(), + + // 日志管理 + logWebSocket: null, + allLogs: [], + filteredLogs: [], + currentLogFilter: 'all', + + // 使用统计 + usageStatsData: {}, + + // 冷却倒计时 + cooldownTimerInterval: null +}; + +// ===================================================================== +// 凭证管理器工厂 +// ===================================================================== +function createCredsManager(type) { + const modeParam = type === 'antigravity' ? 'mode=antigravity' : 'mode=geminicli'; + + return { + type: type, + data: {}, + filteredData: {}, + currentPage: 1, + pageSize: 20, + selectedFiles: new Set(), + totalCount: 0, + currentStatusFilter: 'all', + currentErrorCodeFilter: 'all', + currentCooldownFilter: 'all', + statsData: { total: 0, normal: 0, disabled: 0 }, + + // API端点 + getEndpoint: (action) => { + const endpoints = { + status: `./creds/status`, + action: `./creds/action`, + batchAction: `./creds/batch-action`, + download: `./creds/download`, + downloadAll: `./creds/download-all`, + detail: `./creds/detail`, + fetchEmail: `./creds/fetch-email`, + refreshAllEmails: `./creds/refresh-all-emails`, + deduplicate: `./creds/deduplicate-by-email`, + verifyProject: `./creds/verify-project`, + quota: `./creds/quota` + }; + return endpoints[action] || ''; + }, + + // 获取mode参数 + getModeParam: () => modeParam, + + // DOM元素ID前缀 + getElementId: (suffix) => { + // 普通凭证的ID首字母小写,如 credsLoading + // Antigravity的ID是 antigravity + 首字母大写,如 antigravityCredsLoading + if (type === 'antigravity') { + return 'antigravity' + suffix.charAt(0).toUpperCase() + suffix.slice(1); + } + return suffix.charAt(0).toLowerCase() + suffix.slice(1); + }, + + // 刷新凭证列表 + async refresh() { + const loading = document.getElementById(this.getElementId('CredsLoading')); + const list = document.getElementById(this.getElementId('CredsList')); + + try { + loading.style.display = 'block'; + list.innerHTML = ''; + + const offset = (this.currentPage - 1) * this.pageSize; + const errorCodeFilter = this.currentErrorCodeFilter || 'all'; + const cooldownFilter = this.currentCooldownFilter || 'all'; + const response = await fetch( + `${this.getEndpoint('status')}?offset=${offset}&limit=${this.pageSize}&status_filter=${this.currentStatusFilter}&error_code_filter=${errorCodeFilter}&cooldown_filter=${cooldownFilter}&${this.getModeParam()}`, + { headers: getAuthHeaders() } + ); + + const data = await response.json(); + + if (response.ok) { + this.data = {}; + data.items.forEach(item => { + this.data[item.filename] = { + filename: item.filename, + status: { + disabled: item.disabled, + error_codes: item.error_codes || [], + last_success: item.last_success, + }, + user_email: item.user_email, + model_cooldowns: item.model_cooldowns || {} + }; + }); + + this.totalCount = data.total; + // 使用后端返回的全局统计数据 + if (data.stats) { + this.statsData = data.stats; + } else { + // 兼容旧版本后端 + this.calculateStats(); + } + this.updateStatsDisplay(); + this.filteredData = this.data; + this.renderList(); + this.updatePagination(); + + let msg = `已加载 ${data.total} 个${type === 'antigravity' ? 'Antigravity' : ''}凭证文件`; + if (this.currentStatusFilter !== 'all') { + msg += ` (筛选: ${this.currentStatusFilter === 'enabled' ? '仅启用' : '仅禁用'})`; + } + showStatus(msg, 'success'); + } else { + showStatus(`加载失败: ${data.detail || data.error || '未知错误'}`, 'error'); + } + } catch (error) { + showStatus(`网络错误: ${error.message}`, 'error'); + } finally { + loading.style.display = 'none'; + } + }, + + // 计算统计数据(仅用于兼容旧版本后端) + calculateStats() { + this.statsData = { total: this.totalCount, normal: 0, disabled: 0 }; + Object.values(this.data).forEach(credInfo => { + if (credInfo.status.disabled) { + this.statsData.disabled++; + } else { + this.statsData.normal++; + } + }); + }, + + // 更新统计显示 + updateStatsDisplay() { + document.getElementById(this.getElementId('StatTotal')).textContent = this.statsData.total; + document.getElementById(this.getElementId('StatNormal')).textContent = this.statsData.normal; + document.getElementById(this.getElementId('StatDisabled')).textContent = this.statsData.disabled; + }, + + // 渲染凭证列表 + renderList() { + const list = document.getElementById(this.getElementId('CredsList')); + list.innerHTML = ''; + + const entries = Object.entries(this.filteredData); + + if (entries.length === 0) { + const msg = this.totalCount === 0 ? '暂无凭证文件' : '当前筛选条件下暂无数据'; + list.innerHTML = `

${msg}

`; + document.getElementById(this.getElementId('PaginationContainer')).style.display = 'none'; + return; + } + + entries.forEach(([, credInfo]) => { + list.appendChild(createCredCard(credInfo, this)); + }); + + document.getElementById(this.getElementId('PaginationContainer')).style.display = + this.getTotalPages() > 1 ? 'flex' : 'none'; + this.updateBatchControls(); + }, + + // 获取总页数 + getTotalPages() { + return Math.ceil(this.totalCount / this.pageSize); + }, + + // 更新分页信息 + updatePagination() { + const totalPages = this.getTotalPages(); + const startItem = (this.currentPage - 1) * this.pageSize + 1; + const endItem = Math.min(this.currentPage * this.pageSize, this.totalCount); + + document.getElementById(this.getElementId('PaginationInfo')).textContent = + `第 ${this.currentPage} 页,共 ${totalPages} 页 (显示 ${startItem}-${endItem},共 ${this.totalCount} 项)`; + + document.getElementById(this.getElementId('PrevPageBtn')).disabled = this.currentPage <= 1; + document.getElementById(this.getElementId('NextPageBtn')).disabled = this.currentPage >= totalPages; + }, + + // 切换页面 + changePage(direction) { + const newPage = this.currentPage + direction; + if (newPage >= 1 && newPage <= this.getTotalPages()) { + this.currentPage = newPage; + this.refresh(); + } + }, + + // 改变每页大小 + changePageSize() { + this.pageSize = parseInt(document.getElementById(this.getElementId('PageSizeSelect')).value); + this.currentPage = 1; + this.refresh(); + }, + + // 应用状态筛选 + applyStatusFilter() { + this.currentStatusFilter = document.getElementById(this.getElementId('StatusFilter')).value; + const errorCodeFilterEl = document.getElementById(this.getElementId('ErrorCodeFilter')); + const cooldownFilterEl = document.getElementById(this.getElementId('CooldownFilter')); + this.currentErrorCodeFilter = errorCodeFilterEl ? errorCodeFilterEl.value : 'all'; + this.currentCooldownFilter = cooldownFilterEl ? cooldownFilterEl.value : 'all'; + this.currentPage = 1; + this.refresh(); + }, + + // 更新批量控件 + updateBatchControls() { + const selectedCount = this.selectedFiles.size; + document.getElementById(this.getElementId('SelectedCount')).textContent = `已选择 ${selectedCount} 项`; + + const batchBtns = ['Enable', 'Disable', 'Delete', 'Verify'].map(action => + document.getElementById(this.getElementId(`Batch${action}Btn`)) + ); + batchBtns.forEach(btn => btn && (btn.disabled = selectedCount === 0)); + + const selectAllCheckbox = document.getElementById(this.getElementId('SelectAllCheckbox')); + if (!selectAllCheckbox) return; + + const checkboxes = document.querySelectorAll(`.${this.getElementId('file-checkbox')}`); + const currentPageSelectedCount = Array.from(checkboxes) + .filter(cb => this.selectedFiles.has(cb.getAttribute('data-filename'))).length; + + if (currentPageSelectedCount === 0) { + selectAllCheckbox.indeterminate = false; + selectAllCheckbox.checked = false; + } else if (currentPageSelectedCount === checkboxes.length) { + selectAllCheckbox.indeterminate = false; + selectAllCheckbox.checked = true; + } else { + selectAllCheckbox.indeterminate = true; + } + + checkboxes.forEach(cb => { + cb.checked = this.selectedFiles.has(cb.getAttribute('data-filename')); + }); + }, + + // 凭证操作 + async action(filename, action) { + try { + const response = await fetch(`${this.getEndpoint('action')}?${this.getModeParam()}`, { + method: 'POST', + headers: getAuthHeaders(), + body: JSON.stringify({ filename, action }) + }); + + const data = await response.json(); + + if (response.ok) { + showStatus(data.message || `操作成功: ${action}`, 'success'); + await this.refresh(); + } else { + showStatus(`操作失败: ${data.detail || data.error || '未知错误'}`, 'error'); + } + } catch (error) { + showStatus(`网络错误: ${error.message}`, 'error'); + } + }, + + // 批量操作 + async batchAction(action) { + const selectedFiles = Array.from(this.selectedFiles); + + if (selectedFiles.length === 0) { + showStatus('请先选择要操作的文件', 'error'); + return; + } + + const actionNames = { enable: '启用', disable: '禁用', delete: '删除' }; + const confirmMsg = action === 'delete' + ? `确定要删除选中的 ${selectedFiles.length} 个文件吗?\n注意:此操作不可恢复!` + : `确定要${actionNames[action]}选中的 ${selectedFiles.length} 个文件吗?`; + + if (!confirm(confirmMsg)) return; + + try { + showStatus(`正在执行批量${actionNames[action]}操作...`, 'info'); + + const response = await fetch(`${this.getEndpoint('batchAction')}?${this.getModeParam()}`, { + method: 'POST', + headers: getAuthHeaders(), + body: JSON.stringify({ action, filenames: selectedFiles }) + }); + + const data = await response.json(); + + if (response.ok) { + const successCount = data.success_count || data.succeeded; + showStatus(`批量操作完成:成功处理 ${successCount}/${selectedFiles.length} 个文件`, 'success'); + this.selectedFiles.clear(); + this.updateBatchControls(); + await this.refresh(); + } else { + showStatus(`批量操作失败: ${data.detail || data.error || '未知错误'}`, 'error'); + } + } catch (error) { + showStatus(`批量操作网络错误: ${error.message}`, 'error'); + } + } + }; +} + +// ===================================================================== +// 文件上传管理器工厂 +// ===================================================================== +function createUploadManager(type) { + const modeParam = type === 'antigravity' ? 'mode=antigravity' : 'mode=geminicli'; + const endpoint = `./auth/upload?${modeParam}`; + + return { + type: type, + selectedFiles: [], + + getElementId: (suffix) => { + // 普通上传的ID首字母小写,如 fileList + // Antigravity的ID是 antigravity + 首字母大写,如 antigravityFileList + if (type === 'antigravity') { + return 'antigravity' + suffix.charAt(0).toUpperCase() + suffix.slice(1); + } + return suffix.charAt(0).toLowerCase() + suffix.slice(1); + }, + + handleFileSelect(event) { + this.addFiles(Array.from(event.target.files)); + }, + + addFiles(files) { + files.forEach(file => { + const isValid = file.type === 'application/json' || file.name.endsWith('.json') || + file.type === 'application/zip' || file.name.endsWith('.zip'); + + if (isValid) { + if (!this.selectedFiles.find(f => f.name === file.name && f.size === file.size)) { + this.selectedFiles.push(file); + } + } else { + showStatus(`文件 ${file.name} 格式不支持,只支持JSON和ZIP文件`, 'error'); + } + }); + this.updateFileList(); + }, + + updateFileList() { + const list = document.getElementById(this.getElementId('FileList')); + const section = document.getElementById(this.getElementId('FileListSection')); + + if (!list || !section) { + console.warn('File list elements not found:', this.getElementId('FileList')); + return; + } + + if (this.selectedFiles.length === 0) { + section.classList.add('hidden'); + return; + } + + section.classList.remove('hidden'); + list.innerHTML = ''; + + this.selectedFiles.forEach((file, index) => { + const isZip = file.name.endsWith('.zip'); + const fileIcon = isZip ? '📦' : '📄'; + const fileType = isZip ? ' (ZIP压缩包)' : ' (JSON文件)'; + + const fileItem = document.createElement('div'); + fileItem.className = 'file-item'; + fileItem.innerHTML = ` +
+ ${fileIcon} ${file.name} + (${formatFileSize(file.size)}${fileType}) +
+ + `; + list.appendChild(fileItem); + }); + }, + + removeFile(index) { + this.selectedFiles.splice(index, 1); + this.updateFileList(); + }, + + clearFiles() { + this.selectedFiles = []; + this.updateFileList(); + }, + + async upload() { + if (this.selectedFiles.length === 0) { + showStatus('请选择要上传的文件', 'error'); + return; + } + + const progressSection = document.getElementById(this.getElementId('UploadProgressSection')); + const progressFill = document.getElementById(this.getElementId('ProgressFill')); + const progressText = document.getElementById(this.getElementId('ProgressText')); + + progressSection.classList.remove('hidden'); + + const formData = new FormData(); + this.selectedFiles.forEach(file => formData.append('files', file)); + + if (this.selectedFiles.some(f => f.name.endsWith('.zip'))) { + showStatus('正在上传并解压ZIP文件...', 'info'); + } + + try { + const xhr = new XMLHttpRequest(); + xhr.timeout = 300000; // 5分钟 + + xhr.upload.onprogress = (event) => { + if (event.lengthComputable) { + const percent = (event.loaded / event.total) * 100; + progressFill.style.width = percent + '%'; + progressText.textContent = Math.round(percent) + '%'; + } + }; + + xhr.onload = () => { + if (xhr.status === 200) { + try { + const data = JSON.parse(xhr.responseText); + showStatus(`成功上传 ${data.uploaded_count} 个${type === 'antigravity' ? 'Antigravity' : ''}文件`, 'success'); + this.clearFiles(); + progressSection.classList.add('hidden'); + } catch (e) { + showStatus('上传失败: 服务器响应格式错误', 'error'); + } + } else { + try { + const error = JSON.parse(xhr.responseText); + showStatus(`上传失败: ${error.detail || error.error || '未知错误'}`, 'error'); + } catch (e) { + showStatus(`上传失败: HTTP ${xhr.status}`, 'error'); + } + } + }; + + xhr.onerror = () => { + showStatus(`上传失败:连接中断 - 可能原因:文件过多(${this.selectedFiles.length}个)或网络不稳定。建议分批上传。`, 'error'); + progressSection.classList.add('hidden'); + }; + + xhr.ontimeout = () => { + showStatus('上传失败:请求超时 - 文件处理时间过长,请减少文件数量或检查网络连接', 'error'); + progressSection.classList.add('hidden'); + }; + + xhr.open('POST', endpoint); + xhr.setRequestHeader('Authorization', `Bearer ${AppState.authToken}`); + xhr.send(formData); + } catch (error) { + showStatus(`上传失败: ${error.message}`, 'error'); + } + } + }; +} + +// ===================================================================== +// 工具函数 +// ===================================================================== +function showStatus(message, type = 'info') { + const statusSection = document.getElementById('statusSection'); + if (statusSection) { + // 清除之前的定时器 + if (window._statusTimeout) { + clearTimeout(window._statusTimeout); + } + + // 创建新的 toast + statusSection.innerHTML = `
${message}
`; + const statusDiv = statusSection.querySelector('.status'); + + // 强制重绘以触发动画 + statusDiv.offsetHeight; + statusDiv.classList.add('show'); + + // 3秒后淡出并移除 + window._statusTimeout = setTimeout(() => { + statusDiv.classList.add('fade-out'); + setTimeout(() => { + statusSection.innerHTML = ''; + }, 300); // 等待淡出动画完成 + }, 3000); + } else { + alert(message); + } +} + +function getAuthHeaders() { + return { + 'Content-Type': 'application/json', + 'Authorization': `Bearer ${AppState.authToken}` + }; +} + +function formatFileSize(bytes) { + if (bytes < 1024) return bytes + ' B'; + if (bytes < 1024 * 1024) return Math.round(bytes / 1024) + ' KB'; + return Math.round(bytes / (1024 * 1024)) + ' MB'; +} + +function formatCooldownTime(remainingSeconds) { + const hours = Math.floor(remainingSeconds / 3600); + const minutes = Math.floor((remainingSeconds % 3600) / 60); + const seconds = remainingSeconds % 60; + + if (hours > 0) return `${hours}h ${minutes}m ${seconds}s`; + if (minutes > 0) return `${minutes}m ${seconds}s`; + return `${seconds}s`; +} + +// ===================================================================== +// 凭证卡片创建(通用) +// ===================================================================== +function createCredCard(credInfo, manager) { + const div = document.createElement('div'); + const { status, filename } = credInfo; + const managerType = manager.type; + + // 卡片样式 + div.className = status.disabled ? 'cred-card disabled' : 'cred-card'; + + // 状态徽章 + let statusBadges = ''; + statusBadges += status.disabled + ? '已禁用' + : '已启用'; + + if (status.error_codes && status.error_codes.length > 0) { + statusBadges += `错误码: ${status.error_codes.join(', ')}`; + const autoBan = status.error_codes.filter(c => c === 400 || c === 403); + if (autoBan.length > 0 && status.disabled) { + statusBadges += 'AUTO_BAN'; + } + } else { + statusBadges += '无错误'; + } + + // 模型级冷却状态 + if (credInfo.model_cooldowns && Object.keys(credInfo.model_cooldowns).length > 0) { + const currentTime = Date.now() / 1000; + const activeCooldowns = Object.entries(credInfo.model_cooldowns) + .filter(([, until]) => until > currentTime) + .map(([model, until]) => { + const remaining = Math.max(0, Math.floor(until - currentTime)); + const shortModel = model.replace('gemini-', '').replace('-exp', '') + .replace('2.0-', '2-').replace('1.5-', '1.5-'); + return { + model: shortModel, + time: formatCooldownTime(remaining).replace(/s$/, '').replace(/ /g, ''), + fullModel: model + }; + }); + + if (activeCooldowns.length > 0) { + activeCooldowns.slice(0, 2).forEach(item => { + statusBadges += `🔧 ${item.model}: ${item.time}`; + }); + if (activeCooldowns.length > 2) { + const remaining = activeCooldowns.length - 2; + const remainingModels = activeCooldowns.slice(2).map(i => `${i.fullModel}: ${i.time}`).join('\n'); + statusBadges += `+${remaining}`; + } + } + } + + // 路径ID + const pathId = (managerType === 'antigravity' ? 'ag_' : '') + btoa(encodeURIComponent(filename)).replace(/[+/=]/g, '_'); + + // 操作按钮 + const actionButtons = ` + ${status.disabled + ? `` + : `` + } + + + + ${managerType === 'antigravity' ? `` : ''} + + + `; + + // 邮箱信息 + const emailInfo = credInfo.user_email + ? `
${credInfo.user_email}
` + : '
未获取邮箱
'; + + const checkboxClass = manager.getElementId('file-checkbox'); + + div.innerHTML = ` +
+
+ +
+
${filename}
+ ${emailInfo} +
+
+
${statusBadges}
+
+
${actionButtons}
+
+
点击"查看内容"按钮加载文件详情...
+
+ ${managerType === 'antigravity' ? ` + + ` : ''} + `; + + // 添加事件监听 + div.querySelectorAll('[data-filename][data-action]').forEach(button => { + button.addEventListener('click', function () { + const fn = this.getAttribute('data-filename'); + const action = this.getAttribute('data-action'); + if (action === 'delete') { + if (confirm(`确定要删除${managerType === 'antigravity' ? ' Antigravity ' : ''}凭证文件吗?\n${fn}`)) { + manager.action(fn, action); + } + } else { + manager.action(fn, action); + } + }); + }); + + return div; +} + +// ===================================================================== +// 凭证详情切换 +// ===================================================================== +async function toggleCredDetails(pathId) { + await toggleCredDetailsCommon(pathId, AppState.creds); +} + +async function toggleAntigravityCredDetails(pathId) { + await toggleCredDetailsCommon(pathId, AppState.antigravityCreds); +} + +async function toggleCredDetailsCommon(pathId, manager) { + const details = document.getElementById('details-' + pathId); + if (!details) return; + + const isShowing = details.classList.toggle('show'); + + if (isShowing) { + const contentDiv = details.querySelector('.cred-content'); + const filename = contentDiv.getAttribute('data-filename'); + const loaded = contentDiv.getAttribute('data-loaded'); + + if (loaded === 'false' && filename) { + contentDiv.textContent = '正在加载文件内容...'; + + try { + const modeParam = manager.type === 'antigravity' ? 'mode=antigravity' : 'mode=geminicli'; + const endpoint = `./creds/detail/${encodeURIComponent(filename)}?${modeParam}`; + + const response = await fetch(endpoint, { headers: getAuthHeaders() }); + + const data = await response.json(); + if (response.ok && data.content) { + contentDiv.textContent = JSON.stringify(data.content, null, 2); + contentDiv.setAttribute('data-loaded', 'true'); + } else { + contentDiv.textContent = '无法加载文件内容: ' + (data.error || data.detail || '未知错误'); + } + } catch (error) { + contentDiv.textContent = '加载文件内容失败: ' + error.message; + } + } + } +} + +// ===================================================================== +// 登录相关函数 +// ===================================================================== +async function login() { + const password = document.getElementById('loginPassword').value; + + if (!password) { + showStatus('请输入密码', 'error'); + return; + } + + try { + const response = await fetch('./auth/login', { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ password }) + }); + + const data = await response.json(); + + if (response.ok) { + AppState.authToken = data.token; + localStorage.setItem('gcli2api_auth_token', AppState.authToken); + document.getElementById('loginSection').classList.add('hidden'); + document.getElementById('mainSection').classList.remove('hidden'); + showStatus('登录成功', 'success'); + // 显示面板后初始化滑块 + requestAnimationFrame(() => initTabSlider()); + } else { + showStatus(`登录失败: ${data.detail || data.error || '未知错误'}`, 'error'); + } + } catch (error) { + showStatus(`网络错误: ${error.message}`, 'error'); + } +} + +async function autoLogin() { + const savedToken = localStorage.getItem('gcli2api_auth_token'); + if (!savedToken) return false; + + AppState.authToken = savedToken; + + try { + const response = await fetch('./config/get', { + headers: { + 'Content-Type': 'application/json', + 'Authorization': `Bearer ${AppState.authToken}` + } + }); + + if (response.ok) { + document.getElementById('loginSection').classList.add('hidden'); + document.getElementById('mainSection').classList.remove('hidden'); + showStatus('自动登录成功', 'success'); + // 显示面板后初始化滑块 + requestAnimationFrame(() => initTabSlider()); + return true; + } else if (response.status === 401) { + localStorage.removeItem('gcli2api_auth_token'); + AppState.authToken = ''; + return false; + } + return false; + } catch (error) { + return false; + } +} + +function logout() { + localStorage.removeItem('gcli2api_auth_token'); + AppState.authToken = ''; + document.getElementById('loginSection').classList.remove('hidden'); + document.getElementById('mainSection').classList.add('hidden'); + showStatus('已退出登录', 'info'); + const passwordInput = document.getElementById('loginPassword'); + if (passwordInput) passwordInput.value = ''; +} + +function handlePasswordEnter(event) { + if (event.key === 'Enter') login(); +} + +// ===================================================================== +// 标签页切换 +// ===================================================================== + +// 更新滑块位置 +function updateTabSlider(targetTab, animate = true) { + const slider = document.querySelector('.tab-slider'); + const tabs = document.querySelector('.tabs'); + if (!slider || !tabs || !targetTab) return; + + // 获取按钮位置和容器宽度 + const tabLeft = targetTab.offsetLeft; + const tabWidth = targetTab.offsetWidth; + const tabsWidth = tabs.scrollWidth; + + // 使用 left 和 right 同时控制,确保动画同步 + const rightValue = tabsWidth - tabLeft - tabWidth; + + if (animate) { + slider.style.left = `${tabLeft}px`; + slider.style.right = `${rightValue}px`; + } else { + // 首次加载时不使用动画 + slider.style.transition = 'none'; + slider.style.left = `${tabLeft}px`; + slider.style.right = `${rightValue}px`; + // 强制重绘后恢复过渡 + slider.offsetHeight; + slider.style.transition = ''; + } +} + +// 初始化滑块位置 +function initTabSlider() { + const activeTab = document.querySelector('.tab.active'); + if (activeTab) { + updateTabSlider(activeTab, false); + } +} + +// 页面加载和窗口大小变化时初始化滑块 +document.addEventListener('DOMContentLoaded', initTabSlider); +window.addEventListener('resize', () => { + const activeTab = document.querySelector('.tab.active'); + if (activeTab) updateTabSlider(activeTab, false); +}); + +function switchTab(tabName) { + // 获取当前活动的内容区域 + const currentContent = document.querySelector('.tab-content.active'); + const targetContent = document.getElementById(tabName + 'Tab'); + + // 如果点击的是当前标签页,不做任何操作 + if (currentContent === targetContent) return; + + // 找到目标标签按钮 + const targetTab = event && event.target ? event.target : + document.querySelector(`.tab[onclick*="'${tabName}'"]`); + + // 移除所有标签页的active状态 + document.querySelectorAll('.tab').forEach(tab => tab.classList.remove('active')); + + // 添加当前点击标签的active状态 + if (targetTab) { + targetTab.classList.add('active'); + // 更新滑块位置(带动画) + updateTabSlider(targetTab, true); + } + + // 淡出当前内容 + if (currentContent) { + // 设置淡出过渡 + currentContent.style.transition = 'opacity 0.18s ease-out, transform 0.18s ease-out'; + currentContent.style.opacity = '0'; + currentContent.style.transform = 'translateX(-12px)'; + + setTimeout(() => { + currentContent.classList.remove('active'); + currentContent.style.transition = ''; + currentContent.style.opacity = ''; + currentContent.style.transform = ''; + + // 淡入新内容 + if (targetContent) { + // 先设置初始状态(在添加 active 类之前) + targetContent.style.opacity = '0'; + targetContent.style.transform = 'translateX(12px)'; + targetContent.style.transition = 'none'; // 暂时禁用过渡 + + // 添加 active 类使元素可见 + targetContent.classList.add('active'); + + // 使用双重 requestAnimationFrame 确保浏览器完成重绘 + requestAnimationFrame(() => { + requestAnimationFrame(() => { + // 启用过渡并应用最终状态 + targetContent.style.transition = 'opacity 0.25s ease-out, transform 0.25s ease-out'; + targetContent.style.opacity = '1'; + targetContent.style.transform = 'translateX(0)'; + + // 清理内联样式并执行数据加载 + setTimeout(() => { + targetContent.style.transition = ''; + targetContent.style.opacity = ''; + targetContent.style.transform = ''; + + // 动画完成后触发数据加载 + triggerTabDataLoad(tabName); + }, 260); + }); + }); + } + }, 180); + } else { + // 如果没有当前内容(首次加载),直接显示目标内容 + if (targetContent) { + targetContent.classList.add('active'); + // 直接触发数据加载 + triggerTabDataLoad(tabName); + } + } +} + +// 标签页数据加载(从动画中分离出来) +function triggerTabDataLoad(tabName) { + if (tabName === 'manage') AppState.creds.refresh(); + if (tabName === 'antigravity-manage') AppState.antigravityCreds.refresh(); + if (tabName === 'config') loadConfig(); + if (tabName === 'logs') connectWebSocket(); +} + + +// ===================================================================== +// OAuth认证相关函数 +// ===================================================================== +async function startAuth() { + const projectId = document.getElementById('projectId').value.trim(); + AppState.currentProjectId = projectId || null; + + const btn = document.getElementById('getAuthBtn'); + btn.disabled = true; + btn.textContent = '正在获取认证链接...'; + + try { + const requestBody = projectId ? { project_id: projectId } : {}; + showStatus(projectId ? '使用指定的项目ID生成认证链接...' : '将尝试自动检测项目ID,正在生成认证链接...', 'info'); + + const response = await fetch('./auth/start', { + method: 'POST', + headers: getAuthHeaders(), + body: JSON.stringify(requestBody) + }); + + const data = await response.json(); + + if (response.ok) { + document.getElementById('authUrl').href = data.auth_url; + document.getElementById('authUrl').textContent = data.auth_url; + document.getElementById('authUrlSection').classList.remove('hidden'); + + const msg = data.auto_project_detection + ? '认证链接已生成(将在认证完成后自动检测项目ID),请点击链接完成授权' + : `认证链接已生成(项目ID: ${data.detected_project_id}),请点击链接完成授权`; + showStatus(msg, 'info'); + AppState.authInProgress = true; + } else { + showStatus(`错误: ${data.error || '获取认证链接失败'}`, 'error'); + } + } catch (error) { + showStatus(`网络错误: ${error.message}`, 'error'); + } finally { + btn.disabled = false; + btn.textContent = '获取认证链接'; + } +} + +async function getCredentials() { + if (!AppState.authInProgress) { + showStatus('请先获取认证链接并完成授权', 'error'); + return; + } + + const btn = document.getElementById('getCredsBtn'); + btn.disabled = true; + btn.textContent = '等待OAuth回调中...'; + + try { + showStatus('正在等待OAuth回调,这可能需要一些时间...', 'info'); + + const requestBody = AppState.currentProjectId ? { project_id: AppState.currentProjectId } : {}; + + const response = await fetch('./auth/callback', { + method: 'POST', + headers: getAuthHeaders(), + body: JSON.stringify(requestBody) + }); + + const data = await response.json(); + + if (response.ok) { + document.getElementById('credentialsContent').textContent = JSON.stringify(data.credentials, null, 2); + + const msg = data.auto_detected_project + ? `✅ 认证成功!项目ID已自动检测为: ${data.credentials.project_id},文件已保存到: ${data.file_path}` + : `✅ 认证成功!文件已保存到: ${data.file_path}`; + showStatus(msg, 'success'); + + document.getElementById('credentialsSection').classList.remove('hidden'); + AppState.authInProgress = false; + } else if (data.requires_project_selection && data.available_projects) { + let projectOptions = "请选择一个项目:\n\n"; + data.available_projects.forEach((project, index) => { + projectOptions += `${index + 1}. ${project.name} (${project.project_id})\n`; + }); + projectOptions += `\n请输入序号 (1-${data.available_projects.length}):`; + + const selection = prompt(projectOptions); + const projectIndex = parseInt(selection) - 1; + + if (projectIndex >= 0 && projectIndex < data.available_projects.length) { + AppState.currentProjectId = data.available_projects[projectIndex].project_id; + btn.textContent = '重新尝试获取认证文件'; + showStatus(`使用选择的项目重新尝试...`, 'info'); + setTimeout(() => getCredentials(), 1000); + return; + } else { + showStatus('无效的选择,请重新开始认证', 'error'); + } + } else if (data.requires_manual_project_id) { + const userProjectId = prompt('无法自动检测项目ID,请手动输入您的Google Cloud项目ID:'); + if (userProjectId && userProjectId.trim()) { + AppState.currentProjectId = userProjectId.trim(); + btn.textContent = '重新尝试获取认证文件'; + showStatus('使用手动输入的项目ID重新尝试...', 'info'); + setTimeout(() => getCredentials(), 1000); + return; + } else { + showStatus('需要项目ID才能完成认证,请重新开始并输入正确的项目ID', 'error'); + } + } else { + showStatus(`❌ 错误: ${data.error || '获取认证文件失败'}`, 'error'); + } + } catch (error) { + showStatus(`网络错误: ${error.message}`, 'error'); + } finally { + btn.disabled = false; + btn.textContent = '获取认证文件'; + } +} + +// ===================================================================== +// Antigravity 认证相关函数 +// ===================================================================== +async function startAntigravityAuth() { + const btn = document.getElementById('getAntigravityAuthBtn'); + btn.disabled = true; + btn.textContent = '生成认证链接中...'; + + try { + showStatus('正在生成 Antigravity 认证链接...', 'info'); + + const response = await fetch('./auth/start', { + method: 'POST', + headers: getAuthHeaders(), + body: JSON.stringify({ mode: 'antigravity' }) + }); + + const data = await response.json(); + + if (response.ok) { + AppState.antigravityAuthState = data.state; + AppState.antigravityAuthInProgress = true; + + const authUrlLink = document.getElementById('antigravityAuthUrl'); + authUrlLink.href = data.auth_url; + authUrlLink.textContent = data.auth_url; + document.getElementById('antigravityAuthUrlSection').classList.remove('hidden'); + + showStatus('✅ Antigravity 认证链接已生成!请点击链接完成授权', 'success'); + } else { + showStatus(`❌ 错误: ${data.error || '生成认证链接失败'}`, 'error'); + } + } catch (error) { + showStatus(`网络错误: ${error.message}`, 'error'); + } finally { + btn.disabled = false; + btn.textContent = '获取 Antigravity 认证链接'; + } +} + +async function getAntigravityCredentials() { + if (!AppState.antigravityAuthInProgress) { + showStatus('请先获取 Antigravity 认证链接并完成授权', 'error'); + return; + } + + const btn = document.getElementById('getAntigravityCredsBtn'); + btn.disabled = true; + btn.textContent = '等待OAuth回调中...'; + + try { + showStatus('正在等待 Antigravity OAuth回调...', 'info'); + + const response = await fetch('./auth/callback', { + method: 'POST', + headers: getAuthHeaders(), + body: JSON.stringify({ mode: 'antigravity' }) + }); + + const data = await response.json(); + + if (response.ok) { + document.getElementById('antigravityCredsContent').textContent = JSON.stringify(data.credentials, null, 2); + document.getElementById('antigravityCredsSection').classList.remove('hidden'); + AppState.antigravityAuthInProgress = false; + showStatus(`✅ Antigravity 认证成功!文件已保存到: ${data.file_path}`, 'success'); + } else { + showStatus(`❌ 错误: ${data.error || '获取认证文件失败'}`, 'error'); + } + } catch (error) { + showStatus(`网络错误: ${error.message}`, 'error'); + } finally { + btn.disabled = false; + btn.textContent = '获取 Antigravity 凭证'; + } +} + +function downloadAntigravityCredentials() { + const content = document.getElementById('antigravityCredsContent').textContent; + const blob = new Blob([content], { type: 'application/json' }); + const url = window.URL.createObjectURL(blob); + const a = document.createElement('a'); + a.href = url; + a.download = `antigravity-credential-${Date.now()}.json`; + a.click(); + window.URL.revokeObjectURL(url); +} + +// ===================================================================== +// 回调URL处理 +// ===================================================================== +function toggleProjectIdSection() { + const section = document.getElementById('projectIdSection'); + const icon = document.getElementById('projectIdToggleIcon'); + + if (section.style.display === 'none') { + section.style.display = 'block'; + icon.style.transform = 'rotate(90deg)'; + icon.textContent = '▼'; + } else { + section.style.display = 'none'; + icon.style.transform = 'rotate(0deg)'; + icon.textContent = '▶'; + } +} + +function toggleCallbackUrlSection() { + const section = document.getElementById('callbackUrlSection'); + const icon = document.getElementById('callbackUrlToggleIcon'); + + if (section.style.display === 'none') { + section.style.display = 'block'; + icon.style.transform = 'rotate(180deg)'; + icon.textContent = '▲'; + } else { + section.style.display = 'none'; + icon.style.transform = 'rotate(0deg)'; + icon.textContent = '▼'; + } +} + +function toggleAntigravityCallbackUrlSection() { + const section = document.getElementById('antigravityCallbackUrlSection'); + const icon = document.getElementById('antigravityCallbackUrlToggleIcon'); + + if (section.style.display === 'none') { + section.style.display = 'block'; + icon.style.transform = 'rotate(180deg)'; + icon.textContent = '▲'; + } else { + section.style.display = 'none'; + icon.style.transform = 'rotate(0deg)'; + icon.textContent = '▼'; + } +} + +async function processCallbackUrl() { + const callbackUrl = document.getElementById('callbackUrlInput').value.trim(); + + if (!callbackUrl) { + showStatus('请输入回调URL', 'error'); + return; + } + + if (!callbackUrl.startsWith('http://') && !callbackUrl.startsWith('https://')) { + showStatus('请输入有效的URL(以http://或https://开头)', 'error'); + return; + } + + if (!callbackUrl.includes('code=') || !callbackUrl.includes('state=')) { + showStatus('❌ 这不是有效的回调URL!请确保:\n1. 已完成Google OAuth授权\n2. 复制的是浏览器地址栏的完整URL\n3. URL包含code和state参数', 'error'); + return; + } + + showStatus('正在从回调URL获取凭证...', 'info'); + + try { + const projectId = document.getElementById('projectId')?.value.trim() || null; + + const response = await fetch('./auth/callback-url', { + method: 'POST', + headers: getAuthHeaders(), + body: JSON.stringify({ callback_url: callbackUrl, project_id: projectId }) + }); + + const result = await response.json(); + + if (result.credentials) { + showStatus(result.message || '从回调URL获取凭证成功!', 'success'); + document.getElementById('credentialsContent').innerHTML = '
' + JSON.stringify(result.credentials, null, 2) + '
'; + document.getElementById('credentialsSection').classList.remove('hidden'); + } else if (result.requires_manual_project_id) { + showStatus('需要手动指定项目ID,请在高级选项中填入Google Cloud项目ID后重试', 'error'); + } else if (result.requires_project_selection) { + let msg = '
可用项目:
'; + result.available_projects.forEach(p => { + msg += `• ${p.name} (ID: ${p.project_id})
`; + }); + showStatus('检测到多个项目,请在高级选项中指定项目ID:' + msg, 'error'); + } else { + showStatus(result.error || '从回调URL获取凭证失败', 'error'); + } + + document.getElementById('callbackUrlInput').value = ''; + } catch (error) { + showStatus(`从回调URL获取凭证失败: ${error.message}`, 'error'); + } +} + +async function processAntigravityCallbackUrl() { + const callbackUrl = document.getElementById('antigravityCallbackUrlInput').value.trim(); + + if (!callbackUrl) { + showStatus('请输入回调URL', 'error'); + return; + } + + if (!callbackUrl.startsWith('http://') && !callbackUrl.startsWith('https://')) { + showStatus('请输入有效的URL(以http://或https://开头)', 'error'); + return; + } + + if (!callbackUrl.includes('code=') || !callbackUrl.includes('state=')) { + showStatus('❌ 这不是有效的回调URL!请确保包含code和state参数', 'error'); + return; + } + + showStatus('正在从回调URL获取 Antigravity 凭证...', 'info'); + + try { + const response = await fetch('./auth/callback-url', { + method: 'POST', + headers: getAuthHeaders(), + body: JSON.stringify({ callback_url: callbackUrl, mode: 'antigravity' }) + }); + + const result = await response.json(); + + if (result.credentials) { + showStatus(result.message || '从回调URL获取 Antigravity 凭证成功!', 'success'); + document.getElementById('antigravityCredsContent').textContent = JSON.stringify(result.credentials, null, 2); + document.getElementById('antigravityCredsSection').classList.remove('hidden'); + } else { + showStatus(result.error || '从回调URL获取 Antigravity 凭证失败', 'error'); + } + + document.getElementById('antigravityCallbackUrlInput').value = ''; + } catch (error) { + showStatus(`从回调URL获取 Antigravity 凭证失败: ${error.message}`, 'error'); + } +} + +// ===================================================================== +// 全局兼容函数(供HTML调用) +// ===================================================================== +// 普通凭证管理 +function refreshCredsStatus() { AppState.creds.refresh(); } +function applyStatusFilter() { AppState.creds.applyStatusFilter(); } +function changePage(direction) { AppState.creds.changePage(direction); } +function changePageSize() { AppState.creds.changePageSize(); } +function toggleFileSelection(filename) { + if (AppState.creds.selectedFiles.has(filename)) { + AppState.creds.selectedFiles.delete(filename); + } else { + AppState.creds.selectedFiles.add(filename); + } + AppState.creds.updateBatchControls(); +} +function toggleSelectAll() { + const checkbox = document.getElementById('selectAllCheckbox'); + const checkboxes = document.querySelectorAll('.file-checkbox'); + + if (checkbox.checked) { + checkboxes.forEach(cb => AppState.creds.selectedFiles.add(cb.getAttribute('data-filename'))); + } else { + AppState.creds.selectedFiles.clear(); + } + checkboxes.forEach(cb => cb.checked = checkbox.checked); + AppState.creds.updateBatchControls(); +} +function batchAction(action) { AppState.creds.batchAction(action); } +function downloadCred(filename) { + fetch(`./creds/download/${filename}`, { headers: { 'Authorization': `Bearer ${AppState.authToken}` } }) + .then(r => r.ok ? r.blob() : Promise.reject()) + .then(blob => { + const url = window.URL.createObjectURL(blob); + const a = document.createElement('a'); + a.href = url; + a.download = filename; + a.click(); + window.URL.revokeObjectURL(url); + showStatus(`已下载文件: ${filename}`, 'success'); + }) + .catch(() => showStatus(`下载失败: ${filename}`, 'error')); +} +async function downloadAllCreds() { + try { + const response = await fetch('./creds/download-all', { + headers: { 'Authorization': `Bearer ${AppState.authToken}` } + }); + if (response.ok) { + const blob = await response.blob(); + const url = window.URL.createObjectURL(blob); + const a = document.createElement('a'); + a.href = url; + a.download = 'credentials.zip'; + a.click(); + window.URL.revokeObjectURL(url); + showStatus('已下载所有凭证文件', 'success'); + } + } catch (error) { + showStatus(`打包下载失败: ${error.message}`, 'error'); + } +} + +// Antigravity凭证管理 +function refreshAntigravityCredsList() { AppState.antigravityCreds.refresh(); } +function applyAntigravityStatusFilter() { AppState.antigravityCreds.applyStatusFilter(); } +function changeAntigravityPage(direction) { AppState.antigravityCreds.changePage(direction); } +function changeAntigravityPageSize() { AppState.antigravityCreds.changePageSize(); } +function toggleAntigravityFileSelection(filename) { + if (AppState.antigravityCreds.selectedFiles.has(filename)) { + AppState.antigravityCreds.selectedFiles.delete(filename); + } else { + AppState.antigravityCreds.selectedFiles.add(filename); + } + AppState.antigravityCreds.updateBatchControls(); +} +function toggleSelectAllAntigravity() { + const checkbox = document.getElementById('selectAllAntigravityCheckbox'); + const checkboxes = document.querySelectorAll('.antigravityFile-checkbox'); + + if (checkbox.checked) { + checkboxes.forEach(cb => AppState.antigravityCreds.selectedFiles.add(cb.getAttribute('data-filename'))); + } else { + AppState.antigravityCreds.selectedFiles.clear(); + } + checkboxes.forEach(cb => cb.checked = checkbox.checked); + AppState.antigravityCreds.updateBatchControls(); +} +function batchAntigravityAction(action) { AppState.antigravityCreds.batchAction(action); } +function downloadAntigravityCred(filename) { + fetch(`./creds/download/${filename}?mode=antigravity`, { headers: getAuthHeaders() }) + .then(r => r.ok ? r.blob() : Promise.reject()) + .then(blob => { + const url = window.URL.createObjectURL(blob); + const a = document.createElement('a'); + a.href = url; + a.download = filename; + a.click(); + window.URL.revokeObjectURL(url); + showStatus(`✅ 已下载: ${filename}`, 'success'); + }) + .catch(() => showStatus(`下载失败: ${filename}`, 'error')); +} +function deleteAntigravityCred(filename) { + if (confirm(`确定要删除 ${filename} 吗?`)) { + AppState.antigravityCreds.action(filename, 'delete'); + } +} +async function downloadAllAntigravityCreds() { + try { + const response = await fetch('./creds/download-all?mode=antigravity', { headers: getAuthHeaders() }); + if (response.ok) { + const blob = await response.blob(); + const url = window.URL.createObjectURL(blob); + const a = document.createElement('a'); + a.href = url; + a.download = `antigravity_credentials_${Date.now()}.zip`; + a.click(); + window.URL.revokeObjectURL(url); + showStatus('✅ 所有Antigravity凭证已打包下载', 'success'); + } + } catch (error) { + showStatus(`网络错误: ${error.message}`, 'error'); + } +} + +// 文件上传 +function handleFileSelect(event) { AppState.uploadFiles.handleFileSelect(event); } +function removeFile(index) { AppState.uploadFiles.removeFile(index); } +function clearFiles() { AppState.uploadFiles.clearFiles(); } +function uploadFiles() { AppState.uploadFiles.upload(); } + +function handleAntigravityFileSelect(event) { AppState.antigravityUploadFiles.handleFileSelect(event); } +function handleAntigravityFileDrop(event) { + event.preventDefault(); + event.currentTarget.style.borderColor = '#007bff'; + event.currentTarget.style.backgroundColor = '#f8f9fa'; + AppState.antigravityUploadFiles.addFiles(Array.from(event.dataTransfer.files)); +} +function removeAntigravityFile(index) { AppState.antigravityUploadFiles.removeFile(index); } +function clearAntigravityFiles() { AppState.antigravityUploadFiles.clearFiles(); } +function uploadAntigravityFiles() { AppState.antigravityUploadFiles.upload(); } + +// 邮箱相关 +// 辅助函数:根据文件名更新卡片中的邮箱显示 +function updateEmailDisplay(filename, email, managerType = 'normal') { + // 查找对应的凭证卡片 + const containerId = managerType === 'antigravity' ? 'antigravityCredsList' : 'credsList'; + const container = document.getElementById(containerId); + if (!container) return false; + + // 通过 data-filename 找到对应的复选框,再找到其父卡片 + const checkbox = container.querySelector(`input[data-filename="${filename}"]`); + if (!checkbox) return false; + + // 找到对应的 cred-card 元素 + const card = checkbox.closest('.cred-card'); + if (!card) return false; + + // 找到邮箱显示元素 + const emailDiv = card.querySelector('.cred-email'); + if (emailDiv) { + emailDiv.textContent = email; + emailDiv.style.color = '#666'; + emailDiv.style.fontStyle = 'normal'; + return true; + } + return false; +} + +async function fetchUserEmail(filename) { + try { + showStatus('正在获取用户邮箱...', 'info'); + const response = await fetch(`./creds/fetch-email/${encodeURIComponent(filename)}`, { + method: 'POST', + headers: getAuthHeaders() + }); + const data = await response.json(); + if (response.ok && data.user_email) { + showStatus(`成功获取邮箱: ${data.user_email}`, 'success'); + // 直接更新卡片中的邮箱显示,不刷新整个列表 + updateEmailDisplay(filename, data.user_email, 'normal'); + } else { + showStatus(data.message || '无法获取用户邮箱', 'error'); + } + } catch (error) { + showStatus(`获取邮箱失败: ${error.message}`, 'error'); + } +} + +async function fetchAntigravityUserEmail(filename) { + try { + showStatus('正在获取用户邮箱...', 'info'); + const response = await fetch(`./creds/fetch-email/${encodeURIComponent(filename)}?mode=antigravity`, { + method: 'POST', + headers: getAuthHeaders() + }); + const data = await response.json(); + if (response.ok && data.user_email) { + showStatus(`成功获取邮箱: ${data.user_email}`, 'success'); + // 直接更新卡片中的邮箱显示,不刷新整个列表 + updateEmailDisplay(filename, data.user_email, 'antigravity'); + } else { + showStatus(data.message || '无法获取用户邮箱', 'error'); + } + } catch (error) { + showStatus(`获取邮箱失败: ${error.message}`, 'error'); + } +} + +async function verifyProjectId(filename) { + try { + // 显示加载状态 + showStatus('🔍 正在检验Project ID,请稍候...', 'info'); + + const response = await fetch(`./creds/verify-project/${encodeURIComponent(filename)}`, { + method: 'POST', + headers: getAuthHeaders() + }); + const data = await response.json(); + + if (response.ok && data.success) { + // 成功时显示绿色成功消息和Project ID + const successMsg = `✅ 检验成功!\n文件: ${filename}\nProject ID: ${data.project_id}\n\n${data.message}`; + showStatus(successMsg.replace(/\n/g, '
'), 'success'); + + // 弹出成功提示 + alert(`✅ 检验成功!\n\n文件: ${filename}\nProject ID: ${data.project_id}\n\n${data.message}`); + + await AppState.creds.refresh(); + } else { + // 失败时显示红色错误消息 + const errorMsg = data.message || '检验失败'; + showStatus(`❌ ${errorMsg}`, 'error'); + alert(`❌ 检验失败\n\n${errorMsg}`); + } + } catch (error) { + const errorMsg = `检验失败: ${error.message}`; + showStatus(`❌ ${errorMsg}`, 'error'); + alert(`❌ ${errorMsg}`); + } +} + +async function verifyAntigravityProjectId(filename) { + try { + // 显示加载状态 + showStatus('🔍 正在检验Antigravity Project ID,请稍候...', 'info'); + + const response = await fetch(`./creds/verify-project/${encodeURIComponent(filename)}?mode=antigravity`, { + method: 'POST', + headers: getAuthHeaders() + }); + const data = await response.json(); + + if (response.ok && data.success) { + // 成功时显示绿色成功消息和Project ID + const successMsg = `✅ 检验成功!\n文件: ${filename}\nProject ID: ${data.project_id}\n\n${data.message}`; + showStatus(successMsg.replace(/\n/g, '
'), 'success'); + + // 弹出成功提示 + alert(`✅ Antigravity检验成功!\n\n文件: ${filename}\nProject ID: ${data.project_id}\n\n${data.message}`); + + await AppState.antigravityCreds.refresh(); + } else { + // 失败时显示红色错误消息 + const errorMsg = data.message || '检验失败'; + showStatus(`❌ ${errorMsg}`, 'error'); + alert(`❌ 检验失败\n\n${errorMsg}`); + } + } catch (error) { + const errorMsg = `检验失败: ${error.message}`; + showStatus(`❌ ${errorMsg}`, 'error'); + alert(`❌ ${errorMsg}`); + } +} + +async function toggleAntigravityQuotaDetails(pathId) { + const quotaDetails = document.getElementById('quota-' + pathId); + if (!quotaDetails) return; + + // 切换显示状态 + const isShowing = quotaDetails.style.display === 'block'; + + if (isShowing) { + // 收起 + quotaDetails.style.display = 'none'; + } else { + // 展开 + quotaDetails.style.display = 'block'; + + const contentDiv = quotaDetails.querySelector('.cred-quota-content'); + const filename = contentDiv.getAttribute('data-filename'); + const loaded = contentDiv.getAttribute('data-loaded'); + + // 如果还没加载过,则加载数据 + if (loaded === 'false' && filename) { + contentDiv.innerHTML = '
📊 正在加载额度信息...
'; + + try { + const response = await fetch(`./creds/quota/${encodeURIComponent(filename)}?mode=antigravity`, { + method: 'GET', + headers: getAuthHeaders() + }); + const data = await response.json(); + + if (response.ok && data.success) { + // 成功时渲染美化的额度信息 + const models = data.models || {}; + + if (Object.keys(models).length === 0) { + contentDiv.innerHTML = ` +
+
📊
+
暂无额度信息
+
+ `; + } else { + let quotaHTML = ` +
+

+ 📊 + 额度信息详情 +

+
文件: ${filename}
+
+
+ `; + + for (const [modelName, quotaData] of Object.entries(models)) { + // 后端返回的是剩余比例 (0-1),不是绝对数量 + const remainingFraction = quotaData.remaining || 0; + const resetTime = quotaData.resetTime || 'N/A'; + + // 计算已使用百分比(1 - 剩余比例) + const usedPercentage = Math.round((1 - remainingFraction) * 100); + const remainingPercentage = Math.round(remainingFraction * 100); + + // 根据使用情况选择颜色 + let percentageColor = '#28a745'; // 绿色:使用少 + if (usedPercentage >= 90) percentageColor = '#dc3545'; // 红色:使用多 + else if (usedPercentage >= 70) percentageColor = '#ffc107'; // 黄色:使用较多 + else if (usedPercentage >= 50) percentageColor = '#17a2b8'; // 蓝色:使用中等 + + quotaHTML += ` +
+
+
+ ${modelName} +
+
+ ${remainingPercentage}% +
+
+
+
+
+
+ ${resetTime !== 'N/A' ? '🔄 ' + resetTime : ''} +
+
+ `; + } + + quotaHTML += '
'; + contentDiv.innerHTML = quotaHTML; + } + + contentDiv.setAttribute('data-loaded', 'true'); + showStatus('✅ 成功加载额度信息', 'success'); + } else { + // 失败时显示错误 + const errorMsg = data.error || '获取额度信息失败'; + contentDiv.innerHTML = ` +
+
+
获取额度信息失败
+
${errorMsg}
+
+ `; + showStatus(`❌ ${errorMsg}`, 'error'); + } + } catch (error) { + contentDiv.innerHTML = ` +
+
+
网络错误
+
${error.message}
+
+ `; + showStatus(`❌ 获取额度信息失败: ${error.message}`, 'error'); + } + } + } +} + +async function batchVerifyProjectIds() { + const selectedFiles = Array.from(AppState.creds.selectedFiles); + if (selectedFiles.length === 0) { + showStatus('❌ 请先选择要检验的凭证', 'error'); + alert('请先选择要检验的凭证'); + return; + } + + if (!confirm(`确定要批量检验 ${selectedFiles.length} 个凭证的Project ID吗?\n\n将并行检验以加快速度。`)) { + return; + } + + showStatus(`🔍 正在并行检验 ${selectedFiles.length} 个凭证,请稍候...`, 'info'); + + // 并行执行所有检验请求 + const promises = selectedFiles.map(async (filename) => { + try { + const response = await fetch(`./creds/verify-project/${encodeURIComponent(filename)}`, { + method: 'POST', + headers: getAuthHeaders() + }); + const data = await response.json(); + + if (response.ok && data.success) { + return { success: true, filename, projectId: data.project_id, message: data.message }; + } else { + return { success: false, filename, error: data.message || '失败' }; + } + } catch (error) { + return { success: false, filename, error: error.message }; + } + }); + + // 等待所有请求完成 + const results = await Promise.all(promises); + + // 统计结果 + let successCount = 0; + let failCount = 0; + const resultMessages = []; + + results.forEach(result => { + if (result.success) { + successCount++; + resultMessages.push(`✅ ${result.filename}: ${result.projectId}`); + } else { + failCount++; + resultMessages.push(`❌ ${result.filename}: ${result.error}`); + } + }); + + await AppState.creds.refresh(); + + const summary = `批量检验完成!\n\n成功: ${successCount} 个\n失败: ${failCount} 个\n总计: ${selectedFiles.length} 个\n\n详细结果:\n${resultMessages.join('\n')}`; + + if (failCount === 0) { + showStatus(`✅ 全部检验成功!成功检验 ${successCount}/${selectedFiles.length} 个凭证`, 'success'); + } else if (successCount === 0) { + showStatus(`❌ 全部检验失败!失败 ${failCount}/${selectedFiles.length} 个凭证`, 'error'); + } else { + showStatus(`⚠️ 批量检验完成:成功 ${successCount}/${selectedFiles.length} 个,失败 ${failCount} 个`, 'info'); + } + + console.log(summary); + alert(summary); +} + +async function batchVerifyAntigravityProjectIds() { + const selectedFiles = Array.from(AppState.antigravityCreds.selectedFiles); + if (selectedFiles.length === 0) { + showStatus('❌ 请先选择要检验的Antigravity凭证', 'error'); + alert('请先选择要检验的Antigravity凭证'); + return; + } + + if (!confirm(`确定要批量检验 ${selectedFiles.length} 个Antigravity凭证的Project ID吗?\n\n将并行检验以加快速度。`)) { + return; + } + + showStatus(`🔍 正在并行检验 ${selectedFiles.length} 个Antigravity凭证,请稍候...`, 'info'); + + // 并行执行所有检验请求 + const promises = selectedFiles.map(async (filename) => { + try { + const response = await fetch(`./creds/verify-project/${encodeURIComponent(filename)}?mode=antigravity`, { + method: 'POST', + headers: getAuthHeaders() + }); + const data = await response.json(); + + if (response.ok && data.success) { + return { success: true, filename, projectId: data.project_id, message: data.message }; + } else { + return { success: false, filename, error: data.message || '失败' }; + } + } catch (error) { + return { success: false, filename, error: error.message }; + } + }); + + // 等待所有请求完成 + const results = await Promise.all(promises); + + // 统计结果 + let successCount = 0; + let failCount = 0; + const resultMessages = []; + + results.forEach(result => { + if (result.success) { + successCount++; + resultMessages.push(`✅ ${result.filename}: ${result.projectId}`); + } else { + failCount++; + resultMessages.push(`❌ ${result.filename}: ${result.error}`); + } + }); + + await AppState.antigravityCreds.refresh(); + + const summary = `Antigravity批量检验完成!\n\n成功: ${successCount} 个\n失败: ${failCount} 个\n总计: ${selectedFiles.length} 个\n\n详细结果:\n${resultMessages.join('\n')}`; + + if (failCount === 0) { + showStatus(`✅ 全部检验成功!成功检验 ${successCount}/${selectedFiles.length} 个Antigravity凭证`, 'success'); + } else if (successCount === 0) { + showStatus(`❌ 全部检验失败!失败 ${failCount}/${selectedFiles.length} 个Antigravity凭证`, 'error'); + } else { + showStatus(`⚠️ 批量检验完成:成功 ${successCount}/${selectedFiles.length} 个,失败 ${failCount} 个`, 'info'); + } + + console.log(summary); + alert(summary); +} + + +async function refreshAllEmails() { + if (!confirm('确定要刷新所有凭证的用户邮箱吗?这可能需要一些时间。')) return; + + try { + showStatus('正在刷新所有用户邮箱...', 'info'); + const response = await fetch('./creds/refresh-all-emails', { + method: 'POST', + headers: getAuthHeaders() + }); + const data = await response.json(); + if (response.ok) { + showStatus(`邮箱刷新完成:成功获取 ${data.success_count}/${data.total_count} 个邮箱地址`, 'success'); + await AppState.creds.refresh(); + } else { + showStatus(data.message || '邮箱刷新失败', 'error'); + } + } catch (error) { + showStatus(`邮箱刷新网络错误: ${error.message}`, 'error'); + } +} + +async function refreshAllAntigravityEmails() { + if (!confirm('确定要刷新所有Antigravity凭证的用户邮箱吗?这可能需要一些时间。')) return; + + try { + showStatus('正在刷新所有用户邮箱...', 'info'); + const response = await fetch('./creds/refresh-all-emails?mode=antigravity', { + method: 'POST', + headers: getAuthHeaders() + }); + const data = await response.json(); + if (response.ok) { + showStatus(`邮箱刷新完成:成功获取 ${data.success_count}/${data.total_count} 个邮箱地址`, 'success'); + await AppState.antigravityCreds.refresh(); + } else { + showStatus(data.message || '邮箱刷新失败', 'error'); + } + } catch (error) { + showStatus(`邮箱刷新网络错误: ${error.message}`, 'error'); + } +} + +async function deduplicateByEmail() { + if (!confirm('确定要对凭证进行凭证一键去重吗?\n\n相同邮箱的凭证只保留一个,其他将被删除。\n此操作不可撤销!')) return; + + try { + showStatus('正在进行凭证一键去重...', 'info'); + const response = await fetch('./creds/deduplicate-by-email', { + method: 'POST', + headers: getAuthHeaders() + }); + const data = await response.json(); + if (response.ok) { + const msg = `去重完成:删除 ${data.deleted_count} 个重复凭证,保留 ${data.kept_count} 个凭证(${data.unique_emails_count} 个唯一邮箱)`; + showStatus(msg, 'success'); + await AppState.creds.refresh(); + + // 显示详细信息 + if (data.duplicate_groups && data.duplicate_groups.length > 0) { + let details = '去重详情:\n\n'; + data.duplicate_groups.forEach(group => { + details += `邮箱: ${group.email}\n保留: ${group.kept_file}\n删除: ${group.deleted_files.join(', ')}\n\n`; + }); + console.log(details); + } + } else { + showStatus(data.message || '去重失败', 'error'); + } + } catch (error) { + showStatus(`去重网络错误: ${error.message}`, 'error'); + } +} + +async function deduplicateAntigravityByEmail() { + if (!confirm('确定要对Antigravity凭证进行凭证一键去重吗?\n\n相同邮箱的凭证只保留一个,其他将被删除。\n此操作不可撤销!')) return; + + try { + showStatus('正在进行凭证一键去重...', 'info'); + const response = await fetch('./creds/deduplicate-by-email?mode=antigravity', { + method: 'POST', + headers: getAuthHeaders() + }); + const data = await response.json(); + if (response.ok) { + const msg = `去重完成:删除 ${data.deleted_count} 个重复凭证,保留 ${data.kept_count} 个凭证(${data.unique_emails_count} 个唯一邮箱)`; + showStatus(msg, 'success'); + await AppState.antigravityCreds.refresh(); + + // 显示详细信息 + if (data.duplicate_groups && data.duplicate_groups.length > 0) { + let details = '去重详情:\n\n'; + data.duplicate_groups.forEach(group => { + details += `邮箱: ${group.email}\n保留: ${group.kept_file}\n删除: ${group.deleted_files.join(', ')}\n\n`; + }); + console.log(details); + } + } else { + showStatus(data.message || '去重失败', 'error'); + } + } catch (error) { + showStatus(`去重网络错误: ${error.message}`, 'error'); + } +} + +// ===================================================================== +// WebSocket日志相关 +// ===================================================================== +function connectWebSocket() { + if (AppState.logWebSocket && AppState.logWebSocket.readyState === WebSocket.OPEN) { + showStatus('WebSocket已经连接', 'info'); + return; + } + + try { + const wsPath = new URL('./auth/logs/stream', window.location.href).href; + const wsUrl = wsPath.replace(/^http/, 'ws'); + + document.getElementById('connectionStatusText').textContent = '连接中...'; + document.getElementById('logConnectionStatus').className = 'status info'; + + AppState.logWebSocket = new WebSocket(wsUrl); + + AppState.logWebSocket.onopen = () => { + document.getElementById('connectionStatusText').textContent = '已连接'; + document.getElementById('logConnectionStatus').className = 'status success'; + showStatus('日志流连接成功', 'success'); + clearLogsDisplay(); + }; + + AppState.logWebSocket.onmessage = (event) => { + const logLine = event.data; + if (logLine.trim()) { + AppState.allLogs.push(logLine); + if (AppState.allLogs.length > 1000) { + AppState.allLogs = AppState.allLogs.slice(-1000); + } + filterLogs(); + if (document.getElementById('autoScroll').checked) { + const logContainer = document.getElementById('logContainer'); + logContainer.scrollTop = logContainer.scrollHeight; + } + } + }; + + AppState.logWebSocket.onclose = () => { + document.getElementById('connectionStatusText').textContent = '连接断开'; + document.getElementById('logConnectionStatus').className = 'status error'; + showStatus('日志流连接断开', 'info'); + }; + + AppState.logWebSocket.onerror = (error) => { + document.getElementById('connectionStatusText').textContent = '连接错误'; + document.getElementById('logConnectionStatus').className = 'status error'; + showStatus('日志流连接错误: ' + error, 'error'); + }; + } catch (error) { + showStatus('创建WebSocket连接失败: ' + error.message, 'error'); + document.getElementById('connectionStatusText').textContent = '连接失败'; + document.getElementById('logConnectionStatus').className = 'status error'; + } +} + +function disconnectWebSocket() { + if (AppState.logWebSocket) { + AppState.logWebSocket.close(); + AppState.logWebSocket = null; + document.getElementById('connectionStatusText').textContent = '未连接'; + document.getElementById('logConnectionStatus').className = 'status info'; + showStatus('日志流连接已断开', 'info'); + } +} + +function clearLogsDisplay() { + AppState.allLogs = []; + AppState.filteredLogs = []; + document.getElementById('logContent').textContent = '日志已清空,等待新日志...'; +} + +async function downloadLogs() { + try { + const response = await fetch('./auth/logs/download', { headers: getAuthHeaders() }); + + if (response.ok) { + const contentDisposition = response.headers.get('Content-Disposition'); + let filename = 'gcli2api_logs.txt'; + if (contentDisposition) { + const match = contentDisposition.match(/filename=(.+)/); + if (match) filename = match[1]; + } + + const blob = await response.blob(); + const url = window.URL.createObjectURL(blob); + const a = document.createElement('a'); + a.href = url; + a.download = filename; + a.click(); + window.URL.revokeObjectURL(url); + + showStatus(`日志文件下载成功: ${filename}`, 'success'); + } else { + const data = await response.json(); + showStatus(`下载日志失败: ${data.detail || data.error || '未知错误'}`, 'error'); + } + } catch (error) { + showStatus(`下载日志时网络错误: ${error.message}`, 'error'); + } +} + +async function clearLogs() { + try { + const response = await fetch('./auth/logs/clear', { + method: 'POST', + headers: getAuthHeaders() + }); + + const data = await response.json(); + + if (response.ok) { + clearLogsDisplay(); + showStatus(data.message, 'success'); + } else { + showStatus(`清空日志失败: ${data.detail || data.error || '未知错误'}`, 'error'); + } + } catch (error) { + clearLogsDisplay(); + showStatus(`清空日志时网络错误: ${error.message}`, 'error'); + } +} + +function filterLogs() { + const filter = document.getElementById('logLevelFilter').value; + AppState.currentLogFilter = filter; + + if (filter === 'all') { + AppState.filteredLogs = [...AppState.allLogs]; + } else { + AppState.filteredLogs = AppState.allLogs.filter(log => log.toUpperCase().includes(filter)); + } + + displayLogs(); +} + +function displayLogs() { + const logContent = document.getElementById('logContent'); + if (AppState.filteredLogs.length === 0) { + logContent.textContent = AppState.currentLogFilter === 'all' ? + '暂无日志...' : `暂无${AppState.currentLogFilter}级别的日志...`; + } else { + logContent.textContent = AppState.filteredLogs.join('\n'); + } +} + +// ===================================================================== +// 环境变量凭证管理 +// ===================================================================== +async function checkEnvCredsStatus() { + const loading = document.getElementById('envStatusLoading'); + const content = document.getElementById('envStatusContent'); + + try { + loading.style.display = 'block'; + content.classList.add('hidden'); + + const response = await fetch('./auth/env-creds-status', { headers: getAuthHeaders() }); + const data = await response.json(); + + if (response.ok) { + const envVarsList = document.getElementById('envVarsList'); + envVarsList.textContent = Object.keys(data.available_env_vars).length > 0 + ? Object.keys(data.available_env_vars).join(', ') + : '未找到GCLI_CREDS_*环境变量'; + + const autoLoadStatus = document.getElementById('autoLoadStatus'); + autoLoadStatus.textContent = data.auto_load_enabled ? '✅ 已启用' : '❌ 未启用'; + autoLoadStatus.style.color = data.auto_load_enabled ? '#28a745' : '#dc3545'; + + document.getElementById('envFilesCount').textContent = `${data.existing_env_files_count} 个文件`; + + const envFilesList = document.getElementById('envFilesList'); + envFilesList.textContent = data.existing_env_files.length > 0 + ? data.existing_env_files.join(', ') + : '无'; + + content.classList.remove('hidden'); + showStatus('环境变量状态检查完成', 'success'); + } else { + showStatus(`获取环境变量状态失败: ${data.detail || data.error || '未知错误'}`, 'error'); + } + } catch (error) { + showStatus(`网络错误: ${error.message}`, 'error'); + } finally { + loading.style.display = 'none'; + } +} + +async function loadEnvCredentials() { + try { + showStatus('正在从环境变量导入凭证...', 'info'); + + const response = await fetch('./auth/load-env-creds', { + method: 'POST', + headers: getAuthHeaders() + }); + + const data = await response.json(); + + if (response.ok) { + if (data.loaded_count > 0) { + showStatus(`✅ 成功导入 ${data.loaded_count}/${data.total_count} 个凭证文件`, 'success'); + setTimeout(() => checkEnvCredsStatus(), 1000); + } else { + showStatus(`⚠️ ${data.message}`, 'info'); + } + } else { + showStatus(`导入失败: ${data.detail || data.error || '未知错误'}`, 'error'); + } + } catch (error) { + showStatus(`网络错误: ${error.message}`, 'error'); + } +} + +async function clearEnvCredentials() { + if (!confirm('确定要清除所有从环境变量导入的凭证文件吗?\n这将删除所有文件名以 "env-" 开头的认证文件。')) { + return; + } + + try { + showStatus('正在清除环境变量凭证文件...', 'info'); + + const response = await fetch('./auth/env-creds', { + method: 'DELETE', + headers: getAuthHeaders() + }); + + const data = await response.json(); + + if (response.ok) { + showStatus(`✅ 成功删除 ${data.deleted_count} 个环境变量凭证文件`, 'success'); + setTimeout(() => checkEnvCredsStatus(), 1000); + } else { + showStatus(`清除失败: ${data.detail || data.error || '未知错误'}`, 'error'); + } + } catch (error) { + showStatus(`网络错误: ${error.message}`, 'error'); + } +} + +// ===================================================================== +// 配置管理 +// ===================================================================== +async function loadConfig() { + const loading = document.getElementById('configLoading'); + const form = document.getElementById('configForm'); + + try { + loading.style.display = 'block'; + form.classList.add('hidden'); + + const response = await fetch('./config/get', { headers: getAuthHeaders() }); + const data = await response.json(); + + if (response.ok) { + AppState.currentConfig = data.config; + AppState.envLockedFields = new Set(data.env_locked || []); + + populateConfigForm(); + form.classList.remove('hidden'); + showStatus('配置加载成功', 'success'); + } else { + showStatus(`加载配置失败: ${data.detail || data.error || '未知错误'}`, 'error'); + } + } catch (error) { + showStatus(`网络错误: ${error.message}`, 'error'); + } finally { + loading.style.display = 'none'; + } +} + +function populateConfigForm() { + const c = AppState.currentConfig; + + setConfigField('host', c.host || '0.0.0.0'); + setConfigField('port', c.port || 7861); + setConfigField('configApiPassword', c.api_password || ''); + setConfigField('configPanelPassword', c.panel_password || ''); + setConfigField('configPassword', c.password || 'pwd'); + setConfigField('credentialsDir', c.credentials_dir || ''); + setConfigField('proxy', c.proxy || ''); + setConfigField('codeAssistEndpoint', c.code_assist_endpoint || ''); + setConfigField('oauthProxyUrl', c.oauth_proxy_url || ''); + setConfigField('googleapisProxyUrl', c.googleapis_proxy_url || ''); + setConfigField('resourceManagerApiUrl', c.resource_manager_api_url || ''); + setConfigField('serviceUsageApiUrl', c.service_usage_api_url || ''); + setConfigField('antigravityApiUrl', c.antigravity_api_url || ''); + + document.getElementById('autoBanEnabled').checked = Boolean(c.auto_ban_enabled); + setConfigField('autoBanErrorCodes', (c.auto_ban_error_codes || []).join(',')); + setConfigField('callsPerRotation', c.calls_per_rotation || 10); + + document.getElementById('retry429Enabled').checked = Boolean(c.retry_429_enabled); + setConfigField('retry429MaxRetries', c.retry_429_max_retries || 20); + setConfigField('retry429Interval', c.retry_429_interval || 0.1); + + document.getElementById('compatibilityModeEnabled').checked = Boolean(c.compatibility_mode_enabled); + document.getElementById('returnThoughtsToFrontend').checked = Boolean(c.return_thoughts_to_frontend !== false); + document.getElementById('antigravityStream2nostream').checked = Boolean(c.antigravity_stream2nostream !== false); + + setConfigField('antiTruncationMaxAttempts', c.anti_truncation_max_attempts || 3); +} + +function setConfigField(fieldId, value) { + const field = document.getElementById(fieldId); + if (field) { + field.value = value; + const configKey = fieldId.replace(/([A-Z])/g, '_$1').toLowerCase(); + if (AppState.envLockedFields.has(configKey)) { + field.disabled = true; + field.classList.add('env-locked'); + } else { + field.disabled = false; + field.classList.remove('env-locked'); + } + } +} + +async function saveConfig() { + try { + const getValue = (id, def = '') => document.getElementById(id)?.value.trim() || def; + const getInt = (id, def = 0) => parseInt(document.getElementById(id)?.value) || def; + const getFloat = (id, def = 0.0) => parseFloat(document.getElementById(id)?.value) || def; + const getChecked = (id, def = false) => document.getElementById(id)?.checked || def; + + const config = { + host: getValue('host', '0.0.0.0'), + port: getInt('port', 7861), + api_password: getValue('configApiPassword'), + panel_password: getValue('configPanelPassword'), + password: getValue('configPassword', 'pwd'), + code_assist_endpoint: getValue('codeAssistEndpoint'), + credentials_dir: getValue('credentialsDir'), + proxy: getValue('proxy'), + oauth_proxy_url: getValue('oauthProxyUrl'), + googleapis_proxy_url: getValue('googleapisProxyUrl'), + resource_manager_api_url: getValue('resourceManagerApiUrl'), + service_usage_api_url: getValue('serviceUsageApiUrl'), + antigravity_api_url: getValue('antigravityApiUrl'), + auto_ban_enabled: getChecked('autoBanEnabled'), + auto_ban_error_codes: getValue('autoBanErrorCodes').split(',') + .map(c => parseInt(c.trim())).filter(c => !isNaN(c)), + calls_per_rotation: getInt('callsPerRotation', 10), + retry_429_enabled: getChecked('retry429Enabled'), + retry_429_max_retries: getInt('retry429MaxRetries', 20), + retry_429_interval: getFloat('retry429Interval', 0.1), + compatibility_mode_enabled: getChecked('compatibilityModeEnabled'), + return_thoughts_to_frontend: getChecked('returnThoughtsToFrontend'), + antigravity_stream2nostream: getChecked('antigravityStream2nostream'), + anti_truncation_max_attempts: getInt('antiTruncationMaxAttempts', 3) + }; + + const response = await fetch('./config/save', { + method: 'POST', + headers: getAuthHeaders(), + body: JSON.stringify({ config }) + }); + + const data = await response.json(); + + if (response.ok) { + let message = '配置保存成功'; + + if (data.hot_updated && data.hot_updated.length > 0) { + message += `,以下配置已立即生效: ${data.hot_updated.join(', ')}`; + } + + if (data.restart_required && data.restart_required.length > 0) { + message += `\n⚠️ 重启提醒: ${data.restart_notice}`; + showStatus(message, 'info'); + } else { + showStatus(message, 'success'); + } + + setTimeout(() => loadConfig(), 1000); + } else { + showStatus(`保存配置失败: ${data.detail || data.error || '未知错误'}`, 'error'); + } + } catch (error) { + showStatus(`网络错误: ${error.message}`, 'error'); + } +} + +// 镜像网址配置 +const mirrorUrls = { + codeAssistEndpoint: 'https://gcli-api.sukaka.top/cloudcode-pa', + oauthProxyUrl: 'https://gcli-api.sukaka.top/oauth2', + googleapisProxyUrl: 'https://gcli-api.sukaka.top/googleapis', + resourceManagerApiUrl: 'https://gcli-api.sukaka.top/cloudresourcemanager', + serviceUsageApiUrl: 'https://gcli-api.sukaka.top/serviceusage', + antigravityApiUrl: 'https://gcli-api.sukaka.top/daily-cloudcode-pa' +}; + +const officialUrls = { + codeAssistEndpoint: 'https://cloudcode-pa.googleapis.com', + oauthProxyUrl: 'https://oauth2.googleapis.com', + googleapisProxyUrl: 'https://www.googleapis.com', + resourceManagerApiUrl: 'https://cloudresourcemanager.googleapis.com', + serviceUsageApiUrl: 'https://serviceusage.googleapis.com', + antigravityApiUrl: 'https://daily-cloudcode-pa.sandbox.googleapis.com' +}; + +function useMirrorUrls() { + if (confirm('确定要将所有端点配置为镜像网址吗?')) { + for (const [fieldId, url] of Object.entries(mirrorUrls)) { + const field = document.getElementById(fieldId); + if (field && !field.disabled) field.value = url; + } + showStatus('✅ 已切换到镜像网址配置,记得点击"保存配置"按钮保存设置', 'success'); + } +} + +function restoreOfficialUrls() { + if (confirm('确定要将所有端点配置为官方地址吗?')) { + for (const [fieldId, url] of Object.entries(officialUrls)) { + const field = document.getElementById(fieldId); + if (field && !field.disabled) field.value = url; + } + showStatus('✅ 已切换到官方端点配置,记得点击"保存配置"按钮保存设置', 'success'); + } +} + +// ===================================================================== +// 使用统计 +// ===================================================================== +async function refreshUsageStats() { + const loading = document.getElementById('usageLoading'); + const list = document.getElementById('usageList'); + + try { + loading.style.display = 'block'; + list.innerHTML = ''; + + const [statsResponse, aggregatedResponse] = await Promise.all([ + fetch('./usage/stats', { headers: getAuthHeaders() }), + fetch('./usage/aggregated', { headers: getAuthHeaders() }) + ]); + + if (statsResponse.status === 401 || aggregatedResponse.status === 401) { + showStatus('认证失败,请重新登录', 'error'); + setTimeout(() => location.reload(), 1500); + return; + } + + const statsData = await statsResponse.json(); + const aggregatedData = await aggregatedResponse.json(); + + if (statsResponse.ok && aggregatedResponse.ok) { + AppState.usageStatsData = statsData.success ? statsData.data : statsData; + + const aggData = aggregatedData.success ? aggregatedData.data : aggregatedData; + document.getElementById('totalApiCalls').textContent = aggData.total_calls_24h || 0; + document.getElementById('totalFiles').textContent = aggData.total_files || 0; + document.getElementById('avgCallsPerFile').textContent = (aggData.avg_calls_per_file || 0).toFixed(1); + + renderUsageList(); + + showStatus(`已加载 ${aggData.total_files || Object.keys(AppState.usageStatsData).length} 个文件的使用统计`, 'success'); + } else { + const errorMsg = statsData.detail || aggregatedData.detail || '加载使用统计失败'; + showStatus(`错误: ${errorMsg}`, 'error'); + } + } catch (error) { + showStatus(`网络错误: ${error.message}`, 'error'); + } finally { + loading.style.display = 'none'; + } +} + +function renderUsageList() { + const list = document.getElementById('usageList'); + list.innerHTML = ''; + + if (Object.keys(AppState.usageStatsData).length === 0) { + list.innerHTML = '

暂无使用统计数据

'; + return; + } + + for (const [filename, stats] of Object.entries(AppState.usageStatsData)) { + const card = document.createElement('div'); + card.className = 'usage-card'; + + const calls24h = stats.calls_24h || 0; + + card.innerHTML = ` +
+
${filename}
+
+
+
+ 24小时内调用次数 + ${calls24h} +
+
+
+ +
+ `; + + list.appendChild(card); + } +} + +async function resetSingleUsageStats(filename) { + if (!confirm(`确定要重置 ${filename} 的使用统计吗?`)) return; + + try { + const response = await fetch('./usage/reset', { + method: 'POST', + headers: getAuthHeaders(), + body: JSON.stringify({ filename }) + }); + + const data = await response.json(); + + if (response.ok && data.success) { + showStatus(data.message, 'success'); + await refreshUsageStats(); + } else { + showStatus(`重置失败: ${data.message || data.detail || data.error || '未知错误'}`, 'error'); + } + } catch (error) { + showStatus(`网络错误: ${error.message}`, 'error'); + } +} + +async function resetAllUsageStats() { + if (!confirm('确定要重置所有文件的使用统计吗?此操作不可恢复!')) return; + + try { + const response = await fetch('./usage/reset', { + method: 'POST', + headers: getAuthHeaders(), + body: JSON.stringify({}) + }); + + const data = await response.json(); + + if (response.ok && data.success) { + showStatus(data.message, 'success'); + await refreshUsageStats(); + } else { + showStatus(`重置失败: ${data.message || data.detail || data.error || '未知错误'}`, 'error'); + } + } catch (error) { + showStatus(`网络错误: ${error.message}`, 'error'); + } +} + +// ===================================================================== +// 冷却倒计时自动更新 +// ===================================================================== +function startCooldownTimer() { + if (AppState.cooldownTimerInterval) { + clearInterval(AppState.cooldownTimerInterval); + } + + AppState.cooldownTimerInterval = setInterval(() => { + updateCooldownDisplays(); + }, 1000); +} + +function stopCooldownTimer() { + if (AppState.cooldownTimerInterval) { + clearInterval(AppState.cooldownTimerInterval); + AppState.cooldownTimerInterval = null; + } +} + +function updateCooldownDisplays() { + let needsRefresh = false; + + // 检查模型级冷却是否过期 + for (const credInfo of Object.values(AppState.creds.data)) { + if (credInfo.model_cooldowns && Object.keys(credInfo.model_cooldowns).length > 0) { + const currentTime = Date.now() / 1000; + const hasExpiredCooldowns = Object.entries(credInfo.model_cooldowns).some(([, until]) => until <= currentTime); + + if (hasExpiredCooldowns) { + needsRefresh = true; + break; + } + } + } + + if (needsRefresh) { + AppState.creds.renderList(); + return; + } + + // 更新模型级冷却的显示 + document.querySelectorAll('.cooldown-badge').forEach(badge => { + const card = badge.closest('.cred-card'); + const filenameEl = card?.querySelector('.cred-filename'); + if (!filenameEl) return; + + const filename = filenameEl.textContent; + const credInfo = Object.values(AppState.creds.data).find(c => c.filename === filename); + + if (credInfo && credInfo.model_cooldowns) { + const currentTime = Date.now() / 1000; + const titleMatch = badge.getAttribute('title')?.match(/模型: (.+)/); + if (titleMatch) { + const model = titleMatch[1]; + const cooldownUntil = credInfo.model_cooldowns[model]; + if (cooldownUntil) { + const remaining = Math.max(0, Math.floor(cooldownUntil - currentTime)); + if (remaining > 0) { + const shortModel = model.replace('gemini-', '').replace('-exp', '') + .replace('2.0-', '2-').replace('1.5-', '1.5-'); + const timeDisplay = formatCooldownTime(remaining).replace(/s$/, '').replace(/ /g, ''); + badge.innerHTML = `🔧 ${shortModel}: ${timeDisplay}`; + } + } + } + } + }); +} + +// ===================================================================== +// 版本信息管理 +// ===================================================================== + +// 获取并显示版本信息(不检查更新) +async function fetchAndDisplayVersion() { + try { + const response = await fetch('./version/info'); + const data = await response.json(); + + const versionText = document.getElementById('versionText'); + + if (data.success) { + // 只显示版本号 + versionText.textContent = `v${data.version}`; + versionText.title = `完整版本: ${data.full_hash}\n提交信息: ${data.message}\n提交时间: ${data.date}`; + versionText.style.cursor = 'help'; + } else { + versionText.textContent = '未知版本'; + versionText.title = data.error || '无法获取版本信息'; + } + } catch (error) { + console.error('获取版本信息失败:', error); + const versionText = document.getElementById('versionText'); + if (versionText) { + versionText.textContent = '版本信息获取失败'; + } + } +} + +// 检查更新 +async function checkForUpdates() { + const checkBtn = document.getElementById('checkUpdateBtn'); + if (!checkBtn) return; + + const originalText = checkBtn.textContent; + + try { + // 显示检查中状态 + checkBtn.textContent = '检查中...'; + checkBtn.disabled = true; + + // 调用API检查更新 + const response = await fetch('./version/info?check_update=true'); + const data = await response.json(); + + if (data.success) { + if (data.check_update === false) { + // 检查更新失败 + showStatus(`检查更新失败: ${data.update_error || '未知错误'}`, 'error'); + } else if (data.has_update === true) { + // 有更新 + const updateMsg = `发现新版本!\n当前: v${data.version}\n最新: v${data.latest_version}\n\n更新内容: ${data.latest_message || '无'}`; + showStatus(updateMsg.replace(/\n/g, ' '), 'warning'); + + // 更新按钮样式 + checkBtn.style.backgroundColor = '#ffc107'; + checkBtn.textContent = '有新版本'; + + setTimeout(() => { + checkBtn.style.backgroundColor = '#17a2b8'; + checkBtn.textContent = originalText; + }, 5000); + } else if (data.has_update === false) { + // 已是最新 + showStatus('已是最新版本!', 'success'); + + checkBtn.style.backgroundColor = '#28a745'; + checkBtn.textContent = '已是最新'; + + setTimeout(() => { + checkBtn.style.backgroundColor = '#17a2b8'; + checkBtn.textContent = originalText; + }, 3000); + } else { + // 无法确定 + showStatus('无法确定是否有更新', 'info'); + } + } else { + showStatus(`检查更新失败: ${data.error}`, 'error'); + } + } catch (error) { + console.error('检查更新失败:', error); + showStatus(`检查更新失败: ${error.message}`, 'error'); + } finally { + checkBtn.disabled = false; + if (checkBtn.textContent === '检查中...') { + checkBtn.textContent = originalText; + } + } +} + +// ===================================================================== +// 页面初始化 +// ===================================================================== +window.onload = async function () { + const autoLoginSuccess = await autoLogin(); + + if (!autoLoginSuccess) { + showStatus('请输入密码登录', 'info'); + } else { + // 登录成功后获取版本信息 + await fetchAndDisplayVersion(); + } + + startCooldownTimer(); + + const antigravityAuthBtn = document.getElementById('getAntigravityAuthBtn'); + if (antigravityAuthBtn) { + antigravityAuthBtn.addEventListener('click', startAntigravityAuth); + } +}; + +// 拖拽功能 - 初始化 +document.addEventListener('DOMContentLoaded', function () { + const uploadArea = document.getElementById('uploadArea'); + + if (uploadArea) { + uploadArea.addEventListener('dragover', (event) => { + event.preventDefault(); + uploadArea.classList.add('dragover'); + }); + + uploadArea.addEventListener('dragleave', (event) => { + event.preventDefault(); + uploadArea.classList.remove('dragover'); + }); + + uploadArea.addEventListener('drop', (event) => { + event.preventDefault(); + uploadArea.classList.remove('dragover'); + AppState.uploadFiles.addFiles(Array.from(event.dataTransfer.files)); + }); + } +}); diff --git a/front/control_panel.html b/front/control_panel.html new file mode 100644 index 0000000000000000000000000000000000000000..8b0781939521f35ee1db3202f26ad33b4e77362e --- /dev/null +++ b/front/control_panel.html @@ -0,0 +1,2092 @@ + + + + + + + GCLI2API 控制面板 + + + + +
+ + +
+

GCLI2API 管理面板

+

请输入访问密码:

+ +
+ +
+ + + +
+ + + + + + + \ No newline at end of file diff --git a/front/control_panel_mobile.html b/front/control_panel_mobile.html new file mode 100644 index 0000000000000000000000000000000000000000..d83b320095a88f08bb1734104144d76fd4cb4e3f --- /dev/null +++ b/front/control_panel_mobile.html @@ -0,0 +1,1822 @@ + + + + + + + GCLI2API 移动端控制面板 + + + + +
+ + +
+

GCLI2API 移动端控制面板

+

请输入访问密码:

+ +

+ +
+ + + +
+ + + + + + \ No newline at end of file diff --git a/install.ps1 b/install.ps1 new file mode 100644 index 0000000000000000000000000000000000000000..3c81816c6ec1ea5fb1aeb51a2995ee6be9b44807 --- /dev/null +++ b/install.ps1 @@ -0,0 +1,35 @@ +# 检测是否为管理员 +$IsElevated = ([Security.Principal.WindowsPrincipal] [Security.Principal.WindowsIdentity]::GetCurrent()). + IsInRole([Security.Principal.WindowsBuiltInRole]::Administrator) + +# Skip Scoop install if already present to avoid stopping the script +if (Get-Command scoop -ErrorAction SilentlyContinue) { + Write-Host "Scoop is already installed. Skipping installation." +} else { + Set-ExecutionPolicy -ExecutionPolicy RemoteSigned -Scope CurrentUser -Force + if ($IsElevated) { + # 管理员:使用官方一行命令并传入 -RunAsAdmin + Invoke-Expression "& {$(Invoke-RestMethod get.scoop.sh)} -RunAsAdmin" + } else { + # 普通用户安装 + Invoke-WebRequest -useb get.scoop.sh | Invoke-Expression + } +} + +scoop install git uv +if (Test-Path -LiteralPath "./web.py") { + # Already in target directory; skip clone and cd +} +elseif (Test-Path -LiteralPath "./gcli2api/web.py") { + Set-Location ./gcli2api +} +else { + git clone https://github.com/su-kaka/gcli2api.git + Set-Location ./gcli2api +} +# Create relocatable virtual environment to ensure portability +$env:UV_VENV_CLEAR = "1" +uv venv --relocatable +uv sync +.venv/Scripts/activate.ps1 +python web.py \ No newline at end of file diff --git a/install.sh b/install.sh new file mode 100644 index 0000000000000000000000000000000000000000..c5372a1352b947e4741b91cc852dd8a2c59c82ff --- /dev/null +++ b/install.sh @@ -0,0 +1,302 @@ +#!/bin/bash +set -e # Exit on error +set -u # Exit on undefined variable +set -o pipefail # Exit on pipe failure + +# Color codes for output +RED='\033[0;31m' +GREEN='\033[0;32m' +YELLOW='\033[1;33m' +BLUE='\033[0;34m' +NC='\033[0m' # No Color + +# Logging functions +log_info() { + echo -e "${GREEN}[INFO]${NC} $1" +} + +log_error() { + echo -e "${RED}[ERROR]${NC} $1" >&2 +} + +log_warn() { + echo -e "${YELLOW}[WARN]${NC} $1" +} + +log_debug() { + echo -e "${BLUE}[DEBUG]${NC} $1" +} + +# Cleanup function for error handling +cleanup() { + local exit_code=$? + if [ $exit_code -ne 0 ]; then + log_error "Installation failed with exit code $exit_code" + fi + exit $exit_code +} + +trap cleanup EXIT + +# Detect OS and distribution +detect_os() { + log_info "Detecting operating system..." + + if [[ "$OSTYPE" == "linux-gnu"* ]]; then + if [ -f /etc/os-release ]; then + . /etc/os-release + OS_NAME=$ID + OS_VERSION=$VERSION_ID + log_info "Detected: $NAME $VERSION_ID" + elif [ -f /etc/lsb-release ]; then + . /etc/lsb-release + OS_NAME=$DISTRIB_ID + OS_VERSION=$DISTRIB_RELEASE + log_info "Detected: $DISTRIB_ID $DISTRIB_RELEASE" + else + OS_NAME="linux" + OS_VERSION="unknown" + log_warn "Could not determine specific Linux distribution" + fi + elif [[ "$OSTYPE" == "darwin"* ]]; then + OS_NAME="macos" + OS_VERSION=$(sw_vers -productVersion) + log_info "Detected: macOS $OS_VERSION" + elif [[ "$OSTYPE" == "freebsd"* ]]; then + OS_NAME="freebsd" + OS_VERSION=$(freebsd-version) + log_info "Detected: FreeBSD $OS_VERSION" + else + log_error "Unsupported operating system: $OSTYPE" + exit 1 + fi +} + +# Check for root privileges (only for Linux package managers that need it) +check_root_if_needed() { + if [[ "$OS_NAME" == "ubuntu" ]] || [[ "$OS_NAME" == "debian" ]] || [[ "$OS_NAME" == "linuxmint" ]] || [[ "$OS_NAME" == "kali" ]]; then + if [ "$EUID" -ne 0 ]; then + log_error "This script requires root privileges for apt. Please run with sudo." + exit 1 + fi + elif [[ "$OS_NAME" == "fedora" ]] || [[ "$OS_NAME" == "rhel" ]] || [[ "$OS_NAME" == "centos" ]] || [[ "$OS_NAME" == "rocky" ]] || [[ "$OS_NAME" == "almalinux" ]]; then + if [ "$EUID" -ne 0 ]; then + log_error "This script requires root privileges for dnf/yum. Please run with sudo." + exit 1 + fi + elif [[ "$OS_NAME" == "arch" ]] || [[ "$OS_NAME" == "manjaro" ]]; then + if [ "$EUID" -ne 0 ]; then + log_error "This script requires root privileges for pacman. Please run with sudo." + exit 1 + fi + fi +} + +# Update package manager +update_packages() { + log_info "Updating package manager..." + + case "$OS_NAME" in + ubuntu|debian|linuxmint|kali|pop) + if ! apt update; then + log_error "Failed to update apt package lists" + exit 1 + fi + ;; + fedora|rhel|centos|rocky|almalinux) + if command -v dnf &> /dev/null; then + if ! dnf check-update; then + # dnf check-update returns 100 if updates are available, which is not an error + if [ $? -ne 100 ]; then + log_warn "dnf check-update returned non-standard exit code" + fi + fi + else + if ! yum check-update; then + if [ $? -ne 100 ]; then + log_warn "yum check-update returned non-standard exit code" + fi + fi + fi + ;; + arch|manjaro) + if ! pacman -Syu; then + log_error "Failed to update pacman database" + exit 1 + fi + ;; + macos) + if command -v brew &> /dev/null; then + log_info "Updating Homebrew..." + brew update + else + log_warn "Homebrew not installed. Skipping package manager update." + fi + ;; + *) + log_warn "Unknown package manager for $OS_NAME. Skipping update." + ;; + esac +} + +# Install git based on OS +install_git() { + if ! command -v git &> /dev/null; then + log_info "Installing git..." + + case "$OS_NAME" in + ubuntu|debian|linuxmint|kali|pop) + if ! apt install git -y; then + log_error "Failed to install git" + exit 1 + fi + ;; + fedora|rhel|centos|rocky|almalinux) + if command -v dnf &> /dev/null; then + if ! dnf install git -y; then + log_error "Failed to install git" + exit 1 + fi + else + if ! yum install git -y; then + log_error "Failed to install git" + exit 1 + fi + fi + ;; + arch|manjaro) + if ! pacman -S git --noconfirm; then + log_error "Failed to install git" + exit 1 + fi + ;; + macos) + if command -v brew &> /dev/null; then + if ! brew install git; then + log_error "Failed to install git" + exit 1 + fi + else + log_error "Homebrew is required for macOS. Install from https://brew.sh/" + exit 1 + fi + ;; + *) + log_error "Don't know how to install git on $OS_NAME" + exit 1 + ;; + esac + else + log_info "Git is already installed ($(git --version))" + fi +} + +# Detect OS first +detect_os + +# Check root if needed +check_root_if_needed + +log_info "Starting installation process..." + +# Update package lists +update_packages + +# Install git +install_git + +# Install uv if not present +if ! command -v uv &> /dev/null; then + log_info "Installing uv package manager..." + if ! curl -Ls https://astral.sh/uv/install.sh | sh; then + log_error "Failed to install uv" + exit 1 + fi + + # Source environment + if [ -f "$HOME/.local/bin/env" ]; then + source "$HOME/.local/bin/env" + elif [ -f "$HOME/.cargo/env" ]; then + source "$HOME/.cargo/env" + fi + + # Verify uv installation + if ! command -v uv &> /dev/null; then + log_error "uv installation failed - command not found after install" + exit 1 + fi +else + log_info "uv is already installed" +fi + +# Determine working directory +log_info "Checking project directory..." +if [ -f "./web.py" ]; then + log_info "Already in target directory" +elif [ -f "./gcli2api/web.py" ]; then + log_info "Changing to gcli2api directory" + cd ./gcli2api || exit 1 +else + log_info "Cloning repository..." + if [ -d "./gcli2api" ]; then + log_warn "gcli2api directory exists but web.py not found. Removing and re-cloning..." + rm -rf ./gcli2api + fi + + if ! git clone https://github.com/su-kaka/gcli2api.git; then + log_error "Failed to clone repository" + exit 1 + fi + + cd ./gcli2api || exit 1 +fi + +# Update repository if it's a git repo +if [ -d ".git" ]; then + log_info "Updating repository..." + if ! git pull; then + log_warn "Git pull failed, continuing anyway..." + fi +else + log_warn "Not a git repository, skipping update" +fi + +# Create relocatable virtual environment to ensure portability +log_info "Creating relocatable virtual environment..." +export UV_VENV_CLEAR=1 +if ! uv venv --relocatable; then + log_error "Failed to create virtual environment" + exit 1 +fi + +# Sync dependencies +log_info "Syncing dependencies with uv..." +if ! uv sync; then + log_error "Failed to sync dependencies" + exit 1 +fi + +# Activate virtual environment +log_info "Activating virtual environment..." +if [ -f ".venv/bin/activate" ]; then + source .venv/bin/activate +else + log_error "Virtual environment not found at .venv/bin/activate" + exit 1 +fi + +# Verify Python is available +if ! command -v python3 &> /dev/null; then + log_error "python3 not found in virtual environment" + exit 1 +fi + +# Check if web.py exists +if [ ! -f "web.py" ]; then + log_error "web.py not found in current directory" + exit 1 +fi + +# Start the application +log_info "Starting application..." +python3 web.py \ No newline at end of file diff --git a/log.py b/log.py new file mode 100644 index 0000000000000000000000000000000000000000..9305b7b8c8a8eed6c2b01868f3e65f792f81a3e7 --- /dev/null +++ b/log.py @@ -0,0 +1,179 @@ +""" +日志模块 - 使用环境变量配置 +""" + +import os +import sys +import threading +from datetime import datetime + +# 日志级别定义 +LOG_LEVELS = {"debug": 0, "info": 1, "warning": 2, "error": 3, "critical": 4} + +# 线程锁,用于文件写入同步 +_file_lock = threading.Lock() + +# 文件写入状态标志 +_file_writing_disabled = False +_disable_reason = None + + +def _get_current_log_level(): + """获取当前日志级别""" + level = os.getenv("LOG_LEVEL", "info").lower() + return LOG_LEVELS.get(level, LOG_LEVELS["info"]) + + +def _get_log_file_path(): + """获取日志文件路径""" + return os.getenv("LOG_FILE", "log.txt") + + +def _clear_log_file(): + """清空日志文件(在启动时调用)""" + global _file_writing_disabled, _disable_reason + + try: + log_file = _get_log_file_path() + with _file_lock: + with open(log_file, "w", encoding="utf-8") as f: + f.write("") # 清空文件 + except (PermissionError, OSError, IOError) as e: + # 检测只读文件系统或权限问题,禁用文件写入 + _file_writing_disabled = True + _disable_reason = str(e) + print( + f"Warning: File system appears to be read-only or permission denied. " + f"Disabling log file writing: {e}", + file=sys.stderr, + ) + print("Log messages will continue to display in console only.", file=sys.stderr) + except Exception as e: + # 其他异常仍然输出警告但不禁用写入(可能是临时问题) + print(f"Warning: Failed to clear log file: {e}", file=sys.stderr) + + +def _write_to_file(message: str): + """线程安全地写入日志文件""" + global _file_writing_disabled, _disable_reason + + # 如果文件写入已被禁用,直接返回 + if _file_writing_disabled: + return + + try: + log_file = _get_log_file_path() + with _file_lock: + with open(log_file, "a", encoding="utf-8") as f: + f.write(message + "\n") + f.flush() # 强制刷新到磁盘,确保实时写入 + except (PermissionError, OSError, IOError) as e: + # 检测只读文件系统或权限问题,禁用文件写入 + _file_writing_disabled = True + _disable_reason = str(e) + print( + f"Warning: File system appears to be read-only or permission denied. " + f"Disabling log file writing: {e}", + file=sys.stderr, + ) + print("Log messages will continue to display in console only.", file=sys.stderr) + except Exception as e: + # 其他异常仍然输出警告但不禁用写入(可能是临时问题) + print(f"Warning: Failed to write to log file: {e}", file=sys.stderr) + + +def _log(level: str, message: str): + """ + 内部日志函数 + """ + level = level.lower() + if level not in LOG_LEVELS: + print(f"Warning: Unknown log level '{level}'", file=sys.stderr) + return + + # 检查日志级别 + current_level = _get_current_log_level() + if LOG_LEVELS[level] < current_level: + return + + # 截断日志消息到最多500个字符 + #if len(message) > 500: + #message = message[:500] + "..." + + # 格式化日志消息 + timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + entry = f"[{timestamp}] [{level.upper()}] {message}" + + # 输出到控制台 + if level in ("error", "critical"): + print(entry, file=sys.stderr) + else: + print(entry) + + # 实时写入文件 + _write_to_file(entry) + + +def set_log_level(level: str): + """设置日志级别提示""" + level = level.lower() + if level not in LOG_LEVELS: + print(f"Warning: Unknown log level '{level}'. Valid levels: {', '.join(LOG_LEVELS.keys())}") + return False + + print(f"Note: To set log level '{level}', please set LOG_LEVEL environment variable") + return True + + +class Logger: + """支持 log('info', 'msg') 和 log.info('msg') 两种调用方式""" + + def __call__(self, level: str, message: str): + """支持 log('info', 'message') 调用方式""" + _log(level, message) + + def debug(self, message: str): + """记录调试信息""" + _log("debug", message) + + def info(self, message: str): + """记录一般信息""" + _log("info", message) + + def warning(self, message: str): + """记录警告信息""" + _log("warning", message) + + def error(self, message: str): + """记录错误信息""" + _log("error", message) + + def critical(self, message: str): + """记录严重错误信息""" + _log("critical", message) + + def get_current_level(self) -> str: + """获取当前日志级别名称""" + current_level = _get_current_log_level() + for name, value in LOG_LEVELS.items(): + if value == current_level: + return name + return "info" + + def get_log_file(self) -> str: + """获取当前日志文件路径""" + return _get_log_file_path() + + +# 导出全局日志实例 +log = Logger() + +# 导出的公共接口 +__all__ = ["log", "set_log_level", "LOG_LEVELS"] + +# 在模块加载时清空日志文件 +_clear_log_file() + +# 使用说明: +# 1. 设置日志级别: export LOG_LEVEL=debug (或在.env文件中设置) +# 2. 设置日志文件: export LOG_FILE=log.txt (或在.env文件中设置) diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000000000000000000000000000000000000..403e7028a9c35004fd08cd59398dd6a0f91d2c50 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,102 @@ +[project] +name = "gcli2api" +version = "0.1.0" +description = "Convert GeminiCLI to OpenAI and Gemini API interfaces" +readme = "README.md" +requires-python = ">=3.12" +license = {text = "CNC-1.0"} +authors = [ + {name = "su-kaka"} +] +keywords = ["gemini", "openai", "api", "converter", "cli"] +classifiers = [ + "Development Status :: 4 - Beta", + "Intended Audience :: Developers", + "License :: Other/Proprietary License", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", +] +dependencies = [ + "aiofiles>=24.1.0", + "fastapi>=0.116.1", + "httpx[socks]>=0.28.1", + "hypercorn>=0.17.3", + "motor>=3.7.1", + "oauthlib>=3.3.1", + "pydantic>=2.11.7", + "pyjwt>=2.10.1", + "python-dotenv>=1.1.1", + "python-multipart>=0.0.20", + "pypinyin>=0.51.0", + "aiosqlite>=0.20.0", +] + +[project.optional-dependencies] +dev = [ + "pytest>=8.0.0", + "pytest-asyncio>=0.23.0", + "pytest-cov>=4.1.0", + "black>=24.0.0", + "flake8>=7.0.0", + "mypy>=1.8.0", + "pre-commit>=3.6.0", +] + +[tool.pytest.ini_options] +minversion = "8.0" +testpaths = ["."] +python_files = ["test_*.py"] +python_classes = ["Test*"] +python_functions = ["test_*"] +asyncio_mode = "auto" +addopts = [ + "-v", + "--strict-markers", +] + +[tool.black] +line-length = 100 +target-version = ["py312"] +include = '\.pyi?$' +extend-exclude = ''' +/( + # directories + \.eggs + | \.git + | \.hg + | \.mypy_cache + | \.tox + | \.venv + | build + | dist +)/ +''' + +[tool.mypy] +python_version = "3.12" +warn_return_any = true +warn_unused_configs = true +disallow_untyped_defs = false +ignore_missing_imports = true +exclude = [ + "build", + "dist", +] + +[tool.coverage.run] +source = ["src"] +omit = [ + "*/tests/*", + "*/test_*.py", +] + +[tool.coverage.report] +exclude_lines = [ + "pragma: no cover", + "def __repr__", + "raise AssertionError", + "raise NotImplementedError", + "if __name__ == .__main__.:", + "if TYPE_CHECKING:", +] diff --git a/requirements-dev.txt b/requirements-dev.txt new file mode 100644 index 0000000000000000000000000000000000000000..806eee2bea04262bdbab80988fe7992273daffe3 --- /dev/null +++ b/requirements-dev.txt @@ -0,0 +1,20 @@ +# Development dependencies for gcli2api +# Install with: pip install -r requirements-dev.txt + +# Testing +pytest>=8.0.0 +pytest-asyncio>=0.23.0 +pytest-cov>=4.1.0 + +# Code formatting and linting +black>=24.0.0 +flake8>=7.0.0 +isort>=5.13.0 +mypy>=1.8.0 + +# Pre-commit hooks +pre-commit>=3.6.0 + +# Security scanning +safety>=3.0.0 +bandit>=1.7.5 diff --git a/requirements-termux.txt b/requirements-termux.txt new file mode 100644 index 0000000000000000000000000000000000000000..f265e54ba6b0665a1e93ff1a4cca18a7eacc9697 --- /dev/null +++ b/requirements-termux.txt @@ -0,0 +1,12 @@ +fastapi +httpx[socks] +pydantic==1.10.22 +python-dotenv +hypercorn +aiofiles +python-multipart +PyJWT +oauthlib +motor +pypinyin +aiosqlite \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..2945c0c96269f6884e53acb1ef7abd3a307ff462 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,12 @@ +fastapi>=0.116.1 +httpx[socks]>=0.28.1 +pydantic>=2.11.7 +python-dotenv>=1.1.1 +hypercorn>=0.17.3 +aiofiles>=24.1.0 +python-multipart>=0.0.20 +PyJWT>=2.10.1 +oauthlib>=3.3.1 +motor>=3.7.1 +aiosqlite>=0.20.0 +pypinyin>=0.51.0 diff --git a/runtime.txt b/runtime.txt new file mode 100644 index 0000000000000000000000000000000000000000..0e1ecfb2198313d2bb0ce7183b23ef6b065fa1ac --- /dev/null +++ b/runtime.txt @@ -0,0 +1 @@ +python-3.12.7 \ No newline at end of file diff --git a/setup-dev.sh b/setup-dev.sh new file mode 100644 index 0000000000000000000000000000000000000000..4916d847a95958bf202d3de63fbf47a1dfdfca11 --- /dev/null +++ b/setup-dev.sh @@ -0,0 +1,93 @@ +#!/bin/bash +# Development setup script for gcli2api +# This script sets up the development environment + +set -e + +echo "==========================================" +echo "gcli2api Development Setup" +echo "==========================================" +echo + +# Check Python version +echo "Checking Python version..." +python_version=$(python --version 2>&1 | awk '{print $2}') +required_version="3.12" + +if ! python -c "import sys; exit(0 if sys.version_info >= (3, 12) else 1)"; then + echo "❌ Error: Python 3.12 or higher is required. Found: $python_version" + exit 1 +fi +echo "✅ Python $python_version" +echo + +# Create virtual environment if it doesn't exist +if [ ! -d "venv" ]; then + echo "Creating virtual environment..." + python -m venv venv + echo "✅ Virtual environment created" +else + echo "✅ Virtual environment already exists" +fi +echo + +# Activate virtual environment +echo "Activating virtual environment..." +source venv/bin/activate +echo "✅ Virtual environment activated" +echo + +# Upgrade pip +echo "Upgrading pip..." +pip install --upgrade pip -q +echo "✅ pip upgraded" +echo + +# Install production dependencies +echo "Installing production dependencies..." +pip install -r requirements.txt -q +echo "✅ Production dependencies installed" +echo + +# Install development dependencies +echo "Installing development dependencies..." +pip install -r requirements-dev.txt -q +echo "✅ Development dependencies installed" +echo + +# Copy .env.example to .env if it doesn't exist +if [ ! -f ".env" ]; then + echo "Creating .env file from .env.example..." + cp .env.example .env + echo "✅ .env file created" + echo "⚠️ Please edit .env file with your configuration" +else + echo "✅ .env file already exists" +fi +echo + +# Install pre-commit hooks +echo "Installing pre-commit hooks..." +pre-commit install +echo "✅ Pre-commit hooks installed" +echo + +echo "==========================================" +echo "✅ Development setup complete!" +echo "==========================================" +echo +echo "Next steps:" +echo " 1. Edit .env with your configuration" +echo " 2. Run 'make test' to verify setup" +echo " 3. Run 'make run' to start the application" +echo +echo "Available commands:" +echo " make help - Show all available commands" +echo " make test - Run tests" +echo " make lint - Run linters" +echo " make format - Format code" +echo " make run - Run the application" +echo +echo "To activate the virtual environment in the future:" +echo " source venv/bin/activate" +echo diff --git a/src/api/Response_example.txt b/src/api/Response_example.txt new file mode 100644 index 0000000000000000000000000000000000000000..456ecccec35ede0c18172ce80a9f403aa62ac3ad --- /dev/null +++ b/src/api/Response_example.txt @@ -0,0 +1,210 @@ +================================================================================ +GeminiCli API 测试 +================================================================================ + +================================================================================ +【测试1】流式请求 (stream_request with native=False) +================================================================================ +请求体: { + "model": "gemini-2.5-flash", + "request": { + "contents": [ + { + "role": "user", + "parts": [ + { + "text": "Hello, tell me a joke in one sentence." + } + ] + } + ] + } +} + +流式响应数据 (每个chunk): +-------------------------------------------------------------------------------- +[2026-01-10 09:55:29] [INFO] SQLite storage initialized at ./creds\credentials.db +[2026-01-10 09:55:29] [INFO] Using SQLite storage backend +[2026-01-10 09:55:31] [INFO] Token刷新成 功并已保存: my-project-9-481103-1765596755.json (mode=geminicli) +[2026-01-10 09:55:34] [INFO] [DB] 准备commit,总更新行数=1 +[2026-01-10 09:55:34] [INFO] [DB] commit 完成 +[2026-01-10 09:55:34] [INFO] [DB] update_credential_state 结束: success=True, updated_count=1 + +Chunk #1: + 类型: str + 长度: 626 + 内容预览: 'data: {"response": {"candidates": [{"content": {"role": "model","parts": [{"text": "Why did the scarecrow win an award? Because he was outstanding in his field."}]},"finishReason": "STOP"}],"usageMeta' + 解析后的JSON: { + "response": { + "candidates": [ + { + "content": { + "role": "model", + "parts": [ + { + "text": "Why did the scarecrow win an award? Because he was outstanding in his field." + } + ] + }, + "finishReason": "STOP" + } + ], + "usageMetadata": { + "promptTokenCount": 10, + "candidatesTokenCount": 17, + "totalTokenCount": 51, + "trafficType": "PROVISIONED_THROUGHPUT", + "promptTokensDetails": [ + { + "modality": "TEXT", + "tokenCount": 10 + } + ], + "candidatesTokensDetails": [ + { + "modality": "TEXT", + "tokenCount": 17 + } + ], + "thoughtsTokenCount": 24 + }, + "modelVersion": "gemini-2.5-flash", + "createTime": "2026-01-10T01:55:29.168589Z", + "responseId": "kbFhaY2lCr-ZseMPqMiDmAU" + }, + "traceId": "55650653afd3c738" +} + +Chunk #2: + 类型: str + 长度: 0 + 内容预览: '' +E:\projects\gcli2api\src\api\geminicli.py:491: RuntimeWarning: coroutine 'get_auto_ban_error_codes' was never awaited + async for chunk in stream_request(body=test_body, native=False): +RuntimeWarning: Enable tracemalloc to get the object allocation traceback + +总共收到 2 个chunk + +================================================================================ +【测试2】非流式请求 (non_stream_request) +================================================================================ +请求体: { + "model": "gemini-2.5-flash", + "request": { + "contents": [ + { + "role": "user", + "parts": [ + { + "text": "Hello, tell me a joke in one sentence." + } + ] + } + ] + } +} + +[2026-01-10 09:55:35] [INFO] Token刷新成 功并已保存: gen-lang-client-0194852792-1767296759.json (mode=geminicli) +[2026-01-10 09:55:38] [INFO] [DB] 准备commit,总更新行数=1 +[2026-01-10 09:55:38] [INFO] [DB] commit 完成 +[2026-01-10 09:55:38] [INFO] [DB] update_credential_state 结束: success=True, updated_count=1 +E:\projects\gcli2api\src\api\geminicli.py:530: RuntimeWarning: coroutine 'get_auto_ban_error_codes' was never awaited + response = await non_stream_request(body=test_body) +RuntimeWarning: Enable tracemalloc to get the object allocation traceback +非流式响应数据: +-------------------------------------------------------------------------------- +状态码: 200 +Content-Type: application/json; charset=UTF-8 + +响应头: {'server': 'openresty', 'date': 'Sat, 10 Jan 2026 01:55:34 GMT', 'content-type': 'application/json; charset=UTF-8', 'transfer-encoding': 'chunked', 'connection': 'keep-alive', 'x-cloudaicompanion-trace-id': 'bf3a5eb6636774d2', 'vary': 'Origin, X-Origin, Referer', 'content-encoding': 'gzip', 'x-xss-protection': '0', 'x-frame-options': 'SAMEORIGIN', 'x-content-type-options': 'nosniff', 'server-timing': 'gfet4t7; dur=1377', 'alt-svc': 'h3=":443"; ma=2592000,h3-29=":443"; ma=2592000', 'access-control-allow-origin': '*', 'access-control-allow-methods': 'GET, POST, PUT, DELETE, PATCH, OPTIONS', 'access-control-allow-headers': 'Content-Type, Authorization, X-Requested-With', 'cache-control': 'no-cache', 'content-length': '969'} + +响应内容 (原始): +{ + "response": { + "candidates": [ + { + "content": { + "role": "model", + "parts": [ + { + "text": "Why did the scarecrow win an award? Because he was outstanding in his field!" + } + ] + }, + "finishReason": "STOP", + "avgLogprobs": -0.54438119776108684 + } + ], + "usageMetadata": { + "promptTokenCount": 10, + "candidatesTokenCount": 17, + "totalTokenCount": 47, + "trafficType": "PROVISIONED_THROUGHPUT", + "promptTokensDetails": [ + { + "modality": "TEXT", + "tokenCount": 10 + } + ], + "candidatesTokensDetails": [ + { + "modality": "TEXT", + "tokenCount": 17 + } + ], + "thoughtsTokenCount": 20 + }, + "modelVersion": "gemini-2.5-flash", + "createTime": "2026-01-10T01:55:33.450396Z", + "responseId": "lbFhady-G7yi694PmLOP4As" + }, + "traceId": "bf3a5eb6636774d2" +} + + +响应内容 (格式化JSON): +{ + "response": { + "candidates": [ + { + "content": { + "role": "model", + "parts": [ + { + "text": "Why did the scarecrow win an award? Because he was outstanding in his field!" + } + ] + }, + "finishReason": "STOP", + "avgLogprobs": -0.5443811977610868 + } + ], + "usageMetadata": { + "promptTokenCount": 10, + "candidatesTokenCount": 17, + "totalTokenCount": 47, + "trafficType": "PROVISIONED_THROUGHPUT", + "promptTokensDetails": [ + { + "modality": "TEXT", + "tokenCount": 10 + } + ], + "candidatesTokensDetails": [ + { + "modality": "TEXT", + "tokenCount": 17 + } + ], + "thoughtsTokenCount": 20 + }, + "modelVersion": "gemini-2.5-flash", + "createTime": "2026-01-10T01:55:33.450396Z", + "responseId": "lbFhady-G7yi694PmLOP4As" + }, + "traceId": "bf3a5eb6636774d2" +} + +================================================================================ +测试完成 +================================================================================ \ No newline at end of file diff --git a/src/api/antigravity.py b/src/api/antigravity.py new file mode 100644 index 0000000000000000000000000000000000000000..0960942445dd317fc1dc084ba79237c0f069a8a6 --- /dev/null +++ b/src/api/antigravity.py @@ -0,0 +1,694 @@ +""" +Antigravity API Client - Handles communication with Google's Antigravity API +处理与 Google Antigravity API 的通信 +""" + +import asyncio +import json +import uuid +from datetime import datetime, timezone +from typing import Any, Dict, List, Optional + +from fastapi import Response +from config import ( + get_antigravity_api_url, + get_antigravity_stream2nostream, + get_auto_ban_error_codes, +) +from log import log + +from src.credential_manager import CredentialManager +from src.httpx_client import stream_post_async, post_async +from src.models import Model, model_to_dict +from src.utils import ANTIGRAVITY_USER_AGENT + +# 导入共同的基础功能 +from src.api.utils import ( + handle_error_with_retry, + get_retry_config, + record_api_call_success, + record_api_call_error, + parse_and_log_cooldown, + collect_streaming_response, +) + +# ==================== 全局凭证管理器 ==================== + +# 全局凭证管理器实例(单例模式) +_credential_manager: Optional[CredentialManager] = None + + +async def _get_credential_manager() -> CredentialManager: + """ + 获取全局凭证管理器实例 + + Returns: + CredentialManager实例 + """ + global _credential_manager + if not _credential_manager: + _credential_manager = CredentialManager() + await _credential_manager.initialize() + return _credential_manager + + +# ==================== 辅助函数 ==================== + +def build_antigravity_headers(access_token: str, model_name: str = "") -> Dict[str, str]: + """ + 构建 Antigravity API 请求头 + + Args: + access_token: 访问令牌 + model_name: 模型名称,用于判断 request_type + + Returns: + 请求头字典 + """ + headers = { + 'User-Agent': ANTIGRAVITY_USER_AGENT, + 'Authorization': f'Bearer {access_token}', + 'Content-Type': 'application/json', + 'Accept-Encoding': 'gzip', + 'requestId': f"req-{uuid.uuid4()}" + } + + # 根据模型名称判断 request_type + if model_name: + request_type = "image_gen" if "image" in model_name.lower() else "agent" + headers['requestType'] = request_type + + return headers + + +# ==================== 新的流式和非流式请求函数 ==================== + +async def stream_request( + body: Dict[str, Any], + native: bool = False, + headers: Optional[Dict[str, str]] = None, +): + """ + 流式请求函数 + + Args: + body: 请求体 + native: 是否返回原生bytes流,False则返回str流 + headers: 额外的请求头 + + Yields: + Response对象(错误时)或 bytes流/str流(成功时) + """ + # 获取凭证管理器 + credential_manager = await _get_credential_manager() + + model_name = body.get("model", "") + + # 1. 获取有效凭证 + cred_result = await credential_manager.get_valid_credential( + mode="antigravity", model_key=model_name + ) + + if not cred_result: + # 如果返回值是None,直接返回错误500 + log.error("[ANTIGRAVITY STREAM] 当前无可用凭证") + yield Response( + content=json.dumps({"error": "当前无可用凭证"}), + status_code=500, + media_type="application/json" + ) + return + + current_file, credential_data = cred_result + access_token = credential_data.get("access_token") or credential_data.get("token") + + if not access_token: + log.error(f"[ANTIGRAVITY STREAM] No access token in credential: {current_file}") + yield Response( + content=json.dumps({"error": "凭证中没有访问令牌"}), + status_code=500, + media_type="application/json" + ) + return + + # 2. 构建URL和请求头 + antigravity_url = await get_antigravity_api_url() + target_url = f"{antigravity_url}/v1internal:streamGenerateContent?alt=sse" + + auth_headers = build_antigravity_headers(access_token, model_name) + + # 合并自定义headers + if headers: + auth_headers.update(headers) + + # 3. 调用stream_post_async进行请求 + retry_config = await get_retry_config() + max_retries = retry_config["max_retries"] + retry_interval = retry_config["retry_interval"] + + DISABLE_ERROR_CODES = await get_auto_ban_error_codes() # 禁用凭证的错误码 + last_error_response = None # 记录最后一次的错误响应 + + # 内部函数:获取新凭证并更新headers + async def refresh_credential(): + nonlocal current_file, access_token, auth_headers + cred_result = await credential_manager.get_valid_credential( + mode="antigravity", model_key=model_name + ) + if not cred_result: + return None + current_file, credential_data = cred_result + access_token = credential_data.get("access_token") or credential_data.get("token") + if not access_token: + return None + auth_headers = build_antigravity_headers(access_token, model_name) + if headers: + auth_headers.update(headers) + return True + + for attempt in range(max_retries + 1): + success_recorded = False # 标记是否已记录成功 + need_retry = False # 标记是否需要重试 + + try: + async for chunk in stream_post_async( + url=target_url, + body=body, + native=native, + headers=auth_headers + ): + # 判断是否是Response对象 + if isinstance(chunk, Response): + status_code = chunk.status_code + last_error_response = chunk # 记录最后一次错误 + + # 如果错误码是429或者不在禁用码当中,做好记录后进行重试 + if status_code == 429 or status_code not in DISABLE_ERROR_CODES: + # 解析错误响应内容 + try: + error_body = chunk.body.decode('utf-8') if isinstance(chunk.body, bytes) else str(chunk.body) + log.warning(f"[ANTIGRAVITY STREAM] 流式请求失败 (status={status_code}), 凭证: {current_file}, 响应: {error_body[:500]}") + except Exception: + log.warning(f"[ANTIGRAVITY STREAM] 流式请求失败 (status={status_code}), 凭证: {current_file}") + + # 记录错误 + cooldown_until = None + if status_code == 429: + # 尝试解析冷却时间 + try: + error_body = chunk.body.decode('utf-8') if isinstance(chunk.body, bytes) else str(chunk.body) + cooldown_until = await parse_and_log_cooldown(error_body, mode="antigravity") + except Exception: + pass + + await record_api_call_error( + credential_manager, current_file, status_code, + cooldown_until, mode="antigravity", model_key=model_name + ) + + # 检查是否应该重试 + should_retry = await handle_error_with_retry( + credential_manager, status_code, current_file, + retry_config["retry_enabled"], attempt, max_retries, retry_interval, + mode="antigravity" + ) + + if should_retry and attempt < max_retries: + need_retry = True + break # 跳出内层循环,准备重试 + else: + # 不重试,直接返回原始错误 + log.error(f"[ANTIGRAVITY STREAM] 达到最大重试次数或不应重试,返回原始错误") + yield chunk + return + else: + # 错误码在禁用码当中,直接返回,无需重试 + try: + error_body = chunk.body.decode('utf-8') if isinstance(chunk.body, bytes) else str(chunk.body) + log.error(f"[ANTIGRAVITY STREAM] 流式请求失败,禁用错误码 (status={status_code}), 凭证: {current_file}, 响应: {error_body[:500]}") + except Exception: + log.error(f"[ANTIGRAVITY STREAM] 流式请求失败,禁用错误码 (status={status_code}), 凭证: {current_file}") + await record_api_call_error( + credential_manager, current_file, status_code, + None, mode="antigravity", model_key=model_name + ) + yield chunk + return + else: + # 不是Response,说明是真流,直接yield返回 + # 只在第一个chunk时记录成功 + if not success_recorded: + await record_api_call_success( + credential_manager, current_file, mode="antigravity", model_key=model_name + ) + success_recorded = True + log.info(f"[ANTIGRAVITY STREAM] 开始接收流式响应,模型: {model_name}") + + # 记录原始chunk内容(用于调试) + if isinstance(chunk, bytes): + log.debug(f"[ANTIGRAVITY STREAM RAW] chunk(bytes): {chunk}") + else: + log.debug(f"[ANTIGRAVITY STREAM RAW] chunk(str): {chunk}") + + yield chunk + + # 流式请求完成,检查结果 + if success_recorded: + log.info(f"[ANTIGRAVITY STREAM] 流式响应完成,模型: {model_name}") + return + elif not need_retry: + # 没有收到任何数据(空回复),需要重试 + log.warning(f"[ANTIGRAVITY STREAM] 收到空回复,无任何内容,凭证: {current_file}") + await record_api_call_error( + credential_manager, current_file, 200, + None, mode="antigravity", model_key=model_name + ) + + if attempt < max_retries: + need_retry = True + else: + log.error(f"[ANTIGRAVITY STREAM] 空回复达到最大重试次数") + yield Response( + content=json.dumps({"error": "服务返回空回复"}), + status_code=500, + media_type="application/json" + ) + return + + # 统一处理重试 + if need_retry: + log.info(f"[ANTIGRAVITY STREAM] 重试请求 (attempt {attempt + 2}/{max_retries + 1})...") + await asyncio.sleep(retry_interval) + + if not await refresh_credential(): + log.error("[ANTIGRAVITY STREAM] 重试时无可用凭证或令牌") + yield Response( + content=json.dumps({"error": "当前无可用凭证"}), + status_code=500, + media_type="application/json" + ) + return + continue # 重试 + + except Exception as e: + log.error(f"[ANTIGRAVITY STREAM] 流式请求异常: {e}, 凭证: {current_file}") + if attempt < max_retries: + log.info(f"[ANTIGRAVITY STREAM] 异常后重试 (attempt {attempt + 2}/{max_retries + 1})...") + await asyncio.sleep(retry_interval) + continue + else: + # 所有重试都失败,返回最后一次的错误(如果有) + log.error(f"[ANTIGRAVITY STREAM] 所有重试均失败,最后异常: {e}") + yield last_error_response + + +async def non_stream_request( + body: Dict[str, Any], + headers: Optional[Dict[str, str]] = None, +) -> Response: + """ + 非流式请求函数 + + Args: + body: 请求体 + headers: 额外的请求头 + + Returns: + Response对象 + """ + # 检查是否启用流式收集模式 + if await get_antigravity_stream2nostream(): + log.info("[ANTIGRAVITY] 使用流式收集模式实现非流式请求") + + # 调用stream_request获取流 + stream = stream_request(body=body, native=False, headers=headers) + + # 收集流式响应 + # stream_request是一个异步生成器,可能yield Response(错误)或流数据 + # collect_streaming_response会自动处理这两种情况 + return await collect_streaming_response(stream) + + # 否则使用传统非流式模式 + log.info("[ANTIGRAVITY] 使用传统非流式模式") + + # 获取凭证管理器 + credential_manager = await _get_credential_manager() + + model_name = body.get("model", "") + + # 1. 获取有效凭证 + cred_result = await credential_manager.get_valid_credential( + mode="antigravity", model_key=model_name + ) + + if not cred_result: + # 如果返回值是None,直接返回错误500 + log.error("[ANTIGRAVITY] 当前无可用凭证") + return Response( + content=json.dumps({"error": "当前无可用凭证"}), + status_code=500, + media_type="application/json" + ) + + current_file, credential_data = cred_result + access_token = credential_data.get("access_token") or credential_data.get("token") + + if not access_token: + log.error(f"[ANTIGRAVITY] No access token in credential: {current_file}") + return Response( + content=json.dumps({"error": "凭证中没有访问令牌"}), + status_code=500, + media_type="application/json" + ) + + # 2. 构建URL和请求头 + antigravity_url = await get_antigravity_api_url() + target_url = f"{antigravity_url}/v1internal:generateContent" + + auth_headers = build_antigravity_headers(access_token, model_name) + + # 合并自定义headers + if headers: + auth_headers.update(headers) + + # 3. 调用post_async进行请求 + retry_config = await get_retry_config() + max_retries = retry_config["max_retries"] + retry_interval = retry_config["retry_interval"] + + DISABLE_ERROR_CODES = await get_auto_ban_error_codes() # 禁用凭证的错误码 + last_error_response = None # 记录最后一次的错误响应 + + # 内部函数:获取新凭证并更新headers + async def refresh_credential(): + nonlocal current_file, access_token, auth_headers + cred_result = await credential_manager.get_valid_credential( + mode="antigravity", model_key=model_name + ) + if not cred_result: + return None + current_file, credential_data = cred_result + access_token = credential_data.get("access_token") or credential_data.get("token") + if not access_token: + return None + auth_headers = build_antigravity_headers(access_token, model_name) + if headers: + auth_headers.update(headers) + return True + + for attempt in range(max_retries + 1): + need_retry = False # 标记是否需要重试 + + try: + response = await post_async( + url=target_url, + json=body, + headers=auth_headers, + timeout=300.0 + ) + + status_code = response.status_code + + # 成功 + if status_code == 200: + # 检查是否为空回复 + if not response.content or len(response.content) == 0: + log.warning(f"[ANTIGRAVITY] 收到200响应但内容为空,凭证: {current_file}") + + # 记录错误 + await record_api_call_error( + credential_manager, current_file, 200, + None, mode="antigravity", model_key=model_name + ) + + if attempt < max_retries: + need_retry = True + else: + log.error(f"[ANTIGRAVITY] 空回复达到最大重试次数") + return Response( + content=json.dumps({"error": "服务返回空回复"}), + status_code=500, + media_type="application/json" + ) + else: + # 正常响应 + await record_api_call_success( + credential_manager, current_file, mode="antigravity", model_key=model_name + ) + return Response( + content=response.content, + status_code=200, + headers=dict(response.headers) + ) + + # 失败 - 记录最后一次错误 + if status_code != 200: + last_error_response = Response( + content=response.content, + status_code=status_code, + headers=dict(response.headers) + ) + + # 判断是否需要重试 + if status_code == 429 or status_code not in DISABLE_ERROR_CODES: + try: + error_text = response.text + log.warning(f"[ANTIGRAVITY] 非流式请求失败 (status={status_code}), 凭证: {current_file}, 响应: {error_text[:500]}") + except Exception: + log.warning(f"[ANTIGRAVITY] 非流式请求失败 (status={status_code}), 凭证: {current_file}") + + # 记录错误 + cooldown_until = None + if status_code == 429: + # 尝试解析冷却时间 + try: + error_text = response.text + cooldown_until = await parse_and_log_cooldown(error_text, mode="antigravity") + except Exception: + pass + + await record_api_call_error( + credential_manager, current_file, status_code, + cooldown_until, mode="antigravity", model_key=model_name + ) + + # 检查是否应该重试 + should_retry = await handle_error_with_retry( + credential_manager, status_code, current_file, + retry_config["retry_enabled"], attempt, max_retries, retry_interval, + mode="antigravity" + ) + + if should_retry and attempt < max_retries: + need_retry = True + else: + # 不重试,直接返回原始错误 + log.error(f"[ANTIGRAVITY] 达到最大重试次数或不应重试,返回原始错误") + return last_error_response + else: + # 错误码在禁用码当中,直接返回,无需重试 + try: + error_text = response.text + log.error(f"[ANTIGRAVITY] 非流式请求失败,禁用错误码 (status={status_code}), 凭证: {current_file}, 响应: {error_text[:500]}") + except Exception: + log.error(f"[ANTIGRAVITY] 非流式请求失败,禁用错误码 (status={status_code}), 凭证: {current_file}") + await record_api_call_error( + credential_manager, current_file, status_code, + None, mode="antigravity", model_key=model_name + ) + return last_error_response + + # 统一处理重试 + if need_retry: + log.info(f"[ANTIGRAVITY] 重试请求 (attempt {attempt + 2}/{max_retries + 1})...") + await asyncio.sleep(retry_interval) + + if not await refresh_credential(): + log.error("[ANTIGRAVITY] 重试时无可用凭证或令牌") + return Response( + content=json.dumps({"error": "当前无可用凭证"}), + status_code=500, + media_type="application/json" + ) + continue # 重试 + + except Exception as e: + log.error(f"[ANTIGRAVITY] 非流式请求异常: {e}, 凭证: {current_file}") + if attempt < max_retries: + log.info(f"[ANTIGRAVITY] 异常后重试 (attempt {attempt + 2}/{max_retries + 1})...") + await asyncio.sleep(retry_interval) + continue + else: + # 所有重试都失败,返回最后一次的错误(如果有) + log.error(f"[ANTIGRAVITY] 所有重试均失败,最后异常: {e}") + return last_error_response + + # 所有重试都失败,返回最后一次的原始错误 + log.error("[ANTIGRAVITY] 所有重试均失败") + return last_error_response + + +# ==================== 模型和配额查询 ==================== + +async def fetch_available_models() -> List[Dict[str, Any]]: + """ + 获取可用模型列表,返回符合 OpenAI API 规范的格式 + + Returns: + 模型列表,格式为字典列表(用于兼容现有代码) + + Raises: + 返回空列表如果获取失败 + """ + # 获取凭证管理器和可用凭证 + credential_manager = await _get_credential_manager() + cred_result = await credential_manager.get_valid_credential(mode="antigravity") + if not cred_result: + log.error("[ANTIGRAVITY] No valid credentials available for fetching models") + return [] + + current_file, credential_data = cred_result + access_token = credential_data.get("access_token") or credential_data.get("token") + + if not access_token: + log.error(f"[ANTIGRAVITY] No access token in credential: {current_file}") + return [] + + # 构建请求头 + headers = build_antigravity_headers(access_token) + + try: + # 使用 POST 请求获取模型列表 + antigravity_url = await get_antigravity_api_url() + + response = await post_async( + url=f"{antigravity_url}/v1internal:fetchAvailableModels", + json={}, # 空的请求体 + headers=headers + ) + + if response.status_code == 200: + data = response.json() + log.debug(f"[ANTIGRAVITY] Raw models response: {json.dumps(data, ensure_ascii=False)[:500]}") + + # 转换为 OpenAI 格式的模型列表,使用 Model 类 + model_list = [] + current_timestamp = int(datetime.now(timezone.utc).timestamp()) + + if 'models' in data and isinstance(data['models'], dict): + # 遍历模型字典 + for model_id in data['models'].keys(): + model = Model( + id=model_id, + object='model', + created=current_timestamp, + owned_by='google' + ) + model_list.append(model_to_dict(model)) + + # 添加额外的 claude-opus-4-5 模型 + claude_opus_model = Model( + id='claude-opus-4-5', + object='model', + created=current_timestamp, + owned_by='google' + ) + model_list.append(model_to_dict(claude_opus_model)) + + log.info(f"[ANTIGRAVITY] Fetched {len(model_list)} available models") + return model_list + else: + log.error(f"[ANTIGRAVITY] Failed to fetch models ({response.status_code}): {response.text[:500]}") + return [] + + except Exception as e: + import traceback + log.error(f"[ANTIGRAVITY] Failed to fetch models: {e}") + log.error(f"[ANTIGRAVITY] Traceback: {traceback.format_exc()}") + return [] + + +async def fetch_quota_info(access_token: str) -> Dict[str, Any]: + """ + 获取指定凭证的额度信息 + + Args: + access_token: Antigravity 访问令牌 + + Returns: + 包含额度信息的字典,格式为: + { + "success": True/False, + "models": { + "model_name": { + "remaining": 0.95, + "resetTime": "12-20 10:30", + "resetTimeRaw": "2025-12-20T02:30:00Z" + } + }, + "error": "错误信息" (仅在失败时) + } + """ + + headers = build_antigravity_headers(access_token) + + try: + antigravity_url = await get_antigravity_api_url() + + response = await post_async( + url=f"{antigravity_url}/v1internal:fetchAvailableModels", + json={}, + headers=headers, + timeout=30.0 + ) + + if response.status_code == 200: + data = response.json() + log.debug(f"[ANTIGRAVITY QUOTA] Raw response: {json.dumps(data, ensure_ascii=False)[:500]}") + + quota_info = {} + + if 'models' in data and isinstance(data['models'], dict): + for model_id, model_data in data['models'].items(): + if isinstance(model_data, dict) and 'quotaInfo' in model_data: + quota = model_data['quotaInfo'] + remaining = quota.get('remainingFraction', 0) + reset_time_raw = quota.get('resetTime', '') + + # 转换为北京时间 + reset_time_beijing = 'N/A' + if reset_time_raw: + try: + utc_date = datetime.fromisoformat(reset_time_raw.replace('Z', '+00:00')) + # 转换为北京时间 (UTC+8) + from datetime import timedelta + beijing_date = utc_date + timedelta(hours=8) + reset_time_beijing = beijing_date.strftime('%m-%d %H:%M') + except Exception as e: + log.warning(f"[ANTIGRAVITY QUOTA] Failed to parse reset time: {e}") + + quota_info[model_id] = { + "remaining": remaining, + "resetTime": reset_time_beijing, + "resetTimeRaw": reset_time_raw + } + + return { + "success": True, + "models": quota_info + } + else: + log.error(f"[ANTIGRAVITY QUOTA] Failed to fetch quota ({response.status_code}): {response.text[:500]}") + return { + "success": False, + "error": f"API返回错误: {response.status_code}" + } + + except Exception as e: + import traceback + log.error(f"[ANTIGRAVITY QUOTA] Failed to fetch quota: {e}") + log.error(f"[ANTIGRAVITY QUOTA] Traceback: {traceback.format_exc()}") + return { + "success": False, + "error": str(e) + } \ No newline at end of file diff --git a/src/api/geminicli.py b/src/api/geminicli.py new file mode 100644 index 0000000000000000000000000000000000000000..21dcce3afddca25623c5057a3420450de57b9fa7 --- /dev/null +++ b/src/api/geminicli.py @@ -0,0 +1,597 @@ +""" +GeminiCli API Client - Handles all communication with GeminiCli API. +This module is used by both OpenAI compatibility layer and native Gemini endpoints. +GeminiCli API 客户端 - 处理与 GeminiCli API 的所有通信 +""" + +import sys +from pathlib import Path + +# 添加项目根目录到Python路径(用于直接运行测试) +if __name__ == "__main__": + project_root = Path(__file__).resolve().parent.parent.parent + if str(project_root) not in sys.path: + sys.path.insert(0, str(project_root)) + +import asyncio +import json +from typing import Any, Dict, Optional + +from fastapi import Response +from config import get_code_assist_endpoint, get_auto_ban_error_codes +from src.api.utils import get_model_group +from log import log + +from src.credential_manager import CredentialManager +from src.httpx_client import stream_post_async, post_async + +# 导入共同的基础功能 +from src.api.utils import ( + handle_error_with_retry, + get_retry_config, + record_api_call_success, + record_api_call_error, + parse_and_log_cooldown, +) +from src.utils import GEMINICLI_USER_AGENT + +# ==================== 全局凭证管理器 ==================== + +# 全局凭证管理器实例(单例模式) +_credential_manager: Optional[CredentialManager] = None + + +async def _get_credential_manager() -> CredentialManager: + """ + 获取全局凭证管理器实例 + + Returns: + CredentialManager实例 + """ + global _credential_manager + if not _credential_manager: + _credential_manager = CredentialManager() + await _credential_manager.initialize() + return _credential_manager + + +# ==================== 请求准备 ==================== + +async def prepare_request_headers_and_payload( + payload: dict, credential_data: dict, target_url: str +): + """ + 从凭证数据准备请求头和最终payload + + Args: + payload: 原始请求payload + credential_data: 凭证数据字典 + target_url: 目标URL + + Returns: + 元组: (headers, final_payload, target_url) + + Raises: + Exception: 如果凭证中缺少必要字段 + """ + token = credential_data.get("token") or credential_data.get("access_token", "") + if not token: + raise Exception("凭证中没有找到有效的访问令牌(token或access_token字段)") + + source_request = payload.get("request", {}) + + # 内部API使用Bearer Token和项目ID + headers = { + "Authorization": f"Bearer {token}", + "Content-Type": "application/json", + "User-Agent": GEMINICLI_USER_AGENT, + } + project_id = credential_data.get("project_id", "") + if not project_id: + raise Exception("项目ID不存在于凭证数据中") + final_payload = { + "model": payload.get("model"), + "project": project_id, + "request": source_request, + } + + return headers, final_payload, target_url + + +# ==================== 新的流式和非流式请求函数 ==================== + +async def stream_request( + body: Dict[str, Any], + native: bool = False, + headers: Optional[Dict[str, str]] = None, +): + """ + 流式请求函数 + + Args: + body: 请求体 + native: 是否返回原生bytes流,False则返回str流 + headers: 额外的请求头 + + Yields: + Response对象(错误时)或 bytes流/str流(成功时) + """ + # 获取凭证管理器 + credential_manager = await _get_credential_manager() + + model_name = body.get("model", "") + model_group = get_model_group(model_name) + + # 1. 获取有效凭证 + cred_result = await credential_manager.get_valid_credential( + mode="geminicli", model_key=model_group + ) + + if not cred_result: + # 如果返回值是None,直接返回错误500 + yield Response( + content=json.dumps({"error": "当前无可用凭证"}), + status_code=500, + media_type="application/json" + ) + return + + current_file, credential_data = cred_result + + # 2. 构建URL和请求头 + try: + auth_headers, final_payload, target_url = await prepare_request_headers_and_payload( + body, credential_data, + f"{await get_code_assist_endpoint()}/v1internal:streamGenerateContent?alt=sse" + ) + + # 合并自定义headers + if headers: + auth_headers.update(headers) + + except Exception as e: + log.error(f"准备请求失败: {e}") + yield Response( + content=json.dumps({"error": f"准备请求失败: {str(e)}"}), + status_code=500, + media_type="application/json" + ) + return + + # 3. 调用stream_post_async进行请求 + retry_config = await get_retry_config() + max_retries = retry_config["max_retries"] + retry_interval = retry_config["retry_interval"] + + DISABLE_ERROR_CODES = await get_auto_ban_error_codes() # 禁用凭证的错误码 + last_error_response = None # 记录最后一次的错误响应 + + for attempt in range(max_retries + 1): + success_recorded = False # 标记是否已记录成功 + + try: + async for chunk in stream_post_async( + url=target_url, + body=final_payload, + native=native, + headers=auth_headers + ): + # 判断是否是Response对象 + if isinstance(chunk, Response): + status_code = chunk.status_code + last_error_response = chunk # 记录最后一次错误 + + # 如果错误码是429或者不在禁用码当中,做好记录后进行重试 + if status_code == 429 or status_code not in DISABLE_ERROR_CODES: + # 解析错误响应内容 + try: + error_body = chunk.body.decode('utf-8') if isinstance(chunk.body, bytes) else str(chunk.body) + log.warning(f"流式请求失败 (status={status_code}), 凭证: {current_file}, 响应: {error_body[:500]}") + except Exception: + log.warning(f"流式请求失败 (status={status_code}), 凭证: {current_file}") + + # 记录错误 + cooldown_until = None + if status_code == 429: + # 尝试解析冷却时间 + try: + error_body = chunk.body.decode('utf-8') if isinstance(chunk.body, bytes) else str(chunk.body) + cooldown_until = await parse_and_log_cooldown(error_body, mode="geminicli") + except Exception: + pass + + await record_api_call_error( + credential_manager, current_file, status_code, + cooldown_until, mode="geminicli", model_key=model_group + ) + + # 检查是否应该重试 + should_retry = await handle_error_with_retry( + credential_manager, status_code, current_file, + retry_config["retry_enabled"], attempt, max_retries, retry_interval, + mode="geminicli" + ) + + if should_retry and attempt < max_retries: + # 重新获取凭证并重试 + log.info(f"[STREAM] 重试请求 (attempt {attempt + 2}/{max_retries + 1})...") + await asyncio.sleep(retry_interval) + + # 获取新凭证 + cred_result = await credential_manager.get_valid_credential( + mode="geminicli", model_key=model_group + ) + if not cred_result: + log.error("[STREAM] 重试时无可用凭证") + yield Response( + content=json.dumps({"error": "当前无可用凭证"}), + status_code=500, + media_type="application/json" + ) + return + + current_file, credential_data = cred_result + auth_headers, final_payload, target_url = await prepare_request_headers_and_payload( + body, credential_data, + f"{await get_code_assist_endpoint()}/v1internal:streamGenerateContent?alt=sse" + ) + if headers: + auth_headers.update(headers) + break # 跳出内层循环,重新请求 + else: + # 不重试,直接返回原始错误 + log.error(f"[STREAM] 达到最大重试次数或不应重试,返回原始错误") + yield chunk + return + else: + # 错误码在禁用码当中,直接返回,无需重试 + try: + error_body = chunk.body.decode('utf-8') if isinstance(chunk.body, bytes) else str(chunk.body) + log.error(f"流式请求失败,禁用错误码 (status={status_code}), 凭证: {current_file}, 响应: {error_body[:500]}") + except Exception: + log.error(f"流式请求失败,禁用错误码 (status={status_code}), 凭证: {current_file}") + await record_api_call_error( + credential_manager, current_file, status_code, + None, mode="geminicli", model_key=model_group + ) + yield chunk + return + else: + # 不是Response,说明是真流,直接yield返回 + # 只在第一个chunk时记录成功 + if not success_recorded: + await record_api_call_success( + credential_manager, current_file, mode="geminicli", model_key=model_group + ) + success_recorded = True + + yield chunk + + # 流式请求成功完成,退出重试循环 + return + + except Exception as e: + log.error(f"流式请求异常: {e}, 凭证: {current_file}") + if attempt < max_retries: + log.info(f"[STREAM] 异常后重试 (attempt {attempt + 2}/{max_retries + 1})...") + await asyncio.sleep(retry_interval) + continue + else: + # 所有重试都失败,返回最后一次的错误(如果有)或500错误 + log.error(f"[STREAM] 所有重试均失败,最后异常: {e}") + yield last_error_response + + +async def non_stream_request( + body: Dict[str, Any], + headers: Optional[Dict[str, str]] = None, +) -> Response: + """ + 非流式请求函数 + + Args: + body: 请求体 + native: 保留参数以保持接口一致性(实际未使用) + headers: 额外的请求头 + + Returns: + Response对象 + """ + # 获取凭证管理器 + credential_manager = await _get_credential_manager() + + model_name = body.get("model", "") + model_group = get_model_group(model_name) + + # 1. 获取有效凭证 + cred_result = await credential_manager.get_valid_credential( + mode="geminicli", model_key=model_group + ) + + if not cred_result: + # 如果返回值是None,直接返回错误500 + return Response( + content=json.dumps({"error": "当前无可用凭证"}), + status_code=500, + media_type="application/json" + ) + + current_file, credential_data = cred_result + + # 2. 构建URL和请求头 + try: + auth_headers, final_payload, target_url = await prepare_request_headers_and_payload( + body, credential_data, + f"{await get_code_assist_endpoint()}/v1internal:generateContent" + ) + + # 合并自定义headers + if headers: + auth_headers.update(headers) + + except Exception as e: + log.error(f"准备请求失败: {e}") + return Response( + content=json.dumps({"error": f"准备请求失败: {str(e)}"}), + status_code=500, + media_type="application/json" + ) + + # 3. 调用post_async进行请求 + retry_config = await get_retry_config() + max_retries = retry_config["max_retries"] + retry_interval = retry_config["retry_interval"] + + DISABLE_ERROR_CODES = await get_auto_ban_error_codes() # 禁用凭证的错误码 + last_error_response = None # 记录最后一次的错误响应 + + for attempt in range(max_retries + 1): + try: + response = await post_async( + url=target_url, + json=final_payload, + headers=auth_headers, + timeout=300.0 + ) + + status_code = response.status_code + + # 成功 + if status_code == 200: + await record_api_call_success( + credential_manager, current_file, mode="geminicli", model_key=model_group + ) + # 创建响应头,移除压缩相关的header避免重复解压 + response_headers = dict(response.headers) + response_headers.pop('content-encoding', None) + response_headers.pop('content-length', None) + + return Response( + content=response.content, + status_code=200, + headers=response_headers + ) + + # 失败 - 记录最后一次错误 + # 创建响应头,移除压缩相关的header避免重复解压 + error_headers = dict(response.headers) + error_headers.pop('content-encoding', None) + error_headers.pop('content-length', None) + + last_error_response = Response( + content=response.content, + status_code=status_code, + headers=error_headers + ) + + # 判断是否需要重试 + if status_code == 429 or status_code not in DISABLE_ERROR_CODES: + try: + error_text = response.text + log.warning(f"非流式请求失败 (status={status_code}), 凭证: {current_file}, 响应: {error_text[:500]}") + except Exception: + log.warning(f"非流式请求失败 (status={status_code}), 凭证: {current_file}") + + # 记录错误 + cooldown_until = None + if status_code == 429: + # 尝试解析冷却时间 + try: + error_text = response.text + cooldown_until = await parse_and_log_cooldown(error_text, mode="geminicli") + except Exception: + pass + + await record_api_call_error( + credential_manager, current_file, status_code, + cooldown_until, mode="geminicli", model_key=model_group + ) + + # 检查是否应该重试 + should_retry = await handle_error_with_retry( + credential_manager, status_code, current_file, + retry_config["retry_enabled"], attempt, max_retries, retry_interval, + mode="geminicli" + ) + + if should_retry and attempt < max_retries: + # 重新获取凭证并重试 + log.info(f"[NON-STREAM] 重试请求 (attempt {attempt + 2}/{max_retries + 1})...") + await asyncio.sleep(retry_interval) + + # 获取新凭证 + cred_result = await credential_manager.get_valid_credential( + mode="geminicli", model_key=model_group + ) + if not cred_result: + log.error("[NON-STREAM] 重试时无可用凭证") + return Response( + content=json.dumps({"error": "当前无可用凭证"}), + status_code=500, + media_type="application/json" + ) + + current_file, credential_data = cred_result + auth_headers, final_payload, target_url = await prepare_request_headers_and_payload( + body, credential_data, + f"{await get_code_assist_endpoint()}/v1internal:generateContent" + ) + if headers: + auth_headers.update(headers) + continue # 重试 + else: + # 不重试,直接返回原始错误 + log.error(f"[NON-STREAM] 达到最大重试次数或不应重试,返回原始错误") + return last_error_response + else: + # 错误码在禁用码当中,直接返回,无需重试 + try: + error_text = response.text + log.error(f"非流式请求失败,禁用错误码 (status={status_code}), 凭证: {current_file}, 响应: {error_text[:500]}") + except Exception: + log.error(f"非流式请求失败,禁用错误码 (status={status_code}), 凭证: {current_file}") + await record_api_call_error( + credential_manager, current_file, status_code, + None, mode="geminicli", model_key=model_group + ) + return last_error_response + + except Exception as e: + log.error(f"非流式请求异常: {e}, 凭证: {current_file}") + if attempt < max_retries: + log.info(f"[NON-STREAM] 异常后重试 (attempt {attempt + 2}/{max_retries + 1})...") + await asyncio.sleep(retry_interval) + continue + else: + # 所有重试都失败,返回最后一次的错误(如果有)或500错误 + log.error(f"[NON-STREAM] 所有重试均失败,最后异常: {e}") + if last_error_response: + return last_error_response + else: + return Response( + content=json.dumps({"error": f"请求异常: {str(e)}"}), + status_code=500, + media_type="application/json" + ) + + # 所有重试都失败,返回最后一次的原始错误 + log.error("[NON-STREAM] 所有重试均失败") + return last_error_response + + +# ==================== 测试代码 ==================== + +if __name__ == "__main__": + """ + 测试代码:演示API返回的流式和非流式数据格式 + 运行方式: python src/api/geminicli.py + """ + print("=" * 80) + print("GeminiCli API 测试") + print("=" * 80) + + # 测试请求体 + test_body = { + "model": "gemini-2.5-flash", + "request": { + "contents": [ + { + "role": "user", + "parts": [{"text": "Hello, tell me a joke in one sentence."}] + } + ] + } + } + + async def test_stream_request(): + """测试流式请求""" + print("\n" + "=" * 80) + print("【测试1】流式请求 (stream_request with native=False)") + print("=" * 80) + print(f"请求体: {json.dumps(test_body, indent=2, ensure_ascii=False)}\n") + + print("流式响应数据 (每个chunk):") + print("-" * 80) + + chunk_count = 0 + async for chunk in stream_request(body=test_body, native=False): + chunk_count += 1 + if isinstance(chunk, Response): + # 错误响应 + print(f"\n❌ 错误响应:") + print(f" 状态码: {chunk.status_code}") + print(f" Content-Type: {chunk.headers.get('content-type', 'N/A')}") + try: + content = chunk.body.decode('utf-8') if isinstance(chunk.body, bytes) else str(chunk.body) + print(f" 内容: {content}") + except Exception as e: + print(f" 内容解析失败: {e}") + else: + # 正常的流式数据块 (str类型) + print(f"\nChunk #{chunk_count}:") + print(f" 类型: {type(chunk).__name__}") + print(f" 长度: {len(chunk) if hasattr(chunk, '__len__') else 'N/A'}") + print(f" 内容预览: {repr(chunk[:200] if len(chunk) > 200 else chunk)}") + + # 如果是SSE格式,尝试解析 + if isinstance(chunk, str) and chunk.startswith("data: "): + try: + data_line = chunk.strip() + if data_line.startswith("data: "): + json_str = data_line[6:] # 去掉 "data: " 前缀 + json_data = json.loads(json_str) + print(f" 解析后的JSON: {json.dumps(json_data, indent=4, ensure_ascii=False)}") + except Exception as e: + print(f" SSE解析尝试失败: {e}") + + print(f"\n总共收到 {chunk_count} 个chunk") + + async def test_non_stream_request(): + """测试非流式请求""" + print("\n" + "=" * 80) + print("【测试2】非流式请求 (non_stream_request)") + print("=" * 80) + print(f"请求体: {json.dumps(test_body, indent=2, ensure_ascii=False)}\n") + + response = await non_stream_request(body=test_body) + + print("非流式响应数据:") + print("-" * 80) + print(f"状态码: {response.status_code}") + print(f"Content-Type: {response.headers.get('content-type', 'N/A')}") + print(f"\n响应头: {dict(response.headers)}\n") + + try: + content = response.body.decode('utf-8') if isinstance(response.body, bytes) else str(response.body) + print(f"响应内容 (原始):\n{content}\n") + + # 尝试解析JSON + try: + json_data = json.loads(content) + print(f"响应内容 (格式化JSON):") + print(json.dumps(json_data, indent=2, ensure_ascii=False)) + except json.JSONDecodeError: + print("(非JSON格式)") + except Exception as e: + print(f"内容解析失败: {e}") + + async def main(): + """主测试函数""" + try: + # 测试流式请求 + await test_stream_request() + + # 测试非流式请求 + await test_non_stream_request() + + print("\n" + "=" * 80) + print("测试完成") + print("=" * 80) + + except Exception as e: + print(f"\n❌ 测试过程中出现异常: {e}") + import traceback + traceback.print_exc() + + # 运行测试 + asyncio.run(main()) diff --git a/src/api/utils.py b/src/api/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..491ac2537add16757e945b24a62695b04e89ee26 --- /dev/null +++ b/src/api/utils.py @@ -0,0 +1,497 @@ +""" +Base API Client - 共用的 API 客户端基础功能 +提供错误处理、自动封禁、重试逻辑等共同功能 +""" + +import asyncio +import json +from datetime import datetime, timezone +from typing import Any, Dict, Optional + +from fastapi import Response + +from config import ( + get_auto_ban_enabled, + get_auto_ban_error_codes, + get_retry_429_enabled, + get_retry_429_interval, + get_retry_429_max_retries, +) +from log import log +from src.credential_manager import CredentialManager + + +# ==================== 错误检查与处理 ==================== + +async def check_should_auto_ban(status_code: int) -> bool: + """ + 检查是否应该触发自动封禁 + + Args: + status_code: HTTP状态码 + + Returns: + bool: 是否应该触发自动封禁 + """ + return ( + await get_auto_ban_enabled() + and status_code in await get_auto_ban_error_codes() + ) + + +async def handle_auto_ban( + credential_manager: CredentialManager, + status_code: int, + credential_name: str, + mode: str = "geminicli" +) -> None: + """ + 处理自动封禁:直接禁用凭证 + + Args: + credential_manager: 凭证管理器实例 + status_code: HTTP状态码 + credential_name: 凭证名称 + mode: 模式(geminicli 或 antigravity) + """ + if credential_manager and credential_name: + log.warning( + f"[{mode.upper()} AUTO_BAN] Status {status_code} triggers auto-ban for credential: {credential_name}" + ) + await credential_manager.set_cred_disabled( + credential_name, True, mode=mode + ) + + +async def handle_error_with_retry( + credential_manager: CredentialManager, + status_code: int, + credential_name: str, + retry_enabled: bool, + attempt: int, + max_retries: int, + retry_interval: float, + mode: str = "geminicli" +) -> bool: + """ + 统一处理错误和重试逻辑 + + 仅在以下情况下进行自动重试: + 1. 429错误(速率限制) + 2. 导致凭证封禁的错误(AUTO_BAN_ERROR_CODES配置) + + Args: + credential_manager: 凭证管理器实例 + status_code: HTTP状态码 + credential_name: 凭证名称 + retry_enabled: 是否启用重试 + attempt: 当前重试次数 + max_retries: 最大重试次数 + retry_interval: 重试间隔 + mode: 模式(geminicli 或 antigravity) + + Returns: + bool: True表示需要继续重试,False表示不需要重试 + """ + # 优先检查自动封禁 + should_auto_ban = await check_should_auto_ban(status_code) + + if should_auto_ban: + # 触发自动封禁 + await handle_auto_ban(credential_manager, status_code, credential_name, mode) + + # 自动封禁后,仍然尝试重试(会在下次循环中自动获取新凭证) + if retry_enabled and attempt < max_retries: + log.info( + f"[{mode.upper()} RETRY] Retrying with next credential after auto-ban " + f"(status {status_code}, attempt {attempt + 1}/{max_retries})" + ) + await asyncio.sleep(retry_interval) + return True + return False + + # 如果不触发自动封禁,仅对429错误进行重试 + if status_code == 429 and retry_enabled and attempt < max_retries: + log.info( + f"[{mode.upper()} RETRY] 429 rate limit encountered, retrying " + f"(attempt {attempt + 1}/{max_retries})" + ) + await asyncio.sleep(retry_interval) + return True + + # 其他错误不进行重试 + return False + + +# ==================== 重试配置获取 ==================== + +async def get_retry_config() -> Dict[str, Any]: + """ + 获取重试配置 + + Returns: + 包含重试配置的字典 + """ + return { + "retry_enabled": await get_retry_429_enabled(), + "max_retries": await get_retry_429_max_retries(), + "retry_interval": await get_retry_429_interval(), + } + + +# ==================== API调用结果记录 ==================== + +async def record_api_call_success( + credential_manager: CredentialManager, + credential_name: str, + mode: str = "geminicli", + model_key: Optional[str] = None +) -> None: + """ + 记录API调用成功 + + Args: + credential_manager: 凭证管理器实例 + credential_name: 凭证名称 + mode: 模式(geminicli 或 antigravity) + model_key: 模型键(用于模型级CD) + """ + if credential_manager and credential_name: + await credential_manager.record_api_call_result( + credential_name, True, mode=mode, model_key=model_key + ) + + +async def record_api_call_error( + credential_manager: CredentialManager, + credential_name: str, + status_code: int, + cooldown_until: Optional[float] = None, + mode: str = "geminicli", + model_key: Optional[str] = None +) -> None: + """ + 记录API调用错误 + + Args: + credential_manager: 凭证管理器实例 + credential_name: 凭证名称 + status_code: HTTP状态码 + cooldown_until: 冷却截止时间(Unix时间戳) + mode: 模式(geminicli 或 antigravity) + model_key: 模型键(用于模型级CD) + """ + if credential_manager and credential_name: + await credential_manager.record_api_call_result( + credential_name, + False, + status_code, + cooldown_until=cooldown_until, + mode=mode, + model_key=model_key + ) + + +# ==================== 429错误处理 ==================== + +async def parse_and_log_cooldown( + error_text: str, + mode: str = "geminicli" +) -> Optional[float]: + """ + 解析并记录冷却时间 + + Args: + error_text: 错误响应文本 + mode: 模式(geminicli 或 antigravity) + + Returns: + 冷却截止时间(Unix时间戳),如果解析失败则返回None + """ + try: + error_data = json.loads(error_text) + cooldown_until = parse_quota_reset_timestamp(error_data) + if cooldown_until: + log.info( + f"[{mode.upper()}] 检测到quota冷却时间: " + f"{datetime.fromtimestamp(cooldown_until, timezone.utc).isoformat()}" + ) + return cooldown_until + except Exception as parse_err: + log.debug(f"[{mode.upper()}] Failed to parse cooldown time: {parse_err}") + return None + + +# ==================== 流式响应收集 ==================== + +async def collect_streaming_response(stream_generator) -> Response: + """ + 将Gemini流式响应收集为一条完整的非流式响应 + + Args: + stream_generator: 流式响应生成器,产生 "data: {json}" 格式的行或Response对象 + + Returns: + Response: 合并后的完整响应对象 + + Example: + >>> async for line in stream_generator: + ... # line format: "data: {...}" or Response object + >>> response = await collect_streaming_response(stream_generator) + """ + # 初始化响应结构 + merged_response = { + "response": { + "candidates": [{ + "content": { + "parts": [], + "role": "model" + }, + "finishReason": None, + "safetyRatings": [], + "citationMetadata": None + }], + "usageMetadata": { + "promptTokenCount": 0, + "candidatesTokenCount": 0, + "totalTokenCount": 0 + } + } + } + + collected_text = [] # 用于收集文本内容 + collected_thought_text = [] # 用于收集思维链内容 + collected_other_parts = [] # 用于收集其他类型的parts(图片、文件等) + has_data = False + line_count = 0 + + log.debug("[STREAM COLLECTOR] Starting to collect streaming response") + + try: + async for line in stream_generator: + line_count += 1 + + # 如果收到的是Response对象(错误),直接返回 + if isinstance(line, Response): + log.debug(f"[STREAM COLLECTOR] 收到错误Response,状态码: {line.status_code}") + return line + + # 处理 bytes 类型 + if isinstance(line, bytes): + line_str = line.decode('utf-8', errors='ignore') + log.debug(f"[STREAM COLLECTOR] Processing bytes line {line_count}: {line_str[:200] if line_str else 'empty'}") + elif isinstance(line, str): + line_str = line + log.debug(f"[STREAM COLLECTOR] Processing line {line_count}: {line_str[:200] if line_str else 'empty'}") + else: + log.debug(f"[STREAM COLLECTOR] Skipping non-string/bytes line: {type(line)}") + continue + + # 解析流式数据行 + if not line_str.startswith("data: "): + log.debug(f"[STREAM COLLECTOR] Skipping line without 'data: ' prefix: {line_str[:100]}") + continue + + raw = line_str[6:].strip() + if raw == "[DONE]": + log.debug("[STREAM COLLECTOR] Received [DONE] marker") + break + + try: + log.debug(f"[STREAM COLLECTOR] Parsing JSON: {raw[:200]}") + chunk = json.loads(raw) + has_data = True + log.debug(f"[STREAM COLLECTOR] Chunk keys: {chunk.keys() if isinstance(chunk, dict) else type(chunk)}") + + # 提取响应对象 + response_obj = chunk.get("response", {}) + if not response_obj: + log.debug("[STREAM COLLECTOR] No 'response' key in chunk, trying direct access") + response_obj = chunk # 尝试直接使用chunk + + candidates = response_obj.get("candidates", []) + log.debug(f"[STREAM COLLECTOR] Found {len(candidates)} candidates") + if not candidates: + log.debug(f"[STREAM COLLECTOR] No candidates in chunk, chunk structure: {list(chunk.keys()) if isinstance(chunk, dict) else type(chunk)}") + continue + + candidate = candidates[0] + + # 收集文本内容 + content = candidate.get("content", {}) + parts = content.get("parts", []) + log.debug(f"[STREAM COLLECTOR] Processing {len(parts)} parts from candidate") + + for part in parts: + if not isinstance(part, dict): + continue + + # 处理文本内容 + text = part.get("text", "") + if text: + # 区分普通文本和思维链 + if part.get("thought", False): + collected_thought_text.append(text) + log.debug(f"[STREAM COLLECTOR] Collected thought text: {text[:100]}") + else: + collected_text.append(text) + log.debug(f"[STREAM COLLECTOR] Collected regular text: {text[:100]}") + # 处理非文本内容(图片、文件等) + elif "inlineData" in part or "fileData" in part or "executableCode" in part or "codeExecutionResult" in part: + collected_other_parts.append(part) + log.debug(f"[STREAM COLLECTOR] Collected non-text part: {list(part.keys())}") + + # 收集其他信息(使用最后一个块的值) + if candidate.get("finishReason"): + merged_response["response"]["candidates"][0]["finishReason"] = candidate["finishReason"] + + if candidate.get("safetyRatings"): + merged_response["response"]["candidates"][0]["safetyRatings"] = candidate["safetyRatings"] + + if candidate.get("citationMetadata"): + merged_response["response"]["candidates"][0]["citationMetadata"] = candidate["citationMetadata"] + + # 更新使用元数据 + usage = response_obj.get("usageMetadata", {}) + if usage: + merged_response["response"]["usageMetadata"].update(usage) + + except json.JSONDecodeError as e: + log.debug(f"[STREAM COLLECTOR] Failed to parse JSON chunk: {e}") + continue + except Exception as e: + log.debug(f"[STREAM COLLECTOR] Error processing chunk: {e}") + continue + + except Exception as e: + log.error(f"[STREAM COLLECTOR] Error collecting stream after {line_count} lines: {e}") + return Response( + content=json.dumps({"error": f"收集流式响应失败: {str(e)}"}), + status_code=500, + media_type="application/json" + ) + + log.debug(f"[STREAM COLLECTOR] Finished iteration, has_data={has_data}, line_count={line_count}") + + # 如果没有收集到任何数据,返回错误 + if not has_data: + log.error(f"[STREAM COLLECTOR] No data collected from stream after {line_count} lines") + return Response( + content=json.dumps({"error": "No data collected from stream"}), + status_code=500, + media_type="application/json" + ) + + # 组装最终的parts + final_parts = [] + + # 先添加思维链内容(如果有) + if collected_thought_text: + final_parts.append({ + "text": "".join(collected_thought_text), + "thought": True + }) + + # 再添加普通文本内容 + if collected_text: + final_parts.append({ + "text": "".join(collected_text) + }) + + # 添加其他类型的parts(图片、文件等) + final_parts.extend(collected_other_parts) + + # 如果没有任何内容,添加空文本 + if not final_parts: + final_parts.append({"text": ""}) + + merged_response["response"]["candidates"][0]["content"]["parts"] = final_parts + + log.info(f"[STREAM COLLECTOR] Collected {len(collected_text)} text chunks, {len(collected_thought_text)} thought chunks, and {len(collected_other_parts)} other parts") + + # 去掉嵌套的 "response" 包装(Antigravity格式 -> 标准Gemini格式) + if "response" in merged_response and "candidates" not in merged_response: + log.debug(f"[STREAM COLLECTOR] 展开response包装") + merged_response = merged_response["response"] + + # 返回纯JSON格式 + return Response( + content=json.dumps(merged_response, ensure_ascii=False).encode('utf-8'), + status_code=200, + headers={}, + media_type="application/json" + ) + + +def parse_quota_reset_timestamp(error_response: dict) -> Optional[float]: + """ + 从Google API错误响应中提取quota重置时间戳 + + Args: + error_response: Google API返回的错误响应字典 + + Returns: + Unix时间戳(秒),如果无法解析则返回None + + 示例错误响应: + { + "error": { + "code": 429, + "message": "You have exhausted your capacity...", + "status": "RESOURCE_EXHAUSTED", + "details": [ + { + "@type": "type.googleapis.com/google.rpc.ErrorInfo", + "reason": "QUOTA_EXHAUSTED", + "metadata": { + "quotaResetTimeStamp": "2025-11-30T14:57:24Z", + "quotaResetDelay": "13h19m1.20964964s" + } + } + ] + } + } + """ + try: + details = error_response.get("error", {}).get("details", []) + + for detail in details: + if detail.get("@type") == "type.googleapis.com/google.rpc.ErrorInfo": + reset_timestamp_str = detail.get("metadata", {}).get("quotaResetTimeStamp") + + if reset_timestamp_str: + if reset_timestamp_str.endswith("Z"): + reset_timestamp_str = reset_timestamp_str.replace("Z", "+00:00") + + reset_dt = datetime.fromisoformat(reset_timestamp_str) + if reset_dt.tzinfo is None: + reset_dt = reset_dt.replace(tzinfo=timezone.utc) + + return reset_dt.astimezone(timezone.utc).timestamp() + + return None + + except Exception: + return None + +def get_model_group(model_name: str) -> str: + """ + 获取模型组,用于 GCLI CD 机制。 + + Args: + model_name: 模型名称 + + Returns: + "pro" 或 "flash" + + 说明: + - pro 组: gemini-2.5-pro, gemini-3-pro-preview 共享额度 + - flash 组: gemini-2.5-flash 单独额度 + """ + + # 判断模型组 + if "flash" in model_name.lower(): + return "flash" + else: + # pro 模型(包括 gemini-2.5-pro 和 gemini-3-pro-preview) + return "pro" \ No newline at end of file diff --git a/src/auth.py b/src/auth.py new file mode 100644 index 0000000000000000000000000000000000000000..6d2d0f4292bcaa50badf567e5be4117cc0229699 --- /dev/null +++ b/src/auth.py @@ -0,0 +1,1242 @@ +""" +认证API模块 +""" + +import asyncio +import json +import secrets +import socket +import threading +import time +import uuid +from datetime import timezone +from http.server import BaseHTTPRequestHandler, HTTPServer +from typing import Any, Dict, List, Optional +from urllib.parse import parse_qs, urlparse + +from config import get_config_value, get_antigravity_api_url, get_code_assist_endpoint +from log import log + +from .google_oauth_api import ( + Credentials, + Flow, + enable_required_apis, + fetch_project_id, + get_user_projects, + select_default_project, +) +from .storage_adapter import get_storage_adapter +from .utils import ( + ANTIGRAVITY_CLIENT_ID, + ANTIGRAVITY_CLIENT_SECRET, + ANTIGRAVITY_SCOPES, + ANTIGRAVITY_USER_AGENT, + CALLBACK_HOST, + CLIENT_ID, + CLIENT_SECRET, + SCOPES, + GEMINICLI_USER_AGENT, + TOKEN_URL, +) + + +async def get_callback_port(): + """获取OAuth回调端口""" + return int(await get_config_value("oauth_callback_port", "11451", "OAUTH_CALLBACK_PORT")) + + +def _prepare_credentials_data(credentials: Credentials, project_id: str, mode: str = "geminicli") -> Dict[str, Any]: + """准备凭证数据字典(统一函数)""" + if mode == "antigravity": + creds_data = { + "client_id": ANTIGRAVITY_CLIENT_ID, + "client_secret": ANTIGRAVITY_CLIENT_SECRET, + "token": credentials.access_token, + "refresh_token": credentials.refresh_token, + "scopes": ANTIGRAVITY_SCOPES, + "token_uri": TOKEN_URL, + "project_id": project_id, + } + else: + creds_data = { + "client_id": CLIENT_ID, + "client_secret": CLIENT_SECRET, + "token": credentials.access_token, + "refresh_token": credentials.refresh_token, + "scopes": SCOPES, + "token_uri": TOKEN_URL, + "project_id": project_id, + } + + if credentials.expires_at: + if credentials.expires_at.tzinfo is None: + expiry_utc = credentials.expires_at.replace(tzinfo=timezone.utc) + else: + expiry_utc = credentials.expires_at + creds_data["expiry"] = expiry_utc.isoformat() + + return creds_data + + +def _generate_random_project_id() -> str: + """生成随机project_id(antigravity模式使用)""" + random_id = uuid.uuid4().hex[:8] + return f"projects/random-{random_id}/locations/global" + + +def _cleanup_auth_flow_server(state: str): + """清理认证流程的服务器资源""" + if state in auth_flows: + flow_data_to_clean = auth_flows[state] + try: + if flow_data_to_clean.get("server"): + server = flow_data_to_clean["server"] + port = flow_data_to_clean.get("callback_port") + async_shutdown_server(server, port) + except Exception as e: + log.debug(f"关闭服务器时出错: {e}") + del auth_flows[state] + + +class _OAuthLibPatcher: + """oauthlib参数验证补丁的上下文管理器""" + def __init__(self): + import oauthlib.oauth2.rfc6749.parameters + self.module = oauthlib.oauth2.rfc6749.parameters + self.original_validate = None + + def __enter__(self): + self.original_validate = self.module.validate_token_parameters + + def patched_validate(params): + try: + return self.original_validate(params) + except Warning: + pass + + self.module.validate_token_parameters = patched_validate + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if self.original_validate: + self.module.validate_token_parameters = self.original_validate + + +# 全局状态管理 - 严格限制大小 +auth_flows = {} # 存储进行中的认证流程 +MAX_AUTH_FLOWS = 20 # 严格限制最大认证流程数 + + +def cleanup_auth_flows_for_memory(): + """清理认证流程以释放内存""" + global auth_flows + cleanup_expired_flows() + # 如果还是太多,强制清理一些旧的流程 + if len(auth_flows) > 10: + # 按创建时间排序,保留最新的10个 + sorted_flows = sorted( + auth_flows.items(), key=lambda x: x[1].get("created_at", 0), reverse=True + ) + new_auth_flows = dict(sorted_flows[:10]) + + # 清理被移除的流程 + for state, flow_data in auth_flows.items(): + if state not in new_auth_flows: + try: + if flow_data.get("server"): + server = flow_data["server"] + port = flow_data.get("callback_port") + async_shutdown_server(server, port) + except Exception: + pass + flow_data.clear() + + auth_flows = new_auth_flows + log.info(f"强制清理认证流程,保留 {len(auth_flows)} 个最新流程") + + return len(auth_flows) + + +async def find_available_port(start_port: int = None) -> int: + """动态查找可用端口""" + if start_port is None: + start_port = await get_callback_port() + + # 首先尝试默认端口 + for port in range(start_port, start_port + 100): # 尝试100个端口 + try: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + s.bind(("0.0.0.0", port)) + log.info(f"找到可用端口: {port}") + return port + except OSError: + continue + + # 如果都不可用,让系统自动分配端口 + try: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("0.0.0.0", 0)) + port = s.getsockname()[1] + log.info(f"系统分配可用端口: {port}") + return port + except OSError as e: + log.error(f"无法找到可用端口: {e}") + raise RuntimeError("无法找到可用端口") + + +def create_callback_server(port: int) -> HTTPServer: + """创建指定端口的回调服务器,优化快速关闭""" + try: + # 服务器监听0.0.0.0 + server = HTTPServer(("0.0.0.0", port), AuthCallbackHandler) + + # 设置socket选项以支持快速关闭 + server.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + # 设置较短的超时时间 + server.timeout = 1.0 + + log.info(f"创建OAuth回调服务器,监听端口: {port}") + return server + except OSError as e: + log.error(f"创建端口{port}的服务器失败: {e}") + raise + + +class AuthCallbackHandler(BaseHTTPRequestHandler): + """OAuth回调处理器""" + + def do_GET(self): + query_components = parse_qs(urlparse(self.path).query) + code = query_components.get("code", [None])[0] + state = query_components.get("state", [None])[0] + + log.info(f"收到OAuth回调: code={'已获取' if code else '未获取'}, state={state}") + + if code and state and state in auth_flows: + # 更新流程状态 + auth_flows[state]["code"] = code + auth_flows[state]["completed"] = True + + log.info(f"OAuth回调成功处理: state={state}") + + self.send_response(200) + self.send_header("Content-type", "text/html") + self.end_headers() + # 成功页面 + self.wfile.write( + b"

OAuth authentication successful!

You can close this window. Please return to the original page and click 'Get Credentials' button.

" + ) + else: + self.send_response(400) + self.send_header("Content-type", "text/html") + self.end_headers() + self.wfile.write(b"

Authentication failed.

Please try again.

") + + def log_message(self, format, *args): + # 减少日志噪音 + pass + + +async def create_auth_url( + project_id: Optional[str] = None, user_session: str = None, mode: str = "geminicli" +) -> Dict[str, Any]: + """创建认证URL,支持动态端口分配""" + try: + # 动态分配端口 + callback_port = await find_available_port() + callback_url = f"http://{CALLBACK_HOST}:{callback_port}" + + # 立即启动回调服务器 + try: + callback_server = create_callback_server(callback_port) + # 在后台线程中运行服务器 + server_thread = threading.Thread( + target=callback_server.serve_forever, + daemon=True, + name=f"OAuth-Server-{callback_port}", + ) + server_thread.start() + log.info(f"OAuth回调服务器已启动,端口: {callback_port}") + except Exception as e: + log.error(f"启动回调服务器失败: {e}") + return { + "success": False, + "error": f"无法启动OAuth回调服务器,端口{callback_port}: {str(e)}", + } + + # 创建OAuth流程 + # 根据模式选择配置 + if mode == "antigravity": + client_id = ANTIGRAVITY_CLIENT_ID + client_secret = ANTIGRAVITY_CLIENT_SECRET + scopes = ANTIGRAVITY_SCOPES + else: + client_id = CLIENT_ID + client_secret = CLIENT_SECRET + scopes = SCOPES + + flow = Flow( + client_id=client_id, + client_secret=client_secret, + scopes=scopes, + redirect_uri=callback_url, + ) + + # 生成状态标识符,包含用户会话信息 + if user_session: + state = f"{user_session}_{str(uuid.uuid4())}" + else: + state = str(uuid.uuid4()) + + # 生成认证URL + auth_url = flow.get_auth_url(state=state) + + # 严格控制认证流程数量 - 超过限制时立即清理最旧的 + if len(auth_flows) >= MAX_AUTH_FLOWS: + # 清理最旧的认证流程 + oldest_state = min(auth_flows.keys(), key=lambda k: auth_flows[k].get("created_at", 0)) + try: + # 清理服务器资源 + old_flow = auth_flows[oldest_state] + if old_flow.get("server"): + server = old_flow["server"] + port = old_flow.get("callback_port") + async_shutdown_server(server, port) + except Exception as e: + log.warning(f"Failed to cleanup old auth flow {oldest_state}: {e}") + + del auth_flows[oldest_state] + log.debug(f"Removed oldest auth flow: {oldest_state}") + + # 保存流程状态 + auth_flows[state] = { + "flow": flow, + "project_id": project_id, # 可能为None,稍后在回调时确定 + "user_session": user_session, + "callback_port": callback_port, # 存储分配的端口 + "callback_url": callback_url, # 存储完整回调URL + "server": callback_server, # 存储服务器实例 + "server_thread": server_thread, # 存储服务器线程 + "code": None, + "completed": False, + "created_at": time.time(), + "auto_project_detection": project_id is None, # 标记是否需要自动检测项目ID + "mode": mode, # 凭证模式 + } + + # 清理过期的流程(30分钟) + cleanup_expired_flows() + + log.info(f"OAuth流程已创建: state={state}, project_id={project_id}") + log.info(f"用户需要访问认证URL,然后OAuth会回调到 {callback_url}") + log.info(f"为此认证流程分配的端口: {callback_port}") + + return { + "auth_url": auth_url, + "state": state, + "callback_port": callback_port, + "success": True, + "auto_project_detection": project_id is None, + "detected_project_id": project_id, + } + + except Exception as e: + log.error(f"创建认证URL失败: {e}") + return {"success": False, "error": str(e)} + + +def wait_for_callback_sync(state: str, timeout: int = 300) -> Optional[str]: + """同步等待OAuth回调完成,使用对应流程的专用服务器""" + if state not in auth_flows: + log.error(f"未找到状态为 {state} 的认证流程") + return None + + flow_data = auth_flows[state] + callback_port = flow_data["callback_port"] + + # 服务器已经在create_auth_url时启动了,这里只需要等待 + log.info(f"等待OAuth回调完成,端口: {callback_port}") + + # 等待回调完成 + start_time = time.time() + while time.time() - start_time < timeout: + if flow_data.get("code"): + log.info("OAuth回调成功完成") + return flow_data["code"] + time.sleep(0.5) # 每0.5秒检查一次 + + # 刷新flow_data引用 + if state in auth_flows: + flow_data = auth_flows[state] + + log.warning(f"等待OAuth回调超时 ({timeout}秒)") + return None + + +async def complete_auth_flow( + project_id: Optional[str] = None, user_session: str = None +) -> Dict[str, Any]: + """完成认证流程并保存凭证,支持自动检测项目ID""" + try: + # 查找对应的认证流程 + state = None + flow_data = None + + # 如果指定了project_id,先尝试匹配指定的项目 + if project_id: + for s, data in auth_flows.items(): + if data["project_id"] == project_id: + # 如果指定了用户会话,优先匹配相同会话的流程 + if user_session and data.get("user_session") == user_session: + state = s + flow_data = data + break + # 如果没有指定会话,或没找到匹配会话的流程,使用第一个匹配项目ID的 + elif not state: + state = s + flow_data = data + + # 如果没有指定项目ID或没找到匹配的,查找需要自动检测项目ID的流程 + if not state: + for s, data in auth_flows.items(): + if data.get("auto_project_detection", False): + # 如果指定了用户会话,优先匹配相同会话的流程 + if user_session and data.get("user_session") == user_session: + state = s + flow_data = data + break + # 使用第一个找到的需要自动检测的流程 + elif not state: + state = s + flow_data = data + + if not state or not flow_data: + return {"success": False, "error": "未找到对应的认证流程,请先点击获取认证链接"} + + if not project_id: + project_id = flow_data.get("project_id") + if not project_id: + return { + "success": False, + "error": "缺少项目ID,请指定项目ID", + "requires_manual_project_id": True, + } + + flow = flow_data["flow"] + + # 如果还没有授权码,需要等待回调 + if not flow_data.get("code"): + log.info(f"等待用户完成OAuth授权 (state: {state})") + auth_code = wait_for_callback_sync(state) + + if not auth_code: + return { + "success": False, + "error": "未接收到授权回调,请确保完成了浏览器中的OAuth认证", + } + + # 更新流程数据 + auth_flows[state]["code"] = auth_code + auth_flows[state]["completed"] = True + else: + auth_code = flow_data["code"] + + # 使用认证代码获取凭证 + with _OAuthLibPatcher(): + try: + credentials = await flow.exchange_code(auth_code) + # credentials 已经在 exchange_code 中获得 + + # 如果需要自动检测项目ID且没有提供项目ID + if flow_data.get("auto_project_detection", False) and not project_id: + log.info("尝试通过API获取用户项目列表...") + log.info(f"使用的token: {credentials.access_token[:20]}...") + log.info(f"Token过期时间: {credentials.expires_at}") + user_projects = await get_user_projects(credentials) + + if user_projects: + # 如果只有一个项目,自动使用 + if len(user_projects) == 1: + # Google API returns projectId in camelCase + project_id = user_projects[0].get("projectId") + if project_id: + flow_data["project_id"] = project_id + log.info(f"自动选择唯一项目: {project_id}") + # 如果有多个项目,尝试选择默认项目 + else: + project_id = await select_default_project(user_projects) + if project_id: + flow_data["project_id"] = project_id + log.info(f"自动选择默认项目: {project_id}") + else: + # 返回项目列表让用户选择 + return { + "success": False, + "error": "请从以下项目中选择一个", + "requires_project_selection": True, + "available_projects": [ + { + # Google API returns projectId in camelCase + "project_id": p.get("projectId"), + "name": p.get("displayName") or p.get("projectId"), + "projectNumber": p.get("projectNumber"), + } + for p in user_projects + ], + } + else: + # 如果无法获取项目列表,提示手动输入 + return { + "success": False, + "error": "无法获取您的项目列表,请手动指定项目ID", + "requires_manual_project_id": True, + } + + # 如果仍然没有项目ID,返回错误 + if not project_id: + return { + "success": False, + "error": "缺少项目ID,请指定项目ID", + "requires_manual_project_id": True, + } + + # 保存凭证 + saved_filename = await save_credentials(credentials, project_id) + + # 准备返回的凭证数据 + creds_data = _prepare_credentials_data(credentials, project_id, mode="geminicli") + + # 清理使用过的流程 + _cleanup_auth_flow_server(state) + + log.info("OAuth认证成功,凭证已保存") + return { + "success": True, + "credentials": creds_data, + "file_path": saved_filename, + "auto_detected_project": flow_data.get("auto_project_detection", False), + } + + except Exception as e: + log.error(f"获取凭证失败: {e}") + return {"success": False, "error": f"获取凭证失败: {str(e)}"} + + except Exception as e: + log.error(f"完成认证流程失败: {e}") + return {"success": False, "error": str(e)} + + +async def asyncio_complete_auth_flow( + project_id: Optional[str] = None, user_session: str = None, mode: str = "geminicli" +) -> Dict[str, Any]: + """异步完成认证流程,支持自动检测项目ID""" + try: + log.info( + f"asyncio_complete_auth_flow开始执行: project_id={project_id}, user_session={user_session}" + ) + + # 查找对应的认证流程 + state = None + flow_data = None + + log.debug(f"当前所有auth_flows: {list(auth_flows.keys())}") + + # 如果指定了project_id,先尝试匹配指定的项目 + if project_id: + log.info(f"尝试匹配指定的项目ID: {project_id}") + for s, data in auth_flows.items(): + if data["project_id"] == project_id: + # 如果指定了用户会话,优先匹配相同会话的流程 + if user_session and data.get("user_session") == user_session: + state = s + flow_data = data + log.info(f"找到匹配的用户会话: {s}") + break + # 如果没有指定会话,或没找到匹配会话的流程,使用第一个匹配项目ID的 + elif not state: + state = s + flow_data = data + log.info(f"找到匹配的项目ID: {s}") + + # 如果没有指定项目ID或没找到匹配的,查找需要自动检测项目ID的流程 + if not state: + log.info("没有找到指定项目的流程,查找自动检测流程") + # 首先尝试找到已完成的流程(有授权码的) + completed_flows = [] + for s, data in auth_flows.items(): + if data.get("auto_project_detection", False): + if user_session and data.get("user_session") == user_session: + if data.get("code"): # 优先选择已完成的 + completed_flows.append((s, data, data.get("created_at", 0))) + + # 如果有已完成的流程,选择最新的 + if completed_flows: + completed_flows.sort(key=lambda x: x[2], reverse=True) # 按时间倒序 + state, flow_data, _ = completed_flows[0] + log.info(f"找到已完成的最新认证流程: {state}") + else: + # 如果没有已完成的,找最新的未完成流程 + pending_flows = [] + for s, data in auth_flows.items(): + if data.get("auto_project_detection", False): + if user_session and data.get("user_session") == user_session: + pending_flows.append((s, data, data.get("created_at", 0))) + elif not user_session: + pending_flows.append((s, data, data.get("created_at", 0))) + + if pending_flows: + pending_flows.sort(key=lambda x: x[2], reverse=True) # 按时间倒序 + state, flow_data, _ = pending_flows[0] + log.info(f"找到最新的待完成认证流程: {state}") + + if not state or not flow_data: + log.error(f"未找到认证流程: state={state}, flow_data存在={bool(flow_data)}") + log.debug(f"当前所有flow_data: {list(auth_flows.keys())}") + return {"success": False, "error": "未找到对应的认证流程,请先点击获取认证链接"} + + log.info(f"找到认证流程: state={state}") + log.info( + f"flow_data内容: project_id={flow_data.get('project_id')}, auto_project_detection={flow_data.get('auto_project_detection')}" + ) + log.info(f"传入的project_id参数: {project_id}") + + # 如果需要自动检测项目ID且没有提供项目ID + log.info( + f"检查auto_project_detection条件: auto_project_detection={flow_data.get('auto_project_detection', False)}, not project_id={not project_id}" + ) + if flow_data.get("auto_project_detection", False) and not project_id: + log.info("跳过自动检测项目ID,进入等待阶段") + elif not project_id: + log.info("进入project_id检查分支") + project_id = flow_data.get("project_id") + if not project_id: + log.error("缺少项目ID,返回错误") + return { + "success": False, + "error": "缺少项目ID,请指定项目ID", + "requires_manual_project_id": True, + } + else: + log.info(f"使用提供的项目ID: {project_id}") + + # 检查是否已经有授权码 + log.info("开始检查OAuth授权码...") + log.info(f"等待state={state}的授权回调,回调端口: {flow_data.get('callback_port')}") + log.info(f"当前flow_data状态: completed={flow_data.get('completed')}, code存在={bool(flow_data.get('code'))}") + max_wait_time = 60 # 最多等待60秒 + wait_interval = 1 # 每秒检查一次 + waited = 0 + + while waited < max_wait_time: + if flow_data.get("code"): + log.info(f"检测到OAuth授权码,开始处理凭证 (等待时间: {waited}秒)") + break + + # 每5秒输出一次提示 + if waited % 5 == 0 and waited > 0: + log.info(f"仍在等待OAuth授权... ({waited}/{max_wait_time}秒)") + log.debug(f"当前state: {state}, flow_data keys: {list(flow_data.keys())}") + + # 异步等待 + await asyncio.sleep(wait_interval) + waited += wait_interval + + # 刷新flow_data引用,因为可能被回调更新了 + if state in auth_flows: + flow_data = auth_flows[state] + + if not flow_data.get("code"): + log.error(f"等待OAuth回调超时,等待了{waited}秒") + return { + "success": False, + "error": "等待OAuth回调超时,请确保完成了浏览器中的认证并看到成功页面", + } + + flow = flow_data["flow"] + auth_code = flow_data["code"] + + log.info(f"开始使用授权码获取凭证: code={'***' + auth_code[-4:] if auth_code else 'None'}") + + # 使用认证代码获取凭证 + with _OAuthLibPatcher(): + try: + log.info("调用flow.exchange_code...") + credentials = await flow.exchange_code(auth_code) + log.info( + f"成功获取凭证,token前缀: {credentials.access_token[:20] if credentials.access_token else 'None'}..." + ) + + log.info( + f"检查是否需要项目检测: auto_project_detection={flow_data.get('auto_project_detection')}, project_id={project_id}" + ) + + # 检查凭证模式 + cred_mode = flow_data.get("mode", "geminicli") if flow_data.get("mode") else mode + if cred_mode == "antigravity": + log.info("Antigravity模式:从API获取project_id...") + # 使用API获取project_id + antigravity_url = await get_antigravity_api_url() + project_id = await fetch_project_id( + credentials.access_token, + ANTIGRAVITY_USER_AGENT, + antigravity_url + ) + if project_id: + log.info(f"成功从API获取project_id: {project_id}") + else: + log.warning("无法从API获取project_id,回退到随机生成") + project_id = _generate_random_project_id() + log.info(f"生成的随机project_id: {project_id}") + + # 保存antigravity凭证 + saved_filename = await save_credentials(credentials, project_id, mode="antigravity") + + # 准备返回的凭证数据 + creds_data = _prepare_credentials_data(credentials, project_id, mode="antigravity") + + # 清理使用过的流程 + _cleanup_auth_flow_server(state) + + log.info("Antigravity OAuth认证成功,凭证已保存") + return { + "success": True, + "credentials": creds_data, + "file_path": saved_filename, + "auto_detected_project": False, + "mode": "antigravity", + } + + # 如果需要自动检测项目ID且没有提供项目ID(标准模式) + if flow_data.get("auto_project_detection", False) and not project_id: + log.info("标准模式:从API获取project_id...") + # 使用API获取project_id(使用标准模式的User-Agent) + code_assist_url = await get_code_assist_endpoint() + project_id = await fetch_project_id( + credentials.access_token, + GEMINICLI_USER_AGENT, + code_assist_url + ) + if project_id: + flow_data["project_id"] = project_id + log.info(f"成功从API获取project_id: {project_id}") + # 自动启用必需的API服务 + log.info("正在自动启用必需的API服务...") + await enable_required_apis(credentials, project_id) + else: + log.warning("无法从API获取project_id,回退到项目列表获取方式") + # 回退到原来的项目列表获取方式 + user_projects = await get_user_projects(credentials) + + if user_projects: + # 如果只有一个项目,自动使用 + if len(user_projects) == 1: + # Google API returns projectId in camelCase + project_id = user_projects[0].get("projectId") + if project_id: + flow_data["project_id"] = project_id + log.info(f"自动选择唯一项目: {project_id}") + # 自动启用必需的API服务 + log.info("正在自动启用必需的API服务...") + await enable_required_apis(credentials, project_id) + # 如果有多个项目,尝试选择默认项目 + else: + project_id = await select_default_project(user_projects) + if project_id: + flow_data["project_id"] = project_id + log.info(f"自动选择默认项目: {project_id}") + # 自动启用必需的API服务 + log.info("正在自动启用必需的API服务...") + await enable_required_apis(credentials, project_id) + else: + # 返回项目列表让用户选择 + return { + "success": False, + "error": "请从以下项目中选择一个", + "requires_project_selection": True, + "available_projects": [ + { + # Google API returns projectId in camelCase + "project_id": p.get("projectId"), + "name": p.get("displayName") or p.get("projectId"), + "projectNumber": p.get("projectNumber"), + } + for p in user_projects + ], + } + else: + # 如果无法获取项目列表,提示手动输入 + return { + "success": False, + "error": "无法获取您的项目列表,请手动指定项目ID", + "requires_manual_project_id": True, + } + elif project_id: + # 如果已经有项目ID(手动提供或环境检测),也尝试启用API服务 + log.info("正在为已提供的项目ID自动启用必需的API服务...") + await enable_required_apis(credentials, project_id) + + # 如果仍然没有项目ID,返回错误 + if not project_id: + return { + "success": False, + "error": "缺少项目ID,请指定项目ID", + "requires_manual_project_id": True, + } + + # 保存凭证 + saved_filename = await save_credentials(credentials, project_id) + + # 准备返回的凭证数据 + creds_data = _prepare_credentials_data(credentials, project_id, mode="geminicli") + + # 清理使用过的流程 + _cleanup_auth_flow_server(state) + + log.info("OAuth认证成功,凭证已保存") + return { + "success": True, + "credentials": creds_data, + "file_path": saved_filename, + "auto_detected_project": flow_data.get("auto_project_detection", False), + } + + except Exception as e: + log.error(f"获取凭证失败: {e}") + return {"success": False, "error": f"获取凭证失败: {str(e)}"} + + except Exception as e: + log.error(f"异步完成认证流程失败: {e}") + return {"success": False, "error": str(e)} + + +async def complete_auth_flow_from_callback_url( + callback_url: str, project_id: Optional[str] = None, mode: str = "geminicli" +) -> Dict[str, Any]: + """从回调URL直接完成认证流程,无需启动本地服务器""" + try: + log.info(f"开始从回调URL完成认证: {callback_url}") + + # 解析回调URL + parsed_url = urlparse(callback_url) + query_params = parse_qs(parsed_url.query) + + # 验证必要参数 + if "state" not in query_params or "code" not in query_params: + return {"success": False, "error": "回调URL缺少必要参数 (state 或 code)"} + + state = query_params["state"][0] + code = query_params["code"][0] + + log.info(f"从URL解析到: state={state}, code=xxx...") + + # 检查是否有对应的认证流程 + if state not in auth_flows: + return { + "success": False, + "error": f"未找到对应的认证流程,请先启动认证 (state: {state})", + } + + flow_data = auth_flows[state] + flow = flow_data["flow"] + + # 构造回调URL(使用flow中存储的redirect_uri) + redirect_uri = flow.redirect_uri + log.info(f"使用redirect_uri: {redirect_uri}") + + try: + # 使用authorization code获取token + credentials = await flow.exchange_code(code) + log.info("成功获取访问令牌") + + # 检查凭证模式 + cred_mode = flow_data.get("mode", "geminicli") if flow_data.get("mode") else mode + if cred_mode == "antigravity": + log.info("Antigravity模式(从回调URL):从API获取project_id...") + # 使用API获取project_id + antigravity_url = await get_antigravity_api_url() + project_id = await fetch_project_id( + credentials.access_token, + ANTIGRAVITY_USER_AGENT, + antigravity_url + ) + if project_id: + log.info(f"成功从API获取project_id: {project_id}") + else: + log.warning("无法从API获取project_id,回退到随机生成") + project_id = _generate_random_project_id() + log.info(f"生成的随机project_id: {project_id}") + + # 保存antigravity凭证 + saved_filename = await save_credentials(credentials, project_id, mode="antigravity") + + # 准备返回的凭证数据 + creds_data = _prepare_credentials_data(credentials, project_id, mode="antigravity") + + # 清理使用过的流程 + _cleanup_auth_flow_server(state) + + log.info("从回调URL完成Antigravity OAuth认证成功,凭证已保存") + return { + "success": True, + "credentials": creds_data, + "file_path": saved_filename, + "auto_detected_project": False, + "mode": "antigravity", + } + + # 标准模式的项目ID处理逻辑 + detected_project_id = None + auto_detected = False + + if not project_id: + # 尝试使用fetch_project_id自动获取项目ID + try: + log.info("标准模式:从API获取project_id...") + code_assist_url = await get_code_assist_endpoint() + detected_project_id = await fetch_project_id( + credentials.access_token, + GEMINICLI_USER_AGENT, + code_assist_url + ) + if detected_project_id: + auto_detected = True + log.info(f"成功从API获取project_id: {detected_project_id}") + else: + log.warning("无法从API获取project_id,回退到项目列表获取方式") + # 回退到原来的项目列表获取方式 + projects = await get_user_projects(credentials) + if projects: + if len(projects) == 1: + # 只有一个项目,自动使用 + # Google API returns projectId in camelCase + detected_project_id = projects[0]["projectId"] + auto_detected = True + log.info(f"自动检测到唯一项目ID: {detected_project_id}") + else: + # 多个项目,自动选择第一个 + # Google API returns projectId in camelCase + detected_project_id = projects[0]["projectId"] + auto_detected = True + log.info( + f"检测到{len(projects)}个项目,自动选择第一个: {detected_project_id}" + ) + log.debug(f"其他可用项目: {[p['projectId'] for p in projects[1:]]}") + else: + # 没有项目访问权限 + return { + "success": False, + "error": "未检测到可访问的项目,请检查权限或手动指定项目ID", + "requires_manual_project_id": True, + } + except Exception as e: + log.warning(f"自动检测项目ID失败: {e}") + return { + "success": False, + "error": f"自动检测项目ID失败: {str(e)},请手动指定项目ID", + "requires_manual_project_id": True, + } + else: + detected_project_id = project_id + + # 启用必需的API服务 + if detected_project_id: + try: + log.info(f"正在为项目 {detected_project_id} 启用必需的API服务...") + await enable_required_apis(credentials, detected_project_id) + except Exception as e: + log.warning(f"启用API服务失败: {e}") + + # 保存凭证 + saved_filename = await save_credentials(credentials, detected_project_id) + + # 准备返回的凭证数据 + creds_data = _prepare_credentials_data(credentials, detected_project_id, mode="geminicli") + + # 清理使用过的流程 + _cleanup_auth_flow_server(state) + + log.info("从回调URL完成OAuth认证成功,凭证已保存") + return { + "success": True, + "credentials": creds_data, + "file_path": saved_filename, + "auto_detected_project": auto_detected, + } + + except Exception as e: + log.error(f"从回调URL获取凭证失败: {e}") + return {"success": False, "error": f"获取凭证失败: {str(e)}"} + + except Exception as e: + log.error(f"从回调URL完成认证流程失败: {e}") + return {"success": False, "error": str(e)} + + +async def save_credentials(creds: Credentials, project_id: str, mode: str = "geminicli") -> str: + """通过统一存储系统保存凭证""" + # 生成文件名(使用project_id和时间戳) + timestamp = int(time.time()) + + # antigravity模式使用特殊前缀 + if mode == "antigravity": + filename = f"ag_{project_id}-{timestamp}.json" + else: + filename = f"{project_id}-{timestamp}.json" + + # 准备凭证数据 + creds_data = _prepare_credentials_data(creds, project_id, mode) + + # 通过存储适配器保存 + storage_adapter = await get_storage_adapter() + success = await storage_adapter.store_credential(filename, creds_data, mode=mode) + + if success: + # 创建默认状态记录 + try: + default_state = { + "error_codes": [], + "disabled": False, + "last_success": time.time(), + "user_email": None, + } + await storage_adapter.update_credential_state(filename, default_state, mode=mode) + log.info(f"凭证和状态已保存到: {filename} (mode={mode})") + except Exception as e: + log.warning(f"创建默认状态记录失败 {filename}: {e}") + + return filename + else: + raise Exception(f"保存凭证失败: {filename}") + + +def async_shutdown_server(server, port): + """异步关闭OAuth回调服务器,避免阻塞主流程""" + + def shutdown_server_async(): + try: + # 设置一个标志来跟踪关闭状态 + shutdown_completed = threading.Event() + + def do_shutdown(): + try: + server.shutdown() + server.server_close() + shutdown_completed.set() + log.info(f"已关闭端口 {port} 的OAuth回调服务器") + except Exception as e: + shutdown_completed.set() + log.debug(f"关闭服务器时出错: {e}") + + # 在单独线程中执行关闭操作 + shutdown_worker = threading.Thread(target=do_shutdown, daemon=True) + shutdown_worker.start() + + # 等待最多5秒,如果超时就放弃等待 + if shutdown_completed.wait(timeout=5): + log.debug(f"端口 {port} 服务器关闭完成") + else: + log.warning(f"端口 {port} 服务器关闭超时,但不阻塞主流程") + + except Exception as e: + log.debug(f"异步关闭服务器时出错: {e}") + + # 在后台线程中关闭服务器,不阻塞主流程 + shutdown_thread = threading.Thread(target=shutdown_server_async, daemon=True) + shutdown_thread.start() + log.debug(f"开始异步关闭端口 {port} 的OAuth回调服务器") + + +def cleanup_expired_flows(): + """清理过期的认证流程""" + current_time = time.time() + EXPIRY_TIME = 600 # 10分钟过期 + + # 直接遍历删除,避免创建额外列表 + states_to_remove = [ + state + for state, flow_data in auth_flows.items() + if current_time - flow_data["created_at"] > EXPIRY_TIME + ] + + # 批量清理,提高效率 + cleaned_count = 0 + for state in states_to_remove: + flow_data = auth_flows.get(state) + if flow_data: + # 快速关闭可能存在的服务器 + try: + if flow_data.get("server"): + server = flow_data["server"] + port = flow_data.get("callback_port") + async_shutdown_server(server, port) + except Exception as e: + log.debug(f"清理过期流程时启动异步关闭服务器失败: {e}") + + # 显式清理流程数据,释放内存 + flow_data.clear() + del auth_flows[state] + cleaned_count += 1 + + if cleaned_count > 0: + log.info(f"清理了 {cleaned_count} 个过期的认证流程") + + # 更积极的垃圾回收触发条件 + if len(auth_flows) > 20: # 降低阈值 + import gc + + gc.collect() + log.debug(f"触发垃圾回收,当前活跃认证流程数: {len(auth_flows)}") + + +def get_auth_status(project_id: str) -> Dict[str, Any]: + """获取认证状态""" + for state, flow_data in auth_flows.items(): + if flow_data["project_id"] == project_id: + return { + "status": "completed" if flow_data["completed"] else "pending", + "state": state, + "created_at": flow_data["created_at"], + } + + return {"status": "not_found"} + + +# 鉴权功能 - 使用更小的数据结构 +auth_tokens = {} # 存储有效的认证令牌 +TOKEN_EXPIRY = 3600 # 1小时令牌过期时间 + + +async def verify_password(password: str) -> bool: + """验证密码(面板登录使用)""" + from config import get_panel_password + + correct_password = await get_panel_password() + return password == correct_password + + +def generate_auth_token() -> str: + """生成认证令牌""" + # 清理过期令牌 + cleanup_expired_tokens() + + token = secrets.token_urlsafe(32) + # 只存储创建时间 + auth_tokens[token] = time.time() + return token + + +def verify_auth_token(token: str) -> bool: + """验证认证令牌""" + if not token or token not in auth_tokens: + return False + + created_at = auth_tokens[token] + + # 检查令牌是否过期 (使用更短的过期时间) + if time.time() - created_at > TOKEN_EXPIRY: + del auth_tokens[token] + return False + + return True + + +def cleanup_expired_tokens(): + """清理过期的认证令牌""" + current_time = time.time() + expired_tokens = [ + token + for token, created_at in auth_tokens.items() + if current_time - created_at > TOKEN_EXPIRY + ] + + for token in expired_tokens: + del auth_tokens[token] + + if expired_tokens: + log.debug(f"清理了 {len(expired_tokens)} 个过期的认证令牌") + + +def invalidate_auth_token(token: str): + """使认证令牌失效""" + if token in auth_tokens: + del auth_tokens[token] + + +# 文件验证和处理功能 - 使用统一存储系统 +def validate_credential_content(content: str) -> Dict[str, Any]: + """验证凭证内容格式""" + try: + creds_data = json.loads(content) + + # 检查必要字段 + required_fields = ["client_id", "client_secret", "refresh_token", "token_uri"] + missing_fields = [field for field in required_fields if field not in creds_data] + + if missing_fields: + return {"valid": False, "error": f'缺少必要字段: {", ".join(missing_fields)}'} + + # 检查project_id + if "project_id" not in creds_data: + log.warning("认证文件缺少project_id字段") + + return {"valid": True, "data": creds_data} + + except json.JSONDecodeError as e: + return {"valid": False, "error": f"JSON格式错误: {str(e)}"} + except Exception as e: + return {"valid": False, "error": f"文件验证失败: {str(e)}"} + + +async def save_uploaded_credential(content: str, original_filename: str) -> Dict[str, Any]: + """通过统一存储系统保存上传的凭证""" + try: + # 验证内容格式 + validation = validate_credential_content(content) + if not validation["valid"]: + return {"success": False, "error": validation["error"]} + + creds_data = validation["data"] + + # 生成文件名 + project_id = creds_data.get("project_id", "unknown") + timestamp = int(time.time()) + + # 从原文件名中提取有用信息 + import os + + base_name = os.path.splitext(original_filename)[0] + filename = f"{base_name}-{timestamp}.json" + + # 通过存储适配器保存 + storage_adapter = await get_storage_adapter() + success = await storage_adapter.store_credential(filename, creds_data) + + if success: + log.info(f"凭证文件已上传保存: {filename}") + return {"success": True, "file_path": filename, "project_id": project_id} + else: + return {"success": False, "error": "保存到存储系统失败"} + + except Exception as e: + log.error(f"保存上传文件失败: {e}") + return {"success": False, "error": str(e)} + + +async def batch_upload_credentials(files_data: List[Dict[str, str]]) -> Dict[str, Any]: + """批量上传凭证文件到统一存储系统""" + results = [] + success_count = 0 + + for file_data in files_data: + filename = file_data.get("filename", "unknown.json") + content = file_data.get("content", "") + + result = await save_uploaded_credential(content, filename) + result["filename"] = filename + results.append(result) + + if result["success"]: + success_count += 1 + + return {"uploaded_count": success_count, "total_count": len(files_data), "results": results} diff --git a/src/converter/anthropic2gemini.py b/src/converter/anthropic2gemini.py new file mode 100644 index 0000000000000000000000000000000000000000..d757f16ff5a789c756becaddbc9c110869aac908 --- /dev/null +++ b/src/converter/anthropic2gemini.py @@ -0,0 +1,931 @@ +""" +Anthropic 到 Gemini 格式转换器 + +提供请求体、响应和流式转换的完整功能。 +""" +from __future__ import annotations + +import json +import os +import uuid +from typing import Any, AsyncIterator, Dict, List, Optional + +from log import log +from src.converter.utils import merge_system_messages + +from src.converter.thoughtSignature_fix import ( + encode_tool_id_with_signature, + decode_tool_id_and_signature +) + +DEFAULT_TEMPERATURE = 0.4 +_DEBUG_TRUE = {"1", "true", "yes", "on"} + + +# ============================================================================ +# 请求验证和提取 +# ============================================================================ + + +def _anthropic_debug_enabled() -> bool: + """检查是否启用 Anthropic 调试模式""" + return str(os.getenv("ANTHROPIC_DEBUG", "true")).strip().lower() in _DEBUG_TRUE + + +def _is_non_whitespace_text(value: Any) -> bool: + """ + 判断文本是否包含"非空白"内容。 + + 说明:下游(Antigravity/Claude 兼容层)会对纯 text 内容块做校验: + - text 不能为空字符串 + - text 不能仅由空白字符(空格/换行/制表等)组成 + """ + if value is None: + return False + try: + return bool(str(value).strip()) + except Exception: + return False + + +def _remove_nulls_for_tool_input(value: Any) -> Any: + """ + 递归移除 dict/list 中值为 null/None 的字段/元素。 + + 背景:Roo/Kilo 在 Anthropic native tool 路径下,若收到 tool_use.input 中包含 null, + 可能会把 null 当作真实入参执行(例如"在 null 中搜索")。 + """ + if isinstance(value, dict): + cleaned: Dict[str, Any] = {} + for k, v in value.items(): + if v is None: + continue + cleaned[k] = _remove_nulls_for_tool_input(v) + return cleaned + + if isinstance(value, list): + cleaned_list = [] + for item in value: + if item is None: + continue + cleaned_list.append(_remove_nulls_for_tool_input(item)) + return cleaned_list + + return value + +# ============================================================================ +# 2. JSON Schema 清理 +# ============================================================================ + +def clean_json_schema(schema: Any) -> Any: + """ + 清理 JSON Schema,移除下游不支持的字段,并把验证要求追加到 description。 + """ + if not isinstance(schema, dict): + return schema + + # 下游不支持的字段 + unsupported_keys = { + "$schema", "$id", "$ref", "$defs", "definitions", "title", + "example", "examples", "readOnly", "writeOnly", "default", + "exclusiveMaximum", "exclusiveMinimum", "oneOf", "anyOf", "allOf", + "const", "additionalItems", "contains", "patternProperties", + "dependencies", "propertyNames", "if", "then", "else", + "contentEncoding", "contentMediaType", + } + + validation_fields = { + "minLength": "minLength", + "maxLength": "maxLength", + "minimum": "minimum", + "maximum": "maximum", + "minItems": "minItems", + "maxItems": "maxItems", + } + fields_to_remove = {"additionalProperties"} + + validations: List[str] = [] + for field, label in validation_fields.items(): + if field in schema: + validations.append(f"{label}: {schema[field]}") + + cleaned: Dict[str, Any] = {} + for key, value in schema.items(): + if key in unsupported_keys or key in fields_to_remove or key in validation_fields: + continue + + if key == "type" and isinstance(value, list): + # type: ["string", "null"] -> type: "string", nullable: true + has_null = any( + isinstance(t, str) and t.strip() and t.strip().lower() == "null" for t in value + ) + non_null_types = [ + t.strip() + for t in value + if isinstance(t, str) and t.strip() and t.strip().lower() != "null" + ] + + cleaned[key] = non_null_types[0] if non_null_types else "string" + if has_null: + cleaned["nullable"] = True + continue + + if key == "description" and validations: + cleaned[key] = f"{value} ({', '.join(validations)})" + elif isinstance(value, dict): + cleaned[key] = clean_json_schema(value) + elif isinstance(value, list): + cleaned[key] = [clean_json_schema(item) if isinstance(item, dict) else item for item in value] + else: + cleaned[key] = value + + if validations and "description" not in cleaned: + cleaned["description"] = f"Validation: {', '.join(validations)}" + + # 如果有 properties 但没有显式 type,则补齐为 object + if "properties" in cleaned and "type" not in cleaned: + cleaned["type"] = "object" + + return cleaned + + +# ============================================================================ +# 4. Tools 转换 +# ============================================================================ + +def convert_tools(anthropic_tools: Optional[List[Dict[str, Any]]]) -> Optional[List[Dict[str, Any]]]: + """ + 将 Anthropic tools[] 转换为下游 tools(functionDeclarations)结构。 + """ + if not anthropic_tools: + return None + + gemini_tools: List[Dict[str, Any]] = [] + for tool in anthropic_tools: + name = tool.get("name", "nameless_function") + description = tool.get("description", "") + input_schema = tool.get("input_schema", {}) or {} + parameters = clean_json_schema(input_schema) + + gemini_tools.append( + { + "functionDeclarations": [ + { + "name": name, + "description": description, + "parameters": parameters, + } + ] + } + ) + + return gemini_tools or None + + +# ============================================================================ +# 5. Messages 转换 +# ============================================================================ + +def _extract_tool_result_output(content: Any) -> str: + """从 tool_result.content 中提取输出字符串""" + if isinstance(content, list): + if not content: + return "" + first = content[0] + if isinstance(first, dict) and first.get("type") == "text": + return str(first.get("text", "")) + return str(first) + if content is None: + return "" + return str(content) + + +def convert_messages_to_contents( + messages: List[Dict[str, Any]], + *, + include_thinking: bool = True +) -> List[Dict[str, Any]]: + """ + 将 Anthropic messages[] 转换为下游 contents[](role: user/model, parts: [])。 + + Args: + messages: Anthropic 格式的消息列表 + include_thinking: 是否包含 thinking 块 + """ + contents: List[Dict[str, Any]] = [] + + # 第一遍:构建 tool_use_id -> name 的映射 + tool_use_names: Dict[str, str] = {} + for msg in messages: + raw_content = msg.get("content", "") + if isinstance(raw_content, list): + for item in raw_content: + if isinstance(item, dict) and item.get("type") == "tool_use": + tool_id = item.get("id") + tool_name = item.get("name") + if tool_id and tool_name: + tool_use_names[str(tool_id)] = tool_name + + for msg in messages: + role = msg.get("role", "user") + + # system 消息已经由 merge_system_messages 处理,这里跳过 + if role == "system": + continue + + gemini_role = "model" if role == "assistant" else "user" + raw_content = msg.get("content", "") + + parts: List[Dict[str, Any]] = [] + if isinstance(raw_content, str): + if _is_non_whitespace_text(raw_content): + parts = [{"text": str(raw_content)}] + elif isinstance(raw_content, list): + for item in raw_content: + if not isinstance(item, dict): + if _is_non_whitespace_text(item): + parts.append({"text": str(item)}) + continue + + item_type = item.get("type") + if item_type == "thinking": + if not include_thinking: + continue + + thinking_text = item.get("thinking", "") + if thinking_text is None: + thinking_text = "" + + part: Dict[str, Any] = { + "text": str(thinking_text), + "thought": True, + } + + # 如果有 signature 则添加 + signature = item.get("signature") + if signature: + part["thoughtSignature"] = signature + + parts.append(part) + elif item_type == "redacted_thinking": + if not include_thinking: + continue + + thinking_text = item.get("thinking") + if thinking_text is None: + thinking_text = item.get("data", "") + + part_dict: Dict[str, Any] = { + "text": str(thinking_text or ""), + "thought": True, + } + + # 如果有 signature 则添加 + signature = item.get("signature") + if signature: + part_dict["thoughtSignature"] = signature + + parts.append(part_dict) + elif item_type == "text": + text = item.get("text", "") + if _is_non_whitespace_text(text): + parts.append({"text": str(text)}) + elif item_type == "image": + source = item.get("source", {}) or {} + if source.get("type") == "base64": + parts.append( + { + "inlineData": { + "mimeType": source.get("media_type", "image/png"), + "data": source.get("data", ""), + } + } + ) + elif item_type == "tool_use": + encoded_id = item.get("id") or "" + original_id, signature = decode_tool_id_and_signature(encoded_id) + + fc_part: Dict[str, Any] = { + "functionCall": { + "id": original_id, + "name": item.get("name"), + "args": item.get("input", {}) or {}, + } + } + + # 如果提取到签名则添加 + if signature: + fc_part["thoughtSignature"] = signature + + parts.append(fc_part) + elif item_type == "tool_result": + output = _extract_tool_result_output(item.get("content")) + encoded_tool_use_id = item.get("tool_use_id") or "" + # 解码获取原始ID(functionResponse不需要签名) + original_tool_use_id, _ = decode_tool_id_and_signature(encoded_tool_use_id) + + # 从 tool_result 获取 name,如果没有则从映射中查找 + func_name = item.get("name") + if not func_name and encoded_tool_use_id: + # 使用编码ID查找,因为映射中存储的是编码ID + func_name = tool_use_names.get(str(encoded_tool_use_id)) + if not func_name: + func_name = "unknown_function" + parts.append( + { + "functionResponse": { + "id": original_tool_use_id, # 使用解码后的ID以匹配functionCall + "name": func_name, + "response": {"output": output}, + } + } + ) + else: + parts.append({"text": json.dumps(item, ensure_ascii=False)}) + else: + if _is_non_whitespace_text(raw_content): + parts = [{"text": str(raw_content)}] + + if not parts: + continue + + contents.append({"role": gemini_role, "parts": parts}) + + return contents + + +def reorganize_tool_messages(contents: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """ + 重新组织消息,满足 tool_use/tool_result 约束。 + """ + tool_results: Dict[str, Dict[str, Any]] = {} + + for msg in contents: + for part in msg.get("parts", []) or []: + if isinstance(part, dict) and "functionResponse" in part: + tool_id = (part.get("functionResponse") or {}).get("id") + if tool_id: + tool_results[str(tool_id)] = part + + flattened: List[Dict[str, Any]] = [] + for msg in contents: + role = msg.get("role") + for part in msg.get("parts", []) or []: + flattened.append({"role": role, "parts": [part]}) + + new_contents: List[Dict[str, Any]] = [] + i = 0 + while i < len(flattened): + msg = flattened[i] + part = msg["parts"][0] + + if isinstance(part, dict) and "functionResponse" in part: + i += 1 + continue + + if isinstance(part, dict) and "functionCall" in part: + tool_id = (part.get("functionCall") or {}).get("id") + new_contents.append({"role": "model", "parts": [part]}) + + if tool_id is not None and str(tool_id) in tool_results: + new_contents.append({"role": "user", "parts": [tool_results[str(tool_id)]]}) + + i += 1 + continue + + new_contents.append(msg) + i += 1 + + return new_contents + + +# ============================================================================ +# 7. Generation Config 构建 +# ============================================================================ + +def build_generation_config(payload: Dict[str, Any]) -> Dict[str, Any]: + """ + 根据 Anthropic Messages 请求构造下游 generationConfig。 + + Returns: + generation_config: 生成配置字典 + """ + config: Dict[str, Any] = { + "topP": 1, + "candidateCount": 1, + "stopSequences": [ + "<|user|>", + "<|bot|>", + "<|context_request|>", + "<|endoftext|>", + "<|end_of_turn|>", + ], + } + + temperature = payload.get("temperature", None) + config["temperature"] = DEFAULT_TEMPERATURE if temperature is None else temperature + + top_p = payload.get("top_p", None) + if top_p is not None: + config["topP"] = top_p + + top_k = payload.get("top_k", None) + if top_k is not None: + config["topK"] = top_k + + max_tokens = payload.get("max_tokens") + if max_tokens is not None: + config["maxOutputTokens"] = max_tokens + + stop_sequences = payload.get("stop_sequences") + if isinstance(stop_sequences, list) and stop_sequences: + config["stopSequences"] = config["stopSequences"] + [str(s) for s in stop_sequences] + + return config + + +# ============================================================================ +# 8. 主要转换函数 +# ============================================================================ + +async def anthropic_to_gemini_request(payload: Dict[str, Any]) -> Dict[str, Any]: + """ + 将 Anthropic 格式请求体转换为 Gemini 格式请求体 + + 注意: 此函数只负责基础转换,不包含 normalize_gemini_request 中的处理 + (如 thinking config 自动设置、search tools、参数范围限制等) + + Args: + payload: Anthropic 格式的请求体字典 + + Returns: + Gemini 格式的请求体字典,包含: + - contents: 转换后的消息内容 + - generationConfig: 生成配置 + - systemInstruction: 系统指令 (如果有) + - tools: 工具定义 (如果有) + """ + # 处理连续的system消息(兼容性模式) + payload = await merge_system_messages(payload) + + # 提取和转换基础信息 + messages = payload.get("messages") or [] + if not isinstance(messages, list): + messages = [] + + # 构建生成配置 + generation_config = build_generation_config(payload) + + # 转换消息内容(始终包含thinking块,由响应端处理) + contents = convert_messages_to_contents(messages, include_thinking=True) + contents = reorganize_tool_messages(contents) + + # 转换工具 + tools = convert_tools(payload.get("tools")) + + # 构建基础请求数据 + gemini_request = { + "contents": contents, + "generationConfig": generation_config, + } + + # 如果 merge_system_messages 已经添加了 systemInstruction,使用它 + if "systemInstruction" in payload: + gemini_request["systemInstruction"] = payload["systemInstruction"] + + if tools: + gemini_request["tools"] = tools + + return gemini_request + + +def gemini_to_anthropic_response( + gemini_response: Dict[str, Any], + model: str, + status_code: int = 200 +) -> Dict[str, Any]: + """ + 将 Gemini 格式非流式响应转换为 Anthropic 格式非流式响应 + + 注意: 如果收到的不是 200 开头的响应体,不做任何处理,直接转发 + + Args: + gemini_response: Gemini 格式的响应体字典 + model: 模型名称 + status_code: HTTP 状态码 (默认 200) + + Returns: + Anthropic 格式的响应体字典,或原始响应 (如果状态码不是 2xx) + """ + # 非 2xx 状态码直接返回原始响应 + if not (200 <= status_code < 300): + return gemini_response + + # 处理 GeminiCLI 的 response 包装格式 + if "response" in gemini_response: + response_data = gemini_response["response"] + else: + response_data = gemini_response + + # 提取候选结果 + candidate = response_data.get("candidates", [{}])[0] or {} + parts = candidate.get("content", {}).get("parts", []) or [] + + # 获取 usage metadata + usage_metadata = {} + if "usageMetadata" in response_data: + usage_metadata = response_data["usageMetadata"] + elif "usageMetadata" in candidate: + usage_metadata = candidate["usageMetadata"] + + # 转换内容块 + content = [] + has_tool_use = False + + for part in parts: + if not isinstance(part, dict): + continue + + # 处理 thinking 块 + if part.get("thought") is True: + block: Dict[str, Any] = {"type": "thinking", "thinking": part.get("text", "")} + signature = part.get("thoughtSignature") + if signature: + block["signature"] = signature + content.append(block) + continue + + # 处理文本块 + if "text" in part: + content.append({"type": "text", "text": part.get("text", "")}) + continue + + # 处理工具调用 + if "functionCall" in part: + has_tool_use = True + fc = part.get("functionCall", {}) or {} + original_id = fc.get("id") or f"toolu_{uuid.uuid4().hex}" + signature = part.get("thoughtSignature") + encoded_id = encode_tool_id_with_signature(original_id, signature) + content.append( + { + "type": "tool_use", + "id": encoded_id, + "name": fc.get("name") or "", + "input": _remove_nulls_for_tool_input(fc.get("args", {}) or {}), + } + ) + continue + + # 处理图片 + if "inlineData" in part: + inline = part.get("inlineData", {}) or {} + content.append( + { + "type": "image", + "source": { + "type": "base64", + "media_type": inline.get("mimeType", "image/png"), + "data": inline.get("data", ""), + }, + } + ) + continue + + # 确定停止原因 + finish_reason = candidate.get("finishReason") + stop_reason = "tool_use" if has_tool_use else "end_turn" + if finish_reason == "MAX_TOKENS" and not has_tool_use: + stop_reason = "max_tokens" + + # 提取 token 使用情况 + input_tokens = usage_metadata.get("promptTokenCount", 0) if isinstance(usage_metadata, dict) else 0 + output_tokens = usage_metadata.get("candidatesTokenCount", 0) if isinstance(usage_metadata, dict) else 0 + + # 构建 Anthropic 响应 + message_id = f"msg_{uuid.uuid4().hex}" + + return { + "id": message_id, + "type": "message", + "role": "assistant", + "model": model, + "content": content, + "stop_reason": stop_reason, + "stop_sequence": None, + "usage": { + "input_tokens": int(input_tokens or 0), + "output_tokens": int(output_tokens or 0), + }, + } + + +async def gemini_stream_to_anthropic_stream( + gemini_stream: AsyncIterator[bytes], + model: str, + status_code: int = 200 +) -> AsyncIterator[bytes]: + """ + 将 Gemini 格式流式响应转换为 Anthropic SSE 格式流式响应 + + 注意: 如果收到的不是 200 开头的响应体,不做任何处理,直接转发 + + Args: + gemini_stream: Gemini 格式的流式响应 (bytes 迭代器) + model: 模型名称 + status_code: HTTP 状态码 (默认 200) + + Yields: + Anthropic SSE 格式的响应块 (bytes) + """ + # 非 2xx 状态码直接转发原始流 + if not (200 <= status_code < 300): + async for chunk in gemini_stream: + yield chunk + return + + # 初始化状态 + message_id = f"msg_{uuid.uuid4().hex}" + message_start_sent = False + current_block_type: Optional[str] = None + current_block_index = -1 + current_thinking_signature: Optional[str] = None + has_tool_use = False + input_tokens = 0 + output_tokens = 0 + finish_reason: Optional[str] = None + + def _sse_event(event: str, data: Dict[str, Any]) -> bytes: + """生成 SSE 事件""" + payload = json.dumps(data, ensure_ascii=False, separators=(",", ":")) + return f"event: {event}\ndata: {payload}\n\n".encode("utf-8") + + def _close_block() -> Optional[bytes]: + """关闭当前内容块""" + nonlocal current_block_type + if current_block_type is None: + return None + event = _sse_event( + "content_block_stop", + {"type": "content_block_stop", "index": current_block_index}, + ) + current_block_type = None + return event + + # 处理流式数据 + try: + async for chunk in gemini_stream: + # 记录接收到的原始chunk + log.debug(f"[GEMINI_TO_ANTHROPIC] Raw chunk: {chunk[:200] if chunk else b''}") + + # 解析 Gemini 流式块 + if not chunk or not chunk.startswith(b"data: "): + log.debug(f"[GEMINI_TO_ANTHROPIC] Skipping chunk (not SSE format or empty)") + continue + + raw = chunk[6:].strip() + if raw == b"[DONE]": + log.debug(f"[GEMINI_TO_ANTHROPIC] Received [DONE] marker") + break + + log.debug(f"[GEMINI_TO_ANTHROPIC] Parsing JSON: {raw[:200]}") + + try: + data = json.loads(raw.decode('utf-8', errors='ignore')) + log.debug(f"[GEMINI_TO_ANTHROPIC] Parsed data: {json.dumps(data, ensure_ascii=False)[:300]}") + except Exception as e: + log.warning(f"[GEMINI_TO_ANTHROPIC] JSON parse error: {e}") + continue + + # 处理 GeminiCLI 的 response 包装格式 + if "response" in data: + response = data["response"] + else: + response = data + + candidate = (response.get("candidates", []) or [{}])[0] or {} + parts = (candidate.get("content", {}) or {}).get("parts", []) or [] + + # 更新 usage metadata + if "usageMetadata" in response: + usage = response["usageMetadata"] + if isinstance(usage, dict): + if "promptTokenCount" in usage: + input_tokens = int(usage.get("promptTokenCount", 0) or 0) + if "candidatesTokenCount" in usage: + output_tokens = int(usage.get("candidatesTokenCount", 0) or 0) + + # 发送 message_start(仅一次) + if not message_start_sent: + message_start_sent = True + yield _sse_event( + "message_start", + { + "type": "message_start", + "message": { + "id": message_id, + "type": "message", + "role": "assistant", + "model": model, + "content": [], + "stop_reason": None, + "stop_sequence": None, + "usage": {"input_tokens": 0, "output_tokens": 0}, + }, + }, + ) + + # 处理各种 parts + for part in parts: + if not isinstance(part, dict): + continue + + # 处理 thinking 块 + if part.get("thought") is True: + if current_block_type != "thinking": + close_evt = _close_block() + if close_evt: + yield close_evt + + current_block_index += 1 + current_block_type = "thinking" + signature = part.get("thoughtSignature") + current_thinking_signature = signature + + block: Dict[str, Any] = {"type": "thinking", "thinking": ""} + if signature: + block["signature"] = signature + + yield _sse_event( + "content_block_start", + { + "type": "content_block_start", + "index": current_block_index, + "content_block": block, + }, + ) + + thinking_text = part.get("text", "") + if thinking_text: + yield _sse_event( + "content_block_delta", + { + "type": "content_block_delta", + "index": current_block_index, + "delta": {"type": "thinking_delta", "thinking": thinking_text}, + }, + ) + continue + + # 处理文本块 + if "text" in part: + text = part.get("text", "") + if isinstance(text, str) and not text.strip(): + continue + + if current_block_type != "text": + close_evt = _close_block() + if close_evt: + yield close_evt + + current_block_index += 1 + current_block_type = "text" + + yield _sse_event( + "content_block_start", + { + "type": "content_block_start", + "index": current_block_index, + "content_block": {"type": "text", "text": ""}, + }, + ) + + if text: + yield _sse_event( + "content_block_delta", + { + "type": "content_block_delta", + "index": current_block_index, + "delta": {"type": "text_delta", "text": text}, + }, + ) + continue + + # 处理工具调用 + if "functionCall" in part: + close_evt = _close_block() + if close_evt: + yield close_evt + + has_tool_use = True + fc = part.get("functionCall", {}) or {} + original_id = fc.get("id") or f"toolu_{uuid.uuid4().hex}" + signature = part.get("thoughtSignature") + tool_id = encode_tool_id_with_signature(original_id, signature) + tool_name = fc.get("name") or "" + tool_args = _remove_nulls_for_tool_input(fc.get("args", {}) or {}) + + if _anthropic_debug_enabled(): + log.info( + f"[ANTHROPIC][tool_use] 处理工具调用: name={tool_name}, " + f"id={tool_id}, has_signature={signature is not None}" + ) + + current_block_index += 1 + # 注意:工具调用不设置 current_block_type,因为它是独立完整的块 + + yield _sse_event( + "content_block_start", + { + "type": "content_block_start", + "index": current_block_index, + "content_block": { + "type": "tool_use", + "id": tool_id, + "name": tool_name, + "input": {}, + }, + }, + ) + + input_json = json.dumps(tool_args, ensure_ascii=False, separators=(",", ":")) + yield _sse_event( + "content_block_delta", + { + "type": "content_block_delta", + "index": current_block_index, + "delta": {"type": "input_json_delta", "partial_json": input_json}, + }, + ) + + yield _sse_event( + "content_block_stop", + {"type": "content_block_stop", "index": current_block_index}, + ) + # 工具调用块已完全关闭,current_block_type 保持为 None + + if _anthropic_debug_enabled(): + log.info(f"[ANTHROPIC][tool_use] 工具调用块已关闭: index={current_block_index}") + + continue + + # 检查是否结束 + if candidate.get("finishReason"): + finish_reason = candidate.get("finishReason") + break + + # 关闭最后的内容块 + close_evt = _close_block() + if close_evt: + yield close_evt + + # 确定停止原因 + stop_reason = "tool_use" if has_tool_use else "end_turn" + if finish_reason == "MAX_TOKENS" and not has_tool_use: + stop_reason = "max_tokens" + + if _anthropic_debug_enabled(): + log.info( + f"[ANTHROPIC][stream_end] 流式结束: stop_reason={stop_reason}, " + f"has_tool_use={has_tool_use}, finish_reason={finish_reason}, " + f"input_tokens={input_tokens}, output_tokens={output_tokens}" + ) + + # 发送 message_delta 和 message_stop + yield _sse_event( + "message_delta", + { + "type": "message_delta", + "delta": {"stop_reason": stop_reason, "stop_sequence": None}, + "usage": { + "output_tokens": output_tokens, + }, + }, + ) + + yield _sse_event("message_stop", {"type": "message_stop"}) + + except Exception as e: + log.error(f"[ANTHROPIC] 流式转换失败: {e}") + # 发送错误事件 + if not message_start_sent: + yield _sse_event( + "message_start", + { + "type": "message_start", + "message": { + "id": message_id, + "type": "message", + "role": "assistant", + "model": model, + "content": [], + "stop_reason": None, + "stop_sequence": None, + "usage": {"input_tokens": 0, "output_tokens": 0}, + }, + }, + ) + yield _sse_event( + "error", + {"type": "error", "error": {"type": "api_error", "message": str(e)}}, + ) \ No newline at end of file diff --git a/src/converter/anti_truncation.py b/src/converter/anti_truncation.py new file mode 100644 index 0000000000000000000000000000000000000000..5f7c78db3a221f9bbdc1baf8f57a3623d7f82f3c --- /dev/null +++ b/src/converter/anti_truncation.py @@ -0,0 +1,699 @@ +""" +Anti-Truncation Module - Ensures complete streaming output +保持一个流式请求内完整输出的反截断模块 +""" + +import io +import json +import re +from typing import Any, AsyncGenerator, Dict, List, Tuple + +from fastapi.responses import StreamingResponse + +from log import log + +# 反截断配置 +DONE_MARKER = "[done]" +CONTINUATION_PROMPT = f"""请从刚才被截断的地方继续输出剩余的所有内容。 + +重要提醒: +1. 不要重复前面已经输出的内容 +2. 直接继续输出,无需任何前言或解释 +3. 当你完整完成所有内容输出后,必须在最后一行单独输出:{DONE_MARKER} +4. {DONE_MARKER} 标记表示你的回答已经完全结束,这是必需的结束标记 + +现在请继续输出:""" + +# 正则替换配置 +REGEX_REPLACEMENTS: List[Tuple[str, str, str]] = [ + ( + "age_pattern", # 替换规则名称 + r"(?:[1-9]|1[0-8])岁(?:的)?|(?:十一|十二|十三|十四|十五|十六|十七|十八|十|一|二|三|四|五|六|七|八|九)岁(?:的)?", # 正则模式 + "", # 替换文本 + ), + # 可在此处添加更多替换规则 + # ("rule_name", r"pattern", "replacement"), +] + + +def apply_regex_replacements(text: str) -> str: + """ + 对文本应用正则替换规则 + + Args: + text: 要处理的文本 + + Returns: + 处理后的文本 + """ + if not text: + return text + + processed_text = text + replacement_count = 0 + + for rule_name, pattern, replacement in REGEX_REPLACEMENTS: + try: + # 编译正则表达式,使用IGNORECASE标志 + regex = re.compile(pattern, re.IGNORECASE) + + # 执行替换 + new_text, count = regex.subn(replacement, processed_text) + + if count > 0: + log.debug(f"Regex replacement '{rule_name}': {count} matches replaced") + processed_text = new_text + replacement_count += count + + except re.error as e: + log.error(f"Invalid regex pattern in rule '{rule_name}': {e}") + continue + + if replacement_count > 0: + log.info(f"Applied {replacement_count} regex replacements to text") + + return processed_text + + +def apply_regex_replacements_to_payload(payload: Dict[str, Any]) -> Dict[str, Any]: + """ + 对请求payload中的文本内容应用正则替换 + + Args: + payload: 请求payload + + Returns: + 应用替换后的payload + """ + if not REGEX_REPLACEMENTS: + return payload + + modified_payload = payload.copy() + request_data = modified_payload.get("request", {}) + + # 处理contents中的文本 + contents = request_data.get("contents", []) + if contents: + new_contents = [] + for content in contents: + if isinstance(content, dict): + new_content = content.copy() + parts = new_content.get("parts", []) + if parts: + new_parts = [] + for part in parts: + if isinstance(part, dict) and "text" in part: + new_part = part.copy() + new_part["text"] = apply_regex_replacements(part["text"]) + new_parts.append(new_part) + else: + new_parts.append(part) + new_content["parts"] = new_parts + new_contents.append(new_content) + else: + new_contents.append(content) + + request_data["contents"] = new_contents + modified_payload["request"] = request_data + log.debug("Applied regex replacements to request contents") + + return modified_payload + + +def apply_anti_truncation(payload: Dict[str, Any]) -> Dict[str, Any]: + """ + 对请求payload应用反截断处理和正则替换 + 在systemInstruction中添加提醒,要求模型在结束时输出DONE_MARKER标记 + + Args: + payload: 原始请求payload + + Returns: + 添加了反截断指令并应用了正则替换的payload + """ + # 首先应用正则替换 + modified_payload = apply_regex_replacements_to_payload(payload) + request_data = modified_payload.get("request", {}) + + # 获取或创建systemInstruction + system_instruction = request_data.get("systemInstruction", {}) + if not system_instruction: + system_instruction = {"parts": []} + elif "parts" not in system_instruction: + system_instruction["parts"] = [] + + # 添加反截断指令 + anti_truncation_instruction = { + "text": f"""严格执行以下输出结束规则: + +1. 当你完成完整回答时,必须在输出的最后单独一行输出:{DONE_MARKER} +2. {DONE_MARKER} 标记表示你的回答已经完全结束,这是必需的结束标记 +3. 只有输出了 {DONE_MARKER} 标记,系统才认为你的回答是完整的 +4. 如果你的回答被截断,系统会要求你继续输出剩余内容 +5. 无论回答长短,都必须以 {DONE_MARKER} 标记结束 + +示例格式: +``` +你的回答内容... +更多回答内容... +{DONE_MARKER} +``` + +注意:{DONE_MARKER} 必须单独占一行,前面不要有任何其他字符。 + +这个规则对于确保输出完整性极其重要,请严格遵守。""" + } + + # 检查是否已经包含反截断指令 + has_done_instruction = any( + part.get("text", "").find(DONE_MARKER) != -1 + for part in system_instruction["parts"] + if isinstance(part, dict) + ) + + if not has_done_instruction: + system_instruction["parts"].append(anti_truncation_instruction) + request_data["systemInstruction"] = system_instruction + modified_payload["request"] = request_data + + log.debug("Applied anti-truncation instruction to request") + + return modified_payload + + +class AntiTruncationStreamProcessor: + """反截断流式处理器""" + + def __init__( + self, + original_request_func, + payload: Dict[str, Any], + max_attempts: int = 3, + ): + self.original_request_func = original_request_func + self.base_payload = payload.copy() + self.max_attempts = max_attempts + # 使用 StringIO 避免字符串拼接的内存问题 + self.collected_content = io.StringIO() + self.current_attempt = 0 + + def _get_collected_text(self) -> str: + """获取收集的文本内容""" + return self.collected_content.getvalue() + + def _append_content(self, content: str): + """追加内容到收集器""" + if content: + self.collected_content.write(content) + + def _clear_content(self): + """清空收集的内容,释放内存""" + self.collected_content.close() + self.collected_content = io.StringIO() + + async def process_stream(self) -> AsyncGenerator[bytes, None]: + """处理流式响应,检测并处理截断""" + + while self.current_attempt < self.max_attempts: + self.current_attempt += 1 + + # 构建当前请求payload + current_payload = self._build_current_payload() + + log.debug(f"Anti-truncation attempt {self.current_attempt}/{self.max_attempts}") + + # 发送请求 + try: + response = await self.original_request_func(current_payload) + + if not isinstance(response, StreamingResponse): + # 非流式响应,直接处理 + yield await self._handle_non_streaming_response(response) + return + + # 处理流式响应(按行处理) + chunk_buffer = io.StringIO() # 使用 StringIO 缓存当前轮次的chunk + found_done_marker = False + + async for line in response.body_iterator: + if not line: + yield line + continue + + # 处理 bytes 类型的流式数据 + if isinstance(line, bytes): + # 解码 bytes 为字符串 + line_str = line.decode('utf-8', errors='ignore').strip() + else: + line_str = str(line).strip() + + # 跳过空行 + if not line_str: + yield line + continue + + # 处理 SSE 格式的数据行 + if line_str.startswith("data: "): + payload_str = line_str[6:] # 去掉 "data: " 前缀 + + # 检查是否是 [DONE] 标记 + if payload_str.strip() == "[DONE]": + if found_done_marker: + log.info("Anti-truncation: Found [done] marker, output complete") + yield line + # 清理内存 + chunk_buffer.close() + self._clear_content() + return + else: + log.warning("Anti-truncation: Stream ended without [done] marker") + # 不发送[DONE],准备继续 + break + + # 尝试解析 JSON 数据 + try: + data = json.loads(payload_str) + content = self._extract_content_from_chunk(data) + + log.debug(f"Anti-truncation: Extracted content: {repr(content[:100] if content else '')}") + + if content: + chunk_buffer.write(content) + + # 检查是否包含done标记 + has_marker = self._check_done_marker_in_chunk_content(content) + log.debug(f"Anti-truncation: Check done marker result: {has_marker}, DONE_MARKER='{DONE_MARKER}'") + if has_marker: + found_done_marker = True + log.debug(f"Anti-truncation: Found [done] marker in chunk, content: {content[:200]}") + + # 清理行中的[done]标记后再发送 + cleaned_line = self._remove_done_marker_from_line(line, line_str, data) + yield cleaned_line + + except (json.JSONDecodeError, ValueError): + # 无法解析的行,直接传递 + yield line + continue + else: + # 非 data: 开头的行,直接传递 + yield line + + # 更新收集的内容 - 使用 StringIO 高效处理 + chunk_text = chunk_buffer.getvalue() + if chunk_text: + self._append_content(chunk_text) + chunk_buffer.close() + + log.debug(f"Anti-truncation: After processing stream, found_done_marker={found_done_marker}") + + # 如果找到了done标记,结束 + if found_done_marker: + # 立即清理内容释放内存 + self._clear_content() + yield b"data: [DONE]\n\n" + return + + # 只有在单个chunk中没有找到done标记时,才检查累积内容(防止done标记跨chunk出现) + if not found_done_marker: + accumulated_text = self._get_collected_text() + if self._check_done_marker_in_text(accumulated_text): + log.info("Anti-truncation: Found [done] marker in accumulated content") + # 立即清理内容释放内存 + self._clear_content() + yield b"data: [DONE]\n\n" + return + + # 如果没找到done标记且不是最后一次尝试,准备续传 + if self.current_attempt < self.max_attempts: + accumulated_text = self._get_collected_text() + total_length = len(accumulated_text) + log.info( + f"Anti-truncation: No [done] marker found in output (length: {total_length}), preparing continuation (attempt {self.current_attempt + 1})" + ) + if total_length > 100: + log.debug( + f"Anti-truncation: Current collected content ends with: ...{accumulated_text[-100:]}" + ) + # 在下一次循环中会继续 + continue + else: + # 最后一次尝试,直接结束 + log.warning("Anti-truncation: Max attempts reached, ending stream") + # 立即清理内容释放内存 + self._clear_content() + yield b"data: [DONE]\n\n" + return + + except Exception as e: + log.error(f"Anti-truncation error in attempt {self.current_attempt}: {str(e)}") + if self.current_attempt >= self.max_attempts: + # 发送错误chunk + error_chunk = { + "error": { + "message": f"Anti-truncation failed: {str(e)}", + "type": "api_error", + "code": 500, + } + } + yield f"data: {json.dumps(error_chunk)}\n\n".encode() + yield b"data: [DONE]\n\n" + return + # 否则继续下一次尝试 + + # 如果所有尝试都失败了 + log.error("Anti-truncation: All attempts failed") + # 清理内存 + self._clear_content() + yield b"data: [DONE]\n\n" + + def _build_current_payload(self) -> Dict[str, Any]: + """构建当前请求的payload""" + if self.current_attempt == 1: + # 第一次请求,使用原始payload(已经包含反截断指令) + return self.base_payload + + # 后续请求,添加续传指令 + continuation_payload = self.base_payload.copy() + request_data = continuation_payload.get("request", {}) + + # 获取原始对话内容 + contents = request_data.get("contents", []) + new_contents = contents.copy() + + # 如果有收集到的内容,添加到对话中 + accumulated_text = self._get_collected_text() + if accumulated_text: + new_contents.append({"role": "model", "parts": [{"text": accumulated_text}]}) + + # 构建具体的续写指令,包含前面的内容摘要 + content_summary = "" + if accumulated_text: + if len(accumulated_text) > 200: + content_summary = f'\n\n前面你已经输出了约 {len(accumulated_text)} 个字符的内容,结尾是:\n"...{accumulated_text[-100:]}"' + else: + content_summary = f'\n\n前面你已经输出的内容是:\n"{accumulated_text}"' + + detailed_continuation_prompt = f"""{CONTINUATION_PROMPT}{content_summary}""" + + # 添加继续指令 + continuation_message = {"role": "user", "parts": [{"text": detailed_continuation_prompt}]} + new_contents.append(continuation_message) + + request_data["contents"] = new_contents + continuation_payload["request"] = request_data + + return continuation_payload + + def _extract_content_from_chunk(self, data: Dict[str, Any]) -> str: + """从chunk数据中提取文本内容""" + content = "" + + # 先尝试解包 response 字段(Gemini API 格式) + if "response" in data: + data = data["response"] + + # 处理 Gemini 格式 + if "candidates" in data: + for candidate in data["candidates"]: + if "content" in candidate: + parts = candidate["content"].get("parts", []) + for part in parts: + if "text" in part: + content += part["text"] + + # 处理 OpenAI 流式格式(choices/delta) + elif "choices" in data: + for choice in data["choices"]: + if "delta" in choice and "content" in choice["delta"]: + delta_content = choice["delta"]["content"] + if delta_content: + content += delta_content + + return content + + async def _handle_non_streaming_response(self, response) -> bytes: + """处理非流式响应 - 使用循环代替递归避免栈溢出""" + # 使用循环代替递归 + while True: + try: + # 特殊处理:如果返回的是StreamingResponse,需要读取其body_iterator + if isinstance(response, StreamingResponse): + log.error("Anti-truncation: Received StreamingResponse in non-streaming handler - this should not happen") + # 尝试读取流式响应的内容 + chunks = [] + async for chunk in response.body_iterator: + chunks.append(chunk) + content = b"".join(chunks).decode() if chunks else "" + # 提取响应内容 + elif hasattr(response, "body"): + content = ( + response.body.decode() if isinstance(response.body, bytes) else response.body + ) + elif hasattr(response, "content"): + content = ( + response.content.decode() + if isinstance(response.content, bytes) + else response.content + ) + else: + log.error(f"Anti-truncation: Unknown response type: {type(response)}") + content = str(response) + + # 验证内容不为空 + if not content or not content.strip(): + log.error("Anti-truncation: Received empty response content") + return json.dumps( + { + "error": { + "message": "Empty response from server", + "type": "api_error", + "code": 500, + } + } + ).encode() + + # 尝试解析 JSON + try: + response_data = json.loads(content) + except json.JSONDecodeError as json_err: + log.error(f"Anti-truncation: Failed to parse JSON response: {json_err}, content: {content[:200]}") + # 如果不是 JSON,直接返回原始内容 + return content.encode() if isinstance(content, str) else content + + # 检查是否包含done标记 + text_content = self._extract_content_from_response(response_data) + has_done_marker = self._check_done_marker_in_text(text_content) + + if has_done_marker or self.current_attempt >= self.max_attempts: + # 找到done标记或达到最大尝试次数,返回结果 + return content.encode() if isinstance(content, str) else content + + # 需要继续,收集内容并构建下一个请求 + if text_content: + self._append_content(text_content) + + log.info("Anti-truncation: Non-streaming response needs continuation") + + # 增加尝试次数 + self.current_attempt += 1 + + # 构建续传payload并发送下一个请求 + next_payload = self._build_current_payload() + response = await self.original_request_func(next_payload) + + # 继续循环处理下一个响应 + + except Exception as e: + log.error(f"Anti-truncation non-streaming error: {str(e)}") + return json.dumps( + { + "error": { + "message": f"Anti-truncation failed: {str(e)}", + "type": "api_error", + "code": 500, + } + } + ).encode() + + def _check_done_marker_in_text(self, text: str) -> bool: + """检测文本中是否包含DONE_MARKER(只检测指定标记)""" + if not text: + return False + + # 只要文本中出现DONE_MARKER即可 + return DONE_MARKER in text + + def _check_done_marker_in_chunk_content(self, content: str) -> bool: + """检查单个chunk内容中是否包含done标记""" + return self._check_done_marker_in_text(content) + + def _extract_content_from_response(self, data: Dict[str, Any]) -> str: + """从响应数据中提取文本内容""" + content = "" + + # 先尝试解包 response 字段(Gemini API 格式) + if "response" in data: + data = data["response"] + + # 处理Gemini格式 + if "candidates" in data: + for candidate in data["candidates"]: + if "content" in candidate: + parts = candidate["content"].get("parts", []) + for part in parts: + if "text" in part: + content += part["text"] + + # 处理OpenAI格式 + elif "choices" in data: + for choice in data["choices"]: + if "message" in choice and "content" in choice["message"]: + content += choice["message"]["content"] + + return content + + def _remove_done_marker_from_line(self, line: bytes, line_str: str, data: Dict[str, Any]) -> bytes: + """从行中移除[done]标记""" + try: + # 首先检查是否真的包含[done]标记 + if "[done]" not in line_str.lower(): + return line # 没有[done]标记,直接返回原始行 + + log.info(f"Anti-truncation: Attempting to remove [done] marker from line") + log.debug(f"Anti-truncation: Original line (first 200 chars): {line_str[:200]}") + + # 编译正则表达式,匹配[done]标记(忽略大小写,包括可能的空白字符) + done_pattern = re.compile(r"\s*\[done\]\s*", re.IGNORECASE) + + # 检查是否有 response 包裹层 + has_response_wrapper = "response" in data + log.debug(f"Anti-truncation: has_response_wrapper={has_response_wrapper}, data keys={list(data.keys())}") + if has_response_wrapper: + # 需要保留外层的 response 字段 + inner_data = data["response"] + else: + inner_data = data + + log.debug(f"Anti-truncation: inner_data keys={list(inner_data.keys())}") + + log.debug(f"Anti-truncation: inner_data keys={list(inner_data.keys())}") + + # 处理Gemini格式 + if "candidates" in inner_data: + log.info(f"Anti-truncation: Processing Gemini format to remove [done] marker") + modified_inner = inner_data.copy() + modified_inner["candidates"] = [] + + for i, candidate in enumerate(inner_data["candidates"]): + modified_candidate = candidate.copy() + # 只在最后一个candidate中清理[done]标记 + is_last_candidate = i == len(inner_data["candidates"]) - 1 + + if "content" in candidate: + modified_content = candidate["content"].copy() + if "parts" in modified_content: + modified_parts = [] + for part in modified_content["parts"]: + if "text" in part and isinstance(part["text"], str): + modified_part = part.copy() + original_text = part["text"] + # 只在最后一个candidate中清理[done]标记 + if is_last_candidate: + modified_part["text"] = done_pattern.sub("", part["text"]) + if "[done]" in original_text.lower(): + log.debug(f"Anti-truncation: Removed [done] from text: '{original_text[:100]}' -> '{modified_part['text'][:100]}'") + modified_parts.append(modified_part) + else: + modified_parts.append(part) + modified_content["parts"] = modified_parts + modified_candidate["content"] = modified_content + modified_inner["candidates"].append(modified_candidate) + + # 如果有 response 包裹层,需要重新包装 + if has_response_wrapper: + modified_data = data.copy() + modified_data["response"] = modified_inner + else: + modified_data = modified_inner + + # 重新编码为行格式 - SSE格式需要两个换行符 + json_str = json.dumps(modified_data, separators=(",", ":"), ensure_ascii=False) + result = f"data: {json_str}\n\n".encode("utf-8") + log.debug(f"Anti-truncation: Modified line (first 200 chars): {result.decode('utf-8', errors='ignore')[:200]}") + return result + + # 处理OpenAI格式 + elif "choices" in inner_data: + modified_inner = inner_data.copy() + modified_inner["choices"] = [] + + for choice in inner_data["choices"]: + modified_choice = choice.copy() + if "delta" in choice and "content" in choice["delta"]: + modified_delta = choice["delta"].copy() + modified_delta["content"] = done_pattern.sub("", choice["delta"]["content"]) + modified_choice["delta"] = modified_delta + elif "message" in choice and "content" in choice["message"]: + modified_message = choice["message"].copy() + modified_message["content"] = done_pattern.sub("", choice["message"]["content"]) + modified_choice["message"] = modified_message + modified_inner["choices"].append(modified_choice) + + # 如果有 response 包裹层,需要重新包装 + if has_response_wrapper: + modified_data = data.copy() + modified_data["response"] = modified_inner + else: + modified_data = modified_inner + + # 重新编码为行格式 - SSE格式需要两个换行符 + json_str = json.dumps(modified_data, separators=(",", ":"), ensure_ascii=False) + return f"data: {json_str}\n\n".encode("utf-8") + + # 如果没有找到支持的格式,返回原始行 + return line + + except Exception as e: + log.warning(f"Failed to remove [done] marker from line: {str(e)}") + return line + + +async def apply_anti_truncation_to_stream( + request_func, payload: Dict[str, Any], max_attempts: int = 3 +) -> StreamingResponse: + """ + 对流式请求应用反截断处理 + + Args: + request_func: 原始请求函数 + payload: 请求payload + max_attempts: 最大续传尝试次数 + + Returns: + 处理后的StreamingResponse + """ + + # 首先对payload应用反截断指令 + anti_truncation_payload = apply_anti_truncation(payload) + + # 创建反截断处理器 + processor = AntiTruncationStreamProcessor( + lambda p: request_func(p), anti_truncation_payload, max_attempts + ) + + # 返回包装后的流式响应 + return StreamingResponse(processor.process_stream(), media_type="text/event-stream") + + +def is_anti_truncation_enabled(request_data: Dict[str, Any]) -> bool: + """ + 检查请求是否启用了反截断功能 + + Args: + request_data: 请求数据 + + Returns: + 是否启用反截断 + """ + return request_data.get("enable_anti_truncation", False) \ No newline at end of file diff --git a/src/converter/fake_stream.py b/src/converter/fake_stream.py new file mode 100644 index 0000000000000000000000000000000000000000..7fb0ba1062df7874a7ba3976730e1a51cca47a54 --- /dev/null +++ b/src/converter/fake_stream.py @@ -0,0 +1,537 @@ +from typing import Any, Dict, List, Tuple +import json +from src.converter.utils import extract_content_and_reasoning +from log import log +from src.converter.openai2gemini import _convert_usage_metadata + +def safe_get_nested(obj: Any, *keys: str, default: Any = None) -> Any: + """安全获取嵌套字典值 + + Args: + obj: 字典对象 + *keys: 嵌套键路径 + default: 默认值 + + Returns: + 获取到的值或默认值 + """ + for key in keys: + if not isinstance(obj, dict): + return default + obj = obj.get(key, default) + if obj is default: + return default + return obj + +def parse_response_for_fake_stream(response_data: Dict[str, Any]) -> tuple: + """从完整响应中提取内容和推理内容(用于假流式) + + Args: + response_data: Gemini API 响应数据 + + Returns: + (content, reasoning_content, finish_reason, images): 内容、推理内容、结束原因和图片数据的元组 + """ + import json + + # 处理GeminiCLI的response包装格式 + if "response" in response_data and "candidates" not in response_data: + log.debug(f"[FAKE_STREAM] Unwrapping response field") + response_data = response_data["response"] + + candidates = response_data.get("candidates", []) + log.debug(f"[FAKE_STREAM] Found {len(candidates)} candidates") + if not candidates: + return "", "", "STOP", [] + + candidate = candidates[0] + finish_reason = candidate.get("finishReason", "STOP") + parts = safe_get_nested(candidate, "content", "parts", default=[]) + log.debug(f"[FAKE_STREAM] Extracted {len(parts)} parts: {json.dumps(parts, ensure_ascii=False)}") + content, reasoning_content, images = extract_content_and_reasoning(parts) + log.debug(f"[FAKE_STREAM] Content length: {len(content)}, Reasoning length: {len(reasoning_content)}, Images count: {len(images)}") + + return content, reasoning_content, finish_reason, images + +def extract_fake_stream_content(response: Any) -> Tuple[str, str, Dict[str, int]]: + """ + 从 Gemini 非流式响应中提取内容,用于假流式处理 + + Args: + response: Gemini API 响应对象 + + Returns: + (content, reasoning_content, usage) 元组 + """ + from src.converter.utils import extract_content_and_reasoning + + # 解析响应体 + if hasattr(response, "body"): + body_str = ( + response.body.decode() + if isinstance(response.body, bytes) + else str(response.body) + ) + elif hasattr(response, "content"): + body_str = ( + response.content.decode() + if isinstance(response.content, bytes) + else str(response.content) + ) + else: + body_str = str(response) + + try: + response_data = json.loads(body_str) + + # GeminiCLI 返回的格式是 {"response": {...}, "traceId": "..."} + # 需要先提取 response 字段 + if "response" in response_data: + gemini_response = response_data["response"] + else: + gemini_response = response_data + + # 从Gemini响应中提取内容,使用思维链分离逻辑 + content = "" + reasoning_content = "" + images = [] + if "candidates" in gemini_response and gemini_response["candidates"]: + # Gemini格式响应 - 使用思维链分离 + candidate = gemini_response["candidates"][0] + if "content" in candidate and "parts" in candidate["content"]: + parts = candidate["content"]["parts"] + content, reasoning_content, images = extract_content_and_reasoning(parts) + elif "choices" in gemini_response and gemini_response["choices"]: + # OpenAI格式响应 + content = gemini_response["choices"][0].get("message", {}).get("content", "") + + # 如果没有正常内容但有思维内容,给出警告 + if not content and reasoning_content: + log.warning("Fake stream response contains only thinking content") + content = "[模型正在思考中,请稍后再试或重新提问]" + + # 如果完全没有内容,提供默认回复 + if not content: + log.warning(f"No content found in response: {gemini_response}") + content = "[响应为空,请重新尝试]" + + # 转换usageMetadata为OpenAI格式 + usage = _convert_usage_metadata(gemini_response.get("usageMetadata")) + + return content, reasoning_content, usage + + except json.JSONDecodeError: + # 如果不是JSON,直接返回原始文本 + return body_str, "", None + +def _build_candidate(parts: List[Dict[str, Any]], finish_reason: str = "STOP") -> Dict[str, Any]: + """构建标准候选响应结构 + + Args: + parts: parts 列表 + finish_reason: 结束原因 + + Returns: + 候选响应字典 + """ + return { + "candidates": [{ + "content": {"parts": parts, "role": "model"}, + "finishReason": finish_reason, + "index": 0, + }] + } + +def create_openai_heartbeat_chunk() -> Dict[str, Any]: + """ + 创建 OpenAI 格式的心跳块(用于假流式) + + Returns: + 心跳响应块字典 + """ + return { + "choices": [ + { + "index": 0, + "delta": {"role": "assistant", "content": ""}, + "finish_reason": None, + } + ] + } + +def build_gemini_fake_stream_chunks(content: str, reasoning_content: str, finish_reason: str, images: List[Dict[str, Any]] = None, chunk_size: int = 50) -> List[Dict[str, Any]]: + """构建假流式响应的数据块 + + Args: + content: 主要内容 + reasoning_content: 推理内容 + finish_reason: 结束原因 + images: 图片数据列表(可选) + chunk_size: 每个chunk的字符数(默认50) + + Returns: + 响应数据块列表 + """ + if images is None: + images = [] + + log.debug(f"[build_gemini_fake_stream_chunks] Input - content: {repr(content)}, reasoning: {repr(reasoning_content)}, finish_reason: {finish_reason}, images count: {len(images)}") + chunks = [] + + # 如果没有正常内容但有思维内容,提供默认回复 + if not content: + default_text = "[模型正在思考中,请稍后再试或重新提问]" if reasoning_content else "[响应为空,请重新尝试]" + return [_build_candidate([{"text": default_text}], finish_reason)] + + # 分块发送主要内容 + first_chunk = True + for i in range(0, len(content), chunk_size): + chunk_text = content[i:i + chunk_size] + is_last_chunk = (i + chunk_size >= len(content)) and not reasoning_content + chunk_finish_reason = finish_reason if is_last_chunk else None + + # 如果是第一个chunk且有图片,将图片包含在parts中 + parts = [] + if first_chunk and images: + # 在Gemini格式中,需要将image_url格式转换为inlineData格式 + for img in images: + if img.get("type") == "image_url": + url = img.get("image_url", {}).get("url", "") + # 解析 data URL: data:{mime_type};base64,{data} + if url.startswith("data:"): + parts_of_url = url.split(";base64,") + if len(parts_of_url) == 2: + mime_type = parts_of_url[0].replace("data:", "") + base64_data = parts_of_url[1] + parts.append({ + "inlineData": { + "mimeType": mime_type, + "data": base64_data + } + }) + first_chunk = False + + parts.append({"text": chunk_text}) + chunk_data = _build_candidate(parts, chunk_finish_reason) + log.debug(f"[build_gemini_fake_stream_chunks] Generated chunk: {chunk_data}") + chunks.append(chunk_data) + + # 如果有推理内容,分块发送 + if reasoning_content: + for i in range(0, len(reasoning_content), chunk_size): + chunk_text = reasoning_content[i:i + chunk_size] + is_last_chunk = i + chunk_size >= len(reasoning_content) + chunk_finish_reason = finish_reason if is_last_chunk else None + chunks.append(_build_candidate([{"text": chunk_text, "thought": True}], chunk_finish_reason)) + + log.debug(f"[build_gemini_fake_stream_chunks] Total chunks generated: {len(chunks)}") + return chunks + + +def create_gemini_heartbeat_chunk() -> Dict[str, Any]: + """创建 Gemini 格式的心跳数据块 + + Returns: + 心跳数据块 + """ + chunk = _build_candidate([{"text": ""}]) + chunk["candidates"][0]["finishReason"] = None + return chunk + + +def build_openai_fake_stream_chunks(content: str, reasoning_content: str, finish_reason: str, model: str, images: List[Dict[str, Any]] = None, chunk_size: int = 50) -> List[Dict[str, Any]]: + """构建 OpenAI 格式的假流式响应数据块 + + Args: + content: 主要内容 + reasoning_content: 推理内容 + finish_reason: 结束原因(如 "STOP", "MAX_TOKENS") + model: 模型名称 + images: 图片数据列表(可选) + chunk_size: 每个chunk的字符数(默认50) + + Returns: + OpenAI 格式的响应数据块列表 + """ + import time + import uuid + + if images is None: + images = [] + + log.debug(f"[build_openai_fake_stream_chunks] Input - content: {repr(content)}, reasoning: {repr(reasoning_content)}, finish_reason: {finish_reason}, images count: {len(images)}") + chunks = [] + response_id = f"chatcmpl-{uuid.uuid4().hex[:24]}" + created = int(time.time()) + + # 映射 Gemini finish_reason 到 OpenAI 格式 + openai_finish_reason = None + if finish_reason == "STOP": + openai_finish_reason = "stop" + elif finish_reason == "MAX_TOKENS": + openai_finish_reason = "length" + elif finish_reason in ["SAFETY", "RECITATION"]: + openai_finish_reason = "content_filter" + + # 如果没有正常内容但有思维内容,提供默认回复 + if not content: + default_text = "[模型正在思考中,请稍后再试或重新提问]" if reasoning_content else "[响应为空,请重新尝试]" + return [{ + "id": response_id, + "object": "chat.completion.chunk", + "created": created, + "model": model, + "choices": [{ + "index": 0, + "delta": {"content": default_text}, + "finish_reason": openai_finish_reason, + }] + }] + + # 分块发送主要内容 + first_chunk = True + for i in range(0, len(content), chunk_size): + chunk_text = content[i:i + chunk_size] + is_last_chunk = (i + chunk_size >= len(content)) and not reasoning_content + chunk_finish = openai_finish_reason if is_last_chunk else None + + delta_content = {} + + # 如果是第一个chunk且有图片,构建包含图片的content数组 + if first_chunk and images: + delta_content["content"] = images + [{"type": "text", "text": chunk_text}] + first_chunk = False + else: + delta_content["content"] = chunk_text + + chunk_data = { + "id": response_id, + "object": "chat.completion.chunk", + "created": created, + "model": model, + "choices": [{ + "index": 0, + "delta": delta_content, + "finish_reason": chunk_finish, + }] + } + log.debug(f"[build_openai_fake_stream_chunks] Generated chunk: {chunk_data}") + chunks.append(chunk_data) + + # 如果有推理内容,分块发送(使用 reasoning_content 字段) + if reasoning_content: + for i in range(0, len(reasoning_content), chunk_size): + chunk_text = reasoning_content[i:i + chunk_size] + is_last_chunk = i + chunk_size >= len(reasoning_content) + chunk_finish = openai_finish_reason if is_last_chunk else None + + chunks.append({ + "id": response_id, + "object": "chat.completion.chunk", + "created": created, + "model": model, + "choices": [{ + "index": 0, + "delta": {"reasoning_content": chunk_text}, + "finish_reason": chunk_finish, + }] + }) + + log.debug(f"[build_openai_fake_stream_chunks] Total chunks generated: {len(chunks)}") + return chunks + + +def create_anthropic_heartbeat_chunk() -> Dict[str, Any]: + """ + 创建 Anthropic 格式的心跳块(用于假流式) + + Returns: + 心跳响应块字典 + """ + return { + "type": "ping" + } + + +def build_anthropic_fake_stream_chunks(content: str, reasoning_content: str, finish_reason: str, model: str, images: List[Dict[str, Any]] = None, chunk_size: int = 50) -> List[Dict[str, Any]]: + """构建 Anthropic 格式的假流式响应数据块 + + Args: + content: 主要内容 + reasoning_content: 推理内容(thinking content) + finish_reason: 结束原因(如 "STOP", "MAX_TOKENS") + model: 模型名称 + images: 图片数据列表(可选) + chunk_size: 每个chunk的字符数(默认50) + + Returns: + Anthropic SSE 格式的响应数据块列表 + """ + import uuid + + if images is None: + images = [] + + log.debug(f"[build_anthropic_fake_stream_chunks] Input - content: {repr(content)}, reasoning: {repr(reasoning_content)}, finish_reason: {finish_reason}, images count: {len(images)}") + chunks = [] + message_id = f"msg_{uuid.uuid4().hex}" + + # 映射 Gemini finish_reason 到 Anthropic 格式 + anthropic_stop_reason = "end_turn" + if finish_reason == "MAX_TOKENS": + anthropic_stop_reason = "max_tokens" + elif finish_reason in ["SAFETY", "RECITATION"]: + anthropic_stop_reason = "end_turn" + + # 1. 发送 message_start 事件 + chunks.append({ + "type": "message_start", + "message": { + "id": message_id, + "type": "message", + "role": "assistant", + "model": model, + "content": [], + "stop_reason": None, + "stop_sequence": None, + "usage": {"input_tokens": 0, "output_tokens": 0} + } + }) + + # 如果没有正常内容但有思维内容,提供默认回复 + if not content: + default_text = "[模型正在思考中,请稍后再试或重新提问]" if reasoning_content else "[响应为空,请重新尝试]" + + # content_block_start + chunks.append({ + "type": "content_block_start", + "index": 0, + "content_block": {"type": "text", "text": ""} + }) + + # content_block_delta + chunks.append({ + "type": "content_block_delta", + "index": 0, + "delta": {"type": "text_delta", "text": default_text} + }) + + # content_block_stop + chunks.append({ + "type": "content_block_stop", + "index": 0 + }) + + # message_delta + chunks.append({ + "type": "message_delta", + "delta": {"stop_reason": anthropic_stop_reason, "stop_sequence": None}, + "usage": {"output_tokens": 0} + }) + + # message_stop + chunks.append({ + "type": "message_stop" + }) + + return chunks + + block_index = 0 + + # 2. 如果有推理内容,先发送 thinking 块 + if reasoning_content: + # thinking content_block_start + chunks.append({ + "type": "content_block_start", + "index": block_index, + "content_block": {"type": "thinking", "thinking": ""} + }) + + # 分块发送推理内容 + for i in range(0, len(reasoning_content), chunk_size): + chunk_text = reasoning_content[i:i + chunk_size] + chunks.append({ + "type": "content_block_delta", + "index": block_index, + "delta": {"type": "thinking_delta", "thinking": chunk_text} + }) + + # thinking content_block_stop + chunks.append({ + "type": "content_block_stop", + "index": block_index + }) + + block_index += 1 + + # 3. 如果有图片,发送图片块 + if images: + for img in images: + if img.get("type") == "image_url": + url = img.get("image_url", {}).get("url", "") + # 解析 data URL: data:{mime_type};base64,{data} + if url.startswith("data:"): + parts_of_url = url.split(";base64,") + if len(parts_of_url) == 2: + mime_type = parts_of_url[0].replace("data:", "") + base64_data = parts_of_url[1] + + # image content_block_start + chunks.append({ + "type": "content_block_start", + "index": block_index, + "content_block": { + "type": "image", + "source": { + "type": "base64", + "media_type": mime_type, + "data": base64_data + } + } + }) + + # image content_block_stop + chunks.append({ + "type": "content_block_stop", + "index": block_index + }) + + block_index += 1 + + # 4. 发送主要内容(text 块) + # text content_block_start + chunks.append({ + "type": "content_block_start", + "index": block_index, + "content_block": {"type": "text", "text": ""} + }) + + # 分块发送主要内容 + for i in range(0, len(content), chunk_size): + chunk_text = content[i:i + chunk_size] + chunks.append({ + "type": "content_block_delta", + "index": block_index, + "delta": {"type": "text_delta", "text": chunk_text} + }) + + # text content_block_stop + chunks.append({ + "type": "content_block_stop", + "index": block_index + }) + + # 5. 发送 message_delta + chunks.append({ + "type": "message_delta", + "delta": {"stop_reason": anthropic_stop_reason, "stop_sequence": None}, + "usage": {"output_tokens": len(content) + len(reasoning_content)} + }) + + # 6. 发送 message_stop + chunks.append({ + "type": "message_stop" + }) + + log.debug(f"[build_anthropic_fake_stream_chunks] Total chunks generated: {len(chunks)}") + return chunks \ No newline at end of file diff --git a/src/converter/gemini_fix.py b/src/converter/gemini_fix.py new file mode 100644 index 0000000000000000000000000000000000000000..ccb2f24292bf45dc773818ef92c01231e054eaeb --- /dev/null +++ b/src/converter/gemini_fix.py @@ -0,0 +1,418 @@ +""" +Gemini Format Utilities - 统一的 Gemini 格式处理和转换工具 +提供对 Gemini API 请求体和响应的标准化处理 +──────────────────────────────────────────────────────────────── +""" + +from typing import Any, Dict, List, Optional + +from log import log + +# ==================== Gemini API 配置 ==================== + +# Gemini API 不支持的 JSON Schema 字段集合 +# 参考: github.com/googleapis/python-genai/issues/699, #388, #460, #1122, #264, #4551 +UNSUPPORTED_SCHEMA_KEYS = { + '$schema', '$id', '$ref', '$defs', 'definitions', + 'example', 'examples', 'readOnly', 'writeOnly', 'default', + 'exclusiveMaximum', 'exclusiveMinimum', + 'oneOf', 'anyOf', 'allOf', 'const', + 'additionalItems', 'contains', 'patternProperties', 'dependencies', + 'propertyNames', 'if', 'then', 'else', + 'contentEncoding', 'contentMediaType', + 'additionalProperties', 'minLength', 'maxLength', + 'minItems', 'maxItems', 'uniqueItems' +} + + + +def clean_tools_for_gemini(tools: Optional[List[Dict[str, Any]]]) -> Optional[List[Dict[str, Any]]]: + """ + 清理工具定义,移除 Gemini API 不支持的 JSON Schema 字段 + + Gemini API 只支持有限的 OpenAPI 3.0 Schema 属性: + - 支持: type, description, enum, items, properties, required, nullable, format + - 不支持: $schema, $id, $ref, $defs, title, examples, default, readOnly, + exclusiveMaximum, exclusiveMinimum, oneOf, anyOf, allOf, const 等 + + Args: + tools: 工具定义列表 + + Returns: + 清理后的工具定义列表 + """ + if not tools: + return tools + + def clean_schema(obj: Any) -> Any: + """递归清理 schema 对象""" + if isinstance(obj, dict): + cleaned = {} + for key, value in obj.items(): + if key in UNSUPPORTED_SCHEMA_KEYS: + continue + cleaned[key] = clean_schema(value) + # 确保有 type 字段(如果有 properties 但没有 type) + if "properties" in cleaned and "type" not in cleaned: + cleaned["type"] = "object" + return cleaned + elif isinstance(obj, list): + return [clean_schema(item) for item in obj] + else: + return obj + + # 清理每个工具的参数 + cleaned_tools = [] + for tool in tools: + if not isinstance(tool, dict): + cleaned_tools.append(tool) + continue + + cleaned_tool = tool.copy() + + # 清理 functionDeclarations + if "functionDeclarations" in cleaned_tool: + cleaned_declarations = [] + for func_decl in cleaned_tool["functionDeclarations"]: + if not isinstance(func_decl, dict): + cleaned_declarations.append(func_decl) + continue + + cleaned_decl = func_decl.copy() + if "parameters" in cleaned_decl: + cleaned_decl["parameters"] = clean_schema(cleaned_decl["parameters"]) + cleaned_declarations.append(cleaned_decl) + + cleaned_tool["functionDeclarations"] = cleaned_declarations + + cleaned_tools.append(cleaned_tool) + + return cleaned_tools + +def prepare_image_generation_request( + request_body: Dict[str, Any], + model: str +) -> Dict[str, Any]: + """ + 图像生成模型请求体后处理 + + Args: + request_body: 原始请求体 + model: 模型名称 + + Returns: + 处理后的请求体 + """ + request_body = request_body.copy() + model_lower = model.lower() + + # 解析分辨率 + image_size = "4K" if "-4k" in model_lower else "2K" if "-2k" in model_lower else None + + # 解析比例 + aspect_ratio = None + for suffix, ratio in [ + ("-21x9", "21:9"), ("-16x9", "16:9"), ("-9x16", "9:16"), + ("-4x3", "4:3"), ("-3x4", "3:4"), ("-1x1", "1:1") + ]: + if suffix in model_lower: + aspect_ratio = ratio + break + + # 构建 imageConfig + image_config = {} + if aspect_ratio: + image_config["aspectRatio"] = aspect_ratio + if image_size: + image_config["imageSize"] = image_size + + request_body["model"] = "gemini-3-pro-image" # 统一使用基础模型名 + request_body["generationConfig"] = { + "candidateCount": 1, + "imageConfig": image_config + } + + # 移除不需要的字段 + for key in ("systemInstruction", "tools", "toolConfig"): + request_body.pop(key, None) + + return request_body + + +# ==================== 模型特性辅助函数 ==================== + +def get_base_model_name(model_name: str) -> str: + """移除模型名称中的后缀,返回基础模型名""" + # 按照从长到短的顺序排列,避免 -think 先于 -maxthinking 被匹配 + suffixes = ["-maxthinking", "-nothinking", "-search", "-think"] + result = model_name + changed = True + # 持续循环直到没有任何后缀可以移除 + while changed: + changed = False + for suffix in suffixes: + if result.endswith(suffix): + result = result[:-len(suffix)] + changed = True + # 不使用 break,继续检查是否还有其他后缀 + return result + + +def get_thinking_settings(model_name: str) -> tuple[Optional[int], bool]: + """ + 根据模型名称获取思考配置 + + Returns: + (thinking_budget, include_thoughts): 思考预算和是否包含思考内容 + """ + base_model = get_base_model_name(model_name) + + if "-nothinking" in model_name: + # nothinking 模式: 限制思考,pro模型仍包含thoughts + return 128, "pro" in base_model + elif "-maxthinking" in model_name: + # maxthinking 模式: 最大思考预算 + budget = 24576 if "flash" in base_model else 32768 + return budget, True + else: + # 默认模式: 不设置thinking budget + return None, True + + +def is_search_model(model_name: str) -> bool: + """检查是否为搜索模型""" + return "-search" in model_name + + +# ==================== 统一的 Gemini 请求后处理 ==================== + +def is_thinking_model(model_name: str) -> bool: + """检查是否为思考模型 (包含 -thinking 或 pro)""" + return "-thinking" in model_name or "pro" in model_name.lower() + + +def check_last_assistant_has_thinking(contents: List[Dict[str, Any]]) -> bool: + """ + 检查最后一个 assistant 消息是否以 thinking 块开始 + + 根据 Claude API 要求:当启用 thinking 时,最后一个 assistant 消息必须以 thinking 块开始 + + Args: + contents: Gemini 格式的 contents 数组 + + Returns: + 如果最后一个 assistant 消息以 thinking 块开始则返回 True,否则返回 False + """ + if not contents: + return True # 没有 contents,允许启用 thinking + + # 从后往前找最后一个 assistant (model) 消息 + last_assistant_content = None + for content in reversed(contents): + if isinstance(content, dict) and content.get("role") == "model": + last_assistant_content = content + break + + if not last_assistant_content: + return True # 没有 assistant 消息,允许启用 thinking + + # 检查第一个 part 是否是 thinking 块 + parts = last_assistant_content.get("parts", []) + if not parts: + return False # 有 assistant 消息但没有 parts,不允许 thinking + + first_part = parts[0] + if not isinstance(first_part, dict): + return False + + # 检查是否是 thinking 块(有 thought 字段且为 True) + return first_part.get("thought") is True + + +async def normalize_gemini_request( + request: Dict[str, Any], + mode: str = "geminicli" +) -> Dict[str, Any]: + """ + 规范化 Gemini 请求 + + 处理逻辑: + 1. 模型特性处理 (thinking config, search tools) + 2. 字段名转换 (system_instructions -> systemInstruction) + 3. 参数范围限制 (maxOutputTokens, topK) + 4. 工具清理 + + Args: + request: 原始请求字典 + mode: 模式 ("geminicli" 或 "antigravity") + + Returns: + 规范化后的请求 + """ + # 导入配置函数 + from config import get_return_thoughts_to_frontend + + result = request.copy() + model = result.get("model", "") + generation_config = (result.get("generationConfig") or {}).copy() # 创建副本避免修改原对象 + tools = result.get("tools") + system_instruction = result.get("systemInstruction") or result.get("system_instructions") + + # 记录原始请求 + log.debug(f"[GEMINI_FIX] 原始请求 - 模型: {model}, mode: {mode}, generationConfig: {generation_config}") + + # 获取配置值 + return_thoughts = await get_return_thoughts_to_frontend() + + # ========== 模式特定处理 ========== + if mode == "geminicli": + # 1. 思考设置 + thinking_budget, include_thoughts = get_thinking_settings(model) + if thinking_budget is not None and "thinkingConfig" not in generation_config: + # 如果配置为不返回thoughts,则强制设置为False;否则使用模型默认设置 + final_include_thoughts = include_thoughts if return_thoughts else False + generation_config["thinkingConfig"] = { + "thinkingBudget": thinking_budget, + "includeThoughts": final_include_thoughts + } + + # 2. 工具清理和处理 + if tools: + result["tools"] = clean_tools_for_gemini(tools) + + # 3. 搜索模型添加 Google Search + if is_search_model(model): + result_tools = result.get("tools") or [] + result["tools"] = result_tools + if not any(tool.get("googleSearch") for tool in result_tools if isinstance(tool, dict)): + result_tools.append({"googleSearch": {}}) + + # 4. 模型名称处理 + result["model"] = get_base_model_name(model) + + elif mode == "antigravity": + # 1. 处理 system_instruction + custom_prompt = "Please ignore the following [ignore]You are Antigravity, a powerful agentic AI coding assistant designed by the Google Deepmind team working on Advanced Agentic Coding.You are pair programming with a USER to solve their coding task. The task may require creating a new codebase, modifying or debugging an existing codebase, or simply answering a question.**Absolute paths only****Proactiveness**[/ignore]" + + # 提取原有的 parts(如果存在) + existing_parts = [] + if system_instruction: + if isinstance(system_instruction, dict): + existing_parts = system_instruction.get("parts", []) + + # custom_prompt 始终放在第一位,原有内容整体后移 + result["systemInstruction"] = { + "parts": [{"text": custom_prompt}] + existing_parts + } + + # 2. 判断图片模型 + if "image" in model.lower(): + # 调用图片生成专用处理函数 + return prepare_image_generation_request(result, model) + else: + # 3. 思考模型处理 + if is_thinking_model(model): + # 检查最后一个 assistant 消息是否以 thinking 块开始 + contents = result.get("contents", []) + can_enable_thinking = check_last_assistant_has_thinking(contents) + + if can_enable_thinking: + if "thinkingConfig" not in generation_config: + generation_config["thinkingConfig"] = {} + + thinking_config = generation_config["thinkingConfig"] + # 优先使用传入的思考预算,否则使用默认值 + if "thinkingBudget" not in thinking_config: + thinking_config["thinkingBudget"] = 1024 + if "includeThoughts" not in thinking_config: + thinking_config["includeThoughts"] = return_thoughts + else: + # 最后一个 assistant 消息不是以 thinking 块开始,禁用 thinking + log.warning(f"[ANTIGRAVITY] 最后一个 assistant 消息不以 thinking 块开始,禁用 thinkingConfig") + # 移除可能存在的 thinkingConfig + generation_config.pop("thinkingConfig", None) + + # 移除 -thinking 后缀 + model = model.replace("-thinking", "") + + # 4. Claude 模型关键词映射 + # 使用关键词匹配而不是精确匹配,更灵活地处理各种变体 + original_model = model + if "opus" in model.lower(): + model = "claude-opus-4-5-thinking" + elif "sonnet" in model.lower() or "haiku" in model.lower(): + model = "claude-sonnet-4-5-thinking" + elif "claude" in model.lower(): + # Claude 模型兜底:如果包含 claude 但不是 opus/sonnet/haiku + model = "claude-sonnet-4-5-thinking" + + result["model"] = model + if original_model != model: + log.debug(f"[ANTIGRAVITY] 映射模型: {original_model} -> {model}") + + # ========== 公共处理 ========== + # 1. 字段名转换 + if "system_instructions" in result: + result["systemInstruction"] = result.pop("system_instructions") + + # 2. 参数范围限制 + if generation_config: + max_tokens = generation_config.get("maxOutputTokens") + if max_tokens is not None: + generation_config["maxOutputTokens"] = 64000 + + top_k = generation_config.get("topK") + if top_k is not None: + generation_config["topK"] = 64 + + # 3. 工具清理 + if tools: + result["tools"] = clean_tools_for_gemini(tools) + + # 4. 清理空的 parts 和未知字段(修复 400 错误:required oneof field 'data' must have one initialized field) + # 同时移除不支持的字段如 cache_control + if "contents" in result: + # 定义 part 中允许的字段集合 + ALLOWED_PART_KEYS = { + "text", "inlineData", "fileData", "functionCall", "functionResponse", + "thought", "thoughtSignature" # thinking 相关字段 + } + + cleaned_contents = [] + for content in result["contents"]: + if isinstance(content, dict) and "parts" in content: + # 过滤掉空的或无效的 parts,并移除未知字段 + valid_parts = [] + for part in content["parts"]: + if not isinstance(part, dict): + continue + + # 移除不支持的字段(如 cache_control) + cleaned_part = {k: v for k, v in part.items() if k in ALLOWED_PART_KEYS} + + # 检查 part 是否有有效的数据字段 + has_valid_data = any( + key in cleaned_part and cleaned_part[key] + for key in ["text", "inlineData", "fileData", "functionCall", "functionResponse"] + ) + if has_valid_data: + valid_parts.append(cleaned_part) + else: + log.warning(f"[GEMINI_FIX] 移除空的或无效的 part: {part}") + + # 只添加有有效 parts 的 content + if valid_parts: + cleaned_content = content.copy() + cleaned_content["parts"] = valid_parts + cleaned_contents.append(cleaned_content) + else: + log.warning(f"[GEMINI_FIX] 跳过没有有效 parts 的 content: {content.get('role')}") + else: + cleaned_contents.append(content) + + result["contents"] = cleaned_contents + + if generation_config: + result["generationConfig"] = generation_config + + return result \ No newline at end of file diff --git a/src/converter/openai2gemini.py b/src/converter/openai2gemini.py new file mode 100644 index 0000000000000000000000000000000000000000..c7fcad49a2490c6b4b238f07fbf2671325d6bf3b --- /dev/null +++ b/src/converter/openai2gemini.py @@ -0,0 +1,930 @@ +""" +OpenAI Transfer Module - Handles conversion between OpenAI and Gemini API formats +被openai-router调用,负责OpenAI格式与Gemini格式的双向转换 +""" + +import json +import time +import uuid +from typing import Any, Dict, List, Optional, Tuple, Union + +from pypinyin import Style, lazy_pinyin + +from src.converter.thoughtSignature_fix import ( + encode_tool_id_with_signature, + decode_tool_id_and_signature, +) +from src.converter.utils import merge_system_messages + +from log import log + +def _convert_usage_metadata(usage_metadata: Dict[str, Any]) -> Dict[str, int]: + """ + 将Gemini的usageMetadata转换为OpenAI格式的usage字段 + + Args: + usage_metadata: Gemini API的usageMetadata字段 + + Returns: + OpenAI格式的usage字典,如果没有usage数据则返回None + """ + if not usage_metadata: + return None + + return { + "prompt_tokens": usage_metadata.get("promptTokenCount", 0), + "completion_tokens": usage_metadata.get("candidatesTokenCount", 0), + "total_tokens": usage_metadata.get("totalTokenCount", 0), + } + + +def _build_message_with_reasoning(role: str, content: str, reasoning_content: str) -> dict: + """构建包含可选推理内容的消息对象""" + message = {"role": role, "content": content} + + # 如果有thinking tokens,添加reasoning_content + if reasoning_content: + message["reasoning_content"] = reasoning_content + + return message + + +def _map_finish_reason(gemini_reason: str) -> str: + """ + 将Gemini结束原因映射到OpenAI结束原因 + + Args: + gemini_reason: 来自Gemini API的结束原因 + + Returns: + OpenAI兼容的结束原因 + """ + if gemini_reason == "STOP": + return "stop" + elif gemini_reason == "MAX_TOKENS": + return "length" + elif gemini_reason in ["SAFETY", "RECITATION"]: + return "content_filter" + else: + return None + + +# ==================== Tool Conversion Functions ==================== + + +def _normalize_function_name(name: str) -> str: + """ + 规范化函数名以符合 Gemini API 要求 + + 规则: + - 必须以字母或下划线开头 + - 只能包含 a-z, A-Z, 0-9, 下划线, 点, 短横线 + - 最大长度 64 个字符 + + 转换策略: + - 中文字符转换为拼音 + - 如果以非字母/下划线开头,添加 "_" 前缀 + - 将非法字符(空格、@、#等)替换为下划线 + - 连续的下划线合并为一个 + - 如果超过 64 个字符,截断 + + Args: + name: 原始函数名 + + Returns: + 规范化后的函数名 + """ + import re + + if not name: + return "_unnamed_function" + + # 第零步:检测并转换中文字符为拼音 + # 检查是否包含中文字符 + if re.search(r"[\u4e00-\u9fff]", name): + try: + + # 将中文转换为拼音,用下划线连接多音字 + parts = [] + for char in name: + if "\u4e00" <= char <= "\u9fff": + # 中文字符,转换为拼音 + pinyin = lazy_pinyin(char, style=Style.NORMAL) + parts.append("".join(pinyin)) + else: + # 非中文字符,保持不变 + parts.append(char) + normalized = "".join(parts) + except ImportError: + log.warning("pypinyin not installed, cannot convert Chinese characters to pinyin") + normalized = name + else: + normalized = name + + # 第一步:将非法字符替换为下划线 + # 保留:a-z, A-Z, 0-9, 下划线, 点, 短横线 + normalized = re.sub(r"[^a-zA-Z0-9_.\-]", "_", normalized) + + # 第二步:如果以非字母/下划线开头,处理首字符 + prefix_added = False + if normalized and not (normalized[0].isalpha() or normalized[0] == "_"): + if normalized[0] in ".-": + # 点和短横线在开头位置替换为下划线(它们在中间是合法的) + normalized = "_" + normalized[1:] + else: + # 其他字符(如数字)添加下划线前缀 + normalized = "_" + normalized + prefix_added = True + + # 第三步:合并连续的下划线 + normalized = re.sub(r"_+", "_", normalized) + + # 第四步:移除首尾的下划线 + # 如果原本就是下划线开头,或者我们添加了前缀,则保留开头的下划线 + if name.startswith("_") or prefix_added: + # 只移除尾部的下划线 + normalized = normalized.rstrip("_") + else: + # 移除首尾的下划线 + normalized = normalized.strip("_") + + # 第五步:确保不为空 + if not normalized: + normalized = "_unnamed_function" + + # 第六步:截断到 64 个字符 + if len(normalized) > 64: + normalized = normalized[:64] + + return normalized + + +def _clean_schema_for_gemini(schema: Any) -> Any: + """ + 清理 JSON Schema,移除 Gemini 不支持的字段 + + Gemini API 只支持有限的 OpenAPI 3.0 Schema 属性: + - 支持: type, description, enum, items, properties, required, nullable, format + - 不支持: $schema, $id, $ref, $defs, title, examples, default, readOnly, + exclusiveMaximum, exclusiveMinimum, oneOf, anyOf, allOf, const 等 + + Args: + schema: JSON Schema 对象(字典、列表或其他值) + + Returns: + 清理后的 schema + """ + if not isinstance(schema, dict): + return schema + + # Gemini 不支持的字段 + unsupported_keys = { + "$schema", + "$id", + "$ref", + "$defs", + "definitions", + "example", + "examples", + "readOnly", + "writeOnly", + "default", + "exclusiveMaximum", + "exclusiveMinimum", + "oneOf", + "anyOf", + "allOf", + "const", + "additionalItems", + "contains", + "patternProperties", + "dependencies", + "propertyNames", + "if", + "then", + "else", + "contentEncoding", + "contentMediaType", + } + + cleaned = {} + for key, value in schema.items(): + if key in unsupported_keys: + continue + if isinstance(value, dict): + cleaned[key] = _clean_schema_for_gemini(value) + elif isinstance(value, list): + cleaned[key] = [ + _clean_schema_for_gemini(item) if isinstance(item, dict) else item for item in value + ] + else: + cleaned[key] = value + + # 确保有 type 字段(如果有 properties 但没有 type) + if "properties" in cleaned and "type" not in cleaned: + cleaned["type"] = "object" + + return cleaned + + +def convert_openai_tools_to_gemini(openai_tools: List) -> List[Dict[str, Any]]: + """ + 将 OpenAI tools 格式转换为 Gemini functionDeclarations 格式 + + Args: + openai_tools: OpenAI 格式的工具列表(可能是字典或 Pydantic 模型) + + Returns: + Gemini 格式的工具列表 + """ + if not openai_tools: + return [] + + function_declarations = [] + + for tool in openai_tools: + if tool.get("type") != "function": + log.warning(f"Skipping non-function tool type: {tool.get('type')}") + continue + + function = tool.get("function") + if not function: + log.warning("Tool missing 'function' field") + continue + + # 获取并规范化函数名 + original_name = function.get("name") + if not original_name: + log.warning("Tool missing 'name' field, using default") + original_name = "_unnamed_function" + + normalized_name = _normalize_function_name(original_name) + + # 如果名称被修改了,记录日志 + if normalized_name != original_name: + log.debug(f"Function name normalized: '{original_name}' -> '{normalized_name}'") + + # 构建 Gemini function declaration + declaration = { + "name": normalized_name, + "description": function.get("description", ""), + } + + # 添加参数(如果有)- 清理不支持的 schema 字段 + if "parameters" in function: + cleaned_params = _clean_schema_for_gemini(function["parameters"]) + if cleaned_params: + declaration["parameters"] = cleaned_params + + function_declarations.append(declaration) + + if not function_declarations: + return [] + + # Gemini 格式:工具数组中包含 functionDeclarations + return [{"functionDeclarations": function_declarations}] + + +def convert_tool_choice_to_tool_config(tool_choice: Union[str, Dict[str, Any]]) -> Dict[str, Any]: + """ + 将 OpenAI tool_choice 转换为 Gemini toolConfig + + Args: + tool_choice: OpenAI 格式的 tool_choice + + Returns: + Gemini 格式的 toolConfig + """ + if isinstance(tool_choice, str): + if tool_choice == "auto": + return {"functionCallingConfig": {"mode": "AUTO"}} + elif tool_choice == "none": + return {"functionCallingConfig": {"mode": "NONE"}} + elif tool_choice == "required": + return {"functionCallingConfig": {"mode": "ANY"}} + elif isinstance(tool_choice, dict): + # {"type": "function", "function": {"name": "my_function"}} + if tool_choice.get("type") == "function": + function_name = tool_choice.get("function", {}).get("name") + if function_name: + return { + "functionCallingConfig": { + "mode": "ANY", + "allowedFunctionNames": [function_name], + } + } + + # 默认返回 AUTO 模式 + return {"functionCallingConfig": {"mode": "AUTO"}} + + +def convert_tool_message_to_function_response(message, all_messages: List = None) -> Dict[str, Any]: + """ + 将 OpenAI 的 tool role 消息转换为 Gemini functionResponse + + Args: + message: OpenAI 格式的工具消息 + all_messages: 所有消息的列表,用于查找 tool_call_id 对应的函数名 + + Returns: + Gemini 格式的 functionResponse part + """ + # 获取 name 字段 + name = getattr(message, "name", None) + encoded_tool_call_id = getattr(message, "tool_call_id", None) or "" + + # 解码获取原始ID(functionResponse不需要签名) + original_tool_call_id, _ = decode_tool_id_and_signature(encoded_tool_call_id) + + # 如果没有 name,尝试从 all_messages 中查找对应的 tool_call_id + # 注意:使用编码ID查找,因为存储的是编码ID + if not name and encoded_tool_call_id and all_messages: + for msg in all_messages: + if getattr(msg, "role", None) == "assistant" and hasattr(msg, "tool_calls") and msg.tool_calls: + for tool_call in msg.tool_calls: + if getattr(tool_call, "id", None) == encoded_tool_call_id: + func = getattr(tool_call, "function", None) + if func: + name = getattr(func, "name", None) + break + if name: + break + + # 最终兜底:如果仍然没有 name,使用默认值 + if not name: + name = "unknown_function" + log.warning(f"Tool message missing function name, using default: {name}") + + try: + # 尝试将 content 解析为 JSON + response_data = ( + json.loads(message.content) if isinstance(message.content, str) else message.content + ) + except (json.JSONDecodeError, TypeError): + # 如果不是有效的 JSON,包装为对象 + response_data = {"result": str(message.content)} + + return {"functionResponse": {"id": original_tool_call_id, "name": name, "response": response_data}} + + +def extract_tool_calls_from_parts( + parts: List[Dict[str, Any]], is_streaming: bool = False +) -> Tuple[List[Dict[str, Any]], str]: + """ + 从 Gemini response parts 中提取工具调用和文本内容 + + Args: + parts: Gemini response 的 parts 数组 + is_streaming: 是否为流式响应(流式响应需要添加 index 字段) + + Returns: + (tool_calls, text_content) 元组 + """ + tool_calls = [] + text_content = "" + + for idx, part in enumerate(parts): + # 检查是否是函数调用 + if "functionCall" in part: + function_call = part["functionCall"] + # 获取原始ID或生成新ID + original_id = function_call.get("id") or f"call_{uuid.uuid4().hex[:24]}" + # 将thoughtSignature编码到ID中以便往返保留 + signature = part.get("thoughtSignature") + encoded_id = encode_tool_id_with_signature(original_id, signature) + + tool_call = { + "id": encoded_id, + "type": "function", + "function": { + "name": function_call.get("name", "nameless_function"), + "arguments": json.dumps(function_call.get("args", {})), + }, + } + # 流式响应需要 index 字段 + if is_streaming: + tool_call["index"] = idx + tool_calls.append(tool_call) + + # 提取文本内容(排除 thinking tokens) + elif "text" in part and not part.get("thought", False): + text_content += part["text"] + + return tool_calls, text_content + + +def extract_images_from_content(content: Any) -> Dict[str, Any]: + """ + 从 OpenAI content 中提取文本和图片 + + Args: + content: OpenAI 消息的 content 字段(可能是字符串或列表) + + Returns: + 包含 text 和 images 的字典 + """ + result = {"text": "", "images": []} + + if isinstance(content, str): + result["text"] = content + elif isinstance(content, list): + for item in content: + if isinstance(item, dict): + if item.get("type") == "text": + result["text"] += item.get("text", "") + elif item.get("type") == "image_url": + image_url = item.get("image_url", {}).get("url", "") + # 解析  格式 + if image_url.startswith("data:image/"): + import re + match = re.match(r"^data:image/(\w+);base64,(.+)$", image_url) + if match: + mime_type = match.group(1) + base64_data = match.group(2) + result["images"].append({ + "inlineData": { + "mimeType": f"image/{mime_type}", + "data": base64_data + } + }) + + return result + +async def convert_openai_to_gemini_request(openai_request: Dict[str, Any]) -> Dict[str, Any]: + """ + 将 OpenAI 格式请求体转换为 Gemini 格式请求体 + + 注意: 此函数只负责基础转换,不包含 normalize_gemini_request 中的处理 + (如 thinking config, search tools, 参数范围限制等) + + Args: + openai_request: OpenAI 格式的请求体字典,包含: + - messages: 消息列表 + - temperature, top_p, max_tokens, stop 等生成参数 + - tools, tool_choice (可选) + - response_format (可选) + + Returns: + Gemini 格式的请求体字典,包含: + - contents: 转换后的消息内容 + - generationConfig: 生成配置 + - systemInstruction: 系统指令 (如果有) + - tools, toolConfig (如果有) + """ + # 处理连续的system消息(兼容性模式) + openai_request = await merge_system_messages(openai_request) + + contents = [] + + # 提取消息列表 + messages = openai_request.get("messages", []) + + for message in messages: + role = message.get("role", "user") + content = message.get("content", "") + + # 处理工具消息(tool role) + if role == "tool": + tool_call_id = message.get("tool_call_id", "") + func_name = message.get("name") + + # 如果没有name,尝试从消息列表中查找 + if not func_name and tool_call_id: + for msg in messages: + if msg.get("role") == "assistant" and msg.get("tool_calls"): + for tc in msg["tool_calls"]: + if tc.get("id") == tool_call_id: + func_name = tc.get("function", {}).get("name") + break + if func_name: + break + + if not func_name: + func_name = "unknown_function" + + # 解析响应数据 + try: + response_data = json.loads(content) if isinstance(content, str) else content + except (json.JSONDecodeError, TypeError): + response_data = {"result": str(content)} + + contents.append({ + "role": "user", + "parts": [{ + "functionResponse": { + "id": tool_call_id, + "name": func_name, + "response": response_data + } + }] + }) + continue + + # system 消息已经由 merge_system_messages 处理,这里跳过 + if role == "system": + continue + + # 将OpenAI角色映射到Gemini角色 + if role == "assistant": + role = "model" + + # 检查是否有tool_calls + tool_calls = message.get("tool_calls") + if tool_calls: + parts = [] + + # 如果有文本内容,先添加文本 + if content: + parts.append({"text": content}) + + # 添加每个工具调用 + for tool_call in tool_calls: + try: + args = ( + json.loads(tool_call["function"]["arguments"]) + if isinstance(tool_call["function"]["arguments"], str) + else tool_call["function"]["arguments"] + ) + + # 解码工具ID和thoughtSignature + encoded_id = tool_call.get("id", "") + original_id, signature = decode_tool_id_and_signature(encoded_id) + + # 构建functionCall part + function_call_part = { + "functionCall": { + "id": original_id, + "name": tool_call["function"]["name"], + "args": args + } + } + + # 如果有thoughtSignature,添加到part中 + if signature: + function_call_part["thoughtSignature"] = signature + + parts.append(function_call_part) + except (json.JSONDecodeError, KeyError) as e: + log.error(f"Failed to parse tool call: {e}") + continue + + if parts: + contents.append({"role": role, "parts": parts}) + continue + + # 处理普通内容 + if isinstance(content, list): + parts = [] + for part in content: + if part.get("type") == "text": + parts.append({"text": part.get("text", "")}) + elif part.get("type") == "image_url": + image_url = part.get("image_url", {}).get("url") + if image_url: + try: + mime_type, base64_data = image_url.split(";") + _, mime_type = mime_type.split(":") + _, base64_data = base64_data.split(",") + parts.append({ + "inlineData": { + "mimeType": mime_type, + "data": base64_data, + } + }) + except ValueError: + continue + if parts: + contents.append({"role": role, "parts": parts}) + elif content: + contents.append({"role": role, "parts": [{"text": content}]}) + + # 构建生成配置 + generation_config = {} + if "temperature" in openai_request: + generation_config["temperature"] = openai_request["temperature"] + if "top_p" in openai_request: + generation_config["topP"] = openai_request["top_p"] + if "max_tokens" in openai_request: + generation_config["maxOutputTokens"] = openai_request["max_tokens"] + if "stop" in openai_request: + stop = openai_request["stop"] + generation_config["stopSequences"] = [stop] if isinstance(stop, str) else stop + if "frequency_penalty" in openai_request: + generation_config["frequencyPenalty"] = openai_request["frequency_penalty"] + if "presence_penalty" in openai_request: + generation_config["presencePenalty"] = openai_request["presence_penalty"] + if "n" in openai_request: + generation_config["candidateCount"] = openai_request["n"] + if "seed" in openai_request: + generation_config["seed"] = openai_request["seed"] + if "response_format" in openai_request and openai_request["response_format"]: + if openai_request["response_format"].get("type") == "json_object": + generation_config["responseMimeType"] = "application/json" + + # 如果contents为空,添加默认用户消息 + if not contents: + contents.append({"role": "user", "parts": [{"text": "请根据系统指令回答。"}]}) + + # 构建基础请求 + gemini_request = { + "contents": contents, + "generationConfig": generation_config + } + + # 如果 merge_system_messages 已经添加了 systemInstruction,使用它 + if "systemInstruction" in openai_request: + gemini_request["systemInstruction"] = openai_request["systemInstruction"] + + # 处理工具 + if "tools" in openai_request and openai_request["tools"]: + gemini_request["tools"] = convert_openai_tools_to_gemini(openai_request["tools"]) + + # 处理tool_choice + if "tool_choice" in openai_request and openai_request["tool_choice"]: + gemini_request["toolConfig"] = convert_tool_choice_to_tool_config(openai_request["tool_choice"]) + + return gemini_request + + +def convert_gemini_to_openai_response( + gemini_response: Union[Dict[str, Any], Any], + model: str, + status_code: int = 200 +) -> Dict[str, Any]: + """ + 将 Gemini 格式非流式响应转换为 OpenAI 格式非流式响应 + + 注意: 如果收到的不是 200 开头的响应,不做任何处理,直接转发原始响应 + + Args: + gemini_response: Gemini 格式的响应体 (字典或响应对象) + model: 模型名称 + status_code: HTTP 状态码 (默认 200) + + Returns: + OpenAI 格式的响应体字典,或原始响应 (如果状态码不是 2xx) + """ + # 非 2xx 状态码直接返回原始响应 + if not (200 <= status_code < 300): + if isinstance(gemini_response, dict): + return gemini_response + else: + # 如果是响应对象,尝试解析为字典 + try: + if hasattr(gemini_response, "json"): + return gemini_response.json() + elif hasattr(gemini_response, "body"): + body = gemini_response.body + if isinstance(body, bytes): + return json.loads(body.decode()) + return json.loads(str(body)) + else: + return {"error": str(gemini_response)} + except: + return {"error": str(gemini_response)} + + # 确保是字典格式 + if not isinstance(gemini_response, dict): + try: + if hasattr(gemini_response, "json"): + gemini_response = gemini_response.json() + elif hasattr(gemini_response, "body"): + body = gemini_response.body + if isinstance(body, bytes): + gemini_response = json.loads(body.decode()) + else: + gemini_response = json.loads(str(body)) + else: + gemini_response = json.loads(str(gemini_response)) + except: + return {"error": "Invalid response format"} + + # 处理 GeminiCLI 的 response 包装格式 + if "response" in gemini_response: + gemini_response = gemini_response["response"] + + # 转换为 OpenAI 格式 + choices = [] + + for candidate in gemini_response.get("candidates", []): + role = candidate.get("content", {}).get("role", "assistant") + + # 将Gemini角色映射回OpenAI角色 + if role == "model": + role = "assistant" + + # 提取并分离thinking tokens和常规内容 + parts = candidate.get("content", {}).get("parts", []) + + # 提取工具调用和文本内容 + tool_calls, text_content = extract_tool_calls_from_parts(parts) + + # 提取图片数据 + images = [] + for part in parts: + if "inlineData" in part: + inline_data = part["inlineData"] + mime_type = inline_data.get("mimeType", "image/png") + base64_data = inline_data.get("data", "") + images.append({ + "type": "image_url", + "image_url": { + "url": f"data:{mime_type};base64,{base64_data}" + } + }) + + # 提取 reasoning content + reasoning_content = "" + for part in parts: + if part.get("thought", False) and "text" in part: + reasoning_content += part["text"] + + # 构建消息对象 + message = {"role": role} + + # 如果有工具调用 + if tool_calls: + message["tool_calls"] = tool_calls + message["content"] = text_content if text_content else None + finish_reason = "tool_calls" + # 如果有图片 + elif images: + content_list = [] + if text_content: + content_list.append({"type": "text", "text": text_content}) + content_list.extend(images) + message["content"] = content_list + finish_reason = _map_finish_reason(candidate.get("finishReason")) + else: + message["content"] = text_content + finish_reason = _map_finish_reason(candidate.get("finishReason")) + + # 添加 reasoning content (如果有) + if reasoning_content: + message["reasoning_content"] = reasoning_content + + choices.append({ + "index": candidate.get("index", 0), + "message": message, + "finish_reason": finish_reason, + }) + + # 转换 usageMetadata + usage = _convert_usage_metadata(gemini_response.get("usageMetadata")) + + response_data = { + "id": str(uuid.uuid4()), + "object": "chat.completion", + "created": int(time.time()), + "model": model, + "choices": choices, + } + + if usage: + response_data["usage"] = usage + + return response_data + + +def convert_gemini_to_openai_stream( + gemini_stream_chunk: str, + model: str, + response_id: str, + status_code: int = 200 +) -> Optional[str]: + """ + 将 Gemini 格式流式响应块转换为 OpenAI SSE 格式流式响应 + + 注意: 如果收到的不是 200 开头的响应,不做任何处理,直接转发原始内容 + + Args: + gemini_stream_chunk: Gemini 格式的流式响应块 (字符串,通常是 "data: {json}" 格式) + model: 模型名称 + response_id: 此流式响应的一致ID + status_code: HTTP 状态码 (默认 200) + + Returns: + OpenAI SSE 格式的响应字符串 (如 "data: {json}\n\n"), + 或原始内容 (如果状态码不是 2xx), + 或 None (如果解析失败) + """ + # 非 2xx 状态码直接返回原始内容 + if not (200 <= status_code < 300): + return gemini_stream_chunk + + # 解析 Gemini 流式块 + try: + # 去除 "data: " 前缀 + if isinstance(gemini_stream_chunk, bytes): + if gemini_stream_chunk.startswith(b"data: "): + payload_str = gemini_stream_chunk[len(b"data: "):].strip().decode("utf-8") + else: + payload_str = gemini_stream_chunk.strip().decode("utf-8") + else: + if gemini_stream_chunk.startswith("data: "): + payload_str = gemini_stream_chunk[len("data: "):].strip() + else: + payload_str = gemini_stream_chunk.strip() + + # 跳过空块 + if not payload_str: + return None + + # 解析 JSON + gemini_chunk = json.loads(payload_str) + except (json.JSONDecodeError, UnicodeDecodeError): + # 解析失败,跳过此块 + return None + + # 处理 GeminiCLI 的 response 包装格式 + if "response" in gemini_chunk: + gemini_response = gemini_chunk["response"] + else: + gemini_response = gemini_chunk + + # 转换为 OpenAI 流式格式 + choices = [] + + for candidate in gemini_response.get("candidates", []): + role = candidate.get("content", {}).get("role", "assistant") + + # 将Gemini角色映射回OpenAI角色 + if role == "model": + role = "assistant" + + # 提取并分离thinking tokens和常规内容 + parts = candidate.get("content", {}).get("parts", []) + + # 提取工具调用和文本内容 (流式需要 index) + tool_calls, text_content = extract_tool_calls_from_parts(parts, is_streaming=True) + + # 提取图片数据 + images = [] + for part in parts: + if "inlineData" in part: + inline_data = part["inlineData"] + mime_type = inline_data.get("mimeType", "image/png") + base64_data = inline_data.get("data", "") + images.append({ + "type": "image_url", + "image_url": { + "url": f"data:{mime_type};base64,{base64_data}" + } + }) + + # 提取 reasoning content + reasoning_content = "" + for part in parts: + if part.get("thought", False) and "text" in part: + reasoning_content += part["text"] + + # 构建 delta 对象 + delta = {} + + if tool_calls: + delta["tool_calls"] = tool_calls + if text_content: + delta["content"] = text_content + elif images: + # 流式响应中的图片: 以 markdown 格式返回 + markdown_images = [f"![Generated Image]({img['image_url']['url']})" for img in images] + if text_content: + delta["content"] = text_content + "\n\n" + "\n\n".join(markdown_images) + else: + delta["content"] = "\n\n".join(markdown_images) + elif text_content: + delta["content"] = text_content + + if reasoning_content: + delta["reasoning_content"] = reasoning_content + + finish_reason = _map_finish_reason(candidate.get("finishReason")) + if finish_reason and tool_calls: + finish_reason = "tool_calls" + + choices.append({ + "index": candidate.get("index", 0), + "delta": delta, + "finish_reason": finish_reason, + }) + + # 转换 usageMetadata (只在流结束时存在) + usage = _convert_usage_metadata(gemini_response.get("usageMetadata")) + + # 构建 OpenAI 流式响应 + response_data = { + "id": response_id, + "object": "chat.completion.chunk", + "created": int(time.time()), + "model": model, + "choices": choices, + } + + # 只在有 usage 数据且有 finish_reason 时添加 usage + if usage: + has_finish_reason = any(choice.get("finish_reason") for choice in choices) + if has_finish_reason: + response_data["usage"] = usage + + # 转换为 SSE 格式: "data: {json}\n\n" + return f"data: {json.dumps(response_data)}\n\n" diff --git a/src/converter/thoughtSignature_fix.py b/src/converter/thoughtSignature_fix.py new file mode 100644 index 0000000000000000000000000000000000000000..caff9bfc41e86ba26d895cc8ff8447a1424b59a6 --- /dev/null +++ b/src/converter/thoughtSignature_fix.py @@ -0,0 +1,56 @@ +""" +thoughtSignature 处理公共模块 + +提供统一的 thoughtSignature 编码/解码功能,用于在工具调用ID中保留签名信息。 +这使得签名能够在客户端往返传输中保留,即使客户端会删除自定义字段。 +""" + +from typing import Optional, Tuple + +# 在工具调用ID中嵌入thoughtSignature的分隔符 +# 这使得签名能够在客户端往返传输中保留,即使客户端会删除自定义字段 +THOUGHT_SIGNATURE_SEPARATOR = "__thought__" + + +def encode_tool_id_with_signature(tool_id: str, signature: Optional[str]) -> str: + """ + 将 thoughtSignature 编码到工具调用ID中,以便往返保留。 + + Args: + tool_id: 原始工具调用ID + signature: thoughtSignature(可选) + + Returns: + 编码后的工具调用ID + + Examples: + >>> encode_tool_id_with_signature("call_123", "abc") + 'call_123__thought__abc' + >>> encode_tool_id_with_signature("call_123", None) + 'call_123' + """ + if not signature: + return tool_id + return f"{tool_id}{THOUGHT_SIGNATURE_SEPARATOR}{signature}" + + +def decode_tool_id_and_signature(encoded_id: str) -> Tuple[str, Optional[str]]: + """ + 从编码的ID中提取原始工具ID和thoughtSignature。 + + Args: + encoded_id: 编码的工具调用ID + + Returns: + (原始工具ID, thoughtSignature) 元组 + + Examples: + >>> decode_tool_id_and_signature("call_123__thought__abc") + ('call_123', 'abc') + >>> decode_tool_id_and_signature("call_123") + ('call_123', None) + """ + if not encoded_id or THOUGHT_SIGNATURE_SEPARATOR not in encoded_id: + return encoded_id, None + parts = encoded_id.split(THOUGHT_SIGNATURE_SEPARATOR, 1) + return parts[0], parts[1] if len(parts) == 2 else None diff --git a/src/converter/utils.py b/src/converter/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..99469ea0717dedd5268a47016ec9e03430d251e1 --- /dev/null +++ b/src/converter/utils.py @@ -0,0 +1,231 @@ +from typing import Any, Dict + + +def extract_content_and_reasoning(parts: list) -> tuple: + """从Gemini响应部件中提取内容和推理内容 + + Args: + parts: Gemini 响应中的 parts 列表 + + Returns: + (content, reasoning_content, images): 文本内容、推理内容和图片数据的元组 + - content: 文本内容字符串 + - reasoning_content: 推理内容字符串 + - images: 图片数据列表,每个元素格式为: + { + "type": "image_url", + "image_url": { + "url": "data:{mime_type};base64,{base64_data}" + } + } + """ + content = "" + reasoning_content = "" + images = [] + + for part in parts: + # 提取文本内容 + text = part.get("text", "") + if text: + if part.get("thought", False): + reasoning_content += text + else: + content += text + + # 提取图片数据 + if "inlineData" in part: + inline_data = part["inlineData"] + mime_type = inline_data.get("mimeType", "image/png") + base64_data = inline_data.get("data", "") + images.append({ + "type": "image_url", + "image_url": { + "url": f"data:{mime_type};base64,{base64_data}" + } + }) + + return content, reasoning_content, images + + +async def merge_system_messages(request_body: Dict[str, Any]) -> Dict[str, Any]: + """ + 根据兼容性模式处理请求体中的system消息 + + - 兼容性模式关闭(False):将连续的system消息合并为systemInstruction + - 兼容性模式开启(True):将所有system消息转换为user消息 + + Args: + request_body: OpenAI或Claude格式的请求体,包含messages字段 + + Returns: + 处理后的请求体 + + Example (兼容性模式关闭): + 输入: + { + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "system", "content": "You are an expert in Python."}, + {"role": "user", "content": "Hello"} + ] + } + + 输出: + { + "systemInstruction": { + "parts": [ + {"text": "You are a helpful assistant."}, + {"text": "You are an expert in Python."} + ] + }, + "messages": [ + {"role": "user", "content": "Hello"} + ] + } + + Example (兼容性模式开启): + 输入: + { + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Hello"} + ] + } + + 输出: + { + "messages": [ + {"role": "user", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Hello"} + ] + } + + Example (Anthropic格式,兼容性模式关闭): + 输入: + { + "system": "You are a helpful assistant.", + "messages": [ + {"role": "user", "content": "Hello"} + ] + } + + 输出: + { + "systemInstruction": { + "parts": [ + {"text": "You are a helpful assistant."} + ] + }, + "messages": [ + {"role": "user", "content": "Hello"} + ] + } + """ + from config import get_compatibility_mode_enabled + + compatibility_mode = await get_compatibility_mode_enabled() + + # 处理 Anthropic 格式的顶层 system 参数 + # Anthropic API 规范: system 是顶层参数,不在 messages 中 + system_content = request_body.get("system") + if system_content and "systemInstruction" not in request_body: + system_parts = [] + + if isinstance(system_content, str): + if system_content.strip(): + system_parts.append({"text": system_content}) + elif isinstance(system_content, list): + # system 可以是包含多个块的列表 + for item in system_content: + if isinstance(item, dict): + if item.get("type") == "text" and item.get("text", "").strip(): + system_parts.append({"text": item["text"]}) + elif isinstance(item, str) and item.strip(): + system_parts.append({"text": item}) + + if system_parts: + if compatibility_mode: + # 兼容性模式:将 system 转换为 user 消息插入到 messages 开头 + user_system_message = { + "role": "user", + "content": system_content if isinstance(system_content, str) else + "\n".join(part["text"] for part in system_parts) + } + messages = request_body.get("messages", []) + request_body = request_body.copy() + request_body["messages"] = [user_system_message] + messages + else: + # 非兼容性模式:添加为 systemInstruction + request_body = request_body.copy() + request_body["systemInstruction"] = {"parts": system_parts} + + messages = request_body.get("messages", []) + if not messages: + return request_body + + compatibility_mode = await get_compatibility_mode_enabled() + + if compatibility_mode: + # 兼容性模式开启:将所有system消息转换为user消息 + converted_messages = [] + for message in messages: + if message.get("role") == "system": + # 创建新的消息对象,将role改为user + converted_message = message.copy() + converted_message["role"] = "user" + converted_messages.append(converted_message) + else: + converted_messages.append(message) + + result = request_body.copy() + result["messages"] = converted_messages + return result + else: + # 兼容性模式关闭:提取连续的system消息合并为systemInstruction + system_parts = [] + + # 如果已经从顶层 system 参数创建了 systemInstruction,获取现有的 parts + if "systemInstruction" in request_body: + existing_instruction = request_body.get("systemInstruction", {}) + if isinstance(existing_instruction, dict): + system_parts = existing_instruction.get("parts", []).copy() + + remaining_messages = [] + collecting_system = True + + for message in messages: + role = message.get("role", "") + content = message.get("content", "") + + if role == "system" and collecting_system: + # 提取system消息的文本内容 + if isinstance(content, str): + if content.strip(): + system_parts.append({"text": content}) + elif isinstance(content, list): + # 处理列表格式的content + for item in content: + if isinstance(item, dict): + if item.get("type") == "text" and item.get("text", "").strip(): + system_parts.append({"text": item["text"]}) + elif isinstance(item, str) and item.strip(): + system_parts.append({"text": item}) + else: + # 遇到非system消息,停止收集 + collecting_system = False + remaining_messages.append(message) + + # 如果没有找到任何system消息(包括顶层参数和messages中的),返回原始请求体 + if not system_parts: + return request_body + + # 构建新的请求体 + result = request_body.copy() + + # 添加或更新systemInstruction + result["systemInstruction"] = {"parts": system_parts} + + # 更新messages列表(移除已处理的system消息) + result["messages"] = remaining_messages + + return result \ No newline at end of file diff --git a/src/credential_manager.py b/src/credential_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..ab86732cf32e0fe1ab019abb5b0d551757658e42 --- /dev/null +++ b/src/credential_manager.py @@ -0,0 +1,521 @@ +""" +凭证管理器 +""" + +import asyncio +import time +from datetime import datetime, timezone +from typing import Any, Dict, List, Optional, Tuple + +from log import log + +from .google_oauth_api import Credentials +from .storage_adapter import get_storage_adapter + +class CredentialManager: + """ + 统一凭证管理器 + 所有存储操作通过storage_adapter进行 + """ + + def __init__(self): + # 核心状态 + self._initialized = False + self._storage_adapter = None + + # 并发控制(简化) + self._operation_lock = asyncio.Lock() + + async def _ensure_initialized(self): + """确保管理器已初始化(内部使用)""" + if not self._initialized or self._storage_adapter is None: + await self.initialize() + + async def initialize(self): + """初始化凭证管理器""" + async with self._operation_lock: + if self._initialized and self._storage_adapter is not None: + return + + # 初始化统一存储适配器 + self._storage_adapter = await get_storage_adapter() + self._initialized = True + + async def close(self): + """清理资源""" + log.debug("Closing credential manager...") + self._initialized = False + log.debug("Credential manager closed") + + async def get_valid_credential( + self, mode: str = "geminicli", model_key: Optional[str] = None + ) -> Optional[Tuple[str, Dict[str, Any]]]: + """ + 获取有效的凭证 - 随机负载均衡版 + 每次随机选择一个可用的凭证(未禁用、未冷却) + 如果刷新失败会自动禁用失效凭证并重试获取下一个可用凭证 + + Args: + mode: 凭证模式 ("geminicli" 或 "antigravity") + model_key: 模型键,用于模型级冷却检查 + - antigravity: 模型名称(如 "gemini-2.0-flash-exp") + - gcli: "pro" 或 "flash" + """ + await self._ensure_initialized() + + # 最多重试3次 + max_retries = 3 + for attempt in range(max_retries): + result = await self._storage_adapter._backend.get_next_available_credential( + mode=mode, model_key=model_key + ) + + # 如果没有可用凭证,直接返回None + if not result: + if attempt == 0: + log.warning(f"没有可用凭证 (mode={mode}, model_key={model_key})") + return None + + filename, credential_data = result + + # Token 刷新检查 + if await self._should_refresh_token(credential_data): + log.debug(f"Token需要刷新 - 文件: {filename} (mode={mode})") + refreshed_data = await self._refresh_token(credential_data, filename, mode=mode) + if refreshed_data: + # 刷新成功,返回凭证 + credential_data = refreshed_data + log.debug(f"Token刷新成功: {filename} (mode={mode})") + return filename, credential_data + else: + # 刷新失败(_refresh_token内部已自动禁用失效凭证) + log.warning(f"Token刷新失败,尝试获取下一个凭证: {filename} (mode={mode}, attempt={attempt+1}/{max_retries})") + # 继续循环,尝试获取下一个可用凭证 + continue + else: + # Token有效,直接返回 + return filename, credential_data + + # 重试次数用尽 + log.error(f"重试{max_retries}次后仍无可用凭证 (mode={mode}, model_key={model_key})") + return None + + async def add_credential(self, credential_name: str, credential_data: Dict[str, Any]): + """ + 新增或更新一个凭证 + 存储层会自动处理轮换顺序 + """ + await self._ensure_initialized() + async with self._operation_lock: + await self._storage_adapter.store_credential(credential_name, credential_data) + log.info(f"Credential added/updated: {credential_name}") + + async def add_antigravity_credential(self, credential_name: str, credential_data: Dict[str, Any]): + """ + 新增或更新一个Antigravity凭证 + 存储层会自动处理轮换顺序 + """ + await self._ensure_initialized() + async with self._operation_lock: + await self._storage_adapter.store_credential(credential_name, credential_data, mode="antigravity") + log.info(f"Antigravity credential added/updated: {credential_name}") + + async def remove_credential(self, credential_name: str, mode: str = "geminicli") -> bool: + """删除一个凭证""" + await self._ensure_initialized() + async with self._operation_lock: + try: + await self._storage_adapter.delete_credential(credential_name, mode=mode) + log.info(f"Credential removed: {credential_name} (mode={mode})") + return True + except Exception as e: + log.error(f"Error removing credential {credential_name}: {e}") + return False + + async def update_credential_state(self, credential_name: str, state_updates: Dict[str, Any], mode: str = "geminicli"): + """更新凭证状态""" + log.debug(f"[CredMgr] update_credential_state 开始: credential_name={credential_name}, state_updates={state_updates}, mode={mode}") + log.debug(f"[CredMgr] 调用 _ensure_initialized...") + await self._ensure_initialized() + log.debug(f"[CredMgr] _ensure_initialized 完成") + try: + log.debug(f"[CredMgr] 调用 storage_adapter.update_credential_state...") + success = await self._storage_adapter.update_credential_state( + credential_name, state_updates, mode=mode + ) + log.debug(f"[CredMgr] storage_adapter.update_credential_state 返回: {success}") + if success: + log.debug(f"Updated credential state: {credential_name} (mode={mode})") + else: + log.warning(f"Failed to update credential state: {credential_name} (mode={mode})") + return success + except Exception as e: + log.error(f"Error updating credential state {credential_name}: {e}", exc_info=True) + return False + + async def set_cred_disabled(self, credential_name: str, disabled: bool, mode: str = "geminicli"): + """设置凭证的启用/禁用状态""" + try: + log.info(f"[CredMgr] set_cred_disabled 开始: credential_name={credential_name}, disabled={disabled}, mode={mode}") + success = await self.update_credential_state( + credential_name, {"disabled": disabled}, mode=mode + ) + log.info(f"[CredMgr] update_credential_state 返回: success={success}") + if success: + action = "disabled" if disabled else "enabled" + log.info(f"Credential {action}: {credential_name} (mode={mode})") + else: + log.warning(f"[CredMgr] 设置禁用状态失败: credential_name={credential_name}, disabled={disabled}") + return success + except Exception as e: + log.error(f"Error setting credential disabled state {credential_name}: {e}") + return False + + async def get_creds_status(self) -> Dict[str, Dict[str, Any]]: + """获取所有凭证的状态""" + await self._ensure_initialized() + try: + return await self._storage_adapter.get_all_credential_states() + except Exception as e: + log.error(f"Error getting credential statuses: {e}") + return {} + + async def get_creds_summary(self) -> List[Dict[str, Any]]: + """ + 获取所有凭证的摘要信息(轻量级,不包含完整凭证数据) + 优先使用后端的高性能查询 + """ + await self._ensure_initialized() + try: + # 如果后端支持高性能摘要查询,直接使用 + if hasattr(self._storage_adapter._backend, 'get_credentials_summary'): + return await self._storage_adapter._backend.get_credentials_summary() + + # 否则回退到传统方式 + all_states = await self._storage_adapter.get_all_credential_states() + summaries = [] + + import time + current_time = time.time() + + for filename, state in all_states.items(): + summaries.append({ + "filename": filename, + "disabled": state.get("disabled", False), + "error_codes": state.get("error_codes", []), + "last_success": state.get("last_success", current_time), + "user_email": state.get("user_email"), + "model_cooldowns": state.get("model_cooldowns", {}), + }) + + return summaries + + except Exception as e: + log.error(f"Error getting credentials summary: {e}") + return [] + + async def get_or_fetch_user_email(self, credential_name: str, mode: str = "geminicli") -> Optional[str]: + """获取或获取用户邮箱地址""" + try: + # 确保已初始化 + await self._ensure_initialized() + + # 从状态中获取缓存的邮箱 + state = await self._storage_adapter.get_credential_state(credential_name, mode=mode) + cached_email = state.get("user_email") if state else None + + if cached_email: + return cached_email + + # 如果没有缓存,从凭证数据获取 + credential_data = await self._storage_adapter.get_credential(credential_name, mode=mode) + if not credential_data: + return None + + # 创建凭证对象并自动刷新 token + from .google_oauth_api import Credentials, get_user_email + + credentials = Credentials.from_dict(credential_data) + if not credentials: + return None + + # 自动刷新 token(如果需要) + token_refreshed = await credentials.refresh_if_needed() + + # 如果 token 被刷新了,更新存储 + if token_refreshed: + log.info(f"Token已自动刷新: {credential_name} (mode={mode})") + updated_data = credentials.to_dict() + await self._storage_adapter.store_credential(credential_name, updated_data, mode=mode) + + # 获取邮箱 + email = await get_user_email(credentials) + + if email: + # 缓存邮箱地址 + await self._storage_adapter.update_credential_state( + credential_name, {"user_email": email}, mode=mode + ) + return email + + return None + + except Exception as e: + log.error(f"Error fetching user email for {credential_name}: {e}") + return None + + async def record_api_call_result( + self, + credential_name: str, + success: bool, + error_code: Optional[int] = None, + cooldown_until: Optional[float] = None, + mode: str = "geminicli", + model_key: Optional[str] = None + ): + """ + 记录API调用结果 + + Args: + credential_name: 凭证名称 + success: 是否成功 + error_code: 错误码(如果失败) + cooldown_until: 冷却截止时间戳(Unix时间戳,针对429 QUOTA_EXHAUSTED) + mode: 凭证模式 ("geminicli" 或 "antigravity") + model_key: 模型键(用于设置模型级冷却) + """ + await self._ensure_initialized() + try: + state_updates = {} + + if success: + state_updates["last_success"] = time.time() + # 清除错误码 + state_updates["error_codes"] = [] + + # 如果提供了 model_key,清除该模型的冷却 + if model_key: + if hasattr(self._storage_adapter._backend, 'set_model_cooldown'): + await self._storage_adapter._backend.set_model_cooldown( + credential_name, model_key, None, mode=mode + ) + + elif error_code: + # 记录错误码 + current_state = await self._storage_adapter.get_credential_state(credential_name, mode=mode) + error_codes = current_state.get("error_codes", []) + + if error_code not in error_codes: + error_codes.append(error_code) + # 限制错误码列表长度 + if len(error_codes) > 10: + error_codes = error_codes[-10:] + + state_updates["error_codes"] = error_codes + + # 如果提供了冷却时间和模型键,设置模型级冷却 + if cooldown_until is not None and model_key: + if hasattr(self._storage_adapter._backend, 'set_model_cooldown'): + await self._storage_adapter._backend.set_model_cooldown( + credential_name, model_key, cooldown_until, mode=mode + ) + log.info( + f"设置模型级冷却: {credential_name}, model_key={model_key}, " + f"冷却至: {datetime.fromtimestamp(cooldown_until, timezone.utc).isoformat()}" + ) + + if state_updates: + await self.update_credential_state(credential_name, state_updates, mode=mode) + + except Exception as e: + log.error(f"Error recording API call result for {credential_name}: {e}") + + async def _should_refresh_token(self, credential_data: Dict[str, Any]) -> bool: + """检查token是否需要刷新""" + try: + # 如果没有access_token或过期时间,需要刷新 + if not credential_data.get("access_token") and not credential_data.get("token"): + log.debug("没有access_token,需要刷新") + return True + + expiry_str = credential_data.get("expiry") + if not expiry_str: + log.debug("没有过期时间,需要刷新") + return True + + # 解析过期时间 + try: + if isinstance(expiry_str, str): + if "+" in expiry_str: + file_expiry = datetime.fromisoformat(expiry_str) + elif expiry_str.endswith("Z"): + file_expiry = datetime.fromisoformat(expiry_str.replace("Z", "+00:00")) + else: + file_expiry = datetime.fromisoformat(expiry_str) + else: + log.debug("过期时间格式无效,需要刷新") + return True + + # 确保时区信息 + if file_expiry.tzinfo is None: + file_expiry = file_expiry.replace(tzinfo=timezone.utc) + + # 检查是否还有至少5分钟有效期 + now = datetime.now(timezone.utc) + time_left = (file_expiry - now).total_seconds() + + log.debug( + f"Token时间检查: " + f"当前UTC时间={now.isoformat()}, " + f"过期时间={file_expiry.isoformat()}, " + f"剩余时间={int(time_left/60)}分{int(time_left%60)}秒" + ) + + if time_left > 300: # 5分钟缓冲 + return False + else: + log.debug(f"Token即将过期(剩余{int(time_left/60)}分钟),需要刷新") + return True + + except Exception as e: + log.warning(f"解析过期时间失败: {e},需要刷新") + return True + + except Exception as e: + log.error(f"检查token过期时出错: {e}") + return True + + async def _refresh_token( + self, credential_data: Dict[str, Any], filename: str, mode: str = "geminicli" + ) -> Optional[Dict[str, Any]]: + """刷新token并更新存储""" + await self._ensure_initialized() + try: + # 创建Credentials对象 + creds = Credentials.from_dict(credential_data) + + # 检查是否可以刷新 + if not creds.refresh_token: + log.error(f"没有refresh_token,无法刷新: {filename} (mode={mode})") + # 自动禁用没有refresh_token的凭证 + try: + await self.update_credential_state(filename, {"disabled": True}, mode=mode) + log.warning(f"凭证已自动禁用(缺少refresh_token): {filename}") + except Exception as e: + log.error(f"禁用凭证失败 {filename}: {e}") + return None + + # 刷新token + log.debug(f"正在刷新token: {filename} (mode={mode})") + await creds.refresh() + + # 更新凭证数据 + if creds.access_token: + credential_data["access_token"] = creds.access_token + # 保持兼容性 + credential_data["token"] = creds.access_token + + if creds.expires_at: + credential_data["expiry"] = creds.expires_at.isoformat() + + # 保存到存储 + await self._storage_adapter.store_credential(filename, credential_data, mode=mode) + log.info(f"Token刷新成功并已保存: {filename} (mode={mode})") + + return credential_data + + except Exception as e: + error_msg = str(e) + log.error(f"Token刷新失败 {filename} (mode={mode}): {error_msg}") + + # 尝试提取HTTP状态码(TokenError可能携带status_code属性) + status_code = None + if hasattr(e, 'status_code'): + status_code = e.status_code + + # 检查是否是凭证永久失效的错误(只有明确的400/403等才判定为永久失效) + is_permanent_failure = self._is_permanent_refresh_failure(error_msg, status_code) + + if is_permanent_failure: + log.warning(f"检测到凭证永久失效 (HTTP {status_code}): {filename}") + # 记录失效状态 + if status_code: + await self.record_api_call_result(filename, False, status_code, mode=mode) + else: + await self.record_api_call_result(filename, False, 400, mode=mode) + + # 禁用失效凭证 + try: + # 直接禁用该凭证(随机选择机制会自动跳过它) + disabled_ok = await self.update_credential_state(filename, {"disabled": True}, mode=mode) + if disabled_ok: + log.warning(f"永久失效凭证已禁用: {filename}") + else: + log.warning("永久失效凭证禁用失败,将由上层逻辑继续处理") + except Exception as e2: + log.error(f"禁用永久失效凭证时出错 {filename}: {e2}") + else: + # 网络错误或其他临时性错误,不封禁凭证 + log.warning(f"Token刷新失败但非永久性错误 (HTTP {status_code}),不封禁凭证: {filename}") + + return None + + def _is_permanent_refresh_failure(self, error_msg: str, status_code: Optional[int] = None) -> bool: + """ + 判断是否是凭证永久失效的错误 + + Args: + error_msg: 错误信息 + status_code: HTTP状态码(如果有) + + Returns: + True表示凭证永久失效应封禁,False表示临时错误不应封禁 + """ + # 优先使用HTTP状态码判断 + if status_code is not None: + # 400/401/403 明确表示凭证有问题,应该封禁 + if status_code in [400, 401, 403]: + log.debug(f"检测到客户端错误状态码 {status_code},判定为永久失效") + return True + # 500/502/503/504 是服务器错误,不应封禁凭证 + elif status_code in [500, 502, 503, 504]: + log.debug(f"检测到服务器错误状态码 {status_code},不应封禁凭证") + return False + # 429 (限流) 不应封禁凭证 + elif status_code == 429: + log.debug("检测到限流错误 429,不应封禁凭证") + return False + + # 如果没有状态码,回退到错误信息匹配(谨慎判断) + # 只有明确的凭证失效错误才判定为永久失效 + permanent_error_patterns = [ + "invalid_grant", + "refresh_token_expired", + "invalid_refresh_token", + "unauthorized_client", + "access_denied", + ] + + error_msg_lower = error_msg.lower() + for pattern in permanent_error_patterns: + if pattern.lower() in error_msg_lower: + log.debug(f"错误信息匹配到永久失效模式: {pattern}") + return True + + # 默认认为是临时错误(如网络问题),不应封禁凭证 + log.debug("未匹配到明确的永久失效模式,判定为临时错误") + return False + +# 全局实例管理(保持兼容性) +_credential_manager: Optional[CredentialManager] = None + + +async def get_credential_manager() -> CredentialManager: + """获取全局凭证管理器实例""" + global _credential_manager + + if _credential_manager is None: + _credential_manager = CredentialManager() + await _credential_manager.initialize() + + return _credential_manager diff --git a/src/google_oauth_api.py b/src/google_oauth_api.py new file mode 100644 index 0000000000000000000000000000000000000000..af0b2404b8169daa86ade8721415c64af4c3a196 --- /dev/null +++ b/src/google_oauth_api.py @@ -0,0 +1,781 @@ +""" +Google OAuth2 认证模块 +""" + +import time +import asyncio +from datetime import datetime, timedelta, timezone +from typing import Any, Dict, List, Optional +from urllib.parse import urlencode + +import jwt + +from config import ( + get_googleapis_proxy_url, + get_oauth_proxy_url, + get_resource_manager_api_url, + get_service_usage_api_url, +) +from log import log + +from .httpx_client import get_async, post_async + + +class TokenError(Exception): + """Token相关错误""" + + pass + + +class Credentials: + """凭证类""" + + def __init__( + self, + access_token: str, + refresh_token: str = None, + client_id: str = None, + client_secret: str = None, + expires_at: datetime = None, + project_id: str = None, + ): + self.access_token = access_token + self.refresh_token = refresh_token + self.client_id = client_id + self.client_secret = client_secret + self.expires_at = expires_at + self.project_id = project_id + + # 反代配置将在使用时异步获取 + self.oauth_base_url = None + self.token_endpoint = None + + def is_expired(self) -> bool: + """检查token是否过期""" + if not self.expires_at: + return True + + # 提前3分钟认为过期 + buffer = timedelta(minutes=3) + return (self.expires_at - buffer) <= datetime.now(timezone.utc) + + async def refresh_if_needed(self) -> bool: + """如果需要则刷新token""" + if not self.is_expired(): + return False + + if not self.refresh_token: + raise TokenError("需要刷新令牌但未提供") + + await self.refresh() + return True + + async def refresh(self): + """刷新访问令牌""" + if not self.refresh_token: + raise TokenError("无刷新令牌") + + data = { + "client_id": self.client_id, + "client_secret": self.client_secret, + "refresh_token": self.refresh_token, + "grant_type": "refresh_token", + } + + try: + oauth_base_url = await get_oauth_proxy_url() + token_url = f"{oauth_base_url.rstrip('/')}/token" + response = await post_async( + token_url, + data=data, + headers={"Content-Type": "application/x-www-form-urlencoded"}, + ) + response.raise_for_status() + + token_data = response.json() + self.access_token = token_data["access_token"] + + if "expires_in" in token_data: + expires_in = int(token_data["expires_in"]) + current_utc = datetime.now(timezone.utc) + self.expires_at = current_utc + timedelta(seconds=expires_in) + log.debug( + f"Token刷新: 当前UTC时间={current_utc.isoformat()}, " + f"有效期={expires_in}秒, " + f"过期时间={self.expires_at.isoformat()}" + ) + + if "refresh_token" in token_data: + self.refresh_token = token_data["refresh_token"] + + log.debug(f"Token刷新成功,过期时间: {self.expires_at}") + + except Exception as e: + error_msg = str(e) + status_code = None + if hasattr(e, 'response') and hasattr(e.response, 'status_code'): + status_code = e.response.status_code + error_msg = f"Token刷新失败 (HTTP {status_code}): {error_msg}" + else: + error_msg = f"Token刷新失败: {error_msg}" + + log.error(error_msg) + token_error = TokenError(error_msg) + token_error.status_code = status_code + raise token_error + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "Credentials": + """从字典创建凭证""" + # 处理过期时间 + expires_at = None + if "expiry" in data and data["expiry"]: + try: + expiry_str = data["expiry"] + if isinstance(expiry_str, str): + if expiry_str.endswith("Z"): + expires_at = datetime.fromisoformat(expiry_str.replace("Z", "+00:00")) + elif "+" in expiry_str: + expires_at = datetime.fromisoformat(expiry_str) + else: + expires_at = datetime.fromisoformat(expiry_str).replace(tzinfo=timezone.utc) + except ValueError: + log.warning(f"无法解析过期时间: {expiry_str}") + + return cls( + access_token=data.get("token") or data.get("access_token", ""), + refresh_token=data.get("refresh_token"), + client_id=data.get("client_id"), + client_secret=data.get("client_secret"), + expires_at=expires_at, + project_id=data.get("project_id"), + ) + + def to_dict(self) -> Dict[str, Any]: + """转为字典""" + result = { + "access_token": self.access_token, + "refresh_token": self.refresh_token, + "client_id": self.client_id, + "client_secret": self.client_secret, + "project_id": self.project_id, + } + + if self.expires_at: + result["expiry"] = self.expires_at.isoformat() + + return result + + +class Flow: + """OAuth流程类""" + + def __init__( + self, client_id: str, client_secret: str, scopes: List[str], redirect_uri: str = None + ): + self.client_id = client_id + self.client_secret = client_secret + self.scopes = scopes + self.redirect_uri = redirect_uri + + # 反代配置将在使用时异步获取 + self.oauth_base_url = None + self.token_endpoint = None + self.auth_endpoint = "https://accounts.google.com/o/oauth2/auth" + + self.credentials: Optional[Credentials] = None + + def get_auth_url(self, state: str = None, **kwargs) -> str: + """生成授权URL""" + params = { + "client_id": self.client_id, + "redirect_uri": self.redirect_uri, + "scope": " ".join(self.scopes), + "response_type": "code", + "access_type": "offline", + "prompt": "consent", + "include_granted_scopes": "true", + } + + if state: + params["state"] = state + + params.update(kwargs) + return f"{self.auth_endpoint}?{urlencode(params)}" + + async def exchange_code(self, code: str) -> Credentials: + """用授权码换取token""" + data = { + "client_id": self.client_id, + "client_secret": self.client_secret, + "redirect_uri": self.redirect_uri, + "code": code, + "grant_type": "authorization_code", + } + + try: + oauth_base_url = await get_oauth_proxy_url() + token_url = f"{oauth_base_url.rstrip('/')}/token" + response = await post_async( + token_url, data=data, headers={"Content-Type": "application/x-www-form-urlencoded"} + ) + response.raise_for_status() + + token_data = response.json() + + # 计算过期时间 + expires_at = None + if "expires_in" in token_data: + expires_in = int(token_data["expires_in"]) + expires_at = datetime.now(timezone.utc) + timedelta(seconds=expires_in) + + # 创建凭证对象 + self.credentials = Credentials( + access_token=token_data["access_token"], + refresh_token=token_data.get("refresh_token"), + client_id=self.client_id, + client_secret=self.client_secret, + expires_at=expires_at, + ) + + return self.credentials + + except Exception as e: + error_msg = f"获取token失败: {str(e)}" + log.error(error_msg) + raise TokenError(error_msg) + + +class ServiceAccount: + """Service Account类""" + + def __init__( + self, email: str, private_key: str, project_id: str = None, scopes: List[str] = None + ): + self.email = email + self.private_key = private_key + self.project_id = project_id + self.scopes = scopes or [] + + # 反代配置将在使用时异步获取 + self.oauth_base_url = None + self.token_endpoint = None + + self.access_token: Optional[str] = None + self.expires_at: Optional[datetime] = None + + def is_expired(self) -> bool: + """检查token是否过期""" + if not self.expires_at: + return True + + buffer = timedelta(minutes=3) + return (self.expires_at - buffer) <= datetime.now(timezone.utc) + + def create_jwt(self) -> str: + """创建JWT令牌""" + now = int(time.time()) + + payload = { + "iss": self.email, + "scope": " ".join(self.scopes) if self.scopes else "", + "aud": self.token_endpoint, + "exp": now + 3600, + "iat": now, + } + + return jwt.encode(payload, self.private_key, algorithm="RS256") + + async def get_access_token(self) -> str: + """获取访问令牌""" + if not self.is_expired() and self.access_token: + return self.access_token + + assertion = self.create_jwt() + + data = {"grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer", "assertion": assertion} + + try: + oauth_base_url = await get_oauth_proxy_url() + token_url = f"{oauth_base_url.rstrip('/')}/token" + response = await post_async( + token_url, data=data, headers={"Content-Type": "application/x-www-form-urlencoded"} + ) + response.raise_for_status() + + token_data = response.json() + self.access_token = token_data["access_token"] + + if "expires_in" in token_data: + expires_in = int(token_data["expires_in"]) + self.expires_at = datetime.now(timezone.utc) + timedelta(seconds=expires_in) + + return self.access_token + + except Exception as e: + error_msg = f"Service Account获取token失败: {str(e)}" + log.error(error_msg) + raise TokenError(error_msg) + + @classmethod + def from_dict(cls, data: Dict[str, Any], scopes: List[str] = None) -> "ServiceAccount": + """从字典创建Service Account凭证""" + return cls( + email=data["client_email"], + private_key=data["private_key"], + project_id=data.get("project_id"), + scopes=scopes, + ) + + +# 工具函数 +async def get_user_info(credentials: Credentials) -> Optional[Dict[str, Any]]: + """获取用户信息""" + await credentials.refresh_if_needed() + + try: + googleapis_base_url = await get_googleapis_proxy_url() + userinfo_url = f"{googleapis_base_url.rstrip('/')}/oauth2/v2/userinfo" + response = await get_async( + userinfo_url, headers={"Authorization": f"Bearer {credentials.access_token}"} + ) + response.raise_for_status() + return response.json() + except Exception as e: + log.error(f"获取用户信息失败: {e}") + return None + + +async def get_user_email(credentials: Credentials) -> Optional[str]: + """获取用户邮箱地址""" + try: + # 确保凭证有效 + await credentials.refresh_if_needed() + + # 调用Google userinfo API获取邮箱 + user_info = await get_user_info(credentials) + if user_info: + email = user_info.get("email") + if email: + log.info(f"成功获取邮箱地址: {email}") + return email + else: + log.warning(f"userinfo响应中没有邮箱信息: {user_info}") + return None + else: + log.warning("获取用户信息失败") + return None + + except Exception as e: + log.error(f"获取用户邮箱失败: {e}") + return None + + +async def fetch_user_email_from_file(cred_data: Dict[str, Any]) -> Optional[str]: + """从凭证数据获取用户邮箱地址(支持统一存储)""" + try: + # 直接从凭证数据创建凭证对象 + credentials = Credentials.from_dict(cred_data) + if not credentials or not credentials.access_token: + log.warning("无法从凭证数据创建凭证对象或获取访问令牌") + return None + + # 获取邮箱 + return await get_user_email(credentials) + + except Exception as e: + log.error(f"从凭证数据获取用户邮箱失败: {e}") + return None + + +async def validate_token(token: str) -> Optional[Dict[str, Any]]: + """验证访问令牌""" + try: + oauth_base_url = await get_oauth_proxy_url() + tokeninfo_url = f"{oauth_base_url.rstrip('/')}/tokeninfo?access_token={token}" + + response = await get_async(tokeninfo_url) + response.raise_for_status() + return response.json() + except Exception as e: + log.error(f"验证令牌失败: {e}") + return None + + +async def enable_required_apis(credentials: Credentials, project_id: str) -> bool: + """自动启用必需的API服务""" + try: + # 确保凭证有效 + if credentials.is_expired() and credentials.refresh_token: + await credentials.refresh() + + headers = { + "Authorization": f"Bearer {credentials.access_token}", + "Content-Type": "application/json", + "User-Agent": "geminicli-oauth/1.0", + } + + # 需要启用的服务列表 + required_services = [ + "geminicloudassist.googleapis.com", # Gemini Cloud Assist API + "cloudaicompanion.googleapis.com", # Gemini for Google Cloud API + ] + + for service in required_services: + log.info(f"正在检查并启用服务: {service}") + + # 检查服务是否已启用 + service_usage_base_url = await get_service_usage_api_url() + check_url = ( + f"{service_usage_base_url.rstrip('/')}/v1/projects/{project_id}/services/{service}" + ) + try: + check_response = await get_async(check_url, headers=headers) + if check_response.status_code == 200: + service_data = check_response.json() + if service_data.get("state") == "ENABLED": + log.info(f"服务 {service} 已启用") + continue + except Exception as e: + log.debug(f"检查服务状态失败,将尝试启用: {e}") + + # 启用服务 + enable_url = f"{service_usage_base_url.rstrip('/')}/v1/projects/{project_id}/services/{service}:enable" + try: + enable_response = await post_async(enable_url, headers=headers, json={}) + + if enable_response.status_code in [200, 201]: + log.info(f"✅ 成功启用服务: {service}") + elif enable_response.status_code == 400: + error_data = enable_response.json() + if "already enabled" in error_data.get("error", {}).get("message", "").lower(): + log.info(f"✅ 服务 {service} 已经启用") + else: + log.warning(f"⚠️ 启用服务 {service} 时出现警告: {error_data}") + else: + log.warning( + f"⚠️ 启用服务 {service} 失败: {enable_response.status_code} - {enable_response.text}" + ) + + except Exception as e: + log.warning(f"⚠️ 启用服务 {service} 时发生异常: {e}") + + return True + + except Exception as e: + log.error(f"启用API服务时发生错误: {e}") + return False + + +async def get_user_projects(credentials: Credentials) -> List[Dict[str, Any]]: + """获取用户可访问的Google Cloud项目列表""" + try: + # 确保凭证有效 + if credentials.is_expired() and credentials.refresh_token: + await credentials.refresh() + + headers = { + "Authorization": f"Bearer {credentials.access_token}", + "User-Agent": "geminicli-oauth/1.0", + } + + # 使用Resource Manager API的正确域名和端点 + resource_manager_base_url = await get_resource_manager_api_url() + url = f"{resource_manager_base_url.rstrip('/')}/v1/projects" + log.info(f"正在调用API: {url}") + response = await get_async(url, headers=headers) + + log.info(f"API响应状态码: {response.status_code}") + if response.status_code != 200: + log.error(f"API响应内容: {response.text}") + + if response.status_code == 200: + data = response.json() + projects = data.get("projects", []) + # 只返回活跃的项目 + active_projects = [ + project for project in projects if project.get("lifecycleState") == "ACTIVE" + ] + log.info(f"获取到 {len(active_projects)} 个活跃项目") + return active_projects + else: + log.warning(f"获取项目列表失败: {response.status_code} - {response.text}") + return [] + + except Exception as e: + log.error(f"获取用户项目列表失败: {e}") + return [] + + +async def select_default_project(projects: List[Dict[str, Any]]) -> Optional[str]: + """从项目列表中选择默认项目""" + if not projects: + return None + + # 策略1:查找显示名称或项目ID包含"default"的项目 + for project in projects: + display_name = project.get("displayName", "").lower() + # Google API returns projectId in camelCase + project_id = project.get("projectId", "") + if "default" in display_name or "default" in project_id.lower(): + log.info(f"选择默认项目: {project_id} ({project.get('displayName', project_id)})") + return project_id + + # 策略2:选择第一个项目 + first_project = projects[0] + # Google API returns projectId in camelCase + project_id = first_project.get("projectId", "") + log.info( + f"选择第一个项目作为默认: {project_id} ({first_project.get('displayName', project_id)})" + ) + return project_id + + +async def fetch_project_id( + access_token: str, + user_agent: str, + api_base_url: str +) -> Optional[str]: + """ + 从 API 获取 project_id,如果 loadCodeAssist 失败则回退到 onboardUser + + Args: + access_token: Google OAuth access token + user_agent: User-Agent header + api_base_url: API base URL (e.g., antigravity or code assist endpoint) + + Returns: + project_id 字符串,如果获取失败返回 None + """ + headers = { + 'User-Agent': user_agent, + 'Authorization': f'Bearer {access_token}', + 'Content-Type': 'application/json', + 'Accept-Encoding': 'gzip' + } + + # 步骤 1: 尝试 loadCodeAssist + try: + project_id = await _try_load_code_assist(api_base_url, headers) + if project_id: + return project_id + + log.warning("[fetch_project_id] loadCodeAssist did not return project_id, falling back to onboardUser") + + except Exception as e: + log.warning(f"[fetch_project_id] loadCodeAssist failed: {type(e).__name__}: {e}") + log.warning("[fetch_project_id] Falling back to onboardUser") + + # 步骤 2: 回退到 onboardUser + try: + project_id = await _try_onboard_user(api_base_url, headers) + if project_id: + return project_id + + log.error("[fetch_project_id] Failed to get project_id from both loadCodeAssist and onboardUser") + return None + + except Exception as e: + log.error(f"[fetch_project_id] onboardUser failed: {type(e).__name__}: {e}") + import traceback + log.debug(f"[fetch_project_id] Traceback: {traceback.format_exc()}") + return None + + +async def _try_load_code_assist( + api_base_url: str, + headers: dict +) -> Optional[str]: + """ + 尝试通过 loadCodeAssist 获取 project_id + + Returns: + project_id 或 None + """ + request_url = f"{api_base_url.rstrip('/')}/v1internal:loadCodeAssist" + request_body = { + "metadata": { + "ideType": "ANTIGRAVITY", + "platform": "PLATFORM_UNSPECIFIED", + "pluginType": "GEMINI" + } + } + + log.debug(f"[loadCodeAssist] Fetching project_id from: {request_url}") + log.debug(f"[loadCodeAssist] Request body: {request_body}") + + response = await post_async( + request_url, + json=request_body, + headers=headers, + timeout=30.0, + ) + + log.debug(f"[loadCodeAssist] Response status: {response.status_code}") + + if response.status_code == 200: + response_text = response.text + log.debug(f"[loadCodeAssist] Response body: {response_text}") + + data = response.json() + log.debug(f"[loadCodeAssist] Response JSON keys: {list(data.keys())}") + + # 检查是否有 currentTier(表示用户已激活) + current_tier = data.get("currentTier") + if current_tier: + log.info("[loadCodeAssist] User is already activated") + + # 使用服务器返回的 project_id + project_id = data.get("cloudaicompanionProject") + if project_id: + log.info(f"[loadCodeAssist] Successfully fetched project_id: {project_id}") + return project_id + + log.warning("[loadCodeAssist] No project_id in response") + return None + else: + log.info("[loadCodeAssist] User not activated yet (no currentTier)") + return None + else: + log.warning(f"[loadCodeAssist] Failed: HTTP {response.status_code}") + log.warning(f"[loadCodeAssist] Response body: {response.text[:500]}") + raise Exception(f"HTTP {response.status_code}: {response.text[:200]}") + + +async def _try_onboard_user( + api_base_url: str, + headers: dict +) -> Optional[str]: + """ + 尝试通过 onboardUser 获取 project_id(长时间运行操作,需要轮询) + + Returns: + project_id 或 None + """ + request_url = f"{api_base_url.rstrip('/')}/v1internal:onboardUser" + + # 首先需要获取用户的 tier 信息 + tier_id = await _get_onboard_tier(api_base_url, headers) + if not tier_id: + log.error("[onboardUser] Failed to determine user tier") + return None + + log.info(f"[onboardUser] User tier: {tier_id}") + + # 构造 onboardUser 请求 + # 注意:FREE tier 不应该包含 cloudaicompanionProject + request_body = { + "tierId": tier_id, + "metadata": { + "ideType": "ANTIGRAVITY", + "platform": "PLATFORM_UNSPECIFIED", + "pluginType": "GEMINI" + } + } + + log.debug(f"[onboardUser] Request URL: {request_url}") + log.debug(f"[onboardUser] Request body: {request_body}") + + # onboardUser 是长时间运行操作,需要轮询 + # 最多等待 10 秒(5 次 * 2 秒) + max_attempts = 5 + attempt = 0 + + while attempt < max_attempts: + attempt += 1 + log.debug(f"[onboardUser] Polling attempt {attempt}/{max_attempts}") + + response = await post_async( + request_url, + json=request_body, + headers=headers, + timeout=30.0, + ) + + log.debug(f"[onboardUser] Response status: {response.status_code}") + + if response.status_code == 200: + data = response.json() + log.debug(f"[onboardUser] Response data: {data}") + + # 检查长时间运行操作是否完成 + if data.get("done"): + log.info("[onboardUser] Operation completed") + + # 从响应中提取 project_id + response_data = data.get("response", {}) + project_obj = response_data.get("cloudaicompanionProject", {}) + + if isinstance(project_obj, dict): + project_id = project_obj.get("id") + elif isinstance(project_obj, str): + project_id = project_obj + else: + project_id = None + + if project_id: + log.info(f"[onboardUser] Successfully fetched project_id: {project_id}") + return project_id + else: + log.warning("[onboardUser] Operation completed but no project_id in response") + return None + else: + log.debug("[onboardUser] Operation still in progress, waiting 2 seconds...") + await asyncio.sleep(2) + else: + log.warning(f"[onboardUser] Failed: HTTP {response.status_code}") + log.warning(f"[onboardUser] Response body: {response.text[:500]}") + raise Exception(f"HTTP {response.status_code}: {response.text[:200]}") + + log.error("[onboardUser] Timeout: Operation did not complete within 10 seconds") + return None + + +async def _get_onboard_tier( + api_base_url: str, + headers: dict +) -> Optional[str]: + """ + 从 loadCodeAssist 响应中获取用户应该注册的 tier + + Returns: + tier_id (如 "FREE", "STANDARD", "LEGACY") 或 None + """ + request_url = f"{api_base_url.rstrip('/')}/v1internal:loadCodeAssist" + request_body = { + "metadata": { + "ideType": "ANTIGRAVITY", + "platform": "PLATFORM_UNSPECIFIED", + "pluginType": "GEMINI" + } + } + + log.debug(f"[_get_onboard_tier] Fetching tier info from: {request_url}") + + response = await post_async( + request_url, + json=request_body, + headers=headers, + timeout=30.0, + ) + + if response.status_code == 200: + data = response.json() + log.debug(f"[_get_onboard_tier] Response data: {data}") + + # 查找默认的 tier + allowed_tiers = data.get("allowedTiers", []) + for tier in allowed_tiers: + if tier.get("isDefault"): + tier_id = tier.get("id") + log.info(f"[_get_onboard_tier] Found default tier: {tier_id}") + return tier_id + + # 如果没有默认 tier,使用 LEGACY 作为回退 + log.warning("[_get_onboard_tier] No default tier found, using LEGACY") + return "LEGACY" + else: + log.error(f"[_get_onboard_tier] Failed to fetch tier info: HTTP {response.status_code}") + return None + + diff --git a/src/httpx_client.py b/src/httpx_client.py new file mode 100644 index 0000000000000000000000000000000000000000..45e29e26daf6928e03ddd2f143362bec4067b7f9 --- /dev/null +++ b/src/httpx_client.py @@ -0,0 +1,108 @@ +""" +通用的HTTP客户端模块 +为所有需要使用httpx的模块提供统一的客户端配置和方法 +保持通用性,不与特定业务逻辑耦合 +""" + +from contextlib import asynccontextmanager +from typing import Any, AsyncGenerator, Dict, Optional + +import httpx + +from config import get_proxy_config +from log import log + + +class HttpxClientManager: + """通用HTTP客户端管理器""" + + async def get_client_kwargs(self, timeout: float = 30.0, **kwargs) -> Dict[str, Any]: + """获取httpx客户端的通用配置参数""" + client_kwargs = {"timeout": timeout, **kwargs} + + # 动态读取代理配置,支持热更新 + current_proxy_config = await get_proxy_config() + if current_proxy_config: + client_kwargs["proxy"] = current_proxy_config + + return client_kwargs + + @asynccontextmanager + async def get_client( + self, timeout: float = 30.0, **kwargs + ) -> AsyncGenerator[httpx.AsyncClient, None]: + """获取配置好的异步HTTP客户端""" + client_kwargs = await self.get_client_kwargs(timeout=timeout, **kwargs) + + async with httpx.AsyncClient(**client_kwargs) as client: + yield client + + @asynccontextmanager + async def get_streaming_client( + self, timeout: float = None, **kwargs + ) -> AsyncGenerator[httpx.AsyncClient, None]: + """获取用于流式请求的HTTP客户端(无超时限制)""" + client_kwargs = await self.get_client_kwargs(timeout=timeout, **kwargs) + + # 创建独立的客户端实例用于流式处理 + client = httpx.AsyncClient(**client_kwargs) + try: + yield client + finally: + # 确保无论发生什么都关闭客户端 + try: + await client.aclose() + except Exception as e: + log.warning(f"Error closing streaming client: {e}") + + +# 全局HTTP客户端管理器实例 +http_client = HttpxClientManager() + + +# 通用的异步方法 +async def get_async( + url: str, headers: Optional[Dict[str, str]] = None, timeout: float = 30.0, **kwargs +) -> httpx.Response: + """通用异步GET请求""" + async with http_client.get_client(timeout=timeout, **kwargs) as client: + return await client.get(url, headers=headers) + + +async def post_async( + url: str, + data: Any = None, + json: Any = None, + headers: Optional[Dict[str, str]] = None, + timeout: float = 30.0, + **kwargs, +) -> httpx.Response: + """通用异步POST请求""" + async with http_client.get_client(timeout=timeout, **kwargs) as client: + return await client.post(url, data=data, json=json, headers=headers) + + +async def stream_post_async( + url: str, + body: Dict[str, Any], + native: bool = False, + headers: Optional[Dict[str, str]] = None, + **kwargs, +): + """流式异步POST请求""" + async with http_client.get_streaming_client(**kwargs) as client: + async with client.stream("POST", url, json=body, headers=headers) as r: + # 错误直接返回 + if r.status_code != 200: + from fastapi import Response + yield Response(await r.aread(), r.status_code, dict(r.headers)) + return + + # 如果native=True,直接返回bytes流 + if native: + async for chunk in r.aiter_bytes(): + yield chunk + else: + # 通过aiter_lines转化成str流返回 + async for line in r.aiter_lines(): + yield line diff --git a/src/models.py b/src/models.py new file mode 100644 index 0000000000000000000000000000000000000000..37b6ca119115c427c5cf06bf09bdfb97334df6ec --- /dev/null +++ b/src/models.py @@ -0,0 +1,376 @@ +from typing import Any, Dict, List, Optional, Union + +from pydantic import BaseModel, Field + + +# Pydantic v1/v2 兼容性辅助函数 +def model_to_dict(model: BaseModel) -> Dict[str, Any]: + """ + 兼容 Pydantic v1 和 v2 的模型转字典方法 + - v1: model.dict() + - v2: model.model_dump() + """ + if hasattr(model, 'model_dump'): + # Pydantic v2 + return model.model_dump() + else: + # Pydantic v1 + return model.dict() + + +# Common Models +class Model(BaseModel): + id: str + object: str = "model" + created: Optional[int] = None + owned_by: Optional[str] = "google" + + +class ModelList(BaseModel): + object: str = "list" + data: List[Model] + + +# OpenAI Models +class OpenAIToolFunction(BaseModel): + name: str + arguments: str # JSON string + + +class OpenAIToolCall(BaseModel): + id: str + type: str = "function" + function: OpenAIToolFunction + + +class OpenAITool(BaseModel): + type: str = "function" + function: Dict[str, Any] + + +class OpenAIChatMessage(BaseModel): + role: str + content: Union[str, List[Dict[str, Any]], None] = None + reasoning_content: Optional[str] = None + name: Optional[str] = None + tool_calls: Optional[List[OpenAIToolCall]] = None + tool_call_id: Optional[str] = None # for role="tool" + + +class OpenAIChatCompletionRequest(BaseModel): + model: str + messages: List[OpenAIChatMessage] + stream: bool = False + temperature: Optional[float] = Field(None, ge=0.0, le=2.0) + top_p: Optional[float] = Field(None, ge=0.0, le=1.0) + max_tokens: Optional[int] = Field(None, ge=1) + stop: Optional[Union[str, List[str]]] = None + frequency_penalty: Optional[float] = Field(None, ge=-2.0, le=2.0) + presence_penalty: Optional[float] = Field(None, ge=-2.0, le=2.0) + n: Optional[int] = Field(1, ge=1, le=128) + seed: Optional[int] = None + response_format: Optional[Dict[str, Any]] = None + top_k: Optional[int] = Field(None, ge=1) + tools: Optional[List[OpenAITool]] = None + tool_choice: Optional[Union[str, Dict[str, Any]]] = None + + class Config: + extra = "allow" # Allow additional fields not explicitly defined + + +# 通用的聊天完成请求模型(兼容OpenAI和其他格式) +ChatCompletionRequest = OpenAIChatCompletionRequest + + +class OpenAIChatCompletionChoice(BaseModel): + index: int + message: OpenAIChatMessage + finish_reason: Optional[str] = None + logprobs: Optional[Dict[str, Any]] = None + + +class OpenAIChatCompletionResponse(BaseModel): + id: str + object: str = "chat.completion" + created: int + model: str + choices: List[OpenAIChatCompletionChoice] + usage: Optional[Dict[str, int]] = None + system_fingerprint: Optional[str] = None + + +class OpenAIDelta(BaseModel): + role: Optional[str] = None + content: Optional[str] = None + reasoning_content: Optional[str] = None + + +class OpenAIChatCompletionStreamChoice(BaseModel): + index: int + delta: OpenAIDelta + finish_reason: Optional[str] = None + logprobs: Optional[Dict[str, Any]] = None + + +class OpenAIChatCompletionStreamResponse(BaseModel): + id: str + object: str = "chat.completion.chunk" + created: int + model: str + choices: List[OpenAIChatCompletionStreamChoice] + system_fingerprint: Optional[str] = None + + +# Gemini Models +class GeminiPart(BaseModel): + text: Optional[str] = None + inlineData: Optional[Dict[str, Any]] = None + fileData: Optional[Dict[str, Any]] = None + thought: Optional[bool] = False + + +class GeminiContent(BaseModel): + role: str + parts: List[GeminiPart] + + +class GeminiSystemInstruction(BaseModel): + parts: List[GeminiPart] + + +class GeminiImageConfig(BaseModel): + """图片生成配置""" + aspect_ratio: Optional[str] = None # "1:1", "2:3", "3:2", "3:4", "4:3", "4:5", "5:4", "9:16", "16:9", "21:9" + image_size: Optional[str] = None # "1K", "2K", "4K" + + +class GeminiGenerationConfig(BaseModel): + temperature: Optional[float] = Field(None, ge=0.0, le=2.0) + topP: Optional[float] = Field(None, ge=0.0, le=1.0) + topK: Optional[int] = Field(None, ge=1) + maxOutputTokens: Optional[int] = Field(None, ge=1) + stopSequences: Optional[List[str]] = None + responseMimeType: Optional[str] = None + responseSchema: Optional[Dict[str, Any]] = None + candidateCount: Optional[int] = Field(None, ge=1, le=8) + seed: Optional[int] = None + frequencyPenalty: Optional[float] = Field(None, ge=-2.0, le=2.0) + presencePenalty: Optional[float] = Field(None, ge=-2.0, le=2.0) + thinkingConfig: Optional[Dict[str, Any]] = None + # 图片生成相关参数 + response_modalities: Optional[List[str]] = None # ["TEXT", "IMAGE"] + image_config: Optional[GeminiImageConfig] = None + + +class GeminiSafetySetting(BaseModel): + category: str + threshold: str + + +class GeminiRequest(BaseModel): + contents: List[GeminiContent] + systemInstruction: Optional[GeminiSystemInstruction] = None + generationConfig: Optional[GeminiGenerationConfig] = None + safetySettings: Optional[List[GeminiSafetySetting]] = None + tools: Optional[List[Dict[str, Any]]] = None + toolConfig: Optional[Dict[str, Any]] = None + cachedContent: Optional[str] = None + + class Config: + extra = "allow" # 允许透传未定义的字段 + + +class GeminiCandidate(BaseModel): + content: GeminiContent + finishReason: Optional[str] = None + index: int = 0 + safetyRatings: Optional[List[Dict[str, Any]]] = None + citationMetadata: Optional[Dict[str, Any]] = None + tokenCount: Optional[int] = None + + +class GeminiUsageMetadata(BaseModel): + promptTokenCount: Optional[int] = None + candidatesTokenCount: Optional[int] = None + totalTokenCount: Optional[int] = None + + +class GeminiResponse(BaseModel): + candidates: List[GeminiCandidate] + usageMetadata: Optional[GeminiUsageMetadata] = None + modelVersion: Optional[str] = None + + +# Claude Models +class ClaudeContentBlock(BaseModel): + type: str # "text", "image", "tool_use", "tool_result" + text: Optional[str] = None + source: Optional[Dict[str, Any]] = None # for image type + id: Optional[str] = None # for tool_use + name: Optional[str] = None # for tool_use + input: Optional[Dict[str, Any]] = None # for tool_use + tool_use_id: Optional[str] = None # for tool_result + content: Optional[Union[str, List[Dict[str, Any]]]] = None # for tool_result + + +class ClaudeMessage(BaseModel): + role: str # "user" or "assistant" + content: Union[str, List[ClaudeContentBlock]] + + +class ClaudeTool(BaseModel): + name: str + description: Optional[str] = None + input_schema: Dict[str, Any] + + +class ClaudeMetadata(BaseModel): + user_id: Optional[str] = None + + +class ClaudeRequest(BaseModel): + model: str + messages: List[ClaudeMessage] + max_tokens: int = Field(..., ge=1) + system: Optional[Union[str, List[Dict[str, Any]]]] = None + temperature: Optional[float] = Field(None, ge=0.0, le=1.0) + top_p: Optional[float] = Field(None, ge=0.0, le=1.0) + top_k: Optional[int] = Field(None, ge=1) + stop_sequences: Optional[List[str]] = None + stream: bool = False + metadata: Optional[ClaudeMetadata] = None + tools: Optional[List[ClaudeTool]] = None + tool_choice: Optional[Union[str, Dict[str, Any]]] = None + + class Config: + extra = "allow" + + +class ClaudeUsage(BaseModel): + input_tokens: int + output_tokens: int + + +class ClaudeResponse(BaseModel): + id: str + type: str = "message" + role: str = "assistant" + content: List[ClaudeContentBlock] + model: str + stop_reason: Optional[str] = None + stop_sequence: Optional[str] = None + usage: ClaudeUsage + + +class ClaudeStreamEvent(BaseModel): + type: str # "message_start", "content_block_start", "content_block_delta", "content_block_stop", "message_delta", "message_stop" + message: Optional[ClaudeResponse] = None + index: Optional[int] = None + content_block: Optional[ClaudeContentBlock] = None + delta: Optional[Dict[str, Any]] = None + usage: Optional[ClaudeUsage] = None + + class Config: + extra = "allow" + + +# Error Models +class APIError(BaseModel): + message: str + type: str = "api_error" + code: Optional[int] = None + + +class ErrorResponse(BaseModel): + error: APIError + + +# Control Panel Models +class SystemStatus(BaseModel): + status: str + timestamp: str + credentials: Dict[str, int] + config: Dict[str, Any] + current_credential: str + + +class CredentialInfo(BaseModel): + filename: str + project_id: Optional[str] = None + status: Dict[str, Any] + size: Optional[int] = None + modified_time: Optional[str] = None + error: Optional[str] = None + + +class LogEntry(BaseModel): + timestamp: str + level: str + message: str + module: Optional[str] = None + + +class ConfigValue(BaseModel): + key: str + value: Any + env_locked: bool = False + description: Optional[str] = None + + +# Authentication Models +class AuthRequest(BaseModel): + project_id: Optional[str] = None + user_session: Optional[str] = None + + +class AuthResponse(BaseModel): + success: bool + auth_url: Optional[str] = None + state: Optional[str] = None + error: Optional[str] = None + credentials: Optional[Dict[str, Any]] = None + file_path: Optional[str] = None + requires_manual_project_id: Optional[bool] = None + requires_project_selection: Optional[bool] = None + available_projects: Optional[List[Dict[str, str]]] = None + + +class CredentialStatus(BaseModel): + disabled: bool = False + error_codes: List[int] = [] + last_success: Optional[str] = None + + +# Web Routes Models +class LoginRequest(BaseModel): + password: str + + +class AuthStartRequest(BaseModel): + project_id: Optional[str] = None # 现在是可选的 + mode: Optional[str] = "geminicli" # 凭证模式: geminicli 或 antigravity + + +class AuthCallbackRequest(BaseModel): + project_id: Optional[str] = None # 现在是可选的 + mode: Optional[str] = "geminicli" # 凭证模式: geminicli 或 antigravity + + +class AuthCallbackUrlRequest(BaseModel): + callback_url: str # OAuth回调完整URL + project_id: Optional[str] = None # 可选的项目ID + mode: Optional[str] = "geminicli" # 凭证模式: geminicli 或 antigravity + + +class CredFileActionRequest(BaseModel): + filename: str + action: str # enable, disable, delete + + +class CredFileBatchActionRequest(BaseModel): + action: str # "enable", "disable", "delete" + filenames: List[str] # 批量操作的文件名列表 + + +class ConfigSaveRequest(BaseModel): + config: dict diff --git a/src/router/antigravity/anthropic.py b/src/router/antigravity/anthropic.py new file mode 100644 index 0000000000000000000000000000000000000000..3089864554435c6cdd9b466de015721da6bfd0f0 --- /dev/null +++ b/src/router/antigravity/anthropic.py @@ -0,0 +1,566 @@ +""" +Anthropic Router - Handles Anthropic/Claude format API requests via Antigravity +通过Antigravity处理Anthropic/Claude格式请求的路由模块 +""" + +import sys +from pathlib import Path + +# 添加项目根目录到Python路径 +project_root = Path(__file__).resolve().parent.parent.parent.parent +if str(project_root) not in sys.path: + sys.path.insert(0, str(project_root)) + +# 标准库 +import asyncio +import json + +# 第三方库 +from fastapi import APIRouter, Depends, HTTPException +from fastapi.responses import JSONResponse, StreamingResponse + +# 本地模块 - 配置和日志 +from config import get_anti_truncation_max_attempts +from log import log + +# 本地模块 - 工具和认证 +from src.utils import ( + get_base_model_from_feature_model, + is_anti_truncation_model, + is_fake_streaming_model, + authenticate_bearer, +) + +# 本地模块 - 转换器(假流式需要) +from src.converter.fake_stream import ( + parse_response_for_fake_stream, + build_anthropic_fake_stream_chunks, + create_anthropic_heartbeat_chunk, +) + +# 本地模块 - 基础路由工具 +from src.router.hi_check import is_health_check_request, create_health_check_response + +# 本地模块 - 数据模型 +from src.models import ClaudeRequest, model_to_dict + +# 本地模块 - 任务管理 +from src.task_manager import create_managed_task + + +# ==================== 路由器初始化 ==================== + +router = APIRouter() + + +# ==================== API 路由 ==================== + +@router.post("/antigravity/v1/messages") +async def messages( + claude_request: ClaudeRequest, + _token: str = Depends(authenticate_bearer) +): + """ + 处理Anthropic/Claude格式的消息请求(流式和非流式) + + Args: + claude_request: Anthropic/Claude格式的请求体 + token: Bearer认证令牌 + """ + log.debug(f"[ANTIGRAVITY-ANTHROPIC] Request for model: {claude_request.model}") + + # 转换为字典 + normalized_dict = model_to_dict(claude_request) + + # 健康检查 + if is_health_check_request(normalized_dict, format="anthropic"): + response = create_health_check_response(format="anthropic") + return JSONResponse(content=response) + + # 处理模型名称和功能检测 + use_fake_streaming = is_fake_streaming_model(claude_request.model) + use_anti_truncation = is_anti_truncation_model(claude_request.model) + real_model = get_base_model_from_feature_model(claude_request.model) + + # 获取流式标志 + is_streaming = claude_request.stream + + # 对于抗截断模型的非流式请求,给出警告 + if use_anti_truncation and not is_streaming: + log.warning("抗截断功能仅在流式传输时有效,非流式请求将忽略此设置") + + # 更新模型名为真实模型名 + normalized_dict["model"] = real_model + + # 转换为 Gemini 格式 (使用 converter) + from src.converter.anthropic2gemini import anthropic_to_gemini_request + gemini_dict = await anthropic_to_gemini_request(normalized_dict) + + # anthropic_to_gemini_request 不包含 model 字段,需要手动添加 + gemini_dict["model"] = real_model + + # 规范化 Gemini 请求 (使用 antigravity 模式) + from src.converter.gemini_fix import normalize_gemini_request + gemini_dict = await normalize_gemini_request(gemini_dict, mode="antigravity") + + # 准备API请求格式 - 提取model并将其他字段放入request中 + api_request = { + "model": gemini_dict.pop("model"), + "request": gemini_dict + } + + # ========== 非流式请求 ========== + if not is_streaming: + # 调用 API 层的非流式请求 + from src.api.antigravity import non_stream_request + response = await non_stream_request(body=api_request) + + # 检查响应状态码 + status_code = getattr(response, "status_code", 200) + + # 提取响应体 + if hasattr(response, "body"): + response_body = response.body.decode() if isinstance(response.body, bytes) else response.body + elif hasattr(response, "content"): + response_body = response.content.decode() if isinstance(response.content, bytes) else response.content + else: + response_body = str(response) + + try: + gemini_response = json.loads(response_body) + except Exception as e: + log.error(f"Failed to parse Gemini response: {e}") + raise HTTPException(status_code=500, detail="Response parsing failed") + + # 转换为 Anthropic 格式 + from src.converter.anthropic2gemini import gemini_to_anthropic_response + anthropic_response = gemini_to_anthropic_response( + gemini_response, + real_model, + status_code + ) + + return JSONResponse(content=anthropic_response, status_code=status_code) + + # ========== 流式请求 ========== + + # ========== 假流式生成器 ========== + async def fake_stream_generator(): + # 发送心跳 + heartbeat = create_anthropic_heartbeat_chunk() + yield f"data: {json.dumps(heartbeat)}\n\n".encode() + + # 异步发送实际请求 + async def get_response(): + from src.api.antigravity import non_stream_request + response = await non_stream_request(body=api_request) + return response + + # 创建请求任务 + response_task = create_managed_task(get_response(), name="anthropic_fake_stream_request") + + try: + # 每3秒发送一次心跳,直到收到响应 + while not response_task.done(): + await asyncio.sleep(3.0) + if not response_task.done(): + yield f"data: {json.dumps(heartbeat)}\n\n".encode() + + # 获取响应结果 + response = await response_task + + except asyncio.CancelledError: + response_task.cancel() + try: + await response_task + except asyncio.CancelledError: + pass + raise + except Exception as e: + response_task.cancel() + try: + await response_task + except asyncio.CancelledError: + pass + log.error(f"Fake streaming request failed: {e}") + raise + + # 检查响应状态码 + if hasattr(response, "status_code") and response.status_code != 200: + # 错误响应 - 提取错误信息并以SSE格式返回 + log.error(f"Fake streaming got error response: status={response.status_code}") + + if hasattr(response, "body"): + error_body = response.body.decode() if isinstance(response.body, bytes) else response.body + elif hasattr(response, "content"): + error_body = response.content.decode() if isinstance(response.content, bytes) else response.content + else: + error_body = str(response) + + try: + error_data = json.loads(error_body) + # 转换错误为 Anthropic 格式 + from src.converter.anthropic2gemini import gemini_to_anthropic_response + anthropic_error = gemini_to_anthropic_response( + error_data, + real_model, + response.status_code + ) + yield f"data: {json.dumps(anthropic_error)}\n\n".encode() + except Exception: + # 如果无法解析为JSON,包装成错误对象 + yield f"data: {json.dumps({'error': error_body})}\n\n".encode() + + yield "data: [DONE]\n\n".encode() + return + + # 处理成功响应 - 提取响应内容 + if hasattr(response, "body"): + response_body = response.body.decode() if isinstance(response.body, bytes) else response.body + elif hasattr(response, "content"): + response_body = response.content.decode() if isinstance(response.content, bytes) else response.content + else: + response_body = str(response) + + try: + gemini_response = json.loads(response_body) + log.debug(f"Anthropic fake stream Gemini response: {gemini_response}") + + # 检查是否是错误响应(有些错误可能status_code是200但包含error字段) + if "error" in gemini_response: + log.error(f"Fake streaming got error in response body: {gemini_response['error']}") + # 转换错误为 Anthropic 格式 + from src.converter.anthropic2gemini import gemini_to_anthropic_response + anthropic_error = gemini_to_anthropic_response( + gemini_response, + real_model, + 200 + ) + yield f"data: {json.dumps(anthropic_error)}\n\n".encode() + yield "data: [DONE]\n\n".encode() + return + + # 使用统一的解析函数 + content, reasoning_content, finish_reason, images = parse_response_for_fake_stream(gemini_response) + + log.debug(f"Anthropic extracted content: {content}") + log.debug(f"Anthropic extracted reasoning: {reasoning_content[:100] if reasoning_content else 'None'}...") + log.debug(f"Anthropic extracted images count: {len(images)}") + + # 构建响应块 + chunks = build_anthropic_fake_stream_chunks(content, reasoning_content, finish_reason, real_model, images) + for idx, chunk in enumerate(chunks): + chunk_json = json.dumps(chunk) + log.debug(f"[FAKE_STREAM] Yielding chunk #{idx+1}: {chunk_json[:200]}") + yield f"data: {chunk_json}\n\n".encode() + + except Exception as e: + log.error(f"Response parsing failed: {e}, directly yield error") + # 构建错误响应 + error_chunk = { + "type": "error", + "error": { + "type": "api_error", + "message": str(e) + } + } + yield f"data: {json.dumps(error_chunk)}\n\n".encode() + + yield "data: [DONE]\n\n".encode() + + # ========== 流式抗截断生成器 ========== + async def anti_truncation_generator(): + from src.converter.anti_truncation import AntiTruncationStreamProcessor + from src.api.antigravity import stream_request + from src.converter.anti_truncation import apply_anti_truncation + from src.converter.anthropic2gemini import gemini_stream_to_anthropic_stream + + max_attempts = await get_anti_truncation_max_attempts() + + # 首先对payload应用反截断指令 + anti_truncation_payload = apply_anti_truncation(api_request) + + # 定义流式请求函数(返回 StreamingResponse) + async def stream_request_wrapper(payload): + # stream_request 返回异步生成器,需要包装成 StreamingResponse + stream_gen = stream_request(body=payload, native=False) + return StreamingResponse(stream_gen, media_type="text/event-stream") + + # 创建反截断处理器 + processor = AntiTruncationStreamProcessor( + stream_request_wrapper, + anti_truncation_payload, + max_attempts + ) + + # 包装以确保是bytes流 + async def bytes_wrapper(): + async for chunk in processor.process_stream(): + if isinstance(chunk, str): + yield chunk.encode('utf-8') + else: + yield chunk + + # 直接将整个流传递给转换器 + async for anthropic_chunk in gemini_stream_to_anthropic_stream( + bytes_wrapper(), + real_model, + 200 + ): + if anthropic_chunk: + yield anthropic_chunk + + # ========== 普通流式生成器 ========== + async def normal_stream_generator(): + from src.api.antigravity import stream_request + from fastapi import Response + from src.converter.anthropic2gemini import gemini_stream_to_anthropic_stream + + # 调用 API 层的流式请求(不使用 native 模式) + stream_gen = stream_request(body=api_request, native=False) + + # 包装流式生成器以处理错误响应 + async def gemini_chunk_wrapper(): + async for chunk in stream_gen: + # 检查是否是Response对象(错误情况) + if isinstance(chunk, Response): + # 错误响应,不进行转换,直接传递 + error_content = chunk.body if isinstance(chunk.body, bytes) else chunk.body.encode('utf-8') + try: + gemini_error = json.loads(error_content.decode('utf-8')) + from src.converter.anthropic2gemini import gemini_to_anthropic_response + anthropic_error = gemini_to_anthropic_response( + gemini_error, + real_model, + chunk.status_code + ) + yield f"data: {json.dumps(anthropic_error)}\n\n".encode('utf-8') + except Exception: + yield f"data: {json.dumps({'type': 'error', 'error': {'type': 'api_error', 'message': 'Stream error'}})}\n\n".encode('utf-8') + return + else: + # 确保是bytes类型 + if isinstance(chunk, str): + yield chunk.encode('utf-8') + else: + yield chunk + + # 使用转换器处理整个流 + async for anthropic_chunk in gemini_stream_to_anthropic_stream( + gemini_chunk_wrapper(), + real_model, + 200 + ): + if anthropic_chunk: + yield anthropic_chunk + + # ========== 根据模式选择生成器 ========== + if use_fake_streaming: + return StreamingResponse(fake_stream_generator(), media_type="text/event-stream") + elif use_anti_truncation: + log.info("启用流式抗截断功能") + return StreamingResponse(anti_truncation_generator(), media_type="text/event-stream") + else: + return StreamingResponse(normal_stream_generator(), media_type="text/event-stream") + + +# ==================== 测试代码 ==================== + +if __name__ == "__main__": + """ + 测试代码:演示Anthropic路由的流式和非流式响应 + 运行方式: python src/router/antigravity/anthropic.py + """ + + from fastapi.testclient import TestClient + from fastapi import FastAPI + + print("=" * 80) + print("Anthropic Router 测试") + print("=" * 80) + + # 创建测试应用 + app = FastAPI() + app.include_router(router) + + # 测试客户端 + client = TestClient(app) + + # 测试请求体 (Anthropic格式) + test_request_body = { + "model": "gemini-2.5-flash", + "max_tokens": 1024, + "messages": [ + {"role": "user", "content": "Hello, tell me a joke in one sentence."} + ] + } + + # 测试Bearer令牌(模拟) + test_token = "Bearer pwd" + + def test_non_stream_request(): + """测试非流式请求""" + print("\n" + "=" * 80) + print("【测试1】非流式请求 (POST /antigravity/v1/messages)") + print("=" * 80) + print(f"请求体: {json.dumps(test_request_body, indent=2, ensure_ascii=False)}\n") + + response = client.post( + "/antigravity/v1/messages", + json=test_request_body, + headers={"Authorization": test_token} + ) + + print("非流式响应数据:") + print("-" * 80) + print(f"状态码: {response.status_code}") + print(f"Content-Type: {response.headers.get('content-type', 'N/A')}") + + try: + content = response.text + print(f"\n响应内容 (原始):\n{content}\n") + + # 尝试解析JSON + try: + json_data = response.json() + print(f"响应内容 (格式化JSON):") + print(json.dumps(json_data, indent=2, ensure_ascii=False)) + except json.JSONDecodeError: + print("(非JSON格式)") + except Exception as e: + print(f"内容解析失败: {e}") + + def test_stream_request(): + """测试流式请求""" + print("\n" + "=" * 80) + print("【测试2】流式请求 (POST /antigravity/v1/messages)") + print("=" * 80) + + stream_request_body = test_request_body.copy() + stream_request_body["stream"] = True + + print(f"请求体: {json.dumps(stream_request_body, indent=2, ensure_ascii=False)}\n") + + print("流式响应数据 (每个chunk):") + print("-" * 80) + + with client.stream( + "POST", + "/antigravity/v1/messages", + json=stream_request_body, + headers={"Authorization": test_token} + ) as response: + print(f"状态码: {response.status_code}") + print(f"Content-Type: {response.headers.get('content-type', 'N/A')}\n") + + chunk_count = 0 + for chunk in response.iter_bytes(): + if chunk: + chunk_count += 1 + print(f"\nChunk #{chunk_count}:") + print(f" 类型: {type(chunk).__name__}") + print(f" 长度: {len(chunk)}") + + # 解码chunk + try: + chunk_str = chunk.decode('utf-8') + print(f" 内容预览: {repr(chunk_str[:200] if len(chunk_str) > 200 else chunk_str)}") + + # 如果是SSE格式,尝试解析每一行 + if chunk_str.startswith("event: ") or chunk_str.startswith("data: "): + # 按行分割,处理每个SSE事件 + for line in chunk_str.strip().split('\n'): + line = line.strip() + if not line: + continue + + if line == "data: [DONE]": + print(f" => 流结束标记") + elif line.startswith("data: "): + try: + json_str = line[6:] # 去掉 "data: " 前缀 + json_data = json.loads(json_str) + print(f" 解析后的JSON: {json.dumps(json_data, indent=4, ensure_ascii=False)}") + except Exception as e: + print(f" SSE解析失败: {e}") + except Exception as e: + print(f" 解码失败: {e}") + + print(f"\n总共收到 {chunk_count} 个chunk") + + def test_fake_stream_request(): + """测试假流式请求""" + print("\n" + "=" * 80) + print("【测试3】假流式请求 (POST /antigravity/v1/messages with 假流式 prefix)") + print("=" * 80) + + fake_stream_request_body = test_request_body.copy() + fake_stream_request_body["model"] = "假流式/gemini-2.5-flash" + fake_stream_request_body["stream"] = True + + print(f"请求体: {json.dumps(fake_stream_request_body, indent=2, ensure_ascii=False)}\n") + + print("假流式响应数据 (每个chunk):") + print("-" * 80) + + with client.stream( + "POST", + "/antigravity/v1/messages", + json=fake_stream_request_body, + headers={"Authorization": test_token} + ) as response: + print(f"状态码: {response.status_code}") + print(f"Content-Type: {response.headers.get('content-type', 'N/A')}\n") + + chunk_count = 0 + for chunk in response.iter_bytes(): + if chunk: + chunk_count += 1 + chunk_str = chunk.decode('utf-8') + + print(f"\nChunk #{chunk_count}:") + print(f" 长度: {len(chunk_str)} 字节") + + # 解析chunk中的所有SSE事件 + events = [] + for line in chunk_str.split('\n'): + line = line.strip() + if line.startswith("data: ") or line.startswith("event: "): + events.append(line) + + print(f" 包含 {len(events)} 个SSE事件") + + # 显示每个事件 + for event_idx, event_line in enumerate(events, 1): + if event_line == "data: [DONE]": + print(f" 事件 #{event_idx}: [DONE]") + elif event_line.startswith("data: "): + try: + json_str = event_line[6:] # 去掉 "data: " 前缀 + json_data = json.loads(json_str) + event_type = json_data.get("type", "unknown") + print(f" 事件 #{event_idx}: type={event_type}") + except Exception as e: + print(f" 事件 #{event_idx}: 解析失败 - {e}") + + print(f"\n总共收到 {chunk_count} 个HTTP chunk") + + # 运行测试 + try: + # 测试非流式请求 + test_non_stream_request() + + # 测试流式请求 + test_stream_request() + + # 测试假流式请求 + test_fake_stream_request() + + print("\n" + "=" * 80) + print("测试完成") + print("=" * 80) + + except Exception as e: + print(f"\n❌ 测试过程中出现异常: {e}") + import traceback + traceback.print_exc() diff --git a/src/router/antigravity/gemini.py b/src/router/antigravity/gemini.py new file mode 100644 index 0000000000000000000000000000000000000000..4640f4dc3962f90f5c7386d725b8092ba58c02c7 --- /dev/null +++ b/src/router/antigravity/gemini.py @@ -0,0 +1,690 @@ +""" +Gemini Router - Handles native Gemini format API requests (Antigravity backend) +处理原生Gemini格式请求的路由模块(Antigravity后端) +""" + +import sys +from pathlib import Path + +# 添加项目根目录到Python路径 +project_root = Path(__file__).resolve().parent.parent.parent.parent +if str(project_root) not in sys.path: + sys.path.insert(0, str(project_root)) + +# 标准库 +import asyncio +import json + +# 第三方库 +from fastapi import APIRouter, Depends, HTTPException, Path, Request +from fastapi.responses import JSONResponse, StreamingResponse + +# 本地模块 - 配置和日志 +from config import get_anti_truncation_max_attempts +from log import log + +# 本地模块 - 工具和认证 +from src.utils import ( + get_base_model_from_feature_model, + is_anti_truncation_model, + authenticate_gemini_flexible, + is_fake_streaming_model +) + +# 本地模块 - 转换器(假流式需要) +from src.converter.fake_stream import ( + parse_response_for_fake_stream, + build_gemini_fake_stream_chunks, + create_gemini_heartbeat_chunk, +) + +# 本地模块 - 基础路由工具 +from src.router.hi_check import is_health_check_request, create_health_check_response + +# 本地模块 - 数据模型 +from src.models import GeminiRequest, model_to_dict + +# 本地模块 - 任务管理 +from src.task_manager import create_managed_task + + +# ==================== 路由器初始化 ==================== + +router = APIRouter() + + +# ==================== API 路由 ==================== + +@router.post("/antigravity/v1beta/models/{model:path}:generateContent") +@router.post("/antigravity/v1/models/{model:path}:generateContent") +async def generate_content( + gemini_request: "GeminiRequest", + model: str = Path(..., description="Model name"), + api_key: str = Depends(authenticate_gemini_flexible), +): + """ + 处理Gemini格式的内容生成请求(非流式) + + Args: + gemini_request: Gemini格式的请求体 + model: 模型名称 + api_key: API 密钥 + """ + log.debug(f"[ANTIGRAVITY] Non-streaming request for model: {model}") + + # 转换为字典 + normalized_dict = model_to_dict(gemini_request) + + # 健康检查 + if is_health_check_request(normalized_dict, format="gemini"): + response = create_health_check_response(format="gemini") + return JSONResponse(content=response) + + # 处理模型名称和功能检测 + use_anti_truncation = is_anti_truncation_model(model) + real_model = get_base_model_from_feature_model(model) + + # 对于抗截断模型的非流式请求,给出警告 + if use_anti_truncation: + log.warning("抗截断功能仅在流式传输时有效,非流式请求将忽略此设置") + + # 更新模型名为真实模型名 + normalized_dict["model"] = real_model + + # 规范化 Gemini 请求 (使用 antigravity 模式) + from src.converter.gemini_fix import normalize_gemini_request + normalized_dict = await normalize_gemini_request(normalized_dict, mode="antigravity") + + # 准备API请求格式 - 提取model并将其他字段放入request中 + api_request = { + "model": normalized_dict.pop("model"), + "request": normalized_dict + } + + # 调用 API 层的非流式请求 + from src.api.antigravity import non_stream_request + response = await non_stream_request(body=api_request) + + # 直接返回响应(response已经是FastAPI Response对象) + # 保持 Gemini 原生的 inlineData 格式,不进行 Markdown 转换 + return response + +@router.post("/antigravity/v1beta/models/{model:path}:streamGenerateContent") +@router.post("/antigravity/v1/models/{model:path}:streamGenerateContent") +async def stream_generate_content( + gemini_request: GeminiRequest, + model: str = Path(..., description="Model name"), + api_key: str = Depends(authenticate_gemini_flexible), +): + """ + 处理Gemini格式的流式内容生成请求 + + Args: + gemini_request: Gemini格式的请求体 + model: 模型名称 + api_key: API 密钥 + """ + log.debug(f"[ANTIGRAVITY] Streaming request for model: {model}") + + # 转换为字典 + normalized_dict = model_to_dict(gemini_request) + + # 处理模型名称和功能检测 + use_fake_streaming = is_fake_streaming_model(model) + use_anti_truncation = is_anti_truncation_model(model) + real_model = get_base_model_from_feature_model(model) + + # 更新模型名为真实模型名 + normalized_dict["model"] = real_model + + # ========== 假流式生成器 ========== + async def fake_stream_generator(): + from src.converter.gemini_fix import normalize_gemini_request + normalized_req = await normalize_gemini_request(normalized_dict.copy(), mode="antigravity") + + # 准备API请求格式 - 提取model并将其他字段放入request中 + api_request = { + "model": normalized_req.pop("model"), + "request": normalized_req + } + + # 发送心跳 + heartbeat = create_gemini_heartbeat_chunk() + yield f"data: {json.dumps(heartbeat)}\n\n".encode() + + # 异步发送实际请求 + async def get_response(): + from src.api.antigravity import non_stream_request + response = await non_stream_request(body=api_request) + return response + + # 创建请求任务 + response_task = create_managed_task(get_response(), name="gemini_fake_stream_request") + + try: + # 每3秒发送一次心跳,直到收到响应 + while not response_task.done(): + await asyncio.sleep(3.0) + if not response_task.done(): + yield f"data: {json.dumps(heartbeat)}\n\n".encode() + + # 获取响应结果 + response = await response_task + + except asyncio.CancelledError: + response_task.cancel() + try: + await response_task + except asyncio.CancelledError: + pass + raise + except Exception as e: + response_task.cancel() + try: + await response_task + except asyncio.CancelledError: + pass + log.error(f"Fake streaming request failed: {e}") + raise + + # 检查响应状态码 + if hasattr(response, "status_code") and response.status_code != 200: + # 错误响应 - 提取错误信息并以SSE格式返回 + log.error(f"Fake streaming got error response: status={response.status_code}") + + if hasattr(response, "body"): + error_body = response.body.decode() if isinstance(response.body, bytes) else response.body + elif hasattr(response, "content"): + error_body = response.content.decode() if isinstance(response.content, bytes) else response.content + else: + error_body = str(response) + + try: + error_data = json.loads(error_body) + # 以SSE格式返回错误 + yield f"data: {json.dumps(error_data)}\n\n".encode() + except Exception: + # 如果无法解析为JSON,包装成错误对象 + yield f"data: {json.dumps({'error': error_body})}\n\n".encode() + + yield "data: [DONE]\n\n".encode() + return + + # 处理成功响应 - 提取响应内容 + if hasattr(response, "body"): + response_body = response.body.decode() if isinstance(response.body, bytes) else response.body + elif hasattr(response, "content"): + response_body = response.content.decode() if isinstance(response.content, bytes) else response.content + else: + response_body = str(response) + + try: + response_data = json.loads(response_body) + log.debug(f"Gemini fake stream response data: {response_data}") + + # 检查是否是错误响应(有些错误可能status_code是200但包含error字段) + if "error" in response_data: + log.error(f"Fake streaming got error in response body: {response_data['error']}") + yield f"data: {json.dumps(response_data)}\n\n".encode() + yield "data: [DONE]\n\n".encode() + return + + # 使用统一的解析函数 + content, reasoning_content, finish_reason, images = parse_response_for_fake_stream(response_data) + + log.debug(f"Gemini extracted content: {content}") + log.debug(f"Gemini extracted reasoning: {reasoning_content[:100] if reasoning_content else 'None'}...") + log.debug(f"Gemini extracted images count: {len(images)}") + + # 构建响应块 + chunks = build_gemini_fake_stream_chunks(content, reasoning_content, finish_reason, images) + for idx, chunk in enumerate(chunks): + chunk_json = json.dumps(chunk) + log.debug(f"[FAKE_STREAM] Yielding chunk #{idx+1}: {chunk_json[:200]}") + yield f"data: {chunk_json}\n\n".encode() + + except Exception as e: + log.error(f"Response parsing failed: {e}, directly yield original response") + # 直接yield原始响应,不进行包装 + yield f"data: {response_body}\n\n".encode() + + yield "data: [DONE]\n\n".encode() + + # ========== 流式抗截断生成器 ========== + async def anti_truncation_generator(): + from src.converter.gemini_fix import normalize_gemini_request + from src.converter.anti_truncation import AntiTruncationStreamProcessor + from src.converter.anti_truncation import apply_anti_truncation + from src.api.antigravity import stream_request + + # 先进行基础标准化 + normalized_req = await normalize_gemini_request(normalized_dict.copy(), mode="antigravity") + + # 准备API请求格式 - 提取model并将其他字段放入request中 + api_request = { + "model": normalized_req.pop("model") if "model" in normalized_req else real_model, + "request": normalized_req + } + + max_attempts = await get_anti_truncation_max_attempts() + + # 首先对payload应用反截断指令 + anti_truncation_payload = apply_anti_truncation(api_request) + + # 定义流式请求函数(返回 StreamingResponse) + async def stream_request_wrapper(payload): + # stream_request 返回异步生成器,需要包装成 StreamingResponse + stream_gen = stream_request(body=payload, native=False) + return StreamingResponse(stream_gen, media_type="text/event-stream") + + # 创建反截断处理器 + processor = AntiTruncationStreamProcessor( + stream_request_wrapper, + anti_truncation_payload, + max_attempts + ) + + # 迭代 process_stream() 生成器,并展开 response 包装 + async for chunk in processor.process_stream(): + if isinstance(chunk, (str, bytes)): + chunk_str = chunk.decode('utf-8') if isinstance(chunk, bytes) else chunk + + # 解析并展开 response 包装 + if chunk_str.startswith("data: "): + json_str = chunk_str[6:].strip() + + # 跳过 [DONE] 标记 + if json_str == "[DONE]": + yield chunk + continue + + try: + # 解析JSON + data = json.loads(json_str) + + # 展开 response 包装 + if "response" in data and "candidates" not in data: + log.debug(f"[ANTIGRAVITY-ANTI-TRUNCATION] 展开response包装") + unwrapped_data = data["response"] + # 重新构建SSE格式 + yield f"data: {json.dumps(unwrapped_data, ensure_ascii=False)}\n\n".encode('utf-8') + else: + # 已经是展开的格式,直接返回 + yield chunk + except json.JSONDecodeError: + # JSON解析失败,直接返回原始chunk + yield chunk + else: + # 不是SSE格式,直接返回 + yield chunk + else: + # 其他类型,直接返回 + yield chunk + + # ========== 普通流式生成器 ========== + async def normal_stream_generator(): + from src.converter.gemini_fix import normalize_gemini_request + from src.api.antigravity import stream_request + from fastapi import Response + + normalized_req = await normalize_gemini_request(normalized_dict.copy(), mode="antigravity") + + # 准备API请求格式 - 提取model并将其他字段放入request中 + api_request = { + "model": normalized_req.pop("model"), + "request": normalized_req + } + + # 所有流式请求都使用非 native 模式(SSE格式)并展开 response 包装 + log.debug(f"[ANTIGRAVITY] 使用非native模式,将展开response包装") + stream_gen = stream_request(body=api_request, native=False) + + # 展开 response 包装 + async for chunk in stream_gen: + # 检查是否是Response对象(错误情况) + if isinstance(chunk, Response): + # 将Response转换为SSE格式的错误消息 + error_content = chunk.body if isinstance(chunk.body, bytes) else chunk.body.encode('utf-8') + error_json = json.loads(error_content.decode('utf-8')) + # 以SSE格式返回错误 + yield f"data: {json.dumps(error_json)}\n\n".encode('utf-8') + return + + # 处理SSE格式的chunk + if isinstance(chunk, (str, bytes)): + chunk_str = chunk.decode('utf-8') if isinstance(chunk, bytes) else chunk + + # 解析并展开 response 包装 + if chunk_str.startswith("data: "): + json_str = chunk_str[6:].strip() + + # 跳过 [DONE] 标记 + if json_str == "[DONE]": + yield chunk + continue + + try: + # 解析JSON + data = json.loads(json_str) + + # 展开 response 包装 + if "response" in data and "candidates" not in data: + log.debug(f"[ANTIGRAVITY] 展开response包装") + unwrapped_data = data["response"] + # 重新构建SSE格式 + yield f"data: {json.dumps(unwrapped_data, ensure_ascii=False)}\n\n".encode('utf-8') + else: + # 已经是展开的格式,直接返回 + yield chunk + except json.JSONDecodeError: + # JSON解析失败,直接返回原始chunk + yield chunk + else: + # 不是SSE格式,直接返回 + yield chunk + + # ========== 根据模式选择生成器 ========== + if use_fake_streaming: + return StreamingResponse(fake_stream_generator(), media_type="text/event-stream") + elif use_anti_truncation: + log.info("启用流式抗截断功能") + return StreamingResponse(anti_truncation_generator(), media_type="text/event-stream") + else: + return StreamingResponse(normal_stream_generator(), media_type="text/event-stream") + +@router.post("/antigravity/v1beta/models/{model:path}:countTokens") +@router.post("/antigravity/v1/models/{model:path}:countTokens") +async def count_tokens( + request: Request = None, + api_key: str = Depends(authenticate_gemini_flexible), +): + """ + 模拟Gemini格式的token计数 + + 使用简单的启发式方法:大约4字符=1token + """ + + try: + request_data = await request.json() + except Exception as e: + log.error(f"Failed to parse JSON request: {e}") + raise HTTPException(status_code=400, detail=f"Invalid JSON: {str(e)}") + + # 简单的token计数模拟 - 基于文本长度估算 + total_tokens = 0 + + # 如果有contents字段 + if "contents" in request_data: + for content in request_data["contents"]: + if "parts" in content: + for part in content["parts"]: + if "text" in part: + # 简单估算:大约4字符=1token + text_length = len(part["text"]) + total_tokens += max(1, text_length // 4) + + # 如果有generateContentRequest字段 + elif "generateContentRequest" in request_data: + gen_request = request_data["generateContentRequest"] + if "contents" in gen_request: + for content in gen_request["contents"]: + if "parts" in content: + for part in content["parts"]: + if "text" in part: + text_length = len(part["text"]) + total_tokens += max(1, text_length // 4) + + # 返回Gemini格式的响应 + return JSONResponse(content={"totalTokens": total_tokens}) + +# ==================== 测试代码 ==================== + +if __name__ == "__main__": + """ + 测试代码:演示Gemini路由的流式和非流式响应 + 运行方式: python src/router/antigravity/gemini.py + """ + + from fastapi.testclient import TestClient + from fastapi import FastAPI + + print("=" * 80) + print("Gemini Router (Antigravity Backend) 测试") + print("=" * 80) + + # 创建测试应用 + app = FastAPI() + app.include_router(router) + + # 测试客户端 + client = TestClient(app) + + # 测试请求体 (Gemini格式) + test_request_body = { + "contents": [ + { + "role": "user", + "parts": [{"text": "Hello, tell me a joke in one sentence."}] + } + ] + } + + # 测试API密钥(模拟) + test_api_key = "pwd" + + def test_non_stream_request(): + """测试非流式请求""" + print("\n" + "=" * 80) + print("【测试2】非流式请求 (POST /antigravity/v1/models/gemini-2.5-flash:generateContent)") + print("=" * 80) + print(f"请求体: {json.dumps(test_request_body, indent=2, ensure_ascii=False)}\n") + + response = client.post( + "/antigravity/v1/models/gemini-2.5-flash:generateContent", + json=test_request_body, + params={"key": test_api_key} + ) + + print("非流式响应数据:") + print("-" * 80) + print(f"状态码: {response.status_code}") + print(f"Content-Type: {response.headers.get('content-type', 'N/A')}") + + try: + content = response.text + print(f"\n响应内容 (原始):\n{content}\n") + + # 尝试解析JSON + try: + json_data = response.json() + print(f"响应内容 (格式化JSON):") + print(json.dumps(json_data, indent=2, ensure_ascii=False)) + except json.JSONDecodeError: + print("(非JSON格式)") + except Exception as e: + print(f"内容解析失败: {e}") + + def test_stream_request(): + """测试流式请求""" + print("\n" + "=" * 80) + print("【测试3】流式请求 (POST /antigravity/v1/models/gemini-2.5-flash:streamGenerateContent)") + print("=" * 80) + print(f"请求体: {json.dumps(test_request_body, indent=2, ensure_ascii=False)}\n") + + print("流式响应数据 (每个chunk):") + print("-" * 80) + + with client.stream( + "POST", + "/antigravity/v1/models/gemini-2.5-flash:streamGenerateContent", + json=test_request_body, + params={"key": test_api_key} + ) as response: + print(f"状态码: {response.status_code}") + print(f"Content-Type: {response.headers.get('content-type', 'N/A')}\n") + + chunk_count = 0 + for chunk in response.iter_bytes(): + if chunk: + chunk_count += 1 + print(f"\nChunk #{chunk_count}:") + print(f" 类型: {type(chunk).__name__}") + print(f" 长度: {len(chunk)}") + + # 解码chunk + try: + chunk_str = chunk.decode('utf-8') + print(f" 内容预览: {repr(chunk_str[:200] if len(chunk_str) > 200 else chunk_str)}") + + # 如果是SSE格式,尝试解析每一行 + if chunk_str.startswith("data: "): + # 按行分割,处理每个SSE事件 + for line in chunk_str.strip().split('\n'): + line = line.strip() + if not line: + continue + + if line == "data: [DONE]": + print(f" => 流结束标记") + elif line.startswith("data: "): + try: + json_str = line[6:] # 去掉 "data: " 前缀 + json_data = json.loads(json_str) + print(f" 解析后的JSON: {json.dumps(json_data, indent=4, ensure_ascii=False)}") + except Exception as e: + print(f" SSE解析失败: {e}") + except Exception as e: + print(f" 解码失败: {e}") + + print(f"\n总共收到 {chunk_count} 个chunk") + + def test_fake_stream_request(): + """测试假流式请求""" + print("\n" + "=" * 80) + print("【测试4】假流式请求 (POST /antigravity/v1/models/假流式/gemini-2.5-flash:streamGenerateContent)") + print("=" * 80) + print(f"请求体: {json.dumps(test_request_body, indent=2, ensure_ascii=False)}\n") + + print("假流式响应数据 (每个chunk):") + print("-" * 80) + + with client.stream( + "POST", + "/antigravity/v1/models/假流式/gemini-2.5-flash:streamGenerateContent", + json=test_request_body, + params={"key": test_api_key} + ) as response: + print(f"状态码: {response.status_code}") + print(f"Content-Type: {response.headers.get('content-type', 'N/A')}\n") + + chunk_count = 0 + for chunk in response.iter_bytes(): + if chunk: + chunk_count += 1 + chunk_str = chunk.decode('utf-8') + + print(f"\nChunk #{chunk_count}:") + print(f" 长度: {len(chunk_str)} 字节") + + # 解析chunk中的所有SSE事件 + events = [] + for line in chunk_str.split('\n'): + line = line.strip() + if line.startswith("data: "): + events.append(line) + + print(f" 包含 {len(events)} 个SSE事件") + + # 显示每个事件 + for event_idx, event_line in enumerate(events, 1): + if event_line == "data: [DONE]": + print(f" 事件 #{event_idx}: [DONE]") + else: + try: + json_str = event_line[6:] # 去掉 "data: " 前缀 + json_data = json.loads(json_str) + # 提取text内容 + text = json_data.get("candidates", [{}])[0].get("content", {}).get("parts", [{}])[0].get("text", "") + finish_reason = json_data.get("candidates", [{}])[0].get("finishReason") + print(f" 事件 #{event_idx}: text={repr(text[:50])}{'...' if len(text) > 50 else ''}, finishReason={finish_reason}") + except Exception as e: + print(f" 事件 #{event_idx}: 解析失败 - {e}") + + print(f"\n总共收到 {chunk_count} 个HTTP chunk") + + def test_anti_truncation_stream_request(): + """测试流式抗截断请求""" + print("\n" + "=" * 80) + print("【测试5】流式抗截断请求 (POST /antigravity/v1/models/流式抗截断/gemini-2.5-flash:streamGenerateContent)") + print("=" * 80) + print(f"请求体: {json.dumps(test_request_body, indent=2, ensure_ascii=False)}\n") + + print("流式抗截断响应数据 (每个chunk):") + print("-" * 80) + + with client.stream( + "POST", + "/antigravity/v1/models/流式抗截断/gemini-2.5-flash:streamGenerateContent", + json=test_request_body, + params={"key": test_api_key} + ) as response: + print(f"状态码: {response.status_code}") + print(f"Content-Type: {response.headers.get('content-type', 'N/A')}\n") + + chunk_count = 0 + for chunk in response.iter_bytes(): + if chunk: + chunk_count += 1 + print(f"\nChunk #{chunk_count}:") + print(f" 类型: {type(chunk).__name__}") + print(f" 长度: {len(chunk)}") + + # 解码chunk + try: + chunk_str = chunk.decode('utf-8') + print(f" 内容预览: {repr(chunk_str[:200] if len(chunk_str) > 200 else chunk_str)}") + + # 如果是SSE格式,尝试解析每一行 + if chunk_str.startswith("data: "): + # 按行分割,处理每个SSE事件 + for line in chunk_str.strip().split('\n'): + line = line.strip() + if not line: + continue + + if line == "data: [DONE]": + print(f" => 流结束标记") + elif line.startswith("data: "): + try: + json_str = line[6:] # 去掉 "data: " 前缀 + json_data = json.loads(json_str) + print(f" 解析后的JSON: {json.dumps(json_data, indent=4, ensure_ascii=False)}") + except Exception as e: + print(f" SSE解析失败: {e}") + except Exception as e: + print(f" 解码失败: {e}") + + print(f"\n总共收到 {chunk_count} 个chunk") + + # 运行测试 + try: + # 测试非流式请求 + test_non_stream_request() + + # 测试流式请求 + test_stream_request() + + # 测试假流式请求 + test_fake_stream_request() + + # 测试流式抗截断请求 + test_anti_truncation_stream_request() + + print("\n" + "=" * 80) + print("测试完成") + print("=" * 80) + + except Exception as e: + print(f"\n❌ 测试过程中出现异常: {e}") + import traceback + traceback.print_exc() diff --git a/src/router/antigravity/model_list.py b/src/router/antigravity/model_list.py new file mode 100644 index 0000000000000000000000000000000000000000..f27c81822dbb26b1dc2d6939704ecd193b9f4b25 --- /dev/null +++ b/src/router/antigravity/model_list.py @@ -0,0 +1,105 @@ +""" +Antigravity Model List Router - Handles model list requests +Antigravity 模型列表路由 - 处理模型列表请求 +""" + +import sys +from pathlib import Path + +# 添加项目根目录到Python路径 +project_root = Path(__file__).resolve().parent.parent.parent.parent +if str(project_root) not in sys.path: + sys.path.insert(0, str(project_root)) + +# 第三方库 +from fastapi import APIRouter, Depends +from fastapi.responses import JSONResponse + +# 本地模块 - 工具和认证 +from src.utils import ( + get_base_model_from_feature_model, + authenticate_flexible +) + +# 本地模块 - API +from src.api.antigravity import fetch_available_models + +# 本地模块 - 基础路由工具 +from src.router.base_router import create_gemini_model_list, create_openai_model_list +from src.models import model_to_dict +from log import log + + +# ==================== 路由器初始化 ==================== + +router = APIRouter() + + +# ==================== 辅助函数 ==================== + +async def get_antigravity_models_with_features(): + """ + 获取 Antigravity 模型列表并添加功能前缀 + + Returns: + 带有功能前缀的模型列表 + """ + # 从 API 获取基础模型列表 + base_models_data = await fetch_available_models() + + if not base_models_data: + log.warning("[ANTIGRAVITY MODEL LIST] 无法获取模型列表,返回空列表") + return [] + + # 提取模型 ID + base_model_ids = [model['id'] for model in base_models_data if 'id' in model] + + # 添加功能前缀 + models = [] + for base_model in base_model_ids: + # 基础模型 + models.append(base_model) + + # 假流式模型 (前缀格式) + models.append(f"假流式/{base_model}") + + # 流式抗截断模型 (仅在流式传输时有效,前缀格式) + models.append(f"流式抗截断/{base_model}") + + log.info(f"[ANTIGRAVITY MODEL LIST] 生成了 {len(models)} 个模型(包含功能前缀)") + return models + + +# ==================== API 路由 ==================== + +@router.get("/antigravity/v1beta/models") +async def list_gemini_models(token: str = Depends(authenticate_flexible)): + """ + 返回 Gemini 格式的模型列表 + + 从 src.api.antigravity.fetch_available_models 动态获取模型列表 + 并添加假流式和流式抗截断前缀 + """ + models = await get_antigravity_models_with_features() + log.info("[ANTIGRAVITY MODEL LIST] 返回 Gemini 格式") + return JSONResponse(content=create_gemini_model_list( + models, + base_name_extractor=get_base_model_from_feature_model + )) + + +@router.get("/antigravity/v1/models") +async def list_openai_models(token: str = Depends(authenticate_flexible)): + """ + 返回 OpenAI 格式的模型列表 + + 从 src.api.antigravity.fetch_available_models 动态获取模型列表 + 并添加假流式和流式抗截断前缀 + """ + models = await get_antigravity_models_with_features() + log.info("[ANTIGRAVITY MODEL LIST] 返回 OpenAI 格式") + model_list = create_openai_model_list(models, owned_by="google") + return JSONResponse(content={ + "object": "list", + "data": [model_to_dict(model) for model in model_list.data] + }) diff --git a/src/router/antigravity/openai.py b/src/router/antigravity/openai.py new file mode 100644 index 0000000000000000000000000000000000000000..ac8b017daccf67518a37910d1e1beb6be56d0a7d --- /dev/null +++ b/src/router/antigravity/openai.py @@ -0,0 +1,615 @@ +""" +OpenAI Router - Handles OpenAI format API requests via Antigravity +通过Antigravity处理OpenAI格式请求的路由模块 +""" + +import sys +from pathlib import Path + +# 添加项目根目录到Python路径 +project_root = Path(__file__).resolve().parent.parent.parent.parent +if str(project_root) not in sys.path: + sys.path.insert(0, str(project_root)) + +# 标准库 +import asyncio +import json + +# 第三方库 +from fastapi import APIRouter, Depends, HTTPException +from fastapi.responses import JSONResponse, StreamingResponse + +# 本地模块 - 配置和日志 +from config import get_anti_truncation_max_attempts +from log import log + +# 本地模块 - 工具和认证 +from src.utils import ( + get_base_model_from_feature_model, + is_anti_truncation_model, + is_fake_streaming_model, + authenticate_bearer, +) + +# 本地模块 - 转换器(假流式需要) +from src.converter.fake_stream import ( + parse_response_for_fake_stream, + build_openai_fake_stream_chunks, + create_openai_heartbeat_chunk, +) + +# 本地模块 - 基础路由工具 +from src.router.hi_check import is_health_check_request, create_health_check_response + +# 本地模块 - 数据模型 +from src.models import OpenAIChatCompletionRequest, model_to_dict + +# 本地模块 - 任务管理 +from src.task_manager import create_managed_task + + +# ==================== 路由器初始化 ==================== + +router = APIRouter() + + +# ==================== API 路由 ==================== + +@router.post("/antigravity/v1/chat/completions") +async def chat_completions( + openai_request: OpenAIChatCompletionRequest, + token: str = Depends(authenticate_bearer) +): + """ + 处理OpenAI格式的聊天完成请求(流式和非流式) + + Args: + openai_request: OpenAI格式的请求体 + token: Bearer认证令牌 + """ + log.debug(f"[ANTIGRAVITY-OPENAI] Request for model: {openai_request.model}") + + # 转换为字典 + normalized_dict = model_to_dict(openai_request) + + # 健康检查 + if is_health_check_request(normalized_dict, format="openai"): + response = create_health_check_response(format="openai") + return JSONResponse(content=response) + + # 处理模型名称和功能检测 + use_fake_streaming = is_fake_streaming_model(openai_request.model) + use_anti_truncation = is_anti_truncation_model(openai_request.model) + real_model = get_base_model_from_feature_model(openai_request.model) + + # 获取流式标志 + is_streaming = openai_request.stream + + # 对于抗截断模型的非流式请求,给出警告 + if use_anti_truncation and not is_streaming: + log.warning("抗截断功能仅在流式传输时有效,非流式请求将忽略此设置") + + # 更新模型名为真实模型名 + normalized_dict["model"] = real_model + + # 转换为 Gemini 格式 (使用 converter) + from src.converter.openai2gemini import convert_openai_to_gemini_request + gemini_dict = await convert_openai_to_gemini_request(normalized_dict) + + # convert_openai_to_gemini_request 不包含 model 字段,需要手动添加 + gemini_dict["model"] = real_model + + # 规范化 Gemini 请求 (使用 antigravity 模式) + from src.converter.gemini_fix import normalize_gemini_request + gemini_dict = await normalize_gemini_request(gemini_dict, mode="antigravity") + + # 准备API请求格式 - 提取model并将其他字段放入request中 + api_request = { + "model": gemini_dict.pop("model"), + "request": gemini_dict + } + + # ========== 非流式请求 ========== + if not is_streaming: + # 调用 API 层的非流式请求 + from src.api.antigravity import non_stream_request + response = await non_stream_request(body=api_request) + + # 检查响应状态码 + status_code = getattr(response, "status_code", 200) + + # 提取响应体 + if hasattr(response, "body"): + response_body = response.body.decode() if isinstance(response.body, bytes) else response.body + elif hasattr(response, "content"): + response_body = response.content.decode() if isinstance(response.content, bytes) else response.content + else: + response_body = str(response) + + try: + gemini_response = json.loads(response_body) + except Exception as e: + log.error(f"Failed to parse Gemini response: {e}") + raise HTTPException(status_code=500, detail="Response parsing failed") + + # 转换为 OpenAI 格式 + from src.converter.openai2gemini import convert_gemini_to_openai_response + openai_response = convert_gemini_to_openai_response( + gemini_response, + real_model, + status_code + ) + + return JSONResponse(content=openai_response, status_code=status_code) + + # ========== 流式请求 ========== + + # ========== 假流式生成器 ========== + async def fake_stream_generator(): + # 发送心跳 + heartbeat = create_openai_heartbeat_chunk() + yield f"data: {json.dumps(heartbeat)}\n\n".encode() + + # 异步发送实际请求 + async def get_response(): + from src.api.antigravity import non_stream_request + response = await non_stream_request(body=api_request) + return response + + # 创建请求任务 + response_task = create_managed_task(get_response(), name="openai_fake_stream_request") + + try: + # 每3秒发送一次心跳,直到收到响应 + while not response_task.done(): + await asyncio.sleep(3.0) + if not response_task.done(): + yield f"data: {json.dumps(heartbeat)}\n\n".encode() + + # 获取响应结果 + response = await response_task + + except asyncio.CancelledError: + response_task.cancel() + try: + await response_task + except asyncio.CancelledError: + pass + raise + except Exception as e: + response_task.cancel() + try: + await response_task + except asyncio.CancelledError: + pass + log.error(f"Fake streaming request failed: {e}") + raise + + # 检查响应状态码 + if hasattr(response, "status_code") and response.status_code != 200: + # 错误响应 - 提取错误信息并以SSE格式返回 + log.error(f"Fake streaming got error response: status={response.status_code}") + + if hasattr(response, "body"): + error_body = response.body.decode() if isinstance(response.body, bytes) else response.body + elif hasattr(response, "content"): + error_body = response.content.decode() if isinstance(response.content, bytes) else response.content + else: + error_body = str(response) + + try: + error_data = json.loads(error_body) + # 转换错误为 OpenAI 格式 + from src.converter.openai2gemini import convert_gemini_to_openai_response + openai_error = convert_gemini_to_openai_response( + error_data, + real_model, + response.status_code + ) + yield f"data: {json.dumps(openai_error)}\n\n".encode() + except Exception: + # 如果无法解析为JSON,包装成错误对象 + yield f"data: {json.dumps({'error': error_body})}\n\n".encode() + + yield "data: [DONE]\n\n".encode() + return + + # 处理成功响应 - 提取响应内容 + if hasattr(response, "body"): + response_body = response.body.decode() if isinstance(response.body, bytes) else response.body + elif hasattr(response, "content"): + response_body = response.content.decode() if isinstance(response.content, bytes) else response.content + else: + response_body = str(response) + + try: + gemini_response = json.loads(response_body) + log.debug(f"OpenAI fake stream Gemini response: {gemini_response}") + + # 检查是否是错误响应(有些错误可能status_code是200但包含error字段) + if "error" in gemini_response: + log.error(f"Fake streaming got error in response body: {gemini_response['error']}") + # 转换错误为 OpenAI 格式 + from src.converter.openai2gemini import convert_gemini_to_openai_response + openai_error = convert_gemini_to_openai_response( + gemini_response, + real_model, + 200 + ) + yield f"data: {json.dumps(openai_error)}\n\n".encode() + yield "data: [DONE]\n\n".encode() + return + + # 使用统一的解析函数 + content, reasoning_content, finish_reason, images = parse_response_for_fake_stream(gemini_response) + + log.debug(f"OpenAI extracted content: {content}") + log.debug(f"OpenAI extracted reasoning: {reasoning_content[:100] if reasoning_content else 'None'}...") + log.debug(f"OpenAI extracted images count: {len(images)}") + + # 构建响应块 + chunks = build_openai_fake_stream_chunks(content, reasoning_content, finish_reason, real_model, images) + for idx, chunk in enumerate(chunks): + chunk_json = json.dumps(chunk) + log.debug(f"[FAKE_STREAM] Yielding chunk #{idx+1}: {chunk_json[:200]}") + yield f"data: {chunk_json}\n\n".encode() + + except Exception as e: + log.error(f"Response parsing failed: {e}, directly yield error") + # 构建错误响应 + error_chunk = { + "id": "error", + "object": "chat.completion.chunk", + "created": int(asyncio.get_event_loop().time()), + "model": real_model, + "choices": [{ + "index": 0, + "delta": {"content": f"Error: {str(e)}"}, + "finish_reason": "error" + }] + } + yield f"data: {json.dumps(error_chunk)}\n\n".encode() + + yield "data: [DONE]\n\n".encode() + + # ========== 流式抗截断生成器 ========== + async def anti_truncation_generator(): + from src.converter.anti_truncation import AntiTruncationStreamProcessor + from src.api.antigravity import stream_request + from src.converter.anti_truncation import apply_anti_truncation + + max_attempts = await get_anti_truncation_max_attempts() + + # 首先对payload应用反截断指令 + anti_truncation_payload = apply_anti_truncation(api_request) + + # 定义流式请求函数(返回 StreamingResponse) + async def stream_request_wrapper(payload): + # stream_request 返回异步生成器,需要包装成 StreamingResponse + stream_gen = stream_request(body=payload, native=False) + return StreamingResponse(stream_gen, media_type="text/event-stream") + + # 创建反截断处理器 + processor = AntiTruncationStreamProcessor( + stream_request_wrapper, + anti_truncation_payload, + max_attempts + ) + + # 转换为 OpenAI 格式 + import uuid + response_id = str(uuid.uuid4()) + + # 直接迭代 process_stream() 生成器,并转换为 OpenAI 格式 + async for chunk in processor.process_stream(): + if not chunk: + continue + + # 解析 Gemini SSE 格式 + chunk_str = chunk.decode('utf-8') if isinstance(chunk, bytes) else chunk + + # 跳过空行 + if not chunk_str.strip(): + continue + + # 处理 [DONE] 标记 + if chunk_str.strip() == "data: [DONE]": + yield "data: [DONE]\n\n".encode('utf-8') + return + + # 解析 "data: {...}" 格式 + if chunk_str.startswith("data: "): + try: + # 转换为 OpenAI 格式 + from src.converter.openai2gemini import convert_gemini_to_openai_stream + openai_chunk_str = convert_gemini_to_openai_stream( + chunk_str, + real_model, + response_id + ) + + if openai_chunk_str: + yield openai_chunk_str.encode('utf-8') + + except Exception as e: + log.error(f"Failed to convert chunk: {e}") + continue + + # 发送结束标记 + yield "data: [DONE]\n\n".encode('utf-8') + + # ========== 普通流式生成器 ========== + async def normal_stream_generator(): + from src.api.antigravity import stream_request + from fastapi import Response + import uuid + + # 调用 API 层的流式请求(不使用 native 模式) + stream_gen = stream_request(body=api_request, native=False) + + response_id = str(uuid.uuid4()) + + # yield所有数据,处理可能的错误Response + async for chunk in stream_gen: + # 检查是否是Response对象(错误情况) + if isinstance(chunk, Response): + # 将Response转换为SSE格式的错误消息 + error_content = chunk.body if isinstance(chunk.body, bytes) else chunk.body.encode('utf-8') + try: + gemini_error = json.loads(error_content.decode('utf-8')) + # 转换为 OpenAI 格式错误 + from src.converter.openai2gemini import convert_gemini_to_openai_response + openai_error = convert_gemini_to_openai_response( + gemini_error, + real_model, + chunk.status_code + ) + yield f"data: {json.dumps(openai_error)}\n\n".encode('utf-8') + except Exception: + yield f"data: {json.dumps({'error': 'Stream error'})}\n\n".encode('utf-8') + return + else: + # 正常的bytes数据,转换为 OpenAI 格式 + chunk_str = chunk.decode('utf-8') if isinstance(chunk, bytes) else chunk + + # 跳过空行 + if not chunk_str.strip(): + continue + + # 处理 [DONE] 标记 + if chunk_str.strip() == "data: [DONE]": + yield "data: [DONE]\n\n".encode('utf-8') + return + + # 解析并转换 Gemini chunk 为 OpenAI 格式 + if chunk_str.startswith("data: "): + try: + # 转换为 OpenAI 格式 + from src.converter.openai2gemini import convert_gemini_to_openai_stream + openai_chunk_str = convert_gemini_to_openai_stream( + chunk_str, + real_model, + response_id + ) + + if openai_chunk_str: + yield openai_chunk_str.encode('utf-8') + + except Exception as e: + log.error(f"Failed to convert chunk: {e}") + continue + + # 发送结束标记 + yield "data: [DONE]\n\n".encode('utf-8') + + # ========== 根据模式选择生成器 ========== + if use_fake_streaming: + return StreamingResponse(fake_stream_generator(), media_type="text/event-stream") + elif use_anti_truncation: + log.info("启用流式抗截断功能") + return StreamingResponse(anti_truncation_generator(), media_type="text/event-stream") + else: + return StreamingResponse(normal_stream_generator(), media_type="text/event-stream") + + +# ==================== 测试代码 ==================== + +if __name__ == "__main__": + """ + 测试代码:演示OpenAI路由的流式和非流式响应 + 运行方式: python src/router/antigravity/openai.py + """ + + from fastapi.testclient import TestClient + from fastapi import FastAPI + + print("=" * 80) + print("OpenAI Router 测试") + print("=" * 80) + + # 创建测试应用 + app = FastAPI() + app.include_router(router) + + # 测试客户端 + client = TestClient(app) + + # 测试请求体 (OpenAI格式) + test_request_body = { + "model": "gemini-2.5-flash", + "messages": [ + {"role": "user", "content": "Hello, tell me a joke in one sentence."} + ] + } + + # 测试Bearer令牌(模拟) + test_token = "Bearer pwd" + + def test_non_stream_request(): + """测试非流式请求""" + print("\n" + "=" * 80) + print("【测试1】非流式请求 (POST /antigravity/v1/chat/completions)") + print("=" * 80) + print(f"请求体: {json.dumps(test_request_body, indent=2, ensure_ascii=False)}\n") + + response = client.post( + "/antigravity/v1/chat/completions", + json=test_request_body, + headers={"Authorization": test_token} + ) + + print("非流式响应数据:") + print("-" * 80) + print(f"状态码: {response.status_code}") + print(f"Content-Type: {response.headers.get('content-type', 'N/A')}") + + try: + content = response.text + print(f"\n响应内容 (原始):\n{content}\n") + + # 尝试解析JSON + try: + json_data = response.json() + print(f"响应内容 (格式化JSON):") + print(json.dumps(json_data, indent=2, ensure_ascii=False)) + except json.JSONDecodeError: + print("(非JSON格式)") + except Exception as e: + print(f"内容解析失败: {e}") + + def test_stream_request(): + """测试流式请求""" + print("\n" + "=" * 80) + print("【测试2】流式请求 (POST /antigravity/v1/chat/completions)") + print("=" * 80) + + stream_request_body = test_request_body.copy() + stream_request_body["stream"] = True + + print(f"请求体: {json.dumps(stream_request_body, indent=2, ensure_ascii=False)}\n") + + print("流式响应数据 (每个chunk):") + print("-" * 80) + + with client.stream( + "POST", + "/antigravity/v1/chat/completions", + json=stream_request_body, + headers={"Authorization": test_token} + ) as response: + print(f"状态码: {response.status_code}") + print(f"Content-Type: {response.headers.get('content-type', 'N/A')}\n") + + chunk_count = 0 + for chunk in response.iter_bytes(): + if chunk: + chunk_count += 1 + print(f"\nChunk #{chunk_count}:") + print(f" 类型: {type(chunk).__name__}") + print(f" 长度: {len(chunk)}") + + # 解码chunk + try: + chunk_str = chunk.decode('utf-8') + print(f" 内容预览: {repr(chunk_str[:200] if len(chunk_str) > 200 else chunk_str)}") + + # 如果是SSE格式,尝试解析每一行 + if chunk_str.startswith("data: "): + # 按行分割,处理每个SSE事件 + for line in chunk_str.strip().split('\n'): + line = line.strip() + if not line: + continue + + if line == "data: [DONE]": + print(f" => 流结束标记") + elif line.startswith("data: "): + try: + json_str = line[6:] # 去掉 "data: " 前缀 + json_data = json.loads(json_str) + print(f" 解析后的JSON: {json.dumps(json_data, indent=4, ensure_ascii=False)}") + except Exception as e: + print(f" SSE解析失败: {e}") + except Exception as e: + print(f" 解码失败: {e}") + + print(f"\n总共收到 {chunk_count} 个chunk") + + def test_fake_stream_request(): + """测试假流式请求""" + print("\n" + "=" * 80) + print("【测试3】假流式请求 (POST /antigravity/v1/chat/completions with 假流式 prefix)") + print("=" * 80) + + fake_stream_request_body = test_request_body.copy() + fake_stream_request_body["model"] = "假流式/gemini-2.5-flash" + fake_stream_request_body["stream"] = True + + print(f"请求体: {json.dumps(fake_stream_request_body, indent=2, ensure_ascii=False)}\n") + + print("假流式响应数据 (每个chunk):") + print("-" * 80) + + with client.stream( + "POST", + "/antigravity/v1/chat/completions", + json=fake_stream_request_body, + headers={"Authorization": test_token} + ) as response: + print(f"状态码: {response.status_code}") + print(f"Content-Type: {response.headers.get('content-type', 'N/A')}\n") + + chunk_count = 0 + for chunk in response.iter_bytes(): + if chunk: + chunk_count += 1 + chunk_str = chunk.decode('utf-8') + + print(f"\nChunk #{chunk_count}:") + print(f" 长度: {len(chunk_str)} 字节") + + # 解析chunk中的所有SSE事件 + events = [] + for line in chunk_str.split('\n'): + line = line.strip() + if line.startswith("data: "): + events.append(line) + + print(f" 包含 {len(events)} 个SSE事件") + + # 显示每个事件 + for event_idx, event_line in enumerate(events, 1): + if event_line == "data: [DONE]": + print(f" 事件 #{event_idx}: [DONE]") + else: + try: + json_str = event_line[6:] # 去掉 "data: " 前缀 + json_data = json.loads(json_str) + # 提取content内容 + content = json_data.get("choices", [{}])[0].get("delta", {}).get("content", "") + finish_reason = json_data.get("choices", [{}])[0].get("finish_reason") + print(f" 事件 #{event_idx}: content={repr(content[:50])}{'...' if len(content) > 50 else ''}, finish_reason={finish_reason}") + except Exception as e: + print(f" 事件 #{event_idx}: 解析失败 - {e}") + + print(f"\n总共收到 {chunk_count} 个HTTP chunk") + + # 运行测试 + try: + # 测试非流式请求 + test_non_stream_request() + + # 测试流式请求 + test_stream_request() + + # 测试假流式请求 + test_fake_stream_request() + + print("\n" + "=" * 80) + print("测试完成") + print("=" * 80) + + except Exception as e: + print(f"\n❌ 测试过程中出现异常: {e}") + import traceback + traceback.print_exc() diff --git a/src/router/base_router.py b/src/router/base_router.py new file mode 100644 index 0000000000000000000000000000000000000000..e17b648189eb144adaa165cc94ab4aad7702c146 --- /dev/null +++ b/src/router/base_router.py @@ -0,0 +1,106 @@ +""" +Base Router - 共用的路由基础功能 +提供模型列表处理、通用响应等共同功能 +""" + +from typing import List, Optional + +from src.models import Model, ModelList + + +# ==================== 模型列表处理 ==================== + +def expand_models_with_features( + base_models: List[str], + features: Optional[List[str]] = None +) -> List[str]: + """ + 使用特性前缀扩展模型列表 + + Args: + base_models: 基础模型列表 + features: 特性前缀列表,如 ["流式抗截断", "假流式"] + + Returns: + 扩展后的模型列表(包含原始模型和特性变体) + """ + if not features: + return base_models.copy() + + expanded = [] + for model in base_models: + # 添加原始模型 + expanded.append(model) + + # 添加特性变体 + for feature in features: + expanded.append(f"{feature}/{model}") + + return expanded + + +def create_openai_model_list( + model_ids: List[str], + owned_by: str = "google" +) -> ModelList: + """ + 创建OpenAI格式的模型列表 + + Args: + model_ids: 模型ID列表 + owned_by: 模型所有者 + + Returns: + ModelList对象 + """ + from datetime import datetime, timezone + current_timestamp = int(datetime.now(timezone.utc).timestamp()) + + models = [ + Model( + id=model_id, + object='model', + created=current_timestamp, + owned_by=owned_by + ) + for model_id in model_ids + ] + + return ModelList(data=models) + + +def create_gemini_model_list( + model_ids: List[str], + base_name_extractor=None +) -> dict: + """ + 创建Gemini格式的模型列表 + + Args: + model_ids: 模型ID列表 + base_name_extractor: 可选的基础模型名提取函数 + + Returns: + 包含模型列表的字典 + """ + gemini_models = [] + + for model_id in model_ids: + base_model = model_id + if base_name_extractor: + try: + base_model = base_name_extractor(model_id) + except Exception: + pass + + model_info = { + "name": f"models/{model_id}", + "baseModelId": base_model, + "version": "001", + "displayName": model_id, + "description": f"Gemini {base_model} model", + "supportedGenerationMethods": ["generateContent", "streamGenerateContent"], + } + gemini_models.append(model_info) + + return {"models": gemini_models} \ No newline at end of file diff --git a/src/router/geminicli/anthropic.py b/src/router/geminicli/anthropic.py new file mode 100644 index 0000000000000000000000000000000000000000..c99958b574a72ac302dd0ac74023aec53b5d961f --- /dev/null +++ b/src/router/geminicli/anthropic.py @@ -0,0 +1,566 @@ +""" +Anthropic Router - Handles Anthropic/Claude format API requests via GeminiCLI +通过GeminiCLI处理Anthropic/Claude格式请求的路由模块 +""" + +import sys +from pathlib import Path + +# 添加项目根目录到Python路径 +project_root = Path(__file__).resolve().parent.parent.parent.parent +if str(project_root) not in sys.path: + sys.path.insert(0, str(project_root)) + +# 标准库 +import asyncio +import json + +# 第三方库 +from fastapi import APIRouter, Depends, HTTPException +from fastapi.responses import JSONResponse, StreamingResponse + +# 本地模块 - 配置和日志 +from config import get_anti_truncation_max_attempts +from log import log + +# 本地模块 - 工具和认证 +from src.utils import ( + get_base_model_from_feature_model, + is_anti_truncation_model, + is_fake_streaming_model, + authenticate_bearer, +) + +# 本地模块 - 转换器(假流式需要) +from src.converter.fake_stream import ( + parse_response_for_fake_stream, + build_anthropic_fake_stream_chunks, + create_anthropic_heartbeat_chunk, +) + +# 本地模块 - 基础路由工具 +from src.router.hi_check import is_health_check_request, create_health_check_response + +# 本地模块 - 数据模型 +from src.models import ClaudeRequest, model_to_dict + +# 本地模块 - 任务管理 +from src.task_manager import create_managed_task + + +# ==================== 路由器初始化 ==================== + +router = APIRouter() + + +# ==================== API 路由 ==================== + +@router.post("/v1/messages") +async def messages( + claude_request: ClaudeRequest, + token: str = Depends(authenticate_bearer) +): + """ + 处理Anthropic/Claude格式的消息请求(流式和非流式) + + Args: + claude_request: Anthropic/Claude格式的请求体 + token: Bearer认证令牌 + """ + log.debug(f"[GEMINICLI-ANTHROPIC] Request for model: {claude_request.model}") + + # 转换为字典 + normalized_dict = model_to_dict(claude_request) + + # 健康检查 + if is_health_check_request(normalized_dict, format="anthropic"): + response = create_health_check_response(format="anthropic") + return JSONResponse(content=response) + + # 处理模型名称和功能检测 + use_fake_streaming = is_fake_streaming_model(claude_request.model) + use_anti_truncation = is_anti_truncation_model(claude_request.model) + real_model = get_base_model_from_feature_model(claude_request.model) + + # 获取流式标志 + is_streaming = claude_request.stream + + # 对于抗截断模型的非流式请求,给出警告 + if use_anti_truncation and not is_streaming: + log.warning("抗截断功能仅在流式传输时有效,非流式请求将忽略此设置") + + # 更新模型名为真实模型名 + normalized_dict["model"] = real_model + + # 转换为 Gemini 格式 (使用 converter) + from src.converter.anthropic2gemini import anthropic_to_gemini_request + gemini_dict = await anthropic_to_gemini_request(normalized_dict) + + # anthropic_to_gemini_request 不包含 model 字段,需要手动添加 + gemini_dict["model"] = real_model + + # 规范化 Gemini 请求 (使用 geminicli 模式) + from src.converter.gemini_fix import normalize_gemini_request + gemini_dict = await normalize_gemini_request(gemini_dict, mode="geminicli") + + # 准备API请求格式 - 提取model并将其他字段放入request中 + api_request = { + "model": gemini_dict.pop("model"), + "request": gemini_dict + } + + # ========== 非流式请求 ========== + if not is_streaming: + # 调用 API 层的非流式请求 + from src.api.geminicli import non_stream_request + response = await non_stream_request(body=api_request) + + # 检查响应状态码 + status_code = getattr(response, "status_code", 200) + + # 提取响应体 + if hasattr(response, "body"): + response_body = response.body.decode() if isinstance(response.body, bytes) else response.body + elif hasattr(response, "content"): + response_body = response.content.decode() if isinstance(response.content, bytes) else response.content + else: + response_body = str(response) + + try: + gemini_response = json.loads(response_body) + except Exception as e: + log.error(f"Failed to parse Gemini response: {e}") + raise HTTPException(status_code=500, detail="Response parsing failed") + + # 转换为 Anthropic 格式 + from src.converter.anthropic2gemini import gemini_to_anthropic_response + anthropic_response = gemini_to_anthropic_response( + gemini_response, + real_model, + status_code + ) + + return JSONResponse(content=anthropic_response, status_code=status_code) + + # ========== 流式请求 ========== + + # ========== 假流式生成器 ========== + async def fake_stream_generator(): + # 发送心跳 + heartbeat = create_anthropic_heartbeat_chunk() + yield f"data: {json.dumps(heartbeat)}\n\n".encode() + + # 异步发送实际请求 + async def get_response(): + from src.api.geminicli import non_stream_request + response = await non_stream_request(body=api_request) + return response + + # 创建请求任务 + response_task = create_managed_task(get_response(), name="anthropic_fake_stream_request") + + try: + # 每3秒发送一次心跳,直到收到响应 + while not response_task.done(): + await asyncio.sleep(3.0) + if not response_task.done(): + yield f"data: {json.dumps(heartbeat)}\n\n".encode() + + # 获取响应结果 + response = await response_task + + except asyncio.CancelledError: + response_task.cancel() + try: + await response_task + except asyncio.CancelledError: + pass + raise + except Exception as e: + response_task.cancel() + try: + await response_task + except asyncio.CancelledError: + pass + log.error(f"Fake streaming request failed: {e}") + raise + + # 检查响应状态码 + if hasattr(response, "status_code") and response.status_code != 200: + # 错误响应 - 提取错误信息并以SSE格式返回 + log.error(f"Fake streaming got error response: status={response.status_code}") + + if hasattr(response, "body"): + error_body = response.body.decode() if isinstance(response.body, bytes) else response.body + elif hasattr(response, "content"): + error_body = response.content.decode() if isinstance(response.content, bytes) else response.content + else: + error_body = str(response) + + try: + error_data = json.loads(error_body) + # 转换错误为 Anthropic 格式 + from src.converter.anthropic2gemini import gemini_to_anthropic_response + anthropic_error = gemini_to_anthropic_response( + error_data, + real_model, + response.status_code + ) + yield f"data: {json.dumps(anthropic_error)}\n\n".encode() + except Exception: + # 如果无法解析为JSON,包装成错误对象 + yield f"data: {json.dumps({'error': error_body})}\n\n".encode() + + yield "data: [DONE]\n\n".encode() + return + + # 处理成功响应 - 提取响应内容 + if hasattr(response, "body"): + response_body = response.body.decode() if isinstance(response.body, bytes) else response.body + elif hasattr(response, "content"): + response_body = response.content.decode() if isinstance(response.content, bytes) else response.content + else: + response_body = str(response) + + try: + gemini_response = json.loads(response_body) + log.debug(f"Anthropic fake stream Gemini response: {gemini_response}") + + # 检查是否是错误响应(有些错误可能status_code是200但包含error字段) + if "error" in gemini_response: + log.error(f"Fake streaming got error in response body: {gemini_response['error']}") + # 转换错误为 Anthropic 格式 + from src.converter.anthropic2gemini import gemini_to_anthropic_response + anthropic_error = gemini_to_anthropic_response( + gemini_response, + real_model, + 200 + ) + yield f"data: {json.dumps(anthropic_error)}\n\n".encode() + yield "data: [DONE]\n\n".encode() + return + + # 使用统一的解析函数 + content, reasoning_content, finish_reason, images = parse_response_for_fake_stream(gemini_response) + + log.debug(f"Anthropic extracted content: {content}") + log.debug(f"Anthropic extracted reasoning: {reasoning_content[:100] if reasoning_content else 'None'}...") + log.debug(f"Anthropic extracted images count: {len(images)}") + + # 构建响应块 + chunks = build_anthropic_fake_stream_chunks(content, reasoning_content, finish_reason, real_model, images) + for idx, chunk in enumerate(chunks): + chunk_json = json.dumps(chunk) + log.debug(f"[FAKE_STREAM] Yielding chunk #{idx+1}: {chunk_json[:200]}") + yield f"data: {chunk_json}\n\n".encode() + + except Exception as e: + log.error(f"Response parsing failed: {e}, directly yield error") + # 构建错误响应 + error_chunk = { + "type": "error", + "error": { + "type": "api_error", + "message": str(e) + } + } + yield f"data: {json.dumps(error_chunk)}\n\n".encode() + + yield "data: [DONE]\n\n".encode() + + # ========== 流式抗截断生成器 ========== + async def anti_truncation_generator(): + from src.converter.anti_truncation import AntiTruncationStreamProcessor + from src.api.geminicli import stream_request + from src.converter.anti_truncation import apply_anti_truncation + from src.converter.anthropic2gemini import gemini_stream_to_anthropic_stream + + max_attempts = await get_anti_truncation_max_attempts() + + # 首先对payload应用反截断指令 + anti_truncation_payload = apply_anti_truncation(api_request) + + # 定义流式请求函数(返回 StreamingResponse) + async def stream_request_wrapper(payload): + # stream_request 返回异步生成器,需要包装成 StreamingResponse + stream_gen = stream_request(body=payload, native=False) + return StreamingResponse(stream_gen, media_type="text/event-stream") + + # 创建反截断处理器 + processor = AntiTruncationStreamProcessor( + stream_request_wrapper, + anti_truncation_payload, + max_attempts + ) + + # 包装以确保是bytes流 + async def bytes_wrapper(): + async for chunk in processor.process_stream(): + if isinstance(chunk, str): + yield chunk.encode('utf-8') + else: + yield chunk + + # 直接将整个流传递给转换器 + async for anthropic_chunk in gemini_stream_to_anthropic_stream( + bytes_wrapper(), + real_model, + 200 + ): + if anthropic_chunk: + yield anthropic_chunk + + # ========== 普通流式生成器 ========== + async def normal_stream_generator(): + from src.api.geminicli import stream_request + from fastapi import Response + from src.converter.anthropic2gemini import gemini_stream_to_anthropic_stream + + # 调用 API 层的流式请求(不使用 native 模式) + stream_gen = stream_request(body=api_request, native=False) + + # 包装流式生成器以处理错误响应 + async def gemini_chunk_wrapper(): + async for chunk in stream_gen: + # 检查是否是Response对象(错误情况) + if isinstance(chunk, Response): + # 错误响应,不进行转换,直接传递 + error_content = chunk.body if isinstance(chunk.body, bytes) else chunk.body.encode('utf-8') + try: + gemini_error = json.loads(error_content.decode('utf-8')) + from src.converter.anthropic2gemini import gemini_to_anthropic_response + anthropic_error = gemini_to_anthropic_response( + gemini_error, + real_model, + chunk.status_code + ) + yield f"data: {json.dumps(anthropic_error)}\n\n".encode('utf-8') + except Exception: + yield f"data: {json.dumps({'type': 'error', 'error': {'type': 'api_error', 'message': 'Stream error'}})}\n\n".encode('utf-8') + return + else: + # 确保是bytes类型 + if isinstance(chunk, str): + yield chunk.encode('utf-8') + else: + yield chunk + + # 使用转换器处理整个流 + async for anthropic_chunk in gemini_stream_to_anthropic_stream( + gemini_chunk_wrapper(), + real_model, + 200 + ): + if anthropic_chunk: + yield anthropic_chunk + + # ========== 根据模式选择生成器 ========== + if use_fake_streaming: + return StreamingResponse(fake_stream_generator(), media_type="text/event-stream") + elif use_anti_truncation: + log.info("启用流式抗截断功能") + return StreamingResponse(anti_truncation_generator(), media_type="text/event-stream") + else: + return StreamingResponse(normal_stream_generator(), media_type="text/event-stream") + + +# ==================== 测试代码 ==================== + +if __name__ == "__main__": + """ + 测试代码:演示Anthropic路由的流式和非流式响应 + 运行方式: python src/router/geminicli/anthropic.py + """ + + from fastapi.testclient import TestClient + from fastapi import FastAPI + + print("=" * 80) + print("Anthropic Router 测试") + print("=" * 80) + + # 创建测试应用 + app = FastAPI() + app.include_router(router) + + # 测试客户端 + client = TestClient(app) + + # 测试请求体 (Anthropic格式) + test_request_body = { + "model": "gemini-2.5-flash", + "max_tokens": 1024, + "messages": [ + {"role": "user", "content": "Hello, tell me a joke in one sentence."} + ] + } + + # 测试Bearer令牌(模拟) + test_token = "Bearer pwd" + + def test_non_stream_request(): + """测试非流式请求""" + print("\n" + "=" * 80) + print("【测试1】非流式请求 (POST /v1/messages)") + print("=" * 80) + print(f"请求体: {json.dumps(test_request_body, indent=2, ensure_ascii=False)}\n") + + response = client.post( + "/v1/messages", + json=test_request_body, + headers={"Authorization": test_token} + ) + + print("非流式响应数据:") + print("-" * 80) + print(f"状态码: {response.status_code}") + print(f"Content-Type: {response.headers.get('content-type', 'N/A')}") + + try: + content = response.text + print(f"\n响应内容 (原始):\n{content}\n") + + # 尝试解析JSON + try: + json_data = response.json() + print(f"响应内容 (格式化JSON):") + print(json.dumps(json_data, indent=2, ensure_ascii=False)) + except json.JSONDecodeError: + print("(非JSON格式)") + except Exception as e: + print(f"内容解析失败: {e}") + + def test_stream_request(): + """测试流式请求""" + print("\n" + "=" * 80) + print("【测试2】流式请求 (POST /v1/messages)") + print("=" * 80) + + stream_request_body = test_request_body.copy() + stream_request_body["stream"] = True + + print(f"请求体: {json.dumps(stream_request_body, indent=2, ensure_ascii=False)}\n") + + print("流式响应数据 (每个chunk):") + print("-" * 80) + + with client.stream( + "POST", + "/v1/messages", + json=stream_request_body, + headers={"Authorization": test_token} + ) as response: + print(f"状态码: {response.status_code}") + print(f"Content-Type: {response.headers.get('content-type', 'N/A')}\n") + + chunk_count = 0 + for chunk in response.iter_bytes(): + if chunk: + chunk_count += 1 + print(f"\nChunk #{chunk_count}:") + print(f" 类型: {type(chunk).__name__}") + print(f" 长度: {len(chunk)}") + + # 解码chunk + try: + chunk_str = chunk.decode('utf-8') + print(f" 内容预览: {repr(chunk_str[:200] if len(chunk_str) > 200 else chunk_str)}") + + # 如果是SSE格式,尝试解析每一行 + if chunk_str.startswith("event: ") or chunk_str.startswith("data: "): + # 按行分割,处理每个SSE事件 + for line in chunk_str.strip().split('\n'): + line = line.strip() + if not line: + continue + + if line == "data: [DONE]": + print(f" => 流结束标记") + elif line.startswith("data: "): + try: + json_str = line[6:] # 去掉 "data: " 前缀 + json_data = json.loads(json_str) + print(f" 解析后的JSON: {json.dumps(json_data, indent=4, ensure_ascii=False)}") + except Exception as e: + print(f" SSE解析失败: {e}") + except Exception as e: + print(f" 解码失败: {e}") + + print(f"\n总共收到 {chunk_count} 个chunk") + + def test_fake_stream_request(): + """测试假流式请求""" + print("\n" + "=" * 80) + print("【测试3】假流式请求 (POST /v1/messages with 假流式 prefix)") + print("=" * 80) + + fake_stream_request_body = test_request_body.copy() + fake_stream_request_body["model"] = "假流式/gemini-2.5-flash" + fake_stream_request_body["stream"] = True + + print(f"请求体: {json.dumps(fake_stream_request_body, indent=2, ensure_ascii=False)}\n") + + print("假流式响应数据 (每个chunk):") + print("-" * 80) + + with client.stream( + "POST", + "/v1/messages", + json=fake_stream_request_body, + headers={"Authorization": test_token} + ) as response: + print(f"状态码: {response.status_code}") + print(f"Content-Type: {response.headers.get('content-type', 'N/A')}\n") + + chunk_count = 0 + for chunk in response.iter_bytes(): + if chunk: + chunk_count += 1 + chunk_str = chunk.decode('utf-8') + + print(f"\nChunk #{chunk_count}:") + print(f" 长度: {len(chunk_str)} 字节") + + # 解析chunk中的所有SSE事件 + events = [] + for line in chunk_str.split('\n'): + line = line.strip() + if line.startswith("data: ") or line.startswith("event: "): + events.append(line) + + print(f" 包含 {len(events)} 个SSE事件") + + # 显示每个事件 + for event_idx, event_line in enumerate(events, 1): + if event_line == "data: [DONE]": + print(f" 事件 #{event_idx}: [DONE]") + elif event_line.startswith("data: "): + try: + json_str = event_line[6:] # 去掉 "data: " 前缀 + json_data = json.loads(json_str) + event_type = json_data.get("type", "unknown") + print(f" 事件 #{event_idx}: type={event_type}") + except Exception as e: + print(f" 事件 #{event_idx}: 解析失败 - {e}") + + print(f"\n总共收到 {chunk_count} 个HTTP chunk") + + # 运行测试 + try: + # 测试非流式请求 + test_non_stream_request() + + # 测试流式请求 + test_stream_request() + + # 测试假流式请求 + test_fake_stream_request() + + print("\n" + "=" * 80) + print("测试完成") + print("=" * 80) + + except Exception as e: + print(f"\n❌ 测试过程中出现异常: {e}") + import traceback + traceback.print_exc() diff --git a/src/router/geminicli/gemini.py b/src/router/geminicli/gemini.py new file mode 100644 index 0000000000000000000000000000000000000000..f1d889101216bff48f10c105dd906003bb6487be --- /dev/null +++ b/src/router/geminicli/gemini.py @@ -0,0 +1,690 @@ +""" +Gemini Router - Handles native Gemini format API requests +处理原生Gemini格式请求的路由模块 +""" + +import sys +from pathlib import Path + +# 添加项目根目录到Python路径 +project_root = Path(__file__).resolve().parent.parent.parent.parent +if str(project_root) not in sys.path: + sys.path.insert(0, str(project_root)) + +# 标准库 +import asyncio +import json + +# 第三方库 +from fastapi import APIRouter, Depends, HTTPException, Path, Request +from fastapi.responses import JSONResponse, StreamingResponse + +# 本地模块 - 配置和日志 +from config import get_anti_truncation_max_attempts +from log import log + +# 本地模块 - 工具和认证 +from src.utils import ( + get_base_model_from_feature_model, + is_anti_truncation_model, + authenticate_gemini_flexible, + is_fake_streaming_model +) + +# 本地模块 - 转换器(假流式需要) +from src.converter.fake_stream import ( + parse_response_for_fake_stream, + build_gemini_fake_stream_chunks, + create_gemini_heartbeat_chunk, +) + +# 本地模块 - 基础路由工具 +from src.router.hi_check import is_health_check_request, create_health_check_response + +# 本地模块 - 数据模型 +from src.models import GeminiRequest, model_to_dict + +# 本地模块 - 任务管理 +from src.task_manager import create_managed_task + + +# ==================== 路由器初始化 ==================== + +router = APIRouter() + + +# ==================== API 路由 ==================== + +@router.post("/v1beta/models/{model:path}:generateContent") +@router.post("/v1/models/{model:path}:generateContent") +async def generate_content( + gemini_request: "GeminiRequest", + model: str = Path(..., description="Model name"), + api_key: str = Depends(authenticate_gemini_flexible), +): + """ + 处理Gemini格式的内容生成请求(非流式) + + Args: + gemini_request: Gemini格式的请求体 + model: 模型名称 + api_key: API 密钥 + """ + log.debug(f"[GEMINICLI] Non-streaming request for model: {model}") + + # 转换为字典 + normalized_dict = model_to_dict(gemini_request) + + # 健康检查 + if is_health_check_request(normalized_dict, format="gemini"): + response = create_health_check_response(format="gemini") + return JSONResponse(content=response) + + # 处理模型名称和功能检测 + use_anti_truncation = is_anti_truncation_model(model) + real_model = get_base_model_from_feature_model(model) + + # 对于抗截断模型的非流式请求,给出警告 + if use_anti_truncation: + log.warning("抗截断功能仅在流式传输时有效,非流式请求将忽略此设置") + + # 更新模型名为真实模型名 + normalized_dict["model"] = real_model + + # 规范化 Gemini 请求 (使用 geminicli 模式) + from src.converter.gemini_fix import normalize_gemini_request + normalized_dict = await normalize_gemini_request(normalized_dict, mode="geminicli") + + # 准备API请求格式 - 提取model并将其他字段放入request中 + api_request = { + "model": normalized_dict.pop("model"), + "request": normalized_dict + } + + # 调用 API 层的非流式请求 + from src.api.geminicli import non_stream_request + response = await non_stream_request(body=api_request) + + # 直接返回响应(response已经是FastAPI Response对象) + return response + +@router.post("/v1beta/models/{model:path}:streamGenerateContent") +@router.post("/v1/models/{model:path}:streamGenerateContent") +async def stream_generate_content( + gemini_request: GeminiRequest, + model: str = Path(..., description="Model name"), + api_key: str = Depends(authenticate_gemini_flexible), +): + """ + 处理Gemini格式的流式内容生成请求 + + Args: + gemini_request: Gemini格式的请求体 + model: 模型名称 + api_key: API 密钥 + """ + log.debug(f"[GEMINICLI] Streaming request for model: {model}") + + # 转换为字典 + normalized_dict = model_to_dict(gemini_request) + + # 处理模型名称和功能检测 + use_fake_streaming = is_fake_streaming_model(model) + use_anti_truncation = is_anti_truncation_model(model) + real_model = get_base_model_from_feature_model(model) + + # 更新模型名为真实模型名 + normalized_dict["model"] = real_model + + # ========== 假流式生成器 ========== + async def fake_stream_generator(): + from src.converter.gemini_fix import normalize_gemini_request + normalized_req = await normalize_gemini_request(normalized_dict.copy(), mode="geminicli") + + # 准备API请求格式 - 提取model并将其他字段放入request中 + api_request = { + "model": normalized_req.pop("model"), + "request": normalized_req + } + + # 发送心跳 + heartbeat = create_gemini_heartbeat_chunk() + yield f"data: {json.dumps(heartbeat)}\n\n".encode() + + # 异步发送实际请求 + async def get_response(): + from src.api.geminicli import non_stream_request + response = await non_stream_request(body=api_request) + return response + + # 创建请求任务 + response_task = create_managed_task(get_response(), name="gemini_fake_stream_request") + + try: + # 每3秒发送一次心跳,直到收到响应 + while not response_task.done(): + await asyncio.sleep(3.0) + if not response_task.done(): + yield f"data: {json.dumps(heartbeat)}\n\n".encode() + + # 获取响应结果 + response = await response_task + + except asyncio.CancelledError: + response_task.cancel() + try: + await response_task + except asyncio.CancelledError: + pass + raise + except Exception as e: + response_task.cancel() + try: + await response_task + except asyncio.CancelledError: + pass + log.error(f"Fake streaming request failed: {e}") + raise + + # 检查响应状态码 + if hasattr(response, "status_code") and response.status_code != 200: + # 错误响应 - 提取错误信息并以SSE格式返回 + log.error(f"Fake streaming got error response: status={response.status_code}") + + if hasattr(response, "body"): + error_body = response.body.decode() if isinstance(response.body, bytes) else response.body + elif hasattr(response, "content"): + error_body = response.content.decode() if isinstance(response.content, bytes) else response.content + else: + error_body = str(response) + + try: + error_data = json.loads(error_body) + # 以SSE格式返回错误 + yield f"data: {json.dumps(error_data)}\n\n".encode() + except Exception: + # 如果无法解析为JSON,包装成错误对象 + yield f"data: {json.dumps({'error': error_body})}\n\n".encode() + + yield "data: [DONE]\n\n".encode() + return + + # 处理成功响应 - 提取响应内容 + if hasattr(response, "body"): + response_body = response.body.decode() if isinstance(response.body, bytes) else response.body + elif hasattr(response, "content"): + response_body = response.content.decode() if isinstance(response.content, bytes) else response.content + else: + response_body = str(response) + + try: + response_data = json.loads(response_body) + log.debug(f"Gemini fake stream response data: {response_data}") + + # 检查是否是错误响应(有些错误可能status_code是200但包含error字段) + if "error" in response_data: + log.error(f"Fake streaming got error in response body: {response_data['error']}") + yield f"data: {json.dumps(response_data)}\n\n".encode() + yield "data: [DONE]\n\n".encode() + return + + # 使用统一的解析函数 + content, reasoning_content, finish_reason, images = parse_response_for_fake_stream(response_data) + + log.debug(f"Gemini extracted content: {content}") + log.debug(f"Gemini extracted reasoning: {reasoning_content[:100] if reasoning_content else 'None'}...") + log.debug(f"Gemini extracted images count: {len(images)}") + + # 构建响应块 + chunks = build_gemini_fake_stream_chunks(content, reasoning_content, finish_reason, images) + for idx, chunk in enumerate(chunks): + chunk_json = json.dumps(chunk) + log.debug(f"[FAKE_STREAM] Yielding chunk #{idx+1}: {chunk_json[:200]}") + yield f"data: {chunk_json}\n\n".encode() + + except Exception as e: + log.error(f"Response parsing failed: {e}, directly yield original response") + # 直接yield原始响应,不进行包装 + yield f"data: {response_body}\n\n".encode() + + yield "data: [DONE]\n\n".encode() + + # ========== 流式抗截断生成器 ========== + async def anti_truncation_generator(): + from src.converter.gemini_fix import normalize_gemini_request + from src.converter.anti_truncation import AntiTruncationStreamProcessor + from src.converter.anti_truncation import apply_anti_truncation + from src.api.geminicli import stream_request + + # 先进行基础标准化 + normalized_req = await normalize_gemini_request(normalized_dict.copy(), mode="geminicli") + + # 准备API请求格式 - 提取model并将其他字段放入request中 + api_request = { + "model": normalized_req.pop("model") if "model" in normalized_req else real_model, + "request": normalized_req + } + + max_attempts = await get_anti_truncation_max_attempts() + + # 首先对payload应用反截断指令 + anti_truncation_payload = apply_anti_truncation(api_request) + + # 定义流式请求函数(返回 StreamingResponse) + async def stream_request_wrapper(payload): + # stream_request 返回异步生成器,需要包装成 StreamingResponse + stream_gen = stream_request(body=payload, native=False) + return StreamingResponse(stream_gen, media_type="text/event-stream") + + # 创建反截断处理器 + processor = AntiTruncationStreamProcessor( + stream_request_wrapper, + anti_truncation_payload, + max_attempts + ) + + # 迭代 process_stream() 生成器,并展开 response 包装 + async for chunk in processor.process_stream(): + if isinstance(chunk, (str, bytes)): + chunk_str = chunk.decode('utf-8') if isinstance(chunk, bytes) else chunk + + # 解析并展开 response 包装 + if chunk_str.startswith("data: "): + json_str = chunk_str[6:].strip() + + # 跳过 [DONE] 标记 + if json_str == "[DONE]": + yield chunk + continue + + try: + # 解析JSON + data = json.loads(json_str) + + # 展开 response 包装 + if "response" in data and "candidates" not in data: + log.debug(f"[GEMINICLI-ANTI-TRUNCATION] 展开response包装") + unwrapped_data = data["response"] + # 重新构建SSE格式 + yield f"data: {json.dumps(unwrapped_data, ensure_ascii=False)}\n\n".encode('utf-8') + else: + # 已经是展开的格式,直接返回 + yield chunk + except json.JSONDecodeError: + # JSON解析失败,直接返回原始chunk + yield chunk + else: + # 不是SSE格式,直接返回 + yield chunk + else: + # 其他类型,直接返回 + yield chunk + + # ========== 普通流式生成器 ========== + async def normal_stream_generator(): + from src.converter.gemini_fix import normalize_gemini_request + from src.api.geminicli import stream_request + from fastapi import Response + + normalized_req = await normalize_gemini_request(normalized_dict.copy(), mode="geminicli") + + # 准备API请求格式 - 提取model并将其他字段放入request中 + api_request = { + "model": normalized_req.pop("model"), + "request": normalized_req + } + + # 所有流式请求都使用非 native 模式(SSE格式)并展开 response 包装 + log.debug(f"[GEMINICLI] 使用非native模式,将展开response包装") + stream_gen = stream_request(body=api_request, native=False) + + # 展开 response 包装 + async for chunk in stream_gen: + # 检查是否是Response对象(错误情况) + if isinstance(chunk, Response): + # 将Response转换为SSE格式的错误消息 + error_content = chunk.body if isinstance(chunk.body, bytes) else chunk.body.encode('utf-8') + error_json = json.loads(error_content.decode('utf-8')) + # 以SSE格式返回错误 + yield f"data: {json.dumps(error_json)}\n\n".encode('utf-8') + return + + # 处理SSE格式的chunk + if isinstance(chunk, (str, bytes)): + chunk_str = chunk.decode('utf-8') if isinstance(chunk, bytes) else chunk + + # 解析并展开 response 包装 + if chunk_str.startswith("data: "): + json_str = chunk_str[6:].strip() + + # 跳过 [DONE] 标记 + if json_str == "[DONE]": + yield chunk + continue + + try: + # 解析JSON + data = json.loads(json_str) + + # 展开 response 包装 + if "response" in data and "candidates" not in data: + log.debug(f"[GEMINICLI] 展开response包装") + unwrapped_data = data["response"] + # 重新构建SSE格式 + yield f"data: {json.dumps(unwrapped_data, ensure_ascii=False)}\n\n".encode('utf-8') + else: + # 已经是展开的格式,直接返回 + yield chunk + except json.JSONDecodeError: + # JSON解析失败,直接返回原始chunk + yield chunk + else: + # 不是SSE格式,直接返回 + yield chunk + + # ========== 根据模式选择生成器 ========== + if use_fake_streaming: + return StreamingResponse(fake_stream_generator(), media_type="text/event-stream") + elif use_anti_truncation: + log.info("启用流式抗截断功能") + return StreamingResponse(anti_truncation_generator(), media_type="text/event-stream") + else: + return StreamingResponse(normal_stream_generator(), media_type="text/event-stream") + +@router.post("/v1beta/models/{model:path}:countTokens") +@router.post("/v1/models/{model:path}:countTokens") +async def count_tokens( + request: Request = None, + api_key: str = Depends(authenticate_gemini_flexible), +): + """ + 模拟Gemini格式的token计数 + + 使用简单的启发式方法:大约4字符=1token + """ + + try: + request_data = await request.json() + except Exception as e: + log.error(f"Failed to parse JSON request: {e}") + raise HTTPException(status_code=400, detail=f"Invalid JSON: {str(e)}") + + # 简单的token计数模拟 - 基于文本长度估算 + total_tokens = 0 + + # 如果有contents字段 + if "contents" in request_data: + for content in request_data["contents"]: + if "parts" in content: + for part in content["parts"]: + if "text" in part: + # 简单估算:大约4字符=1token + text_length = len(part["text"]) + total_tokens += max(1, text_length // 4) + + # 如果有generateContentRequest字段 + elif "generateContentRequest" in request_data: + gen_request = request_data["generateContentRequest"] + if "contents" in gen_request: + for content in gen_request["contents"]: + if "parts" in content: + for part in content["parts"]: + if "text" in part: + text_length = len(part["text"]) + total_tokens += max(1, text_length // 4) + + # 返回Gemini格式的响应 + return JSONResponse(content={"totalTokens": total_tokens}) + +# ==================== 测试代码 ==================== + +if __name__ == "__main__": + """ + 测试代码:演示Gemini路由的流式和非流式响应 + 运行方式: python src/router/geminicli/gemini.py + """ + + from fastapi.testclient import TestClient + from fastapi import FastAPI + + print("=" * 80) + print("Gemini Router 测试") + print("=" * 80) + + # 创建测试应用 + app = FastAPI() + app.include_router(router) + + # 测试客户端 + client = TestClient(app) + + # 测试请求体 (Gemini格式) + test_request_body = { + "contents": [ + { + "role": "user", + "parts": [{"text": "Hello, tell me a joke in one sentence."}] + } + ] + } + + # 测试API密钥(模拟) + test_api_key = "pwd" + + def test_non_stream_request(): + """测试非流式请求""" + print("\n" + "=" * 80) + print("【测试2】非流式请求 (POST /v1/models/gemini-2.5-flash:generateContent)") + print("=" * 80) + print(f"请求体: {json.dumps(test_request_body, indent=2, ensure_ascii=False)}\n") + + response = client.post( + "/v1/models/gemini-2.5-flash:generateContent", + json=test_request_body, + params={"key": test_api_key} + ) + + print("非流式响应数据:") + print("-" * 80) + print(f"状态码: {response.status_code}") + print(f"Content-Type: {response.headers.get('content-type', 'N/A')}") + + try: + content = response.text + print(f"\n响应内容 (原始):\n{content}\n") + + # 尝试解析JSON + try: + json_data = response.json() + print(f"响应内容 (格式化JSON):") + print(json.dumps(json_data, indent=2, ensure_ascii=False)) + except json.JSONDecodeError: + print("(非JSON格式)") + except Exception as e: + print(f"内容解析失败: {e}") + + def test_stream_request(): + """测试流式请求""" + print("\n" + "=" * 80) + print("【测试3】流式请求 (POST /v1/models/gemini-2.5-flash:streamGenerateContent)") + print("=" * 80) + print(f"请求体: {json.dumps(test_request_body, indent=2, ensure_ascii=False)}\n") + + print("流式响应数据 (每个chunk):") + print("-" * 80) + + with client.stream( + "POST", + "/v1/models/gemini-2.5-flash:streamGenerateContent", + json=test_request_body, + params={"key": test_api_key} + ) as response: + print(f"状态码: {response.status_code}") + print(f"Content-Type: {response.headers.get('content-type', 'N/A')}\n") + + chunk_count = 0 + for chunk in response.iter_bytes(): + if chunk: + chunk_count += 1 + print(f"\nChunk #{chunk_count}:") + print(f" 类型: {type(chunk).__name__}") + print(f" 长度: {len(chunk)}") + + # 解码chunk + try: + chunk_str = chunk.decode('utf-8') + print(f" 内容预览: {repr(chunk_str[:200] if len(chunk_str) > 200 else chunk_str)}") + + # 如果是SSE格式,尝试解析每一行 + if chunk_str.startswith("data: "): + # 按行分割,处理每个SSE事件 + for line in chunk_str.strip().split('\n'): + line = line.strip() + if not line: + continue + + if line == "data: [DONE]": + print(f" => 流结束标记") + elif line.startswith("data: "): + try: + json_str = line[6:] # 去掉 "data: " 前缀 + json_data = json.loads(json_str) + print(f" 解析后的JSON: {json.dumps(json_data, indent=4, ensure_ascii=False)}") + except Exception as e: + print(f" SSE解析失败: {e}") + except Exception as e: + print(f" 解码失败: {e}") + + print(f"\n总共收到 {chunk_count} 个chunk") + + def test_fake_stream_request(): + """测试假流式请求""" + print("\n" + "=" * 80) + print("【测试4】假流式请求 (POST /v1/models/假流式/gemini-2.5-flash:streamGenerateContent)") + print("=" * 80) + print(f"请求体: {json.dumps(test_request_body, indent=2, ensure_ascii=False)}\n") + + print("假流式响应数据 (每个chunk):") + print("-" * 80) + + with client.stream( + "POST", + "/v1/models/假流式/gemini-2.5-flash:streamGenerateContent", + json=test_request_body, + params={"key": test_api_key} + ) as response: + print(f"状态码: {response.status_code}") + print(f"Content-Type: {response.headers.get('content-type', 'N/A')}\n") + + chunk_count = 0 + for chunk in response.iter_bytes(): + if chunk: + chunk_count += 1 + chunk_str = chunk.decode('utf-8') + + print(f"\nChunk #{chunk_count}:") + print(f" 长度: {len(chunk_str)} 字节") + + # 解析chunk中的所有SSE事件 + events = [] + for line in chunk_str.split('\n'): + line = line.strip() + if line.startswith("data: "): + events.append(line) + + print(f" 包含 {len(events)} 个SSE事件") + + # 显示每个事件 + for event_idx, event_line in enumerate(events, 1): + if event_line == "data: [DONE]": + print(f" 事件 #{event_idx}: [DONE]") + else: + try: + json_str = event_line[6:] # 去掉 "data: " 前缀 + json_data = json.loads(json_str) + # 提取text内容 + text = json_data.get("candidates", [{}])[0].get("content", {}).get("parts", [{}])[0].get("text", "") + finish_reason = json_data.get("candidates", [{}])[0].get("finishReason") + print(f" 事件 #{event_idx}: text={repr(text[:50])}{'...' if len(text) > 50 else ''}, finishReason={finish_reason}") + except Exception as e: + print(f" 事件 #{event_idx}: 解析失败 - {e}") + + print(f"\n总共收到 {chunk_count} 个HTTP chunk") + + def test_anti_truncation_stream_request(): + """测试流式抗截断请求""" + print("\n" + "=" * 80) + print("【测试5】流式抗截断请求 (POST /v1/models/流式抗截断/gemini-2.5-flash:streamGenerateContent)") + print("=" * 80) + print(f"请求体: {json.dumps(test_request_body, indent=2, ensure_ascii=False)}\n") + + print("流式抗截断响应数据 (每个chunk):") + print("-" * 80) + + with client.stream( + "POST", + "/v1/models/流式抗截断/gemini-2.5-flash:streamGenerateContent", + json=test_request_body, + params={"key": test_api_key} + ) as response: + print(f"状态码: {response.status_code}") + print(f"Content-Type: {response.headers.get('content-type', 'N/A')}\n") + + chunk_count = 0 + for chunk in response.iter_bytes(): + if chunk: + chunk_count += 1 + print(f"\nChunk #{chunk_count}:") + print(f" 类型: {type(chunk).__name__}") + print(f" 长度: {len(chunk)}") + + # 解码chunk + try: + chunk_str = chunk.decode('utf-8') + print(f" 内容预览: {repr(chunk_str[:200] if len(chunk_str) > 200 else chunk_str)}") + + # 如果是SSE格式,尝试解析每一行 + if chunk_str.startswith("data: "): + # 按行分割,处理每个SSE事件 + for line in chunk_str.strip().split('\n'): + line = line.strip() + if not line: + continue + + if line == "data: [DONE]": + print(f" => 流结束标记") + elif line.startswith("data: "): + try: + json_str = line[6:] # 去掉 "data: " 前缀 + json_data = json.loads(json_str) + print(f" 解析后的JSON: {json.dumps(json_data, indent=4, ensure_ascii=False)}") + except Exception as e: + print(f" SSE解析失败: {e}") + except Exception as e: + print(f" 解码失败: {e}") + + print(f"\n总共收到 {chunk_count} 个chunk") + + # 运行测试 + try: + # 测试非流式请求 + test_non_stream_request() + + # 测试流式请求 + test_stream_request() + + # 测试假流式请求 + test_fake_stream_request() + + # 测试流式抗截断请求 + test_anti_truncation_stream_request() + + print("\n" + "=" * 80) + print("测试完成") + print("=" * 80) + + except Exception as e: + print(f"\n❌ 测试过程中出现异常: {e}") + import traceback + traceback.print_exc() + diff --git a/src/router/geminicli/model_list.py b/src/router/geminicli/model_list.py new file mode 100644 index 0000000000000000000000000000000000000000..333a987a8f2d2b99155436ad2bc35349bd2e325f --- /dev/null +++ b/src/router/geminicli/model_list.py @@ -0,0 +1,66 @@ +""" +Gemini CLI Model List Router - Handles model list requests +Gemini CLI 模型列表路由 - 处理模型列表请求 +""" + +import sys +from pathlib import Path + +# 添加项目根目录到Python路径 +project_root = Path(__file__).resolve().parent.parent.parent.parent +if str(project_root) not in sys.path: + sys.path.insert(0, str(project_root)) + +# 第三方库 +from fastapi import APIRouter, Depends +from fastapi.responses import JSONResponse + +# 本地模块 - 工具和认证 +from src.utils import ( + get_available_models, + get_base_model_from_feature_model, + authenticate_flexible +) + +# 本地模块 - 基础路由工具 +from src.router.base_router import create_gemini_model_list, create_openai_model_list +from src.models import model_to_dict +from log import log + + +# ==================== 路由器初始化 ==================== + +router = APIRouter() + + +# ==================== API 路由 ==================== + +@router.get("/v1beta/models") +async def list_gemini_models(token: str = Depends(authenticate_flexible)): + """ + 返回 Gemini 格式的模型列表 + + 使用 create_gemini_model_list 工具函数创建标准格式 + """ + models = get_available_models("gemini") + log.info("[GEMINICLI MODEL LIST] 返回 Gemini 格式") + return JSONResponse(content=create_gemini_model_list( + models, + base_name_extractor=get_base_model_from_feature_model + )) + + +@router.get("/v1/models") +async def list_openai_models(token: str = Depends(authenticate_flexible)): + """ + 返回 OpenAI 格式的模型列表 + + 使用 create_openai_model_list 工具函数创建标准格式 + """ + models = get_available_models("gemini") + log.info("[GEMINICLI MODEL LIST] 返回 OpenAI 格式") + model_list = create_openai_model_list(models, owned_by="google") + return JSONResponse(content={ + "object": "list", + "data": [model_to_dict(model) for model in model_list.data] + }) diff --git a/src/router/geminicli/openai.py b/src/router/geminicli/openai.py new file mode 100644 index 0000000000000000000000000000000000000000..df4fa18959ced8adbb3004123e3075a84f5c5710 --- /dev/null +++ b/src/router/geminicli/openai.py @@ -0,0 +1,615 @@ +""" +OpenAI Router - Handles OpenAI format API requests via GeminiCLI +通过GeminiCLI处理OpenAI格式请求的路由模块 +""" + +import sys +from pathlib import Path + +# 添加项目根目录到Python路径 +project_root = Path(__file__).resolve().parent.parent.parent.parent +if str(project_root) not in sys.path: + sys.path.insert(0, str(project_root)) + +# 标准库 +import asyncio +import json + +# 第三方库 +from fastapi import APIRouter, Depends, HTTPException +from fastapi.responses import JSONResponse, StreamingResponse + +# 本地模块 - 配置和日志 +from config import get_anti_truncation_max_attempts +from log import log + +# 本地模块 - 工具和认证 +from src.utils import ( + get_base_model_from_feature_model, + is_anti_truncation_model, + is_fake_streaming_model, + authenticate_bearer, +) + +# 本地模块 - 转换器(假流式需要) +from src.converter.fake_stream import ( + parse_response_for_fake_stream, + build_openai_fake_stream_chunks, + create_openai_heartbeat_chunk, +) + +# 本地模块 - 基础路由工具 +from src.router.hi_check import is_health_check_request, create_health_check_response + +# 本地模块 - 数据模型 +from src.models import OpenAIChatCompletionRequest, model_to_dict + +# 本地模块 - 任务管理 +from src.task_manager import create_managed_task + + +# ==================== 路由器初始化 ==================== + +router = APIRouter() + + +# ==================== API 路由 ==================== + +@router.post("/v1/chat/completions") +async def chat_completions( + openai_request: OpenAIChatCompletionRequest, + token: str = Depends(authenticate_bearer) +): + """ + 处理OpenAI格式的聊天完成请求(流式和非流式) + + Args: + openai_request: OpenAI格式的请求体 + token: Bearer认证令牌 + """ + log.debug(f"[GEMINICLI-OPENAI] Request for model: {openai_request.model}") + + # 转换为字典 + normalized_dict = model_to_dict(openai_request) + + # 健康检查 + if is_health_check_request(normalized_dict, format="openai"): + response = create_health_check_response(format="openai") + return JSONResponse(content=response) + + # 处理模型名称和功能检测 + use_fake_streaming = is_fake_streaming_model(openai_request.model) + use_anti_truncation = is_anti_truncation_model(openai_request.model) + real_model = get_base_model_from_feature_model(openai_request.model) + + # 获取流式标志 + is_streaming = openai_request.stream + + # 对于抗截断模型的非流式请求,给出警告 + if use_anti_truncation and not is_streaming: + log.warning("抗截断功能仅在流式传输时有效,非流式请求将忽略此设置") + + # 更新模型名为真实模型名 + normalized_dict["model"] = real_model + + # 转换为 Gemini 格式 (使用 converter) + from src.converter.openai2gemini import convert_openai_to_gemini_request + gemini_dict = await convert_openai_to_gemini_request(normalized_dict) + + # convert_openai_to_gemini_request 不包含 model 字段,需要手动添加 + gemini_dict["model"] = real_model + + # 规范化 Gemini 请求 (使用 geminicli 模式) + from src.converter.gemini_fix import normalize_gemini_request + gemini_dict = await normalize_gemini_request(gemini_dict, mode="geminicli") + + # 准备API请求格式 - 提取model并将其他字段放入request中 + api_request = { + "model": gemini_dict.pop("model"), + "request": gemini_dict + } + + # ========== 非流式请求 ========== + if not is_streaming: + # 调用 API 层的非流式请求 + from src.api.geminicli import non_stream_request + response = await non_stream_request(body=api_request) + + # 检查响应状态码 + status_code = getattr(response, "status_code", 200) + + # 提取响应体 + if hasattr(response, "body"): + response_body = response.body.decode() if isinstance(response.body, bytes) else response.body + elif hasattr(response, "content"): + response_body = response.content.decode() if isinstance(response.content, bytes) else response.content + else: + response_body = str(response) + + try: + gemini_response = json.loads(response_body) + except Exception as e: + log.error(f"Failed to parse Gemini response: {e}") + raise HTTPException(status_code=500, detail="Response parsing failed") + + # 转换为 OpenAI 格式 + from src.converter.openai2gemini import convert_gemini_to_openai_response + openai_response = convert_gemini_to_openai_response( + gemini_response, + real_model, + status_code + ) + + return JSONResponse(content=openai_response, status_code=status_code) + + # ========== 流式请求 ========== + + # ========== 假流式生成器 ========== + async def fake_stream_generator(): + # 发送心跳 + heartbeat = create_openai_heartbeat_chunk() + yield f"data: {json.dumps(heartbeat)}\n\n".encode() + + # 异步发送实际请求 + async def get_response(): + from src.api.geminicli import non_stream_request + response = await non_stream_request(body=api_request) + return response + + # 创建请求任务 + response_task = create_managed_task(get_response(), name="openai_fake_stream_request") + + try: + # 每3秒发送一次心跳,直到收到响应 + while not response_task.done(): + await asyncio.sleep(3.0) + if not response_task.done(): + yield f"data: {json.dumps(heartbeat)}\n\n".encode() + + # 获取响应结果 + response = await response_task + + except asyncio.CancelledError: + response_task.cancel() + try: + await response_task + except asyncio.CancelledError: + pass + raise + except Exception as e: + response_task.cancel() + try: + await response_task + except asyncio.CancelledError: + pass + log.error(f"Fake streaming request failed: {e}") + raise + + # 检查响应状态码 + if hasattr(response, "status_code") and response.status_code != 200: + # 错误响应 - 提取错误信息并以SSE格式返回 + log.error(f"Fake streaming got error response: status={response.status_code}") + + if hasattr(response, "body"): + error_body = response.body.decode() if isinstance(response.body, bytes) else response.body + elif hasattr(response, "content"): + error_body = response.content.decode() if isinstance(response.content, bytes) else response.content + else: + error_body = str(response) + + try: + error_data = json.loads(error_body) + # 转换错误为 OpenAI 格式 + from src.converter.openai2gemini import convert_gemini_to_openai_response + openai_error = convert_gemini_to_openai_response( + error_data, + real_model, + response.status_code + ) + yield f"data: {json.dumps(openai_error)}\n\n".encode() + except Exception: + # 如果无法解析为JSON,包装成错误对象 + yield f"data: {json.dumps({'error': error_body})}\n\n".encode() + + yield "data: [DONE]\n\n".encode() + return + + # 处理成功响应 - 提取响应内容 + if hasattr(response, "body"): + response_body = response.body.decode() if isinstance(response.body, bytes) else response.body + elif hasattr(response, "content"): + response_body = response.content.decode() if isinstance(response.content, bytes) else response.content + else: + response_body = str(response) + + try: + gemini_response = json.loads(response_body) + log.debug(f"OpenAI fake stream Gemini response: {gemini_response}") + + # 检查是否是错误响应(有些错误可能status_code是200但包含error字段) + if "error" in gemini_response: + log.error(f"Fake streaming got error in response body: {gemini_response['error']}") + # 转换错误为 OpenAI 格式 + from src.converter.openai2gemini import convert_gemini_to_openai_response + openai_error = convert_gemini_to_openai_response( + gemini_response, + real_model, + 200 + ) + yield f"data: {json.dumps(openai_error)}\n\n".encode() + yield "data: [DONE]\n\n".encode() + return + + # 使用统一的解析函数 + content, reasoning_content, finish_reason, images = parse_response_for_fake_stream(gemini_response) + + log.debug(f"OpenAI extracted content: {content}") + log.debug(f"OpenAI extracted reasoning: {reasoning_content[:100] if reasoning_content else 'None'}...") + log.debug(f"OpenAI extracted images count: {len(images)}") + + # 构建响应块 + chunks = build_openai_fake_stream_chunks(content, reasoning_content, finish_reason, real_model, images) + for idx, chunk in enumerate(chunks): + chunk_json = json.dumps(chunk) + log.debug(f"[FAKE_STREAM] Yielding chunk #{idx+1}: {chunk_json[:200]}") + yield f"data: {chunk_json}\n\n".encode() + + except Exception as e: + log.error(f"Response parsing failed: {e}, directly yield error") + # 构建错误响应 + error_chunk = { + "id": "error", + "object": "chat.completion.chunk", + "created": int(asyncio.get_event_loop().time()), + "model": real_model, + "choices": [{ + "index": 0, + "delta": {"content": f"Error: {str(e)}"}, + "finish_reason": "error" + }] + } + yield f"data: {json.dumps(error_chunk)}\n\n".encode() + + yield "data: [DONE]\n\n".encode() + + # ========== 流式抗截断生成器 ========== + async def anti_truncation_generator(): + from src.converter.anti_truncation import AntiTruncationStreamProcessor + from src.api.geminicli import stream_request + from src.converter.anti_truncation import apply_anti_truncation + + max_attempts = await get_anti_truncation_max_attempts() + + # 首先对payload应用反截断指令 + anti_truncation_payload = apply_anti_truncation(api_request) + + # 定义流式请求函数(返回 StreamingResponse) + async def stream_request_wrapper(payload): + # stream_request 返回异步生成器,需要包装成 StreamingResponse + stream_gen = stream_request(body=payload, native=False) + return StreamingResponse(stream_gen, media_type="text/event-stream") + + # 创建反截断处理器 + processor = AntiTruncationStreamProcessor( + stream_request_wrapper, + anti_truncation_payload, + max_attempts + ) + + # 转换为 OpenAI 格式 + import uuid + response_id = str(uuid.uuid4()) + + # 直接迭代 process_stream() 生成器,并转换为 OpenAI 格式 + async for chunk in processor.process_stream(): + if not chunk: + continue + + # 解析 Gemini SSE 格式 + chunk_str = chunk.decode('utf-8') if isinstance(chunk, bytes) else chunk + + # 跳过空行 + if not chunk_str.strip(): + continue + + # 处理 [DONE] 标记 + if chunk_str.strip() == "data: [DONE]": + yield "data: [DONE]\n\n".encode('utf-8') + return + + # 解析 "data: {...}" 格式 + if chunk_str.startswith("data: "): + try: + # 转换为 OpenAI 格式 + from src.converter.openai2gemini import convert_gemini_to_openai_stream + openai_chunk_str = convert_gemini_to_openai_stream( + chunk_str, + real_model, + response_id + ) + + if openai_chunk_str: + yield openai_chunk_str.encode('utf-8') + + except Exception as e: + log.error(f"Failed to convert chunk: {e}") + continue + + # 发送结束标记 + yield "data: [DONE]\n\n".encode('utf-8') + + # ========== 普通流式生成器 ========== + async def normal_stream_generator(): + from src.api.geminicli import stream_request + from fastapi import Response + import uuid + + # 调用 API 层的流式请求(不使用 native 模式) + stream_gen = stream_request(body=api_request, native=False) + + response_id = str(uuid.uuid4()) + + # yield所有数据,处理可能的错误Response + async for chunk in stream_gen: + # 检查是否是Response对象(错误情况) + if isinstance(chunk, Response): + # 将Response转换为SSE格式的错误消息 + error_content = chunk.body if isinstance(chunk.body, bytes) else chunk.body.encode('utf-8') + try: + gemini_error = json.loads(error_content.decode('utf-8')) + # 转换为 OpenAI 格式错误 + from src.converter.openai2gemini import convert_gemini_to_openai_response + openai_error = convert_gemini_to_openai_response( + gemini_error, + real_model, + chunk.status_code + ) + yield f"data: {json.dumps(openai_error)}\n\n".encode('utf-8') + except Exception: + yield f"data: {json.dumps({'error': 'Stream error'})}\n\n".encode('utf-8') + return + else: + # 正常的bytes数据,转换为 OpenAI 格式 + chunk_str = chunk.decode('utf-8') if isinstance(chunk, bytes) else chunk + + # 跳过空行 + if not chunk_str.strip(): + continue + + # 处理 [DONE] 标记 + if chunk_str.strip() == "data: [DONE]": + yield "data: [DONE]\n\n".encode('utf-8') + return + + # 解析并转换 Gemini chunk 为 OpenAI 格式 + if chunk_str.startswith("data: "): + try: + # 转换为 OpenAI 格式 + from src.converter.openai2gemini import convert_gemini_to_openai_stream + openai_chunk_str = convert_gemini_to_openai_stream( + chunk_str, + real_model, + response_id + ) + + if openai_chunk_str: + yield openai_chunk_str.encode('utf-8') + + except Exception as e: + log.error(f"Failed to convert chunk: {e}") + continue + + # 发送结束标记 + yield "data: [DONE]\n\n".encode('utf-8') + + # ========== 根据模式选择生成器 ========== + if use_fake_streaming: + return StreamingResponse(fake_stream_generator(), media_type="text/event-stream") + elif use_anti_truncation: + log.info("启用流式抗截断功能") + return StreamingResponse(anti_truncation_generator(), media_type="text/event-stream") + else: + return StreamingResponse(normal_stream_generator(), media_type="text/event-stream") + + +# ==================== 测试代码 ==================== + +if __name__ == "__main__": + """ + 测试代码:演示OpenAI路由的流式和非流式响应 + 运行方式: python src/router/geminicli/openai.py + """ + + from fastapi.testclient import TestClient + from fastapi import FastAPI + + print("=" * 80) + print("OpenAI Router 测试") + print("=" * 80) + + # 创建测试应用 + app = FastAPI() + app.include_router(router) + + # 测试客户端 + client = TestClient(app) + + # 测试请求体 (OpenAI格式) + test_request_body = { + "model": "gemini-2.5-flash", + "messages": [ + {"role": "user", "content": "Hello, tell me a joke in one sentence."} + ] + } + + # 测试Bearer令牌(模拟) + test_token = "Bearer pwd" + + def test_non_stream_request(): + """测试非流式请求""" + print("\n" + "=" * 80) + print("【测试1】非流式请求 (POST /v1/chat/completions)") + print("=" * 80) + print(f"请求体: {json.dumps(test_request_body, indent=2, ensure_ascii=False)}\n") + + response = client.post( + "/v1/chat/completions", + json=test_request_body, + headers={"Authorization": test_token} + ) + + print("非流式响应数据:") + print("-" * 80) + print(f"状态码: {response.status_code}") + print(f"Content-Type: {response.headers.get('content-type', 'N/A')}") + + try: + content = response.text + print(f"\n响应内容 (原始):\n{content}\n") + + # 尝试解析JSON + try: + json_data = response.json() + print(f"响应内容 (格式化JSON):") + print(json.dumps(json_data, indent=2, ensure_ascii=False)) + except json.JSONDecodeError: + print("(非JSON格式)") + except Exception as e: + print(f"内容解析失败: {e}") + + def test_stream_request(): + """测试流式请求""" + print("\n" + "=" * 80) + print("【测试2】流式请求 (POST /v1/chat/completions)") + print("=" * 80) + + stream_request_body = test_request_body.copy() + stream_request_body["stream"] = True + + print(f"请求体: {json.dumps(stream_request_body, indent=2, ensure_ascii=False)}\n") + + print("流式响应数据 (每个chunk):") + print("-" * 80) + + with client.stream( + "POST", + "/v1/chat/completions", + json=stream_request_body, + headers={"Authorization": test_token} + ) as response: + print(f"状态码: {response.status_code}") + print(f"Content-Type: {response.headers.get('content-type', 'N/A')}\n") + + chunk_count = 0 + for chunk in response.iter_bytes(): + if chunk: + chunk_count += 1 + print(f"\nChunk #{chunk_count}:") + print(f" 类型: {type(chunk).__name__}") + print(f" 长度: {len(chunk)}") + + # 解码chunk + try: + chunk_str = chunk.decode('utf-8') + print(f" 内容预览: {repr(chunk_str[:200] if len(chunk_str) > 200 else chunk_str)}") + + # 如果是SSE格式,尝试解析每一行 + if chunk_str.startswith("data: "): + # 按行分割,处理每个SSE事件 + for line in chunk_str.strip().split('\n'): + line = line.strip() + if not line: + continue + + if line == "data: [DONE]": + print(f" => 流结束标记") + elif line.startswith("data: "): + try: + json_str = line[6:] # 去掉 "data: " 前缀 + json_data = json.loads(json_str) + print(f" 解析后的JSON: {json.dumps(json_data, indent=4, ensure_ascii=False)}") + except Exception as e: + print(f" SSE解析失败: {e}") + except Exception as e: + print(f" 解码失败: {e}") + + print(f"\n总共收到 {chunk_count} 个chunk") + + def test_fake_stream_request(): + """测试假流式请求""" + print("\n" + "=" * 80) + print("【测试3】假流式请求 (POST /v1/chat/completions with 假流式 prefix)") + print("=" * 80) + + fake_stream_request_body = test_request_body.copy() + fake_stream_request_body["model"] = "假流式/gemini-2.5-flash" + fake_stream_request_body["stream"] = True + + print(f"请求体: {json.dumps(fake_stream_request_body, indent=2, ensure_ascii=False)}\n") + + print("假流式响应数据 (每个chunk):") + print("-" * 80) + + with client.stream( + "POST", + "/v1/chat/completions", + json=fake_stream_request_body, + headers={"Authorization": test_token} + ) as response: + print(f"状态码: {response.status_code}") + print(f"Content-Type: {response.headers.get('content-type', 'N/A')}\n") + + chunk_count = 0 + for chunk in response.iter_bytes(): + if chunk: + chunk_count += 1 + chunk_str = chunk.decode('utf-8') + + print(f"\nChunk #{chunk_count}:") + print(f" 长度: {len(chunk_str)} 字节") + + # 解析chunk中的所有SSE事件 + events = [] + for line in chunk_str.split('\n'): + line = line.strip() + if line.startswith("data: "): + events.append(line) + + print(f" 包含 {len(events)} 个SSE事件") + + # 显示每个事件 + for event_idx, event_line in enumerate(events, 1): + if event_line == "data: [DONE]": + print(f" 事件 #{event_idx}: [DONE]") + else: + try: + json_str = event_line[6:] # 去掉 "data: " 前缀 + json_data = json.loads(json_str) + # 提取content内容 + content = json_data.get("choices", [{}])[0].get("delta", {}).get("content", "") + finish_reason = json_data.get("choices", [{}])[0].get("finish_reason") + print(f" 事件 #{event_idx}: content={repr(content[:50])}{'...' if len(content) > 50 else ''}, finish_reason={finish_reason}") + except Exception as e: + print(f" 事件 #{event_idx}: 解析失败 - {e}") + + print(f"\n总共收到 {chunk_count} 个HTTP chunk") + + # 运行测试 + try: + # 测试非流式请求 + test_non_stream_request() + + # 测试流式请求 + test_stream_request() + + # 测试假流式请求 + test_fake_stream_request() + + print("\n" + "=" * 80) + print("测试完成") + print("=" * 80) + + except Exception as e: + print(f"\n❌ 测试过程中出现异常: {e}") + import traceback + traceback.print_exc() diff --git a/src/router/hi_check.py b/src/router/hi_check.py new file mode 100644 index 0000000000000000000000000000000000000000..b3ec7a2928235624e3f1217f78dc26a9090e6909 --- /dev/null +++ b/src/router/hi_check.py @@ -0,0 +1,131 @@ +""" +统一的健康检查(Hi消息)处理模块 + +提供对OpenAI、Gemini和Anthropic格式的Hi消息的解析和响应 +""" +import time +from typing import Any, Dict, List + + +# ==================== Hi消息检测 ==================== + +def is_health_check_request(request_data: dict, format: str = "openai") -> bool: + """ + 检查是否是健康检查请求(Hi消息) + + Args: + request_data: 请求数据 + format: 请求格式("openai"、"gemini" 或 "anthropic") + + Returns: + 是否是健康检查请求 + """ + if format == "openai": + # OpenAI格式健康检查: {"messages": [{"role": "user", "content": "Hi"}]} + messages = request_data.get("messages", []) + if len(messages) == 1: + msg = messages[0] + if msg.get("role") == "user" and msg.get("content") == "Hi": + return True + + elif format == "gemini": + # Gemini格式健康检查: {"contents": [{"role": "user", "parts": [{"text": "Hi"}]}]} + contents = request_data.get("contents", []) + if len(contents) == 1: + content = contents[0] + if (content.get("role") == "user" and + content.get("parts", [{}])[0].get("text") == "Hi"): + return True + + elif format == "anthropic": + # Anthropic格式健康检查: {"messages": [{"role": "user", "content": "Hi"}]} + messages = request_data.get("messages", []) + if (len(messages) == 1 + and messages[0].get("role") == "user" + and messages[0].get("content") == "Hi"): + return True + + return False + + +def is_health_check_message(messages: List[Dict[str, Any]]) -> bool: + """ + 直接检查消息列表是否为健康检查消息(Anthropic专用) + + 这是一个便捷函数,用于已经提取出消息列表的场景。 + + Args: + messages: 消息列表 + + Returns: + 是否为健康检查消息 + """ + return ( + len(messages) == 1 + and messages[0].get("role") == "user" + and messages[0].get("content") == "Hi" + ) + + +# ==================== Hi消息响应生成 ==================== + +def create_health_check_response(format: str = "openai", **kwargs) -> dict: + """ + 创建健康检查响应 + + Args: + format: 响应格式("openai"、"gemini" 或 "anthropic") + **kwargs: 格式特定的额外参数 + - model: 模型名称(anthropic格式需要) + - message_id: 消息ID(anthropic格式需要) + + Returns: + 健康检查响应字典 + """ + if format == "openai": + # OpenAI格式响应 + return { + "id": "healthcheck", + "object": "chat.completion", + "created": int(time.time()), + "model": "healthcheck", + "choices": [{ + "index": 0, + "message": { + "role": "assistant", + "content": "API is working" + }, + "finish_reason": "stop" + }] + } + + elif format == "gemini": + # Gemini格式响应 + return { + "candidates": [{ + "content": { + "parts": [{"text": "gcli2api工作中"}], + "role": "model" + }, + "finishReason": "STOP", + "index": 0, + }] + } + + elif format == "anthropic": + # Anthropic格式响应 + model = kwargs.get("model", "claude-unknown") + message_id = kwargs.get("message_id", "msg_healthcheck") + return { + "id": message_id, + "type": "message", + "role": "assistant", + "model": str(model), + "content": [{"type": "text", "text": "antigravity Anthropic Messages 正常工作中"}], + "stop_reason": "end_turn", + "stop_sequence": None, + "usage": {"input_tokens": 0, "output_tokens": 0}, + } + + # 未知格式返回空字典 + return {} diff --git a/src/storage/mongodb_manager.py b/src/storage/mongodb_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..1184a2f2f8de5bb9897499d910d9862068b9682d --- /dev/null +++ b/src/storage/mongodb_manager.py @@ -0,0 +1,877 @@ +""" +MongoDB 存储管理器 +""" + +import os +import time +import re +from typing import Any, Dict, List, Optional + +from motor.motor_asyncio import AsyncIOMotorClient, AsyncIOMotorDatabase + +from log import log + + +class MongoDBManager: + """MongoDB 数据库管理器""" + + # 状态字段常量 + STATE_FIELDS = { + "error_codes", + "disabled", + "last_success", + "user_email", + "model_cooldowns", + } + + def __init__(self): + self._client: Optional[AsyncIOMotorClient] = None + self._db: Optional[AsyncIOMotorDatabase] = None + self._initialized = False + + # 内存配置缓存 - 初始化时加载一次 + self._config_cache: Dict[str, Any] = {} + self._config_loaded = False + + async def initialize(self) -> None: + """初始化 MongoDB 连接""" + if self._initialized: + return + + try: + mongodb_uri = os.getenv("MONGODB_URI") + if not mongodb_uri: + raise ValueError("MONGODB_URI environment variable not set") + + database_name = os.getenv("MONGODB_DATABASE", "gcli2api") + + self._client = AsyncIOMotorClient(mongodb_uri) + self._db = self._client[database_name] + + # 测试连接 + await self._db.command("ping") + + # 创建索引 + await self._create_indexes() + + # 加载配置到内存 + await self._load_config_cache() + + self._initialized = True + log.info(f"MongoDB storage initialized (database: {database_name})") + + except Exception as e: + log.error(f"Error initializing MongoDB: {e}") + raise + + async def _create_indexes(self): + """创建索引""" + credentials_collection = self._db["credentials"] + antigravity_credentials_collection = self._db["antigravity_credentials"] + + # 创建普通凭证索引 + await credentials_collection.create_index("filename", unique=True) + await credentials_collection.create_index("disabled") + await credentials_collection.create_index("rotation_order") + + # 复合索引 + await credentials_collection.create_index([("disabled", 1), ("rotation_order", 1)]) + + # 如果经常按错误码筛选,可以添加此索引 + await credentials_collection.create_index("error_codes") + + # 创建 Antigravity 凭证索引 + await antigravity_credentials_collection.create_index("filename", unique=True) + await antigravity_credentials_collection.create_index("disabled") + await antigravity_credentials_collection.create_index("rotation_order") + + # 复合索引 + await antigravity_credentials_collection.create_index([("disabled", 1), ("rotation_order", 1)]) + + # 如果经常按错误码筛选,可以添加此索引 + await antigravity_credentials_collection.create_index("error_codes") + + log.debug("MongoDB indexes created") + + async def _load_config_cache(self): + """加载配置到内存缓存(仅在初始化时调用一次)""" + if self._config_loaded: + return + + try: + config_collection = self._db["config"] + cursor = config_collection.find({}) + + async for doc in cursor: + self._config_cache[doc["key"]] = doc.get("value") + + self._config_loaded = True + log.debug(f"Loaded {len(self._config_cache)} config items into cache") + + except Exception as e: + log.error(f"Error loading config cache: {e}") + self._config_cache = {} + + async def close(self) -> None: + """关闭 MongoDB 连接""" + if self._client: + self._client.close() + self._client = None + self._db = None + self._initialized = False + log.debug("MongoDB storage closed") + + def _ensure_initialized(self): + """确保已初始化""" + if not self._initialized: + raise RuntimeError("MongoDB manager not initialized") + + def _get_collection_name(self, mode: str) -> str: + """根据 mode 获取对应的集合名""" + if mode == "antigravity": + return "antigravity_credentials" + elif mode == "geminicli": + return "credentials" + else: + raise ValueError(f"Invalid mode: {mode}. Must be 'geminicli' or 'antigravity'") + + # ============ SQL 方法 ============ + + async def get_next_available_credential( + self, mode: str = "geminicli", model_key: Optional[str] = None + ) -> Optional[tuple[str, Dict[str, Any]]]: + """ + 随机获取一个可用凭证(负载均衡) + - 未禁用 + - 如果提供了 model_key,还会检查模型级冷却 + - 随机选择 + + Args: + mode: 凭证模式 ("geminicli" 或 "antigravity") + model_key: 模型键(用于模型级冷却检查,antigravity 用模型名,gcli 用 pro/flash) + + Note: + - 对于 antigravity: model_key 是具体模型名(如 "gemini-2.0-flash-exp") + - 对于 gcli: model_key 是 "pro" 或 "flash" + - 使用聚合管道在数据库层面过滤冷却状态,性能更优 + """ + self._ensure_initialized() + + try: + collection_name = self._get_collection_name(mode) + collection = self._db[collection_name] + current_time = time.time() + + # 构建聚合管道 + pipeline = [ + # 第一步: 筛选未禁用的凭证 + {"$match": {"disabled": False}}, + ] + + # 如果提供了 model_key,添加冷却检查 + if model_key: + pipeline.extend([ + # 第二步: 添加冷却状态字段 + { + "$addFields": { + "is_available": { + "$or": [ + # model_cooldowns 中没有该 model_key + {"$not": {"$ifNull": [f"$model_cooldowns.{model_key}", False]}}, + # 或者冷却时间已过期 + {"$lte": [f"$model_cooldowns.{model_key}", current_time]} + ] + } + } + }, + # 第三步: 只保留可用的凭证 + {"$match": {"is_available": True}}, + ]) + + # 第四步: 随机抽取一个 + pipeline.append({"$sample": {"size": 1}}) + + # 第五步: 只投影需要的字段 + pipeline.append({ + "$project": { + "filename": 1, + "credential_data": 1, + "_id": 0 + } + }) + + # 执行聚合 + docs = await collection.aggregate(pipeline).to_list(length=1) + + if docs: + doc = docs[0] + return doc["filename"], doc.get("credential_data") + + return None + + except Exception as e: + log.error(f"Error getting next available credential (mode={mode}, model_key={model_key}): {e}") + return None + + async def get_available_credentials_list(self, mode: str = "geminicli") -> List[str]: + """ + 获取所有可用凭证列表 + - 未禁用 + - 按轮换顺序排序 + """ + self._ensure_initialized() + + try: + collection_name = self._get_collection_name(mode) + collection = self._db[collection_name] + + pipeline = [ + {"$match": {"disabled": False}}, + {"$sort": {"rotation_order": 1}}, + {"$project": {"filename": 1, "_id": 0}} + ] + + docs = await collection.aggregate(pipeline).to_list(length=None) + return [doc["filename"] for doc in docs] + + except Exception as e: + log.error(f"Error getting available credentials list (mode={mode}): {e}") + return [] + + # ============ StorageBackend 协议方法 ============ + + async def store_credential(self, filename: str, credential_data: Dict[str, Any], mode: str = "geminicli") -> bool: + """存储或更新凭证""" + self._ensure_initialized() + + try: + collection_name = self._get_collection_name(mode) + collection = self._db[collection_name] + current_ts = time.time() + + # 使用 upsert + $setOnInsert + # 如果文档存在,只更新 credential_data 和 updated_at + # 如果文档不存在,设置所有默认字段 + + # 先尝试更新现有文档 + result = await collection.update_one( + {"filename": filename}, + { + "$set": { + "credential_data": credential_data, + "updated_at": current_ts, + } + } + ) + + # 如果没有匹配到(新凭证),需要插入 + if result.matched_count == 0: + # 获取下一个 rotation_order + pipeline = [ + {"$group": {"_id": None, "max_order": {"$max": "$rotation_order"}}}, + {"$project": {"_id": 0, "next_order": {"$add": ["$max_order", 1]}}} + ] + + result_list = await collection.aggregate(pipeline).to_list(length=1) + next_order = result_list[0]["next_order"] if result_list else 0 + + # 插入新凭证(使用 insert_one,因为我们已经确认不存在) + try: + await collection.insert_one({ + "filename": filename, + "credential_data": credential_data, + "disabled": False, + "error_codes": [], + "last_success": current_ts, + "user_email": None, + "model_cooldowns": {}, + "rotation_order": next_order, + "call_count": 0, + "created_at": current_ts, + "updated_at": current_ts, + }) + except Exception as insert_error: + # 处理并发插入导致的重复键错误 + if "duplicate key" in str(insert_error).lower(): + # 重试更新 + await collection.update_one( + {"filename": filename}, + {"$set": {"credential_data": credential_data, "updated_at": current_ts}} + ) + else: + raise + + log.debug(f"Stored credential: {filename} (mode={mode})") + return True + + except Exception as e: + log.error(f"Error storing credential {filename}: {e}") + return False + + async def get_credential(self, filename: str, mode: str = "geminicli") -> Optional[Dict[str, Any]]: + """获取凭证数据,支持basename匹配以兼容旧数据""" + self._ensure_initialized() + + try: + collection_name = self._get_collection_name(mode) + collection = self._db[collection_name] + + # 首先尝试精确匹配,只投影需要的字段 + doc = await collection.find_one( + {"filename": filename}, + {"credential_data": 1, "_id": 0} + ) + if doc: + return doc.get("credential_data") + + # 如果精确匹配失败,尝试使用basename匹配(处理包含路径的旧数据) + # 直接使用 $regex 结尾匹配,移除重复的 $or 条件 + regex_pattern = re.escape(filename) + doc = await collection.find_one( + {"filename": {"$regex": f".*{regex_pattern}$"}}, + {"credential_data": 1, "_id": 0} + ) + + if doc: + return doc.get("credential_data") + + return None + + except Exception as e: + log.error(f"Error getting credential {filename}: {e}") + return None + + async def list_credentials(self, mode: str = "geminicli") -> List[str]: + """列出所有凭证文件名""" + self._ensure_initialized() + + try: + collection_name = self._get_collection_name(mode) + collection = self._db[collection_name] + + # 使用聚合管道 + pipeline = [ + {"$sort": {"rotation_order": 1}}, + {"$project": {"filename": 1, "_id": 0}} + ] + + docs = await collection.aggregate(pipeline).to_list(length=None) + return [doc["filename"] for doc in docs] + + except Exception as e: + log.error(f"Error listing credentials: {e}") + return [] + + async def delete_credential(self, filename: str, mode: str = "geminicli") -> bool: + """删除凭证,支持basename匹配以兼容旧数据""" + self._ensure_initialized() + + try: + collection_name = self._get_collection_name(mode) + collection = self._db[collection_name] + + # 首先尝试精确匹配删除 + result = await collection.delete_one({"filename": filename}) + deleted_count = result.deleted_count + + # 如果精确匹配没有删除任何记录,尝试basename匹配 + if deleted_count == 0: + regex_pattern = re.escape(filename) + result = await collection.delete_one({ + "filename": {"$regex": f".*{regex_pattern}$"} + }) + deleted_count = result.deleted_count + + if deleted_count > 0: + log.debug(f"Deleted {deleted_count} credential(s): {filename} (mode={mode})") + return True + else: + log.warning(f"No credential found to delete: {filename} (mode={mode})") + return False + + except Exception as e: + log.error(f"Error deleting credential {filename}: {e}") + return False + + async def get_duplicate_credentials_by_email(self, mode: str = "geminicli") -> Dict[str, Any]: + """ + 获取按邮箱分组的重复凭证信息(只查询邮箱和文件名,不加载完整凭证数据) + 用于去重操作 + + Args: + mode: 凭证模式 ("geminicli" 或 "antigravity") + + Returns: + 包含 email_groups(邮箱分组)、duplicate_count(重复数量)、no_email_count(无邮箱数量)的字典 + """ + self._ensure_initialized() + + try: + collection_name = self._get_collection_name(mode) + collection = self._db[collection_name] + + # 使用聚合管道,只查询 filename 和 user_email 字段 + pipeline = [ + { + "$project": { + "filename": 1, + "user_email": 1, + "_id": 0 + } + }, + { + "$sort": {"filename": 1} + } + ] + + docs = await collection.aggregate(pipeline).to_list(length=None) + + # 按邮箱分组 + email_to_files = {} + no_email_files = [] + + for doc in docs: + filename = doc.get("filename") + user_email = doc.get("user_email") + + if user_email: + if user_email not in email_to_files: + email_to_files[user_email] = [] + email_to_files[user_email].append(filename) + else: + no_email_files.append(filename) + + # 找出重复的邮箱组 + duplicate_groups = [] + total_duplicate_count = 0 + + for email, files in email_to_files.items(): + if len(files) > 1: + # 保留第一个文件,其他为重复 + duplicate_groups.append({ + "email": email, + "kept_file": files[0], + "duplicate_files": files[1:], + "duplicate_count": len(files) - 1, + }) + total_duplicate_count += len(files) - 1 + + return { + "email_groups": email_to_files, + "duplicate_groups": duplicate_groups, + "duplicate_count": total_duplicate_count, + "no_email_files": no_email_files, + "no_email_count": len(no_email_files), + "unique_email_count": len(email_to_files), + "total_count": len(docs), + } + + except Exception as e: + log.error(f"Error getting duplicate credentials by email: {e}") + return { + "email_groups": {}, + "duplicate_groups": [], + "duplicate_count": 0, + "no_email_files": [], + "no_email_count": 0, + "unique_email_count": 0, + "total_count": 0, + } + + async def update_credential_state( + self, filename: str, state_updates: Dict[str, Any], mode: str = "geminicli" + ) -> bool: + """更新凭证状态,支持basename匹配以兼容旧数据""" + self._ensure_initialized() + + try: + collection_name = self._get_collection_name(mode) + collection = self._db[collection_name] + + # 过滤只更新状态字段 + valid_updates = { + k: v for k, v in state_updates.items() if k in self.STATE_FIELDS + } + + if not valid_updates: + return True + + valid_updates["updated_at"] = time.time() + + # 首先尝试精确匹配更新 + result = await collection.update_one( + {"filename": filename}, {"$set": valid_updates} + ) + updated_count = result.modified_count + result.matched_count + + # 如果精确匹配没有更新任何记录,尝试basename匹配 + if updated_count == 0: + regex_pattern = re.escape(filename) + result = await collection.update_one( + {"filename": {"$regex": f".*{regex_pattern}$"}}, + {"$set": valid_updates} + ) + updated_count = result.modified_count + result.matched_count + + return updated_count > 0 + + except Exception as e: + log.error(f"Error updating credential state {filename}: {e}") + return False + + async def get_credential_state(self, filename: str, mode: str = "geminicli") -> Dict[str, Any]: + """获取凭证状态,支持basename匹配以兼容旧数据""" + self._ensure_initialized() + + try: + collection_name = self._get_collection_name(mode) + collection = self._db[collection_name] + + # 首先尝试精确匹配 + doc = await collection.find_one({"filename": filename}) + + if doc: + return { + "disabled": doc.get("disabled", False), + "error_codes": doc.get("error_codes", []), + "last_success": doc.get("last_success", time.time()), + "user_email": doc.get("user_email"), + "model_cooldowns": doc.get("model_cooldowns", {}), + } + + # 如果精确匹配失败,尝试basename匹配 + regex_pattern = re.escape(filename) + doc = await collection.find_one({ + "filename": {"$regex": f".*{regex_pattern}$"} + }) + + if doc: + return { + "disabled": doc.get("disabled", False), + "error_codes": doc.get("error_codes", []), + "last_success": doc.get("last_success", time.time()), + "user_email": doc.get("user_email"), + "model_cooldowns": doc.get("model_cooldowns", {}), + } + + # 返回默认状态 + return { + "disabled": False, + "error_codes": [], + "last_success": time.time(), + "user_email": None, + "model_cooldowns": {}, + } + + except Exception as e: + log.error(f"Error getting credential state {filename}: {e}") + return {} + + async def get_all_credential_states(self, mode: str = "geminicli") -> Dict[str, Dict[str, Any]]: + """获取所有凭证状态""" + self._ensure_initialized() + + try: + collection_name = self._get_collection_name(mode) + collection = self._db[collection_name] + + # 使用投影只获取需要的字段 + cursor = collection.find( + {}, + projection={ + "filename": 1, + "disabled": 1, + "error_codes": 1, + "last_success": 1, + "user_email": 1, + "model_cooldowns": 1, + "_id": 0 + } + ) + + states = {} + current_time = time.time() + + async for doc in cursor: + filename = doc["filename"] + model_cooldowns = doc.get("model_cooldowns", {}) + + # 自动过滤掉已过期的模型CD + if model_cooldowns: + model_cooldowns = { + k: v for k, v in model_cooldowns.items() + if v > current_time + } + + states[filename] = { + "disabled": doc.get("disabled", False), + "error_codes": doc.get("error_codes", []), + "last_success": doc.get("last_success", time.time()), + "user_email": doc.get("user_email"), + "model_cooldowns": model_cooldowns, + } + + return states + + except Exception as e: + log.error(f"Error getting all credential states: {e}") + return {} + + async def get_credentials_summary( + self, + offset: int = 0, + limit: Optional[int] = None, + status_filter: str = "all", + mode: str = "geminicli", + error_code_filter: Optional[str] = None, + cooldown_filter: Optional[str] = None + ) -> Dict[str, Any]: + """ + 获取凭证的摘要信息(不包含完整凭证数据)- 支持分页和状态筛选 + + Args: + offset: 跳过的记录数(默认0) + limit: 返回的最大记录数(None表示返回所有) + status_filter: 状态筛选(all=全部, enabled=仅启用, disabled=仅禁用) + mode: 凭证模式 ("geminicli" 或 "antigravity") + error_code_filter: 错误码筛选(格式如"400"或"403",筛选包含该错误码的凭证) + cooldown_filter: 冷却状态筛选("in_cooldown"=冷却中, "no_cooldown"=未冷却) + + Returns: + 包含 items(凭证列表)、total(总数)、offset、limit 的字典 + """ + self._ensure_initialized() + + try: + # 根据 mode 选择集合名 + collection_name = self._get_collection_name(mode) + collection = self._db[collection_name] + + # 构建查询条件 + query = {} + if status_filter == "enabled": + query["disabled"] = False + elif status_filter == "disabled": + query["disabled"] = True + + # 错误码筛选 - 兼容存储为数字或字符串的情况 + if error_code_filter and str(error_code_filter).strip().lower() != "all": + filter_value = str(error_code_filter).strip() + query_values = [filter_value] + try: + query_values.append(int(filter_value)) + except ValueError: + pass + query["error_codes"] = {"$in": query_values} + + # 计算全局统计数据(不受筛选条件影响) + global_stats = {"total": 0, "normal": 0, "disabled": 0} + stats_pipeline = [ + { + "$group": { + "_id": "$disabled", + "count": {"$sum": 1} + } + } + ] + + stats_result = await collection.aggregate(stats_pipeline).to_list(length=10) + for item in stats_result: + count = item["count"] + global_stats["total"] += count + if item["_id"]: + global_stats["disabled"] = count + else: + global_stats["normal"] = count + + # 获取所有匹配的文档(用于冷却筛选,因为需要在Python中判断) + cursor = collection.find( + query, + projection={ + "filename": 1, + "disabled": 1, + "error_codes": 1, + "last_success": 1, + "user_email": 1, + "rotation_order": 1, + "model_cooldowns": 1, + "_id": 0 + } + ).sort("rotation_order", 1) + + all_summaries = [] + current_time = time.time() + + async for doc in cursor: + model_cooldowns = doc.get("model_cooldowns", {}) + + # 自动过滤掉已过期的模型CD + active_cooldowns = {} + if model_cooldowns: + active_cooldowns = { + k: v for k, v in model_cooldowns.items() + if v > current_time + } + + summary = { + "filename": doc["filename"], + "disabled": doc.get("disabled", False), + "error_codes": doc.get("error_codes", []), + "last_success": doc.get("last_success", current_time), + "user_email": doc.get("user_email"), + "rotation_order": doc.get("rotation_order", 0), + "model_cooldowns": active_cooldowns, + } + + # 应用冷却筛选 + if cooldown_filter == "in_cooldown": + # 只保留有冷却的凭证 + if active_cooldowns: + all_summaries.append(summary) + elif cooldown_filter == "no_cooldown": + # 只保留没有冷却的凭证 + if not active_cooldowns: + all_summaries.append(summary) + else: + # 不筛选冷却状态 + all_summaries.append(summary) + + # 应用分页 + total_count = len(all_summaries) + if limit is not None: + summaries = all_summaries[offset:offset + limit] + else: + summaries = all_summaries[offset:] + + return { + "items": summaries, + "total": total_count, + "offset": offset, + "limit": limit, + "stats": global_stats, + } + + except Exception as e: + log.error(f"Error getting credentials summary: {e}") + return { + "items": [], + "total": 0, + "offset": offset, + "limit": limit, + "stats": {"total": 0, "normal": 0, "disabled": 0}, + } + + # ============ 配置管理(内存缓存)============ + + async def set_config(self, key: str, value: Any) -> bool: + """设置配置(写入数据库 + 更新内存缓存)""" + self._ensure_initialized() + + try: + config_collection = self._db["config"] + await config_collection.update_one( + {"key": key}, + {"$set": {"value": value, "updated_at": time.time()}}, + upsert=True, + ) + + # 更新内存缓存 + self._config_cache[key] = value + return True + + except Exception as e: + log.error(f"Error setting config {key}: {e}") + return False + + async def reload_config_cache(self): + """重新加载配置缓存(在批量修改配置后调用)""" + self._ensure_initialized() + self._config_loaded = False + await self._load_config_cache() + log.info("Config cache reloaded from database") + + async def get_config(self, key: str, default: Any = None) -> Any: + """获取配置(从内存缓存)""" + self._ensure_initialized() + return self._config_cache.get(key, default) + + async def get_all_config(self) -> Dict[str, Any]: + """获取所有配置(从内存缓存)""" + self._ensure_initialized() + return self._config_cache.copy() + + async def delete_config(self, key: str) -> bool: + """删除配置""" + self._ensure_initialized() + + try: + config_collection = self._db["config"] + result = await config_collection.delete_one({"key": key}) + + # 从内存缓存移除 + self._config_cache.pop(key, None) + return result.deleted_count > 0 + + except Exception as e: + log.error(f"Error deleting config {key}: {e}") + return False + + # ============ 模型级冷却管理 ============ + + async def set_model_cooldown( + self, + filename: str, + model_key: str, + cooldown_until: Optional[float], + mode: str = "geminicli" + ) -> bool: + """ + 设置特定模型的冷却时间 + + Args: + filename: 凭证文件名 + model_key: 模型键(antigravity 用模型名,gcli 用 pro/flash) + cooldown_until: 冷却截止时间戳(None 表示清除冷却) + mode: 凭证模式 ("geminicli" 或 "antigravity") + + Returns: + 是否成功 + """ + self._ensure_initialized() + + try: + collection_name = self._get_collection_name(mode) + collection = self._db[collection_name] + + # 使用原子操作直接更新,避免竞态条件 + if cooldown_until is None: + # 删除指定模型的冷却 + result = await collection.update_one( + {"filename": filename}, + { + "$unset": {f"model_cooldowns.{model_key}": ""}, + "$set": {"updated_at": time.time()} + } + ) + else: + # 设置冷却时间 + result = await collection.update_one( + {"filename": filename}, + { + "$set": { + f"model_cooldowns.{model_key}": cooldown_until, + "updated_at": time.time() + } + } + ) + + if result.matched_count == 0: + log.warning(f"Credential {filename} not found") + return False + + log.debug(f"Set model cooldown: {filename}, model_key={model_key}, cooldown_until={cooldown_until}") + return True + + except Exception as e: + log.error(f"Error setting model cooldown for {filename}: {e}") + return False diff --git a/src/storage/sqlite_manager.py b/src/storage/sqlite_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..e21a640008f5e96c49ac662f57277b0a0e7d6712 --- /dev/null +++ b/src/storage/sqlite_manager.py @@ -0,0 +1,1007 @@ +""" +SQLite 存储管理器 +""" + +import asyncio +import json +import os +import time +from typing import Any, Dict, List, Optional, Tuple + +import aiosqlite + +from log import log + + +class SQLiteManager: + """SQLite 数据库管理器""" + + # 状态字段常量 + STATE_FIELDS = { + "error_codes", + "disabled", + "last_success", + "user_email", + "model_cooldowns", + } + + # 所有必需的列定义(用于自动校验和修复) + REQUIRED_COLUMNS = { + "credentials": [ + ("disabled", "INTEGER DEFAULT 0"), + ("error_codes", "TEXT DEFAULT '[]'"), + ("last_success", "REAL"), + ("user_email", "TEXT"), + ("model_cooldowns", "TEXT DEFAULT '{}'"), + ("rotation_order", "INTEGER DEFAULT 0"), + ("call_count", "INTEGER DEFAULT 0"), + ("created_at", "REAL DEFAULT (unixepoch())"), + ("updated_at", "REAL DEFAULT (unixepoch())") + ], + "antigravity_credentials": [ + ("disabled", "INTEGER DEFAULT 0"), + ("error_codes", "TEXT DEFAULT '[]'"), + ("last_success", "REAL"), + ("user_email", "TEXT"), + ("model_cooldowns", "TEXT DEFAULT '{}'"), + ("rotation_order", "INTEGER DEFAULT 0"), + ("call_count", "INTEGER DEFAULT 0"), + ("created_at", "REAL DEFAULT (unixepoch())"), + ("updated_at", "REAL DEFAULT (unixepoch())") + ] + } + + def __init__(self): + self._db_path = None + self._credentials_dir = None + self._initialized = False + self._lock = asyncio.Lock() + + # 内存配置缓存 - 初始化时加载一次 + self._config_cache: Dict[str, Any] = {} + self._config_loaded = False + + async def initialize(self) -> None: + """初始化 SQLite 数据库""" + if self._initialized: + return + + async with self._lock: + if self._initialized: + return + + try: + # 获取凭证目录 + self._credentials_dir = os.getenv("CREDENTIALS_DIR", "./creds") + self._db_path = os.path.join(self._credentials_dir, "credentials.db") + + # 确保目录存在 + os.makedirs(self._credentials_dir, exist_ok=True) + + # 创建数据库和表 + async with aiosqlite.connect(self._db_path) as db: + # 启用 WAL 模式(提升并发性能) + await db.execute("PRAGMA journal_mode=WAL") + await db.execute("PRAGMA foreign_keys=ON") + + # 检查并自动修复数据库结构 + await self._ensure_schema_compatibility(db) + + # 创建表 + await self._create_tables(db) + + await db.commit() + + # 加载配置到内存 + await self._load_config_cache() + + self._initialized = True + log.info(f"SQLite storage initialized at {self._db_path}") + + except Exception as e: + log.error(f"Error initializing SQLite: {e}") + raise + + async def _ensure_schema_compatibility(self, db: aiosqlite.Connection) -> None: + """ + 确保数据库结构兼容,自动修复缺失的列 + """ + try: + # 检查每个表 + for table_name, columns in self.REQUIRED_COLUMNS.items(): + # 检查表是否存在 + async with db.execute( + "SELECT name FROM sqlite_master WHERE type='table' AND name=?", + (table_name,) + ) as cursor: + if not await cursor.fetchone(): + log.debug(f"Table {table_name} does not exist, will be created") + continue + + # 获取现有列 + async with db.execute(f"PRAGMA table_info({table_name})") as cursor: + existing_columns = {row[1] for row in await cursor.fetchall()} + + # 添加缺失的列 + added_count = 0 + for col_name, col_def in columns: + if col_name not in existing_columns: + try: + await db.execute(f"ALTER TABLE {table_name} ADD COLUMN {col_name} {col_def}") + log.info(f"Added missing column {table_name}.{col_name}") + added_count += 1 + except Exception as e: + log.error(f"Failed to add column {table_name}.{col_name}: {e}") + + if added_count > 0: + log.info(f"Table {table_name}: added {added_count} missing column(s)") + + except Exception as e: + log.error(f"Error ensuring schema compatibility: {e}") + # 不抛出异常,允许继续初始化 + + async def _create_tables(self, db: aiosqlite.Connection): + """创建数据库表和索引""" + # 凭证表 + await db.execute(""" + CREATE TABLE IF NOT EXISTS credentials ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + filename TEXT UNIQUE NOT NULL, + credential_data TEXT NOT NULL, + + -- 状态字段 + disabled INTEGER DEFAULT 0, + error_codes TEXT DEFAULT '[]', + last_success REAL, + user_email TEXT, + + -- 模型级 CD 支持 (JSON: {model_key: cooldown_timestamp}) + model_cooldowns TEXT DEFAULT '{}', + + -- 轮换相关 + rotation_order INTEGER DEFAULT 0, + call_count INTEGER DEFAULT 0, + + -- 时间戳 + created_at REAL DEFAULT (unixepoch()), + updated_at REAL DEFAULT (unixepoch()) + ) + """) + + # Antigravity 凭证表(结构相同但独立存储) + await db.execute(""" + CREATE TABLE IF NOT EXISTS antigravity_credentials ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + filename TEXT UNIQUE NOT NULL, + credential_data TEXT NOT NULL, + + -- 状态字段 + disabled INTEGER DEFAULT 0, + error_codes TEXT DEFAULT '[]', + last_success REAL, + user_email TEXT, + + -- 模型级 CD 支持 (JSON: {model_name: cooldown_timestamp}) + model_cooldowns TEXT DEFAULT '{}', + + -- 轮换相关 + rotation_order INTEGER DEFAULT 0, + call_count INTEGER DEFAULT 0, + + -- 时间戳 + created_at REAL DEFAULT (unixepoch()), + updated_at REAL DEFAULT (unixepoch()) + ) + """) + + # 创建索引 - 普通凭证表 + await db.execute(""" + CREATE INDEX IF NOT EXISTS idx_disabled + ON credentials(disabled) + """) + await db.execute(""" + CREATE INDEX IF NOT EXISTS idx_rotation_order + ON credentials(rotation_order) + """) + + # 创建索引 - Antigravity 凭证表 + await db.execute(""" + CREATE INDEX IF NOT EXISTS idx_ag_disabled + ON antigravity_credentials(disabled) + """) + await db.execute(""" + CREATE INDEX IF NOT EXISTS idx_ag_rotation_order + ON antigravity_credentials(rotation_order) + """) + + # 配置表 + await db.execute(""" + CREATE TABLE IF NOT EXISTS config ( + key TEXT PRIMARY KEY, + value TEXT NOT NULL, + updated_at REAL DEFAULT (unixepoch()) + ) + """) + + log.debug("SQLite tables and indexes created") + + async def _load_config_cache(self): + """加载配置到内存缓存(仅在初始化时调用一次)""" + if self._config_loaded: + return + + try: + async with aiosqlite.connect(self._db_path) as db: + async with db.execute("SELECT key, value FROM config") as cursor: + rows = await cursor.fetchall() + + for key, value in rows: + try: + self._config_cache[key] = json.loads(value) + except json.JSONDecodeError: + self._config_cache[key] = value + + self._config_loaded = True + log.debug(f"Loaded {len(self._config_cache)} config items into cache") + + except Exception as e: + log.error(f"Error loading config cache: {e}") + self._config_cache = {} + + async def close(self) -> None: + """关闭数据库连接""" + self._initialized = False + log.debug("SQLite storage closed") + + def _ensure_initialized(self): + """确保已初始化""" + if not self._initialized: + raise RuntimeError("SQLite manager not initialized") + + def _get_table_name(self, mode: str) -> str: + """根据 mode 获取对应的表名""" + if mode == "antigravity": + return "antigravity_credentials" + elif mode == "geminicli": + return "credentials" + else: + raise ValueError(f"Invalid mode: {mode}. Must be 'geminicli' or 'antigravity'") + + # ============ SQL 方法 ============ + + async def get_next_available_credential( + self, mode: str = "geminicli", model_key: Optional[str] = None + ) -> Optional[Tuple[str, Dict[str, Any]]]: + """ + 随机获取一个可用凭证(负载均衡) + - 未禁用 + - 如果提供了 model_key,还会检查模型级冷却 + - 随机选择 + + Args: + mode: 凭证模式 ("geminicli" 或 "antigravity") + model_key: 模型键(用于模型级冷却检查,antigravity 用模型名,gcli 用 pro/flash) + + Note: + - 对于 antigravity: model_key 是具体模型名(如 "gemini-2.0-flash-exp") + - 对于 gcli: model_key 是 "pro" 或 "flash" + """ + self._ensure_initialized() + + try: + table_name = self._get_table_name(mode) + async with aiosqlite.connect(self._db_path) as db: + current_time = time.time() + + # 获取所有候选凭证(未禁用) + async with db.execute(f""" + SELECT filename, credential_data, model_cooldowns + FROM {table_name} + WHERE disabled = 0 + ORDER BY RANDOM() + """) as cursor: + rows = await cursor.fetchall() + + # 如果没有提供 model_key,使用第一个可用凭证 + if not model_key: + if rows: + filename, credential_json, _ = rows[0] + credential_data = json.loads(credential_json) + return filename, credential_data + return None + + # 如果提供了 model_key,检查模型级冷却 + for filename, credential_json, model_cooldowns_json in rows: + model_cooldowns = json.loads(model_cooldowns_json or '{}') + + # 检查该模型是否在冷却中 + model_cooldown = model_cooldowns.get(model_key) + if model_cooldown is None or current_time >= model_cooldown: + # 该模型未冷却或冷却已过期 + credential_data = json.loads(credential_json) + return filename, credential_data + + return None + + except Exception as e: + log.error(f"Error getting next available credential (mode={mode}, model_key={model_key}): {e}") + return None + + async def get_available_credentials_list(self) -> List[str]: + """ + 获取所有可用凭证列表 + - 未禁用 + - 按轮换顺序排序 + """ + self._ensure_initialized() + + try: + async with aiosqlite.connect(self._db_path) as db: + async with db.execute(""" + SELECT filename + FROM credentials + WHERE disabled = 0 + ORDER BY rotation_order ASC + """) as cursor: + rows = await cursor.fetchall() + return [row[0] for row in rows] + + except Exception as e: + log.error(f"Error getting available credentials list: {e}") + return [] + + # ============ StorageBackend 协议方法 ============ + + async def store_credential(self, filename: str, credential_data: Dict[str, Any], mode: str = "geminicli") -> bool: + """存储或更新凭证""" + self._ensure_initialized() + + try: + table_name = self._get_table_name(mode) + async with aiosqlite.connect(self._db_path) as db: + # 检查凭证是否存在 + async with db.execute(f""" + SELECT disabled, error_codes, last_success, user_email, + rotation_order, call_count + FROM {table_name} WHERE filename = ? + """, (filename,)) as cursor: + existing = await cursor.fetchone() + + if existing: + # 更新现有凭证(保留状态) + await db.execute(f""" + UPDATE {table_name} + SET credential_data = ?, + updated_at = unixepoch() + WHERE filename = ? + """, (json.dumps(credential_data), filename)) + else: + # 插入新凭证 + async with db.execute(f""" + SELECT COALESCE(MAX(rotation_order), -1) + 1 FROM {table_name} + """) as cursor: + row = await cursor.fetchone() + next_order = row[0] + + await db.execute(f""" + INSERT INTO {table_name} + (filename, credential_data, rotation_order, last_success) + VALUES (?, ?, ?, ?) + """, (filename, json.dumps(credential_data), next_order, time.time())) + + await db.commit() + log.debug(f"Stored credential: {filename} (mode={mode})") + return True + + except Exception as e: + log.error(f"Error storing credential {filename}: {e}") + return False + + async def get_credential(self, filename: str, mode: str = "geminicli") -> Optional[Dict[str, Any]]: + """获取凭证数据,支持basename匹配以兼容旧数据""" + self._ensure_initialized() + + try: + table_name = self._get_table_name(mode) + async with aiosqlite.connect(self._db_path) as db: + # 首先尝试精确匹配 + async with db.execute(f""" + SELECT credential_data FROM {table_name} WHERE filename = ? + """, (filename,)) as cursor: + row = await cursor.fetchone() + if row: + return json.loads(row[0]) + + # 如果精确匹配失败,尝试使用basename匹配(处理包含路径的旧数据) + async with db.execute(f""" + SELECT credential_data FROM {table_name} + WHERE filename LIKE '%' || ? OR filename = ? + """, (filename, filename)) as cursor: + rows = await cursor.fetchall() + # 优先返回完全匹配的,否则返回basename匹配的第一个 + for row in rows: + return json.loads(row[0]) + + return None + + except Exception as e: + log.error(f"Error getting credential {filename}: {e}") + return None + + async def list_credentials(self, mode: str = "geminicli") -> List[str]: + """列出所有凭证文件名(包括禁用的)""" + self._ensure_initialized() + + try: + table_name = self._get_table_name(mode) + async with aiosqlite.connect(self._db_path) as db: + async with db.execute(f""" + SELECT filename FROM {table_name} ORDER BY rotation_order + """) as cursor: + rows = await cursor.fetchall() + return [row[0] for row in rows] + + except Exception as e: + log.error(f"Error listing credentials: {e}") + return [] + + async def delete_credential(self, filename: str, mode: str = "geminicli") -> bool: + """删除凭证,支持basename匹配以兼容旧数据""" + self._ensure_initialized() + + try: + table_name = self._get_table_name(mode) + async with aiosqlite.connect(self._db_path) as db: + # 首先尝试精确匹配删除 + result = await db.execute(f""" + DELETE FROM {table_name} WHERE filename = ? + """, (filename,)) + deleted_count = result.rowcount + + # 如果精确匹配没有删除任何记录,尝试basename匹配 + if deleted_count == 0: + result = await db.execute(f""" + DELETE FROM {table_name} WHERE filename LIKE '%' || ? + """, (filename,)) + deleted_count = result.rowcount + + await db.commit() + + if deleted_count > 0: + log.debug(f"Deleted {deleted_count} credential(s): {filename} (mode={mode})") + return True + else: + log.warning(f"No credential found to delete: {filename} (mode={mode})") + return False + + except Exception as e: + log.error(f"Error deleting credential {filename}: {e}") + return False + + async def update_credential_state(self, filename: str, state_updates: Dict[str, Any], mode: str = "geminicli") -> bool: + """更新凭证状态,支持basename匹配以兼容旧数据""" + self._ensure_initialized() + + try: + table_name = self._get_table_name(mode) + log.debug(f"[DB] update_credential_state 开始: filename={filename}, state_updates={state_updates}, mode={mode}, table={table_name}") + + # 构建动态 SQL + set_clauses = [] + values = [] + + for key, value in state_updates.items(): + if key in self.STATE_FIELDS: + if key == "error_codes": + set_clauses.append(f"{key} = ?") + values.append(json.dumps(value)) + elif key == "model_cooldowns": + set_clauses.append(f"{key} = ?") + values.append(json.dumps(value)) + else: + set_clauses.append(f"{key} = ?") + values.append(value) + + if not set_clauses: + log.info(f"[DB] 没有需要更新的状态字段") + return True + + set_clauses.append("updated_at = unixepoch()") + values.append(filename) + + log.debug(f"[DB] SQL参数: set_clauses={set_clauses}, values={values}") + + async with aiosqlite.connect(self._db_path) as db: + # 首先尝试精确匹配更新 + sql_exact = f""" + UPDATE {table_name} + SET {', '.join(set_clauses)} + WHERE filename = ? + """ + log.debug(f"[DB] 执行精确匹配SQL: {sql_exact}") + log.debug(f"[DB] SQL参数值: {values}") + + result = await db.execute(sql_exact, values) + updated_count = result.rowcount + log.debug(f"[DB] 精确匹配 rowcount={updated_count}") + + # 如果精确匹配没有更新任何记录,尝试basename匹配 + if updated_count == 0: + sql_basename = f""" + UPDATE {table_name} + SET {', '.join(set_clauses)} + WHERE filename LIKE '%' || ? + """ + log.debug(f"[DB] 精确匹配失败,尝试basename匹配SQL: {sql_basename}") + result = await db.execute(sql_basename, values) + updated_count = result.rowcount + log.info(f"[DB] basename匹配 rowcount={updated_count}") + + # 提交前检查 + log.debug(f"[DB] 准备commit,总更新行数={updated_count}") + await db.commit() + log.debug(f"[DB] commit完成") + + success = updated_count > 0 + log.debug(f"[DB] update_credential_state 结束: success={success}, updated_count={updated_count}") + return success + + except Exception as e: + log.error(f"[DB] Error updating credential state {filename}: {e}", exc_info=True) + return False + + async def get_credential_state(self, filename: str, mode: str = "geminicli") -> Dict[str, Any]: + """获取凭证状态,支持basename匹配以兼容旧数据""" + self._ensure_initialized() + + try: + table_name = self._get_table_name(mode) + async with aiosqlite.connect(self._db_path) as db: + # 首先尝试精确匹配 + async with db.execute(f""" + SELECT disabled, error_codes, last_success, user_email, model_cooldowns + FROM {table_name} WHERE filename = ? + """, (filename,)) as cursor: + row = await cursor.fetchone() + + if row: + error_codes_json = row[1] or '[]' + model_cooldowns_json = row[4] or '{}' + return { + "disabled": bool(row[0]), + "error_codes": json.loads(error_codes_json), + "last_success": row[2] or time.time(), + "user_email": row[3], + "model_cooldowns": json.loads(model_cooldowns_json), + } + + # 如果精确匹配失败,尝试basename匹配 + async with db.execute(f""" + SELECT disabled, error_codes, last_success, user_email, model_cooldowns + FROM {table_name} WHERE filename LIKE '%' || ? + """, (filename,)) as cursor: + row = await cursor.fetchone() + + if row: + error_codes_json = row[1] or '[]' + model_cooldowns_json = row[4] or '{}' + return { + "disabled": bool(row[0]), + "error_codes": json.loads(error_codes_json), + "last_success": row[2] or time.time(), + "user_email": row[3], + "model_cooldowns": json.loads(model_cooldowns_json), + } + + # 返回默认状态 + return { + "disabled": False, + "error_codes": [], + "last_success": time.time(), + "user_email": None, + "model_cooldowns": {}, + } + + except Exception as e: + log.error(f"Error getting credential state {filename}: {e}") + return {} + + async def get_all_credential_states(self, mode: str = "geminicli") -> Dict[str, Dict[str, Any]]: + """获取所有凭证状态""" + self._ensure_initialized() + + try: + table_name = self._get_table_name(mode) + async with aiosqlite.connect(self._db_path) as db: + async with db.execute(f""" + SELECT filename, disabled, error_codes, last_success, + user_email, model_cooldowns + FROM {table_name} + """) as cursor: + rows = await cursor.fetchall() + + states = {} + current_time = time.time() + + for row in rows: + filename = row[0] + error_codes_json = row[2] or '[]' + model_cooldowns_json = row[5] or '{}' + model_cooldowns = json.loads(model_cooldowns_json) + + # 自动过滤掉已过期的模型CD + if model_cooldowns: + model_cooldowns = { + k: v for k, v in model_cooldowns.items() + if v > current_time + } + + states[filename] = { + "disabled": bool(row[1]), + "error_codes": json.loads(error_codes_json), + "last_success": row[3] or time.time(), + "user_email": row[4], + "model_cooldowns": model_cooldowns, + } + + return states + + except Exception as e: + log.error(f"Error getting all credential states: {e}") + return {} + + async def get_credentials_summary( + self, + offset: int = 0, + limit: Optional[int] = None, + status_filter: str = "all", + mode: str = "geminicli", + error_code_filter: Optional[str] = None, + cooldown_filter: Optional[str] = None + ) -> Dict[str, Any]: + """ + 获取凭证的摘要信息(不包含完整凭证数据)- 支持分页和状态筛选 + + Args: + offset: 跳过的记录数(默认0) + limit: 返回的最大记录数(None表示返回所有) + status_filter: 状态筛选(all=全部, enabled=仅启用, disabled=仅禁用) + mode: 凭证模式 ("geminicli" 或 "antigravity") + error_code_filter: 错误码筛选(格式如"400"或"403",筛选包含该错误码的凭证) + cooldown_filter: 冷却状态筛选("in_cooldown"=冷却中, "no_cooldown"=未冷却) + + Returns: + 包含 items(凭证列表)、total(总数)、offset、limit 的字典 + """ + self._ensure_initialized() + + try: + # 根据 mode 选择表名 + table_name = self._get_table_name(mode) + + async with aiosqlite.connect(self._db_path) as db: + # 先计算全局统计数据(不受筛选条件影响) + global_stats = {"total": 0, "normal": 0, "disabled": 0} + async with db.execute(f""" + SELECT disabled, COUNT(*) FROM {table_name} GROUP BY disabled + """) as stats_cursor: + stats_rows = await stats_cursor.fetchall() + for disabled, count in stats_rows: + global_stats["total"] += count + if disabled: + global_stats["disabled"] = count + else: + global_stats["normal"] = count + + # 构建WHERE子句 + where_clauses = [] + count_params = [] + + if status_filter == "enabled": + where_clauses.append("disabled = 0") + elif status_filter == "disabled": + where_clauses.append("disabled = 1") + + filter_value = None + filter_int = None + if error_code_filter and str(error_code_filter).strip().lower() != "all": + filter_value = str(error_code_filter).strip() + try: + filter_int = int(filter_value) + except ValueError: + filter_int = None + + # 构建WHERE子句 + where_clause = "" + if where_clauses: + where_clause = "WHERE " + " AND ".join(where_clauses) + + # 先获取所有数据(用于冷却筛选,因为需要在Python中判断) + all_query = f""" + SELECT filename, disabled, error_codes, last_success, + user_email, rotation_order, model_cooldowns + FROM {table_name} + {where_clause} + ORDER BY rotation_order + """ + + async with db.execute(all_query, count_params) as cursor: + all_rows = await cursor.fetchall() + + current_time = time.time() + all_summaries = [] + + for row in all_rows: + filename = row[0] + error_codes_json = row[2] or '[]' + model_cooldowns_json = row[6] or '{}' + model_cooldowns = json.loads(model_cooldowns_json) + + # 自动过滤掉已过期的模型CD + active_cooldowns = {} + if model_cooldowns: + active_cooldowns = { + k: v for k, v in model_cooldowns.items() + if v > current_time + } + + error_codes = json.loads(error_codes_json) + if filter_value: + match = False + for code in error_codes: + if code == filter_value or code == filter_int: + match = True + break + if isinstance(code, str) and filter_int is not None: + try: + if int(code) == filter_int: + match = True + break + except ValueError: + pass + if not match: + continue + + summary = { + "filename": filename, + "disabled": bool(row[1]), + "error_codes": error_codes, + "last_success": row[3] or current_time, + "user_email": row[4], + "rotation_order": row[5], + "model_cooldowns": active_cooldowns, + } + + # 应用冷却筛选 + if cooldown_filter == "in_cooldown": + # 只保留有冷却的凭证 + if active_cooldowns: + all_summaries.append(summary) + elif cooldown_filter == "no_cooldown": + # 只保留没有冷却的凭证 + if not active_cooldowns: + all_summaries.append(summary) + else: + # 不筛选冷却状态 + all_summaries.append(summary) + + # 应用分页 + total_count = len(all_summaries) + if limit is not None: + summaries = all_summaries[offset:offset + limit] + else: + summaries = all_summaries[offset:] + + return { + "items": summaries, + "total": total_count, + "offset": offset, + "limit": limit, + "stats": global_stats, + } + + except Exception as e: + log.error(f"Error getting credentials summary: {e}") + return { + "items": [], + "total": 0, + "offset": offset, + "limit": limit, + "stats": {"total": 0, "normal": 0, "disabled": 0}, + } + + async def get_duplicate_credentials_by_email(self, mode: str = "geminicli") -> Dict[str, Any]: + """ + 获取按邮箱分组的重复凭证信息(只查询邮箱和文件名,不加载完整凭证数据) + 用于去重操作 + + Args: + mode: 凭证模式 ("geminicli" 或 "antigravity") + + Returns: + 包含 email_groups(邮箱分组)、duplicate_count(重复数量)、no_email_count(无邮箱数量)的字典 + """ + self._ensure_initialized() + + try: + # 根据 mode 选择表名 + table_name = self._get_table_name(mode) + + async with aiosqlite.connect(self._db_path) as db: + # 查询所有凭证的文件名和邮箱(不加载完整凭证数据) + query = f""" + SELECT filename, user_email + FROM {table_name} + ORDER BY filename + """ + + async with db.execute(query) as cursor: + rows = await cursor.fetchall() + + # 按邮箱分组 + email_to_files = {} + no_email_files = [] + + for filename, user_email in rows: + if user_email: + if user_email not in email_to_files: + email_to_files[user_email] = [] + email_to_files[user_email].append(filename) + else: + no_email_files.append(filename) + + # 找出重复的邮箱组 + duplicate_groups = [] + total_duplicate_count = 0 + + for email, files in email_to_files.items(): + if len(files) > 1: + # 保留第一个文件,其他为重复 + duplicate_groups.append({ + "email": email, + "kept_file": files[0], + "duplicate_files": files[1:], + "duplicate_count": len(files) - 1, + }) + total_duplicate_count += len(files) - 1 + + return { + "email_groups": email_to_files, + "duplicate_groups": duplicate_groups, + "duplicate_count": total_duplicate_count, + "no_email_files": no_email_files, + "no_email_count": len(no_email_files), + "unique_email_count": len(email_to_files), + "total_count": len(rows), + } + + except Exception as e: + log.error(f"Error getting duplicate credentials by email: {e}") + return { + "email_groups": {}, + "duplicate_groups": [], + "duplicate_count": 0, + "no_email_files": [], + "no_email_count": 0, + "unique_email_count": 0, + "total_count": 0, + } + + # ============ 配置管理(内存缓存)============ + + async def set_config(self, key: str, value: Any) -> bool: + """设置配置(写入数据库 + 更新内存缓存)""" + self._ensure_initialized() + + try: + async with aiosqlite.connect(self._db_path) as db: + await db.execute(""" + INSERT INTO config (key, value, updated_at) + VALUES (?, ?, unixepoch()) + ON CONFLICT(key) DO UPDATE SET + value = excluded.value, + updated_at = excluded.updated_at + """, (key, json.dumps(value))) + await db.commit() + + # 更新内存缓存 + self._config_cache[key] = value + return True + + except Exception as e: + log.error(f"Error setting config {key}: {e}") + return False + + async def reload_config_cache(self): + """重新加载配置缓存(在批量修改配置后调用)""" + self._ensure_initialized() + self._config_loaded = False + await self._load_config_cache() + log.info("Config cache reloaded from database") + + async def get_config(self, key: str, default: Any = None) -> Any: + """获取配置(从内存缓存)""" + self._ensure_initialized() + return self._config_cache.get(key, default) + + async def get_all_config(self) -> Dict[str, Any]: + """获取所有配置(从内存缓存)""" + self._ensure_initialized() + return self._config_cache.copy() + + async def delete_config(self, key: str) -> bool: + """删除配置""" + self._ensure_initialized() + + try: + async with aiosqlite.connect(self._db_path) as db: + await db.execute("DELETE FROM config WHERE key = ?", (key,)) + await db.commit() + + # 从内存缓存移除 + self._config_cache.pop(key, None) + return True + + except Exception as e: + log.error(f"Error deleting config {key}: {e}") + return False + + # ============ 模型级冷却管理 ============ + + async def set_model_cooldown( + self, + filename: str, + model_key: str, + cooldown_until: Optional[float], + mode: str = "geminicli" + ) -> bool: + """ + 设置特定模型的冷却时间 + + Args: + filename: 凭证文件名 + model_key: 模型键(antigravity 用模型名,gcli 用 pro/flash) + cooldown_until: 冷却截止时间戳(None 表示清除冷却) + mode: 凭证模式 ("geminicli" 或 "antigravity") + + Returns: + 是否成功 + """ + self._ensure_initialized() + + try: + table_name = self._get_table_name(mode) + async with aiosqlite.connect(self._db_path) as db: + # 获取当前的 model_cooldowns + async with db.execute(f""" + SELECT model_cooldowns FROM {table_name} WHERE filename = ? + """, (filename,)) as cursor: + row = await cursor.fetchone() + + if not row: + log.warning(f"Credential {filename} not found") + return False + + model_cooldowns = json.loads(row[0] or '{}') + + # 更新或删除指定模型的冷却时间 + if cooldown_until is None: + model_cooldowns.pop(model_key, None) + else: + model_cooldowns[model_key] = cooldown_until + + # 写回数据库 + await db.execute(f""" + UPDATE {table_name} + SET model_cooldowns = ?, + updated_at = unixepoch() + WHERE filename = ? + """, (json.dumps(model_cooldowns), filename)) + await db.commit() + + log.debug(f"Set model cooldown: {filename}, model_key={model_key}, cooldown_until={cooldown_until}") + return True + + except Exception as e: + log.error(f"Error setting model cooldown for {filename}: {e}") + return False diff --git a/src/storage_adapter.py b/src/storage_adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..9be85a566aa2b9f164b7868311cac7c0ca822d98 --- /dev/null +++ b/src/storage_adapter.py @@ -0,0 +1,313 @@ +""" +存储适配器,提供统一的接口来处理 SQLite 和 MongoDB 存储。 +根据配置自动选择存储后端: +- 默认使用 SQLite(本地文件存储) +- 如果设置了 MONGODB_URI 环境变量,则使用 MongoDB +""" + +import asyncio +import json +import os +from typing import Any, Dict, List, Optional, Protocol + +from log import log + + +class StorageBackend(Protocol): + """存储后端协议""" + + async def initialize(self) -> None: + """初始化存储后端""" + ... + + async def close(self) -> None: + """关闭存储后端""" + ... + + # 凭证管理 + async def store_credential(self, filename: str, credential_data: Dict[str, Any], mode: str = "geminicli") -> bool: + """存储凭证数据""" + ... + + async def get_credential(self, filename: str, mode: str = "geminicli") -> Optional[Dict[str, Any]]: + """获取凭证数据""" + ... + + async def list_credentials(self, mode: str = "geminicli") -> List[str]: + """列出所有凭证文件名""" + ... + + async def delete_credential(self, filename: str, mode: str = "geminicli") -> bool: + """删除凭证""" + ... + + # 状态管理 + async def update_credential_state(self, filename: str, state_updates: Dict[str, Any], mode: str = "geminicli") -> bool: + """更新凭证状态""" + ... + + async def get_credential_state(self, filename: str, mode: str = "geminicli") -> Dict[str, Any]: + """获取凭证状态""" + ... + + async def get_all_credential_states(self, mode: str = "geminicli") -> Dict[str, Dict[str, Any]]: + """获取所有凭证状态""" + ... + + # 配置管理 + async def set_config(self, key: str, value: Any) -> bool: + """设置配置项""" + ... + + async def get_config(self, key: str, default: Any = None) -> Any: + """获取配置项""" + ... + + async def get_all_config(self) -> Dict[str, Any]: + """获取所有配置""" + ... + + async def delete_config(self, key: str) -> bool: + """删除配置项""" + ... + + +class StorageAdapter: + """存储适配器,根据配置选择存储后端""" + + def __init__(self): + self._backend: Optional["StorageBackend"] = None + self._initialized = False + self._lock = asyncio.Lock() + + async def initialize(self) -> None: + """初始化存储适配器""" + async with self._lock: + if self._initialized: + return + + # 按优先级检查存储后端:SQLite > MongoDB + mongodb_uri = os.getenv("MONGODB_URI", "") + + # 优先使用 SQLite(默认启用,无需环境变量) + if not mongodb_uri: + try: + from .storage.sqlite_manager import SQLiteManager + + self._backend = SQLiteManager() + await self._backend.initialize() + log.info("Using SQLite storage backend") + except Exception as e: + log.error(f"Failed to initialize SQLite backend: {e}") + raise RuntimeError("No storage backend available") from e + else: + # 使用 MongoDB + try: + from .storage.mongodb_manager import MongoDBManager + + self._backend = MongoDBManager() + await self._backend.initialize() + log.info("Using MongoDB storage backend") + except Exception as e: + log.error(f"Failed to initialize MongoDB backend: {e}") + # 尝试降级到 SQLite + log.info("Falling back to SQLite storage backend") + try: + from .storage.sqlite_manager import SQLiteManager + + self._backend = SQLiteManager() + await self._backend.initialize() + log.info("Using SQLite storage backend (fallback)") + except Exception as e2: + log.error(f"Failed to initialize SQLite backend: {e2}") + raise RuntimeError("No storage backend available") from e2 + + self._initialized = True + + async def close(self) -> None: + """关闭存储适配器""" + if self._backend: + await self._backend.close() + self._backend = None + self._initialized = False + + def _ensure_initialized(self): + """确保存储适配器已初始化""" + if not self._initialized or not self._backend: + raise RuntimeError("Storage adapter not initialized") + + # ============ 凭证管理 ============ + + async def store_credential(self, filename: str, credential_data: Dict[str, Any], mode: str = "geminicli") -> bool: + """存储凭证数据""" + self._ensure_initialized() + return await self._backend.store_credential(filename, credential_data, mode) + + async def get_credential(self, filename: str, mode: str = "geminicli") -> Optional[Dict[str, Any]]: + """获取凭证数据""" + self._ensure_initialized() + return await self._backend.get_credential(filename, mode) + + async def list_credentials(self, mode: str = "geminicli") -> List[str]: + """列出所有凭证文件名""" + self._ensure_initialized() + return await self._backend.list_credentials(mode) + + async def delete_credential(self, filename: str, mode: str = "geminicli") -> bool: + """删除凭证""" + self._ensure_initialized() + return await self._backend.delete_credential(filename, mode) + + # ============ 状态管理 ============ + + async def update_credential_state(self, filename: str, state_updates: Dict[str, Any], mode: str = "geminicli") -> bool: + """更新凭证状态""" + self._ensure_initialized() + return await self._backend.update_credential_state(filename, state_updates, mode) + + async def get_credential_state(self, filename: str, mode: str = "geminicli") -> Dict[str, Any]: + """获取凭证状态""" + self._ensure_initialized() + return await self._backend.get_credential_state(filename, mode) + + async def get_all_credential_states(self, mode: str = "geminicli") -> Dict[str, Dict[str, Any]]: + """获取所有凭证状态""" + self._ensure_initialized() + return await self._backend.get_all_credential_states(mode) + + # ============ 配置管理 ============ + + async def set_config(self, key: str, value: Any) -> bool: + """设置配置项""" + self._ensure_initialized() + return await self._backend.set_config(key, value) + + async def get_config(self, key: str, default: Any = None) -> Any: + """获取配置项""" + self._ensure_initialized() + return await self._backend.get_config(key, default) + + async def get_all_config(self) -> Dict[str, Any]: + """获取所有配置""" + self._ensure_initialized() + return await self._backend.get_all_config() + + async def delete_config(self, key: str) -> bool: + """删除配置项""" + self._ensure_initialized() + return await self._backend.delete_config(key) + + # ============ 工具方法 ============ + + async def export_credential_to_json(self, filename: str, output_path: str = None) -> bool: + """将凭证导出为JSON文件""" + self._ensure_initialized() + if hasattr(self._backend, "export_credential_to_json"): + return await self._backend.export_credential_to_json(filename, output_path) + # MongoDB后端的fallback实现 + credential_data = await self.get_credential(filename) + if credential_data is None: + return False + + if output_path is None: + output_path = f"{filename}.json" + + import aiofiles + + try: + async with aiofiles.open(output_path, "w", encoding="utf-8") as f: + await f.write(json.dumps(credential_data, indent=2, ensure_ascii=False)) + return True + except Exception: + return False + + async def import_credential_from_json(self, json_path: str, filename: str = None) -> bool: + """从JSON文件导入凭证""" + self._ensure_initialized() + if hasattr(self._backend, "import_credential_from_json"): + return await self._backend.import_credential_from_json(json_path, filename) + # MongoDB后端的fallback实现 + try: + import aiofiles + + async with aiofiles.open(json_path, "r", encoding="utf-8") as f: + content = await f.read() + + credential_data = json.loads(content) + + if filename is None: + filename = os.path.basename(json_path) + + return await self.store_credential(filename, credential_data) + except Exception: + return False + + def get_backend_type(self) -> str: + """获取当前存储后端类型""" + if not self._backend: + return "none" + + # 检查后端类型 + backend_class_name = self._backend.__class__.__name__ + if "SQLite" in backend_class_name or "sqlite" in backend_class_name.lower(): + return "sqlite" + elif "MongoDB" in backend_class_name or "mongo" in backend_class_name.lower(): + return "mongodb" + else: + return "unknown" + + async def get_backend_info(self) -> Dict[str, Any]: + """获取存储后端信息""" + self._ensure_initialized() + + backend_type = self.get_backend_type() + info = {"backend_type": backend_type, "initialized": self._initialized} + + # 获取底层存储信息 + if hasattr(self._backend, "get_database_info"): + try: + db_info = await self._backend.get_database_info() + info.update(db_info) + except Exception as e: + info["database_error"] = str(e) + else: + backend_type = self.get_backend_type() + if backend_type == "sqlite": + info.update( + { + "database_path": getattr(self._backend, "_db_path", None), + "credentials_dir": getattr(self._backend, "_credentials_dir", None), + } + ) + elif backend_type == "mongodb": + info.update( + { + "database_name": getattr(self._backend, "_db", {}).name if hasattr(self._backend, "_db") else None, + } + ) + + return info + + +# 全局存储适配器实例 +_storage_adapter: Optional[StorageAdapter] = None + + +async def get_storage_adapter() -> StorageAdapter: + """获取全局存储适配器实例""" + global _storage_adapter + + if _storage_adapter is None: + _storage_adapter = StorageAdapter() + await _storage_adapter.initialize() + + return _storage_adapter + + +async def close_storage_adapter(): + """关闭全局存储适配器""" + global _storage_adapter + + if _storage_adapter: + await _storage_adapter.close() + _storage_adapter = None diff --git a/src/task_manager.py b/src/task_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..9e0343ab1918c8e871e706b7a63676c536fb42af --- /dev/null +++ b/src/task_manager.py @@ -0,0 +1,143 @@ +""" +Global task lifecycle management module +管理应用程序中所有异步任务的生命周期,确保正确清理 +""" + +import asyncio +import weakref +from typing import Any, Dict, Set + +from log import log + + +class TaskManager: + """全局异步任务管理器 - 单例模式""" + + _instance = None + _lock = asyncio.Lock() + + def __new__(cls): + if cls._instance is None: + cls._instance = super().__new__(cls) + cls._instance._initialized = False + return cls._instance + + def __init__(self): + if self._initialized: + return + + self._tasks: Set[asyncio.Task] = set() + self._resources: Set[Any] = set() # 需要关闭的资源 + self._shutdown_event = asyncio.Event() + self._initialized = True + log.debug("TaskManager initialized") + + def register_task(self, task: asyncio.Task, description: str = None) -> asyncio.Task: + """注册一个任务供生命周期管理""" + self._tasks.add(task) + task.add_done_callback(lambda t: self._tasks.discard(t)) + + if description: + task.set_name(description) + + log.debug(f"Registered task: {task.get_name() or 'unnamed'}") + return task + + def create_task(self, coro, *, name: str = None) -> asyncio.Task: + """创建并注册一个任务""" + task = asyncio.create_task(coro, name=name) + return self.register_task(task, name) + + def register_resource(self, resource: Any) -> Any: + """注册一个需要清理的资源(如HTTP客户端、文件句柄等)""" + # 使用弱引用避免循环引用 + self._resources.add(weakref.ref(resource)) + log.debug(f"Registered resource: {type(resource).__name__}") + return resource + + async def shutdown(self, timeout: float = 30.0): + """关闭所有任务和资源""" + log.info("TaskManager shutdown initiated") + + # 设置关闭标志 + self._shutdown_event.set() + + # 取消所有未完成的任务 + cancelled_count = 0 + for task in list(self._tasks): + if not task.done(): + task.cancel() + cancelled_count += 1 + + if cancelled_count > 0: + log.info(f"Cancelled {cancelled_count} pending tasks") + + # 等待所有任务完成(包括取消) + if self._tasks: + try: + await asyncio.wait_for( + asyncio.gather(*self._tasks, return_exceptions=True), timeout=timeout + ) + except asyncio.TimeoutError: + log.warning(f"Some tasks did not complete within {timeout}s timeout") + + # 清理资源 - 改进弱引用处理 + cleaned_resources = 0 + failed_resources = 0 + for resource_ref in list(self._resources): + resource = resource_ref() + if resource is not None: + try: + if hasattr(resource, "close"): + if asyncio.iscoroutinefunction(resource.close): + await resource.close() + else: + resource.close() + elif hasattr(resource, "aclose"): + await resource.aclose() + cleaned_resources += 1 + except Exception as e: + log.warning(f"Failed to close resource {type(resource).__name__}: {e}") + failed_resources += 1 + # 如果弱引用已失效,资源已经被自动回收,无需操作 + + if cleaned_resources > 0: + log.info(f"Cleaned up {cleaned_resources} resources") + if failed_resources > 0: + log.warning(f"Failed to clean {failed_resources} resources") + + self._tasks.clear() + self._resources.clear() + log.info("TaskManager shutdown completed") + + @property + def is_shutdown(self) -> bool: + """检查是否已经开始关闭""" + return self._shutdown_event.is_set() + + def get_stats(self) -> Dict[str, int]: + """获取任务管理统计信息""" + return { + "active_tasks": len(self._tasks), + "registered_resources": len(self._resources), + "is_shutdown": self.is_shutdown, + } + + +# 全局任务管理器实例 +task_manager = TaskManager() + + +def create_managed_task(coro, *, name: str = None) -> asyncio.Task: + """创建一个被管理的异步任务的便捷函数""" + return task_manager.create_task(coro, name=name) + + +def register_resource(resource: Any) -> Any: + """注册资源的便捷函数""" + return task_manager.register_resource(resource) + + +async def shutdown_all_tasks(timeout: float = 30.0): + """关闭所有任务的便捷函数""" + await task_manager.shutdown(timeout) diff --git a/src/token_estimator.py b/src/token_estimator.py new file mode 100644 index 0000000000000000000000000000000000000000..ed82cf28ce0d5805f6d3debe0531e16eb3db5290 --- /dev/null +++ b/src/token_estimator.py @@ -0,0 +1,30 @@ +"""简单的 token 估算,不追求精确""" +from __future__ import annotations + +from typing import Any, Dict + + +def estimate_input_tokens(payload: Dict[str, Any]) -> int: + """粗略估算 token 数:字符数 / 4 + 图片固定值""" + total_chars = 0 + image_count = 0 + + # 统计所有文本字符 + def count_str(obj: Any) -> None: + nonlocal total_chars, image_count + if isinstance(obj, str): + total_chars += len(obj) + elif isinstance(obj, dict): + # 检测图片 + if obj.get("type") == "image" or "inlineData" in obj: + image_count += 1 + for v in obj.values(): + count_str(v) + elif isinstance(obj, list): + for item in obj: + count_str(item) + + count_str(payload) + + # 粗略估算:字符数/4 + 每张图片300 tokens + return max(1, total_chars // 4 + image_count * 300) diff --git a/src/utils.py b/src/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..83ca21dd91740533f669278a7387fb49fa76b0ec --- /dev/null +++ b/src/utils.py @@ -0,0 +1,261 @@ +from datetime import datetime, timezone +from typing import List, Optional + +from config import get_api_password, get_panel_password +from fastapi import Depends, HTTPException, Header, Query, Request, status +from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer +from log import log + +# HTTP Bearer security scheme +security = HTTPBearer() + +# ====================== OAuth Configuration ====================== + +GEMINICLI_USER_AGENT = "GeminiCLI/0.1.5 (Windows; AMD64)" + +ANTIGRAVITY_USER_AGENT = "antigravity/1.11.3 windows/amd64" + +# OAuth Configuration - 标准模式 +CLIENT_ID = "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com" +CLIENT_SECRET = "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" +SCOPES = [ + "https://www.googleapis.com/auth/cloud-platform", + "https://www.googleapis.com/auth/userinfo.email", + "https://www.googleapis.com/auth/userinfo.profile", +] + +# Antigravity OAuth Configuration +ANTIGRAVITY_CLIENT_ID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com" +ANTIGRAVITY_CLIENT_SECRET = "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" +ANTIGRAVITY_SCOPES = [ + 'https://www.googleapis.com/auth/cloud-platform', + 'https://www.googleapis.com/auth/userinfo.email', + 'https://www.googleapis.com/auth/userinfo.profile', + 'https://www.googleapis.com/auth/cclog', + 'https://www.googleapis.com/auth/experimentsandconfigs' +] + +# 统一的 Token URL(两种模式相同) +TOKEN_URL = "https://oauth2.googleapis.com/token" + +# 回调服务器配置 +CALLBACK_HOST = "localhost" + +# ====================== Model Configuration ====================== + +# Default Safety Settings for Google API +DEFAULT_SAFETY_SETTINGS = [ + {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"}, + {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"}, + {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE"}, + {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_NONE"}, + {"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "BLOCK_NONE"}, + {"category": "HARM_CATEGORY_IMAGE_HATE", "threshold": "BLOCK_NONE"}, + {"category": "HARM_CATEGORY_IMAGE_DANGEROUS_CONTENT", "threshold": "BLOCK_NONE"}, + {"category": "HARM_CATEGORY_IMAGE_HARASSMENT", "threshold": "BLOCK_NONE"}, + {"category": "HARM_CATEGORY_IMAGE_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE"}, + {"category": "HARM_CATEGORY_JAILBREAK", "threshold": "BLOCK_NONE"}, +] + +# Model name lists for different features +BASE_MODELS = [ + "gemini-2.5-pro", + "gemini-2.5-flash", + "gemini-3-pro-preview", + "gemini-3-flash-preview" +] + + +# ====================== Model Helper Functions ====================== + +def is_fake_streaming_model(model_name: str) -> bool: + """Check if model name indicates fake streaming should be used.""" + return model_name.startswith("假流式/") + + +def is_anti_truncation_model(model_name: str) -> bool: + """Check if model name indicates anti-truncation should be used.""" + return model_name.startswith("流式抗截断/") + + +def get_base_model_from_feature_model(model_name: str) -> str: + """Get base model name from feature model name.""" + # Remove feature prefixes + for prefix in ["假流式/", "流式抗截断/"]: + if model_name.startswith(prefix): + return model_name[len(prefix) :] + return model_name + + +def get_available_models(router_type: str = "openai") -> List[str]: + """ + Get available models with feature prefixes. + + Args: + router_type: "openai" or "gemini" + + Returns: + List of model names with feature prefixes + """ + models = [] + + for base_model in BASE_MODELS: + # 基础模型 + models.append(base_model) + + # 假流式模型 (前缀格式) + models.append(f"假流式/{base_model}") + + # 流式抗截断模型 (仅在流式传输时有效,前缀格式) + models.append(f"流式抗截断/{base_model}") + + # 支持thinking模式后缀与功能前缀组合 + # 新增: 支持多后缀组合 (thinking + search) + thinking_suffixes = ["-maxthinking", "-nothinking"] + search_suffix = "-search" + + # 1. 单独的 thinking 后缀 + for thinking_suffix in thinking_suffixes: + models.append(f"{base_model}{thinking_suffix}") + models.append(f"假流式/{base_model}{thinking_suffix}") + models.append(f"流式抗截断/{base_model}{thinking_suffix}") + + # 2. 单独的 search 后缀 + models.append(f"{base_model}{search_suffix}") + models.append(f"假流式/{base_model}{search_suffix}") + models.append(f"流式抗截断/{base_model}{search_suffix}") + + # 3. thinking + search 组合后缀 + for thinking_suffix in thinking_suffixes: + combined_suffix = f"{thinking_suffix}{search_suffix}" + models.append(f"{base_model}{combined_suffix}") + models.append(f"假流式/{base_model}{combined_suffix}") + models.append(f"流式抗截断/{base_model}{combined_suffix}") + + return models + + +# ====================== Authentication Functions ====================== + +async def authenticate_flexible( + request: Request, + authorization: Optional[str] = Header(None), + x_api_key: Optional[str] = Header(None, alias="x-api-key"), + access_token: Optional[str] = Header(None, alias="access_token"), + x_goog_api_key: Optional[str] = Header(None, alias="x-goog-api-key"), + key: Optional[str] = Query(None) +) -> str: + """ + 统一的灵活认证函数,支持多种认证方式 + + 此函数可以直接用作 FastAPI 的 Depends 依赖 + + 支持的认证方式: + - URL 参数: key + - HTTP 头部: Authorization (Bearer token) + - HTTP 头部: x-api-key + - HTTP 头部: access_token + - HTTP 头部: x-goog-api-key + + Args: + request: FastAPI Request 对象 + authorization: Authorization 头部值(自动注入) + x_api_key: x-api-key 头部值(自动注入) + access_token: access_token 头部值(自动注入) + x_goog_api_key: x-goog-api-key 头部值(自动注入) + key: URL 参数 key(自动注入) + + Returns: + 验证通过的token + + Raises: + HTTPException: 认证失败时抛出异常 + + 使用示例: + @router.post("/endpoint") + async def endpoint(token: str = Depends(authenticate_flexible)): + # token 已验证通过 + pass + """ + password = await get_api_password() + token = None + auth_method = None + + # 1. 尝试从 URL 参数 key 获取(Google 官方标准方式) + if key: + token = key + auth_method = "URL parameter 'key'" + + # 2. 尝试从 x-goog-api-key 头部获取(Google API 标准方式) + elif x_goog_api_key: + token = x_goog_api_key + auth_method = "x-goog-api-key header" + + # 3. 尝试从 x-api-key 头部获取 + elif x_api_key: + token = x_api_key + auth_method = "x-api-key header" + + # 4. 尝试从 access_token 头部获取 + elif access_token: + token = access_token + auth_method = "access_token header" + + # 5. 尝试从 Authorization 头部获取 + elif authorization: + if not authorization.startswith("Bearer "): + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid authentication scheme. Use 'Bearer '", + headers={"WWW-Authenticate": "Bearer"}, + ) + token = authorization[7:] # 移除 "Bearer " 前缀 + auth_method = "Authorization Bearer header" + + # 检查是否提供了任何认证凭据 + if not token: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Missing authentication credentials. Use 'key' URL parameter, 'x-goog-api-key', 'x-api-key', 'access_token' header, or 'Authorization: Bearer '", + headers={"WWW-Authenticate": "Bearer"}, + ) + + # 验证 token + if token != password: + log.error(f"Authentication failed using {auth_method}") + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="密码错误" + ) + + log.debug(f"Authentication successful using {auth_method}") + return token + + +# 为了保持向后兼容,保留旧函数名作为别名 +authenticate_bearer = authenticate_flexible +authenticate_gemini_flexible = authenticate_flexible + + +# ====================== Panel Authentication Functions ====================== + +async def verify_panel_token(credentials: HTTPAuthorizationCredentials = Depends(security)) -> str: + """ + 简化的控制面板密码验证函数 + + 直接验证Bearer token是否等于控制面板密码 + + Args: + credentials: HTTPAuthorizationCredentials 自动注入 + + Returns: + 验证通过的token + + Raises: + HTTPException: 密码错误时抛出401异常 + """ + + password = await get_panel_password() + if credentials.credentials != password: + raise HTTPException(status_code=401, detail="密码错误") + return credentials.credentials diff --git a/src/web_routes.py b/src/web_routes.py new file mode 100644 index 0000000000000000000000000000000000000000..67b34dbabee2d41aa680cbfac542f1414c0fc451 --- /dev/null +++ b/src/web_routes.py @@ -0,0 +1,1958 @@ +""" +Web路由模块 - 处理认证相关的HTTP请求和控制面板功能 +用于与上级web.py集成 +""" + +import asyncio +import datetime +import io +import json +import os +import time +import zipfile +from collections import deque +from typing import List + +from fastapi import ( + APIRouter, + Depends, + File, + HTTPException, + Request, + UploadFile, + WebSocket, + WebSocketDisconnect, +) +from fastapi.responses import FileResponse, HTMLResponse, JSONResponse, Response +from starlette.websockets import WebSocketState + +import config +from log import log + +from src.auth import ( + asyncio_complete_auth_flow, + complete_auth_flow_from_callback_url, + create_auth_url, + get_auth_status, + verify_password, +) +from src.credential_manager import CredentialManager +from .models import ( + LoginRequest, + AuthStartRequest, + AuthCallbackRequest, + AuthCallbackUrlRequest, + CredFileActionRequest, + CredFileBatchActionRequest, + ConfigSaveRequest, +) +from src.storage_adapter import get_storage_adapter +from src.utils import verify_panel_token, GEMINICLI_USER_AGENT, ANTIGRAVITY_USER_AGENT +from src.api.antigravity import fetch_quota_info +from src.google_oauth_api import Credentials, fetch_project_id +from config import get_code_assist_endpoint, get_antigravity_api_url + +# 创建路由器 +router = APIRouter() + +# 创建credential manager实例(延迟初始化,在首次使用时自动初始化) +credential_manager = CredentialManager() + +# WebSocket连接管理 + + +class ConnectionManager: + def __init__(self, max_connections: int = 3): # 进一步降低最大连接数 + # 使用双端队列严格限制内存使用 + self.active_connections: deque = deque(maxlen=max_connections) + self.max_connections = max_connections + self._last_cleanup = 0 + self._cleanup_interval = 120 # 120秒清理一次死连接 + + async def connect(self, websocket: WebSocket): + # 自动清理死连接 + self._auto_cleanup() + + # 限制最大连接数,防止内存无限增长 + if len(self.active_connections) >= self.max_connections: + await websocket.close(code=1008, reason="Too many connections") + return False + + await websocket.accept() + self.active_connections.append(websocket) + log.debug(f"WebSocket连接建立,当前连接数: {len(self.active_connections)}") + return True + + def disconnect(self, websocket: WebSocket): + # 使用更高效的方式移除连接 + try: + self.active_connections.remove(websocket) + except ValueError: + pass # 连接已不存在 + log.debug(f"WebSocket连接断开,当前连接数: {len(self.active_connections)}") + + async def send_personal_message(self, message: str, websocket: WebSocket): + try: + await websocket.send_text(message) + except Exception: + self.disconnect(websocket) + + async def broadcast(self, message: str): + # 使用更高效的方式处理广播,避免索引操作 + dead_connections = [] + for conn in self.active_connections: + try: + await conn.send_text(message) + except Exception: + dead_connections.append(conn) + + # 批量移除死连接 + for dead_conn in dead_connections: + self.disconnect(dead_conn) + + def _auto_cleanup(self): + """自动清理死连接""" + current_time = time.time() + if current_time - self._last_cleanup > self._cleanup_interval: + self.cleanup_dead_connections() + self._last_cleanup = current_time + + def cleanup_dead_connections(self): + """清理已断开的连接""" + original_count = len(self.active_connections) + # 使用列表推导式过滤活跃连接,更高效 + alive_connections = deque( + [ + conn + for conn in self.active_connections + if hasattr(conn, "client_state") + and conn.client_state != WebSocketState.DISCONNECTED + ], + maxlen=self.max_connections, + ) + + self.active_connections = alive_connections + cleaned = original_count - len(self.active_connections) + if cleaned > 0: + log.debug(f"清理了 {cleaned} 个死连接,剩余连接数: {len(self.active_connections)}") + + +manager = ConnectionManager() + + +async def ensure_credential_manager_initialized(): + """确保credential manager已初始化""" + if not credential_manager._initialized: + await credential_manager.initialize() + + +async def get_credential_manager(): + """获取全局凭证管理器实例(已废弃,直接使用模块级的 credential_manager)""" + global credential_manager + # 确保已初始化(在首次使用时自动初始化) + await credential_manager._ensure_initialized() + return credential_manager + + +def is_mobile_user_agent(user_agent: str) -> bool: + """检测是否为移动设备用户代理""" + if not user_agent: + return False + + user_agent_lower = user_agent.lower() + mobile_keywords = [ + "mobile", + "android", + "iphone", + "ipad", + "ipod", + "blackberry", + "windows phone", + "samsung", + "htc", + "motorola", + "nokia", + "palm", + "webos", + "opera mini", + "opera mobi", + "fennec", + "minimo", + "symbian", + "psp", + "nintendo", + "tablet", + ] + + return any(keyword in user_agent_lower for keyword in mobile_keywords) + + +@router.get("/", response_class=HTMLResponse) +async def serve_control_panel(request: Request): + """提供统一控制面板""" + try: + user_agent = request.headers.get("user-agent", "") + is_mobile = is_mobile_user_agent(user_agent) + + if is_mobile: + html_file_path = "front/control_panel_mobile.html" + else: + html_file_path = "front/control_panel.html" + + with open(html_file_path, "r", encoding="utf-8") as f: + html_content = f.read() + return HTMLResponse(content=html_content) + + except Exception as e: + log.error(f"加载控制面板页面失败: {e}") + raise HTTPException(status_code=500, detail="服务器内部错误") + + +@router.post("/auth/login") +async def login(request: LoginRequest): + """用户登录(简化版:直接返回密码作为token)""" + try: + if await verify_password(request.password): + # 直接使用密码作为token,简化认证流程 + return JSONResponse(content={"token": request.password, "message": "登录成功"}) + else: + raise HTTPException(status_code=401, detail="密码错误") + except HTTPException: + raise + except Exception as e: + log.error(f"登录失败: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@router.post("/auth/start") +async def start_auth(request: AuthStartRequest, token: str = Depends(verify_panel_token)): + """开始认证流程,支持自动检测项目ID""" + try: + # 如果没有提供项目ID,尝试自动检测 + project_id = request.project_id + if not project_id: + log.info("用户未提供项目ID,后续将使用自动检测...") + + # 使用认证令牌作为用户会话标识 + user_session = token if token else None + result = await create_auth_url( + project_id, user_session, mode=request.mode + ) + + if result["success"]: + return JSONResponse( + content={ + "auth_url": result["auth_url"], + "state": result["state"], + "auto_project_detection": result.get("auto_project_detection", False), + "detected_project_id": result.get("detected_project_id"), + } + ) + else: + raise HTTPException(status_code=500, detail=result["error"]) + + except HTTPException: + raise + except Exception as e: + log.error(f"开始认证流程失败: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@router.post("/auth/callback") +async def auth_callback(request: AuthCallbackRequest, token: str = Depends(verify_panel_token)): + """处理认证回调,支持自动检测项目ID""" + try: + # 项目ID现在是可选的,在回调处理中进行自动检测 + project_id = request.project_id + + # 使用认证令牌作为用户会话标识 + user_session = token if token else None + # 异步等待OAuth回调完成 + result = await asyncio_complete_auth_flow( + project_id, user_session, mode=request.mode + ) + + if result["success"]: + # 单项目认证成功 + return JSONResponse( + content={ + "credentials": result["credentials"], + "file_path": result["file_path"], + "message": "认证成功,凭证已保存", + "auto_detected_project": result.get("auto_detected_project", False), + } + ) + else: + # 如果需要手动项目ID或项目选择,在响应中标明 + if result.get("requires_manual_project_id"): + # 使用JSON响应 + return JSONResponse( + status_code=400, + content={"error": result["error"], "requires_manual_project_id": True}, + ) + elif result.get("requires_project_selection"): + # 返回项目列表供用户选择 + return JSONResponse( + status_code=400, + content={ + "error": result["error"], + "requires_project_selection": True, + "available_projects": result["available_projects"], + }, + ) + else: + raise HTTPException(status_code=400, detail=result["error"]) + + except HTTPException: + raise + except Exception as e: + log.error(f"处理认证回调失败: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@router.post("/auth/callback-url") +async def auth_callback_url(request: AuthCallbackUrlRequest, token: str = Depends(verify_panel_token)): + """从回调URL直接完成认证""" + try: + # 验证URL格式 + if not request.callback_url or not request.callback_url.startswith(("http://", "https://")): + raise HTTPException(status_code=400, detail="请提供有效的回调URL") + + # 从回调URL完成认证 + result = await complete_auth_flow_from_callback_url( + request.callback_url, request.project_id, mode=request.mode + ) + + if result["success"]: + # 单项目认证成功 + return JSONResponse( + content={ + "credentials": result["credentials"], + "file_path": result["file_path"], + "message": "从回调URL认证成功,凭证已保存", + "auto_detected_project": result.get("auto_detected_project", False), + } + ) + else: + # 处理各种错误情况 + if result.get("requires_manual_project_id"): + return JSONResponse( + status_code=400, + content={"error": result["error"], "requires_manual_project_id": True}, + ) + elif result.get("requires_project_selection"): + return JSONResponse( + status_code=400, + content={ + "error": result["error"], + "requires_project_selection": True, + "available_projects": result["available_projects"], + }, + ) + else: + raise HTTPException(status_code=400, detail=result["error"]) + + except HTTPException: + raise + except Exception as e: + log.error(f"从回调URL处理认证失败: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@router.get("/auth/status/{project_id}") +async def check_auth_status(project_id: str, token: str = Depends(verify_panel_token)): + """检查认证状态""" + try: + if not project_id: + raise HTTPException(status_code=400, detail="Project ID 不能为空") + + status = get_auth_status(project_id) + return JSONResponse(content=status) + + except Exception as e: + log.error(f"检查认证状态失败: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +# ============================================================================= +# 工具函数 (Helper Functions) +# ============================================================================= + + +def validate_mode(mode: str = "geminicli") -> str: + """ + 验证 mode 参数 + + Args: + mode: 模式字符串 ("geminicli" 或 "antigravity") + + Returns: + str: 验证后的 mode 字符串 + + Raises: + HTTPException: 如果 mode 参数无效 + """ + if mode not in ["geminicli", "antigravity"]: + raise HTTPException( + status_code=400, + detail=f"无效的 mode 参数: {mode},只支持 'geminicli' 或 'antigravity'" + ) + return mode + + +def get_env_locked_keys() -> set: + """获取被环境变量锁定的配置键集合""" + env_locked_keys = set() + + # 使用 config.py 中统一维护的映射表 + for env_key, config_key in config.ENV_MAPPINGS.items(): + if os.getenv(env_key): + env_locked_keys.add(config_key) + + return env_locked_keys + + +async def extract_json_files_from_zip(zip_file: UploadFile) -> List[dict]: + """从ZIP文件中提取JSON文件""" + try: + # 读取ZIP文件内容 + zip_content = await zip_file.read() + + # 不限制ZIP文件大小,只在处理时控制文件数量 + + files_data = [] + + with zipfile.ZipFile(io.BytesIO(zip_content), "r") as zip_ref: + # 获取ZIP中的所有文件 + file_list = zip_ref.namelist() + json_files = [ + f for f in file_list if f.endswith(".json") and not f.startswith("__MACOSX/") + ] + + if not json_files: + raise HTTPException(status_code=400, detail="ZIP文件中没有找到JSON文件") + + log.info(f"从ZIP文件 {zip_file.filename} 中找到 {len(json_files)} 个JSON文件") + + for json_filename in json_files: + try: + # 读取JSON文件内容 + with zip_ref.open(json_filename) as json_file: + content = json_file.read() + + try: + content_str = content.decode("utf-8") + except UnicodeDecodeError: + log.warning(f"跳过编码错误的文件: {json_filename}") + continue + + # 使用原始文件名(去掉路径) + filename = os.path.basename(json_filename) + files_data.append({"filename": filename, "content": content_str}) + + except Exception as e: + log.warning(f"处理ZIP中的文件 {json_filename} 时出错: {e}") + continue + + log.info(f"成功从ZIP文件中提取 {len(files_data)} 个有效的JSON文件") + return files_data + + except zipfile.BadZipFile: + raise HTTPException(status_code=400, detail="无效的ZIP文件格式") + except Exception as e: + log.error(f"处理ZIP文件失败: {e}") + raise HTTPException(status_code=500, detail=f"处理ZIP文件失败: {str(e)}") + + +async def upload_credentials_common( + files: List[UploadFile], mode: str = "geminicli" +) -> JSONResponse: + """批量上传凭证文件的通用函数""" + mode = validate_mode(mode) + + if not files: + raise HTTPException(status_code=400, detail="请选择要上传的文件") + + # 检查文件数量限制 + if len(files) > 100: + raise HTTPException( + status_code=400, detail=f"文件数量过多,最多支持100个文件,当前:{len(files)}个" + ) + + files_data = [] + for file in files: + # 检查文件类型:支持JSON和ZIP + if file.filename.endswith(".zip"): + zip_files_data = await extract_json_files_from_zip(file) + files_data.extend(zip_files_data) + log.info(f"从ZIP文件 {file.filename} 中提取了 {len(zip_files_data)} 个JSON文件") + + elif file.filename.endswith(".json"): + # 处理单个JSON文件 - 流式读取 + content_chunks = [] + while True: + chunk = await file.read(8192) + if not chunk: + break + content_chunks.append(chunk) + + content = b"".join(content_chunks) + try: + content_str = content.decode("utf-8") + except UnicodeDecodeError: + raise HTTPException( + status_code=400, detail=f"文件 {file.filename} 编码格式不支持" + ) + + files_data.append({"filename": file.filename, "content": content_str}) + else: + raise HTTPException( + status_code=400, detail=f"文件 {file.filename} 格式不支持,只支持JSON和ZIP文件" + ) + + + + batch_size = 1000 + all_results = [] + total_success = 0 + + for i in range(0, len(files_data), batch_size): + batch_files = files_data[i : i + batch_size] + + async def process_single_file(file_data): + try: + filename = file_data["filename"] + # 确保文件名只保存basename,避免路径问题 + filename = os.path.basename(filename) + content_str = file_data["content"] + credential_data = json.loads(content_str) + + # 根据凭证类型调用不同的添加方法 + if mode == "antigravity": + await credential_manager.add_antigravity_credential(filename, credential_data) + else: + await credential_manager.add_credential(filename, credential_data) + + log.debug(f"成功上传 {mode} 凭证文件: {filename}") + return {"filename": filename, "status": "success", "message": "上传成功"} + + except json.JSONDecodeError as e: + return { + "filename": file_data["filename"], + "status": "error", + "message": f"JSON格式错误: {str(e)}", + } + except Exception as e: + return { + "filename": file_data["filename"], + "status": "error", + "message": f"处理失败: {str(e)}", + } + + log.info(f"开始并发处理 {len(batch_files)} 个 {mode} 文件...") + concurrent_tasks = [process_single_file(file_data) for file_data in batch_files] + batch_results = await asyncio.gather(*concurrent_tasks, return_exceptions=True) + + processed_results = [] + batch_uploaded_count = 0 + for result in batch_results: + if isinstance(result, Exception): + processed_results.append( + { + "filename": "unknown", + "status": "error", + "message": f"处理异常: {str(result)}", + } + ) + else: + processed_results.append(result) + if result["status"] == "success": + batch_uploaded_count += 1 + + all_results.extend(processed_results) + total_success += batch_uploaded_count + + batch_num = (i // batch_size) + 1 + total_batches = (len(files_data) + batch_size - 1) // batch_size + log.info( + f"批次 {batch_num}/{total_batches} 完成: 成功 " + f"{batch_uploaded_count}/{len(batch_files)} 个 {mode} 文件" + ) + + if total_success > 0: + return JSONResponse( + content={ + "uploaded_count": total_success, + "total_count": len(files_data), + "results": all_results, + "message": f"批量上传完成: 成功 {total_success}/{len(files_data)} 个 {mode} 文件", + } + ) + else: + raise HTTPException(status_code=400, detail=f"没有 {mode} 文件上传成功") + + +async def get_creds_status_common( + offset: int, limit: int, status_filter: str, mode: str = "geminicli", + error_code_filter: str = None, cooldown_filter: str = None +) -> JSONResponse: + """获取凭证文件状态的通用函数""" + mode = validate_mode(mode) + # 验证分页参数 + if offset < 0: + raise HTTPException(status_code=400, detail="offset 必须大于等于 0") + if limit not in [20, 50, 100, 200, 500, 1000]: + raise HTTPException(status_code=400, detail="limit 只能是 20、50、100、200、500 或 1000") + if status_filter not in ["all", "enabled", "disabled"]: + raise HTTPException(status_code=400, detail="status_filter 只能是 all、enabled 或 disabled") + if cooldown_filter and cooldown_filter not in ["all", "in_cooldown", "no_cooldown"]: + raise HTTPException(status_code=400, detail="cooldown_filter 只能是 all、in_cooldown 或 no_cooldown") + + + + storage_adapter = await get_storage_adapter() + backend_info = await storage_adapter.get_backend_info() + backend_type = backend_info.get("backend_type", "unknown") + + # 优先使用高性能的分页摘要查询 + if hasattr(storage_adapter._backend, 'get_credentials_summary'): + result = await storage_adapter._backend.get_credentials_summary( + offset=offset, + limit=limit, + status_filter=status_filter, + mode=mode, + error_code_filter=error_code_filter if error_code_filter and error_code_filter != "all" else None, + cooldown_filter=cooldown_filter if cooldown_filter and cooldown_filter != "all" else None + ) + + creds_list = [] + for summary in result["items"]: + cred_info = { + "filename": os.path.basename(summary["filename"]), + "user_email": summary["user_email"], + "disabled": summary["disabled"], + "error_codes": summary["error_codes"], + "last_success": summary["last_success"], + "backend_type": backend_type, + "model_cooldowns": summary.get("model_cooldowns", {}), + } + + creds_list.append(cred_info) + + return JSONResponse(content={ + "items": creds_list, + "total": result["total"], + "offset": offset, + "limit": limit, + "has_more": (offset + limit) < result["total"], + "stats": result.get("stats", {"total": 0, "normal": 0, "disabled": 0}), + }) + + # 回退到传统方式(MongoDB/其他后端) + all_credentials = await storage_adapter.list_credentials(mode=mode) + all_states = await storage_adapter.get_all_credential_states(mode=mode) + + # 应用状态筛选 + filtered_credentials = [] + for filename in all_credentials: + file_status = all_states.get(filename, {"disabled": False}) + is_disabled = file_status.get("disabled", False) + + if status_filter == "all": + filtered_credentials.append(filename) + elif status_filter == "enabled" and not is_disabled: + filtered_credentials.append(filename) + elif status_filter == "disabled" and is_disabled: + filtered_credentials.append(filename) + + total_count = len(filtered_credentials) + paginated_credentials = filtered_credentials[offset:offset + limit] + + creds_list = [] + for filename in paginated_credentials: + file_status = all_states.get(filename, { + "error_codes": [], + "disabled": False, + "last_success": time.time(), + "user_email": None, + }) + + cred_info = { + "filename": os.path.basename(filename), + "user_email": file_status.get("user_email"), + "disabled": file_status.get("disabled", False), + "error_codes": file_status.get("error_codes", []), + "last_success": file_status.get("last_success", time.time()), + "backend_type": backend_type, + "model_cooldowns": file_status.get("model_cooldowns", {}), + } + + creds_list.append(cred_info) + + return JSONResponse(content={ + "items": creds_list, + "total": total_count, + "offset": offset, + "limit": limit, + "has_more": (offset + limit) < total_count, + }) + + +async def download_all_creds_common(mode: str = "geminicli") -> Response: + """打包下载所有凭证文件的通用函数""" + mode = validate_mode(mode) + zip_filename = "antigravity_credentials.zip" if mode == "antigravity" else "credentials.zip" + + storage_adapter = await get_storage_adapter() + credential_filenames = await storage_adapter.list_credentials(mode=mode) + + if not credential_filenames: + raise HTTPException(status_code=404, detail=f"没有找到 {mode} 凭证文件") + + log.info(f"开始打包 {len(credential_filenames)} 个 {mode} 凭证文件...") + + zip_buffer = io.BytesIO() + + with zipfile.ZipFile(zip_buffer, "w", zipfile.ZIP_DEFLATED) as zip_file: + success_count = 0 + for idx, filename in enumerate(credential_filenames, 1): + try: + credential_data = await storage_adapter.get_credential(filename, mode=mode) + if credential_data: + content = json.dumps(credential_data, ensure_ascii=False, indent=2) + zip_file.writestr(os.path.basename(filename), content) + success_count += 1 + + if idx % 10 == 0: + log.debug(f"打包进度: {idx}/{len(credential_filenames)}") + + except Exception as e: + log.warning(f"处理 {mode} 凭证文件 {filename} 时出错: {e}") + continue + + log.info(f"打包完成: 成功 {success_count}/{len(credential_filenames)} 个文件") + + zip_buffer.seek(0) + return Response( + content=zip_buffer.getvalue(), + media_type="application/zip", + headers={"Content-Disposition": f"attachment; filename={zip_filename}"}, + ) + + +async def fetch_user_email_common(filename: str, mode: str = "geminicli") -> JSONResponse: + """获取指定凭证文件用户邮箱的通用函数""" + mode = validate_mode(mode) + + filename_only = os.path.basename(filename) + if not filename_only.endswith(".json"): + raise HTTPException(status_code=404, detail="无效的文件名") + + storage_adapter = await get_storage_adapter() + credential_data = await storage_adapter.get_credential(filename_only, mode=mode) + if not credential_data: + raise HTTPException(status_code=404, detail="凭证文件不存在") + + email = await credential_manager.get_or_fetch_user_email(filename_only, mode=mode) + + if email: + return JSONResponse( + content={ + "filename": filename_only, + "user_email": email, + "message": "成功获取用户邮箱", + } + ) + else: + return JSONResponse( + content={ + "filename": filename_only, + "user_email": None, + "message": "无法获取用户邮箱,可能凭证已过期或权限不足", + }, + status_code=400, + ) + + +async def refresh_all_user_emails_common(mode: str = "geminicli") -> JSONResponse: + """刷新所有凭证文件用户邮箱的通用函数 - 只为没有邮箱的凭证获取 + + 利用 get_all_credential_states 批量获取状态 + """ + mode = validate_mode(mode) + + storage_adapter = await get_storage_adapter() + + # 一次性批量获取所有凭证的状态 + all_states = await storage_adapter.get_all_credential_states(mode=mode) + + results = [] + success_count = 0 + skipped_count = 0 + + # 在内存中筛选出需要获取邮箱的凭证 + for filename, state in all_states.items(): + try: + cached_email = state.get("user_email") + + if cached_email: + # 已有邮箱,跳过获取 + skipped_count += 1 + results.append({ + "filename": os.path.basename(filename), + "user_email": cached_email, + "success": True, + "skipped": True, + }) + continue + + # 没有邮箱,尝试获取 + email = await credential_manager.get_or_fetch_user_email(filename, mode=mode) + if email: + success_count += 1 + results.append({ + "filename": os.path.basename(filename), + "user_email": email, + "success": True, + }) + else: + results.append({ + "filename": os.path.basename(filename), + "user_email": None, + "success": False, + "error": "无法获取邮箱", + }) + except Exception as e: + results.append({ + "filename": os.path.basename(filename), + "user_email": None, + "success": False, + "error": str(e), + }) + + total_count = len(all_states) + return JSONResponse( + content={ + "success_count": success_count, + "total_count": total_count, + "skipped_count": skipped_count, + "results": results, + "message": f"成功获取 {success_count}/{total_count} 个邮箱地址,跳过 {skipped_count} 个已有邮箱的凭证", + } + ) + + +async def deduplicate_credentials_by_email_common(mode: str = "geminicli") -> JSONResponse: + """批量去重凭证文件的通用函数 - 删除邮箱相同的凭证(只保留一个)""" + mode = validate_mode(mode) + storage_adapter = await get_storage_adapter() + + try: + duplicate_info = await storage_adapter._backend.get_duplicate_credentials_by_email( + mode=mode + ) + + duplicate_groups = duplicate_info.get("duplicate_groups", []) + no_email_files = duplicate_info.get("no_email_files", []) + total_count = duplicate_info.get("total_count", 0) + + if not duplicate_groups: + return JSONResponse( + content={ + "deleted_count": 0, + "kept_count": total_count, + "total_count": total_count, + "unique_emails_count": duplicate_info.get("unique_email_count", 0), + "no_email_count": len(no_email_files), + "duplicate_groups": [], + "delete_errors": [], + "message": "没有发现重复的凭证(相同邮箱)", + } + ) + + # 执行删除操作 + deleted_count = 0 + delete_errors = [] + result_duplicate_groups = [] + + for group in duplicate_groups: + email = group["email"] + kept_file = group["kept_file"] + duplicate_files = group["duplicate_files"] + + deleted_files_in_group = [] + for filename in duplicate_files: + try: + success = await credential_manager.remove_credential(filename, mode=mode) + if success: + deleted_count += 1 + deleted_files_in_group.append(os.path.basename(filename)) + log.info(f"去重删除凭证: {filename} (邮箱: {email}) (mode={mode})") + else: + delete_errors.append(f"{os.path.basename(filename)}: 删除失败") + except Exception as e: + delete_errors.append(f"{os.path.basename(filename)}: {str(e)}") + log.error(f"去重删除凭证 {filename} 时出错: {e}") + + result_duplicate_groups.append({ + "email": email, + "kept_file": os.path.basename(kept_file), + "deleted_files": deleted_files_in_group, + "duplicate_count": len(deleted_files_in_group), + }) + + kept_count = total_count - deleted_count + + return JSONResponse( + content={ + "deleted_count": deleted_count, + "kept_count": kept_count, + "total_count": total_count, + "unique_emails_count": duplicate_info.get("unique_email_count", 0), + "no_email_count": len(no_email_files), + "duplicate_groups": result_duplicate_groups, + "delete_errors": delete_errors, + "message": f"去重完成:删除 {deleted_count} 个重复凭证,保留 {kept_count} 个凭证({duplicate_info.get('unique_email_count', 0)} 个唯一邮箱)", + } + ) + + except Exception as e: + log.error(f"批量去重凭证时出错: {e}") + return JSONResponse( + status_code=500, + content={ + "deleted_count": 0, + "kept_count": 0, + "total_count": 0, + "message": f"去重操作失败: {str(e)}", + } + ) + + +# ============================================================================= +# 路由处理函数 (Route Handlers) +# ============================================================================= + + +@router.post("/auth/upload") +async def upload_credentials( + files: List[UploadFile] = File(...), + token: str = Depends(verify_panel_token), + mode: str = "geminicli" +): + """批量上传认证文件""" + try: + mode = validate_mode(mode) + return await upload_credentials_common(files, mode=mode) + except HTTPException: + raise + except Exception as e: + log.error(f"批量上传失败: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@router.get("/creds/status") +async def get_creds_status( + token: str = Depends(verify_panel_token), + offset: int = 0, + limit: int = 50, + status_filter: str = "all", + error_code_filter: str = "all", + cooldown_filter: str = "all", + mode: str = "geminicli" +): + """ + 获取凭证文件的状态(轻量级摘要,不包含完整凭证数据,支持分页和状态筛选) + + Args: + offset: 跳过的记录数(默认0) + limit: 每页返回的记录数(默认50,可选:20, 50, 100, 200, 500, 1000) + status_filter: 状态筛选(all=全部, enabled=仅启用, disabled=仅禁用) + error_code_filter: 错误码筛选(all=全部, 或具体错误码如"400", "403") + cooldown_filter: 冷却状态筛选(all=全部, in_cooldown=冷却中, no_cooldown=未冷却) + mode: 凭证模式(geminicli 或 antigravity) + + Returns: + 包含凭证列表、总数、分页信息的响应 + """ + try: + mode = validate_mode(mode) + return await get_creds_status_common( + offset, limit, status_filter, mode=mode, + error_code_filter=error_code_filter, + cooldown_filter=cooldown_filter + ) + except HTTPException: + raise + except Exception as e: + log.error(f"获取凭证状态失败: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@router.get("/creds/detail/{filename}") +async def get_cred_detail( + filename: str, + token: str = Depends(verify_panel_token), + mode: str = "geminicli" +): + """ + 按需获取单个凭证的详细数据(包含完整凭证内容) + 用于用户查看/编辑凭证详情 + """ + try: + mode = validate_mode(mode) + # 验证文件名 + if not filename.endswith(".json"): + raise HTTPException(status_code=400, detail="无效的文件名") + + + + storage_adapter = await get_storage_adapter() + backend_info = await storage_adapter.get_backend_info() + backend_type = backend_info.get("backend_type", "unknown") + + # 获取凭证数据 + credential_data = await storage_adapter.get_credential(filename, mode=mode) + if not credential_data: + raise HTTPException(status_code=404, detail="凭证不存在") + + # 获取状态信息 + file_status = await storage_adapter.get_credential_state(filename, mode=mode) + if not file_status: + file_status = { + "error_codes": [], + "disabled": False, + "last_success": time.time(), + "user_email": None, + } + + result = { + "status": file_status, + "content": credential_data, + "filename": os.path.basename(filename), + "backend_type": backend_type, + "user_email": file_status.get("user_email"), + "model_cooldowns": file_status.get("model_cooldowns", {}), + } + + if backend_type == "file" and os.path.exists(filename): + result.update({ + "size": os.path.getsize(filename), + "modified_time": os.path.getmtime(filename), + }) + + return JSONResponse(content=result) + + except HTTPException: + raise + except Exception as e: + log.error(f"获取凭证详情失败 {filename}: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@router.post("/creds/action") +async def creds_action( + request: CredFileActionRequest, + token: str = Depends(verify_panel_token), + mode: str = "geminicli" +): + """对凭证文件执行操作(启用/禁用/删除)""" + try: + mode = validate_mode(mode) + + log.info(f"Received request: {request}") + + filename = request.filename + action = request.action + + log.info(f"Performing action '{action}' on file: {filename} (mode={mode})") + + # 验证文件名 + if not filename.endswith(".json"): + log.error(f"无效的文件名: {filename}(不是.json文件)") + raise HTTPException(status_code=400, detail=f"无效的文件名: {filename}") + + # 获取存储适配器 + storage_adapter = await get_storage_adapter() + + # 对于删除操作,不需要检查凭证数据是否完整,只需检查条目是否存在 + # 对于其他操作,需要确保凭证数据存在且完整 + if action != "delete": + # 检查凭证数据是否存在 + credential_data = await storage_adapter.get_credential(filename, mode=mode) + if not credential_data: + log.error(f"凭证未找到: {filename} (mode={mode})") + raise HTTPException(status_code=404, detail="凭证文件不存在") + + if action == "enable": + log.info(f"Web请求: 启用文件 {filename} (mode={mode})") + result = await credential_manager.set_cred_disabled(filename, False, mode=mode) + log.info(f"[WebRoute] set_cred_disabled 返回结果: {result}") + if result: + log.info(f"Web请求: 文件 {filename} 已成功启用 (mode={mode})") + return JSONResponse(content={"message": f"已启用凭证文件 {os.path.basename(filename)}"}) + else: + log.error(f"Web请求: 文件 {filename} 启用失败 (mode={mode})") + raise HTTPException(status_code=500, detail="启用凭证失败,可能凭证不存在") + + elif action == "disable": + log.info(f"Web请求: 禁用文件 {filename} (mode={mode})") + result = await credential_manager.set_cred_disabled(filename, True, mode=mode) + log.info(f"[WebRoute] set_cred_disabled 返回结果: {result}") + if result: + log.info(f"Web请求: 文件 {filename} 已成功禁用 (mode={mode})") + return JSONResponse(content={"message": f"已禁用凭证文件 {os.path.basename(filename)}"}) + else: + log.error(f"Web请求: 文件 {filename} 禁用失败 (mode={mode})") + raise HTTPException(status_code=500, detail="禁用凭证失败,可能凭证不存在") + + elif action == "delete": + try: + # 使用 CredentialManager 删除凭证(包含队列/状态同步) + success = await credential_manager.remove_credential(filename, mode=mode) + if success: + log.info(f"通过管理器成功删除凭证: {filename} (mode={mode})") + return JSONResponse( + content={"message": f"已删除凭证文件 {os.path.basename(filename)}"} + ) + else: + raise HTTPException(status_code=500, detail="删除凭证失败") + except Exception as e: + log.error(f"删除凭证 {filename} 时出错: {e}") + raise HTTPException(status_code=500, detail=f"删除文件失败: {str(e)}") + + else: + raise HTTPException(status_code=400, detail="无效的操作类型") + + except HTTPException: + raise + except Exception as e: + log.error(f"凭证文件操作失败: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@router.post("/creds/batch-action") +async def creds_batch_action( + request: CredFileBatchActionRequest, + token: str = Depends(verify_panel_token), + mode: str = "geminicli" +): + """批量对凭证文件执行操作(启用/禁用/删除)""" + try: + mode = validate_mode(mode) + + action = request.action + filenames = request.filenames + + if not filenames: + raise HTTPException(status_code=400, detail="文件名列表不能为空") + + log.info(f"对 {len(filenames)} 个文件执行批量操作 '{action}'") + + success_count = 0 + errors = [] + + storage_adapter = await get_storage_adapter() + + for filename in filenames: + try: + # 验证文件名安全性 + if not filename.endswith(".json"): + errors.append(f"{filename}: 无效的文件类型") + continue + + # 对于删除操作,不需要检查凭证数据完整性 + # 对于其他操作,需要确保凭证数据存在 + if action != "delete": + credential_data = await storage_adapter.get_credential(filename, mode=mode) + if not credential_data: + errors.append(f"{filename}: 凭证不存在") + continue + + # 执行相应操作 + if action == "enable": + await credential_manager.set_cred_disabled(filename, False, mode=mode) + success_count += 1 + + elif action == "disable": + await credential_manager.set_cred_disabled(filename, True, mode=mode) + success_count += 1 + + elif action == "delete": + try: + delete_success = await credential_manager.remove_credential(filename, mode=mode) + if delete_success: + success_count += 1 + log.info(f"成功删除批量中的凭证: {filename}") + else: + errors.append(f"{filename}: 删除失败") + continue + except Exception as e: + errors.append(f"{filename}: 删除文件失败 - {str(e)}") + continue + else: + errors.append(f"{filename}: 无效的操作类型") + continue + + except Exception as e: + log.error(f"处理 {filename} 时出错: {e}") + errors.append(f"{filename}: 处理失败 - {str(e)}") + continue + + # 构建返回消息 + result_message = f"批量操作完成:成功处理 {success_count}/{len(filenames)} 个文件" + if errors: + result_message += "\n错误详情:\n" + "\n".join(errors) + + response_data = { + "success_count": success_count, + "total_count": len(filenames), + "errors": errors, + "message": result_message, + } + + return JSONResponse(content=response_data) + + except HTTPException: + raise + except Exception as e: + log.error(f"批量凭证文件操作失败: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@router.get("/creds/download/{filename}") +async def download_cred_file( + filename: str, + token: str = Depends(verify_panel_token), + mode: str = "geminicli" +): + """下载单个凭证文件""" + try: + mode = validate_mode(mode) + # 验证文件名安全性 + if not filename.endswith(".json"): + raise HTTPException(status_code=404, detail="无效的文件名") + + # 获取存储适配器 + storage_adapter = await get_storage_adapter() + + # 从存储系统获取凭证数据 + credential_data = await storage_adapter.get_credential(filename, mode=mode) + if not credential_data: + raise HTTPException(status_code=404, detail="文件不存在") + + # 转换为JSON字符串 + content = json.dumps(credential_data, ensure_ascii=False, indent=2) + + from fastapi.responses import Response + + return Response( + content=content, + media_type="application/json", + headers={"Content-Disposition": f"attachment; filename={filename}"}, + ) + + except HTTPException: + raise + except Exception as e: + log.error(f"下载凭证文件失败: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@router.post("/creds/fetch-email/{filename}") +async def fetch_user_email( + filename: str, + token: str = Depends(verify_panel_token), + mode: str = "geminicli" +): + """获取指定凭证文件的用户邮箱地址""" + try: + mode = validate_mode(mode) + return await fetch_user_email_common(filename, mode=mode) + except HTTPException: + raise + except Exception as e: + log.error(f"获取用户邮箱失败: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@router.post("/creds/refresh-all-emails") +async def refresh_all_user_emails( + token: str = Depends(verify_panel_token), + mode: str = "geminicli" +): + """刷新所有凭证文件的用户邮箱地址""" + try: + mode = validate_mode(mode) + return await refresh_all_user_emails_common(mode=mode) + except Exception as e: + log.error(f"批量获取用户邮箱失败: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@router.post("/creds/deduplicate-by-email") +async def deduplicate_credentials_by_email( + token: str = Depends(verify_panel_token), + mode: str = "geminicli" +): + """批量去重凭证文件 - 删除邮箱相同的凭证(只保留一个)""" + try: + mode = validate_mode(mode) + return await deduplicate_credentials_by_email_common(mode=mode) + except Exception as e: + log.error(f"批量去重凭证失败: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@router.get("/creds/download-all") +async def download_all_creds( + token: str = Depends(verify_panel_token), + mode: str = "geminicli" +): + """ + 打包下载所有凭证文件(流式处理,按需加载每个凭证数据) + 只在实际下载时才加载完整凭证内容,最大化性能 + """ + try: + mode = validate_mode(mode) + return await download_all_creds_common(mode=mode) + except HTTPException: + raise + except Exception as e: + log.error(f"打包下载失败: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@router.get("/config/get") +async def get_config(token: str = Depends(verify_panel_token)): + """获取当前配置""" + try: + + + # 读取当前配置(包括环境变量和TOML文件中的配置) + current_config = {} + + # 基础配置 + current_config["code_assist_endpoint"] = await config.get_code_assist_endpoint() + current_config["credentials_dir"] = await config.get_credentials_dir() + current_config["proxy"] = await config.get_proxy_config() or "" + + # 代理端点配置 + current_config["oauth_proxy_url"] = await config.get_oauth_proxy_url() + current_config["googleapis_proxy_url"] = await config.get_googleapis_proxy_url() + current_config["resource_manager_api_url"] = await config.get_resource_manager_api_url() + current_config["service_usage_api_url"] = await config.get_service_usage_api_url() + current_config["antigravity_api_url"] = await config.get_antigravity_api_url() + + # 自动封禁配置 + current_config["auto_ban_enabled"] = await config.get_auto_ban_enabled() + current_config["auto_ban_error_codes"] = await config.get_auto_ban_error_codes() + + # 429重试配置 + current_config["retry_429_max_retries"] = await config.get_retry_429_max_retries() + current_config["retry_429_enabled"] = await config.get_retry_429_enabled() + current_config["retry_429_interval"] = await config.get_retry_429_interval() + + # 抗截断配置 + current_config["anti_truncation_max_attempts"] = await config.get_anti_truncation_max_attempts() + + # 兼容性配置 + current_config["compatibility_mode_enabled"] = await config.get_compatibility_mode_enabled() + + # 思维链返回配置 + current_config["return_thoughts_to_frontend"] = await config.get_return_thoughts_to_frontend() + + # Antigravity流式转非流式配置 + current_config["antigravity_stream2nostream"] = await config.get_antigravity_stream2nostream() + + # 服务器配置 + current_config["host"] = await config.get_server_host() + current_config["port"] = await config.get_server_port() + current_config["api_password"] = await config.get_api_password() + current_config["panel_password"] = await config.get_panel_password() + current_config["password"] = await config.get_server_password() + + # 从存储系统读取配置 + storage_adapter = await get_storage_adapter() + storage_config = await storage_adapter.get_all_config() + + # 获取环境变量锁定的配置键 + env_locked_keys = get_env_locked_keys() + + # 合并存储系统配置(不覆盖环境变量) + for key, value in storage_config.items(): + if key not in env_locked_keys: + current_config[key] = value + + return JSONResponse(content={"config": current_config, "env_locked": list(env_locked_keys)}) + + except Exception as e: + log.error(f"获取配置失败: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@router.post("/config/save") +async def save_config(request: ConfigSaveRequest, token: str = Depends(verify_panel_token)): + """保存配置""" + try: + + new_config = request.config + + log.debug(f"收到的配置数据: {list(new_config.keys())}") + log.debug(f"收到的password值: {new_config.get('password', 'NOT_FOUND')}") + + # 验证配置项 + if "retry_429_max_retries" in new_config: + if ( + not isinstance(new_config["retry_429_max_retries"], int) + or new_config["retry_429_max_retries"] < 0 + ): + raise HTTPException(status_code=400, detail="最大429重试次数必须是大于等于0的整数") + + if "retry_429_enabled" in new_config: + if not isinstance(new_config["retry_429_enabled"], bool): + raise HTTPException(status_code=400, detail="429重试开关必须是布尔值") + + # 验证新的配置项 + if "retry_429_interval" in new_config: + try: + interval = float(new_config["retry_429_interval"]) + if interval < 0.01 or interval > 10: + raise HTTPException(status_code=400, detail="429重试间隔必须在0.01-10秒之间") + except (ValueError, TypeError): + raise HTTPException(status_code=400, detail="429重试间隔必须是有效的数字") + + if "anti_truncation_max_attempts" in new_config: + if ( + not isinstance(new_config["anti_truncation_max_attempts"], int) + or new_config["anti_truncation_max_attempts"] < 1 + or new_config["anti_truncation_max_attempts"] > 10 + ): + raise HTTPException( + status_code=400, detail="抗截断最大重试次数必须是1-10之间的整数" + ) + + if "compatibility_mode_enabled" in new_config: + if not isinstance(new_config["compatibility_mode_enabled"], bool): + raise HTTPException(status_code=400, detail="兼容性模式开关必须是布尔值") + + if "return_thoughts_to_frontend" in new_config: + if not isinstance(new_config["return_thoughts_to_frontend"], bool): + raise HTTPException(status_code=400, detail="思维链返回开关必须是布尔值") + + if "antigravity_stream2nostream" in new_config: + if not isinstance(new_config["antigravity_stream2nostream"], bool): + raise HTTPException(status_code=400, detail="Antigravity流式转非流式开关必须是布尔值") + + # 验证服务器配置 + if "host" in new_config: + if not isinstance(new_config["host"], str) or not new_config["host"].strip(): + raise HTTPException(status_code=400, detail="服务器主机地址不能为空") + + if "port" in new_config: + if ( + not isinstance(new_config["port"], int) + or new_config["port"] < 1 + or new_config["port"] > 65535 + ): + raise HTTPException(status_code=400, detail="端口号必须是1-65535之间的整数") + + if "api_password" in new_config: + if not isinstance(new_config["api_password"], str): + raise HTTPException(status_code=400, detail="API访问密码必须是字符串") + + if "panel_password" in new_config: + if not isinstance(new_config["panel_password"], str): + raise HTTPException(status_code=400, detail="控制面板密码必须是字符串") + + if "password" in new_config: + if not isinstance(new_config["password"], str): + raise HTTPException(status_code=400, detail="访问密码必须是字符串") + + # 获取环境变量锁定的配置键 + env_locked_keys = get_env_locked_keys() + + # 直接使用存储适配器保存配置 + storage_adapter = await get_storage_adapter() + for key, value in new_config.items(): + if key not in env_locked_keys: + await storage_adapter.set_config(key, value) + if key in ("password", "api_password", "panel_password"): + log.debug(f"设置{key}字段为: {value}") + + # 重新加载配置缓存(关键!) + await config.reload_config() + + # 验证保存后的结果 + test_api_password = await config.get_api_password() + test_panel_password = await config.get_panel_password() + test_password = await config.get_server_password() + log.debug(f"保存后立即读取的API密码: {test_api_password}") + log.debug(f"保存后立即读取的面板密码: {test_panel_password}") + log.debug(f"保存后立即读取的通用密码: {test_password}") + + # 构建响应消息 + response_data = { + "message": "配置保存成功", + "saved_config": {k: v for k, v in new_config.items() if k not in env_locked_keys}, + } + + return JSONResponse(content=response_data) + + except HTTPException: + raise + except Exception as e: + log.error(f"保存配置失败: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +# ============================================================================= +# 实时日志WebSocket (Real-time Logs WebSocket) +# ============================================================================= + + +@router.post("/auth/logs/clear") +async def clear_logs(token: str = Depends(verify_panel_token)): + """清空日志文件""" + try: + # 直接使用环境变量获取日志文件路径 + log_file_path = os.getenv("LOG_FILE", "log.txt") + + # 检查日志文件是否存在 + if os.path.exists(log_file_path): + try: + # 清空文件内容(保留文件),确保以UTF-8编码写入 + with open(log_file_path, "w", encoding="utf-8", newline="") as f: + f.write("") + f.flush() # 强制刷新到磁盘 + log.info(f"日志文件已清空: {log_file_path}") + + # 通知所有WebSocket连接日志已清空 + await manager.broadcast("--- 日志文件已清空 ---") + + return JSONResponse( + content={"message": f"日志文件已清空: {os.path.basename(log_file_path)}"} + ) + except Exception as e: + log.error(f"清空日志文件失败: {e}") + raise HTTPException(status_code=500, detail=f"清空日志文件失败: {str(e)}") + else: + return JSONResponse(content={"message": "日志文件不存在"}) + + except Exception as e: + log.error(f"清空日志文件失败: {e}") + raise HTTPException(status_code=500, detail=f"清空日志文件失败: {str(e)}") + + +@router.get("/auth/logs/download") +async def download_logs(token: str = Depends(verify_panel_token)): + """下载日志文件""" + try: + # 直接使用环境变量获取日志文件路径 + log_file_path = os.getenv("LOG_FILE", "log.txt") + + # 检查日志文件是否存在 + if not os.path.exists(log_file_path): + raise HTTPException(status_code=404, detail="日志文件不存在") + + # 检查文件是否为空 + file_size = os.path.getsize(log_file_path) + if file_size == 0: + raise HTTPException(status_code=404, detail="日志文件为空") + + # 生成文件名(包含时间戳) + timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") + filename = f"gcli2api_logs_{timestamp}.txt" + + log.info(f"下载日志文件: {log_file_path}") + + return FileResponse( + path=log_file_path, + filename=filename, + media_type="text/plain", + headers={"Content-Disposition": f"attachment; filename={filename}"}, + ) + + except HTTPException: + raise + except Exception as e: + log.error(f"下载日志文件失败: {e}") + raise HTTPException(status_code=500, detail=f"下载日志文件失败: {str(e)}") + + +@router.websocket("/auth/logs/stream") +async def websocket_logs(websocket: WebSocket): + """WebSocket端点,用于实时日志流""" + # 检查连接数限制 + if not await manager.connect(websocket): + return + + try: + # 直接使用环境变量获取日志文件路径 + log_file_path = os.getenv("LOG_FILE", "log.txt") + + # 发送初始日志(限制为最后50行,减少内存占用) + if os.path.exists(log_file_path): + try: + with open(log_file_path, "r", encoding="utf-8") as f: + lines = f.readlines() + # 只发送最后50行,减少初始内存消耗 + for line in lines[-50:]: + if line.strip(): + await websocket.send_text(line.strip()) + except Exception as e: + await websocket.send_text(f"Error reading log file: {e}") + + # 监控日志文件变化 + last_size = os.path.getsize(log_file_path) if os.path.exists(log_file_path) else 0 + max_read_size = 8192 # 限制单次读取大小为8KB,防止大量日志造成内存激增 + check_interval = 2 # 增加检查间隔,减少CPU和I/O开销 + + # 创建后台任务监听客户端断开 + # 即使没有日志更新,receive_text() 也能即时感知断开 + async def listen_for_disconnect(): + try: + while True: + await websocket.receive_text() + except Exception: + pass + + listener_task = asyncio.create_task(listen_for_disconnect()) + + try: + while websocket.client_state == WebSocketState.CONNECTED: + # 使用 asyncio.wait 同时等待定时器和断开信号 + # timeout=check_interval 替代了 asyncio.sleep + done, pending = await asyncio.wait( + [listener_task], + timeout=check_interval, + return_when=asyncio.FIRST_COMPLETED + ) + + # 如果监听任务结束(通常是因为连接断开),则退出循环 + if listener_task in done: + break + + if os.path.exists(log_file_path): + current_size = os.path.getsize(log_file_path) + if current_size > last_size: + # 限制读取大小,防止单次读取过多内容 + read_size = min(current_size - last_size, max_read_size) + + try: + with open(log_file_path, "r", encoding="utf-8", errors="replace") as f: + f.seek(last_size) + new_content = f.read(read_size) + + # 处理编码错误的情况 + if not new_content: + last_size = current_size + continue + + # 分行发送,避免发送不完整的行 + lines = new_content.splitlines(keepends=True) + if lines: + # 如果最后一行没有换行符,保留到下次处理 + if not lines[-1].endswith("\n") and len(lines) > 1: + # 除了最后一行,其他都发送 + for line in lines[:-1]: + if line.strip(): + await websocket.send_text(line.rstrip()) + # 更新位置,但要退回最后一行的字节数 + last_size += len(new_content.encode("utf-8")) - len( + lines[-1].encode("utf-8") + ) + else: + # 所有行都发送 + for line in lines: + if line.strip(): + await websocket.send_text(line.rstrip()) + last_size += len(new_content.encode("utf-8")) + except UnicodeDecodeError as e: + # 遇到编码错误时,跳过这部分内容 + log.warning(f"WebSocket日志读取编码错误: {e}, 跳过部分内容") + last_size = current_size + except Exception as e: + await websocket.send_text(f"Error reading new content: {e}") + # 发生其他错误时,重置文件位置 + last_size = current_size + + # 如果文件被截断(如清空日志),重置位置 + elif current_size < last_size: + last_size = 0 + await websocket.send_text("--- 日志已清空 ---") + + finally: + # 确保清理监听任务 + if not listener_task.done(): + listener_task.cancel() + try: + await listener_task + except asyncio.CancelledError: + pass + + except WebSocketDisconnect: + pass + except Exception as e: + log.error(f"WebSocket logs error: {e}") + finally: + manager.disconnect(websocket) + + +async def verify_credential_project_common(filename: str, mode: str = "geminicli") -> JSONResponse: + """验证并重新获取凭证的project id的通用函数""" + mode = validate_mode(mode) + + # 验证文件名 + if not filename.endswith(".json"): + raise HTTPException(status_code=400, detail="无效的文件名") + + + storage_adapter = await get_storage_adapter() + + # 获取凭证数据 + credential_data = await storage_adapter.get_credential(filename, mode=mode) + if not credential_data: + raise HTTPException(status_code=404, detail="凭证不存在") + + # 创建凭证对象 + credentials = Credentials.from_dict(credential_data) + + # 确保token有效(自动刷新) + token_refreshed = await credentials.refresh_if_needed() + + # 如果token被刷新了,更新存储 + if token_refreshed: + log.info(f"Token已自动刷新: {filename} (mode={mode})") + credential_data = credentials.to_dict() + await storage_adapter.store_credential(filename, credential_data, mode=mode) + + # 获取API端点和对应的User-Agent + if mode == "antigravity": + api_base_url = await get_antigravity_api_url() + user_agent = ANTIGRAVITY_USER_AGENT + else: + api_base_url = await get_code_assist_endpoint() + user_agent = GEMINICLI_USER_AGENT + + # 重新获取project id + project_id = await fetch_project_id( + access_token=credentials.access_token, + user_agent=user_agent, + api_base_url=api_base_url + ) + + if project_id: + # 更新凭证数据中的project_id + credential_data["project_id"] = project_id + await storage_adapter.store_credential(filename, credential_data, mode=mode) + + # 检验成功后自动解除禁用状态并清除错误码 + await storage_adapter.update_credential_state(filename, { + "disabled": False, + "error_codes": [] + }, mode=mode) + + log.info(f"检验 {mode} 凭证成功: {filename} - Project ID: {project_id} - 已解除禁用并清除错误码") + + return JSONResponse(content={ + "success": True, + "filename": filename, + "project_id": project_id, + "message": "检验成功!Project ID已更新,已解除禁用状态并清除错误码,403错误应该已恢复" + }) + else: + return JSONResponse( + status_code=400, + content={ + "success": False, + "filename": filename, + "message": "检验失败:无法获取Project ID,请检查凭证是否有效" + } + ) + + +@router.post("/creds/verify-project/{filename}") +async def verify_credential_project( + filename: str, + token: str = Depends(verify_panel_token), + mode: str = "geminicli" +): + """ + 检验凭证的project id,重新获取project id + 检验成功可以使403错误恢复 + """ + try: + mode = validate_mode(mode) + return await verify_credential_project_common(filename, mode=mode) + except HTTPException: + raise + except Exception as e: + log.error(f"检验凭证Project ID失败 {filename}: {e}") + raise HTTPException(status_code=500, detail=f"检验失败: {str(e)}") + + +@router.get("/creds/quota/{filename}") +async def get_credential_quota( + filename: str, + token: str = Depends(verify_panel_token), + mode: str = "antigravity" +): + """ + 获取指定凭证的额度信息(仅支持 antigravity 模式) + """ + try: + mode = validate_mode(mode) + # 验证文件名 + if not filename.endswith(".json"): + raise HTTPException(status_code=400, detail="无效的文件名") + + + storage_adapter = await get_storage_adapter() + + # 获取凭证数据 + credential_data = await storage_adapter.get_credential(filename, mode=mode) + if not credential_data: + raise HTTPException(status_code=404, detail="凭证不存在") + + # 使用 Credentials 对象自动处理 token 刷新 + from .google_oauth_api import Credentials + + creds = Credentials.from_dict(credential_data) + + # 自动刷新 token(如果需要) + await creds.refresh_if_needed() + + # 如果 token 被刷新了,更新存储 + updated_data = creds.to_dict() + if updated_data != credential_data: + log.info(f"Token已自动刷新: {filename}") + await storage_adapter.store_credential(filename, updated_data, mode=mode) + credential_data = updated_data + + # 获取访问令牌 + access_token = credential_data.get("access_token") or credential_data.get("token") + if not access_token: + raise HTTPException(status_code=400, detail="凭证中没有访问令牌") + + # 获取额度信息 + quota_info = await fetch_quota_info(access_token) + + if quota_info.get("success"): + return JSONResponse(content={ + "success": True, + "filename": filename, + "models": quota_info.get("models", {}) + }) + else: + return JSONResponse( + status_code=400, + content={ + "success": False, + "filename": filename, + "error": quota_info.get("error", "未知错误") + } + ) + + except HTTPException: + raise + except Exception as e: + log.error(f"获取凭证额度失败 {filename}: {e}") + raise HTTPException(status_code=500, detail=f"获取额度失败: {str(e)}") + + +@router.get("/version/info") +async def get_version_info(check_update: bool = False): + """ + 获取当前版本信息 - 从version.txt读取 + 可选参数 check_update: 是否检查GitHub上的最新版本 + """ + try: + # 获取项目根目录 + project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + version_file = os.path.join(project_root, "version.txt") + + # 读取version.txt + if not os.path.exists(version_file): + return JSONResponse({ + "success": False, + "error": "version.txt文件不存在" + }) + + version_data = {} + with open(version_file, 'r', encoding='utf-8') as f: + for line in f: + line = line.strip() + if '=' in line: + key, value = line.split('=', 1) + version_data[key] = value + + # 检查必要字段 + if 'short_hash' not in version_data: + return JSONResponse({ + "success": False, + "error": "version.txt格式错误" + }) + + response_data = { + "success": True, + "version": version_data.get('short_hash', 'unknown'), + "full_hash": version_data.get('full_hash', ''), + "message": version_data.get('message', ''), + "date": version_data.get('date', '') + } + + # 如果需要检查更新 + if check_update: + try: + from src.httpx_client import get_async + + # 直接获取GitHub上的version.txt文件 + github_version_url = "https://raw.githubusercontent.com/su-kaka/gcli2api/refs/heads/master/version.txt" + + # 使用统一的httpx客户端 + resp = await get_async(github_version_url, timeout=10.0) + + if resp.status_code == 200: + # 解析远程version.txt + remote_version_data = {} + for line in resp.text.strip().split('\n'): + line = line.strip() + if '=' in line: + key, value = line.split('=', 1) + remote_version_data[key] = value + + latest_hash = remote_version_data.get('full_hash', '') + latest_short_hash = remote_version_data.get('short_hash', '') + current_hash = version_data.get('full_hash', '') + + has_update = (current_hash != latest_hash) if current_hash and latest_hash else None + + response_data['check_update'] = True + response_data['has_update'] = has_update + response_data['latest_version'] = latest_short_hash + response_data['latest_hash'] = latest_hash + response_data['latest_message'] = remote_version_data.get('message', '') + response_data['latest_date'] = remote_version_data.get('date', '') + else: + # GitHub获取失败,但不影响基本版本信息 + response_data['check_update'] = False + response_data['update_error'] = f"GitHub返回错误: {resp.status_code}" + + except Exception as e: + log.debug(f"检查更新失败: {e}") + response_data['check_update'] = False + response_data['update_error'] = str(e) + + return JSONResponse(response_data) + + except Exception as e: + log.error(f"获取版本信息失败: {e}") + return JSONResponse({ + "success": False, + "error": str(e) + }) + + + + diff --git a/start.bat b/start.bat new file mode 100644 index 0000000000000000000000000000000000000000..adb1ee9f9bf1c9de1e34c61cb5426ba63bccf586 --- /dev/null +++ b/start.bat @@ -0,0 +1,7 @@ +git fetch --all +for /f "delims=" %%b in ('git rev-parse --abbrev-ref HEAD') do set branch=%%b +git reset --hard origin/%branch% +uv sync +call .venv\Scripts\activate.bat +python web.py +pause \ No newline at end of file diff --git a/start.sh b/start.sh new file mode 100644 index 0000000000000000000000000000000000000000..1b4a2f05f0941acf257ab84d15e88506544ae1e1 --- /dev/null +++ b/start.sh @@ -0,0 +1,6 @@ +echo "强制同步项目代码,忽略本地修改..." +git fetch --all +git reset --hard origin/$(git rev-parse --abbrev-ref HEAD) +uv sync +source .venv/bin/activate +python web.py \ No newline at end of file diff --git a/termux-install.sh b/termux-install.sh new file mode 100644 index 0000000000000000000000000000000000000000..4a9d8af5452802c191fe364996a03ee9b94c2590 --- /dev/null +++ b/termux-install.sh @@ -0,0 +1,160 @@ +#!/bin/bash + +# 避免交互式提示 +export DEBIAN_FRONTEND=noninteractive + +if [ "$(whoami)" = "root" ]; then + echo "检测到root用户,正在退出..." + exit +fi + +echo "检查Termux镜像源配置..." + +# 检查当前镜像源是否已经是Cloudflare镜像 +target_mirror="https://packages-cf.termux.dev/apt/termux-main" +fallback_mirror="https://packages.termux.dev/apt/termux-main" +if [ -f "$PREFIX/etc/apt/sources.list" ] && grep -q "$target_mirror" "$PREFIX/etc/apt/sources.list"; then + echo "✅ 镜像源已经配置为Cloudflare镜像,跳过修改" +else + echo "正在设置Termux镜像为Cloudflare镜像..." + + # 备份原始sources.list文件 + if [ -f "$PREFIX/etc/apt/sources.list" ]; then + echo "备份原始sources.list文件..." + cp "$PREFIX/etc/apt/sources.list" "$PREFIX/etc/apt/sources.list.backup.$(date +%s)" + fi + + # 写入新的镜像源 + echo "写入新的镜像源配置..." + cat > "$PREFIX/etc/apt/sources.list" << 'EOF' +# Cloudflare镜像源 +deb https://packages-cf.termux.dev/apt/termux-main stable main +EOF + + echo "✅ 镜像源已更新为: $target_mirror" +fi + +ensure_dpkg_ready() { + echo "检查并修复 dpkg/apt 状态..." + # 等待可能存在的 apt/dpkg 进程结束 + if pgrep -f "apt|dpkg" >/dev/null 2>&1; then + echo "检测到 apt/dpkg 正在运行,等待其结束..." + while pgrep -f "apt|dpkg" >/dev/null 2>&1; do sleep 1; done + fi + # 清理可能残留的锁(若无进程) + for f in "$PREFIX/var/lib/dpkg/lock" \ + "$PREFIX/var/lib/apt/lists/lock" \ + "$PREFIX/var/cache/apt/archives/lock"; do + [ -e "$f" ] && rm -f "$f" + done + # 尝试继续未完成的配置 + dpkg --configure -a || true +} + + +# 更新包列表并检查错误 +echo "正在更新包列表..." +ensure_dpkg_ready +apt_output=$(apt update 2>&1) +if [ $? -ne 0 ]; then + if echo "$apt_output" | grep -qi "is not signed"; then + echo "⚠️ 检测到仓库未签名,尝试切换到官方镜像并修复 keyring..." + # 切换到官方镜像 + sed -i "s#${target_mirror}#${fallback_mirror}#g" "$PREFIX/etc/apt/sources.list" || true + # 清理列表与锁 + rm -rf "$PREFIX/var/lib/apt/lists/"* || true + rm -f "$PREFIX/var/lib/dpkg/lock" "$PREFIX/var/lib/apt/lists/lock" "$PREFIX/var/cache/apt/archives/lock" || true + # 重新安装 termux-keyring(若已安装则强制重装) + apt-get install --reinstall -y termux-keyring || true + # 再次更新 + ensure_dpkg_ready + apt update + else + echo "apt update 失败,错误信息:" + echo "$apt_output" | head -20 + exit 1 + fi +else + echo "$apt_output" +fi + +echo "✅ Termux镜像设置完成!" +echo "📁 原始配置已备份到: $PREFIX/etc/apt/sources.list.backup.*" +echo "🔄 如需恢复原始镜像,可以运行:" +echo " cp \$PREFIX/etc/apt/sources.list.backup.* \$PREFIX/etc/apt/sources.list && apt update" + +# 检查是否需要更新包管理器和安装软件 +need_update=false +packages_to_install="" + +# 检查 uv 是否已安装 +if ! command -v uv &> /dev/null; then + need_update=true + packages_to_install="$packages_to_install uv" +fi + +# 检查 python 是否已安装 +if ! command -v python &> /dev/null; then + need_update=true + packages_to_install="$packages_to_install python" +fi + +# 检查 nodejs 是否已安装 +if ! command -v node &> /dev/null; then + need_update=true + packages_to_install="$packages_to_install nodejs" +fi + +# 检查 git 是否已安装 +if ! command -v git &> /dev/null; then + need_update=true + packages_to_install="$packages_to_install git" +fi + +# 如果需要安装软件,则更新包管理器并安装 +if [ "$need_update" = true ]; then + echo "正在更新包管理器..." + ensure_dpkg_ready + echo "正在安装缺失的软件包: $packages_to_install" + pkg install $packages_to_install -y +else + echo "所需软件包已全部安装,跳过更新和安装步骤" +fi + +# 检查 pm2 是否已安装 +if ! command -v pm2 &> /dev/null; then + echo "正在安装 pm2..." + npm install pm2 -g +else + echo "pm2 已安装,跳过安装" +fi + +# 项目目录处理逻辑 +if [ -f "./web.py" ]; then + # Already in target directory; skip clone and cd + echo "已在目标目录中,跳过克隆操作" +elif [ -f "./gcli2api/web.py" ]; then + echo "进入已存在的 gcli2api 目录" + cd ./gcli2api +else + echo "克隆项目仓库..." + git clone https://github.com/su-kaka/gcli2api.git + cd ./gcli2api +fi + +echo "强制同步项目代码,忽略本地修改..." +git fetch --all +git reset --hard origin/$(git rev-parse --abbrev-ref HEAD) + +echo "初始化 uv 环境..." +uv init + +echo "安装 Python 依赖..." +uv add -r requirements-termux.txt + +echo "激活虚拟环境并启动服务..." +source .venv/bin/activate +pm2 start .venv/bin/python --name web -- web.py +cd .. + +echo "✅ 安装完成!服务已启动。" \ No newline at end of file diff --git a/termux-start.sh b/termux-start.sh new file mode 100644 index 0000000000000000000000000000000000000000..d4c172f3a250bd5caf1a39f609508e29920b6509 --- /dev/null +++ b/termux-start.sh @@ -0,0 +1,6 @@ +echo "强制同步项目代码,忽略本地修改..." +git fetch --all +git reset --hard origin/$(git rev-parse --abbrev-ref HEAD) +uv add -r requirements-termux.txt +source .venv/bin/activate +pm2 start .venv/bin/python --name web -- web.py \ No newline at end of file diff --git a/version.txt b/version.txt new file mode 100644 index 0000000000000000000000000000000000000000..b43d38ec6b5ca49850cdbe28a5f30bfec2a3895c --- /dev/null +++ b/version.txt @@ -0,0 +1,4 @@ +full_hash=1dbd7bfb11d5216f19733f3ad918ed27ef5c015d +short_hash=1dbd7bf +message=Merge branch 'master' of https://github.com/su-kaka/gcli2api +date=2026-01-11 11:11:24 +0800 diff --git a/web.py b/web.py new file mode 100644 index 0000000000000000000000000000000000000000..c9b3d45c5c0009b47582b23031e60c8de938b052 --- /dev/null +++ b/web.py @@ -0,0 +1,190 @@ +""" +Main Web Integration - Integrates all routers and modules +集合router并开启主服务 +""" + +import asyncio +from contextlib import asynccontextmanager + +from fastapi import FastAPI, Response +from fastapi.middleware.cors import CORSMiddleware +from fastapi.staticfiles import StaticFiles + +from config import get_server_host, get_server_port +from log import log + +# Import managers and utilities +from src.credential_manager import CredentialManager + +# Import all routers +from src.router.antigravity.openai import router as antigravity_openai_router +from src.router.antigravity.gemini import router as antigravity_gemini_router +from src.router.antigravity.anthropic import router as antigravity_anthropic_router +from src.router.antigravity.model_list import router as antigravity_model_list_router +from src.router.geminicli.openai import router as openai_router +from src.router.geminicli.gemini import router as gemini_router +from src.router.geminicli.anthropic import router as geminicli_anthropic_router +from src.router.geminicli.model_list import router as geminicli_model_list_router +from src.task_manager import shutdown_all_tasks +from src.web_routes import router as web_router + +# 全局凭证管理器 +global_credential_manager = None + + +@asynccontextmanager +async def lifespan(app: FastAPI): + """应用生命周期管理""" + global global_credential_manager + + log.info("启动 GCLI2API 主服务") + + # 初始化配置缓存(优先执行) + try: + import config + await config.init_config() + log.info("配置缓存初始化成功") + except Exception as e: + log.error(f"配置缓存初始化失败: {e}") + + # 初始化全局凭证管理器 + try: + global_credential_manager = CredentialManager() + await global_credential_manager.initialize() + log.info("凭证管理器初始化成功") + except Exception as e: + log.error(f"凭证管理器初始化失败: {e}") + global_credential_manager = None + + # OAuth回调服务器将在需要时按需启动 + + yield + + # 清理资源 + log.info("开始关闭 GCLI2API 主服务") + + # 首先关闭所有异步任务 + try: + await shutdown_all_tasks(timeout=10.0) + log.info("所有异步任务已关闭") + except Exception as e: + log.error(f"关闭异步任务时出错: {e}") + + # 然后关闭凭证管理器 + if global_credential_manager: + try: + await global_credential_manager.close() + log.info("凭证管理器已关闭") + except Exception as e: + log.error(f"关闭凭证管理器时出错: {e}") + + log.info("GCLI2API 主服务已停止") + + +# 创建FastAPI应用 +app = FastAPI( + title="GCLI2API", + description="Gemini API proxy with OpenAI compatibility", + version="2.0.0", + lifespan=lifespan, +) + +# CORS中间件 +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +# 挂载路由器 +# OpenAI兼容路由 - 处理OpenAI格式请求 +app.include_router(openai_router, prefix="", tags=["OpenAI Compatible API"]) + +# Gemini原生路由 - 处理Gemini格式请求 +app.include_router(gemini_router, prefix="", tags=["Gemini Native API"]) + +# Geminicli模型列表路由 - 处理Gemini格式的模型列表请求 +app.include_router(geminicli_model_list_router, prefix="", tags=["Geminicli Model List"]) + +# Antigravity路由 - 处理OpenAI格式请求并转换为Antigravity API +app.include_router(antigravity_openai_router, prefix="", tags=["Antigravity OpenAI API"]) + +# Antigravity路由 - 处理Gemini格式请求并转换为Antigravity API +app.include_router(antigravity_gemini_router, prefix="", tags=["Antigravity Gemini API"]) + +# Antigravity模型列表路由 - 处理Gemini格式的模型列表请求 +app.include_router(antigravity_model_list_router, prefix="", tags=["Antigravity Model List"]) + +# Antigravity Anthropic Messages 路由 - Anthropic Messages 格式兼容 +app.include_router(antigravity_anthropic_router, prefix="", tags=["Antigravity Anthropic Messages"]) + +# Geminicli Anthropic Messages 路由 - Anthropic Messages 格式兼容 (Geminicli) +app.include_router(geminicli_anthropic_router, prefix="", tags=["Geminicli Anthropic Messages"]) + +# Web路由 - 包含认证、凭证管理和控制面板功能 +app.include_router(web_router, prefix="", tags=["Web Interface"]) + +# 静态文件路由 - 服务docs目录下的文件(如捐赠图片) +app.mount("/docs", StaticFiles(directory="docs"), name="docs") + +# 静态文件路由 - 服务front目录下的文件(HTML、JS、CSS等) +app.mount("/front", StaticFiles(directory="front"), name="front") + + +# 保活接口(仅响应 HEAD) +@app.head("/keepalive") +async def keepalive() -> Response: + return Response(status_code=200) + + +def get_credential_manager(): + """获取全局凭证管理器实例""" + return global_credential_manager + + +# 导出给其他模块使用 +__all__ = ["app", "get_credential_manager"] + + +async def main(): + """异步主启动函数""" + from hypercorn.asyncio import serve + from hypercorn.config import Config + + # 日志系统现在直接使用环境变量,无需初始化 + # 从环境变量或配置获取端口和主机 + port = await get_server_port() + host = await get_server_host() + + log.info("=" * 60) + log.info("启动 GCLI2API") + log.info("=" * 60) + log.info(f"控制面板: http://127.0.0.1:{port}") + log.info("=" * 60) + log.info("API端点:") + log.info(f" Geminicli (OpenAI格式): http://127.0.0.1:{port}/v1") + log.info(f" Geminicli (Claude格式): http://127.0.0.1:{port}/v1") + log.info(f" Geminicli (Gemini格式): http://127.0.0.1:{port}") + + log.info(f" Antigravity (OpenAI格式): http://127.0.0.1:{port}/antigravity/v1") + log.info(f" Antigravity (Claude格式): http://127.0.0.1:{port}/antigravity/v1") + log.info(f" Antigravity (Gemini格式): http://127.0.0.1:{port}/antigravity") + + # 配置hypercorn + config = Config() + config.bind = [f"{host}:{port}"] + config.accesslog = "-" + config.errorlog = "-" + config.loglevel = "INFO" + + # 设置连接超时 + config.keep_alive_timeout = 600 # 10分钟 + config.read_timeout = 600 # 10分钟读取超时 + + await serve(app, config) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/zeabur.yaml b/zeabur.yaml new file mode 100644 index 0000000000000000000000000000000000000000..961526ec3e86c4bb84379735e30a67d429c02362 --- /dev/null +++ b/zeabur.yaml @@ -0,0 +1,50 @@ +apiVersion: zeabur.com/v1 +kind: Template +metadata: + name: gcli2api +spec: + description: "将 GeminiCLI 转换为 OpenAI 和 GEMINI API 接口" + tags: + - ai + - api + - gemini + - openai + variables: + - key: PASSWORD + type: STRING + name: API密码 + description: 用于访问API和控制面板的密码 + - key: DOMAIN + type: DOMAIN + name: 域名 + description: 服务访问域名 + readme: | + # GeminiCLI to API + + 将 GeminiCLI 转换为 OpenAI 和 GEMINI API 接口 + + 部署后请访问您的域名进行配置。 + services: + - name: gcli2api + template: PREBUILT + spec: + id: gcli2api + name: gcli2api + source: + image: ghcr.io/su-kaka/gcli2api:latest + ports: + - id: web + port: 7861 + type: HTTP + env: + PASSWORD: + default: ${PASSWORD} + PORT: + default: "7861" + HOST: + default: "0.0.0.0" + DOMAIN: + default: ${DOMAIN} + volumes: + - id: creds + dir: /app/creds \ No newline at end of file