diff --git a/.env.example b/.env.example new file mode 100644 index 0000000000000000000000000000000000000000..1290db9c2a41aeecd5afcc5dcfa0b6686b9eebc2 --- /dev/null +++ b/.env.example @@ -0,0 +1,175 @@ +# ================================================================ +# GCLI2API 环境变量配置示例文件 +# 复制此文件为 .env 并根据需要修改配置值 +# ================================================================ + +# ================================================================ +# 服务器配置 +# ================================================================ + +# 服务器监听地址 +# 默认: 0.0.0.0 (监听所有网络接口) +HOST=0.0.0.0 + +# 服务器端口 +# 默认: 7861 +PORT=7861 + +# ================================================================ +# 密码配置 (支持分离密码) +# ================================================================ + +# 聊天API访问密码 (用于OpenAI和Gemini API端点认证) +# 默认: 继承通用密码或 pwd +API_PASSWORD=your_api_password + +# 控制面板访问密码 (用于Web界面登录认证) +# 默认: 继承通用密码或 pwd +PANEL_PASSWORD=your_panel_password + +# 通用访问密码 (兼容性保留) +# 设置后会覆盖上述两个专用密码,优先级最高 +# 如果只想使用一个密码,设置此项即可 +# 默认: pwd +PASSWORD=pwd + +# ================================================================ +# 存储配置 +# ================================================================ + +# 存储后端优先级: PostgreSQL > MongoDB > 本地sqlite文件存储 +# 系统会自动选择可用的最高优先级存储后端 + +# PostgreSQL 分布式存储模式配置 (最高优先级) +# 设置 POSTGRESQL_URI 后自动启用 PostgreSQL 模式 +# 本地 PostgreSQL: postgresql://user:password@localhost:5432/gcli2api +# 带 SSL: postgresql://user:password@host:5432/gcli2api?sslmode=require +# 默认: 无 (不启用 PostgreSQL 存储) +POSTGRESQL_URI=postgresql://user:password@localhost:5432/gcli2api + +# MongoDB 分布式存储模式配置 (第二优先级) +# 设置 MONGODB_URI 后自动启用 MongoDB 模式,不再使用本地文件存储 + +# Redis 缓存存储配置 +# 设置 REDIS_URL 后自动启用 Redis 模式,性能最佳,可大幅降低 MongoDB 的读写压力 +# 本地 Redis: redis://127.0.0.1:6379/0 +# 带密码: redis://:password@127.0.0.1:6379/0 +# 默认: 无 (不启用 Redis 缓存) +REDIS_URL=redis://127.0.0.1:6379/0 + +# MongoDB 连接字符串 (设置后启用 MongoDB 分布式存储模式) +# 本地 MongoDB: mongodb://localhost:27017 +# 带认证: mongodb://admin:password@localhost:27017/admin +# MongoDB Atlas: mongodb+srv://username:password@cluster.mongodb.net +# 副本集: mongodb://host1:27017,host2:27017,host3:27017/gcli2api?replicaSet=rs0 +# 默认: 无 (使用本地文件存储) +MONGODB_URI=mongodb://localhost:27017 + +# MongoDB 数据库名称 (仅在启用 MongoDB 模式时有效) +# 默认: gcli2api +MONGODB_DATABASE=gcli2api + +# ================================================================ +# Google API 配置 +# ================================================================ + +# 凭证文件目录 (仅在文件存储模式下使用) +# 默认: ./creds +CREDENTIALS_DIR=./creds + + +# 代理配置 (可选) +# 支持 http, https, socks5 代理 +# 格式: http://proxy:port, https://proxy:port, socks5://proxy:port +PROXY=http://localhost:7890 + +# Google API 代理 URL 配置 (可选) + +# Google Code Assist API 端点 +# 默认: https://cloudcode-pa.googleapis.com +CODE_ASSIST_ENDPOINT=https://cloudcode-pa.googleapis.com +# 用于Google OAuth2认证的代理URL +# 默认: https://oauth2.googleapis.com +OAUTH_PROXY_URL=https://oauth2.googleapis.com + +# 用于Google APIs调用的代理URL +# 默认: https://www.googleapis.com +GOOGLEAPIS_PROXY_URL=https://www.googleapis.com + +# 用于Google Cloud Resource Manager API的URL +# 默认: https://cloudresourcemanager.googleapis.com +RESOURCE_MANAGER_API_URL=https://cloudresourcemanager.googleapis.com + +# 用于Google Cloud Service Usage API的URL +# 默认: https://serviceusage.googleapis.com +SERVICE_USAGE_API_URL=https://serviceusage.googleapis.com + +# ================================================================ +# 错误处理和重试配置 +# ================================================================ + +# 是否启用自动封禁功能 +# 当凭证返回特定错误码时自动禁用该凭证 +# 默认: false +AUTO_BAN=false + +# 自动封禁的错误码列表 (逗号分隔) +# 默认: 400,403 +AUTO_BAN_ERROR_CODES=403 + +# 是否启用 429 错误重试 +# 默认: true +RETRY_429_ENABLED=true + +# 429 错误最大重试次数 +# 默认: 5 +RETRY_429_MAX_RETRIES=5 + +# 429 错误重试间隔 (秒) +# 默认: 1 +RETRY_429_INTERVAL=1 + +# ================================================================ +# 日志配置 +# ================================================================ + +# 日志级别 +# 可选值: debug, info, warning, error, critical +# 默认: info +LOG_LEVEL=info + +# 日志文件路径 +# 默认: log.txt +LOG_FILE=log.txt + +# ================================================================ +# 高级功能配置 +# ================================================================ + +# 流式抗截断最大尝试次数 +# 用于 "流式抗截断/" 前缀的模型 +# 默认: 3 +ANTI_TRUNCATION_MAX_ATTEMPTS=3 + +# ================================================================ +# 环境变量使用说明 +# ================================================================ + +# 1. 存储模式配置 (按优先级自动选择): +# - PostgreSQL 分布式模式 (最高优先级): 设置 POSTGRESQL_URI,数据存储在 PostgreSQL 数据库 +# - Redis 缓存: 同时设置 REDIS_URI和 MONGODB_URI时,数据缓存在 Redis 数据库,持久化在MONGODB,性能最佳 +# - MongoDB 分布式模式: 设置 MONGODB_URI,数据存储在 MongoDB 数据库 +# - 文件存储模式 (默认): 不设置上述 URI,数据存储在本地 creds/ 目录 +# - 自动切换: 系统根据可用的存储配置自动选择最高优先级的存储后端 + +# 2. 密码配置优先级: +# a) PASSWORD 环境变量 (最高优先级,设置后覆盖其他密码) +# b) API_PASSWORD / PANEL_PASSWORD 环境变量 (专用密码) +# c) 默认值 "pwd" +# +# 3. 通用配置优先级: +# 环境变量 > 默认值 + +# 4. 布尔值环境变量: +# true/1/yes/on 表示启用 +# false/0/no/off 表示禁用 diff --git a/.github/ISSUE_TEMPLATE/bug_report.yml b/.github/ISSUE_TEMPLATE/bug_report.yml new file mode 100644 index 0000000000000000000000000000000000000000..61f97efc9f3c6172e807210b7a6576d008c65a23 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/bug_report.yml @@ -0,0 +1,92 @@ +name: Bug 报告 +description: 报告项目使用中遇到的问题 +title: "[Bug]: " +labels: ["bug", "待处理"] +body: + - type: markdown + attributes: + value: | + ## 感谢你的反馈! + 请填写以下信息以帮助我们更快定位问题。 + + - type: checkboxes + id: checklist + attributes: + label: 提交前确认 + options: + - label: 我已经搜索过现有的 issues,确认这不是重复问题 + required: true + - label: 我已经阅读过项目文档 + required: true + + - type: dropdown + id: latest-version + attributes: + label: 是否是最新版 + description: 请确认你使用的是否是最新版本 + options: + - 是,使用最新版 + - 否,使用旧版本 + validations: + required: true + + - type: input + id: channel + attributes: + label: 调用的是哪个渠道 + description: 例如 geminicli 或者 antigravity + placeholder: "例如: geminicli" + validations: + required: true + + - type: input + id: model + attributes: + label: 调用的是哪个模型 + description: 例如 gemini-2.5-flash + placeholder: "例如: gemini-2.5-flash" + validations: + required: true + + - type: dropdown + id: format + attributes: + label: 调用的是哪个格式 + description: 选择你使用的 API 格式 + options: + - gemini 格式 + - openai 格式 + - claude 格式 + - 其他格式 + validations: + required: true + + - type: textarea + id: error-content + attributes: + label: 具体报错内容 + description: 请粘贴完整的错误信息或截图 + placeholder: | + 请在这里粘贴完整的错误日志或堆栈信息 + render: shell + validations: + required: true + + - type: textarea + id: error-description + attributes: + label: 错误描述 + description: 详细描述问题的发生场景、预期行为和实际行为 + placeholder: | + 1. 我在做什么操作时遇到了这个问题 + 2. 我期望的结果是... + 3. 但实际上发生了... + validations: + required: true + + - type: textarea + id: additional-context + attributes: + label: 补充信息(可选) + description: 其他任何有助于解决问题的信息 + placeholder: 例如:操作系统、Python 版本、相关配置等 \ No newline at end of file diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml new file mode 100644 index 0000000000000000000000000000000000000000..ceaef7e8914dfd70fd470aff137cf443d927cbb3 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/config.yml @@ -0,0 +1,8 @@ +blank_issues_enabled: false +contact_links: + - name: 使用问题讨论 + url: https://github.com/su-kaka/gcli2api/issues + about: 如果是使用方面的问题,请在 issues 中提问 + - name: 项目文档 + url: https://github.com/su-kaka/gcli2api + about: 查看完整文档和使用指南 \ No newline at end of file diff --git a/.github/workflows/docker-publish.yml b/.github/workflows/docker-publish.yml new file mode 100644 index 0000000000000000000000000000000000000000..fc40808e83ba8cf46964756cc6bf6cf0ee516bcb --- /dev/null +++ b/.github/workflows/docker-publish.yml @@ -0,0 +1,81 @@ +name: Docker Build and Publish + +on: + workflow_run: + workflows: ["Update Version File"] + types: + - completed + branches: + - master + - main + push: + tags: + - 'v*' + pull_request: + branches: + - master + - main + workflow_dispatch: + +env: + REGISTRY: ghcr.io + IMAGE_NAME: ${{ github.repository }} + +jobs: + build-and-push: + runs-on: ubuntu-latest + # 只在 workflow_run 成功时运行,或者非 workflow_run 触发时运行 + if: ${{ github.event_name != 'workflow_run' || github.event.workflow_run.conclusion == 'success' }} + permissions: + contents: read + packages: write + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + with: + # workflow_run 触发时需要获取最新的代码(包括 version.txt 的更新) + ref: ${{ github.event_name == 'workflow_run' && github.event.workflow_run.head_branch || github.ref }} + + - name: Set up QEMU + uses: docker/setup-qemu-action@v3 + + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 + + - name: Log in to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ${{ env.REGISTRY }} + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + + - name: Extract metadata + id: meta + uses: docker/metadata-action@v5 + with: + images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }} + tags: | + type=ref,event=branch + type=ref,event=tag + type=ref,event=pr + type=raw,value=latest,enable={{is_default_branch}} + type=sha,prefix={{branch}}- + type=semver,pattern={{version}} + type=semver,pattern={{major}}.{{minor}} + type=semver,pattern={{major}} + + - name: Build and push Docker image + uses: docker/build-push-action@v5 + with: + context: . + platforms: linux/amd64,linux/arm64 + push: ${{ github.event_name != 'pull_request' }} + tags: ${{ steps.meta.outputs.tags }} + labels: ${{ steps.meta.outputs.labels }} + cache-from: type=gha + cache-to: type=gha,mode=max + build-args: | + BUILD_DATE=${{ fromJSON(steps.meta.outputs.json).labels['org.opencontainers.image.created'] }} + VERSION=${{ fromJSON(steps.meta.outputs.json).labels['org.opencontainers.image.version'] }} + REVISION=${{ github.sha }} \ No newline at end of file diff --git a/.github/workflows/update-version.yml b/.github/workflows/update-version.yml new file mode 100644 index 0000000000000000000000000000000000000000..93b018ca576fb727cd3d12f74d893db1ce29129a --- /dev/null +++ b/.github/workflows/update-version.yml @@ -0,0 +1,51 @@ +name: Update Version File + +on: + push: + branches: + - master + - main + +jobs: + update-version: + runs-on: ubuntu-latest + permissions: + contents: write + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + with: + fetch-depth: 0 + token: ${{ secrets.GITHUB_TOKEN }} + + - name: Update version.txt + run: | + # 获取最新commit信息 + FULL_HASH=$(git log -1 --format=%H) + SHORT_HASH=$(git log -1 --format=%h) + MESSAGE=$(git log -1 --format=%s) + DATE=$(git log -1 --format=%ci) + + # 写入version.txt + echo "full_hash=$FULL_HASH" > version.txt + echo "short_hash=$SHORT_HASH" >> version.txt + echo "message=$MESSAGE" >> version.txt + echo "date=$DATE" >> version.txt + + echo "Version file updated:" + cat version.txt + + - name: Commit version.txt if changed + run: | + git config --local user.email "github-actions[bot]@users.noreply.github.com" + git config --local user.name "github-actions[bot]" + + # 检查是否有变化 + if git diff --quiet version.txt; then + echo "No changes to version.txt" + else + git add version.txt + git commit -m "chore: update version.txt [skip ci]" + git push + fi diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..c1fab8d850b5af68cbb41d50159527d6234a578f --- /dev/null +++ b/.gitignore @@ -0,0 +1,99 @@ +# Credential files - should never be committed +*.json +!package.json +!package-lock.json +!tsconfig.json +*.toml +!pyproject.toml +creds/ +CLAUDE.md +GEMINI.md +.kiro +# Environment configuration +.env + +# Python +uv.lock +.python-version +__pycache__/ +*.py[cod] +*$py.class +*.so +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ + +# Virtual environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# IDE +.vscode/ +.idea/ +.claude/ +*.swp +*.swo +*~ + +# OS +.DS_Store +.DS_Store? +._* +.Spotlight-V100 +.Trashes +ehthumbs.db +Thumbs.db + +# Logs +*.log +log.txt + +# Temporary files +*.tmp +*.temp +*.bak + +tools/ diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000000000000000000000000000000000000..2e237c1dbe43b760b1e2aa2cac6546b69eb7c735 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,169 @@ +# Contributing to gcli2api + +First off, thank you for considering contributing to gcli2api! It's people like you that make gcli2api such a great tool. + +## Code of Conduct + +This project is intended for personal learning and research purposes only. By participating, you are expected to uphold this code and respect the CNC-1.0 license restrictions on commercial use. + +## How Can I Contribute? + +### Reporting Bugs + +Before creating bug reports, please check the existing issues to avoid duplicates. When you create a bug report, include as many details as possible: + +* **Use a clear and descriptive title** +* **Describe the exact steps to reproduce the problem** +* **Provide specific examples** - Include code snippets, configuration files, or log outputs +* **Describe the behavior you observed** and what you expected to see +* **Include environment details**: OS, Python version, Docker version (if applicable) + +### Suggesting Enhancements + +Enhancement suggestions are tracked as GitHub issues. When creating an enhancement suggestion, include: + +* **Use a clear and descriptive title** +* **Provide a detailed description** of the suggested enhancement +* **Explain why this enhancement would be useful** +* **List any alternative solutions** you've considered + +### Pull Requests + +1. Fork the repo and create your branch from `master` +2. If you've added code that should be tested, add tests +3. If you've changed APIs, update the documentation +4. Ensure the test suite passes +5. Make sure your code follows the existing style +6. Write a clear commit message + +## Development Setup + +### Prerequisites + +* Python 3.12 or higher +* pip or uv package manager + +### Setting Up Your Development Environment + +```bash +# Clone your fork +git clone https://github.com/YOUR_USERNAME/gcli2api.git +cd gcli2api + +# Install development dependencies +make install-dev +# or +pip install -e ".[dev]" + +# Copy environment example +cp .env.example .env +# Edit .env with your configuration +``` + +### Development Workflow + +```bash +# Run tests +make test + +# Format code +make format + +# Run linters +make lint + +# Run the application locally +make run +``` + +### Testing + +We use pytest for testing. All new features should include appropriate tests. + +```bash +# Run all tests +make test + +# Run with coverage +make test-cov + +# Run specific test file +python -m pytest test_tool_calling.py -v +``` + +### Code Style + +* We use [Black](https://black.readthedocs.io/) for code formatting (line length: 100) +* We use [flake8](https://flake8.pycqa.org/) for linting +* We use [mypy](http://mypy-lang.org/) for type checking (optional, but encouraged) + +```bash +# Format your code before committing +make format + +# Check if code is properly formatted +make format-check + +# Run linters +make lint +``` + +## Project Structure + +``` +gcli2api/ +├── src/ # Main source code +│ ├── auth.py # Authentication and OAuth +│ ├── credential_manager.py # Credential rotation +│ ├── openai_router.py # OpenAI-compatible endpoints +│ ├── gemini_router.py # Gemini native endpoints +│ ├── openai_transfer.py # Format conversion +│ ├── storage/ # Storage backends (Redis, MongoDB, Postgres, File) +│ └── ... +├── front/ # Frontend static files +├── tests/ # Test directory (to be created) +├── test_*.py # Test files (root level) +├── web.py # Main application entry point +├── config.py # Configuration management +└── requirements.txt # Production dependencies +``` + +## Coding Guidelines + +### Python Style + +* Follow PEP 8 guidelines +* Use type hints where appropriate +* Write docstrings for classes and functions +* Keep functions focused and concise + +### Commit Messages + +* Use the present tense ("Add feature" not "Added feature") +* Use the imperative mood ("Move cursor to..." not "Moves cursor to...") +* Limit the first line to 72 characters or less +* Reference issues and pull requests liberally after the first line + +### Documentation + +* Update the README.md if you change functionality +* Comment your code where necessary +* Update the .env.example if you add new configuration options + +## License + +By contributing to gcli2api, you agree that your contributions will be licensed under the CNC-1.0 license. This is a strict anti-commercial license - see [LICENSE](LICENSE) for details. + +### Important License Restrictions + +* ❌ No commercial use +* ❌ No use by companies with revenue > $1M USD +* ❌ No use by VC-backed or publicly traded companies +* ✅ Personal learning, research, and educational use only +* ✅ Open source integration (must follow same license) + +## Questions? + +Feel free to open an issue with your question or reach out to the maintainers. + +Thank you for contributing! 🎉 diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..1c86c3aaebe124d147285185dc588c9f46c9f42f --- /dev/null +++ b/Dockerfile @@ -0,0 +1,34 @@ +# Multi-stage build for gcli2api +FROM python:3.13-slim as base + +# Set environment variables +ENV PYTHONUNBUFFERED=1 \ + PYTHONDONTWRITEBYTECODE=1 \ + PIP_NO_CACHE_DIR=1 \ + PIP_DISABLE_PIP_VERSION_CHECK=1 \ + TZ=Asia/Shanghai + +# Install tzdata and set timezone +RUN apt-get update && \ + apt-get install -y --no-install-recommends tzdata && \ + ln -sf /usr/share/zoneinfo/Asia/Shanghai /etc/localtime && \ + echo "Asia/Shanghai" > /etc/timezone && \ + apt-get clean && \ + rm -rf /var/lib/apt/lists/* + +WORKDIR /app + +# Copy only requirements first for better caching +COPY requirements.txt . + +# Install Python dependencies +RUN pip install --no-cache-dir -r requirements.txt + +# Copy application code +COPY . . + +# Expose port +EXPOSE 7861 + +# Default command +CMD ["python", "web.py"] \ No newline at end of file diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..2317eb517b7aa0e90bc5b81db74b82ff885336e5 --- /dev/null +++ b/LICENSE @@ -0,0 +1,83 @@ +Cooperative Non-Commercial License (CNC-1.0) + +Copyright (c) 2024 gcli2api contributors + +Permission is hereby granted, free of charge, to any person or organization +obtaining a copy of this software and associated documentation files (the +"Software"), to use, copy, modify, merge, publish, distribute, and/or +sublicense the Software, subject to the following conditions: + +TERMS AND CONDITIONS: + +1. NON-COMMERCIAL USE ONLY + The Software may only be used for non-commercial purposes. Commercial use + is strictly prohibited without explicit written permission from the + copyright holders. + +2. DEFINITION OF COMMERCIAL USE + "Commercial use" includes but is not limited to: + a) Using the Software to provide paid services or products + b) Integrating the Software into commercial products or services + c) Using the Software in any business operation that generates revenue + d) Offering the Software as part of a paid subscription or service + e) Using the Software to compete with the original project commercially + +3. COPYLEFT REQUIREMENT + Any derivative works, modifications, or substantial portions of the Software + must be licensed under the same or substantially similar terms. This ensures + that all derivatives remain non-commercial and freely available. + +4. SOURCE CODE AVAILABILITY + If you distribute the Software or any derivative works, you must make the + complete source code available under the same license terms at no charge. + +5. ATTRIBUTION REQUIREMENT + You must retain all copyright notices, license notices, and attribution + statements in all copies or substantial portions of the Software. + +6. ANTI-CORPORATE CLAUSE + This Software may not be used by corporations with annual revenue exceeding + $1 million USD, venture capital backed companies, or publicly traded + companies without explicit written permission from the copyright holders. + +7. EDUCATIONAL AND RESEARCH EXEMPTION + Use by educational institutions, non-profit research organizations, and + individual researchers for educational or research purposes is explicitly + permitted and encouraged. + +8. MODIFICATION AND CONTRIBUTION + Modifications and contributions to the Software are welcomed and encouraged, + provided they comply with these license terms. Contributors grant the same + license to their contributions. + +9. PATENT GRANT + Each contributor grants you a non-exclusive, worldwide, royalty-free patent + license to make, have made, use, offer to sell, sell, import, and otherwise + transfer the Work for non-commercial purposes only. + +10. TERMINATION + This license automatically terminates if you violate any of its terms. + Upon termination, you must cease all use and distribution of the Software + and destroy all copies in your possession. + +11. LIABILITY DISCLAIMER + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + SOFTWARE. + +12. JURISDICTION + This license shall be governed by and construed in accordance with the laws + of the jurisdiction where the copyright holder resides. + +SUMMARY: +This license allows free use, modification, and distribution of the Software +for non-commercial purposes only. It explicitly prohibits commercial use and +ensures that all derivatives remain freely available under the same terms. +The license promotes cooperative development while preventing commercial +exploitation of the community's work. + +For commercial licensing inquiries, please contact the copyright holders. \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000000000000000000000000000000000000..f676fe5697525b2376a44a5c905b4193b0a9d1c9 --- /dev/null +++ b/README.md @@ -0,0 +1,768 @@ +--- +title: gcli2api +colorFrom: blue +colorTo: green +sdk: docker +app_port: 7861 +pinned: false +--- +# GeminiCLI to API + +**灏?GeminiCLI 鍜?Antigravity 杞崲涓?OpenAI 銆丟EMINI 鍜?Claude API 鍏煎鎺ュ彛** + +[![Python 3.12+](https://img.shields.io/badge/python-3.12+-blue.svg)](https://www.python.org/downloads/) +[![License: CNC-1.0](https://img.shields.io/badge/License-CNC--1.0-red.svg)](LICENSE) +[![Docker](https://img.shields.io/badge/docker-available-blue.svg)](https://github.com/su-kaka/gcli2api/pkgs/container/gcli2api) + +[English](docs/README_EN.md) | 涓枃 | [鏃ユ湰瑾瀅(docs/README_JA.md) + +## 馃殌 蹇€熼儴缃? + +[![Deploy on Zeabur](https://zeabur.com/button.svg)](https://zeabur.com/templates/97VMEF?referralCode=sukaka) +[![Deploy to Render](https://render.com/images/deploy-to-render-button.svg)](https://render.com/deploy?repo=https://github.com/su-kaka/gcli2api) +--- + +## 鈿狅笍 璁稿彲璇佸0鏄? + +**鏈」鐩噰鐢?Cooperative Non-Commercial License (CNC-1.0)** + +杩欐槸涓€涓弽鍟嗕笟鍖栫殑涓ユ牸寮€婧愬崗璁紝璇︽儏璇锋煡鐪?[LICENSE](LICENSE) 鏂囦欢銆? + +### 鉁?鍏佽鐨勭敤閫旓細 +- 涓汉瀛︿範銆佺爺绌躲€佹暀鑲茬敤閫? +- 闈炶惀鍒╃粍缁囦娇鐢? +- 寮€婧愰」鐩泦鎴愶紙闇€閬靛惊鐩稿悓鍗忚锛? +- 瀛︽湳鐮旂┒鍜岃鏂囧彂琛? + +### 鉂?绂佹鐨勭敤閫旓細 +- 浠讳綍褰㈠紡鐨勫晢涓氫娇鐢? +- 骞存敹鍏ヨ秴杩?00涓囩編鍏冪殑浼佷笟浣跨敤 +- 椋庢姇鏀寔鎴栧叕寮€浜ゆ槗鐨勫叕鍙镐娇鐢? +- 鎻愪緵浠樿垂鏈嶅姟鎴栦骇鍝? +- 鍟嗕笟绔炰簤鐢ㄩ€? + +## 鏍稿績鍔熻兘 + +### 馃攧 API 绔偣鍜屾牸寮忔敮鎸? + +**澶氱鐐瑰鏍煎紡鏀寔** +- **OpenAI 鍏煎绔偣**锛歚/v1/chat/completions` 鍜?`/v1/models` + - 鏀寔鏍囧噯 OpenAI 鏍煎紡锛坢essages 缁撴瀯锛? + - 鏀寔 Gemini 鍘熺敓鏍煎紡锛坈ontents 缁撴瀯锛? + - 鑷姩鏍煎紡妫€娴嬪拰杞崲锛屾棤闇€鎵嬪姩鍒囨崲 + - 鏀寔澶氭ā鎬佽緭鍏ワ紙鏂囨湰 + 鍥惧儚锛? +- **Gemini 鍘熺敓绔偣**锛歚/v1/models/{model}:generateContent` 鍜?`streamGenerateContent` + - 鏀寔瀹屾暣鐨?Gemini 鍘熺敓 API 瑙勮寖 + - 澶氱璁よ瘉鏂瑰紡锛欱earer Token銆亁-goog-api-key 澶撮儴銆乁RL 鍙傛暟 key +- **Claude 鏍煎紡鍏煎**锛氬畬鏁存敮鎸?Claude API 鏍煎紡 + - 绔偣锛歚/v1/messages`锛堥伒寰?Claude API 瑙勮寖锛? + - 鏀寔 Claude 鏍囧噯鐨?messages 鏍煎紡 + - 鏀寔 system 鍙傛暟鍜?Claude 鐗规湁鍔熻兘 + - 鑷姩杞崲涓哄悗绔敮鎸佺殑鏍煎紡 +- **Antigravity API 鏀寔**锛氬悓鏃舵敮鎸?OpenAI銆丟emini 鍜?Claude 鏍煎紡 + - OpenAI 鏍煎紡绔偣锛歚/antigravity/v1/chat/completions` + - Gemini 鏍煎紡绔偣锛歚/antigravity/v1/models/{model}:generateContent` 鍜?`streamGenerateContent` + - Claude 鏍煎紡绔偣锛歚/antigravity/v1/messages` + - 鏀寔鎵€鏈?Antigravity 妯″瀷锛圕laude銆丟emini 绛夛級 + - 鑷姩妯″瀷鍚嶇О鏄犲皠鍜屾€濈淮妯″紡妫€娴? + +### 馃攼 璁よ瘉鍜屽畨鍏ㄧ鐞? + +**鐏垫椿鐨勫瘑鐮佺鐞?* +- **鍒嗙瀵嗙爜鏀寔**锛欰PI 瀵嗙爜锛堣亰澶╃鐐癸級鍜屾帶鍒堕潰鏉垮瘑鐮佸彲鐙珛璁剧疆 +- **澶氱璁よ瘉鏂瑰紡**锛氭敮鎸?Authorization Bearer銆亁-goog-api-key 澶撮儴銆乁RL 鍙傛暟绛? +- **JWT Token 璁よ瘉**锛氭帶鍒堕潰鏉挎敮鎸?JWT 浠ょ墝璁よ瘉 +- **鐢ㄦ埛閭鑾峰彇**锛氳嚜鍔ㄨ幏鍙栧拰鏄剧ず Google 璐︽埛閭鍦板潃 + +### 馃搳 鏅鸿兘鍑瘉绠$悊绯荤粺 + +**楂樼骇鍑瘉绠$悊** +- 澶氫釜 Google OAuth 鍑瘉鑷姩杞崲 +- 閫氳繃鍐椾綑璁よ瘉澧炲己绋冲畾鎬? +- 璐熻浇鍧囪 涓庡苟鍙戣姹傛敮鎸? +- 鑷姩鏁呴殰妫€娴嬪拰鍑瘉绂佺敤 +- 鍑瘉浣跨敤缁熻鍜岄厤棰濈鐞? +- 鏀寔鎵嬪姩鍚敤/绂佺敤鍑瘉鏂囦欢 +- 鎵归噺鍑瘉鏂囦欢鎿嶄綔锛堝惎鐢ㄣ€佺鐢ㄣ€佸垹闄わ級 + +**鍑瘉鐘舵€佺洃鎺?* +- 瀹炴椂鍑瘉鍋ュ悍妫€鏌? +- 閿欒鐮佽拷韪紙429銆?03銆?00 绛夛級 +- 鑷姩灏佺鏈哄埗锛堝彲閰嶇疆锛? + +### 馃寠 娴佸紡浼犺緭鍜屽搷搴斿鐞? + +**澶氱娴佸紡鏀寔** +- 鐪熸鐨勫疄鏃舵祦寮忓搷搴? +- 鍋囨祦寮忔ā寮忥紙鐢ㄤ簬鍏煎鎬э級 +- 娴佸紡鎶楁埅鏂姛鑳斤紙闃叉鍥炵瓟琚埅鏂級 +- 寮傛浠诲姟绠$悊鍜岃秴鏃跺鐞? + +**鍝嶅簲浼樺寲** +- 鎬濈淮閾撅紙Thinking锛夊唴瀹瑰垎绂? +- 鎺ㄧ悊杩囩▼锛坮easoning_content锛夊鐞? +- 澶氳疆瀵硅瘽涓婁笅鏂囩鐞? +- 鍏煎鎬фā寮忥紙灏?system 娑堟伅杞崲涓?user 娑堟伅锛? + +### 馃帥锔?Web 绠$悊鎺у埗鍙? + +**鍏ㄥ姛鑳?Web 鐣岄潰** +- OAuth 璁よ瘉娴佺▼绠$悊锛堟敮鎸?GCLI 鍜?Antigravity 鍙屾ā寮忥級 +- 鍑瘉鏂囦欢涓婁紶銆佷笅杞姐€佺鐞? +- 瀹炴椂鏃ュ織鏌ョ湅锛圵ebSocket锛? +- 绯荤粺閰嶇疆绠$悊 +- 浣跨敤缁熻鍜岀洃鎺ч潰鏉? +- 绉诲姩绔€傞厤鐣岄潰 + +**鎵归噺鎿嶄綔鏀寔** +- ZIP 鏂囦欢鎵归噺涓婁紶鍑瘉锛圙CLI 鍜?Antigravity锛? +- 鎵归噺鍚敤/绂佺敤/鍒犻櫎鍑瘉 +- 鎵归噺鑾峰彇鐢ㄦ埛閭 +- 鎵归噺閰嶇疆绠$悊 +- 缁熶竴鎵归噺涓婁紶鐣岄潰绠$悊鎵€鏈夊嚟璇佺被鍨? + +### 馃搱 浣跨敤鐩戞帶 + +**瀹炴椂鐩戞帶** +- WebSocket 瀹炴椂鏃ュ織娴? +- 绯荤粺鐘舵€佺洃鎺? +- 鍑瘉鍋ュ悍鐘舵€? + +### 馃敡 楂樼骇閰嶇疆鍜岃嚜瀹氫箟 + +**缃戠粶鍜屼唬鐞嗛厤缃?* +- HTTP/HTTPS 浠g悊鏀寔 +- 浠g悊绔偣閰嶇疆锛圤Auth銆丟oogle APIs銆佸厓鏁版嵁鏈嶅姟锛? +- 瓒呮椂鍜岄噸璇曢厤缃? +- 缃戠粶閿欒澶勭悊鍜屾仮澶? + +**鎬ц兘鍜岀ǔ瀹氭€ч厤缃?* +- 429 閿欒鑷姩閲嶈瘯锛堝彲閰嶇疆闂撮殧鍜屾鏁帮級 +- 鎶楁埅鏂渶澶ч噸璇曟鏁? + +**鏃ュ織鍜岃皟璇?* +- 澶氱骇鏃ュ織绯荤粺锛圖EBUG銆両NFO銆乄ARNING銆丒RROR锛? +- 鏃ュ織鏂囦欢绠$悊 +- 瀹炴椂鏃ュ織娴? +- 鏃ュ織涓嬭浇鍜屾竻绌? + +### 馃攧 鐜鍙橀噺鍜岄厤缃鐞? + +**鐏垫椿鐨勯厤缃柟寮?* +- 鐜鍙橀噺閰嶇疆 +- 鐑厤缃洿鏂帮紙閮ㄥ垎閰嶇疆椤癸級 +- 閰嶇疆閿佸畾锛堢幆澧冨彉閲忎紭鍏堢骇锛? + +## 鏀寔鐨勬ā鍨? + +鎵€鏈夋ā鍨嬪潎鍏峰 1M 涓婁笅鏂囩獥鍙e閲忋€傛瘡涓嚟璇佹枃浠舵彁渚?1000 娆¤姹傞搴︺€? + +### 馃 鍩虹妯″瀷 +- `gemini-2.5-pro` +- `gemini-3-pro-preview` +- `gemini-3.1-pro-preview` + +### 馃 鎬濈淮妯″瀷锛圱hinking Models锛? +- `gemini-2.5-pro-high`锛氭€濊€冩ā寮? +- `gemini-2.5-pro-low`锛氫綆鎬濊€冩ā寮? +- 鏀寔鑷畾涔夋€濊€冮绠楅厤缃? +- 鑷姩鍒嗙鎬濈淮鍐呭鍜屾渶缁堝洖绛? + +### 馃攳 鎼滅储澧炲己妯″瀷 +- `gemini-2.5-pro-search`锛氶泦鎴愭悳绱㈠姛鑳界殑妯″瀷 + +### 馃柤锔?鍥惧儚鐢熸垚妯″瀷锛圓ntigravity锛? +- `gemini-3.1-flash-image`锛氬熀纭€鍥惧儚鐢熸垚妯″瀷 +- **鍒嗚鲸鐜囧悗缂€**锛? + - `-2k`锛?K 鍒嗚鲸鐜? + - `-4k`锛?K 楂樻竻鍒嗚鲸鐜? +- **姣斾緥鍚庣紑**锛? + - `-1x1`锛氭鏂瑰舰锛堝ご鍍忥級 + - `-16x9`锛氭í灞忥紙鐢佃剳澹佺焊锛? + - `-9x16`锛氱珫灞忥紙鎵嬫満澹佺焊锛? + - `-21x9`锛氳秴瀹藉睆锛堝甫楸煎睆锛? + - `-4x3`锛氫紶缁熸樉绀哄櫒 + - `-3x4`锛氱珫鐗堟捣鎶? +- **缁勫悎浣跨敤绀轰緥**锛? + - `gemini-3.1-flash-image-4k-16x9`锛?K 妯睆 + - `gemini-3.1-flash-image-2k-9x16`锛?K 绔栧睆 +- 涓嶆寚瀹氭瘮渚嬫椂锛孉PI 鑷姩鍐冲畾妯珫姣斾緥 + +### 馃寠 鐗规畩鍔熻兘鍙樹綋 +- **鍋囨祦寮忔ā寮?*锛氬湪浠讳綍妯″瀷鍚嶇О鍚庢坊鍔?`-鍋囨祦寮廯 鍚庣紑 + - 渚嬶細`gemini-2.5-pro-鍋囨祦寮廯 + - 鐢ㄤ簬闇€瑕佹祦寮忓搷搴斾絾鏈嶅姟绔笉鏀寔鐪熸祦寮忕殑鍦烘櫙 +- **娴佸紡鎶楁埅鏂ā寮?*锛氬湪妯″瀷鍚嶇О鍓嶆坊鍔?`娴佸紡鎶楁埅鏂?` 鍓嶇紑 + - 渚嬶細`娴佸紡鎶楁埅鏂?gemini-2.5-pro` + - 鑷姩妫€娴嬪搷搴旀埅鏂苟閲嶈瘯锛岀‘淇濆畬鏁村洖绛? + +### 馃敡 妯″瀷鍔熻兘鑷姩妫€娴? +- 绯荤粺鑷姩璇嗗埆妯″瀷鍚嶇О涓殑鍔熻兘鏍囪瘑 +- 閫忔槑鍦板鐞嗗姛鑳芥ā寮忚浆鎹? +- 鏀寔鍔熻兘缁勫悎浣跨敤 + + +--- + +## 瀹夎鎸囧崡 + +### Termux 鐜 + +**鍒濆瀹夎** +```bash +curl -o termux-install.sh "https://raw.githubusercontent.com/su-kaka/gcli2api/refs/heads/master/termux-install.sh" && chmod +x termux-install.sh && ./termux-install.sh +``` + +**閲嶅惎鏈嶅姟** +```bash +cd gcli2api +bash termux-start.sh +``` + +### Windows 鐜 + +**鍒濆瀹夎** +```powershell +iex (iwr "https://raw.githubusercontent.com/su-kaka/gcli2api/refs/heads/master/install.ps1" -UseBasicParsing).Content +``` + +**閲嶅惎鏈嶅姟** +鍙屽嚮鎵ц `start.bat` + +### Linux 鐜 + +**鍒濆瀹夎** +```bash +curl -o install.sh "https://raw.githubusercontent.com/su-kaka/gcli2api/refs/heads/master/install.sh" && chmod +x install.sh && ./install.sh +``` + +**閲嶅惎鏈嶅姟** +```bash +cd gcli2api +bash start.sh +``` + +### macOS 鐜 + +**鍒濆瀹夎** +```bash +curl -o darwin-install.sh "https://raw.githubusercontent.com/su-kaka/gcli2api/refs/heads/master/darwin-install.sh" && chmod +x darwin-install.sh && ./darwin-install.sh +``` + +**閲嶅惎鏈嶅姟** +```bash +cd gcli2api +bash start.sh +``` + +### Docker 鐜 + +**Docker 杩愯鍛戒护** +```bash +# 浣跨敤閫氱敤瀵嗙爜 +docker run -d --name gcli2api --network host -e PASSWORD=pwd -e PORT=7861 -v $(pwd)/data/creds:/app/creds ghcr.io/su-kaka/gcli2api:latest + +# 浣跨敤鍒嗙瀵嗙爜 +docker run -d --name gcli2api --network host -e API_PASSWORD=api_pwd -e PANEL_PASSWORD=panel_pwd -e PORT=7861 -v $(pwd)/data/creds:/app/creds ghcr.io/su-kaka/gcli2api:latest +``` + +**Docker Mac** +```bash +# 浣跨敤閫氱敤瀵嗙爜 +docker run -d \ + --name gcli2api \ + -p 7861:7861 \ + -p 8080:8080 \ + -e PASSWORD=pwd \ + -e PORT=7861 \ + -v "$(pwd)/data/creds":/app/creds \ + ghcr.io/su-kaka/gcli2api:latest +``` + +```bash +# 浣跨敤鍒嗙瀵嗙爜 +docker run -d \ +--name gcli2api \ +-p 7861:7861 \ +-p 8080:8080 \ +-e API_PASSWORD=api_pwd \ +-e PANEL_PASSWORD=panel_pwd \ +-e PORT=7861 \ +-v $(pwd)/data/creds:/app/creds \ +ghcr.io/su-kaka/gcli2api:latest +``` + +**Docker Compose 杩愯鍛戒护** +1. 灏嗕互涓嬪唴瀹逛繚瀛樹负 `docker-compose.yml` 鏂囦欢锛? + ```yaml + version: '3.8' + + services: + gcli2api: + image: ghcr.io/su-kaka/gcli2api:latest + container_name: gcli2api + restart: unless-stopped + network_mode: host + environment: + # 浣跨敤閫氱敤瀵嗙爜锛堟帹鑽愮敤浜庣畝鍗曢儴缃诧級 + - PASSWORD=pwd + - PORT=7861 + # 鎴栦娇鐢ㄥ垎绂诲瘑鐮侊紙鎺ㄨ崘鐢ㄤ簬鐢熶骇鐜锛? + # - API_PASSWORD=your_api_password + # - PANEL_PASSWORD=your_panel_password + volumes: + - ./data/creds:/app/creds + healthcheck: + test: ["CMD-SHELL", "python -c \"import sys, urllib.request, os; port = os.environ.get('PORT', '7861'); req = urllib.request.Request(f'http://localhost:{port}/v1/models', headers={'Authorization': 'Bearer ' + os.environ.get('PASSWORD', 'pwd')}); sys.exit(0 if urllib.request.urlopen(req, timeout=5).getcode() == 200 else 1)\""] + interval: 30s + timeout: 10s + retries: 3 + start_period: 40s + ``` +2. 鍚姩鏈嶅姟锛? + ```bash + docker-compose up -d + ``` + +--- + +## 閰嶇疆璇存槑 + +1. 璁块棶 `http://127.0.0.1:7861` 锛堥粯璁ょ鍙o紝鍙€氳繃 PORT 鐜鍙橀噺淇敼锛? +2. 瀹屾垚 OAuth 璁よ瘉娴佺▼锛堥粯璁ゅ瘑鐮侊細`pwd`锛屽彲閫氳繃鐜鍙橀噺淇敼锛? + - **GCLI 妯″紡**锛氱敤浜庤幏鍙?Google Cloud Gemini API 鍑瘉 + - **Antigravity 妯″紡**锛氱敤浜庤幏鍙?Google Antigravity API 鍑瘉 +3. 閰嶇疆瀹㈡埛绔細 + +**OpenAI 鍏煎瀹㈡埛绔細** + - **绔偣鍦板潃**锛歚http://127.0.0.1:7861/v1` + - **API 瀵嗛挜**锛歚pwd`锛堥粯璁ゅ€硷紝鍙€氳繃 API_PASSWORD 鎴?PASSWORD 鐜鍙橀噺淇敼锛? + +**Gemini 鍘熺敓瀹㈡埛绔細** + - **绔偣鍦板潃**锛歚http://127.0.0.1:7861` + - **璁よ瘉鏂瑰紡**锛? + - `Authorization: Bearer your_api_password` + - `x-goog-api-key: your_api_password` + - URL 鍙傛暟锛歚?key=your_api_password` + +### 馃専 鍙岃璇佹ā寮忔敮鎸? + +**GCLI 璁よ瘉妯″紡** +- 鏍囧噯鐨?Google Cloud Gemini API 璁よ瘉 +- 鏀寔 OAuth2.0 璁よ瘉娴佺▼ +- 鑷姩鍚敤蹇呴渶鐨?Google Cloud API + +**Antigravity 璁よ瘉妯″紡** +- Google Antigravity API 涓撶敤璁よ瘉 +- 鐙珛鐨勫嚟璇佺鐞嗙郴缁? +- 鏀寔鎵归噺涓婁紶鍜岀鐞? +- 涓?GCLI 鍑瘉瀹屽叏闅旂 + +**缁熶竴绠$悊鐣岄潰** +- 鍦?鎵归噺涓婁紶"鏍囩椤典腑鍙竴娆℃€х鐞嗕袱绉嶅嚟璇? +- 涓婂崐閮ㄥ垎锛欸CLI 鍑瘉鎵归噺涓婁紶锛堣摑鑹蹭富棰橈級 +- 涓嬪崐閮ㄥ垎锛欰ntigravity 鍑瘉鎵归噺涓婁紶锛堢豢鑹蹭富棰橈級 +- 鍚勮嚜鐙珛鐨勫嚟璇佺鐞嗘爣绛鹃〉 + +## 馃捑 鏁版嵁瀛樺偍妯″紡 + +### 馃専 瀛樺偍鍚庣鏀寔 + +gcli2api 鏀寔涓ょ瀛樺偍鍚庣锛?*鏈湴 SQLite锛堥粯璁わ級** 鍜?**MongoDB锛堜簯绔垎甯冨紡瀛樺偍锛?* + +### 馃搧 鏈湴 SQLite 瀛樺偍锛堥粯璁わ級 + +**榛樿瀛樺偍鏂瑰紡** +- 鏃犻渶閰嶇疆锛屽紑绠卞嵆鐢? +- 鏁版嵁瀛樺偍鍦ㄦ湰鍦?SQLite 鏁版嵁搴撲腑 +- 閫傚悎鍗曟満閮ㄧ讲鍜屼釜浜轰娇鐢? +- 鑷姩鍒涘缓鍜岀鐞嗘暟鎹簱鏂囦欢 + +### 馃崈 MongoDB 浜戠瀛樺偍妯″紡 + +**浜戠鍒嗗竷寮忓瓨鍌ㄦ柟妗?* + +褰撻渶瑕佸瀹炰緥閮ㄧ讲鎴栦簯绔瓨鍌ㄦ椂锛屽彲浠ュ惎鐢?MongoDB 瀛樺偍妯″紡銆? + +### 鈿欙笍 鍚敤 MongoDB 妯″紡 + +**姝ラ 1: 閰嶇疆 MongoDB 杩炴帴** +```bash +# 鏈湴 MongoDB +export MONGODB_URI="mongodb://localhost:27017" + +# MongoDB Atlas 浜戞湇鍔? +export MONGODB_URI="mongodb+srv://username:password@cluster.mongodb.net" + +# 甯﹁璇佺殑 MongoDB +export MONGODB_URI="mongodb://admin:password@localhost:27017/admin" + +# 鍙€夛細鑷畾涔夋暟鎹簱鍚嶇О锛堥粯璁? gcli2api锛? +export MONGODB_DATABASE="my_gcli_db" +``` + +**姝ラ 2: 鍚姩搴旂敤** +```bash +# 搴旂敤浼氳嚜鍔ㄦ娴?MongoDB 閰嶇疆骞朵娇鐢?MongoDB 瀛樺偍 +python web.py +``` + +**Docker 鐜浣跨敤 MongoDB** +```bash +# 鍗曟満 MongoDB 閮ㄧ讲 +docker run -d --name gcli2api \ + -e MONGODB_URI="mongodb://mongodb:27017" \ + -e API_PASSWORD=your_password \ + --network your_network \ + ghcr.io/su-kaka/gcli2api:latest + +# 浣跨敤 MongoDB Atlas +docker run -d --name gcli2api \ + -e MONGODB_URI="mongodb+srv://user:pass@cluster.mongodb.net/gcli2api" \ + -e API_PASSWORD=your_password \ + -p 7861:7861 \ + ghcr.io/su-kaka/gcli2api:latest +``` + +**Docker Compose 绀轰緥** +```yaml +version: '3.8' + +services: + mongodb: + image: mongo:7 + container_name: gcli2api-mongodb + restart: unless-stopped + environment: + MONGO_INITDB_ROOT_USERNAME: admin + MONGO_INITDB_ROOT_PASSWORD: password123 + volumes: + - mongodb_data:/data/db + ports: + - "27017:27017" + + gcli2api: + image: ghcr.io/su-kaka/gcli2api:latest + container_name: gcli2api + restart: unless-stopped + depends_on: + - mongodb + environment: + - MONGODB_URI=mongodb://admin:password123@mongodb:27017/admin + - MONGODB_DATABASE=gcli2api + - API_PASSWORD=your_api_password + - PORT=7861 + ports: + - "7861:7861" + +volumes: + mongodb_data: +``` + + +### 馃敡 楂樼骇閰嶇疆 + +**MongoDB 杩炴帴浼樺寲** +```bash +# 杩炴帴姹犲拰瓒呮椂閰嶇疆 +export MONGODB_URI="mongodb://localhost:27017?maxPoolSize=10&serverSelectionTimeoutMS=5000" + +# 鍓湰闆嗛厤缃? +export MONGODB_URI="mongodb://host1:27017,host2:27017,host3:27017/gcli2api?replicaSet=myReplicaSet" + +# 璇诲啓鍒嗙閰嶇疆 +export MONGODB_URI="mongodb://localhost:27017/gcli2api?readPreference=secondaryPreferred" +``` + +### 鐜鍙橀噺閰嶇疆 + +**鍩虹閰嶇疆** +- `PORT`: 鏈嶅姟绔彛锛堥粯璁わ細7861锛? +- `HOST`: 鏈嶅姟鍣ㄧ洃鍚湴鍧€锛堥粯璁わ細0.0.0.0锛? + +**瀵嗙爜閰嶇疆** +- `API_PASSWORD`: 鑱婂ぉ API 璁块棶瀵嗙爜锛堥粯璁わ細缁ф壙 PASSWORD 鎴?pwd锛? +- `PANEL_PASSWORD`: 鎺у埗闈㈡澘璁块棶瀵嗙爜锛堥粯璁わ細缁ф壙 PASSWORD 鎴?pwd锛? +- `PASSWORD`: 閫氱敤瀵嗙爜锛岃缃悗瑕嗙洊涓婅堪涓や釜锛堥粯璁わ細pwd锛? + +**鎬ц兘鍜岀ǔ瀹氭€ч厤缃?* +- `RETRY_429_ENABLED`: 鍚敤 429 閿欒鑷姩閲嶈瘯锛堥粯璁わ細true锛? +- `RETRY_429_MAX_RETRIES`: 429 閿欒鏈€澶ч噸璇曟鏁帮紙榛樿锛?锛? +- `RETRY_429_INTERVAL`: 429 閿欒閲嶈瘯闂撮殧锛岀锛堥粯璁わ細1.0锛? +- `ANTI_TRUNCATION_MAX_ATTEMPTS`: 鎶楁埅鏂渶澶ч噸璇曟鏁帮紙榛樿锛?锛? + +**缃戠粶鍜屼唬鐞嗛厤缃?* +- `PROXY`: HTTP/HTTPS 浠g悊鍦板潃锛堟牸寮忥細`http://host:port`锛? +- `OAUTH_PROXY_URL`: OAuth 璁よ瘉浠g悊绔偣 +- `GOOGLEAPIS_PROXY_URL`: Google APIs 浠g悊绔偣 +- `METADATA_SERVICE_URL`: 鍏冩暟鎹湇鍔′唬鐞嗙鐐? + +**鑷姩鍖栭厤缃?* +- `AUTO_BAN`: 鍚敤鍑瘉鑷姩灏佺锛堥粯璁わ細true锛? +- `AUTO_LOAD_ENV_CREDS`: 鍚姩鏃惰嚜鍔ㄥ姞杞界幆澧冨彉閲忓嚟璇侊紙榛樿锛歠alse锛? + +**鍏煎鎬ч厤缃?* +- `COMPATIBILITY_MODE`: 鍚敤鍏煎鎬фā寮忥紝灏?system 娑堟伅杞负 user 娑堟伅锛堥粯璁わ細false锛? + +**鏃ュ織閰嶇疆** +- `LOG_LEVEL`: 鏃ュ織绾у埆锛圖EBUG/INFO/WARNING/ERROR锛岄粯璁わ細INFO锛? +- `LOG_FILE`: 鏃ュ織鏂囦欢璺緞锛堥粯璁わ細log.txt锛? + +**瀛樺偍閰嶇疆** + +**SQLite 閰嶇疆锛堥粯璁わ級** +- 鏃犻渶閰嶇疆锛岃嚜鍔ㄤ娇鐢ㄦ湰鍦?SQLite 鏁版嵁搴? +- 鏁版嵁搴撴枃浠惰嚜鍔ㄥ垱寤哄湪椤圭洰鐩綍 + +**MongoDB 閰嶇疆锛堝彲閫変簯绔瓨鍌級** +- `MONGODB_URI`: MongoDB 杩炴帴瀛楃涓诧紙璁剧疆鍚庡惎鐢?MongoDB 妯″紡锛? +- `MONGODB_DATABASE`: MongoDB 鏁版嵁搴撳悕绉帮紙榛樿锛歡cli2api锛? + +**Docker 浣跨敤绀轰緥** +```bash +# 浣跨敤閫氱敤瀵嗙爜 +docker run -d --name gcli2api \ + -e PASSWORD=mypassword \ + -e PORT=7861 \ + ghcr.io/su-kaka/gcli2api:latest + +# 浣跨敤鍒嗙瀵嗙爜 +docker run -d --name gcli2api \ + -e API_PASSWORD=my_api_password \ + -e PANEL_PASSWORD=my_panel_password \ + -e PORT=7861 \ + ghcr.io/su-kaka/gcli2api:latest +``` + +娉ㄦ剰锛氬綋璁剧疆浜嗗嚟璇佺幆澧冨彉閲忔椂锛岀郴缁熷皢浼樺厛浣跨敤鐜鍙橀噺涓殑鍑瘉锛屽拷鐣?`creds` 鐩綍涓殑鏂囦欢銆? + +### API 浣跨敤鏂瑰紡 + +鏈湇鍔℃敮鎸佷笁濂楀畬鏁寸殑 API 绔偣锛? + +#### 1. OpenAI 鍏煎绔偣锛圙CLI锛? + +**绔偣锛?* `/v1/chat/completions` +**璁よ瘉锛?* `Authorization: Bearer your_api_password` + +鏀寔涓ょ璇锋眰鏍煎紡锛屼細鑷姩妫€娴嬪苟澶勭悊锛? + +**OpenAI 鏍煎紡锛?* +```json +{ + "model": "gemini-2.5-pro", + "messages": [ + {"role": "system", "content": "You are a helpful assistant"}, + {"role": "user", "content": "Hello"} + ], + "temperature": 0.7, + "stream": true +} +``` + +**Gemini 鍘熺敓鏍煎紡锛?* +```json +{ + "model": "gemini-2.5-pro", + "contents": [ + {"role": "user", "parts": [{"text": "Hello"}]} + ], + "systemInstruction": {"parts": [{"text": "You are a helpful assistant"}]}, + "generationConfig": { + "temperature": 0.7 + } +} +``` + +#### 2. Gemini 鍘熺敓绔偣锛圙CLI锛? + +**闈炴祦寮忕鐐癸細** `/v1/models/{model}:generateContent` +**娴佸紡绔偣锛?* `/v1/models/{model}:streamGenerateContent` +**妯″瀷鍒楄〃锛?* `/v1/models` + +**璁よ瘉鏂瑰紡锛堜换閫変竴绉嶏級锛?* +- `Authorization: Bearer your_api_password` +- `x-goog-api-key: your_api_password` +- URL 鍙傛暟锛歚?key=your_api_password` + +**璇锋眰绀轰緥锛?* +```bash +# 浣跨敤 x-goog-api-key 澶撮儴 +curl -X POST "http://127.0.0.1:7861/v1/models/gemini-2.5-pro:generateContent" \ + -H "x-goog-api-key: your_api_password" \ + -H "Content-Type: application/json" \ + -d '{ + "contents": [ + {"role": "user", "parts": [{"text": "Hello"}]} + ] + }' + +# 浣跨敤 URL 鍙傛暟 +curl -X POST "http://127.0.0.1:7861/v1/models/gemini-2.5-pro:streamGenerateContent?key=your_api_password" \ + -H "Content-Type: application/json" \ + -d '{ + "contents": [ + {"role": "user", "parts": [{"text": "Hello"}]} + ] + }' +``` + +#### 3. Claude API 鏍煎紡绔偣 + +**绔偣锛?* `/v1/messages` +**璁よ瘉锛?* `x-api-key: your_api_password` 鎴?`Authorization: Bearer your_api_password` + +**璇锋眰绀轰緥锛?* +```bash +curl -X POST "http://127.0.0.1:7861/v1/messages" \ + -H "x-api-key: your_api_password" \ + -H "anthropic-version: 2023-06-01" \ + -H "Content-Type: application/json" \ + -d '{ + "model": "gemini-2.5-pro", + "max_tokens": 1024, + "messages": [ + {"role": "user", "content": "Hello, Claude!"} + ] + }' +``` + +**鏀寔 system 鍙傛暟锛?* +```json +{ + "model": "gemini-2.5-pro", + "max_tokens": 1024, + "system": "You are a helpful assistant", + "messages": [ + {"role": "user", "content": "Hello"} + ] +} +``` + +**璇存槑锛?* +- 瀹屽叏鍏煎 Claude API 鏍煎紡瑙勮寖 +- 鑷姩杞崲涓?Gemini 鏍煎紡璋冪敤鍚庣 +- 鏀寔 Claude 鐨勬墍鏈夋爣鍑嗗弬鏁? +- 鍝嶅簲鏍煎紡绗﹀悎 Claude API 瑙勮寖 + +## 馃搵 瀹屾暣 API 鍙傝€? + +### Web 鎺у埗鍙?API + +**璁よ瘉绔偣** +- `POST /auth/login` - 鐢ㄦ埛鐧诲綍 +- `POST /auth/start` - 寮€濮?OAuth 璁よ瘉锛堟敮鎸?GCLI 鍜?Antigravity 妯″紡锛? +- `POST /auth/callback` - 澶勭悊 OAuth 鍥炶皟 +- `POST /auth/callback-url` - 浠庡洖璋?URL 鐩存帴瀹屾垚璁よ瘉 +- `GET /auth/status/{project_id}` - 妫€鏌ヨ璇佺姸鎬? + +**鍑瘉绠$悊绔偣**锛堟敮鎸?`mode=geminicli` 鎴?`mode=antigravity` 鍙傛暟锛? +- `POST /creds/upload` - 鎵归噺涓婁紶鍑瘉鏂囦欢锛堟敮鎸?JSON 鍜?ZIP锛? +- `GET /creds/status` - 鑾峰彇鍑瘉鐘舵€佸垪琛紙鏀寔鍒嗛〉鍜岀瓫閫夛級 +- `GET /creds/detail/{filename}` - 鑾峰彇鍗曚釜鍑瘉璇︽儏 +- `POST /creds/action` - 鍗曚釜鍑瘉鎿嶄綔锛堝惎鐢?绂佺敤/鍒犻櫎锛? +- `POST /creds/batch-action` - 鎵归噺鍑瘉鎿嶄綔 +- `GET /creds/download/{filename}` - 涓嬭浇鍗曚釜鍑瘉鏂囦欢 +- `GET /creds/download-all` - 鎵撳寘涓嬭浇鎵€鏈夊嚟璇? +- `POST /creds/fetch-email/{filename}` - 鑾峰彇鐢ㄦ埛閭 +- `POST /creds/refresh-all-emails` - 鎵归噺鍒锋柊鐢ㄦ埛閭 +- `POST /creds/deduplicate-by-email` - 鎸夐偖绠卞幓閲嶅嚟璇? +- `POST /creds/verify-project/{filename}` - 妫€楠屽嚟璇?Project ID +- `GET /creds/quota/{filename}` - 鑾峰彇鍑瘉棰濆害淇℃伅锛堜粎 Antigravity锛? + +**閰嶇疆绠$悊绔偣** +- `GET /config/get` - 鑾峰彇褰撳墠閰嶇疆 +- `POST /config/save` - 淇濆瓨閰嶇疆 + +**鏃ュ織绠$悊绔偣** +- `POST /logs/clear` - 娓呯┖鏃ュ織 +- `GET /logs/download` - 涓嬭浇鏃ュ織鏂囦欢 +- `WebSocket /logs/stream` - 瀹炴椂鏃ュ織娴? + +**鐗堟湰淇℃伅绔偣** +- `GET /version/info` - 鑾峰彇鐗堟湰淇℃伅锛堝彲閫?`check_update=true` 鍙傛暟妫€鏌ユ洿鏂帮級 + +### 鑱婂ぉ API 鍔熻兘鐗规€? + +**澶氭ā鎬佹敮鎸?* +```json +{ + "model": "gemini-2.5-pro", + "messages": [ + { + "role": "user", + "content": [ + {"type": "text", "text": "鎻忚堪杩欏紶鍥剧墖"}, + { + "type": "image_url", + "image_url": { + "url": "data:image/jpeg;base64,/9j/4AAQSkZJRgABA..." + } + } + ] + } + ] +} +``` + +**鎬濈淮妯″紡鏀寔** +```json +{ + "model": "gemini-2.5-pro-maxthinking", + "messages": [ + {"role": "user", "content": "澶嶆潅鏁板闂"} + ] +} +``` + +鍝嶅簲灏嗗寘鍚垎绂荤殑鎬濈淮鍐呭锛? +```json +{ + "choices": [{ + "message": { + "role": "assistant", + "content": "鏈€缁堢瓟妗?, + "reasoning_content": "璇︾粏鐨勬€濊€冭繃绋?.." + } + }] +} +``` + +**娴佸紡鎶楁埅鏂娇鐢?* +```json +{ + "model": "娴佸紡鎶楁埅鏂?gemini-2.5-pro", + "messages": [ + {"role": "user", "content": "鍐欎竴绡囬暱鏂囩珷"} + ], + "stream": true +} +``` + +**鍏煎鎬фā寮?* +```bash +# 鍚敤鍏煎鎬фā寮? +export COMPATIBILITY_MODE=true +``` +姝ゆā寮忎笅锛屾墍鏈?`system` 娑堟伅浼氳浆鎹负 `user` 娑堟伅锛屾彁楂樹笌鏌愪簺瀹㈡埛绔殑鍏煎鎬с€? + +--- + +## 馃挰 浜ゆ祦缇? + +娆㈣繋鍔犲叆 QQ 缇や氦娴佽璁猴紒 + +**QQ 缇ゅ彿锛?083250744** + +QQ缇や簩缁寸爜 + +--- + +## 璁稿彲璇佷笌鍏嶈矗澹版槑 + +鏈」鐩粎渚涘涔犲拰鐮旂┒鐢ㄩ€斻€備娇鐢ㄦ湰椤圭洰琛ㄧず鎮ㄥ悓鎰忥細 +- 涓嶅皢鏈」鐩敤浜庝换浣曞晢涓氱敤閫? +- 鎵挎媴浣跨敤鏈」鐩殑鎵€鏈夐闄╁拰璐d换 +- 閬靛畧鐩稿叧鐨勬湇鍔℃潯娆惧拰娉曞緥娉曡 + +椤圭洰浣滆€呭鍥犱娇鐢ㄦ湰椤圭洰鑰屼骇鐢熺殑浠讳綍鐩存帴鎴栭棿鎺ユ崯澶变笉鎵挎媴璐d换銆? diff --git a/config.py b/config.py new file mode 100644 index 0000000000000000000000000000000000000000..cc7d40554db2a9af4e3a33e2fe612011c206fc8f --- /dev/null +++ b/config.py @@ -0,0 +1,457 @@ +""" +Configuration constants for the Geminicli2api proxy server. +Centralizes all configuration to avoid duplication across modules. + +- 启动时加载一次配置到内存 +- 修改配置时调用 reload_config() 重新从数据库加载 +""" + +import os +from typing import Any, Optional + +# 全局配置缓存 +_config_cache: dict[str, Any] = {} +_config_initialized = False + +# Client Configuration + +# 需要自动封禁的错误码 (默认值,可通过环境变量或配置覆盖) +AUTO_BAN_ERROR_CODES = [403] + +# ====================== 环境变量映射表 ====================== +# 统一维护环境变量名和配置键名的映射关系 +# 格式: "环境变量名": "配置键名" +ENV_MAPPINGS = { + "CODE_ASSIST_ENDPOINT": "code_assist_endpoint", + "CREDENTIALS_DIR": "credentials_dir", + "PROXY": "proxy", + "OAUTH_PROXY_URL": "oauth_proxy_url", + "GOOGLEAPIS_PROXY_URL": "googleapis_proxy_url", + "RESOURCE_MANAGER_API_URL": "resource_manager_api_url", + "SERVICE_USAGE_API_URL": "service_usage_api_url", + "AUTO_BAN": "auto_ban_enabled", + "AUTO_BAN_ERROR_CODES": "auto_ban_error_codes", + "RETRY_429_MAX_RETRIES": "retry_429_max_retries", + "RETRY_429_ENABLED": "retry_429_enabled", + "RETRY_429_INTERVAL": "retry_429_interval", + "ANTI_TRUNCATION_MAX_ATTEMPTS": "anti_truncation_max_attempts", + "COMPATIBILITY_MODE": "compatibility_mode_enabled", + "RETURN_THOUGHTS_TO_FRONTEND": "return_thoughts_to_frontend", + "ANTIGRAVITY_STREAM2NOSTREAM": "antigravity_stream2nostream", + "HOST": "host", + "PORT": "port", + "API_PASSWORD": "api_password", + "PANEL_PASSWORD": "panel_password", + "PASSWORD": "password", + "KEEPALIVE_URL": "keepalive_url", + "KEEPALIVE_INTERVAL": "keepalive_interval", +} + + +# ====================== 配置系统 ====================== + +async def init_config(): + """初始化配置缓存(启动时调用一次)""" + global _config_cache, _config_initialized + + if _config_initialized: + return + + try: + from src.storage_adapter import get_storage_adapter + storage_adapter = await get_storage_adapter() + _config_cache = await storage_adapter.get_all_config() + _config_initialized = True + except Exception: + # 初始化失败时使用空缓存 + _config_cache = {} + _config_initialized = True + + +async def reload_config(): + """重新加载配置(修改配置后调用)""" + global _config_cache, _config_initialized + + try: + from src.storage_adapter import get_storage_adapter + storage_adapter = await get_storage_adapter() + + # 如果后端支持 reload_config_cache,调用它 + if hasattr(storage_adapter._backend, 'reload_config_cache'): + await storage_adapter._backend.reload_config_cache() + + # 重新加载配置缓存 + _config_cache = await storage_adapter.get_all_config() + _config_initialized = True + except Exception: + pass + + +def _get_cached_config(key: str, default: Any = None) -> Any: + """从内存缓存获取配置(同步)""" + return _config_cache.get(key, default) + + +async def get_config_value(key: str, default: Any = None, env_var: Optional[str] = None) -> Any: + """Get configuration value with priority: ENV > Storage > default.""" + # 确保配置已初始化 + if not _config_initialized: + await init_config() + + # Priority 1: Environment variable + if env_var and os.getenv(env_var): + return os.getenv(env_var) + + # Priority 2: Memory cache + value = _get_cached_config(key) + if value is not None: + return value + + return default + + +# Configuration getters - all async +async def get_proxy_config(): + """Get proxy configuration.""" + proxy_url = await get_config_value("proxy", env_var="PROXY") + return proxy_url if proxy_url else None + + +async def get_auto_ban_enabled() -> bool: + """Get auto ban enabled setting.""" + env_value = os.getenv("AUTO_BAN") + if env_value: + return env_value.lower() in ("true", "1", "yes", "on") + + return bool(await get_config_value("auto_ban_enabled", False)) + + +async def get_auto_ban_error_codes() -> list: + """ + Get auto ban error codes. + + Environment variable: AUTO_BAN_ERROR_CODES (comma-separated, e.g., "400,403") + Database config key: auto_ban_error_codes + Default: [400, 403] + """ + env_value = os.getenv("AUTO_BAN_ERROR_CODES") + if env_value: + try: + return [int(code.strip()) for code in env_value.split(",") if code.strip()] + except ValueError: + pass + + codes = await get_config_value("auto_ban_error_codes") + if codes and isinstance(codes, list): + return codes + return AUTO_BAN_ERROR_CODES + + +async def get_retry_429_max_retries() -> int: + """Get max retries for 429 errors.""" + env_value = os.getenv("RETRY_429_MAX_RETRIES") + if env_value: + try: + return int(env_value) + except ValueError: + pass + + return int(await get_config_value("retry_429_max_retries", 5)) + + +async def get_retry_429_enabled() -> bool: + """Get 429 retry enabled setting.""" + env_value = os.getenv("RETRY_429_ENABLED") + if env_value: + return env_value.lower() in ("true", "1", "yes", "on") + + return bool(await get_config_value("retry_429_enabled", True)) + + +async def get_retry_429_interval() -> float: + """Get 429 retry interval in seconds.""" + env_value = os.getenv("RETRY_429_INTERVAL") + if env_value: + try: + return float(env_value) + except ValueError: + pass + + return float(await get_config_value("retry_429_interval", 1)) + + +async def get_anti_truncation_max_attempts() -> int: + """ + Get maximum attempts for anti-truncation continuation. + + Environment variable: ANTI_TRUNCATION_MAX_ATTEMPTS + Database config key: anti_truncation_max_attempts + Default: 3 + """ + env_value = os.getenv("ANTI_TRUNCATION_MAX_ATTEMPTS") + if env_value: + try: + return int(env_value) + except ValueError: + pass + + return int(await get_config_value("anti_truncation_max_attempts", 3)) + + +# Server Configuration +async def get_server_host() -> str: + """ + Get server host setting. + + Environment variable: HOST + Database config key: host + Default: 0.0.0.0 + """ + return str(await get_config_value("host", "0.0.0.0", "HOST")) + + +async def get_server_port() -> int: + """ + Get server port setting. + + Environment variable: PORT + Database config key: port + Default: 7861 + """ + env_value = os.getenv("PORT") + if env_value: + try: + return int(env_value) + except ValueError: + pass + + return int(await get_config_value("port", 7861)) + + +async def get_api_password() -> str: + """ + Get API password setting for chat endpoints. + + Environment variable: API_PASSWORD + Database config key: api_password + Default: Uses PASSWORD env var for compatibility, otherwise 'pwd' + """ + # 优先使用 API_PASSWORD,如果没有则使用通用 PASSWORD 保证兼容性 + api_password = await get_config_value("api_password", None, "API_PASSWORD") + if api_password is not None: + return str(api_password) + + # 兼容性:使用通用密码 + return str(await get_config_value("password", "pwd", "PASSWORD")) + + +async def get_panel_password() -> str: + """ + Get panel password setting for web interface. + + Environment variable: PANEL_PASSWORD + Database config key: panel_password + Default: Uses PASSWORD env var for compatibility, otherwise 'pwd' + """ + # 优先使用 PANEL_PASSWORD,如果没有则使用通用 PASSWORD 保证兼容性 + panel_password = await get_config_value("panel_password", None, "PANEL_PASSWORD") + if panel_password is not None: + return str(panel_password) + + # 兼容性:使用通用密码 + return str(await get_config_value("password", "pwd", "PASSWORD")) + + +async def get_server_password() -> str: + """ + Get server password setting (deprecated, use get_api_password or get_panel_password). + + Environment variable: PASSWORD + Database config key: password + Default: pwd + """ + return str(await get_config_value("password", "pwd", "PASSWORD")) + + +async def get_credentials_dir() -> str: + """ + Get credentials directory setting. + + Environment variable: CREDENTIALS_DIR + Database config key: credentials_dir + Default: ./creds + """ + return str(await get_config_value("credentials_dir", "./creds", "CREDENTIALS_DIR")) + + +async def get_code_assist_endpoint() -> str: + """ + Get Code Assist endpoint setting. + + Environment variable: CODE_ASSIST_ENDPOINT + Database config key: code_assist_endpoint + Default: https://cloudcode-pa.googleapis.com + """ + return str( + await get_config_value( + "code_assist_endpoint", "https://cloudcode-pa.googleapis.com", "CODE_ASSIST_ENDPOINT" + ) + ) + + +async def get_compatibility_mode_enabled() -> bool: + """ + Get compatibility mode setting. + + 兼容性模式:启用后所有system消息全部转换成user,停用system_instructions。 + 该选项可能会降低模型理解能力,但是能避免流式空回的情况。 + + Environment variable: COMPATIBILITY_MODE + Database config key: compatibility_mode_enabled + Default: False + """ + env_value = os.getenv("COMPATIBILITY_MODE") + if env_value: + return env_value.lower() in ("true", "1", "yes", "on") + + return bool(await get_config_value("compatibility_mode_enabled", False)) + + +async def get_return_thoughts_to_frontend() -> bool: + """ + Get return thoughts to frontend setting. + + 控制是否将思维链返回到前端。 + 启用后,思维链会在响应中返回;禁用后,思维链会在响应中被过滤掉。 + + Environment variable: RETURN_THOUGHTS_TO_FRONTEND + Database config key: return_thoughts_to_frontend + Default: True + """ + env_value = os.getenv("RETURN_THOUGHTS_TO_FRONTEND") + if env_value: + return env_value.lower() in ("true", "1", "yes", "on") + + return bool(await get_config_value("return_thoughts_to_frontend", True)) + + +async def get_antigravity_stream2nostream() -> bool: + """ + Get use stream for non-stream setting. + + 控制antigravity非流式请求是否使用流式API并收集为完整响应。 + 启用后,非流式请求将在后端使用流式API,然后收集所有块后再返回完整响应。 + + Environment variable: ANTIGRAVITY_STREAM2NOSTREAM + Database config key: antigravity_stream2nostream + Default: True + """ + env_value = os.getenv("ANTIGRAVITY_STREAM2NOSTREAM") + if env_value: + return env_value.lower() in ("true", "1", "yes", "on") + + return bool(await get_config_value("antigravity_stream2nostream", True)) + + +async def get_oauth_proxy_url() -> str: + """ + Get OAuth proxy URL setting. + + 用于Google OAuth2认证的代理URL。 + + Environment variable: OAUTH_PROXY_URL + Database config key: oauth_proxy_url + Default: https://oauth2.googleapis.com + """ + return str( + await get_config_value( + "oauth_proxy_url", "https://oauth2.googleapis.com", "OAUTH_PROXY_URL" + ) + ) + + +async def get_googleapis_proxy_url() -> str: + """ + Get Google APIs proxy URL setting. + + 用于Google APIs调用的代理URL。 + + Environment variable: GOOGLEAPIS_PROXY_URL + Database config key: googleapis_proxy_url + Default: https://www.googleapis.com + """ + return str( + await get_config_value( + "googleapis_proxy_url", "https://www.googleapis.com", "GOOGLEAPIS_PROXY_URL" + ) + ) + + +async def get_resource_manager_api_url() -> str: + """ + Get Google Cloud Resource Manager API URL setting. + + 用于Google Cloud Resource Manager API的URL。 + + Environment variable: RESOURCE_MANAGER_API_URL + Database config key: resource_manager_api_url + Default: https://cloudresourcemanager.googleapis.com + """ + return str( + await get_config_value( + "resource_manager_api_url", + "https://cloudresourcemanager.googleapis.com", + "RESOURCE_MANAGER_API_URL", + ) + ) + + +async def get_service_usage_api_url() -> str: + """ + Get Google Cloud Service Usage API URL setting. + + 用于Google Cloud Service Usage API的URL。 + + Environment variable: SERVICE_USAGE_API_URL + Database config key: service_usage_api_url + Default: https://serviceusage.googleapis.com + """ + return str( + await get_config_value( + "service_usage_api_url", "https://serviceusage.googleapis.com", "SERVICE_USAGE_API_URL" + ) + ) + + +async def get_keepalive_url() -> str: + """ + Get keep-alive URL setting. + + 配置后保活服务会定期向该URL发送GET请求。 + 留空表示禁用保活服务。 + + Environment variable: KEEPALIVE_URL + Database config key: keepalive_url + Default: "" (disabled) + """ + return str(await get_config_value("keepalive_url", "", "KEEPALIVE_URL")) + + +async def get_keepalive_interval() -> int: + """ + Get keep-alive interval in seconds. + + 保活请求发送间隔(秒)。 + + Environment variable: KEEPALIVE_INTERVAL + Database config key: keepalive_interval + Default: 60 + """ + env_value = os.getenv("KEEPALIVE_INTERVAL") + if env_value: + try: + return int(env_value) + except ValueError: + pass + + return int(await get_config_value("keepalive_interval", 60)) diff --git a/darwin-install.sh b/darwin-install.sh new file mode 100644 index 0000000000000000000000000000000000000000..91681f26e2f240ad6495802aa0364b8a6768289f --- /dev/null +++ b/darwin-install.sh @@ -0,0 +1,55 @@ +#!/bin/bash +# macOS 安装脚本 (支持 Intel 和 Apple Silicon) + +# 确保 Homebrew 已安装 +if ! command -v brew &> /dev/null; then + echo "未检测到 Homebrew,开始安装..." + /bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Homebrew/install/HEAD/install.sh)" + + # 检测 Homebrew 安装路径并设置环境变量 + if [[ -f "/opt/homebrew/bin/brew" ]]; then + # Apple Silicon Mac + eval "$(/opt/homebrew/bin/brew shellenv)" + elif [[ -f "/usr/local/bin/brew" ]]; then + # Intel Mac + eval "$(/usr/local/bin/brew shellenv)" + fi +fi + +# 更新 brew 并安装 git +brew update +brew install git + +# 安装 uv (Python 环境管理工具) +curl -Ls https://astral.sh/uv/install.sh | sh + +# 确保 uv 在 PATH 中 +export PATH="$HOME/.local/bin:$PATH" + +# 克隆或进入项目目录 +if [ -f "./web.py" ]; then + # 已经在目标目录 + : +elif [ -f "./gcli2api/web.py" ]; then + cd ./gcli2api +else + git clone https://github.com/su-kaka/gcli2api.git + cd ./gcli2api +fi + +# 拉取最新代码 +git pull + +# 创建并同步虚拟环境 +uv sync + +# 激活虚拟环境 +if [ -f ".venv/bin/activate" ]; then + source .venv/bin/activate +else + echo "❌ 未找到虚拟环境,请检查 uv 是否安装成功" + exit 1 +fi + +# 启动项目 +python3 web.py diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 0000000000000000000000000000000000000000..d68b8eef43e5ab6e87f3629790785c3099bdf96f --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,79 @@ +version: '3.8' + +services: + gcli2api: + image: ghcr.io/su-kaka/gcli2api:latest + container_name: gcli2api + restart: unless-stopped + network_mode: host + environment: + # Password configuration (choose one) + # Option 1: Use common password + - PASSWORD=${PASSWORD:-pwd} + # Option 2: Use separate passwords (uncomment if needed) + # - API_PASSWORD=${API_PASSWORD:-your_api_password} + # - PANEL_PASSWORD=${PANEL_PASSWORD:-your_panel_password} + + # Server configuration + - PORT=${PORT:-7861} + - HOST=${HOST:-0.0.0.0} + + # Optional: Google credentials from environment + # - GOOGLE_CREDENTIALS=${GOOGLE_CREDENTIALS} + + # Optional: Logging configuration + # - LOG_LEVEL=${LOG_LEVEL:-info} + + # Optional: Redis configuration (for distributed storage) + # - REDIS_URI=${REDIS_URI} + # - REDIS_DATABASE=${REDIS_DATABASE:-0} + + # Optional: MongoDB configuration (for distributed storage) + # - MONGODB_URI=${MONGODB_URI} + # - MONGODB_DATABASE=${MONGODB_DATABASE:-gcli2api} + + # Optional: PostgreSQL configuration (for distributed storage) + # - POSTGRES_DSN=${POSTGRES_DSN} + + # Optional: Proxy configuration + # - PROXY=${PROXY} + volumes: + - ./data/creds:/app/creds + +# Example with Redis for distributed storage +# redis: +# image: redis:7-alpine +# container_name: gcli2api-redis +# restart: unless-stopped +# ports: +# - "6379:6379" +# volumes: +# - redis_data:/data +# command: redis-server --appendonly yes +# healthcheck: +# test: ["CMD", "redis-cli", "ping"] +# interval: 10s +# timeout: 3s +# retries: 3 + +# Example with MongoDB for distributed storage +# mongodb: +# image: mongo:7 +# container_name: gcli2api-mongodb +# restart: unless-stopped +# environment: +# MONGO_INITDB_ROOT_USERNAME: ${MONGO_USER:-admin} +# MONGO_INITDB_ROOT_PASSWORD: ${MONGO_PASSWORD:-password} +# ports: +# - "27017:27017" +# volumes: +# - mongodb_data:/data/db +# healthcheck: +# test: ["CMD", "mongosh", "--eval", "db.adminCommand('ping')"] +# interval: 10s +# timeout: 5s +# retries: 3 + +#volumes: +# redis_data: +# mongodb_data: diff --git a/docs/README_EN.md b/docs/README_EN.md new file mode 100644 index 0000000000000000000000000000000000000000..19787ea59c0ae1f33fe383dab55d14c76c001b37 --- /dev/null +++ b/docs/README_EN.md @@ -0,0 +1,760 @@ +# GeminiCLI to API + +**Convert GeminiCLI and Antigravity to OpenAI, GEMINI, and Claude API Compatible Interfaces** + +[![Python 3.12+](https://img.shields.io/badge/python-3.12+-blue.svg)](https://www.python.org/downloads/) +[![License: CNC-1.0](https://img.shields.io/badge/License-CNC--1.0-red.svg)](../LICENSE) +[![Docker](https://img.shields.io/badge/docker-available-blue.svg)](https://github.com/su-kaka/gcli2api/pkgs/container/gcli2api) + +[中文](../README.md) | English | [日本語](./README_JA.md) + +## 🚀 Quick Deploy + +[![Deploy on Zeabur](https://zeabur.com/button.svg)](https://zeabur.com/templates/97VMEF?referralCode=sukaka) +[![Deploy to Render](https://render.com/images/deploy-to-render-button.svg)](https://render.com/deploy?repo=https://github.com/su-kaka/gcli2api) +--- + +## ⚠️ License Declaration + +**This project is licensed under the Cooperative Non-Commercial License (CNC-1.0)** + +This is a strict anti-commercial open source license. Please refer to the [LICENSE](../LICENSE) file for details. + +### ✅ Permitted Uses: +- Personal learning, research, and educational purposes +- Non-profit organization use +- Open source project integration (must comply with the same license) +- Academic research and publication + +### ❌ Prohibited Uses: +- Any form of commercial use +- Enterprise use with annual revenue exceeding $1 million +- Venture capital-backed or publicly traded companies +- Providing paid services or products +- Commercial competitive use + +## Core Features + +### 🔄 API Endpoints and Format Support + +**Multi-endpoint Multi-format Support** +- **OpenAI Compatible Endpoints**: `/v1/chat/completions` and `/v1/models` + - Supports standard OpenAI format (messages structure) + - Supports Gemini native format (contents structure) + - Automatic format detection and conversion, no manual switching required + - Supports multimodal input (text + images) +- **Gemini Native Endpoints**: `/v1/models/{model}:generateContent` and `streamGenerateContent` + - Supports complete Gemini native API specifications + - Multiple authentication methods: Bearer Token, x-goog-api-key header, URL parameter key +- **Claude Format Compatibility**: Full support for Claude API format + - Endpoint: `/v1/messages` (follows Claude API specification) + - Supports Claude standard messages format + - Supports system parameter and Claude-specific features + - Automatically converts to backend-supported format +- **Antigravity API Support**: Supports OpenAI, Gemini, and Claude formats + - OpenAI format endpoint: `/antigravity/v1/chat/completions` + - Gemini format endpoint: `/antigravity/v1/models/{model}:generateContent` and `streamGenerateContent` + - Claude format endpoint: `/antigravity/v1/messages` + - Supports all Antigravity models (Claude, Gemini, etc.) + - Automatic model name mapping and thinking mode detection + +### 🔐 Authentication and Security Management + +**Flexible Password Management** +- **Separate Password Support**: API password (chat endpoints) and control panel password can be set independently +- **Multiple Authentication Methods**: Supports Authorization Bearer, x-goog-api-key header, URL parameters, etc. +- **JWT Token Authentication**: Control panel supports JWT token authentication +- **User Email Retrieval**: Automatically retrieves and displays Google account email addresses + +### 📊 Intelligent Credential Management System + +**Advanced Credential Management** +- Multiple Google OAuth credential automatic rotation +- Enhanced stability through redundant authentication +- Load balancing and concurrent request support +- Automatic failure detection and credential disabling +- Credential usage statistics and quota management +- Support for manual enable/disable credential files +- Batch credential file operations (enable, disable, delete) + +**Credential Status Monitoring** +- Real-time credential health checks +- Error code tracking (429, 403, 500, etc.) +- Automatic banning mechanism (configurable) + +### 🌊 Streaming and Response Processing + +**Multiple Streaming Support** +- True real-time streaming responses +- Fake streaming mode (for compatibility) +- Streaming anti-truncation feature (prevents answer truncation) +- Asynchronous task management and timeout handling + +**Response Optimization** +- Thinking chain content separation +- Reasoning process (reasoning_content) handling +- Multi-turn conversation context management +- Compatibility mode (converts system messages to user messages) + +### 🎛️ Web Management Console + +**Full-featured Web Interface** +- OAuth authentication flow management (supports GCLI and Antigravity dual modes) +- Credential file upload, download, and management +- Real-time log viewing (WebSocket) +- System configuration management +- Usage statistics and monitoring dashboard +- Mobile-friendly interface + +**Batch Operation Support** +- ZIP file batch credential upload (GCLI and Antigravity) +- Batch enable/disable/delete credentials +- Batch user email retrieval +- Batch configuration management +- Unified batch upload interface for all credential types + +### 📈 Usage Monitoring + +**Real-time Monitoring** +- WebSocket real-time log streams +- System status monitoring +- Credential health status + +### 🔧 Advanced Configuration and Customization + +**Network and Proxy Configuration** +- HTTP/HTTPS proxy support +- Proxy endpoint configuration (OAuth, Google APIs, metadata service) +- Timeout and retry configuration +- Network error handling and recovery + +**Performance and Stability Configuration** +- 429 error automatic retry (configurable interval and attempts) +- Anti-truncation maximum retry attempts + +**Logging and Debugging** +- Multi-level logging system (DEBUG, INFO, WARNING, ERROR) +- Log file management +- Real-time log streams +- Log download and clearing + +### 🔄 Environment Variables and Configuration Management + +**Flexible Configuration Methods** +- Environment variable configuration +- Hot configuration updates (partial configuration items) +- Configuration locking (environment variable priority) + +## Supported Models + +All models have 1M context window capacity. Each credential file provides 1000 request quota. + +### 🤖 Base Models +- `gemini-2.5-pro` +- `gemini-3-pro-preview` +- `gemini-3.1-pro-preview` + +### 🧠 Thinking Models +- `gemini-2.5-pro-high`: Thinking mode +- `gemini-2.5-pro-low`: Low thinking mode +- Supports custom thinking budget configuration +- Automatic separation of thinking content and final answers + +### 🔍 Search-Enhanced Models +- `gemini-2.5-pro-search`: Model with integrated search functionality + +### 🖼️ Image Generation Models (Antigravity) +- `gemini-3.1-flash-image`: Base image generation model +- **Resolution Suffixes**: + - `-2k`: 2K resolution + - `-4k`: 4K HD resolution +- **Aspect Ratio Suffixes**: + - `-1x1`: Square (avatar) + - `-16x9`: Landscape (desktop wallpaper) + - `-9x16`: Portrait (mobile wallpaper) + - `-21x9`: Ultra-wide (ultrawide monitor) + - `-4x3`: Traditional display + - `-3x4`: Portrait poster +- **Combination Examples**: + - `gemini-3.1-flash-image-4k-16x9`: 4K landscape + - `gemini-3.1-flash-image-2k-9x16`: 2K portrait +- When no ratio is specified, the API automatically decides the aspect ratio + +### 🌊 Special Feature Variants +- **Fake Streaming Mode**: Add `-假流式` suffix to any model name + - Example: `gemini-2.5-pro-假流式` + - For scenarios requiring streaming responses but server doesn't support true streaming +- **Streaming Anti-truncation Mode**: Add `流式抗截断/` prefix to model name + - Example: `流式抗截断/gemini-2.5-pro` + - Automatically detects response truncation and retries to ensure complete answers + +### 🔧 Automatic Model Feature Detection +- System automatically recognizes feature identifiers in model names +- Transparently handles feature mode transitions +- Supports feature combination usage + + +--- + +## Installation Guide + +### Termux Environment + +**Initial Installation** +```bash +curl -o termux-install.sh "https://raw.githubusercontent.com/su-kaka/gcli2api/refs/heads/master/termux-install.sh" && chmod +x termux-install.sh && ./termux-install.sh +``` + +**Restart Service** +```bash +cd gcli2api +bash termux-start.sh +``` + +### Windows Environment + +**Initial Installation** +```powershell +iex (iwr "https://raw.githubusercontent.com/su-kaka/gcli2api/refs/heads/master/install.ps1" -UseBasicParsing).Content +``` + +**Restart Service** +Double-click to execute `start.bat` + +### Linux Environment + +**Initial Installation** +```bash +curl -o install.sh "https://raw.githubusercontent.com/su-kaka/gcli2api/refs/heads/master/install.sh" && chmod +x install.sh && ./install.sh +``` + +**Restart Service** +```bash +cd gcli2api +bash start.sh +``` + +### macOS Environment + +**Initial Installation** +```bash +curl -o darwin-install.sh "https://raw.githubusercontent.com/su-kaka/gcli2api/refs/heads/master/darwin-install.sh" && chmod +x darwin-install.sh && ./darwin-install.sh +``` + +**Restart Service** +```bash +cd gcli2api +bash start.sh +``` + +### Docker Environment + +**Docker Run Command** +```bash +# Using universal password +docker run -d --name gcli2api --network host -e PASSWORD=pwd -e PORT=7861 -v $(pwd)/data/creds:/app/creds ghcr.io/su-kaka/gcli2api:latest + +# Using separate passwords +docker run -d --name gcli2api --network host -e API_PASSWORD=api_pwd -e PANEL_PASSWORD=panel_pwd -e PORT=7861 -v $(pwd)/data/creds:/app/creds ghcr.io/su-kaka/gcli2api:latest +``` + +**Docker Mac** +```bash +# Using universal password +docker run -d \ + --name gcli2api \ + -p 7861:7861 \ + -p 8080:8080 \ + -e PASSWORD=pwd \ + -e PORT=7861 \ + -v "$(pwd)/data/creds":/app/creds \ + ghcr.io/su-kaka/gcli2api:latest +``` + +```bash +# Using separate passwords +docker run -d \ +--name gcli2api \ +-p 7861:7861 \ +-p 8080:8080 \ +-e API_PASSWORD=api_pwd \ +-e PANEL_PASSWORD=panel_pwd \ +-e PORT=7861 \ +-v $(pwd)/data/creds:/app/creds \ +ghcr.io/su-kaka/gcli2api:latest +``` + +**Docker Compose Run Command** +1. Save the following content as `docker-compose.yml` file: + ```yaml + version: '3.8' + + services: + gcli2api: + image: ghcr.io/su-kaka/gcli2api:latest + container_name: gcli2api + restart: unless-stopped + network_mode: host + environment: + # Using universal password (recommended for simple deployment) + - PASSWORD=pwd + - PORT=7861 + # Or use separate passwords (recommended for production) + # - API_PASSWORD=your_api_password + # - PANEL_PASSWORD=your_panel_password + volumes: + - ./data/creds:/app/creds + healthcheck: + test: ["CMD-SHELL", "python -c \"import sys, urllib.request, os; port = os.environ.get('PORT', '7861'); req = urllib.request.Request(f'http://localhost:{port}/v1/models', headers={'Authorization': 'Bearer ' + os.environ.get('PASSWORD', 'pwd')}); sys.exit(0 if urllib.request.urlopen(req, timeout=5).getcode() == 200 else 1)\""] + interval: 30s + timeout: 10s + retries: 3 + start_period: 40s + ``` +2. Start the service: + ```bash + docker-compose up -d + ``` + +--- + +## Configuration Instructions + +1. Visit `http://127.0.0.1:7861` (default port, modifiable via PORT environment variable) +2. Complete OAuth authentication flow (default password: `pwd`, modifiable via environment variables) + - **GCLI Mode**: For obtaining Google Cloud Gemini API credentials + - **Antigravity Mode**: For obtaining Google Antigravity API credentials +3. Configure client: + +**OpenAI Compatible Client:** + - **Endpoint Address**: `http://127.0.0.1:7861/v1` + - **API Key**: `pwd` (default value, modifiable via API_PASSWORD or PASSWORD environment variables) + +**Gemini Native Client:** + - **Endpoint Address**: `http://127.0.0.1:7861` + - **Authentication Methods**: + - `Authorization: Bearer your_api_password` + - `x-goog-api-key: your_api_password` + - URL parameter: `?key=your_api_password` + +### 🌟 Dual Authentication Mode Support + +**GCLI Authentication Mode** +- Standard Google Cloud Gemini API authentication +- Supports OAuth2.0 authentication flow +- Automatically enables required Google Cloud APIs + +**Antigravity Authentication Mode** +- Dedicated authentication for Google Antigravity API +- Independent credential management system +- Supports batch upload and management +- Completely isolated from GCLI credentials + +**Unified Management Interface** +- Manage both credential types in the "Batch Upload" tab +- Upper section: GCLI credential batch upload (blue theme) +- Lower section: Antigravity credential batch upload (green theme) +- Separate credential management tabs for each type + +## 💾 Data Storage Mode + +### 🌟 Storage Backend Support + +gcli2api supports two storage backends: **Local SQLite (Default)** and **MongoDB (Cloud Distributed Storage)** + +### 📁 Local SQLite Storage (Default) + +**Default Storage Method** +- No configuration required, works out of the box +- Data is stored in a local SQLite database +- Suitable for single-machine deployment and personal use +- Automatically creates and manages database files + +### 🍃 MongoDB Cloud Storage Mode + +**Cloud Distributed Storage Solution** + +When multi-instance deployment or cloud storage is needed, MongoDB storage mode can be enabled. + +### ⚙️ Enable MongoDB Mode + +**Step 1: Configure MongoDB Connection** +```bash +# Local MongoDB +export MONGODB_URI="mongodb://localhost:27017" + +# MongoDB Atlas cloud service +export MONGODB_URI="mongodb+srv://username:password@cluster.mongodb.net" + +# MongoDB with authentication +export MONGODB_URI="mongodb://admin:password@localhost:27017/admin" + +# Optional: Custom database name (default: gcli2api) +export MONGODB_DATABASE="my_gcli_db" +``` + +**Step 2: Start Application** +```bash +# Application will automatically detect MongoDB configuration and use MongoDB storage +python web.py +``` + +**Docker Environment using MongoDB** +```bash +# Single MongoDB deployment +docker run -d --name gcli2api \ + -e MONGODB_URI="mongodb://mongodb:27017" \ + -e API_PASSWORD=your_password \ + --network your_network \ + ghcr.io/su-kaka/gcli2api:latest + +# Using MongoDB Atlas +docker run -d --name gcli2api \ + -e MONGODB_URI="mongodb+srv://user:pass@cluster.mongodb.net/gcli2api" \ + -e API_PASSWORD=your_password \ + -p 7861:7861 \ + ghcr.io/su-kaka/gcli2api:latest +``` + +**Docker Compose Example** +```yaml +version: '3.8' + +services: + mongodb: + image: mongo:7 + container_name: gcli2api-mongodb + restart: unless-stopped + environment: + MONGO_INITDB_ROOT_USERNAME: admin + MONGO_INITDB_ROOT_PASSWORD: password123 + volumes: + - mongodb_data:/data/db + ports: + - "27017:27017" + + gcli2api: + image: ghcr.io/su-kaka/gcli2api:latest + container_name: gcli2api + restart: unless-stopped + depends_on: + - mongodb + environment: + - MONGODB_URI=mongodb://admin:password123@mongodb:27017/admin + - MONGODB_DATABASE=gcli2api + - API_PASSWORD=your_api_password + - PORT=7861 + ports: + - "7861:7861" + +volumes: + mongodb_data: +``` + + +### 🔧 Advanced Configuration + +**MongoDB Connection Optimization** +```bash +# Connection pool and timeout configuration +export MONGODB_URI="mongodb://localhost:27017?maxPoolSize=10&serverSelectionTimeoutMS=5000" + +# Replica set configuration +export MONGODB_URI="mongodb://host1:27017,host2:27017,host3:27017/gcli2api?replicaSet=myReplicaSet" + +# Read-write separation configuration +export MONGODB_URI="mongodb://localhost:27017/gcli2api?readPreference=secondaryPreferred" +``` + +### Environment Variable Configuration + +**Basic Configuration** +- `PORT`: Service port (default: 7861) +- `HOST`: Server listen address (default: 0.0.0.0) + +**Password Configuration** +- `API_PASSWORD`: Chat API access password (default: inherits PASSWORD or pwd) +- `PANEL_PASSWORD`: Control panel access password (default: inherits PASSWORD or pwd) +- `PASSWORD`: Universal password, overrides the above two when set (default: pwd) + +**Performance and Stability Configuration** +- `RETRY_429_ENABLED`: Enable 429 error automatic retry (default: true) +- `RETRY_429_MAX_RETRIES`: Maximum retry attempts for 429 errors (default: 3) +- `RETRY_429_INTERVAL`: Retry interval for 429 errors, in seconds (default: 1.0) +- `ANTI_TRUNCATION_MAX_ATTEMPTS`: Maximum retry attempts for anti-truncation (default: 3) + +**Network and Proxy Configuration** +- `PROXY`: HTTP/HTTPS proxy address (format: `http://host:port`) +- `OAUTH_PROXY_URL`: OAuth authentication proxy endpoint +- `GOOGLEAPIS_PROXY_URL`: Google APIs proxy endpoint +- `METADATA_SERVICE_URL`: Metadata service proxy endpoint + +**Automation Configuration** +- `AUTO_BAN`: Enable automatic credential banning (default: true) +- `AUTO_LOAD_ENV_CREDS`: Automatically load environment variable credentials at startup (default: false) + +**Compatibility Configuration** +- `COMPATIBILITY_MODE`: Enable compatibility mode, converts system messages to user messages (default: false) + +**Logging Configuration** +- `LOG_LEVEL`: Log level (DEBUG/INFO/WARNING/ERROR, default: INFO) +- `LOG_FILE`: Log file path (default: log.txt) + +**Storage Configuration** + +**SQLite Configuration (Default)** +- No configuration required, automatically uses local SQLite database +- Database files are automatically created in the project directory + +**MongoDB Configuration (Optional Cloud Storage)** +- `MONGODB_URI`: MongoDB connection string (enables MongoDB mode when set) +- `MONGODB_DATABASE`: MongoDB database name (default: gcli2api) + +**Docker Usage Example** +```bash +# Using universal password +docker run -d --name gcli2api \ + -e PASSWORD=mypassword \ + -e PORT=7861 \ + ghcr.io/su-kaka/gcli2api:latest + +# Using separate passwords +docker run -d --name gcli2api \ + -e API_PASSWORD=my_api_password \ + -e PANEL_PASSWORD=my_panel_password \ + -e PORT=7861 \ + ghcr.io/su-kaka/gcli2api:latest +``` + +Note: When credential environment variables are set, the system will prioritize using credentials from environment variables and ignore files in the `creds` directory. + +### API Usage Methods + +This service supports multiple complete sets of API endpoints: + +#### 1. OpenAI Compatible Endpoints (GCLI) + +**Endpoint:** `/v1/chat/completions` +**Authentication:** `Authorization: Bearer your_api_password` + +Supports two request formats with automatic detection and processing: + +**OpenAI Format:** +```json +{ + "model": "gemini-2.5-pro", + "messages": [ + {"role": "system", "content": "You are a helpful assistant"}, + {"role": "user", "content": "Hello"} + ], + "temperature": 0.7, + "stream": true +} +``` + +**Gemini Native Format:** +```json +{ + "model": "gemini-2.5-pro", + "contents": [ + {"role": "user", "parts": [{"text": "Hello"}]} + ], + "systemInstruction": {"parts": [{"text": "You are a helpful assistant"}]}, + "generationConfig": { + "temperature": 0.7 + } +} +``` + +#### 2. Gemini Native Endpoints (GCLI) + +**Non-streaming Endpoint:** `/v1/models/{model}:generateContent` +**Streaming Endpoint:** `/v1/models/{model}:streamGenerateContent` +**Model List:** `/v1/models` + +**Authentication Methods (choose one):** +- `Authorization: Bearer your_api_password` +- `x-goog-api-key: your_api_password` +- URL parameter: `?key=your_api_password` + +**Request Examples:** +```bash +# Using x-goog-api-key header +curl -X POST "http://127.0.0.1:7861/v1/models/gemini-2.5-pro:generateContent" \ + -H "x-goog-api-key: your_api_password" \ + -H "Content-Type: application/json" \ + -d '{ + "contents": [ + {"role": "user", "parts": [{"text": "Hello"}]} + ] + }' + +# Using URL parameter +curl -X POST "http://127.0.0.1:7861/v1/models/gemini-2.5-pro:streamGenerateContent?key=your_api_password" \ + -H "Content-Type: application/json" \ + -d '{ + "contents": [ + {"role": "user", "parts": [{"text": "Hello"}]} + ] + }' +``` + +#### 3. Claude API Format Endpoints + +**Endpoint:** `/v1/messages` +**Authentication:** `x-api-key: your_api_password` or `Authorization: Bearer your_api_password` + +**Request Example:** +```bash +curl -X POST "http://127.0.0.1:7861/v1/messages" \ + -H "x-api-key: your_api_password" \ + -H "anthropic-version: 2023-06-01" \ + -H "Content-Type: application/json" \ + -d '{ + "model": "gemini-2.5-pro", + "max_tokens": 1024, + "messages": [ + {"role": "user", "content": "Hello, Claude!"} + ] + }' +``` + +**Support for system parameter:** +```json +{ + "model": "gemini-2.5-pro", + "max_tokens": 1024, + "system": "You are a helpful assistant", + "messages": [ + {"role": "user", "content": "Hello"} + ] +} +``` + +**Notes:** +- Fully compatible with Claude API format specification +- Automatically converts to Gemini format for backend calls +- Supports all Claude standard parameters +- Response format follows Claude API specification + +## 📋 Complete API Reference + +### Web Console API + +**Authentication Endpoints** +- `POST /auth/login` - User login +- `POST /auth/start` - Start OAuth authentication (supports GCLI and Antigravity modes) +- `POST /auth/callback` - Handle OAuth callback +- `POST /auth/callback-url` - Complete authentication directly from callback URL +- `GET /auth/status/{project_id}` - Check authentication status + +**Credential Management Endpoints** (supports `mode=geminicli` or `mode=antigravity` parameter) +- `POST /creds/upload` - Batch upload credential files (supports JSON and ZIP) +- `GET /creds/status` - Get credential status list (supports pagination and filtering) +- `GET /creds/detail/{filename}` - Get single credential details +- `POST /creds/action` - Single credential operation (enable/disable/delete) +- `POST /creds/batch-action` - Batch credential operations +- `GET /creds/download/{filename}` - Download single credential file +- `GET /creds/download-all` - Package download all credentials +- `POST /creds/fetch-email/{filename}` - Get user email +- `POST /creds/refresh-all-emails` - Batch refresh user emails +- `POST /creds/deduplicate-by-email` - Deduplicate credentials by email +- `POST /creds/verify-project/{filename}` - Verify credential Project ID +- `GET /creds/quota/{filename}` - Get credential quota information (Antigravity only) + +**Configuration Management Endpoints** +- `GET /config/get` - Get current configuration +- `POST /config/save` - Save configuration + +**Log Management Endpoints** +- `POST /logs/clear` - Clear logs +- `GET /logs/download` - Download log file +- `WebSocket /logs/stream` - Real-time log stream + +**Version Information Endpoints** +- `GET /version/info` - Get version information (optional `check_update=true` parameter to check for updates) + +### Chat API Features + +**Multimodal Support** +```json +{ + "model": "gemini-2.5-pro", + "messages": [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Describe this image"}, + { + "type": "image_url", + "image_url": { + "url": "data:image/jpeg;base64,/9j/4AAQSkZJRgABA..." + } + } + ] + } + ] +} +``` + +**Thinking Mode Support** +```json +{ + "model": "gemini-2.5-pro-high", + "messages": [ + {"role": "user", "content": "Complex math problem"} + ] +} +``` + +Response will include separated thinking content: +```json +{ + "choices": [{ + "message": { + "role": "assistant", + "content": "Final answer", + "reasoning_content": "Detailed thought process..." + } + }] +} +``` + +**Streaming Anti-truncation Usage** +```json +{ + "model": "流式抗截断/gemini-2.5-pro", + "messages": [ + {"role": "user", "content": "Write a long article"} + ], + "stream": true +} +``` + +**Compatibility Mode** +```bash +# Enable compatibility mode +export COMPATIBILITY_MODE=true +``` +In this mode, all `system` messages are converted to `user` messages, improving compatibility with certain clients. + +--- + +## 💬 Community + +Welcome to join the QQ group for discussion! + +**QQ Group: 1083250744** + +QQ Group QR Code + +--- + +## License and Disclaimer + +This project is for learning and research purposes only. Using this project indicates that you agree to: +- Not use this project for any commercial purposes +- Bear all risks and responsibilities of using this project +- Comply with relevant terms of service and legal regulations + +The project authors are not responsible for any direct or indirect losses arising from the use of this project. diff --git a/docs/README_JA.md b/docs/README_JA.md new file mode 100644 index 0000000000000000000000000000000000000000..f4fb1a3ad5afb4ec7f77157a7d939c626957b67d --- /dev/null +++ b/docs/README_JA.md @@ -0,0 +1,760 @@ +# GeminiCLI to API + +**GeminiCLIおよびAntigravityをOpenAI、GEMINI、Claude API互換インターフェースに変換** + +[![Python 3.12+](https://img.shields.io/badge/python-3.12+-blue.svg)](https://www.python.org/downloads/) +[![License: CNC-1.0](https://img.shields.io/badge/License-CNC--1.0-red.svg)](../LICENSE) +[![Docker](https://img.shields.io/badge/docker-available-blue.svg)](https://github.com/su-kaka/gcli2api/pkgs/container/gcli2api) + +[中文](../README.md) | [English](README_EN.md) | 日本語 + +## 🚀 クイックデプロイ + +[![Deploy on Zeabur](https://zeabur.com/button.svg)](https://zeabur.com/templates/97VMEF?referralCode=sukaka) +[![Deploy to Render](https://render.com/images/deploy-to-render-button.svg)](https://render.com/deploy?repo=https://github.com/su-kaka/gcli2api) +--- + +## ⚠️ ライセンスについて + +**本プロジェクトはCooperative Non-Commercial License (CNC-1.0) の下でライセンスされています** + +これは厳格な非商用オープンソースライセンスです。詳細は [LICENSE](../LICENSE) ファイルをご参照ください。 + +### ✅ 許可される用途: +- 個人の学習、研究、教育目的 +- 非営利団体での利用 +- オープンソースプロジェクトへの統合(同一ライセンスの遵守が必要) +- 学術研究および論文発表 + +### ❌ 禁止される用途: +- あらゆる形態の商用利用 +- 年間売上が100万ドルを超える企業での利用 +- ベンチャーキャピタルの出資を受けた企業または上場企業 +- 有料サービスまたは製品の提供 +- 商業的な競合利用 + +## コア機能 + +### 🔄 APIエンドポイントとフォーマット対応 + +**マルチエンドポイント・マルチフォーマット対応** +- **OpenAI互換エンドポイント**: `/v1/chat/completions` および `/v1/models` + - 標準OpenAIフォーマット(messages構造)に対応 + - Geminiネイティブフォーマット(contents構造)に対応 + - フォーマットの自動検出・変換、手動切替不要 + - マルチモーダル入力に対応(テキスト+画像) +- **Geminiネイティブエンドポイント**: `/v1/models/{model}:generateContent` および `streamGenerateContent` + - Geminiネイティブ API仕様に完全対応 + - 複数の認証方式: Bearer Token、x-goog-api-keyヘッダー、URLパラメータkey +- **Claudeフォーマット互換**: Claude APIフォーマットに完全対応 + - エンドポイント: `/v1/messages`(Claude API仕様に準拠) + - Claude標準messagesフォーマットに対応 + - systemパラメータおよびClaude固有機能に対応 + - バックエンド対応フォーマットへの自動変換 +- **Antigravity API対応**: OpenAI、Gemini、Claudeフォーマットに対応 + - OpenAIフォーマットエンドポイント: `/antigravity/v1/chat/completions` + - Geminiフォーマットエンドポイント: `/antigravity/v1/models/{model}:generateContent` および `streamGenerateContent` + - Claudeフォーマットエンドポイント: `/antigravity/v1/messages` + - 全Antigravityモデルに対応(Claude、Geminiなど) + - モデル名の自動マッピングおよびThinkingモード検出 + +### 🔐 認証とセキュリティ管理 + +**柔軟なパスワード管理** +- **個別パスワード対応**: APIパスワード(チャットエンドポイント)とコントロールパネルパスワードを個別に設定可能 +- **複数の認証方式**: Authorization Bearer、x-goog-api-keyヘッダー、URLパラメータなどに対応 +- **JWTトークン認証**: コントロールパネルはJWTトークン認証に対応 +- **ユーザーメール取得**: Googleアカウントのメールアドレスを自動取得・表示 + +### 📊 インテリジェントなクレデンシャル管理システム + +**高度なクレデンシャル管理** +- 複数のGoogle OAuthクレデンシャルの自動ローテーション +- 冗長認証による安定性の向上 +- ロードバランシングと同時リクエスト対応 +- 自動障害検出とクレデンシャル無効化 +- クレデンシャル使用統計とクォータ管理 +- クレデンシャルファイルの手動有効化/無効化に対応 +- クレデンシャルファイルの一括操作(有効化、無効化、削除) + +**クレデンシャルステータス監視** +- リアルタイムのクレデンシャルヘルスチェック +- エラーコードの追跡(429、403、500など) +- 自動BAN機能(設定可能) + +### 🌊 ストリーミングとレスポンス処理 + +**複数のストリーミング対応** +- リアルタイムストリーミングレスポンス +- 疑似ストリーミングモード(互換性向上用) +- ストリーミング途切れ防止機能(回答の途切れを防止) +- 非同期タスク管理とタイムアウト処理 + +**レスポンス最適化** +- 思考チェーン内容の分離 +- 推論プロセス(reasoning_content)の処理 +- マルチターン会話のコンテキスト管理 +- 互換モード(systemメッセージをuserメッセージに変換) + +### 🎛️ Web管理コンソール + +**フル機能のWebインターフェース** +- OAuth認証フロー管理(GCLIおよびAntigravityデュアルモード対応) +- クレデンシャルファイルのアップロード、ダウンロード、管理 +- リアルタイムログ表示(WebSocket) +- システム設定管理 +- 使用統計と監視ダッシュボード +- モバイル対応インターフェース + +**一括操作対応** +- ZIPファイルによるクレデンシャル一括アップロード(GCLIおよびAntigravity) +- クレデンシャルの一括有効化/無効化/削除 +- ユーザーメールの一括取得 +- 設定の一括管理 +- 全クレデンシャルタイプ統合一括アップロードインターフェース + +### 📈 使用状況モニタリング + +**リアルタイム監視** +- WebSocketリアルタイムログストリーム +- システムステータス監視 +- クレデンシャルヘルスステータス + +### 🔧 高度な設定とカスタマイズ + +**ネットワークとプロキシ設定** +- HTTP/HTTPSプロキシ対応 +- プロキシエンドポイント設定(OAuth、Google APIs、メタデータサービス) +- タイムアウトとリトライ設定 +- ネットワークエラー処理とリカバリ + +**パフォーマンスと安定性の設定** +- 429エラーの自動リトライ(間隔と回数を設定可能) +- 途切れ防止の最大リトライ回数 + +**ログとデバッグ** +- マルチレベルログシステム(DEBUG、INFO、WARNING、ERROR) +- ログファイル管理 +- リアルタイムログストリーム +- ログのダウンロードとクリア + +### 🔄 環境変数と設定管理 + +**柔軟な設定方法** +- 環境変数による設定 +- ホット設定更新(一部設定項目) +- 設定ロック(環境変数優先) + +## 対応モデル + +全モデルが100万トークンのコンテキストウィンドウに対応。各クレデンシャルファイルで1000リクエストのクォータを提供。 + +### 🤖 基本モデル +- `gemini-2.5-pro` +- `gemini-3-pro-preview` +- `gemini-3.1-pro-preview` + +### 🧠 Thinkingモデル +- `gemini-2.5-pro-high`: Thinkingモード +- `gemini-2.5-pro-low`: 低Thinkingモード +- カスタムThinkingバジェット設定に対応 +- 思考内容と最終回答の自動分離 + +### 🔍 検索拡張モデル +- `gemini-2.5-pro-search`: 検索機能統合モデル + +### 🖼️ 画像生成モデル(Antigravity) +- `gemini-3.1-flash-image`: 基本画像生成モデル +- **解像度サフィックス**: + - `-2k`: 2K解像度 + - `-4k`: 4K HD解像度 +- **アスペクト比サフィックス**: + - `-1x1`: 正方形(アバター) + - `-16x9`: 横長(デスクトップ壁紙) + - `-9x16`: 縦長(モバイル壁紙) + - `-21x9`: ウルトラワイド(ウルトラワイドモニター) + - `-4x3`: 従来のディスプレイ + - `-3x4`: 縦型ポスター +- **組み合わせ例**: + - `gemini-3.1-flash-image-4k-16x9`: 4K横長 + - `gemini-3.1-flash-image-2k-9x16`: 2K縦長 +- 比率未指定時はAPIが自動的にアスペクト比を決定 + +### 🌊 特殊機能バリアント +- **疑似ストリーミングモード**: 任意のモデル名に `-假流式` サフィックスを追加 + - 例: `gemini-2.5-pro-假流式` + - ストリーミングレスポンスが必要だがサーバーが真のストリーミングに対応していない場合に使用 +- **ストリーミング途切れ防止モード**: モデル名に `流式抗截断/` プレフィックスを追加 + - 例: `流式抗截断/gemini-2.5-pro` + - レスポンスの途切れを自動検出しリトライして完全な回答を保証 + +### 🔧 モデル機能の自動検出 +- システムがモデル名内の機能識別子を自動認識 +- 機能モード切替を透過的に処理 +- 機能の組み合わせ使用に対応 + + +--- + +## インストールガイド + +### Termux環境 + +**初期インストール** +```bash +curl -o termux-install.sh "https://raw.githubusercontent.com/su-kaka/gcli2api/refs/heads/master/termux-install.sh" && chmod +x termux-install.sh && ./termux-install.sh +``` + +**サービス再起動** +```bash +cd gcli2api +bash termux-start.sh +``` + +### Windows環境 + +**初期インストール** +```powershell +iex (iwr "https://raw.githubusercontent.com/su-kaka/gcli2api/refs/heads/master/install.ps1" -UseBasicParsing).Content +``` + +**サービス再起動** +`start.bat` をダブルクリックして実行 + +### Linux環境 + +**初期インストール** +```bash +curl -o install.sh "https://raw.githubusercontent.com/su-kaka/gcli2api/refs/heads/master/install.sh" && chmod +x install.sh && ./install.sh +``` + +**サービス再起動** +```bash +cd gcli2api +bash start.sh +``` + +### macOS環境 + +**初期インストール** +```bash +curl -o darwin-install.sh "https://raw.githubusercontent.com/su-kaka/gcli2api/refs/heads/master/darwin-install.sh" && chmod +x darwin-install.sh && ./darwin-install.sh +``` + +**サービス再起動** +```bash +cd gcli2api +bash start.sh +``` + +### Docker環境 + +**Docker Runコマンド** +```bash +# 共通パスワードを使用 +docker run -d --name gcli2api --network host -e PASSWORD=pwd -e PORT=7861 -v $(pwd)/data/creds:/app/creds ghcr.io/su-kaka/gcli2api:latest + +# 個別パスワードを使用 +docker run -d --name gcli2api --network host -e API_PASSWORD=api_pwd -e PANEL_PASSWORD=panel_pwd -e PORT=7861 -v $(pwd)/data/creds:/app/creds ghcr.io/su-kaka/gcli2api:latest +``` + +**Docker Mac** +```bash +# 共通パスワードを使用 +docker run -d \ + --name gcli2api \ + -p 7861:7861 \ + -p 8080:8080 \ + -e PASSWORD=pwd \ + -e PORT=7861 \ + -v "$(pwd)/data/creds":/app/creds \ + ghcr.io/su-kaka/gcli2api:latest +``` + +```bash +# 個別パスワードを使用 +docker run -d \ +--name gcli2api \ +-p 7861:7861 \ +-p 8080:8080 \ +-e API_PASSWORD=api_pwd \ +-e PANEL_PASSWORD=panel_pwd \ +-e PORT=7861 \ +-v $(pwd)/data/creds:/app/creds \ +ghcr.io/su-kaka/gcli2api:latest +``` + +**Docker Compose Runコマンド** +1. 以下の内容を `docker-compose.yml` ファイルとして保存: + ```yaml + version: '3.8' + + services: + gcli2api: + image: ghcr.io/su-kaka/gcli2api:latest + container_name: gcli2api + restart: unless-stopped + network_mode: host + environment: + # 共通パスワードを使用(シンプルなデプロイに推奨) + - PASSWORD=pwd + - PORT=7861 + # または個別パスワードを使用(本番環境に推奨) + # - API_PASSWORD=your_api_password + # - PANEL_PASSWORD=your_panel_password + volumes: + - ./data/creds:/app/creds + healthcheck: + test: ["CMD-SHELL", "python -c \"import sys, urllib.request, os; port = os.environ.get('PORT', '7861'); req = urllib.request.Request(f'http://localhost:{port}/v1/models', headers={'Authorization': 'Bearer ' + os.environ.get('PASSWORD', 'pwd')}); sys.exit(0 if urllib.request.urlopen(req, timeout=5).getcode() == 200 else 1)\""] + interval: 30s + timeout: 10s + retries: 3 + start_period: 40s + ``` +2. サービスを起動: + ```bash + docker-compose up -d + ``` + +--- + +## 設定手順 + +1. `http://127.0.0.1:7861` にアクセス(デフォルトポート、PORT環境変数で変更可能) +2. OAuth認証フローを完了(デフォルトパスワード: `pwd`、環境変数で変更可能) + - **GCLIモード**: Google Cloud Gemini APIクレデンシャルの取得用 + - **Antigravityモード**: Google Antigravity APIクレデンシャルの取得用 +3. クライアントを設定: + +**OpenAI互換クライアント:** + - **エンドポイントアドレス**: `http://127.0.0.1:7861/v1` + - **APIキー**: `pwd`(デフォルト値、API_PASSWORDまたはPASSWORD環境変数で変更可能) + +**Geminiネイティブクライアント:** + - **エンドポイントアドレス**: `http://127.0.0.1:7861` + - **認証方式**: + - `Authorization: Bearer your_api_password` + - `x-goog-api-key: your_api_password` + - URLパラメータ: `?key=your_api_password` + +### 🌟 デュアル認証モード対応 + +**GCLI認証モード** +- 標準Google Cloud Gemini API認証 +- OAuth2.0認証フローに対応 +- 必要なGoogle Cloud APIを自動的に有効化 + +**Antigravity認証モード** +- Google Antigravity API専用認証 +- 独立したクレデンシャル管理システム +- 一括アップロードと管理に対応 +- GCLIクレデンシャルとは完全に分離 + +**統合管理インターフェース** +- 「一括アップロード」タブで両方のクレデンシャルタイプを管理 +- 上部セクション: GCLIクレデンシャル一括アップロード(青テーマ) +- 下部セクション: Antigravityクレデンシャル一括アップロード(緑テーマ) +- 各タイプ別のクレデンシャル管理タブ + +## 💾 データストレージモード + +### 🌟 ストレージバックエンド対応 + +gcli2apiは2つのストレージバックエンドに対応: **ローカルSQLite(デフォルト)** と **MongoDB(クラウド分散ストレージ)** + +### 📁 ローカルSQLiteストレージ(デフォルト) + +**デフォルトストレージ方式** +- 設定不要、すぐに利用可能 +- データはローカルSQLiteデータベースに保存 +- 単一マシンデプロイおよび個人利用に最適 +- データベースファイルの自動作成・管理 + +### 🍃 MongoDBクラウドストレージモード + +**クラウド分散ストレージソリューション** + +マルチインスタンスデプロイやクラウドストレージが必要な場合、MongoDBストレージモードを有効にできます。 + +### ⚙️ MongoDBモードの有効化 + +**ステップ1: MongoDB接続の設定** +```bash +# ローカルMongoDB +export MONGODB_URI="mongodb://localhost:27017" + +# MongoDB Atlasクラウドサービス +export MONGODB_URI="mongodb+srv://username:password@cluster.mongodb.net" + +# 認証付きMongoDB +export MONGODB_URI="mongodb://admin:password@localhost:27017/admin" + +# オプション: カスタムデータベース名(デフォルト: gcli2api) +export MONGODB_DATABASE="my_gcli_db" +``` + +**ステップ2: アプリケーションの起動** +```bash +# アプリケーションがMongoDB設定を自動検出し、MongoDBストレージを使用します +python web.py +``` + +**Docker環境でのMongoDB使用** +```bash +# 単一MongoDBデプロイ +docker run -d --name gcli2api \ + -e MONGODB_URI="mongodb://mongodb:27017" \ + -e API_PASSWORD=your_password \ + --network your_network \ + ghcr.io/su-kaka/gcli2api:latest + +# MongoDB Atlasの使用 +docker run -d --name gcli2api \ + -e MONGODB_URI="mongodb+srv://user:pass@cluster.mongodb.net/gcli2api" \ + -e API_PASSWORD=your_password \ + -p 7861:7861 \ + ghcr.io/su-kaka/gcli2api:latest +``` + +**Docker Composeの例** +```yaml +version: '3.8' + +services: + mongodb: + image: mongo:7 + container_name: gcli2api-mongodb + restart: unless-stopped + environment: + MONGO_INITDB_ROOT_USERNAME: admin + MONGO_INITDB_ROOT_PASSWORD: password123 + volumes: + - mongodb_data:/data/db + ports: + - "27017:27017" + + gcli2api: + image: ghcr.io/su-kaka/gcli2api:latest + container_name: gcli2api + restart: unless-stopped + depends_on: + - mongodb + environment: + - MONGODB_URI=mongodb://admin:password123@mongodb:27017/admin + - MONGODB_DATABASE=gcli2api + - API_PASSWORD=your_api_password + - PORT=7861 + ports: + - "7861:7861" + +volumes: + mongodb_data: +``` + + +### 🔧 高度な設定 + +**MongoDB接続の最適化** +```bash +# コネクションプールとタイムアウト設定 +export MONGODB_URI="mongodb://localhost:27017?maxPoolSize=10&serverSelectionTimeoutMS=5000" + +# レプリカセット設定 +export MONGODB_URI="mongodb://host1:27017,host2:27017,host3:27017/gcli2api?replicaSet=myReplicaSet" + +# リード・ライト分離設定 +export MONGODB_URI="mongodb://localhost:27017/gcli2api?readPreference=secondaryPreferred" +``` + +### 環境変数設定 + +**基本設定** +- `PORT`: サービスポート(デフォルト: 7861) +- `HOST`: サーバーリッスンアドレス(デフォルト: 0.0.0.0) + +**パスワード設定** +- `API_PASSWORD`: チャットAPIアクセスパスワード(デフォルト: PASSWORDまたはpwdを継承) +- `PANEL_PASSWORD`: コントロールパネルアクセスパスワード(デフォルト: PASSWORDまたはpwdを継承) +- `PASSWORD`: 共通パスワード、設定時に上記2つを上書き(デフォルト: pwd) + +**パフォーマンスと安定性の設定** +- `RETRY_429_ENABLED`: 429エラー自動リトライの有効化(デフォルト: true) +- `RETRY_429_MAX_RETRIES`: 429エラーの最大リトライ回数(デフォルト: 3) +- `RETRY_429_INTERVAL`: 429エラーのリトライ間隔、秒単位(デフォルト: 1.0) +- `ANTI_TRUNCATION_MAX_ATTEMPTS`: 途切れ防止の最大リトライ回数(デフォルト: 3) + +**ネットワークとプロキシ設定** +- `PROXY`: HTTP/HTTPSプロキシアドレス(形式: `http://host:port`) +- `OAUTH_PROXY_URL`: OAuth認証プロキシエンドポイント +- `GOOGLEAPIS_PROXY_URL`: Google APIsプロキシエンドポイント +- `METADATA_SERVICE_URL`: メタデータサービスプロキシエンドポイント + +**自動化設定** +- `AUTO_BAN`: クレデンシャル自動BANの有効化(デフォルト: true) +- `AUTO_LOAD_ENV_CREDS`: 起動時に環境変数クレデンシャルを自動ロード(デフォルト: false) + +**互換性設定** +- `COMPATIBILITY_MODE`: 互換モードの有効化、systemメッセージをuserメッセージに変換(デフォルト: false) + +**ログ設定** +- `LOG_LEVEL`: ログレベル(DEBUG/INFO/WARNING/ERROR、デフォルト: INFO) +- `LOG_FILE`: ログファイルパス(デフォルト: log.txt) + +**ストレージ設定** + +**SQLite設定(デフォルト)** +- 設定不要、自動的にローカルSQLiteデータベースを使用 +- データベースファイルはプロジェクトディレクトリに自動作成 + +**MongoDB設定(オプションのクラウドストレージ)** +- `MONGODB_URI`: MongoDB接続文字列(設定時にMongoDBモードを有効化) +- `MONGODB_DATABASE`: MongoDBデータベース名(デフォルト: gcli2api) + +**Docker使用例** +```bash +# 共通パスワードを使用 +docker run -d --name gcli2api \ + -e PASSWORD=mypassword \ + -e PORT=7861 \ + ghcr.io/su-kaka/gcli2api:latest + +# 個別パスワードを使用 +docker run -d --name gcli2api \ + -e API_PASSWORD=my_api_password \ + -e PANEL_PASSWORD=my_panel_password \ + -e PORT=7861 \ + ghcr.io/su-kaka/gcli2api:latest +``` + +注意: クレデンシャル環境変数が設定されている場合、システムは環境変数のクレデンシャルを優先的に使用し、`creds` ディレクトリ内のファイルを無視します。 + +### API使用方法 + +本サービスは複数の完全なAPIエンドポイントセットに対応しています: + +#### 1. OpenAI互換エンドポイント(GCLI) + +**エンドポイント:** `/v1/chat/completions` +**認証:** `Authorization: Bearer your_api_password` + +2つのリクエストフォーマットに対応し、自動検出・処理を行います: + +**OpenAIフォーマット:** +```json +{ + "model": "gemini-2.5-pro", + "messages": [ + {"role": "system", "content": "You are a helpful assistant"}, + {"role": "user", "content": "Hello"} + ], + "temperature": 0.7, + "stream": true +} +``` + +**Geminiネイティブフォーマット:** +```json +{ + "model": "gemini-2.5-pro", + "contents": [ + {"role": "user", "parts": [{"text": "Hello"}]} + ], + "systemInstruction": {"parts": [{"text": "You are a helpful assistant"}]}, + "generationConfig": { + "temperature": 0.7 + } +} +``` + +#### 2. Geminiネイティブエンドポイント(GCLI) + +**非ストリーミングエンドポイント:** `/v1/models/{model}:generateContent` +**ストリーミングエンドポイント:** `/v1/models/{model}:streamGenerateContent` +**モデル一覧:** `/v1/models` + +**認証方式(いずれか1つを選択):** +- `Authorization: Bearer your_api_password` +- `x-goog-api-key: your_api_password` +- URLパラメータ: `?key=your_api_password` + +**リクエスト例:** +```bash +# x-goog-api-keyヘッダーを使用 +curl -X POST "http://127.0.0.1:7861/v1/models/gemini-2.5-pro:generateContent" \ + -H "x-goog-api-key: your_api_password" \ + -H "Content-Type: application/json" \ + -d '{ + "contents": [ + {"role": "user", "parts": [{"text": "Hello"}]} + ] + }' + +# URLパラメータを使用 +curl -X POST "http://127.0.0.1:7861/v1/models/gemini-2.5-pro:streamGenerateContent?key=your_api_password" \ + -H "Content-Type: application/json" \ + -d '{ + "contents": [ + {"role": "user", "parts": [{"text": "Hello"}]} + ] + }' +``` + +#### 3. Claude APIフォーマットエンドポイント + +**エンドポイント:** `/v1/messages` +**認証:** `x-api-key: your_api_password` または `Authorization: Bearer your_api_password` + +**リクエスト例:** +```bash +curl -X POST "http://127.0.0.1:7861/v1/messages" \ + -H "x-api-key: your_api_password" \ + -H "anthropic-version: 2023-06-01" \ + -H "Content-Type: application/json" \ + -d '{ + "model": "gemini-2.5-pro", + "max_tokens": 1024, + "messages": [ + {"role": "user", "content": "Hello, Claude!"} + ] + }' +``` + +**systemパラメータの対応:** +```json +{ + "model": "gemini-2.5-pro", + "max_tokens": 1024, + "system": "You are a helpful assistant", + "messages": [ + {"role": "user", "content": "Hello"} + ] +} +``` + +**注意事項:** +- Claude APIフォーマット仕様に完全互換 +- バックエンド呼び出し時にGeminiフォーマットへ自動変換 +- すべてのClaude標準パラメータに対応 +- レスポンスフォーマットはClaude API仕様に準拠 + +## 📋 完全なAPIリファレンス + +### Webコンソール API + +**認証エンドポイント** +- `POST /auth/login` - ユーザーログイン +- `POST /auth/start` - OAuth認証の開始(GCLIおよびAntigravityモード対応) +- `POST /auth/callback` - OAuthコールバックの処理 +- `POST /auth/callback-url` - コールバックURLから直接認証を完了 +- `GET /auth/status/{project_id}` - 認証ステータスの確認 + +**クレデンシャル管理エンドポイント**(`mode=geminicli` または `mode=antigravity` パラメータ対応) +- `POST /creds/upload` - クレデンシャルファイルの一括アップロード(JSONおよびZIP対応) +- `GET /creds/status` - クレデンシャルステータス一覧の取得(ページネーションとフィルタリング対応) +- `GET /creds/detail/{filename}` - 単一クレデンシャルの詳細取得 +- `POST /creds/action` - 単一クレデンシャル操作(有効化/無効化/削除) +- `POST /creds/batch-action` - クレデンシャルの一括操作 +- `GET /creds/download/{filename}` - 単一クレデンシャルファイルのダウンロード +- `GET /creds/download-all` - 全クレデンシャルの一括ダウンロード +- `POST /creds/fetch-email/{filename}` - ユーザーメールの取得 +- `POST /creds/refresh-all-emails` - ユーザーメールの一括更新 +- `POST /creds/deduplicate-by-email` - メールによるクレデンシャルの重複排除 +- `POST /creds/verify-project/{filename}` - クレデンシャルのProject ID検証 +- `GET /creds/quota/{filename}` - クレデンシャルのクォータ情報取得(Antigravityのみ) + +**設定管理エンドポイント** +- `GET /config/get` - 現在の設定の取得 +- `POST /config/save` - 設定の保存 + +**ログ管理エンドポイント** +- `POST /logs/clear` - ログのクリア +- `GET /logs/download` - ログファイルのダウンロード +- `WebSocket /logs/stream` - リアルタイムログストリーム + +**バージョン情報エンドポイント** +- `GET /version/info` - バージョン情報の取得(オプション `check_update=true` パラメータで更新確認) + +### チャットAPI機能 + +**マルチモーダル対応** +```json +{ + "model": "gemini-2.5-pro", + "messages": [ + { + "role": "user", + "content": [ + {"type": "text", "text": "この画像を説明してください"}, + { + "type": "image_url", + "image_url": { + "url": "data:image/jpeg;base64,/9j/4AAQSkZJRgABA..." + } + } + ] + } + ] +} +``` + +**Thinkingモード対応** +```json +{ + "model": "gemini-2.5-pro-high", + "messages": [ + {"role": "user", "content": "複雑な数学の問題"} + ] +} +``` + +レスポンスには分離された思考内容が含まれます: +```json +{ + "choices": [{ + "message": { + "role": "assistant", + "content": "最終回答", + "reasoning_content": "詳細な思考プロセス..." + } + }] +} +``` + +**ストリーミング途切れ防止の使用方法** +```json +{ + "model": "流式抗截断/gemini-2.5-pro", + "messages": [ + {"role": "user", "content": "長い記事を書いてください"} + ], + "stream": true +} +``` + +**互換モード** +```bash +# 互換モードを有効化 +export COMPATIBILITY_MODE=true +``` +このモードでは、すべての `system` メッセージが `user` メッセージに変換され、特定のクライアントとの互換性が向上します。 + +--- + +## 💬 コミュニティ + +QQグループへの参加をお待ちしています! + +**QQグループ: 1083250744** + +QQグループQRコード + +--- + +## ライセンスと免責事項 + +本プロジェクトは学習および研究目的のみに使用できます。本プロジェクトの使用は、以下に同意したことを意味します: +- 本プロジェクトをいかなる商用目的にも使用しないこと +- 本プロジェクトの使用に伴うすべてのリスクと責任を負うこと +- 関連するサービス利用規約および法的規制を遵守すること + +プロジェクトの作者は、本プロジェクトの使用から生じるいかなる直接的または間接的な損害についても責任を負いません。 diff --git a/front/common.js b/front/common.js new file mode 100644 index 0000000000000000000000000000000000000000..50cbab82a69d05d5cb12ed30edc4d3fa32474b82 --- /dev/null +++ b/front/common.js @@ -0,0 +1,3219 @@ +// ===================================================================== +// GCLI2API 控制面板公共JavaScript模块 +// ===================================================================== + +// ===================================================================== +// 全局状态管理 +// ===================================================================== +const AppState = { + // 认证相关 + authToken: '', + authInProgress: false, + currentProjectId: '', + + // Antigravity认证 + antigravityAuthState: null, + antigravityAuthInProgress: false, + + // 凭证管理 + creds: createCredsManager('normal'), + antigravityCreds: createCredsManager('antigravity'), + + // 文件上传 + uploadFiles: createUploadManager('normal'), + antigravityUploadFiles: createUploadManager('antigravity'), + + // 配置管理 + currentConfig: {}, + envLockedFields: new Set(), + + // 日志管理 + logWebSocket: null, + allLogs: [], + filteredLogs: [], + currentLogFilter: 'all', + + // 使用统计 + usageStatsData: {}, + + // 冷却倒计时 + cooldownTimerInterval: null +}; + +// ===================================================================== +// 凭证管理器工厂 +// ===================================================================== +function createCredsManager(type) { + const modeParam = type === 'antigravity' ? 'mode=antigravity' : 'mode=geminicli'; + + return { + type: type, + data: {}, + filteredData: {}, + currentPage: 1, + pageSize: 20, + selectedFiles: new Set(), + totalCount: 0, + currentStatusFilter: 'all', + currentErrorCodeFilter: 'all', + currentCooldownFilter: 'all', + currentPreviewFilter: 'all', + currentTierFilter: 'all', + statsData: { total: 0, normal: 0, disabled: 0 }, + + // API端点 + getEndpoint: (action) => { + const endpoints = { + status: `./creds/status`, + action: `./creds/action`, + batchAction: `./creds/batch-action`, + download: `./creds/download`, + downloadAll: `./creds/download-all`, + detail: `./creds/detail`, + fetchEmail: `./creds/fetch-email`, + refreshAllEmails: `./creds/refresh-all-emails`, + deduplicate: `./creds/deduplicate-by-email`, + verifyProject: `./creds/verify-project`, + quota: `./creds/quota` + }; + return endpoints[action] || ''; + }, + + // 获取mode参数 + getModeParam: () => modeParam, + + // DOM元素ID前缀 + getElementId: (suffix) => { + // 普通凭证的ID首字母小写,如 credsLoading + // Antigravity的ID是 antigravity + 首字母大写,如 antigravityCredsLoading + if (type === 'antigravity') { + return 'antigravity' + suffix.charAt(0).toUpperCase() + suffix.slice(1); + } + return suffix.charAt(0).toLowerCase() + suffix.slice(1); + }, + + // 刷新凭证列表 + async refresh() { + const loading = document.getElementById(this.getElementId('CredsLoading')); + const list = document.getElementById(this.getElementId('CredsList')); + + try { + loading.style.display = 'block'; + list.innerHTML = ''; + + const offset = (this.currentPage - 1) * this.pageSize; + const errorCodeFilter = this.currentErrorCodeFilter || 'all'; + const cooldownFilter = this.currentCooldownFilter || 'all'; + const previewFilter = this.currentPreviewFilter || 'all'; + const tierFilter = this.currentTierFilter || 'all'; + const response = await fetch( + `${this.getEndpoint('status')}?offset=${offset}&limit=${this.pageSize}&status_filter=${this.currentStatusFilter}&error_code_filter=${errorCodeFilter}&cooldown_filter=${cooldownFilter}&preview_filter=${previewFilter}&tier_filter=${tierFilter}&${this.getModeParam()}`, + { headers: getAuthHeaders() } + ); + + const data = await response.json(); + + if (response.ok) { + this.data = {}; + data.items.forEach(item => { + this.data[item.filename] = { + filename: item.filename, + status: { + disabled: item.disabled, + error_codes: item.error_codes || [], + last_success: item.last_success, + }, + user_email: item.user_email, + model_cooldowns: item.model_cooldowns || {}, + preview: item.preview, + tier: item.tier || 'pro', + enable_credit: !!item.enable_credit + }; + }); + + this.totalCount = data.total; + // 使用后端返回的全局统计数据 + if (data.stats) { + this.statsData = data.stats; + } else { + // 兼容旧版本后端 + this.calculateStats(); + } + this.updateStatsDisplay(); + this.filteredData = this.data; + this.renderList(); + this.updatePagination(); + + let msg = `已加载 ${data.total} 个${type === 'antigravity' ? 'Antigravity' : ''}凭证文件`; + if (this.currentStatusFilter !== 'all') { + msg += ` (筛选: ${this.currentStatusFilter === 'enabled' ? '仅启用' : '仅禁用'})`; + } + showStatus(msg, 'success'); + } else { + showStatus(`加载失败: ${data.detail || data.error || '未知错误'}`, 'error'); + } + } catch (error) { + showStatus(`网络错误: ${error.message}`, 'error'); + } finally { + loading.style.display = 'none'; + } + }, + + // 计算统计数据(仅用于兼容旧版本后端) + calculateStats() { + this.statsData = { total: this.totalCount, normal: 0, disabled: 0 }; + Object.values(this.data).forEach(credInfo => { + if (credInfo.status.disabled) { + this.statsData.disabled++; + } else { + this.statsData.normal++; + } + }); + }, + + // 更新统计显示 + updateStatsDisplay() { + document.getElementById(this.getElementId('StatTotal')).textContent = this.statsData.total; + document.getElementById(this.getElementId('StatNormal')).textContent = this.statsData.normal; + document.getElementById(this.getElementId('StatDisabled')).textContent = this.statsData.disabled; + }, + + // 渲染凭证列表 + renderList() { + const list = document.getElementById(this.getElementId('CredsList')); + list.innerHTML = ''; + + const entries = Object.entries(this.filteredData); + + if (entries.length === 0) { + const msg = this.totalCount === 0 ? '暂无凭证文件' : '当前筛选条件下暂无数据'; + list.innerHTML = `

${msg}

`; + document.getElementById(this.getElementId('PaginationContainer')).style.display = 'none'; + return; + } + + entries.forEach(([, credInfo]) => { + list.appendChild(createCredCard(credInfo, this)); + }); + + document.getElementById(this.getElementId('PaginationContainer')).style.display = + this.getTotalPages() > 1 ? 'flex' : 'none'; + this.updateBatchControls(); + }, + + // 获取总页数 + getTotalPages() { + return Math.ceil(this.totalCount / this.pageSize); + }, + + // 更新分页信息 + updatePagination() { + const totalPages = this.getTotalPages(); + const startItem = (this.currentPage - 1) * this.pageSize + 1; + const endItem = Math.min(this.currentPage * this.pageSize, this.totalCount); + + document.getElementById(this.getElementId('PaginationInfo')).textContent = + `第 ${this.currentPage} 页,共 ${totalPages} 页 (显示 ${startItem}-${endItem},共 ${this.totalCount} 项)`; + + document.getElementById(this.getElementId('PrevPageBtn')).disabled = this.currentPage <= 1; + document.getElementById(this.getElementId('NextPageBtn')).disabled = this.currentPage >= totalPages; + }, + + // 切换页面 + changePage(direction) { + const newPage = this.currentPage + direction; + if (newPage >= 1 && newPage <= this.getTotalPages()) { + this.currentPage = newPage; + this.refresh(); + } + }, + + // 改变每页大小 + changePageSize() { + this.pageSize = parseInt(document.getElementById(this.getElementId('PageSizeSelect')).value); + this.currentPage = 1; + this.refresh(); + }, + + // 应用状态筛选 + applyStatusFilter() { + this.currentStatusFilter = document.getElementById(this.getElementId('StatusFilter')).value; + const errorCodeFilterEl = document.getElementById(this.getElementId('ErrorCodeFilter')); + const cooldownFilterEl = document.getElementById(this.getElementId('CooldownFilter')); + const previewFilterEl = document.getElementById(this.getElementId('PreviewFilter')); + const tierFilterEl = document.getElementById(this.getElementId('TierFilter')); + this.currentErrorCodeFilter = errorCodeFilterEl ? errorCodeFilterEl.value : 'all'; + this.currentCooldownFilter = cooldownFilterEl ? cooldownFilterEl.value : 'all'; + this.currentPreviewFilter = previewFilterEl ? previewFilterEl.value : 'all'; + this.currentTierFilter = tierFilterEl ? tierFilterEl.value : 'all'; + this.currentPage = 1; + this.refresh(); + }, + + // 更新批量控件 + updateBatchControls() { + const selectedCount = this.selectedFiles.size; + document.getElementById(this.getElementId('SelectedCount')).textContent = `已选择 ${selectedCount} 项`; + + const batchBtnNames = ['Enable', 'Disable', 'Delete', 'Verify', 'Preview']; + if (this.type === 'antigravity') { + batchBtnNames.push('EnableCredit'); + batchBtnNames.push('DisableCredit'); + } + const batchBtns = batchBtnNames.map(action => + document.getElementById(this.getElementId(`Batch${action}Btn`)) + ); + batchBtns.forEach(btn => btn && (btn.disabled = selectedCount === 0)); + + const selectAllCheckbox = document.getElementById(this.getElementId('SelectAllCheckbox')); + if (!selectAllCheckbox) return; + + const checkboxes = document.querySelectorAll(`.${this.getElementId('file-checkbox')}`); + const currentPageSelectedCount = Array.from(checkboxes) + .filter(cb => this.selectedFiles.has(cb.getAttribute('data-filename'))).length; + + if (currentPageSelectedCount === 0) { + selectAllCheckbox.indeterminate = false; + selectAllCheckbox.checked = false; + } else if (currentPageSelectedCount === checkboxes.length) { + selectAllCheckbox.indeterminate = false; + selectAllCheckbox.checked = true; + } else { + selectAllCheckbox.indeterminate = true; + } + + checkboxes.forEach(cb => { + cb.checked = this.selectedFiles.has(cb.getAttribute('data-filename')); + }); + }, + + // 凭证操作 + async action(filename, action) { + try { + const response = await fetch(`${this.getEndpoint('action')}?${this.getModeParam()}`, { + method: 'POST', + headers: getAuthHeaders(), + body: JSON.stringify({ filename, action }) + }); + + const data = await response.json(); + + if (response.ok) { + showStatus(data.message || `操作成功: ${action}`, 'success'); + await this.refresh(); + } else { + showStatus(`操作失败: ${data.detail || data.error || '未知错误'}`, 'error'); + } + } catch (error) { + showStatus(`网络错误: ${error.message}`, 'error'); + } + }, + + // 批量操作 + async batchAction(action) { + const selectedFiles = Array.from(this.selectedFiles); + + if (selectedFiles.length === 0) { + showStatus('请先选择要操作的文件', 'error'); + return; + } + + const actionNames = { + enable: '启用', + disable: '禁用', + delete: '删除', + enable_credit: '开启积分', + disable_credit: '关闭积分' + }; + const actionLabel = actionNames[action] || action; + const confirmMsg = action === 'delete' + ? `确定要删除选中的 ${selectedFiles.length} 个文件吗?\n注意:此操作不可恢复!` + : `确定要${actionLabel}选中的 ${selectedFiles.length} 个文件吗?`; + + if (!confirm(confirmMsg)) return; + + try { + showStatus(`正在执行批量${actionLabel}操作...`, 'info'); + + const response = await fetch(`${this.getEndpoint('batchAction')}?${this.getModeParam()}`, { + method: 'POST', + headers: getAuthHeaders(), + body: JSON.stringify({ action, filenames: selectedFiles }) + }); + + const data = await response.json(); + + if (response.ok) { + const successCount = data.success_count || data.succeeded; + showStatus(`批量操作完成:成功处理 ${successCount}/${selectedFiles.length} 个文件`, 'success'); + this.selectedFiles.clear(); + this.updateBatchControls(); + await this.refresh(); + } else { + showStatus(`批量操作失败: ${data.detail || data.error || '未知错误'}`, 'error'); + } + } catch (error) { + showStatus(`批量操作网络错误: ${error.message}`, 'error'); + } + } + }; +} + +// ===================================================================== +// 文件上传管理器工厂 +// ===================================================================== +function createUploadManager(type) { + const modeParam = type === 'antigravity' ? 'mode=antigravity' : 'mode=geminicli'; + const endpoint = `./creds/upload?${modeParam}`; + + return { + type: type, + selectedFiles: [], + + getElementId: (suffix) => { + // 普通上传的ID首字母小写,如 fileList + // Antigravity的ID是 antigravity + 首字母大写,如 antigravityFileList + if (type === 'antigravity') { + return 'antigravity' + suffix.charAt(0).toUpperCase() + suffix.slice(1); + } + return suffix.charAt(0).toLowerCase() + suffix.slice(1); + }, + + handleFileSelect(event) { + this.addFiles(Array.from(event.target.files)); + }, + + addFiles(files) { + files.forEach(file => { + const isValid = file.type === 'application/json' || file.name.endsWith('.json') || + file.type === 'application/zip' || file.name.endsWith('.zip'); + + if (isValid) { + if (!this.selectedFiles.find(f => f.name === file.name && f.size === file.size)) { + this.selectedFiles.push(file); + } + } else { + showStatus(`文件 ${file.name} 格式不支持,只支持JSON和ZIP文件`, 'error'); + } + }); + this.updateFileList(); + }, + + updateFileList() { + const list = document.getElementById(this.getElementId('FileList')); + const section = document.getElementById(this.getElementId('FileListSection')); + + if (!list || !section) { + console.warn('File list elements not found:', this.getElementId('FileList')); + return; + } + + if (this.selectedFiles.length === 0) { + section.classList.add('hidden'); + return; + } + + section.classList.remove('hidden'); + list.innerHTML = ''; + + this.selectedFiles.forEach((file, index) => { + const isZip = file.name.endsWith('.zip'); + const fileIcon = isZip ? '📦' : '📄'; + const fileType = isZip ? ' (ZIP压缩包)' : ' (JSON文件)'; + + const fileItem = document.createElement('div'); + fileItem.className = 'file-item'; + fileItem.innerHTML = ` +
+ ${fileIcon} ${file.name} + (${formatFileSize(file.size)}${fileType}) +
+ + `; + list.appendChild(fileItem); + }); + }, + + removeFile(index) { + this.selectedFiles.splice(index, 1); + this.updateFileList(); + }, + + clearFiles() { + this.selectedFiles = []; + this.updateFileList(); + }, + + async upload() { + if (this.selectedFiles.length === 0) { + showStatus('请选择要上传的文件', 'error'); + return; + } + + const progressSection = document.getElementById(this.getElementId('UploadProgressSection')); + const progressFill = document.getElementById(this.getElementId('ProgressFill')); + const progressText = document.getElementById(this.getElementId('ProgressText')); + + progressSection.classList.remove('hidden'); + + const formData = new FormData(); + this.selectedFiles.forEach(file => formData.append('files', file)); + + if (this.selectedFiles.some(f => f.name.endsWith('.zip'))) { + showStatus('正在上传并解压ZIP文件...', 'info'); + } + + try { + const xhr = new XMLHttpRequest(); + xhr.timeout = 300000; // 5分钟 + + xhr.upload.onprogress = (event) => { + if (event.lengthComputable) { + const percent = (event.loaded / event.total) * 100; + progressFill.style.width = percent + '%'; + progressText.textContent = Math.round(percent) + '%'; + } + }; + + xhr.onload = () => { + if (xhr.status === 200) { + try { + const data = JSON.parse(xhr.responseText); + showStatus(`成功上传 ${data.uploaded_count} 个${type === 'antigravity' ? 'Antigravity' : ''}文件`, 'success'); + this.clearFiles(); + progressSection.classList.add('hidden'); + } catch (e) { + showStatus('上传失败: 服务器响应格式错误', 'error'); + } + } else { + try { + const error = JSON.parse(xhr.responseText); + showStatus(`上传失败: ${error.detail || error.error || '未知错误'}`, 'error'); + } catch (e) { + showStatus(`上传失败: HTTP ${xhr.status}`, 'error'); + } + } + }; + + xhr.onerror = () => { + showStatus(`上传失败:连接中断 - 可能原因:文件过多(${this.selectedFiles.length}个)或网络不稳定。建议分批上传。`, 'error'); + progressSection.classList.add('hidden'); + }; + + xhr.ontimeout = () => { + showStatus('上传失败:请求超时 - 文件处理时间过长,请减少文件数量或检查网络连接', 'error'); + progressSection.classList.add('hidden'); + }; + + xhr.open('POST', endpoint); + xhr.setRequestHeader('Authorization', `Bearer ${AppState.authToken}`); + xhr.send(formData); + } catch (error) { + showStatus(`上传失败: ${error.message}`, 'error'); + } + } + }; +} + +// ===================================================================== +// 工具函数 +// ===================================================================== +function showStatus(message, type = 'info') { + const statusSection = document.getElementById('statusSection'); + if (statusSection) { + // 清除之前的定时器 + if (window._statusTimeout) { + clearTimeout(window._statusTimeout); + } + + // 创建新的 toast + statusSection.innerHTML = `
${message}
`; + const statusDiv = statusSection.querySelector('.status'); + + // 强制重绘以触发动画 + statusDiv.offsetHeight; + statusDiv.classList.add('show'); + + // 3秒后淡出并移除 + window._statusTimeout = setTimeout(() => { + statusDiv.classList.add('fade-out'); + setTimeout(() => { + statusSection.innerHTML = ''; + }, 300); // 等待淡出动画完成 + }, 3000); + } else { + showMessageModal('提示', message, 'info'); + } +} + +// 将文本中的链接转换为可点击的HTML链接 +function linkifyText(text) { + if (!text) return text; + + // 匹配 http://, https:// 和 www. 开头的链接,排除常见的标点符号 + const urlPattern = /(https?:\/\/[^\s"'<>()[\]{}]+)|(www\.[^\s"'<>()[\]{}]+)/gi; + + return text.replace(urlPattern, function(url) { + let href = url; + // 如果是 www. 开头,添加 https:// + if (url.startsWith('www.')) { + href = 'https://' + url; + } + + return `${url}`; + }); +} + +// 显示增强的消息模态框(支持链接高亮) +function showMessageModal(title, message, type = 'info') { + // 创建模态框 + const modal = document.createElement('div'); + modal.className = 'message-modal-overlay'; + modal.innerHTML = ` +
+
+

${title}

+ +
+
+ ${linkifyText(message).replace(/\n/g, '
')} +
+ +
+ `; + + // 添加到页面 + document.body.appendChild(modal); + + // 点击遮罩层关闭 + modal.addEventListener('click', function(e) { + if (e.target === modal) { + modal.remove(); + } + }); + + // ESC 键关闭 + const escHandler = function(e) { + if (e.key === 'Escape') { + modal.remove(); + document.removeEventListener('keydown', escHandler); + } + }; + document.addEventListener('keydown', escHandler); +} + +function getAuthHeaders() { + return { + 'Content-Type': 'application/json', + 'Authorization': `Bearer ${AppState.authToken}` + }; +} + +function formatFileSize(bytes) { + if (bytes < 1024) return bytes + ' B'; + if (bytes < 1024 * 1024) return Math.round(bytes / 1024) + ' KB'; + return Math.round(bytes / (1024 * 1024)) + ' MB'; +} + +function formatCooldownTime(remainingSeconds) { + const hours = Math.floor(remainingSeconds / 3600); + const minutes = Math.floor((remainingSeconds % 3600) / 60); + const seconds = remainingSeconds % 60; + + if (hours > 0) return `${hours}h ${minutes}m ${seconds}s`; + if (minutes > 0) return `${minutes}m ${seconds}s`; + return `${seconds}s`; +} + +// ===================================================================== +// 凭证卡片创建(通用) +// ===================================================================== +function createCredCard(credInfo, manager) { + const div = document.createElement('div'); + const { status, filename } = credInfo; + const managerType = manager.type; + + // 卡片样式 + div.className = status.disabled ? 'cred-card disabled' : 'cred-card'; + + // 状态徽章 + let statusBadges = ''; + statusBadges += status.disabled + ? '已禁用' + : '已启用'; + + if (status.error_codes && status.error_codes.length > 0) { + statusBadges += `错误码: ${status.error_codes.join(', ')}`; + const autoBan = status.error_codes.filter(c => c === 400 || c === 403); + if (autoBan.length > 0 && status.disabled) { + statusBadges += 'AUTO_BAN'; + } + } else { + statusBadges += '无错误'; + } + + // Preview状态显示 (仅对geminicli模式显示) + if (managerType !== 'antigravity' && credInfo.preview !== undefined) { + if (credInfo.preview) { + statusBadges += 'Preview: ON'; + } else { + statusBadges += 'Preview: OFF'; + } + } + + // tier 状态显示 (geminicli 和 antigravity 都显示) + const tier = (credInfo.tier || 'pro').toString().toLowerCase(); + const tierLabel = tier.toUpperCase(); + const tierColor = tier === 'ultra' ? '#ff9800' : (tier === 'free' ? '#607d8b' : '#2e7d32'); + statusBadges += `Tier: ${tierLabel}`; + + // Credit 状态显示(仅 antigravity) + if (managerType === 'antigravity') { + if (credInfo.enable_credit) { + statusBadges += 'Credit: ON'; + } else { + statusBadges += 'Credit: OFF'; + } + } + + // 模型级冷却状态 + if (credInfo.model_cooldowns && Object.keys(credInfo.model_cooldowns).length > 0) { + const currentTime = Date.now() / 1000; + const activeCooldowns = Object.entries(credInfo.model_cooldowns) + .filter(([, until]) => until > currentTime) + .map(([model, until]) => { + const remaining = Math.max(0, Math.floor(until - currentTime)); + const shortModel = model.replace('gemini-', '').replace('-exp', '') + .replace('2.0-', '2-').replace('1.5-', '1.5-'); + return { + model: shortModel, + time: formatCooldownTime(remaining).replace(/s$/, '').replace(/ /g, ''), + fullModel: model + }; + }); + + if (activeCooldowns.length > 0) { + activeCooldowns.slice(0, 2).forEach(item => { + statusBadges += `⏰ ${item.model}: ${item.time}`; + }); + if (activeCooldowns.length > 2) { + const remaining = activeCooldowns.length - 2; + const remainingModels = activeCooldowns.slice(2).map(i => `${i.fullModel}: ${i.time}`).join('\n'); + statusBadges += `+${remaining}`; + } + } + } + + // 路径ID + const pathId = (managerType === 'antigravity' ? 'ag_' : '') + btoa(encodeURIComponent(filename)).replace(/[+/=]/g, '_'); + + // 操作按钮 + const actionButtons = ` + ${status.disabled + ? `` + : `` + } + + + + ${managerType === 'antigravity' ? `` : ''} + ${managerType === 'antigravity' ? (credInfo.enable_credit + ? `` + : `` + ) : ''} + ${managerType !== 'antigravity' ? `` : ''} + + + + + `; + + // 邮箱信息 + const emailInfo = credInfo.user_email + ? `
${credInfo.user_email}
` + : '
未获取邮箱
'; + + const checkboxClass = manager.getElementId('file-checkbox'); + + div.innerHTML = ` +
+
+ +
+
${filename}
+ ${emailInfo} +
+
+
${statusBadges}
+
+
${actionButtons}
+
+
点击"查看内容"按钮加载文件详情...
+
+
+
点击"查看报错"按钮加载报错信息...
+
+ ${managerType === 'antigravity' ? ` + + ` : ''} + `; + + // 添加事件监听 + div.querySelectorAll('[data-filename][data-action]').forEach(button => { + button.addEventListener('click', function () { + const fn = this.getAttribute('data-filename'); + const action = this.getAttribute('data-action'); + if (action === 'delete') { + if (confirm(`确定要删除${managerType === 'antigravity' ? ' Antigravity ' : ''}凭证文件吗?\n${fn}`)) { + manager.action(fn, action); + } + } else { + manager.action(fn, action); + } + }); + }); + + return div; +} + +// ===================================================================== +// 凭证详情切换 +// ===================================================================== +async function toggleCredDetails(pathId) { + await toggleCredDetailsCommon(pathId, AppState.creds); +} + +async function toggleAntigravityCredDetails(pathId) { + await toggleCredDetailsCommon(pathId, AppState.antigravityCreds); +} + +async function toggleCredDetailsCommon(pathId, manager) { + const details = document.getElementById('details-' + pathId); + if (!details) return; + + const isShowing = details.classList.toggle('show'); + + if (isShowing) { + const contentDiv = details.querySelector('.cred-content'); + const filename = contentDiv.getAttribute('data-filename'); + const loaded = contentDiv.getAttribute('data-loaded'); + + if (loaded === 'false' && filename) { + contentDiv.textContent = '正在加载文件内容...'; + + try { + const modeParam = manager.type === 'antigravity' ? 'mode=antigravity' : 'mode=geminicli'; + const endpoint = `./creds/detail/${encodeURIComponent(filename)}?${modeParam}`; + + const response = await fetch(endpoint, { headers: getAuthHeaders() }); + + const data = await response.json(); + if (response.ok && data.content) { + contentDiv.textContent = JSON.stringify(data.content, null, 2); + contentDiv.setAttribute('data-loaded', 'true'); + } else { + contentDiv.textContent = '无法加载文件内容: ' + (data.error || data.detail || '未知错误'); + } + } catch (error) { + contentDiv.textContent = '加载文件内容失败: ' + error.message; + } + } + } +} + +// ===================================================================== +// 登录相关函数 +// ===================================================================== +async function login() { + const password = document.getElementById('loginPassword').value; + + if (!password) { + showStatus('请输入密码', 'error'); + return; + } + + try { + const response = await fetch('./auth/login', { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ password }) + }); + + const data = await response.json(); + + if (response.ok) { + AppState.authToken = data.token; + localStorage.setItem('gcli2api_auth_token', AppState.authToken); + document.getElementById('loginSection').classList.add('hidden'); + document.getElementById('mainSection').classList.remove('hidden'); + showStatus('登录成功', 'success'); + // 显示面板后初始化滑块 + requestAnimationFrame(() => initTabSlider()); + } else { + showStatus(`登录失败: ${data.detail || data.error || '未知错误'}`, 'error'); + } + } catch (error) { + showStatus(`网络错误: ${error.message}`, 'error'); + } +} + +async function autoLogin() { + const savedToken = localStorage.getItem('gcli2api_auth_token'); + if (!savedToken) return false; + + AppState.authToken = savedToken; + + try { + const response = await fetch('./config/get', { + headers: { + 'Content-Type': 'application/json', + 'Authorization': `Bearer ${AppState.authToken}` + } + }); + + if (response.ok) { + document.getElementById('loginSection').classList.add('hidden'); + document.getElementById('mainSection').classList.remove('hidden'); + showStatus('自动登录成功', 'success'); + // 显示面板后初始化滑块 + requestAnimationFrame(() => initTabSlider()); + return true; + } else if (response.status === 401) { + localStorage.removeItem('gcli2api_auth_token'); + AppState.authToken = ''; + return false; + } + return false; + } catch (error) { + return false; + } +} + +function logout() { + localStorage.removeItem('gcli2api_auth_token'); + AppState.authToken = ''; + document.getElementById('loginSection').classList.remove('hidden'); + document.getElementById('mainSection').classList.add('hidden'); + showStatus('已退出登录', 'info'); + const passwordInput = document.getElementById('loginPassword'); + if (passwordInput) passwordInput.value = ''; +} + +function handlePasswordEnter(event) { + if (event.key === 'Enter') login(); +} + +// ===================================================================== +// 标签页切换 +// ===================================================================== + +// 更新滑块位置 +function updateTabSlider(targetTab, animate = true) { + const slider = document.querySelector('.tab-slider'); + const tabs = document.querySelector('.tabs'); + if (!slider || !tabs || !targetTab) return; + + // 获取按钮位置和容器宽度 + const tabLeft = targetTab.offsetLeft; + const tabWidth = targetTab.offsetWidth; + const tabsWidth = tabs.scrollWidth; + + // 使用 left 和 right 同时控制,确保动画同步 + const rightValue = tabsWidth - tabLeft - tabWidth; + + if (animate) { + slider.style.left = `${tabLeft}px`; + slider.style.right = `${rightValue}px`; + } else { + // 首次加载时不使用动画 + slider.style.transition = 'none'; + slider.style.left = `${tabLeft}px`; + slider.style.right = `${rightValue}px`; + // 强制重绘后恢复过渡 + slider.offsetHeight; + slider.style.transition = ''; + } +} + +// 初始化滑块位置 +function initTabSlider() { + const activeTab = document.querySelector('.tab.active'); + if (activeTab) { + updateTabSlider(activeTab, false); + } +} + +// 页面加载和窗口大小变化时初始化滑块 +document.addEventListener('DOMContentLoaded', initTabSlider); +window.addEventListener('resize', () => { + const activeTab = document.querySelector('.tab.active'); + if (activeTab) updateTabSlider(activeTab, false); +}); + +function switchTab(tabName) { + // 获取当前活动的内容区域 + const currentContent = document.querySelector('.tab-content.active'); + const targetContent = document.getElementById(tabName + 'Tab'); + + // 如果点击的是当前标签页,不做任何操作 + if (currentContent === targetContent) return; + + // 找到目标标签按钮 + const targetTab = event && event.target ? event.target : + document.querySelector(`.tab[onclick*="'${tabName}'"]`); + + // 移除所有标签页的active状态 + document.querySelectorAll('.tab').forEach(tab => tab.classList.remove('active')); + + // 添加当前点击标签的active状态 + if (targetTab) { + targetTab.classList.add('active'); + // 更新滑块位置(带动画) + updateTabSlider(targetTab, true); + } + + // 淡出当前内容 + if (currentContent) { + // 设置淡出过渡 + currentContent.style.transition = 'opacity 0.18s ease-out, transform 0.18s ease-out'; + currentContent.style.opacity = '0'; + currentContent.style.transform = 'translateX(-12px)'; + + setTimeout(() => { + currentContent.classList.remove('active'); + currentContent.style.transition = ''; + currentContent.style.opacity = ''; + currentContent.style.transform = ''; + + // 淡入新内容 + if (targetContent) { + // 先设置初始状态(在添加 active 类之前) + targetContent.style.opacity = '0'; + targetContent.style.transform = 'translateX(12px)'; + targetContent.style.transition = 'none'; // 暂时禁用过渡 + + // 添加 active 类使元素可见 + targetContent.classList.add('active'); + + // 使用双重 requestAnimationFrame 确保浏览器完成重绘 + requestAnimationFrame(() => { + requestAnimationFrame(() => { + // 启用过渡并应用最终状态 + targetContent.style.transition = 'opacity 0.25s ease-out, transform 0.25s ease-out'; + targetContent.style.opacity = '1'; + targetContent.style.transform = 'translateX(0)'; + + // 清理内联样式并执行数据加载 + setTimeout(() => { + targetContent.style.transition = ''; + targetContent.style.opacity = ''; + targetContent.style.transform = ''; + + // 动画完成后触发数据加载 + triggerTabDataLoad(tabName); + }, 260); + }); + }); + } + }, 180); + } else { + // 如果没有当前内容(首次加载),直接显示目标内容 + if (targetContent) { + targetContent.classList.add('active'); + // 直接触发数据加载 + triggerTabDataLoad(tabName); + } + } +} + +// 标签页数据加载(从动画中分离出来) +function triggerTabDataLoad(tabName) { + if (tabName === 'manage') AppState.creds.refresh(); + if (tabName === 'antigravity-manage') AppState.antigravityCreds.refresh(); + if (tabName === 'config') loadConfig(); + if (tabName === 'logs') connectWebSocket(); +} + + +// ===================================================================== +// OAuth认证相关函数 +// ===================================================================== +async function startAuth() { + const projectId = document.getElementById('projectId').value.trim(); + AppState.currentProjectId = projectId || null; + + const btn = document.getElementById('getAuthBtn'); + btn.disabled = true; + btn.textContent = '正在获取认证链接...'; + + try { + const requestBody = projectId ? { project_id: projectId } : {}; + showStatus(projectId ? '使用指定的项目ID生成认证链接...' : '将尝试自动检测项目ID,正在生成认证链接...', 'info'); + + const response = await fetch('./auth/start', { + method: 'POST', + headers: getAuthHeaders(), + body: JSON.stringify(requestBody) + }); + + const data = await response.json(); + + if (response.ok) { + document.getElementById('authUrl').href = data.auth_url; + document.getElementById('authUrl').textContent = data.auth_url; + document.getElementById('authUrlSection').classList.remove('hidden'); + + const msg = data.auto_project_detection + ? '认证链接已生成(将在认证完成后自动检测项目ID),请点击链接完成授权' + : `认证链接已生成(项目ID: ${data.detected_project_id}),请点击链接完成授权`; + showStatus(msg, 'info'); + AppState.authInProgress = true; + } else { + showStatus(`错误: ${data.error || '获取认证链接失败'}`, 'error'); + } + } catch (error) { + showStatus(`网络错误: ${error.message}`, 'error'); + } finally { + btn.disabled = false; + btn.textContent = '获取认证链接'; + } +} + +async function getCredentials() { + if (!AppState.authInProgress) { + showStatus('请先获取认证链接并完成授权', 'error'); + return; + } + + const btn = document.getElementById('getCredsBtn'); + btn.disabled = true; + btn.textContent = '等待OAuth回调中...'; + + try { + showStatus('正在等待OAuth回调,这可能需要一些时间...', 'info'); + + const requestBody = AppState.currentProjectId ? { project_id: AppState.currentProjectId } : {}; + + const response = await fetch('./auth/callback', { + method: 'POST', + headers: getAuthHeaders(), + body: JSON.stringify(requestBody) + }); + + const data = await response.json(); + + if (response.ok) { + document.getElementById('credentialsContent').textContent = JSON.stringify(data.credentials, null, 2); + + const msg = data.auto_detected_project + ? `✅ 认证成功!项目ID已自动检测为: ${data.credentials.project_id},文件已保存到: ${data.file_path}` + : `✅ 认证成功!文件已保存到: ${data.file_path}`; + showStatus(msg, 'success'); + + document.getElementById('credentialsSection').classList.remove('hidden'); + AppState.authInProgress = false; + } else if (data.requires_project_selection && data.available_projects) { + let projectOptions = "请选择一个项目:\n\n"; + data.available_projects.forEach((project, index) => { + projectOptions += `${index + 1}. ${project.name} (${project.project_id})\n`; + }); + projectOptions += `\n请输入序号 (1-${data.available_projects.length}):`; + + const selection = prompt(projectOptions); + const projectIndex = parseInt(selection) - 1; + + if (projectIndex >= 0 && projectIndex < data.available_projects.length) { + AppState.currentProjectId = data.available_projects[projectIndex].project_id; + btn.textContent = '重新尝试获取认证文件'; + showStatus(`使用选择的项目重新尝试...`, 'info'); + setTimeout(() => getCredentials(), 1000); + return; + } else { + showStatus('无效的选择,请重新开始认证', 'error'); + } + } else if (data.requires_manual_project_id) { + const userProjectId = prompt('无法自动检测项目ID,请手动输入您的Google Cloud项目ID:'); + if (userProjectId && userProjectId.trim()) { + AppState.currentProjectId = userProjectId.trim(); + btn.textContent = '重新尝试获取认证文件'; + showStatus('使用手动输入的项目ID重新尝试...', 'info'); + setTimeout(() => getCredentials(), 1000); + return; + } else { + showStatus('需要项目ID才能完成认证,请重新开始并输入正确的项目ID', 'error'); + } + } else { + showStatus(`❌ 错误: ${data.error || '获取认证文件失败'}`, 'error'); + } + } catch (error) { + showStatus(`网络错误: ${error.message}`, 'error'); + } finally { + btn.disabled = false; + btn.textContent = '获取认证文件'; + } +} + +// ===================================================================== +// Antigravity 认证相关函数 +// ===================================================================== +async function startAntigravityAuth() { + const btn = document.getElementById('getAntigravityAuthBtn'); + btn.disabled = true; + btn.textContent = '生成认证链接中...'; + + try { + showStatus('正在生成 Antigravity 认证链接...', 'info'); + + const response = await fetch('./auth/start', { + method: 'POST', + headers: getAuthHeaders(), + body: JSON.stringify({ mode: 'antigravity' }) + }); + + const data = await response.json(); + + if (response.ok) { + AppState.antigravityAuthState = data.state; + AppState.antigravityAuthInProgress = true; + + const authUrlLink = document.getElementById('antigravityAuthUrl'); + authUrlLink.href = data.auth_url; + authUrlLink.textContent = data.auth_url; + document.getElementById('antigravityAuthUrlSection').classList.remove('hidden'); + + showStatus('✅ Antigravity 认证链接已生成!请点击链接完成授权', 'success'); + } else { + showStatus(`❌ 错误: ${data.error || '生成认证链接失败'}`, 'error'); + } + } catch (error) { + showStatus(`网络错误: ${error.message}`, 'error'); + } finally { + btn.disabled = false; + btn.textContent = '获取 Antigravity 认证链接'; + } +} + +async function getAntigravityCredentials() { + if (!AppState.antigravityAuthInProgress) { + showStatus('请先获取 Antigravity 认证链接并完成授权', 'error'); + return; + } + + const btn = document.getElementById('getAntigravityCredsBtn'); + btn.disabled = true; + btn.textContent = '等待OAuth回调中...'; + + try { + showStatus('正在等待 Antigravity OAuth回调...', 'info'); + + const response = await fetch('./auth/callback', { + method: 'POST', + headers: getAuthHeaders(), + body: JSON.stringify({ mode: 'antigravity' }) + }); + + const data = await response.json(); + + if (response.ok) { + document.getElementById('antigravityCredsContent').textContent = JSON.stringify(data.credentials, null, 2); + document.getElementById('antigravityCredsSection').classList.remove('hidden'); + AppState.antigravityAuthInProgress = false; + showStatus(`✅ Antigravity 认证成功!文件已保存到: ${data.file_path}`, 'success'); + } else { + showStatus(`❌ 错误: ${data.error || '获取认证文件失败'}`, 'error'); + } + } catch (error) { + showStatus(`网络错误: ${error.message}`, 'error'); + } finally { + btn.disabled = false; + btn.textContent = '获取 Antigravity 凭证'; + } +} + +function downloadAntigravityCredentials() { + const content = document.getElementById('antigravityCredsContent').textContent; + const blob = new Blob([content], { type: 'application/json' }); + const url = window.URL.createObjectURL(blob); + const a = document.createElement('a'); + a.href = url; + a.download = `antigravity-credential-${Date.now()}.json`; + a.click(); + window.URL.revokeObjectURL(url); +} + +// ===================================================================== +// 回调URL处理 +// ===================================================================== +function toggleProjectIdSection() { + const section = document.getElementById('projectIdSection'); + const icon = document.getElementById('projectIdToggleIcon'); + + if (section.style.display === 'none') { + section.style.display = 'block'; + icon.style.transform = 'rotate(90deg)'; + icon.textContent = '▼'; + } else { + section.style.display = 'none'; + icon.style.transform = 'rotate(0deg)'; + icon.textContent = '▶'; + } +} + +function toggleCallbackUrlSection() { + const section = document.getElementById('callbackUrlSection'); + const icon = document.getElementById('callbackUrlToggleIcon'); + + if (section.style.display === 'none') { + section.style.display = 'block'; + icon.style.transform = 'rotate(180deg)'; + icon.textContent = '▲'; + } else { + section.style.display = 'none'; + icon.style.transform = 'rotate(0deg)'; + icon.textContent = '▼'; + } +} + +function toggleAntigravityCallbackUrlSection() { + const section = document.getElementById('antigravityCallbackUrlSection'); + const icon = document.getElementById('antigravityCallbackUrlToggleIcon'); + + if (section.style.display === 'none') { + section.style.display = 'block'; + icon.style.transform = 'rotate(180deg)'; + icon.textContent = '▲'; + } else { + section.style.display = 'none'; + icon.style.transform = 'rotate(0deg)'; + icon.textContent = '▼'; + } +} + +async function processCallbackUrl() { + const callbackUrl = document.getElementById('callbackUrlInput').value.trim(); + + if (!callbackUrl) { + showStatus('请输入回调URL', 'error'); + return; + } + + if (!callbackUrl.startsWith('http://') && !callbackUrl.startsWith('https://')) { + showStatus('请输入有效的URL(以http://或https://开头)', 'error'); + return; + } + + if (!callbackUrl.includes('code=') || !callbackUrl.includes('state=')) { + showStatus('❌ 这不是有效的回调URL!请确保:\n1. 已完成Google OAuth授权\n2. 复制的是浏览器地址栏的完整URL\n3. URL包含code和state参数', 'error'); + return; + } + + showStatus('正在从回调URL获取凭证...', 'info'); + + try { + const projectId = document.getElementById('projectId')?.value.trim() || null; + + const response = await fetch('./auth/callback-url', { + method: 'POST', + headers: getAuthHeaders(), + body: JSON.stringify({ callback_url: callbackUrl, project_id: projectId }) + }); + + const result = await response.json(); + + if (result.credentials) { + showStatus(result.message || '从回调URL获取凭证成功!', 'success'); + document.getElementById('credentialsContent').innerHTML = '
' + JSON.stringify(result.credentials, null, 2) + '
'; + document.getElementById('credentialsSection').classList.remove('hidden'); + } else if (result.requires_manual_project_id) { + showStatus('需要手动指定项目ID,请在高级选项中填入Google Cloud项目ID后重试', 'error'); + } else if (result.requires_project_selection) { + let msg = '
可用项目:
'; + result.available_projects.forEach(p => { + msg += `• ${p.name} (ID: ${p.project_id})
`; + }); + showStatus('检测到多个项目,请在高级选项中指定项目ID:' + msg, 'error'); + } else { + showStatus(result.error || '从回调URL获取凭证失败', 'error'); + } + + document.getElementById('callbackUrlInput').value = ''; + } catch (error) { + showStatus(`从回调URL获取凭证失败: ${error.message}`, 'error'); + } +} + +async function processAntigravityCallbackUrl() { + const callbackUrl = document.getElementById('antigravityCallbackUrlInput').value.trim(); + + if (!callbackUrl) { + showStatus('请输入回调URL', 'error'); + return; + } + + if (!callbackUrl.startsWith('http://') && !callbackUrl.startsWith('https://')) { + showStatus('请输入有效的URL(以http://或https://开头)', 'error'); + return; + } + + if (!callbackUrl.includes('code=') || !callbackUrl.includes('state=')) { + showStatus('❌ 这不是有效的回调URL!请确保包含code和state参数', 'error'); + return; + } + + showStatus('正在从回调URL获取 Antigravity 凭证...', 'info'); + + try { + const response = await fetch('./auth/callback-url', { + method: 'POST', + headers: getAuthHeaders(), + body: JSON.stringify({ callback_url: callbackUrl, mode: 'antigravity' }) + }); + + const result = await response.json(); + + if (result.credentials) { + showStatus(result.message || '从回调URL获取 Antigravity 凭证成功!', 'success'); + document.getElementById('antigravityCredsContent').textContent = JSON.stringify(result.credentials, null, 2); + document.getElementById('antigravityCredsSection').classList.remove('hidden'); + } else { + showStatus(result.error || '从回调URL获取 Antigravity 凭证失败', 'error'); + } + + document.getElementById('antigravityCallbackUrlInput').value = ''; + } catch (error) { + showStatus(`从回调URL获取 Antigravity 凭证失败: ${error.message}`, 'error'); + } +} + +// ===================================================================== +// 全局兼容函数(供HTML调用) +// ===================================================================== +// 普通凭证管理 +function refreshCredsStatus() { AppState.creds.refresh(); } +function applyStatusFilter() { AppState.creds.applyStatusFilter(); } +function changePage(direction) { AppState.creds.changePage(direction); } +function changePageSize() { AppState.creds.changePageSize(); } +function toggleFileSelection(filename) { + if (AppState.creds.selectedFiles.has(filename)) { + AppState.creds.selectedFiles.delete(filename); + } else { + AppState.creds.selectedFiles.add(filename); + } + AppState.creds.updateBatchControls(); +} +function toggleSelectAll() { + const checkbox = document.getElementById('selectAllCheckbox'); + const checkboxes = document.querySelectorAll('.file-checkbox'); + + if (checkbox.checked) { + checkboxes.forEach(cb => AppState.creds.selectedFiles.add(cb.getAttribute('data-filename'))); + } else { + AppState.creds.selectedFiles.clear(); + } + checkboxes.forEach(cb => cb.checked = checkbox.checked); + AppState.creds.updateBatchControls(); +} +function batchAction(action) { AppState.creds.batchAction(action); } +function downloadCred(filename) { + fetch(`./creds/download/${filename}`, { headers: { 'Authorization': `Bearer ${AppState.authToken}` } }) + .then(r => r.ok ? r.blob() : Promise.reject()) + .then(blob => { + const url = window.URL.createObjectURL(blob); + const a = document.createElement('a'); + a.href = url; + a.download = filename; + a.click(); + window.URL.revokeObjectURL(url); + showStatus(`已下载文件: ${filename}`, 'success'); + }) + .catch(() => showStatus(`下载失败: ${filename}`, 'error')); +} +async function downloadAllCreds() { + try { + const response = await fetch('./creds/download-all', { + headers: { 'Authorization': `Bearer ${AppState.authToken}` } + }); + if (response.ok) { + const blob = await response.blob(); + const url = window.URL.createObjectURL(blob); + const a = document.createElement('a'); + a.href = url; + a.download = 'credentials.zip'; + a.click(); + window.URL.revokeObjectURL(url); + showStatus('已下载所有凭证文件', 'success'); + } + } catch (error) { + showStatus(`打包下载失败: ${error.message}`, 'error'); + } +} + +// Antigravity凭证管理 +function refreshAntigravityCredsList() { AppState.antigravityCreds.refresh(); } +function applyAntigravityStatusFilter() { AppState.antigravityCreds.applyStatusFilter(); } +function changeAntigravityPage(direction) { AppState.antigravityCreds.changePage(direction); } +function changeAntigravityPageSize() { AppState.antigravityCreds.changePageSize(); } +function toggleAntigravityFileSelection(filename) { + if (AppState.antigravityCreds.selectedFiles.has(filename)) { + AppState.antigravityCreds.selectedFiles.delete(filename); + } else { + AppState.antigravityCreds.selectedFiles.add(filename); + } + AppState.antigravityCreds.updateBatchControls(); +} +function toggleSelectAllAntigravity() { + const checkbox = document.getElementById('selectAllAntigravityCheckbox'); + const checkboxes = document.querySelectorAll('.antigravityFile-checkbox'); + + if (checkbox.checked) { + checkboxes.forEach(cb => AppState.antigravityCreds.selectedFiles.add(cb.getAttribute('data-filename'))); + } else { + AppState.antigravityCreds.selectedFiles.clear(); + } + checkboxes.forEach(cb => cb.checked = checkbox.checked); + AppState.antigravityCreds.updateBatchControls(); +} +function batchAntigravityAction(action) { AppState.antigravityCreds.batchAction(action); } +function downloadAntigravityCred(filename) { + fetch(`./creds/download/${filename}?mode=antigravity`, { headers: getAuthHeaders() }) + .then(r => r.ok ? r.blob() : Promise.reject()) + .then(blob => { + const url = window.URL.createObjectURL(blob); + const a = document.createElement('a'); + a.href = url; + a.download = filename; + a.click(); + window.URL.revokeObjectURL(url); + showStatus(`✅ 已下载: ${filename}`, 'success'); + }) + .catch(() => showStatus(`下载失败: ${filename}`, 'error')); +} +function deleteAntigravityCred(filename) { + if (confirm(`确定要删除 ${filename} 吗?`)) { + AppState.antigravityCreds.action(filename, 'delete'); + } +} +async function downloadAllAntigravityCreds() { + try { + const response = await fetch('./creds/download-all?mode=antigravity', { headers: getAuthHeaders() }); + if (response.ok) { + const blob = await response.blob(); + const url = window.URL.createObjectURL(blob); + const a = document.createElement('a'); + a.href = url; + a.download = `antigravity_credentials_${Date.now()}.zip`; + a.click(); + window.URL.revokeObjectURL(url); + showStatus('✅ 所有Antigravity凭证已打包下载', 'success'); + } + } catch (error) { + showStatus(`网络错误: ${error.message}`, 'error'); + } +} + +// 文件上传 +function handleFileSelect(event) { AppState.uploadFiles.handleFileSelect(event); } +function removeFile(index) { AppState.uploadFiles.removeFile(index); } +function clearFiles() { AppState.uploadFiles.clearFiles(); } +function uploadFiles() { AppState.uploadFiles.upload(); } + +function handleAntigravityFileSelect(event) { AppState.antigravityUploadFiles.handleFileSelect(event); } +function handleAntigravityFileDrop(event) { + event.preventDefault(); + event.currentTarget.style.borderColor = '#007bff'; + event.currentTarget.style.backgroundColor = '#f8f9fa'; + AppState.antigravityUploadFiles.addFiles(Array.from(event.dataTransfer.files)); +} +function removeAntigravityFile(index) { AppState.antigravityUploadFiles.removeFile(index); } +function clearAntigravityFiles() { AppState.antigravityUploadFiles.clearFiles(); } +function uploadAntigravityFiles() { AppState.antigravityUploadFiles.upload(); } + +// 邮箱相关 +// 辅助函数:根据文件名更新卡片中的邮箱显示 +function updateEmailDisplay(filename, email, managerType = 'normal') { + // 查找对应的凭证卡片 + const containerId = managerType === 'antigravity' ? 'antigravityCredsList' : 'credsList'; + const container = document.getElementById(containerId); + if (!container) return false; + + // 通过 data-filename 找到对应的复选框,再找到其父卡片 + const checkbox = container.querySelector(`input[data-filename="${filename}"]`); + if (!checkbox) return false; + + // 找到对应的 cred-card 元素 + const card = checkbox.closest('.cred-card'); + if (!card) return false; + + // 找到邮箱显示元素 + const emailDiv = card.querySelector('.cred-email'); + if (emailDiv) { + emailDiv.textContent = email; + emailDiv.style.color = '#666'; + emailDiv.style.fontStyle = 'normal'; + return true; + } + return false; +} + +async function fetchUserEmail(filename) { + try { + showStatus('正在获取用户邮箱...', 'info'); + const response = await fetch(`./creds/fetch-email/${encodeURIComponent(filename)}`, { + method: 'POST', + headers: getAuthHeaders() + }); + const data = await response.json(); + if (response.ok && data.user_email) { + showStatus(`成功获取邮箱: ${data.user_email}`, 'success'); + // 直接更新卡片中的邮箱显示,不刷新整个列表 + updateEmailDisplay(filename, data.user_email, 'normal'); + } else { + showStatus(data.message || '无法获取用户邮箱', 'error'); + } + } catch (error) { + showStatus(`获取邮箱失败: ${error.message}`, 'error'); + } +} + +async function fetchAntigravityUserEmail(filename) { + try { + showStatus('正在获取用户邮箱...', 'info'); + const response = await fetch(`./creds/fetch-email/${encodeURIComponent(filename)}?mode=antigravity`, { + method: 'POST', + headers: getAuthHeaders() + }); + const data = await response.json(); + if (response.ok && data.user_email) { + showStatus(`成功获取邮箱: ${data.user_email}`, 'success'); + // 直接更新卡片中的邮箱显示,不刷新整个列表 + updateEmailDisplay(filename, data.user_email, 'antigravity'); + } else { + showStatus(data.message || '无法获取用户邮箱', 'error'); + } + } catch (error) { + showStatus(`获取邮箱失败: ${error.message}`, 'error'); + } +} + +async function verifyProjectId(filename) { + try { + // 显示加载状态 + showStatus('🔍 正在检验Project ID,请稍候...', 'info'); + + const response = await fetch(`./creds/verify-project/${encodeURIComponent(filename)}`, { + method: 'POST', + headers: getAuthHeaders() + }); + const data = await response.json(); + + if (response.ok && data.success) { + // 成功时显示绿色成功消息和Project ID + const tierLine = data.subscription_tier ? `\nTier: ${data.subscription_tier}` : ''; + const creditLine = data.credit_amount !== undefined && data.credit_amount !== null + ? `\n积分: ${data.credit_amount}` + : ''; + const successMsg = `✅ 检验成功!\n文件: ${filename}\nProject ID: ${data.project_id}${tierLine}${creditLine}\n\n${data.message}`; + showStatus(successMsg.replace(/\n/g, '
'), 'success'); + + // 弹出成功提示 + showMessageModal('检验成功', `✅ 检验成功!\n\n文件: ${filename}\nProject ID: ${data.project_id}${tierLine}${creditLine}\n\n${data.message}`, 'success'); + + await AppState.creds.refresh(); + } else { + // 失败时显示红色错误消息 + const errorMsg = data.message || '检验失败'; + showStatus(`❌ ${errorMsg}`, 'error'); + showMessageModal('检验失败', `❌ 检验失败\n\n${errorMsg}`, 'error'); + } + } catch (error) { + const errorMsg = `检验失败: ${error.message}`; + showStatus(`❌ ${errorMsg}`, 'error'); + showMessageModal('检验失败', `❌ ${errorMsg}`, 'error'); + } +} + +async function verifyAntigravityProjectId(filename) { + try { + // 显示加载状态 + showStatus('🔍 正在检验Antigravity Project ID,请稍候...', 'info'); + + const response = await fetch(`./creds/verify-project/${encodeURIComponent(filename)}?mode=antigravity`, { + method: 'POST', + headers: getAuthHeaders() + }); + const data = await response.json(); + + if (response.ok && data.success) { + // 成功时显示绿色成功消息和Project ID + const tierLine = data.subscription_tier ? `\nTier: ${data.subscription_tier}` : ''; + const creditLine = data.credit_amount !== undefined && data.credit_amount !== null + ? `\n积分: ${data.credit_amount}` + : ''; + const successMsg = `✅ 检验成功!\n文件: ${filename}\nProject ID: ${data.project_id}${tierLine}${creditLine}\n\n${data.message}`; + showStatus(successMsg.replace(/\n/g, '
'), 'success'); + + // 弹出成功提示 + showMessageModal('检验成功', `✅ Antigravity检验成功!\n\n文件: ${filename}\nProject ID: ${data.project_id}${tierLine}${creditLine}\n\n${data.message}`, 'success'); + + await AppState.antigravityCreds.refresh(); + } else { + // 失败时显示红色错误消息 + const errorMsg = data.message || '检验失败'; + showStatus(`❌ ${errorMsg}`, 'error'); + showMessageModal('检验失败', `❌ 检验失败\n\n${errorMsg}`, 'error'); + } + } catch (error) { + const errorMsg = `检验失败: ${error.message}`; + showStatus(`❌ ${errorMsg}`, 'error'); + showMessageModal('检验失败', `❌ ${errorMsg}`, 'error'); + } +} + +async function testCredential(filename) { + try { + // 显示加载状态 + showStatus('🧪 正在测试凭证,请稍候...', 'info'); + + const response = await fetch(`./creds/test/${encodeURIComponent(filename)}`, { + method: 'POST', + headers: getAuthHeaders() + }); + + // 解析JSON响应 + const data = await response.json(); + + if (response.status === 200) { + // 凭证可用 + const successMsg = `✅ 测试成功!\n文件: ${filename}\n状态: ${data.message || '凭证可用'} (${data.status_code || 200})`; + showStatus('✅ 测试成功!', 'success'); + showMessageModal('测试成功', successMsg, 'success'); + await AppState.creds.refresh(); + } + else { + // 其他错误 - 显示完整错误信息 + let errorDetails = `❌ 测试失败\n文件: ${filename}\n`; + + // 如果有完整的错误响应,添加到详情中 + if (data.error) { + try { + // 尝试格式化JSON错误 + const errorObj = JSON.parse(data.error); + errorDetails += `\n错误详情:\n${JSON.stringify(errorObj, null, 2)}`; + } catch { + // 如果不是JSON,直接显示文本 + errorDetails += `\n错误详情:\n${data.error}`; + } + } else { + errorDetails += `错误码: ${data.status_code || response.status}`; + } + + showStatus(`❌ 测试失败 - ${data.message || '错误码: ' + (data.status_code || response.status)}`, 'error'); + showMessageModal('测试失败', errorDetails, 'error'); + } + } catch (error) { + const errorMsg = `测试失败: ${error.message}`; + showStatus(`❌ ${errorMsg}`, 'error'); + showMessageModal('测试失败', `❌ ${errorMsg}`, 'error'); + } +} + +async function testAntigravityCredential(filename) { + try { + // 显示加载状态 + showStatus('🧪 正在测试Antigravity凭证,请稍候...', 'info'); + + const response = await fetch(`./creds/test/${encodeURIComponent(filename)}?mode=antigravity`, { + method: 'POST', + headers: getAuthHeaders() + }); + + // 解析JSON响应 + const data = await response.json(); + + if (response.status === 200) { + // 凭证可用 + const successMsg = `✅ 测试成功!\n文件: ${filename}\n状态: ${data.message || 'Antigravity凭证可用'} (${data.status_code || 200})`; + showStatus('✅ 测试成功!', 'success'); + showMessageModal('测试成功', successMsg, 'success'); + await AppState.antigravityCreds.refresh(); + } + else { + // 其他错误 - 显示完整错误信息 + let errorDetails = `❌ 测试失败\n文件: ${filename}\n`; + + // 如果有完整的错误响应,添加到详情中 + if (data.error) { + try { + // 尝试格式化JSON错误 + const errorObj = JSON.parse(data.error); + errorDetails += `\n错误详情:\n${JSON.stringify(errorObj, null, 2)}`; + } catch { + // 如果不是JSON,直接显示文本 + errorDetails += `\n错误详情:\n${data.error}`; + } + } else { + errorDetails += `错误码: ${data.status_code || response.status}`; + } + + showStatus(`❌ 测试失败 - ${data.message || '错误码: ' + (data.status_code || response.status)}`, 'error'); + showMessageModal('测试失败', errorDetails, 'error'); + } + } catch (error) { + const errorMsg = `测试失败: ${error.message}`; + showStatus(`❌ ${errorMsg}`, 'error'); + showMessageModal('测试失败', `❌ ${errorMsg}`, 'error'); + } +} + +async function configurePreviewChannel(filename) { + try { + // 显示加载状态 + showStatus('🔧 正在配置Preview通道,请稍候...', 'info'); + + const response = await fetch(`./creds/configure-preview/${encodeURIComponent(filename)}`, { + method: 'POST', + headers: getAuthHeaders() + }); + + const data = await response.json(); + + if (response.ok && data.success) { + // 配置成功 + const successMsg = `✅ 配置成功!\n文件: ${filename}\n状态: ${data.message}`; + showStatus(successMsg.replace(/\n/g, '
'), 'success'); + showMessageModal('Preview通道配置成功', `✅ Preview通道配置成功!\n\n文件: ${filename}\n\n${data.message}\n\nSetting ID: ${data.setting_id || 'N/A'}\nBinding ID: ${data.binding_id || 'N/A'}`, 'success'); + + // 刷新凭证列表 + await AppState.creds.refresh(); + } else { + // 配置失败 + const errorMsg = data.message || '配置失败'; + const errorDetail = data.error || ''; + const step = data.step || ''; + + let alertMsg = `❌ Preview通道配置失败\n\n文件: ${filename}\n\n${errorMsg}`; + if (step) { + alertMsg += `\n失败步骤: ${step}`; + } + if (errorDetail) { + alertMsg += `\n\n错误详情: ${errorDetail}`; + } + + showStatus(`❌ ${errorMsg}`, 'error'); + showMessageModal('Preview通道配置失败', alertMsg, 'error'); + } + } catch (error) { + const errorMsg = `配置Preview通道失败: ${error.message}`; + showStatus(`❌ ${errorMsg}`, 'error'); + showMessageModal('配置Preview通道失败', `❌ ${errorMsg}`, 'error'); + } +} + +async function toggleAntigravityQuotaDetails(pathId) { + const quotaDetails = document.getElementById('quota-' + pathId); + if (!quotaDetails) return; + + // 切换显示状态 + const isShowing = quotaDetails.style.display === 'block'; + + if (isShowing) { + // 收起 + quotaDetails.style.display = 'none'; + } else { + // 展开 + quotaDetails.style.display = 'block'; + + const contentDiv = quotaDetails.querySelector('.cred-quota-content'); + const filename = contentDiv.getAttribute('data-filename'); + + // 每次展开都重新加载数据 + if (filename) { + contentDiv.innerHTML = '
📊 正在加载额度信息...
'; + + try { + const response = await fetch(`./creds/quota/${encodeURIComponent(filename)}?mode=antigravity`, { + method: 'GET', + headers: getAuthHeaders() + }); + const data = await response.json(); + + if (response.ok && data.success) { + // 成功时渲染美化的额度信息 + const models = data.models || {}; + + if (Object.keys(models).length === 0) { + contentDiv.innerHTML = ` +
+
📊
+
暂无额度信息
+
+ `; + } else { + let quotaHTML = ` +
+

+ 📊 + 额度信息详情 +

+
文件: ${filename}
+
+
+ `; + + for (const [modelName, quotaData] of Object.entries(models)) { + // 后端返回的是剩余比例 (0-1),不是绝对数量 + const remainingFraction = quotaData.remaining || 0; + const resetTime = quotaData.resetTime || 'N/A'; + + // 计算已使用百分比(1 - 剩余比例) + const usedPercentage = Math.round((1 - remainingFraction) * 100); + const remainingPercentage = Math.round(remainingFraction * 100); + + // 根据使用情况选择颜色 + let percentageColor = '#28a745'; // 绿色:使用少 + if (usedPercentage >= 90) percentageColor = '#dc3545'; // 红色:使用多 + else if (usedPercentage >= 70) percentageColor = '#ffc107'; // 黄色:使用较多 + else if (usedPercentage >= 50) percentageColor = '#17a2b8'; // 蓝色:使用中等 + + quotaHTML += ` +
+
+
+ ${modelName} +
+
+ ${remainingPercentage}% +
+
+
+
+
+
+ ${resetTime !== 'N/A' ? '🔄 ' + resetTime : ''} +
+
+ `; + } + + quotaHTML += '
'; + contentDiv.innerHTML = quotaHTML; + } + + showStatus('✅ 成功加载额度信息', 'success'); + } else { + // 失败时显示错误 + const errorMsg = data.error || '获取额度信息失败'; + contentDiv.innerHTML = ` +
+
+
获取额度信息失败
+
${errorMsg}
+
+ `; + showStatus(`❌ ${errorMsg}`, 'error'); + } + } catch (error) { + contentDiv.innerHTML = ` +
+
+
网络错误
+
${error.message}
+
+ `; + showStatus(`❌ 获取额度信息失败: ${error.message}`, 'error'); + } + } + } +} + +// ===================================================================== +// 查看报错详情 +// ===================================================================== +async function toggleErrorDetails(pathId) { + await toggleErrorDetailsCommon(pathId, AppState.creds); +} + +async function toggleAntigravityErrorDetails(pathId) { + await toggleErrorDetailsCommon(pathId, AppState.antigravityCreds); +} + +async function toggleErrorDetailsCommon(pathId, manager) { + const errorDetails = document.getElementById('errors-' + pathId); + if (!errorDetails) return; + + // 切换显示状态 + const isShowing = errorDetails.classList.toggle('show'); + + if (isShowing) { + const contentDiv = errorDetails.querySelector('.cred-content'); + const filename = contentDiv.getAttribute('data-filename'); + + // 每次展开都重新加载数据 + if (filename) { + contentDiv.innerHTML = '
⏳ 正在加载报错信息...
'; + + try { + const modeParam = manager.type === 'antigravity' ? 'mode=antigravity' : 'mode=geminicli'; + const response = await fetch(`./creds/errors/${encodeURIComponent(filename)}?${modeParam}`, { + method: 'GET', + headers: getAuthHeaders() + }); + const data = await response.json(); + + if (response.ok) { + const errorCodes = data.error_codes || []; + const errorMessages = data.error_messages || {}; + + if (errorCodes.length === 0) { + contentDiv.innerHTML = ` +
+
+
无报错记录
+
该凭证运行正常
+
+ `; + } else { + let errorHTML = ''; + + // 遍历所有错误码,从 errorMessages 对象中获取对应消息 + errorCodes.forEach((errorCode) => { + const messageStr = errorMessages[errorCode] || '无详细信息'; + + // 提取核心错误消息和详细信息 + let displayMsg = messageStr; + let detailsHtml = ''; + + try { + // 尝试解析 JSON 格式的 message + const parsedMsg = JSON.parse(messageStr); + if (parsedMsg.error) { + // 显示核心错误信息 + if (parsedMsg.error.message) { + displayMsg = parsedMsg.error.message; + } + + // 如果有 details 字段,也显示出来 + if (parsedMsg.error.details && Array.isArray(parsedMsg.error.details)) { + detailsHtml = '
'; + detailsHtml += '
详细信息:
'; + + parsedMsg.error.details.forEach((detail, idx) => { + detailsHtml += '
'; + + // 显示 @type + if (detail['@type']) { + const highlightedType = highlightHttpLinks(escapeHtml(detail['@type'])); + detailsHtml += `
类型: ${highlightedType}
`; + } + + // 显示 reason + if (detail.reason) { + detailsHtml += `
原因: ${escapeHtml(detail.reason)}
`; + } + + // 显示 metadata(如 quotaResetTimeStamp) + if (detail.metadata) { + detailsHtml += '
'; + for (const [key, value] of Object.entries(detail.metadata)) { + const highlightedValue = highlightHttpLinks(escapeHtml(String(value))); + detailsHtml += `
${escapeHtml(key)}: ${highlightedValue}
`; + } + detailsHtml += '
'; + } + + detailsHtml += '
'; + }); + + detailsHtml += '
'; + } + + // 如果有 status 字段,也显示 + if (parsedMsg.error.status) { + if (!detailsHtml) { + detailsHtml = '
'; + } + detailsHtml += `
状态: ${escapeHtml(parsedMsg.error.status)}
`; + if (!parsedMsg.error.details) { + detailsHtml += '
'; + } + } + } + } catch (e) { + // 如果不是 JSON 格式,直接使用原始消息 + } + + // 对消息中的HTTP链接进行高亮处理 + const highlightedMsg = highlightHttpLinks(escapeHtml(displayMsg)); + + errorHTML += ` +
+
错误码: ${errorCode}
+
+ ${highlightedMsg} +
+ ${detailsHtml} +
+ `; + }); + + contentDiv.innerHTML = errorHTML; + } + + showStatus('✅ 成功加载报错信息', 'success'); + } else { + // 失败时显示错误 + const errorMsg = data.detail || data.error || '获取报错信息失败'; + contentDiv.innerHTML = ` +
+
+
加载失败
+
${errorMsg}
+
+ `; + showStatus(`❌ 获取报错信息失败: ${errorMsg}`, 'error'); + } + } catch (error) { + contentDiv.innerHTML = ` +
+
+
网络错误
+
${error.message}
+
+ `; + showStatus(`❌ 获取报错信息失败: ${error.message}`, 'error'); + } + } + } +} + +// HTML转义函数 +function escapeHtml(text) { + const div = document.createElement('div'); + div.textContent = text; + return div.innerHTML; +} + +// 高亮HTTP链接函数 +function highlightHttpLinks(text) { + // 匹配 http:// 或 https:// 开头的URL + const urlRegex = /(https?:\/\/[^\s<>"]+)/gi; + return text.replace(urlRegex, function(url) { + return `${url}`; + }); +} + +async function batchVerifyProjectIds() { + const selectedFiles = Array.from(AppState.creds.selectedFiles); + if (selectedFiles.length === 0) { + showStatus('❌ 请先选择要检验的凭证', 'error'); + showMessageModal('提示', '请先选择要检验的凭证', 'error'); + return; + } + + if (!confirm(`确定要批量检验 ${selectedFiles.length} 个凭证的Project ID吗?\n\n将并行检验以加快速度。`)) { + return; + } + + showStatus(`🔍 正在并行检验 ${selectedFiles.length} 个凭证,请稍候...`, 'info'); + + // 并行执行所有检验请求 + const promises = selectedFiles.map(async (filename) => { + try { + const response = await fetch(`./creds/verify-project/${encodeURIComponent(filename)}`, { + method: 'POST', + headers: getAuthHeaders() + }); + const data = await response.json(); + + if (response.ok && data.success) { + return { + success: true, + filename, + projectId: data.project_id, + creditAmount: data.credit_amount, + message: data.message + }; + } else { + return { success: false, filename, error: data.message || '失败' }; + } + } catch (error) { + return { success: false, filename, error: error.message }; + } + }); + + // 等待所有请求完成 + const results = await Promise.all(promises); + + // 统计结果 + let successCount = 0; + let failCount = 0; + const resultMessages = []; + + results.forEach(result => { + if (result.success) { + successCount++; + const creditSuffix = result.creditAmount !== undefined && result.creditAmount !== null + ? ` (积分: ${result.creditAmount})` + : ''; + resultMessages.push(`✅ ${result.filename}: ${result.projectId}${creditSuffix}`); + } else { + failCount++; + resultMessages.push(`❌ ${result.filename}: ${result.error}`); + } + }); + + await AppState.creds.refresh(); + + const summary = `批量检验完成!\n\n成功: ${successCount} 个\n失败: ${failCount} 个\n总计: ${selectedFiles.length} 个\n\n详细结果:\n${resultMessages.join('\n')}`; + + if (failCount === 0) { + showStatus(`✅ 全部检验成功!成功检验 ${successCount}/${selectedFiles.length} 个凭证`, 'success'); + showMessageModal('批量检验完成', summary, 'success'); + } else if (successCount === 0) { + showStatus(`❌ 全部检验失败!失败 ${failCount}/${selectedFiles.length} 个凭证`, 'error'); + showMessageModal('批量检验完成', summary, 'error'); + } else { + showStatus(`⚠️ 批量检验完成:成功 ${successCount}/${selectedFiles.length} 个,失败 ${failCount} 个`, 'info'); + showMessageModal('批量检验完成', summary, 'info'); + } + + console.log(summary); +} + +async function batchVerifyAntigravityProjectIds() { + const selectedFiles = Array.from(AppState.antigravityCreds.selectedFiles); + if (selectedFiles.length === 0) { + showStatus('❌ 请先选择要检验的Antigravity凭证', 'error'); + showMessageModal('提示', '请先选择要检验的Antigravity凭证', 'error'); + return; + } + + if (!confirm(`确定要批量检验 ${selectedFiles.length} 个Antigravity凭证的Project ID吗?\n\n将并行检验以加快速度。`)) { + return; + } + + showStatus(`🔍 正在并行检验 ${selectedFiles.length} 个Antigravity凭证,请稍候...`, 'info'); + + // 并行执行所有检验请求 + const promises = selectedFiles.map(async (filename) => { + try { + const response = await fetch(`./creds/verify-project/${encodeURIComponent(filename)}?mode=antigravity`, { + method: 'POST', + headers: getAuthHeaders() + }); + const data = await response.json(); + + if (response.ok && data.success) { + return { + success: true, + filename, + projectId: data.project_id, + creditAmount: data.credit_amount, + message: data.message + }; + } else { + return { success: false, filename, error: data.message || '失败' }; + } + } catch (error) { + return { success: false, filename, error: error.message }; + } + }); + + // 等待所有请求完成 + const results = await Promise.all(promises); + + // 统计结果 + let successCount = 0; + let failCount = 0; + const resultMessages = []; + + results.forEach(result => { + if (result.success) { + successCount++; + const creditSuffix = result.creditAmount !== undefined && result.creditAmount !== null + ? ` (积分: ${result.creditAmount})` + : ''; + resultMessages.push(`✅ ${result.filename}: ${result.projectId}${creditSuffix}`); + } else { + failCount++; + resultMessages.push(`❌ ${result.filename}: ${result.error}`); + } + }); + + await AppState.antigravityCreds.refresh(); + + const summary = `Antigravity批量检验完成!\n\n成功: ${successCount} 个\n失败: ${failCount} 个\n总计: ${selectedFiles.length} 个\n\n详细结果:\n${resultMessages.join('\n')}`; + + if (failCount === 0) { + showStatus(`✅ 全部检验成功!成功检验 ${successCount}/${selectedFiles.length} 个Antigravity凭证`, 'success'); + showMessageModal('Antigravity批量检验完成', summary, 'success'); + } else if (successCount === 0) { + showStatus(`❌ 全部检验失败!失败 ${failCount}/${selectedFiles.length} 个Antigravity凭证`, 'error'); + showMessageModal('Antigravity批量检验完成', summary, 'error'); + } else { + showStatus(`⚠️ 批量检验完成:成功 ${successCount}/${selectedFiles.length} 个,失败 ${failCount} 个`, 'info'); + showMessageModal('Antigravity批量检验完成', summary, 'info'); + } + + console.log(summary); +} + +async function batchConfigurePreview() { + const selectedFiles = Array.from(AppState.creds.selectedFiles); + if (selectedFiles.length === 0) { + showStatus('❌ 请先选择要配置Preview的凭证', 'error'); + showMessageModal('提示', '请先选择要配置Preview的凭证', 'error'); + return; + } + + if (!confirm(`确定要为 ${selectedFiles.length} 个凭证批量设置Preview通道吗?\n\n将并行配置以加快速度。`)) { + return; + } + + showStatus(`🔧 正在为 ${selectedFiles.length} 个凭证配置Preview通道,请稍候...`, 'info'); + + // 并行执行所有配置请求 + const promises = selectedFiles.map(async (filename) => { + try { + const response = await fetch(`./creds/configure-preview/${encodeURIComponent(filename)}`, { + method: 'POST', + headers: getAuthHeaders() + }); + const data = await response.json(); + + if (response.ok && data.success) { + return { + success: true, + filename, + message: data.message, + setting_id: data.setting_id, + binding_id: data.binding_id + }; + } else { + return { + success: false, + filename, + error: data.message || '配置失败', + step: data.step, + errorDetail: data.error + }; + } + } catch (error) { + return { success: false, filename, error: error.message }; + } + }); + + // 等待所有请求完成 + const results = await Promise.all(promises); + + // 统计结果 + let successCount = 0; + let failCount = 0; + const resultMessages = []; + + results.forEach(result => { + if (result.success) { + successCount++; + resultMessages.push(`✅ ${result.filename}: ${result.message || '配置成功'}`); + } else { + failCount++; + const errorMsg = result.step ? `${result.error} (步骤: ${result.step})` : result.error; + resultMessages.push(`❌ ${result.filename}: ${errorMsg}`); + } + }); + + await AppState.creds.refresh(); + + const summary = `批量配置Preview通道完成!\n\n成功: ${successCount} 个\n失败: ${failCount} 个\n总计: ${selectedFiles.length} 个\n\n详细结果:\n${resultMessages.join('\n')}`; + + if (failCount === 0) { + showStatus(`✅ 全部配置成功!成功配置 ${successCount}/${selectedFiles.length} 个凭证的Preview通道`, 'success'); + showMessageModal('批量配置Preview通道完成', summary, 'success'); + } else if (successCount === 0) { + showStatus(`❌ 全部配置失败!失败 ${failCount}/${selectedFiles.length} 个凭证`, 'error'); + showMessageModal('批量配置Preview通道完成', summary, 'error'); + } else { + showStatus(`⚠️ 批量配置完成:成功 ${successCount}/${selectedFiles.length} 个,失败 ${failCount} 个`, 'info'); + showMessageModal('批量配置Preview通道完成', summary, 'info'); + } + + console.log(summary); +} + + +async function refreshAllEmails() { + if (!confirm('确定要刷新所有凭证的用户邮箱吗?这可能需要一些时间。')) return; + + try { + showStatus('正在刷新所有用户邮箱...', 'info'); + const response = await fetch('./creds/refresh-all-emails', { + method: 'POST', + headers: getAuthHeaders() + }); + const data = await response.json(); + if (response.ok) { + showStatus(`邮箱刷新完成:成功获取 ${data.success_count}/${data.total_count} 个邮箱地址`, 'success'); + await AppState.creds.refresh(); + } else { + showStatus(data.message || '邮箱刷新失败', 'error'); + } + } catch (error) { + showStatus(`邮箱刷新网络错误: ${error.message}`, 'error'); + } +} + +async function refreshAllAntigravityEmails() { + if (!confirm('确定要刷新所有Antigravity凭证的用户邮箱吗?这可能需要一些时间。')) return; + + try { + showStatus('正在刷新所有用户邮箱...', 'info'); + const response = await fetch('./creds/refresh-all-emails?mode=antigravity', { + method: 'POST', + headers: getAuthHeaders() + }); + const data = await response.json(); + if (response.ok) { + showStatus(`邮箱刷新完成:成功获取 ${data.success_count}/${data.total_count} 个邮箱地址`, 'success'); + await AppState.antigravityCreds.refresh(); + } else { + showStatus(data.message || '邮箱刷新失败', 'error'); + } + } catch (error) { + showStatus(`邮箱刷新网络错误: ${error.message}`, 'error'); + } +} + +async function deduplicateByEmail() { + if (!confirm('确定要对凭证进行凭证一键去重吗?\n\n相同邮箱的凭证只保留一个,其他将被删除。\n此操作不可撤销!')) return; + + try { + showStatus('正在进行凭证一键去重...', 'info'); + const response = await fetch('./creds/deduplicate-by-email', { + method: 'POST', + headers: getAuthHeaders() + }); + const data = await response.json(); + if (response.ok) { + const msg = `去重完成:删除 ${data.deleted_count} 个重复凭证,保留 ${data.kept_count} 个凭证(${data.unique_emails_count} 个唯一邮箱)`; + showStatus(msg, 'success'); + await AppState.creds.refresh(); + + // 显示详细信息 + if (data.duplicate_groups && data.duplicate_groups.length > 0) { + let details = '去重详情:\n\n'; + data.duplicate_groups.forEach(group => { + details += `邮箱: ${group.email}\n保留: ${group.kept_file}\n删除: ${group.deleted_files.join(', ')}\n\n`; + }); + console.log(details); + } + } else { + showStatus(data.message || '去重失败', 'error'); + } + } catch (error) { + showStatus(`去重网络错误: ${error.message}`, 'error'); + } +} + +async function deduplicateAntigravityByEmail() { + if (!confirm('确定要对Antigravity凭证进行凭证一键去重吗?\n\n相同邮箱的凭证只保留一个,其他将被删除。\n此操作不可撤销!')) return; + + try { + showStatus('正在进行凭证一键去重...', 'info'); + const response = await fetch('./creds/deduplicate-by-email?mode=antigravity', { + method: 'POST', + headers: getAuthHeaders() + }); + const data = await response.json(); + if (response.ok) { + const msg = `去重完成:删除 ${data.deleted_count} 个重复凭证,保留 ${data.kept_count} 个凭证(${data.unique_emails_count} 个唯一邮箱)`; + showStatus(msg, 'success'); + await AppState.antigravityCreds.refresh(); + + // 显示详细信息 + if (data.duplicate_groups && data.duplicate_groups.length > 0) { + let details = '去重详情:\n\n'; + data.duplicate_groups.forEach(group => { + details += `邮箱: ${group.email}\n保留: ${group.kept_file}\n删除: ${group.deleted_files.join(', ')}\n\n`; + }); + console.log(details); + } + } else { + showStatus(data.message || '去重失败', 'error'); + } + } catch (error) { + showStatus(`去重网络错误: ${error.message}`, 'error'); + } +} + +// ===================================================================== +// WebSocket日志相关 +// ===================================================================== +function connectWebSocket() { + if (AppState.logWebSocket && AppState.logWebSocket.readyState === WebSocket.OPEN) { + showStatus('WebSocket已经连接', 'info'); + return; + } + + try { + const wsPath = new URL('./logs/stream', window.location.href).href; + const wsUrl = wsPath.replace(/^http/, 'ws'); + + // 添加 token 认证参数 + const wsUrlWithAuth = `${wsUrl}?token=${encodeURIComponent(AppState.authToken)}`; + + document.getElementById('connectionStatusText').textContent = '连接中...'; + document.getElementById('logConnectionStatus').className = 'status info'; + + AppState.logWebSocket = new WebSocket(wsUrlWithAuth); + + AppState.logWebSocket.onopen = () => { + document.getElementById('connectionStatusText').textContent = '已连接'; + document.getElementById('logConnectionStatus').className = 'status success'; + showStatus('日志流连接成功', 'success'); + clearLogsDisplay(); + }; + + AppState.logWebSocket.onmessage = (event) => { + const logLine = event.data; + if (logLine.trim()) { + AppState.allLogs.push(logLine); + if (AppState.allLogs.length > 1000) { + AppState.allLogs = AppState.allLogs.slice(-1000); + } + filterLogs(); + if (document.getElementById('autoScroll').checked) { + const logContainer = document.getElementById('logContainer'); + logContainer.scrollTop = logContainer.scrollHeight; + } + } + }; + + AppState.logWebSocket.onclose = () => { + document.getElementById('connectionStatusText').textContent = '连接断开'; + document.getElementById('logConnectionStatus').className = 'status error'; + showStatus('日志流连接断开', 'info'); + }; + + AppState.logWebSocket.onerror = (error) => { + document.getElementById('connectionStatusText').textContent = '连接错误'; + document.getElementById('logConnectionStatus').className = 'status error'; + showStatus('日志流连接错误: ' + error, 'error'); + }; + } catch (error) { + showStatus('创建WebSocket连接失败: ' + error.message, 'error'); + document.getElementById('connectionStatusText').textContent = '连接失败'; + document.getElementById('logConnectionStatus').className = 'status error'; + } +} + +function disconnectWebSocket() { + if (AppState.logWebSocket) { + AppState.logWebSocket.close(); + AppState.logWebSocket = null; + document.getElementById('connectionStatusText').textContent = '未连接'; + document.getElementById('logConnectionStatus').className = 'status info'; + showStatus('日志流连接已断开', 'info'); + } +} + +function clearLogsDisplay() { + AppState.allLogs = []; + AppState.filteredLogs = []; + document.getElementById('logContent').textContent = '日志已清空,等待新日志...'; +} + +async function downloadLogs() { + try { + const response = await fetch('./logs/download', { headers: getAuthHeaders() }); + + if (response.ok) { + const contentDisposition = response.headers.get('Content-Disposition'); + let filename = 'gcli2api_logs.txt'; + if (contentDisposition) { + const match = contentDisposition.match(/filename=(.+)/); + if (match) filename = match[1]; + } + + const blob = await response.blob(); + const url = window.URL.createObjectURL(blob); + const a = document.createElement('a'); + a.href = url; + a.download = filename; + a.click(); + window.URL.revokeObjectURL(url); + + showStatus(`日志文件下载成功: ${filename}`, 'success'); + } else { + const data = await response.json(); + showStatus(`下载日志失败: ${data.detail || data.error || '未知错误'}`, 'error'); + } + } catch (error) { + showStatus(`下载日志时网络错误: ${error.message}`, 'error'); + } +} + +async function clearLogs() { + try { + const response = await fetch('./logs/clear', { + method: 'POST', + headers: getAuthHeaders() + }); + + const data = await response.json(); + + if (response.ok) { + clearLogsDisplay(); + showStatus(data.message, 'success'); + } else { + showStatus(`清空日志失败: ${data.detail || data.error || '未知错误'}`, 'error'); + } + } catch (error) { + clearLogsDisplay(); + showStatus(`清空日志时网络错误: ${error.message}`, 'error'); + } +} + +function filterLogs() { + const filter = document.getElementById('logLevelFilter').value; + AppState.currentLogFilter = filter; + + if (filter === 'all') { + AppState.filteredLogs = [...AppState.allLogs]; + } else { + AppState.filteredLogs = AppState.allLogs.filter(log => log.toUpperCase().includes(filter)); + } + + displayLogs(); +} + +function displayLogs() { + const logContent = document.getElementById('logContent'); + if (AppState.filteredLogs.length === 0) { + logContent.textContent = AppState.currentLogFilter === 'all' ? + '暂无日志...' : `暂无${AppState.currentLogFilter}级别的日志...`; + } else { + logContent.textContent = AppState.filteredLogs.join('\n'); + } +} + +// ===================================================================== +// 环境变量凭证管理 +// ===================================================================== +async function checkEnvCredsStatus() { + const loading = document.getElementById('envStatusLoading'); + const content = document.getElementById('envStatusContent'); + + try { + loading.style.display = 'block'; + content.classList.add('hidden'); + + const response = await fetch('./auth/env-creds-status', { headers: getAuthHeaders() }); + const data = await response.json(); + + if (response.ok) { + const envVarsList = document.getElementById('envVarsList'); + envVarsList.textContent = Object.keys(data.available_env_vars).length > 0 + ? Object.keys(data.available_env_vars).join(', ') + : '未找到GCLI_CREDS_*环境变量'; + + const autoLoadStatus = document.getElementById('autoLoadStatus'); + autoLoadStatus.textContent = data.auto_load_enabled ? '✅ 已启用' : '❌ 未启用'; + autoLoadStatus.style.color = data.auto_load_enabled ? '#28a745' : '#dc3545'; + + document.getElementById('envFilesCount').textContent = `${data.existing_env_files_count} 个文件`; + + const envFilesList = document.getElementById('envFilesList'); + envFilesList.textContent = data.existing_env_files.length > 0 + ? data.existing_env_files.join(', ') + : '无'; + + content.classList.remove('hidden'); + showStatus('环境变量状态检查完成', 'success'); + } else { + showStatus(`获取环境变量状态失败: ${data.detail || data.error || '未知错误'}`, 'error'); + } + } catch (error) { + showStatus(`网络错误: ${error.message}`, 'error'); + } finally { + loading.style.display = 'none'; + } +} + +async function loadEnvCredentials() { + try { + showStatus('正在从环境变量导入凭证...', 'info'); + + const response = await fetch('./auth/load-env-creds', { + method: 'POST', + headers: getAuthHeaders() + }); + + const data = await response.json(); + + if (response.ok) { + if (data.loaded_count > 0) { + showStatus(`✅ 成功导入 ${data.loaded_count}/${data.total_count} 个凭证文件`, 'success'); + setTimeout(() => checkEnvCredsStatus(), 1000); + } else { + showStatus(`⚠️ ${data.message}`, 'info'); + } + } else { + showStatus(`导入失败: ${data.detail || data.error || '未知错误'}`, 'error'); + } + } catch (error) { + showStatus(`网络错误: ${error.message}`, 'error'); + } +} + +async function clearEnvCredentials() { + if (!confirm('确定要清除所有从环境变量导入的凭证文件吗?\n这将删除所有文件名以 "env-" 开头的认证文件。')) { + return; + } + + try { + showStatus('正在清除环境变量凭证文件...', 'info'); + + const response = await fetch('./auth/env-creds', { + method: 'DELETE', + headers: getAuthHeaders() + }); + + const data = await response.json(); + + if (response.ok) { + showStatus(`✅ 成功删除 ${data.deleted_count} 个环境变量凭证文件`, 'success'); + setTimeout(() => checkEnvCredsStatus(), 1000); + } else { + showStatus(`清除失败: ${data.detail || data.error || '未知错误'}`, 'error'); + } + } catch (error) { + showStatus(`网络错误: ${error.message}`, 'error'); + } +} + +// ===================================================================== +// 配置管理 +// ===================================================================== +async function loadConfig() { + const loading = document.getElementById('configLoading'); + const form = document.getElementById('configForm'); + + try { + loading.style.display = 'block'; + form.classList.add('hidden'); + + const response = await fetch('./config/get', { headers: getAuthHeaders() }); + const data = await response.json(); + + if (response.ok) { + AppState.currentConfig = data.config; + AppState.envLockedFields = new Set(data.env_locked || []); + + populateConfigForm(); + form.classList.remove('hidden'); + showStatus('配置加载成功', 'success'); + } else { + showStatus(`加载配置失败: ${data.detail || data.error || '未知错误'}`, 'error'); + } + } catch (error) { + showStatus(`网络错误: ${error.message}`, 'error'); + } finally { + loading.style.display = 'none'; + } +} + +function populateConfigForm() { + const c = AppState.currentConfig; + + setConfigField('host', c.host || '0.0.0.0'); + setConfigField('port', c.port || 7861); + setConfigField('configApiPassword', c.api_password || ''); + setConfigField('configPanelPassword', c.panel_password || ''); + setConfigField('configPassword', c.password || 'pwd'); + setConfigField('credentialsDir', c.credentials_dir || ''); + setConfigField('proxy', c.proxy || ''); + setConfigField('codeAssistEndpoint', c.code_assist_endpoint || ''); + setConfigField('oauthProxyUrl', c.oauth_proxy_url || ''); + setConfigField('googleapisProxyUrl', c.googleapis_proxy_url || ''); + setConfigField('resourceManagerApiUrl', c.resource_manager_api_url || ''); + setConfigField('serviceUsageApiUrl', c.service_usage_api_url || ''); + + document.getElementById('autoBanEnabled').checked = Boolean(c.auto_ban_enabled); + setConfigField('autoBanErrorCodes', (c.auto_ban_error_codes || []).join(',')); + setConfigField('callsPerRotation', c.calls_per_rotation || 10); + + document.getElementById('retry429Enabled').checked = Boolean(c.retry_429_enabled); + setConfigField('retry429MaxRetries', c.retry_429_max_retries || 20); + setConfigField('retry429Interval', c.retry_429_interval || 0.1); + + document.getElementById('compatibilityModeEnabled').checked = Boolean(c.compatibility_mode_enabled); + document.getElementById('returnThoughtsToFrontend').checked = Boolean(c.return_thoughts_to_frontend !== false); + document.getElementById('antigravityStream2nostream').checked = Boolean(c.antigravity_stream2nostream !== false); + + setConfigField('antiTruncationMaxAttempts', c.anti_truncation_max_attempts || 3); + + setConfigField('keepaliveUrl', c.keepalive_url || ''); + setConfigField('keepaliveInterval', c.keepalive_interval || 60); +} + +function setConfigField(fieldId, value) { + const field = document.getElementById(fieldId); + if (field) { + field.value = value; + const configKey = fieldId.replace(/([A-Z])/g, '_$1').toLowerCase(); + if (AppState.envLockedFields.has(configKey)) { + field.disabled = true; + field.classList.add('env-locked'); + } else { + field.disabled = false; + field.classList.remove('env-locked'); + } + } +} + +async function saveConfig() { + try { + const getValue = (id, def = '') => document.getElementById(id)?.value.trim() || def; + const getInt = (id, def = 0) => parseInt(document.getElementById(id)?.value) || def; + const getFloat = (id, def = 0.0) => parseFloat(document.getElementById(id)?.value) || def; + const getChecked = (id, def = false) => document.getElementById(id)?.checked || def; + + const config = { + host: getValue('host', '0.0.0.0'), + port: getInt('port', 7861), + api_password: getValue('configApiPassword'), + panel_password: getValue('configPanelPassword'), + password: getValue('configPassword', 'pwd'), + code_assist_endpoint: getValue('codeAssistEndpoint'), + credentials_dir: getValue('credentialsDir'), + proxy: getValue('proxy'), + oauth_proxy_url: getValue('oauthProxyUrl'), + googleapis_proxy_url: getValue('googleapisProxyUrl'), + resource_manager_api_url: getValue('resourceManagerApiUrl'), + service_usage_api_url: getValue('serviceUsageApiUrl'), + auto_ban_enabled: getChecked('autoBanEnabled'), + auto_ban_error_codes: getValue('autoBanErrorCodes').split(',') + .map(c => parseInt(c.trim())).filter(c => !isNaN(c)), + calls_per_rotation: getInt('callsPerRotation', 10), + retry_429_enabled: getChecked('retry429Enabled'), + retry_429_max_retries: getInt('retry429MaxRetries', 20), + retry_429_interval: getFloat('retry429Interval', 0.1), + compatibility_mode_enabled: getChecked('compatibilityModeEnabled'), + return_thoughts_to_frontend: getChecked('returnThoughtsToFrontend'), + antigravity_stream2nostream: getChecked('antigravityStream2nostream'), + anti_truncation_max_attempts: getInt('antiTruncationMaxAttempts', 3), + keepalive_url: getValue('keepaliveUrl'), + keepalive_interval: getInt('keepaliveInterval', 60) + }; + + const response = await fetch('./config/save', { + method: 'POST', + headers: getAuthHeaders(), + body: JSON.stringify({ config }) + }); + + const data = await response.json(); + + if (response.ok) { + let message = '配置保存成功'; + + if (data.hot_updated && data.hot_updated.length > 0) { + message += `,以下配置已立即生效: ${data.hot_updated.join(', ')}`; + } + + if (data.restart_required && data.restart_required.length > 0) { + message += `\n⚠️ 重启提醒: ${data.restart_notice}`; + showStatus(message, 'info'); + } else { + showStatus(message, 'success'); + } + + setTimeout(() => loadConfig(), 1000); + } else { + showStatus(`保存配置失败: ${data.detail || data.error || '未知错误'}`, 'error'); + } + } catch (error) { + showStatus(`网络错误: ${error.message}`, 'error'); + } +} + +// 镜像网址配置 +const mirrorUrls = { + codeAssistEndpoint: 'https://gcli-api.sukaka.top/cloudcode-pa', + oauthProxyUrl: 'https://gcli-api.sukaka.top/oauth2', + googleapisProxyUrl: 'https://gcli-api.sukaka.top/googleapis', + resourceManagerApiUrl: 'https://gcli-api.sukaka.top/cloudresourcemanager', + serviceUsageApiUrl: 'https://gcli-api.sukaka.top/serviceusage' +}; + +const officialUrls = { + codeAssistEndpoint: 'https://cloudcode-pa.googleapis.com', + oauthProxyUrl: 'https://oauth2.googleapis.com', + googleapisProxyUrl: 'https://www.googleapis.com', + resourceManagerApiUrl: 'https://cloudresourcemanager.googleapis.com', + serviceUsageApiUrl: 'https://serviceusage.googleapis.com' +}; + +function useMirrorUrls() { + if (confirm('确定要将所有端点配置为镜像网址吗?')) { + for (const [fieldId, url] of Object.entries(mirrorUrls)) { + const field = document.getElementById(fieldId); + if (field && !field.disabled) field.value = url; + } + showStatus('✅ 已切换到镜像网址配置,记得点击"保存配置"按钮保存设置', 'success'); + } +} + +function restoreOfficialUrls() { + if (confirm('确定要将所有端点配置为官方地址吗?')) { + for (const [fieldId, url] of Object.entries(officialUrls)) { + const field = document.getElementById(fieldId); + if (field && !field.disabled) field.value = url; + } + showStatus('✅ 已切换到官方端点配置,记得点击"保存配置"按钮保存设置', 'success'); + } +} + +// ===================================================================== +// 使用统计 +// ===================================================================== +async function refreshUsageStats() { + const loading = document.getElementById('usageLoading'); + const list = document.getElementById('usageList'); + + try { + loading.style.display = 'block'; + list.innerHTML = ''; + + const [statsResponse, aggregatedResponse] = await Promise.all([ + fetch('./usage/stats', { headers: getAuthHeaders() }), + fetch('./usage/aggregated', { headers: getAuthHeaders() }) + ]); + + if (statsResponse.status === 401 || aggregatedResponse.status === 401) { + showStatus('认证失败,请重新登录', 'error'); + setTimeout(() => location.reload(), 1500); + return; + } + + const statsData = await statsResponse.json(); + const aggregatedData = await aggregatedResponse.json(); + + if (statsResponse.ok && aggregatedResponse.ok) { + AppState.usageStatsData = statsData.success ? statsData.data : statsData; + + const aggData = aggregatedData.success ? aggregatedData.data : aggregatedData; + document.getElementById('totalApiCalls').textContent = aggData.total_calls_24h || 0; + document.getElementById('totalFiles').textContent = aggData.total_files || 0; + document.getElementById('avgCallsPerFile').textContent = (aggData.avg_calls_per_file || 0).toFixed(1); + + renderUsageList(); + + showStatus(`已加载 ${aggData.total_files || Object.keys(AppState.usageStatsData).length} 个文件的使用统计`, 'success'); + } else { + const errorMsg = statsData.detail || aggregatedData.detail || '加载使用统计失败'; + showStatus(`错误: ${errorMsg}`, 'error'); + } + } catch (error) { + showStatus(`网络错误: ${error.message}`, 'error'); + } finally { + loading.style.display = 'none'; + } +} + +function renderUsageList() { + const list = document.getElementById('usageList'); + list.innerHTML = ''; + + if (Object.keys(AppState.usageStatsData).length === 0) { + list.innerHTML = '

暂无使用统计数据

'; + return; + } + + for (const [filename, stats] of Object.entries(AppState.usageStatsData)) { + const card = document.createElement('div'); + card.className = 'usage-card'; + + const calls24h = stats.calls_24h || 0; + + card.innerHTML = ` +
+
${filename}
+
+
+
+ 24小时内调用次数 + ${calls24h} +
+
+
+ +
+ `; + + list.appendChild(card); + } +} + +async function resetSingleUsageStats(filename) { + if (!confirm(`确定要重置 ${filename} 的使用统计吗?`)) return; + + try { + const response = await fetch('./usage/reset', { + method: 'POST', + headers: getAuthHeaders(), + body: JSON.stringify({ filename }) + }); + + const data = await response.json(); + + if (response.ok && data.success) { + showStatus(data.message, 'success'); + await refreshUsageStats(); + } else { + showStatus(`重置失败: ${data.message || data.detail || data.error || '未知错误'}`, 'error'); + } + } catch (error) { + showStatus(`网络错误: ${error.message}`, 'error'); + } +} + +async function resetAllUsageStats() { + if (!confirm('确定要重置所有文件的使用统计吗?此操作不可恢复!')) return; + + try { + const response = await fetch('./usage/reset', { + method: 'POST', + headers: getAuthHeaders(), + body: JSON.stringify({}) + }); + + const data = await response.json(); + + if (response.ok && data.success) { + showStatus(data.message, 'success'); + await refreshUsageStats(); + } else { + showStatus(`重置失败: ${data.message || data.detail || data.error || '未知错误'}`, 'error'); + } + } catch (error) { + showStatus(`网络错误: ${error.message}`, 'error'); + } +} + +// ===================================================================== +// 冷却倒计时自动更新 +// ===================================================================== +function startCooldownTimer() { + if (AppState.cooldownTimerInterval) { + clearInterval(AppState.cooldownTimerInterval); + } + + AppState.cooldownTimerInterval = setInterval(() => { + updateCooldownDisplays(); + }, 1000); +} + +function stopCooldownTimer() { + if (AppState.cooldownTimerInterval) { + clearInterval(AppState.cooldownTimerInterval); + AppState.cooldownTimerInterval = null; + } +} + +function updateCooldownDisplays() { + let needsRefresh = false; + + // 检查模型级冷却是否过期 + for (const credInfo of Object.values(AppState.creds.data)) { + if (credInfo.model_cooldowns && Object.keys(credInfo.model_cooldowns).length > 0) { + const currentTime = Date.now() / 1000; + const hasExpiredCooldowns = Object.entries(credInfo.model_cooldowns).some(([, until]) => until <= currentTime); + + if (hasExpiredCooldowns) { + needsRefresh = true; + break; + } + } + } + + if (needsRefresh) { + AppState.creds.renderList(); + return; + } + + // 更新模型级冷却的显示 + document.querySelectorAll('.cooldown-badge').forEach(badge => { + const card = badge.closest('.cred-card'); + const filenameEl = card?.querySelector('.cred-filename'); + if (!filenameEl) return; + + const filename = filenameEl.textContent; + const credInfo = Object.values(AppState.creds.data).find(c => c.filename === filename); + + if (credInfo && credInfo.model_cooldowns) { + const currentTime = Date.now() / 1000; + const titleMatch = badge.getAttribute('title')?.match(/模型: (.+)/); + if (titleMatch) { + const model = titleMatch[1]; + const cooldownUntil = credInfo.model_cooldowns[model]; + if (cooldownUntil) { + const remaining = Math.max(0, Math.floor(cooldownUntil - currentTime)); + if (remaining > 0) { + const shortModel = model.replace('gemini-', '').replace('-exp', '') + .replace('2.0-', '2-').replace('1.5-', '1.5-'); + const timeDisplay = formatCooldownTime(remaining).replace(/s$/, '').replace(/ /g, ''); + badge.innerHTML = `⏰ ${shortModel}: ${timeDisplay}`; + } + } + } + } + }); +} + +// ===================================================================== +// 版本信息管理 +// ===================================================================== + +// 获取并显示版本信息(不检查更新) +async function fetchAndDisplayVersion() { + try { + const response = await fetch('./version/info'); + const data = await response.json(); + + const versionText = document.getElementById('versionText'); + + if (data.success) { + // 只显示版本号 + versionText.textContent = `v${data.version}`; + versionText.title = `完整版本: ${data.full_hash}\n提交信息: ${data.message}\n提交时间: ${data.date}`; + versionText.style.cursor = 'help'; + } else { + versionText.textContent = '未知版本'; + versionText.title = data.error || '无法获取版本信息'; + } + } catch (error) { + console.error('获取版本信息失败:', error); + const versionText = document.getElementById('versionText'); + if (versionText) { + versionText.textContent = '版本信息获取失败'; + } + } +} + +// 检查更新 +async function checkForUpdates() { + const checkBtn = document.getElementById('checkUpdateBtn'); + if (!checkBtn) return; + + const originalText = checkBtn.textContent; + + try { + // 显示检查中状态 + checkBtn.textContent = '检查中...'; + checkBtn.disabled = true; + + // 调用API检查更新 + const response = await fetch('./version/info?check_update=true'); + const data = await response.json(); + + if (data.success) { + if (data.check_update === false) { + // 检查更新失败 + showStatus(`检查更新失败: ${data.update_error || '未知错误'}`, 'error'); + } else if (data.has_update === true) { + // 有更新 + const updateMsg = `发现新版本!\n当前: v${data.version}\n最新: v${data.latest_version}\n\n更新内容: ${data.latest_message || '无'}`; + showStatus(updateMsg.replace(/\n/g, ' '), 'warning'); + + // 更新按钮样式 + checkBtn.style.backgroundColor = '#ffc107'; + checkBtn.textContent = '有新版本'; + + setTimeout(() => { + checkBtn.style.backgroundColor = '#17a2b8'; + checkBtn.textContent = originalText; + }, 5000); + } else if (data.has_update === false) { + // 已是最新 + showStatus('已是最新版本!', 'success'); + + checkBtn.style.backgroundColor = '#28a745'; + checkBtn.textContent = '已是最新'; + + setTimeout(() => { + checkBtn.style.backgroundColor = '#17a2b8'; + checkBtn.textContent = originalText; + }, 3000); + } else { + // 无法确定 + showStatus('无法确定是否有更新', 'info'); + } + } else { + showStatus(`检查更新失败: ${data.error}`, 'error'); + } + } catch (error) { + console.error('检查更新失败:', error); + showStatus(`检查更新失败: ${error.message}`, 'error'); + } finally { + checkBtn.disabled = false; + if (checkBtn.textContent === '检查中...') { + checkBtn.textContent = originalText; + } + } +} + +// ===================================================================== +// 页面初始化 +// ===================================================================== +window.onload = async function () { + const autoLoginSuccess = await autoLogin(); + + if (!autoLoginSuccess) { + showStatus('请输入密码登录', 'info'); + } else { + // 登录成功后获取版本信息 + await fetchAndDisplayVersion(); + } + + startCooldownTimer(); + + const antigravityAuthBtn = document.getElementById('getAntigravityAuthBtn'); + if (antigravityAuthBtn) { + antigravityAuthBtn.addEventListener('click', startAntigravityAuth); + } +}; + +// 拖拽功能 - 初始化 +document.addEventListener('DOMContentLoaded', function () { + const uploadArea = document.getElementById('uploadArea'); + + if (uploadArea) { + uploadArea.addEventListener('dragover', (event) => { + event.preventDefault(); + uploadArea.classList.add('dragover'); + }); + + uploadArea.addEventListener('dragleave', (event) => { + event.preventDefault(); + uploadArea.classList.remove('dragover'); + }); + + uploadArea.addEventListener('drop', (event) => { + event.preventDefault(); + uploadArea.classList.remove('dragover'); + AppState.uploadFiles.addFiles(Array.from(event.dataTransfer.files)); + }); + } +}); + +function autoSetKeepaliveUrl() { + const url = `${window.location.protocol}//${window.location.host}`; + document.getElementById('keepaliveUrl').value = url; +} diff --git a/front/control_panel.html b/front/control_panel.html new file mode 100644 index 0000000000000000000000000000000000000000..b37c9050d6c8d43c51d19b22cb2c066a4170afc3 --- /dev/null +++ b/front/control_panel.html @@ -0,0 +1,2392 @@ + + + + + + + GCLI2API 控制面板 + + + + +
+ + +
+

GCLI2API 管理面板

+

请输入访问密码:

+ +
+ +
+ + + +
+ + + + + + + \ No newline at end of file diff --git a/front/control_panel_mobile.html b/front/control_panel_mobile.html new file mode 100644 index 0000000000000000000000000000000000000000..77462122e16d792934f3cfd71b83423e8b826109 --- /dev/null +++ b/front/control_panel_mobile.html @@ -0,0 +1,2115 @@ + + + + + + + GCLI2API 移动端控制面板 + + + + +
+ + +
+

GCLI2API 移动端控制面板

+

请输入访问密码:

+ +

+ +
+ + + +
+ + + + + + \ No newline at end of file diff --git a/install.ps1 b/install.ps1 new file mode 100644 index 0000000000000000000000000000000000000000..3c81816c6ec1ea5fb1aeb51a2995ee6be9b44807 --- /dev/null +++ b/install.ps1 @@ -0,0 +1,35 @@ +# 检测是否为管理员 +$IsElevated = ([Security.Principal.WindowsPrincipal] [Security.Principal.WindowsIdentity]::GetCurrent()). + IsInRole([Security.Principal.WindowsBuiltInRole]::Administrator) + +# Skip Scoop install if already present to avoid stopping the script +if (Get-Command scoop -ErrorAction SilentlyContinue) { + Write-Host "Scoop is already installed. Skipping installation." +} else { + Set-ExecutionPolicy -ExecutionPolicy RemoteSigned -Scope CurrentUser -Force + if ($IsElevated) { + # 管理员:使用官方一行命令并传入 -RunAsAdmin + Invoke-Expression "& {$(Invoke-RestMethod get.scoop.sh)} -RunAsAdmin" + } else { + # 普通用户安装 + Invoke-WebRequest -useb get.scoop.sh | Invoke-Expression + } +} + +scoop install git uv +if (Test-Path -LiteralPath "./web.py") { + # Already in target directory; skip clone and cd +} +elseif (Test-Path -LiteralPath "./gcli2api/web.py") { + Set-Location ./gcli2api +} +else { + git clone https://github.com/su-kaka/gcli2api.git + Set-Location ./gcli2api +} +# Create relocatable virtual environment to ensure portability +$env:UV_VENV_CLEAR = "1" +uv venv --relocatable +uv sync +.venv/Scripts/activate.ps1 +python web.py \ No newline at end of file diff --git a/install.sh b/install.sh new file mode 100644 index 0000000000000000000000000000000000000000..c5372a1352b947e4741b91cc852dd8a2c59c82ff --- /dev/null +++ b/install.sh @@ -0,0 +1,302 @@ +#!/bin/bash +set -e # Exit on error +set -u # Exit on undefined variable +set -o pipefail # Exit on pipe failure + +# Color codes for output +RED='\033[0;31m' +GREEN='\033[0;32m' +YELLOW='\033[1;33m' +BLUE='\033[0;34m' +NC='\033[0m' # No Color + +# Logging functions +log_info() { + echo -e "${GREEN}[INFO]${NC} $1" +} + +log_error() { + echo -e "${RED}[ERROR]${NC} $1" >&2 +} + +log_warn() { + echo -e "${YELLOW}[WARN]${NC} $1" +} + +log_debug() { + echo -e "${BLUE}[DEBUG]${NC} $1" +} + +# Cleanup function for error handling +cleanup() { + local exit_code=$? + if [ $exit_code -ne 0 ]; then + log_error "Installation failed with exit code $exit_code" + fi + exit $exit_code +} + +trap cleanup EXIT + +# Detect OS and distribution +detect_os() { + log_info "Detecting operating system..." + + if [[ "$OSTYPE" == "linux-gnu"* ]]; then + if [ -f /etc/os-release ]; then + . /etc/os-release + OS_NAME=$ID + OS_VERSION=$VERSION_ID + log_info "Detected: $NAME $VERSION_ID" + elif [ -f /etc/lsb-release ]; then + . /etc/lsb-release + OS_NAME=$DISTRIB_ID + OS_VERSION=$DISTRIB_RELEASE + log_info "Detected: $DISTRIB_ID $DISTRIB_RELEASE" + else + OS_NAME="linux" + OS_VERSION="unknown" + log_warn "Could not determine specific Linux distribution" + fi + elif [[ "$OSTYPE" == "darwin"* ]]; then + OS_NAME="macos" + OS_VERSION=$(sw_vers -productVersion) + log_info "Detected: macOS $OS_VERSION" + elif [[ "$OSTYPE" == "freebsd"* ]]; then + OS_NAME="freebsd" + OS_VERSION=$(freebsd-version) + log_info "Detected: FreeBSD $OS_VERSION" + else + log_error "Unsupported operating system: $OSTYPE" + exit 1 + fi +} + +# Check for root privileges (only for Linux package managers that need it) +check_root_if_needed() { + if [[ "$OS_NAME" == "ubuntu" ]] || [[ "$OS_NAME" == "debian" ]] || [[ "$OS_NAME" == "linuxmint" ]] || [[ "$OS_NAME" == "kali" ]]; then + if [ "$EUID" -ne 0 ]; then + log_error "This script requires root privileges for apt. Please run with sudo." + exit 1 + fi + elif [[ "$OS_NAME" == "fedora" ]] || [[ "$OS_NAME" == "rhel" ]] || [[ "$OS_NAME" == "centos" ]] || [[ "$OS_NAME" == "rocky" ]] || [[ "$OS_NAME" == "almalinux" ]]; then + if [ "$EUID" -ne 0 ]; then + log_error "This script requires root privileges for dnf/yum. Please run with sudo." + exit 1 + fi + elif [[ "$OS_NAME" == "arch" ]] || [[ "$OS_NAME" == "manjaro" ]]; then + if [ "$EUID" -ne 0 ]; then + log_error "This script requires root privileges for pacman. Please run with sudo." + exit 1 + fi + fi +} + +# Update package manager +update_packages() { + log_info "Updating package manager..." + + case "$OS_NAME" in + ubuntu|debian|linuxmint|kali|pop) + if ! apt update; then + log_error "Failed to update apt package lists" + exit 1 + fi + ;; + fedora|rhel|centos|rocky|almalinux) + if command -v dnf &> /dev/null; then + if ! dnf check-update; then + # dnf check-update returns 100 if updates are available, which is not an error + if [ $? -ne 100 ]; then + log_warn "dnf check-update returned non-standard exit code" + fi + fi + else + if ! yum check-update; then + if [ $? -ne 100 ]; then + log_warn "yum check-update returned non-standard exit code" + fi + fi + fi + ;; + arch|manjaro) + if ! pacman -Syu; then + log_error "Failed to update pacman database" + exit 1 + fi + ;; + macos) + if command -v brew &> /dev/null; then + log_info "Updating Homebrew..." + brew update + else + log_warn "Homebrew not installed. Skipping package manager update." + fi + ;; + *) + log_warn "Unknown package manager for $OS_NAME. Skipping update." + ;; + esac +} + +# Install git based on OS +install_git() { + if ! command -v git &> /dev/null; then + log_info "Installing git..." + + case "$OS_NAME" in + ubuntu|debian|linuxmint|kali|pop) + if ! apt install git -y; then + log_error "Failed to install git" + exit 1 + fi + ;; + fedora|rhel|centos|rocky|almalinux) + if command -v dnf &> /dev/null; then + if ! dnf install git -y; then + log_error "Failed to install git" + exit 1 + fi + else + if ! yum install git -y; then + log_error "Failed to install git" + exit 1 + fi + fi + ;; + arch|manjaro) + if ! pacman -S git --noconfirm; then + log_error "Failed to install git" + exit 1 + fi + ;; + macos) + if command -v brew &> /dev/null; then + if ! brew install git; then + log_error "Failed to install git" + exit 1 + fi + else + log_error "Homebrew is required for macOS. Install from https://brew.sh/" + exit 1 + fi + ;; + *) + log_error "Don't know how to install git on $OS_NAME" + exit 1 + ;; + esac + else + log_info "Git is already installed ($(git --version))" + fi +} + +# Detect OS first +detect_os + +# Check root if needed +check_root_if_needed + +log_info "Starting installation process..." + +# Update package lists +update_packages + +# Install git +install_git + +# Install uv if not present +if ! command -v uv &> /dev/null; then + log_info "Installing uv package manager..." + if ! curl -Ls https://astral.sh/uv/install.sh | sh; then + log_error "Failed to install uv" + exit 1 + fi + + # Source environment + if [ -f "$HOME/.local/bin/env" ]; then + source "$HOME/.local/bin/env" + elif [ -f "$HOME/.cargo/env" ]; then + source "$HOME/.cargo/env" + fi + + # Verify uv installation + if ! command -v uv &> /dev/null; then + log_error "uv installation failed - command not found after install" + exit 1 + fi +else + log_info "uv is already installed" +fi + +# Determine working directory +log_info "Checking project directory..." +if [ -f "./web.py" ]; then + log_info "Already in target directory" +elif [ -f "./gcli2api/web.py" ]; then + log_info "Changing to gcli2api directory" + cd ./gcli2api || exit 1 +else + log_info "Cloning repository..." + if [ -d "./gcli2api" ]; then + log_warn "gcli2api directory exists but web.py not found. Removing and re-cloning..." + rm -rf ./gcli2api + fi + + if ! git clone https://github.com/su-kaka/gcli2api.git; then + log_error "Failed to clone repository" + exit 1 + fi + + cd ./gcli2api || exit 1 +fi + +# Update repository if it's a git repo +if [ -d ".git" ]; then + log_info "Updating repository..." + if ! git pull; then + log_warn "Git pull failed, continuing anyway..." + fi +else + log_warn "Not a git repository, skipping update" +fi + +# Create relocatable virtual environment to ensure portability +log_info "Creating relocatable virtual environment..." +export UV_VENV_CLEAR=1 +if ! uv venv --relocatable; then + log_error "Failed to create virtual environment" + exit 1 +fi + +# Sync dependencies +log_info "Syncing dependencies with uv..." +if ! uv sync; then + log_error "Failed to sync dependencies" + exit 1 +fi + +# Activate virtual environment +log_info "Activating virtual environment..." +if [ -f ".venv/bin/activate" ]; then + source .venv/bin/activate +else + log_error "Virtual environment not found at .venv/bin/activate" + exit 1 +fi + +# Verify Python is available +if ! command -v python3 &> /dev/null; then + log_error "python3 not found in virtual environment" + exit 1 +fi + +# Check if web.py exists +if [ ! -f "web.py" ]; then + log_error "web.py not found in current directory" + exit 1 +fi + +# Start the application +log_info "Starting application..." +python3 web.py \ No newline at end of file diff --git a/log.py b/log.py new file mode 100644 index 0000000000000000000000000000000000000000..ea944a741093cc8ee0e465a2d4c5c5091566819e --- /dev/null +++ b/log.py @@ -0,0 +1,327 @@ +""" +日志模块 - 使用环境变量配置 +""" + +import os +import sys +import threading +from datetime import datetime +from collections import deque +import atexit + +# 日志级别定义 +LOG_LEVELS = {"debug": 0, "info": 1, "warning": 2, "error": 3, "critical": 4} + +# 文件写入状态标志(仅由 writer 线程修改,无需锁保护) +_file_writing_disabled = False +_disable_reason = None + +# 全局文件句柄(仅由 writer 线程访问,无需文件锁) +_log_file_handle = None + +# ----------------------------------------------------------------- +# 高性能无锁队列:用 deque + Condition 替代 Queue +# deque.append / deque.popleft 在 CPython 中受 GIL 保护,是原子操作, +# 不需要额外的 Lock 做入队保护,只用 Condition 做"有数据"通知。 +# ----------------------------------------------------------------- +_log_deque: deque = deque() +_deque_condition = threading.Condition(threading.Lock()) +_writer_thread = None +_writer_running = False + +# ----------------------------------------------------------------- +# 缓存日志级别,避免每次都读 os.getenv(高并发热路径) +# ----------------------------------------------------------------- +_cached_log_level: int = LOG_LEVELS["info"] +_cached_log_file: str = "log.txt" +# ENABLE_LOG=0/false/no/off 时彻底关闭日志 +_log_enabled: bool = True + + +def _refresh_config(): + """从环境变量刷新缓存配置(模块加载时及需要时调用)""" + global _cached_log_level, _cached_log_file, _log_enabled + level = os.getenv("LOG_LEVEL", "info").lower() + _cached_log_level = LOG_LEVELS.get(level, LOG_LEVELS["info"]) + _cached_log_file = os.getenv("LOG_FILE", "log.txt") + _log_enabled = os.getenv("ENABLE_LOG", "1").strip().lower() not in ("0", "false", "no", "off") + + +def _get_current_log_level() -> int: + return _cached_log_level + + +def _get_log_file_path() -> str: + return _cached_log_file + + +# ----------------------------------------------------------------- +# 文件句柄管理(仅在 writer 线程内调用,不需要 _file_lock) +# ----------------------------------------------------------------- + +def _close_log_file(): + global _log_file_handle + if _log_file_handle is not None: + try: + _log_file_handle.flush() + _log_file_handle.close() + except Exception: + pass + finally: + _log_file_handle = None + + +def _open_log_file(mode: str = "a") -> bool: + global _log_file_handle, _file_writing_disabled, _disable_reason + _close_log_file() + try: + # 使用较大缓冲区(64 KB),由 writer 线程定期 flush,减少系统调用 + _log_file_handle = open(_cached_log_file, mode, encoding="utf-8", buffering=65536) + return True + except (PermissionError, OSError, IOError) as e: + _file_writing_disabled = True + _disable_reason = str(e) + print(f"Warning: Cannot open log file, disabling file writing: {e}", file=sys.stderr) + print("Log messages will continue to display in console only.", file=sys.stderr) + return False + except Exception as e: + print(f"Warning: Failed to open log file: {e}", file=sys.stderr) + return False + + +def _clear_log_file(): + """清空日志文件(启动时调用,此时 writer 线程尚未启动,直接操作安全)""" + global _file_writing_disabled, _disable_reason + try: + with open(_cached_log_file, "w", encoding="utf-8") as f: + pass # 覆盖清空 + _open_log_file("a") + except (PermissionError, OSError, IOError) as e: + _file_writing_disabled = True + _disable_reason = str(e) + print( + f"Warning: File system appears to be read-only or permission denied. " + f"Disabling log file writing: {e}", + file=sys.stderr, + ) + print("Log messages will continue to display in console only.", file=sys.stderr) + except Exception as e: + print(f"Warning: Failed to clear log file: {e}", file=sys.stderr) + + +# ----------------------------------------------------------------- +# Writer 线程:批量从 deque 取出并写入,减少系统调用次数 +# ----------------------------------------------------------------- +_BATCH_SIZE = 1000 # 单次最多批量写入条数 +_FLUSH_INTERVAL = 2 # 秒:无新消息时强制 flush 周期 + + +def _log_writer_worker(): + global _writer_running + + last_flush_time = 0.0 + + while True: + # 等待数据或超时 + with _deque_condition: + if not _log_deque and _writer_running: + _deque_condition.wait(timeout=_FLUSH_INTERVAL) + + # 批量取出 + batch = [] + for _ in range(_BATCH_SIZE): + if _log_deque: + batch.append(_log_deque.popleft()) + else: + break + + if batch and not _file_writing_disabled: + # 一次 write 调用搞定整批,最大化减少系统调用 + chunk = "\n".join(batch) + "\n" + try: + if _log_file_handle is None: + _open_log_file("a") + if _log_file_handle is not None: + _log_file_handle.write(chunk) + except Exception as e: + print(f"Warning: Failed to write log batch: {e}", file=sys.stderr) + _close_log_file() + try: + _open_log_file("a") + except Exception: + pass + + # 定时 flush + now = _now_ts() + if now - last_flush_time >= _FLUSH_INTERVAL: + if _log_file_handle is not None: + try: + _log_file_handle.flush() + except Exception: + pass + last_flush_time = now + + # 退出条件:已停止 + deque 已清空 + if not _writer_running and not _log_deque: + break + + # 最终 flush & close + if _log_file_handle is not None: + try: + _log_file_handle.flush() + except Exception: + pass + _close_log_file() + + +def _now_ts() -> float: + import time + return time.monotonic() + + +def _start_writer_thread(): + global _writer_thread, _writer_running + + if _writer_thread is None or not _writer_thread.is_alive(): + _writer_running = True + _writer_thread = threading.Thread(target=_log_writer_worker, daemon=True, name="LogWriter") + _writer_thread.start() + + +def _stop_writer_thread(): + global _writer_running + + _writer_running = False + # 唤醒 writer 线程让它能感知退出信号 + with _deque_condition: + _deque_condition.notify_all() + + if _writer_thread and _writer_thread.is_alive(): + _writer_thread.join(timeout=3.0) + + +# ----------------------------------------------------------------- +# 入队(热路径,极轻量) +# ----------------------------------------------------------------- +_MAX_QUEUE_SIZE = 5000 # 防止极端情况内存无限增长 + + +def _write_to_file(message: str): + if _file_writing_disabled: + return + # deque.append 在 CPython 受 GIL 保护,无需额外锁 + if len(_log_deque) >= _MAX_QUEUE_SIZE: + return # 过载保护:丢弃而非阻塞 + _log_deque.append(message) + # 非阻塞通知 writer(acquire 失败直接跳过,不影响主线程) + if _deque_condition.acquire(blocking=False): + try: + _deque_condition.notify() + finally: + _deque_condition.release() + + +# ----------------------------------------------------------------- +# 核心日志函数(热路径) +# ----------------------------------------------------------------- + +def _log(level: str, message: str): + # 最快短路:日志整体已禁用时直接返回,零开销 + if not _log_enabled: + return + + level = level.lower() + level_val = LOG_LEVELS.get(level) + if level_val is None: + print(f"Warning: Unknown log level '{level}'", file=sys.stderr) + return + + # 热路径:直接与缓存值比较,无函数调用开销 + if level_val < _cached_log_level: + return + + timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + entry = f"[{timestamp}] [{level.upper()}] {message}" + + if level in ("error", "critical"): + print(entry, file=sys.stderr) + else: + print(entry) + + _write_to_file(entry) + + +def set_log_level(level: str): + """动态设置日志级别(同时更新缓存)""" + global _cached_log_level + level = level.lower() + if level not in LOG_LEVELS: + print(f"Warning: Unknown log level '{level}'. Valid levels: {', '.join(LOG_LEVELS.keys())}") + return False + _cached_log_level = LOG_LEVELS[level] + return True + + +class Logger: + """支持 log('info', 'msg') 和 log.info('msg') 两种调用方式""" + + def __call__(self, level: str, message: str): + _log(level, message) + + def debug(self, message: str): + _log("debug", message) + + def info(self, message: str): + _log("info", message) + + def warning(self, message: str): + _log("warning", message) + + def error(self, message: str): + _log("error", message) + + def critical(self, message: str): + _log("critical", message) + + def get_current_level(self) -> str: + current_level = _get_current_log_level() + for name, value in LOG_LEVELS.items(): + if value == current_level: + return name + return "info" + + def get_log_file(self) -> str: + return _get_log_file_path() + + def close(self): + """手动关闭(优雅退出用)""" + _stop_writer_thread() + + def get_queue_size(self) -> int: + return len(_log_deque) + + +# 导出全局日志实例 +log = Logger() + +# 导出的公共接口 +__all__ = ["log", "set_log_level", "LOG_LEVELS"] + +# 模块加载时:读取配置缓存 → 清空日志文件 → 启动 writer 线程 +_refresh_config() +if _log_enabled: + _clear_log_file() + _start_writer_thread() + +# 注册退出清理 +atexit.register(_stop_writer_thread) + +# 使用说明: +# 1. 设置日志级别: export LOG_LEVEL=debug (或在 .env 中设置) +# 2. 设置日志文件: export LOG_FILE=log.txt (或在 .env 中设置) +# 3. 日志级别已缓存,热路径零 os.getenv 调用 +# 4. 写入线程批量处理(最多 200 条/次),64 KB 缓冲区,每 0.5 s flush 一次 +# 5. 队列上限 5000 条,超出时丢弃新日志(过载保护,不阻塞主线程) +# 6. 动态调整级别:set_log_level('debug') 立即生效 +# 7. 彻底关闭日志(最高性能):export ENABLE_LOG=0 (或 false/no/off) +# 关闭后不会启动 writer 线程、不写文件、不打印控制台,_log 直接 return diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000000000000000000000000000000000000..37746149416dc858e3286c10996271c05112e212 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,103 @@ +[project] +name = "gcli2api" +version = "0.1.0" +description = "Convert GeminiCLI to OpenAI and Gemini API interfaces" +readme = "README.md" +requires-python = ">=3.12" +license = {text = "CNC-1.0"} +authors = [ + {name = "su-kaka"} +] +keywords = ["gemini", "openai", "api", "converter", "cli"] +classifiers = [ + "Development Status :: 4 - Beta", + "Intended Audience :: Developers", + "License :: Other/Proprietary License", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.13", +] +dependencies = [ + "aiofiles>=24.1.0", + "fastapi>=0.116.1", + "httpx[socks]>=0.28.1", + "hypercorn>=0.17.3", + "motor>=3.7.1", + "oauthlib>=3.3.1", + "pydantic>=2.11.7", + "pyjwt>=2.10.1", + "python-dotenv>=1.1.1", + "python-multipart>=0.0.20", + "pypinyin>=0.51.0", + "aiosqlite>=0.20.0", + "redis>=7.2.0", + "asyncpg>=0.31.0", +] + +[project.optional-dependencies] +dev = [ + "pytest>=8.0.0", + "pytest-asyncio>=0.23.0", + "pytest-cov>=4.1.0", + "black>=24.0.0", + "flake8>=7.0.0", + "mypy>=1.8.0", + "pre-commit>=3.6.0", +] + +[tool.pytest.ini_options] +minversion = "8.0" +testpaths = ["."] +python_files = ["test_*.py"] +python_classes = ["Test*"] +python_functions = ["test_*"] +asyncio_mode = "auto" +addopts = [ + "-v", + "--strict-markers", +] + +[tool.black] +line-length = 100 +target-version = ["py313"] +include = '\.pyi?$' +extend-exclude = ''' +/( + # directories + \.eggs + | \.git + | \.hg + | \.mypy_cache + | \.tox + | \.venv + | build + | dist +)/ +''' + +[tool.mypy] +python_version = "3.13" +warn_return_any = true +warn_unused_configs = true +disallow_untyped_defs = false +ignore_missing_imports = true +exclude = [ + "build", + "dist", +] + +[tool.coverage.run] +source = ["src"] +omit = [ + "*/tests/*", + "*/test_*.py", +] + +[tool.coverage.report] +exclude_lines = [ + "pragma: no cover", + "def __repr__", + "raise AssertionError", + "raise NotImplementedError", + "if __name__ == .__main__.:", + "if TYPE_CHECKING:", +] diff --git a/render.yaml b/render.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7eabfe61b71296ae5ed3e6f231e738859dbd00a4 --- /dev/null +++ b/render.yaml @@ -0,0 +1,20 @@ +services: + - type: web + name: gcli2api + runtime: docker + dockerfilePath: ./Dockerfile + dockerContext: . + plan: free + region: singapore + healthCheckPath: / + + envVars: + # ========== 必填:访问密码 ========== + - key: PASSWORD + sync: false # 部署时手动填写,不同步到代码库 + + # ========== 服务器配置 ========== + - key: HOST + value: 0.0.0.0 + - key: PORT + value: "10000" # Render 要求 Web 服务监听 10000 端口 diff --git a/requirements-termux.txt b/requirements-termux.txt new file mode 100644 index 0000000000000000000000000000000000000000..aa9bd0ea916dc7e2b9d55d1c4dac234a96cd3b1f --- /dev/null +++ b/requirements-termux.txt @@ -0,0 +1,14 @@ +fastapi +httpx[socks] +pydantic==1.10.22 +python-dotenv +hypercorn +aiofiles +python-multipart +PyJWT +oauthlib +motor +pypinyin +aiosqlite +redis +asyncpg \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..fcc07407949409bf590c8e866bdd0dc9b4e55db8 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,14 @@ +fastapi>=0.116.1 +httpx[socks]>=0.28.1 +pydantic>=2.11.7 +python-dotenv>=1.1.1 +hypercorn>=0.17.3 +aiofiles>=24.1.0 +python-multipart>=0.0.20 +PyJWT>=2.10.1 +oauthlib>=3.3.1 +motor>=3.7.1 +aiosqlite>=0.20.0 +pypinyin>=0.51.0 +redis>=4.2.0 +asyncpg diff --git a/src/api/Response_example.txt b/src/api/Response_example.txt new file mode 100644 index 0000000000000000000000000000000000000000..456ecccec35ede0c18172ce80a9f403aa62ac3ad --- /dev/null +++ b/src/api/Response_example.txt @@ -0,0 +1,210 @@ +================================================================================ +GeminiCli API 测试 +================================================================================ + +================================================================================ +【测试1】流式请求 (stream_request with native=False) +================================================================================ +请求体: { + "model": "gemini-2.5-flash", + "request": { + "contents": [ + { + "role": "user", + "parts": [ + { + "text": "Hello, tell me a joke in one sentence." + } + ] + } + ] + } +} + +流式响应数据 (每个chunk): +-------------------------------------------------------------------------------- +[2026-01-10 09:55:29] [INFO] SQLite storage initialized at ./creds\credentials.db +[2026-01-10 09:55:29] [INFO] Using SQLite storage backend +[2026-01-10 09:55:31] [INFO] Token刷新成 功并已保存: my-project-9-481103-1765596755.json (mode=geminicli) +[2026-01-10 09:55:34] [INFO] [DB] 准备commit,总更新行数=1 +[2026-01-10 09:55:34] [INFO] [DB] commit 完成 +[2026-01-10 09:55:34] [INFO] [DB] update_credential_state 结束: success=True, updated_count=1 + +Chunk #1: + 类型: str + 长度: 626 + 内容预览: 'data: {"response": {"candidates": [{"content": {"role": "model","parts": [{"text": "Why did the scarecrow win an award? Because he was outstanding in his field."}]},"finishReason": "STOP"}],"usageMeta' + 解析后的JSON: { + "response": { + "candidates": [ + { + "content": { + "role": "model", + "parts": [ + { + "text": "Why did the scarecrow win an award? Because he was outstanding in his field." + } + ] + }, + "finishReason": "STOP" + } + ], + "usageMetadata": { + "promptTokenCount": 10, + "candidatesTokenCount": 17, + "totalTokenCount": 51, + "trafficType": "PROVISIONED_THROUGHPUT", + "promptTokensDetails": [ + { + "modality": "TEXT", + "tokenCount": 10 + } + ], + "candidatesTokensDetails": [ + { + "modality": "TEXT", + "tokenCount": 17 + } + ], + "thoughtsTokenCount": 24 + }, + "modelVersion": "gemini-2.5-flash", + "createTime": "2026-01-10T01:55:29.168589Z", + "responseId": "kbFhaY2lCr-ZseMPqMiDmAU" + }, + "traceId": "55650653afd3c738" +} + +Chunk #2: + 类型: str + 长度: 0 + 内容预览: '' +E:\projects\gcli2api\src\api\geminicli.py:491: RuntimeWarning: coroutine 'get_auto_ban_error_codes' was never awaited + async for chunk in stream_request(body=test_body, native=False): +RuntimeWarning: Enable tracemalloc to get the object allocation traceback + +总共收到 2 个chunk + +================================================================================ +【测试2】非流式请求 (non_stream_request) +================================================================================ +请求体: { + "model": "gemini-2.5-flash", + "request": { + "contents": [ + { + "role": "user", + "parts": [ + { + "text": "Hello, tell me a joke in one sentence." + } + ] + } + ] + } +} + +[2026-01-10 09:55:35] [INFO] Token刷新成 功并已保存: gen-lang-client-0194852792-1767296759.json (mode=geminicli) +[2026-01-10 09:55:38] [INFO] [DB] 准备commit,总更新行数=1 +[2026-01-10 09:55:38] [INFO] [DB] commit 完成 +[2026-01-10 09:55:38] [INFO] [DB] update_credential_state 结束: success=True, updated_count=1 +E:\projects\gcli2api\src\api\geminicli.py:530: RuntimeWarning: coroutine 'get_auto_ban_error_codes' was never awaited + response = await non_stream_request(body=test_body) +RuntimeWarning: Enable tracemalloc to get the object allocation traceback +非流式响应数据: +-------------------------------------------------------------------------------- +状态码: 200 +Content-Type: application/json; charset=UTF-8 + +响应头: {'server': 'openresty', 'date': 'Sat, 10 Jan 2026 01:55:34 GMT', 'content-type': 'application/json; charset=UTF-8', 'transfer-encoding': 'chunked', 'connection': 'keep-alive', 'x-cloudaicompanion-trace-id': 'bf3a5eb6636774d2', 'vary': 'Origin, X-Origin, Referer', 'content-encoding': 'gzip', 'x-xss-protection': '0', 'x-frame-options': 'SAMEORIGIN', 'x-content-type-options': 'nosniff', 'server-timing': 'gfet4t7; dur=1377', 'alt-svc': 'h3=":443"; ma=2592000,h3-29=":443"; ma=2592000', 'access-control-allow-origin': '*', 'access-control-allow-methods': 'GET, POST, PUT, DELETE, PATCH, OPTIONS', 'access-control-allow-headers': 'Content-Type, Authorization, X-Requested-With', 'cache-control': 'no-cache', 'content-length': '969'} + +响应内容 (原始): +{ + "response": { + "candidates": [ + { + "content": { + "role": "model", + "parts": [ + { + "text": "Why did the scarecrow win an award? Because he was outstanding in his field!" + } + ] + }, + "finishReason": "STOP", + "avgLogprobs": -0.54438119776108684 + } + ], + "usageMetadata": { + "promptTokenCount": 10, + "candidatesTokenCount": 17, + "totalTokenCount": 47, + "trafficType": "PROVISIONED_THROUGHPUT", + "promptTokensDetails": [ + { + "modality": "TEXT", + "tokenCount": 10 + } + ], + "candidatesTokensDetails": [ + { + "modality": "TEXT", + "tokenCount": 17 + } + ], + "thoughtsTokenCount": 20 + }, + "modelVersion": "gemini-2.5-flash", + "createTime": "2026-01-10T01:55:33.450396Z", + "responseId": "lbFhady-G7yi694PmLOP4As" + }, + "traceId": "bf3a5eb6636774d2" +} + + +响应内容 (格式化JSON): +{ + "response": { + "candidates": [ + { + "content": { + "role": "model", + "parts": [ + { + "text": "Why did the scarecrow win an award? Because he was outstanding in his field!" + } + ] + }, + "finishReason": "STOP", + "avgLogprobs": -0.5443811977610868 + } + ], + "usageMetadata": { + "promptTokenCount": 10, + "candidatesTokenCount": 17, + "totalTokenCount": 47, + "trafficType": "PROVISIONED_THROUGHPUT", + "promptTokensDetails": [ + { + "modality": "TEXT", + "tokenCount": 10 + } + ], + "candidatesTokensDetails": [ + { + "modality": "TEXT", + "tokenCount": 17 + } + ], + "thoughtsTokenCount": 20 + }, + "modelVersion": "gemini-2.5-flash", + "createTime": "2026-01-10T01:55:33.450396Z", + "responseId": "lbFhady-G7yi694PmLOP4As" + }, + "traceId": "bf3a5eb6636774d2" +} + +================================================================================ +测试完成 +================================================================================ \ No newline at end of file diff --git a/src/api/antigravity.py b/src/api/antigravity.py new file mode 100644 index 0000000000000000000000000000000000000000..f23ebccc0393cdd83fceacdf16dcef63448985a6 --- /dev/null +++ b/src/api/antigravity.py @@ -0,0 +1,845 @@ +""" +Antigravity API Client - Handles communication with Google's Antigravity API +处理与 Google Antigravity API 的通信 +""" + +import asyncio +import json +import uuid +from datetime import datetime, timezone +from typing import Any, Dict, List, Optional, Callable, Tuple + +from fastapi import Response +from config import ( + get_code_assist_endpoint, + get_antigravity_stream2nostream, + get_auto_ban_error_codes, +) +from log import log + +from src.credential_manager import credential_manager +from src.httpx_client import stream_post_async, post_async +from src.models import Model, model_to_dict +from src.utils import ANTIGRAVITY_USER_AGENT + +# 导入共同的基础功能 +from src.api.utils import ( + handle_error_with_retry, + get_retry_config, + record_api_call_success, + record_api_call_error, + parse_and_log_cooldown, + collect_streaming_response, +) + +# ==================== 全局凭证管理器 ==================== + +# 使用全局单例 credential_manager,自动初始化 + + +# ==================== 辅助函数 ==================== + +def build_antigravity_headers(access_token: str, model_name: str = "") -> Dict[str, str]: + """ + 构建 Antigravity API 请求头 + + Args: + access_token: 访问令牌 + model_name: 模型名称,用于判断 request_type + + Returns: + 请求头字典 + """ + headers = { + 'User-Agent': ANTIGRAVITY_USER_AGENT, + 'Authorization': f'Bearer {access_token}', + 'Content-Type': 'application/json', + 'Accept-Encoding': 'gzip', + 'requestId': f"req-{uuid.uuid4()}" + } + + # 根据模型名称判断 request_type + if model_name: + # 先判断是否是图片模型 + if "image" in model_name.lower(): + request_type = "image_gen" + headers['requestType'] = request_type + else: + request_type = "agent" + headers['requestType'] = request_type + + return headers + + +def _is_retryable_status(status_code: int, disable_error_codes: List[int]) -> bool: + """统一判断是否属于可重试状态码。""" + return status_code in (429, 503) or status_code in disable_error_codes + + +async def _switch_credential_for_retry( + *, + next_cred_task: Optional[asyncio.Task], + retry_interval: float, + refresh_credential_fast: Callable[[], Any], + apply_cred_result: Callable[[Tuple[str, Dict[str, Any]]], bool], + log_prefix: str, +) -> Tuple[bool, Optional[asyncio.Task]]: + """优先使用预热凭证,失败后退回同步刷新。""" + if next_cred_task is not None: + try: + cred_result = await next_cred_task + next_cred_task = None + if cred_result and apply_cred_result(cred_result): + await asyncio.sleep(retry_interval) + return True, next_cred_task + except Exception as e: + log.warning(f"{log_prefix} 预热凭证任务失败: {e}") + next_cred_task = None + + await asyncio.sleep(retry_interval) + if await refresh_credential_fast(): + return True, next_cred_task + + return False, next_cred_task + + +# ==================== 新的流式和非流式请求函数 ==================== + +async def stream_request( + body: Dict[str, Any], + native: bool = False, + headers: Optional[Dict[str, str]] = None, +): + """ + 流式请求函数 + + Args: + body: 请求体 + native: 是否返回原生bytes流,False则返回str流 + headers: 额外的请求头 + + Yields: + Response对象(错误时)或 bytes流/str流(成功时) + """ + model_name = body.get("model", "") + + # 1. 获取有效凭证 + cred_result = await credential_manager.get_valid_credential( + mode="antigravity", model_name=model_name + ) + + if not cred_result: + # 如果返回值是None,直接返回错误500 + log.error("[ANTIGRAVITY STREAM] 当前无可用凭证") + yield Response( + content=json.dumps({"error": "当前无可用凭证"}), + status_code=500, + media_type="application/json" + ) + return + + current_file, credential_data = cred_result + access_token = credential_data.get("access_token") or credential_data.get("token") + project_id = credential_data.get("project_id", "") + + if not access_token: + log.error(f"[ANTIGRAVITY STREAM] No access token in credential: {current_file}") + yield Response( + content=json.dumps({"error": "凭证中没有访问令牌"}), + status_code=500, + media_type="application/json" + ) + return + + # 2. 构建URL和请求头 + antigravity_url = await get_code_assist_endpoint() + target_url = f"{antigravity_url}/v1internal:streamGenerateContent?alt=sse" + + auth_headers = build_antigravity_headers(access_token, model_name) + + # 合并自定义headers + if headers: + auth_headers.update(headers) + + # 构建包含project的payload + final_payload = { + "model": body.get("model"), + "project": project_id, + "request": body.get("request", {}), + } + + # 仅当凭证明确开启积分消耗时注入 enabledCreditTypes + def apply_enabled_credit_types(cred_data: Dict[str, Any]) -> None: + if cred_data.get("enable_credit") is True: + final_payload["enabledCreditTypes"] = ["GOOGLE_ONE_AI"] + else: + final_payload.pop("enabledCreditTypes", None) + + apply_enabled_credit_types(credential_data) + + # 3. 调用stream_post_async进行请求 + retry_config = await get_retry_config() + max_retries = retry_config["max_retries"] + retry_interval = retry_config["retry_interval"] + + DISABLE_ERROR_CODES = await get_auto_ban_error_codes() # 禁用凭证的错误码 + last_error_response = None # 记录最后一次的错误响应 + next_cred_task = None # 预热的下一个凭证任务 + + # 内部函数:快速更新凭证(只更新token和project_id,避免重建整个请求) + async def refresh_credential_fast(): + nonlocal current_file, access_token, auth_headers, project_id, final_payload + cred_result = await credential_manager.get_valid_credential( + mode="antigravity", model_name=model_name + ) + if not cred_result: + return None + current_file, credential_data = cred_result + access_token = credential_data.get("access_token") or credential_data.get("token") + project_id = credential_data.get("project_id", "") + if not access_token: + return None + # 只更新token和project_id,不重建整个headers和payload + auth_headers["Authorization"] = f"Bearer {access_token}" + final_payload["project"] = project_id + apply_enabled_credit_types(credential_data) + return True + + def apply_cred_result(cred_result: Tuple[str, Dict[str, Any]]) -> bool: + nonlocal current_file, access_token, project_id, auth_headers, final_payload + current_file, credential_data = cred_result + access_token = credential_data.get("access_token") or credential_data.get("token") + project_id = credential_data.get("project_id", "") + if not access_token or not project_id: + return False + auth_headers["Authorization"] = f"Bearer {access_token}" + final_payload["project"] = project_id + apply_enabled_credit_types(credential_data) + return True + + for attempt in range(max_retries + 1): + success_recorded = False # 标记是否已记录成功 + need_retry = False # 标记是否需要重试 + + try: + async for chunk in stream_post_async( + url=target_url, + body=final_payload, + native=native, + headers=auth_headers + ): + # 判断是否是Response对象 + if isinstance(chunk, Response): + status_code = chunk.status_code + last_error_response = chunk # 记录最后一次错误 + + # 缓存错误解析结果,避免重复decode + error_body = None + try: + error_body = chunk.body.decode('utf-8') if isinstance(chunk.body, bytes) else str(chunk.body) + except Exception: + error_body = "" + + # 如果错误码是429、503或者在禁用码当中,做好记录后进行重试 + if _is_retryable_status(status_code, DISABLE_ERROR_CODES): + log.warning(f"[ANTIGRAVITY STREAM] 流式请求失败 (status={status_code}), 凭证: {current_file}, 响应: {error_body[:500] if error_body else '无'}") + + # 解析冷却时间 + cooldown_until = None + if (status_code == 429 or status_code == 503) and error_body: + try: + cooldown_until = await parse_and_log_cooldown(error_body, mode="antigravity") + except Exception: + pass + + # 预热下一个凭证 + if next_cred_task is None and attempt < max_retries: + next_cred_task = asyncio.create_task( + credential_manager.get_valid_credential( + mode="antigravity", model_name=model_name + ) + ) + + # 记录错误并切换凭证 + await record_api_call_error( + credential_manager, current_file, status_code, + cooldown_until, mode="antigravity", model_name=model_name, + error_message=error_body + ) + + # 检查是否应该重试 + should_retry = await handle_error_with_retry( + credential_manager, status_code, current_file, + retry_config["retry_enabled"], attempt, max_retries, retry_interval, + mode="antigravity" + ) + + if should_retry and attempt < max_retries: + need_retry = True + break # 跳出内层循环,准备重试 + else: + # 不重试,直接返回原始错误 + log.error(f"[ANTIGRAVITY STREAM] 达到最大重试次数或不应重试,返回原始错误") + yield chunk + return + else: + # 错误码不在禁用码当中,直接返回,无需重试 + log.error(f"[ANTIGRAVITY STREAM] 流式请求失败,非重试错误码 (status={status_code}), 凭证: {current_file}, 响应: {error_body[:500] if error_body else '无'}") + await record_api_call_error( + credential_manager, current_file, status_code, + None, mode="antigravity", model_name=model_name, + error_message=error_body + ) + yield chunk + return + else: + # 不是Response,说明是真流,直接yield返回 + # 只在第一个chunk时记录成功 + if not success_recorded: + await record_api_call_success( + credential_manager, current_file, mode="antigravity", model_name=model_name + ) + success_recorded = True + log.debug(f"[ANTIGRAVITY STREAM] 开始接收流式响应,模型: {model_name}") + + # 记录原始chunk内容(用于调试) + if isinstance(chunk, bytes): + log.debug(f"[ANTIGRAVITY STREAM RAW] chunk(bytes): {chunk}") + else: + log.debug(f"[ANTIGRAVITY STREAM RAW] chunk(str): {chunk}") + + yield chunk + + # 流式请求完成,检查结果 + if success_recorded: + log.debug(f"[ANTIGRAVITY STREAM] 流式响应完成,模型: {model_name}") + return + elif not need_retry: + # 没有收到任何数据(空回复),需要重试 + log.warning(f"[ANTIGRAVITY STREAM] 收到空回复,无任何内容,凭证: {current_file}") + await record_api_call_error( + credential_manager, current_file, 200, + None, mode="antigravity", model_name=model_name, + error_message="Empty response from API" + ) + + if attempt < max_retries: + need_retry = True + else: + log.error(f"[ANTIGRAVITY STREAM] 空回复达到最大重试次数") + yield Response( + content=json.dumps({"error": "服务返回空回复"}), + status_code=500, + media_type="application/json" + ) + return + + # 统一处理重试 + if need_retry: + log.info(f"[ANTIGRAVITY STREAM] 重试请求 (attempt {attempt + 2}/{max_retries + 1})...") + + switched, next_cred_task = await _switch_credential_for_retry( + next_cred_task=next_cred_task, + retry_interval=retry_interval, + refresh_credential_fast=refresh_credential_fast, + apply_cred_result=apply_cred_result, + log_prefix="[ANTIGRAVITY STREAM]", + ) + if not switched: + log.error("[ANTIGRAVITY STREAM] 重试时无可用凭证或令牌") + yield Response( + content=json.dumps({"error": "当前无可用凭证"}), + status_code=500, + media_type="application/json" + ) + return + continue # 重试 + + except Exception as e: + log.error(f"[ANTIGRAVITY STREAM] 流式请求异常: {e}, 凭证: {current_file}") + if attempt < max_retries: + log.info(f"[ANTIGRAVITY STREAM] 异常后重试 (attempt {attempt + 2}/{max_retries + 1})...") + await asyncio.sleep(retry_interval) + continue + else: + # 所有重试都失败,返回最后一次的错误(如果有) + log.error(f"[ANTIGRAVITY STREAM] 所有重试均失败,最后异常: {e}") + if last_error_response: + yield last_error_response + else: + # 如果没有记录到错误响应,返回500错误 + yield Response( + content=json.dumps({"error": f"流式请求异常: {str(e)}"}), + status_code=500, + media_type="application/json" + ) + return + + # 所有重试均已耗尽(for循环正常结束),返回最后记录的错误 + log.error("[ANTIGRAVITY STREAM] 所有重试均失败") + if last_error_response: + yield last_error_response + else: + yield Response( + content=json.dumps({"error": "请求失败,所有重试均已耗尽"}), + status_code=429, + media_type="application/json" + ) + + +async def non_stream_request( + body: Dict[str, Any], + headers: Optional[Dict[str, str]] = None, +) -> Response: + """ + 非流式请求函数 + + Args: + body: 请求体 + headers: 额外的请求头 + + Returns: + Response对象 + """ + # 检查是否启用流式收集模式 + if await get_antigravity_stream2nostream(): + log.debug("[ANTIGRAVITY] 使用流式收集模式实现非流式请求") + + # 调用stream_request获取流 + stream = stream_request(body=body, native=False, headers=headers) + + # 收集流式响应 + # stream_request是一个异步生成器,可能yield Response(错误)或流数据 + # collect_streaming_response会自动处理这两种情况 + return await collect_streaming_response(stream) + + # 否则使用传统非流式模式 + log.debug("[ANTIGRAVITY] 使用传统非流式模式") + + model_name = body.get("model", "") + + # 1. 获取有效凭证 + cred_result = await credential_manager.get_valid_credential( + mode="antigravity", model_name=model_name + ) + + if not cred_result: + # 如果返回值是None,直接返回错误500 + log.error("[ANTIGRAVITY] 当前无可用凭证") + return Response( + content=json.dumps({"error": "当前无可用凭证"}), + status_code=500, + media_type="application/json" + ) + + current_file, credential_data = cred_result + access_token = credential_data.get("access_token") or credential_data.get("token") + project_id = credential_data.get("project_id", "") + + if not access_token: + log.error(f"[ANTIGRAVITY] No access token in credential: {current_file}") + return Response( + content=json.dumps({"error": "凭证中没有访问令牌"}), + status_code=500, + media_type="application/json" + ) + + # 2. 构建URL和请求头 + antigravity_url = await get_code_assist_endpoint() + target_url = f"{antigravity_url}/v1internal:generateContent" + + auth_headers = build_antigravity_headers(access_token, model_name) + + # 合并自定义headers + if headers: + auth_headers.update(headers) + + # 构建包含project的payload + final_payload = { + "model": body.get("model"), + "project": project_id, + "request": body.get("request", {}), + } + + # 仅当凭证明确开启积分消耗时注入 enabledCreditTypes + def apply_enabled_credit_types(cred_data: Dict[str, Any]) -> None: + if cred_data.get("enable_credit") is True: + final_payload["enabledCreditTypes"] = ["GOOGLE_ONE_AI"] + else: + final_payload.pop("enabledCreditTypes", None) + + apply_enabled_credit_types(credential_data) + + # 3. 调用post_async进行请求 + retry_config = await get_retry_config() + max_retries = retry_config["max_retries"] + retry_interval = retry_config["retry_interval"] + + DISABLE_ERROR_CODES = await get_auto_ban_error_codes() # 禁用凭证的错误码 + last_error_response = None # 记录最后一次的错误响应 + next_cred_task = None # 预热的下一个凭证任务 + + # 内部函数:快速更新凭证(只更新token和project_id,避免重建整个请求) + async def refresh_credential_fast(): + nonlocal current_file, access_token, auth_headers, project_id, final_payload + cred_result = await credential_manager.get_valid_credential( + mode="antigravity", model_name=model_name + ) + if not cred_result: + return None + current_file, credential_data = cred_result + access_token = credential_data.get("access_token") or credential_data.get("token") + project_id = credential_data.get("project_id", "") + if not access_token: + return None + # 只更新token和project_id,不重建整个headers和payload + auth_headers["Authorization"] = f"Bearer {access_token}" + final_payload["project"] = project_id + apply_enabled_credit_types(credential_data) + return True + + def apply_cred_result(cred_result: Tuple[str, Dict[str, Any]]) -> bool: + nonlocal current_file, access_token, project_id, auth_headers, final_payload + current_file, credential_data = cred_result + access_token = credential_data.get("access_token") or credential_data.get("token") + project_id = credential_data.get("project_id", "") + if not access_token or not project_id: + return False + auth_headers["Authorization"] = f"Bearer {access_token}" + final_payload["project"] = project_id + apply_enabled_credit_types(credential_data) + return True + + for attempt in range(max_retries + 1): + need_retry = False # 标记是否需要重试 + + try: + response = await post_async( + url=target_url, + json=final_payload, + headers=auth_headers, + timeout=300.0 + ) + + status_code = response.status_code + + # 成功 + if status_code == 200: + # 检查是否为空回复 + if not response.content or len(response.content) == 0: + log.warning(f"[ANTIGRAVITY] 收到200响应但内容为空,凭证: {current_file}") + + # 记录错误 + await record_api_call_error( + credential_manager, current_file, 200, + None, mode="antigravity", model_name=model_name, + error_message="Empty response from API" + ) + + if attempt < max_retries: + need_retry = True + else: + log.error(f"[ANTIGRAVITY] 空回复达到最大重试次数") + return Response( + content=json.dumps({"error": "服务返回空回复"}), + status_code=500, + media_type="application/json" + ) + else: + # 正常响应 + await record_api_call_success( + credential_manager, current_file, mode="antigravity", model_name=model_name + ) + return Response( + content=response.content, + status_code=200, + headers=dict(response.headers) + ) + + # 失败 - 记录最后一次错误 + if status_code != 200: + last_error_response = Response( + content=response.content, + status_code=status_code, + headers=dict(response.headers) + ) + + # 判断是否需要重试 + # 缓存错误文本,避免重复解析 + error_text = "" + try: + error_text = response.text + except Exception: + pass + + if _is_retryable_status(status_code, DISABLE_ERROR_CODES): + log.warning(f"[ANTIGRAVITY] 非流式请求失败 (status={status_code}), 凭证: {current_file}, 响应: {error_text[:500] if error_text else '无'}") + + # 解析冷却时间 + cooldown_until = None + if (status_code == 429 or status_code == 503) and error_text: + try: + cooldown_until = await parse_and_log_cooldown(error_text, mode="antigravity") + except Exception: + pass + + # 并行预热下一个凭证,不阻塞当前处理 + if next_cred_task is None and attempt < max_retries: + next_cred_task = asyncio.create_task( + credential_manager.get_valid_credential( + mode="antigravity", model_name=model_name + ) + ) + + # 记录错误并切换凭证 + await record_api_call_error( + credential_manager, current_file, status_code, + cooldown_until, mode="antigravity", model_name=model_name, + error_message=error_text + ) + + # 检查是否应该重试 + should_retry = await handle_error_with_retry( + credential_manager, status_code, current_file, + retry_config["retry_enabled"], attempt, max_retries, retry_interval, + mode="antigravity" + ) + + if should_retry and attempt < max_retries: + need_retry = True + else: + # 不重试,直接返回原始错误 + log.error(f"[ANTIGRAVITY] 达到最大重试次数或不应重试,返回原始错误") + return last_error_response + else: + # 错误码不在禁用码当中,直接返回,无需重试 + log.error(f"[ANTIGRAVITY] 非流式请求失败,非重试错误码 (status={status_code}), 凭证: {current_file}, 响应: {error_text[:500] if error_text else '无'}") + await record_api_call_error( + credential_manager, current_file, status_code, + None, mode="antigravity", model_name=model_name, + error_message=error_text + ) + return last_error_response + + # 统一处理重试 + if need_retry: + log.info(f"[ANTIGRAVITY] 重试请求 (attempt {attempt + 2}/{max_retries + 1})...") + + switched, next_cred_task = await _switch_credential_for_retry( + next_cred_task=next_cred_task, + retry_interval=retry_interval, + refresh_credential_fast=refresh_credential_fast, + apply_cred_result=apply_cred_result, + log_prefix="[ANTIGRAVITY]", + ) + if not switched: + log.error("[ANTIGRAVITY] 重试时无可用凭证或令牌") + return Response( + content=json.dumps({"error": "当前无可用凭证"}), + status_code=500, + media_type="application/json" + ) + continue # 重试 + + except Exception as e: + log.error(f"[ANTIGRAVITY] 非流式请求异常: {e}, 凭证: {current_file}") + if attempt < max_retries: + log.info(f"[ANTIGRAVITY] 异常后重试 (attempt {attempt + 2}/{max_retries + 1})...") + await asyncio.sleep(retry_interval) + continue + else: + # 所有重试都失败,返回最后一次的错误(如果有)或500错误 + log.error(f"[ANTIGRAVITY] 所有重试均失败,最后异常: {e}") + if last_error_response: + return last_error_response + else: + return Response( + content=json.dumps({"error": f"非流式请求异常: {str(e)}"}), + status_code=500, + media_type="application/json" + ) + + # 所有重试都失败,返回最后一次的原始错误(如果有)或500错误 + log.error("[ANTIGRAVITY] 所有重试均失败") + if last_error_response: + return last_error_response + else: + return Response( + content=json.dumps({"error": "所有重试均失败"}), + status_code=500, + media_type="application/json" + ) + + +# ==================== 模型和配额查询 ==================== + +async def fetch_available_models() -> List[Dict[str, Any]]: + """ + 获取可用模型列表,返回符合 OpenAI API 规范的格式 + + Returns: + 模型列表,格式为字典列表(用于兼容现有代码) + + Raises: + 返回空列表如果获取失败 + """ + # 获取凭证管理器和可用凭证 + cred_result = await credential_manager.get_valid_credential(mode="antigravity") + if not cred_result: + log.error("[ANTIGRAVITY] No valid credentials available for fetching models") + return [] + + current_file, credential_data = cred_result + access_token = credential_data.get("access_token") or credential_data.get("token") + + if not access_token: + log.error(f"[ANTIGRAVITY] No access token in credential: {current_file}") + return [] + + # 构建请求头 + headers = build_antigravity_headers(access_token) + + try: + # 使用 POST 请求获取模型列表 + antigravity_url = await get_code_assist_endpoint() + + response = await post_async( + url=f"{antigravity_url}/v1internal:fetchAvailableModels", + json={}, # 空的请求体 + headers=headers + ) + + if response.status_code == 200: + data = response.json() + log.debug(f"[ANTIGRAVITY] Raw models response: {json.dumps(data, ensure_ascii=False)[:500]}") + + # 转换为 OpenAI 格式的模型列表,使用 Model 类 + model_list = [] + current_timestamp = int(datetime.now(timezone.utc).timestamp()) + + if 'models' in data and isinstance(data['models'], dict): + # 遍历模型字典 + for model_id in data['models'].keys(): + model = Model( + id=model_id, + object='model', + created=current_timestamp, + owned_by='google' + ) + model_list.append(model_to_dict(model)) + # 添加额外的 claude-sonnet-4-6-thinking 模型 + if "claude-sonnet-4-6" in data.get('models', {}): + model = Model( + id='claude-sonnet-4-6-thinking', + object='model', + created=current_timestamp, + owned_by='google' + ) + model_list.append(model_to_dict(model)) + # 添加额外的 claude-opus-4-6 模型 + if "claude-opus-4-6-thinking" in data.get('models', {}): + claude_opus_model = Model( + id='claude-opus-4-6', + object='model', + created=current_timestamp, + owned_by='google' + ) + model_list.append(model_to_dict(claude_opus_model)) + + log.info(f"[ANTIGRAVITY] Fetched {len(model_list)} available models") + return model_list + else: + log.error(f"[ANTIGRAVITY] Failed to fetch models ({response.status_code}): {response.text[:500]}") + return [] + + except Exception as e: + import traceback + log.error(f"[ANTIGRAVITY] Failed to fetch models: {e}") + log.error(f"[ANTIGRAVITY] Traceback: {traceback.format_exc()}") + return [] + + +async def fetch_quota_info(access_token: str) -> Dict[str, Any]: + """ + 获取指定凭证的额度信息 + + Args: + access_token: Antigravity 访问令牌 + + Returns: + 包含额度信息的字典,格式为: + { + "success": True/False, + "models": { + "model_name": { + "remaining": 0.95, + "resetTime": "12-20 10:30", + "resetTimeRaw": "2025-12-20T02:30:00Z" + } + }, + "error": "错误信息" (仅在失败时) + } + """ + + headers = build_antigravity_headers(access_token) + + try: + antigravity_url = await get_code_assist_endpoint() + + response = await post_async( + url=f"{antigravity_url}/v1internal:fetchAvailableModels", + json={}, + headers=headers, + timeout=30.0 + ) + + if response.status_code == 200: + data = response.json() + log.debug(f"[ANTIGRAVITY QUOTA] Raw response: {json.dumps(data, ensure_ascii=False)[:500]}") + + quota_info = {} + + if 'models' in data and isinstance(data['models'], dict): + for model_id, model_data in data['models'].items(): + if isinstance(model_data, dict) and 'quotaInfo' in model_data: + quota = model_data['quotaInfo'] + remaining = quota.get('remainingFraction', 0) + reset_time_raw = quota.get('resetTime', '') + + # 转换为北京时间 + reset_time_beijing = 'N/A' + if reset_time_raw: + try: + utc_date = datetime.fromisoformat(reset_time_raw.replace('Z', '+00:00')) + # 转换为北京时间 (UTC+8) + from datetime import timedelta + beijing_date = utc_date + timedelta(hours=8) + reset_time_beijing = beijing_date.strftime('%m-%d %H:%M') + except Exception as e: + log.warning(f"[ANTIGRAVITY QUOTA] Failed to parse reset time: {e}") + + quota_info[model_id] = { + "remaining": remaining, + "resetTime": reset_time_beijing, + "resetTimeRaw": reset_time_raw + } + + return { + "success": True, + "models": quota_info + } + else: + log.error(f"[ANTIGRAVITY QUOTA] Failed to fetch quota ({response.status_code}): {response.text[:500]}") + return { + "success": False, + "error": f"API返回错误: {response.status_code}" + } + + except Exception as e: + import traceback + log.error(f"[ANTIGRAVITY QUOTA] Failed to fetch quota: {e}") + log.error(f"[ANTIGRAVITY QUOTA] Traceback: {traceback.format_exc()}") + return { + "success": False, + "error": str(e) + } \ No newline at end of file diff --git a/src/api/geminicli.py b/src/api/geminicli.py new file mode 100644 index 0000000000000000000000000000000000000000..09f70db10e83feb6ff5450e6cceb60c0b1990f02 --- /dev/null +++ b/src/api/geminicli.py @@ -0,0 +1,808 @@ +""" +GeminiCli API Client - Handles all communication with GeminiCli API. +This module is used by both OpenAI compatibility layer and native Gemini endpoints. +GeminiCli API 客户端 - 处理与 GeminiCli API 的所有通信 +""" + +import sys +from pathlib import Path + +# 添加项目根目录到Python路径(用于直接运行测试) +if __name__ == "__main__": + project_root = Path(__file__).resolve().parent.parent.parent + if str(project_root) not in sys.path: + sys.path.insert(0, str(project_root)) + +import asyncio +import json +from typing import Any, Dict, Optional, Callable, Tuple + +from fastapi import Response +from config import get_code_assist_endpoint, get_auto_ban_error_codes +from log import log + +from src.credential_manager import credential_manager +from src.httpx_client import stream_post_async, post_async + +# 导入共同的基础功能 +from src.api.utils import ( + handle_error_with_retry, + get_retry_config, + record_api_call_success, + record_api_call_error, + parse_and_log_cooldown, +) +from src.utils import get_geminicli_user_agent + +# ==================== 全局凭证管理器 ==================== + +# 使用全局单例 credential_manager,自动初始化 + + +# ==================== 请求准备 ==================== + +async def prepare_request_headers_and_payload( + payload: dict, credential_data: dict, target_url: str +): + """ + 从凭证数据准备请求头和最终payload + + Args: + payload: 原始请求payload + credential_data: 凭证数据字典 + target_url: 目标URL + + Returns: + 元组: (headers, final_payload, target_url) + + Raises: + Exception: 如果凭证中缺少必要字段 + """ + token = credential_data.get("token") or credential_data.get("access_token", "") + if not token: + raise Exception("凭证中没有找到有效的访问令牌(token或access_token字段)") + + source_request = payload.get("request", {}) + + # 内部API使用Bearer Token和项目ID + headers = { + "Authorization": f"Bearer {token}", + "Content-Type": "application/json", + "User-Agent": get_geminicli_user_agent(payload.get("model", "")), + } + project_id = credential_data.get("project_id", "") + if not project_id: + raise Exception("项目ID不存在于凭证数据中") + final_payload = { + "model": payload.get("model"), + "project": project_id, + "request": source_request, + } + + return headers, final_payload, target_url + + +def _is_retryable_status(status_code: int, disable_error_codes: list[int]) -> bool: + """统一判断是否属于可重试状态码。""" + return status_code in (429, 503) or status_code in disable_error_codes + + +async def _switch_credential_for_retry( + *, + next_cred_task: Optional[asyncio.Task], + retry_interval: float, + refresh_credential_fast: Callable[[], Any], + apply_cred_result: Callable[[Tuple[str, Dict[str, Any]]], bool], + log_prefix: str, +) -> Tuple[bool, Optional[asyncio.Task]]: + """优先使用预热凭证,失败后退回同步刷新。""" + if next_cred_task is not None: + try: + cred_result = await next_cred_task + next_cred_task = None + if cred_result and apply_cred_result(cred_result): + await asyncio.sleep(retry_interval) + return True, next_cred_task + except Exception as e: + log.warning(f"{log_prefix} 预热凭证任务失败: {e}") + next_cred_task = None + + await asyncio.sleep(retry_interval) + if await refresh_credential_fast(): + return True, next_cred_task + + return False, next_cred_task + + +# ==================== 新的流式和非流式请求函数 ==================== + +async def stream_request( + body: Dict[str, Any], + native: bool = False, + headers: Optional[Dict[str, str]] = None, +): + """ + 流式请求函数 + + Args: + body: 请求体 + native: 是否返回原生bytes流,False则返回str流 + headers: 额外的请求头 + + Yields: + Response对象(错误时)或 bytes流/str流(成功时) + """ + # 获取有效凭证 + model_name = body.get("model", "") + + # 1. 获取有效凭证 + cred_result = await credential_manager.get_valid_credential( + mode="geminicli", model_name=model_name + ) + + if not cred_result: + # 如果返回值是None,直接返回错误500 + yield Response( + content=json.dumps({"error": "当前无可用凭证"}), + status_code=500, + media_type="application/json" + ) + return + + current_file, credential_data = cred_result + + # 2. 构建URL和请求头 + try: + auth_headers, final_payload, target_url = await prepare_request_headers_and_payload( + body, credential_data, + f"{await get_code_assist_endpoint()}/v1internal:streamGenerateContent?alt=sse" + ) + + # 合并自定义headers + if headers: + auth_headers.update(headers) + + except Exception as e: + log.error(f"准备请求失败: {e}") + yield Response( + content=json.dumps({"error": f"准备请求失败: {str(e)}"}), + status_code=500, + media_type="application/json" + ) + return + + # 3. 调用stream_post_async进行请求 + retry_config = await get_retry_config() + max_retries = retry_config["max_retries"] + retry_interval = retry_config["retry_interval"] + + DISABLE_ERROR_CODES = await get_auto_ban_error_codes() # 禁用凭证的错误码 + last_error_response = None # 记录最后一次的错误响应 + next_cred_task = None # 预热的下一个凭证任务 + + # 内部函数:快速更新凭证(只更新token和project_id,避免重建整个请求) + async def refresh_credential_fast(): + nonlocal current_file, credential_data, auth_headers, final_payload + cred_result = await credential_manager.get_valid_credential( + mode="geminicli", model_name=model_name + ) + if not cred_result: + return None + current_file, credential_data = cred_result + try: + # 只更新token和project_id,不重建整个headers和payload + token = credential_data.get("token") or credential_data.get("access_token", "") + project_id = credential_data.get("project_id", "") + if not token or not project_id: + return None + + # 直接更新现有的headers和payload + auth_headers["Authorization"] = f"Bearer {token}" + final_payload["project"] = project_id + return True + except Exception: + return None + + def apply_cred_result(cred_result: Tuple[str, Dict[str, Any]]) -> bool: + nonlocal current_file, credential_data, auth_headers, final_payload + current_file, credential_data = cred_result + token = credential_data.get("token") or credential_data.get("access_token", "") + project_id = credential_data.get("project_id", "") + if not token or not project_id: + return False + auth_headers["Authorization"] = f"Bearer {token}" + final_payload["project"] = project_id + return True + + for attempt in range(max_retries + 1): + success_recorded = False # 标记是否已记录成功 + need_retry = False # 标记是否需要重试 + + try: + async for chunk in stream_post_async( + url=target_url, + body=final_payload, + native=native, + headers=auth_headers + ): + # 判断是否是Response对象 + if isinstance(chunk, Response): + status_code = chunk.status_code + last_error_response = chunk # 记录最后一次错误 + + # 缓存错误解析结果,避免重复decode + error_body = None + try: + error_body = chunk.body.decode('utf-8') if isinstance(chunk.body, bytes) else str(chunk.body) + except Exception: + error_body = "" + + # 如果错误码是429、503或者在禁用码当中,做好记录后进行重试 + if _is_retryable_status(status_code, DISABLE_ERROR_CODES): + log.warning(f"[GEMINICLI STREAM] 流式请求失败 (status={status_code}), 凭证: {current_file}, 响应: {error_body[:500] if error_body else '无'}") + + # 解析冷却时间 + cooldown_until = None + if (status_code == 429 or status_code == 503) and error_body: + try: + cooldown_until = await parse_and_log_cooldown(error_body, mode="geminicli") + except Exception: + pass + + # 预热下一个凭证 + if next_cred_task is None and attempt < max_retries: + next_cred_task = asyncio.create_task( + credential_manager.get_valid_credential( + mode="geminicli", model_name=model_name + ) + ) + + # 记录错误并切换凭证 + await record_api_call_error( + credential_manager, current_file, status_code, + cooldown_until, mode="geminicli", model_name=model_name, + error_message=error_body + ) + + # 检查是否应该重试 + should_retry = await handle_error_with_retry( + credential_manager, status_code, current_file, + retry_config["retry_enabled"], attempt, max_retries, retry_interval, + mode="geminicli" + ) + + if should_retry and attempt < max_retries: + need_retry = True + break # 跳出内层循环,准备重试 + else: + # 不重试,直接返回原始错误 + log.error(f"[GEMINICLI STREAM] 达到最大重试次数或不应重试,返回原始错误") + yield chunk + return + elif status_code == 404 and "preview" in model_name.lower(): + # 特殊处理:preview模型返回404,说明该凭证不支持preview模型 + log.warning(f"[GEMINICLI STREAM] Preview模型404错误,凭证不支持preview: {current_file}") + + # 将该凭证的preview状态设置为False + try: + await credential_manager.update_credential_state( + current_file, {"preview": False}, mode="geminicli" + ) + log.info(f"[GEMINICLI STREAM] 已将凭证 {current_file} 的preview状态设置为False") + except Exception as e: + log.error(f"[GEMINICLI STREAM] 更新凭证preview状态失败: {e}") + + # 记录404错误 + await record_api_call_error( + credential_manager, current_file, status_code, + None, mode="geminicli", model_name=model_name, + error_message=error_body + ) + + # 预热下一个凭证(会自动跳过preview=False的凭证) + if next_cred_task is None and attempt < max_retries: + next_cred_task = asyncio.create_task( + credential_manager.get_valid_credential( + mode="geminicli", model_name=model_name + ) + ) + + # 触发重试 + if attempt < max_retries: + need_retry = True + break + else: + log.error(f"[GEMINICLI STREAM] 达到最大重试次数,返回404错误") + yield chunk + return + else: + # 错误码不在禁用码当中,直接返回,无需重试 + log.error(f"[GEMINICLI STREAM] 流式请求失败,非重试错误码 (status={status_code}), 凭证: {current_file}, 响应: {error_body[:500] if error_body else '无'}") + await record_api_call_error( + credential_manager, current_file, status_code, + None, mode="geminicli", model_name=model_name, + error_message=error_body + ) + yield chunk + return + else: + # 不是Response,说明是真流,直接yield返回 + # 只在第一个chunk时记录成功 + if not success_recorded: + await record_api_call_success( + credential_manager, current_file, mode="geminicli", model_name=model_name + ) + success_recorded = True + log.debug(f"[GEMINICLI STREAM] 开始接收流式响应,模型: {model_name}") + + yield chunk + + # 流式请求完成,检查结果 + if success_recorded: + log.debug(f"[GEMINICLI STREAM] 流式响应完成,模型: {model_name}") + return + + # 统一处理重试 + if need_retry: + # 如果已经是最后一次尝试,不再重试,直接返回错误 + if attempt >= max_retries: + log.error(f"[GEMINICLI STREAM] 达到最大重试次数,返回错误") + if last_error_response: + yield last_error_response + else: + yield Response( + content=json.dumps({"error": "请求失败,所有重试均已耗尽"}), + status_code=429, + media_type="application/json" + ) + return + + log.info(f"[GEMINICLI STREAM] 重试请求 (attempt {attempt + 2}/{max_retries + 1})...") + + switched, next_cred_task = await _switch_credential_for_retry( + next_cred_task=next_cred_task, + retry_interval=retry_interval, + refresh_credential_fast=refresh_credential_fast, + apply_cred_result=apply_cred_result, + log_prefix="[GEMINICLI STREAM]", + ) + if not switched: + log.error("[GEMINICLI STREAM] 重试时无可用凭证或刷新失败") + yield Response( + content=json.dumps({"error": "当前无可用凭证"}), + status_code=500, + media_type="application/json" + ) + return + continue # 重试 + + except Exception as e: + log.error(f"[GEMINICLI STREAM] 流式请求异常: {e}, 凭证: {current_file}") + if attempt < max_retries: + log.info(f"[GEMINICLI STREAM] 异常后重试 (attempt {attempt + 2}/{max_retries + 1})...") + await asyncio.sleep(retry_interval) + continue + else: + # 所有重试都失败,返回最后一次的错误(如果有) + log.error(f"[GEMINICLI STREAM] 所有重试均失败,最后异常: {e}") + if last_error_response: + yield last_error_response + else: + # 如果没有记录到错误响应,返回500错误 + yield Response( + content=json.dumps({"error": f"流式请求异常: {str(e)}"}), + status_code=500, + media_type="application/json" + ) + return + + # 所有重试均已耗尽(for循环正常结束),返回最后记录的错误 + log.error("[GEMINICLI STREAM] 所有重试均失败") + if last_error_response: + yield last_error_response + else: + yield Response( + content=json.dumps({"error": "请求失败,所有重试均已耗尽"}), + status_code=429, + media_type="application/json" + ) + + +async def non_stream_request( + body: Dict[str, Any], + headers: Optional[Dict[str, str]] = None, +) -> Response: + """ + 非流式请求函数 + + Args: + body: 请求体 + native: 保留参数以保持接口一致性(实际未使用) + headers: 额外的请求头 + + Returns: + Response对象 + """ + # 获取有效凭证 + model_name = body.get("model", "") + + # 1. 获取有效凭证 + cred_result = await credential_manager.get_valid_credential( + mode="geminicli", model_name=model_name + ) + + if not cred_result: + # 如果返回值是None,直接返回错误500 + return Response( + content=json.dumps({"error": "当前无可用凭证"}), + status_code=500, + media_type="application/json" + ) + + current_file, credential_data = cred_result + + # 2. 构建URL和请求头 + try: + auth_headers, final_payload, target_url = await prepare_request_headers_and_payload( + body, credential_data, + f"{await get_code_assist_endpoint()}/v1internal:generateContent" + ) + + # 合并自定义headers + if headers: + auth_headers.update(headers) + + except Exception as e: + log.error(f"准备请求失败: {e}") + return Response( + content=json.dumps({"error": f"准备请求失败: {str(e)}"}), + status_code=500, + media_type="application/json" + ) + + # 3. 调用post_async进行请求 + retry_config = await get_retry_config() + max_retries = retry_config["max_retries"] + retry_interval = retry_config["retry_interval"] + + DISABLE_ERROR_CODES = await get_auto_ban_error_codes() # 禁用凭证的错误码 + last_error_response = None # 记录最后一次的错误响应 + next_cred_task = None # 预热的下一个凭证任务 + + # 内部函数:快速更新凭证(只更新token和project_id,避免重建整个请求) + async def refresh_credential_fast(): + nonlocal current_file, credential_data, auth_headers, final_payload + cred_result = await credential_manager.get_valid_credential( + mode="geminicli", model_name=model_name + ) + if not cred_result: + return None + current_file, credential_data = cred_result + try: + # 只更新token和project_id,不重建整个headers和payload + token = credential_data.get("token") or credential_data.get("access_token", "") + project_id = credential_data.get("project_id", "") + if not token or not project_id: + return None + + # 直接更新现有的headers和payload + auth_headers["Authorization"] = f"Bearer {token}" + final_payload["project"] = project_id + return True + except Exception: + return None + + def apply_cred_result(cred_result: Tuple[str, Dict[str, Any]]) -> bool: + nonlocal current_file, credential_data, auth_headers, final_payload + current_file, credential_data = cred_result + token = credential_data.get("token") or credential_data.get("access_token", "") + project_id = credential_data.get("project_id", "") + if not token or not project_id: + return False + auth_headers["Authorization"] = f"Bearer {token}" + final_payload["project"] = project_id + return True + + for attempt in range(max_retries + 1): + try: + response = await post_async( + url=target_url, + json=final_payload, + headers=auth_headers, + timeout=300.0 + ) + + status_code = response.status_code + + # 成功 + if status_code == 200: + await record_api_call_success( + credential_manager, current_file, mode="geminicli", model_name=model_name + ) + # 创建响应头,移除压缩相关的header避免重复解压 + response_headers = dict(response.headers) + response_headers.pop('content-encoding', None) + response_headers.pop('content-length', None) + + return Response( + content=response.content, + status_code=200, + headers=response_headers + ) + + # 失败 - 记录最后一次错误 + # 创建响应头,移除压缩相关的header避免重复解压 + error_headers = dict(response.headers) + error_headers.pop('content-encoding', None) + error_headers.pop('content-length', None) + + last_error_response = Response( + content=response.content, + status_code=status_code, + headers=error_headers + ) + + # 判断是否需要重试 + # 缓存错误文本,避免重复解析 + error_text = "" + try: + error_text = response.text + except Exception: + pass + + # 统一处理所有需要重试的错误码(429、503、禁用码) + if _is_retryable_status(status_code, DISABLE_ERROR_CODES): + log.warning(f"[NON-STREAM] 非流式请求失败 (status={status_code}), 凭证: {current_file}, 响应: {error_text[:500] if error_text else '无'}") + + # 解析冷却时间 + cooldown_until = None + if (status_code == 429 or status_code == 503) and error_text: + try: + cooldown_until = await parse_and_log_cooldown(error_text, mode="geminicli") + except Exception: + pass + + # 并行预热下一个凭证,不阻塞当前处理 + if next_cred_task is None and attempt < max_retries: + next_cred_task = asyncio.create_task( + credential_manager.get_valid_credential( + mode="geminicli", model_name=model_name + ) + ) + + # 记录错误并切换凭证 + await record_api_call_error( + credential_manager, current_file, status_code, + cooldown_until, mode="geminicli", model_name=model_name, + error_message=error_text + ) + + # 检查是否应该重试(会自动处理禁用逻辑) + should_retry = await handle_error_with_retry( + credential_manager, status_code, current_file, + retry_config["retry_enabled"], attempt, max_retries, retry_interval, + mode="geminicli" + ) + + if should_retry and attempt < max_retries: + # 重新获取凭证并重试 + log.info(f"[NON-STREAM] 重试请求 (attempt {attempt + 2}/{max_retries + 1})...") + + switched, next_cred_task = await _switch_credential_for_retry( + next_cred_task=next_cred_task, + retry_interval=retry_interval, + refresh_credential_fast=refresh_credential_fast, + apply_cred_result=apply_cred_result, + log_prefix="[NON-STREAM]", + ) + if not switched: + log.error("[NON-STREAM] 重试时无可用凭证或刷新失败") + return Response( + content=json.dumps({"error": "当前无可用凭证"}), + status_code=500, + media_type="application/json" + ) + continue # 重试 + else: + # 不重试,直接返回原始错误 + log.error(f"[NON-STREAM] 达到最大重试次数或不应重试,返回原始错误") + return last_error_response + elif status_code == 404 and "preview" in model_name.lower(): + # 特殊处理:preview模型返回404,说明该凭证不支持preview模型 + log.warning(f"[NON-STREAM] Preview模型404错误,凭证不支持preview: {current_file}") + + # 将该凭证的preview状态设置为False + try: + await credential_manager.update_credential_state( + current_file, {"preview": False}, mode="geminicli" + ) + log.info(f"[NON-STREAM] 已将凭证 {current_file} 的preview状态设置为False") + except Exception as e: + log.error(f"[NON-STREAM] 更新凭证preview状态失败: {e}") + + # 记录404错误 + await record_api_call_error( + credential_manager, current_file, status_code, + None, mode="geminicli", model_name=model_name, + error_message=error_text + ) + + # 预热下一个凭证(会自动跳过preview=False的凭证) + if next_cred_task is None and attempt < max_retries: + next_cred_task = asyncio.create_task( + credential_manager.get_valid_credential( + mode="geminicli", model_name=model_name + ) + ) + + # 触发重试 + if attempt < max_retries: + log.info(f"[NON-STREAM] 重试请求 (attempt {attempt + 2}/{max_retries + 1})...") + + switched, next_cred_task = await _switch_credential_for_retry( + next_cred_task=next_cred_task, + retry_interval=retry_interval, + refresh_credential_fast=refresh_credential_fast, + apply_cred_result=apply_cred_result, + log_prefix="[NON-STREAM]", + ) + if not switched: + log.error("[NON-STREAM] 重试时无可用凭证或刷新失败") + return Response( + content=json.dumps({"error": "当前无可用凭证"}), + status_code=500, + media_type="application/json" + ) + continue # 重试 + else: + log.error(f"[NON-STREAM] 达到最大重试次数,返回404错误") + return last_error_response + else: + # 错误码不在重试范围内,直接返回 + log.error(f"[NON-STREAM] 非流式请求失败,非重试错误码 (status={status_code}), 凭证: {current_file}, 响应: {error_text[:500] if error_text else '无'}") + await record_api_call_error( + credential_manager, current_file, status_code, + None, mode="geminicli", model_name=model_name, + error_message=error_text + ) + return last_error_response + + except Exception as e: + log.error(f"非流式请求异常: {e}, 凭证: {current_file}") + if attempt < max_retries: + log.info(f"[NON-STREAM] 异常后重试 (attempt {attempt + 2}/{max_retries + 1})...") + await asyncio.sleep(retry_interval) + continue + else: + # 所有重试都失败,返回最后一次的错误(如果有)或500错误 + log.error(f"[NON-STREAM] 所有重试均失败,最后异常: {e}") + if last_error_response: + return last_error_response + else: + return Response( + content=json.dumps({"error": f"请求异常: {str(e)}"}), + status_code=500, + media_type="application/json" + ) + + # 所有重试都失败,返回最后一次的原始错误 + log.error("[NON-STREAM] 所有重试均失败") + return last_error_response + + +# ==================== 测试代码 ==================== + +if __name__ == "__main__": + """ + 测试代码:演示API返回的流式和非流式数据格式 + 运行方式: python src/api/geminicli.py + """ + print("=" * 80) + print("GeminiCli API 测试") + print("=" * 80) + + # 测试请求体 + test_body = { + "model": "gemini-2.5-flash", + "request": { + "contents": [ + { + "role": "user", + "parts": [{"text": "Hello, tell me a joke in one sentence."}] + } + ] + } + } + + async def test_stream_request(): + """测试流式请求""" + print("\n" + "=" * 80) + print("【测试1】流式请求 (stream_request with native=False)") + print("=" * 80) + print(f"请求体: {json.dumps(test_body, indent=2, ensure_ascii=False)}\n") + + print("流式响应数据 (每个chunk):") + print("-" * 80) + + chunk_count = 0 + async for chunk in stream_request(body=test_body, native=False): + chunk_count += 1 + if isinstance(chunk, Response): + # 错误响应 + print(f"\n❌ 错误响应:") + print(f" 状态码: {chunk.status_code}") + print(f" Content-Type: {chunk.headers.get('content-type', 'N/A')}") + try: + content = chunk.body.decode('utf-8') if isinstance(chunk.body, bytes) else str(chunk.body) + print(f" 内容: {content}") + except Exception as e: + print(f" 内容解析失败: {e}") + else: + # 正常的流式数据块 (str类型) + print(f"\nChunk #{chunk_count}:") + print(f" 类型: {type(chunk).__name__}") + print(f" 长度: {len(chunk) if hasattr(chunk, '__len__') else 'N/A'}") + print(f" 内容预览: {repr(chunk[:200] if len(chunk) > 200 else chunk)}") + + # 如果是SSE格式,尝试解析 + if isinstance(chunk, str) and chunk.startswith("data: "): + try: + data_line = chunk.strip() + if data_line.startswith("data: "): + json_str = data_line[6:] # 去掉 "data: " 前缀 + json_data = json.loads(json_str) + print(f" 解析后的JSON: {json.dumps(json_data, indent=4, ensure_ascii=False)}") + except Exception as e: + print(f" SSE解析尝试失败: {e}") + + print(f"\n总共收到 {chunk_count} 个chunk") + + async def test_non_stream_request(): + """测试非流式请求""" + print("\n" + "=" * 80) + print("【测试2】非流式请求 (non_stream_request)") + print("=" * 80) + print(f"请求体: {json.dumps(test_body, indent=2, ensure_ascii=False)}\n") + + response = await non_stream_request(body=test_body) + + print("非流式响应数据:") + print("-" * 80) + print(f"状态码: {response.status_code}") + print(f"Content-Type: {response.headers.get('content-type', 'N/A')}") + print(f"\n响应头: {dict(response.headers)}\n") + + try: + content = response.body.decode('utf-8') if isinstance(response.body, bytes) else str(response.body) + print(f"响应内容 (原始):\n{content}\n") + + # 尝试解析JSON + try: + json_data = json.loads(content) + print(f"响应内容 (格式化JSON):") + print(json.dumps(json_data, indent=2, ensure_ascii=False)) + except json.JSONDecodeError: + print("(非JSON格式)") + except Exception as e: + print(f"内容解析失败: {e}") + + async def main(): + """主测试函数""" + try: + # 测试流式请求 + await test_stream_request() + + # 测试非流式请求 + await test_non_stream_request() + + print("\n" + "=" * 80) + print("测试完成") + print("=" * 80) + + except Exception as e: + print(f"\n❌ 测试过程中出现异常: {e}") + import traceback + traceback.print_exc() + + # 运行测试 + asyncio.run(main()) diff --git a/src/api/utils.py b/src/api/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1638d77609c8cc4b7616beb8944c01b551433976 --- /dev/null +++ b/src/api/utils.py @@ -0,0 +1,505 @@ +""" +Base API Client - 共用的 API 客户端基础功能 +提供错误处理、自动封禁、重试逻辑等共同功能 +""" + +import asyncio +import json +from datetime import datetime, timezone +from typing import Any, Dict, Optional + +from fastapi import Response + +from config import ( + get_auto_ban_enabled, + get_auto_ban_error_codes, + get_retry_429_enabled, + get_retry_429_interval, + get_retry_429_max_retries, +) +from log import log +from src.credential_manager import CredentialManager + + +# ==================== 错误检查与处理 ==================== + +async def check_should_auto_ban(status_code: int) -> bool: + """ + 检查是否应该触发自动封禁 + + Args: + status_code: HTTP状态码 + + Returns: + bool: 是否应该触发自动封禁 + """ + return ( + await get_auto_ban_enabled() + and status_code in await get_auto_ban_error_codes() + ) + + +async def handle_auto_ban( + credential_manager: CredentialManager, + status_code: int, + credential_name: str, + mode: str = "geminicli" +) -> None: + """ + 处理自动封禁:直接禁用凭证 + + Args: + credential_manager: 凭证管理器实例 + status_code: HTTP状态码 + credential_name: 凭证名称 + mode: 模式(geminicli 或 antigravity) + """ + if credential_manager and credential_name: + log.warning( + f"[{mode.upper()} AUTO_BAN] Status {status_code} triggers auto-ban for credential: {credential_name}" + ) + await credential_manager.set_cred_disabled( + credential_name, True, mode=mode + ) + + +async def handle_error_with_retry( + credential_manager: CredentialManager, + status_code: int, + credential_name: str, + retry_enabled: bool, + attempt: int, + max_retries: int, + retry_interval: float, + mode: str = "geminicli" +) -> bool: + """ + 统一处理错误和重试逻辑 + + 仅在以下情况下进行自动重试: + 1. 429错误(速率限制) + 2. 503错误(服务不可用) + 3. 导致凭证封禁的错误(AUTO_BAN_ERROR_CODES配置) + + Args: + credential_manager: 凭证管理器实例 + status_code: HTTP状态码 + credential_name: 凭证名称 + retry_enabled: 是否启用重试 + attempt: 当前重试次数 + max_retries: 最大重试次数 + retry_interval: 重试间隔 + mode: 模式(geminicli 或 antigravity) + + Returns: + bool: True表示需要继续重试,False表示不需要重试 + """ + # 优先检查自动封禁 + should_auto_ban = await check_should_auto_ban(status_code) + + if should_auto_ban: + # 触发自动封禁 + await handle_auto_ban(credential_manager, status_code, credential_name, mode) + + # 自动封禁后,仍然尝试重试(会在下次循环中自动获取新凭证) + if retry_enabled and attempt < max_retries: + log.info( + f"[{mode.upper()} RETRY] Retrying with next credential after auto-ban " + f"(status {status_code}, attempt {attempt + 1}/{max_retries})" + ) + await asyncio.sleep(retry_interval) + return True + return False + + # 如果不触发自动封禁,仅对429和503错误进行重试 + if (status_code == 429 or status_code == 503) and retry_enabled and attempt < max_retries: + log.info( + f"[{mode.upper()} RETRY] {status_code} error encountered, retrying " + f"(attempt {attempt + 1}/{max_retries})" + ) + await asyncio.sleep(retry_interval) + return True + + # 其他错误不进行重试 + return False + + +# ==================== 重试配置获取 ==================== + +async def get_retry_config() -> Dict[str, Any]: + """ + 获取重试配置 + + Returns: + 包含重试配置的字典 + """ + return { + "retry_enabled": await get_retry_429_enabled(), + "max_retries": await get_retry_429_max_retries(), + "retry_interval": await get_retry_429_interval(), + } + + +# ==================== API调用结果记录 ==================== + +async def record_api_call_success( + credential_manager: CredentialManager, + credential_name: str, + mode: str = "geminicli", + model_name: Optional[str] = None +) -> None: + """ + 记录API调用成功 + + Args: + credential_manager: 凭证管理器实例 + credential_name: 凭证名称 + mode: 模式(geminicli 或 antigravity) + model_name: 模型名称(用于模型级CD) + """ + if credential_manager and credential_name: + await credential_manager.record_api_call_result( + credential_name, True, mode=mode, model_name=model_name + ) + + +async def record_api_call_error( + credential_manager: CredentialManager, + credential_name: str, + status_code: int, + cooldown_until: Optional[float] = None, + mode: str = "geminicli", + model_name: Optional[str] = None, + error_message: Optional[str] = None +) -> None: + """ + 记录API调用错误 + + Args: + credential_manager: 凭证管理器实例 + credential_name: 凭证名称 + status_code: HTTP状态码 + cooldown_until: 冷却截止时间(Unix时间戳) + mode: 模式(geminicli 或 antigravity) + model_name: 模型名称(用于模型级CD) + error_message: 错误信息(可选) + """ + if credential_manager and credential_name: + await credential_manager.record_api_call_result( + credential_name, + False, + status_code, + cooldown_until=cooldown_until, + mode=mode, + model_name=model_name, + error_message=error_message + ) + + +# ==================== 429错误处理 ==================== + +async def parse_and_log_cooldown( + error_text: str, + mode: str = "geminicli" +) -> Optional[float]: + """ + 解析并记录冷却时间 + + Args: + error_text: 错误响应文本 + mode: 模式(geminicli 或 antigravity) + + Returns: + 冷却截止时间(Unix时间戳),如果解析失败则返回None + """ + try: + error_data = json.loads(error_text) + cooldown_until = parse_quota_reset_timestamp(error_data) + if cooldown_until: + log.info( + f"[{mode.upper()}] 检测到quota冷却时间: " + f"{datetime.fromtimestamp(cooldown_until, timezone.utc).isoformat()}" + ) + return cooldown_until + except Exception as parse_err: + log.debug(f"[{mode.upper()}] Failed to parse cooldown time: {parse_err}") + return None + + +# ==================== 流式响应收集 ==================== + +async def collect_streaming_response(stream_generator) -> Response: + """ + 将Gemini流式响应收集为一条完整的非流式响应 + + Args: + stream_generator: 流式响应生成器,产生 "data: {json}" 格式的行或Response对象 + + Returns: + Response: 合并后的完整响应对象 + + Example: + >>> async for line in stream_generator: + ... # line format: "data: {...}" or Response object + >>> response = await collect_streaming_response(stream_generator) + """ + # 初始化响应结构 + merged_response = { + "response": { + "candidates": [{ + "content": { + "parts": [], + "role": "model" + }, + "finishReason": None, + "safetyRatings": [], + "citationMetadata": None + }], + "usageMetadata": { + "promptTokenCount": 0, + "candidatesTokenCount": 0, + "totalTokenCount": 0 + } + } + } + + collected_text = [] # 用于收集文本内容 + collected_thought_text = [] # 用于收集思维链内容 + collected_other_parts = [] # 用于收集其他类型的parts(图片、文件、工具调用等) + collected_tool_parts_count = 0 # 记录工具调用相关part数量 + has_data = False + line_count = 0 + + log.debug("[STREAM COLLECTOR] Starting to collect streaming response") + + try: + async for line in stream_generator: + line_count += 1 + + # 如果收到的是Response对象(错误),直接返回 + if isinstance(line, Response): + log.debug(f"[STREAM COLLECTOR] 收到错误Response,状态码: {line.status_code}") + return line + + # 处理 bytes 类型 + if isinstance(line, bytes): + line_str = line.decode('utf-8', errors='ignore') + log.debug(f"[STREAM COLLECTOR] Processing bytes line {line_count}: {line_str[:200] if line_str else 'empty'}") + elif isinstance(line, str): + line_str = line + log.debug(f"[STREAM COLLECTOR] Processing line {line_count}: {line_str[:200] if line_str else 'empty'}") + else: + log.debug(f"[STREAM COLLECTOR] Skipping non-string/bytes line: {type(line)}") + continue + + # 解析流式数据行 + if not line_str.startswith("data: "): + log.debug(f"[STREAM COLLECTOR] Skipping line without 'data: ' prefix: {line_str[:100]}") + continue + + raw = line_str[6:].strip() + if raw == "[DONE]": + log.debug("[STREAM COLLECTOR] Received [DONE] marker") + break + + try: + log.debug(f"[STREAM COLLECTOR] Parsing JSON: {raw[:200]}") + chunk = json.loads(raw) + has_data = True + log.debug(f"[STREAM COLLECTOR] Chunk keys: {chunk.keys() if isinstance(chunk, dict) else type(chunk)}") + + # 提取响应对象 + response_obj = chunk.get("response", {}) + if not response_obj: + log.debug("[STREAM COLLECTOR] No 'response' key in chunk, trying direct access") + response_obj = chunk # 尝试直接使用chunk + + candidates = response_obj.get("candidates", []) + log.debug(f"[STREAM COLLECTOR] Found {len(candidates)} candidates") + if not candidates: + log.debug(f"[STREAM COLLECTOR] No candidates in chunk, chunk structure: {list(chunk.keys()) if isinstance(chunk, dict) else type(chunk)}") + continue + + candidate = candidates[0] + + # 收集文本内容 + content = candidate.get("content", {}) + parts = content.get("parts", []) + log.debug(f"[STREAM COLLECTOR] Processing {len(parts)} parts from candidate") + + for part in parts: + if not isinstance(part, dict): + continue + + # 优先保留工具调用相关 part(functionCall / functionResponse) + # 避免在 stream2nostream 模式下工具调用丢失 + if "functionCall" in part or "functionResponse" in part or "function_call" in part: + collected_other_parts.append(part) + collected_tool_parts_count += 1 + log.debug(f"[STREAM COLLECTOR] Collected tool part: {list(part.keys())}") + continue + + # 处理文本内容 + text = part.get("text", "") + if text: + # 区分普通文本和思维链 + if part.get("thought", False): + collected_thought_text.append(text) + log.debug(f"[STREAM COLLECTOR] Collected thought text: {text[:100]}") + else: + collected_text.append(text) + log.debug(f"[STREAM COLLECTOR] Collected regular text: {text[:100]}") + # 处理非文本内容(图片、文件等) + elif "inlineData" in part or "fileData" in part or "executableCode" in part or "codeExecutionResult" in part: + collected_other_parts.append(part) + log.debug(f"[STREAM COLLECTOR] Collected non-text part: {list(part.keys())}") + + # 收集其他信息(使用最后一个块的值) + if candidate.get("finishReason"): + merged_response["response"]["candidates"][0]["finishReason"] = candidate["finishReason"] + + if candidate.get("safetyRatings"): + merged_response["response"]["candidates"][0]["safetyRatings"] = candidate["safetyRatings"] + + if candidate.get("citationMetadata"): + merged_response["response"]["candidates"][0]["citationMetadata"] = candidate["citationMetadata"] + + # 更新使用元数据 + usage = response_obj.get("usageMetadata", {}) + if usage: + merged_response["response"]["usageMetadata"].update(usage) + + except json.JSONDecodeError as e: + log.debug(f"[STREAM COLLECTOR] Failed to parse JSON chunk: {e}") + continue + except Exception as e: + log.debug(f"[STREAM COLLECTOR] Error processing chunk: {e}") + continue + + except Exception as e: + log.error(f"[STREAM COLLECTOR] Error collecting stream after {line_count} lines: {e}") + return Response( + content=json.dumps({"error": f"收集流式响应失败: {str(e)}"}), + status_code=500, + media_type="application/json" + ) + + log.debug(f"[STREAM COLLECTOR] Finished iteration, has_data={has_data}, line_count={line_count}") + + # 如果没有收集到任何数据,返回错误 + if not has_data: + log.error(f"[STREAM COLLECTOR] No data collected from stream after {line_count} lines") + return Response( + content=json.dumps({"error": "No data collected from stream"}), + status_code=500, + media_type="application/json" + ) + + # 组装最终的parts + final_parts = [] + + # 先添加思维链内容(如果有) + if collected_thought_text: + final_parts.append({ + "text": "".join(collected_thought_text), + "thought": True + }) + + # 再添加普通文本内容 + if collected_text: + final_parts.append({ + "text": "".join(collected_text) + }) + + # 添加其他类型的parts(图片、文件等) + final_parts.extend(collected_other_parts) + + # 如果没有任何内容,添加空文本 + if not final_parts: + final_parts.append({"text": ""}) + + merged_response["response"]["candidates"][0]["content"]["parts"] = final_parts + + log.info( + f"[STREAM COLLECTOR] Collected {len(collected_text)} text chunks, " + f"{len(collected_thought_text)} thought chunks, {len(collected_other_parts)} other parts " + f"(tool parts: {collected_tool_parts_count})" + ) + + # 去掉嵌套的 "response" 包装(Antigravity格式 -> 标准Gemini格式) + if "response" in merged_response and "candidates" not in merged_response: + log.debug(f"[STREAM COLLECTOR] 展开response包装") + merged_response = merged_response["response"] + + # 返回纯JSON格式 + return Response( + content=json.dumps(merged_response, ensure_ascii=False).encode('utf-8'), + status_code=200, + headers={}, + media_type="application/json" + ) + + +RESOURCE_EXHAUSTED_COOLDOWN_HOURS = 4 # RESOURCE_EXHAUSTED 错误的默认冷却时间(小时) + + +def parse_quota_reset_timestamp(error_response: dict) -> Optional[float]: + """ + 从Google API错误响应中提取quota重置时间戳 + + Args: + error_response: Google API返回的错误响应字典 + + Returns: + Unix时间戳(秒),如果无法解析则返回None + + 示例错误响应: + { + "error": { + "code": 429, + "message": "You have exhausted your capacity...", + "status": "RESOURCE_EXHAUSTED", + "details": [ + { + "@type": "type.googleapis.com/google.rpc.ErrorInfo", + "reason": "QUOTA_EXHAUSTED", + "metadata": { + "quotaResetTimeStamp": "2025-11-30T14:57:24Z", + "quotaResetDelay": "13h19m1.20964964s" + } + } + ] + } + } + """ + try: + error_obj = error_response.get("error", {}) + details = error_obj.get("details", []) + + for detail in details: + if detail.get("@type") == "type.googleapis.com/google.rpc.ErrorInfo": + reset_timestamp_str = detail.get("metadata", {}).get("quotaResetTimeStamp") + + if reset_timestamp_str: + if reset_timestamp_str.endswith("Z"): + reset_timestamp_str = reset_timestamp_str.replace("Z", "+00:00") + + reset_dt = datetime.fromisoformat(reset_timestamp_str) + if reset_dt.tzinfo is None: + reset_dt = reset_dt.replace(tzinfo=timezone.utc) + + return reset_dt.astimezone(timezone.utc).timestamp() + + # 如果是 RESOURCE_EXHAUSTED 错误且消息完全匹配,设置默认4小时冷却时间 + if ( + error_obj.get("status") == "RESOURCE_EXHAUSTED" + and error_obj.get("message") == "Resource has been exhausted (e.g. check quota)." + ): + import time + cooldown_until = time.time() + RESOURCE_EXHAUSTED_COOLDOWN_HOURS * 3600 + return cooldown_until + + return None + + except Exception: + return None diff --git a/src/auth.py b/src/auth.py new file mode 100644 index 0000000000000000000000000000000000000000..4902060191d28b5586055eb0268e12539979bf7d --- /dev/null +++ b/src/auth.py @@ -0,0 +1,1089 @@ +""" +认证API模块 +""" + +import asyncio +import socket +import threading +import time +import uuid +from datetime import timezone +from http.server import BaseHTTPRequestHandler, HTTPServer +from typing import Any, Dict, Optional +from urllib.parse import parse_qs, urlparse + +from config import get_config_value, get_code_assist_endpoint +from log import log + +from .google_oauth_api import ( + Credentials, + Flow, + enable_required_apis, + fetch_project_id_and_tier, + get_user_projects, + select_default_project, +) +from .storage_adapter import get_storage_adapter +from .utils import ( + ANTIGRAVITY_CLIENT_ID, + ANTIGRAVITY_CLIENT_SECRET, + ANTIGRAVITY_SCOPES, + ANTIGRAVITY_USER_AGENT, + CALLBACK_HOST, + CLIENT_ID, + CLIENT_SECRET, + SCOPES, + GEMINICLI_USER_AGENT, + TOKEN_URL, +) + + +async def get_callback_port(): + """获取OAuth回调端口""" + return int(await get_config_value("oauth_callback_port", "11451", "OAUTH_CALLBACK_PORT")) + + +def _prepare_credentials_data(credentials: Credentials, project_id: str, mode: str = "geminicli", subscription_tier: str = None) -> Dict[str, Any]: + """准备凭证数据字典(统一函数)""" + if mode == "antigravity": + creds_data = { + "client_id": ANTIGRAVITY_CLIENT_ID, + "client_secret": ANTIGRAVITY_CLIENT_SECRET, + "token": credentials.access_token, + "refresh_token": credentials.refresh_token, + "scopes": ANTIGRAVITY_SCOPES, + "token_uri": TOKEN_URL, + "project_id": project_id, + } + else: + creds_data = { + "client_id": CLIENT_ID, + "client_secret": CLIENT_SECRET, + "token": credentials.access_token, + "refresh_token": credentials.refresh_token, + "scopes": SCOPES, + "token_uri": TOKEN_URL, + "project_id": project_id, + } + + if credentials.expires_at: + if credentials.expires_at.tzinfo is None: + expiry_utc = credentials.expires_at.replace(tzinfo=timezone.utc) + else: + expiry_utc = credentials.expires_at + creds_data["expiry"] = expiry_utc.isoformat() + + return creds_data + + +def _cleanup_auth_flow_server(state: str): + """清理认证流程的服务器资源""" + if state in auth_flows: + flow_data_to_clean = auth_flows[state] + try: + if flow_data_to_clean.get("server"): + server = flow_data_to_clean["server"] + port = flow_data_to_clean.get("callback_port") + async_shutdown_server(server, port) + except Exception as e: + log.debug(f"关闭服务器时出错: {e}") + del auth_flows[state] + + +class _OAuthLibPatcher: + """oauthlib参数验证补丁的上下文管理器""" + def __init__(self): + import oauthlib.oauth2.rfc6749.parameters + self.module = oauthlib.oauth2.rfc6749.parameters + self.original_validate = None + + def __enter__(self): + self.original_validate = self.module.validate_token_parameters + + def patched_validate(params): + try: + return self.original_validate(params) + except Warning: + pass + + self.module.validate_token_parameters = patched_validate + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if self.original_validate: + self.module.validate_token_parameters = self.original_validate + + +# 全局状态管理 - 严格限制大小 +auth_flows = {} # 存储进行中的认证流程 +MAX_AUTH_FLOWS = 20 # 严格限制最大认证流程数 +DEFAULT_PROJECT_ID = "gemini-pro-1751713012-07fc4dfd" + + +def cleanup_auth_flows_for_memory(): + """清理认证流程以释放内存""" + global auth_flows + cleanup_expired_flows() + # 如果还是太多,强制清理一些旧的流程 + if len(auth_flows) > 10: + # 按创建时间排序,保留最新的10个 + sorted_flows = sorted( + auth_flows.items(), key=lambda x: x[1].get("created_at", 0), reverse=True + ) + new_auth_flows = dict(sorted_flows[:10]) + + # 清理被移除的流程 + for state, flow_data in auth_flows.items(): + if state not in new_auth_flows: + try: + if flow_data.get("server"): + server = flow_data["server"] + port = flow_data.get("callback_port") + async_shutdown_server(server, port) + except Exception: + pass + flow_data.clear() + + auth_flows = new_auth_flows + log.info(f"强制清理认证流程,保留 {len(auth_flows)} 个最新流程") + + return len(auth_flows) + + +async def find_available_port(start_port: int = None) -> int: + """动态查找可用端口""" + if start_port is None: + start_port = await get_callback_port() + + # 首先尝试默认端口 + for port in range(start_port, start_port + 100): # 尝试100个端口 + try: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + s.bind(("0.0.0.0", port)) + log.info(f"找到可用端口: {port}") + return port + except OSError: + continue + + # 如果都不可用,让系统自动分配端口 + try: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("0.0.0.0", 0)) + port = s.getsockname()[1] + log.info(f"系统分配可用端口: {port}") + return port + except OSError as e: + log.error(f"无法找到可用端口: {e}") + raise RuntimeError("无法找到可用端口") + + +def create_callback_server(port: int) -> HTTPServer: + """创建指定端口的回调服务器,优化快速关闭""" + try: + # 服务器监听0.0.0.0 + server = HTTPServer(("0.0.0.0", port), AuthCallbackHandler) + + # 设置socket选项以支持快速关闭 + server.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + # 设置较短的超时时间 + server.timeout = 1.0 + + log.info(f"创建OAuth回调服务器,监听端口: {port}") + return server + except OSError as e: + log.error(f"创建端口{port}的服务器失败: {e}") + raise + + +class AuthCallbackHandler(BaseHTTPRequestHandler): + """OAuth回调处理器""" + + def do_GET(self): + query_components = parse_qs(urlparse(self.path).query) + code = query_components.get("code", [None])[0] + state = query_components.get("state", [None])[0] + + log.info(f"收到OAuth回调: code={'已获取' if code else '未获取'}, state={state}") + + if code and state and state in auth_flows: + # 更新流程状态 + auth_flows[state]["code"] = code + auth_flows[state]["completed"] = True + + log.info(f"OAuth回调成功处理: state={state}") + + self.send_response(200) + self.send_header("Content-type", "text/html") + self.end_headers() + # 成功页面 + self.wfile.write( + b"

OAuth authentication successful!

You can close this window. Please return to the original page and click 'Get Credentials' button.

" + ) + else: + self.send_response(400) + self.send_header("Content-type", "text/html") + self.end_headers() + self.wfile.write(b"

Authentication failed.

Please try again.

") + + def log_message(self, format, *args): + # 减少日志噪音 + pass + + +async def create_auth_url( + project_id: Optional[str] = None, user_session: str = None, mode: str = "geminicli" +) -> Dict[str, Any]: + """创建认证URL,支持动态端口分配""" + try: + # 动态分配端口 + callback_port = await find_available_port() + callback_url = f"http://{CALLBACK_HOST}:{callback_port}" + + # 立即启动回调服务器 + try: + callback_server = create_callback_server(callback_port) + # 在后台线程中运行服务器 + server_thread = threading.Thread( + target=callback_server.serve_forever, + daemon=True, + name=f"OAuth-Server-{callback_port}", + ) + server_thread.start() + log.info(f"OAuth回调服务器已启动,端口: {callback_port}") + except Exception as e: + log.error(f"启动回调服务器失败: {e}") + return { + "success": False, + "error": f"无法启动OAuth回调服务器,端口{callback_port}: {str(e)}", + } + + # 创建OAuth流程 + # 根据模式选择配置 + if mode == "antigravity": + client_id = ANTIGRAVITY_CLIENT_ID + client_secret = ANTIGRAVITY_CLIENT_SECRET + scopes = ANTIGRAVITY_SCOPES + else: + client_id = CLIENT_ID + client_secret = CLIENT_SECRET + scopes = SCOPES + + flow = Flow( + client_id=client_id, + client_secret=client_secret, + scopes=scopes, + redirect_uri=callback_url, + ) + + # 生成状态标识符,包含用户会话信息 + if user_session: + state = f"{user_session}_{str(uuid.uuid4())}" + else: + state = str(uuid.uuid4()) + + # 生成认证URL + auth_url = flow.get_auth_url(state=state) + + # 严格控制认证流程数量 - 超过限制时立即清理最旧的 + if len(auth_flows) >= MAX_AUTH_FLOWS: + # 清理最旧的认证流程 + oldest_state = min(auth_flows.keys(), key=lambda k: auth_flows[k].get("created_at", 0)) + try: + # 清理服务器资源 + old_flow = auth_flows[oldest_state] + if old_flow.get("server"): + server = old_flow["server"] + port = old_flow.get("callback_port") + async_shutdown_server(server, port) + except Exception as e: + log.warning(f"Failed to cleanup old auth flow {oldest_state}: {e}") + + del auth_flows[oldest_state] + log.debug(f"Removed oldest auth flow: {oldest_state}") + + # 保存流程状态 + auth_flows[state] = { + "flow": flow, + "project_id": project_id, # 可能为None,稍后在回调时确定 + "user_session": user_session, + "callback_port": callback_port, # 存储分配的端口 + "callback_url": callback_url, # 存储完整回调URL + "server": callback_server, # 存储服务器实例 + "server_thread": server_thread, # 存储服务器线程 + "code": None, + "completed": False, + "created_at": time.time(), + "auto_project_detection": project_id is None, # 标记是否需要自动检测项目ID + "mode": mode, # 凭证模式 + } + + # 清理过期的流程(30分钟) + cleanup_expired_flows() + + log.info(f"OAuth流程已创建: state={state}, project_id={project_id}") + log.info(f"用户需要访问认证URL,然后OAuth会回调到 {callback_url}") + log.info(f"为此认证流程分配的端口: {callback_port}") + + return { + "auth_url": auth_url, + "state": state, + "callback_port": callback_port, + "success": True, + "auto_project_detection": project_id is None, + "detected_project_id": project_id, + } + + except Exception as e: + log.error(f"创建认证URL失败: {e}") + return {"success": False, "error": str(e)} + + +def wait_for_callback_sync(state: str, timeout: int = 300) -> Optional[str]: + """同步等待OAuth回调完成,使用对应流程的专用服务器""" + if state not in auth_flows: + log.error(f"未找到状态为 {state} 的认证流程") + return None + + flow_data = auth_flows[state] + callback_port = flow_data["callback_port"] + + # 服务器已经在create_auth_url时启动了,这里只需要等待 + log.info(f"等待OAuth回调完成,端口: {callback_port}") + + # 等待回调完成 + start_time = time.time() + while time.time() - start_time < timeout: + if flow_data.get("code"): + log.info("OAuth回调成功完成") + return flow_data["code"] + time.sleep(0.5) # 每0.5秒检查一次 + + # 刷新flow_data引用 + if state in auth_flows: + flow_data = auth_flows[state] + + log.warning(f"等待OAuth回调超时 ({timeout}秒)") + return None + + +async def complete_auth_flow( + project_id: Optional[str] = None, user_session: str = None +) -> Dict[str, Any]: + """完成认证流程并保存凭证,支持自动检测项目ID""" + try: + # 查找对应的认证流程 + state = None + flow_data = None + + # 如果指定了project_id,先尝试匹配指定的项目 + if project_id: + for s, data in auth_flows.items(): + if data["project_id"] == project_id: + # 如果指定了用户会话,优先匹配相同会话的流程 + if user_session and data.get("user_session") == user_session: + state = s + flow_data = data + break + # 如果没有指定会话,或没找到匹配会话的流程,使用第一个匹配项目ID的 + elif not state: + state = s + flow_data = data + + # 如果没有指定项目ID或没找到匹配的,查找需要自动检测项目ID的流程 + if not state: + for s, data in auth_flows.items(): + if data.get("auto_project_detection", False): + # 如果指定了用户会话,优先匹配相同会话的流程 + if user_session and data.get("user_session") == user_session: + state = s + flow_data = data + break + # 使用第一个找到的需要自动检测的流程 + elif not state: + state = s + flow_data = data + + if not state or not flow_data: + return {"success": False, "error": "未找到对应的认证流程,请先点击获取认证链接"} + + if not project_id: + project_id = flow_data.get("project_id") + if not project_id: + project_id = DEFAULT_PROJECT_ID + log.warning(f"未获取到project_id,使用默认project_id: {project_id}") + + flow = flow_data["flow"] + + # 如果还没有授权码,需要等待回调 + if not flow_data.get("code"): + log.info(f"等待用户完成OAuth授权 (state: {state})") + auth_code = wait_for_callback_sync(state) + + if not auth_code: + return { + "success": False, + "error": "未接收到授权回调,请确保完成了浏览器中的OAuth认证", + } + + # 更新流程数据 + auth_flows[state]["code"] = auth_code + auth_flows[state]["completed"] = True + else: + auth_code = flow_data["code"] + + # 使用认证代码获取凭证 + with _OAuthLibPatcher(): + try: + credentials = await flow.exchange_code(auth_code) + # credentials 已经在 exchange_code 中获得 + + # 如果需要自动检测项目ID且没有提供项目ID + if flow_data.get("auto_project_detection", False) and not project_id: + log.info("尝试通过API获取用户项目列表...") + log.info(f"使用的token: {credentials.access_token[:20]}...") + log.info(f"Token过期时间: {credentials.expires_at}") + user_projects = await get_user_projects(credentials) + + if user_projects: + # 如果只有一个项目,自动使用 + if len(user_projects) == 1: + # Google API returns projectId in camelCase + project_id = user_projects[0].get("projectId") + if project_id: + flow_data["project_id"] = project_id + log.info(f"自动选择唯一项目: {project_id}") + # 如果有多个项目,尝试选择默认项目 + else: + project_id = await select_default_project(user_projects) + if project_id: + flow_data["project_id"] = project_id + log.info(f"自动选择默认项目: {project_id}") + else: + # 返回项目列表让用户选择 + return { + "success": False, + "error": "请从以下项目中选择一个", + "requires_project_selection": True, + "available_projects": [ + { + # Google API returns projectId in camelCase + "project_id": p.get("projectId"), + "name": p.get("displayName") or p.get("projectId"), + "projectNumber": p.get("projectNumber"), + } + for p in user_projects + ], + } + else: + # 如果无法获取项目列表,使用默认project_id + project_id = DEFAULT_PROJECT_ID + flow_data["project_id"] = project_id + log.warning(f"无法获取项目列表,使用默认project_id: {project_id}") + + # 如果仍然没有项目ID,返回错误 + if not project_id: + project_id = DEFAULT_PROJECT_ID + flow_data["project_id"] = project_id + log.warning(f"仍未获取到project_id,使用默认project_id: {project_id}") + + # 保存凭证 + saved_filename = await save_credentials(credentials, project_id) + + # 准备返回的凭证数据 + creds_data = _prepare_credentials_data(credentials, project_id, mode="geminicli") + + # 清理使用过的流程 + _cleanup_auth_flow_server(state) + + log.info("OAuth认证成功,凭证已保存") + return { + "success": True, + "credentials": creds_data, + "file_path": saved_filename, + "auto_detected_project": flow_data.get("auto_project_detection", False), + } + + except Exception as e: + log.error(f"获取凭证失败: {e}") + return {"success": False, "error": f"获取凭证失败: {str(e)}"} + + except Exception as e: + log.error(f"完成认证流程失败: {e}") + return {"success": False, "error": str(e)} + + +async def asyncio_complete_auth_flow( + project_id: Optional[str] = None, user_session: str = None, mode: str = "geminicli" +) -> Dict[str, Any]: + """异步完成认证流程,支持自动检测项目ID""" + try: + log.info( + f"asyncio_complete_auth_flow开始执行: project_id={project_id}, user_session={user_session}" + ) + + # 查找对应的认证流程 + state = None + flow_data = None + + log.debug(f"当前所有auth_flows: {list(auth_flows.keys())}") + + # 如果指定了project_id,先尝试匹配指定的项目 + if project_id: + log.info(f"尝试匹配指定的项目ID: {project_id}") + for s, data in auth_flows.items(): + if data["project_id"] == project_id: + # 如果指定了用户会话,优先匹配相同会话的流程 + if user_session and data.get("user_session") == user_session: + state = s + flow_data = data + log.info(f"找到匹配的用户会话: {s}") + break + # 如果没有指定会话,或没找到匹配会话的流程,使用第一个匹配项目ID的 + elif not state: + state = s + flow_data = data + log.info(f"找到匹配的项目ID: {s}") + + # 如果没有指定项目ID或没找到匹配的,查找需要自动检测项目ID的流程 + if not state: + log.info("没有找到指定项目的流程,查找自动检测流程") + # 首先尝试找到已完成的流程(有授权码的) + completed_flows = [] + for s, data in auth_flows.items(): + if data.get("auto_project_detection", False): + if user_session and data.get("user_session") == user_session: + if data.get("code"): # 优先选择已完成的 + completed_flows.append((s, data, data.get("created_at", 0))) + + # 如果有已完成的流程,选择最新的 + if completed_flows: + completed_flows.sort(key=lambda x: x[2], reverse=True) # 按时间倒序 + state, flow_data, _ = completed_flows[0] + log.info(f"找到已完成的最新认证流程: {state}") + else: + # 如果没有已完成的,找最新的未完成流程 + pending_flows = [] + for s, data in auth_flows.items(): + if data.get("auto_project_detection", False): + if user_session and data.get("user_session") == user_session: + pending_flows.append((s, data, data.get("created_at", 0))) + elif not user_session: + pending_flows.append((s, data, data.get("created_at", 0))) + + if pending_flows: + pending_flows.sort(key=lambda x: x[2], reverse=True) # 按时间倒序 + state, flow_data, _ = pending_flows[0] + log.info(f"找到最新的待完成认证流程: {state}") + + if not state or not flow_data: + log.error(f"未找到认证流程: state={state}, flow_data存在={bool(flow_data)}") + log.debug(f"当前所有flow_data: {list(auth_flows.keys())}") + return {"success": False, "error": "未找到对应的认证流程,请先点击获取认证链接"} + + log.info(f"找到认证流程: state={state}") + log.info( + f"flow_data内容: project_id={flow_data.get('project_id')}, auto_project_detection={flow_data.get('auto_project_detection')}" + ) + log.info(f"传入的project_id参数: {project_id}") + + # 如果需要自动检测项目ID且没有提供项目ID + log.info( + f"检查auto_project_detection条件: auto_project_detection={flow_data.get('auto_project_detection', False)}, not project_id={not project_id}" + ) + if flow_data.get("auto_project_detection", False) and not project_id: + log.info("跳过自动检测项目ID,进入等待阶段") + elif not project_id: + log.info("进入project_id检查分支") + project_id = flow_data.get("project_id") + if not project_id: + project_id = DEFAULT_PROJECT_ID + flow_data["project_id"] = project_id + log.warning(f"缺少项目ID,使用默认project_id: {project_id}") + else: + log.info(f"使用提供的项目ID: {project_id}") + + # 检查是否已经有授权码 + log.info("开始检查OAuth授权码...") + log.info(f"等待state={state}的授权回调,回调端口: {flow_data.get('callback_port')}") + log.info(f"当前flow_data状态: completed={flow_data.get('completed')}, code存在={bool(flow_data.get('code'))}") + max_wait_time = 60 # 最多等待60秒 + wait_interval = 1 # 每秒检查一次 + waited = 0 + + while waited < max_wait_time: + if flow_data.get("code"): + log.info(f"检测到OAuth授权码,开始处理凭证 (等待时间: {waited}秒)") + break + + # 每5秒输出一次提示 + if waited % 5 == 0 and waited > 0: + log.info(f"仍在等待OAuth授权... ({waited}/{max_wait_time}秒)") + log.debug(f"当前state: {state}, flow_data keys: {list(flow_data.keys())}") + + # 异步等待 + await asyncio.sleep(wait_interval) + waited += wait_interval + + # 刷新flow_data引用,因为可能被回调更新了 + if state in auth_flows: + flow_data = auth_flows[state] + + if not flow_data.get("code"): + log.error(f"等待OAuth回调超时,等待了{waited}秒") + return { + "success": False, + "error": "等待OAuth回调超时,请确保完成了浏览器中的认证并看到成功页面", + } + + flow = flow_data["flow"] + auth_code = flow_data["code"] + + log.info(f"开始使用授权码获取凭证: code={'***' + auth_code[-4:] if auth_code else 'None'}") + + # 使用认证代码获取凭证 + with _OAuthLibPatcher(): + try: + log.info("调用flow.exchange_code...") + credentials = await flow.exchange_code(auth_code) + log.info( + f"成功获取凭证,token前缀: {credentials.access_token[:20] if credentials.access_token else 'None'}..." + ) + + log.info( + f"检查是否需要项目检测: auto_project_detection={flow_data.get('auto_project_detection')}, project_id={project_id}" + ) + + # 检查凭证模式 + cred_mode = flow_data.get("mode", "geminicli") if flow_data.get("mode") else mode + if cred_mode == "antigravity": + log.info("Antigravity模式:从API获取project_id...") + # 使用API获取project_id + antigravity_url = await get_code_assist_endpoint() + project_id, subscription_tier = await fetch_project_id_and_tier( + credentials.access_token, + ANTIGRAVITY_USER_AGENT, + antigravity_url + ) + if project_id: + log.info(f"成功从API获取project_id: {project_id}, tier: {subscription_tier}") + else: + project_id = DEFAULT_PROJECT_ID + log.warning(f"无法从API获取project_id,使用默认project_id: {project_id}") + + # 保存antigravity凭证 + saved_filename = await save_credentials(credentials, project_id, mode="antigravity", subscription_tier=subscription_tier) + + # 准备返回的凭证数据 + creds_data = _prepare_credentials_data(credentials, project_id, mode="antigravity", subscription_tier=subscription_tier) + + # 清理使用过的流程 + _cleanup_auth_flow_server(state) + + log.info("Antigravity OAuth认证成功,凭证已保存") + return { + "success": True, + "credentials": creds_data, + "file_path": saved_filename, + "auto_detected_project": False, + "mode": "antigravity", + } + + # 如果需要自动检测项目ID且没有提供项目ID(标准模式) + if flow_data.get("auto_project_detection", False) and not project_id: + log.info("标准模式:从API获取project_id...") + # 使用API获取project_id(使用标准模式的User-Agent) + code_assist_url = await get_code_assist_endpoint() + project_id, subscription_tier = await fetch_project_id_and_tier( + credentials.access_token, + GEMINICLI_USER_AGENT, + code_assist_url + ) + if project_id: + flow_data["project_id"] = project_id + log.info(f"成功从API获取project_id: {project_id}") + # 自动启用必需的API服务 + log.info("正在自动启用必需的API服务...") + await enable_required_apis(credentials, project_id) + else: + log.warning("无法从API获取project_id,回退到项目列表获取方式") + # 回退到原来的项目列表获取方式 + user_projects = await get_user_projects(credentials) + + if user_projects: + # 如果只有一个项目,自动使用 + if len(user_projects) == 1: + # Google API returns projectId in camelCase + project_id = user_projects[0].get("projectId") + if project_id: + flow_data["project_id"] = project_id + log.info(f"自动选择唯一项目: {project_id}") + # 自动启用必需的API服务 + log.info("正在自动启用必需的API服务...") + await enable_required_apis(credentials, project_id) + # 如果有多个项目,尝试选择默认项目 + else: + project_id = await select_default_project(user_projects) + if project_id: + flow_data["project_id"] = project_id + log.info(f"自动选择默认项目: {project_id}") + # 自动启用必需的API服务 + log.info("正在自动启用必需的API服务...") + await enable_required_apis(credentials, project_id) + else: + # 返回项目列表让用户选择 + return { + "success": False, + "error": "请从以下项目中选择一个", + "requires_project_selection": True, + "available_projects": [ + { + # Google API returns projectId in camelCase + "project_id": p.get("projectId"), + "name": p.get("displayName") or p.get("projectId"), + "projectNumber": p.get("projectNumber"), + } + for p in user_projects + ], + } + else: + # 如果无法获取项目列表,使用默认project_id + project_id = DEFAULT_PROJECT_ID + flow_data["project_id"] = project_id + log.warning(f"无法获取项目列表,使用默认project_id: {project_id}") + elif project_id: + # 如果已经有项目ID(手动提供或环境检测),也尝试启用API服务 + log.info("正在为已提供的项目ID自动启用必需的API服务...") + await enable_required_apis(credentials, project_id) + + # 如果仍然没有项目ID,返回错误 + if not project_id: + project_id = DEFAULT_PROJECT_ID + flow_data["project_id"] = project_id + log.warning(f"仍未获取到project_id,使用默认project_id: {project_id}") + + # 保存凭证 + saved_filename = await save_credentials(credentials, project_id) + + # 准备返回的凭证数据 + creds_data = _prepare_credentials_data(credentials, project_id, mode="geminicli") + + # 清理使用过的流程 + _cleanup_auth_flow_server(state) + + log.info("OAuth认证成功,凭证已保存") + return { + "success": True, + "credentials": creds_data, + "file_path": saved_filename, + "auto_detected_project": flow_data.get("auto_project_detection", False), + } + + except Exception as e: + log.error(f"获取凭证失败: {e}") + return {"success": False, "error": f"获取凭证失败: {str(e)}"} + + except Exception as e: + log.error(f"异步完成认证流程失败: {e}") + return {"success": False, "error": str(e)} + + +async def complete_auth_flow_from_callback_url( + callback_url: str, project_id: Optional[str] = None, mode: str = "geminicli" +) -> Dict[str, Any]: + """从回调URL直接完成认证流程,无需启动本地服务器""" + try: + log.info(f"开始从回调URL完成认证: {callback_url}") + + # 解析回调URL + parsed_url = urlparse(callback_url) + query_params = parse_qs(parsed_url.query) + + # 验证必要参数 + if "state" not in query_params or "code" not in query_params: + return {"success": False, "error": "回调URL缺少必要参数 (state 或 code)"} + + state = query_params["state"][0] + code = query_params["code"][0] + + log.info(f"从URL解析到: state={state}, code=xxx...") + + # 检查是否有对应的认证流程 + if state not in auth_flows: + return { + "success": False, + "error": f"未找到对应的认证流程,请先启动认证 (state: {state})", + } + + flow_data = auth_flows[state] + flow = flow_data["flow"] + + # 构造回调URL(使用flow中存储的redirect_uri) + redirect_uri = flow.redirect_uri + log.info(f"使用redirect_uri: {redirect_uri}") + + try: + # 使用authorization code获取token + credentials = await flow.exchange_code(code) + log.info("成功获取访问令牌") + + # 检查凭证模式 + cred_mode = flow_data.get("mode", "geminicli") if flow_data.get("mode") else mode + if cred_mode == "antigravity": + log.info("Antigravity模式(从回调URL):从API获取project_id...") + # 使用API获取project_id + antigravity_url = await get_code_assist_endpoint() + project_id, subscription_tier = await fetch_project_id_and_tier( + credentials.access_token, + ANTIGRAVITY_USER_AGENT, + antigravity_url + ) + if project_id: + log.info(f"成功从API获取project_id: {project_id}, tier: {subscription_tier}") + else: + project_id = DEFAULT_PROJECT_ID + log.warning(f"无法从API获取project_id,使用默认project_id: {project_id}") + + # 保存antigravity凭证 + saved_filename = await save_credentials(credentials, project_id, mode="antigravity", subscription_tier=subscription_tier) + + # 准备返回的凭证数据 + creds_data = _prepare_credentials_data(credentials, project_id, mode="antigravity", subscription_tier=subscription_tier) + + # 清理使用过的流程 + _cleanup_auth_flow_server(state) + + log.info("从回调URL完成Antigravity OAuth认证成功,凭证已保存") + return { + "success": True, + "credentials": creds_data, + "file_path": saved_filename, + "auto_detected_project": False, + "mode": "antigravity", + } + + # 标准模式的项目ID处理逻辑 + detected_project_id = None + auto_detected = False + subscription_tier = None + + if not project_id: + # 尝试使用fetch_project_id_and_tier自动获取项目ID + try: + log.info("标准模式:从API获取project_id...") + code_assist_url = await get_code_assist_endpoint() + detected_project_id, subscription_tier = await fetch_project_id_and_tier( + credentials.access_token, + GEMINICLI_USER_AGENT, + code_assist_url + ) + if detected_project_id: + auto_detected = True + log.info(f"成功从API获取project_id: {detected_project_id}, tier: {subscription_tier}") + else: + log.warning("无法从API获取project_id,回退到项目列表获取方式") + # 回退到原来的项目列表获取方式 + projects = await get_user_projects(credentials) + if projects: + if len(projects) == 1: + # 只有一个项目,自动使用 + # Google API returns projectId in camelCase + detected_project_id = projects[0]["projectId"] + auto_detected = True + log.info(f"自动检测到唯一项目ID: {detected_project_id}") + else: + # 多个项目,自动选择第一个 + # Google API returns projectId in camelCase + detected_project_id = projects[0]["projectId"] + auto_detected = True + log.info( + f"检测到{len(projects)}个项目,自动选择第一个: {detected_project_id}" + ) + log.debug(f"其他可用项目: {[p['projectId'] for p in projects[1:]]}") + else: + # 没有项目访问权限,使用默认project_id + detected_project_id = DEFAULT_PROJECT_ID + auto_detected = False + log.warning(f"未检测到可访问项目,使用默认project_id: {detected_project_id}") + except Exception as e: + log.warning(f"自动检测项目ID失败: {e},使用默认project_id") + detected_project_id = DEFAULT_PROJECT_ID + auto_detected = False + else: + detected_project_id = project_id + + # 启用必需的API服务 + if detected_project_id: + try: + log.info(f"正在为项目 {detected_project_id} 启用必需的API服务...") + await enable_required_apis(credentials, detected_project_id) + except Exception as e: + log.warning(f"启用API服务失败: {e}") + + # 保存凭证 + saved_filename = await save_credentials(credentials, detected_project_id, subscription_tier=subscription_tier) + + # 准备返回的凭证数据 + creds_data = _prepare_credentials_data(credentials, detected_project_id, mode="geminicli", subscription_tier=subscription_tier) + + # 清理使用过的流程 + _cleanup_auth_flow_server(state) + + log.info("从回调URL完成OAuth认证成功,凭证已保存") + return { + "success": True, + "credentials": creds_data, + "file_path": saved_filename, + "auto_detected_project": auto_detected, + } + + except Exception as e: + log.error(f"从回调URL获取凭证失败: {e}") + return {"success": False, "error": f"获取凭证失败: {str(e)}"} + + except Exception as e: + log.error(f"从回调URL完成认证流程失败: {e}") + return {"success": False, "error": str(e)} + + +async def save_credentials(creds: Credentials, project_id: str, mode: str = "geminicli", subscription_tier: str = None) -> str: + """通过统一存储系统保存凭证""" + # 生成文件名(使用project_id和时间戳) + timestamp = int(time.time()) + + # antigravity模式使用特殊前缀 + if mode == "antigravity": + filename = f"ag_{project_id}-{timestamp}.json" + else: + filename = f"{project_id}-{timestamp}.json" + + # 准备凭证数据 + creds_data = _prepare_credentials_data(creds, project_id, mode, subscription_tier) + + # 通过存储适配器保存 + storage_adapter = await get_storage_adapter() + success = await storage_adapter.store_credential(filename, creds_data, mode=mode) + + if success: + # 创建默认状态记录 + try: + default_state = { + "error_codes": [], + "disabled": False, + "last_success": time.time(), + "user_email": None, + "tier": subscription_tier, + } + await storage_adapter.update_credential_state(filename, default_state, mode=mode) + log.info(f"凭证和状态已保存到: {filename} (mode={mode})") + except Exception as e: + log.warning(f"创建默认状态记录失败 {filename}: {e}") + + return filename + else: + raise Exception(f"保存凭证失败: {filename}") + + +def async_shutdown_server(server, port): + """异步关闭OAuth回调服务器,避免阻塞主流程""" + + def shutdown_server_async(): + try: + # 设置一个标志来跟踪关闭状态 + shutdown_completed = threading.Event() + + def do_shutdown(): + try: + server.shutdown() + server.server_close() + shutdown_completed.set() + log.info(f"已关闭端口 {port} 的OAuth回调服务器") + except Exception as e: + shutdown_completed.set() + log.debug(f"关闭服务器时出错: {e}") + + # 在单独线程中执行关闭操作 + shutdown_worker = threading.Thread(target=do_shutdown, daemon=True) + shutdown_worker.start() + + # 等待最多5秒,如果超时就放弃等待 + if shutdown_completed.wait(timeout=5): + log.debug(f"端口 {port} 服务器关闭完成") + else: + log.warning(f"端口 {port} 服务器关闭超时,但不阻塞主流程") + + except Exception as e: + log.debug(f"异步关闭服务器时出错: {e}") + + # 在后台线程中关闭服务器,不阻塞主流程 + shutdown_thread = threading.Thread(target=shutdown_server_async, daemon=True) + shutdown_thread.start() + log.debug(f"开始异步关闭端口 {port} 的OAuth回调服务器") + + +def cleanup_expired_flows(): + """清理过期的认证流程""" + current_time = time.time() + EXPIRY_TIME = 600 # 10分钟过期 + + # 直接遍历删除,避免创建额外列表 + states_to_remove = [ + state + for state, flow_data in auth_flows.items() + if current_time - flow_data["created_at"] > EXPIRY_TIME + ] + + # 批量清理,提高效率 + cleaned_count = 0 + for state in states_to_remove: + flow_data = auth_flows.get(state) + if flow_data: + # 快速关闭可能存在的服务器 + try: + if flow_data.get("server"): + server = flow_data["server"] + port = flow_data.get("callback_port") + async_shutdown_server(server, port) + except Exception as e: + log.debug(f"清理过期流程时启动异步关闭服务器失败: {e}") + + # 显式清理流程数据,释放内存 + flow_data.clear() + del auth_flows[state] + cleaned_count += 1 + + if cleaned_count > 0: + log.info(f"清理了 {cleaned_count} 个过期的认证流程") + + # 更积极的垃圾回收触发条件 + if len(auth_flows) > 20: # 降低阈值 + import gc + + gc.collect() + log.debug(f"触发垃圾回收,当前活跃认证流程数: {len(auth_flows)}") + + +def get_auth_status(project_id: str) -> Dict[str, Any]: + """获取认证状态""" + for state, flow_data in auth_flows.items(): + if flow_data["project_id"] == project_id: + return { + "status": "completed" if flow_data["completed"] else "pending", + "state": state, + "created_at": flow_data["created_at"], + } + + return {"status": "not_found"} + + +# 鉴权功能 - 使用更小的数据结构 +auth_tokens = {} # 存储有效的认证令牌 +TOKEN_EXPIRY = 3600 # 1小时令牌过期时间 + + +async def verify_password(password: str) -> bool: + """验证密码(面板登录使用)""" + from config import get_panel_password + + correct_password = await get_panel_password() + return password == correct_password diff --git a/src/converter/anthropic2gemini.py b/src/converter/anthropic2gemini.py new file mode 100644 index 0000000000000000000000000000000000000000..65c40b5bbac895fc60dd55236ac8cf85071532a7 --- /dev/null +++ b/src/converter/anthropic2gemini.py @@ -0,0 +1,1260 @@ +""" +Anthropic 到 Gemini 格式转换器 + +提供请求体、响应和流式转换的完整功能。 +""" +from __future__ import annotations + +import json +import os +import uuid +from typing import Any, AsyncIterator, Dict, List, Optional + +from fastapi import Response +from log import log +from src.converter.utils import merge_system_messages + +from src.converter.thoughtSignature_fix import ( + encode_tool_id_with_signature, + decode_tool_id_and_signature +) + +DEFAULT_TEMPERATURE = 0.4 +_DEBUG_TRUE = {"1", "true", "yes", "on"} + +# ============================================================================ +# Thinking 块验证和清理 +# ============================================================================ + +# 最小有效签名长度 +MIN_SIGNATURE_LENGTH = 10 + + +def has_valid_thoughtsignature(block: Dict[str, Any]) -> bool: + """ + 检查 thinking 块是否有有效签名 + + Args: + block: content block 字典 + + Returns: + bool: 是否有有效签名 + """ + if not isinstance(block, dict): + return True + + block_type = block.get("type") + if block_type not in ("thinking", "redacted_thinking"): + return True # 非 thinking 块默认有效 + + thinking = block.get("thinking", "") + thoughtsignature = block.get("thoughtSignature") + + # 空 thinking + 任意 thoughtsignature = 有效 (trailing signature case) + if not thinking and thoughtsignature is not None: + return True + + # 有内容 + 足够长度的 thoughtsignature = 有效 + if thoughtsignature and isinstance(thoughtsignature, str) and len(thoughtsignature) >= MIN_SIGNATURE_LENGTH: + return True + + return False + + +def sanitize_thinking_block(block: Dict[str, Any]) -> Dict[str, Any]: + """ + 清理 thinking 块,只保留必要字段(移除 cache_control 等) + + Args: + block: content block 字典 + + Returns: + 清理后的 block 字典 + """ + if not isinstance(block, dict): + return block + + block_type = block.get("type") + if block_type not in ("thinking", "redacted_thinking"): + return block + + # 重建块,移除额外字段 + sanitized: Dict[str, Any] = { + "type": block_type, + "thinking": block.get("thinking", "") + } + + thoughtsignature = block.get("thoughtSignature") + if thoughtsignature: + sanitized["thoughtSignature"] = thoughtsignature + + return sanitized + + +def remove_trailing_unsigned_thinking(blocks: List[Dict[str, Any]]) -> None: + """ + 移除尾部的无签名 thinking 块 + + Args: + blocks: content blocks 列表 (会被修改) + """ + if not blocks: + return + + # 从后向前扫描 + end_index = len(blocks) + for i in range(len(blocks) - 1, -1, -1): + block = blocks[i] + if not isinstance(block, dict): + break + + block_type = block.get("type") + if block_type in ("thinking", "redacted_thinking"): + if not has_valid_thoughtsignature(block): + end_index = i + else: + break # 遇到有效签名的 thinking 块,停止 + else: + break # 遇到非 thinking 块,停止 + + if end_index < len(blocks): + removed = len(blocks) - end_index + del blocks[end_index:] + log.debug(f"Removed {removed} trailing unsigned thinking block(s)") + + +def filter_invalid_thinking_blocks(messages: List[Dict[str, Any]]) -> None: + """ + 过滤消息中的无效 thinking 块,并清理所有 thinking 块的额外字段(如 cache_control) + + Args: + messages: Anthropic messages 列表 (会被修改) + """ + total_filtered = 0 + + for msg in messages: + # 只处理 assistant 和 model 消息 + role = msg.get("role", "") + if role not in ("assistant", "model"): + continue + + content = msg.get("content") + if not isinstance(content, list): + continue + + original_len = len(content) + new_blocks: List[Dict[str, Any]] = [] + + for block in content: + if not isinstance(block, dict): + new_blocks.append(block) + continue + + block_type = block.get("type") + if block_type not in ("thinking", "redacted_thinking"): + new_blocks.append(block) + continue + + # 所有 thinking 块都需要清理(移除 cache_control 等额外字段) + # 检查 thinking 块的有效性 + if has_valid_thoughtsignature(block): + # 有效签名,清理后保留 + new_blocks.append(sanitize_thinking_block(block)) + else: + # 无效签名,将内容转换为 text 块 + thinking_text = block.get("thinking", "") + if thinking_text and str(thinking_text).strip(): + log.info( + f"[Claude-Handler] Converting thinking block with invalid thoughtSignature to text. " + f"Content length: {len(thinking_text)} chars" + ) + new_blocks.append({"type": "text", "text": thinking_text}) + else: + log.debug("[Claude-Handler] Dropping empty thinking block with invalid thoughtSignature") + + msg["content"] = new_blocks + filtered_count = original_len - len(new_blocks) + total_filtered += filtered_count + + # 如果过滤后为空,添加一个空文本块以保持消息有效 + if not new_blocks: + msg["content"] = [{"type": "text", "text": ""}] + + if total_filtered > 0: + log.debug(f"Filtered {total_filtered} invalid thinking block(s) from history") + + +# ============================================================================ +# 请求验证和提取 +# ============================================================================ + + +def _anthropic_debug_enabled() -> bool: + """检查是否启用 Anthropic 调试模式""" + return str(os.getenv("ANTHROPIC_DEBUG", "true")).strip().lower() in _DEBUG_TRUE + + +def _is_non_whitespace_text(value: Any) -> bool: + """ + 判断文本是否包含"非空白"内容。 + + 说明:下游(Antigravity/Claude 兼容层)会对纯 text 内容块做校验: + - text 不能为空字符串 + - text 不能仅由空白字符(空格/换行/制表等)组成 + """ + if value is None: + return False + try: + return bool(str(value).strip()) + except Exception: + return False + + +def _remove_nulls_for_tool_input(value: Any) -> Any: + """ + 递归移除 dict/list 中值为 null/None 的字段/元素。 + + 背景:Roo/Kilo 在 Anthropic native tool 路径下,若收到 tool_use.input 中包含 null, + 可能会把 null 当作真实入参执行(例如"在 null 中搜索")。 + """ + if isinstance(value, dict): + cleaned: Dict[str, Any] = {} + for k, v in value.items(): + if v is None: + continue + cleaned[k] = _remove_nulls_for_tool_input(v) + return cleaned + + if isinstance(value, list): + cleaned_list = [] + for item in value: + if item is None: + continue + cleaned_list.append(_remove_nulls_for_tool_input(item)) + return cleaned_list + + return value + +# ============================================================================ +# 2. JSON Schema 清理 +# ============================================================================ + +def clean_json_schema(schema: Any) -> Any: + """ + 清理 JSON Schema,移除下游不支持的字段,并把验证要求追加到 description。 + """ + if not isinstance(schema, dict): + return schema + + # 下游不支持的字段 + unsupported_keys = { + "$schema", "$id", "$ref", "$defs", "definitions", "title", + "example", "examples", "readOnly", "writeOnly", "default", + "exclusiveMaximum", "exclusiveMinimum", "oneOf", "anyOf", "allOf", + "const", "additionalItems", "contains", "patternProperties", + "dependencies", "propertyNames", "if", "then", "else", + "contentEncoding", "contentMediaType", + } + + validation_fields = { + "minLength": "minLength", + "maxLength": "maxLength", + "minimum": "minimum", + "maximum": "maximum", + "minItems": "minItems", + "maxItems": "maxItems", + } + fields_to_remove = {"additionalProperties"} + + validations: List[str] = [] + for field, label in validation_fields.items(): + if field in schema: + validations.append(f"{label}: {schema[field]}") + + cleaned: Dict[str, Any] = {} + for key, value in schema.items(): + if key in unsupported_keys or key in fields_to_remove or key in validation_fields: + continue + + if key == "type" and isinstance(value, list): + # type: ["string", "null"] -> type: "string", nullable: true + has_null = any( + isinstance(t, str) and t.strip() and t.strip().lower() == "null" for t in value + ) + non_null_types = [ + t.strip() + for t in value + if isinstance(t, str) and t.strip() and t.strip().lower() != "null" + ] + + cleaned[key] = non_null_types[0] if non_null_types else "string" + if has_null: + cleaned["nullable"] = True + continue + + if key == "description" and validations: + cleaned[key] = f"{value} ({', '.join(validations)})" + elif isinstance(value, dict): + cleaned[key] = clean_json_schema(value) + elif isinstance(value, list): + cleaned[key] = [clean_json_schema(item) if isinstance(item, dict) else item for item in value] + else: + cleaned[key] = value + + if validations and "description" not in cleaned: + cleaned["description"] = f"Validation: {', '.join(validations)}" + + # 如果有 properties 但没有显式 type,则补齐为 object + if "properties" in cleaned and "type" not in cleaned: + cleaned["type"] = "object" + + return cleaned + + +# ============================================================================ +# 4. Tools 转换 +# ============================================================================ + +def convert_tools(anthropic_tools: Optional[List[Dict[str, Any]]]) -> Optional[List[Dict[str, Any]]]: + """ + 将 Anthropic tools[] 转换为下游 tools(functionDeclarations)结构。 + """ + if not anthropic_tools: + return None + + gemini_tools: List[Dict[str, Any]] = [] + for tool in anthropic_tools: + name = tool.get("name", "nameless_function") + description = tool.get("description", "") + input_schema = tool.get("input_schema", {}) or {} + parameters = clean_json_schema(input_schema) + + gemini_tools.append( + { + "functionDeclarations": [ + { + "name": name, + "description": description, + "parameters": parameters, + } + ] + } + ) + + return gemini_tools or None + + +# ============================================================================ +# 5. Messages 转换 +# ============================================================================ + +def _extract_tool_result_output(content: Any) -> str: + """从 tool_result.content 中提取输出字符串""" + if isinstance(content, list): + if not content: + return "" + first = content[0] + if isinstance(first, dict) and first.get("type") == "text": + return str(first.get("text", "")) + return str(first) + if content is None: + return "" + return str(content) + + +def convert_messages_to_contents( + messages: List[Dict[str, Any]], + *, + include_thinking: bool = True +) -> List[Dict[str, Any]]: + """ + 将 Anthropic messages[] 转换为下游 contents[](role: user/model, parts: [])。 + + Args: + messages: Anthropic 格式的消息列表 + include_thinking: 是否包含 thinking 块 + """ + contents: List[Dict[str, Any]] = [] + + # 第一遍:构建 tool_use_id -> (name, thoughtsignature) 的映射 + # 注意:存储的是编码后的 ID(可能包含签名) + tool_use_info: Dict[str, tuple[str, Optional[str]]] = {} + for msg in messages: + raw_content = msg.get("content", "") + if isinstance(raw_content, list): + for item in raw_content: + if isinstance(item, dict) and item.get("type") == "tool_use": + encoded_tool_id = item.get("id") + tool_name = item.get("name") + if encoded_tool_id and tool_name: + # 解码获取原始ID和签名 + original_id, thoughtsignature = decode_tool_id_and_signature(encoded_tool_id) + # 存储映射:编码ID -> (name, thoughtsignature) + tool_use_info[str(encoded_tool_id)] = (tool_name, thoughtsignature) + + for msg in messages: + role = msg.get("role", "user") + + # system 消息已经由 merge_system_messages 处理,这里跳过 + if role == "system": + continue + + # 支持 'assistant' 和 'model' 角色(Google history usage) + gemini_role = "model" if role in ("assistant", "model") else "user" + raw_content = msg.get("content", "") + + parts: List[Dict[str, Any]] = [] + if isinstance(raw_content, str): + if _is_non_whitespace_text(raw_content): + parts = [{"text": str(raw_content)}] + elif isinstance(raw_content, list): + for item in raw_content: + if not isinstance(item, dict): + if _is_non_whitespace_text(item): + parts.append({"text": str(item)}) + continue + + item_type = item.get("type") + if item_type == "thinking": + if not include_thinking: + continue + + thinking_text = item.get("thinking", "") + if thinking_text is None: + thinking_text = "" + + part: Dict[str, Any] = { + "text": str(thinking_text), + "thought": True, + } + + # 如果有 thoughtsignature 则添加 + thoughtsignature = item.get("thoughtSignature") + if thoughtsignature: + part["thoughtSignature"] = thoughtsignature + + parts.append(part) + elif item_type == "redacted_thinking": + if not include_thinking: + continue + + thinking_text = item.get("thinking") + if thinking_text is None: + thinking_text = item.get("data", "") + + part_dict: Dict[str, Any] = { + "text": str(thinking_text or ""), + "thought": True, + } + + # 如果有 thoughtsignature 则添加 + thoughtsignature = item.get("thoughtSignature") + if thoughtsignature: + part_dict["thoughtSignature"] = thoughtsignature + + parts.append(part_dict) + elif item_type == "text": + text = item.get("text", "") + if _is_non_whitespace_text(text): + parts.append({"text": str(text)}) + elif item_type == "image": + source = item.get("source", {}) or {} + if source.get("type") == "base64": + parts.append( + { + "inlineData": { + "mimeType": source.get("media_type", "image/png"), + "data": source.get("data", ""), + } + } + ) + elif item_type == "tool_use": + encoded_id = item.get("id") or "" + original_id, thoughtsignature = decode_tool_id_and_signature(encoded_id) + + fc_part: Dict[str, Any] = { + "functionCall": { + "id": original_id, # 使用原始ID,不带签名 + "name": item.get("name"), + "args": item.get("input", {}) or {}, + } + } + + # 如果提取到签名则添加,否则使用占位符以满足 Gemini API 要求 + if thoughtsignature: + fc_part["thoughtSignature"] = thoughtsignature + else: + fc_part["thoughtSignature"] = "skip_thought_signature_validator" + + parts.append(fc_part) + elif item_type == "tool_result": + output = _extract_tool_result_output(item.get("content")) + encoded_tool_use_id = item.get("tool_use_id") or "" + + # 解码获取原始ID(functionResponse不需要签名) + original_tool_use_id, _ = decode_tool_id_and_signature(encoded_tool_use_id) + + # 从 tool_result 获取 name,如果没有则从映射中查找 + func_name = item.get("name") + if not func_name and encoded_tool_use_id: + # 使用编码ID查找映射 + tool_info = tool_use_info.get(str(encoded_tool_use_id)) + if tool_info: + func_name = tool_info[0] # 获取 name + if not func_name: + func_name = "unknown_function" + + parts.append( + { + "functionResponse": { + "id": original_tool_use_id, # 使用解码后的原始ID以匹配functionCall + "name": func_name, + "response": {"output": output}, + } + } + ) + else: + parts.append({"text": json.dumps(item, ensure_ascii=False)}) + else: + if _is_non_whitespace_text(raw_content): + parts = [{"text": str(raw_content)}] + + if not parts: + continue + + contents.append({"role": gemini_role, "parts": parts}) + + return contents + + +def reorganize_tool_messages(contents: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """ + 重新组织消息,满足 tool_use/tool_result 约束。 + """ + tool_results: Dict[str, Dict[str, Any]] = {} + + for msg in contents: + for part in msg.get("parts", []) or []: + if isinstance(part, dict) and "functionResponse" in part: + tool_id = (part.get("functionResponse") or {}).get("id") + if tool_id: + tool_results[str(tool_id)] = part + + flattened: List[Dict[str, Any]] = [] + for msg in contents: + role = msg.get("role") + for part in msg.get("parts", []) or []: + flattened.append({"role": role, "parts": [part]}) + + new_contents: List[Dict[str, Any]] = [] + i = 0 + while i < len(flattened): + msg = flattened[i] + part = msg["parts"][0] + + if isinstance(part, dict) and "functionResponse" in part: + i += 1 + continue + + if isinstance(part, dict) and "functionCall" in part: + tool_id = (part.get("functionCall") or {}).get("id") + new_contents.append({"role": "model", "parts": [part]}) + + if tool_id is not None and str(tool_id) in tool_results: + new_contents.append({"role": "user", "parts": [tool_results[str(tool_id)]]}) + + i += 1 + continue + + new_contents.append(msg) + i += 1 + + return new_contents + + +# ============================================================================ +# 7. Tool Choice 转换 +# ============================================================================ + +def convert_tool_choice_to_tool_config(tool_choice: Any) -> Optional[Dict[str, Any]]: + """ + 将 Anthropic tool_choice 转换为 Gemini toolConfig + + Args: + tool_choice: Anthropic 格式的 tool_choice + - {"type": "auto"}: 模型自动决定是否使用工具 + - {"type": "any"}: 模型必须使用工具 + - {"type": "tool", "name": "tool_name"}: 模型必须使用指定工具 + + Returns: + Gemini 格式的 toolConfig,如果无效则返回 None + """ + if not tool_choice: + return None + + if isinstance(tool_choice, dict): + choice_type = tool_choice.get("type") + + if choice_type == "auto": + return {"functionCallingConfig": {"mode": "AUTO"}} + elif choice_type == "any": + return {"functionCallingConfig": {"mode": "ANY"}} + elif choice_type == "tool": + tool_name = tool_choice.get("name") + if tool_name: + return { + "functionCallingConfig": { + "mode": "ANY", + "allowedFunctionNames": [tool_name], + } + } + + # 无效或不支持的 tool_choice,返回 None + return None + + +# ============================================================================ +# 8. Generation Config 构建 +# ============================================================================ + +def build_generation_config(payload: Dict[str, Any]) -> Dict[str, Any]: + """ + 根据 Anthropic Messages 请求构造下游 generationConfig。 + + Returns: + generation_config: 生成配置字典 + """ + config: Dict[str, Any] = { + "topP": 1, + "candidateCount": 1, + "stopSequences": [ + "<|user|>", + "<|bot|>", + "<|context_request|>", + "<|endoftext|>", + "<|end_of_turn|>", + ], + } + + temperature = payload.get("temperature", None) + config["temperature"] = DEFAULT_TEMPERATURE if temperature is None else temperature + + top_p = payload.get("top_p", None) + if top_p is not None: + config["topP"] = top_p + + top_k = payload.get("top_k", None) + if top_k is not None: + config["topK"] = top_k + + max_tokens = payload.get("max_tokens") + if max_tokens is not None: + config["maxOutputTokens"] = max_tokens + + # 处理 extended thinking 参数 (plan mode) + thinking = payload.get("thinking") + is_plan_mode = False + if thinking and isinstance(thinking, dict): + thinking_type = thinking.get("type") + budget_tokens = thinking.get("budget_tokens") + + # 如果启用了 extended thinking,设置 thinkingConfig + if thinking_type == "enabled": + is_plan_mode = True + thinking_config: Dict[str, Any] = {} + + # 设置思考预算,默认使用较大的值以支持计划模式 + if budget_tokens is not None: + thinking_config["thinkingBudget"] = budget_tokens + else: + # 默认给一个较大的思考预算以支持完整的计划生成 + thinking_config["thinkingBudget"] = 48000 + + # 始终包含思考内容,这样才能看到计划 + thinking_config["includeThoughts"] = True + + config["thinkingConfig"] = thinking_config + log.info(f"[ANTHROPIC2GEMINI] Extended thinking enabled with budget: {thinking_config['thinkingBudget']}") + elif thinking_type == "disabled": + # 明确禁用思考模式 + config["thinkingConfig"] = { + "includeThoughts": False + } + log.info("[ANTHROPIC2GEMINI] Extended thinking explicitly disabled") + + stop_sequences = payload.get("stop_sequences") + if isinstance(stop_sequences, list) and stop_sequences: + config["stopSequences"] = config["stopSequences"] + [str(s) for s in stop_sequences] + elif is_plan_mode: + # Plan mode 时清空默认 stop sequences,避免过早停止 + # 默认的 stop sequences 可能会导致模型在生成计划时过早停止 + config["stopSequences"] = [] + log.info("[ANTHROPIC2GEMINI] Plan mode: cleared default stop sequences to prevent premature stopping") + + # 如果不是 plan mode 且没有自定义 stop_sequences,保持默认值 + # (默认值已经在 config 初始化时设置) + + return config + + +# ============================================================================ +# 8. 主要转换函数 +# ============================================================================ + +async def anthropic_to_gemini_request(payload: Dict[str, Any]) -> Dict[str, Any]: + """ + 将 Anthropic 格式请求体转换为 Gemini 格式请求体 + + 注意: 此函数只负责基础转换,不包含 normalize_gemini_request 中的处理 + (如 thinking config 自动设置、search tools、参数范围限制等) + + Args: + payload: Anthropic 格式的请求体字典 + + Returns: + Gemini 格式的请求体字典,包含: + - contents: 转换后的消息内容 + - generationConfig: 生成配置 + - systemInstruction: 系统指令 (如果有) + - tools: 工具定义 (如果有) + - toolConfig: 工具调用配置 (如果有 tool_choice) + """ + # 处理连续的system消息(兼容性模式) + payload = await merge_system_messages(payload) + + # 提取和转换基础信息 + messages = payload.get("messages") or [] + if not isinstance(messages, list): + messages = [] + + # [CRITICAL FIX] 过滤并修复 Thinking 块签名 + # 在转换前先过滤无效的 thinking 块 + filter_invalid_thinking_blocks(messages) + + # 构建生成配置 + generation_config = build_generation_config(payload) + + # 转换消息内容(始终包含thinking块,由响应端处理) + contents = convert_messages_to_contents(messages, include_thinking=True) + + # [CRITICAL FIX] 移除尾部无签名的 thinking 块 + # 对真实请求应用额外的清理 + for content in contents: + role = content.get("role", "") + if role == "model": # 只处理 model/assistant 消息 + parts = content.get("parts", []) + if isinstance(parts, list): + remove_trailing_unsigned_thinking(parts) + + contents = reorganize_tool_messages(contents) + + # 转换工具 + tools = convert_tools(payload.get("tools")) + + # 转换 tool_choice + tool_config = convert_tool_choice_to_tool_config(payload.get("tool_choice")) + + # 构建基础请求数据 + gemini_request = { + "contents": contents, + "generationConfig": generation_config, + } + + # 如果 merge_system_messages 已经添加了 systemInstruction,使用它 + if "systemInstruction" in payload: + gemini_request["systemInstruction"] = payload["systemInstruction"] + + if tools: + gemini_request["tools"] = tools + + # 添加 toolConfig(如果有 tool_choice) + if tool_config: + gemini_request["toolConfig"] = tool_config + + return gemini_request + + +def gemini_to_anthropic_response( + gemini_response: Dict[str, Any], + model: str, + status_code: int = 200 +) -> Dict[str, Any]: + """ + 将 Gemini 格式非流式响应转换为 Anthropic 格式非流式响应 + + 注意: 如果收到的不是 200 开头的响应体,不做任何处理,直接转发 + + Args: + gemini_response: Gemini 格式的响应体字典 + model: 模型名称 + status_code: HTTP 状态码 (默认 200) + + Returns: + Anthropic 格式的响应体字典,或原始响应 (如果状态码不是 2xx) + """ + # 非 2xx 状态码直接返回原始响应 + if not (200 <= status_code < 300): + return gemini_response + + # 处理 GeminiCLI 的 response 包装格式 + if "response" in gemini_response: + response_data = gemini_response["response"] + else: + response_data = gemini_response + + # 提取候选结果 + candidate = response_data.get("candidates", [{}])[0] or {} + parts = candidate.get("content", {}).get("parts", []) or [] + + # 获取 usage metadata + usage_metadata = {} + if "usageMetadata" in response_data: + usage_metadata = response_data["usageMetadata"] + elif "usageMetadata" in candidate: + usage_metadata = candidate["usageMetadata"] + + # 转换内容块 + content = [] + has_tool_use = False + + for part in parts: + if not isinstance(part, dict): + continue + + # 处理 thinking 块 + if part.get("thought") is True: + thinking_text = part.get("text", "") + if thinking_text is None: + thinking_text = "" + + block: Dict[str, Any] = {"type": "thinking", "thinking": str(thinking_text)} + + # 如果有 thoughtsignature 则添加 + thoughtsignature = part.get("thoughtSignature") + if thoughtsignature: + block["thoughtSignature"] = thoughtsignature + + content.append(block) + continue + + # 处理文本块 + if "text" in part: + content.append({"type": "text", "text": part.get("text", "")}) + continue + + # 处理工具调用 + if "functionCall" in part: + has_tool_use = True + fc = part.get("functionCall", {}) or {} + original_id = fc.get("id") or f"toolu_{uuid.uuid4().hex}" + thoughtsignature = part.get("thoughtSignature") + + # 对工具调用ID进行签名编码 + encoded_id = encode_tool_id_with_signature(original_id, thoughtsignature) + content.append( + { + "type": "tool_use", + "id": encoded_id, + "name": fc.get("name") or "", + "input": _remove_nulls_for_tool_input(fc.get("args", {}) or {}), + } + ) + continue + + # 处理图片 + if "inlineData" in part: + inline = part.get("inlineData", {}) or {} + content.append( + { + "type": "image", + "source": { + "type": "base64", + "media_type": inline.get("mimeType", "image/png"), + "data": inline.get("data", ""), + }, + } + ) + continue + + # 确定停止原因 + finish_reason = candidate.get("finishReason") + + # 只有在正常停止(STOP)且有工具调用时才设为 tool_use + # 避免在 SAFETY、MAX_TOKENS 等情况下仍然返回 tool_use 导致循环 + if has_tool_use and finish_reason == "STOP": + stop_reason = "tool_use" + elif finish_reason == "MAX_TOKENS": + stop_reason = "max_tokens" + else: + # 其他情况(SAFETY、RECITATION 等)默认为 end_turn + stop_reason = "end_turn" + + # 提取 token 使用情况 + input_tokens = usage_metadata.get("promptTokenCount", 0) if isinstance(usage_metadata, dict) else 0 + output_tokens = usage_metadata.get("candidatesTokenCount", 0) if isinstance(usage_metadata, dict) else 0 + + # 构建 Anthropic 响应 + message_id = f"msg_{uuid.uuid4().hex}" + + return { + "id": message_id, + "type": "message", + "role": "assistant", + "model": model, + "content": content, + "stop_reason": stop_reason, + "stop_sequence": None, + "usage": { + "input_tokens": int(input_tokens or 0), + "output_tokens": int(output_tokens or 0), + }, + } + + +async def gemini_stream_to_anthropic_stream( + gemini_stream: AsyncIterator[bytes], + model: str, + status_code: int = 200 +) -> AsyncIterator[bytes]: + """ + 将 Gemini 格式流式响应转换为 Anthropic SSE 格式流式响应 + + 注意: 如果收到的不是 200 开头的响应体,不做任何处理,直接转发 + + Args: + gemini_stream: Gemini 格式的流式响应 (bytes 迭代器) + model: 模型名称 + status_code: HTTP 状态码 (默认 200) + + Yields: + Anthropic SSE 格式的响应块 (bytes) + """ + # 非 2xx 状态码直接转发原始流 + if not (200 <= status_code < 300): + async for chunk in gemini_stream: + yield chunk + return + + # 初始化状态 + message_id = f"msg_{uuid.uuid4().hex}" + message_start_sent = False + current_block_type: Optional[str] = None + current_block_index = -1 + current_thinking_signature: Optional[str] = None + has_tool_use = False + input_tokens = 0 + output_tokens = 0 + finish_reason: Optional[str] = None + + def _sse_event(event: str, data: Dict[str, Any]) -> bytes: + """生成 SSE 事件""" + payload = json.dumps(data, ensure_ascii=False, separators=(",", ":")) + return f"event: {event}\ndata: {payload}\n\n".encode("utf-8") + + def _close_block() -> Optional[bytes]: + """关闭当前内容块""" + nonlocal current_block_type + if current_block_type is None: + return None + event = _sse_event( + "content_block_stop", + {"type": "content_block_stop", "index": current_block_index}, + ) + current_block_type = None + return event + + # 处理流式数据 + try: + async for chunk in gemini_stream: + # 检查是否是 Response 对象(错误情况) + if isinstance(chunk, Response): + log.warning(f"[GEMINI_TO_ANTHROPIC] 收到 Response 对象,状态码: {chunk.status_code},直接转发错误") + # 直接转发错误响应内容,不做格式转换 + error_content = chunk.body if isinstance(chunk.body, bytes) else chunk.body.encode('utf-8') + yield error_content + return + + # 记录接收到的原始chunk + log.debug(f"[GEMINI_TO_ANTHROPIC] Raw chunk: {chunk[:200] if chunk else b''}") + + # 解析 Gemini 流式块 + if not chunk or not chunk.startswith(b"data: "): + log.debug(f"[GEMINI_TO_ANTHROPIC] Skipping chunk (not SSE format or empty)") + continue + + raw = chunk[6:].strip() + if raw == b"[DONE]": + log.debug(f"[GEMINI_TO_ANTHROPIC] Received [DONE] marker") + break + + log.debug(f"[GEMINI_TO_ANTHROPIC] Parsing JSON: {raw[:200]}") + + try: + data = json.loads(raw.decode('utf-8', errors='ignore')) + log.debug(f"[GEMINI_TO_ANTHROPIC] Parsed data: {json.dumps(data, ensure_ascii=False)[:300]}") + except Exception as e: + log.warning(f"[GEMINI_TO_ANTHROPIC] JSON parse error: {e}") + continue + + # 处理 GeminiCLI 的 response 包装格式 + if "response" in data: + response = data["response"] + else: + response = data + + candidate = (response.get("candidates", []) or [{}])[0] or {} + parts = (candidate.get("content", {}) or {}).get("parts", []) or [] + + # 更新 usage metadata + if "usageMetadata" in response: + usage = response["usageMetadata"] + if isinstance(usage, dict): + if "promptTokenCount" in usage: + input_tokens = int(usage.get("promptTokenCount", 0) or 0) + if "candidatesTokenCount" in usage: + output_tokens = int(usage.get("candidatesTokenCount", 0) or 0) + + # 发送 message_start(仅一次) + if not message_start_sent: + message_start_sent = True + yield _sse_event( + "message_start", + { + "type": "message_start", + "message": { + "id": message_id, + "type": "message", + "role": "assistant", + "model": model, + "content": [], + "stop_reason": None, + "stop_sequence": None, + "usage": {"input_tokens": input_tokens, "output_tokens": output_tokens}, + }, + }, + ) + + # 处理各种 parts + for part in parts: + if not isinstance(part, dict): + continue + + # 处理 thinking 块 + if part.get("thought") is True: + thinking_text = part.get("text", "") + thoughtsignature = part.get("thoughtSignature") + + # 检查是否需要关闭上一个块并开启新的 thinking 块 + if current_block_type != "thinking": + close_evt = _close_block() + if close_evt: + yield close_evt + + current_block_index += 1 + current_block_type = "thinking" + current_thinking_signature = thoughtsignature + + block: Dict[str, Any] = {"type": "thinking", "thinking": ""} + if thoughtsignature: + block["thoughtSignature"] = thoughtsignature + yield _sse_event( + "content_block_start", + { + "type": "content_block_start", + "index": current_block_index, + "content_block": block, + }, + ) + elif thoughtsignature and thoughtsignature != current_thinking_signature: + # 签名变化,需要开启新的 thinking 块 + close_evt = _close_block() + if close_evt: + yield close_evt + + current_block_index += 1 + current_block_type = "thinking" + current_thinking_signature = thoughtsignature + + block_new: Dict[str, Any] = {"type": "thinking", "thinking": ""} + if thoughtsignature: + block_new["thoughtSignature"] = thoughtsignature + + yield _sse_event( + "content_block_start", + { + "type": "content_block_start", + "index": current_block_index, + "content_block": block_new, + }, + ) + + # 发送 thinking 文本增量 + if thinking_text: + yield _sse_event( + "content_block_delta", + { + "type": "content_block_delta", + "index": current_block_index, + "delta": {"type": "thinking_delta", "thinking": thinking_text}, + }, + ) + continue + + # 处理文本块 + if "text" in part: + text = part.get("text", "") + if isinstance(text, str) and not text.strip(): + continue + + if current_block_type != "text": + close_evt = _close_block() + if close_evt: + yield close_evt + + current_block_index += 1 + current_block_type = "text" + + yield _sse_event( + "content_block_start", + { + "type": "content_block_start", + "index": current_block_index, + "content_block": {"type": "text", "text": ""}, + }, + ) + + if text: + yield _sse_event( + "content_block_delta", + { + "type": "content_block_delta", + "index": current_block_index, + "delta": {"type": "text_delta", "text": text}, + }, + ) + continue + + # 处理工具调用 + if "functionCall" in part: + close_evt = _close_block() + if close_evt: + yield close_evt + + has_tool_use = True + fc = part.get("functionCall", {}) or {} + original_id = fc.get("id") or f"toolu_{uuid.uuid4().hex}" + thoughtsignature = part.get("thoughtSignature") + tool_id = encode_tool_id_with_signature(original_id, thoughtsignature) + tool_name = fc.get("name") or "" + tool_args = _remove_nulls_for_tool_input(fc.get("args", {}) or {}) + + if _anthropic_debug_enabled(): + log.info( + f"[ANTHROPIC][tool_use] 处理工具调用: name={tool_name}, " + f"id={tool_id}, has_signature={thoughtsignature is not None}" + ) + + current_block_index += 1 + # 注意:工具调用不设置 current_block_type,因为它是独立完整的块 + + yield _sse_event( + "content_block_start", + { + "type": "content_block_start", + "index": current_block_index, + "content_block": { + "type": "tool_use", + "id": tool_id, + "name": tool_name, + "input": {}, + }, + }, + ) + + input_json = json.dumps(tool_args, ensure_ascii=False, separators=(",", ":")) + yield _sse_event( + "content_block_delta", + { + "type": "content_block_delta", + "index": current_block_index, + "delta": {"type": "input_json_delta", "partial_json": input_json}, + }, + ) + + yield _sse_event( + "content_block_stop", + {"type": "content_block_stop", "index": current_block_index}, + ) + # 工具调用块已完全关闭,current_block_type 保持为 None + + if _anthropic_debug_enabled(): + log.info(f"[ANTHROPIC][tool_use] 工具调用块已关闭: index={current_block_index}") + + continue + + # 检查是否结束 + if candidate.get("finishReason"): + finish_reason = candidate.get("finishReason") + break + + # 关闭最后的内容块 + close_evt = _close_block() + if close_evt: + yield close_evt + + # 确定停止原因 + # 只有在正常停止(STOP)且有工具调用时才设为 tool_use + # 避免在 SAFETY、MAX_TOKENS 等情况下仍然返回 tool_use 导致循环 + if has_tool_use and finish_reason == "STOP": + stop_reason = "tool_use" + elif finish_reason == "MAX_TOKENS": + stop_reason = "max_tokens" + else: + # 其他情况(SAFETY、RECITATION 等)默认为 end_turn + stop_reason = "end_turn" + + if _anthropic_debug_enabled(): + log.info( + f"[ANTHROPIC][stream_end] 流式结束: stop_reason={stop_reason}, " + f"has_tool_use={has_tool_use}, finish_reason={finish_reason}, " + f"input_tokens={input_tokens}, output_tokens={output_tokens}" + ) + + # 发送 message_delta 和 message_stop + yield _sse_event( + "message_delta", + { + "type": "message_delta", + "delta": {"stop_reason": stop_reason, "stop_sequence": None}, + "usage": { + "output_tokens": output_tokens, + }, + }, + ) + + yield _sse_event("message_stop", {"type": "message_stop"}) + + except Exception as e: + log.error(f"[ANTHROPIC] 流式转换失败: {e}") + # 发送错误事件 + if not message_start_sent: + yield _sse_event( + "message_start", + { + "type": "message_start", + "message": { + "id": message_id, + "type": "message", + "role": "assistant", + "model": model, + "content": [], + "stop_reason": None, + "stop_sequence": None, + "usage": {"input_tokens": input_tokens, "output_tokens": output_tokens}, + }, + }, + ) + yield _sse_event( + "error", + {"type": "error", "error": {"type": "api_error", "message": str(e)}}, + ) \ No newline at end of file diff --git a/src/converter/anti_truncation.py b/src/converter/anti_truncation.py new file mode 100644 index 0000000000000000000000000000000000000000..578d1cdf101583861cb26147109ddd802d3fe9c9 --- /dev/null +++ b/src/converter/anti_truncation.py @@ -0,0 +1,731 @@ +""" +Anti-Truncation Module - Ensures complete streaming output +保持一个流式请求内完整输出的反截断模块 +""" + +import io +import json +import re +from typing import Any, AsyncGenerator, Dict, List, Tuple + +from fastapi.responses import StreamingResponse + +from log import log + +# 反截断配置 +DONE_MARKER = "[done]" +CONTINUATION_PROMPT = f"""请从刚才被截断的地方继续输出剩余的所有内容。 + +重要提醒: +1. 不要重复前面已经输出的内容 +2. 直接继续输出,无需任何前言或解释 +3. 当你完整完成所有内容输出后,必须在最后一行单独输出:{DONE_MARKER} +4. {DONE_MARKER} 标记表示你的回答已经完全结束,这是必需的结束标记 + +现在请继续输出:""" + +# 正则替换配置 +REGEX_REPLACEMENTS: List[Tuple[str, str, str]] = [ + ( + "age_pattern", # 替换规则名称 + r"(?:[1-9]|1[0-8])岁(?:的)?|(?:十一|十二|十三|十四|十五|十六|十七|十八|十|一|二|三|四|五|六|七|八|九)岁(?:的)?", # 正则模式 + "", # 替换文本 + ), + # 可在此处添加更多替换规则 + # ("rule_name", r"pattern", "replacement"), +] + + +def apply_regex_replacements(text: str) -> str: + """ + 对文本应用正则替换规则 + + Args: + text: 要处理的文本 + + Returns: + 处理后的文本 + """ + if not text: + return text + + processed_text = text + replacement_count = 0 + + for rule_name, pattern, replacement in REGEX_REPLACEMENTS: + try: + # 编译正则表达式,使用IGNORECASE标志 + regex = re.compile(pattern, re.IGNORECASE) + + # 执行替换 + new_text, count = regex.subn(replacement, processed_text) + + if count > 0: + log.debug(f"Regex replacement '{rule_name}': {count} matches replaced") + processed_text = new_text + replacement_count += count + + except re.error as e: + log.error(f"Invalid regex pattern in rule '{rule_name}': {e}") + continue + + if replacement_count > 0: + log.info(f"Applied {replacement_count} regex replacements to text") + + return processed_text + + +def apply_regex_replacements_to_payload(payload: Dict[str, Any]) -> Dict[str, Any]: + """ + 对请求payload中的文本内容应用正则替换 + + Args: + payload: 请求payload + + Returns: + 应用替换后的payload + """ + if not REGEX_REPLACEMENTS: + return payload + + modified_payload = payload.copy() + request_data = modified_payload.get("request", {}) + + # 处理contents中的文本 + contents = request_data.get("contents", []) + if contents: + new_contents = [] + for content in contents: + if isinstance(content, dict): + new_content = content.copy() + parts = new_content.get("parts", []) + if parts: + new_parts = [] + for part in parts: + if isinstance(part, dict) and "text" in part: + new_part = part.copy() + new_part["text"] = apply_regex_replacements(part["text"]) + new_parts.append(new_part) + else: + new_parts.append(part) + new_content["parts"] = new_parts + new_contents.append(new_content) + else: + new_contents.append(content) + + request_data["contents"] = new_contents + modified_payload["request"] = request_data + log.debug("Applied regex replacements to request contents") + + return modified_payload + + +def apply_anti_truncation(payload: Dict[str, Any]) -> Dict[str, Any]: + """ + 对请求payload应用反截断处理和正则替换 + 在systemInstruction中添加提醒,要求模型在结束时输出DONE_MARKER标记 + + Args: + payload: 原始请求payload + + Returns: + 添加了反截断指令并应用了正则替换的payload + """ + # 首先应用正则替换 + modified_payload = apply_regex_replacements_to_payload(payload) + request_data = modified_payload.get("request", {}) + + # 获取或创建systemInstruction + system_instruction = request_data.get("systemInstruction", {}) + if not system_instruction: + system_instruction = {"parts": []} + elif "parts" not in system_instruction: + system_instruction["parts"] = [] + + # 添加反截断指令 + anti_truncation_instruction = { + "text": f"""严格执行以下输出结束规则: + +1. 当你完成完整回答时,必须在输出的最后单独一行输出:{DONE_MARKER} +2. {DONE_MARKER} 标记表示你的回答已经完全结束,这是必需的结束标记 +3. 只有输出了 {DONE_MARKER} 标记,系统才认为你的回答是完整的 +4. 如果你的回答被截断,系统会要求你继续输出剩余内容 +5. 无论回答长短,都必须以 {DONE_MARKER} 标记结束 + +示例格式: +``` +你的回答内容... +更多回答内容... +{DONE_MARKER} +``` + +注意:{DONE_MARKER} 必须单独占一行,前面不要有任何其他字符。 + +这个规则对于确保输出完整性极其重要,请严格遵守。""" + } + + # 检查是否已经包含反截断指令 + has_done_instruction = any( + part.get("text", "").find(DONE_MARKER) != -1 + for part in system_instruction["parts"] + if isinstance(part, dict) + ) + + if not has_done_instruction: + system_instruction["parts"].append(anti_truncation_instruction) + request_data["systemInstruction"] = system_instruction + modified_payload["request"] = request_data + + log.debug("Applied anti-truncation instruction to request") + + return modified_payload + + +class AntiTruncationStreamProcessor: + """反截断流式处理器""" + + def __init__( + self, + original_request_func, + payload: Dict[str, Any], + max_attempts: int = 3, + enable_prefill_mode: bool = False, + ): + self.original_request_func = original_request_func + self.base_payload = payload.copy() + self.max_attempts = max_attempts + self.enable_prefill_mode = enable_prefill_mode + # 使用 StringIO 避免字符串拼接的内存问题 + self.collected_content = io.StringIO() + self.current_attempt = 0 + + def _get_collected_text(self) -> str: + """获取收集的文本内容""" + return self.collected_content.getvalue() + + def _append_content(self, content: str): + """追加内容到收集器""" + if content: + self.collected_content.write(content) + + def _clear_content(self): + """清空收集的内容,释放内存""" + self.collected_content.close() + self.collected_content = io.StringIO() + + async def process_stream(self) -> AsyncGenerator[bytes, None]: + """处理流式响应,检测并处理截断""" + + while self.current_attempt < self.max_attempts: + self.current_attempt += 1 + + # 构建当前请求payload + current_payload = self._build_current_payload() + + log.debug(f"Anti-truncation attempt {self.current_attempt}/{self.max_attempts}") + + # 发送请求 + try: + response = await self.original_request_func(current_payload) + + if not isinstance(response, StreamingResponse): + # 非流式响应,直接处理 + yield await self._handle_non_streaming_response(response) + return + + # 处理流式响应(按行处理) + chunk_buffer = io.StringIO() # 使用 StringIO 缓存当前轮次的chunk + found_done_marker = False + + async for line in response.body_iterator: + if not line: + yield line + continue + + # 处理上游生成器 yield 出 Response 对象的情况(错误响应) + from fastapi import Response as FastAPIResponse + if isinstance(line, FastAPIResponse): + log.error(f"Anti-truncation: Received Response object from stream (status={line.status_code}), treating as error") + error_chunk = { + "error": { + "message": line.body.decode('utf-8', errors='ignore') if hasattr(line, 'body') and line.body else "Upstream error", + "type": "api_error", + "code": line.status_code, + } + } + yield f"data: {json.dumps(error_chunk)}\n\n".encode() + yield b"data: [DONE]\n\n" + return + + # 处理 bytes 类型的流式数据 + if isinstance(line, bytes): + # 解码 bytes 为字符串 + line_str = line.decode('utf-8', errors='ignore').strip() + else: + line_str = str(line).strip() + + # 跳过空行 + if not line_str: + yield line + continue + + # 处理 SSE 格式的数据行 + if line_str.startswith("data: "): + payload_str = line_str[6:] # 去掉 "data: " 前缀 + + # 检查是否是 [DONE] 标记 + if payload_str.strip() == "[DONE]": + if found_done_marker: + log.info("Anti-truncation: Found [done] marker, output complete") + yield line + # 清理内存 + chunk_buffer.close() + self._clear_content() + return + else: + log.warning("Anti-truncation: Stream ended without [done] marker") + # 不发送[DONE],准备继续 + break + + # 尝试解析 JSON 数据 + try: + data = json.loads(payload_str) + content = self._extract_content_from_chunk(data) + + log.debug(f"Anti-truncation: Extracted content: {repr(content[:100] if content else '')}") + + if content: + chunk_buffer.write(content) + + # 检查是否包含done标记 + has_marker = self._check_done_marker_in_chunk_content(content) + log.debug(f"Anti-truncation: Check done marker result: {has_marker}, DONE_MARKER='{DONE_MARKER}'") + if has_marker: + found_done_marker = True + log.debug(f"Anti-truncation: Found [done] marker in chunk, content: {content[:200]}") + + # 清理行中的[done]标记后再发送 + cleaned_line = self._remove_done_marker_from_line(line, line_str, data) + yield cleaned_line + + except (json.JSONDecodeError, ValueError): + # 无法解析的行,直接传递 + yield line + continue + else: + # 非 data: 开头的行,直接传递 + yield line + + # 更新收集的内容 - 使用 StringIO 高效处理 + chunk_text = chunk_buffer.getvalue() + if chunk_text: + self._append_content(chunk_text) + chunk_buffer.close() + + log.debug(f"Anti-truncation: After processing stream, found_done_marker={found_done_marker}") + + # 如果找到了done标记,结束 + if found_done_marker: + # 立即清理内容释放内存 + self._clear_content() + yield b"data: [DONE]\n\n" + return + + # 只有在单个chunk中没有找到done标记时,才检查累积内容(防止done标记跨chunk出现) + if not found_done_marker: + accumulated_text = self._get_collected_text() + if self._check_done_marker_in_text(accumulated_text): + log.info("Anti-truncation: Found [done] marker in accumulated content") + # 立即清理内容释放内存 + self._clear_content() + yield b"data: [DONE]\n\n" + return + + # 如果没找到done标记且不是最后一次尝试,准备续传 + if self.current_attempt < self.max_attempts: + accumulated_text = self._get_collected_text() + total_length = len(accumulated_text) + log.info( + f"Anti-truncation: No [done] marker found in output (length: {total_length}), preparing continuation (attempt {self.current_attempt + 1})" + ) + if total_length > 100: + log.debug( + f"Anti-truncation: Current collected content ends with: ...{accumulated_text[-100:]}" + ) + # 在下一次循环中会继续 + continue + else: + # 最后一次尝试,直接结束 + log.warning("Anti-truncation: Max attempts reached, ending stream") + # 立即清理内容释放内存 + self._clear_content() + yield b"data: [DONE]\n\n" + return + + except Exception as e: + log.error(f"Anti-truncation error in attempt {self.current_attempt}: {str(e)}") + if self.current_attempt >= self.max_attempts: + # 发送错误chunk + error_chunk = { + "error": { + "message": f"Anti-truncation failed: {str(e)}", + "type": "api_error", + "code": 500, + } + } + yield f"data: {json.dumps(error_chunk)}\n\n".encode() + yield b"data: [DONE]\n\n" + return + # 否则继续下一次尝试 + + # 如果所有尝试都失败了 + log.error("Anti-truncation: All attempts failed") + # 清理内存 + self._clear_content() + yield b"data: [DONE]\n\n" + + def _build_current_payload(self) -> Dict[str, Any]: + """构建当前请求的payload""" + if self.current_attempt == 1: + # 第一次请求,使用原始payload(已经包含反截断指令) + return self.base_payload + + # 后续请求,添加续传指令 + continuation_payload = self.base_payload.copy() + request_data = continuation_payload.get("request", {}) + + # 获取原始对话内容 + contents = request_data.get("contents", []) + new_contents = contents.copy() + + # 如果有收集到的内容,添加到对话中 + accumulated_text = self._get_collected_text() + if accumulated_text: + new_contents.append({"role": "model", "parts": [{"text": accumulated_text}]}) + + # 预填充模式:直接用拼接内容作为末尾 model 预填充,不再增加 user 续写指令 + if self.enable_prefill_mode: + log.debug("Anti-truncation: Using prefill continuation mode (no user continuation prompt)") + request_data["contents"] = new_contents + continuation_payload["request"] = request_data + return continuation_payload + + # 构建具体的续写指令,包含前面的内容摘要 + content_summary = "" + if accumulated_text: + if len(accumulated_text) > 200: + content_summary = f'\n\n前面你已经输出了约 {len(accumulated_text)} 个字符的内容,结尾是:\n"...{accumulated_text[-100:]}"' + else: + content_summary = f'\n\n前面你已经输出的内容是:\n"{accumulated_text}"' + + detailed_continuation_prompt = f"""{CONTINUATION_PROMPT}{content_summary}""" + + # 添加继续指令 + continuation_message = {"role": "user", "parts": [{"text": detailed_continuation_prompt}]} + new_contents.append(continuation_message) + + request_data["contents"] = new_contents + continuation_payload["request"] = request_data + + return continuation_payload + + def _extract_content_from_chunk(self, data: Dict[str, Any]) -> str: + """从chunk数据中提取文本内容""" + content = "" + + # 先尝试解包 response 字段(Gemini API 格式) + if "response" in data: + data = data["response"] + + # 处理 Gemini 格式 + if "candidates" in data: + for candidate in data["candidates"]: + if "content" in candidate: + parts = candidate["content"].get("parts", []) + for part in parts: + if "text" in part: + content += part["text"] + + # 处理 OpenAI 流式格式(choices/delta) + elif "choices" in data: + for choice in data["choices"]: + if "delta" in choice and "content" in choice["delta"]: + delta_content = choice["delta"]["content"] + if delta_content: + content += delta_content + + return content + + async def _handle_non_streaming_response(self, response) -> bytes: + """处理非流式响应 - 使用循环代替递归避免栈溢出""" + # 使用循环代替递归 + while True: + try: + # 特殊处理:如果返回的是StreamingResponse,需要读取其body_iterator + if isinstance(response, StreamingResponse): + log.error("Anti-truncation: Received StreamingResponse in non-streaming handler - this should not happen") + # 尝试读取流式响应的内容 + chunks = [] + async for chunk in response.body_iterator: + chunks.append(chunk) + content = b"".join(chunks).decode() if chunks else "" + # 提取响应内容 + elif hasattr(response, "body"): + content = ( + response.body.decode() if isinstance(response.body, bytes) else response.body + ) + elif hasattr(response, "content"): + content = ( + response.content.decode() + if isinstance(response.content, bytes) + else response.content + ) + else: + log.error(f"Anti-truncation: Unknown response type: {type(response)}") + content = str(response) + + # 验证内容不为空 + if not content or not content.strip(): + log.error("Anti-truncation: Received empty response content") + return json.dumps( + { + "error": { + "message": "Empty response from server", + "type": "api_error", + "code": 500, + } + } + ).encode() + + # 尝试解析 JSON + try: + response_data = json.loads(content) + except json.JSONDecodeError as json_err: + log.error(f"Anti-truncation: Failed to parse JSON response: {json_err}, content: {content[:200]}") + # 如果不是 JSON,直接返回原始内容 + return content.encode() if isinstance(content, str) else content + + # 检查是否包含done标记 + text_content = self._extract_content_from_response(response_data) + has_done_marker = self._check_done_marker_in_text(text_content) + + if has_done_marker or self.current_attempt >= self.max_attempts: + # 找到done标记或达到最大尝试次数,返回结果 + return content.encode() if isinstance(content, str) else content + + # 需要继续,收集内容并构建下一个请求 + if text_content: + self._append_content(text_content) + + log.info("Anti-truncation: Non-streaming response needs continuation") + + # 增加尝试次数 + self.current_attempt += 1 + + # 构建续传payload并发送下一个请求 + next_payload = self._build_current_payload() + response = await self.original_request_func(next_payload) + + # 继续循环处理下一个响应 + + except Exception as e: + log.error(f"Anti-truncation non-streaming error: {str(e)}") + return json.dumps( + { + "error": { + "message": f"Anti-truncation failed: {str(e)}", + "type": "api_error", + "code": 500, + } + } + ).encode() + + def _check_done_marker_in_text(self, text: str) -> bool: + """检测文本中是否包含DONE_MARKER(只检测指定标记)""" + if not text: + return False + + # 只要文本中出现DONE_MARKER即可 + return DONE_MARKER in text + + def _check_done_marker_in_chunk_content(self, content: str) -> bool: + """检查单个chunk内容中是否包含done标记""" + return self._check_done_marker_in_text(content) + + def _extract_content_from_response(self, data: Dict[str, Any]) -> str: + """从响应数据中提取文本内容""" + content = "" + + # 先尝试解包 response 字段(Gemini API 格式) + if "response" in data: + data = data["response"] + + # 处理Gemini格式 + if "candidates" in data: + for candidate in data["candidates"]: + if "content" in candidate: + parts = candidate["content"].get("parts", []) + for part in parts: + if "text" in part: + content += part["text"] + + # 处理OpenAI格式 + elif "choices" in data: + for choice in data["choices"]: + if "message" in choice and "content" in choice["message"]: + content += choice["message"]["content"] + + return content + + def _remove_done_marker_from_line(self, line: bytes, line_str: str, data: Dict[str, Any]) -> bytes: + """从行中移除[done]标记""" + try: + # 首先检查是否真的包含[done]标记 + if "[done]" not in line_str.lower(): + return line # 没有[done]标记,直接返回原始行 + + log.info(f"Anti-truncation: Attempting to remove [done] marker from line") + log.debug(f"Anti-truncation: Original line (first 200 chars): {line_str[:200]}") + + # 编译正则表达式,匹配[done]标记(忽略大小写,包括可能的空白字符) + done_pattern = re.compile(r"\s*\[done\]\s*", re.IGNORECASE) + + # 检查是否有 response 包裹层 + has_response_wrapper = "response" in data + log.debug(f"Anti-truncation: has_response_wrapper={has_response_wrapper}, data keys={list(data.keys())}") + if has_response_wrapper: + # 需要保留外层的 response 字段 + inner_data = data["response"] + else: + inner_data = data + + log.debug(f"Anti-truncation: inner_data keys={list(inner_data.keys())}") + + log.debug(f"Anti-truncation: inner_data keys={list(inner_data.keys())}") + + # 处理Gemini格式 + if "candidates" in inner_data: + log.info(f"Anti-truncation: Processing Gemini format to remove [done] marker") + modified_inner = inner_data.copy() + modified_inner["candidates"] = [] + + for i, candidate in enumerate(inner_data["candidates"]): + modified_candidate = candidate.copy() + # 只在最后一个candidate中清理[done]标记 + is_last_candidate = i == len(inner_data["candidates"]) - 1 + + if "content" in candidate: + modified_content = candidate["content"].copy() + if "parts" in modified_content: + modified_parts = [] + for part in modified_content["parts"]: + if "text" in part and isinstance(part["text"], str): + modified_part = part.copy() + original_text = part["text"] + # 只在最后一个candidate中清理[done]标记 + if is_last_candidate: + modified_part["text"] = done_pattern.sub("", part["text"]) + if "[done]" in original_text.lower(): + log.debug(f"Anti-truncation: Removed [done] from text: '{original_text[:100]}' -> '{modified_part['text'][:100]}'") + modified_parts.append(modified_part) + else: + modified_parts.append(part) + modified_content["parts"] = modified_parts + modified_candidate["content"] = modified_content + modified_inner["candidates"].append(modified_candidate) + + # 如果有 response 包裹层,需要重新包装 + if has_response_wrapper: + modified_data = data.copy() + modified_data["response"] = modified_inner + else: + modified_data = modified_inner + + # 重新编码为行格式 - SSE格式需要两个换行符 + json_str = json.dumps(modified_data, separators=(",", ":"), ensure_ascii=False) + result = f"data: {json_str}\n\n".encode("utf-8") + log.debug(f"Anti-truncation: Modified line (first 200 chars): {result.decode('utf-8', errors='ignore')[:200]}") + return result + + # 处理OpenAI格式 + elif "choices" in inner_data: + modified_inner = inner_data.copy() + modified_inner["choices"] = [] + + for choice in inner_data["choices"]: + modified_choice = choice.copy() + if "delta" in choice and "content" in choice["delta"]: + modified_delta = choice["delta"].copy() + modified_delta["content"] = done_pattern.sub("", choice["delta"]["content"]) + modified_choice["delta"] = modified_delta + elif "message" in choice and "content" in choice["message"]: + modified_message = choice["message"].copy() + modified_message["content"] = done_pattern.sub("", choice["message"]["content"]) + modified_choice["message"] = modified_message + modified_inner["choices"].append(modified_choice) + + # 如果有 response 包裹层,需要重新包装 + if has_response_wrapper: + modified_data = data.copy() + modified_data["response"] = modified_inner + else: + modified_data = modified_inner + + # 重新编码为行格式 - SSE格式需要两个换行符 + json_str = json.dumps(modified_data, separators=(",", ":"), ensure_ascii=False) + return f"data: {json_str}\n\n".encode("utf-8") + + # 如果没有找到支持的格式,返回原始行 + return line + + except Exception as e: + log.warning(f"Failed to remove [done] marker from line: {str(e)}") + return line + + +async def apply_anti_truncation_to_stream( + request_func, + payload: Dict[str, Any], + max_attempts: int = 3, + enable_prefill_mode: bool = False, +) -> StreamingResponse: + """ + 对流式请求应用反截断处理 + + Args: + request_func: 原始请求函数 + payload: 请求payload + max_attempts: 最大续传尝试次数 + enable_prefill_mode: 是否启用预填充模式。启用后续传请求不再添加 user 续写指令, + 而是将已收集内容作为末尾 model 内容进行预填充 + + Returns: + 处理后的StreamingResponse + """ + + # 首先对payload应用反截断指令 + anti_truncation_payload = apply_anti_truncation(payload) + + # 创建反截断处理器 + processor = AntiTruncationStreamProcessor( + lambda p: request_func(p), + anti_truncation_payload, + max_attempts, + enable_prefill_mode, + ) + + # 返回包装后的流式响应 + return StreamingResponse(processor.process_stream(), media_type="text/event-stream") + + +def is_anti_truncation_enabled(request_data: Dict[str, Any]) -> bool: + """ + 检查请求是否启用了反截断功能 + + Args: + request_data: 请求数据 + + Returns: + 是否启用反截断 + """ + return request_data.get("enable_anti_truncation", False) \ No newline at end of file diff --git a/src/converter/fake_stream.py b/src/converter/fake_stream.py new file mode 100644 index 0000000000000000000000000000000000000000..7fb0ba1062df7874a7ba3976730e1a51cca47a54 --- /dev/null +++ b/src/converter/fake_stream.py @@ -0,0 +1,537 @@ +from typing import Any, Dict, List, Tuple +import json +from src.converter.utils import extract_content_and_reasoning +from log import log +from src.converter.openai2gemini import _convert_usage_metadata + +def safe_get_nested(obj: Any, *keys: str, default: Any = None) -> Any: + """安全获取嵌套字典值 + + Args: + obj: 字典对象 + *keys: 嵌套键路径 + default: 默认值 + + Returns: + 获取到的值或默认值 + """ + for key in keys: + if not isinstance(obj, dict): + return default + obj = obj.get(key, default) + if obj is default: + return default + return obj + +def parse_response_for_fake_stream(response_data: Dict[str, Any]) -> tuple: + """从完整响应中提取内容和推理内容(用于假流式) + + Args: + response_data: Gemini API 响应数据 + + Returns: + (content, reasoning_content, finish_reason, images): 内容、推理内容、结束原因和图片数据的元组 + """ + import json + + # 处理GeminiCLI的response包装格式 + if "response" in response_data and "candidates" not in response_data: + log.debug(f"[FAKE_STREAM] Unwrapping response field") + response_data = response_data["response"] + + candidates = response_data.get("candidates", []) + log.debug(f"[FAKE_STREAM] Found {len(candidates)} candidates") + if not candidates: + return "", "", "STOP", [] + + candidate = candidates[0] + finish_reason = candidate.get("finishReason", "STOP") + parts = safe_get_nested(candidate, "content", "parts", default=[]) + log.debug(f"[FAKE_STREAM] Extracted {len(parts)} parts: {json.dumps(parts, ensure_ascii=False)}") + content, reasoning_content, images = extract_content_and_reasoning(parts) + log.debug(f"[FAKE_STREAM] Content length: {len(content)}, Reasoning length: {len(reasoning_content)}, Images count: {len(images)}") + + return content, reasoning_content, finish_reason, images + +def extract_fake_stream_content(response: Any) -> Tuple[str, str, Dict[str, int]]: + """ + 从 Gemini 非流式响应中提取内容,用于假流式处理 + + Args: + response: Gemini API 响应对象 + + Returns: + (content, reasoning_content, usage) 元组 + """ + from src.converter.utils import extract_content_and_reasoning + + # 解析响应体 + if hasattr(response, "body"): + body_str = ( + response.body.decode() + if isinstance(response.body, bytes) + else str(response.body) + ) + elif hasattr(response, "content"): + body_str = ( + response.content.decode() + if isinstance(response.content, bytes) + else str(response.content) + ) + else: + body_str = str(response) + + try: + response_data = json.loads(body_str) + + # GeminiCLI 返回的格式是 {"response": {...}, "traceId": "..."} + # 需要先提取 response 字段 + if "response" in response_data: + gemini_response = response_data["response"] + else: + gemini_response = response_data + + # 从Gemini响应中提取内容,使用思维链分离逻辑 + content = "" + reasoning_content = "" + images = [] + if "candidates" in gemini_response and gemini_response["candidates"]: + # Gemini格式响应 - 使用思维链分离 + candidate = gemini_response["candidates"][0] + if "content" in candidate and "parts" in candidate["content"]: + parts = candidate["content"]["parts"] + content, reasoning_content, images = extract_content_and_reasoning(parts) + elif "choices" in gemini_response and gemini_response["choices"]: + # OpenAI格式响应 + content = gemini_response["choices"][0].get("message", {}).get("content", "") + + # 如果没有正常内容但有思维内容,给出警告 + if not content and reasoning_content: + log.warning("Fake stream response contains only thinking content") + content = "[模型正在思考中,请稍后再试或重新提问]" + + # 如果完全没有内容,提供默认回复 + if not content: + log.warning(f"No content found in response: {gemini_response}") + content = "[响应为空,请重新尝试]" + + # 转换usageMetadata为OpenAI格式 + usage = _convert_usage_metadata(gemini_response.get("usageMetadata")) + + return content, reasoning_content, usage + + except json.JSONDecodeError: + # 如果不是JSON,直接返回原始文本 + return body_str, "", None + +def _build_candidate(parts: List[Dict[str, Any]], finish_reason: str = "STOP") -> Dict[str, Any]: + """构建标准候选响应结构 + + Args: + parts: parts 列表 + finish_reason: 结束原因 + + Returns: + 候选响应字典 + """ + return { + "candidates": [{ + "content": {"parts": parts, "role": "model"}, + "finishReason": finish_reason, + "index": 0, + }] + } + +def create_openai_heartbeat_chunk() -> Dict[str, Any]: + """ + 创建 OpenAI 格式的心跳块(用于假流式) + + Returns: + 心跳响应块字典 + """ + return { + "choices": [ + { + "index": 0, + "delta": {"role": "assistant", "content": ""}, + "finish_reason": None, + } + ] + } + +def build_gemini_fake_stream_chunks(content: str, reasoning_content: str, finish_reason: str, images: List[Dict[str, Any]] = None, chunk_size: int = 50) -> List[Dict[str, Any]]: + """构建假流式响应的数据块 + + Args: + content: 主要内容 + reasoning_content: 推理内容 + finish_reason: 结束原因 + images: 图片数据列表(可选) + chunk_size: 每个chunk的字符数(默认50) + + Returns: + 响应数据块列表 + """ + if images is None: + images = [] + + log.debug(f"[build_gemini_fake_stream_chunks] Input - content: {repr(content)}, reasoning: {repr(reasoning_content)}, finish_reason: {finish_reason}, images count: {len(images)}") + chunks = [] + + # 如果没有正常内容但有思维内容,提供默认回复 + if not content: + default_text = "[模型正在思考中,请稍后再试或重新提问]" if reasoning_content else "[响应为空,请重新尝试]" + return [_build_candidate([{"text": default_text}], finish_reason)] + + # 分块发送主要内容 + first_chunk = True + for i in range(0, len(content), chunk_size): + chunk_text = content[i:i + chunk_size] + is_last_chunk = (i + chunk_size >= len(content)) and not reasoning_content + chunk_finish_reason = finish_reason if is_last_chunk else None + + # 如果是第一个chunk且有图片,将图片包含在parts中 + parts = [] + if first_chunk and images: + # 在Gemini格式中,需要将image_url格式转换为inlineData格式 + for img in images: + if img.get("type") == "image_url": + url = img.get("image_url", {}).get("url", "") + # 解析 data URL: data:{mime_type};base64,{data} + if url.startswith("data:"): + parts_of_url = url.split(";base64,") + if len(parts_of_url) == 2: + mime_type = parts_of_url[0].replace("data:", "") + base64_data = parts_of_url[1] + parts.append({ + "inlineData": { + "mimeType": mime_type, + "data": base64_data + } + }) + first_chunk = False + + parts.append({"text": chunk_text}) + chunk_data = _build_candidate(parts, chunk_finish_reason) + log.debug(f"[build_gemini_fake_stream_chunks] Generated chunk: {chunk_data}") + chunks.append(chunk_data) + + # 如果有推理内容,分块发送 + if reasoning_content: + for i in range(0, len(reasoning_content), chunk_size): + chunk_text = reasoning_content[i:i + chunk_size] + is_last_chunk = i + chunk_size >= len(reasoning_content) + chunk_finish_reason = finish_reason if is_last_chunk else None + chunks.append(_build_candidate([{"text": chunk_text, "thought": True}], chunk_finish_reason)) + + log.debug(f"[build_gemini_fake_stream_chunks] Total chunks generated: {len(chunks)}") + return chunks + + +def create_gemini_heartbeat_chunk() -> Dict[str, Any]: + """创建 Gemini 格式的心跳数据块 + + Returns: + 心跳数据块 + """ + chunk = _build_candidate([{"text": ""}]) + chunk["candidates"][0]["finishReason"] = None + return chunk + + +def build_openai_fake_stream_chunks(content: str, reasoning_content: str, finish_reason: str, model: str, images: List[Dict[str, Any]] = None, chunk_size: int = 50) -> List[Dict[str, Any]]: + """构建 OpenAI 格式的假流式响应数据块 + + Args: + content: 主要内容 + reasoning_content: 推理内容 + finish_reason: 结束原因(如 "STOP", "MAX_TOKENS") + model: 模型名称 + images: 图片数据列表(可选) + chunk_size: 每个chunk的字符数(默认50) + + Returns: + OpenAI 格式的响应数据块列表 + """ + import time + import uuid + + if images is None: + images = [] + + log.debug(f"[build_openai_fake_stream_chunks] Input - content: {repr(content)}, reasoning: {repr(reasoning_content)}, finish_reason: {finish_reason}, images count: {len(images)}") + chunks = [] + response_id = f"chatcmpl-{uuid.uuid4().hex[:24]}" + created = int(time.time()) + + # 映射 Gemini finish_reason 到 OpenAI 格式 + openai_finish_reason = None + if finish_reason == "STOP": + openai_finish_reason = "stop" + elif finish_reason == "MAX_TOKENS": + openai_finish_reason = "length" + elif finish_reason in ["SAFETY", "RECITATION"]: + openai_finish_reason = "content_filter" + + # 如果没有正常内容但有思维内容,提供默认回复 + if not content: + default_text = "[模型正在思考中,请稍后再试或重新提问]" if reasoning_content else "[响应为空,请重新尝试]" + return [{ + "id": response_id, + "object": "chat.completion.chunk", + "created": created, + "model": model, + "choices": [{ + "index": 0, + "delta": {"content": default_text}, + "finish_reason": openai_finish_reason, + }] + }] + + # 分块发送主要内容 + first_chunk = True + for i in range(0, len(content), chunk_size): + chunk_text = content[i:i + chunk_size] + is_last_chunk = (i + chunk_size >= len(content)) and not reasoning_content + chunk_finish = openai_finish_reason if is_last_chunk else None + + delta_content = {} + + # 如果是第一个chunk且有图片,构建包含图片的content数组 + if first_chunk and images: + delta_content["content"] = images + [{"type": "text", "text": chunk_text}] + first_chunk = False + else: + delta_content["content"] = chunk_text + + chunk_data = { + "id": response_id, + "object": "chat.completion.chunk", + "created": created, + "model": model, + "choices": [{ + "index": 0, + "delta": delta_content, + "finish_reason": chunk_finish, + }] + } + log.debug(f"[build_openai_fake_stream_chunks] Generated chunk: {chunk_data}") + chunks.append(chunk_data) + + # 如果有推理内容,分块发送(使用 reasoning_content 字段) + if reasoning_content: + for i in range(0, len(reasoning_content), chunk_size): + chunk_text = reasoning_content[i:i + chunk_size] + is_last_chunk = i + chunk_size >= len(reasoning_content) + chunk_finish = openai_finish_reason if is_last_chunk else None + + chunks.append({ + "id": response_id, + "object": "chat.completion.chunk", + "created": created, + "model": model, + "choices": [{ + "index": 0, + "delta": {"reasoning_content": chunk_text}, + "finish_reason": chunk_finish, + }] + }) + + log.debug(f"[build_openai_fake_stream_chunks] Total chunks generated: {len(chunks)}") + return chunks + + +def create_anthropic_heartbeat_chunk() -> Dict[str, Any]: + """ + 创建 Anthropic 格式的心跳块(用于假流式) + + Returns: + 心跳响应块字典 + """ + return { + "type": "ping" + } + + +def build_anthropic_fake_stream_chunks(content: str, reasoning_content: str, finish_reason: str, model: str, images: List[Dict[str, Any]] = None, chunk_size: int = 50) -> List[Dict[str, Any]]: + """构建 Anthropic 格式的假流式响应数据块 + + Args: + content: 主要内容 + reasoning_content: 推理内容(thinking content) + finish_reason: 结束原因(如 "STOP", "MAX_TOKENS") + model: 模型名称 + images: 图片数据列表(可选) + chunk_size: 每个chunk的字符数(默认50) + + Returns: + Anthropic SSE 格式的响应数据块列表 + """ + import uuid + + if images is None: + images = [] + + log.debug(f"[build_anthropic_fake_stream_chunks] Input - content: {repr(content)}, reasoning: {repr(reasoning_content)}, finish_reason: {finish_reason}, images count: {len(images)}") + chunks = [] + message_id = f"msg_{uuid.uuid4().hex}" + + # 映射 Gemini finish_reason 到 Anthropic 格式 + anthropic_stop_reason = "end_turn" + if finish_reason == "MAX_TOKENS": + anthropic_stop_reason = "max_tokens" + elif finish_reason in ["SAFETY", "RECITATION"]: + anthropic_stop_reason = "end_turn" + + # 1. 发送 message_start 事件 + chunks.append({ + "type": "message_start", + "message": { + "id": message_id, + "type": "message", + "role": "assistant", + "model": model, + "content": [], + "stop_reason": None, + "stop_sequence": None, + "usage": {"input_tokens": 0, "output_tokens": 0} + } + }) + + # 如果没有正常内容但有思维内容,提供默认回复 + if not content: + default_text = "[模型正在思考中,请稍后再试或重新提问]" if reasoning_content else "[响应为空,请重新尝试]" + + # content_block_start + chunks.append({ + "type": "content_block_start", + "index": 0, + "content_block": {"type": "text", "text": ""} + }) + + # content_block_delta + chunks.append({ + "type": "content_block_delta", + "index": 0, + "delta": {"type": "text_delta", "text": default_text} + }) + + # content_block_stop + chunks.append({ + "type": "content_block_stop", + "index": 0 + }) + + # message_delta + chunks.append({ + "type": "message_delta", + "delta": {"stop_reason": anthropic_stop_reason, "stop_sequence": None}, + "usage": {"output_tokens": 0} + }) + + # message_stop + chunks.append({ + "type": "message_stop" + }) + + return chunks + + block_index = 0 + + # 2. 如果有推理内容,先发送 thinking 块 + if reasoning_content: + # thinking content_block_start + chunks.append({ + "type": "content_block_start", + "index": block_index, + "content_block": {"type": "thinking", "thinking": ""} + }) + + # 分块发送推理内容 + for i in range(0, len(reasoning_content), chunk_size): + chunk_text = reasoning_content[i:i + chunk_size] + chunks.append({ + "type": "content_block_delta", + "index": block_index, + "delta": {"type": "thinking_delta", "thinking": chunk_text} + }) + + # thinking content_block_stop + chunks.append({ + "type": "content_block_stop", + "index": block_index + }) + + block_index += 1 + + # 3. 如果有图片,发送图片块 + if images: + for img in images: + if img.get("type") == "image_url": + url = img.get("image_url", {}).get("url", "") + # 解析 data URL: data:{mime_type};base64,{data} + if url.startswith("data:"): + parts_of_url = url.split(";base64,") + if len(parts_of_url) == 2: + mime_type = parts_of_url[0].replace("data:", "") + base64_data = parts_of_url[1] + + # image content_block_start + chunks.append({ + "type": "content_block_start", + "index": block_index, + "content_block": { + "type": "image", + "source": { + "type": "base64", + "media_type": mime_type, + "data": base64_data + } + } + }) + + # image content_block_stop + chunks.append({ + "type": "content_block_stop", + "index": block_index + }) + + block_index += 1 + + # 4. 发送主要内容(text 块) + # text content_block_start + chunks.append({ + "type": "content_block_start", + "index": block_index, + "content_block": {"type": "text", "text": ""} + }) + + # 分块发送主要内容 + for i in range(0, len(content), chunk_size): + chunk_text = content[i:i + chunk_size] + chunks.append({ + "type": "content_block_delta", + "index": block_index, + "delta": {"type": "text_delta", "text": chunk_text} + }) + + # text content_block_stop + chunks.append({ + "type": "content_block_stop", + "index": block_index + }) + + # 5. 发送 message_delta + chunks.append({ + "type": "message_delta", + "delta": {"stop_reason": anthropic_stop_reason, "stop_sequence": None}, + "usage": {"output_tokens": len(content) + len(reasoning_content)} + }) + + # 6. 发送 message_stop + chunks.append({ + "type": "message_stop" + }) + + log.debug(f"[build_anthropic_fake_stream_chunks] Total chunks generated: {len(chunks)}") + return chunks \ No newline at end of file diff --git a/src/converter/gemini_fix.py b/src/converter/gemini_fix.py new file mode 100644 index 0000000000000000000000000000000000000000..6feaf44e5e30c38bf187234689c1e4131119dd2a --- /dev/null +++ b/src/converter/gemini_fix.py @@ -0,0 +1,472 @@ +""" +Gemini Format Utilities - 统一的 Gemini 格式处理和转换工具 +提供对 Gemini API 请求体和响应的标准化处理 +──────────────────────────────────────────────────────────────── +""" +from math import e +from typing import Any, Dict, Optional + +from log import log + +# ==================== Gemini API 配置 ==================== + +# ====================== Model Configuration ====================== + +# Default Safety Settings for Google API +DEFAULT_SAFETY_SETTINGS = [ + {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"}, + {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"}, + {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE"}, + {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_NONE"}, + {"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "BLOCK_NONE"}, + {"category": "HARM_CATEGORY_IMAGE_HATE", "threshold": "BLOCK_NONE"}, + {"category": "HARM_CATEGORY_IMAGE_DANGEROUS_CONTENT", "threshold": "BLOCK_NONE"}, + {"category": "HARM_CATEGORY_IMAGE_HARASSMENT", "threshold": "BLOCK_NONE"}, + {"category": "HARM_CATEGORY_IMAGE_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE"}, + {"category": "HARM_CATEGORY_JAILBREAK", "threshold": "BLOCK_NONE"}, +] + +LITE_SAFETY_SETTINGS = [ + {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"}, + {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE"}, + {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_NONE"}, + {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"}, + {"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "BLOCK_NONE"}, +] + +def prepare_image_generation_request( + request_body: Dict[str, Any], + model: str +) -> Dict[str, Any]: + """ + 图像生成模型请求体后处理 + + Args: + request_body: 原始请求体 + model: 模型名称 + + Returns: + 处理后的请求体 + """ + request_body = request_body.copy() + model_lower = model.lower() + + # 解析分辨率 + image_size = "4K" if "-4k" in model_lower else "2K" if "-2k" in model_lower else None + + # 解析比例 + aspect_ratio = None + for suffix, ratio in [ + ("-21x9", "21:9"), ("-16x9", "16:9"), ("-9x16", "9:16"), + ("-4x3", "4:3"), ("-3x4", "3:4"), ("-1x1", "1:1") + ]: + if suffix in model_lower: + aspect_ratio = ratio + break + + # 构建 imageConfig + image_config = {} + if aspect_ratio: + image_config["aspectRatio"] = aspect_ratio + if image_size: + image_config["imageSize"] = image_size + + request_body["model"] = "gemini-3.1-flash-image" # 统一使用基础模型名 + request_body["generationConfig"] = { + "candidateCount": 1, + "imageConfig": image_config + } + + # 移除不需要的字段 + for key in ("systemInstruction", "tools", "toolConfig"): + request_body.pop(key, None) + + return request_body + + +# ==================== 模型特性辅助函数 ==================== + +def get_base_model_name(model_name: str) -> str: + """移除模型名称中的后缀,返回基础模型名""" + # 按照从长到短的顺序排列,避免短后缀先于长后缀被匹配 + suffixes = [ + "-maxthinking", "-nothinking", # 兼容旧模式 + "-minimal", "-medium", "-search", "-think", # 中等长度后缀 + "-high", "-max", "-low" # 短后缀 + ] + result = model_name + changed = True + # 持续循环直到没有任何后缀可以移除 + while changed: + changed = False + for suffix in suffixes: + if result.endswith(suffix): + result = result[:-len(suffix)] + changed = True + # 不使用 break,继续检查是否还有其他后缀 + return result + + +def get_thinking_settings(model_name: str) -> tuple[Optional[int], Optional[str]]: + """ + 根据模型名称获取思考配置 + + 支持两种模式: + 1. CLI 模式思考预算 (Gemini 2.5 系列): -max, -high, -medium, -low, -minimal + 2. CLI 模式思考等级 (Gemini 3 Preview 系列): -high, -medium, -low, -minimal (仅 3-flash) + 3. 兼容旧模式: -maxthinking, -nothinking (不返回给用户) + + Returns: + (thinking_budget, thinking_level): 思考预算和思考等级 + """ + base_model = get_base_model_name(model_name) + + # ========== 兼容旧模式 (不返回给用户) ========== + if "-nothinking" in model_name: + # nothinking 模式: 限制思考 + if "flash" in base_model: + return 0, None + return 128, None + elif "-maxthinking" in model_name: + # maxthinking 模式: 最大思考预算 + budget = 24576 if "flash" in base_model else 32768 + if "gemini-3" in base_model: + # Gemini 3 系列不支持 thinkingBudget,返回 high 等级 + return None, "high" + else: + return budget, None + + # ========== 新 CLI 模式: 基于思考预算/等级 ========== + + # Gemini 3 Preview 系列: 使用 thinkingLevel + if "gemini-3" in base_model: + if "-high" in model_name: + return None, "high" + elif "-medium" in model_name: + # 仅 3-flash-preview 支持 medium + if "flash" in base_model: + return None, "medium" + # pro 系列不支持 medium,返回 Default + return None, None + elif "-low" in model_name: + return None, "low" + elif "-minimal" in model_name: + return None, None + else: + # Default: 不设置 thinking 配置 + return None, None + + # Gemini 2.5 系列: 使用 thinkingBudget + elif "gemini-2.5" in base_model: + if "-max" in model_name: + # 2.5-flash-max: 24576, 2.5-pro-max: 32768 + budget = 24576 if "flash" in base_model else 32768 + return budget, None + elif "-high" in model_name: + # 2.5-flash-high: 16000, 2.5-pro-high: 16000 + return 16000, None + elif "-medium" in model_name: + # 2.5-flash-medium: 8192, 2.5-pro-medium: 8192 + return 8192, None + elif "-low" in model_name: + # 2.5-flash-low: 1024, 2.5-pro-low: 1024 + return 1024, None + elif "-minimal" in model_name: + # 2.5-flash-minimal: 0, 2.5-pro-minimal: 128 + budget = 0 if "flash" in base_model else 128 + return budget, None + else: + # Default: 不设置 thinking budget + return None, None + + # 其他模型: 不设置 thinking 配置 + return None, None + + +def is_search_model(model_name: str) -> bool: + """检查是否为搜索模型""" + return "-search" in model_name + + +# ==================== 统一的 Gemini 请求后处理 ==================== + +def is_thinking_model(model_name: str) -> bool: + """检查是否为思考模型 (包含 -thinking 或 pro)""" + return "think" in model_name or "pro" in model_name.lower() + + +async def normalize_gemini_request( + request: Dict[str, Any], + mode: str = "geminicli" +) -> Dict[str, Any]: + """ + 规范化 Gemini 请求 + + 处理逻辑: + 1. 模型特性处理 (thinking config, search tools) + 3. 参数范围限制 (maxOutputTokens, topK) + 4. 工具清理 + + Args: + request: 原始请求字典 + mode: 模式 ("geminicli" 或 "antigravity") + + Returns: + 规范化后的请求 + """ + # 导入配置函数 + from config import get_return_thoughts_to_frontend + + result = request.copy() + model = result.get("model", "") + generation_config = (result.get("generationConfig") or {}).copy() # 创建副本避免修改原对象 + tools = result.get("tools") + system_instruction = result.get("systemInstruction") or result.get("system_instructions") + + # 记录原始请求 + log.debug(f"[GEMINI_FIX] 原始请求 - 模型: {model}, mode: {mode}, generationConfig: {generation_config}") + + # 获取配置值 + return_thoughts = await get_return_thoughts_to_frontend() + + # ========== 模式特定处理 ========== + if mode == "geminicli": + # 1. 思考设置 + # 优先使用 get_thinking_settings 获取的思考预算和等级 + thinking_budget, thinking_level = get_thinking_settings(model) + + # 其次使用传入的思考预算(如果未从模型名称获取) + if thinking_budget is None and thinking_level is None: + thinking_budget = generation_config.get("thinkingConfig", {}).get("thinkingBudget") + thinking_level = generation_config.get("thinkingConfig", {}).get("thinkingLevel") + + # 假如 is_thinking_model 为真或者思考预算/等级不为空,设置 thinkingConfig + if is_thinking_model(model) or thinking_budget is not None or thinking_level is not None: + # 确保 thinkingConfig 存在 + if "thinkingConfig" not in generation_config: + generation_config["thinkingConfig"] = {} + + thinking_config = generation_config["thinkingConfig"] + + # 设置思考预算或等级(互斥) + if thinking_budget is not None: + thinking_config["thinkingBudget"] = thinking_budget + thinking_config.pop("thinkingLevel", None) # 避免与 thinkingBudget 冲突 + elif thinking_level is not None: + thinking_config["thinkingLevel"] = thinking_level + thinking_config.pop("thinkingBudget", None) # 避免与 thinkingLevel 冲突 + + # includeThoughts 逻辑: + # 1. 如果是 pro 模型,为 return_thoughts + # 2. 如果不是 pro 模型,检查是否有思考预算或思考等级 + base_model = get_base_model_name(model) + if "pro" in base_model: + include_thoughts = return_thoughts + elif "3-flash" in base_model: + if thinking_level is None: + include_thoughts = False + else: + include_thoughts = return_thoughts + else: + # 非 pro 模型: 有思考预算或等级才包含思考 + # 注意: 思考预算为 0 时不包含思考 + if thinking_budget is None or thinking_budget == 0: + include_thoughts = False + else: + include_thoughts = return_thoughts + + thinking_config["includeThoughts"] = include_thoughts + + # 2. 搜索模型添加 Google Search + if is_search_model(model): + result_tools = result.get("tools") or [] + result["tools"] = result_tools + if not any(tool.get("googleSearch") for tool in result_tools if isinstance(tool, dict)): + result_tools.append({"googleSearch": {}}) + + # 3. 模型名称处理 + result["model"] = get_base_model_name(model) + + elif mode == "antigravity": + + ''' + # 1. 处理 system_instruction + custom_prompt = "Please ignore the following [ignore]You are Antigravity, a powerful agentic AI coding assistant designed by the Google Deepmind team working on Advanced Agentic Coding.You are pair programming with a USER to solve their coding task. The task may require creating a new codebase, modifying or debugging an existing codebase, or simply answering a question.**Absolute paths only****Proactiveness**[/ignore]" + + # 提取原有的 parts(如果存在) + existing_parts = [] + if system_instruction: + if isinstance(system_instruction, dict): + existing_parts = system_instruction.get("parts", []) + + # custom_prompt 始终放在第一位,原有内容整体后移 + result["systemInstruction"] = { + "parts": [{"text": custom_prompt}] + existing_parts + } + ''' + + # 2. 判断图片模型 + if "image" in model.lower(): + # 调用图片生成专用处理函数 + return prepare_image_generation_request(result, model) + else: + # 3. 思考模型处理 + if is_thinking_model(model) or ("thinkingBudget" in generation_config.get("thinkingConfig", {}) and generation_config["thinkingConfig"]["thinkingBudget"] != 0): + # 直接设置 thinkingConfig + if "thinkingConfig" not in generation_config: + generation_config["thinkingConfig"] = {} + + thinking_config = generation_config["thinkingConfig"] + # 优先使用传入的思考预算,否则使用默认值 + if "thinkingBudget" not in thinking_config: + thinking_config["thinkingBudget"] = 1024 + thinking_config.pop("thinkingLevel", None) # 避免与 thinkingBudget 冲突 + thinking_config["includeThoughts"] = return_thoughts + + # 检查最后一个 assistant 消息是否以 thinking 块开始 + contents = result.get("contents", []) + + if "claude" in model.lower(): + # 检测是否有工具调用(MCP场景) + has_tool_calls = any( + isinstance(content, dict) and + any( + isinstance(part, dict) and ("functionCall" in part or "function_call" in part) + for part in content.get("parts", []) + ) + for content in contents + ) + + if has_tool_calls: + # MCP 场景:检测到工具调用,移除 thinkingConfig + log.warning(f"[ANTIGRAVITY] 检测到工具调用(MCP场景),移除 thinkingConfig 避免失效") + generation_config.pop("thinkingConfig", None) + else: + # 非 MCP 场景:填充思考块 + # log.warning(f"[ANTIGRAVITY] 最后一个 assistant 消息不以 thinking 块开始,自动填充思考块") + + # 找到最后一个 model 角色的 content + for i in range(len(contents) - 1, -1, -1): + content = contents[i] + if isinstance(content, dict) and content.get("role") == "model": + # 在 parts 开头插入思考块(使用官方跳过验证的虚拟签名) + parts = content.get("parts", []) + thinking_part = { + "text": "...", + # "thought": True, # 标记为思考块 + "thoughtSignature": "skip_thought_signature_validator" # 官方文档推荐的虚拟签名 + } + # 如果第一个 part 不是 thinking,则插入 + if not parts or not (isinstance(parts[0], dict) and ("thought" in parts[0] or "thoughtSignature" in parts[0])): + content["parts"] = [thinking_part] + parts + log.debug(f"[ANTIGRAVITY] 已在最后一个 assistant 消息开头插入思考块(含跳过验证签名)") + break + + # 移除 -thinking 后缀 + model = model.replace("-thinking", "") + + # 4. Claude 模型关键词映射 + # 使用关键词匹配而不是精确匹配,更灵活地处理各种变体 + original_model = model + if "opus" in model.lower(): + model = "claude-opus-4-6-thinking" + elif "sonnet" in model.lower(): + model = "claude-sonnet-4-6" + elif "haiku" in model.lower(): + model = "gemini-2.5-flash" + elif "claude" in model.lower(): + # Claude 模型兜底:如果包含 claude 但不是 opus/sonnet/haiku + model = "claude-sonnet-4-6" + + result["model"] = model + if original_model != model: + log.debug(f"[ANTIGRAVITY] 映射模型: {original_model} -> {model}") + + # 5. 模型特殊处理:循环移除末尾的 model 消息,保证以用户消息结尾 + # 因为该模型不支持预填充 + if "claude-opus-4-6-thinking" in model.lower() or "claude-sonnet-4-6" in model.lower(): + contents = result.get("contents", []) + removed_count = 0 + while contents and isinstance(contents[-1], dict) and contents[-1].get("role") == "model": + contents.pop() + removed_count += 1 + if removed_count > 0: + log.warning(f"[ANTIGRAVITY] {model} 不支持预填充,移除了 {removed_count} 条末尾 model 消息") + result["contents"] = contents + + # 6. 移除 antigravity 模式不支持的字段 + generation_config.pop("presencePenalty", None) + generation_config.pop("frequencyPenalty", None) + generation_config.pop("stopSequences", None) + + # ========== 公共处理 ========== + + # 1. 安全设置覆盖 + if "lite" in model.lower(): + result["safetySettings"] = LITE_SAFETY_SETTINGS + else: + result["safetySettings"] = DEFAULT_SAFETY_SETTINGS + + # 2. 参数范围限制 + if generation_config: + # 强制设置 maxOutputTokens 为 64000 + generation_config["maxOutputTokens"] = 64000 + # 强制设置 topK 为 64 + generation_config["topK"] = 64 + + if "contents" in result: + cleaned_contents = [] + for content in result["contents"]: + if isinstance(content, dict) and "parts" in content: + # 过滤掉空的或无效的 parts + valid_parts = [] + for part in content["parts"]: + if not isinstance(part, dict): + continue + + # 检查 part 是否有有效的非空值 + # 过滤掉空字典或所有值都为空的 part + has_valid_value = any( + value not in (None, "", {}, []) + for key, value in part.items() + if key != "thought" # thought 字段可以为空 + ) + + if has_valid_value: + part = part.copy() + + # 修复 text 字段:确保是字符串而不是列表 + if "text" in part: + text_value = part["text"] + if isinstance(text_value, list): + # 如果是列表,合并为字符串 + log.warning(f"[GEMINI_FIX] text 字段是列表,自动合并: {text_value}") + part["text"] = " ".join(str(t) for t in text_value if t) + elif isinstance(text_value, str): + # 清理尾随空格 + part["text"] = text_value.rstrip() + else: + # 其他类型转为字符串 + log.warning(f"[GEMINI_FIX] text 字段类型异常 ({type(text_value)}), 转为字符串: {text_value}") + part["text"] = str(text_value) + + valid_parts.append(part) + else: + log.warning(f"[GEMINI_FIX] 移除空的或无效的 part: {part}") + + # 只添加有有效 parts 的 content + if valid_parts: + cleaned_content = content.copy() + cleaned_content["parts"] = valid_parts + cleaned_contents.append(cleaned_content) + else: + log.warning(f"[GEMINI_FIX] 跳过没有有效 parts 的 content: {content.get('role')}") + else: + cleaned_contents.append(content) + + result["contents"] = cleaned_contents + + if generation_config: + result["generationConfig"] = generation_config + + return result \ No newline at end of file diff --git a/src/converter/openai2gemini.py b/src/converter/openai2gemini.py new file mode 100644 index 0000000000000000000000000000000000000000..debd6fb934b7d49ed3a647009191ad1d6dfc0875 --- /dev/null +++ b/src/converter/openai2gemini.py @@ -0,0 +1,1533 @@ +""" +OpenAI Transfer Module - Handles conversion between OpenAI and Gemini API formats +被openai-router调用,负责OpenAI格式与Gemini格式的双向转换 +""" + +import json +import time +import uuid +from typing import Any, Dict, List, Optional, Tuple, Union + +from pypinyin import Style, lazy_pinyin + +from src.converter.thoughtSignature_fix import ( + encode_tool_id_with_signature, + decode_tool_id_and_signature, +) +from src.converter.utils import merge_system_messages + +from log import log + +def _convert_usage_metadata(usage_metadata: Dict[str, Any]) -> Dict[str, int]: + """ + 将Gemini的usageMetadata转换为OpenAI格式的usage字段 + + Args: + usage_metadata: Gemini API的usageMetadata字段 + + Returns: + OpenAI格式的usage字典,如果没有usage数据则返回None + """ + if not usage_metadata: + return None + + return { + "prompt_tokens": usage_metadata.get("promptTokenCount", 0), + "completion_tokens": usage_metadata.get("candidatesTokenCount", 0), + "total_tokens": usage_metadata.get("totalTokenCount", 0), + } + + +def _build_message_with_reasoning(role: str, content: str, reasoning_content: str) -> dict: + """构建包含可选推理内容的消息对象""" + message = {"role": role, "content": content} + + # 如果有thinking tokens,添加reasoning_content + if reasoning_content: + message["reasoning_content"] = reasoning_content + + return message + + +def _map_finish_reason(gemini_reason: str) -> str: + """ + 将Gemini结束原因映射到OpenAI结束原因 + + Args: + gemini_reason: 来自Gemini API的结束原因 + + Returns: + OpenAI兼容的结束原因 + """ + if gemini_reason == "STOP": + return "stop" + elif gemini_reason == "MAX_TOKENS": + return "length" + elif gemini_reason in ["SAFETY", "RECITATION"]: + return "content_filter" + else: + # 对于 None 或未知的 finishReason,返回 "stop" 作为默认值 + # 避免返回 None 导致 MCP 客户端误判为响应未完成而循环调用 + return "stop" + + +# ==================== Tool Conversion Functions ==================== + + +def _normalize_function_name(name: str) -> str: + """ + 规范化函数名以符合 Gemini API 要求 + + 规则: + - 必须以字母或下划线开头 + - 只能包含 a-z, A-Z, 0-9, 下划线, 英文句点, 英文短划线 + - 最大长度 64 个字符 + + 转换策略: + 1. 中文字符转换为拼音 + 2. 将非法字符替换为下划线 + 3. 如果以非字母/下划线开头,添加下划线前缀 + 4. 截断到 64 个字符 + + Args: + name: 原始函数名 + + Returns: + 规范化后的函数名 + """ + import re + + if not name: + return "_unnamed_function" + + # 步骤1:转换中文字符为拼音 + if re.search(r"[\u4e00-\u9fff]", name): + try: + parts = [] + for char in name: + if "\u4e00" <= char <= "\u9fff": + # 中文字符转换为拼音 + pinyin = lazy_pinyin(char, style=Style.NORMAL) + parts.append("".join(pinyin)) + else: + parts.append(char) + normalized = "".join(parts) + except ImportError: + log.warning("pypinyin not installed, cannot convert Chinese characters to pinyin") + normalized = name + else: + normalized = name + + # 步骤2:将非法字符替换为下划线 + # 合法字符:a-z, A-Z, 0-9, _, ., - + normalized = re.sub(r"[^a-zA-Z0-9_.\-]", "_", normalized) + + # 步骤3:确保以字母或下划线开头 + if normalized and not (normalized[0].isalpha() or normalized[0] == "_"): + # 以数字、点或短横线开头,添加下划线前缀 + normalized = "_" + normalized + + # 步骤4:截断到 64 个字符 + if len(normalized) > 64: + normalized = normalized[:64] + + # 步骤5:确保不为空 + if not normalized: + normalized = "_unnamed_function" + + return normalized + + +def _resolve_ref(ref: str, root_schema: Dict[str, Any]) -> Optional[Dict[str, Any]]: + """ + 解析 $ref 引用 + + Args: + ref: 引用路径,如 "#/definitions/MyType" + root_schema: 根 schema 对象 + + Returns: + 解析后的 schema,如果失败返回 None + """ + if not ref.startswith('#/'): + return None + + path = ref[2:].split('/') + current = root_schema + + for segment in path: + if isinstance(current, dict) and segment in current: + current = current[segment] + else: + return None + + return current if isinstance(current, dict) else None + + +def _clean_schema_for_claude(schema: Any, root_schema: Optional[Dict[str, Any]] = None, visited: Optional[set] = None) -> Any: + """ + 清理 JSON Schema,转换为 Claude API 支持的格式(符合 JSON Schema draft 2020-12) + + 处理逻辑: + 1. 解析 $ref 引用 + 2. 合并 allOf 中的 schema + 3. 转换 anyOf 为更兼容的格式 + 4. 保持标准 JSON Schema 类型(不转换为大写) + 5. 处理 array 的 items + 6. 清理 Claude 不支持的字段 + + Args: + schema: JSON Schema 对象 + root_schema: 根 schema(用于解析 $ref) + visited: 已访问的对象集合(防止循环引用) + + Returns: + 清理后的 schema + """ + # 非字典类型直接返回 + if not isinstance(schema, dict): + return schema + + # 初始化 + if root_schema is None: + root_schema = schema + if visited is None: + visited = set() + + # 防止循环引用 + schema_id = id(schema) + if schema_id in visited: + return schema + visited.add(schema_id) + + # 创建副本避免修改原对象 + result = {} + + # 1. 处理 $ref + if "$ref" in schema: + resolved = _resolve_ref(schema["$ref"], root_schema) + if resolved: + import copy + result = copy.deepcopy(resolved) + for key, value in schema.items(): + if key != "$ref": + result[key] = value + schema = result + result = {} + + # 2. 处理 allOf(合并所有 schema) + if "allOf" in schema: + all_of_schemas = schema["allOf"] + for item in all_of_schemas: + cleaned_item = _clean_schema_for_claude(item, root_schema, visited) + + if "properties" in cleaned_item: + if "properties" not in result: + result["properties"] = {} + result["properties"].update(cleaned_item["properties"]) + + if "required" in cleaned_item: + if "required" not in result: + result["required"] = [] + result["required"].extend(cleaned_item["required"]) + + for key, value in cleaned_item.items(): + if key not in ["properties", "required"]: + result[key] = value + + for key, value in schema.items(): + if key not in ["allOf", "properties", "required"]: + result[key] = value + elif key in ["properties", "required"] and key not in result: + result[key] = value + else: + result = dict(schema) + + # 3. 处理 type 数组(如 ["string", "null"]) + if "type" in result: + type_value = result["type"] + if isinstance(type_value, list): + # Claude 支持 type 数组,保持不变 + pass + + # 4. 处理 array 的 items + if result.get("type") == "array": + if "items" not in result: + result["items"] = {} + elif isinstance(result["items"], list): + # Tuple 定义,检查是否所有元素类型相同 + tuple_items = result["items"] + first_type = tuple_items[0].get("type") if tuple_items else None + is_homogeneous = all(item.get("type") == first_type for item in tuple_items) + + if is_homogeneous and first_type: + result["items"] = _clean_schema_for_claude(tuple_items[0], root_schema, visited) + else: + # 异质元组,使用 anyOf 表示 + result["items"] = { + "anyOf": [_clean_schema_for_claude(item, root_schema, visited) for item in tuple_items] + } + else: + result["items"] = _clean_schema_for_claude(result["items"], root_schema, visited) + + # 5. 处理 anyOf(保持 anyOf,递归清理) + if "anyOf" in result: + result["anyOf"] = [_clean_schema_for_claude(item, root_schema, visited) for item in result["anyOf"]] + + # 6. 清理 Claude 不支持的字段(根据 JSON Schema 2020-12) + # Claude API 对某些字段比较严格,移除可能导致问题的字段 + unsupported_keys = { + "title", "$schema", "strict", + "additionalItems", # 废弃字段,使用 items 替代 + "exclusiveMaximum", "exclusiveMinimum", # 在 2020-12 中这些应该是数值而非布尔值 + "$defs", "definitions", # 移除 definitions 相关字段避免冲突 + "example", "examples", "readOnly", "writeOnly", + "const", # const 可能导致问题 + "contentEncoding", "contentMediaType", + "oneOf", # oneOf 可能导致问题,用 anyOf 替代 + "patternProperties", "dependencies", "propertyNames", # Google API 不支持 + } + + for key in list(result.keys()): + if key in unsupported_keys: + del result[key] + + # 递归处理 additionalProperties(如果存在) + if "additionalProperties" in result and isinstance(result["additionalProperties"], dict): + result["additionalProperties"] = _clean_schema_for_claude(result["additionalProperties"], root_schema, visited) + + # 7. 递归处理 properties + if "properties" in result: + cleaned_props = {} + for prop_name, prop_schema in result["properties"].items(): + cleaned_props[prop_name] = _clean_schema_for_claude(prop_schema, root_schema, visited) + result["properties"] = cleaned_props + + # 8. 确保有 type 字段(如果有 properties 但没有 type) + if "properties" in result and "type" not in result: + result["type"] = "object" + + # 9. 去重 required 数组 + if "required" in result and isinstance(result["required"], list): + result["required"] = list(dict.fromkeys(result["required"])) + + return result + + +def _clean_schema_for_gemini(schema: Any, root_schema: Optional[Dict[str, Any]] = None, visited: Optional[set] = None) -> Any: + """ + 清理 JSON Schema,转换为 Gemini 支持的格式 + + 参考 worker.mjs 的 transformOpenApiSchemaToGemini 实现 + + 处理逻辑: + 1. 解析 $ref 引用 + 2. 合并 allOf 中的 schema + 3. 转换 anyOf 为 enum(如果可能) + 4. 类型映射(string -> STRING) + 5. 处理 ARRAY 的 items(包括 Tuple) + 6. 将 default 值移到 description + 7. 清理不支持的字段 + + Args: + schema: JSON Schema 对象 + root_schema: 根 schema(用于解析 $ref) + visited: 已访问的对象集合(防止循环引用) + + Returns: + 清理后的 schema + """ + # 非字典类型直接返回 + if not isinstance(schema, dict): + return schema + + # 初始化 + if root_schema is None: + root_schema = schema + if visited is None: + visited = set() + + # 防止循环引用 + schema_id = id(schema) + if schema_id in visited: + return schema + visited.add(schema_id) + + # 创建副本避免修改原对象 + result = {} + + # 1. 处理 $ref + if "$ref" in schema: + resolved = _resolve_ref(schema["$ref"], root_schema) + if resolved: + # 检测循环引用:若 resolved 已在 visited 中,直接返回占位符 + resolved_id = id(resolved) + if resolved_id in visited: + return {"type": "OBJECT", "description": "(circular reference)"} + # 将 resolved 的 id 加入 visited,防止后续递归时重复处理 + visited.add(resolved_id) + # 合并解析后的 schema 和当前 schema(浅拷贝,避免 deepcopy 爆栈) + merged = dict(resolved) + # 当前 schema 的其他字段会覆盖解析后的字段 + for key, value in schema.items(): + if key != "$ref": + merged[key] = value + schema = merged + result = {} + + # 2. 处理 allOf(合并所有 schema) + if "allOf" in schema: + all_of_schemas = schema["allOf"] + for item in all_of_schemas: + cleaned_item = _clean_schema_for_gemini(item, root_schema, visited) + + # 合并 properties + if "properties" in cleaned_item: + if "properties" not in result: + result["properties"] = {} + result["properties"].update(cleaned_item["properties"]) + + # 合并 required + if "required" in cleaned_item: + if "required" not in result: + result["required"] = [] + result["required"].extend(cleaned_item["required"]) + + # 合并其他字段(简单覆盖) + for key, value in cleaned_item.items(): + if key not in ["properties", "required"]: + result[key] = value + + # 复制其他字段 + for key, value in schema.items(): + if key not in ["allOf", "properties", "required"]: + result[key] = value + elif key in ["properties", "required"] and key not in result: + result[key] = value + else: + # 复制所有字段 + result = dict(schema) + + # 3. 类型映射(转换为大写) + # 注意:Gemini API 的 type 字段必须是字符串,不能是数组 + if "type" in result: + type_value = result["type"] + + # 如果 type 是列表,提取主要类型(非 null) + if isinstance(type_value, list): + primary_type = next((t for t in type_value if t != "null"), None) + type_value = primary_type if primary_type else "STRING" # 默认为 STRING + + # 类型映射 + type_map = { + "string": "STRING", + "number": "NUMBER", + "integer": "INTEGER", + "boolean": "BOOLEAN", + "array": "ARRAY", + "object": "OBJECT", + } + + if isinstance(type_value, str) and type_value.lower() in type_map: + # 确保 result["type"] 是字符串而不是列表 + result["type"] = type_map[type_value.lower()] + else: + # 未知类型,删除该字段 + del result["type"] + + # 4. 处理 ARRAY 的 items + if result.get("type") == "ARRAY": + if "items" not in result: + # 没有 items,默认允许任意类型 + result["items"] = {} + elif isinstance(result["items"], list): + # Tuple 定义(items 是数组) + tuple_items = result["items"] + + # 提取类型信息用于 description + tuple_types = [item.get("type", "any") for item in tuple_items] + tuple_desc = f"(Tuple: [{', '.join(tuple_types)}])" + + original_desc = result.get("description", "") + result["description"] = f"{original_desc} {tuple_desc}".strip() + + # 检查是否所有元素类型相同 + first_type = tuple_items[0].get("type") if tuple_items else None + is_homogeneous = all(item.get("type") == first_type for item in tuple_items) + + if is_homogeneous and first_type: + # 同质元组,转换为 List + result["items"] = _clean_schema_for_gemini(tuple_items[0], root_schema, visited) + else: + # 异质元组,Gemini 不支持,设为 {} + result["items"] = {} + else: + # 递归处理 items + result["items"] = _clean_schema_for_gemini(result["items"], root_schema, visited) + + # 5. 处理 anyOf(尝试转换为 enum) + if "anyOf" in result: + any_of_schemas = result["anyOf"] + + # 递归处理每个 schema + cleaned_any_of = [_clean_schema_for_gemini(item, root_schema, visited) for item in any_of_schemas] + + # 尝试提取 enum + if all("const" in item for item in cleaned_any_of): + enum_values = [ + str(item["const"]) + for item in cleaned_any_of + if item.get("const") not in ["", None] + ] + if enum_values: + result["type"] = "STRING" + result["enum"] = enum_values + elif "type" not in result: + # 如果不是 enum,尝试取第一个有效的类型定义 + first_valid = next((item for item in cleaned_any_of if item.get("type") or item.get("enum")), None) + if first_valid: + result.update(first_valid) + + # 删除 anyOf + del result["anyOf"] + + # 6. 将 default 值移到 description + if "default" in result: + default_value = result["default"] + original_desc = result.get("description", "") + result["description"] = f"{original_desc} (Default: {json.dumps(default_value)})".strip() + del result["default"] + + # 7. 清理不支持的字段 + unsupported_keys = { + "title", "$schema", "$ref", "strict", "exclusiveMaximum", + "exclusiveMinimum", "additionalProperties", "oneOf", "allOf", + "$defs", "definitions", "example", "examples", "readOnly", + "writeOnly", "const", "additionalItems", "contains", + "patternProperties", "dependencies", "propertyNames", + "if", "then", "else", "contentEncoding", "contentMediaType" + } + + for key in list(result.keys()): + if key in unsupported_keys: + del result[key] + + # 8. 递归处理 properties + if "properties" in result: + cleaned_props = {} + for prop_name, prop_schema in result["properties"].items(): + cleaned_props[prop_name] = _clean_schema_for_gemini(prop_schema, root_schema, visited) + result["properties"] = cleaned_props + + # 9. 确保有 type 字段(如果有 properties 但没有 type) + if "properties" in result and "type" not in result: + result["type"] = "OBJECT" + + # 10. 去重 required 数组 + if "required" in result and isinstance(result["required"], list): + result["required"] = list(dict.fromkeys(result["required"])) # 保持顺序去重 + + return result + + +def fix_tool_call_args_types( + args: Dict[str, Any], + parameters_schema: Dict[str, Any] +) -> Dict[str, Any]: + """ + 根据工具的参数 schema 修正函数调用参数的类型 + + 例如:将字符串 "5" 转换为数字 5,根据 schema 中的 type 定义 + + Args: + args: 函数调用的参数字典 + parameters_schema: 工具定义中的 parameters schema + + Returns: + 类型修正后的参数字典 + """ + if not args or not parameters_schema: + return args + + properties = parameters_schema.get("properties", {}) + if not properties: + return args + + fixed_args = {} + for key, value in args.items(): + if key not in properties: + # 参数不在 schema 中,保持原样 + fixed_args[key] = value + continue + + param_schema = properties[key] + param_type = param_schema.get("type") + + # 根据 schema 中的类型修正参数值 + if param_type == "number" or param_type == "integer": + # 如果值是字符串,尝试转换为数字 + if isinstance(value, str): + try: + if param_type == "integer": + fixed_args[key] = int(value) + else: + # 尝试转换为 float,如果是整数则保持为 int + num_value = float(value) + fixed_args[key] = int(num_value) if num_value.is_integer() else num_value + log.debug(f"[OPENAI2GEMINI] 修正参数类型: {key} '{value}' -> {fixed_args[key]} ({param_type})") + except (ValueError, AttributeError): + # 转换失败,保持原样 + fixed_args[key] = value + log.warning(f"[OPENAI2GEMINI] 无法将参数 {key} 的值 '{value}' 转换为 {param_type}") + else: + fixed_args[key] = value + elif param_type == "boolean": + # 如果值是字符串,转换为布尔值 + if isinstance(value, str): + if value.lower() in ("true", "1", "yes"): + fixed_args[key] = True + elif value.lower() in ("false", "0", "no"): + fixed_args[key] = False + else: + fixed_args[key] = value + if fixed_args[key] != value: + log.debug(f"[OPENAI2GEMINI] 修正参数类型: {key} '{value}' -> {fixed_args[key]} (boolean)") + else: + fixed_args[key] = value + elif param_type == "string": + # 如果值不是字符串,转换为字符串 + if not isinstance(value, str): + fixed_args[key] = str(value) + log.debug(f"[OPENAI2GEMINI] 修正参数类型: {key} {value} -> '{fixed_args[key]}' (string)") + else: + fixed_args[key] = value + else: + # 其他类型(array, object 等)保持原样 + fixed_args[key] = value + + return fixed_args + + +def convert_openai_tools_to_gemini(openai_tools: List, model: str = "") -> List[Dict[str, Any]]: + """ + 将 OpenAI tools 格式转换为 Gemini functionDeclarations 格式 + + Args: + openai_tools: OpenAI 格式的工具列表(可能是字典或 Pydantic 模型) + model: 模型名称(用于判断是否为 Claude 模型) + + Returns: + Gemini 格式的工具列表 + """ + if not openai_tools: + return [] + + # 判断是否为 Claude 模型 + is_claude_model = "claude" in model.lower() + + function_declarations = [] + + for tool in openai_tools: + if tool.get("type") != "function": + log.warning(f"Skipping non-function tool type: {tool.get('type')}") + continue + + function = tool.get("function") + if not function: + log.warning("Tool missing 'function' field") + continue + + # 获取并规范化函数名 + original_name = function.get("name") + if not original_name: + log.warning("Tool missing 'name' field, using default") + original_name = "_unnamed_function" + + normalized_name = _normalize_function_name(original_name) + + # 如果名称被修改了,记录日志 + if normalized_name != original_name: + log.debug(f"Function name normalized: '{original_name}' -> '{normalized_name}'") + + # 构建 Gemini function declaration + declaration = { + "name": normalized_name, + "description": function.get("description", ""), + } + + # 添加参数(如果有)- 根据模型选择不同的清理函数 + if "parameters" in function: + if is_claude_model: + cleaned_params = _clean_schema_for_claude(function["parameters"]) + log.debug(f"[OPENAI2GEMINI] Using Claude schema cleaning for tool: {normalized_name}") + else: + cleaned_params = _clean_schema_for_gemini(function["parameters"]) + + if cleaned_params: + declaration["parameters"] = cleaned_params + + function_declarations.append(declaration) + + if not function_declarations: + return [] + + # Gemini 格式:工具数组中包含 functionDeclarations + return [{"functionDeclarations": function_declarations}] + + +def convert_tool_choice_to_tool_config(tool_choice: Union[str, Dict[str, Any]]) -> Dict[str, Any]: + """ + 将 OpenAI tool_choice 转换为 Gemini toolConfig + + Args: + tool_choice: OpenAI 格式的 tool_choice + + Returns: + Gemini 格式的 toolConfig + """ + if isinstance(tool_choice, str): + if tool_choice == "auto": + return {"functionCallingConfig": {"mode": "AUTO"}} + elif tool_choice == "none": + return {"functionCallingConfig": {"mode": "NONE"}} + elif tool_choice == "required": + return {"functionCallingConfig": {"mode": "ANY"}} + elif isinstance(tool_choice, dict): + # {"type": "function", "function": {"name": "my_function"}} + if tool_choice.get("type") == "function": + function_name = tool_choice.get("function", {}).get("name") + if function_name: + return { + "functionCallingConfig": { + "mode": "ANY", + "allowedFunctionNames": [function_name], + } + } + + # 默认返回 AUTO 模式 + return {"functionCallingConfig": {"mode": "AUTO"}} + + +def convert_tool_message_to_function_response(message, all_messages: List = None) -> Dict[str, Any]: + """ + 将 OpenAI 的 tool role 消息转换为 Gemini functionResponse + + Args: + message: OpenAI 格式的工具消息 + all_messages: 所有消息的列表,用于查找 tool_call_id 对应的函数名 + + Returns: + Gemini 格式的 functionResponse part + """ + # 获取 name 字段 + name = getattr(message, "name", None) + encoded_tool_call_id = getattr(message, "tool_call_id", None) or "" + + # 解码获取原始ID(functionResponse不需要签名) + original_tool_call_id, _ = decode_tool_id_and_signature(encoded_tool_call_id) + + # 如果没有 name,尝试从 all_messages 中查找对应的 tool_call_id + # 注意:使用编码ID查找,因为存储的是编码ID + if not name and encoded_tool_call_id and all_messages: + for msg in all_messages: + if getattr(msg, "role", None) == "assistant" and hasattr(msg, "tool_calls") and msg.tool_calls: + for tool_call in msg.tool_calls: + if getattr(tool_call, "id", None) == encoded_tool_call_id: + func = getattr(tool_call, "function", None) + if func: + name = getattr(func, "name", None) + break + if name: + break + + # 最终兜底:如果仍然没有 name,使用默认值 + if not name: + name = "unknown_function" + log.warning(f"Tool message missing function name, using default: {name}") + + try: + # 尝试将 content 解析为 JSON + response_data = ( + json.loads(message.content) if isinstance(message.content, str) else message.content + ) + except (json.JSONDecodeError, TypeError): + # 如果不是有效的 JSON,包装为对象 + response_data = {"result": str(message.content)} + + # 确保 response_data 是字典类型(Gemini API 要求 response 必须是对象) + if not isinstance(response_data, dict): + response_data = {"result": response_data} + + return {"functionResponse": {"id": original_tool_call_id, "name": name, "response": response_data}} + + +def _reverse_transform_value(value: Any) -> Any: + """ + 将值转换回原始类型(Gemini 可能将所有值转为字符串) + + 仅处理 Gemini 在工具参数中常见的布尔/空值字符串化情况, + 不再对数字字符串做启发式转换,避免把 schema 声明为 string + 的参数错误还原成 integer。 + + 参考 worker.mjs 的 reverseTransformValue + + Args: + value: 要转换的值 + + Returns: + 转换后的值 + """ + if not isinstance(value, str): + return value + + # 布尔值 + if value == 'true': + return True + if value == 'false': + return False + + # null + if value == 'null': + return None + + # 其他情况保持字符串 + return value + + +def _reverse_transform_args(args: Any) -> Any: + """ + 递归转换函数参数,将字符串转回原始类型 + + 参考 worker.mjs 的 reverseTransformArgs + + Args: + args: 函数参数(可能是字典、列表或其他类型) + + Returns: + 转换后的参数 + """ + if not isinstance(args, (dict, list)): + return args + + if isinstance(args, list): + return [_reverse_transform_args(item) for item in args] + + # 处理字典 + result = {} + for key, value in args.items(): + if isinstance(value, (dict, list)): + result[key] = _reverse_transform_args(value) + else: + result[key] = _reverse_transform_value(value) + + return result + + +def extract_tool_calls_from_parts( + parts: List[Dict[str, Any]], is_streaming: bool = False +) -> Tuple[List[Dict[str, Any]], str]: + """ + 从 Gemini response parts 中提取工具调用和文本内容 + + Args: + parts: Gemini response 的 parts 数组 + is_streaming: 是否为流式响应(流式响应需要添加 index 字段) + + Returns: + (tool_calls, text_content) 元组 + """ + tool_calls = [] + text_content = "" + + for idx, part in enumerate(parts): + # 检查是否是函数调用 + if "functionCall" in part: + function_call = part["functionCall"] + # 获取原始ID或生成新ID + original_id = function_call.get("id") or f"call_{uuid.uuid4().hex[:24]}" + # 将thoughtSignature编码到ID中以便往返保留 + signature = part.get("thoughtSignature") + encoded_id = encode_tool_id_with_signature(original_id, signature) + + # 获取参数并转换类型 + args = function_call.get("args", {}) + # 将字符串类型的值转回原始类型 + args = _reverse_transform_args(args) + + tool_call = { + "id": encoded_id, + "type": "function", + "function": { + "name": function_call.get("name", "nameless_function"), + "arguments": json.dumps(args), + }, + } + # 流式响应需要 index 字段 + if is_streaming: + tool_call["index"] = idx + tool_calls.append(tool_call) + + # 提取文本内容(排除 thinking tokens) + elif "text" in part and not part.get("thought", False): + text_content += part["text"] + + return tool_calls, text_content + + +def extract_images_from_content(content: Any) -> Dict[str, Any]: + """ + 从 OpenAI content 中提取文本和图片 + + Args: + content: OpenAI 消息的 content 字段(可能是字符串或列表) + + Returns: + 包含 text 和 images 的字典 + """ + result = {"text": "", "images": []} + + if isinstance(content, str): + result["text"] = content + elif isinstance(content, list): + for item in content: + if isinstance(item, dict): + if item.get("type") == "text": + result["text"] += item.get("text", "") + elif item.get("type") == "image_url": + image_url = item.get("image_url", {}).get("url", "") + # 解析 data:image/png;base64,xxx 格式 + if image_url.startswith("data:image/"): + import re + match = re.match(r"^data:image/(\w+);base64,(.+)$", image_url) + if match: + mime_type = match.group(1) + base64_data = match.group(2) + result["images"].append({ + "inlineData": { + "mimeType": f"image/{mime_type}", + "data": base64_data + } + }) + + return result + +async def convert_openai_to_gemini_request(openai_request: Dict[str, Any]) -> Dict[str, Any]: + """ + 将 OpenAI 格式请求体转换为 Gemini 格式请求体 + + 注意: 此函数只负责基础转换,不包含 normalize_gemini_request 中的处理 + (如 thinking config, search tools, 参数范围限制等) + + Args: + openai_request: OpenAI 格式的请求体字典,包含: + - messages: 消息列表 + - temperature, top_p, max_tokens, stop 等生成参数 + - tools, tool_choice (可选) + - response_format (可选) + + Returns: + Gemini 格式的请求体字典,包含: + - contents: 转换后的消息内容 + - generationConfig: 生成配置 + - systemInstruction: 系统指令 (如果有) + - tools, toolConfig (如果有) + """ + # 处理连续的system消息(兼容性模式) + openai_request = await merge_system_messages(openai_request) + + contents = [] + + # 提取消息列表 + messages = openai_request.get("messages", []) + + # 构建 tool_call_id -> (name, original_id, signature) 的映射 + tool_call_mapping = {} + for msg in messages: + if msg.get("role") == "assistant" and msg.get("tool_calls"): + for tc in msg["tool_calls"]: + encoded_id = tc.get("id", "") + func_name = tc.get("function", {}).get("name") or "" + if encoded_id: + # 解码获取原始ID和签名 + original_id, signature = decode_tool_id_and_signature(encoded_id) + tool_call_mapping[encoded_id] = (func_name, original_id, signature) + + # 构建工具名称到参数 schema 的映射(用于类型修正) + tool_schemas = {} + if "tools" in openai_request and openai_request["tools"]: + for tool in openai_request["tools"]: + if tool.get("type") == "function": + function = tool.get("function", {}) + func_name = function.get("name") + if func_name: + tool_schemas[func_name] = function.get("parameters", {}) + + # 用于累积连续的 tool message 的 functionResponse parts + pending_tool_parts = [] + + def flush_pending_tool_parts(): + """将累积的 tool parts 作为单个 contents 条目追加""" + nonlocal pending_tool_parts + if pending_tool_parts: + contents.append({ + "role": "user", + "parts": pending_tool_parts + }) + pending_tool_parts = [] + + for message in messages: + role = message.get("role", "user") + content = message.get("content", "") + + # 处理工具消息(tool role)- 累积到 pending_tool_parts + if role == "tool": + tool_call_id = message.get("tool_call_id", "") + func_name = message.get("name") + + # 使用映射表查找 + if tool_call_id in tool_call_mapping: + func_name, original_id, _ = tool_call_mapping[tool_call_id] + else: + # 如果没有name,尝试从消息列表中查找 + if not func_name and tool_call_id: + for msg in messages: + if msg.get("role") == "assistant" and msg.get("tool_calls"): + for tc in msg["tool_calls"]: + if tc.get("id") == tool_call_id: + func_name = tc.get("function", {}).get("name") + break + if func_name: + break + + # 解码 tool_call_id 获取原始 ID + original_id, _ = decode_tool_id_and_signature(tool_call_id) + + # 最终兜底:确保 func_name 不为空 + if not func_name: + func_name = "unknown_function" + log.warning(f"Tool message missing function name for tool_call_id={tool_call_id}, using default: {func_name}") + + # 解析响应数据 + try: + response_data = json.loads(content) if isinstance(content, str) else content + except (json.JSONDecodeError, TypeError): + response_data = {"result": str(content)} + + # 确保 response_data 是字典类型(Gemini API 要求 response 必须是对象) + if not isinstance(response_data, dict): + response_data = {"result": response_data} + + # 累积 functionResponse part(不立即追加到 contents) + pending_tool_parts.append({ + "functionResponse": { + "id": original_id, + "name": func_name, + "response": response_data + } + }) + continue + + # 遇到非 tool 消息时,先 flush 累积的 tool parts + flush_pending_tool_parts() + + # system 消息已经由 merge_system_messages 处理,这里跳过 + if role == "system": + continue + + # 将OpenAI角色映射到Gemini角色 + if role == "assistant": + role = "model" + + # 检查是否有tool_calls + tool_calls = message.get("tool_calls") + if tool_calls: + parts = [] + + # 如果有文本内容,先添加文本 + if content: + parts.append({"text": content}) + + # 添加每个工具调用 + for tool_call in tool_calls: + try: + args = ( + json.loads(tool_call["function"]["arguments"]) + if isinstance(tool_call["function"]["arguments"], str) + else tool_call["function"]["arguments"] + ) + + # 根据工具的 schema 修正参数类型 + func_name = tool_call["function"]["name"] + if func_name in tool_schemas: + args = fix_tool_call_args_types(args, tool_schemas[func_name]) + + # 解码工具ID和thoughtSignature + encoded_id = tool_call.get("id", "") + original_id, signature = decode_tool_id_and_signature(encoded_id) + + # 构建functionCall part + function_call_part = { + "functionCall": { + "id": original_id, + "name": func_name, + "args": args + } + } + + # 如果有thoughtSignature则添加,否则使用占位符以满足 Gemini API 要求 + if signature: + function_call_part["thoughtSignature"] = signature + else: + function_call_part["thoughtSignature"] = "skip_thought_signature_validator" + + parts.append(function_call_part) + except (json.JSONDecodeError, KeyError) as e: + log.error(f"Failed to parse tool call: {e}") + continue + + if parts: + contents.append({"role": role, "parts": parts}) + continue + + # 处理普通内容 + if isinstance(content, list): + parts = [] + for part in content: + if part.get("type") == "text": + parts.append({"text": part.get("text", "")}) + elif part.get("type") == "image_url": + image_url = part.get("image_url", {}).get("url") + if image_url: + try: + mime_type, base64_data = image_url.split(";") + _, mime_type = mime_type.split(":") + _, base64_data = base64_data.split(",") + parts.append({ + "inlineData": { + "mimeType": mime_type, + "data": base64_data, + } + }) + except ValueError: + continue + if parts: + contents.append({"role": role, "parts": parts}) + elif content: + contents.append({"role": role, "parts": [{"text": content}]}) + + # 循环结束后,flush 剩余的 tool parts(如果消息列表以 tool 消息结尾) + flush_pending_tool_parts() + + # 构建生成配置 + generation_config = {} + model = openai_request.get("model", "") + + # 基础参数映射 + if "temperature" in openai_request: + generation_config["temperature"] = openai_request["temperature"] + if "top_p" in openai_request: + generation_config["topP"] = openai_request["top_p"] + if "top_k" in openai_request: + generation_config["topK"] = openai_request["top_k"] + if "max_tokens" in openai_request or "max_completion_tokens" in openai_request: + # max_completion_tokens 优先于 max_tokens + max_tokens = openai_request.get("max_completion_tokens") or openai_request.get("max_tokens") + generation_config["maxOutputTokens"] = max_tokens + if "stop" in openai_request: + stop = openai_request["stop"] + generation_config["stopSequences"] = [stop] if isinstance(stop, str) else stop + if "frequency_penalty" in openai_request: + generation_config["frequencyPenalty"] = openai_request["frequency_penalty"] + if "presence_penalty" in openai_request: + generation_config["presencePenalty"] = openai_request["presence_penalty"] + if "n" in openai_request: + generation_config["candidateCount"] = openai_request["n"] + if "seed" in openai_request: + generation_config["seed"] = openai_request["seed"] + + # 处理 response_format + if "response_format" in openai_request and openai_request["response_format"]: + response_format = openai_request["response_format"] + format_type = response_format.get("type") + + if format_type == "json_schema": + # JSON Schema 模式 + if "json_schema" in response_format and "schema" in response_format["json_schema"]: + schema = response_format["json_schema"]["schema"] + # 清理 schema + generation_config["responseSchema"] = _clean_schema_for_gemini(schema) + generation_config["responseMimeType"] = "application/json" + elif format_type == "json_object": + # JSON Object 模式 + generation_config["responseMimeType"] = "application/json" + elif format_type == "text": + # Text 模式 + generation_config["responseMimeType"] = "text/plain" + + # 如果contents为空,添加默认用户消息 + if not contents: + contents.append({"role": "user", "parts": [{"text": "请根据系统指令回答。"}]}) + + # 构建基础请求 + gemini_request = { + "contents": contents, + "generationConfig": generation_config + } + + # 如果 merge_system_messages 已经添加了 systemInstruction,使用它 + if "systemInstruction" in openai_request: + gemini_request["systemInstruction"] = openai_request["systemInstruction"] + + # 处理工具 - 传递 model 参数以便根据模型类型选择清理策略 + model = openai_request.get("model", "") + if "tools" in openai_request and openai_request["tools"]: + gemini_request["tools"] = convert_openai_tools_to_gemini(openai_request["tools"], model) + + # 处理tool_choice + if "tool_choice" in openai_request and openai_request["tool_choice"]: + gemini_request["toolConfig"] = convert_tool_choice_to_tool_config(openai_request["tool_choice"]) + + return gemini_request + + +def convert_gemini_to_openai_response( + gemini_response: Union[Dict[str, Any], Any], + model: str, + status_code: int = 200 +) -> Dict[str, Any]: + """ + 将 Gemini 格式非流式响应转换为 OpenAI 格式非流式响应 + + 注意: 如果收到的不是 200 开头的响应,不做任何处理,直接转发原始响应 + + Args: + gemini_response: Gemini 格式的响应体 (字典或响应对象) + model: 模型名称 + status_code: HTTP 状态码 (默认 200) + + Returns: + OpenAI 格式的响应体字典,或原始响应 (如果状态码不是 2xx) + """ + # 非 2xx 状态码直接返回原始响应 + if not (200 <= status_code < 300): + if isinstance(gemini_response, dict): + return gemini_response + else: + # 如果是响应对象,尝试解析为字典 + try: + if hasattr(gemini_response, "json"): + return gemini_response.json() + elif hasattr(gemini_response, "body"): + body = gemini_response.body + if isinstance(body, bytes): + return json.loads(body.decode()) + return json.loads(str(body)) + else: + return {"error": str(gemini_response)} + except Exception: + return {"error": str(gemini_response)} + + # 确保是字典格式 + if not isinstance(gemini_response, dict): + try: + if hasattr(gemini_response, "json"): + gemini_response = gemini_response.json() + elif hasattr(gemini_response, "body"): + body = gemini_response.body + if isinstance(body, bytes): + gemini_response = json.loads(body.decode()) + else: + gemini_response = json.loads(str(body)) + else: + gemini_response = json.loads(str(gemini_response)) + except Exception: + return {"error": "Invalid response format"} + + # 处理 GeminiCLI 的 response 包装格式 + if "response" in gemini_response: + gemini_response = gemini_response["response"] + + # 转换为 OpenAI 格式 + choices = [] + + for candidate in gemini_response.get("candidates", []): + role = candidate.get("content", {}).get("role", "assistant") + + # 将Gemini角色映射回OpenAI角色 + if role == "model": + role = "assistant" + + # 提取并分离thinking tokens和常规内容 + parts = candidate.get("content", {}).get("parts", []) + + # 提取工具调用和文本内容 + tool_calls, text_content = extract_tool_calls_from_parts(parts) + + # 提取多种类型的内容 + content_parts = [] + reasoning_parts = [] + + for part in parts: + # 处理 executableCode(代码生成) + if "executableCode" in part: + exec_code = part["executableCode"] + lang = exec_code.get("language", "python").lower() + code = exec_code.get("code", "") + # 添加代码块(前后加换行符确保 Markdown 渲染正确) + content_parts.append(f"\n```{lang}\n{code}\n```\n") + + # 处理 codeExecutionResult(代码执行结果) + elif "codeExecutionResult" in part: + result = part["codeExecutionResult"] + outcome = result.get("outcome") + output = result.get("output", "") + + if output: + label = "output" if outcome == "OUTCOME_OK" else "error" + content_parts.append(f"\n```{label}\n{output}\n```\n") + + # 处理 thought(思考内容) + elif part.get("thought", False) and "text" in part: + reasoning_parts.append(part["text"]) + + # 处理普通文本(非思考内容) + elif "text" in part and not part.get("thought", False): + # 这部分已经在 extract_tool_calls_from_parts 中处理 + pass + + # 处理 inlineData(图片) + elif "inlineData" in part: + inline_data = part["inlineData"] + mime_type = inline_data.get("mimeType", "image/png") + base64_data = inline_data.get("data", "") + # 使用 Markdown 格式 + content_parts.append(f"![gemini-generated-content](data:{mime_type};base64,{base64_data})") + + # 合并所有内容部分 + if content_parts: + # 使用双换行符连接各部分,确保块之间有间距 + additional_content = "\n\n".join(content_parts) + if text_content: + text_content = text_content + "\n\n" + additional_content + else: + text_content = additional_content + + # 合并 reasoning content + reasoning_content = "\n\n".join(reasoning_parts) if reasoning_parts else "" + + # 构建消息对象 + message = {"role": role} + + # 获取 Gemini 的 finishReason + gemini_finish_reason = candidate.get("finishReason") + + # 如果有工具调用 + if tool_calls: + message["tool_calls"] = tool_calls + message["content"] = text_content if text_content else None + # 只有在正常停止(STOP)时才设为 tool_calls,其他情况保持原始 finish_reason + # 这样可以避免在 SAFETY、MAX_TOKENS 等情况下仍然返回 tool_calls 导致循环 + if gemini_finish_reason == "STOP": + finish_reason = "tool_calls" + else: + finish_reason = _map_finish_reason(gemini_finish_reason) + else: + message["content"] = text_content + finish_reason = _map_finish_reason(gemini_finish_reason) + + # 添加 reasoning content (如果有) + if reasoning_content: + message["reasoning_content"] = reasoning_content + + choices.append({ + "index": candidate.get("index", 0), + "message": message, + "finish_reason": finish_reason, + }) + + # 转换 usageMetadata + usage = _convert_usage_metadata(gemini_response.get("usageMetadata")) + + response_data = { + "id": str(uuid.uuid4()), + "object": "chat.completion", + "created": int(time.time()), + "model": model, + "choices": choices, + } + + if usage: + response_data["usage"] = usage + + return response_data + + +def convert_gemini_to_openai_stream( + gemini_stream_chunk: str, + model: str, + response_id: str, + status_code: int = 200 +) -> Optional[str]: + """ + 将 Gemini 格式流式响应块转换为 OpenAI SSE 格式流式响应 + + 注意: 如果收到的不是 200 开头的响应,不做任何处理,直接转发原始内容 + + Args: + gemini_stream_chunk: Gemini 格式的流式响应块 (字符串,通常是 "data: {json}" 格式) + model: 模型名称 + response_id: 此流式响应的一致ID + status_code: HTTP 状态码 (默认 200) + + Returns: + OpenAI SSE 格式的响应字符串 (如 "data: {json}\n\n"), + 或原始内容 (如果状态码不是 2xx), + 或 None (如果解析失败) + """ + # 非 2xx 状态码直接返回原始内容 + if not (200 <= status_code < 300): + return gemini_stream_chunk + + # 解析 Gemini 流式块 + try: + # 去除 "data: " 前缀 + if isinstance(gemini_stream_chunk, bytes): + if gemini_stream_chunk.startswith(b"data: "): + payload_str = gemini_stream_chunk[len(b"data: "):].strip().decode("utf-8") + else: + payload_str = gemini_stream_chunk.strip().decode("utf-8") + else: + if gemini_stream_chunk.startswith("data: "): + payload_str = gemini_stream_chunk[len("data: "):].strip() + else: + payload_str = gemini_stream_chunk.strip() + + # 跳过空块 + if not payload_str: + return None + + # 解析 JSON + gemini_chunk = json.loads(payload_str) + except (json.JSONDecodeError, UnicodeDecodeError): + # 解析失败,跳过此块 + return None + + # 处理 GeminiCLI 的 response 包装格式 + if "response" in gemini_chunk: + gemini_response = gemini_chunk["response"] + else: + gemini_response = gemini_chunk + + # 转换为 OpenAI 流式格式 + choices = [] + + for candidate in gemini_response.get("candidates", []): + role = candidate.get("content", {}).get("role", "assistant") + + # 将Gemini角色映射回OpenAI角色 + if role == "model": + role = "assistant" + + # 提取并分离thinking tokens和常规内容 + parts = candidate.get("content", {}).get("parts", []) + + # 提取工具调用和文本内容 (流式需要 index) + tool_calls, text_content = extract_tool_calls_from_parts(parts, is_streaming=True) + + # 提取多种类型的内容 + content_parts = [] + reasoning_parts = [] + + for part in parts: + # 处理 executableCode(代码生成) + if "executableCode" in part: + exec_code = part["executableCode"] + lang = exec_code.get("language", "python").lower() + code = exec_code.get("code", "") + content_parts.append(f"\n```{lang}\n{code}\n```\n") + + # 处理 codeExecutionResult(代码执行结果) + elif "codeExecutionResult" in part: + result = part["codeExecutionResult"] + outcome = result.get("outcome") + output = result.get("output", "") + + if output: + label = "output" if outcome == "OUTCOME_OK" else "error" + content_parts.append(f"\n```{label}\n{output}\n```\n") + + # 处理 thought(思考内容) + elif part.get("thought", False) and "text" in part: + reasoning_parts.append(part["text"]) + + # 处理普通文本(非思考内容) + elif "text" in part and not part.get("thought", False): + # 这部分已经在 extract_tool_calls_from_parts 中处理 + pass + + # 处理 inlineData(图片) + elif "inlineData" in part: + inline_data = part["inlineData"] + mime_type = inline_data.get("mimeType", "image/png") + base64_data = inline_data.get("data", "") + content_parts.append(f"![gemini-generated-content](data:{mime_type};base64,{base64_data})") + + # 合并所有内容部分 + if content_parts: + additional_content = "\n\n".join(content_parts) + if text_content: + text_content = text_content + "\n\n" + additional_content + else: + text_content = additional_content + + # 合并 reasoning content + reasoning_content = "\n\n".join(reasoning_parts) if reasoning_parts else "" + + # 构建 delta 对象 + delta = {} + + if tool_calls: + delta["tool_calls"] = tool_calls + if text_content: + delta["content"] = text_content + elif text_content: + delta["content"] = text_content + + if reasoning_content: + delta["reasoning_content"] = reasoning_content + + # 获取 Gemini 的 finishReason + gemini_finish_reason = candidate.get("finishReason") + finish_reason = _map_finish_reason(gemini_finish_reason) + + # 只有在正常停止(STOP)且有工具调用时才设为 tool_calls + # 避免在 SAFETY、MAX_TOKENS 等情况下仍然返回 tool_calls 导致循环 + if tool_calls and gemini_finish_reason == "STOP": + finish_reason = "tool_calls" + + choices.append({ + "index": candidate.get("index", 0), + "delta": delta, + "finish_reason": finish_reason, + }) + + # 转换 usageMetadata (只在流结束时存在) + usage = _convert_usage_metadata(gemini_response.get("usageMetadata")) + + # 构建 OpenAI 流式响应 + response_data = { + "id": response_id, + "object": "chat.completion.chunk", + "created": int(time.time()), + "model": model, + "choices": choices, + } + + # 只在有 usage 数据且有 finish_reason 时添加 usage + if usage: + has_finish_reason = any(choice.get("finish_reason") for choice in choices) + if has_finish_reason: + response_data["usage"] = usage + + # 转换为 SSE 格式: "data: {json}\n\n" + return f"data: {json.dumps(response_data)}\n\n" diff --git a/src/converter/thoughtSignature_fix.py b/src/converter/thoughtSignature_fix.py new file mode 100644 index 0000000000000000000000000000000000000000..caff9bfc41e86ba26d895cc8ff8447a1424b59a6 --- /dev/null +++ b/src/converter/thoughtSignature_fix.py @@ -0,0 +1,56 @@ +""" +thoughtSignature 处理公共模块 + +提供统一的 thoughtSignature 编码/解码功能,用于在工具调用ID中保留签名信息。 +这使得签名能够在客户端往返传输中保留,即使客户端会删除自定义字段。 +""" + +from typing import Optional, Tuple + +# 在工具调用ID中嵌入thoughtSignature的分隔符 +# 这使得签名能够在客户端往返传输中保留,即使客户端会删除自定义字段 +THOUGHT_SIGNATURE_SEPARATOR = "__thought__" + + +def encode_tool_id_with_signature(tool_id: str, signature: Optional[str]) -> str: + """ + 将 thoughtSignature 编码到工具调用ID中,以便往返保留。 + + Args: + tool_id: 原始工具调用ID + signature: thoughtSignature(可选) + + Returns: + 编码后的工具调用ID + + Examples: + >>> encode_tool_id_with_signature("call_123", "abc") + 'call_123__thought__abc' + >>> encode_tool_id_with_signature("call_123", None) + 'call_123' + """ + if not signature: + return tool_id + return f"{tool_id}{THOUGHT_SIGNATURE_SEPARATOR}{signature}" + + +def decode_tool_id_and_signature(encoded_id: str) -> Tuple[str, Optional[str]]: + """ + 从编码的ID中提取原始工具ID和thoughtSignature。 + + Args: + encoded_id: 编码的工具调用ID + + Returns: + (原始工具ID, thoughtSignature) 元组 + + Examples: + >>> decode_tool_id_and_signature("call_123__thought__abc") + ('call_123', 'abc') + >>> decode_tool_id_and_signature("call_123") + ('call_123', None) + """ + if not encoded_id or THOUGHT_SIGNATURE_SEPARATOR not in encoded_id: + return encoded_id, None + parts = encoded_id.split(THOUGHT_SIGNATURE_SEPARATOR, 1) + return parts[0], parts[1] if len(parts) == 2 else None diff --git a/src/converter/utils.py b/src/converter/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..2d313adc0a4f19d3c1c1403c9ad096814d5e072b --- /dev/null +++ b/src/converter/utils.py @@ -0,0 +1,237 @@ +from typing import Any, Dict + + +def extract_content_and_reasoning(parts: list) -> tuple: + """从Gemini响应部件中提取内容和推理内容 + + Args: + parts: Gemini 响应中的 parts 列表 + + Returns: + (content, reasoning_content, images): 文本内容、推理内容和图片数据的元组 + - content: 文本内容字符串 + - reasoning_content: 推理内容字符串 + - images: 图片数据列表,每个元素格式为: + { + "type": "image_url", + "image_url": { + "url": "data:{mime_type};base64,{base64_data}" + } + } + """ + content = "" + reasoning_content = "" + images = [] + + for part in parts: + # 提取文本内容 + text = part.get("text", "") + if text: + if part.get("thought", False): + reasoning_content += text + else: + content += text + + # 提取图片数据 + if "inlineData" in part: + inline_data = part["inlineData"] + mime_type = inline_data.get("mimeType", "image/png") + base64_data = inline_data.get("data", "") + images.append({ + "type": "image_url", + "image_url": { + "url": f"data:{mime_type};base64,{base64_data}" + } + }) + + return content, reasoning_content, images + + +async def merge_system_messages(request_body: Dict[str, Any]) -> Dict[str, Any]: + """ + 根据兼容性模式处理请求体中的system消息 + + - 兼容性模式关闭(False):将连续的system消息合并为systemInstruction + - 兼容性模式开启(True):将所有system消息转换为user消息 + + Args: + request_body: OpenAI或Claude格式的请求体,包含messages字段 + + Returns: + 处理后的请求体 + + Example (兼容性模式关闭): + 输入: + { + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "system", "content": "You are an expert in Python."}, + {"role": "user", "content": "Hello"} + ] + } + + 输出: + { + "systemInstruction": { + "parts": [ + {"text": "You are a helpful assistant."}, + {"text": "You are an expert in Python."} + ] + }, + "messages": [ + {"role": "user", "content": "Hello"} + ] + } + + Example (兼容性模式开启): + 输入: + { + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Hello"} + ] + } + + 输出: + { + "messages": [ + {"role": "user", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Hello"} + ] + } + + Example (Anthropic格式,兼容性模式关闭): + 输入: + { + "system": "You are a helpful assistant.", + "messages": [ + {"role": "user", "content": "Hello"} + ] + } + + 输出: + { + "systemInstruction": { + "parts": [ + {"text": "You are a helpful assistant."} + ] + }, + "messages": [ + {"role": "user", "content": "Hello"} + ] + } + """ + from config import get_compatibility_mode_enabled + + compatibility_mode = await get_compatibility_mode_enabled() + + # 处理 Anthropic 格式的顶层 system 参数 + # Anthropic API 规范: system 是顶层参数,不在 messages 中 + system_content = request_body.get("system") + if system_content: + system_parts = [] + + if isinstance(system_content, str): + if system_content.strip(): + system_parts.append({"text": system_content}) + elif isinstance(system_content, list): + # system 可以是包含多个块的列表 + for item in system_content: + if isinstance(item, dict): + if item.get("type") == "text" and item.get("text", "").strip(): + system_parts.append({"text": item["text"]}) + elif isinstance(item, str) and item.strip(): + system_parts.append({"text": item}) + + if system_parts: + if compatibility_mode: + # 兼容性模式:将 system 转换为 user 消息插入到 messages 开头 + user_system_message = { + "role": "user", + "content": system_content if isinstance(system_content, str) else + "\n".join(part["text"] for part in system_parts) + } + messages = request_body.get("messages", []) + request_body = request_body.copy() + request_body["messages"] = [user_system_message] + messages + else: + # 非兼容性模式:添加为 systemInstruction + request_body = request_body.copy() + request_body["systemInstruction"] = {"parts": system_parts} + + messages = request_body.get("messages", []) + if not messages: + return request_body + + compatibility_mode = await get_compatibility_mode_enabled() + + if compatibility_mode: + # 兼容性模式开启:将所有system消息转换为user消息 + converted_messages = [] + for message in messages: + if message.get("role") == "system": + # 创建新的消息对象,将role改为user + converted_message = message.copy() + converted_message["role"] = "user" + converted_messages.append(converted_message) + else: + converted_messages.append(message) + + result = request_body.copy() + result["messages"] = converted_messages + return result + else: + # 兼容性模式关闭:提取连续的system消息合并为systemInstruction + system_parts = [] + + # 如果已经从顶层 system 参数创建了 systemInstruction,获取现有的 parts + if "systemInstruction" in request_body: + existing_instruction = request_body.get("systemInstruction", {}) + if isinstance(existing_instruction, dict): + system_parts = existing_instruction.get("parts", []).copy() + + remaining_messages = [] + collecting_system = True + + for message in messages: + role = message.get("role", "") + content = message.get("content", "") + + if role == "system" and collecting_system: + # 提取system消息的文本内容 + if isinstance(content, str): + if content.strip(): + system_parts.append({"text": content}) + elif isinstance(content, list): + # 处理列表格式的content + for item in content: + if isinstance(item, dict): + if item.get("type") == "text" and item.get("text", "").strip(): + system_parts.append({"text": item["text"]}) + elif isinstance(item, str) and item.strip(): + system_parts.append({"text": item}) + else: + # 遇到非system消息,停止收集 + collecting_system = False + if role == "system": + # 将后续的system消息转换为user消息 + converted_message = message.copy() + converted_message["role"] = "user" + remaining_messages.append(converted_message) + else: + remaining_messages.append(message) + + # 如果没有找到任何system消息(包括顶层参数和messages中的),返回原始请求体 + if not system_parts: + return request_body + + # 构建新的请求体 + result = request_body.copy() + + # 添加或更新systemInstruction + result["systemInstruction"] = {"parts": system_parts} + + # 更新messages列表(移除已处理的system消息) + result["messages"] = remaining_messages + + return result \ No newline at end of file diff --git a/src/credential_manager.py b/src/credential_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..b99feb7757a85325637bcd3ef886a58f7f6c607a --- /dev/null +++ b/src/credential_manager.py @@ -0,0 +1,510 @@ +""" +凭证管理器 +""" + +import asyncio +import time +from datetime import datetime, timezone +from typing import Any, Dict, List, Optional, Tuple + +from log import log + +from src.google_oauth_api import Credentials +from src.storage_adapter import get_storage_adapter + +class CredentialManager: + """ + 统一凭证管理器 + 所有存储操作通过storage_adapter进行 + """ + + def __init__(self): + # 核心状态 + self._initialized = False + self._storage_adapter = None + + # 并发控制(简化) + # 后端数据库自行处理并发,credential_manager 不再使用本地锁 + + async def _ensure_initialized(self): + """确保管理器已初始化(内部使用)""" + if not self._initialized or self._storage_adapter is None: + await self.initialize() + + async def initialize(self): + """初始化凭证管理器""" + if self._initialized and self._storage_adapter is not None: + return + + # 初始化统一存储适配器 + self._storage_adapter = await get_storage_adapter() + self._initialized = True + + async def close(self): + """清理资源""" + log.debug("Closing credential manager...") + self._initialized = False + log.debug("Credential manager closed") + + async def get_valid_credential( + self, mode: str = "geminicli", model_name: Optional[str] = None + ) -> Optional[Tuple[str, Dict[str, Any]]]: + """ + 获取有效的凭证 - 随机负载均衡版 + 每次随机选择一个可用的凭证(未禁用、未冷却、符合preview要求) + 如果刷新失败会自动禁用失效凭证并重试获取下一个可用凭证 + + Args: + mode: 凭证模式 ("geminicli" 或 "antigravity") + model_name: 完整模型名,用于模型级冷却检查和preview筛选 + - geminicli: 完整模型名 + - 包含 "preview" 的模型只能使用 preview=True 的凭证 + - 不包含 "preview" 的模型优先使用 preview=False 的凭证 + - antigravity: 完整模型名(如 "gemini-2.0-flash-exp") + """ + await self._ensure_initialized() + + # 最多重试3次 + max_retries = 3 + for attempt in range(max_retries): + result = await self._storage_adapter._backend.get_next_available_credential( + mode=mode, model_name=model_name + ) + + # 如果没有可用凭证,直接返回None + if not result: + if attempt == 0: + log.warning(f"没有可用凭证 (mode={mode}, model_name={model_name})") + return None + + filename, credential_data = result + + # Token 刷新检查 + if await self._should_refresh_token(credential_data): + log.debug(f"Token需要刷新 - 文件: {filename} (mode={mode})") + refreshed_data = await self._refresh_token(credential_data, filename, mode=mode) + if refreshed_data: + # 刷新成功,返回凭证 + credential_data = refreshed_data + log.debug(f"Token刷新成功: {filename} (mode={mode})") + return filename, credential_data + else: + # 刷新失败(_refresh_token内部已自动禁用失效凭证) + log.warning(f"Token刷新失败,尝试获取下一个凭证: {filename} (mode={mode}, attempt={attempt+1}/{max_retries})") + # 继续循环,尝试获取下一个可用凭证 + continue + else: + # Token有效,直接返回 + return filename, credential_data + + # 重试次数用尽 + log.error(f"重试{max_retries}次后仍无可用凭证 (mode={mode}, model_name={model_name})") + return None + + async def add_credential(self, credential_name: str, credential_data: Dict[str, Any]): + """ + 新增或更新一个凭证 + 存储层会自动处理轮换顺序 + """ + await self._ensure_initialized() + await self._storage_adapter.store_credential(credential_name, credential_data) + log.info(f"Credential added/updated: {credential_name}") + + async def add_antigravity_credential(self, credential_name: str, credential_data: Dict[str, Any]): + """ + 新增或更新一个Antigravity凭证 + 存储层会自动处理轮换顺序 + """ + await self._ensure_initialized() + await self._storage_adapter.store_credential(credential_name, credential_data, mode="antigravity") + log.info(f"Antigravity credential added/updated: {credential_name}") + + async def remove_credential(self, credential_name: str, mode: str = "geminicli") -> bool: + """删除一个凭证""" + await self._ensure_initialized() + try: + await self._storage_adapter.delete_credential(credential_name, mode=mode) + log.info(f"Credential removed: {credential_name} (mode={mode})") + return True + except Exception as e: + log.error(f"Error removing credential {credential_name}: {e}") + return False + + async def update_credential_state(self, credential_name: str, state_updates: Dict[str, Any], mode: str = "geminicli"): + """更新凭证状态""" + log.debug(f"[CredMgr] update_credential_state 开始: credential_name={credential_name}, state_updates={state_updates}, mode={mode}") + log.debug(f"[CredMgr] 调用 _ensure_initialized...") + await self._ensure_initialized() + log.debug(f"[CredMgr] _ensure_initialized 完成") + try: + log.debug(f"[CredMgr] 调用 storage_adapter.update_credential_state...") + success = await self._storage_adapter.update_credential_state( + credential_name, state_updates, mode=mode + ) + log.debug(f"[CredMgr] storage_adapter.update_credential_state 返回: {success}") + if success: + log.debug(f"Updated credential state: {credential_name} (mode={mode})") + else: + log.warning(f"Failed to update credential state: {credential_name} (mode={mode})") + return success + except Exception as e: + log.error(f"Error updating credential state {credential_name}: {e}") + return False + + async def set_cred_disabled(self, credential_name: str, disabled: bool, mode: str = "geminicli"): + """设置凭证的启用/禁用状态""" + try: + log.info(f"[CredMgr] set_cred_disabled 开始: credential_name={credential_name}, disabled={disabled}, mode={mode}") + success = await self.update_credential_state( + credential_name, {"disabled": disabled}, mode=mode + ) + log.info(f"[CredMgr] update_credential_state 返回: success={success}") + if success: + action = "disabled" if disabled else "enabled" + log.info(f"Credential {action}: {credential_name} (mode={mode})") + else: + log.warning(f"[CredMgr] 设置禁用状态失败: credential_name={credential_name}, disabled={disabled}") + return success + except Exception as e: + log.error(f"Error setting credential disabled state {credential_name}: {e}") + return False + + async def get_creds_status(self) -> Dict[str, Dict[str, Any]]: + """获取所有凭证的状态""" + await self._ensure_initialized() + try: + return await self._storage_adapter.get_all_credential_states() + except Exception as e: + log.error(f"Error getting credential statuses: {e}") + return {} + + async def get_creds_summary(self) -> List[Dict[str, Any]]: + """ + 获取所有凭证的摘要信息(轻量级,不包含完整凭证数据) + 使用后端的高性能查询 + """ + await self._ensure_initialized() + try: + return await self._storage_adapter._backend.get_credentials_summary() + except Exception as e: + log.error(f"Error getting credentials summary: {e}") + return [] + + async def get_or_fetch_user_email(self, credential_name: str, mode: str = "geminicli") -> Optional[str]: + """获取或获取用户邮箱地址""" + try: + # 确保已初始化 + await self._ensure_initialized() + + # 从状态中获取缓存的邮箱 + state = await self._storage_adapter.get_credential_state(credential_name, mode=mode) + cached_email = state.get("user_email") if state else None + + if cached_email: + return cached_email + + # 如果没有缓存,从凭证数据获取 + credential_data = await self._storage_adapter.get_credential(credential_name, mode=mode) + if not credential_data: + return None + + # 创建凭证对象并自动刷新 token + from .google_oauth_api import Credentials, get_user_email + + credentials = Credentials.from_dict(credential_data) + if not credentials: + return None + + # 自动刷新 token(如果需要) + token_refreshed = await credentials.refresh_if_needed() + + # 如果 token 被刷新了,更新存储 + if token_refreshed: + log.info(f"Token已自动刷新: {credential_name} (mode={mode})") + updated_data = credentials.to_dict() + await self._storage_adapter.store_credential(credential_name, updated_data, mode=mode) + + # 获取邮箱 + email = await get_user_email(credentials) + + if email: + # 缓存邮箱地址 + await self._storage_adapter.update_credential_state( + credential_name, {"user_email": email}, mode=mode + ) + return email + + return None + + except Exception as e: + log.error(f"Error fetching user email for {credential_name}: {e}") + return None + + async def record_api_call_result( + self, + credential_name: str, + success: bool, + error_code: Optional[int] = None, + cooldown_until: Optional[float] = None, + mode: str = "geminicli", + model_name: Optional[str] = None, + error_message: Optional[str] = None + ): + """ + 记录API调用结果 + + Args: + credential_name: 凭证名称 + success: 是否成功 + error_code: 错误码(如果失败) + cooldown_until: 冷却截止时间戳(Unix时间戳,针对429 QUOTA_EXHAUSTED) + mode: 凭证模式 ("geminicli" 或 "antigravity") + model_name: 模型名(用于设置模型级冷却) + error_message: 错误信息(如果失败) + """ + await self._ensure_initialized() + try: + if success: + # 条件写入:仅当凭证有错误状态或模型冷却时才写 DB,零内存缓存 + # fire-and-forget,不阻塞请求链路 + asyncio.create_task( + self._storage_adapter._backend.record_success( + credential_name, model_name=model_name, mode=mode + ) + ) + + elif error_code: + # 记录错误码和错误信息 + error_messages = {} + if error_message: + error_messages[str(error_code)] = error_message + + state_updates = { + "error_codes": [error_code], + "error_messages": error_messages, + } + + await self.update_credential_state(credential_name, state_updates, mode=mode) + + # 设置模型级冷却 + if cooldown_until is not None and model_name: + if hasattr(self._storage_adapter._backend, 'set_model_cooldown'): + await self._storage_adapter._backend.set_model_cooldown( + credential_name, model_name, cooldown_until, mode=mode + ) + log.info( + f"设置模型级冷却: {credential_name}, model_name={model_name}, " + f"冷却至: {datetime.fromtimestamp(cooldown_until, timezone.utc).isoformat()}" + ) + + except Exception as e: + log.error(f"Error recording API call result for {credential_name}: {e}") + + async def _should_refresh_token(self, credential_data: Dict[str, Any]) -> bool: + """检查token是否需要刷新""" + try: + # 如果没有access_token或过期时间,需要刷新 + if not credential_data.get("access_token") and not credential_data.get("token"): + log.debug("没有access_token,需要刷新") + return True + + expiry_str = credential_data.get("expiry") + if not expiry_str: + log.debug("没有过期时间,需要刷新") + return True + + # 解析过期时间 + try: + if isinstance(expiry_str, str): + if "+" in expiry_str: + file_expiry = datetime.fromisoformat(expiry_str) + elif expiry_str.endswith("Z"): + file_expiry = datetime.fromisoformat(expiry_str.replace("Z", "+00:00")) + else: + file_expiry = datetime.fromisoformat(expiry_str) + else: + log.debug("过期时间格式无效,需要刷新") + return True + + # 确保时区信息 + if file_expiry.tzinfo is None: + file_expiry = file_expiry.replace(tzinfo=timezone.utc) + + # 检查是否还有至少5分钟有效期 + now = datetime.now(timezone.utc) + time_left = (file_expiry - now).total_seconds() + + log.debug( + f"Token时间检查: " + f"当前UTC时间={now.isoformat()}, " + f"过期时间={file_expiry.isoformat()}, " + f"剩余时间={int(time_left/60)}分{int(time_left%60)}秒" + ) + + if time_left > 300: # 5分钟缓冲 + return False + else: + log.debug(f"Token即将过期(剩余{int(time_left/60)}分钟),需要刷新") + return True + + except Exception as e: + log.warning(f"解析过期时间失败: {e},需要刷新") + return True + + except Exception as e: + log.error(f"检查token过期时出错: {e}") + return True + + async def _refresh_token( + self, credential_data: Dict[str, Any], filename: str, mode: str = "geminicli" + ) -> Optional[Dict[str, Any]]: + """刷新token并更新存储""" + await self._ensure_initialized() + try: + # 创建Credentials对象 + creds = Credentials.from_dict(credential_data) + + # 检查是否可以刷新 + if not creds.refresh_token: + log.error(f"没有refresh_token,无法刷新: {filename} (mode={mode})") + # 自动禁用没有refresh_token的凭证 + try: + await self.update_credential_state(filename, {"disabled": True}, mode=mode) + log.warning(f"凭证已自动禁用(缺少refresh_token): {filename}") + except Exception as e: + log.error(f"禁用凭证失败 {filename}: {e}") + return None + + # 刷新token + log.debug(f"正在刷新token: {filename} (mode={mode})") + await creds.refresh() + + # 更新凭证数据 + if creds.access_token: + credential_data["access_token"] = creds.access_token + # 保持兼容性 + credential_data["token"] = creds.access_token + + if creds.expires_at: + credential_data["expiry"] = creds.expires_at.isoformat() + + # 保存到存储 + await self._storage_adapter.store_credential(filename, credential_data, mode=mode) + log.info(f"Token刷新成功并已保存: {filename} (mode={mode})") + + return credential_data + + except Exception as e: + error_msg = str(e) + log.error(f"Token刷新失败 {filename} (mode={mode}): {error_msg}") + + # 尝试提取HTTP状态码(TokenError可能携带status_code属性) + status_code = None + if hasattr(e, 'status_code'): + status_code = e.status_code + + # 检查是否是凭证永久失效的错误(只有明确的400/403等才判定为永久失效) + is_permanent_failure = self._is_permanent_refresh_failure(error_msg, status_code) + + if is_permanent_failure: + log.warning(f"检测到凭证永久失效 (HTTP {status_code}): {filename}") + # 记录失效状态 + if status_code: + await self.record_api_call_result(filename, False, status_code, mode=mode) + else: + await self.record_api_call_result(filename, False, 400, mode=mode) + + # 禁用失效凭证 + try: + # 直接禁用该凭证(随机选择机制会自动跳过它) + disabled_ok = await self.update_credential_state(filename, {"disabled": True}, mode=mode) + if disabled_ok: + log.warning(f"永久失效凭证已禁用: {filename}") + else: + log.warning("永久失效凭证禁用失败,将由上层逻辑继续处理") + except Exception as e2: + log.error(f"禁用永久失效凭证时出错 {filename}: {e2}") + else: + # 网络错误或其他临时性错误,不封禁凭证 + log.warning(f"Token刷新失败但非永久性错误 (HTTP {status_code}),不封禁凭证: {filename}") + + return None + + def _is_permanent_refresh_failure(self, error_msg: str, status_code: Optional[int] = None) -> bool: + """ + 判断是否是凭证永久失效的错误 + + Args: + error_msg: 错误信息 + status_code: HTTP状态码(如果有) + + Returns: + True表示凭证永久失效应封禁,False表示临时错误不应封禁 + """ + # 优先使用HTTP状态码判断 + if status_code is not None: + # 400/401/403 明确表示凭证有问题,应该封禁 + if status_code in [400, 401, 403]: + log.debug(f"检测到客户端错误状态码 {status_code},判定为永久失效") + return True + # 500/502/503/504 是服务器错误,不应封禁凭证 + elif status_code in [500, 502, 503, 504]: + log.debug(f"检测到服务器错误状态码 {status_code},不应封禁凭证") + return False + # 429 (限流) 不应封禁凭证 + elif status_code == 429: + log.debug("检测到限流错误 429,不应封禁凭证") + return False + + # 如果没有状态码,回退到错误信息匹配(谨慎判断) + # 只有明确的凭证失效错误才判定为永久失效 + permanent_error_patterns = [ + "invalid_grant", + "refresh_token_expired", + "invalid_refresh_token", + "unauthorized_client", + "access_denied", + ] + + error_msg_lower = error_msg.lower() + for pattern in permanent_error_patterns: + if pattern.lower() in error_msg_lower: + log.debug(f"错误信息匹配到永久失效模式: {pattern}") + return True + + # 默认认为是临时错误(如网络问题),不应封禁凭证 + log.debug("未匹配到明确的永久失效模式,判定为临时错误") + return False + +class _CredentialManagerSingleton: + """单例包装器,支持懒加载和自动初始化""" + + _instance: Optional[CredentialManager] = None + _lock = None + + def __init__(self): + self._manager = None + + async def _get_or_create(self) -> CredentialManager: + """获取或创建单例实例(线程安全)""" + if self._instance is None: + # 简单的实例创建(异步环境下一般不需要复杂的锁) + if self._instance is None: + self._instance = CredentialManager() + await self._instance.initialize() + log.debug("CredentialManager singleton initialized") + + return self._instance + + def __getattr__(self, name): + """代理所有方法调用到真实的 CredentialManager 实例""" + async def _async_wrapper(*args, **kwargs): + manager = await self._get_or_create() + method = getattr(manager, name) + return await method(*args, **kwargs) + + return _async_wrapper + + +# 全局单例实例 - 直接导入即可使用 +credential_manager = _CredentialManagerSingleton() diff --git a/src/google_oauth_api.py b/src/google_oauth_api.py new file mode 100644 index 0000000000000000000000000000000000000000..f2681ddfe36bcbd7eab7f8353fd7921683bbb89d --- /dev/null +++ b/src/google_oauth_api.py @@ -0,0 +1,852 @@ +""" +Google OAuth2 认证模块 +""" + +import time +import asyncio +from datetime import datetime, timedelta, timezone +from typing import Any, Dict, List, Optional, Tuple +from urllib.parse import urlencode + +import jwt + +from config import ( + get_googleapis_proxy_url, + get_oauth_proxy_url, + get_resource_manager_api_url, + get_service_usage_api_url, +) +from log import log + +from src.httpx_client import get_async, post_async + + +class TokenError(Exception): + """Token相关错误""" + + pass + + +class Credentials: + """凭证类""" + + def __init__( + self, + access_token: str, + refresh_token: str = None, + client_id: str = None, + client_secret: str = None, + expires_at: datetime = None, + project_id: str = None, + ): + self.access_token = access_token + self.refresh_token = refresh_token + self.client_id = client_id + self.client_secret = client_secret + self.expires_at = expires_at + self.project_id = project_id + + # 反代配置将在使用时异步获取 + self.oauth_base_url = None + self.token_endpoint = None + + def is_expired(self) -> bool: + """检查token是否过期""" + if not self.expires_at: + return True + + # 提前3分钟认为过期 + buffer = timedelta(minutes=3) + return (self.expires_at - buffer) <= datetime.now(timezone.utc) + + async def refresh_if_needed(self) -> bool: + """如果需要则刷新token""" + if not self.is_expired(): + return False + + if not self.refresh_token: + raise TokenError("需要刷新令牌但未提供") + + await self.refresh() + return True + + async def refresh(self): + """刷新访问令牌""" + if not self.refresh_token: + raise TokenError("无刷新令牌") + + data = { + "client_id": self.client_id, + "client_secret": self.client_secret, + "refresh_token": self.refresh_token, + "grant_type": "refresh_token", + } + + try: + oauth_base_url = await get_oauth_proxy_url() + token_url = f"{oauth_base_url.rstrip('/')}/token" + response = await post_async( + token_url, + data=data, + headers={"Content-Type": "application/x-www-form-urlencoded"}, + ) + response.raise_for_status() + + token_data = response.json() + self.access_token = token_data["access_token"] + + if "expires_in" in token_data: + expires_in = int(token_data["expires_in"]) + current_utc = datetime.now(timezone.utc) + self.expires_at = current_utc + timedelta(seconds=expires_in) + log.debug( + f"Token刷新: 当前UTC时间={current_utc.isoformat()}, " + f"有效期={expires_in}秒, " + f"过期时间={self.expires_at.isoformat()}" + ) + + if "refresh_token" in token_data: + self.refresh_token = token_data["refresh_token"] + + log.debug(f"Token刷新成功,过期时间: {self.expires_at}") + + except Exception as e: + error_msg = str(e) + status_code = None + if hasattr(e, 'response') and hasattr(e.response, 'status_code'): + status_code = e.response.status_code + error_msg = f"Token刷新失败 (HTTP {status_code}): {error_msg}" + else: + error_msg = f"Token刷新失败: {error_msg}" + + log.error(error_msg) + token_error = TokenError(error_msg) + token_error.status_code = status_code + raise token_error + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "Credentials": + """从字典创建凭证""" + # 处理过期时间 + expires_at = None + if "expiry" in data and data["expiry"]: + try: + expiry_str = data["expiry"] + if isinstance(expiry_str, str): + if expiry_str.endswith("Z"): + expires_at = datetime.fromisoformat(expiry_str.replace("Z", "+00:00")) + elif "+" in expiry_str: + expires_at = datetime.fromisoformat(expiry_str) + else: + expires_at = datetime.fromisoformat(expiry_str).replace(tzinfo=timezone.utc) + except ValueError: + log.warning(f"无法解析过期时间: {expiry_str}") + + return cls( + access_token=data.get("token") or data.get("access_token", ""), + refresh_token=data.get("refresh_token"), + client_id=data.get("client_id"), + client_secret=data.get("client_secret"), + expires_at=expires_at, + project_id=data.get("project_id"), + ) + + def to_dict(self) -> Dict[str, Any]: + """转为字典""" + result = { + "access_token": self.access_token, + "refresh_token": self.refresh_token, + "client_id": self.client_id, + "client_secret": self.client_secret, + "project_id": self.project_id, + } + + if self.expires_at: + result["expiry"] = self.expires_at.isoformat() + + return result + + +class Flow: + """OAuth流程类""" + + def __init__( + self, client_id: str, client_secret: str, scopes: List[str], redirect_uri: str = None + ): + self.client_id = client_id + self.client_secret = client_secret + self.scopes = scopes + self.redirect_uri = redirect_uri + + # 反代配置将在使用时异步获取 + self.oauth_base_url = None + self.token_endpoint = None + self.auth_endpoint = "https://accounts.google.com/o/oauth2/auth" + + self.credentials: Optional[Credentials] = None + + def get_auth_url(self, state: str = None, **kwargs) -> str: + """生成授权URL""" + params = { + "client_id": self.client_id, + "redirect_uri": self.redirect_uri, + "scope": " ".join(self.scopes), + "response_type": "code", + "access_type": "offline", + "prompt": "consent", + "include_granted_scopes": "true", + } + + if state: + params["state"] = state + + params.update(kwargs) + return f"{self.auth_endpoint}?{urlencode(params)}" + + async def exchange_code(self, code: str) -> Credentials: + """用授权码换取token""" + data = { + "client_id": self.client_id, + "client_secret": self.client_secret, + "redirect_uri": self.redirect_uri, + "code": code, + "grant_type": "authorization_code", + } + + try: + oauth_base_url = await get_oauth_proxy_url() + token_url = f"{oauth_base_url.rstrip('/')}/token" + response = await post_async( + token_url, data=data, headers={"Content-Type": "application/x-www-form-urlencoded"} + ) + response.raise_for_status() + + token_data = response.json() + + # 计算过期时间 + expires_at = None + if "expires_in" in token_data: + expires_in = int(token_data["expires_in"]) + expires_at = datetime.now(timezone.utc) + timedelta(seconds=expires_in) + + # 创建凭证对象 + self.credentials = Credentials( + access_token=token_data["access_token"], + refresh_token=token_data.get("refresh_token"), + client_id=self.client_id, + client_secret=self.client_secret, + expires_at=expires_at, + ) + + return self.credentials + + except Exception as e: + error_msg = f"获取token失败: {str(e)}" + log.error(error_msg) + raise TokenError(error_msg) + + +class ServiceAccount: + """Service Account类""" + + def __init__( + self, email: str, private_key: str, project_id: str = None, scopes: List[str] = None + ): + self.email = email + self.private_key = private_key + self.project_id = project_id + self.scopes = scopes or [] + + # 反代配置将在使用时异步获取 + self.oauth_base_url = None + self.token_endpoint = None + + self.access_token: Optional[str] = None + self.expires_at: Optional[datetime] = None + + def is_expired(self) -> bool: + """检查token是否过期""" + if not self.expires_at: + return True + + buffer = timedelta(minutes=3) + return (self.expires_at - buffer) <= datetime.now(timezone.utc) + + def create_jwt(self) -> str: + """创建JWT令牌""" + now = int(time.time()) + + payload = { + "iss": self.email, + "scope": " ".join(self.scopes) if self.scopes else "", + "aud": self.token_endpoint, + "exp": now + 3600, + "iat": now, + } + + return jwt.encode(payload, self.private_key, algorithm="RS256") + + async def get_access_token(self) -> str: + """获取访问令牌""" + if not self.is_expired() and self.access_token: + return self.access_token + + assertion = self.create_jwt() + + data = {"grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer", "assertion": assertion} + + try: + oauth_base_url = await get_oauth_proxy_url() + token_url = f"{oauth_base_url.rstrip('/')}/token" + response = await post_async( + token_url, data=data, headers={"Content-Type": "application/x-www-form-urlencoded"} + ) + response.raise_for_status() + + token_data = response.json() + self.access_token = token_data["access_token"] + + if "expires_in" in token_data: + expires_in = int(token_data["expires_in"]) + self.expires_at = datetime.now(timezone.utc) + timedelta(seconds=expires_in) + + return self.access_token + + except Exception as e: + error_msg = f"Service Account获取token失败: {str(e)}" + log.error(error_msg) + raise TokenError(error_msg) + + @classmethod + def from_dict(cls, data: Dict[str, Any], scopes: List[str] = None) -> "ServiceAccount": + """从字典创建Service Account凭证""" + return cls( + email=data["client_email"], + private_key=data["private_key"], + project_id=data.get("project_id"), + scopes=scopes, + ) + + +# 工具函数 +async def get_user_info(credentials: Credentials) -> Optional[Dict[str, Any]]: + """获取用户信息""" + await credentials.refresh_if_needed() + + try: + googleapis_base_url = await get_googleapis_proxy_url() + userinfo_url = f"{googleapis_base_url.rstrip('/')}/oauth2/v2/userinfo" + response = await get_async( + userinfo_url, headers={"Authorization": f"Bearer {credentials.access_token}"} + ) + response.raise_for_status() + return response.json() + except Exception as e: + log.error(f"获取用户信息失败: {e}") + return None + + +async def get_user_email(credentials: Credentials) -> Optional[str]: + """获取用户邮箱地址""" + try: + # 确保凭证有效 + await credentials.refresh_if_needed() + + # 调用Google userinfo API获取邮箱 + user_info = await get_user_info(credentials) + if user_info: + email = user_info.get("email") + if email: + log.info(f"成功获取邮箱地址: {email}") + return email + else: + log.warning(f"userinfo响应中没有邮箱信息: {user_info}") + return None + else: + log.warning("获取用户信息失败") + return None + + except Exception as e: + log.error(f"获取用户邮箱失败: {e}") + return None + + +async def fetch_user_email_from_file(cred_data: Dict[str, Any]) -> Optional[str]: + """从凭证数据获取用户邮箱地址(支持统一存储)""" + try: + # 直接从凭证数据创建凭证对象 + credentials = Credentials.from_dict(cred_data) + if not credentials or not credentials.access_token: + log.warning("无法从凭证数据创建凭证对象或获取访问令牌") + return None + + # 获取邮箱 + return await get_user_email(credentials) + + except Exception as e: + log.error(f"从凭证数据获取用户邮箱失败: {e}") + return None + + +async def validate_token(token: str) -> Optional[Dict[str, Any]]: + """验证访问令牌""" + try: + oauth_base_url = await get_oauth_proxy_url() + tokeninfo_url = f"{oauth_base_url.rstrip('/')}/tokeninfo?access_token={token}" + + response = await get_async(tokeninfo_url) + response.raise_for_status() + return response.json() + except Exception as e: + log.error(f"验证令牌失败: {e}") + return None + + +async def enable_required_apis(credentials: Credentials, project_id: str) -> bool: + """自动启用必需的API服务""" + try: + # 确保凭证有效 + if credentials.is_expired() and credentials.refresh_token: + await credentials.refresh() + + headers = { + "Authorization": f"Bearer {credentials.access_token}", + "Content-Type": "application/json", + "User-Agent": "geminicli-oauth/1.0", + } + + # 需要启用的服务列表 + required_services = [ + "geminicloudassist.googleapis.com", # Gemini Cloud Assist API + "cloudaicompanion.googleapis.com", # Gemini for Google Cloud API + ] + + for service in required_services: + log.info(f"正在检查并启用服务: {service}") + + # 检查服务是否已启用 + service_usage_base_url = await get_service_usage_api_url() + check_url = ( + f"{service_usage_base_url.rstrip('/')}/v1/projects/{project_id}/services/{service}" + ) + try: + check_response = await get_async(check_url, headers=headers) + if check_response.status_code == 200: + service_data = check_response.json() + if service_data.get("state") == "ENABLED": + log.info(f"服务 {service} 已启用") + continue + except Exception as e: + log.debug(f"检查服务状态失败,将尝试启用: {e}") + + # 启用服务 + enable_url = f"{service_usage_base_url.rstrip('/')}/v1/projects/{project_id}/services/{service}:enable" + try: + enable_response = await post_async(enable_url, headers=headers, json={}) + + if enable_response.status_code in [200, 201]: + log.info(f"✅ 成功启用服务: {service}") + elif enable_response.status_code == 400: + error_data = enable_response.json() + if "already enabled" in error_data.get("error", {}).get("message", "").lower(): + log.info(f"✅ 服务 {service} 已经启用") + else: + log.warning(f"⚠️ 启用服务 {service} 时出现警告: {error_data}") + else: + log.warning( + f"⚠️ 启用服务 {service} 失败: {enable_response.status_code} - {enable_response.text}" + ) + + except Exception as e: + log.warning(f"⚠️ 启用服务 {service} 时发生异常: {e}") + + return True + + except Exception as e: + log.error(f"启用API服务时发生错误: {e}") + return False + + +async def get_user_projects(credentials: Credentials) -> List[Dict[str, Any]]: + """获取用户可访问的Google Cloud项目列表""" + try: + # 确保凭证有效 + if credentials.is_expired() and credentials.refresh_token: + await credentials.refresh() + + headers = { + "Authorization": f"Bearer {credentials.access_token}", + "User-Agent": "geminicli-oauth/1.0", + } + + # 使用Resource Manager API的正确域名和端点 + resource_manager_base_url = await get_resource_manager_api_url() + url = f"{resource_manager_base_url.rstrip('/')}/v1/projects" + log.info(f"正在调用API: {url}") + response = await get_async(url, headers=headers) + + log.info(f"API响应状态码: {response.status_code}") + if response.status_code != 200: + log.error(f"API响应内容: {response.text}") + + if response.status_code == 200: + data = response.json() + projects = data.get("projects", []) + # 只返回活跃的项目 + active_projects = [ + project for project in projects if project.get("lifecycleState") == "ACTIVE" + ] + log.info(f"获取到 {len(active_projects)} 个活跃项目") + return active_projects + else: + log.warning(f"获取项目列表失败: {response.status_code} - {response.text}") + return [] + + except Exception as e: + log.error(f"获取用户项目列表失败: {e}") + return [] + + +async def select_default_project(projects: List[Dict[str, Any]]) -> Optional[str]: + """从项目列表中选择默认项目""" + if not projects: + return None + + # 策略1:查找显示名称或项目ID包含"default"的项目 + for project in projects: + display_name = project.get("displayName", "").lower() + # Google API returns projectId in camelCase + project_id = project.get("projectId", "") + if "default" in display_name or "default" in project_id.lower(): + log.info(f"选择默认项目: {project_id} ({project.get('displayName', project_id)})") + return project_id + + # 策略2:选择第一个项目 + first_project = projects[0] + # Google API returns projectId in camelCase + project_id = first_project.get("projectId", "") + log.info( + f"选择第一个项目作为默认: {project_id} ({first_project.get('displayName', project_id)})" + ) + return project_id + + +async def fetch_project_id_and_tier( + access_token: str, + user_agent: str, + api_base_url: str, + include_credits: bool = False, +) -> Tuple[Optional[str], Optional[str]] | Tuple[Optional[str], Optional[str], Optional[int]]: + """ + 从 API 获取 project_id 和订阅等级 + + Args: + access_token: Google OAuth access token + user_agent: User-Agent header + api_base_url: API base URL (e.g., antigravity or code assist endpoint) + + Returns: + 默认返回 (project_id, subscription_tier) + 当 include_credits=True 时返回 (project_id, subscription_tier, credit_amount) + subscription_tier 可能是 "FREE", "PRO", "ULTRA" 或 None + credit_amount 为积分数量(整数)或 None + """ + headers = { + 'User-Agent': user_agent, + 'Authorization': f'Bearer {access_token}', + 'Content-Type': 'application/json', + 'Accept-Encoding': 'gzip' + } + + def _map_raw_tier(raw_tier: Optional[str]) -> Optional[str]: + """将 loadCodeAssist 返回的 raw tier 映射为统一 tier。""" + if not raw_tier: + return None + + tier_mapping = { + "g1-ultra-tier": "ultra", + "ws-ai-ultra-business-tier": "ultra", + "g1-pro-tier": "pro", + "helium-tier": "pro", + "standard-tier": "pro", + "free-tier": "free", + } + + return tier_mapping.get(raw_tier.lower(), "pro") + + subscription_tier = None + credit_amount: Optional[int] = None + + # 步骤 1: 尝试 loadCodeAssist + try: + project_id, raw_tier, raw_credit_amount = await _try_load_code_assist(api_base_url, headers) + subscription_tier = _map_raw_tier(raw_tier) + + if raw_credit_amount is not None: + try: + credit_amount = int(raw_credit_amount) + log.info( + f"[fetch_project_id_and_tier] Found credit_amount: {credit_amount}" + ) + except (TypeError, ValueError): + log.warning( + f"[fetch_project_id_and_tier] Invalid credit_amount: {raw_credit_amount}" + ) + + if raw_tier: + log.info( + f"[fetch_project_id_and_tier] Raw tier '{raw_tier}' mapped to '{subscription_tier}'" + ) + + if project_id: + if include_credits: + return project_id, subscription_tier, credit_amount + return project_id, subscription_tier + + log.warning("[fetch_project_id_and_tier] loadCodeAssist did not return project_id, falling back to onboardUser") + + except Exception as e: + log.warning(f"[fetch_project_id_and_tier] loadCodeAssist failed: {type(e).__name__}: {e}") + log.warning("[fetch_project_id_and_tier] Falling back to onboardUser") + + # 步骤 2: 回退到 onboardUser + try: + project_id = await _try_onboard_user(api_base_url, headers) + if project_id: + if include_credits: + return project_id, subscription_tier, credit_amount + return project_id, subscription_tier + + log.error("[fetch_project_id_and_tier] Failed to get project_id from both loadCodeAssist and onboardUser") + if include_credits: + return None, subscription_tier, credit_amount + return None, subscription_tier + + except Exception as e: + log.error(f"[fetch_project_id_and_tier] onboardUser failed: {type(e).__name__}: {e}") + import traceback + log.debug(f"[fetch_project_id_and_tier] Traceback: {traceback.format_exc()}") + if include_credits: + return None, subscription_tier, credit_amount + return None, subscription_tier + + +async def _try_load_code_assist( + api_base_url: str, + headers: dict +) -> Tuple[Optional[str], Optional[str], Optional[str]]: + """ + 尝试通过 loadCodeAssist 获取 project_id 和订阅等级 + + Returns: + (project_id, subscription_tier, credit_amount) 元组 + subscription_tier 可能是 "FREE", "PRO", "ULTRA" 或 None + credit_amount 为字符串格式积分或 None + """ + request_url = f"{api_base_url.rstrip('/')}/v1internal:loadCodeAssist" + request_body = { + "metadata": { + "ideType": "ANTIGRAVITY" + } + } + + log.debug(f"[loadCodeAssist] Fetching project_id from: {request_url}") + log.debug(f"[loadCodeAssist] Request body: {request_body}") + + response = await post_async( + request_url, + json=request_body, + headers=headers, + timeout=30.0, + ) + + log.debug(f"[loadCodeAssist] Response status: {response.status_code}") + + if response.status_code == 200: + response_text = response.text + log.debug(f"[loadCodeAssist] Response body: {response_text}") + + data = response.json() + log.debug(f"[loadCodeAssist] Response JSON keys: {list(data.keys())}") + + # 提取订阅等级 - 优先使用 paidTier(更准确反映实际权益) + paid_tier = data.get("paidTier", {}) + current_tier = data.get("currentTier", {}) + available_credits = paid_tier.get("availableCredits", []) if isinstance(paid_tier, dict) else [] + + # paidTier.id 优先,然后是 currentTier.id + subscription_tier = None + if isinstance(paid_tier, dict) and paid_tier.get("id"): + subscription_tier = paid_tier.get("id") + log.info(f"[loadCodeAssist] Found paidTier: {subscription_tier}") + elif isinstance(current_tier, dict) and current_tier.get("id"): + subscription_tier = current_tier.get("id") + log.info(f"[loadCodeAssist] Found currentTier: {subscription_tier}") + + # 提取积分数量(如果返回了 availableCredits) + credit_amount = None + if isinstance(available_credits, list) and available_credits: + first_credit = available_credits[0] + if isinstance(first_credit, dict): + credit_amount = first_credit.get("creditAmount") + if credit_amount is not None: + log.info(f"[loadCodeAssist] Found creditAmount: {credit_amount}") + + # 检查是否有 currentTier(表示用户已激活) + if current_tier: + log.info("[loadCodeAssist] User is already activated") + + # 使用服务器返回的 project_id + project_id = data.get("cloudaicompanionProject") + if project_id: + log.info(f"[loadCodeAssist] Successfully fetched project_id: {project_id}, tier: {subscription_tier}") + return project_id, subscription_tier, credit_amount + + log.warning("[loadCodeAssist] No project_id in response") + return None, subscription_tier, credit_amount + else: + log.info("[loadCodeAssist] User not activated yet (no currentTier)") + return None, None, credit_amount + else: + log.warning(f"[loadCodeAssist] Failed: HTTP {response.status_code}") + log.warning(f"[loadCodeAssist] Response body: {response.text[:500]}") + raise Exception(f"HTTP {response.status_code}: {response.text[:200]}") + + +async def _try_onboard_user( + api_base_url: str, + headers: dict +) -> Optional[str]: + """ + 尝试通过 onboardUser 获取 project_id(长时间运行操作,需要轮询) + + Returns: + project_id 或 None + """ + request_url = f"{api_base_url.rstrip('/')}/v1internal:onboardUser" + + # 首先需要获取用户的 tier 信息 + tier_id = await _get_onboard_tier(api_base_url, headers) + if not tier_id: + log.error("[onboardUser] Failed to determine user tier") + return None + + log.info(f"[onboardUser] User tier: {tier_id}") + + # 构造 onboardUser 请求 + # 注意:FREE tier 不应该包含 cloudaicompanionProject + request_body = { + "tierId": tier_id, + "metadata": { + "ideType": "ANTIGRAVITY", + "platform": "PLATFORM_UNSPECIFIED", + "pluginType": "GEMINI" + } + } + + log.debug(f"[onboardUser] Request URL: {request_url}") + log.debug(f"[onboardUser] Request body: {request_body}") + + # onboardUser 是长时间运行操作,需要轮询 + # 最多等待 10 秒(5 次 * 2 秒) + max_attempts = 5 + attempt = 0 + + while attempt < max_attempts: + attempt += 1 + log.debug(f"[onboardUser] Polling attempt {attempt}/{max_attempts}") + + response = await post_async( + request_url, + json=request_body, + headers=headers, + timeout=30.0, + ) + + log.debug(f"[onboardUser] Response status: {response.status_code}") + + if response.status_code == 200: + data = response.json() + log.debug(f"[onboardUser] Response data: {data}") + + # 检查长时间运行操作是否完成 + if data.get("done"): + log.info("[onboardUser] Operation completed") + + # 从响应中提取 project_id + response_data = data.get("response", {}) + project_obj = response_data.get("cloudaicompanionProject", {}) + + if isinstance(project_obj, dict): + project_id = project_obj.get("id") + elif isinstance(project_obj, str): + project_id = project_obj + else: + project_id = None + + if project_id: + log.info(f"[onboardUser] Successfully fetched project_id: {project_id}") + return project_id + else: + log.warning("[onboardUser] Operation completed but no project_id in response") + return None + else: + log.debug("[onboardUser] Operation still in progress, waiting 2 seconds...") + await asyncio.sleep(2) + else: + log.warning(f"[onboardUser] Failed: HTTP {response.status_code}") + log.warning(f"[onboardUser] Response body: {response.text[:500]}") + raise Exception(f"HTTP {response.status_code}: {response.text[:200]}") + + log.error("[onboardUser] Timeout: Operation did not complete within 10 seconds") + return None + + +async def _get_onboard_tier( + api_base_url: str, + headers: dict +) -> Optional[str]: + """ + 从 loadCodeAssist 响应中获取用户应该注册的 tier + + Returns: + tier_id (如 "FREE", "STANDARD", "LEGACY") 或 None + """ + request_url = f"{api_base_url.rstrip('/')}/v1internal:loadCodeAssist" + request_body = { + "metadata": { + "ideType": "ANTIGRAVITY", + "platform": "PLATFORM_UNSPECIFIED", + "pluginType": "GEMINI" + } + } + + log.debug(f"[_get_onboard_tier] Fetching tier info from: {request_url}") + + response = await post_async( + request_url, + json=request_body, + headers=headers, + timeout=30.0, + ) + + if response.status_code == 200: + data = response.json() + log.debug(f"[_get_onboard_tier] Response data: {data}") + + # 查找默认的 tier + allowed_tiers = data.get("allowedTiers", []) + for tier in allowed_tiers: + if tier.get("isDefault"): + tier_id = tier.get("id") + log.info(f"[_get_onboard_tier] Found default tier: {tier_id}") + return tier_id + + # 如果没有默认 tier,使用 LEGACY 作为回退 + log.warning("[_get_onboard_tier] No default tier found, using LEGACY") + return "LEGACY" + else: + log.error(f"[_get_onboard_tier] Failed to fetch tier info: HTTP {response.status_code}") + return None + + diff --git a/src/httpx_client.py b/src/httpx_client.py new file mode 100644 index 0000000000000000000000000000000000000000..d04007e2e09e1113f2ef989d204a42b50323664a --- /dev/null +++ b/src/httpx_client.py @@ -0,0 +1,121 @@ +""" +通用的HTTP客户端模块 +为所有需要使用httpx的模块提供统一的客户端配置和方法 +保持通用性,不与特定业务逻辑耦合 +""" + +from contextlib import asynccontextmanager +from typing import Any, AsyncGenerator, Dict, Optional + +import httpx + +from config import get_proxy_config +from log import log + + +class HttpxClientManager: + """通用HTTP客户端管理器""" + + async def get_client_kwargs(self, timeout: float = 30.0, **kwargs) -> Dict[str, Any]: + """获取httpx客户端的通用配置参数""" + client_kwargs = {"timeout": timeout, **kwargs} + + # 动态读取代理配置,支持热更新 + current_proxy_config = await get_proxy_config() + if current_proxy_config: + client_kwargs["proxy"] = current_proxy_config + + return client_kwargs + + @asynccontextmanager + async def get_client( + self, timeout: float = 30.0, **kwargs + ) -> AsyncGenerator[httpx.AsyncClient, None]: + """获取配置好的异步HTTP客户端""" + client_kwargs = await self.get_client_kwargs(timeout=timeout, **kwargs) + + async with httpx.AsyncClient(**client_kwargs) as client: + yield client + + @asynccontextmanager + async def get_streaming_client( + self, timeout: float = None, **kwargs + ) -> AsyncGenerator[httpx.AsyncClient, None]: + """获取用于流式请求的HTTP客户端(无超时限制)""" + client_kwargs = await self.get_client_kwargs(timeout=timeout, **kwargs) + + # 创建独立的客户端实例用于流式处理 + client = httpx.AsyncClient(**client_kwargs) + try: + yield client + finally: + # 确保无论发生什么都关闭客户端 + try: + await client.aclose() + except Exception as e: + log.warning(f"Error closing streaming client: {e}") + + +# 全局HTTP客户端管理器实例 +http_client = HttpxClientManager() + + +# 通用的异步方法 +async def get_async( + url: str, headers: Optional[Dict[str, str]] = None, timeout: float = 30.0, **kwargs +) -> httpx.Response: + """通用异步GET请求""" + async with http_client.get_client(timeout=timeout, **kwargs) as client: + return await client.get(url, headers=headers) + + +async def post_async( + url: str, + data: Any = None, + json: Any = None, + headers: Optional[Dict[str, str]] = None, + timeout: float = 900.0, + **kwargs, +) -> httpx.Response: + """通用异步POST请求""" + async with http_client.get_client(timeout=timeout, **kwargs) as client: + return await client.post(url, data=data, json=json, headers=headers) + + +# 调试用:设为 True 时所有流式请求都返回 429 +_MOCK_STREAM_429 = False + +async def stream_post_async( + url: str, + body: Dict[str, Any], + native: bool = False, + headers: Optional[Dict[str, str]] = None, + **kwargs, +): + """流式异步POST请求""" + if _MOCK_STREAM_429: + from fastapi import Response + import json + log.warning(f"[MOCK] stream_post_async: 返回模拟429错误") + yield Response( + content=json.dumps({"error": {"code": 429, "message": "mock rate limit", "status": "RESOURCE_EXHAUSTED"}}), + status_code=429, + ) + return + + async with http_client.get_streaming_client(**kwargs) as client: + async with client.stream("POST", url, json=body, headers=headers) as r: + # 错误直接返回 + if r.status_code != 200: + from fastapi import Response + yield Response(await r.aread(), r.status_code, dict(r.headers)) + return + + # 如果native=True,直接返回bytes流 + if native: + async for chunk in r.aiter_bytes(): + yield chunk + else: + # 通过aiter_lines转化成str流返回 + async for line in r.aiter_lines(): + yield line diff --git a/src/keeplive.py b/src/keeplive.py new file mode 100644 index 0000000000000000000000000000000000000000..0844d70f77581dbf46f0b8a7ea8420b4f1fb83cb --- /dev/null +++ b/src/keeplive.py @@ -0,0 +1,88 @@ +""" +保活服务模块 +定期向配置的URL发送GET请求,保持服务在线 +未配置保活URL时不启动任何任务,零资源占用 +""" + +import asyncio +from typing import Optional + +from config import get_keepalive_interval, get_keepalive_url +from log import log +from src.httpx_client import get_async + + +class KeepAliveService: + """保活服务:定期向指定URL发送GET请求""" + + def __init__(self): + self._task: Optional[asyncio.Task] = None + + async def _run(self, url: str, interval: int): + """保活循环,读取到有效URL才会被调用""" + log.info(f"[KeepAlive] 保活任务启动,URL={url},间隔={interval}s") + while True: + try: + response = await get_async(url, timeout=30.0) + log.info(f"[KeepAlive] GET {url} -> {response.status_code}") + except asyncio.CancelledError: + raise + except Exception as e: + log.warning(f"[KeepAlive] GET {url} 失败: {e}") + + try: + await asyncio.sleep(interval) + except asyncio.CancelledError: + raise + + async def start(self): + """ + 启动保活服务。 + 仅当配置了有效的保活URL时才创建后台任务,否则零开销。 + """ + if self._task and not self._task.done(): + # 已有任务在运行,不重复启动 + return + + url = await get_keepalive_url() + interval = await get_keepalive_interval() + + if not url or not url.strip(): + log.debug("[KeepAlive] 未配置保活URL,保活服务不启动") + return + + if interval <= 0: + log.warning(f"[KeepAlive] 保活间隔无效({interval}),保活服务不启动") + return + + self._task = asyncio.create_task( + self._run(url.strip(), interval), name="keepalive_service" + ) + + async def stop(self): + """停止保活服务""" + if self._task and not self._task.done(): + self._task.cancel() + try: + await self._task + except asyncio.CancelledError: + pass + log.info("[KeepAlive] 保活服务已停止") + self._task = None + + async def restart(self): + """ + 重启保活服务。 + 配置变更时调用,会停止旧任务并根据最新配置决定是否启动新任务。 + """ + await self.stop() + await self.start() + + @property + def is_running(self) -> bool: + """当前保活任务是否在运行""" + return self._task is not None and not self._task.done() + + +# 全局保活服务实例 +keepalive_service = KeepAliveService() diff --git a/src/models.py b/src/models.py new file mode 100644 index 0000000000000000000000000000000000000000..1e947a258038fe6036b30d5ce168506fdf2638d3 --- /dev/null +++ b/src/models.py @@ -0,0 +1,379 @@ +from typing import Any, Dict, List, Optional, Union + +from pydantic import BaseModel, Field + + +# Pydantic v1/v2 兼容性辅助函数 +def model_to_dict(model: BaseModel) -> Dict[str, Any]: + """ + 兼容 Pydantic v1 和 v2 的模型转字典方法,排除 None 值 + - v1: model.dict(exclude_none=True) + - v2: model.model_dump(exclude_none=True) + """ + if hasattr(model, 'model_dump'): + # Pydantic v2 + return model.model_dump(exclude_none=True) + else: + # Pydantic v1 + return model.dict(exclude_none=True) + + +# Common Models +class Model(BaseModel): + id: str + object: str = "model" + created: Optional[int] = None + owned_by: Optional[str] = "google" + + +class ModelList(BaseModel): + object: str = "list" + data: List[Model] + + +# OpenAI Models +class OpenAIToolFunction(BaseModel): + name: str + arguments: str # JSON string + + +class OpenAIToolCall(BaseModel): + id: str + type: str = "function" + function: OpenAIToolFunction + + +class OpenAITool(BaseModel): + type: str = "function" + function: Dict[str, Any] + + +class OpenAIChatMessage(BaseModel): + role: str + content: Union[str, List[Dict[str, Any]], None] = None + reasoning_content: Optional[str] = None + name: Optional[str] = None + tool_calls: Optional[List[OpenAIToolCall]] = None + tool_call_id: Optional[str] = None # for role="tool" + + +class OpenAIChatCompletionRequest(BaseModel): + model: str + messages: List[OpenAIChatMessage] + stream: bool = False + temperature: Optional[float] = Field(None, ge=0.0, le=2.0) + top_p: Optional[float] = Field(None, ge=0.0, le=1.0) + max_tokens: Optional[int] = Field(None, ge=1) + stop: Optional[Union[str, List[str]]] = None + frequency_penalty: Optional[float] = Field(None, ge=-2.0, le=2.0) + presence_penalty: Optional[float] = Field(None, ge=-2.0, le=2.0) + n: Optional[int] = Field(1, ge=1, le=128) + seed: Optional[int] = None + response_format: Optional[Dict[str, Any]] = None + top_k: Optional[int] = Field(None, ge=1) + tools: Optional[List[OpenAITool]] = None + tool_choice: Optional[Union[str, Dict[str, Any]]] = None + + class Config: + extra = "allow" # Allow additional fields not explicitly defined + + +# 通用的聊天完成请求模型(兼容OpenAI和其他格式) +ChatCompletionRequest = OpenAIChatCompletionRequest + + +class OpenAIChatCompletionChoice(BaseModel): + index: int + message: OpenAIChatMessage + finish_reason: Optional[str] = None + logprobs: Optional[Dict[str, Any]] = None + + +class OpenAIChatCompletionResponse(BaseModel): + id: str + object: str = "chat.completion" + created: int + model: str + choices: List[OpenAIChatCompletionChoice] + usage: Optional[Dict[str, int]] = None + system_fingerprint: Optional[str] = None + + +class OpenAIDelta(BaseModel): + role: Optional[str] = None + content: Optional[str] = None + reasoning_content: Optional[str] = None + + +class OpenAIChatCompletionStreamChoice(BaseModel): + index: int + delta: OpenAIDelta + finish_reason: Optional[str] = None + logprobs: Optional[Dict[str, Any]] = None + + +class OpenAIChatCompletionStreamResponse(BaseModel): + id: str + object: str = "chat.completion.chunk" + created: int + model: str + choices: List[OpenAIChatCompletionStreamChoice] + system_fingerprint: Optional[str] = None + + +# Gemini Models +class GeminiPart(BaseModel): + text: Optional[str] = None + inlineData: Optional[Dict[str, Any]] = None + fileData: Optional[Dict[str, Any]] = None + thought: Optional[bool] = None # 改为 None,避免序列化时包含 False + + class Config: + extra = "allow" # 允许额外字段(如 functionCall, functionResponse) + + +class GeminiContent(BaseModel): + role: str + parts: List[GeminiPart] + + +class GeminiSystemInstruction(BaseModel): + parts: List[GeminiPart] + + +class GeminiImageConfig(BaseModel): + """图片生成配置""" + aspect_ratio: Optional[str] = None # "1:1", "2:3", "3:2", "3:4", "4:3", "4:5", "5:4", "9:16", "16:9", "21:9" + image_size: Optional[str] = None # "1K", "2K", "4K" + + +class GeminiGenerationConfig(BaseModel): + temperature: Optional[float] = Field(None, ge=0.0, le=2.0) + topP: Optional[float] = Field(None, ge=0.0, le=1.0) + topK: Optional[int] = Field(None, ge=1) + maxOutputTokens: Optional[int] = Field(None, ge=1) + stopSequences: Optional[List[str]] = None + responseMimeType: Optional[str] = None + responseSchema: Optional[Dict[str, Any]] = None + candidateCount: Optional[int] = Field(None, ge=1, le=8) + seed: Optional[int] = None + frequencyPenalty: Optional[float] = Field(None, ge=-2.0, le=2.0) + presencePenalty: Optional[float] = Field(None, ge=-2.0, le=2.0) + thinkingConfig: Optional[Dict[str, Any]] = None + # 图片生成相关参数 + response_modalities: Optional[List[str]] = None # ["TEXT", "IMAGE"] + image_config: Optional[GeminiImageConfig] = None + + +class GeminiSafetySetting(BaseModel): + category: str + threshold: str + + +class GeminiRequest(BaseModel): + contents: List[GeminiContent] + systemInstruction: Optional[GeminiSystemInstruction] = None + generationConfig: Optional[GeminiGenerationConfig] = None + safetySettings: Optional[List[GeminiSafetySetting]] = None + tools: Optional[List[Dict[str, Any]]] = None + toolConfig: Optional[Dict[str, Any]] = None + cachedContent: Optional[str] = None + + class Config: + extra = "allow" # 允许透传未定义的字段 + + +class GeminiCandidate(BaseModel): + content: GeminiContent + finishReason: Optional[str] = None + index: int = 0 + safetyRatings: Optional[List[Dict[str, Any]]] = None + citationMetadata: Optional[Dict[str, Any]] = None + tokenCount: Optional[int] = None + + +class GeminiUsageMetadata(BaseModel): + promptTokenCount: Optional[int] = None + candidatesTokenCount: Optional[int] = None + totalTokenCount: Optional[int] = None + + +class GeminiResponse(BaseModel): + candidates: List[GeminiCandidate] + usageMetadata: Optional[GeminiUsageMetadata] = None + modelVersion: Optional[str] = None + + +# Claude Models +class ClaudeContentBlock(BaseModel): + type: str # "text", "image", "tool_use", "tool_result" + text: Optional[str] = None + source: Optional[Dict[str, Any]] = None # for image type + id: Optional[str] = None # for tool_use + name: Optional[str] = None # for tool_use + input: Optional[Dict[str, Any]] = None # for tool_use + tool_use_id: Optional[str] = None # for tool_result + content: Optional[Union[str, List[Dict[str, Any]]]] = None # for tool_result + + +class ClaudeMessage(BaseModel): + role: str # "user" or "assistant" + content: Union[str, List[ClaudeContentBlock]] + + +class ClaudeTool(BaseModel): + name: str + description: Optional[str] = None + input_schema: Dict[str, Any] + + +class ClaudeMetadata(BaseModel): + user_id: Optional[str] = None + + +class ClaudeRequest(BaseModel): + model: str + messages: List[ClaudeMessage] + max_tokens: int = Field(..., ge=1) + system: Optional[Union[str, List[Dict[str, Any]]]] = None + temperature: Optional[float] = Field(None, ge=0.0, le=1.0) + top_p: Optional[float] = Field(None, ge=0.0, le=1.0) + top_k: Optional[int] = Field(None, ge=1) + stop_sequences: Optional[List[str]] = None + stream: bool = False + metadata: Optional[ClaudeMetadata] = None + tools: Optional[List[ClaudeTool]] = None + tool_choice: Optional[Union[str, Dict[str, Any]]] = None + + class Config: + extra = "allow" + + +class ClaudeUsage(BaseModel): + input_tokens: int + output_tokens: int + + +class ClaudeResponse(BaseModel): + id: str + type: str = "message" + role: str = "assistant" + content: List[ClaudeContentBlock] + model: str + stop_reason: Optional[str] = None + stop_sequence: Optional[str] = None + usage: ClaudeUsage + + +class ClaudeStreamEvent(BaseModel): + type: str # "message_start", "content_block_start", "content_block_delta", "content_block_stop", "message_delta", "message_stop" + message: Optional[ClaudeResponse] = None + index: Optional[int] = None + content_block: Optional[ClaudeContentBlock] = None + delta: Optional[Dict[str, Any]] = None + usage: Optional[ClaudeUsage] = None + + class Config: + extra = "allow" + + +# Error Models +class APIError(BaseModel): + message: str + type: str = "api_error" + code: Optional[int] = None + + +class ErrorResponse(BaseModel): + error: APIError + + +# Control Panel Models +class SystemStatus(BaseModel): + status: str + timestamp: str + credentials: Dict[str, int] + config: Dict[str, Any] + current_credential: str + + +class CredentialInfo(BaseModel): + filename: str + project_id: Optional[str] = None + status: Dict[str, Any] + size: Optional[int] = None + modified_time: Optional[str] = None + error: Optional[str] = None + + +class LogEntry(BaseModel): + timestamp: str + level: str + message: str + module: Optional[str] = None + + +class ConfigValue(BaseModel): + key: str + value: Any + env_locked: bool = False + description: Optional[str] = None + + +# Authentication Models +class AuthRequest(BaseModel): + project_id: Optional[str] = None + user_session: Optional[str] = None + + +class AuthResponse(BaseModel): + success: bool + auth_url: Optional[str] = None + state: Optional[str] = None + error: Optional[str] = None + credentials: Optional[Dict[str, Any]] = None + file_path: Optional[str] = None + requires_manual_project_id: Optional[bool] = None + requires_project_selection: Optional[bool] = None + available_projects: Optional[List[Dict[str, str]]] = None + + +class CredentialStatus(BaseModel): + disabled: bool = False + error_codes: List[int] = [] + last_success: Optional[str] = None + + +# Web Routes Models +class LoginRequest(BaseModel): + password: str + + +class AuthStartRequest(BaseModel): + project_id: Optional[str] = None # 现在是可选的 + mode: Optional[str] = "geminicli" # 凭证模式: geminicli 或 antigravity + + +class AuthCallbackRequest(BaseModel): + project_id: Optional[str] = None # 现在是可选的 + mode: Optional[str] = "geminicli" # 凭证模式: geminicli 或 antigravity + + +class AuthCallbackUrlRequest(BaseModel): + callback_url: str # OAuth回调完整URL + project_id: Optional[str] = None # 可选的项目ID + mode: Optional[str] = "geminicli" # 凭证模式: geminicli 或 antigravity + + +class CredFileActionRequest(BaseModel): + filename: str + action: str # enable, disable, delete + + +class CredFileBatchActionRequest(BaseModel): + action: str # "enable", "disable", "delete" + filenames: List[str] # 批量操作的文件名列表 + + +class ConfigSaveRequest(BaseModel): + config: dict diff --git a/src/panel/__init__.py b/src/panel/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..07fda27851f384a842e5abaf7ac3e7f41c0bed68 --- /dev/null +++ b/src/panel/__init__.py @@ -0,0 +1,37 @@ +""" +Panel模块 - 整合所有控制面板路由 +""" + +from fastapi import APIRouter + +from . import auth, creds, config_routes, logs, version, root + + +def create_router() -> APIRouter: + """创建并返回整合所有子路由的主路由器""" + router = APIRouter() + + # 包含所有子路由 + router.include_router(root.router) + router.include_router(auth.router) + router.include_router(creds.router) + router.include_router(config_routes.router) + router.include_router(logs.router) + router.include_router(version.router) + + return router + + +# 导出主路由器 +router = create_router() + +# 导出常用工具 +from .utils import ConnectionManager, is_mobile_user_agent, validate_mode, get_env_locked_keys + +__all__ = [ + "router", + "ConnectionManager", + "is_mobile_user_agent", + "validate_mode", + "get_env_locked_keys", +] diff --git a/src/panel/auth.py b/src/panel/auth.py new file mode 100644 index 0000000000000000000000000000000000000000..21adff032191e62ccf2dc491820cb4f77b75b4c5 --- /dev/null +++ b/src/panel/auth.py @@ -0,0 +1,192 @@ +""" +认证路由模块 - 处理 /auth/* 相关的HTTP请求 +""" + +from fastapi import APIRouter, Depends, HTTPException +from fastapi.responses import JSONResponse + +from log import log +from src.auth import ( + asyncio_complete_auth_flow, + complete_auth_flow_from_callback_url, + create_auth_url, + get_auth_status, + verify_password, +) +from src.models import ( + LoginRequest, + AuthStartRequest, + AuthCallbackRequest, + AuthCallbackUrlRequest, +) +from src.utils import verify_panel_token + + +# 创建路由器 +router = APIRouter(prefix="/auth", tags=["auth"]) + + +@router.post("/login") +async def login(request: LoginRequest): + """用户登录(简化版:直接返回密码作为token)""" + try: + if await verify_password(request.password): + # 直接使用密码作为token,简化认证流程 + return JSONResponse(content={"token": request.password, "message": "登录成功"}) + else: + raise HTTPException(status_code=401, detail="密码错误") + except HTTPException: + raise + except Exception as e: + log.error(f"登录失败: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@router.post("/start") +async def start_auth(request: AuthStartRequest, token: str = Depends(verify_panel_token)): + """开始认证流程,支持自动检测项目ID""" + try: + # 如果没有提供项目ID,尝试自动检测 + project_id = request.project_id + if not project_id: + log.info("用户未提供项目ID,后续将使用自动检测...") + + # 使用认证令牌作为用户会话标识 + user_session = token if token else None + result = await create_auth_url( + project_id, user_session, mode=request.mode + ) + + if result["success"]: + return JSONResponse( + content={ + "auth_url": result["auth_url"], + "state": result["state"], + "auto_project_detection": result.get("auto_project_detection", False), + "detected_project_id": result.get("detected_project_id"), + } + ) + else: + raise HTTPException(status_code=500, detail=result["error"]) + + except HTTPException: + raise + except Exception as e: + log.error(f"开始认证流程失败: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@router.post("/callback") +async def auth_callback(request: AuthCallbackRequest, token: str = Depends(verify_panel_token)): + """处理认证回调,支持自动检测项目ID""" + try: + # 项目ID现在是可选的,在回调处理中进行自动检测 + project_id = request.project_id + + # 使用认证令牌作为用户会话标识 + user_session = token if token else None + # 异步等待OAuth回调完成 + result = await asyncio_complete_auth_flow( + project_id, user_session, mode=request.mode + ) + + if result["success"]: + # 单项目认证成功 + return JSONResponse( + content={ + "credentials": result["credentials"], + "file_path": result["file_path"], + "message": "认证成功,凭证已保存", + "auto_detected_project": result.get("auto_detected_project", False), + } + ) + else: + # 如果需要手动项目ID或项目选择,在响应中标明 + if result.get("requires_manual_project_id"): + # 使用JSON响应 + return JSONResponse( + status_code=400, + content={"error": result["error"], "requires_manual_project_id": True}, + ) + elif result.get("requires_project_selection"): + # 返回项目列表供用户选择 + return JSONResponse( + status_code=400, + content={ + "error": result["error"], + "requires_project_selection": True, + "available_projects": result["available_projects"], + }, + ) + else: + raise HTTPException(status_code=400, detail=result["error"]) + + except HTTPException: + raise + except Exception as e: + log.error(f"处理认证回调失败: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@router.post("/callback-url") +async def auth_callback_url(request: AuthCallbackUrlRequest, token: str = Depends(verify_panel_token)): + """从回调URL直接完成认证""" + try: + # 验证URL格式 + if not request.callback_url or not request.callback_url.startswith(("http://", "https://")): + raise HTTPException(status_code=400, detail="请提供有效的回调URL") + + # 从回调URL完成认证 + result = await complete_auth_flow_from_callback_url( + request.callback_url, request.project_id, mode=request.mode + ) + + if result["success"]: + # 单项目认证成功 + return JSONResponse( + content={ + "credentials": result["credentials"], + "file_path": result["file_path"], + "message": "从回调URL认证成功,凭证已保存", + "auto_detected_project": result.get("auto_detected_project", False), + } + ) + else: + # 处理各种错误情况 + if result.get("requires_manual_project_id"): + return JSONResponse( + status_code=400, + content={"error": result["error"], "requires_manual_project_id": True}, + ) + elif result.get("requires_project_selection"): + return JSONResponse( + status_code=400, + content={ + "error": result["error"], + "requires_project_selection": True, + "available_projects": result["available_projects"], + }, + ) + else: + raise HTTPException(status_code=400, detail=result["error"]) + + except HTTPException: + raise + except Exception as e: + log.error(f"从回调URL处理认证失败: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@router.get("/status/{project_id}") +async def check_auth_status(project_id: str, token: str = Depends(verify_panel_token)): + """检查认证状态""" + try: + if not project_id: + raise HTTPException(status_code=400, detail="Project ID 不能为空") + + status = get_auth_status(project_id) + return JSONResponse(content=status) + + except Exception as e: + log.error(f"检查认证状态失败: {e}") + raise HTTPException(status_code=500, detail=str(e)) diff --git a/src/panel/config_routes.py b/src/panel/config_routes.py new file mode 100644 index 0000000000000000000000000000000000000000..0702fde281faebd17ee65d8f07c33ad2123a989e --- /dev/null +++ b/src/panel/config_routes.py @@ -0,0 +1,224 @@ +""" +配置路由模块 - 处理 /config/* 相关的HTTP请求 +""" + +from fastapi import APIRouter, Depends, HTTPException +from fastapi.responses import JSONResponse + +import config +from log import log +from src.keeplive import keepalive_service +from src.models import ConfigSaveRequest +from src.storage_adapter import get_storage_adapter +from src.utils import verify_panel_token +from .utils import get_env_locked_keys + + +# 创建路由器 +router = APIRouter(prefix="/config", tags=["config"]) + + +@router.get("/get") +async def get_config(token: str = Depends(verify_panel_token)): + """获取当前配置""" + try: + + + # 读取当前配置(包括环境变量和TOML文件中的配置) + current_config = {} + + # 基础配置 + current_config["code_assist_endpoint"] = await config.get_code_assist_endpoint() + current_config["credentials_dir"] = await config.get_credentials_dir() + current_config["proxy"] = await config.get_proxy_config() or "" + + # 代理端点配置 + current_config["oauth_proxy_url"] = await config.get_oauth_proxy_url() + current_config["googleapis_proxy_url"] = await config.get_googleapis_proxy_url() + current_config["resource_manager_api_url"] = await config.get_resource_manager_api_url() + current_config["service_usage_api_url"] = await config.get_service_usage_api_url() + + # 自动封禁配置 + current_config["auto_ban_enabled"] = await config.get_auto_ban_enabled() + current_config["auto_ban_error_codes"] = await config.get_auto_ban_error_codes() + + # 429重试配置 + current_config["retry_429_max_retries"] = await config.get_retry_429_max_retries() + current_config["retry_429_enabled"] = await config.get_retry_429_enabled() + current_config["retry_429_interval"] = await config.get_retry_429_interval() + # 抗截断配置 + current_config["anti_truncation_max_attempts"] = await config.get_anti_truncation_max_attempts() + + # 兼容性配置 + current_config["compatibility_mode_enabled"] = await config.get_compatibility_mode_enabled() + + # 思维链返回配置 + current_config["return_thoughts_to_frontend"] = await config.get_return_thoughts_to_frontend() + + # Antigravity流式转非流式配置 + current_config["antigravity_stream2nostream"] = await config.get_antigravity_stream2nostream() + + # 保活配置 + current_config["keepalive_url"] = await config.get_keepalive_url() + current_config["keepalive_interval"] = await config.get_keepalive_interval() + + # 服务器配置 + current_config["host"] = await config.get_server_host() + current_config["port"] = await config.get_server_port() + current_config["api_password"] = await config.get_api_password() + current_config["panel_password"] = await config.get_panel_password() + current_config["password"] = await config.get_server_password() + + # 从存储系统读取配置 + storage_adapter = await get_storage_adapter() + storage_config = await storage_adapter.get_all_config() + + # 获取环境变量锁定的配置键 + env_locked_keys = get_env_locked_keys() + + # 合并存储系统配置(不覆盖环境变量) + for key, value in storage_config.items(): + if key not in env_locked_keys: + current_config[key] = value + + return JSONResponse(content={"config": current_config, "env_locked": list(env_locked_keys)}) + + except Exception as e: + log.error(f"获取配置失败: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@router.post("/save") +async def save_config(request: ConfigSaveRequest, token: str = Depends(verify_panel_token)): + """保存配置""" + try: + + new_config = request.config + + log.debug(f"收到的配置数据: {list(new_config.keys())}") + log.debug(f"收到的password值: {new_config.get('password', 'NOT_FOUND')}") + + # 验证配置项 + if "retry_429_max_retries" in new_config: + if ( + not isinstance(new_config["retry_429_max_retries"], int) + or new_config["retry_429_max_retries"] < 0 + ): + raise HTTPException(status_code=400, detail="最大429重试次数必须是大于等于0的整数") + + if "retry_429_enabled" in new_config: + if not isinstance(new_config["retry_429_enabled"], bool): + raise HTTPException(status_code=400, detail="429重试开关必须是布尔值") + + # 验证新的配置项 + if "retry_429_interval" in new_config: + try: + interval = float(new_config["retry_429_interval"]) + if interval < 0.01 or interval > 10: + raise HTTPException(status_code=400, detail="429重试间隔必须在0.01-10秒之间") + except (ValueError, TypeError): + raise HTTPException(status_code=400, detail="429重试间隔必须是有效的数字") + + if "anti_truncation_max_attempts" in new_config: + if ( + not isinstance(new_config["anti_truncation_max_attempts"], int) + or new_config["anti_truncation_max_attempts"] < 1 + or new_config["anti_truncation_max_attempts"] > 10 + ): + raise HTTPException( + status_code=400, detail="抗截断最大重试次数必须是1-10之间的整数" + ) + + if "compatibility_mode_enabled" in new_config: + if not isinstance(new_config["compatibility_mode_enabled"], bool): + raise HTTPException(status_code=400, detail="兼容性模式开关必须是布尔值") + + if "return_thoughts_to_frontend" in new_config: + if not isinstance(new_config["return_thoughts_to_frontend"], bool): + raise HTTPException(status_code=400, detail="思维链返回开关必须是布尔值") + + if "antigravity_stream2nostream" in new_config: + if not isinstance(new_config["antigravity_stream2nostream"], bool): + raise HTTPException(status_code=400, detail="Antigravity流式转非流式开关必须是布尔值") + + # 验证保活配置 + if "keepalive_url" in new_config: + if not isinstance(new_config["keepalive_url"], str): + raise HTTPException(status_code=400, detail="保活URL必须是字符串") + + if "keepalive_interval" in new_config: + try: + interval = int(new_config["keepalive_interval"]) + if interval < 5 or interval > 86400: + raise HTTPException(status_code=400, detail="保活间隔必须在 5-86400 秒之间") + new_config["keepalive_interval"] = interval + except (ValueError, TypeError): + raise HTTPException(status_code=400, detail="保活间隔必须是有效整数") + # 验证服务器配置 + if "host" in new_config: + if not isinstance(new_config["host"], str) or not new_config["host"].strip(): + raise HTTPException(status_code=400, detail="服务器主机地址不能为空") + + if "port" in new_config: + if ( + not isinstance(new_config["port"], int) + or new_config["port"] < 1 + or new_config["port"] > 65535 + ): + raise HTTPException(status_code=400, detail="端口号必须是1-65535之间的整数") + + if "api_password" in new_config: + if not isinstance(new_config["api_password"], str): + raise HTTPException(status_code=400, detail="API访问密码必须是字符串") + + if "panel_password" in new_config: + if not isinstance(new_config["panel_password"], str): + raise HTTPException(status_code=400, detail="控制面板密码必须是字符串") + + if "password" in new_config: + if not isinstance(new_config["password"], str): + raise HTTPException(status_code=400, detail="访问密码必须是字符串") + + # 获取环境变量锁定的配置键 + env_locked_keys = get_env_locked_keys() + + # 直接使用存储适配器保存配置 + storage_adapter = await get_storage_adapter() + for key, value in new_config.items(): + if key not in env_locked_keys: + await storage_adapter.set_config(key, value) + if key in ("password", "api_password", "panel_password"): + log.debug(f"设置{key}字段为: {value}") + + # 重新加载配置缓存(关键!) + await config.reload_config() + + # 如果保活相关配置发生变化,立即重启保活服务 + keepalive_keys = {"keepalive_url", "keepalive_interval"} + if keepalive_keys & set(new_config.keys()): + try: + await keepalive_service.restart() + except Exception as e: + log.warning(f"重启保活服务失败: {e}") + + # 验证保存后的结果 + test_api_password = await config.get_api_password() + test_panel_password = await config.get_panel_password() + test_password = await config.get_server_password() + log.debug(f"保存后立即读取的API密码: {test_api_password}") + log.debug(f"保存后立即读取的面板密码: {test_panel_password}") + log.debug(f"保存后立即读取的通用密码: {test_password}") + + # 构建响应消息 + response_data = { + "message": "配置保存成功", + "saved_config": {k: v for k, v in new_config.items() if k not in env_locked_keys}, + } + + return JSONResponse(content=response_data) + + except HTTPException: + raise + except Exception as e: + log.error(f"保存配置失败: {e}") + raise HTTPException(status_code=500, detail=str(e)) diff --git a/src/panel/creds.py b/src/panel/creds.py new file mode 100644 index 0000000000000000000000000000000000000000..bbadcdfa7df8a1ff34465947b0627d541760e9c1 --- /dev/null +++ b/src/panel/creds.py @@ -0,0 +1,1585 @@ +""" +凭证管理路由模块 - 处理 /creds/* 相关的HTTP请求 +""" + +import asyncio +import io +import json +import os +import time +import zipfile +from typing import Any, List + +from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, Response +from fastapi.responses import JSONResponse + +from log import log +from src.credential_manager import credential_manager +from src.models import ( + CredFileActionRequest, + CredFileBatchActionRequest +) +from src.storage_adapter import get_storage_adapter +from src.utils import verify_panel_token, GEMINICLI_USER_AGENT, ANTIGRAVITY_USER_AGENT +from src.api.antigravity import fetch_quota_info +from src.google_oauth_api import Credentials, fetch_project_id_and_tier +from config import get_code_assist_endpoint +from .utils import validate_mode + + +# 创建路由器 +router = APIRouter(prefix="/creds", tags=["credentials"]) + + +# ============================================================================= +# 工具函数 (Helper Functions) +# ============================================================================= + + +async def extract_json_files_from_zip(zip_file: UploadFile) -> List[dict]: + """从ZIP文件中提取JSON文件""" + try: + # 读取ZIP文件内容 + zip_content = await zip_file.read() + + # 不限制ZIP文件大小,只在处理时控制文件数量 + + files_data = [] + + with zipfile.ZipFile(io.BytesIO(zip_content), "r") as zip_ref: + # 获取ZIP中的所有文件 + file_list = zip_ref.namelist() + json_files = [ + f for f in file_list if f.endswith(".json") and not f.startswith("__MACOSX/") + ] + + if not json_files: + raise HTTPException(status_code=400, detail="ZIP文件中没有找到JSON文件") + + log.info(f"从ZIP文件 {zip_file.filename} 中找到 {len(json_files)} 个JSON文件") + + for json_filename in json_files: + try: + # 读取JSON文件内容 + with zip_ref.open(json_filename) as json_file: + content = json_file.read() + + try: + content_str = content.decode("utf-8") + except UnicodeDecodeError: + log.warning(f"跳过编码错误的文件: {json_filename}") + continue + + # 使用原始文件名(去掉路径) + filename = os.path.basename(json_filename) + files_data.append({"filename": filename, "content": content_str}) + + except Exception as e: + log.warning(f"处理ZIP中的文件 {json_filename} 时出错: {e}") + continue + + log.info(f"成功从ZIP文件中提取 {len(files_data)} 个有效的JSON文件") + return files_data + + except zipfile.BadZipFile: + raise HTTPException(status_code=400, detail="无效的ZIP文件格式") + except Exception as e: + log.error(f"处理ZIP文件失败: {e}") + raise HTTPException(status_code=500, detail=f"处理ZIP文件失败: {str(e)}") + + +async def clear_all_model_cooldowns_for_credential( + storage_adapter: Any, + filename: str, + mode: str, +) -> None: + """清空指定凭证的所有模型冷却(后端支持时执行)。""" + try: + cleared = await storage_adapter._backend.clear_all_model_cooldowns(filename, mode=mode) + if not cleared: + log.warning(f"清空模型CD失败或凭证不存在: {filename} (mode={mode})") + except Exception as e: + log.warning(f"清空模型CD时出错: {filename} (mode={mode}), error={e}") + + +async def upload_credentials_common( + files: List[UploadFile], mode: str = "geminicli" +) -> JSONResponse: + """批量上传凭证文件的通用函数""" + mode = validate_mode(mode) + + if not files: + raise HTTPException(status_code=400, detail="请选择要上传的文件") + + # 检查文件数量限制 + if len(files) > 100: + raise HTTPException( + status_code=400, detail=f"文件数量过多,最多支持100个文件,当前:{len(files)}个" + ) + + files_data = [] + for file in files: + # 检查文件类型:支持JSON和ZIP + if file.filename.endswith(".zip"): + zip_files_data = await extract_json_files_from_zip(file) + files_data.extend(zip_files_data) + log.info(f"从ZIP文件 {file.filename} 中提取了 {len(zip_files_data)} 个JSON文件") + + elif file.filename.endswith(".json"): + # 处理单个JSON文件 - 流式读取 + content_chunks = [] + while True: + chunk = await file.read(8192) + if not chunk: + break + content_chunks.append(chunk) + + content = b"".join(content_chunks) + try: + content_str = content.decode("utf-8") + except UnicodeDecodeError: + raise HTTPException( + status_code=400, detail=f"文件 {file.filename} 编码格式不支持" + ) + + files_data.append({"filename": file.filename, "content": content_str}) + else: + raise HTTPException( + status_code=400, detail=f"文件 {file.filename} 格式不支持,只支持JSON和ZIP文件" + ) + + + + batch_size = 1000 + all_results = [] + total_success = 0 + + for i in range(0, len(files_data), batch_size): + batch_files = files_data[i : i + batch_size] + + async def process_single_file(file_data): + try: + filename = file_data["filename"] + # 确保文件名只保存basename,避免路径问题 + filename = os.path.basename(filename) + content_str = file_data["content"] + credential_data = json.loads(content_str) + + # 根据凭证类型调用不同的添加方法 + if mode == "antigravity": + await credential_manager.add_antigravity_credential(filename, credential_data) + else: + await credential_manager.add_credential(filename, credential_data) + + log.debug(f"成功上传 {mode} 凭证文件: {filename}") + return {"filename": filename, "status": "success", "message": "上传成功"} + + except json.JSONDecodeError as e: + return { + "filename": file_data["filename"], + "status": "error", + "message": f"JSON格式错误: {str(e)}", + } + except Exception as e: + return { + "filename": file_data["filename"], + "status": "error", + "message": f"处理失败: {str(e)}", + } + + log.info(f"开始并发处理 {len(batch_files)} 个 {mode} 文件...") + concurrent_tasks = [process_single_file(file_data) for file_data in batch_files] + batch_results = await asyncio.gather(*concurrent_tasks, return_exceptions=True) + + processed_results = [] + batch_uploaded_count = 0 + for result in batch_results: + if isinstance(result, Exception): + processed_results.append( + { + "filename": "unknown", + "status": "error", + "message": f"处理异常: {str(result)}", + } + ) + else: + processed_results.append(result) + if result["status"] == "success": + batch_uploaded_count += 1 + + all_results.extend(processed_results) + total_success += batch_uploaded_count + + batch_num = (i // batch_size) + 1 + total_batches = (len(files_data) + batch_size - 1) // batch_size + log.info( + f"批次 {batch_num}/{total_batches} 完成: 成功 " + f"{batch_uploaded_count}/{len(batch_files)} 个 {mode} 文件" + ) + + if total_success > 0: + return JSONResponse( + content={ + "uploaded_count": total_success, + "total_count": len(files_data), + "results": all_results, + "message": f"批量上传完成: 成功 {total_success}/{len(files_data)} 个 {mode} 文件", + } + ) + else: + raise HTTPException(status_code=400, detail=f"没有 {mode} 文件上传成功") + + +async def get_creds_status_common( + offset: int, limit: int, status_filter: str, mode: str = "geminicli", + error_code_filter: str = None, cooldown_filter: str = None, preview_filter: str = None, tier_filter: str = None +) -> JSONResponse: + """获取凭证文件状态的通用函数""" + mode = validate_mode(mode) + # 验证分页参数 + if offset < 0: + raise HTTPException(status_code=400, detail="offset 必须大于等于 0") + if limit not in [20, 50, 100, 200, 500, 1000]: + raise HTTPException(status_code=400, detail="limit 只能是 20、50、100、200、500 或 1000") + if status_filter not in ["all", "enabled", "disabled"]: + raise HTTPException(status_code=400, detail="status_filter 只能是 all、enabled 或 disabled") + if cooldown_filter and cooldown_filter not in ["all", "in_cooldown", "no_cooldown"]: + raise HTTPException(status_code=400, detail="cooldown_filter 只能是 all、in_cooldown 或 no_cooldown") + if preview_filter and preview_filter not in ["all", "preview", "no_preview"]: + raise HTTPException(status_code=400, detail="preview_filter 只能是 all、preview 或 no_preview") + if tier_filter and tier_filter not in ["all", "free", "pro", "ultra"]: + raise HTTPException(status_code=400, detail="tier_filter 只能是 all、free、pro 或 ultra") + + + + storage_adapter = await get_storage_adapter() + backend_info = await storage_adapter.get_backend_info() + backend_type = backend_info.get("backend_type", "unknown") + + # 使用高性能的分页摘要查询 + result = await storage_adapter._backend.get_credentials_summary( + offset=offset, + limit=limit, + status_filter=status_filter, + mode=mode, + error_code_filter=error_code_filter if error_code_filter and error_code_filter != "all" else None, + cooldown_filter=cooldown_filter if cooldown_filter and cooldown_filter != "all" else None, + preview_filter=preview_filter if preview_filter and preview_filter != "all" else None, + tier_filter=tier_filter if tier_filter and tier_filter != "all" else None + ) + + creds_list = [] + for summary in result["items"]: + cred_info = { + "filename": os.path.basename(summary["filename"]), + "user_email": summary["user_email"], + "disabled": summary["disabled"], + "error_codes": summary["error_codes"], + "last_success": summary["last_success"], + "backend_type": backend_type, + "model_cooldowns": summary.get("model_cooldowns", {}), + "tier": summary.get("tier", "pro"), + } + + if mode == "geminicli": + cred_info["preview"] = summary.get("preview", True) + else: + cred_info["enable_credit"] = summary.get("enable_credit", False) + + creds_list.append(cred_info) + + return JSONResponse(content={ + "items": creds_list, + "total": result["total"], + "offset": offset, + "limit": limit, + "has_more": (offset + limit) < result["total"], + "stats": result.get("stats", {"total": 0, "normal": 0, "disabled": 0}), + }) + + +async def download_all_creds_common(mode: str = "geminicli") -> Response: + """打包下载所有凭证文件的通用函数""" + mode = validate_mode(mode) + zip_filename = "antigravity_credentials.zip" if mode == "antigravity" else "credentials.zip" + + storage_adapter = await get_storage_adapter() + credential_filenames = await storage_adapter.list_credentials(mode=mode) + + if not credential_filenames: + raise HTTPException(status_code=404, detail=f"没有找到 {mode} 凭证文件") + + log.info(f"开始打包 {len(credential_filenames)} 个 {mode} 凭证文件...") + + zip_buffer = io.BytesIO() + + with zipfile.ZipFile(zip_buffer, "w", zipfile.ZIP_DEFLATED) as zip_file: + success_count = 0 + for idx, filename in enumerate(credential_filenames, 1): + try: + credential_data = await storage_adapter.get_credential(filename, mode=mode) + if credential_data: + content = json.dumps(credential_data, ensure_ascii=False, indent=2) + zip_file.writestr(os.path.basename(filename), content) + success_count += 1 + + if idx % 10 == 0: + log.debug(f"打包进度: {idx}/{len(credential_filenames)}") + + except Exception as e: + log.warning(f"处理 {mode} 凭证文件 {filename} 时出错: {e}") + continue + + log.info(f"打包完成: 成功 {success_count}/{len(credential_filenames)} 个文件") + + zip_buffer.seek(0) + return Response( + content=zip_buffer.getvalue(), + media_type="application/zip", + headers={"Content-Disposition": f"attachment; filename={zip_filename}"}, + ) + + +async def fetch_user_email_common(filename: str, mode: str = "geminicli") -> JSONResponse: + """获取指定凭证文件用户邮箱的通用函数""" + mode = validate_mode(mode) + + filename_only = os.path.basename(filename) + if not filename_only.endswith(".json"): + raise HTTPException(status_code=404, detail="无效的文件名") + + storage_adapter = await get_storage_adapter() + credential_data = await storage_adapter.get_credential(filename_only, mode=mode) + if not credential_data: + raise HTTPException(status_code=404, detail="凭证文件不存在") + + email = await credential_manager.get_or_fetch_user_email(filename_only, mode=mode) + + if email: + return JSONResponse( + content={ + "filename": filename_only, + "user_email": email, + "message": "成功获取用户邮箱", + } + ) + else: + return JSONResponse( + content={ + "filename": filename_only, + "user_email": None, + "message": "无法获取用户邮箱,可能凭证已过期或权限不足", + }, + status_code=400, + ) + + +async def refresh_all_user_emails_common(mode: str = "geminicli") -> JSONResponse: + """刷新所有凭证文件用户邮箱的通用函数 - 只为没有邮箱的凭证获取 + + 利用 get_all_credential_states 批量获取状态 + """ + mode = validate_mode(mode) + + storage_adapter = await get_storage_adapter() + + # 一次性批量获取所有凭证的状态 + all_states = await storage_adapter.get_all_credential_states(mode=mode) + + results = [] + success_count = 0 + skipped_count = 0 + + # 在内存中筛选出需要获取邮箱的凭证 + for filename, state in all_states.items(): + try: + cached_email = state.get("user_email") + + if cached_email: + # 已有邮箱,跳过获取 + skipped_count += 1 + results.append({ + "filename": os.path.basename(filename), + "user_email": cached_email, + "success": True, + "skipped": True, + }) + continue + + # 没有邮箱,尝试获取 + email = await credential_manager.get_or_fetch_user_email(filename, mode=mode) + if email: + success_count += 1 + results.append({ + "filename": os.path.basename(filename), + "user_email": email, + "success": True, + }) + else: + results.append({ + "filename": os.path.basename(filename), + "user_email": None, + "success": False, + "error": "无法获取邮箱", + }) + except Exception as e: + results.append({ + "filename": os.path.basename(filename), + "user_email": None, + "success": False, + "error": str(e), + }) + + total_count = len(all_states) + return JSONResponse( + content={ + "success_count": success_count, + "total_count": total_count, + "skipped_count": skipped_count, + "results": results, + "message": f"成功获取 {success_count}/{total_count} 个邮箱地址,跳过 {skipped_count} 个已有邮箱的凭证", + } + ) + + +async def deduplicate_credentials_by_email_common(mode: str = "geminicli") -> JSONResponse: + """批量去重凭证文件的通用函数 - 删除邮箱相同的凭证(只保留一个)""" + mode = validate_mode(mode) + storage_adapter = await get_storage_adapter() + + try: + duplicate_info = await storage_adapter._backend.get_duplicate_credentials_by_email( + mode=mode + ) + + duplicate_groups = duplicate_info.get("duplicate_groups", []) + no_email_files = duplicate_info.get("no_email_files", []) + total_count = duplicate_info.get("total_count", 0) + + if not duplicate_groups: + return JSONResponse( + content={ + "deleted_count": 0, + "kept_count": total_count, + "total_count": total_count, + "unique_emails_count": duplicate_info.get("unique_email_count", 0), + "no_email_count": len(no_email_files), + "duplicate_groups": [], + "delete_errors": [], + "message": "没有发现重复的凭证(相同邮箱)", + } + ) + + # 执行删除操作 + deleted_count = 0 + delete_errors = [] + result_duplicate_groups = [] + + for group in duplicate_groups: + email = group["email"] + kept_file = group["kept_file"] + duplicate_files = group["duplicate_files"] + + deleted_files_in_group = [] + for filename in duplicate_files: + try: + success = await credential_manager.remove_credential(filename, mode=mode) + if success: + deleted_count += 1 + deleted_files_in_group.append(os.path.basename(filename)) + log.info(f"去重删除凭证: {filename} (邮箱: {email}) (mode={mode})") + else: + delete_errors.append(f"{os.path.basename(filename)}: 删除失败") + except Exception as e: + delete_errors.append(f"{os.path.basename(filename)}: {str(e)}") + log.error(f"去重删除凭证 {filename} 时出错: {e}") + + result_duplicate_groups.append({ + "email": email, + "kept_file": os.path.basename(kept_file), + "deleted_files": deleted_files_in_group, + "duplicate_count": len(deleted_files_in_group), + }) + + kept_count = total_count - deleted_count + + return JSONResponse( + content={ + "deleted_count": deleted_count, + "kept_count": kept_count, + "total_count": total_count, + "unique_emails_count": duplicate_info.get("unique_email_count", 0), + "no_email_count": len(no_email_files), + "duplicate_groups": result_duplicate_groups, + "delete_errors": delete_errors, + "message": f"去重完成:删除 {deleted_count} 个重复凭证,保留 {kept_count} 个凭证({duplicate_info.get('unique_email_count', 0)} 个唯一邮箱)", + } + ) + + except Exception as e: + log.error(f"批量去重凭证时出错: {e}") + return JSONResponse( + status_code=500, + content={ + "deleted_count": 0, + "kept_count": 0, + "total_count": 0, + "message": f"去重操作失败: {str(e)}", + } + ) + + +async def verify_credential_project_common(filename: str, mode: str = "geminicli") -> JSONResponse: + """验证并重新获取凭证的project id的通用函数""" + mode = validate_mode(mode) + + # 验证文件名 + if not filename.endswith(".json"): + raise HTTPException(status_code=400, detail="无效的文件名") + + + storage_adapter = await get_storage_adapter() + + # 获取凭证数据 + credential_data = await storage_adapter.get_credential(filename, mode=mode) + if not credential_data: + raise HTTPException(status_code=404, detail="凭证不存在") + + # 创建凭证对象 + credentials = Credentials.from_dict(credential_data) + + # 确保token有效(自动刷新) + token_refreshed = await credentials.refresh_if_needed() + + # 如果token被刷新了,更新存储 + if token_refreshed: + log.info(f"Token已自动刷新: {filename} (mode={mode})") + credential_data = credentials.to_dict() + await storage_adapter.store_credential(filename, credential_data, mode=mode) + + # 获取API端点和对应的User-Agent + if mode == "antigravity": + api_base_url = await get_code_assist_endpoint() + user_agent = ANTIGRAVITY_USER_AGENT + else: + api_base_url = await get_code_assist_endpoint() + user_agent = GEMINICLI_USER_AGENT + + # 重新获取project id(仅 antigravity 模式请求积分) + if mode == "antigravity": + project_id, subscription_tier, credit_amount = await fetch_project_id_and_tier( + access_token=credentials.access_token, + user_agent=user_agent, + api_base_url=api_base_url, + include_credits=True, + ) + else: + project_id, subscription_tier = await fetch_project_id_and_tier( + access_token=credentials.access_token, + user_agent=user_agent, + api_base_url=api_base_url, + ) + credit_amount = None + + if project_id: + credential_data["project_id"] = project_id + + if project_id or subscription_tier: + await storage_adapter.store_credential(filename, credential_data, mode=mode) + + # 检验成功后自动解除禁用状态并清除错误码 + state_update = { + "disabled": False, + "error_codes": [] + } + + # 同步更新状态表中的 tier 字段 + state_update["tier"] = subscription_tier + + # 如果是 geminicli 模式,直接设置 preview=True + if mode == "geminicli": + state_update["preview"] = True + + await storage_adapter.update_credential_state(filename, state_update, mode=mode) + + log.info(f"检验 {mode} 凭证成功: {filename} - Project ID: {project_id}, Tier: {subscription_tier} - 已解除禁用并清除错误码") + + response_data = { + "success": True, + "filename": filename, + "project_id": project_id, + "subscription_tier": subscription_tier, + "message": "检验成功!Project ID已更新,已解除禁用状态并清除错误码,403错误应该已恢复" + } + + if mode == "antigravity" and credit_amount is not None: + response_data["credit_amount"] = credit_amount + + return JSONResponse(content=response_data) + else: + return JSONResponse( + status_code=400, + content={ + "success": False, + "filename": filename, + "message": "检验失败:无法获取Project ID,请检查凭证是否有效" + } + ) + + +# ============================================================================= +# 路由处理函数 (Route Handlers) +# ============================================================================= + + +@router.post("/upload") +async def upload_credentials( + files: List[UploadFile] = File(...), + token: str = Depends(verify_panel_token), + mode: str = "geminicli" +): + """批量上传凭证文件""" + try: + mode = validate_mode(mode) + return await upload_credentials_common(files, mode=mode) + except HTTPException: + raise + except Exception as e: + log.error(f"批量上传失败: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@router.get("/status") +async def get_creds_status( + token: str = Depends(verify_panel_token), + offset: int = 0, + limit: int = 50, + status_filter: str = "all", + error_code_filter: str = "all", + cooldown_filter: str = "all", + preview_filter: str = "all", + tier_filter: str = "all", + mode: str = "geminicli" +): + """ + 获取凭证文件的状态(轻量级摘要,不包含完整凭证数据,支持分页和状态筛选) + + Args: + offset: 跳过的记录数(默认0) + limit: 每页返回的记录数(默认50,可选:20, 50, 100, 200, 500, 1000) + status_filter: 状态筛选(all=全部, enabled=仅启用, disabled=仅禁用) + error_code_filter: 错误码筛选(all=全部, 或具体错误码如"400", "403") + cooldown_filter: 冷却状态筛选(all=全部, in_cooldown=冷却中, no_cooldown=未冷却) + preview_filter: Preview筛选(all=全部, preview=支持preview, no_preview=不支持preview,仅geminicli模式有效) + tier_filter: tier筛选(all=全部, free/pro/ultra) + mode: 凭证模式(geminicli 或 antigravity) + + Returns: + 包含凭证列表、总数、分页信息的响应 + """ + try: + mode = validate_mode(mode) + return await get_creds_status_common( + offset, limit, status_filter, mode=mode, + error_code_filter=error_code_filter, + cooldown_filter=cooldown_filter, + preview_filter=preview_filter, + tier_filter=tier_filter + ) + except HTTPException: + raise + except Exception as e: + log.error(f"获取凭证状态失败: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@router.get("/detail/{filename}") +async def get_cred_detail( + filename: str, + token: str = Depends(verify_panel_token), + mode: str = "geminicli" +): + """ + 按需获取单个凭证的详细数据(包含完整凭证内容) + 用于用户查看/编辑凭证详情 + """ + try: + mode = validate_mode(mode) + # 验证文件名 + if not filename.endswith(".json"): + raise HTTPException(status_code=400, detail="无效的文件名") + + + + storage_adapter = await get_storage_adapter() + backend_info = await storage_adapter.get_backend_info() + backend_type = backend_info.get("backend_type", "unknown") + + # 获取凭证数据 + credential_data = await storage_adapter.get_credential(filename, mode=mode) + if not credential_data: + raise HTTPException(status_code=404, detail="凭证不存在") + + # 获取状态信息 + file_status = await storage_adapter.get_credential_state(filename, mode=mode) + if not file_status: + file_status = { + "error_codes": [], + "disabled": False, + "last_success": time.time(), + "user_email": None, + } + + result = { + "status": file_status, + "content": credential_data, + "filename": os.path.basename(filename), + "backend_type": backend_type, + "user_email": file_status.get("user_email"), + "model_cooldowns": file_status.get("model_cooldowns", {}), + } + + if mode == "geminicli": + result["preview"] = file_status.get("preview", True) + else: + result["enable_credit"] = file_status.get("enable_credit", False) + + if backend_type == "file" and os.path.exists(filename): + result.update({ + "size": os.path.getsize(filename), + "modified_time": os.path.getmtime(filename), + }) + + return JSONResponse(content=result) + + except HTTPException: + raise + except Exception as e: + log.error(f"获取凭证详情失败 {filename}: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@router.post("/action") +async def creds_action( + request: CredFileActionRequest, + token: str = Depends(verify_panel_token), + mode: str = "geminicli" +): + """对凭证文件执行操作(启用/禁用/删除/enable_credit开关)""" + try: + mode = validate_mode(mode) + + log.info(f"Received request: {request}") + + filename = request.filename + action = request.action + + log.info(f"Performing action '{action}' on file: {filename} (mode={mode})") + + # 验证文件名 + if not filename.endswith(".json"): + log.error(f"无效的文件名: {filename}(不是.json文件)") + raise HTTPException(status_code=400, detail=f"无效的文件名: {filename}") + + # 获取存储适配器 + storage_adapter = await get_storage_adapter() + + # 对于删除操作,不需要检查凭证数据是否完整,只需检查条目是否存在 + # 对于其他操作,需要确保凭证数据存在且完整 + if action != "delete": + # 检查凭证数据是否存在 + credential_data = await storage_adapter.get_credential(filename, mode=mode) + if not credential_data: + log.error(f"凭证未找到: {filename} (mode={mode})") + raise HTTPException(status_code=404, detail="凭证文件不存在") + + if action == "enable": + log.info(f"Web请求: 启用文件 {filename} (mode={mode})") + result = await credential_manager.set_cred_disabled(filename, False, mode=mode) + log.info(f"[WebRoute] set_cred_disabled 返回结果: {result}") + if result: + log.info(f"Web请求: 文件 {filename} 已成功启用 (mode={mode})") + return JSONResponse(content={"message": f"已启用凭证文件 {os.path.basename(filename)}"}) + else: + log.error(f"Web请求: 文件 {filename} 启用失败 (mode={mode})") + raise HTTPException(status_code=500, detail="启用凭证失败,可能凭证不存在") + + elif action == "disable": + log.info(f"Web请求: 禁用文件 {filename} (mode={mode})") + result = await credential_manager.set_cred_disabled(filename, True, mode=mode) + log.info(f"[WebRoute] set_cred_disabled 返回结果: {result}") + if result: + log.info(f"Web请求: 文件 {filename} 已成功禁用 (mode={mode})") + return JSONResponse(content={"message": f"已禁用凭证文件 {os.path.basename(filename)}"}) + else: + log.error(f"Web请求: 文件 {filename} 禁用失败 (mode={mode})") + raise HTTPException(status_code=500, detail="禁用凭证失败,可能凭证不存在") + + elif action == "delete": + try: + # 使用 CredentialManager 删除凭证(包含队列/状态同步) + success = await credential_manager.remove_credential(filename, mode=mode) + if success: + log.info(f"通过管理器成功删除凭证: {filename} (mode={mode})") + return JSONResponse( + content={"message": f"已删除凭证文件 {os.path.basename(filename)}"} + ) + else: + raise HTTPException(status_code=500, detail="删除凭证失败") + except Exception as e: + log.error(f"删除凭证 {filename} 时出错: {e}") + raise HTTPException(status_code=500, detail=f"删除文件失败: {str(e)}") + + elif action == "enable_credit": + if mode != "antigravity": + raise HTTPException(status_code=400, detail="enable_credit 仅支持 antigravity 模式") + updated = await storage_adapter.update_credential_state( + filename, {"enable_credit": True}, mode=mode + ) + if updated: + await clear_all_model_cooldowns_for_credential(storage_adapter, filename, mode) + return JSONResponse(content={"message": f"已开启凭证信用额度模式 {os.path.basename(filename)}"}) + raise HTTPException(status_code=500, detail="开启信用额度模式失败,可能凭证不存在") + + elif action == "disable_credit": + if mode != "antigravity": + raise HTTPException(status_code=400, detail="disable_credit 仅支持 antigravity 模式") + updated = await storage_adapter.update_credential_state( + filename, {"enable_credit": False}, mode=mode + ) + if updated: + await clear_all_model_cooldowns_for_credential(storage_adapter, filename, mode) + return JSONResponse(content={"message": f"已关闭凭证信用额度模式 {os.path.basename(filename)}"}) + raise HTTPException(status_code=500, detail="关闭信用额度模式失败,可能凭证不存在") + + else: + raise HTTPException(status_code=400, detail="无效的操作类型") + + except HTTPException: + raise + except Exception as e: + log.error(f"凭证文件操作失败: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@router.post("/batch-action") +async def creds_batch_action( + request: CredFileBatchActionRequest, + token: str = Depends(verify_panel_token), + mode: str = "geminicli" +): + """批量对凭证文件执行操作(启用/禁用/删除/enable_credit开关)""" + try: + mode = validate_mode(mode) + + action = request.action + filenames = request.filenames + + if not filenames: + raise HTTPException(status_code=400, detail="文件名列表不能为空") + + log.info(f"对 {len(filenames)} 个文件执行批量操作 '{action}'") + + success_count = 0 + errors = [] + + storage_adapter = await get_storage_adapter() + + for filename in filenames: + try: + # 验证文件名安全性 + if not filename.endswith(".json"): + errors.append(f"{filename}: 无效的文件类型") + continue + + # 对于删除操作,不需要检查凭证数据完整性 + # 对于其他操作,需要确保凭证数据存在 + if action != "delete": + credential_data = await storage_adapter.get_credential(filename, mode=mode) + if not credential_data: + errors.append(f"{filename}: 凭证不存在") + continue + + # 执行相应操作 + if action == "enable": + await credential_manager.set_cred_disabled(filename, False, mode=mode) + success_count += 1 + + elif action == "disable": + await credential_manager.set_cred_disabled(filename, True, mode=mode) + success_count += 1 + + elif action == "delete": + try: + delete_success = await credential_manager.remove_credential(filename, mode=mode) + if delete_success: + success_count += 1 + log.info(f"成功删除批量中的凭证: {filename}") + else: + errors.append(f"{filename}: 删除失败") + continue + except Exception as e: + errors.append(f"{filename}: 删除文件失败 - {str(e)}") + continue + elif action == "enable_credit": + if mode != "antigravity": + errors.append(f"{filename}: enable_credit 仅支持 antigravity 模式") + continue + updated = await storage_adapter.update_credential_state( + filename, {"enable_credit": True}, mode=mode + ) + if updated: + await clear_all_model_cooldowns_for_credential(storage_adapter, filename, mode) + success_count += 1 + else: + errors.append(f"{filename}: 开启信用额度模式失败") + continue + elif action == "disable_credit": + if mode != "antigravity": + errors.append(f"{filename}: disable_credit 仅支持 antigravity 模式") + continue + updated = await storage_adapter.update_credential_state( + filename, {"enable_credit": False}, mode=mode + ) + if updated: + await clear_all_model_cooldowns_for_credential(storage_adapter, filename, mode) + success_count += 1 + else: + errors.append(f"{filename}: 关闭信用额度模式失败") + continue + else: + errors.append(f"{filename}: 无效的操作类型") + continue + + except Exception as e: + log.error(f"处理 {filename} 时出错: {e}") + errors.append(f"{filename}: 处理失败 - {str(e)}") + continue + + # 构建返回消息 + result_message = f"批量操作完成:成功处理 {success_count}/{len(filenames)} 个文件" + if errors: + result_message += "\n错误详情:\n" + "\n".join(errors) + + response_data = { + "success_count": success_count, + "total_count": len(filenames), + "errors": errors, + "message": result_message, + } + + return JSONResponse(content=response_data) + + except HTTPException: + raise + except Exception as e: + log.error(f"批量凭证文件操作失败: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@router.get("/download/{filename}") +async def download_cred_file( + filename: str, + token: str = Depends(verify_panel_token), + mode: str = "geminicli" +): + """下载单个凭证文件""" + try: + mode = validate_mode(mode) + # 验证文件名安全性 + if not filename.endswith(".json"): + raise HTTPException(status_code=404, detail="无效的文件名") + + # 获取存储适配器 + storage_adapter = await get_storage_adapter() + + # 从存储系统获取凭证数据 + credential_data = await storage_adapter.get_credential(filename, mode=mode) + if not credential_data: + raise HTTPException(status_code=404, detail="文件不存在") + + # 转换为JSON字符串 + content = json.dumps(credential_data, ensure_ascii=False, indent=2) + + from fastapi.responses import Response + + return Response( + content=content, + media_type="application/json", + headers={"Content-Disposition": f"attachment; filename={filename}"}, + ) + + except HTTPException: + raise + except Exception as e: + log.error(f"下载凭证文件失败: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@router.post("/fetch-email/{filename}") +async def fetch_user_email( + filename: str, + token: str = Depends(verify_panel_token), + mode: str = "geminicli" +): + """获取指定凭证文件的用户邮箱地址""" + try: + mode = validate_mode(mode) + return await fetch_user_email_common(filename, mode=mode) + except HTTPException: + raise + except Exception as e: + log.error(f"获取用户邮箱失败: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@router.post("/refresh-all-emails") +async def refresh_all_user_emails( + token: str = Depends(verify_panel_token), + mode: str = "geminicli" +): + """刷新所有凭证文件的用户邮箱地址""" + try: + mode = validate_mode(mode) + return await refresh_all_user_emails_common(mode=mode) + except Exception as e: + log.error(f"批量获取用户邮箱失败: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@router.post("/deduplicate-by-email") +async def deduplicate_credentials_by_email( + token: str = Depends(verify_panel_token), + mode: str = "geminicli" +): + """批量去重凭证文件 - 删除邮箱相同的凭证(只保留一个)""" + try: + mode = validate_mode(mode) + return await deduplicate_credentials_by_email_common(mode=mode) + except Exception as e: + log.error(f"批量去重凭证失败: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@router.get("/download-all") +async def download_all_creds( + token: str = Depends(verify_panel_token), + mode: str = "geminicli" +): + """ + 打包下载所有凭证文件(流式处理,按需加载每个凭证数据) + 只在实际下载时才加载完整凭证内容,最大化性能 + """ + try: + mode = validate_mode(mode) + return await download_all_creds_common(mode=mode) + except HTTPException: + raise + except Exception as e: + log.error(f"打包下载失败: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@router.post("/verify-project/{filename}") +async def verify_credential_project( + filename: str, + token: str = Depends(verify_panel_token), + mode: str = "geminicli" +): + """ + 检验凭证的project id,重新获取project id + 检验成功可以使403错误恢复 + """ + try: + mode = validate_mode(mode) + return await verify_credential_project_common(filename, mode=mode) + except HTTPException: + raise + except Exception as e: + log.error(f"检验凭证Project ID失败 {filename}: {e}") + raise HTTPException(status_code=500, detail=f"检验失败: {str(e)}") + + +@router.get("/errors/{filename}") +async def get_credential_errors( + filename: str, + token: str = Depends(verify_panel_token), + mode: str = "geminicli" +): + """ + 获取指定凭证的错误信息(包含 error_codes 和 error_messages) + + Args: + filename: 凭证文件名 + mode: 凭证模式(geminicli 或 antigravity) + + Returns: + 包含 error_codes 和 error_messages 的 JSON 响应 + """ + try: + mode = validate_mode(mode) + + # 验证文件名 + if not filename.endswith(".json"): + raise HTTPException(status_code=400, detail="无效的文件名") + + storage_adapter = await get_storage_adapter() + + # 检查后端是否支持 get_credential_errors 方法 + if not hasattr(storage_adapter._backend, 'get_credential_errors'): + raise HTTPException( + status_code=501, + detail="当前存储后端不支持获取错误信息" + ) + + # 获取错误信息 + error_info = await storage_adapter._backend.get_credential_errors(filename, mode=mode) + + return JSONResponse(content=error_info) + + except HTTPException: + raise + except Exception as e: + log.error(f"获取凭证错误信息失败 {filename}: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@router.get("/quota/{filename}") +async def get_credential_quota( + filename: str, + token: str = Depends(verify_panel_token), + mode: str = "antigravity" +): + """ + 获取指定凭证的额度信息(仅支持 antigravity 模式) + """ + try: + mode = validate_mode(mode) + # 验证文件名 + if not filename.endswith(".json"): + raise HTTPException(status_code=400, detail="无效的文件名") + + + storage_adapter = await get_storage_adapter() + + # 获取凭证数据 + credential_data = await storage_adapter.get_credential(filename, mode=mode) + if not credential_data: + raise HTTPException(status_code=404, detail="凭证不存在") + + # 使用 Credentials 对象自动处理 token 刷新 + from src.google_oauth_api import Credentials + + creds = Credentials.from_dict(credential_data) + + # 自动刷新 token(如果需要) + await creds.refresh_if_needed() + + # 如果 token 被刷新了,更新存储 + updated_data = creds.to_dict() + if updated_data != credential_data: + log.info(f"Token已自动刷新: {filename}") + await storage_adapter.store_credential(filename, updated_data, mode=mode) + credential_data = updated_data + + # 获取访问令牌 + access_token = credential_data.get("access_token") or credential_data.get("token") + if not access_token: + raise HTTPException(status_code=400, detail="凭证中没有访问令牌") + + # 获取额度信息 + quota_info = await fetch_quota_info(access_token) + + if quota_info.get("success"): + return JSONResponse(content={ + "success": True, + "filename": filename, + "models": quota_info.get("models", {}) + }) + else: + return JSONResponse( + status_code=400, + content={ + "success": False, + "filename": filename, + "error": quota_info.get("error", "未知错误") + } + ) + + except HTTPException: + raise + except Exception as e: + log.error(f"获取凭证额度失败 {filename}: {e}") + raise HTTPException(status_code=500, detail=f"获取额度失败: {str(e)}") + + +@router.post("/configure-preview/{filename}") +async def configure_preview_channel( + filename: str, + token: str = Depends(verify_panel_token), + mode: str = "geminicli" +): + """ + 为 geminicli 凭证配置 preview 通道 + + 通过调用 Google Cloud API 设置 release_channel 为 EXPERIMENTAL + + Args: + filename: 凭证文件名 + mode: 凭证模式(仅支持 geminicli) + + Returns: + 配置结果信息 + """ + try: + mode = validate_mode(mode) + + # 只支持 geminicli 模式 + if mode != "geminicli": + raise HTTPException( + status_code=400, + detail="配置 preview 通道仅支持 geminicli 模式" + ) + + # 验证文件名 + if not filename.endswith(".json"): + raise HTTPException(status_code=400, detail="无效的文件名") + + storage_adapter = await get_storage_adapter() + + # 获取凭证数据 + credential_data = await storage_adapter.get_credential(filename, mode=mode) + if not credential_data: + raise HTTPException(status_code=404, detail="凭证不存在") + + # 创建凭证对象并刷新 token(如果需要) + credentials = Credentials.from_dict(credential_data) + token_refreshed = await credentials.refresh_if_needed() + + if token_refreshed: + log.info(f"Token已自动刷新: {filename}") + credential_data = credentials.to_dict() + await storage_adapter.store_credential(filename, credential_data, mode=mode) + + # 获取 access_token 和 project_id + access_token = credential_data.get("access_token") or credential_data.get("token") + project_id = credential_data.get("project_id", "") + + if not access_token: + raise HTTPException(status_code=400, detail="凭证中没有访问令牌") + if not project_id: + raise HTTPException(status_code=400, detail="凭证中没有项目ID") + + # 调用 Google Cloud API 配置 preview 通道 + # 根据文档,需要两个步骤: + # 1. 创建 Release Channel Setting (EXPERIMENTAL) + # 2. 创建 Setting Binding (绑定到目标项目) + from src.httpx_client import post_async + import uuid + + # 生成唯一的 ID + setting_id = f"preview-setting-{uuid.uuid4().hex[:8]}" + binding_id = f"preview-binding-{uuid.uuid4().hex[:8]}" + + base_url = f"https://cloudaicompanion.googleapis.com/v1/projects/{project_id}/locations/global" + headers = { + "Authorization": f"Bearer {access_token}", + "Content-Type": "application/json" + } + + log.info(f"开始配置 preview 通道: {filename} (project_id={project_id})") + + # 步骤 1: 创建 Release Channel Setting + setting_url = f"{base_url}/releaseChannelSettings" + setting_response = await post_async( + url=setting_url, + json={"release_channel": "EXPERIMENTAL"}, + headers=headers, + params={"release_channel_setting_id": setting_id}, + timeout=30.0 + ) + + setting_status = setting_response.status_code + + if setting_status == 200 or setting_status == 201: + log.info(f"步骤 1/2: Release Channel Setting 创建成功 (setting_id={setting_id})") + elif setting_status == 409: + # Setting 已存在,继续下一步 + log.info(f"步骤 1/2: Release Channel Setting 已存在") + else: + # 步骤 1 失败 + error_text = setting_response.text if hasattr(setting_response, 'text') else "" + log.error(f"步骤 1/2 失败: {filename} - Status: {setting_status}, Error: {error_text}") + + return JSONResponse( + status_code=setting_status, + content={ + "success": False, + "filename": filename, + "preview": False, + "message": f"创建 Release Channel Setting 失败: HTTP {setting_status}", + "error": error_text, + "step": "create_setting" + } + ) + + # 步骤 2: 创建 Setting Binding (绑定到当前项目) + binding_url = f"{base_url}/releaseChannelSettings/{setting_id}/settingBindings" + binding_response = await post_async( + url=binding_url, + json={ + "target": f"projects/{project_id}", + "product": "GEMINI_CODE_ASSIST" + }, + headers=headers, + params={"setting_binding_id": binding_id}, + timeout=30.0 + ) + + binding_status = binding_response.status_code + + if binding_status == 200 or binding_status == 201: + await storage_adapter.update_credential_state(filename, { + "preview": True + }, mode=mode) + + log.info(f"步骤 2/2: Setting Binding 创建成功 - Preview 通道配置完成: {filename}") + + return JSONResponse(content={ + "success": True, + "filename": filename, + "preview": True, + "message": "Preview 通道配置成功,已将 preview 属性设置为 true", + "setting_id": setting_id, + "binding_id": binding_id + }) + elif binding_status == 409: + # Binding 已存在,说明已经配置过了 + await storage_adapter.update_credential_state(filename, { + "preview": True + }, mode=mode) + + log.info(f"步骤 2/2: Setting Binding 已存在 - Preview 通道已配置: {filename}") + + return JSONResponse(content={ + "success": True, + "filename": filename, + "preview": True, + "message": "Preview 通道配置已存在,已将 preview 属性设置为 true" + }) + else: + # 步骤 2 失败 + error_text = binding_response.text if hasattr(binding_response, 'text') else "" + log.error(f"步骤 2/2 失败: {filename} - Status: {binding_status}, Error: {error_text}") + + return JSONResponse( + status_code=binding_status, + content={ + "success": False, + "filename": filename, + "preview": False, + "message": f"创建 Setting Binding 失败: HTTP {binding_status}", + "error": error_text, + "step": "create_binding" + } + ) + + except HTTPException: + raise + except Exception as e: + log.error(f"配置 preview 通道失败 {filename}: {e}") + raise HTTPException(status_code=500, detail=f"配置失败: {str(e)}") + + +@router.post("/test/{filename}") +async def test_credential( + filename: str, + mode: str = "geminicli", + _token: str = Depends(verify_panel_token) +): + """ + 测试指定凭证是否可用 + + Args: + filename: 凭证文件名 + mode: 凭证模式(geminicli 或 antigravity) + + Returns: + 返回状态码: + - 200: 凭证可用 + - 429: 凭证被限流但有效 + - 其他: 凭证失败(返回实际错误码) + """ + try: + mode = validate_mode(mode) + + # 验证文件名 + if not filename.endswith(".json"): + raise HTTPException(status_code=400, detail="无效的文件名") + + storage_adapter = await get_storage_adapter() + + # 获取凭证数据 + credential_data = await storage_adapter.get_credential(filename, mode=mode) + if not credential_data: + raise HTTPException(status_code=404, detail="凭证不存在") + + # 创建凭证对象并尝试刷新 token(如果需要) + credentials = Credentials.from_dict(credential_data) + token_refreshed = await credentials.refresh_if_needed() + + # 如果 token 被刷新了,更新存储 + if token_refreshed: + log.info(f"Token已自动刷新: {filename} (mode={mode})") + credential_data = credentials.to_dict() + await storage_adapter.store_credential(filename, credential_data, mode=mode) + + # 获取访问令牌 + access_token = credential_data.get("access_token") or credential_data.get("token") + if not access_token: + raise HTTPException(status_code=400, detail="凭证中没有访问令牌") + + # 根据模式构造测试请求 + from src.httpx_client import post_async + + # 获取 project_id + project_id = credential_data.get("project_id", "") + if not project_id: + raise HTTPException(status_code=400, detail="凭证中没有项目ID") + + # 根据模式选择 API 端点和请求头 + # 对于 geminicli 模式,使用两次测试:gemini-2.5-flash 和 gemini-3-flash-preview + # 对于 antigravity 模式,只使用 gemini-2.5-flash + test_model = "gemini-2.5-flash" + + if mode == "antigravity": + api_base_url = await get_code_assist_endpoint() + from src.api.antigravity import build_antigravity_headers + headers = build_antigravity_headers(access_token, test_model) + else: + api_base_url = await get_code_assist_endpoint() + headers = { + "Authorization": f"Bearer {access_token}", + "Content-Type": "application/json", + "User-Agent": GEMINICLI_USER_AGENT, + } + + # 第一次测试:使用 gemini-2.5-flash + response = await post_async( + url=f"{api_base_url}/v1internal:generateContent", + json={ + "model": test_model, + "project": project_id, + "request": { + "contents": [{"role": "user", "parts": [{"text": "hi"}]}], + "generationConfig": {"maxOutputTokens": 1} + } + }, + headers=headers, + timeout=30.0 + ) + + # 返回实际的状态码和详细信息 + status_code = response.status_code + + if status_code == 200 or status_code == 429: + log.info(f"凭证测试成功: {filename} (mode={mode}, model={test_model}, status={status_code})") + # 测试成功时清除错误状态 + if status_code == 200: + await storage_adapter.update_credential_state(filename, { + "error_codes": [], + "error_messages": {} + }, mode=mode) + + # 如果是 geminicli 模式且第一次测试成功,继续测试 gemini-3-flash-preview + if mode == "geminicli": + preview_model = "gemini-3-flash-preview" + log.info(f"开始测试 preview 模型: {filename} (model={preview_model})") + + try: + preview_response = await post_async( + url=f"{api_base_url}/v1internal:generateContent", + json={ + "model": preview_model, + "project": project_id, + "request": { + "contents": [{"role": "user", "parts": [{"text": "hi"}]}], + "generationConfig": {"maxOutputTokens": 1} + } + }, + headers=headers, + timeout=30.0 + ) + + preview_status = preview_response.status_code + + if preview_status == 200 or preview_status == 429: + # preview 模型测试成功,设置 preview=True + log.info(f"Preview 模型测试成功: {filename} (status={preview_status})") + await storage_adapter.update_credential_state(filename, { + "preview": True + }, mode=mode) + elif preview_status == 404: + # preview 模型返回 404,说明不支持,设置 preview=False + log.warning(f"Preview 模型不支持: {filename} (status=404)") + await storage_adapter.update_credential_state(filename, { + "preview": False + }, mode=mode) + else: + # 其他错误,保持默认 preview 状态 + log.warning(f"Preview 模型测试失败: {filename} (status={preview_status})") + except Exception as e: + log.error(f"Preview 模型测试异常: {filename} - {e}") + + # 返回成功响应 + return JSONResponse( + status_code=status_code, + content={ + "success": True, + "status_code": status_code, + "message": "测试成功", + "filename": filename + } + ) + else: + log.warning(f"凭证测试失败: {filename} (mode={mode}, status={status_code})") + # 测试失败时保存错误码和错误消息(覆盖模式,只保存最新的一个错误) + try: + error_text = response.text if hasattr(response, 'text') else "" + + # 打印详细错误内容到日志 + log.error(f"凭证测试错误详情 - 文件: {filename}, 模式: {mode}, 状态码: {status_code}, 错误内容: {error_text}") + + # 使用覆盖模式保存错误(与 credential_manager 保持一致) + error_codes = [status_code] + error_messages = {str(status_code): error_text if error_text else f"HTTP {status_code}"} + + # 更新状态 + await storage_adapter.update_credential_state(filename, { + "error_codes": error_codes, + "error_messages": error_messages + }, mode=mode) + + log.info(f"已保存测试错误信息: {filename} - 错误码 {status_code}") + except Exception as e: + log.error(f"保存测试错误信息失败: {e}") + + # 返回错误响应,包含完整的错误信息 + error_text = response.text if hasattr(response, 'text') else "" + + return JSONResponse( + status_code=status_code, + content={ + "success": False, + "status_code": status_code, + "message": f"测试失败: HTTP {status_code}", + "error": error_text, + "filename": filename + } + ) + + except HTTPException: + raise + except Exception as e: + log.error(f"测试凭证失败 {filename}: {e}") + raise HTTPException(status_code=500, detail=f"测试失败: {str(e)}") diff --git a/src/panel/logs.py b/src/panel/logs.py new file mode 100644 index 0000000000000000000000000000000000000000..4dcf1d1223a40c80fcf5a746cbd93f61434c08f3 --- /dev/null +++ b/src/panel/logs.py @@ -0,0 +1,237 @@ +""" +日志路由模块 - 处理 /logs/* 相关的HTTP请求和WebSocket连接 +""" + +import asyncio +import datetime +import os + +from fastapi import APIRouter, Depends, HTTPException, WebSocket, WebSocketDisconnect +from fastapi.responses import FileResponse, JSONResponse +from starlette.websockets import WebSocketState + +import config +from log import log +from src.utils import verify_panel_token +from .utils import ConnectionManager + + +# 创建路由器 +router = APIRouter(prefix="/logs", tags=["logs"]) + +# WebSocket连接管理器 +manager = ConnectionManager() + + +@router.post("/clear") +async def clear_logs(token: str = Depends(verify_panel_token)): + """清空日志文件""" + try: + # 直接使用环境变量获取日志文件路径 + log_file_path = os.getenv("LOG_FILE", "log.txt") + + # 检查日志文件是否存在 + if os.path.exists(log_file_path): + try: + # 清空文件内容(保留文件),确保以UTF-8编码写入 + # 使用 with 确保文件正确关闭 + with open(log_file_path, "w", encoding="utf-8") as f: + f.write("") + f.flush() # 强制刷新到磁盘 + # with 退出时会自动关闭文件 + log.info(f"日志文件已清空: {log_file_path}") + + # 通知所有WebSocket连接日志已清空 + await manager.broadcast("--- 日志文件已清空 ---") + + return JSONResponse( + content={"message": f"日志文件已清空: {os.path.basename(log_file_path)}"} + ) + except Exception as e: + log.error(f"清空日志文件失败: {e}") + raise HTTPException(status_code=500, detail=f"清空日志文件失败: {str(e)}") + else: + return JSONResponse(content={"message": "日志文件不存在"}) + + except Exception as e: + log.error(f"清空日志文件失败: {e}") + raise HTTPException(status_code=500, detail=f"清空日志文件失败: {str(e)}") + + +@router.get("/download") +async def download_logs(token: str = Depends(verify_panel_token)): + """下载日志文件""" + try: + # 直接使用环境变量获取日志文件路径 + log_file_path = os.getenv("LOG_FILE", "log.txt") + + # 检查日志文件是否存在 + if not os.path.exists(log_file_path): + raise HTTPException(status_code=404, detail="日志文件不存在") + + # 检查文件是否为空 + file_size = os.path.getsize(log_file_path) + if file_size == 0: + raise HTTPException(status_code=404, detail="日志文件为空") + + # 生成文件名(包含时间戳) + timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") + filename = f"gcli2api_logs_{timestamp}.txt" + + log.info(f"下载日志文件: {log_file_path}") + + return FileResponse( + path=log_file_path, + filename=filename, + media_type="text/plain", + headers={"Content-Disposition": f"attachment; filename={filename}"}, + ) + + except HTTPException: + raise + except Exception as e: + log.error(f"下载日志文件失败: {e}") + raise HTTPException(status_code=500, detail=f"下载日志文件失败: {str(e)}") + + +@router.websocket("/stream") +async def websocket_logs(websocket: WebSocket): + """WebSocket端点,用于实时日志流""" + # WebSocket 认证: 从查询参数获取 token + token = websocket.query_params.get("token") + + if not token: + await websocket.close(code=403, reason="Missing authentication token") + log.warning("WebSocket连接被拒绝: 缺少认证token") + return + + # 验证 token + try: + panel_password = await config.get_panel_password() + if token != panel_password: + await websocket.close(code=403, reason="Invalid authentication token") + log.warning("WebSocket连接被拒绝: token验证失败") + return + except Exception as e: + await websocket.close(code=1011, reason="Authentication error") + log.error(f"WebSocket认证过程出错: {e}") + return + + # 检查连接数限制 + if not await manager.connect(websocket): + return + + try: + # 直接使用环境变量获取日志文件路径 + log_file_path = os.getenv("LOG_FILE", "log.txt") + + # 发送初始日志(限制为最后50行,减少内存占用) + if os.path.exists(log_file_path): + try: + # 使用 with 确保文件正确关闭 + with open(log_file_path, "r", encoding="utf-8") as f: + lines = f.readlines() + # 只发送最后50行,减少初始内存消耗 + for line in lines[-50:]: + if line.strip(): + await websocket.send_text(line.strip()) + except Exception as e: + await websocket.send_text(f"Error reading log file: {e}") + log.error(f"WebSocket初始日志读取错误: {e}") + + # 监控日志文件变化 + last_size = os.path.getsize(log_file_path) if os.path.exists(log_file_path) else 0 + max_read_size = 8192 # 限制单次读取大小为8KB,防止大量日志造成内存激增 + check_interval = 2 # 增加检查间隔,减少CPU和I/O开销 + + # 创建后台任务监听客户端断开 + # 即使没有日志更新,receive_text() 也能即时感知断开 + async def listen_for_disconnect(): + try: + while True: + await websocket.receive_text() + except Exception: + pass + + listener_task = asyncio.create_task(listen_for_disconnect()) + + try: + while websocket.client_state == WebSocketState.CONNECTED: + # 使用 asyncio.wait 同时等待定时器和断开信号 + # timeout=check_interval 替代了 asyncio.sleep + done, pending = await asyncio.wait( + [listener_task], + timeout=check_interval, + return_when=asyncio.FIRST_COMPLETED + ) + + # 如果监听任务结束(通常是因为连接断开),则退出循环 + if listener_task in done: + break + + if os.path.exists(log_file_path): + current_size = os.path.getsize(log_file_path) + if current_size > last_size: + # 限制读取大小,防止单次读取过多内容 + read_size = min(current_size - last_size, max_read_size) + + try: + # 使用 with 确保文件正确关闭,即使发生异常 + with open(log_file_path, "r", encoding="utf-8", errors="replace") as f: + f.seek(last_size) + new_content = f.read(read_size) + # with 退出时自动关闭文件句柄 + + # 处理编码错误的情况 + if not new_content: + last_size = current_size + continue + + # 分行发送,避免发送不完整的行 + lines = new_content.splitlines(keepends=True) + if lines: + # 如果最后一行没有换行符,保留到下次处理 + if not lines[-1].endswith("\n") and len(lines) > 1: + # 除了最后一行,其他都发送 + for line in lines[:-1]: + if line.strip(): + await websocket.send_text(line.rstrip()) + # 更新位置,但要退回最后一行的字节数 + last_size += len(new_content.encode("utf-8")) - len( + lines[-1].encode("utf-8") + ) + else: + # 所有行都发送 + for line in lines: + if line.strip(): + await websocket.send_text(line.rstrip()) + last_size += len(new_content.encode("utf-8")) + except UnicodeDecodeError as e: + # 遇到编码错误时,跳过这部分内容 + log.warning(f"WebSocket日志读取编码错误: {e}, 跳过部分内容") + last_size = current_size + except Exception as e: + await websocket.send_text(f"Error reading new content: {e}") + # 发生其他错误时,重置文件位置 + last_size = current_size + + # 如果文件被截断(如清空日志),重置位置 + elif current_size < last_size: + last_size = 0 + await websocket.send_text("--- 日志已清空 ---") + + finally: + # 确保清理监听任务 + if not listener_task.done(): + listener_task.cancel() + try: + await listener_task + except asyncio.CancelledError: + pass + + except WebSocketDisconnect: + pass + except Exception as e: + log.error(f"WebSocket logs error: {e}") + finally: + manager.disconnect(websocket) diff --git a/src/panel/root.py b/src/panel/root.py new file mode 100644 index 0000000000000000000000000000000000000000..b2139578d369a5c8e7c3fe0109e475bcc44b6343 --- /dev/null +++ b/src/panel/root.py @@ -0,0 +1,34 @@ +""" +根路由模块 - 处理控制面板主页 +""" + +from fastapi import APIRouter, HTTPException, Request +from fastapi.responses import HTMLResponse + +from log import log +from .utils import is_mobile_user_agent + + +# 创建路由器 +router = APIRouter(tags=["root"]) + + +@router.get("/", response_class=HTMLResponse) +async def serve_control_panel(request: Request): + """提供统一控制面板""" + try: + user_agent = request.headers.get("user-agent", "") + is_mobile = is_mobile_user_agent(user_agent) + + if is_mobile: + html_file_path = "front/control_panel_mobile.html" + else: + html_file_path = "front/control_panel.html" + + with open(html_file_path, "r", encoding="utf-8") as f: + html_content = f.read() + return HTMLResponse(content=html_content) + + except Exception as e: + log.error(f"加载控制面板页面失败: {e}") + raise HTTPException(status_code=500, detail="服务器内部错误") diff --git a/src/panel/utils.py b/src/panel/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f75e47d3b620b13fc8129fb383ca1a7ae00d086d --- /dev/null +++ b/src/panel/utils.py @@ -0,0 +1,166 @@ +""" +共享工具模块 - 包含WebSocket连接管理、工具函数等 +""" + +import os +import time +from collections import deque +from typing import Set + +from fastapi import HTTPException, WebSocket +from starlette.websockets import WebSocketState + +import config +from log import log + + +# ============================================================================= +# WebSocket连接管理 +# ============================================================================= + + +class ConnectionManager: + def __init__(self, max_connections: int = 3): # 进一步降低最大连接数 + # 使用双端队列严格限制内存使用 + self.active_connections: deque = deque(maxlen=max_connections) + self.max_connections = max_connections + self._last_cleanup = 0 + self._cleanup_interval = 120 # 120秒清理一次死连接 + + async def connect(self, websocket: WebSocket): + # 自动清理死连接 + self._auto_cleanup() + + # 限制最大连接数,防止内存无限增长 + if len(self.active_connections) >= self.max_connections: + await websocket.close(code=1008, reason="Too many connections") + return False + + await websocket.accept() + self.active_connections.append(websocket) + log.debug(f"WebSocket连接建立,当前连接数: {len(self.active_connections)}") + return True + + def disconnect(self, websocket: WebSocket): + # 使用更高效的方式移除连接 + try: + self.active_connections.remove(websocket) + except ValueError: + pass # 连接已不存在 + log.debug(f"WebSocket连接断开,当前连接数: {len(self.active_connections)}") + + async def send_personal_message(self, message: str, websocket: WebSocket): + try: + await websocket.send_text(message) + except Exception: + self.disconnect(websocket) + + async def broadcast(self, message: str): + # 使用更高效的方式处理广播,避免索引操作 + dead_connections = [] + for conn in self.active_connections: + try: + await conn.send_text(message) + except Exception: + dead_connections.append(conn) + + # 批量移除死连接 + for dead_conn in dead_connections: + self.disconnect(dead_conn) + + def _auto_cleanup(self): + """自动清理死连接""" + current_time = time.time() + if current_time - self._last_cleanup > self._cleanup_interval: + self.cleanup_dead_connections() + self._last_cleanup = current_time + + def cleanup_dead_connections(self): + """清理已断开的连接""" + original_count = len(self.active_connections) + # 使用列表推导式过滤活跃连接,更高效 + alive_connections = deque( + [ + conn + for conn in self.active_connections + if hasattr(conn, "client_state") + and conn.client_state != WebSocketState.DISCONNECTED + ], + maxlen=self.max_connections, + ) + + self.active_connections = alive_connections + cleaned = original_count - len(self.active_connections) + if cleaned > 0: + log.debug(f"清理了 {cleaned} 个死连接,剩余连接数: {len(self.active_connections)}") + + +# ============================================================================= +# 工具函数 +# ============================================================================= + + +def is_mobile_user_agent(user_agent: str) -> bool: + """检测是否为移动设备用户代理""" + if not user_agent: + return False + + user_agent_lower = user_agent.lower() + mobile_keywords = [ + "mobile", + "android", + "iphone", + "ipad", + "ipod", + "blackberry", + "windows phone", + "samsung", + "htc", + "motorola", + "nokia", + "palm", + "webos", + "opera mini", + "opera mobi", + "fennec", + "minimo", + "symbian", + "psp", + "nintendo", + "tablet", + ] + + return any(keyword in user_agent_lower for keyword in mobile_keywords) + + +def validate_mode(mode: str = "geminicli") -> str: + """ + 验证 mode 参数 + + Args: + mode: 模式字符串 ("geminicli" 或 "antigravity") + + Returns: + str: 验证后的 mode 字符串 + + Raises: + HTTPException: 如果 mode 参数无效 + """ + if mode not in ["geminicli", "antigravity"]: + raise HTTPException( + status_code=400, + detail=f"无效的 mode 参数: {mode},只支持 'geminicli' 或 'antigravity'" + ) + return mode + + +def get_env_locked_keys() -> Set: + """获取被环境变量锁定的配置键集合""" + env_locked_keys = set() + + # 使用 config.py 中统一维护的映射表 + for env_key, config_key in config.ENV_MAPPINGS.items(): + if os.getenv(env_key): + env_locked_keys.add(config_key) + + return env_locked_keys diff --git a/src/panel/version.py b/src/panel/version.py new file mode 100644 index 0000000000000000000000000000000000000000..88c87264096dd1a92342f16445fcba173c46ae28 --- /dev/null +++ b/src/panel/version.py @@ -0,0 +1,107 @@ +""" +版本信息路由模块 - 处理 /version/* 相关的HTTP请求 +""" + +import os + +from fastapi import APIRouter +from fastapi.responses import JSONResponse + +from log import log + + +# 创建路由器 +router = APIRouter(prefix="/version", tags=["version"]) + + +@router.get("/info") +async def get_version_info(check_update: bool = False): + """ + 获取当前版本信息 - 从version.txt读取 + 可选参数 check_update: 是否检查GitHub上的最新版本 + """ + try: + # 获取项目根目录 + project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + version_file = os.path.join(project_root, "version.txt") + + # 读取version.txt + if not os.path.exists(version_file): + return JSONResponse({ + "success": False, + "error": "version.txt文件不存在" + }) + + version_data = {} + with open(version_file, 'r', encoding='utf-8') as f: + for line in f: + line = line.strip() + if '=' in line: + key, value = line.split('=', 1) + version_data[key] = value + + # 检查必要字段 + if 'short_hash' not in version_data: + return JSONResponse({ + "success": False, + "error": "version.txt格式错误" + }) + + response_data = { + "success": True, + "version": version_data.get('short_hash', 'unknown'), + "full_hash": version_data.get('full_hash', ''), + "message": version_data.get('message', ''), + "date": version_data.get('date', '') + } + + # 如果需要检查更新 + if check_update: + try: + from src.httpx_client import get_async + + # 直接获取GitHub上的version.txt文件 + github_version_url = "https://raw.githubusercontent.com/su-kaka/gcli2api/refs/heads/master/version.txt" + + # 使用统一的httpx客户端 + resp = await get_async(github_version_url, timeout=10.0) + + if resp.status_code == 200: + # 解析远程version.txt + remote_version_data = {} + for line in resp.text.strip().split('\n'): + line = line.strip() + if '=' in line: + key, value = line.split('=', 1) + remote_version_data[key] = value + + latest_hash = remote_version_data.get('full_hash', '') + latest_short_hash = remote_version_data.get('short_hash', '') + current_hash = version_data.get('full_hash', '') + + has_update = (current_hash != latest_hash) if current_hash and latest_hash else None + + response_data['check_update'] = True + response_data['has_update'] = has_update + response_data['latest_version'] = latest_short_hash + response_data['latest_hash'] = latest_hash + response_data['latest_message'] = remote_version_data.get('message', '') + response_data['latest_date'] = remote_version_data.get('date', '') + else: + # GitHub获取失败,但不影响基本版本信息 + response_data['check_update'] = False + response_data['update_error'] = f"GitHub返回错误: {resp.status_code}" + + except Exception as e: + log.debug(f"检查更新失败: {e}") + response_data['check_update'] = False + response_data['update_error'] = str(e) + + return JSONResponse(response_data) + + except Exception as e: + log.error(f"获取版本信息失败: {e}") + return JSONResponse({ + "success": False, + "error": str(e) + }) diff --git a/src/router/antigravity/anthropic.py b/src/router/antigravity/anthropic.py new file mode 100644 index 0000000000000000000000000000000000000000..e608b8a60fddc83e75c22a9e48ae3d972966056d --- /dev/null +++ b/src/router/antigravity/anthropic.py @@ -0,0 +1,614 @@ +""" +Anthropic Router - Handles Anthropic/Claude format API requests via Antigravity +通过Antigravity处理Anthropic/Claude格式请求的路由模块 +""" + +import sys +from pathlib import Path + +# 添加项目根目录到Python路径 +project_root = Path(__file__).resolve().parent.parent.parent.parent +if str(project_root) not in sys.path: + sys.path.insert(0, str(project_root)) + +# 标准库 +import asyncio +import json + +# 第三方库 +from fastapi import APIRouter, Depends, HTTPException, Request +from fastapi.responses import JSONResponse, StreamingResponse + +# 本地模块 - 配置和日志 +from config import get_anti_truncation_max_attempts, get_api_password +from log import log + +# 本地模块 - 工具和认证 +from src.utils import ( + get_base_model_from_feature_model, + is_anti_truncation_model, + is_fake_streaming_model, + authenticate_bearer, +) + +# 本地模块 - 转换器(假流式需要) +from src.converter.fake_stream import ( + parse_response_for_fake_stream, + build_anthropic_fake_stream_chunks, + create_anthropic_heartbeat_chunk, +) + +# 本地模块 - 基础路由工具 +from src.router.hi_check import is_health_check_request, create_health_check_response +from src.router.stream_passthrough import ( + build_streaming_response_or_error, + prepend_async_item, + read_first_async_item, +) + +# 本地模块 - 数据模型 +from src.models import ClaudeRequest, model_to_dict + +# 本地模块 - 任务管理 +from src.task_manager import create_managed_task + +# 本地模块 - Token估算 +from src.token_estimator import estimate_input_tokens + + +# ==================== 路由器初始化 ==================== + +router = APIRouter() + + +# ==================== API 路由 ==================== + +@router.post("/antigravity/v1/messages") +async def messages( + claude_request: ClaudeRequest, + _token: str = Depends(authenticate_bearer) +): + """ + 处理Anthropic/Claude格式的消息请求(流式和非流式) + + Args: + claude_request: Anthropic/Claude格式的请求体 + token: Bearer认证令牌 + """ + log.debug(f"[ANTIGRAVITY-ANTHROPIC] Request for model: {claude_request.model}") + + # 转换为字典 + normalized_dict = model_to_dict(claude_request) + + # 健康检查 + if is_health_check_request(normalized_dict, format="anthropic"): + response = create_health_check_response(format="anthropic") + return JSONResponse(content=response) + + # 处理模型名称和功能检测 + use_fake_streaming = is_fake_streaming_model(claude_request.model) + use_anti_truncation = is_anti_truncation_model(claude_request.model) + real_model = get_base_model_from_feature_model(claude_request.model) + + # 获取流式标志 + is_streaming = claude_request.stream + + # 对于抗截断模型的非流式请求,给出警告 + if use_anti_truncation and not is_streaming: + log.warning("抗截断功能仅在流式传输时有效,非流式请求将忽略此设置") + + # 更新模型名为真实模型名 + normalized_dict["model"] = real_model + + # 转换为 Gemini 格式 (使用 converter) + from src.converter.anthropic2gemini import anthropic_to_gemini_request + gemini_dict = await anthropic_to_gemini_request(normalized_dict) + + # anthropic_to_gemini_request 不包含 model 字段,需要手动添加 + gemini_dict["model"] = real_model + + # 规范化 Gemini 请求 (使用 antigravity 模式) + from src.converter.gemini_fix import normalize_gemini_request + gemini_dict = await normalize_gemini_request(gemini_dict, mode="antigravity") + + # 准备API请求格式 - 提取model并将其他字段放入request中 + api_request = { + "model": gemini_dict.pop("model"), + "request": gemini_dict + } + + # ========== 非流式请求 ========== + if not is_streaming: + # 调用 API 层的非流式请求 + from src.api.antigravity import non_stream_request + response = await non_stream_request(body=api_request) + + # 检查响应状态码 + status_code = getattr(response, "status_code", 200) + + # 提取响应体 + if hasattr(response, "body"): + response_body = response.body.decode() if isinstance(response.body, bytes) else response.body + elif hasattr(response, "content"): + response_body = response.content.decode() if isinstance(response.content, bytes) else response.content + else: + response_body = str(response) + + try: + gemini_response = json.loads(response_body) + except Exception as e: + log.error(f"Failed to parse Gemini response: {e}") + raise HTTPException(status_code=500, detail="Response parsing failed") + + # 转换为 Anthropic 格式 + from src.converter.anthropic2gemini import gemini_to_anthropic_response + anthropic_response = gemini_to_anthropic_response( + gemini_response, + real_model, + status_code + ) + + return JSONResponse(content=anthropic_response, status_code=status_code) + + # ========== 流式请求 ========== + + # ========== 假流式生成器 ========== + async def fake_stream_generator(): + from src.api.antigravity import non_stream_request + + response = await non_stream_request(body=api_request) + + # 检查响应状态码 + if hasattr(response, "status_code") and response.status_code != 200: + log.error(f"Fake streaming got error response: status={response.status_code}") + yield response + return + + # 处理成功响应 - 提取响应内容 + if hasattr(response, "body"): + response_body = response.body.decode() if isinstance(response.body, bytes) else response.body + elif hasattr(response, "content"): + response_body = response.content.decode() if isinstance(response.content, bytes) else response.content + else: + response_body = str(response) + + try: + gemini_response = json.loads(response_body) + log.debug(f"Anthropic fake stream Gemini response: {gemini_response}") + + # 检查是否是错误响应(有些错误可能status_code是200但包含error字段) + if "error" in gemini_response: + log.error(f"Fake streaming got error in response body: {gemini_response['error']}") + # 转换错误为 Anthropic 格式 + from src.converter.anthropic2gemini import gemini_to_anthropic_response + anthropic_error = gemini_to_anthropic_response( + gemini_response, + real_model, + 200 + ) + yield f"data: {json.dumps(anthropic_error)}\n\n".encode() + yield "data: [DONE]\n\n".encode() + return + + # 使用统一的解析函数 + content, reasoning_content, finish_reason, images = parse_response_for_fake_stream(gemini_response) + + log.debug(f"Anthropic extracted content: {content}") + log.debug(f"Anthropic extracted reasoning: {reasoning_content[:100] if reasoning_content else 'None'}...") + log.debug(f"Anthropic extracted images count: {len(images)}") + + # 构建响应块 + chunks = build_anthropic_fake_stream_chunks(content, reasoning_content, finish_reason, real_model, images) + for idx, chunk in enumerate(chunks): + chunk_json = json.dumps(chunk) + log.debug(f"[FAKE_STREAM] Yielding chunk #{idx+1}: {chunk_json[:200]}") + yield f"data: {chunk_json}\n\n".encode() + + except Exception as e: + log.error(f"Response parsing failed: {e}, directly yield error") + # 构建错误响应 + error_chunk = { + "type": "error", + "error": { + "type": "api_error", + "message": str(e) + } + } + yield f"data: {json.dumps(error_chunk)}\n\n".encode() + + yield "data: [DONE]\n\n".encode() + + # ========== 流式抗截断生成器 ========== + async def anti_truncation_generator(): + from src.converter.anti_truncation import AntiTruncationStreamProcessor + from src.api.antigravity import stream_request + from src.converter.anti_truncation import apply_anti_truncation + from src.converter.anthropic2gemini import gemini_stream_to_anthropic_stream + from fastapi import Response + + max_attempts = await get_anti_truncation_max_attempts() + + # 首先对payload应用反截断指令 + anti_truncation_payload = apply_anti_truncation(api_request) + + first_attempt_stream = stream_request(body=anti_truncation_payload, native=False) + try: + first_chunk = await read_first_async_item(first_attempt_stream) + except StopAsyncIteration: + return + + if isinstance(first_chunk, Response): + yield first_chunk + return + + first_attempt_pending = True + + async def stream_request_wrapper(payload): + nonlocal first_attempt_pending + + if first_attempt_pending: + first_attempt_pending = False + stream_gen = prepend_async_item(first_chunk, first_attempt_stream) + else: + stream_gen = stream_request(body=payload, native=False) + return StreamingResponse(stream_gen, media_type="text/event-stream") + + # 创建反截断处理器 + processor = AntiTruncationStreamProcessor( + stream_request_wrapper, + anti_truncation_payload, + max_attempts, + enable_prefill_mode=("claude" not in str(api_request.get("model", "")).lower()), + ) + + # 包装以确保是bytes流 + async def bytes_wrapper(): + async for chunk in processor.process_stream(): + if isinstance(chunk, str): + yield chunk.encode('utf-8') + else: + yield chunk + + # 直接将整个流传递给转换器 + async for anthropic_chunk in gemini_stream_to_anthropic_stream( + bytes_wrapper(), + real_model, + 200 + ): + if anthropic_chunk: + yield anthropic_chunk + + # ========== 普通流式生成器 ========== + async def normal_stream_generator(): + from src.api.antigravity import stream_request + from fastapi import Response + from src.converter.anthropic2gemini import gemini_stream_to_anthropic_stream + + # 调用 API 层的流式请求(不使用 native 模式) + stream_gen = stream_request(body=api_request, native=False) + try: + first_chunk = await read_first_async_item(stream_gen) + except StopAsyncIteration: + return + + if isinstance(first_chunk, Response): + yield first_chunk + return + + # 包装流式生成器以处理错误响应 + async def gemini_chunk_wrapper(): + async for chunk in prepend_async_item(first_chunk, stream_gen): + # 检查是否是Response对象(错误情况) + if isinstance(chunk, Response): + # 错误响应,不进行转换,直接传递 + try: + error_content = chunk.body if isinstance(chunk.body, bytes) else (chunk.body or b'').encode('utf-8') + gemini_error = json.loads(error_content.decode('utf-8')) + from src.converter.anthropic2gemini import gemini_to_anthropic_response + anthropic_error = gemini_to_anthropic_response( + gemini_error, + real_model, + chunk.status_code + ) + yield f"data: {json.dumps(anthropic_error)}\n\n".encode('utf-8') + except Exception: + yield f"data: {json.dumps({'type': 'error', 'error': {'type': 'api_error', 'message': 'Stream error'}})}\n\n".encode('utf-8') + yield b"data: [DONE]\n\n" + return + else: + # 确保是bytes类型 + if isinstance(chunk, str): + yield chunk.encode('utf-8') + else: + yield chunk + + # 使用转换器处理整个流 + async for anthropic_chunk in gemini_stream_to_anthropic_stream( + gemini_chunk_wrapper(), + real_model, + 200 + ): + if anthropic_chunk: + yield anthropic_chunk + + # ========== 根据模式选择生成器 ========== + if use_fake_streaming: + return await build_streaming_response_or_error(fake_stream_generator()) + elif use_anti_truncation: + log.info("启用流式抗截断功能") + return await build_streaming_response_or_error(anti_truncation_generator()) + else: + return await build_streaming_response_or_error(normal_stream_generator()) + + +@router.post("/antigravity/v1/messages/count_tokens") +async def count_tokens( + request: Request, + _token: str = Depends(authenticate_bearer) +): + """ + 处理Anthropic格式的token计数请求 + + Args: + request: FastAPI请求对象 + _token: Bearer认证令牌(由Depends验证) + + Returns: + JSONResponse: 包含input_tokens的响应 + """ + try: + payload = await request.json() + except Exception as e: + return JSONResponse( + status_code=400, + content={"type": "error", "error": {"type": "invalid_request_error", "message": f"JSON 解析失败: {str(e)}"}} + ) + + if not isinstance(payload, dict): + return JSONResponse( + status_code=400, + content={"type": "error", "error": {"type": "invalid_request_error", "message": "请求体必须为 JSON object"}} + ) + + if not payload.get("model") or not isinstance(payload.get("messages"), list): + return JSONResponse( + status_code=400, + content={"type": "error", "error": {"type": "invalid_request_error", "message": "缺少必填字段:model / messages"}} + ) + + try: + client_host = request.client.host if request.client else "unknown" + client_port = request.client.port if request.client else "unknown" + except Exception: + client_host = "unknown" + client_port = "unknown" + + thinking_present = "thinking" in payload + thinking_value = payload.get("thinking") + thinking_summary = None + if thinking_present: + if isinstance(thinking_value, dict): + thinking_summary = { + "type": thinking_value.get("type"), + "budget_tokens": thinking_value.get("budget_tokens"), + } + else: + thinking_summary = thinking_value + + user_agent = request.headers.get("user-agent", "") + log.info( + f"[ANTIGRAVITY-ANTHROPIC] /messages/count_tokens 收到请求: client={client_host}:{client_port}, " + f"model={payload.get('model')}, messages={len(payload.get('messages') or [])}, " + f"thinking_present={thinking_present}, thinking={thinking_summary}, ua={user_agent}" + ) + + # 简单估算 + input_tokens = 0 + try: + input_tokens = estimate_input_tokens(payload) + except Exception as e: + log.error(f"[ANTIGRAVITY-ANTHROPIC] token 估算失败: {e}") + + return JSONResponse(content={"input_tokens": input_tokens}) + + +# ==================== 测试代码 ==================== + +if __name__ == "__main__": + """ + 测试代码:演示Anthropic路由的流式和非流式响应 + 运行方式: python src/router/antigravity/anthropic.py + """ + + from fastapi.testclient import TestClient + from fastapi import FastAPI + + print("=" * 80) + print("Anthropic Router 测试") + print("=" * 80) + + # 创建测试应用 + app = FastAPI() + app.include_router(router) + + # 测试客户端 + client = TestClient(app) + + # 测试请求体 (Anthropic格式) + test_request_body = { + "model": "gemini-2.5-flash", + "max_tokens": 1024, + "messages": [ + {"role": "user", "content": "Hello, tell me a joke in one sentence."} + ] + } + + # 测试Bearer令牌(模拟) + test_token = "Bearer pwd" + + def test_non_stream_request(): + """测试非流式请求""" + print("\n" + "=" * 80) + print("【测试1】非流式请求 (POST /antigravity/v1/messages)") + print("=" * 80) + print(f"请求体: {json.dumps(test_request_body, indent=2, ensure_ascii=False)}\n") + + response = client.post( + "/antigravity/v1/messages", + json=test_request_body, + headers={"Authorization": test_token} + ) + + print("非流式响应数据:") + print("-" * 80) + print(f"状态码: {response.status_code}") + print(f"Content-Type: {response.headers.get('content-type', 'N/A')}") + + try: + content = response.text + print(f"\n响应内容 (原始):\n{content}\n") + + # 尝试解析JSON + try: + json_data = response.json() + print(f"响应内容 (格式化JSON):") + print(json.dumps(json_data, indent=2, ensure_ascii=False)) + except json.JSONDecodeError: + print("(非JSON格式)") + except Exception as e: + print(f"内容解析失败: {e}") + + def test_stream_request(): + """测试流式请求""" + print("\n" + "=" * 80) + print("【测试2】流式请求 (POST /antigravity/v1/messages)") + print("=" * 80) + + stream_request_body = test_request_body.copy() + stream_request_body["stream"] = True + + print(f"请求体: {json.dumps(stream_request_body, indent=2, ensure_ascii=False)}\n") + + print("流式响应数据 (每个chunk):") + print("-" * 80) + + with client.stream( + "POST", + "/antigravity/v1/messages", + json=stream_request_body, + headers={"Authorization": test_token} + ) as response: + print(f"状态码: {response.status_code}") + print(f"Content-Type: {response.headers.get('content-type', 'N/A')}\n") + + chunk_count = 0 + for chunk in response.iter_bytes(): + if chunk: + chunk_count += 1 + print(f"\nChunk #{chunk_count}:") + print(f" 类型: {type(chunk).__name__}") + print(f" 长度: {len(chunk)}") + + # 解码chunk + try: + chunk_str = chunk.decode('utf-8') + print(f" 内容预览: {repr(chunk_str[:200] if len(chunk_str) > 200 else chunk_str)}") + + # 如果是SSE格式,尝试解析每一行 + if chunk_str.startswith("event: ") or chunk_str.startswith("data: "): + # 按行分割,处理每个SSE事件 + for line in chunk_str.strip().split('\n'): + line = line.strip() + if not line: + continue + + if line == "data: [DONE]": + print(f" => 流结束标记") + elif line.startswith("data: "): + try: + json_str = line[6:] # 去掉 "data: " 前缀 + json_data = json.loads(json_str) + print(f" 解析后的JSON: {json.dumps(json_data, indent=4, ensure_ascii=False)}") + except Exception as e: + print(f" SSE解析失败: {e}") + except Exception as e: + print(f" 解码失败: {e}") + + print(f"\n总共收到 {chunk_count} 个chunk") + + def test_fake_stream_request(): + """测试假流式请求""" + print("\n" + "=" * 80) + print("【测试3】假流式请求 (POST /antigravity/v1/messages with 假流式 prefix)") + print("=" * 80) + + fake_stream_request_body = test_request_body.copy() + fake_stream_request_body["model"] = "假流式/gemini-2.5-flash" + fake_stream_request_body["stream"] = True + + print(f"请求体: {json.dumps(fake_stream_request_body, indent=2, ensure_ascii=False)}\n") + + print("假流式响应数据 (每个chunk):") + print("-" * 80) + + with client.stream( + "POST", + "/antigravity/v1/messages", + json=fake_stream_request_body, + headers={"Authorization": test_token} + ) as response: + print(f"状态码: {response.status_code}") + print(f"Content-Type: {response.headers.get('content-type', 'N/A')}\n") + + chunk_count = 0 + for chunk in response.iter_bytes(): + if chunk: + chunk_count += 1 + chunk_str = chunk.decode('utf-8') + + print(f"\nChunk #{chunk_count}:") + print(f" 长度: {len(chunk_str)} 字节") + + # 解析chunk中的所有SSE事件 + events = [] + for line in chunk_str.split('\n'): + line = line.strip() + if line.startswith("data: ") or line.startswith("event: "): + events.append(line) + + print(f" 包含 {len(events)} 个SSE事件") + + # 显示每个事件 + for event_idx, event_line in enumerate(events, 1): + if event_line == "data: [DONE]": + print(f" 事件 #{event_idx}: [DONE]") + elif event_line.startswith("data: "): + try: + json_str = event_line[6:] # 去掉 "data: " 前缀 + json_data = json.loads(json_str) + event_type = json_data.get("type", "unknown") + print(f" 事件 #{event_idx}: type={event_type}") + except Exception as e: + print(f" 事件 #{event_idx}: 解析失败 - {e}") + + print(f"\n总共收到 {chunk_count} 个HTTP chunk") + + # 运行测试 + try: + # 测试非流式请求 + test_non_stream_request() + + # 测试流式请求 + test_stream_request() + + # 测试假流式请求 + test_fake_stream_request() + + print("\n" + "=" * 80) + print("测试完成") + print("=" * 80) + + except Exception as e: + print(f"\n❌ 测试过程中出现异常: {e}") + import traceback + traceback.print_exc() diff --git a/src/router/antigravity/gemini.py b/src/router/antigravity/gemini.py new file mode 100644 index 0000000000000000000000000000000000000000..8398463c83d18acf29c2a1f592755756772d9589 --- /dev/null +++ b/src/router/antigravity/gemini.py @@ -0,0 +1,685 @@ +""" +Gemini Router - Handles native Gemini format API requests (Antigravity backend) +处理原生Gemini格式请求的路由模块(Antigravity后端) +""" + +import sys +from pathlib import Path + +# 添加项目根目录到Python路径 +project_root = Path(__file__).resolve().parent.parent.parent.parent +if str(project_root) not in sys.path: + sys.path.insert(0, str(project_root)) + +# 标准库 +import asyncio +import json + +# 第三方库 +from fastapi import APIRouter, Depends, HTTPException, Path, Request +from fastapi.responses import JSONResponse, StreamingResponse + +# 本地模块 - 配置和日志 +from config import get_anti_truncation_max_attempts +from log import log + +# 本地模块 - 工具和认证 +from src.utils import ( + get_base_model_from_feature_model, + is_anti_truncation_model, + authenticate_gemini_flexible, + is_fake_streaming_model +) + +# 本地模块 - 转换器(假流式需要) +from src.converter.fake_stream import ( + parse_response_for_fake_stream, + build_gemini_fake_stream_chunks, + create_gemini_heartbeat_chunk, +) + +# 本地模块 - 基础路由工具 +from src.router.hi_check import is_health_check_request, create_health_check_response +from src.router.stream_passthrough import ( + build_streaming_response_or_error, + prepend_async_item, + read_first_async_item, +) + +# 本地模块 - 数据模型 +from src.models import GeminiRequest, model_to_dict + +# 本地模块 - 任务管理 +from src.task_manager import create_managed_task + + +# ==================== 路由器初始化 ==================== + +router = APIRouter() + + +# ==================== API 路由 ==================== + +@router.post("/antigravity/v1beta/models/{model:path}:generateContent") +@router.post("/antigravity/v1/models/{model:path}:generateContent") +async def generate_content( + gemini_request: "GeminiRequest", + model: str = Path(..., description="Model name"), + api_key: str = Depends(authenticate_gemini_flexible), +): + """ + 处理Gemini格式的内容生成请求(非流式) + + Args: + gemini_request: Gemini格式的请求体 + model: 模型名称 + api_key: API 密钥 + """ + log.debug(f"[ANTIGRAVITY] Non-streaming request for model: {model}") + + # 转换为字典 + normalized_dict = model_to_dict(gemini_request) + + # 健康检查 + if is_health_check_request(normalized_dict, format="gemini"): + response = create_health_check_response(format="gemini") + return JSONResponse(content=response) + + # 处理模型名称和功能检测 + use_anti_truncation = is_anti_truncation_model(model) + real_model = get_base_model_from_feature_model(model) + + # 对于抗截断模型的非流式请求,给出警告 + if use_anti_truncation: + log.warning("抗截断功能仅在流式传输时有效,非流式请求将忽略此设置") + + # 更新模型名为真实模型名 + normalized_dict["model"] = real_model + + # 规范化 Gemini 请求 (使用 antigravity 模式) + from src.converter.gemini_fix import normalize_gemini_request + normalized_dict = await normalize_gemini_request(normalized_dict, mode="antigravity") + + # 准备API请求格式 - 提取model并将其他字段放入request中 + api_request = { + "model": normalized_dict.pop("model"), + "request": normalized_dict + } + + # 调用 API 层的非流式请求 + from src.api.antigravity import non_stream_request + response = await non_stream_request(body=api_request) + + # 解包装响应:Antigravity API 可能返回的格式有额外的 response 包装层 + # 需要提取并返回标准 Gemini 格式 + # 保持 Gemini 原生的 inlineData 格式,不进行 Markdown 转换 + try: + if response.status_code == 200: + response_data = json.loads(response.body if hasattr(response, 'body') else response.content) + # 如果有 response 包装,解包装它 + if "response" in response_data: + unwrapped_data = response_data["response"] + return JSONResponse(content=unwrapped_data) + # 错误响应或没有 response 字段,直接返回 + return response + except Exception as e: + log.warning(f"Failed to unwrap response: {e}, returning original response") + return response + +@router.post("/antigravity/v1beta/models/{model:path}:streamGenerateContent") +@router.post("/antigravity/v1/models/{model:path}:streamGenerateContent") +async def stream_generate_content( + gemini_request: GeminiRequest, + model: str = Path(..., description="Model name"), + api_key: str = Depends(authenticate_gemini_flexible), +): + """ + 处理Gemini格式的流式内容生成请求 + + Args: + gemini_request: Gemini格式的请求体 + model: 模型名称 + api_key: API 密钥 + """ + log.debug(f"[ANTIGRAVITY] Streaming request for model: {model}") + + # 转换为字典 + normalized_dict = model_to_dict(gemini_request) + + # 处理模型名称和功能检测 + use_fake_streaming = is_fake_streaming_model(model) + use_anti_truncation = is_anti_truncation_model(model) + real_model = get_base_model_from_feature_model(model) + + # 更新模型名为真实模型名 + normalized_dict["model"] = real_model + + # ========== 假流式生成器 ========== + async def fake_stream_generator(): + from src.converter.gemini_fix import normalize_gemini_request + from src.api.antigravity import non_stream_request + + normalized_req = await normalize_gemini_request(normalized_dict.copy(), mode="antigravity") + + # 准备API请求格式 - 提取model并将其他字段放入request中 + api_request = { + "model": normalized_req.pop("model"), + "request": normalized_req + } + + response = await non_stream_request(body=api_request) + + # 检查响应状态码 + if hasattr(response, "status_code") and response.status_code != 200: + log.error(f"Fake streaming got error response: status={response.status_code}") + yield response + return + + # 处理成功响应 - 提取响应内容 + if hasattr(response, "body"): + response_body = response.body.decode() if isinstance(response.body, bytes) else response.body + elif hasattr(response, "content"): + response_body = response.content.decode() if isinstance(response.content, bytes) else response.content + else: + response_body = str(response) + + try: + response_data = json.loads(response_body) + log.debug(f"Gemini fake stream response data: {response_data}") + + # 检查是否是错误响应(有些错误可能status_code是200但包含error字段) + if "error" in response_data: + log.error(f"Fake streaming got error in response body: {response_data['error']}") + yield f"data: {json.dumps(response_data)}\n\n".encode() + yield "data: [DONE]\n\n".encode() + return + + # 使用统一的解析函数 + content, reasoning_content, finish_reason, images = parse_response_for_fake_stream(response_data) + + log.debug(f"Gemini extracted content: {content}") + log.debug(f"Gemini extracted reasoning: {reasoning_content[:100] if reasoning_content else 'None'}...") + log.debug(f"Gemini extracted images count: {len(images)}") + + # 构建响应块 + chunks = build_gemini_fake_stream_chunks(content, reasoning_content, finish_reason, images) + for idx, chunk in enumerate(chunks): + chunk_json = json.dumps(chunk) + log.debug(f"[FAKE_STREAM] Yielding chunk #{idx+1}: {chunk_json[:200]}") + yield f"data: {chunk_json}\n\n".encode() + + except Exception as e: + log.error(f"Response parsing failed: {e}, directly yield original response") + # 直接yield原始响应,不进行包装 + yield f"data: {response_body}\n\n".encode() + + yield "data: [DONE]\n\n".encode() + + # ========== 流式抗截断生成器 ========== + async def anti_truncation_generator(): + from src.converter.gemini_fix import normalize_gemini_request + from src.converter.anti_truncation import AntiTruncationStreamProcessor + from src.converter.anti_truncation import apply_anti_truncation + from src.api.antigravity import stream_request + from fastapi import Response + + # 先进行基础标准化 + normalized_req = await normalize_gemini_request(normalized_dict.copy(), mode="antigravity") + + # 准备API请求格式 - 提取model并将其他字段放入request中 + api_request = { + "model": normalized_req.pop("model") if "model" in normalized_req else real_model, + "request": normalized_req + } + + max_attempts = await get_anti_truncation_max_attempts() + + # 首先对payload应用反截断指令 + anti_truncation_payload = apply_anti_truncation(api_request) + + first_attempt_stream = stream_request(body=anti_truncation_payload, native=False) + try: + first_chunk = await read_first_async_item(first_attempt_stream) + except StopAsyncIteration: + return + + if isinstance(first_chunk, Response): + yield first_chunk + return + + first_attempt_pending = True + + async def stream_request_wrapper(payload): + nonlocal first_attempt_pending + + if first_attempt_pending: + first_attempt_pending = False + stream_gen = prepend_async_item(first_chunk, first_attempt_stream) + else: + stream_gen = stream_request(body=payload, native=False) + return StreamingResponse(stream_gen, media_type="text/event-stream") + + # 创建反截断处理器 + processor = AntiTruncationStreamProcessor( + stream_request_wrapper, + anti_truncation_payload, + max_attempts, + enable_prefill_mode=("claude" not in str(api_request.get("model", "")).lower()), + ) + + # 迭代 process_stream() 生成器,并展开 response 包装 + async for chunk in processor.process_stream(): + if isinstance(chunk, (str, bytes)): + chunk_str = chunk.decode('utf-8') if isinstance(chunk, bytes) else chunk + + # 解析并展开 response 包装 + if chunk_str.startswith("data: "): + json_str = chunk_str[6:].strip() + + # 跳过 [DONE] 标记 + if json_str == "[DONE]": + yield chunk + continue + + try: + # 解析JSON + data = json.loads(json_str) + + # 展开 response 包装 + if "response" in data and "candidates" not in data: + log.debug(f"[ANTIGRAVITY-ANTI-TRUNCATION] 展开response包装") + unwrapped_data = data["response"] + # 重新构建SSE格式 + yield f"data: {json.dumps(unwrapped_data, ensure_ascii=False)}\n\n".encode('utf-8') + else: + # 已经是展开的格式,直接返回 + yield chunk + except json.JSONDecodeError: + # JSON解析失败,直接返回原始chunk + yield chunk + else: + # 不是SSE格式,直接返回 + yield chunk + else: + # 其他类型,直接返回 + yield chunk + + # ========== 普通流式生成器 ========== + async def normal_stream_generator(): + from src.converter.gemini_fix import normalize_gemini_request + from src.api.antigravity import stream_request + from fastapi import Response + + normalized_req = await normalize_gemini_request(normalized_dict.copy(), mode="antigravity") + + # 准备API请求格式 - 提取model并将其他字段放入request中 + api_request = { + "model": normalized_req.pop("model"), + "request": normalized_req + } + + # 所有流式请求都使用非 native 模式(SSE格式)并展开 response 包装 + log.debug(f"[ANTIGRAVITY] 使用非native模式,将展开response包装") + stream_gen = stream_request(body=api_request, native=False) + try: + first_chunk = await read_first_async_item(stream_gen) + except StopAsyncIteration: + return + + if isinstance(first_chunk, Response): + yield first_chunk + return + + # 展开 response 包装 + async for chunk in prepend_async_item(first_chunk, stream_gen): + # 检查是否是Response对象(错误情况) + if isinstance(chunk, Response): + # 将Response转换为SSE格式的错误消息 + try: + error_content = chunk.body if isinstance(chunk.body, bytes) else (chunk.body or b'').encode('utf-8') + error_json = json.loads(error_content.decode('utf-8')) + except Exception: + error_json = {"error": {"code": chunk.status_code, "message": "upstream error", "status": "ERROR"}} + log.error(f"[ANTIGRAVITY STREAM] 返回错误给客户端: status={chunk.status_code}, error={str(error_json)[:200]}") + yield f"data: {json.dumps(error_json)}\n\n".encode('utf-8') + yield b"data: [DONE]\n\n" + return + + # 处理SSE格式的chunk + if isinstance(chunk, (str, bytes)): + chunk_str = chunk.decode('utf-8') if isinstance(chunk, bytes) else chunk + + # 解析并展开 response 包装 + if chunk_str.startswith("data: "): + json_str = chunk_str[6:].strip() + + # 跳过 [DONE] 标记 + if json_str == "[DONE]": + yield chunk + continue + + try: + # 解析JSON + data = json.loads(json_str) + + # 展开 response 包装 + if "response" in data and "candidates" not in data: + log.debug(f"[ANTIGRAVITY] 展开response包装") + unwrapped_data = data["response"] + # 重新构建SSE格式 + yield f"data: {json.dumps(unwrapped_data, ensure_ascii=False)}\n\n".encode('utf-8') + else: + # 已经是展开的格式,直接返回 + yield chunk + except json.JSONDecodeError: + # JSON解析失败,直接返回原始chunk + yield chunk + else: + # 不是SSE格式,直接返回 + yield chunk + + # ========== 根据模式选择生成器 ========== + if use_fake_streaming: + return await build_streaming_response_or_error(fake_stream_generator()) + elif use_anti_truncation: + log.info("启用流式抗截断功能") + return await build_streaming_response_or_error(anti_truncation_generator()) + else: + return await build_streaming_response_or_error(normal_stream_generator()) + +@router.post("/antigravity/v1beta/models/{model:path}:countTokens") +@router.post("/antigravity/v1/models/{model:path}:countTokens") +async def count_tokens( + request: Request = None, + api_key: str = Depends(authenticate_gemini_flexible), +): + """ + 模拟Gemini格式的token计数 + + 使用简单的启发式方法:大约4字符=1token + """ + + try: + request_data = await request.json() + except Exception as e: + log.error(f"Failed to parse JSON request: {e}") + raise HTTPException(status_code=400, detail=f"Invalid JSON: {str(e)}") + + # 简单的token计数模拟 - 基于文本长度估算 + total_tokens = 0 + + # 如果有contents字段 + if "contents" in request_data: + for content in request_data["contents"]: + if "parts" in content: + for part in content["parts"]: + if "text" in part: + # 简单估算:大约4字符=1token + text_length = len(part["text"]) + total_tokens += max(1, text_length // 4) + + # 如果有generateContentRequest字段 + elif "generateContentRequest" in request_data: + gen_request = request_data["generateContentRequest"] + if "contents" in gen_request: + for content in gen_request["contents"]: + if "parts" in content: + for part in content["parts"]: + if "text" in part: + text_length = len(part["text"]) + total_tokens += max(1, text_length // 4) + + # 返回Gemini格式的响应 + return JSONResponse(content={"totalTokens": total_tokens}) + +# ==================== 测试代码 ==================== + +if __name__ == "__main__": + """ + 测试代码:演示Gemini路由的流式和非流式响应 + 运行方式: python src/router/antigravity/gemini.py + """ + + from fastapi.testclient import TestClient + from fastapi import FastAPI + + print("=" * 80) + print("Gemini Router (Antigravity Backend) 测试") + print("=" * 80) + + # 创建测试应用 + app = FastAPI() + app.include_router(router) + + # 测试客户端 + client = TestClient(app) + + # 测试请求体 (Gemini格式) + test_request_body = { + "contents": [ + { + "role": "user", + "parts": [{"text": "Hello, tell me a joke in one sentence."}] + } + ] + } + + # 测试API密钥(模拟) + test_api_key = "pwd" + + def test_non_stream_request(): + """测试非流式请求""" + print("\n" + "=" * 80) + print("【测试2】非流式请求 (POST /antigravity/v1/models/gemini-2.5-flash:generateContent)") + print("=" * 80) + print(f"请求体: {json.dumps(test_request_body, indent=2, ensure_ascii=False)}\n") + + response = client.post( + "/antigravity/v1/models/gemini-2.5-flash:generateContent", + json=test_request_body, + params={"key": test_api_key} + ) + + print("非流式响应数据:") + print("-" * 80) + print(f"状态码: {response.status_code}") + print(f"Content-Type: {response.headers.get('content-type', 'N/A')}") + + try: + content = response.text + print(f"\n响应内容 (原始):\n{content}\n") + + # 尝试解析JSON + try: + json_data = response.json() + print(f"响应内容 (格式化JSON):") + print(json.dumps(json_data, indent=2, ensure_ascii=False)) + except json.JSONDecodeError: + print("(非JSON格式)") + except Exception as e: + print(f"内容解析失败: {e}") + + def test_stream_request(): + """测试流式请求""" + print("\n" + "=" * 80) + print("【测试3】流式请求 (POST /antigravity/v1/models/gemini-2.5-flash:streamGenerateContent)") + print("=" * 80) + print(f"请求体: {json.dumps(test_request_body, indent=2, ensure_ascii=False)}\n") + + print("流式响应数据 (每个chunk):") + print("-" * 80) + + with client.stream( + "POST", + "/antigravity/v1/models/gemini-2.5-flash:streamGenerateContent", + json=test_request_body, + params={"key": test_api_key} + ) as response: + print(f"状态码: {response.status_code}") + print(f"Content-Type: {response.headers.get('content-type', 'N/A')}\n") + + chunk_count = 0 + for chunk in response.iter_bytes(): + if chunk: + chunk_count += 1 + print(f"\nChunk #{chunk_count}:") + print(f" 类型: {type(chunk).__name__}") + print(f" 长度: {len(chunk)}") + + # 解码chunk + try: + chunk_str = chunk.decode('utf-8') + print(f" 内容预览: {repr(chunk_str[:200] if len(chunk_str) > 200 else chunk_str)}") + + # 如果是SSE格式,尝试解析每一行 + if chunk_str.startswith("data: "): + # 按行分割,处理每个SSE事件 + for line in chunk_str.strip().split('\n'): + line = line.strip() + if not line: + continue + + if line == "data: [DONE]": + print(f" => 流结束标记") + elif line.startswith("data: "): + try: + json_str = line[6:] # 去掉 "data: " 前缀 + json_data = json.loads(json_str) + print(f" 解析后的JSON: {json.dumps(json_data, indent=4, ensure_ascii=False)}") + except Exception as e: + print(f" SSE解析失败: {e}") + except Exception as e: + print(f" 解码失败: {e}") + + print(f"\n总共收到 {chunk_count} 个chunk") + + def test_fake_stream_request(): + """测试假流式请求""" + print("\n" + "=" * 80) + print("【测试4】假流式请求 (POST /antigravity/v1/models/假流式/gemini-2.5-flash:streamGenerateContent)") + print("=" * 80) + print(f"请求体: {json.dumps(test_request_body, indent=2, ensure_ascii=False)}\n") + + print("假流式响应数据 (每个chunk):") + print("-" * 80) + + with client.stream( + "POST", + "/antigravity/v1/models/假流式/gemini-2.5-flash:streamGenerateContent", + json=test_request_body, + params={"key": test_api_key} + ) as response: + print(f"状态码: {response.status_code}") + print(f"Content-Type: {response.headers.get('content-type', 'N/A')}\n") + + chunk_count = 0 + for chunk in response.iter_bytes(): + if chunk: + chunk_count += 1 + chunk_str = chunk.decode('utf-8') + + print(f"\nChunk #{chunk_count}:") + print(f" 长度: {len(chunk_str)} 字节") + + # 解析chunk中的所有SSE事件 + events = [] + for line in chunk_str.split('\n'): + line = line.strip() + if line.startswith("data: "): + events.append(line) + + print(f" 包含 {len(events)} 个SSE事件") + + # 显示每个事件 + for event_idx, event_line in enumerate(events, 1): + if event_line == "data: [DONE]": + print(f" 事件 #{event_idx}: [DONE]") + else: + try: + json_str = event_line[6:] # 去掉 "data: " 前缀 + json_data = json.loads(json_str) + # 提取text内容 + text = json_data.get("candidates", [{}])[0].get("content", {}).get("parts", [{}])[0].get("text", "") + finish_reason = json_data.get("candidates", [{}])[0].get("finishReason") + print(f" 事件 #{event_idx}: text={repr(text[:50])}{'...' if len(text) > 50 else ''}, finishReason={finish_reason}") + except Exception as e: + print(f" 事件 #{event_idx}: 解析失败 - {e}") + + print(f"\n总共收到 {chunk_count} 个HTTP chunk") + + def test_anti_truncation_stream_request(): + """测试流式抗截断请求""" + print("\n" + "=" * 80) + print("【测试5】流式抗截断请求 (POST /antigravity/v1/models/流式抗截断/gemini-2.5-flash:streamGenerateContent)") + print("=" * 80) + print(f"请求体: {json.dumps(test_request_body, indent=2, ensure_ascii=False)}\n") + + print("流式抗截断响应数据 (每个chunk):") + print("-" * 80) + + with client.stream( + "POST", + "/antigravity/v1/models/流式抗截断/gemini-2.5-flash:streamGenerateContent", + json=test_request_body, + params={"key": test_api_key} + ) as response: + print(f"状态码: {response.status_code}") + print(f"Content-Type: {response.headers.get('content-type', 'N/A')}\n") + + chunk_count = 0 + for chunk in response.iter_bytes(): + if chunk: + chunk_count += 1 + print(f"\nChunk #{chunk_count}:") + print(f" 类型: {type(chunk).__name__}") + print(f" 长度: {len(chunk)}") + + # 解码chunk + try: + chunk_str = chunk.decode('utf-8') + print(f" 内容预览: {repr(chunk_str[:200] if len(chunk_str) > 200 else chunk_str)}") + + # 如果是SSE格式,尝试解析每一行 + if chunk_str.startswith("data: "): + # 按行分割,处理每个SSE事件 + for line in chunk_str.strip().split('\n'): + line = line.strip() + if not line: + continue + + if line == "data: [DONE]": + print(f" => 流结束标记") + elif line.startswith("data: "): + try: + json_str = line[6:] # 去掉 "data: " 前缀 + json_data = json.loads(json_str) + print(f" 解析后的JSON: {json.dumps(json_data, indent=4, ensure_ascii=False)}") + except Exception as e: + print(f" SSE解析失败: {e}") + except Exception as e: + print(f" 解码失败: {e}") + + print(f"\n总共收到 {chunk_count} 个chunk") + + # 运行测试 + try: + # 测试非流式请求 + test_non_stream_request() + + # 测试流式请求 + test_stream_request() + + # 测试假流式请求 + test_fake_stream_request() + + # 测试流式抗截断请求 + test_anti_truncation_stream_request() + + print("\n" + "=" * 80) + print("测试完成") + print("=" * 80) + + except Exception as e: + print(f"\n❌ 测试过程中出现异常: {e}") + import traceback + traceback.print_exc() diff --git a/src/router/antigravity/model_list.py b/src/router/antigravity/model_list.py new file mode 100644 index 0000000000000000000000000000000000000000..f27c81822dbb26b1dc2d6939704ecd193b9f4b25 --- /dev/null +++ b/src/router/antigravity/model_list.py @@ -0,0 +1,105 @@ +""" +Antigravity Model List Router - Handles model list requests +Antigravity 模型列表路由 - 处理模型列表请求 +""" + +import sys +from pathlib import Path + +# 添加项目根目录到Python路径 +project_root = Path(__file__).resolve().parent.parent.parent.parent +if str(project_root) not in sys.path: + sys.path.insert(0, str(project_root)) + +# 第三方库 +from fastapi import APIRouter, Depends +from fastapi.responses import JSONResponse + +# 本地模块 - 工具和认证 +from src.utils import ( + get_base_model_from_feature_model, + authenticate_flexible +) + +# 本地模块 - API +from src.api.antigravity import fetch_available_models + +# 本地模块 - 基础路由工具 +from src.router.base_router import create_gemini_model_list, create_openai_model_list +from src.models import model_to_dict +from log import log + + +# ==================== 路由器初始化 ==================== + +router = APIRouter() + + +# ==================== 辅助函数 ==================== + +async def get_antigravity_models_with_features(): + """ + 获取 Antigravity 模型列表并添加功能前缀 + + Returns: + 带有功能前缀的模型列表 + """ + # 从 API 获取基础模型列表 + base_models_data = await fetch_available_models() + + if not base_models_data: + log.warning("[ANTIGRAVITY MODEL LIST] 无法获取模型列表,返回空列表") + return [] + + # 提取模型 ID + base_model_ids = [model['id'] for model in base_models_data if 'id' in model] + + # 添加功能前缀 + models = [] + for base_model in base_model_ids: + # 基础模型 + models.append(base_model) + + # 假流式模型 (前缀格式) + models.append(f"假流式/{base_model}") + + # 流式抗截断模型 (仅在流式传输时有效,前缀格式) + models.append(f"流式抗截断/{base_model}") + + log.info(f"[ANTIGRAVITY MODEL LIST] 生成了 {len(models)} 个模型(包含功能前缀)") + return models + + +# ==================== API 路由 ==================== + +@router.get("/antigravity/v1beta/models") +async def list_gemini_models(token: str = Depends(authenticate_flexible)): + """ + 返回 Gemini 格式的模型列表 + + 从 src.api.antigravity.fetch_available_models 动态获取模型列表 + 并添加假流式和流式抗截断前缀 + """ + models = await get_antigravity_models_with_features() + log.info("[ANTIGRAVITY MODEL LIST] 返回 Gemini 格式") + return JSONResponse(content=create_gemini_model_list( + models, + base_name_extractor=get_base_model_from_feature_model + )) + + +@router.get("/antigravity/v1/models") +async def list_openai_models(token: str = Depends(authenticate_flexible)): + """ + 返回 OpenAI 格式的模型列表 + + 从 src.api.antigravity.fetch_available_models 动态获取模型列表 + 并添加假流式和流式抗截断前缀 + """ + models = await get_antigravity_models_with_features() + log.info("[ANTIGRAVITY MODEL LIST] 返回 OpenAI 格式") + model_list = create_openai_model_list(models, owned_by="google") + return JSONResponse(content={ + "object": "list", + "data": [model_to_dict(model) for model in model_list.data] + }) diff --git a/src/router/antigravity/openai.py b/src/router/antigravity/openai.py new file mode 100644 index 0000000000000000000000000000000000000000..f914bd8735e7d0c7c3b618f6d180394f07ffdab6 --- /dev/null +++ b/src/router/antigravity/openai.py @@ -0,0 +1,590 @@ +""" +OpenAI Router - Handles OpenAI format API requests via Antigravity +通过Antigravity处理OpenAI格式请求的路由模块 +""" + +import sys +from pathlib import Path + +# 添加项目根目录到Python路径 +project_root = Path(__file__).resolve().parent.parent.parent.parent +if str(project_root) not in sys.path: + sys.path.insert(0, str(project_root)) + +# 标准库 +import asyncio +import json + +# 第三方库 +from fastapi import APIRouter, Depends, HTTPException +from fastapi.responses import JSONResponse, StreamingResponse + +# 本地模块 - 配置和日志 +from config import get_anti_truncation_max_attempts +from log import log + +# 本地模块 - 工具和认证 +from src.utils import ( + get_base_model_from_feature_model, + is_anti_truncation_model, + is_fake_streaming_model, + authenticate_bearer, +) + +# 本地模块 - 转换器(假流式需要) +from src.converter.fake_stream import ( + parse_response_for_fake_stream, + build_openai_fake_stream_chunks, + create_openai_heartbeat_chunk, +) + +# 本地模块 - 基础路由工具 +from src.router.hi_check import is_health_check_request, create_health_check_response +from src.router.stream_passthrough import ( + build_streaming_response_or_error, + prepend_async_item, + read_first_async_item, +) + +# 本地模块 - 数据模型 +from src.models import OpenAIChatCompletionRequest, model_to_dict + +# 本地模块 - 任务管理 +from src.task_manager import create_managed_task + + +# ==================== 路由器初始化 ==================== + +router = APIRouter() + + +# ==================== API 路由 ==================== + +@router.post("/antigravity/v1/chat/completions") +async def chat_completions( + openai_request: OpenAIChatCompletionRequest, + token: str = Depends(authenticate_bearer) +): + """ + 处理OpenAI格式的聊天完成请求(流式和非流式) + + Args: + openai_request: OpenAI格式的请求体 + token: Bearer认证令牌 + """ + log.debug(f"[ANTIGRAVITY-OPENAI] Request for model: {openai_request.model}") + + # 转换为字典 + normalized_dict = model_to_dict(openai_request) + + # 健康检查 + if is_health_check_request(normalized_dict, format="openai"): + response = create_health_check_response(format="openai") + return JSONResponse(content=response) + + # 处理模型名称和功能检测 + use_fake_streaming = is_fake_streaming_model(openai_request.model) + use_anti_truncation = is_anti_truncation_model(openai_request.model) + real_model = get_base_model_from_feature_model(openai_request.model) + + # 获取流式标志 + is_streaming = openai_request.stream + + # 对于抗截断模型的非流式请求,给出警告 + if use_anti_truncation and not is_streaming: + log.warning("抗截断功能仅在流式传输时有效,非流式请求将忽略此设置") + + # 更新模型名为真实模型名 + normalized_dict["model"] = real_model + + # 转换为 Gemini 格式 (使用 converter) + from src.converter.openai2gemini import convert_openai_to_gemini_request + gemini_dict = await convert_openai_to_gemini_request(normalized_dict) + + # convert_openai_to_gemini_request 不包含 model 字段,需要手动添加 + gemini_dict["model"] = real_model + + # 规范化 Gemini 请求 (使用 antigravity 模式) + from src.converter.gemini_fix import normalize_gemini_request + gemini_dict = await normalize_gemini_request(gemini_dict, mode="antigravity") + + # 准备API请求格式 - 提取model并将其他字段放入request中 + api_request = { + "model": gemini_dict.pop("model"), + "request": gemini_dict + } + + # ========== 非流式请求 ========== + if not is_streaming: + # 调用 API 层的非流式请求 + from src.api.antigravity import non_stream_request + response = await non_stream_request(body=api_request) + + # 检查响应状态码 + status_code = getattr(response, "status_code", 200) + + # 提取响应体 + if hasattr(response, "body"): + response_body = response.body.decode() if isinstance(response.body, bytes) else response.body + elif hasattr(response, "content"): + response_body = response.content.decode() if isinstance(response.content, bytes) else response.content + else: + response_body = str(response) + + try: + gemini_response = json.loads(response_body) + except Exception as e: + log.error(f"Failed to parse Gemini response: {e}") + raise HTTPException(status_code=500, detail="Response parsing failed") + + # 转换为 OpenAI 格式 + from src.converter.openai2gemini import convert_gemini_to_openai_response + openai_response = convert_gemini_to_openai_response( + gemini_response, + real_model, + status_code + ) + + return JSONResponse(content=openai_response, status_code=status_code) + + # ========== 流式请求 ========== + + # ========== 假流式生成器 ========== + async def fake_stream_generator(): + from src.api.antigravity import non_stream_request + + response = await non_stream_request(body=api_request) + + # 检查响应状态码 + if hasattr(response, "status_code") and response.status_code != 200: + log.error(f"Fake streaming got error response: status={response.status_code}") + yield response + return + + # 处理成功响应 - 提取响应内容 + if hasattr(response, "body"): + response_body = response.body.decode() if isinstance(response.body, bytes) else response.body + elif hasattr(response, "content"): + response_body = response.content.decode() if isinstance(response.content, bytes) else response.content + else: + response_body = str(response) + + try: + gemini_response = json.loads(response_body) + log.debug(f"OpenAI fake stream Gemini response: {gemini_response}") + + # 检查是否是错误响应(有些错误可能status_code是200但包含error字段) + if "error" in gemini_response: + log.error(f"Fake streaming got error in response body: {gemini_response['error']}") + # 转换错误为 OpenAI 格式 + from src.converter.openai2gemini import convert_gemini_to_openai_response + openai_error = convert_gemini_to_openai_response( + gemini_response, + real_model, + 200 + ) + yield f"data: {json.dumps(openai_error)}\n\n".encode() + yield "data: [DONE]\n\n".encode() + return + + # 使用统一的解析函数 + content, reasoning_content, finish_reason, images = parse_response_for_fake_stream(gemini_response) + + log.debug(f"OpenAI extracted content: {content}") + log.debug(f"OpenAI extracted reasoning: {reasoning_content[:100] if reasoning_content else 'None'}...") + log.debug(f"OpenAI extracted images count: {len(images)}") + + # 构建响应块 + chunks = build_openai_fake_stream_chunks(content, reasoning_content, finish_reason, real_model, images) + for idx, chunk in enumerate(chunks): + chunk_json = json.dumps(chunk) + log.debug(f"[FAKE_STREAM] Yielding chunk #{idx+1}: {chunk_json[:200]}") + yield f"data: {chunk_json}\n\n".encode() + + except Exception as e: + log.error(f"Response parsing failed: {e}, directly yield error") + # 构建错误响应 + error_chunk = { + "id": "error", + "object": "chat.completion.chunk", + "created": int(asyncio.get_event_loop().time()), + "model": real_model, + "choices": [{ + "index": 0, + "delta": {"content": f"Error: {str(e)}"}, + "finish_reason": "error" + }] + } + yield f"data: {json.dumps(error_chunk)}\n\n".encode() + + yield "data: [DONE]\n\n".encode() + + # ========== 流式抗截断生成器 ========== + async def anti_truncation_generator(): + from src.converter.anti_truncation import AntiTruncationStreamProcessor + from src.api.antigravity import stream_request + from src.converter.anti_truncation import apply_anti_truncation + from fastapi import Response + + max_attempts = await get_anti_truncation_max_attempts() + + # 首先对payload应用反截断指令 + anti_truncation_payload = apply_anti_truncation(api_request) + + first_attempt_stream = stream_request(body=anti_truncation_payload, native=False) + try: + first_chunk = await read_first_async_item(first_attempt_stream) + except StopAsyncIteration: + return + + if isinstance(first_chunk, Response): + yield first_chunk + return + + first_attempt_pending = True + + async def stream_request_wrapper(payload): + nonlocal first_attempt_pending + + if first_attempt_pending: + first_attempt_pending = False + stream_gen = prepend_async_item(first_chunk, first_attempt_stream) + else: + stream_gen = stream_request(body=payload, native=False) + + return StreamingResponse(stream_gen, media_type="text/event-stream") + + # 创建反截断处理器 + processor = AntiTruncationStreamProcessor( + stream_request_wrapper, + anti_truncation_payload, + max_attempts, + enable_prefill_mode=("claude" not in str(api_request.get("model", "")).lower()), + ) + + # 转换为 OpenAI 格式 + import uuid + response_id = str(uuid.uuid4()) + + # 直接迭代 process_stream() 生成器,并转换为 OpenAI 格式 + async for chunk in processor.process_stream(): + if not chunk: + continue + + # 解析 Gemini SSE 格式 + chunk_str = chunk.decode('utf-8') if isinstance(chunk, bytes) else chunk + + # 跳过空行 + if not chunk_str.strip(): + continue + + # 处理 [DONE] 标记 + if chunk_str.strip() == "data: [DONE]": + yield "data: [DONE]\n\n".encode('utf-8') + return + + # 解析 "data: {...}" 格式 + if chunk_str.startswith("data: "): + try: + # 转换为 OpenAI 格式 + from src.converter.openai2gemini import convert_gemini_to_openai_stream + openai_chunk_str = convert_gemini_to_openai_stream( + chunk_str, + real_model, + response_id + ) + + if openai_chunk_str: + yield openai_chunk_str.encode('utf-8') + + except Exception as e: + log.error(f"Failed to convert chunk: {e}") + continue + + # 发送结束标记 + yield "data: [DONE]\n\n".encode('utf-8') + + # ========== 普通流式生成器 ========== + async def normal_stream_generator(): + from src.api.antigravity import stream_request + from fastapi import Response + import uuid + + # 调用 API 层的流式请求(不使用 native 模式) + stream_gen = stream_request(body=api_request, native=False) + try: + first_chunk = await read_first_async_item(stream_gen) + except StopAsyncIteration: + return + + if isinstance(first_chunk, Response): + yield first_chunk + return + + response_id = str(uuid.uuid4()) + + # yield所有数据,处理可能的错误Response + async for chunk in prepend_async_item(first_chunk, stream_gen): + # 检查是否是Response对象(错误情况) + if isinstance(chunk, Response): + # 将Response转换为SSE格式的错误消息 + try: + error_content = chunk.body if isinstance(chunk.body, bytes) else (chunk.body or b'').encode('utf-8') + gemini_error = json.loads(error_content.decode('utf-8')) + # 转换为 OpenAI 格式错误 + from src.converter.openai2gemini import convert_gemini_to_openai_response + openai_error = convert_gemini_to_openai_response( + gemini_error, + real_model, + chunk.status_code + ) + yield f"data: {json.dumps(openai_error)}\n\n".encode('utf-8') + except Exception: + yield f"data: {json.dumps({'error': 'Stream error'})}\n\n".encode('utf-8') + yield b"data: [DONE]\n\n" + return + else: + # 正常的bytes数据,转换为 OpenAI 格式 + chunk_str = chunk.decode('utf-8') if isinstance(chunk, bytes) else chunk + + # 跳过空行 + if not chunk_str.strip(): + continue + + # 处理 [DONE] 标记 + if chunk_str.strip() == "data: [DONE]": + yield "data: [DONE]\n\n".encode('utf-8') + return + + # 解析并转换 Gemini chunk 为 OpenAI 格式 + if chunk_str.startswith("data: "): + try: + # 转换为 OpenAI 格式 + from src.converter.openai2gemini import convert_gemini_to_openai_stream + openai_chunk_str = convert_gemini_to_openai_stream( + chunk_str, + real_model, + response_id + ) + + if openai_chunk_str: + yield openai_chunk_str.encode('utf-8') + + except Exception as e: + log.error(f"Failed to convert chunk: {e}") + continue + + # 发送结束标记 + yield "data: [DONE]\n\n".encode('utf-8') + + # ========== 根据模式选择生成器 ========== + if use_fake_streaming: + return await build_streaming_response_or_error(fake_stream_generator()) + elif use_anti_truncation: + log.info("启用流式抗截断功能") + return await build_streaming_response_or_error(anti_truncation_generator()) + else: + return await build_streaming_response_or_error(normal_stream_generator()) + + +# ==================== 测试代码 ==================== + +if __name__ == "__main__": + """ + 测试代码:演示OpenAI路由的流式和非流式响应 + 运行方式: python src/router/antigravity/openai.py + """ + + from fastapi.testclient import TestClient + from fastapi import FastAPI + + print("=" * 80) + print("OpenAI Router 测试") + print("=" * 80) + + # 创建测试应用 + app = FastAPI() + app.include_router(router) + + # 测试客户端 + client = TestClient(app) + + # 测试请求体 (OpenAI格式) + test_request_body = { + "model": "gemini-2.5-flash", + "messages": [ + {"role": "user", "content": "Hello, tell me a joke in one sentence."} + ] + } + + # 测试Bearer令牌(模拟) + test_token = "Bearer pwd" + + def test_non_stream_request(): + """测试非流式请求""" + print("\n" + "=" * 80) + print("【测试1】非流式请求 (POST /antigravity/v1/chat/completions)") + print("=" * 80) + print(f"请求体: {json.dumps(test_request_body, indent=2, ensure_ascii=False)}\n") + + response = client.post( + "/antigravity/v1/chat/completions", + json=test_request_body, + headers={"Authorization": test_token} + ) + + print("非流式响应数据:") + print("-" * 80) + print(f"状态码: {response.status_code}") + print(f"Content-Type: {response.headers.get('content-type', 'N/A')}") + + try: + content = response.text + print(f"\n响应内容 (原始):\n{content}\n") + + # 尝试解析JSON + try: + json_data = response.json() + print(f"响应内容 (格式化JSON):") + print(json.dumps(json_data, indent=2, ensure_ascii=False)) + except json.JSONDecodeError: + print("(非JSON格式)") + except Exception as e: + print(f"内容解析失败: {e}") + + def test_stream_request(): + """测试流式请求""" + print("\n" + "=" * 80) + print("【测试2】流式请求 (POST /antigravity/v1/chat/completions)") + print("=" * 80) + + stream_request_body = test_request_body.copy() + stream_request_body["stream"] = True + + print(f"请求体: {json.dumps(stream_request_body, indent=2, ensure_ascii=False)}\n") + + print("流式响应数据 (每个chunk):") + print("-" * 80) + + with client.stream( + "POST", + "/antigravity/v1/chat/completions", + json=stream_request_body, + headers={"Authorization": test_token} + ) as response: + print(f"状态码: {response.status_code}") + print(f"Content-Type: {response.headers.get('content-type', 'N/A')}\n") + + chunk_count = 0 + for chunk in response.iter_bytes(): + if chunk: + chunk_count += 1 + print(f"\nChunk #{chunk_count}:") + print(f" 类型: {type(chunk).__name__}") + print(f" 长度: {len(chunk)}") + + # 解码chunk + try: + chunk_str = chunk.decode('utf-8') + print(f" 内容预览: {repr(chunk_str[:200] if len(chunk_str) > 200 else chunk_str)}") + + # 如果是SSE格式,尝试解析每一行 + if chunk_str.startswith("data: "): + # 按行分割,处理每个SSE事件 + for line in chunk_str.strip().split('\n'): + line = line.strip() + if not line: + continue + + if line == "data: [DONE]": + print(f" => 流结束标记") + elif line.startswith("data: "): + try: + json_str = line[6:] # 去掉 "data: " 前缀 + json_data = json.loads(json_str) + print(f" 解析后的JSON: {json.dumps(json_data, indent=4, ensure_ascii=False)}") + except Exception as e: + print(f" SSE解析失败: {e}") + except Exception as e: + print(f" 解码失败: {e}") + + print(f"\n总共收到 {chunk_count} 个chunk") + + def test_fake_stream_request(): + """测试假流式请求""" + print("\n" + "=" * 80) + print("【测试3】假流式请求 (POST /antigravity/v1/chat/completions with 假流式 prefix)") + print("=" * 80) + + fake_stream_request_body = test_request_body.copy() + fake_stream_request_body["model"] = "假流式/gemini-2.5-flash" + fake_stream_request_body["stream"] = True + + print(f"请求体: {json.dumps(fake_stream_request_body, indent=2, ensure_ascii=False)}\n") + + print("假流式响应数据 (每个chunk):") + print("-" * 80) + + with client.stream( + "POST", + "/antigravity/v1/chat/completions", + json=fake_stream_request_body, + headers={"Authorization": test_token} + ) as response: + print(f"状态码: {response.status_code}") + print(f"Content-Type: {response.headers.get('content-type', 'N/A')}\n") + + chunk_count = 0 + for chunk in response.iter_bytes(): + if chunk: + chunk_count += 1 + chunk_str = chunk.decode('utf-8') + + print(f"\nChunk #{chunk_count}:") + print(f" 长度: {len(chunk_str)} 字节") + + # 解析chunk中的所有SSE事件 + events = [] + for line in chunk_str.split('\n'): + line = line.strip() + if line.startswith("data: "): + events.append(line) + + print(f" 包含 {len(events)} 个SSE事件") + + # 显示每个事件 + for event_idx, event_line in enumerate(events, 1): + if event_line == "data: [DONE]": + print(f" 事件 #{event_idx}: [DONE]") + else: + try: + json_str = event_line[6:] # 去掉 "data: " 前缀 + json_data = json.loads(json_str) + # 提取content内容 + content = json_data.get("choices", [{}])[0].get("delta", {}).get("content", "") + finish_reason = json_data.get("choices", [{}])[0].get("finish_reason") + print(f" 事件 #{event_idx}: content={repr(content[:50])}{'...' if len(content) > 50 else ''}, finish_reason={finish_reason}") + except Exception as e: + print(f" 事件 #{event_idx}: 解析失败 - {e}") + + print(f"\n总共收到 {chunk_count} 个HTTP chunk") + + # 运行测试 + try: + # 测试非流式请求 + test_non_stream_request() + + # 测试流式请求 + test_stream_request() + + # 测试假流式请求 + test_fake_stream_request() + + print("\n" + "=" * 80) + print("测试完成") + print("=" * 80) + + except Exception as e: + print(f"\n❌ 测试过程中出现异常: {e}") + import traceback + traceback.print_exc() diff --git a/src/router/base_router.py b/src/router/base_router.py new file mode 100644 index 0000000000000000000000000000000000000000..f9d2b37d76d003a34510b7454eb5254bfa0a6cf8 --- /dev/null +++ b/src/router/base_router.py @@ -0,0 +1,74 @@ +""" +Base Router - 共用的路由基础功能 +提供模型列表处理、通用响应等共同功能 +""" + +from typing import List + +from src.models import Model, ModelList + +def create_openai_model_list( + model_ids: List[str], + owned_by: str = "google" +) -> ModelList: + """ + 创建OpenAI格式的模型列表 + + Args: + model_ids: 模型ID列表 + owned_by: 模型所有者 + + Returns: + ModelList对象 + """ + from datetime import datetime, timezone + current_timestamp = int(datetime.now(timezone.utc).timestamp()) + + models = [ + Model( + id=model_id, + object='model', + created=current_timestamp, + owned_by=owned_by + ) + for model_id in model_ids + ] + + return ModelList(data=models) + + +def create_gemini_model_list( + model_ids: List[str], + base_name_extractor=None +) -> dict: + """ + 创建Gemini格式的模型列表 + + Args: + model_ids: 模型ID列表 + base_name_extractor: 可选的基础模型名提取函数 + + Returns: + 包含模型列表的字典 + """ + gemini_models = [] + + for model_id in model_ids: + base_model = model_id + if base_name_extractor: + try: + base_model = base_name_extractor(model_id) + except Exception: + pass + + model_info = { + "name": f"models/{model_id}", + "baseModelId": base_model, + "version": "001", + "displayName": model_id, + "description": f"Gemini {base_model} model", + "supportedGenerationMethods": ["generateContent", "streamGenerateContent"], + } + gemini_models.append(model_info) + + return {"models": gemini_models} \ No newline at end of file diff --git a/src/router/geminicli/anthropic.py b/src/router/geminicli/anthropic.py new file mode 100644 index 0000000000000000000000000000000000000000..f9f9e4dc2bb0bb1e9c709060e6c10d1a3b6ab8d7 --- /dev/null +++ b/src/router/geminicli/anthropic.py @@ -0,0 +1,614 @@ +""" +Anthropic Router - Handles Anthropic/Claude format API requests via GeminiCLI +通过GeminiCLI处理Anthropic/Claude格式请求的路由模块 +""" + +import sys +from pathlib import Path + +# 添加项目根目录到Python路径 +project_root = Path(__file__).resolve().parent.parent.parent.parent +if str(project_root) not in sys.path: + sys.path.insert(0, str(project_root)) + +# 标准库 +import asyncio +import json + +# 第三方库 +from fastapi import APIRouter, Depends, HTTPException, Request +from fastapi.responses import JSONResponse, StreamingResponse + +# 本地模块 - 配置和日志 +from config import get_anti_truncation_max_attempts +from log import log + +# 本地模块 - 工具和认证 +from src.utils import ( + get_base_model_from_feature_model, + is_anti_truncation_model, + is_fake_streaming_model, + authenticate_bearer, +) + +# 本地模块 - 转换器(假流式需要) +from src.converter.fake_stream import ( + parse_response_for_fake_stream, + build_anthropic_fake_stream_chunks, + create_anthropic_heartbeat_chunk, +) + +# 本地模块 - 基础路由工具 +from src.router.hi_check import is_health_check_request, create_health_check_response +from src.router.stream_passthrough import ( + build_streaming_response_or_error, + prepend_async_item, + read_first_async_item, +) + +# 本地模块 - 数据模型 +from src.models import ClaudeRequest, model_to_dict + +# 本地模块 - 任务管理 +from src.task_manager import create_managed_task + +# 本地模块 - Token估算 +from src.token_estimator import estimate_input_tokens + + +# ==================== 路由器初始化 ==================== + +router = APIRouter() + + +# ==================== API 路由 ==================== + +@router.post("/v1/messages") +async def messages( + claude_request: ClaudeRequest, + token: str = Depends(authenticate_bearer) +): + """ + 处理Anthropic/Claude格式的消息请求(流式和非流式) + + Args: + claude_request: Anthropic/Claude格式的请求体 + token: Bearer认证令牌 + """ + log.debug(f"[GEMINICLI-ANTHROPIC] Request for model: {claude_request.model}") + + # 转换为字典 + normalized_dict = model_to_dict(claude_request) + + # 健康检查 + if is_health_check_request(normalized_dict, format="anthropic"): + response = create_health_check_response(format="anthropic") + return JSONResponse(content=response) + + # 处理模型名称和功能检测 + use_fake_streaming = is_fake_streaming_model(claude_request.model) + use_anti_truncation = is_anti_truncation_model(claude_request.model) + real_model = get_base_model_from_feature_model(claude_request.model) + + # 获取流式标志 + is_streaming = claude_request.stream + + # 对于抗截断模型的非流式请求,给出警告 + if use_anti_truncation and not is_streaming: + log.warning("抗截断功能仅在流式传输时有效,非流式请求将忽略此设置") + + # 更新模型名为真实模型名 + normalized_dict["model"] = real_model + + # 转换为 Gemini 格式 (使用 converter) + from src.converter.anthropic2gemini import anthropic_to_gemini_request + gemini_dict = await anthropic_to_gemini_request(normalized_dict) + + # anthropic_to_gemini_request 不包含 model 字段,需要手动添加 + gemini_dict["model"] = real_model + + # 规范化 Gemini 请求 (使用 geminicli 模式) + from src.converter.gemini_fix import normalize_gemini_request + gemini_dict = await normalize_gemini_request(gemini_dict, mode="geminicli") + + # 准备API请求格式 - 提取model并将其他字段放入request中 + api_request = { + "model": gemini_dict.pop("model"), + "request": gemini_dict + } + + # ========== 非流式请求 ========== + if not is_streaming: + # 调用 API 层的非流式请求 + from src.api.geminicli import non_stream_request + response = await non_stream_request(body=api_request) + + # 检查响应状态码 + status_code = getattr(response, "status_code", 200) + + # 提取响应体 + if hasattr(response, "body"): + response_body = response.body.decode() if isinstance(response.body, bytes) else response.body + elif hasattr(response, "content"): + response_body = response.content.decode() if isinstance(response.content, bytes) else response.content + else: + response_body = str(response) + + try: + gemini_response = json.loads(response_body) + except Exception as e: + log.error(f"Failed to parse Gemini response: {e}") + raise HTTPException(status_code=500, detail="Response parsing failed") + + # 转换为 Anthropic 格式 + from src.converter.anthropic2gemini import gemini_to_anthropic_response + anthropic_response = gemini_to_anthropic_response( + gemini_response, + real_model, + status_code + ) + + return JSONResponse(content=anthropic_response, status_code=status_code) + + # ========== 流式请求 ========== + + # ========== 假流式生成器 ========== + async def fake_stream_generator(): + from src.api.geminicli import non_stream_request + + response = await non_stream_request(body=api_request) + + # 检查响应状态码 + if hasattr(response, "status_code") and response.status_code != 200: + log.error(f"Fake streaming got error response: status={response.status_code}") + yield response + return + + # 处理成功响应 - 提取响应内容 + if hasattr(response, "body"): + response_body = response.body.decode() if isinstance(response.body, bytes) else response.body + elif hasattr(response, "content"): + response_body = response.content.decode() if isinstance(response.content, bytes) else response.content + else: + response_body = str(response) + + try: + gemini_response = json.loads(response_body) + log.debug(f"Anthropic fake stream Gemini response: {gemini_response}") + + # 检查是否是错误响应(有些错误可能status_code是200但包含error字段) + if "error" in gemini_response: + log.error(f"Fake streaming got error in response body: {gemini_response['error']}") + # 转换错误为 Anthropic 格式 + from src.converter.anthropic2gemini import gemini_to_anthropic_response + anthropic_error = gemini_to_anthropic_response( + gemini_response, + real_model, + 200 + ) + yield f"data: {json.dumps(anthropic_error)}\n\n".encode() + yield "data: [DONE]\n\n".encode() + return + + # 使用统一的解析函数 + content, reasoning_content, finish_reason, images = parse_response_for_fake_stream(gemini_response) + + log.debug(f"Anthropic extracted content: {content}") + log.debug(f"Anthropic extracted reasoning: {reasoning_content[:100] if reasoning_content else 'None'}...") + log.debug(f"Anthropic extracted images count: {len(images)}") + + # 构建响应块 + chunks = build_anthropic_fake_stream_chunks(content, reasoning_content, finish_reason, real_model, images) + for idx, chunk in enumerate(chunks): + chunk_json = json.dumps(chunk) + log.debug(f"[FAKE_STREAM] Yielding chunk #{idx+1}: {chunk_json[:200]}") + yield f"data: {chunk_json}\n\n".encode() + + except Exception as e: + log.error(f"Response parsing failed: {e}, directly yield error") + # 构建错误响应 + error_chunk = { + "type": "error", + "error": { + "type": "api_error", + "message": str(e) + } + } + yield f"data: {json.dumps(error_chunk)}\n\n".encode() + + yield "data: [DONE]\n\n".encode() + + # ========== 流式抗截断生成器 ========== + async def anti_truncation_generator(): + from src.converter.anti_truncation import AntiTruncationStreamProcessor + from src.api.geminicli import stream_request + from src.converter.anti_truncation import apply_anti_truncation + from src.converter.anthropic2gemini import gemini_stream_to_anthropic_stream + from fastapi import Response + + max_attempts = await get_anti_truncation_max_attempts() + + # 首先对payload应用反截断指令 + anti_truncation_payload = apply_anti_truncation(api_request) + + first_attempt_stream = stream_request(body=anti_truncation_payload, native=False) + try: + first_chunk = await read_first_async_item(first_attempt_stream) + except StopAsyncIteration: + return + + if isinstance(first_chunk, Response): + yield first_chunk + return + + first_attempt_pending = True + + async def stream_request_wrapper(payload): + nonlocal first_attempt_pending + + if first_attempt_pending: + first_attempt_pending = False + stream_gen = prepend_async_item(first_chunk, first_attempt_stream) + else: + stream_gen = stream_request(body=payload, native=False) + return StreamingResponse(stream_gen, media_type="text/event-stream") + + # 创建反截断处理器 + processor = AntiTruncationStreamProcessor( + stream_request_wrapper, + anti_truncation_payload, + max_attempts, + enable_prefill_mode=True, + ) + + # 包装以确保是bytes流 + async def bytes_wrapper(): + async for chunk in processor.process_stream(): + if isinstance(chunk, str): + yield chunk.encode('utf-8') + else: + yield chunk + + # 直接将整个流传递给转换器 + async for anthropic_chunk in gemini_stream_to_anthropic_stream( + bytes_wrapper(), + real_model, + 200 + ): + if anthropic_chunk: + yield anthropic_chunk + + # ========== 普通流式生成器 ========== + async def normal_stream_generator(): + from src.api.geminicli import stream_request + from fastapi import Response + from src.converter.anthropic2gemini import gemini_stream_to_anthropic_stream + + # 调用 API 层的流式请求(不使用 native 模式) + stream_gen = stream_request(body=api_request, native=False) + try: + first_chunk = await read_first_async_item(stream_gen) + except StopAsyncIteration: + return + + if isinstance(first_chunk, Response): + yield first_chunk + return + + # 包装流式生成器以处理错误响应 + async def gemini_chunk_wrapper(): + async for chunk in prepend_async_item(first_chunk, stream_gen): + # 检查是否是Response对象(错误情况) + if isinstance(chunk, Response): + # 错误响应,不进行转换,直接传递 + try: + error_content = chunk.body if isinstance(chunk.body, bytes) else (chunk.body or b'').encode('utf-8') + gemini_error = json.loads(error_content.decode('utf-8')) + from src.converter.anthropic2gemini import gemini_to_anthropic_response + anthropic_error = gemini_to_anthropic_response( + gemini_error, + real_model, + chunk.status_code + ) + yield f"data: {json.dumps(anthropic_error)}\n\n".encode('utf-8') + except Exception: + yield f"data: {json.dumps({'type': 'error', 'error': {'type': 'api_error', 'message': 'Stream error'}})}\n\n".encode('utf-8') + yield b"data: [DONE]\n\n" + return + else: + # 确保是bytes类型 + if isinstance(chunk, str): + yield chunk.encode('utf-8') + else: + yield chunk + + # 使用转换器处理整个流 + async for anthropic_chunk in gemini_stream_to_anthropic_stream( + gemini_chunk_wrapper(), + real_model, + 200 + ): + if anthropic_chunk: + yield anthropic_chunk + + # ========== 根据模式选择生成器 ========== + if use_fake_streaming: + return await build_streaming_response_or_error(fake_stream_generator()) + elif use_anti_truncation: + log.info("启用流式抗截断功能") + return await build_streaming_response_or_error(anti_truncation_generator()) + else: + return await build_streaming_response_or_error(normal_stream_generator()) + + +@router.post("/v1/messages/count_tokens") +async def count_tokens( + request: Request, + _token: str = Depends(authenticate_bearer) +): + """ + 处理Anthropic格式的token计数请求 + + Args: + request: FastAPI请求对象 + _token: Bearer认证令牌(由Depends验证) + + Returns: + JSONResponse: 包含input_tokens的响应 + """ + try: + payload = await request.json() + except Exception as e: + return JSONResponse( + status_code=400, + content={"type": "error", "error": {"type": "invalid_request_error", "message": f"JSON 解析失败: {str(e)}"}} + ) + + if not isinstance(payload, dict): + return JSONResponse( + status_code=400, + content={"type": "error", "error": {"type": "invalid_request_error", "message": "请求体必须为 JSON object"}} + ) + + if not payload.get("model") or not isinstance(payload.get("messages"), list): + return JSONResponse( + status_code=400, + content={"type": "error", "error": {"type": "invalid_request_error", "message": "缺少必填字段:model / messages"}} + ) + + try: + client_host = request.client.host if request.client else "unknown" + client_port = request.client.port if request.client else "unknown" + except Exception: + client_host = "unknown" + client_port = "unknown" + + thinking_present = "thinking" in payload + thinking_value = payload.get("thinking") + thinking_summary = None + if thinking_present: + if isinstance(thinking_value, dict): + thinking_summary = { + "type": thinking_value.get("type"), + "budget_tokens": thinking_value.get("budget_tokens"), + } + else: + thinking_summary = thinking_value + + user_agent = request.headers.get("user-agent", "") + log.info( + f"[GEMINICLI-ANTHROPIC] /messages/count_tokens 收到请求: client={client_host}:{client_port}, " + f"model={payload.get('model')}, messages={len(payload.get('messages') or [])}, " + f"thinking_present={thinking_present}, thinking={thinking_summary}, ua={user_agent}" + ) + + # 简单估算 + input_tokens = 0 + try: + input_tokens = estimate_input_tokens(payload) + except Exception as e: + log.error(f"[GEMINICLI-ANTHROPIC] token 估算失败: {e}") + + return JSONResponse(content={"input_tokens": input_tokens}) + + +# ==================== 测试代码 ==================== + +if __name__ == "__main__": + """ + 测试代码:演示Anthropic路由的流式和非流式响应 + 运行方式: python src/router/geminicli/anthropic.py + """ + + from fastapi.testclient import TestClient + from fastapi import FastAPI + + print("=" * 80) + print("Anthropic Router 测试") + print("=" * 80) + + # 创建测试应用 + app = FastAPI() + app.include_router(router) + + # 测试客户端 + client = TestClient(app) + + # 测试请求体 (Anthropic格式) + test_request_body = { + "model": "gemini-2.5-flash", + "max_tokens": 1024, + "messages": [ + {"role": "user", "content": "Hello, tell me a joke in one sentence."} + ] + } + + # 测试Bearer令牌(模拟) + test_token = "Bearer pwd" + + def test_non_stream_request(): + """测试非流式请求""" + print("\n" + "=" * 80) + print("【测试1】非流式请求 (POST /v1/messages)") + print("=" * 80) + print(f"请求体: {json.dumps(test_request_body, indent=2, ensure_ascii=False)}\n") + + response = client.post( + "/v1/messages", + json=test_request_body, + headers={"Authorization": test_token} + ) + + print("非流式响应数据:") + print("-" * 80) + print(f"状态码: {response.status_code}") + print(f"Content-Type: {response.headers.get('content-type', 'N/A')}") + + try: + content = response.text + print(f"\n响应内容 (原始):\n{content}\n") + + # 尝试解析JSON + try: + json_data = response.json() + print(f"响应内容 (格式化JSON):") + print(json.dumps(json_data, indent=2, ensure_ascii=False)) + except json.JSONDecodeError: + print("(非JSON格式)") + except Exception as e: + print(f"内容解析失败: {e}") + + def test_stream_request(): + """测试流式请求""" + print("\n" + "=" * 80) + print("【测试2】流式请求 (POST /v1/messages)") + print("=" * 80) + + stream_request_body = test_request_body.copy() + stream_request_body["stream"] = True + + print(f"请求体: {json.dumps(stream_request_body, indent=2, ensure_ascii=False)}\n") + + print("流式响应数据 (每个chunk):") + print("-" * 80) + + with client.stream( + "POST", + "/v1/messages", + json=stream_request_body, + headers={"Authorization": test_token} + ) as response: + print(f"状态码: {response.status_code}") + print(f"Content-Type: {response.headers.get('content-type', 'N/A')}\n") + + chunk_count = 0 + for chunk in response.iter_bytes(): + if chunk: + chunk_count += 1 + print(f"\nChunk #{chunk_count}:") + print(f" 类型: {type(chunk).__name__}") + print(f" 长度: {len(chunk)}") + + # 解码chunk + try: + chunk_str = chunk.decode('utf-8') + print(f" 内容预览: {repr(chunk_str[:200] if len(chunk_str) > 200 else chunk_str)}") + + # 如果是SSE格式,尝试解析每一行 + if chunk_str.startswith("event: ") or chunk_str.startswith("data: "): + # 按行分割,处理每个SSE事件 + for line in chunk_str.strip().split('\n'): + line = line.strip() + if not line: + continue + + if line == "data: [DONE]": + print(f" => 流结束标记") + elif line.startswith("data: "): + try: + json_str = line[6:] # 去掉 "data: " 前缀 + json_data = json.loads(json_str) + print(f" 解析后的JSON: {json.dumps(json_data, indent=4, ensure_ascii=False)}") + except Exception as e: + print(f" SSE解析失败: {e}") + except Exception as e: + print(f" 解码失败: {e}") + + print(f"\n总共收到 {chunk_count} 个chunk") + + def test_fake_stream_request(): + """测试假流式请求""" + print("\n" + "=" * 80) + print("【测试3】假流式请求 (POST /v1/messages with 假流式 prefix)") + print("=" * 80) + + fake_stream_request_body = test_request_body.copy() + fake_stream_request_body["model"] = "假流式/gemini-2.5-flash" + fake_stream_request_body["stream"] = True + + print(f"请求体: {json.dumps(fake_stream_request_body, indent=2, ensure_ascii=False)}\n") + + print("假流式响应数据 (每个chunk):") + print("-" * 80) + + with client.stream( + "POST", + "/v1/messages", + json=fake_stream_request_body, + headers={"Authorization": test_token} + ) as response: + print(f"状态码: {response.status_code}") + print(f"Content-Type: {response.headers.get('content-type', 'N/A')}\n") + + chunk_count = 0 + for chunk in response.iter_bytes(): + if chunk: + chunk_count += 1 + chunk_str = chunk.decode('utf-8') + + print(f"\nChunk #{chunk_count}:") + print(f" 长度: {len(chunk_str)} 字节") + + # 解析chunk中的所有SSE事件 + events = [] + for line in chunk_str.split('\n'): + line = line.strip() + if line.startswith("data: ") or line.startswith("event: "): + events.append(line) + + print(f" 包含 {len(events)} 个SSE事件") + + # 显示每个事件 + for event_idx, event_line in enumerate(events, 1): + if event_line == "data: [DONE]": + print(f" 事件 #{event_idx}: [DONE]") + elif event_line.startswith("data: "): + try: + json_str = event_line[6:] # 去掉 "data: " 前缀 + json_data = json.loads(json_str) + event_type = json_data.get("type", "unknown") + print(f" 事件 #{event_idx}: type={event_type}") + except Exception as e: + print(f" 事件 #{event_idx}: 解析失败 - {e}") + + print(f"\n总共收到 {chunk_count} 个HTTP chunk") + + # 运行测试 + try: + # 测试非流式请求 + test_non_stream_request() + + # 测试流式请求 + test_stream_request() + + # 测试假流式请求 + test_fake_stream_request() + + print("\n" + "=" * 80) + print("测试完成") + print("=" * 80) + + except Exception as e: + print(f"\n❌ 测试过程中出现异常: {e}") + import traceback + traceback.print_exc() diff --git a/src/router/geminicli/gemini.py b/src/router/geminicli/gemini.py new file mode 100644 index 0000000000000000000000000000000000000000..cd21de0330ebbc6703a02ab15c4e4185bb8ffdae --- /dev/null +++ b/src/router/geminicli/gemini.py @@ -0,0 +1,683 @@ +""" +Gemini Router - Handles native Gemini format API requests +处理原生Gemini格式请求的路由模块 +""" + +import sys +from pathlib import Path + +# 添加项目根目录到Python路径 +project_root = Path(__file__).resolve().parent.parent.parent.parent +if str(project_root) not in sys.path: + sys.path.insert(0, str(project_root)) + +import json + +# 第三方库 +from fastapi import APIRouter, Depends, HTTPException, Path, Request +from fastapi.responses import JSONResponse, StreamingResponse + +# 本地模块 - 配置和日志 +from config import get_anti_truncation_max_attempts +from log import log + +# 本地模块 - 工具和认证 +from src.utils import ( + get_base_model_from_feature_model, + is_anti_truncation_model, + authenticate_gemini_flexible, + is_fake_streaming_model +) + +# 本地模块 - 转换器(假流式需要) +from src.converter.fake_stream import ( + parse_response_for_fake_stream, + build_gemini_fake_stream_chunks, + create_gemini_heartbeat_chunk, +) + +# 本地模块 - 基础路由工具 +from src.router.hi_check import is_health_check_request, create_health_check_response +from src.router.stream_passthrough import ( + build_streaming_response_or_error, + prepend_async_item, + read_first_async_item, +) + +# 本地模块 - 数据模型 +from src.models import GeminiRequest, model_to_dict + +# 本地模块 - 任务管理 +from src.task_manager import create_managed_task + + +# ==================== 路由器初始化 ==================== + +router = APIRouter() + + +# ==================== API 路由 ==================== + +@router.post("/v1beta/models/{model:path}:generateContent") +@router.post("/v1/models/{model:path}:generateContent") +async def generate_content( + gemini_request: "GeminiRequest", + model: str = Path(..., description="Model name"), + api_key: str = Depends(authenticate_gemini_flexible), +): + """ + 处理Gemini格式的内容生成请求(非流式) + + Args: + gemini_request: Gemini格式的请求体 + model: 模型名称 + api_key: API 密钥 + """ + log.debug(f"[GEMINICLI] Non-streaming request for model: {model}") + + # 转换为字典 + normalized_dict = model_to_dict(gemini_request) + + # 健康检查 + if is_health_check_request(normalized_dict, format="gemini"): + response = create_health_check_response(format="gemini") + return JSONResponse(content=response) + + # 处理模型名称和功能检测 + use_anti_truncation = is_anti_truncation_model(model) + real_model = get_base_model_from_feature_model(model) + + # 对于抗截断模型的非流式请求,给出警告 + if use_anti_truncation: + log.warning("抗截断功能仅在流式传输时有效,非流式请求将忽略此设置") + + # 更新模型名为真实模型名 + normalized_dict["model"] = real_model + + # 规范化 Gemini 请求 (使用 geminicli 模式) + from src.converter.gemini_fix import normalize_gemini_request + normalized_dict = await normalize_gemini_request(normalized_dict, mode="geminicli") + + # 准备API请求格式 - 提取model并将其他字段放入request中 + api_request = { + "model": normalized_dict.pop("model"), + "request": normalized_dict + } + + # 调用 API 层的非流式请求 + from src.api.geminicli import non_stream_request + response = await non_stream_request(body=api_request) + + # 解包装响应:GeminiCli API 返回的格式有额外的 response 包装层 + # 需要提取 response.response 并返回标准 Gemini 格式 + try: + if response.status_code == 200: + response_data = json.loads(response.body if hasattr(response, 'body') else response.content) + # 如果有 response 包装,解包装它 + if "response" in response_data: + unwrapped_data = response_data["response"] + return JSONResponse(content=unwrapped_data) + # 错误响应或没有 response 字段,直接返回 + return response + except Exception as e: + log.warning(f"Failed to unwrap response: {e}, returning original response") + return response + +@router.post("/v1beta/models/{model:path}:streamGenerateContent") +@router.post("/v1/models/{model:path}:streamGenerateContent") +async def stream_generate_content( + gemini_request: GeminiRequest, + model: str = Path(..., description="Model name"), + api_key: str = Depends(authenticate_gemini_flexible), +): + """ + 处理Gemini格式的流式内容生成请求 + + Args: + gemini_request: Gemini格式的请求体 + model: 模型名称 + api_key: API 密钥 + """ + log.debug(f"[GEMINICLI] Streaming request for model: {model}") + + # 转换为字典 + normalized_dict = model_to_dict(gemini_request) + + # 处理模型名称和功能检测 + use_fake_streaming = is_fake_streaming_model(model) + use_anti_truncation = is_anti_truncation_model(model) + real_model = get_base_model_from_feature_model(model) + + # 更新模型名为真实模型名 + normalized_dict["model"] = real_model + + # ========== 假流式生成器 ========== + async def fake_stream_generator(): + from src.converter.gemini_fix import normalize_gemini_request + from src.api.geminicli import non_stream_request + + normalized_req = await normalize_gemini_request(normalized_dict.copy(), mode="geminicli") + + # 准备API请求格式 - 提取model并将其他字段放入request中 + api_request = { + "model": normalized_req.pop("model"), + "request": normalized_req + } + + response = await non_stream_request(body=api_request) + + # 检查响应状态码 + if hasattr(response, "status_code") and response.status_code != 200: + log.error(f"Fake streaming got error response: status={response.status_code}") + yield response + return + + # 处理成功响应 - 提取响应内容 + if hasattr(response, "body"): + response_body = response.body.decode() if isinstance(response.body, bytes) else response.body + elif hasattr(response, "content"): + response_body = response.content.decode() if isinstance(response.content, bytes) else response.content + else: + response_body = str(response) + + try: + response_data = json.loads(response_body) + log.debug(f"Gemini fake stream response data: {response_data}") + + # 检查是否是错误响应(有些错误可能status_code是200但包含error字段) + if "error" in response_data: + log.error(f"Fake streaming got error in response body: {response_data['error']}") + yield f"data: {json.dumps(response_data)}\n\n".encode() + yield "data: [DONE]\n\n".encode() + return + + # 使用统一的解析函数 + content, reasoning_content, finish_reason, images = parse_response_for_fake_stream(response_data) + + log.debug(f"Gemini extracted content: {content}") + log.debug(f"Gemini extracted reasoning: {reasoning_content[:100] if reasoning_content else 'None'}...") + log.debug(f"Gemini extracted images count: {len(images)}") + + # 构建响应块 + chunks = build_gemini_fake_stream_chunks(content, reasoning_content, finish_reason, images) + for idx, chunk in enumerate(chunks): + chunk_json = json.dumps(chunk) + log.debug(f"[FAKE_STREAM] Yielding chunk #{idx+1}: {chunk_json[:200]}") + yield f"data: {chunk_json}\n\n".encode() + + except Exception as e: + log.error(f"Response parsing failed: {e}, directly yield original response") + # 直接yield原始响应,不进行包装 + yield f"data: {response_body}\n\n".encode() + + yield "data: [DONE]\n\n".encode() + + # ========== 流式抗截断生成器 ========== + async def anti_truncation_generator(): + from src.converter.gemini_fix import normalize_gemini_request + from src.converter.anti_truncation import AntiTruncationStreamProcessor + from src.converter.anti_truncation import apply_anti_truncation + from src.api.geminicli import stream_request + from fastapi import Response + + # 先进行基础标准化 + normalized_req = await normalize_gemini_request(normalized_dict.copy(), mode="geminicli") + + # 准备API请求格式 - 提取model并将其他字段放入request中 + api_request = { + "model": normalized_req.pop("model") if "model" in normalized_req else real_model, + "request": normalized_req + } + + max_attempts = await get_anti_truncation_max_attempts() + + # 首先对payload应用反截断指令 + anti_truncation_payload = apply_anti_truncation(api_request) + + first_attempt_stream = stream_request(body=anti_truncation_payload, native=False) + try: + first_chunk = await read_first_async_item(first_attempt_stream) + except StopAsyncIteration: + return + + if isinstance(first_chunk, Response): + yield first_chunk + return + + first_attempt_pending = True + + async def stream_request_wrapper(payload): + nonlocal first_attempt_pending + + if first_attempt_pending: + first_attempt_pending = False + stream_gen = prepend_async_item(first_chunk, first_attempt_stream) + else: + stream_gen = stream_request(body=payload, native=False) + return StreamingResponse(stream_gen, media_type="text/event-stream") + + # 创建反截断处理器 + processor = AntiTruncationStreamProcessor( + stream_request_wrapper, + anti_truncation_payload, + max_attempts, + enable_prefill_mode=True, + ) + + # 迭代 process_stream() 生成器,并展开 response 包装 + async for chunk in processor.process_stream(): + if isinstance(chunk, (str, bytes)): + chunk_str = chunk.decode('utf-8') if isinstance(chunk, bytes) else chunk + + # 解析并展开 response 包装 + if chunk_str.startswith("data: "): + json_str = chunk_str[6:].strip() + + # 跳过 [DONE] 标记 + if json_str == "[DONE]": + yield chunk + continue + + try: + # 解析JSON + data = json.loads(json_str) + + # 展开 response 包装 + if "response" in data and "candidates" not in data: + log.debug(f"[GEMINICLI-ANTI-TRUNCATION] 展开response包装") + unwrapped_data = data["response"] + # 重新构建SSE格式 + yield f"data: {json.dumps(unwrapped_data, ensure_ascii=False)}\n\n".encode('utf-8') + else: + # 已经是展开的格式,直接返回 + yield chunk + except json.JSONDecodeError: + # JSON解析失败,直接返回原始chunk + yield chunk + else: + # 不是SSE格式,直接返回 + yield chunk + else: + # 其他类型,直接返回 + yield chunk + + # ========== 普通流式生成器 ========== + async def normal_stream_generator(): + from src.converter.gemini_fix import normalize_gemini_request + from src.api.geminicli import stream_request + from fastapi import Response + + normalized_req = await normalize_gemini_request(normalized_dict.copy(), mode="geminicli") + + # 准备API请求格式 - 提取model并将其他字段放入request中 + api_request = { + "model": normalized_req.pop("model"), + "request": normalized_req + } + + # 所有流式请求都使用非 native 模式(SSE格式)并展开 response 包装 + log.debug(f"[GEMINICLI] 使用非native模式,将展开response包装") + stream_gen = stream_request(body=api_request, native=False) + try: + first_chunk = await read_first_async_item(stream_gen) + except StopAsyncIteration: + return + + if isinstance(first_chunk, Response): + yield first_chunk + return + + # 展开 response 包装 + async for chunk in prepend_async_item(first_chunk, stream_gen): + # 检查是否是Response对象(错误情况) + if isinstance(chunk, Response): + # 将Response转换为SSE格式的错误消息 + try: + error_content = chunk.body if isinstance(chunk.body, bytes) else (chunk.body or b'').encode('utf-8') + error_json = json.loads(error_content.decode('utf-8')) + except Exception: + error_json = {"error": {"code": chunk.status_code, "message": "upstream error", "status": "ERROR"}} + log.error(f"[GEMINICLI STREAM] 返回错误给客户端: status={chunk.status_code}, error={str(error_json)[:200]}") + # 以SSE格式返回错误,并以[DONE]结束 + yield f"data: {json.dumps(error_json)}\n\n".encode('utf-8') + yield b"data: [DONE]\n\n" + return + + # 处理SSE格式的chunk + if isinstance(chunk, (str, bytes)): + chunk_str = chunk.decode('utf-8') if isinstance(chunk, bytes) else chunk + + # 解析并展开 response 包装 + if chunk_str.startswith("data: "): + json_str = chunk_str[6:].strip() + + # 跳过 [DONE] 标记 + if json_str == "[DONE]": + yield chunk + continue + + try: + # 解析JSON + data = json.loads(json_str) + + # 展开 response 包装 + if "response" in data and "candidates" not in data: + log.debug(f"[GEMINICLI] 展开response包装") + unwrapped_data = data["response"] + # 重新构建SSE格式 + yield f"data: {json.dumps(unwrapped_data, ensure_ascii=False)}\n\n".encode('utf-8') + else: + # 已经是展开的格式,直接返回 + yield chunk + except json.JSONDecodeError: + # JSON解析失败,直接返回原始chunk + yield chunk + else: + # 不是SSE格式,直接返回 + yield chunk + + # ========== 根据模式选择生成器 ========== + if use_fake_streaming: + return await build_streaming_response_or_error(fake_stream_generator()) + elif use_anti_truncation: + log.info("启用流式抗截断功能") + return await build_streaming_response_or_error(anti_truncation_generator()) + else: + return await build_streaming_response_or_error(normal_stream_generator()) + +@router.post("/v1beta/models/{model:path}:countTokens") +@router.post("/v1/models/{model:path}:countTokens") +async def count_tokens( + request: Request = None, + api_key: str = Depends(authenticate_gemini_flexible), +): + """ + 模拟Gemini格式的token计数 + + 使用简单的启发式方法:大约4字符=1token + """ + + try: + request_data = await request.json() + except Exception as e: + log.error(f"Failed to parse JSON request: {e}") + raise HTTPException(status_code=400, detail=f"Invalid JSON: {str(e)}") + + # 简单的token计数模拟 - 基于文本长度估算 + total_tokens = 0 + + # 如果有contents字段 + if "contents" in request_data: + for content in request_data["contents"]: + if "parts" in content: + for part in content["parts"]: + if "text" in part: + # 简单估算:大约4字符=1token + text_length = len(part["text"]) + total_tokens += max(1, text_length // 4) + + # 如果有generateContentRequest字段 + elif "generateContentRequest" in request_data: + gen_request = request_data["generateContentRequest"] + if "contents" in gen_request: + for content in gen_request["contents"]: + if "parts" in content: + for part in content["parts"]: + if "text" in part: + text_length = len(part["text"]) + total_tokens += max(1, text_length // 4) + + # 返回Gemini格式的响应 + return JSONResponse(content={"totalTokens": total_tokens}) + +# ==================== 测试代码 ==================== + +if __name__ == "__main__": + """ + 测试代码:演示Gemini路由的流式和非流式响应 + 运行方式: python src/router/geminicli/gemini.py + """ + + from fastapi.testclient import TestClient + from fastapi import FastAPI + + print("=" * 80) + print("Gemini Router 测试") + print("=" * 80) + + # 创建测试应用 + app = FastAPI() + app.include_router(router) + + # 测试客户端 + client = TestClient(app) + + # 测试请求体 (Gemini格式) + test_request_body = { + "contents": [ + { + "role": "user", + "parts": [{"text": "Hello, tell me a joke in one sentence."}] + } + ] + } + + # 测试API密钥(模拟) + test_api_key = "pwd" + + def test_non_stream_request(): + """测试非流式请求""" + print("\n" + "=" * 80) + print("【测试2】非流式请求 (POST /v1/models/gemini-2.5-flash:generateContent)") + print("=" * 80) + print(f"请求体: {json.dumps(test_request_body, indent=2, ensure_ascii=False)}\n") + + response = client.post( + "/v1/models/gemini-2.5-flash:generateContent", + json=test_request_body, + params={"key": test_api_key} + ) + + print("非流式响应数据:") + print("-" * 80) + print(f"状态码: {response.status_code}") + print(f"Content-Type: {response.headers.get('content-type', 'N/A')}") + + try: + content = response.text + print(f"\n响应内容 (原始):\n{content}\n") + + # 尝试解析JSON + try: + json_data = response.json() + print(f"响应内容 (格式化JSON):") + print(json.dumps(json_data, indent=2, ensure_ascii=False)) + except json.JSONDecodeError: + print("(非JSON格式)") + except Exception as e: + print(f"内容解析失败: {e}") + + def test_stream_request(): + """测试流式请求""" + print("\n" + "=" * 80) + print("【测试3】流式请求 (POST /v1/models/gemini-2.5-flash:streamGenerateContent)") + print("=" * 80) + print(f"请求体: {json.dumps(test_request_body, indent=2, ensure_ascii=False)}\n") + + print("流式响应数据 (每个chunk):") + print("-" * 80) + + with client.stream( + "POST", + "/v1/models/gemini-2.5-flash:streamGenerateContent", + json=test_request_body, + params={"key": test_api_key} + ) as response: + print(f"状态码: {response.status_code}") + print(f"Content-Type: {response.headers.get('content-type', 'N/A')}\n") + + chunk_count = 0 + for chunk in response.iter_bytes(): + if chunk: + chunk_count += 1 + print(f"\nChunk #{chunk_count}:") + print(f" 类型: {type(chunk).__name__}") + print(f" 长度: {len(chunk)}") + + # 解码chunk + try: + chunk_str = chunk.decode('utf-8') + print(f" 内容预览: {repr(chunk_str[:200] if len(chunk_str) > 200 else chunk_str)}") + + # 如果是SSE格式,尝试解析每一行 + if chunk_str.startswith("data: "): + # 按行分割,处理每个SSE事件 + for line in chunk_str.strip().split('\n'): + line = line.strip() + if not line: + continue + + if line == "data: [DONE]": + print(f" => 流结束标记") + elif line.startswith("data: "): + try: + json_str = line[6:] # 去掉 "data: " 前缀 + json_data = json.loads(json_str) + print(f" 解析后的JSON: {json.dumps(json_data, indent=4, ensure_ascii=False)}") + except Exception as e: + print(f" SSE解析失败: {e}") + except Exception as e: + print(f" 解码失败: {e}") + + print(f"\n总共收到 {chunk_count} 个chunk") + + def test_fake_stream_request(): + """测试假流式请求""" + print("\n" + "=" * 80) + print("【测试4】假流式请求 (POST /v1/models/假流式/gemini-2.5-flash:streamGenerateContent)") + print("=" * 80) + print(f"请求体: {json.dumps(test_request_body, indent=2, ensure_ascii=False)}\n") + + print("假流式响应数据 (每个chunk):") + print("-" * 80) + + with client.stream( + "POST", + "/v1/models/假流式/gemini-2.5-flash:streamGenerateContent", + json=test_request_body, + params={"key": test_api_key} + ) as response: + print(f"状态码: {response.status_code}") + print(f"Content-Type: {response.headers.get('content-type', 'N/A')}\n") + + chunk_count = 0 + for chunk in response.iter_bytes(): + if chunk: + chunk_count += 1 + chunk_str = chunk.decode('utf-8') + + print(f"\nChunk #{chunk_count}:") + print(f" 长度: {len(chunk_str)} 字节") + + # 解析chunk中的所有SSE事件 + events = [] + for line in chunk_str.split('\n'): + line = line.strip() + if line.startswith("data: "): + events.append(line) + + print(f" 包含 {len(events)} 个SSE事件") + + # 显示每个事件 + for event_idx, event_line in enumerate(events, 1): + if event_line == "data: [DONE]": + print(f" 事件 #{event_idx}: [DONE]") + else: + try: + json_str = event_line[6:] # 去掉 "data: " 前缀 + json_data = json.loads(json_str) + # 提取text内容 + text = json_data.get("candidates", [{}])[0].get("content", {}).get("parts", [{}])[0].get("text", "") + finish_reason = json_data.get("candidates", [{}])[0].get("finishReason") + print(f" 事件 #{event_idx}: text={repr(text[:50])}{'...' if len(text) > 50 else ''}, finishReason={finish_reason}") + except Exception as e: + print(f" 事件 #{event_idx}: 解析失败 - {e}") + + print(f"\n总共收到 {chunk_count} 个HTTP chunk") + + def test_anti_truncation_stream_request(): + """测试流式抗截断请求""" + print("\n" + "=" * 80) + print("【测试5】流式抗截断请求 (POST /v1/models/流式抗截断/gemini-2.5-flash:streamGenerateContent)") + print("=" * 80) + print(f"请求体: {json.dumps(test_request_body, indent=2, ensure_ascii=False)}\n") + + print("流式抗截断响应数据 (每个chunk):") + print("-" * 80) + + with client.stream( + "POST", + "/v1/models/流式抗截断/gemini-2.5-flash:streamGenerateContent", + json=test_request_body, + params={"key": test_api_key} + ) as response: + print(f"状态码: {response.status_code}") + print(f"Content-Type: {response.headers.get('content-type', 'N/A')}\n") + + chunk_count = 0 + for chunk in response.iter_bytes(): + if chunk: + chunk_count += 1 + print(f"\nChunk #{chunk_count}:") + print(f" 类型: {type(chunk).__name__}") + print(f" 长度: {len(chunk)}") + + # 解码chunk + try: + chunk_str = chunk.decode('utf-8') + print(f" 内容预览: {repr(chunk_str[:200] if len(chunk_str) > 200 else chunk_str)}") + + # 如果是SSE格式,尝试解析每一行 + if chunk_str.startswith("data: "): + # 按行分割,处理每个SSE事件 + for line in chunk_str.strip().split('\n'): + line = line.strip() + if not line: + continue + + if line == "data: [DONE]": + print(f" => 流结束标记") + elif line.startswith("data: "): + try: + json_str = line[6:] # 去掉 "data: " 前缀 + json_data = json.loads(json_str) + print(f" 解析后的JSON: {json.dumps(json_data, indent=4, ensure_ascii=False)}") + except Exception as e: + print(f" SSE解析失败: {e}") + except Exception as e: + print(f" 解码失败: {e}") + + print(f"\n总共收到 {chunk_count} 个chunk") + + # 运行测试 + try: + # 测试非流式请求 + test_non_stream_request() + + # 测试流式请求 + test_stream_request() + + # 测试假流式请求 + test_fake_stream_request() + + # 测试流式抗截断请求 + test_anti_truncation_stream_request() + + print("\n" + "=" * 80) + print("测试完成") + print("=" * 80) + + except Exception as e: + print(f"\n❌ 测试过程中出现异常: {e}") + import traceback + traceback.print_exc() diff --git a/src/router/geminicli/model_list.py b/src/router/geminicli/model_list.py new file mode 100644 index 0000000000000000000000000000000000000000..333a987a8f2d2b99155436ad2bc35349bd2e325f --- /dev/null +++ b/src/router/geminicli/model_list.py @@ -0,0 +1,66 @@ +""" +Gemini CLI Model List Router - Handles model list requests +Gemini CLI 模型列表路由 - 处理模型列表请求 +""" + +import sys +from pathlib import Path + +# 添加项目根目录到Python路径 +project_root = Path(__file__).resolve().parent.parent.parent.parent +if str(project_root) not in sys.path: + sys.path.insert(0, str(project_root)) + +# 第三方库 +from fastapi import APIRouter, Depends +from fastapi.responses import JSONResponse + +# 本地模块 - 工具和认证 +from src.utils import ( + get_available_models, + get_base_model_from_feature_model, + authenticate_flexible +) + +# 本地模块 - 基础路由工具 +from src.router.base_router import create_gemini_model_list, create_openai_model_list +from src.models import model_to_dict +from log import log + + +# ==================== 路由器初始化 ==================== + +router = APIRouter() + + +# ==================== API 路由 ==================== + +@router.get("/v1beta/models") +async def list_gemini_models(token: str = Depends(authenticate_flexible)): + """ + 返回 Gemini 格式的模型列表 + + 使用 create_gemini_model_list 工具函数创建标准格式 + """ + models = get_available_models("gemini") + log.info("[GEMINICLI MODEL LIST] 返回 Gemini 格式") + return JSONResponse(content=create_gemini_model_list( + models, + base_name_extractor=get_base_model_from_feature_model + )) + + +@router.get("/v1/models") +async def list_openai_models(token: str = Depends(authenticate_flexible)): + """ + 返回 OpenAI 格式的模型列表 + + 使用 create_openai_model_list 工具函数创建标准格式 + """ + models = get_available_models("gemini") + log.info("[GEMINICLI MODEL LIST] 返回 OpenAI 格式") + model_list = create_openai_model_list(models, owned_by="google") + return JSONResponse(content={ + "object": "list", + "data": [model_to_dict(model) for model in model_list.data] + }) diff --git a/src/router/geminicli/openai.py b/src/router/geminicli/openai.py new file mode 100644 index 0000000000000000000000000000000000000000..353b725005c16269000616bbab69b2be80d88a4d --- /dev/null +++ b/src/router/geminicli/openai.py @@ -0,0 +1,590 @@ +""" +OpenAI Router - Handles OpenAI format API requests via GeminiCLI +通过GeminiCLI处理OpenAI格式请求的路由模块 +""" + +import sys +from pathlib import Path + +# 添加项目根目录到Python路径 +project_root = Path(__file__).resolve().parent.parent.parent.parent +if str(project_root) not in sys.path: + sys.path.insert(0, str(project_root)) + +# 标准库 +import asyncio +import json + +# 第三方库 +from fastapi import APIRouter, Depends, HTTPException +from fastapi.responses import JSONResponse, StreamingResponse + +# 本地模块 - 配置和日志 +from config import get_anti_truncation_max_attempts +from log import log + +# 本地模块 - 工具和认证 +from src.utils import ( + get_base_model_from_feature_model, + is_anti_truncation_model, + is_fake_streaming_model, + authenticate_bearer, +) + +# 本地模块 - 转换器(假流式需要) +from src.converter.fake_stream import ( + parse_response_for_fake_stream, + build_openai_fake_stream_chunks, + create_openai_heartbeat_chunk, +) + +# 本地模块 - 基础路由工具 +from src.router.hi_check import is_health_check_request, create_health_check_response +from src.router.stream_passthrough import ( + build_streaming_response_or_error, + prepend_async_item, + read_first_async_item, +) + +# 本地模块 - 数据模型 +from src.models import OpenAIChatCompletionRequest, model_to_dict + +# 本地模块 - 任务管理 +from src.task_manager import create_managed_task + + +# ==================== 路由器初始化 ==================== + +router = APIRouter() + + +# ==================== API 路由 ==================== + +@router.post("/v1/chat/completions") +async def chat_completions( + openai_request: OpenAIChatCompletionRequest, + token: str = Depends(authenticate_bearer) +): + """ + 处理OpenAI格式的聊天完成请求(流式和非流式) + + Args: + openai_request: OpenAI格式的请求体 + token: Bearer认证令牌 + """ + log.debug(f"[GEMINICLI-OPENAI] Request for model: {openai_request.model}") + + # 转换为字典 + normalized_dict = model_to_dict(openai_request) + + # 健康检查 + if is_health_check_request(normalized_dict, format="openai"): + response = create_health_check_response(format="openai") + return JSONResponse(content=response) + + # 处理模型名称和功能检测 + use_fake_streaming = is_fake_streaming_model(openai_request.model) + use_anti_truncation = is_anti_truncation_model(openai_request.model) + real_model = get_base_model_from_feature_model(openai_request.model) + + # 获取流式标志 + is_streaming = openai_request.stream + + # 对于抗截断模型的非流式请求,给出警告 + if use_anti_truncation and not is_streaming: + log.warning("抗截断功能仅在流式传输时有效,非流式请求将忽略此设置") + + # 更新模型名为真实模型名 + normalized_dict["model"] = real_model + + # 转换为 Gemini 格式 (使用 converter) + from src.converter.openai2gemini import convert_openai_to_gemini_request + gemini_dict = await convert_openai_to_gemini_request(normalized_dict) + + # convert_openai_to_gemini_request 不包含 model 字段,需要手动添加 + gemini_dict["model"] = real_model + + # 规范化 Gemini 请求 (使用 geminicli 模式) + from src.converter.gemini_fix import normalize_gemini_request + gemini_dict = await normalize_gemini_request(gemini_dict, mode="geminicli") + + # 准备API请求格式 - 提取model并将其他字段放入request中 + api_request = { + "model": gemini_dict.pop("model"), + "request": gemini_dict + } + + # ========== 非流式请求 ========== + if not is_streaming: + # 调用 API 层的非流式请求 + from src.api.geminicli import non_stream_request + response = await non_stream_request(body=api_request) + + # 检查响应状态码 + status_code = getattr(response, "status_code", 200) + + # 提取响应体 + if hasattr(response, "body"): + response_body = response.body.decode() if isinstance(response.body, bytes) else response.body + elif hasattr(response, "content"): + response_body = response.content.decode() if isinstance(response.content, bytes) else response.content + else: + response_body = str(response) + + try: + gemini_response = json.loads(response_body) + except Exception as e: + log.error(f"Failed to parse Gemini response: {e}") + raise HTTPException(status_code=500, detail="Response parsing failed") + + # 转换为 OpenAI 格式 + from src.converter.openai2gemini import convert_gemini_to_openai_response + openai_response = convert_gemini_to_openai_response( + gemini_response, + real_model, + status_code + ) + + return JSONResponse(content=openai_response, status_code=status_code) + + # ========== 流式请求 ========== + + # ========== 假流式生成器 ========== + async def fake_stream_generator(): + from src.api.geminicli import non_stream_request + + response = await non_stream_request(body=api_request) + + # 检查响应状态码 + if hasattr(response, "status_code") and response.status_code != 200: + log.error(f"Fake streaming got error response: status={response.status_code}") + yield response + return + + # 处理成功响应 - 提取响应内容 + if hasattr(response, "body"): + response_body = response.body.decode() if isinstance(response.body, bytes) else response.body + elif hasattr(response, "content"): + response_body = response.content.decode() if isinstance(response.content, bytes) else response.content + else: + response_body = str(response) + + try: + gemini_response = json.loads(response_body) + log.debug(f"OpenAI fake stream Gemini response: {gemini_response}") + + # 检查是否是错误响应(有些错误可能status_code是200但包含error字段) + if "error" in gemini_response: + log.error(f"Fake streaming got error in response body: {gemini_response['error']}") + # 转换错误为 OpenAI 格式 + from src.converter.openai2gemini import convert_gemini_to_openai_response + openai_error = convert_gemini_to_openai_response( + gemini_response, + real_model, + 200 + ) + yield f"data: {json.dumps(openai_error)}\n\n".encode() + yield "data: [DONE]\n\n".encode() + return + + # 使用统一的解析函数 + content, reasoning_content, finish_reason, images = parse_response_for_fake_stream(gemini_response) + + log.debug(f"OpenAI extracted content: {content}") + log.debug(f"OpenAI extracted reasoning: {reasoning_content[:100] if reasoning_content else 'None'}...") + log.debug(f"OpenAI extracted images count: {len(images)}") + + # 构建响应块 + chunks = build_openai_fake_stream_chunks(content, reasoning_content, finish_reason, real_model, images) + for idx, chunk in enumerate(chunks): + chunk_json = json.dumps(chunk) + log.debug(f"[FAKE_STREAM] Yielding chunk #{idx+1}: {chunk_json[:200]}") + yield f"data: {chunk_json}\n\n".encode() + + except Exception as e: + log.error(f"Response parsing failed: {e}, directly yield error") + # 构建错误响应 + error_chunk = { + "id": "error", + "object": "chat.completion.chunk", + "created": int(asyncio.get_event_loop().time()), + "model": real_model, + "choices": [{ + "index": 0, + "delta": {"content": f"Error: {str(e)}"}, + "finish_reason": "error" + }] + } + yield f"data: {json.dumps(error_chunk)}\n\n".encode() + + yield "data: [DONE]\n\n".encode() + + # ========== 流式抗截断生成器 ========== + async def anti_truncation_generator(): + from src.converter.anti_truncation import AntiTruncationStreamProcessor + from src.api.geminicli import stream_request + from src.converter.anti_truncation import apply_anti_truncation + from fastapi import Response + + max_attempts = await get_anti_truncation_max_attempts() + + # 首先对payload应用反截断指令 + anti_truncation_payload = apply_anti_truncation(api_request) + + first_attempt_stream = stream_request(body=anti_truncation_payload, native=False) + try: + first_chunk = await read_first_async_item(first_attempt_stream) + except StopAsyncIteration: + return + + if isinstance(first_chunk, Response): + yield first_chunk + return + + first_attempt_pending = True + + async def stream_request_wrapper(payload): + nonlocal first_attempt_pending + + if first_attempt_pending: + first_attempt_pending = False + stream_gen = prepend_async_item(first_chunk, first_attempt_stream) + else: + stream_gen = stream_request(body=payload, native=False) + + return StreamingResponse(stream_gen, media_type="text/event-stream") + + # 创建反截断处理器 + processor = AntiTruncationStreamProcessor( + stream_request_wrapper, + anti_truncation_payload, + max_attempts, + enable_prefill_mode=True, + ) + + # 转换为 OpenAI 格式 + import uuid + response_id = str(uuid.uuid4()) + + # 直接迭代 process_stream() 生成器,并转换为 OpenAI 格式 + async for chunk in processor.process_stream(): + if not chunk: + continue + + # 解析 Gemini SSE 格式 + chunk_str = chunk.decode('utf-8') if isinstance(chunk, bytes) else chunk + + # 跳过空行 + if not chunk_str.strip(): + continue + + # 处理 [DONE] 标记 + if chunk_str.strip() == "data: [DONE]": + yield "data: [DONE]\n\n".encode('utf-8') + return + + # 解析 "data: {...}" 格式 + if chunk_str.startswith("data: "): + try: + # 转换为 OpenAI 格式 + from src.converter.openai2gemini import convert_gemini_to_openai_stream + openai_chunk_str = convert_gemini_to_openai_stream( + chunk_str, + real_model, + response_id + ) + + if openai_chunk_str: + yield openai_chunk_str.encode('utf-8') + + except Exception as e: + log.error(f"Failed to convert chunk: {e}") + continue + + # 发送结束标记 + yield "data: [DONE]\n\n".encode('utf-8') + + # ========== 普通流式生成器 ========== + async def normal_stream_generator(): + from src.api.geminicli import stream_request + from fastapi import Response + import uuid + + # 调用 API 层的流式请求(不使用 native 模式) + stream_gen = stream_request(body=api_request, native=False) + try: + first_chunk = await read_first_async_item(stream_gen) + except StopAsyncIteration: + return + + if isinstance(first_chunk, Response): + yield first_chunk + return + + response_id = str(uuid.uuid4()) + + # yield所有数据,处理可能的错误Response + async for chunk in prepend_async_item(first_chunk, stream_gen): + # 检查是否是Response对象(错误情况) + if isinstance(chunk, Response): + # 将Response转换为SSE格式的错误消息 + try: + error_content = chunk.body if isinstance(chunk.body, bytes) else (chunk.body or b'').encode('utf-8') + gemini_error = json.loads(error_content.decode('utf-8')) + # 转换为 OpenAI 格式错误 + from src.converter.openai2gemini import convert_gemini_to_openai_response + openai_error = convert_gemini_to_openai_response( + gemini_error, + real_model, + chunk.status_code + ) + yield f"data: {json.dumps(openai_error)}\n\n".encode('utf-8') + except Exception: + yield f"data: {json.dumps({'error': 'Stream error'})}\n\n".encode('utf-8') + yield b"data: [DONE]\n\n" + return + else: + # 正常的bytes数据,转换为 OpenAI 格式 + chunk_str = chunk.decode('utf-8') if isinstance(chunk, bytes) else chunk + + # 跳过空行 + if not chunk_str.strip(): + continue + + # 处理 [DONE] 标记 + if chunk_str.strip() == "data: [DONE]": + yield "data: [DONE]\n\n".encode('utf-8') + return + + # 解析并转换 Gemini chunk 为 OpenAI 格式 + if chunk_str.startswith("data: "): + try: + # 转换为 OpenAI 格式 + from src.converter.openai2gemini import convert_gemini_to_openai_stream + openai_chunk_str = convert_gemini_to_openai_stream( + chunk_str, + real_model, + response_id + ) + + if openai_chunk_str: + yield openai_chunk_str.encode('utf-8') + + except Exception as e: + log.error(f"Failed to convert chunk: {e}") + continue + + # 发送结束标记 + yield "data: [DONE]\n\n".encode('utf-8') + + # ========== 根据模式选择生成器 ========== + if use_fake_streaming: + return await build_streaming_response_or_error(fake_stream_generator()) + elif use_anti_truncation: + log.info("启用流式抗截断功能") + return await build_streaming_response_or_error(anti_truncation_generator()) + else: + return await build_streaming_response_or_error(normal_stream_generator()) + + +# ==================== 测试代码 ==================== + +if __name__ == "__main__": + """ + 测试代码:演示OpenAI路由的流式和非流式响应 + 运行方式: python src/router/geminicli/openai.py + """ + + from fastapi.testclient import TestClient + from fastapi import FastAPI + + print("=" * 80) + print("OpenAI Router 测试") + print("=" * 80) + + # 创建测试应用 + app = FastAPI() + app.include_router(router) + + # 测试客户端 + client = TestClient(app) + + # 测试请求体 (OpenAI格式) + test_request_body = { + "model": "gemini-2.5-flash", + "messages": [ + {"role": "user", "content": "Hello, tell me a joke in one sentence."} + ] + } + + # 测试Bearer令牌(模拟) + test_token = "Bearer pwd" + + def test_non_stream_request(): + """测试非流式请求""" + print("\n" + "=" * 80) + print("【测试1】非流式请求 (POST /v1/chat/completions)") + print("=" * 80) + print(f"请求体: {json.dumps(test_request_body, indent=2, ensure_ascii=False)}\n") + + response = client.post( + "/v1/chat/completions", + json=test_request_body, + headers={"Authorization": test_token} + ) + + print("非流式响应数据:") + print("-" * 80) + print(f"状态码: {response.status_code}") + print(f"Content-Type: {response.headers.get('content-type', 'N/A')}") + + try: + content = response.text + print(f"\n响应内容 (原始):\n{content}\n") + + # 尝试解析JSON + try: + json_data = response.json() + print(f"响应内容 (格式化JSON):") + print(json.dumps(json_data, indent=2, ensure_ascii=False)) + except json.JSONDecodeError: + print("(非JSON格式)") + except Exception as e: + print(f"内容解析失败: {e}") + + def test_stream_request(): + """测试流式请求""" + print("\n" + "=" * 80) + print("【测试2】流式请求 (POST /v1/chat/completions)") + print("=" * 80) + + stream_request_body = test_request_body.copy() + stream_request_body["stream"] = True + + print(f"请求体: {json.dumps(stream_request_body, indent=2, ensure_ascii=False)}\n") + + print("流式响应数据 (每个chunk):") + print("-" * 80) + + with client.stream( + "POST", + "/v1/chat/completions", + json=stream_request_body, + headers={"Authorization": test_token} + ) as response: + print(f"状态码: {response.status_code}") + print(f"Content-Type: {response.headers.get('content-type', 'N/A')}\n") + + chunk_count = 0 + for chunk in response.iter_bytes(): + if chunk: + chunk_count += 1 + print(f"\nChunk #{chunk_count}:") + print(f" 类型: {type(chunk).__name__}") + print(f" 长度: {len(chunk)}") + + # 解码chunk + try: + chunk_str = chunk.decode('utf-8') + print(f" 内容预览: {repr(chunk_str[:200] if len(chunk_str) > 200 else chunk_str)}") + + # 如果是SSE格式,尝试解析每一行 + if chunk_str.startswith("data: "): + # 按行分割,处理每个SSE事件 + for line in chunk_str.strip().split('\n'): + line = line.strip() + if not line: + continue + + if line == "data: [DONE]": + print(f" => 流结束标记") + elif line.startswith("data: "): + try: + json_str = line[6:] # 去掉 "data: " 前缀 + json_data = json.loads(json_str) + print(f" 解析后的JSON: {json.dumps(json_data, indent=4, ensure_ascii=False)}") + except Exception as e: + print(f" SSE解析失败: {e}") + except Exception as e: + print(f" 解码失败: {e}") + + print(f"\n总共收到 {chunk_count} 个chunk") + + def test_fake_stream_request(): + """测试假流式请求""" + print("\n" + "=" * 80) + print("【测试3】假流式请求 (POST /v1/chat/completions with 假流式 prefix)") + print("=" * 80) + + fake_stream_request_body = test_request_body.copy() + fake_stream_request_body["model"] = "假流式/gemini-2.5-flash" + fake_stream_request_body["stream"] = True + + print(f"请求体: {json.dumps(fake_stream_request_body, indent=2, ensure_ascii=False)}\n") + + print("假流式响应数据 (每个chunk):") + print("-" * 80) + + with client.stream( + "POST", + "/v1/chat/completions", + json=fake_stream_request_body, + headers={"Authorization": test_token} + ) as response: + print(f"状态码: {response.status_code}") + print(f"Content-Type: {response.headers.get('content-type', 'N/A')}\n") + + chunk_count = 0 + for chunk in response.iter_bytes(): + if chunk: + chunk_count += 1 + chunk_str = chunk.decode('utf-8') + + print(f"\nChunk #{chunk_count}:") + print(f" 长度: {len(chunk_str)} 字节") + + # 解析chunk中的所有SSE事件 + events = [] + for line in chunk_str.split('\n'): + line = line.strip() + if line.startswith("data: "): + events.append(line) + + print(f" 包含 {len(events)} 个SSE事件") + + # 显示每个事件 + for event_idx, event_line in enumerate(events, 1): + if event_line == "data: [DONE]": + print(f" 事件 #{event_idx}: [DONE]") + else: + try: + json_str = event_line[6:] # 去掉 "data: " 前缀 + json_data = json.loads(json_str) + # 提取content内容 + content = json_data.get("choices", [{}])[0].get("delta", {}).get("content", "") + finish_reason = json_data.get("choices", [{}])[0].get("finish_reason") + print(f" 事件 #{event_idx}: content={repr(content[:50])}{'...' if len(content) > 50 else ''}, finish_reason={finish_reason}") + except Exception as e: + print(f" 事件 #{event_idx}: 解析失败 - {e}") + + print(f"\n总共收到 {chunk_count} 个HTTP chunk") + + # 运行测试 + try: + # 测试非流式请求 + test_non_stream_request() + + # 测试流式请求 + test_stream_request() + + # 测试假流式请求 + test_fake_stream_request() + + print("\n" + "=" * 80) + print("测试完成") + print("=" * 80) + + except Exception as e: + print(f"\n❌ 测试过程中出现异常: {e}") + import traceback + traceback.print_exc() diff --git a/src/router/hi_check.py b/src/router/hi_check.py new file mode 100644 index 0000000000000000000000000000000000000000..b3ec7a2928235624e3f1217f78dc26a9090e6909 --- /dev/null +++ b/src/router/hi_check.py @@ -0,0 +1,131 @@ +""" +统一的健康检查(Hi消息)处理模块 + +提供对OpenAI、Gemini和Anthropic格式的Hi消息的解析和响应 +""" +import time +from typing import Any, Dict, List + + +# ==================== Hi消息检测 ==================== + +def is_health_check_request(request_data: dict, format: str = "openai") -> bool: + """ + 检查是否是健康检查请求(Hi消息) + + Args: + request_data: 请求数据 + format: 请求格式("openai"、"gemini" 或 "anthropic") + + Returns: + 是否是健康检查请求 + """ + if format == "openai": + # OpenAI格式健康检查: {"messages": [{"role": "user", "content": "Hi"}]} + messages = request_data.get("messages", []) + if len(messages) == 1: + msg = messages[0] + if msg.get("role") == "user" and msg.get("content") == "Hi": + return True + + elif format == "gemini": + # Gemini格式健康检查: {"contents": [{"role": "user", "parts": [{"text": "Hi"}]}]} + contents = request_data.get("contents", []) + if len(contents) == 1: + content = contents[0] + if (content.get("role") == "user" and + content.get("parts", [{}])[0].get("text") == "Hi"): + return True + + elif format == "anthropic": + # Anthropic格式健康检查: {"messages": [{"role": "user", "content": "Hi"}]} + messages = request_data.get("messages", []) + if (len(messages) == 1 + and messages[0].get("role") == "user" + and messages[0].get("content") == "Hi"): + return True + + return False + + +def is_health_check_message(messages: List[Dict[str, Any]]) -> bool: + """ + 直接检查消息列表是否为健康检查消息(Anthropic专用) + + 这是一个便捷函数,用于已经提取出消息列表的场景。 + + Args: + messages: 消息列表 + + Returns: + 是否为健康检查消息 + """ + return ( + len(messages) == 1 + and messages[0].get("role") == "user" + and messages[0].get("content") == "Hi" + ) + + +# ==================== Hi消息响应生成 ==================== + +def create_health_check_response(format: str = "openai", **kwargs) -> dict: + """ + 创建健康检查响应 + + Args: + format: 响应格式("openai"、"gemini" 或 "anthropic") + **kwargs: 格式特定的额外参数 + - model: 模型名称(anthropic格式需要) + - message_id: 消息ID(anthropic格式需要) + + Returns: + 健康检查响应字典 + """ + if format == "openai": + # OpenAI格式响应 + return { + "id": "healthcheck", + "object": "chat.completion", + "created": int(time.time()), + "model": "healthcheck", + "choices": [{ + "index": 0, + "message": { + "role": "assistant", + "content": "API is working" + }, + "finish_reason": "stop" + }] + } + + elif format == "gemini": + # Gemini格式响应 + return { + "candidates": [{ + "content": { + "parts": [{"text": "gcli2api工作中"}], + "role": "model" + }, + "finishReason": "STOP", + "index": 0, + }] + } + + elif format == "anthropic": + # Anthropic格式响应 + model = kwargs.get("model", "claude-unknown") + message_id = kwargs.get("message_id", "msg_healthcheck") + return { + "id": message_id, + "type": "message", + "role": "assistant", + "model": str(model), + "content": [{"type": "text", "text": "antigravity Anthropic Messages 正常工作中"}], + "stop_reason": "end_turn", + "stop_sequence": None, + "usage": {"input_tokens": 0, "output_tokens": 0}, + } + + # 未知格式返回空字典 + return {} diff --git a/src/router/stream_passthrough.py b/src/router/stream_passthrough.py new file mode 100644 index 0000000000000000000000000000000000000000..95747107ae14b351569a1d52a84b85f0ac7bb696 --- /dev/null +++ b/src/router/stream_passthrough.py @@ -0,0 +1,38 @@ +from typing import Any, AsyncIterator + +from fastapi import Response +from fastapi.responses import StreamingResponse + + +async def prepend_async_item(first_item: Any, iterator: AsyncIterator[Any]): + """Yield a prefetched item before continuing the original iterator.""" + yield first_item + async for item in iterator: + yield item + + +async def read_first_async_item(iterator: AsyncIterator[Any]) -> Any: + """Python 3.9-compatible async equivalent of built-in anext().""" + return await iterator.__anext__() + + +async def build_streaming_response_or_error( + iterator: AsyncIterator[Any], + media_type: str = "text/event-stream", +): + """ + Prefetch the first async item so router code can return an upstream error + response directly before FastAPI commits a 200 streaming response. + """ + try: + first_item = await read_first_async_item(iterator) + except StopAsyncIteration: + return Response(status_code=204) + + if isinstance(first_item, Response): + return first_item + + return StreamingResponse( + prepend_async_item(first_item, iterator), + media_type=media_type, + ) diff --git a/src/storage/mongodb_manager.py b/src/storage/mongodb_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..43c57d3a176a51e0dd1a8b53fc33e5ad6eeaa975 --- /dev/null +++ b/src/storage/mongodb_manager.py @@ -0,0 +1,1530 @@ +""" +MongoDB 存储管理器 +""" + +import json +import os +import random +import time +from typing import Any, Dict, List, Optional + +from motor.motor_asyncio import AsyncIOMotorClient, AsyncIOMotorDatabase + +from log import log + + +class MongoDBManager: + """MongoDB 数据库管理器""" + + # 状态字段常量 + STATE_FIELDS = { + "error_codes", + "error_messages", + "disabled", + "last_success", + "user_email", + "model_cooldowns", + "preview", + "tier", + "enable_credit", + } + + @staticmethod + def _escape_model_name(model_name: str) -> str: + """ + 转义模型名中的点号,避免 MongoDB 将其解释为嵌套结构 + + Args: + model_name: 原始模型名 (如 "gemini-2.5-flash") + + Returns: + 转义后的模型名 (如 "gemini-2-5-flash") + """ + return model_name.replace(".", "-") + + def __init__(self): + self._client: Optional[AsyncIOMotorClient] = None + self._db: Optional[AsyncIOMotorDatabase] = None + self._initialized = False + + # 内存配置缓存 - 初始化时加载一次 + self._config_cache: Dict[str, Any] = {} + self._config_loaded = False + + # Redis 缓存(仅当 REDIS_URL 环境变量存在时启用) + self._redis = None + self._redis_enabled: bool = False + + async def initialize(self) -> None: + """初始化 MongoDB 连接""" + if self._initialized: + return + + try: + mongodb_uri = os.getenv("MONGODB_URI") + if not mongodb_uri: + raise ValueError("MONGODB_URI environment variable not set") + + database_name = os.getenv("MONGODB_DATABASE", "gcli2api") + + self._client = AsyncIOMotorClient(mongodb_uri) + self._db = self._client[database_name] + + # 测试连接 + await self._db.command("ping") + + # 创建索引 + await self._create_indexes() + + # 加载配置到内存 + await self._load_config_cache() + + self._initialized = True + log.info(f"MongoDB storage initialized (database: {database_name})") + + # 尝试初始化 Redis(可选) + await self._init_redis() + + except Exception as e: + log.error(f"Error initializing MongoDB: {e}") + raise + + async def _create_indexes(self): + """ + 创建索引 + """ + from pymongo import IndexModel, ASCENDING + + credentials_collection = self._db["credentials"] + antigravity_credentials_collection = self._db["antigravity_credentials"] + + # ===== Geminicli 凭证索引 ===== + geminicli_indexes = [ + # 唯一索引 - 用于所有按文件名的精确查询 + IndexModel([("filename", ASCENDING)], unique=True, name="idx_filename_unique"), + + # 复合索引 - 用于 get_next_available_credential 和 get_available_credentials_list + # 查询模式: {disabled: False} + sort by rotation_order + IndexModel( + [("disabled", ASCENDING), ("rotation_order", ASCENDING)], + name="idx_disabled_rotation" + ), + + # 单字段索引 - 用于 get_credentials_summary 的错误筛选 + IndexModel([("error_codes", ASCENDING)], name="idx_error_codes"), + + # 单字段索引 - 用于 get_duplicate_credentials_by_email 的去重查询 + IndexModel([("user_email", ASCENDING)], name="idx_user_email"), + ] + + # ===== Antigravity 凭证索引 ===== + antigravity_indexes = [ + # 唯一索引 + IndexModel([("filename", ASCENDING)], unique=True, name="idx_filename_unique"), + + # 复合索引 - 查询模式: {disabled: False} + sort by rotation_order + # 查询模式: {disabled: False} + 可选 sort by rotation_order + IndexModel( + [("disabled", ASCENDING), ("rotation_order", ASCENDING)], + name="idx_disabled_rotation" + ), + + # 单字段索引 - 错误筛选 + IndexModel([("error_codes", ASCENDING)], name="idx_error_codes"), + + # 单字段索引 - 去重查询 + IndexModel([("user_email", ASCENDING)], name="idx_user_email"), + ] + + # 并行创建新索引 + try: + await credentials_collection.create_indexes(geminicli_indexes) + await antigravity_credentials_collection.create_indexes(antigravity_indexes) + log.debug("MongoDB indexes created successfully") + except Exception as e: + # 如果索引已存在,忽略错误 + if "already exists" not in str(e).lower(): + log.warning(f"Index creation warning: {e}") + + async def _load_config_cache(self): + """加载配置到内存缓存(仅在初始化时调用一次)""" + if self._config_loaded: + return + + try: + config_collection = self._db["config"] + cursor = config_collection.find({}) + + async for doc in cursor: + self._config_cache[doc["key"]] = doc.get("value") + + self._config_loaded = True + log.debug(f"Loaded {len(self._config_cache)} config items into cache") + + except Exception as e: + log.error(f"Error loading config cache: {e}") + self._config_cache = {} + + # ============ Redis 缓存(可选,仅当 REDIS_URL 存在时启用)============ + + async def _init_redis(self) -> None: + """初始化 Redis 连接并重建凭证池缓存(若 REDIS_URL 存在)""" + redis_url = os.getenv("REDIS_URL") + if not redis_url: + return + + try: + import redis.asyncio as aioredis # type: ignore + except ImportError: + log.warning("redis package not installed, Redis cache disabled. Run: pip install redis") + return + + try: + self._redis = aioredis.from_url(redis_url, decode_responses=True) + await self._redis.ping() + self._redis_enabled = True + log.info("Redis connected, rebuilding credential pool cache...") + + # 并行重建两个 mode 的缓存及配置缓存 + import asyncio + await asyncio.gather( + self._rebuild_redis_cache("geminicli"), + self._rebuild_redis_cache("antigravity"), + self._load_config_to_redis(), + ) + log.info("Redis credential pool cache ready") + except Exception as e: + log.warning(f"Redis init failed, falling back to MongoDB-only mode: {e}") + self._redis = None + self._redis_enabled = False + + # ---- Redis key 工具 ---- + + def _rk_avail(self, mode: str) -> str: + """所有未禁用凭证的 Redis Set key""" + return f"gcli:avail:{mode}" + + def _rk_tier(self, mode: str, tier: str) -> str: + """按 tier 分桶的未禁用凭证 Redis Set key""" + return f"gcli:tier:{mode}:{tier}" + + def _rk_preview(self, mode: str) -> str: + """preview=True 凭证的 Redis Set key""" + return f"gcli:preview:{mode}" + + def _rk_cd(self, mode: str, filename: str, escaped_model: str) -> str: + """模型冷却 Redis key(带 TTL)""" + return f"gcli:cd:{mode}:{filename}:{escaped_model}" + + # ---- Redis 缓存维护 ---- + + async def _rebuild_redis_cache(self, mode: str) -> None: + """ + 从 MongoDB 重建指定 mode 的 Redis 凭证池缓存。 + + 使用临时 key + RENAME 原子替换 + """ + if not self._redis: + return + try: + collection = self._db[self._get_collection_name(mode)] + # 同时投影 model_cooldowns、tier、preview,以便重建缓存 + projection: Dict[str, Any] = {"filename": 1, "disabled": 1, "model_cooldowns": 1, "tier": 1, "preview": 1, "_id": 0} + + avail: List[str] = [] + tier_buckets: Dict[str, List[str]] = {} # tier -> [filename, ...] + preview_members: List[str] = [] + cooldown_entries: List[tuple] = [] # (cd_key, ttl_seconds, value) + current_time = time.time() + + async for doc in collection.find({}, projection=projection): + if not doc.get("disabled", False): + filename = doc["filename"] + avail.append(filename) + + # 按 tier 分桶 + tier = doc.get("tier") or "pro" + tier_buckets.setdefault(tier, []).append(filename) + + # preview 分桶(仅 geminicli) + if mode == "geminicli" and doc.get("preview", True): + preview_members.append(filename) + + # 收集未过期的模型冷却,重建 Redis TTL Key + model_cooldowns = doc.get("model_cooldowns") or {} + for escaped_model, cooldown_until in model_cooldowns.items(): + if isinstance(cooldown_until, (int, float)) and cooldown_until > current_time: + ttl = int(cooldown_until - current_time) + if ttl > 0: + cd_key = self._rk_cd(mode, filename, escaped_model) + cooldown_entries.append((cd_key, ttl, str(cooldown_until))) + + tmp_avail = self._rk_avail(mode) + ":tmp" + + pipe = self._redis.pipeline() + # 先写临时 key(此时正式 key 仍完整可用) + pipe.delete(tmp_avail) + if avail: + pipe.sadd(tmp_avail, *avail) + await pipe.execute() + + # RENAME 是原子操作:瞬间切换,不存在空窗 + pipe2 = self._redis.pipeline() + if avail: + pipe2.rename(tmp_avail, self._rk_avail(mode)) + else: + pipe2.delete(self._rk_avail(mode)) + pipe2.delete(tmp_avail) + await pipe2.execute() + + # 重建 tier 分桶 Set(原子替换) + all_tiers = ("free", "pro", "ultra") + pipe3 = self._redis.pipeline() + for tier in all_tiers: + tier_key = self._rk_tier(mode, tier) + tmp_tier_key = tier_key + ":tmp" + pipe3.delete(tmp_tier_key) + members = tier_buckets.get(tier, []) + if members: + pipe3.sadd(tmp_tier_key, *members) + await pipe3.execute() + + pipe4 = self._redis.pipeline() + for tier in all_tiers: + tier_key = self._rk_tier(mode, tier) + tmp_tier_key = tier_key + ":tmp" + members = tier_buckets.get(tier, []) + if members: + pipe4.rename(tmp_tier_key, tier_key) + else: + pipe4.delete(tier_key) + pipe4.delete(tmp_tier_key) + await pipe4.execute() + + # 重建 preview 分桶(仅 geminicli) + preview_key = self._rk_preview(mode) + tmp_preview_key = preview_key + ":tmp" + pipe5 = self._redis.pipeline() + pipe5.delete(tmp_preview_key) + if preview_members: + pipe5.sadd(tmp_preview_key, *preview_members) + await pipe5.execute() + pipe6 = self._redis.pipeline() + if preview_members: + pipe6.rename(tmp_preview_key, preview_key) + else: + pipe6.delete(preview_key) + pipe6.delete(tmp_preview_key) + await pipe6.execute() + + # 批量恢复未过期的模型冷却 TTL Key + if cooldown_entries: + pipe7 = self._redis.pipeline() + for cd_key, ttl, value in cooldown_entries: + pipe7.setex(cd_key, ttl, value) + await pipe7.execute() + + log.debug( + f"Redis cache rebuilt [{mode}]: {len(avail)} avail, " + f"tiers={{{', '.join(f'{t}:{len(tier_buckets.get(t, []))}' for t in all_tiers)}}}, " + f"preview={len(preview_members)}, " + f"{len(cooldown_entries)} cooldown key(s) restored" + ) + except Exception as e: + log.warning(f"Redis rebuild cache error [{mode}]: {e}") + + async def _redis_add_cred(self, mode: str, filename: str, tier: str = "pro", preview: bool = True) -> None: + """将凭证加入 Redis 可用池及对应 tier 分桶、preview 分桶""" + if not self._redis_enabled: + return + try: + pipe = self._redis.pipeline() + pipe.sadd(self._rk_avail(mode), filename) + pipe.sadd(self._rk_tier(mode, tier), filename) + if mode == "geminicli" and preview: + pipe.sadd(self._rk_preview(mode), filename) + await pipe.execute() + except Exception as e: + log.warning(f"Redis add_cred error: {e}") + + async def _redis_remove_cred(self, mode: str, filename: str, tier: Optional[str] = None) -> None: + """从 Redis 所有池中移除凭证""" + if not self._redis_enabled: + return + try: + pipe = self._redis.pipeline() + pipe.srem(self._rk_avail(mode), filename) + if tier: + pipe.srem(self._rk_tier(mode, tier), filename) + else: + # tier 未知时从所有分桶中移除 + for t in ("free", "pro", "ultra"): + pipe.srem(self._rk_tier(mode, t), filename) + pipe.srem(self._rk_preview(mode), filename) + await pipe.execute() + except Exception as e: + log.warning(f"Redis remove_cred error: {e}") + + async def _redis_sync_cred(self, mode: str, filename: str, disabled: bool, tier: str = "pro", preview: bool = True) -> None: + """根据最新状态同步单个凭证在 Redis 中的集合成员""" + if not self._redis_enabled: + return + try: + pipe = self._redis.pipeline() + if disabled: + pipe.srem(self._rk_avail(mode), filename) + for t in ("free", "pro", "ultra"): + pipe.srem(self._rk_tier(mode, t), filename) + pipe.srem(self._rk_preview(mode), filename) + else: + pipe.sadd(self._rk_avail(mode), filename) + pipe.sadd(self._rk_tier(mode, tier), filename) + if mode == "geminicli" and preview: + pipe.sadd(self._rk_preview(mode), filename) + else: + pipe.srem(self._rk_preview(mode), filename) + await pipe.execute() + except Exception as e: + log.warning(f"Redis sync_cred error: {e}") + + async def _get_next_available_from_redis( + self, mode: str, model_name: Optional[str], exclude_free_tier: bool = False, preview_only: bool = False + ) -> Optional[tuple]: + """ + Redis 快速路径:随机取候选凭证,跳过冷却中的,返回 (filename, credential_data)。 + 失败或池为空时返回 None,由调用方降级到 MongoDB。 + """ + try: + # 选择候选池优先级:preview_only > exclude_free_tier > 全量池 + if preview_only and exclude_free_tier: + # preview 且非 free:preview ∩ (pro ∪ ultra) + preview_set = await self._redis.smembers(self._rk_preview(mode)) + pro_members = await self._redis.smembers(self._rk_tier(mode, "pro")) + ultra_members = await self._redis.smembers(self._rk_tier(mode, "ultra")) + non_free = pro_members | ultra_members + all_candidates = list(preview_set & non_free) + if not all_candidates: + log.debug(f"[Redis MISS] mode={mode} preview+non-free: no candidates, fallback to MongoDB") + return None + sample_size = min(len(all_candidates), 10) + candidates = random.sample(all_candidates, sample_size) + elif preview_only: + preview_key = self._rk_preview(mode) + preview_size = await self._redis.scard(preview_key) + if preview_size == 0: + log.debug(f"[Redis MISS] mode={mode} preview_only: pool empty, fallback to MongoDB") + return None + sample_size = min(preview_size, 10) + candidates = await self._redis.srandmember(preview_key, sample_size) + if not candidates: + return None + elif exclude_free_tier: + pro_members = await self._redis.smembers(self._rk_tier(mode, "pro")) + ultra_members = await self._redis.smembers(self._rk_tier(mode, "ultra")) + all_candidates = list(pro_members | ultra_members) + if not all_candidates: + log.debug(f"[Redis MISS] mode={mode} exclude_free: no non-free creds, fallback to MongoDB") + return None + sample_size = min(len(all_candidates), 10) + candidates = random.sample(all_candidates, sample_size) + else: + pool_key = self._rk_avail(mode) + pool_size = await self._redis.scard(pool_key) + if pool_size == 0: + log.debug(f"[Redis MISS] mode={mode} pool_key={pool_key}: pool empty, fallback to MongoDB") + return None + sample_size = min(pool_size, 10) + candidates = await self._redis.srandmember(pool_key, sample_size) + if not candidates: + return None + + # 过滤冷却中的凭证 + if model_name: + escaped = self._escape_model_name(model_name) + for filename in candidates: + cd_key = self._rk_cd(mode, filename, escaped) + if not await self._redis.exists(cd_key): + credential_data = await self.get_credential(filename, mode) + if mode == "antigravity": + state = await self.get_credential_state(filename, mode) + credential_data = credential_data or {} + credential_data["enable_credit"] = bool(state.get("enable_credit", False)) + log.debug(f"[Redis HIT] mode={mode} model={model_name} -> {filename}") + return filename, credential_data + # 所有候选都在冷却中,降级到 MongoDB + log.debug(f"[Redis MISS] mode={mode} model={model_name}: all {len(candidates)} candidates in cooldown, fallback to MongoDB") + return None + else: + filename = candidates[0] + credential_data = await self.get_credential(filename, mode) + if mode == "antigravity": + state = await self.get_credential_state(filename, mode) + credential_data = credential_data or {} + credential_data["enable_credit"] = bool(state.get("enable_credit", False)) + log.debug(f"[Redis HIT] mode={mode} -> {filename}") + return filename, credential_data + except Exception as e: + log.warning(f"Redis get_next_available error: {e}") + return None + + async def close(self) -> None: + """关闭 MongoDB 连接""" + if self._redis: + await self._redis.aclose() + self._redis = None + self._redis_enabled = False + if self._client: + self._client.close() + self._client = None + self._db = None + self._initialized = False + log.debug("MongoDB storage closed") + + def _ensure_initialized(self): + """确保已初始化""" + if not self._initialized: + raise RuntimeError("MongoDB manager not initialized") + + def _get_collection_name(self, mode: str) -> str: + """根据 mode 获取对应的集合名""" + if mode == "antigravity": + return "antigravity_credentials" + elif mode == "geminicli": + return "credentials" + else: + raise ValueError(f"Invalid mode: {mode}. Must be 'geminicli' or 'antigravity'") + + # ============ SQL 方法 ============ + + async def get_next_available_credential( + self, mode: str = "geminicli", model_name: Optional[str] = None + ) -> Optional[tuple[str, Dict[str, Any]]]: + """ + 随机获取一个可用凭证(负载均衡) + - 未禁用 + - 如果提供了 model_name,还会检查模型级冷却 + - 随机选择 + + Args: + mode: 凭证模式 ("geminicli" 或 "antigravity") + model_name: 完整模型名(如 "gemini-2.0-flash-exp") + + Note: + - 开启 Redis 时:利用 Redis Set 随机选凭证 + TTL key 判断冷却 + - 未开启 Redis 时:使用 count + random skip + limit(1) + """ + self._ensure_initialized() + + # Redis 快速路径:根据模型名派生过滤标志,直接在 Redis 分桶中筛选 + if self._redis_enabled: + model_lower = model_name.lower() if model_name else "" + exclude_free = mode == "geminicli" and "pro" in model_lower + preview_only = mode == "geminicli" and "preview" in model_lower + result = await self._get_next_available_from_redis( + mode, model_name, exclude_free_tier=exclude_free, preview_only=preview_only + ) + if result is not None: + return result + # result 为 None:池为空或所有候选都冷却中,降级到 MongoDB 以扩大样本空间 + log.debug(f"[MongoDB fallback] mode={mode} model={model_name}") + + try: + collection_name = self._get_collection_name(mode) + collection = self._db[collection_name] + current_time = time.time() + + # 构建普通查询(避免 $sample 聚合导致全集合扫描) + match_query: Dict[str, Any] = {"disabled": False} + + # pro 模型只允许非 free tier 凭证 + if mode == "geminicli" and model_name and "pro" in model_name.lower(): + match_query["tier"] = {"$ne": "free"} + + # preview 模型只允许 preview=True 的凭证 + if mode == "geminicli" and model_name and "preview" in model_name.lower(): + match_query["preview"] = True + + # 冷却检查:直接用 MongoDB 查询表达,无需 $addFields + if model_name: + escaped_model_name = self._escape_model_name(model_name) + field = f"model_cooldowns.{escaped_model_name}" + match_query["$or"] = [ + {field: {"$exists": False}}, + {field: {"$lte": current_time}}, + ] + + # 统计符合条件的凭证总数(走索引,极快) + count = await collection.count_documents(match_query) + if count == 0: + return None + + # 随机偏移 + limit(1),替代 $sample,避免全集合随机排序 + skip_n = random.randint(0, count - 1) + projection = {"filename": 1, "credential_data": 1, "enable_credit": 1, "_id": 0} + docs = await collection.find(match_query, projection).skip(skip_n).limit(1).to_list(1) + + if docs: + doc = docs[0] + credential_data = doc.get("credential_data") or {} + if mode == "antigravity": + credential_data["enable_credit"] = bool(doc.get("enable_credit", False)) + return doc["filename"], credential_data + + return None + + except Exception as e: + log.error(f"Error getting next available credential (mode={mode}, model_name={model_name}): {e}") + return None + + async def get_available_credentials_list(self, mode: str = "geminicli") -> List[str]: + """ + 获取所有可用凭证列表 + - 未禁用 + - 按轮换顺序排序 + """ + self._ensure_initialized() + + try: + collection_name = self._get_collection_name(mode) + collection = self._db[collection_name] + + pipeline = [ + {"$match": {"disabled": False}}, + {"$sort": {"rotation_order": 1}}, + {"$project": {"filename": 1, "_id": 0}} + ] + + docs = await collection.aggregate(pipeline).to_list(length=None) + return [doc["filename"] for doc in docs] + + except Exception as e: + log.error(f"Error getting available credentials list (mode={mode}): {e}") + return [] + + # ============ StorageBackend 协议方法 ============ + + async def store_credential(self, filename: str, credential_data: Dict[str, Any], mode: str = "geminicli") -> bool: + """存储或更新凭证""" + self._ensure_initialized() + + # 统一使用 basename 处理文件名 + filename = os.path.basename(filename) + + try: + collection_name = self._get_collection_name(mode) + collection = self._db[collection_name] + current_ts = time.time() + + # 使用 upsert + $setOnInsert + # 如果文档存在,只更新 credential_data 和 updated_at + # 如果文档不存在,设置所有默认字段 + + # 先尝试更新现有文档 + result = await collection.update_one( + {"filename": filename}, + { + "$set": { + "credential_data": credential_data, + "updated_at": current_ts, + } + } + ) + + # 如果没有匹配到(新凭证),需要插入 + if result.matched_count == 0: + # 获取下一个 rotation_order + pipeline = [ + {"$group": {"_id": None, "max_order": {"$max": "$rotation_order"}}}, + {"$project": {"_id": 0, "next_order": {"$add": ["$max_order", 1]}}} + ] + + result_list = await collection.aggregate(pipeline).to_list(length=1) + next_order = result_list[0]["next_order"] if result_list else 0 + + # 插入新凭证(使用 insert_one,因为我们已经确认不存在) + try: + new_credential = { + "filename": filename, + "credential_data": credential_data, + "disabled": False, + "error_codes": [], + "error_messages": [], + "last_success": current_ts, + "user_email": None, + "model_cooldowns": {}, + "preview": True, + "tier": "pro", + "rotation_order": next_order, + "call_count": 0, + "created_at": current_ts, + "updated_at": current_ts, + } + + if mode == "antigravity": + new_credential["enable_credit"] = False + + await collection.insert_one(new_credential) + # 新凭证插入成功,添加到 Redis 可用池 + await self._redis_add_cred(mode, filename) + except Exception as insert_error: + # 处理并发插入导致的重复键错误 + if "duplicate key" in str(insert_error).lower(): + # 重试更新(已存在的凭证,无需更新 Redis) + await collection.update_one( + {"filename": filename}, + {"$set": {"credential_data": credential_data, "updated_at": current_ts}} + ) + else: + raise + + log.debug(f"Stored credential: {filename} (mode={mode})") + return True + + except Exception as e: + log.error(f"Error storing credential {filename}: {e}") + return False + + async def get_credential(self, filename: str, mode: str = "geminicli") -> Optional[Dict[str, Any]]: + """获取凭证数据""" + self._ensure_initialized() + + # 统一使用 basename 处理文件名 + filename = os.path.basename(filename) + + try: + collection_name = self._get_collection_name(mode) + collection = self._db[collection_name] + + # 精确匹配,只投影需要的字段 + doc = await collection.find_one( + {"filename": filename}, + {"credential_data": 1, "_id": 0} + ) + if doc: + return doc.get("credential_data") + + return None + + except Exception as e: + log.error(f"Error getting credential {filename}: {e}") + return None + + async def list_credentials(self, mode: str = "geminicli") -> List[str]: + """列出所有凭证文件名""" + self._ensure_initialized() + + try: + collection_name = self._get_collection_name(mode) + collection = self._db[collection_name] + + # 使用聚合管道 + pipeline = [ + {"$sort": {"rotation_order": 1}}, + {"$project": {"filename": 1, "_id": 0}} + ] + + docs = await collection.aggregate(pipeline).to_list(length=None) + return [doc["filename"] for doc in docs] + + except Exception as e: + log.error(f"Error listing credentials: {e}") + return [] + + async def delete_credential(self, filename: str, mode: str = "geminicli") -> bool: + """删除凭证""" + self._ensure_initialized() + + # 统一使用 basename 处理文件名 + filename = os.path.basename(filename) + + try: + collection_name = self._get_collection_name(mode) + collection = self._db[collection_name] + + # 精确匹配删除 + result = await collection.delete_one({"filename": filename}) + deleted_count = result.deleted_count + + if deleted_count > 0: + # 从 Redis 池中移除 + await self._redis_remove_cred(mode, filename) + log.debug(f"Deleted {deleted_count} credential(s): {filename} (mode={mode})") + return True + else: + log.warning(f"No credential found to delete: {filename} (mode={mode})") + return False + + except Exception as e: + log.error(f"Error deleting credential {filename}: {e}") + return False + + async def get_duplicate_credentials_by_email(self, mode: str = "geminicli") -> Dict[str, Any]: + """ + 获取按邮箱分组的重复凭证信息(只查询邮箱和文件名,不加载完整凭证数据) + 用于去重操作 + + Args: + mode: 凭证模式 ("geminicli" 或 "antigravity") + + Returns: + 包含 email_groups(邮箱分组)、duplicate_count(重复数量)、no_email_count(无邮箱数量)的字典 + """ + self._ensure_initialized() + + try: + collection_name = self._get_collection_name(mode) + collection = self._db[collection_name] + + # 使用聚合管道,只查询 filename 和 user_email 字段 + pipeline = [ + { + "$project": { + "filename": 1, + "user_email": 1, + "_id": 0 + } + }, + { + "$sort": {"filename": 1} + } + ] + + docs = await collection.aggregate(pipeline).to_list(length=None) + + # 按邮箱分组 + email_to_files = {} + no_email_files = [] + + for doc in docs: + filename = doc.get("filename") + user_email = doc.get("user_email") + + if user_email: + if user_email not in email_to_files: + email_to_files[user_email] = [] + email_to_files[user_email].append(filename) + else: + no_email_files.append(filename) + + # 找出重复的邮箱组 + duplicate_groups = [] + total_duplicate_count = 0 + + for email, files in email_to_files.items(): + if len(files) > 1: + # 保留第一个文件,其他为重复 + duplicate_groups.append({ + "email": email, + "kept_file": files[0], + "duplicate_files": files[1:], + "duplicate_count": len(files) - 1, + }) + total_duplicate_count += len(files) - 1 + + return { + "email_groups": email_to_files, + "duplicate_groups": duplicate_groups, + "duplicate_count": total_duplicate_count, + "no_email_files": no_email_files, + "no_email_count": len(no_email_files), + "unique_email_count": len(email_to_files), + "total_count": len(docs), + } + + except Exception as e: + log.error(f"Error getting duplicate credentials by email: {e}") + return { + "email_groups": {}, + "duplicate_groups": [], + "duplicate_count": 0, + "no_email_files": [], + "no_email_count": 0, + "unique_email_count": 0, + "total_count": 0, + } + + async def update_credential_state( + self, filename: str, state_updates: Dict[str, Any], mode: str = "geminicli" + ) -> bool: + """更新凭证状态""" + self._ensure_initialized() + + # 统一使用 basename 处理文件名 + filename = os.path.basename(filename) + + try: + collection_name = self._get_collection_name(mode) + collection = self._db[collection_name] + + # 过滤只更新状态字段 + valid_updates = { + k: v for k, v in state_updates.items() if k in self.STATE_FIELDS + } + + if mode != "antigravity": + valid_updates.pop("enable_credit", None) + + if not valid_updates: + return True + + valid_updates["updated_at"] = time.time() + + # 精确匹配更新 + result = await collection.update_one( + {"filename": filename}, {"$set": valid_updates} + ) + updated_count = result.modified_count + result.matched_count + + # 如果 disabled 发生变化,同步 Redis 池成员关系 + if self._redis_enabled and "disabled" in valid_updates: + if valid_updates["disabled"]: + # 直接禁用:从集合中移除 + await self._redis_remove_cred(mode, filename) + else: + # 重新启用:需要读取当前 tier/preview 以正确放入分桶 + doc = await collection.find_one( + {"filename": filename}, + projection={"tier": 1, "preview": 1, "_id": 0}, + ) + tier_val = (doc or {}).get("tier", "pro") or "pro" + preview_val = (doc or {}).get("preview", True) + await self._redis_sync_cred(mode, filename, disabled=False, tier=tier_val, preview=preview_val) + elif self._redis_enabled and ("tier" in valid_updates or "preview" in valid_updates): + # tier 或 preview 更新:重新同步分桶(只在凭证未禁用时) + doc = await collection.find_one( + {"filename": filename}, + projection={"disabled": 1, "tier": 1, "preview": 1, "_id": 0}, + ) + if doc and not doc.get("disabled", False): + tier_val = doc.get("tier", "pro") or "pro" + preview_val = doc.get("preview", True) + await self._redis_sync_cred(mode, filename, disabled=False, tier=tier_val, preview=preview_val) + + return updated_count > 0 + + except Exception as e: + log.error(f"Error updating credential state {filename}: {e}") + return False + + async def get_credential_state(self, filename: str, mode: str = "geminicli") -> Dict[str, Any]: + """获取凭证状态(不包含error_messages)""" + self._ensure_initialized() + + # 统一使用 basename 处理文件名 + filename = os.path.basename(filename) + + try: + collection_name = self._get_collection_name(mode) + collection = self._db[collection_name] + current_time = time.time() + + # 精确匹配 + doc = await collection.find_one({"filename": filename}) + + if doc: + model_cooldowns = doc.get("model_cooldowns", {}) + # 过滤掉损坏的数据(dict类型)和过期的冷却 + if model_cooldowns: + model_cooldowns = { + k: v for k, v in model_cooldowns.items() + if isinstance(v, (int, float)) and v > current_time + } + + state = { + "disabled": doc.get("disabled", False), + "error_codes": doc.get("error_codes", []), + "last_success": doc.get("last_success", current_time), + "user_email": doc.get("user_email"), + "model_cooldowns": model_cooldowns, + "preview": doc.get("preview", True), + "tier": doc.get("tier", "pro"), + } + if mode == "antigravity": + state["enable_credit"] = doc.get("enable_credit", False) + return state + + # 返回默认状态 + default_state = { + "disabled": False, + "error_codes": [], + "last_success": current_time, + "user_email": None, + "model_cooldowns": {}, + "preview": True, + "tier": "pro", + } + if mode == "antigravity": + default_state["enable_credit"] = False + return default_state + + except Exception as e: + log.error(f"Error getting credential state {filename}: {e}") + return {} + + async def get_all_credential_states(self, mode: str = "geminicli") -> Dict[str, Dict[str, Any]]: + """获取所有凭证状态(不包含error_messages)""" + self._ensure_initialized() + + try: + collection_name = self._get_collection_name(mode) + collection = self._db[collection_name] + + # 使用投影只获取需要的字段(不包含error_messages) + projection = { + "filename": 1, + "disabled": 1, + "error_codes": 1, + "last_success": 1, + "user_email": 1, + "model_cooldowns": 1, + "preview": 1, + "tier": 1, + "enable_credit": 1, + "_id": 0 + } + + cursor = collection.find({}, projection=projection) + + states = {} + current_time = time.time() + + async for doc in cursor: + filename = doc["filename"] + model_cooldowns = doc.get("model_cooldowns", {}) + + # 自动过滤掉已过期的模型CD + if model_cooldowns: + model_cooldowns = { + k: v for k, v in model_cooldowns.items() + if isinstance(v, (int, float)) and v > current_time + } + + state = { + "disabled": doc.get("disabled", False), + "error_codes": doc.get("error_codes", []), + "last_success": doc.get("last_success", time.time()), + "user_email": doc.get("user_email"), + "model_cooldowns": model_cooldowns, + "preview": doc.get("preview", True), + "tier": doc.get("tier", "pro"), + } + if mode == "antigravity": + state["enable_credit"] = doc.get("enable_credit", False) + states[filename] = state + + return states + + except Exception as e: + log.error(f"Error getting all credential states: {e}") + return {} + + async def get_credentials_summary( + self, + offset: int = 0, + limit: Optional[int] = None, + status_filter: str = "all", + mode: str = "geminicli", + error_code_filter: Optional[str] = None, + cooldown_filter: Optional[str] = None, + preview_filter: Optional[str] = None, + tier_filter: Optional[str] = None + ) -> Dict[str, Any]: + """ + 获取凭证的摘要信息(不包含完整凭证数据)- 支持分页和状态筛选 + + Args: + offset: 跳过的记录数(默认0) + limit: 返回的最大记录数(None表示返回所有) + status_filter: 状态筛选(all=全部, enabled=仅启用, disabled=仅禁用) + mode: 凭证模式 ("geminicli" 或 "antigravity") + error_code_filter: 错误码筛选(格式如"400"或"403",筛选包含该错误码的凭证) + cooldown_filter: 冷却状态筛选("in_cooldown"=冷却中, "no_cooldown"=未冷却) + preview_filter: Preview筛选("preview"=支持preview, "no_preview"=不支持preview,仅geminicli模式有效) + tier_filter: tier筛选("free", "pro", "ultra") + + Returns: + 包含 items(凭证列表)、total(总数)、offset、limit 的字典 + """ + self._ensure_initialized() + + try: + # 根据 mode 选择集合名 + collection_name = self._get_collection_name(mode) + collection = self._db[collection_name] + + # 构建查询条件 + query = {} + if status_filter == "enabled": + query["disabled"] = False + elif status_filter == "disabled": + query["disabled"] = True + + # 错误码筛选 - 兼容存储为数字或字符串的情况 + if error_code_filter and str(error_code_filter).strip().lower() != "all": + filter_value = str(error_code_filter).strip() + query_values = [filter_value] + try: + query_values.append(int(filter_value)) + except ValueError: + pass + query["error_codes"] = {"$in": query_values} + + # 计算全局统计数据(不受筛选条件影响) + global_stats = {"total": 0, "normal": 0, "disabled": 0} + stats_pipeline = [ + { + "$group": { + "_id": "$disabled", + "count": {"$sum": 1} + } + } + ] + + stats_result = await collection.aggregate(stats_pipeline).to_list(length=10) + for item in stats_result: + count = item["count"] + global_stats["total"] += count + if item["_id"]: + global_stats["disabled"] = count + else: + global_stats["normal"] = count + + # 获取所有匹配的文档(用于冷却筛选,因为需要在Python中判断) + projection = { + "filename": 1, + "disabled": 1, + "error_codes": 1, + "last_success": 1, + "user_email": 1, + "rotation_order": 1, + "model_cooldowns": 1, + "preview": 1, + "tier": 1, + "enable_credit": 1, + "_id": 0 + } + + cursor = collection.find(query, projection=projection).sort("rotation_order", 1) + + all_summaries = [] + current_time = time.time() + + async for doc in cursor: + model_cooldowns = doc.get("model_cooldowns", {}) + + # 自动过滤掉已过期的模型CD + active_cooldowns = {} + if model_cooldowns: + active_cooldowns = { + k: v for k, v in model_cooldowns.items() + if isinstance(v, (int, float)) and v > current_time + } + + summary = { + "filename": doc["filename"], + "disabled": doc.get("disabled", False), + "error_codes": doc.get("error_codes", []), + "last_success": doc.get("last_success", current_time), + "user_email": doc.get("user_email"), + "rotation_order": doc.get("rotation_order", 0), + "model_cooldowns": active_cooldowns, + "preview": doc.get("preview", True), + "tier": doc.get("tier", "pro"), + } + + if mode == "antigravity": + summary["enable_credit"] = bool(doc.get("enable_credit", False)) + + if mode == "geminicli" and preview_filter: + preview_value = summary.get("preview", True) + if preview_filter == "preview" and not preview_value: + continue + if preview_filter == "no_preview" and preview_value: + continue + + # 应用tier筛选 + if tier_filter and tier_filter in ("free", "pro", "ultra"): + if summary["tier"] != tier_filter: + continue + + # 应用冷却筛选 + if cooldown_filter == "in_cooldown": + # 只保留有冷却的凭证 + if active_cooldowns: + all_summaries.append(summary) + elif cooldown_filter == "no_cooldown": + # 只保留没有冷却的凭证 + if not active_cooldowns: + all_summaries.append(summary) + else: + # 不筛选冷却状态 + all_summaries.append(summary) + + # 应用分页 + total_count = len(all_summaries) + if limit is not None: + summaries = all_summaries[offset:offset + limit] + else: + summaries = all_summaries[offset:] + + return { + "items": summaries, + "total": total_count, + "offset": offset, + "limit": limit, + "stats": global_stats, + } + + except Exception as e: + log.error(f"Error getting credentials summary: {e}") + return { + "items": [], + "total": 0, + "offset": offset, + "limit": limit, + "stats": {"total": 0, "normal": 0, "disabled": 0}, + } + + # ============ 配置管理(内存缓存 + 可选 Redis)============ + + def _rk_config(self, key: str) -> str: + """配置项的 Redis key""" + return f"gcli:config:{key}" + + def _rk_config_all(self) -> str: + """所有配置的 Redis Hash key""" + return "gcli:config" + + async def _load_config_to_redis(self) -> None: + """将所有配置从 MongoDB 同步到 Redis Hash""" + if not self._redis_enabled: + return + try: + config_collection = self._db["config"] + cursor = config_collection.find({}) + mapping = {} + async for doc in cursor: + mapping[doc["key"]] = json.dumps(doc.get("value")) + pipe = self._redis.pipeline() + pipe.delete(self._rk_config_all()) + if mapping: + pipe.hset(self._rk_config_all(), mapping=mapping) + await pipe.execute() + log.debug(f"Synced {len(mapping)} config items to Redis") + except Exception as e: + log.warning(f"Failed to sync config to Redis: {e}") + + async def set_config(self, key: str, value: Any) -> bool: + """设置配置(写入数据库;Redis 启用时写 Redis,否则更新内存缓存)""" + self._ensure_initialized() + + try: + config_collection = self._db["config"] + await config_collection.update_one( + {"key": key}, + {"$set": {"value": value, "updated_at": time.time()}}, + upsert=True, + ) + + if self._redis_enabled: + try: + await self._redis.hset(self._rk_config_all(), key, json.dumps(value)) + except Exception as e: + log.warning(f"Redis config set error for key={key}: {e}") + else: + self._config_cache[key] = value + + return True + + except Exception as e: + log.error(f"Error setting config {key}: {e}") + return False + + async def reload_config_cache(self): + """重新加载配置缓存(在批量修改配置后调用)""" + self._ensure_initialized() + if self._redis_enabled: + await self._load_config_to_redis() + else: + self._config_loaded = False + await self._load_config_cache() + log.info("Config cache reloaded from database") + + async def get_config(self, key: str, default: Any = None) -> Any: + """获取配置(Redis 启用时从 Redis 读取,否则从内存缓存)""" + self._ensure_initialized() + + if self._redis_enabled: + try: + raw = await self._redis.hget(self._rk_config_all(), key) + if raw is not None: + return json.loads(raw) + return default + except Exception as e: + log.warning(f"Redis config get error for key={key}: {e}") + return default + + return self._config_cache.get(key, default) + + async def get_all_config(self) -> Dict[str, Any]: + """获取所有配置(Redis 启用时从 Redis 读取,否则从内存缓存)""" + self._ensure_initialized() + + if self._redis_enabled: + try: + raw_map = await self._redis.hgetall(self._rk_config_all()) + return {k: json.loads(v) for k, v in raw_map.items()} + except Exception as e: + log.warning(f"Redis config getall error: {e}") + return {} + + return self._config_cache.copy() + + async def delete_config(self, key: str) -> bool: + """删除配置""" + self._ensure_initialized() + + try: + config_collection = self._db["config"] + result = await config_collection.delete_one({"key": key}) + + if self._redis_enabled: + try: + await self._redis.hdel(self._rk_config_all(), key) + except Exception as e: + log.warning(f"Redis config delete error for key={key}: {e}") + else: + self._config_cache.pop(key, None) + + return result.deleted_count > 0 + + except Exception as e: + log.error(f"Error deleting config {key}: {e}") + return False + + async def get_credential_errors(self, filename: str, mode: str = "geminicli") -> Dict[str, Any]: + """ + 专门获取凭证的错误信息(包含 error_codes 和 error_messages) + + Args: + filename: 凭证文件名 + mode: 凭证模式 ("geminicli" 或 "antigravity") + + Returns: + 包含 error_codes 和 error_messages 的字典 + """ + self._ensure_initialized() + + # 统一使用 basename 处理文件名 + filename = os.path.basename(filename) + + try: + collection_name = self._get_collection_name(mode) + collection = self._db[collection_name] + + # 精确匹配 + doc = await collection.find_one( + {"filename": filename}, + {"error_codes": 1, "error_messages": 1, "_id": 0} + ) + + if doc: + return { + "filename": filename, + "error_codes": doc.get("error_codes", []), + "error_messages": doc.get("error_messages", []), + } + + # 凭证不存在,返回空错误信息 + return { + "filename": filename, + "error_codes": [], + "error_messages": [], + } + + except Exception as e: + log.error(f"Error getting credential errors {filename}: {e}") + return { + "filename": filename, + "error_codes": [], + "error_messages": [], + "error": str(e) + } + + # ============ 模型级冷却管理 ============ + + async def set_model_cooldown( + self, + filename: str, + model_name: str, + cooldown_until: Optional[float], + mode: str = "geminicli" + ) -> bool: + """ + 设置特定模型的冷却时间 + + Args: + filename: 凭证文件名 + model_name: 模型名(完整模型名,如 "gemini-2.0-flash-exp") + cooldown_until: 冷却截止时间戳(None 表示清除冷却) + mode: 凭证模式 ("geminicli" 或 "antigravity") + + Returns: + 是否成功 + """ + self._ensure_initialized() + + # 统一使用 basename 处理文件名 + filename = os.path.basename(filename) + + try: + collection_name = self._get_collection_name(mode) + collection = self._db[collection_name] + + # 转义模型名中的点号 + escaped_model_name = self._escape_model_name(model_name) + + # 使用原子操作直接更新,避免竞态条件 + if cooldown_until is None: + # 删除指定模型的冷却 + result = await collection.update_one( + {"filename": filename}, + { + "$unset": {f"model_cooldowns.{escaped_model_name}": ""}, + "$set": {"updated_at": time.time()} + } + ) + else: + # 设置冷却时间 + result = await collection.update_one( + {"filename": filename}, + { + "$set": { + f"model_cooldowns.{escaped_model_name}": cooldown_until, + "updated_at": time.time() + } + } + ) + + if result.matched_count == 0: + log.warning(f"Credential {filename} not found") + return False + + # 同步写入 Redis TTL key + if self._redis_enabled: + cd_key = self._rk_cd(mode, filename, escaped_model_name) + if cooldown_until is None: + await self._redis.delete(cd_key) + else: + ttl = int(cooldown_until - time.time()) + if ttl > 0: + await self._redis.setex(cd_key, ttl, str(cooldown_until)) + else: + # 冷却已经过期,确保清除 + await self._redis.delete(cd_key) + + log.debug(f"Set model cooldown: {filename}, model_name={model_name}, cooldown_until={cooldown_until}") + return True + + except Exception as e: + log.error(f"Error setting model cooldown for {filename}: {e}") + return False + + async def clear_all_model_cooldowns( + self, + filename: str, + mode: str = "geminicli" + ) -> bool: + """ + 清除某个凭证的所有模型冷却时间 + + Args: + filename: 凭证文件名 + mode: 凭证模式 ("geminicli" 或 "antigravity") + + Returns: + 是否成功 + """ + self._ensure_initialized() + + filename = os.path.basename(filename) + + try: + collection_name = self._get_collection_name(mode) + collection = self._db[collection_name] + + doc = await collection.find_one( + {"filename": filename}, + {"model_cooldowns": 1, "_id": 0} + ) + if not doc: + log.warning(f"Credential {filename} not found") + return False + + model_cooldowns = doc.get("model_cooldowns") or {} + + await collection.update_one( + {"filename": filename}, + { + "$set": { + "model_cooldowns": {}, + "updated_at": time.time(), + } + } + ) + + if self._redis_enabled and isinstance(model_cooldowns, dict) and model_cooldowns: + redis_keys = [self._rk_cd(mode, filename, escaped_model) for escaped_model in model_cooldowns.keys()] + await self._redis.delete(*redis_keys) + + log.debug(f"Cleared all model cooldowns: {filename} (mode={mode})") + return True + + except Exception as e: + log.error(f"Error clearing all model cooldowns for {filename}: {e}") + return False + + async def record_success( + self, + filename: str, + model_name: Optional[str] = None, + mode: str = "geminicli" + ) -> None: + """ + 成功调用后的条件写入: + - 只有当前 error_codes 非空时才清除错误并写 last_success + - 只有当前存在该模型的冷却键时才清除 + 通过 MongoDB 服务端条件匹配实现 + """ + self._ensure_initialized() + filename = os.path.basename(filename) + + try: + collection_name = self._get_collection_name(mode) + collection = self._db[collection_name] + now = time.time() + + # 条件写入:只有 error_codes 非空时才触发,避免无意义的写 IO + await collection.update_one( + {"filename": filename, "error_codes": {"$ne": []}}, + {"$set": { + "last_success": now, + "error_codes": [], + "error_messages": {}, + "updated_at": now, + }} + ) + + # 条件删除模型冷却:只有该键存在时才写入 + if model_name: + escaped = self._escape_model_name(model_name) + await collection.update_one( + {"filename": filename, f"model_cooldowns.{escaped}": {"$exists": True}}, + {"$unset": {f"model_cooldowns.{escaped}": ""}, "$set": {"updated_at": now}} + ) + # 同步删除 Redis 冷却 key + if self._redis_enabled: + await self._redis.delete(self._rk_cd(mode, filename, escaped)) + + except Exception as e: + log.error(f"Error recording success for {filename}: {e}") \ No newline at end of file diff --git a/src/storage/psql_manager.py b/src/storage/psql_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..fe7fa372ed1764e288d86a0b4906ace89e8ef7e2 --- /dev/null +++ b/src/storage/psql_manager.py @@ -0,0 +1,1025 @@ +""" +PostgreSQL 存储管理器 +""" + +import asyncio +import json +import os +import time +from typing import Any, Dict, List, Optional, Tuple + +import asyncpg + +from log import log + + +class PSQLManager: + """PostgreSQL 数据库管理器""" + + # 状态字段常量 + STATE_FIELDS = { + "error_codes", + "error_messages", + "disabled", + "last_success", + "user_email", + "model_cooldowns", + "preview", + "tier", + "enable_credit", + } + + def __init__(self): + self._dsn: Optional[str] = None + self._pool: Optional[asyncpg.Pool] = None + self._initialized = False + self._lock = asyncio.Lock() + + # 内存配置缓存 + self._config_cache: Dict[str, Any] = {} + self._config_loaded = False + + async def initialize(self) -> None: + """初始化 PostgreSQL 数据库""" + if self._initialized: + return + + async with self._lock: + if self._initialized: + return + + try: + self._dsn = os.getenv("POSTGRESQL_URI", "") + if not self._dsn: + raise RuntimeError("POSTGRESQL_URI environment variable is not set") + + self._pool = await asyncpg.create_pool(self._dsn, min_size=2, max_size=10) + + async with self._pool.acquire() as conn: + await self._create_tables(conn) + await self._ensure_schema_compatibility(conn) + + await self._load_config_cache() + + self._initialized = True + log.info("PostgreSQL storage initialized") + + except Exception as e: + log.error(f"Error initializing PostgreSQL: {e}") + if self._pool: + await self._pool.close() + self._pool = None + raise + + async def _create_tables(self, conn: asyncpg.Connection) -> None: + """创建数据库表和索引""" + await conn.execute(""" + CREATE TABLE IF NOT EXISTS credentials ( + id SERIAL PRIMARY KEY, + filename TEXT UNIQUE NOT NULL, + credential_data TEXT NOT NULL, + + disabled INTEGER DEFAULT 0, + error_codes TEXT DEFAULT '[]', + error_messages TEXT DEFAULT '[]', + last_success DOUBLE PRECISION, + user_email TEXT, + + model_cooldowns TEXT DEFAULT '{}', + preview INTEGER DEFAULT 1, + tier TEXT DEFAULT 'pro', + + rotation_order INTEGER DEFAULT 0, + call_count INTEGER DEFAULT 0, + + created_at DOUBLE PRECISION DEFAULT EXTRACT(EPOCH FROM NOW()), + updated_at DOUBLE PRECISION DEFAULT EXTRACT(EPOCH FROM NOW()) + ) + """) + + await conn.execute(""" + CREATE TABLE IF NOT EXISTS antigravity_credentials ( + id SERIAL PRIMARY KEY, + filename TEXT UNIQUE NOT NULL, + credential_data TEXT NOT NULL, + + disabled INTEGER DEFAULT 0, + error_codes TEXT DEFAULT '[]', + error_messages TEXT DEFAULT '[]', + last_success DOUBLE PRECISION, + user_email TEXT, + + model_cooldowns TEXT DEFAULT '{}', + tier TEXT DEFAULT 'pro', + enable_credit INTEGER DEFAULT 0, + + rotation_order INTEGER DEFAULT 0, + call_count INTEGER DEFAULT 0, + + created_at DOUBLE PRECISION DEFAULT EXTRACT(EPOCH FROM NOW()), + updated_at DOUBLE PRECISION DEFAULT EXTRACT(EPOCH FROM NOW()) + ) + """) + + await conn.execute(""" + CREATE TABLE IF NOT EXISTS config ( + key TEXT PRIMARY KEY, + value TEXT NOT NULL, + updated_at DOUBLE PRECISION DEFAULT EXTRACT(EPOCH FROM NOW()) + ) + """) + + # 索引 + await conn.execute(""" + CREATE INDEX IF NOT EXISTS idx_disabled ON credentials(disabled) + """) + await conn.execute(""" + CREATE INDEX IF NOT EXISTS idx_rotation_order ON credentials(rotation_order) + """) + await conn.execute(""" + CREATE INDEX IF NOT EXISTS idx_ag_disabled ON antigravity_credentials(disabled) + """) + await conn.execute(""" + CREATE INDEX IF NOT EXISTS idx_ag_rotation_order ON antigravity_credentials(rotation_order) + """) + + log.debug("PostgreSQL tables and indexes created") + + async def _ensure_schema_compatibility(self, conn: asyncpg.Connection) -> None: + """确保数据库结构兼容,自动修复缺失的列""" + required_columns = { + "credentials": [ + ("disabled", "INTEGER DEFAULT 0"), + ("error_codes", "TEXT DEFAULT '[]'"), + ("error_messages", "TEXT DEFAULT '[]'"), + ("last_success", "DOUBLE PRECISION"), + ("user_email", "TEXT"), + ("model_cooldowns", "TEXT DEFAULT '{}'"), + ("preview", "INTEGER DEFAULT 1"), + ("tier", "TEXT DEFAULT 'pro'"), + ("rotation_order", "INTEGER DEFAULT 0"), + ("call_count", "INTEGER DEFAULT 0"), + ("created_at", "DOUBLE PRECISION DEFAULT EXTRACT(EPOCH FROM NOW())"), + ("updated_at", "DOUBLE PRECISION DEFAULT EXTRACT(EPOCH FROM NOW())"), + ], + "antigravity_credentials": [ + ("disabled", "INTEGER DEFAULT 0"), + ("error_codes", "TEXT DEFAULT '[]'"), + ("error_messages", "TEXT DEFAULT '[]'"), + ("last_success", "DOUBLE PRECISION"), + ("user_email", "TEXT"), + ("model_cooldowns", "TEXT DEFAULT '{}'"), + ("tier", "TEXT DEFAULT 'pro'"), + ("enable_credit", "INTEGER DEFAULT 0"), + ("rotation_order", "INTEGER DEFAULT 0"), + ("call_count", "INTEGER DEFAULT 0"), + ("created_at", "DOUBLE PRECISION DEFAULT EXTRACT(EPOCH FROM NOW())"), + ("updated_at", "DOUBLE PRECISION DEFAULT EXTRACT(EPOCH FROM NOW())"), + ], + } + + try: + for table_name, columns in required_columns.items(): + rows = await conn.fetch(""" + SELECT column_name FROM information_schema.columns + WHERE table_name = $1 + """, table_name) + existing = {r["column_name"] for r in rows} + + for col_name, col_def in columns: + if col_name not in existing: + try: + await conn.execute( + f"ALTER TABLE {table_name} ADD COLUMN {col_name} {col_def}" + ) + log.info(f"Added missing column {table_name}.{col_name}") + except Exception as e: + log.error(f"Failed to add column {table_name}.{col_name}: {e}") + except Exception as e: + log.error(f"Error ensuring schema compatibility: {e}") + + async def _load_config_cache(self) -> None: + """加载配置到内存缓存""" + if self._config_loaded: + return + + try: + async with self._pool.acquire() as conn: + rows = await conn.fetch("SELECT key, value FROM config") + + for row in rows: + try: + self._config_cache[row["key"]] = json.loads(row["value"]) + except json.JSONDecodeError: + self._config_cache[row["key"]] = row["value"] + + self._config_loaded = True + log.debug(f"Loaded {len(self._config_cache)} config items into cache") + + except Exception as e: + log.error(f"Error loading config cache: {e}") + self._config_cache = {} + + async def close(self) -> None: + """关闭数据库连接池""" + if self._pool: + await self._pool.close() + self._pool = None + self._initialized = False + log.debug("PostgreSQL storage closed") + + def _ensure_initialized(self) -> None: + if not self._initialized or not self._pool: + raise RuntimeError("PostgreSQL manager not initialized") + + def _get_table_name(self, mode: str) -> str: + if mode == "antigravity": + return "antigravity_credentials" + elif mode == "geminicli": + return "credentials" + else: + raise ValueError(f"Invalid mode: {mode}. Must be 'geminicli' or 'antigravity'") + + # ============ 凭证查询方法 ============ + + async def get_next_available_credential( + self, mode: str = "geminicli", model_name: Optional[str] = None + ) -> Optional[Tuple[str, Dict[str, Any]]]: + """随机获取一个可用凭证(负载均衡)""" + self._ensure_initialized() + + try: + table_name = self._get_table_name(mode) + current_time = time.time() + + async with self._pool.acquire() as conn: + if mode == "geminicli": + tier_clause = "" + if model_name and "pro" in model_name.lower(): + tier_clause = "AND (tier IS NULL OR tier != 'free')" + + rows = await conn.fetch(f""" + SELECT filename, credential_data, model_cooldowns, preview + FROM {table_name} + WHERE disabled = 0 {tier_clause} + ORDER BY RANDOM() + """) + + if not model_name: + if rows: + return rows[0]["filename"], json.loads(rows[0]["credential_data"]) + return None + + is_preview_model = "preview" in model_name.lower() + non_preview_creds = [] + preview_creds = [] + + for row in rows: + model_cooldowns = json.loads(row["model_cooldowns"] or "{}") + cd = model_cooldowns.get(model_name) + if cd is None or current_time >= cd: + if row["preview"]: + preview_creds.append((row["filename"], row["credential_data"])) + else: + non_preview_creds.append((row["filename"], row["credential_data"])) + + if is_preview_model: + if preview_creds: + return preview_creds[0][0], json.loads(preview_creds[0][1]) + else: + if non_preview_creds: + return non_preview_creds[0][0], json.loads(non_preview_creds[0][1]) + elif preview_creds: + return preview_creds[0][0], json.loads(preview_creds[0][1]) + + return None + else: + rows = await conn.fetch(f""" + SELECT filename, credential_data, model_cooldowns, enable_credit + FROM {table_name} + WHERE disabled = 0 + ORDER BY RANDOM() + """) + + if not model_name: + if rows: + credential_data = json.loads(rows[0]["credential_data"]) + credential_data["enable_credit"] = bool(rows[0]["enable_credit"]) + return rows[0]["filename"], credential_data + return None + + for row in rows: + model_cooldowns = json.loads(row["model_cooldowns"] or "{}") + cd = model_cooldowns.get(model_name) + if cd is None or current_time >= cd: + credential_data = json.loads(row["credential_data"]) + credential_data["enable_credit"] = bool(row["enable_credit"]) + return row["filename"], credential_data + + return None + + except Exception as e: + log.error(f"Error getting next available credential (mode={mode}, model_name={model_name}): {e}") + return None + + async def get_available_credentials_list(self) -> List[str]: + """获取所有可用凭证列表""" + self._ensure_initialized() + + try: + async with self._pool.acquire() as conn: + rows = await conn.fetch(""" + SELECT filename FROM credentials + WHERE disabled = 0 + ORDER BY rotation_order ASC + """) + return [r["filename"] for r in rows] + except Exception as e: + log.error(f"Error getting available credentials list: {e}") + return [] + + # ============ StorageBackend 协议方法 ============ + + async def store_credential(self, filename: str, credential_data: Dict[str, Any], mode: str = "geminicli") -> bool: + """存储或更新凭证""" + self._ensure_initialized() + filename = os.path.basename(filename) + + try: + table_name = self._get_table_name(mode) + async with self._pool.acquire() as conn: + existing = await conn.fetchrow( + f"SELECT rotation_order FROM {table_name} WHERE filename = $1", filename + ) + + if existing: + await conn.execute( + f""" + UPDATE {table_name} + SET credential_data = $1, + updated_at = EXTRACT(EPOCH FROM NOW()) + WHERE filename = $2 + """, + json.dumps(credential_data), filename + ) + else: + row = await conn.fetchrow( + f"SELECT COALESCE(MAX(rotation_order), -1) + 1 AS next_order FROM {table_name}" + ) + next_order = row["next_order"] + await conn.execute( + f""" + INSERT INTO {table_name} + (filename, credential_data, rotation_order, last_success) + VALUES ($1, $2, $3, $4) + """, + filename, json.dumps(credential_data), next_order, time.time() + ) + + log.debug(f"Stored credential: {filename} (mode={mode})") + return True + + except Exception as e: + log.error(f"Error storing credential {filename}: {e}") + return False + + async def get_credential(self, filename: str, mode: str = "geminicli") -> Optional[Dict[str, Any]]: + """获取凭证数据""" + self._ensure_initialized() + filename = os.path.basename(filename) + + try: + table_name = self._get_table_name(mode) + async with self._pool.acquire() as conn: + row = await conn.fetchrow( + f"SELECT credential_data FROM {table_name} WHERE filename = $1", filename + ) + if row: + return json.loads(row["credential_data"]) + return None + except Exception as e: + log.error(f"Error getting credential {filename}: {e}") + return None + + async def list_credentials(self, mode: str = "geminicli") -> List[str]: + """列出所有凭证文件名(包括禁用的)""" + self._ensure_initialized() + + try: + table_name = self._get_table_name(mode) + async with self._pool.acquire() as conn: + rows = await conn.fetch( + f"SELECT filename FROM {table_name} ORDER BY rotation_order" + ) + return [r["filename"] for r in rows] + except Exception as e: + log.error(f"Error listing credentials: {e}") + return [] + + async def delete_credential(self, filename: str, mode: str = "geminicli") -> bool: + """删除凭证""" + self._ensure_initialized() + filename = os.path.basename(filename) + + try: + table_name = self._get_table_name(mode) + async with self._pool.acquire() as conn: + result = await conn.execute( + f"DELETE FROM {table_name} WHERE filename = $1", filename + ) + # asyncpg returns "DELETE N" + deleted_count = int(result.split()[-1]) + + if deleted_count > 0: + log.debug(f"Deleted credential: {filename} (mode={mode})") + return True + else: + log.warning(f"No credential found to delete: {filename} (mode={mode})") + return False + + except Exception as e: + log.error(f"Error deleting credential {filename}: {e}") + return False + + async def update_credential_state(self, filename: str, state_updates: Dict[str, Any], mode: str = "geminicli") -> bool: + """更新凭证状态""" + self._ensure_initialized() + filename = os.path.basename(filename) + + try: + table_name = self._get_table_name(mode) + log.debug(f"[DB] update_credential_state: filename={filename}, updates={state_updates}, mode={mode}") + + set_clauses = [] + values = [] + idx = 1 + + for key, value in state_updates.items(): + if key in self.STATE_FIELDS: + if key == "enable_credit" and mode != "antigravity": + continue + if key in ("error_codes", "error_messages", "model_cooldowns"): + set_clauses.append(f"{key} = ${idx}") + values.append(json.dumps(value)) + else: + set_clauses.append(f"{key} = ${idx}") + values.append(value) + idx += 1 + + if not set_clauses: + return True + + set_clauses.append(f"updated_at = EXTRACT(EPOCH FROM NOW())") + values.append(filename) + + sql = f""" + UPDATE {table_name} + SET {', '.join(set_clauses)} + WHERE filename = ${idx} + """ + + async with self._pool.acquire() as conn: + result = await conn.execute(sql, *values) + updated_count = int(result.split()[-1]) + + return updated_count > 0 + + except Exception as e: + log.error(f"[DB] Error updating credential state {filename}: {e}") + return False + + async def get_credential_state(self, filename: str, mode: str = "geminicli") -> Dict[str, Any]: + """获取凭证状态""" + self._ensure_initialized() + filename = os.path.basename(filename) + + try: + table_name = self._get_table_name(mode) + async with self._pool.acquire() as conn: + if mode == "geminicli": + row = await conn.fetchrow(f""" + SELECT disabled, error_codes, last_success, user_email, model_cooldowns, preview, tier + FROM {table_name} WHERE filename = $1 + """, filename) + + if row: + return { + "disabled": bool(row["disabled"]), + "error_codes": json.loads(row["error_codes"] or "[]"), + "last_success": row["last_success"] or time.time(), + "user_email": row["user_email"], + "model_cooldowns": json.loads(row["model_cooldowns"] or "{}"), + "preview": bool(row["preview"]) if row["preview"] is not None else True, + "tier": row["tier"] if row["tier"] is not None else "pro", + } + + return { + "disabled": False, + "error_codes": [], + "last_success": time.time(), + "user_email": None, + "model_cooldowns": {}, + "preview": True, + "tier": "pro", + } + else: + row = await conn.fetchrow(f""" + SELECT disabled, error_codes, last_success, user_email, model_cooldowns, tier, enable_credit + FROM {table_name} WHERE filename = $1 + """, filename) + + if row: + return { + "disabled": bool(row["disabled"]), + "error_codes": json.loads(row["error_codes"] or "[]"), + "last_success": row["last_success"] or time.time(), + "user_email": row["user_email"], + "model_cooldowns": json.loads(row["model_cooldowns"] or "{}"), + "tier": row["tier"] if row["tier"] is not None else "pro", + "enable_credit": bool(row["enable_credit"]) if row["enable_credit"] is not None else False, + } + + return { + "disabled": False, + "error_codes": [], + "last_success": time.time(), + "user_email": None, + "model_cooldowns": {}, + "tier": "pro", + "enable_credit": False, + } + + except Exception as e: + log.error(f"Error getting credential state {filename}: {e}") + return {} + + async def get_all_credential_states(self, mode: str = "geminicli") -> Dict[str, Dict[str, Any]]: + """获取所有凭证状态""" + self._ensure_initialized() + + try: + table_name = self._get_table_name(mode) + current_time = time.time() + + async with self._pool.acquire() as conn: + if mode == "geminicli": + rows = await conn.fetch(f""" + SELECT filename, disabled, error_codes, last_success, + user_email, model_cooldowns, preview, tier + FROM {table_name} + """) + + states = {} + for row in rows: + model_cooldowns = json.loads(row["model_cooldowns"] or "{}") + if model_cooldowns: + model_cooldowns = {k: v for k, v in model_cooldowns.items() if v > current_time} + + states[row["filename"]] = { + "disabled": bool(row["disabled"]), + "error_codes": json.loads(row["error_codes"] or "[]"), + "last_success": row["last_success"] or current_time, + "user_email": row["user_email"], + "model_cooldowns": model_cooldowns, + "preview": bool(row["preview"]) if row["preview"] is not None else True, + "tier": row["tier"] if row["tier"] is not None else "pro", + } + return states + else: + rows = await conn.fetch(f""" + SELECT filename, disabled, error_codes, last_success, + user_email, model_cooldowns, tier, enable_credit + FROM {table_name} + """) + + states = {} + for row in rows: + model_cooldowns = json.loads(row["model_cooldowns"] or "{}") + if model_cooldowns: + model_cooldowns = {k: v for k, v in model_cooldowns.items() if v > current_time} + + states[row["filename"]] = { + "disabled": bool(row["disabled"]), + "error_codes": json.loads(row["error_codes"] or "[]"), + "last_success": row["last_success"] or current_time, + "user_email": row["user_email"], + "model_cooldowns": model_cooldowns, + "tier": row["tier"] if row["tier"] is not None else "pro", + "enable_credit": bool(row["enable_credit"]) if row["enable_credit"] is not None else False, + } + return states + + except Exception as e: + log.error(f"Error getting all credential states: {e}") + return {} + + async def get_credentials_summary( + self, + offset: int = 0, + limit: Optional[int] = None, + status_filter: str = "all", + mode: str = "geminicli", + error_code_filter: Optional[str] = None, + cooldown_filter: Optional[str] = None, + preview_filter: Optional[str] = None, + tier_filter: Optional[str] = None + ) -> Dict[str, Any]: + """获取凭证的摘要信息,支持分页和状态筛选""" + self._ensure_initialized() + + try: + table_name = self._get_table_name(mode) + current_time = time.time() + + async with self._pool.acquire() as conn: + # 全局统计 + stats_rows = await conn.fetch( + f"SELECT disabled, COUNT(*) AS cnt FROM {table_name} GROUP BY disabled" + ) + global_stats = {"total": 0, "normal": 0, "disabled": 0} + for r in stats_rows: + global_stats["total"] += r["cnt"] + if r["disabled"]: + global_stats["disabled"] = r["cnt"] + else: + global_stats["normal"] = r["cnt"] + + # WHERE 子句 + where_clauses = [] + if status_filter == "enabled": + where_clauses.append("disabled = 0") + elif status_filter == "disabled": + where_clauses.append("disabled = 1") + + where_clause = ("WHERE " + " AND ".join(where_clauses)) if where_clauses else "" + + # 查询 + if mode == "geminicli": + all_rows = await conn.fetch(f""" + SELECT filename, disabled, error_codes, last_success, + user_email, rotation_order, model_cooldowns, preview, tier + FROM {table_name} + {where_clause} + ORDER BY rotation_order + """) + else: + all_rows = await conn.fetch(f""" + SELECT filename, disabled, error_codes, last_success, + user_email, rotation_order, model_cooldowns, tier, enable_credit + FROM {table_name} + {where_clause} + ORDER BY rotation_order + """) + + # 错误码筛选 + filter_value = None + filter_int = None + if error_code_filter and str(error_code_filter).strip().lower() != "all": + filter_value = str(error_code_filter).strip() + try: + filter_int = int(filter_value) + except ValueError: + filter_int = None + + all_summaries = [] + for row in all_rows: + error_codes_json = row["error_codes"] or "[]" + model_cooldowns = json.loads(row["model_cooldowns"] or "{}") + active_cooldowns = {k: v for k, v in model_cooldowns.items() if v > current_time} + error_codes = json.loads(error_codes_json) + + if filter_value: + match = False + for code in error_codes: + if code == filter_value or code == filter_int: + match = True + break + if isinstance(code, str) and filter_int is not None: + try: + if int(code) == filter_int: + match = True + break + except ValueError: + pass + if not match: + continue + + summary = { + "filename": row["filename"], + "disabled": bool(row["disabled"]), + "error_codes": error_codes, + "last_success": row["last_success"] or current_time, + "user_email": row["user_email"], + "rotation_order": row["rotation_order"], + "model_cooldowns": active_cooldowns, + "tier": row["tier"] if row["tier"] is not None else "pro", + } + + if mode == "geminicli": + summary["preview"] = bool(row["preview"]) if row["preview"] is not None else True + + if preview_filter: + preview_value = summary.get("preview", True) + if preview_filter == "preview" and not preview_value: + continue + elif preview_filter == "no_preview" and preview_value: + continue + else: + summary["enable_credit"] = bool(row["enable_credit"]) if row["enable_credit"] is not None else False + + if tier_filter and tier_filter in ("free", "pro", "ultra"): + if summary["tier"] != tier_filter: + continue + + if cooldown_filter == "in_cooldown": + if active_cooldowns: + all_summaries.append(summary) + elif cooldown_filter == "no_cooldown": + if not active_cooldowns: + all_summaries.append(summary) + else: + all_summaries.append(summary) + + total_count = len(all_summaries) + if limit is not None: + summaries = all_summaries[offset:offset + limit] + else: + summaries = all_summaries[offset:] + + return { + "items": summaries, + "total": total_count, + "offset": offset, + "limit": limit, + "stats": global_stats, + } + + except Exception as e: + log.error(f"Error getting credentials summary: {e}") + return { + "items": [], + "total": 0, + "offset": offset, + "limit": limit, + "stats": {"total": 0, "normal": 0, "disabled": 0}, + } + + async def get_duplicate_credentials_by_email(self, mode: str = "geminicli") -> Dict[str, Any]: + """获取按邮箱分组的重复凭证信息""" + self._ensure_initialized() + + try: + table_name = self._get_table_name(mode) + async with self._pool.acquire() as conn: + rows = await conn.fetch( + f"SELECT filename, user_email FROM {table_name} ORDER BY filename" + ) + + email_to_files: Dict[str, List[str]] = {} + no_email_files: List[str] = [] + + for row in rows: + if row["user_email"]: + email_to_files.setdefault(row["user_email"], []).append(row["filename"]) + else: + no_email_files.append(row["filename"]) + + duplicate_groups = [] + total_duplicate_count = 0 + for email, files in email_to_files.items(): + if len(files) > 1: + duplicate_groups.append({ + "email": email, + "kept_file": files[0], + "duplicate_files": files[1:], + "duplicate_count": len(files) - 1, + }) + total_duplicate_count += len(files) - 1 + + return { + "email_groups": email_to_files, + "duplicate_groups": duplicate_groups, + "duplicate_count": total_duplicate_count, + "no_email_files": no_email_files, + "no_email_count": len(no_email_files), + "unique_email_count": len(email_to_files), + "total_count": len(rows), + } + + except Exception as e: + log.error(f"Error getting duplicate credentials by email: {e}") + return { + "email_groups": {}, + "duplicate_groups": [], + "duplicate_count": 0, + "no_email_files": [], + "no_email_count": 0, + "unique_email_count": 0, + "total_count": 0, + } + + # ============ 配置管理(内存缓存)============ + + async def set_config(self, key: str, value: Any) -> bool: + """设置配置(写入数据库 + 更新内存缓存)""" + self._ensure_initialized() + + try: + async with self._pool.acquire() as conn: + await conn.execute(""" + INSERT INTO config (key, value, updated_at) + VALUES ($1, $2, EXTRACT(EPOCH FROM NOW())) + ON CONFLICT (key) DO UPDATE + SET value = EXCLUDED.value, + updated_at = EXCLUDED.updated_at + """, key, json.dumps(value)) + + self._config_cache[key] = value + return True + + except Exception as e: + log.error(f"Error setting config {key}: {e}") + return False + + async def reload_config_cache(self) -> None: + """重新加载配置缓存""" + self._ensure_initialized() + self._config_loaded = False + await self._load_config_cache() + log.info("Config cache reloaded from database") + + async def get_config(self, key: str, default: Any = None) -> Any: + """获取配置(从内存缓存)""" + self._ensure_initialized() + return self._config_cache.get(key, default) + + async def get_all_config(self) -> Dict[str, Any]: + """获取所有配置(从内存缓存)""" + self._ensure_initialized() + return self._config_cache.copy() + + async def delete_config(self, key: str) -> bool: + """删除配置""" + self._ensure_initialized() + + try: + async with self._pool.acquire() as conn: + await conn.execute("DELETE FROM config WHERE key = $1", key) + + self._config_cache.pop(key, None) + return True + + except Exception as e: + log.error(f"Error deleting config {key}: {e}") + return False + + async def get_credential_errors(self, filename: str, mode: str = "geminicli") -> Dict[str, Any]: + """获取凭证的错误信息""" + self._ensure_initialized() + filename = os.path.basename(filename) + + try: + table_name = self._get_table_name(mode) + async with self._pool.acquire() as conn: + row = await conn.fetchrow( + f"SELECT error_codes, error_messages FROM {table_name} WHERE filename = $1", + filename + ) + + if row: + return { + "filename": filename, + "error_codes": json.loads(row["error_codes"] or "[]"), + "error_messages": json.loads(row["error_messages"] or "[]"), + } + + return {"filename": filename, "error_codes": [], "error_messages": []} + + except Exception as e: + log.error(f"Error getting credential errors {filename}: {e}") + return {"filename": filename, "error_codes": [], "error_messages": [], "error": str(e)} + + # ============ 模型级冷却管理 ============ + + async def set_model_cooldown( + self, + filename: str, + model_name: str, + cooldown_until: Optional[float], + mode: str = "geminicli" + ) -> bool: + """设置特定模型的冷却时间""" + self._ensure_initialized() + filename = os.path.basename(filename) + + try: + table_name = self._get_table_name(mode) + async with self._pool.acquire() as conn: + row = await conn.fetchrow( + f"SELECT model_cooldowns FROM {table_name} WHERE filename = $1", filename + ) + + if not row: + log.warning(f"Credential {filename} not found") + return False + + model_cooldowns = json.loads(row["model_cooldowns"] or "{}") + + if cooldown_until is None: + model_cooldowns.pop(model_name, None) + else: + model_cooldowns[model_name] = cooldown_until + + await conn.execute( + f""" + UPDATE {table_name} + SET model_cooldowns = $1, + updated_at = EXTRACT(EPOCH FROM NOW()) + WHERE filename = $2 + """, + json.dumps(model_cooldowns), filename + ) + + log.debug(f"Set model cooldown: {filename}, model_name={model_name}, cooldown_until={cooldown_until}") + return True + + except Exception as e: + log.error(f"Error setting model cooldown for {filename}: {e}") + return False + + async def clear_all_model_cooldowns( + self, + filename: str, + mode: str = "geminicli" + ) -> bool: + """清除某个凭证的所有模型冷却时间""" + self._ensure_initialized() + filename = os.path.basename(filename) + + try: + table_name = self._get_table_name(mode) + async with self._pool.acquire() as conn: + result = await conn.execute( + f""" + UPDATE {table_name} + SET model_cooldowns = '{{}}', + updated_at = EXTRACT(EPOCH FROM NOW()) + WHERE filename = $1 + """, + filename, + ) + updated_count = int(result.split()[-1]) + + if updated_count == 0: + log.warning(f"Credential {filename} not found") + return False + + log.debug(f"Cleared all model cooldowns: {filename} (mode={mode})") + return True + + except Exception as e: + log.error(f"Error clearing all model cooldowns for {filename}: {e}") + return False + + async def record_success( + self, + filename: str, + model_name: Optional[str] = None, + mode: str = "geminicli" + ) -> None: + """成功调用后的条件写入""" + self._ensure_initialized() + filename = os.path.basename(filename) + + try: + table_name = self._get_table_name(mode) + async with self._pool.acquire() as conn: + await conn.execute(f""" + UPDATE {table_name} + SET last_success = EXTRACT(EPOCH FROM NOW()), + error_codes = '[]', + error_messages = '{{}}', + updated_at = EXTRACT(EPOCH FROM NOW()) + WHERE filename = $1 + AND (error_codes IS NOT NULL AND error_codes != '[]' AND error_codes != '') + """, filename) + + if model_name: + row = await conn.fetchrow( + f"SELECT model_cooldowns FROM {table_name} WHERE filename = $1", filename + ) + if row: + cooldowns = json.loads(row["model_cooldowns"] or "{}") + if model_name in cooldowns: + cooldowns.pop(model_name) + await conn.execute( + f""" + UPDATE {table_name} + SET model_cooldowns = $1, updated_at = EXTRACT(EPOCH FROM NOW()) + WHERE filename = $2 + """, + json.dumps(cooldowns), filename + ) + + except Exception as e: + log.error(f"Error recording success for {filename}: {e}") diff --git a/src/storage/sqlite_manager.py b/src/storage/sqlite_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..95a107487f573c87b8d376eb8393c88fa277ccbd --- /dev/null +++ b/src/storage/sqlite_manager.py @@ -0,0 +1,1365 @@ +""" +SQLite 存储管理器 +""" + +import asyncio +import json +import os +import time +from typing import Any, Dict, List, Optional, Tuple + +import aiosqlite + +from log import log + + +class SQLiteManager: + """SQLite 数据库管理器""" + + # 状态字段常量 + STATE_FIELDS = { + "error_codes", + "error_messages", + "disabled", + "last_success", + "user_email", + "model_cooldowns", + "preview", + "tier", + "enable_credit", + } + + # 所有必需的列定义(用于自动校验和修复) + REQUIRED_COLUMNS = { + "credentials": [ + ("disabled", "INTEGER DEFAULT 0"), + ("error_codes", "TEXT DEFAULT '[]'"), + ("error_messages", "TEXT DEFAULT '[]'"), + ("last_success", "REAL"), + ("user_email", "TEXT"), + ("model_cooldowns", "TEXT DEFAULT '{}'"), + ("preview", "INTEGER DEFAULT 1"), + ("tier", "TEXT DEFAULT 'pro'"), + ("rotation_order", "INTEGER DEFAULT 0"), + ("call_count", "INTEGER DEFAULT 0"), + ("created_at", "REAL DEFAULT (unixepoch())"), + ("updated_at", "REAL DEFAULT (unixepoch())") + ], + "antigravity_credentials": [ + ("disabled", "INTEGER DEFAULT 0"), + ("error_codes", "TEXT DEFAULT '[]'"), + ("error_messages", "TEXT DEFAULT '[]'"), + ("last_success", "REAL"), + ("user_email", "TEXT"), + ("model_cooldowns", "TEXT DEFAULT '{}'"), + ("tier", "TEXT DEFAULT 'pro'"), + ("enable_credit", "INTEGER DEFAULT 0"), + ("rotation_order", "INTEGER DEFAULT 0"), + ("call_count", "INTEGER DEFAULT 0"), + ("created_at", "REAL DEFAULT (unixepoch())"), + ("updated_at", "REAL DEFAULT (unixepoch())") + ] + } + + def __init__(self): + self._db_path = None + self._credentials_dir = None + self._initialized = False + self._lock = asyncio.Lock() + + # 内存配置缓存 - 初始化时加载一次 + self._config_cache: Dict[str, Any] = {} + self._config_loaded = False + + async def initialize(self) -> None: + """初始化 SQLite 数据库""" + if self._initialized: + return + + async with self._lock: + if self._initialized: + return + + try: + # 获取凭证目录 + self._credentials_dir = os.getenv("CREDENTIALS_DIR", "./creds") + self._db_path = os.path.join(self._credentials_dir, "credentials.db") + + # 确保目录存在 + os.makedirs(self._credentials_dir, exist_ok=True) + + # 创建数据库和表 + async with aiosqlite.connect(self._db_path) as db: + # 启用 WAL 模式(提升并发性能) + await db.execute("PRAGMA journal_mode=WAL") + await db.execute("PRAGMA foreign_keys=ON") + + # 检查并自动修复数据库结构 + await self._ensure_schema_compatibility(db) + + # 创建表 + await self._create_tables(db) + + # 修复可能包含路径的凭证文件名 + await self._repair_credential_filenames(db) + + await db.commit() + + # 加载配置到内存 + await self._load_config_cache() + + self._initialized = True + log.info(f"SQLite storage initialized at {self._db_path}") + + except Exception as e: + log.error(f"Error initializing SQLite: {e}") + raise + + async def _ensure_schema_compatibility(self, db: aiosqlite.Connection) -> None: + """ + 确保数据库结构兼容,自动修复缺失的列 + """ + try: + # 检查每个表 + for table_name, columns in self.REQUIRED_COLUMNS.items(): + # 检查表是否存在 + async with db.execute( + "SELECT name FROM sqlite_master WHERE type='table' AND name=?", + (table_name,) + ) as cursor: + if not await cursor.fetchone(): + log.debug(f"Table {table_name} does not exist, will be created") + continue + + # 获取现有列 + async with db.execute(f"PRAGMA table_info({table_name})") as cursor: + existing_columns = {row[1] for row in await cursor.fetchall()} + + # 添加缺失的列 + added_count = 0 + for col_name, col_def in columns: + if col_name not in existing_columns: + try: + await db.execute(f"ALTER TABLE {table_name} ADD COLUMN {col_name} {col_def}") + log.info(f"Added missing column {table_name}.{col_name}") + added_count += 1 + except Exception as e: + log.error(f"Failed to add column {table_name}.{col_name}: {e}") + + if added_count > 0: + log.info(f"Table {table_name}: added {added_count} missing column(s)") + + except Exception as e: + log.error(f"Error ensuring schema compatibility: {e}") + # 不抛出异常,允许继续初始化 + + async def _create_tables(self, db: aiosqlite.Connection): + """创建数据库表和索引""" + # 凭证表 + await db.execute(""" + CREATE TABLE IF NOT EXISTS credentials ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + filename TEXT UNIQUE NOT NULL, + credential_data TEXT NOT NULL, + + -- 状态字段 + disabled INTEGER DEFAULT 0, + error_codes TEXT DEFAULT '[]', + error_messages TEXT DEFAULT '[]', + last_success REAL, + user_email TEXT, + + -- 模型级 CD 支持 (JSON: {model_name: cooldown_timestamp}) + model_cooldowns TEXT DEFAULT '{}', + + -- preview 状态 (只对 geminicli 有效,默认为 true) + preview INTEGER DEFAULT 1, + + -- tier 状态 (只对 geminicli 有效,默认为 pro) + tier TEXT DEFAULT 'pro', + + -- 轮换相关 + rotation_order INTEGER DEFAULT 0, + call_count INTEGER DEFAULT 0, + + -- 时间戳 + created_at REAL DEFAULT (unixepoch()), + updated_at REAL DEFAULT (unixepoch()) + ) + """) + + # Antigravity 凭证表(结构相同但独立存储) + await db.execute(""" + CREATE TABLE IF NOT EXISTS antigravity_credentials ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + filename TEXT UNIQUE NOT NULL, + credential_data TEXT NOT NULL, + + -- 状态字段 + disabled INTEGER DEFAULT 0, + error_codes TEXT DEFAULT '[]', + error_messages TEXT DEFAULT '[]', + last_success REAL, + user_email TEXT, + + -- 模型级 CD 支持 (JSON: {model_name: cooldown_timestamp}) + model_cooldowns TEXT DEFAULT '{}', + + -- tier 状态 (默认为 pro) + tier TEXT DEFAULT 'pro', + + -- 是否启用信用额度模式(仅 antigravity,有效值 0/1) + enable_credit INTEGER DEFAULT 0, + + -- 轮换相关 + rotation_order INTEGER DEFAULT 0, + call_count INTEGER DEFAULT 0, + + -- 时间戳 + created_at REAL DEFAULT (unixepoch()), + updated_at REAL DEFAULT (unixepoch()) + ) + """) + + # 创建索引 - 普通凭证表 + await db.execute(""" + CREATE INDEX IF NOT EXISTS idx_disabled + ON credentials(disabled) + """) + await db.execute(""" + CREATE INDEX IF NOT EXISTS idx_rotation_order + ON credentials(rotation_order) + """) + + # 创建索引 - Antigravity 凭证表 + await db.execute(""" + CREATE INDEX IF NOT EXISTS idx_ag_disabled + ON antigravity_credentials(disabled) + """) + await db.execute(""" + CREATE INDEX IF NOT EXISTS idx_ag_rotation_order + ON antigravity_credentials(rotation_order) + """) + + # 配置表 + await db.execute(""" + CREATE TABLE IF NOT EXISTS config ( + key TEXT PRIMARY KEY, + value TEXT NOT NULL, + updated_at REAL DEFAULT (unixepoch()) + ) + """) + + log.debug("SQLite tables and indexes created") + + async def _repair_credential_filenames(self, db: aiosqlite.Connection): + """ + 修复凭证数据库中可能包含路径的文件名,确保所有文件名都是 basename + """ + try: + repaired_count = 0 + + # 修复 credentials 表 + async with db.execute("SELECT filename FROM credentials") as cursor: + rows = await cursor.fetchall() + for (filename,) in rows: + basename = os.path.basename(filename) + if basename != filename: + # 检查是否会产生冲突 + async with db.execute( + "SELECT COUNT(*) FROM credentials WHERE filename = ?", + (basename,) + ) as check_cursor: + count = (await check_cursor.fetchone())[0] + + if count == 0: + # 无冲突,直接更新 + await db.execute( + "UPDATE credentials SET filename = ? WHERE filename = ?", + (basename, filename) + ) + repaired_count += 1 + log.info(f"Repaired credential filename: {filename} -> {basename}") + else: + # 有冲突,删除带路径的旧记录(保留 basename 的记录) + await db.execute( + "DELETE FROM credentials WHERE filename = ?", + (filename,) + ) + repaired_count += 1 + log.warning(f"Removed duplicate credential with path: {filename} (kept {basename})") + + # 修复 antigravity_credentials 表 + async with db.execute("SELECT filename FROM antigravity_credentials") as cursor: + rows = await cursor.fetchall() + for (filename,) in rows: + basename = os.path.basename(filename) + if basename != filename: + # 检查是否会产生冲突 + async with db.execute( + "SELECT COUNT(*) FROM antigravity_credentials WHERE filename = ?", + (basename,) + ) as check_cursor: + count = (await check_cursor.fetchone())[0] + + if count == 0: + # 无冲突,直接更新 + await db.execute( + "UPDATE antigravity_credentials SET filename = ? WHERE filename = ?", + (basename, filename) + ) + repaired_count += 1 + log.info(f"Repaired antigravity credential filename: {filename} -> {basename}") + else: + # 有冲突,删除带路径的旧记录(保留 basename 的记录) + await db.execute( + "DELETE FROM antigravity_credentials WHERE filename = ?", + (filename,) + ) + repaired_count += 1 + log.warning(f"Removed duplicate antigravity credential with path: {filename} (kept {basename})") + + if repaired_count > 0: + log.info(f"Repaired {repaired_count} credential filename(s)") + else: + log.debug("No credential filenames need repair") + + except Exception as e: + log.error(f"Error repairing credential filenames: {e}") + # 不抛出异常,允许继续初始化 + + async def _load_config_cache(self): + """加载配置到内存缓存(仅在初始化时调用一次)""" + if self._config_loaded: + return + + try: + async with aiosqlite.connect(self._db_path) as db: + async with db.execute("SELECT key, value FROM config") as cursor: + rows = await cursor.fetchall() + + for key, value in rows: + try: + self._config_cache[key] = json.loads(value) + except json.JSONDecodeError: + self._config_cache[key] = value + + self._config_loaded = True + log.debug(f"Loaded {len(self._config_cache)} config items into cache") + + except Exception as e: + log.error(f"Error loading config cache: {e}") + self._config_cache = {} + + async def close(self) -> None: + """关闭数据库连接""" + self._initialized = False + log.debug("SQLite storage closed") + + def _ensure_initialized(self): + """确保已初始化""" + if not self._initialized: + raise RuntimeError("SQLite manager not initialized") + + def _get_table_name(self, mode: str) -> str: + """根据 mode 获取对应的表名""" + if mode == "antigravity": + return "antigravity_credentials" + elif mode == "geminicli": + return "credentials" + else: + raise ValueError(f"Invalid mode: {mode}. Must be 'geminicli' or 'antigravity'") + + # ============ SQL 方法 ============ + + async def get_next_available_credential( + self, mode: str = "geminicli", model_name: Optional[str] = None + ) -> Optional[Tuple[str, Dict[str, Any]]]: + """ + 随机获取一个可用凭证(负载均衡) + - 未禁用 + - 如果提供了 model_name,还会检查模型级冷却和preview状态 + - 随机选择 + + Args: + mode: 凭证模式 ("geminicli" 或 "antigravity") + model_name: 完整模型名(如 "gemini-2.0-flash-exp", "gemini-3-flash-preview") + """ + self._ensure_initialized() + + try: + table_name = self._get_table_name(mode) + async with aiosqlite.connect(self._db_path) as db: + current_time = time.time() + + if mode == "geminicli": + tier_clause = "" + if model_name and "pro" in model_name.lower(): + tier_clause = "AND (tier IS NULL OR tier != 'free')" + + async with db.execute(f""" + SELECT filename, credential_data, model_cooldowns, preview + FROM {table_name} + WHERE disabled = 0 {tier_clause} + ORDER BY RANDOM() + """) as cursor: + rows = await cursor.fetchall() + + if not model_name: + if rows: + filename, credential_json, _, _ = rows[0] + credential_data = json.loads(credential_json) + return filename, credential_data + return None + + is_preview_model = "preview" in model_name.lower() + non_preview_creds = [] + preview_creds = [] + + for filename, credential_json, model_cooldowns_json, preview in rows: + model_cooldowns = json.loads(model_cooldowns_json or '{}') + model_cooldown = model_cooldowns.get(model_name) + if model_cooldown is None or current_time >= model_cooldown: + if preview: + preview_creds.append((filename, credential_json)) + else: + non_preview_creds.append((filename, credential_json)) + + if is_preview_model: + if preview_creds: + filename, credential_json = preview_creds[0] + credential_data = json.loads(credential_json) + return filename, credential_data + else: + if non_preview_creds: + filename, credential_json = non_preview_creds[0] + credential_data = json.loads(credential_json) + return filename, credential_data + elif preview_creds: + filename, credential_json = preview_creds[0] + credential_data = json.loads(credential_json) + return filename, credential_data + + return None + else: + async with db.execute(f""" + SELECT filename, credential_data, model_cooldowns, enable_credit + FROM {table_name} + WHERE disabled = 0 + ORDER BY RANDOM() + """) as cursor: + rows = await cursor.fetchall() + + if not model_name: + if rows: + filename, credential_json, _, enable_credit = rows[0] + credential_data = json.loads(credential_json) + credential_data["enable_credit"] = bool(enable_credit) + return filename, credential_data + return None + + for filename, credential_json, model_cooldowns_json, enable_credit in rows: + model_cooldowns = json.loads(model_cooldowns_json or '{}') + model_cooldown = model_cooldowns.get(model_name) + if model_cooldown is None or current_time >= model_cooldown: + credential_data = json.loads(credential_json) + credential_data["enable_credit"] = bool(enable_credit) + return filename, credential_data + + return None + + except Exception as e: + log.error(f"Error getting next available credential (mode={mode}, model_name={model_name}): {e}") + return None + + async def get_available_credentials_list(self) -> List[str]: + """ + 获取所有可用凭证列表 + - 未禁用 + - 按轮换顺序排序 + """ + self._ensure_initialized() + + try: + async with aiosqlite.connect(self._db_path) as db: + async with db.execute(""" + SELECT filename + FROM credentials + WHERE disabled = 0 + ORDER BY rotation_order ASC + """) as cursor: + rows = await cursor.fetchall() + return [row[0] for row in rows] + + except Exception as e: + log.error(f"Error getting available credentials list: {e}") + return [] + + # ============ StorageBackend 协议方法 ============ + + async def store_credential(self, filename: str, credential_data: Dict[str, Any], mode: str = "geminicli") -> bool: + """存储或更新凭证""" + self._ensure_initialized() + + # 统一使用 basename 处理文件名 + filename = os.path.basename(filename) + + try: + table_name = self._get_table_name(mode) + async with aiosqlite.connect(self._db_path) as db: + # 检查凭证是否存在 + async with db.execute(f""" + SELECT disabled, error_codes, last_success, user_email, + rotation_order, call_count + FROM {table_name} WHERE filename = ? + """, (filename,)) as cursor: + existing = await cursor.fetchone() + + if existing: + # 更新现有凭证(保留状态) + await db.execute(f""" + UPDATE {table_name} + SET credential_data = ?, + updated_at = unixepoch() + WHERE filename = ? + """, (json.dumps(credential_data), filename)) + else: + # 插入新凭证 + async with db.execute(f""" + SELECT COALESCE(MAX(rotation_order), -1) + 1 FROM {table_name} + """) as cursor: + row = await cursor.fetchone() + next_order = row[0] + + await db.execute(f""" + INSERT INTO {table_name} + (filename, credential_data, rotation_order, last_success) + VALUES (?, ?, ?, ?) + """, (filename, json.dumps(credential_data), next_order, time.time())) + + await db.commit() + log.debug(f"Stored credential: {filename} (mode={mode})") + return True + + except Exception as e: + log.error(f"Error storing credential {filename}: {e}") + return False + + async def get_credential(self, filename: str, mode: str = "geminicli") -> Optional[Dict[str, Any]]: + """获取凭证数据""" + self._ensure_initialized() + + # 统一使用 basename 处理文件名 + filename = os.path.basename(filename) + + try: + table_name = self._get_table_name(mode) + async with aiosqlite.connect(self._db_path) as db: + # 精确匹配 + async with db.execute(f""" + SELECT credential_data FROM {table_name} WHERE filename = ? + """, (filename,)) as cursor: + row = await cursor.fetchone() + if row: + return json.loads(row[0]) + + return None + + except Exception as e: + log.error(f"Error getting credential {filename}: {e}") + return None + + async def list_credentials(self, mode: str = "geminicli") -> List[str]: + """列出所有凭证文件名(包括禁用的)""" + self._ensure_initialized() + + try: + table_name = self._get_table_name(mode) + async with aiosqlite.connect(self._db_path) as db: + async with db.execute(f""" + SELECT filename FROM {table_name} ORDER BY rotation_order + """) as cursor: + rows = await cursor.fetchall() + return [row[0] for row in rows] + + except Exception as e: + log.error(f"Error listing credentials: {e}") + return [] + + async def delete_credential(self, filename: str, mode: str = "geminicli") -> bool: + """删除凭证""" + self._ensure_initialized() + + # 统一使用 basename 处理文件名 + filename = os.path.basename(filename) + + try: + table_name = self._get_table_name(mode) + async with aiosqlite.connect(self._db_path) as db: + # 精确匹配删除 + result = await db.execute(f""" + DELETE FROM {table_name} WHERE filename = ? + """, (filename,)) + deleted_count = result.rowcount + + await db.commit() + + if deleted_count > 0: + log.debug(f"Deleted {deleted_count} credential(s): {filename} (mode={mode})") + return True + else: + log.warning(f"No credential found to delete: {filename} (mode={mode})") + return False + + except Exception as e: + log.error(f"Error deleting credential {filename}: {e}") + return False + + async def update_credential_state(self, filename: str, state_updates: Dict[str, Any], mode: str = "geminicli") -> bool: + """更新凭证状态""" + self._ensure_initialized() + + # 统一使用 basename 处理文件名 + filename = os.path.basename(filename) + + try: + table_name = self._get_table_name(mode) + log.debug(f"[DB] update_credential_state 开始: filename={filename}, state_updates={state_updates}, mode={mode}, table={table_name}") + + # 构建动态 SQL + set_clauses = [] + values = [] + + for key, value in state_updates.items(): + if key in self.STATE_FIELDS: + if key == "enable_credit" and mode != "antigravity": + continue + if key in ("error_codes", "error_messages", "model_cooldowns"): + # JSON 字段需要序列化 + set_clauses.append(f"{key} = ?") + values.append(json.dumps(value)) + else: + set_clauses.append(f"{key} = ?") + values.append(value) + + if not set_clauses: + log.info(f"[DB] 没有需要更新的状态字段") + return True + + set_clauses.append("updated_at = unixepoch()") + values.append(filename) + + log.debug(f"[DB] SQL参数: set_clauses={set_clauses}, values={values}") + + async with aiosqlite.connect(self._db_path) as db: + # 精确匹配更新 + sql_exact = f""" + UPDATE {table_name} + SET {', '.join(set_clauses)} + WHERE filename = ? + """ + log.debug(f"[DB] 执行精确匹配SQL: {sql_exact}") + log.debug(f"[DB] SQL参数值: {values}") + + result = await db.execute(sql_exact, values) + updated_count = result.rowcount + log.debug(f"[DB] 精确匹配 rowcount={updated_count}") + + # 提交前检查 + log.debug(f"[DB] 准备commit,总更新行数={updated_count}") + await db.commit() + log.debug(f"[DB] commit完成") + + success = updated_count > 0 + log.debug(f"[DB] update_credential_state 结束: success={success}, updated_count={updated_count}") + return success + + except Exception as e: + log.error(f"[DB] Error updating credential state {filename}: {e}") + return False + + async def get_credential_state(self, filename: str, mode: str = "geminicli") -> Dict[str, Any]: + """获取凭证状态(不包含error_messages)""" + self._ensure_initialized() + + # 统一使用 basename 处理文件名 + filename = os.path.basename(filename) + + try: + table_name = self._get_table_name(mode) + async with aiosqlite.connect(self._db_path) as db: + # 精确匹配 + if mode == "geminicli": + async with db.execute(f""" + SELECT disabled, error_codes, last_success, user_email, model_cooldowns, preview, tier + FROM {table_name} WHERE filename = ? + """, (filename,)) as cursor: + row = await cursor.fetchone() + + if row: + error_codes_json = row[1] or '[]' + model_cooldowns_json = row[4] or '{}' + return { + "disabled": bool(row[0]), + "error_codes": json.loads(error_codes_json), + "last_success": row[2] or time.time(), + "user_email": row[3], + "model_cooldowns": json.loads(model_cooldowns_json), + "preview": bool(row[5]) if row[5] is not None else True, + "tier": row[6] if row[6] is not None else "pro", + } + + # 返回默认状态 + return { + "disabled": False, + "error_codes": [], + "last_success": time.time(), + "user_email": None, + "model_cooldowns": {}, + "preview": True, + "tier": "pro", + } + else: + # antigravity 模式 + async with db.execute(f""" + SELECT disabled, error_codes, last_success, user_email, model_cooldowns, tier, enable_credit + FROM {table_name} WHERE filename = ? + """, (filename,)) as cursor: + row = await cursor.fetchone() + + if row: + error_codes_json = row[1] or '[]' + model_cooldowns_json = row[4] or '{}' + return { + "disabled": bool(row[0]), + "error_codes": json.loads(error_codes_json), + "last_success": row[2] or time.time(), + "user_email": row[3], + "model_cooldowns": json.loads(model_cooldowns_json), + "tier": row[5] if row[5] is not None else "pro", + "enable_credit": bool(row[6]) if row[6] is not None else False, + } + + # 返回默认状态 + return { + "disabled": False, + "error_codes": [], + "last_success": time.time(), + "user_email": None, + "model_cooldowns": {}, + "tier": "pro", + "enable_credit": False, + } + + except Exception as e: + log.error(f"Error getting credential state {filename}: {e}") + return {} + + async def get_all_credential_states(self, mode: str = "geminicli") -> Dict[str, Dict[str, Any]]: + """获取所有凭证状态(不包含error_messages)""" + self._ensure_initialized() + + try: + table_name = self._get_table_name(mode) + async with aiosqlite.connect(self._db_path) as db: + if mode == "geminicli": + async with db.execute(f""" + SELECT filename, disabled, error_codes, last_success, + user_email, model_cooldowns, preview, tier + FROM {table_name} + """) as cursor: + rows = await cursor.fetchall() + + states = {} + current_time = time.time() + + for row in rows: + filename = row[0] + error_codes_json = row[2] or '[]' + model_cooldowns_json = row[5] or '{}' + model_cooldowns = json.loads(model_cooldowns_json) + + # 自动过滤掉已过期的模型CD + if model_cooldowns: + model_cooldowns = { + k: v for k, v in model_cooldowns.items() + if v > current_time + } + + states[filename] = { + "disabled": bool(row[1]), + "error_codes": json.loads(error_codes_json), + "last_success": row[3] or time.time(), + "user_email": row[4], + "model_cooldowns": model_cooldowns, + "preview": bool(row[6]) if row[6] is not None else True, + "tier": row[7] if row[7] is not None else "pro", + } + + return states + else: + # antigravity 模式 + async with db.execute(f""" + SELECT filename, disabled, error_codes, last_success, + user_email, model_cooldowns, tier, enable_credit + FROM {table_name} + """) as cursor: + rows = await cursor.fetchall() + + states = {} + current_time = time.time() + + for row in rows: + filename = row[0] + error_codes_json = row[2] or '[]' + model_cooldowns_json = row[5] or '{}' + model_cooldowns = json.loads(model_cooldowns_json) + + # 自动过滤掉已过期的模型CD + if model_cooldowns: + model_cooldowns = { + k: v for k, v in model_cooldowns.items() + if v > current_time + } + + states[filename] = { + "disabled": bool(row[1]), + "error_codes": json.loads(error_codes_json), + "last_success": row[3] or time.time(), + "user_email": row[4], + "model_cooldowns": model_cooldowns, + "tier": row[6] if row[6] is not None else "pro", + "enable_credit": bool(row[7]) if row[7] is not None else False, + } + + return states + + except Exception as e: + log.error(f"Error getting all credential states: {e}") + return {} + + async def get_credentials_summary( + self, + offset: int = 0, + limit: Optional[int] = None, + status_filter: str = "all", + mode: str = "geminicli", + error_code_filter: Optional[str] = None, + cooldown_filter: Optional[str] = None, + preview_filter: Optional[str] = None, + tier_filter: Optional[str] = None + ) -> Dict[str, Any]: + """ + 获取凭证的摘要信息(不包含完整凭证数据)- 支持分页和状态筛选 + + Args: + offset: 跳过的记录数(默认0) + limit: 返回的最大记录数(None表示返回所有) + status_filter: 状态筛选(all=全部, enabled=仅启用, disabled=仅禁用) + mode: 凭证模式 ("geminicli" 或 "antigravity") + error_code_filter: 错误码筛选(格式如"400"或"403",筛选包含该错误码的凭证) + cooldown_filter: 冷却状态筛选("in_cooldown"=冷却中, "no_cooldown"=未冷却) + preview_filter: Preview筛选("preview"=支持preview, "no_preview"=不支持preview,仅geminicli模式有效) + tier_filter: tier筛选("free", "pro", "ultra") + + Returns: + 包含 items(凭证列表)、total(总数)、offset、limit 的字典 + """ + self._ensure_initialized() + + try: + # 根据 mode 选择表名 + table_name = self._get_table_name(mode) + + async with aiosqlite.connect(self._db_path) as db: + # 先计算全局统计数据(不受筛选条件影响) + global_stats = {"total": 0, "normal": 0, "disabled": 0} + async with db.execute(f""" + SELECT disabled, COUNT(*) FROM {table_name} GROUP BY disabled + """) as stats_cursor: + stats_rows = await stats_cursor.fetchall() + for disabled, count in stats_rows: + global_stats["total"] += count + if disabled: + global_stats["disabled"] = count + else: + global_stats["normal"] = count + + # 构建WHERE子句 + where_clauses = [] + count_params = [] + + if status_filter == "enabled": + where_clauses.append("disabled = 0") + elif status_filter == "disabled": + where_clauses.append("disabled = 1") + + filter_value = None + filter_int = None + if error_code_filter and str(error_code_filter).strip().lower() != "all": + filter_value = str(error_code_filter).strip() + try: + filter_int = int(filter_value) + except ValueError: + filter_int = None + + # 构建WHERE子句 + where_clause = "" + if where_clauses: + where_clause = "WHERE " + " AND ".join(where_clauses) + + # 先获取所有数据(用于冷却筛选,因为需要在Python中判断) + if mode == "geminicli": + all_query = f""" + SELECT filename, disabled, error_codes, last_success, + user_email, rotation_order, model_cooldowns, preview, tier + FROM {table_name} + {where_clause} + ORDER BY rotation_order + """ + else: + all_query = f""" + SELECT filename, disabled, error_codes, last_success, + user_email, rotation_order, model_cooldowns, tier, enable_credit + FROM {table_name} + {where_clause} + ORDER BY rotation_order + """ + + async with db.execute(all_query, count_params) as cursor: + all_rows = await cursor.fetchall() + + current_time = time.time() + all_summaries = [] + + for row in all_rows: + filename = row[0] + error_codes_json = row[2] or '[]' + model_cooldowns_json = row[6] or '{}' + model_cooldowns = json.loads(model_cooldowns_json) + + # 自动过滤掉已过期的模型CD + active_cooldowns = {} + if model_cooldowns: + active_cooldowns = { + k: v for k, v in model_cooldowns.items() + if v > current_time + } + + error_codes = json.loads(error_codes_json) + if filter_value: + match = False + for code in error_codes: + if code == filter_value or code == filter_int: + match = True + break + if isinstance(code, str) and filter_int is not None: + try: + if int(code) == filter_int: + match = True + break + except ValueError: + pass + if not match: + continue + + summary = { + "filename": filename, + "disabled": bool(row[1]), + "error_codes": error_codes, + "last_success": row[3] or current_time, + "user_email": row[4], + "rotation_order": row[5], + "model_cooldowns": active_cooldowns, + "tier": row[8] if mode == "geminicli" and row[8] is not None else ( + row[7] if mode != "geminicli" and row[7] is not None else "pro" + ), + } + + if mode != "geminicli": + summary["enable_credit"] = bool(row[8]) if row[8] is not None else False + + if mode == "geminicli": + summary["preview"] = bool(row[7]) if row[7] is not None else True + + if preview_filter: + preview_value = summary.get("preview", True) + if preview_filter == "preview" and not preview_value: + continue + elif preview_filter == "no_preview" and preview_value: + continue + + # 应用tier筛选 + if tier_filter and tier_filter in ("free", "pro", "ultra"): + if summary["tier"] != tier_filter: + continue + + # 应用冷却筛选 + if cooldown_filter == "in_cooldown": + # 只保留有冷却的凭证 + if active_cooldowns: + all_summaries.append(summary) + elif cooldown_filter == "no_cooldown": + # 只保留没有冷却的凭证 + if not active_cooldowns: + all_summaries.append(summary) + else: + # 不筛选冷却状态 + all_summaries.append(summary) + + # 应用分页 + total_count = len(all_summaries) + if limit is not None: + summaries = all_summaries[offset:offset + limit] + else: + summaries = all_summaries[offset:] + + return { + "items": summaries, + "total": total_count, + "offset": offset, + "limit": limit, + "stats": global_stats, + } + + except Exception as e: + log.error(f"Error getting credentials summary: {e}") + return { + "items": [], + "total": 0, + "offset": offset, + "limit": limit, + "stats": {"total": 0, "normal": 0, "disabled": 0}, + } + + async def get_duplicate_credentials_by_email(self, mode: str = "geminicli") -> Dict[str, Any]: + """ + 获取按邮箱分组的重复凭证信息(只查询邮箱和文件名,不加载完整凭证数据) + 用于去重操作 + + Args: + mode: 凭证模式 ("geminicli" 或 "antigravity") + + Returns: + 包含 email_groups(邮箱分组)、duplicate_count(重复数量)、no_email_count(无邮箱数量)的字典 + """ + self._ensure_initialized() + + try: + # 根据 mode 选择表名 + table_name = self._get_table_name(mode) + + async with aiosqlite.connect(self._db_path) as db: + # 查询所有凭证的文件名和邮箱(不加载完整凭证数据) + query = f""" + SELECT filename, user_email + FROM {table_name} + ORDER BY filename + """ + + async with db.execute(query) as cursor: + rows = await cursor.fetchall() + + # 按邮箱分组 + email_to_files = {} + no_email_files = [] + + for filename, user_email in rows: + if user_email: + if user_email not in email_to_files: + email_to_files[user_email] = [] + email_to_files[user_email].append(filename) + else: + no_email_files.append(filename) + + # 找出重复的邮箱组 + duplicate_groups = [] + total_duplicate_count = 0 + + for email, files in email_to_files.items(): + if len(files) > 1: + # 保留第一个文件,其他为重复 + duplicate_groups.append({ + "email": email, + "kept_file": files[0], + "duplicate_files": files[1:], + "duplicate_count": len(files) - 1, + }) + total_duplicate_count += len(files) - 1 + + return { + "email_groups": email_to_files, + "duplicate_groups": duplicate_groups, + "duplicate_count": total_duplicate_count, + "no_email_files": no_email_files, + "no_email_count": len(no_email_files), + "unique_email_count": len(email_to_files), + "total_count": len(rows), + } + + except Exception as e: + log.error(f"Error getting duplicate credentials by email: {e}") + return { + "email_groups": {}, + "duplicate_groups": [], + "duplicate_count": 0, + "no_email_files": [], + "no_email_count": 0, + "unique_email_count": 0, + "total_count": 0, + } + + # ============ 配置管理(内存缓存)============ + + async def set_config(self, key: str, value: Any) -> bool: + """设置配置(写入数据库 + 更新内存缓存)""" + self._ensure_initialized() + + try: + async with aiosqlite.connect(self._db_path) as db: + await db.execute(""" + INSERT INTO config (key, value, updated_at) + VALUES (?, ?, unixepoch()) + ON CONFLICT(key) DO UPDATE SET + value = excluded.value, + updated_at = excluded.updated_at + """, (key, json.dumps(value))) + await db.commit() + + # 更新内存缓存 + self._config_cache[key] = value + return True + + except Exception as e: + log.error(f"Error setting config {key}: {e}") + return False + + async def reload_config_cache(self): + """重新加载配置缓存(在批量修改配置后调用)""" + self._ensure_initialized() + self._config_loaded = False + await self._load_config_cache() + log.info("Config cache reloaded from database") + + async def get_config(self, key: str, default: Any = None) -> Any: + """获取配置(从内存缓存)""" + self._ensure_initialized() + return self._config_cache.get(key, default) + + async def get_all_config(self) -> Dict[str, Any]: + """获取所有配置(从内存缓存)""" + self._ensure_initialized() + return self._config_cache.copy() + + async def delete_config(self, key: str) -> bool: + """删除配置""" + self._ensure_initialized() + + try: + async with aiosqlite.connect(self._db_path) as db: + await db.execute("DELETE FROM config WHERE key = ?", (key,)) + await db.commit() + + # 从内存缓存移除 + self._config_cache.pop(key, None) + return True + + except Exception as e: + log.error(f"Error deleting config {key}: {e}") + return False + + async def get_credential_errors(self, filename: str, mode: str = "geminicli") -> Dict[str, Any]: + """ + 专门获取凭证的错误信息(包含 error_codes 和 error_messages) + + Args: + filename: 凭证文件名 + mode: 凭证模式 ("geminicli" 或 "antigravity") + + Returns: + 包含 error_codes 和 error_messages 的字典 + """ + self._ensure_initialized() + + # 统一使用 basename 处理文件名 + filename = os.path.basename(filename) + + try: + table_name = self._get_table_name(mode) + async with aiosqlite.connect(self._db_path) as db: + # 精确匹配 + async with db.execute(f""" + SELECT error_codes, error_messages FROM {table_name} WHERE filename = ? + """, (filename,)) as cursor: + row = await cursor.fetchone() + + if row: + error_codes_json = row[0] or '[]' + error_messages_json = row[1] or '[]' + return { + "filename": filename, + "error_codes": json.loads(error_codes_json), + "error_messages": json.loads(error_messages_json), + } + + # 凭证不存在,返回空错误信息 + return { + "filename": filename, + "error_codes": [], + "error_messages": [], + } + + except Exception as e: + log.error(f"Error getting credential errors {filename}: {e}") + return { + "filename": filename, + "error_codes": [], + "error_messages": [], + "error": str(e) + } + + # ============ 模型级冷却管理 ============ + + async def set_model_cooldown( + self, + filename: str, + model_name: str, + cooldown_until: Optional[float], + mode: str = "geminicli" + ) -> bool: + """ + 设置特定模型的冷却时间 + + Args: + filename: 凭证文件名 + model_name: 模型名(完整模型名,如 "gemini-2.0-flash-exp") + cooldown_until: 冷却截止时间戳(None 表示清除冷却) + mode: 凭证模式 ("geminicli" 或 "antigravity") + + Returns: + 是否成功 + """ + self._ensure_initialized() + + # 统一使用 basename 处理文件名 + filename = os.path.basename(filename) + + try: + table_name = self._get_table_name(mode) + async with aiosqlite.connect(self._db_path) as db: + # 获取当前的 model_cooldowns + async with db.execute(f""" + SELECT model_cooldowns FROM {table_name} WHERE filename = ? + """, (filename,)) as cursor: + row = await cursor.fetchone() + + if not row: + log.warning(f"Credential {filename} not found") + return False + + model_cooldowns = json.loads(row[0] or '{}') + + # 更新或删除指定模型的冷却时间 + if cooldown_until is None: + model_cooldowns.pop(model_name, None) + else: + model_cooldowns[model_name] = cooldown_until + + # 写回数据库 + await db.execute(f""" + UPDATE {table_name} + SET model_cooldowns = ?, + updated_at = unixepoch() + WHERE filename = ? + """, (json.dumps(model_cooldowns), filename)) + await db.commit() + + log.debug(f"Set model cooldown: {filename}, model_name={model_name}, cooldown_until={cooldown_until}") + return True + + except Exception as e: + log.error(f"Error setting model cooldown for {filename}: {e}") + return False + + async def clear_all_model_cooldowns( + self, + filename: str, + mode: str = "geminicli" + ) -> bool: + """清除某个凭证的所有模型冷却时间""" + self._ensure_initialized() + + filename = os.path.basename(filename) + + try: + table_name = self._get_table_name(mode) + async with aiosqlite.connect(self._db_path) as db: + result = await db.execute(f""" + UPDATE {table_name} + SET model_cooldowns = '{{}}', + updated_at = unixepoch() + WHERE filename = ? + """, (filename,)) + updated_count = result.rowcount + await db.commit() + + if updated_count == 0: + log.warning(f"Credential {filename} not found") + return False + + log.debug(f"Cleared all model cooldowns: {filename} (mode={mode})") + return True + + except Exception as e: + log.error(f"Error clearing all model cooldowns for {filename}: {e}") + return False + + async def record_success( + self, + filename: str, + model_name: Optional[str] = None, + mode: str = "geminicli" + ) -> None: + """ + 成功调用后的条件写入: + - 只有当前 error_codes 非空时才清除错误并写 last_success + - 只有当前存在该模型的冷却键时才清除 + 通过 SQL WHERE 条件匹配实现 + """ + self._ensure_initialized() + filename = os.path.basename(filename) + + try: + table_name = self._get_table_name(mode) + async with aiosqlite.connect(self._db_path) as db: + # 条件写入:只有 error_codes 非空时才触发 + await db.execute(f""" + UPDATE {table_name} + SET last_success = unixepoch(), + error_codes = '[]', + error_messages = '{{}}', + updated_at = unixepoch() + WHERE filename = ? + AND (error_codes IS NOT NULL AND error_codes != '[]' AND error_codes != '') + """, (filename,)) + + # 条件删除模型冷却:只有模型键存在时才写入 + if model_name: + async with db.execute(f""" + SELECT model_cooldowns FROM {table_name} WHERE filename = ? + """, (filename,)) as cursor: + row = await cursor.fetchone() + if row: + cooldowns = json.loads(row[0] or '{}') + if model_name in cooldowns: + cooldowns.pop(model_name) + await db.execute(f""" + UPDATE {table_name} + SET model_cooldowns = ?, updated_at = unixepoch() + WHERE filename = ? + """, (json.dumps(cooldowns), filename)) + + await db.commit() + + except Exception as e: + log.error(f"Error recording success for {filename}: {e}") \ No newline at end of file diff --git a/src/storage_adapter.py b/src/storage_adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..134a0807bda837e51dc395004d0a5354b52af66e --- /dev/null +++ b/src/storage_adapter.py @@ -0,0 +1,343 @@ +""" +存储适配器,提供统一的接口来处理 SQLite 和 MongoDB 存储。 +根据配置自动选择存储后端: +- 默认使用 SQLite(本地文件存储) +- 如果设置了 MONGODB_URI 环境变量,则使用 MongoDB +""" + +import asyncio +import json +import os +from typing import Any, Dict, List, Optional, Protocol + +from log import log + + +class StorageBackend(Protocol): + """存储后端协议""" + + async def initialize(self) -> None: + """初始化存储后端""" + ... + + async def close(self) -> None: + """关闭存储后端""" + ... + + # 凭证管理 + async def store_credential(self, filename: str, credential_data: Dict[str, Any], mode: str = "geminicli") -> bool: + """存储凭证数据""" + ... + + async def get_credential(self, filename: str, mode: str = "geminicli") -> Optional[Dict[str, Any]]: + """获取凭证数据""" + ... + + async def list_credentials(self, mode: str = "geminicli") -> List[str]: + """列出所有凭证文件名""" + ... + + async def delete_credential(self, filename: str, mode: str = "geminicli") -> bool: + """删除凭证""" + ... + + # 状态管理 + async def update_credential_state(self, filename: str, state_updates: Dict[str, Any], mode: str = "geminicli") -> bool: + """更新凭证状态""" + ... + + async def get_credential_state(self, filename: str, mode: str = "geminicli") -> Dict[str, Any]: + """获取凭证状态""" + ... + + async def get_all_credential_states(self, mode: str = "geminicli") -> Dict[str, Dict[str, Any]]: + """获取所有凭证状态""" + ... + + # 配置管理 + async def set_config(self, key: str, value: Any) -> bool: + """设置配置项""" + ... + + async def get_config(self, key: str, default: Any = None) -> Any: + """获取配置项""" + ... + + async def get_all_config(self) -> Dict[str, Any]: + """获取所有配置""" + ... + + async def delete_config(self, key: str) -> bool: + """删除配置项""" + ... + + +class StorageAdapter: + """存储适配器,根据配置选择存储后端""" + + def __init__(self): + self._backend: Optional["StorageBackend"] = None + self._initialized = False + self._lock = asyncio.Lock() + + async def initialize(self) -> None: + """初始化存储适配器""" + async with self._lock: + if self._initialized: + return + + # 按优先级检查存储后端:PostgreSQL > MongoDB > SQLite + postgresql_uri = os.getenv("POSTGRESQL_URI", "") + mongodb_uri = os.getenv("MONGODB_URI", "") + + if postgresql_uri: + # 使用 PostgreSQL + try: + from .storage.psql_manager import PSQLManager + + self._backend = PSQLManager() + await self._backend.initialize() + log.info("Using PostgreSQL storage backend") + except Exception as e: + log.error(f"Failed to initialize PostgreSQL backend: {e}") + # 尝试降级到 SQLite + log.info("Falling back to SQLite storage backend") + try: + from .storage.sqlite_manager import SQLiteManager + + self._backend = SQLiteManager() + await self._backend.initialize() + log.info("Using SQLite storage backend (fallback)") + except Exception as e2: + log.error(f"Failed to initialize SQLite backend: {e2}") + raise RuntimeError("No storage backend available") from e2 + elif not mongodb_uri: + # 优先使用 SQLite(默认启用,无需环境变量) + try: + from .storage.sqlite_manager import SQLiteManager + + self._backend = SQLiteManager() + await self._backend.initialize() + log.info("Using SQLite storage backend") + except Exception as e: + log.error(f"Failed to initialize SQLite backend: {e}") + raise RuntimeError("No storage backend available") from e + else: + # 使用 MongoDB + try: + from .storage.mongodb_manager import MongoDBManager + + self._backend = MongoDBManager() + await self._backend.initialize() + log.info("Using MongoDB storage backend") + except Exception as e: + log.error(f"Failed to initialize MongoDB backend: {e}") + # 尝试降级到 SQLite + log.info("Falling back to SQLite storage backend") + try: + from .storage.sqlite_manager import SQLiteManager + + self._backend = SQLiteManager() + await self._backend.initialize() + log.info("Using SQLite storage backend (fallback)") + except Exception as e2: + log.error(f"Failed to initialize SQLite backend: {e2}") + raise RuntimeError("No storage backend available") from e2 + + self._initialized = True + + async def close(self) -> None: + """关闭存储适配器""" + if self._backend: + await self._backend.close() + self._backend = None + self._initialized = False + + def _ensure_initialized(self): + """确保存储适配器已初始化""" + if not self._initialized or not self._backend: + raise RuntimeError("Storage adapter not initialized") + + # ============ 凭证管理 ============ + + async def store_credential(self, filename: str, credential_data: Dict[str, Any], mode: str = "geminicli") -> bool: + """存储凭证数据""" + self._ensure_initialized() + return await self._backend.store_credential(filename, credential_data, mode) + + async def get_credential(self, filename: str, mode: str = "geminicli") -> Optional[Dict[str, Any]]: + """获取凭证数据""" + self._ensure_initialized() + return await self._backend.get_credential(filename, mode) + + async def list_credentials(self, mode: str = "geminicli") -> List[str]: + """列出所有凭证文件名""" + self._ensure_initialized() + return await self._backend.list_credentials(mode) + + async def delete_credential(self, filename: str, mode: str = "geminicli") -> bool: + """删除凭证""" + self._ensure_initialized() + return await self._backend.delete_credential(filename, mode) + + # ============ 状态管理 ============ + + async def update_credential_state(self, filename: str, state_updates: Dict[str, Any], mode: str = "geminicli") -> bool: + """更新凭证状态""" + self._ensure_initialized() + return await self._backend.update_credential_state(filename, state_updates, mode) + + async def get_credential_state(self, filename: str, mode: str = "geminicli") -> Dict[str, Any]: + """获取凭证状态""" + self._ensure_initialized() + return await self._backend.get_credential_state(filename, mode) + + async def get_all_credential_states(self, mode: str = "geminicli") -> Dict[str, Dict[str, Any]]: + """获取所有凭证状态""" + self._ensure_initialized() + return await self._backend.get_all_credential_states(mode) + + # ============ 配置管理 ============ + + async def set_config(self, key: str, value: Any) -> bool: + """设置配置项""" + self._ensure_initialized() + return await self._backend.set_config(key, value) + + async def get_config(self, key: str, default: Any = None) -> Any: + """获取配置项""" + self._ensure_initialized() + return await self._backend.get_config(key, default) + + async def get_all_config(self) -> Dict[str, Any]: + """获取所有配置""" + self._ensure_initialized() + return await self._backend.get_all_config() + + async def delete_config(self, key: str) -> bool: + """删除配置项""" + self._ensure_initialized() + return await self._backend.delete_config(key) + + # ============ 工具方法 ============ + + async def export_credential_to_json(self, filename: str, output_path: str = None) -> bool: + """将凭证导出为JSON文件""" + self._ensure_initialized() + if hasattr(self._backend, "export_credential_to_json"): + return await self._backend.export_credential_to_json(filename, output_path) + # MongoDB后端的fallback实现 + credential_data = await self.get_credential(filename) + if credential_data is None: + return False + + if output_path is None: + output_path = f"{filename}.json" + + import aiofiles + + try: + async with aiofiles.open(output_path, "w", encoding="utf-8") as f: + await f.write(json.dumps(credential_data, indent=2, ensure_ascii=False)) + return True + except Exception: + return False + + async def import_credential_from_json(self, json_path: str, filename: str = None) -> bool: + """从JSON文件导入凭证""" + self._ensure_initialized() + if hasattr(self._backend, "import_credential_from_json"): + return await self._backend.import_credential_from_json(json_path, filename) + # MongoDB后端的fallback实现 + try: + import aiofiles + + async with aiofiles.open(json_path, "r", encoding="utf-8") as f: + content = await f.read() + + credential_data = json.loads(content) + + if filename is None: + filename = os.path.basename(json_path) + + return await self.store_credential(filename, credential_data) + except Exception: + return False + + def get_backend_type(self) -> str: + """获取当前存储后端类型""" + if not self._backend: + return "none" + + # 检查后端类型 + backend_class_name = self._backend.__class__.__name__ + if "SQLite" in backend_class_name or "sqlite" in backend_class_name.lower(): + return "sqlite" + elif "MongoDB" in backend_class_name or "mongo" in backend_class_name.lower(): + return "mongodb" + elif "PSQL" in backend_class_name or "Postgres" in backend_class_name or "psql" in backend_class_name.lower(): + return "postgresql" + else: + return "unknown" + + async def get_backend_info(self) -> Dict[str, Any]: + """获取存储后端信息""" + self._ensure_initialized() + + backend_type = self.get_backend_type() + info = {"backend_type": backend_type, "initialized": self._initialized} + + # 获取底层存储信息 + if hasattr(self._backend, "get_database_info"): + try: + db_info = await self._backend.get_database_info() + info.update(db_info) + except Exception as e: + info["database_error"] = str(e) + else: + backend_type = self.get_backend_type() + if backend_type == "sqlite": + info.update( + { + "database_path": getattr(self._backend, "_db_path", None), + "credentials_dir": getattr(self._backend, "_credentials_dir", None), + } + ) + elif backend_type == "mongodb": + info.update( + { + "database_name": getattr(self._backend, "_db", {}).name if hasattr(self._backend, "_db") else None, + } + ) + elif backend_type == "postgresql": + info.update( + { + "dsn": getattr(self._backend, "_dsn", None), + } + ) + + return info + + +# 全局存储适配器实例 +_storage_adapter: Optional[StorageAdapter] = None + + +async def get_storage_adapter() -> StorageAdapter: + """获取全局存储适配器实例""" + global _storage_adapter + + if _storage_adapter is None: + _storage_adapter = StorageAdapter() + await _storage_adapter.initialize() + + return _storage_adapter + + +async def close_storage_adapter(): + """关闭全局存储适配器""" + global _storage_adapter + + if _storage_adapter: + await _storage_adapter.close() + _storage_adapter = None diff --git a/src/task_manager.py b/src/task_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..9e0343ab1918c8e871e706b7a63676c536fb42af --- /dev/null +++ b/src/task_manager.py @@ -0,0 +1,143 @@ +""" +Global task lifecycle management module +管理应用程序中所有异步任务的生命周期,确保正确清理 +""" + +import asyncio +import weakref +from typing import Any, Dict, Set + +from log import log + + +class TaskManager: + """全局异步任务管理器 - 单例模式""" + + _instance = None + _lock = asyncio.Lock() + + def __new__(cls): + if cls._instance is None: + cls._instance = super().__new__(cls) + cls._instance._initialized = False + return cls._instance + + def __init__(self): + if self._initialized: + return + + self._tasks: Set[asyncio.Task] = set() + self._resources: Set[Any] = set() # 需要关闭的资源 + self._shutdown_event = asyncio.Event() + self._initialized = True + log.debug("TaskManager initialized") + + def register_task(self, task: asyncio.Task, description: str = None) -> asyncio.Task: + """注册一个任务供生命周期管理""" + self._tasks.add(task) + task.add_done_callback(lambda t: self._tasks.discard(t)) + + if description: + task.set_name(description) + + log.debug(f"Registered task: {task.get_name() or 'unnamed'}") + return task + + def create_task(self, coro, *, name: str = None) -> asyncio.Task: + """创建并注册一个任务""" + task = asyncio.create_task(coro, name=name) + return self.register_task(task, name) + + def register_resource(self, resource: Any) -> Any: + """注册一个需要清理的资源(如HTTP客户端、文件句柄等)""" + # 使用弱引用避免循环引用 + self._resources.add(weakref.ref(resource)) + log.debug(f"Registered resource: {type(resource).__name__}") + return resource + + async def shutdown(self, timeout: float = 30.0): + """关闭所有任务和资源""" + log.info("TaskManager shutdown initiated") + + # 设置关闭标志 + self._shutdown_event.set() + + # 取消所有未完成的任务 + cancelled_count = 0 + for task in list(self._tasks): + if not task.done(): + task.cancel() + cancelled_count += 1 + + if cancelled_count > 0: + log.info(f"Cancelled {cancelled_count} pending tasks") + + # 等待所有任务完成(包括取消) + if self._tasks: + try: + await asyncio.wait_for( + asyncio.gather(*self._tasks, return_exceptions=True), timeout=timeout + ) + except asyncio.TimeoutError: + log.warning(f"Some tasks did not complete within {timeout}s timeout") + + # 清理资源 - 改进弱引用处理 + cleaned_resources = 0 + failed_resources = 0 + for resource_ref in list(self._resources): + resource = resource_ref() + if resource is not None: + try: + if hasattr(resource, "close"): + if asyncio.iscoroutinefunction(resource.close): + await resource.close() + else: + resource.close() + elif hasattr(resource, "aclose"): + await resource.aclose() + cleaned_resources += 1 + except Exception as e: + log.warning(f"Failed to close resource {type(resource).__name__}: {e}") + failed_resources += 1 + # 如果弱引用已失效,资源已经被自动回收,无需操作 + + if cleaned_resources > 0: + log.info(f"Cleaned up {cleaned_resources} resources") + if failed_resources > 0: + log.warning(f"Failed to clean {failed_resources} resources") + + self._tasks.clear() + self._resources.clear() + log.info("TaskManager shutdown completed") + + @property + def is_shutdown(self) -> bool: + """检查是否已经开始关闭""" + return self._shutdown_event.is_set() + + def get_stats(self) -> Dict[str, int]: + """获取任务管理统计信息""" + return { + "active_tasks": len(self._tasks), + "registered_resources": len(self._resources), + "is_shutdown": self.is_shutdown, + } + + +# 全局任务管理器实例 +task_manager = TaskManager() + + +def create_managed_task(coro, *, name: str = None) -> asyncio.Task: + """创建一个被管理的异步任务的便捷函数""" + return task_manager.create_task(coro, name=name) + + +def register_resource(resource: Any) -> Any: + """注册资源的便捷函数""" + return task_manager.register_resource(resource) + + +async def shutdown_all_tasks(timeout: float = 30.0): + """关闭所有任务的便捷函数""" + await task_manager.shutdown(timeout) diff --git a/src/token_estimator.py b/src/token_estimator.py new file mode 100644 index 0000000000000000000000000000000000000000..ed82cf28ce0d5805f6d3debe0531e16eb3db5290 --- /dev/null +++ b/src/token_estimator.py @@ -0,0 +1,30 @@ +"""简单的 token 估算,不追求精确""" +from __future__ import annotations + +from typing import Any, Dict + + +def estimate_input_tokens(payload: Dict[str, Any]) -> int: + """粗略估算 token 数:字符数 / 4 + 图片固定值""" + total_chars = 0 + image_count = 0 + + # 统计所有文本字符 + def count_str(obj: Any) -> None: + nonlocal total_chars, image_count + if isinstance(obj, str): + total_chars += len(obj) + elif isinstance(obj, dict): + # 检测图片 + if obj.get("type") == "image" or "inlineData" in obj: + image_count += 1 + for v in obj.values(): + count_str(v) + elif isinstance(obj, list): + for item in obj: + count_str(item) + + count_str(payload) + + # 粗略估算:字符数/4 + 每张图片300 tokens + return max(1, total_chars // 4 + image_count * 300) diff --git a/src/utils.py b/src/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e8d81b0643368b8da59f1ff2374f5cb97d2e2714 --- /dev/null +++ b/src/utils.py @@ -0,0 +1,284 @@ +from typing import List, Optional + +from config import get_api_password, get_panel_password +from fastapi import Depends, HTTPException, Header, Query, Request, status +from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer +from log import log + +# HTTP Bearer security scheme +security = HTTPBearer() + +# ====================== OAuth Configuration ====================== + +_GEMINICLI_VERSION = "0.35.2" +_GEMINICLI_PLATFORM = "win32" +_GEMINICLI_ARCH = "x64" +_GEMINICLI_SURFACE = "cloud-shell" + +def get_geminicli_user_agent(model: str = "") -> str: + """生成动态 User-Agent: GeminiCLI/{version}/{model} ({platform}; {arch}; {surface})""" + if model: + return f"GeminiCLI/{_GEMINICLI_VERSION}/{model} ({_GEMINICLI_PLATFORM}; {_GEMINICLI_ARCH}; {_GEMINICLI_SURFACE})" + return f"GeminiCLI/{_GEMINICLI_VERSION} ({_GEMINICLI_PLATFORM}; {_GEMINICLI_ARCH}; {_GEMINICLI_SURFACE})" + +# 静态常量 +GEMINICLI_USER_AGENT = get_geminicli_user_agent() + +ANTIGRAVITY_USER_AGENT = "antigravity/1.22.2 windows/amd64" + +# OAuth Configuration - 标准模式 +CLIENT_ID = "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com" +CLIENT_SECRET = "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" +SCOPES = [ + "https://www.googleapis.com/auth/cloud-platform", + "https://www.googleapis.com/auth/userinfo.email", + "https://www.googleapis.com/auth/userinfo.profile", +] + +# Antigravity OAuth Configuration +ANTIGRAVITY_CLIENT_ID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com" +ANTIGRAVITY_CLIENT_SECRET = "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" +ANTIGRAVITY_SCOPES = [ + 'https://www.googleapis.com/auth/cloud-platform', + 'https://www.googleapis.com/auth/userinfo.email', + 'https://www.googleapis.com/auth/userinfo.profile', + 'https://www.googleapis.com/auth/cclog', + 'https://www.googleapis.com/auth/experimentsandconfigs' +] + +# 统一的 Token URL(两种模式相同) +TOKEN_URL = "https://oauth2.googleapis.com/token" + +# 回调服务器配置 +CALLBACK_HOST = "localhost" + +# Model name lists for different features +BASE_MODELS = [ + "gemini-2.5-pro", + "gemini-2.5-flash", + "gemini-3-flash-preview", + "gemini-3.1-pro-preview" +] + + +# ====================== Model Helper Functions ====================== + +def is_fake_streaming_model(model_name: str) -> bool: + """Check if model name indicates fake streaming should be used.""" + return model_name.startswith("假流式/") + + +def is_anti_truncation_model(model_name: str) -> bool: + """Check if model name indicates anti-truncation should be used.""" + return model_name.startswith("流式抗截断/") + + +def get_base_model_from_feature_model(model_name: str) -> str: + """Get base model name from feature model name.""" + # Remove feature prefixes + for prefix in ["假流式/", "流式抗截断/"]: + if model_name.startswith(prefix): + return model_name[len(prefix) :] + return model_name + + +def get_available_models(router_type: str = "openai") -> List[str]: + """ + Get available models with feature prefixes. + + Args: + router_type: "openai" or "gemini" + + Returns: + List of model names with feature prefixes + """ + models = [] + + for base_model in BASE_MODELS: + # 基础模型 + models.append(base_model) + + # 假流式模型 (前缀格式) + models.append(f"假流式/{base_model}") + + # 流式抗截断模型 (仅在流式传输时有效,前缀格式) + models.append(f"流式抗截断/{base_model}") + + # 定义思考后缀(根据模型系列不同) + thinking_suffixes = [] + + # Gemini 2.5 系列: 使用思考预算后缀 + if "gemini-2.5" in base_model: + thinking_suffixes = ["-max", "-high", "-medium", "-low", "-minimal"] + # Gemini 3 系列: 使用思考等级后缀 + elif "gemini-3" in base_model: + if "flash" in base_model: + # 3-flash-preview: 支持 high/medium/low/minimal + thinking_suffixes = ["-high", "-medium", "-low", "-minimal"] + elif "pro" in base_model: + # 3-pro-preview: 支持 high/low + thinking_suffixes = ["-low"] + + search_suffix = "-search" + + # 1. 单独的 thinking 后缀 + for thinking_suffix in thinking_suffixes: + models.append(f"{base_model}{thinking_suffix}") + models.append(f"假流式/{base_model}{thinking_suffix}") + models.append(f"流式抗截断/{base_model}{thinking_suffix}") + + # 2. 单独的 search 后缀 + models.append(f"{base_model}{search_suffix}") + models.append(f"假流式/{base_model}{search_suffix}") + models.append(f"流式抗截断/{base_model}{search_suffix}") + + # 3. thinking + search 组合后缀 + for thinking_suffix in thinking_suffixes: + combined_suffix = f"{thinking_suffix}{search_suffix}" + models.append(f"{base_model}{combined_suffix}") + models.append(f"假流式/{base_model}{combined_suffix}") + models.append(f"流式抗截断/{base_model}{combined_suffix}") + + return models + + +# ====================== Authentication Functions ====================== + +async def authenticate_flexible( + request: Request, + authorization: Optional[str] = Header(None), + x_api_key: Optional[str] = Header(None, alias="x-api-key"), + access_token: Optional[str] = Header(None, alias="access_token"), + x_goog_api_key: Optional[str] = Header(None, alias="x-goog-api-key"), + x_anthropic_auth_token: Optional[str] = Header(None, alias="x-anthropic-auth-token"), + anthropic_auth_token: Optional[str] = Header(None, alias="anthropic-auth-token"), + key: Optional[str] = Query(None) +) -> str: + """ + 统一的灵活认证函数,支持多种认证方式 + + 此函数可以直接用作 FastAPI 的 Depends 依赖 + + 支持的认证方式: + - URL 参数: key + - HTTP 头部: Authorization (Bearer token) + - HTTP 头部: x-api-key + - HTTP 头部: access_token + - HTTP 头部: x-goog-api-key + - HTTP 头部: x-anthropic-auth-token + - HTTP 头部: anthropic-auth-token + + Args: + request: FastAPI Request 对象 + authorization: Authorization 头部值(自动注入) + x_api_key: x-api-key 头部值(自动注入) + access_token: access_token 头部值(自动注入) + x_goog_api_key: x-goog-api-key 头部值(自动注入) + x_anthropic_auth_token: x-anthropic-auth-token 头部值(自动注入) + anthropic_auth_token: anthropic-auth-token 头部值(自动注入) + key: URL 参数 key(自动注入) + + Returns: + 验证通过的token + + Raises: + HTTPException: 认证失败时抛出异常 + + 使用示例: + @router.post("/endpoint") + async def endpoint(token: str = Depends(authenticate_flexible)): + # token 已验证通过 + pass + """ + password = await get_api_password() + token = None + auth_method = None + + # 1. 尝试从 URL 参数 key 获取(Google 官方标准方式) + if key: + token = key + auth_method = "URL parameter 'key'" + + # 2. 尝试从 x-goog-api-key 头部获取(Google API 标准方式) + elif x_goog_api_key: + token = x_goog_api_key + auth_method = "x-goog-api-key header" + + # 3. 尝试从 x-anthropic-auth-token 头部获取(Anthropic 标准方式) + elif x_anthropic_auth_token: + token = x_anthropic_auth_token + auth_method = "x-anthropic-auth-token header" + + # 4. 尝试从 anthropic-auth-token 头部获取(Anthropic 替代方式) + elif anthropic_auth_token: + token = anthropic_auth_token + auth_method = "anthropic-auth-token header" + + # 5. 尝试从 x-api-key 头部获取 + elif x_api_key: + token = x_api_key + auth_method = "x-api-key header" + + # 6. 尝试从 access_token 头部获取 + elif access_token: + token = access_token + auth_method = "access_token header" + + # 7. 尝试从 Authorization 头部获取 + elif authorization: + if not authorization.startswith("Bearer "): + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid authentication scheme. Use 'Bearer '", + headers={"WWW-Authenticate": "Bearer"}, + ) + token = authorization[7:] # 移除 "Bearer " 前缀 + auth_method = "Authorization Bearer header" + + # 检查是否提供了任何认证凭据 + if not token: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Missing authentication credentials. Use 'key' URL parameter, 'x-goog-api-key', 'x-anthropic-auth-token', 'anthropic-auth-token', 'x-api-key', 'access_token' header, or 'Authorization: Bearer '", + headers={"WWW-Authenticate": "Bearer"}, + ) + + # 验证 token + if token != password: + log.debug(f"Authentication failed using {auth_method}") + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="密码错误" + ) + + log.debug(f"Authentication successful using {auth_method}") + return token + + +# 为了保持向后兼容,保留旧函数名作为别名 +authenticate_bearer = authenticate_flexible +authenticate_gemini_flexible = authenticate_flexible + + +# ====================== Panel Authentication Functions ====================== + +async def verify_panel_token(credentials: HTTPAuthorizationCredentials = Depends(security)) -> str: + """ + 简化的控制面板密码验证函数 + + 直接验证Bearer token是否等于控制面板密码 + + Args: + credentials: HTTPAuthorizationCredentials 自动注入 + + Returns: + 验证通过的token + + Raises: + HTTPException: 密码错误时抛出401异常 + """ + + password = await get_panel_password() + if credentials.credentials != password: + raise HTTPException(status_code=401, detail="密码错误") + return credentials.credentials diff --git a/start.bat b/start.bat new file mode 100644 index 0000000000000000000000000000000000000000..adb1ee9f9bf1c9de1e34c61cb5426ba63bccf586 --- /dev/null +++ b/start.bat @@ -0,0 +1,7 @@ +git fetch --all +for /f "delims=" %%b in ('git rev-parse --abbrev-ref HEAD') do set branch=%%b +git reset --hard origin/%branch% +uv sync +call .venv\Scripts\activate.bat +python web.py +pause \ No newline at end of file diff --git a/start.sh b/start.sh new file mode 100644 index 0000000000000000000000000000000000000000..1b4a2f05f0941acf257ab84d15e88506544ae1e1 --- /dev/null +++ b/start.sh @@ -0,0 +1,6 @@ +echo "强制同步项目代码,忽略本地修改..." +git fetch --all +git reset --hard origin/$(git rev-parse --abbrev-ref HEAD) +uv sync +source .venv/bin/activate +python web.py \ No newline at end of file diff --git a/termux-install.sh b/termux-install.sh new file mode 100644 index 0000000000000000000000000000000000000000..579607761ec8f63658c6d55530268bd484adec38 --- /dev/null +++ b/termux-install.sh @@ -0,0 +1,169 @@ +#!/bin/bash + +# 避免交互式提示 +export DEBIAN_FRONTEND=noninteractive + +if [ "$(whoami)" = "root" ]; then + echo "检测到root用户,正在退出..." + exit +fi + +echo "检查Termux镜像源配置..." + +# 检查当前镜像源是否已经是Cloudflare镜像 +target_mirror="https://packages-cf.termux.dev/apt/termux-main" +fallback_mirror="https://packages.termux.dev/apt/termux-main" +if [ -f "$PREFIX/etc/apt/sources.list" ] && grep -q "$target_mirror" "$PREFIX/etc/apt/sources.list"; then + echo "✅ 镜像源已经配置为Cloudflare镜像,跳过修改" +else + echo "正在设置Termux镜像为Cloudflare镜像..." + + # 备份原始sources.list文件 + if [ -f "$PREFIX/etc/apt/sources.list" ]; then + echo "备份原始sources.list文件..." + cp "$PREFIX/etc/apt/sources.list" "$PREFIX/etc/apt/sources.list.backup.$(date +%s)" + fi + + # 写入新的镜像源 + echo "写入新的镜像源配置..." + cat > "$PREFIX/etc/apt/sources.list" << 'EOF' +# Cloudflare镜像源 +deb https://packages-cf.termux.dev/apt/termux-main stable main +EOF + + echo "✅ 镜像源已更新为: $target_mirror" +fi + +ensure_dpkg_ready() { + echo "检查并修复 dpkg/apt 状态..." + # 等待可能存在的 apt/dpkg 进程结束 + if pgrep -f "apt|dpkg" >/dev/null 2>&1; then + echo "检测到 apt/dpkg 正在运行,等待其结束..." + while pgrep -f "apt|dpkg" >/dev/null 2>&1; do sleep 1; done + fi + # 清理可能残留的锁(若无进程) + for f in "$PREFIX/var/lib/dpkg/lock" \ + "$PREFIX/var/lib/apt/lists/lock" \ + "$PREFIX/var/cache/apt/archives/lock"; do + [ -e "$f" ] && rm -f "$f" + done + # 尝试继续未完成的配置 + dpkg --configure -a || true +} + +# 更新包列表并检查错误 +echo "正在更新包列表..." +ensure_dpkg_ready +apt_output=$(apt update 2>&1) +if [ $? -ne 0 ]; then + if echo "$apt_output" | grep -qi "is not signed"; then + echo "⚠️ 检测到仓库未签名,尝试切换到官方镜像并修复 keyring..." + # 切换到官方镜像 + sed -i "s#${target_mirror}#${fallback_mirror}#g" "$PREFIX/etc/apt/sources.list" || true + # 清理列表与锁 + rm -rf "$PREFIX/var/lib/apt/lists/"* || true + rm -f "$PREFIX/var/lib/dpkg/lock" "$PREFIX/var/lib/apt/lists/lock" "$PREFIX/var/cache/apt/archives/lock" || true + # 重新安装 termux-keyring(若已安装则强制重装) + apt-get install --reinstall -y termux-keyring || true + # 再次更新 + ensure_dpkg_ready + apt update + else + echo "apt update 失败,错误信息:" + echo "$apt_output" | head -20 + exit 1 + fi +else + echo "$apt_output" +fi + +echo "✅ Termux镜像设置完成!" +echo "📁 原始配置已备份到: $PREFIX/etc/apt/sources.list.backup.*" +echo "🔄 如需恢复原始镜像,可以运行:" +echo " cp \$PREFIX/etc/apt/sources.list.backup.* \$PREFIX/etc/apt/sources.list && apt update" + +# 检查是否需要更新包管理器和安装软件 +need_update=false +packages_to_install="" + +# 检查 uv 是否已安装 +if ! command -v uv &> /dev/null; then + need_update=true + packages_to_install="$packages_to_install uv" +fi + +# 检查 python 是否已安装 +if ! command -v python &> /dev/null; then + need_update=true + packages_to_install="$packages_to_install python" +fi + +# 检查 nodejs 是否已安装 +if ! command -v node &> /dev/null; then + need_update=true + packages_to_install="$packages_to_install nodejs" +fi + +# 检查 git 是否已安装 +if ! command -v git &> /dev/null; then + need_update=true + packages_to_install="$packages_to_install git" +fi + +# 如果需要安装软件,则更新包管理器并安装 +if [ "$need_update" = true ]; then + echo "正在更新包管理器..." + ensure_dpkg_ready + echo "正在安装缺失的软件包: $packages_to_install" + pkg install $packages_to_install -y +else + echo "所需软件包已全部安装,跳过更新和安装步骤" +fi + +# 检查 pm2 是否已安装 +if ! command -v pm2 &> /dev/null; then + echo "正在安装 pm2..." + npm install pm2 -g +else + echo "pm2 已安装,跳过安装" +fi + +# 项目目录处理逻辑 +if [ -f "./web.py" ]; then + # Already in target directory; skip clone and cd + echo "已在目标目录中,跳过克隆操作" +elif [ -f "./gcli2api/web.py" ]; then + echo "进入已存在的 gcli2api 目录" + cd ./gcli2api +else + echo "克隆项目仓库..." + git clone https://github.com/su-kaka/gcli2api.git + cd ./gcli2api +fi + +echo "强制同步项目代码,忽略本地修改..." +git fetch --all +git reset --hard origin/$(git rev-parse --abbrev-ref HEAD) + +# 只在不存在时创建 +if [ ! -d ".venv" ]; then + echo "创建虚拟环境..." + PYTHON_VERSION=$(python -V 2>&1 | grep -oP '\d+\.\d+' | head -1) + if [ -n "$PYTHON_VERSION" ]; then + echo "检测到 Python $PYTHON_VERSION,固定版本..." + uv python pin "$PYTHON_VERSION" + fi + rm pyproject.toml + uv init + uv venv +else + echo "虚拟环境已存在,跳过创建" +fi + +echo "安装 Python 依赖..." +uv add -r requirements-termux.txt + +echo "激活虚拟环境并启动服务..." +source .venv/bin/activate +pm2 start .venv/bin/python --name web -- web.py +cd .. \ No newline at end of file diff --git a/termux-start.sh b/termux-start.sh new file mode 100644 index 0000000000000000000000000000000000000000..ed22db2fd4a65df6c0905cfb533eab698f616c74 --- /dev/null +++ b/termux-start.sh @@ -0,0 +1,12 @@ +echo "强制同步项目代码,忽略本地修改..." +git fetch --all +git reset --hard origin/$(git rev-parse --abbrev-ref HEAD) +echo "创建虚拟环境..." +PYTHON_VERSION=$(python -V 2>&1 | grep -oP '\d+\.\d+' | head -1) +if [ -n "$PYTHON_VERSION" ]; then + echo "检测到 Python $PYTHON_VERSION,固定版本..." + uv python pin "$PYTHON_VERSION" +fi +uv add -r requirements-termux.txt +source .venv/bin/activate +pm2 start .venv/bin/python --name gcli2api -- web.py diff --git a/version.txt b/version.txt new file mode 100644 index 0000000000000000000000000000000000000000..9cdc7b95f596f641262afe44c5a163de73b593fd --- /dev/null +++ b/version.txt @@ -0,0 +1,4 @@ +full_hash=a89888e697a2ee75213c68c3dd6297b3341d101a +short_hash=a89888e +message=优化流式抗截断 +date=2026-04-12 00:26:45 +0800 diff --git a/web.py b/web.py new file mode 100644 index 0000000000000000000000000000000000000000..6c8a68b5738b4e3b725dafe5dedaec00803a81a5 --- /dev/null +++ b/web.py @@ -0,0 +1,216 @@ +""" +Main Web Integration - Integrates all routers and modules +集合router并开启主服务 +""" + +import asyncio +import os +from contextlib import asynccontextmanager + +from fastapi import FastAPI, Response +from fastapi.middleware.cors import CORSMiddleware +from fastapi.staticfiles import StaticFiles + +from config import get_server_host, get_server_port +from log import log + +# Import managers and utilities +from src.credential_manager import credential_manager + +# Import all routers +from src.router.antigravity.openai import router as antigravity_openai_router +from src.router.antigravity.gemini import router as antigravity_gemini_router +from src.router.antigravity.anthropic import router as antigravity_anthropic_router +from src.router.antigravity.model_list import router as antigravity_model_list_router +from src.router.geminicli.openai import router as geminicli_openai_router +from src.router.geminicli.gemini import router as geminicli_gemini_router +from src.router.geminicli.anthropic import router as geminicli_anthropic_router +from src.router.geminicli.model_list import router as geminicli_model_list_router +from src.task_manager import shutdown_all_tasks +from src.panel import router as panel_router +from src.keeplive import keepalive_service + +# 全局凭证管理器 +global_credential_manager = None + + +@asynccontextmanager +async def lifespan(app: FastAPI): + """应用生命周期管理""" + global global_credential_manager + + log.info("启动 GCLI2API 主服务") + + # 初始化配置缓存(优先执行) + try: + import config + await config.init_config() + log.info("配置缓存初始化成功") + except Exception as e: + log.error(f"配置缓存初始化失败: {e}") + + # 初始化全局凭证管理器(通过单例工厂) + try: + # credential_manager 会在第一次调用时自动初始化 + # 这里预先触发初始化以便在启动时检测错误 + await credential_manager._get_or_create() + log.info("凭证管理器初始化成功") + except Exception as e: + log.error(f"凭证管理器初始化失败: {e}") + global_credential_manager = None + + # OAuth回调服务器将在需要时按需启动 + + # 启动保活服务(未配置URL时自动跳过,零开销) + try: + await keepalive_service.start() + except Exception as e: + log.error(f"保活服务启动失败: {e}") + + yield + + # 清理资源 + log.info("开始关闭 GCLI2API 主服务") + + # 停止保活服务 + try: + await keepalive_service.stop() + except Exception as e: + log.error(f"关闭保活服务时出错: {e}") + + # 首先关闭所有异步任务 + try: + await shutdown_all_tasks(timeout=10.0) + log.info("所有异步任务已关闭") + except Exception as e: + log.error(f"关闭异步任务时出错: {e}") + + # 然后关闭凭证管理器 + if global_credential_manager: + try: + await global_credential_manager.close() + log.info("凭证管理器已关闭") + except Exception as e: + log.error(f"关闭凭证管理器时出错: {e}") + + log.info("GCLI2API 主服务已停止") + + +# 创建FastAPI应用 +app = FastAPI( + title="GCLI2API", + description="Gemini API proxy with OpenAI compatibility", + version="2.0.0", + lifespan=lifespan, +) + +# CORS中间件 +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +# 挂载路由器 +# OpenAI兼容路由 - 处理OpenAI格式请求 +app.include_router(geminicli_openai_router, prefix="", tags=["Geminicli OpenAI API"]) + +# Gemini原生路由 - 处理Gemini格式请求 +app.include_router(geminicli_gemini_router, prefix="", tags=["Geminicli Gemini API"]) + +# Geminicli模型列表路由 - 处理Gemini格式的模型列表请求 +app.include_router(geminicli_model_list_router, prefix="", tags=["Geminicli Model List"]) + +# Antigravity路由 - 处理OpenAI格式请求并转换为Antigravity API +app.include_router(antigravity_openai_router, prefix="", tags=["Antigravity OpenAI API"]) + +# Antigravity路由 - 处理Gemini格式请求并转换为Antigravity API +app.include_router(antigravity_gemini_router, prefix="", tags=["Antigravity Gemini API"]) + +# Antigravity模型列表路由 - 处理Gemini格式的模型列表请求 +app.include_router(antigravity_model_list_router, prefix="", tags=["Antigravity Model List"]) + +# Antigravity Anthropic Messages 路由 - Anthropic Messages 格式兼容 +app.include_router(antigravity_anthropic_router, prefix="", tags=["Antigravity Anthropic Messages"]) + +# Geminicli Anthropic Messages 路由 - Anthropic Messages 格式兼容 (Geminicli) +app.include_router(geminicli_anthropic_router, prefix="", tags=["Geminicli Anthropic Messages"]) + +# Panel路由 - 包含认证、凭证管理和控制面板功能 +app.include_router(panel_router, prefix="", tags=["Panel Interface"]) + +# 静态文件路由 - 服务docs目录下的文件 +app.mount("/docs", StaticFiles(directory="docs"), name="docs") + +# 静态文件路由 - 服务front目录下的文件(HTML、JS、CSS等) +app.mount("/front", StaticFiles(directory="front"), name="front") + + +# 保活接口(仅响应 HEAD) +@app.head("/keepalive") +async def keepalive() -> Response: + return Response(status_code=200) + +def main(): + """主启动函数""" + from hypercorn.asyncio import serve + from hypercorn.config import Config + from hypercorn.run import run + + workers = int(os.environ.get("WORKERS", 1)) + + async def _run(): + port = await get_server_port() + host = await get_server_host() + + log.info("=" * 60) + log.info("启动 GCLI2API") + log.info("=" * 60) + log.info(f"控制面板: http://127.0.0.1:{port}") + if workers > 1: + log.info(f"Worker 数量: {workers}") + log.info("=" * 60) + + config = Config() + config.bind = [f"{host}:{port}"] + config.accesslog = "-" + config.errorlog = "-" + config.loglevel = "INFO" + + # 设置连接超时 + config.keep_alive_timeout = 900 + config.read_timeout = 900 + + await serve(app, config) + + if workers == 1: + asyncio.run(_run()) + else: + # 多 worker 模式下 hypercorn run 自行管理进程,先同步获取配置 + port = int(os.environ.get("PORT", 7861)) + host = os.environ.get("HOST", "0.0.0.0") + + log.info("=" * 60) + log.info("启动 GCLI2API") + log.info("=" * 60) + log.info(f"控制面板: http://127.0.0.1:{port}") + log.info(f"Worker 数量: {workers}") + log.info("=" * 60) + + config = Config() + config.bind = [f"{host}:{port}"] + config.accesslog = "-" + config.errorlog = "-" + config.loglevel = "INFO" + config.workers = workers + config.application_path = "web:app" + config.keep_alive_timeout = 900 + config.read_timeout = 900 + + run(config) + + +if __name__ == "__main__": + main() diff --git a/zeabur.yaml b/zeabur.yaml new file mode 100644 index 0000000000000000000000000000000000000000..961526ec3e86c4bb84379735e30a67d429c02362 --- /dev/null +++ b/zeabur.yaml @@ -0,0 +1,50 @@ +apiVersion: zeabur.com/v1 +kind: Template +metadata: + name: gcli2api +spec: + description: "将 GeminiCLI 转换为 OpenAI 和 GEMINI API 接口" + tags: + - ai + - api + - gemini + - openai + variables: + - key: PASSWORD + type: STRING + name: API密码 + description: 用于访问API和控制面板的密码 + - key: DOMAIN + type: DOMAIN + name: 域名 + description: 服务访问域名 + readme: | + # GeminiCLI to API + + 将 GeminiCLI 转换为 OpenAI 和 GEMINI API 接口 + + 部署后请访问您的域名进行配置。 + services: + - name: gcli2api + template: PREBUILT + spec: + id: gcli2api + name: gcli2api + source: + image: ghcr.io/su-kaka/gcli2api:latest + ports: + - id: web + port: 7861 + type: HTTP + env: + PASSWORD: + default: ${PASSWORD} + PORT: + default: "7861" + HOST: + default: "0.0.0.0" + DOMAIN: + default: ${DOMAIN} + volumes: + - id: creds + dir: /app/creds \ No newline at end of file