KiroProxy User commited on
Commit
d3cadd5
·
0 Parent(s):

Initial commit: KiroProxy project

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. KiroProxy/.github/workflows/build.yml +245 -0
  2. KiroProxy/.gitignore +54 -0
  3. KiroProxy/CAPTURE_GUIDE.md +0 -0
  4. KiroProxy/README.md +423 -0
  5. KiroProxy/assets/icon.iconset/icon_128x128.png +0 -0
  6. KiroProxy/assets/icon.iconset/icon_16x16.png +0 -0
  7. KiroProxy/assets/icon.iconset/icon_256x256.png +0 -0
  8. KiroProxy/assets/icon.iconset/icon_32x32.png +0 -0
  9. KiroProxy/assets/icon.iconset/icon_512x512.png +0 -0
  10. KiroProxy/assets/icon.iconset/icon_64x64.png +0 -0
  11. KiroProxy/assets/icon.png +0 -0
  12. KiroProxy/assets/icon.svg +1 -0
  13. KiroProxy/build.py +219 -0
  14. KiroProxy/examples/quota_display_example.py +95 -0
  15. KiroProxy/examples/test_quota_display.html +118 -0
  16. KiroProxy/kiro.svg +1 -0
  17. KiroProxy/kiro_proxy/__init__.py +2 -0
  18. KiroProxy/kiro_proxy/__main__.py +5 -0
  19. KiroProxy/kiro_proxy/auth/__init__.py +32 -0
  20. KiroProxy/kiro_proxy/auth/device_flow.py +603 -0
  21. KiroProxy/kiro_proxy/cli.py +375 -0
  22. KiroProxy/kiro_proxy/config.py +133 -0
  23. KiroProxy/kiro_proxy/converters/__init__.py +1196 -0
  24. KiroProxy/kiro_proxy/core/__init__.py +55 -0
  25. KiroProxy/kiro_proxy/core/account.py +287 -0
  26. KiroProxy/kiro_proxy/core/account_selector.py +390 -0
  27. KiroProxy/kiro_proxy/core/browser.py +186 -0
  28. KiroProxy/kiro_proxy/core/error_handler.py +188 -0
  29. KiroProxy/kiro_proxy/core/flow_monitor.py +572 -0
  30. KiroProxy/kiro_proxy/core/history_manager.py +829 -0
  31. KiroProxy/kiro_proxy/core/kiro_api.py +146 -0
  32. KiroProxy/kiro_proxy/core/persistence.py +69 -0
  33. KiroProxy/kiro_proxy/core/protocol_handler.py +318 -0
  34. KiroProxy/kiro_proxy/core/quota_cache.py +397 -0
  35. KiroProxy/kiro_proxy/core/quota_scheduler.py +321 -0
  36. KiroProxy/kiro_proxy/core/rate_limiter.py +125 -0
  37. KiroProxy/kiro_proxy/core/refresh_manager.py +888 -0
  38. KiroProxy/kiro_proxy/core/retry.py +117 -0
  39. KiroProxy/kiro_proxy/core/scheduler.py +125 -0
  40. KiroProxy/kiro_proxy/core/state.py +280 -0
  41. KiroProxy/kiro_proxy/core/stats.py +130 -0
  42. KiroProxy/kiro_proxy/core/thinking.py +456 -0
  43. KiroProxy/kiro_proxy/core/usage.py +235 -0
  44. KiroProxy/kiro_proxy/credential/__init__.py +17 -0
  45. KiroProxy/kiro_proxy/credential/fingerprint.py +131 -0
  46. KiroProxy/kiro_proxy/credential/quota.py +100 -0
  47. KiroProxy/kiro_proxy/credential/refresher.py +195 -0
  48. KiroProxy/kiro_proxy/credential/types.py +121 -0
  49. KiroProxy/kiro_proxy/docs/01-quickstart.md +143 -0
  50. KiroProxy/kiro_proxy/docs/02-features.md +225 -0
KiroProxy/.github/workflows/build.yml ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Build Release
2
+
3
+ on:
4
+ push:
5
+ tags:
6
+ - 'v*'
7
+ workflow_dispatch:
8
+
9
+ permissions:
10
+ contents: write
11
+
12
+ env:
13
+ APP_NAME: KiroProxy
14
+
15
+ jobs:
16
+ build-linux:
17
+ runs-on: ubuntu-latest
18
+ steps:
19
+ - uses: actions/checkout@v4
20
+
21
+ - name: Get version from tag
22
+ id: version
23
+ run: |
24
+ if [[ "${{ github.ref }}" == refs/tags/* ]]; then
25
+ VERSION=${GITHUB_REF#refs/tags/v}
26
+ else
27
+ VERSION=$(grep -oP '__version__ = "\K[^"]+' kiro_proxy/__init__.py)
28
+ fi
29
+ echo "VERSION=$VERSION" >> $GITHUB_OUTPUT
30
+ echo "Version: $VERSION"
31
+
32
+ - name: Set up Python
33
+ uses: actions/setup-python@v5
34
+ with:
35
+ python-version: '3.11'
36
+
37
+ - name: Install dependencies
38
+ run: |
39
+ python -m pip install --upgrade pip
40
+ pip install -r requirements.txt
41
+ pip install pyinstaller
42
+
43
+ - name: Build binary
44
+ run: python build.py
45
+
46
+ - name: Install packaging tools
47
+ run: |
48
+ sudo apt-get update
49
+ sudo apt-get install -y ruby ruby-dev rubygems build-essential rpm libfuse2
50
+ sudo gem install --no-document fpm
51
+
52
+ - name: Create packages
53
+ run: |
54
+ mkdir -p release
55
+ VERSION=${{ steps.version.outputs.VERSION }}
56
+
57
+ # Binary (standalone)
58
+ cp dist/KiroProxy release/KiroProxy-${VERSION}-linux-x86_64
59
+ chmod +x release/KiroProxy-${VERSION}-linux-x86_64
60
+
61
+ # tar.gz
62
+ tar -czvf release/KiroProxy-${VERSION}-linux-x86_64.tar.gz -C dist KiroProxy
63
+
64
+ # deb package
65
+ fpm -s dir -t deb \
66
+ -n kiroproxy \
67
+ -v ${VERSION} \
68
+ --description "Kiro API Proxy Server" \
69
+ --license "MIT" \
70
+ --architecture amd64 \
71
+ --maintainer "petehsu" \
72
+ --url "https://github.com/petehsu/KiroProxy" \
73
+ -p release/kiroproxy_${VERSION}_amd64.deb \
74
+ dist/KiroProxy=/usr/local/bin/KiroProxy
75
+
76
+ # rpm package
77
+ fpm -s dir -t rpm \
78
+ -n kiroproxy \
79
+ -v ${VERSION} \
80
+ --description "Kiro API Proxy Server" \
81
+ --license "MIT" \
82
+ --architecture x86_64 \
83
+ --maintainer "petehsu" \
84
+ --url "https://github.com/petehsu/KiroProxy" \
85
+ -p release/kiroproxy-${VERSION}-1.x86_64.rpm \
86
+ dist/KiroProxy=/usr/local/bin/KiroProxy
87
+
88
+ - name: Upload artifacts
89
+ uses: actions/upload-artifact@v4
90
+ with:
91
+ name: KiroProxy-Linux
92
+ path: release/*
93
+
94
+ build-windows:
95
+ runs-on: windows-latest
96
+ steps:
97
+ - uses: actions/checkout@v4
98
+
99
+ - name: Get version from tag
100
+ id: version
101
+ shell: bash
102
+ run: |
103
+ if [[ "${{ github.ref }}" == refs/tags/* ]]; then
104
+ VERSION=${GITHUB_REF#refs/tags/v}
105
+ else
106
+ VERSION=$(grep -oP '__version__ = "\K[^"]+' kiro_proxy/__init__.py)
107
+ fi
108
+ echo "VERSION=$VERSION" >> $GITHUB_OUTPUT
109
+ echo "Version: $VERSION"
110
+
111
+ - name: Set up Python
112
+ uses: actions/setup-python@v5
113
+ with:
114
+ python-version: '3.11'
115
+
116
+ - name: Install dependencies
117
+ run: |
118
+ python -m pip install --upgrade pip
119
+ pip install -r requirements.txt
120
+ pip install pyinstaller
121
+
122
+ - name: Build
123
+ run: python build.py
124
+
125
+ - name: Create packages
126
+ shell: pwsh
127
+ run: |
128
+ $VERSION = "${{ steps.version.outputs.VERSION }}"
129
+ New-Item -ItemType Directory -Force -Path release
130
+
131
+ # exe (standalone)
132
+ Copy-Item dist/KiroProxy.exe release/KiroProxy-${VERSION}-windows-x86_64.exe
133
+
134
+ # zip
135
+ Compress-Archive -Path dist/KiroProxy.exe -DestinationPath release/KiroProxy-${VERSION}-windows-x86_64.zip
136
+
137
+ - name: Upload artifacts
138
+ uses: actions/upload-artifact@v4
139
+ with:
140
+ name: KiroProxy-Windows
141
+ path: release/*
142
+
143
+ build-macos:
144
+ runs-on: macos-latest
145
+ steps:
146
+ - uses: actions/checkout@v4
147
+
148
+ - name: Get version from tag
149
+ id: version
150
+ run: |
151
+ if [[ "${{ github.ref }}" == refs/tags/* ]]; then
152
+ VERSION=${GITHUB_REF#refs/tags/v}
153
+ else
154
+ VERSION=$(grep -oP '__version__ = "\K[^"]+' kiro_proxy/__init__.py || echo "1.0.0")
155
+ fi
156
+ echo "VERSION=$VERSION" >> $GITHUB_OUTPUT
157
+ echo "Version: $VERSION"
158
+
159
+ - name: Set up Python
160
+ uses: actions/setup-python@v5
161
+ with:
162
+ python-version: '3.11'
163
+
164
+ - name: Install dependencies
165
+ run: |
166
+ python -m pip install --upgrade pip
167
+ pip install -r requirements.txt
168
+ pip install pyinstaller
169
+
170
+ - name: Generate icon
171
+ run: |
172
+ mkdir -p assets/icon.iconset
173
+ for size in 16 32 64 128 256 512; do
174
+ sips -z $size $size assets/icon.png --out assets/icon.iconset/icon_${size}x${size}.png
175
+ done
176
+ iconutil -c icns assets/icon.iconset -o assets/icon.icns
177
+
178
+ - name: Build
179
+ run: python build.py
180
+
181
+ - name: Create packages
182
+ run: |
183
+ VERSION=${{ steps.version.outputs.VERSION }}
184
+ mkdir -p release
185
+
186
+ # Binary (standalone)
187
+ cp dist/KiroProxy release/KiroProxy-${VERSION}-macos-x86_64
188
+ chmod +x release/KiroProxy-${VERSION}-macos-x86_64
189
+
190
+ # zip
191
+ cd dist && zip -r ../release/KiroProxy-${VERSION}-macos-x86_64.zip KiroProxy && cd ..
192
+
193
+ - name: Upload artifacts
194
+ uses: actions/upload-artifact@v4
195
+ with:
196
+ name: KiroProxy-macOS
197
+ path: release/*
198
+
199
+ release:
200
+ needs: [build-linux, build-windows, build-macos]
201
+ runs-on: ubuntu-latest
202
+ if: startsWith(github.ref, 'refs/tags/')
203
+
204
+ steps:
205
+ - uses: actions/checkout@v4
206
+
207
+ - name: Get version from tag
208
+ id: version
209
+ run: |
210
+ VERSION=${GITHUB_REF#refs/tags/v}
211
+ echo "VERSION=$VERSION" >> $GITHUB_OUTPUT
212
+
213
+ - name: Download all artifacts
214
+ uses: actions/download-artifact@v4
215
+ with:
216
+ path: artifacts
217
+
218
+ - name: List artifacts
219
+ run: find artifacts -type f
220
+
221
+ - name: Create Release
222
+ uses: softprops/action-gh-release@v1
223
+ with:
224
+ name: KiroProxy v${{ steps.version.outputs.VERSION }}
225
+ body: |
226
+ ## Downloads
227
+
228
+ | Platform | File | Description |
229
+ |----------|------|-------------|
230
+ | **Linux** | `KiroProxy-${{ steps.version.outputs.VERSION }}-linux-x86_64` | Standalone binary |
231
+ | | `KiroProxy-${{ steps.version.outputs.VERSION }}-linux-x86_64.tar.gz` | Compressed archive |
232
+ | | `kiroproxy_${{ steps.version.outputs.VERSION }}_amd64.deb` | Debian/Ubuntu package |
233
+ | | `kiroproxy-${{ steps.version.outputs.VERSION }}-1.x86_64.rpm` | Fedora/RHEL/CentOS package |
234
+ | **Windows** | `KiroProxy-${{ steps.version.outputs.VERSION }}-windows-x86_64.exe` | Standalone executable |
235
+ | | `KiroProxy-${{ steps.version.outputs.VERSION }}-windows-x86_64.zip` | Compressed archive |
236
+ | **macOS** | `KiroProxy-${{ steps.version.outputs.VERSION }}-macos-x86_64` | Standalone binary |
237
+ | | `KiroProxy-${{ steps.version.outputs.VERSION }}-macos-x86_64.zip` | Compressed archive |
238
+ files: |
239
+ artifacts/KiroProxy-Linux/*
240
+ artifacts/KiroProxy-Windows/*
241
+ artifacts/KiroProxy-macOS/*
242
+ draft: false
243
+ prerelease: false
244
+ env:
245
+ GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
KiroProxy/.gitignore ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+ venv/
6
+ .venv/
7
+ *.egg-info/
8
+ .hypothesis/
9
+ .pytest_cache/
10
+
11
+ # Build
12
+ build/
13
+ dist/
14
+ release/
15
+ *.spec
16
+
17
+ # IDE
18
+ .idea/
19
+ .vscode/
20
+ *.swp
21
+ *.swo
22
+
23
+ # OS
24
+ .DS_Store
25
+ Thumbs.db
26
+
27
+ # HAR files (contain sensitive data)
28
+ *.har
29
+
30
+ # Logs
31
+ *.log
32
+
33
+ # Test files
34
+ [0-9].txt
35
+ [0-9][0-9].txt
36
+ 线索*.txt
37
+
38
+ # Temp analysis files
39
+ flows
40
+ flows_*
41
+ traffic.mitm
42
+ *.mitm
43
+ analyze_har.py
44
+ parse_*.py
45
+ *_analysis.txt
46
+ *_check.txt
47
+ hex_dump.txt
48
+ parsed_*.txt
49
+ response.txt
50
+ 参考.txt
51
+
52
+ # Other projects
53
+ Antigravity-Manager/
54
+ cc-switch/
KiroProxy/CAPTURE_GUIDE.md ADDED
File without changes
KiroProxy/README.md ADDED
@@ -0,0 +1,423 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <p align="center">
2
+ <img src="assets/icon.svg" width="80" height="96" alt="Kiro Proxy">
3
+ </p>
4
+
5
+ <h1 align="center">Kiro API Proxy</h1>
6
+
7
+ <p align="center">
8
+ Kiro IDE API 反向代理服务器,支持多账号轮询、Token 自动刷新、配额管理
9
+ </p>
10
+
11
+ <p align="center">
12
+ <a href="#功能特性">功能</a> •
13
+ <a href="#快速开始">快速开始</a> •
14
+ <a href="#cli-配置">CLI 配置</a> •
15
+ <a href="#api-端点">API</a> •
16
+ <a href="#许可证">许可证</a>
17
+ </p>
18
+
19
+ ---
20
+
21
+ > **⚠️ 测试说明**
22
+ >
23
+ > 本项目支持 **Claude Code**、**Codex CLI**、**Gemini CLI** 三种客户端,工具调用功能已全面支持。
24
+
25
+ ## 功能特性
26
+
27
+ ### 核心功能
28
+ - **多协议支持** - OpenAI / Anthropic / Gemini 三种协议兼容
29
+ - **完整工具调用** - 三种协议的工具调用功能全面支持
30
+ - **图片理解** - 支持 Claude Code / Codex CLI 图片输入
31
+ - **网络搜索** - 支持 Claude Code / Codex CLI 网络搜索工具
32
+ - **思考功能** - 支持 Claude 的扩展思考功能(Extended Thinking)
33
+ - **多账号轮询(默认随机)** - 每次请求随机切换账号,分散压力,避免单账号 RPM 过高
34
+ - **会话粘性(可选)** - 非 `random` 策略下,同一会话 60 秒内使用同一账号,保持上下文
35
+ - **Web UI** - 简洁的管理界面,支持监控、日志、设置
36
+
37
+ ### v1.7.1 新功能
38
+ - **Windows 支持补强** - 注册表浏览器检测 + PATH 回退,兼容便携版
39
+ - **打包资源修复** - PyInstaller 打包后可正常加载图标与内置文档
40
+ - **Token 扫描稳定性** - Windows 路径编码处理修复
41
+
42
+ ### v1.6.3 新功能
43
+ - **命令行工具 (CLI)** - 无 GUI 服务器也能轻松管理
44
+ - `python run.py accounts list` - 列出账号
45
+ - `python run.py accounts export/import` - 导出/导入账号
46
+ - `python run.py accounts add` - 交互式添加 Token
47
+ - `python run.py accounts scan` - 扫描本地 Token
48
+ - `python run.py login google/github` - 命令行登录
49
+ - `python run.py login remote` - 生成远程登录链接
50
+ - **远程登录链接** - 在有浏览器的机器上完成授权,Token 自动同步
51
+ - **账号导入导出** - 跨机器迁移账号配置
52
+ - **手动添加 Token** - 直接粘贴 accessToken/refreshToken
53
+
54
+ ### v1.6.2 新功能
55
+ - **Codex CLI 完整支持** - 使用 OpenAI Responses API (`/v1/responses`)
56
+ - 完整工具调用支持(shell、file 等所有工具)
57
+ - 图片输入支持(`input_image` 类型)
58
+ - 网络搜索支持(`web_search` 工具)
59
+ - 错误代码映射(rate_limit、context_length 等)
60
+ - **Claude Code 增强** - 图片理解和网络搜索完整支持
61
+ - 支持 Anthropic 和 OpenAI 两种图片格式
62
+ - 支持 `web_search` / `web_search_20250305` 工具
63
+
64
+ ### v1.6.1 新功能
65
+ - **请求限速** - 通过限制请求频率降低账号封禁风险
66
+ - 每账号最小请求间隔
67
+ - 每账号每分钟最大请求数
68
+ - 全局每分钟最大请求数
69
+ - WebUI 设置页面可配置
70
+ - **账号封禁检测** - 自动检测 TEMPORARILY_SUSPENDED 错误
71
+ - 友好的错误日志输出
72
+ - 自动禁用被封禁账号
73
+ - 自动切换到其他可用账号
74
+ - **统一错误处理** - 三种协议使用统一的错误分类和处理
75
+
76
+ ### v1.6.0 功能
77
+ - **历史消息管理** - 4 种策略处理对话长度限制,可自由组合
78
+ - 自动截断:发送前优先保留最新上下文并摘要前文,必要时按数量/字符数截断
79
+ - 智能摘要:用 AI 生成早期对话摘要,保留关键信息
80
+ - 摘要缓存:历史变化不大时复用最近摘要,减少重复 LLM 调用(默认启用)
81
+ - 错误重试:遇到长度错误时自动截断重试(默认启用)
82
+ - 预估检测:预估 token 数量,超限预先截断
83
+ - **Gemini 工具调用** - 完整支持 functionDeclarations/functionCall/functionResponse
84
+ - **设置页面** - WebUI 新增设置标签页,可配置历史消息管理策略
85
+
86
+ ### v1.5.0 功能
87
+ - **用量查询** - 查询账号配额使用情况,显示已用/余额/使用率
88
+ - **多登录方式** - 支持 Google / GitHub / AWS Builder ID 三种登录方式
89
+ - **流量监控** - 完整的 LLM 请求监控,支持搜索、过滤、导出
90
+ - **浏览器选择** - 自动检测已安装浏览器,支持无痕模式
91
+ - **文档中心** - 内置帮助文档,左侧目录 + 右侧 Markdown 渲染
92
+
93
+ ### v1.4.0 功能
94
+ - **Token 预刷新** - 后台每 5 分钟检查,提前 15 分钟自动刷新
95
+ - **健康检查** - 每 10 分钟检测账号可用性,自动标记状态
96
+ - **请求统计增强** - 按账号/模型统计,24 小时趋势
97
+ - **请求重试机制** - 网络错误/5xx 自动重试,指数退避
98
+
99
+ ## 工具调用支持
100
+
101
+ | 功能 | Anthropic (Claude Code) | OpenAI (Codex CLI) | Gemini |
102
+ |------|------------------------|-------------------|--------|
103
+ | 工具定义 | ✅ `tools` | ✅ `tools.function` | ✅ `functionDeclarations` |
104
+ | 工具调用响应 | ✅ `tool_use` | ✅ `tool_calls` | ✅ `functionCall` |
105
+ | 工具结果 | ✅ `tool_result` | �� `tool` 角色消息 | ✅ `functionResponse` |
106
+ | 强制工具调用 | ✅ `tool_choice` | ✅ `tool_choice` | ✅ `toolConfig.mode` |
107
+ | 工具数量限制 | ✅ 50 个 | ✅ 50 个 | ✅ 50 个 |
108
+ | 历史消息修复 | ✅ | ✅ | ✅ |
109
+ | 图片理解 | ✅ | ✅ | ❌ |
110
+ | 网络搜索 | ✅ | ✅ | ❌ |
111
+
112
+ ## 已知限制
113
+
114
+ ### 对话长度限制
115
+
116
+ Kiro API 有输入长度限制。当对话历史过长时,会返回错误:
117
+
118
+ ```
119
+ Input is too long. (CONTENT_LENGTH_EXCEEDS_THRESHOLD)
120
+ ```
121
+
122
+ #### 自动处理(v1.6.0+)
123
+
124
+ 代理内置了历史消息管理功能,可在「设置」页面配置:
125
+
126
+ - **错误重试**(默认):遇到长度错误时自动截断并重试
127
+ - **智能摘要**:用 AI 生成早期对话摘要,保留关键信息
128
+ - **摘要缓存**(默认):历史变化不大时复用最近摘要,减少重复 LLM 调用
129
+ - **自动截断**:每次请求前优先保留最新上下文并摘要前文,必要时按数量/字符数截断
130
+ - **预估检测**:预估 token 数量,超限预先截断
131
+
132
+ 摘要缓存可通过以下配置项调整(默认值):
133
+ - `summary_cache_enabled`: `true`
134
+ - `summary_cache_min_delta_messages`: `3`
135
+ - `summary_cache_min_delta_chars`: `4000`
136
+ - `summary_cache_max_age_seconds`: `180`
137
+
138
+ #### 手动处理
139
+
140
+ 1. 在 Claude Code 中输入 `/clear` 清空对话历史
141
+ 2. 告诉 AI 你之前在做什么,它会读取代码文件恢复上下文
142
+
143
+ ## 快速开始
144
+
145
+ ### 方式一:下载预编译版本
146
+
147
+ 从 [Releases](../../releases) 下载对应平台的安装包,解压后直接运行。
148
+
149
+ ### 方式二:从源码运行
150
+
151
+ ```bash
152
+ # 克隆项目
153
+ git clone https://github.com/yourname/kiro-proxy.git
154
+ cd kiro-proxy
155
+
156
+ # 创建虚拟环境
157
+ python -m venv venv
158
+ source venv/bin/activate # Windows: venv\Scripts\activate
159
+
160
+ # 安装依赖
161
+ pip install -r requirements.txt
162
+
163
+ # 运行
164
+ python run.py
165
+
166
+ # 或指定端口
167
+ python run.py 8081
168
+ ```
169
+
170
+ 启动后访问 http://localhost:8080
171
+
172
+ ### 命令行工具 (CLI)
173
+
174
+ 无 GUI 服务器可使用 CLI 管理账号:
175
+
176
+ ```bash
177
+ # 账号管理
178
+ python run.py accounts list # 列出账号
179
+ python run.py accounts export -o acc.json # 导出账号
180
+ python run.py accounts import acc.json # 导入账号
181
+ python run.py accounts add # 交互式添加 Token
182
+ python run.py accounts scan --auto # 扫描并自动添加本地 Token
183
+
184
+ # 登录
185
+ python run.py login google # Google 登录
186
+ python run.py login github # GitHub 登录
187
+ python run.py login remote --host myserver.com:8080 # 生成远程登录链接
188
+
189
+ # 服务
190
+ python run.py serve # 启动服务 (默认 8080)
191
+ python run.py serve -p 8081 # 指定端口
192
+ python run.py status # 查看状态
193
+ ```
194
+
195
+ ### 登录获取 Token
196
+
197
+ **方式一:在线登录(推荐)**
198
+ 1. 打开 Web UI,点击「在线登录」
199
+ 2. 选择登录方式:Google / GitHub / AWS Builder ID
200
+ 3. 在浏览器中完成授权
201
+ 4. 账号自动添加
202
+
203
+ **方式二:扫描 Token**
204
+ 1. 打开 Kiro IDE,使用 Google/GitHub 账号登录
205
+ 2. 登录成功后 token 自动保存到 `~/.aws/sso/cache/`
206
+ 3. 在 Web UI 点击「扫描 Token」添加账号
207
+
208
+ ## CLI 配置
209
+
210
+ ### 模型对照表
211
+
212
+ | Kiro 模型 | 能力 | Claude Code | Codex |
213
+ |-----------|------|-------------|-------|
214
+ | `claude-sonnet-4` | ⭐⭐⭐ 推荐 | `claude-sonnet-4` | `gpt-4o` |
215
+ | `claude-sonnet-4.5` | ⭐⭐⭐⭐ 更强 | `claude-sonnet-4.5` | `gpt-4o` |
216
+ | `claude-haiku-4.5` | ⚡ 快速 | `claude-haiku-4.5` | `gpt-4o-mini` |
217
+
218
+ ### Claude Code 配置
219
+
220
+ ```
221
+ 名称: Kiro Proxy
222
+ API Key: any
223
+ Base URL: http://localhost:8080
224
+ 模型: claude-sonnet-4
225
+ ```
226
+
227
+ ### Codex 配置
228
+
229
+ Codex CLI 使用 OpenAI Responses API,配置如下:
230
+
231
+ ```bash
232
+ # 设置环境变量
233
+ export OPENAI_API_KEY=any
234
+ export OPENAI_BASE_URL=http://localhost:8080/v1
235
+
236
+ # 运行 Codex
237
+ codex
238
+ ```
239
+
240
+ 或在 `~/.codex/config.toml` 中配置:
241
+
242
+ ```toml
243
+ [providers.openai]
244
+ api_key = "any"
245
+ base_url = "http://localhost:8080/v1"
246
+ ```
247
+
248
+ ## 思考功能支持
249
+
250
+ ### 什么是思考功能
251
+
252
+ 思考功能(Extended Thinking)允许 Claude 在生成回答前展示其思考过程,帮助用户理解 AI 的推理步骤。
253
+
254
+ ### 如何使用
255
+
256
+ 在请求中添加 `thinking`(或对应协议的 thinking 配置)即可启用:
257
+
258
+ ```json
259
+ {
260
+ "model": "claude-sonnet-4.5",
261
+ "messages": [
262
+ {
263
+ "role": "user",
264
+ "content": "解释一下量子计算的原理"
265
+ }
266
+ ],
267
+ "thinking": {
268
+ "thinking_type": "enabled",
269
+ "budget_tokens": 20000
270
+ },
271
+ "stream": true
272
+ }
273
+ ```
274
+
275
+ OpenAI Chat Completions (`POST /v1/chat/completions`) 也支持:
276
+
277
+ ```json
278
+ {
279
+ "model": "gpt-4o",
280
+ "messages": [{"role": "user", "content": "解释一下量子计算的原理"}],
281
+ "thinking": { "type": "enabled" },
282
+ "stream": true
283
+ }
284
+ ```
285
+
286
+ OpenAI Responses (`POST /v1/responses`) 也支持:
287
+
288
+ ```json
289
+ {
290
+ "model": "gpt-4o",
291
+ "input": "解释一下量子计算的原理",
292
+ "thinking": { "type": "enabled" }
293
+ }
294
+ ```
295
+
296
+ Gemini generateContent (`POST /v1/models/{model}:generateContent`) 也支持:
297
+
298
+ ```json
299
+ {
300
+ "contents": [{"role": "user", "parts": [{"text": "解释一下量子计算的原理"}]}],
301
+ "generationConfig": {
302
+ "thinkingConfig": { "includeThoughts": true }
303
+ }
304
+ }
305
+ ```
306
+
307
+ ### 参数说明
308
+
309
+ - `thinking_type`: 思考类型,设为 `"enabled"` 启用思考功能
310
+ - `budget_tokens`: 思考过程的 token 预算(不传则视为无限制)
311
+
312
+ ### 响应格式
313
+
314
+ 启用思考功能后,流式响应会包含两种内容块:
315
+
316
+ 1. **思考块**(type: "thinking"):展示 AI 的思考过程
317
+ 2. **文本块**(type: "text"):最终的回答内容
318
+
319
+ 示例响应:
320
+ ```
321
+ data: {"type":"content_block_start","index":1,"content_block":{"type":"thinking","thinking":""}}
322
+ data: {"type":"content_block_delta","index":1,"delta":{"type":"thinking_delta","thinking":"让我思考一下量子计算的原理..."}}
323
+ data: {"type":"content_block_stop","index":1}
324
+ data: {"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}}
325
+ data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"量子计算是一种..."}}
326
+ data: {"type":"content_block_stop","index":0}
327
+ ```
328
+
329
+ ## API 端点
330
+
331
+ | 协议 | 端点 | 用途 |
332
+ |------|------|------|
333
+ | OpenAI | `POST /v1/chat/completions` | Chat Completions API |
334
+ | OpenAI | `POST /v1/responses` | Responses API (Codex CLI) |
335
+ | OpenAI | `GET /v1/models` | 模型列表 |
336
+ | Anthropic | `POST /v1/messages` | Claude Code |
337
+ | Anthropic | `POST /v1/messages/count_tokens` | Token 计数 |
338
+ | Gemini | `POST /v1/models/{model}:generateContent` | Gemini CLI |
339
+
340
+ ### 管理 API
341
+
342
+ | 端点 | 方法 | 说明 |
343
+ |------|------|------|
344
+ | `/api/accounts` | GET | 获取所有账号状态 |
345
+ | `/api/accounts/{id}` | GET | 获取账号详情 |
346
+ | `/api/accounts/{id}/usage` | GET | 获取账号用量信息 |
347
+ | `/api/accounts/{id}/refresh` | POST | 刷新账号 Token |
348
+ | `/api/accounts/{id}/restore` | POST | 恢复账号(从冷却状态) |
349
+ | `/api/accounts/refresh-all` | POST | 刷新所有即将过期的 Token |
350
+ | `/api/flows` | GET | 获取流量记录 |
351
+ | `/api/flows/stats` | GET | 获取流量统计 |
352
+ | `/api/flows/{id}` | GET | 获取流量详情 |
353
+ | `/api/quota` | GET | 获取配额状态 |
354
+ | `/api/stats` | GET | 获取统计信息 |
355
+ | `/api/health-check` | POST | 手动触发健康检查 |
356
+ | `/api/browsers` | GET | 获取可用浏览器列表 |
357
+ | `/api/docs` | GET | 获取文档列表 |
358
+ | `/api/docs/{id}` | GET | 获取文档内容 |
359
+
360
+ ## 项目结构
361
+
362
+ ```
363
+ kiro_proxy/
364
+ ├── main.py # FastAPI 应用入口
365
+ ├── config.py # 全局配置
366
+ ├── converters.py # 协议转换
367
+
368
+ ├── core/ # 核心模块
369
+ │ ├── account.py # 账号管理
370
+ │ ├── state.py # 全局状态
371
+ │ ├── persistence.py # 配置持久化
372
+ │ ├── scheduler.py # 后台任务调度
373
+ │ ├── stats.py # 请求统计
374
+ │ ├── retry.py # 重试机制
375
+ │ ├── browser.py # 浏览器检测
376
+ │ ├── flow_monitor.py # 流量监控
377
+ │ └── usage.py # 用量查询
378
+
379
+ ├── credential/ # 凭证管理
380
+ │ ├── types.py # KiroCredentials
381
+ │ ├── fingerprint.py # Machine ID 生成
382
+ │ ├── quota.py # 配额管理器
383
+ │ └── refresher.py # Token 刷新
384
+
385
+ ├── auth/ # 认证模块
386
+ │ └── device_flow.py # Device Code Flow / Social Auth
387
+
388
+ ├── handlers/ # API 处理器
389
+ │ ├── anthropic.py # /v1/messages
390
+ │ ├── openai.py # /v1/chat/completions
391
+ │ ├── responses.py # /v1/responses (Codex CLI)
392
+ │ ├── gemini.py # /v1/models/{model}:generateContent
393
+ │ └── admin.py # 管理 API
394
+
395
+ ├── cli.py # 命令行工具
396
+
397
+ ├── docs/ # 内置文档
398
+ │ ├── 01-quickstart.md # 快速开始
399
+ │ ├── 02-features.md # 功能特性
400
+ │ ├── 03-faq.md # 常见问题
401
+ │ └── 04-api.md # API 参考
402
+
403
+ └── web/
404
+ └── html.py # Web UI (组件化单文件)
405
+ ```
406
+
407
+ ## 构建
408
+
409
+ ```bash
410
+ # 安装构建依赖
411
+ pip install pyinstaller
412
+
413
+ # 构建
414
+ python build.py
415
+ ```
416
+
417
+ 输出文件在 `dist/` 目录。
418
+
419
+ ## 免责声明
420
+
421
+ 本项目仅供学习研究,禁止商用。使用本项目产生的任何后果由使用者自行承担,与作者无关。
422
+
423
+ 本项目与 Kiro / AWS / Anthropic 官方无关。
KiroProxy/assets/icon.iconset/icon_128x128.png ADDED
KiroProxy/assets/icon.iconset/icon_16x16.png ADDED
KiroProxy/assets/icon.iconset/icon_256x256.png ADDED
KiroProxy/assets/icon.iconset/icon_32x32.png ADDED
KiroProxy/assets/icon.iconset/icon_512x512.png ADDED
KiroProxy/assets/icon.iconset/icon_64x64.png ADDED
KiroProxy/assets/icon.png ADDED
KiroProxy/assets/icon.svg ADDED
KiroProxy/build.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Kiro Proxy Cross-platform Build Script
4
+ Supports: Windows / macOS / Linux
5
+
6
+ Usage:
7
+ python build.py # Build for current platform
8
+ python build.py --all # Show all platform instructions
9
+ """
10
+
11
+ import os
12
+ import sys
13
+ import shutil
14
+ import subprocess
15
+ from pathlib import Path
16
+
17
+ from kiro_proxy import __version__ as VERSION
18
+
19
+ APP_NAME = "KiroProxy"
20
+ MAIN_SCRIPT = "run.py"
21
+ ICON_DIR = Path("assets")
22
+
23
+ def get_platform():
24
+ if sys.platform == "win32":
25
+ return "windows"
26
+ elif sys.platform == "darwin":
27
+ return "macos"
28
+ else:
29
+ return "linux"
30
+
31
+ def ensure_pyinstaller():
32
+ try:
33
+ import PyInstaller
34
+ print(f"[OK] PyInstaller {PyInstaller.__version__} installed")
35
+ except ImportError:
36
+ print("[..] Installing PyInstaller...")
37
+ subprocess.run([sys.executable, "-m", "pip", "install", "pyinstaller"], check=True)
38
+
39
+ def clean_build():
40
+ for d in ["build", "dist", f"{APP_NAME}.spec"]:
41
+ if os.path.isdir(d):
42
+ shutil.rmtree(d)
43
+ elif os.path.isfile(d):
44
+ os.remove(d)
45
+ print("[OK] Cleaned build directories")
46
+
47
+ def build_app():
48
+ platform = get_platform()
49
+ print(f"\n{'='*50}")
50
+ print(f" Building {APP_NAME} v{VERSION} - {platform}")
51
+ print(f"{'='*50}\n")
52
+
53
+ ensure_pyinstaller()
54
+ clean_build()
55
+
56
+ args = [
57
+ sys.executable, "-m", "PyInstaller",
58
+ "--name", APP_NAME,
59
+ "--onefile",
60
+ "--clean",
61
+ "--noconfirm",
62
+ ]
63
+
64
+ icon_file = None
65
+ if platform == "windows" and (ICON_DIR / "icon.ico").exists():
66
+ icon_file = ICON_DIR / "icon.ico"
67
+ elif platform == "macos" and (ICON_DIR / "icon.icns").exists():
68
+ icon_file = ICON_DIR / "icon.icns"
69
+ elif (ICON_DIR / "icon.png").exists():
70
+ icon_file = ICON_DIR / "icon.png"
71
+
72
+ if icon_file:
73
+ args.extend(["--icon", str(icon_file)])
74
+ print(f"[OK] Using icon: {icon_file}")
75
+
76
+ # 添加资源文件打包
77
+ if (ICON_DIR).exists():
78
+ if platform == "windows":
79
+ args.extend(["--add-data", f"{ICON_DIR};assets"])
80
+ else:
81
+ args.extend(["--add-data", f"{ICON_DIR}:assets"])
82
+ print(f"[OK] Adding assets directory")
83
+
84
+ # 添加文档文件打包
85
+ docs_dir = Path("kiro_proxy/docs")
86
+ if docs_dir.exists():
87
+ if platform == "windows":
88
+ args.extend(["--add-data", f"{docs_dir};kiro_proxy/docs"])
89
+ else:
90
+ args.extend(["--add-data", f"{docs_dir}:kiro_proxy/docs"])
91
+ print(f"[OK] Adding docs directory")
92
+
93
+ hidden_imports = [
94
+ "uvicorn.logging",
95
+ "uvicorn.protocols.http",
96
+ "uvicorn.protocols.http.auto",
97
+ "uvicorn.protocols.http.h11_impl",
98
+ "uvicorn.protocols.websockets",
99
+ "uvicorn.protocols.websockets.auto",
100
+ "uvicorn.lifespan",
101
+ "uvicorn.lifespan.on",
102
+ "httpx",
103
+ "httpx._transports",
104
+ "httpx._transports.default",
105
+ "anyio",
106
+ "anyio._backends",
107
+ "anyio._backends._asyncio",
108
+ ]
109
+ for imp in hidden_imports:
110
+ args.extend(["--hidden-import", imp])
111
+
112
+ args.append(MAIN_SCRIPT)
113
+ args = [a for a in args if a]
114
+
115
+ print(f"[..] Running: {' '.join(args)}\n")
116
+ result = subprocess.run(args)
117
+
118
+ if result.returncode == 0:
119
+ if platform == "windows":
120
+ output = Path("dist") / f"{APP_NAME}.exe"
121
+ else:
122
+ output = Path("dist") / APP_NAME
123
+
124
+ if output.exists():
125
+ size_mb = output.stat().st_size / (1024 * 1024)
126
+ print(f"\n{'='*50}")
127
+ print(f" [OK] Build successful!")
128
+ print(f" Output: {output}")
129
+ print(f" Size: {size_mb:.1f} MB")
130
+ print(f"{'='*50}")
131
+
132
+ create_release_package(platform, output)
133
+ else:
134
+ print("[FAIL] Build failed: output file not found")
135
+ sys.exit(1)
136
+ else:
137
+ print("[FAIL] Build failed")
138
+ sys.exit(1)
139
+
140
+ def create_release_package(platform, binary_path):
141
+ release_dir = Path("release")
142
+ release_dir.mkdir(exist_ok=True)
143
+
144
+ if platform == "windows":
145
+ archive_name = f"{APP_NAME}-{VERSION}-Windows"
146
+ shutil.copy(binary_path, release_dir / f"{APP_NAME}.exe")
147
+ shutil.make_archive(
148
+ str(release_dir / archive_name),
149
+ "zip",
150
+ release_dir,
151
+ f"{APP_NAME}.exe"
152
+ )
153
+ (release_dir / f"{APP_NAME}.exe").unlink()
154
+ print(f" Release: release/{archive_name}.zip")
155
+
156
+ elif platform == "macos":
157
+ archive_name = f"{APP_NAME}-{VERSION}-macOS"
158
+ shutil.copy(binary_path, release_dir / APP_NAME)
159
+ os.chmod(release_dir / APP_NAME, 0o755)
160
+ shutil.make_archive(
161
+ str(release_dir / archive_name),
162
+ "zip",
163
+ release_dir,
164
+ APP_NAME
165
+ )
166
+ (release_dir / APP_NAME).unlink()
167
+ print(f" Release: release/{archive_name}.zip")
168
+
169
+ else:
170
+ archive_name = f"{APP_NAME}-{VERSION}-Linux"
171
+ shutil.copy(binary_path, release_dir / APP_NAME)
172
+ os.chmod(release_dir / APP_NAME, 0o755)
173
+ shutil.make_archive(
174
+ str(release_dir / archive_name),
175
+ "gztar",
176
+ release_dir,
177
+ APP_NAME
178
+ )
179
+ (release_dir / APP_NAME).unlink()
180
+ print(f" Release: release/{archive_name}.tar.gz")
181
+
182
+ def show_all_platforms():
183
+ print(f"""
184
+ {'='*60}
185
+ Kiro Proxy Cross-platform Build Instructions
186
+ {'='*60}
187
+
188
+ This script must run on the target platform.
189
+
190
+ [Windows]
191
+ Run on Windows:
192
+ python build.py
193
+
194
+ Output: release/KiroProxy-{VERSION}-Windows.zip
195
+
196
+ [macOS]
197
+ Run on macOS:
198
+ python build.py
199
+
200
+ Output: release/KiroProxy-{VERSION}-macOS.zip
201
+
202
+ [Linux]
203
+ Run on Linux:
204
+ python build.py
205
+
206
+ Output: release/KiroProxy-{VERSION}-Linux.tar.gz
207
+
208
+ [GitHub Actions]
209
+ Push to GitHub and Actions will build all platforms.
210
+ See .github/workflows/build.yml
211
+
212
+ {'='*60}
213
+ """)
214
+
215
+ if __name__ == "__main__":
216
+ if "--all" in sys.argv or "-a" in sys.argv:
217
+ show_all_platforms()
218
+ else:
219
+ build_app()
KiroProxy/examples/quota_display_example.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """展示额度重置时间功能的示例"""
2
+ import json
3
+ from datetime import datetime
4
+
5
+
6
+ def generate_quota_display_example():
7
+ """生成额度显示示例"""
8
+
9
+ # 模拟账号的额度信息(从 API 获取)
10
+ quota_data = {
11
+ "subscription_title": "Kiro Pro",
12
+ "usage_limit": 700.0,
13
+ "current_usage": 150.0,
14
+ "balance": 550.0,
15
+ "usage_percent": 21.4,
16
+ "is_low_balance": False,
17
+ "is_exhausted": False,
18
+ "balance_status": "normal",
19
+
20
+ # 免费试用信息
21
+ "free_trial_limit": 500.0,
22
+ "free_trial_usage": 100.0,
23
+ "free_trial_expiry": "2026-02-13T23:59:59Z",
24
+ "trial_expiry_text": "2026-02-13",
25
+
26
+ # 奖励信息
27
+ "bonus_limit": 150.0,
28
+ "bonus_usage": 25.0,
29
+ "bonus_expiries": ["2026-03-01T23:59:59Z", "2026-02-28T23:59:59Z"],
30
+ "active_bonuses": 2,
31
+
32
+ # 重置时间
33
+ "next_reset_date": "2026-02-01T00:00:00Z",
34
+ "reset_date_text": "2026-02-01",
35
+
36
+ # 更新时间
37
+ "updated_at": "2分钟前",
38
+ "error": None
39
+ }
40
+
41
+ # 生成 HTML 显示片段(类似在 Web 界面中的显示)
42
+ html_template = """
43
+ <div class="account-quota-section">
44
+ <div class="quota-header">
45
+ <span>已用/总额</span>
46
+ <span>{current_usage:.1f} / {usage_limit:.1f}</span>
47
+ </div>
48
+ <div class="progress-bar">
49
+ <div class="progress-fill" style="width: {usage_percent:.1f}%"></div>
50
+ </div>
51
+ <div class="quota-detail">
52
+ <span>试用: {free_trial_usage:.0f}/{free_trial_limit:.0f}</span>
53
+ <span>奖励: {bonus_usage:.0f}/{bonus_limit:.0f} ({active_bonuses}个)</span>
54
+ <span>更新: {updated_at}</span>
55
+ </div>
56
+ <div class="quota-reset-info">
57
+ <span>🔄 重置: {reset_date_text}</span>
58
+ <span>🎁 试用过期: {trial_expiry_text}</span>
59
+ </div>
60
+ </div>
61
+ """.format(**quota_data)
62
+
63
+ print("=== 额度信息展示示例 ===")
64
+ print(html_template)
65
+
66
+ # 生成卡片式展示
67
+ card_template = """
68
+ <div class="quota-card">
69
+ <h3>主配额</h3>
70
+ <div class="quota-amount">{current_usage:.0f} / {usage_limit:.0f}</div>
71
+ <div class="quota-reset">2026-02-01 重置</div>
72
+ </div>
73
+ <div class="quota-card">
74
+ <h3>免费试用</h3>
75
+ <div class="quota-amount">{free_trial_usage:.0f} / {free_trial_limit:.0f}</div>
76
+ <div class="quota-expiry">ACTIVE</div>
77
+ <div class="quota-reset">2026-02-13 过期</div>
78
+ </div>
79
+ <div class="quota-card">
80
+ <h3>奖励总计</h3>
81
+ <div class="quota-amount">{bonus_usage:.0f} / {bonus_limit:.0f}</div>
82
+ <div class="quota-expiry">{active_bonuses}个生效奖励</div>
83
+ </div>
84
+ """.format(**quota_data)
85
+
86
+ print("\n=== 卡片式展示(如图所示)===")
87
+ print(card_template)
88
+
89
+ # 生成 JSON 数据
90
+ print("\n=== JSON 数据格式 ===")
91
+ print(json.dumps(quota_data, indent=2, ensure_ascii=False))
92
+
93
+
94
+ if __name__ == "__main__":
95
+ generate_quota_display_example()
KiroProxy/examples/test_quota_display.html ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html>
3
+ <head>
4
+ <meta charset="UTF-8">
5
+ <title>额度重置时间测试</title>
6
+ <style>
7
+ body {
8
+ font-family: Arial, sans-serif;
9
+ padding: 20px;
10
+ background: #f5f5f5;
11
+ }
12
+ .account-card {
13
+ background: white;
14
+ border-radius: 10px;
15
+ padding: 20px;
16
+ margin-bottom: 20px;
17
+ box-shadow: 0 2px 4px rgba(0,0,0,0.1);
18
+ }
19
+ .quota-header {
20
+ display: flex;
21
+ justify-content: space-between;
22
+ margin-bottom: 10px;
23
+ font-weight: bold;
24
+ }
25
+ .progress-bar {
26
+ background: #e0e0e0;
27
+ border-radius: 4px;
28
+ height: 10px;
29
+ margin-bottom: 10px;
30
+ overflow: hidden;
31
+ }
32
+ .progress-fill {
33
+ background: #4CAF50;
34
+ height: 100%;
35
+ transition: width 0.3s;
36
+ }
37
+ .quota-detail {
38
+ display: flex;
39
+ gap: 20px;
40
+ font-size: 0.9em;
41
+ color: #666;
42
+ margin-bottom: 10px;
43
+ }
44
+ .quota-reset-info {
45
+ display: flex;
46
+ gap: 20px;
47
+ font-size: 0.8em;
48
+ color: #888;
49
+ }
50
+ .badge {
51
+ padding: 2px 8px;
52
+ border-radius: 4px;
53
+ font-size: 0.8em;
54
+ }
55
+ .badge.success { background: #4CAF50; color: white; }
56
+ .badge.error { background: #f44336; color: white; }
57
+ </style>
58
+ </head>
59
+ <body>
60
+ <h1>额度重置时间测试</h1>
61
+ <div id="accountsContainer"></div>
62
+
63
+ <script>
64
+ async function loadAccounts() {
65
+ try {
66
+ const response = await fetch('http://localhost:8080/api/accounts/status');
67
+ const data = await response.json();
68
+
69
+ const container = document.getElementById('accountsContainer');
70
+ container.innerHTML = '';
71
+
72
+ data.accounts.forEach(account => {
73
+ const quota = account.quota;
74
+ if (!quota) return;
75
+
76
+ const usedPercent = quota.usage_limit > 0 ? (quota.current_usage / quota.usage_limit * 100) : 0;
77
+ const isExhausted = quota.is_exhausted;
78
+
79
+ const card = document.createElement('div');
80
+ card.className = 'account-card';
81
+ card.innerHTML = `
82
+ <h3>${account.name} <span class="badge ${isExhausted ? 'error' : 'success'}">${isExhausted ? '额度耗尽' : '正常'}</span></h3>
83
+ <div class="quota-header">
84
+ <span>已用/总额</span>
85
+ <span>${quota.current_usage.toFixed(1)} / ${quota.usage_limit.toFixed(1)}</span>
86
+ </div>
87
+ <div class="progress-bar">
88
+ <div class="progress-fill" style="width: ${usedPercent}%"></div>
89
+ </div>
90
+ <div class="quota-detail">
91
+ <span>试用: ${quota.free_trial_usage.toFixed(0)}/${quota.free_trial_limit.toFixed(0)}</span>
92
+ <span>奖励: ${quota.bonus_usage.toFixed(0)}/${quota.bonus_limit.toFixed(0)} (${quota.active_bonuses}个)</span>
93
+ <span>更新: ${quota.updated_at || '未知'}</span>
94
+ </div>
95
+ ${quota.reset_date_text || quota.trial_expiry_text ? `
96
+ <div class="quota-reset-info">
97
+ ${quota.reset_date_text ? `<span>🔄 重置: ${quota.reset_date_text}</span>` : ''}
98
+ ${quota.trial_expiry_text ? `<span>🎁 试用过期: ${quota.trial_expiry_text}</span>` : ''}
99
+ </div>
100
+ ` : ''}
101
+ `;
102
+
103
+ container.appendChild(card);
104
+ });
105
+ } catch (error) {
106
+ console.error('加载失败:', error);
107
+ document.getElementById('accountsContainer').innerHTML = '<p>加载失败,请确保服务器正在运行</p>';
108
+ }
109
+ }
110
+
111
+ // 页面加载时获取数据
112
+ loadAccounts();
113
+
114
+ // 每30秒刷新一次
115
+ setInterval(loadAccounts, 30000);
116
+ </script>
117
+ </body>
118
+ </html>
KiroProxy/kiro.svg ADDED
KiroProxy/kiro_proxy/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # Kiro API Proxy
2
+ __version__ = "1.7.1"
KiroProxy/kiro_proxy/__main__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .cli import main
2
+
3
+
4
+ if __name__ == "__main__":
5
+ main()
KiroProxy/kiro_proxy/auth/__init__.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Kiro 认证模块"""
2
+ from .device_flow import (
3
+ start_device_flow,
4
+ poll_device_flow,
5
+ cancel_device_flow,
6
+ get_login_state,
7
+ save_credentials_to_file,
8
+ DeviceFlowState,
9
+ # Social Auth
10
+ start_social_auth,
11
+ exchange_social_auth_token,
12
+ cancel_social_auth,
13
+ get_social_auth_state,
14
+ start_callback_server,
15
+ wait_for_callback,
16
+ )
17
+
18
+ __all__ = [
19
+ "start_device_flow",
20
+ "poll_device_flow",
21
+ "cancel_device_flow",
22
+ "get_login_state",
23
+ "save_credentials_to_file",
24
+ "DeviceFlowState",
25
+ # Social Auth
26
+ "start_social_auth",
27
+ "exchange_social_auth_token",
28
+ "cancel_social_auth",
29
+ "get_social_auth_state",
30
+ "start_callback_server",
31
+ "wait_for_callback",
32
+ ]
KiroProxy/kiro_proxy/auth/device_flow.py ADDED
@@ -0,0 +1,603 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Kiro Device Code Flow 登录
2
+
3
+ 实现 AWS OIDC Device Authorization Flow:
4
+ 1. 注册 OIDC 客户端 -> 获取 clientId + clientSecret
5
+ 2. 发起设备授权 -> 获取 deviceCode + userCode + verificationUri
6
+ 3. 用户在浏览器中输入 userCode 完成授权
7
+ 4. 轮询 Token -> 获取 accessToken + refreshToken
8
+
9
+ Social Auth (Google/GitHub):
10
+ 1. 生成 PKCE code_verifier 和 code_challenge
11
+ 2. 构建登录 URL,打开浏览器
12
+ 3. 启动本地回调服务器接收授权码
13
+ 4. 用授权码交换 Token
14
+ """
15
+ import json
16
+ import time
17
+ import httpx
18
+ import secrets
19
+ import hashlib
20
+ import base64
21
+ import asyncio
22
+ from pathlib import Path
23
+ from dataclasses import dataclass, asdict
24
+ from typing import Optional, Tuple
25
+ from datetime import datetime, timezone
26
+
27
+
28
+ @dataclass
29
+ class DeviceFlowState:
30
+ """设备授权流程状态"""
31
+ client_id: str
32
+ client_secret: str
33
+ device_code: str
34
+ user_code: str
35
+ verification_uri: str
36
+ interval: int
37
+ expires_at: int
38
+ region: str
39
+ started_at: float
40
+
41
+
42
+ @dataclass
43
+ class SocialAuthState:
44
+ """Social Auth 登录状态"""
45
+ provider: str # Google / Github
46
+ code_verifier: str
47
+ code_challenge: str
48
+ oauth_state: str
49
+ expires_at: int
50
+ started_at: float
51
+
52
+
53
+ # 全局登录状态
54
+ _login_state: Optional[DeviceFlowState] = None
55
+ _social_auth_state: Optional[SocialAuthState] = None
56
+ _callback_server = None
57
+
58
+ # Kiro OIDC 配置
59
+ KIRO_START_URL = "https://view.awsapps.com/start"
60
+ KIRO_AUTH_ENDPOINT = "https://prod.us-east-1.auth.desktop.kiro.dev"
61
+ KIRO_SCOPES = [
62
+ "codewhisperer:completions",
63
+ "codewhisperer:analysis",
64
+ "codewhisperer:conversations",
65
+ "codewhisperer:transformations",
66
+ "codewhisperer:taskassist",
67
+ ]
68
+
69
+
70
+ def get_login_state() -> Optional[dict]:
71
+ """获取当前登录状态"""
72
+ global _login_state
73
+ if _login_state is None:
74
+ return None
75
+
76
+ # 检查是否过期
77
+ if time.time() > _login_state.expires_at:
78
+ _login_state = None
79
+ return None
80
+
81
+ return {
82
+ "user_code": _login_state.user_code,
83
+ "verification_uri": _login_state.verification_uri,
84
+ "expires_in": int(_login_state.expires_at - time.time()),
85
+ "interval": _login_state.interval,
86
+ }
87
+
88
+
89
+ async def start_device_flow(region: str = "us-east-1") -> Tuple[bool, dict]:
90
+ """
91
+ 启动设备授权流程
92
+
93
+ Returns:
94
+ (success, result_or_error)
95
+ """
96
+ global _login_state
97
+
98
+ oidc_base = f"https://oidc.{region}.amazonaws.com"
99
+
100
+ async with httpx.AsyncClient(timeout=30) as client:
101
+ # Step 1: 注册 OIDC 客户端
102
+ print(f"[DeviceFlow] Step 1: 注册 OIDC 客户端...")
103
+
104
+ reg_body = {
105
+ "clientName": "Kiro Proxy",
106
+ "clientType": "public",
107
+ "scopes": KIRO_SCOPES,
108
+ "grantTypes": ["urn:ietf:params:oauth:grant-type:device_code", "refresh_token"],
109
+ "issuerUrl": KIRO_START_URL
110
+ }
111
+
112
+ try:
113
+ reg_resp = await client.post(
114
+ f"{oidc_base}/client/register",
115
+ json=reg_body,
116
+ headers={"Content-Type": "application/json"}
117
+ )
118
+ except Exception as e:
119
+ return False, {"error": f"注册客户端请求失败: {e}"}
120
+
121
+ if reg_resp.status_code != 200:
122
+ return False, {"error": f"注册客户端失败: {reg_resp.text}"}
123
+
124
+ reg_data = reg_resp.json()
125
+ client_id = reg_data.get("clientId")
126
+ client_secret = reg_data.get("clientSecret")
127
+
128
+ if not client_id or not client_secret:
129
+ return False, {"error": "注册响应缺少 clientId 或 clientSecret"}
130
+
131
+ print(f"[DeviceFlow] 客户端注册成功: {client_id[:20]}...")
132
+
133
+ # Step 2: 发起设备授权
134
+ print(f"[DeviceFlow] Step 2: 发起设备授权...")
135
+
136
+ auth_body = {
137
+ "clientId": client_id,
138
+ "clientSecret": client_secret,
139
+ "startUrl": KIRO_START_URL
140
+ }
141
+
142
+ try:
143
+ auth_resp = await client.post(
144
+ f"{oidc_base}/device_authorization",
145
+ json=auth_body,
146
+ headers={"Content-Type": "application/json"}
147
+ )
148
+ except Exception as e:
149
+ return False, {"error": f"设备授权请求失败: {e}"}
150
+
151
+ if auth_resp.status_code != 200:
152
+ return False, {"error": f"设备授权失败: {auth_resp.text}"}
153
+
154
+ auth_data = auth_resp.json()
155
+ device_code = auth_data.get("deviceCode")
156
+ user_code = auth_data.get("userCode")
157
+ verification_uri = auth_data.get("verificationUriComplete") or auth_data.get("verificationUri")
158
+ interval = auth_data.get("interval", 5)
159
+ expires_in = auth_data.get("expiresIn", 600)
160
+
161
+ if not device_code or not user_code or not verification_uri:
162
+ return False, {"error": "设备授权响应缺少必要字��"}
163
+
164
+ print(f"[DeviceFlow] 设备码获取成功: {user_code}")
165
+
166
+ # 保存状态
167
+ _login_state = DeviceFlowState(
168
+ client_id=client_id,
169
+ client_secret=client_secret,
170
+ device_code=device_code,
171
+ user_code=user_code,
172
+ verification_uri=verification_uri,
173
+ interval=interval,
174
+ expires_at=int(time.time() + expires_in),
175
+ region=region,
176
+ started_at=time.time()
177
+ )
178
+
179
+ return True, {
180
+ "user_code": user_code,
181
+ "verification_uri": verification_uri,
182
+ "expires_in": expires_in,
183
+ "interval": interval,
184
+ }
185
+
186
+
187
+ async def poll_device_flow() -> Tuple[bool, dict]:
188
+ """
189
+ 轮询设备授权状态
190
+
191
+ Returns:
192
+ (success, result_or_error)
193
+ - success=True, result={"completed": True, "credentials": {...}} 授权完成
194
+ - success=True, result={"completed": False, "status": "pending"} 等待中
195
+ - success=False, result={"error": "..."} 错误
196
+ """
197
+ global _login_state
198
+
199
+ if _login_state is None:
200
+ return False, {"error": "没有进行中的登录"}
201
+
202
+ # 检查是否过期
203
+ if time.time() > _login_state.expires_at:
204
+ _login_state = None
205
+ return False, {"error": "授权已过期,请重新开始"}
206
+
207
+ oidc_base = f"https://oidc.{_login_state.region}.amazonaws.com"
208
+
209
+ token_body = {
210
+ "clientId": _login_state.client_id,
211
+ "clientSecret": _login_state.client_secret,
212
+ "grantType": "urn:ietf:params:oauth:grant-type:device_code",
213
+ "deviceCode": _login_state.device_code
214
+ }
215
+
216
+ async with httpx.AsyncClient(timeout=30) as client:
217
+ try:
218
+ token_resp = await client.post(
219
+ f"{oidc_base}/token",
220
+ json=token_body,
221
+ headers={"Content-Type": "application/json"}
222
+ )
223
+ except Exception as e:
224
+ return False, {"error": f"Token 请求失败: {e}"}
225
+
226
+ if token_resp.status_code == 200:
227
+ # 授权成功
228
+ token_data = token_resp.json()
229
+
230
+ credentials = {
231
+ "accessToken": token_data.get("accessToken"),
232
+ "refreshToken": token_data.get("refreshToken"),
233
+ "expiresAt": datetime.now(timezone.utc).isoformat(),
234
+ "clientId": _login_state.client_id,
235
+ "clientSecret": _login_state.client_secret,
236
+ "region": _login_state.region,
237
+ "authMethod": "idc",
238
+ }
239
+
240
+ # 计算过期时间
241
+ if expires_in := token_data.get("expiresIn"):
242
+ from datetime import timedelta
243
+ expires_at = datetime.now(timezone.utc) + timedelta(seconds=expires_in)
244
+ credentials["expiresAt"] = expires_at.isoformat()
245
+
246
+ # 清除状态
247
+ _login_state = None
248
+
249
+ print(f"[DeviceFlow] 授权成功!")
250
+ return True, {"completed": True, "credentials": credentials}
251
+
252
+ # 检查错误类型
253
+ try:
254
+ error_data = token_resp.json()
255
+ error_code = error_data.get("error", "")
256
+ except:
257
+ error_code = ""
258
+
259
+ if error_code == "authorization_pending":
260
+ # 用户还未完成授权
261
+ return True, {"completed": False, "status": "pending"}
262
+ elif error_code == "slow_down":
263
+ # 请求太频繁
264
+ return True, {"completed": False, "status": "slow_down"}
265
+ elif error_code == "expired_token":
266
+ _login_state = None
267
+ return False, {"error": "授权已过期,请重新开始"}
268
+ elif error_code == "access_denied":
269
+ _login_state = None
270
+ return False, {"error": "用户拒绝授权"}
271
+ else:
272
+ return False, {"error": f"Token 请求失败: {token_resp.text}"}
273
+
274
+
275
+ def cancel_device_flow() -> bool:
276
+ """取消设备授权流程"""
277
+ global _login_state
278
+ if _login_state is not None:
279
+ _login_state = None
280
+ return True
281
+ return False
282
+
283
+
284
+ async def save_credentials_to_file(credentials: dict, name: str = "kiro-proxy-auth") -> str:
285
+ """
286
+ 保存凭证到文件
287
+
288
+ 支持的字段:
289
+ - accessToken, refreshToken, profileArn, expiresAt
290
+ - clientId, clientSecret (IDC 认证)
291
+ - region, authMethod, provider
292
+
293
+ Returns:
294
+ 保存的文件路径
295
+ """
296
+ from ..config import TOKEN_DIR
297
+ TOKEN_DIR.mkdir(parents=True, exist_ok=True)
298
+
299
+ # 生成文件名
300
+ file_path = TOKEN_DIR / f"{name}.json"
301
+
302
+ # 如果文件已存在,合并现有数据
303
+ existing = {}
304
+ if file_path.exists():
305
+ try:
306
+ with open(file_path, "r") as f:
307
+ existing = json.load(f)
308
+ except Exception:
309
+ pass
310
+
311
+ # 更新凭证(只更新非空值)
312
+ for key, value in credentials.items():
313
+ if value is not None:
314
+ existing[key] = value
315
+
316
+ with open(file_path, "w") as f:
317
+ json.dump(existing, f, indent=2)
318
+
319
+ print(f"[DeviceFlow] 凭证已保存到: {file_path}")
320
+ return str(file_path)
321
+
322
+
323
+ # ==================== Social Auth (Google/GitHub) ====================
324
+
325
+ def _generate_code_verifier() -> str:
326
+ """生成 PKCE code_verifier"""
327
+ return secrets.token_urlsafe(64)[:128]
328
+
329
+
330
+ def _generate_code_challenge(verifier: str) -> str:
331
+ """生成 PKCE code_challenge (SHA256)"""
332
+ digest = hashlib.sha256(verifier.encode()).digest()
333
+ return base64.urlsafe_b64encode(digest).rstrip(b'=').decode()
334
+
335
+
336
+ def _generate_oauth_state() -> str:
337
+ """生成 OAuth state"""
338
+ return secrets.token_urlsafe(32)
339
+
340
+
341
+ def get_social_auth_state() -> Optional[dict]:
342
+ """获取当前 Social Auth 状态"""
343
+ global _social_auth_state
344
+ if _social_auth_state is None:
345
+ return None
346
+
347
+ if time.time() > _social_auth_state.expires_at:
348
+ _social_auth_state = None
349
+ return None
350
+
351
+ return {
352
+ "provider": _social_auth_state.provider,
353
+ "expires_in": int(_social_auth_state.expires_at - time.time()),
354
+ }
355
+
356
+
357
+ async def start_social_auth(provider: str, redirect_uri: str = None) -> Tuple[bool, dict]:
358
+ """
359
+ 启动 Social Auth 登录 (Google/GitHub)
360
+
361
+ Args:
362
+ provider: "google" 或 "github"
363
+ redirect_uri: 回调地址,默认使用 Kiro 官方回调地址
364
+
365
+ Returns:
366
+ (success, result_or_error)
367
+ """
368
+ global _social_auth_state
369
+
370
+ # 验证 provider
371
+ provider_normalized = provider.lower()
372
+ if provider_normalized == "google":
373
+ provider_normalized = "Google"
374
+ elif provider_normalized == "github":
375
+ provider_normalized = "Github"
376
+ else:
377
+ return False, {"error": f"不支持的登录提供商: {provider}"}
378
+
379
+ print(f"[SocialAuth] 开始 {provider_normalized} 登录流程")
380
+
381
+ # 生成 PKCE
382
+ code_verifier = _generate_code_verifier()
383
+ code_challenge = _generate_code_challenge(code_verifier)
384
+ oauth_state = _generate_oauth_state()
385
+
386
+ # 回调地址 - 使用 Kiro 官方的回调地址(已在 Cognito 中注册)
387
+ # 参考 Kiro-account-manager: kiro://kiro.kiroAgent/authenticate-success
388
+ if redirect_uri is None:
389
+ redirect_uri = "kiro://kiro.kiroAgent/authenticate-success"
390
+
391
+ # 构建登录 URL (使用 /login 端点,参考 Kiro-account-manager)
392
+ from urllib.parse import quote, urlencode
393
+
394
+ # 使用 urlencode 确保参数正确编码
395
+ params = {
396
+ "idp": provider_normalized,
397
+ "redirect_uri": redirect_uri,
398
+ "code_challenge": code_challenge,
399
+ "code_challenge_method": "S256",
400
+ "state": oauth_state,
401
+ }
402
+ login_url = f"{KIRO_AUTH_ENDPOINT}/login?{urlencode(params)}"
403
+
404
+ print(f"[SocialAuth] ========== Social Auth 登录 ==========")
405
+ print(f"[SocialAuth] Provider: {provider_normalized}")
406
+ print(f"[SocialAuth] Redirect URI: {redirect_uri}")
407
+ print(f"[SocialAuth] Code Challenge: {code_challenge[:20]}...")
408
+ print(f"[SocialAuth] State: {oauth_state}")
409
+ print(f"[SocialAuth] 登录 URL: {login_url}")
410
+ print(f"[SocialAuth] =========================================")
411
+
412
+ # 保存状态(10 分钟过期)
413
+ _social_auth_state = SocialAuthState(
414
+ provider=provider_normalized,
415
+ code_verifier=code_verifier,
416
+ code_challenge=code_challenge,
417
+ oauth_state=oauth_state,
418
+ expires_at=int(time.time() + 600),
419
+ started_at=time.time(),
420
+ )
421
+
422
+ return True, {
423
+ "login_url": login_url,
424
+ "state": oauth_state,
425
+ "provider": provider_normalized,
426
+ "redirect_uri": redirect_uri,
427
+ }
428
+
429
+
430
+ async def exchange_social_auth_token(code: str, state: str, redirect_uri: str = None) -> Tuple[bool, dict]:
431
+ """
432
+ 用授权码交换 Token
433
+
434
+ 参考 Kiro-account-manager 实现:
435
+ - 端点: https://prod.us-east-1.auth.desktop.kiro.dev/oauth/token
436
+ - 请求体: {code, code_verifier, redirect_uri}
437
+ - 响应: {accessToken, refreshToken, profileArn, expiresIn}
438
+
439
+ Args:
440
+ code: 授权码
441
+ state: OAuth state
442
+ redirect_uri: 回调地址(需要与 start_social_auth 中使用的一致)
443
+
444
+ Returns:
445
+ (success, result_or_error)
446
+ """
447
+ global _social_auth_state
448
+
449
+ if _social_auth_state is None:
450
+ return False, {"error": "没有进行中的社交登录"}
451
+
452
+ # 验证 state
453
+ if state != _social_auth_state.oauth_state:
454
+ _social_auth_state = None
455
+ return False, {"error": "OAuth state 不匹配"}
456
+
457
+ # 检查过期
458
+ if time.time() > _social_auth_state.expires_at:
459
+ _social_auth_state = None
460
+ return False, {"error": "登录已过期,请重新开始"}
461
+
462
+ print(f"[SocialAuth] 交换 Token...")
463
+
464
+ # 回调地址 - 需要与 start_social_auth 中使��的一致
465
+ # 使用 Kiro 官方的回调地址
466
+ if redirect_uri is None:
467
+ redirect_uri = "kiro://kiro.kiroAgent/authenticate-success"
468
+
469
+ # 交换 Token (参考 Kiro-account-manager 的请求格式)
470
+ token_body = {
471
+ "code": code,
472
+ "code_verifier": _social_auth_state.code_verifier,
473
+ "redirect_uri": redirect_uri,
474
+ }
475
+
476
+ async with httpx.AsyncClient(timeout=30) as client:
477
+ try:
478
+ token_resp = await client.post(
479
+ f"{KIRO_AUTH_ENDPOINT}/oauth/token",
480
+ json=token_body,
481
+ headers={"Content-Type": "application/json"}
482
+ )
483
+ except Exception as e:
484
+ _social_auth_state = None
485
+ return False, {"error": f"Token 请求失败: {e}"}
486
+
487
+ if token_resp.status_code != 200:
488
+ error_text = token_resp.text
489
+ _social_auth_state = None
490
+ return False, {"error": f"Token 交换失败: {error_text}"}
491
+
492
+ token_data = token_resp.json()
493
+
494
+ # 解析响应 (参考 Kiro-account-manager 的响应格式)
495
+ # 响应字段: accessToken, refreshToken, profileArn, expiresIn
496
+ provider = _social_auth_state.provider
497
+
498
+ credentials = {
499
+ "accessToken": token_data.get("accessToken") or token_data.get("access_token"),
500
+ "refreshToken": token_data.get("refreshToken") or token_data.get("refresh_token"),
501
+ "profileArn": token_data.get("profileArn"),
502
+ "expiresAt": datetime.now(timezone.utc).isoformat(),
503
+ "authMethod": "social",
504
+ "provider": provider, # 保存 provider 字段
505
+ }
506
+
507
+ # 计算过期时间
508
+ expires_in = token_data.get("expiresIn") or token_data.get("expires_in")
509
+ if expires_in:
510
+ from datetime import timedelta
511
+ expires_at = datetime.now(timezone.utc) + timedelta(seconds=expires_in)
512
+ credentials["expiresAt"] = expires_at.isoformat()
513
+
514
+ _social_auth_state = None
515
+
516
+ print(f"[SocialAuth] {provider} 登录成功!")
517
+ return True, {"completed": True, "credentials": credentials, "provider": provider}
518
+
519
+
520
+ def cancel_social_auth() -> bool:
521
+ """取消 Social Auth 登录"""
522
+ global _social_auth_state
523
+ if _social_auth_state is not None:
524
+ _social_auth_state = None
525
+ return True
526
+ return False
527
+
528
+
529
+ # ==================== 回调服务器 ====================
530
+
531
+ _callback_result = None
532
+ _callback_event = None
533
+
534
+ async def start_callback_server() -> Tuple[bool, dict]:
535
+ """启动本地回调服务器"""
536
+ global _callback_result, _callback_event
537
+
538
+ from aiohttp import web
539
+
540
+ _callback_result = None
541
+ _callback_event = asyncio.Event()
542
+
543
+ async def handle_callback(request):
544
+ global _callback_result
545
+ code = request.query.get("code")
546
+ state = request.query.get("state")
547
+ error = request.query.get("error")
548
+
549
+ if error:
550
+ _callback_result = {"error": error}
551
+ elif code and state:
552
+ _callback_result = {"code": code, "state": state}
553
+ else:
554
+ _callback_result = {"error": "缺少授权码"}
555
+
556
+ _callback_event.set()
557
+
558
+ # 返回成功页面
559
+ html = """
560
+ <html>
561
+ <head><title>登录成功</title></head>
562
+ <body style="font-family:sans-serif;text-align:center;padding:50px">
563
+ <h1>✅ 登录成功</h1>
564
+ <p>您可以关闭此窗口并返回 Kiro Proxy</p>
565
+ <script>setTimeout(()=>window.close(),2000)</script>
566
+ </body>
567
+ </html>
568
+ """
569
+ return web.Response(text=html, content_type="text/html")
570
+
571
+ app = web.Application()
572
+ app.router.add_get("/kiro-social-callback", handle_callback)
573
+
574
+ runner = web.AppRunner(app)
575
+ await runner.setup()
576
+
577
+ try:
578
+ site = web.TCPSite(runner, "127.0.0.1", 19823)
579
+ await site.start()
580
+ print("[SocialAuth] 回调服务器已启动: http://127.0.0.1:19823")
581
+ return True, {"port": 19823}
582
+ except Exception as e:
583
+ return False, {"error": f"启动回调服务器失败: {e}"}
584
+
585
+
586
+ async def wait_for_callback(timeout: int = 300) -> Tuple[bool, dict]:
587
+ """等待回调"""
588
+ global _callback_result, _callback_event
589
+
590
+ if _callback_event is None:
591
+ return False, {"error": "回调服务器未启动"}
592
+
593
+ try:
594
+ await asyncio.wait_for(_callback_event.wait(), timeout=timeout)
595
+
596
+ if _callback_result and "code" in _callback_result:
597
+ return True, _callback_result
598
+ elif _callback_result and "error" in _callback_result:
599
+ return False, _callback_result
600
+ else:
601
+ return False, {"error": "未收到有效回调"}
602
+ except asyncio.TimeoutError:
603
+ return False, {"error": "等待回调超时"}
KiroProxy/kiro_proxy/cli.py ADDED
@@ -0,0 +1,375 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Kiro Proxy CLI - 轻量命令行工具"""
3
+ import argparse
4
+ import asyncio
5
+ import json
6
+ import sys
7
+ from pathlib import Path
8
+
9
+ from . import __version__
10
+
11
+
12
+ def cmd_serve(args):
13
+ """启动代理服务"""
14
+ from .main import run
15
+ run(port=args.port)
16
+
17
+
18
+ def cmd_accounts_list(args):
19
+ """列出所有账号"""
20
+ from .core import state
21
+ accounts = state.get_accounts_status()
22
+ if not accounts:
23
+ print("暂无账号")
24
+ return
25
+ print(f"{'ID':<10} {'名称':<20} {'状态':<10} {'请求数':<8}")
26
+ print("-" * 50)
27
+ for acc in accounts:
28
+ print(f"{acc['id']:<10} {acc['name']:<20} {acc['status']:<10} {acc['request_count']:<8}")
29
+
30
+
31
+ def cmd_accounts_export(args):
32
+ """导出账号配置"""
33
+ from .core import state
34
+ accounts_data = []
35
+ for acc in state.accounts:
36
+ creds = acc.get_credentials()
37
+ if creds:
38
+ accounts_data.append({
39
+ "name": acc.name,
40
+ "enabled": acc.enabled,
41
+ "credentials": {
42
+ "accessToken": creds.access_token,
43
+ "refreshToken": creds.refresh_token,
44
+ "expiresAt": creds.expires_at,
45
+ "region": creds.region,
46
+ "authMethod": creds.auth_method,
47
+ }
48
+ })
49
+
50
+ output = {"accounts": accounts_data, "version": "1.0"}
51
+
52
+ if args.output:
53
+ Path(args.output).write_text(json.dumps(output, indent=2, ensure_ascii=False))
54
+ print(f"已导出 {len(accounts_data)} 个账号到 {args.output}")
55
+ else:
56
+ print(json.dumps(output, indent=2, ensure_ascii=False))
57
+
58
+
59
+ def cmd_accounts_import(args):
60
+ """导入账号配置"""
61
+ import uuid
62
+ from .core import state, Account
63
+ from .auth import save_credentials_to_file
64
+
65
+ data = json.loads(Path(args.file).read_text())
66
+ accounts_data = data.get("accounts", [])
67
+ imported = 0
68
+
69
+ for acc_data in accounts_data:
70
+ creds = acc_data.get("credentials", {})
71
+ if not creds.get("accessToken"):
72
+ print(f"跳过 {acc_data.get('name', '未知')}: 缺少 accessToken")
73
+ continue
74
+
75
+ # 保存凭证到文件
76
+ file_path = asyncio.run(save_credentials_to_file({
77
+ "accessToken": creds.get("accessToken"),
78
+ "refreshToken": creds.get("refreshToken"),
79
+ "expiresAt": creds.get("expiresAt"),
80
+ "region": creds.get("region", "us-east-1"),
81
+ "authMethod": creds.get("authMethod", "social"),
82
+ }, f"imported-{uuid.uuid4().hex[:8]}"))
83
+
84
+ account = Account(
85
+ id=uuid.uuid4().hex[:8],
86
+ name=acc_data.get("name", "导入账号"),
87
+ token_path=file_path,
88
+ enabled=acc_data.get("enabled", True)
89
+ )
90
+ state.accounts.append(account)
91
+ account.load_credentials()
92
+ imported += 1
93
+ print(f"已导入: {account.name}")
94
+
95
+ state._save_accounts()
96
+ print(f"\n共导入 {imported} 个账号")
97
+
98
+
99
+ def cmd_accounts_add(args):
100
+ """手动添加 Token"""
101
+ import uuid
102
+ from .core import state, Account
103
+ from .auth import save_credentials_to_file
104
+
105
+ print("手动添加 Kiro 账号")
106
+ print("-" * 40)
107
+
108
+ name = input("账号名称 [我的账号]: ").strip() or "我的账号"
109
+ print("\n请粘贴 Access Token:")
110
+ access_token = input().strip()
111
+
112
+ if not access_token:
113
+ print("错误: Access Token 不能为空")
114
+ return
115
+
116
+ print("\n请粘贴 Refresh Token (可选,直接回车跳过):")
117
+ refresh_token = input().strip() or None
118
+
119
+ # 保存凭证
120
+ file_path = asyncio.run(save_credentials_to_file({
121
+ "accessToken": access_token,
122
+ "refreshToken": refresh_token,
123
+ "region": "us-east-1",
124
+ "authMethod": "social",
125
+ }, f"manual-{uuid.uuid4().hex[:8]}"))
126
+
127
+ account = Account(
128
+ id=uuid.uuid4().hex[:8],
129
+ name=name,
130
+ token_path=file_path
131
+ )
132
+ state.accounts.append(account)
133
+ account.load_credentials()
134
+ state._save_accounts()
135
+
136
+ print(f"\n✅ 账号已添加: {name} (ID: {account.id})")
137
+
138
+
139
+ def cmd_accounts_scan(args):
140
+ """扫描本地 Token"""
141
+ import uuid
142
+ from .core import state, Account
143
+ from .config import TOKEN_DIR
144
+
145
+ # 扫描新目录
146
+ found = []
147
+ if TOKEN_DIR.exists():
148
+ for f in TOKEN_DIR.glob("*.json"):
149
+ try:
150
+ data = json.loads(f.read_text())
151
+ if "accessToken" in data:
152
+ already = any(a.token_path == str(f) for a in state.accounts)
153
+ found.append({"path": str(f), "name": f.stem, "already": already})
154
+ except:
155
+ pass
156
+
157
+ # 兼容旧目录
158
+ sso_cache = Path.home() / ".aws/sso/cache"
159
+ if sso_cache.exists():
160
+ for f in sso_cache.glob("*.json"):
161
+ try:
162
+ data = json.loads(f.read_text())
163
+ if "accessToken" in data:
164
+ already = any(a.token_path == str(f) for a in state.accounts)
165
+ found.append({"path": str(f), "name": f.stem + " (旧目录)", "already": already})
166
+ except:
167
+ pass
168
+
169
+ if not found:
170
+ print("未找到 Token 文件")
171
+ print(f"Token 目录: {TOKEN_DIR}")
172
+ return
173
+
174
+ print(f"找到 {len(found)} 个 Token:\n")
175
+ for i, t in enumerate(found):
176
+ status = "[已添加]" if t["already"] else ""
177
+ print(f" {i+1}. {t['name']} {status}")
178
+
179
+ if args.auto:
180
+ # 自动添加所有未添加的
181
+ added = 0
182
+ for t in found:
183
+ if not t["already"]:
184
+ account = Account(
185
+ id=uuid.uuid4().hex[:8],
186
+ name=t["name"],
187
+ token_path=t["path"]
188
+ )
189
+ state.accounts.append(account)
190
+ account.load_credentials()
191
+ added += 1
192
+ state._save_accounts()
193
+ print(f"\n已添加 {added} 个账号")
194
+ else:
195
+ print("\n使用 --auto 自动添加所有未添加的账号")
196
+
197
+
198
+ def cmd_login_remote(args):
199
+ """生成远程登录链接"""
200
+ import uuid
201
+ import time
202
+
203
+ session_id = uuid.uuid4().hex
204
+ host = args.host or "localhost:8080"
205
+ scheme = "https" if args.https else "http"
206
+
207
+ print("远程登录链接")
208
+ print("-" * 40)
209
+ print(f"\n将以下链接发送到有浏览器的机器上完成登录:\n")
210
+ print(f" {scheme}://{host}/remote-login/{session_id}")
211
+ print(f"\n链接有效期 10 分钟")
212
+ print("\n登录完成后,在那台机器上导出账号,然后在这里导入:")
213
+ print(f" python -m kiro_proxy accounts import xxx.json")
214
+
215
+
216
+ def cmd_login_social(args):
217
+ """Social 登录 (Google/GitHub)"""
218
+ from .auth import start_social_auth
219
+
220
+ provider = args.provider
221
+ print(f"启动 {provider.title()} 登录...")
222
+
223
+ success, result = asyncio.run(start_social_auth(provider))
224
+ if not success:
225
+ print(f"错误: {result.get('error', '未知错误')}")
226
+ return
227
+
228
+ print(f"\n请在浏览器中打开以下链接完成授权:\n")
229
+ print(f" {result['login_url']}")
230
+ print(f"\n授权完成后,将浏览器地址栏中的完整 URL 粘贴到这里:")
231
+ callback_url = input().strip()
232
+
233
+ if not callback_url:
234
+ print("已取消")
235
+ return
236
+
237
+ try:
238
+ from urllib.parse import urlparse, parse_qs
239
+ parsed = urlparse(callback_url)
240
+ params = parse_qs(parsed.query)
241
+ code = params.get("code", [None])[0]
242
+ oauth_state = params.get("state", [None])[0]
243
+
244
+ if not code or not oauth_state:
245
+ print("错误: 无效的回调 URL")
246
+ return
247
+
248
+ from .auth import exchange_social_auth_token
249
+ success, result = asyncio.run(exchange_social_auth_token(code, oauth_state))
250
+
251
+ if success and result.get("completed"):
252
+ import uuid
253
+ from .core import state, Account
254
+ from .auth import save_credentials_to_file
255
+
256
+ credentials = result["credentials"]
257
+ file_path = asyncio.run(save_credentials_to_file(
258
+ credentials, f"cli-{provider}"
259
+ ))
260
+
261
+ account = Account(
262
+ id=uuid.uuid4().hex[:8],
263
+ name=f"{provider.title()} 登录",
264
+ token_path=file_path
265
+ )
266
+ state.accounts.append(account)
267
+ account.load_credentials()
268
+ state._save_accounts()
269
+
270
+ print(f"\n✅ 登录成功! 账号已添加: {account.name}")
271
+ else:
272
+ print(f"错误: {result.get('error', '登录失败')}")
273
+ except Exception as e:
274
+ print(f"错误: {e}")
275
+
276
+
277
+ def cmd_status(args):
278
+ """查看服务状态"""
279
+ from .core import state
280
+ stats = state.get_stats()
281
+
282
+ print("Kiro Proxy 状态")
283
+ print("-" * 40)
284
+ print(f"运行时间: {stats['uptime_seconds']} 秒")
285
+ print(f"总请求数: {stats['total_requests']}")
286
+ print(f"错误数: {stats['total_errors']}")
287
+ print(f"错误率: {stats['error_rate']}")
288
+ print(f"账号总数: {stats['accounts_total']}")
289
+ print(f"可用账号: {stats['accounts_available']}")
290
+ print(f"冷却中: {stats['accounts_cooldown']}")
291
+
292
+
293
+ def main():
294
+ parser = argparse.ArgumentParser(
295
+ prog="kiro-proxy",
296
+ description="Kiro API Proxy CLI"
297
+ )
298
+ parser.add_argument("-v", "--version", action="version", version=__version__)
299
+
300
+ subparsers = parser.add_subparsers(dest="command", help="命令")
301
+
302
+ # serve
303
+ serve_parser = subparsers.add_parser("serve", help="启动代理服务")
304
+ serve_parser.add_argument("-p", "--port", type=int, default=8080, help="端口号")
305
+ serve_parser.set_defaults(func=cmd_serve)
306
+
307
+ # status
308
+ status_parser = subparsers.add_parser("status", help="查看状态")
309
+ status_parser.set_defaults(func=cmd_status)
310
+
311
+ # accounts
312
+ accounts_parser = subparsers.add_parser("accounts", help="账号管理")
313
+ accounts_sub = accounts_parser.add_subparsers(dest="accounts_cmd")
314
+
315
+ # accounts list
316
+ list_parser = accounts_sub.add_parser("list", help="列出账号")
317
+ list_parser.set_defaults(func=cmd_accounts_list)
318
+
319
+ # accounts export
320
+ export_parser = accounts_sub.add_parser("export", help="导出账号")
321
+ export_parser.add_argument("-o", "--output", help="输出文件")
322
+ export_parser.set_defaults(func=cmd_accounts_export)
323
+
324
+ # accounts import
325
+ import_parser = accounts_sub.add_parser("import", help="导入账号")
326
+ import_parser.add_argument("file", help="JSON 文件路径")
327
+ import_parser.set_defaults(func=cmd_accounts_import)
328
+
329
+ # accounts add
330
+ add_parser = accounts_sub.add_parser("add", help="手动添加 Token")
331
+ add_parser.set_defaults(func=cmd_accounts_add)
332
+
333
+ # accounts scan
334
+ scan_parser = accounts_sub.add_parser("scan", help="扫描本地 Token")
335
+ scan_parser.add_argument("--auto", action="store_true", help="自动添加")
336
+ scan_parser.set_defaults(func=cmd_accounts_scan)
337
+
338
+ # login
339
+ login_parser = subparsers.add_parser("login", help="登录")
340
+ login_sub = login_parser.add_subparsers(dest="login_cmd")
341
+
342
+ # login remote
343
+ remote_parser = login_sub.add_parser("remote", help="生成远程登录链接")
344
+ remote_parser.add_argument("--host", help="服务器地址 (如 example.com:8080)")
345
+ remote_parser.add_argument("--https", action="store_true", help="使用 HTTPS")
346
+ remote_parser.set_defaults(func=cmd_login_remote)
347
+
348
+ # login google
349
+ google_parser = login_sub.add_parser("google", help="Google 登录")
350
+ google_parser.set_defaults(func=cmd_login_social, provider="google")
351
+
352
+ # login github
353
+ github_parser = login_sub.add_parser("github", help="GitHub 登录")
354
+ github_parser.set_defaults(func=cmd_login_social, provider="github")
355
+
356
+ args = parser.parse_args()
357
+
358
+ if not args.command:
359
+ parser.print_help()
360
+ return
361
+
362
+ if args.command == "accounts" and not args.accounts_cmd:
363
+ accounts_parser.print_help()
364
+ return
365
+
366
+ if args.command == "login" and not args.login_cmd:
367
+ login_parser.print_help()
368
+ return
369
+
370
+ if hasattr(args, "func"):
371
+ args.func(args)
372
+
373
+
374
+ if __name__ == "__main__":
375
+ main()
KiroProxy/kiro_proxy/config.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """配置模块"""
2
+ from pathlib import Path
3
+
4
+ KIRO_API_URL = "https://q.us-east-1.amazonaws.com/generateAssistantResponse"
5
+ MODELS_URL = "https://q.us-east-1.amazonaws.com/ListAvailableModels"
6
+
7
+ # 统一数据目录 (所有配置文件都在这里)
8
+ DATA_DIR = Path.home() / ".kiro-proxy"
9
+
10
+ # Token 存储目录
11
+ TOKEN_DIR = DATA_DIR / "tokens"
12
+
13
+ # 默认 Token 路径 (兼容旧代码)
14
+ TOKEN_PATH = TOKEN_DIR / "kiro-auth-token.json"
15
+
16
+ # 配额管理配置
17
+ QUOTA_COOLDOWN_SECONDS = 300 # 配额超限冷却时间(秒)
18
+
19
+ # 模型映射
20
+ MODEL_MAPPING = {
21
+ # Claude 3.5 -> Kiro Claude 4
22
+ "claude-3-5-sonnet-20241022": "claude-sonnet-4",
23
+ "claude-3-5-sonnet-latest": "claude-sonnet-4",
24
+ "claude-3-5-sonnet": "claude-sonnet-4",
25
+ "claude-3-5-haiku-20241022": "claude-haiku-4.5",
26
+ "claude-3-5-haiku-latest": "claude-haiku-4.5",
27
+ # Claude 3
28
+ "claude-3-opus-20240229": "claude-sonnet-4.5",
29
+ "claude-3-opus-latest": "claude-sonnet-4.5",
30
+ "claude-3-sonnet-20240229": "claude-sonnet-4",
31
+ "claude-3-haiku-20240307": "claude-haiku-4.5",
32
+ # Claude 4
33
+ "claude-4-sonnet": "claude-sonnet-4",
34
+ "claude-4-opus": "claude-sonnet-4.5",
35
+ # OpenAI GPT -> Claude
36
+ "gpt-4o": "claude-sonnet-4",
37
+ "gpt-4o-mini": "claude-haiku-4.5",
38
+ "gpt-4-turbo": "claude-sonnet-4",
39
+ "gpt-4": "claude-sonnet-4",
40
+ "gpt-3.5-turbo": "claude-haiku-4.5",
41
+ # OpenAI o1 -> Claude Opus
42
+ "o1": "claude-sonnet-4.5",
43
+ "o1-preview": "claude-sonnet-4.5",
44
+ "o1-mini": "claude-sonnet-4",
45
+ # Gemini -> Claude
46
+ "gemini-2.0-flash": "claude-sonnet-4",
47
+ "gemini-2.0-flash-thinking": "claude-sonnet-4.5",
48
+ "gemini-1.5-pro": "claude-sonnet-4.5",
49
+ "gemini-1.5-flash": "claude-sonnet-4",
50
+ # 别名
51
+ "sonnet": "claude-sonnet-4",
52
+ "haiku": "claude-haiku-4.5",
53
+ "opus": "claude-sonnet-4.5",
54
+ }
55
+
56
+ KIRO_MODELS = {"auto", "claude-sonnet-4.5", "claude-sonnet-4", "claude-haiku-4.5"}
57
+
58
+ def get_best_model_by_tier(tier: str, available_models: set = None) -> str:
59
+ """根据等级获取最佳可用模型(等级对等 + 智能降级)"""
60
+ if available_models is None:
61
+ available_models = KIRO_MODELS
62
+
63
+ # 等级对等映射 + 降级路径
64
+ TIER_PRIORITIES = {
65
+ # Opus: 最强 → 次强 → 快速 → 自动
66
+ "opus": ["claude-sonnet-4.5", "claude-sonnet-4", "claude-haiku-4.5", "auto"],
67
+
68
+ # Sonnet: 高性能 → 最强 → 标准 → 快速 → 自动
69
+ "sonnet": ["claude-sonnet-4.5", "claude-sonnet-4", "claude-haiku-4.5", "auto"],
70
+
71
+ # Haiku: 快速 → 标准 → 高性能 → 自动
72
+ "haiku": ["claude-haiku-4.5", "claude-sonnet-4", "claude-sonnet-4.5", "auto"],
73
+ }
74
+
75
+ priorities = TIER_PRIORITIES.get(tier, TIER_PRIORITIES["sonnet"])
76
+
77
+ # 选择第一个可用的模型
78
+ for model in priorities:
79
+ if model in available_models:
80
+ return model
81
+
82
+ return "auto" # 最终回退
83
+
84
+
85
+ def detect_model_tier(model: str) -> str:
86
+ """智能检测模型等级"""
87
+ if not model:
88
+ return "sonnet" # 默认中等
89
+
90
+ model_lower = model.lower()
91
+
92
+ # 特殊模型优先检测(避免被通用关键词误判)
93
+ if "gemini" in model_lower:
94
+ if any(keyword in model_lower for keyword in ["1.5-pro", "pro"]):
95
+ return "opus"
96
+ elif any(keyword in model_lower for keyword in ["2.0", "flash"]):
97
+ return "sonnet" # Gemini 2.0 和 flash 系列归为 sonnet
98
+
99
+ # 等级关键词检测(优先级从高到低)
100
+ # Opus 等级 - 最强模型
101
+ if any(keyword in model_lower for keyword in ["opus", "o1", "max", "ultra", "premium"]):
102
+ return "opus"
103
+
104
+ # Haiku 等级 - 快速模型(需要排除 sonnet 中的 3.5)
105
+ if any(keyword in model_lower for keyword in ["haiku", "mini", "light", "fast", "turbo"]):
106
+ return "haiku"
107
+ # 特殊处理:gpt-3.5 系列属于 haiku
108
+ if "3.5" in model_lower and "sonnet" not in model_lower:
109
+ return "haiku"
110
+
111
+ # Sonnet 等级 - 平衡模型
112
+ if any(keyword in model_lower for keyword in ["sonnet", "4o", "4", "standard", "base"]):
113
+ return "sonnet"
114
+
115
+ return "sonnet" # 默认中等
116
+
117
+
118
+ def map_model_name(model: str, available_models: set = None) -> str:
119
+ """将外部模型名称映射到 Kiro 支持的名称(支持动态模型选择)"""
120
+ if not model:
121
+ return "auto"
122
+
123
+ # 1. 精确匹配优先
124
+ if model in MODEL_MAPPING:
125
+ return MODEL_MAPPING[model]
126
+ if model in KIRO_MODELS:
127
+ return model
128
+
129
+ # 2. 智能等级检测 + 动态选择
130
+ tier = detect_model_tier(model)
131
+ best_model = get_best_model_by_tier(tier, available_models)
132
+
133
+ return best_model
KiroProxy/kiro_proxy/converters/__init__.py ADDED
@@ -0,0 +1,1196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """协议转换模块 - Anthropic/OpenAI/Gemini <-> Kiro
2
+
3
+ 增强版:参考 proxycast 实现
4
+ - 工具数量限制(最多 50 个)
5
+ - 工具描述截断(最多 500 字符)
6
+ - 历史消息交替修复
7
+ - OpenAI tool 角色消息处理
8
+ - tool_choice: required 支持
9
+ - web_search 特殊工具支持
10
+ - tool_results 去重
11
+ """
12
+ import json
13
+ import hashlib
14
+ import re
15
+ from typing import List, Dict, Any, Tuple, Optional
16
+
17
+ # 常量
18
+ MAX_TOOLS = 50
19
+ MAX_TOOL_DESCRIPTION_LENGTH = 500
20
+
21
+
22
+ def generate_session_id(messages: list) -> str:
23
+ """基于消息内容生成会话ID"""
24
+ content = json.dumps(messages[:3], sort_keys=True)
25
+ return hashlib.sha256(content.encode()).hexdigest()[:16]
26
+
27
+
28
+ def extract_images_from_content(content) -> Tuple[str, List[dict]]:
29
+ """从消息内容中提取文本和图片
30
+
31
+ Returns:
32
+ (text_content, images_list)
33
+ """
34
+ if isinstance(content, str):
35
+ return content, []
36
+
37
+ if not isinstance(content, list):
38
+ return str(content) if content else "", []
39
+
40
+ text_parts = []
41
+ images = []
42
+
43
+ for block in content:
44
+ if isinstance(block, str):
45
+ text_parts.append(block)
46
+ elif isinstance(block, dict):
47
+ block_type = block.get("type", "")
48
+
49
+ if block_type == "text":
50
+ text_parts.append(block.get("text", ""))
51
+
52
+ elif block_type == "image":
53
+ # Anthropic 格式
54
+ source = block.get("source", {})
55
+ media_type = source.get("media_type", "image/jpeg")
56
+ data = source.get("data", "")
57
+
58
+ fmt = "jpeg"
59
+ if "png" in media_type:
60
+ fmt = "png"
61
+ elif "gif" in media_type:
62
+ fmt = "gif"
63
+ elif "webp" in media_type:
64
+ fmt = "webp"
65
+
66
+ if data:
67
+ images.append({
68
+ "format": fmt,
69
+ "source": {"bytes": data}
70
+ })
71
+
72
+ elif block_type == "image_url":
73
+ # OpenAI 格式
74
+ image_url = block.get("image_url", {})
75
+ url = image_url.get("url", "")
76
+
77
+ if url.startswith("data:"):
78
+ match = re.match(r'data:image/(\w+);base64,(.+)', url)
79
+ if match:
80
+ fmt = match.group(1)
81
+ data = match.group(2)
82
+ images.append({
83
+ "format": fmt,
84
+ "source": {"bytes": data}
85
+ })
86
+
87
+ return "\n".join(text_parts), images
88
+
89
+
90
+ def truncate_description(desc: str, max_length: int = MAX_TOOL_DESCRIPTION_LENGTH) -> str:
91
+ """截断工具描述"""
92
+ if len(desc) <= max_length:
93
+ return desc
94
+ return desc[:max_length - 3] + "..."
95
+
96
+
97
+ # ==================== Anthropic 转换 ====================
98
+
99
+ def convert_anthropic_tools_to_kiro(tools: List[dict]) -> List[dict]:
100
+ """将 Anthropic 工具格式转换为 Kiro 格式
101
+
102
+ 增强:
103
+ - 限制最多 50 个工具
104
+ - 截断过长的描述
105
+ - 支持 web_search 特殊工具
106
+ """
107
+ kiro_tools = []
108
+ function_count = 0
109
+
110
+ for tool in tools:
111
+ name = tool.get("name", "")
112
+
113
+ # 特殊工具:web_search
114
+ if name in ("web_search", "web_search_20250305"):
115
+ kiro_tools.append({
116
+ "webSearchTool": {
117
+ "type": "web_search"
118
+ }
119
+ })
120
+ continue
121
+
122
+ # 限制工具数量
123
+ if function_count >= MAX_TOOLS:
124
+ continue
125
+ function_count += 1
126
+
127
+ description = tool.get("description", f"Tool: {name}")
128
+ description = truncate_description(description)
129
+
130
+ input_schema = tool.get("input_schema", {"type": "object", "properties": {}})
131
+
132
+ kiro_tools.append({
133
+ "toolSpecification": {
134
+ "name": name,
135
+ "description": description,
136
+ "inputSchema": {
137
+ "json": input_schema
138
+ }
139
+ }
140
+ })
141
+
142
+ return kiro_tools
143
+
144
+
145
+ def fix_history_alternation(history: List[dict], model_id: str = "claude-sonnet-4") -> List[dict]:
146
+ """修复历史记录,确保 user/assistant 严格交替,并验证 toolUses/toolResults 配对
147
+
148
+ Kiro API 规则:
149
+ 1. 消息必须严格交替:user -> assistant -> user -> assistant
150
+ 2. 当 assistant 有 toolUses 时,下一条 user 必须有对应的 toolResults
151
+ 3. 当 assistant 没有 toolUses 时,下一条 user 不能有 toolResults
152
+ """
153
+ if not history:
154
+ return history
155
+
156
+ # 深拷贝以避免修改原始数据
157
+ import copy
158
+ history = copy.deepcopy(history)
159
+
160
+ fixed = []
161
+
162
+ for i, item in enumerate(history):
163
+ is_user = "userInputMessage" in item
164
+ is_assistant = "assistantResponseMessage" in item
165
+
166
+ if is_user:
167
+ # 检查上一条是否也是 user
168
+ if fixed and "userInputMessage" in fixed[-1]:
169
+ # 检查当前消息是否有 tool_results
170
+ user_msg = item["userInputMessage"]
171
+ ctx = user_msg.get("userInputMessageContext", {})
172
+ has_tool_results = bool(ctx.get("toolResults"))
173
+
174
+ if has_tool_results:
175
+ # 合并 tool_results 到上一条 user 消息
176
+ new_results = ctx["toolResults"]
177
+ last_user = fixed[-1]["userInputMessage"]
178
+
179
+ if "userInputMessageContext" not in last_user:
180
+ last_user["userInputMessageContext"] = {}
181
+
182
+ last_ctx = last_user["userInputMessageContext"]
183
+ if "toolResults" in last_ctx and last_ctx["toolResults"]:
184
+ last_ctx["toolResults"].extend(new_results)
185
+ else:
186
+ last_ctx["toolResults"] = new_results
187
+ continue
188
+ else:
189
+ # 插入一个占位 assistant 消息(不带 toolUses)
190
+ fixed.append({
191
+ "assistantResponseMessage": {
192
+ "content": "I understand."
193
+ }
194
+ })
195
+
196
+ # 验证 toolResults 与前一个 assistant 的 toolUses 配对
197
+ if fixed and "assistantResponseMessage" in fixed[-1]:
198
+ last_assistant = fixed[-1]["assistantResponseMessage"]
199
+ has_tool_uses = bool(last_assistant.get("toolUses"))
200
+
201
+ user_msg = item["userInputMessage"]
202
+ ctx = user_msg.get("userInputMessageContext", {})
203
+ has_tool_results = bool(ctx.get("toolResults"))
204
+
205
+ if has_tool_uses and not has_tool_results:
206
+ # assistant 有 toolUses 但 user 没有 toolResults
207
+ # 这是不允许的:不要删除 toolUses(否则会破坏后续上下文/导致 tool_use 轮次丢失)
208
+ # 改为在本条 user 前插入一个“工具结果占位” user 消息,与 toolUses 严格配对。
209
+ placeholder_results = []
210
+ for tu in (last_assistant.get("toolUses") or []):
211
+ tuid = ""
212
+ if isinstance(tu, dict):
213
+ tuid = tu.get("toolUseId") or ""
214
+ if tuid:
215
+ placeholder_results.append({
216
+ "content": [{"text": ""}],
217
+ "status": "success",
218
+ "toolUseId": tuid,
219
+ })
220
+ fixed.append({
221
+ "userInputMessage": {
222
+ "content": "Tool results provided.",
223
+ "modelId": model_id,
224
+ "origin": "AI_EDITOR",
225
+ "userInputMessageContext": {
226
+ "toolResults": placeholder_results
227
+ }
228
+ }
229
+ })
230
+ elif not has_tool_uses and has_tool_results:
231
+ # assistant 没有 toolUses 但 user 有 toolResults
232
+ # 这是不允许的,需要清除 user 的 toolResults
233
+ item["userInputMessage"].pop("userInputMessageContext", None)
234
+
235
+ fixed.append(item)
236
+
237
+ elif is_assistant:
238
+ # 检查上一条是否也是 assistant
239
+ if fixed and "assistantResponseMessage" in fixed[-1]:
240
+ # 插入一个占位 user 消息(不带 toolResults)
241
+ fixed.append({
242
+ "userInputMessage": {
243
+ "content": "Continue",
244
+ "modelId": model_id,
245
+ "origin": "AI_EDITOR"
246
+ }
247
+ })
248
+
249
+ # 如果历史为空,先插入一个 user 消息
250
+ if not fixed:
251
+ fixed.append({
252
+ "userInputMessage": {
253
+ "content": "Continue",
254
+ "modelId": model_id,
255
+ "origin": "AI_EDITOR"
256
+ }
257
+ })
258
+
259
+ fixed.append(item)
260
+
261
+ # 确保以 assistant 结尾(如果最后是 user,添加占位 assistant)
262
+ if fixed and "userInputMessage" in fixed[-1]:
263
+ # 不需要清除 toolResults,因为它是与前一个 assistant 的 toolUses 配对的
264
+ # 占位 assistant 只是为了满足交替规则
265
+ fixed.append({
266
+ "assistantResponseMessage": {
267
+ "content": "I understand."
268
+ }
269
+ })
270
+
271
+ return fixed
272
+
273
+
274
+ def convert_anthropic_messages_to_kiro(messages: List[dict], system="") -> Tuple[str, List[dict], List[dict]]:
275
+ """将 Anthropic 消息格式转换为 Kiro 格式
276
+
277
+ Returns:
278
+ (user_content, history, tool_results)
279
+ """
280
+ history = []
281
+ user_content = ""
282
+ current_tool_results = []
283
+
284
+ def _strip_thinking(text: str) -> str:
285
+ if text is None:
286
+ return ""
287
+ if not isinstance(text, str):
288
+ text = str(text)
289
+ if not text:
290
+ return ""
291
+ cleaned = text
292
+ while True:
293
+ start = find_real_thinking_start_tag(cleaned)
294
+ if start == -1:
295
+ break
296
+ end = find_real_thinking_end_tag(cleaned, start + len("<thinking>"))
297
+ if end == -1:
298
+ cleaned = cleaned[:start].rstrip()
299
+ break
300
+ before = cleaned[:start].rstrip()
301
+ after = cleaned[end + len("</thinking>"):].lstrip()
302
+ if before and after:
303
+ cleaned = before + "\n" + after
304
+ else:
305
+ cleaned = before or after
306
+ return cleaned.strip()
307
+
308
+ # 处理 system
309
+ system_text = ""
310
+ if isinstance(system, list):
311
+ for block in system:
312
+ if isinstance(block, dict) and block.get("type") == "text":
313
+ system_text += block.get("text", "") + "\n"
314
+ elif isinstance(block, str):
315
+ system_text += block + "\n"
316
+ system_text = system_text.strip()
317
+ elif isinstance(system, str):
318
+ system_text = system
319
+
320
+ system_text = _strip_thinking(system_text)
321
+
322
+ for i, msg in enumerate(messages):
323
+ role = msg.get("role", "")
324
+ content = msg.get("content", "")
325
+ is_last = (i == len(messages) - 1)
326
+
327
+ # 处理 content 列表
328
+ tool_results = []
329
+ text_parts = []
330
+
331
+ if isinstance(content, list):
332
+ for block in content:
333
+ if isinstance(block, dict):
334
+ if block.get("type") == "text":
335
+ text_parts.append(block.get("text", ""))
336
+ elif block.get("type") == "tool_result":
337
+ tr_content = block.get("content", "")
338
+ if isinstance(tr_content, list):
339
+ tr_text_parts = []
340
+ for tc in tr_content:
341
+ if isinstance(tc, dict) and tc.get("type") == "text":
342
+ tr_text_parts.append(tc.get("text", ""))
343
+ elif isinstance(tc, str):
344
+ tr_text_parts.append(tc)
345
+ tr_content = "\n".join(tr_text_parts)
346
+
347
+ # 处理 is_error
348
+ status = "error" if block.get("is_error") else "success"
349
+
350
+ tool_results.append({
351
+ "content": [{"text": str(tr_content)}],
352
+ "status": status,
353
+ "toolUseId": block.get("tool_use_id", "")
354
+ })
355
+ elif isinstance(block, str):
356
+ text_parts.append(block)
357
+
358
+ content = "\n".join(text_parts) if text_parts else ""
359
+
360
+ content = _strip_thinking(content)
361
+
362
+ # 处理工具结果
363
+ if tool_results:
364
+ # 去重
365
+ seen_ids = set()
366
+ unique_results = []
367
+ for tr in tool_results:
368
+ if tr["toolUseId"] not in seen_ids:
369
+ seen_ids.add(tr["toolUseId"])
370
+ unique_results.append(tr)
371
+ tool_results = unique_results
372
+
373
+ if is_last:
374
+ current_tool_results = tool_results
375
+ user_content = content if content else "Tool results provided."
376
+ else:
377
+ history.append({
378
+ "userInputMessage": {
379
+ "content": content if content else "Tool results provided.",
380
+ "modelId": "claude-sonnet-4",
381
+ "origin": "AI_EDITOR",
382
+ "userInputMessageContext": {
383
+ "toolResults": tool_results
384
+ }
385
+ }
386
+ })
387
+ continue
388
+
389
+ if role == "user":
390
+ if system_text and not history:
391
+ content = f"{system_text}\n\n{content}" if content else system_text
392
+
393
+ content = _strip_thinking(content)
394
+
395
+ if is_last:
396
+ user_content = content if content else "Continue"
397
+ else:
398
+ history.append({
399
+ "userInputMessage": {
400
+ "content": content if content else "Continue",
401
+ "modelId": "claude-sonnet-4",
402
+ "origin": "AI_EDITOR"
403
+ }
404
+ })
405
+
406
+ elif role == "assistant":
407
+ tool_uses = []
408
+ assistant_text = ""
409
+
410
+ if isinstance(msg.get("content"), list):
411
+ text_parts = []
412
+ for block in msg["content"]:
413
+ if isinstance(block, dict):
414
+ if block.get("type") == "tool_use":
415
+ tool_uses.append({
416
+ "toolUseId": block.get("id", ""),
417
+ "name": block.get("name", ""),
418
+ "input": block.get("input", {})
419
+ })
420
+ elif block.get("type") == "text":
421
+ text_parts.append(block.get("text", ""))
422
+ assistant_text = "\n".join(text_parts)
423
+ else:
424
+ assistant_text = content if isinstance(content, str) else ""
425
+
426
+ assistant_text = _strip_thinking(assistant_text)
427
+
428
+ if not assistant_text and not tool_uses:
429
+ continue
430
+
431
+ # 确保 assistant 消息有内容
432
+ if not assistant_text:
433
+ assistant_text = "I understand."
434
+
435
+ assistant_msg = {
436
+ "assistantResponseMessage": {
437
+ "content": assistant_text
438
+ }
439
+ }
440
+ # 只有在有 toolUses 时才添加这个字段
441
+ if tool_uses:
442
+ assistant_msg["assistantResponseMessage"]["toolUses"] = tool_uses
443
+
444
+ history.append(assistant_msg)
445
+
446
+ # 修复历史交替
447
+ history = fix_history_alternation(history)
448
+
449
+ return user_content, history, current_tool_results
450
+
451
+
452
+ def convert_kiro_response_to_anthropic(result: dict, model: str, msg_id: str) -> dict:
453
+ """将 Kiro 响应转换为 Anthropic 格式"""
454
+ content = []
455
+ text = "".join(result["content"])
456
+ if text:
457
+ content.append({"type": "text", "text": text})
458
+
459
+ for tool_use in result["tool_uses"]:
460
+ content.append(tool_use)
461
+
462
+ return {
463
+ "id": msg_id,
464
+ "type": "message",
465
+ "role": "assistant",
466
+ "content": content,
467
+ "model": model,
468
+ "stop_reason": result["stop_reason"],
469
+ "stop_sequence": None,
470
+ "usage": {"input_tokens": 100, "output_tokens": 100}
471
+ }
472
+
473
+
474
+ # ==================== OpenAI 转换 ====================
475
+
476
+ def is_tool_choice_required(tool_choice) -> bool:
477
+ """检查 tool_choice 是否为 required"""
478
+ if isinstance(tool_choice, dict):
479
+ t = tool_choice.get("type", "")
480
+ return t in ("any", "tool", "required")
481
+ elif isinstance(tool_choice, str):
482
+ return tool_choice in ("required", "any")
483
+ return False
484
+
485
+
486
+ def convert_openai_tools_to_kiro(tools: List[dict]) -> List[dict]:
487
+ """将 OpenAI 工具格式转换为 Kiro 格式"""
488
+ kiro_tools = []
489
+ function_count = 0
490
+
491
+ for tool in tools:
492
+ tool_type = tool.get("type", "function")
493
+
494
+ # 特殊工具
495
+ if tool_type == "web_search":
496
+ kiro_tools.append({
497
+ "webSearchTool": {
498
+ "type": "web_search"
499
+ }
500
+ })
501
+ continue
502
+
503
+ if tool_type != "function":
504
+ continue
505
+
506
+ # 限制工具数量
507
+ if function_count >= MAX_TOOLS:
508
+ continue
509
+ function_count += 1
510
+
511
+ func = tool.get("function", {})
512
+ name = func.get("name", "")
513
+ description = func.get("description", f"Tool: {name}")
514
+ description = truncate_description(description)
515
+ parameters = func.get("parameters", {"type": "object", "properties": {}})
516
+
517
+ kiro_tools.append({
518
+ "toolSpecification": {
519
+ "name": name,
520
+ "description": description,
521
+ "inputSchema": {
522
+ "json": parameters
523
+ }
524
+ }
525
+ })
526
+
527
+ return kiro_tools
528
+
529
+
530
+ def convert_openai_messages_to_kiro(
531
+ messages: List[dict],
532
+ model: str,
533
+ tools: List[dict] = None,
534
+ tool_choice = None
535
+ ) -> Tuple[str, List[dict], List[dict], List[dict]]:
536
+ """将 OpenAI 消息格式转换为 Kiro 格式
537
+
538
+ 增强:
539
+ - 支持 tool 角色消息
540
+ - 支持 assistant 的 tool_calls
541
+ - 支持 tool_choice: required
542
+ - 历史交替修复
543
+
544
+ Returns:
545
+ (user_content, history, tool_results, kiro_tools)
546
+ """
547
+ system_content = ""
548
+ history = []
549
+ user_content = ""
550
+ current_tool_results = []
551
+ pending_tool_results = [] # 待处理的 tool 消息
552
+
553
+ # 处理 tool_choice: required
554
+ tool_instruction = ""
555
+ if is_tool_choice_required(tool_choice) and tools:
556
+ tool_instruction = "\n\n[CRITICAL INSTRUCTION] You MUST use one of the provided tools to respond. Do NOT respond with plain text. Call a tool function immediately."
557
+
558
+ for i, msg in enumerate(messages):
559
+ role = msg.get("role", "")
560
+ content = msg.get("content", "")
561
+ is_last = (i == len(messages) - 1)
562
+
563
+ # 提取文本内容
564
+ if isinstance(content, list):
565
+ content = " ".join([c.get("text", "") for c in content if c.get("type") == "text"])
566
+ if not content:
567
+ content = ""
568
+
569
+ if role == "system":
570
+ system_content = content + tool_instruction
571
+
572
+ elif role == "tool":
573
+ # OpenAI tool 角色消息 -> Kiro toolResults
574
+ tool_call_id = msg.get("tool_call_id", "")
575
+ pending_tool_results.append({
576
+ "content": [{"text": str(content)}],
577
+ "status": "success",
578
+ "toolUseId": tool_call_id
579
+ })
580
+
581
+ elif role == "user":
582
+ # 如果有待处理的 tool results,先处理
583
+ if pending_tool_results:
584
+ # 去重
585
+ seen_ids = set()
586
+ unique_results = []
587
+ for tr in pending_tool_results:
588
+ if tr["toolUseId"] not in seen_ids:
589
+ seen_ids.add(tr["toolUseId"])
590
+ unique_results.append(tr)
591
+
592
+ if is_last:
593
+ current_tool_results = unique_results
594
+ else:
595
+ history.append({
596
+ "userInputMessage": {
597
+ "content": "Tool results provided.",
598
+ "modelId": model,
599
+ "origin": "AI_EDITOR",
600
+ "userInputMessageContext": {
601
+ "toolResults": unique_results
602
+ }
603
+ }
604
+ })
605
+ pending_tool_results = []
606
+
607
+ # 合并 system prompt
608
+ if system_content and not history:
609
+ content = f"{system_content}\n\n{content}"
610
+
611
+ if is_last:
612
+ user_content = content
613
+ else:
614
+ history.append({
615
+ "userInputMessage": {
616
+ "content": content,
617
+ "modelId": model,
618
+ "origin": "AI_EDITOR"
619
+ }
620
+ })
621
+
622
+ elif role == "assistant":
623
+ # 如果有待处理的 tool results,先创建 user 消息
624
+ if pending_tool_results:
625
+ seen_ids = set()
626
+ unique_results = []
627
+ for tr in pending_tool_results:
628
+ if tr["toolUseId"] not in seen_ids:
629
+ seen_ids.add(tr["toolUseId"])
630
+ unique_results.append(tr)
631
+
632
+ history.append({
633
+ "userInputMessage": {
634
+ "content": "Tool results provided.",
635
+ "modelId": model,
636
+ "origin": "AI_EDITOR",
637
+ "userInputMessageContext": {
638
+ "toolResults": unique_results
639
+ }
640
+ }
641
+ })
642
+ pending_tool_results = []
643
+
644
+ # 处理 tool_calls
645
+ tool_uses = []
646
+ tool_calls = msg.get("tool_calls", [])
647
+ for tc in tool_calls:
648
+ func = tc.get("function", {})
649
+ args_str = func.get("arguments", "{}")
650
+ try:
651
+ args = json.loads(args_str)
652
+ except:
653
+ args = {}
654
+
655
+ tool_uses.append({
656
+ "toolUseId": tc.get("id", ""),
657
+ "name": func.get("name", ""),
658
+ "input": args
659
+ })
660
+
661
+ assistant_text = content if content else "I understand."
662
+
663
+ assistant_msg = {
664
+ "assistantResponseMessage": {
665
+ "content": assistant_text
666
+ }
667
+ }
668
+ # 只有在有 toolUses 时才添加这个字段
669
+ if tool_uses:
670
+ assistant_msg["assistantResponseMessage"]["toolUses"] = tool_uses
671
+
672
+ history.append(assistant_msg)
673
+
674
+ # 处理末尾的 tool results
675
+ if pending_tool_results:
676
+ seen_ids = set()
677
+ unique_results = []
678
+ for tr in pending_tool_results:
679
+ if tr["toolUseId"] not in seen_ids:
680
+ seen_ids.add(tr["toolUseId"])
681
+ unique_results.append(tr)
682
+ current_tool_results = unique_results
683
+ if not user_content:
684
+ user_content = "Tool results provided."
685
+
686
+ # 如果没有用户消息
687
+ if not user_content:
688
+ user_content = messages[-1].get("content", "") if messages else "Continue"
689
+ if isinstance(user_content, list):
690
+ user_content = " ".join([c.get("text", "") for c in user_content if c.get("type") == "text"])
691
+ if not user_content:
692
+ user_content = "Continue"
693
+
694
+ # 历史不包含最后一条用户消息
695
+ if history and "userInputMessage" in history[-1]:
696
+ history = history[:-1]
697
+
698
+ # 修复历史交替
699
+ history = fix_history_alternation(history, model)
700
+
701
+ # 转换工具
702
+ kiro_tools = convert_openai_tools_to_kiro(tools) if tools else []
703
+
704
+ return user_content, history, current_tool_results, kiro_tools
705
+
706
+
707
+ def convert_kiro_response_to_openai(result: dict, model: str, msg_id: str) -> dict:
708
+ """将 Kiro 响应转换为 OpenAI 格式"""
709
+ text = "".join(result["content"])
710
+ tool_calls = []
711
+
712
+ for tool_use in result.get("tool_uses", []):
713
+ if tool_use.get("type") == "tool_use":
714
+ tool_calls.append({
715
+ "id": tool_use.get("id", ""),
716
+ "type": "function",
717
+ "function": {
718
+ "name": tool_use.get("name", ""),
719
+ "arguments": json.dumps(tool_use.get("input", {}))
720
+ }
721
+ })
722
+
723
+ # 映射 stop_reason
724
+ stop_reason = result.get("stop_reason", "stop")
725
+ finish_reason = "tool_calls" if tool_calls else "stop"
726
+ if stop_reason == "max_tokens":
727
+ finish_reason = "length"
728
+
729
+ message = {
730
+ "role": "assistant",
731
+ "content": text if text else None
732
+ }
733
+ if tool_calls:
734
+ message["tool_calls"] = tool_calls
735
+
736
+ return {
737
+ "id": msg_id,
738
+ "object": "chat.completion",
739
+ "model": model,
740
+ "choices": [{
741
+ "index": 0,
742
+ "message": message,
743
+ "finish_reason": finish_reason
744
+ }],
745
+ "usage": {
746
+ "prompt_tokens": 100,
747
+ "completion_tokens": 100,
748
+ "total_tokens": 200
749
+ }
750
+ }
751
+
752
+
753
+ # ==================== Gemini 转换 ====================
754
+
755
+ def convert_gemini_tools_to_kiro(tools: List[dict]) -> List[dict]:
756
+ """将 Gemini 工具格式转换为 Kiro 格式
757
+
758
+ Gemini 工具格式:
759
+ {
760
+ "functionDeclarations": [
761
+ {
762
+ "name": "get_weather",
763
+ "description": "Get weather info",
764
+ "parameters": {...}
765
+ }
766
+ ]
767
+ }
768
+ """
769
+ kiro_tools = []
770
+ function_count = 0
771
+
772
+ for tool in tools:
773
+ # Gemini 的工具定义在 functionDeclarations 中
774
+ declarations = tool.get("functionDeclarations", [])
775
+
776
+ for func in declarations:
777
+ # 限制工具数量
778
+ if function_count >= MAX_TOOLS:
779
+ break
780
+ function_count += 1
781
+
782
+ name = func.get("name", "")
783
+ description = func.get("description", f"Tool: {name}")
784
+ description = truncate_description(description)
785
+ parameters = func.get("parameters", {"type": "object", "properties": {}})
786
+
787
+ kiro_tools.append({
788
+ "toolSpecification": {
789
+ "name": name,
790
+ "description": description,
791
+ "inputSchema": {
792
+ "json": parameters
793
+ }
794
+ }
795
+ })
796
+
797
+ return kiro_tools
798
+
799
+
800
+ def convert_gemini_contents_to_kiro(
801
+ contents: List[dict],
802
+ system_instruction: dict,
803
+ model: str,
804
+ tools: List[dict] = None,
805
+ tool_config: dict = None
806
+ ) -> Tuple[str, List[dict], List[dict], List[dict]]:
807
+ """将 Gemini 消息格式转换为 Kiro 格式
808
+
809
+ 增强:
810
+ - 支持 functionCall 和 functionResponse
811
+ - 支持 tool_config
812
+
813
+ Returns:
814
+ (user_content, history, tool_results, kiro_tools)
815
+ """
816
+ history = []
817
+ user_content = ""
818
+ current_tool_results = []
819
+ pending_tool_results = []
820
+
821
+ # 处理 system instruction
822
+ system_text = ""
823
+ if system_instruction:
824
+ parts = system_instruction.get("parts", [])
825
+ system_text = " ".join(p.get("text", "") for p in parts if "text" in p)
826
+
827
+ # 处理 tool_config(类似 tool_choice)
828
+ tool_instruction = ""
829
+ if tool_config:
830
+ mode = tool_config.get("functionCallingConfig", {}).get("mode", "")
831
+ if mode in ("ANY", "REQUIRED"):
832
+ tool_instruction = "\n\n[CRITICAL INSTRUCTION] You MUST use one of the provided tools to respond. Do NOT respond with plain text."
833
+
834
+ for i, content in enumerate(contents):
835
+ role = content.get("role", "user")
836
+ parts = content.get("parts", [])
837
+ is_last = (i == len(contents) - 1)
838
+
839
+ # 提取文本和工具调用
840
+ text_parts = []
841
+ tool_calls = []
842
+ tool_responses = []
843
+
844
+ for part in parts:
845
+ if "text" in part:
846
+ text_parts.append(part["text"])
847
+ elif "functionCall" in part:
848
+ # Gemini 的工具调用
849
+ fc = part["functionCall"]
850
+ tool_calls.append({
851
+ "toolUseId": fc.get("name", "") + "_" + str(i), # Gemini 没有 ID,生成一个
852
+ "name": fc.get("name", ""),
853
+ "input": fc.get("args", {})
854
+ })
855
+ elif "functionResponse" in part:
856
+ # Gemini 的工具响应
857
+ fr = part["functionResponse"]
858
+ response_content = fr.get("response", {})
859
+ if isinstance(response_content, dict):
860
+ response_text = json.dumps(response_content)
861
+ else:
862
+ response_text = str(response_content)
863
+
864
+ tool_responses.append({
865
+ "content": [{"text": response_text}],
866
+ "status": "success",
867
+ "toolUseId": fr.get("name", "") + "_" + str(i - 1) # 匹配上一个调用
868
+ })
869
+
870
+ text = " ".join(text_parts)
871
+
872
+ if role == "user":
873
+ # 处理待处理的 tool responses
874
+ if pending_tool_results:
875
+ seen_ids = set()
876
+ unique_results = []
877
+ for tr in pending_tool_results:
878
+ if tr["toolUseId"] not in seen_ids:
879
+ seen_ids.add(tr["toolUseId"])
880
+ unique_results.append(tr)
881
+
882
+ history.append({
883
+ "userInputMessage": {
884
+ "content": "Tool results provided.",
885
+ "modelId": model,
886
+ "origin": "AI_EDITOR",
887
+ "userInputMessageContext": {
888
+ "toolResults": unique_results
889
+ }
890
+ }
891
+ })
892
+ pending_tool_results = []
893
+
894
+ # 处理 functionResponse(用户消息中的工具响应)
895
+ if tool_responses:
896
+ pending_tool_results.extend(tool_responses)
897
+
898
+ # 合并 system prompt
899
+ if system_text and not history:
900
+ text = f"{system_text}{tool_instruction}\n\n{text}"
901
+
902
+ if is_last:
903
+ user_content = text
904
+ if pending_tool_results:
905
+ current_tool_results = pending_tool_results
906
+ pending_tool_results = []
907
+ else:
908
+ if text:
909
+ history.append({
910
+ "userInputMessage": {
911
+ "content": text,
912
+ "modelId": model,
913
+ "origin": "AI_EDITOR"
914
+ }
915
+ })
916
+
917
+ elif role == "model":
918
+ # 处理待处理的 tool responses
919
+ if pending_tool_results:
920
+ seen_ids = set()
921
+ unique_results = []
922
+ for tr in pending_tool_results:
923
+ if tr["toolUseId"] not in seen_ids:
924
+ seen_ids.add(tr["toolUseId"])
925
+ unique_results.append(tr)
926
+
927
+ history.append({
928
+ "userInputMessage": {
929
+ "content": "Tool results provided.",
930
+ "modelId": model,
931
+ "origin": "AI_EDITOR",
932
+ "userInputMessageContext": {
933
+ "toolResults": unique_results
934
+ }
935
+ }
936
+ })
937
+ pending_tool_results = []
938
+
939
+ assistant_text = text if text else "I understand."
940
+
941
+ assistant_msg = {
942
+ "assistantResponseMessage": {
943
+ "content": assistant_text
944
+ }
945
+ }
946
+ # 只有在有 toolUses 时才添加这个字段
947
+ if tool_calls:
948
+ assistant_msg["assistantResponseMessage"]["toolUses"] = tool_calls
949
+
950
+ history.append(assistant_msg)
951
+
952
+ # 处理末尾的 tool results
953
+ if pending_tool_results:
954
+ current_tool_results = pending_tool_results
955
+ if not user_content:
956
+ user_content = "Tool results provided."
957
+
958
+ # 如果没有用户消息
959
+ if not user_content:
960
+ if contents:
961
+ last_parts = contents[-1].get("parts", [])
962
+ user_content = " ".join(p.get("text", "") for p in last_parts if "text" in p)
963
+ if not user_content:
964
+ user_content = "Continue"
965
+
966
+ # 修复历史交替
967
+ history = fix_history_alternation(history, model)
968
+
969
+ # 移除最后一条(当前用户消息)
970
+ if history and "userInputMessage" in history[-1]:
971
+ history = history[:-1]
972
+
973
+ # 转换工具
974
+ kiro_tools = convert_gemini_tools_to_kiro(tools) if tools else []
975
+
976
+ return user_content, history, current_tool_results, kiro_tools
977
+
978
+
979
+ def convert_kiro_response_to_gemini(result: dict, model: str) -> dict:
980
+ """将 Kiro 响应转换为 Gemini 格式"""
981
+ text = "".join(result.get("content", []))
982
+ tool_uses = result.get("tool_uses", [])
983
+
984
+ parts = []
985
+
986
+ # 添加文本部分
987
+ if text:
988
+ parts.append({"text": text})
989
+
990
+ # 添加工具调用
991
+ for tool_use in tool_uses:
992
+ if tool_use.get("type") == "tool_use":
993
+ parts.append({
994
+ "functionCall": {
995
+ "name": tool_use.get("name", ""),
996
+ "args": tool_use.get("input", {})
997
+ }
998
+ })
999
+
1000
+ # 映射 stop_reason
1001
+ stop_reason = result.get("stop_reason", "STOP")
1002
+ finish_reason = "STOP"
1003
+ if tool_uses:
1004
+ finish_reason = "TOOL_CALLS"
1005
+ elif stop_reason == "max_tokens":
1006
+ finish_reason = "MAX_TOKENS"
1007
+
1008
+ return {
1009
+ "candidates": [{
1010
+ "content": {
1011
+ "parts": parts,
1012
+ "role": "model"
1013
+ },
1014
+ "finishReason": finish_reason,
1015
+ "index": 0
1016
+ }],
1017
+ "usageMetadata": {
1018
+ "promptTokenCount": 100,
1019
+ "candidatesTokenCount": 100,
1020
+ "totalTokenCount": 200
1021
+ }
1022
+ }
1023
+
1024
+
1025
+ # ==================== 思考功能支持 ====================
1026
+
1027
+ def generate_thinking_prefix(thinking_type: str = "enabled", budget_tokens: int = 20000) -> str:
1028
+ """生成思考模式的前缀 XML 标签
1029
+
1030
+ Args:
1031
+ thinking_type: 思考类型,通常为 "enabled"
1032
+ budget_tokens: 思考的 token 预算
1033
+
1034
+ Returns:
1035
+ XML 格式的思考标签字符串
1036
+ """
1037
+ if thinking_type != "enabled":
1038
+ return ""
1039
+
1040
+ return f"<thinking_mode>enabled</thinking_mode>\n<max_thinking_length>{budget_tokens}</max_thinking_length>"
1041
+
1042
+
1043
+ def has_thinking_tags(text: str) -> bool:
1044
+ """检查文本是否已包含思考标签
1045
+
1046
+ Args:
1047
+ text: 要检查的文本
1048
+
1049
+ Returns:
1050
+ 如果包含思考标签返回 True
1051
+ """
1052
+ return "<thinking_mode>" in text and "</thinking_mode>" in text
1053
+
1054
+
1055
+ def inject_thinking_tags_to_system(system, thinking_type: str = "enabled", budget_tokens: int = 20000):
1056
+ """将思考标签注入到系统消息中
1057
+
1058
+ Args:
1059
+ system: 原始系统消息 (可以是字符串或列表)
1060
+ thinking_type: 思考类型
1061
+ budget_tokens: 思考的 token 预算
1062
+
1063
+ Returns:
1064
+ 注入思考标签后的系统消息 (保持原始类型)
1065
+ """
1066
+ # 生成思考前缀
1067
+ thinking_prefix = generate_thinking_prefix(thinking_type, budget_tokens)
1068
+
1069
+ if not thinking_prefix:
1070
+ return system
1071
+
1072
+ # 处理 system 为列表的情况 (Anthropic API 支持 system 为 content blocks 列表)
1073
+ if isinstance(system, list):
1074
+ # 将列表转换为字符串
1075
+ system_text = ""
1076
+ for block in system:
1077
+ if isinstance(block, dict) and block.get("type") == "text":
1078
+ system_text += block.get("text", "") + "\n"
1079
+ elif isinstance(block, str):
1080
+ system_text += block + "\n"
1081
+ system_text = system_text.strip()
1082
+
1083
+ if not system_text:
1084
+ return thinking_prefix
1085
+
1086
+ if has_thinking_tags(system_text):
1087
+ return system
1088
+
1089
+ # 返回字符串形式
1090
+ return f"{thinking_prefix}\n\n{system_text}"
1091
+
1092
+ # 处理 system 为字符串的情况
1093
+ if not system or not str(system).strip():
1094
+ return thinking_prefix
1095
+
1096
+ # 如果已经包含思考标签,不再重复注入
1097
+ if has_thinking_tags(str(system)):
1098
+ return system
1099
+
1100
+ # 将思考标签插入到系统消息开头
1101
+ return f"{thinking_prefix}\n\n{system}"
1102
+
1103
+
1104
+ def find_real_thinking_start_tag(text: str, pos: int = 0) -> int:
1105
+ """查找真正的 <thinking> 标签位置,忽略被引号包围的情况
1106
+
1107
+ Args:
1108
+ text: 要搜索的文本
1109
+ pos: 开始搜索的位置
1110
+
1111
+ Returns:
1112
+ 找到的标签位置,如果没找到返回 -1
1113
+ """
1114
+ while True:
1115
+ idx = text.find("<thinking>", pos)
1116
+ if idx == -1:
1117
+ return -1
1118
+
1119
+ # 检查是否被引号包围
1120
+ # 向前查找最近的引号
1121
+ prev_quote = max(
1122
+ text.rfind("`", 0, idx),
1123
+ text.rfind("'", 0, idx),
1124
+ text.rfind('"', 0, idx)
1125
+ )
1126
+
1127
+ # 如果有引号且引号后没有换行,说明是被包围的
1128
+ if prev_quote != -1:
1129
+ # 检查引号到标签之间是否有换行
1130
+ between = text[prev_quote + 1:idx]
1131
+ if "\n" not in between:
1132
+ pos = idx + len("<thinking>")
1133
+ continue
1134
+
1135
+ return idx
1136
+
1137
+
1138
+ def find_real_thinking_end_tag(text: str, pos: int = 0) -> int:
1139
+ """查找真正的 </thinking> 标签位置,忽略被引号包围的情况
1140
+
1141
+ Args:
1142
+ text: 要搜索的文本
1143
+ pos: 开始搜索的位置
1144
+
1145
+ Returns:
1146
+ 找到的标签位置,如果没找到返回 -1
1147
+ """
1148
+ while True:
1149
+ idx = text.find("</thinking>", pos)
1150
+ if idx == -1:
1151
+ return -1
1152
+
1153
+ # 检查是否被引号包围
1154
+ # 向前查找最近的引号
1155
+ prev_quote = max(
1156
+ text.rfind("`", 0, idx),
1157
+ text.rfind("'", 0, idx),
1158
+ text.rfind('"', 0, idx)
1159
+ )
1160
+
1161
+ # 如果有引号且引号后没有换行,说明是被包围的
1162
+ if prev_quote != -1:
1163
+ # 检查引号到标签之间是否有换行
1164
+ between = text[prev_quote + 1:idx]
1165
+ if "\n" not in between:
1166
+ pos = idx + len("</thinking>")
1167
+ continue
1168
+
1169
+ return idx
1170
+
1171
+
1172
+ def extract_thinking_from_content(content: str) -> Tuple[str, str]:
1173
+ """从内容中提取思考部分和正文部分
1174
+
1175
+ Args:
1176
+ content: 原始内容
1177
+
1178
+ Returns:
1179
+ (thinking_content, text_content)
1180
+ """
1181
+ thinking_start = find_real_thinking_start_tag(content)
1182
+ thinking_end = find_real_thinking_end_tag(content)
1183
+
1184
+ if thinking_start == -1 or thinking_end == -1:
1185
+ return "", content
1186
+
1187
+ # 提取思考内容(去掉标签)
1188
+ thinking_content = content[thinking_start + len("<thinking>"):thinking_end].strip()
1189
+
1190
+ # 提取正文内容(去掉思考部分)
1191
+ text_content = content[:thinking_start].strip()
1192
+ after_thinking = content[thinking_end + len("</thinking>"):].strip()
1193
+ if after_thinking:
1194
+ text_content += "\n" + after_thinking
1195
+
1196
+ return thinking_content, text_content
KiroProxy/kiro_proxy/core/__init__.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """核心模块"""
2
+ from .state import state, ProxyState, RequestLog
3
+ from .account import Account
4
+ from .persistence import load_config, save_config, CONFIG_FILE
5
+ from .retry import RetryableRequest, is_retryable_error, RETRYABLE_STATUS_CODES
6
+ from .scheduler import scheduler
7
+ from .stats import stats_manager
8
+ from .browser import detect_browsers, open_url, get_browsers_info
9
+ from .flow_monitor import flow_monitor, FlowMonitor, LLMFlow, FlowState, TokenUsage
10
+ from .usage import get_usage_limits, get_account_usage, UsageInfo
11
+ from .history_manager import (
12
+ HistoryManager, HistoryConfig, TruncateStrategy,
13
+ get_history_config, set_history_config, update_history_config,
14
+ is_content_length_error
15
+ )
16
+ from .error_handler import (
17
+ ErrorType, KiroError, classify_error, is_account_suspended,
18
+ get_anthropic_error_response, format_error_log
19
+ )
20
+ from .rate_limiter import RateLimiter, RateLimitConfig, rate_limiter, get_rate_limiter
21
+
22
+ # 新增模块
23
+ from .quota_cache import QuotaCache, CachedQuota, get_quota_cache
24
+ from .account_selector import AccountSelector, SelectionStrategy, get_account_selector
25
+ from .quota_scheduler import QuotaScheduler, get_quota_scheduler
26
+ from .refresh_manager import (
27
+ RefreshManager, RefreshProgress, RefreshConfig,
28
+ get_refresh_manager, reset_refresh_manager
29
+ )
30
+ from .kiro_api import kiro_api_request, get_user_info, get_user_email
31
+
32
+ __all__ = [
33
+ "state", "ProxyState", "RequestLog", "Account",
34
+ "load_config", "save_config", "CONFIG_FILE",
35
+ "RetryableRequest", "is_retryable_error", "RETRYABLE_STATUS_CODES",
36
+ "scheduler", "stats_manager",
37
+ "detect_browsers", "open_url", "get_browsers_info",
38
+ "flow_monitor", "FlowMonitor", "LLMFlow", "FlowState", "TokenUsage",
39
+ "get_usage_limits", "get_account_usage", "UsageInfo",
40
+ "HistoryManager", "HistoryConfig", "TruncateStrategy",
41
+ "get_history_config", "set_history_config", "update_history_config",
42
+ "is_content_length_error",
43
+ "ErrorType", "KiroError", "classify_error", "is_account_suspended",
44
+ "get_anthropic_error_response", "format_error_log",
45
+ "RateLimiter", "RateLimitConfig", "rate_limiter", "get_rate_limiter",
46
+ # 新增导出
47
+ "QuotaCache", "CachedQuota", "get_quota_cache",
48
+ "AccountSelector", "SelectionStrategy", "get_account_selector",
49
+ "QuotaScheduler", "get_quota_scheduler",
50
+ # RefreshManager 导出
51
+ "RefreshManager", "RefreshProgress", "RefreshConfig",
52
+ "get_refresh_manager", "reset_refresh_manager",
53
+ # Kiro API 导出
54
+ "kiro_api_request", "get_user_info", "get_user_email",
55
+ ]
KiroProxy/kiro_proxy/core/account.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """账号管理"""
2
+ import json
3
+ import time
4
+ from dataclasses import dataclass, field
5
+ from pathlib import Path
6
+ from typing import Optional
7
+
8
+ from ..credential import (
9
+ KiroCredentials, TokenRefresher, CredentialStatus,
10
+ generate_machine_id, quota_manager
11
+ )
12
+
13
+
14
+ @dataclass
15
+ class Account:
16
+ """账号信息"""
17
+ id: str
18
+ name: str
19
+ token_path: str
20
+ enabled: bool = True
21
+ # 是否因额度耗尽被自动禁用(用于区分手动禁用,避免被自动启用)
22
+ auto_disabled: bool = False
23
+ request_count: int = 0
24
+ error_count: int = 0
25
+ last_used: Optional[float] = None
26
+ status: CredentialStatus = CredentialStatus.ACTIVE
27
+
28
+ _credentials: Optional[KiroCredentials] = field(default=None, repr=False)
29
+ _machine_id: Optional[str] = field(default=None, repr=False)
30
+
31
+ def is_available(self) -> bool:
32
+ """检查账号是否可用"""
33
+ if not self.enabled:
34
+ return False
35
+ if self.status in (CredentialStatus.DISABLED, CredentialStatus.UNHEALTHY, CredentialStatus.SUSPENDED):
36
+ return False
37
+ if not quota_manager.is_available(self.id):
38
+ return False
39
+
40
+ # 检查额度是否耗尽
41
+ from .quota_cache import get_quota_cache
42
+ quota_cache = get_quota_cache()
43
+ quota = quota_cache.get(self.id)
44
+ if quota and quota.is_exhausted:
45
+ return False
46
+
47
+ return True
48
+
49
+ def is_active(self) -> bool:
50
+ """检查账号是否活跃(最近60秒内使用过)"""
51
+ from .quota_scheduler import get_quota_scheduler
52
+ scheduler = get_quota_scheduler()
53
+ return scheduler.is_active(self.id)
54
+
55
+ def get_priority_order(self) -> Optional[int]:
56
+ """获取优先级顺序(从1开始),非优先账号返回 None"""
57
+ from .account_selector import get_account_selector
58
+ selector = get_account_selector()
59
+ return selector.get_priority_order(self.id)
60
+
61
+ def is_priority(self) -> bool:
62
+ """检查是否为优先账号"""
63
+ return self.get_priority_order() is not None
64
+
65
+ def load_credentials(self) -> Optional[KiroCredentials]:
66
+ """加载凭证信息"""
67
+ try:
68
+ self._credentials = KiroCredentials.from_file(self.token_path)
69
+
70
+ if self._credentials.client_id_hash and not self._credentials.client_id:
71
+ self._merge_client_credentials()
72
+
73
+ return self._credentials
74
+ except Exception as e:
75
+ print(f"[Account] 加载凭证失败 {self.id}: {e}")
76
+ return None
77
+
78
+ def _merge_client_credentials(self):
79
+ """合并 clientIdHash 对应的凭证文件"""
80
+ if not self._credentials or not self._credentials.client_id_hash:
81
+ return
82
+
83
+ cache_dir = Path(self.token_path).parent
84
+ hash_file = cache_dir / f"{self._credentials.client_id_hash}.json"
85
+
86
+ if hash_file.exists():
87
+ try:
88
+ with open(hash_file) as f:
89
+ data = json.load(f)
90
+ if not self._credentials.client_id:
91
+ self._credentials.client_id = data.get("clientId")
92
+ if not self._credentials.client_secret:
93
+ self._credentials.client_secret = data.get("clientSecret")
94
+ except Exception:
95
+ pass
96
+
97
+ def get_credentials(self) -> Optional[KiroCredentials]:
98
+ """获取凭证(带缓存)"""
99
+ if self._credentials is None:
100
+ self.load_credentials()
101
+ return self._credentials
102
+
103
+ def get_token(self) -> str:
104
+ """获取 access_token"""
105
+ creds = self.get_credentials()
106
+ if creds and creds.access_token:
107
+ return creds.access_token
108
+
109
+ try:
110
+ with open(self.token_path) as f:
111
+ return json.load(f).get("accessToken", "")
112
+ except Exception:
113
+ return ""
114
+
115
+ def get_machine_id(self) -> str:
116
+ """获取基于此账号的 Machine ID"""
117
+ if self._machine_id:
118
+ return self._machine_id
119
+
120
+ creds = self.get_credentials()
121
+ if creds:
122
+ self._machine_id = generate_machine_id(creds.profile_arn, creds.client_id)
123
+ else:
124
+ self._machine_id = generate_machine_id()
125
+
126
+ return self._machine_id
127
+
128
+ def is_token_expired(self) -> bool:
129
+ """检查 token 是否过期"""
130
+ creds = self.get_credentials()
131
+ return creds.is_expired() if creds else True
132
+
133
+ def is_token_expiring_soon(self, minutes: int = 10) -> bool:
134
+ """检查 token 是否即将过期"""
135
+ creds = self.get_credentials()
136
+ return creds.is_expiring_soon(minutes) if creds else False
137
+
138
+ async def refresh_token(self) -> tuple:
139
+ """刷新 token"""
140
+ creds = self.get_credentials()
141
+ if not creds:
142
+ return False, "无法加载凭证"
143
+
144
+ refresher = TokenRefresher(creds)
145
+ success, result = await refresher.refresh()
146
+
147
+ if success:
148
+ creds.save_to_file(self.token_path)
149
+ self._credentials = creds
150
+ self.status = CredentialStatus.ACTIVE
151
+ return True, "Token 刷新成功"
152
+ else:
153
+ self.status = CredentialStatus.UNHEALTHY
154
+ return False, result
155
+
156
+ def mark_quota_exceeded(self, reason: str = "Rate limited"):
157
+ """标记配额超限(进入冷却并避免被继续选中)
158
+
159
+ 429 错误自动冷却 5 分钟,无需手动配置
160
+ """
161
+ quota_manager.mark_exceeded(self.id, reason)
162
+ self.status = CredentialStatus.COOLDOWN
163
+ self.error_count += 1
164
+
165
+ def get_status_info(self) -> dict:
166
+ """获取状态信息"""
167
+ cooldown_remaining = quota_manager.get_cooldown_remaining(self.id)
168
+ creds = self.get_credentials()
169
+
170
+ # 获取额度信息
171
+ from .quota_cache import get_quota_cache
172
+ quota_cache = get_quota_cache()
173
+ quota = quota_cache.get(self.id)
174
+
175
+ quota_info = None
176
+ if quota:
177
+ # 计算相对时间
178
+ updated_ago = ""
179
+ if quota.updated_at > 0:
180
+ seconds_ago = time.time() - quota.updated_at
181
+ if seconds_ago < 60:
182
+ updated_ago = f"{int(seconds_ago)}秒前"
183
+ elif seconds_ago < 3600:
184
+ updated_ago = f"{int(seconds_ago / 60)}分钟前"
185
+ else:
186
+ updated_ago = f"{int(seconds_ago / 3600)}小时前"
187
+
188
+ # 格式化重置时间
189
+ reset_date_text = None
190
+ if quota.next_reset_date:
191
+ try:
192
+ # 处理时间戳格式
193
+ if isinstance(quota.next_reset_date, (int, float)):
194
+ from datetime import datetime
195
+ reset_dt = datetime.fromtimestamp(quota.next_reset_date)
196
+ reset_date_text = reset_dt.strftime('%Y-%m-%d')
197
+ else:
198
+ # 处理 ISO 格式
199
+ from datetime import datetime
200
+ reset_dt = datetime.fromisoformat(quota.next_reset_date.replace('Z', '+00:00'))
201
+ reset_date_text = reset_dt.strftime('%Y-%m-%d')
202
+ except:
203
+ reset_date_text = str(quota.next_reset_date)
204
+
205
+ # 格式化免费试用过期时间
206
+ trial_expiry_text = None
207
+ if quota.free_trial_expiry:
208
+ try:
209
+ # 处理时间戳格式
210
+ if isinstance(quota.free_trial_expiry, (int, float)):
211
+ from datetime import datetime
212
+ expiry_dt = datetime.fromtimestamp(quota.free_trial_expiry)
213
+ trial_expiry_text = expiry_dt.strftime('%Y-%m-%d')
214
+ else:
215
+ # 处理 ISO 格式
216
+ from datetime import datetime
217
+ expiry_dt = datetime.fromisoformat(quota.free_trial_expiry.replace('Z', '+00:00'))
218
+ trial_expiry_text = expiry_dt.strftime('%Y-%m-%d')
219
+ except:
220
+ trial_expiry_text = str(quota.free_trial_expiry)
221
+
222
+ # 计算生效奖励数
223
+ active_bonuses = len([e for e in (quota.bonus_expiries or []) if e])
224
+
225
+ quota_info = {
226
+ "balance": quota.balance,
227
+ "usage_limit": quota.usage_limit,
228
+ "current_usage": quota.current_usage,
229
+ "usage_percent": quota.usage_percent,
230
+ "is_low_balance": quota.is_low_balance,
231
+ "is_exhausted": quota.is_exhausted, # 额度是否耗尽
232
+ "is_suspended": getattr(quota, 'is_suspended', False), # 账号是否被封禁
233
+ "balance_status": quota.balance_status, # 额度状态: normal, low, exhausted
234
+ "subscription_title": quota.subscription_title,
235
+ "free_trial_limit": quota.free_trial_limit,
236
+ "free_trial_usage": quota.free_trial_usage,
237
+ "bonus_limit": quota.bonus_limit,
238
+ "bonus_usage": quota.bonus_usage,
239
+ "updated_at": updated_ago,
240
+ "updated_timestamp": quota.updated_at,
241
+ "error": quota.error,
242
+ # 新增重置时间字段
243
+ "next_reset_date": quota.next_reset_date,
244
+ "reset_date_text": reset_date_text, # 格式化后的重置日期
245
+ "free_trial_expiry": quota.free_trial_expiry,
246
+ "trial_expiry_text": trial_expiry_text, # 格式化后的试用过期日期
247
+ "bonus_expiries": quota.bonus_expiries or [],
248
+ "active_bonuses": active_bonuses, # 生效奖励数量
249
+ }
250
+
251
+ # 计算最后使用时间
252
+ last_used_ago = None
253
+ if self.last_used:
254
+ seconds_ago = time.time() - self.last_used
255
+ if seconds_ago < 60:
256
+ last_used_ago = f"{int(seconds_ago)}秒前"
257
+ elif seconds_ago < 3600:
258
+ last_used_ago = f"{int(seconds_ago / 60)}分钟前"
259
+ else:
260
+ last_used_ago = f"{int(seconds_ago / 3600)}小时前"
261
+
262
+ return {
263
+ "id": self.id,
264
+ "name": self.name,
265
+ "enabled": self.enabled,
266
+ "status": self.status.value,
267
+ "available": self.is_available(),
268
+ "request_count": self.request_count,
269
+ "error_count": self.error_count,
270
+ "error_rate": f"{(self.error_count / max(1, self.request_count) * 100):.1f}%",
271
+ "cooldown_remaining": cooldown_remaining,
272
+ "token_expired": self.is_token_expired() if creds else None,
273
+ "token_expiring_soon": self.is_token_expiring_soon() if creds else None,
274
+ "token_expires_at": creds.expires_at if creds else None, # Token 过期时间戳
275
+ "auth_method": creds.auth_method if creds else None,
276
+ "has_refresh_token": bool(creds and creds.refresh_token),
277
+ "idc_config_complete": bool(creds and creds.client_id and creds.client_secret) if creds and creds.auth_method == "idc" else None,
278
+ # 新增字段
279
+ "quota": quota_info,
280
+ "is_priority": self.is_priority(),
281
+ "priority_order": self.get_priority_order(),
282
+ "is_active": self.is_active(),
283
+ "last_used": self.last_used,
284
+ "last_used_ago": last_used_ago,
285
+ # Provider 字段 (Google/Github)
286
+ "provider": creds.provider if creds else None,
287
+ }
KiroProxy/kiro_proxy/core/account_selector.py ADDED
@@ -0,0 +1,390 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """账号选择器模块
2
+
3
+ 实现基于剩余额度的智能账号选择策略,支持优先账号配置。
4
+ """
5
+ import json
6
+ import random
7
+ import time
8
+ from enum import Enum
9
+ from pathlib import Path
10
+ from typing import Optional, List, Set, TYPE_CHECKING
11
+ from threading import Lock
12
+
13
+ if TYPE_CHECKING:
14
+ from .account import Account
15
+ from .quota_cache import QuotaCache
16
+
17
+
18
+ class SelectionStrategy(Enum):
19
+ """选择策略"""
20
+ LOWEST_BALANCE = "lowest_balance" # 剩余额度最少优先
21
+ ROUND_ROBIN = "round_robin" # 轮询
22
+ LEAST_REQUESTS = "least_requests" # 请求最少优先
23
+ RANDOM = "random" # 随机选择(分散压力)
24
+
25
+
26
+ class AccountSelector:
27
+ """账号选择器
28
+
29
+ 根据配置的策略选择最合适的账号,支持优先账号配置。
30
+ """
31
+
32
+ def __init__(self, quota_cache: 'QuotaCache', priority_file: Optional[str] = None):
33
+ """
34
+ 初始化账号选择器
35
+
36
+ Args:
37
+ quota_cache: 额度缓存实例
38
+ priority_file: 优先账号配置文件路径
39
+ """
40
+ self.quota_cache = quota_cache
41
+ self._priority_accounts: List[str] = []
42
+ # 默认使用随机策略,避免单账号 RPM 过高导致封禁风险
43
+ self._strategy = SelectionStrategy.RANDOM
44
+ self._lock = Lock()
45
+ self._round_robin_index = 0
46
+ self._last_random_account_id: Optional[str] = None
47
+
48
+ # 设置优先账号配置文件路径
49
+ if priority_file:
50
+ self._priority_file = Path(priority_file)
51
+ else:
52
+ from ..config import DATA_DIR
53
+ self._priority_file = DATA_DIR / "priority.json"
54
+
55
+ # 加载优先账号配置
56
+ self._load_priority_config()
57
+
58
+ @property
59
+ def strategy(self) -> SelectionStrategy:
60
+ """获取当前选择策略"""
61
+ return self._strategy
62
+
63
+ @strategy.setter
64
+ def strategy(self, value: SelectionStrategy):
65
+ """设置选择策略"""
66
+ self._strategy = value
67
+ self._save_priority_config()
68
+
69
+ def select(self,
70
+ available_accounts: List['Account'],
71
+ session_id: Optional[str] = None) -> Optional['Account']:
72
+ """选择最合适的账号
73
+
74
+ Args:
75
+ available_accounts: 可用账号列表
76
+ session_id: 会话ID(用于会话粘性,暂未实现)
77
+
78
+ Returns:
79
+ 选中的账号,如果没有可用账号则返回 None
80
+ """
81
+ if not available_accounts:
82
+ return None
83
+
84
+ with self._lock:
85
+ # 1. 首先检查优先账号
86
+ if self._priority_accounts:
87
+ for priority_id in self._priority_accounts:
88
+ for account in available_accounts:
89
+ if account.id == priority_id and account.is_available():
90
+ return account
91
+
92
+ # 2. 根据策略选择
93
+ if self._strategy == SelectionStrategy.LOWEST_BALANCE:
94
+ return self._select_lowest_balance(available_accounts)
95
+ elif self._strategy == SelectionStrategy.ROUND_ROBIN:
96
+ return self._select_round_robin(available_accounts)
97
+ elif self._strategy == SelectionStrategy.LEAST_REQUESTS:
98
+ return self._select_least_requests(available_accounts)
99
+ elif self._strategy == SelectionStrategy.RANDOM:
100
+ return self._select_random(available_accounts)
101
+
102
+ # 默认返回第一个可用账号
103
+ return available_accounts[0] if available_accounts else None
104
+
105
+ def _select_lowest_balance(self, accounts: List['Account']) -> Optional['Account']:
106
+ """选择剩余额度最少的账号"""
107
+ available = [a for a in accounts if a.is_available()]
108
+ if not available:
109
+ return None
110
+
111
+ def get_balance_and_requests(account: 'Account') -> tuple:
112
+ """获取账号的余额和请求数,用于排序"""
113
+ quota = self.quota_cache.get(account.id)
114
+ balance = quota.balance if quota and not quota.has_error() else float('inf')
115
+ return (balance, account.request_count)
116
+
117
+ # 按余额升序,余额相同时按请求数升序
118
+ return min(available, key=get_balance_and_requests)
119
+
120
+ def _select_round_robin(self, accounts: List['Account']) -> Optional['Account']:
121
+ """轮询选择账号"""
122
+ available = [a for a in accounts if a.is_available()]
123
+ if not available:
124
+ return None
125
+
126
+ self._round_robin_index = self._round_robin_index % len(available)
127
+ account = available[self._round_robin_index]
128
+ self._round_robin_index += 1
129
+ return account
130
+
131
+ def _select_least_requests(self, accounts: List['Account']) -> Optional['Account']:
132
+ """选择请求数最少的账号"""
133
+ available = [a for a in accounts if a.is_available()]
134
+ if not available:
135
+ return None
136
+ return min(available, key=lambda a: a.request_count)
137
+
138
+ def _select_random(self, accounts: List['Account']) -> Optional['Account']:
139
+ """随机选择账号(分散请求压力)"""
140
+ available = [a for a in accounts if a.is_available()]
141
+ if not available:
142
+ return None
143
+
144
+ # 尽量避免连续两次命中同一账号(在有多个可用账号时)
145
+ if self._last_random_account_id and len(available) > 1:
146
+ candidates = [a for a in available if a.id != self._last_random_account_id]
147
+ if candidates:
148
+ selected = random.choice(candidates)
149
+ else:
150
+ selected = random.choice(available)
151
+ else:
152
+ selected = random.choice(available)
153
+
154
+ self._last_random_account_id = selected.id
155
+ return selected
156
+
157
+ def set_priority_accounts(self, account_ids: List[str],
158
+ valid_account_ids: Optional[Set[str]] = None) -> tuple:
159
+ """设置优先账号列表(按顺序)
160
+
161
+ Args:
162
+ account_ids: 优先账号ID列表(按顺序)
163
+ valid_account_ids: 有效账号ID集合(用于验证)
164
+
165
+ Returns:
166
+ (success, message)
167
+ """
168
+ with self._lock:
169
+ if not account_ids:
170
+ self._priority_accounts = []
171
+ self._strategy = SelectionStrategy.RANDOM
172
+ self._save_priority_config()
173
+ return True, "已清除优先账号"
174
+
175
+ # 去重(保持顺序)
176
+ unique_ids: List[str] = []
177
+ seen: Set[str] = set()
178
+ for aid in account_ids:
179
+ if aid in seen:
180
+ continue
181
+ seen.add(aid)
182
+ unique_ids.append(aid)
183
+
184
+ # 验证账号是否存在
185
+ if valid_account_ids:
186
+ for aid in unique_ids:
187
+ if aid not in valid_account_ids:
188
+ return False, f"账号不存在: {aid}"
189
+
190
+ self._priority_accounts = unique_ids
191
+ self._save_priority_config()
192
+ if len(unique_ids) == 1:
193
+ return True, f"已设置优先账号: {unique_ids[0]}"
194
+ return True, f"已设置优先账号: {', '.join(unique_ids)}"
195
+
196
+ def set_priority_account(self, account_id: Optional[str],
197
+ valid_account_ids: Optional[Set[str]] = None) -> tuple:
198
+ """设置优先账号(单个)
199
+
200
+ Args:
201
+ account_id: 账号ID,None 表示清除
202
+ valid_account_ids: 有效账号ID集合(用于验证)
203
+
204
+ Returns:
205
+ (success, message)
206
+ """
207
+ if account_id is None:
208
+ return self.set_priority_accounts([], valid_account_ids)
209
+ return self.set_priority_accounts([account_id], valid_account_ids)
210
+
211
+ def add_priority_account(self, account_id: str,
212
+ position: int = -1,
213
+ valid_account_ids: Optional[Set[str]] = None) -> tuple:
214
+ """添加优先账号(可指定插入位置)
215
+
216
+ Args:
217
+ account_id: 账号ID
218
+ position: 插入位置(0-based),-1 表示追加到末尾
219
+ valid_account_ids: 有效账号ID集合(用于验证)
220
+
221
+ Returns:
222
+ (success, message)
223
+ """
224
+ with self._lock:
225
+ if valid_account_ids and account_id not in valid_account_ids:
226
+ return False, f"账号不存在: {account_id}"
227
+
228
+ if account_id in self._priority_accounts:
229
+ self._priority_accounts.remove(account_id)
230
+
231
+ if position is None or position < 0 or position >= len(self._priority_accounts):
232
+ self._priority_accounts.append(account_id)
233
+ else:
234
+ self._priority_accounts.insert(position, account_id)
235
+
236
+ self._save_priority_config()
237
+ return True, f"已添加优先账号: {account_id}"
238
+
239
+ def remove_priority_account(self, account_id: str = None) -> tuple:
240
+ """移除优先账号
241
+
242
+ Args:
243
+ account_id: 账号ID(可选,不传则清除所有)
244
+
245
+ Returns:
246
+ (success, message)
247
+ """
248
+ with self._lock:
249
+ if not self._priority_accounts:
250
+ return False, "没有设置优先账号"
251
+
252
+ if account_id:
253
+ if account_id not in self._priority_accounts:
254
+ return False, f"账号 {account_id} 不是优先账号"
255
+
256
+ self._priority_accounts.remove(account_id)
257
+ if not self._priority_accounts:
258
+ self._strategy = SelectionStrategy.RANDOM
259
+ self._save_priority_config()
260
+ return True, f"已移除优先账号: {account_id}"
261
+
262
+ self._priority_accounts = []
263
+ self._strategy = SelectionStrategy.RANDOM
264
+ self._save_priority_config()
265
+ return True, "已清除优先账号"
266
+
267
+ def reorder_priority(self, account_ids: List[str]) -> tuple:
268
+ """重新排序优先账号列表
269
+
270
+ Args:
271
+ account_ids: 新的优先账号顺序(必须与当前优先账号集合一致)
272
+
273
+ Returns:
274
+ (success, message)
275
+ """
276
+ with self._lock:
277
+ if not self._priority_accounts:
278
+ return False, "没有设置优先账号"
279
+
280
+ if not account_ids:
281
+ return False, "账号列表不能为空"
282
+
283
+ if len(account_ids) != len(self._priority_accounts):
284
+ return False, "账号数量不匹配"
285
+
286
+ if len(set(account_ids)) != len(account_ids):
287
+ return False, "账号列表包含重复项"
288
+
289
+ if set(account_ids) != set(self._priority_accounts):
290
+ return False, "账号列表与当前优先账号不匹配"
291
+
292
+ self._priority_accounts = list(account_ids)
293
+ self._save_priority_config()
294
+ return True, "已更新优先账号顺序"
295
+
296
+ def get_priority_account(self) -> Optional[str]:
297
+ """获取优先账号(单个)"""
298
+ with self._lock:
299
+ return self._priority_accounts[0] if self._priority_accounts else None
300
+
301
+ def get_priority_accounts(self) -> List[str]:
302
+ """获取优先账号列表"""
303
+ with self._lock:
304
+ return list(self._priority_accounts)
305
+
306
+ def is_priority_account(self, account_id: str) -> bool:
307
+ """检查账号是否为优先账号"""
308
+ with self._lock:
309
+ return account_id in self._priority_accounts
310
+
311
+ def get_priority_order(self, account_id: str) -> Optional[int]:
312
+ """获取账号的优先级顺序(从1开始)"""
313
+ with self._lock:
314
+ if account_id in self._priority_accounts:
315
+ return self._priority_accounts.index(account_id) + 1
316
+ return None
317
+
318
+ def _load_priority_config(self) -> bool:
319
+ """从文件加载优先账号配置"""
320
+ if not self._priority_file.exists():
321
+ return False
322
+
323
+ try:
324
+ with open(self._priority_file, 'r', encoding='utf-8') as f:
325
+ data = json.load(f)
326
+
327
+ self._priority_accounts = data.get("priority_accounts", [])
328
+ strategy_str = data.get("strategy", SelectionStrategy.RANDOM.value)
329
+ try:
330
+ self._strategy = SelectionStrategy(strategy_str)
331
+ except ValueError:
332
+ self._strategy = SelectionStrategy.RANDOM
333
+
334
+ # 兼容旧版本:历史默认策略为 lowest_balance,但无优先账号时更需要分散压力
335
+ if not self._priority_accounts and self._strategy == SelectionStrategy.LOWEST_BALANCE:
336
+ self._strategy = SelectionStrategy.RANDOM
337
+ self._save_priority_config()
338
+
339
+ print(f"[AccountSelector] 加载优先账号配置: {len(self._priority_accounts)} 个优先账号")
340
+ return True
341
+
342
+ except Exception as e:
343
+ print(f"[AccountSelector] 加载优先账号配置失败: {e}")
344
+ return False
345
+
346
+ def _save_priority_config(self) -> bool:
347
+ """保存优先账号配置到文件"""
348
+ try:
349
+ self._priority_file.parent.mkdir(parents=True, exist_ok=True)
350
+
351
+ data = {
352
+ "version": "1.0",
353
+ "priority_accounts": self._priority_accounts,
354
+ "strategy": self._strategy.value
355
+ }
356
+
357
+ temp_file = self._priority_file.with_suffix('.tmp')
358
+ with open(temp_file, 'w', encoding='utf-8') as f:
359
+ json.dump(data, f, indent=2, ensure_ascii=False)
360
+ temp_file.replace(self._priority_file)
361
+
362
+ return True
363
+
364
+ except Exception as e:
365
+ print(f"[AccountSelector] 保存优先账号配置失败: {e}")
366
+ return False
367
+
368
+ def get_status(self) -> dict:
369
+ """获取选择器状态"""
370
+ with self._lock:
371
+ return {
372
+ "strategy": self._strategy.value,
373
+ "priority_accounts": list(self._priority_accounts),
374
+ "priority_count": len(self._priority_accounts)
375
+ }
376
+
377
+
378
+ # 全局选择器实例
379
+ _account_selector: Optional[AccountSelector] = None
380
+
381
+
382
+ def get_account_selector(quota_cache: Optional['QuotaCache'] = None) -> AccountSelector:
383
+ """获取全局选择器实例"""
384
+ global _account_selector
385
+ if _account_selector is None:
386
+ if quota_cache is None:
387
+ from .quota_cache import get_quota_cache
388
+ quota_cache = get_quota_cache()
389
+ _account_selector = AccountSelector(quota_cache)
390
+ return _account_selector
KiroProxy/kiro_proxy/core/browser.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """浏览器检测和打开"""
2
+ import os
3
+ import shlex
4
+ import shutil
5
+ import subprocess
6
+ import platform
7
+ from dataclasses import dataclass
8
+ from typing import List, Optional
9
+
10
+
11
+ @dataclass
12
+ class BrowserInfo:
13
+ id: str
14
+ name: str
15
+ path: str
16
+ supports_incognito: bool
17
+ incognito_arg: str = ""
18
+
19
+
20
+ # 浏览器配置
21
+ BROWSER_CONFIGS = {
22
+ "chrome": {
23
+ "names": ["google-chrome", "google-chrome-stable", "chrome", "chromium", "chromium-browser"],
24
+ "display": "Chrome",
25
+ "incognito": "--incognito",
26
+ },
27
+ "firefox": {
28
+ "names": ["firefox", "firefox-esr"],
29
+ "display": "Firefox",
30
+ "incognito": "--private-window",
31
+ },
32
+ "edge": {
33
+ "names": ["microsoft-edge", "microsoft-edge-stable", "msedge"],
34
+ "display": "Edge",
35
+ "incognito": "--inprivate",
36
+ },
37
+ "brave": {
38
+ "names": ["brave", "brave-browser"],
39
+ "display": "Brave",
40
+ "incognito": "--incognito",
41
+ },
42
+ "opera": {
43
+ "names": ["opera"],
44
+ "display": "Opera",
45
+ "incognito": "--private",
46
+ },
47
+ "vivaldi": {
48
+ "names": ["vivaldi", "vivaldi-stable"],
49
+ "display": "Vivaldi",
50
+ "incognito": "--incognito",
51
+ },
52
+ }
53
+
54
+
55
+ def detect_browsers() -> List[BrowserInfo]:
56
+ """检测系统安装的浏览器"""
57
+ browsers = []
58
+ system = platform.system().lower()
59
+
60
+ if system == "windows":
61
+ import winreg
62
+
63
+ def normalize_exe_path(raw: str) -> Optional[str]:
64
+ if not raw:
65
+ return None
66
+ expanded = os.path.expandvars(raw.strip())
67
+ try:
68
+ parts = shlex.split(expanded, posix=False)
69
+ except ValueError:
70
+ parts = [expanded]
71
+ candidate = (parts[0] if parts else expanded).strip().strip('"')
72
+ if os.path.exists(candidate):
73
+ return candidate
74
+ lower = expanded.lower()
75
+ exe_idx = lower.find(".exe")
76
+ if exe_idx != -1:
77
+ candidate = expanded[:exe_idx + 4].strip().strip('"')
78
+ if os.path.exists(candidate):
79
+ return candidate
80
+ return None
81
+
82
+ def get_reg_path(exe_name: str) -> Optional[str]:
83
+ name = f"{exe_name}.exe"
84
+ for root in (winreg.HKEY_LOCAL_MACHINE, winreg.HKEY_CURRENT_USER):
85
+ try:
86
+ with winreg.OpenKey(root, rf"SOFTWARE\Microsoft\Windows\CurrentVersion\App Paths\{name}") as key:
87
+ value, _ = winreg.QueryValueEx(key, "")
88
+ path = normalize_exe_path(value)
89
+ if path:
90
+ return path
91
+ except (FileNotFoundError, OSError, WindowsError):
92
+ pass
93
+ return None
94
+
95
+ for browser_id, config in BROWSER_CONFIGS.items():
96
+ path = None
97
+ for exe_name in config["names"]:
98
+ path = get_reg_path(exe_name)
99
+ if path:
100
+ break
101
+ if not path:
102
+ for exe_name in config["names"]:
103
+ path = shutil.which(exe_name)
104
+ if path:
105
+ break
106
+ if path:
107
+ browsers.append(BrowserInfo(
108
+ id=browser_id,
109
+ name=config["display"],
110
+ path=path,
111
+ supports_incognito=bool(config.get("incognito")),
112
+ incognito_arg=config.get("incognito", ""),
113
+ ))
114
+ else:
115
+ for browser_id, config in BROWSER_CONFIGS.items():
116
+ for name in config["names"]:
117
+ path = shutil.which(name)
118
+ if path:
119
+ browsers.append(BrowserInfo(
120
+ id=browser_id,
121
+ name=config["display"],
122
+ path=path,
123
+ supports_incognito=bool(config.get("incognito")),
124
+ incognito_arg=config.get("incognito", ""),
125
+ ))
126
+ break
127
+
128
+ # 添加默认浏览器选项
129
+ if browsers:
130
+ browsers.insert(0, BrowserInfo(
131
+ id="default",
132
+ name="默认浏览器",
133
+ path="xdg-open" if system == "linux" else "open",
134
+ supports_incognito=False,
135
+ incognito_arg="",
136
+ ))
137
+
138
+ return browsers
139
+
140
+
141
+ def open_url(url: str, browser_id: str = "default", incognito: bool = False) -> bool:
142
+ """用指定浏览器打开 URL"""
143
+ browsers = detect_browsers()
144
+ browser = next((b for b in browsers if b.id == browser_id), None)
145
+
146
+ if not browser:
147
+ # 降级到默认
148
+ browser = browsers[0] if browsers else None
149
+
150
+ if not browser:
151
+ return False
152
+
153
+ try:
154
+ if browser.id == "default":
155
+ # 使用系统默认浏览器
156
+ system = platform.system().lower()
157
+ if system == "linux":
158
+ subprocess.Popen(["xdg-open", url], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
159
+ elif system == "darwin":
160
+ subprocess.Popen(["open", url], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
161
+ else:
162
+ os.startfile(url)
163
+ else:
164
+ # 使用指定浏览器
165
+ args = [browser.path]
166
+ if incognito and browser.supports_incognito and browser.incognito_arg:
167
+ args.append(browser.incognito_arg)
168
+ args.append(url)
169
+ subprocess.Popen(args, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
170
+
171
+ return True
172
+ except Exception as e:
173
+ print(f"[Browser] 打开失败: {e}")
174
+ return False
175
+
176
+
177
+ def get_browsers_info() -> List[dict]:
178
+ """获取浏览器信息列表"""
179
+ return [
180
+ {
181
+ "id": b.id,
182
+ "name": b.name,
183
+ "supports_incognito": b.supports_incognito,
184
+ }
185
+ for b in detect_browsers()
186
+ ]
KiroProxy/kiro_proxy/core/error_handler.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """错误处理模块 - 统一的错误分类和处理
2
+
3
+ 检测各种 Kiro API 错误类型:
4
+ - 账号封禁 (TEMPORARILY_SUSPENDED)
5
+ - 配额超限 (Rate Limit)
6
+ - 内容过长 (CONTENT_LENGTH_EXCEEDS_THRESHOLD)
7
+ - 认证失败 (Unauthorized)
8
+ - 服务不可用 (Service Unavailable)
9
+ """
10
+ import re
11
+ from enum import Enum
12
+ from dataclasses import dataclass
13
+ from typing import Optional, Tuple
14
+
15
+
16
+ class ErrorType(str, Enum):
17
+ """错误类型"""
18
+ ACCOUNT_SUSPENDED = "account_suspended" # 账号被封禁
19
+ RATE_LIMITED = "rate_limited" # 配额超限
20
+ CONTENT_TOO_LONG = "content_too_long" # 内容过长
21
+ AUTH_FAILED = "auth_failed" # 认证失败
22
+ SERVICE_UNAVAILABLE = "service_unavailable" # 服务不可用
23
+ MODEL_UNAVAILABLE = "model_unavailable" # 模型不可用
24
+ UNKNOWN = "unknown" # 未知错误
25
+
26
+
27
+ @dataclass
28
+ class KiroError:
29
+ """Kiro API 错误"""
30
+ type: ErrorType
31
+ status_code: int
32
+ message: str
33
+ user_message: str # 用户友好的消息
34
+ should_disable_account: bool = False # 是否应该禁用账号
35
+ should_switch_account: bool = False # 是否应该切换账号
36
+ should_retry: bool = False # 是否应该重试
37
+ cooldown_seconds: int = 0 # 冷却时间
38
+
39
+
40
+ def classify_error(status_code: int, error_text: str) -> KiroError:
41
+ """分类 Kiro API 错误
42
+
43
+ Args:
44
+ status_code: HTTP 状态码
45
+ error_text: 错误响应文本
46
+
47
+ Returns:
48
+ KiroError 对象
49
+ """
50
+ error_lower = error_text.lower()
51
+
52
+ # 1. 账号封禁检测 (最严重)
53
+ # 检测: AccountSuspendedException, 423 状态码, temporarily_suspended, suspended
54
+ is_suspended = (
55
+ status_code == 423 or
56
+ "accountsuspendedexception" in error_lower or
57
+ "temporarily_suspended" in error_lower or
58
+ "suspended" in error_lower
59
+ )
60
+
61
+ if is_suspended:
62
+ # 提取 User ID
63
+ user_id_match = re.search(r'User ID \(([^)]+)\)', error_text)
64
+ user_id = user_id_match.group(1) if user_id_match else "unknown"
65
+
66
+ return KiroError(
67
+ type=ErrorType.ACCOUNT_SUSPENDED,
68
+ status_code=status_code,
69
+ message=error_text,
70
+ user_message=f"⚠️ 账号已被封禁 (User ID: {user_id})。请联系 AWS 支持解封: https://support.aws.amazon.com/#/contacts/kiro",
71
+ should_disable_account=True,
72
+ should_switch_account=True,
73
+ )
74
+
75
+ # 2. 402 Payment Required - 额度用尽(不触发冷却,仅切换账号)
76
+ if status_code == 402 or "payment required" in error_lower or "insufficient" in error_lower:
77
+ return KiroError(
78
+ type=ErrorType.RATE_LIMITED,
79
+ status_code=status_code,
80
+ message=error_text,
81
+ user_message="账号额度已用尽,已切换到其他账号",
82
+ should_switch_account=False, # 不自动切换,让上层逻辑处理
83
+ cooldown_seconds=0, # 不触发冷却
84
+ )
85
+
86
+ # 3. 配额超限检测 (仅 429 触发冷却)
87
+ if status_code == 429:
88
+ return KiroError(
89
+ type=ErrorType.RATE_LIMITED,
90
+ status_code=status_code,
91
+ message=error_text,
92
+ user_message="请求过于频繁,账号已进入冷却期",
93
+ should_switch_account=True,
94
+ cooldown_seconds=30, # 基础冷却时间,实际由 QuotaManager 动态管理
95
+ )
96
+
97
+ # 4. 内容过长检测
98
+ if "content_length_exceeds_threshold" in error_lower or (
99
+ "too long" in error_lower and ("input" in error_lower or "content" in error_lower)
100
+ ):
101
+ return KiroError(
102
+ type=ErrorType.CONTENT_TOO_LONG,
103
+ status_code=status_code,
104
+ message=error_text,
105
+ user_message="对话历史过长,请使用 /clear 清空对话",
106
+ should_retry=True,
107
+ )
108
+
109
+ # 5. 认证失败检测
110
+ if status_code == 401 or "unauthorized" in error_lower or "invalid token" in error_lower:
111
+ return KiroError(
112
+ type=ErrorType.AUTH_FAILED,
113
+ status_code=status_code,
114
+ message=error_text,
115
+ user_message="Token 已过期或无效,请刷新 Token",
116
+ should_switch_account=True,
117
+ )
118
+
119
+ # 6. 模型不可用检测
120
+ if "model_temporarily_unavailable" in error_lower or "unexpectedly high load" in error_lower:
121
+ return KiroError(
122
+ type=ErrorType.MODEL_UNAVAILABLE,
123
+ status_code=status_code,
124
+ message=error_text,
125
+ user_message="模型暂时不可用,请稍后重试",
126
+ should_retry=True,
127
+ )
128
+
129
+ # 7. 服务不可用检测
130
+ if status_code in (502, 503, 504) or "service unavailable" in error_lower:
131
+ return KiroError(
132
+ type=ErrorType.SERVICE_UNAVAILABLE,
133
+ status_code=status_code,
134
+ message=error_text,
135
+ user_message="服务暂时不可用,请稍后重试",
136
+ should_retry=True,
137
+ )
138
+
139
+ # 8. 未知错误
140
+ return KiroError(
141
+ type=ErrorType.UNKNOWN,
142
+ status_code=status_code,
143
+ message=error_text,
144
+ user_message=f"API 错误 ({status_code})",
145
+ )
146
+
147
+
148
+ def is_account_suspended(status_code: int, error_text: str) -> bool:
149
+ """检查是否为账号封禁错误"""
150
+ error = classify_error(status_code, error_text)
151
+ return error.type == ErrorType.ACCOUNT_SUSPENDED
152
+
153
+
154
+ def get_anthropic_error_response(error: KiroError) -> dict:
155
+ """生成 Anthropic 格式的错误响应"""
156
+ error_type_map = {
157
+ ErrorType.ACCOUNT_SUSPENDED: "authentication_error",
158
+ ErrorType.RATE_LIMITED: "rate_limit_error",
159
+ ErrorType.CONTENT_TOO_LONG: "invalid_request_error",
160
+ ErrorType.AUTH_FAILED: "authentication_error",
161
+ ErrorType.SERVICE_UNAVAILABLE: "api_error",
162
+ ErrorType.MODEL_UNAVAILABLE: "overloaded_error",
163
+ ErrorType.UNKNOWN: "api_error",
164
+ }
165
+
166
+ return {
167
+ "type": "error",
168
+ "error": {
169
+ "type": error_type_map.get(error.type, "api_error"),
170
+ "message": error.user_message
171
+ }
172
+ }
173
+
174
+
175
+ def format_error_log(error: KiroError, account_id: str = None) -> str:
176
+ """格式化错误日志"""
177
+ lines = [
178
+ f"[{error.type.value.upper()}]",
179
+ f" Status: {error.status_code}",
180
+ f" Message: {error.user_message}",
181
+ ]
182
+ if account_id:
183
+ lines.insert(1, f" Account: {account_id}")
184
+ if error.should_disable_account:
185
+ lines.append(" Action: 账号已被禁用")
186
+ elif error.should_switch_account:
187
+ lines.append(" Action: 切换到其他账号")
188
+ return "\n".join(lines)
KiroProxy/kiro_proxy/core/flow_monitor.py ADDED
@@ -0,0 +1,572 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Flow Monitor - LLM 流量监控
2
+
3
+ 记录完整的请求/响应数据,支持查询、过滤、导出。
4
+ """
5
+ import json
6
+ import time
7
+ import uuid
8
+ from pathlib import Path
9
+ from dataclasses import dataclass, field, asdict
10
+ from typing import Optional, List, Dict, Any
11
+ from datetime import datetime, timezone
12
+ from collections import deque
13
+ from enum import Enum
14
+
15
+
16
+ class FlowState(str, Enum):
17
+ """Flow 状态"""
18
+ PENDING = "pending" # 等待响应
19
+ STREAMING = "streaming" # 流式传输中
20
+ COMPLETED = "completed" # 完成
21
+ ERROR = "error" # 错误
22
+
23
+
24
+ @dataclass
25
+ class Message:
26
+ """消息"""
27
+ role: str # user/assistant/system/tool
28
+ content: Any # str 或 list
29
+ name: Optional[str] = None # tool name
30
+ tool_call_id: Optional[str] = None
31
+
32
+
33
+ @dataclass
34
+ class TokenUsage:
35
+ """Token 使用量"""
36
+ input_tokens: int = 0
37
+ output_tokens: int = 0
38
+ cache_read_tokens: int = 0
39
+ cache_write_tokens: int = 0
40
+
41
+ @property
42
+ def total_tokens(self) -> int:
43
+ return self.input_tokens + self.output_tokens
44
+
45
+
46
+ @dataclass
47
+ class FlowRequest:
48
+ """请求数据"""
49
+ method: str
50
+ path: str
51
+ headers: Dict[str, str]
52
+ body: Dict[str, Any]
53
+
54
+ # 解析后的字段
55
+ model: str = ""
56
+ messages: List[Message] = field(default_factory=list)
57
+ system: str = ""
58
+ tools: List[Dict] = field(default_factory=list)
59
+ stream: bool = False
60
+ max_tokens: int = 0
61
+ temperature: float = 1.0
62
+
63
+
64
+ @dataclass
65
+ class FlowResponse:
66
+ """响应数据"""
67
+ status_code: int
68
+ headers: Dict[str, str] = field(default_factory=dict)
69
+ body: Any = None
70
+
71
+ # 解析后的字段
72
+ content: str = ""
73
+ tool_calls: List[Dict] = field(default_factory=list)
74
+ stop_reason: str = ""
75
+ usage: TokenUsage = field(default_factory=TokenUsage)
76
+
77
+ # 流式响应
78
+ chunks: List[str] = field(default_factory=list)
79
+ chunk_count: int = 0
80
+
81
+
82
+ @dataclass
83
+ class FlowError:
84
+ """错误信息"""
85
+ type: str # rate_limit_error, api_error, etc.
86
+ message: str
87
+ status_code: int = 0
88
+ raw: str = ""
89
+
90
+
91
+ @dataclass
92
+ class FlowTiming:
93
+ """时间信息"""
94
+ created_at: float = 0
95
+ first_byte_at: Optional[float] = None
96
+ completed_at: Optional[float] = None
97
+
98
+ @property
99
+ def ttfb_ms(self) -> Optional[float]:
100
+ """Time to first byte"""
101
+ if self.first_byte_at and self.created_at:
102
+ return (self.first_byte_at - self.created_at) * 1000
103
+ return None
104
+
105
+ @property
106
+ def duration_ms(self) -> Optional[float]:
107
+ """Total duration"""
108
+ if self.completed_at and self.created_at:
109
+ return (self.completed_at - self.created_at) * 1000
110
+ return None
111
+
112
+
113
+ @dataclass
114
+ class LLMFlow:
115
+ """完整的 LLM 请求流"""
116
+ id: str
117
+ state: FlowState
118
+
119
+ # 路由信息
120
+ protocol: str # anthropic, openai, gemini
121
+ account_id: Optional[str] = None
122
+ account_name: Optional[str] = None
123
+
124
+ # 请求/响应
125
+ request: Optional[FlowRequest] = None
126
+ response: Optional[FlowResponse] = None
127
+ error: Optional[FlowError] = None
128
+
129
+ # 时间
130
+ timing: FlowTiming = field(default_factory=FlowTiming)
131
+
132
+ # 元数据
133
+ tags: List[str] = field(default_factory=list)
134
+ notes: str = ""
135
+ bookmarked: bool = False
136
+
137
+ # 重试信息
138
+ retry_count: int = 0
139
+ parent_flow_id: Optional[str] = None
140
+
141
+ def to_dict(self) -> dict:
142
+ """转换为字典"""
143
+ d = {
144
+ "id": self.id,
145
+ "state": self.state.value,
146
+ "protocol": self.protocol,
147
+ "account_id": self.account_id,
148
+ "account_name": self.account_name,
149
+ "timing": {
150
+ "created_at": self.timing.created_at,
151
+ "first_byte_at": self.timing.first_byte_at,
152
+ "completed_at": self.timing.completed_at,
153
+ "ttfb_ms": self.timing.ttfb_ms,
154
+ "duration_ms": self.timing.duration_ms,
155
+ },
156
+ "tags": self.tags,
157
+ "notes": self.notes,
158
+ "bookmarked": self.bookmarked,
159
+ "retry_count": self.retry_count,
160
+ }
161
+
162
+ if self.request:
163
+ d["request"] = {
164
+ "method": self.request.method,
165
+ "path": self.request.path,
166
+ "model": self.request.model,
167
+ "stream": self.request.stream,
168
+ "message_count": len(self.request.messages),
169
+ "has_tools": bool(self.request.tools),
170
+ "has_system": bool(self.request.system),
171
+ }
172
+
173
+ if self.response:
174
+ d["response"] = {
175
+ "status_code": self.response.status_code,
176
+ "content_length": len(self.response.content),
177
+ "has_tool_calls": bool(self.response.tool_calls),
178
+ "stop_reason": self.response.stop_reason,
179
+ "chunk_count": self.response.chunk_count,
180
+ "usage": asdict(self.response.usage),
181
+ }
182
+
183
+ if self.error:
184
+ d["error"] = asdict(self.error)
185
+
186
+ return d
187
+
188
+ def to_full_dict(self) -> dict:
189
+ """转换为完整字典(包含请求/响应体)"""
190
+ d = self.to_dict()
191
+
192
+ if self.request:
193
+ d["request"]["headers"] = self.request.headers
194
+ d["request"]["body"] = self.request.body
195
+ d["request"]["messages"] = [asdict(m) if hasattr(m, '__dataclass_fields__') else m for m in self.request.messages]
196
+ d["request"]["system"] = self.request.system
197
+ d["request"]["tools"] = self.request.tools
198
+
199
+ if self.response:
200
+ d["response"]["headers"] = self.response.headers
201
+ d["response"]["body"] = self.response.body
202
+ d["response"]["content"] = self.response.content
203
+ d["response"]["tool_calls"] = self.response.tool_calls
204
+ d["response"]["chunks"] = self.response.chunks[-10:] # 只保留最后10个chunk
205
+
206
+ return d
207
+
208
+
209
+ class FlowStore:
210
+ """Flow 存储"""
211
+
212
+ def __init__(self, max_flows: int = 500, persist_dir: Optional[Path] = None):
213
+ self.flows: deque[LLMFlow] = deque(maxlen=max_flows)
214
+ self.flow_map: Dict[str, LLMFlow] = {}
215
+ self.persist_dir = persist_dir
216
+ self.max_flows = max_flows
217
+
218
+ # 统计
219
+ self.total_flows = 0
220
+ self.total_tokens_in = 0
221
+ self.total_tokens_out = 0
222
+
223
+ def add(self, flow: LLMFlow):
224
+ """添加 Flow"""
225
+ # 如果队列满了,移除最旧的
226
+ if len(self.flows) >= self.max_flows:
227
+ old = self.flows[0]
228
+ if old.id in self.flow_map:
229
+ del self.flow_map[old.id]
230
+
231
+ self.flows.append(flow)
232
+ self.flow_map[flow.id] = flow
233
+ self.total_flows += 1
234
+
235
+ def get(self, flow_id: str) -> Optional[LLMFlow]:
236
+ """获取 Flow"""
237
+ return self.flow_map.get(flow_id)
238
+
239
+ def update(self, flow_id: str, **kwargs):
240
+ """更新 Flow"""
241
+ flow = self.flow_map.get(flow_id)
242
+ if flow:
243
+ for k, v in kwargs.items():
244
+ if hasattr(flow, k):
245
+ setattr(flow, k, v)
246
+
247
+ def query(
248
+ self,
249
+ protocol: Optional[str] = None,
250
+ model: Optional[str] = None,
251
+ account_id: Optional[str] = None,
252
+ state: Optional[FlowState] = None,
253
+ has_error: Optional[bool] = None,
254
+ bookmarked: Optional[bool] = None,
255
+ min_duration_ms: Optional[float] = None,
256
+ max_duration_ms: Optional[float] = None,
257
+ start_time: Optional[float] = None,
258
+ end_time: Optional[float] = None,
259
+ search: Optional[str] = None,
260
+ limit: int = 100,
261
+ offset: int = 0,
262
+ ) -> List[LLMFlow]:
263
+ """查询 Flows"""
264
+ results = []
265
+
266
+ for flow in reversed(self.flows):
267
+ # 过滤条件
268
+ if protocol and flow.protocol != protocol:
269
+ continue
270
+ if model and flow.request and flow.request.model != model:
271
+ continue
272
+ if account_id and flow.account_id != account_id:
273
+ continue
274
+ if state and flow.state != state:
275
+ continue
276
+ if has_error is not None:
277
+ if has_error and not flow.error:
278
+ continue
279
+ if not has_error and flow.error:
280
+ continue
281
+ if bookmarked is not None and flow.bookmarked != bookmarked:
282
+ continue
283
+ if min_duration_ms and flow.timing.duration_ms and flow.timing.duration_ms < min_duration_ms:
284
+ continue
285
+ if max_duration_ms and flow.timing.duration_ms and flow.timing.duration_ms > max_duration_ms:
286
+ continue
287
+ if start_time and flow.timing.created_at < start_time:
288
+ continue
289
+ if end_time and flow.timing.created_at > end_time:
290
+ continue
291
+ if search:
292
+ # 简单搜索:在内容中查找
293
+ found = False
294
+ if flow.request and search.lower() in json.dumps(flow.request.body).lower():
295
+ found = True
296
+ if flow.response and search.lower() in flow.response.content.lower():
297
+ found = True
298
+ if not found:
299
+ continue
300
+
301
+ results.append(flow)
302
+
303
+ return results[offset:offset + limit]
304
+
305
+ def get_stats(self) -> dict:
306
+ """获取统计信息"""
307
+ completed = [f for f in self.flows if f.state == FlowState.COMPLETED]
308
+ errors = [f for f in self.flows if f.state == FlowState.ERROR]
309
+
310
+ # 按模型统计
311
+ model_stats = {}
312
+ for f in self.flows:
313
+ if f.request:
314
+ model = f.request.model or "unknown"
315
+ if model not in model_stats:
316
+ model_stats[model] = {"count": 0, "errors": 0, "tokens_in": 0, "tokens_out": 0}
317
+ model_stats[model]["count"] += 1
318
+ if f.error:
319
+ model_stats[model]["errors"] += 1
320
+ if f.response and f.response.usage:
321
+ model_stats[model]["tokens_in"] += f.response.usage.input_tokens
322
+ model_stats[model]["tokens_out"] += f.response.usage.output_tokens
323
+
324
+ # 计算平均延迟
325
+ durations = [f.timing.duration_ms for f in completed if f.timing.duration_ms]
326
+ avg_duration = sum(durations) / len(durations) if durations else 0
327
+
328
+ return {
329
+ "total_flows": self.total_flows,
330
+ "active_flows": len(self.flows),
331
+ "completed": len(completed),
332
+ "errors": len(errors),
333
+ "error_rate": f"{len(errors) / max(1, len(self.flows)) * 100:.1f}%",
334
+ "avg_duration_ms": round(avg_duration, 2),
335
+ "total_tokens_in": self.total_tokens_in,
336
+ "total_tokens_out": self.total_tokens_out,
337
+ "by_model": model_stats,
338
+ }
339
+
340
+ def export_jsonl(self, flows: List[LLMFlow]) -> str:
341
+ """导出为 JSONL 格式"""
342
+ lines = []
343
+ for f in flows:
344
+ lines.append(json.dumps(f.to_full_dict(), ensure_ascii=False))
345
+ return "\n".join(lines)
346
+
347
+ def export_markdown(self, flow: LLMFlow) -> str:
348
+ """导出单个 Flow 为 Markdown"""
349
+ lines = [
350
+ f"# Flow {flow.id}",
351
+ "",
352
+ f"- **Protocol**: {flow.protocol}",
353
+ f"- **State**: {flow.state.value}",
354
+ f"- **Account**: {flow.account_name or flow.account_id or 'N/A'}",
355
+ f"- **Created**: {datetime.fromtimestamp(flow.timing.created_at).isoformat()}",
356
+ ]
357
+
358
+ if flow.timing.duration_ms:
359
+ lines.append(f"- **Duration**: {flow.timing.duration_ms:.0f}ms")
360
+
361
+ if flow.request:
362
+ lines.extend([
363
+ "",
364
+ "## Request",
365
+ "",
366
+ f"- **Model**: {flow.request.model}",
367
+ f"- **Stream**: {flow.request.stream}",
368
+ f"- **Messages**: {len(flow.request.messages)}",
369
+ ])
370
+
371
+ if flow.request.system:
372
+ lines.extend(["", "### System", "", f"```\n{flow.request.system}\n```"])
373
+
374
+ lines.extend(["", "### Messages", ""])
375
+ for msg in flow.request.messages:
376
+ content = msg.content if isinstance(msg.content, str) else json.dumps(msg.content, ensure_ascii=False)
377
+ lines.append(f"**{msg.role}**: {content[:500]}{'...' if len(content) > 500 else ''}")
378
+ lines.append("")
379
+
380
+ if flow.response:
381
+ lines.extend([
382
+ "## Response",
383
+ "",
384
+ f"- **Status**: {flow.response.status_code}",
385
+ f"- **Stop Reason**: {flow.response.stop_reason}",
386
+ ])
387
+
388
+ if flow.response.usage:
389
+ lines.append(f"- **Tokens**: {flow.response.usage.input_tokens} in / {flow.response.usage.output_tokens} out")
390
+
391
+ if flow.response.content:
392
+ lines.extend(["", "### Content", "", f"```\n{flow.response.content[:2000]}\n```"])
393
+
394
+ if flow.error:
395
+ lines.extend([
396
+ "",
397
+ "## Error",
398
+ "",
399
+ f"- **Type**: {flow.error.type}",
400
+ f"- **Message**: {flow.error.message}",
401
+ ])
402
+
403
+ return "\n".join(lines)
404
+
405
+
406
+ class FlowMonitor:
407
+ """Flow 监控器"""
408
+
409
+ def __init__(self, max_flows: int = 500):
410
+ self.store = FlowStore(max_flows=max_flows)
411
+
412
+ def create_flow(
413
+ self,
414
+ protocol: str,
415
+ method: str,
416
+ path: str,
417
+ headers: Dict[str, str],
418
+ body: Dict[str, Any],
419
+ account_id: Optional[str] = None,
420
+ account_name: Optional[str] = None,
421
+ ) -> str:
422
+ """创建新的 Flow"""
423
+ flow_id = uuid.uuid4().hex[:12]
424
+
425
+ # 解析请求
426
+ request = FlowRequest(
427
+ method=method,
428
+ path=path,
429
+ headers={k: v for k, v in headers.items() if k.lower() not in ["authorization"]},
430
+ body=body,
431
+ model=body.get("model", ""),
432
+ stream=body.get("stream", False),
433
+ system=body.get("system", ""),
434
+ tools=body.get("tools", []),
435
+ max_tokens=body.get("max_tokens", 0),
436
+ temperature=body.get("temperature", 1.0),
437
+ )
438
+
439
+ # 解析消息
440
+ messages = body.get("messages", [])
441
+ for msg in messages:
442
+ request.messages.append(Message(
443
+ role=msg.get("role", "user"),
444
+ content=msg.get("content", ""),
445
+ name=msg.get("name"),
446
+ tool_call_id=msg.get("tool_call_id"),
447
+ ))
448
+
449
+ flow = LLMFlow(
450
+ id=flow_id,
451
+ state=FlowState.PENDING,
452
+ protocol=protocol,
453
+ account_id=account_id,
454
+ account_name=account_name,
455
+ request=request,
456
+ timing=FlowTiming(created_at=time.time()),
457
+ )
458
+
459
+ self.store.add(flow)
460
+ return flow_id
461
+
462
+ def start_streaming(self, flow_id: str):
463
+ """标记开始流式传输"""
464
+ flow = self.store.get(flow_id)
465
+ if flow:
466
+ flow.state = FlowState.STREAMING
467
+ flow.timing.first_byte_at = time.time()
468
+ if not flow.response:
469
+ flow.response = FlowResponse(status_code=200)
470
+
471
+ def add_chunk(self, flow_id: str, chunk: str):
472
+ """添加流式响应块"""
473
+ flow = self.store.get(flow_id)
474
+ if flow and flow.response:
475
+ flow.response.chunks.append(chunk)
476
+ flow.response.chunk_count += 1
477
+ flow.response.content += chunk
478
+
479
+ def complete_flow(
480
+ self,
481
+ flow_id: str,
482
+ status_code: int,
483
+ content: str = "",
484
+ tool_calls: List[Dict] = None,
485
+ stop_reason: str = "",
486
+ usage: Optional[TokenUsage] = None,
487
+ headers: Dict[str, str] = None,
488
+ ):
489
+ """完成 Flow"""
490
+ flow = self.store.get(flow_id)
491
+ if not flow:
492
+ return
493
+
494
+ flow.state = FlowState.COMPLETED
495
+ flow.timing.completed_at = time.time()
496
+
497
+ if not flow.response:
498
+ flow.response = FlowResponse(status_code=status_code)
499
+
500
+ flow.response.status_code = status_code
501
+ flow.response.content = content or flow.response.content
502
+ flow.response.tool_calls = tool_calls or []
503
+ flow.response.stop_reason = stop_reason
504
+ flow.response.headers = headers or {}
505
+
506
+ if usage:
507
+ flow.response.usage = usage
508
+ self.store.total_tokens_in += usage.input_tokens
509
+ self.store.total_tokens_out += usage.output_tokens
510
+
511
+ def fail_flow(self, flow_id: str, error_type: str, message: str, status_code: int = 0, raw: str = ""):
512
+ """标记 Flow 失败"""
513
+ flow = self.store.get(flow_id)
514
+ if not flow:
515
+ return
516
+
517
+ flow.state = FlowState.ERROR
518
+ flow.timing.completed_at = time.time()
519
+ flow.error = FlowError(
520
+ type=error_type,
521
+ message=message,
522
+ status_code=status_code,
523
+ raw=raw[:1000], # 限制长度
524
+ )
525
+
526
+ def bookmark_flow(self, flow_id: str, bookmarked: bool = True):
527
+ """书签 Flow"""
528
+ flow = self.store.get(flow_id)
529
+ if flow:
530
+ flow.bookmarked = bookmarked
531
+
532
+ def add_note(self, flow_id: str, note: str):
533
+ """添加备注"""
534
+ flow = self.store.get(flow_id)
535
+ if flow:
536
+ flow.notes = note
537
+
538
+ def add_tag(self, flow_id: str, tag: str):
539
+ """添加标签"""
540
+ flow = self.store.get(flow_id)
541
+ if flow and tag not in flow.tags:
542
+ flow.tags.append(tag)
543
+
544
+ def get_flow(self, flow_id: str) -> Optional[LLMFlow]:
545
+ """获取 Flow"""
546
+ return self.store.get(flow_id)
547
+
548
+ def query(self, **kwargs) -> List[LLMFlow]:
549
+ """查询 Flows"""
550
+ return self.store.query(**kwargs)
551
+
552
+ def get_stats(self) -> dict:
553
+ """获取统计"""
554
+ return self.store.get_stats()
555
+
556
+ def export(self, flow_ids: List[str] = None, format: str = "jsonl") -> str:
557
+ """导出 Flows"""
558
+ if flow_ids:
559
+ flows = [self.store.get(fid) for fid in flow_ids if self.store.get(fid)]
560
+ else:
561
+ flows = list(self.store.flows)
562
+
563
+ if format == "jsonl":
564
+ return self.store.export_jsonl(flows)
565
+ elif format == "markdown" and len(flows) == 1:
566
+ return self.store.export_markdown(flows[0])
567
+ else:
568
+ return json.dumps([f.to_dict() for f in flows], ensure_ascii=False, indent=2)
569
+
570
+
571
+ # 全局实例
572
+ flow_monitor = FlowMonitor(max_flows=500)
KiroProxy/kiro_proxy/core/history_manager.py ADDED
@@ -0,0 +1,829 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """历史消息管理器 - 错误触发压缩版
2
+
3
+ 自动化管理对话历史长度,收到超限错误时智能压缩而非强硬截断:
4
+ 1. 无预检测 - 不再依赖阈值,正常发送请求
5
+ 2. 错误触发 - 收到 CONTENT_LENGTH_EXCEEDS_THRESHOLD 错误后自动压缩
6
+ 3. 智能压缩 - 保留最近消息 + 摘要早期对话,目标 20K-50K 字符
7
+ 4. 自动重试 - 压缩后自动重试请求
8
+ """
9
+ import json
10
+ import time
11
+ from typing import List, Dict, Any, Tuple, Optional, Callable
12
+ from dataclasses import dataclass, field
13
+ from collections import OrderedDict
14
+ from enum import Enum
15
+
16
+
17
+ @dataclass
18
+ class SummaryCacheEntry:
19
+ summary: str
20
+ old_history_hash: str
21
+ updated_at: float
22
+
23
+
24
+ class SummaryCache:
25
+ """摘要缓存"""
26
+
27
+ def __init__(self, max_entries: int = 64):
28
+ self._entries: "OrderedDict[str, SummaryCacheEntry]" = OrderedDict()
29
+ self._max_entries = max_entries
30
+
31
+ def get(self, key: str, old_history_hash: str, max_age: int = 300) -> Optional[str]:
32
+ entry = self._entries.get(key)
33
+ if not entry:
34
+ return None
35
+ if time.time() - entry.updated_at > max_age:
36
+ self._entries.pop(key, None)
37
+ return None
38
+ if entry.old_history_hash != old_history_hash:
39
+ return None
40
+ self._entries.move_to_end(key)
41
+ return entry.summary
42
+
43
+ def set(self, key: str, summary: str, old_history_hash: str):
44
+ self._entries[key] = SummaryCacheEntry(
45
+ summary=summary,
46
+ old_history_hash=old_history_hash,
47
+ updated_at=time.time()
48
+ )
49
+ self._entries.move_to_end(key)
50
+ if len(self._entries) > self._max_entries:
51
+ self._entries.popitem(last=False)
52
+
53
+
54
+ @dataclass
55
+ class CompressionCacheEntry:
56
+ """压缩结果缓存条目"""
57
+ compressed_history: List[dict]
58
+ original_hash: str
59
+ compressed_chars: int
60
+ updated_at: float
61
+
62
+
63
+ class CompressionCache:
64
+ """全局压缩结果缓存
65
+
66
+ 解决 Claude Code CLI 反复压缩问题:
67
+ - 客户端每次请求都发送完整原始历史
68
+ - 缓存压缩结果,避免对相同内容重复压缩
69
+ - 基于原始历史的 hash 匹配
70
+ """
71
+
72
+ def __init__(self, max_entries: int = 32, max_age: int = 600):
73
+ self._entries: "OrderedDict[str, CompressionCacheEntry]" = OrderedDict()
74
+ self._max_entries = max_entries
75
+ self._max_age = max_age # 缓存有效期(秒),默认 10 分钟
76
+
77
+ def get(self, original_hash: str) -> Optional[List[dict]]:
78
+ """获取缓存的压缩结果"""
79
+ entry = self._entries.get(original_hash)
80
+ if not entry:
81
+ return None
82
+ if time.time() - entry.updated_at > self._max_age:
83
+ self._entries.pop(original_hash, None)
84
+ return None
85
+ self._entries.move_to_end(original_hash)
86
+ print(f"[CompressionCache] 命中缓存,跳过重复压缩 (原始 hash: {original_hash[:16]}...)")
87
+ return entry.compressed_history
88
+
89
+ def set(self, original_hash: str, compressed_history: List[dict], compressed_chars: int):
90
+ """缓存压缩结果"""
91
+ self._entries[original_hash] = CompressionCacheEntry(
92
+ compressed_history=compressed_history,
93
+ original_hash=original_hash,
94
+ compressed_chars=compressed_chars,
95
+ updated_at=time.time()
96
+ )
97
+ self._entries.move_to_end(original_hash)
98
+ if len(self._entries) > self._max_entries:
99
+ self._entries.popitem(last=False)
100
+ print(f"[CompressionCache] 缓存压缩结果 (原始 hash: {original_hash[:16]}..., 压缩后: {compressed_chars} 字符)")
101
+
102
+ def clear(self):
103
+ """清空缓存"""
104
+ self._entries.clear()
105
+
106
+
107
+ # 全局压缩缓存实例
108
+ _compression_cache = CompressionCache()
109
+
110
+
111
+ class TruncateStrategy(str, Enum):
112
+ """压缩策略(保留用于兼容)"""
113
+ NONE = "none"
114
+ AUTO_TRUNCATE = "auto_truncate"
115
+ SMART_SUMMARY = "smart_summary"
116
+ ERROR_RETRY = "error_retry"
117
+ PRE_ESTIMATE = "pre_estimate"
118
+
119
+
120
+ # 自动管理的常量(不再使用阈值触发,仅在错误后压缩)
121
+ # AUTO_COMPRESS_THRESHOLD 已废弃,不再用于预检测
122
+ SAFE_CHAR_LIMIT = 35000 # 压缩后的目标字符数 (20K-50K 范围的中间值)
123
+ SAFE_CHAR_LIMIT_MIN = 20000 # 压缩目标下限
124
+ SAFE_CHAR_LIMIT_MAX = 50000 # 压缩目标上限
125
+ MIN_KEEP_MESSAGES = 6 # 最少保留的最近消息数
126
+ MAX_KEEP_MESSAGES = 20 # 最多保留的最近消息数
127
+ SUMMARY_MAX_LENGTH = 3000 # 摘要最大长度
128
+
129
+
130
+ @dataclass
131
+ class HistoryConfig:
132
+ """历史消息配置(简化版,大部分参数自动管理)"""
133
+ # 启用的策略
134
+ strategies: List[TruncateStrategy] = field(default_factory=lambda: [TruncateStrategy.ERROR_RETRY])
135
+
136
+ # 以下参数保留用于兼容,但实际使用自动值
137
+ max_messages: int = 30
138
+ max_chars: int = 150000
139
+ summary_keep_recent: int = 10
140
+ summary_threshold: int = 100000
141
+ summary_max_length: int = 2000
142
+ retry_max_messages: int = 20
143
+ max_retries: int = 3
144
+ estimate_threshold: int = 180000
145
+ chars_per_token: float = 3.0
146
+ summary_cache_enabled: bool = True
147
+ summary_cache_min_delta_messages: int = 3
148
+ summary_cache_min_delta_chars: int = 4000
149
+ summary_cache_max_age_seconds: int = 300
150
+ add_warning_header: bool = True
151
+
152
+ def to_dict(self) -> dict:
153
+ return {
154
+ "strategies": [s.value for s in self.strategies],
155
+ "max_messages": self.max_messages,
156
+ "max_chars": self.max_chars,
157
+ "summary_keep_recent": self.summary_keep_recent,
158
+ "summary_threshold": self.summary_threshold,
159
+ "summary_max_length": self.summary_max_length,
160
+ "retry_max_messages": self.retry_max_messages,
161
+ "max_retries": self.max_retries,
162
+ "estimate_threshold": self.estimate_threshold,
163
+ "chars_per_token": self.chars_per_token,
164
+ "summary_cache_enabled": self.summary_cache_enabled,
165
+ "summary_cache_min_delta_messages": self.summary_cache_min_delta_messages,
166
+ "summary_cache_min_delta_chars": self.summary_cache_min_delta_chars,
167
+ "summary_cache_max_age_seconds": self.summary_cache_max_age_seconds,
168
+ "add_warning_header": self.add_warning_header,
169
+ }
170
+
171
+ @classmethod
172
+ def from_dict(cls, data: dict) -> "HistoryConfig":
173
+ strategies = [TruncateStrategy(s) for s in data.get("strategies", ["error_retry"])]
174
+ return cls(
175
+ strategies=strategies,
176
+ max_messages=data.get("max_messages", 30),
177
+ max_chars=data.get("max_chars", 150000),
178
+ summary_keep_recent=data.get("summary_keep_recent", 10),
179
+ summary_threshold=data.get("summary_threshold", 100000),
180
+ summary_max_length=data.get("summary_max_length", 2000),
181
+ retry_max_messages=data.get("retry_max_messages", 20),
182
+ max_retries=data.get("max_retries", 3),
183
+ estimate_threshold=data.get("estimate_threshold", 180000),
184
+ chars_per_token=data.get("chars_per_token", 3.0),
185
+ summary_cache_enabled=data.get("summary_cache_enabled", True),
186
+ summary_cache_min_delta_messages=data.get("summary_cache_min_delta_messages", 3),
187
+ summary_cache_min_delta_chars=data.get("summary_cache_min_delta_chars", 4000),
188
+ summary_cache_max_age_seconds=data.get("summary_cache_max_age_seconds", 300),
189
+ add_warning_header=data.get("add_warning_header", True),
190
+ )
191
+
192
+
193
+ _summary_cache = SummaryCache()
194
+
195
+
196
+ class HistoryManager:
197
+ """历史消息管理器 - 错误触发压缩版
198
+
199
+ 不再依赖阈值预检测,仅在收到上下文超限错误后触发压缩。
200
+ 压缩目标为 20K-50K 字符范围。
201
+ """
202
+
203
+ def __init__(self, config: HistoryConfig = None, cache_key: Optional[str] = None):
204
+ self.config = config or HistoryConfig()
205
+ self._truncated = False
206
+ self._truncate_info = ""
207
+ self.cache_key = cache_key
208
+ self._retry_count = 0
209
+
210
+ @property
211
+ def was_truncated(self) -> bool:
212
+ return self._truncated
213
+
214
+ @property
215
+ def truncate_info(self) -> str:
216
+ return self._truncate_info
217
+
218
+ def reset(self):
219
+ self._truncated = False
220
+ self._truncate_info = ""
221
+
222
+ def set_cache_key(self, key: Optional[str]):
223
+ self.cache_key = key
224
+
225
+ def _hash_history(self, history: List[dict]) -> str:
226
+ """生成历史消息的简单哈希"""
227
+ return f"{len(history)}:{len(json.dumps(history, ensure_ascii=False))}"
228
+
229
+ def estimate_tokens(self, text: str) -> int:
230
+ return int(len(text) / self.config.chars_per_token)
231
+
232
+ def estimate_history_size(self, history: List[dict]) -> Tuple[int, int]:
233
+ char_count = len(json.dumps(history, ensure_ascii=False))
234
+ return len(history), char_count
235
+
236
+ def estimate_request_chars(self, history: List[dict], user_content: str = "") -> Tuple[int, int, int]:
237
+ history_chars = len(json.dumps(history, ensure_ascii=False))
238
+ user_chars = len(user_content or "")
239
+ return history_chars, user_chars, history_chars + user_chars
240
+
241
+ def _extract_text(self, content) -> str:
242
+ if isinstance(content, str):
243
+ return content
244
+ if isinstance(content, list):
245
+ texts = []
246
+ for item in content:
247
+ if isinstance(item, dict) and item.get("type") == "text":
248
+ texts.append(item.get("text", ""))
249
+ elif isinstance(item, str):
250
+ texts.append(item)
251
+ return "\n".join(texts)
252
+ if isinstance(content, dict):
253
+ return content.get("text", "") or content.get("content", "")
254
+ return str(content) if content else ""
255
+
256
+
257
+ def _format_for_summary(self, history: List[dict]) -> str:
258
+ """格式化历史消息用于生成摘要"""
259
+ lines = []
260
+ for msg in history:
261
+ role = "unknown"
262
+ content = ""
263
+ if "userInputMessage" in msg:
264
+ role = "user"
265
+ content = msg.get("userInputMessage", {}).get("content", "")
266
+ elif "assistantResponseMessage" in msg:
267
+ role = "assistant"
268
+ content = msg.get("assistantResponseMessage", {}).get("content", "")
269
+ else:
270
+ role = msg.get("role", "unknown")
271
+ content = self._extract_text(msg.get("content", ""))
272
+ # 截断过长的单条消息
273
+ if len(content) > 800:
274
+ content = content[:800] + "..."
275
+ lines.append(f"[{role}]: {content}")
276
+ return "\n".join(lines)
277
+
278
+ def _calculate_keep_count(self, history: List[dict], target_chars: int) -> int:
279
+ """计算应该保留多少条最近消息"""
280
+ if not history:
281
+ return 0
282
+
283
+ # 从后往前累计,找到合适的保留数量
284
+ total = 0
285
+ count = 0
286
+ for msg in reversed(history):
287
+ msg_chars = len(json.dumps(msg, ensure_ascii=False))
288
+ if total + msg_chars > target_chars and count >= MIN_KEEP_MESSAGES:
289
+ break
290
+ total += msg_chars
291
+ count += 1
292
+ if count >= MAX_KEEP_MESSAGES:
293
+ break
294
+
295
+ return max(MIN_KEEP_MESSAGES, min(count, len(history) - 1))
296
+
297
+ def _build_compressed_history(
298
+ self,
299
+ summary: str,
300
+ recent_history: List[dict],
301
+ label: str = ""
302
+ ) -> List[dict]:
303
+ """构建压缩后的历史(摘要 + 最近消息)"""
304
+ # 确保 recent_history 以 user 消息开头
305
+ if recent_history and "assistantResponseMessage" in recent_history[0]:
306
+ recent_history = recent_history[1:]
307
+
308
+ # 清理孤立的 toolResults
309
+ tool_use_ids = set()
310
+ for msg in recent_history:
311
+ if "assistantResponseMessage" in msg:
312
+ for tu in msg["assistantResponseMessage"].get("toolUses", []) or []:
313
+ if tu.get("toolUseId"):
314
+ tool_use_ids.add(tu["toolUseId"])
315
+
316
+ # 清理第一条 user 消息的 toolResults(因为前面没有对应的 toolUse)
317
+ if recent_history and "userInputMessage" in recent_history[0]:
318
+ recent_history[0]["userInputMessage"].pop("userInputMessageContext", None)
319
+
320
+ # 过滤其他消息中孤立的 toolResults
321
+ if tool_use_ids:
322
+ for msg in recent_history:
323
+ if "userInputMessage" in msg:
324
+ ctx = msg.get("userInputMessage", {}).get("userInputMessageContext", {})
325
+ results = ctx.get("toolResults")
326
+ if results:
327
+ filtered = [r for r in results if r.get("toolUseId") in tool_use_ids]
328
+ if filtered:
329
+ ctx["toolResults"] = filtered
330
+ else:
331
+ ctx.pop("toolResults", None)
332
+ if not ctx:
333
+ msg["userInputMessage"].pop("userInputMessageContext", None)
334
+ else:
335
+ for msg in recent_history:
336
+ if "userInputMessage" in msg:
337
+ msg["userInputMessage"].pop("userInputMessageContext", None)
338
+
339
+
340
+ # 获取 model_id
341
+ model_id = "claude-sonnet-4"
342
+ for msg in reversed(recent_history):
343
+ if "userInputMessage" in msg:
344
+ model_id = msg["userInputMessage"].get("modelId", model_id)
345
+ break
346
+ if "assistantResponseMessage" in msg:
347
+ model_id = msg["assistantResponseMessage"].get("modelId", model_id)
348
+ break
349
+
350
+ # 检测消息格式
351
+ is_kiro_format = any("userInputMessage" in h or "assistantResponseMessage" in h for h in recent_history)
352
+
353
+ if is_kiro_format:
354
+ result = [
355
+ {
356
+ "userInputMessage": {
357
+ "content": f"[Earlier conversation summary]\n{summary}\n\n[Continuing from recent context...]",
358
+ "modelId": model_id,
359
+ "origin": "AI_EDITOR",
360
+ }
361
+ },
362
+ {
363
+ "assistantResponseMessage": {
364
+ "content": "I understand the context from the summary. Let's continue."
365
+ }
366
+ }
367
+ ]
368
+ else:
369
+ result = [
370
+ {"role": "user", "content": f"[Earlier conversation summary]\n{summary}\n\n[Continuing from recent context...]"},
371
+ {"role": "assistant", "content": "I understand the context from the summary. Let's continue."}
372
+ ]
373
+
374
+ result.extend(recent_history)
375
+
376
+ if label:
377
+ print(f"[HistoryManager] {label}: {len(recent_history)} recent + summary")
378
+
379
+ return result
380
+
381
+ async def _generate_summary(self, history: List[dict], api_caller: Callable) -> Optional[str]:
382
+ """生成历史消息摘要"""
383
+ if not history or not api_caller:
384
+ return None
385
+
386
+ formatted = self._format_for_summary(history)
387
+ if len(formatted) > 15000:
388
+ formatted = formatted[:15000] + "\n...(truncated)"
389
+
390
+ prompt = f"""请简洁总结以下对话的关键信息:
391
+ 1. 用户的主要目标
392
+ 2. 已完成的重要操作和决策
393
+ 3. 当前工作状态和关键上下文
394
+
395
+ 对话历史:
396
+ {formatted}
397
+
398
+ 请用中文输出摘要,控制在 {SUMMARY_MAX_LENGTH} 字符以内,重点保留对后续对话有用的信息:"""
399
+
400
+ try:
401
+ summary = await api_caller(prompt)
402
+ if summary and len(summary) > SUMMARY_MAX_LENGTH:
403
+ summary = summary[:SUMMARY_MAX_LENGTH] + "..."
404
+ return summary
405
+ except Exception as e:
406
+ print(f"[HistoryManager] 生成摘要失败: {e}")
407
+ return None
408
+
409
+
410
+ async def smart_compress(
411
+ self,
412
+ history: List[dict],
413
+ api_caller: Callable,
414
+ target_chars: int = SAFE_CHAR_LIMIT,
415
+ retry_level: int = 0
416
+ ) -> List[dict]:
417
+ """智能压缩历史消息
418
+
419
+ 核心逻辑:保留最近消息 + 摘要早期对话
420
+ 压缩目标为 20K-50K 字符范围
421
+
422
+ Args:
423
+ history: 历史消息
424
+ api_caller: 用于生成摘要的 API 调用函数
425
+ target_chars: 目标字符数 (默认 35K,范围 20K-50K)
426
+ retry_level: 重试级别(越高保留越少)
427
+ """
428
+ if not history:
429
+ return history
430
+
431
+ current_chars = len(json.dumps(history, ensure_ascii=False))
432
+
433
+ # 确保目标在 20K-50K 范围内
434
+ target_chars = max(SAFE_CHAR_LIMIT_MIN, min(target_chars, SAFE_CHAR_LIMIT_MAX))
435
+
436
+ # 如果已经在目标范围内,不需要压缩
437
+ if current_chars <= target_chars:
438
+ return history
439
+
440
+ # 根据重试级别调整保留数量
441
+ adjusted_target = int(target_chars * (0.85 ** retry_level))
442
+ adjusted_target = max(SAFE_CHAR_LIMIT_MIN, adjusted_target) # 确保不低于下限
443
+
444
+ keep_count = self._calculate_keep_count(history, adjusted_target)
445
+
446
+ # 确保至少保留一些消息用于摘要
447
+ if keep_count >= len(history):
448
+ keep_count = max(MIN_KEEP_MESSAGES, len(history) - 2)
449
+
450
+ old_history = history[:-keep_count] if keep_count < len(history) else []
451
+ recent_history = history[-keep_count:] if keep_count > 0 else history
452
+
453
+ if not old_history:
454
+ # 没有可摘要的历史,直接返回
455
+ return recent_history
456
+
457
+ # 尝试从缓存获取摘要
458
+ cache_key = f"{self.cache_key}:{keep_count}" if self.cache_key else None
459
+ old_hash = self._hash_history(old_history)
460
+
461
+ cached_summary = None
462
+ if cache_key and self.config.summary_cache_enabled:
463
+ cached_summary = _summary_cache.get(cache_key, old_hash, self.config.summary_cache_max_age_seconds)
464
+
465
+ if cached_summary:
466
+ result = self._build_compressed_history(cached_summary, recent_history, "压缩(缓存)")
467
+ result_chars = len(json.dumps(result, ensure_ascii=False))
468
+ self._truncated = True
469
+ self._truncate_info = f"智能压缩(缓存): {len(history)} -> {len(result)} 条消息, {current_chars} -> {result_chars} 字符"
470
+ return result
471
+
472
+ # 生成新摘要
473
+ summary = await self._generate_summary(old_history, api_caller)
474
+
475
+ if summary:
476
+ if cache_key and self.config.summary_cache_enabled:
477
+ _summary_cache.set(cache_key, summary, old_hash)
478
+
479
+ result = self._build_compressed_history(summary, recent_history, "智能压缩")
480
+ result_chars = len(json.dumps(result, ensure_ascii=False))
481
+ self._truncated = True
482
+ self._truncate_info = f"智能压缩: {len(history)} -> {len(result)} 条消息, {current_chars} -> {result_chars} 字符 (摘要 {len(summary)} 字符)"
483
+ return result
484
+
485
+ # 摘要失败,回退到简单截断
486
+ self._truncated = True
487
+ result_chars = len(json.dumps(recent_history, ensure_ascii=False))
488
+ self._truncate_info = f"摘要失败,保留最近 {len(recent_history)} 条消息, {current_chars} -> {result_chars} 字符"
489
+ return recent_history
490
+
491
+
492
+ def needs_compression(self, history: List[dict], user_content: str = "") -> bool:
493
+ """检查是否需要压缩
494
+
495
+ 注意:此方法现在始终返回 False,不再基于阈值预检测。
496
+ 压缩仅在收到上下文超限错误后触发。
497
+ 保留此方法是为了兼容旧 API。
498
+ """
499
+ # 不再基于阈��预检测,始终返回 False
500
+ # 压缩将在收到 CONTENT_LENGTH_EXCEEDS_THRESHOLD 错误后触发
501
+ return False
502
+
503
+ async def pre_process_async(
504
+ self,
505
+ history: List[dict],
506
+ user_content: str = "",
507
+ api_caller: Callable = None
508
+ ) -> List[dict]:
509
+ """预处理历史消息
510
+
511
+ 注意:不再进行发送前自动压缩。
512
+ 压缩仅在收到上下文超限错误后触发。
513
+ """
514
+ self.reset()
515
+
516
+ if not history:
517
+ return history
518
+
519
+ # 不再进行预压缩,直接返回原始历史
520
+ # 压缩将在收到错误后由 handle_length_error_async 处理
521
+ return history
522
+
523
+ def pre_process(self, history: List[dict], user_content: str = "") -> List[dict]:
524
+ """预处理历史消息(同步版本)
525
+
526
+ 注意:不再进行发送前自动压缩。
527
+ 压缩仅在收到上下文超限错误后触发。
528
+ """
529
+ self.reset()
530
+
531
+ if not history:
532
+ return history
533
+
534
+ # 不再进行预压缩,直接返回原始历史
535
+ return history
536
+
537
+ async def handle_length_error_async(
538
+ self,
539
+ history: List[dict],
540
+ retry_count: int = 0,
541
+ api_caller: Optional[Callable] = None
542
+ ) -> Tuple[List[dict], bool]:
543
+ """处理长度超限错误(智能压缩后重试)
544
+
545
+ 这是唯一触发压缩的入口点。当收到上下文超限错误时调用此方法。
546
+ 压缩目标为 20K-50K 字符范围。
547
+
548
+ 防止无限循环:
549
+ - 追踪压缩状态,避免重复压缩相同内容
550
+ - 压缩前检查大小,如果已经很小则不再压缩
551
+ - 达到最大重试次数后返回清晰错误
552
+
553
+ Args:
554
+ history: 历史消息
555
+ retry_count: 当前重试次数
556
+ api_caller: API 调用函数
557
+
558
+ Returns:
559
+ (compressed_history, should_retry)
560
+ """
561
+ max_retries = self.config.max_retries
562
+
563
+ if retry_count >= max_retries:
564
+ print(f"[HistoryManager] 已达最大重试次数 ({max_retries}),建议清空对话")
565
+ self._truncate_info = f"已达最大压缩次数 ({max_retries}),请清空对话或减少消息数量"
566
+ return history, False
567
+
568
+ if not history:
569
+ return history, False
570
+
571
+ self.reset()
572
+
573
+ current_chars = len(json.dumps(history, ensure_ascii=False))
574
+ current_hash = self._hash_history(history)
575
+
576
+ print(f"[HistoryManager] 收到上下文超限错误,当前大小: {current_chars} 字符")
577
+
578
+ # 优先检查全局压缩缓存(解决 Claude Code CLI 反复压缩问题)
579
+ cached_result = _compression_cache.get(current_hash)
580
+ if cached_result is not None:
581
+ cached_chars = len(json.dumps(cached_result, ensure_ascii=False))
582
+ self._truncated = True
583
+ self._truncate_info = f"使用缓存的压缩结果: {len(history)} -> {len(cached_result)} 条消息, {current_chars} -> {cached_chars} 字符"
584
+ print(f"[HistoryManager] {self._truncate_info}")
585
+ return cached_result, True
586
+
587
+ print(f"[HistoryManager] 开始压缩...")
588
+
589
+ # 防止无限循环:检查是否已经压缩过相同内容(实例级缓存)
590
+ instance_cache_key = f"compression:{current_hash}:{retry_count}"
591
+ if hasattr(self, '_instance_compression_cache') and instance_cache_key in self._instance_compression_cache:
592
+ print(f"[HistoryManager] 检测到重复压缩请求,跳过")
593
+ self._truncate_info = "内容已压缩到最小,无法继续压缩,请清空对话"
594
+ return history, False
595
+
596
+ # 初始化实例级压缩缓存
597
+ if not hasattr(self, '_instance_compression_cache'):
598
+ self._instance_compression_cache = {}
599
+
600
+ # 根据重试次数计算目标大小 (20K-50K 范围)
601
+ # 第一次重试: 目标 35K (中间值)
602
+ # 第二次重试: 目标 25K
603
+ # 第三次重试: 目标 20K (下限)
604
+ if retry_count == 0:
605
+ target_chars = SAFE_CHAR_LIMIT # 35K
606
+ elif retry_count == 1:
607
+ target_chars = 25000
608
+ else:
609
+ target_chars = SAFE_CHAR_LIMIT_MIN # 20K
610
+
611
+ # 防止无限循环:如果当前大小已经小于目标,不再压缩
612
+ if current_chars <= target_chars:
613
+ print(f"[HistoryManager] 当前大小 ({current_chars}) 已小于目标 ({target_chars}),无法继续压缩")
614
+ self._truncate_info = f"内容已压缩到 {current_chars} 字符,仍然超限,请清空对话"
615
+ return history, False
616
+
617
+ print(f"[HistoryManager] 第 {retry_count + 1} 次重试,目标压缩到 {target_chars} 字符")
618
+
619
+ if api_caller:
620
+ compressed = await self.smart_compress(
621
+ history, api_caller,
622
+ target_chars=target_chars,
623
+ retry_level=retry_count
624
+ )
625
+ compressed_chars = len(json.dumps(compressed, ensure_ascii=False))
626
+
627
+ # 防止无限循环:检查压缩是否有效
628
+ if compressed_chars >= current_chars * 0.95: # 压缩效果不足 5%
629
+ print(f"[HistoryManager] 压缩效果不足,无法继续压缩")
630
+ self._truncate_info = f"压缩效果不足,请清空对话或减少消息数量"
631
+ return history, False
632
+
633
+ # 防止无限循环:检查压缩后是否仍然过大
634
+ if compressed_chars > 50000 and retry_count >= max_retries - 1:
635
+ print(f"[HistoryManager] 压缩后仍然过大 ({compressed_chars}),建议清空对话")
636
+ self._truncate_info = f"压缩后仍有 {compressed_chars} 字符,请清空对话"
637
+ return compressed, False
638
+
639
+ if len(compressed) < len(history):
640
+ # 保存到全局压缩缓存(解决 Claude Code CLI 反复压缩问题)
641
+ _compression_cache.set(current_hash, compressed, compressed_chars)
642
+
643
+ # 记录实例级压缩缓存(防止同一请求内的重复压缩)
644
+ self._instance_compression_cache[instance_cache_key] = True
645
+ # 清理旧缓存(保留最近 10 条)
646
+ if len(self._instance_compression_cache) > 10:
647
+ oldest_key = next(iter(self._instance_compression_cache))
648
+ del self._instance_compression_cache[oldest_key]
649
+
650
+ self._truncated = True
651
+ self._truncate_info = f"错误后压缩 (第 {retry_count + 1} 次): {len(history)} -> {len(compressed)} 条消息, {current_chars} -> {compressed_chars} 字符"
652
+ print(f"[HistoryManager] {self._truncate_info}")
653
+ return compressed, True
654
+ else:
655
+ # 无 api_caller,简单截断
656
+ keep_count = max(MIN_KEEP_MESSAGES, int(len(history) * (0.5 ** (retry_count + 1))))
657
+ if keep_count < len(history):
658
+ truncated = history[-keep_count:]
659
+ self._truncated = True
660
+ truncated_chars = len(json.dumps(truncated, ensure_ascii=False))
661
+
662
+ # 防止无限循环:检查截断是否有效
663
+ if truncated_chars >= current_chars * 0.95:
664
+ print(f"[HistoryManager] 截断效果不足,无法继续压缩")
665
+ self._truncate_info = f"截断效果不足,请清空对话"
666
+ return history, False
667
+
668
+ self._truncate_info = f"错误后截断 (第 {retry_count + 1} 次): {len(history)} -> {len(truncated)} 条消息, {current_chars} -> {truncated_chars} 字符"
669
+ print(f"[HistoryManager] {self._truncate_info}")
670
+ return truncated, True
671
+
672
+ return history, False
673
+
674
+
675
+ def handle_length_error(self, history: List[dict], retry_count: int = 0) -> Tuple[List[dict], bool]:
676
+ """处理长度超限错误(同步版本,简单截断)"""
677
+ max_retries = self.config.max_retries
678
+
679
+ if retry_count >= max_retries:
680
+ return history, False
681
+
682
+ if not history:
683
+ return history, False
684
+
685
+ self.reset()
686
+
687
+ # 根据重试次数逐步减少
688
+ keep_ratio = 0.5 ** (retry_count + 1)
689
+ keep_count = max(MIN_KEEP_MESSAGES, int(len(history) * keep_ratio))
690
+
691
+ if keep_count < len(history):
692
+ truncated = history[-keep_count:]
693
+ self._truncated = True
694
+ self._truncate_info = f"错误重试截断 (第 {retry_count + 1} 次): {len(history)} -> {len(truncated)} 条消息"
695
+ return truncated, True
696
+
697
+ return history, False
698
+
699
+ def get_warning_header(self) -> Optional[str]:
700
+ if not self.config.add_warning_header or not self._truncated:
701
+ return None
702
+ return self._truncate_info
703
+
704
+ # ========== 兼容旧 API ==========
705
+
706
+ def truncate_by_count(self, history: List[dict], max_count: int) -> List[dict]:
707
+ """按消息数量截断(兼容)"""
708
+ if len(history) <= max_count:
709
+ return history
710
+ original_count = len(history)
711
+ truncated = history[-max_count:]
712
+ self._truncated = True
713
+ self._truncate_info = f"按数量截断: {original_count} -> {len(truncated)} 条消息"
714
+ return truncated
715
+
716
+ def truncate_by_chars(self, history: List[dict], max_chars: int) -> List[dict]:
717
+ """按字符数截断(兼容)"""
718
+ total_chars = len(json.dumps(history, ensure_ascii=False))
719
+ if total_chars <= max_chars:
720
+ return history
721
+
722
+ original_count = len(history)
723
+ result = []
724
+ current_chars = 0
725
+
726
+ for msg in reversed(history):
727
+ msg_chars = len(json.dumps(msg, ensure_ascii=False))
728
+ if current_chars + msg_chars > max_chars and result:
729
+ break
730
+ result.insert(0, msg)
731
+ current_chars += msg_chars
732
+
733
+ if len(result) < original_count:
734
+ self._truncated = True
735
+ self._truncate_info = f"按字符数截断: {original_count} -> {len(result)} 条消息"
736
+
737
+ return result
738
+
739
+ def should_pre_truncate(self, history: List[dict], user_content: str) -> bool:
740
+ """兼容旧 API"""
741
+ return self.needs_compression(history, user_content)
742
+
743
+ def should_summarize(self, history: List[dict]) -> bool:
744
+ """兼容旧 API"""
745
+ return self.needs_compression(history)
746
+
747
+ def should_smart_summarize(self, history: List[dict]) -> bool:
748
+ """兼容旧 API"""
749
+ return self.needs_compression(history)
750
+
751
+ def should_auto_truncate_summarize(self, history: List[dict]) -> bool:
752
+ """兼容旧 API"""
753
+ return self.needs_compression(history)
754
+
755
+ def should_pre_summary_for_error_retry(self, history: List[dict], user_content: str = "") -> bool:
756
+ """兼容旧 API"""
757
+ return self.needs_compression(history, user_content)
758
+
759
+ async def compress_with_summary(self, history: List[dict], api_caller: Callable) -> List[dict]:
760
+ """兼容旧 API"""
761
+ return await self.smart_compress(history, api_caller)
762
+
763
+ async def compress_before_auto_truncate(self, history: List[dict], api_caller: Callable) -> List[dict]:
764
+ """兼容旧 API"""
765
+ return await self.smart_compress(history, api_caller)
766
+
767
+ async def generate_summary(self, history: List[dict], api_caller: Callable) -> Optional[str]:
768
+ """兼容旧 API"""
769
+ return await self._generate_summary(history, api_caller)
770
+
771
+ def summarize_history_structure(self, history: List[dict], max_items: int = 12) -> str:
772
+ """生成历史结构摘要(调试用)"""
773
+ if not history:
774
+ return "len=0"
775
+
776
+ def entry_kind(msg):
777
+ if "userInputMessage" in msg:
778
+ return "U"
779
+ if "assistantResponseMessage" in msg:
780
+ return "A"
781
+ role = msg.get("role")
782
+ return "U" if role == "user" else ("A" if role == "assistant" else "?")
783
+
784
+ kinds = [entry_kind(msg) for msg in history]
785
+ if len(kinds) <= max_items:
786
+ seq = "".join(kinds)
787
+ else:
788
+ head = max_items // 2
789
+ tail = max_items - head
790
+ seq = f"{''.join(kinds[:head])}...{''.join(kinds[-tail:])}"
791
+
792
+ return f"len={len(history)} seq={seq}"
793
+
794
+
795
+
796
+ # ========== 全局配置 ==========
797
+
798
+ _history_config = HistoryConfig()
799
+
800
+
801
+ def get_history_config() -> HistoryConfig:
802
+ """获取历史消息配置"""
803
+ return _history_config
804
+
805
+
806
+ def set_history_config(config: HistoryConfig):
807
+ """设置历史消息配置"""
808
+ global _history_config
809
+ _history_config = config
810
+
811
+
812
+ def update_history_config(data: dict):
813
+ """更新历史消息配置"""
814
+ global _history_config
815
+ _history_config = HistoryConfig.from_dict(data)
816
+
817
+
818
+ def is_content_length_error(status_code: int, error_text: str) -> bool:
819
+ """检查是否为内容长度超限错误"""
820
+ if "CONTENT_LENGTH_EXCEEDS_THRESHOLD" in error_text:
821
+ return True
822
+ if "Input is too long" in error_text:
823
+ return True
824
+ lowered = error_text.lower()
825
+ if "too long" in lowered and ("input" in lowered or "content" in lowered or "message" in lowered):
826
+ return True
827
+ if "context length" in lowered or "token limit" in lowered:
828
+ return True
829
+ return False
KiroProxy/kiro_proxy/core/kiro_api.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Kiro Web Portal API 调用模块
2
+
3
+ 调用 Kiro 的 Web Portal API 获取用户信息,使用 CBOR 编码。
4
+ 参考: chaogei/Kiro-account-manager
5
+ """
6
+ import uuid
7
+ import httpx
8
+ from typing import Optional, Tuple, Any, Dict
9
+
10
+ try:
11
+ import cbor2
12
+ HAS_CBOR = True
13
+ except ImportError:
14
+ HAS_CBOR = False
15
+ print("[KiroAPI] 警告: cbor2 未安装,部分功能不可用。请运行: pip install cbor2")
16
+
17
+
18
+ # Kiro Web Portal API 基础 URL
19
+ KIRO_API_BASE = "https://app.kiro.dev/service/KiroWebPortalService/operation"
20
+
21
+
22
+ async def kiro_api_request(
23
+ operation: str,
24
+ body: Dict[str, Any],
25
+ access_token: str,
26
+ idp: str = "Google",
27
+ ) -> Tuple[bool, Any]:
28
+ """
29
+ 调用 Kiro Web Portal API
30
+
31
+ Args:
32
+ operation: API 操作名称,如 "GetUserUsageAndLimits"
33
+ body: 请求体(会被 CBOR 编码)
34
+ access_token: Bearer token
35
+ idp: 身份提供商 ("Google" 或 "Github")
36
+
37
+ Returns:
38
+ (success, response_data or error_dict)
39
+ """
40
+ if not HAS_CBOR:
41
+ return False, {"error": "cbor2 未安装"}
42
+
43
+ if not access_token:
44
+ return False, {"error": "缺少 access token"}
45
+
46
+ url = f"{KIRO_API_BASE}/{operation}"
47
+
48
+ # CBOR 编码请求体
49
+ try:
50
+ encoded_body = cbor2.dumps(body)
51
+ except Exception as e:
52
+ return False, {"error": f"CBOR 编码失败: {e}"}
53
+
54
+ headers = {
55
+ "accept": "application/cbor",
56
+ "content-type": "application/cbor",
57
+ "smithy-protocol": "rpc-v2-cbor",
58
+ "amz-sdk-invocation-id": str(uuid.uuid4()),
59
+ "amz-sdk-request": "attempt=1; max=1",
60
+ "x-amz-user-agent": "aws-sdk-js/1.0.0 kiro-proxy/1.0.0",
61
+ "authorization": f"Bearer {access_token}",
62
+ "cookie": f"Idp={idp}; AccessToken={access_token}",
63
+ }
64
+
65
+ try:
66
+ async with httpx.AsyncClient(timeout=15, verify=False) as client:
67
+ response = await client.post(url, content=encoded_body, headers=headers)
68
+
69
+ if response.status_code != 200:
70
+ return False, {"error": f"API 请求失败: {response.status_code}"}
71
+
72
+ # CBOR 解码响应
73
+ try:
74
+ data = cbor2.loads(response.content)
75
+ return True, data
76
+ except Exception as e:
77
+ return False, {"error": f"CBOR 解码失败: {e}"}
78
+
79
+ except httpx.TimeoutException:
80
+ return False, {"error": "请求超时"}
81
+ except Exception as e:
82
+ return False, {"error": f"请求失败: {str(e)}"}
83
+
84
+
85
+ async def get_user_info(
86
+ access_token: str,
87
+ idp: str = "Google",
88
+ ) -> Tuple[bool, Dict[str, Any]]:
89
+ """
90
+ 获取用户信息(包括邮箱)
91
+
92
+ Args:
93
+ access_token: Bearer token
94
+ idp: 身份提供商 ("Google" 或 "Github")
95
+
96
+ Returns:
97
+ (success, user_info or error_dict)
98
+ user_info 包含: email, userId 等
99
+ """
100
+ success, result = await kiro_api_request(
101
+ operation="GetUserUsageAndLimits",
102
+ body={"isEmailRequired": True, "origin": "KIRO_IDE"},
103
+ access_token=access_token,
104
+ idp=idp,
105
+ )
106
+
107
+ if not success:
108
+ return False, result
109
+
110
+ # 提取用户信息
111
+ user_info = result.get("userInfo", {})
112
+ return True, {
113
+ "email": user_info.get("email"),
114
+ "userId": user_info.get("userId"),
115
+ "raw": result,
116
+ }
117
+
118
+
119
+ async def get_user_email(
120
+ access_token: str,
121
+ provider: str = "Google",
122
+ ) -> Optional[str]:
123
+ """
124
+ 获取用户邮箱地址
125
+
126
+ Args:
127
+ access_token: Bearer token
128
+ provider: 登录提供商 ("Google" 或 "Github")
129
+
130
+ Returns:
131
+ 邮箱地址,失败返回 None
132
+ """
133
+ # 标准化 provider 名称
134
+ idp = provider
135
+ if provider and provider.lower() == "google":
136
+ idp = "Google"
137
+ elif provider and provider.lower() == "github":
138
+ idp = "Github"
139
+
140
+ success, result = await get_user_info(access_token, idp)
141
+
142
+ if success:
143
+ return result.get("email")
144
+
145
+ print(f"[KiroAPI] 获取邮箱失败: {result.get('error', '未知错误')}")
146
+ return None
KiroProxy/kiro_proxy/core/persistence.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """配置持久化"""
2
+ import json
3
+ from pathlib import Path
4
+ from typing import List, Dict, Any
5
+
6
+ # 统一使用 config.py 中的 DATA_DIR
7
+ from ..config import DATA_DIR
8
+
9
+ # 配置文件路径
10
+ CONFIG_DIR = DATA_DIR
11
+ CONFIG_FILE = CONFIG_DIR / "config.json"
12
+
13
+
14
+ def ensure_config_dir():
15
+ """确保配置目录存在"""
16
+ CONFIG_DIR.mkdir(parents=True, exist_ok=True)
17
+
18
+
19
+ def save_accounts(accounts: List[Dict[str, Any]]) -> bool:
20
+ """保存账号配置"""
21
+ try:
22
+ ensure_config_dir()
23
+ config = load_config()
24
+ config["accounts"] = accounts
25
+ with open(CONFIG_FILE, "w", encoding="utf-8") as f:
26
+ json.dump(config, f, indent=2, ensure_ascii=False)
27
+ return True
28
+ except Exception as e:
29
+ print(f"[Persistence] 保存配置失败: {e}")
30
+ return False
31
+
32
+
33
+ def load_accounts() -> List[Dict[str, Any]]:
34
+ """加载账号配置"""
35
+ config = load_config()
36
+ return config.get("accounts", [])
37
+
38
+
39
+ def load_config() -> Dict[str, Any]:
40
+ """加载完整配置"""
41
+ try:
42
+ if CONFIG_FILE.exists():
43
+ with open(CONFIG_FILE, "r", encoding="utf-8") as f:
44
+ return json.load(f)
45
+ except Exception as e:
46
+ print(f"[Persistence] 加载配置失败: {e}")
47
+ return {}
48
+
49
+
50
+ def save_config(config: Dict[str, Any]) -> bool:
51
+ """保存完整配置"""
52
+ try:
53
+ ensure_config_dir()
54
+ with open(CONFIG_FILE, "w", encoding="utf-8") as f:
55
+ json.dump(config, f, indent=2, ensure_ascii=False)
56
+ return True
57
+ except Exception as e:
58
+ print(f"[Persistence] 保存配置失败: {e}")
59
+ return False
60
+
61
+
62
+ def export_config() -> Dict[str, Any]:
63
+ """导出配置(用于备份)"""
64
+ return load_config()
65
+
66
+
67
+ def import_config(config: Dict[str, Any]) -> bool:
68
+ """导入配置(用于恢复)"""
69
+ return save_config(config)
KiroProxy/kiro_proxy/core/protocol_handler.py ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """自定义协议处理器
2
+
3
+ 在 Windows 上注册 kiro:// 协议,用于处理 OAuth 回调。
4
+ """
5
+ import sys
6
+ import os
7
+ import asyncio
8
+ import threading
9
+ from pathlib import Path
10
+ from typing import Optional, Callable
11
+ from http.server import HTTPServer, BaseHTTPRequestHandler
12
+ from urllib.parse import urlparse, parse_qs, urlencode
13
+ import socket
14
+
15
+
16
+ # 回调服务器端口
17
+ CALLBACK_PORT = 19823
18
+
19
+ # 全局回调结果
20
+ _callback_result = None
21
+ _callback_event = None
22
+ _callback_server = None
23
+ _server_thread = None
24
+
25
+
26
+ class CallbackHandler(BaseHTTPRequestHandler):
27
+ """处理 OAuth 回调的 HTTP 请求处理器"""
28
+
29
+ def log_message(self, format, *args):
30
+ """禁用日志输出"""
31
+ pass
32
+
33
+ def do_GET(self):
34
+ global _callback_result, _callback_event
35
+
36
+ # 解析 URL
37
+ parsed = urlparse(self.path)
38
+ params = parse_qs(parsed.query)
39
+
40
+ # 检查是否是回调路径
41
+ if parsed.path == '/kiro-callback' or parsed.path == '/' or 'code' in params:
42
+ code = params.get('code', [None])[0]
43
+ state = params.get('state', [None])[0]
44
+ error = params.get('error', [None])[0]
45
+
46
+ print(f"[ProtocolHandler] 收到回调: code={code[:20] if code else None}..., state={state}, error={error}")
47
+
48
+ if error:
49
+ _callback_result = {"error": error}
50
+ elif code and state:
51
+ _callback_result = {"code": code, "state": state}
52
+ else:
53
+ _callback_result = {"error": "缺少授权码"}
54
+
55
+ # 触发事件
56
+ if _callback_event:
57
+ _callback_event.set()
58
+
59
+ # 返回成功页面
60
+ self.send_response(200)
61
+ self.send_header('Content-type', 'text/html; charset=utf-8')
62
+ self.end_headers()
63
+
64
+ html = """
65
+ <!DOCTYPE html>
66
+ <html>
67
+ <head>
68
+ <meta charset="utf-8">
69
+ <title>登录成功</title>
70
+ <style>
71
+ body { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
72
+ display: flex; justify-content: center; align-items: center; height: 100vh;
73
+ margin: 0; background: #1a1a2e; color: #fff; }
74
+ .container { text-align: center; padding: 2rem; }
75
+ h1 { color: #4ade80; margin-bottom: 1rem; }
76
+ p { color: #9ca3af; }
77
+ </style>
78
+ </head>
79
+ <body>
80
+ <div class="container">
81
+ <h1>✅ 登录成功</h1>
82
+ <p>您可以关闭此窗口并返回 Kiro Proxy</p>
83
+ <script>setTimeout(function(){window.close();}, 3000);</script>
84
+ </div>
85
+ </body>
86
+ </html>
87
+ """
88
+ self.wfile.write(html.encode('utf-8'))
89
+ else:
90
+ self.send_response(404)
91
+ self.end_headers()
92
+
93
+
94
+ def is_port_available(port: int) -> bool:
95
+ """检查端口是否可用"""
96
+ try:
97
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
98
+ s.bind(('127.0.0.1', port))
99
+ return True
100
+ except OSError:
101
+ return False
102
+
103
+
104
+ def start_callback_server() -> tuple:
105
+ """启动回调服务器
106
+
107
+ Returns:
108
+ (success, port or error)
109
+ """
110
+ global _callback_server, _callback_result, _callback_event, _server_thread
111
+
112
+ # 如果服务器已经在运行,直接返回成功
113
+ if _callback_server is not None and _server_thread is not None and _server_thread.is_alive():
114
+ print(f"[ProtocolHandler] 回调服务器已在运行: http://127.0.0.1:{CALLBACK_PORT}")
115
+ return True, CALLBACK_PORT
116
+
117
+ _callback_result = None
118
+ _callback_event = threading.Event()
119
+
120
+ # 检查端口
121
+ if not is_port_available(CALLBACK_PORT):
122
+ # 端口被占用,可能是之前的服务器还在运行
123
+ print(f"[ProtocolHandler] 端口 {CALLBACK_PORT} 已被占用,尝试复用")
124
+ return True, CALLBACK_PORT
125
+
126
+ try:
127
+ _callback_server = HTTPServer(('127.0.0.1', CALLBACK_PORT), CallbackHandler)
128
+
129
+ # 在后台线程运行服务器
130
+ _server_thread = threading.Thread(target=_callback_server.serve_forever, daemon=True)
131
+ _server_thread.start()
132
+
133
+ print(f"[ProtocolHandler] 回调服务器已启动: http://127.0.0.1:{CALLBACK_PORT}")
134
+ return True, CALLBACK_PORT
135
+ except Exception as e:
136
+ return False, str(e)
137
+
138
+
139
+ def stop_callback_server():
140
+ """停止回调服务器"""
141
+ global _callback_server, _server_thread
142
+
143
+ if _callback_server:
144
+ try:
145
+ _callback_server.shutdown()
146
+ except:
147
+ pass
148
+ _callback_server = None
149
+ _server_thread = None
150
+ print("[ProtocolHandler] 回调服务���已停止")
151
+
152
+
153
+ def wait_for_callback(timeout: int = 300) -> tuple:
154
+ """等待回调
155
+
156
+ Args:
157
+ timeout: 超时时间(秒)
158
+
159
+ Returns:
160
+ (success, result or error)
161
+ """
162
+ global _callback_result, _callback_event
163
+
164
+ if _callback_event is None:
165
+ return False, {"error": "回调服务器未启动"}
166
+
167
+ # 等待回调
168
+ if _callback_event.wait(timeout=timeout):
169
+ if _callback_result and "code" in _callback_result:
170
+ return True, _callback_result
171
+ elif _callback_result and "error" in _callback_result:
172
+ return False, _callback_result
173
+ else:
174
+ return False, {"error": "未收到有效回调"}
175
+ else:
176
+ return False, {"error": "等待回调超时"}
177
+
178
+
179
+ def get_callback_result() -> Optional[dict]:
180
+ """获取回调结果(非阻塞)"""
181
+ global _callback_result
182
+ return _callback_result
183
+
184
+
185
+ def clear_callback_result():
186
+ """清除回调结果"""
187
+ global _callback_result, _callback_event
188
+ _callback_result = None
189
+ if _callback_event:
190
+ _callback_event.clear()
191
+
192
+
193
+ # Windows 协议注册
194
+ def register_protocol_windows() -> tuple:
195
+ """在 Windows 上注册 kiro:// 协议
196
+
197
+ 注册后,当浏览器重定向到 kiro:// URL 时,Windows 会调用我们的脚本,
198
+ 脚本将参数重定向到本地 HTTP 服务器。
199
+
200
+ Returns:
201
+ (success, message)
202
+ """
203
+ if sys.platform != 'win32':
204
+ return False, "仅支持 Windows"
205
+
206
+ try:
207
+ import winreg
208
+
209
+ # 获取当前 Python 解释器路径
210
+ python_exe = sys.executable
211
+
212
+ # 创建一个处理脚本
213
+ script_dir = Path.home() / ".kiro-proxy"
214
+ script_dir.mkdir(parents=True, exist_ok=True)
215
+ script_path = script_dir / "protocol_redirect.pyw"
216
+
217
+ # 写入重定向脚本 (.pyw 不显示控制台窗口)
218
+ script_content = f'''# -*- coding: utf-8 -*-
219
+ # Kiro Protocol Redirect Script
220
+ import sys
221
+ import webbrowser
222
+ from urllib.parse import urlparse, parse_qs, urlencode
223
+
224
+ if len(sys.argv) > 1:
225
+ url = sys.argv[1]
226
+
227
+ # 解析 kiro:// URL
228
+ # 格式: kiro://kiro.kiroAgent/authenticate-success?code=xxx&state=xxx
229
+ if url.startswith('kiro://'):
230
+ # 提取查询参数
231
+ query_start = url.find('?')
232
+ if query_start > -1:
233
+ query_string = url[query_start + 1:]
234
+ # 重定向到本地 HTTP 服务器
235
+ redirect_url = "http://127.0.0.1:{CALLBACK_PORT}/kiro-callback?" + query_string
236
+ webbrowser.open(redirect_url)
237
+ '''
238
+ script_path.write_text(script_content, encoding='utf-8')
239
+
240
+ # 获取 pythonw.exe 路径(无控制台窗口)
241
+ python_dir = Path(python_exe).parent
242
+ pythonw_exe = python_dir / "pythonw.exe"
243
+ if not pythonw_exe.exists():
244
+ pythonw_exe = python_exe # 降级使用 python.exe
245
+
246
+ # 注册协议
247
+ key_path = r"SOFTWARE\\Classes\\kiro"
248
+
249
+ # 创建主键
250
+ key = winreg.CreateKey(winreg.HKEY_CURRENT_USER, key_path)
251
+ winreg.SetValue(key, "", winreg.REG_SZ, "URL:Kiro Protocol")
252
+ winreg.SetValueEx(key, "URL Protocol", 0, winreg.REG_SZ, "")
253
+ winreg.CloseKey(key)
254
+
255
+ # 创建 DefaultIcon 键
256
+ icon_key = winreg.CreateKey(winreg.HKEY_CURRENT_USER, key_path + r"\\DefaultIcon")
257
+ winreg.SetValue(icon_key, "", winreg.REG_SZ, f"{python_exe},0")
258
+ winreg.CloseKey(icon_key)
259
+
260
+ # 创建 shell\\open\\command 键
261
+ cmd_key = winreg.CreateKey(winreg.HKEY_CURRENT_USER, key_path + r"\\shell\\open\\command")
262
+ cmd = f'"{pythonw_exe}" "{script_path}" "%1"'
263
+ winreg.SetValue(cmd_key, "", winreg.REG_SZ, cmd)
264
+ winreg.CloseKey(cmd_key)
265
+
266
+ print(f"[ProtocolHandler] 已注册 kiro:// 协议")
267
+ print(f"[ProtocolHandler] 脚本路径: {script_path}")
268
+ print(f"[ProtocolHandler] 命令: {cmd}")
269
+ return True, "协议注册成功"
270
+
271
+ except Exception as e:
272
+ import traceback
273
+ traceback.print_exc()
274
+ return False, f"注册失败: {e}"
275
+
276
+
277
+ def unregister_protocol_windows() -> tuple:
278
+ """取消注册 kiro:// 协议"""
279
+ if sys.platform != 'win32':
280
+ return False, "仅支持 Windows"
281
+
282
+ try:
283
+ import winreg
284
+
285
+ def delete_key_recursive(key, subkey):
286
+ try:
287
+ open_key = winreg.OpenKey(key, subkey, 0, winreg.KEY_ALL_ACCESS)
288
+ info = winreg.QueryInfoKey(open_key)
289
+ for i in range(info[0]):
290
+ child = winreg.EnumKey(open_key, 0)
291
+ delete_key_recursive(open_key, child)
292
+ winreg.CloseKey(open_key)
293
+ winreg.DeleteKey(key, subkey)
294
+ except WindowsError:
295
+ pass
296
+
297
+ delete_key_recursive(winreg.HKEY_CURRENT_USER, r"SOFTWARE\\Classes\\kiro")
298
+
299
+ print("[ProtocolHandler] 已取消注册 kiro:// 协议")
300
+ return True, "协议取消注册成功"
301
+
302
+ except Exception as e:
303
+ return False, f"取消注册失败: {e}"
304
+
305
+
306
+ def is_protocol_registered() -> bool:
307
+ """检查 kiro:// 协议是否已注册"""
308
+ if sys.platform != 'win32':
309
+ return False
310
+
311
+ try:
312
+ import winreg
313
+ key = winreg.OpenKey(winreg.HKEY_CURRENT_USER, r"SOFTWARE\\Classes\\kiro")
314
+ winreg.CloseKey(key)
315
+ return True
316
+ except WindowsError:
317
+ return False
318
+
KiroProxy/kiro_proxy/core/quota_cache.py ADDED
@@ -0,0 +1,397 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """额度缓存管理模块
2
+
3
+ 提供账号额度信息的内存缓存和文件持久化功能。
4
+ """
5
+ import json
6
+ import time
7
+ import asyncio
8
+ from dataclasses import dataclass, field, asdict
9
+ from enum import Enum
10
+ from pathlib import Path
11
+ from typing import Optional, Dict, Any
12
+ from threading import Lock
13
+
14
+
15
+ # 默认缓存过期时间(秒)
16
+ DEFAULT_CACHE_MAX_AGE = 300 # 5分钟
17
+
18
+ # 低余额阈值
19
+ LOW_BALANCE_THRESHOLD = 0.2
20
+
21
+
22
+ class BalanceStatus(Enum):
23
+ """额度状态枚举
24
+
25
+ 用于区分账号的额度状态:
26
+ - NORMAL: 正常(剩余额度 > 20%)
27
+ - LOW: 低额度(0 < 剩余额度 <= 20%)
28
+ - EXHAUSTED: 无额度(剩余额度 <= 0)
29
+ """
30
+ NORMAL = "normal" # 正常(>20%)
31
+ LOW = "low" # 低额度(0-20%)
32
+ EXHAUSTED = "exhausted" # 无额度(<=0)
33
+
34
+
35
+ @dataclass
36
+ class CachedQuota:
37
+ """缓存的额度信息"""
38
+ account_id: str
39
+ usage_limit: float = 0.0 # 总额度
40
+ current_usage: float = 0.0 # 已用额度
41
+ balance: float = 0.0 # 剩余额度
42
+ usage_percent: float = 0.0 # 使用百分比
43
+ balance_status: str = "normal" # 额度状态: normal, low, exhausted
44
+ is_low_balance: bool = False # 是否低额度(兼容旧字段)
45
+ is_exhausted: bool = False # 是否无额度
46
+ is_suspended: bool = False # 是否被封禁
47
+ subscription_title: str = "" # 订阅类型
48
+ free_trial_limit: float = 0.0 # 免费试用额度
49
+ free_trial_usage: float = 0.0 # 免费试用已用
50
+ bonus_limit: float = 0.0 # 奖励额度
51
+ bonus_usage: float = 0.0 # 奖励已用
52
+ updated_at: float = 0.0 # 更新时间戳
53
+ error: Optional[str] = None # 错误信息(如果获取失败)
54
+
55
+ # 重置和过期时间
56
+ next_reset_date: Optional[str] = None # 下次重置时间
57
+ free_trial_expiry: Optional[str] = None # 免费试用过期时间
58
+ bonus_expiries: list = None # 奖励过期时间列表
59
+
60
+ def __post_init__(self):
61
+ """初始化后计算额度状态"""
62
+ self._update_balance_status()
63
+
64
+ def _update_balance_status(self) -> None:
65
+ """更新额度状态"""
66
+ if self.error is not None:
67
+ # 有错误时不更新状态
68
+ return
69
+
70
+ if self.balance <= 0:
71
+ self.balance_status = BalanceStatus.EXHAUSTED.value
72
+ self.is_exhausted = True
73
+ self.is_low_balance = False
74
+ elif self.usage_limit > 0:
75
+ remaining_percent = (self.balance / self.usage_limit) * 100
76
+ if remaining_percent <= LOW_BALANCE_THRESHOLD * 100:
77
+ self.balance_status = BalanceStatus.LOW.value
78
+ self.is_low_balance = True
79
+ self.is_exhausted = False
80
+ else:
81
+ self.balance_status = BalanceStatus.NORMAL.value
82
+ self.is_low_balance = False
83
+ self.is_exhausted = False
84
+ else:
85
+ self.balance_status = BalanceStatus.NORMAL.value
86
+ self.is_low_balance = False
87
+ self.is_exhausted = False
88
+
89
+ @classmethod
90
+ def from_usage_info(cls, account_id: str, usage_info: 'UsageInfo') -> 'CachedQuota':
91
+ """从 UsageInfo 创建 CachedQuota"""
92
+ usage_percent = (usage_info.current_usage / usage_info.usage_limit * 100) if usage_info.usage_limit > 0 else 0.0
93
+ quota = cls(
94
+ account_id=account_id,
95
+ usage_limit=usage_info.usage_limit,
96
+ current_usage=usage_info.current_usage,
97
+ balance=usage_info.balance,
98
+ usage_percent=round(usage_percent, 2),
99
+ is_low_balance=usage_info.is_low_balance,
100
+ subscription_title=usage_info.subscription_title,
101
+ free_trial_limit=usage_info.free_trial_limit,
102
+ free_trial_usage=usage_info.free_trial_usage,
103
+ bonus_limit=usage_info.bonus_limit,
104
+ bonus_usage=usage_info.bonus_usage,
105
+ updated_at=time.time(),
106
+ error=None,
107
+ next_reset_date=usage_info.next_reset_date,
108
+ free_trial_expiry=usage_info.free_trial_expiry,
109
+ bonus_expiries=usage_info.bonus_expiries or [],
110
+ )
111
+ # 重新计算状态以确保一致性
112
+ quota._update_balance_status()
113
+ return quota
114
+
115
+ @classmethod
116
+ def from_error(cls, account_id: str, error: str) -> 'CachedQuota':
117
+ """创建错误状态的缓存"""
118
+ # 检查是否为账号封禁错误
119
+ is_suspended = (
120
+ "temporarily_suspended" in error.lower() or
121
+ "suspended" in error.lower() or
122
+ "accountsuspendedexception" in error.lower()
123
+ )
124
+
125
+ quota = cls(
126
+ account_id=account_id,
127
+ updated_at=time.time(),
128
+ error=error
129
+ )
130
+
131
+ # 如果是封禁错误,标记为特殊状态
132
+ if is_suspended:
133
+ quota.is_suspended = True
134
+
135
+ return quota
136
+
137
+ @classmethod
138
+ def from_dict(cls, data: Dict[str, Any]) -> 'CachedQuota':
139
+ """从字典创建"""
140
+ quota = cls(
141
+ account_id=data.get("account_id", ""),
142
+ usage_limit=data.get("usage_limit", 0.0),
143
+ current_usage=data.get("current_usage", 0.0),
144
+ balance=data.get("balance", 0.0),
145
+ usage_percent=data.get("usage_percent", 0.0),
146
+ balance_status=data.get("balance_status", "normal"),
147
+ is_low_balance=data.get("is_low_balance", False),
148
+ is_exhausted=data.get("is_exhausted", False),
149
+ is_suspended=data.get("is_suspended", False),
150
+ subscription_title=data.get("subscription_title", ""),
151
+ free_trial_limit=data.get("free_trial_limit", 0.0),
152
+ free_trial_usage=data.get("free_trial_usage", 0.0),
153
+ bonus_limit=data.get("bonus_limit", 0.0),
154
+ bonus_usage=data.get("bonus_usage", 0.0),
155
+ updated_at=data.get("updated_at", 0.0),
156
+ error=data.get("error"),
157
+ next_reset_date=data.get("next_reset_date"),
158
+ free_trial_expiry=data.get("free_trial_expiry"),
159
+ bonus_expiries=data.get("bonus_expiries", []),
160
+ )
161
+ # 重新计算状态以确保一致性
162
+ quota._update_balance_status()
163
+ return quota
164
+
165
+ def to_dict(self) -> Dict[str, Any]:
166
+ """转换为字典"""
167
+ return asdict(self)
168
+
169
+ def has_error(self) -> bool:
170
+ """是否有错误"""
171
+ return self.error is not None
172
+
173
+ def is_available(self) -> bool:
174
+ """额度是否可用(未耗尽且无错误)"""
175
+ return not self.is_exhausted and not self.has_error()
176
+
177
+ def get_balance_status_enum(self) -> BalanceStatus:
178
+ """获取额度状态枚举"""
179
+ try:
180
+ return BalanceStatus(self.balance_status)
181
+ except ValueError:
182
+ return BalanceStatus.NORMAL
183
+
184
+
185
+ class QuotaCache:
186
+ """额度缓存管理器
187
+
188
+ 提供线程安全的额度缓存操作,支持内存缓存和文件持久化。
189
+ """
190
+
191
+ def __init__(self, cache_file: Optional[str] = None):
192
+ """
193
+ 初始化缓存管理器
194
+
195
+ Args:
196
+ cache_file: 缓存文件路径,None 则使用默认路径
197
+ """
198
+ self._cache: Dict[str, CachedQuota] = {}
199
+ self._lock = Lock()
200
+ self._save_lock = asyncio.Lock()
201
+
202
+ # 设置缓存文件路径
203
+ if cache_file:
204
+ self._cache_file = Path(cache_file)
205
+ else:
206
+ from ..config import DATA_DIR
207
+ self._cache_file = DATA_DIR / "quota_cache.json"
208
+
209
+ # 启动时加载缓存
210
+ self.load_from_file()
211
+
212
+ def get(self, account_id: str) -> Optional[CachedQuota]:
213
+ """获取账号的缓存额度
214
+
215
+ Args:
216
+ account_id: 账号ID
217
+
218
+ Returns:
219
+ 缓存的额度信息,不存在则返回 None
220
+ """
221
+ with self._lock:
222
+ return self._cache.get(account_id)
223
+
224
+ def set(self, account_id: str, quota: CachedQuota) -> None:
225
+ """设置账号的额度缓存
226
+
227
+ Args:
228
+ account_id: 账号ID
229
+ quota: 额度信息
230
+ """
231
+ with self._lock:
232
+ self._cache[account_id] = quota
233
+
234
+ def is_stale(self, account_id: str, max_age_seconds: int = DEFAULT_CACHE_MAX_AGE) -> bool:
235
+ """检查缓存是否过期
236
+
237
+ Args:
238
+ account_id: 账号ID
239
+ max_age_seconds: 最大缓存时间(秒)
240
+
241
+ Returns:
242
+ True 表示缓存过期或不存在
243
+ """
244
+ with self._lock:
245
+ quota = self._cache.get(account_id)
246
+ if quota is None:
247
+ return True
248
+ return (time.time() - quota.updated_at) > max_age_seconds
249
+
250
+ def get_all(self) -> Dict[str, CachedQuota]:
251
+ """获取所有缓存
252
+
253
+ Returns:
254
+ 所有账号的额度缓存副本
255
+ """
256
+ with self._lock:
257
+ return dict(self._cache)
258
+
259
+ def remove(self, account_id: str) -> None:
260
+ """移除账号缓存
261
+
262
+ Args:
263
+ account_id: 账号ID
264
+ """
265
+ with self._lock:
266
+ self._cache.pop(account_id, None)
267
+
268
+ def clear(self) -> None:
269
+ """清空所有缓存"""
270
+ with self._lock:
271
+ self._cache.clear()
272
+
273
+ def load_from_file(self) -> bool:
274
+ """从文件加载缓存
275
+
276
+ Returns:
277
+ 是否加载成功
278
+ """
279
+ if not self._cache_file.exists():
280
+ return False
281
+
282
+ try:
283
+ with open(self._cache_file, 'r', encoding='utf-8') as f:
284
+ data = json.load(f)
285
+
286
+ # 验证版本
287
+ version = data.get("version", "1.0")
288
+ accounts_data = data.get("accounts", {})
289
+
290
+ with self._lock:
291
+ self._cache.clear()
292
+ for account_id, quota_data in accounts_data.items():
293
+ quota_data["account_id"] = account_id
294
+ self._cache[account_id] = CachedQuota.from_dict(quota_data)
295
+
296
+ print(f"[QuotaCache] 从文件加载 {len(self._cache)} 个账号的额度缓存")
297
+ return True
298
+
299
+ except json.JSONDecodeError as e:
300
+ print(f"[QuotaCache] 缓存文件格式错误: {e}")
301
+ return False
302
+ except Exception as e:
303
+ print(f"[QuotaCache] 加载缓存失败: {e}")
304
+ return False
305
+
306
+ def save_to_file(self) -> bool:
307
+ """保存缓存到文件(同步版本)
308
+
309
+ Returns:
310
+ 是否保存成功
311
+ """
312
+ try:
313
+ # 确保目录存在
314
+ self._cache_file.parent.mkdir(parents=True, exist_ok=True)
315
+
316
+ with self._lock:
317
+ accounts_data = {}
318
+ for account_id, quota in self._cache.items():
319
+ quota_dict = quota.to_dict()
320
+ quota_dict.pop("account_id", None) # 避免重复存储
321
+ accounts_data[account_id] = quota_dict
322
+
323
+ data = {
324
+ "version": "1.0",
325
+ "updated_at": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
326
+ "accounts": accounts_data
327
+ }
328
+
329
+ # 写入临时文件后重命名,确保原子性
330
+ temp_file = self._cache_file.with_suffix('.tmp')
331
+ with open(temp_file, 'w', encoding='utf-8') as f:
332
+ json.dump(data, f, indent=2, ensure_ascii=False)
333
+ temp_file.replace(self._cache_file)
334
+
335
+ return True
336
+
337
+ except Exception as e:
338
+ print(f"[QuotaCache] 保存缓存失败: {e}")
339
+ return False
340
+
341
+ async def save_to_file_async(self) -> bool:
342
+ """异步保存缓存到文件
343
+
344
+ Returns:
345
+ 是否保存成功
346
+ """
347
+ async with self._save_lock:
348
+ # 在线程池中执行同步保存
349
+ loop = asyncio.get_event_loop()
350
+ return await loop.run_in_executor(None, self.save_to_file)
351
+
352
+ def get_summary(self) -> Dict[str, Any]:
353
+ """获取缓存汇总信息
354
+
355
+ Returns:
356
+ 汇总统计信息
357
+ """
358
+ with self._lock:
359
+ total_balance = 0.0
360
+ total_usage = 0.0
361
+ total_limit = 0.0
362
+ error_count = 0
363
+ stale_count = 0
364
+
365
+ current_time = time.time()
366
+
367
+ for quota in self._cache.values():
368
+ if quota.has_error():
369
+ error_count += 1
370
+ else:
371
+ total_balance += quota.balance
372
+ total_usage += quota.current_usage
373
+ total_limit += quota.usage_limit
374
+
375
+ if (current_time - quota.updated_at) > DEFAULT_CACHE_MAX_AGE:
376
+ stale_count += 1
377
+
378
+ return {
379
+ "total_accounts": len(self._cache),
380
+ "total_balance": round(total_balance, 2),
381
+ "total_usage": round(total_usage, 2),
382
+ "total_limit": round(total_limit, 2),
383
+ "error_count": error_count,
384
+ "stale_count": stale_count
385
+ }
386
+
387
+
388
+ # 全局缓存实例
389
+ _quota_cache: Optional[QuotaCache] = None
390
+
391
+
392
+ def get_quota_cache() -> QuotaCache:
393
+ """获取全局缓存实例"""
394
+ global _quota_cache
395
+ if _quota_cache is None:
396
+ _quota_cache = QuotaCache()
397
+ return _quota_cache
KiroProxy/kiro_proxy/core/quota_scheduler.py ADDED
@@ -0,0 +1,321 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """额度更新调度器模块
2
+
3
+ 实现启动时并发获取所有账号额度、定时更新活跃账号额度的功能。
4
+ """
5
+ import asyncio
6
+ import time
7
+ from typing import Optional, Set, Dict, List, TYPE_CHECKING
8
+ from threading import Lock
9
+
10
+ if TYPE_CHECKING:
11
+ from .account import Account
12
+
13
+ from .quota_cache import QuotaCache, CachedQuota, get_quota_cache
14
+ from .usage import get_account_usage
15
+
16
+
17
+ # 默认更新间隔(秒)
18
+ DEFAULT_UPDATE_INTERVAL = 60
19
+
20
+ # 活跃账号判定时间窗口(秒)
21
+ # 需要覆盖一次更新周期,避免低频请求时“永远错过”定时刷新
22
+ ACTIVE_WINDOW_SECONDS = 120
23
+
24
+
25
+ class QuotaScheduler:
26
+ """额度更新调度器
27
+
28
+ 负责启动时并发获取所有账号额度,以及定时更新活跃账号的额度。
29
+ """
30
+
31
+ def __init__(self,
32
+ quota_cache: Optional[QuotaCache] = None,
33
+ update_interval: int = DEFAULT_UPDATE_INTERVAL):
34
+ """
35
+ 初始化调度器
36
+
37
+ Args:
38
+ quota_cache: 额度缓存实例
39
+ update_interval: 更新间隔(秒)
40
+ """
41
+ self.quota_cache = quota_cache or get_quota_cache()
42
+ self.update_interval = update_interval
43
+
44
+ self._active_accounts: Dict[str, float] = {} # account_id -> last_used_timestamp
45
+ self._lock = Lock()
46
+ self._task: Optional[asyncio.Task] = None
47
+ self._running = False
48
+ self._last_full_refresh: Optional[float] = None
49
+ self._accounts_getter = None # 获取账号列表的回调函数
50
+
51
+ def set_accounts_getter(self, getter):
52
+ """设置获取账号列表的回调函数
53
+
54
+ Args:
55
+ getter: 返回账号列表的可调用对象
56
+ """
57
+ self._accounts_getter = getter
58
+
59
+ def _get_accounts(self) -> List['Account']:
60
+ """获取账号列表"""
61
+ if self._accounts_getter:
62
+ return self._accounts_getter()
63
+ return []
64
+
65
+ async def start(self) -> None:
66
+ """启动调度器"""
67
+ if self._running:
68
+ return
69
+
70
+ self._running = True
71
+ print("[QuotaScheduler] 启动额度更新调度器")
72
+
73
+ # 启动时刷新所有账号额度
74
+ await self.refresh_all()
75
+
76
+ # 启动定时更新任务
77
+ self._task = asyncio.create_task(self._update_loop())
78
+
79
+ async def stop(self) -> None:
80
+ """停止调度器"""
81
+ self._running = False
82
+
83
+ if self._task:
84
+ self._task.cancel()
85
+ try:
86
+ await self._task
87
+ except asyncio.CancelledError:
88
+ pass
89
+ self._task = None
90
+
91
+ print("[QuotaScheduler] 额度更新调度器已停止")
92
+
93
+ async def refresh_all(self) -> Dict[str, bool]:
94
+ """刷新所有账号额度
95
+
96
+ Returns:
97
+ 账号ID -> 是否成功的字典
98
+ """
99
+ accounts = self._get_accounts()
100
+ if not accounts:
101
+ print("[QuotaScheduler] 没有账号需要刷新")
102
+ return {}
103
+
104
+ # 刷新所有账号(包括禁用的,以便检查是否可以解禁)
105
+ print(f"[QuotaScheduler] 开始刷新 {len(accounts)} 个账号的额度...")
106
+
107
+ # 并发获取所有账号额度
108
+ tasks = [self._refresh_account_internal(acc) for acc in accounts]
109
+ results = await asyncio.gather(*tasks, return_exceptions=True)
110
+
111
+ # 统计结果
112
+ success_count = 0
113
+ fail_count = 0
114
+ result_dict = {}
115
+
116
+ for acc, result in zip(accounts, results):
117
+ if isinstance(result, Exception):
118
+ result_dict[acc.id] = False
119
+ fail_count += 1
120
+ else:
121
+ result_dict[acc.id] = result
122
+ if result:
123
+ success_count += 1
124
+ else:
125
+ fail_count += 1
126
+
127
+ self._last_full_refresh = time.time()
128
+
129
+ # 保存缓存
130
+ await self.quota_cache.save_to_file_async()
131
+
132
+ # 保存账号配置(因为可能有启用/禁用状态变化)
133
+ self._save_accounts_config()
134
+
135
+ print(f"[QuotaScheduler] 额度刷新完成: 成功 {success_count}, 失败 {fail_count}")
136
+ return result_dict
137
+
138
+ def _save_accounts_config(self):
139
+ """保存账号配置"""
140
+ try:
141
+ from .state import state
142
+ state._save_accounts()
143
+ except Exception as e:
144
+ print(f"[QuotaScheduler] 保存账号配置失败: {e}")
145
+
146
+ async def refresh_account(self, account_id: str) -> bool:
147
+ """刷新单个账号额度
148
+
149
+ Args:
150
+ account_id: 账号ID
151
+
152
+ Returns:
153
+ 是否成功
154
+ """
155
+ accounts = self._get_accounts()
156
+ account = next((acc for acc in accounts if acc.id == account_id), None)
157
+
158
+ if not account:
159
+ print(f"[QuotaScheduler] 账号不存在: {account_id}")
160
+ return False
161
+
162
+ success = await self._refresh_account_internal(account)
163
+
164
+ if success:
165
+ await self.quota_cache.save_to_file_async()
166
+ self._save_accounts_config()
167
+
168
+ return success
169
+
170
+ async def _refresh_account_internal(self, account: 'Account') -> bool:
171
+ """内部刷新账号额度方法
172
+
173
+ Args:
174
+ account: 账号对象
175
+
176
+ Returns:
177
+ 是否成功
178
+ """
179
+ try:
180
+ success, result = await get_account_usage(account)
181
+
182
+ if success:
183
+ quota = CachedQuota.from_usage_info(account.id, result)
184
+ self.quota_cache.set(account.id, quota)
185
+
186
+ # 额度为 0 时自动禁用账号
187
+ if quota.is_exhausted:
188
+ if account.enabled:
189
+ account.enabled = False
190
+ # 标记为自动禁用,避免与手动禁用混淆
191
+ if hasattr(account, "auto_disabled"):
192
+ account.auto_disabled = True
193
+ print(f"[QuotaScheduler] 账号 {account.id} ({account.name}) 额度已用尽,自动禁用")
194
+ else:
195
+ # 有额度时自动解禁账号(仅对自动禁用的账号生效,避免覆盖手动禁用/封禁)
196
+ if (not account.enabled) and getattr(account, "auto_disabled", False):
197
+ account.enabled = True
198
+ account.auto_disabled = False
199
+ print(f"[QuotaScheduler] 账号 {account.id} ({account.name}) 有可用额度,自动启用")
200
+
201
+ return True
202
+ else:
203
+ error_msg = result.get("error", "Unknown error") if isinstance(result, dict) else str(result)
204
+ quota = CachedQuota.from_error(account.id, error_msg)
205
+ self.quota_cache.set(account.id, quota)
206
+ print(f"[QuotaScheduler] 获取账号 {account.id} 额度失败: {error_msg}")
207
+ return False
208
+
209
+ except Exception as e:
210
+ error_msg = str(e)
211
+ quota = CachedQuota.from_error(account.id, error_msg)
212
+ self.quota_cache.set(account.id, quota)
213
+ print(f"[QuotaScheduler] 获取账号 {account.id} 额度异常: {error_msg}")
214
+ return False
215
+
216
+ def mark_active(self, account_id: str) -> None:
217
+ """标记账号为活跃
218
+
219
+ Args:
220
+ account_id: 账号ID
221
+ """
222
+ with self._lock:
223
+ self._active_accounts[account_id] = time.time()
224
+
225
+ def is_active(self, account_id: str) -> bool:
226
+ """检查账号是否活跃
227
+
228
+ Args:
229
+ account_id: 账号ID
230
+
231
+ Returns:
232
+ 是否在活跃时间窗口内
233
+ """
234
+ with self._lock:
235
+ last_used = self._active_accounts.get(account_id)
236
+ if last_used is None:
237
+ return False
238
+ return (time.time() - last_used) < ACTIVE_WINDOW_SECONDS
239
+
240
+ def get_active_accounts(self) -> Set[str]:
241
+ """获取活跃账号列表
242
+
243
+ Returns:
244
+ 活跃账号ID集合
245
+ """
246
+ current_time = time.time()
247
+ with self._lock:
248
+ return {
249
+ account_id
250
+ for account_id, last_used in self._active_accounts.items()
251
+ if (current_time - last_used) < ACTIVE_WINDOW_SECONDS
252
+ }
253
+
254
+ def cleanup_inactive(self) -> None:
255
+ """清理不活跃的账号记录"""
256
+ current_time = time.time()
257
+ with self._lock:
258
+ self._active_accounts = {
259
+ account_id: last_used
260
+ for account_id, last_used in self._active_accounts.items()
261
+ if (current_time - last_used) < ACTIVE_WINDOW_SECONDS * 2
262
+ }
263
+
264
+ async def _update_loop(self) -> None:
265
+ """定时更新循环"""
266
+ while self._running:
267
+ try:
268
+ await asyncio.sleep(self.update_interval)
269
+
270
+ if not self._running:
271
+ break
272
+
273
+ # 获取活跃账号
274
+ active_ids = self.get_active_accounts()
275
+
276
+ if active_ids:
277
+ print(f"[QuotaScheduler] 更新 {len(active_ids)} 个活跃账号的额度...")
278
+
279
+ accounts = self._get_accounts()
280
+ active_accounts = [acc for acc in accounts if acc.id in active_ids]
281
+
282
+ # 并发更新
283
+ tasks = [self._refresh_account_internal(acc) for acc in active_accounts]
284
+ await asyncio.gather(*tasks, return_exceptions=True)
285
+
286
+ # 保存缓存
287
+ await self.quota_cache.save_to_file_async()
288
+
289
+ # 清理不活跃记录
290
+ self.cleanup_inactive()
291
+
292
+ except asyncio.CancelledError:
293
+ break
294
+ except Exception as e:
295
+ print(f"[QuotaScheduler] 更新循环异常: {e}")
296
+
297
+ def get_last_full_refresh(self) -> Optional[float]:
298
+ """获取最后一次全量刷新时间"""
299
+ return self._last_full_refresh
300
+
301
+ def get_status(self) -> dict:
302
+ """获取调度器状态"""
303
+ return {
304
+ "running": self._running,
305
+ "update_interval": self.update_interval,
306
+ "active_accounts": list(self.get_active_accounts()),
307
+ "active_count": len(self.get_active_accounts()),
308
+ "last_full_refresh": self._last_full_refresh
309
+ }
310
+
311
+
312
+ # 全局调度器实例
313
+ _quota_scheduler: Optional[QuotaScheduler] = None
314
+
315
+
316
+ def get_quota_scheduler() -> QuotaScheduler:
317
+ """获取全局调度器实例"""
318
+ global _quota_scheduler
319
+ if _quota_scheduler is None:
320
+ _quota_scheduler = QuotaScheduler()
321
+ return _quota_scheduler
KiroProxy/kiro_proxy/core/rate_limiter.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """请求限速器 - 降低账号封禁风险
2
+
3
+ 通过限制请求频率来降低被检测为异常活动的风险:
4
+ - 每账号请求间隔
5
+ - 全局请求限制
6
+ - 突发请求检测
7
+
8
+ 注意:429 冷却时间已改为自动管理(固定5分钟),不再需要手动配置
9
+ """
10
+ import time
11
+ from dataclasses import dataclass, field
12
+ from typing import Dict, Optional
13
+ from collections import deque
14
+
15
+
16
+ @dataclass
17
+ class RateLimitConfig:
18
+ """限速配置"""
19
+ # 每账号最小请求间隔(秒)
20
+ min_request_interval: float = 0.5
21
+
22
+ # 每账号每分钟最大请求数
23
+ max_requests_per_minute: int = 60
24
+
25
+ # 全局每分钟最大请求数
26
+ global_max_requests_per_minute: int = 120
27
+
28
+ # 是否启用限速
29
+ enabled: bool = False
30
+
31
+
32
+ @dataclass
33
+ class AccountRateState:
34
+ """账号限速状态"""
35
+ last_request_time: float = 0
36
+ request_times: deque = field(default_factory=lambda: deque(maxlen=100))
37
+
38
+ def get_requests_in_window(self, window_seconds: int = 60) -> int:
39
+ """获取时间窗口内的请求数"""
40
+ now = time.time()
41
+ cutoff = now - window_seconds
42
+ return sum(1 for t in self.request_times if t > cutoff)
43
+
44
+
45
+ class RateLimiter:
46
+ """请求限速器"""
47
+
48
+ def __init__(self, config: RateLimitConfig = None):
49
+ self.config = config or RateLimitConfig()
50
+ self._account_states: Dict[str, AccountRateState] = {}
51
+ self._global_requests: deque = deque(maxlen=1000)
52
+
53
+ def _get_account_state(self, account_id: str) -> AccountRateState:
54
+ """获取账号状态"""
55
+ if account_id not in self._account_states:
56
+ self._account_states[account_id] = AccountRateState()
57
+ return self._account_states[account_id]
58
+
59
+ def can_request(self, account_id: str) -> tuple:
60
+ """检查是否可以发送请求
61
+
62
+ Returns:
63
+ (can_request, wait_seconds, reason)
64
+ """
65
+ if not self.config.enabled:
66
+ return True, 0, None
67
+
68
+ now = time.time()
69
+ state = self._get_account_state(account_id)
70
+
71
+ # 检查最小请求间隔
72
+ time_since_last = now - state.last_request_time
73
+ if time_since_last < self.config.min_request_interval:
74
+ wait = self.config.min_request_interval - time_since_last
75
+ return False, wait, f"请求过快,请等待 {wait:.1f} 秒"
76
+
77
+ # 检查每账号每分钟限制
78
+ account_rpm = state.get_requests_in_window(60)
79
+ if account_rpm >= self.config.max_requests_per_minute:
80
+ return False, 2, f"账号请求过于频繁 ({account_rpm}/分钟)"
81
+
82
+ # 检查全局每分钟限制
83
+ global_rpm = sum(1 for t in self._global_requests if t > now - 60)
84
+ if global_rpm >= self.config.global_max_requests_per_minute:
85
+ return False, 1, f"全局请求过于频繁 ({global_rpm}/分钟)"
86
+
87
+ return True, 0, None
88
+
89
+ def record_request(self, account_id: str):
90
+ """记录请求"""
91
+ now = time.time()
92
+ state = self._get_account_state(account_id)
93
+ state.last_request_time = now
94
+ state.request_times.append(now)
95
+ self._global_requests.append(now)
96
+
97
+ def get_stats(self) -> dict:
98
+ """获取统计信息"""
99
+ now = time.time()
100
+ return {
101
+ "enabled": self.config.enabled,
102
+ "global_rpm": sum(1 for t in self._global_requests if t > now - 60),
103
+ "accounts": {
104
+ aid: {
105
+ "rpm": state.get_requests_in_window(60),
106
+ "last_request": now - state.last_request_time if state.last_request_time else None
107
+ }
108
+ for aid, state in self._account_states.items()
109
+ }
110
+ }
111
+
112
+ def update_config(self, **kwargs):
113
+ """更新配置"""
114
+ for key, value in kwargs.items():
115
+ if hasattr(self.config, key):
116
+ setattr(self.config, key, value)
117
+
118
+
119
+ # 全局实例
120
+ rate_limiter = RateLimiter()
121
+
122
+
123
+ def get_rate_limiter() -> RateLimiter:
124
+ """获取限速器实例"""
125
+ return rate_limiter
KiroProxy/kiro_proxy/core/refresh_manager.py ADDED
@@ -0,0 +1,888 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Token 刷新管理模块
2
+
3
+ 提供 Token 批量刷新的管理功能,包括:
4
+ - 刷新进度跟踪
5
+ - 并发控制
6
+ - 重试机制配置
7
+ - 全局锁防止重复刷新
8
+ - Token 过期检测和自动刷新
9
+ - 指数退避重试策略
10
+ """
11
+ import time
12
+ import asyncio
13
+ from dataclasses import dataclass, field, asdict
14
+ from typing import Optional, Dict, Any, List, Tuple, Callable, TYPE_CHECKING
15
+ from threading import Lock
16
+
17
+ if TYPE_CHECKING:
18
+ from .account import Account
19
+
20
+
21
+ @dataclass
22
+ class RefreshProgress:
23
+ """刷新进度信息
24
+
25
+ 用于跟踪批量 Token 刷新操作的进度状态。
26
+
27
+ Attributes:
28
+ total: 需要刷新的账号总数
29
+ completed: 已完成处理的账号数(包括成功和失败)
30
+ success: 刷新成功的账号数
31
+ failed: 刷新失败的账号数
32
+ current_account: 当前正在处理的账号ID
33
+ status: 刷新状态 - running(进行中), completed(已完成), error(出错)
34
+ started_at: 刷新开始时间戳
35
+ message: 状态消息,用于显示当前操作或错误信息
36
+ """
37
+ total: int = 0
38
+ completed: int = 0
39
+ success: int = 0
40
+ failed: int = 0
41
+ current_account: Optional[str] = None
42
+ status: str = "running" # running, completed, error
43
+ started_at: float = field(default_factory=time.time)
44
+ message: Optional[str] = None
45
+
46
+ def to_dict(self) -> Dict[str, Any]:
47
+ """转换为字典格式
48
+
49
+ Returns:
50
+ 包含所有进度信息的字典
51
+ """
52
+ return asdict(self)
53
+
54
+ @property
55
+ def progress_percent(self) -> float:
56
+ """计算完成百分比
57
+
58
+ Returns:
59
+ 完成百分比(0-100)
60
+ """
61
+ if self.total == 0:
62
+ return 0.0
63
+ return round((self.completed / self.total) * 100, 2)
64
+
65
+ @property
66
+ def elapsed_seconds(self) -> float:
67
+ """计算已用时间(秒)
68
+
69
+ Returns:
70
+ 从开始到现在的秒数
71
+ """
72
+ return time.time() - self.started_at
73
+
74
+ def is_running(self) -> bool:
75
+ """检查是否正在运行
76
+
77
+ Returns:
78
+ True 表示正在运行
79
+ """
80
+ return self.status == "running"
81
+
82
+ def is_completed(self) -> bool:
83
+ """检查是否已完成
84
+
85
+ Returns:
86
+ True 表示已完成(成功或出错)
87
+ """
88
+ return self.status in ("completed", "error")
89
+
90
+
91
+ @dataclass
92
+ class RefreshConfig:
93
+ """刷新配置
94
+
95
+ 控制 Token 刷新行为的配置参数。
96
+
97
+ Attributes:
98
+ max_retries: 单个账号刷新失败时的最大重试次数
99
+ retry_base_delay: 重试基础延迟时间(秒),实际延迟会指数增长
100
+ concurrency: 并发刷新的账号数量
101
+ token_refresh_before_expiry: Token 过期前多少秒开始刷新(默认5分钟)
102
+ auto_refresh_interval: 自动刷新检查间隔(秒)
103
+ """
104
+ max_retries: int = 3
105
+ retry_base_delay: float = 1.0
106
+ concurrency: int = 3
107
+ token_refresh_before_expiry: int = 300 # 5分钟
108
+ auto_refresh_interval: int = 60 # 1分钟
109
+
110
+ def to_dict(self) -> Dict[str, Any]:
111
+ """转换为字典格式
112
+
113
+ Returns:
114
+ 包含所有配置项的字典
115
+ """
116
+ return asdict(self)
117
+
118
+ @classmethod
119
+ def from_dict(cls, data: Dict[str, Any]) -> 'RefreshConfig':
120
+ """从字典创建配置实例
121
+
122
+ Args:
123
+ data: 配置字典
124
+
125
+ Returns:
126
+ RefreshConfig 实例
127
+ """
128
+ return cls(
129
+ max_retries=data.get("max_retries", 3),
130
+ retry_base_delay=data.get("retry_base_delay", 1.0),
131
+ concurrency=data.get("concurrency", 3),
132
+ token_refresh_before_expiry=data.get("token_refresh_before_expiry", 300),
133
+ auto_refresh_interval=data.get("auto_refresh_interval", 60)
134
+ )
135
+
136
+ def validate(self) -> bool:
137
+ """验证配置有效性
138
+
139
+ Returns:
140
+ True 表示配置有效
141
+
142
+ Raises:
143
+ ValueError: 配置值无效时抛出
144
+ """
145
+ if self.max_retries < 0:
146
+ raise ValueError("max_retries 不能为负数")
147
+ if self.retry_base_delay <= 0:
148
+ raise ValueError("retry_base_delay 必须大于0")
149
+ if self.concurrency < 1:
150
+ raise ValueError("concurrency 必须至少为1")
151
+ if self.token_refresh_before_expiry < 0:
152
+ raise ValueError("token_refresh_before_expiry 不能为负数")
153
+ if self.auto_refresh_interval < 1:
154
+ raise ValueError("auto_refresh_interval 必须至少为1秒")
155
+ return True
156
+
157
+
158
+ class RefreshManager:
159
+ """Token 刷新管理器
160
+
161
+ 管理 Token 批量刷新操作,提供:
162
+ - 全局锁机制防止重复刷新
163
+ - 进度跟踪
164
+ - 配置管理
165
+ - 自动 Token 刷新定时器
166
+
167
+ 使用示例:
168
+ manager = get_refresh_manager()
169
+ if not manager.is_refreshing():
170
+ # 开始刷新操作
171
+ pass
172
+ """
173
+
174
+ def __init__(self, config: Optional[RefreshConfig] = None):
175
+ """初始化刷新管理器
176
+
177
+ Args:
178
+ config: 刷新配置,None 则使用默认配置
179
+ """
180
+ # 配置
181
+ self._config = config or RefreshConfig()
182
+
183
+ # 线程锁(用于同步访问状态)
184
+ self._lock = Lock()
185
+
186
+ # 异步锁(用于防止并发刷新操作)
187
+ self._async_lock = asyncio.Lock()
188
+
189
+ # 刷新状态
190
+ self._is_refreshing: bool = False
191
+ self._progress: Optional[RefreshProgress] = None
192
+
193
+ # 上次刷新完成时间
194
+ self._last_refresh_time: Optional[float] = None
195
+
196
+ # 自动刷新定时器
197
+ self._auto_refresh_task: Optional[asyncio.Task] = None
198
+ self._auto_refresh_running: bool = False
199
+
200
+ # 获取账号列表的回调函数
201
+ self._accounts_getter: Optional[Callable] = None
202
+
203
+ @property
204
+ def config(self) -> RefreshConfig:
205
+ """获取当前配置
206
+
207
+ Returns:
208
+ 当前的刷新配置
209
+ """
210
+ with self._lock:
211
+ return self._config
212
+
213
+ def is_refreshing(self) -> bool:
214
+ """检查是否正在刷新
215
+
216
+ Returns:
217
+ True 表示正在进行刷新操作
218
+ """
219
+ with self._lock:
220
+ return self._is_refreshing
221
+
222
+ def get_progress(self) -> Optional[RefreshProgress]:
223
+ """获取当前刷新进度
224
+
225
+ Returns:
226
+ 当前进度信息,如果没有进行中的刷新则返回 None
227
+ """
228
+ with self._lock:
229
+ return self._progress
230
+
231
+ def get_progress_dict(self) -> Optional[Dict[str, Any]]:
232
+ """获取当前刷新进度(字典格式)
233
+
234
+ Returns:
235
+ 进度信息字典,如果没有进行中的刷新则返回 None
236
+ """
237
+ with self._lock:
238
+ if self._progress is None:
239
+ return None
240
+ return self._progress.to_dict()
241
+
242
+ def update_config(self, **kwargs) -> None:
243
+ """更新配置参数
244
+
245
+ 支持的参数:
246
+ max_retries: 最大重试次数
247
+ retry_base_delay: 重试基础延迟
248
+ concurrency: 并发数
249
+ token_refresh_before_expiry: Token 过期前刷新时间
250
+ auto_refresh_interval: 自动刷新检查间隔
251
+
252
+ Args:
253
+ **kwargs: 要更新的配置项
254
+
255
+ Raises:
256
+ ValueError: 配置值无效时抛出
257
+ """
258
+ with self._lock:
259
+ # 创建新配置
260
+ new_config = RefreshConfig(
261
+ max_retries=kwargs.get("max_retries", self._config.max_retries),
262
+ retry_base_delay=kwargs.get("retry_base_delay", self._config.retry_base_delay),
263
+ concurrency=kwargs.get("concurrency", self._config.concurrency),
264
+ token_refresh_before_expiry=kwargs.get(
265
+ "token_refresh_before_expiry",
266
+ self._config.token_refresh_before_expiry
267
+ ),
268
+ auto_refresh_interval=kwargs.get(
269
+ "auto_refresh_interval",
270
+ self._config.auto_refresh_interval
271
+ )
272
+ )
273
+
274
+ # 验证配置
275
+ new_config.validate()
276
+
277
+ # 应用新配置
278
+ self._config = new_config
279
+
280
+ def _start_refresh(self, total: int, message: Optional[str] = None) -> RefreshProgress:
281
+ """开始刷新操作(内部方法)
282
+
283
+ Args:
284
+ total: 需要刷新的账号总数
285
+ message: 初始状态消息
286
+
287
+ Returns:
288
+ 新创建的进度对象
289
+ """
290
+ with self._lock:
291
+ self._is_refreshing = True
292
+ self._progress = RefreshProgress(
293
+ total=total,
294
+ completed=0,
295
+ success=0,
296
+ failed=0,
297
+ current_account=None,
298
+ status="running",
299
+ started_at=time.time(),
300
+ message=message or "开始刷新"
301
+ )
302
+ return self._progress
303
+
304
+ def _update_progress(
305
+ self,
306
+ current_account: Optional[str] = None,
307
+ success: bool = False,
308
+ failed: bool = False,
309
+ message: Optional[str] = None
310
+ ) -> None:
311
+ """更新刷新进度(内部方法)
312
+
313
+ Args:
314
+ current_account: 当前处理的账号ID
315
+ success: 是否成功完成一个账号
316
+ failed: 是否失败一个账号
317
+ message: 状态消息
318
+ """
319
+ with self._lock:
320
+ if self._progress is None:
321
+ return
322
+
323
+ if current_account is not None:
324
+ self._progress.current_account = current_account
325
+
326
+ if success:
327
+ self._progress.success += 1
328
+ self._progress.completed += 1
329
+ elif failed:
330
+ self._progress.failed += 1
331
+ self._progress.completed += 1
332
+
333
+ if message is not None:
334
+ self._progress.message = message
335
+
336
+ def _finish_refresh(self, status: str = "completed", message: Optional[str] = None) -> None:
337
+ """完成刷新操作(内部方法)
338
+
339
+ Args:
340
+ status: 最终状态 - completed 或 error
341
+ message: 最终状态消息
342
+ """
343
+ with self._lock:
344
+ self._is_refreshing = False
345
+ self._last_refresh_time = time.time()
346
+
347
+ if self._progress is not None:
348
+ self._progress.status = status
349
+ self._progress.current_account = None
350
+ if message is not None:
351
+ self._progress.message = message
352
+ elif status == "completed":
353
+ self._progress.message = (
354
+ f"刷新完成: 成功 {self._progress.success}, "
355
+ f"失败 {self._progress.failed}"
356
+ )
357
+
358
+ def get_last_refresh_time(self) -> Optional[float]:
359
+ """获取上次刷新完成时间
360
+
361
+ Returns:
362
+ 上次刷新完成的时间戳,如果从未刷新则返回 None
363
+ """
364
+ with self._lock:
365
+ return self._last_refresh_time
366
+
367
+ def get_status(self) -> Dict[str, Any]:
368
+ """获取管理器状态
369
+
370
+ Returns:
371
+ 包含管理器状态信息的字典
372
+ """
373
+ with self._lock:
374
+ return {
375
+ "is_refreshing": self._is_refreshing,
376
+ "progress": self._progress.to_dict() if self._progress else None,
377
+ "last_refresh_time": self._last_refresh_time,
378
+ "config": self._config.to_dict()
379
+ }
380
+
381
+ async def acquire_refresh_lock(self) -> bool:
382
+ """尝试获取刷新锁
383
+
384
+ 用于在开始刷新操作前获取异步锁,防止并发刷新。
385
+
386
+ Returns:
387
+ True 表示成功获取锁,False 表示已有刷新在进行
388
+ """
389
+ if self._async_lock.locked():
390
+ return False
391
+
392
+ await self._async_lock.acquire()
393
+ return True
394
+
395
+ def release_refresh_lock(self) -> None:
396
+ """释放刷新锁
397
+
398
+ 在刷新操作完成后调用,释放异步锁。
399
+ """
400
+ if self._async_lock.locked():
401
+ self._async_lock.release()
402
+
403
+ def should_refresh_token(self, account: 'Account') -> bool:
404
+ """判断是否需要刷新 Token
405
+
406
+ 检查账号的 Token 是否即将过期(过期前5分钟)或已过期。
407
+
408
+ Args:
409
+ account: 账号对象
410
+
411
+ Returns:
412
+ True 表示需要刷新 Token
413
+ """
414
+ creds = account.get_credentials()
415
+ if creds is None:
416
+ return True # 无法获取凭证,需要刷新
417
+
418
+ # 检查是否已过期或即将过期
419
+ minutes_before = self._config.token_refresh_before_expiry // 60
420
+ return creds.is_expired() or creds.is_expiring_soon(minutes=minutes_before)
421
+
422
+ async def refresh_token_if_needed(self, account: 'Account') -> Tuple[bool, str]:
423
+ """如果需要则刷新 Token
424
+
425
+ 检查账号 Token 状态,如果即将过期或已过期则刷新。
426
+
427
+ Args:
428
+ account: 账号对象
429
+
430
+ Returns:
431
+ (success, message) 元组
432
+ - success: True 表示 Token 有效(无需刷新或刷新成功)
433
+ - message: 状态消息
434
+ """
435
+ if not self.should_refresh_token(account):
436
+ return True, "Token 有效,无需刷新"
437
+
438
+ print(f"[RefreshManager] 账号 {account.id} Token 即将过期,开始刷新...")
439
+
440
+ success, result = await account.refresh_token()
441
+
442
+ if success:
443
+ print(f"[RefreshManager] 账号 {account.id} Token 刷新成功")
444
+ return True, "Token 刷新成功"
445
+ else:
446
+ print(f"[RefreshManager] 账号 {account.id} Token 刷新失败: {result}")
447
+ return False, f"Token 刷新失败: {result}"
448
+
449
+ async def refresh_account_with_token(
450
+ self,
451
+ account: 'Account',
452
+ get_quota_func: Optional[Callable] = None
453
+ ) -> Tuple[bool, str]:
454
+ """刷新单个账号(先刷新 Token,再获取额度)
455
+
456
+ Args:
457
+ account: 账号对象
458
+ get_quota_func: 获取额度的异步函数,接受 account 参数
459
+
460
+ Returns:
461
+ (success, message) 元组
462
+ """
463
+ # 1. 先刷新 Token(如果需要)
464
+ token_success, token_msg = await self.refresh_token_if_needed(account)
465
+
466
+ if not token_success:
467
+ return False, token_msg
468
+
469
+ # 2. 获取额度(如果提供了获取函数)
470
+ if get_quota_func:
471
+ try:
472
+ quota_success, quota_result = await get_quota_func(account)
473
+ if quota_success:
474
+ return True, "刷新成功"
475
+ else:
476
+ error_msg = quota_result.get("error", "Unknown error") if isinstance(quota_result, dict) else str(quota_result)
477
+ return False, f"获取额度失败: {error_msg}"
478
+ except Exception as e:
479
+ return False, f"获取额度异常: {str(e)}"
480
+
481
+ return True, token_msg
482
+
483
+ async def retry_with_backoff(
484
+ self,
485
+ func: Callable,
486
+ *args,
487
+ max_retries: Optional[int] = None,
488
+ **kwargs
489
+ ) -> Tuple[bool, Any]:
490
+ """带指数退避的重试
491
+
492
+ 执行异步函数,失败时使用指数退避策略重试。
493
+
494
+ Args:
495
+ func: 要执行的异步函数
496
+ *args: 传递给函数的位置参数
497
+ max_retries: 最大重试次数,None 则使用配置值
498
+ **kwargs: 传递给函数的关键字参数
499
+
500
+ Returns:
501
+ (success, result) 元组
502
+ - success: True 表示执行成功
503
+ - result: 成功时为函数返回值,失败时为错误信息
504
+ """
505
+ retries = max_retries if max_retries is not None else self._config.max_retries
506
+ base_delay = self._config.retry_base_delay
507
+
508
+ last_error = None
509
+
510
+ for attempt in range(retries + 1):
511
+ try:
512
+ result = await func(*args, **kwargs)
513
+
514
+ # 检查返回值格式
515
+ if isinstance(result, tuple) and len(result) == 2:
516
+ success, data = result
517
+ if success:
518
+ return True, data
519
+ else:
520
+ last_error = data
521
+ # 检查是否是 429 错误
522
+ if self._is_rate_limit_error(data):
523
+ delay = self._get_rate_limit_delay(attempt, base_delay)
524
+ else:
525
+ delay = base_delay * (2 ** attempt)
526
+ else:
527
+ # 函数返回非元组,视为成功
528
+ return True, result
529
+
530
+ except Exception as e:
531
+ last_error = str(e)
532
+ delay = base_delay * (2 ** attempt)
533
+
534
+ # 如果还有重试机会,等待后重试
535
+ if attempt < retries:
536
+ print(f"[RefreshManager] 第 {attempt + 1} 次尝试失败,{delay:.1f}秒后重试...")
537
+ await asyncio.sleep(delay)
538
+
539
+ return False, last_error
540
+
541
+ def _is_rate_limit_error(self, error: Any) -> bool:
542
+ """检查是否是限流错误(429)
543
+
544
+ Args:
545
+ error: 错误信息
546
+
547
+ Returns:
548
+ True 表示是限流错误
549
+ """
550
+ if isinstance(error, str):
551
+ return "429" in error or "rate limit" in error.lower() or "请求过于频繁" in error
552
+ return False
553
+
554
+ def _get_rate_limit_delay(self, attempt: int, base_delay: float) -> float:
555
+ """获取限流错误的等待时间
556
+
557
+ 429 错误使用更长的等待时间。
558
+
559
+ Args:
560
+ attempt: 当前尝试次数(从0开始)
561
+ base_delay: 基础延迟
562
+
563
+ Returns:
564
+ 等待时间(秒)
565
+ """
566
+ # 429 错误使用 3 倍的基础延迟
567
+ return base_delay * 3 * (2 ** attempt)
568
+
569
+ async def refresh_all_with_token(
570
+ self,
571
+ accounts: List['Account'],
572
+ get_quota_func: Optional[Callable] = None,
573
+ skip_disabled: bool = True,
574
+ skip_error: bool = True
575
+ ) -> RefreshProgress:
576
+ """刷新所有账号(先刷新 Token,再获取额度)
577
+
578
+ 使用全局锁防止并发刷新,支持进度跟踪。
579
+
580
+ Args:
581
+ accounts: 账号列表
582
+ get_quota_func: 获取额度的异步函数
583
+ skip_disabled: 是否跳过已禁用的账号
584
+ skip_error: 是否跳过已处于错误状态的账号
585
+
586
+ Returns:
587
+ 刷新进度信息
588
+ """
589
+ # 尝试获取锁
590
+ if not await self.acquire_refresh_lock():
591
+ # 已有刷新在进行
592
+ progress = self.get_progress()
593
+ if progress:
594
+ return progress
595
+ # 返回一个错误状态的进度
596
+ return RefreshProgress(
597
+ total=0,
598
+ status="error",
599
+ message="刷新操作正在进行中"
600
+ )
601
+
602
+ try:
603
+ # 过滤账号
604
+ accounts_to_refresh = []
605
+ for acc in accounts:
606
+ if skip_disabled and not acc.enabled:
607
+ continue
608
+ if skip_error and acc.status.value in ("unhealthy", "suspended"):
609
+ continue
610
+ accounts_to_refresh.append(acc)
611
+
612
+ total = len(accounts_to_refresh)
613
+
614
+ # 开始刷新
615
+ self._start_refresh(total, f"开始刷新 {total} 个账号")
616
+
617
+ if total == 0:
618
+ self._finish_refresh("completed", "没有需要刷新的账号")
619
+ return self.get_progress()
620
+
621
+ # 使用信号量控制并发
622
+ semaphore = asyncio.Semaphore(self._config.concurrency)
623
+
624
+ async def refresh_one(account: 'Account'):
625
+ async with semaphore:
626
+ self._update_progress(
627
+ current_account=account.id,
628
+ message=f"正在刷新: {account.name}"
629
+ )
630
+
631
+ # 使用重试机制刷新
632
+ success, result = await self.retry_with_backoff(
633
+ self.refresh_account_with_token,
634
+ account,
635
+ get_quota_func
636
+ )
637
+
638
+ if success:
639
+ self._update_progress(success=True)
640
+ else:
641
+ self._update_progress(failed=True)
642
+
643
+ return success, result
644
+
645
+ # 并发执行
646
+ tasks = [refresh_one(acc) for acc in accounts_to_refresh]
647
+ await asyncio.gather(*tasks, return_exceptions=True)
648
+
649
+ # 完成
650
+ self._finish_refresh("completed")
651
+ return self.get_progress()
652
+
653
+ except Exception as e:
654
+ self._finish_refresh("error", f"刷新异常: {str(e)}")
655
+ return self.get_progress()
656
+
657
+ finally:
658
+ self.release_refresh_lock()
659
+
660
+ def _is_auth_error(self, error: Any) -> bool:
661
+ """检查是否是认证错误(401)
662
+
663
+ Args:
664
+ error: 错误信息
665
+
666
+ Returns:
667
+ True 表示是认证错误
668
+ """
669
+ if isinstance(error, str):
670
+ return "401" in error or "unauthorized" in error.lower() or "凭证已过期" in error or "需要重新登录" in error
671
+ return False
672
+
673
+ async def execute_with_auth_retry(
674
+ self,
675
+ account: 'Account',
676
+ func: Callable,
677
+ *args,
678
+ **kwargs
679
+ ) -> Tuple[bool, Any]:
680
+ """执行操作,遇到 401 错误时自动刷新 Token 并重试
681
+
682
+ Args:
683
+ account: 账号对象
684
+ func: 要执行的异步函数
685
+ *args: 传递给函数的位置参数
686
+ **kwargs: 传递给函数的关键字参数
687
+
688
+ Returns:
689
+ (success, result) 元组
690
+ """
691
+ try:
692
+ result = await func(*args, **kwargs)
693
+
694
+ # 检查返回值
695
+ if isinstance(result, tuple) and len(result) == 2:
696
+ success, data = result
697
+ if success:
698
+ return True, data
699
+
700
+ # 检查是否是 401 错误
701
+ if self._is_auth_error(data):
702
+ print(f"[RefreshManager] 账号 {account.id} 遇到 401 错误,尝试刷新 Token...")
703
+
704
+ # 刷新 Token
705
+ refresh_success, refresh_msg = await account.refresh_token()
706
+
707
+ if refresh_success:
708
+ print(f"[RefreshManager] Token 刷新成功,重试请求...")
709
+ # 重试原请求
710
+ retry_result = await func(*args, **kwargs)
711
+ if isinstance(retry_result, tuple) and len(retry_result) == 2:
712
+ return retry_result
713
+ return True, retry_result
714
+ else:
715
+ return False, f"Token 刷新失败: {refresh_msg}"
716
+
717
+ return False, data
718
+
719
+ return True, result
720
+
721
+ except Exception as e:
722
+ error_str = str(e)
723
+
724
+ # 检查异常是否是 401 错误
725
+ if self._is_auth_error(error_str):
726
+ print(f"[RefreshManager] 账号 {account.id} 遇到 401 异常,尝试刷新 Token...")
727
+
728
+ refresh_success, refresh_msg = await account.refresh_token()
729
+
730
+ if refresh_success:
731
+ print(f"[RefreshManager] Token 刷新成功,重试请求...")
732
+ try:
733
+ retry_result = await func(*args, **kwargs)
734
+ if isinstance(retry_result, tuple) and len(retry_result) == 2:
735
+ return retry_result
736
+ return True, retry_result
737
+ except Exception as retry_e:
738
+ return False, f"重试失败: {str(retry_e)}"
739
+ else:
740
+ return False, f"Token 刷新失败: {refresh_msg}"
741
+
742
+ return False, error_str
743
+
744
+ def set_accounts_getter(self, getter: Callable) -> None:
745
+ """设置获取账号列表的回调函数
746
+
747
+ Args:
748
+ getter: 返回账号列表的可调用对象
749
+ """
750
+ self._accounts_getter = getter
751
+
752
+ def _get_accounts(self) -> List['Account']:
753
+ """获取账号列表"""
754
+ if self._accounts_getter:
755
+ return self._accounts_getter()
756
+ return []
757
+
758
+ async def start_auto_refresh(self) -> None:
759
+ """启动自动 Token 刷新定时器
760
+
761
+ 定期检查所有账号的 Token 状态,自动刷新即将过期的 Token。
762
+ 启动前会清除已存在的定时器,防止重复启动。
763
+ """
764
+ # 先停止已存在的定时器
765
+ await self.stop_auto_refresh()
766
+
767
+ self._auto_refresh_running = True
768
+ self._auto_refresh_task = asyncio.create_task(self._auto_refresh_loop())
769
+ print(f"[RefreshManager] 自动 Token 刷新定时器已启动,检查间隔: {self._config.auto_refresh_interval}秒")
770
+
771
+ async def stop_auto_refresh(self) -> None:
772
+ """停止自动 Token 刷新定时器"""
773
+ self._auto_refresh_running = False
774
+
775
+ if self._auto_refresh_task:
776
+ self._auto_refresh_task.cancel()
777
+ try:
778
+ await self._auto_refresh_task
779
+ except asyncio.CancelledError:
780
+ pass
781
+ self._auto_refresh_task = None
782
+ print("[RefreshManager] 自动 Token 刷新定时器已停止")
783
+
784
+ def is_auto_refresh_running(self) -> bool:
785
+ """检查自动刷新定时器是否在运行
786
+
787
+ Returns:
788
+ True 表示定时器正在运行
789
+ """
790
+ return self._auto_refresh_running and self._auto_refresh_task is not None
791
+
792
+ async def _auto_refresh_loop(self) -> None:
793
+ """自动刷新循环
794
+
795
+ 定期检查所有账号的 Token 状态,刷新即将过期的 Token。
796
+ 跳过已禁用或错误状态的账号,单个失败不影响其他账号。
797
+ """
798
+ while self._auto_refresh_running:
799
+ try:
800
+ await asyncio.sleep(self._config.auto_refresh_interval)
801
+
802
+ if not self._auto_refresh_running:
803
+ break
804
+
805
+ accounts = self._get_accounts()
806
+ if not accounts:
807
+ continue
808
+
809
+ # 检查需要刷新的账号
810
+ accounts_to_refresh = []
811
+ for account in accounts:
812
+ # 跳过已禁用的账号
813
+ if not account.enabled:
814
+ continue
815
+
816
+ # 跳过错误状态的账号
817
+ if hasattr(account, 'status') and account.status.value in ("unhealthy", "suspended", "disabled"):
818
+ continue
819
+
820
+ # 检查是否需要刷新 Token
821
+ if self.should_refresh_token(account):
822
+ accounts_to_refresh.append(account)
823
+
824
+ if accounts_to_refresh:
825
+ print(f"[RefreshManager] 发现 {len(accounts_to_refresh)} 个账号需要刷新 Token")
826
+
827
+ # 逐个刷新,单个失败不影响其他
828
+ for account in accounts_to_refresh:
829
+ try:
830
+ success, message = await self.refresh_token_if_needed(account)
831
+ if not success:
832
+ print(f"[RefreshManager] 账号 {account.id} 自动刷新失败: {message}")
833
+ except Exception as e:
834
+ print(f"[RefreshManager] 账号 {account.id} 自动刷新异常: {e}")
835
+ # 继续处理其他账号
836
+
837
+ except asyncio.CancelledError:
838
+ break
839
+ except Exception as e:
840
+ print(f"[RefreshManager] 自动刷新循环异常: {e}")
841
+ # 继续运行,不因异常停止
842
+
843
+ def get_auto_refresh_status(self) -> Dict[str, Any]:
844
+ """获取自动刷新状态
845
+
846
+ Returns:
847
+ 包含自动刷新状态信息的字典
848
+ """
849
+ return {
850
+ "running": self.is_auto_refresh_running(),
851
+ "interval": self._config.auto_refresh_interval,
852
+ "token_refresh_before_expiry": self._config.token_refresh_before_expiry
853
+ }
854
+
855
+
856
+ # 全局刷新管理器实例
857
+ _refresh_manager: Optional[RefreshManager] = None
858
+ _manager_lock = Lock()
859
+
860
+
861
+ def get_refresh_manager() -> RefreshManager:
862
+ """获取全局刷新管理器实例
863
+
864
+ 使用单例模式,确保全局只有一个刷新管理器实例。
865
+
866
+ Returns:
867
+ 全局 RefreshManager 实例
868
+ """
869
+ global _refresh_manager
870
+
871
+ if _refresh_manager is None:
872
+ with _manager_lock:
873
+ # 双重检查锁定
874
+ if _refresh_manager is None:
875
+ _refresh_manager = RefreshManager()
876
+
877
+ return _refresh_manager
878
+
879
+
880
+ def reset_refresh_manager() -> None:
881
+ """重置全局刷新管理器
882
+
883
+ 主要用于测试场景,重置全局实例。
884
+ """
885
+ global _refresh_manager
886
+
887
+ with _manager_lock:
888
+ _refresh_manager = None
KiroProxy/kiro_proxy/core/retry.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """请求重试机制"""
2
+ import asyncio
3
+ from typing import Callable, Any, Optional, Set
4
+ from functools import wraps
5
+
6
+ # 可重试的状态码
7
+ RETRYABLE_STATUS_CODES: Set[int] = {
8
+ 408, # Request Timeout
9
+ 500, # Internal Server Error
10
+ 502, # Bad Gateway
11
+ 503, # Service Unavailable
12
+ 504, # Gateway Timeout
13
+ }
14
+
15
+ # 不可重试的状态码(直接返回错误)
16
+ NON_RETRYABLE_STATUS_CODES: Set[int] = {
17
+ 400, # Bad Request
18
+ 401, # Unauthorized
19
+ 403, # Forbidden
20
+ 404, # Not Found
21
+ 422, # Unprocessable Entity
22
+ }
23
+
24
+
25
+ def is_retryable_error(status_code: Optional[int], error: Optional[Exception] = None) -> bool:
26
+ """判断是否为可重试的错误"""
27
+ # 网络错误可重试
28
+ if error:
29
+ error_name = type(error).__name__.lower()
30
+ if any(kw in error_name for kw in ['timeout', 'connect', 'network', 'reset']):
31
+ return True
32
+
33
+ # 特定状态码可重试
34
+ if status_code and status_code in RETRYABLE_STATUS_CODES:
35
+ return True
36
+
37
+ return False
38
+
39
+
40
+ def is_non_retryable_error(status_code: Optional[int]) -> bool:
41
+ """判断是否为不可重试的错误"""
42
+ return status_code in NON_RETRYABLE_STATUS_CODES if status_code else False
43
+
44
+
45
+ async def retry_async(
46
+ func: Callable,
47
+ max_retries: int = 2,
48
+ base_delay: float = 0.5,
49
+ max_delay: float = 5.0,
50
+ on_retry: Optional[Callable[[int, Exception], None]] = None
51
+ ) -> Any:
52
+ """
53
+ 异步重试装饰器
54
+
55
+ Args:
56
+ func: 要执行的异步函数
57
+ max_retries: 最大重试次数
58
+ base_delay: 基础延迟(秒)
59
+ max_delay: 最大延迟(秒)
60
+ on_retry: 重试时的回调函数
61
+ """
62
+ last_error = None
63
+
64
+ for attempt in range(max_retries + 1):
65
+ try:
66
+ return await func()
67
+ except Exception as e:
68
+ last_error = e
69
+
70
+ # 检查是否可重试
71
+ status_code = getattr(e, 'status_code', None)
72
+ if is_non_retryable_error(status_code):
73
+ raise
74
+
75
+ if attempt < max_retries and is_retryable_error(status_code, e):
76
+ # 指数退避
77
+ delay = min(base_delay * (2 ** attempt), max_delay)
78
+
79
+ if on_retry:
80
+ on_retry(attempt + 1, e)
81
+ else:
82
+ print(f"[Retry] 第 {attempt + 1} 次重试,延迟 {delay:.1f}s,错误: {type(e).__name__}")
83
+
84
+ await asyncio.sleep(delay)
85
+ else:
86
+ raise
87
+
88
+ raise last_error
89
+
90
+
91
+ class RetryableRequest:
92
+ """可重试的请求上下文"""
93
+
94
+ def __init__(self, max_retries: int = 2, base_delay: float = 0.5):
95
+ self.max_retries = max_retries
96
+ self.base_delay = base_delay
97
+ self.attempt = 0
98
+ self.last_error = None
99
+
100
+ def should_retry(self, status_code: Optional[int] = None, error: Optional[Exception] = None) -> bool:
101
+ """判断是否应该重试"""
102
+ self.attempt += 1
103
+ self.last_error = error
104
+
105
+ if self.attempt > self.max_retries:
106
+ return False
107
+
108
+ if is_non_retryable_error(status_code):
109
+ return False
110
+
111
+ return is_retryable_error(status_code, error)
112
+
113
+ async def wait(self):
114
+ """等待重试延迟"""
115
+ delay = min(self.base_delay * (2 ** (self.attempt - 1)), 5.0)
116
+ print(f"[Retry] 第 {self.attempt} 次重试,延迟 {delay:.1f}s")
117
+ await asyncio.sleep(delay)
KiroProxy/kiro_proxy/core/scheduler.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """后台任务调度器"""
2
+ import asyncio
3
+ from typing import Optional
4
+ from datetime import datetime
5
+
6
+
7
+ class BackgroundScheduler:
8
+ """后台任务调度器
9
+
10
+ 负责:
11
+ - Token 过期预刷新
12
+ - 账号健康检查
13
+ - 统计数据更新
14
+ """
15
+
16
+ def __init__(self):
17
+ self._task: Optional[asyncio.Task] = None
18
+ self._running = False
19
+ self._refresh_interval = 300 # 5 分钟检查一次
20
+ self._health_check_interval = 600 # 10 分钟健康检查
21
+ self._last_health_check = 0
22
+
23
+ async def start(self):
24
+ """启动后台任务"""
25
+ if self._running:
26
+ return
27
+ self._running = True
28
+ self._task = asyncio.create_task(self._run())
29
+ print("[Scheduler] 后台任务已启动")
30
+
31
+ async def stop(self):
32
+ """停止后台任务"""
33
+ self._running = False
34
+ if self._task:
35
+ self._task.cancel()
36
+ try:
37
+ await self._task
38
+ except asyncio.CancelledError:
39
+ pass
40
+ print("[Scheduler] 后台任务已停止")
41
+
42
+ async def _run(self):
43
+ """主循环"""
44
+ from . import state
45
+ import time
46
+
47
+ while self._running:
48
+ try:
49
+ # Token 预刷新
50
+ await self._refresh_expiring_tokens(state)
51
+
52
+ # 健康检查
53
+ now = time.time()
54
+ if now - self._last_health_check > self._health_check_interval:
55
+ await self._health_check(state)
56
+ self._last_health_check = now
57
+
58
+ await asyncio.sleep(self._refresh_interval)
59
+
60
+ except asyncio.CancelledError:
61
+ break
62
+ except Exception as e:
63
+ print(f"[Scheduler] 错误: {e}")
64
+ await asyncio.sleep(60)
65
+
66
+ async def _refresh_expiring_tokens(self, state):
67
+ """刷新即将过期的 Token"""
68
+ for acc in state.accounts:
69
+ if not acc.enabled:
70
+ continue
71
+
72
+ # 提前 15 分钟刷新
73
+ if acc.is_token_expiring_soon(15):
74
+ print(f"[Scheduler] Token 即将过期,预刷新: {acc.name}")
75
+ success, msg = await acc.refresh_token()
76
+ if success:
77
+ print(f"[Scheduler] Token 刷新成功: {acc.name}")
78
+ else:
79
+ print(f"[Scheduler] Token 刷新失败: {acc.name} - {msg}")
80
+
81
+ async def _health_check(self, state):
82
+ """健康检查"""
83
+ import httpx
84
+ from ..config import MODELS_URL
85
+ from ..credential import CredentialStatus
86
+
87
+ for acc in state.accounts:
88
+ if not acc.enabled:
89
+ continue
90
+
91
+ try:
92
+ token = acc.get_token()
93
+ if not token:
94
+ acc.status = CredentialStatus.UNHEALTHY
95
+ continue
96
+
97
+ headers = {
98
+ "Authorization": f"Bearer {token}",
99
+ "content-type": "application/json"
100
+ }
101
+
102
+ async with httpx.AsyncClient(verify=False, timeout=10) as client:
103
+ resp = await client.get(
104
+ MODELS_URL,
105
+ headers=headers,
106
+ params={"origin": "AI_EDITOR"}
107
+ )
108
+
109
+ if resp.status_code == 200:
110
+ if acc.status == CredentialStatus.UNHEALTHY:
111
+ acc.status = CredentialStatus.ACTIVE
112
+ print(f"[HealthCheck] 账号恢复健康: {acc.name}")
113
+ elif resp.status_code == 401:
114
+ acc.status = CredentialStatus.UNHEALTHY
115
+ print(f"[HealthCheck] 账号认证失败: {acc.name}")
116
+ elif resp.status_code == 429:
117
+ # 配额超限,不改变状态
118
+ pass
119
+
120
+ except Exception as e:
121
+ print(f"[HealthCheck] 检查失败 {acc.name}: {e}")
122
+
123
+
124
+ # 全局调度器实例
125
+ scheduler = BackgroundScheduler()
KiroProxy/kiro_proxy/core/state.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """全局状态管理"""
2
+ import time
3
+ from collections import deque
4
+ from dataclasses import dataclass
5
+ from typing import Optional, List, Dict
6
+ from pathlib import Path
7
+
8
+ from ..config import TOKEN_PATH
9
+ from ..credential import quota_manager, CredentialStatus
10
+ from .account import Account
11
+ from .persistence import load_accounts, save_accounts
12
+ from .quota_cache import get_quota_cache
13
+ from .account_selector import get_account_selector, SelectionStrategy
14
+ from .quota_scheduler import get_quota_scheduler
15
+
16
+
17
+ @dataclass
18
+ class RequestLog:
19
+ """请求日志"""
20
+ id: str
21
+ timestamp: float
22
+ method: str
23
+ path: str
24
+ model: str
25
+ account_id: Optional[str]
26
+ status: int
27
+ duration_ms: float
28
+ tokens_in: int = 0
29
+ tokens_out: int = 0
30
+ error: Optional[str] = None
31
+
32
+
33
+ class ProxyState:
34
+ """全局状态管理"""
35
+
36
+ def __init__(self):
37
+ self.accounts: List[Account] = []
38
+ self.request_logs: deque = deque(maxlen=1000)
39
+ self.total_requests: int = 0
40
+ self.total_errors: int = 0
41
+ self.session_locks: Dict[str, str] = {}
42
+ self.session_timestamps: Dict[str, float] = {}
43
+ self.start_time: float = time.time()
44
+ self._load_accounts()
45
+
46
+ def _load_accounts(self):
47
+ """从配置文件加载账号"""
48
+ saved = load_accounts()
49
+ if saved:
50
+ for acc_data in saved:
51
+ # 验证 token 文件存在
52
+ if Path(acc_data.get("token_path", "")).exists():
53
+ self.accounts.append(Account(
54
+ id=acc_data["id"],
55
+ name=acc_data["name"],
56
+ token_path=acc_data["token_path"],
57
+ enabled=acc_data.get("enabled", True),
58
+ auto_disabled=acc_data.get("auto_disabled", False),
59
+ ))
60
+ print(f"[State] 从配置加载 {len(self.accounts)} 个账号")
61
+
62
+ # 如果没有账号,尝试添加默认账号
63
+ if not self.accounts and TOKEN_PATH.exists():
64
+ self.accounts.append(Account(
65
+ id="default",
66
+ name="默认账号",
67
+ token_path=str(TOKEN_PATH)
68
+ ))
69
+ self._save_accounts()
70
+
71
+ def _save_accounts(self):
72
+ """保存账号到配置文件"""
73
+ accounts_data = [
74
+ {
75
+ "id": acc.id,
76
+ "name": acc.name,
77
+ "token_path": acc.token_path,
78
+ "enabled": acc.enabled,
79
+ "auto_disabled": getattr(acc, "auto_disabled", False),
80
+ }
81
+ for acc in self.accounts
82
+ ]
83
+ save_accounts(accounts_data)
84
+
85
+ def get_available_account(self, session_id: Optional[str] = None) -> Optional[Account]:
86
+ """获取可用账号(支持会话粘性和智能选择)"""
87
+ quota_manager.cleanup_expired()
88
+
89
+ selector = get_account_selector()
90
+ has_priority = bool(selector.get_priority_accounts())
91
+ use_session_sticky = bool(session_id) and not has_priority and selector.strategy != SelectionStrategy.RANDOM
92
+
93
+ # 会话粘性
94
+ if use_session_sticky and session_id in self.session_locks:
95
+ account_id = self.session_locks[session_id]
96
+ ts = self.session_timestamps.get(session_id, 0)
97
+ if time.time() - ts < 60:
98
+ for acc in self.accounts:
99
+ if acc.id == account_id and acc.is_available():
100
+ self.session_timestamps[session_id] = time.time()
101
+ return acc
102
+
103
+ # 使用 AccountSelector 选择账号
104
+ account = selector.select(self.accounts, session_id)
105
+
106
+ if account and use_session_sticky:
107
+ self.session_locks[session_id] = account.id
108
+ self.session_timestamps[session_id] = time.time()
109
+
110
+ # 标记为活跃账号,便于额度调度器定期更新
111
+ if account:
112
+ try:
113
+ get_quota_scheduler().mark_active(account.id)
114
+ except Exception:
115
+ pass
116
+
117
+ return account
118
+
119
+ def mark_account_used(self, account_id: str) -> None:
120
+ """标记账号被使用"""
121
+ scheduler = get_quota_scheduler()
122
+ scheduler.mark_active(account_id)
123
+
124
+ for acc in self.accounts:
125
+ if acc.id == account_id:
126
+ acc.last_used = time.time()
127
+ break
128
+
129
+ def get_next_available_account(self, exclude_id: str) -> Optional[Account]:
130
+ """获取下一个可用账号(排除指定账号)"""
131
+ available = [a for a in self.accounts if a.is_available() and a.id != exclude_id]
132
+ if not available:
133
+ return None
134
+ account = min(available, key=lambda a: a.request_count)
135
+ try:
136
+ get_quota_scheduler().mark_active(account.id)
137
+ except Exception:
138
+ pass
139
+ return account
140
+
141
+ def mark_rate_limited(self, account_id: str, duration_seconds: int = 60):
142
+ """标记账号限流"""
143
+ for acc in self.accounts:
144
+ if acc.id == account_id:
145
+ acc.mark_quota_exceeded("Rate limited")
146
+ break
147
+
148
+ def mark_quota_exceeded(self, account_id: str, reason: str = "Quota exceeded"):
149
+ """标记账号配额超限"""
150
+ for acc in self.accounts:
151
+ if acc.id == account_id:
152
+ acc.mark_quota_exceeded(reason)
153
+ break
154
+
155
+ async def refresh_account_token(self, account_id: str) -> tuple:
156
+ """刷新指定账号的 token"""
157
+ for acc in self.accounts:
158
+ if acc.id == account_id:
159
+ return await acc.refresh_token()
160
+ return False, "账号不存在"
161
+
162
+ async def refresh_expiring_tokens(self) -> List[dict]:
163
+ """刷新所有即将过期的 token"""
164
+ results = []
165
+ for acc in self.accounts:
166
+ if acc.enabled and acc.is_token_expiring_soon(10):
167
+ success, msg = await acc.refresh_token()
168
+ results.append({
169
+ "account_id": acc.id,
170
+ "success": success,
171
+ "message": msg
172
+ })
173
+ return results
174
+
175
+ def add_log(self, log: RequestLog):
176
+ """添加请求日志"""
177
+ self.request_logs.append(log)
178
+ self.total_requests += 1
179
+ if log.error:
180
+ self.total_errors += 1
181
+
182
+ def get_stats(self) -> dict:
183
+ """获取统计信息"""
184
+ uptime = time.time() - self.start_time
185
+
186
+ # 获取额度汇总
187
+ quota_cache = get_quota_cache()
188
+ quota_summary = quota_cache.get_summary()
189
+
190
+ # 获取选择器状态
191
+ selector = get_account_selector()
192
+ selector_status = selector.get_status()
193
+
194
+ # 获取调度器状态
195
+ scheduler = get_quota_scheduler()
196
+ scheduler_status = scheduler.get_status()
197
+
198
+ return {
199
+ "uptime_seconds": int(uptime),
200
+ "total_requests": self.total_requests,
201
+ "total_errors": self.total_errors,
202
+ "error_rate": f"{(self.total_errors / max(1, self.total_requests) * 100):.1f}%",
203
+ "accounts_total": len(self.accounts),
204
+ "accounts_available": len([a for a in self.accounts if a.is_available()]),
205
+ "accounts_cooldown": len([a for a in self.accounts if a.status == CredentialStatus.COOLDOWN]),
206
+ "recent_logs": len(self.request_logs),
207
+ # 新增字段
208
+ "quota_summary": quota_summary,
209
+ "selector": selector_status,
210
+ "scheduler": scheduler_status,
211
+ }
212
+
213
+ def get_accounts_status(self) -> List[dict]:
214
+ """获取所有账号状态"""
215
+ return [acc.get_status_info() for acc in self.accounts]
216
+
217
+ def get_accounts_summary(self) -> dict:
218
+ """获取账号汇总统计"""
219
+ quota_cache = get_quota_cache()
220
+ selector = get_account_selector()
221
+ scheduler = get_quota_scheduler()
222
+
223
+ total_balance = 0.0
224
+ total_usage = 0.0
225
+ total_limit = 0.0
226
+
227
+ available_count = 0
228
+ cooldown_count = 0
229
+ unhealthy_count = 0
230
+ disabled_count = 0
231
+
232
+ for acc in self.accounts:
233
+ if not acc.enabled:
234
+ disabled_count += 1
235
+ elif acc.status == CredentialStatus.COOLDOWN:
236
+ cooldown_count += 1
237
+ elif acc.status == CredentialStatus.UNHEALTHY:
238
+ unhealthy_count += 1
239
+ elif acc.is_available():
240
+ available_count += 1
241
+
242
+ quota = quota_cache.get(acc.id)
243
+ if quota and not quota.has_error():
244
+ total_balance += quota.balance
245
+ total_usage += quota.current_usage
246
+ total_limit += quota.usage_limit
247
+
248
+ last_refresh = scheduler.get_last_full_refresh()
249
+ last_refresh_ago = None
250
+ if last_refresh:
251
+ seconds_ago = time.time() - last_refresh
252
+ if seconds_ago < 60:
253
+ last_refresh_ago = f"{int(seconds_ago)}秒前"
254
+ elif seconds_ago < 3600:
255
+ last_refresh_ago = f"{int(seconds_ago / 60)}分钟前"
256
+ else:
257
+ last_refresh_ago = f"{int(seconds_ago / 3600)}小时前"
258
+
259
+ return {
260
+ "total_accounts": len(self.accounts),
261
+ "available_accounts": available_count,
262
+ "cooldown_accounts": cooldown_count,
263
+ "unhealthy_accounts": unhealthy_count,
264
+ "disabled_accounts": disabled_count,
265
+ "total_balance": round(total_balance, 2),
266
+ "total_usage": round(total_usage, 2),
267
+ "total_limit": round(total_limit, 2),
268
+ "last_refresh": last_refresh_ago,
269
+ "last_refresh_timestamp": last_refresh,
270
+ "strategy": selector.strategy.value,
271
+ "priority_accounts": selector.get_priority_accounts(),
272
+ }
273
+
274
+ def get_valid_account_ids(self) -> set:
275
+ """获取所有有效账号ID集合"""
276
+ return {acc.id for acc in self.accounts if acc.enabled}
277
+
278
+
279
+ # 全局状态实例
280
+ state = ProxyState()
KiroProxy/kiro_proxy/core/stats.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """请求统计增强"""
2
+ from collections import defaultdict
3
+ from dataclasses import dataclass, field
4
+ from typing import Dict, List
5
+ import time
6
+
7
+
8
+ @dataclass
9
+ class AccountStats:
10
+ """账号统计"""
11
+ total_requests: int = 0
12
+ total_errors: int = 0
13
+ total_tokens_in: int = 0
14
+ total_tokens_out: int = 0
15
+ last_request_time: float = 0
16
+
17
+ def record(self, success: bool, tokens_in: int = 0, tokens_out: int = 0):
18
+ self.total_requests += 1
19
+ if not success:
20
+ self.total_errors += 1
21
+ self.total_tokens_in += tokens_in
22
+ self.total_tokens_out += tokens_out
23
+ self.last_request_time = time.time()
24
+
25
+ @property
26
+ def error_rate(self) -> float:
27
+ if self.total_requests == 0:
28
+ return 0
29
+ return self.total_errors / self.total_requests
30
+
31
+
32
+ @dataclass
33
+ class ModelStats:
34
+ """模型统计"""
35
+ total_requests: int = 0
36
+ total_errors: int = 0
37
+ total_latency_ms: float = 0
38
+
39
+ def record(self, success: bool, latency_ms: float):
40
+ self.total_requests += 1
41
+ if not success:
42
+ self.total_errors += 1
43
+ self.total_latency_ms += latency_ms
44
+
45
+ @property
46
+ def avg_latency_ms(self) -> float:
47
+ if self.total_requests == 0:
48
+ return 0
49
+ return self.total_latency_ms / self.total_requests
50
+
51
+
52
+ class StatsManager:
53
+ """统计管理器"""
54
+
55
+ def __init__(self):
56
+ self.by_account: Dict[str, AccountStats] = defaultdict(AccountStats)
57
+ self.by_model: Dict[str, ModelStats] = defaultdict(ModelStats)
58
+ self.hourly_requests: Dict[int, int] = defaultdict(int) # hour -> count
59
+
60
+ def record_request(
61
+ self,
62
+ account_id: str,
63
+ model: str,
64
+ success: bool,
65
+ latency_ms: float,
66
+ tokens_in: int = 0,
67
+ tokens_out: int = 0
68
+ ):
69
+ """记录请求"""
70
+ # 按账号统计
71
+ self.by_account[account_id].record(success, tokens_in, tokens_out)
72
+
73
+ # 按模型统计
74
+ self.by_model[model].record(success, latency_ms)
75
+
76
+ # 按小时统计
77
+ hour = int(time.time() // 3600)
78
+ self.hourly_requests[hour] += 1
79
+
80
+ # 清理旧数据(保留 24 小时)
81
+ self._cleanup_hourly()
82
+
83
+ def _cleanup_hourly(self):
84
+ """清理超过 24 小时的数据"""
85
+ current_hour = int(time.time() // 3600)
86
+ cutoff = current_hour - 24
87
+ self.hourly_requests = defaultdict(
88
+ int,
89
+ {h: c for h, c in self.hourly_requests.items() if h > cutoff}
90
+ )
91
+
92
+ def get_account_stats(self, account_id: str) -> dict:
93
+ """获取账号统计"""
94
+ stats = self.by_account.get(account_id, AccountStats())
95
+ return {
96
+ "total_requests": stats.total_requests,
97
+ "total_errors": stats.total_errors,
98
+ "error_rate": f"{stats.error_rate * 100:.1f}%",
99
+ "total_tokens_in": stats.total_tokens_in,
100
+ "total_tokens_out": stats.total_tokens_out,
101
+ "last_request": stats.last_request_time
102
+ }
103
+
104
+ def get_model_stats(self, model: str) -> dict:
105
+ """获取模型统计"""
106
+ stats = self.by_model.get(model, ModelStats())
107
+ return {
108
+ "total_requests": stats.total_requests,
109
+ "total_errors": stats.total_errors,
110
+ "avg_latency_ms": round(stats.avg_latency_ms, 2)
111
+ }
112
+
113
+ def get_all_stats(self) -> dict:
114
+ """获取所有统计"""
115
+ return {
116
+ "by_account": {
117
+ acc_id: self.get_account_stats(acc_id)
118
+ for acc_id in self.by_account
119
+ },
120
+ "by_model": {
121
+ model: self.get_model_stats(model)
122
+ for model in self.by_model
123
+ },
124
+ "hourly_requests": dict(self.hourly_requests),
125
+ "requests_last_24h": sum(self.hourly_requests.values())
126
+ }
127
+
128
+
129
+ # 全局统计实例
130
+ stats_manager = StatsManager()
KiroProxy/kiro_proxy/core/thinking.py ADDED
@@ -0,0 +1,456 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Thinking / Extended Thinking helpers.
2
+
3
+ This project implements "thinking" at the proxy layer by:
4
+ 1) Making a separate Kiro request to generate internal reasoning text.
5
+ 2) Injecting that reasoning back into the main user prompt (hidden) to improve quality.
6
+ 3) Optionally returning the reasoning to clients in protocol-appropriate formats.
7
+
8
+ Notes:
9
+ - Kiro's upstream API doesn't expose a native "thinking budget" knob, so `budget_tokens`
10
+ is enforced only via prompt instructions (best-effort).
11
+ - If the client does not provide a budget, we treat it as "unlimited" (no prompt limit).
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ from dataclasses import dataclass
17
+ from typing import Any, AsyncIterator, Optional
18
+
19
+ import json
20
+
21
+ import httpx
22
+
23
+ from ..config import KIRO_API_URL
24
+ from ..kiro_api import build_kiro_request, parse_event_stream
25
+
26
+
27
+ @dataclass(frozen=True)
28
+ class ThinkingConfig:
29
+ enabled: bool
30
+ budget_tokens: Optional[int] = None # None == unlimited
31
+
32
+
33
+ def _coerce_bool(value: Any) -> Optional[bool]:
34
+ if isinstance(value, bool):
35
+ return value
36
+ if isinstance(value, (int, float)):
37
+ return bool(value)
38
+ if isinstance(value, str):
39
+ v = value.strip().lower()
40
+ if v in {"true", "1", "yes", "y", "on", "enabled"}:
41
+ return True
42
+ if v in {"false", "0", "no", "n", "off", "disabled"}:
43
+ return False
44
+ return None
45
+
46
+
47
+ def _coerce_int(value: Any) -> Optional[int]:
48
+ if value is None:
49
+ return None
50
+ if isinstance(value, bool):
51
+ return None
52
+ if isinstance(value, int):
53
+ return value
54
+ if isinstance(value, float):
55
+ return int(value)
56
+ if isinstance(value, str):
57
+ v = value.strip()
58
+ if not v:
59
+ return None
60
+ try:
61
+ return int(v)
62
+ except ValueError:
63
+ return None
64
+ return None
65
+
66
+
67
+ def normalize_thinking_config(raw: Any) -> ThinkingConfig:
68
+ """Normalize multiple "thinking" shapes into a single config.
69
+
70
+ Supported shapes (best-effort):
71
+ - None / missing: disabled
72
+ - bool: enabled/disabled
73
+ - str: "enabled"/"disabled"
74
+ - dict:
75
+ - {"type": "enabled", "budget_tokens": 20000} (Anthropic style)
76
+ - {"thinking_type": "enabled", "budget_tokens": 20000} (legacy)
77
+ - {"enabled": true, "budget_tokens": 20000}
78
+ - {"includeThoughts": true, "thinkingBudget": 20000} (Gemini-ish)
79
+ """
80
+ if raw is None:
81
+ return ThinkingConfig(enabled=False, budget_tokens=None)
82
+
83
+ bool_value = _coerce_bool(raw)
84
+ if bool_value is not None and not isinstance(raw, dict):
85
+ return ThinkingConfig(enabled=bool_value, budget_tokens=None)
86
+
87
+ if isinstance(raw, dict):
88
+ mode = raw.get("type") or raw.get("thinking_type") or raw.get("mode")
89
+ enabled = None
90
+ if isinstance(mode, str):
91
+ enabled = _coerce_bool(mode)
92
+ if enabled is None:
93
+ enabled = _coerce_bool(raw.get("enabled"))
94
+ if enabled is None:
95
+ enabled = _coerce_bool(raw.get("includeThoughts") or raw.get("include_thoughts"))
96
+ if enabled is None:
97
+ enabled = False
98
+
99
+ budget_tokens = None
100
+ for key in (
101
+ "budget_tokens",
102
+ "budgetTokens",
103
+ "thinkingBudget",
104
+ "thinking_budget",
105
+ "max_thinking_length",
106
+ "maxThinkingLength",
107
+ ):
108
+ if key in raw:
109
+ budget_tokens = _coerce_int(raw.get(key))
110
+ break
111
+ if budget_tokens is not None and budget_tokens <= 0:
112
+ budget_tokens = None
113
+
114
+ return ThinkingConfig(enabled=bool(enabled), budget_tokens=budget_tokens)
115
+
116
+ if isinstance(raw, str):
117
+ enabled = _coerce_bool(raw)
118
+ return ThinkingConfig(enabled=bool(enabled), budget_tokens=None)
119
+
120
+ return ThinkingConfig(enabled=False, budget_tokens=None)
121
+
122
+
123
+ def map_openai_reasoning_effort_to_budget(effort: Any) -> Optional[int]:
124
+ """Map OpenAI-style reasoning effort into a best-effort budget.
125
+
126
+ We keep this generous; if effort is "high", treat as unlimited.
127
+ """
128
+ if not isinstance(effort, str):
129
+ return None
130
+ v = effort.strip().lower()
131
+ if v in {"high"}:
132
+ return None
133
+ if v in {"medium"}:
134
+ return 20000
135
+ if v in {"low"}:
136
+ return 10000
137
+ return None
138
+
139
+
140
+ def extract_thinking_config_from_openai_body(body: dict) -> tuple[ThinkingConfig, bool]:
141
+ """Extract thinking config from OpenAI ChatCompletions/Responses-style bodies."""
142
+ if not isinstance(body, dict):
143
+ return ThinkingConfig(False, None), False
144
+
145
+ if "thinking" in body:
146
+ return normalize_thinking_config(body.get("thinking")), True
147
+
148
+ # OpenAI Responses API style
149
+ reasoning = body.get("reasoning")
150
+ if "reasoning" in body:
151
+ if isinstance(reasoning, dict):
152
+ effort = reasoning.get("effort")
153
+ if isinstance(effort, str) and effort.strip().lower() in {"low", "medium", "high"}:
154
+ return ThinkingConfig(True, map_openai_reasoning_effort_to_budget(effort)), True
155
+ cfg = normalize_thinking_config(reasoning)
156
+ return cfg, True
157
+
158
+ effort = body.get("reasoning_effort")
159
+ if "reasoning_effort" in body and isinstance(effort, str) and effort.strip().lower() in {"low", "medium", "high"}:
160
+ return ThinkingConfig(True, map_openai_reasoning_effort_to_budget(effort)), True
161
+
162
+ return ThinkingConfig(False, None), False
163
+
164
+
165
+ def extract_thinking_config_from_gemini_body(body: dict) -> tuple[ThinkingConfig, bool]:
166
+ """Extract thinking config from Gemini generateContent bodies (best-effort)."""
167
+ if not isinstance(body, dict):
168
+ return ThinkingConfig(False, None), False
169
+
170
+ if "thinking" in body:
171
+ return normalize_thinking_config(body.get("thinking")), True
172
+
173
+ if "thinkingConfig" in body:
174
+ return normalize_thinking_config(body.get("thinkingConfig")), True
175
+
176
+ gen_cfg = body.get("generationConfig")
177
+ if isinstance(gen_cfg, dict):
178
+ if "thinkingConfig" in gen_cfg:
179
+ raw = gen_cfg.get("thinkingConfig")
180
+ cfg = normalize_thinking_config(raw)
181
+ if cfg.enabled:
182
+ return cfg, True
183
+ # Budget without explicit includeThoughts/mode: treat as enabled (client guidance exists)
184
+ if isinstance(raw, dict) and any(
185
+ k in raw for k in ("thinkingBudget", "budgetTokens", "budget_tokens", "max_thinking_length")
186
+ ):
187
+ return ThinkingConfig(True, cfg.budget_tokens), True
188
+ return cfg, True
189
+
190
+ return ThinkingConfig(False, None), False
191
+
192
+
193
+ def infer_thinking_from_anthropic_messages(messages: list[dict]) -> bool:
194
+ """推断历史消息中是否包含思维链内容,用于在客户端未明确指定时自动启用思维链"""
195
+ for msg in messages or []:
196
+ content = msg.get("content")
197
+ if not isinstance(content, list):
198
+ continue
199
+ for block in content:
200
+ if isinstance(block, dict):
201
+ # 检查标准的 thinking 块
202
+ if block.get("type") == "thinking":
203
+ return True
204
+ # 检查文本块中嵌入的 <thinking> 标签(assistant 消息中可能存在)
205
+ if block.get("type") == "text" and msg.get("role") == "assistant":
206
+ text = block.get("text", "")
207
+ if isinstance(text, str) and "<thinking>" in text and "</thinking>" in text:
208
+ return True
209
+ return False
210
+
211
+
212
+ def infer_thinking_from_openai_messages(messages: list[dict]) -> bool:
213
+ for msg in messages or []:
214
+ content = msg.get("content", "")
215
+ if isinstance(content, str):
216
+ if "<thinking>" in content and "</thinking>" in content:
217
+ return True
218
+ continue
219
+ if isinstance(content, list):
220
+ for part in content:
221
+ if isinstance(part, dict) and part.get("type") == "text":
222
+ text = part.get("text", "")
223
+ if "<thinking>" in text and "</thinking>" in text:
224
+ return True
225
+ return False
226
+
227
+
228
+ def infer_thinking_from_openai_responses_input(input_data: Any) -> bool:
229
+ """Infer thinking from OpenAI Responses API `input` payloads (best-effort)."""
230
+ if isinstance(input_data, str):
231
+ return "<thinking>" in input_data and "</thinking>" in input_data
232
+
233
+ if not isinstance(input_data, list):
234
+ return False
235
+
236
+ for item in input_data:
237
+ if not isinstance(item, dict):
238
+ continue
239
+ if item.get("type") != "message":
240
+ continue
241
+
242
+ content_list = item.get("content", []) or []
243
+ for c in content_list:
244
+ if isinstance(c, str):
245
+ if "<thinking>" in c and "</thinking>" in c:
246
+ return True
247
+ continue
248
+ if not isinstance(c, dict):
249
+ continue
250
+ c_type = c.get("type")
251
+ if c_type in {"input_text", "output_text", "text"}:
252
+ text = c.get("text", "")
253
+ if isinstance(text, str) and "<thinking>" in text and "</thinking>" in text:
254
+ return True
255
+ return False
256
+
257
+
258
+ def infer_thinking_from_gemini_contents(contents: list[dict]) -> bool:
259
+ for item in contents or []:
260
+ for part in item.get("parts", []) or []:
261
+ if isinstance(part, dict) and isinstance(part.get("text"), str):
262
+ text = part["text"]
263
+ if "<thinking>" in text and "</thinking>" in text:
264
+ return True
265
+ return False
266
+
267
+
268
+ import re
269
+
270
+ _THINKING_PATTERN = re.compile(r"<thinking>.*?</thinking>\s*", re.DOTALL)
271
+
272
+
273
+ def strip_thinking_from_text(text: str) -> str:
274
+ """Remove <thinking> blocks from text."""
275
+ if not text or not isinstance(text, str):
276
+ return text
277
+ return _THINKING_PATTERN.sub("", text).strip()
278
+
279
+
280
+ def strip_thinking_from_history(history: list) -> list:
281
+ """Return a copy of history with <thinking> blocks removed from all messages."""
282
+ if not history:
283
+ return []
284
+
285
+ cleaned = []
286
+ for msg in history:
287
+ if not isinstance(msg, dict):
288
+ cleaned.append(msg)
289
+ continue
290
+
291
+ new_msg = msg.copy()
292
+ content = msg.get("content")
293
+
294
+ if isinstance(content, str):
295
+ new_msg["content"] = strip_thinking_from_text(content)
296
+ elif isinstance(content, list):
297
+ new_content = []
298
+ for part in content:
299
+ if isinstance(part, dict) and part.get("type") == "text":
300
+ new_part = part.copy()
301
+ new_part["text"] = strip_thinking_from_text(part.get("text", ""))
302
+ new_content.append(new_part)
303
+ else:
304
+ new_content.append(part)
305
+ new_msg["content"] = new_content
306
+
307
+ cleaned.append(new_msg)
308
+
309
+ return cleaned
310
+
311
+
312
+ def format_thinking_block(thinking_content: str) -> str:
313
+ if thinking_content is None:
314
+ return ""
315
+ thinking_content = str(thinking_content).strip()
316
+ if not thinking_content:
317
+ return ""
318
+ return f"<thinking>\n{thinking_content}\n</thinking>"
319
+
320
+
321
+ def build_thinking_prompt(user_content: str, *, budget_tokens: Optional[int]) -> str:
322
+ """Build a separate prompt using Tree of Thoughts approach.
323
+
324
+ Use multiple expert perspectives to analyze the problem deeply.
325
+ """
326
+ if user_content is None:
327
+ user_content = ""
328
+
329
+ budget_str = ""
330
+ if budget_tokens:
331
+ budget_str = f" Budget: {budget_tokens} tokens."
332
+
333
+ return (
334
+ f"Think deeply and comprehensively about this problem.{budget_str}\n\n"
335
+ "Use the following approach:\n"
336
+ "1. Break down the problem into components\n"
337
+ "2. Consider multiple perspectives and solutions\n"
338
+ "3. Evaluate trade-offs and edge cases\n"
339
+ "4. Synthesize your analysis into a coherent response\n\n"
340
+ f"{user_content}"
341
+ )
342
+
343
+ def build_user_prompt_with_thinking(user_content: str, thinking_content: str) -> str:
344
+ """Inject thinking into the main prompt.
345
+
346
+ Minimal injection to avoid context pollution.
347
+ """
348
+ if user_content is None:
349
+ user_content = ""
350
+
351
+ thinking_block = format_thinking_block(thinking_content)
352
+ if not thinking_block:
353
+ return user_content
354
+
355
+ return f"{thinking_block}\n\n{user_content}"
356
+
357
+
358
+ async def iter_aws_event_stream_text(byte_iter: AsyncIterator[bytes]) -> AsyncIterator[str]:
359
+ """Yield incremental text content from AWS event-stream chunks."""
360
+ buffer = b""
361
+
362
+ async for chunk in byte_iter:
363
+ buffer += chunk
364
+
365
+ while len(buffer) >= 12:
366
+ total_len = int.from_bytes(buffer[0:4], "big")
367
+
368
+ if total_len <= 0:
369
+ return
370
+ if len(buffer) < total_len:
371
+ break
372
+
373
+ headers_len = int.from_bytes(buffer[4:8], "big")
374
+ payload_start = 12 + headers_len
375
+ payload_end = total_len - 4
376
+
377
+ if payload_start < payload_end:
378
+ try:
379
+ payload = json.loads(buffer[payload_start:payload_end].decode("utf-8"))
380
+ content = None
381
+ if "assistantResponseEvent" in payload:
382
+ content = payload["assistantResponseEvent"].get("content")
383
+ elif "content" in payload and "toolUseId" not in payload:
384
+ content = payload.get("content")
385
+ if content:
386
+ yield content
387
+ except Exception:
388
+ pass
389
+
390
+ buffer = buffer[total_len:]
391
+
392
+
393
+ async def fetch_thinking_text(
394
+ *,
395
+ headers: dict,
396
+ model: str,
397
+ user_content: str,
398
+ history: list,
399
+ images: list | None = None,
400
+ tool_results: list | None = None,
401
+ budget_tokens: Optional[int] = None,
402
+ timeout_s: float = 600.0,
403
+ ) -> str:
404
+ """Non-streaming helper to get thinking content (best-effort)."""
405
+ thinking_prompt = build_thinking_prompt(user_content, budget_tokens=budget_tokens)
406
+ clean_history = strip_thinking_from_history(history)
407
+ thinking_request = build_kiro_request(
408
+ thinking_prompt,
409
+ model,
410
+ clean_history,
411
+ tools=None,
412
+ images=images,
413
+ tool_results=tool_results,
414
+ )
415
+
416
+ try:
417
+ async with httpx.AsyncClient(verify=False, timeout=timeout_s) as client:
418
+ resp = await client.post(KIRO_API_URL, json=thinking_request, headers=headers)
419
+ if resp.status_code != 200:
420
+ return ""
421
+ return parse_event_stream(resp.content)
422
+ except Exception:
423
+ return ""
424
+
425
+
426
+ async def stream_thinking_text(
427
+ *,
428
+ headers: dict,
429
+ model: str,
430
+ user_content: str,
431
+ history: list,
432
+ images: list | None = None,
433
+ tool_results: list | None = None,
434
+ budget_tokens: Optional[int] = None,
435
+ timeout_s: float = 600.0,
436
+ ) -> AsyncIterator[str]:
437
+ """Streaming helper to yield thinking content incrementally (best-effort)."""
438
+ thinking_prompt = build_thinking_prompt(user_content, budget_tokens=budget_tokens)
439
+ clean_history = strip_thinking_from_history(history)
440
+ thinking_request = build_kiro_request(
441
+ thinking_prompt,
442
+ model,
443
+ clean_history,
444
+ tools=None,
445
+ images=images,
446
+ tool_results=tool_results,
447
+ )
448
+
449
+ async with httpx.AsyncClient(verify=False, timeout=timeout_s) as client:
450
+ async with client.stream(
451
+ "POST", KIRO_API_URL, json=thinking_request, headers=headers
452
+ ) as response:
453
+ if response.status_code != 200:
454
+ return
455
+ async for piece in iter_aws_event_stream_text(response.aiter_bytes()):
456
+ yield piece
KiroProxy/kiro_proxy/core/usage.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Kiro 用量查询服务
2
+
3
+ 通过调用 AWS Q 的 getUsageLimits API 获取用户的用量信息。
4
+ """
5
+ import uuid
6
+ import httpx
7
+ from dataclasses import dataclass
8
+ from typing import Optional, Tuple
9
+
10
+
11
+ # API 端点
12
+ USAGE_LIMITS_URL = "https://q.us-east-1.amazonaws.com/getUsageLimits"
13
+
14
+ # 低余额阈值 (20%)
15
+ LOW_BALANCE_THRESHOLD = 0.2
16
+
17
+
18
+ @dataclass
19
+ class UsageInfo:
20
+ """用量信息"""
21
+ subscription_title: str = ""
22
+ usage_limit: float = 0.0
23
+ current_usage: float = 0.0
24
+ balance: float = 0.0
25
+ is_low_balance: bool = False
26
+
27
+ # 详细信息
28
+ free_trial_limit: float = 0.0
29
+ free_trial_usage: float = 0.0
30
+ bonus_limit: float = 0.0
31
+ bonus_usage: float = 0.0
32
+
33
+ # 重置和过期时间
34
+ next_reset_date: Optional[str] = None # 下次重置时间
35
+ free_trial_expiry: Optional[str] = None # 免费试用过期时间
36
+ bonus_expiries: list = None # 奖励过期时间列表
37
+
38
+ def __post_init__(self):
39
+ if self.bonus_expiries is None:
40
+ self.bonus_expiries = []
41
+
42
+
43
+ def build_usage_api_url(auth_method: str, profile_arn: Optional[str] = None) -> str:
44
+ """构造 API 请求 URL"""
45
+ url = f"{USAGE_LIMITS_URL}?origin=AI_EDITOR&resourceType=AGENTIC_REQUEST"
46
+
47
+ # Social 认证需要 profileArn
48
+ if auth_method == "social" and profile_arn:
49
+ from urllib.parse import quote
50
+ url += f"&profileArn={quote(profile_arn)}"
51
+
52
+ return url
53
+
54
+
55
+ def build_usage_headers(
56
+ access_token: str,
57
+ machine_id: str,
58
+ kiro_version: str = "1.0.0"
59
+ ) -> dict:
60
+ """构造请求头"""
61
+ import platform
62
+ os_name = platform.system().lower()
63
+
64
+ return {
65
+ "Authorization": f"Bearer {access_token}",
66
+ "User-Agent": f"aws-sdk-js/1.0.0 ua/2.1 os/{os_name} lang/python api/codewhispererruntime#1.0.0 m/N,E KiroIDE-{kiro_version}-{machine_id}",
67
+ "x-amz-user-agent": f"aws-sdk-js/1.0.0 KiroIDE-{kiro_version}-{machine_id}",
68
+ "amz-sdk-invocation-id": str(uuid.uuid4()),
69
+ "amz-sdk-request": "attempt=1; max=1",
70
+ "Connection": "close",
71
+ }
72
+
73
+
74
+ def calculate_balance(response: dict) -> UsageInfo:
75
+ """从 API 响应计算余额
76
+
77
+ 注意:只计算 resourceType 为 CREDIT 的额度,忽略其他类型(如 AGENTIC_REQUEST)
78
+ """
79
+ subscription_info = response.get("subscriptionInfo", {})
80
+ usage_breakdown_list = response.get("usageBreakdownList", [])
81
+
82
+ total_limit = 0.0
83
+ total_usage = 0.0
84
+ free_trial_limit = 0.0
85
+ free_trial_usage = 0.0
86
+ bonus_limit = 0.0
87
+ bonus_usage = 0.0
88
+
89
+ # 重置和过期时间
90
+ next_reset_date = response.get("nextDateReset") # 下次重置时间
91
+ free_trial_expiry = None
92
+ bonus_expiries = []
93
+
94
+ # 只查找 CREDIT 类型的额度
95
+ credit_breakdown = None
96
+ for breakdown in usage_breakdown_list:
97
+ resource_type = breakdown.get("resourceType", "")
98
+ display_name = breakdown.get("displayName", "")
99
+ if resource_type == "CREDIT" or display_name == "Credits":
100
+ credit_breakdown = breakdown
101
+ break
102
+
103
+ if credit_breakdown:
104
+ # 基本额度 (优先使用带精度的值)
105
+ total_limit = credit_breakdown.get("usageLimitWithPrecision", 0.0) or credit_breakdown.get("usageLimit", 0.0)
106
+ total_usage = credit_breakdown.get("currentUsageWithPrecision", 0.0) or credit_breakdown.get("currentUsage", 0.0)
107
+
108
+ # 免费试用额度 (只有状态为 ACTIVE 时才计算)
109
+ free_trial = credit_breakdown.get("freeTrialInfo")
110
+ if free_trial and free_trial.get("freeTrialStatus") == "ACTIVE":
111
+ ft_limit = free_trial.get("usageLimitWithPrecision", 0.0) or free_trial.get("usageLimit", 0.0)
112
+ ft_usage = free_trial.get("currentUsageWithPrecision", 0.0) or free_trial.get("currentUsage", 0.0)
113
+ total_limit += ft_limit
114
+ total_usage += ft_usage
115
+ free_trial_limit = ft_limit
116
+ free_trial_usage = ft_usage
117
+ # 获取免费试用过期时间
118
+ free_trial_expiry = free_trial.get("freeTrialExpiry")
119
+
120
+ # 奖励额度 (只计算状态为 ACTIVE 的奖励)
121
+ bonuses = credit_breakdown.get("bonuses", [])
122
+ for bonus in bonuses or []:
123
+ if bonus.get("status") == "ACTIVE":
124
+ b_limit = bonus.get("usageLimitWithPrecision", 0.0) or bonus.get("usageLimit", 0.0)
125
+ b_usage = bonus.get("currentUsageWithPrecision", 0.0) or bonus.get("currentUsage", 0.0)
126
+ total_limit += b_limit
127
+ total_usage += b_usage
128
+ bonus_limit += b_limit
129
+ bonus_usage += b_usage
130
+ # 获取奖励过期时间
131
+ expires_at = bonus.get("expiresAt")
132
+ if expires_at:
133
+ bonus_expiries.append(expires_at)
134
+
135
+ balance = total_limit - total_usage
136
+ is_low = (balance / total_limit) < LOW_BALANCE_THRESHOLD if total_limit > 0 else False
137
+
138
+ return UsageInfo(
139
+ subscription_title=subscription_info.get("subscriptionTitle", "Unknown"),
140
+ usage_limit=total_limit,
141
+ current_usage=total_usage,
142
+ balance=balance,
143
+ is_low_balance=is_low,
144
+ free_trial_limit=free_trial_limit,
145
+ free_trial_usage=free_trial_usage,
146
+ bonus_limit=bonus_limit,
147
+ bonus_usage=bonus_usage,
148
+ next_reset_date=next_reset_date,
149
+ free_trial_expiry=free_trial_expiry,
150
+ bonus_expiries=bonus_expiries,
151
+ )
152
+
153
+
154
+ async def get_usage_limits(
155
+ access_token: str,
156
+ auth_method: str = "social",
157
+ profile_arn: Optional[str] = None,
158
+ machine_id: str = "",
159
+ kiro_version: str = "1.0.0",
160
+ ) -> Tuple[bool, UsageInfo | dict]:
161
+ """
162
+ 获取 Kiro 用量信息
163
+
164
+ Args:
165
+ access_token: Bearer token
166
+ auth_method: 认证方式 ("social" 或 "idc")
167
+ profile_arn: Social 认证需要的 profileArn
168
+ machine_id: 设备 ID
169
+ kiro_version: Kiro 版本号
170
+
171
+ Returns:
172
+ (success, UsageInfo or error_dict)
173
+ """
174
+ if not access_token:
175
+ return False, {"error": "缺少 access token"}
176
+
177
+ if not machine_id:
178
+ return False, {"error": "缺少 machine ID"}
179
+
180
+ # 构造 URL 和请求头
181
+ url = build_usage_api_url(auth_method, profile_arn)
182
+ headers = build_usage_headers(access_token, machine_id, kiro_version)
183
+
184
+ try:
185
+ async with httpx.AsyncClient(timeout=10, verify=False) as client:
186
+ response = await client.get(url, headers=headers)
187
+
188
+ if response.status_code != 200:
189
+ return False, {"error": f"API 请求失败: {response.status_code} - {response.text[:200]}"}
190
+
191
+ data = response.json()
192
+ usage_info = calculate_balance(data)
193
+ return True, usage_info
194
+
195
+ except httpx.TimeoutException:
196
+ return False, {"error": "请求超时"}
197
+ except Exception as e:
198
+ return False, {"error": f"请求失败: {str(e)}"}
199
+
200
+
201
+ async def get_account_usage(account) -> Tuple[bool, UsageInfo | dict]:
202
+ """
203
+ 获取指定账号的用量信息
204
+
205
+ Args:
206
+ account: Account 对象
207
+
208
+ Returns:
209
+ (success, UsageInfo or error_dict)
210
+ """
211
+ from ..credential import get_kiro_version
212
+ from .refresh_manager import get_refresh_manager
213
+
214
+ creds = account.get_credentials()
215
+ if not creds:
216
+ return False, {"error": "无法获取凭证"}
217
+
218
+ # 先刷新 Token(如即将过期/已过期),避免额度获取失败
219
+ refresh_manager = get_refresh_manager()
220
+ if refresh_manager.should_refresh_token(account):
221
+ token_success, token_msg = await refresh_manager.refresh_token_if_needed(account)
222
+ if not token_success:
223
+ return False, {"error": f"Token 刷新失败: {token_msg}"}
224
+
225
+ token = account.get_token()
226
+ if not token:
227
+ return False, {"error": "无法获取 token"}
228
+
229
+ return await get_usage_limits(
230
+ access_token=token,
231
+ auth_method=creds.auth_method or "social",
232
+ profile_arn=creds.profile_arn,
233
+ machine_id=account.get_machine_id(),
234
+ kiro_version=get_kiro_version(),
235
+ )
KiroProxy/kiro_proxy/credential/__init__.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """凭证管理模块"""
2
+ from .fingerprint import generate_machine_id, get_kiro_version, get_system_info
3
+ from .quota import QuotaManager, QuotaRecord, quota_manager
4
+ from .refresher import TokenRefresher
5
+ from .types import KiroCredentials, CredentialStatus
6
+
7
+ __all__ = [
8
+ "generate_machine_id",
9
+ "get_kiro_version",
10
+ "get_system_info",
11
+ "QuotaManager",
12
+ "QuotaRecord",
13
+ "quota_manager",
14
+ "TokenRefresher",
15
+ "KiroCredentials",
16
+ "CredentialStatus",
17
+ ]
KiroProxy/kiro_proxy/credential/fingerprint.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """设备指纹生成"""
2
+ import hashlib
3
+ import platform
4
+ import subprocess
5
+ import time
6
+ from pathlib import Path
7
+ from typing import Optional
8
+
9
+
10
+ def get_raw_machine_id() -> Optional[str]:
11
+ """获取系统原始 Machine ID"""
12
+ system = platform.system()
13
+
14
+ try:
15
+ if system == "Darwin":
16
+ result = subprocess.run(
17
+ ["ioreg", "-rd1", "-c", "IOPlatformExpertDevice"],
18
+ capture_output=True, text=True, timeout=5
19
+ )
20
+ for line in result.stdout.split("\n"):
21
+ if "IOPlatformUUID" in line:
22
+ return line.split("=")[1].strip().strip('"').lower()
23
+
24
+ elif system == "Linux":
25
+ for path in ["/etc/machine-id", "/var/lib/dbus/machine-id"]:
26
+ if Path(path).exists():
27
+ return Path(path).read_text().strip().lower()
28
+
29
+ elif system == "Windows":
30
+ result = subprocess.run(
31
+ ["wmic", "csproduct", "get", "UUID"],
32
+ capture_output=True, text=True, timeout=5,
33
+ creationflags=0x08000000
34
+ )
35
+ lines = [l.strip() for l in result.stdout.split("\n") if l.strip()]
36
+ if len(lines) > 1:
37
+ return lines[1].lower()
38
+ except Exception:
39
+ pass
40
+
41
+ return None
42
+
43
+
44
+ def generate_machine_id(
45
+ profile_arn: Optional[str] = None,
46
+ client_id: Optional[str] = None
47
+ ) -> str:
48
+ """生成基于凭证的唯一 Machine ID
49
+
50
+ 每个凭证生成独立的 Machine ID,避免多账号共用同一指纹被检测。
51
+ 优先级:profileArn > clientId > 系统硬件 ID
52
+ 添加时间因子:按小时变化,避免指纹完全固化。
53
+ """
54
+ unique_key = None
55
+ if profile_arn:
56
+ unique_key = profile_arn
57
+ elif client_id:
58
+ unique_key = client_id
59
+ else:
60
+ unique_key = get_raw_machine_id() or "KIRO_DEFAULT_MACHINE"
61
+
62
+ hour_slot = int(time.time()) // 3600
63
+
64
+ hasher = hashlib.sha256()
65
+ hasher.update(unique_key.encode())
66
+ hasher.update(hour_slot.to_bytes(8, 'little'))
67
+
68
+ return hasher.hexdigest()
69
+
70
+
71
+ def get_kiro_version() -> str:
72
+ """获取 Kiro IDE 版本号
73
+
74
+ 优先检测本地安装的 Kiro,否则使用默认版本 (与 kiro.rs 保持一致)
75
+ """
76
+ if platform.system() == "Darwin":
77
+ kiro_paths = [
78
+ "/Applications/Kiro.app/Contents/Info.plist",
79
+ str(Path.home() / "Applications/Kiro.app/Contents/Info.plist"),
80
+ ]
81
+ for plist_path in kiro_paths:
82
+ try:
83
+ result = subprocess.run(
84
+ ["defaults", "read", plist_path, "CFBundleShortVersionString"],
85
+ capture_output=True, text=True, timeout=5
86
+ )
87
+ version = result.stdout.strip()
88
+ if version:
89
+ return version
90
+ except Exception:
91
+ pass
92
+
93
+ # 默认版本与 kiro.rs 保持一致
94
+ return "0.8.0"
95
+
96
+
97
+ def get_system_info() -> tuple:
98
+ """获取系统运行时信息 (os_name, node_version)
99
+
100
+ node_version 与 kiro.rs 保持一致
101
+ """
102
+ system = platform.system()
103
+
104
+ if system == "Darwin":
105
+ try:
106
+ result = subprocess.run(
107
+ ["sw_vers", "-productVersion"],
108
+ capture_output=True, text=True, timeout=5
109
+ )
110
+ version = result.stdout.strip() or "14.0"
111
+ os_name = f"macos#{version}"
112
+ except Exception:
113
+ os_name = "macos#14.0"
114
+ elif system == "Linux":
115
+ try:
116
+ result = subprocess.run(
117
+ ["uname", "-r"],
118
+ capture_output=True, text=True, timeout=5
119
+ )
120
+ version = result.stdout.strip() or "5.15.0"
121
+ os_name = f"linux#{version}"
122
+ except Exception:
123
+ os_name = "linux#5.15.0"
124
+ elif system == "Windows":
125
+ os_name = "windows#10.0"
126
+ else:
127
+ os_name = "other#1.0"
128
+
129
+ # Node 版本与 kiro.rs 保持一致
130
+ node_version = "22.11.0"
131
+ return os_name, node_version
KiroProxy/kiro_proxy/credential/quota.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """配额管理"""
2
+ import time
3
+ from dataclasses import dataclass
4
+ from typing import Dict, Optional
5
+
6
+
7
+ @dataclass
8
+ class QuotaRecord:
9
+ """配额超限记录"""
10
+ credential_id: str
11
+ exceeded_at: float
12
+ cooldown_until: float
13
+ reason: str
14
+
15
+
16
+ class QuotaManager:
17
+ """配额管理器
18
+
19
+ 管理凭证的配额超限状态:
20
+ - 仅在收到 429 错误时触发冷却
21
+ - 自动管理冷却时间:固定 5 分钟(300秒)
22
+ - 自动清理过期的冷却状态
23
+ """
24
+
25
+ # 固定冷却时间(秒)- 429 错误自动冷却 5 分钟
26
+ COOLDOWN_SECONDS = 300
27
+
28
+ def __init__(self):
29
+ self.exceeded_records: Dict[str, QuotaRecord] = {}
30
+
31
+ def is_429_error(self, status_code: Optional[int]) -> bool:
32
+ """检查是否为 429 错误(仅 429 触发冷却)"""
33
+ return status_code == 429
34
+
35
+ def is_quota_exceeded_error(self, status_code: Optional[int], error_message: str) -> bool:
36
+ """检查是否为配额超限错误(仅用于判断是否切换账号,不触发冷却)"""
37
+ # 仅 429 才算配额超限
38
+ return status_code == 429
39
+
40
+ def mark_exceeded(self, credential_id: str, reason: str) -> QuotaRecord:
41
+ """标记凭证为配额超限(仅 429 时调用)
42
+
43
+ 自动管理冷却时间:固定 5 分钟(300秒)
44
+ """
45
+ now = time.time()
46
+
47
+ record = QuotaRecord(
48
+ credential_id=credential_id,
49
+ exceeded_at=now,
50
+ cooldown_until=now + self.COOLDOWN_SECONDS,
51
+ reason=reason
52
+ )
53
+ self.exceeded_records[credential_id] = record
54
+
55
+ print(f"[QuotaManager] 账号 {credential_id} 遇到 429 错误,自动冷却 {self.COOLDOWN_SECONDS} 秒(5分钟)")
56
+ return record
57
+
58
+ def is_available(self, credential_id: str) -> bool:
59
+ """检查凭证是否可用"""
60
+ record = self.exceeded_records.get(credential_id)
61
+ if not record:
62
+ return True
63
+
64
+ if time.time() >= record.cooldown_until:
65
+ del self.exceeded_records[credential_id]
66
+ return True
67
+
68
+ return False
69
+
70
+ def get_cooldown_remaining(self, credential_id: str) -> Optional[int]:
71
+ """获取剩余冷却时间(秒)"""
72
+ record = self.exceeded_records.get(credential_id)
73
+ if not record:
74
+ return None
75
+
76
+ remaining = record.cooldown_until - time.time()
77
+ if remaining <= 0:
78
+ del self.exceeded_records[credential_id]
79
+ return None
80
+
81
+ return int(remaining)
82
+
83
+ def cleanup_expired(self) -> int:
84
+ """清理过期的冷却记录"""
85
+ now = time.time()
86
+ expired = [k for k, v in self.exceeded_records.items() if now >= v.cooldown_until]
87
+ for k in expired:
88
+ del self.exceeded_records[k]
89
+ return len(expired)
90
+
91
+ def restore(self, credential_id: str) -> bool:
92
+ """手动恢复凭证"""
93
+ if credential_id in self.exceeded_records:
94
+ del self.exceeded_records[credential_id]
95
+ return True
96
+ return False
97
+
98
+
99
+ # 全局实例 - 429 自动冷却 5 分钟
100
+ quota_manager = QuotaManager()
KiroProxy/kiro_proxy/credential/refresher.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Token 刷新器"""
2
+ import httpx
3
+ from datetime import datetime, timezone, timedelta
4
+ from typing import Tuple
5
+
6
+ from .types import KiroCredentials
7
+ from .fingerprint import generate_machine_id, get_kiro_version
8
+
9
+
10
+ # Kiro Auth 端点
11
+ KIRO_AUTH_ENDPOINT = "https://prod.us-east-1.auth.desktop.kiro.dev"
12
+
13
+
14
+ class TokenRefresher:
15
+ """Token 刷新器"""
16
+
17
+ def __init__(self, credentials: KiroCredentials):
18
+ self.credentials = credentials
19
+
20
+ def get_refresh_url(self) -> str:
21
+ """获取刷新 URL"""
22
+ region = self.credentials.region or "us-east-1"
23
+ auth_method = (self.credentials.auth_method or "social").lower()
24
+
25
+ if auth_method == "idc":
26
+ # IDC (AWS Builder ID) 使用 OIDC 端点
27
+ return f"https://oidc.{region}.amazonaws.com/token"
28
+ else:
29
+ # Social (Google/GitHub) 使用 Kiro Auth 端点
30
+ return f"{KIRO_AUTH_ENDPOINT}/refreshToken"
31
+
32
+ def validate_refresh_token(self) -> Tuple[bool, str]:
33
+ """验证 refresh_token 有效性"""
34
+ refresh_token = self.credentials.refresh_token
35
+
36
+ if not refresh_token:
37
+ return False, "缺少 refresh_token"
38
+
39
+ if len(refresh_token.strip()) == 0:
40
+ return False, "refresh_token 为空"
41
+
42
+ if len(refresh_token) < 100 or refresh_token.endswith("..."):
43
+ return False, f"refresh_token 已被截断(长度: {len(refresh_token)})"
44
+
45
+ return True, ""
46
+
47
+ def _get_machine_id(self) -> str:
48
+ """获取 Machine ID"""
49
+ return generate_machine_id(
50
+ self.credentials.profile_arn,
51
+ self.credentials.client_id
52
+ )
53
+
54
+ async def refresh_social_token(self) -> Tuple[bool, str]:
55
+ """
56
+ 刷新 Social Token (Google/GitHub)
57
+
58
+ 参考 Kiro-account-manager 实现:
59
+ - 端点: https://prod.us-east-1.auth.desktop.kiro.dev/refreshToken
60
+ - 请求体: {"refreshToken": refresh_token}
61
+ - 响应: {accessToken, refreshToken, expiresIn}
62
+ """
63
+ refresh_url = f"{KIRO_AUTH_ENDPOINT}/refreshToken"
64
+
65
+ body = {"refreshToken": self.credentials.refresh_token}
66
+ headers = {
67
+ "Content-Type": "application/json",
68
+ "User-Agent": "kiro-proxy/1.0.0",
69
+ "Accept": "application/json",
70
+ }
71
+
72
+ try:
73
+ async with httpx.AsyncClient(verify=False, timeout=30) as client:
74
+ resp = await client.post(refresh_url, json=body, headers=headers)
75
+
76
+ if resp.status_code != 200:
77
+ error_text = resp.text
78
+ if resp.status_code == 401:
79
+ return False, "凭证已过期或无效,需要重新登录"
80
+ elif resp.status_code == 429:
81
+ return False, "请求过于频繁,请稍后重试"
82
+ else:
83
+ return False, f"刷新失败: {resp.status_code} - {error_text[:200]}"
84
+
85
+ data = resp.json()
86
+
87
+ new_token = data.get("accessToken")
88
+ if not new_token:
89
+ return False, "响应中没有 accessToken"
90
+
91
+ # 更新凭证
92
+ self.credentials.access_token = new_token
93
+
94
+ # 更新 refreshToken(如果服务器返回了新的)
95
+ if rt := data.get("refreshToken"):
96
+ self.credentials.refresh_token = rt
97
+
98
+ # 更新过期时间
99
+ if expires_in := data.get("expiresIn"):
100
+ expires_at = datetime.now(timezone.utc) + timedelta(seconds=expires_in)
101
+ self.credentials.expires_at = expires_at.isoformat()
102
+
103
+ self.credentials.last_refresh = datetime.now(timezone.utc).isoformat()
104
+
105
+ print(f"[TokenRefresher] Social token 刷新成功,过期时间: {expires_in}s")
106
+ return True, new_token
107
+
108
+ except Exception as e:
109
+ return False, f"刷新异常: {str(e)}"
110
+
111
+ async def refresh_idc_token(self) -> Tuple[bool, str]:
112
+ """
113
+ 刷新 IDC Token (AWS Builder ID)
114
+
115
+ 使用 AWS OIDC 端点刷新
116
+ """
117
+ region = self.credentials.region or "us-east-1"
118
+ refresh_url = f"https://oidc.{region}.amazonaws.com/token"
119
+
120
+ if not self.credentials.client_id or not self.credentials.client_secret:
121
+ return False, "IdC 认证缺少 client_id 或 client_secret"
122
+
123
+ machine_id = self._get_machine_id()
124
+ kiro_version = get_kiro_version()
125
+
126
+ body = {
127
+ "refreshToken": self.credentials.refresh_token,
128
+ "clientId": self.credentials.client_id,
129
+ "clientSecret": self.credentials.client_secret,
130
+ "grantType": "refresh_token"
131
+ }
132
+ headers = {
133
+ "Content-Type": "application/json",
134
+ "x-amz-user-agent": f"aws-sdk-js/3.738.0 KiroIDE-{kiro_version}-{machine_id}",
135
+ "User-Agent": "node",
136
+ }
137
+
138
+ try:
139
+ async with httpx.AsyncClient(verify=False, timeout=30) as client:
140
+ resp = await client.post(refresh_url, json=body, headers=headers)
141
+
142
+ if resp.status_code != 200:
143
+ error_text = resp.text
144
+ if resp.status_code == 401:
145
+ return False, "凭证已过期或无效,需要重新登录"
146
+ elif resp.status_code == 429:
147
+ return False, "请求过于频繁,请稍后重试"
148
+ else:
149
+ return False, f"刷新失败: {resp.status_code} - {error_text[:200]}"
150
+
151
+ data = resp.json()
152
+
153
+ new_token = data.get("accessToken") or data.get("access_token")
154
+ if not new_token:
155
+ return False, "响应中没有 access_token"
156
+
157
+ # 更新凭证
158
+ self.credentials.access_token = new_token
159
+
160
+ if rt := data.get("refreshToken") or data.get("refresh_token"):
161
+ self.credentials.refresh_token = rt
162
+
163
+ if arn := data.get("profileArn"):
164
+ self.credentials.profile_arn = arn
165
+
166
+ if expires_in := data.get("expiresIn") or data.get("expires_in"):
167
+ expires_at = datetime.now(timezone.utc) + timedelta(seconds=expires_in)
168
+ self.credentials.expires_at = expires_at.isoformat()
169
+
170
+ self.credentials.last_refresh = datetime.now(timezone.utc).isoformat()
171
+
172
+ print(f"[TokenRefresher] IDC token 刷新成功")
173
+ return True, new_token
174
+
175
+ except Exception as e:
176
+ return False, f"刷新异常: {str(e)}"
177
+
178
+ async def refresh(self) -> Tuple[bool, str]:
179
+ """
180
+ 刷新 token,根据 authMethod 分发到正确的刷新方法
181
+
182
+ Returns:
183
+ (success, new_token_or_error)
184
+ """
185
+ is_valid, error = self.validate_refresh_token()
186
+ if not is_valid:
187
+ return False, error
188
+
189
+ auth_method = (self.credentials.auth_method or "social").lower()
190
+
191
+ if auth_method == "idc":
192
+ return await self.refresh_idc_token()
193
+ else:
194
+ # social 或其他默认使用 social 刷新
195
+ return await self.refresh_social_token()
KiroProxy/kiro_proxy/credential/types.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """凭证数据类型"""
2
+ import json
3
+ import time
4
+ from dataclasses import dataclass
5
+ from datetime import datetime, timezone, timedelta
6
+ from enum import Enum
7
+ from pathlib import Path
8
+ from typing import Optional
9
+
10
+
11
+ class CredentialStatus(Enum):
12
+ """凭证状态"""
13
+ ACTIVE = "active"
14
+ COOLDOWN = "cooldown"
15
+ UNHEALTHY = "unhealthy"
16
+ DISABLED = "disabled"
17
+ SUSPENDED = "suspended" # 账号被封禁
18
+
19
+
20
+ @dataclass
21
+ class KiroCredentials:
22
+ """Kiro 凭证信息"""
23
+ access_token: Optional[str] = None
24
+ refresh_token: Optional[str] = None
25
+ client_id: Optional[str] = None
26
+ client_secret: Optional[str] = None
27
+ profile_arn: Optional[str] = None
28
+ expires_at: Optional[str] = None
29
+ region: str = "us-east-1"
30
+ auth_method: str = "social"
31
+ provider: Optional[str] = None # Google / Github (社交登录提供商)
32
+ client_id_hash: Optional[str] = None
33
+ last_refresh: Optional[str] = None
34
+
35
+ @classmethod
36
+ def from_file(cls, path: str) -> "KiroCredentials":
37
+ """从文件加载凭证"""
38
+ with open(path) as f:
39
+ data = json.load(f)
40
+
41
+ return cls(
42
+ access_token=data.get("accessToken"),
43
+ refresh_token=data.get("refreshToken"),
44
+ client_id=data.get("clientId"),
45
+ client_secret=data.get("clientSecret"),
46
+ profile_arn=data.get("profileArn"),
47
+ expires_at=data.get("expiresAt") or data.get("expire"),
48
+ region=data.get("region", "us-east-1"),
49
+ auth_method=data.get("authMethod", "social"),
50
+ provider=data.get("provider"),
51
+ client_id_hash=data.get("clientIdHash"),
52
+ last_refresh=data.get("lastRefresh"),
53
+ )
54
+
55
+ def to_dict(self) -> dict:
56
+ """转换为字典"""
57
+ result = {
58
+ "accessToken": self.access_token,
59
+ "refreshToken": self.refresh_token,
60
+ "clientId": self.client_id,
61
+ "clientSecret": self.client_secret,
62
+ "profileArn": self.profile_arn,
63
+ "expiresAt": self.expires_at,
64
+ "region": self.region,
65
+ "authMethod": self.auth_method,
66
+ "clientIdHash": self.client_id_hash,
67
+ "lastRefresh": self.last_refresh,
68
+ }
69
+ # 只有社交登录才添加 provider 字段
70
+ if self.provider:
71
+ result["provider"] = self.provider
72
+ return result
73
+
74
+ def save_to_file(self, path: str):
75
+ """保存凭证到文件"""
76
+ existing = {}
77
+ if Path(path).exists():
78
+ try:
79
+ with open(path) as f:
80
+ existing = json.load(f)
81
+ except Exception:
82
+ pass
83
+
84
+ existing.update({k: v for k, v in self.to_dict().items() if v is not None})
85
+
86
+ with open(path, "w") as f:
87
+ json.dump(existing, f, indent=2)
88
+
89
+ def is_expired(self) -> bool:
90
+ """检查 token 是否已过期"""
91
+ if not self.expires_at:
92
+ return True
93
+
94
+ try:
95
+ if "T" in self.expires_at:
96
+ expires = datetime.fromisoformat(self.expires_at.replace("Z", "+00:00"))
97
+ now = datetime.now(timezone.utc)
98
+ return expires <= now + timedelta(minutes=5)
99
+
100
+ expires_ts = int(self.expires_at)
101
+ now_ts = int(time.time())
102
+ return now_ts >= (expires_ts - 300)
103
+ except Exception:
104
+ return True
105
+
106
+ def is_expiring_soon(self, minutes: int = 10) -> bool:
107
+ """检查 token 是否即将过期"""
108
+ if not self.expires_at:
109
+ return False
110
+
111
+ try:
112
+ if "T" in self.expires_at:
113
+ expires = datetime.fromisoformat(self.expires_at.replace("Z", "+00:00"))
114
+ now = datetime.now(timezone.utc)
115
+ return expires < now + timedelta(minutes=minutes)
116
+
117
+ expires_ts = int(self.expires_at)
118
+ now_ts = int(time.time())
119
+ return now_ts >= (expires_ts - minutes * 60)
120
+ except Exception:
121
+ return False
KiroProxy/kiro_proxy/docs/01-quickstart.md ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 快速开始
2
+
3
+ ## 安装运行
4
+
5
+ ### 方式一:下载预编译版本
6
+
7
+ 从 [Releases](https://github.com/yourname/kiro-proxy/releases) 下载对应平台的安装包:
8
+
9
+ - **Windows**: `kiro-proxy-windows.zip`
10
+ - **macOS**: `kiro-proxy-macos.zip`
11
+ - **Linux**: `kiro-proxy-linux.tar.gz`
12
+
13
+ 解压后双击运行即可。
14
+
15
+ ### 方式二:从源码运行
16
+
17
+ ```bash
18
+ # 克隆项目
19
+ git clone https://github.com/yourname/kiro-proxy.git
20
+ cd kiro-proxy
21
+
22
+ # 创建虚拟环境
23
+ python -m venv venv
24
+ source venv/bin/activate # Windows: venv\Scripts\activate
25
+
26
+ # 安装依赖
27
+ pip install -r requirements.txt
28
+
29
+ # 运行(默认端口 8080)
30
+ python run.py
31
+
32
+ # 指定端口
33
+ python run.py 8081
34
+ ```
35
+
36
+ 启动成功后,访问 http://localhost:8080 打开管理界面。
37
+
38
+ ---
39
+
40
+ ## 获取 Kiro 账号
41
+
42
+ Kiro Proxy 需要 Kiro 账号的 Token 才能工作。有两种方式获取:
43
+
44
+ ### 方式一:在线登录(推荐)
45
+
46
+ 1. 打开 Web UI,点击「账号」标签页
47
+ 2. 点击「在线登录」按钮
48
+ 3. 选择登录方式:
49
+ - **Google** - 使用 Google 账号
50
+ - **GitHub** - 使用 GitHub 账号
51
+ - **AWS** - 使用 AWS Builder ID
52
+ 4. 在弹出的浏览器中完成授权
53
+ 5. 授权成功后,账号自动添加到代理
54
+
55
+ ### 方式二:扫描本地 Token
56
+
57
+ 如果你已经在 Kiro IDE 中登录过:
58
+
59
+ 1. 打开 Kiro IDE,确保已登录
60
+ 2. 回到 Web UI,点击「扫描 Token」
61
+ 3. 系统会扫描 `~/.aws/sso/cache/` 目录
62
+ 4. 选择要添加的 Token 文件
63
+
64
+ ---
65
+
66
+ ## 配置 AI 客户端
67
+
68
+ ### Claude Code (VSCode 插件)
69
+
70
+ 这是最推荐的使用方式,工具调用功能已验证可用。
71
+
72
+ 1. 安装 Claude Code 插件
73
+ 2. 打开设置,添加自定义 Provider:
74
+
75
+ ```
76
+ 名称: Kiro Proxy
77
+ API Provider: Anthropic
78
+ API Key: any(随便填一个)
79
+ Base URL: http://localhost:8080
80
+ 模型: claude-sonnet-4
81
+ ```
82
+
83
+ 3. 选择 Kiro Proxy 作为当前 Provider
84
+
85
+ ### Codex CLI
86
+
87
+ OpenAI 官方命令行工具。
88
+
89
+ ```bash
90
+ # 安装
91
+ npm install -g @openai/codex
92
+
93
+ # 配置 (~/.codex/config.toml)
94
+ model = "gpt-4o"
95
+ model_provider = "kiro"
96
+
97
+ [model_providers.kiro]
98
+ name = "Kiro Proxy"
99
+ base_url = "http://localhost:8080/v1"
100
+ ```
101
+
102
+ ### Gemini CLI
103
+
104
+ ```bash
105
+ # 设置环境变量
106
+ export GEMINI_API_BASE=http://localhost:8080/v1
107
+
108
+ # 或在配置文件中设置
109
+ base_url = "http://localhost:8080/v1"
110
+ model = "gemini-pro"
111
+ ```
112
+
113
+ ### 其他兼容客户端
114
+
115
+ 任何支持 OpenAI 或 Anthropic API 的客户端都可以使用:
116
+
117
+ - **Base URL**: `http://localhost:8080` 或 `http://localhost:8080/v1`
118
+ - **API Key**: 任意值(代理不验证)
119
+ - **模型**: 见下方模型对照表
120
+
121
+ ---
122
+
123
+ ## 模型对照表
124
+
125
+ Kiro 支持以下模型,你可以使用 Kiro 原生名称或映射名称:
126
+
127
+ | Kiro 模型 | 能力 | 可用名称(任选其一) |
128
+ |-----------|------|---------------------|
129
+ | `claude-sonnet-4` | ⭐⭐⭐ 推荐,性价比最高 | `gpt-4o`, `gpt-4`, `gpt-4-turbo`, `claude-3-5-sonnet-20241022`, `claude-3-5-sonnet-latest`, `sonnet` |
130
+ | `claude-sonnet-4.5` | ⭐⭐⭐⭐ 更强,适合复杂任务 | `gemini-1.5-pro`, `o1`, `o1-preview`, `claude-3-opus-20240229`, `claude-3-opus-latest`, `claude-4-opus`, `opus` |
131
+ | `claude-haiku-4.5` | ⚡ 快速,适合简单任务 | `gpt-4o-mini`, `gpt-3.5-turbo`, `claude-3-5-haiku-20241022`, `haiku` |
132
+ | `auto` | 🤖 自动选择 | `auto` |
133
+
134
+ ### 各客户端推荐配置
135
+
136
+ | 客户端 | 推荐模型名 | 实际使用 |
137
+ |--------|-----------|---------|
138
+ | Claude Code | `claude-sonnet-4` 或 `claude-sonnet-4.5` | 直接使用 Kiro 模型名 |
139
+ | Codex CLI | `gpt-4o` | 映射到 claude-sonnet-4 |
140
+ | Gemini CLI | `gemini-1.5-pro` | 映射到 claude-sonnet-4.5 |
141
+ | 其他 OpenAI 客户端 | `gpt-4o` | 映射到 claude-sonnet-4 |
142
+
143
+ > 💡 **提示**:不确定用什么模型?直接用 `claude-sonnet-4` 或 `gpt-4o`,性价比最高。
KiroProxy/kiro_proxy/docs/02-features.md ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 功能特性
2
+
3
+ ## 多协议支持
4
+
5
+ Kiro Proxy 支持三种主流 AI API 协议,可以适配不同的客户端:
6
+
7
+ | 协议 | 端点 | 适用客户端 |
8
+ |------|------|------------|
9
+ | OpenAI | `/v1/chat/completions` | Codex CLI, ChatGPT 客户端 |
10
+ | Anthropic | `/v1/messages` | Claude Code, Claude 客户端 |
11
+ | Gemini | `/v1/models/{model}:generateContent` | Gemini CLI |
12
+
13
+ 代理会自动将请求转换为 Kiro API 格式,响应转换回对应协议格式。
14
+
15
+ ---
16
+
17
+ ## 工具调用支持
18
+
19
+ 完整支持三种协议的工具调用功能:
20
+
21
+ ### Anthropic 协议(Claude Code)
22
+
23
+ - `tools` 定义和 `tool_result` 响应完整支持
24
+ - `tool_choice: required` 支持(通过 prompt 注入)
25
+ - `web_search` 特殊工具自动识别
26
+ - 工具数量限制(最多 50 个)
27
+ - 描述截断(超过 500 字符自动截断)
28
+
29
+ ### OpenAI 协议(Codex CLI)
30
+
31
+ - `tools` 定义(function 类型)
32
+ - `tool_calls` 响应处理
33
+ - `tool` 角色消息转换
34
+ - `tool_choice: required/any` 支持
35
+ - 工具数量限制和描述截断
36
+
37
+ ### Gemini 协议
38
+
39
+ - `functionDeclarations` 工具定义
40
+ - `functionCall` 响应处理
41
+ - `functionResponse` 工具结果
42
+ - `toolConfig.functionCallingConfig.mode` 支持(ANY/REQUIRED)
43
+ - 工具数量限制和描述截断
44
+
45
+ ### 历史消息修复
46
+
47
+ Kiro API 要求消息必须严格交替(user → assistant → user → assistant),代理会自动:
48
+
49
+ - 检测并修复连续的同角色消息
50
+ - 合并重复的 tool_results
51
+ - 插入占位消息保持交替
52
+
53
+ ---
54
+
55
+ ## 多账号管理
56
+
57
+ ### 账号轮询
58
+
59
+ 支持添加多个 Kiro 账号,代理会自动轮询使用(默认随机):
60
+
61
+ - 每次请求随机选择一个可用账号(尽量避免连续命中同一账号)
62
+ - 自动跳过冷却中或不健康的账号
63
+ - 分散请求压力,降低单账号 RPM 过高导致封禁风险
64
+
65
+ ### 会话粘性(可选)
66
+
67
+ 为了保持对话上下文的连贯性,在非 `random` 策略下会启用会话粘性:
68
+
69
+ - 同一会话 ID 在 60 秒内会使用同一账号
70
+ - 超过 60 秒或账号不可用时才切换
71
+ - 会话 ID 由请求内容生成;可通过 `~/.kiro-proxy/priority.json` 中的 `strategy` 调整策略
72
+
73
+ ### 账号状态
74
+
75
+ 每个账号有四种状态:
76
+
77
+ | 状态 | 说明 | 颜色 |
78
+ |------|------|------|
79
+ | Active | 正常可用 | 绿色 |
80
+ | Cooldown | 触发限流,冷却中 | 黄色 |
81
+ | Unhealthy | 健康检查失败 | 红色 |
82
+ | Disabled | 手动禁用 | 灰色 |
83
+
84
+ ---
85
+
86
+ ## Token 自动刷新
87
+
88
+ ### 自动检测
89
+
90
+ - 后台每 5 分钟检查所有账号的 Token 状态
91
+ - 检测 Token 是否即将过期(15 分钟内)
92
+
93
+ ### 自动刷新
94
+
95
+ - 发现即将过期的 Token 自动刷新
96
+ - 支持 Social 认证(Google/GitHub)的 refresh_token
97
+ - 刷新失败会标记账号为不健康
98
+
99
+ ### 手动刷新
100
+
101
+ - 在账号卡片点击「刷新 Token」
102
+ - 或点击「刷新所有 Token」批量刷新
103
+
104
+ ---
105
+
106
+ ## 配额管理
107
+
108
+ ### 429 自动处理
109
+
110
+ 当 Kiro API 返回 429 (Too Many Requests) 时:
111
+
112
+ 1. 自动将该账号标记为 Cooldown 状态
113
+ 2. 设置 5 分钟冷却时间
114
+ 3. 立即切换到其他可用账号重试
115
+ 4. 冷却结束后自动恢复
116
+
117
+ ### 手动恢复
118
+
119
+ 如果需要提前恢复账号:
120
+
121
+ 1. 在「监控」页面查看配额状态
122
+ 2. 点击账号旁的「恢复」按钮
123
+
124
+ ---
125
+
126
+ ## 流量监控
127
+
128
+ ### 请求记录
129
+
130
+ 记录所有经过代理的 LLM 请求:
131
+
132
+ - 请求时间、模型、账号
133
+ - 输入/输出 Token 数量
134
+ - 响应时间、状态码
135
+ - 完整的请求和响应内容
136
+
137
+ ### 搜索过滤
138
+
139
+ - 按协议筛选(OpenAI/Anthropic/Gemini)
140
+ - 按状态筛选(完成/错误/进行中)
141
+ - 关键词搜索
142
+
143
+ ### 导出功能
144
+
145
+ - 支持导出为 JSON 格式
146
+ - 可选择导出全部或指定记录
147
+
148
+ ---
149
+
150
+ ## 登录方式
151
+
152
+ ### Google 登录
153
+
154
+ 使用 Google 账号通过 OAuth 授权登录。
155
+
156
+ ### GitHub 登录
157
+
158
+ 使用 GitHub 账号通过 OAuth 授权登录。
159
+
160
+ ### AWS Builder ID
161
+
162
+ 使用 AWS Builder ID 通过 Device Code Flow 登录:
163
+
164
+ 1. 点击 AWS 登录按钮
165
+ 2. 复制显示的授权码
166
+ 3. 在浏览器中打开授权页面
167
+ 4. 输入授权码完成登录
168
+
169
+ ---
170
+
171
+ ## 历史消息管理
172
+
173
+ ### 对话长度限制
174
+
175
+ Kiro API 有输入长度限制,当对话历史过长时会返回 `CONTENT_LENGTH_EXCEEDS_THRESHOLD` 错误。
176
+
177
+ 代理内置了多种策略自动处理这个问题:
178
+
179
+ ### 可用策略
180
+
181
+ | 策略 | 说明 | 触发时机 |
182
+ |------|------|----------|
183
+ | 自动截断 | 优先保留最新上下文并摘要前文,必要时截断 | 每次请求前 |
184
+ | 智能摘要 | 用 AI 生成早期对话摘要 | 超过阈值时 |
185
+ | 错误重试 | 遇到长度错误时截断重试 | 收到错误后 |
186
+ | 预估检测 | 预估 token 数量,超限预先截断 | 每次请求前 |
187
+
188
+ ### 配置选项
189
+
190
+ 在「设置」页面可以配置:
191
+
192
+ - **最大消息数** - 自动截断时保留的消息数量(默认 30)
193
+ - **最大字符数** - 自动截断时的字符数限制(默认 150000)
194
+ - **重试保留数** - 错误重试时保留的消息数(默认 20)
195
+ - **最大重试次数** - 错误重试的最大次数(默认 2)
196
+ - **摘要保留数** - 智能摘要时保留的最近消息数(默认 10)
197
+ - **摘要阈值** - 触发智能摘要的字符数阈值(默认 100000)
198
+ - **添加警告** - 截断时是否在日志中记录
199
+
200
+ ### 推荐配置
201
+
202
+ - **默认**:只启用「错误重试」,遇到问题时自动处理
203
+ - **保守**:启用「智能摘要 + 错误重试」,保留关键信息
204
+ - **激进**:启用「自动截断 + 预估检测」,预防性截断
205
+
206
+ ---
207
+
208
+ ## 配置持久化
209
+
210
+ ### 自动保存
211
+
212
+ 账号配置自动保存到 `~/.kiro-proxy/config.json`:
213
+
214
+ - 账号列表和状态
215
+ - 启用/禁用设置
216
+ - Token 文件路径
217
+
218
+ ### 重启恢复
219
+
220
+ 重启代理后自动加载保存的配置,无需重新添加账号。
221
+
222
+ ### 导入导出
223
+
224
+ - 「导出配置」下载当前配置
225
+ - 「导入配置」从文件恢复