diff --git a/KiroProxy/.github/workflows/build.yml b/KiroProxy/.github/workflows/build.yml new file mode 100644 index 0000000000000000000000000000000000000000..5ac4d562ae5ba6a08b240e90da2a15dc3607db40 --- /dev/null +++ b/KiroProxy/.github/workflows/build.yml @@ -0,0 +1,245 @@ +name: Build Release + +on: + push: + tags: + - 'v*' + workflow_dispatch: + +permissions: + contents: write + +env: + APP_NAME: KiroProxy + +jobs: + build-linux: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Get version from tag + id: version + run: | + if [[ "${{ github.ref }}" == refs/tags/* ]]; then + VERSION=${GITHUB_REF#refs/tags/v} + else + VERSION=$(grep -oP '__version__ = "\K[^"]+' kiro_proxy/__init__.py) + fi + echo "VERSION=$VERSION" >> $GITHUB_OUTPUT + echo "Version: $VERSION" + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.11' + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -r requirements.txt + pip install pyinstaller + + - name: Build binary + run: python build.py + + - name: Install packaging tools + run: | + sudo apt-get update + sudo apt-get install -y ruby ruby-dev rubygems build-essential rpm libfuse2 + sudo gem install --no-document fpm + + - name: Create packages + run: | + mkdir -p release + VERSION=${{ steps.version.outputs.VERSION }} + + # Binary (standalone) + cp dist/KiroProxy release/KiroProxy-${VERSION}-linux-x86_64 + chmod +x release/KiroProxy-${VERSION}-linux-x86_64 + + # tar.gz + tar -czvf release/KiroProxy-${VERSION}-linux-x86_64.tar.gz -C dist KiroProxy + + # deb package + fpm -s dir -t deb \ + -n kiroproxy \ + -v ${VERSION} \ + --description "Kiro API Proxy Server" \ + --license "MIT" \ + --architecture amd64 \ + --maintainer "petehsu" \ + --url "https://github.com/petehsu/KiroProxy" \ + -p release/kiroproxy_${VERSION}_amd64.deb \ + dist/KiroProxy=/usr/local/bin/KiroProxy + + # rpm package + fpm -s dir -t rpm \ + -n kiroproxy \ + -v ${VERSION} \ + --description "Kiro API Proxy Server" \ + --license "MIT" \ + --architecture x86_64 \ + --maintainer "petehsu" \ + --url "https://github.com/petehsu/KiroProxy" \ + -p release/kiroproxy-${VERSION}-1.x86_64.rpm \ + dist/KiroProxy=/usr/local/bin/KiroProxy + + - name: Upload artifacts + uses: actions/upload-artifact@v4 + with: + name: KiroProxy-Linux + path: release/* + + build-windows: + runs-on: windows-latest + steps: + - uses: actions/checkout@v4 + + - name: Get version from tag + id: version + shell: bash + run: | + if [[ "${{ github.ref }}" == refs/tags/* ]]; then + VERSION=${GITHUB_REF#refs/tags/v} + else + VERSION=$(grep -oP '__version__ = "\K[^"]+' kiro_proxy/__init__.py) + fi + echo "VERSION=$VERSION" >> $GITHUB_OUTPUT + echo "Version: $VERSION" + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.11' + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -r requirements.txt + pip install pyinstaller + + - name: Build + run: python build.py + + - name: Create packages + shell: pwsh + run: | + $VERSION = "${{ steps.version.outputs.VERSION }}" + New-Item -ItemType Directory -Force -Path release + + # exe (standalone) + Copy-Item dist/KiroProxy.exe release/KiroProxy-${VERSION}-windows-x86_64.exe + + # zip + Compress-Archive -Path dist/KiroProxy.exe -DestinationPath release/KiroProxy-${VERSION}-windows-x86_64.zip + + - name: Upload artifacts + uses: actions/upload-artifact@v4 + with: + name: KiroProxy-Windows + path: release/* + + build-macos: + runs-on: macos-latest + steps: + - uses: actions/checkout@v4 + + - name: Get version from tag + id: version + run: | + if [[ "${{ github.ref }}" == refs/tags/* ]]; then + VERSION=${GITHUB_REF#refs/tags/v} + else + VERSION=$(grep -oP '__version__ = "\K[^"]+' kiro_proxy/__init__.py || echo "1.0.0") + fi + echo "VERSION=$VERSION" >> $GITHUB_OUTPUT + echo "Version: $VERSION" + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.11' + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -r requirements.txt + pip install pyinstaller + + - name: Generate icon + run: | + mkdir -p assets/icon.iconset + for size in 16 32 64 128 256 512; do + sips -z $size $size assets/icon.png --out assets/icon.iconset/icon_${size}x${size}.png + done + iconutil -c icns assets/icon.iconset -o assets/icon.icns + + - name: Build + run: python build.py + + - name: Create packages + run: | + VERSION=${{ steps.version.outputs.VERSION }} + mkdir -p release + + # Binary (standalone) + cp dist/KiroProxy release/KiroProxy-${VERSION}-macos-x86_64 + chmod +x release/KiroProxy-${VERSION}-macos-x86_64 + + # zip + cd dist && zip -r ../release/KiroProxy-${VERSION}-macos-x86_64.zip KiroProxy && cd .. + + - name: Upload artifacts + uses: actions/upload-artifact@v4 + with: + name: KiroProxy-macOS + path: release/* + + release: + needs: [build-linux, build-windows, build-macos] + runs-on: ubuntu-latest + if: startsWith(github.ref, 'refs/tags/') + + steps: + - uses: actions/checkout@v4 + + - name: Get version from tag + id: version + run: | + VERSION=${GITHUB_REF#refs/tags/v} + echo "VERSION=$VERSION" >> $GITHUB_OUTPUT + + - name: Download all artifacts + uses: actions/download-artifact@v4 + with: + path: artifacts + + - name: List artifacts + run: find artifacts -type f + + - name: Create Release + uses: softprops/action-gh-release@v1 + with: + name: KiroProxy v${{ steps.version.outputs.VERSION }} + body: | + ## Downloads + + | Platform | File | Description | + |----------|------|-------------| + | **Linux** | `KiroProxy-${{ steps.version.outputs.VERSION }}-linux-x86_64` | Standalone binary | + | | `KiroProxy-${{ steps.version.outputs.VERSION }}-linux-x86_64.tar.gz` | Compressed archive | + | | `kiroproxy_${{ steps.version.outputs.VERSION }}_amd64.deb` | Debian/Ubuntu package | + | | `kiroproxy-${{ steps.version.outputs.VERSION }}-1.x86_64.rpm` | Fedora/RHEL/CentOS package | + | **Windows** | `KiroProxy-${{ steps.version.outputs.VERSION }}-windows-x86_64.exe` | Standalone executable | + | | `KiroProxy-${{ steps.version.outputs.VERSION }}-windows-x86_64.zip` | Compressed archive | + | **macOS** | `KiroProxy-${{ steps.version.outputs.VERSION }}-macos-x86_64` | Standalone binary | + | | `KiroProxy-${{ steps.version.outputs.VERSION }}-macos-x86_64.zip` | Compressed archive | + files: | + artifacts/KiroProxy-Linux/* + artifacts/KiroProxy-Windows/* + artifacts/KiroProxy-macOS/* + draft: false + prerelease: false + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} diff --git a/KiroProxy/.gitignore b/KiroProxy/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..ad5a9bd24fddf7691f3491945d804bff8cd2cd4a --- /dev/null +++ b/KiroProxy/.gitignore @@ -0,0 +1,54 @@ +# Python +__pycache__/ +*.py[cod] +*$py.class +venv/ +.venv/ +*.egg-info/ +.hypothesis/ +.pytest_cache/ + +# Build +build/ +dist/ +release/ +*.spec + +# IDE +.idea/ +.vscode/ +*.swp +*.swo + +# OS +.DS_Store +Thumbs.db + +# HAR files (contain sensitive data) +*.har + +# Logs +*.log + +# Test files +[0-9].txt +[0-9][0-9].txt +线索*.txt + +# Temp analysis files +flows +flows_* +traffic.mitm +*.mitm +analyze_har.py +parse_*.py +*_analysis.txt +*_check.txt +hex_dump.txt +parsed_*.txt +response.txt +参考.txt + +# Other projects +Antigravity-Manager/ +cc-switch/ diff --git a/KiroProxy/CAPTURE_GUIDE.md b/KiroProxy/CAPTURE_GUIDE.md new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/KiroProxy/README.md b/KiroProxy/README.md new file mode 100644 index 0000000000000000000000000000000000000000..fec684e9aa884456fbd4b21f2e83fb5c9b734548 --- /dev/null +++ b/KiroProxy/README.md @@ -0,0 +1,423 @@ +

+ Kiro Proxy +

+ +

Kiro API Proxy

+ +

+ Kiro IDE API 反向代理服务器,支持多账号轮询、Token 自动刷新、配额管理 +

+ +

+ 功能 • + 快速开始 • + CLI 配置 • + API • + 许可证 +

+ +--- + +> **⚠️ 测试说明** +> +> 本项目支持 **Claude Code**、**Codex CLI**、**Gemini CLI** 三种客户端,工具调用功能已全面支持。 + +## 功能特性 + +### 核心功能 +- **多协议支持** - OpenAI / Anthropic / Gemini 三种协议兼容 +- **完整工具调用** - 三种协议的工具调用功能全面支持 +- **图片理解** - 支持 Claude Code / Codex CLI 图片输入 +- **网络搜索** - 支持 Claude Code / Codex CLI 网络搜索工具 +- **思考功能** - 支持 Claude 的扩展思考功能(Extended Thinking) +- **多账号轮询(默认随机)** - 每次请求随机切换账号,分散压力,避免单账号 RPM 过高 +- **会话粘性(可选)** - 非 `random` 策略下,同一会话 60 秒内使用同一账号,保持上下文 +- **Web UI** - 简洁的管理界面,支持监控、日志、设置 + +### v1.7.1 新功能 +- **Windows 支持补强** - 注册表浏览器检测 + PATH 回退,兼容便携版 +- **打包资源修复** - PyInstaller 打包后可正常加载图标与内置文档 +- **Token 扫描稳定性** - Windows 路径编码处理修复 + +### v1.6.3 新功能 +- **命令行工具 (CLI)** - 无 GUI 服务器也能轻松管理 + - `python run.py accounts list` - 列出账号 + - `python run.py accounts export/import` - 导出/导入账号 + - `python run.py accounts add` - 交互式添加 Token + - `python run.py accounts scan` - 扫描本地 Token + - `python run.py login google/github` - 命令行登录 + - `python run.py login remote` - 生成远程登录链接 +- **远程登录链接** - 在有浏览器的机器上完成授权,Token 自动同步 +- **账号导入导出** - 跨机器迁移账号配置 +- **手动添加 Token** - 直接粘贴 accessToken/refreshToken + +### v1.6.2 新功能 +- **Codex CLI 完整支持** - 使用 OpenAI Responses API (`/v1/responses`) + - 完整工具调用支持(shell、file 等所有工具) + - 图片输入支持(`input_image` 类型) + - 网络搜索支持(`web_search` 工具) + - 错误代码映射(rate_limit、context_length 等) +- **Claude Code 增强** - 图片理解和网络搜索完整支持 + - 支持 Anthropic 和 OpenAI 两种图片格式 + - 支持 `web_search` / `web_search_20250305` 工具 + +### v1.6.1 新功能 +- **请求限速** - 通过限制请求频率降低账号封禁风险 + - 每账号最小请求间隔 + - 每账号每分钟最大请求数 + - 全局每分钟最大请求数 + - WebUI 设置页面可配置 +- **账号封禁检测** - 自动检测 TEMPORARILY_SUSPENDED 错误 + - 友好的错误日志输出 + - 自动禁用被封禁账号 + - 自动切换到其他可用账号 +- **统一错误处理** - 三种协议使用统一的错误分类和处理 + +### v1.6.0 功能 +- **历史消息管理** - 4 种策略处理对话长度限制,可自由组合 + - 自动截断:发送前优先保留最新上下文并摘要前文,必要时按数量/字符数截断 + - 智能摘要:用 AI 生成早期对话摘要,保留关键信息 + - 摘要缓存:历史变化不大时复用最近摘要,减少重复 LLM 调用(默认启用) + - 错误重试:遇到长度错误时自动截断重试(默认启用) + - 预估检测:预估 token 数量,超限预先截断 +- **Gemini 工具调用** - 完整支持 functionDeclarations/functionCall/functionResponse +- **设置页面** - WebUI 新增设置标签页,可配置历史消息管理策略 + +### v1.5.0 功能 +- **用量查询** - 查询账号配额使用情况,显示已用/余额/使用率 +- **多登录方式** - 支持 Google / GitHub / AWS Builder ID 三种登录方式 +- **流量监控** - 完整的 LLM 请求监控,支持搜索、过滤、导出 +- **浏览器选择** - 自动检测已安装浏览器,支持无痕模式 +- **文档中心** - 内置帮助文档,左侧目录 + 右侧 Markdown 渲染 + +### v1.4.0 功能 +- **Token 预刷新** - 后台每 5 分钟检查,提前 15 分钟自动刷新 +- **健康检查** - 每 10 分钟检测账号可用性,自动标记状态 +- **请求统计增强** - 按账号/模型统计,24 小时趋势 +- **请求重试机制** - 网络错误/5xx 自动重试,指数退避 + +## 工具调用支持 + +| 功能 | Anthropic (Claude Code) | OpenAI (Codex CLI) | Gemini | +|------|------------------------|-------------------|--------| +| 工具定义 | ✅ `tools` | ✅ `tools.function` | ✅ `functionDeclarations` | +| 工具调用响应 | ✅ `tool_use` | ✅ `tool_calls` | ✅ `functionCall` | +| 工具结果 | ✅ `tool_result` | ✅ `tool` 角色消息 | ✅ `functionResponse` | +| 强制工具调用 | ✅ `tool_choice` | ✅ `tool_choice` | ✅ `toolConfig.mode` | +| 工具数量限制 | ✅ 50 个 | ✅ 50 个 | ✅ 50 个 | +| 历史消息修复 | ✅ | ✅ | ✅ | +| 图片理解 | ✅ | ✅ | ❌ | +| 网络搜索 | ✅ | ✅ | ❌ | + +## 已知限制 + +### 对话长度限制 + +Kiro API 有输入长度限制。当对话历史过长时,会返回错误: + +``` +Input is too long. (CONTENT_LENGTH_EXCEEDS_THRESHOLD) +``` + +#### 自动处理(v1.6.0+) + +代理内置了历史消息管理功能,可在「设置」页面配置: + +- **错误重试**(默认):遇到长度错误时自动截断并重试 +- **智能摘要**:用 AI 生成早期对话摘要,保留关键信息 +- **摘要缓存**(默认):历史变化不大时复用最近摘要,减少重复 LLM 调用 +- **自动截断**:每次请求前优先保留最新上下文并摘要前文,必要时按数量/字符数截断 +- **预估检测**:预估 token 数量,超限预先截断 + +摘要缓存可通过以下配置项调整(默认值): +- `summary_cache_enabled`: `true` +- `summary_cache_min_delta_messages`: `3` +- `summary_cache_min_delta_chars`: `4000` +- `summary_cache_max_age_seconds`: `180` + +#### 手动处理 + +1. 在 Claude Code 中输入 `/clear` 清空对话历史 +2. 告诉 AI 你之前在做什么,它会读取代码文件恢复上下文 + +## 快速开始 + +### 方式一:下载预编译版本 + +从 [Releases](../../releases) 下载对应平台的安装包,解压后直接运行。 + +### 方式二:从源码运行 + +```bash +# 克隆项目 +git clone https://github.com/yourname/kiro-proxy.git +cd kiro-proxy + +# 创建虚拟环境 +python -m venv venv +source venv/bin/activate # Windows: venv\Scripts\activate + +# 安装依赖 +pip install -r requirements.txt + +# 运行 +python run.py + +# 或指定端口 +python run.py 8081 +``` + +启动后访问 http://localhost:8080 + +### 命令行工具 (CLI) + +无 GUI 服务器可使用 CLI 管理账号: + +```bash +# 账号管理 +python run.py accounts list # 列出账号 +python run.py accounts export -o acc.json # 导出账号 +python run.py accounts import acc.json # 导入账号 +python run.py accounts add # 交互式添加 Token +python run.py accounts scan --auto # 扫描并自动添加本地 Token + +# 登录 +python run.py login google # Google 登录 +python run.py login github # GitHub 登录 +python run.py login remote --host myserver.com:8080 # 生成远程登录链接 + +# 服务 +python run.py serve # 启动服务 (默认 8080) +python run.py serve -p 8081 # 指定端口 +python run.py status # 查看状态 +``` + +### 登录获取 Token + +**方式一:在线登录(推荐)** +1. 打开 Web UI,点击「在线登录」 +2. 选择登录方式:Google / GitHub / AWS Builder ID +3. 在浏览器中完成授权 +4. 账号自动添加 + +**方式二:扫描 Token** +1. 打开 Kiro IDE,使用 Google/GitHub 账号登录 +2. 登录成功后 token 自动保存到 `~/.aws/sso/cache/` +3. 在 Web UI 点击「扫描 Token」添加账号 + +## CLI 配置 + +### 模型对照表 + +| Kiro 模型 | 能力 | Claude Code | Codex | +|-----------|------|-------------|-------| +| `claude-sonnet-4` | ⭐⭐⭐ 推荐 | `claude-sonnet-4` | `gpt-4o` | +| `claude-sonnet-4.5` | ⭐⭐⭐⭐ 更强 | `claude-sonnet-4.5` | `gpt-4o` | +| `claude-haiku-4.5` | ⚡ 快速 | `claude-haiku-4.5` | `gpt-4o-mini` | + +### Claude Code 配置 + +``` +名称: Kiro Proxy +API Key: any +Base URL: http://localhost:8080 +模型: claude-sonnet-4 +``` + +### Codex 配置 + +Codex CLI 使用 OpenAI Responses API,配置如下: + +```bash +# 设置环境变量 +export OPENAI_API_KEY=any +export OPENAI_BASE_URL=http://localhost:8080/v1 + +# 运行 Codex +codex +``` + +或在 `~/.codex/config.toml` 中配置: + +```toml +[providers.openai] +api_key = "any" +base_url = "http://localhost:8080/v1" +``` + +## 思考功能支持 + +### 什么是思考功能 + +思考功能(Extended Thinking)允许 Claude 在生成回答前展示其思考过程,帮助用户理解 AI 的推理步骤。 + +### 如何使用 + +在请求中添加 `thinking`(或对应协议的 thinking 配置)即可启用: + +```json +{ + "model": "claude-sonnet-4.5", + "messages": [ + { + "role": "user", + "content": "解释一下量子计算的原理" + } + ], + "thinking": { + "thinking_type": "enabled", + "budget_tokens": 20000 + }, + "stream": true +} +``` + +OpenAI Chat Completions (`POST /v1/chat/completions`) 也支持: + +```json +{ + "model": "gpt-4o", + "messages": [{"role": "user", "content": "解释一下量子计算的原理"}], + "thinking": { "type": "enabled" }, + "stream": true +} +``` + +OpenAI Responses (`POST /v1/responses`) 也支持: + +```json +{ + "model": "gpt-4o", + "input": "解释一下量子计算的原理", + "thinking": { "type": "enabled" } +} +``` + +Gemini generateContent (`POST /v1/models/{model}:generateContent`) 也支持: + +```json +{ + "contents": [{"role": "user", "parts": [{"text": "解释一下量子计算的原理"}]}], + "generationConfig": { + "thinkingConfig": { "includeThoughts": true } + } +} +``` + +### 参数说明 + +- `thinking_type`: 思考类型,设为 `"enabled"` 启用思考功能 +- `budget_tokens`: 思考过程的 token 预算(不传则视为无限制) + +### 响应格式 + +启用思考功能后,流式响应会包含两种内容块: + +1. **思考块**(type: "thinking"):展示 AI 的思考过程 +2. **文本块**(type: "text"):最终的回答内容 + +示例响应: +``` +data: {"type":"content_block_start","index":1,"content_block":{"type":"thinking","thinking":""}} +data: {"type":"content_block_delta","index":1,"delta":{"type":"thinking_delta","thinking":"让我思考一下量子计算的原理..."}} +data: {"type":"content_block_stop","index":1} +data: {"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}} +data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"量子计算是一种..."}} +data: {"type":"content_block_stop","index":0} +``` + +## API 端点 + +| 协议 | 端点 | 用途 | +|------|------|------| +| OpenAI | `POST /v1/chat/completions` | Chat Completions API | +| OpenAI | `POST /v1/responses` | Responses API (Codex CLI) | +| OpenAI | `GET /v1/models` | 模型列表 | +| Anthropic | `POST /v1/messages` | Claude Code | +| Anthropic | `POST /v1/messages/count_tokens` | Token 计数 | +| Gemini | `POST /v1/models/{model}:generateContent` | Gemini CLI | + +### 管理 API + +| 端点 | 方法 | 说明 | +|------|------|------| +| `/api/accounts` | GET | 获取所有账号状态 | +| `/api/accounts/{id}` | GET | 获取账号详情 | +| `/api/accounts/{id}/usage` | GET | 获取账号用量信息 | +| `/api/accounts/{id}/refresh` | POST | 刷新账号 Token | +| `/api/accounts/{id}/restore` | POST | 恢复账号(从冷却状态) | +| `/api/accounts/refresh-all` | POST | 刷新所有即将过期的 Token | +| `/api/flows` | GET | 获取流量记录 | +| `/api/flows/stats` | GET | 获取流量统计 | +| `/api/flows/{id}` | GET | 获取流量详情 | +| `/api/quota` | GET | 获取配额状态 | +| `/api/stats` | GET | 获取统计信息 | +| `/api/health-check` | POST | 手动触发健康检查 | +| `/api/browsers` | GET | 获取可用浏览器列表 | +| `/api/docs` | GET | 获取文档列表 | +| `/api/docs/{id}` | GET | 获取文档内容 | + +## 项目结构 + +``` +kiro_proxy/ +├── main.py # FastAPI 应用入口 +├── config.py # 全局配置 +├── converters.py # 协议转换 +│ +├── core/ # 核心模块 +│ ├── account.py # 账号管理 +│ ├── state.py # 全局状态 +│ ├── persistence.py # 配置持久化 +│ ├── scheduler.py # 后台任务调度 +│ ├── stats.py # 请求统计 +│ ├── retry.py # 重试机制 +│ ├── browser.py # 浏览器检测 +│ ├── flow_monitor.py # 流量监控 +│ └── usage.py # 用量查询 +│ +├── credential/ # 凭证管理 +│ ├── types.py # KiroCredentials +│ ├── fingerprint.py # Machine ID 生成 +│ ├── quota.py # 配额管理器 +│ └── refresher.py # Token 刷新 +│ +├── auth/ # 认证模块 +│ └── device_flow.py # Device Code Flow / Social Auth +│ +├── handlers/ # API 处理器 +│ ├── anthropic.py # /v1/messages +│ ├── openai.py # /v1/chat/completions +│ ├── responses.py # /v1/responses (Codex CLI) +│ ├── gemini.py # /v1/models/{model}:generateContent +│ └── admin.py # 管理 API +│ +├── cli.py # 命令行工具 +│ +├── docs/ # 内置文档 +│ ├── 01-quickstart.md # 快速开始 +│ ├── 02-features.md # 功能特性 +│ ├── 03-faq.md # 常见问题 +│ └── 04-api.md # API 参考 +│ +└── web/ + └── html.py # Web UI (组件化单文件) +``` + +## 构建 + +```bash +# 安装构建依赖 +pip install pyinstaller + +# 构建 +python build.py +``` + +输出文件在 `dist/` 目录。 + +## 免责声明 + +本项目仅供学习研究,禁止商用。使用本项目产生的任何后果由使用者自行承担,与作者无关。 + +本项目与 Kiro / AWS / Anthropic 官方无关。 diff --git a/KiroProxy/assets/icon.iconset/icon_128x128.png b/KiroProxy/assets/icon.iconset/icon_128x128.png new file mode 100644 index 0000000000000000000000000000000000000000..184a76ee4c394ec33115d58b1dc48449b89ecc5b Binary files /dev/null and b/KiroProxy/assets/icon.iconset/icon_128x128.png differ diff --git a/KiroProxy/assets/icon.iconset/icon_16x16.png b/KiroProxy/assets/icon.iconset/icon_16x16.png new file mode 100644 index 0000000000000000000000000000000000000000..243fb06658b6c5659420e22a617ca7c7d1d7a478 Binary files /dev/null and b/KiroProxy/assets/icon.iconset/icon_16x16.png differ diff --git a/KiroProxy/assets/icon.iconset/icon_256x256.png b/KiroProxy/assets/icon.iconset/icon_256x256.png new file mode 100644 index 0000000000000000000000000000000000000000..15644da40e6ab20e2dc3b5e3eaa3c6a5bc20a94e Binary files /dev/null and b/KiroProxy/assets/icon.iconset/icon_256x256.png differ diff --git a/KiroProxy/assets/icon.iconset/icon_32x32.png b/KiroProxy/assets/icon.iconset/icon_32x32.png new file mode 100644 index 0000000000000000000000000000000000000000..b0b210c8755aa3ee07bac1c91b994308b6f1209b Binary files /dev/null and b/KiroProxy/assets/icon.iconset/icon_32x32.png differ diff --git a/KiroProxy/assets/icon.iconset/icon_512x512.png b/KiroProxy/assets/icon.iconset/icon_512x512.png new file mode 100644 index 0000000000000000000000000000000000000000..8d045f4f89649ee6f4be43550dc35bd72015aa6e Binary files /dev/null and b/KiroProxy/assets/icon.iconset/icon_512x512.png differ diff --git a/KiroProxy/assets/icon.iconset/icon_64x64.png b/KiroProxy/assets/icon.iconset/icon_64x64.png new file mode 100644 index 0000000000000000000000000000000000000000..b84f3097f92281da3ddeb4187a33a7ee6c84c3a5 Binary files /dev/null and b/KiroProxy/assets/icon.iconset/icon_64x64.png differ diff --git a/KiroProxy/assets/icon.png b/KiroProxy/assets/icon.png new file mode 100644 index 0000000000000000000000000000000000000000..15644da40e6ab20e2dc3b5e3eaa3c6a5bc20a94e Binary files /dev/null and b/KiroProxy/assets/icon.png differ diff --git a/KiroProxy/assets/icon.svg b/KiroProxy/assets/icon.svg new file mode 100644 index 0000000000000000000000000000000000000000..153c516d4ec39254844f8d454be093eb00754b14 --- /dev/null +++ b/KiroProxy/assets/icon.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/KiroProxy/build.py b/KiroProxy/build.py new file mode 100644 index 0000000000000000000000000000000000000000..4ae207409d3d46aa660f37af67d1ba7edbef2813 --- /dev/null +++ b/KiroProxy/build.py @@ -0,0 +1,219 @@ +#!/usr/bin/env python3 +""" +Kiro Proxy Cross-platform Build Script +Supports: Windows / macOS / Linux + +Usage: + python build.py # Build for current platform + python build.py --all # Show all platform instructions +""" + +import os +import sys +import shutil +import subprocess +from pathlib import Path + +from kiro_proxy import __version__ as VERSION + +APP_NAME = "KiroProxy" +MAIN_SCRIPT = "run.py" +ICON_DIR = Path("assets") + +def get_platform(): + if sys.platform == "win32": + return "windows" + elif sys.platform == "darwin": + return "macos" + else: + return "linux" + +def ensure_pyinstaller(): + try: + import PyInstaller + print(f"[OK] PyInstaller {PyInstaller.__version__} installed") + except ImportError: + print("[..] Installing PyInstaller...") + subprocess.run([sys.executable, "-m", "pip", "install", "pyinstaller"], check=True) + +def clean_build(): + for d in ["build", "dist", f"{APP_NAME}.spec"]: + if os.path.isdir(d): + shutil.rmtree(d) + elif os.path.isfile(d): + os.remove(d) + print("[OK] Cleaned build directories") + +def build_app(): + platform = get_platform() + print(f"\n{'='*50}") + print(f" Building {APP_NAME} v{VERSION} - {platform}") + print(f"{'='*50}\n") + + ensure_pyinstaller() + clean_build() + + args = [ + sys.executable, "-m", "PyInstaller", + "--name", APP_NAME, + "--onefile", + "--clean", + "--noconfirm", + ] + + icon_file = None + if platform == "windows" and (ICON_DIR / "icon.ico").exists(): + icon_file = ICON_DIR / "icon.ico" + elif platform == "macos" and (ICON_DIR / "icon.icns").exists(): + icon_file = ICON_DIR / "icon.icns" + elif (ICON_DIR / "icon.png").exists(): + icon_file = ICON_DIR / "icon.png" + + if icon_file: + args.extend(["--icon", str(icon_file)]) + print(f"[OK] Using icon: {icon_file}") + + # 添加资源文件打包 + if (ICON_DIR).exists(): + if platform == "windows": + args.extend(["--add-data", f"{ICON_DIR};assets"]) + else: + args.extend(["--add-data", f"{ICON_DIR}:assets"]) + print(f"[OK] Adding assets directory") + + # 添加文档文件打包 + docs_dir = Path("kiro_proxy/docs") + if docs_dir.exists(): + if platform == "windows": + args.extend(["--add-data", f"{docs_dir};kiro_proxy/docs"]) + else: + args.extend(["--add-data", f"{docs_dir}:kiro_proxy/docs"]) + print(f"[OK] Adding docs directory") + + hidden_imports = [ + "uvicorn.logging", + "uvicorn.protocols.http", + "uvicorn.protocols.http.auto", + "uvicorn.protocols.http.h11_impl", + "uvicorn.protocols.websockets", + "uvicorn.protocols.websockets.auto", + "uvicorn.lifespan", + "uvicorn.lifespan.on", + "httpx", + "httpx._transports", + "httpx._transports.default", + "anyio", + "anyio._backends", + "anyio._backends._asyncio", + ] + for imp in hidden_imports: + args.extend(["--hidden-import", imp]) + + args.append(MAIN_SCRIPT) + args = [a for a in args if a] + + print(f"[..] Running: {' '.join(args)}\n") + result = subprocess.run(args) + + if result.returncode == 0: + if platform == "windows": + output = Path("dist") / f"{APP_NAME}.exe" + else: + output = Path("dist") / APP_NAME + + if output.exists(): + size_mb = output.stat().st_size / (1024 * 1024) + print(f"\n{'='*50}") + print(f" [OK] Build successful!") + print(f" Output: {output}") + print(f" Size: {size_mb:.1f} MB") + print(f"{'='*50}") + + create_release_package(platform, output) + else: + print("[FAIL] Build failed: output file not found") + sys.exit(1) + else: + print("[FAIL] Build failed") + sys.exit(1) + +def create_release_package(platform, binary_path): + release_dir = Path("release") + release_dir.mkdir(exist_ok=True) + + if platform == "windows": + archive_name = f"{APP_NAME}-{VERSION}-Windows" + shutil.copy(binary_path, release_dir / f"{APP_NAME}.exe") + shutil.make_archive( + str(release_dir / archive_name), + "zip", + release_dir, + f"{APP_NAME}.exe" + ) + (release_dir / f"{APP_NAME}.exe").unlink() + print(f" Release: release/{archive_name}.zip") + + elif platform == "macos": + archive_name = f"{APP_NAME}-{VERSION}-macOS" + shutil.copy(binary_path, release_dir / APP_NAME) + os.chmod(release_dir / APP_NAME, 0o755) + shutil.make_archive( + str(release_dir / archive_name), + "zip", + release_dir, + APP_NAME + ) + (release_dir / APP_NAME).unlink() + print(f" Release: release/{archive_name}.zip") + + else: + archive_name = f"{APP_NAME}-{VERSION}-Linux" + shutil.copy(binary_path, release_dir / APP_NAME) + os.chmod(release_dir / APP_NAME, 0o755) + shutil.make_archive( + str(release_dir / archive_name), + "gztar", + release_dir, + APP_NAME + ) + (release_dir / APP_NAME).unlink() + print(f" Release: release/{archive_name}.tar.gz") + +def show_all_platforms(): + print(f""" +{'='*60} + Kiro Proxy Cross-platform Build Instructions +{'='*60} + +This script must run on the target platform. + +[Windows] + Run on Windows: + python build.py + + Output: release/KiroProxy-{VERSION}-Windows.zip + +[macOS] + Run on macOS: + python build.py + + Output: release/KiroProxy-{VERSION}-macOS.zip + +[Linux] + Run on Linux: + python build.py + + Output: release/KiroProxy-{VERSION}-Linux.tar.gz + +[GitHub Actions] + Push to GitHub and Actions will build all platforms. + See .github/workflows/build.yml + +{'='*60} +""") + +if __name__ == "__main__": + if "--all" in sys.argv or "-a" in sys.argv: + show_all_platforms() + else: + build_app() diff --git a/KiroProxy/examples/quota_display_example.py b/KiroProxy/examples/quota_display_example.py new file mode 100644 index 0000000000000000000000000000000000000000..b010b12a65878701ef3589b1836c75041b9d7269 --- /dev/null +++ b/KiroProxy/examples/quota_display_example.py @@ -0,0 +1,95 @@ +"""展示额度重置时间功能的示例""" +import json +from datetime import datetime + + +def generate_quota_display_example(): + """生成额度显示示例""" + + # 模拟账号的额度信息(从 API 获取) + quota_data = { + "subscription_title": "Kiro Pro", + "usage_limit": 700.0, + "current_usage": 150.0, + "balance": 550.0, + "usage_percent": 21.4, + "is_low_balance": False, + "is_exhausted": False, + "balance_status": "normal", + + # 免费试用信息 + "free_trial_limit": 500.0, + "free_trial_usage": 100.0, + "free_trial_expiry": "2026-02-13T23:59:59Z", + "trial_expiry_text": "2026-02-13", + + # 奖励信息 + "bonus_limit": 150.0, + "bonus_usage": 25.0, + "bonus_expiries": ["2026-03-01T23:59:59Z", "2026-02-28T23:59:59Z"], + "active_bonuses": 2, + + # 重置时间 + "next_reset_date": "2026-02-01T00:00:00Z", + "reset_date_text": "2026-02-01", + + # 更新时间 + "updated_at": "2分钟前", + "error": None + } + + # 生成 HTML 显示片段(类似在 Web 界面中的显示) + html_template = """ +
+
+ 已用/总额 + {current_usage:.1f} / {usage_limit:.1f} +
+
+
+
+
+ 试用: {free_trial_usage:.0f}/{free_trial_limit:.0f} + 奖励: {bonus_usage:.0f}/{bonus_limit:.0f} ({active_bonuses}个) + 更新: {updated_at} +
+
+ 🔄 重置: {reset_date_text} + 🎁 试用过期: {trial_expiry_text} +
+
+ """.format(**quota_data) + + print("=== 额度信息展示示例 ===") + print(html_template) + + # 生成卡片式展示 + card_template = """ +
+

主配额

+
{current_usage:.0f} / {usage_limit:.0f}
+
2026-02-01 重置
+
+
+

免费试用

+
{free_trial_usage:.0f} / {free_trial_limit:.0f}
+
ACTIVE
+
2026-02-13 过期
+
+
+

奖励总计

+
{bonus_usage:.0f} / {bonus_limit:.0f}
+
{active_bonuses}个生效奖励
+
+ """.format(**quota_data) + + print("\n=== 卡片式展示(如图所示)===") + print(card_template) + + # 生成 JSON 数据 + print("\n=== JSON 数据格式 ===") + print(json.dumps(quota_data, indent=2, ensure_ascii=False)) + + +if __name__ == "__main__": + generate_quota_display_example() diff --git a/KiroProxy/examples/test_quota_display.html b/KiroProxy/examples/test_quota_display.html new file mode 100644 index 0000000000000000000000000000000000000000..768e22b24ae8892519409eb32953bb48aef95c9d --- /dev/null +++ b/KiroProxy/examples/test_quota_display.html @@ -0,0 +1,118 @@ + + + + + 额度重置时间测试 + + + +

额度重置时间测试

+
+ + + + diff --git a/KiroProxy/kiro.svg b/KiroProxy/kiro.svg new file mode 100644 index 0000000000000000000000000000000000000000..e132ead12110cb2aab2d3ea69116876db87a6dff --- /dev/null +++ b/KiroProxy/kiro.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/KiroProxy/kiro_proxy/__init__.py b/KiroProxy/kiro_proxy/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8c631050e0f14dbf1e6a346f993c2a73f56e1989 --- /dev/null +++ b/KiroProxy/kiro_proxy/__init__.py @@ -0,0 +1,2 @@ +# Kiro API Proxy +__version__ = "1.7.1" diff --git a/KiroProxy/kiro_proxy/__main__.py b/KiroProxy/kiro_proxy/__main__.py new file mode 100644 index 0000000000000000000000000000000000000000..2f05ddc225577125ac018702cfe4de55a1aacd71 --- /dev/null +++ b/KiroProxy/kiro_proxy/__main__.py @@ -0,0 +1,5 @@ +from .cli import main + + +if __name__ == "__main__": + main() diff --git a/KiroProxy/kiro_proxy/auth/__init__.py b/KiroProxy/kiro_proxy/auth/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fc5195094d359d43b661ed40f67c60f9fabdac39 --- /dev/null +++ b/KiroProxy/kiro_proxy/auth/__init__.py @@ -0,0 +1,32 @@ +"""Kiro 认证模块""" +from .device_flow import ( + start_device_flow, + poll_device_flow, + cancel_device_flow, + get_login_state, + save_credentials_to_file, + DeviceFlowState, + # Social Auth + start_social_auth, + exchange_social_auth_token, + cancel_social_auth, + get_social_auth_state, + start_callback_server, + wait_for_callback, +) + +__all__ = [ + "start_device_flow", + "poll_device_flow", + "cancel_device_flow", + "get_login_state", + "save_credentials_to_file", + "DeviceFlowState", + # Social Auth + "start_social_auth", + "exchange_social_auth_token", + "cancel_social_auth", + "get_social_auth_state", + "start_callback_server", + "wait_for_callback", +] diff --git a/KiroProxy/kiro_proxy/auth/device_flow.py b/KiroProxy/kiro_proxy/auth/device_flow.py new file mode 100644 index 0000000000000000000000000000000000000000..ef333d6371ad93894ad09d893e7904ab4049c251 --- /dev/null +++ b/KiroProxy/kiro_proxy/auth/device_flow.py @@ -0,0 +1,603 @@ +"""Kiro Device Code Flow 登录 + +实现 AWS OIDC Device Authorization Flow: +1. 注册 OIDC 客户端 -> 获取 clientId + clientSecret +2. 发起设备授权 -> 获取 deviceCode + userCode + verificationUri +3. 用户在浏览器中输入 userCode 完成授权 +4. 轮询 Token -> 获取 accessToken + refreshToken + +Social Auth (Google/GitHub): +1. 生成 PKCE code_verifier 和 code_challenge +2. 构建登录 URL,打开浏览器 +3. 启动本地回调服务器接收授权码 +4. 用授权码交换 Token +""" +import json +import time +import httpx +import secrets +import hashlib +import base64 +import asyncio +from pathlib import Path +from dataclasses import dataclass, asdict +from typing import Optional, Tuple +from datetime import datetime, timezone + + +@dataclass +class DeviceFlowState: + """设备授权流程状态""" + client_id: str + client_secret: str + device_code: str + user_code: str + verification_uri: str + interval: int + expires_at: int + region: str + started_at: float + + +@dataclass +class SocialAuthState: + """Social Auth 登录状态""" + provider: str # Google / Github + code_verifier: str + code_challenge: str + oauth_state: str + expires_at: int + started_at: float + + +# 全局登录状态 +_login_state: Optional[DeviceFlowState] = None +_social_auth_state: Optional[SocialAuthState] = None +_callback_server = None + +# Kiro OIDC 配置 +KIRO_START_URL = "https://view.awsapps.com/start" +KIRO_AUTH_ENDPOINT = "https://prod.us-east-1.auth.desktop.kiro.dev" +KIRO_SCOPES = [ + "codewhisperer:completions", + "codewhisperer:analysis", + "codewhisperer:conversations", + "codewhisperer:transformations", + "codewhisperer:taskassist", +] + + +def get_login_state() -> Optional[dict]: + """获取当前登录状态""" + global _login_state + if _login_state is None: + return None + + # 检查是否过期 + if time.time() > _login_state.expires_at: + _login_state = None + return None + + return { + "user_code": _login_state.user_code, + "verification_uri": _login_state.verification_uri, + "expires_in": int(_login_state.expires_at - time.time()), + "interval": _login_state.interval, + } + + +async def start_device_flow(region: str = "us-east-1") -> Tuple[bool, dict]: + """ + 启动设备授权流程 + + Returns: + (success, result_or_error) + """ + global _login_state + + oidc_base = f"https://oidc.{region}.amazonaws.com" + + async with httpx.AsyncClient(timeout=30) as client: + # Step 1: 注册 OIDC 客户端 + print(f"[DeviceFlow] Step 1: 注册 OIDC 客户端...") + + reg_body = { + "clientName": "Kiro Proxy", + "clientType": "public", + "scopes": KIRO_SCOPES, + "grantTypes": ["urn:ietf:params:oauth:grant-type:device_code", "refresh_token"], + "issuerUrl": KIRO_START_URL + } + + try: + reg_resp = await client.post( + f"{oidc_base}/client/register", + json=reg_body, + headers={"Content-Type": "application/json"} + ) + except Exception as e: + return False, {"error": f"注册客户端请求失败: {e}"} + + if reg_resp.status_code != 200: + return False, {"error": f"注册客户端失败: {reg_resp.text}"} + + reg_data = reg_resp.json() + client_id = reg_data.get("clientId") + client_secret = reg_data.get("clientSecret") + + if not client_id or not client_secret: + return False, {"error": "注册响应缺少 clientId 或 clientSecret"} + + print(f"[DeviceFlow] 客户端注册成功: {client_id[:20]}...") + + # Step 2: 发起设备授权 + print(f"[DeviceFlow] Step 2: 发起设备授权...") + + auth_body = { + "clientId": client_id, + "clientSecret": client_secret, + "startUrl": KIRO_START_URL + } + + try: + auth_resp = await client.post( + f"{oidc_base}/device_authorization", + json=auth_body, + headers={"Content-Type": "application/json"} + ) + except Exception as e: + return False, {"error": f"设备授权请求失败: {e}"} + + if auth_resp.status_code != 200: + return False, {"error": f"设备授权失败: {auth_resp.text}"} + + auth_data = auth_resp.json() + device_code = auth_data.get("deviceCode") + user_code = auth_data.get("userCode") + verification_uri = auth_data.get("verificationUriComplete") or auth_data.get("verificationUri") + interval = auth_data.get("interval", 5) + expires_in = auth_data.get("expiresIn", 600) + + if not device_code or not user_code or not verification_uri: + return False, {"error": "设备授权响应缺少必要字段"} + + print(f"[DeviceFlow] 设备码获取成功: {user_code}") + + # 保存状态 + _login_state = DeviceFlowState( + client_id=client_id, + client_secret=client_secret, + device_code=device_code, + user_code=user_code, + verification_uri=verification_uri, + interval=interval, + expires_at=int(time.time() + expires_in), + region=region, + started_at=time.time() + ) + + return True, { + "user_code": user_code, + "verification_uri": verification_uri, + "expires_in": expires_in, + "interval": interval, + } + + +async def poll_device_flow() -> Tuple[bool, dict]: + """ + 轮询设备授权状态 + + Returns: + (success, result_or_error) + - success=True, result={"completed": True, "credentials": {...}} 授权完成 + - success=True, result={"completed": False, "status": "pending"} 等待中 + - success=False, result={"error": "..."} 错误 + """ + global _login_state + + if _login_state is None: + return False, {"error": "没有进行中的登录"} + + # 检查是否过期 + if time.time() > _login_state.expires_at: + _login_state = None + return False, {"error": "授权已过期,请重新开始"} + + oidc_base = f"https://oidc.{_login_state.region}.amazonaws.com" + + token_body = { + "clientId": _login_state.client_id, + "clientSecret": _login_state.client_secret, + "grantType": "urn:ietf:params:oauth:grant-type:device_code", + "deviceCode": _login_state.device_code + } + + async with httpx.AsyncClient(timeout=30) as client: + try: + token_resp = await client.post( + f"{oidc_base}/token", + json=token_body, + headers={"Content-Type": "application/json"} + ) + except Exception as e: + return False, {"error": f"Token 请求失败: {e}"} + + if token_resp.status_code == 200: + # 授权成功 + token_data = token_resp.json() + + credentials = { + "accessToken": token_data.get("accessToken"), + "refreshToken": token_data.get("refreshToken"), + "expiresAt": datetime.now(timezone.utc).isoformat(), + "clientId": _login_state.client_id, + "clientSecret": _login_state.client_secret, + "region": _login_state.region, + "authMethod": "idc", + } + + # 计算过期时间 + if expires_in := token_data.get("expiresIn"): + from datetime import timedelta + expires_at = datetime.now(timezone.utc) + timedelta(seconds=expires_in) + credentials["expiresAt"] = expires_at.isoformat() + + # 清除状态 + _login_state = None + + print(f"[DeviceFlow] 授权成功!") + return True, {"completed": True, "credentials": credentials} + + # 检查错误类型 + try: + error_data = token_resp.json() + error_code = error_data.get("error", "") + except: + error_code = "" + + if error_code == "authorization_pending": + # 用户还未完成授权 + return True, {"completed": False, "status": "pending"} + elif error_code == "slow_down": + # 请求太频繁 + return True, {"completed": False, "status": "slow_down"} + elif error_code == "expired_token": + _login_state = None + return False, {"error": "授权已过期,请重新开始"} + elif error_code == "access_denied": + _login_state = None + return False, {"error": "用户拒绝授权"} + else: + return False, {"error": f"Token 请求失败: {token_resp.text}"} + + +def cancel_device_flow() -> bool: + """取消设备授权流程""" + global _login_state + if _login_state is not None: + _login_state = None + return True + return False + + +async def save_credentials_to_file(credentials: dict, name: str = "kiro-proxy-auth") -> str: + """ + 保存凭证到文件 + + 支持的字段: + - accessToken, refreshToken, profileArn, expiresAt + - clientId, clientSecret (IDC 认证) + - region, authMethod, provider + + Returns: + 保存的文件路径 + """ + from ..config import TOKEN_DIR + TOKEN_DIR.mkdir(parents=True, exist_ok=True) + + # 生成文件名 + file_path = TOKEN_DIR / f"{name}.json" + + # 如果文件已存在,合并现有数据 + existing = {} + if file_path.exists(): + try: + with open(file_path, "r") as f: + existing = json.load(f) + except Exception: + pass + + # 更新凭证(只更新非空值) + for key, value in credentials.items(): + if value is not None: + existing[key] = value + + with open(file_path, "w") as f: + json.dump(existing, f, indent=2) + + print(f"[DeviceFlow] 凭证已保存到: {file_path}") + return str(file_path) + + +# ==================== Social Auth (Google/GitHub) ==================== + +def _generate_code_verifier() -> str: + """生成 PKCE code_verifier""" + return secrets.token_urlsafe(64)[:128] + + +def _generate_code_challenge(verifier: str) -> str: + """生成 PKCE code_challenge (SHA256)""" + digest = hashlib.sha256(verifier.encode()).digest() + return base64.urlsafe_b64encode(digest).rstrip(b'=').decode() + + +def _generate_oauth_state() -> str: + """生成 OAuth state""" + return secrets.token_urlsafe(32) + + +def get_social_auth_state() -> Optional[dict]: + """获取当前 Social Auth 状态""" + global _social_auth_state + if _social_auth_state is None: + return None + + if time.time() > _social_auth_state.expires_at: + _social_auth_state = None + return None + + return { + "provider": _social_auth_state.provider, + "expires_in": int(_social_auth_state.expires_at - time.time()), + } + + +async def start_social_auth(provider: str, redirect_uri: str = None) -> Tuple[bool, dict]: + """ + 启动 Social Auth 登录 (Google/GitHub) + + Args: + provider: "google" 或 "github" + redirect_uri: 回调地址,默认使用 Kiro 官方回调地址 + + Returns: + (success, result_or_error) + """ + global _social_auth_state + + # 验证 provider + provider_normalized = provider.lower() + if provider_normalized == "google": + provider_normalized = "Google" + elif provider_normalized == "github": + provider_normalized = "Github" + else: + return False, {"error": f"不支持的登录提供商: {provider}"} + + print(f"[SocialAuth] 开始 {provider_normalized} 登录流程") + + # 生成 PKCE + code_verifier = _generate_code_verifier() + code_challenge = _generate_code_challenge(code_verifier) + oauth_state = _generate_oauth_state() + + # 回调地址 - 使用 Kiro 官方的回调地址(已在 Cognito 中注册) + # 参考 Kiro-account-manager: kiro://kiro.kiroAgent/authenticate-success + if redirect_uri is None: + redirect_uri = "kiro://kiro.kiroAgent/authenticate-success" + + # 构建登录 URL (使用 /login 端点,参考 Kiro-account-manager) + from urllib.parse import quote, urlencode + + # 使用 urlencode 确保参数正确编码 + params = { + "idp": provider_normalized, + "redirect_uri": redirect_uri, + "code_challenge": code_challenge, + "code_challenge_method": "S256", + "state": oauth_state, + } + login_url = f"{KIRO_AUTH_ENDPOINT}/login?{urlencode(params)}" + + print(f"[SocialAuth] ========== Social Auth 登录 ==========") + print(f"[SocialAuth] Provider: {provider_normalized}") + print(f"[SocialAuth] Redirect URI: {redirect_uri}") + print(f"[SocialAuth] Code Challenge: {code_challenge[:20]}...") + print(f"[SocialAuth] State: {oauth_state}") + print(f"[SocialAuth] 登录 URL: {login_url}") + print(f"[SocialAuth] =========================================") + + # 保存状态(10 分钟过期) + _social_auth_state = SocialAuthState( + provider=provider_normalized, + code_verifier=code_verifier, + code_challenge=code_challenge, + oauth_state=oauth_state, + expires_at=int(time.time() + 600), + started_at=time.time(), + ) + + return True, { + "login_url": login_url, + "state": oauth_state, + "provider": provider_normalized, + "redirect_uri": redirect_uri, + } + + +async def exchange_social_auth_token(code: str, state: str, redirect_uri: str = None) -> Tuple[bool, dict]: + """ + 用授权码交换 Token + + 参考 Kiro-account-manager 实现: + - 端点: https://prod.us-east-1.auth.desktop.kiro.dev/oauth/token + - 请求体: {code, code_verifier, redirect_uri} + - 响应: {accessToken, refreshToken, profileArn, expiresIn} + + Args: + code: 授权码 + state: OAuth state + redirect_uri: 回调地址(需要与 start_social_auth 中使用的一致) + + Returns: + (success, result_or_error) + """ + global _social_auth_state + + if _social_auth_state is None: + return False, {"error": "没有进行中的社交登录"} + + # 验证 state + if state != _social_auth_state.oauth_state: + _social_auth_state = None + return False, {"error": "OAuth state 不匹配"} + + # 检查过期 + if time.time() > _social_auth_state.expires_at: + _social_auth_state = None + return False, {"error": "登录已过期,请重新开始"} + + print(f"[SocialAuth] 交换 Token...") + + # 回调地址 - 需要与 start_social_auth 中使用的一致 + # 使用 Kiro 官方的回调地址 + if redirect_uri is None: + redirect_uri = "kiro://kiro.kiroAgent/authenticate-success" + + # 交换 Token (参考 Kiro-account-manager 的请求格式) + token_body = { + "code": code, + "code_verifier": _social_auth_state.code_verifier, + "redirect_uri": redirect_uri, + } + + async with httpx.AsyncClient(timeout=30) as client: + try: + token_resp = await client.post( + f"{KIRO_AUTH_ENDPOINT}/oauth/token", + json=token_body, + headers={"Content-Type": "application/json"} + ) + except Exception as e: + _social_auth_state = None + return False, {"error": f"Token 请求失败: {e}"} + + if token_resp.status_code != 200: + error_text = token_resp.text + _social_auth_state = None + return False, {"error": f"Token 交换失败: {error_text}"} + + token_data = token_resp.json() + + # 解析响应 (参考 Kiro-account-manager 的响应格式) + # 响应字段: accessToken, refreshToken, profileArn, expiresIn + provider = _social_auth_state.provider + + credentials = { + "accessToken": token_data.get("accessToken") or token_data.get("access_token"), + "refreshToken": token_data.get("refreshToken") or token_data.get("refresh_token"), + "profileArn": token_data.get("profileArn"), + "expiresAt": datetime.now(timezone.utc).isoformat(), + "authMethod": "social", + "provider": provider, # 保存 provider 字段 + } + + # 计算过期时间 + expires_in = token_data.get("expiresIn") or token_data.get("expires_in") + if expires_in: + from datetime import timedelta + expires_at = datetime.now(timezone.utc) + timedelta(seconds=expires_in) + credentials["expiresAt"] = expires_at.isoformat() + + _social_auth_state = None + + print(f"[SocialAuth] {provider} 登录成功!") + return True, {"completed": True, "credentials": credentials, "provider": provider} + + +def cancel_social_auth() -> bool: + """取消 Social Auth 登录""" + global _social_auth_state + if _social_auth_state is not None: + _social_auth_state = None + return True + return False + + +# ==================== 回调服务器 ==================== + +_callback_result = None +_callback_event = None + +async def start_callback_server() -> Tuple[bool, dict]: + """启动本地回调服务器""" + global _callback_result, _callback_event + + from aiohttp import web + + _callback_result = None + _callback_event = asyncio.Event() + + async def handle_callback(request): + global _callback_result + code = request.query.get("code") + state = request.query.get("state") + error = request.query.get("error") + + if error: + _callback_result = {"error": error} + elif code and state: + _callback_result = {"code": code, "state": state} + else: + _callback_result = {"error": "缺少授权码"} + + _callback_event.set() + + # 返回成功页面 + html = """ + + 登录成功 + +

✅ 登录成功

+

您可以关闭此窗口并返回 Kiro Proxy

+ + + + """ + return web.Response(text=html, content_type="text/html") + + app = web.Application() + app.router.add_get("/kiro-social-callback", handle_callback) + + runner = web.AppRunner(app) + await runner.setup() + + try: + site = web.TCPSite(runner, "127.0.0.1", 19823) + await site.start() + print("[SocialAuth] 回调服务器已启动: http://127.0.0.1:19823") + return True, {"port": 19823} + except Exception as e: + return False, {"error": f"启动回调服务器失败: {e}"} + + +async def wait_for_callback(timeout: int = 300) -> Tuple[bool, dict]: + """等待回调""" + global _callback_result, _callback_event + + if _callback_event is None: + return False, {"error": "回调服务器未启动"} + + try: + await asyncio.wait_for(_callback_event.wait(), timeout=timeout) + + if _callback_result and "code" in _callback_result: + return True, _callback_result + elif _callback_result and "error" in _callback_result: + return False, _callback_result + else: + return False, {"error": "未收到有效回调"} + except asyncio.TimeoutError: + return False, {"error": "等待回调超时"} diff --git a/KiroProxy/kiro_proxy/cli.py b/KiroProxy/kiro_proxy/cli.py new file mode 100644 index 0000000000000000000000000000000000000000..2c91a0d8e964f5c662f14816888a0420ee71e921 --- /dev/null +++ b/KiroProxy/kiro_proxy/cli.py @@ -0,0 +1,375 @@ +#!/usr/bin/env python3 +"""Kiro Proxy CLI - 轻量命令行工具""" +import argparse +import asyncio +import json +import sys +from pathlib import Path + +from . import __version__ + + +def cmd_serve(args): + """启动代理服务""" + from .main import run + run(port=args.port) + + +def cmd_accounts_list(args): + """列出所有账号""" + from .core import state + accounts = state.get_accounts_status() + if not accounts: + print("暂无账号") + return + print(f"{'ID':<10} {'名称':<20} {'状态':<10} {'请求数':<8}") + print("-" * 50) + for acc in accounts: + print(f"{acc['id']:<10} {acc['name']:<20} {acc['status']:<10} {acc['request_count']:<8}") + + +def cmd_accounts_export(args): + """导出账号配置""" + from .core import state + accounts_data = [] + for acc in state.accounts: + creds = acc.get_credentials() + if creds: + accounts_data.append({ + "name": acc.name, + "enabled": acc.enabled, + "credentials": { + "accessToken": creds.access_token, + "refreshToken": creds.refresh_token, + "expiresAt": creds.expires_at, + "region": creds.region, + "authMethod": creds.auth_method, + } + }) + + output = {"accounts": accounts_data, "version": "1.0"} + + if args.output: + Path(args.output).write_text(json.dumps(output, indent=2, ensure_ascii=False)) + print(f"已导出 {len(accounts_data)} 个账号到 {args.output}") + else: + print(json.dumps(output, indent=2, ensure_ascii=False)) + + +def cmd_accounts_import(args): + """导入账号配置""" + import uuid + from .core import state, Account + from .auth import save_credentials_to_file + + data = json.loads(Path(args.file).read_text()) + accounts_data = data.get("accounts", []) + imported = 0 + + for acc_data in accounts_data: + creds = acc_data.get("credentials", {}) + if not creds.get("accessToken"): + print(f"跳过 {acc_data.get('name', '未知')}: 缺少 accessToken") + continue + + # 保存凭证到文件 + file_path = asyncio.run(save_credentials_to_file({ + "accessToken": creds.get("accessToken"), + "refreshToken": creds.get("refreshToken"), + "expiresAt": creds.get("expiresAt"), + "region": creds.get("region", "us-east-1"), + "authMethod": creds.get("authMethod", "social"), + }, f"imported-{uuid.uuid4().hex[:8]}")) + + account = Account( + id=uuid.uuid4().hex[:8], + name=acc_data.get("name", "导入账号"), + token_path=file_path, + enabled=acc_data.get("enabled", True) + ) + state.accounts.append(account) + account.load_credentials() + imported += 1 + print(f"已导入: {account.name}") + + state._save_accounts() + print(f"\n共导入 {imported} 个账号") + + +def cmd_accounts_add(args): + """手动添加 Token""" + import uuid + from .core import state, Account + from .auth import save_credentials_to_file + + print("手动添加 Kiro 账号") + print("-" * 40) + + name = input("账号名称 [我的账号]: ").strip() or "我的账号" + print("\n请粘贴 Access Token:") + access_token = input().strip() + + if not access_token: + print("错误: Access Token 不能为空") + return + + print("\n请粘贴 Refresh Token (可选,直接回车跳过):") + refresh_token = input().strip() or None + + # 保存凭证 + file_path = asyncio.run(save_credentials_to_file({ + "accessToken": access_token, + "refreshToken": refresh_token, + "region": "us-east-1", + "authMethod": "social", + }, f"manual-{uuid.uuid4().hex[:8]}")) + + account = Account( + id=uuid.uuid4().hex[:8], + name=name, + token_path=file_path + ) + state.accounts.append(account) + account.load_credentials() + state._save_accounts() + + print(f"\n✅ 账号已添加: {name} (ID: {account.id})") + + +def cmd_accounts_scan(args): + """扫描本地 Token""" + import uuid + from .core import state, Account + from .config import TOKEN_DIR + + # 扫描新目录 + found = [] + if TOKEN_DIR.exists(): + for f in TOKEN_DIR.glob("*.json"): + try: + data = json.loads(f.read_text()) + if "accessToken" in data: + already = any(a.token_path == str(f) for a in state.accounts) + found.append({"path": str(f), "name": f.stem, "already": already}) + except: + pass + + # 兼容旧目录 + sso_cache = Path.home() / ".aws/sso/cache" + if sso_cache.exists(): + for f in sso_cache.glob("*.json"): + try: + data = json.loads(f.read_text()) + if "accessToken" in data: + already = any(a.token_path == str(f) for a in state.accounts) + found.append({"path": str(f), "name": f.stem + " (旧目录)", "already": already}) + except: + pass + + if not found: + print("未找到 Token 文件") + print(f"Token 目录: {TOKEN_DIR}") + return + + print(f"找到 {len(found)} 个 Token:\n") + for i, t in enumerate(found): + status = "[已添加]" if t["already"] else "" + print(f" {i+1}. {t['name']} {status}") + + if args.auto: + # 自动添加所有未添加的 + added = 0 + for t in found: + if not t["already"]: + account = Account( + id=uuid.uuid4().hex[:8], + name=t["name"], + token_path=t["path"] + ) + state.accounts.append(account) + account.load_credentials() + added += 1 + state._save_accounts() + print(f"\n已添加 {added} 个账号") + else: + print("\n使用 --auto 自动添加所有未添加的账号") + + +def cmd_login_remote(args): + """生成远程登录链接""" + import uuid + import time + + session_id = uuid.uuid4().hex + host = args.host or "localhost:8080" + scheme = "https" if args.https else "http" + + print("远程登录链接") + print("-" * 40) + print(f"\n将以下链接发送到有浏览器的机器上完成登录:\n") + print(f" {scheme}://{host}/remote-login/{session_id}") + print(f"\n链接有效期 10 分钟") + print("\n登录完成后,在那台机器上导出账号,然后在这里导入:") + print(f" python -m kiro_proxy accounts import xxx.json") + + +def cmd_login_social(args): + """Social 登录 (Google/GitHub)""" + from .auth import start_social_auth + + provider = args.provider + print(f"启动 {provider.title()} 登录...") + + success, result = asyncio.run(start_social_auth(provider)) + if not success: + print(f"错误: {result.get('error', '未知错误')}") + return + + print(f"\n请在浏览器中打开以下链接完成授权:\n") + print(f" {result['login_url']}") + print(f"\n授权完成后,将浏览器地址栏中的完整 URL 粘贴到这里:") + callback_url = input().strip() + + if not callback_url: + print("已取消") + return + + try: + from urllib.parse import urlparse, parse_qs + parsed = urlparse(callback_url) + params = parse_qs(parsed.query) + code = params.get("code", [None])[0] + oauth_state = params.get("state", [None])[0] + + if not code or not oauth_state: + print("错误: 无效的回调 URL") + return + + from .auth import exchange_social_auth_token + success, result = asyncio.run(exchange_social_auth_token(code, oauth_state)) + + if success and result.get("completed"): + import uuid + from .core import state, Account + from .auth import save_credentials_to_file + + credentials = result["credentials"] + file_path = asyncio.run(save_credentials_to_file( + credentials, f"cli-{provider}" + )) + + account = Account( + id=uuid.uuid4().hex[:8], + name=f"{provider.title()} 登录", + token_path=file_path + ) + state.accounts.append(account) + account.load_credentials() + state._save_accounts() + + print(f"\n✅ 登录成功! 账号已添加: {account.name}") + else: + print(f"错误: {result.get('error', '登录失败')}") + except Exception as e: + print(f"错误: {e}") + + +def cmd_status(args): + """查看服务状态""" + from .core import state + stats = state.get_stats() + + print("Kiro Proxy 状态") + print("-" * 40) + print(f"运行时间: {stats['uptime_seconds']} 秒") + print(f"总请求数: {stats['total_requests']}") + print(f"错误数: {stats['total_errors']}") + print(f"错误率: {stats['error_rate']}") + print(f"账号总数: {stats['accounts_total']}") + print(f"可用账号: {stats['accounts_available']}") + print(f"冷却中: {stats['accounts_cooldown']}") + + +def main(): + parser = argparse.ArgumentParser( + prog="kiro-proxy", + description="Kiro API Proxy CLI" + ) + parser.add_argument("-v", "--version", action="version", version=__version__) + + subparsers = parser.add_subparsers(dest="command", help="命令") + + # serve + serve_parser = subparsers.add_parser("serve", help="启动代理服务") + serve_parser.add_argument("-p", "--port", type=int, default=8080, help="端口号") + serve_parser.set_defaults(func=cmd_serve) + + # status + status_parser = subparsers.add_parser("status", help="查看状态") + status_parser.set_defaults(func=cmd_status) + + # accounts + accounts_parser = subparsers.add_parser("accounts", help="账号管理") + accounts_sub = accounts_parser.add_subparsers(dest="accounts_cmd") + + # accounts list + list_parser = accounts_sub.add_parser("list", help="列出账号") + list_parser.set_defaults(func=cmd_accounts_list) + + # accounts export + export_parser = accounts_sub.add_parser("export", help="导出账号") + export_parser.add_argument("-o", "--output", help="输出文件") + export_parser.set_defaults(func=cmd_accounts_export) + + # accounts import + import_parser = accounts_sub.add_parser("import", help="导入账号") + import_parser.add_argument("file", help="JSON 文件路径") + import_parser.set_defaults(func=cmd_accounts_import) + + # accounts add + add_parser = accounts_sub.add_parser("add", help="手动添加 Token") + add_parser.set_defaults(func=cmd_accounts_add) + + # accounts scan + scan_parser = accounts_sub.add_parser("scan", help="扫描本地 Token") + scan_parser.add_argument("--auto", action="store_true", help="自动添加") + scan_parser.set_defaults(func=cmd_accounts_scan) + + # login + login_parser = subparsers.add_parser("login", help="登录") + login_sub = login_parser.add_subparsers(dest="login_cmd") + + # login remote + remote_parser = login_sub.add_parser("remote", help="生成远程登录链接") + remote_parser.add_argument("--host", help="服务器地址 (如 example.com:8080)") + remote_parser.add_argument("--https", action="store_true", help="使用 HTTPS") + remote_parser.set_defaults(func=cmd_login_remote) + + # login google + google_parser = login_sub.add_parser("google", help="Google 登录") + google_parser.set_defaults(func=cmd_login_social, provider="google") + + # login github + github_parser = login_sub.add_parser("github", help="GitHub 登录") + github_parser.set_defaults(func=cmd_login_social, provider="github") + + args = parser.parse_args() + + if not args.command: + parser.print_help() + return + + if args.command == "accounts" and not args.accounts_cmd: + accounts_parser.print_help() + return + + if args.command == "login" and not args.login_cmd: + login_parser.print_help() + return + + if hasattr(args, "func"): + args.func(args) + + +if __name__ == "__main__": + main() diff --git a/KiroProxy/kiro_proxy/config.py b/KiroProxy/kiro_proxy/config.py new file mode 100644 index 0000000000000000000000000000000000000000..63a6eff564e8357d8ee25476489ef0008cf0646c --- /dev/null +++ b/KiroProxy/kiro_proxy/config.py @@ -0,0 +1,133 @@ +"""配置模块""" +from pathlib import Path + +KIRO_API_URL = "https://q.us-east-1.amazonaws.com/generateAssistantResponse" +MODELS_URL = "https://q.us-east-1.amazonaws.com/ListAvailableModels" + +# 统一数据目录 (所有配置文件都在这里) +DATA_DIR = Path.home() / ".kiro-proxy" + +# Token 存储目录 +TOKEN_DIR = DATA_DIR / "tokens" + +# 默认 Token 路径 (兼容旧代码) +TOKEN_PATH = TOKEN_DIR / "kiro-auth-token.json" + +# 配额管理配置 +QUOTA_COOLDOWN_SECONDS = 300 # 配额超限冷却时间(秒) + +# 模型映射 +MODEL_MAPPING = { + # Claude 3.5 -> Kiro Claude 4 + "claude-3-5-sonnet-20241022": "claude-sonnet-4", + "claude-3-5-sonnet-latest": "claude-sonnet-4", + "claude-3-5-sonnet": "claude-sonnet-4", + "claude-3-5-haiku-20241022": "claude-haiku-4.5", + "claude-3-5-haiku-latest": "claude-haiku-4.5", + # Claude 3 + "claude-3-opus-20240229": "claude-sonnet-4.5", + "claude-3-opus-latest": "claude-sonnet-4.5", + "claude-3-sonnet-20240229": "claude-sonnet-4", + "claude-3-haiku-20240307": "claude-haiku-4.5", + # Claude 4 + "claude-4-sonnet": "claude-sonnet-4", + "claude-4-opus": "claude-sonnet-4.5", + # OpenAI GPT -> Claude + "gpt-4o": "claude-sonnet-4", + "gpt-4o-mini": "claude-haiku-4.5", + "gpt-4-turbo": "claude-sonnet-4", + "gpt-4": "claude-sonnet-4", + "gpt-3.5-turbo": "claude-haiku-4.5", + # OpenAI o1 -> Claude Opus + "o1": "claude-sonnet-4.5", + "o1-preview": "claude-sonnet-4.5", + "o1-mini": "claude-sonnet-4", + # Gemini -> Claude + "gemini-2.0-flash": "claude-sonnet-4", + "gemini-2.0-flash-thinking": "claude-sonnet-4.5", + "gemini-1.5-pro": "claude-sonnet-4.5", + "gemini-1.5-flash": "claude-sonnet-4", + # 别名 + "sonnet": "claude-sonnet-4", + "haiku": "claude-haiku-4.5", + "opus": "claude-sonnet-4.5", +} + +KIRO_MODELS = {"auto", "claude-sonnet-4.5", "claude-sonnet-4", "claude-haiku-4.5"} + +def get_best_model_by_tier(tier: str, available_models: set = None) -> str: + """根据等级获取最佳可用模型(等级对等 + 智能降级)""" + if available_models is None: + available_models = KIRO_MODELS + + # 等级对等映射 + 降级路径 + TIER_PRIORITIES = { + # Opus: 最强 → 次强 → 快速 → 自动 + "opus": ["claude-sonnet-4.5", "claude-sonnet-4", "claude-haiku-4.5", "auto"], + + # Sonnet: 高性能 → 最强 → 标准 → 快速 → 自动 + "sonnet": ["claude-sonnet-4.5", "claude-sonnet-4", "claude-haiku-4.5", "auto"], + + # Haiku: 快速 → 标准 → 高性能 → 自动 + "haiku": ["claude-haiku-4.5", "claude-sonnet-4", "claude-sonnet-4.5", "auto"], + } + + priorities = TIER_PRIORITIES.get(tier, TIER_PRIORITIES["sonnet"]) + + # 选择第一个可用的模型 + for model in priorities: + if model in available_models: + return model + + return "auto" # 最终回退 + + +def detect_model_tier(model: str) -> str: + """智能检测模型等级""" + if not model: + return "sonnet" # 默认中等 + + model_lower = model.lower() + + # 特殊模型优先检测(避免被通用关键词误判) + if "gemini" in model_lower: + if any(keyword in model_lower for keyword in ["1.5-pro", "pro"]): + return "opus" + elif any(keyword in model_lower for keyword in ["2.0", "flash"]): + return "sonnet" # Gemini 2.0 和 flash 系列归为 sonnet + + # 等级关键词检测(优先级从高到低) + # Opus 等级 - 最强模型 + if any(keyword in model_lower for keyword in ["opus", "o1", "max", "ultra", "premium"]): + return "opus" + + # Haiku 等级 - 快速模型(需要排除 sonnet 中的 3.5) + if any(keyword in model_lower for keyword in ["haiku", "mini", "light", "fast", "turbo"]): + return "haiku" + # 特殊处理:gpt-3.5 系列属于 haiku + if "3.5" in model_lower and "sonnet" not in model_lower: + return "haiku" + + # Sonnet 等级 - 平衡模型 + if any(keyword in model_lower for keyword in ["sonnet", "4o", "4", "standard", "base"]): + return "sonnet" + + return "sonnet" # 默认中等 + + +def map_model_name(model: str, available_models: set = None) -> str: + """将外部模型名称映射到 Kiro 支持的名称(支持动态模型选择)""" + if not model: + return "auto" + + # 1. 精确匹配优先 + if model in MODEL_MAPPING: + return MODEL_MAPPING[model] + if model in KIRO_MODELS: + return model + + # 2. 智能等级检测 + 动态选择 + tier = detect_model_tier(model) + best_model = get_best_model_by_tier(tier, available_models) + + return best_model diff --git a/KiroProxy/kiro_proxy/converters/__init__.py b/KiroProxy/kiro_proxy/converters/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d48cc5d46ad19e507b1485ac858619f4a663b96e --- /dev/null +++ b/KiroProxy/kiro_proxy/converters/__init__.py @@ -0,0 +1,1196 @@ +"""协议转换模块 - Anthropic/OpenAI/Gemini <-> Kiro + +增强版:参考 proxycast 实现 +- 工具数量限制(最多 50 个) +- 工具描述截断(最多 500 字符) +- 历史消息交替修复 +- OpenAI tool 角色消息处理 +- tool_choice: required 支持 +- web_search 特殊工具支持 +- tool_results 去重 +""" +import json +import hashlib +import re +from typing import List, Dict, Any, Tuple, Optional + +# 常量 +MAX_TOOLS = 50 +MAX_TOOL_DESCRIPTION_LENGTH = 500 + + +def generate_session_id(messages: list) -> str: + """基于消息内容生成会话ID""" + content = json.dumps(messages[:3], sort_keys=True) + return hashlib.sha256(content.encode()).hexdigest()[:16] + + +def extract_images_from_content(content) -> Tuple[str, List[dict]]: + """从消息内容中提取文本和图片 + + Returns: + (text_content, images_list) + """ + if isinstance(content, str): + return content, [] + + if not isinstance(content, list): + return str(content) if content else "", [] + + text_parts = [] + images = [] + + for block in content: + if isinstance(block, str): + text_parts.append(block) + elif isinstance(block, dict): + block_type = block.get("type", "") + + if block_type == "text": + text_parts.append(block.get("text", "")) + + elif block_type == "image": + # Anthropic 格式 + source = block.get("source", {}) + media_type = source.get("media_type", "image/jpeg") + data = source.get("data", "") + + fmt = "jpeg" + if "png" in media_type: + fmt = "png" + elif "gif" in media_type: + fmt = "gif" + elif "webp" in media_type: + fmt = "webp" + + if data: + images.append({ + "format": fmt, + "source": {"bytes": data} + }) + + elif block_type == "image_url": + # OpenAI 格式 + image_url = block.get("image_url", {}) + url = image_url.get("url", "") + + if url.startswith("data:"): + match = re.match(r'data:image/(\w+);base64,(.+)', url) + if match: + fmt = match.group(1) + data = match.group(2) + images.append({ + "format": fmt, + "source": {"bytes": data} + }) + + return "\n".join(text_parts), images + + +def truncate_description(desc: str, max_length: int = MAX_TOOL_DESCRIPTION_LENGTH) -> str: + """截断工具描述""" + if len(desc) <= max_length: + return desc + return desc[:max_length - 3] + "..." + + +# ==================== Anthropic 转换 ==================== + +def convert_anthropic_tools_to_kiro(tools: List[dict]) -> List[dict]: + """将 Anthropic 工具格式转换为 Kiro 格式 + + 增强: + - 限制最多 50 个工具 + - 截断过长的描述 + - 支持 web_search 特殊工具 + """ + kiro_tools = [] + function_count = 0 + + for tool in tools: + name = tool.get("name", "") + + # 特殊工具:web_search + if name in ("web_search", "web_search_20250305"): + kiro_tools.append({ + "webSearchTool": { + "type": "web_search" + } + }) + continue + + # 限制工具数量 + if function_count >= MAX_TOOLS: + continue + function_count += 1 + + description = tool.get("description", f"Tool: {name}") + description = truncate_description(description) + + input_schema = tool.get("input_schema", {"type": "object", "properties": {}}) + + kiro_tools.append({ + "toolSpecification": { + "name": name, + "description": description, + "inputSchema": { + "json": input_schema + } + } + }) + + return kiro_tools + + +def fix_history_alternation(history: List[dict], model_id: str = "claude-sonnet-4") -> List[dict]: + """修复历史记录,确保 user/assistant 严格交替,并验证 toolUses/toolResults 配对 + + Kiro API 规则: + 1. 消息必须严格交替:user -> assistant -> user -> assistant + 2. 当 assistant 有 toolUses 时,下一条 user 必须有对应的 toolResults + 3. 当 assistant 没有 toolUses 时,下一条 user 不能有 toolResults + """ + if not history: + return history + + # 深拷贝以避免修改原始数据 + import copy + history = copy.deepcopy(history) + + fixed = [] + + for i, item in enumerate(history): + is_user = "userInputMessage" in item + is_assistant = "assistantResponseMessage" in item + + if is_user: + # 检查上一条是否也是 user + if fixed and "userInputMessage" in fixed[-1]: + # 检查当前消息是否有 tool_results + user_msg = item["userInputMessage"] + ctx = user_msg.get("userInputMessageContext", {}) + has_tool_results = bool(ctx.get("toolResults")) + + if has_tool_results: + # 合并 tool_results 到上一条 user 消息 + new_results = ctx["toolResults"] + last_user = fixed[-1]["userInputMessage"] + + if "userInputMessageContext" not in last_user: + last_user["userInputMessageContext"] = {} + + last_ctx = last_user["userInputMessageContext"] + if "toolResults" in last_ctx and last_ctx["toolResults"]: + last_ctx["toolResults"].extend(new_results) + else: + last_ctx["toolResults"] = new_results + continue + else: + # 插入一个占位 assistant 消息(不带 toolUses) + fixed.append({ + "assistantResponseMessage": { + "content": "I understand." + } + }) + + # 验证 toolResults 与前一个 assistant 的 toolUses 配对 + if fixed and "assistantResponseMessage" in fixed[-1]: + last_assistant = fixed[-1]["assistantResponseMessage"] + has_tool_uses = bool(last_assistant.get("toolUses")) + + user_msg = item["userInputMessage"] + ctx = user_msg.get("userInputMessageContext", {}) + has_tool_results = bool(ctx.get("toolResults")) + + if has_tool_uses and not has_tool_results: + # assistant 有 toolUses 但 user 没有 toolResults + # 这是不允许的:不要删除 toolUses(否则会破坏后续上下文/导致 tool_use 轮次丢失) + # 改为在本条 user 前插入一个“工具结果占位” user 消息,与 toolUses 严格配对。 + placeholder_results = [] + for tu in (last_assistant.get("toolUses") or []): + tuid = "" + if isinstance(tu, dict): + tuid = tu.get("toolUseId") or "" + if tuid: + placeholder_results.append({ + "content": [{"text": ""}], + "status": "success", + "toolUseId": tuid, + }) + fixed.append({ + "userInputMessage": { + "content": "Tool results provided.", + "modelId": model_id, + "origin": "AI_EDITOR", + "userInputMessageContext": { + "toolResults": placeholder_results + } + } + }) + elif not has_tool_uses and has_tool_results: + # assistant 没有 toolUses 但 user 有 toolResults + # 这是不允许的,需要清除 user 的 toolResults + item["userInputMessage"].pop("userInputMessageContext", None) + + fixed.append(item) + + elif is_assistant: + # 检查上一条是否也是 assistant + if fixed and "assistantResponseMessage" in fixed[-1]: + # 插入一个占位 user 消息(不带 toolResults) + fixed.append({ + "userInputMessage": { + "content": "Continue", + "modelId": model_id, + "origin": "AI_EDITOR" + } + }) + + # 如果历史为空,先插入一个 user 消息 + if not fixed: + fixed.append({ + "userInputMessage": { + "content": "Continue", + "modelId": model_id, + "origin": "AI_EDITOR" + } + }) + + fixed.append(item) + + # 确保以 assistant 结尾(如果最后是 user,添加占位 assistant) + if fixed and "userInputMessage" in fixed[-1]: + # 不需要清除 toolResults,因为它是与前一个 assistant 的 toolUses 配对的 + # 占位 assistant 只是为了满足交替规则 + fixed.append({ + "assistantResponseMessage": { + "content": "I understand." + } + }) + + return fixed + + +def convert_anthropic_messages_to_kiro(messages: List[dict], system="") -> Tuple[str, List[dict], List[dict]]: + """将 Anthropic 消息格式转换为 Kiro 格式 + + Returns: + (user_content, history, tool_results) + """ + history = [] + user_content = "" + current_tool_results = [] + + def _strip_thinking(text: str) -> str: + if text is None: + return "" + if not isinstance(text, str): + text = str(text) + if not text: + return "" + cleaned = text + while True: + start = find_real_thinking_start_tag(cleaned) + if start == -1: + break + end = find_real_thinking_end_tag(cleaned, start + len("")) + if end == -1: + cleaned = cleaned[:start].rstrip() + break + before = cleaned[:start].rstrip() + after = cleaned[end + len(""):].lstrip() + if before and after: + cleaned = before + "\n" + after + else: + cleaned = before or after + return cleaned.strip() + + # 处理 system + system_text = "" + if isinstance(system, list): + for block in system: + if isinstance(block, dict) and block.get("type") == "text": + system_text += block.get("text", "") + "\n" + elif isinstance(block, str): + system_text += block + "\n" + system_text = system_text.strip() + elif isinstance(system, str): + system_text = system + + system_text = _strip_thinking(system_text) + + for i, msg in enumerate(messages): + role = msg.get("role", "") + content = msg.get("content", "") + is_last = (i == len(messages) - 1) + + # 处理 content 列表 + tool_results = [] + text_parts = [] + + if isinstance(content, list): + for block in content: + if isinstance(block, dict): + if block.get("type") == "text": + text_parts.append(block.get("text", "")) + elif block.get("type") == "tool_result": + tr_content = block.get("content", "") + if isinstance(tr_content, list): + tr_text_parts = [] + for tc in tr_content: + if isinstance(tc, dict) and tc.get("type") == "text": + tr_text_parts.append(tc.get("text", "")) + elif isinstance(tc, str): + tr_text_parts.append(tc) + tr_content = "\n".join(tr_text_parts) + + # 处理 is_error + status = "error" if block.get("is_error") else "success" + + tool_results.append({ + "content": [{"text": str(tr_content)}], + "status": status, + "toolUseId": block.get("tool_use_id", "") + }) + elif isinstance(block, str): + text_parts.append(block) + + content = "\n".join(text_parts) if text_parts else "" + + content = _strip_thinking(content) + + # 处理工具结果 + if tool_results: + # 去重 + seen_ids = set() + unique_results = [] + for tr in tool_results: + if tr["toolUseId"] not in seen_ids: + seen_ids.add(tr["toolUseId"]) + unique_results.append(tr) + tool_results = unique_results + + if is_last: + current_tool_results = tool_results + user_content = content if content else "Tool results provided." + else: + history.append({ + "userInputMessage": { + "content": content if content else "Tool results provided.", + "modelId": "claude-sonnet-4", + "origin": "AI_EDITOR", + "userInputMessageContext": { + "toolResults": tool_results + } + } + }) + continue + + if role == "user": + if system_text and not history: + content = f"{system_text}\n\n{content}" if content else system_text + + content = _strip_thinking(content) + + if is_last: + user_content = content if content else "Continue" + else: + history.append({ + "userInputMessage": { + "content": content if content else "Continue", + "modelId": "claude-sonnet-4", + "origin": "AI_EDITOR" + } + }) + + elif role == "assistant": + tool_uses = [] + assistant_text = "" + + if isinstance(msg.get("content"), list): + text_parts = [] + for block in msg["content"]: + if isinstance(block, dict): + if block.get("type") == "tool_use": + tool_uses.append({ + "toolUseId": block.get("id", ""), + "name": block.get("name", ""), + "input": block.get("input", {}) + }) + elif block.get("type") == "text": + text_parts.append(block.get("text", "")) + assistant_text = "\n".join(text_parts) + else: + assistant_text = content if isinstance(content, str) else "" + + assistant_text = _strip_thinking(assistant_text) + + if not assistant_text and not tool_uses: + continue + + # 确保 assistant 消息有内容 + if not assistant_text: + assistant_text = "I understand." + + assistant_msg = { + "assistantResponseMessage": { + "content": assistant_text + } + } + # 只有在有 toolUses 时才添加这个字段 + if tool_uses: + assistant_msg["assistantResponseMessage"]["toolUses"] = tool_uses + + history.append(assistant_msg) + + # 修复历史交替 + history = fix_history_alternation(history) + + return user_content, history, current_tool_results + + +def convert_kiro_response_to_anthropic(result: dict, model: str, msg_id: str) -> dict: + """将 Kiro 响应转换为 Anthropic 格式""" + content = [] + text = "".join(result["content"]) + if text: + content.append({"type": "text", "text": text}) + + for tool_use in result["tool_uses"]: + content.append(tool_use) + + return { + "id": msg_id, + "type": "message", + "role": "assistant", + "content": content, + "model": model, + "stop_reason": result["stop_reason"], + "stop_sequence": None, + "usage": {"input_tokens": 100, "output_tokens": 100} + } + + +# ==================== OpenAI 转换 ==================== + +def is_tool_choice_required(tool_choice) -> bool: + """检查 tool_choice 是否为 required""" + if isinstance(tool_choice, dict): + t = tool_choice.get("type", "") + return t in ("any", "tool", "required") + elif isinstance(tool_choice, str): + return tool_choice in ("required", "any") + return False + + +def convert_openai_tools_to_kiro(tools: List[dict]) -> List[dict]: + """将 OpenAI 工具格式转换为 Kiro 格式""" + kiro_tools = [] + function_count = 0 + + for tool in tools: + tool_type = tool.get("type", "function") + + # 特殊工具 + if tool_type == "web_search": + kiro_tools.append({ + "webSearchTool": { + "type": "web_search" + } + }) + continue + + if tool_type != "function": + continue + + # 限制工具数量 + if function_count >= MAX_TOOLS: + continue + function_count += 1 + + func = tool.get("function", {}) + name = func.get("name", "") + description = func.get("description", f"Tool: {name}") + description = truncate_description(description) + parameters = func.get("parameters", {"type": "object", "properties": {}}) + + kiro_tools.append({ + "toolSpecification": { + "name": name, + "description": description, + "inputSchema": { + "json": parameters + } + } + }) + + return kiro_tools + + +def convert_openai_messages_to_kiro( + messages: List[dict], + model: str, + tools: List[dict] = None, + tool_choice = None +) -> Tuple[str, List[dict], List[dict], List[dict]]: + """将 OpenAI 消息格式转换为 Kiro 格式 + + 增强: + - 支持 tool 角色消息 + - 支持 assistant 的 tool_calls + - 支持 tool_choice: required + - 历史交替修复 + + Returns: + (user_content, history, tool_results, kiro_tools) + """ + system_content = "" + history = [] + user_content = "" + current_tool_results = [] + pending_tool_results = [] # 待处理的 tool 消息 + + # 处理 tool_choice: required + tool_instruction = "" + if is_tool_choice_required(tool_choice) and tools: + tool_instruction = "\n\n[CRITICAL INSTRUCTION] You MUST use one of the provided tools to respond. Do NOT respond with plain text. Call a tool function immediately." + + for i, msg in enumerate(messages): + role = msg.get("role", "") + content = msg.get("content", "") + is_last = (i == len(messages) - 1) + + # 提取文本内容 + if isinstance(content, list): + content = " ".join([c.get("text", "") for c in content if c.get("type") == "text"]) + if not content: + content = "" + + if role == "system": + system_content = content + tool_instruction + + elif role == "tool": + # OpenAI tool 角色消息 -> Kiro toolResults + tool_call_id = msg.get("tool_call_id", "") + pending_tool_results.append({ + "content": [{"text": str(content)}], + "status": "success", + "toolUseId": tool_call_id + }) + + elif role == "user": + # 如果有待处理的 tool results,先处理 + if pending_tool_results: + # 去重 + seen_ids = set() + unique_results = [] + for tr in pending_tool_results: + if tr["toolUseId"] not in seen_ids: + seen_ids.add(tr["toolUseId"]) + unique_results.append(tr) + + if is_last: + current_tool_results = unique_results + else: + history.append({ + "userInputMessage": { + "content": "Tool results provided.", + "modelId": model, + "origin": "AI_EDITOR", + "userInputMessageContext": { + "toolResults": unique_results + } + } + }) + pending_tool_results = [] + + # 合并 system prompt + if system_content and not history: + content = f"{system_content}\n\n{content}" + + if is_last: + user_content = content + else: + history.append({ + "userInputMessage": { + "content": content, + "modelId": model, + "origin": "AI_EDITOR" + } + }) + + elif role == "assistant": + # 如果有待处理的 tool results,先创建 user 消息 + if pending_tool_results: + seen_ids = set() + unique_results = [] + for tr in pending_tool_results: + if tr["toolUseId"] not in seen_ids: + seen_ids.add(tr["toolUseId"]) + unique_results.append(tr) + + history.append({ + "userInputMessage": { + "content": "Tool results provided.", + "modelId": model, + "origin": "AI_EDITOR", + "userInputMessageContext": { + "toolResults": unique_results + } + } + }) + pending_tool_results = [] + + # 处理 tool_calls + tool_uses = [] + tool_calls = msg.get("tool_calls", []) + for tc in tool_calls: + func = tc.get("function", {}) + args_str = func.get("arguments", "{}") + try: + args = json.loads(args_str) + except: + args = {} + + tool_uses.append({ + "toolUseId": tc.get("id", ""), + "name": func.get("name", ""), + "input": args + }) + + assistant_text = content if content else "I understand." + + assistant_msg = { + "assistantResponseMessage": { + "content": assistant_text + } + } + # 只有在有 toolUses 时才添加这个字段 + if tool_uses: + assistant_msg["assistantResponseMessage"]["toolUses"] = tool_uses + + history.append(assistant_msg) + + # 处理末尾的 tool results + if pending_tool_results: + seen_ids = set() + unique_results = [] + for tr in pending_tool_results: + if tr["toolUseId"] not in seen_ids: + seen_ids.add(tr["toolUseId"]) + unique_results.append(tr) + current_tool_results = unique_results + if not user_content: + user_content = "Tool results provided." + + # 如果没有用户消息 + if not user_content: + user_content = messages[-1].get("content", "") if messages else "Continue" + if isinstance(user_content, list): + user_content = " ".join([c.get("text", "") for c in user_content if c.get("type") == "text"]) + if not user_content: + user_content = "Continue" + + # 历史不包含最后一条用户消息 + if history and "userInputMessage" in history[-1]: + history = history[:-1] + + # 修复历史交替 + history = fix_history_alternation(history, model) + + # 转换工具 + kiro_tools = convert_openai_tools_to_kiro(tools) if tools else [] + + return user_content, history, current_tool_results, kiro_tools + + +def convert_kiro_response_to_openai(result: dict, model: str, msg_id: str) -> dict: + """将 Kiro 响应转换为 OpenAI 格式""" + text = "".join(result["content"]) + tool_calls = [] + + for tool_use in result.get("tool_uses", []): + if tool_use.get("type") == "tool_use": + tool_calls.append({ + "id": tool_use.get("id", ""), + "type": "function", + "function": { + "name": tool_use.get("name", ""), + "arguments": json.dumps(tool_use.get("input", {})) + } + }) + + # 映射 stop_reason + stop_reason = result.get("stop_reason", "stop") + finish_reason = "tool_calls" if tool_calls else "stop" + if stop_reason == "max_tokens": + finish_reason = "length" + + message = { + "role": "assistant", + "content": text if text else None + } + if tool_calls: + message["tool_calls"] = tool_calls + + return { + "id": msg_id, + "object": "chat.completion", + "model": model, + "choices": [{ + "index": 0, + "message": message, + "finish_reason": finish_reason + }], + "usage": { + "prompt_tokens": 100, + "completion_tokens": 100, + "total_tokens": 200 + } + } + + +# ==================== Gemini 转换 ==================== + +def convert_gemini_tools_to_kiro(tools: List[dict]) -> List[dict]: + """将 Gemini 工具格式转换为 Kiro 格式 + + Gemini 工具格式: + { + "functionDeclarations": [ + { + "name": "get_weather", + "description": "Get weather info", + "parameters": {...} + } + ] + } + """ + kiro_tools = [] + function_count = 0 + + for tool in tools: + # Gemini 的工具定义在 functionDeclarations 中 + declarations = tool.get("functionDeclarations", []) + + for func in declarations: + # 限制工具数量 + if function_count >= MAX_TOOLS: + break + function_count += 1 + + name = func.get("name", "") + description = func.get("description", f"Tool: {name}") + description = truncate_description(description) + parameters = func.get("parameters", {"type": "object", "properties": {}}) + + kiro_tools.append({ + "toolSpecification": { + "name": name, + "description": description, + "inputSchema": { + "json": parameters + } + } + }) + + return kiro_tools + + +def convert_gemini_contents_to_kiro( + contents: List[dict], + system_instruction: dict, + model: str, + tools: List[dict] = None, + tool_config: dict = None +) -> Tuple[str, List[dict], List[dict], List[dict]]: + """将 Gemini 消息格式转换为 Kiro 格式 + + 增强: + - 支持 functionCall 和 functionResponse + - 支持 tool_config + + Returns: + (user_content, history, tool_results, kiro_tools) + """ + history = [] + user_content = "" + current_tool_results = [] + pending_tool_results = [] + + # 处理 system instruction + system_text = "" + if system_instruction: + parts = system_instruction.get("parts", []) + system_text = " ".join(p.get("text", "") for p in parts if "text" in p) + + # 处理 tool_config(类似 tool_choice) + tool_instruction = "" + if tool_config: + mode = tool_config.get("functionCallingConfig", {}).get("mode", "") + if mode in ("ANY", "REQUIRED"): + tool_instruction = "\n\n[CRITICAL INSTRUCTION] You MUST use one of the provided tools to respond. Do NOT respond with plain text." + + for i, content in enumerate(contents): + role = content.get("role", "user") + parts = content.get("parts", []) + is_last = (i == len(contents) - 1) + + # 提取文本和工具调用 + text_parts = [] + tool_calls = [] + tool_responses = [] + + for part in parts: + if "text" in part: + text_parts.append(part["text"]) + elif "functionCall" in part: + # Gemini 的工具调用 + fc = part["functionCall"] + tool_calls.append({ + "toolUseId": fc.get("name", "") + "_" + str(i), # Gemini 没有 ID,生成一个 + "name": fc.get("name", ""), + "input": fc.get("args", {}) + }) + elif "functionResponse" in part: + # Gemini 的工具响应 + fr = part["functionResponse"] + response_content = fr.get("response", {}) + if isinstance(response_content, dict): + response_text = json.dumps(response_content) + else: + response_text = str(response_content) + + tool_responses.append({ + "content": [{"text": response_text}], + "status": "success", + "toolUseId": fr.get("name", "") + "_" + str(i - 1) # 匹配上一个调用 + }) + + text = " ".join(text_parts) + + if role == "user": + # 处理待处理的 tool responses + if pending_tool_results: + seen_ids = set() + unique_results = [] + for tr in pending_tool_results: + if tr["toolUseId"] not in seen_ids: + seen_ids.add(tr["toolUseId"]) + unique_results.append(tr) + + history.append({ + "userInputMessage": { + "content": "Tool results provided.", + "modelId": model, + "origin": "AI_EDITOR", + "userInputMessageContext": { + "toolResults": unique_results + } + } + }) + pending_tool_results = [] + + # 处理 functionResponse(用户消息中的工具响应) + if tool_responses: + pending_tool_results.extend(tool_responses) + + # 合并 system prompt + if system_text and not history: + text = f"{system_text}{tool_instruction}\n\n{text}" + + if is_last: + user_content = text + if pending_tool_results: + current_tool_results = pending_tool_results + pending_tool_results = [] + else: + if text: + history.append({ + "userInputMessage": { + "content": text, + "modelId": model, + "origin": "AI_EDITOR" + } + }) + + elif role == "model": + # 处理待处理的 tool responses + if pending_tool_results: + seen_ids = set() + unique_results = [] + for tr in pending_tool_results: + if tr["toolUseId"] not in seen_ids: + seen_ids.add(tr["toolUseId"]) + unique_results.append(tr) + + history.append({ + "userInputMessage": { + "content": "Tool results provided.", + "modelId": model, + "origin": "AI_EDITOR", + "userInputMessageContext": { + "toolResults": unique_results + } + } + }) + pending_tool_results = [] + + assistant_text = text if text else "I understand." + + assistant_msg = { + "assistantResponseMessage": { + "content": assistant_text + } + } + # 只有在有 toolUses 时才添加这个字段 + if tool_calls: + assistant_msg["assistantResponseMessage"]["toolUses"] = tool_calls + + history.append(assistant_msg) + + # 处理末尾的 tool results + if pending_tool_results: + current_tool_results = pending_tool_results + if not user_content: + user_content = "Tool results provided." + + # 如果没有用户消息 + if not user_content: + if contents: + last_parts = contents[-1].get("parts", []) + user_content = " ".join(p.get("text", "") for p in last_parts if "text" in p) + if not user_content: + user_content = "Continue" + + # 修复历史交替 + history = fix_history_alternation(history, model) + + # 移除最后一条(当前用户消息) + if history and "userInputMessage" in history[-1]: + history = history[:-1] + + # 转换工具 + kiro_tools = convert_gemini_tools_to_kiro(tools) if tools else [] + + return user_content, history, current_tool_results, kiro_tools + + +def convert_kiro_response_to_gemini(result: dict, model: str) -> dict: + """将 Kiro 响应转换为 Gemini 格式""" + text = "".join(result.get("content", [])) + tool_uses = result.get("tool_uses", []) + + parts = [] + + # 添加文本部分 + if text: + parts.append({"text": text}) + + # 添加工具调用 + for tool_use in tool_uses: + if tool_use.get("type") == "tool_use": + parts.append({ + "functionCall": { + "name": tool_use.get("name", ""), + "args": tool_use.get("input", {}) + } + }) + + # 映射 stop_reason + stop_reason = result.get("stop_reason", "STOP") + finish_reason = "STOP" + if tool_uses: + finish_reason = "TOOL_CALLS" + elif stop_reason == "max_tokens": + finish_reason = "MAX_TOKENS" + + return { + "candidates": [{ + "content": { + "parts": parts, + "role": "model" + }, + "finishReason": finish_reason, + "index": 0 + }], + "usageMetadata": { + "promptTokenCount": 100, + "candidatesTokenCount": 100, + "totalTokenCount": 200 + } + } + + +# ==================== 思考功能支持 ==================== + +def generate_thinking_prefix(thinking_type: str = "enabled", budget_tokens: int = 20000) -> str: + """生成思考模式的前缀 XML 标签 + + Args: + thinking_type: 思考类型,通常为 "enabled" + budget_tokens: 思考的 token 预算 + + Returns: + XML 格式的思考标签字符串 + """ + if thinking_type != "enabled": + return "" + + return f"enabled\n{budget_tokens}" + + +def has_thinking_tags(text: str) -> bool: + """检查文本是否已包含思考标签 + + Args: + text: 要检查的文本 + + Returns: + 如果包含思考标签返回 True + """ + return "" in text and "" in text + + +def inject_thinking_tags_to_system(system, thinking_type: str = "enabled", budget_tokens: int = 20000): + """将思考标签注入到系统消息中 + + Args: + system: 原始系统消息 (可以是字符串或列表) + thinking_type: 思考类型 + budget_tokens: 思考的 token 预算 + + Returns: + 注入思考标签后的系统消息 (保持原始类型) + """ + # 生成思考前缀 + thinking_prefix = generate_thinking_prefix(thinking_type, budget_tokens) + + if not thinking_prefix: + return system + + # 处理 system 为列表的情况 (Anthropic API 支持 system 为 content blocks 列表) + if isinstance(system, list): + # 将列表转换为字符串 + system_text = "" + for block in system: + if isinstance(block, dict) and block.get("type") == "text": + system_text += block.get("text", "") + "\n" + elif isinstance(block, str): + system_text += block + "\n" + system_text = system_text.strip() + + if not system_text: + return thinking_prefix + + if has_thinking_tags(system_text): + return system + + # 返回字符串形式 + return f"{thinking_prefix}\n\n{system_text}" + + # 处理 system 为字符串的情况 + if not system or not str(system).strip(): + return thinking_prefix + + # 如果已经包含思考标签,不再重复注入 + if has_thinking_tags(str(system)): + return system + + # 将思考标签插入到系统消息开头 + return f"{thinking_prefix}\n\n{system}" + + +def find_real_thinking_start_tag(text: str, pos: int = 0) -> int: + """查找真正的 标签位置,忽略被引号包围的情况 + + Args: + text: 要搜索的文本 + pos: 开始搜索的位置 + + Returns: + 找到的标签位置,如果没找到返回 -1 + """ + while True: + idx = text.find("", pos) + if idx == -1: + return -1 + + # 检查是否被引号包围 + # 向前查找最近的引号 + prev_quote = max( + text.rfind("`", 0, idx), + text.rfind("'", 0, idx), + text.rfind('"', 0, idx) + ) + + # 如果有引号且引号后没有换行,说明是被包围的 + if prev_quote != -1: + # 检查引号到标签之间是否有换行 + between = text[prev_quote + 1:idx] + if "\n" not in between: + pos = idx + len("") + continue + + return idx + + +def find_real_thinking_end_tag(text: str, pos: int = 0) -> int: + """查找真正的 标签位置,忽略被引号包围的情况 + + Args: + text: 要搜索的文本 + pos: 开始搜索的位置 + + Returns: + 找到的标签位置,如果没找到返回 -1 + """ + while True: + idx = text.find("", pos) + if idx == -1: + return -1 + + # 检查是否被引号包围 + # 向前查找最近的引号 + prev_quote = max( + text.rfind("`", 0, idx), + text.rfind("'", 0, idx), + text.rfind('"', 0, idx) + ) + + # 如果有引号且引号后没有换行,说明是被包围的 + if prev_quote != -1: + # 检查引号到标签之间是否有换行 + between = text[prev_quote + 1:idx] + if "\n" not in between: + pos = idx + len("") + continue + + return idx + + +def extract_thinking_from_content(content: str) -> Tuple[str, str]: + """从内容中提取思考部分和正文部分 + + Args: + content: 原始内容 + + Returns: + (thinking_content, text_content) + """ + thinking_start = find_real_thinking_start_tag(content) + thinking_end = find_real_thinking_end_tag(content) + + if thinking_start == -1 or thinking_end == -1: + return "", content + + # 提取思考内容(去掉标签) + thinking_content = content[thinking_start + len(""):thinking_end].strip() + + # 提取正文内容(去掉思考部分) + text_content = content[:thinking_start].strip() + after_thinking = content[thinking_end + len(""):].strip() + if after_thinking: + text_content += "\n" + after_thinking + + return thinking_content, text_content diff --git a/KiroProxy/kiro_proxy/core/__init__.py b/KiroProxy/kiro_proxy/core/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d719f74938964f104a951600189c92f4a8e6d6bf --- /dev/null +++ b/KiroProxy/kiro_proxy/core/__init__.py @@ -0,0 +1,55 @@ +"""核心模块""" +from .state import state, ProxyState, RequestLog +from .account import Account +from .persistence import load_config, save_config, CONFIG_FILE +from .retry import RetryableRequest, is_retryable_error, RETRYABLE_STATUS_CODES +from .scheduler import scheduler +from .stats import stats_manager +from .browser import detect_browsers, open_url, get_browsers_info +from .flow_monitor import flow_monitor, FlowMonitor, LLMFlow, FlowState, TokenUsage +from .usage import get_usage_limits, get_account_usage, UsageInfo +from .history_manager import ( + HistoryManager, HistoryConfig, TruncateStrategy, + get_history_config, set_history_config, update_history_config, + is_content_length_error +) +from .error_handler import ( + ErrorType, KiroError, classify_error, is_account_suspended, + get_anthropic_error_response, format_error_log +) +from .rate_limiter import RateLimiter, RateLimitConfig, rate_limiter, get_rate_limiter + +# 新增模块 +from .quota_cache import QuotaCache, CachedQuota, get_quota_cache +from .account_selector import AccountSelector, SelectionStrategy, get_account_selector +from .quota_scheduler import QuotaScheduler, get_quota_scheduler +from .refresh_manager import ( + RefreshManager, RefreshProgress, RefreshConfig, + get_refresh_manager, reset_refresh_manager +) +from .kiro_api import kiro_api_request, get_user_info, get_user_email + +__all__ = [ + "state", "ProxyState", "RequestLog", "Account", + "load_config", "save_config", "CONFIG_FILE", + "RetryableRequest", "is_retryable_error", "RETRYABLE_STATUS_CODES", + "scheduler", "stats_manager", + "detect_browsers", "open_url", "get_browsers_info", + "flow_monitor", "FlowMonitor", "LLMFlow", "FlowState", "TokenUsage", + "get_usage_limits", "get_account_usage", "UsageInfo", + "HistoryManager", "HistoryConfig", "TruncateStrategy", + "get_history_config", "set_history_config", "update_history_config", + "is_content_length_error", + "ErrorType", "KiroError", "classify_error", "is_account_suspended", + "get_anthropic_error_response", "format_error_log", + "RateLimiter", "RateLimitConfig", "rate_limiter", "get_rate_limiter", + # 新增导出 + "QuotaCache", "CachedQuota", "get_quota_cache", + "AccountSelector", "SelectionStrategy", "get_account_selector", + "QuotaScheduler", "get_quota_scheduler", + # RefreshManager 导出 + "RefreshManager", "RefreshProgress", "RefreshConfig", + "get_refresh_manager", "reset_refresh_manager", + # Kiro API 导出 + "kiro_api_request", "get_user_info", "get_user_email", +] diff --git a/KiroProxy/kiro_proxy/core/account.py b/KiroProxy/kiro_proxy/core/account.py new file mode 100644 index 0000000000000000000000000000000000000000..574e3398037bf756011748012927c7c944f8d5cd --- /dev/null +++ b/KiroProxy/kiro_proxy/core/account.py @@ -0,0 +1,287 @@ +"""账号管理""" +import json +import time +from dataclasses import dataclass, field +from pathlib import Path +from typing import Optional + +from ..credential import ( + KiroCredentials, TokenRefresher, CredentialStatus, + generate_machine_id, quota_manager +) + + +@dataclass +class Account: + """账号信息""" + id: str + name: str + token_path: str + enabled: bool = True + # 是否因额度耗尽被自动禁用(用于区分手动禁用,避免被自动启用) + auto_disabled: bool = False + request_count: int = 0 + error_count: int = 0 + last_used: Optional[float] = None + status: CredentialStatus = CredentialStatus.ACTIVE + + _credentials: Optional[KiroCredentials] = field(default=None, repr=False) + _machine_id: Optional[str] = field(default=None, repr=False) + + def is_available(self) -> bool: + """检查账号是否可用""" + if not self.enabled: + return False + if self.status in (CredentialStatus.DISABLED, CredentialStatus.UNHEALTHY, CredentialStatus.SUSPENDED): + return False + if not quota_manager.is_available(self.id): + return False + + # 检查额度是否耗尽 + from .quota_cache import get_quota_cache + quota_cache = get_quota_cache() + quota = quota_cache.get(self.id) + if quota and quota.is_exhausted: + return False + + return True + + def is_active(self) -> bool: + """检查账号是否活跃(最近60秒内使用过)""" + from .quota_scheduler import get_quota_scheduler + scheduler = get_quota_scheduler() + return scheduler.is_active(self.id) + + def get_priority_order(self) -> Optional[int]: + """获取优先级顺序(从1开始),非优先账号返回 None""" + from .account_selector import get_account_selector + selector = get_account_selector() + return selector.get_priority_order(self.id) + + def is_priority(self) -> bool: + """检查是否为优先账号""" + return self.get_priority_order() is not None + + def load_credentials(self) -> Optional[KiroCredentials]: + """加载凭证信息""" + try: + self._credentials = KiroCredentials.from_file(self.token_path) + + if self._credentials.client_id_hash and not self._credentials.client_id: + self._merge_client_credentials() + + return self._credentials + except Exception as e: + print(f"[Account] 加载凭证失败 {self.id}: {e}") + return None + + def _merge_client_credentials(self): + """合并 clientIdHash 对应的凭证文件""" + if not self._credentials or not self._credentials.client_id_hash: + return + + cache_dir = Path(self.token_path).parent + hash_file = cache_dir / f"{self._credentials.client_id_hash}.json" + + if hash_file.exists(): + try: + with open(hash_file) as f: + data = json.load(f) + if not self._credentials.client_id: + self._credentials.client_id = data.get("clientId") + if not self._credentials.client_secret: + self._credentials.client_secret = data.get("clientSecret") + except Exception: + pass + + def get_credentials(self) -> Optional[KiroCredentials]: + """获取凭证(带缓存)""" + if self._credentials is None: + self.load_credentials() + return self._credentials + + def get_token(self) -> str: + """获取 access_token""" + creds = self.get_credentials() + if creds and creds.access_token: + return creds.access_token + + try: + with open(self.token_path) as f: + return json.load(f).get("accessToken", "") + except Exception: + return "" + + def get_machine_id(self) -> str: + """获取基于此账号的 Machine ID""" + if self._machine_id: + return self._machine_id + + creds = self.get_credentials() + if creds: + self._machine_id = generate_machine_id(creds.profile_arn, creds.client_id) + else: + self._machine_id = generate_machine_id() + + return self._machine_id + + def is_token_expired(self) -> bool: + """检查 token 是否过期""" + creds = self.get_credentials() + return creds.is_expired() if creds else True + + def is_token_expiring_soon(self, minutes: int = 10) -> bool: + """检查 token 是否即将过期""" + creds = self.get_credentials() + return creds.is_expiring_soon(minutes) if creds else False + + async def refresh_token(self) -> tuple: + """刷新 token""" + creds = self.get_credentials() + if not creds: + return False, "无法加载凭证" + + refresher = TokenRefresher(creds) + success, result = await refresher.refresh() + + if success: + creds.save_to_file(self.token_path) + self._credentials = creds + self.status = CredentialStatus.ACTIVE + return True, "Token 刷新成功" + else: + self.status = CredentialStatus.UNHEALTHY + return False, result + + def mark_quota_exceeded(self, reason: str = "Rate limited"): + """标记配额超限(进入冷却并避免被继续选中) + + 429 错误自动冷却 5 分钟,无需手动配置 + """ + quota_manager.mark_exceeded(self.id, reason) + self.status = CredentialStatus.COOLDOWN + self.error_count += 1 + + def get_status_info(self) -> dict: + """获取状态信息""" + cooldown_remaining = quota_manager.get_cooldown_remaining(self.id) + creds = self.get_credentials() + + # 获取额度信息 + from .quota_cache import get_quota_cache + quota_cache = get_quota_cache() + quota = quota_cache.get(self.id) + + quota_info = None + if quota: + # 计算相对时间 + updated_ago = "" + if quota.updated_at > 0: + seconds_ago = time.time() - quota.updated_at + if seconds_ago < 60: + updated_ago = f"{int(seconds_ago)}秒前" + elif seconds_ago < 3600: + updated_ago = f"{int(seconds_ago / 60)}分钟前" + else: + updated_ago = f"{int(seconds_ago / 3600)}小时前" + + # 格式化重置时间 + reset_date_text = None + if quota.next_reset_date: + try: + # 处理时间戳格式 + if isinstance(quota.next_reset_date, (int, float)): + from datetime import datetime + reset_dt = datetime.fromtimestamp(quota.next_reset_date) + reset_date_text = reset_dt.strftime('%Y-%m-%d') + else: + # 处理 ISO 格式 + from datetime import datetime + reset_dt = datetime.fromisoformat(quota.next_reset_date.replace('Z', '+00:00')) + reset_date_text = reset_dt.strftime('%Y-%m-%d') + except: + reset_date_text = str(quota.next_reset_date) + + # 格式化免费试用过期时间 + trial_expiry_text = None + if quota.free_trial_expiry: + try: + # 处理时间戳格式 + if isinstance(quota.free_trial_expiry, (int, float)): + from datetime import datetime + expiry_dt = datetime.fromtimestamp(quota.free_trial_expiry) + trial_expiry_text = expiry_dt.strftime('%Y-%m-%d') + else: + # 处理 ISO 格式 + from datetime import datetime + expiry_dt = datetime.fromisoformat(quota.free_trial_expiry.replace('Z', '+00:00')) + trial_expiry_text = expiry_dt.strftime('%Y-%m-%d') + except: + trial_expiry_text = str(quota.free_trial_expiry) + + # 计算生效奖励数 + active_bonuses = len([e for e in (quota.bonus_expiries or []) if e]) + + quota_info = { + "balance": quota.balance, + "usage_limit": quota.usage_limit, + "current_usage": quota.current_usage, + "usage_percent": quota.usage_percent, + "is_low_balance": quota.is_low_balance, + "is_exhausted": quota.is_exhausted, # 额度是否耗尽 + "is_suspended": getattr(quota, 'is_suspended', False), # 账号是否被封禁 + "balance_status": quota.balance_status, # 额度状态: normal, low, exhausted + "subscription_title": quota.subscription_title, + "free_trial_limit": quota.free_trial_limit, + "free_trial_usage": quota.free_trial_usage, + "bonus_limit": quota.bonus_limit, + "bonus_usage": quota.bonus_usage, + "updated_at": updated_ago, + "updated_timestamp": quota.updated_at, + "error": quota.error, + # 新增重置时间字段 + "next_reset_date": quota.next_reset_date, + "reset_date_text": reset_date_text, # 格式化后的重置日期 + "free_trial_expiry": quota.free_trial_expiry, + "trial_expiry_text": trial_expiry_text, # 格式化后的试用过期日期 + "bonus_expiries": quota.bonus_expiries or [], + "active_bonuses": active_bonuses, # 生效奖励数量 + } + + # 计算最后使用时间 + last_used_ago = None + if self.last_used: + seconds_ago = time.time() - self.last_used + if seconds_ago < 60: + last_used_ago = f"{int(seconds_ago)}秒前" + elif seconds_ago < 3600: + last_used_ago = f"{int(seconds_ago / 60)}分钟前" + else: + last_used_ago = f"{int(seconds_ago / 3600)}小时前" + + return { + "id": self.id, + "name": self.name, + "enabled": self.enabled, + "status": self.status.value, + "available": self.is_available(), + "request_count": self.request_count, + "error_count": self.error_count, + "error_rate": f"{(self.error_count / max(1, self.request_count) * 100):.1f}%", + "cooldown_remaining": cooldown_remaining, + "token_expired": self.is_token_expired() if creds else None, + "token_expiring_soon": self.is_token_expiring_soon() if creds else None, + "token_expires_at": creds.expires_at if creds else None, # Token 过期时间戳 + "auth_method": creds.auth_method if creds else None, + "has_refresh_token": bool(creds and creds.refresh_token), + "idc_config_complete": bool(creds and creds.client_id and creds.client_secret) if creds and creds.auth_method == "idc" else None, + # 新增字段 + "quota": quota_info, + "is_priority": self.is_priority(), + "priority_order": self.get_priority_order(), + "is_active": self.is_active(), + "last_used": self.last_used, + "last_used_ago": last_used_ago, + # Provider 字段 (Google/Github) + "provider": creds.provider if creds else None, + } diff --git a/KiroProxy/kiro_proxy/core/account_selector.py b/KiroProxy/kiro_proxy/core/account_selector.py new file mode 100644 index 0000000000000000000000000000000000000000..f17e107d3f83bd9b08f965e16db7a792c555ce7d --- /dev/null +++ b/KiroProxy/kiro_proxy/core/account_selector.py @@ -0,0 +1,390 @@ +"""账号选择器模块 + +实现基于剩余额度的智能账号选择策略,支持优先账号配置。 +""" +import json +import random +import time +from enum import Enum +from pathlib import Path +from typing import Optional, List, Set, TYPE_CHECKING +from threading import Lock + +if TYPE_CHECKING: + from .account import Account + from .quota_cache import QuotaCache + + +class SelectionStrategy(Enum): + """选择策略""" + LOWEST_BALANCE = "lowest_balance" # 剩余额度最少优先 + ROUND_ROBIN = "round_robin" # 轮询 + LEAST_REQUESTS = "least_requests" # 请求最少优先 + RANDOM = "random" # 随机选择(分散压力) + + +class AccountSelector: + """账号选择器 + + 根据配置的策略选择最合适的账号,支持优先账号配置。 + """ + + def __init__(self, quota_cache: 'QuotaCache', priority_file: Optional[str] = None): + """ + 初始化账号选择器 + + Args: + quota_cache: 额度缓存实例 + priority_file: 优先账号配置文件路径 + """ + self.quota_cache = quota_cache + self._priority_accounts: List[str] = [] + # 默认使用随机策略,避免单账号 RPM 过高导致封禁风险 + self._strategy = SelectionStrategy.RANDOM + self._lock = Lock() + self._round_robin_index = 0 + self._last_random_account_id: Optional[str] = None + + # 设置优先账号配置文件路径 + if priority_file: + self._priority_file = Path(priority_file) + else: + from ..config import DATA_DIR + self._priority_file = DATA_DIR / "priority.json" + + # 加载优先账号配置 + self._load_priority_config() + + @property + def strategy(self) -> SelectionStrategy: + """获取当前选择策略""" + return self._strategy + + @strategy.setter + def strategy(self, value: SelectionStrategy): + """设置选择策略""" + self._strategy = value + self._save_priority_config() + + def select(self, + available_accounts: List['Account'], + session_id: Optional[str] = None) -> Optional['Account']: + """选择最合适的账号 + + Args: + available_accounts: 可用账号列表 + session_id: 会话ID(用于会话粘性,暂未实现) + + Returns: + 选中的账号,如果没有可用账号则返回 None + """ + if not available_accounts: + return None + + with self._lock: + # 1. 首先检查优先账号 + if self._priority_accounts: + for priority_id in self._priority_accounts: + for account in available_accounts: + if account.id == priority_id and account.is_available(): + return account + + # 2. 根据策略选择 + if self._strategy == SelectionStrategy.LOWEST_BALANCE: + return self._select_lowest_balance(available_accounts) + elif self._strategy == SelectionStrategy.ROUND_ROBIN: + return self._select_round_robin(available_accounts) + elif self._strategy == SelectionStrategy.LEAST_REQUESTS: + return self._select_least_requests(available_accounts) + elif self._strategy == SelectionStrategy.RANDOM: + return self._select_random(available_accounts) + + # 默认返回第一个可用账号 + return available_accounts[0] if available_accounts else None + + def _select_lowest_balance(self, accounts: List['Account']) -> Optional['Account']: + """选择剩余额度最少的账号""" + available = [a for a in accounts if a.is_available()] + if not available: + return None + + def get_balance_and_requests(account: 'Account') -> tuple: + """获取账号的余额和请求数,用于排序""" + quota = self.quota_cache.get(account.id) + balance = quota.balance if quota and not quota.has_error() else float('inf') + return (balance, account.request_count) + + # 按余额升序,余额相同时按请求数升序 + return min(available, key=get_balance_and_requests) + + def _select_round_robin(self, accounts: List['Account']) -> Optional['Account']: + """轮询选择账号""" + available = [a for a in accounts if a.is_available()] + if not available: + return None + + self._round_robin_index = self._round_robin_index % len(available) + account = available[self._round_robin_index] + self._round_robin_index += 1 + return account + + def _select_least_requests(self, accounts: List['Account']) -> Optional['Account']: + """选择请求数最少的账号""" + available = [a for a in accounts if a.is_available()] + if not available: + return None + return min(available, key=lambda a: a.request_count) + + def _select_random(self, accounts: List['Account']) -> Optional['Account']: + """随机选择账号(分散请求压力)""" + available = [a for a in accounts if a.is_available()] + if not available: + return None + + # 尽量避免连续两次命中同一账号(在有多个可用账号时) + if self._last_random_account_id and len(available) > 1: + candidates = [a for a in available if a.id != self._last_random_account_id] + if candidates: + selected = random.choice(candidates) + else: + selected = random.choice(available) + else: + selected = random.choice(available) + + self._last_random_account_id = selected.id + return selected + + def set_priority_accounts(self, account_ids: List[str], + valid_account_ids: Optional[Set[str]] = None) -> tuple: + """设置优先账号列表(按顺序) + + Args: + account_ids: 优先账号ID列表(按顺序) + valid_account_ids: 有效账号ID集合(用于验证) + + Returns: + (success, message) + """ + with self._lock: + if not account_ids: + self._priority_accounts = [] + self._strategy = SelectionStrategy.RANDOM + self._save_priority_config() + return True, "已清除优先账号" + + # 去重(保持顺序) + unique_ids: List[str] = [] + seen: Set[str] = set() + for aid in account_ids: + if aid in seen: + continue + seen.add(aid) + unique_ids.append(aid) + + # 验证账号是否存在 + if valid_account_ids: + for aid in unique_ids: + if aid not in valid_account_ids: + return False, f"账号不存在: {aid}" + + self._priority_accounts = unique_ids + self._save_priority_config() + if len(unique_ids) == 1: + return True, f"已设置优先账号: {unique_ids[0]}" + return True, f"已设置优先账号: {', '.join(unique_ids)}" + + def set_priority_account(self, account_id: Optional[str], + valid_account_ids: Optional[Set[str]] = None) -> tuple: + """设置优先账号(单个) + + Args: + account_id: 账号ID,None 表示清除 + valid_account_ids: 有效账号ID集合(用于验证) + + Returns: + (success, message) + """ + if account_id is None: + return self.set_priority_accounts([], valid_account_ids) + return self.set_priority_accounts([account_id], valid_account_ids) + + def add_priority_account(self, account_id: str, + position: int = -1, + valid_account_ids: Optional[Set[str]] = None) -> tuple: + """添加优先账号(可指定插入位置) + + Args: + account_id: 账号ID + position: 插入位置(0-based),-1 表示追加到末尾 + valid_account_ids: 有效账号ID集合(用于验证) + + Returns: + (success, message) + """ + with self._lock: + if valid_account_ids and account_id not in valid_account_ids: + return False, f"账号不存在: {account_id}" + + if account_id in self._priority_accounts: + self._priority_accounts.remove(account_id) + + if position is None or position < 0 or position >= len(self._priority_accounts): + self._priority_accounts.append(account_id) + else: + self._priority_accounts.insert(position, account_id) + + self._save_priority_config() + return True, f"已添加优先账号: {account_id}" + + def remove_priority_account(self, account_id: str = None) -> tuple: + """移除优先账号 + + Args: + account_id: 账号ID(可选,不传则清除所有) + + Returns: + (success, message) + """ + with self._lock: + if not self._priority_accounts: + return False, "没有设置优先账号" + + if account_id: + if account_id not in self._priority_accounts: + return False, f"账号 {account_id} 不是优先账号" + + self._priority_accounts.remove(account_id) + if not self._priority_accounts: + self._strategy = SelectionStrategy.RANDOM + self._save_priority_config() + return True, f"已移除优先账号: {account_id}" + + self._priority_accounts = [] + self._strategy = SelectionStrategy.RANDOM + self._save_priority_config() + return True, "已清除优先账号" + + def reorder_priority(self, account_ids: List[str]) -> tuple: + """重新排序优先账号列表 + + Args: + account_ids: 新的优先账号顺序(必须与当前优先账号集合一致) + + Returns: + (success, message) + """ + with self._lock: + if not self._priority_accounts: + return False, "没有设置优先账号" + + if not account_ids: + return False, "账号列表不能为空" + + if len(account_ids) != len(self._priority_accounts): + return False, "账号数量不匹配" + + if len(set(account_ids)) != len(account_ids): + return False, "账号列表包含重复项" + + if set(account_ids) != set(self._priority_accounts): + return False, "账号列表与当前优先账号不匹配" + + self._priority_accounts = list(account_ids) + self._save_priority_config() + return True, "已更新优先账号顺序" + + def get_priority_account(self) -> Optional[str]: + """获取优先账号(单个)""" + with self._lock: + return self._priority_accounts[0] if self._priority_accounts else None + + def get_priority_accounts(self) -> List[str]: + """获取优先账号列表""" + with self._lock: + return list(self._priority_accounts) + + def is_priority_account(self, account_id: str) -> bool: + """检查账号是否为优先账号""" + with self._lock: + return account_id in self._priority_accounts + + def get_priority_order(self, account_id: str) -> Optional[int]: + """获取账号的优先级顺序(从1开始)""" + with self._lock: + if account_id in self._priority_accounts: + return self._priority_accounts.index(account_id) + 1 + return None + + def _load_priority_config(self) -> bool: + """从文件加载优先账号配置""" + if not self._priority_file.exists(): + return False + + try: + with open(self._priority_file, 'r', encoding='utf-8') as f: + data = json.load(f) + + self._priority_accounts = data.get("priority_accounts", []) + strategy_str = data.get("strategy", SelectionStrategy.RANDOM.value) + try: + self._strategy = SelectionStrategy(strategy_str) + except ValueError: + self._strategy = SelectionStrategy.RANDOM + + # 兼容旧版本:历史默认策略为 lowest_balance,但无优先账号时更需要分散压力 + if not self._priority_accounts and self._strategy == SelectionStrategy.LOWEST_BALANCE: + self._strategy = SelectionStrategy.RANDOM + self._save_priority_config() + + print(f"[AccountSelector] 加载优先账号配置: {len(self._priority_accounts)} 个优先账号") + return True + + except Exception as e: + print(f"[AccountSelector] 加载优先账号配置失败: {e}") + return False + + def _save_priority_config(self) -> bool: + """保存优先账号配置到文件""" + try: + self._priority_file.parent.mkdir(parents=True, exist_ok=True) + + data = { + "version": "1.0", + "priority_accounts": self._priority_accounts, + "strategy": self._strategy.value + } + + temp_file = self._priority_file.with_suffix('.tmp') + with open(temp_file, 'w', encoding='utf-8') as f: + json.dump(data, f, indent=2, ensure_ascii=False) + temp_file.replace(self._priority_file) + + return True + + except Exception as e: + print(f"[AccountSelector] 保存优先账号配置失败: {e}") + return False + + def get_status(self) -> dict: + """获取选择器状态""" + with self._lock: + return { + "strategy": self._strategy.value, + "priority_accounts": list(self._priority_accounts), + "priority_count": len(self._priority_accounts) + } + + +# 全局选择器实例 +_account_selector: Optional[AccountSelector] = None + + +def get_account_selector(quota_cache: Optional['QuotaCache'] = None) -> AccountSelector: + """获取全局选择器实例""" + global _account_selector + if _account_selector is None: + if quota_cache is None: + from .quota_cache import get_quota_cache + quota_cache = get_quota_cache() + _account_selector = AccountSelector(quota_cache) + return _account_selector diff --git a/KiroProxy/kiro_proxy/core/browser.py b/KiroProxy/kiro_proxy/core/browser.py new file mode 100644 index 0000000000000000000000000000000000000000..e5b7b110378cd08d62b68b4386b24f9aa09e914e --- /dev/null +++ b/KiroProxy/kiro_proxy/core/browser.py @@ -0,0 +1,186 @@ +"""浏览器检测和打开""" +import os +import shlex +import shutil +import subprocess +import platform +from dataclasses import dataclass +from typing import List, Optional + + +@dataclass +class BrowserInfo: + id: str + name: str + path: str + supports_incognito: bool + incognito_arg: str = "" + + +# 浏览器配置 +BROWSER_CONFIGS = { + "chrome": { + "names": ["google-chrome", "google-chrome-stable", "chrome", "chromium", "chromium-browser"], + "display": "Chrome", + "incognito": "--incognito", + }, + "firefox": { + "names": ["firefox", "firefox-esr"], + "display": "Firefox", + "incognito": "--private-window", + }, + "edge": { + "names": ["microsoft-edge", "microsoft-edge-stable", "msedge"], + "display": "Edge", + "incognito": "--inprivate", + }, + "brave": { + "names": ["brave", "brave-browser"], + "display": "Brave", + "incognito": "--incognito", + }, + "opera": { + "names": ["opera"], + "display": "Opera", + "incognito": "--private", + }, + "vivaldi": { + "names": ["vivaldi", "vivaldi-stable"], + "display": "Vivaldi", + "incognito": "--incognito", + }, +} + + +def detect_browsers() -> List[BrowserInfo]: + """检测系统安装的浏览器""" + browsers = [] + system = platform.system().lower() + + if system == "windows": + import winreg + + def normalize_exe_path(raw: str) -> Optional[str]: + if not raw: + return None + expanded = os.path.expandvars(raw.strip()) + try: + parts = shlex.split(expanded, posix=False) + except ValueError: + parts = [expanded] + candidate = (parts[0] if parts else expanded).strip().strip('"') + if os.path.exists(candidate): + return candidate + lower = expanded.lower() + exe_idx = lower.find(".exe") + if exe_idx != -1: + candidate = expanded[:exe_idx + 4].strip().strip('"') + if os.path.exists(candidate): + return candidate + return None + + def get_reg_path(exe_name: str) -> Optional[str]: + name = f"{exe_name}.exe" + for root in (winreg.HKEY_LOCAL_MACHINE, winreg.HKEY_CURRENT_USER): + try: + with winreg.OpenKey(root, rf"SOFTWARE\Microsoft\Windows\CurrentVersion\App Paths\{name}") as key: + value, _ = winreg.QueryValueEx(key, "") + path = normalize_exe_path(value) + if path: + return path + except (FileNotFoundError, OSError, WindowsError): + pass + return None + + for browser_id, config in BROWSER_CONFIGS.items(): + path = None + for exe_name in config["names"]: + path = get_reg_path(exe_name) + if path: + break + if not path: + for exe_name in config["names"]: + path = shutil.which(exe_name) + if path: + break + if path: + browsers.append(BrowserInfo( + id=browser_id, + name=config["display"], + path=path, + supports_incognito=bool(config.get("incognito")), + incognito_arg=config.get("incognito", ""), + )) + else: + for browser_id, config in BROWSER_CONFIGS.items(): + for name in config["names"]: + path = shutil.which(name) + if path: + browsers.append(BrowserInfo( + id=browser_id, + name=config["display"], + path=path, + supports_incognito=bool(config.get("incognito")), + incognito_arg=config.get("incognito", ""), + )) + break + + # 添加默认浏览器选项 + if browsers: + browsers.insert(0, BrowserInfo( + id="default", + name="默认浏览器", + path="xdg-open" if system == "linux" else "open", + supports_incognito=False, + incognito_arg="", + )) + + return browsers + + +def open_url(url: str, browser_id: str = "default", incognito: bool = False) -> bool: + """用指定浏览器打开 URL""" + browsers = detect_browsers() + browser = next((b for b in browsers if b.id == browser_id), None) + + if not browser: + # 降级到默认 + browser = browsers[0] if browsers else None + + if not browser: + return False + + try: + if browser.id == "default": + # 使用系统默认浏览器 + system = platform.system().lower() + if system == "linux": + subprocess.Popen(["xdg-open", url], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) + elif system == "darwin": + subprocess.Popen(["open", url], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) + else: + os.startfile(url) + else: + # 使用指定浏览器 + args = [browser.path] + if incognito and browser.supports_incognito and browser.incognito_arg: + args.append(browser.incognito_arg) + args.append(url) + subprocess.Popen(args, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) + + return True + except Exception as e: + print(f"[Browser] 打开失败: {e}") + return False + + +def get_browsers_info() -> List[dict]: + """获取浏览器信息列表""" + return [ + { + "id": b.id, + "name": b.name, + "supports_incognito": b.supports_incognito, + } + for b in detect_browsers() + ] diff --git a/KiroProxy/kiro_proxy/core/error_handler.py b/KiroProxy/kiro_proxy/core/error_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..642d827ee06a25a6edf6709db5c1178e807761f3 --- /dev/null +++ b/KiroProxy/kiro_proxy/core/error_handler.py @@ -0,0 +1,188 @@ +"""错误处理模块 - 统一的错误分类和处理 + +检测各种 Kiro API 错误类型: +- 账号封禁 (TEMPORARILY_SUSPENDED) +- 配额超限 (Rate Limit) +- 内容过长 (CONTENT_LENGTH_EXCEEDS_THRESHOLD) +- 认证失败 (Unauthorized) +- 服务不可用 (Service Unavailable) +""" +import re +from enum import Enum +from dataclasses import dataclass +from typing import Optional, Tuple + + +class ErrorType(str, Enum): + """错误类型""" + ACCOUNT_SUSPENDED = "account_suspended" # 账号被封禁 + RATE_LIMITED = "rate_limited" # 配额超限 + CONTENT_TOO_LONG = "content_too_long" # 内容过长 + AUTH_FAILED = "auth_failed" # 认证失败 + SERVICE_UNAVAILABLE = "service_unavailable" # 服务不可用 + MODEL_UNAVAILABLE = "model_unavailable" # 模型不可用 + UNKNOWN = "unknown" # 未知错误 + + +@dataclass +class KiroError: + """Kiro API 错误""" + type: ErrorType + status_code: int + message: str + user_message: str # 用户友好的消息 + should_disable_account: bool = False # 是否应该禁用账号 + should_switch_account: bool = False # 是否应该切换账号 + should_retry: bool = False # 是否应该重试 + cooldown_seconds: int = 0 # 冷却时间 + + +def classify_error(status_code: int, error_text: str) -> KiroError: + """分类 Kiro API 错误 + + Args: + status_code: HTTP 状态码 + error_text: 错误响应文本 + + Returns: + KiroError 对象 + """ + error_lower = error_text.lower() + + # 1. 账号封禁检测 (最严重) + # 检测: AccountSuspendedException, 423 状态码, temporarily_suspended, suspended + is_suspended = ( + status_code == 423 or + "accountsuspendedexception" in error_lower or + "temporarily_suspended" in error_lower or + "suspended" in error_lower + ) + + if is_suspended: + # 提取 User ID + user_id_match = re.search(r'User ID \(([^)]+)\)', error_text) + user_id = user_id_match.group(1) if user_id_match else "unknown" + + return KiroError( + type=ErrorType.ACCOUNT_SUSPENDED, + status_code=status_code, + message=error_text, + user_message=f"⚠️ 账号已被封禁 (User ID: {user_id})。请联系 AWS 支持解封: https://support.aws.amazon.com/#/contacts/kiro", + should_disable_account=True, + should_switch_account=True, + ) + + # 2. 402 Payment Required - 额度用尽(不触发冷却,仅切换账号) + if status_code == 402 or "payment required" in error_lower or "insufficient" in error_lower: + return KiroError( + type=ErrorType.RATE_LIMITED, + status_code=status_code, + message=error_text, + user_message="账号额度已用尽,已切换到其他账号", + should_switch_account=False, # 不自动切换,让上层逻辑处理 + cooldown_seconds=0, # 不触发冷却 + ) + + # 3. 配额超限检测 (仅 429 触发冷却) + if status_code == 429: + return KiroError( + type=ErrorType.RATE_LIMITED, + status_code=status_code, + message=error_text, + user_message="请求过于频繁,账号已进入冷却期", + should_switch_account=True, + cooldown_seconds=30, # 基础冷却时间,实际由 QuotaManager 动态管理 + ) + + # 4. 内容过长检测 + if "content_length_exceeds_threshold" in error_lower or ( + "too long" in error_lower and ("input" in error_lower or "content" in error_lower) + ): + return KiroError( + type=ErrorType.CONTENT_TOO_LONG, + status_code=status_code, + message=error_text, + user_message="对话历史过长,请使用 /clear 清空对话", + should_retry=True, + ) + + # 5. 认证失败检测 + if status_code == 401 or "unauthorized" in error_lower or "invalid token" in error_lower: + return KiroError( + type=ErrorType.AUTH_FAILED, + status_code=status_code, + message=error_text, + user_message="Token 已过期或无效,请刷新 Token", + should_switch_account=True, + ) + + # 6. 模型不可用检测 + if "model_temporarily_unavailable" in error_lower or "unexpectedly high load" in error_lower: + return KiroError( + type=ErrorType.MODEL_UNAVAILABLE, + status_code=status_code, + message=error_text, + user_message="模型暂时不可用,请稍后重试", + should_retry=True, + ) + + # 7. 服务不可用检测 + if status_code in (502, 503, 504) or "service unavailable" in error_lower: + return KiroError( + type=ErrorType.SERVICE_UNAVAILABLE, + status_code=status_code, + message=error_text, + user_message="服务暂时不可用,请稍后重试", + should_retry=True, + ) + + # 8. 未知错误 + return KiroError( + type=ErrorType.UNKNOWN, + status_code=status_code, + message=error_text, + user_message=f"API 错误 ({status_code})", + ) + + +def is_account_suspended(status_code: int, error_text: str) -> bool: + """检查是否为账号封禁错误""" + error = classify_error(status_code, error_text) + return error.type == ErrorType.ACCOUNT_SUSPENDED + + +def get_anthropic_error_response(error: KiroError) -> dict: + """生成 Anthropic 格式的错误响应""" + error_type_map = { + ErrorType.ACCOUNT_SUSPENDED: "authentication_error", + ErrorType.RATE_LIMITED: "rate_limit_error", + ErrorType.CONTENT_TOO_LONG: "invalid_request_error", + ErrorType.AUTH_FAILED: "authentication_error", + ErrorType.SERVICE_UNAVAILABLE: "api_error", + ErrorType.MODEL_UNAVAILABLE: "overloaded_error", + ErrorType.UNKNOWN: "api_error", + } + + return { + "type": "error", + "error": { + "type": error_type_map.get(error.type, "api_error"), + "message": error.user_message + } + } + + +def format_error_log(error: KiroError, account_id: str = None) -> str: + """格式化错误日志""" + lines = [ + f"[{error.type.value.upper()}]", + f" Status: {error.status_code}", + f" Message: {error.user_message}", + ] + if account_id: + lines.insert(1, f" Account: {account_id}") + if error.should_disable_account: + lines.append(" Action: 账号已被禁用") + elif error.should_switch_account: + lines.append(" Action: 切换到其他账号") + return "\n".join(lines) diff --git a/KiroProxy/kiro_proxy/core/flow_monitor.py b/KiroProxy/kiro_proxy/core/flow_monitor.py new file mode 100644 index 0000000000000000000000000000000000000000..c1256243d45ecca5a28225ae9244c39ba57e1d1b --- /dev/null +++ b/KiroProxy/kiro_proxy/core/flow_monitor.py @@ -0,0 +1,572 @@ +"""Flow Monitor - LLM 流量监控 + +记录完整的请求/响应数据,支持查询、过滤、导出。 +""" +import json +import time +import uuid +from pathlib import Path +from dataclasses import dataclass, field, asdict +from typing import Optional, List, Dict, Any +from datetime import datetime, timezone +from collections import deque +from enum import Enum + + +class FlowState(str, Enum): + """Flow 状态""" + PENDING = "pending" # 等待响应 + STREAMING = "streaming" # 流式传输中 + COMPLETED = "completed" # 完成 + ERROR = "error" # 错误 + + +@dataclass +class Message: + """消息""" + role: str # user/assistant/system/tool + content: Any # str 或 list + name: Optional[str] = None # tool name + tool_call_id: Optional[str] = None + + +@dataclass +class TokenUsage: + """Token 使用量""" + input_tokens: int = 0 + output_tokens: int = 0 + cache_read_tokens: int = 0 + cache_write_tokens: int = 0 + + @property + def total_tokens(self) -> int: + return self.input_tokens + self.output_tokens + + +@dataclass +class FlowRequest: + """请求数据""" + method: str + path: str + headers: Dict[str, str] + body: Dict[str, Any] + + # 解析后的字段 + model: str = "" + messages: List[Message] = field(default_factory=list) + system: str = "" + tools: List[Dict] = field(default_factory=list) + stream: bool = False + max_tokens: int = 0 + temperature: float = 1.0 + + +@dataclass +class FlowResponse: + """响应数据""" + status_code: int + headers: Dict[str, str] = field(default_factory=dict) + body: Any = None + + # 解析后的字段 + content: str = "" + tool_calls: List[Dict] = field(default_factory=list) + stop_reason: str = "" + usage: TokenUsage = field(default_factory=TokenUsage) + + # 流式响应 + chunks: List[str] = field(default_factory=list) + chunk_count: int = 0 + + +@dataclass +class FlowError: + """错误信息""" + type: str # rate_limit_error, api_error, etc. + message: str + status_code: int = 0 + raw: str = "" + + +@dataclass +class FlowTiming: + """时间信息""" + created_at: float = 0 + first_byte_at: Optional[float] = None + completed_at: Optional[float] = None + + @property + def ttfb_ms(self) -> Optional[float]: + """Time to first byte""" + if self.first_byte_at and self.created_at: + return (self.first_byte_at - self.created_at) * 1000 + return None + + @property + def duration_ms(self) -> Optional[float]: + """Total duration""" + if self.completed_at and self.created_at: + return (self.completed_at - self.created_at) * 1000 + return None + + +@dataclass +class LLMFlow: + """完整的 LLM 请求流""" + id: str + state: FlowState + + # 路由信息 + protocol: str # anthropic, openai, gemini + account_id: Optional[str] = None + account_name: Optional[str] = None + + # 请求/响应 + request: Optional[FlowRequest] = None + response: Optional[FlowResponse] = None + error: Optional[FlowError] = None + + # 时间 + timing: FlowTiming = field(default_factory=FlowTiming) + + # 元数据 + tags: List[str] = field(default_factory=list) + notes: str = "" + bookmarked: bool = False + + # 重试信息 + retry_count: int = 0 + parent_flow_id: Optional[str] = None + + def to_dict(self) -> dict: + """转换为字典""" + d = { + "id": self.id, + "state": self.state.value, + "protocol": self.protocol, + "account_id": self.account_id, + "account_name": self.account_name, + "timing": { + "created_at": self.timing.created_at, + "first_byte_at": self.timing.first_byte_at, + "completed_at": self.timing.completed_at, + "ttfb_ms": self.timing.ttfb_ms, + "duration_ms": self.timing.duration_ms, + }, + "tags": self.tags, + "notes": self.notes, + "bookmarked": self.bookmarked, + "retry_count": self.retry_count, + } + + if self.request: + d["request"] = { + "method": self.request.method, + "path": self.request.path, + "model": self.request.model, + "stream": self.request.stream, + "message_count": len(self.request.messages), + "has_tools": bool(self.request.tools), + "has_system": bool(self.request.system), + } + + if self.response: + d["response"] = { + "status_code": self.response.status_code, + "content_length": len(self.response.content), + "has_tool_calls": bool(self.response.tool_calls), + "stop_reason": self.response.stop_reason, + "chunk_count": self.response.chunk_count, + "usage": asdict(self.response.usage), + } + + if self.error: + d["error"] = asdict(self.error) + + return d + + def to_full_dict(self) -> dict: + """转换为完整字典(包含请求/响应体)""" + d = self.to_dict() + + if self.request: + d["request"]["headers"] = self.request.headers + d["request"]["body"] = self.request.body + d["request"]["messages"] = [asdict(m) if hasattr(m, '__dataclass_fields__') else m for m in self.request.messages] + d["request"]["system"] = self.request.system + d["request"]["tools"] = self.request.tools + + if self.response: + d["response"]["headers"] = self.response.headers + d["response"]["body"] = self.response.body + d["response"]["content"] = self.response.content + d["response"]["tool_calls"] = self.response.tool_calls + d["response"]["chunks"] = self.response.chunks[-10:] # 只保留最后10个chunk + + return d + + +class FlowStore: + """Flow 存储""" + + def __init__(self, max_flows: int = 500, persist_dir: Optional[Path] = None): + self.flows: deque[LLMFlow] = deque(maxlen=max_flows) + self.flow_map: Dict[str, LLMFlow] = {} + self.persist_dir = persist_dir + self.max_flows = max_flows + + # 统计 + self.total_flows = 0 + self.total_tokens_in = 0 + self.total_tokens_out = 0 + + def add(self, flow: LLMFlow): + """添加 Flow""" + # 如果队列满了,移除最旧的 + if len(self.flows) >= self.max_flows: + old = self.flows[0] + if old.id in self.flow_map: + del self.flow_map[old.id] + + self.flows.append(flow) + self.flow_map[flow.id] = flow + self.total_flows += 1 + + def get(self, flow_id: str) -> Optional[LLMFlow]: + """获取 Flow""" + return self.flow_map.get(flow_id) + + def update(self, flow_id: str, **kwargs): + """更新 Flow""" + flow = self.flow_map.get(flow_id) + if flow: + for k, v in kwargs.items(): + if hasattr(flow, k): + setattr(flow, k, v) + + def query( + self, + protocol: Optional[str] = None, + model: Optional[str] = None, + account_id: Optional[str] = None, + state: Optional[FlowState] = None, + has_error: Optional[bool] = None, + bookmarked: Optional[bool] = None, + min_duration_ms: Optional[float] = None, + max_duration_ms: Optional[float] = None, + start_time: Optional[float] = None, + end_time: Optional[float] = None, + search: Optional[str] = None, + limit: int = 100, + offset: int = 0, + ) -> List[LLMFlow]: + """查询 Flows""" + results = [] + + for flow in reversed(self.flows): + # 过滤条件 + if protocol and flow.protocol != protocol: + continue + if model and flow.request and flow.request.model != model: + continue + if account_id and flow.account_id != account_id: + continue + if state and flow.state != state: + continue + if has_error is not None: + if has_error and not flow.error: + continue + if not has_error and flow.error: + continue + if bookmarked is not None and flow.bookmarked != bookmarked: + continue + if min_duration_ms and flow.timing.duration_ms and flow.timing.duration_ms < min_duration_ms: + continue + if max_duration_ms and flow.timing.duration_ms and flow.timing.duration_ms > max_duration_ms: + continue + if start_time and flow.timing.created_at < start_time: + continue + if end_time and flow.timing.created_at > end_time: + continue + if search: + # 简单搜索:在内容中查找 + found = False + if flow.request and search.lower() in json.dumps(flow.request.body).lower(): + found = True + if flow.response and search.lower() in flow.response.content.lower(): + found = True + if not found: + continue + + results.append(flow) + + return results[offset:offset + limit] + + def get_stats(self) -> dict: + """获取统计信息""" + completed = [f for f in self.flows if f.state == FlowState.COMPLETED] + errors = [f for f in self.flows if f.state == FlowState.ERROR] + + # 按模型统计 + model_stats = {} + for f in self.flows: + if f.request: + model = f.request.model or "unknown" + if model not in model_stats: + model_stats[model] = {"count": 0, "errors": 0, "tokens_in": 0, "tokens_out": 0} + model_stats[model]["count"] += 1 + if f.error: + model_stats[model]["errors"] += 1 + if f.response and f.response.usage: + model_stats[model]["tokens_in"] += f.response.usage.input_tokens + model_stats[model]["tokens_out"] += f.response.usage.output_tokens + + # 计算平均延迟 + durations = [f.timing.duration_ms for f in completed if f.timing.duration_ms] + avg_duration = sum(durations) / len(durations) if durations else 0 + + return { + "total_flows": self.total_flows, + "active_flows": len(self.flows), + "completed": len(completed), + "errors": len(errors), + "error_rate": f"{len(errors) / max(1, len(self.flows)) * 100:.1f}%", + "avg_duration_ms": round(avg_duration, 2), + "total_tokens_in": self.total_tokens_in, + "total_tokens_out": self.total_tokens_out, + "by_model": model_stats, + } + + def export_jsonl(self, flows: List[LLMFlow]) -> str: + """导出为 JSONL 格式""" + lines = [] + for f in flows: + lines.append(json.dumps(f.to_full_dict(), ensure_ascii=False)) + return "\n".join(lines) + + def export_markdown(self, flow: LLMFlow) -> str: + """导出单个 Flow 为 Markdown""" + lines = [ + f"# Flow {flow.id}", + "", + f"- **Protocol**: {flow.protocol}", + f"- **State**: {flow.state.value}", + f"- **Account**: {flow.account_name or flow.account_id or 'N/A'}", + f"- **Created**: {datetime.fromtimestamp(flow.timing.created_at).isoformat()}", + ] + + if flow.timing.duration_ms: + lines.append(f"- **Duration**: {flow.timing.duration_ms:.0f}ms") + + if flow.request: + lines.extend([ + "", + "## Request", + "", + f"- **Model**: {flow.request.model}", + f"- **Stream**: {flow.request.stream}", + f"- **Messages**: {len(flow.request.messages)}", + ]) + + if flow.request.system: + lines.extend(["", "### System", "", f"```\n{flow.request.system}\n```"]) + + lines.extend(["", "### Messages", ""]) + for msg in flow.request.messages: + content = msg.content if isinstance(msg.content, str) else json.dumps(msg.content, ensure_ascii=False) + lines.append(f"**{msg.role}**: {content[:500]}{'...' if len(content) > 500 else ''}") + lines.append("") + + if flow.response: + lines.extend([ + "## Response", + "", + f"- **Status**: {flow.response.status_code}", + f"- **Stop Reason**: {flow.response.stop_reason}", + ]) + + if flow.response.usage: + lines.append(f"- **Tokens**: {flow.response.usage.input_tokens} in / {flow.response.usage.output_tokens} out") + + if flow.response.content: + lines.extend(["", "### Content", "", f"```\n{flow.response.content[:2000]}\n```"]) + + if flow.error: + lines.extend([ + "", + "## Error", + "", + f"- **Type**: {flow.error.type}", + f"- **Message**: {flow.error.message}", + ]) + + return "\n".join(lines) + + +class FlowMonitor: + """Flow 监控器""" + + def __init__(self, max_flows: int = 500): + self.store = FlowStore(max_flows=max_flows) + + def create_flow( + self, + protocol: str, + method: str, + path: str, + headers: Dict[str, str], + body: Dict[str, Any], + account_id: Optional[str] = None, + account_name: Optional[str] = None, + ) -> str: + """创建新的 Flow""" + flow_id = uuid.uuid4().hex[:12] + + # 解析请求 + request = FlowRequest( + method=method, + path=path, + headers={k: v for k, v in headers.items() if k.lower() not in ["authorization"]}, + body=body, + model=body.get("model", ""), + stream=body.get("stream", False), + system=body.get("system", ""), + tools=body.get("tools", []), + max_tokens=body.get("max_tokens", 0), + temperature=body.get("temperature", 1.0), + ) + + # 解析消息 + messages = body.get("messages", []) + for msg in messages: + request.messages.append(Message( + role=msg.get("role", "user"), + content=msg.get("content", ""), + name=msg.get("name"), + tool_call_id=msg.get("tool_call_id"), + )) + + flow = LLMFlow( + id=flow_id, + state=FlowState.PENDING, + protocol=protocol, + account_id=account_id, + account_name=account_name, + request=request, + timing=FlowTiming(created_at=time.time()), + ) + + self.store.add(flow) + return flow_id + + def start_streaming(self, flow_id: str): + """标记开始流式传输""" + flow = self.store.get(flow_id) + if flow: + flow.state = FlowState.STREAMING + flow.timing.first_byte_at = time.time() + if not flow.response: + flow.response = FlowResponse(status_code=200) + + def add_chunk(self, flow_id: str, chunk: str): + """添加流式响应块""" + flow = self.store.get(flow_id) + if flow and flow.response: + flow.response.chunks.append(chunk) + flow.response.chunk_count += 1 + flow.response.content += chunk + + def complete_flow( + self, + flow_id: str, + status_code: int, + content: str = "", + tool_calls: List[Dict] = None, + stop_reason: str = "", + usage: Optional[TokenUsage] = None, + headers: Dict[str, str] = None, + ): + """完成 Flow""" + flow = self.store.get(flow_id) + if not flow: + return + + flow.state = FlowState.COMPLETED + flow.timing.completed_at = time.time() + + if not flow.response: + flow.response = FlowResponse(status_code=status_code) + + flow.response.status_code = status_code + flow.response.content = content or flow.response.content + flow.response.tool_calls = tool_calls or [] + flow.response.stop_reason = stop_reason + flow.response.headers = headers or {} + + if usage: + flow.response.usage = usage + self.store.total_tokens_in += usage.input_tokens + self.store.total_tokens_out += usage.output_tokens + + def fail_flow(self, flow_id: str, error_type: str, message: str, status_code: int = 0, raw: str = ""): + """标记 Flow 失败""" + flow = self.store.get(flow_id) + if not flow: + return + + flow.state = FlowState.ERROR + flow.timing.completed_at = time.time() + flow.error = FlowError( + type=error_type, + message=message, + status_code=status_code, + raw=raw[:1000], # 限制长度 + ) + + def bookmark_flow(self, flow_id: str, bookmarked: bool = True): + """书签 Flow""" + flow = self.store.get(flow_id) + if flow: + flow.bookmarked = bookmarked + + def add_note(self, flow_id: str, note: str): + """添加备注""" + flow = self.store.get(flow_id) + if flow: + flow.notes = note + + def add_tag(self, flow_id: str, tag: str): + """添加标签""" + flow = self.store.get(flow_id) + if flow and tag not in flow.tags: + flow.tags.append(tag) + + def get_flow(self, flow_id: str) -> Optional[LLMFlow]: + """获取 Flow""" + return self.store.get(flow_id) + + def query(self, **kwargs) -> List[LLMFlow]: + """查询 Flows""" + return self.store.query(**kwargs) + + def get_stats(self) -> dict: + """获取统计""" + return self.store.get_stats() + + def export(self, flow_ids: List[str] = None, format: str = "jsonl") -> str: + """导出 Flows""" + if flow_ids: + flows = [self.store.get(fid) for fid in flow_ids if self.store.get(fid)] + else: + flows = list(self.store.flows) + + if format == "jsonl": + return self.store.export_jsonl(flows) + elif format == "markdown" and len(flows) == 1: + return self.store.export_markdown(flows[0]) + else: + return json.dumps([f.to_dict() for f in flows], ensure_ascii=False, indent=2) + + +# 全局实例 +flow_monitor = FlowMonitor(max_flows=500) diff --git a/KiroProxy/kiro_proxy/core/history_manager.py b/KiroProxy/kiro_proxy/core/history_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..cd48ca036f930fd93a389464ae9dfb2e78de5049 --- /dev/null +++ b/KiroProxy/kiro_proxy/core/history_manager.py @@ -0,0 +1,829 @@ +"""历史消息管理器 - 错误触发压缩版 + +自动化管理对话历史长度,收到超限错误时智能压缩而非强硬截断: +1. 无预检测 - 不再依赖阈值,正常发送请求 +2. 错误触发 - 收到 CONTENT_LENGTH_EXCEEDS_THRESHOLD 错误后自动压缩 +3. 智能压缩 - 保留最近消息 + 摘要早期对话,目标 20K-50K 字符 +4. 自动重试 - 压缩后自动重试请求 +""" +import json +import time +from typing import List, Dict, Any, Tuple, Optional, Callable +from dataclasses import dataclass, field +from collections import OrderedDict +from enum import Enum + + +@dataclass +class SummaryCacheEntry: + summary: str + old_history_hash: str + updated_at: float + + +class SummaryCache: + """摘要缓存""" + + def __init__(self, max_entries: int = 64): + self._entries: "OrderedDict[str, SummaryCacheEntry]" = OrderedDict() + self._max_entries = max_entries + + def get(self, key: str, old_history_hash: str, max_age: int = 300) -> Optional[str]: + entry = self._entries.get(key) + if not entry: + return None + if time.time() - entry.updated_at > max_age: + self._entries.pop(key, None) + return None + if entry.old_history_hash != old_history_hash: + return None + self._entries.move_to_end(key) + return entry.summary + + def set(self, key: str, summary: str, old_history_hash: str): + self._entries[key] = SummaryCacheEntry( + summary=summary, + old_history_hash=old_history_hash, + updated_at=time.time() + ) + self._entries.move_to_end(key) + if len(self._entries) > self._max_entries: + self._entries.popitem(last=False) + + +@dataclass +class CompressionCacheEntry: + """压缩结果缓存条目""" + compressed_history: List[dict] + original_hash: str + compressed_chars: int + updated_at: float + + +class CompressionCache: + """全局压缩结果缓存 + + 解决 Claude Code CLI 反复压缩问题: + - 客户端每次请求都发送完整原始历史 + - 缓存压缩结果,避免对相同内容重复压缩 + - 基于原始历史的 hash 匹配 + """ + + def __init__(self, max_entries: int = 32, max_age: int = 600): + self._entries: "OrderedDict[str, CompressionCacheEntry]" = OrderedDict() + self._max_entries = max_entries + self._max_age = max_age # 缓存有效期(秒),默认 10 分钟 + + def get(self, original_hash: str) -> Optional[List[dict]]: + """获取缓存的压缩结果""" + entry = self._entries.get(original_hash) + if not entry: + return None + if time.time() - entry.updated_at > self._max_age: + self._entries.pop(original_hash, None) + return None + self._entries.move_to_end(original_hash) + print(f"[CompressionCache] 命中缓存,跳过重复压缩 (原始 hash: {original_hash[:16]}...)") + return entry.compressed_history + + def set(self, original_hash: str, compressed_history: List[dict], compressed_chars: int): + """缓存压缩结果""" + self._entries[original_hash] = CompressionCacheEntry( + compressed_history=compressed_history, + original_hash=original_hash, + compressed_chars=compressed_chars, + updated_at=time.time() + ) + self._entries.move_to_end(original_hash) + if len(self._entries) > self._max_entries: + self._entries.popitem(last=False) + print(f"[CompressionCache] 缓存压缩结果 (原始 hash: {original_hash[:16]}..., 压缩后: {compressed_chars} 字符)") + + def clear(self): + """清空缓存""" + self._entries.clear() + + +# 全局压缩缓存实例 +_compression_cache = CompressionCache() + + +class TruncateStrategy(str, Enum): + """压缩策略(保留用于兼容)""" + NONE = "none" + AUTO_TRUNCATE = "auto_truncate" + SMART_SUMMARY = "smart_summary" + ERROR_RETRY = "error_retry" + PRE_ESTIMATE = "pre_estimate" + + +# 自动管理的常量(不再使用阈值触发,仅在错误后压缩) +# AUTO_COMPRESS_THRESHOLD 已废弃,不再用于预检测 +SAFE_CHAR_LIMIT = 35000 # 压缩后的目标字符数 (20K-50K 范围的中间值) +SAFE_CHAR_LIMIT_MIN = 20000 # 压缩目标下限 +SAFE_CHAR_LIMIT_MAX = 50000 # 压缩目标上限 +MIN_KEEP_MESSAGES = 6 # 最少保留的最近消息数 +MAX_KEEP_MESSAGES = 20 # 最多保留的最近消息数 +SUMMARY_MAX_LENGTH = 3000 # 摘要最大长度 + + +@dataclass +class HistoryConfig: + """历史消息配置(简化版,大部分参数自动管理)""" + # 启用的策略 + strategies: List[TruncateStrategy] = field(default_factory=lambda: [TruncateStrategy.ERROR_RETRY]) + + # 以下参数保留用于兼容,但实际使用自动值 + max_messages: int = 30 + max_chars: int = 150000 + summary_keep_recent: int = 10 + summary_threshold: int = 100000 + summary_max_length: int = 2000 + retry_max_messages: int = 20 + max_retries: int = 3 + estimate_threshold: int = 180000 + chars_per_token: float = 3.0 + summary_cache_enabled: bool = True + summary_cache_min_delta_messages: int = 3 + summary_cache_min_delta_chars: int = 4000 + summary_cache_max_age_seconds: int = 300 + add_warning_header: bool = True + + def to_dict(self) -> dict: + return { + "strategies": [s.value for s in self.strategies], + "max_messages": self.max_messages, + "max_chars": self.max_chars, + "summary_keep_recent": self.summary_keep_recent, + "summary_threshold": self.summary_threshold, + "summary_max_length": self.summary_max_length, + "retry_max_messages": self.retry_max_messages, + "max_retries": self.max_retries, + "estimate_threshold": self.estimate_threshold, + "chars_per_token": self.chars_per_token, + "summary_cache_enabled": self.summary_cache_enabled, + "summary_cache_min_delta_messages": self.summary_cache_min_delta_messages, + "summary_cache_min_delta_chars": self.summary_cache_min_delta_chars, + "summary_cache_max_age_seconds": self.summary_cache_max_age_seconds, + "add_warning_header": self.add_warning_header, + } + + @classmethod + def from_dict(cls, data: dict) -> "HistoryConfig": + strategies = [TruncateStrategy(s) for s in data.get("strategies", ["error_retry"])] + return cls( + strategies=strategies, + max_messages=data.get("max_messages", 30), + max_chars=data.get("max_chars", 150000), + summary_keep_recent=data.get("summary_keep_recent", 10), + summary_threshold=data.get("summary_threshold", 100000), + summary_max_length=data.get("summary_max_length", 2000), + retry_max_messages=data.get("retry_max_messages", 20), + max_retries=data.get("max_retries", 3), + estimate_threshold=data.get("estimate_threshold", 180000), + chars_per_token=data.get("chars_per_token", 3.0), + summary_cache_enabled=data.get("summary_cache_enabled", True), + summary_cache_min_delta_messages=data.get("summary_cache_min_delta_messages", 3), + summary_cache_min_delta_chars=data.get("summary_cache_min_delta_chars", 4000), + summary_cache_max_age_seconds=data.get("summary_cache_max_age_seconds", 300), + add_warning_header=data.get("add_warning_header", True), + ) + + +_summary_cache = SummaryCache() + + +class HistoryManager: + """历史消息管理器 - 错误触发压缩版 + + 不再依赖阈值预检测,仅在收到上下文超限错误后触发压缩。 + 压缩目标为 20K-50K 字符范围。 + """ + + def __init__(self, config: HistoryConfig = None, cache_key: Optional[str] = None): + self.config = config or HistoryConfig() + self._truncated = False + self._truncate_info = "" + self.cache_key = cache_key + self._retry_count = 0 + + @property + def was_truncated(self) -> bool: + return self._truncated + + @property + def truncate_info(self) -> str: + return self._truncate_info + + def reset(self): + self._truncated = False + self._truncate_info = "" + + def set_cache_key(self, key: Optional[str]): + self.cache_key = key + + def _hash_history(self, history: List[dict]) -> str: + """生成历史消息的简单哈希""" + return f"{len(history)}:{len(json.dumps(history, ensure_ascii=False))}" + + def estimate_tokens(self, text: str) -> int: + return int(len(text) / self.config.chars_per_token) + + def estimate_history_size(self, history: List[dict]) -> Tuple[int, int]: + char_count = len(json.dumps(history, ensure_ascii=False)) + return len(history), char_count + + def estimate_request_chars(self, history: List[dict], user_content: str = "") -> Tuple[int, int, int]: + history_chars = len(json.dumps(history, ensure_ascii=False)) + user_chars = len(user_content or "") + return history_chars, user_chars, history_chars + user_chars + + def _extract_text(self, content) -> str: + if isinstance(content, str): + return content + if isinstance(content, list): + texts = [] + for item in content: + if isinstance(item, dict) and item.get("type") == "text": + texts.append(item.get("text", "")) + elif isinstance(item, str): + texts.append(item) + return "\n".join(texts) + if isinstance(content, dict): + return content.get("text", "") or content.get("content", "") + return str(content) if content else "" + + + def _format_for_summary(self, history: List[dict]) -> str: + """格式化历史消息用于生成摘要""" + lines = [] + for msg in history: + role = "unknown" + content = "" + if "userInputMessage" in msg: + role = "user" + content = msg.get("userInputMessage", {}).get("content", "") + elif "assistantResponseMessage" in msg: + role = "assistant" + content = msg.get("assistantResponseMessage", {}).get("content", "") + else: + role = msg.get("role", "unknown") + content = self._extract_text(msg.get("content", "")) + # 截断过长的单条消息 + if len(content) > 800: + content = content[:800] + "..." + lines.append(f"[{role}]: {content}") + return "\n".join(lines) + + def _calculate_keep_count(self, history: List[dict], target_chars: int) -> int: + """计算应该保留多少条最近消息""" + if not history: + return 0 + + # 从后往前累计,找到合适的保留数量 + total = 0 + count = 0 + for msg in reversed(history): + msg_chars = len(json.dumps(msg, ensure_ascii=False)) + if total + msg_chars > target_chars and count >= MIN_KEEP_MESSAGES: + break + total += msg_chars + count += 1 + if count >= MAX_KEEP_MESSAGES: + break + + return max(MIN_KEEP_MESSAGES, min(count, len(history) - 1)) + + def _build_compressed_history( + self, + summary: str, + recent_history: List[dict], + label: str = "" + ) -> List[dict]: + """构建压缩后的历史(摘要 + 最近消息)""" + # 确保 recent_history 以 user 消息开头 + if recent_history and "assistantResponseMessage" in recent_history[0]: + recent_history = recent_history[1:] + + # 清理孤立的 toolResults + tool_use_ids = set() + for msg in recent_history: + if "assistantResponseMessage" in msg: + for tu in msg["assistantResponseMessage"].get("toolUses", []) or []: + if tu.get("toolUseId"): + tool_use_ids.add(tu["toolUseId"]) + + # 清理第一条 user 消息的 toolResults(因为前面没有对应的 toolUse) + if recent_history and "userInputMessage" in recent_history[0]: + recent_history[0]["userInputMessage"].pop("userInputMessageContext", None) + + # 过滤其他消息中孤立的 toolResults + if tool_use_ids: + for msg in recent_history: + if "userInputMessage" in msg: + ctx = msg.get("userInputMessage", {}).get("userInputMessageContext", {}) + results = ctx.get("toolResults") + if results: + filtered = [r for r in results if r.get("toolUseId") in tool_use_ids] + if filtered: + ctx["toolResults"] = filtered + else: + ctx.pop("toolResults", None) + if not ctx: + msg["userInputMessage"].pop("userInputMessageContext", None) + else: + for msg in recent_history: + if "userInputMessage" in msg: + msg["userInputMessage"].pop("userInputMessageContext", None) + + + # 获取 model_id + model_id = "claude-sonnet-4" + for msg in reversed(recent_history): + if "userInputMessage" in msg: + model_id = msg["userInputMessage"].get("modelId", model_id) + break + if "assistantResponseMessage" in msg: + model_id = msg["assistantResponseMessage"].get("modelId", model_id) + break + + # 检测消息格式 + is_kiro_format = any("userInputMessage" in h or "assistantResponseMessage" in h for h in recent_history) + + if is_kiro_format: + result = [ + { + "userInputMessage": { + "content": f"[Earlier conversation summary]\n{summary}\n\n[Continuing from recent context...]", + "modelId": model_id, + "origin": "AI_EDITOR", + } + }, + { + "assistantResponseMessage": { + "content": "I understand the context from the summary. Let's continue." + } + } + ] + else: + result = [ + {"role": "user", "content": f"[Earlier conversation summary]\n{summary}\n\n[Continuing from recent context...]"}, + {"role": "assistant", "content": "I understand the context from the summary. Let's continue."} + ] + + result.extend(recent_history) + + if label: + print(f"[HistoryManager] {label}: {len(recent_history)} recent + summary") + + return result + + async def _generate_summary(self, history: List[dict], api_caller: Callable) -> Optional[str]: + """生成历史消息摘要""" + if not history or not api_caller: + return None + + formatted = self._format_for_summary(history) + if len(formatted) > 15000: + formatted = formatted[:15000] + "\n...(truncated)" + + prompt = f"""请简洁总结以下对话的关键信息: +1. 用户的主要目标 +2. 已完成的重要操作和决策 +3. 当前工作状态和关键上下文 + +对话历史: +{formatted} + +请用中文输出摘要,控制在 {SUMMARY_MAX_LENGTH} 字符以内,重点保留对后续对话有用的信息:""" + + try: + summary = await api_caller(prompt) + if summary and len(summary) > SUMMARY_MAX_LENGTH: + summary = summary[:SUMMARY_MAX_LENGTH] + "..." + return summary + except Exception as e: + print(f"[HistoryManager] 生成摘要失败: {e}") + return None + + + async def smart_compress( + self, + history: List[dict], + api_caller: Callable, + target_chars: int = SAFE_CHAR_LIMIT, + retry_level: int = 0 + ) -> List[dict]: + """智能压缩历史消息 + + 核心逻辑:保留最近消息 + 摘要早期对话 + 压缩目标为 20K-50K 字符范围 + + Args: + history: 历史消息 + api_caller: 用于生成摘要的 API 调用函数 + target_chars: 目标字符数 (默认 35K,范围 20K-50K) + retry_level: 重试级别(越高保留越少) + """ + if not history: + return history + + current_chars = len(json.dumps(history, ensure_ascii=False)) + + # 确保目标在 20K-50K 范围内 + target_chars = max(SAFE_CHAR_LIMIT_MIN, min(target_chars, SAFE_CHAR_LIMIT_MAX)) + + # 如果已经在目标范围内,不需要压缩 + if current_chars <= target_chars: + return history + + # 根据重试级别调整保留数量 + adjusted_target = int(target_chars * (0.85 ** retry_level)) + adjusted_target = max(SAFE_CHAR_LIMIT_MIN, adjusted_target) # 确保不低于下限 + + keep_count = self._calculate_keep_count(history, adjusted_target) + + # 确保至少保留一些消息用于摘要 + if keep_count >= len(history): + keep_count = max(MIN_KEEP_MESSAGES, len(history) - 2) + + old_history = history[:-keep_count] if keep_count < len(history) else [] + recent_history = history[-keep_count:] if keep_count > 0 else history + + if not old_history: + # 没有可摘要的历史,直接返回 + return recent_history + + # 尝试从缓存获取摘要 + cache_key = f"{self.cache_key}:{keep_count}" if self.cache_key else None + old_hash = self._hash_history(old_history) + + cached_summary = None + if cache_key and self.config.summary_cache_enabled: + cached_summary = _summary_cache.get(cache_key, old_hash, self.config.summary_cache_max_age_seconds) + + if cached_summary: + result = self._build_compressed_history(cached_summary, recent_history, "压缩(缓存)") + result_chars = len(json.dumps(result, ensure_ascii=False)) + self._truncated = True + self._truncate_info = f"智能压缩(缓存): {len(history)} -> {len(result)} 条消息, {current_chars} -> {result_chars} 字符" + return result + + # 生成新摘要 + summary = await self._generate_summary(old_history, api_caller) + + if summary: + if cache_key and self.config.summary_cache_enabled: + _summary_cache.set(cache_key, summary, old_hash) + + result = self._build_compressed_history(summary, recent_history, "智能压缩") + result_chars = len(json.dumps(result, ensure_ascii=False)) + self._truncated = True + self._truncate_info = f"智能压缩: {len(history)} -> {len(result)} 条消息, {current_chars} -> {result_chars} 字符 (摘要 {len(summary)} 字符)" + return result + + # 摘要失败,回退到简单截断 + self._truncated = True + result_chars = len(json.dumps(recent_history, ensure_ascii=False)) + self._truncate_info = f"摘要失败,保留最近 {len(recent_history)} 条消息, {current_chars} -> {result_chars} 字符" + return recent_history + + + def needs_compression(self, history: List[dict], user_content: str = "") -> bool: + """检查是否需要压缩 + + 注意:此方法现在始终返回 False,不再基于阈值预检测。 + 压缩仅在收到上下文超限错误后触发。 + 保留此方法是为了兼容旧 API。 + """ + # 不再基于阈值预检测,始终返回 False + # 压缩将在收到 CONTENT_LENGTH_EXCEEDS_THRESHOLD 错误后触发 + return False + + async def pre_process_async( + self, + history: List[dict], + user_content: str = "", + api_caller: Callable = None + ) -> List[dict]: + """预处理历史消息 + + 注意:不再进行发送前自动压缩。 + 压缩仅在收到上下文超限错误后触发。 + """ + self.reset() + + if not history: + return history + + # 不再进行预压缩,直接返回原始历史 + # 压缩将在收到错误后由 handle_length_error_async 处理 + return history + + def pre_process(self, history: List[dict], user_content: str = "") -> List[dict]: + """预处理历史消息(同步版本) + + 注意:不再进行发送前自动压缩。 + 压缩仅在收到上下文超限错误后触发。 + """ + self.reset() + + if not history: + return history + + # 不再进行预压缩,直接返回原始历史 + return history + + async def handle_length_error_async( + self, + history: List[dict], + retry_count: int = 0, + api_caller: Optional[Callable] = None + ) -> Tuple[List[dict], bool]: + """处理长度超限错误(智能压缩后重试) + + 这是唯一触发压缩的入口点。当收到上下文超限错误时调用此方法。 + 压缩目标为 20K-50K 字符范围。 + + 防止无限循环: + - 追踪压缩状态,避免重复压缩相同内容 + - 压缩前检查大小,如果已经很小则不再压缩 + - 达到最大重试次数后返回清晰错误 + + Args: + history: 历史消息 + retry_count: 当前重试次数 + api_caller: API 调用函数 + + Returns: + (compressed_history, should_retry) + """ + max_retries = self.config.max_retries + + if retry_count >= max_retries: + print(f"[HistoryManager] 已达最大重试次数 ({max_retries}),建议清空对话") + self._truncate_info = f"已达最大压缩次数 ({max_retries}),请清空对话或减少消息数量" + return history, False + + if not history: + return history, False + + self.reset() + + current_chars = len(json.dumps(history, ensure_ascii=False)) + current_hash = self._hash_history(history) + + print(f"[HistoryManager] 收到上下文超限错误,当前大小: {current_chars} 字符") + + # 优先检查全局压缩缓存(解决 Claude Code CLI 反复压缩问题) + cached_result = _compression_cache.get(current_hash) + if cached_result is not None: + cached_chars = len(json.dumps(cached_result, ensure_ascii=False)) + self._truncated = True + self._truncate_info = f"使用缓存的压缩结果: {len(history)} -> {len(cached_result)} 条消息, {current_chars} -> {cached_chars} 字符" + print(f"[HistoryManager] {self._truncate_info}") + return cached_result, True + + print(f"[HistoryManager] 开始压缩...") + + # 防止无限循环:检查是否已经压缩过相同内容(实例级缓存) + instance_cache_key = f"compression:{current_hash}:{retry_count}" + if hasattr(self, '_instance_compression_cache') and instance_cache_key in self._instance_compression_cache: + print(f"[HistoryManager] 检测到重复压缩请求,跳过") + self._truncate_info = "内容已压缩到最小,无法继续压缩,请清空对话" + return history, False + + # 初始化实例级压缩缓存 + if not hasattr(self, '_instance_compression_cache'): + self._instance_compression_cache = {} + + # 根据重试次数计算目标大小 (20K-50K 范围) + # 第一次重试: 目标 35K (中间值) + # 第二次重试: 目标 25K + # 第三次重试: 目标 20K (下限) + if retry_count == 0: + target_chars = SAFE_CHAR_LIMIT # 35K + elif retry_count == 1: + target_chars = 25000 + else: + target_chars = SAFE_CHAR_LIMIT_MIN # 20K + + # 防止无限循环:如果当前大小已经小于目标,不再压缩 + if current_chars <= target_chars: + print(f"[HistoryManager] 当前大小 ({current_chars}) 已小于目标 ({target_chars}),无法继续压缩") + self._truncate_info = f"内容已压缩到 {current_chars} 字符,仍然超限,请清空对话" + return history, False + + print(f"[HistoryManager] 第 {retry_count + 1} 次重试,目标压缩到 {target_chars} 字符") + + if api_caller: + compressed = await self.smart_compress( + history, api_caller, + target_chars=target_chars, + retry_level=retry_count + ) + compressed_chars = len(json.dumps(compressed, ensure_ascii=False)) + + # 防止无限循环:检查压缩是否有效 + if compressed_chars >= current_chars * 0.95: # 压缩效果不足 5% + print(f"[HistoryManager] 压缩效果不足,无法继续压缩") + self._truncate_info = f"压缩效果不足,请清空对话或减少消息数量" + return history, False + + # 防止无限循环:检查压缩后是否仍然过大 + if compressed_chars > 50000 and retry_count >= max_retries - 1: + print(f"[HistoryManager] 压缩后仍然过大 ({compressed_chars}),建议清空对话") + self._truncate_info = f"压缩后仍有 {compressed_chars} 字符,请清空对话" + return compressed, False + + if len(compressed) < len(history): + # 保存到全局压缩缓存(解决 Claude Code CLI 反复压缩问题) + _compression_cache.set(current_hash, compressed, compressed_chars) + + # 记录实例级压缩缓存(防止同一请求内的重复压缩) + self._instance_compression_cache[instance_cache_key] = True + # 清理旧缓存(保留最近 10 条) + if len(self._instance_compression_cache) > 10: + oldest_key = next(iter(self._instance_compression_cache)) + del self._instance_compression_cache[oldest_key] + + self._truncated = True + self._truncate_info = f"错误后压缩 (第 {retry_count + 1} 次): {len(history)} -> {len(compressed)} 条消息, {current_chars} -> {compressed_chars} 字符" + print(f"[HistoryManager] {self._truncate_info}") + return compressed, True + else: + # 无 api_caller,简单截断 + keep_count = max(MIN_KEEP_MESSAGES, int(len(history) * (0.5 ** (retry_count + 1)))) + if keep_count < len(history): + truncated = history[-keep_count:] + self._truncated = True + truncated_chars = len(json.dumps(truncated, ensure_ascii=False)) + + # 防止无限循环:检查截断是否有效 + if truncated_chars >= current_chars * 0.95: + print(f"[HistoryManager] 截断效果不足,无法继续压缩") + self._truncate_info = f"截断效果不足,请清空对话" + return history, False + + self._truncate_info = f"错误后截断 (第 {retry_count + 1} 次): {len(history)} -> {len(truncated)} 条消息, {current_chars} -> {truncated_chars} 字符" + print(f"[HistoryManager] {self._truncate_info}") + return truncated, True + + return history, False + + + def handle_length_error(self, history: List[dict], retry_count: int = 0) -> Tuple[List[dict], bool]: + """处理长度超限错误(同步版本,简单截断)""" + max_retries = self.config.max_retries + + if retry_count >= max_retries: + return history, False + + if not history: + return history, False + + self.reset() + + # 根据重试次数逐步减少 + keep_ratio = 0.5 ** (retry_count + 1) + keep_count = max(MIN_KEEP_MESSAGES, int(len(history) * keep_ratio)) + + if keep_count < len(history): + truncated = history[-keep_count:] + self._truncated = True + self._truncate_info = f"错误重试截断 (第 {retry_count + 1} 次): {len(history)} -> {len(truncated)} 条消息" + return truncated, True + + return history, False + + def get_warning_header(self) -> Optional[str]: + if not self.config.add_warning_header or not self._truncated: + return None + return self._truncate_info + + # ========== 兼容旧 API ========== + + def truncate_by_count(self, history: List[dict], max_count: int) -> List[dict]: + """按消息数量截断(兼容)""" + if len(history) <= max_count: + return history + original_count = len(history) + truncated = history[-max_count:] + self._truncated = True + self._truncate_info = f"按数量截断: {original_count} -> {len(truncated)} 条消息" + return truncated + + def truncate_by_chars(self, history: List[dict], max_chars: int) -> List[dict]: + """按字符数截断(兼容)""" + total_chars = len(json.dumps(history, ensure_ascii=False)) + if total_chars <= max_chars: + return history + + original_count = len(history) + result = [] + current_chars = 0 + + for msg in reversed(history): + msg_chars = len(json.dumps(msg, ensure_ascii=False)) + if current_chars + msg_chars > max_chars and result: + break + result.insert(0, msg) + current_chars += msg_chars + + if len(result) < original_count: + self._truncated = True + self._truncate_info = f"按字符数截断: {original_count} -> {len(result)} 条消息" + + return result + + def should_pre_truncate(self, history: List[dict], user_content: str) -> bool: + """兼容旧 API""" + return self.needs_compression(history, user_content) + + def should_summarize(self, history: List[dict]) -> bool: + """兼容旧 API""" + return self.needs_compression(history) + + def should_smart_summarize(self, history: List[dict]) -> bool: + """兼容旧 API""" + return self.needs_compression(history) + + def should_auto_truncate_summarize(self, history: List[dict]) -> bool: + """兼容旧 API""" + return self.needs_compression(history) + + def should_pre_summary_for_error_retry(self, history: List[dict], user_content: str = "") -> bool: + """兼容旧 API""" + return self.needs_compression(history, user_content) + + async def compress_with_summary(self, history: List[dict], api_caller: Callable) -> List[dict]: + """兼容旧 API""" + return await self.smart_compress(history, api_caller) + + async def compress_before_auto_truncate(self, history: List[dict], api_caller: Callable) -> List[dict]: + """兼容旧 API""" + return await self.smart_compress(history, api_caller) + + async def generate_summary(self, history: List[dict], api_caller: Callable) -> Optional[str]: + """兼容旧 API""" + return await self._generate_summary(history, api_caller) + + def summarize_history_structure(self, history: List[dict], max_items: int = 12) -> str: + """生成历史结构摘要(调试用)""" + if not history: + return "len=0" + + def entry_kind(msg): + if "userInputMessage" in msg: + return "U" + if "assistantResponseMessage" in msg: + return "A" + role = msg.get("role") + return "U" if role == "user" else ("A" if role == "assistant" else "?") + + kinds = [entry_kind(msg) for msg in history] + if len(kinds) <= max_items: + seq = "".join(kinds) + else: + head = max_items // 2 + tail = max_items - head + seq = f"{''.join(kinds[:head])}...{''.join(kinds[-tail:])}" + + return f"len={len(history)} seq={seq}" + + + +# ========== 全局配置 ========== + +_history_config = HistoryConfig() + + +def get_history_config() -> HistoryConfig: + """获取历史消息配置""" + return _history_config + + +def set_history_config(config: HistoryConfig): + """设置历史消息配置""" + global _history_config + _history_config = config + + +def update_history_config(data: dict): + """更新历史消息配置""" + global _history_config + _history_config = HistoryConfig.from_dict(data) + + +def is_content_length_error(status_code: int, error_text: str) -> bool: + """检查是否为内容长度超限错误""" + if "CONTENT_LENGTH_EXCEEDS_THRESHOLD" in error_text: + return True + if "Input is too long" in error_text: + return True + lowered = error_text.lower() + if "too long" in lowered and ("input" in lowered or "content" in lowered or "message" in lowered): + return True + if "context length" in lowered or "token limit" in lowered: + return True + return False diff --git a/KiroProxy/kiro_proxy/core/kiro_api.py b/KiroProxy/kiro_proxy/core/kiro_api.py new file mode 100644 index 0000000000000000000000000000000000000000..819ba21f72a4c8b7449feda828c4daf93bacc345 --- /dev/null +++ b/KiroProxy/kiro_proxy/core/kiro_api.py @@ -0,0 +1,146 @@ +"""Kiro Web Portal API 调用模块 + +调用 Kiro 的 Web Portal API 获取用户信息,使用 CBOR 编码。 +参考: chaogei/Kiro-account-manager +""" +import uuid +import httpx +from typing import Optional, Tuple, Any, Dict + +try: + import cbor2 + HAS_CBOR = True +except ImportError: + HAS_CBOR = False + print("[KiroAPI] 警告: cbor2 未安装,部分功能不可用。请运行: pip install cbor2") + + +# Kiro Web Portal API 基础 URL +KIRO_API_BASE = "https://app.kiro.dev/service/KiroWebPortalService/operation" + + +async def kiro_api_request( + operation: str, + body: Dict[str, Any], + access_token: str, + idp: str = "Google", +) -> Tuple[bool, Any]: + """ + 调用 Kiro Web Portal API + + Args: + operation: API 操作名称,如 "GetUserUsageAndLimits" + body: 请求体(会被 CBOR 编码) + access_token: Bearer token + idp: 身份提供商 ("Google" 或 "Github") + + Returns: + (success, response_data or error_dict) + """ + if not HAS_CBOR: + return False, {"error": "cbor2 未安装"} + + if not access_token: + return False, {"error": "缺少 access token"} + + url = f"{KIRO_API_BASE}/{operation}" + + # CBOR 编码请求体 + try: + encoded_body = cbor2.dumps(body) + except Exception as e: + return False, {"error": f"CBOR 编码失败: {e}"} + + headers = { + "accept": "application/cbor", + "content-type": "application/cbor", + "smithy-protocol": "rpc-v2-cbor", + "amz-sdk-invocation-id": str(uuid.uuid4()), + "amz-sdk-request": "attempt=1; max=1", + "x-amz-user-agent": "aws-sdk-js/1.0.0 kiro-proxy/1.0.0", + "authorization": f"Bearer {access_token}", + "cookie": f"Idp={idp}; AccessToken={access_token}", + } + + try: + async with httpx.AsyncClient(timeout=15, verify=False) as client: + response = await client.post(url, content=encoded_body, headers=headers) + + if response.status_code != 200: + return False, {"error": f"API 请求失败: {response.status_code}"} + + # CBOR 解码响应 + try: + data = cbor2.loads(response.content) + return True, data + except Exception as e: + return False, {"error": f"CBOR 解码失败: {e}"} + + except httpx.TimeoutException: + return False, {"error": "请求超时"} + except Exception as e: + return False, {"error": f"请求失败: {str(e)}"} + + +async def get_user_info( + access_token: str, + idp: str = "Google", +) -> Tuple[bool, Dict[str, Any]]: + """ + 获取用户信息(包括邮箱) + + Args: + access_token: Bearer token + idp: 身份提供商 ("Google" 或 "Github") + + Returns: + (success, user_info or error_dict) + user_info 包含: email, userId 等 + """ + success, result = await kiro_api_request( + operation="GetUserUsageAndLimits", + body={"isEmailRequired": True, "origin": "KIRO_IDE"}, + access_token=access_token, + idp=idp, + ) + + if not success: + return False, result + + # 提取用户信息 + user_info = result.get("userInfo", {}) + return True, { + "email": user_info.get("email"), + "userId": user_info.get("userId"), + "raw": result, + } + + +async def get_user_email( + access_token: str, + provider: str = "Google", +) -> Optional[str]: + """ + 获取用户邮箱地址 + + Args: + access_token: Bearer token + provider: 登录提供商 ("Google" 或 "Github") + + Returns: + 邮箱地址,失败返回 None + """ + # 标准化 provider 名称 + idp = provider + if provider and provider.lower() == "google": + idp = "Google" + elif provider and provider.lower() == "github": + idp = "Github" + + success, result = await get_user_info(access_token, idp) + + if success: + return result.get("email") + + print(f"[KiroAPI] 获取邮箱失败: {result.get('error', '未知错误')}") + return None diff --git a/KiroProxy/kiro_proxy/core/persistence.py b/KiroProxy/kiro_proxy/core/persistence.py new file mode 100644 index 0000000000000000000000000000000000000000..d46b780ab82034e0cb522d551e0359a1b76870bf --- /dev/null +++ b/KiroProxy/kiro_proxy/core/persistence.py @@ -0,0 +1,69 @@ +"""配置持久化""" +import json +from pathlib import Path +from typing import List, Dict, Any + +# 统一使用 config.py 中的 DATA_DIR +from ..config import DATA_DIR + +# 配置文件路径 +CONFIG_DIR = DATA_DIR +CONFIG_FILE = CONFIG_DIR / "config.json" + + +def ensure_config_dir(): + """确保配置目录存在""" + CONFIG_DIR.mkdir(parents=True, exist_ok=True) + + +def save_accounts(accounts: List[Dict[str, Any]]) -> bool: + """保存账号配置""" + try: + ensure_config_dir() + config = load_config() + config["accounts"] = accounts + with open(CONFIG_FILE, "w", encoding="utf-8") as f: + json.dump(config, f, indent=2, ensure_ascii=False) + return True + except Exception as e: + print(f"[Persistence] 保存配置失败: {e}") + return False + + +def load_accounts() -> List[Dict[str, Any]]: + """加载账号配置""" + config = load_config() + return config.get("accounts", []) + + +def load_config() -> Dict[str, Any]: + """加载完整配置""" + try: + if CONFIG_FILE.exists(): + with open(CONFIG_FILE, "r", encoding="utf-8") as f: + return json.load(f) + except Exception as e: + print(f"[Persistence] 加载配置失败: {e}") + return {} + + +def save_config(config: Dict[str, Any]) -> bool: + """保存完整配置""" + try: + ensure_config_dir() + with open(CONFIG_FILE, "w", encoding="utf-8") as f: + json.dump(config, f, indent=2, ensure_ascii=False) + return True + except Exception as e: + print(f"[Persistence] 保存配置失败: {e}") + return False + + +def export_config() -> Dict[str, Any]: + """导出配置(用于备份)""" + return load_config() + + +def import_config(config: Dict[str, Any]) -> bool: + """导入配置(用于恢复)""" + return save_config(config) diff --git a/KiroProxy/kiro_proxy/core/protocol_handler.py b/KiroProxy/kiro_proxy/core/protocol_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..77ab917c3b656f528369570815c59d42b383d740 --- /dev/null +++ b/KiroProxy/kiro_proxy/core/protocol_handler.py @@ -0,0 +1,318 @@ +"""自定义协议处理器 + +在 Windows 上注册 kiro:// 协议,用于处理 OAuth 回调。 +""" +import sys +import os +import asyncio +import threading +from pathlib import Path +from typing import Optional, Callable +from http.server import HTTPServer, BaseHTTPRequestHandler +from urllib.parse import urlparse, parse_qs, urlencode +import socket + + +# 回调服务器端口 +CALLBACK_PORT = 19823 + +# 全局回调结果 +_callback_result = None +_callback_event = None +_callback_server = None +_server_thread = None + + +class CallbackHandler(BaseHTTPRequestHandler): + """处理 OAuth 回调的 HTTP 请求处理器""" + + def log_message(self, format, *args): + """禁用日志输出""" + pass + + def do_GET(self): + global _callback_result, _callback_event + + # 解析 URL + parsed = urlparse(self.path) + params = parse_qs(parsed.query) + + # 检查是否是回调路径 + if parsed.path == '/kiro-callback' or parsed.path == '/' or 'code' in params: + code = params.get('code', [None])[0] + state = params.get('state', [None])[0] + error = params.get('error', [None])[0] + + print(f"[ProtocolHandler] 收到回调: code={code[:20] if code else None}..., state={state}, error={error}") + + if error: + _callback_result = {"error": error} + elif code and state: + _callback_result = {"code": code, "state": state} + else: + _callback_result = {"error": "缺少授权码"} + + # 触发事件 + if _callback_event: + _callback_event.set() + + # 返回成功页面 + self.send_response(200) + self.send_header('Content-type', 'text/html; charset=utf-8') + self.end_headers() + + html = """ + + + + + 登录成功 + + + +
+

✅ 登录成功

+

您可以关闭此窗口并返回 Kiro Proxy

+ +
+ + + """ + self.wfile.write(html.encode('utf-8')) + else: + self.send_response(404) + self.end_headers() + + +def is_port_available(port: int) -> bool: + """检查端口是否可用""" + try: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(('127.0.0.1', port)) + return True + except OSError: + return False + + +def start_callback_server() -> tuple: + """启动回调服务器 + + Returns: + (success, port or error) + """ + global _callback_server, _callback_result, _callback_event, _server_thread + + # 如果服务器已经在运行,直接返回成功 + if _callback_server is not None and _server_thread is not None and _server_thread.is_alive(): + print(f"[ProtocolHandler] 回调服务器已在运行: http://127.0.0.1:{CALLBACK_PORT}") + return True, CALLBACK_PORT + + _callback_result = None + _callback_event = threading.Event() + + # 检查端口 + if not is_port_available(CALLBACK_PORT): + # 端口被占用,可能是之前的服务器还在运行 + print(f"[ProtocolHandler] 端口 {CALLBACK_PORT} 已被占用,尝试复用") + return True, CALLBACK_PORT + + try: + _callback_server = HTTPServer(('127.0.0.1', CALLBACK_PORT), CallbackHandler) + + # 在后台线程运行服务器 + _server_thread = threading.Thread(target=_callback_server.serve_forever, daemon=True) + _server_thread.start() + + print(f"[ProtocolHandler] 回调服务器已启动: http://127.0.0.1:{CALLBACK_PORT}") + return True, CALLBACK_PORT + except Exception as e: + return False, str(e) + + +def stop_callback_server(): + """停止回调服务器""" + global _callback_server, _server_thread + + if _callback_server: + try: + _callback_server.shutdown() + except: + pass + _callback_server = None + _server_thread = None + print("[ProtocolHandler] 回调服务器已停止") + + +def wait_for_callback(timeout: int = 300) -> tuple: + """等待回调 + + Args: + timeout: 超时时间(秒) + + Returns: + (success, result or error) + """ + global _callback_result, _callback_event + + if _callback_event is None: + return False, {"error": "回调服务器未启动"} + + # 等待回调 + if _callback_event.wait(timeout=timeout): + if _callback_result and "code" in _callback_result: + return True, _callback_result + elif _callback_result and "error" in _callback_result: + return False, _callback_result + else: + return False, {"error": "未收到有效回调"} + else: + return False, {"error": "等待回调超时"} + + +def get_callback_result() -> Optional[dict]: + """获取回调结果(非阻塞)""" + global _callback_result + return _callback_result + + +def clear_callback_result(): + """清除回调结果""" + global _callback_result, _callback_event + _callback_result = None + if _callback_event: + _callback_event.clear() + + +# Windows 协议注册 +def register_protocol_windows() -> tuple: + """在 Windows 上注册 kiro:// 协议 + + 注册后,当浏览器重定向到 kiro:// URL 时,Windows 会调用我们的脚本, + 脚本将参数重定向到本地 HTTP 服务器。 + + Returns: + (success, message) + """ + if sys.platform != 'win32': + return False, "仅支持 Windows" + + try: + import winreg + + # 获取当前 Python 解释器路径 + python_exe = sys.executable + + # 创建一个处理脚本 + script_dir = Path.home() / ".kiro-proxy" + script_dir.mkdir(parents=True, exist_ok=True) + script_path = script_dir / "protocol_redirect.pyw" + + # 写入重定向脚本 (.pyw 不显示控制台窗口) + script_content = f'''# -*- coding: utf-8 -*- +# Kiro Protocol Redirect Script +import sys +import webbrowser +from urllib.parse import urlparse, parse_qs, urlencode + +if len(sys.argv) > 1: + url = sys.argv[1] + + # 解析 kiro:// URL + # 格式: kiro://kiro.kiroAgent/authenticate-success?code=xxx&state=xxx + if url.startswith('kiro://'): + # 提取查询参数 + query_start = url.find('?') + if query_start > -1: + query_string = url[query_start + 1:] + # 重定向到本地 HTTP 服务器 + redirect_url = "http://127.0.0.1:{CALLBACK_PORT}/kiro-callback?" + query_string + webbrowser.open(redirect_url) +''' + script_path.write_text(script_content, encoding='utf-8') + + # 获取 pythonw.exe 路径(无控制台窗口) + python_dir = Path(python_exe).parent + pythonw_exe = python_dir / "pythonw.exe" + if not pythonw_exe.exists(): + pythonw_exe = python_exe # 降级使用 python.exe + + # 注册协议 + key_path = r"SOFTWARE\\Classes\\kiro" + + # 创建主键 + key = winreg.CreateKey(winreg.HKEY_CURRENT_USER, key_path) + winreg.SetValue(key, "", winreg.REG_SZ, "URL:Kiro Protocol") + winreg.SetValueEx(key, "URL Protocol", 0, winreg.REG_SZ, "") + winreg.CloseKey(key) + + # 创建 DefaultIcon 键 + icon_key = winreg.CreateKey(winreg.HKEY_CURRENT_USER, key_path + r"\\DefaultIcon") + winreg.SetValue(icon_key, "", winreg.REG_SZ, f"{python_exe},0") + winreg.CloseKey(icon_key) + + # 创建 shell\\open\\command 键 + cmd_key = winreg.CreateKey(winreg.HKEY_CURRENT_USER, key_path + r"\\shell\\open\\command") + cmd = f'"{pythonw_exe}" "{script_path}" "%1"' + winreg.SetValue(cmd_key, "", winreg.REG_SZ, cmd) + winreg.CloseKey(cmd_key) + + print(f"[ProtocolHandler] 已注册 kiro:// 协议") + print(f"[ProtocolHandler] 脚本路径: {script_path}") + print(f"[ProtocolHandler] 命令: {cmd}") + return True, "协议注册成功" + + except Exception as e: + import traceback + traceback.print_exc() + return False, f"注册失败: {e}" + + +def unregister_protocol_windows() -> tuple: + """取消注册 kiro:// 协议""" + if sys.platform != 'win32': + return False, "仅支持 Windows" + + try: + import winreg + + def delete_key_recursive(key, subkey): + try: + open_key = winreg.OpenKey(key, subkey, 0, winreg.KEY_ALL_ACCESS) + info = winreg.QueryInfoKey(open_key) + for i in range(info[0]): + child = winreg.EnumKey(open_key, 0) + delete_key_recursive(open_key, child) + winreg.CloseKey(open_key) + winreg.DeleteKey(key, subkey) + except WindowsError: + pass + + delete_key_recursive(winreg.HKEY_CURRENT_USER, r"SOFTWARE\\Classes\\kiro") + + print("[ProtocolHandler] 已取消注册 kiro:// 协议") + return True, "协议取消注册成功" + + except Exception as e: + return False, f"取消注册失败: {e}" + + +def is_protocol_registered() -> bool: + """检查 kiro:// 协议是否已注册""" + if sys.platform != 'win32': + return False + + try: + import winreg + key = winreg.OpenKey(winreg.HKEY_CURRENT_USER, r"SOFTWARE\\Classes\\kiro") + winreg.CloseKey(key) + return True + except WindowsError: + return False + diff --git a/KiroProxy/kiro_proxy/core/quota_cache.py b/KiroProxy/kiro_proxy/core/quota_cache.py new file mode 100644 index 0000000000000000000000000000000000000000..0f465e1167c5bb5b520d1d5bc949aec2db50b90c --- /dev/null +++ b/KiroProxy/kiro_proxy/core/quota_cache.py @@ -0,0 +1,397 @@ +"""额度缓存管理模块 + +提供账号额度信息的内存缓存和文件持久化功能。 +""" +import json +import time +import asyncio +from dataclasses import dataclass, field, asdict +from enum import Enum +from pathlib import Path +from typing import Optional, Dict, Any +from threading import Lock + + +# 默认缓存过期时间(秒) +DEFAULT_CACHE_MAX_AGE = 300 # 5分钟 + +# 低余额阈值 +LOW_BALANCE_THRESHOLD = 0.2 + + +class BalanceStatus(Enum): + """额度状态枚举 + + 用于区分账号的额度状态: + - NORMAL: 正常(剩余额度 > 20%) + - LOW: 低额度(0 < 剩余额度 <= 20%) + - EXHAUSTED: 无额度(剩余额度 <= 0) + """ + NORMAL = "normal" # 正常(>20%) + LOW = "low" # 低额度(0-20%) + EXHAUSTED = "exhausted" # 无额度(<=0) + + +@dataclass +class CachedQuota: + """缓存的额度信息""" + account_id: str + usage_limit: float = 0.0 # 总额度 + current_usage: float = 0.0 # 已用额度 + balance: float = 0.0 # 剩余额度 + usage_percent: float = 0.0 # 使用百分比 + balance_status: str = "normal" # 额度状态: normal, low, exhausted + is_low_balance: bool = False # 是否低额度(兼容旧字段) + is_exhausted: bool = False # 是否无额度 + is_suspended: bool = False # 是否被封禁 + subscription_title: str = "" # 订阅类型 + free_trial_limit: float = 0.0 # 免费试用额度 + free_trial_usage: float = 0.0 # 免费试用已用 + bonus_limit: float = 0.0 # 奖励额度 + bonus_usage: float = 0.0 # 奖励已用 + updated_at: float = 0.0 # 更新时间戳 + error: Optional[str] = None # 错误信息(如果获取失败) + + # 重置和过期时间 + next_reset_date: Optional[str] = None # 下次重置时间 + free_trial_expiry: Optional[str] = None # 免费试用过期时间 + bonus_expiries: list = None # 奖励过期时间列表 + + def __post_init__(self): + """初始化后计算额度状态""" + self._update_balance_status() + + def _update_balance_status(self) -> None: + """更新额度状态""" + if self.error is not None: + # 有错误时不更新状态 + return + + if self.balance <= 0: + self.balance_status = BalanceStatus.EXHAUSTED.value + self.is_exhausted = True + self.is_low_balance = False + elif self.usage_limit > 0: + remaining_percent = (self.balance / self.usage_limit) * 100 + if remaining_percent <= LOW_BALANCE_THRESHOLD * 100: + self.balance_status = BalanceStatus.LOW.value + self.is_low_balance = True + self.is_exhausted = False + else: + self.balance_status = BalanceStatus.NORMAL.value + self.is_low_balance = False + self.is_exhausted = False + else: + self.balance_status = BalanceStatus.NORMAL.value + self.is_low_balance = False + self.is_exhausted = False + + @classmethod + def from_usage_info(cls, account_id: str, usage_info: 'UsageInfo') -> 'CachedQuota': + """从 UsageInfo 创建 CachedQuota""" + usage_percent = (usage_info.current_usage / usage_info.usage_limit * 100) if usage_info.usage_limit > 0 else 0.0 + quota = cls( + account_id=account_id, + usage_limit=usage_info.usage_limit, + current_usage=usage_info.current_usage, + balance=usage_info.balance, + usage_percent=round(usage_percent, 2), + is_low_balance=usage_info.is_low_balance, + subscription_title=usage_info.subscription_title, + free_trial_limit=usage_info.free_trial_limit, + free_trial_usage=usage_info.free_trial_usage, + bonus_limit=usage_info.bonus_limit, + bonus_usage=usage_info.bonus_usage, + updated_at=time.time(), + error=None, + next_reset_date=usage_info.next_reset_date, + free_trial_expiry=usage_info.free_trial_expiry, + bonus_expiries=usage_info.bonus_expiries or [], + ) + # 重新计算状态以确保一致性 + quota._update_balance_status() + return quota + + @classmethod + def from_error(cls, account_id: str, error: str) -> 'CachedQuota': + """创建错误状态的缓存""" + # 检查是否为账号封禁错误 + is_suspended = ( + "temporarily_suspended" in error.lower() or + "suspended" in error.lower() or + "accountsuspendedexception" in error.lower() + ) + + quota = cls( + account_id=account_id, + updated_at=time.time(), + error=error + ) + + # 如果是封禁错误,标记为特殊状态 + if is_suspended: + quota.is_suspended = True + + return quota + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> 'CachedQuota': + """从字典创建""" + quota = cls( + account_id=data.get("account_id", ""), + usage_limit=data.get("usage_limit", 0.0), + current_usage=data.get("current_usage", 0.0), + balance=data.get("balance", 0.0), + usage_percent=data.get("usage_percent", 0.0), + balance_status=data.get("balance_status", "normal"), + is_low_balance=data.get("is_low_balance", False), + is_exhausted=data.get("is_exhausted", False), + is_suspended=data.get("is_suspended", False), + subscription_title=data.get("subscription_title", ""), + free_trial_limit=data.get("free_trial_limit", 0.0), + free_trial_usage=data.get("free_trial_usage", 0.0), + bonus_limit=data.get("bonus_limit", 0.0), + bonus_usage=data.get("bonus_usage", 0.0), + updated_at=data.get("updated_at", 0.0), + error=data.get("error"), + next_reset_date=data.get("next_reset_date"), + free_trial_expiry=data.get("free_trial_expiry"), + bonus_expiries=data.get("bonus_expiries", []), + ) + # 重新计算状态以确保一致性 + quota._update_balance_status() + return quota + + def to_dict(self) -> Dict[str, Any]: + """转换为字典""" + return asdict(self) + + def has_error(self) -> bool: + """是否有错误""" + return self.error is not None + + def is_available(self) -> bool: + """额度是否可用(未耗尽且无错误)""" + return not self.is_exhausted and not self.has_error() + + def get_balance_status_enum(self) -> BalanceStatus: + """获取额度状态枚举""" + try: + return BalanceStatus(self.balance_status) + except ValueError: + return BalanceStatus.NORMAL + + +class QuotaCache: + """额度缓存管理器 + + 提供线程安全的额度缓存操作,支持内存缓存和文件持久化。 + """ + + def __init__(self, cache_file: Optional[str] = None): + """ + 初始化缓存管理器 + + Args: + cache_file: 缓存文件路径,None 则使用默认路径 + """ + self._cache: Dict[str, CachedQuota] = {} + self._lock = Lock() + self._save_lock = asyncio.Lock() + + # 设置缓存文件路径 + if cache_file: + self._cache_file = Path(cache_file) + else: + from ..config import DATA_DIR + self._cache_file = DATA_DIR / "quota_cache.json" + + # 启动时加载缓存 + self.load_from_file() + + def get(self, account_id: str) -> Optional[CachedQuota]: + """获取账号的缓存额度 + + Args: + account_id: 账号ID + + Returns: + 缓存的额度信息,不存在则返回 None + """ + with self._lock: + return self._cache.get(account_id) + + def set(self, account_id: str, quota: CachedQuota) -> None: + """设置账号的额度缓存 + + Args: + account_id: 账号ID + quota: 额度信息 + """ + with self._lock: + self._cache[account_id] = quota + + def is_stale(self, account_id: str, max_age_seconds: int = DEFAULT_CACHE_MAX_AGE) -> bool: + """检查缓存是否过期 + + Args: + account_id: 账号ID + max_age_seconds: 最大缓存时间(秒) + + Returns: + True 表示缓存过期或不存在 + """ + with self._lock: + quota = self._cache.get(account_id) + if quota is None: + return True + return (time.time() - quota.updated_at) > max_age_seconds + + def get_all(self) -> Dict[str, CachedQuota]: + """获取所有缓存 + + Returns: + 所有账号的额度缓存副本 + """ + with self._lock: + return dict(self._cache) + + def remove(self, account_id: str) -> None: + """移除账号缓存 + + Args: + account_id: 账号ID + """ + with self._lock: + self._cache.pop(account_id, None) + + def clear(self) -> None: + """清空所有缓存""" + with self._lock: + self._cache.clear() + + def load_from_file(self) -> bool: + """从文件加载缓存 + + Returns: + 是否加载成功 + """ + if not self._cache_file.exists(): + return False + + try: + with open(self._cache_file, 'r', encoding='utf-8') as f: + data = json.load(f) + + # 验证版本 + version = data.get("version", "1.0") + accounts_data = data.get("accounts", {}) + + with self._lock: + self._cache.clear() + for account_id, quota_data in accounts_data.items(): + quota_data["account_id"] = account_id + self._cache[account_id] = CachedQuota.from_dict(quota_data) + + print(f"[QuotaCache] 从文件加载 {len(self._cache)} 个账号的额度缓存") + return True + + except json.JSONDecodeError as e: + print(f"[QuotaCache] 缓存文件格式错误: {e}") + return False + except Exception as e: + print(f"[QuotaCache] 加载缓存失败: {e}") + return False + + def save_to_file(self) -> bool: + """保存缓存到文件(同步版本) + + Returns: + 是否保存成功 + """ + try: + # 确保目录存在 + self._cache_file.parent.mkdir(parents=True, exist_ok=True) + + with self._lock: + accounts_data = {} + for account_id, quota in self._cache.items(): + quota_dict = quota.to_dict() + quota_dict.pop("account_id", None) # 避免重复存储 + accounts_data[account_id] = quota_dict + + data = { + "version": "1.0", + "updated_at": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()), + "accounts": accounts_data + } + + # 写入临时文件后重命名,确保原子性 + temp_file = self._cache_file.with_suffix('.tmp') + with open(temp_file, 'w', encoding='utf-8') as f: + json.dump(data, f, indent=2, ensure_ascii=False) + temp_file.replace(self._cache_file) + + return True + + except Exception as e: + print(f"[QuotaCache] 保存缓存失败: {e}") + return False + + async def save_to_file_async(self) -> bool: + """异步保存缓存到文件 + + Returns: + 是否保存成功 + """ + async with self._save_lock: + # 在线程池中执行同步保存 + loop = asyncio.get_event_loop() + return await loop.run_in_executor(None, self.save_to_file) + + def get_summary(self) -> Dict[str, Any]: + """获取缓存汇总信息 + + Returns: + 汇总统计信息 + """ + with self._lock: + total_balance = 0.0 + total_usage = 0.0 + total_limit = 0.0 + error_count = 0 + stale_count = 0 + + current_time = time.time() + + for quota in self._cache.values(): + if quota.has_error(): + error_count += 1 + else: + total_balance += quota.balance + total_usage += quota.current_usage + total_limit += quota.usage_limit + + if (current_time - quota.updated_at) > DEFAULT_CACHE_MAX_AGE: + stale_count += 1 + + return { + "total_accounts": len(self._cache), + "total_balance": round(total_balance, 2), + "total_usage": round(total_usage, 2), + "total_limit": round(total_limit, 2), + "error_count": error_count, + "stale_count": stale_count + } + + +# 全局缓存实例 +_quota_cache: Optional[QuotaCache] = None + + +def get_quota_cache() -> QuotaCache: + """获取全局缓存实例""" + global _quota_cache + if _quota_cache is None: + _quota_cache = QuotaCache() + return _quota_cache diff --git a/KiroProxy/kiro_proxy/core/quota_scheduler.py b/KiroProxy/kiro_proxy/core/quota_scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..2579cbb9d1ef9fcf98326a2f99ca894eccf3e878 --- /dev/null +++ b/KiroProxy/kiro_proxy/core/quota_scheduler.py @@ -0,0 +1,321 @@ +"""额度更新调度器模块 + +实现启动时并发获取所有账号额度、定时更新活跃账号额度的功能。 +""" +import asyncio +import time +from typing import Optional, Set, Dict, List, TYPE_CHECKING +from threading import Lock + +if TYPE_CHECKING: + from .account import Account + +from .quota_cache import QuotaCache, CachedQuota, get_quota_cache +from .usage import get_account_usage + + +# 默认更新间隔(秒) +DEFAULT_UPDATE_INTERVAL = 60 + +# 活跃账号判定时间窗口(秒) +# 需要覆盖一次更新周期,避免低频请求时“永远错过”定时刷新 +ACTIVE_WINDOW_SECONDS = 120 + + +class QuotaScheduler: + """额度更新调度器 + + 负责启动时并发获取所有账号额度,以及定时更新活跃账号的额度。 + """ + + def __init__(self, + quota_cache: Optional[QuotaCache] = None, + update_interval: int = DEFAULT_UPDATE_INTERVAL): + """ + 初始化调度器 + + Args: + quota_cache: 额度缓存实例 + update_interval: 更新间隔(秒) + """ + self.quota_cache = quota_cache or get_quota_cache() + self.update_interval = update_interval + + self._active_accounts: Dict[str, float] = {} # account_id -> last_used_timestamp + self._lock = Lock() + self._task: Optional[asyncio.Task] = None + self._running = False + self._last_full_refresh: Optional[float] = None + self._accounts_getter = None # 获取账号列表的回调函数 + + def set_accounts_getter(self, getter): + """设置获取账号列表的回调函数 + + Args: + getter: 返回账号列表的可调用对象 + """ + self._accounts_getter = getter + + def _get_accounts(self) -> List['Account']: + """获取账号列表""" + if self._accounts_getter: + return self._accounts_getter() + return [] + + async def start(self) -> None: + """启动调度器""" + if self._running: + return + + self._running = True + print("[QuotaScheduler] 启动额度更新调度器") + + # 启动时刷新所有账号额度 + await self.refresh_all() + + # 启动定时更新任务 + self._task = asyncio.create_task(self._update_loop()) + + async def stop(self) -> None: + """停止调度器""" + self._running = False + + if self._task: + self._task.cancel() + try: + await self._task + except asyncio.CancelledError: + pass + self._task = None + + print("[QuotaScheduler] 额度更新调度器已停止") + + async def refresh_all(self) -> Dict[str, bool]: + """刷新所有账号额度 + + Returns: + 账号ID -> 是否成功的字典 + """ + accounts = self._get_accounts() + if not accounts: + print("[QuotaScheduler] 没有账号需要刷新") + return {} + + # 刷新所有账号(包括禁用的,以便检查是否可以解禁) + print(f"[QuotaScheduler] 开始刷新 {len(accounts)} 个账号的额度...") + + # 并发获取所有账号额度 + tasks = [self._refresh_account_internal(acc) for acc in accounts] + results = await asyncio.gather(*tasks, return_exceptions=True) + + # 统计结果 + success_count = 0 + fail_count = 0 + result_dict = {} + + for acc, result in zip(accounts, results): + if isinstance(result, Exception): + result_dict[acc.id] = False + fail_count += 1 + else: + result_dict[acc.id] = result + if result: + success_count += 1 + else: + fail_count += 1 + + self._last_full_refresh = time.time() + + # 保存缓存 + await self.quota_cache.save_to_file_async() + + # 保存账号配置(因为可能有启用/禁用状态变化) + self._save_accounts_config() + + print(f"[QuotaScheduler] 额度刷新完成: 成功 {success_count}, 失败 {fail_count}") + return result_dict + + def _save_accounts_config(self): + """保存账号配置""" + try: + from .state import state + state._save_accounts() + except Exception as e: + print(f"[QuotaScheduler] 保存账号配置失败: {e}") + + async def refresh_account(self, account_id: str) -> bool: + """刷新单个账号额度 + + Args: + account_id: 账号ID + + Returns: + 是否成功 + """ + accounts = self._get_accounts() + account = next((acc for acc in accounts if acc.id == account_id), None) + + if not account: + print(f"[QuotaScheduler] 账号不存在: {account_id}") + return False + + success = await self._refresh_account_internal(account) + + if success: + await self.quota_cache.save_to_file_async() + self._save_accounts_config() + + return success + + async def _refresh_account_internal(self, account: 'Account') -> bool: + """内部刷新账号额度方法 + + Args: + account: 账号对象 + + Returns: + 是否成功 + """ + try: + success, result = await get_account_usage(account) + + if success: + quota = CachedQuota.from_usage_info(account.id, result) + self.quota_cache.set(account.id, quota) + + # 额度为 0 时自动禁用账号 + if quota.is_exhausted: + if account.enabled: + account.enabled = False + # 标记为自动禁用,避免与手动禁用混淆 + if hasattr(account, "auto_disabled"): + account.auto_disabled = True + print(f"[QuotaScheduler] 账号 {account.id} ({account.name}) 额度已用尽,自动禁用") + else: + # 有额度时自动解禁账号(仅对自动禁用的账号生效,避免覆盖手动禁用/封禁) + if (not account.enabled) and getattr(account, "auto_disabled", False): + account.enabled = True + account.auto_disabled = False + print(f"[QuotaScheduler] 账号 {account.id} ({account.name}) 有可用额度,自动启用") + + return True + else: + error_msg = result.get("error", "Unknown error") if isinstance(result, dict) else str(result) + quota = CachedQuota.from_error(account.id, error_msg) + self.quota_cache.set(account.id, quota) + print(f"[QuotaScheduler] 获取账号 {account.id} 额度失败: {error_msg}") + return False + + except Exception as e: + error_msg = str(e) + quota = CachedQuota.from_error(account.id, error_msg) + self.quota_cache.set(account.id, quota) + print(f"[QuotaScheduler] 获取账号 {account.id} 额度异常: {error_msg}") + return False + + def mark_active(self, account_id: str) -> None: + """标记账号为活跃 + + Args: + account_id: 账号ID + """ + with self._lock: + self._active_accounts[account_id] = time.time() + + def is_active(self, account_id: str) -> bool: + """检查账号是否活跃 + + Args: + account_id: 账号ID + + Returns: + 是否在活跃时间窗口内 + """ + with self._lock: + last_used = self._active_accounts.get(account_id) + if last_used is None: + return False + return (time.time() - last_used) < ACTIVE_WINDOW_SECONDS + + def get_active_accounts(self) -> Set[str]: + """获取活跃账号列表 + + Returns: + 活跃账号ID集合 + """ + current_time = time.time() + with self._lock: + return { + account_id + for account_id, last_used in self._active_accounts.items() + if (current_time - last_used) < ACTIVE_WINDOW_SECONDS + } + + def cleanup_inactive(self) -> None: + """清理不活跃的账号记录""" + current_time = time.time() + with self._lock: + self._active_accounts = { + account_id: last_used + for account_id, last_used in self._active_accounts.items() + if (current_time - last_used) < ACTIVE_WINDOW_SECONDS * 2 + } + + async def _update_loop(self) -> None: + """定时更新循环""" + while self._running: + try: + await asyncio.sleep(self.update_interval) + + if not self._running: + break + + # 获取活跃账号 + active_ids = self.get_active_accounts() + + if active_ids: + print(f"[QuotaScheduler] 更新 {len(active_ids)} 个活跃账号的额度...") + + accounts = self._get_accounts() + active_accounts = [acc for acc in accounts if acc.id in active_ids] + + # 并发更新 + tasks = [self._refresh_account_internal(acc) for acc in active_accounts] + await asyncio.gather(*tasks, return_exceptions=True) + + # 保存缓存 + await self.quota_cache.save_to_file_async() + + # 清理不活跃记录 + self.cleanup_inactive() + + except asyncio.CancelledError: + break + except Exception as e: + print(f"[QuotaScheduler] 更新循环异常: {e}") + + def get_last_full_refresh(self) -> Optional[float]: + """获取最后一次全量刷新时间""" + return self._last_full_refresh + + def get_status(self) -> dict: + """获取调度器状态""" + return { + "running": self._running, + "update_interval": self.update_interval, + "active_accounts": list(self.get_active_accounts()), + "active_count": len(self.get_active_accounts()), + "last_full_refresh": self._last_full_refresh + } + + +# 全局调度器实例 +_quota_scheduler: Optional[QuotaScheduler] = None + + +def get_quota_scheduler() -> QuotaScheduler: + """获取全局调度器实例""" + global _quota_scheduler + if _quota_scheduler is None: + _quota_scheduler = QuotaScheduler() + return _quota_scheduler diff --git a/KiroProxy/kiro_proxy/core/rate_limiter.py b/KiroProxy/kiro_proxy/core/rate_limiter.py new file mode 100644 index 0000000000000000000000000000000000000000..27af43743265abd05f1a3f32ac9a36239204be32 --- /dev/null +++ b/KiroProxy/kiro_proxy/core/rate_limiter.py @@ -0,0 +1,125 @@ +"""请求限速器 - 降低账号封禁风险 + +通过限制请求频率来降低被检测为异常活动的风险: +- 每账号请求间隔 +- 全局请求限制 +- 突发请求检测 + +注意:429 冷却时间已改为自动管理(固定5分钟),不再需要手动配置 +""" +import time +from dataclasses import dataclass, field +from typing import Dict, Optional +from collections import deque + + +@dataclass +class RateLimitConfig: + """限速配置""" + # 每账号最小请求间隔(秒) + min_request_interval: float = 0.5 + + # 每账号每分钟最大请求数 + max_requests_per_minute: int = 60 + + # 全局每分钟最大请求数 + global_max_requests_per_minute: int = 120 + + # 是否启用限速 + enabled: bool = False + + +@dataclass +class AccountRateState: + """账号限速状态""" + last_request_time: float = 0 + request_times: deque = field(default_factory=lambda: deque(maxlen=100)) + + def get_requests_in_window(self, window_seconds: int = 60) -> int: + """获取时间窗口内的请求数""" + now = time.time() + cutoff = now - window_seconds + return sum(1 for t in self.request_times if t > cutoff) + + +class RateLimiter: + """请求限速器""" + + def __init__(self, config: RateLimitConfig = None): + self.config = config or RateLimitConfig() + self._account_states: Dict[str, AccountRateState] = {} + self._global_requests: deque = deque(maxlen=1000) + + def _get_account_state(self, account_id: str) -> AccountRateState: + """获取账号状态""" + if account_id not in self._account_states: + self._account_states[account_id] = AccountRateState() + return self._account_states[account_id] + + def can_request(self, account_id: str) -> tuple: + """检查是否可以发送请求 + + Returns: + (can_request, wait_seconds, reason) + """ + if not self.config.enabled: + return True, 0, None + + now = time.time() + state = self._get_account_state(account_id) + + # 检查最小请求间隔 + time_since_last = now - state.last_request_time + if time_since_last < self.config.min_request_interval: + wait = self.config.min_request_interval - time_since_last + return False, wait, f"请求过快,请等待 {wait:.1f} 秒" + + # 检查每账号每分钟限制 + account_rpm = state.get_requests_in_window(60) + if account_rpm >= self.config.max_requests_per_minute: + return False, 2, f"账号请求过于频繁 ({account_rpm}/分钟)" + + # 检查全局每分钟限制 + global_rpm = sum(1 for t in self._global_requests if t > now - 60) + if global_rpm >= self.config.global_max_requests_per_minute: + return False, 1, f"全局请求过于频繁 ({global_rpm}/分钟)" + + return True, 0, None + + def record_request(self, account_id: str): + """记录请求""" + now = time.time() + state = self._get_account_state(account_id) + state.last_request_time = now + state.request_times.append(now) + self._global_requests.append(now) + + def get_stats(self) -> dict: + """获取统计信息""" + now = time.time() + return { + "enabled": self.config.enabled, + "global_rpm": sum(1 for t in self._global_requests if t > now - 60), + "accounts": { + aid: { + "rpm": state.get_requests_in_window(60), + "last_request": now - state.last_request_time if state.last_request_time else None + } + for aid, state in self._account_states.items() + } + } + + def update_config(self, **kwargs): + """更新配置""" + for key, value in kwargs.items(): + if hasattr(self.config, key): + setattr(self.config, key, value) + + +# 全局实例 +rate_limiter = RateLimiter() + + +def get_rate_limiter() -> RateLimiter: + """获取限速器实例""" + return rate_limiter diff --git a/KiroProxy/kiro_proxy/core/refresh_manager.py b/KiroProxy/kiro_proxy/core/refresh_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..97f049b6959e54e343f7b8a2b4684b7b703312e5 --- /dev/null +++ b/KiroProxy/kiro_proxy/core/refresh_manager.py @@ -0,0 +1,888 @@ +"""Token 刷新管理模块 + +提供 Token 批量刷新的管理功能,包括: +- 刷新进度跟踪 +- 并发控制 +- 重试机制配置 +- 全局锁防止重复刷新 +- Token 过期检测和自动刷新 +- 指数退避重试策略 +""" +import time +import asyncio +from dataclasses import dataclass, field, asdict +from typing import Optional, Dict, Any, List, Tuple, Callable, TYPE_CHECKING +from threading import Lock + +if TYPE_CHECKING: + from .account import Account + + +@dataclass +class RefreshProgress: + """刷新进度信息 + + 用于跟踪批量 Token 刷新操作的进度状态。 + + Attributes: + total: 需要刷新的账号总数 + completed: 已完成处理的账号数(包括成功和失败) + success: 刷新成功的账号数 + failed: 刷新失败的账号数 + current_account: 当前正在处理的账号ID + status: 刷新状态 - running(进行中), completed(已完成), error(出错) + started_at: 刷新开始时间戳 + message: 状态消息,用于显示当前操作或错误信息 + """ + total: int = 0 + completed: int = 0 + success: int = 0 + failed: int = 0 + current_account: Optional[str] = None + status: str = "running" # running, completed, error + started_at: float = field(default_factory=time.time) + message: Optional[str] = None + + def to_dict(self) -> Dict[str, Any]: + """转换为字典格式 + + Returns: + 包含所有进度信息的字典 + """ + return asdict(self) + + @property + def progress_percent(self) -> float: + """计算完成百分比 + + Returns: + 完成百分比(0-100) + """ + if self.total == 0: + return 0.0 + return round((self.completed / self.total) * 100, 2) + + @property + def elapsed_seconds(self) -> float: + """计算已用时间(秒) + + Returns: + 从开始到现在的秒数 + """ + return time.time() - self.started_at + + def is_running(self) -> bool: + """检查是否正在运行 + + Returns: + True 表示正在运行 + """ + return self.status == "running" + + def is_completed(self) -> bool: + """检查是否已完成 + + Returns: + True 表示已完成(成功或出错) + """ + return self.status in ("completed", "error") + + +@dataclass +class RefreshConfig: + """刷新配置 + + 控制 Token 刷新行为的配置参数。 + + Attributes: + max_retries: 单个账号刷新失败时的最大重试次数 + retry_base_delay: 重试基础延迟时间(秒),实际延迟会指数增长 + concurrency: 并发刷新的账号数量 + token_refresh_before_expiry: Token 过期前多少秒开始刷新(默认5分钟) + auto_refresh_interval: 自动刷新检查间隔(秒) + """ + max_retries: int = 3 + retry_base_delay: float = 1.0 + concurrency: int = 3 + token_refresh_before_expiry: int = 300 # 5分钟 + auto_refresh_interval: int = 60 # 1分钟 + + def to_dict(self) -> Dict[str, Any]: + """转换为字典格式 + + Returns: + 包含所有配置项的字典 + """ + return asdict(self) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> 'RefreshConfig': + """从字典创建配置实例 + + Args: + data: 配置字典 + + Returns: + RefreshConfig 实例 + """ + return cls( + max_retries=data.get("max_retries", 3), + retry_base_delay=data.get("retry_base_delay", 1.0), + concurrency=data.get("concurrency", 3), + token_refresh_before_expiry=data.get("token_refresh_before_expiry", 300), + auto_refresh_interval=data.get("auto_refresh_interval", 60) + ) + + def validate(self) -> bool: + """验证配置有效性 + + Returns: + True 表示配置有效 + + Raises: + ValueError: 配置值无效时抛出 + """ + if self.max_retries < 0: + raise ValueError("max_retries 不能为负数") + if self.retry_base_delay <= 0: + raise ValueError("retry_base_delay 必须大于0") + if self.concurrency < 1: + raise ValueError("concurrency 必须至少为1") + if self.token_refresh_before_expiry < 0: + raise ValueError("token_refresh_before_expiry 不能为负数") + if self.auto_refresh_interval < 1: + raise ValueError("auto_refresh_interval 必须至少为1秒") + return True + + +class RefreshManager: + """Token 刷新管理器 + + 管理 Token 批量刷新操作,提供: + - 全局锁机制防止重复刷新 + - 进度跟踪 + - 配置管理 + - 自动 Token 刷新定时器 + + 使用示例: + manager = get_refresh_manager() + if not manager.is_refreshing(): + # 开始刷新操作 + pass + """ + + def __init__(self, config: Optional[RefreshConfig] = None): + """初始化刷新管理器 + + Args: + config: 刷新配置,None 则使用默认配置 + """ + # 配置 + self._config = config or RefreshConfig() + + # 线程锁(用于同步访问状态) + self._lock = Lock() + + # 异步锁(用于防止并发刷新操作) + self._async_lock = asyncio.Lock() + + # 刷新状态 + self._is_refreshing: bool = False + self._progress: Optional[RefreshProgress] = None + + # 上次刷新完成时间 + self._last_refresh_time: Optional[float] = None + + # 自动刷新定时器 + self._auto_refresh_task: Optional[asyncio.Task] = None + self._auto_refresh_running: bool = False + + # 获取账号列表的回调函数 + self._accounts_getter: Optional[Callable] = None + + @property + def config(self) -> RefreshConfig: + """获取当前配置 + + Returns: + 当前的刷新配置 + """ + with self._lock: + return self._config + + def is_refreshing(self) -> bool: + """检查是否正在刷新 + + Returns: + True 表示正在进行刷新操作 + """ + with self._lock: + return self._is_refreshing + + def get_progress(self) -> Optional[RefreshProgress]: + """获取当前刷新进度 + + Returns: + 当前进度信息,如果没有进行中的刷新则返回 None + """ + with self._lock: + return self._progress + + def get_progress_dict(self) -> Optional[Dict[str, Any]]: + """获取当前刷新进度(字典格式) + + Returns: + 进度信息字典,如果没有进行中的刷新则返回 None + """ + with self._lock: + if self._progress is None: + return None + return self._progress.to_dict() + + def update_config(self, **kwargs) -> None: + """更新配置参数 + + 支持的参数: + max_retries: 最大重试次数 + retry_base_delay: 重试基础延迟 + concurrency: 并发数 + token_refresh_before_expiry: Token 过期前刷新时间 + auto_refresh_interval: 自动刷新检查间隔 + + Args: + **kwargs: 要更新的配置项 + + Raises: + ValueError: 配置值无效时抛出 + """ + with self._lock: + # 创建新配置 + new_config = RefreshConfig( + max_retries=kwargs.get("max_retries", self._config.max_retries), + retry_base_delay=kwargs.get("retry_base_delay", self._config.retry_base_delay), + concurrency=kwargs.get("concurrency", self._config.concurrency), + token_refresh_before_expiry=kwargs.get( + "token_refresh_before_expiry", + self._config.token_refresh_before_expiry + ), + auto_refresh_interval=kwargs.get( + "auto_refresh_interval", + self._config.auto_refresh_interval + ) + ) + + # 验证配置 + new_config.validate() + + # 应用新配置 + self._config = new_config + + def _start_refresh(self, total: int, message: Optional[str] = None) -> RefreshProgress: + """开始刷新操作(内部方法) + + Args: + total: 需要刷新的账号总数 + message: 初始状态消息 + + Returns: + 新创建的进度对象 + """ + with self._lock: + self._is_refreshing = True + self._progress = RefreshProgress( + total=total, + completed=0, + success=0, + failed=0, + current_account=None, + status="running", + started_at=time.time(), + message=message or "开始刷新" + ) + return self._progress + + def _update_progress( + self, + current_account: Optional[str] = None, + success: bool = False, + failed: bool = False, + message: Optional[str] = None + ) -> None: + """更新刷新进度(内部方法) + + Args: + current_account: 当前处理的账号ID + success: 是否成功完成一个账号 + failed: 是否失败一个账号 + message: 状态消息 + """ + with self._lock: + if self._progress is None: + return + + if current_account is not None: + self._progress.current_account = current_account + + if success: + self._progress.success += 1 + self._progress.completed += 1 + elif failed: + self._progress.failed += 1 + self._progress.completed += 1 + + if message is not None: + self._progress.message = message + + def _finish_refresh(self, status: str = "completed", message: Optional[str] = None) -> None: + """完成刷新操作(内部方法) + + Args: + status: 最终状态 - completed 或 error + message: 最终状态消息 + """ + with self._lock: + self._is_refreshing = False + self._last_refresh_time = time.time() + + if self._progress is not None: + self._progress.status = status + self._progress.current_account = None + if message is not None: + self._progress.message = message + elif status == "completed": + self._progress.message = ( + f"刷新完成: 成功 {self._progress.success}, " + f"失败 {self._progress.failed}" + ) + + def get_last_refresh_time(self) -> Optional[float]: + """获取上次刷新完成时间 + + Returns: + 上次刷新完成的时间戳,如果从未刷新则返回 None + """ + with self._lock: + return self._last_refresh_time + + def get_status(self) -> Dict[str, Any]: + """获取管理器状态 + + Returns: + 包含管理器状态信息的字典 + """ + with self._lock: + return { + "is_refreshing": self._is_refreshing, + "progress": self._progress.to_dict() if self._progress else None, + "last_refresh_time": self._last_refresh_time, + "config": self._config.to_dict() + } + + async def acquire_refresh_lock(self) -> bool: + """尝试获取刷新锁 + + 用于在开始刷新操作前获取异步锁,防止并发刷新。 + + Returns: + True 表示成功获取锁,False 表示已有刷新在进行 + """ + if self._async_lock.locked(): + return False + + await self._async_lock.acquire() + return True + + def release_refresh_lock(self) -> None: + """释放刷新锁 + + 在刷新操作完成后调用,释放异步锁。 + """ + if self._async_lock.locked(): + self._async_lock.release() + + def should_refresh_token(self, account: 'Account') -> bool: + """判断是否需要刷新 Token + + 检查账号的 Token 是否即将过期(过期前5分钟)或已过期。 + + Args: + account: 账号对象 + + Returns: + True 表示需要刷新 Token + """ + creds = account.get_credentials() + if creds is None: + return True # 无法获取凭证,需要刷新 + + # 检查是否已过期或即将过期 + minutes_before = self._config.token_refresh_before_expiry // 60 + return creds.is_expired() or creds.is_expiring_soon(minutes=minutes_before) + + async def refresh_token_if_needed(self, account: 'Account') -> Tuple[bool, str]: + """如果需要则刷新 Token + + 检查账号 Token 状态,如果即将过期或已过期则刷新。 + + Args: + account: 账号对象 + + Returns: + (success, message) 元组 + - success: True 表示 Token 有效(无需刷新或刷新成功) + - message: 状态消息 + """ + if not self.should_refresh_token(account): + return True, "Token 有效,无需刷新" + + print(f"[RefreshManager] 账号 {account.id} Token 即将过期,开始刷新...") + + success, result = await account.refresh_token() + + if success: + print(f"[RefreshManager] 账号 {account.id} Token 刷新成功") + return True, "Token 刷新成功" + else: + print(f"[RefreshManager] 账号 {account.id} Token 刷新失败: {result}") + return False, f"Token 刷新失败: {result}" + + async def refresh_account_with_token( + self, + account: 'Account', + get_quota_func: Optional[Callable] = None + ) -> Tuple[bool, str]: + """刷新单个账号(先刷新 Token,再获取额度) + + Args: + account: 账号对象 + get_quota_func: 获取额度的异步函数,接受 account 参数 + + Returns: + (success, message) 元组 + """ + # 1. 先刷新 Token(如果需要) + token_success, token_msg = await self.refresh_token_if_needed(account) + + if not token_success: + return False, token_msg + + # 2. 获取额度(如果提供了获取函数) + if get_quota_func: + try: + quota_success, quota_result = await get_quota_func(account) + if quota_success: + return True, "刷新成功" + else: + error_msg = quota_result.get("error", "Unknown error") if isinstance(quota_result, dict) else str(quota_result) + return False, f"获取额度失败: {error_msg}" + except Exception as e: + return False, f"获取额度异常: {str(e)}" + + return True, token_msg + + async def retry_with_backoff( + self, + func: Callable, + *args, + max_retries: Optional[int] = None, + **kwargs + ) -> Tuple[bool, Any]: + """带指数退避的重试 + + 执行异步函数,失败时使用指数退避策略重试。 + + Args: + func: 要执行的异步函数 + *args: 传递给函数的位置参数 + max_retries: 最大重试次数,None 则使用配置值 + **kwargs: 传递给函数的关键字参数 + + Returns: + (success, result) 元组 + - success: True 表示执行成功 + - result: 成功时为函数返回值,失败时为错误信息 + """ + retries = max_retries if max_retries is not None else self._config.max_retries + base_delay = self._config.retry_base_delay + + last_error = None + + for attempt in range(retries + 1): + try: + result = await func(*args, **kwargs) + + # 检查返回值格式 + if isinstance(result, tuple) and len(result) == 2: + success, data = result + if success: + return True, data + else: + last_error = data + # 检查是否是 429 错误 + if self._is_rate_limit_error(data): + delay = self._get_rate_limit_delay(attempt, base_delay) + else: + delay = base_delay * (2 ** attempt) + else: + # 函数返回非元组,视为成功 + return True, result + + except Exception as e: + last_error = str(e) + delay = base_delay * (2 ** attempt) + + # 如果还有重试机会,等待后重试 + if attempt < retries: + print(f"[RefreshManager] 第 {attempt + 1} 次尝试失败,{delay:.1f}秒后重试...") + await asyncio.sleep(delay) + + return False, last_error + + def _is_rate_limit_error(self, error: Any) -> bool: + """检查是否是限流错误(429) + + Args: + error: 错误信息 + + Returns: + True 表示是限流错误 + """ + if isinstance(error, str): + return "429" in error or "rate limit" in error.lower() or "请求过于频繁" in error + return False + + def _get_rate_limit_delay(self, attempt: int, base_delay: float) -> float: + """获取限流错误的等待时间 + + 429 错误使用更长的等待时间。 + + Args: + attempt: 当前尝试次数(从0开始) + base_delay: 基础延迟 + + Returns: + 等待时间(秒) + """ + # 429 错误使用 3 倍的基础延迟 + return base_delay * 3 * (2 ** attempt) + + async def refresh_all_with_token( + self, + accounts: List['Account'], + get_quota_func: Optional[Callable] = None, + skip_disabled: bool = True, + skip_error: bool = True + ) -> RefreshProgress: + """刷新所有账号(先刷新 Token,再获取额度) + + 使用全局锁防止并发刷新,支持进度跟踪。 + + Args: + accounts: 账号列表 + get_quota_func: 获取额度的异步函数 + skip_disabled: 是否跳过已禁用的账号 + skip_error: 是否跳过已处于错误状态的账号 + + Returns: + 刷新进度信息 + """ + # 尝试获取锁 + if not await self.acquire_refresh_lock(): + # 已有刷新在进行 + progress = self.get_progress() + if progress: + return progress + # 返回一个错误状态的进度 + return RefreshProgress( + total=0, + status="error", + message="刷新操作正在进行中" + ) + + try: + # 过滤账号 + accounts_to_refresh = [] + for acc in accounts: + if skip_disabled and not acc.enabled: + continue + if skip_error and acc.status.value in ("unhealthy", "suspended"): + continue + accounts_to_refresh.append(acc) + + total = len(accounts_to_refresh) + + # 开始刷新 + self._start_refresh(total, f"开始刷新 {total} 个账号") + + if total == 0: + self._finish_refresh("completed", "没有需要刷新的账号") + return self.get_progress() + + # 使用信号量控制并发 + semaphore = asyncio.Semaphore(self._config.concurrency) + + async def refresh_one(account: 'Account'): + async with semaphore: + self._update_progress( + current_account=account.id, + message=f"正在刷新: {account.name}" + ) + + # 使用重试机制刷新 + success, result = await self.retry_with_backoff( + self.refresh_account_with_token, + account, + get_quota_func + ) + + if success: + self._update_progress(success=True) + else: + self._update_progress(failed=True) + + return success, result + + # 并发执行 + tasks = [refresh_one(acc) for acc in accounts_to_refresh] + await asyncio.gather(*tasks, return_exceptions=True) + + # 完成 + self._finish_refresh("completed") + return self.get_progress() + + except Exception as e: + self._finish_refresh("error", f"刷新异常: {str(e)}") + return self.get_progress() + + finally: + self.release_refresh_lock() + + def _is_auth_error(self, error: Any) -> bool: + """检查是否是认证错误(401) + + Args: + error: 错误信息 + + Returns: + True 表示是认证错误 + """ + if isinstance(error, str): + return "401" in error or "unauthorized" in error.lower() or "凭证已过期" in error or "需要重新登录" in error + return False + + async def execute_with_auth_retry( + self, + account: 'Account', + func: Callable, + *args, + **kwargs + ) -> Tuple[bool, Any]: + """执行操作,遇到 401 错误时自动刷新 Token 并重试 + + Args: + account: 账号对象 + func: 要执行的异步函数 + *args: 传递给函数的位置参数 + **kwargs: 传递给函数的关键字参数 + + Returns: + (success, result) 元组 + """ + try: + result = await func(*args, **kwargs) + + # 检查返回值 + if isinstance(result, tuple) and len(result) == 2: + success, data = result + if success: + return True, data + + # 检查是否是 401 错误 + if self._is_auth_error(data): + print(f"[RefreshManager] 账号 {account.id} 遇到 401 错误,尝试刷新 Token...") + + # 刷新 Token + refresh_success, refresh_msg = await account.refresh_token() + + if refresh_success: + print(f"[RefreshManager] Token 刷新成功,重试请求...") + # 重试原请求 + retry_result = await func(*args, **kwargs) + if isinstance(retry_result, tuple) and len(retry_result) == 2: + return retry_result + return True, retry_result + else: + return False, f"Token 刷新失败: {refresh_msg}" + + return False, data + + return True, result + + except Exception as e: + error_str = str(e) + + # 检查异常是否是 401 错误 + if self._is_auth_error(error_str): + print(f"[RefreshManager] 账号 {account.id} 遇到 401 异常,尝试刷新 Token...") + + refresh_success, refresh_msg = await account.refresh_token() + + if refresh_success: + print(f"[RefreshManager] Token 刷新成功,重试请求...") + try: + retry_result = await func(*args, **kwargs) + if isinstance(retry_result, tuple) and len(retry_result) == 2: + return retry_result + return True, retry_result + except Exception as retry_e: + return False, f"重试失败: {str(retry_e)}" + else: + return False, f"Token 刷新失败: {refresh_msg}" + + return False, error_str + + def set_accounts_getter(self, getter: Callable) -> None: + """设置获取账号列表的回调函数 + + Args: + getter: 返回账号列表的可调用对象 + """ + self._accounts_getter = getter + + def _get_accounts(self) -> List['Account']: + """获取账号列表""" + if self._accounts_getter: + return self._accounts_getter() + return [] + + async def start_auto_refresh(self) -> None: + """启动自动 Token 刷新定时器 + + 定期检查所有账号的 Token 状态,自动刷新即将过期的 Token。 + 启动前会清除已存在的定时器,防止重复启动。 + """ + # 先停止已存在的定时器 + await self.stop_auto_refresh() + + self._auto_refresh_running = True + self._auto_refresh_task = asyncio.create_task(self._auto_refresh_loop()) + print(f"[RefreshManager] 自动 Token 刷新定时器已启动,检查间隔: {self._config.auto_refresh_interval}秒") + + async def stop_auto_refresh(self) -> None: + """停止自动 Token 刷新定时器""" + self._auto_refresh_running = False + + if self._auto_refresh_task: + self._auto_refresh_task.cancel() + try: + await self._auto_refresh_task + except asyncio.CancelledError: + pass + self._auto_refresh_task = None + print("[RefreshManager] 自动 Token 刷新定时器已停止") + + def is_auto_refresh_running(self) -> bool: + """检查自动刷新定时器是否在运行 + + Returns: + True 表示定时器正在运行 + """ + return self._auto_refresh_running and self._auto_refresh_task is not None + + async def _auto_refresh_loop(self) -> None: + """自动刷新循环 + + 定期检查所有账号的 Token 状态,刷新即将过期的 Token。 + 跳过已禁用或错误状态的账号,单个失败不影响其他账号。 + """ + while self._auto_refresh_running: + try: + await asyncio.sleep(self._config.auto_refresh_interval) + + if not self._auto_refresh_running: + break + + accounts = self._get_accounts() + if not accounts: + continue + + # 检查需要刷新的账号 + accounts_to_refresh = [] + for account in accounts: + # 跳过已禁用的账号 + if not account.enabled: + continue + + # 跳过错误状态的账号 + if hasattr(account, 'status') and account.status.value in ("unhealthy", "suspended", "disabled"): + continue + + # 检查是否需要刷新 Token + if self.should_refresh_token(account): + accounts_to_refresh.append(account) + + if accounts_to_refresh: + print(f"[RefreshManager] 发现 {len(accounts_to_refresh)} 个账号需要刷新 Token") + + # 逐个刷新,单个失败不影响其他 + for account in accounts_to_refresh: + try: + success, message = await self.refresh_token_if_needed(account) + if not success: + print(f"[RefreshManager] 账号 {account.id} 自动刷新失败: {message}") + except Exception as e: + print(f"[RefreshManager] 账号 {account.id} 自动刷新异常: {e}") + # 继续处理其他账号 + + except asyncio.CancelledError: + break + except Exception as e: + print(f"[RefreshManager] 自动刷新循环异常: {e}") + # 继续运行,不因异常停止 + + def get_auto_refresh_status(self) -> Dict[str, Any]: + """获取自动刷新状态 + + Returns: + 包含自动刷新状态信息的字典 + """ + return { + "running": self.is_auto_refresh_running(), + "interval": self._config.auto_refresh_interval, + "token_refresh_before_expiry": self._config.token_refresh_before_expiry + } + + +# 全局刷新管理器实例 +_refresh_manager: Optional[RefreshManager] = None +_manager_lock = Lock() + + +def get_refresh_manager() -> RefreshManager: + """获取全局刷新管理器实例 + + 使用单例模式,确保全局只有一个刷新管理器实例。 + + Returns: + 全局 RefreshManager 实例 + """ + global _refresh_manager + + if _refresh_manager is None: + with _manager_lock: + # 双重检查锁定 + if _refresh_manager is None: + _refresh_manager = RefreshManager() + + return _refresh_manager + + +def reset_refresh_manager() -> None: + """重置全局刷新管理器 + + 主要用于测试场景,重置全局实例。 + """ + global _refresh_manager + + with _manager_lock: + _refresh_manager = None diff --git a/KiroProxy/kiro_proxy/core/retry.py b/KiroProxy/kiro_proxy/core/retry.py new file mode 100644 index 0000000000000000000000000000000000000000..f83b000020f1ba25b27d777363685b81e986a455 --- /dev/null +++ b/KiroProxy/kiro_proxy/core/retry.py @@ -0,0 +1,117 @@ +"""请求重试机制""" +import asyncio +from typing import Callable, Any, Optional, Set +from functools import wraps + +# 可重试的状态码 +RETRYABLE_STATUS_CODES: Set[int] = { + 408, # Request Timeout + 500, # Internal Server Error + 502, # Bad Gateway + 503, # Service Unavailable + 504, # Gateway Timeout +} + +# 不可重试的状态码(直接返回错误) +NON_RETRYABLE_STATUS_CODES: Set[int] = { + 400, # Bad Request + 401, # Unauthorized + 403, # Forbidden + 404, # Not Found + 422, # Unprocessable Entity +} + + +def is_retryable_error(status_code: Optional[int], error: Optional[Exception] = None) -> bool: + """判断是否为可重试的错误""" + # 网络错误可重试 + if error: + error_name = type(error).__name__.lower() + if any(kw in error_name for kw in ['timeout', 'connect', 'network', 'reset']): + return True + + # 特定状态码可重试 + if status_code and status_code in RETRYABLE_STATUS_CODES: + return True + + return False + + +def is_non_retryable_error(status_code: Optional[int]) -> bool: + """判断是否为不可重试的错误""" + return status_code in NON_RETRYABLE_STATUS_CODES if status_code else False + + +async def retry_async( + func: Callable, + max_retries: int = 2, + base_delay: float = 0.5, + max_delay: float = 5.0, + on_retry: Optional[Callable[[int, Exception], None]] = None +) -> Any: + """ + 异步重试装饰器 + + Args: + func: 要执行的异步函数 + max_retries: 最大重试次数 + base_delay: 基础延迟(秒) + max_delay: 最大延迟(秒) + on_retry: 重试时的回调函数 + """ + last_error = None + + for attempt in range(max_retries + 1): + try: + return await func() + except Exception as e: + last_error = e + + # 检查是否可重试 + status_code = getattr(e, 'status_code', None) + if is_non_retryable_error(status_code): + raise + + if attempt < max_retries and is_retryable_error(status_code, e): + # 指数退避 + delay = min(base_delay * (2 ** attempt), max_delay) + + if on_retry: + on_retry(attempt + 1, e) + else: + print(f"[Retry] 第 {attempt + 1} 次重试,延迟 {delay:.1f}s,错误: {type(e).__name__}") + + await asyncio.sleep(delay) + else: + raise + + raise last_error + + +class RetryableRequest: + """可重试的请求上下文""" + + def __init__(self, max_retries: int = 2, base_delay: float = 0.5): + self.max_retries = max_retries + self.base_delay = base_delay + self.attempt = 0 + self.last_error = None + + def should_retry(self, status_code: Optional[int] = None, error: Optional[Exception] = None) -> bool: + """判断是否应该重试""" + self.attempt += 1 + self.last_error = error + + if self.attempt > self.max_retries: + return False + + if is_non_retryable_error(status_code): + return False + + return is_retryable_error(status_code, error) + + async def wait(self): + """等待重试延迟""" + delay = min(self.base_delay * (2 ** (self.attempt - 1)), 5.0) + print(f"[Retry] 第 {self.attempt} 次重试,延迟 {delay:.1f}s") + await asyncio.sleep(delay) diff --git a/KiroProxy/kiro_proxy/core/scheduler.py b/KiroProxy/kiro_proxy/core/scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..88e97bca5aff3935508861551f9476e9f6cc5e6d --- /dev/null +++ b/KiroProxy/kiro_proxy/core/scheduler.py @@ -0,0 +1,125 @@ +"""后台任务调度器""" +import asyncio +from typing import Optional +from datetime import datetime + + +class BackgroundScheduler: + """后台任务调度器 + + 负责: + - Token 过期预刷新 + - 账号健康检查 + - 统计数据更新 + """ + + def __init__(self): + self._task: Optional[asyncio.Task] = None + self._running = False + self._refresh_interval = 300 # 5 分钟检查一次 + self._health_check_interval = 600 # 10 分钟健康检查 + self._last_health_check = 0 + + async def start(self): + """启动后台任务""" + if self._running: + return + self._running = True + self._task = asyncio.create_task(self._run()) + print("[Scheduler] 后台任务已启动") + + async def stop(self): + """停止后台任务""" + self._running = False + if self._task: + self._task.cancel() + try: + await self._task + except asyncio.CancelledError: + pass + print("[Scheduler] 后台任务已停止") + + async def _run(self): + """主循环""" + from . import state + import time + + while self._running: + try: + # Token 预刷新 + await self._refresh_expiring_tokens(state) + + # 健康检查 + now = time.time() + if now - self._last_health_check > self._health_check_interval: + await self._health_check(state) + self._last_health_check = now + + await asyncio.sleep(self._refresh_interval) + + except asyncio.CancelledError: + break + except Exception as e: + print(f"[Scheduler] 错误: {e}") + await asyncio.sleep(60) + + async def _refresh_expiring_tokens(self, state): + """刷新即将过期的 Token""" + for acc in state.accounts: + if not acc.enabled: + continue + + # 提前 15 分钟刷新 + if acc.is_token_expiring_soon(15): + print(f"[Scheduler] Token 即将过期,预刷新: {acc.name}") + success, msg = await acc.refresh_token() + if success: + print(f"[Scheduler] Token 刷新成功: {acc.name}") + else: + print(f"[Scheduler] Token 刷新失败: {acc.name} - {msg}") + + async def _health_check(self, state): + """健康检查""" + import httpx + from ..config import MODELS_URL + from ..credential import CredentialStatus + + for acc in state.accounts: + if not acc.enabled: + continue + + try: + token = acc.get_token() + if not token: + acc.status = CredentialStatus.UNHEALTHY + continue + + headers = { + "Authorization": f"Bearer {token}", + "content-type": "application/json" + } + + async with httpx.AsyncClient(verify=False, timeout=10) as client: + resp = await client.get( + MODELS_URL, + headers=headers, + params={"origin": "AI_EDITOR"} + ) + + if resp.status_code == 200: + if acc.status == CredentialStatus.UNHEALTHY: + acc.status = CredentialStatus.ACTIVE + print(f"[HealthCheck] 账号恢复健康: {acc.name}") + elif resp.status_code == 401: + acc.status = CredentialStatus.UNHEALTHY + print(f"[HealthCheck] 账号认证失败: {acc.name}") + elif resp.status_code == 429: + # 配额超限,不改变状态 + pass + + except Exception as e: + print(f"[HealthCheck] 检查失败 {acc.name}: {e}") + + +# 全局调度器实例 +scheduler = BackgroundScheduler() diff --git a/KiroProxy/kiro_proxy/core/state.py b/KiroProxy/kiro_proxy/core/state.py new file mode 100644 index 0000000000000000000000000000000000000000..084b50ba383511eaf47b22aadeb21870150a0bc1 --- /dev/null +++ b/KiroProxy/kiro_proxy/core/state.py @@ -0,0 +1,280 @@ +"""全局状态管理""" +import time +from collections import deque +from dataclasses import dataclass +from typing import Optional, List, Dict +from pathlib import Path + +from ..config import TOKEN_PATH +from ..credential import quota_manager, CredentialStatus +from .account import Account +from .persistence import load_accounts, save_accounts +from .quota_cache import get_quota_cache +from .account_selector import get_account_selector, SelectionStrategy +from .quota_scheduler import get_quota_scheduler + + +@dataclass +class RequestLog: + """请求日志""" + id: str + timestamp: float + method: str + path: str + model: str + account_id: Optional[str] + status: int + duration_ms: float + tokens_in: int = 0 + tokens_out: int = 0 + error: Optional[str] = None + + +class ProxyState: + """全局状态管理""" + + def __init__(self): + self.accounts: List[Account] = [] + self.request_logs: deque = deque(maxlen=1000) + self.total_requests: int = 0 + self.total_errors: int = 0 + self.session_locks: Dict[str, str] = {} + self.session_timestamps: Dict[str, float] = {} + self.start_time: float = time.time() + self._load_accounts() + + def _load_accounts(self): + """从配置文件加载账号""" + saved = load_accounts() + if saved: + for acc_data in saved: + # 验证 token 文件存在 + if Path(acc_data.get("token_path", "")).exists(): + self.accounts.append(Account( + id=acc_data["id"], + name=acc_data["name"], + token_path=acc_data["token_path"], + enabled=acc_data.get("enabled", True), + auto_disabled=acc_data.get("auto_disabled", False), + )) + print(f"[State] 从配置加载 {len(self.accounts)} 个账号") + + # 如果没有账号,尝试添加默认账号 + if not self.accounts and TOKEN_PATH.exists(): + self.accounts.append(Account( + id="default", + name="默认账号", + token_path=str(TOKEN_PATH) + )) + self._save_accounts() + + def _save_accounts(self): + """保存账号到配置文件""" + accounts_data = [ + { + "id": acc.id, + "name": acc.name, + "token_path": acc.token_path, + "enabled": acc.enabled, + "auto_disabled": getattr(acc, "auto_disabled", False), + } + for acc in self.accounts + ] + save_accounts(accounts_data) + + def get_available_account(self, session_id: Optional[str] = None) -> Optional[Account]: + """获取可用账号(支持会话粘性和智能选择)""" + quota_manager.cleanup_expired() + + selector = get_account_selector() + has_priority = bool(selector.get_priority_accounts()) + use_session_sticky = bool(session_id) and not has_priority and selector.strategy != SelectionStrategy.RANDOM + + # 会话粘性 + if use_session_sticky and session_id in self.session_locks: + account_id = self.session_locks[session_id] + ts = self.session_timestamps.get(session_id, 0) + if time.time() - ts < 60: + for acc in self.accounts: + if acc.id == account_id and acc.is_available(): + self.session_timestamps[session_id] = time.time() + return acc + + # 使用 AccountSelector 选择账号 + account = selector.select(self.accounts, session_id) + + if account and use_session_sticky: + self.session_locks[session_id] = account.id + self.session_timestamps[session_id] = time.time() + + # 标记为活跃账号,便于额度调度器定期更新 + if account: + try: + get_quota_scheduler().mark_active(account.id) + except Exception: + pass + + return account + + def mark_account_used(self, account_id: str) -> None: + """标记账号被使用""" + scheduler = get_quota_scheduler() + scheduler.mark_active(account_id) + + for acc in self.accounts: + if acc.id == account_id: + acc.last_used = time.time() + break + + def get_next_available_account(self, exclude_id: str) -> Optional[Account]: + """获取下一个可用账号(排除指定账号)""" + available = [a for a in self.accounts if a.is_available() and a.id != exclude_id] + if not available: + return None + account = min(available, key=lambda a: a.request_count) + try: + get_quota_scheduler().mark_active(account.id) + except Exception: + pass + return account + + def mark_rate_limited(self, account_id: str, duration_seconds: int = 60): + """标记账号限流""" + for acc in self.accounts: + if acc.id == account_id: + acc.mark_quota_exceeded("Rate limited") + break + + def mark_quota_exceeded(self, account_id: str, reason: str = "Quota exceeded"): + """标记账号配额超限""" + for acc in self.accounts: + if acc.id == account_id: + acc.mark_quota_exceeded(reason) + break + + async def refresh_account_token(self, account_id: str) -> tuple: + """刷新指定账号的 token""" + for acc in self.accounts: + if acc.id == account_id: + return await acc.refresh_token() + return False, "账号不存在" + + async def refresh_expiring_tokens(self) -> List[dict]: + """刷新所有即将过期的 token""" + results = [] + for acc in self.accounts: + if acc.enabled and acc.is_token_expiring_soon(10): + success, msg = await acc.refresh_token() + results.append({ + "account_id": acc.id, + "success": success, + "message": msg + }) + return results + + def add_log(self, log: RequestLog): + """添加请求日志""" + self.request_logs.append(log) + self.total_requests += 1 + if log.error: + self.total_errors += 1 + + def get_stats(self) -> dict: + """获取统计信息""" + uptime = time.time() - self.start_time + + # 获取额度汇总 + quota_cache = get_quota_cache() + quota_summary = quota_cache.get_summary() + + # 获取选择器状态 + selector = get_account_selector() + selector_status = selector.get_status() + + # 获取调度器状态 + scheduler = get_quota_scheduler() + scheduler_status = scheduler.get_status() + + return { + "uptime_seconds": int(uptime), + "total_requests": self.total_requests, + "total_errors": self.total_errors, + "error_rate": f"{(self.total_errors / max(1, self.total_requests) * 100):.1f}%", + "accounts_total": len(self.accounts), + "accounts_available": len([a for a in self.accounts if a.is_available()]), + "accounts_cooldown": len([a for a in self.accounts if a.status == CredentialStatus.COOLDOWN]), + "recent_logs": len(self.request_logs), + # 新增字段 + "quota_summary": quota_summary, + "selector": selector_status, + "scheduler": scheduler_status, + } + + def get_accounts_status(self) -> List[dict]: + """获取所有账号状态""" + return [acc.get_status_info() for acc in self.accounts] + + def get_accounts_summary(self) -> dict: + """获取账号汇总统计""" + quota_cache = get_quota_cache() + selector = get_account_selector() + scheduler = get_quota_scheduler() + + total_balance = 0.0 + total_usage = 0.0 + total_limit = 0.0 + + available_count = 0 + cooldown_count = 0 + unhealthy_count = 0 + disabled_count = 0 + + for acc in self.accounts: + if not acc.enabled: + disabled_count += 1 + elif acc.status == CredentialStatus.COOLDOWN: + cooldown_count += 1 + elif acc.status == CredentialStatus.UNHEALTHY: + unhealthy_count += 1 + elif acc.is_available(): + available_count += 1 + + quota = quota_cache.get(acc.id) + if quota and not quota.has_error(): + total_balance += quota.balance + total_usage += quota.current_usage + total_limit += quota.usage_limit + + last_refresh = scheduler.get_last_full_refresh() + last_refresh_ago = None + if last_refresh: + seconds_ago = time.time() - last_refresh + if seconds_ago < 60: + last_refresh_ago = f"{int(seconds_ago)}秒前" + elif seconds_ago < 3600: + last_refresh_ago = f"{int(seconds_ago / 60)}分钟前" + else: + last_refresh_ago = f"{int(seconds_ago / 3600)}小时前" + + return { + "total_accounts": len(self.accounts), + "available_accounts": available_count, + "cooldown_accounts": cooldown_count, + "unhealthy_accounts": unhealthy_count, + "disabled_accounts": disabled_count, + "total_balance": round(total_balance, 2), + "total_usage": round(total_usage, 2), + "total_limit": round(total_limit, 2), + "last_refresh": last_refresh_ago, + "last_refresh_timestamp": last_refresh, + "strategy": selector.strategy.value, + "priority_accounts": selector.get_priority_accounts(), + } + + def get_valid_account_ids(self) -> set: + """获取所有有效账号ID集合""" + return {acc.id for acc in self.accounts if acc.enabled} + + +# 全局状态实例 +state = ProxyState() diff --git a/KiroProxy/kiro_proxy/core/stats.py b/KiroProxy/kiro_proxy/core/stats.py new file mode 100644 index 0000000000000000000000000000000000000000..b5c7554a092282f3cf8c1b5e42bfe73e8fd27c6b --- /dev/null +++ b/KiroProxy/kiro_proxy/core/stats.py @@ -0,0 +1,130 @@ +"""请求统计增强""" +from collections import defaultdict +from dataclasses import dataclass, field +from typing import Dict, List +import time + + +@dataclass +class AccountStats: + """账号统计""" + total_requests: int = 0 + total_errors: int = 0 + total_tokens_in: int = 0 + total_tokens_out: int = 0 + last_request_time: float = 0 + + def record(self, success: bool, tokens_in: int = 0, tokens_out: int = 0): + self.total_requests += 1 + if not success: + self.total_errors += 1 + self.total_tokens_in += tokens_in + self.total_tokens_out += tokens_out + self.last_request_time = time.time() + + @property + def error_rate(self) -> float: + if self.total_requests == 0: + return 0 + return self.total_errors / self.total_requests + + +@dataclass +class ModelStats: + """模型统计""" + total_requests: int = 0 + total_errors: int = 0 + total_latency_ms: float = 0 + + def record(self, success: bool, latency_ms: float): + self.total_requests += 1 + if not success: + self.total_errors += 1 + self.total_latency_ms += latency_ms + + @property + def avg_latency_ms(self) -> float: + if self.total_requests == 0: + return 0 + return self.total_latency_ms / self.total_requests + + +class StatsManager: + """统计管理器""" + + def __init__(self): + self.by_account: Dict[str, AccountStats] = defaultdict(AccountStats) + self.by_model: Dict[str, ModelStats] = defaultdict(ModelStats) + self.hourly_requests: Dict[int, int] = defaultdict(int) # hour -> count + + def record_request( + self, + account_id: str, + model: str, + success: bool, + latency_ms: float, + tokens_in: int = 0, + tokens_out: int = 0 + ): + """记录请求""" + # 按账号统计 + self.by_account[account_id].record(success, tokens_in, tokens_out) + + # 按模型统计 + self.by_model[model].record(success, latency_ms) + + # 按小时统计 + hour = int(time.time() // 3600) + self.hourly_requests[hour] += 1 + + # 清理旧数据(保留 24 小时) + self._cleanup_hourly() + + def _cleanup_hourly(self): + """清理超过 24 小时的数据""" + current_hour = int(time.time() // 3600) + cutoff = current_hour - 24 + self.hourly_requests = defaultdict( + int, + {h: c for h, c in self.hourly_requests.items() if h > cutoff} + ) + + def get_account_stats(self, account_id: str) -> dict: + """获取账号统计""" + stats = self.by_account.get(account_id, AccountStats()) + return { + "total_requests": stats.total_requests, + "total_errors": stats.total_errors, + "error_rate": f"{stats.error_rate * 100:.1f}%", + "total_tokens_in": stats.total_tokens_in, + "total_tokens_out": stats.total_tokens_out, + "last_request": stats.last_request_time + } + + def get_model_stats(self, model: str) -> dict: + """获取模型统计""" + stats = self.by_model.get(model, ModelStats()) + return { + "total_requests": stats.total_requests, + "total_errors": stats.total_errors, + "avg_latency_ms": round(stats.avg_latency_ms, 2) + } + + def get_all_stats(self) -> dict: + """获取所有统计""" + return { + "by_account": { + acc_id: self.get_account_stats(acc_id) + for acc_id in self.by_account + }, + "by_model": { + model: self.get_model_stats(model) + for model in self.by_model + }, + "hourly_requests": dict(self.hourly_requests), + "requests_last_24h": sum(self.hourly_requests.values()) + } + + +# 全局统计实例 +stats_manager = StatsManager() diff --git a/KiroProxy/kiro_proxy/core/thinking.py b/KiroProxy/kiro_proxy/core/thinking.py new file mode 100644 index 0000000000000000000000000000000000000000..d23f955134d3ae42a45e814cc2d9443e74800c0c --- /dev/null +++ b/KiroProxy/kiro_proxy/core/thinking.py @@ -0,0 +1,456 @@ +"""Thinking / Extended Thinking helpers. + +This project implements "thinking" at the proxy layer by: +1) Making a separate Kiro request to generate internal reasoning text. +2) Injecting that reasoning back into the main user prompt (hidden) to improve quality. +3) Optionally returning the reasoning to clients in protocol-appropriate formats. + +Notes: +- Kiro's upstream API doesn't expose a native "thinking budget" knob, so `budget_tokens` + is enforced only via prompt instructions (best-effort). +- If the client does not provide a budget, we treat it as "unlimited" (no prompt limit). +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, AsyncIterator, Optional + +import json + +import httpx + +from ..config import KIRO_API_URL +from ..kiro_api import build_kiro_request, parse_event_stream + + +@dataclass(frozen=True) +class ThinkingConfig: + enabled: bool + budget_tokens: Optional[int] = None # None == unlimited + + +def _coerce_bool(value: Any) -> Optional[bool]: + if isinstance(value, bool): + return value + if isinstance(value, (int, float)): + return bool(value) + if isinstance(value, str): + v = value.strip().lower() + if v in {"true", "1", "yes", "y", "on", "enabled"}: + return True + if v in {"false", "0", "no", "n", "off", "disabled"}: + return False + return None + + +def _coerce_int(value: Any) -> Optional[int]: + if value is None: + return None + if isinstance(value, bool): + return None + if isinstance(value, int): + return value + if isinstance(value, float): + return int(value) + if isinstance(value, str): + v = value.strip() + if not v: + return None + try: + return int(v) + except ValueError: + return None + return None + + +def normalize_thinking_config(raw: Any) -> ThinkingConfig: + """Normalize multiple "thinking" shapes into a single config. + + Supported shapes (best-effort): + - None / missing: disabled + - bool: enabled/disabled + - str: "enabled"/"disabled" + - dict: + - {"type": "enabled", "budget_tokens": 20000} (Anthropic style) + - {"thinking_type": "enabled", "budget_tokens": 20000} (legacy) + - {"enabled": true, "budget_tokens": 20000} + - {"includeThoughts": true, "thinkingBudget": 20000} (Gemini-ish) + """ + if raw is None: + return ThinkingConfig(enabled=False, budget_tokens=None) + + bool_value = _coerce_bool(raw) + if bool_value is not None and not isinstance(raw, dict): + return ThinkingConfig(enabled=bool_value, budget_tokens=None) + + if isinstance(raw, dict): + mode = raw.get("type") or raw.get("thinking_type") or raw.get("mode") + enabled = None + if isinstance(mode, str): + enabled = _coerce_bool(mode) + if enabled is None: + enabled = _coerce_bool(raw.get("enabled")) + if enabled is None: + enabled = _coerce_bool(raw.get("includeThoughts") or raw.get("include_thoughts")) + if enabled is None: + enabled = False + + budget_tokens = None + for key in ( + "budget_tokens", + "budgetTokens", + "thinkingBudget", + "thinking_budget", + "max_thinking_length", + "maxThinkingLength", + ): + if key in raw: + budget_tokens = _coerce_int(raw.get(key)) + break + if budget_tokens is not None and budget_tokens <= 0: + budget_tokens = None + + return ThinkingConfig(enabled=bool(enabled), budget_tokens=budget_tokens) + + if isinstance(raw, str): + enabled = _coerce_bool(raw) + return ThinkingConfig(enabled=bool(enabled), budget_tokens=None) + + return ThinkingConfig(enabled=False, budget_tokens=None) + + +def map_openai_reasoning_effort_to_budget(effort: Any) -> Optional[int]: + """Map OpenAI-style reasoning effort into a best-effort budget. + + We keep this generous; if effort is "high", treat as unlimited. + """ + if not isinstance(effort, str): + return None + v = effort.strip().lower() + if v in {"high"}: + return None + if v in {"medium"}: + return 20000 + if v in {"low"}: + return 10000 + return None + + +def extract_thinking_config_from_openai_body(body: dict) -> tuple[ThinkingConfig, bool]: + """Extract thinking config from OpenAI ChatCompletions/Responses-style bodies.""" + if not isinstance(body, dict): + return ThinkingConfig(False, None), False + + if "thinking" in body: + return normalize_thinking_config(body.get("thinking")), True + + # OpenAI Responses API style + reasoning = body.get("reasoning") + if "reasoning" in body: + if isinstance(reasoning, dict): + effort = reasoning.get("effort") + if isinstance(effort, str) and effort.strip().lower() in {"low", "medium", "high"}: + return ThinkingConfig(True, map_openai_reasoning_effort_to_budget(effort)), True + cfg = normalize_thinking_config(reasoning) + return cfg, True + + effort = body.get("reasoning_effort") + if "reasoning_effort" in body and isinstance(effort, str) and effort.strip().lower() in {"low", "medium", "high"}: + return ThinkingConfig(True, map_openai_reasoning_effort_to_budget(effort)), True + + return ThinkingConfig(False, None), False + + +def extract_thinking_config_from_gemini_body(body: dict) -> tuple[ThinkingConfig, bool]: + """Extract thinking config from Gemini generateContent bodies (best-effort).""" + if not isinstance(body, dict): + return ThinkingConfig(False, None), False + + if "thinking" in body: + return normalize_thinking_config(body.get("thinking")), True + + if "thinkingConfig" in body: + return normalize_thinking_config(body.get("thinkingConfig")), True + + gen_cfg = body.get("generationConfig") + if isinstance(gen_cfg, dict): + if "thinkingConfig" in gen_cfg: + raw = gen_cfg.get("thinkingConfig") + cfg = normalize_thinking_config(raw) + if cfg.enabled: + return cfg, True + # Budget without explicit includeThoughts/mode: treat as enabled (client guidance exists) + if isinstance(raw, dict) and any( + k in raw for k in ("thinkingBudget", "budgetTokens", "budget_tokens", "max_thinking_length") + ): + return ThinkingConfig(True, cfg.budget_tokens), True + return cfg, True + + return ThinkingConfig(False, None), False + + +def infer_thinking_from_anthropic_messages(messages: list[dict]) -> bool: + """推断历史消息中是否包含思维链内容,用于在客户端未明确指定时自动启用思维链""" + for msg in messages or []: + content = msg.get("content") + if not isinstance(content, list): + continue + for block in content: + if isinstance(block, dict): + # 检查标准的 thinking 块 + if block.get("type") == "thinking": + return True + # 检查文本块中嵌入的 标签(assistant 消息中可能存在) + if block.get("type") == "text" and msg.get("role") == "assistant": + text = block.get("text", "") + if isinstance(text, str) and "" in text and "" in text: + return True + return False + + +def infer_thinking_from_openai_messages(messages: list[dict]) -> bool: + for msg in messages or []: + content = msg.get("content", "") + if isinstance(content, str): + if "" in content and "" in content: + return True + continue + if isinstance(content, list): + for part in content: + if isinstance(part, dict) and part.get("type") == "text": + text = part.get("text", "") + if "" in text and "" in text: + return True + return False + + +def infer_thinking_from_openai_responses_input(input_data: Any) -> bool: + """Infer thinking from OpenAI Responses API `input` payloads (best-effort).""" + if isinstance(input_data, str): + return "" in input_data and "" in input_data + + if not isinstance(input_data, list): + return False + + for item in input_data: + if not isinstance(item, dict): + continue + if item.get("type") != "message": + continue + + content_list = item.get("content", []) or [] + for c in content_list: + if isinstance(c, str): + if "" in c and "" in c: + return True + continue + if not isinstance(c, dict): + continue + c_type = c.get("type") + if c_type in {"input_text", "output_text", "text"}: + text = c.get("text", "") + if isinstance(text, str) and "" in text and "" in text: + return True + return False + + +def infer_thinking_from_gemini_contents(contents: list[dict]) -> bool: + for item in contents or []: + for part in item.get("parts", []) or []: + if isinstance(part, dict) and isinstance(part.get("text"), str): + text = part["text"] + if "" in text and "" in text: + return True + return False + + +import re + +_THINKING_PATTERN = re.compile(r".*?\s*", re.DOTALL) + + +def strip_thinking_from_text(text: str) -> str: + """Remove blocks from text.""" + if not text or not isinstance(text, str): + return text + return _THINKING_PATTERN.sub("", text).strip() + + +def strip_thinking_from_history(history: list) -> list: + """Return a copy of history with blocks removed from all messages.""" + if not history: + return [] + + cleaned = [] + for msg in history: + if not isinstance(msg, dict): + cleaned.append(msg) + continue + + new_msg = msg.copy() + content = msg.get("content") + + if isinstance(content, str): + new_msg["content"] = strip_thinking_from_text(content) + elif isinstance(content, list): + new_content = [] + for part in content: + if isinstance(part, dict) and part.get("type") == "text": + new_part = part.copy() + new_part["text"] = strip_thinking_from_text(part.get("text", "")) + new_content.append(new_part) + else: + new_content.append(part) + new_msg["content"] = new_content + + cleaned.append(new_msg) + + return cleaned + + +def format_thinking_block(thinking_content: str) -> str: + if thinking_content is None: + return "" + thinking_content = str(thinking_content).strip() + if not thinking_content: + return "" + return f"\n{thinking_content}\n" + + +def build_thinking_prompt(user_content: str, *, budget_tokens: Optional[int]) -> str: + """Build a separate prompt using Tree of Thoughts approach. + + Use multiple expert perspectives to analyze the problem deeply. + """ + if user_content is None: + user_content = "" + + budget_str = "" + if budget_tokens: + budget_str = f" Budget: {budget_tokens} tokens." + + return ( + f"Think deeply and comprehensively about this problem.{budget_str}\n\n" + "Use the following approach:\n" + "1. Break down the problem into components\n" + "2. Consider multiple perspectives and solutions\n" + "3. Evaluate trade-offs and edge cases\n" + "4. Synthesize your analysis into a coherent response\n\n" + f"{user_content}" + ) + +def build_user_prompt_with_thinking(user_content: str, thinking_content: str) -> str: + """Inject thinking into the main prompt. + + Minimal injection to avoid context pollution. + """ + if user_content is None: + user_content = "" + + thinking_block = format_thinking_block(thinking_content) + if not thinking_block: + return user_content + + return f"{thinking_block}\n\n{user_content}" + + +async def iter_aws_event_stream_text(byte_iter: AsyncIterator[bytes]) -> AsyncIterator[str]: + """Yield incremental text content from AWS event-stream chunks.""" + buffer = b"" + + async for chunk in byte_iter: + buffer += chunk + + while len(buffer) >= 12: + total_len = int.from_bytes(buffer[0:4], "big") + + if total_len <= 0: + return + if len(buffer) < total_len: + break + + headers_len = int.from_bytes(buffer[4:8], "big") + payload_start = 12 + headers_len + payload_end = total_len - 4 + + if payload_start < payload_end: + try: + payload = json.loads(buffer[payload_start:payload_end].decode("utf-8")) + content = None + if "assistantResponseEvent" in payload: + content = payload["assistantResponseEvent"].get("content") + elif "content" in payload and "toolUseId" not in payload: + content = payload.get("content") + if content: + yield content + except Exception: + pass + + buffer = buffer[total_len:] + + +async def fetch_thinking_text( + *, + headers: dict, + model: str, + user_content: str, + history: list, + images: list | None = None, + tool_results: list | None = None, + budget_tokens: Optional[int] = None, + timeout_s: float = 600.0, +) -> str: + """Non-streaming helper to get thinking content (best-effort).""" + thinking_prompt = build_thinking_prompt(user_content, budget_tokens=budget_tokens) + clean_history = strip_thinking_from_history(history) + thinking_request = build_kiro_request( + thinking_prompt, + model, + clean_history, + tools=None, + images=images, + tool_results=tool_results, + ) + + try: + async with httpx.AsyncClient(verify=False, timeout=timeout_s) as client: + resp = await client.post(KIRO_API_URL, json=thinking_request, headers=headers) + if resp.status_code != 200: + return "" + return parse_event_stream(resp.content) + except Exception: + return "" + + +async def stream_thinking_text( + *, + headers: dict, + model: str, + user_content: str, + history: list, + images: list | None = None, + tool_results: list | None = None, + budget_tokens: Optional[int] = None, + timeout_s: float = 600.0, +) -> AsyncIterator[str]: + """Streaming helper to yield thinking content incrementally (best-effort).""" + thinking_prompt = build_thinking_prompt(user_content, budget_tokens=budget_tokens) + clean_history = strip_thinking_from_history(history) + thinking_request = build_kiro_request( + thinking_prompt, + model, + clean_history, + tools=None, + images=images, + tool_results=tool_results, + ) + + async with httpx.AsyncClient(verify=False, timeout=timeout_s) as client: + async with client.stream( + "POST", KIRO_API_URL, json=thinking_request, headers=headers + ) as response: + if response.status_code != 200: + return + async for piece in iter_aws_event_stream_text(response.aiter_bytes()): + yield piece diff --git a/KiroProxy/kiro_proxy/core/usage.py b/KiroProxy/kiro_proxy/core/usage.py new file mode 100644 index 0000000000000000000000000000000000000000..3d0231f49e9c9aee89a2a15cedb87fb71877f859 --- /dev/null +++ b/KiroProxy/kiro_proxy/core/usage.py @@ -0,0 +1,235 @@ +"""Kiro 用量查询服务 + +通过调用 AWS Q 的 getUsageLimits API 获取用户的用量信息。 +""" +import uuid +import httpx +from dataclasses import dataclass +from typing import Optional, Tuple + + +# API 端点 +USAGE_LIMITS_URL = "https://q.us-east-1.amazonaws.com/getUsageLimits" + +# 低余额阈值 (20%) +LOW_BALANCE_THRESHOLD = 0.2 + + +@dataclass +class UsageInfo: + """用量信息""" + subscription_title: str = "" + usage_limit: float = 0.0 + current_usage: float = 0.0 + balance: float = 0.0 + is_low_balance: bool = False + + # 详细信息 + free_trial_limit: float = 0.0 + free_trial_usage: float = 0.0 + bonus_limit: float = 0.0 + bonus_usage: float = 0.0 + + # 重置和过期时间 + next_reset_date: Optional[str] = None # 下次重置时间 + free_trial_expiry: Optional[str] = None # 免费试用过期时间 + bonus_expiries: list = None # 奖励过期时间列表 + + def __post_init__(self): + if self.bonus_expiries is None: + self.bonus_expiries = [] + + +def build_usage_api_url(auth_method: str, profile_arn: Optional[str] = None) -> str: + """构造 API 请求 URL""" + url = f"{USAGE_LIMITS_URL}?origin=AI_EDITOR&resourceType=AGENTIC_REQUEST" + + # Social 认证需要 profileArn + if auth_method == "social" and profile_arn: + from urllib.parse import quote + url += f"&profileArn={quote(profile_arn)}" + + return url + + +def build_usage_headers( + access_token: str, + machine_id: str, + kiro_version: str = "1.0.0" +) -> dict: + """构造请求头""" + import platform + os_name = platform.system().lower() + + return { + "Authorization": f"Bearer {access_token}", + "User-Agent": f"aws-sdk-js/1.0.0 ua/2.1 os/{os_name} lang/python api/codewhispererruntime#1.0.0 m/N,E KiroIDE-{kiro_version}-{machine_id}", + "x-amz-user-agent": f"aws-sdk-js/1.0.0 KiroIDE-{kiro_version}-{machine_id}", + "amz-sdk-invocation-id": str(uuid.uuid4()), + "amz-sdk-request": "attempt=1; max=1", + "Connection": "close", + } + + +def calculate_balance(response: dict) -> UsageInfo: + """从 API 响应计算余额 + + 注意:只计算 resourceType 为 CREDIT 的额度,忽略其他类型(如 AGENTIC_REQUEST) + """ + subscription_info = response.get("subscriptionInfo", {}) + usage_breakdown_list = response.get("usageBreakdownList", []) + + total_limit = 0.0 + total_usage = 0.0 + free_trial_limit = 0.0 + free_trial_usage = 0.0 + bonus_limit = 0.0 + bonus_usage = 0.0 + + # 重置和过期时间 + next_reset_date = response.get("nextDateReset") # 下次重置时间 + free_trial_expiry = None + bonus_expiries = [] + + # 只查找 CREDIT 类型的额度 + credit_breakdown = None + for breakdown in usage_breakdown_list: + resource_type = breakdown.get("resourceType", "") + display_name = breakdown.get("displayName", "") + if resource_type == "CREDIT" or display_name == "Credits": + credit_breakdown = breakdown + break + + if credit_breakdown: + # 基本额度 (优先使用带精度的值) + total_limit = credit_breakdown.get("usageLimitWithPrecision", 0.0) or credit_breakdown.get("usageLimit", 0.0) + total_usage = credit_breakdown.get("currentUsageWithPrecision", 0.0) or credit_breakdown.get("currentUsage", 0.0) + + # 免费试用额度 (只有状态为 ACTIVE 时才计算) + free_trial = credit_breakdown.get("freeTrialInfo") + if free_trial and free_trial.get("freeTrialStatus") == "ACTIVE": + ft_limit = free_trial.get("usageLimitWithPrecision", 0.0) or free_trial.get("usageLimit", 0.0) + ft_usage = free_trial.get("currentUsageWithPrecision", 0.0) or free_trial.get("currentUsage", 0.0) + total_limit += ft_limit + total_usage += ft_usage + free_trial_limit = ft_limit + free_trial_usage = ft_usage + # 获取免费试用过期时间 + free_trial_expiry = free_trial.get("freeTrialExpiry") + + # 奖励额度 (只计算状态为 ACTIVE 的奖励) + bonuses = credit_breakdown.get("bonuses", []) + for bonus in bonuses or []: + if bonus.get("status") == "ACTIVE": + b_limit = bonus.get("usageLimitWithPrecision", 0.0) or bonus.get("usageLimit", 0.0) + b_usage = bonus.get("currentUsageWithPrecision", 0.0) or bonus.get("currentUsage", 0.0) + total_limit += b_limit + total_usage += b_usage + bonus_limit += b_limit + bonus_usage += b_usage + # 获取奖励过期时间 + expires_at = bonus.get("expiresAt") + if expires_at: + bonus_expiries.append(expires_at) + + balance = total_limit - total_usage + is_low = (balance / total_limit) < LOW_BALANCE_THRESHOLD if total_limit > 0 else False + + return UsageInfo( + subscription_title=subscription_info.get("subscriptionTitle", "Unknown"), + usage_limit=total_limit, + current_usage=total_usage, + balance=balance, + is_low_balance=is_low, + free_trial_limit=free_trial_limit, + free_trial_usage=free_trial_usage, + bonus_limit=bonus_limit, + bonus_usage=bonus_usage, + next_reset_date=next_reset_date, + free_trial_expiry=free_trial_expiry, + bonus_expiries=bonus_expiries, + ) + + +async def get_usage_limits( + access_token: str, + auth_method: str = "social", + profile_arn: Optional[str] = None, + machine_id: str = "", + kiro_version: str = "1.0.0", +) -> Tuple[bool, UsageInfo | dict]: + """ + 获取 Kiro 用量信息 + + Args: + access_token: Bearer token + auth_method: 认证方式 ("social" 或 "idc") + profile_arn: Social 认证需要的 profileArn + machine_id: 设备 ID + kiro_version: Kiro 版本号 + + Returns: + (success, UsageInfo or error_dict) + """ + if not access_token: + return False, {"error": "缺少 access token"} + + if not machine_id: + return False, {"error": "缺少 machine ID"} + + # 构造 URL 和请求头 + url = build_usage_api_url(auth_method, profile_arn) + headers = build_usage_headers(access_token, machine_id, kiro_version) + + try: + async with httpx.AsyncClient(timeout=10, verify=False) as client: + response = await client.get(url, headers=headers) + + if response.status_code != 200: + return False, {"error": f"API 请求失败: {response.status_code} - {response.text[:200]}"} + + data = response.json() + usage_info = calculate_balance(data) + return True, usage_info + + except httpx.TimeoutException: + return False, {"error": "请求超时"} + except Exception as e: + return False, {"error": f"请求失败: {str(e)}"} + + +async def get_account_usage(account) -> Tuple[bool, UsageInfo | dict]: + """ + 获取指定账号的用量信息 + + Args: + account: Account 对象 + + Returns: + (success, UsageInfo or error_dict) + """ + from ..credential import get_kiro_version + from .refresh_manager import get_refresh_manager + + creds = account.get_credentials() + if not creds: + return False, {"error": "无法获取凭证"} + + # 先刷新 Token(如即将过期/已过期),避免额度获取失败 + refresh_manager = get_refresh_manager() + if refresh_manager.should_refresh_token(account): + token_success, token_msg = await refresh_manager.refresh_token_if_needed(account) + if not token_success: + return False, {"error": f"Token 刷新失败: {token_msg}"} + + token = account.get_token() + if not token: + return False, {"error": "无法获取 token"} + + return await get_usage_limits( + access_token=token, + auth_method=creds.auth_method or "social", + profile_arn=creds.profile_arn, + machine_id=account.get_machine_id(), + kiro_version=get_kiro_version(), + ) diff --git a/KiroProxy/kiro_proxy/credential/__init__.py b/KiroProxy/kiro_proxy/credential/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..66329cf29df4cf45a5c5fed0399573da7ab1e04c --- /dev/null +++ b/KiroProxy/kiro_proxy/credential/__init__.py @@ -0,0 +1,17 @@ +"""凭证管理模块""" +from .fingerprint import generate_machine_id, get_kiro_version, get_system_info +from .quota import QuotaManager, QuotaRecord, quota_manager +from .refresher import TokenRefresher +from .types import KiroCredentials, CredentialStatus + +__all__ = [ + "generate_machine_id", + "get_kiro_version", + "get_system_info", + "QuotaManager", + "QuotaRecord", + "quota_manager", + "TokenRefresher", + "KiroCredentials", + "CredentialStatus", +] diff --git a/KiroProxy/kiro_proxy/credential/fingerprint.py b/KiroProxy/kiro_proxy/credential/fingerprint.py new file mode 100644 index 0000000000000000000000000000000000000000..c1088e9e7478dfee05e4ef77e9bc5e6d2d065967 --- /dev/null +++ b/KiroProxy/kiro_proxy/credential/fingerprint.py @@ -0,0 +1,131 @@ +"""设备指纹生成""" +import hashlib +import platform +import subprocess +import time +from pathlib import Path +from typing import Optional + + +def get_raw_machine_id() -> Optional[str]: + """获取系统原始 Machine ID""" + system = platform.system() + + try: + if system == "Darwin": + result = subprocess.run( + ["ioreg", "-rd1", "-c", "IOPlatformExpertDevice"], + capture_output=True, text=True, timeout=5 + ) + for line in result.stdout.split("\n"): + if "IOPlatformUUID" in line: + return line.split("=")[1].strip().strip('"').lower() + + elif system == "Linux": + for path in ["/etc/machine-id", "/var/lib/dbus/machine-id"]: + if Path(path).exists(): + return Path(path).read_text().strip().lower() + + elif system == "Windows": + result = subprocess.run( + ["wmic", "csproduct", "get", "UUID"], + capture_output=True, text=True, timeout=5, + creationflags=0x08000000 + ) + lines = [l.strip() for l in result.stdout.split("\n") if l.strip()] + if len(lines) > 1: + return lines[1].lower() + except Exception: + pass + + return None + + +def generate_machine_id( + profile_arn: Optional[str] = None, + client_id: Optional[str] = None +) -> str: + """生成基于凭证的唯一 Machine ID + + 每个凭证生成独立的 Machine ID,避免多账号共用同一指纹被检测。 + 优先级:profileArn > clientId > 系统硬件 ID + 添加时间因子:按小时变化,避免指纹完全固化。 + """ + unique_key = None + if profile_arn: + unique_key = profile_arn + elif client_id: + unique_key = client_id + else: + unique_key = get_raw_machine_id() or "KIRO_DEFAULT_MACHINE" + + hour_slot = int(time.time()) // 3600 + + hasher = hashlib.sha256() + hasher.update(unique_key.encode()) + hasher.update(hour_slot.to_bytes(8, 'little')) + + return hasher.hexdigest() + + +def get_kiro_version() -> str: + """获取 Kiro IDE 版本号 + + 优先检测本地安装的 Kiro,否则使用默认版本 (与 kiro.rs 保持一致) + """ + if platform.system() == "Darwin": + kiro_paths = [ + "/Applications/Kiro.app/Contents/Info.plist", + str(Path.home() / "Applications/Kiro.app/Contents/Info.plist"), + ] + for plist_path in kiro_paths: + try: + result = subprocess.run( + ["defaults", "read", plist_path, "CFBundleShortVersionString"], + capture_output=True, text=True, timeout=5 + ) + version = result.stdout.strip() + if version: + return version + except Exception: + pass + + # 默认版本与 kiro.rs 保持一致 + return "0.8.0" + + +def get_system_info() -> tuple: + """获取系统运行时信息 (os_name, node_version) + + node_version 与 kiro.rs 保持一致 + """ + system = platform.system() + + if system == "Darwin": + try: + result = subprocess.run( + ["sw_vers", "-productVersion"], + capture_output=True, text=True, timeout=5 + ) + version = result.stdout.strip() or "14.0" + os_name = f"macos#{version}" + except Exception: + os_name = "macos#14.0" + elif system == "Linux": + try: + result = subprocess.run( + ["uname", "-r"], + capture_output=True, text=True, timeout=5 + ) + version = result.stdout.strip() or "5.15.0" + os_name = f"linux#{version}" + except Exception: + os_name = "linux#5.15.0" + elif system == "Windows": + os_name = "windows#10.0" + else: + os_name = "other#1.0" + + # Node 版本与 kiro.rs 保持一致 + node_version = "22.11.0" + return os_name, node_version diff --git a/KiroProxy/kiro_proxy/credential/quota.py b/KiroProxy/kiro_proxy/credential/quota.py new file mode 100644 index 0000000000000000000000000000000000000000..5442de1dffa7593124d2e605fee74f3b7dfdd448 --- /dev/null +++ b/KiroProxy/kiro_proxy/credential/quota.py @@ -0,0 +1,100 @@ +"""配额管理""" +import time +from dataclasses import dataclass +from typing import Dict, Optional + + +@dataclass +class QuotaRecord: + """配额超限记录""" + credential_id: str + exceeded_at: float + cooldown_until: float + reason: str + + +class QuotaManager: + """配额管理器 + + 管理凭证的配额超限状态: + - 仅在收到 429 错误时触发冷却 + - 自动管理冷却时间:固定 5 分钟(300秒) + - 自动清理过期的冷却状态 + """ + + # 固定冷却时间(秒)- 429 错误自动冷却 5 分钟 + COOLDOWN_SECONDS = 300 + + def __init__(self): + self.exceeded_records: Dict[str, QuotaRecord] = {} + + def is_429_error(self, status_code: Optional[int]) -> bool: + """检查是否为 429 错误(仅 429 触发冷却)""" + return status_code == 429 + + def is_quota_exceeded_error(self, status_code: Optional[int], error_message: str) -> bool: + """检查是否为配额超限错误(仅用于判断是否切换账号,不触发冷却)""" + # 仅 429 才算配额超限 + return status_code == 429 + + def mark_exceeded(self, credential_id: str, reason: str) -> QuotaRecord: + """标记凭证为配额超限(仅 429 时调用) + + 自动管理冷却时间:固定 5 分钟(300秒) + """ + now = time.time() + + record = QuotaRecord( + credential_id=credential_id, + exceeded_at=now, + cooldown_until=now + self.COOLDOWN_SECONDS, + reason=reason + ) + self.exceeded_records[credential_id] = record + + print(f"[QuotaManager] 账号 {credential_id} 遇到 429 错误,自动冷却 {self.COOLDOWN_SECONDS} 秒(5分钟)") + return record + + def is_available(self, credential_id: str) -> bool: + """检查凭证是否可用""" + record = self.exceeded_records.get(credential_id) + if not record: + return True + + if time.time() >= record.cooldown_until: + del self.exceeded_records[credential_id] + return True + + return False + + def get_cooldown_remaining(self, credential_id: str) -> Optional[int]: + """获取剩余冷却时间(秒)""" + record = self.exceeded_records.get(credential_id) + if not record: + return None + + remaining = record.cooldown_until - time.time() + if remaining <= 0: + del self.exceeded_records[credential_id] + return None + + return int(remaining) + + def cleanup_expired(self) -> int: + """清理过期的冷却记录""" + now = time.time() + expired = [k for k, v in self.exceeded_records.items() if now >= v.cooldown_until] + for k in expired: + del self.exceeded_records[k] + return len(expired) + + def restore(self, credential_id: str) -> bool: + """手动恢复凭证""" + if credential_id in self.exceeded_records: + del self.exceeded_records[credential_id] + return True + return False + + +# 全局实例 - 429 自动冷却 5 分钟 +quota_manager = QuotaManager() diff --git a/KiroProxy/kiro_proxy/credential/refresher.py b/KiroProxy/kiro_proxy/credential/refresher.py new file mode 100644 index 0000000000000000000000000000000000000000..4e438e1dcd00fcf6bbba5df8853f545b3270e6b4 --- /dev/null +++ b/KiroProxy/kiro_proxy/credential/refresher.py @@ -0,0 +1,195 @@ +"""Token 刷新器""" +import httpx +from datetime import datetime, timezone, timedelta +from typing import Tuple + +from .types import KiroCredentials +from .fingerprint import generate_machine_id, get_kiro_version + + +# Kiro Auth 端点 +KIRO_AUTH_ENDPOINT = "https://prod.us-east-1.auth.desktop.kiro.dev" + + +class TokenRefresher: + """Token 刷新器""" + + def __init__(self, credentials: KiroCredentials): + self.credentials = credentials + + def get_refresh_url(self) -> str: + """获取刷新 URL""" + region = self.credentials.region or "us-east-1" + auth_method = (self.credentials.auth_method or "social").lower() + + if auth_method == "idc": + # IDC (AWS Builder ID) 使用 OIDC 端点 + return f"https://oidc.{region}.amazonaws.com/token" + else: + # Social (Google/GitHub) 使用 Kiro Auth 端点 + return f"{KIRO_AUTH_ENDPOINT}/refreshToken" + + def validate_refresh_token(self) -> Tuple[bool, str]: + """验证 refresh_token 有效性""" + refresh_token = self.credentials.refresh_token + + if not refresh_token: + return False, "缺少 refresh_token" + + if len(refresh_token.strip()) == 0: + return False, "refresh_token 为空" + + if len(refresh_token) < 100 or refresh_token.endswith("..."): + return False, f"refresh_token 已被截断(长度: {len(refresh_token)})" + + return True, "" + + def _get_machine_id(self) -> str: + """获取 Machine ID""" + return generate_machine_id( + self.credentials.profile_arn, + self.credentials.client_id + ) + + async def refresh_social_token(self) -> Tuple[bool, str]: + """ + 刷新 Social Token (Google/GitHub) + + 参考 Kiro-account-manager 实现: + - 端点: https://prod.us-east-1.auth.desktop.kiro.dev/refreshToken + - 请求体: {"refreshToken": refresh_token} + - 响应: {accessToken, refreshToken, expiresIn} + """ + refresh_url = f"{KIRO_AUTH_ENDPOINT}/refreshToken" + + body = {"refreshToken": self.credentials.refresh_token} + headers = { + "Content-Type": "application/json", + "User-Agent": "kiro-proxy/1.0.0", + "Accept": "application/json", + } + + try: + async with httpx.AsyncClient(verify=False, timeout=30) as client: + resp = await client.post(refresh_url, json=body, headers=headers) + + if resp.status_code != 200: + error_text = resp.text + if resp.status_code == 401: + return False, "凭证已过期或无效,需要重新登录" + elif resp.status_code == 429: + return False, "请求过于频繁,请稍后重试" + else: + return False, f"刷新失败: {resp.status_code} - {error_text[:200]}" + + data = resp.json() + + new_token = data.get("accessToken") + if not new_token: + return False, "响应中没有 accessToken" + + # 更新凭证 + self.credentials.access_token = new_token + + # 更新 refreshToken(如果服务器返回了新的) + if rt := data.get("refreshToken"): + self.credentials.refresh_token = rt + + # 更新过期时间 + if expires_in := data.get("expiresIn"): + expires_at = datetime.now(timezone.utc) + timedelta(seconds=expires_in) + self.credentials.expires_at = expires_at.isoformat() + + self.credentials.last_refresh = datetime.now(timezone.utc).isoformat() + + print(f"[TokenRefresher] Social token 刷新成功,过期时间: {expires_in}s") + return True, new_token + + except Exception as e: + return False, f"刷新异常: {str(e)}" + + async def refresh_idc_token(self) -> Tuple[bool, str]: + """ + 刷新 IDC Token (AWS Builder ID) + + 使用 AWS OIDC 端点刷新 + """ + region = self.credentials.region or "us-east-1" + refresh_url = f"https://oidc.{region}.amazonaws.com/token" + + if not self.credentials.client_id or not self.credentials.client_secret: + return False, "IdC 认证缺少 client_id 或 client_secret" + + machine_id = self._get_machine_id() + kiro_version = get_kiro_version() + + body = { + "refreshToken": self.credentials.refresh_token, + "clientId": self.credentials.client_id, + "clientSecret": self.credentials.client_secret, + "grantType": "refresh_token" + } + headers = { + "Content-Type": "application/json", + "x-amz-user-agent": f"aws-sdk-js/3.738.0 KiroIDE-{kiro_version}-{machine_id}", + "User-Agent": "node", + } + + try: + async with httpx.AsyncClient(verify=False, timeout=30) as client: + resp = await client.post(refresh_url, json=body, headers=headers) + + if resp.status_code != 200: + error_text = resp.text + if resp.status_code == 401: + return False, "凭证已过期或无效,需要重新登录" + elif resp.status_code == 429: + return False, "请求过于频繁,请稍后重试" + else: + return False, f"刷新失败: {resp.status_code} - {error_text[:200]}" + + data = resp.json() + + new_token = data.get("accessToken") or data.get("access_token") + if not new_token: + return False, "响应中没有 access_token" + + # 更新凭证 + self.credentials.access_token = new_token + + if rt := data.get("refreshToken") or data.get("refresh_token"): + self.credentials.refresh_token = rt + + if arn := data.get("profileArn"): + self.credentials.profile_arn = arn + + if expires_in := data.get("expiresIn") or data.get("expires_in"): + expires_at = datetime.now(timezone.utc) + timedelta(seconds=expires_in) + self.credentials.expires_at = expires_at.isoformat() + + self.credentials.last_refresh = datetime.now(timezone.utc).isoformat() + + print(f"[TokenRefresher] IDC token 刷新成功") + return True, new_token + + except Exception as e: + return False, f"刷新异常: {str(e)}" + + async def refresh(self) -> Tuple[bool, str]: + """ + 刷新 token,根据 authMethod 分发到正确的刷新方法 + + Returns: + (success, new_token_or_error) + """ + is_valid, error = self.validate_refresh_token() + if not is_valid: + return False, error + + auth_method = (self.credentials.auth_method or "social").lower() + + if auth_method == "idc": + return await self.refresh_idc_token() + else: + # social 或其他默认使用 social 刷新 + return await self.refresh_social_token() diff --git a/KiroProxy/kiro_proxy/credential/types.py b/KiroProxy/kiro_proxy/credential/types.py new file mode 100644 index 0000000000000000000000000000000000000000..67fd7b1819508e5d42a1756b0fbef7ede9e2dc3f --- /dev/null +++ b/KiroProxy/kiro_proxy/credential/types.py @@ -0,0 +1,121 @@ +"""凭证数据类型""" +import json +import time +from dataclasses import dataclass +from datetime import datetime, timezone, timedelta +from enum import Enum +from pathlib import Path +from typing import Optional + + +class CredentialStatus(Enum): + """凭证状态""" + ACTIVE = "active" + COOLDOWN = "cooldown" + UNHEALTHY = "unhealthy" + DISABLED = "disabled" + SUSPENDED = "suspended" # 账号被封禁 + + +@dataclass +class KiroCredentials: + """Kiro 凭证信息""" + access_token: Optional[str] = None + refresh_token: Optional[str] = None + client_id: Optional[str] = None + client_secret: Optional[str] = None + profile_arn: Optional[str] = None + expires_at: Optional[str] = None + region: str = "us-east-1" + auth_method: str = "social" + provider: Optional[str] = None # Google / Github (社交登录提供商) + client_id_hash: Optional[str] = None + last_refresh: Optional[str] = None + + @classmethod + def from_file(cls, path: str) -> "KiroCredentials": + """从文件加载凭证""" + with open(path) as f: + data = json.load(f) + + return cls( + access_token=data.get("accessToken"), + refresh_token=data.get("refreshToken"), + client_id=data.get("clientId"), + client_secret=data.get("clientSecret"), + profile_arn=data.get("profileArn"), + expires_at=data.get("expiresAt") or data.get("expire"), + region=data.get("region", "us-east-1"), + auth_method=data.get("authMethod", "social"), + provider=data.get("provider"), + client_id_hash=data.get("clientIdHash"), + last_refresh=data.get("lastRefresh"), + ) + + def to_dict(self) -> dict: + """转换为字典""" + result = { + "accessToken": self.access_token, + "refreshToken": self.refresh_token, + "clientId": self.client_id, + "clientSecret": self.client_secret, + "profileArn": self.profile_arn, + "expiresAt": self.expires_at, + "region": self.region, + "authMethod": self.auth_method, + "clientIdHash": self.client_id_hash, + "lastRefresh": self.last_refresh, + } + # 只有社交登录才添加 provider 字段 + if self.provider: + result["provider"] = self.provider + return result + + def save_to_file(self, path: str): + """保存凭证到文件""" + existing = {} + if Path(path).exists(): + try: + with open(path) as f: + existing = json.load(f) + except Exception: + pass + + existing.update({k: v for k, v in self.to_dict().items() if v is not None}) + + with open(path, "w") as f: + json.dump(existing, f, indent=2) + + def is_expired(self) -> bool: + """检查 token 是否已过期""" + if not self.expires_at: + return True + + try: + if "T" in self.expires_at: + expires = datetime.fromisoformat(self.expires_at.replace("Z", "+00:00")) + now = datetime.now(timezone.utc) + return expires <= now + timedelta(minutes=5) + + expires_ts = int(self.expires_at) + now_ts = int(time.time()) + return now_ts >= (expires_ts - 300) + except Exception: + return True + + def is_expiring_soon(self, minutes: int = 10) -> bool: + """检查 token 是否即将过期""" + if not self.expires_at: + return False + + try: + if "T" in self.expires_at: + expires = datetime.fromisoformat(self.expires_at.replace("Z", "+00:00")) + now = datetime.now(timezone.utc) + return expires < now + timedelta(minutes=minutes) + + expires_ts = int(self.expires_at) + now_ts = int(time.time()) + return now_ts >= (expires_ts - minutes * 60) + except Exception: + return False diff --git a/KiroProxy/kiro_proxy/docs/01-quickstart.md b/KiroProxy/kiro_proxy/docs/01-quickstart.md new file mode 100644 index 0000000000000000000000000000000000000000..52c0575d9bc0faad0a84d8a5bdaadf2a7c6d3a3b --- /dev/null +++ b/KiroProxy/kiro_proxy/docs/01-quickstart.md @@ -0,0 +1,143 @@ +# 快速开始 + +## 安装运行 + +### 方式一:下载预编译版本 + +从 [Releases](https://github.com/yourname/kiro-proxy/releases) 下载对应平台的安装包: + +- **Windows**: `kiro-proxy-windows.zip` +- **macOS**: `kiro-proxy-macos.zip` +- **Linux**: `kiro-proxy-linux.tar.gz` + +解压后双击运行即可。 + +### 方式二:从源码运行 + +```bash +# 克隆项目 +git clone https://github.com/yourname/kiro-proxy.git +cd kiro-proxy + +# 创建虚拟环境 +python -m venv venv +source venv/bin/activate # Windows: venv\Scripts\activate + +# 安装依赖 +pip install -r requirements.txt + +# 运行(默认端口 8080) +python run.py + +# 指定端口 +python run.py 8081 +``` + +启动成功后,访问 http://localhost:8080 打开管理界面。 + +--- + +## 获取 Kiro 账号 + +Kiro Proxy 需要 Kiro 账号的 Token 才能工作。有两种方式获取: + +### 方式一:在线登录(推荐) + +1. 打开 Web UI,点击「账号」标签页 +2. 点击「在线登录」按钮 +3. 选择登录方式: + - **Google** - 使用 Google 账号 + - **GitHub** - 使用 GitHub 账号 + - **AWS** - 使用 AWS Builder ID +4. 在弹出的浏览器中完成授权 +5. 授权成功后,账号自动添加到代理 + +### 方式二:扫描本地 Token + +如果你已经在 Kiro IDE 中登录过: + +1. 打开 Kiro IDE,确保已登录 +2. 回到 Web UI,点击「扫描 Token」 +3. 系统会扫描 `~/.aws/sso/cache/` 目录 +4. 选择要添加的 Token 文件 + +--- + +## 配置 AI 客户端 + +### Claude Code (VSCode 插件) + +这是最推荐的使用方式,工具调用功能已验证可用。 + +1. 安装 Claude Code 插件 +2. 打开设置,添加自定义 Provider: + +``` +名称: Kiro Proxy +API Provider: Anthropic +API Key: any(随便填一个) +Base URL: http://localhost:8080 +模型: claude-sonnet-4 +``` + +3. 选择 Kiro Proxy 作为当前 Provider + +### Codex CLI + +OpenAI 官方命令行工具。 + +```bash +# 安装 +npm install -g @openai/codex + +# 配置 (~/.codex/config.toml) +model = "gpt-4o" +model_provider = "kiro" + +[model_providers.kiro] +name = "Kiro Proxy" +base_url = "http://localhost:8080/v1" +``` + +### Gemini CLI + +```bash +# 设置环境变量 +export GEMINI_API_BASE=http://localhost:8080/v1 + +# 或在配置文件中设置 +base_url = "http://localhost:8080/v1" +model = "gemini-pro" +``` + +### 其他兼容客户端 + +任何支持 OpenAI 或 Anthropic API 的客户端都可以使用: + +- **Base URL**: `http://localhost:8080` 或 `http://localhost:8080/v1` +- **API Key**: 任意值(代理不验证) +- **模型**: 见下方模型对照表 + +--- + +## 模型对照表 + +Kiro 支持以下模型,你可以使用 Kiro 原生名称或映射名称: + +| Kiro 模型 | 能力 | 可用名称(任选其一) | +|-----------|------|---------------------| +| `claude-sonnet-4` | ⭐⭐⭐ 推荐,性价比最高 | `gpt-4o`, `gpt-4`, `gpt-4-turbo`, `claude-3-5-sonnet-20241022`, `claude-3-5-sonnet-latest`, `sonnet` | +| `claude-sonnet-4.5` | ⭐⭐⭐⭐ 更强,适合复杂任务 | `gemini-1.5-pro`, `o1`, `o1-preview`, `claude-3-opus-20240229`, `claude-3-opus-latest`, `claude-4-opus`, `opus` | +| `claude-haiku-4.5` | ⚡ 快速,适合简单任务 | `gpt-4o-mini`, `gpt-3.5-turbo`, `claude-3-5-haiku-20241022`, `haiku` | +| `auto` | 🤖 自动选择 | `auto` | + +### 各客户端推荐配置 + +| 客户端 | 推荐模型名 | 实际使用 | +|--------|-----------|---------| +| Claude Code | `claude-sonnet-4` 或 `claude-sonnet-4.5` | 直接使用 Kiro 模型名 | +| Codex CLI | `gpt-4o` | 映射到 claude-sonnet-4 | +| Gemini CLI | `gemini-1.5-pro` | 映射到 claude-sonnet-4.5 | +| 其他 OpenAI 客户端 | `gpt-4o` | 映射到 claude-sonnet-4 | + +> 💡 **提示**:不确定用什么模型?直接用 `claude-sonnet-4` 或 `gpt-4o`,性价比最高。 diff --git a/KiroProxy/kiro_proxy/docs/02-features.md b/KiroProxy/kiro_proxy/docs/02-features.md new file mode 100644 index 0000000000000000000000000000000000000000..42d08691b73844e4220d6d563b5e844afc9be2c1 --- /dev/null +++ b/KiroProxy/kiro_proxy/docs/02-features.md @@ -0,0 +1,225 @@ +# 功能特性 + +## 多协议支持 + +Kiro Proxy 支持三种主流 AI API 协议,可以适配不同的客户端: + +| 协议 | 端点 | 适用客户端 | +|------|------|------------| +| OpenAI | `/v1/chat/completions` | Codex CLI, ChatGPT 客户端 | +| Anthropic | `/v1/messages` | Claude Code, Claude 客户端 | +| Gemini | `/v1/models/{model}:generateContent` | Gemini CLI | + +代理会自动将请求转换为 Kiro API 格式,响应转换回对应协议格式。 + +--- + +## 工具调用支持 + +完整支持三种协议的工具调用功能: + +### Anthropic 协议(Claude Code) + +- `tools` 定义和 `tool_result` 响应完整支持 +- `tool_choice: required` 支持(通过 prompt 注入) +- `web_search` 特殊工具自动识别 +- 工具数量限制(最多 50 个) +- 描述截断(超过 500 字符自动截断) + +### OpenAI 协议(Codex CLI) + +- `tools` 定义(function 类型) +- `tool_calls` 响应处理 +- `tool` 角色消息转换 +- `tool_choice: required/any` 支持 +- 工具数量限制和描述截断 + +### Gemini 协议 + +- `functionDeclarations` 工具定义 +- `functionCall` 响应处理 +- `functionResponse` 工具结果 +- `toolConfig.functionCallingConfig.mode` 支持(ANY/REQUIRED) +- 工具数量限制和描述截断 + +### 历史消息修复 + +Kiro API 要求消息必须严格交替(user → assistant → user → assistant),代理会自动: + +- 检测并修复连续的同角色消息 +- 合并重复的 tool_results +- 插入占位消息保持交替 + +--- + +## 多账号管理 + +### 账号轮询 + +支持添加多个 Kiro 账号,代理会自动轮询使用(默认随机): + +- 每次请求随机选择一个可用账号(尽量避免连续命中同一账号) +- 自动跳过冷却中或不健康的账号 +- 分散请求压力,降低单账号 RPM 过高导致封禁风险 + +### 会话粘性(可选) + +为了保持对话上下文的连贯性,在非 `random` 策略下会启用会话粘性: + +- 同一会话 ID 在 60 秒内会使用同一账号 +- 超过 60 秒或账号不可用时才切换 +- 会话 ID 由请求内容生成;可通过 `~/.kiro-proxy/priority.json` 中的 `strategy` 调整策略 + +### 账号状态 + +每个账号有四种状态: + +| 状态 | 说明 | 颜色 | +|------|------|------| +| Active | 正常可用 | 绿色 | +| Cooldown | 触发限流,冷却中 | 黄色 | +| Unhealthy | 健康检查失败 | 红色 | +| Disabled | 手动禁用 | 灰色 | + +--- + +## Token 自动刷新 + +### 自动检测 + +- 后台每 5 分钟检查所有账号的 Token 状态 +- 检测 Token 是否即将过期(15 分钟内) + +### 自动刷新 + +- 发现即将过期的 Token 自动刷新 +- 支持 Social 认证(Google/GitHub)的 refresh_token +- 刷新失败会标记账号为不健康 + +### 手动刷新 + +- 在账号卡片点击「刷新 Token」 +- 或点击「刷新所有 Token」批量刷新 + +--- + +## 配额管理 + +### 429 自动处理 + +当 Kiro API 返回 429 (Too Many Requests) 时: + +1. 自动将该账号标记为 Cooldown 状态 +2. 设置 5 分钟冷却时间 +3. 立即切换到其他可用账号重试 +4. 冷却结束后自动恢复 + +### 手动恢复 + +如果需要提前恢复账号: + +1. 在「监控」页面查看配额状态 +2. 点击账号旁的「恢复」按钮 + +--- + +## 流量监控 + +### 请求记录 + +记录所有经过代理的 LLM 请求: + +- 请求时间、模型、账号 +- 输入/输出 Token 数量 +- 响应时间、状态码 +- 完整的请求和响应内容 + +### 搜索过滤 + +- 按协议筛选(OpenAI/Anthropic/Gemini) +- 按状态筛选(完成/错误/进行中) +- 关键词搜索 + +### 导出功能 + +- 支持导出为 JSON 格式 +- 可选择导出全部或指定记录 + +--- + +## 登录方式 + +### Google 登录 + +使用 Google 账号通过 OAuth 授权登录。 + +### GitHub 登录 + +使用 GitHub 账号通过 OAuth 授权登录。 + +### AWS Builder ID + +使用 AWS Builder ID 通过 Device Code Flow 登录: + +1. 点击 AWS 登录按钮 +2. 复制显示的授权码 +3. 在浏览器中打开授权页面 +4. 输入授权码完成登录 + +--- + +## 历史消息管理 + +### 对话长度限制 + +Kiro API 有输入长度限制,当对话历史过长时会返回 `CONTENT_LENGTH_EXCEEDS_THRESHOLD` 错误。 + +代理内置了多种策略自动处理这个问题: + +### 可用策略 + +| 策略 | 说明 | 触发时机 | +|------|------|----------| +| 自动截断 | 优先保留最新上下文并摘要前文,必要时截断 | 每次请求前 | +| 智能摘要 | 用 AI 生成早期对话摘要 | 超过阈值时 | +| 错误重试 | 遇到长度错误时截断重试 | 收到错误后 | +| 预估检测 | 预估 token 数量,超限预先截断 | 每次请求前 | + +### 配置选项 + +在「设置」页面可以配置: + +- **最大消息数** - 自动截断时保留的消息数量(默认 30) +- **最大字符数** - 自动截断时的字符数限制(默认 150000) +- **重试保留数** - 错误重试时保留的消息数(默认 20) +- **最大重试次数** - 错误重试的最大次数(默认 2) +- **摘要保留数** - 智能摘要时保留的最近消息数(默认 10) +- **摘要阈值** - 触发智能摘要的字符数阈值(默认 100000) +- **添加警告** - 截断时是否在日志中记录 + +### 推荐配置 + +- **默认**:只启用「错误重试」,遇到问题时自动处理 +- **保守**:启用「智能摘要 + 错误重试」,保留关键信息 +- **激进**:启用「自动截断 + 预估检测」,预防性截断 + +--- + +## 配置持久化 + +### 自动保存 + +账号配置自动保存到 `~/.kiro-proxy/config.json`: + +- 账号列表和状态 +- 启用/禁用设置 +- Token 文件路径 + +### 重启恢复 + +重启代理后自动加载保存的配置,无需重新添加账号。 + +### 导入导出 + +- 「导出配置」下载当前配置 +- 「导入配置」从文件恢复 diff --git a/KiroProxy/kiro_proxy/docs/03-faq.md b/KiroProxy/kiro_proxy/docs/03-faq.md new file mode 100644 index 0000000000000000000000000000000000000000..bced898327026cad3f50b2282474d522bddabbe3 --- /dev/null +++ b/KiroProxy/kiro_proxy/docs/03-faq.md @@ -0,0 +1,192 @@ +# 常见问题 + +## 连接问题 + +### 无法连接到代理 + +**症状**:客户端报错 `Connection refused` 或 `ECONNREFUSED` + +**解决方案**: + +1. 确认代理已启动 + ```bash + python run.py + # 应该看到: Kiro API Proxy v1.7.1 + # http://localhost:8080 + ``` + +2. 检查端口是否正确 + - 默认端口是 8080 + - 如果修改了端口,客户端配置也要对应修改 + +3. 检查防火墙 + - Windows: 允许 Python 通过防火墙 + - macOS: 系统偏好设置 → 安全性与隐私 → 防火墙 + +### 401 认证失败 + +**症状**:请求返回 401 Unauthorized + +**原因**:Token 已过期或无效 + +**解决方案**: + +1. 点击账号卡片的「刷新 Token」 +2. 如果刷新失败,重新登录获取新 Token +3. 检查账号状态是否为 Active + +--- + +## 请求问题 + +### 429 Too Many Requests + +**症状**:请求返回 429 错误 + +**原因**:Kiro 有请求频率限制,短时间内请求过多 + +**代理自动处理**: + +- 将该账号冷却 5 分钟 +- 自动切换到其他可用账号 +- 冷却结束后自动恢复 + +**建议**: + +- 添加多个账号分散请求 +- 避免短时间内大量请求 +- 在「监控」页面查看配额状态 + +### 对话太长 (CONTENT_LENGTH_EXCEEDS_THRESHOLD) + +**症状**:请求返回错误 `Input is too long` + +**原因**:Kiro API 有输入长度限制,这是服务端限制 + +**代理自动处理**: + +代理内置了历史消息管理功能,可以在「设置」页面配置: + +1. **自动截断** - 发送前优先保留最新上下文并摘要前文,必要时截断保留最近 N 条 +2. **错误重试** - 遇到长度错误时自动截断并重试(默认启用) +3. **预估检测** - 发送前预估 token 数量,超限则预先截断 + +推荐组合:**错误重试**(默认)或 **自动截断 + 预估检测** + +**手动解决方案**: + +1. 在 Claude Code 中输入 `/clear` 清空对话历史 +2. 清空后告诉 AI 你之前在做什么 +3. AI 会读取代码文件恢复上下文 + +**预防措施**: + +- 复杂任务分阶段完成 +- 每个阶段结束后清空对话 +- 避免在对话中粘贴大量代码 + +### 响应很慢 + +**可能原因**: + +1. 网络延迟 +2. Kiro 服务端繁忙 +3. 请求内容过长 + +**解决方案**: + +1. 在「监控」页面运行速度测试 +2. 尝试切换账号 +3. 检查网络连接 + +--- + +## Token 问题 + +### Token 过期了怎么办 + +**自动处理**:代理会自动检测并刷新 Token + +**手动处理**: + +1. 点击账号卡片的「刷新 Token」按钮 +2. 如果刷新失败,说明 refresh_token 也过期了 +3. 需要重新登录获取新 Token + +### 如何添加多个账号 + +**方法一:多次在线登录** + +1. 使用不同的 Google/GitHub 账号 +2. 每次登录后自动添加新账号 + +**方法二:扫描 Token** + +1. 在 Kiro IDE 中用不同账号登录 +2. 每次登录后点击「扫描 Token」 +3. 选择新的 Token 文件添加 + +### Token 文件在哪里 + +Token 保存在 `~/.aws/sso/cache/` 目录下: + +``` +~/.aws/sso/cache/ +├── xxxxxxxx.json # Token 文件 +├── yyyyyyyy.json # 另一个 Token +└── ... +``` + +每个 JSON 文件包含一个账号的 Token 信息。 + +--- + +## 模型问题 + +### 支持哪些模型 + +| 模型 | 能力 | 推荐场景 | +|------|------|----------| +| claude-sonnet-4 | ⭐⭐⭐ 均衡 | 日常编程,推荐 | +| claude-sonnet-4.5 | ⭐⭐⭐⭐ 更强 | 复杂任务 | +| claude-haiku-4.5 | ⚡ 快速 | 简单问答,速度优先 | + +### 模型映射关系 + +使用 OpenAI 模型名时会自动映射: + +| 请求模型 | 实际使用 | +|----------|----------| +| gpt-4o, gpt-4 | claude-sonnet-4 | +| gpt-4o-mini, gpt-3.5-turbo | claude-haiku-4.5 | +| o1, o1-preview | claude-sonnet-4.5 | + +### 支持工具调用吗 + +**支持!** Claude Code 的工具调用功能已验证可用,包括: + +- 文件读写 +- 命令执行 +- 代码搜索 +- 等等 + +--- + +## 其他问题 + +### 如何查看请求日志 + +1. 打开「日志」标签页 +2. 查看最近的请求记录 +3. 包含时间、路径、模型、状态、耗时 + +### 如何监控账号状态 + +1. 打开「监控」标签页 +2. 查看服务状态和统计信息 +3. 查看配额状态和冷却中的账号 + +### 配置文件在哪里 + +- 账号配置:`~/.kiro-proxy/config.json` +- Token 文件:`~/.aws/sso/cache/*.json` diff --git a/KiroProxy/kiro_proxy/docs/04-api.md b/KiroProxy/kiro_proxy/docs/04-api.md new file mode 100644 index 0000000000000000000000000000000000000000..e884133d670f526da630a5ca05196475eebf5e8a --- /dev/null +++ b/KiroProxy/kiro_proxy/docs/04-api.md @@ -0,0 +1,137 @@ +# API 参考 + +## 代理端点 + +### OpenAI 协议 + +#### POST /v1/chat/completions + +Chat Completions API,兼容 OpenAI 格式。 + +**请求示例:** + +```json +{ + "model": "gpt-4o", + "messages": [ + {"role": "user", "content": "Hello!"} + ], + "stream": true +} +``` + +**模型映射:** + +| 请求模型 | 实际使用 | +|----------|----------| +| gpt-4o, gpt-4 | claude-sonnet-4 | +| gpt-4o-mini, gpt-3.5-turbo | claude-haiku-4.5 | +| o1, o1-preview | claude-sonnet-4.5 | + +#### GET /v1/models + +获取可用模型列表。 + +--- + +### Anthropic 协议 + +#### POST /v1/messages + +Messages API,兼容 Anthropic 格式。 + +**请求示例:** + +```json +{ + "model": "claude-sonnet-4", + "max_tokens": 4096, + "messages": [ + {"role": "user", "content": "Hello!"} + ] +} +``` + +#### POST /v1/messages/count_tokens + +计算消息的 Token 数量。 + +--- + +### Gemini 协议 + +#### POST /v1/models/{model}:generateContent + +Generate Content API,兼容 Gemini 格式。 + +--- + +## 管理 API + +### 状态与统计 + +| 端点 | 方法 | 说明 | +|------|------|------| +| `/api/status` | GET | 服务状态 | +| `/api/stats` | GET | 基础统计 | +| `/api/stats/detailed` | GET | 详细统计 | +| `/api/quota` | GET | 配额状态 | +| `/api/logs` | GET | 请求日志 | + +### 账号管理 + +| 端点 | 方法 | 说明 | +|------|------|------| +| `/api/accounts` | GET | 账号列表 | +| `/api/accounts` | POST | 添加账号 | +| `/api/accounts/{id}` | GET | 账号详情 | +| `/api/accounts/{id}` | DELETE | 删除账号 | +| `/api/accounts/{id}/toggle` | POST | 启用/禁用 | +| `/api/accounts/{id}/refresh` | POST | 刷新 Token | +| `/api/accounts/{id}/restore` | POST | 恢复账号 | +| `/api/accounts/{id}/usage` | GET | 用量查询 | +| `/api/accounts/refresh-all` | POST | 刷新所有 | + +### Token 操作 + +| 端点 | 方法 | 说明 | +|------|------|------| +| `/api/token/scan` | GET | 扫描本地 Token | +| `/api/token/add-from-scan` | POST | 从扫描添加 | +| `/api/token/refresh-check` | POST | 检查 Token 状态 | + +### 登录 + +| 端点 | 方法 | 说明 | +|------|------|------| +| `/api/kiro/login/start` | POST | 启动 AWS 登录 | +| `/api/kiro/login/poll` | GET | 轮询登录状态 | +| `/api/kiro/login/cancel` | POST | 取消登录 | +| `/api/kiro/social/start` | POST | 启动 Social 登录 | +| `/api/kiro/social/exchange` | POST | 交换 Token | + +### Flow 监控 + +| 端点 | 方法 | 说明 | +|------|------|------| +| `/api/flows` | GET | 查询 Flows | +| `/api/flows/stats` | GET | Flow 统计 | +| `/api/flows/{id}` | GET | Flow 详情 | +| `/api/flows/{id}/bookmark` | POST | 收藏 Flow | +| `/api/flows/export` | POST | 导出 Flows | + +--- + +## 配置 + +### 配置文件位置 + +- 账号配置:`~/.kiro-proxy/config.json` +- Token 缓存:`~/.aws/sso/cache/` + +### 配置导入导出 + +| 端点 | 方法 | 说明 | +|------|------|------| +| `/api/config/export` | GET | 导出配置 | +| `/api/config/import` | POST | 导入配置 | diff --git a/KiroProxy/kiro_proxy/docs/05-server-deploy.md b/KiroProxy/kiro_proxy/docs/05-server-deploy.md new file mode 100644 index 0000000000000000000000000000000000000000..0798f2a635d732bfbe35f2ff730e7b8b45d81fc3 --- /dev/null +++ b/KiroProxy/kiro_proxy/docs/05-server-deploy.md @@ -0,0 +1,659 @@ +# 服务器部署指南 + +本文档详细介绍如何在各种服务器环境中部署 Kiro Proxy。 + +## 目录 + +- [方式一:预编译二进制(推荐)](#方式一预编译二进制推荐) +- [方式二:从源码运行](#方式二从源码运行) +- [方式三:Docker 部署](#方式三docker-部署) +- [账号配置](#账号配置) +- [开机自启配置](#开机自启配置) +- [反向代理配置](#反向代理配置) +- [常见问题](#常见问题) + +--- + +## 方式一:预编译二进制(推荐) + +最简单的方式,不需要安装任何依赖。 + +### Linux (x86_64) + +```bash +# 下载最新版本 +wget https://github.com/petehsu/KiroProxy/releases/latest/download/KiroProxy-1.7.1-linux-x86_64 + +# 添加执行权限 +chmod +x KiroProxy-1.7.1-linux-x86_64 + +# 运行 +./KiroProxy-1.7.1-linux-x86_64 + +# 指定端口 +./KiroProxy-1.7.1-linux-x86_64 8081 +``` + +**使用 curl 下载:** + +```bash +curl -LO https://github.com/petehsu/KiroProxy/releases/latest/download/KiroProxy-1.7.1-linux-x86_64 +chmod +x KiroProxy-1.7.1-linux-x86_64 +./KiroProxy-1.7.1-linux-x86_64 +``` + +**Debian/Ubuntu 使用 deb 包:** + +```bash +wget https://github.com/petehsu/KiroProxy/releases/latest/download/kiroproxy_1.7.1_amd64.deb +sudo dpkg -i kiroproxy_1.7.1_amd64.deb + +# 运行 +KiroProxy +``` + +**Fedora/RHEL/CentOS 使用 rpm 包:** + +```bash +wget https://github.com/petehsu/KiroProxy/releases/latest/download/kiroproxy-1.7.1-1.x86_64.rpm +sudo rpm -i kiroproxy-1.7.1-1.x86_64.rpm + +# 运行 +KiroProxy +``` + +### macOS + +```bash +# Intel Mac +curl -LO https://github.com/petehsu/KiroProxy/releases/latest/download/KiroProxy-1.7.1-macos-x86_64 +chmod +x KiroProxy-1.7.1-macos-x86_64 +./KiroProxy-1.7.1-macos-x86_64 + +# 如果提示无法验证开发者,运行: +xattr -d com.apple.quarantine KiroProxy-1.7.1-macos-x86_64 +``` + +### Windows + +```powershell +# PowerShell 下载 +Invoke-WebRequest -Uri "https://github.com/petehsu/KiroProxy/releases/latest/download/KiroProxy-1.7.1-windows-x86_64.exe" -OutFile "KiroProxy.exe" + +# 运行 +.\KiroProxy.exe + +# 指定端口 +.\KiroProxy.exe 8081 +``` + +--- + +## 方式二:从源码运行 + +需要 Python 3.9+ 和 Git。 + +### 安装 Git(如果没有) + +**Ubuntu/Debian:** +```bash +sudo apt update +sudo apt install git -y +``` + +**CentOS/RHEL/Fedora:** +```bash +sudo yum install git -y +# 或 +sudo dnf install git -y +``` + +**macOS:** +```bash +# 安装 Xcode Command Line Tools +xcode-select --install +# 或使用 Homebrew +brew install git +``` + +**Windows:** +从 https://git-scm.com/download/win 下载安装 + +### 安装 Python(如果没有) + +**Ubuntu/Debian:** +```bash +sudo apt update +sudo apt install python3 python3-pip python3-venv -y +``` + +**CentOS/RHEL 8+:** +```bash +sudo dnf install python39 python39-pip -y +``` + +**Fedora:** +```bash +sudo dnf install python3 python3-pip -y +``` + +**macOS:** +```bash +brew install python@3.11 +``` + +**Windows:** +从 https://www.python.org/downloads/ 下载安装,勾选 "Add to PATH" + +### 克隆并运行 + +```bash +# 克隆项目 +git clone https://github.com/petehsu/KiroProxy.git +cd KiroProxy + +# 创建虚拟环境(推荐) +python3 -m venv venv +source venv/bin/activate # Windows: venv\Scripts\activate + +# 安装依赖 +pip install -r requirements.txt + +# 运行 +python run.py + +# 指定端口 +python run.py 8081 + +# 或使用 CLI +python run.py serve -p 8081 +``` + +### 更新到最新版本 + +```bash +cd KiroProxy +git pull origin main +pip install -r requirements.txt +``` + +--- + +## 方式三:Docker 部署 + +### 使用 Dockerfile + +创建 `Dockerfile`: + +```dockerfile +FROM python:3.11-slim + +WORKDIR /app + +# 安装依赖 +COPY requirements.txt . +RUN pip install --no-cache-dir -r requirements.txt + +# 复制代码 +COPY . . + +# 暴露端口 +EXPOSE 8080 + +# 数据目录 +VOLUME ["/root/.config/kiro-proxy"] + +# 启动 +CMD ["python", "run.py"] +``` + +构建并运行: + +```bash +docker build -t kiro-proxy . +docker run -d -p 8080:8080 -v kiro-data:/root/.config/kiro-proxy --name kiro-proxy kiro-proxy +``` + +### Docker Compose + +创建 `docker-compose.yml`: + +```yaml +version: '3' +services: + kiro-proxy: + build: . + ports: + - "8080:8080" + volumes: + - kiro-data:/root/.config/kiro-proxy + restart: unless-stopped + +volumes: + kiro-data: +``` + +运行: + +```bash +docker-compose up -d +``` + +--- + +## 账号配置 + +服务器通常没有浏览器,有以下几种方式添加账号: + +### 方式一:远程登录链接(推荐) + +1. 在服务器上启动 KiroProxy +2. 在本地浏览器打开 `http://服务器IP:8080` +3. 点击「远程登录链接」按钮 +4. 复制生成的链接,在本地浏览器打开 +5. 完成 Google/GitHub 授权 +6. 账号自动添加到服务器 + +### 方式二:导入导出 + +**本地电脑:** +```bash +# 运行 KiroProxy 并登录 +python run.py + +# 导出账号 +python run.py accounts export -o accounts.json +``` + +**服务器:** +```bash +# 上传 accounts.json 到服务器后导入 +python run.py accounts import accounts.json + +# 或使用 curl +curl -X POST http://localhost:8080/api/accounts/import \ + -H "Content-Type: application/json" \ + -d @accounts.json +``` + +### 方式三:手动添加 Token + +1. 在本地 Kiro IDE 登录 +2. 找到 `~/.aws/sso/cache/` 目录下的 JSON 文件 +3. 复制 `accessToken` 和 `refreshToken` + +**服务器上:** +```bash +# 交互式添加 +python run.py accounts add + +# 或使用 API +curl -X POST http://localhost:8080/api/accounts/manual \ + -H "Content-Type: application/json" \ + -d '{ + "name": "我的账号", + "access_token": "eyJ...", + "refresh_token": "eyJ..." + }' +``` + +### 方式四:扫描本地 Token + +如果服务器上安装了 Kiro IDE 并已登录: + +```bash +python run.py accounts scan --auto +``` + +--- + +## 开机自启配置 + +### Linux (systemd) + +创建服务文件 `/etc/systemd/system/kiro-proxy.service`: + +```ini +[Unit] +Description=Kiro API Proxy +After=network.target + +[Service] +Type=simple +User=root +WorkingDirectory=/opt/kiro-proxy +ExecStart=/opt/kiro-proxy/KiroProxy +Restart=always +RestartSec=10 + +[Install] +WantedBy=multi-user.target +``` + +**使用预编译二进制:** +```bash +# 创建目录并下载 +sudo mkdir -p /opt/kiro-proxy +sudo wget -O /opt/kiro-proxy/KiroProxy https://github.com/petehsu/KiroProxy/releases/latest/download/KiroProxy-1.7.1-linux-x86_64 +sudo chmod +x /opt/kiro-proxy/KiroProxy + +# 创建服务文件 +sudo tee /etc/systemd/system/kiro-proxy.service << 'EOF' +[Unit] +Description=Kiro API Proxy +After=network.target + +[Service] +Type=simple +User=root +WorkingDirectory=/opt/kiro-proxy +ExecStart=/opt/kiro-proxy/KiroProxy +Restart=always +RestartSec=10 + +[Install] +WantedBy=multi-user.target +EOF + +# 启用并启动 +sudo systemctl daemon-reload +sudo systemctl enable kiro-proxy +sudo systemctl start kiro-proxy + +# 查看状态 +sudo systemctl status kiro-proxy + +# 查看日志 +sudo journalctl -u kiro-proxy -f +``` + +**使用源码运行:** +```bash +sudo tee /etc/systemd/system/kiro-proxy.service << 'EOF' +[Unit] +Description=Kiro API Proxy +After=network.target + +[Service] +Type=simple +User=root +WorkingDirectory=/opt/KiroProxy +ExecStart=/opt/KiroProxy/venv/bin/python run.py +Restart=always +RestartSec=10 + +[Install] +WantedBy=multi-user.target +EOF +``` + +### Linux (使用 screen/tmux) + +**screen:** +```bash +# 安装 +sudo apt install screen # Debian/Ubuntu +sudo yum install screen # CentOS + +# 创建会话并运行 +screen -S kiro +./KiroProxy + +# 按 Ctrl+A D 退出会话(程序继续运行) + +# 重新连接 +screen -r kiro +``` + +**tmux:** +```bash +# 安装 +sudo apt install tmux # Debian/Ubuntu +sudo yum install tmux # CentOS + +# 创建会话并运行 +tmux new -s kiro +./KiroProxy + +# 按 Ctrl+B D 退出会话 + +# 重新连接 +tmux attach -t kiro +``` + +### Linux (使用 nohup) + +```bash +# 后台运行 +nohup ./KiroProxy > kiro.log 2>&1 & + +# 查看日志 +tail -f kiro.log + +# 停止 +pkill -f KiroProxy +``` + +### macOS (launchd) + +创建 `~/Library/LaunchAgents/com.kiro.proxy.plist`: + +```xml + + + + + Label + com.kiro.proxy + ProgramArguments + + /usr/local/bin/KiroProxy + + RunAtLoad + + KeepAlive + + StandardOutPath + /tmp/kiro-proxy.log + StandardErrorPath + /tmp/kiro-proxy.err + + +``` + +```bash +# 加载服务 +launchctl load ~/Library/LaunchAgents/com.kiro.proxy.plist + +# 启动 +launchctl start com.kiro.proxy + +# 停止 +launchctl stop com.kiro.proxy + +# 卸载 +launchctl unload ~/Library/LaunchAgents/com.kiro.proxy.plist +``` + +### Windows (任务计划程序) + +**方法一:使用 PowerShell 创建计划任务** + +```powershell +# 创建计划任务(开机自启) +$action = New-ScheduledTaskAction -Execute "C:\KiroProxy\KiroProxy.exe" +$trigger = New-ScheduledTaskTrigger -AtStartup +$principal = New-ScheduledTaskPrincipal -UserId "SYSTEM" -LogonType ServiceAccount +Register-ScheduledTask -TaskName "KiroProxy" -Action $action -Trigger $trigger -Principal $principal + +# 立即启动 +Start-ScheduledTask -TaskName "KiroProxy" + +# 停止 +Stop-ScheduledTask -TaskName "KiroProxy" + +# 删除 +Unregister-ScheduledTask -TaskName "KiroProxy" -Confirm:$false +``` + +**方法二:使用 NSSM 创建 Windows 服务** + +1. 下载 NSSM: https://nssm.cc/download +2. 解压并运行: + +```cmd +nssm install KiroProxy C:\KiroProxy\KiroProxy.exe +nssm start KiroProxy + +# 停止 +nssm stop KiroProxy + +# 删除 +nssm remove KiroProxy confirm +``` + +**方法三:创建 VBS 启动脚本** + +创建 `start-kiro.vbs`: + +```vbscript +Set WshShell = CreateObject("WScript.Shell") +WshShell.Run "C:\KiroProxy\KiroProxy.exe", 0, False +``` + +将此文件放入启动文件夹:`shell:startup` + +--- + +## 反向代理配置 + +### Nginx + +```nginx +server { + listen 80; + server_name kiro.example.com; + + location / { + proxy_pass http://127.0.0.1:8080; + proxy_http_version 1.1; + proxy_set_header Upgrade $http_upgrade; + proxy_set_header Connection "upgrade"; + proxy_set_header Host $host; + proxy_set_header X-Real-IP $remote_addr; + proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; + proxy_set_header X-Forwarded-Proto $scheme; + + # SSE 支持 + proxy_buffering off; + proxy_cache off; + proxy_read_timeout 86400; + } +} +``` + +**启用 HTTPS(使用 Certbot):** + +```bash +sudo apt install certbot python3-certbot-nginx +sudo certbot --nginx -d kiro.example.com +``` + +### Caddy + +```caddyfile +kiro.example.com { + reverse_proxy localhost:8080 +} +``` + +Caddy 会自动申请和续期 HTTPS 证书。 + +### Apache + +```apache + + ServerName kiro.example.com + + ProxyPreserveHost On + ProxyPass / http://127.0.0.1:8080/ + ProxyPassReverse / http://127.0.0.1:8080/ + + # SSE 支持 + SetEnv proxy-sendchunked 1 + +``` + +--- + +## 常见问题 + +### 端口被占用 + +```bash +# 查看端口占用 +lsof -i :8080 # Linux/macOS +netstat -ano | findstr :8080 # Windows + +# 使用其他端口 +./KiroProxy 8081 +``` + +### 防火墙配置 + +**Ubuntu/Debian (ufw):** +```bash +sudo ufw allow 8080/tcp +``` + +**CentOS/RHEL (firewalld):** +```bash +sudo firewall-cmd --permanent --add-port=8080/tcp +sudo firewall-cmd --reload +``` + +**Windows:** +```powershell +New-NetFirewallRule -DisplayName "KiroProxy" -Direction Inbound -Port 8080 -Protocol TCP -Action Allow +``` + +### 权限问题 + +```bash +# 如果遇到权限问题 +chmod +x KiroProxy +sudo chown -R $USER:$USER /opt/kiro-proxy +``` + +### 查看日志 + +```bash +# systemd +sudo journalctl -u kiro-proxy -f + +# 直接运行时 +./KiroProxy 2>&1 | tee kiro.log +``` + +### 更新版本 + +**预编译二进制:** +```bash +# 停止服务 +sudo systemctl stop kiro-proxy + +# 下载新版本 +sudo wget -O /opt/kiro-proxy/KiroProxy https://github.com/petehsu/KiroProxy/releases/latest/download/KiroProxy-1.7.1-linux-x86_64 +sudo chmod +x /opt/kiro-proxy/KiroProxy + +# 启动服务 +sudo systemctl start kiro-proxy +``` + +**源码方式:** +```bash +cd /opt/KiroProxy +git pull origin main +pip install -r requirements.txt +sudo systemctl restart kiro-proxy +``` diff --git a/KiroProxy/kiro_proxy/handlers/__init__.py b/KiroProxy/kiro_proxy/handlers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bbd0890440f98cb1bd8cee6d8cdb219b63579167 --- /dev/null +++ b/KiroProxy/kiro_proxy/handlers/__init__.py @@ -0,0 +1 @@ +# API Handlers diff --git a/KiroProxy/kiro_proxy/handlers/admin/__init__.py b/KiroProxy/kiro_proxy/handlers/admin/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5168b327c5f6a4d9893887d24414bc1058e0f245 --- /dev/null +++ b/KiroProxy/kiro_proxy/handlers/admin/__init__.py @@ -0,0 +1,1879 @@ +"""管理 API 处理""" +import json +import uuid +import time +import httpx +from pathlib import Path +from datetime import datetime +from dataclasses import asdict +from typing import Optional +from fastapi import Request, HTTPException, Query + +from ...config import TOKEN_PATH, MODELS_URL +from ...core import state, Account, stats_manager, get_browsers_info, open_url, flow_monitor, get_account_usage +from ...credential import quota_manager, generate_machine_id, get_kiro_version, CredentialStatus +from ...auth import start_device_flow, poll_device_flow, cancel_device_flow, get_login_state, save_credentials_to_file +from ...auth import start_social_auth, exchange_social_auth_token, cancel_social_auth, get_social_auth_state + + +async def get_status(): + """服务状态""" + try: + # 检查是否有可用账号 + available_count = len([a for a in state.accounts if a.enabled and a.is_available()]) + return { + "ok": available_count > 0, + "available_accounts": available_count, + "total_accounts": len(state.accounts), + "stats": state.get_stats() + } + except Exception as e: + return {"ok": False, "error": str(e), "stats": state.get_stats()} + + +async def get_stats(): + """获取统计信息""" + return state.get_stats() + + +async def event_logging_batch(request: Request): + """接收事件日志批量上报(兼容客户端)""" + try: + await request.json() + except Exception: + pass + return {"ok": True} + + +async def get_logs(limit: int = Query(100, le=1000)): + """获取请求日志""" + logs = list(state.request_logs)[-limit:] + return { + "logs": [asdict(log) for log in reversed(logs)], + "total": len(state.request_logs) + } + + +async def get_accounts(): + """获取账号列表(增强版)""" + return { + "accounts": state.get_accounts_status() + } + + +async def get_account_detail(account_id: str): + """获取账号详细信息""" + for acc in state.accounts: + if acc.id == account_id: + creds = acc.get_credentials() + return { + "id": acc.id, + "name": acc.name, + "enabled": acc.enabled, + "status": acc.status.value, + "available": acc.is_available(), + "request_count": acc.request_count, + "error_count": acc.error_count, + "last_used": acc.last_used, + "token_path": acc.token_path, + "machine_id": acc.get_machine_id()[:16] + "...", + "credentials": { + "access_token": creds.access_token if creds else None, + "refresh_token": creds.refresh_token if creds else None, + "profile_arn": creds.profile_arn if creds else None, + "client_id": creds.client_id if creds else None, + "auth_method": creds.auth_method if creds else None, + "provider": creds.provider if creds else None, + "region": creds.region if creds else None, + "expires_at": creds.expires_at if creds else None, + "is_expired": acc.is_token_expired(), + "is_expiring_soon": acc.is_token_expiring_soon(), + } if creds else None, + "cooldown": { + "is_cooldown": not quota_manager.is_available(acc.id), + "remaining_seconds": quota_manager.get_cooldown_remaining(acc.id), + } + } + raise HTTPException(404, "Account not found") + + +async def add_account(request: Request): + """添加账号""" + body = await request.json() + name = body.get("name", f"账号{len(state.accounts)+1}") + token_path = body.get("token_path") + + if not token_path or not Path(token_path).exists(): + raise HTTPException(400, "Invalid token path") + + account = Account( + id=uuid.uuid4().hex[:8], + name=name, + token_path=token_path + ) + state.accounts.append(account) + + # 预加载凭证 + account.load_credentials() + + # 保存配置 + state._save_accounts() + + return {"ok": True, "account_id": account.id} + + +async def delete_account(account_id: str): + """删除账号""" + state.accounts = [a for a in state.accounts if a.id != account_id] + # 清理配额记录 + quota_manager.restore(account_id) + # 保存配置 + state._save_accounts() + return {"ok": True} + + +async def update_account(account_id: str, request: Request): + """更新账号信息 + + 支持更新: + - name: 账号名称 + - enabled: 是否启用 + - provider: 登录提供商 (Google/Github) + + 凭证相关字段(需要重新加载凭证): + - refresh_token: 刷新令牌 + - client_id: IDC 客户端 ID + - client_secret: IDC 客户端密钥 + - region: 区域 + """ + body = await request.json() + + # 查找账号 + account = None + for acc in state.accounts: + if acc.id == account_id: + account = acc + break + + if not account: + raise HTTPException(404, "账号不存在") + + updated_fields = [] + + # 更新基本信息 + if "name" in body: + new_name = body["name"].strip() + if new_name: + account.name = new_name + updated_fields.append("name") + + if "enabled" in body: + account.enabled = bool(body["enabled"]) + updated_fields.append("enabled") + + # 更新凭证相关字段 + creds = account.get_credentials() + creds_updated = False + + if creds: + if "provider" in body: + provider = body["provider"].strip() if body["provider"] else None + if provider in (None, "", "Google", "Github"): + creds.provider = provider if provider else None + creds_updated = True + updated_fields.append("provider") + + if "refresh_token" in body: + new_rt = body["refresh_token"].strip() + if new_rt and len(new_rt) > 50: + creds.refresh_token = new_rt + creds_updated = True + updated_fields.append("refresh_token") + + if "client_id" in body: + creds.client_id = body["client_id"].strip() or None + creds_updated = True + updated_fields.append("client_id") + + if "client_secret" in body: + creds.client_secret = body["client_secret"].strip() or None + creds_updated = True + updated_fields.append("client_secret") + + if "region" in body: + new_region = body["region"].strip() + if new_region: + creds.region = new_region + creds_updated = True + updated_fields.append("region") + + # 保存凭证到文件 + if creds_updated: + creds.save_to_file(account.token_path) + # 重新加载凭证 + account._credentials = None + account.load_credentials() + + # 保存账号配置 + state._save_accounts() + + return { + "ok": True, + "account_id": account_id, + "updated_fields": updated_fields, + "message": f"已更新: {', '.join(updated_fields)}" if updated_fields else "无更新" + } + + +async def toggle_account(account_id: str): + """启用/禁用账号""" + for acc in state.accounts: + if acc.id == account_id: + acc.enabled = not acc.enabled + # 手动切换时清除自动禁用标记(避免后续被自动启用覆盖手动意图) + if hasattr(acc, "auto_disabled"): + acc.auto_disabled = False + # 保存配置 + state._save_accounts() + return {"ok": True, "enabled": acc.enabled} + raise HTTPException(404, "Account not found") + + +async def refresh_account_token(account_id: str): + """刷新指定账号的 token""" + success, message = await state.refresh_account_token(account_id) + return {"ok": success, "message": message} + + +async def refresh_all_tokens(): + """刷新所有账号的 token""" + results = [] + for acc in state.accounts: + if acc.enabled: + try: + success, msg = await acc.refresh_token() + results.append({ + "account_id": acc.id, + "name": acc.name, + "success": success, + "message": msg + }) + except Exception as e: + results.append({ + "account_id": acc.id, + "name": acc.name, + "success": False, + "message": str(e) + }) + + refreshed_count = len([r for r in results if r["success"]]) + return { + "ok": True, + "results": results, + "refreshed": refreshed_count, + "total": len(results) + } + + +async def restore_account(account_id: str): + """恢复账号(从冷却状态)""" + restored = quota_manager.restore(account_id) + if restored: + for acc in state.accounts: + if acc.id == account_id: + from ...credential import CredentialStatus + acc.status = CredentialStatus.ACTIVE + break + return {"ok": restored} + + +async def speedtest(): + """测试 API 延迟""" + account = state.get_available_account() + if not account: + return {"ok": False, "error": "No available account"} + + start = time.time() + try: + token = account.get_token() + machine_id = account.get_machine_id() + kiro_version = get_kiro_version() + + headers = { + "content-type": "application/json", + "x-amz-user-agent": f"aws-sdk-js/1.0.0 KiroIDE-{kiro_version}-{machine_id}", + "Authorization": f"Bearer {token}", + } + async with httpx.AsyncClient(verify=False, timeout=10) as client: + resp = await client.get(MODELS_URL, headers=headers, params={"origin": "AI_EDITOR"}) + latency = (time.time() - start) * 1000 + return { + "ok": resp.status_code == 200, + "latency_ms": round(latency, 2), + "status": resp.status_code, + "account_id": account.id + } + except Exception as e: + return {"ok": False, "error": str(e), "latency_ms": (time.time() - start) * 1000} + + +async def test_account_token(account_id: str): + """测试指定账号的 Token 是否有效 + + 测试内容: + 1. Token 是否存在 + 2. Token 是否过期 + 3. 调用 Kiro API 验证 Token 有效性 + 4. 获取用户邮箱(验证 Token 权限) + + Returns: + 测试结果,包含各项检查状态 + """ + # 查找账号 + account = None + for acc in state.accounts: + if acc.id == account_id: + account = acc + break + + if not account: + return {"ok": False, "error": "账号不存在"} + + result = { + "ok": True, + "account_id": account_id, + "account_name": account.name, + "tests": {} + } + + # 1. 检查 Token 是否存在 + token = account.get_token() + result["tests"]["token_exists"] = { + "passed": bool(token), + "message": "Token 存在" if token else "Token 不存在" + } + + if not token: + result["ok"] = False + return result + + # 2. 检查 Token 是否过期 + creds = account.get_credentials() + is_expired = account.is_token_expired() + is_expiring_soon = account.is_token_expiring_soon(10) + + result["tests"]["token_expiry"] = { + "passed": not is_expired, + "message": "Token 已过期" if is_expired else ("Token 即将过期" if is_expiring_soon else "Token 有效期正常"), + "expires_at": creds.expires_at if creds else None, + "is_expiring_soon": is_expiring_soon + } + + if is_expired: + result["ok"] = False + result["tests"]["token_expiry"]["suggestion"] = "请刷新 Token" + + # 3. 调用 Kiro API 验证 Token + start = time.time() + try: + machine_id = account.get_machine_id() + kiro_version = get_kiro_version() + + headers = { + "content-type": "application/json", + "x-amz-user-agent": f"aws-sdk-js/1.0.0 KiroIDE-{kiro_version}-{machine_id}", + "Authorization": f"Bearer {token}", + } + async with httpx.AsyncClient(verify=False, timeout=15) as client: + resp = await client.get(MODELS_URL, headers=headers, params={"origin": "AI_EDITOR"}) + latency = (time.time() - start) * 1000 + + api_ok = resp.status_code == 200 + result["tests"]["api_call"] = { + "passed": api_ok, + "message": "API 调用成功" if api_ok else f"API 调用失败 (HTTP {resp.status_code})", + "status_code": resp.status_code, + "latency_ms": round(latency, 2) + } + + if resp.status_code == 401: + result["tests"]["api_call"]["suggestion"] = "Token 无效或已过期,请刷新 Token" + result["ok"] = False + elif resp.status_code == 429: + result["tests"]["api_call"]["suggestion"] = "请求过于频繁,请稍后再试" + elif resp.status_code == 403: + result["tests"]["api_call"]["suggestion"] = "账号可能已被封禁" + result["ok"] = False + elif not api_ok: + result["ok"] = False + + except httpx.TimeoutException: + result["tests"]["api_call"] = { + "passed": False, + "message": "API 调用超时", + "suggestion": "网络连接问题,请检查网络" + } + result["ok"] = False + except Exception as e: + result["tests"]["api_call"] = { + "passed": False, + "message": f"API 调用异常: {str(e)}", + } + result["ok"] = False + + # 4. 尝试获取用户邮箱(验证 Token 权限) + try: + email = await _get_user_email(creds) + result["tests"]["get_email"] = { + "passed": bool(email), + "message": f"获取邮箱成功: {email}" if email else "无法获取邮箱", + "email": email + } + except Exception as e: + result["tests"]["get_email"] = { + "passed": False, + "message": f"获取邮箱失败: {str(e)}" + } + + # 汇总结果 + passed_count = sum(1 for t in result["tests"].values() if t.get("passed")) + total_count = len(result["tests"]) + result["summary"] = f"{passed_count}/{total_count} 项测试通过" + + return result + + +async def scan_tokens(): + """扫描系统中的 Kiro token 文件""" + from ...config import TOKEN_DIR + + found = [] + + # 扫描新目录 + if TOKEN_DIR.exists(): + for f in TOKEN_DIR.glob("*.json"): + try: + with open(f) as fp: + data = json.load(fp) + if "accessToken" in data: + # 检查是否已添加 + already_added = any(a.token_path == str(f) for a in state.accounts) + + auth_method = data.get("authMethod", "social") + client_id_hash = data.get("clientIdHash") + + # 检查 IdC 配置完整性 + idc_complete = None + if auth_method == "idc" and client_id_hash: + hash_file = TOKEN_DIR / f"{client_id_hash}.json" + if hash_file.exists(): + try: + with open(hash_file) as hf: + hash_data = json.load(hf) + idc_complete = bool(hash_data.get("clientId") and hash_data.get("clientSecret")) + except: + idc_complete = False + else: + idc_complete = False + + found.append({ + "path": str(f), + "name": f.stem, + "expires": data.get("expiresAt"), + "auth_method": auth_method, + "region": data.get("region", "us-east-1"), + "has_refresh_token": "refreshToken" in data, + "already_added": already_added, + "idc_config_complete": idc_complete, + }) + except: + pass + + # 兼容:也扫描旧的 AWS SSO 目录 + sso_cache = Path.home() / ".aws/sso/cache" + if sso_cache.exists(): + for f in sso_cache.glob("*.json"): + try: + with open(f) as fp: + data = json.load(fp) + if "accessToken" in data: + already_added = any(a.token_path == str(f) for a in state.accounts) + auth_method = data.get("authMethod", "social") + + found.append({ + "path": str(f), + "name": f.stem + " (旧目录)", + "expires": data.get("expiresAt"), + "auth_method": auth_method, + "region": data.get("region", "us-east-1"), + "has_refresh_token": "refreshToken" in data, + "already_added": already_added, + "idc_config_complete": None, + }) + except: + pass + + return {"tokens": found} + + +async def add_from_scan(request: Request): + """从扫描结果添加账号""" + body = await request.json() + token_path = body.get("path") + name = body.get("name", "扫描账号") + + if not token_path or not Path(token_path).exists(): + raise HTTPException(400, "Token 文件不存在") + + if any(a.token_path == token_path for a in state.accounts): + raise HTTPException(400, "该账号已添加") + + try: + with open(token_path) as f: + data = json.load(f) + if "accessToken" not in data: + raise HTTPException(400, "无效的 token 文件") + except json.JSONDecodeError: + raise HTTPException(400, "无效的 JSON 文件") + + account = Account( + id=uuid.uuid4().hex[:8], + name=name, + token_path=token_path + ) + state.accounts.append(account) + + # 预加载凭证 + account.load_credentials() + + # 保存配置 + state._save_accounts() + + return {"ok": True, "account_id": account.id} + + +async def export_config(): + """导出配置""" + return { + "accounts": [ + {"name": a.name, "token_path": a.token_path, "enabled": a.enabled} + for a in state.accounts + ], + "exported_at": datetime.now().isoformat() + } + + +async def import_config(request: Request): + """导入配置""" + body = await request.json() + accounts = body.get("accounts", []) + imported = 0 + + for acc_data in accounts: + token_path = acc_data.get("token_path", "") + if Path(token_path).exists(): + if not any(a.token_path == token_path for a in state.accounts): + account = Account( + id=uuid.uuid4().hex[:8], + name=acc_data.get("name", "导入账号"), + token_path=token_path, + enabled=acc_data.get("enabled", True) + ) + state.accounts.append(account) + account.load_credentials() + imported += 1 + + # 保存配置 + state._save_accounts() + + return {"ok": True, "imported": imported} + + +async def refresh_token_check(): + """检查所有账号的 token 状态""" + results = [] + for acc in state.accounts: + creds = acc.get_credentials() + if creds: + results.append({ + "id": acc.id, + "name": acc.name, + "valid": not acc.is_token_expired(), + "expiring_soon": acc.is_token_expiring_soon(), + "expires": creds.expires_at, + "auth_method": creds.auth_method, + "has_refresh_token": bool(creds.refresh_token), + }) + else: + results.append({ + "id": acc.id, + "name": acc.name, + "valid": False, + "error": "无法加载凭证" + }) + + return {"accounts": results} + + +async def get_quota_status(): + """获取配额状态""" + return { + "cooldown_seconds": quota_manager.COOLDOWN_SECONDS, + "exceeded_count": len(quota_manager.exceeded_records), + "exceeded_credentials": [ + { + "credential_id": r.credential_id, + "exceeded_at": r.exceeded_at, + "cooldown_until": r.cooldown_until, + "remaining_seconds": max(0, int(r.cooldown_until - time.time())), + "reason": r.reason + } + for r in quota_manager.exceeded_records.values() + ] + } + + +async def get_kiro_login_url(): + """获取 Kiro 登录说明""" + from ...config import TOKEN_DIR + return { + "message": "请使用本代理的登录功能,或从 Kiro IDE 导入 token", + "instructions": [ + "1. 点击「添加」按钮,选择登录方式", + "2. 或者从 Kiro IDE 的 ~/.aws/sso/cache/ 复制 token 文件", + "3. 将 token 文件放到 ~/.kiro-proxy/tokens/ 目录", + "4. 点击「扫描」按钮自动识别" + ], + "token_dir": str(TOKEN_DIR), + "token_dir_exists": TOKEN_DIR.exists() + } + + +async def get_detailed_stats(): + """获取详细统计信息""" + basic_stats = state.get_stats() + detailed = stats_manager.get_all_stats() + + return { + **basic_stats, + "detailed": detailed + } + + +async def run_health_check(): + """手动触发健康检查""" + results = [] + + for acc in state.accounts: + if not acc.enabled: + results.append({ + "id": acc.id, + "name": acc.name, + "status": "disabled", + "healthy": False + }) + continue + + try: + token = acc.get_token() + if not token: + acc.status = CredentialStatus.UNHEALTHY + results.append({ + "id": acc.id, + "name": acc.name, + "status": "no_token", + "healthy": False + }) + continue + + headers = { + "Authorization": f"Bearer {token}", + "content-type": "application/json" + } + + async with httpx.AsyncClient(verify=False, timeout=10) as client: + resp = await client.get( + MODELS_URL, + headers=headers, + params={"origin": "AI_EDITOR"} + ) + + if resp.status_code == 200: + if acc.status == CredentialStatus.UNHEALTHY: + acc.status = CredentialStatus.ACTIVE + results.append({ + "id": acc.id, + "name": acc.name, + "status": "healthy", + "healthy": True, + "latency_ms": resp.elapsed.total_seconds() * 1000 + }) + elif resp.status_code == 401: + acc.status = CredentialStatus.UNHEALTHY + results.append({ + "id": acc.id, + "name": acc.name, + "status": "auth_failed", + "healthy": False + }) + elif resp.status_code == 429: + results.append({ + "id": acc.id, + "name": acc.name, + "status": "rate_limited", + "healthy": True # 限流不代表不健康 + }) + else: + results.append({ + "id": acc.id, + "name": acc.name, + "status": f"error_{resp.status_code}", + "healthy": False + }) + + except Exception as e: + results.append({ + "id": acc.id, + "name": acc.name, + "status": "error", + "healthy": False, + "error": str(e) + }) + + healthy_count = len([r for r in results if r["healthy"]]) + return { + "ok": True, + "total": len(results), + "healthy": healthy_count, + "unhealthy": len(results) - healthy_count, + "results": results + } + + +# ==================== Kiro 登录 API ==================== + +async def get_browsers(): + """获取可用浏览器列表""" + return {"browsers": get_browsers_info()} + + +async def start_kiro_login(request: Request): + """启动 Kiro 设备授权登录""" + body = await request.json() if request.headers.get("content-type") == "application/json" else {} + region = body.get("region", "us-east-1") + + success, result = await start_device_flow(region) + + if success: + return { + "ok": True, + "user_code": result["user_code"], + "verification_uri": result["verification_uri"], + "expires_in": result["expires_in"], + "interval": result["interval"], + } + else: + return {"ok": False, "error": result.get("error", "未知错误")} + + +async def poll_kiro_login(): + """轮询 Kiro 登录状态""" + success, result = await poll_device_flow() + + if not success: + return {"ok": False, "error": result.get("error", "未知错误")} + + if result.get("completed"): + # 授权完成,保存凭证并添加账号 + credentials = result["credentials"] + + # 保存到文件 + from ...auth.device_flow import save_credentials_to_file + file_path = await save_credentials_to_file(credentials) + + # 尝试获取邮箱作为账号名称 + account_name = "在线登录账号" + try: + from ...credential import KiroCredentials + creds = KiroCredentials( + access_token=credentials.get("accessToken"), + refresh_token=credentials.get("refreshToken"), + auth_method=credentials.get("authMethod", "idc"), + ) + email = await _get_user_email(creds) + if email: + account_name = email + except Exception as e: + print(f"[DeviceFlow] 获取邮箱失败: {e}") + + # 添加账号 + account = Account( + id=uuid.uuid4().hex[:8], + name=account_name, + token_path=file_path + ) + state.accounts.append(account) + account.load_credentials() + state._save_accounts() + + return { + "ok": True, + "completed": True, + "account_id": account.id, + "message": "登录成功,账号已添加" + } + else: + return { + "ok": True, + "completed": False, + "status": result.get("status", "pending") + } + + +async def cancel_kiro_login(): + """取消 Kiro 登录""" + cancelled = cancel_device_flow() + return {"ok": cancelled} + + +async def get_kiro_login_status(): + """获取当前登录状态""" + login_state = get_login_state() + if login_state: + return { + "ok": True, + "in_progress": True, + **login_state + } + else: + return {"ok": True, "in_progress": False} + + +# ==================== Social Auth API (Google/GitHub) ==================== + +async def start_social_login(request: Request): + """启动 Social Auth 登录 (Google/GitHub)""" + body = await request.json() if request.headers.get("content-type") == "application/json" else {} + provider = body.get("provider", "google") + + success, result = await start_social_auth(provider) + + if success: + return { + "ok": True, + "provider": result["provider"], + "login_url": result["login_url"], + "state": result["state"], + } + else: + return {"ok": False, "error": result.get("error", "未知错误")} + + +async def exchange_social_token(request: Request): + """交换 Social Auth Token""" + body = await request.json() + code = body.get("code") + oauth_state = body.get("state") + + if not code or not oauth_state: + return {"ok": False, "error": "缺少 code 或 state"} + + success, result = await exchange_social_auth_token(code, oauth_state) + + if not success: + return {"ok": False, "error": result.get("error", "未知错误")} + + if result.get("completed"): + # 保存凭证并添加账号 + credentials = result["credentials"] + provider = result.get("provider", "Social") + + # 保存到文件 + from ...auth.device_flow import save_credentials_to_file + file_path = await save_credentials_to_file(credentials, f"kiro-{provider.lower()}-auth") + + # 尝试获取邮箱作为账号名称 + account_name = f"{provider} 登录账号" + try: + from ...credential import KiroCredentials + creds = KiroCredentials( + access_token=credentials.get("accessToken"), + refresh_token=credentials.get("refreshToken"), + provider=provider, + ) + email = await _get_user_email(creds) + if email: + account_name = email + except Exception as e: + print(f"[SocialAuth] 获取邮箱失败: {e}") + + # 添加账号 + account = Account( + id=uuid.uuid4().hex[:8], + name=account_name, + token_path=file_path + ) + state.accounts.append(account) + account.load_credentials() + state._save_accounts() + + return { + "ok": True, + "completed": True, + "account_id": account.id, + "provider": provider, + "message": f"{provider} 登录成功,账号已添加" + } + + return {"ok": False, "error": "Token 交换失败"} + + +async def cancel_social_login(): + """取消 Social Auth 登录""" + cancelled = cancel_social_auth() + return {"ok": cancelled} + + +async def get_social_login_status(): + """获取 Social Auth 状态""" + auth_state = get_social_auth_state() + if auth_state: + return { + "ok": True, + "in_progress": True, + **auth_state + } + else: + return {"ok": True, "in_progress": False} + + +# ==================== Flow Monitor API ==================== + +async def get_flows( + protocol: str = None, + model: str = None, + account_id: str = None, + state_filter: str = None, + has_error: bool = None, + bookmarked: bool = None, + search: str = None, + limit: int = 50, + offset: int = 0, +): + """查询 Flows""" + from ...core.flow_monitor import FlowState + + state_enum = None + if state_filter: + try: + state_enum = FlowState(state_filter) + except ValueError: + pass + + flows = flow_monitor.query( + protocol=protocol, + model=model, + account_id=account_id, + state=state_enum, + has_error=has_error, + bookmarked=bookmarked, + search=search, + limit=limit, + offset=offset, + ) + + return { + "flows": [f.to_dict() for f in flows], + "total": len(flows), + } + + +async def get_flow_detail(flow_id: str): + """获取 Flow 详情""" + flow = flow_monitor.get_flow(flow_id) + if not flow: + raise HTTPException(404, "Flow not found") + return flow.to_full_dict() + + +async def get_flow_stats(): + """获取 Flow 统计""" + return flow_monitor.get_stats() + + +async def bookmark_flow(flow_id: str, request: Request): + """书签 Flow""" + body = await request.json() + bookmarked = body.get("bookmarked", True) + flow_monitor.bookmark_flow(flow_id, bookmarked) + return {"ok": True} + + +async def add_flow_note(flow_id: str, request: Request): + """添加 Flow 备注""" + body = await request.json() + note = body.get("note", "") + flow_monitor.add_note(flow_id, note) + return {"ok": True} + + +async def add_flow_tag(flow_id: str, request: Request): + """添加 Flow 标签""" + body = await request.json() + tag = body.get("tag", "") + if tag: + flow_monitor.add_tag(flow_id, tag) + return {"ok": True} + + +async def export_flows(request: Request): + """导出 Flows""" + body = await request.json() + flow_ids = body.get("flow_ids", []) + format = body.get("format", "json") + + content = flow_monitor.export(flow_ids if flow_ids else None, format) + return {"content": content, "format": format} + + +# ==================== Usage API ==================== + +async def get_account_usage_info(account_id: str): + """获取账号用量信息""" + for acc in state.accounts: + if acc.id == account_id: + success, result = await get_account_usage(acc) + if success: + return { + "ok": True, + "account_id": account_id, + "account_name": acc.name, + "usage": { + "subscription_title": result.subscription_title, + "usage_limit": result.usage_limit, + "current_usage": result.current_usage, + "balance": result.balance, + "is_low_balance": result.is_low_balance, + "free_trial_limit": result.free_trial_limit, + "free_trial_usage": result.free_trial_usage, + "bonus_limit": result.bonus_limit, + "bonus_usage": result.bonus_usage, + } + } + else: + return {"ok": False, "error": result.get("error", "查询失败")} + raise HTTPException(404, "Account not found") + + +# ==================== 账号导入导出 API ==================== + +async def export_accounts(): + """导出所有账号配置(包含 token)""" + accounts_data = [] + for acc in state.accounts: + creds = acc.get_credentials() + if creds: + accounts_data.append({ + "name": acc.name, + "enabled": acc.enabled, + "credentials": { + "accessToken": creds.access_token, + "refreshToken": creds.refresh_token, + "expiresAt": creds.expires_at, + "region": creds.region or "us-east-1", + "authMethod": creds.auth_method or "social", + "clientId": creds.client_id, + "clientSecret": creds.client_secret, + } + }) + return { + "ok": True, + "accounts": accounts_data, + "exported_at": datetime.now().isoformat(), + "version": "1.0" + } + + +async def import_accounts(request: Request): + """导入账号配置 + + 支持: + - Refresh Token 必填,Access Token 可选 + - 账号名可选(可自动获取邮箱) + - 根据 Refresh Token 去重 + """ + body = await request.json() + accounts_data = body.get("accounts", []) + imported = 0 + + for acc_data in accounts_data: + token_path = acc_data.get("token_path", "") + if Path(token_path).exists(): + if not any(a.token_path == token_path for a in state.accounts): + account = Account( + id=uuid.uuid4().hex[:8], + name=acc_data.get("name", "导入账号"), + token_path=token_path, + enabled=acc_data.get("enabled", True) + ) + state.accounts.append(account) + account.load_credentials() + imported += 1 + + # 保存配置 + state._save_accounts() + + return {"ok": True, "imported": imported} + + +async def refresh_token_check(): + """检查所有账号的 token 状态""" + results = [] + for acc in state.accounts: + creds = acc.get_credentials() + if creds: + results.append({ + "id": acc.id, + "name": acc.name, + "valid": not acc.is_token_expired(), + "expiring_soon": acc.is_token_expiring_soon(), + "expires": creds.expires_at, + "auth_method": creds.auth_method, + "has_refresh_token": bool(creds.refresh_token), + }) + else: + results.append({ + "id": acc.id, + "name": acc.name, + "valid": False, + "error": "无法加载凭证" + }) + + return {"accounts": results} + + +async def add_manual_token(request: Request): + """手动添加 Token + + 支持: + - Refresh Token 必填,Access Token 可选(可通过 Refresh Token 获取) + - 账号名可选(可自动获取邮箱作为名称) + - 根据 Refresh Token 去重 + - 支持 authMethod: social/idc + - 支持 provider: Google/Github (社交登录) + - 支持 clientId/clientSecret (IDC 认证) + """ + body = await request.json() + access_token = body.get("access_token", "").strip() + refresh_token = body.get("refresh_token", "").strip() + name = body.get("name", "").strip() + region = body.get("region", "us-east-1") + auth_method = body.get("auth_method", "social") + provider = body.get("provider", "").strip() # Google/Github + client_id = body.get("client_id", "").strip() + client_secret = body.get("client_secret", "").strip() + + # Refresh Token 必填 + if not refresh_token: + raise HTTPException(400, "缺少 refresh_token(必填)") + + # IDC 认证需要 clientId 和 clientSecret + if auth_method == "idc" and (not client_id or not client_secret): + raise HTTPException(400, "IDC 认证需要 client_id 和 client_secret") + + # 检查 Refresh Token 是否已存在(去重) + for acc in state.accounts: + creds = acc.get_credentials() + if creds and creds.refresh_token == refresh_token: + raise HTTPException(400, f"该 Refresh Token 已存在,对应账号: {acc.name} ({acc.id})") + + # 构建凭证对象 + from ...credential import KiroCredentials, TokenRefresher + + creds = KiroCredentials( + access_token=access_token if access_token else None, + refresh_token=refresh_token, + region=region, + auth_method=auth_method, + provider=provider if provider else None, + client_id=client_id if client_id else None, + client_secret=client_secret if client_secret else None, + ) + + # 如果没有 Access Token,通过 Refresh Token 获取 + if not access_token: + refresher = TokenRefresher(creds) + success, result = await refresher.refresh() + if not success: + raise HTTPException(400, f"无法通过 Refresh Token 获取 Access Token: {result}") + # refresh 成功后 creds.access_token 已更新 + + # 如果没有提供名称,尝试获取邮箱作为名称 + auto_name = None + if not name: + try: + email = await _get_user_email(creds) + if email: + auto_name = email + except Exception as e: + print(f"[AddAccount] 获取邮箱失败: {e}") + + final_name = name or auto_name or "手动添加账号" + + # 构建保存的凭证数据 + creds_data = { + "accessToken": creds.access_token, + "refreshToken": creds.refresh_token, + "expiresAt": creds.expires_at, + "region": region, + "authMethod": auth_method, + "profileArn": creds.profile_arn, + } + + # 添加 provider 字段(社交登录) + if provider: + creds_data["provider"] = provider + + # 添加 IDC 认证字段 + if client_id: + creds_data["clientId"] = client_id + if client_secret: + creds_data["clientSecret"] = client_secret + + # 保存凭证到文件 + file_path = await save_credentials_to_file(creds_data, f"manual-{uuid.uuid4().hex[:8]}") + + # 添加账号 + account = Account( + id=uuid.uuid4().hex[:8], + name=final_name, + token_path=file_path + ) + state.accounts.append(account) + account.load_credentials() + state._save_accounts() + + return { + "ok": True, + "account_id": account.id, + "name": final_name, + "auto_name": auto_name is not None + } + + +async def batch_import_accounts(request: Request): + """批量导入账号 + + 接收 JSON 数组,每个元素包含: + - refresh_token: 必填 + - access_token: 可选 + - name: 可选(自动获取邮箱) + - auth_method: 可选,默认 social + - provider: 可选 (Google/Github) + - client_id, client_secret: IDC 认证需要 + - region: 可选,默认 us-east-1 + + 返回导入结果统计 + """ + body = await request.json() + accounts_data = body.get("accounts", []) + + if not accounts_data: + raise HTTPException(400, "accounts 数组为空") + + results = { + "total": len(accounts_data), + "success": 0, + "skipped": 0, + "failed": 0, + "details": [] + } + + # 获取现有 refresh_token 集合(去重) + existing_refresh_tokens = set() + for acc in state.accounts: + creds = acc.get_credentials() + if creds and creds.refresh_token: + existing_refresh_tokens.add(creds.refresh_token) + + from ...credential import KiroCredentials, TokenRefresher + + for i, acc_data in enumerate(accounts_data): + try: + refresh_token = acc_data.get("refresh_token", "").strip() + access_token = acc_data.get("access_token", "").strip() + name = acc_data.get("name", "").strip() + auth_method = acc_data.get("auth_method", "social") + provider = acc_data.get("provider", "").strip() + client_id = acc_data.get("client_id", "").strip() + client_secret = acc_data.get("client_secret", "").strip() + region = acc_data.get("region", "us-east-1") + + # 验证必填字段 + if not refresh_token: + results["failed"] += 1 + results["details"].append({"index": i, "status": "failed", "error": "缺少 refresh_token"}) + continue + + # 去重检查 + if refresh_token in existing_refresh_tokens: + results["skipped"] += 1 + results["details"].append({"index": i, "status": "skipped", "error": "refresh_token 已存在"}) + continue + + # IDC 认证验证 + if auth_method == "idc" and (not client_id or not client_secret): + results["failed"] += 1 + results["details"].append({"index": i, "status": "failed", "error": "IDC 认证需要 client_id 和 client_secret"}) + continue + + # 构建凭证 + creds = KiroCredentials( + access_token=access_token if access_token else None, + refresh_token=refresh_token, + region=region, + auth_method=auth_method, + provider=provider if provider else None, + client_id=client_id if client_id else None, + client_secret=client_secret if client_secret else None, + ) + + # 如果没有 access_token,尝试刷新获取 + if not access_token: + refresher = TokenRefresher(creds) + success, result = await refresher.refresh() + if not success: + results["failed"] += 1 + results["details"].append({"index": i, "status": "failed", "error": f"Token 刷新失败: {result}"}) + continue + + # 获取邮箱作为名称 + final_name = name + if not final_name: + try: + email = await _get_user_email(creds) + if email: + final_name = email + except Exception: + pass + final_name = final_name or f"批量导入账号 {i+1}" + + # 保存凭证 + creds_data = { + "accessToken": creds.access_token, + "refreshToken": creds.refresh_token, + "expiresAt": creds.expires_at, + "region": region, + "authMethod": auth_method, + "profileArn": creds.profile_arn, + } + if provider: + creds_data["provider"] = provider + if client_id: + creds_data["clientId"] = client_id + if client_secret: + creds_data["clientSecret"] = client_secret + + file_path = await save_credentials_to_file(creds_data, f"batch-{uuid.uuid4().hex[:8]}") + + # 添加账号 + account = Account( + id=uuid.uuid4().hex[:8], + name=final_name, + token_path=file_path + ) + state.accounts.append(account) + account.load_credentials() + + # 添加到已存在集合 + existing_refresh_tokens.add(refresh_token) + + results["success"] += 1 + results["details"].append({"index": i, "status": "success", "account_id": account.id, "name": final_name}) + + except Exception as e: + results["failed"] += 1 + results["details"].append({"index": i, "status": "failed", "error": str(e)}) + + # 保存配置 + state._save_accounts() + + return { + "ok": True, + **results + } + + +async def _get_user_email(creds: 'KiroCredentials') -> Optional[str]: + """通过 Kiro API 获取用户邮箱""" + from ...core.kiro_api import get_user_email + + if not creds.access_token: + return None + + # 获取 provider + provider = creds.provider or "Google" + + try: + email = await get_user_email(creds.access_token, provider) + if email: + print(f"[GetUserEmail] 成功获取邮箱: {email}") + return email + except Exception as e: + print(f"[GetUserEmail] 请求失败: {e}") + + return None + + +# ==================== 额度管理 API ==================== + +async def get_accounts_status_enhanced(): + """获取完整账号状态(增强版)""" + return { + "ok": True, + "summary": state.get_accounts_summary(), + "accounts": state.get_accounts_status() + } + + +async def refresh_account_quota(account_id: str): + """刷新单个账号额度""" + from ...core import get_quota_scheduler + scheduler = get_quota_scheduler() + + success = await scheduler.refresh_account(account_id) + + if success: + return {"ok": True, "message": f"账号 {account_id} 额度刷新成功"} + else: + return {"ok": False, "error": f"账号 {account_id} 额度刷新失败"} + + +async def refresh_all_quotas(): + """刷新所有账号额度""" + from ...core import get_quota_scheduler + scheduler = get_quota_scheduler() + + results = await scheduler.refresh_all() + + success_count = sum(1 for v in results.values() if v) + fail_count = len(results) - success_count + + return { + "ok": True, + "results": results, + "success_count": success_count, + "fail_count": fail_count + } + + +# ==================== 优先账号 API ==================== + +async def get_priority_accounts(): + """获取优先账号列表""" + from ...core import get_account_selector + selector = get_account_selector() + + priority_ids = selector.get_priority_accounts() + + # 获取账号详情 + priority_accounts = [] + for pid in priority_ids: + for acc in state.accounts: + if acc.id == pid: + priority_accounts.append({ + "id": acc.id, + "name": acc.name, + "enabled": acc.enabled, + "available": acc.is_available(), + "order": selector.get_priority_order(acc.id) + }) + break + + return { + "ok": True, + "priority_accounts": priority_accounts, + "strategy": selector.strategy.value + } + + +async def set_priority_account(account_id: str, request: Request): + """设置优先账号""" + from ...core import get_account_selector + selector = get_account_selector() + + body = await request.json() if request.headers.get("content-type") == "application/json" else {} + position = body.get("position", -1) + + valid_ids = state.get_valid_account_ids() + success, message = selector.add_priority_account(account_id, position, valid_ids) + + return {"ok": success, "message": message} + + +async def remove_priority_account(account_id: str): + """取消优先账号""" + from ...core import get_account_selector + selector = get_account_selector() + + success, message = selector.remove_priority_account(account_id) + + return {"ok": success, "message": message} + + +async def reorder_priority_accounts(request: Request): + """调整优先账号顺序""" + from ...core import get_account_selector + selector = get_account_selector() + + body = await request.json() + account_ids = body.get("account_ids", []) + + success, message = selector.reorder_priority(account_ids) + + return {"ok": success, "message": message} + + +# ==================== 汇总统计 API ==================== + +async def get_accounts_summary(): + """获取账号汇总统计""" + return { + "ok": True, + "summary": state.get_accounts_summary() + } + + +# ==================== 刷新进度 API ==================== + +async def get_refresh_progress(): + """获取刷新进度""" + from ...core import get_refresh_manager + manager = get_refresh_manager() + + progress = manager.get_progress_dict() + is_refreshing = manager.is_refreshing() + + if progress: + return { + "ok": True, + "is_refreshing": is_refreshing, + "progress": progress, + "progress_percent": progress.get("total", 0) and round( + (progress.get("completed", 0) / progress.get("total", 1)) * 100, 2 + ) + } + else: + return { + "ok": True, + "is_refreshing": is_refreshing, + "progress": None, + "message": "没有进行中的刷新操作" + } + + +async def refresh_all_with_progress(): + """批量刷新(带进度和锁检查) + + 使用 RefreshManager 进行批量刷新,支持: + - 全局锁防止重复刷新 + - 进度跟踪 + - 自动刷新 Token + - 重试机制 + + 注意:刷新操作在后台执行,API 立即返回,前端通过轮询获取进度。 + """ + import asyncio + from ...core import get_refresh_manager, get_account_usage + manager = get_refresh_manager() + + # 检查是否已有刷新在进行 + if manager.is_refreshing(): + progress = manager.get_progress_dict() + return { + "ok": False, + "error": "刷新操作正在进行中", + "progress": progress + } + + # 定义获取额度的函数 + async def get_quota_func(account): + """获取账号额度""" + success, result = await get_account_usage(account) + if success: + # 更新额度缓存 + from ...core import get_quota_cache + from ...core.quota_cache import CachedQuota + quota_cache = get_quota_cache() + cached_quota = CachedQuota.from_usage_info(account.id, result) + quota_cache.set(account.id, cached_quota) + + # 自动启用/禁用账号 + if cached_quota.is_exhausted: + # 额度用尽,自动禁用 + if account.enabled: + account.enabled = False + if hasattr(account, "auto_disabled"): + account.auto_disabled = True + print(f"[RefreshManager] 账号 {account.id} ({account.name}) 额度已用尽,自动禁用") + else: + # 有额度,自动启用(仅对自动禁用的账号生效) + if (not account.enabled) and getattr(account, "auto_disabled", False): + account.enabled = True + account.auto_disabled = False + print(f"[RefreshManager] 账号 {account.id} ({account.name}) 有可用额度,自动启用") + + return True, result + else: + return False, result + + # 定义后台刷新任务 + async def background_refresh(): + """后台执行刷新""" + try: + await manager.refresh_all_with_token( + accounts=state.accounts, + get_quota_func=get_quota_func, + skip_disabled=False, # 不跳过禁用账号,以便检查是否可以解禁 + skip_error=False # 不跳过错误账号,以便检查是否已恢复 + ) + # 刷新完成后保存账号配置(因为可能有启用/禁用状态变化) + state._save_accounts() + except Exception as e: + print(f"[RefreshManager] 后台刷新异常: {e}") + + # 启动后台任务,不等待完成 + asyncio.create_task(background_refresh()) + + # 立即返回,前端通过轮询获取进度 + return { + "ok": True, + "message": "刷新任务已启动,请通过 /api/refresh/progress 获取进度" + } + + +async def get_refresh_config(): + """获取刷新配置""" + from ...core import get_refresh_manager + manager = get_refresh_manager() + + config = manager.config + return { + "ok": True, + "config": config.to_dict() + } + + +async def update_refresh_config(request: Request): + """更新刷新配置""" + from ...core import get_refresh_manager + manager = get_refresh_manager() + + body = await request.json() + + try: + # 更新配置 + manager.update_config(**body) + + return { + "ok": True, + "config": manager.config.to_dict(), + "message": "配置更新成功" + } + except ValueError as e: + return { + "ok": False, + "error": str(e) + } + + +async def get_refresh_manager_status(): + """获取刷新管理器状态""" + from ...core import get_refresh_manager + manager = get_refresh_manager() + + status = manager.get_status() + auto_refresh_status = manager.get_auto_refresh_status() + + return { + "ok": True, + "status": status, + "auto_refresh": auto_refresh_status, + "last_refresh_time": manager.get_last_refresh_time() + } + + +# ==================== 集成 RefreshManager 到现有刷新接口 ==================== + +async def refresh_account_token_with_manager(account_id: str): + """刷新指定账号的 token(集成 RefreshManager) + + 刷新前自动检查 Token 状态,使用 RefreshManager 的重试机制。 + """ + from ...core import get_refresh_manager + manager = get_refresh_manager() + + # 查找账号 + account = None + for acc in state.accounts: + if acc.id == account_id: + account = acc + break + + if not account: + return {"ok": False, "error": "账号不存在"} + + # 使用 RefreshManager 的重试机制刷新 Token + success, result = await manager.retry_with_backoff( + account.refresh_token + ) + + if success: + return {"ok": True, "message": "Token 刷新成功"} + else: + return {"ok": False, "error": f"Token 刷新失败: {result}"} + + +async def refresh_account_quota_with_token(account_id: str): + """刷新单个账号额度(先刷新 Token) + + 在获取额度前自动检查并刷新 Token(如果需要)。 + """ + from ...core import get_refresh_manager, get_account_usage, get_quota_cache + manager = get_refresh_manager() + + # 查找账号 + account = None + for acc in state.accounts: + if acc.id == account_id: + account = acc + break + + if not account: + return {"ok": False, "error": "账号不存在"} + + # 先刷新 Token(如果需要) + token_success, token_msg = await manager.refresh_token_if_needed(account) + + if not token_success: + return {"ok": False, "error": f"Token 刷新失败: {token_msg}"} + + # 获取额度 + success, result = await get_account_usage(account) + + if success: + # 更新额度缓存 + from ...core.quota_cache import CachedQuota + quota_cache = get_quota_cache() + cached_quota = CachedQuota.from_usage_info(account.id, result) + quota_cache.set(account.id, cached_quota) + + # 自动启用/禁用账号 + auto_status_changed = False + if cached_quota.is_exhausted: + # 额度用尽,自动禁用 + if account.enabled: + account.enabled = False + if hasattr(account, "auto_disabled"): + account.auto_disabled = True + auto_status_changed = True + print(f"[RefreshManager] 账号 {account.id} ({account.name}) 额度已用尽,自动禁用") + else: + # 有额度,自动启用(仅对自动禁用的账号生效) + if (not account.enabled) and getattr(account, "auto_disabled", False): + account.enabled = True + account.auto_disabled = False + auto_status_changed = True + print(f"[RefreshManager] 账号 {account.id} ({account.name}) 有可用额度,自动启用") + + # 如果状态变化,保存配置 + if auto_status_changed: + state._save_accounts() + + return { + "ok": True, + "message": f"账号 {account_id} 额度刷新成功", + "token_refreshed": token_msg != "Token 有效,无需刷新", + "auto_enabled": auto_status_changed and account.enabled, + "auto_disabled": auto_status_changed and not account.enabled, + "usage": { + "balance": result.balance, + "current_usage": result.current_usage, + "usage_limit": result.usage_limit + } + } + else: + error_msg = result.get("error", "Unknown error") if isinstance(result, dict) else str(result) + + # 更新额度缓存,包含错误信息(用于检测封禁) + from ...core.quota_cache import CachedQuota + quota_cache = get_quota_cache() + cached_quota = CachedQuota.from_error(account.id, error_msg) + quota_cache.set(account.id, cached_quota) + + return {"ok": False, "error": f"获取额度失败: {error_msg}"} + + +# ==================== 协议注册 API ==================== + +async def register_kiro_protocol(): + """注册 kiro:// 协议""" + from ...core.protocol_handler import ( + register_protocol_windows, + start_callback_server, + is_protocol_registered + ) + + # 启动回调服务器 + server_success, server_result = start_callback_server() + if not server_success: + return {"ok": False, "error": f"启动回调服务器失败: {server_result}"} + + # 注册协议 + reg_success, reg_msg = register_protocol_windows() + + return { + "ok": reg_success, + "message": reg_msg, + "callback_port": server_result if server_success else None, + "is_registered": is_protocol_registered() + } + + +async def unregister_kiro_protocol(): + """取消注册 kiro:// 协议""" + from ...core.protocol_handler import ( + unregister_protocol_windows, + stop_callback_server + ) + + # 停止回调服务器 + stop_callback_server() + + # 取消注册协议 + success, msg = unregister_protocol_windows() + + return {"ok": success, "message": msg} + + +async def get_protocol_status(): + """获取协议注册状态""" + from ...core.protocol_handler import is_protocol_registered, CALLBACK_PORT + + return { + "is_registered": is_protocol_registered(), + "callback_port": CALLBACK_PORT + } + + +async def get_callback_result(): + """获取回调结果""" + from ...core.protocol_handler import get_callback_result as _get_result, clear_callback_result + + result = _get_result() + if result: + # 清除结果,避免重复获取 + clear_callback_result() + return {"ok": True, "result": result} + else: + return {"ok": False, "result": None} diff --git a/KiroProxy/kiro_proxy/handlers/anthropic/__init__.py b/KiroProxy/kiro_proxy/handlers/anthropic/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..24c39d150dee192b3732d723c99abe20f28ba522 --- /dev/null +++ b/KiroProxy/kiro_proxy/handlers/anthropic/__init__.py @@ -0,0 +1,1592 @@ +"""Anthropic 协议处理 - /v1/messages""" +import json +import uuid +import time +import asyncio +import httpx +from fastapi import Request, HTTPException +from fastapi.responses import StreamingResponse + +from ...config import KIRO_API_URL, map_model_name +from ...core import state, RetryableRequest, is_retryable_error, stats_manager, flow_monitor, TokenUsage +from ...core.state import RequestLog +from ...core.history_manager import HistoryManager, get_history_config, is_content_length_error, TruncateStrategy +from ...core.error_handler import classify_error, ErrorType, format_error_log +from ...core.rate_limiter import get_rate_limiter +from ...credential import quota_manager +from ...kiro_api import build_headers, build_kiro_request, parse_event_stream_full, parse_event_stream, is_quota_exceeded_error +from ...core.thinking import ( + ThinkingConfig, + build_thinking_prompt, + build_user_prompt_with_thinking, + infer_thinking_from_anthropic_messages, + normalize_thinking_config, + strip_thinking_from_history, +) +from ...converters import ( + generate_session_id, + convert_anthropic_tools_to_kiro, + convert_anthropic_messages_to_kiro, + convert_kiro_response_to_anthropic, + extract_images_from_content, + inject_thinking_tags_to_system, + find_real_thinking_start_tag, + find_real_thinking_end_tag, + extract_thinking_from_content +) + +# 尝试导入 tiktoken,如果失败则使用估算方法 +try: + import tiktoken + _encoding = tiktoken.get_encoding("cl100k_base") + _USE_TIKTOKEN = True +except ImportError: + _encoding = None + _USE_TIKTOKEN = False + + +def _extract_text_from_content(content) -> str: + if content is None: + return "" + if isinstance(content, str): + return content + if isinstance(content, list): + parts = [] + for item in content: + parts.append(_extract_text_from_content(item)) + return "".join(parts) + if isinstance(content, dict): + if "text" in content and isinstance(content.get("text"), str): + return content["text"] + if "content" in content: + return _extract_text_from_content(content.get("content")) + return "" + + +def _estimate_tokens(text: str) -> int: + """估算/计算 token 数量 + + 优先使用 tiktoken (cl100k_base),否则使用字符估算: + - 中文字符:约 1.5 字符 = 1 token + - 其他字符:约 4 字符 = 1 token + """ + if not text: + return 0 + + if _USE_TIKTOKEN and _encoding: + return len(_encoding.encode(text)) + + # 回退到字符估算 + chinese_chars = sum(1 for c in text if '\u4e00' <= c <= '\u9fff') + other_chars = len(text) - chinese_chars + tokens = int(chinese_chars / 1.5) + int(other_chars / 4) + return max(1, tokens) + + +def _count_tokens_from_messages(messages, system: str = "") -> int: + total = _estimate_tokens(system) if system else 0 + for msg in messages or []: + total += _estimate_tokens(_extract_text_from_content(msg.get("content"))) + return total + + +def _estimate_output_tokens_from_text(text: str) -> int: + return _estimate_tokens(text) + + +async def _check_and_disable_if_exhausted(account): + """检查账号额度,如果为 0 则禁用账号 + + Args: + account: 账号对象 + """ + if not account: + return + + try: + from ...core.usage import get_account_usage + from ...core.quota_cache import CachedQuota, get_quota_cache + + success, result = await get_account_usage(account) + if success: + quota = CachedQuota.from_usage_info(account.id, result) + get_quota_cache().set(account.id, quota) + + if quota.is_exhausted: + account.enabled = False + if hasattr(account, "auto_disabled"): + account.auto_disabled = True + from ...core.state import state + state._save_accounts() + print(f"[Account] 账号 {account.id} ({account.name}) 额度已用尽,自动禁用") + except Exception as e: + print(f"[Account] 检查账号 {account.id} 额度失败: {e}") + + +def _handle_kiro_error(status_code: int, error_text: str, account): + """处理 Kiro API 错误,返回 (http_status, error_type, error_message)""" + error = classify_error(status_code, error_text) + + # 打印友好的错误日志 + print(format_error_log(error, account.id if account else None)) + + # 账号封禁 - 禁用账号 + if error.should_disable_account and account: + account.enabled = False + if hasattr(account, "auto_disabled"): + account.auto_disabled = False + from ...credential import CredentialStatus + account.status = CredentialStatus.SUSPENDED + try: + from ...core import state as _state + _state._save_accounts() + except Exception: + pass + print(f"[Account] 账号 {account.id} 已被禁用 (封禁)") + + # 仅 429 状态码触发冷却(不再根据错误文本判断) + elif status_code == 429 and account: + account.mark_quota_exceeded(error.message[:100]) + + # 其他错误(非 429、非内容过长)- 异步检查额度 + elif error.type not in (ErrorType.RATE_LIMITED, ErrorType.CONTENT_TOO_LONG) and account: + import asyncio + asyncio.create_task(_check_and_disable_if_exhausted(account)) + + # 映射错误类型 + error_type_map = { + ErrorType.ACCOUNT_SUSPENDED: (403, "authentication_error"), + ErrorType.RATE_LIMITED: (429, "rate_limit_error"), + ErrorType.CONTENT_TOO_LONG: (400, "invalid_request_error"), + ErrorType.AUTH_FAILED: (401, "authentication_error"), + ErrorType.SERVICE_UNAVAILABLE: (503, "api_error"), + ErrorType.MODEL_UNAVAILABLE: (503, "overloaded_error"), + ErrorType.UNKNOWN: (500, "api_error"), + } + + http_status, err_type = error_type_map.get(error.type, (500, "api_error")) + return http_status, err_type, error.user_message, error + + +async def handle_count_tokens(request: Request): + '''Handle /v1/messages/count_tokens requests.''' + body = await request.json() + messages = body.get("messages", []) + system = body.get("system", "") + if not messages and not system: + raise HTTPException(400, "messages required") + return {"input_tokens": _count_tokens_from_messages(messages, system)} + + +async def _call_kiro_for_summary(prompt: str, account, headers: dict) -> str: + """调用 Kiro API 生成摘要(内部使用)""" + kiro_request = build_kiro_request(prompt, "claude-haiku-4.5", []) # 用快速模型生成摘要 + try: + async with httpx.AsyncClient(verify=False, timeout=60) as client: + resp = await client.post(KIRO_API_URL, json=kiro_request, headers=headers) + if resp.status_code == 200: + return parse_event_stream(resp.content) + except Exception as e: + print(f"[Summary] API 调用失败: {e}") + return "" + + +async def handle_messages(request: Request): + """处理 /v1/messages 请求""" + start_time = time.time() + log_id = uuid.uuid4().hex[:8] + + try: + hdrs = request.headers + print( + f"[Anthropic][Headers:{log_id}] accept={hdrs.get('accept')} content-type={hdrs.get('content-type')} " + f"anthropic-beta={hdrs.get('anthropic-beta')} user-agent={hdrs.get('user-agent')}" + ) + except Exception: + pass + + body = await request.json() + model = map_model_name(body.get("model", "claude-sonnet-4")) + messages = body.get("messages", []) + system = body.get("system", "") + stream = body.get("stream", False) + tools = body.get("tools", []) + + # 处理思考功能(Extended Thinking) + thinking_explicit = "thinking" in body + thinking_cfg: ThinkingConfig = ( + normalize_thinking_config(body.get("thinking")) if thinking_explicit else ThinkingConfig(False, None) + ) + if not thinking_explicit and infer_thinking_from_anthropic_messages(messages): + # Claude Code 可能只在首轮携带 thinking 配置;如果历史里已经出现 thinking block,默认继承开启。 + thinking_cfg = ThinkingConfig(True, None) + + # 启用思考模式:使用“独立请求”生成思维链,避免向主请求注入提示词污染上下文 + if thinking_cfg.enabled: + print( + f"[Anthropic] Thinking mode enabled (separate request): budget_tokens={thinking_cfg.budget_tokens if thinking_cfg.budget_tokens is not None else 'unlimited'}" + ) + + # 调试:打印原始请求的关键信息 + print( + f"[Anthropic] Request: model={body.get('model')} -> {model}, messages={len(messages)}, stream={stream}, tools={len(tools)}, thinking={'enabled' if thinking_cfg.enabled else 'disabled'}" + ) + + if not messages: + raise HTTPException(400, "messages required") + + session_id = generate_session_id(messages) + account = state.get_available_account(session_id) + + if not account: + raise HTTPException(503, "All accounts are rate limited or unavailable") + + # 创建 Flow 记录 + flow_id = flow_monitor.create_flow( + protocol="anthropic", + method="POST", + path="/v1/messages", + headers=dict(request.headers), + body=body, + account_id=account.id, + account_name=account.name, + ) + + # 检查 token 是否即将过期,尝试刷新 + if account.is_token_expiring_soon(5): + print(f"[Anthropic] Token 即将过期,尝试刷新: {account.id}") + success, msg = await account.refresh_token() + if not success: + print(f"[Anthropic] Token 刷新失败: {msg}") + + token = account.get_token() + if not token: + flow_monitor.fail_flow(flow_id, "authentication_error", f"Failed to get token for account {account.name}") + raise HTTPException(500, f"Failed to get token for account {account.name}") + + # 使用账号的动态 Machine ID(提前构建,供摘要使用) + creds = account.get_credentials() + headers = build_headers( + token, + machine_id=account.get_machine_id(), + profile_arn=creds.profile_arn if creds else None, + client_id=creds.client_id if creds else None + ) + + # 限速检查 + rate_limiter = get_rate_limiter() + can_request, wait_seconds, reason = rate_limiter.can_request(account.id) + if not can_request: + print(f"[Anthropic] 限速: {reason}") + await asyncio.sleep(wait_seconds) + + # 转换消息格式 + user_content, history, tool_results = convert_anthropic_messages_to_kiro(messages, system) + + # 历史消息预处理 + history_manager = HistoryManager(get_history_config(), cache_key=session_id) + + # 检查是否需要智能摘要或错误重试预摘要 + async def api_caller(prompt: str) -> str: + return await _call_kiro_for_summary(prompt, account, headers) + if history_manager.should_summarize(history) or history_manager.should_pre_summary_for_error_retry(history, user_content): + history = await history_manager.pre_process_async(history, user_content, api_caller) + else: + history = history_manager.pre_process(history, user_content) + + # 摘要/截断后再次修复历史交替和 toolUses/toolResults 配对 + from ...converters import fix_history_alternation + history = fix_history_alternation(history) + + if history_manager.was_truncated: + print(f"[Anthropic] {history_manager.truncate_info}") + + # 提取最后一条消息中的图片 + images = [] + if messages: + last_msg = messages[-1] + if last_msg.get("role") == "user": + _, images = extract_images_from_content(last_msg.get("content", "")) + + # 构建 Kiro 请求 + kiro_tools = convert_anthropic_tools_to_kiro(tools) if tools else None + clean_history = strip_thinking_from_history(history) + kiro_request = build_kiro_request(user_content, model, clean_history, kiro_tools, images, tool_results) + + if stream: + return await _handle_stream( + kiro_request, + headers, + account, + model, + log_id, + start_time, + session_id, + flow_id, + history, + user_content, + kiro_tools, + images, + tool_results, + history_manager, + thinking_enabled=thinking_cfg.enabled, + budget_tokens=thinking_cfg.budget_tokens, + ) + else: + return await _handle_non_stream( + kiro_request, + headers, + account, + model, + log_id, + start_time, + session_id, + flow_id, + history, + user_content, + kiro_tools, + images, + tool_results, + history_manager, + thinking_enabled=thinking_cfg.enabled, + budget_tokens=thinking_cfg.budget_tokens, + ) + + +async def _handle_stream(kiro_request, headers, account, model, log_id, start_time, session_id=None, flow_id=None, history=None, user_content="", kiro_tools=None, images=None, tool_results=None, history_manager=None, thinking_enabled=False, budget_tokens: int | None = None): + """Handle streaming responses with auto-retry on quota exceeded and network errors. + + When thinking_enabled=True, makes TWO separate API calls: + 1. First call: Generate thinking/reasoning content (streamed as thinking block) + 2. Second call: Generate actual response (streamed as text block) + """ + + async def generate(): + nonlocal kiro_request, history + current_account = account + retry_count = 0 + max_retries = 2 + full_content = "" + saw_any_chunk = False + saw_any_text = False + sent_any_event = False + content_block_index_ref = [0] + + def _next_index() -> int: + content_block_index_ref[0] += 1 + return content_block_index_ref[0] - 1 + + print(f"[Anthropic][Stream:{log_id}] start model={model} account={getattr(current_account, 'id', None)} thinking={thinking_enabled}") + + def _build_thinking_history( + base_history: list | None, + current_user_content: str, + current_tool_results: list | None, + ) -> list: + """Kiro 对 toolUses/toolResults 配对非常严格。 + + 对于 thinking 的"独立请求",最稳妥的做法是: + 完全过滤掉所有涉及 toolUses/toolResults 的历史消息, + 只保留纯文本对话历史。这样可以保证 thinking 请求永远不会因为 + 工具链结构问题触发 Kiro 400。 + """ + # Keep thinking request minimal & robust: + # Strip out ALL tool-related messages to avoid Kiro schema strictness. + hist = list(base_history or []) + if not hist: + return [] + + def _has_tool_uses(item: dict) -> bool: + if not isinstance(item, dict) or "assistantResponseMessage" not in item: + return False + arm = item.get("assistantResponseMessage") or {} + tus = arm.get("toolUses") + return isinstance(tus, list) and len(tus) > 0 + + def _has_tool_results(item: dict) -> bool: + if not isinstance(item, dict) or "userInputMessage" not in item: + return False + uim = item.get("userInputMessage") or {} + ctx = uim.get("userInputMessageContext") or {} + trs = ctx.get("toolResults") + return isinstance(trs, list) and len(trs) > 0 + + # Filter out all messages with toolUses or toolResults + clean_hist = [] + for item in hist: + if _has_tool_uses(item) or _has_tool_results(item): + continue + clean_hist.append(item) + + # After filtering, we need to ensure alternation (user -> assistant -> user -> ...) + # Simple approach: just return what we have; fix_history_alternation should handle it + # But to be extra safe, let's do a quick check + if not clean_hist: + return [] + + # Ensure history ends with assistant (Kiro expects currentMessage to be user) + if clean_hist and isinstance(clean_hist[-1], dict) and "userInputMessage" in clean_hist[-1]: + clean_hist = clean_hist[:-1] + + return clean_hist + + # 思考功能:使用“独立请求”生成思维链(不向主请求注入提示词/标签) + if thinking_enabled: + msg_id = f"msg_{log_id}" + + # 标记开始流式传输(以 message_start 为首包) + if flow_id: + flow_monitor.start_streaming(flow_id) + + sent_any_event = True + yield ( + f'event: message_start\ndata: {{"type":"message_start","message":{{"id":"{msg_id}","type":"message","role":"assistant","content":[],"model":"{model}","stop_reason":null,"stop_sequence":null,"usage":{{"input_tokens":0,"output_tokens":0}}}}}}\n\n' + ) + + # ========== 独立思考请求(thinking block) ========== + thinking_index = _next_index() + yield ( + f'event: content_block_start\ndata: {{"type":"content_block_start","index":{thinking_index},"content_block":{{"type":"thinking","thinking":""}}}}\n\n' + ) + + thinking_accumulated = "" + thinking_retry = 0 + while thinking_retry <= max_retries: + try: + thinking_prompt = build_thinking_prompt(user_content, budget_tokens=budget_tokens) + thinking_history = _build_thinking_history(history, user_content, tool_results) + thinking_request = build_kiro_request( + thinking_prompt, model, thinking_history, None, images, None + ) + + async with httpx.AsyncClient(verify=False, timeout=300) as client: + async with client.stream( + "POST", KIRO_API_URL, json=thinking_request, headers=headers + ) as response: + print(f"[Anthropic][Stream:{log_id}] upstream_thinking_status={response.status_code}") + + if response.status_code != 200: + try: + err_bytes = await response.aread() + err_str = err_bytes.decode(errors="ignore") + except Exception: + err_str = "" + + try: + last_hist = (thinking_history or [])[-4:] + last_hist_shapes = [] + last_hist_ids = [] + for h in last_hist: + if isinstance(h, dict) and "assistantResponseMessage" in h: + arm = h.get("assistantResponseMessage") or {} + tus = arm.get("toolUses") or [] + last_hist_shapes.append({ + "role": "assistant", + "has_tool_uses": bool(tus), + "tool_uses": len(tus), + }) + try: + last_hist_ids.append({ + "role": "assistant", + "toolUseIds": [tu.get("toolUseId") for tu in tus if isinstance(tu, dict) and tu.get("toolUseId")], + }) + except Exception: + last_hist_ids.append({"role": "assistant"}) + elif isinstance(h, dict) and "userInputMessage" in h: + uim = h.get("userInputMessage") or {} + ctx = uim.get("userInputMessageContext") or {} + trs = ctx.get("toolResults") or [] + last_hist_shapes.append({ + "role": "user", + "has_tool_results": bool(trs), + "tool_results": len(trs), + }) + try: + last_hist_ids.append({ + "role": "user", + "toolResultIds": [tr.get("toolUseId") for tr in trs if isinstance(tr, dict) and tr.get("toolUseId")], + }) + except Exception: + last_hist_ids.append({"role": "user"}) + else: + last_hist_shapes.append({"role": str(type(h))}) + last_hist_ids.append({"role": str(type(h))}) + + print("=== Thinking Upstream Error ===") + print(f"Status: {response.status_code}") + print(f"Body (first 800): {err_str[:800]}") + print(f"Thinking request history_len: {len(thinking_history) if thinking_history else 0}") + print(f"Thinking request tool_results: {len(tool_results) if tool_results else 0}") + print(f"Last history shapes: {json.dumps(last_hist_shapes, ensure_ascii=False)}") + print(f"Last history ids: {json.dumps(last_hist_ids, ensure_ascii=False)}") + print("===============================") + except Exception: + pass + + # 仅 429 状态码触发冷却和账号切换 + if response.status_code == 429: + current_account.mark_quota_exceeded("Rate limited (thinking)") + next_account = state.get_next_available_account(current_account.id) + if next_account and thinking_retry < max_retries: + print(f"[Thinking] 429 限流,切换账号: {current_account.id} -> {next_account.id}") + current_account = next_account + headers["Authorization"] = f"Bearer {current_account.get_token()}" + thinking_retry += 1 + continue + break + + # 处理可重试的服务端错误(不触发冷却,仅重试) + if is_retryable_error(response.status_code): + if thinking_retry < max_retries: + print(f"[Thinking] 服务端错误 {response.status_code},重试 {thinking_retry + 1}/{max_retries}") + thinking_retry += 1 + import asyncio + await asyncio.sleep(0.5 * (2 ** thinking_retry)) + continue + break + + if response.status_code != 200: + break + + chunk_buffer = b"" + async for chunk in response.aiter_bytes(): + chunk_buffer += chunk + while len(chunk_buffer) >= 12: + total_len = int.from_bytes(chunk_buffer[0:4], "big") + if len(chunk_buffer) < total_len: + break + + headers_len = int.from_bytes(chunk_buffer[4:8], "big") + payload_start = 12 + headers_len + payload_end = total_len - 4 + + if payload_start < payload_end: + try: + payload = json.loads( + chunk_buffer[payload_start:payload_end].decode("utf-8") + ) + content = None + if "assistantResponseEvent" in payload: + content = payload["assistantResponseEvent"].get("content") + elif "content" in payload: + content = payload.get("content") + + if content: + thinking_accumulated += content + full_content += content + if flow_id: + flow_monitor.add_chunk(flow_id, content) + sent_any_event = True + yield ( + f'event: content_block_delta\ndata: {json.dumps({"type":"content_block_delta","index":thinking_index,"delta":{"type":"thinking_delta","thinking":content}}, separators=(",", ":"), ensure_ascii=False)}\n\n' + ) + except Exception: + pass + + chunk_buffer = chunk_buffer[total_len:] + + # thinking 请求成功计入额度/频率 + current_account.request_count += 1 + current_account.last_used = time.time() + get_rate_limiter().record_request(current_account.id) + break + + except httpx.TimeoutException: + if thinking_retry < max_retries: + print(f"[Thinking] 请求超时,重试 {thinking_retry + 1}/{max_retries}") + thinking_retry += 1 + import asyncio + await asyncio.sleep(0.5 * (2 ** thinking_retry)) + continue + break + except httpx.ConnectError: + if thinking_retry < max_retries: + print(f"[Thinking] 连接错误,重试 {thinking_retry + 1}/{max_retries}") + thinking_retry += 1 + import asyncio + await asyncio.sleep(0.5 * (2 ** thinking_retry)) + continue + break + except Exception as e: + if is_retryable_error(None, e) and thinking_retry < max_retries: + print(f"[Thinking] 网络错误,重试 {thinking_retry + 1}/{max_retries}: {type(e).__name__}") + thinking_retry += 1 + import asyncio + await asyncio.sleep(0.5 * (2 ** thinking_retry)) + continue + break + + # 结束 thinking block(即使未获取到内容也会输出空块) + sent_any_event = True + yield f'event: content_block_stop\ndata: {{"type":"content_block_stop","index":{thinking_index}}}\n\n' + + main_user_content = build_user_prompt_with_thinking(user_content, thinking_accumulated) + clean_history = strip_thinking_from_history(history) + kiro_request = build_kiro_request(main_user_content, model, clean_history, kiro_tools, images, tool_results) + + # ========== 主响应流式处理(text block,独立请求) ========== + main_retry = 0 + while main_retry <= max_retries: + try: + async with httpx.AsyncClient(verify=False, timeout=300) as client: + async with client.stream( + "POST", KIRO_API_URL, json=kiro_request, headers=headers + ) as response: + print(f"[Anthropic][Stream:{log_id}] upstream_status={response.status_code}") + + # 仅 429 状态码触发冷却和账号切换 + if response.status_code == 429: + current_account.mark_quota_exceeded("Rate limited (stream)") + + next_account = state.get_next_available_account(current_account.id) + if next_account and main_retry < max_retries: + print(f"[Stream] 429 限流,切换账号: {current_account.id} -> {next_account.id}") + current_account = next_account + token = current_account.get_token() + headers["Authorization"] = f"Bearer {token}" + main_retry += 1 + continue + + if flow_id: + flow_monitor.fail_flow(flow_id, "rate_limit_error", "All accounts rate limited", 429) + yield f'data: {{"type":"error","error":{{"type":"rate_limit_error","message":"All accounts rate limited"}}}}\n\n' + return + + # 处理可重试的服务端错误(不触发冷却,仅重试) + if is_retryable_error(response.status_code): + if main_retry < max_retries: + print(f"[Stream] 服务端错误 {response.status_code},重试 {main_retry + 1}/{max_retries}") + main_retry += 1 + import asyncio + await asyncio.sleep(0.5 * (2 ** main_retry)) + continue + if flow_id: + flow_monitor.fail_flow(flow_id, "api_error", "Server error after retries", response.status_code) + yield f'data: {{"type":"error","error":{{"type":"api_error","message":"Server error after retries"}}}}\n\n' + return + + if response.status_code != 200: + error_text = await response.aread() + error_str = error_text.decode() + print(f"=== Kiro API Error ===") + print(f"Status: {response.status_code}") + print(f"Response: {error_str[:500]}") + print(f"Request model: {model}") + print(f"History len: {len(history) if history else 0}") + print(f"Tool results: {len(tool_results) if tool_results else 0}") + if response.status_code == 400: + print(f"Kiro request keys: {list(kiro_request.keys())}") + if "conversationState" in kiro_request: + cs = kiro_request["conversationState"] + print(f" conversationState keys: {list(cs.keys())}") + if "currentMessage" in cs: + cm = cs["currentMessage"] + print(f" currentMessage keys: {list(cm.keys())}") + if "userInputMessage" in cm: + uim = cm["userInputMessage"] + print(f" userInputMessage keys: {list(uim.keys())}") + content = uim.get("content", "") + print(f" content (first 200 chars): {str(content)[:200]}") + if "history" in cs: + hist = cs["history"] + print(f" history count: {len(hist) if hist else 0}") + if hist: + for i, h in enumerate(hist[:3]): + print(f" history[{i}] keys: {list(h.keys()) if isinstance(h, dict) else type(h)}") + print(f"======================") + + http_status, error_type, error_msg, error_obj = _handle_kiro_error( + response.status_code, error_str, current_account + ) + + if error_obj.should_switch_account: + next_account = state.get_next_available_account(current_account.id) + if next_account and main_retry < max_retries: + print(f"[Stream] 切换账号: {current_account.id} -> {next_account.id}") + current_account = next_account + headers["Authorization"] = f"Bearer {current_account.get_token()}" + main_retry += 1 + continue + + if error_obj.type == ErrorType.CONTENT_TOO_LONG: + history_chars, user_chars, total_chars = history_manager.estimate_request_chars( + history, main_user_content + ) + print(f"[Stream] 内容长度超限: history={history_chars} chars, user={user_chars} chars, total={total_chars} chars") + async def api_caller(prompt: str) -> str: + return await _call_kiro_for_summary(prompt, current_account, headers) + truncated_history, should_retry = await history_manager.handle_length_error_async( + history, main_retry, api_caller + ) + if should_retry: + print(f"[Stream] 内容长度超限,{history_manager.truncate_info}") + history = truncated_history + clean_history = strip_thinking_from_history(history) + kiro_request = build_kiro_request(main_user_content, model, clean_history, kiro_tools, images, tool_results) + main_retry += 1 + continue + + if flow_id: + flow_monitor.fail_flow(flow_id, error_type, error_msg, response.status_code, error_str) + yield f'data: {{"type":"error","error":{{"type":"{error_type}","message":"{error_msg}"}}}}\n\n' + return + + # text block start + text_index = _next_index() + sent_any_event = True + yield ( + f'event: content_block_start\ndata: {{"type":"content_block_start","index":{text_index},"content_block":{{"type":"text","text":""}}}}\n\n' + ) + + full_response = b"" + chunk_buffer = b"" + async for chunk in response.aiter_bytes(): + if not saw_any_chunk: + saw_any_chunk = True + print(f"[Anthropic][Stream:{log_id}] first_chunk bytes={len(chunk)}") + full_response += chunk + chunk_buffer += chunk + + try: + while len(chunk_buffer) >= 12: + total_len = int.from_bytes(chunk_buffer[0:4], "big") + + if len(chunk_buffer) < total_len: + break + + headers_len = int.from_bytes(chunk_buffer[4:8], "big") + payload_start = 12 + headers_len + payload_end = total_len - 4 + + if payload_start < payload_end: + try: + payload_data = chunk_buffer[payload_start:payload_end] + payload = json.loads(payload_data.decode("utf-8")) + content = None + if "assistantResponseEvent" in payload: + content = payload["assistantResponseEvent"].get("content") + elif "content" in payload: + content = payload.get("content") + if content: + full_content += content + saw_any_text = True + if flow_id: + flow_monitor.add_chunk(flow_id, content) + + sent_any_event = True + yield ( + f'event: content_block_delta\ndata: {json.dumps({"type":"content_block_delta","index":text_index,"delta":{"type":"text_delta","text":content}}, separators=(",", ":"), ensure_ascii=False)}\n\n' + ) + except Exception as e: + print(f"[Stream] Payload parse error: {e}") + pass + + chunk_buffer = chunk_buffer[total_len:] + except Exception as e: + print(f"[Stream] Chunk processing error: {e}") + pass + + # text block stop + sent_any_event = True + yield f'event: content_block_stop\ndata: {{"type":"content_block_stop","index":{text_index}}}\n\n' + + result = parse_event_stream_full(full_response) + + if result["tool_uses"]: + tool_start_index = content_block_index_ref[0] + for i, tool_use in enumerate(result["tool_uses"]): + idx = tool_start_index + i + yield f'event: content_block_start\ndata: {{"type":"content_block_start","index":{idx},"content_block":{{"type":"tool_use","id":"{tool_use["id"]}","name":"{tool_use["name"]}","input":{{}}}}}}\n\n' + partial_json = json.dumps(tool_use.get("input") or {}, ensure_ascii=False) + yield f'event: content_block_delta\ndata: {{"type":"content_block_delta","index":{idx},"delta":{{"type":"input_json_delta","partial_json":{json.dumps(partial_json, ensure_ascii=False)}}}}}\n\n' + yield f'event: content_block_stop\ndata: {{"type":"content_block_stop","index":{idx}}}\n\n' + + stop_reason = result["stop_reason"] + input_tokens = result.get("input_tokens", 0) + output_tokens = result.get("output_tokens", 0) + if not output_tokens and full_content: + output_tokens = _estimate_output_tokens_from_text(full_content) + yield f'event: message_delta\ndata: {{"type":"message_delta","delta":{{"stop_reason":"{stop_reason}","stop_sequence":null}},"usage":{{"input_tokens":{input_tokens},"output_tokens":{output_tokens}}}}}\n\n' + yield f'event: message_stop\ndata: {{"type":"message_stop"}}\n\n' + yield "data: [DONE]\n\n" + print( + f"[Anthropic][Stream:{log_id}] done chunks={saw_any_chunk} text={saw_any_text} sent_events={sent_any_event} " + f"input_tokens={input_tokens} output_tokens={output_tokens} stop_reason={stop_reason}" + ) + + if flow_id: + flow_monitor.complete_flow( + flow_id, + status_code=200, + content=full_content, + tool_calls=result.get("tool_uses", []), + stop_reason=stop_reason, + usage=TokenUsage( + input_tokens=result.get("input_tokens", 0), + output_tokens=result.get("output_tokens", 0), + ), + ) + + current_account.request_count += 1 + current_account.last_used = time.time() + get_rate_limiter().record_request(current_account.id) + + duration = (time.time() - start_time) * 1000 + state.add_log(RequestLog( + id=log_id, + timestamp=time.time(), + method="POST", + path="/v1/messages", + model=model, + account_id=current_account.id if current_account else None, + status=200, + duration_ms=duration, + error=None + )) + return + + except httpx.TimeoutException: + if main_retry < max_retries: + print(f"[Stream] 请求超时,重试 {main_retry + 1}/{max_retries}") + main_retry += 1 + import asyncio + await asyncio.sleep(0.5 * (2 ** main_retry)) + continue + if flow_id: + flow_monitor.fail_flow(flow_id, "timeout_error", "Request timeout after retries", 408) + yield f'data: {{"type":"error","error":{{"type":"api_error","message":"Request timeout after retries"}}}}\n\n' + return + except httpx.ConnectError: + if main_retry < max_retries: + print(f"[Stream] 连接错误,重试 {main_retry + 1}/{max_retries}") + main_retry += 1 + import asyncio + await asyncio.sleep(0.5 * (2 ** main_retry)) + continue + if flow_id: + flow_monitor.fail_flow(flow_id, "connection_error", "Connection error after retries", 502) + yield f'data: {{"type":"error","error":{{"type":"api_error","message":"Connection error after retries"}}}}\n\n' + return + except Exception as e: + if is_retryable_error(None, e) and main_retry < max_retries: + print(f"[Stream] 网络错误,重试 {main_retry + 1}/{max_retries}: {type(e).__name__}") + main_retry += 1 + import asyncio + await asyncio.sleep(0.5 * (2 ** main_retry)) + continue + if flow_id: + flow_monitor.fail_flow(flow_id, "api_error", str(e), 500) + yield f'data: {{"type":"error","error":{{"type":"api_error","message":"{str(e)}"}}}}\n\n' + return + + return + + thinking_processor = ThinkingStreamProcessor(thinking_enabled, index_ref=content_block_index_ref) + + while retry_count <= max_retries: + try: + async with httpx.AsyncClient(verify=False, timeout=300) as client: + async with client.stream("POST", KIRO_API_URL, json=kiro_request, headers=headers) as response: + print(f"[Anthropic][Stream:{log_id}] upstream_status={response.status_code}") + + # 仅 429 状态码触发冷却和账号切换 + if response.status_code == 429: + current_account.mark_quota_exceeded("Rate limited (stream)") + + # 尝试切换账号 + next_account = state.get_next_available_account(current_account.id) + if next_account and retry_count < max_retries: + print(f"[Stream] 429 限流,切换账号: {current_account.id} -> {next_account.id}") + current_account = next_account + token = current_account.get_token() + headers["Authorization"] = f"Bearer {token}" + retry_count += 1 + continue + + if flow_id: + flow_monitor.fail_flow(flow_id, "rate_limit_error", "All accounts rate limited", 429) + yield f'data: {{"type":"error","error":{{"type":"rate_limit_error","message":"All accounts rate limited"}}}}\n\n' + return + + # 处理可重试的服务端错误(不触发冷却,仅重试) + if is_retryable_error(response.status_code): + if retry_count < max_retries: + print(f"[Stream] 服务端错误 {response.status_code},重试 {retry_count + 1}/{max_retries}") + retry_count += 1 + import asyncio + await asyncio.sleep(0.5 * (2 ** retry_count)) + continue + if flow_id: + flow_monitor.fail_flow(flow_id, "api_error", "Server error after retries", response.status_code) + yield f'data: {{"type":"error","error":{{"type":"api_error","message":"Server error after retries"}}}}\n\n' + return + + if response.status_code != 200: + error_text = await response.aread() + error_str = error_text.decode() + print(f"=== Kiro API Error ===") + print(f"Status: {response.status_code}") + print(f"Response: {error_str[:500]}") + print(f"Request model: {model}") + print(f"History len: {len(history) if history else 0}") + print(f"Tool results: {len(tool_results) if tool_results else 0}") + # 对于 400 错误,打印更多请求细节 + if response.status_code == 400: + print(f"Kiro request keys: {list(kiro_request.keys())}") + if 'conversationState' in kiro_request: + cs = kiro_request['conversationState'] + print(f" conversationState keys: {list(cs.keys())}") + if 'currentMessage' in cs: + cm = cs['currentMessage'] + print(f" currentMessage keys: {list(cm.keys())}") + if 'userInputMessage' in cm: + uim = cm['userInputMessage'] + print(f" userInputMessage keys: {list(uim.keys())}") + content = uim.get('content', '') + print(f" content (first 200 chars): {str(content)[:200]}") + if 'history' in cs: + hist = cs['history'] + print(f" history count: {len(hist) if hist else 0}") + if hist: + for i, h in enumerate(hist[:3]): + print(f" history[{i}] keys: {list(h.keys()) if isinstance(h, dict) else type(h)}") + print(f"======================") + + # 使用统一的错误处理 + http_status, error_type, error_msg, error_obj = _handle_kiro_error( + response.status_code, error_str, current_account + ) + + # 账号封禁 - 尝试切换账号 + if error_obj.should_switch_account: + next_account = state.get_next_available_account(current_account.id) + if next_account and retry_count < max_retries: + print(f"[Stream] 切换账号: {current_account.id} -> {next_account.id}") + current_account = next_account + headers["Authorization"] = f"Bearer {current_account.get_token()}" + retry_count += 1 + continue + + # 检查是否为内容长度超限错误,尝试截断重试 + if error_obj.type == ErrorType.CONTENT_TOO_LONG: + history_chars, user_chars, total_chars = history_manager.estimate_request_chars( + history, user_content + ) + print(f"[Stream] 内容长度超限: history={history_chars} chars, user={user_chars} chars, total={total_chars} chars") + async def api_caller(prompt: str) -> str: + return await _call_kiro_for_summary(prompt, current_account, headers) + truncated_history, should_retry = await history_manager.handle_length_error_async( + history, retry_count, api_caller + ) + if should_retry: + print(f"[Stream] 内容长度超限,{history_manager.truncate_info}") + history = truncated_history + # 重新构建请求 + clean_history = strip_thinking_from_history(history) + kiro_request = build_kiro_request(user_content, model, clean_history, kiro_tools, images, tool_results) + retry_count += 1 + continue + + if flow_id: + flow_monitor.fail_flow(flow_id, error_type, error_msg, response.status_code, error_str) + yield f'data: {{"type":"error","error":{{"type":"{error_type}","message":"{error_msg}"}}}}\n\n' + return + + # 标记开始流式传输 + if flow_id: + flow_monitor.start_streaming(flow_id) + + # 正常处理响应 + msg_id = f"msg_{log_id}" + sent_any_event = True + yield f'event: message_start\ndata: {{"type":"message_start","message":{{"id":"{msg_id}","type":"message","role":"assistant","content":[],"model":"{model}","stop_reason":null,"stop_sequence":null,"usage":{{"input_tokens":0,"output_tokens":0}}}}}}\n\n' + + # ========== 主响应流式处理(单次调用,按 标签拆分) ========== + full_response = b"" + text_block_started = False + chunk_buffer = b"" + async for chunk in response.aiter_bytes(): + if not saw_any_chunk: + saw_any_chunk = True + print(f"[Anthropic][Stream:{log_id}] first_chunk bytes={len(chunk)}") + full_response += chunk + chunk_buffer += chunk + + try: + while len(chunk_buffer) >= 12: + total_len = int.from_bytes(chunk_buffer[0:4], 'big') + + # 如果缓冲区不足以容纳整个消息,等待更多数据 + if len(chunk_buffer) < total_len: + break + + headers_len = int.from_bytes(chunk_buffer[4:8], 'big') + payload_start = 12 + headers_len + payload_end = total_len - 4 + + if payload_start < payload_end: + try: + payload_data = chunk_buffer[payload_start:payload_end] + payload = json.loads(payload_data.decode('utf-8')) + content = None + if 'assistantResponseEvent' in payload: + content = payload['assistantResponseEvent'].get('content') + elif 'content' in payload: + content = payload['content'] + if content: + full_content += content + saw_any_text = True + if flow_id: + flow_monitor.add_chunk(flow_id, content) + + events = thinking_processor.process_content(content) + for event in events: + if event["type"] == "content_block_start" and event.get("content_block", {}).get("type") == "text": + text_block_started = True + if event["type"] in ["content_block_start", "content_block_delta", "content_block_stop"]: + sent_any_event = True + yield f'event: {event["type"]}\ndata: {json.dumps(event, separators=(",", ":"), ensure_ascii=False)}\n\n' + except Exception as e: + print(f"[Stream] Payload parse error: {e}") + pass + + # 移动缓冲区 + chunk_buffer = chunk_buffer[total_len:] + except Exception as e: + print(f"[Stream] Chunk processing error: {e}") + pass + + # 完成思考处理 + final_events = thinking_processor.finalize() + for event in final_events: + sent_any_event = True + yield f'event: {event["type"]}\ndata: {json.dumps(event, separators=(",", ":"), ensure_ascii=False)}\n\n' + + # 确保文本块已开始 + if not text_block_started: + idx = thinking_processor._next_index() + yield f'event: content_block_start\ndata: {{"type":"content_block_start","index":{idx},"content_block":{{"type":"text","text":""}}}}\n\n' + yield f'event: content_block_stop\ndata: {{"type":"content_block_stop","index":{idx}}}\n\n' + + result = parse_event_stream_full(full_response) + + if result["tool_uses"]: + tool_start_index = content_block_index_ref[0] + for i, tool_use in enumerate(result["tool_uses"]): + idx = tool_start_index + i + yield f'event: content_block_start\ndata: {{"type":"content_block_start","index":{idx},"content_block":{{"type":"tool_use","id":"{tool_use["id"]}","name":"{tool_use["name"]}","input":{{}}}}}}\n\n' + partial_json = json.dumps(tool_use.get("input") or {}, ensure_ascii=False) + yield f'event: content_block_delta\ndata: {{"type":"content_block_delta","index":{idx},"delta":{{"type":"input_json_delta","partial_json":{json.dumps(partial_json, ensure_ascii=False)}}}}}\n\n' + yield f'event: content_block_stop\ndata: {{"type":"content_block_stop","index":{idx}}}\n\n' + + stop_reason = result["stop_reason"] + input_tokens = result.get("input_tokens", 0) + output_tokens = result.get("output_tokens", 0) + if not output_tokens and full_content: + output_tokens = _estimate_output_tokens_from_text(full_content) + yield f'event: message_delta\ndata: {{"type":"message_delta","delta":{{"stop_reason":"{stop_reason}","stop_sequence":null}},"usage":{{"input_tokens":{input_tokens},"output_tokens":{output_tokens}}}}}\n\n' + yield f'event: message_stop\ndata: {{"type":"message_stop"}}\n\n' + yield 'data: [DONE]\n\n' + print( + f"[Anthropic][Stream:{log_id}] done chunks={saw_any_chunk} text={saw_any_text} sent_events={sent_any_event} " + f"input_tokens={input_tokens} output_tokens={output_tokens} stop_reason={stop_reason}" + ) + + # 完成 Flow + if flow_id: + flow_monitor.complete_flow( + flow_id, + status_code=200, + content=full_content, + tool_calls=result.get("tool_uses", []), + stop_reason=stop_reason, + usage=TokenUsage( + input_tokens=result.get("input_tokens", 0), + output_tokens=result.get("output_tokens", 0), + ), + ) + + current_account.request_count += 1 + current_account.last_used = time.time() + get_rate_limiter().record_request(current_account.id) + + # 记录日志 + duration = (time.time() - start_time) * 1000 + state.add_log(RequestLog( + id=log_id, + timestamp=time.time(), + method="POST", + path="/v1/messages", + model=model, + account_id=current_account.id if current_account else None, + status=200, + duration_ms=duration, + error=None + )) + return + + except httpx.TimeoutException: + if retry_count < max_retries: + print(f"[Stream] 请求超时,重试 {retry_count + 1}/{max_retries}") + retry_count += 1 + import asyncio + await asyncio.sleep(0.5 * (2 ** retry_count)) + continue + if flow_id: + flow_monitor.fail_flow(flow_id, "timeout_error", "Request timeout after retries", 408) + yield f'data: {{"type":"error","error":{{"type":"api_error","message":"Request timeout after retries"}}}}\n\n' + return + except httpx.ConnectError: + if retry_count < max_retries: + print(f"[Stream] 连接错误,重试 {retry_count + 1}/{max_retries}") + retry_count += 1 + import asyncio + await asyncio.sleep(0.5 * (2 ** retry_count)) + continue + if flow_id: + flow_monitor.fail_flow(flow_id, "connection_error", "Connection error after retries", 502) + yield f'data: {{"type":"error","error":{{"type":"api_error","message":"Connection error after retries"}}}}\n\n' + return + except Exception as e: + # 检查是否为可重试的网络错误 + if is_retryable_error(None, e) and retry_count < max_retries: + print(f"[Stream] 网络错误,重试 {retry_count + 1}/{max_retries}: {type(e).__name__}") + retry_count += 1 + import asyncio + await asyncio.sleep(0.5 * (2 ** retry_count)) + continue + if flow_id: + flow_monitor.fail_flow(flow_id, "api_error", str(e), 500) + yield f'data: {{"type":"error","error":{{"type":"api_error","message":"{str(e)}"}}}}\n\n' + return + + return StreamingResponse( + generate(), + media_type="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "X-Accel-Buffering": "no", + }, + ) + + +async def _handle_non_stream(kiro_request, headers, account, model, log_id, start_time, session_id=None, flow_id=None, history=None, user_content="", kiro_tools=None, images=None, tool_results=None, history_manager=None, thinking_enabled: bool = False, budget_tokens: int | None = None): + """Handle non-streaming responses with auto-retry on quota exceeded and network errors.""" + error_msg = None + status_code = 200 + current_account = account + max_retries = 2 + retry_ctx = RetryableRequest(max_retries=2) + + thinking_content = "" + clean_history = strip_thinking_from_history(history) + if thinking_enabled: + thinking_prompt = build_thinking_prompt(user_content, budget_tokens=budget_tokens) + thinking_request = build_kiro_request(thinking_prompt, model, clean_history, None, images, tool_results) + + for retry in range(max_retries + 1): + try: + async with httpx.AsyncClient(verify=False, timeout=300) as client: + response = await client.post(KIRO_API_URL, json=thinking_request, headers=headers) + + # 仅 429 状态码触发冷却和账号切换 + if response.status_code == 429: + current_account.mark_quota_exceeded("Rate limited (thinking non-stream)") + + next_account = state.get_next_available_account(current_account.id) + if next_account and retry < max_retries: + print(f"[NonStream][Thinking] 429 限流,切换账号: {current_account.id} -> {next_account.id}") + current_account = next_account + headers["Authorization"] = f"Bearer {current_account.get_token()}" + continue + break + + # 处理可重试的服务端错误(不触发冷却,仅重试) + if is_retryable_error(response.status_code): + if retry < max_retries: + print(f"[NonStream][Thinking] 服务端错误 {response.status_code},重试 {retry + 1}/{max_retries}") + await retry_ctx.wait() + continue + break + + if response.status_code != 200: + break + + thinking_content = parse_event_stream(response.content) + + # thinking 请求成功计入额度/频率 + current_account.request_count += 1 + current_account.last_used = time.time() + get_rate_limiter().record_request(current_account.id) + break + except Exception: + if retry < max_retries: + await retry_ctx.wait() + continue + break + + main_user_content = build_user_prompt_with_thinking(user_content, thinking_content) + kiro_request = build_kiro_request(main_user_content, model, clean_history, kiro_tools, images, tool_results) + + for retry in range(max_retries + 1): + try: + async with httpx.AsyncClient(verify=False, timeout=300) as client: + response = await client.post(KIRO_API_URL, json=kiro_request, headers=headers) + status_code = response.status_code + + # 仅 429 状态码触发冷却和账号切换 + if response.status_code == 429: + current_account.mark_quota_exceeded("Rate limited") + + # 尝试切换账号 + next_account = state.get_next_available_account(current_account.id) + if next_account and retry < max_retries: + print(f"[NonStream] 429 限流,切换账号: {current_account.id} -> {next_account.id}") + current_account = next_account + token = current_account.get_token() + creds = current_account.get_credentials() + headers["Authorization"] = f"Bearer {token}" + continue + + if flow_id: + flow_monitor.fail_flow(flow_id, "rate_limit_error", "All accounts rate limited", 429) + raise HTTPException(429, "All accounts rate limited") + + # 处理可重试的服务端错误(不触发冷却,仅重试) + if is_retryable_error(response.status_code): + if retry < max_retries: + print(f"[NonStream] 服务端错误 {response.status_code},重试 {retry + 1}/{max_retries}") + await retry_ctx.wait() + continue + if flow_id: + flow_monitor.fail_flow(flow_id, "api_error", f"Server error after {max_retries} retries", response.status_code) + raise HTTPException(response.status_code, f"Server error after {max_retries} retries") + + if response.status_code != 200: + error_msg = response.text + print(f"[NonStream] Kiro API Error {response.status_code}: {error_msg[:500]}") + + # 使用统一的错误处理 + status, error_type, error_message, error_obj = _handle_kiro_error( + response.status_code, error_msg, current_account + ) + + # 账号封禁或配额超限 - 尝试切换账号 + if error_obj.should_switch_account: + next_account = state.get_next_available_account(current_account.id) + if next_account and retry < max_retries: + print(f"[NonStream] 切换账号: {current_account.id} -> {next_account.id}") + current_account = next_account + headers["Authorization"] = f"Bearer {current_account.get_token()}" + continue + + # 检查是否为内容长度超限错误,尝试截断重试 + if error_obj.type == ErrorType.CONTENT_TOO_LONG and history_manager: + history_chars, user_chars, total_chars = history_manager.estimate_request_chars( + history, main_user_content + ) + print(f"[NonStream] 内容长度超限: history={history_chars} chars, user={user_chars} chars, total={total_chars} chars") + async def api_caller(prompt: str) -> str: + return await _call_kiro_for_summary(prompt, current_account, headers) + truncated_history, should_retry = await history_manager.handle_length_error_async( + history, retry, api_caller + ) + if should_retry: + print(f"[NonStream] 内容长度超限,{history_manager.truncate_info}") + history = truncated_history + kiro_request = build_kiro_request(main_user_content, model, history, kiro_tools, images, tool_results) + continue + else: + print(f"[NonStream] 内容长度超限但未重试: retry={retry}/{max_retries}") + + if flow_id: + flow_monitor.fail_flow(flow_id, error_type, error_message, status, error_msg) + raise HTTPException(status, error_message) + + result = parse_event_stream_full(response.content) + current_account.request_count += 1 + current_account.last_used = time.time() + get_rate_limiter().record_request(current_account.id) + + # 完成 Flow + if flow_id: + full_text = "".join(result.get("content", [])) + flow_monitor.complete_flow( + flow_id, + status_code=200, + content=full_text, + tool_calls=result.get("tool_uses", []), + stop_reason=result.get("stop_reason", ""), + usage=TokenUsage( + input_tokens=result.get("input_tokens", 0), + output_tokens=result.get("output_tokens", 0), + ), + ) + + resp = convert_kiro_response_to_anthropic(result, model, f"msg_{log_id}") + if thinking_enabled: + resp["content"].insert(0, {"type": "thinking", "thinking": thinking_content or ""}) + return resp + + except HTTPException: + raise + except httpx.TimeoutException as e: + error_msg = f"Request timeout: {e}" + status_code = 408 + if retry < max_retries: + print(f"[NonStream] 请求超时,重试 {retry + 1}/{max_retries}") + await retry_ctx.wait() + continue + if flow_id: + flow_monitor.fail_flow(flow_id, "timeout_error", "Request timeout after retries", 408) + raise HTTPException(408, "Request timeout after retries") + except httpx.ConnectError as e: + error_msg = f"Connection error: {e}" + status_code = 502 + if retry < max_retries: + print(f"[NonStream] 连接错误,重试 {retry + 1}/{max_retries}") + await retry_ctx.wait() + continue + if flow_id: + flow_monitor.fail_flow(flow_id, "connection_error", "Connection error after retries", 502) + raise HTTPException(502, "Connection error after retries") + except Exception as e: + error_msg = str(e) + status_code = 500 + # 检查是否为可重试的网络错误 + if is_retryable_error(None, e) and retry < max_retries: + print(f"[NonStream] 网络错误,重试 {retry + 1}/{max_retries}: {type(e).__name__}") + await retry_ctx.wait() + continue + if flow_id: + flow_monitor.fail_flow(flow_id, "api_error", str(e), 500) + raise HTTPException(500, str(e)) + finally: + if retry == max_retries or status_code == 200: + duration = (time.time() - start_time) * 1000 + state.add_log(RequestLog( + id=log_id, + timestamp=time.time(), + method="POST", + path="/v1/messages", + model=model, + account_id=current_account.id if current_account else None, + status=status_code, + duration_ms=duration, + error=error_msg + )) + # 记录统计 + stats_manager.record_request( + account_id=current_account.id if current_account else "unknown", + model=model, + success=status_code == 200, + latency_ms=duration + ) + + raise HTTPException(503, "All retries exhausted") + + +class ThinkingStreamProcessor: + """思考内容流式处理器""" + + _THINKING_START_TAG = "" + _THINKING_END_TAG = "" + + def __init__(self, thinking_enabled: bool = False, index_ref: list | None = None): + self.thinking_enabled = thinking_enabled + self.thinking_buffer = "" + self.in_thinking_block = False + self.thinking_extracted = False + self.text_buffer = "" + self._index_ref = index_ref if index_ref is not None else [0] + self._text_index = None + self._thinking_index = None + + def _next_index(self) -> int: + self._index_ref[0] += 1 + return self._index_ref[0] - 1 + + @staticmethod + def _split_incomplete_tag_tail(buffer: str, tag: str) -> tuple[str, str]: + """Split buffer into (flush, keep) parts to handle tags split across chunks. + + Keeps the longest suffix of `buffer` that could be a prefix of `tag` so the + next chunk can complete the tag. + """ + if not buffer: + return "", "" + max_suffix_len = min(len(tag) - 1, len(buffer)) + for suffix_len in range(max_suffix_len, 0, -1): + if buffer.endswith(tag[:suffix_len]): + return buffer[:-suffix_len], buffer[-suffix_len:] + return buffer, "" + + def process_content(self, content: str) -> list: + """处理新到达的内容块,返回生成的事件列表""" + events = [] + + if not self.thinking_enabled: + # 如果未启用思考模式,直接返回文本内容 + if self._text_index is None: + self._text_index = self._next_index() + events.append({ + "type": "content_block_start", + "index": self._text_index, + "content_block": {"type": "text", "text": ""} + }) + events.append({ + "type": "content_block_delta", + "index": self._text_index, + "delta": {"type": "text_delta", "text": content} + }) + return events + + # 将内容添加到缓冲区 + self.text_buffer += content + + # 查找思考标签 + while self.text_buffer: + if not self.in_thinking_block: + # 查找思考开始标签 + start_idx = find_real_thinking_start_tag(self.text_buffer) + if start_idx == -1: + # 没有找到思考标签:只输出安全部分,保留可能被拆分的标签前缀 + flush_text, keep = self._split_incomplete_tag_tail( + self.text_buffer, self._THINKING_START_TAG + ) + if flush_text: + if self._text_index is None: + self._text_index = self._next_index() + events.append({ + "type": "content_block_start", + "index": self._text_index, + "content_block": {"type": "text", "text": ""} + }) + events.append({ + "type": "content_block_delta", + "index": self._text_index, + "delta": {"type": "text_delta", "text": flush_text} + }) + self.text_buffer = keep + break + + # 输出思考标签之前的文本 + if start_idx > 0: + text_before = self.text_buffer[:start_idx] + if self._text_index is None: + self._text_index = self._next_index() + events.append({ + "type": "content_block_start", + "index": self._text_index, + "content_block": {"type": "text", "text": ""} + }) + events.append({ + "type": "content_block_delta", + "index": self._text_index, + "delta": {"type": "text_delta", "text": text_before} + }) + + # 开始思考块 + self.in_thinking_block = True + self._thinking_index = self._next_index() + events.append({ + "type": "content_block_start", + "index": self._thinking_index, + "content_block": {"type": "thinking", "thinking": ""} + }) + + # 移除已处理的内容 + self.text_buffer = self.text_buffer[start_idx + len(self._THINKING_START_TAG):] + + else: + # 在思考块内,查找结束标签 + end_idx = find_real_thinking_end_tag(self.text_buffer) + if end_idx == -1: + # 没有找到结束标签:只输出安全部分,保留可能被拆分的结束标签前缀 + flush_thinking, keep = self._split_incomplete_tag_tail( + self.text_buffer, self._THINKING_END_TAG + ) + if flush_thinking: + self.thinking_buffer += flush_thinking + events.append({ + "type": "content_block_delta", + "index": self._thinking_index, + "delta": {"type": "thinking_delta", "thinking": flush_thinking} + }) + self.text_buffer = keep + break + + # 找到结束标签,输出思考内容 + thinking_content = self.text_buffer[:end_idx] + self.thinking_buffer += thinking_content + if thinking_content: + events.append({ + "type": "content_block_delta", + "index": self._thinking_index, + "delta": {"type": "thinking_delta", "thinking": thinking_content} + }) + + # 结束思考块 + events.append({ + "type": "content_block_stop", + "index": self._thinking_index + }) + + self.in_thinking_block = False + self.thinking_extracted = True + + # 移除已处理的内容 + self.text_buffer = self.text_buffer[end_idx + len(self._THINKING_END_TAG):] + + return events + + def finalize(self) -> list: + """完成处理,返回结束事件""" + events = [] + + # 刷出残留缓冲(可能包含被拆分的标签片段或尾部文本) + if self.text_buffer: + if self.in_thinking_block: + if self._thinking_index is None: + self._thinking_index = self._next_index() + events.append({ + "type": "content_block_start", + "index": self._thinking_index, + "content_block": {"type": "thinking", "thinking": ""} + }) + self.thinking_buffer += self.text_buffer + events.append({ + "type": "content_block_delta", + "index": self._thinking_index, + "delta": {"type": "thinking_delta", "thinking": self.text_buffer} + }) + else: + if self._text_index is None: + self._text_index = self._next_index() + events.append({ + "type": "content_block_start", + "index": self._text_index, + "content_block": {"type": "text", "text": ""} + }) + events.append({ + "type": "content_block_delta", + "index": self._text_index, + "delta": {"type": "text_delta", "text": self.text_buffer} + }) + self.text_buffer = "" + + if self.in_thinking_block and self._thinking_index is not None: + # 如果还在思考块内,强制结束 + events.append({ + "type": "content_block_stop", + "index": self._thinking_index + }) + self.in_thinking_block = False + if self._text_index is not None: + events.append({ + "type": "content_block_stop", + "index": self._text_index + }) + + return events diff --git a/KiroProxy/kiro_proxy/handlers/gemini.py b/KiroProxy/kiro_proxy/handlers/gemini.py new file mode 100644 index 0000000000000000000000000000000000000000..56dc7f641ab6e9f5eac7fb36372968cfb7e75c6c --- /dev/null +++ b/KiroProxy/kiro_proxy/handlers/gemini.py @@ -0,0 +1,306 @@ +"""Gemini 协议处理 - /v1/models/{model}:generateContent""" +import json +import uuid +import time +import hashlib +import asyncio +import httpx +from fastapi import Request, HTTPException + +from ..config import KIRO_API_URL, map_model_name +from ..core import state, is_retryable_error +from ..core.state import RequestLog +from ..core.history_manager import HistoryManager, get_history_config, is_content_length_error +from ..core.error_handler import classify_error, ErrorType, format_error_log +from ..core.rate_limiter import get_rate_limiter +from ..kiro_api import build_headers, build_kiro_request, parse_event_stream, parse_event_stream_full, is_quota_exceeded_error +from ..core.thinking import ( + ThinkingConfig, + build_user_prompt_with_thinking, + extract_thinking_config_from_gemini_body, + fetch_thinking_text, + format_thinking_block, + infer_thinking_from_gemini_contents, + strip_thinking_from_history, +) +from ..converters import convert_gemini_contents_to_kiro, convert_kiro_response_to_gemini, convert_gemini_tools_to_kiro + + +async def handle_generate_content(model_name: str, request: Request): + """处理 Gemini generateContent 请求""" + start_time = time.time() + log_id = uuid.uuid4().hex[:8] + + body = await request.json() + contents = body.get("contents", []) + system_instruction = body.get("systemInstruction", {}) + tools = body.get("tools", []) + tool_config = body.get("toolConfig", {}) + thinking_cfg, thinking_explicit = extract_thinking_config_from_gemini_body(body) + if not thinking_explicit and infer_thinking_from_gemini_contents(contents): + thinking_cfg = ThinkingConfig(True, None) + + model_raw = model_name.replace("models/", "") + model = map_model_name(model_raw) + + session_id = hashlib.sha256(json.dumps(contents[:3], sort_keys=True).encode()).hexdigest()[:16] + account = state.get_available_account(session_id) + + if not account: + raise HTTPException(503, "All accounts are rate limited") + + # 检查 token 是否即将过期 + if account.is_token_expiring_soon(5): + print(f"[Gemini] Token 即将过期,尝试刷新: {account.id}") + success, msg = await account.refresh_token() + if not success: + print(f"[Gemini] Token 刷新失败: {msg}") + + token = account.get_token() + if not token: + raise HTTPException(500, f"Failed to get token for account {account.name}") + + # 构建 headers(提前构建,供摘要使用) + creds = account.get_credentials() + headers = build_headers( + token, + machine_id=account.get_machine_id(), + profile_arn=creds.profile_arn if creds else None, + client_id=creds.client_id if creds else None + ) + + # 限速检查 + rate_limiter = get_rate_limiter() + can_request, wait_seconds, reason = rate_limiter.can_request(account.id) + if not can_request: + print(f"[Gemini] 限速: {reason}") + await asyncio.sleep(wait_seconds) + + # 转换消息格式 + user_content, history, tool_results, kiro_tools = convert_gemini_contents_to_kiro( + contents, system_instruction, model, tools, tool_config + ) + + # 历史消息预处理 + history_manager = HistoryManager(get_history_config(), cache_key=session_id) + + async def call_summary(prompt: str) -> str: + req = build_kiro_request(prompt, "claude-haiku-4.5", []) + try: + async with httpx.AsyncClient(verify=False, timeout=60) as client: + resp = await client.post(KIRO_API_URL, json=req, headers=headers) + if resp.status_code == 200: + return parse_event_stream(resp.content) + except Exception as e: + print(f"[Summary] API 调用失败: {e}") + return "" + + # 检查是否需要智能摘要或错误重试预摘要 + if history_manager.should_summarize(history) or history_manager.should_pre_summary_for_error_retry(history, user_content): + history = await history_manager.pre_process_async(history, user_content, call_summary) + else: + history = history_manager.pre_process(history, user_content) + + # 摘要/截断后再次修复历史交替和 toolUses/toolResults 配对 + from ..converters import fix_history_alternation + history = fix_history_alternation(history) + + if history_manager.was_truncated: + print(f"[Gemini] {history_manager.truncate_info}") + + async def call_summary(prompt: str) -> str: + req = build_kiro_request(prompt, "claude-haiku-4.5", []) + try: + async with httpx.AsyncClient(verify=False, timeout=60) as client: + resp = await client.post(KIRO_API_URL, json=req, headers=headers) + if resp.status_code == 200: + return parse_event_stream(resp.content) + except Exception as e: + print(f"[Summary] API 调用失败: {e}") + return "" + + # 构建 Kiro 请求 + thinking_content = "" + main_user_content = user_content + if thinking_cfg.enabled: + thinking_content = await fetch_thinking_text( + headers=headers, + model=model, + user_content=user_content, + history=history, + tool_results=tool_results if tool_results else None, + budget_tokens=thinking_cfg.budget_tokens, + ) + if thinking_content: + account.request_count += 1 + account.last_used = time.time() + get_rate_limiter().record_request(account.id) + main_user_content = build_user_prompt_with_thinking(user_content, thinking_content) + + clean_history = strip_thinking_from_history(history) + kiro_request = build_kiro_request( + main_user_content, + model, + clean_history, + tools=kiro_tools if kiro_tools else None, + tool_results=tool_results if tool_results else None, + ) + + error_msg = None + status_code = 200 + content = "" + current_account = account + max_retries = 2 + + for retry in range(max_retries + 1): + try: + async with httpx.AsyncClient(verify=False, timeout=120) as client: + resp = await client.post(KIRO_API_URL, json=kiro_request, headers=headers) + status_code = resp.status_code + + # 仅 429 状态码触发冷却和账号切换 + if resp.status_code == 429: + current_account.mark_quota_exceeded("Rate limited") + next_account = state.get_next_available_account(current_account.id) + if next_account and retry < max_retries: + print(f"[Gemini] 429 限流,切换账号: {current_account.id} -> {next_account.id}") + current_account = next_account + token = current_account.get_token() + creds = current_account.get_credentials() + headers = build_headers( + token, + machine_id=current_account.get_machine_id(), + profile_arn=creds.profile_arn if creds else None, + client_id=creds.client_id if creds else None + ) + continue + raise HTTPException(429, "All accounts rate limited") + + # 处理可重试的服务端错误(不触发冷却,仅重试) + if is_retryable_error(resp.status_code): + if retry < max_retries: + print(f"[Gemini] 服务端错误 {resp.status_code},重试 {retry + 1}/{max_retries}") + import asyncio + await asyncio.sleep(0.5 * (2 ** retry)) + continue + raise HTTPException(resp.status_code, f"Server error after {max_retries} retries") + + if resp.status_code != 200: + error_msg = resp.text + + # 使用统一的错误处理 + error = classify_error(resp.status_code, error_msg) + print(format_error_log(error, current_account.id)) + + # 账号封禁 - 禁用账号 + if error.should_disable_account: + current_account.enabled = False + if hasattr(current_account, "auto_disabled"): + current_account.auto_disabled = False + from ..credential import CredentialStatus + current_account.status = CredentialStatus.SUSPENDED + try: + state._save_accounts() + except Exception: + pass + print(f"[Gemini] 账号 {current_account.id} 已被禁用 (封禁)") + + # 仅 429 状态码触发冷却 + elif resp.status_code == 429: + current_account.mark_quota_exceeded(error_msg[:100]) + + # 其他错误(非 429、非内容过长)- 异步检查额度 + elif error.type != ErrorType.CONTENT_TOO_LONG: + from .anthropic import _check_and_disable_if_exhausted + import asyncio + asyncio.create_task(_check_and_disable_if_exhausted(current_account)) + + # 尝试切换账号(仅账号封禁时) + if error.should_disable_account: + next_account = state.get_next_available_account(current_account.id) + if next_account and retry < max_retries: + print(f"[Gemini] 切换账号: {current_account.id} -> {next_account.id}") + current_account = next_account + headers["Authorization"] = f"Bearer {current_account.get_token()}" + continue + + # 检查是否为内容长度超限错误 + if error.type == ErrorType.CONTENT_TOO_LONG: + history_chars, user_chars, total_chars = history_manager.estimate_request_chars( + history, main_user_content + ) + print(f"[Gemini] 内容长度超限: history={history_chars} chars, user={user_chars} chars, total={total_chars} chars") + truncated_history, should_retry = await history_manager.handle_length_error_async( + history, retry, call_summary + ) + if should_retry: + print(f"[Gemini] 内容长度超限,{history_manager.truncate_info}") + history = truncated_history + kiro_request = build_kiro_request( + main_user_content, model, history, + tools=kiro_tools if kiro_tools else None, + tool_results=tool_results if tool_results else None + ) + continue + else: + print(f"[Gemini] 内容长度超限但未重试: retry={retry}/{max_retries}") + + raise HTTPException(resp.status_code, error.user_message) + + # 使用完整解析以支持工具调用 + result = parse_event_stream_full(resp.content) + current_account.request_count += 1 + current_account.last_used = time.time() + get_rate_limiter().record_request(current_account.id) + break + + except HTTPException: + raise + except httpx.TimeoutException: + error_msg = "Request timeout" + status_code = 408 + if retry < max_retries: + print(f"[Gemini] 请求超时,重试 {retry + 1}/{max_retries}") + import asyncio + await asyncio.sleep(0.5 * (2 ** retry)) + continue + raise HTTPException(408, "Request timeout after retries") + except httpx.ConnectError: + error_msg = "Connection error" + status_code = 502 + if retry < max_retries: + print(f"[Gemini] 连接错误,重试 {retry + 1}/{max_retries}") + import asyncio + await asyncio.sleep(0.5 * (2 ** retry)) + continue + raise HTTPException(502, "Connection error after retries") + except Exception as e: + error_msg = str(e) + status_code = 500 + if is_retryable_error(None, e) and retry < max_retries: + print(f"[Gemini] 网络错误,重试 {retry + 1}/{max_retries}: {type(e).__name__}") + import asyncio + await asyncio.sleep(0.5 * (2 ** retry)) + continue + raise HTTPException(500, str(e)) + + # 记录日志 + duration = (time.time() - start_time) * 1000 + state.add_log(RequestLog( + id=log_id, + timestamp=time.time(), + method="POST", + path=f"/v1/models/{model_name}:generateContent", + model=model, + account_id=current_account.id if current_account else None, + status=status_code, + duration_ms=duration, + error=error_msg + )) + + # 使用转换函数生成 Gemini 格式响应 + if thinking_cfg.enabled and thinking_content: + prefix = f"{format_thinking_block(thinking_content)}\n\n" + result = dict(result) + result["content"] = [prefix] + list(result.get("content", []) or []) + return convert_kiro_response_to_gemini(result, model) diff --git a/KiroProxy/kiro_proxy/handlers/openai.py b/KiroProxy/kiro_proxy/handlers/openai.py new file mode 100644 index 0000000000000000000000000000000000000000..543a5f77b023f400047aa3fd74adcbed5bcf4c4e --- /dev/null +++ b/KiroProxy/kiro_proxy/handlers/openai.py @@ -0,0 +1,412 @@ +"""OpenAI 协议处理 - /v1/chat/completions""" +import json +import uuid +import time +import asyncio +import httpx +from datetime import datetime +from fastapi import Request, HTTPException +from fastapi.responses import StreamingResponse + +from ..config import KIRO_API_URL, map_model_name +from ..core import state, is_retryable_error, stats_manager +from ..core.state import RequestLog +from ..core.history_manager import HistoryManager, get_history_config, is_content_length_error +from ..core.error_handler import classify_error, ErrorType, format_error_log +from ..core.rate_limiter import get_rate_limiter +from ..kiro_api import build_headers, build_kiro_request, parse_event_stream, is_quota_exceeded_error +from ..converters import generate_session_id, convert_openai_messages_to_kiro, extract_images_from_content +from ..core.thinking import ( + ThinkingConfig, + build_user_prompt_with_thinking, + extract_thinking_config_from_openai_body, + fetch_thinking_text, + format_thinking_block, + infer_thinking_from_openai_messages, + strip_thinking_from_history, +) + +# 尝试导入 tiktoken,如果失败则使用估算方法 +try: + import tiktoken + _encoding = tiktoken.get_encoding("cl100k_base") + _USE_TIKTOKEN = True + print("[TokenCounter] 使用 tiktoken (cl100k_base) 进行 token 计数") +except ImportError: + _encoding = None + _USE_TIKTOKEN = False + print("[TokenCounter] tiktoken 未安装,使用字符估算方法") + + +def _estimate_tokens(text: str) -> int: + """估算/计算 token 数量 + + 优先使用 tiktoken (cl100k_base),否则使用字符估算: + - 中文字符:约 1.5 字符 = 1 token + - 其他字符:约 4 字符 = 1 token + """ + if not text: + return 0 + + if _USE_TIKTOKEN and _encoding: + return len(_encoding.encode(text)) + + # 回退到字符估算 + chinese_chars = sum(1 for c in text if '\u4e00' <= c <= '\u9fff') + other_chars = len(text) - chinese_chars + tokens = int(chinese_chars / 1.5) + int(other_chars / 4) + return max(1, tokens) + + +def _estimate_input_tokens(messages: list, tools: list = None) -> int: + """估算/计算输入 token 数量""" + total = 0 + + for msg in messages or []: + content = msg.get("content", "") + if isinstance(content, str): + total += _estimate_tokens(content) + elif isinstance(content, list): + for part in content: + if isinstance(part, dict) and part.get("type") == "text": + total += _estimate_tokens(part.get("text", "")) + # 角色和消息结构开销 + role = msg.get("role", "") + total += _estimate_tokens(role) + 4 # 每条消息的结构开销 + + # 工具定义开销 + if tools: + tools_json = json.dumps(tools) + total += _estimate_tokens(tools_json) + + return max(1, total) + + +async def handle_chat_completions(request: Request): + """处理 /v1/chat/completions 请求""" + start_time = time.time() + log_id = uuid.uuid4().hex[:8] + + body = await request.json() + model = map_model_name(body.get("model", "claude-sonnet-4")) + messages = body.get("messages", []) + stream = body.get("stream", False) + tools = body.get("tools", None) + tool_choice = body.get("tool_choice", None) + + thinking_cfg, thinking_explicit = extract_thinking_config_from_openai_body(body) + if not thinking_explicit and infer_thinking_from_openai_messages(messages): + thinking_cfg = ThinkingConfig(True, None) + + if not messages: + raise HTTPException(400, "messages required") + + session_id = generate_session_id(messages) + account = state.get_available_account(session_id) + + if not account: + raise HTTPException(503, "All accounts are rate limited or unavailable") + + # 检查 token 是否即将过期,尝试刷新 + if account.is_token_expiring_soon(5): + print(f"[OpenAI] Token 即将过期,尝试刷新: {account.id}") + success, msg = await account.refresh_token() + if not success: + print(f"[OpenAI] Token 刷新失败: {msg}") + + token = account.get_token() + if not token: + raise HTTPException(500, f"Failed to get token for account {account.name}") + + # 使用账号的动态 Machine ID(提前构建,供摘要使用) + creds = account.get_credentials() + headers = build_headers( + token, + machine_id=account.get_machine_id(), + profile_arn=creds.profile_arn if creds else None, + client_id=creds.client_id if creds else None + ) + + # 限速检查 + rate_limiter = get_rate_limiter() + can_request, wait_seconds, reason = rate_limiter.can_request(account.id) + if not can_request: + print(f"[OpenAI] 限速: {reason}") + await asyncio.sleep(wait_seconds) + + # 使用增强的转换函数 + user_content, history, tool_results, kiro_tools = convert_openai_messages_to_kiro( + messages, model, tools, tool_choice + ) + + # 历史消息预处理 + history_manager = HistoryManager(get_history_config(), cache_key=session_id) + + async def call_summary(prompt: str) -> str: + req = build_kiro_request(prompt, "claude-haiku-4.5", []) + try: + async with httpx.AsyncClient(verify=False, timeout=60) as client: + resp = await client.post(KIRO_API_URL, json=req, headers=headers) + if resp.status_code == 200: + return parse_event_stream(resp.content) + except Exception as e: + print(f"[Summary] API 调用失败: {e}") + return "" + + # 检查是否需要智能摘要或错误重试预摘要 + if history_manager.should_summarize(history) or history_manager.should_pre_summary_for_error_retry(history, user_content): + history = await history_manager.pre_process_async(history, user_content, call_summary) + else: + history = history_manager.pre_process(history, user_content) + + # 摘要/截断后再次修复历史交替和 toolUses/toolResults 配对 + from ..converters import fix_history_alternation + history = fix_history_alternation(history) + + if history_manager.was_truncated: + print(f"[OpenAI] {history_manager.truncate_info}") + + + # 提取最后一条消息中的图片 + images = [] + if messages: + last_msg = messages[-1] + if last_msg.get("role") == "user": + _, images = extract_images_from_content(last_msg.get("content", "")) + + thinking_content = "" + if thinking_cfg.enabled: + thinking_content = await fetch_thinking_text( + headers=headers, + model=model, + user_content=user_content, + history=history, + images=images, + tool_results=tool_results if tool_results else None, + budget_tokens=thinking_cfg.budget_tokens, + ) + if thinking_content: + account.request_count += 1 + account.last_used = time.time() + get_rate_limiter().record_request(account.id) + + main_user_content = build_user_prompt_with_thinking(user_content, thinking_content) + clean_history = strip_thinking_from_history(history) + + kiro_request = build_kiro_request( + main_user_content, + model, + clean_history, + images=images, + tools=kiro_tools if kiro_tools else None, + tool_results=tool_results if tool_results else None, + ) + + error_msg = None + status_code = 200 + content = "" + current_account = account + max_retries = 2 + + for retry in range(max_retries + 1): + try: + async with httpx.AsyncClient(verify=False, timeout=120) as client: + resp = await client.post(KIRO_API_URL, json=kiro_request, headers=headers) + status_code = resp.status_code + + # 仅 429 状态码触发冷却和账号切换 + if resp.status_code == 429: + current_account.mark_quota_exceeded("Rate limited") + + # 尝试切换账号 + next_account = state.get_next_available_account(current_account.id) + if next_account and retry < max_retries: + print(f"[OpenAI] 429 限流,切换账号: {current_account.id} -> {next_account.id}") + current_account = next_account + token = current_account.get_token() + creds = current_account.get_credentials() + headers = build_headers( + token, + machine_id=current_account.get_machine_id(), + profile_arn=creds.profile_arn if creds else None, + client_id=creds.client_id if creds else None + ) + continue + + raise HTTPException(429, "All accounts rate limited") + + # 处理可重试的服务端错误(不触发冷却,仅重试) + if is_retryable_error(resp.status_code): + if retry < max_retries: + print(f"[OpenAI] 服务端错误 {resp.status_code},重试 {retry + 1}/{max_retries}") + await asyncio.sleep(0.5 * (2 ** retry)) + continue + raise HTTPException(resp.status_code, f"Server error after {max_retries} retries") + + if resp.status_code != 200: + error_msg = resp.text + print(f"[OpenAI] Kiro API error {resp.status_code}: {resp.text[:500]}") + + # 使用统一的错误处理 + error = classify_error(resp.status_code, error_msg) + print(format_error_log(error, current_account.id)) + + # 账号封禁 - 禁用账号 + if error.should_disable_account: + current_account.enabled = False + if hasattr(current_account, "auto_disabled"): + current_account.auto_disabled = False + from ..credential import CredentialStatus + current_account.status = CredentialStatus.SUSPENDED + try: + state._save_accounts() + except Exception: + pass + print(f"[OpenAI] 账号 {current_account.id} 已被禁用 (封禁)") + + # 仅 429 状态码触发冷却 + elif resp.status_code == 429: + current_account.mark_quota_exceeded(error_msg[:100]) + + # 其他错误(非 429、非内容过长)- 异步检查额度 + elif error.type != ErrorType.CONTENT_TOO_LONG: + from .anthropic import _check_and_disable_if_exhausted + asyncio.create_task(_check_and_disable_if_exhausted(current_account)) + + # 尝试切换账号(仅账号封禁时) + if error.should_disable_account: + next_account = state.get_next_available_account(current_account.id) + if next_account and retry < max_retries: + print(f"[OpenAI] 切换账号: {current_account.id} -> {next_account.id}") + current_account = next_account + headers["Authorization"] = f"Bearer {current_account.get_token()}" + continue + + # 检查是否为内容长度超限错误,尝试截断重试 + if error.type == ErrorType.CONTENT_TOO_LONG: + history_chars, user_chars, total_chars = history_manager.estimate_request_chars( + history, main_user_content + ) + print(f"[OpenAI] 内容长度超限: history={history_chars} chars, user={user_chars} chars, total={total_chars} chars") + truncated_history, should_retry = await history_manager.handle_length_error_async( + history, retry, call_summary + ) + if should_retry: + print(f"[OpenAI] 内容长度超限,{history_manager.truncate_info}") + history = truncated_history + kiro_request = build_kiro_request( + main_user_content, model, history, + images=images, + tools=kiro_tools if kiro_tools else None, + tool_results=tool_results if tool_results else None + ) + continue + else: + print(f"[OpenAI] 内容长度超限但未重试: retry={retry}/{max_retries}") + + raise HTTPException(resp.status_code, error.user_message) + + content = parse_event_stream(resp.content) + current_account.request_count += 1 + current_account.last_used = time.time() + get_rate_limiter().record_request(current_account.id) + break + + except HTTPException: + raise + except httpx.TimeoutException: + error_msg = "Request timeout" + status_code = 408 + if retry < max_retries: + print(f"[OpenAI] 请求超时,重试 {retry + 1}/{max_retries}") + await asyncio.sleep(0.5 * (2 ** retry)) + continue + raise HTTPException(408, "Request timeout after retries") + except httpx.ConnectError: + error_msg = "Connection error" + status_code = 502 + if retry < max_retries: + print(f"[OpenAI] 连接错误,重试 {retry + 1}/{max_retries}") + await asyncio.sleep(0.5 * (2 ** retry)) + continue + raise HTTPException(502, "Connection error after retries") + except Exception as e: + error_msg = str(e) + status_code = 500 + # 检查是否为可重试的网络错误 + if is_retryable_error(None, e) and retry < max_retries: + print(f"[OpenAI] 网络错误,重试 {retry + 1}/{max_retries}: {type(e).__name__}") + await asyncio.sleep(0.5 * (2 ** retry)) + continue + raise HTTPException(500, str(e)) + + # 记录日志 + duration = (time.time() - start_time) * 1000 + state.add_log(RequestLog( + id=log_id, + timestamp=time.time(), + method="POST", + path="/v1/chat/completions", + model=model, + account_id=current_account.id if current_account else None, + status=status_code, + duration_ms=duration, + error=error_msg + )) + + # 记录统计 + stats_manager.record_request( + account_id=current_account.id if current_account else "unknown", + model=model, + success=status_code == 200, + latency_ms=duration + ) + + # 估算 token 数量 + prompt_tokens = _estimate_input_tokens(messages, tools) + display_content = content + if thinking_cfg.enabled and thinking_content: + display_content = f"{format_thinking_block(thinking_content)}\n\n{content}" + + completion_tokens = _estimate_tokens(display_content) + total_tokens = prompt_tokens + completion_tokens + + if stream: + async def generate(): + for chunk in [display_content[i:i+20] for i in range(0, len(display_content), 20)]: + data = { + "id": f"chatcmpl-{log_id}", + "object": "chat.completion.chunk", + "created": int(time.time()), + "model": model, + "choices": [{"index": 0, "delta": {"content": chunk}, "finish_reason": None}] + } + yield f"data: {json.dumps(data)}\n\n" + await asyncio.sleep(0.02) + + # 最后一个 chunk 包含 usage 信息 + end_data = { + "id": f"chatcmpl-{log_id}", + "object": "chat.completion.chunk", + "created": int(time.time()), + "model": model, + "choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}], + "usage": {"prompt_tokens": prompt_tokens, "completion_tokens": completion_tokens, "total_tokens": total_tokens} + } + yield f"data: {json.dumps(end_data)}\n\n" + yield "data: [DONE]\n\n" + + return StreamingResponse(generate(), media_type="text/event-stream") + + return { + "id": f"chatcmpl-{log_id}", + "object": "chat.completion", + "created": int(datetime.now().timestamp()), + "model": model, + "choices": [{ + "index": 0, + "message": {"role": "assistant", "content": display_content}, + "finish_reason": "stop" + }], + "usage": {"prompt_tokens": prompt_tokens, "completion_tokens": completion_tokens, "total_tokens": total_tokens} + } diff --git a/KiroProxy/kiro_proxy/handlers/responses.py b/KiroProxy/kiro_proxy/handlers/responses.py new file mode 100644 index 0000000000000000000000000000000000000000..93d5248db7843434a03d00e077dbd2c5270b6acc --- /dev/null +++ b/KiroProxy/kiro_proxy/handlers/responses.py @@ -0,0 +1,926 @@ +"""OpenAI Responses API 处理 - /v1/responses + +Codex CLI 使用的 API 端点,深度适配 Codex 源码 +""" +import json +import uuid +import time +import asyncio +import httpx +from fastapi import Request, HTTPException +from fastapi.responses import StreamingResponse + +from ..config import KIRO_API_URL, map_model_name +from ..core import state, is_retryable_error, stats_manager +from ..core.state import RequestLog +from ..core.history_manager import HistoryManager, get_history_config +from ..core.error_handler import classify_error, ErrorType, format_error_log +from ..core.rate_limiter import get_rate_limiter +from ..kiro_api import build_headers, build_kiro_request, parse_event_stream, parse_event_stream_full, is_quota_exceeded_error +from ..core.thinking import ( + ThinkingConfig, + build_user_prompt_with_thinking, + extract_thinking_config_from_openai_body, + fetch_thinking_text, + format_thinking_block, + infer_thinking_from_openai_responses_input, +) + + +def _convert_responses_input_to_kiro(input_data, instructions: str = None): + """将 Responses API 的 input 转换为 Kiro 格式 + + Codex 发送的 input 格式: + - message (role=user): 用户消息 + - message (role=assistant): 助手回复 + - function_call: 工具调用 + - function_call_output: 工具调用结果 + + Kiro API 期望的格式: + - history: [userInputMessage, assistantResponseMessage, ...] 交替 + - 当 assistant 有 toolUses 时,下一条 userInputMessage 必须包含对应的 toolResults + - 当前请求的 userInputMessage 只包含最新一轮的 toolResults + """ + history = [] + user_content = "" + tool_results = [] + model_id = "claude-sonnet-4" + first_user_msg_added = False + pending_images = [] + + if isinstance(input_data, str): + if instructions: + return f"{instructions}\n\n{input_data}", history, tool_results, None + return input_data, history, tool_results, None + + if not isinstance(input_data, list): + return user_content, history, tool_results, None + + # 线性处理消息,跟踪状态 + pending_user_texts = [] + pending_tool_uses = [] + pending_tool_outputs = [] + last_was_assistant_with_tools = False + + for i, item in enumerate(input_data): + item_type = item.get("type", "") + is_last = (i == len(input_data) - 1) + + if item_type == "message": + role = item.get("role", "user") + content_list = item.get("content", []) + + # 提取文本和图片 + text_parts = [] + images = [] + for c in content_list: + if isinstance(c, str): + text_parts.append(c) + elif isinstance(c, dict): + c_type = c.get("type", "") + if c_type in ("input_text", "output_text", "text"): + text_parts.append(c.get("text", "")) + elif c_type == "input_image": + image_url = c.get("image_url", "") + if image_url.startswith("data:"): + import re + match = re.match(r'data:image/(\w+);base64,(.+)', image_url) + if match: + images.append({ + "format": match.group(1), + "source": {"bytes": match.group(2)} + }) + + text = "\n".join(text_parts) if text_parts else "" + + if role == "user": + if images: + pending_images.extend(images) + pending_user_texts.append(text) + + elif role == "assistant": + # 遇到 assistant 消息,先处理之前的 user 消息 + if pending_user_texts: + combined_user = "\n\n".join(pending_user_texts) + if not first_user_msg_added and instructions: + combined_user = f"{instructions}\n\n{combined_user}" + first_user_msg_added = True + + user_msg = { + "userInputMessage": { + "content": combined_user, + "modelId": model_id, + "origin": "AI_EDITOR" + } + } + # 如果上一个 assistant 有工具调用,这个 user 消息需要带 toolResults + if pending_tool_outputs: + user_msg["userInputMessage"]["userInputMessageContext"] = { + "toolResults": pending_tool_outputs + } + pending_tool_outputs = [] + + history.append(user_msg) + pending_user_texts = [] + elif pending_tool_outputs: + # 没有 user 消息,但有工具结果,创建一个带 toolResults 的 user 消息 + user_msg = { + "userInputMessage": { + "content": "Tool execution completed.", + "modelId": model_id, + "origin": "AI_EDITOR", + "userInputMessageContext": { + "toolResults": pending_tool_outputs + } + } + } + history.append(user_msg) + pending_tool_outputs = [] + + # 添加 assistant 消息 + assistant_msg = { + "assistantResponseMessage": { + "content": text or "I understand." + } + } + if pending_tool_uses: + assistant_msg["assistantResponseMessage"]["toolUses"] = pending_tool_uses + pending_tool_uses = [] + last_was_assistant_with_tools = True + else: + # 没有 toolUses 时不添加这个字段 + last_was_assistant_with_tools = False + + history.append(assistant_msg) + + elif item_type == "function_call": + try: + args = json.loads(item.get("arguments", "{}")) if isinstance(item.get("arguments"), str) else item.get("arguments", {}) + except: + args = {} + + tool_use = { + "toolUseId": item.get("call_id", ""), + "name": item.get("name", ""), + "input": args + } + + # 如果上一条是 assistant 消息,添加 toolUses + if history and "assistantResponseMessage" in history[-1]: + if "toolUses" not in history[-1]["assistantResponseMessage"]: + history[-1]["assistantResponseMessage"]["toolUses"] = [] + history[-1]["assistantResponseMessage"]["toolUses"].append(tool_use) + last_was_assistant_with_tools = True + else: + pending_tool_uses.append(tool_use) + + elif item_type == "function_call_output": + call_id = item.get("call_id", "") + output = item.get("output", {}) + + # 跳过没有 call_id 的 tool output + if not call_id: + print(f"[Responses] Warning: function_call_output without call_id, skipping") + continue + + if isinstance(output, str): + output_str = output + status = "success" + elif isinstance(output, dict): + output_str = output.get("content", json.dumps(output)) + status = "success" if output.get("success", True) is not False else "error" + else: + output_str = str(output) + status = "success" + + pending_tool_outputs.append({ + "content": [{"text": output_str}], + "status": status, + "toolUseId": call_id + }) + + # 处理剩余的消息 + if pending_user_texts: + user_content = "\n\n".join(pending_user_texts) + if not first_user_msg_added and instructions: + user_content = f"{instructions}\n\n{user_content}" + elif pending_tool_outputs: + user_content = "Please continue based on the tool results." + + if pending_tool_outputs: + tool_results = pending_tool_outputs + + # 验证并修复 history 中的 toolUses/toolResults 配对 + # Kiro API 规则:当 assistant 有 toolUses 时,下一条 user 必须有对应的 toolResults + for i in range(len(history) - 1): + if "assistantResponseMessage" in history[i]: + assistant = history[i]["assistantResponseMessage"] + has_tool_uses = bool(assistant.get("toolUses")) + + if i + 1 < len(history) and "userInputMessage" in history[i + 1]: + user = history[i + 1]["userInputMessage"] + ctx = user.get("userInputMessageContext", {}) + has_tool_results = bool(ctx.get("toolResults")) + + # 确保配对一致 + if has_tool_uses and not has_tool_results: + # assistant 有 toolUses 但 user 没有 toolResults,清除 toolUses + print(f"[Responses] Warning: history[{i}] has toolUses but history[{i+1}] has no toolResults, removing toolUses") + assistant.pop("toolUses", None) + elif not has_tool_uses and has_tool_results: + # assistant 没有 toolUses 但 user 有 toolResults,清除 toolResults + print(f"[Responses] Warning: history[{i}] has no toolUses but history[{i+1}] has toolResults, removing toolResults") + user.pop("userInputMessageContext", None) + + # 调试日志 + print(f"[Responses] Converted: history={len(history)}, tool_results={len(tool_results)}") + for i, h in enumerate(history): + if "userInputMessage" in h: + has_tr = "toolResults" in h.get("userInputMessage", {}).get("userInputMessageContext", {}) + print(f"[Responses] history[{i}]: userInputMessage, has_toolResults={has_tr}") + elif "assistantResponseMessage" in h: + arm = h.get("assistantResponseMessage", {}) + has_tu_field = "toolUses" in arm + tu_count = len(arm.get("toolUses", []) or []) if has_tu_field else 0 + print(f"[Responses] history[{i}]: assistantResponseMessage, has_toolUses_field={has_tu_field}, toolUses_count={tu_count}") + + images = pending_images if pending_images else None + return user_content, history, tool_results, images + + +def _convert_tools_to_kiro(tools: list) -> list: + """将 Responses API 的 tools 转换为 Kiro 格式 + + Codex Responses API 工具格式: + { + "type": "function", + "name": "...", + "description": "...", + "strict": true, + "parameters": {...} + } + + 或特殊工具: + { + "type": "web_search", + "external_web_access": true/false + } + { + "type": "local_shell" + } + + Kiro API 期望的格式: + { + "toolSpecification": { + "name": "...", + "description": "...", + "inputSchema": {"json": {...}} + } + } + 或 + { + "webSearchTool": {"type": "web_search"} + } + """ + if not tools: + return None + + MAX_TOOLS = 50 # Kiro API 工具数量限制 + kiro_tools = [] + function_count = 0 + + for tool in tools: + tool_type = tool.get("type", "") + + # 特殊工具类型 + if tool_type == "web_search": + # Kiro 支持 web_search + kiro_tools.append({ + "webSearchTool": { + "type": "web_search" + } + }) + continue + elif tool_type == "local_shell": + # local_shell 是 OpenAI 原生工具,Kiro 不支持,跳过 + continue + + # 限制工具数量 + if function_count >= MAX_TOOLS: + continue + + # Responses API 格式:字段直接在工具对象上 + # Chat Completions API 格式:字段嵌套在 function 里 + if tool_type == "function": + # 检查是否是 Chat Completions 格式(有 function 嵌套) + if "function" in tool: + func = tool["function"] + name = func.get("name", "") + description = func.get("description", "")[:500] + parameters = func.get("parameters", {"type": "object", "properties": {}}) + else: + # Responses API 格式 + name = tool.get("name", "") + description = tool.get("description", "")[:500] + parameters = tool.get("parameters", {"type": "object", "properties": {}}) + elif tool_type == "custom": + # 自定义工具格式 + name = tool.get("name", "") + description = tool.get("description", "")[:500] + # custom 工具可能有不同的 schema 格式 + fmt = tool.get("format", {}) + if fmt.get("type") == "json_schema": + parameters = fmt.get("schema", {"type": "object", "properties": {}}) + else: + parameters = {"type": "object", "properties": {}} + else: + name = tool.get("name", "") + description = tool.get("description", "")[:500] + parameters = tool.get("parameters", tool.get("input_schema", {"type": "object", "properties": {}})) + + if not name: + continue + + function_count += 1 + + # 转换为 Kiro 格式 + kiro_tools.append({ + "toolSpecification": { + "name": name, + "description": description or f"Tool: {name}", + "inputSchema": { + "json": parameters + } + } + }) + + return kiro_tools if kiro_tools else None + + +async def handle_responses(request: Request): + """处理 /v1/responses 请求""" + start_time = time.time() + log_id = uuid.uuid4().hex[:12] + + body = await request.json() + model = map_model_name(body.get("model", "gpt-4o")) + input_data = body.get("input", "") + instructions = body.get("instructions", "") + stream = body.get("stream", True) + tools = body.get("tools", []) + + thinking_cfg, thinking_explicit = extract_thinking_config_from_openai_body(body) + if not thinking_explicit and infer_thinking_from_openai_responses_input(input_data): + thinking_cfg = ThinkingConfig(True, None) + + if not input_data: + raise HTTPException(400, "input required") + + import hashlib + session_str = json.dumps(input_data[:3] if isinstance(input_data, list) else str(input_data)[:100], sort_keys=True, default=str) + session_id = hashlib.sha256(session_str.encode()).hexdigest()[:16] + account = state.get_available_account(session_id) + + if not account: + raise HTTPException(503, "All accounts are rate limited or unavailable") + + if account.is_token_expiring_soon(5): + await account.refresh_token() + + token = account.get_token() + if not token: + raise HTTPException(500, f"Failed to get token for account {account.name}") + + creds = account.get_credentials() + headers = build_headers( + token, + machine_id=account.get_machine_id(), + profile_arn=creds.profile_arn if creds else None, + client_id=creds.client_id if creds else None + ) + + rate_limiter = get_rate_limiter() + can_request, wait_seconds, _ = rate_limiter.can_request(account.id) + if not can_request: + await asyncio.sleep(wait_seconds) + + user_content, history, tool_results, images = _convert_responses_input_to_kiro(input_data, instructions) + + # 修复历史消息交替 + from ..converters import fix_history_alternation + history = fix_history_alternation(history) + + history_manager = HistoryManager(get_history_config(), cache_key=session_id) + + # 对于 Responses API,强制启用自动截断(Codex CLI 的历史可能很长) + from ..core.history_manager import TruncateStrategy + if TruncateStrategy.AUTO_TRUNCATE not in history_manager.config.strategies: + history_manager.config.strategies.append(TruncateStrategy.AUTO_TRUNCATE) + + # 创建摘要 API 调用函数 + async def api_caller(prompt: str) -> str: + req = build_kiro_request(prompt, "claude-haiku-4.5", []) + try: + async with httpx.AsyncClient(verify=False, timeout=60) as client: + resp = await client.post(KIRO_API_URL, json=req, headers=headers) + if resp.status_code == 200: + return parse_event_stream(resp.content) + except Exception as e: + print(f"[Responses] Summary API 调用失败: {e}") + return "" + + # 检查是否需要智能摘要或错误重试预摘要 + if history_manager.should_summarize(history) or history_manager.should_pre_summary_for_error_retry(history, user_content): + history = await history_manager.pre_process_async(history, user_content, api_caller) + else: + history = history_manager.pre_process(history, user_content) + + # 摘要/截断后再次修复历史交替和 toolUses/toolResults 配对 + history = fix_history_alternation(history) + + if history_manager.was_truncated: + print(f"[Responses] {history_manager.truncate_info}") + + kiro_tools = _convert_tools_to_kiro(tools) + + # 调试:打印 input 结构 + if isinstance(input_data, list): + for i, item in enumerate(input_data): + item_type = item.get("type", "?") + role = item.get("role", "?") + print(f"[Responses] input[{i}]: type={item_type}, role={role}") + print(f"[Responses] history len: {len(history)}, tool_results len: {len(tool_results)}, images: {len(images) if images else 0}") + print(f"[Responses] user_content len: {len(user_content)}") + + # 验证 tool_results 与 history 的一致性 + if tool_results and history: + # 找到最后一个 assistant 消息 + last_assistant = None + last_assistant_idx = -1 + for i, msg in enumerate(reversed(history)): + if "assistantResponseMessage" in msg: + last_assistant = msg["assistantResponseMessage"] + last_assistant_idx = len(history) - 1 - i + break + + if last_assistant: + tool_use_ids = set() + for tu in last_assistant.get("toolUses", []) or []: + tu_id = tu.get("toolUseId") + if tu_id: + tool_use_ids.add(tu_id) + + print(f"[Responses] Last assistant at idx={last_assistant_idx}, toolUse_ids={tool_use_ids}") + print(f"[Responses] tool_results ids={[tr.get('toolUseId') for tr in tool_results]}") + + # 过滤 tool_results,只保留有对应 toolUse 的 + if tool_use_ids: + filtered_results = [tr for tr in tool_results if tr.get("toolUseId") in tool_use_ids] + if len(filtered_results) != len(tool_results): + print(f"[Responses] Filtered tool_results: {len(tool_results)} -> {len(filtered_results)}") + tool_results = filtered_results + else: + # 如果最后一个 assistant 没有 toolUses,清空 tool_results + print(f"[Responses] Warning: Last assistant has no toolUses, clearing tool_results") + tool_results = [] + else: + print(f"[Responses] Warning: No assistant message in history, clearing tool_results") + tool_results = [] + + # 确保所有消息都有非空的 content + for i, msg in enumerate(history): + if "userInputMessage" in msg: + uim = msg["userInputMessage"] + if not uim.get("content"): + uim["content"] = "Continue" + elif "assistantResponseMessage" in msg: + arm = msg["assistantResponseMessage"] + if not arm.get("content"): + arm["content"] = "I understand." + + kiro_request = build_kiro_request( + user_content, model, history, + tools=kiro_tools, + images=images, + tool_results=tool_results if tool_results else None + ) + + # 调试:打印完整的 Kiro 请求(使用深拷贝避免修改原始请求) + if tool_results: + import copy + # 打印请求结构(不包括 tools,因为太长) + debug_request = copy.deepcopy({ + "conversationState": { + "history_len": len(kiro_request.get("conversationState", {}).get("history", [])), + "currentMessage": kiro_request.get("conversationState", {}).get("currentMessage", {}), + } + }) + # 移除 tools 以便打印(只在 debug_request 中) + if "userInputMessageContext" in debug_request["conversationState"]["currentMessage"].get("userInputMessage", {}): + ctx = debug_request["conversationState"]["currentMessage"]["userInputMessage"]["userInputMessageContext"] + if "tools" in ctx: + ctx["tools_count"] = len(ctx["tools"]) + del ctx["tools"] + print(f"[Responses] Kiro request structure: {json.dumps(debug_request, indent=2)}") + + if stream: + return await _handle_stream(kiro_request, headers, account, model, log_id, start_time, thinking_cfg) + + # 非流式 + thinking_content = "" + if thinking_cfg.enabled: + thinking_content = await fetch_thinking_text( + headers=headers, + model=model, + user_content=user_content, + history=history, + images=images, + tool_results=tool_results if tool_results else None, + budget_tokens=thinking_cfg.budget_tokens, + ) + if thinking_content: + account.request_count += 1 + account.last_used = time.time() + get_rate_limiter().record_request(account.id) + + main_user_content = build_user_prompt_with_thinking(user_content, thinking_content) + kiro_request = build_kiro_request( + main_user_content, + model, + history, + tools=kiro_tools, + images=images, + tool_results=tool_results if tool_results else None, + ) + + async with httpx.AsyncClient(verify=False, timeout=120) as client: + resp = await client.post(KIRO_API_URL, json=kiro_request, headers=headers) + if resp.status_code != 200: + raise HTTPException(resp.status_code, resp.text) + + result = parse_event_stream_full(resp.content) + account.request_count += 1 + account.last_used = time.time() + get_rate_limiter().record_request(account.id) + + return _build_response(result, model, log_id, thinking_content) + + +def _build_response(result: dict, model: str, response_id: str, thinking_content: str = "") -> dict: + """构建非流式响应""" + text = "".join(result.get("content", [])) + if thinking_content: + text = f"{format_thinking_block(thinking_content)}\n\n{text}" if text else format_thinking_block(thinking_content) + output = [] + + if text: + output.append({ + "type": "message", + "id": f"msg_{response_id}", + "role": "assistant", + "content": [{"type": "output_text", "text": text, "annotations": []}] + }) + + for tool_use in result.get("tool_uses", []): + output.append({ + "type": "function_call", + "id": tool_use.get("id", f"call_{uuid.uuid4().hex[:12]}"), + "call_id": tool_use.get("id", f"call_{uuid.uuid4().hex[:12]}"), + "name": tool_use.get("name", ""), + "arguments": json.dumps(tool_use.get("input", {})) + }) + + return { + "id": f"resp_{response_id}", + "object": "response", + "created_at": int(time.time()), + "status": "completed", + "model": model, + "output": output, + "usage": {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0} + } + + +async def _handle_stream(kiro_request, headers, account, model, log_id, start_time, thinking_cfg: ThinkingConfig): + """流式处理 - Codex 期望的 SSE 格式""" + + # 保存完整请求用于调试 + import os + debug_dir = "debug_requests" + os.makedirs(debug_dir, exist_ok=True) + debug_file = f"{debug_dir}/{log_id}_request.json" + with open(debug_file, 'w', encoding='utf-8') as f: + json.dump(kiro_request, f, indent=2, ensure_ascii=False) + print(f"[Responses] Saved request to {debug_file}") + + async def generate(): + response_id = f"resp_{log_id}" + item_id = f"msg_{log_id}" + created_at = int(time.time()) + full_content = "" + tool_uses = [] + error_occurred = False + thinking_content = "" + + # 1. response.created (先发,后续可能失败) + yield _sse( + "response.created", + { + "type": "response.created", + "response": { + "id": response_id, + "object": "response", + "created_at": created_at, + "status": "in_progress", + "model": model, + "output": [], + }, + }, + ) + + # 2. response.output_item.added (message item) + yield _sse( + "response.output_item.added", + { + "type": "response.output_item.added", + "output_index": 0, + "item": { + "id": item_id, + "type": "message", + "status": "in_progress", + "role": "assistant", + "content": [], + }, + }, + ) + + # Responses API: 先发送 thinking(如果启用),再发送主回答 + if thinking_cfg.enabled: + try: + cs = kiro_request.get("conversationState", {}) + hist = cs.get("history", []) + cm = cs.get("currentMessage", {}).get("userInputMessage", {}) + original_user = cm.get("content", "") + images = cm.get("images") + ctx = cm.get("userInputMessageContext", {}) or {} + tool_results = ctx.get("toolResults") + + thinking_content = await fetch_thinking_text( + headers=headers, + model=model, + user_content=original_user, + history=hist, + images=images, + tool_results=tool_results, + budget_tokens=thinking_cfg.budget_tokens, + ) + if thinking_content: + account.request_count += 1 + account.last_used = time.time() + get_rate_limiter().record_request(account.id) + + thinking_prefix = f"{format_thinking_block(thinking_content)}\n\n" + full_content += thinking_prefix + # 以 delta 的形式写入同一个 message item + for i in range(0, len(thinking_prefix), 256): + yield _sse( + "response.output_text.delta", + { + "type": "response.output_text.delta", + "item_id": item_id, + "output_index": 0, + "content_index": 0, + "delta": thinking_prefix[i : i + 256], + }, + ) + + # 将 thinking 注入主请求(强约束不泄露) + main_user = build_user_prompt_with_thinking(original_user, thinking_content) + if "content" in cm: + cm["content"] = main_user + except Exception: + thinking_content = "" + + print(f"[Responses] Request: model={model}, log_id={log_id}") + + try: + async with httpx.AsyncClient(verify=False, timeout=300) as client: + async with client.stream("POST", KIRO_API_URL, json=kiro_request, headers=headers) as response: + + if response.status_code != 200: + error_text = await response.aread() + error_msg = error_text.decode()[:500] + print(f"[Responses] Kiro error: {response.status_code} - {error_msg[:200]}") + + # 打印更多调试信息 + if response.status_code == 400: + cs = kiro_request.get("conversationState", {}) + hist = cs.get("history", []) + print(f"[Responses] 400 Debug: history_len={len(hist)}") + if hist: + # 检查每条 history 的详细结构 + for i, h in enumerate(hist[:5]): # 只打印前5条 + if "userInputMessage" in h: + uim = h["userInputMessage"] + has_ctx = "userInputMessageContext" in uim + has_tr = has_ctx and "toolResults" in uim.get("userInputMessageContext", {}) + content_len = len(uim.get("content", "")) + uim_keys = list(uim.keys()) + print(f"[Responses] hist[{i}]: user, keys={uim_keys}, content_len={content_len}, has_toolResults={has_tr}") + elif "assistantResponseMessage" in h: + arm = h["assistantResponseMessage"] + arm_keys = list(arm.keys()) + has_tu = "toolUses" in arm + tu_count = len(arm.get("toolUses", []) or []) if has_tu else 0 + content_len = len(arm.get("content", "") or "") + print(f"[Responses] hist[{i}]: assistant, keys={arm_keys}, content_len={content_len}, has_toolUses={has_tu}, toolUses_count={tu_count}") + else: + print(f"[Responses] hist[{i}]: UNKNOWN keys={list(h.keys())}") + if len(hist) > 5: + print(f"[Responses] ... ({len(hist) - 5} more)") + + # 打印 currentMessage 结构 + cm = cs.get("currentMessage", {}) + if "userInputMessage" in cm: + uim = cm["userInputMessage"] + print(f"[Responses] currentMessage: keys={list(uim.keys())}, content_len={len(uim.get('content', ''))}") + if "userInputMessageContext" in uim: + ctx = uim["userInputMessageContext"] + print(f"[Responses] context keys={list(ctx.keys())}") + if "toolResults" in ctx: + print(f"[Responses] toolResults count={len(ctx['toolResults'])}") + if "tools" in ctx: + print(f"[Responses] tools count={len(ctx['tools'])}") + + error_occurred = True + + # 映射错误代码 + error_code = "api_error" + error_lower = error_msg.lower() + if response.status_code == 429 or "rate limit" in error_lower or "throttl" in error_lower: + error_code = "rate_limit_exceeded" + elif "context" in error_lower or "too long" in error_lower or "content length" in error_lower: + error_code = "context_length_exceeded" + elif "quota" in error_lower or "insufficient" in error_lower: + error_code = "insufficient_quota" + elif response.status_code == 401 or response.status_code == 403: + error_code = "authentication_error" + + yield _sse("response.failed", { + "type": "response.failed", + "response": { + "id": response_id, + "object": "response", + "status": "failed", + "error": {"code": error_code, "message": error_msg[:200]} + } + }) + return + + # 3. 流式读取并发送 delta + full_response = b"" + async for chunk in response.aiter_bytes(): + full_response += chunk + + # 尝试解析增量内容 + content = _extract_content_from_chunk(chunk) + if content: + full_content += content + yield _sse("response.output_text.delta", { + "type": "response.output_text.delta", + "item_id": item_id, + "output_index": 0, + "content_index": 0, + "delta": content + }) + + # 解析完整响应获取工具调用 + result = parse_event_stream_full(full_response) + tool_uses = result.get("tool_uses", []) + if not full_content: + full_content = "".join(result.get("content", [])) + + account.request_count += 1 + account.last_used = time.time() + get_rate_limiter().record_request(account.id) + + except Exception as e: + error_occurred = True + yield _sse("response.failed", { + "type": "response.failed", + "response": { + "id": response_id, + "status": "failed", + "error": {"code": "internal_error", "message": str(e)[:200]} + } + }) + return + + # 4. response.output_item.done - 消息完成 + message_content = [{"type": "output_text", "text": full_content, "annotations": []}] + yield _sse("response.output_item.done", { + "type": "response.output_item.done", + "output_index": 0, + "item": { + "id": item_id, + "type": "message", + "status": "completed", + "role": "assistant", + "content": message_content + } + }) + + # 构建 output 列表 + output_items = [{ + "id": item_id, + "type": "message", + "status": "completed", + "role": "assistant", + "content": message_content + }] + + # 5. 工具调用 + for i, tool_use in enumerate(tool_uses): + tool_item_id = tool_use.get("id", f"call_{uuid.uuid4().hex[:12]}") + tool_item = { + "type": "function_call", + "id": tool_item_id, + "call_id": tool_item_id, + "name": tool_use.get("name", ""), + "arguments": json.dumps(tool_use.get("input", {})) + } + + yield _sse("response.output_item.added", { + "type": "response.output_item.added", + "output_index": i + 1, + "item": tool_item + }) + + yield _sse("response.output_item.done", { + "type": "response.output_item.done", + "output_index": i + 1, + "item": tool_item + }) + + output_items.append(tool_item) + + # 6. response.completed - 必须发送! + yield _sse("response.completed", { + "type": "response.completed", + "response": { + "id": response_id, + "object": "response", + "created_at": created_at, + "status": "completed", + "model": model, + "output": output_items, + "usage": { + "input_tokens": 0, + "input_tokens_details": {"cached_tokens": 0}, + "output_tokens": 0, + "output_tokens_details": {"reasoning_tokens": 0}, + "total_tokens": 0 + } + } + }) + + return StreamingResponse(generate(), media_type="text/event-stream") + + +def _sse(event_type: str, data: dict) -> str: + """生成 SSE 格式的事件""" + return f"event: {event_type}\ndata: {json.dumps(data)}\n\n" + + +def _extract_content_from_chunk(chunk: bytes) -> str: + """从 AWS event-stream chunk 中提取文本内容""" + content = "" + pos = 0 + + while pos < len(chunk): + if pos + 12 > len(chunk): + break + + total_len = int.from_bytes(chunk[pos:pos+4], 'big') + if total_len == 0 or total_len > len(chunk) - pos: + break + + headers_len = int.from_bytes(chunk[pos+4:pos+8], 'big') + payload_start = pos + 12 + headers_len + payload_end = pos + total_len - 4 + + if payload_start < payload_end: + try: + payload = json.loads(chunk[payload_start:payload_end].decode('utf-8')) + if 'assistantResponseEvent' in payload: + c = payload['assistantResponseEvent'].get('content') + if c: + content += c + elif 'content' in payload and 'toolUseId' not in payload: + content += payload['content'] + except: + pass + + pos += total_len + + return content diff --git a/KiroProxy/kiro_proxy/kiro_api.py b/KiroProxy/kiro_proxy/kiro_api.py new file mode 100644 index 0000000000000000000000000000000000000000..fc210e297bf04170aa8aa2f6936a159f171a5a6e --- /dev/null +++ b/KiroProxy/kiro_proxy/kiro_api.py @@ -0,0 +1,62 @@ +"""Kiro API 调用模块 - 兼容层 + +此文件保留用于向后兼容,实际实现已移至 providers/kiro.py。 +""" +from .providers.kiro import KiroProvider +from .credential import generate_machine_id, get_kiro_version, get_system_info, quota_manager + +# 创建默认 provider 实例 +_default_provider = KiroProvider() + + +def build_headers( + token: str, + agent_mode: str = "vibe", + machine_id: str = None, + profile_arn: str = None, + client_id: str = None +) -> dict: + """构建 Kiro API 请求头""" + if machine_id: + return _default_provider.build_headers(token, agent_mode, machine_id=machine_id) + + # 如果提供了凭证信息,生成对应的 machine_id + if profile_arn or client_id: + mid = generate_machine_id(profile_arn, client_id) + return _default_provider.build_headers(token, agent_mode, machine_id=mid) + + return _default_provider.build_headers(token, agent_mode) + + +def build_kiro_request( + user_content: str, + model: str, + history: list = None, + tools: list = None, + images: list = None, + tool_results: list = None +) -> dict: + """构建 Kiro API 请求体""" + return _default_provider.build_request( + user_content=user_content, + model=model, + history=history, + tools=tools, + images=images, + tool_results=tool_results + ) + + +def parse_event_stream(raw: bytes) -> str: + """解析 AWS event-stream 格式,返回文本内容""" + return _default_provider.parse_response_text(raw) + + +def parse_event_stream_full(raw: bytes) -> dict: + """解析 AWS event-stream 格式,返回完整结构""" + return _default_provider.parse_response(raw) + + +def is_quota_exceeded_error(status_code: int, error_text: str) -> bool: + """检查是否为配额超限错误""" + return quota_manager.is_quota_exceeded_error(status_code, error_text) diff --git a/KiroProxy/kiro_proxy/main.py b/KiroProxy/kiro_proxy/main.py new file mode 100644 index 0000000000000000000000000000000000000000..d8d5be2da3973c45168fbfc13e0867b9e383a5c8 --- /dev/null +++ b/KiroProxy/kiro_proxy/main.py @@ -0,0 +1,97 @@ +import time +from contextlib import asynccontextmanager + +from fastapi import FastAPI, Request +from fastapi.middleware.cors import CORSMiddleware + +from . import __version__ +from .core import get_quota_scheduler, get_refresh_manager, scheduler, state +from .routers import admin, protocols, web + + +@asynccontextmanager +async def lifespan(app: FastAPI): + await scheduler.start() + + refresh_manager = get_refresh_manager() + refresh_manager.set_accounts_getter(lambda: state.accounts) + + accounts = state.accounts + if accounts: + print(f"[Startup] 检查 {len(accounts)} 个账号的 Token 状态...") + for account in accounts: + if account.enabled and refresh_manager.should_refresh_token(account): + try: + success, msg = await refresh_manager.refresh_token_if_needed(account) + if success: + print(f"[Startup] 账号 {account.name} Token 刷新成功") + else: + print(f"[Startup] 账号 {account.name} Token 刷新失败: {msg}") + except Exception as e: + print(f"[Startup] 账号 {account.name} Token 刷新异常: {e}") + + quota_scheduler = get_quota_scheduler() + quota_scheduler.set_accounts_getter(lambda: state.accounts) + await quota_scheduler.start() + + await refresh_manager.start_auto_refresh() + + yield + + await refresh_manager.stop_auto_refresh() + await quota_scheduler.stop() + await scheduler.stop() + + +def create_app() -> FastAPI: + app = FastAPI(title="Kiro API Proxy", docs_url="/docs", redoc_url=None, lifespan=lifespan) + + app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) + + @app.middleware("http") + async def log_requests(request: Request, call_next): + start_time = time.time() + path = request.url.path + method = request.method + + body_str = "" + + print(f"[Request] {method} {path} | Body: {body_str}") + + response = await call_next(request) + + duration = (time.time() - start_time) * 1000 + print(f"[Response] {method} {path} - {response.status_code} ({duration:.2f}ms)") + return response + + app.include_router(web.router) + app.include_router(protocols.router) + app.include_router(admin.router) + + return app + + +app = create_app() + + +def run(port: int = 8080): + import uvicorn + + print(f"\n{'='*50}") + print(f" Kiro API Proxy v{__version__}") + print(f" http://localhost:{port}") + print(f"{'='*50}\n") + uvicorn.run(app, host="0.0.0.0", port=port) + + +if __name__ == "__main__": + import sys + + port = int(sys.argv[1]) if len(sys.argv) > 1 else 8080 + run(port) diff --git a/KiroProxy/kiro_proxy/main_legacy.py b/KiroProxy/kiro_proxy/main_legacy.py new file mode 100644 index 0000000000000000000000000000000000000000..f34bd2de4a6cd12f428d82070180a7bf9768cf1b --- /dev/null +++ b/KiroProxy/kiro_proxy/main_legacy.py @@ -0,0 +1,694 @@ +"""Kiro API Proxy - 主应用""" +import json +import uuid +import httpx +import sys +import time +from pathlib import Path +from contextlib import asynccontextmanager +from fastapi import FastAPI, Request, HTTPException +from fastapi.responses import HTMLResponse, StreamingResponse, JSONResponse +from fastapi.middleware.cors import CORSMiddleware + +from . import __version__ +from .config import MODELS_URL +from .core import state, scheduler, stats_manager, get_quota_scheduler, get_refresh_manager +from .handlers import anthropic, openai, gemini, admin +from .handlers import responses as responses_handler +from .web.html import HTML_PAGE +from .credential import generate_machine_id, get_kiro_version + + +def get_resource_path(relative_path: str) -> Path: + """获取资源文件路径,支持从打包资源读取""" + base_path = Path(sys._MEIPASS) if hasattr(sys, '_MEIPASS') else Path(__file__).parent.parent + return base_path / relative_path + + +@asynccontextmanager +async def lifespan(app: FastAPI): + """应用生命周期管理""" + # 启动时 + await scheduler.start() + + # 初始化刷新管理器 + refresh_manager = get_refresh_manager() + refresh_manager.set_accounts_getter(lambda: state.accounts) + + # 启动时先刷新所有需要刷新的 Token,避免首次请求使用过期 Token + accounts = state.accounts + if accounts: + print(f"[Startup] 检查 {len(accounts)} 个账号的 Token 状态...") + for account in accounts: + if account.enabled and refresh_manager.should_refresh_token(account): + try: + success, msg = await refresh_manager.refresh_token_if_needed(account) + if success: + print(f"[Startup] 账号 {account.name} Token 刷新成功") + else: + print(f"[Startup] 账号 {account.name} Token 刷新失败: {msg}") + except Exception as e: + print(f"[Startup] 账号 {account.name} Token 刷新异常: {e}") + + # 启动额度调度器(在 Token 刷新之后,避免首次额度获取使用过期 Token) + quota_scheduler = get_quota_scheduler() + quota_scheduler.set_accounts_getter(lambda: state.accounts) + await quota_scheduler.start() + + await refresh_manager.start_auto_refresh() + + yield + + # 关闭时 + await refresh_manager.stop_auto_refresh() + await quota_scheduler.stop() + await scheduler.stop() + + +app = FastAPI(title="Kiro API Proxy", docs_url="/docs", redoc_url=None, lifespan=lifespan) + +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + + +@app.middleware("http") +async def log_requests(request: Request, call_next): + start_time = time.time() + path = request.url.path + method = request.method + + body_str = "" + + print(f"[Request] {method} {path} | Body: {body_str}") + + response = await call_next(request) + + duration = (time.time() - start_time) * 1000 + print(f"[Response] {method} {path} - {response.status_code} ({duration:.2f}ms)") + return response + + +# ==================== Web UI ==================== + +@app.get("/", response_class=HTMLResponse) +async def index(): + return HTML_PAGE + + +@app.get("/assets/{path:path}") +async def serve_assets(path: str): + """提供静态资源""" + file_path = get_resource_path("assets") / path + if file_path.exists(): + content_type = "image/svg+xml" if path.endswith(".svg") else "application/octet-stream" + return StreamingResponse(open(file_path, "rb"), media_type=content_type) + raise HTTPException(status_code=404) + + +# ==================== API 端点 ==================== + +@app.get("/v1/models") +async def models(): + """获取可用模型列表""" + try: + account = state.get_available_account() + if not account: + raise Exception("No available account") + + token = account.get_token() + machine_id = account.get_machine_id() + kiro_version = get_kiro_version() + + headers = { + "content-type": "application/json", + "x-amz-user-agent": f"aws-sdk-js/1.0.0 KiroIDE-{kiro_version}-{machine_id}", + "amz-sdk-invocation-id": str(uuid.uuid4()), + "Authorization": f"Bearer {token}", + } + async with httpx.AsyncClient(verify=False, timeout=30) as client: + resp = await client.get(MODELS_URL, headers=headers, params={"origin": "AI_EDITOR"}) + if resp.status_code == 200: + data = resp.json() + return { + "object": "list", + "data": [ + { + "id": m["modelId"], + "object": "model", + "owned_by": "kiro", + "name": m["modelName"], + } + for m in data.get("models", []) + ] + } + except Exception: + pass + + # 降级返回静态列表 + return {"object": "list", "data": [ + {"id": "auto", "object": "model", "owned_by": "kiro", "name": "Auto"}, + {"id": "claude-sonnet-4.5", "object": "model", "owned_by": "kiro", "name": "Claude Sonnet 4.5"}, + {"id": "claude-sonnet-4", "object": "model", "owned_by": "kiro", "name": "Claude Sonnet 4"}, + {"id": "claude-haiku-4.5", "object": "model", "owned_by": "kiro", "name": "Claude Haiku 4.5"}, + ]} + + +# Anthropic 协议 +@app.post("/v1/messages") +async def anthropic_messages(request: Request): + print(f"[Main] Received /v1/messages request from {request.client.host}") + return await anthropic.handle_messages(request) + +@app.post("/v1/messages/count_tokens") +async def anthropic_count_tokens(request: Request): + return await anthropic.handle_count_tokens(request) + + + +@app.post("/v1/complete") +async def anthropic_complete(request: Request): + print(f"[Main] Received /v1/complete request from {request.client.host}") + # 暂时重定向或提示,或者如果需要可以实现 handle_complete + return JSONResponse( + status_code=400, + content={"error": {"type": "invalid_request_error", "message": "KiroProxy currently only supports /v1/messages. Please check if your client can be configured to use Messages API."}} + ) + + +# OpenAI 协议 +@app.post("/v1/chat/completions") +async def openai_chat(request: Request): + return await openai.handle_chat_completions(request) + + +# OpenAI Responses API (Codex CLI 新版本) +@app.post("/v1/responses") +async def openai_responses(request: Request): + return await responses_handler.handle_responses(request) + + +# Gemini 协议 +@app.post("/v1beta/models/{model_name}:generateContent") +@app.post("/v1/models/{model_name}:generateContent") +async def gemini_generate(model_name: str, request: Request): + return await gemini.handle_generate_content(model_name, request) + + +# ==================== 管理 API ==================== + +@app.get("/api/status") +async def api_status(): + return await admin.get_status() + +@app.post("/api/event_logging/batch") +async def api_event_logging_batch(request: Request): + return await admin.event_logging_batch(request) + + +@app.get("/api/stats") +async def api_stats(): + return await admin.get_stats() + + +@app.get("/api/logs") +async def api_logs(limit: int = 100): + return await admin.get_logs(limit) + + +# ==================== 账号导入导出 API ==================== + +@app.get("/api/accounts/export") +async def api_export_accounts(): + """导出所有账号配置""" + return await admin.export_accounts() + + +@app.post("/api/accounts/import") +async def api_import_accounts(request: Request): + """导入账号配置""" + return await admin.import_accounts(request) + + +@app.post("/api/accounts/manual") +async def api_add_manual_token(request: Request): + """手动添加 Token""" + return await admin.add_manual_token(request) + + +@app.post("/api/accounts/batch") +async def api_batch_import_accounts(request: Request): + """批量导入账号""" + return await admin.batch_import_accounts(request) + + +@app.post("/api/accounts/refresh-all") +async def api_refresh_all(): + """刷新所有即将过期的 token""" + return await admin.refresh_all_tokens() + + +# ==================== 额度管理 API (必须在 {account_id} 路由之前) ==================== + +@app.get("/api/accounts/status") +async def api_accounts_status_enhanced(): + """获取完整账号状态(增强版)""" + return await admin.get_accounts_status_enhanced() + + +@app.get("/api/accounts/summary") +async def api_accounts_summary(): + """获取账号汇总统计""" + return await admin.get_accounts_summary() + + +@app.post("/api/accounts/refresh-all-quotas") +async def api_refresh_all_quotas(): + """刷新所有账号额度""" + return await admin.refresh_all_quotas() + + +# ==================== 刷新进度 API ==================== + +@app.get("/api/refresh/progress") +async def api_refresh_progress(): + """获取刷新进度""" + return await admin.get_refresh_progress() + + +@app.post("/api/refresh/all") +async def api_refresh_all_with_progress(): + """批量刷新(带进度和锁检查)""" + return await admin.refresh_all_with_progress() + + +@app.get("/api/refresh/config") +async def api_get_refresh_config(): + """获取刷新配置""" + return await admin.get_refresh_config() + + +@app.put("/api/refresh/config") +async def api_update_refresh_config(request: Request): + """更新刷新配置""" + return await admin.update_refresh_config(request) + + +@app.get("/api/refresh/status") +async def api_refresh_status(): + """获取刷新管理器状态""" + return await admin.get_refresh_manager_status() + + +@app.get("/api/accounts") +async def api_accounts(): + return await admin.get_accounts() + + +@app.post("/api/accounts") +async def api_add_account(request: Request): + return await admin.add_account(request) + + +@app.delete("/api/accounts/{account_id}") +async def api_delete_account(account_id: str): + return await admin.delete_account(account_id) + + +@app.put("/api/accounts/{account_id}") +async def api_update_account(account_id: str, request: Request): + return await admin.update_account(account_id, request) + + +@app.post("/api/accounts/{account_id}/toggle") +async def api_toggle_account(account_id: str): + return await admin.toggle_account(account_id) + + +@app.post("/api/speedtest") +async def api_speedtest(): + return await admin.speedtest() + + +@app.get("/api/accounts/{account_id}/test") +async def api_test_account_token(account_id: str): + """测试指定账号的 Token 是否有效""" + return await admin.test_account_token(account_id) + + +@app.get("/api/token/scan") +async def api_scan_tokens(): + return await admin.scan_tokens() + + +@app.post("/api/token/add-from-scan") +async def api_add_from_scan(request: Request): + return await admin.add_from_scan(request) + + +@app.get("/api/config/export") +async def api_export_config(): + return await admin.export_config() + + +@app.post("/api/config/import") +async def api_import_config(request: Request): + return await admin.import_config(request) + + +@app.post("/api/token/refresh-check") +async def api_refresh_check(): + return await admin.refresh_token_check() + + +@app.post("/api/accounts/{account_id}/refresh") +async def api_refresh_account(account_id: str): + """刷新指定账号的 token(集成 RefreshManager)""" + return await admin.refresh_account_token_with_manager(account_id) + + +@app.post("/api/accounts/{account_id}/restore") +async def api_restore_account(account_id: str): + """恢复账号(从冷却状态)""" + return await admin.restore_account(account_id) + + +@app.get("/api/accounts/{account_id}/usage") +async def api_account_usage(account_id: str): + """获取账号用量信息""" + return await admin.get_account_usage_info(account_id) + + +@app.get("/api/accounts/{account_id}") +async def api_account_detail(account_id: str): + """获取账号详细信息""" + return await admin.get_account_detail(account_id) + + +@app.post("/api/accounts/{account_id}/refresh-quota") +async def api_refresh_account_quota(account_id: str): + """刷新单个账号额度(先刷新 Token)""" + return await admin.refresh_account_quota_with_token(account_id) + + +# ==================== 优先账号 API ==================== + +@app.get("/api/priority") +async def api_get_priority_accounts(): + """获取优先账号列表""" + return await admin.get_priority_accounts() + + +@app.post("/api/priority/{account_id}") +async def api_set_priority_account(account_id: str, request: Request): + """设置优先账号""" + return await admin.set_priority_account(account_id, request) + + +@app.delete("/api/priority/{account_id}") +async def api_remove_priority_account(account_id: str): + """取消优先账号""" + return await admin.remove_priority_account(account_id) + + +@app.put("/api/priority/reorder") +async def api_reorder_priority_accounts(request: Request): + """调整优先账号顺序""" + return await admin.reorder_priority_accounts(request) + + +@app.get("/api/quota") +async def api_quota_status(): + """获取配额状态""" + return await admin.get_quota_status() + + +@app.get("/api/kiro/login-url") +async def api_login_url(): + return await admin.get_kiro_login_url() + + +@app.get("/api/stats/detailed") +async def api_detailed_stats(): + """获取详细统计信息""" + return await admin.get_detailed_stats() + + +@app.post("/api/health-check") +async def api_health_check(): + """手动触发健康检查""" + return await admin.run_health_check() + + +@app.get("/api/browsers") +async def api_browsers(): + """获取可用浏览器列表""" + return await admin.get_browsers() + + +# ==================== Kiro 登录 API ==================== + +@app.post("/api/kiro/login/start") +async def api_kiro_login_start(request: Request): + """启动 Kiro 设备授权登录""" + return await admin.start_kiro_login(request) + + +@app.get("/api/kiro/login/poll") +async def api_kiro_login_poll(): + """轮询登录状态""" + return await admin.poll_kiro_login() + + +@app.post("/api/kiro/login/cancel") +async def api_kiro_login_cancel(): + """取消登录""" + return await admin.cancel_kiro_login() + + +@app.get("/api/kiro/login/status") +async def api_kiro_login_status(): + """获取登录状态""" + return await admin.get_kiro_login_status() + + +# ==================== Social Auth API (Google/GitHub) ==================== + +@app.post("/api/kiro/social/start") +async def api_social_login_start(request: Request): + """启动 Social Auth 登录""" + return await admin.start_social_login(request) + + +@app.post("/api/kiro/social/exchange") +async def api_social_token_exchange(request: Request): + """交换 Social Auth Token""" + return await admin.exchange_social_token(request) + + +@app.post("/api/kiro/social/cancel") +async def api_social_login_cancel(): + """取消 Social Auth 登录""" + return await admin.cancel_social_login() + + +@app.get("/api/kiro/social/status") +async def api_social_login_status(): + """获取 Social Auth 状态""" + return await admin.get_social_login_status() + + +# ==================== 协议注册 API ==================== + +@app.post("/api/protocol/register") +async def api_register_protocol(): + """注册 kiro:// 协议""" + return await admin.register_kiro_protocol() + + +@app.post("/api/protocol/unregister") +async def api_unregister_protocol(): + """取消注册 kiro:// 协议""" + return await admin.unregister_kiro_protocol() + + +@app.get("/api/protocol/status") +async def api_protocol_status(): + """获取协议注册状态""" + return await admin.get_protocol_status() + + +@app.get("/api/protocol/callback") +async def api_protocol_callback(): + """获取回调结果""" + return await admin.get_callback_result() + + +# ==================== Flow Monitor API ==================== + +@app.get("/api/flows") +async def api_flows( + protocol: str = None, + model: str = None, + account_id: str = None, + state: str = None, + has_error: bool = None, + bookmarked: bool = None, + search: str = None, + limit: int = 50, + offset: int = 0, +): + """查询 Flows""" + return await admin.get_flows( + protocol=protocol, + model=model, + account_id=account_id, + state_filter=state, + has_error=has_error, + bookmarked=bookmarked, + search=search, + limit=limit, + offset=offset, + ) + + +@app.get("/api/flows/stats") +async def api_flow_stats(): + """获取 Flow 统计""" + return await admin.get_flow_stats() + + +@app.get("/api/flows/{flow_id}") +async def api_flow_detail(flow_id: str): + """获取 Flow 详情""" + return await admin.get_flow_detail(flow_id) + + +@app.post("/api/flows/{flow_id}/bookmark") +async def api_bookmark_flow(flow_id: str, request: Request): + """书签 Flow""" + return await admin.bookmark_flow(flow_id, request) + + +@app.post("/api/flows/{flow_id}/note") +async def api_add_flow_note(flow_id: str, request: Request): + """添加 Flow 备注""" + return await admin.add_flow_note(flow_id, request) + + +@app.post("/api/flows/{flow_id}/tag") +async def api_add_flow_tag(flow_id: str, request: Request): + """添加 Flow 标签""" + return await admin.add_flow_tag(flow_id, request) + + +@app.post("/api/flows/export") +async def api_export_flows(request: Request): + """导出 Flows""" + return await admin.export_flows(request) + + +# ==================== 历史消息管理 API ==================== + +from .core import get_history_config, update_history_config, TruncateStrategy +from .core.rate_limiter import get_rate_limiter + +@app.get("/api/settings/history") +async def api_get_history_config(): + """获取历史消息管理配置""" + config = get_history_config() + return config.to_dict() + + +@app.post("/api/settings/history") +async def api_update_history_config(request: Request): + """更新历史消息管理配置""" + data = await request.json() + update_history_config(data) + return {"ok": True, "config": get_history_config().to_dict()} + + +# ==================== 限速配置 API ==================== + +@app.get("/api/settings/rate-limit") +async def api_get_rate_limit_config(): + """获取限速配置""" + limiter = get_rate_limiter() + return { + "enabled": limiter.config.enabled, + "min_request_interval": limiter.config.min_request_interval, + "max_requests_per_minute": limiter.config.max_requests_per_minute, + "global_max_requests_per_minute": limiter.config.global_max_requests_per_minute, + "stats": limiter.get_stats() + } + + +@app.post("/api/settings/rate-limit") +async def api_update_rate_limit_config(request: Request): + """更新限速配置""" + data = await request.json() + limiter = get_rate_limiter() + limiter.update_config(**data) + return {"ok": True, "config": { + "enabled": limiter.config.enabled, + "min_request_interval": limiter.config.min_request_interval, + "max_requests_per_minute": limiter.config.max_requests_per_minute, + "global_max_requests_per_minute": limiter.config.global_max_requests_per_minute, + }} + + +# ==================== 文档 API ==================== + +# 文档标题映射 +DOC_TITLES = { + "01-quickstart": "快速开始", + "02-features": "功能特性", + "03-faq": "常见问题", + "04-api": "API 参考", + "05-server-deploy": "服务器部署", +} + +@app.get("/api/docs") +async def api_docs_list(): + """获取文档列表""" + docs_dir = get_resource_path("kiro_proxy/docs") + docs = [] + if docs_dir.exists(): + for doc_file in sorted(docs_dir.glob("*.md")): + doc_id = doc_file.stem + title = DOC_TITLES.get(doc_id, doc_id) + docs.append({"id": doc_id, "title": title}) + return {"docs": docs} + + +@app.get("/api/docs/{doc_id}") +async def api_docs_content(doc_id: str): + """获取文档内容""" + docs_dir = get_resource_path("kiro_proxy/docs") + doc_file = docs_dir / f"{doc_id}.md" + if not doc_file.exists(): + raise HTTPException(status_code=404, detail="文档不存在") + content = doc_file.read_text(encoding="utf-8") + title = DOC_TITLES.get(doc_id, doc_id) + return {"id": doc_id, "title": title, "content": content} + + +# ==================== 启动 ==================== + +def run(port: int = 8080): + import uvicorn + print(f"\n{'='*50}") + print(f" Kiro API Proxy v{__version__}") + print(f" http://localhost:{port}") + print(f"{'='*50}\n") + uvicorn.run(app, host="0.0.0.0", port=port) + + +if __name__ == "__main__": + import sys + port = int(sys.argv[1]) if len(sys.argv) > 1 else 8080 + run(port) diff --git a/KiroProxy/kiro_proxy/models.py b/KiroProxy/kiro_proxy/models.py new file mode 100644 index 0000000000000000000000000000000000000000..dd970e402f0c1872342c0e9c6d9e96eabe769e96 --- /dev/null +++ b/KiroProxy/kiro_proxy/models.py @@ -0,0 +1,15 @@ +"""数据模型 - 兼容层 + +此文件保留用于向后兼容,实际实现已移至 core/ 和 credential/ 模块。 +""" +from .core import state, ProxyState, Account +from .core.state import RequestLog +from .credential import CredentialStatus + +__all__ = [ + "state", + "ProxyState", + "Account", + "RequestLog", + "CredentialStatus", +] diff --git a/KiroProxy/kiro_proxy/providers/__init__.py b/KiroProxy/kiro_proxy/providers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a970310ee968604b6e8da6117adfb22823754788 --- /dev/null +++ b/KiroProxy/kiro_proxy/providers/__init__.py @@ -0,0 +1,5 @@ +"""Provider 模块""" +from .base import BaseProvider +from .kiro import KiroProvider + +__all__ = ["BaseProvider", "KiroProvider"] diff --git a/KiroProxy/kiro_proxy/providers/base.py b/KiroProxy/kiro_proxy/providers/base.py new file mode 100644 index 0000000000000000000000000000000000000000..4722f0b940fe93c3de61024fdc9e238e6a8783e2 --- /dev/null +++ b/KiroProxy/kiro_proxy/providers/base.py @@ -0,0 +1,46 @@ +"""Provider 基类""" +from abc import ABC, abstractmethod +from typing import Optional, Dict, Any, Tuple + + +class BaseProvider(ABC): + """Provider 基类 + + 所有 Provider(Kiro、Gemini、Qwen 等)都应继承此类。 + """ + + @property + @abstractmethod + def name(self) -> str: + """Provider 名称""" + pass + + @property + @abstractmethod + def api_url(self) -> str: + """API 端点 URL""" + pass + + @abstractmethod + def build_headers(self, token: str, **kwargs) -> Dict[str, str]: + """构建请求头""" + pass + + @abstractmethod + def build_request(self, messages: list, model: str, **kwargs) -> Dict[str, Any]: + """构建请求体""" + pass + + @abstractmethod + def parse_response(self, raw: bytes) -> Dict[str, Any]: + """解析响应""" + pass + + @abstractmethod + async def refresh_token(self) -> Tuple[bool, str]: + """刷新 token,返回 (success, new_token_or_error)""" + pass + + def is_quota_exceeded(self, status_code: int, error_text: str) -> bool: + """检查是否为配额超限错误""" + return status_code in {429, 503, 529} diff --git a/KiroProxy/kiro_proxy/providers/kiro.py b/KiroProxy/kiro_proxy/providers/kiro.py new file mode 100644 index 0000000000000000000000000000000000000000..12ad0582ea87c69bba1dac93b5d64623c377a5bf --- /dev/null +++ b/KiroProxy/kiro_proxy/providers/kiro.py @@ -0,0 +1,227 @@ +"""Kiro Provider""" +import json +import uuid +from typing import Dict, Any, List, Optional, Tuple + +from .base import BaseProvider +from ..credential import ( + KiroCredentials, TokenRefresher, + generate_machine_id, get_kiro_version, get_system_info +) + + +class KiroProvider(BaseProvider): + """Kiro/CodeWhisperer Provider""" + + API_URL = "https://q.us-east-1.amazonaws.com/generateAssistantResponse" + MODELS_URL = "https://q.us-east-1.amazonaws.com/ListAvailableModels" + + def __init__(self, credentials: Optional[KiroCredentials] = None): + self.credentials = credentials + self._machine_id: Optional[str] = None + + @property + def name(self) -> str: + return "kiro" + + @property + def api_url(self) -> str: + return self.API_URL + + def get_machine_id(self) -> str: + """获取基于凭证的 Machine ID""" + if self._machine_id: + return self._machine_id + + if self.credentials: + self._machine_id = generate_machine_id( + self.credentials.profile_arn, + self.credentials.client_id + ) + else: + self._machine_id = generate_machine_id() + + return self._machine_id + + def build_headers( + self, + token: str, + agent_mode: str = "vibe", + **kwargs + ) -> Dict[str, str]: + """构建 Kiro API 请求头 (与 kiro.rs 保持一致)""" + machine_id = kwargs.get("machine_id") or self.get_machine_id() + kiro_version = get_kiro_version() + os_name, node_version = get_system_info() + + return { + "content-type": "application/json", + "x-amzn-codewhisperer-optout": "true", + "x-amzn-kiro-agent-mode": agent_mode, + "x-amz-user-agent": f"aws-sdk-js/1.0.27 KiroIDE-{kiro_version}-{machine_id}", + "user-agent": f"aws-sdk-js/1.0.27 ua/2.1 os/{os_name} lang/js md/nodejs#{node_version} api/codewhispererstreaming#1.0.27 m/E KiroIDE-{kiro_version}-{machine_id}", + "amz-sdk-invocation-id": str(uuid.uuid4()), + "amz-sdk-request": "attempt=1; max=3", + "Authorization": f"Bearer {token}", + "Connection": "close", + } + + def build_request( + self, + messages: list = None, + model: str = "claude-sonnet-4", + user_content: str = "", + history: List[dict] = None, + tools: List[dict] = None, + images: List[dict] = None, + tool_results: List[dict] = None, + **kwargs + ) -> Dict[str, Any]: + """构建 Kiro API 请求体""" + conversation_id = str(uuid.uuid4()) + + # 确保 content 不为空 + if not user_content: + user_content = "Continue" + + user_input_message = { + "content": user_content, + "modelId": model, + "origin": "AI_EDITOR", + } + + if images: + user_input_message["images"] = images + + # 只有在有 tools 或 tool_results 时才添加 userInputMessageContext + context = {} + if tools: + context["tools"] = tools + if tool_results: + context["toolResults"] = tool_results + + if context: + user_input_message["userInputMessageContext"] = context + + return { + "conversationState": { + "agentContinuationId": str(uuid.uuid4()), + "agentTaskType": "vibe", + "chatTriggerType": "MANUAL", + "conversationId": conversation_id, + "currentMessage": {"userInputMessage": user_input_message}, + "history": history or [] + } + } + + def parse_response(self, raw: bytes) -> Dict[str, Any]: + """解析 AWS event-stream 格式响应""" + result = { + "content": [], + "tool_uses": [], + "stop_reason": "end_turn" + } + + tool_input_buffer = {} + pos = 0 + + while pos < len(raw): + if pos + 12 > len(raw): + break + + total_len = int.from_bytes(raw[pos:pos+4], 'big') + headers_len = int.from_bytes(raw[pos+4:pos+8], 'big') + + if total_len == 0 or total_len > len(raw) - pos: + break + + header_start = pos + 12 + header_end = header_start + headers_len + headers_data = raw[header_start:header_end] + event_type = None + + try: + headers_str = headers_data.decode('utf-8', errors='ignore') + if 'toolUseEvent' in headers_str: + event_type = 'toolUseEvent' + elif 'assistantResponseEvent' in headers_str: + event_type = 'assistantResponseEvent' + except: + pass + + payload_start = pos + 12 + headers_len + payload_end = pos + total_len - 4 + + if payload_start < payload_end: + try: + payload = json.loads(raw[payload_start:payload_end].decode('utf-8')) + + if 'assistantResponseEvent' in payload: + e = payload['assistantResponseEvent'] + if 'content' in e: + result["content"].append(e['content']) + elif 'content' in payload and event_type != 'toolUseEvent': + result["content"].append(payload['content']) + + if event_type == 'toolUseEvent' or 'toolUseId' in payload: + tool_id = payload.get('toolUseId', '') + tool_name = payload.get('name', '') + tool_input = payload.get('input', '') + + if tool_id: + if tool_id not in tool_input_buffer: + tool_input_buffer[tool_id] = { + "id": tool_id, + "name": tool_name, + "input_parts": [] + } + if tool_name and not tool_input_buffer[tool_id]["name"]: + tool_input_buffer[tool_id]["name"] = tool_name + if tool_input: + tool_input_buffer[tool_id]["input_parts"].append(tool_input) + except: + pass + + pos += total_len + + # 组装工具调用 + for tool_id, tool_data in tool_input_buffer.items(): + input_str = "".join(tool_data["input_parts"]) + try: + input_json = json.loads(input_str) + except: + input_json = {"raw": input_str} + + result["tool_uses"].append({ + "type": "tool_use", + "id": tool_data["id"], + "name": tool_data["name"], + "input": input_json + }) + + if result["tool_uses"]: + result["stop_reason"] = "tool_use" + + return result + + def parse_response_text(self, raw: bytes) -> str: + """解析响应,只返回文本内容""" + result = self.parse_response(raw) + return "".join(result["content"]) or "[No response]" + + async def refresh_token(self) -> Tuple[bool, str]: + """刷新 token""" + if not self.credentials: + return False, "无凭证信息" + + refresher = TokenRefresher(self.credentials) + return await refresher.refresh() + + def is_quota_exceeded(self, status_code: int, error_text: str) -> bool: + """检查是否为配额超限错误""" + if status_code in {429, 503, 529}: + return True + + keywords = ["rate limit", "quota", "too many requests", "throttl"] + error_lower = error_text.lower() + return any(kw in error_lower for kw in keywords) diff --git a/KiroProxy/kiro_proxy/resources.py b/KiroProxy/kiro_proxy/resources.py new file mode 100644 index 0000000000000000000000000000000000000000..8377ab3dd8b4ccdf85991f993bc38ede1e08bc83 --- /dev/null +++ b/KiroProxy/kiro_proxy/resources.py @@ -0,0 +1,7 @@ +import sys +from pathlib import Path + + +def get_resource_path(relative_path: str) -> Path: + base_path = Path(sys._MEIPASS) if hasattr(sys, "_MEIPASS") else Path(__file__).parent.parent + return base_path / relative_path diff --git a/KiroProxy/kiro_proxy/routers/__init__.py b/KiroProxy/kiro_proxy/routers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/KiroProxy/kiro_proxy/routers/admin.py b/KiroProxy/kiro_proxy/routers/admin.py new file mode 100644 index 0000000000000000000000000000000000000000..77ffaab32f56966814e82789c463ae6ac9da10d8 --- /dev/null +++ b/KiroProxy/kiro_proxy/routers/admin.py @@ -0,0 +1,410 @@ +from fastapi import APIRouter, Request, HTTPException + +from ..core import get_history_config, get_rate_limiter, update_history_config +from ..handlers import admin as admin_handler +from ..resources import get_resource_path + +router = APIRouter(prefix="/api") + + +@router.get("/status") +async def api_status(): + return await admin_handler.get_status() + + +@router.post("/event_logging/batch") +async def api_event_logging_batch(request: Request): + return await admin_handler.event_logging_batch(request) + + +@router.get("/stats") +async def api_stats(): + return await admin_handler.get_stats() + + +@router.get("/logs") +async def api_logs(limit: int = 100): + return await admin_handler.get_logs(limit) + + +@router.get("/accounts/export") +async def api_export_accounts(): + return await admin_handler.export_accounts() + + +@router.post("/accounts/import") +async def api_import_accounts(request: Request): + return await admin_handler.import_accounts(request) + + +@router.post("/accounts/manual") +async def api_add_manual_token(request: Request): + return await admin_handler.add_manual_token(request) + + +@router.post("/accounts/batch") +async def api_batch_import_accounts(request: Request): + return await admin_handler.batch_import_accounts(request) + + +@router.post("/accounts/refresh-all") +async def api_refresh_all(): + return await admin_handler.refresh_all_tokens() + + +@router.get("/accounts/status") +async def api_accounts_status_enhanced(): + return await admin_handler.get_accounts_status_enhanced() + + +@router.get("/accounts/summary") +async def api_accounts_summary(): + return await admin_handler.get_accounts_summary() + + +@router.post("/accounts/refresh-all-quotas") +async def api_refresh_all_quotas(): + return await admin_handler.refresh_all_quotas() + + +@router.get("/refresh/progress") +async def api_refresh_progress(): + return await admin_handler.get_refresh_progress() + + +@router.post("/refresh/all") +async def api_refresh_all_with_progress(): + return await admin_handler.refresh_all_with_progress() + + +@router.get("/refresh/config") +async def api_get_refresh_config(): + return await admin_handler.get_refresh_config() + + +@router.put("/refresh/config") +async def api_update_refresh_config(request: Request): + return await admin_handler.update_refresh_config(request) + + +@router.get("/refresh/status") +async def api_refresh_status(): + return await admin_handler.get_refresh_manager_status() + + +@router.get("/accounts") +async def api_accounts(): + return await admin_handler.get_accounts() + + +@router.post("/accounts") +async def api_add_account(request: Request): + return await admin_handler.add_account(request) + + +@router.delete("/accounts/{account_id}") +async def api_delete_account(account_id: str): + return await admin_handler.delete_account(account_id) + + +@router.put("/accounts/{account_id}") +async def api_update_account(account_id: str, request: Request): + return await admin_handler.update_account(account_id, request) + + +@router.post("/accounts/{account_id}/toggle") +async def api_toggle_account(account_id: str): + return await admin_handler.toggle_account(account_id) + + +@router.post("/speedtest") +async def api_speedtest(): + return await admin_handler.speedtest() + + +@router.get("/accounts/{account_id}/test") +async def api_test_account_token(account_id: str): + return await admin_handler.test_account_token(account_id) + + +@router.get("/token/scan") +async def api_scan_tokens(): + return await admin_handler.scan_tokens() + + +@router.post("/token/add-from-scan") +async def api_add_from_scan(request: Request): + return await admin_handler.add_from_scan(request) + + +@router.get("/config/export") +async def api_export_config(): + return await admin_handler.export_config() + + +@router.post("/config/import") +async def api_import_config(request: Request): + return await admin_handler.import_config(request) + + +@router.post("/token/refresh-check") +async def api_refresh_check(): + return await admin_handler.refresh_token_check() + + +@router.post("/accounts/{account_id}/refresh") +async def api_refresh_account(account_id: str): + return await admin_handler.refresh_account_token_with_manager(account_id) + + +@router.post("/accounts/{account_id}/restore") +async def api_restore_account(account_id: str): + return await admin_handler.restore_account(account_id) + + +@router.get("/accounts/{account_id}/usage") +async def api_account_usage(account_id: str): + return await admin_handler.get_account_usage_info(account_id) + + +@router.get("/accounts/{account_id}") +async def api_account_detail(account_id: str): + return await admin_handler.get_account_detail(account_id) + + +@router.post("/accounts/{account_id}/refresh-quota") +async def api_refresh_account_quota(account_id: str): + return await admin_handler.refresh_account_quota_with_token(account_id) + + +@router.get("/priority") +async def api_get_priority_accounts(): + return await admin_handler.get_priority_accounts() + + +@router.post("/priority/{account_id}") +async def api_set_priority_account(account_id: str, request: Request): + return await admin_handler.set_priority_account(account_id, request) + + +@router.delete("/priority/{account_id}") +async def api_remove_priority_account(account_id: str): + return await admin_handler.remove_priority_account(account_id) + + +@router.put("/priority/reorder") +async def api_reorder_priority_accounts(request: Request): + return await admin_handler.reorder_priority_accounts(request) + + +@router.get("/quota") +async def api_quota_status(): + return await admin_handler.get_quota_status() + + +@router.get("/kiro/login-url") +async def api_login_url(): + return await admin_handler.get_kiro_login_url() + + +@router.get("/stats/detailed") +async def api_detailed_stats(): + return await admin_handler.get_detailed_stats() + + +@router.post("/health-check") +async def api_health_check(): + return await admin_handler.run_health_check() + + +@router.get("/browsers") +async def api_browsers(): + return await admin_handler.get_browsers() + + +@router.post("/kiro/login/start") +async def api_kiro_login_start(request: Request): + return await admin_handler.start_kiro_login(request) + + +@router.get("/kiro/login/poll") +async def api_kiro_login_poll(): + return await admin_handler.poll_kiro_login() + + +@router.post("/kiro/login/cancel") +async def api_kiro_login_cancel(): + return await admin_handler.cancel_kiro_login() + + +@router.get("/kiro/login/status") +async def api_kiro_login_status(): + return await admin_handler.get_kiro_login_status() + + +@router.post("/kiro/social/start") +async def api_social_login_start(request: Request): + return await admin_handler.start_social_login(request) + + +@router.post("/kiro/social/exchange") +async def api_social_token_exchange(request: Request): + return await admin_handler.exchange_social_token(request) + + +@router.post("/kiro/social/cancel") +async def api_social_login_cancel(): + return await admin_handler.cancel_social_login() + + +@router.get("/kiro/social/status") +async def api_social_login_status(): + return await admin_handler.get_social_login_status() + + +@router.post("/protocol/register") +async def api_register_protocol(): + return await admin_handler.register_kiro_protocol() + + +@router.post("/protocol/unregister") +async def api_unregister_protocol(): + return await admin_handler.unregister_kiro_protocol() + + +@router.get("/protocol/status") +async def api_protocol_status(): + return await admin_handler.get_protocol_status() + + +@router.get("/protocol/callback") +async def api_protocol_callback(): + return await admin_handler.get_callback_result() + + +@router.get("/flows") +async def api_flows( + protocol: str = None, + model: str = None, + account_id: str = None, + state: str = None, + has_error: bool = None, + bookmarked: bool = None, + search: str = None, + limit: int = 50, + offset: int = 0, +): + return await admin_handler.get_flows( + protocol=protocol, + model=model, + account_id=account_id, + state_filter=state, + has_error=has_error, + bookmarked=bookmarked, + search=search, + limit=limit, + offset=offset, + ) + + +@router.get("/flows/stats") +async def api_flow_stats(): + return await admin_handler.get_flow_stats() + + +@router.get("/flows/{flow_id}") +async def api_flow_detail(flow_id: str): + return await admin_handler.get_flow_detail(flow_id) + + +@router.post("/flows/{flow_id}/bookmark") +async def api_bookmark_flow(flow_id: str, request: Request): + return await admin_handler.bookmark_flow(flow_id, request) + + +@router.post("/flows/{flow_id}/note") +async def api_add_flow_note(flow_id: str, request: Request): + return await admin_handler.add_flow_note(flow_id, request) + + +@router.post("/flows/{flow_id}/tag") +async def api_add_flow_tag(flow_id: str, request: Request): + return await admin_handler.add_flow_tag(flow_id, request) + + +@router.post("/flows/export") +async def api_export_flows(request: Request): + return await admin_handler.export_flows(request) + + +@router.get("/settings/history") +async def api_get_history_config(): + config = get_history_config() + return config.to_dict() + + +@router.post("/settings/history") +async def api_update_history_config(request: Request): + data = await request.json() + update_history_config(data) + return {"ok": True, "config": get_history_config().to_dict()} + + +@router.get("/settings/rate-limit") +async def api_get_rate_limit_config(): + limiter = get_rate_limiter() + return { + "enabled": limiter.config.enabled, + "min_request_interval": limiter.config.min_request_interval, + "max_requests_per_minute": limiter.config.max_requests_per_minute, + "global_max_requests_per_minute": limiter.config.global_max_requests_per_minute, + "stats": limiter.get_stats(), + } + + +@router.post("/settings/rate-limit") +async def api_update_rate_limit_config(request: Request): + data = await request.json() + limiter = get_rate_limiter() + limiter.update_config(**data) + return { + "ok": True, + "config": { + "enabled": limiter.config.enabled, + "min_request_interval": limiter.config.min_request_interval, + "max_requests_per_minute": limiter.config.max_requests_per_minute, + "global_max_requests_per_minute": limiter.config.global_max_requests_per_minute, + }, + } + + +DOC_TITLES = { + "01-quickstart": "快速开始", + "02-features": "功能特性", + "03-faq": "常见问题", + "04-api": "API 参考", + "05-server-deploy": "服务器部署", +} + + +@router.get("/docs") +async def api_docs_list(): + docs_dir = get_resource_path("kiro_proxy/docs") + docs = [] + if docs_dir.exists(): + for doc_file in sorted(docs_dir.glob("*.md")): + doc_id = doc_file.stem + title = DOC_TITLES.get(doc_id, doc_id) + docs.append({"id": doc_id, "title": title}) + return {"docs": docs} + + +@router.get("/docs/{doc_id}") +async def api_docs_content(doc_id: str): + docs_dir = get_resource_path("kiro_proxy/docs") + doc_file = docs_dir / f"{doc_id}.md" + if not doc_file.exists(): + raise HTTPException(status_code=404, detail="文档不存在") + content = doc_file.read_text(encoding="utf-8") + title = DOC_TITLES.get(doc_id, doc_id) + return {"id": doc_id, "title": title, "content": content} diff --git a/KiroProxy/kiro_proxy/routers/protocols.py b/KiroProxy/kiro_proxy/routers/protocols.py new file mode 100644 index 0000000000000000000000000000000000000000..3ce3ad93f622dbcfd1b5cbd977567bc4d413646c --- /dev/null +++ b/KiroProxy/kiro_proxy/routers/protocols.py @@ -0,0 +1,111 @@ +import uuid + +import httpx +from fastapi import APIRouter, Request +from fastapi.responses import JSONResponse + +from ..config import MODELS_URL +from ..core import state +from ..credential import get_kiro_version +from ..handlers import anthropic, gemini, openai +from ..handlers import responses as responses_handler + +router = APIRouter() + + +@router.get("/v1/models") +async def models(): + try: + account = state.get_available_account() + if not account: + raise Exception("No available account") + + token = account.get_token() + machine_id = account.get_machine_id() + kiro_version = get_kiro_version() + + headers = { + "content-type": "application/json", + "x-amz-user-agent": f"aws-sdk-js/1.0.0 KiroIDE-{kiro_version}-{machine_id}", + "amz-sdk-invocation-id": str(uuid.uuid4()), + "Authorization": f"Bearer {token}", + } + async with httpx.AsyncClient(verify=False, timeout=30) as client: + resp = await client.get(MODELS_URL, headers=headers, params={"origin": "AI_EDITOR"}) + if resp.status_code == 200: + data = resp.json() + return { + "object": "list", + "data": [ + { + "id": m["modelId"], + "object": "model", + "owned_by": "kiro", + "name": m["modelName"], + } + for m in data.get("models", []) + ], + } + except Exception: + pass + + return { + "object": "list", + "data": [ + {"id": "auto", "object": "model", "owned_by": "kiro", "name": "Auto"}, + { + "id": "claude-sonnet-4.5", + "object": "model", + "owned_by": "kiro", + "name": "Claude Sonnet 4.5", + }, + {"id": "claude-sonnet-4", "object": "model", "owned_by": "kiro", "name": "Claude Sonnet 4"}, + { + "id": "claude-haiku-4.5", + "object": "model", + "owned_by": "kiro", + "name": "Claude Haiku 4.5", + }, + ], + } + + +@router.post("/v1/messages") +async def anthropic_messages(request: Request): + print(f"[Main] Received /v1/messages request from {request.client.host}") + return await anthropic.handle_messages(request) + + +@router.post("/v1/messages/count_tokens") +async def anthropic_count_tokens(request: Request): + return await anthropic.handle_count_tokens(request) + + +@router.post("/v1/complete") +async def anthropic_complete(request: Request): + print(f"[Main] Received /v1/complete request from {request.client.host}") + return JSONResponse( + status_code=400, + content={ + "error": { + "type": "invalid_request_error", + "message": "KiroProxy currently only supports /v1/messages. Please check if your client can be configured to use Messages API.", + } + }, + ) + + +@router.post("/v1/chat/completions") +async def openai_chat(request: Request): + return await openai.handle_chat_completions(request) + + +@router.post("/v1/responses") +async def openai_responses(request: Request): + return await responses_handler.handle_responses(request) + + +@router.post("/v1beta/models/{model_name}:generateContent") +@router.post("/v1/models/{model_name}:generateContent") +async def gemini_generate(model_name: str, request: Request): + return await gemini.handle_generate_content(model_name, request) diff --git a/KiroProxy/kiro_proxy/routers/web.py b/KiroProxy/kiro_proxy/routers/web.py new file mode 100644 index 0000000000000000000000000000000000000000..f83c362e9e4f99b57a74f88177ab88b8d88b2669 --- /dev/null +++ b/KiroProxy/kiro_proxy/routers/web.py @@ -0,0 +1,21 @@ +from fastapi import APIRouter, HTTPException +from fastapi.responses import HTMLResponse, StreamingResponse + +from ..resources import get_resource_path +from ..web.html import HTML_PAGE + +router = APIRouter() + + +@router.get("/", response_class=HTMLResponse) +async def index(): + return HTML_PAGE + + +@router.get("/assets/{path:path}") +async def serve_assets(path: str): + file_path = get_resource_path("assets") / path + if file_path.exists(): + content_type = "image/svg+xml" if path.endswith(".svg") else "application/octet-stream" + return StreamingResponse(open(file_path, "rb"), media_type=content_type) + raise HTTPException(status_code=404) diff --git a/KiroProxy/kiro_proxy/web/__init__.py b/KiroProxy/kiro_proxy/web/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e97628f941fb285046fbcb62dda30922d857d81e --- /dev/null +++ b/KiroProxy/kiro_proxy/web/__init__.py @@ -0,0 +1 @@ +# Web UI diff --git a/KiroProxy/kiro_proxy/web/html.py b/KiroProxy/kiro_proxy/web/html.py new file mode 100644 index 0000000000000000000000000000000000000000..6ce0981e993e461d53d1a5fc6f1e686ca9076487 --- /dev/null +++ b/KiroProxy/kiro_proxy/web/html.py @@ -0,0 +1,3735 @@ +"""Web UI - 组件化单文件结构""" + +# ==================== CSS 样式 ==================== +CSS_BASE = ''' +* { margin: 0; padding: 0; box-sizing: border-box; } +:root { + --bg: #0a0a0a; + --card: #1a1a1a; + --border: #333; + --text: #fafafa; + --muted: #a3a3a3; + --accent: #3b82f6; + --success: #22c55e; + --error: #ef4444; + --warn: #f59e0b; + --info: #3b82f6; + --primary: #6366f1; + --secondary: #8b5cf6; +} +body { + font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, sans-serif; + background: var(--bg); + color: var(--text); + line-height: 1.6; + min-height: 100vh; +} +.container { + max-width: 1200px; + margin: 0 auto; + padding: 1rem; + min-height: 100vh; + display: flex; + flex-direction: column; +} +''' + +CSS_LAYOUT = ''' +/* Header */ +header { + display: flex; + justify-content: space-between; + align-items: center; + margin-bottom: 2rem; + padding: 1.5rem; + background: var(--card); + border-radius: 16px; + box-shadow: 0 4px 12px rgba(0,0,0,0.3); +} +h1 { + font-size: 1.75rem; + font-weight: 700; + display: flex; + align-items: center; + gap: 0.75rem; + background: linear-gradient(135deg, var(--primary), var(--secondary)); + -webkit-background-clip: text; + -webkit-text-fill-color: transparent; + background-clip: text; +} +h1 img { + width: 32px; + height: 32px; + border-radius: 8px; +} +.status { + display: flex; + align-items: center; + gap: 1rem; + font-size: 0.875rem; + color: var(--muted); +} +.status-dot { + width: 10px; + height: 10px; + border-radius: 50%; + box-shadow: 0 0 8px currentColor; +} +.status-dot.ok { + background: var(--success); + color: var(--success); +} +.status-dot.err { + background: var(--error); + color: var(--error); +} + +/* Navigation Tabs */ +.tabs { + display: flex; + justify-content: center; + gap: 0.5rem; + margin-bottom: 2rem; + padding: 0.5rem; + background: var(--card); + border-radius: 16px; + box-shadow: 0 4px 12px rgba(0,0,0,0.3); +} +.tab { + padding: 0.75rem 1.5rem; + border: none; + background: transparent; + color: var(--muted); + cursor: pointer; + font-size: 0.875rem; + font-weight: 500; + transition: all 0.3s ease; + border-radius: 12px; + position: relative; +} +.tab:hover { + color: var(--text); + background: rgba(255,255,255,0.05); +} +.tab.active { + background: linear-gradient(135deg, var(--primary), var(--secondary)); + color: white; + box-shadow: 0 4px 12px rgba(59,130,246,0.3); +} + +/* Panels */ +.panel { + display: none; + flex: 1; +} +.panel.active { + display: block; +} + +/* Footer */ +.footer { + text-align: center; + color: var(--muted); + font-size: 0.75rem; + margin-top: 2rem; + padding: 1rem; + border-top: 1px solid var(--border); +} +''' + +CSS_COMPONENTS = ''' +/* Cards */ +.card { + background: var(--card); + border: 1px solid var(--border); + border-radius: 16px; + padding: 2rem; + margin-bottom: 1.5rem; + box-shadow: 0 4px 12px rgba(0,0,0,0.3); + transition: all 0.3s ease; +} +.card:hover { + box-shadow: 0 8px 24px rgba(0,0,0,0.4); + transform: translateY(-2px); +} +.card h3 { + font-size: 1.25rem; + font-weight: 600; + margin-bottom: 1.5rem; + display: flex; + justify-content: space-between; + align-items: center; + color: var(--text); +} + +/* Stats Grid - OXO Style */ +.stats-grid { + display: grid; + grid-template-columns: repeat(auto-fit, minmax(140px, 1fr)); + gap: 1rem; + margin-bottom: 1.5rem; +} +.stat-item { + text-align: center; + padding: 1.5rem; + background: linear-gradient(135deg, rgba(59,130,246,0.1), rgba(139,92,246,0.1)); + border-radius: 16px; + border: 1px solid rgba(59,130,246,0.2); + transition: all 0.3s ease; +} +.stat-item:hover { + transform: translateY(-4px); + box-shadow: 0 8px 24px rgba(59,130,246,0.2); +} +.stat-value { + font-size: 2rem; + font-weight: 700; + background: linear-gradient(135deg, var(--primary), var(--secondary)); + -webkit-background-clip: text; + -webkit-text-fill-color: transparent; + background-clip: text; + margin-bottom: 0.5rem; +} +.stat-label { + font-size: 0.875rem; + color: var(--muted); + font-weight: 500; +} + +/* Badges */ +.badge { + display: inline-flex; + align-items: center; + padding: 0.375rem 0.75rem; + border-radius: 12px; + font-size: 0.75rem; + font-weight: 600; + text-transform: uppercase; + letter-spacing: 0.025em; +} +.badge.success { + background: linear-gradient(135deg, #22c55e, #16a34a); + color: white; + box-shadow: 0 2px 8px rgba(34,197,94,0.3); +} +.badge.error { + background: linear-gradient(135deg, #ef4444, #dc2626); + color: white; + box-shadow: 0 2px 8px rgba(239,68,68,0.3); +} +.badge.warn { + background: linear-gradient(135deg, #f59e0b, #d97706); + color: white; + box-shadow: 0 2px 8px rgba(245,158,11,0.3); +} +.badge.info { + background: linear-gradient(135deg, #3b82f6, #2563eb); + color: white; + box-shadow: 0 2px 8px rgba(59,130,246,0.3); +} + +/* Circular Progress */ +.progress-circle { + width: 80px; + height: 80px; + border-radius: 50%; + background: conic-gradient(var(--primary) 0deg, var(--secondary) 180deg, var(--border) 180deg); + display: flex; + align-items: center; + justify-content: center; + position: relative; +} +.progress-circle::before { + content: ''; + width: 60px; + height: 60px; + border-radius: 50%; + background: var(--card); + position: absolute; +} +.progress-text { + position: relative; + z-index: 1; + font-weight: 700; + font-size: 0.875rem; +} +''' + +CSS_FORMS = ''' +/* Buttons - OXO Style */ +button { + padding: 0.75rem 1.5rem; + background: linear-gradient(135deg, var(--primary), var(--secondary)); + color: white; + border: none; + border-radius: 12px; + cursor: pointer; + font-size: 0.875rem; + font-weight: 600; + transition: all 0.3s ease; + box-shadow: 0 4px 12px rgba(59,130,246,0.3); + text-transform: uppercase; + letter-spacing: 0.025em; +} +button:hover { + transform: translateY(-2px); + box-shadow: 0 6px 16px rgba(59,130,246,0.4); +} +button:active { + transform: translateY(0); +} +button:disabled { + opacity: 0.5; + cursor: not-allowed; + transform: none; +} +button.secondary { + background: var(--card); + color: var(--text); + border: 1px solid var(--border); + box-shadow: 0 2px 8px rgba(0,0,0,0.1); +} +button.secondary:hover { + background: rgba(255,255,255,0.05); + border-color: var(--primary); +} +button.small { + padding: 0.5rem 1rem; + font-size: 0.75rem; + border-radius: 8px; +} +button.circle { + width: 48px; + height: 48px; + border-radius: 50%; + padding: 0; + display: flex; + align-items: center; + justify-content: center; +} +button.large { + padding: 1rem 2rem; + font-size: 1rem; + border-radius: 16px; +} + +/* Inputs */ +input[type="text"], +input[type="number"], +input[type="search"], +input[type="password"], +textarea { + padding: 0.75rem 1rem; + border: 1px solid var(--border); + border-radius: 12px; + background: var(--card); + color: var(--text); + font-size: 0.875rem; + transition: all 0.3s ease; + width: 100%; +} +input:hover, textarea:hover { + border-color: var(--primary); +} +input:focus, textarea:focus { + outline: none; + border-color: var(--primary); + box-shadow: 0 0 0 3px rgba(59,130,246,0.1); +} +input::placeholder, textarea::placeholder { + color: var(--muted); +} + +/* Select */ +select { + padding: 0.75rem 1rem; + border: 1px solid var(--border); + border-radius: 12px; + background: var(--card); + color: var(--text); + font-size: 0.875rem; + cursor: pointer; + appearance: none; + background-image: url("data:image/svg+xml,%3Csvg xmlns='http://www.w3.org/2000/svg' width='12' height='12' viewBox='0 0 12 12'%3E%3Cpath fill='%23a3a3a3' d='M6 8L1 3h10z'/%3E%3C/svg%3E"); + background-repeat: no-repeat; + background-position: right 1rem center; + padding-right: 3rem; + transition: all 0.3s ease; +} +select:hover { + border-color: var(--primary); +} +select:focus { + outline: none; + border-color: var(--primary); + box-shadow: 0 0 0 3px rgba(59,130,246,0.1); +} + +/* Tables */ +table { + width: 100%; + border-collapse: collapse; + font-size: 0.875rem; + background: var(--card); + border-radius: 12px; + overflow: hidden; + box-shadow: 0 4px 12px rgba(0,0,0,0.1); +} +th, td { + padding: 1rem; + text-align: left; + border-bottom: 1px solid var(--border); +} +th { + font-weight: 600; + color: var(--muted); + background: rgba(59,130,246,0.05); +} +tr:hover { + background: rgba(255,255,255,0.02); +} + +/* Code blocks */ +pre { + background: var(--bg); + border: 1px solid var(--border); + border-radius: 12px; + padding: 1.5rem; + overflow-x: auto; + font-size: 0.8rem; + font-family: 'Monaco', 'Menlo', 'Ubuntu Mono', monospace; +} +code { + background: rgba(59,130,246,0.1); + padding: 0.25rem 0.5rem; + border-radius: 6px; + font-size: 0.875em; + font-family: 'Monaco', 'Menlo', 'Ubuntu Mono', monospace; +} +''' + +CSS_ACCOUNTS = ''' +.account-card { border: 1px solid var(--border); border-radius: 8px; padding: 1rem; margin-bottom: 0.75rem; background: var(--card); } +.account-header { display: flex; justify-content: space-between; align-items: center; margin-bottom: 0.75rem; } +.account-name { font-weight: 500; display: flex; align-items: center; gap: 0.5rem; } +.account-meta { display: grid; grid-template-columns: repeat(auto-fit, minmax(140px, 1fr)); gap: 0.5rem; font-size: 0.8rem; color: var(--muted); } +.account-meta-item { display: flex; justify-content: space-between; padding: 0.25rem 0; } +.account-actions { display: flex; gap: 0.5rem; flex-wrap: wrap; margin-top: 0.75rem; padding-top: 0.75rem; border-top: 1px solid var(--border); } +''' + +CSS_API = ''' +.endpoint { display: flex; align-items: center; gap: 0.5rem; margin-bottom: 0.5rem; } +.method { padding: 0.25rem 0.5rem; border-radius: 4px; font-size: 0.75rem; font-weight: 600; } +.method.get { background: #dcfce7; color: #166534; } +.method.post { background: #fef3c7; color: #92400e; } +@media (prefers-color-scheme: dark) { + .method.get { background: #14532d; color: #86efac; } + .method.post { background: #78350f; color: #fde68a; } +} +.copy-btn { padding: 0.25rem 0.5rem; font-size: 0.75rem; background: var(--card); border: 1px solid var(--border); color: var(--text); } +''' + +CSS_DOCS = ''' +.docs-container { display: flex; gap: 1.5rem; min-height: 500px; } +.docs-nav { width: 200px; flex-shrink: 0; } +.docs-nav-item { display: block; padding: 0.5rem 0.75rem; margin-bottom: 0.25rem; border-radius: 6px; cursor: pointer; font-size: 0.875rem; color: var(--text); text-decoration: none; transition: background 0.2s; } +.docs-nav-item:hover { background: var(--bg); } +.docs-nav-item.active { background: var(--accent); color: var(--bg); } +.docs-content { flex: 1; min-width: 0; } +.docs-content h1 { font-size: 1.5rem; margin-bottom: 1rem; padding-bottom: 0.5rem; border-bottom: 1px solid var(--border); } +.docs-content h2 { font-size: 1.25rem; margin: 1.5rem 0 0.75rem; color: var(--text); } +.docs-content h3 { font-size: 1rem; margin: 1rem 0 0.5rem; color: var(--text); } +.docs-content h4 { font-size: 0.9rem; margin: 0.75rem 0 0.5rem; color: var(--muted); } +.docs-content p { margin: 0.5rem 0; } +.docs-content ul, .docs-content ol { margin: 0.5rem 0; padding-left: 1.5rem; } +.docs-content li { margin: 0.25rem 0; } +.docs-content code { background: var(--bg); padding: 0.2em 0.4em; border-radius: 3px; font-size: 0.9em; } +.docs-content pre { margin: 0.75rem 0; } +.docs-content pre code { background: none; padding: 0; } +.docs-content table { margin: 0.75rem 0; } +.docs-content blockquote { margin: 0.75rem 0; padding: 0.5rem 1rem; border-left: 3px solid var(--border); color: var(--muted); background: var(--bg); border-radius: 0 6px 6px 0; } +.docs-content hr { margin: 1.5rem 0; border: none; border-top: 1px solid var(--border); } +.docs-content a { color: var(--info); text-decoration: none; } +.docs-content a:hover { text-decoration: underline; } +@media (max-width: 768px) { + .docs-container { flex-direction: column; } + .docs-nav { width: 100%; display: flex; flex-wrap: wrap; gap: 0.5rem; } + .docs-nav-item { margin-bottom: 0; } +} +''' + +# ==================== UI 组件库样式 ==================== +CSS_UI_COMPONENTS = ''' +/* Modal 模态框 */ +.modal-overlay { position: fixed; top: 0; left: 0; right: 0; bottom: 0; background: rgba(0,0,0,0.5); display: flex; align-items: center; justify-content: center; z-index: 1000; opacity: 0; visibility: hidden; transition: all 0.2s; } +.modal-overlay.active { opacity: 1; visibility: visible; } +.modal { background: var(--card); border-radius: 12px; max-width: 500px; width: 90%; max-height: 90vh; overflow: hidden; transform: scale(0.9); transition: transform 0.2s; } +.modal-overlay.active .modal { transform: scale(1); } +.modal-header { padding: 1rem 1.5rem; border-bottom: 1px solid var(--border); display: flex; justify-content: space-between; align-items: center; } +.modal-header h3 { font-size: 1.1rem; font-weight: 600; } +.modal-close { background: none; border: none; font-size: 1.5rem; cursor: pointer; color: var(--muted); padding: 0; line-height: 1; } +.modal-body { padding: 1.5rem; overflow-y: auto; max-height: 60vh; } +.modal-footer { padding: 1rem 1.5rem; border-top: 1px solid var(--border); display: flex; justify-content: flex-end; gap: 0.5rem; } +.modal.danger .modal-header { background: #fee2e2; } +.modal.warning .modal-header { background: #fef3c7; } +@media (prefers-color-scheme: dark) { + .modal.danger .modal-header { background: #7f1d1d; } + .modal.warning .modal-header { background: #78350f; } +} + +/* Toast 通知 */ +.toast-container { position: fixed; top: 1rem; right: 1rem; z-index: 1100; display: flex; flex-direction: column; gap: 0.5rem; } +.toast { padding: 0.75rem 1rem; border-radius: 8px; background: var(--card); border: 1px solid var(--border); box-shadow: 0 4px 12px rgba(0,0,0,0.15); display: flex; align-items: center; gap: 0.5rem; animation: slideIn 0.3s ease; min-width: 250px; } +.toast.success { border-left: 4px solid var(--success); } +.toast.error { border-left: 4px solid var(--error); } +.toast.warning { border-left: 4px solid var(--warn); } +.toast.info { border-left: 4px solid var(--info); } +.toast-close { margin-left: auto; background: none; border: none; cursor: pointer; color: var(--muted); font-size: 1.2rem; padding: 0; } +@keyframes slideIn { from { transform: translateX(100%); opacity: 0; } to { transform: translateX(0); opacity: 1; } } + +/* Select 下拉选择 */ +.custom-select { position: relative; } +.custom-select-trigger { padding: 0.75rem 1rem; border: 1px solid var(--border); border-radius: 6px; background: var(--card); cursor: pointer; display: flex; justify-content: space-between; align-items: center; } +.custom-select-trigger::after { content: "▼"; font-size: 0.7rem; color: var(--muted); } +.custom-select-options { position: absolute; top: 100%; left: 0; right: 0; background: var(--card); border: 1px solid var(--border); border-radius: 6px; margin-top: 4px; max-height: 200px; overflow-y: auto; z-index: 100; display: none; } +.custom-select.open .custom-select-options { display: block; } +.custom-select-option { padding: 0.5rem 1rem; cursor: pointer; } +.custom-select-option:hover { background: var(--bg); } +.custom-select-option.selected { background: var(--accent); color: var(--bg); } + +/* ProgressBar 进度条 */ +.progress-bar { height: 8px; background: var(--bg); border-radius: 4px; overflow: hidden; } +.progress-bar.large { height: 12px; } +.progress-bar.small { height: 4px; } +.progress-fill { height: 100%; background: var(--info); transition: width 0.3s; } +.progress-fill.success { background: var(--success); } +.progress-fill.warning { background: var(--warn); } +.progress-fill.error { background: var(--error); } +.progress-label { display: flex; justify-content: space-between; font-size: 0.75rem; color: var(--muted); margin-top: 0.25rem; } + +/* Dropdown 下拉菜单 */ +.dropdown { position: relative; display: inline-block; } +.dropdown-menu { position: absolute; top: 100%; right: 0; background: var(--card); border: 1px solid var(--border); border-radius: 8px; min-width: 120px; box-shadow: 0 4px 12px rgba(0,0,0,0.15); z-index: 100; display: none; margin-top: 4px; overflow: hidden; } +.dropdown.open .dropdown-menu { display: block; } +.dropdown-item { padding: 0.5rem 0.75rem; cursor: pointer; display: flex; align-items: center; gap: 0.5rem; font-size: 0.875rem; white-space: nowrap; } +.dropdown-item:hover { background: var(--bg); } +.dropdown-item.danger { color: var(--error); } +.dropdown-divider { height: 1px; background: var(--border); margin: 0.25rem 0; } + +/* 账号卡片增强 */ +.account-card-enhanced { border: 1px solid var(--border); border-radius: 12px; padding: 1.25rem; margin-bottom: 1rem; background: var(--card); } +.account-card-enhanced.priority { border-color: var(--info); border-width: 2px; } +.account-card-enhanced.active { box-shadow: 0 0 0 2px var(--success); } +.account-card-header { display: flex; justify-content: space-between; align-items: flex-start; margin-bottom: 1rem; } +.account-card-title { display: flex; align-items: center; gap: 0.5rem; flex-wrap: wrap; } +.account-card-badges { display: flex; gap: 0.25rem; flex-wrap: wrap; } +.account-quota-section { margin: 1rem 0; } +.quota-header { display: flex; justify-content: space-between; margin-bottom: 0.5rem; font-size: 0.875rem; } +.quota-detail { display: flex; gap: 1rem; font-size: 0.75rem; color: var(--muted); margin-top: 0.5rem; flex-wrap: wrap; } +.quota-reset-info { display: flex; gap: 1rem; flex-wrap: wrap; } +.quota-reset-info span { display: inline-flex; align-items: center; gap: 0.25rem; } +.account-stats-grid { display: grid; grid-template-columns: repeat(4, 1fr); gap: 0.5rem; margin: 1rem 0; } +.account-stat { text-align: center; padding: 0.5rem; background: var(--bg); border-radius: 6px; } +.account-stat-value { font-weight: 600; font-size: 0.9rem; } +.account-stat-label { font-size: 0.7rem; color: var(--muted); } + +/* 账号网格布局 - 动态自适应 */ +.accounts-grid { display: grid; grid-template-columns: repeat(auto-fill, minmax(260px, 1fr)); gap: 0.75rem; margin-top: 1rem; } +.account-card-compact { background: var(--card); border: 1px solid var(--border); border-radius: 10px; padding: 0.875rem; transition: all 0.2s; } +.account-card-compact:hover { border-color: var(--accent); } +.account-card-compact.priority { border-color: var(--info); border-width: 2px; } +.account-card-compact.low-balance { border-color: var(--warn); } +.account-card-compact.exhausted { border-color: var(--error); border-width: 2px; } +.account-card-compact.suspended { border-color: var(--error); border-width: 2px; background: rgba(239, 68, 68, 0.1); } +.account-card-compact.unavailable { opacity: 0.6; } +.account-card-top { display: flex; justify-content: space-between; align-items: flex-start; margin-bottom: 0.75rem; } +.account-card-info { flex: 1; min-width: 0; } +.account-card-name { font-weight: 600; font-size: 0.95rem; white-space: nowrap; overflow: hidden; text-overflow: ellipsis; margin-bottom: 0.25rem; } +.account-card-email { font-size: 0.75rem; color: var(--muted); white-space: nowrap; overflow: hidden; text-overflow: ellipsis; } +.account-card-status { display: flex; gap: 0.25rem; flex-wrap: wrap; } +.account-card-quota { margin: 0.75rem 0; } +.account-card-quota-bar { height: 6px; background: var(--bg); border-radius: 3px; overflow: hidden; } +.account-card-quota-fill { height: 100%; transition: width 0.3s; } +.account-card-quota-text { display: flex; justify-content: space-between; font-size: 0.7rem; color: var(--muted); margin-top: 0.25rem; } +.account-card-stats { display: flex; gap: 1rem; font-size: 0.75rem; color: var(--muted); margin-bottom: 0.75rem; } +.account-card-actions { display: flex; gap: 0.5rem; flex-wrap: wrap; padding-top: 0.75rem; border-top: 1px solid var(--border); } +.account-card-actions button { flex: 1; min-width: 60px; } + +/* 紧凑汇总面板 */ +.summary-compact { display: flex; gap: 1rem; flex-wrap: wrap; align-items: center; padding: 0.75rem; background: var(--bg); border-radius: 8px; } +.summary-compact-item { display: flex; align-items: center; gap: 0.5rem; } +.summary-compact-value { font-weight: 600; font-size: 1.1rem; } +.summary-compact-label { font-size: 0.75rem; color: var(--muted); } +.summary-compact-divider { width: 1px; height: 24px; background: var(--border); } +.summary-quota-bar { flex: 1; min-width: 200px; } + +/* 全局进度条 - 批量刷新操作进度显示 */ +.global-progress-bar { position: fixed; top: 0; left: 0; right: 0; z-index: 1200; background: var(--card); border-bottom: 1px solid var(--border); box-shadow: 0 2px 8px rgba(0,0,0,0.1); transform: translateY(-100%); transition: transform 0.3s ease; } +.global-progress-bar.active { transform: translateY(0); } +.global-progress-bar-inner { max-width: 1400px; margin: 0 auto; padding: 0.75rem 1rem; } +.global-progress-bar-header { display: flex; justify-content: space-between; align-items: center; margin-bottom: 0.5rem; } +.global-progress-bar-title { font-weight: 600; font-size: 0.9rem; display: flex; align-items: center; gap: 0.5rem; } +.global-progress-bar-title .spinner { display: inline-block; width: 14px; height: 14px; border: 2px solid var(--border); border-top-color: var(--accent); border-radius: 50%; animation: spin 1s linear infinite; } +.global-progress-bar-stats { display: flex; gap: 1rem; font-size: 0.8rem; color: var(--muted); } +.global-progress-bar-stats span { display: flex; align-items: center; gap: 0.25rem; } +.global-progress-bar-stats .success { color: var(--success); } +.global-progress-bar-stats .error { color: var(--error); } +.global-progress-bar-track { height: 6px; background: var(--bg); border-radius: 3px; overflow: hidden; margin-bottom: 0.5rem; } +.global-progress-bar-fill { height: 100%; background: var(--info); transition: width 0.3s ease; border-radius: 3px; } +.global-progress-bar-fill.complete { background: var(--success); } +.global-progress-bar-current { font-size: 0.75rem; color: var(--muted); white-space: nowrap; overflow: hidden; text-overflow: ellipsis; } +.global-progress-bar-close { background: none; border: none; font-size: 1.2rem; cursor: pointer; color: var(--muted); padding: 0; margin-left: 0.5rem; } +.global-progress-bar-close:hover { color: var(--text); } + +/* 汇总面板 */ +.summary-panel { background: linear-gradient(135deg, var(--card) 0%, var(--bg) 100%); border: 1px solid var(--border); border-radius: 12px; padding: 1.5rem; margin-bottom: 1.5rem; } +.summary-grid { display: grid; grid-template-columns: repeat(auto-fit, minmax(100px, 1fr)); gap: 1rem; margin-bottom: 1rem; } +.summary-item { text-align: center; } +.summary-value { font-size: 1.75rem; font-weight: 700; } +.summary-label { font-size: 0.75rem; color: var(--muted); } +.summary-item.success .summary-value { color: var(--success); } +.summary-item.warning .summary-value { color: var(--warn); } +.summary-item.error .summary-value { color: var(--error); } +.summary-quota { margin: 1rem 0; } +.summary-info { display: flex; gap: 2rem; flex-wrap: wrap; font-size: 0.875rem; color: var(--muted); } +.summary-actions { margin-top: 1rem; display: flex; gap: 0.5rem; } +''' + +CSS_STYLES = CSS_BASE + CSS_LAYOUT + CSS_COMPONENTS + CSS_FORMS + CSS_ACCOUNTS + CSS_API + CSS_DOCS + CSS_UI_COMPONENTS + + +# ==================== HTML 模板 ==================== +# 全局进度条容器 - 显示在页面顶部 +HTML_GLOBAL_PROGRESS = ''' + +
+
+
+
+ + 正在刷新额度... +
+
+ 完成: 0/0 + 成功: 0 + 失败: 0 + +
+
+
+
+
+
准备中...
+
+
+''' + +HTML_HEADER = ''' +
+

KiroKiro API Proxy

+
+ + 检查中... + +
+
+ +
+
📚 帮助
+
📊 监控
+
👥 账号
+
🔌 API
+
⚙️ 设置
+
+''' + +HTML_HELP = ''' +
+
+
+ +
+

加载中...

+
+
+
+
+''' + +HTML_FLOWS = ''' +
+
+

Flow 统计

+
+
+
+

流量监控

+
+ + + + + +
+
+
+ +
+''' + +HTML_MONITOR = ''' +
+ +
+

🚀 服务状态

+
+
+ + +
+

📈 流量统计

+
+
+ + +
+

⚡ 配额状态

+
+
+ + +
+

🎯 速度测试

+
+ + 点击开始测试 +
+
+ + +
+

📋 请求监控

+
+ + + + + +
+
+
+ + +
+

📝 请求日志

+
+ + + + + + + + + + + + +
时间路径模型账号状态耗时
+
+
+ + + +
+''' + + +HTML_ACCOUNTS = ''' +
+ +
+
+

账号管理

+
+ + + + +
+
+ +
+
+ + + + + + + + + + + + + +
+
+''' + +HTML_LOGS = ''' +
+
+

请求日志

+ + + +
时间路径模型账号状态耗时
+
+
+''' + +HTML_API = ''' +
+
+

API 端点

+

支持 OpenAI、Anthropic、Gemini 三种协议

+

OpenAI 协议

+
POST/v1/chat/completions
+
GET/v1/models
+

Anthropic 协议

+
POST/v1/messages
+
POST/v1/messages/count_tokens
+

Gemini 协议

+
POST/v1/models/{model}:generateContent
+

Base URL

+
+ +
+
+

配置示例

+

Claude Code

+
Base URL: 
+API Key: any
+模型: claude-sonnet-4
+

Codex CLI

+
Endpoint: /v1
+API Key: any
+模型: gpt-4o
+
+
+

Claude Code 终端配置

+

Claude Code 终端版需要配置 ~/.claude/settings.json 才能跳过登录使用代理

+ +

临时生效(当前终端)

+
export ANTHROPIC_BASE_URL=""
+export ANTHROPIC_AUTH_TOKEN="sk-any"
+export CLAUDE_CODE_DISABLE_NONESSENTIAL_TRAFFIC=1
+ + +

永久生效(推荐,写入配置文件)

+
# 写入 Claude Code 配置文件
+mkdir -p ~/.claude
+cat > ~/.claude/settings.json << 'EOF'
+{
+  "env": {
+    "ANTHROPIC_BASE_URL": "",
+    "ANTHROPIC_AUTH_TOKEN": "sk-any",
+    "CLAUDE_CODE_DISABLE_NONESSENTIAL_TRAFFIC": "1"
+  }
+}
+EOF
+ + +

清除配置

+
# 删除 Claude Code 配置
+rm -f ~/.claude/settings.json
+unset ANTHROPIC_BASE_URL ANTHROPIC_AUTH_TOKEN CLAUDE_CODE_DISABLE_NONESSENTIAL_TRAFFIC
+ + +

+ 💡 使用 ANTHROPIC_AUTH_TOKEN + CLAUDE_CODE_DISABLE_NONESSENTIAL_TRAFFIC=1 可跳过登录 +

+
+
+

模型映射

+

支持多种模型名称,自动映射到 Kiro 模型

+ + + + + + + + +
Kiro 模型能力可用名称
claude-sonnet-4⭐⭐⭐ 推荐gpt-4o, gpt-4, claude-3-5-sonnet-*, sonnet
claude-sonnet-4.5⭐⭐⭐⭐ 更强gemini-1.5-pro, o1, o1-preview, claude-3-opus-*, opus
claude-haiku-4.5⚡ 快速gpt-4o-mini, gpt-3.5-turbo, haiku
auto🤖 自动auto
+

+ 💡 直接使用 Kiro 模型名(如 claude-sonnet-4)或任意映射名称均可 +

+
+
+''' + +HTML_SETTINGS = ''' +
+ +
+

🤖 自动化管理

+

+ 以下功能已启用自动化管理,无需手动配置: +

+
+
+
+ 🔄 + Token 与额度刷新 + 自动 +
+

+ Token 过期前自动刷新,额度信息定期更新 +

+
+
+
+ + 请求限速与 429 冷却 + 自动 +
+

+ 遇到 429 错误自动冷却 5 分钟,自动切换到其他可用账号 +

+
+
+
+ 📝 + 历史消息压缩 + 自动 +
+

+ 上下文超限时自动压缩,智能生成摘要保留关键信息 +

+
+
+
+ 🎲 + 账号负载均衡 + 自动 +
+

+ 支持随机、轮询、最少请求等多种账号选择策略,分散请求压力 +

+
+
+
+ + + + + + + + + +
+''' + +HTML_BODY = HTML_GLOBAL_PROGRESS + HTML_HEADER + HTML_HELP + HTML_MONITOR + HTML_ACCOUNTS + HTML_API + HTML_SETTINGS + + +# ==================== JavaScript ==================== +JS_UTILS = ''' +const $=s=>document.querySelector(s); +const $$=s=>document.querySelectorAll(s); + +function copy(text){ + navigator.clipboard.writeText(text).then(()=>{ + const toast=document.createElement('div'); + toast.textContent='已复制'; + toast.style.cssText='position:fixed;bottom:2rem;left:50%;transform:translateX(-50%);background:var(--accent);color:var(--bg);padding:0.5rem 1rem;border-radius:6px;font-size:0.875rem;z-index:1000'; + document.body.appendChild(toast); + setTimeout(()=>toast.remove(),1500); + }); +} + +function copyEnvTemp(){ + const url=location.origin; + copy(`export ANTHROPIC_BASE_URL="${url}" +export ANTHROPIC_AUTH_TOKEN="sk-any" +export CLAUDE_CODE_DISABLE_NONESSENTIAL_TRAFFIC=1`); +} + +function copyEnvPerm(){ + const url=location.origin; + copy(`# 写入 Claude Code 配置文件(推荐) +mkdir -p ~/.claude +cat > ~/.claude/settings.json << 'EOF' +{ + "env": { + "ANTHROPIC_BASE_URL": "${url}", + "ANTHROPIC_AUTH_TOKEN": "sk-any", + "CLAUDE_CODE_DISABLE_NONESSENTIAL_TRAFFIC": "1" + } +} +EOF +echo "配置完成,请重新打开终端运行 claude"`); +} + +function copyEnvClear(){ + copy(`# 删除 Claude Code 配置 +rm -f ~/.claude/settings.json +unset ANTHROPIC_BASE_URL ANTHROPIC_AUTH_TOKEN CLAUDE_CODE_DISABLE_NONESSENTIAL_TRAFFIC +echo "配置已清除"`); +} + +function formatUptime(s){ + if(s<60)return s+'秒'; + if(s<3600)return Math.floor(s/60)+'分钟'; + return Math.floor(s/3600)+'小时'+Math.floor((s%3600)/60)+'分钟'; +} + +function escapeHtml(text){ + const div=document.createElement('div'); + div.textContent=text; + return div.innerHTML; +} +''' + +JS_TABS = ''' +// Tabs +$$('.tab').forEach(t=>t.onclick=()=>{ + $$('.tab').forEach(x=>x.classList.remove('active')); + $$('.panel').forEach(x=>x.classList.remove('active')); + t.classList.add('active'); + $('#'+t.dataset.tab).classList.add('active'); + + // 监控面板加载所有数据 + if(t.dataset.tab==='monitor'){ + loadStats(); + loadQuota(); + loadFlowStats(); + loadFlows(); + loadLogs(); + } + if(t.dataset.tab==='accounts'){ + loadAccounts(); + loadAccountsEnhanced(); + } +}); +''' + +JS_STATUS = ''' +// Status +async function checkStatus(){ + try{ + const r=await fetch('/api/status'); + const d=await r.json(); + $('#statusDot').className='status-dot '+(d.ok?'ok':'err'); + $('#statusText').textContent=d.ok?'已连接':'未连接'; + if(d.stats)$('#uptime').textContent='运行 '+formatUptime(d.stats.uptime_seconds); + }catch(e){ + $('#statusDot').className='status-dot err'; + $('#statusText').textContent='连接失败'; + } +} +checkStatus(); +setInterval(checkStatus,30000); + +// URLs +$('#baseUrl').textContent=location.origin; +$$('.pyUrl').forEach(e=>e.textContent=location.origin); +''' + +JS_DOCS = ''' +// 文档浏览 +let docsData = []; +let currentDoc = null; + +// 简单的 Markdown 渲染 +function renderMarkdown(text) { + return text + .replace(/```(\\w*)\\n([\\s\\S]*?)```/g, '
$2
') + .replace(/`([^`]+)`/g, '$1') + .replace(/^#### (.+)$/gm, '

$1

') + .replace(/^### (.+)$/gm, '

$1

') + .replace(/^## (.+)$/gm, '

$1

') + .replace(/^# (.+)$/gm, '

$1

') + .replace(/\\*\\*(.+?)\\*\\*/g, '$1') + .replace(/\\*(.+?)\\*/g, '$1') + .replace(/\\[([^\\]]+)\\]\\(([^)]+)\\)/g, '$1') + .replace(/^- (.+)$/gm, '
  • $1
  • ') + .replace(/(
  • .*<\\/li>\\n?)+/g, '
      $&
    ') + .replace(/^\\d+\\. (.+)$/gm, '
  • $1
  • ') + .replace(/^> (.+)$/gm, '
    $1
    ') + .replace(/^---$/gm, '
    ') + .replace(/\\|(.+)\\|/g, function(match) { + const cells = match.split('|').filter(c => c.trim()); + if (cells.every(c => /^[\\s-:]+$/.test(c))) return ''; + const tag = match.includes('---') ? 'th' : 'td'; + return '' + cells.map(c => '<' + tag + '>' + c.trim() + '').join('') + ''; + }) + .replace(/(.*<\\/tr>\\n?)+/g, '$&
    ') + .replace(/\\n\\n/g, '

    ') + .replace(/\\n/g, '
    '); +} + +async function loadDocs() { + try { + const r = await fetch('/api/docs'); + const d = await r.json(); + docsData = d.docs || []; + + // 渲染导航 + $('#docsNav').innerHTML = docsData.map((doc, i) => + '' + doc.title + '' + ).join(''); + + // 显示第一个文档 + if (docsData.length > 0) { + showDoc(docsData[0].id); + } + } catch (e) { + $('#docsContent').innerHTML = '

    加载文档失败

    '; + } +} + +async function showDoc(id) { + // 更新导航状态 + $$('.docs-nav-item').forEach(item => { + item.classList.toggle('active', item.dataset.id === id); + }); + + // 获取文档内容 + try { + const r = await fetch('/api/docs/' + id); + const d = await r.json(); + currentDoc = d; + $('#docsContent').innerHTML = renderMarkdown(d.content); + } catch (e) { + $('#docsContent').innerHTML = '

    加载文档失败

    '; + } +} + +// 页面加载时加载文档 +loadDocs(); +''' + +JS_STATS = ''' +// Stats +async function loadStats(){ + try{ + const r=await fetch('/api/stats'); + const d=await r.json(); + $('#statsGrid').innerHTML=` +
    ${d.total_requests}
    总请求
    +
    ${d.total_errors}
    错误数
    +
    ${d.error_rate}
    错误率
    +
    ${d.accounts_available}/${d.accounts_total}
    可用账号
    +
    ${d.accounts_cooldown||0}
    冷却中
    + `; + }catch(e){console.error(e)} +} + +// Quota +async function loadQuota(){ + try{ + const r=await fetch('/api/quota'); + const d=await r.json(); + if(d.exceeded_credentials&&d.exceeded_credentials.length>0){ + $('#quotaStatus').innerHTML=d.exceeded_credentials.map(c=>` +
    + 冷却中 ${c.credential_id} + 剩余 ${c.remaining_seconds}秒 + +
    + `).join(''); + }else{ + $('#quotaStatus').innerHTML='

    无冷却中的账号

    '; + } + }catch(e){console.error(e)} +} + +// Speedtest +async function runSpeedtest(){ + $('#speedtestBtn').disabled=true; + $('#speedtestResult').textContent='测试中...'; + try{ + const r=await fetch('/api/speedtest',{method:'POST'}); + const d=await r.json(); + $('#speedtestResult').textContent=d.ok?`延迟: ${d.latency_ms.toFixed(0)}ms (${d.account_id})`:'测试失败: '+d.error; + }catch(e){$('#speedtestResult').textContent='测试失败'} + $('#speedtestBtn').disabled=false; +} +''' + +JS_LOGS = ''' +// Logs +async function loadLogs(){ + try{ + const r=await fetch('/api/logs?limit=50'); + const d=await r.json(); + $('#logTable').innerHTML=(d.logs||[]).map(l=>` + + ${new Date(l.timestamp*1000).toLocaleTimeString()} + ${l.path} + ${l.model||'-'} + ${l.account_id||'-'} + ${l.status} + ${l.duration_ms.toFixed(0)}ms + + `).join(''); + }catch(e){console.error(e)} +} +''' + + +JS_ACCOUNTS = ''' +// Accounts +async function loadAccounts(){ + try{ + const r=await fetch('/api/accounts'); + const d=await r.json(); + if(!d.accounts||d.accounts.length===0){ + $('#accountList').innerHTML='

    暂无账号,请点击"扫描 Token"

    '; + return; + } + $('#accountList').innerHTML=d.accounts.map(a=>{ + const statusBadge=a.status==='active'?'success':a.status==='cooldown'?'warn':a.status==='suspended'?'error':'error'; + const statusText={active:'可用',cooldown:'冷却中',unhealthy:'不健康',disabled:'已禁用',suspended:'已封禁'}[a.status]||a.status; + const authBadge=a.auth_method==='idc'?'info':'success'; + const authText=a.auth_method==='idc'?'IdC':'Social'; + return ` + + `; + }).join(''); + }catch(e){console.error(e)} +} + +async function queryUsage(id){ + const usageDiv=$('#usage-'+id); + usageDiv.style.display='block'; + usageDiv.innerHTML='查询中...'; + try{ + const r=await fetch('/api/accounts/'+id+'/usage'); + const d=await r.json(); + if(d.ok){ + const u=d.usage; + const pct=u.usage_limit>0?((u.current_usage/u.usage_limit)*100).toFixed(1):0; + const barColor=u.is_low_balance?'var(--error)':'var(--success)'; + usageDiv.innerHTML=` +
    + ${u.subscription_title} + ${u.is_low_balance?'余额不足':'正常'} +
    +
    +
    +
    +
    +
    已用/总额: ${u.current_usage.toFixed(2)} / ${u.usage_limit.toFixed(2)}
    +
    使用率: ${pct}%
    + ${u.reset_date_text ? `
    重置时间: ${u.reset_date_text}
    ` : ''} + ${u.trial_expiry_text ? `
    试用过期: ${u.trial_expiry_text}
    ` : ''} +
    + `; + }else{ + usageDiv.innerHTML=`查询失败: ${d.error}`; + } + }catch(e){ + usageDiv.innerHTML=`查询失败: ${e.message}`; + } +} + +async function refreshToken(id){ + try{ + Toast.info('正在刷新 Token...'); + const r=await fetch('/api/accounts/'+id+'/refresh',{method:'POST'}); + const d=await r.json(); + if(d.ok) { + Toast.success('Token 刷新成功'); + } else { + Toast.error('刷新失败: '+(d.message||d.error)); + } + loadAccounts(); + loadAccountsEnhanced(); + }catch(e){Toast.error('刷新失败: '+e.message)} +} + +async function refreshAllTokens(){ + try{ + Toast.info('正在刷新所有 Token...'); + const r=await fetch('/api/accounts/refresh-all',{method:'POST'}); + const d=await r.json(); + Toast.success(`刷新完成: ${d.refreshed} 个账号`); + loadAccounts(); + loadAccountsEnhanced(); + }catch(e){Toast.error('刷新失败: '+e.message)} +} + +async function restoreAccount(id){ + try{ + Toast.info('正在恢复账号...'); + const r = await fetch('/api/accounts/'+id+'/restore',{method:'POST'}); + const d = await r.json(); + if(d.ok) { + Toast.success('账号已恢复'); + } else { + Toast.error(d.error || '恢复失败'); + } + loadAccounts(); + loadAccountsEnhanced(); + loadQuota(); + }catch(e){Toast.error('恢复失败: '+e.message)} +} + +async function viewAccountDetail(id){ + try{ + const r=await fetch('/api/accounts/'+id); + const d=await r.json(); + Modal.info('账号详情', ` +
    +

    账号名: ${d.name}

    +

    ID: ${d.id}

    +

    状态: ${d.status}

    +

    请求数: ${d.request_count}

    +

    错误数: ${d.error_count}

    +
    + `); + }catch(e){Toast.error('获取详情失败: '+e.message)} +} + +async function toggleAccount(id){ + try { + const r = await fetch('/api/accounts/'+id+'/toggle',{method:'POST'}); + const d = await r.json(); + if(d.ok) { + Toast.success(d.enabled ? '账号已启用' : '账号已禁用'); + } else { + Toast.error(d.error || '操作失败'); + } + } catch(e) { + Toast.error('操作失败: ' + e.message); + } + loadAccounts(); + loadAccountsEnhanced(); +} + +async function deleteAccount(id){ + if(confirm('确定删除此账号?')){ + try { + const r = await fetch('/api/accounts/'+id,{method:'DELETE'}); + const d = await r.json(); + if(d.ok) { + Toast.success('账号已删除'); + } else { + Toast.error(d.error || '删除失败'); + } + } catch(e) { + Toast.error('删除失败: ' + e.message); + } + loadAccounts(); + loadAccountsEnhanced(); + } +} + +function showAddAccount(){ + const path=prompt('输入 Token 文件路径:'); + if(path){ + const name=prompt('账号名称:','账号'); + fetch('/api/accounts',{ + method:'POST', + headers:{'Content-Type':'application/json'}, + body:JSON.stringify({name,token_path:path}) + }).then(r=>r.json()).then(d=>{ + if(d.ok){ + Toast.success('账号添加成功'); + loadAccounts(); + loadAccountsEnhanced(); + } + else alert(d.detail||'添加失败'); + }); + } +} + +async function scanTokens(){ + try{ + const r=await fetch('/api/token/scan'); + const d=await r.json(); + const panel=$('#scanResults'); + const list=$('#scanList'); + if(d.tokens&&d.tokens.length>0){ + panel.style.display='block'; + list.innerHTML=d.tokens.map(t=>{ + const path=encodeURIComponent(t.path||''); + const name=encodeURIComponent(t.name||''); + return ` +
    +
    +
    ${t.name}
    +
    ${t.path}
    +
    + ${t.already_added?'已添加':``} +
    + `; + }).join(''); + }else{ + alert('未找到 Token 文件'); + } + }catch(e){alert('扫描失败: '+e.message)} +} + +async function addFromScan(path,name){ + try{ + const r=await fetch('/api/token/add-from-scan',{ + method:'POST', + headers:{'Content-Type':'application/json'}, + body:JSON.stringify({path,name}) + }); + const d=await r.json(); + if(d.ok){ + loadAccounts(); + scanTokens(); + }else{ + alert(d.detail||'添加失败'); + } + }catch(e){alert('添加失败: '+e.message)} +} + +async function checkTokens(){ + try{ + const r=await fetch('/api/token/refresh-check',{method:'POST'}); + const d=await r.json(); + let msg='Token 状态:\\n\\n'; + (d.accounts||[]).forEach(a=>{ + const status=a.valid?'✅ 有效':'❌ 无效'; + msg+=`${a.name}: ${status}\\n`; + }); + alert(msg); + }catch(e){alert('检查失败: '+e.message)} +} + +// 手动添加 Token +function showManualAdd(){ + $('#manualAddPanel').style.display='block'; + $('#manualName').value=''; + $('#manualAccessToken').value=''; + $('#manualRefreshToken').value=''; +} + +async function submitManualToken(){ + const name=$('#manualName').value.trim(); + const accessToken=$('#manualAccessToken').value.trim(); + const refreshToken=$('#manualRefreshToken').value.trim(); + const authMethod=$('#manualAuthMethod').value; + const provider=$('#manualProvider')?.value || ''; + const clientId=$('#manualClientId')?.value?.trim() || ''; + const clientSecret=$('#manualClientSecret')?.value?.trim() || ''; + const region=$('#manualRegion')?.value?.trim() || 'us-east-1'; + + // Refresh Token 必填 + if (!refreshToken) { + Toast.error('Refresh Token 是必填项'); + return; + } + + // 验证 Refresh Token 格式 + if (refreshToken.length < 100) { + Toast.error('Refresh Token 格式不正确(太短)'); + return; + } + + // IDC 认证需要 clientId 和 clientSecret + if (authMethod === 'idc' && (!clientId || !clientSecret)) { + Toast.error('IDC 认证需要填写 Client ID 和 Client Secret'); + return; + } + + Toast.info('正在添加账号...'); + + try{ + const r=await fetchWithRetry('/api/accounts/manual',{ + method:'POST', + headers:{'Content-Type':'application/json'}, + body:JSON.stringify({ + name: name || '', + access_token: accessToken, + refresh_token: refreshToken, + auth_method: authMethod, + provider: provider, + client_id: clientId, + client_secret: clientSecret, + region: region + }) + }); + const d=await r.json(); + if(d.ok){ + let msg = '添加成功'; + if (d.auto_name) { + msg += '(已自动获取邮箱作为名称)'; + } + Toast.success(msg); + $('#manualAddPanel').style.display='none'; + // 清空表单 + $('#manualName').value = ''; + $('#manualAccessToken').value = ''; + $('#manualRefreshToken').value = ''; + if ($('#manualClientId')) $('#manualClientId').value = ''; + if ($('#manualClientSecret')) $('#manualClientSecret').value = ''; + loadAccounts(); + loadAccountsEnhanced(); + }else{ + Toast.error(d.detail||'添加失败'); + } + }catch(e){Toast.error('添加失败: '+e.message)} +} + +// 切换手动添加表单字段显示 +function toggleManualFields() { + const authMethod = $('#manualAuthMethod').value; + const idcFields = $('#manualIdcFields'); + const providerField = $('#manualProviderField'); + + if (authMethod === 'idc') { + if (idcFields) idcFields.style.display = 'block'; + if (providerField) providerField.style.display = 'none'; + } else { + if (idcFields) idcFields.style.display = 'none'; + if (providerField) providerField.style.display = 'block'; + } +} + +// 导出账号 +async function exportAccounts(){ + try{ + const r=await fetch('/api/accounts/export'); + const d=await r.json(); + if(!d.ok){alert('导出失败');return;} + const blob=new Blob([JSON.stringify(d,null,2)],{type:'application/json'}); + const url=URL.createObjectURL(blob); + const a=document.createElement('a'); + a.href=url; + a.download='kiro-accounts-'+new Date().toISOString().slice(0,10)+'.json'; + a.click(); + }catch(e){alert('导出失败: '+e.message)} +} + +// 导入账号 +function importAccounts(){ + const input=document.createElement('input'); + input.type='file'; + input.accept='.json'; + input.onchange=async(e)=>{ + const file=e.target.files[0]; + if(!file)return; + try{ + const text=await file.text(); + const data=JSON.parse(text); + const r=await fetch('/api/accounts/import',{ + method:'POST', + headers:{'Content-Type':'application/json'}, + body:JSON.stringify(data) + }); + const d=await r.json(); + if(d.ok){ + alert(`导入成功: ${d.imported} 个账号`+(d.errors?.length?`\\n错误: ${d.errors.join(', ')}`:'')); + loadAccounts(); + }else{ + alert('导入失败'); + } + }catch(e){alert('导入失败: '+e.message)} + }; + input.click(); +} +''' + +JS_LOGIN = ''' +// Kiro 在线登录 +let loginPollTimer=null; + +function showLoginOptions(){ + $('#loginOptions').style.display='block'; +} + +async function startSocialLogin(provider){ + $('#loginOptions').style.display='none'; + try{ + const r=await fetch('/api/kiro/social/start',{ + method:'POST', + headers:{'Content-Type':'application/json'}, + body:JSON.stringify({provider}) + }); + const d=await r.json(); + if(!d.ok){alert('启动登录失败: '+d.error);return;} + showSocialLoginPanel(d.provider, d.login_url); + }catch(e){alert('启动登录失败: '+e.message)} +} + +// 协议注册状态 +let protocolRegistered = false; +let callbackPollTimer = null; + +function showSocialLoginPanel(provider, loginUrl){ + $('#loginPanel').style.display='block'; + $('#loginContent').innerHTML=` +
    +

    ${provider} 登录

    +
    +
    +

    步骤 1:打开登录链接

    +
    + + +
    + +

    步骤 2:完成授权后粘贴回调 URL

    +

    + 授权完成后,浏览器会尝试打开 kiro:// 链接。
    + 如果提示"无法打开",请复制地址栏中的完整 URL 粘贴到下方。 +

    +
    + + + +
    +

    可选:自动回调模式

    + +
    +
    + +

    +
    + `; +} + +async function registerProtocolAndWait(provider) { + $('#loginStatus').textContent = '正在注册协议处理器...'; + $('#loginStatus').style.color = 'var(--muted)'; + + try { + const regResp = await fetch('/api/protocol/register', { method: 'POST' }); + const regData = await regResp.json(); + + if (!regData.ok) { + $('#loginStatus').textContent = '协议注册失败: ' + regData.error; + $('#loginStatus').style.color = 'var(--error)'; + return; + } + + protocolRegistered = true; + $('#loginStatus').textContent = '✅ 协议已注册,授权完成后将自动接收回调'; + $('#loginStatus').style.color = 'var(--success)'; + + // 开始轮询回调结果 + startCallbackPolling(provider); + + } catch(e) { + $('#loginStatus').textContent = '操作失败: ' + e.message; + $('#loginStatus').style.color = 'var(--error)'; + } +} + +function startCallbackPolling(provider) { + if (callbackPollTimer) clearInterval(callbackPollTimer); + + let pollCount = 0; + const maxPolls = 300; // 5分钟超时 (300 * 1秒) + + callbackPollTimer = setInterval(async () => { + pollCount++; + + if (pollCount > maxPolls) { + clearInterval(callbackPollTimer); + callbackPollTimer = null; + $('#loginStatus').textContent = '等待超时,请重试'; + $('#loginStatus').style.color = 'var(--error)'; + return; + } + + try { + const resp = await fetch('/api/protocol/callback'); + const data = await resp.json(); + + if (data.ok && data.result) { + clearInterval(callbackPollTimer); + callbackPollTimer = null; + + if (data.result.error) { + $('#loginStatus').textContent = '授权失败: ' + data.result.error; + $('#loginStatus').style.color = 'var(--error)'; + } else if (data.result.code && data.result.state) { + // 自动交换 Token + $('#loginStatus').textContent = '正在交换 Token...'; + await exchangeTokenWithCode(data.result.code, data.result.state); + } + } + } catch(e) { + console.error('轮询回调失败:', e); + } + }, 1000); +} + +async function exchangeTokenWithCode(code, state) { + try { + const r = await fetch('/api/kiro/social/exchange', { + method: 'POST', + headers: {'Content-Type': 'application/json'}, + body: JSON.stringify({ code, state }) + }); + const d = await r.json(); + + if (d.ok && d.completed) { + $('#loginStatus').textContent = '✅ ' + d.message; + $('#loginStatus').style.color = 'var(--success)'; + setTimeout(() => { + $('#loginPanel').style.display = 'none'; + loadAccounts(); + loadAccountsEnhanced(); + }, 1500); + } else { + $('#loginStatus').textContent = '❌ ' + (d.error || '登录失败'); + $('#loginStatus').style.color = 'var(--error)'; + } + } catch(e) { + $('#loginStatus').textContent = '交换 Token 失败: ' + e.message; + $('#loginStatus').style.color = 'var(--error)'; + } +} + +function cancelSocialLogin(){ + if (callbackPollTimer) { + clearInterval(callbackPollTimer); + callbackPollTimer = null; + } + fetch('/api/kiro/social/cancel',{method:'POST'}); + $('#loginPanel').style.display='none'; +} + +async function handleSocialCallback(){ + const url=$('#callbackUrl').value.trim(); + if(!url){alert('请粘贴回调 URL');return;} + try{ + // 支持 kiro:// 协议的 URL 解析 + let code, state; + if(url.startsWith('kiro://')){ + // kiro://kiro.kiroAgent/authenticate-success?code=xxx&state=xxx + const queryStart = url.indexOf('?'); + if(queryStart > -1){ + const params = new URLSearchParams(url.substring(queryStart + 1)); + code = params.get('code'); + state = params.get('state'); + } + } else { + // 标准 http/https URL + const urlObj=new URL(url); + code=urlObj.searchParams.get('code'); + state=urlObj.searchParams.get('state'); + } + if(!code||!state){alert('无效的回调 URL,缺少 code 或 state 参数');return;} + $('#loginStatus').textContent='正在交换 Token...'; + const r=await fetch('/api/kiro/social/exchange',{ + method:'POST', + headers:{'Content-Type':'application/json'}, + body:JSON.stringify({code,state}) + }); + const d=await r.json(); + if(d.ok&&d.completed){ + $('#loginStatus').textContent='✅ '+d.message; + $('#loginStatus').style.color='var(--success)'; + setTimeout(()=>{$('#loginPanel').style.display='none';loadAccounts();},1500); + }else{ + $('#loginStatus').textContent='❌ '+(d.error||'登录失败'); + $('#loginStatus').style.color='var(--error)'; + } + }catch(e){alert('处理回调失败: '+e.message)} +} + +async function startAwsLogin(){ + $('#loginOptions').style.display='none'; + try{ + const r=await fetch('/api/kiro/login/start',{ + method:'POST', + headers:{'Content-Type':'application/json'}, + body:JSON.stringify({}) + }); + const d=await r.json(); + if(!d.ok){alert('启动登录失败: '+d.error);return;} + showAwsLoginPanel(d); + startLoginPoll(); + }catch(e){alert('启动登录失败: '+e.message)} +} + +function showAwsLoginPanel(data){ + $('#loginPanel').style.display='block'; + $('#loginContent').innerHTML=` +
    +

    AWS Builder ID 登录

    +
    ${data.user_code}
    +

    复制上方授权码,然后打开以下链接完成授权:

    +
    + + +
    +

    授权码有效期: ${Math.floor(data.expires_in/60)} 分钟

    + +

    等待授权...

    +
    + `; +} + +function startLoginPoll(){ + if(loginPollTimer)clearInterval(loginPollTimer); + loginPollTimer=setInterval(pollLogin,3000); +} + +async function pollLogin(){ + try{ + const r=await fetch('/api/kiro/login/poll'); + const d=await r.json(); + if(!d.ok){$('#loginStatus').textContent='错误: '+d.error;stopLoginPoll();return;} + if(d.completed){ + $('#loginStatus').textContent='✅ 登录成功!'; + $('#loginStatus').style.color='var(--success)'; + stopLoginPoll(); + setTimeout(()=>{$('#loginPanel').style.display='none';loadAccounts();},1500); + } + }catch(e){$('#loginStatus').textContent='轮询失败: '+e.message} +} + +function stopLoginPoll(){ + if(loginPollTimer){clearInterval(loginPollTimer);loginPollTimer=null;} +} + +async function cancelKiroLogin(){ + stopLoginPoll(); + await fetch('/api/kiro/login/cancel',{method:'POST'}); + $('#loginPanel').style.display='none'; +} +''' + + +JS_FLOWS = ''' +// Flow Monitor +async function loadFlowStats(){ + try{ + const r=await fetch('/api/flows/stats'); + const d=await r.json(); + $('#flowStatsGrid').innerHTML=` +
    ${d.total_flows}
    总请求
    +
    ${d.completed}
    完成
    +
    ${d.errors}
    错误
    +
    ${d.error_rate}
    错误率
    +
    ${d.avg_duration_ms.toFixed(0)}ms
    平均延迟
    +
    ${d.total_tokens_in}
    输入Token
    +
    ${d.total_tokens_out}
    输出Token
    + `; + }catch(e){console.error(e)} +} + +async function loadFlows(){ + try{ + const protocol=$('#flowProtocol').value; + const state=$('#flowState').value; + const search=$('#flowSearch').value; + let url='/api/flows?limit=50'; + if(protocol)url+=`&protocol=${protocol}`; + if(state)url+=`&state=${state}`; + if(search)url+=`&search=${encodeURIComponent(search)}`; + const r=await fetch(url); + const d=await r.json(); + if(!d.flows||d.flows.length===0){ + $('#flowList').innerHTML='

    暂无请求记录

    '; + return; + } + $('#flowList').innerHTML=d.flows.map(f=>{ + const stateBadge={completed:'success',error:'error',streaming:'info',pending:'warn'}[f.state]||'info'; + const stateText={completed:'完成',error:'错误',streaming:'流式中',pending:'等待中'}[f.state]||f.state; + const time=new Date(f.timing.created_at*1000).toLocaleTimeString(); + const duration=f.timing.duration_ms?f.timing.duration_ms.toFixed(0)+'ms':'-'; + const model=f.request?.model||'-'; + const tokens=f.response?.usage?(f.response.usage.input_tokens+'/'+f.response.usage.output_tokens):'-'; + return ` +
    +
    +
    + ${stateText} + ${model} + ${f.bookmarked?'':''} +
    +
    + ${time} · ${duration} · ${tokens} tokens · ${f.protocol} +
    +
    + +
    + `; + }).join(''); + }catch(e){console.error(e)} +} + +async function viewFlow(id){ + try{ + const r=await fetch('/api/flows/'+id); + const f=await r.json(); + let html=`
    ID: ${f.id}
    协议: ${f.protocol}
    状态: ${f.state}
    时间: ${new Date(f.timing.created_at*1000).toLocaleString()}
    延迟: ${f.timing.duration_ms?f.timing.duration_ms.toFixed(0)+'ms':'N/A'}
    `; + if(f.request){ + html+=`

    请求

    模型: ${f.request.model}
    流式: ${f.request.stream?'是':'否'}
    `; + } + if(f.response){ + html+=`

    响应

    状态码: ${f.response.status_code}
    Token: ${f.response.usage?.input_tokens||0} in / ${f.response.usage?.output_tokens||0} out
    `; + } + if(f.error){ + html+=`

    错误

    类型: ${f.error.type}
    消息: ${f.error.message}
    `; + } + $('#flowDetailContent').innerHTML=html; + $('#flowDetail').style.display='block'; + }catch(e){alert('获取详情失败: '+e.message)} +} + +async function toggleBookmark(id,bookmarked){ + await fetch('/api/flows/'+id+'/bookmark',{method:'POST',headers:{'Content-Type':'application/json'},body:JSON.stringify({bookmarked})}); + loadFlows(); +} + +async function exportFlows(){ + try{ + const r=await fetch('/api/flows/export',{method:'POST',headers:{'Content-Type':'application/json'},body:JSON.stringify({format:'json'})}); + const d=await r.json(); + const blob=new Blob([d.content],{type:'application/json'}); + const url=URL.createObjectURL(blob); + const a=document.createElement('a'); + a.href=url; + a.download='flows_'+new Date().toISOString().slice(0,10)+'.json'; + a.click(); + }catch(e){alert('导出失败: '+e.message)} +} +''' + +JS_SETTINGS = ''' +// 设置页面 +// 历史消息管理(简化版,自动管理) + +async function loadHistoryConfig(){ + try{ + const r=await fetch('/api/settings/history'); + const d=await r.json(); + $('#maxRetries').value=d.max_retries||3; + $('#summaryCacheMaxAge').value=d.summary_cache_max_age_seconds||300; + $('#addWarningHeader').checked=d.add_warning_header!==false; + }catch(e){console.error('加载配置失败:',e)} +} + +async function updateHistoryConfig(){ + const config={ + strategies:['error_retry'], // 固定使用错误重试策略 + max_retries:parseInt($('#maxRetries').value)||3, + summary_cache_enabled:true, + summary_cache_max_age_seconds:parseInt($('#summaryCacheMaxAge').value)||300, + add_warning_header:$('#addWarningHeader').checked + }; + try{ + await fetch('/api/settings/history',{method:'POST',headers:{'Content-Type':'application/json'},body:JSON.stringify(config)}); + }catch(e){console.error('保存配置失败:',e)} +} + +// 刷新配置 +async function loadRefreshConfig(){ + try{ + const r=await fetch('/api/refresh/config'); + const d=await r.json(); + if(d.ok && d.config){ + const c=d.config; + $('#refreshMaxRetries').value=c.max_retries||3; + $('#refreshConcurrency').value=c.concurrency||3; + $('#refreshAutoInterval').value=c.auto_refresh_interval||60; + $('#refreshRetryDelay').value=c.retry_base_delay||1.0; + $('#refreshBeforeExpiry').value=c.token_refresh_before_expiry||300; + // 更新状态显示 + $('#refreshConfigStatus').innerHTML=` +
    + 最大重试: ${c.max_retries||3} + 并发数: ${c.concurrency||3} + 自动刷新间隔: ${c.auto_refresh_interval||60} + 提前刷新: ${c.token_refresh_before_expiry||300} +
    + `; + } + }catch(e){console.error('加载刷新配置失败:',e)} +} + +async function saveRefreshConfig(){ + const config={ + max_retries:parseInt($('#refreshMaxRetries').value)||3, + concurrency:parseInt($('#refreshConcurrency').value)||3, + auto_refresh_interval:parseInt($('#refreshAutoInterval').value)||60, + retry_base_delay:parseFloat($('#refreshRetryDelay').value)||1.0, + token_refresh_before_expiry:parseInt($('#refreshBeforeExpiry').value)||300 + }; + try{ + const r=await fetch('/api/refresh/config',{method:'PUT',headers:{'Content-Type':'application/json'},body:JSON.stringify(config)}); + const d=await r.json(); + if(d.ok){ + Toast.success('刷新配置保存成功'); + loadRefreshConfig(); + }else{ + Toast.error(d.error||'保存失败'); + } + }catch(e){ + console.error('保存刷新配置失败:',e); + Toast.error('保存刷新配置失败'); + } +} + +// 限速配置 +async function loadRateLimitConfig(){ + try{ + const r=await fetch('/api/settings/rate-limit'); + const d=await r.json(); + $('#rateLimitEnabled').checked=d.enabled; + $('#minRequestInterval').value=d.min_request_interval||0.5; + $('#maxRequestsPerMinute').value=d.max_requests_per_minute||60; + $('#globalMaxRequestsPerMinute').value=d.global_max_requests_per_minute||120; + // 更新统计 + const stats=d.stats||{}; + $('#rateLimitStats').innerHTML=` +
    + 状态: ${d.enabled?'已启用':'已禁用'} + 全局 RPM: ${stats.global_rpm||0} + 429 冷却: 自动 5 分钟 +
    + `; + }catch(e){console.error('加载限速配置失败:',e)} +} + +async function updateRateLimitConfig(){ + const config={ + enabled:$('#rateLimitEnabled').checked, + min_request_interval:parseFloat($('#minRequestInterval').value)||0.5, + max_requests_per_minute:parseInt($('#maxRequestsPerMinute').value)||60, + global_max_requests_per_minute:parseInt($('#globalMaxRequestsPerMinute').value)||120 + }; + try{ + await fetch('/api/settings/rate-limit',{method:'POST',headers:{'Content-Type':'application/json'},body:JSON.stringify(config)}); + loadRateLimitConfig(); + }catch(e){console.error('保存限速配置失败:',e)} +} + +// 还原默认配置函数 +async function resetRefreshConfig(){ + if(!confirm('确定要还原刷新配置为默认值吗?')) return; + const defaultConfig={ + max_retries:3, + concurrency:3, + auto_refresh_interval:60, + retry_base_delay:1.0, + token_refresh_before_expiry:300 + }; + try{ + const r=await fetch('/api/refresh/config',{method:'PUT',headers:{'Content-Type':'application/json'},body:JSON.stringify(defaultConfig)}); + const d=await r.json(); + if(d.ok){ + Toast.success('已还原为默认配置'); + loadRefreshConfig(); + }else{ + Toast.error(d.error||'还原失败'); + } + }catch(e){ + Toast.error('还原配置失败'); + } +} + +async function resetRateLimitConfig(){ + if(!confirm('确定要还原限速配置为默认值吗?')) return; + const defaultConfig={ + enabled:false, + min_request_interval:0.5, + max_requests_per_minute:60, + global_max_requests_per_minute:120 + }; + try{ + await fetch('/api/settings/rate-limit',{method:'POST',headers:{'Content-Type':'application/json'},body:JSON.stringify(defaultConfig)}); + Toast.success('已还原为默认配置'); + loadRateLimitConfig(); + }catch(e){ + Toast.error('还原配置失败'); + } +} + +async function resetHistoryConfig(){ + if(!confirm('确定要还原历史消息配置为默认值吗?')) return; + const defaultConfig={ + strategies:['error_retry'], + max_retries:3, + summary_cache_enabled:true, + summary_cache_max_age_seconds:300, + add_warning_header:true + }; + try{ + await fetch('/api/settings/history',{method:'POST',headers:{'Content-Type':'application/json'},body:JSON.stringify(defaultConfig)}); + Toast.success('已还原为默认配置'); + loadHistoryConfig(); + }catch(e){ + Toast.error('还原配置失败'); + } +} + +// 页面加载时加载设置 +loadHistoryConfig(); +loadRateLimitConfig(); +loadRefreshConfig(); +''' + +# ==================== UI 组件库 JavaScript ==================== +JS_UI_COMPONENTS = ''' +// ==================== Modal 模态框组件 ==================== +class Modal { + constructor(options = {}) { + this.title = options.title || ''; + this.content = options.content || ''; + this.type = options.type || 'default'; + this.confirmText = options.confirmText || '确认'; + this.cancelText = options.cancelText || '取消'; + this.onConfirm = options.onConfirm; + this.onCancel = options.onCancel; + this.showCancel = options.showCancel !== false; + this.element = null; + } + + show() { + const overlay = document.createElement('div'); + overlay.className = 'modal-overlay'; + overlay.innerHTML = ` + + `; + overlay.modal = this; + this.element = overlay; + document.body.appendChild(overlay); + + // 键盘事件 + this.keyHandler = (e) => { + if (e.key === 'Escape') this.hide(); + if (e.key === 'Enter' && !e.target.matches('textarea')) this.confirm(); + }; + document.addEventListener('keydown', this.keyHandler); + + // 点击遮罩关闭 + overlay.addEventListener('click', (e) => { + if (e.target === overlay) this.hide(); + }); + + requestAnimationFrame(() => overlay.classList.add('active')); + return this; + } + + hide() { + if (this.element) { + this.element.classList.remove('active'); + document.removeEventListener('keydown', this.keyHandler); + setTimeout(() => this.element.remove(), 200); + } + } + + confirm() { + if (this.onConfirm) this.onConfirm(); + this.hide(); + } + + cancel() { + if (this.onCancel) this.onCancel(); + this.hide(); + } + + setLoading(loading) { + const btn = this.element?.querySelector('.modal-footer button:last-child'); + if (btn) { + btn.disabled = loading; + btn.textContent = loading ? '处理中...' : this.confirmText; + } + } + + static confirm(title, message, onConfirm) { + return new Modal({ title, content: `

    ${message}

    `, onConfirm }).show(); + } + + static alert(title, message) { + return new Modal({ title, content: `

    ${message}

    `, showCancel: false }).show(); + } + + static danger(title, message, onConfirm) { + return new Modal({ title, content: `

    ${message}

    `, type: 'danger', onConfirm, confirmText: '删除' }).show(); + } +} + +// ==================== Toast 通知组件 ==================== +class Toast { + static container = null; + + static getContainer() { + if (!this.container) { + this.container = document.createElement('div'); + this.container.className = 'toast-container'; + document.body.appendChild(this.container); + } + return this.container; + } + + static show(message, type = 'info', duration = 3000) { + const toast = document.createElement('div'); + toast.className = `toast ${type}`; + toast.innerHTML = ` + ${message} + + `; + this.getContainer().appendChild(toast); + + if (duration > 0) { + setTimeout(() => toast.remove(), duration); + } + return toast; + } + + static success(message, duration) { return this.show(message, 'success', duration); } + static error(message, duration) { return this.show(message, 'error', duration); } + static warning(message, duration) { return this.show(message, 'warning', duration); } + static info(message, duration) { return this.show(message, 'info', duration); } +} + +// ==================== Dropdown 下拉菜单组件 ==================== +class Dropdown { + constructor(trigger, items) { + this.trigger = trigger; + this.items = items; + this.element = null; + this.init(); + } + + init() { + const wrapper = document.createElement('div'); + wrapper.className = 'dropdown'; + + this.trigger.parentNode.insertBefore(wrapper, this.trigger); + wrapper.appendChild(this.trigger); + + const menu = document.createElement('div'); + menu.className = 'dropdown-menu'; + menu.innerHTML = this.items.map(item => { + if (item.divider) return ''; + return ``; + }).join(''); + wrapper.appendChild(menu); + + this.element = wrapper; + + this.trigger.addEventListener('click', (e) => { + e.stopPropagation(); + this.toggle(); + }); + + menu.addEventListener('click', (e) => { + const item = e.target.closest('.dropdown-item'); + if (item) { + const action = item.dataset.action; + const itemConfig = this.items.find(i => i.action === action); + if (itemConfig?.onClick) itemConfig.onClick(); + this.close(); + } + }); + + document.addEventListener('click', () => this.close()); + } + + toggle() { + this.element.classList.toggle('open'); + } + + close() { + this.element.classList.remove('open'); + } +} + +// ==================== 进度条渲染函数 ==================== +function renderProgressBar(value, max, options = {}) { + const percent = max > 0 ? (value / max * 100) : 0; + const color = options.color || (percent > 80 ? 'error' : percent > 60 ? 'warning' : 'success'); + const size = options.size || ''; + const showLabel = options.showLabel !== false; + + return ` +
    +
    +
    + ${showLabel ? `
    ${options.leftLabel || ''}${options.rightLabel || Math.round(percent) + '%'}
    ` : ''} + `; +} + +// ==================== 账号卡片渲染函数 ==================== +function renderAccountCard(account) { + const quota = account.quota; + const isPriority = account.is_priority; + const isActive = account.is_active; + + let statusBadge = ''; + if (!account.enabled) statusBadge = '禁用'; + else if (account.cooldown_remaining > 0) statusBadge = `冷却 ${account.cooldown_remaining}s`; + else if (account.available) statusBadge = '正常'; + else statusBadge = '不可用'; + + let quotaSection = ''; + if (quota && !quota.error) { + const usedPercent = quota.usage_limit > 0 ? (quota.current_usage / quota.usage_limit * 100) : 0; + quotaSection = ` + + `; + } else if (quota?.error) { + quotaSection = ``; + } + + return ` + + `; +} + +// ==================== 汇总面板渲染函数 ==================== +function renderSummaryPanel(summary) { + const strategyLabel = { + lowest_balance: '剩余额度最少优先', + round_robin: '轮询', + least_requests: '请求最少优先', + random: '随机' + }[summary.strategy] || summary.strategy; + + return ` +
    +
    +
    ${summary.total_accounts}
    总账号
    +
    ${summary.available_accounts}
    可用
    +
    ${summary.cooldown_accounts}
    冷却中
    +
    ${summary.unhealthy_accounts + summary.disabled_accounts}
    不可用
    +
    +
    +
    + 总剩余额度 + ${summary.total_balance.toFixed(1)} +
    + ${renderProgressBar(summary.total_usage, summary.total_limit, { + size: 'large', + leftLabel: `已用 ${summary.total_usage.toFixed(0)}`, + rightLabel: `总计 ${summary.total_limit.toFixed(0)}` + })} +
    +
    + 选择策略: ${strategyLabel} + 优先账号: ${summary.priority_accounts.length > 0 ? summary.priority_accounts.join(', ') : '无'} + 最后刷新: ${summary.last_refresh || '未刷新'} +
    +
    + + +
    +
    + `; +} + +// ==================== 账号操作菜单 ==================== +let currentAccountMenu = null; + +function showAccountMenu(accountId, btn) { + if (currentAccountMenu) { + currentAccountMenu.remove(); + currentAccountMenu = null; + } + + const menu = document.createElement('div'); + menu.className = 'dropdown-menu'; + menu.style.cssText = 'display:block;position:absolute;z-index:100;'; + menu.innerHTML = ` + + + + + + `; + + const rect = btn.getBoundingClientRect(); + menu.style.top = (rect.bottom + window.scrollY) + 'px'; + menu.style.left = (rect.left + window.scrollX - 100) + 'px'; + + document.body.appendChild(menu); + currentAccountMenu = menu; + + setTimeout(() => { + document.addEventListener('click', function closeMenu() { + if (currentAccountMenu) { + currentAccountMenu.remove(); + currentAccountMenu = null; + } + document.removeEventListener('click', closeMenu); + }, { once: true }); + }, 0); +} + +// ==================== 额度管理 API 调用 ==================== +async function loadAccountsEnhanced() { + showLoading('#accountsGrid', '加载账号列表...'); + try { + const r = await fetchWithRetry('/api/accounts/status'); + const d = await r.json(); + if (d.ok) { + $('#accountsSummaryCompact').innerHTML = renderSummaryCompact(d.summary); + $('#accountsGrid').innerHTML = d.accounts.map(renderAccountCardCompact).join(''); + } else { + $('#accountsGrid').innerHTML = `

    加载失败: ${d.error || '未知错误'}

    `; + } + } catch(e) { + $('#accountsGrid').innerHTML = `

    网络错误,点击重试

    `; + Toast.error('加载账号列表失败'); + } +} + +// ==================== 紧凑汇总面板 ==================== +function renderSummaryCompact(summary) { + const usedPercent = summary.total_limit > 0 ? (summary.total_usage / summary.total_limit * 100) : 0; + const barColor = usedPercent > 80 ? 'var(--error)' : usedPercent > 60 ? 'var(--warn)' : 'var(--success)'; + return ` +
    +
    + ${summary.total_accounts} + 总账号 +
    +
    + ${summary.available_accounts} + 可用 +
    +
    + ${summary.cooldown_accounts} + 冷却 +
    +
    +
    +
    + 总额度 + ${summary.total_balance.toFixed(0)} / ${summary.total_limit.toFixed(0)} +
    +
    +
    +
    +
    +
    + ${summary.last_refresh || '未刷新'} +
    +
    + `; +} + +// ==================== 紧凑账号卡片 ==================== +function renderAccountCardCompact(account) { + const quota = account.quota; + const isPriority = account.is_priority; + const isLowBalance = quota?.is_low_balance; + const isExhausted = quota?.is_exhausted || (quota && quota.balance <= 0); // 额度耗尽 + const isSuspended = quota?.is_suspended; // 账号被封禁 + const isUnavailable = !account.available; + + let cardClass = 'account-card-compact'; + if (isPriority) cardClass += ' priority'; + if (isSuspended) cardClass += ' suspended'; // 封禁状态 + else if (isExhausted) cardClass += ' exhausted'; // 无额度状态 + else if (isLowBalance) cardClass += ' low-balance'; + if (isUnavailable) cardClass += ' unavailable'; + + // 状态徽章 + let statusBadges = ''; + if (!account.enabled) statusBadges += '禁用'; + else if (account.cooldown_remaining > 0) statusBadges += `冷却`; + else if (account.available) statusBadges += '正常'; + else statusBadges += '异常'; + + if (isPriority) statusBadges += `#${account.priority_order}`; + // Provider 徽章 (Google/Github) + if (account.provider) { + const providerIcon = account.provider === 'Google' ? '🔵' : account.provider === 'Github' ? '⚫' : ''; + statusBadges += `${providerIcon}${account.provider}`; + } + // 状态徽章:封禁(红色)> 无额度(红色)> 低额度(黄色) + if (isSuspended) statusBadges += '已封禁'; + else if (isExhausted) statusBadges += '无额度'; + else if (isLowBalance) statusBadges += '低额度'; + + // Token 过期状态徽章 + if (account.token_expired) statusBadges += 'Token过期'; + else if (account.token_expiring_soon) statusBadges += 'Token即将过期'; + + // 额度条 - 根据状态显示不同颜色 + let quotaBar = ''; + if (quota && !quota.error) { + const usedPercent = quota.usage_limit > 0 ? (quota.current_usage / quota.usage_limit * 100) : 0; + // 颜色逻辑:无额度(红色) > 低额度(黄色) > 正常(绿色) + let barColor = 'var(--success)'; + if (isExhausted) barColor = 'var(--error)'; + else if (isLowBalance) barColor = 'var(--warn)'; + else if (usedPercent > 60) barColor = 'var(--warn)'; + + quotaBar = ` + + `; + } else if (quota?.error) { + // 额度获取失败时显示重试按钮 + // 如果是封禁错误,显示封禁状态 + const errorMsg = quota.error; + const isSuspendedError = errorMsg && ( + errorMsg.toLowerCase().includes('temporarily_suspended') || + errorMsg.toLowerCase().includes('suspended') || + errorMsg.toLowerCase().includes('accountsuspendedexception') + ); + + if (isSuspendedError) { + quotaBar = ` + + `; + } else { + quotaBar = ` + + `; + } + } else { + // 未查询额度时显示查询按钮 + quotaBar = ` + + `; + } + + // Token 过期时间显示 + let tokenExpireInfo = ''; + if (account.token_expires_at) { + // expires_at 可能是 ISO 字符串或时间戳 + let expireDate; + if (typeof account.token_expires_at === 'string') { + // ISO 格式字符串 + expireDate = new Date(account.token_expires_at); + } else if (account.token_expires_at > 1000000000000) { + // 毫秒时间戳 + expireDate = new Date(account.token_expires_at); + } else { + // 秒时间戳 + expireDate = new Date(account.token_expires_at * 1000); + } + + const now = new Date(); + const diffMs = expireDate - now; + + // 检查是否为有效日期 + if (!isNaN(expireDate.getTime()) && !isNaN(diffMs)) { + const diffHours = Math.floor(diffMs / (1000 * 60 * 60)); + const diffDays = Math.floor(diffHours / 24); + + let expireText = ''; + if (diffMs < 0) { + expireText = '已过期'; + } else if (diffDays > 0) { + expireText = `${diffDays}天`; + } else if (diffHours > 0) { + expireText = `${diffHours}时`; + } else { + const diffMins = Math.floor(diffMs / (1000 * 60)); + expireText = diffMins > 0 ? `${diffMins}分` : '即将过期'; + } + tokenExpireInfo = `Token ${expireText}`; + } + } + + return ` +
    + + ${quotaBar} + + +
    + `; +} + +// ==================== 导入导出菜单 ==================== +let importExportMenu = null; + +function showImportExportMenu(btn) { + if (importExportMenu) { + importExportMenu.remove(); + importExportMenu = null; + return; + } + + const menu = document.createElement('div'); + menu.className = 'dropdown-menu'; + menu.style.cssText = 'display:block;position:absolute;z-index:100;min-width:140px;'; + menu.innerHTML = ` + + + + + `; + + const rect = btn.getBoundingClientRect(); + menu.style.top = (rect.bottom + window.scrollY + 4) + 'px'; + menu.style.left = (rect.left + window.scrollX) + 'px'; + + document.body.appendChild(menu); + importExportMenu = menu; + + setTimeout(() => { + document.addEventListener('click', function closeMenu(e) { + if (importExportMenu && !importExportMenu.contains(e.target)) { + importExportMenu.remove(); + importExportMenu = null; + } + document.removeEventListener('click', closeMenu); + }, { once: true }); + }, 10); +} + +async function refreshAllQuotas() { + // 检查是否正在刷新中 + if (GlobalProgressBar.isRefreshing) { + Toast.warning('正在刷新中,请稍候...'); + return; + } + + try { + // 先获取账号数量用于显示 + const statusR = await fetch('/api/accounts/status'); + const statusD = await statusR.json(); + const total = statusD.ok ? statusD.accounts?.length || 0 : 0; + + // 显示进度条 + GlobalProgressBar.show(total); + + // 调用新的批量刷新 API + const r = await fetch('/api/refresh/all', { method: 'POST' }); + const d = await r.json(); + + if (d.ok) { + // 开始轮询进度 + GlobalProgressBar.startPolling(); + } else { + GlobalProgressBar.hide(); + Toast.error('启动刷新失败: ' + (d.error || '未知错误')); + } + } catch(e) { + GlobalProgressBar.hide(); + Toast.error('刷新失败: ' + e.message); + } +} + +async function refreshAccountQuota(accountId) { + Toast.info('正在刷新额度...'); + try { + const r = await fetch(`/api/accounts/${accountId}/refresh-quota`, { method: 'POST' }); + const d = await r.json(); + if (d.ok) { + Toast.success('额度刷新成功'); + loadAccounts(); + loadAccountsEnhanced(); + } else { + Toast.error(d.error || '刷新失败'); + } + } catch(e) { + Toast.error('刷新失败: ' + e.message); + } +} + +// ==================== 测试账号 Token ==================== +async function testAccountToken(accountId) { + // 显示测试中的模态框 + const modal = document.createElement('div'); + modal.className = 'modal'; + modal.id = 'testTokenModal'; + modal.innerHTML = ` + + `; + document.body.appendChild(modal); + modal.style.display = 'flex'; + + try { + const r = await fetch('/api/accounts/' + accountId + '/test'); + const d = await r.json(); + + const resultDiv = document.getElementById('testTokenResult'); + if (!resultDiv) return; + + if (d.ok) { + // 测试通过 + let testsHtml = ''; + for (const [key, test] of Object.entries(d.tests || {})) { + const icon = test.passed ? '✅' : '❌'; + const color = test.passed ? 'var(--success)' : 'var(--error)'; + testsHtml += ` +
    + ${icon} +
    +
    ${test.message}
    + ${test.suggestion ? `
    ${test.suggestion}
    ` : ''} + ${test.latency_ms ? `
    延迟: ${test.latency_ms.toFixed(0)}ms
    ` : ''} + ${test.email ? `
    邮箱: ${test.email}
    ` : ''} +
    +
    + `; + } + + resultDiv.innerHTML = ` +
    + +

    Token 有效

    +

    ${d.summary}

    +
    +
    + ${testsHtml} +
    + `; + } else { + // 测试失败 + let testsHtml = ''; + for (const [key, test] of Object.entries(d.tests || {})) { + const icon = test.passed ? '✅' : '❌'; + testsHtml += ` +
    + ${icon} +
    +
    ${test.message}
    + ${test.suggestion ? `
    💡 ${test.suggestion}
    ` : ''} +
    +
    + `; + } + + resultDiv.innerHTML = ` +
    + +

    Token 无效

    +

    ${d.summary || d.error || '测试失败'}

    +
    + ${Object.keys(d.tests || {}).length > 0 ? ` +
    + ${testsHtml} +
    + ` : ''} +
    + +
    + `; + } + } catch(e) { + const resultDiv = document.getElementById('testTokenResult'); + if (resultDiv) { + resultDiv.innerHTML = ` +
    + ⚠️ +

    测试失败

    +

    ${e.message}

    +
    + `; + } + } +} + +function closeTestTokenModal() { + const modal = document.getElementById('testTokenModal'); + if (modal) modal.remove(); +} + +// ==================== 单账号额度查询 (任务 19.2) ==================== +async function refreshSingleAccountQuota(accountId) { + // 获取按钮元素,显示加载状态 + const safeId = accountId.replace(/[^a-zA-Z0-9]/g, '_'); + const btn = document.getElementById('quota-btn-' + safeId); + const card = document.getElementById('account-card-' + safeId); + + if (btn) { + btn.disabled = true; + btn.dataset.originalText = btn.textContent; + btn.textContent = '查询中...'; + } + + try { + const r = await fetch(`/api/accounts/${accountId}/refresh-quota`, { method: 'POST' }); + const d = await r.json(); + + if (d.ok) { + Toast.success('额度查询成功'); + // 刷新整个账号列表以更新显示 + loadAccounts(); + loadAccountsEnhanced(); + } else { + // 失败时显示错误信息和重试按钮 + Toast.error(d.error || '额度查询失败'); + if (btn) { + btn.textContent = '重试'; + btn.disabled = false; + btn.classList.add('error-state'); + } + // 在卡片上显示错误状态 + if (card) { + const quotaDiv = card.querySelector('.account-card-quota'); + if (quotaDiv) { + quotaDiv.innerHTML = ` + 查询失败: ${d.error || '未知错误'} + + `; + } + } + } + } catch(e) { + Toast.error('网络错误: ' + e.message); + if (btn) { + btn.textContent = '重试'; + btn.disabled = false; + } + } finally { + // 恢复按钮状态(如果没有错误) + if (btn && !btn.classList.contains('error-state')) { + btn.disabled = false; + if (btn.dataset.originalText) { + btn.textContent = btn.dataset.originalText; + } + } + if (btn) { + btn.classList.remove('error-state'); + } + } +} + +// ==================== 单账号 Token 刷新 (任务 19.2) ==================== +async function refreshSingleAccountToken(accountId) { + // 获取按钮元素,显示加载状态 + const safeId = accountId.replace(/[^a-zA-Z0-9]/g, '_'); + const btn = document.getElementById('token-btn-' + safeId); + + if (btn) { + btn.disabled = true; + btn.dataset.originalText = btn.textContent; + btn.textContent = '刷新中...'; + } + + try { + const r = await fetch(`/api/accounts/${accountId}/refresh`, { method: 'POST' }); + const d = await r.json(); + + if (d.ok) { + Toast.success('Token 刷新成功'); + // 刷新整个账号列表以更新显示 + loadAccounts(); + loadAccountsEnhanced(); + } else { + // 失败时显示错误信息 + Toast.error(d.message || d.error || 'Token 刷新失败'); + if (btn) { + btn.textContent = '重试'; + btn.disabled = false; + } + } + } catch(e) { + Toast.error('网络错误: ' + e.message); + if (btn) { + btn.textContent = '重试'; + btn.disabled = false; + } + } finally { + // 恢复按钮状态 + if (btn && btn.textContent !== '重试') { + btn.disabled = false; + if (btn.dataset.originalText) { + btn.textContent = btn.dataset.originalText; + } + } + } +} + +async function togglePriority(accountId) { + try { + // 先检查是否已是优先账号 + const r1 = await fetch('/api/priority'); + const d1 = await r1.json(); + const isPriority = d1.priority_accounts?.some(a => a.id === accountId); + + if (isPriority) { + const r = await fetch(`/api/priority/${accountId}`, { method: 'DELETE' }); + const d = await r.json(); + Toast.show(d.message, d.ok ? 'success' : 'error'); + } else { + const r = await fetch(`/api/priority/${accountId}`, { method: 'POST', headers: {'Content-Type': 'application/json'}, body: '{}' }); + const d = await r.json(); + Toast.show(d.message, d.ok ? 'success' : 'error'); + } + loadAccounts(); + loadAccountsEnhanced(); + } catch(e) { + Toast.error('操作失败: ' + e.message); + } +} + +function confirmDeleteAccount(accountId) { + Modal.danger('删除账号', `确定要删除账号 ${accountId} 吗?此操作不可恢复。`, async () => { + try { + const r = await fetch(`/api/accounts/${accountId}`, { method: 'DELETE' }); + const d = await r.json(); + if (d.ok) { + Toast.success('账号已删除'); + loadAccounts(); + loadAccountsEnhanced(); + } else { + Toast.error('删除失败'); + } + } catch(e) { + Toast.error('删除失败: ' + e.message); + } + }); +} + +// ==================== 账号编辑功能 ==================== +function showEditAccountModal(accountId, currentName) { + const modal = new Modal({ + title: '编辑账号', + content: ` +
    +
    + + +
    +
    + + +
    +
    + + +
    + +
    + `, + confirmText: '保存', + onConfirm: async () => { + const name = document.getElementById('editAccountName').value.trim(); + const provider = document.getElementById('editAccountProvider').value; + const region = document.getElementById('editAccountRegion').value.trim(); + + const updateData = {}; + if (name) updateData.name = name; + if (provider) updateData.provider = provider; + if (region) updateData.region = region; + + try { + const r = await fetch(`/api/accounts/${accountId}`, { + method: 'PUT', + headers: {'Content-Type': 'application/json'}, + body: JSON.stringify(updateData) + }); + const d = await r.json(); + if (d.ok) { + Toast.success(d.message || '账号已更新'); + loadAccounts(); + loadAccountsEnhanced(); + } else { + Toast.error(d.error || '更新失败'); + } + } catch(e) { + Toast.error('更新失败: ' + e.message); + } + } + }); + modal.show(); + + // 加载当前账号信息填充表单 + loadAccountForEdit(accountId); +} + +async function refreshTokenInModal(accountId) { + const btn = document.getElementById('refreshTokenBtn'); + if (btn) { + btn.disabled = true; + btn.textContent = '刷新中...'; + } + + try { + const r = await fetch(`/api/accounts/${accountId}/refresh`, { method: 'POST' }); + const d = await r.json(); + if (d.ok) { + Toast.success('Token 刷新成功'); + // 重新加载账号信息 + await loadAccountForEdit(accountId); + loadAccounts(); + loadAccountsEnhanced(); + } else { + Toast.error(d.message || d.error || 'Token 刷新失败'); + } + } catch(e) { + Toast.error('刷新失败: ' + e.message); + } finally { + if (btn) { + btn.disabled = false; + btn.textContent = '🔄 刷新 Token'; + } + } +} + +function copyToClipboard(text, label) { + navigator.clipboard.writeText(text).then(() => { + Toast.success(label + ' 已复制'); + }).catch(() => { + // 降级方案 + const ta = document.createElement('textarea'); + ta.value = text; + document.body.appendChild(ta); + ta.select(); + document.execCommand('copy'); + document.body.removeChild(ta); + Toast.success(label + ' 已复制'); + }); +} + +function renderTokenField(label, value, fieldId) { + if (!value) return ''; + const shortValue = value.length > 50 ? value.substring(0, 50) + '...' : value; + return ` +
    +
    + ${label}: + +
    +
    ${shortValue}
    +
    + `; +} + +async function loadAccountForEdit(accountId) { + try { + const r = await fetch(`/api/accounts/${accountId}`); + const d = await r.json(); + + const providerSelect = document.getElementById('editAccountProvider'); + const regionInput = document.getElementById('editAccountRegion'); + const tokenSection = document.getElementById('tokenInfoSection'); + const tokenDetails = document.getElementById('tokenDetails'); + + if (d.credentials) { + if (providerSelect && d.credentials.provider) { + providerSelect.value = d.credentials.provider; + } + if (regionInput && d.credentials.region) { + regionInput.value = d.credentials.region; + } + + // 显示 Token 信息 + if (tokenSection && tokenDetails) { + tokenSection.style.display = 'block'; + + let html = ''; + + // Access Token + if (d.credentials.access_token) { + html += renderTokenField('Access Token', d.credentials.access_token, 'field_access_token'); + } + + // Refresh Token + if (d.credentials.refresh_token) { + html += renderTokenField('Refresh Token', d.credentials.refresh_token, 'field_refresh_token'); + } + + // Profile ARN + if (d.credentials.profile_arn) { + html += renderTokenField('Profile ARN', d.credentials.profile_arn, 'field_profile_arn'); + } + + // Client ID + if (d.credentials.client_id) { + html += renderTokenField('Client ID', d.credentials.client_id, 'field_client_id'); + } + + // 过期时间 + if (d.credentials.expires_at) { + const expiresAt = new Date(d.credentials.expires_at); + const now = new Date(); + const diffMs = expiresAt - now; + const diffMins = Math.floor(diffMs / 60000); + let expiryText = expiresAt.toLocaleString(); + if (diffMs < 0) { + expiryText += ' (已过期)'; + } else if (diffMins < 60) { + expiryText += ' (' + diffMins + '分钟后过期)'; + } else { + expiryText += ' (' + Math.floor(diffMins/60) + '小时后过期)'; + } + html += '
    过期时间: ' + expiryText + '
    '; + } + + // Auth Method + if (d.credentials.auth_method) { + html += '
    认证方式: ' + d.credentials.auth_method + '
    '; + } + + tokenDetails.innerHTML = html || '无 Token 信息'; + } + } + } catch(e) { + console.error('加载账号信息失败:', e); + } +} + +// ==================== 自动刷新功能 (任务 10.2) ==================== +let autoRefreshTimer = null; +const AUTO_REFRESH_INTERVAL = 60000; // 60秒 + +function startAutoRefresh() { + if (autoRefreshTimer) clearInterval(autoRefreshTimer); + autoRefreshTimer = setInterval(() => { + const accountsTab = document.querySelector('.tab[data-tab="accounts"]'); + if (accountsTab && accountsTab.classList.contains('active')) { + loadAccounts(); + loadAccountsEnhanced(); + } + }, AUTO_REFRESH_INTERVAL); +} + +function stopAutoRefresh() { + if (autoRefreshTimer) { + clearInterval(autoRefreshTimer); + autoRefreshTimer = null; + } +} + +// 页面加载时启动自动刷新 +startAutoRefresh(); + +// ==================== 加载状态指示器 (任务 10.1) ==================== +function showLoading(container, message = '加载中...') { + const el = typeof container === 'string' ? document.querySelector(container) : container; + if (el) { + el.innerHTML = `
    +
    +

    ${message}

    +
    `; + } +} + +// 添加旋转动画 +if (!document.querySelector('#spinKeyframes')) { + const style = document.createElement('style'); + style.id = 'spinKeyframes'; + style.textContent = '@keyframes spin { to { transform: rotate(360deg); } }'; + document.head.appendChild(style); +} + +// ==================== 表单验证 (任务 10.3) ==================== +function validateToken(token) { + if (!token || token.trim().length === 0) { + return { valid: false, error: 'Token 不能为空' }; + } + if (token.trim().length < 20) { + return { valid: false, error: 'Token 格式不正确,长度过短' }; + } + return { valid: true }; +} + +function validateAccountName(name) { + if (!name || name.trim().length === 0) { + return { valid: true, default: '手动添加账号' }; // 名称可选 + } + if (name.length > 50) { + return { valid: false, error: '账号名称不能超过50个字符' }; + } + return { valid: true }; +} + +// ==================== 网络错误处理 (任务 10.1) ==================== +async function fetchWithRetry(url, options = {}, retries = 2) { + for (let i = 0; i <= retries; i++) { + try { + const r = await fetch(url, options); + if (!r.ok && r.status >= 500 && i < retries) { + await new Promise(resolve => setTimeout(resolve, 1000 * (i + 1))); + continue; + } + return r; + } catch (e) { + if (i === retries) throw e; + await new Promise(resolve => setTimeout(resolve, 1000 * (i + 1))); + } + } +} + +// ==================== 全局进度条组件 (任务 18.1) ==================== +const GlobalProgressBar = { + pollTimer: null, + isRefreshing: false, + + // 显示进度条 + show(total) { + this.isRefreshing = true; + const bar = $('#globalProgressBar'); + if (bar) { + bar.classList.add('active'); + // 重置显示 + $('#globalProgressTitle').textContent = '正在刷新额度...'; + $('#globalProgressCompleted').textContent = '0'; + $('#globalProgressTotal').textContent = total || '0'; + $('#globalProgressSuccess').textContent = '0'; + $('#globalProgressFailed').textContent = '0'; + $('#globalProgressFill').style.width = '0%'; + $('#globalProgressFill').classList.remove('complete'); + $('#globalProgressCurrent').textContent = '准备中...'; + $('#globalProgressClose').style.display = 'none'; + // 显示 spinner + const spinner = bar.querySelector('.spinner'); + if (spinner) spinner.style.display = 'inline-block'; + } + // 禁用刷新按钮 + this.updateRefreshButton(true); + }, + + // 更新进度 + update(progress) { + if (!progress) return; + + const completed = progress.completed || 0; + const total = progress.total || 0; + const success = progress.success || 0; + const failed = progress.failed || 0; + const current = progress.current_account || ''; + const isComplete = progress.status === 'completed' || progress.status === 'idle'; + + // 更新数字 + $('#globalProgressCompleted').textContent = completed; + $('#globalProgressTotal').textContent = total; + $('#globalProgressSuccess').textContent = success; + $('#globalProgressFailed').textContent = failed; + + // 更新进度条 + const percent = total > 0 ? (completed / total * 100) : 0; + const fill = $('#globalProgressFill'); + if (fill) { + fill.style.width = percent + '%'; + if (isComplete) { + fill.classList.add('complete'); + } + } + + // 更新当前处理的账号 + if (current) { + $('#globalProgressCurrent').textContent = '正在处理: ' + current; + } else if (isComplete) { + $('#globalProgressCurrent').textContent = `刷新完成: 成功 ${success} 个, 失败 ${failed} 个`; + } + + // 完成后的处理 + if (isComplete) { + this.isRefreshing = false; + $('#globalProgressTitle').textContent = '刷新完成'; + $('#globalProgressClose').style.display = 'inline-block'; + // 隐藏 spinner + const spinner = $('#globalProgressBar')?.querySelector('.spinner'); + if (spinner) spinner.style.display = 'none'; + // 恢复刷新按钮 + this.updateRefreshButton(false); + // 刷新账号列表 + loadAccounts(); + loadAccountsEnhanced(); + // 显示完成通知 + if (failed > 0) { + Toast.warning(`刷新完成: 成功 ${success} 个, 失败 ${failed} 个`); + } else { + Toast.success(`刷新完成: 成功 ${success} 个`); + } + // 5秒后自动关闭进度条 + setTimeout(() => this.hide(), 5000); + } + }, + + // 隐藏进度条 + hide() { + const bar = $('#globalProgressBar'); + if (bar) { + bar.classList.remove('active'); + } + this.isRefreshing = false; + this.stopPolling(); + this.updateRefreshButton(false); + }, + + // 开始轮询进度 + startPolling() { + this.stopPolling(); + this.pollTimer = setInterval(() => this.pollProgress(), 500); + }, + + // 停止轮询 + stopPolling() { + if (this.pollTimer) { + clearInterval(this.pollTimer); + this.pollTimer = null; + } + }, + + // 轮询进度 API + async pollProgress() { + try { + const r = await fetch('/api/refresh/progress'); + const d = await r.json(); + if (d.ok) { + // 传入 progress 对象,如果没有则传入整个响应(兼容) + const progress = d.progress || d; + // 添加 status 字段用于判断完成状态 + if (!d.is_refreshing && !progress.status) { + progress.status = 'completed'; + } + this.update(progress); + // 如果完成则停止轮询 + if (!d.is_refreshing || progress.status === 'completed' || progress.status === 'idle') { + this.stopPolling(); + } + } + } catch (e) { + console.error('轮询进度失败:', e); + } + }, + + // 更新刷新按钮状态 + updateRefreshButton(disabled) { + // 查找所有刷新额度按钮 + const buttons = document.querySelectorAll('button'); + buttons.forEach(btn => { + const text = btn.textContent; + const originalText = btn.dataset.originalText; + // 匹配"刷新额度"、"刷新全部额度"或已经变成"刷新中..."的按钮 + if (text.includes('刷新额度') || text.includes('刷新全部额度') || + text === '刷新中...' || + (originalText && (originalText.includes('刷新额度') || originalText.includes('刷新全部额度')))) { + btn.disabled = disabled; + if (disabled) { + if (!btn.dataset.originalText) { + btn.dataset.originalText = text; + } + btn.textContent = '刷新中...'; + } else if (btn.dataset.originalText) { + btn.textContent = btn.dataset.originalText; + delete btn.dataset.originalText; + } + } + }); + } +}; + +// ==================== 进度轮询函数 (任务 18.2) ==================== +async function pollRefreshProgress() { + return GlobalProgressBar.pollProgress(); +} +''' + +JS_SCRIPTS = JS_UTILS + JS_TABS + JS_STATUS + JS_DOCS + JS_STATS + JS_LOGS + JS_ACCOUNTS + JS_LOGIN + JS_FLOWS + JS_SETTINGS + JS_UI_COMPONENTS + + +# ==================== 组装最终 HTML ==================== +HTML_PAGE = f''' + + + + +Kiro API + + + + +
    +{HTML_BODY} + +
    + + +''' diff --git a/KiroProxy/legacy/kiro_proxy.py b/KiroProxy/legacy/kiro_proxy.py new file mode 100644 index 0000000000000000000000000000000000000000..e53631997cefb9df187e74fdc93cd916c2269f8c --- /dev/null +++ b/KiroProxy/legacy/kiro_proxy.py @@ -0,0 +1,313 @@ +#!/usr/bin/env python3 +""" +Kiro API 反向代理服务器 +对外暴露 OpenAI 兼容接口,内部调用 Kiro/AWS Q API +""" + +import json +import uuid +import os +import httpx +from fastapi import FastAPI, Request, HTTPException +from fastapi.responses import StreamingResponse, JSONResponse +import uvicorn +from datetime import datetime +from pathlib import Path +import logging + +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') +logger = logging.getLogger(__name__) + +app = FastAPI(title="Kiro API Proxy") + +# Kiro API 配置 +KIRO_API_URL = "https://q.us-east-1.amazonaws.com/generateAssistantResponse" +TOKEN_PATH = Path.home() / ".aws/sso/cache/kiro-auth-token.json" +MACHINE_ID = "fa41d5def91e29225c73f6ea8ee0941a87bd812aae5239e3dde72c3ba7603a26" + +def get_kiro_token() -> str: + """从本地文件读取 Kiro token""" + try: + with open(TOKEN_PATH) as f: + data = json.load(f) + return data.get("accessToken", "") + except Exception as e: + logger.error(f"读取 token 失败: {e}") + raise HTTPException(status_code=500, detail="无法读取 Kiro token") + +def build_kiro_headers(token: str) -> dict: + """构建 Kiro API 请求头""" + return { + "content-type": "application/json", + "x-amzn-codewhisperer-optout": "true", + "x-amzn-kiro-agent-mode": "vibe", + "x-amz-user-agent": f"aws-sdk-js/1.0.27 KiroIDE-0.8.0-{MACHINE_ID}", + "user-agent": f"aws-sdk-js/1.0.27 ua/2.1 os/linux lang/js md/nodejs api/codewhispererstreaming KiroIDE-0.8.0-{MACHINE_ID}", + "amz-sdk-invocation-id": str(uuid.uuid4()), + "amz-sdk-request": "attempt=1; max=3", + "Authorization": f"Bearer {token}", + } + +def build_kiro_request(messages: list, model: str, conversation_id: str = None) -> dict: + """将 OpenAI 格式转换为 Kiro 格式""" + if not conversation_id: + conversation_id = str(uuid.uuid4()) + + # 提取最后一条用户消息 + user_content = "" + for msg in reversed(messages): + if msg.get("role") == "user": + user_content = msg.get("content", "") + break + + return { + "conversationState": { + "conversationId": conversation_id, + "currentMessage": { + "userInputMessage": { + "content": user_content, + "modelId": model.replace("kiro-", ""), # 移除前缀 + "origin": "AI_EDITOR", + "userInputMessageContext": {} + } + }, + "chatTriggerType": "MANUAL" + } + } + +def parse_kiro_response(response_data: dict) -> str: + """解析 Kiro 响应,提取 AI 回复内容""" + try: + # Kiro 响应格式可能是流式的,需要解析 + if isinstance(response_data, dict): + # 尝试多种可能的响应路径 + if "generateAssistantResponseResponse" in response_data: + resp = response_data["generateAssistantResponseResponse"] + if "assistantResponseEvent" in resp: + event = resp["assistantResponseEvent"] + if "content" in event: + return event["content"] + + # 直接返回文本内容 + if "content" in response_data: + return response_data["content"] + + if "message" in response_data: + return response_data["message"] + + return json.dumps(response_data) + except Exception as e: + logger.error(f"解析响应失败: {e}") + return str(response_data) + +def parse_event_stream(raw_content: bytes) -> str: + """解析 AWS event-stream 格式的响应""" + try: + # 尝试直接解码为 UTF-8 + try: + text = raw_content.decode('utf-8') + # 如果是纯 JSON + if text.startswith('{'): + data = json.loads(text) + return parse_kiro_response(data) + except: + pass + + # AWS event-stream 格式解析 + # 格式: [prelude (8 bytes)][headers][payload][message CRC (4 bytes)] + content_parts = [] + pos = 0 + + while pos < len(raw_content): + if pos + 12 > len(raw_content): + break + + # 读取 prelude: total_length (4 bytes) + headers_length (4 bytes) + prelude_crc (4 bytes) + total_length = int.from_bytes(raw_content[pos:pos+4], 'big') + headers_length = int.from_bytes(raw_content[pos+4:pos+8], 'big') + + if total_length == 0 or total_length > len(raw_content) - pos: + break + + # 跳过 prelude (12 bytes) 和 headers + payload_start = pos + 12 + headers_length + payload_end = pos + total_length - 4 # 减去 message CRC + + if payload_start < payload_end: + payload = raw_content[payload_start:payload_end] + try: + # 尝试解析 payload 为 JSON + payload_text = payload.decode('utf-8') + if payload_text.strip(): + payload_json = json.loads(payload_text) + + # 提取文本内容 + if "assistantResponseEvent" in payload_json: + event = payload_json["assistantResponseEvent"] + if "content" in event: + content_parts.append(event["content"]) + elif "content" in payload_json: + content_parts.append(payload_json["content"]) + elif "text" in payload_json: + content_parts.append(payload_json["text"]) + else: + logger.info(f" Event: {payload_text[:200]}") + except Exception as e: + logger.debug(f"解析 payload 失败: {e}") + + pos += total_length + + if content_parts: + return "".join(content_parts) + + # 如果解析失败,返回原始内容的十六进制表示用于调试 + return f"[无法解析响应,原始数据: {raw_content[:500].hex()}]" + + except Exception as e: + logger.error(f"解析 event-stream 失败: {e}") + return f"[解析错误: {e}]" + +@app.get("/") +async def root(): + """健康检查""" + token_exists = TOKEN_PATH.exists() + return { + "status": "ok", + "service": "Kiro API Proxy", + "token_available": token_exists, + "endpoints": { + "chat": "/v1/chat/completions", + "models": "/v1/models" + } + } + +@app.get("/v1/models") +async def list_models(): + """列出可用模型 (OpenAI 兼容)""" + return { + "object": "list", + "data": [ + {"id": "kiro-claude-sonnet-4", "object": "model", "owned_by": "kiro"}, + {"id": "kiro-claude-opus-4.5", "object": "model", "owned_by": "kiro"}, + {"id": "claude-sonnet-4", "object": "model", "owned_by": "kiro"}, + {"id": "claude-opus-4.5", "object": "model", "owned_by": "kiro"}, + ] + } + +@app.post("/v1/chat/completions") +async def chat_completions(request: Request): + """OpenAI 兼容的聊天接口""" + try: + body = await request.json() + except: + raise HTTPException(status_code=400, detail="Invalid JSON") + + messages = body.get("messages", []) + model = body.get("model", "claude-sonnet-4") + stream = body.get("stream", False) + + if not messages: + raise HTTPException(status_code=400, detail="messages is required") + + # 获取 token + token = get_kiro_token() + + # 构建请求 + headers = build_kiro_headers(token) + kiro_body = build_kiro_request(messages, model) + + logger.info(f"📤 发送请求到 Kiro API, model={model}") + logger.info(f" 消息: {messages[-1].get('content', '')[:100]}...") + + try: + async with httpx.AsyncClient(timeout=60.0, verify=False) as client: + response = await client.post( + KIRO_API_URL, + headers=headers, + json=kiro_body + ) + + logger.info(f"📥 Kiro 响应状态: {response.status_code}") + logger.info(f" Content-Type: {response.headers.get('content-type')}") + + if response.status_code != 200: + logger.error(f"Kiro API 错误: {response.text}") + raise HTTPException( + status_code=response.status_code, + detail=f"Kiro API error: {response.text}" + ) + + # 处理响应 - 可能是 event-stream 或 JSON + raw_content = response.content + logger.info(f" 响应大小: {len(raw_content)} bytes") + logger.info(f" 原始响应前200字节: {raw_content[:200]}") + + content = parse_event_stream(raw_content) + + logger.info(f" 回复: {content[:100]}...") + + # 返回 OpenAI 兼容格式 + return JSONResponse({ + "id": f"chatcmpl-{uuid.uuid4().hex[:8]}", + "object": "chat.completion", + "created": int(datetime.now().timestamp()), + "model": model, + "choices": [{ + "index": 0, + "message": { + "role": "assistant", + "content": content + }, + "finish_reason": "stop" + }], + "usage": { + "prompt_tokens": 0, + "completion_tokens": 0, + "total_tokens": 0 + } + }) + + except httpx.RequestError as e: + logger.error(f"请求失败: {e}") + raise HTTPException(status_code=500, detail=str(e)) + except Exception as e: + logger.error(f"未知错误: {e}") + import traceback + traceback.print_exc() + raise HTTPException(status_code=500, detail=str(e)) + +@app.get("/token/status") +async def token_status(): + """检查 token 状态""" + try: + with open(TOKEN_PATH) as f: + data = json.load(f) + expires_at = data.get("expiresAt", "unknown") + return { + "valid": True, + "expires_at": expires_at, + "path": str(TOKEN_PATH) + } + except Exception as e: + return { + "valid": False, + "error": str(e), + "path": str(TOKEN_PATH) + } + +if __name__ == "__main__": + print(""" +╔══════════════════════════════════════════════════════════════╗ +║ Kiro API 反向代理服务器 ║ +╠══════════════════════════════════════════════════════════════╣ +║ 端口: 8000 ║ +║ OpenAI 兼容接口: http://127.0.0.1:8000/v1/chat/completions ║ +╠══════════════════════════════════════════════════════════════╣ +║ 使用方法: ║ +║ curl http://127.0.0.1:8000/v1/chat/completions \\ ║ +║ -H "Content-Type: application/json" \\ ║ +║ -d '{"model":"claude-sonnet-4","messages":[{"role":"user",║ +║ "content":"Hello"}]}' ║ +╚══════════════════════════════════════════════════════════════╝ + """) + uvicorn.run(app, host="0.0.0.0", port=8000) diff --git a/KiroProxy/pyproject.toml b/KiroProxy/pyproject.toml new file mode 100644 index 0000000000000000000000000000000000000000..e3df50f23a05549492ab6e4b4833df2b64fa100c --- /dev/null +++ b/KiroProxy/pyproject.toml @@ -0,0 +1,39 @@ +[build-system] +requires = ["setuptools>=68", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "kiroproxy" +dynamic = ["version"] +description = "Kiro IDE API reverse proxy server" +readme = "README.md" +requires-python = ">=3.10" +license = { text = "MIT" } +authors = [{ name = "petehsu" }] +urls = { Homepage = "https://github.com/petehsu/KiroProxy" } +dependencies = [ + "fastapi>=0.100.0", + "uvicorn>=0.23.0", + "httpx>=0.24.0", + "requests>=2.31.0", + "tiktoken>=0.5.0", + "cbor2>=5.4.0", +] + +[project.optional-dependencies] +dev = [ + "pytest>=7.0.0", + "hypothesis>=6.0.0", +] + +[project.scripts] +kiro-proxy = "kiro_proxy.cli:main" + +[tool.setuptools.dynamic] +version = { attr = "kiro_proxy.__version__" } + +[tool.setuptools.packages.find] +include = ["kiro_proxy*"] + +[tool.pytest.ini_options] +testpaths = ["tests"] diff --git a/KiroProxy/requirements.txt b/KiroProxy/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..3ff93ba28840f80c905482767e236ea0efef20d5 --- /dev/null +++ b/KiroProxy/requirements.txt @@ -0,0 +1,8 @@ +fastapi>=0.100.0 +uvicorn>=0.23.0 +httpx>=0.24.0 +requests>=2.31.0 +pytest>=7.0.0 +hypothesis>=6.0.0 +tiktoken>=0.5.0 +cbor2>=5.4.0 diff --git a/KiroProxy/run.py b/KiroProxy/run.py new file mode 100644 index 0000000000000000000000000000000000000000..7a796e6cb29e867e508f0e67e1293408abec81a5 --- /dev/null +++ b/KiroProxy/run.py @@ -0,0 +1,14 @@ +#!/usr/bin/env python3 +"""Kiro API Proxy 启动脚本""" +import sys + +if __name__ == "__main__": + # 如果有子命令参数,使用 CLI 模式 + if len(sys.argv) > 1 and sys.argv[1] in ("accounts", "login", "status", "serve"): + from kiro_proxy.cli import main + main() + else: + # 兼容旧的启动方式: python run.py [port] + port = int(sys.argv[1]) if len(sys.argv) > 1 else 8080 + from kiro_proxy.main import run + run(port) diff --git a/KiroProxy/scripts/capture_kiro.py b/KiroProxy/scripts/capture_kiro.py new file mode 100644 index 0000000000000000000000000000000000000000..00b4ce3ef30017c2f03cd19f99bc7ac0d264d2e1 --- /dev/null +++ b/KiroProxy/scripts/capture_kiro.py @@ -0,0 +1,139 @@ +#!/usr/bin/env python3 +""" +Kiro IDE 请求抓取工具 + +使用 mitmproxy 抓取 Kiro IDE 发送到 AWS 的请求。 + +安装: + pip install mitmproxy + +使用方法: + 1. 运行此脚本: python capture_kiro.py + 2. 设置系统代理为 127.0.0.1:8888 + 3. 安装 mitmproxy 的 CA 证书 (访问 http://mitm.it) + 4. 启动 Kiro IDE 并使用 + 5. 查看 kiro_requests/ 目录下的抓取结果 + +或者使用 mitmproxy 命令行: + mitmproxy --mode regular@8888 -s capture_kiro.py + + 或者使用 mitmdump (无 UI): + mitmdump --mode regular@8888 -s capture_kiro.py +""" + +import json +import os +from datetime import datetime +from mitmproxy import http, ctx + +# 创建输出目录 +OUTPUT_DIR = "kiro_requests" +os.makedirs(OUTPUT_DIR, exist_ok=True) + +# 计数器 +request_count = 0 + +def request(flow: http.HTTPFlow) -> None: + """处理请求""" + global request_count + + # 只抓取 Kiro/AWS 相关请求 + if "q.us-east-1.amazonaws.com" not in flow.request.host: + return + + request_count += 1 + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + + # 保存请求 + request_data = { + "timestamp": timestamp, + "method": flow.request.method, + "url": flow.request.url, + "headers": dict(flow.request.headers), + "body": None + } + + # 解析请求体 + if flow.request.content: + try: + request_data["body"] = json.loads(flow.request.content.decode('utf-8')) + except: + request_data["body_raw"] = flow.request.content.decode('utf-8', errors='replace') + + # 保存到文件 + filename = f"{OUTPUT_DIR}/{timestamp}_{request_count:04d}_request.json" + with open(filename, 'w', encoding='utf-8') as f: + json.dump(request_data, f, indent=2, ensure_ascii=False) + + ctx.log.info(f"[Kiro] Captured request #{request_count}: {flow.request.method} {flow.request.path}") + + +def response(flow: http.HTTPFlow) -> None: + """处理响应""" + # 只处理 Kiro/AWS 相关响应 + if "q.us-east-1.amazonaws.com" not in flow.request.host: + return + + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + + # 保存响应 + response_data = { + "timestamp": timestamp, + "status_code": flow.response.status_code, + "headers": dict(flow.response.headers), + "body": None + } + + # 响应体可能是 event-stream 格式 + if flow.response.content: + try: + # 尝试解析为 JSON + response_data["body"] = json.loads(flow.response.content.decode('utf-8')) + except: + # 保存原始内容(可能是 event-stream) + response_data["body_raw_length"] = len(flow.response.content) + # 保存前 2000 字节用于调试 + response_data["body_preview"] = flow.response.content[:2000].decode('utf-8', errors='replace') + + # 保存到文件 + filename = f"{OUTPUT_DIR}/{timestamp}_{request_count:04d}_response.json" + with open(filename, 'w', encoding='utf-8') as f: + json.dump(response_data, f, indent=2, ensure_ascii=False) + + ctx.log.info(f"[Kiro] Captured response: {flow.response.status_code}") + + +# 如果直接运行此脚本 +if __name__ == "__main__": + print(""" +╔══════════════════════════════════════════════════════════════════╗ +║ Kiro IDE 请求抓取工具 ║ +╠══════════════════════════════════════════════════════════════════╣ +║ ║ +║ 方法 1: 使用 mitmproxy (推荐) ║ +║ ─────────────────────────────────────────────────────────────── ║ +║ 1. 安装: pip install mitmproxy ║ +║ 2. 运行: mitmproxy -s capture_kiro.py ║ +║ 或: mitmdump -s capture_kiro.py ║ +║ 3. 设置 Kiro IDE 的代理为 127.0.0.1:8080 ║ +║ 4. 安装 CA 证书: 访问 http://mitm.it ║ +║ ║ +║ 方法 2: 使用 Burp Suite ║ +║ ─────────────────────────────────────────────────────────────── ║ +║ 1. 启动 Burp Suite ║ +║ 2. 设置代理监听 127.0.0.1:8080 ║ +║ 3. 导出 CA 证书并安装到系统 ║ +║ 4. 设置 Kiro IDE 的代理 ║ +║ ║ +║ 方法 3: 直接修改 Kiro IDE (最简单) ║ +║ ─────────────────────────────────────────────────────────────── ║ +║ 在 Kiro IDE 的设置中添加: ║ +║ "http.proxy": "http://127.0.0.1:8080" ║ +║ ║ +║ 或者设置环境变量: ║ +║ export HTTPS_PROXY=http://127.0.0.1:8080 ║ +║ export HTTP_PROXY=http://127.0.0.1:8080 ║ +║ export NODE_TLS_REJECT_UNAUTHORIZED=0 ║ +║ ║ +╚══════════════════════════════════════════════════════════════════╝ +""") diff --git a/KiroProxy/scripts/debug_quota_info.py b/KiroProxy/scripts/debug_quota_info.py new file mode 100644 index 0000000000000000000000000000000000000000..c2c2d0ae06895790bbdd5205f0783c657a573fb6 --- /dev/null +++ b/KiroProxy/scripts/debug_quota_info.py @@ -0,0 +1,54 @@ +"""调试额度信息获取""" +import asyncio +import json +from kiro_proxy.core.state import ProxyState + + +def debug_quota_info(): + """调试额度信息获取""" + + # 初始化状态管理器 + state = ProxyState() + + print("=== 调试账号额度信息 ===\n") + + for account in state.accounts[:2]: # 只查看前两个账号 + print(f"账号: {account.name} ({account.id})") + + # 获取状态信息 + status = account.get_status_info() + + if "quota" in status and status["quota"]: + quota = status["quota"] + print(f" - 额度状态: {quota.get('balance_status', 'unknown')}") + print(f" - 已用/总额: {quota.get('current_usage', 0)} / {quota.get('usage_limit', 0)}") + print(f" - 剩余额度: {quota.get('balance', 0)}") + print(f" - 更新时间: {quota.get('updated_at', 'unknown')}") + + # 检查重置时间字段 + print(f" - 下次重置时间: {quota.get('next_reset_date', '未设置')}") + print(f" - 格式化重置日期: {quota.get('reset_date_text', '未设置')}") + print(f" - 免费试用过期: {quota.get('free_trial_expiry', '未设置')}") + print(f" - 格式化过期日期: {quota.get('trial_expiry_text', '未设置')}") + print(f" - 奖励过期列表: {quota.get('bonus_expiries', [])}") + print(f" - 生效奖励数: {quota.get('active_bonuses', 0)}") + else: + print(" - 无额度信息") + if status.get("quota", {}).get("error"): + print(f" - 错误: {status['quota']['error']}") + + print() + + # 模拟 API 响应 + print("=== 模拟 Web API 响应 ===\n") + + accounts_status = state.get_accounts_status() + + # 只显示第一个账号的信息 + if accounts_status: + first_account = accounts_status[0] + print(json.dumps(first_account, indent=2, ensure_ascii=False, default=str)) + + +if __name__ == "__main__": + debug_quota_info() diff --git a/KiroProxy/scripts/get_models.py b/KiroProxy/scripts/get_models.py new file mode 100644 index 0000000000000000000000000000000000000000..be99477cdb7e933457d33ab06ee372b1498cd1b2 --- /dev/null +++ b/KiroProxy/scripts/get_models.py @@ -0,0 +1,44 @@ +#!/usr/bin/env python3 +"""获取 Kiro 支持的模型列表""" + +import json +import uuid +import httpx +from pathlib import Path + +TOKEN_PATH = Path.home() / ".aws/sso/cache/kiro-auth-token.json" +MACHINE_ID = "fa41d5def91e29225c73f6ea8ee0941a87bd812aae5239e3dde72c3ba7603a26" +MODELS_URL = "https://q.us-east-1.amazonaws.com/ListAvailableModels" + +def get_token(): + with open(TOKEN_PATH) as f: + return json.load(f).get("accessToken", "") + +def get_models(): + token = get_token() + headers = { + "content-type": "application/json", + "x-amz-user-agent": f"aws-sdk-js/1.0.27 KiroIDE-0.8.0-{MACHINE_ID}", + "amz-sdk-invocation-id": str(uuid.uuid4()), + "Authorization": f"Bearer {token}", + } + + # 尝试不同的参数 + params = {"origin": "AI_EDITOR"} + + with httpx.Client(verify=False, timeout=30) as client: + resp = client.get(MODELS_URL, headers=headers, params=params) + print(f"Status: {resp.status_code}") + print(f"Headers: {dict(resp.headers)}") + print(f"\nRaw response ({len(resp.content)} bytes):") + + # 尝试解析 + try: + data = resp.json() + print(json.dumps(data, indent=2, ensure_ascii=False)) + except: + # 可能是 event-stream 格式 + print(resp.content[:2000]) + +if __name__ == "__main__": + get_models() diff --git a/KiroProxy/scripts/proxy_server.py b/KiroProxy/scripts/proxy_server.py new file mode 100644 index 0000000000000000000000000000000000000000..e7be7422779f0d663e84a213a462484301b54a07 --- /dev/null +++ b/KiroProxy/scripts/proxy_server.py @@ -0,0 +1,135 @@ +#!/usr/bin/env python3 +""" +Kiro IDE 反向代理测试服务器 +用于测试是否能成功拦截和转发 Kiro 的 API 请求 +""" + +from fastapi import FastAPI, Request, Response +from fastapi.responses import JSONResponse, StreamingResponse +import httpx +import uvicorn +import json +import logging +from datetime import datetime + +# 配置日志 +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(levelname)s - %(message)s' +) +logger = logging.getLogger(__name__) + +app = FastAPI(title="Kiro Reverse Proxy Test") + +# 原始 Kiro API 地址(如果需要转发到真实服务器) +KIRO_API_BASE = "https://api.kiro.dev" + +# 记录所有请求 +request_log = [] + +@app.middleware("http") +async def log_requests(request: Request, call_next): + """记录所有进入的请求""" + body = await request.body() + + log_entry = { + "timestamp": datetime.now().isoformat(), + "method": request.method, + "url": str(request.url), + "path": request.url.path, + "headers": dict(request.headers), + "body": body.decode('utf-8', errors='ignore')[:2000] if body else None + } + + request_log.append(log_entry) + logger.info(f"📥 {request.method} {request.url.path}") + logger.info(f" Headers: {dict(request.headers)}") + if body: + logger.info(f" Body: {body.decode('utf-8', errors='ignore')[:500]}...") + + response = await call_next(request) + return response + +@app.get("/") +async def root(): + """健康检查""" + return {"status": "ok", "message": "Kiro Proxy Server Running", "requests_logged": len(request_log)} + +@app.get("/logs") +async def get_logs(): + """查看所有记录的请求""" + return {"total": len(request_log), "requests": request_log[-50:]} + +@app.get("/clear") +async def clear_logs(): + """清空日志""" + request_log.clear() + return {"message": "Logs cleared"} + +# 模拟认证成功响应 +@app.api_route("/auth/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"]) +async def mock_auth(request: Request, path: str): + """模拟认证端点""" + logger.info(f"🔐 Auth request: {path}") + return JSONResponse({ + "success": True, + "token": "mock-token-for-testing", + "expires_in": 3600 + }) + +# 模拟 AI 对话端点 +@app.post("/v1/chat/completions") +async def mock_chat_completions(request: Request): + """模拟 OpenAI 兼容的聊天接口""" + body = await request.json() + logger.info(f"💬 Chat request: {json.dumps(body, ensure_ascii=False)[:500]}") + + # 返回模拟响应 + return JSONResponse({ + "id": "chatcmpl-test", + "object": "chat.completion", + "created": int(datetime.now().timestamp()), + "model": "kiro-proxy-test", + "choices": [{ + "index": 0, + "message": { + "role": "assistant", + "content": "🎉 反向代理测试成功!你的请求已被成功拦截。" + }, + "finish_reason": "stop" + }], + "usage": {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30} + }) + +# 捕获所有其他请求 +@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "OPTIONS", "PATCH"]) +async def catch_all(request: Request, path: str): + """捕获所有其他请求并记录""" + body = await request.body() + + logger.info(f"🎯 Caught: {request.method} /{path}") + + return JSONResponse({ + "proxy_status": "intercepted", + "method": request.method, + "path": f"/{path}", + "message": "请求已被反向代理捕获", + "headers_received": dict(request.headers) + }) + +if __name__ == "__main__": + print(""" +╔══════════════════════════════════════════════════════════════╗ +║ Kiro IDE 反向代理测试服务器 ║ +╠══════════════════════════════════════════════════════════════╣ +║ 端口: 8000 ║ +║ 查看日志: http://127.0.0.1:8000/logs ║ +║ 清空日志: http://127.0.0.1:8000/clear ║ +╠══════════════════════════════════════════════════════════════╣ +║ 使用方法: ║ +║ 1. 修改 Kiro 的 JS 源码,将 api.kiro.dev 替换为 127.0.0.1:8000 ║ +║ 2. 或者修改 /etc/hosts 添加: 127.0.0.1 api.kiro.dev ║ +║ 3. 启动 Kiro,观察此终端的日志输出 ║ +╚══════════════════════════════════════════════════════════════╝ + """) + uvicorn.run(app, host="0.0.0.0", port=8000) diff --git a/KiroProxy/scripts/test_kiro_proxy.py b/KiroProxy/scripts/test_kiro_proxy.py new file mode 100644 index 0000000000000000000000000000000000000000..ec5017437c3bc0c7b4929d090be6d4b7eda5fe9b --- /dev/null +++ b/KiroProxy/scripts/test_kiro_proxy.py @@ -0,0 +1,66 @@ +#!/usr/bin/env python3 +"""测试 Kiro 反向代理""" + +import requests +import json + +PROXY_URL = "http://127.0.0.1:8000" + +def test_health(): + print("1. 测试健康检查...") + r = requests.get(f"{PROXY_URL}/") + print(f" ✅ {r.json()}") + +def test_token(): + print("\n2. 检查 Token 状态...") + r = requests.get(f"{PROXY_URL}/token/status") + data = r.json() + if data.get("valid"): + print(f" ✅ Token 有效,过期时间: {data.get('expires_at')}") + else: + print(f" ❌ Token 无效: {data.get('error')}") + +def test_models(): + print("\n3. 列出可用模型...") + r = requests.get(f"{PROXY_URL}/v1/models") + models = r.json().get("data", []) + for m in models: + print(f" - {m['id']}") + +def test_chat(): + print("\n4. 测试聊天接口...") + r = requests.post( + f"{PROXY_URL}/v1/chat/completions", + json={ + "model": "claude-sonnet-4", + "messages": [ + {"role": "user", "content": "说一句话测试"} + ] + }, + timeout=60 + ) + + if r.status_code == 200: + data = r.json() + content = data["choices"][0]["message"]["content"] + print(f" ✅ AI 回复: {content[:200]}...") + else: + print(f" ❌ 错误 {r.status_code}: {r.text}") + +if __name__ == "__main__": + print("=" * 50) + print("Kiro 反向代理测试") + print("=" * 50) + + try: + test_health() + test_token() + test_models() + test_chat() + print("\n" + "=" * 50) + print("测试完成") + print("=" * 50) + except requests.exceptions.ConnectionError: + print("\n❌ 连接失败!请先启动代理服务器:") + print(" source venv/bin/activate") + print(" python kiro_proxy.py") diff --git a/KiroProxy/scripts/test_proxy.py b/KiroProxy/scripts/test_proxy.py new file mode 100644 index 0000000000000000000000000000000000000000..026b925778ead53f4faff28b4fc705c5d7a401fc --- /dev/null +++ b/KiroProxy/scripts/test_proxy.py @@ -0,0 +1,65 @@ +#!/usr/bin/env python3 +"""测试反向代理是否正常工作""" + +import requests +import json + +PROXY_URL = "http://127.0.0.1:8000" + +def test_health(): + """测试健康检查""" + print("1. 测试健康检查...") + r = requests.get(f"{PROXY_URL}/") + print(f" ✅ {r.json()}") + +def test_chat(): + """测试聊天接口""" + print("\n2. 测试聊天接口...") + r = requests.post( + f"{PROXY_URL}/v1/chat/completions", + json={ + "model": "test", + "messages": [{"role": "user", "content": "Hello"}] + } + ) + print(f" ✅ {r.json()['choices'][0]['message']['content']}") + +def test_catch_all(): + """测试通用捕获""" + print("\n3. 测试任意路径捕获...") + r = requests.post( + f"{PROXY_URL}/api/v1/some/kiro/endpoint", + json={"test": "data"} + ) + print(f" ✅ {r.json()['message']}") + +def test_auth(): + """测试认证端点""" + print("\n4. 测试认证端点...") + r = requests.post(f"{PROXY_URL}/auth/login") + print(f" ✅ Token: {r.json()['token']}") + +def view_logs(): + """查看日志""" + print("\n5. 查看捕获的请求日志...") + r = requests.get(f"{PROXY_URL}/logs") + data = r.json() + print(f" ✅ 共捕获 {data['total']} 个请求") + +if __name__ == "__main__": + print("=" * 50) + print("Kiro 反向代理测试") + print("=" * 50) + + try: + test_health() + test_chat() + test_catch_all() + test_auth() + view_logs() + print("\n" + "=" * 50) + print("✅ 所有测试通过!反向代理工作正常") + print("=" * 50) + except requests.exceptions.ConnectionError: + print("\n❌ 连接失败!请先启动代理服务器:") + print(" python proxy_server.py") diff --git a/KiroProxy/scripts/test_thinking.py b/KiroProxy/scripts/test_thinking.py new file mode 100644 index 0000000000000000000000000000000000000000..04e7fde34270c2cbaf3146196ee09dfcfa090f16 --- /dev/null +++ b/KiroProxy/scripts/test_thinking.py @@ -0,0 +1,115 @@ +#!/usr/bin/env python3 +""" +测试思考功能 +""" +import asyncio +import json +import httpx + +async def test_thinking_feature(): + """测试思考功能""" + + # 测试数据 + test_data = { + "model": "claude-sonnet-4.5", + "messages": [ + { + "role": "user", + "content": "请解释什么是递归,并给出一个简单的例子。" + } + ], + "thinking": { + "thinking_type": "enabled", + "budget_tokens": 5000 + }, + "stream": True, + "max_tokens": 1000 + } + + print("发送思考功能测试请求...") + print(f"请求内容: {json.dumps(test_data, indent=2, ensure_ascii=False)}") + print("\n" + "="*60 + "\n") + + try: + async with httpx.AsyncClient(timeout=60) as client: + async with client.stream( + "POST", + "http://localhost:8080/v1/messages", + headers={ + "Content-Type": "application/json", + "x-api-key": "any", + "anthropic-version": "2023-06-01" + }, + json=test_data + ) as response: + + if response.status_code != 200: + print(f"错误: {response.status_code}") + print(await response.aread()) + return + + print("收到响应:\n") + + thinking_content = [] + text_content = [] + current_block = None + + async for line in response.aiter_lines(): + if line.startswith("data: "): + data_str = line[6:] + + if data_str == "[DONE]": + break + + try: + data = json.loads(data_str) + event_type = data.get("type") + + if event_type == "content_block_start": + block_type = data.get("content_block", {}).get("type") + current_block = block_type + print(f"\n[开始 {block_type} 块]") + + elif event_type == "content_block_delta": + delta = data.get("delta", {}) + + if delta.get("type") == "thinking_delta": + thinking = delta.get("thinking", "") + thinking_content.append(thinking) + print(thinking, end="", flush=True) + + elif delta.get("type") == "text_delta": + text = delta.get("text", "") + text_content.append(text) + print(text, end="", flush=True) + + elif event_type == "content_block_stop": + print(f"\n[结束块]") + current_block = None + + elif event_type == "message_stop": + print("\n\n[响应完成]") + break + + elif event_type == "error": + print("\n\n[错误]") + print(json.dumps(data.get("error", data), ensure_ascii=False, indent=2)) + break + + except json.JSONDecodeError as e: + print(f"\n解析错误: {e}") + continue + + print("\n" + "="*60) + print("\n思考内容汇总:") + print("".join(thinking_content)) + + print("\n" + "-"*40) + print("\n回答内容汇总:") + print("".join(text_content)) + + except Exception as e: + print(f"请求失败: {e}") + +if __name__ == "__main__": + asyncio.run(test_thinking_feature()) diff --git a/KiroProxy/test_smart_mapping.py b/KiroProxy/test_smart_mapping.py new file mode 100644 index 0000000000000000000000000000000000000000..e16a4f44c2e99faf4e69523ea865453af1f62be2 --- /dev/null +++ b/KiroProxy/test_smart_mapping.py @@ -0,0 +1,196 @@ +#!/usr/bin/env python3 +"""测试智能模型映射功能""" + +import sys +sys.path.append('.') + +from kiro_proxy.config import map_model_name, detect_model_tier, get_best_model_by_tier + +def test_tier_detection(): + """测试等级检测功能""" + print("测试模型等级检测:") + + test_cases = [ + # Opus 等级 (最强) + ("claude-4-opus", "opus"), + ("gpt-o1-preview", "opus"), + ("gemini-1.5-pro", "opus"), + ("claude-3-opus-20240229", "opus"), + ("some-premium-model", "opus"), + + # Sonnet 等级 (平衡) + ("claude-3.5-sonnet", "sonnet"), + ("gpt-4o", "sonnet"), + ("gemini-2.0-flash", "sonnet"), + ("claude-4-standard", "sonnet"), + + # Haiku 等级 (快速) + ("claude-3-haiku", "haiku"), + ("gpt-4o-mini", "haiku"), + ("gpt-3.5-turbo", "haiku"), + ("claude-haiku-fast", "haiku"), + + # 未知模型 + ("unknown-model-xyz", "sonnet"), # 默认中等 + ("", "sonnet"), # 空值默认 + ] + + for model, expected in test_cases: + result = detect_model_tier(model) + status = "OK" if result == expected else "FAIL" + print(f" {status} {model:<25} -> {result:<6} (期望: {expected})") + +def test_dynamic_mapping(): + """测试动态模型映射(等级对等 + 智能降级)""" + print("\n测试动态模型映射(等级对等策略):") + + # 模拟不同的可用模型场景 + scenarios = [ + { + "name": "全部可用", + "available": {"claude-sonnet-4.5", "claude-sonnet-4", "claude-haiku-4.5", "auto"} + }, + { + "name": "缺少4.5版本", + "available": {"claude-sonnet-4", "claude-haiku-4.5", "auto"} + }, + { + "name": "仅有Haiku", + "available": {"claude-haiku-4.5", "auto"} + }, + { + "name": "仅有Sonnet-4", + "available": {"claude-sonnet-4", "auto"} + } + ] + + test_models = [ + ("claude-4-opus", "opus"), # 应该优先选择 sonnet-4.5 + ("gpt-4o", "sonnet"), # 应该优先选择 sonnet-4.5 + ("gpt-4o-mini", "haiku"), # 应该优先选择 haiku-4.5 + ("unknown-future-model", "sonnet") # 未知模型,默认 sonnet-4.5 + ] + + for scenario in scenarios: + print(f"\n 场景: {scenario['name']}") + print(f" 可用模型: {scenario['available']}") + + for model, expected_tier in test_models: + result = map_model_name(model, scenario['available']) + tier = detect_model_tier(model) + print(f" {model:<20} ({tier:<6}) -> {result}") + +def test_tier_mapping_logic(): + """测试等级对等映射逻辑""" + print("\n测试等级对等映射逻辑:") + + # 全部可用时的期望映射 + full_available = {"claude-sonnet-4.5", "claude-sonnet-4", "claude-haiku-4.5", "auto"} + + test_cases = [ + # 格式: (输入模型, 期望等级, 期望输出模型) + ("claude-4-opus", "opus", "claude-sonnet-4.5"), # Opus -> 最强 + ("gpt-4o", "sonnet", "claude-sonnet-4.5"), # Sonnet -> 高性能 + ("gpt-4o-mini", "haiku", "claude-haiku-4.5"), # Haiku -> 快速 + ("o1-preview", "opus", "claude-sonnet-4.5"), # O1 -> 最强 + ("claude-3.5-sonnet", "sonnet", "claude-sonnet-4.5"), # Sonnet -> 高性能 + ("gpt-3.5-turbo", "haiku", "claude-haiku-4.5"), # 3.5 -> 快速 + ] + + for model, expected_tier, expected_output in test_cases: + tier = detect_model_tier(model) + result = map_model_name(model, full_available) + tier_ok = "OK" if tier == expected_tier else "FAIL" + output_ok = "OK" if result == expected_output else "FAIL" + print(f" {tier_ok}/{output_ok} {model:<20} -> {tier:<6} -> {result}") + if tier != expected_tier: + print(f" 等级检测错误: 期望 {expected_tier}, 实际 {tier}") + if result != expected_output: + print(f" 映射错误: 期望 {expected_output}, 实际 {result}") + +def test_degradation_paths(): + """测试降级路径""" + print("\n测试降级路径:") + + degradation_scenarios = [ + { + "name": "Opus降级测试", + "model": "claude-4-opus", + "scenarios": [ + ({"claude-sonnet-4.5", "auto"}, "claude-sonnet-4.5"), # 首选可用 + ({"claude-sonnet-4", "auto"}, "claude-sonnet-4"), # 降级到次强 + ({"claude-haiku-4.5", "auto"}, "claude-haiku-4.5"), # 降级到快速 + ({"auto"}, "auto"), # 最终回退 + ] + }, + { + "name": "Haiku降级测试", + "model": "gpt-4o-mini", + "scenarios": [ + ({"claude-haiku-4.5", "auto"}, "claude-haiku-4.5"), # 首选可用 + ({"claude-sonnet-4", "auto"}, "claude-sonnet-4"), # 降级到标准 + ({"claude-sonnet-4.5", "auto"}, "claude-sonnet-4.5"), # 降级到高性能 + ({"auto"}, "auto"), # 最终回退 + ] + } + ] + + for test_group in degradation_scenarios: + print(f"\n {test_group['name']}:") + model = test_group['model'] + tier = detect_model_tier(model) + + for available, expected in test_group['scenarios']: + result = map_model_name(model, available) + status = "OK" if result == expected else "FAIL" + print(f" {status} 可用:{available} -> {result} (期望:{expected})") + +def test_backward_compatibility(): + """测试向后兼容性""" + print("\n测试向后兼容性:") + + # 原有的精确映射应该仍然工作 + legacy_tests = [ + ("gpt-4o", "claude-sonnet-4"), + ("claude-3-5-sonnet-20241022", "claude-sonnet-4"), + ("o1-preview", "claude-sonnet-4.5"), + ("gemini-1.5-pro", "claude-sonnet-4.5"), + ] + + for model, expected in legacy_tests: + result = map_model_name(model) + status = "OK" if result == expected else "FAIL" + print(f" {status} {model:<25} -> {result:<20} (期望: {expected})") + +def test_edge_cases(): + """测试边界情况""" + print("\n测试边界情况:") + + edge_cases = [ + ("", "auto"), # 空字符串 + (None, "auto"), # None值 (需要修改函数处理) + ("CLAUDE-4-OPUS", "claude-sonnet-4.5"), # 大写 + ("gpt-4o-MINI-turbo", "claude-haiku-4.5"), # 混合大小写 + ("claude_sonnet_4", "claude-sonnet-4"), # 下划线 + ] + + for model, expected in edge_cases: + try: + result = map_model_name(model or "") + tier = detect_model_tier(model or "") + status = "OK" if result == expected else "FAIL" + print(f" {status} {str(model):<25} ({tier}) -> {result}") + except Exception as e: + print(f" ERROR {str(model):<25} -> 错误: {e}") + +if __name__ == "__main__": + print("KiroProxy 智能模型映射测试(等级对等策略)\n") + + test_tier_detection() + test_tier_mapping_logic() + test_degradation_paths() + test_dynamic_mapping() + test_backward_compatibility() + test_edge_cases() + + print("\n测试完成!") \ No newline at end of file diff --git a/KiroProxy/tests/__init__.py b/KiroProxy/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..31408c77a5cf48146b6007cddab371e5cc16f846 --- /dev/null +++ b/KiroProxy/tests/__init__.py @@ -0,0 +1 @@ +# 测试模块 diff --git a/KiroProxy/tests/test_account_management.py b/KiroProxy/tests/test_account_management.py new file mode 100644 index 0000000000000000000000000000000000000000..2bb44bf90880622e23fc6299603a82b792b36ee6 --- /dev/null +++ b/KiroProxy/tests/test_account_management.py @@ -0,0 +1,527 @@ +"""账号管理增强功能属性测试 + +测试覆盖: +- Property 1: OAuth URL Generation Produces Valid PKCE Parameters +- Property 2: Token Response Parsing Extracts All Required Fields +- Property 3: Account Edit Validation and Persistence +- Property 4: Import Validation Based on AuthMethod +- Property 5: Batch Import Processes All Valid Entries +- Property 6: Token Refresh Method Dispatch +- Property 7: Token Refresh Updates Credentials +- Property 8: Provider Field Persistence +- Property 9: Compression State Tracking and Caching +- Property 10: Progressive Compression Strategy +""" +import pytest +import json +import hashlib +import base64 +import secrets +from unittest.mock import Mock, AsyncMock, patch +from datetime import datetime, timezone, timedelta + + +# ==================== Property 1 & 2: Social Auth Tests ==================== + +class TestSocialAuthOAuthURL: + """Property 1: OAuth URL Generation Produces Valid PKCE Parameters""" + + def test_code_verifier_length(self): + """code_verifier 应该是 43-128 字符""" + from kiro_proxy.auth.device_flow import _generate_code_verifier + verifier = _generate_code_verifier() + assert 43 <= len(verifier) <= 128 + + def test_code_verifier_is_url_safe(self): + """code_verifier 应该只包含 URL 安全字符""" + from kiro_proxy.auth.device_flow import _generate_code_verifier + verifier = _generate_code_verifier() + # URL safe base64 字符集 + valid_chars = set('ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_') + assert all(c in valid_chars for c in verifier) + + def test_code_challenge_is_sha256_of_verifier(self): + """code_challenge 应该是 code_verifier 的 SHA256 哈希""" + from kiro_proxy.auth.device_flow import _generate_code_verifier, _generate_code_challenge + verifier = _generate_code_verifier() + challenge = _generate_code_challenge(verifier) + + # 手动计算验证 + expected = base64.urlsafe_b64encode( + hashlib.sha256(verifier.encode()).digest() + ).rstrip(b'=').decode() + + assert challenge == expected + + def test_oauth_state_is_unique(self): + """每次生成的 state 应该是唯一的""" + from kiro_proxy.auth.device_flow import _generate_oauth_state + states = [_generate_oauth_state() for _ in range(100)] + assert len(set(states)) == 100 + + @pytest.mark.asyncio + async def test_start_social_auth_returns_valid_url(self): + """start_social_auth 应该返回有效的登录 URL""" + from kiro_proxy.auth.device_flow import start_social_auth + + success, result = await start_social_auth("google") + + assert success + assert "login_url" in result + assert "state" in result + assert "provider" in result + assert result["provider"] == "Google" + + # 验证 URL 包含必要参数 + url = result["login_url"] + assert "idp=Google" in url + assert "code_challenge=" in url + assert "code_challenge_method=S256" in url + assert "state=" in url + assert "redirect_uri=" in url + + @pytest.mark.asyncio + async def test_start_social_auth_github(self): + """GitHub 登录应该正确设置 provider""" + from kiro_proxy.auth.device_flow import start_social_auth + + success, result = await start_social_auth("github") + + assert success + assert result["provider"] == "Github" + assert "idp=Github" in result["login_url"] + + +class TestTokenResponseParsing: + """Property 2: Token Response Parsing Extracts All Required Fields""" + + def test_credentials_from_file_extracts_all_fields(self): + """from_file 应该提取所有必要字段""" + from kiro_proxy.credential.types import KiroCredentials + import tempfile + import os + + test_data = { + "accessToken": "test_access_token", + "refreshToken": "test_refresh_token", + "profileArn": "arn:aws:test", + "expiresAt": "2025-01-10T00:00:00Z", + "region": "us-west-2", + "authMethod": "social", + "provider": "Google", + "clientId": "test_client_id", + "clientSecret": "test_client_secret", + } + + with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f: + json.dump(test_data, f) + temp_path = f.name + + try: + creds = KiroCredentials.from_file(temp_path) + + assert creds.access_token == "test_access_token" + assert creds.refresh_token == "test_refresh_token" + assert creds.profile_arn == "arn:aws:test" + assert creds.region == "us-west-2" + assert creds.auth_method == "social" + assert creds.provider == "Google" + assert creds.client_id == "test_client_id" + assert creds.client_secret == "test_client_secret" + finally: + os.unlink(temp_path) + + def test_credentials_to_dict_includes_provider(self): + """to_dict 应该包含 provider 字段""" + from kiro_proxy.credential.types import KiroCredentials + + creds = KiroCredentials( + access_token="test", + refresh_token="test", + provider="Github" + ) + + data = creds.to_dict() + assert data["provider"] == "Github" + + def test_credentials_to_dict_excludes_none_provider(self): + """to_dict 不应该包含 None 的 provider""" + from kiro_proxy.credential.types import KiroCredentials + + creds = KiroCredentials( + access_token="test", + refresh_token="test", + provider=None + ) + + data = creds.to_dict() + assert "provider" not in data or data.get("provider") is None + + +# ==================== Property 6 & 7: Token Refresh Tests ==================== + +class TestTokenRefreshDispatch: + """Property 6: Token Refresh Method Dispatch""" + + def test_social_auth_uses_social_refresh(self): + """social authMethod 应该使用 refresh_social_token""" + from kiro_proxy.credential.refresher import TokenRefresher + from kiro_proxy.credential.types import KiroCredentials + + creds = KiroCredentials( + refresh_token="test_refresh_token_" + "x" * 100, + auth_method="social" + ) + refresher = TokenRefresher(creds) + + # 验证 URL + url = refresher.get_refresh_url() + assert "auth.desktop.kiro.dev/refreshToken" in url + + def test_idc_auth_uses_oidc_refresh(self): + """idc authMethod 应该使用 OIDC 端点""" + from kiro_proxy.credential.refresher import TokenRefresher + from kiro_proxy.credential.types import KiroCredentials + + creds = KiroCredentials( + refresh_token="test_refresh_token_" + "x" * 100, + auth_method="idc", + region="us-east-1" + ) + refresher = TokenRefresher(creds) + + url = refresher.get_refresh_url() + assert "oidc.us-east-1.amazonaws.com/token" in url + + def test_validate_refresh_token_rejects_empty(self): + """空的 refresh_token 应该被拒绝""" + from kiro_proxy.credential.refresher import TokenRefresher + from kiro_proxy.credential.types import KiroCredentials + + creds = KiroCredentials(refresh_token="") + refresher = TokenRefresher(creds) + + valid, error = refresher.validate_refresh_token() + assert not valid + assert "为空" in error or "缺少" in error + + def test_validate_refresh_token_rejects_truncated(self): + """截断的 refresh_token 应该被拒绝""" + from kiro_proxy.credential.refresher import TokenRefresher + from kiro_proxy.credential.types import KiroCredentials + + creds = KiroCredentials(refresh_token="short_token...") + refresher = TokenRefresher(creds) + + valid, error = refresher.validate_refresh_token() + assert not valid + assert "截断" in error + + +class TestTokenRefreshUpdates: + """Property 7: Token Refresh Updates Credentials""" + + def test_credentials_save_preserves_existing_data(self): + """save_to_file 应该保留现有数据""" + from kiro_proxy.credential.types import KiroCredentials + import tempfile + import os + + # 创建初始文件 + initial_data = { + "accessToken": "old_token", + "customField": "should_be_preserved" + } + + with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f: + json.dump(initial_data, f) + temp_path = f.name + + try: + # 更新凭证 + creds = KiroCredentials( + access_token="new_token", + refresh_token="new_refresh" + ) + creds.save_to_file(temp_path) + + # 验证 + with open(temp_path) as f: + saved_data = json.load(f) + + assert saved_data["accessToken"] == "new_token" + assert saved_data["refreshToken"] == "new_refresh" + assert saved_data["customField"] == "should_be_preserved" + finally: + os.unlink(temp_path) + + +# ==================== Property 8: Provider Field Persistence ==================== + +class TestProviderFieldPersistence: + """Property 8: Provider Field Persistence""" + + def test_provider_field_roundtrip(self): + """provider 字段应该能正确保存和加载""" + from kiro_proxy.credential.types import KiroCredentials + import tempfile + import os + + creds = KiroCredentials( + access_token="test", + refresh_token="test", + provider="Google", + auth_method="social" + ) + + with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f: + temp_path = f.name + + try: + creds.save_to_file(temp_path) + loaded = KiroCredentials.from_file(temp_path) + + assert loaded.provider == "Google" + assert loaded.auth_method == "social" + finally: + os.unlink(temp_path) + + def test_provider_in_status_info(self): + """get_status_info 应该包含 provider 字段""" + from kiro_proxy.core.account import Account + from kiro_proxy.credential.types import KiroCredentials + import tempfile + import os + + # 创建测试凭证文件 + test_data = { + "accessToken": "test", + "refreshToken": "test", + "provider": "Github", + "authMethod": "social" + } + + with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f: + json.dump(test_data, f) + temp_path = f.name + + try: + account = Account( + id="test_id", + name="Test Account", + token_path=temp_path + ) + account.load_credentials() + + status = account.get_status_info() + assert status["provider"] == "Github" + finally: + os.unlink(temp_path) + + +# ==================== Property 9 & 10: Compression Tests ==================== + +class TestCompressionStateTracking: + """Property 9: Compression State Tracking and Caching""" + + def test_hash_history_is_deterministic(self): + """相同历史应该产生相同哈希""" + from kiro_proxy.core.history_manager import HistoryManager + + manager = HistoryManager() + history = [ + {"userInputMessage": {"content": "Hello"}}, + {"assistantResponseMessage": {"content": "Hi"}} + ] + + hash1 = manager._hash_history(history) + hash2 = manager._hash_history(history) + + assert hash1 == hash2 + + def test_hash_history_changes_with_content(self): + """不同历史应该产生不同哈希""" + from kiro_proxy.core.history_manager import HistoryManager + + manager = HistoryManager() + # 使用不同长度的内容确保哈希不同 + history1 = [{"userInputMessage": {"content": "Hello"}}] + history2 = [{"userInputMessage": {"content": "Hello World, this is longer"}}] + + hash1 = manager._hash_history(history1) + hash2 = manager._hash_history(history2) + + assert hash1 != hash2 + + @pytest.mark.asyncio + async def test_compression_cache_prevents_repeated_compression(self): + """压缩缓存应该防止重复压缩""" + from kiro_proxy.core.history_manager import HistoryManager + + manager = HistoryManager() + history = [{"userInputMessage": {"content": "x" * 1000}} for _ in range(50)] + + # 第一次压缩 + result1, should_retry1 = await manager.handle_length_error_async(history, 0, None) + + # 第二次压缩相同内容 + result2, should_retry2 = await manager.handle_length_error_async(history, 0, None) + + # 第二次应该检测到重复并跳过 + # (由于缓存机制,第二次可能返回 False) + + +class TestProgressiveCompression: + """Property 10: Progressive Compression Strategy""" + + def test_max_retries_stops_compression(self): + """达到最大重试次数应该停止压缩""" + from kiro_proxy.core.history_manager import HistoryManager, HistoryConfig + + config = HistoryConfig(max_retries=3) + manager = HistoryManager(config) + history = [{"userInputMessage": {"content": "test"}}] + + # 超过最大重试次数 + result, should_retry = manager.handle_length_error(history, 5) + + assert not should_retry + + def test_small_history_not_compressed(self): + """小于目标大小的历史不应该被压缩""" + from kiro_proxy.core.history_manager import HistoryManager + + manager = HistoryManager() + history = [{"userInputMessage": {"content": "small"}}] + + result, should_retry = manager.handle_length_error(history, 0) + + # 小历史不需要压缩 + assert len(result) == len(history) + + @pytest.mark.asyncio + async def test_compression_reduces_size(self): + """压缩应该减少历史大小""" + from kiro_proxy.core.history_manager import HistoryManager + + manager = HistoryManager() + # 创建大历史 + history = [ + {"userInputMessage": {"content": f"Message {i}: " + "x" * 500}} + for i in range(100) + ] + + original_size = len(json.dumps(history)) + + # 模拟 API 调用 + async def mock_api_caller(prompt): + return "Summary of conversation" + + result, should_retry = await manager.handle_length_error_async( + history, 0, mock_api_caller + ) + + result_size = len(json.dumps(result)) + + # 压缩后应该更小 + assert result_size < original_size + + +# ==================== Property 3: Account Edit Tests ==================== + +class TestAccountEditValidation: + """Property 3: Account Edit Validation and Persistence""" + + def test_empty_name_not_updated(self): + """空名称不应该更新""" + # 这个测试需要模拟 API 调用 + pass + + def test_invalid_provider_rejected(self): + """无效的 provider 应该被拒绝""" + # 只允许 Google, Github, 或空 + valid_providers = [None, "", "Google", "Github"] + invalid_providers = ["facebook", "twitter", "invalid"] + + for p in valid_providers: + assert p in valid_providers + + for p in invalid_providers: + assert p not in valid_providers + + +# ==================== Property 4 & 5: Import Tests ==================== + +class TestImportValidation: + """Property 4: Import Validation Based on AuthMethod""" + + def test_idc_requires_client_credentials(self): + """IDC 认证应该需要 client_id 和 client_secret""" + # IDC 认证验证逻辑 + auth_method = "idc" + client_id = "" + client_secret = "" + + # 应该失败 + is_valid = not (auth_method == "idc" and (not client_id or not client_secret)) + assert not is_valid + + def test_social_does_not_require_client_credentials(self): + """Social 认证不需要 client_id 和 client_secret""" + auth_method = "social" + client_id = "" + client_secret = "" + + # 应该通过 + is_valid = not (auth_method == "idc" and (not client_id or not client_secret)) + assert is_valid + + def test_refresh_token_required(self): + """refresh_token 是必填的""" + refresh_token = "" + + is_valid = bool(refresh_token) + assert not is_valid + + +class TestBatchImport: + """Property 5: Batch Import Processes All Valid Entries""" + + def test_batch_import_skips_duplicates(self): + """批量导入应该跳过重复的 refresh_token""" + existing_tokens = {"token1", "token2"} + new_tokens = ["token1", "token3", "token4"] + + imported = [] + skipped = [] + + for token in new_tokens: + if token in existing_tokens: + skipped.append(token) + else: + imported.append(token) + existing_tokens.add(token) + + assert len(imported) == 2 + assert len(skipped) == 1 + assert "token1" in skipped + + def test_batch_import_continues_on_error(self): + """批量导入应该在单个错误后继续处理""" + accounts = [ + {"refresh_token": "valid1"}, + {"refresh_token": ""}, # 无效 + {"refresh_token": "valid2"}, + ] + + success = 0 + failed = 0 + + for acc in accounts: + if acc["refresh_token"]: + success += 1 + else: + failed += 1 + + assert success == 2 + assert failed == 1 diff --git a/KiroProxy/tests/test_account_selector.py b/KiroProxy/tests/test_account_selector.py new file mode 100644 index 0000000000000000000000000000000000000000..12213ea5783bb7fbf654926f94e3179e72b8a122 --- /dev/null +++ b/KiroProxy/tests/test_account_selector.py @@ -0,0 +1,511 @@ +"""AccountSelector 属性测试和单元测试 + +Property 4: 最少额度优先选择 +Property 5: 优先账号选择 +Property 6: 优先账号验证 +""" +import json +import os +import time +import tempfile +from pathlib import Path +from dataclasses import dataclass +from typing import Optional, Set + +import pytest +from hypothesis import given, strategies as st, settings, assume + +import sys +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from kiro_proxy.core.quota_cache import QuotaCache, CachedQuota +from kiro_proxy.core.account_selector import AccountSelector, SelectionStrategy + + +# ============== Mock Account 类 ============== + +@dataclass +class MockAccount: + """模拟账号类,用于测试""" + id: str + name: str = "" + enabled: bool = True + request_count: int = 0 + _available: bool = True + + def is_available(self) -> bool: + return self.enabled and self._available + + +# ============== 数据生成策略 ============== + +@st.composite +def account_id_strategy(draw): + """生成有效的账号ID""" + return draw(st.text( + alphabet=st.characters(whitelist_categories=('L', 'N'), whitelist_characters='_-'), + min_size=1, + max_size=16 + )) + + +@st.composite +def mock_account_strategy(draw, account_id: Optional[str] = None): + """生成模拟账号""" + if account_id is None: + account_id = draw(account_id_strategy()) + return MockAccount( + id=account_id, + name=f"Account {account_id}", + enabled=draw(st.booleans()), + request_count=draw(st.integers(min_value=0, max_value=10000)), + _available=draw(st.booleans()) + ) + + +@st.composite +def accounts_with_quotas_strategy(draw, min_accounts=2, max_accounts=10): + """生成账号列表和对应的额度缓存""" + num_accounts = draw(st.integers(min_value=min_accounts, max_value=max_accounts)) + + accounts = [] + quotas = {} + + for i in range(num_accounts): + account_id = f"acc_{i}" + account = MockAccount( + id=account_id, + name=f"Account {i}", + enabled=True, + request_count=draw(st.integers(min_value=0, max_value=1000)), + _available=True + ) + accounts.append(account) + + # 生成额度信息 + balance = draw(st.floats(min_value=0.0, max_value=1000.0, allow_nan=False, allow_infinity=False)) + quotas[account_id] = CachedQuota( + account_id=account_id, + usage_limit=1000.0, + current_usage=1000.0 - balance, + balance=balance, + updated_at=time.time() + ) + + return accounts, quotas + + +# ============== 默认策略:随机(避免单账号压力过大) ============== + +class TestDefaultRandomStrategy: + def test_default_strategy_is_random(self): + with tempfile.TemporaryDirectory() as tmpdir: + cache_file = os.path.join(tmpdir, "quota_cache.json") + priority_file = os.path.join(tmpdir, "priority.json") + + cache = QuotaCache(cache_file=cache_file) + selector = AccountSelector(quota_cache=cache, priority_file=priority_file) + + assert selector.strategy == SelectionStrategy.RANDOM + + def test_legacy_lowest_balance_migrates_to_random_when_no_priority(self): + with tempfile.TemporaryDirectory() as tmpdir: + cache_file = os.path.join(tmpdir, "quota_cache.json") + priority_file = os.path.join(tmpdir, "priority.json") + + Path(priority_file).write_text( + json.dumps( + {"version": "1.0", "priority_accounts": [], "strategy": "lowest_balance"}, + ensure_ascii=False, + indent=2, + ), + encoding="utf-8", + ) + + cache = QuotaCache(cache_file=cache_file) + selector = AccountSelector(quota_cache=cache, priority_file=priority_file) + + assert selector.strategy == SelectionStrategy.RANDOM + saved = json.loads(Path(priority_file).read_text(encoding="utf-8")) + assert saved.get("strategy") == SelectionStrategy.RANDOM.value + + def test_random_strategy_avoids_consecutive_same_account(self): + with tempfile.TemporaryDirectory() as tmpdir: + cache_file = os.path.join(tmpdir, "quota_cache.json") + priority_file = os.path.join(tmpdir, "priority.json") + + cache = QuotaCache(cache_file=cache_file) + selector = AccountSelector(quota_cache=cache, priority_file=priority_file) + + accounts = [ + MockAccount(id="acc_1", _available=True, enabled=True), + MockAccount(id="acc_2", _available=True, enabled=True), + MockAccount(id="acc_3", _available=True, enabled=True), + ] + + ids = [] + for _ in range(30): + selected = selector.select(accounts) + assert selected is not None + ids.append(selected.id) + + assert all(a != b for a, b in zip(ids, ids[1:])) + + +# ============== Property 4: 最少额度优先选择 ============== +# **Validates: Requirements 3.1, 3.3** + +class TestLowestBalanceSelection: + """Property 4: 最少额度优先选择测试""" + + @given(data=st.data()) + @settings(max_examples=100) + def test_selects_lowest_balance_account(self, data): + """ + Property 4: 最少额度优先选择 + *对于任意*可用账号列表(无优先账号配置时),选择器应返回剩余额度最少的账号。 + 如果存在多个相同最少额度的账号,应返回请求次数最少的账号。 + + **Validates: Requirements 3.1, 3.3** + """ + with tempfile.TemporaryDirectory() as tmpdir: + cache_file = os.path.join(tmpdir, "quota_cache.json") + priority_file = os.path.join(tmpdir, "priority.json") + + cache = QuotaCache(cache_file=cache_file) + selector = AccountSelector(quota_cache=cache, priority_file=priority_file) + selector.strategy = SelectionStrategy.LOWEST_BALANCE + + # 生成账号和额度 + accounts, quotas = data.draw(accounts_with_quotas_strategy(min_accounts=2, max_accounts=5)) + + # 设置缓存 + for account_id, quota in quotas.items(): + cache.set(account_id, quota) + + # 选择账号 + selected = selector.select(accounts) + + # 验证选择了余额最少的账号 + assert selected is not None + + selected_quota = quotas[selected.id] + for account in accounts: + if account.is_available(): + other_quota = quotas[account.id] + # 选中的账号余额应该 <= 其他账号 + if other_quota.balance < selected_quota.balance: + # 如果有更低余额的账号,测试失败 + assert False, f"应该选择余额更低的账号 {account.id}" + elif other_quota.balance == selected_quota.balance: + # 余额相同时,请求数应该 <= 其他账号 + assert selected.request_count <= account.request_count + + def test_selects_lowest_balance_simple(self): + """简单场景:选择余额最少的账号""" + with tempfile.TemporaryDirectory() as tmpdir: + cache_file = os.path.join(tmpdir, "quota_cache.json") + priority_file = os.path.join(tmpdir, "priority.json") + + cache = QuotaCache(cache_file=cache_file) + selector = AccountSelector(quota_cache=cache, priority_file=priority_file) + selector.strategy = SelectionStrategy.LOWEST_BALANCE + + # 创建账号 + accounts = [ + MockAccount(id="acc_1", request_count=10, _available=True, enabled=True), + MockAccount(id="acc_2", request_count=5, _available=True, enabled=True), + MockAccount(id="acc_3", request_count=20, _available=True, enabled=True), + ] + + # 设置额度:acc_2 余额最少 + cache.set("acc_1", CachedQuota(account_id="acc_1", balance=500.0, updated_at=time.time())) + cache.set("acc_2", CachedQuota(account_id="acc_2", balance=100.0, updated_at=time.time())) + cache.set("acc_3", CachedQuota(account_id="acc_3", balance=800.0, updated_at=time.time())) + + selected = selector.select(accounts) + assert selected is not None + assert selected.id == "acc_2" + + def test_same_balance_selects_least_requests(self): + """余额相同时选择请求数最少的账号""" + with tempfile.TemporaryDirectory() as tmpdir: + cache_file = os.path.join(tmpdir, "quota_cache.json") + priority_file = os.path.join(tmpdir, "priority.json") + + cache = QuotaCache(cache_file=cache_file) + selector = AccountSelector(quota_cache=cache, priority_file=priority_file) + selector.strategy = SelectionStrategy.LOWEST_BALANCE + + accounts = [ + MockAccount(id="acc_1", request_count=100, _available=True, enabled=True), + MockAccount(id="acc_2", request_count=50, _available=True, enabled=True), + MockAccount(id="acc_3", request_count=200, _available=True, enabled=True), + ] + + # 所有账号余额相同 + for acc in accounts: + cache.set(acc.id, CachedQuota(account_id=acc.id, balance=500.0, updated_at=time.time())) + + selected = selector.select(accounts) + assert selected is not None + assert selected.id == "acc_2" # 请求数最少 + + +# ============== Property 5: 优先账号选择 ============== +# **Validates: Requirements 3.2, 4.3, 4.4** + +class TestPriorityAccountSelection: + """Property 5: 优先账号选择测试""" + + @given(data=st.data()) + @settings(max_examples=100) + def test_priority_account_selected_first(self, data): + """ + Property 5: 优先账号选择 + *对于任意*可用账号列表和优先账号配置,如果优先账号列表中存在可用账号, + 选择器应按优先级顺序返回第一个可用的优先账号; + 如果所有优先账号都不可用,应回退到最少额度优先策略。 + + **Validates: Requirements 3.2, 4.3, 4.4** + """ + with tempfile.TemporaryDirectory() as tmpdir: + cache_file = os.path.join(tmpdir, "quota_cache.json") + priority_file = os.path.join(tmpdir, "priority.json") + + cache = QuotaCache(cache_file=cache_file) + selector = AccountSelector(quota_cache=cache, priority_file=priority_file) + + # 生成账号 + num_accounts = data.draw(st.integers(min_value=3, max_value=6)) + accounts = [] + for i in range(num_accounts): + account_id = f"acc_{i}" + accounts.append(MockAccount( + id=account_id, + enabled=True, + _available=True, + request_count=data.draw(st.integers(min_value=0, max_value=100)) + )) + cache.set(account_id, CachedQuota( + account_id=account_id, + balance=data.draw(st.floats(min_value=100.0, max_value=1000.0, allow_nan=False, allow_infinity=False)), + updated_at=time.time() + )) + + # 随机选择一些账号作为优先账号 + priority_count = data.draw(st.integers(min_value=1, max_value=min(3, num_accounts))) + priority_ids = [accounts[i].id for i in range(priority_count)] + + valid_ids = {acc.id for acc in accounts} + selector.set_priority_accounts(priority_ids, valid_ids) + + # 选择账号 + selected = selector.select(accounts) + + # 验证选择了第一个可用的优先账号 + assert selected is not None + + # 找到第一个可用的优先账号 + first_available_priority = None + for pid in priority_ids: + for acc in accounts: + if acc.id == pid and acc.is_available(): + first_available_priority = acc + break + if first_available_priority: + break + + if first_available_priority: + assert selected.id == first_available_priority.id + + def test_priority_fallback_to_lowest_balance(self): + """优先账号不可用时回退到最少额度策略""" + with tempfile.TemporaryDirectory() as tmpdir: + cache_file = os.path.join(tmpdir, "quota_cache.json") + priority_file = os.path.join(tmpdir, "priority.json") + + cache = QuotaCache(cache_file=cache_file) + selector = AccountSelector(quota_cache=cache, priority_file=priority_file) + selector.strategy = SelectionStrategy.LOWEST_BALANCE + + accounts = [ + MockAccount(id="acc_1", _available=False, enabled=True), # 优先但不可用 + MockAccount(id="acc_2", _available=True, enabled=True), + MockAccount(id="acc_3", _available=True, enabled=True), + ] + + cache.set("acc_1", CachedQuota(account_id="acc_1", balance=1000.0, updated_at=time.time())) + cache.set("acc_2", CachedQuota(account_id="acc_2", balance=200.0, updated_at=time.time())) + cache.set("acc_3", CachedQuota(account_id="acc_3", balance=500.0, updated_at=time.time())) + + # 设置 acc_1 为优先账号 + selector.set_priority_accounts(["acc_1"], {"acc_1", "acc_2", "acc_3"}) + + selected = selector.select(accounts) + + # 优先账号不可用,应该选择余额最少的 acc_2 + assert selected is not None + assert selected.id == "acc_2" + + +# ============== Property 6: 优先账号验证 ============== +# **Validates: Requirements 4.2** + +class TestPriorityAccountValidation: + """Property 6: 优先账号验证测试""" + + @given( + valid_ids=st.lists(account_id_strategy(), min_size=1, max_size=5, unique=True), + invalid_id=account_id_strategy() + ) + @settings(max_examples=100) + def test_invalid_account_rejected(self, valid_ids: list, invalid_id: str): + """ + Property 6: 优先账号验证 + *对于任意*账号ID,设置为优先账号时,如果该账号不存在或未启用, + 操作应失败并返回错误;如果账号存在且已启用,操作应成功。 + + **Validates: Requirements 4.2** + """ + # 确保 invalid_id 不在 valid_ids 中 + assume(invalid_id not in valid_ids) + + with tempfile.TemporaryDirectory() as tmpdir: + cache_file = os.path.join(tmpdir, "quota_cache.json") + priority_file = os.path.join(tmpdir, "priority.json") + + cache = QuotaCache(cache_file=cache_file) + selector = AccountSelector(quota_cache=cache, priority_file=priority_file) + + valid_set = set(valid_ids) + + # 测试添加无效账号 + success, msg = selector.add_priority_account(invalid_id, valid_account_ids=valid_set) + assert not success, "添加无效账号应该失败" + + # 测试添加有效账号 + valid_id = valid_ids[0] + success, msg = selector.add_priority_account(valid_id, valid_account_ids=valid_set) + assert success, "添加有效账号应该成功" + + def test_set_priority_validates_all_accounts(self): + """设置优先账号列表时验证所有账号""" + with tempfile.TemporaryDirectory() as tmpdir: + cache_file = os.path.join(tmpdir, "quota_cache.json") + priority_file = os.path.join(tmpdir, "priority.json") + + cache = QuotaCache(cache_file=cache_file) + selector = AccountSelector(quota_cache=cache, priority_file=priority_file) + + valid_ids = {"acc_1", "acc_2", "acc_3"} + + # 包含无效账号的列表应该失败 + success, msg = selector.set_priority_accounts( + ["acc_1", "invalid_acc"], + valid_account_ids=valid_ids + ) + assert not success + + # 全部有效的列表应该成功 + success, msg = selector.set_priority_accounts( + ["acc_1", "acc_2"], + valid_account_ids=valid_ids + ) + assert success + + +# ============== 单元测试:空账号列表和边界情况 ============== +# **Validates: Requirements 3.4** + +class TestEdgeCases: + """边界情况单元测试""" + + def test_empty_account_list(self): + """空账号列表应返回 None""" + with tempfile.TemporaryDirectory() as tmpdir: + cache_file = os.path.join(tmpdir, "quota_cache.json") + priority_file = os.path.join(tmpdir, "priority.json") + + cache = QuotaCache(cache_file=cache_file) + selector = AccountSelector(quota_cache=cache, priority_file=priority_file) + + selected = selector.select([]) + assert selected is None + + def test_all_accounts_unavailable(self): + """所有账号不可用时应返回 None""" + with tempfile.TemporaryDirectory() as tmpdir: + cache_file = os.path.join(tmpdir, "quota_cache.json") + priority_file = os.path.join(tmpdir, "priority.json") + + cache = QuotaCache(cache_file=cache_file) + selector = AccountSelector(quota_cache=cache, priority_file=priority_file) + + accounts = [ + MockAccount(id="acc_1", _available=False, enabled=True), + MockAccount(id="acc_2", _available=False, enabled=True), + ] + + selected = selector.select(accounts) + assert selected is None + + def test_remove_priority_account(self): + """移除优先账号""" + with tempfile.TemporaryDirectory() as tmpdir: + cache_file = os.path.join(tmpdir, "quota_cache.json") + priority_file = os.path.join(tmpdir, "priority.json") + + cache = QuotaCache(cache_file=cache_file) + selector = AccountSelector(quota_cache=cache, priority_file=priority_file) + + selector.set_priority_accounts(["acc_1", "acc_2"], None) + assert "acc_1" in selector.get_priority_accounts() + + success, _ = selector.remove_priority_account("acc_1") + assert success + assert "acc_1" not in selector.get_priority_accounts() + + # 移除不存在的账号应该失败 + success, _ = selector.remove_priority_account("acc_1") + assert not success + + def test_reorder_priority_accounts(self): + """重新排序优先账号""" + with tempfile.TemporaryDirectory() as tmpdir: + cache_file = os.path.join(tmpdir, "quota_cache.json") + priority_file = os.path.join(tmpdir, "priority.json") + + cache = QuotaCache(cache_file=cache_file) + selector = AccountSelector(quota_cache=cache, priority_file=priority_file) + + selector.set_priority_accounts(["acc_1", "acc_2", "acc_3"], None) + + # 正确的重排序 + success, _ = selector.reorder_priority(["acc_3", "acc_1", "acc_2"]) + assert success + assert selector.get_priority_accounts() == ["acc_3", "acc_1", "acc_2"] + + # 缺少账号的重排序应该失败 + success, _ = selector.reorder_priority(["acc_3", "acc_1"]) + assert not success + + def test_priority_order(self): + """获取优先级顺序""" + with tempfile.TemporaryDirectory() as tmpdir: + cache_file = os.path.join(tmpdir, "quota_cache.json") + priority_file = os.path.join(tmpdir, "priority.json") + + cache = QuotaCache(cache_file=cache_file) + selector = AccountSelector(quota_cache=cache, priority_file=priority_file) + + selector.set_priority_accounts(["acc_1", "acc_2", "acc_3"], None) + + assert selector.get_priority_order("acc_1") == 1 + assert selector.get_priority_order("acc_2") == 2 + assert selector.get_priority_order("acc_3") == 3 + assert selector.get_priority_order("acc_4") is None + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/KiroProxy/tests/test_model_mapping.py b/KiroProxy/tests/test_model_mapping.py new file mode 100644 index 0000000000000000000000000000000000000000..625ccaee866f4f634422be91c5ec5285ca10949c --- /dev/null +++ b/KiroProxy/tests/test_model_mapping.py @@ -0,0 +1,25 @@ +import pytest + + +def test_map_model_name_downgrades_opus(): + from kiro_proxy.config import map_model_name + + assert map_model_name("claude-opus-4.5") == "claude-sonnet-4.5" + assert map_model_name("claude-3-opus-20240229") == "claude-sonnet-4.5" + assert map_model_name("claude-3-opus-latest") == "claude-sonnet-4.5" + assert map_model_name("claude-4-opus") == "claude-sonnet-4.5" + assert map_model_name("o1") == "claude-sonnet-4.5" + assert map_model_name("o1-preview") == "claude-sonnet-4.5" + assert map_model_name("opus") == "claude-sonnet-4.5" + + +@pytest.mark.asyncio +async def test_models_fallback_does_not_advertise_opus(monkeypatch): + from kiro_proxy.routers import protocols + + monkeypatch.setattr(protocols.state, "get_available_account", lambda *args, **kwargs: None) + + resp = await protocols.models() + model_ids = {m["id"] for m in resp.get("data", [])} + assert "claude-opus-4.5" not in model_ids + diff --git a/KiroProxy/tests/test_quota_cache.py b/KiroProxy/tests/test_quota_cache.py new file mode 100644 index 0000000000000000000000000000000000000000..246181d3e4a60de3911110e724f8c1fccb9518f0 --- /dev/null +++ b/KiroProxy/tests/test_quota_cache.py @@ -0,0 +1,479 @@ +"""QuotaCache 属性测试和单元测试 + +Property 1: 缓存存储完整性 - 存储后读取应返回完整数据 +Property 2: 缓存持久化往返 - 保存后加载应产生等价状态 +""" +import os +import time +import tempfile +from pathlib import Path + +import pytest +from hypothesis import given, strategies as st, settings, assume + +# 添加项目路径 +import sys +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from kiro_proxy.core.quota_cache import ( + QuotaCache, CachedQuota, DEFAULT_CACHE_MAX_AGE +) + + +# ============== 数据生成策略 ============== + +# 固定的时间戳范围,避免 hypothesis 的 flaky 问题 +FIXED_MAX_TIMESTAMP = 2000000000.0 # 约 2033 年 + + +@st.composite +def valid_quota_strategy(draw): + """生成有效的 CachedQuota 数据""" + usage_limit = draw(st.floats(min_value=0.0, max_value=10000.0, allow_nan=False, allow_infinity=False)) + current_usage = draw(st.floats(min_value=0.0, max_value=usage_limit, allow_nan=False, allow_infinity=False)) + balance = usage_limit - current_usage + usage_percent = (current_usage / usage_limit * 100) if usage_limit > 0 else 0.0 + + free_trial_limit = draw(st.floats(min_value=0.0, max_value=1000.0, allow_nan=False, allow_infinity=False)) + free_trial_usage = draw(st.floats(min_value=0.0, max_value=free_trial_limit, allow_nan=False, allow_infinity=False)) + + bonus_limit = draw(st.floats(min_value=0.0, max_value=500.0, allow_nan=False, allow_infinity=False)) + bonus_usage = draw(st.floats(min_value=0.0, max_value=bonus_limit, allow_nan=False, allow_infinity=False)) + + return CachedQuota( + account_id=draw(st.text(alphabet=st.characters(whitelist_categories=('L', 'N'), whitelist_characters='_-'), min_size=1, max_size=32)), + usage_limit=usage_limit, + current_usage=current_usage, + balance=balance, + usage_percent=round(usage_percent, 2), + is_low_balance=balance < usage_limit * 0.2 if usage_limit > 0 else False, + subscription_title=draw(st.text(min_size=0, max_size=50)), + free_trial_limit=free_trial_limit, + free_trial_usage=free_trial_usage, + bonus_limit=bonus_limit, + bonus_usage=bonus_usage, + updated_at=draw(st.floats(min_value=0.0, max_value=FIXED_MAX_TIMESTAMP, allow_nan=False, allow_infinity=False)), + error=draw(st.one_of(st.none(), st.text(min_size=1, max_size=100))) + ) + + +@st.composite +def account_id_strategy(draw): + """生成有效的账号ID""" + return draw(st.text( + alphabet=st.characters(whitelist_categories=('L', 'N'), whitelist_characters='_-'), + min_size=1, + max_size=32 + )) + + +# ============== Property 1: 缓存存储完整性 ============== +# **Validates: Requirements 1.2, 2.3** + +class TestCacheStorageIntegrity: + """Property 1: 缓存存储完整性测试""" + + @given(quota=valid_quota_strategy()) + @settings(max_examples=100) + def test_set_then_get_returns_complete_data(self, quota: CachedQuota): + """ + Property 1: 缓存存储完整性 + *对于任意*有效的额度信息,当存储到 QuotaCache 后, + 读取该账号的缓存应返回包含所有必要字段的完整数据。 + + **Validates: Requirements 1.2, 2.3** + """ + # 使用临时文件避免影响真实缓存 + with tempfile.NamedTemporaryFile(suffix='.json', delete=False) as f: + cache_file = f.name + + try: + cache = QuotaCache(cache_file=cache_file) + + # 存储 + cache.set(quota.account_id, quota) + + # 读取 + retrieved = cache.get(quota.account_id) + + # 验证完整性 + assert retrieved is not None, "缓存应该存在" + assert retrieved.account_id == quota.account_id, "account_id 应该一致" + assert retrieved.usage_limit == quota.usage_limit, "usage_limit 应该一致" + assert retrieved.current_usage == quota.current_usage, "current_usage 应该一致" + assert retrieved.balance == quota.balance, "balance 应该一致" + assert retrieved.updated_at == quota.updated_at, "updated_at 应该一致" + assert retrieved.error == quota.error, "error 应该一致" + + finally: + # 清理临时文件 + if os.path.exists(cache_file): + os.unlink(cache_file) + + @given(quotas=st.lists(valid_quota_strategy(), min_size=1, max_size=10, unique_by=lambda q: q.account_id)) + @settings(max_examples=50) + def test_multiple_accounts_stored_independently(self, quotas: list): + """多个账号的缓存应该独立存储""" + with tempfile.NamedTemporaryFile(suffix='.json', delete=False) as f: + cache_file = f.name + + try: + cache = QuotaCache(cache_file=cache_file) + + # 存储所有账号 + for quota in quotas: + cache.set(quota.account_id, quota) + + # 验证每个账号都能正确读取 + for quota in quotas: + retrieved = cache.get(quota.account_id) + assert retrieved is not None + assert retrieved.account_id == quota.account_id + assert retrieved.balance == quota.balance + + finally: + if os.path.exists(cache_file): + os.unlink(cache_file) + + +# ============== Property 2: 缓存持久化往返 ============== +# **Validates: Requirements 7.1, 7.2** + +class TestCachePersistenceRoundTrip: + """Property 2: 缓存持久化往返测试""" + + @given(quotas=st.lists(valid_quota_strategy(), min_size=1, max_size=10, unique_by=lambda q: q.account_id)) + @settings(max_examples=100) + def test_save_then_load_preserves_data(self, quotas: list): + """ + Property 2: 缓存持久化往返 + *对于任意*有效的 QuotaCache 状态,保存到文件后再加载, + 应产生等价的缓存状态(所有账号的额度信息保持一致)。 + + **Validates: Requirements 7.1, 7.2** + """ + with tempfile.NamedTemporaryFile(suffix='.json', delete=False) as f: + cache_file = f.name + + try: + # 创建并填充缓存 + cache1 = QuotaCache(cache_file=cache_file) + for quota in quotas: + cache1.set(quota.account_id, quota) + + # 保存到文件 + success = cache1.save_to_file() + assert success, "保存应该成功" + + # 创建新缓存实例并加载 + cache2 = QuotaCache(cache_file=cache_file) + + # 验证数据一致性 + all_cache1 = cache1.get_all() + all_cache2 = cache2.get_all() + + assert len(all_cache1) == len(all_cache2), "账号数量应该一致" + + for account_id, quota1 in all_cache1.items(): + quota2 = all_cache2.get(account_id) + assert quota2 is not None, f"账号 {account_id} 应该存在" + assert quota1.usage_limit == quota2.usage_limit + assert quota1.current_usage == quota2.current_usage + assert quota1.balance == quota2.balance + assert quota1.updated_at == quota2.updated_at + assert quota1.error == quota2.error + + finally: + if os.path.exists(cache_file): + os.unlink(cache_file) + + @given(quota=valid_quota_strategy()) + @settings(max_examples=50) + def test_dict_roundtrip(self, quota: CachedQuota): + """CachedQuota 的字典序列化往返""" + # 转换为字典 + quota_dict = quota.to_dict() + + # 从字典恢复 + restored = CachedQuota.from_dict(quota_dict) + + # 验证一致性 + assert restored.account_id == quota.account_id + assert restored.usage_limit == quota.usage_limit + assert restored.current_usage == quota.current_usage + assert restored.balance == quota.balance + assert restored.updated_at == quota.updated_at + assert restored.error == quota.error + + +# ============== 单元测试:缓存过期检测 ============== +# **Validates: Requirements 7.3** + +class TestCacheExpiration: + """缓存过期检测单元测试""" + + def test_fresh_cache_not_stale(self): + """新缓存不应该过期""" + with tempfile.NamedTemporaryFile(suffix='.json', delete=False) as f: + cache_file = f.name + + try: + cache = QuotaCache(cache_file=cache_file) + quota = CachedQuota( + account_id="test_account", + usage_limit=1000.0, + current_usage=500.0, + balance=500.0, + updated_at=time.time() # 当前时间 + ) + cache.set("test_account", quota) + + assert not cache.is_stale("test_account"), "新缓存不应该过期" + + finally: + if os.path.exists(cache_file): + os.unlink(cache_file) + + def test_old_cache_is_stale(self): + """旧缓存应该过期""" + with tempfile.NamedTemporaryFile(suffix='.json', delete=False) as f: + cache_file = f.name + + try: + cache = QuotaCache(cache_file=cache_file) + quota = CachedQuota( + account_id="test_account", + usage_limit=1000.0, + current_usage=500.0, + balance=500.0, + updated_at=time.time() - DEFAULT_CACHE_MAX_AGE - 1 # 超过过期时间 + ) + cache.set("test_account", quota) + + assert cache.is_stale("test_account"), "旧缓存应该过期" + + finally: + if os.path.exists(cache_file): + os.unlink(cache_file) + + def test_nonexistent_account_is_stale(self): + """不存在的账号应该被视为过期""" + with tempfile.NamedTemporaryFile(suffix='.json', delete=False) as f: + cache_file = f.name + + try: + cache = QuotaCache(cache_file=cache_file) + assert cache.is_stale("nonexistent"), "不存在的账号应该被视为过期" + + finally: + if os.path.exists(cache_file): + os.unlink(cache_file) + + +# ============== 单元测试:文件读写错误处理 ============== +# **Validates: Requirements 7.3** + +class TestFileErrorHandling: + """文件读写错误处理单元测试""" + + def test_load_nonexistent_file(self): + """加载不存在的文件应该返回 False""" + cache = QuotaCache(cache_file="/nonexistent/path/cache.json") + result = cache.load_from_file() + assert result is False + + def test_load_invalid_json(self): + """加载无效 JSON 应该返回 False""" + with tempfile.NamedTemporaryFile(suffix='.json', delete=False, mode='w') as f: + f.write("invalid json {{{") + cache_file = f.name + + try: + cache = QuotaCache(cache_file=cache_file) + # 构造函数会尝试加载,但应该处理错误 + assert len(cache.get_all()) == 0 + + finally: + if os.path.exists(cache_file): + os.unlink(cache_file) + + def test_remove_account(self): + """移除账号应该正常工作""" + with tempfile.NamedTemporaryFile(suffix='.json', delete=False) as f: + cache_file = f.name + + try: + cache = QuotaCache(cache_file=cache_file) + quota = CachedQuota( + account_id="test_account", + usage_limit=1000.0, + updated_at=time.time() + ) + cache.set("test_account", quota) + assert cache.get("test_account") is not None + + cache.remove("test_account") + assert cache.get("test_account") is None + + finally: + if os.path.exists(cache_file): + os.unlink(cache_file) + + def test_clear_cache(self): + """清空缓存应该正常工作""" + with tempfile.NamedTemporaryFile(suffix='.json', delete=False) as f: + cache_file = f.name + + try: + cache = QuotaCache(cache_file=cache_file) + for i in range(5): + quota = CachedQuota( + account_id=f"account_{i}", + usage_limit=1000.0, + updated_at=time.time() + ) + cache.set(f"account_{i}", quota) + + assert len(cache.get_all()) == 5 + + cache.clear() + assert len(cache.get_all()) == 0 + + finally: + if os.path.exists(cache_file): + os.unlink(cache_file) + + +# ============== 单元测试:CachedQuota 辅助方法 ============== + +class TestCachedQuotaMethods: + """CachedQuota 辅助方法测试""" + + def test_has_error(self): + """has_error 方法测试""" + quota_ok = CachedQuota(account_id="test", error=None) + quota_err = CachedQuota(account_id="test", error="Some error") + + assert not quota_ok.has_error() + assert quota_err.has_error() + + def test_is_exhausted(self): + """is_exhausted 属性测试""" + quota_ok = CachedQuota(account_id="test", balance=100.0, usage_limit=1000.0) + quota_zero = CachedQuota(account_id="test", balance=0.0, usage_limit=1000.0) + quota_negative = CachedQuota(account_id="test", balance=-10.0, usage_limit=1000.0) + quota_error = CachedQuota(account_id="test", balance=0.0, usage_limit=1000.0, error="Error") + + assert not quota_ok.is_exhausted + assert quota_zero.is_exhausted + assert quota_negative.is_exhausted + assert not quota_error.is_exhausted # 有错误时不更新状态 + + def test_balance_status(self): + """balance_status 属性测试""" + # 正常状态 (>20%) + quota_normal = CachedQuota(account_id="test", balance=500.0, usage_limit=1000.0) + assert quota_normal.balance_status == "normal" + assert not quota_normal.is_low_balance + assert not quota_normal.is_exhausted + + # 低额度状态 (0-20%) + quota_low = CachedQuota(account_id="test", balance=100.0, usage_limit=1000.0) + assert quota_low.balance_status == "low" + assert quota_low.is_low_balance + assert not quota_low.is_exhausted + + # 无额度状态 (<=0) + quota_exhausted = CachedQuota(account_id="test", balance=0.0, usage_limit=1000.0) + assert quota_exhausted.balance_status == "exhausted" + assert not quota_exhausted.is_low_balance + assert quota_exhausted.is_exhausted + + def test_is_available(self): + """is_available 方法测试""" + quota_ok = CachedQuota(account_id="test", balance=100.0, usage_limit=1000.0) + quota_exhausted = CachedQuota(account_id="test", balance=0.0, usage_limit=1000.0) + quota_error = CachedQuota(account_id="test", balance=100.0, error="Error") + + assert quota_ok.is_available() + assert not quota_exhausted.is_available() + assert not quota_error.is_available() + + def test_from_error(self): + """from_error 工厂方法测试""" + quota = CachedQuota.from_error("test_account", "Connection failed") + + assert quota.account_id == "test_account" + assert quota.error == "Connection failed" + assert quota.has_error() + assert quota.updated_at > 0 + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) + + +# ============== Property 10: 低额度与无额度区分 ============== +# **Validates: Requirements 5.5, 5.6** + +class TestBalanceStatusDistinction: + """Property 10: 低额度与无额度区分测试""" + + @given( + balance=st.floats(min_value=-100.0, max_value=1000.0, allow_nan=False, allow_infinity=False), + usage_limit=st.floats(min_value=100.0, max_value=1000.0, allow_nan=False, allow_infinity=False) + ) + @settings(max_examples=100) + def test_balance_status_distinction(self, balance: float, usage_limit: float): + """ + Property 10: 低额度与无额度区分 + *对于任意*账号,当剩余额度大于0但低于总额度的20%时,应标记为"低额度"状态; + 当剩余额度为0或负数时,应标记为"无额度"状态。 + + **Validates: Requirements 5.5, 5.6** + """ + quota = CachedQuota( + account_id="test_account", + balance=balance, + usage_limit=usage_limit, + updated_at=time.time() + ) + + remaining_percent = (balance / usage_limit) * 100 if usage_limit > 0 else 0 + + if balance <= 0: + # 无额度状态 + assert quota.balance_status == "exhausted", f"余额 {balance} 应该是 exhausted 状态" + assert quota.is_exhausted, f"余额 {balance} 应该标记为 is_exhausted" + assert not quota.is_low_balance, f"余额 {balance} 不应该标记为 is_low_balance" + assert not quota.is_available(), f"余额 {balance} 不应该可用" + elif remaining_percent <= 20: + # 低额度状态 + assert quota.balance_status == "low", f"余额 {balance}/{usage_limit} ({remaining_percent:.1f}%) 应该是 low 状态" + assert quota.is_low_balance, f"余额 {balance}/{usage_limit} 应该标记为 is_low_balance" + assert not quota.is_exhausted, f"余额 {balance} 不应该标记为 is_exhausted" + assert quota.is_available(), f"余额 {balance} 应该可用" + else: + # 正常状态 + assert quota.balance_status == "normal", f"余额 {balance}/{usage_limit} ({remaining_percent:.1f}%) 应该是 normal 状态" + assert not quota.is_low_balance, f"余额 {balance}/{usage_limit} 不应该标记为 is_low_balance" + assert not quota.is_exhausted, f"余额 {balance} 不应该标记为 is_exhausted" + assert quota.is_available(), f"余额 {balance} 应该可用" + + def test_boundary_values(self): + """边界值测试""" + # 正好 20% + quota_20 = CachedQuota(account_id="test", balance=200.0, usage_limit=1000.0) + assert quota_20.balance_status == "low" + + # 刚好超过 20% + quota_21 = CachedQuota(account_id="test", balance=210.0, usage_limit=1000.0) + assert quota_21.balance_status == "normal" + + # 正好 0 + quota_0 = CachedQuota(account_id="test", balance=0.0, usage_limit=1000.0) + assert quota_0.balance_status == "exhausted" + + # 负数 + quota_neg = CachedQuota(account_id="test", balance=-10.0, usage_limit=1000.0) + assert quota_neg.balance_status == "exhausted" diff --git a/KiroProxy/tests/test_quota_reset_time.py b/KiroProxy/tests/test_quota_reset_time.py new file mode 100644 index 0000000000000000000000000000000000000000..40b77eea7d8774ce63c6ff092ba5d3ec93f7aea6 --- /dev/null +++ b/KiroProxy/tests/test_quota_reset_time.py @@ -0,0 +1,167 @@ +"""测试额度重置时间功能""" +import asyncio +import json +from datetime import datetime, timezone, timedelta +from kiro_proxy.core.usage import calculate_balance, UsageInfo + + +def test_quota_reset_time(): + """测试额度重置时间解析""" + + # 模拟 API 响应数据 + mock_response = { + "subscriptionInfo": { + "subscriptionTitle": "Kiro Pro" + }, + "usageBreakdownList": [ + { + "resourceType": "CREDIT", + "displayName": "Credits", + "usageLimitWithPrecision": 50.0, + "currentUsageWithPrecision": 25.0, + "freeTrialInfo": { + "freeTrialStatus": "ACTIVE", + "usageLimitWithPrecision": 500.0, + "currentUsageWithPrecision": 100.0, + "freeTrialExpiry": "2026-02-13T23:59:59Z" + }, + "bonuses": [ + { + "status": "ACTIVE", + "usageLimitWithPrecision": 100.0, + "currentUsageWithPrecision": 0.0, + "expiresAt": "2026-03-01T23:59:59Z" + }, + { + "status": "ACTIVE", + "usageLimitWithPrecision": 50.0, + "currentUsageWithPrecision": 25.0, + "expiresAt": "2026-02-28T23:59:59Z" + } + ] + } + ], + "nextDateReset": "2026-02-01T00:00:00Z" + } + + # 解析额度信息 + usage_info = calculate_balance(mock_response) + + # 验证结果 + print("=== 额度信息解析结果 ===") + print(f"订阅类型: {usage_info.subscription_title}") + print(f"总额度: {usage_info.usage_limit}") + print(f"已用额度: {usage_info.current_usage}") + print(f"剩余额度: {usage_info.balance}") + print(f"使用率: {(usage_info.current_usage / usage_info.usage_limit * 100):.1f}%") + print(f"下次重置时间: {usage_info.next_reset_date}") + print(f"免费试用过期时间: {usage_info.free_trial_expiry}") + print(f"奖励过期时间列表: {usage_info.bonus_expiries}") + + # 验证具体数值 + assert usage_info.usage_limit == 700.0 # 50 + 500 + 100 + 50 + assert usage_info.current_usage == 150.0 # 25 + 100 + 0 + 25 + assert usage_info.balance == 550.0 + assert usage_info.free_trial_limit == 500.0 + assert usage_info.free_trial_usage == 100.0 + assert usage_info.bonus_limit == 150.0 # 100 + 50 + assert usage_info.bonus_usage == 25.0 # 0 + 25 + assert usage_info.next_reset_date == "2026-02-01T00:00:00Z" + assert usage_info.free_trial_expiry == "2026-02-13T23:59:59Z" + assert len(usage_info.bonus_expiries) == 2 + + print("\n✅ 测试通过!") + + +def test_quota_cache_with_reset_time(): + """测试额度缓存的重置时间功能""" + from kiro_proxy.core.quota_cache import CachedQuota + + # 创建 UsageInfo + usage_info = UsageInfo( + subscription_title="Kiro Pro", + usage_limit=100.0, + current_usage=50.0, + balance=50.0, + next_reset_date="2026-02-01T00:00:00Z", + free_trial_expiry="2026-02-13T23:59:59Z", + bonus_expiries=["2026-03-01T23:59:59Z", "2026-02-28T23:59:59Z"] + ) + + # 从 UsageInfo 创建 CachedQuota + cached_quota = CachedQuota.from_usage_info("test_account", usage_info) + + # 验证转换结果 + print("\n=== 缓存额度信息 ===") + print(f"账号ID: {cached_quota.account_id}") + print(f"下次重置时间: {cached_quota.next_reset_date}") + print(f"免费试用过期时间: {cached_quota.free_trial_expiry}") + print(f"奖励过期时间: {cached_quota.bonus_expiries}") + + # 转换为字典并验证 + quota_dict = cached_quota.to_dict() + assert quota_dict["next_reset_date"] == "2026-02-01T00:00:00Z" + assert quota_dict["free_trial_expiry"] == "2026-02-13T23:59:59Z" + assert len(quota_dict["bonus_expiries"]) == 2 + + # 从字典重建并验证 + rebuilt_quota = CachedQuota.from_dict(quota_dict) + assert rebuilt_quota.next_reset_date == cached_quota.next_reset_date + assert rebuilt_quota.free_trial_expiry == cached_quota.free_trial_expiry + assert rebuilt_quota.bonus_expiries == cached_quota.bonus_expiries + + print("\n✅ 缓存测试通过!") + + +def test_account_status_info(): + """测试账号状态信息中的重置时间""" + from kiro_proxy.core.account import Account + from kiro_proxy.core.quota_cache import CachedQuota, get_quota_cache + + # 创建模拟账号 + account = Account( + id="test_account", + name="测试账号", + token_path="/tmp/test_token.json" + ) + + # 创建缓存额度 + cached_quota = CachedQuota( + account_id="test_account", + usage_limit=100.0, + current_usage=50.0, + balance=50.0, + next_reset_date="2026-02-01T00:00:00Z", + free_trial_expiry="2026-02-13T23:59:59Z", + bonus_expiries=["2026-03-01T23:59:59Z"] + ) + + # 设置缓存 + quota_cache = get_quota_cache() + quota_cache.set("test_account", cached_quota) + + # 获取状态信息 + status_info = account.get_status_info() + + # 验证重置时间信息 + print("\n=== 账号状态信息 ===") + quota_info = status_info.get("quota") + if quota_info: + print(f"下次重置时间: {quota_info.get('next_reset_date')}") + print(f"格式化重置日期: {quota_info.get('reset_date_text')}") + print(f"免费试用过期时间: {quota_info.get('free_trial_expiry')}") + print(f"格式化过期日期: {quota_info.get('trial_expiry_text')}") + print(f"生效奖励数: {quota_info.get('active_bonuses')}") + + assert quota_info["reset_date_text"] == "2026-02-01" + assert quota_info["trial_expiry_text"] == "2026-02-13" + assert quota_info["active_bonuses"] == 1 + + print("\n✅ 账号状态测试通过!") + + +if __name__ == "__main__": + test_quota_reset_time() + test_quota_cache_with_reset_time() + test_account_status_info() + print("\n🎉 所有测试通过!额度重置时间功能已成功实现。") diff --git a/KiroProxy/tests/test_quota_scheduler.py b/KiroProxy/tests/test_quota_scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..a0d78c488914caa3eb2603a8a88628d7d92c0f2a --- /dev/null +++ b/KiroProxy/tests/test_quota_scheduler.py @@ -0,0 +1,273 @@ +"""QuotaScheduler 属性测试和单元测试 + +Property 3: 活跃账号判定 +Property 7: 额度耗尽检测 +Property 8: 缓存过期检测 +Property 9: 获取失败状态标记 +""" +import os +import time +import tempfile +from pathlib import Path +from dataclasses import dataclass +from typing import Optional + +import pytest +from hypothesis import given, strategies as st, settings + +import sys +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from kiro_proxy.core.quota_cache import QuotaCache, CachedQuota, DEFAULT_CACHE_MAX_AGE +from kiro_proxy.core.quota_scheduler import ( + QuotaScheduler, ACTIVE_WINDOW_SECONDS +) + + +# ============== Property 3: 活跃账号判定 ============== +# **Validates: Requirements 2.2** + +class TestActiveAccountDetermination: + """Property 3: 活跃账号判定测试""" + + @given( + account_id=st.text(min_size=1, max_size=16, alphabet=st.characters(whitelist_categories=('L', 'N'))), + seconds_ago=st.floats(min_value=0.0, max_value=300.0, allow_nan=False, allow_infinity=False) + ) + @settings(max_examples=100) + def test_active_account_determination(self, account_id: str, seconds_ago: float): + """ + Property 3: 活跃账号判定 + *对于任意*账号和任意时间戳,如果账号的最后使用时间在当前时间60秒内, + 则该账号应被判定为活跃账号;否则应被判定为非活跃账号。 + + **Validates: Requirements 2.2** + """ + with tempfile.TemporaryDirectory() as tmpdir: + cache_file = os.path.join(tmpdir, "quota_cache.json") + cache = QuotaCache(cache_file=cache_file) + scheduler = QuotaScheduler(quota_cache=cache) + + # 模拟在 seconds_ago 秒前使用账号 + scheduler._active_accounts[account_id] = time.time() - seconds_ago + + is_active = scheduler.is_active(account_id) + + # 验证活跃判定 + if seconds_ago < ACTIVE_WINDOW_SECONDS: + assert is_active, f"账号在 {seconds_ago:.1f} 秒前使用,应该是活跃的" + else: + assert not is_active, f"账号在 {seconds_ago:.1f} 秒前使用,不应该是活跃的" + + def test_mark_active_updates_timestamp(self): + """标记活跃应该更新时间戳""" + with tempfile.TemporaryDirectory() as tmpdir: + cache_file = os.path.join(tmpdir, "quota_cache.json") + cache = QuotaCache(cache_file=cache_file) + scheduler = QuotaScheduler(quota_cache=cache) + + # 初始不活跃 + assert not scheduler.is_active("test_account") + + # 标记活跃 + scheduler.mark_active("test_account") + + # 现在应该活跃 + assert scheduler.is_active("test_account") + + def test_get_active_accounts(self): + """获取活跃账号列表""" + with tempfile.TemporaryDirectory() as tmpdir: + cache_file = os.path.join(tmpdir, "quota_cache.json") + cache = QuotaCache(cache_file=cache_file) + scheduler = QuotaScheduler(quota_cache=cache) + + # 设置一些账号 + scheduler._active_accounts["active_1"] = time.time() + scheduler._active_accounts["active_2"] = time.time() - 30 + scheduler._active_accounts["inactive"] = time.time() - 120 + + active = scheduler.get_active_accounts() + + assert "active_1" in active + assert "active_2" in active + assert "inactive" not in active + + +# ============== Property 7: 额度耗尽检测 ============== +# **Validates: Requirements 2.4** + +class TestQuotaExhaustion: + """Property 7: 额度耗尽检测测试""" + + @given(balance=st.floats(min_value=-100.0, max_value=100.0, allow_nan=False, allow_infinity=False)) + @settings(max_examples=100) + def test_quota_exhaustion_detection(self, balance: float): + """ + Property 7: 额度耗尽检测 + *对于任意*账号,当其剩余额度为0或负数时,该账号应被标记为不可用状态。 + + **Validates: Requirements 2.4** + """ + quota = CachedQuota( + account_id="test_account", + usage_limit=1000.0, + current_usage=1000.0 - balance, + balance=balance, + updated_at=time.time() + ) + + # is_exhausted 现在是属性而不是方法 + is_exhausted = quota.is_exhausted + + if balance <= 0: + assert is_exhausted, f"余额 {balance} 应该被判定为耗尽" + else: + assert not is_exhausted, f"余额 {balance} 不应该被判定为耗尽" + + def test_error_quota_not_exhausted(self): + """有错误的额度不应该被判定为耗尽""" + quota = CachedQuota( + account_id="test_account", + balance=0.0, + usage_limit=1000.0, + error="Connection failed" + ) + + # 有错误时不更新状态 + assert not quota.is_exhausted + + +# ============== Property 8: 缓存过期检测 ============== +# **Validates: Requirements 7.3** + +class TestCacheStaleDetection: + """Property 8: 缓存过期检测测试""" + + @given( + age_seconds=st.floats(min_value=0.0, max_value=1000.0, allow_nan=False, allow_infinity=False), + max_age=st.integers(min_value=60, max_value=600) + ) + @settings(max_examples=100, deadline=None) + def test_cache_stale_detection(self, age_seconds: float, max_age: int): + """ + Property 8: 缓存过期检测 + *对于任意*缓存记录和过期阈值(默认5分钟),如果缓存的更新时间距当前时间超过阈值, + 则该缓存应被判定为过期。 + + **Validates: Requirements 7.3** + """ + with tempfile.TemporaryDirectory() as tmpdir: + cache_file = os.path.join(tmpdir, "quota_cache.json") + cache = QuotaCache(cache_file=cache_file) + + quota = CachedQuota( + account_id="test_account", + balance=500.0, + updated_at=time.time() - age_seconds + ) + cache.set("test_account", quota) + + is_stale = cache.is_stale("test_account", max_age_seconds=max_age) + + if age_seconds > max_age: + assert is_stale, f"缓存年龄 {age_seconds:.1f}s 超过阈值 {max_age}s,应该过期" + else: + assert not is_stale, f"缓存年龄 {age_seconds:.1f}s 未超过阈值 {max_age}s,不应该过期" + + def test_default_max_age(self): + """默认过期时间为5分钟""" + with tempfile.TemporaryDirectory() as tmpdir: + cache_file = os.path.join(tmpdir, "quota_cache.json") + cache = QuotaCache(cache_file=cache_file) + + # 4分钟前的缓存不应该过期 + quota1 = CachedQuota( + account_id="fresh", + balance=500.0, + updated_at=time.time() - 240 + ) + cache.set("fresh", quota1) + assert not cache.is_stale("fresh") + + # 6分钟前的缓存应该过期 + quota2 = CachedQuota( + account_id="stale", + balance=500.0, + updated_at=time.time() - 360 + ) + cache.set("stale", quota2) + assert cache.is_stale("stale") + + +# ============== Property 9: 获取失败状态标记 ============== +# **Validates: Requirements 1.3** + +class TestFetchFailureMarking: + """Property 9: 获取失败状态标记测试""" + + @given(error_msg=st.text(min_size=1, max_size=100)) + @settings(max_examples=50) + def test_error_marking(self, error_msg: str): + """ + Property 9: 获取失败状态标记 + *对于任意*账号,当额度获取失败时,该账号的缓存应包含错误信息, + 且账号状态应被标记为额度未知。 + + **Validates: Requirements 1.3** + """ + quota = CachedQuota.from_error("test_account", error_msg) + + assert quota.has_error(), "应该有错误标记" + assert quota.error == error_msg, "错误信息应该一致" + assert quota.account_id == "test_account" + assert quota.updated_at > 0, "应该有更新时间" + + def test_error_quota_fields(self): + """错误状态的额度字段应该为默认值""" + quota = CachedQuota.from_error("test_account", "Connection timeout") + + assert quota.usage_limit == 0.0 + assert quota.current_usage == 0.0 + assert quota.balance == 0.0 + assert quota.has_error() + + +# ============== 单元测试:调度器状态 ============== + +class TestSchedulerStatus: + """调度器状态测试""" + + def test_initial_status(self): + """初始状态""" + with tempfile.TemporaryDirectory() as tmpdir: + cache_file = os.path.join(tmpdir, "quota_cache.json") + cache = QuotaCache(cache_file=cache_file) + scheduler = QuotaScheduler(quota_cache=cache) + + status = scheduler.get_status() + + assert status["running"] is False + assert status["update_interval"] == 60 + assert status["active_count"] == 0 + assert status["last_full_refresh"] is None + + def test_cleanup_inactive(self): + """清理不活跃账号""" + with tempfile.TemporaryDirectory() as tmpdir: + cache_file = os.path.join(tmpdir, "quota_cache.json") + cache = QuotaCache(cache_file=cache_file) + scheduler = QuotaScheduler(quota_cache=cache) + + # 添加一些账号 + scheduler._active_accounts["recent"] = time.time() + scheduler._active_accounts["old"] = time.time() - (ACTIVE_WINDOW_SECONDS * 2 + 1) # 超过 2 * ACTIVE_WINDOW + + scheduler.cleanup_inactive() + + assert "recent" in scheduler._active_accounts + assert "old" not in scheduler._active_accounts + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/KiroProxy/tests/test_refresh_manager.py b/KiroProxy/tests/test_refresh_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..dfa1d96a3a0dd142a06f1e9d8ebfabbbe039e575 --- /dev/null +++ b/KiroProxy/tests/test_refresh_manager.py @@ -0,0 +1,689 @@ +"""RefreshManager 属性测试和单元测试 + +Property 11: Token 过期检测与自动刷新 +Property 12: 刷新锁互斥 +Property 13: 异常后锁释放 +Property 15: 重试次数限制 +Property 16: 指数退避延迟 +Property 17: 429 错误特殊处理 +Property 20: 401 错误自动重试 +""" +import os +import time +import asyncio +import tempfile +from pathlib import Path +from dataclasses import dataclass +from typing import Optional +from datetime import datetime, timezone, timedelta +from unittest.mock import Mock, AsyncMock, patch + +import pytest +from hypothesis import given, strategies as st, settings + +import sys +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from kiro_proxy.core.refresh_manager import ( + RefreshManager, RefreshConfig, RefreshProgress, + get_refresh_manager, reset_refresh_manager +) + + +# ============== 辅助类和函数 ============== + +@dataclass +class MockCredentials: + """模拟凭证""" + expires_at: Optional[str] = None + + def is_expired(self) -> bool: + if not self.expires_at: + return True + try: + expires = datetime.fromisoformat(self.expires_at.replace("Z", "+00:00")) + now = datetime.now(timezone.utc) + return expires <= now + timedelta(minutes=5) + except Exception: + return True + + def is_expiring_soon(self, minutes: int = 10) -> bool: + if not self.expires_at: + return False + try: + expires = datetime.fromisoformat(self.expires_at.replace("Z", "+00:00")) + now = datetime.now(timezone.utc) + return expires < now + timedelta(minutes=minutes) + except Exception: + return False + + +class MockAccount: + """模拟账号""" + def __init__(self, account_id: str, name: str = None, enabled: bool = True): + self.id = account_id + self.name = name or account_id + self.enabled = enabled + self.status = Mock() + self.status.value = "active" + self._credentials = None + self._refresh_result = (True, "Token 刷新成功") + + def get_credentials(self): + return self._credentials + + def set_credentials(self, creds): + self._credentials = creds + + def set_refresh_result(self, success: bool, message: str): + self._refresh_result = (success, message) + + async def refresh_token(self): + return self._refresh_result + + +# ============== Property 11: Token 过期检测与自动刷新 ============== +# **Validates: Requirements 12.1, 12.2, 17.2** + +class TestTokenExpirationDetection: + """Property 11: Token 过期检测与自动刷新测试""" + + @given( + minutes_until_expiry=st.integers(min_value=-60, max_value=60) + ) + @settings(max_examples=100) + def test_token_expiration_detection(self, minutes_until_expiry: int): + """ + Property 11: Token 过期检测与自动刷新 + *对于任意*账号和当前时间,如果 Token 过期时间距当前时间小于5分钟, + 则该账号应被判定为需要刷新 Token。 + + **Validates: Requirements 12.1, 12.2, 17.2** + """ + manager = RefreshManager() + account = MockAccount("test_account") + + # 设置过期时间 + expires_at = (datetime.now(timezone.utc) + timedelta(minutes=minutes_until_expiry)).isoformat() + creds = MockCredentials(expires_at=expires_at) + account.set_credentials(creds) + + should_refresh = manager.should_refresh_token(account) + + # 默认配置是过期前5分钟刷新 + if minutes_until_expiry <= 5: + assert should_refresh, f"Token 将在 {minutes_until_expiry} 分钟后过期,应该需要刷新" + else: + assert not should_refresh, f"Token 将在 {minutes_until_expiry} 分钟后过期,不应该需要刷新" + + def test_no_credentials_needs_refresh(self): + """无凭证时应该需要刷新""" + manager = RefreshManager() + account = MockAccount("test_account") + account.set_credentials(None) + + assert manager.should_refresh_token(account) + + @pytest.mark.asyncio + async def test_refresh_token_if_needed_valid(self): + """Token 有效时不刷新""" + manager = RefreshManager() + account = MockAccount("test_account") + + # 设置1小时后过期 + expires_at = (datetime.now(timezone.utc) + timedelta(hours=1)).isoformat() + creds = MockCredentials(expires_at=expires_at) + account.set_credentials(creds) + + success, message = await manager.refresh_token_if_needed(account) + + assert success + assert "无需刷新" in message + + @pytest.mark.asyncio + async def test_refresh_token_if_needed_expired(self): + """Token 过期时自动刷新""" + manager = RefreshManager() + account = MockAccount("test_account") + + # 设置已过期 + expires_at = (datetime.now(timezone.utc) - timedelta(hours=1)).isoformat() + creds = MockCredentials(expires_at=expires_at) + account.set_credentials(creds) + account.set_refresh_result(True, "刷新成功") + + success, message = await manager.refresh_token_if_needed(account) + + assert success + assert "刷新成功" in message + + +# ============== Property 12: 刷新锁互斥 ============== +# **Validates: Requirements 14.1, 14.2** + +class TestRefreshLockMutex: + """Property 12: 刷新锁互斥测试""" + + @pytest.mark.asyncio + async def test_concurrent_refresh_blocked(self): + """ + Property 12: 刷新锁互斥 + *对于任意*两个并发的批量刷新请求,系统应只允许一个请求执行, + 另一个请求应被拒绝并返回当前进度信息。 + + **Validates: Requirements 14.1, 14.2** + """ + manager = RefreshManager() + + # 第一个请求获取锁 + acquired1 = await manager.acquire_refresh_lock() + assert acquired1, "第一个请求应该成功获取锁" + + # 第二个请求应该被拒绝 + acquired2 = await manager.acquire_refresh_lock() + assert not acquired2, "第二个请求应该被拒绝" + + # 释放锁 + manager.release_refresh_lock() + + # 现在应该可以获取锁 + acquired3 = await manager.acquire_refresh_lock() + assert acquired3, "锁释放后应该可以获取" + manager.release_refresh_lock() + + @pytest.mark.asyncio + async def test_is_refreshing_status(self): + """刷新状态正确反映""" + manager = RefreshManager() + + assert not manager.is_refreshing() + + # 模拟开始刷新 + manager._start_refresh(5, "测试刷新") + + assert manager.is_refreshing() + + # 完成刷新 + manager._finish_refresh("completed") + + assert not manager.is_refreshing() + + +# ============== Property 13: 异常后锁释放 ============== +# **Validates: Requirements 14.5** + +class TestLockReleaseAfterException: + """Property 13: 异常后锁释放测试""" + + @pytest.mark.asyncio + async def test_lock_released_after_exception(self): + """ + Property 13: 异常后锁释放 + *对于任意*刷新操作,如果操作异常终止,系统应自动释放锁。 + + **Validates: Requirements 14.5** + """ + manager = RefreshManager() + + # 创建会抛出异常的账号 + account = MockAccount("test_account") + + async def failing_quota_func(acc): + raise Exception("模拟异常") + + # 执行刷新(应该捕获异常并释放锁) + result = await manager.refresh_all_with_token( + [account], + get_quota_func=failing_quota_func + ) + + # 锁应该已释放 + assert not manager._async_lock.locked(), "异常后锁应该被释放" + + # 状态应该是 error 或 completed + assert result.status in ("error", "completed") + + +# ============== Property 15: 重试次数限制 ============== +# **Validates: Requirements 15.1, 15.2, 15.5** + +class TestRetryLimit: + """Property 15: 重试次数限制测试""" + + @given(max_retries=st.integers(min_value=0, max_value=5)) + @settings(max_examples=20, deadline=None) + @pytest.mark.asyncio + async def test_retry_count_limit(self, max_retries: int): + """ + Property 15: 重试次数限制 + *对于任意*失败的刷新操作和配置的最大重试次数 N, + 系统应最多重试 N 次。 + + **Validates: Requirements 15.1, 15.2, 15.5** + """ + config = RefreshConfig(max_retries=max_retries, retry_base_delay=0.01) + manager = RefreshManager(config=config) + + call_count = 0 + + async def always_fail(): + nonlocal call_count + call_count += 1 + return False, "总是失败" + + success, result = await manager.retry_with_backoff(always_fail) + + # 应该调用 max_retries + 1 次(初始 + 重试) + expected_calls = max_retries + 1 + assert call_count == expected_calls, f"应该调用 {expected_calls} 次,实际调用 {call_count} 次" + assert not success, "应该最终失败" + + +# ============== Property 16: 指数退避延迟 ============== +# **Validates: Requirements 15.3** + +class TestExponentialBackoff: + """Property 16: 指数退避延迟测试""" + + @given( + attempt=st.integers(min_value=0, max_value=5), + base_delay=st.floats(min_value=0.1, max_value=2.0, allow_nan=False, allow_infinity=False) + ) + @settings(max_examples=50) + def test_exponential_backoff_delay(self, attempt: int, base_delay: float): + """ + Property 16: 指数退避延迟 + *对于任意*重试操作和重试次数 i,第 i 次重试前的等待时间应为 base_delay * 2^i 秒。 + + **Validates: Requirements 15.3** + """ + config = RefreshConfig(retry_base_delay=base_delay) + manager = RefreshManager(config=config) + + # 计算预期延迟 + expected_delay = base_delay * (2 ** attempt) + + # 验证计算逻辑(这里我们验证公式) + actual_delay = config.retry_base_delay * (2 ** attempt) + + assert abs(actual_delay - expected_delay) < 0.001, \ + f"第 {attempt} 次重试延迟应为 {expected_delay:.3f}s,实际为 {actual_delay:.3f}s" + + +# ============== Property 17: 429 错误特殊处理 ============== +# **Validates: Requirements 15.7** + +class TestRateLimitHandling: + """Property 17: 429 错误特殊处理测试""" + + def test_rate_limit_error_detection(self): + """ + Property 17: 429 错误特殊处理 + *对于任意*返回 429 限流错误的请求,系统应识别为限流错误。 + + **Validates: Requirements 15.7** + """ + manager = RefreshManager() + + # 测试各种 429 错误格式 + assert manager._is_rate_limit_error("HTTP 429 Too Many Requests") + assert manager._is_rate_limit_error("Rate limit exceeded") + assert manager._is_rate_limit_error("请求过于频繁,请稍后重试") + + # 非限流错误 + assert not manager._is_rate_limit_error("HTTP 500 Internal Server Error") + assert not manager._is_rate_limit_error("Connection timeout") + + @given(attempt=st.integers(min_value=0, max_value=5)) + @settings(max_examples=20) + def test_rate_limit_longer_delay(self, attempt: int): + """429 错误应使用更长的等待时间""" + config = RefreshConfig(retry_base_delay=1.0) + manager = RefreshManager(config=config) + + normal_delay = config.retry_base_delay * (2 ** attempt) + rate_limit_delay = manager._get_rate_limit_delay(attempt, config.retry_base_delay) + + # 429 延迟应该是普通延迟的 3 倍 + assert rate_limit_delay == normal_delay * 3, \ + f"429 延迟应为普通延迟的 3 倍" + + +# ============== Property 20: 401 错误自动重试 ============== +# **Validates: Requirements 12.6** + +class TestAuthErrorRetry: + """Property 20: 401 错误自动重试测试""" + + def test_auth_error_detection(self): + """ + Property 20: 401 错误自动重试 + 系统应识别 401 认证错误。 + + **Validates: Requirements 12.6** + """ + manager = RefreshManager() + + # 测试各种 401 错误格式 + assert manager._is_auth_error("HTTP 401 Unauthorized") + assert manager._is_auth_error("凭证已过期或无效,需要重新登录") + assert manager._is_auth_error("Unauthorized access") + + # 非认证错误 + assert not manager._is_auth_error("HTTP 500 Internal Server Error") + assert not manager._is_auth_error("Connection timeout") + + @pytest.mark.asyncio + async def test_auth_error_triggers_token_refresh(self): + """401 错误应触发 Token 刷新并重试""" + manager = RefreshManager() + account = MockAccount("test_account") + account.set_refresh_result(True, "刷新成功") + + call_count = 0 + + async def fail_then_succeed(): + nonlocal call_count + call_count += 1 + if call_count == 1: + return False, "HTTP 401 Unauthorized" + return True, "成功" + + success, result = await manager.execute_with_auth_retry( + account, + fail_then_succeed + ) + + assert success, "重试后应该成功" + assert call_count == 2, "应该调用两次(失败 + 重试)" + + +# ============== 单元测试:配置管理 ============== + +class TestConfigManagement: + """配置管理测试""" + + def test_default_config(self): + """默认配置值""" + config = RefreshConfig() + + assert config.max_retries == 3 + assert config.retry_base_delay == 1.0 + assert config.concurrency == 3 + assert config.token_refresh_before_expiry == 300 + assert config.auto_refresh_interval == 60 + + def test_config_validation(self): + """配置验证""" + # 有效配置 + config = RefreshConfig(max_retries=5, concurrency=10) + assert config.validate() + + # 无效配置 + with pytest.raises(ValueError): + RefreshConfig(max_retries=-1).validate() + + with pytest.raises(ValueError): + RefreshConfig(retry_base_delay=0).validate() + + with pytest.raises(ValueError): + RefreshConfig(concurrency=0).validate() + + def test_update_config(self): + """更新配置""" + manager = RefreshManager() + + manager.update_config(max_retries=5, concurrency=10) + + assert manager.config.max_retries == 5 + assert manager.config.concurrency == 10 + # 其他值保持不变 + assert manager.config.retry_base_delay == 1.0 + + +# ============== 单元测试:进度跟踪 ============== + +class TestProgressTracking: + """进度跟踪测试""" + + def test_progress_creation(self): + """进度创建""" + progress = RefreshProgress( + total=10, + completed=5, + success=4, + failed=1 + ) + + assert progress.progress_percent == 50.0 + assert progress.is_running() + assert not progress.is_completed() + + def test_progress_to_dict(self): + """进度转字典""" + progress = RefreshProgress(total=10) + d = progress.to_dict() + + assert "total" in d + assert "completed" in d + assert "status" in d + + def test_manager_progress_tracking(self): + """管理器进度跟踪""" + manager = RefreshManager() + + # 开始刷新 + manager._start_refresh(5, "测试") + + progress = manager.get_progress() + assert progress is not None + assert progress.total == 5 + assert progress.status == "running" + + # 更新进度 + manager._update_progress(current_account="acc_1", success=True) + + progress = manager.get_progress() + assert progress.completed == 1 + assert progress.success == 1 + + # 完成 + manager._finish_refresh("completed") + + progress = manager.get_progress() + assert progress.status == "completed" + + +# ============== 单元测试:全局实例 ============== + +class TestGlobalInstance: + """全局实例测试""" + + def test_singleton_pattern(self): + """单例模式""" + reset_refresh_manager() + + manager1 = get_refresh_manager() + manager2 = get_refresh_manager() + + assert manager1 is manager2 + + def test_reset_manager(self): + """重置管理器""" + manager1 = get_refresh_manager() + reset_refresh_manager() + manager2 = get_refresh_manager() + + assert manager1 is not manager2 + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) + + +# ============== Property 14: 跳过错误状态账号 ============== +# **Validates: Requirements 14.6, 17.4** + +class TestSkipErrorAccounts: + """Property 14: 跳过错误状态账号测试""" + + @pytest.mark.asyncio + async def test_skip_disabled_accounts(self): + """ + Property 14: 跳过错误状态账号 + *对于任意*批量刷新操作,已禁用的账号应被跳过。 + + **Validates: Requirements 14.6, 17.4** + """ + manager = RefreshManager() + + # 创建账号列表 + enabled_account = MockAccount("enabled", enabled=True) + disabled_account = MockAccount("disabled", enabled=False) + + call_count = 0 + + async def track_quota_func(acc): + nonlocal call_count + call_count += 1 + return True, "成功" + + result = await manager.refresh_all_with_token( + [enabled_account, disabled_account], + get_quota_func=track_quota_func, + skip_disabled=True + ) + + # 只有启用的账号被处理 + assert result.total == 1, "只应处理启用的账号" + + @pytest.mark.asyncio + async def test_skip_unhealthy_accounts(self): + """跳过不健康状态的账号""" + manager = RefreshManager() + + healthy_account = MockAccount("healthy") + healthy_account.status.value = "active" + + unhealthy_account = MockAccount("unhealthy") + unhealthy_account.status.value = "unhealthy" + + result = await manager.refresh_all_with_token( + [healthy_account, unhealthy_account], + skip_error=True + ) + + assert result.total == 1, "只应处理健康的账号" + + +# ============== Property 18: 定时器唯一性 ============== +# **Validates: Requirements 17.6** + +class TestTimerUniqueness: + """Property 18: 定时器唯一性测试""" + + @pytest.mark.asyncio + async def test_single_timer_running(self): + """ + Property 18: 定时器唯一性 + *对于任意*时刻,系统中应最多只有一个自动刷新定时器在运行。 + + **Validates: Requirements 17.6** + """ + config = RefreshConfig(auto_refresh_interval=1) + manager = RefreshManager(config=config) + + # 启动第一个定时器 + await manager.start_auto_refresh() + task1 = manager._auto_refresh_task + + assert manager.is_auto_refresh_running() + + # 再次启动应该替换旧定时器 + await manager.start_auto_refresh() + task2 = manager._auto_refresh_task + + # 应该是不同的任务(旧的被取消) + assert task1 is not task2 or task1.cancelled() + assert manager.is_auto_refresh_running() + + # 清理 + await manager.stop_auto_refresh() + assert not manager.is_auto_refresh_running() + + @pytest.mark.asyncio + async def test_stop_clears_timer(self): + """停止应该清除定时器""" + config = RefreshConfig(auto_refresh_interval=1) + manager = RefreshManager(config=config) + + await manager.start_auto_refresh() + assert manager.is_auto_refresh_running() + + await manager.stop_auto_refresh() + assert not manager.is_auto_refresh_running() + assert manager._auto_refresh_task is None + + +# ============== Property 19: 刷新失败隔离 ============== +# **Validates: Requirements 17.5** + +class TestRefreshFailureIsolation: + """Property 19: 刷新失败隔离测试""" + + @pytest.mark.asyncio + async def test_single_failure_does_not_affect_others(self): + """ + Property 19: 刷新失败隔离 + *对于任意*批量刷新操作,单个账号的刷新失败不应影响其他账号的刷新。 + + **Validates: Requirements 17.5** + """ + # 使用无重试配置 + config = RefreshConfig(max_retries=0) + manager = RefreshManager(config=config) + + # 创建账号 + account1 = MockAccount("acc1") + account2 = MockAccount("acc2") + account3 = MockAccount("acc3") + + processed_accounts = set() + + async def track_and_fail_second(acc): + processed_accounts.add(acc.id) + if acc.id == "acc2": + return False, "模拟失败" + return True, "成功" + + result = await manager.refresh_all_with_token( + [account1, account2, account3], + get_quota_func=track_and_fail_second + ) + + # 所有账号都应该被处理 + assert len(processed_accounts) == 3, "所有账号都应该被尝试处理" + assert "acc1" in processed_accounts + assert "acc2" in processed_accounts + assert "acc3" in processed_accounts + + # 结果应该反映成功和失败 + assert result.success == 2 + assert result.failed == 1 + + +# ============== 自动刷新状态测试 ============== + +class TestAutoRefreshStatus: + """自动刷新状态测试""" + + def test_auto_refresh_status(self): + """获取自动刷新状态""" + config = RefreshConfig(auto_refresh_interval=30, token_refresh_before_expiry=600) + manager = RefreshManager(config=config) + + status = manager.get_auto_refresh_status() + + assert status["running"] is False + assert status["interval"] == 30 + assert status["token_refresh_before_expiry"] == 600 diff --git a/KiroProxy/tests/test_thinking_config.py b/KiroProxy/tests/test_thinking_config.py new file mode 100644 index 0000000000000000000000000000000000000000..dbd7fced3183c4885fb2c9b3483d1a81fe6679b7 --- /dev/null +++ b/KiroProxy/tests/test_thinking_config.py @@ -0,0 +1,114 @@ +from pathlib import Path +import sys + +import pytest + +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from kiro_proxy.core.thinking import ( + ThinkingConfig, + build_thinking_prompt, + build_user_prompt_with_thinking, + extract_thinking_config_from_gemini_body, + extract_thinking_config_from_openai_body, + infer_thinking_from_anthropic_messages, + infer_thinking_from_gemini_contents, + infer_thinking_from_openai_messages, + infer_thinking_from_openai_responses_input, + normalize_thinking_config, +) + + +def test_normalize_thinking_config_defaults_to_disabled_unlimited(): + cfg = normalize_thinking_config(None) + assert cfg == ThinkingConfig(False, None) + + +@pytest.mark.parametrize( + "raw,expected", + [ + (True, ThinkingConfig(True, None)), + ("enabled", ThinkingConfig(True, None)), + ({"type": "enabled"}, ThinkingConfig(True, None)), + ({"thinking_type": "enabled", "budget_tokens": 20000}, ThinkingConfig(True, 20000)), + ({"enabled": True, "budget_tokens": 0}, ThinkingConfig(True, None)), + ({"includeThoughts": True, "thinkingBudget": 1234}, ThinkingConfig(True, 1234)), + ({"type": "disabled", "budget_tokens": 9999}, ThinkingConfig(False, 9999)), + ], +) +def test_normalize_thinking_config_variants(raw, expected): + assert normalize_thinking_config(raw) == expected + + +def test_extract_thinking_config_from_openai_body(): + cfg, explicit = extract_thinking_config_from_openai_body({}) + assert cfg == ThinkingConfig(False, None) + assert explicit is False + + cfg, explicit = extract_thinking_config_from_openai_body({"thinking": {"type": "enabled"}}) + assert cfg.enabled is True + assert explicit is True + + cfg, explicit = extract_thinking_config_from_openai_body({"reasoning_effort": "high"}) + assert cfg.enabled is True + assert cfg.budget_tokens is None + assert explicit is True + + cfg, explicit = extract_thinking_config_from_openai_body({"reasoning": {"effort": "medium"}}) + assert cfg == ThinkingConfig(True, 20000) + assert explicit is True + + +def test_extract_thinking_config_from_gemini_body(): + cfg, explicit = extract_thinking_config_from_gemini_body({}) + assert cfg == ThinkingConfig(False, None) + assert explicit is False + + cfg, explicit = extract_thinking_config_from_gemini_body( + {"generationConfig": {"thinkingConfig": {"includeThoughts": True, "thinkingBudget": 1234}}} + ) + assert cfg == ThinkingConfig(True, 1234) + assert explicit is True + + +def test_infer_thinking_from_payloads(): + assert ( + infer_thinking_from_anthropic_messages( + [{"role": "assistant", "content": [{"type": "thinking", "thinking": "x"}]}] + ) + is True + ) + + assert infer_thinking_from_openai_messages( + [{"role": "assistant", "content": "AAA\nBBB"}] + ) + + assert infer_thinking_from_openai_responses_input( + [ + { + "type": "message", + "role": "assistant", + "content": [{"type": "output_text", "text": "AAABBB"}], + } + ] + ) + + assert infer_thinking_from_gemini_contents( + [{"role": "model", "parts": [{"text": "AAA\nBBB"}]}] + ) + + +def test_thinking_prompts_include_ultrathink_and_budget_hint(): + p1 = build_thinking_prompt("hi", budget_tokens=None) + assert "ULTRATHINK" in p1 + assert "within" not in p1.lower() + + p2 = build_thinking_prompt("hi", budget_tokens=123) + assert "ULTRATHINK" in p2 + assert "123" in p2 + + +def test_build_user_prompt_with_thinking_wraps_and_forbids_disclosure(): + prompt = build_user_prompt_with_thinking("hello", "secret reasoning") + assert "" in prompt and "" in prompt + assert "Do NOT reveal" in prompt diff --git a/KiroProxy/tests/test_thinking_stream_processor.py b/KiroProxy/tests/test_thinking_stream_processor.py new file mode 100644 index 0000000000000000000000000000000000000000..c55ef506302e0f2f9e892eb8c406bf8fb581895e --- /dev/null +++ b/KiroProxy/tests/test_thinking_stream_processor.py @@ -0,0 +1,67 @@ +"""ThinkingStreamProcessor 单元测试 + +覆盖 标签在流式分片中被拆分的场景,避免思维链泄露到 text 输出。 +""" + +from pathlib import Path +import sys + +import pytest + +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from kiro_proxy.handlers.anthropic import ThinkingStreamProcessor + + +def _collect_events(chunks: list[str]) -> list[dict]: + processor = ThinkingStreamProcessor(thinking_enabled=True) + events: list[dict] = [] + for chunk in chunks: + events.extend(processor.process_content(chunk)) + events.extend(processor.finalize()) + return events + + +def _extract_text(events: list[dict]) -> str: + return "".join( + e["delta"]["text"] + for e in events + if e.get("type") == "content_block_delta" + and e.get("delta", {}).get("type") == "text_delta" + ) + + +def _extract_thinking(events: list[dict]) -> str: + return "".join( + e["delta"]["thinking"] + for e in events + if e.get("type") == "content_block_delta" + and e.get("delta", {}).get("type") == "thinking_delta" + ) + + +@pytest.mark.parametrize( + "chunks,expected_thinking,expected_text", + [ + # 起始标签被拆分 + (["AAABBB"], "AAA", "BBB"), + # 结束标签被拆分 + (["AAABBB"], "AAA", "BBB"), + # 起始/结束标签都可能被拆分(跨多个分片) + (["AAABBB"], "AAA", "BBB"), + # 无 thinking 标签:文本应保持原样 + (["Hello AAA"], "AAA", ""), + ], +) +def test_thinking_stream_processor_chunk_splitting(chunks, expected_thinking, expected_text): + events = _collect_events(chunks) + assert _extract_thinking(events) == expected_thinking + assert _extract_text(events) == expected_text + + # 思考标签不应出现在 text 输出中 + text = _extract_text(events) + assert "" not in text + assert "" not in text + diff --git a/run.bat b/run.bat new file mode 100644 index 0000000000000000000000000000000000000000..794e7e87ade1a628b421f277a304cba7e63a90dd --- /dev/null +++ b/run.bat @@ -0,0 +1,3 @@ +cd KiroProxy +start http://127.0.0.1:6696 +python run.py 6696 \ No newline at end of file