KiroProxy User commited on
Commit ·
d3cadd5
0
Parent(s):
Initial commit: KiroProxy project
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- KiroProxy/.github/workflows/build.yml +245 -0
- KiroProxy/.gitignore +54 -0
- KiroProxy/CAPTURE_GUIDE.md +0 -0
- KiroProxy/README.md +423 -0
- KiroProxy/assets/icon.iconset/icon_128x128.png +0 -0
- KiroProxy/assets/icon.iconset/icon_16x16.png +0 -0
- KiroProxy/assets/icon.iconset/icon_256x256.png +0 -0
- KiroProxy/assets/icon.iconset/icon_32x32.png +0 -0
- KiroProxy/assets/icon.iconset/icon_512x512.png +0 -0
- KiroProxy/assets/icon.iconset/icon_64x64.png +0 -0
- KiroProxy/assets/icon.png +0 -0
- KiroProxy/assets/icon.svg +1 -0
- KiroProxy/build.py +219 -0
- KiroProxy/examples/quota_display_example.py +95 -0
- KiroProxy/examples/test_quota_display.html +118 -0
- KiroProxy/kiro.svg +1 -0
- KiroProxy/kiro_proxy/__init__.py +2 -0
- KiroProxy/kiro_proxy/__main__.py +5 -0
- KiroProxy/kiro_proxy/auth/__init__.py +32 -0
- KiroProxy/kiro_proxy/auth/device_flow.py +603 -0
- KiroProxy/kiro_proxy/cli.py +375 -0
- KiroProxy/kiro_proxy/config.py +133 -0
- KiroProxy/kiro_proxy/converters/__init__.py +1196 -0
- KiroProxy/kiro_proxy/core/__init__.py +55 -0
- KiroProxy/kiro_proxy/core/account.py +287 -0
- KiroProxy/kiro_proxy/core/account_selector.py +390 -0
- KiroProxy/kiro_proxy/core/browser.py +186 -0
- KiroProxy/kiro_proxy/core/error_handler.py +188 -0
- KiroProxy/kiro_proxy/core/flow_monitor.py +572 -0
- KiroProxy/kiro_proxy/core/history_manager.py +829 -0
- KiroProxy/kiro_proxy/core/kiro_api.py +146 -0
- KiroProxy/kiro_proxy/core/persistence.py +69 -0
- KiroProxy/kiro_proxy/core/protocol_handler.py +318 -0
- KiroProxy/kiro_proxy/core/quota_cache.py +397 -0
- KiroProxy/kiro_proxy/core/quota_scheduler.py +321 -0
- KiroProxy/kiro_proxy/core/rate_limiter.py +125 -0
- KiroProxy/kiro_proxy/core/refresh_manager.py +888 -0
- KiroProxy/kiro_proxy/core/retry.py +117 -0
- KiroProxy/kiro_proxy/core/scheduler.py +125 -0
- KiroProxy/kiro_proxy/core/state.py +280 -0
- KiroProxy/kiro_proxy/core/stats.py +130 -0
- KiroProxy/kiro_proxy/core/thinking.py +456 -0
- KiroProxy/kiro_proxy/core/usage.py +235 -0
- KiroProxy/kiro_proxy/credential/__init__.py +17 -0
- KiroProxy/kiro_proxy/credential/fingerprint.py +131 -0
- KiroProxy/kiro_proxy/credential/quota.py +100 -0
- KiroProxy/kiro_proxy/credential/refresher.py +195 -0
- KiroProxy/kiro_proxy/credential/types.py +121 -0
- KiroProxy/kiro_proxy/docs/01-quickstart.md +143 -0
- KiroProxy/kiro_proxy/docs/02-features.md +225 -0
KiroProxy/.github/workflows/build.yml
ADDED
|
@@ -0,0 +1,245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: Build Release
|
| 2 |
+
|
| 3 |
+
on:
|
| 4 |
+
push:
|
| 5 |
+
tags:
|
| 6 |
+
- 'v*'
|
| 7 |
+
workflow_dispatch:
|
| 8 |
+
|
| 9 |
+
permissions:
|
| 10 |
+
contents: write
|
| 11 |
+
|
| 12 |
+
env:
|
| 13 |
+
APP_NAME: KiroProxy
|
| 14 |
+
|
| 15 |
+
jobs:
|
| 16 |
+
build-linux:
|
| 17 |
+
runs-on: ubuntu-latest
|
| 18 |
+
steps:
|
| 19 |
+
- uses: actions/checkout@v4
|
| 20 |
+
|
| 21 |
+
- name: Get version from tag
|
| 22 |
+
id: version
|
| 23 |
+
run: |
|
| 24 |
+
if [[ "${{ github.ref }}" == refs/tags/* ]]; then
|
| 25 |
+
VERSION=${GITHUB_REF#refs/tags/v}
|
| 26 |
+
else
|
| 27 |
+
VERSION=$(grep -oP '__version__ = "\K[^"]+' kiro_proxy/__init__.py)
|
| 28 |
+
fi
|
| 29 |
+
echo "VERSION=$VERSION" >> $GITHUB_OUTPUT
|
| 30 |
+
echo "Version: $VERSION"
|
| 31 |
+
|
| 32 |
+
- name: Set up Python
|
| 33 |
+
uses: actions/setup-python@v5
|
| 34 |
+
with:
|
| 35 |
+
python-version: '3.11'
|
| 36 |
+
|
| 37 |
+
- name: Install dependencies
|
| 38 |
+
run: |
|
| 39 |
+
python -m pip install --upgrade pip
|
| 40 |
+
pip install -r requirements.txt
|
| 41 |
+
pip install pyinstaller
|
| 42 |
+
|
| 43 |
+
- name: Build binary
|
| 44 |
+
run: python build.py
|
| 45 |
+
|
| 46 |
+
- name: Install packaging tools
|
| 47 |
+
run: |
|
| 48 |
+
sudo apt-get update
|
| 49 |
+
sudo apt-get install -y ruby ruby-dev rubygems build-essential rpm libfuse2
|
| 50 |
+
sudo gem install --no-document fpm
|
| 51 |
+
|
| 52 |
+
- name: Create packages
|
| 53 |
+
run: |
|
| 54 |
+
mkdir -p release
|
| 55 |
+
VERSION=${{ steps.version.outputs.VERSION }}
|
| 56 |
+
|
| 57 |
+
# Binary (standalone)
|
| 58 |
+
cp dist/KiroProxy release/KiroProxy-${VERSION}-linux-x86_64
|
| 59 |
+
chmod +x release/KiroProxy-${VERSION}-linux-x86_64
|
| 60 |
+
|
| 61 |
+
# tar.gz
|
| 62 |
+
tar -czvf release/KiroProxy-${VERSION}-linux-x86_64.tar.gz -C dist KiroProxy
|
| 63 |
+
|
| 64 |
+
# deb package
|
| 65 |
+
fpm -s dir -t deb \
|
| 66 |
+
-n kiroproxy \
|
| 67 |
+
-v ${VERSION} \
|
| 68 |
+
--description "Kiro API Proxy Server" \
|
| 69 |
+
--license "MIT" \
|
| 70 |
+
--architecture amd64 \
|
| 71 |
+
--maintainer "petehsu" \
|
| 72 |
+
--url "https://github.com/petehsu/KiroProxy" \
|
| 73 |
+
-p release/kiroproxy_${VERSION}_amd64.deb \
|
| 74 |
+
dist/KiroProxy=/usr/local/bin/KiroProxy
|
| 75 |
+
|
| 76 |
+
# rpm package
|
| 77 |
+
fpm -s dir -t rpm \
|
| 78 |
+
-n kiroproxy \
|
| 79 |
+
-v ${VERSION} \
|
| 80 |
+
--description "Kiro API Proxy Server" \
|
| 81 |
+
--license "MIT" \
|
| 82 |
+
--architecture x86_64 \
|
| 83 |
+
--maintainer "petehsu" \
|
| 84 |
+
--url "https://github.com/petehsu/KiroProxy" \
|
| 85 |
+
-p release/kiroproxy-${VERSION}-1.x86_64.rpm \
|
| 86 |
+
dist/KiroProxy=/usr/local/bin/KiroProxy
|
| 87 |
+
|
| 88 |
+
- name: Upload artifacts
|
| 89 |
+
uses: actions/upload-artifact@v4
|
| 90 |
+
with:
|
| 91 |
+
name: KiroProxy-Linux
|
| 92 |
+
path: release/*
|
| 93 |
+
|
| 94 |
+
build-windows:
|
| 95 |
+
runs-on: windows-latest
|
| 96 |
+
steps:
|
| 97 |
+
- uses: actions/checkout@v4
|
| 98 |
+
|
| 99 |
+
- name: Get version from tag
|
| 100 |
+
id: version
|
| 101 |
+
shell: bash
|
| 102 |
+
run: |
|
| 103 |
+
if [[ "${{ github.ref }}" == refs/tags/* ]]; then
|
| 104 |
+
VERSION=${GITHUB_REF#refs/tags/v}
|
| 105 |
+
else
|
| 106 |
+
VERSION=$(grep -oP '__version__ = "\K[^"]+' kiro_proxy/__init__.py)
|
| 107 |
+
fi
|
| 108 |
+
echo "VERSION=$VERSION" >> $GITHUB_OUTPUT
|
| 109 |
+
echo "Version: $VERSION"
|
| 110 |
+
|
| 111 |
+
- name: Set up Python
|
| 112 |
+
uses: actions/setup-python@v5
|
| 113 |
+
with:
|
| 114 |
+
python-version: '3.11'
|
| 115 |
+
|
| 116 |
+
- name: Install dependencies
|
| 117 |
+
run: |
|
| 118 |
+
python -m pip install --upgrade pip
|
| 119 |
+
pip install -r requirements.txt
|
| 120 |
+
pip install pyinstaller
|
| 121 |
+
|
| 122 |
+
- name: Build
|
| 123 |
+
run: python build.py
|
| 124 |
+
|
| 125 |
+
- name: Create packages
|
| 126 |
+
shell: pwsh
|
| 127 |
+
run: |
|
| 128 |
+
$VERSION = "${{ steps.version.outputs.VERSION }}"
|
| 129 |
+
New-Item -ItemType Directory -Force -Path release
|
| 130 |
+
|
| 131 |
+
# exe (standalone)
|
| 132 |
+
Copy-Item dist/KiroProxy.exe release/KiroProxy-${VERSION}-windows-x86_64.exe
|
| 133 |
+
|
| 134 |
+
# zip
|
| 135 |
+
Compress-Archive -Path dist/KiroProxy.exe -DestinationPath release/KiroProxy-${VERSION}-windows-x86_64.zip
|
| 136 |
+
|
| 137 |
+
- name: Upload artifacts
|
| 138 |
+
uses: actions/upload-artifact@v4
|
| 139 |
+
with:
|
| 140 |
+
name: KiroProxy-Windows
|
| 141 |
+
path: release/*
|
| 142 |
+
|
| 143 |
+
build-macos:
|
| 144 |
+
runs-on: macos-latest
|
| 145 |
+
steps:
|
| 146 |
+
- uses: actions/checkout@v4
|
| 147 |
+
|
| 148 |
+
- name: Get version from tag
|
| 149 |
+
id: version
|
| 150 |
+
run: |
|
| 151 |
+
if [[ "${{ github.ref }}" == refs/tags/* ]]; then
|
| 152 |
+
VERSION=${GITHUB_REF#refs/tags/v}
|
| 153 |
+
else
|
| 154 |
+
VERSION=$(grep -oP '__version__ = "\K[^"]+' kiro_proxy/__init__.py || echo "1.0.0")
|
| 155 |
+
fi
|
| 156 |
+
echo "VERSION=$VERSION" >> $GITHUB_OUTPUT
|
| 157 |
+
echo "Version: $VERSION"
|
| 158 |
+
|
| 159 |
+
- name: Set up Python
|
| 160 |
+
uses: actions/setup-python@v5
|
| 161 |
+
with:
|
| 162 |
+
python-version: '3.11'
|
| 163 |
+
|
| 164 |
+
- name: Install dependencies
|
| 165 |
+
run: |
|
| 166 |
+
python -m pip install --upgrade pip
|
| 167 |
+
pip install -r requirements.txt
|
| 168 |
+
pip install pyinstaller
|
| 169 |
+
|
| 170 |
+
- name: Generate icon
|
| 171 |
+
run: |
|
| 172 |
+
mkdir -p assets/icon.iconset
|
| 173 |
+
for size in 16 32 64 128 256 512; do
|
| 174 |
+
sips -z $size $size assets/icon.png --out assets/icon.iconset/icon_${size}x${size}.png
|
| 175 |
+
done
|
| 176 |
+
iconutil -c icns assets/icon.iconset -o assets/icon.icns
|
| 177 |
+
|
| 178 |
+
- name: Build
|
| 179 |
+
run: python build.py
|
| 180 |
+
|
| 181 |
+
- name: Create packages
|
| 182 |
+
run: |
|
| 183 |
+
VERSION=${{ steps.version.outputs.VERSION }}
|
| 184 |
+
mkdir -p release
|
| 185 |
+
|
| 186 |
+
# Binary (standalone)
|
| 187 |
+
cp dist/KiroProxy release/KiroProxy-${VERSION}-macos-x86_64
|
| 188 |
+
chmod +x release/KiroProxy-${VERSION}-macos-x86_64
|
| 189 |
+
|
| 190 |
+
# zip
|
| 191 |
+
cd dist && zip -r ../release/KiroProxy-${VERSION}-macos-x86_64.zip KiroProxy && cd ..
|
| 192 |
+
|
| 193 |
+
- name: Upload artifacts
|
| 194 |
+
uses: actions/upload-artifact@v4
|
| 195 |
+
with:
|
| 196 |
+
name: KiroProxy-macOS
|
| 197 |
+
path: release/*
|
| 198 |
+
|
| 199 |
+
release:
|
| 200 |
+
needs: [build-linux, build-windows, build-macos]
|
| 201 |
+
runs-on: ubuntu-latest
|
| 202 |
+
if: startsWith(github.ref, 'refs/tags/')
|
| 203 |
+
|
| 204 |
+
steps:
|
| 205 |
+
- uses: actions/checkout@v4
|
| 206 |
+
|
| 207 |
+
- name: Get version from tag
|
| 208 |
+
id: version
|
| 209 |
+
run: |
|
| 210 |
+
VERSION=${GITHUB_REF#refs/tags/v}
|
| 211 |
+
echo "VERSION=$VERSION" >> $GITHUB_OUTPUT
|
| 212 |
+
|
| 213 |
+
- name: Download all artifacts
|
| 214 |
+
uses: actions/download-artifact@v4
|
| 215 |
+
with:
|
| 216 |
+
path: artifacts
|
| 217 |
+
|
| 218 |
+
- name: List artifacts
|
| 219 |
+
run: find artifacts -type f
|
| 220 |
+
|
| 221 |
+
- name: Create Release
|
| 222 |
+
uses: softprops/action-gh-release@v1
|
| 223 |
+
with:
|
| 224 |
+
name: KiroProxy v${{ steps.version.outputs.VERSION }}
|
| 225 |
+
body: |
|
| 226 |
+
## Downloads
|
| 227 |
+
|
| 228 |
+
| Platform | File | Description |
|
| 229 |
+
|----------|------|-------------|
|
| 230 |
+
| **Linux** | `KiroProxy-${{ steps.version.outputs.VERSION }}-linux-x86_64` | Standalone binary |
|
| 231 |
+
| | `KiroProxy-${{ steps.version.outputs.VERSION }}-linux-x86_64.tar.gz` | Compressed archive |
|
| 232 |
+
| | `kiroproxy_${{ steps.version.outputs.VERSION }}_amd64.deb` | Debian/Ubuntu package |
|
| 233 |
+
| | `kiroproxy-${{ steps.version.outputs.VERSION }}-1.x86_64.rpm` | Fedora/RHEL/CentOS package |
|
| 234 |
+
| **Windows** | `KiroProxy-${{ steps.version.outputs.VERSION }}-windows-x86_64.exe` | Standalone executable |
|
| 235 |
+
| | `KiroProxy-${{ steps.version.outputs.VERSION }}-windows-x86_64.zip` | Compressed archive |
|
| 236 |
+
| **macOS** | `KiroProxy-${{ steps.version.outputs.VERSION }}-macos-x86_64` | Standalone binary |
|
| 237 |
+
| | `KiroProxy-${{ steps.version.outputs.VERSION }}-macos-x86_64.zip` | Compressed archive |
|
| 238 |
+
files: |
|
| 239 |
+
artifacts/KiroProxy-Linux/*
|
| 240 |
+
artifacts/KiroProxy-Windows/*
|
| 241 |
+
artifacts/KiroProxy-macOS/*
|
| 242 |
+
draft: false
|
| 243 |
+
prerelease: false
|
| 244 |
+
env:
|
| 245 |
+
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
KiroProxy/.gitignore
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Python
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
venv/
|
| 6 |
+
.venv/
|
| 7 |
+
*.egg-info/
|
| 8 |
+
.hypothesis/
|
| 9 |
+
.pytest_cache/
|
| 10 |
+
|
| 11 |
+
# Build
|
| 12 |
+
build/
|
| 13 |
+
dist/
|
| 14 |
+
release/
|
| 15 |
+
*.spec
|
| 16 |
+
|
| 17 |
+
# IDE
|
| 18 |
+
.idea/
|
| 19 |
+
.vscode/
|
| 20 |
+
*.swp
|
| 21 |
+
*.swo
|
| 22 |
+
|
| 23 |
+
# OS
|
| 24 |
+
.DS_Store
|
| 25 |
+
Thumbs.db
|
| 26 |
+
|
| 27 |
+
# HAR files (contain sensitive data)
|
| 28 |
+
*.har
|
| 29 |
+
|
| 30 |
+
# Logs
|
| 31 |
+
*.log
|
| 32 |
+
|
| 33 |
+
# Test files
|
| 34 |
+
[0-9].txt
|
| 35 |
+
[0-9][0-9].txt
|
| 36 |
+
线索*.txt
|
| 37 |
+
|
| 38 |
+
# Temp analysis files
|
| 39 |
+
flows
|
| 40 |
+
flows_*
|
| 41 |
+
traffic.mitm
|
| 42 |
+
*.mitm
|
| 43 |
+
analyze_har.py
|
| 44 |
+
parse_*.py
|
| 45 |
+
*_analysis.txt
|
| 46 |
+
*_check.txt
|
| 47 |
+
hex_dump.txt
|
| 48 |
+
parsed_*.txt
|
| 49 |
+
response.txt
|
| 50 |
+
参考.txt
|
| 51 |
+
|
| 52 |
+
# Other projects
|
| 53 |
+
Antigravity-Manager/
|
| 54 |
+
cc-switch/
|
KiroProxy/CAPTURE_GUIDE.md
ADDED
|
File without changes
|
KiroProxy/README.md
ADDED
|
@@ -0,0 +1,423 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<p align="center">
|
| 2 |
+
<img src="assets/icon.svg" width="80" height="96" alt="Kiro Proxy">
|
| 3 |
+
</p>
|
| 4 |
+
|
| 5 |
+
<h1 align="center">Kiro API Proxy</h1>
|
| 6 |
+
|
| 7 |
+
<p align="center">
|
| 8 |
+
Kiro IDE API 反向代理服务器,支持多账号轮询、Token 自动刷新、配额管理
|
| 9 |
+
</p>
|
| 10 |
+
|
| 11 |
+
<p align="center">
|
| 12 |
+
<a href="#功能特性">功能</a> •
|
| 13 |
+
<a href="#快速开始">快速开始</a> •
|
| 14 |
+
<a href="#cli-配置">CLI 配置</a> •
|
| 15 |
+
<a href="#api-端点">API</a> •
|
| 16 |
+
<a href="#许可证">许可证</a>
|
| 17 |
+
</p>
|
| 18 |
+
|
| 19 |
+
---
|
| 20 |
+
|
| 21 |
+
> **⚠️ 测试说明**
|
| 22 |
+
>
|
| 23 |
+
> 本项目支持 **Claude Code**、**Codex CLI**、**Gemini CLI** 三种客户端,工具调用功能已全面支持。
|
| 24 |
+
|
| 25 |
+
## 功能特性
|
| 26 |
+
|
| 27 |
+
### 核心功能
|
| 28 |
+
- **多协议支持** - OpenAI / Anthropic / Gemini 三种协议兼容
|
| 29 |
+
- **完整工具调用** - 三种协议的工具调用功能全面支持
|
| 30 |
+
- **图片理解** - 支持 Claude Code / Codex CLI 图片输入
|
| 31 |
+
- **网络搜索** - 支持 Claude Code / Codex CLI 网络搜索工具
|
| 32 |
+
- **思考功能** - 支持 Claude 的扩展思考功能(Extended Thinking)
|
| 33 |
+
- **多账号轮询(默认随机)** - 每次请求随机切换账号,分散压力,避免单账号 RPM 过高
|
| 34 |
+
- **会话粘性(可选)** - 非 `random` 策略下,同一会话 60 秒内使用同一账号,保持上下文
|
| 35 |
+
- **Web UI** - 简洁的管理界面,支持监控、日志、设置
|
| 36 |
+
|
| 37 |
+
### v1.7.1 新功能
|
| 38 |
+
- **Windows 支持补强** - 注册表浏览器检测 + PATH 回退,兼容便携版
|
| 39 |
+
- **打包资源修复** - PyInstaller 打包后可正常加载图标与内置文档
|
| 40 |
+
- **Token 扫描稳定性** - Windows 路径编码处理修复
|
| 41 |
+
|
| 42 |
+
### v1.6.3 新功能
|
| 43 |
+
- **命令行工具 (CLI)** - 无 GUI 服务器也能轻松管理
|
| 44 |
+
- `python run.py accounts list` - 列出账号
|
| 45 |
+
- `python run.py accounts export/import` - 导出/导入账号
|
| 46 |
+
- `python run.py accounts add` - 交互式添加 Token
|
| 47 |
+
- `python run.py accounts scan` - 扫描本地 Token
|
| 48 |
+
- `python run.py login google/github` - 命令行登录
|
| 49 |
+
- `python run.py login remote` - 生成远程登录链接
|
| 50 |
+
- **远程登录链接** - 在有浏览器的机器上完成授权,Token 自动同步
|
| 51 |
+
- **账号导入导出** - 跨机器迁移账号配置
|
| 52 |
+
- **手动添加 Token** - 直接粘贴 accessToken/refreshToken
|
| 53 |
+
|
| 54 |
+
### v1.6.2 新功能
|
| 55 |
+
- **Codex CLI 完整支持** - 使用 OpenAI Responses API (`/v1/responses`)
|
| 56 |
+
- 完整工具调用支持(shell、file 等所有工具)
|
| 57 |
+
- 图片输入支持(`input_image` 类型)
|
| 58 |
+
- 网络搜索支持(`web_search` 工具)
|
| 59 |
+
- 错误代码映射(rate_limit、context_length 等)
|
| 60 |
+
- **Claude Code 增强** - 图片理解和网络搜索完整支持
|
| 61 |
+
- 支持 Anthropic 和 OpenAI 两种图片格式
|
| 62 |
+
- 支持 `web_search` / `web_search_20250305` 工具
|
| 63 |
+
|
| 64 |
+
### v1.6.1 新功能
|
| 65 |
+
- **请求限速** - 通过限制请求频率降低账号封禁风险
|
| 66 |
+
- 每账号最小请求间隔
|
| 67 |
+
- 每账号每分钟最大请求数
|
| 68 |
+
- 全局每分钟最大请求数
|
| 69 |
+
- WebUI 设置页面可配置
|
| 70 |
+
- **账号封禁检测** - 自动检测 TEMPORARILY_SUSPENDED 错误
|
| 71 |
+
- 友好的错误日志输出
|
| 72 |
+
- 自动禁用被封禁账号
|
| 73 |
+
- 自动切换到其他可用账号
|
| 74 |
+
- **统一错误处理** - 三种协议使用统一的错误分类和处理
|
| 75 |
+
|
| 76 |
+
### v1.6.0 功能
|
| 77 |
+
- **历史消息管理** - 4 种策略处理对话长度限制,可自由组合
|
| 78 |
+
- 自动截断:发送前优先保留最新上下文并摘要前文,必要时按数量/字符数截断
|
| 79 |
+
- 智能摘要:用 AI 生成早期对话摘要,保留关键信息
|
| 80 |
+
- 摘要缓存:历史变化不大时复用最近摘要,减少重复 LLM 调用(默认启用)
|
| 81 |
+
- 错误重试:遇到长度错误时自动截断重试(默认启用)
|
| 82 |
+
- 预估检测:预估 token 数量,超限预先截断
|
| 83 |
+
- **Gemini 工具调用** - 完整支持 functionDeclarations/functionCall/functionResponse
|
| 84 |
+
- **设置页面** - WebUI 新增设置标签页,可配置历史消息管理策略
|
| 85 |
+
|
| 86 |
+
### v1.5.0 功能
|
| 87 |
+
- **用量查询** - 查询账号配额使用情况,显示已用/余额/使用率
|
| 88 |
+
- **多登录方式** - 支持 Google / GitHub / AWS Builder ID 三种登录方式
|
| 89 |
+
- **流量监控** - 完整的 LLM 请求监控,支持搜索、过滤、导出
|
| 90 |
+
- **浏览器选择** - 自动检测已安装浏览器,支持无痕模式
|
| 91 |
+
- **文档中心** - 内置帮助文档,左侧目录 + 右侧 Markdown 渲染
|
| 92 |
+
|
| 93 |
+
### v1.4.0 功能
|
| 94 |
+
- **Token 预刷新** - 后台每 5 分钟检查,提前 15 分钟自动刷新
|
| 95 |
+
- **健康检查** - 每 10 分钟检测账号可用性,自动标记状态
|
| 96 |
+
- **请求统计增强** - 按账号/模型统计,24 小时趋势
|
| 97 |
+
- **请求重试机制** - 网络错误/5xx 自动重试,指数退避
|
| 98 |
+
|
| 99 |
+
## 工具调用支持
|
| 100 |
+
|
| 101 |
+
| 功能 | Anthropic (Claude Code) | OpenAI (Codex CLI) | Gemini |
|
| 102 |
+
|------|------------------------|-------------------|--------|
|
| 103 |
+
| 工具定义 | ✅ `tools` | ✅ `tools.function` | ✅ `functionDeclarations` |
|
| 104 |
+
| 工具调用响应 | ✅ `tool_use` | ✅ `tool_calls` | ✅ `functionCall` |
|
| 105 |
+
| 工具结果 | ✅ `tool_result` | �� `tool` 角色消息 | ✅ `functionResponse` |
|
| 106 |
+
| 强制工具调用 | ✅ `tool_choice` | ✅ `tool_choice` | ✅ `toolConfig.mode` |
|
| 107 |
+
| 工具数量限制 | ✅ 50 个 | ✅ 50 个 | ✅ 50 个 |
|
| 108 |
+
| 历史消息修复 | ✅ | ✅ | ✅ |
|
| 109 |
+
| 图片理解 | ✅ | ✅ | ❌ |
|
| 110 |
+
| 网络搜索 | ✅ | ✅ | ❌ |
|
| 111 |
+
|
| 112 |
+
## 已知限制
|
| 113 |
+
|
| 114 |
+
### 对话长度限制
|
| 115 |
+
|
| 116 |
+
Kiro API 有输入长度限制。当对话历史过长时,会返回错误:
|
| 117 |
+
|
| 118 |
+
```
|
| 119 |
+
Input is too long. (CONTENT_LENGTH_EXCEEDS_THRESHOLD)
|
| 120 |
+
```
|
| 121 |
+
|
| 122 |
+
#### 自动处理(v1.6.0+)
|
| 123 |
+
|
| 124 |
+
代理内置了历史消息管理功能,可在「设置」页面配置:
|
| 125 |
+
|
| 126 |
+
- **错误重试**(默认):遇到长度错误时自动截断并重试
|
| 127 |
+
- **智能摘要**:用 AI 生成早期对话摘要,保留关键信息
|
| 128 |
+
- **摘要缓存**(默认):历史变化不大时复用最近摘要,减少重复 LLM 调用
|
| 129 |
+
- **自动截断**:每次请求前优先保留最新上下文并摘要前文,必要时按数量/字符数截断
|
| 130 |
+
- **预估检测**:预估 token 数量,超限预先截断
|
| 131 |
+
|
| 132 |
+
摘要缓存可通过以下配置项调整(默认值):
|
| 133 |
+
- `summary_cache_enabled`: `true`
|
| 134 |
+
- `summary_cache_min_delta_messages`: `3`
|
| 135 |
+
- `summary_cache_min_delta_chars`: `4000`
|
| 136 |
+
- `summary_cache_max_age_seconds`: `180`
|
| 137 |
+
|
| 138 |
+
#### 手动处理
|
| 139 |
+
|
| 140 |
+
1. 在 Claude Code 中输入 `/clear` 清空对话历史
|
| 141 |
+
2. 告诉 AI 你之前在做什么,它会读取代码文件恢复上下文
|
| 142 |
+
|
| 143 |
+
## 快速开始
|
| 144 |
+
|
| 145 |
+
### 方式一:下载预编译版本
|
| 146 |
+
|
| 147 |
+
从 [Releases](../../releases) 下载对应平台的安装包,解压后直接运行。
|
| 148 |
+
|
| 149 |
+
### 方式二:从源码运行
|
| 150 |
+
|
| 151 |
+
```bash
|
| 152 |
+
# 克隆项目
|
| 153 |
+
git clone https://github.com/yourname/kiro-proxy.git
|
| 154 |
+
cd kiro-proxy
|
| 155 |
+
|
| 156 |
+
# 创建虚拟环境
|
| 157 |
+
python -m venv venv
|
| 158 |
+
source venv/bin/activate # Windows: venv\Scripts\activate
|
| 159 |
+
|
| 160 |
+
# 安装依赖
|
| 161 |
+
pip install -r requirements.txt
|
| 162 |
+
|
| 163 |
+
# 运行
|
| 164 |
+
python run.py
|
| 165 |
+
|
| 166 |
+
# 或指定端口
|
| 167 |
+
python run.py 8081
|
| 168 |
+
```
|
| 169 |
+
|
| 170 |
+
启动后访问 http://localhost:8080
|
| 171 |
+
|
| 172 |
+
### 命令行工具 (CLI)
|
| 173 |
+
|
| 174 |
+
无 GUI 服务器可使用 CLI 管理账号:
|
| 175 |
+
|
| 176 |
+
```bash
|
| 177 |
+
# 账号管理
|
| 178 |
+
python run.py accounts list # 列出账号
|
| 179 |
+
python run.py accounts export -o acc.json # 导出账号
|
| 180 |
+
python run.py accounts import acc.json # 导入账号
|
| 181 |
+
python run.py accounts add # 交互式添加 Token
|
| 182 |
+
python run.py accounts scan --auto # 扫描并自动添加本地 Token
|
| 183 |
+
|
| 184 |
+
# 登录
|
| 185 |
+
python run.py login google # Google 登录
|
| 186 |
+
python run.py login github # GitHub 登录
|
| 187 |
+
python run.py login remote --host myserver.com:8080 # 生成远程登录链接
|
| 188 |
+
|
| 189 |
+
# 服务
|
| 190 |
+
python run.py serve # 启动服务 (默认 8080)
|
| 191 |
+
python run.py serve -p 8081 # 指定端口
|
| 192 |
+
python run.py status # 查看状态
|
| 193 |
+
```
|
| 194 |
+
|
| 195 |
+
### 登录获取 Token
|
| 196 |
+
|
| 197 |
+
**方式一:在线登录(推荐)**
|
| 198 |
+
1. 打开 Web UI,点击「在线登录」
|
| 199 |
+
2. 选择登录方式:Google / GitHub / AWS Builder ID
|
| 200 |
+
3. 在浏览器中完成授权
|
| 201 |
+
4. 账号自动添加
|
| 202 |
+
|
| 203 |
+
**方式二:扫描 Token**
|
| 204 |
+
1. 打开 Kiro IDE,使用 Google/GitHub 账号登录
|
| 205 |
+
2. 登录成功后 token 自动保存到 `~/.aws/sso/cache/`
|
| 206 |
+
3. 在 Web UI 点击「扫描 Token」添加账号
|
| 207 |
+
|
| 208 |
+
## CLI 配置
|
| 209 |
+
|
| 210 |
+
### 模型对照表
|
| 211 |
+
|
| 212 |
+
| Kiro 模型 | 能力 | Claude Code | Codex |
|
| 213 |
+
|-----------|------|-------------|-------|
|
| 214 |
+
| `claude-sonnet-4` | ⭐⭐⭐ 推荐 | `claude-sonnet-4` | `gpt-4o` |
|
| 215 |
+
| `claude-sonnet-4.5` | ⭐⭐⭐⭐ 更强 | `claude-sonnet-4.5` | `gpt-4o` |
|
| 216 |
+
| `claude-haiku-4.5` | ⚡ 快速 | `claude-haiku-4.5` | `gpt-4o-mini` |
|
| 217 |
+
|
| 218 |
+
### Claude Code 配置
|
| 219 |
+
|
| 220 |
+
```
|
| 221 |
+
名称: Kiro Proxy
|
| 222 |
+
API Key: any
|
| 223 |
+
Base URL: http://localhost:8080
|
| 224 |
+
模型: claude-sonnet-4
|
| 225 |
+
```
|
| 226 |
+
|
| 227 |
+
### Codex 配置
|
| 228 |
+
|
| 229 |
+
Codex CLI 使用 OpenAI Responses API,配置如下:
|
| 230 |
+
|
| 231 |
+
```bash
|
| 232 |
+
# 设置环境变量
|
| 233 |
+
export OPENAI_API_KEY=any
|
| 234 |
+
export OPENAI_BASE_URL=http://localhost:8080/v1
|
| 235 |
+
|
| 236 |
+
# 运行 Codex
|
| 237 |
+
codex
|
| 238 |
+
```
|
| 239 |
+
|
| 240 |
+
或在 `~/.codex/config.toml` 中配置:
|
| 241 |
+
|
| 242 |
+
```toml
|
| 243 |
+
[providers.openai]
|
| 244 |
+
api_key = "any"
|
| 245 |
+
base_url = "http://localhost:8080/v1"
|
| 246 |
+
```
|
| 247 |
+
|
| 248 |
+
## 思考功能支持
|
| 249 |
+
|
| 250 |
+
### 什么是思考功能
|
| 251 |
+
|
| 252 |
+
思考功能(Extended Thinking)允许 Claude 在生成回答前展示其思考过程,帮助用户理解 AI 的推理步骤。
|
| 253 |
+
|
| 254 |
+
### 如何使用
|
| 255 |
+
|
| 256 |
+
在请求中添加 `thinking`(或对应协议的 thinking 配置)即可启用:
|
| 257 |
+
|
| 258 |
+
```json
|
| 259 |
+
{
|
| 260 |
+
"model": "claude-sonnet-4.5",
|
| 261 |
+
"messages": [
|
| 262 |
+
{
|
| 263 |
+
"role": "user",
|
| 264 |
+
"content": "解释一下量子计算的原理"
|
| 265 |
+
}
|
| 266 |
+
],
|
| 267 |
+
"thinking": {
|
| 268 |
+
"thinking_type": "enabled",
|
| 269 |
+
"budget_tokens": 20000
|
| 270 |
+
},
|
| 271 |
+
"stream": true
|
| 272 |
+
}
|
| 273 |
+
```
|
| 274 |
+
|
| 275 |
+
OpenAI Chat Completions (`POST /v1/chat/completions`) 也支持:
|
| 276 |
+
|
| 277 |
+
```json
|
| 278 |
+
{
|
| 279 |
+
"model": "gpt-4o",
|
| 280 |
+
"messages": [{"role": "user", "content": "解释一下量子计算的原理"}],
|
| 281 |
+
"thinking": { "type": "enabled" },
|
| 282 |
+
"stream": true
|
| 283 |
+
}
|
| 284 |
+
```
|
| 285 |
+
|
| 286 |
+
OpenAI Responses (`POST /v1/responses`) 也支持:
|
| 287 |
+
|
| 288 |
+
```json
|
| 289 |
+
{
|
| 290 |
+
"model": "gpt-4o",
|
| 291 |
+
"input": "解释一下量子计算的原理",
|
| 292 |
+
"thinking": { "type": "enabled" }
|
| 293 |
+
}
|
| 294 |
+
```
|
| 295 |
+
|
| 296 |
+
Gemini generateContent (`POST /v1/models/{model}:generateContent`) 也支持:
|
| 297 |
+
|
| 298 |
+
```json
|
| 299 |
+
{
|
| 300 |
+
"contents": [{"role": "user", "parts": [{"text": "解释一下量子计算的原理"}]}],
|
| 301 |
+
"generationConfig": {
|
| 302 |
+
"thinkingConfig": { "includeThoughts": true }
|
| 303 |
+
}
|
| 304 |
+
}
|
| 305 |
+
```
|
| 306 |
+
|
| 307 |
+
### 参数说明
|
| 308 |
+
|
| 309 |
+
- `thinking_type`: 思考类型,设为 `"enabled"` 启用思考功能
|
| 310 |
+
- `budget_tokens`: 思考过程的 token 预算(不传则视为无限制)
|
| 311 |
+
|
| 312 |
+
### 响应格式
|
| 313 |
+
|
| 314 |
+
启用思考功能后,流式响应会包含两种内容块:
|
| 315 |
+
|
| 316 |
+
1. **思考块**(type: "thinking"):展示 AI 的思考过程
|
| 317 |
+
2. **文本块**(type: "text"):最终的回答内容
|
| 318 |
+
|
| 319 |
+
示例响应:
|
| 320 |
+
```
|
| 321 |
+
data: {"type":"content_block_start","index":1,"content_block":{"type":"thinking","thinking":""}}
|
| 322 |
+
data: {"type":"content_block_delta","index":1,"delta":{"type":"thinking_delta","thinking":"让我思考一下量子计算的原理..."}}
|
| 323 |
+
data: {"type":"content_block_stop","index":1}
|
| 324 |
+
data: {"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}}
|
| 325 |
+
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"量子计算是一种..."}}
|
| 326 |
+
data: {"type":"content_block_stop","index":0}
|
| 327 |
+
```
|
| 328 |
+
|
| 329 |
+
## API 端点
|
| 330 |
+
|
| 331 |
+
| 协议 | 端点 | 用途 |
|
| 332 |
+
|------|------|------|
|
| 333 |
+
| OpenAI | `POST /v1/chat/completions` | Chat Completions API |
|
| 334 |
+
| OpenAI | `POST /v1/responses` | Responses API (Codex CLI) |
|
| 335 |
+
| OpenAI | `GET /v1/models` | 模型列表 |
|
| 336 |
+
| Anthropic | `POST /v1/messages` | Claude Code |
|
| 337 |
+
| Anthropic | `POST /v1/messages/count_tokens` | Token 计数 |
|
| 338 |
+
| Gemini | `POST /v1/models/{model}:generateContent` | Gemini CLI |
|
| 339 |
+
|
| 340 |
+
### 管理 API
|
| 341 |
+
|
| 342 |
+
| 端点 | 方法 | 说明 |
|
| 343 |
+
|------|------|------|
|
| 344 |
+
| `/api/accounts` | GET | 获取所有账号状态 |
|
| 345 |
+
| `/api/accounts/{id}` | GET | 获取账号详情 |
|
| 346 |
+
| `/api/accounts/{id}/usage` | GET | 获取账号用量信息 |
|
| 347 |
+
| `/api/accounts/{id}/refresh` | POST | 刷新账号 Token |
|
| 348 |
+
| `/api/accounts/{id}/restore` | POST | 恢复账号(从冷却状态) |
|
| 349 |
+
| `/api/accounts/refresh-all` | POST | 刷新所有即将过期的 Token |
|
| 350 |
+
| `/api/flows` | GET | 获取流量记录 |
|
| 351 |
+
| `/api/flows/stats` | GET | 获取流量统计 |
|
| 352 |
+
| `/api/flows/{id}` | GET | 获取流量详情 |
|
| 353 |
+
| `/api/quota` | GET | 获取配额状态 |
|
| 354 |
+
| `/api/stats` | GET | 获取统计信息 |
|
| 355 |
+
| `/api/health-check` | POST | 手动触发健康检查 |
|
| 356 |
+
| `/api/browsers` | GET | 获取可用浏览器列表 |
|
| 357 |
+
| `/api/docs` | GET | 获取文档列表 |
|
| 358 |
+
| `/api/docs/{id}` | GET | 获取文档内容 |
|
| 359 |
+
|
| 360 |
+
## 项目结构
|
| 361 |
+
|
| 362 |
+
```
|
| 363 |
+
kiro_proxy/
|
| 364 |
+
├── main.py # FastAPI 应用入口
|
| 365 |
+
├── config.py # 全局配置
|
| 366 |
+
├── converters.py # 协议转换
|
| 367 |
+
│
|
| 368 |
+
├── core/ # 核心模块
|
| 369 |
+
│ ├── account.py # 账号管理
|
| 370 |
+
│ ├── state.py # 全局状态
|
| 371 |
+
│ ├── persistence.py # 配置持久化
|
| 372 |
+
│ ├── scheduler.py # 后台任务调度
|
| 373 |
+
│ ├── stats.py # 请求统计
|
| 374 |
+
│ ├── retry.py # 重试机制
|
| 375 |
+
│ ├── browser.py # 浏览器检测
|
| 376 |
+
│ ├── flow_monitor.py # 流量监控
|
| 377 |
+
│ └── usage.py # 用量查询
|
| 378 |
+
│
|
| 379 |
+
├── credential/ # 凭证管理
|
| 380 |
+
│ ├── types.py # KiroCredentials
|
| 381 |
+
│ ├── fingerprint.py # Machine ID 生成
|
| 382 |
+
│ ├── quota.py # 配额管理器
|
| 383 |
+
│ └── refresher.py # Token 刷新
|
| 384 |
+
│
|
| 385 |
+
├── auth/ # 认证模块
|
| 386 |
+
│ └── device_flow.py # Device Code Flow / Social Auth
|
| 387 |
+
│
|
| 388 |
+
├── handlers/ # API 处理器
|
| 389 |
+
│ ├── anthropic.py # /v1/messages
|
| 390 |
+
│ ├── openai.py # /v1/chat/completions
|
| 391 |
+
│ ├── responses.py # /v1/responses (Codex CLI)
|
| 392 |
+
│ ├── gemini.py # /v1/models/{model}:generateContent
|
| 393 |
+
│ └── admin.py # 管理 API
|
| 394 |
+
│
|
| 395 |
+
├── cli.py # 命令行工具
|
| 396 |
+
│
|
| 397 |
+
├── docs/ # 内置文档
|
| 398 |
+
│ ├── 01-quickstart.md # 快速开始
|
| 399 |
+
│ ├── 02-features.md # 功能特性
|
| 400 |
+
│ ├── 03-faq.md # 常见问题
|
| 401 |
+
│ └── 04-api.md # API 参考
|
| 402 |
+
│
|
| 403 |
+
└── web/
|
| 404 |
+
└── html.py # Web UI (组件化单文件)
|
| 405 |
+
```
|
| 406 |
+
|
| 407 |
+
## 构建
|
| 408 |
+
|
| 409 |
+
```bash
|
| 410 |
+
# 安装构建依赖
|
| 411 |
+
pip install pyinstaller
|
| 412 |
+
|
| 413 |
+
# 构建
|
| 414 |
+
python build.py
|
| 415 |
+
```
|
| 416 |
+
|
| 417 |
+
输出文件在 `dist/` 目录。
|
| 418 |
+
|
| 419 |
+
## 免责声明
|
| 420 |
+
|
| 421 |
+
本项目仅供学习研究,禁止商用。使用本项目产生的任何后果由使用者自行承担,与作者无关。
|
| 422 |
+
|
| 423 |
+
本项目与 Kiro / AWS / Anthropic 官方无关。
|
KiroProxy/assets/icon.iconset/icon_128x128.png
ADDED
|
|
KiroProxy/assets/icon.iconset/icon_16x16.png
ADDED
|
|
KiroProxy/assets/icon.iconset/icon_256x256.png
ADDED
|
|
KiroProxy/assets/icon.iconset/icon_32x32.png
ADDED
|
|
KiroProxy/assets/icon.iconset/icon_512x512.png
ADDED
|
|
KiroProxy/assets/icon.iconset/icon_64x64.png
ADDED
|
|
KiroProxy/assets/icon.png
ADDED
|
|
KiroProxy/assets/icon.svg
ADDED
|
|
KiroProxy/build.py
ADDED
|
@@ -0,0 +1,219 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Kiro Proxy Cross-platform Build Script
|
| 4 |
+
Supports: Windows / macOS / Linux
|
| 5 |
+
|
| 6 |
+
Usage:
|
| 7 |
+
python build.py # Build for current platform
|
| 8 |
+
python build.py --all # Show all platform instructions
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import os
|
| 12 |
+
import sys
|
| 13 |
+
import shutil
|
| 14 |
+
import subprocess
|
| 15 |
+
from pathlib import Path
|
| 16 |
+
|
| 17 |
+
from kiro_proxy import __version__ as VERSION
|
| 18 |
+
|
| 19 |
+
APP_NAME = "KiroProxy"
|
| 20 |
+
MAIN_SCRIPT = "run.py"
|
| 21 |
+
ICON_DIR = Path("assets")
|
| 22 |
+
|
| 23 |
+
def get_platform():
|
| 24 |
+
if sys.platform == "win32":
|
| 25 |
+
return "windows"
|
| 26 |
+
elif sys.platform == "darwin":
|
| 27 |
+
return "macos"
|
| 28 |
+
else:
|
| 29 |
+
return "linux"
|
| 30 |
+
|
| 31 |
+
def ensure_pyinstaller():
|
| 32 |
+
try:
|
| 33 |
+
import PyInstaller
|
| 34 |
+
print(f"[OK] PyInstaller {PyInstaller.__version__} installed")
|
| 35 |
+
except ImportError:
|
| 36 |
+
print("[..] Installing PyInstaller...")
|
| 37 |
+
subprocess.run([sys.executable, "-m", "pip", "install", "pyinstaller"], check=True)
|
| 38 |
+
|
| 39 |
+
def clean_build():
|
| 40 |
+
for d in ["build", "dist", f"{APP_NAME}.spec"]:
|
| 41 |
+
if os.path.isdir(d):
|
| 42 |
+
shutil.rmtree(d)
|
| 43 |
+
elif os.path.isfile(d):
|
| 44 |
+
os.remove(d)
|
| 45 |
+
print("[OK] Cleaned build directories")
|
| 46 |
+
|
| 47 |
+
def build_app():
|
| 48 |
+
platform = get_platform()
|
| 49 |
+
print(f"\n{'='*50}")
|
| 50 |
+
print(f" Building {APP_NAME} v{VERSION} - {platform}")
|
| 51 |
+
print(f"{'='*50}\n")
|
| 52 |
+
|
| 53 |
+
ensure_pyinstaller()
|
| 54 |
+
clean_build()
|
| 55 |
+
|
| 56 |
+
args = [
|
| 57 |
+
sys.executable, "-m", "PyInstaller",
|
| 58 |
+
"--name", APP_NAME,
|
| 59 |
+
"--onefile",
|
| 60 |
+
"--clean",
|
| 61 |
+
"--noconfirm",
|
| 62 |
+
]
|
| 63 |
+
|
| 64 |
+
icon_file = None
|
| 65 |
+
if platform == "windows" and (ICON_DIR / "icon.ico").exists():
|
| 66 |
+
icon_file = ICON_DIR / "icon.ico"
|
| 67 |
+
elif platform == "macos" and (ICON_DIR / "icon.icns").exists():
|
| 68 |
+
icon_file = ICON_DIR / "icon.icns"
|
| 69 |
+
elif (ICON_DIR / "icon.png").exists():
|
| 70 |
+
icon_file = ICON_DIR / "icon.png"
|
| 71 |
+
|
| 72 |
+
if icon_file:
|
| 73 |
+
args.extend(["--icon", str(icon_file)])
|
| 74 |
+
print(f"[OK] Using icon: {icon_file}")
|
| 75 |
+
|
| 76 |
+
# 添加资源文件打包
|
| 77 |
+
if (ICON_DIR).exists():
|
| 78 |
+
if platform == "windows":
|
| 79 |
+
args.extend(["--add-data", f"{ICON_DIR};assets"])
|
| 80 |
+
else:
|
| 81 |
+
args.extend(["--add-data", f"{ICON_DIR}:assets"])
|
| 82 |
+
print(f"[OK] Adding assets directory")
|
| 83 |
+
|
| 84 |
+
# 添加文档文件打包
|
| 85 |
+
docs_dir = Path("kiro_proxy/docs")
|
| 86 |
+
if docs_dir.exists():
|
| 87 |
+
if platform == "windows":
|
| 88 |
+
args.extend(["--add-data", f"{docs_dir};kiro_proxy/docs"])
|
| 89 |
+
else:
|
| 90 |
+
args.extend(["--add-data", f"{docs_dir}:kiro_proxy/docs"])
|
| 91 |
+
print(f"[OK] Adding docs directory")
|
| 92 |
+
|
| 93 |
+
hidden_imports = [
|
| 94 |
+
"uvicorn.logging",
|
| 95 |
+
"uvicorn.protocols.http",
|
| 96 |
+
"uvicorn.protocols.http.auto",
|
| 97 |
+
"uvicorn.protocols.http.h11_impl",
|
| 98 |
+
"uvicorn.protocols.websockets",
|
| 99 |
+
"uvicorn.protocols.websockets.auto",
|
| 100 |
+
"uvicorn.lifespan",
|
| 101 |
+
"uvicorn.lifespan.on",
|
| 102 |
+
"httpx",
|
| 103 |
+
"httpx._transports",
|
| 104 |
+
"httpx._transports.default",
|
| 105 |
+
"anyio",
|
| 106 |
+
"anyio._backends",
|
| 107 |
+
"anyio._backends._asyncio",
|
| 108 |
+
]
|
| 109 |
+
for imp in hidden_imports:
|
| 110 |
+
args.extend(["--hidden-import", imp])
|
| 111 |
+
|
| 112 |
+
args.append(MAIN_SCRIPT)
|
| 113 |
+
args = [a for a in args if a]
|
| 114 |
+
|
| 115 |
+
print(f"[..] Running: {' '.join(args)}\n")
|
| 116 |
+
result = subprocess.run(args)
|
| 117 |
+
|
| 118 |
+
if result.returncode == 0:
|
| 119 |
+
if platform == "windows":
|
| 120 |
+
output = Path("dist") / f"{APP_NAME}.exe"
|
| 121 |
+
else:
|
| 122 |
+
output = Path("dist") / APP_NAME
|
| 123 |
+
|
| 124 |
+
if output.exists():
|
| 125 |
+
size_mb = output.stat().st_size / (1024 * 1024)
|
| 126 |
+
print(f"\n{'='*50}")
|
| 127 |
+
print(f" [OK] Build successful!")
|
| 128 |
+
print(f" Output: {output}")
|
| 129 |
+
print(f" Size: {size_mb:.1f} MB")
|
| 130 |
+
print(f"{'='*50}")
|
| 131 |
+
|
| 132 |
+
create_release_package(platform, output)
|
| 133 |
+
else:
|
| 134 |
+
print("[FAIL] Build failed: output file not found")
|
| 135 |
+
sys.exit(1)
|
| 136 |
+
else:
|
| 137 |
+
print("[FAIL] Build failed")
|
| 138 |
+
sys.exit(1)
|
| 139 |
+
|
| 140 |
+
def create_release_package(platform, binary_path):
|
| 141 |
+
release_dir = Path("release")
|
| 142 |
+
release_dir.mkdir(exist_ok=True)
|
| 143 |
+
|
| 144 |
+
if platform == "windows":
|
| 145 |
+
archive_name = f"{APP_NAME}-{VERSION}-Windows"
|
| 146 |
+
shutil.copy(binary_path, release_dir / f"{APP_NAME}.exe")
|
| 147 |
+
shutil.make_archive(
|
| 148 |
+
str(release_dir / archive_name),
|
| 149 |
+
"zip",
|
| 150 |
+
release_dir,
|
| 151 |
+
f"{APP_NAME}.exe"
|
| 152 |
+
)
|
| 153 |
+
(release_dir / f"{APP_NAME}.exe").unlink()
|
| 154 |
+
print(f" Release: release/{archive_name}.zip")
|
| 155 |
+
|
| 156 |
+
elif platform == "macos":
|
| 157 |
+
archive_name = f"{APP_NAME}-{VERSION}-macOS"
|
| 158 |
+
shutil.copy(binary_path, release_dir / APP_NAME)
|
| 159 |
+
os.chmod(release_dir / APP_NAME, 0o755)
|
| 160 |
+
shutil.make_archive(
|
| 161 |
+
str(release_dir / archive_name),
|
| 162 |
+
"zip",
|
| 163 |
+
release_dir,
|
| 164 |
+
APP_NAME
|
| 165 |
+
)
|
| 166 |
+
(release_dir / APP_NAME).unlink()
|
| 167 |
+
print(f" Release: release/{archive_name}.zip")
|
| 168 |
+
|
| 169 |
+
else:
|
| 170 |
+
archive_name = f"{APP_NAME}-{VERSION}-Linux"
|
| 171 |
+
shutil.copy(binary_path, release_dir / APP_NAME)
|
| 172 |
+
os.chmod(release_dir / APP_NAME, 0o755)
|
| 173 |
+
shutil.make_archive(
|
| 174 |
+
str(release_dir / archive_name),
|
| 175 |
+
"gztar",
|
| 176 |
+
release_dir,
|
| 177 |
+
APP_NAME
|
| 178 |
+
)
|
| 179 |
+
(release_dir / APP_NAME).unlink()
|
| 180 |
+
print(f" Release: release/{archive_name}.tar.gz")
|
| 181 |
+
|
| 182 |
+
def show_all_platforms():
|
| 183 |
+
print(f"""
|
| 184 |
+
{'='*60}
|
| 185 |
+
Kiro Proxy Cross-platform Build Instructions
|
| 186 |
+
{'='*60}
|
| 187 |
+
|
| 188 |
+
This script must run on the target platform.
|
| 189 |
+
|
| 190 |
+
[Windows]
|
| 191 |
+
Run on Windows:
|
| 192 |
+
python build.py
|
| 193 |
+
|
| 194 |
+
Output: release/KiroProxy-{VERSION}-Windows.zip
|
| 195 |
+
|
| 196 |
+
[macOS]
|
| 197 |
+
Run on macOS:
|
| 198 |
+
python build.py
|
| 199 |
+
|
| 200 |
+
Output: release/KiroProxy-{VERSION}-macOS.zip
|
| 201 |
+
|
| 202 |
+
[Linux]
|
| 203 |
+
Run on Linux:
|
| 204 |
+
python build.py
|
| 205 |
+
|
| 206 |
+
Output: release/KiroProxy-{VERSION}-Linux.tar.gz
|
| 207 |
+
|
| 208 |
+
[GitHub Actions]
|
| 209 |
+
Push to GitHub and Actions will build all platforms.
|
| 210 |
+
See .github/workflows/build.yml
|
| 211 |
+
|
| 212 |
+
{'='*60}
|
| 213 |
+
""")
|
| 214 |
+
|
| 215 |
+
if __name__ == "__main__":
|
| 216 |
+
if "--all" in sys.argv or "-a" in sys.argv:
|
| 217 |
+
show_all_platforms()
|
| 218 |
+
else:
|
| 219 |
+
build_app()
|
KiroProxy/examples/quota_display_example.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""展示额度重置时间功能的示例"""
|
| 2 |
+
import json
|
| 3 |
+
from datetime import datetime
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def generate_quota_display_example():
|
| 7 |
+
"""生成额度显示示例"""
|
| 8 |
+
|
| 9 |
+
# 模拟账号的额度信息(从 API 获取)
|
| 10 |
+
quota_data = {
|
| 11 |
+
"subscription_title": "Kiro Pro",
|
| 12 |
+
"usage_limit": 700.0,
|
| 13 |
+
"current_usage": 150.0,
|
| 14 |
+
"balance": 550.0,
|
| 15 |
+
"usage_percent": 21.4,
|
| 16 |
+
"is_low_balance": False,
|
| 17 |
+
"is_exhausted": False,
|
| 18 |
+
"balance_status": "normal",
|
| 19 |
+
|
| 20 |
+
# 免费试用信息
|
| 21 |
+
"free_trial_limit": 500.0,
|
| 22 |
+
"free_trial_usage": 100.0,
|
| 23 |
+
"free_trial_expiry": "2026-02-13T23:59:59Z",
|
| 24 |
+
"trial_expiry_text": "2026-02-13",
|
| 25 |
+
|
| 26 |
+
# 奖励信息
|
| 27 |
+
"bonus_limit": 150.0,
|
| 28 |
+
"bonus_usage": 25.0,
|
| 29 |
+
"bonus_expiries": ["2026-03-01T23:59:59Z", "2026-02-28T23:59:59Z"],
|
| 30 |
+
"active_bonuses": 2,
|
| 31 |
+
|
| 32 |
+
# 重置时间
|
| 33 |
+
"next_reset_date": "2026-02-01T00:00:00Z",
|
| 34 |
+
"reset_date_text": "2026-02-01",
|
| 35 |
+
|
| 36 |
+
# 更新时间
|
| 37 |
+
"updated_at": "2分钟前",
|
| 38 |
+
"error": None
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
# 生成 HTML 显示片段(类似在 Web 界面中的显示)
|
| 42 |
+
html_template = """
|
| 43 |
+
<div class="account-quota-section">
|
| 44 |
+
<div class="quota-header">
|
| 45 |
+
<span>已用/总额</span>
|
| 46 |
+
<span>{current_usage:.1f} / {usage_limit:.1f}</span>
|
| 47 |
+
</div>
|
| 48 |
+
<div class="progress-bar">
|
| 49 |
+
<div class="progress-fill" style="width: {usage_percent:.1f}%"></div>
|
| 50 |
+
</div>
|
| 51 |
+
<div class="quota-detail">
|
| 52 |
+
<span>试用: {free_trial_usage:.0f}/{free_trial_limit:.0f}</span>
|
| 53 |
+
<span>奖励: {bonus_usage:.0f}/{bonus_limit:.0f} ({active_bonuses}个)</span>
|
| 54 |
+
<span>更新: {updated_at}</span>
|
| 55 |
+
</div>
|
| 56 |
+
<div class="quota-reset-info">
|
| 57 |
+
<span>🔄 重置: {reset_date_text}</span>
|
| 58 |
+
<span>🎁 试用过期: {trial_expiry_text}</span>
|
| 59 |
+
</div>
|
| 60 |
+
</div>
|
| 61 |
+
""".format(**quota_data)
|
| 62 |
+
|
| 63 |
+
print("=== 额度信息展示示例 ===")
|
| 64 |
+
print(html_template)
|
| 65 |
+
|
| 66 |
+
# 生成卡片式展示
|
| 67 |
+
card_template = """
|
| 68 |
+
<div class="quota-card">
|
| 69 |
+
<h3>主配额</h3>
|
| 70 |
+
<div class="quota-amount">{current_usage:.0f} / {usage_limit:.0f}</div>
|
| 71 |
+
<div class="quota-reset">2026-02-01 重置</div>
|
| 72 |
+
</div>
|
| 73 |
+
<div class="quota-card">
|
| 74 |
+
<h3>免费试用</h3>
|
| 75 |
+
<div class="quota-amount">{free_trial_usage:.0f} / {free_trial_limit:.0f}</div>
|
| 76 |
+
<div class="quota-expiry">ACTIVE</div>
|
| 77 |
+
<div class="quota-reset">2026-02-13 过期</div>
|
| 78 |
+
</div>
|
| 79 |
+
<div class="quota-card">
|
| 80 |
+
<h3>奖励总计</h3>
|
| 81 |
+
<div class="quota-amount">{bonus_usage:.0f} / {bonus_limit:.0f}</div>
|
| 82 |
+
<div class="quota-expiry">{active_bonuses}个生效奖励</div>
|
| 83 |
+
</div>
|
| 84 |
+
""".format(**quota_data)
|
| 85 |
+
|
| 86 |
+
print("\n=== 卡片式展示(如图所示)===")
|
| 87 |
+
print(card_template)
|
| 88 |
+
|
| 89 |
+
# 生成 JSON 数据
|
| 90 |
+
print("\n=== JSON 数据格式 ===")
|
| 91 |
+
print(json.dumps(quota_data, indent=2, ensure_ascii=False))
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
if __name__ == "__main__":
|
| 95 |
+
generate_quota_display_example()
|
KiroProxy/examples/test_quota_display.html
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<!DOCTYPE html>
|
| 2 |
+
<html>
|
| 3 |
+
<head>
|
| 4 |
+
<meta charset="UTF-8">
|
| 5 |
+
<title>额度重置时间测试</title>
|
| 6 |
+
<style>
|
| 7 |
+
body {
|
| 8 |
+
font-family: Arial, sans-serif;
|
| 9 |
+
padding: 20px;
|
| 10 |
+
background: #f5f5f5;
|
| 11 |
+
}
|
| 12 |
+
.account-card {
|
| 13 |
+
background: white;
|
| 14 |
+
border-radius: 10px;
|
| 15 |
+
padding: 20px;
|
| 16 |
+
margin-bottom: 20px;
|
| 17 |
+
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
|
| 18 |
+
}
|
| 19 |
+
.quota-header {
|
| 20 |
+
display: flex;
|
| 21 |
+
justify-content: space-between;
|
| 22 |
+
margin-bottom: 10px;
|
| 23 |
+
font-weight: bold;
|
| 24 |
+
}
|
| 25 |
+
.progress-bar {
|
| 26 |
+
background: #e0e0e0;
|
| 27 |
+
border-radius: 4px;
|
| 28 |
+
height: 10px;
|
| 29 |
+
margin-bottom: 10px;
|
| 30 |
+
overflow: hidden;
|
| 31 |
+
}
|
| 32 |
+
.progress-fill {
|
| 33 |
+
background: #4CAF50;
|
| 34 |
+
height: 100%;
|
| 35 |
+
transition: width 0.3s;
|
| 36 |
+
}
|
| 37 |
+
.quota-detail {
|
| 38 |
+
display: flex;
|
| 39 |
+
gap: 20px;
|
| 40 |
+
font-size: 0.9em;
|
| 41 |
+
color: #666;
|
| 42 |
+
margin-bottom: 10px;
|
| 43 |
+
}
|
| 44 |
+
.quota-reset-info {
|
| 45 |
+
display: flex;
|
| 46 |
+
gap: 20px;
|
| 47 |
+
font-size: 0.8em;
|
| 48 |
+
color: #888;
|
| 49 |
+
}
|
| 50 |
+
.badge {
|
| 51 |
+
padding: 2px 8px;
|
| 52 |
+
border-radius: 4px;
|
| 53 |
+
font-size: 0.8em;
|
| 54 |
+
}
|
| 55 |
+
.badge.success { background: #4CAF50; color: white; }
|
| 56 |
+
.badge.error { background: #f44336; color: white; }
|
| 57 |
+
</style>
|
| 58 |
+
</head>
|
| 59 |
+
<body>
|
| 60 |
+
<h1>额度重置时间测试</h1>
|
| 61 |
+
<div id="accountsContainer"></div>
|
| 62 |
+
|
| 63 |
+
<script>
|
| 64 |
+
async function loadAccounts() {
|
| 65 |
+
try {
|
| 66 |
+
const response = await fetch('http://localhost:8080/api/accounts/status');
|
| 67 |
+
const data = await response.json();
|
| 68 |
+
|
| 69 |
+
const container = document.getElementById('accountsContainer');
|
| 70 |
+
container.innerHTML = '';
|
| 71 |
+
|
| 72 |
+
data.accounts.forEach(account => {
|
| 73 |
+
const quota = account.quota;
|
| 74 |
+
if (!quota) return;
|
| 75 |
+
|
| 76 |
+
const usedPercent = quota.usage_limit > 0 ? (quota.current_usage / quota.usage_limit * 100) : 0;
|
| 77 |
+
const isExhausted = quota.is_exhausted;
|
| 78 |
+
|
| 79 |
+
const card = document.createElement('div');
|
| 80 |
+
card.className = 'account-card';
|
| 81 |
+
card.innerHTML = `
|
| 82 |
+
<h3>${account.name} <span class="badge ${isExhausted ? 'error' : 'success'}">${isExhausted ? '额度耗尽' : '正常'}</span></h3>
|
| 83 |
+
<div class="quota-header">
|
| 84 |
+
<span>已用/总额</span>
|
| 85 |
+
<span>${quota.current_usage.toFixed(1)} / ${quota.usage_limit.toFixed(1)}</span>
|
| 86 |
+
</div>
|
| 87 |
+
<div class="progress-bar">
|
| 88 |
+
<div class="progress-fill" style="width: ${usedPercent}%"></div>
|
| 89 |
+
</div>
|
| 90 |
+
<div class="quota-detail">
|
| 91 |
+
<span>试用: ${quota.free_trial_usage.toFixed(0)}/${quota.free_trial_limit.toFixed(0)}</span>
|
| 92 |
+
<span>奖励: ${quota.bonus_usage.toFixed(0)}/${quota.bonus_limit.toFixed(0)} (${quota.active_bonuses}个)</span>
|
| 93 |
+
<span>更新: ${quota.updated_at || '未知'}</span>
|
| 94 |
+
</div>
|
| 95 |
+
${quota.reset_date_text || quota.trial_expiry_text ? `
|
| 96 |
+
<div class="quota-reset-info">
|
| 97 |
+
${quota.reset_date_text ? `<span>🔄 重置: ${quota.reset_date_text}</span>` : ''}
|
| 98 |
+
${quota.trial_expiry_text ? `<span>🎁 试用过期: ${quota.trial_expiry_text}</span>` : ''}
|
| 99 |
+
</div>
|
| 100 |
+
` : ''}
|
| 101 |
+
`;
|
| 102 |
+
|
| 103 |
+
container.appendChild(card);
|
| 104 |
+
});
|
| 105 |
+
} catch (error) {
|
| 106 |
+
console.error('加载失败:', error);
|
| 107 |
+
document.getElementById('accountsContainer').innerHTML = '<p>加载失败,请确保服务器正在运行</p>';
|
| 108 |
+
}
|
| 109 |
+
}
|
| 110 |
+
|
| 111 |
+
// 页面加载时获取数据
|
| 112 |
+
loadAccounts();
|
| 113 |
+
|
| 114 |
+
// 每30秒刷新一次
|
| 115 |
+
setInterval(loadAccounts, 30000);
|
| 116 |
+
</script>
|
| 117 |
+
</body>
|
| 118 |
+
</html>
|
KiroProxy/kiro.svg
ADDED
|
|
KiroProxy/kiro_proxy/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Kiro API Proxy
|
| 2 |
+
__version__ = "1.7.1"
|
KiroProxy/kiro_proxy/__main__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .cli import main
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
if __name__ == "__main__":
|
| 5 |
+
main()
|
KiroProxy/kiro_proxy/auth/__init__.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Kiro 认证模块"""
|
| 2 |
+
from .device_flow import (
|
| 3 |
+
start_device_flow,
|
| 4 |
+
poll_device_flow,
|
| 5 |
+
cancel_device_flow,
|
| 6 |
+
get_login_state,
|
| 7 |
+
save_credentials_to_file,
|
| 8 |
+
DeviceFlowState,
|
| 9 |
+
# Social Auth
|
| 10 |
+
start_social_auth,
|
| 11 |
+
exchange_social_auth_token,
|
| 12 |
+
cancel_social_auth,
|
| 13 |
+
get_social_auth_state,
|
| 14 |
+
start_callback_server,
|
| 15 |
+
wait_for_callback,
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
__all__ = [
|
| 19 |
+
"start_device_flow",
|
| 20 |
+
"poll_device_flow",
|
| 21 |
+
"cancel_device_flow",
|
| 22 |
+
"get_login_state",
|
| 23 |
+
"save_credentials_to_file",
|
| 24 |
+
"DeviceFlowState",
|
| 25 |
+
# Social Auth
|
| 26 |
+
"start_social_auth",
|
| 27 |
+
"exchange_social_auth_token",
|
| 28 |
+
"cancel_social_auth",
|
| 29 |
+
"get_social_auth_state",
|
| 30 |
+
"start_callback_server",
|
| 31 |
+
"wait_for_callback",
|
| 32 |
+
]
|
KiroProxy/kiro_proxy/auth/device_flow.py
ADDED
|
@@ -0,0 +1,603 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Kiro Device Code Flow 登录
|
| 2 |
+
|
| 3 |
+
实现 AWS OIDC Device Authorization Flow:
|
| 4 |
+
1. 注册 OIDC 客户端 -> 获取 clientId + clientSecret
|
| 5 |
+
2. 发起设备授权 -> 获取 deviceCode + userCode + verificationUri
|
| 6 |
+
3. 用户在浏览器中输入 userCode 完成授权
|
| 7 |
+
4. 轮询 Token -> 获取 accessToken + refreshToken
|
| 8 |
+
|
| 9 |
+
Social Auth (Google/GitHub):
|
| 10 |
+
1. 生成 PKCE code_verifier 和 code_challenge
|
| 11 |
+
2. 构建登录 URL,打开浏览器
|
| 12 |
+
3. 启动本地回调服务器接收授权码
|
| 13 |
+
4. 用授权码交换 Token
|
| 14 |
+
"""
|
| 15 |
+
import json
|
| 16 |
+
import time
|
| 17 |
+
import httpx
|
| 18 |
+
import secrets
|
| 19 |
+
import hashlib
|
| 20 |
+
import base64
|
| 21 |
+
import asyncio
|
| 22 |
+
from pathlib import Path
|
| 23 |
+
from dataclasses import dataclass, asdict
|
| 24 |
+
from typing import Optional, Tuple
|
| 25 |
+
from datetime import datetime, timezone
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
@dataclass
|
| 29 |
+
class DeviceFlowState:
|
| 30 |
+
"""设备授权流程状态"""
|
| 31 |
+
client_id: str
|
| 32 |
+
client_secret: str
|
| 33 |
+
device_code: str
|
| 34 |
+
user_code: str
|
| 35 |
+
verification_uri: str
|
| 36 |
+
interval: int
|
| 37 |
+
expires_at: int
|
| 38 |
+
region: str
|
| 39 |
+
started_at: float
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
@dataclass
|
| 43 |
+
class SocialAuthState:
|
| 44 |
+
"""Social Auth 登录状态"""
|
| 45 |
+
provider: str # Google / Github
|
| 46 |
+
code_verifier: str
|
| 47 |
+
code_challenge: str
|
| 48 |
+
oauth_state: str
|
| 49 |
+
expires_at: int
|
| 50 |
+
started_at: float
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
# 全局登录状态
|
| 54 |
+
_login_state: Optional[DeviceFlowState] = None
|
| 55 |
+
_social_auth_state: Optional[SocialAuthState] = None
|
| 56 |
+
_callback_server = None
|
| 57 |
+
|
| 58 |
+
# Kiro OIDC 配置
|
| 59 |
+
KIRO_START_URL = "https://view.awsapps.com/start"
|
| 60 |
+
KIRO_AUTH_ENDPOINT = "https://prod.us-east-1.auth.desktop.kiro.dev"
|
| 61 |
+
KIRO_SCOPES = [
|
| 62 |
+
"codewhisperer:completions",
|
| 63 |
+
"codewhisperer:analysis",
|
| 64 |
+
"codewhisperer:conversations",
|
| 65 |
+
"codewhisperer:transformations",
|
| 66 |
+
"codewhisperer:taskassist",
|
| 67 |
+
]
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def get_login_state() -> Optional[dict]:
|
| 71 |
+
"""获取当前登录状态"""
|
| 72 |
+
global _login_state
|
| 73 |
+
if _login_state is None:
|
| 74 |
+
return None
|
| 75 |
+
|
| 76 |
+
# 检查是否过期
|
| 77 |
+
if time.time() > _login_state.expires_at:
|
| 78 |
+
_login_state = None
|
| 79 |
+
return None
|
| 80 |
+
|
| 81 |
+
return {
|
| 82 |
+
"user_code": _login_state.user_code,
|
| 83 |
+
"verification_uri": _login_state.verification_uri,
|
| 84 |
+
"expires_in": int(_login_state.expires_at - time.time()),
|
| 85 |
+
"interval": _login_state.interval,
|
| 86 |
+
}
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
async def start_device_flow(region: str = "us-east-1") -> Tuple[bool, dict]:
|
| 90 |
+
"""
|
| 91 |
+
启动设备授权流程
|
| 92 |
+
|
| 93 |
+
Returns:
|
| 94 |
+
(success, result_or_error)
|
| 95 |
+
"""
|
| 96 |
+
global _login_state
|
| 97 |
+
|
| 98 |
+
oidc_base = f"https://oidc.{region}.amazonaws.com"
|
| 99 |
+
|
| 100 |
+
async with httpx.AsyncClient(timeout=30) as client:
|
| 101 |
+
# Step 1: 注册 OIDC 客户端
|
| 102 |
+
print(f"[DeviceFlow] Step 1: 注册 OIDC 客户端...")
|
| 103 |
+
|
| 104 |
+
reg_body = {
|
| 105 |
+
"clientName": "Kiro Proxy",
|
| 106 |
+
"clientType": "public",
|
| 107 |
+
"scopes": KIRO_SCOPES,
|
| 108 |
+
"grantTypes": ["urn:ietf:params:oauth:grant-type:device_code", "refresh_token"],
|
| 109 |
+
"issuerUrl": KIRO_START_URL
|
| 110 |
+
}
|
| 111 |
+
|
| 112 |
+
try:
|
| 113 |
+
reg_resp = await client.post(
|
| 114 |
+
f"{oidc_base}/client/register",
|
| 115 |
+
json=reg_body,
|
| 116 |
+
headers={"Content-Type": "application/json"}
|
| 117 |
+
)
|
| 118 |
+
except Exception as e:
|
| 119 |
+
return False, {"error": f"注册客户端请求失败: {e}"}
|
| 120 |
+
|
| 121 |
+
if reg_resp.status_code != 200:
|
| 122 |
+
return False, {"error": f"注册客户端失败: {reg_resp.text}"}
|
| 123 |
+
|
| 124 |
+
reg_data = reg_resp.json()
|
| 125 |
+
client_id = reg_data.get("clientId")
|
| 126 |
+
client_secret = reg_data.get("clientSecret")
|
| 127 |
+
|
| 128 |
+
if not client_id or not client_secret:
|
| 129 |
+
return False, {"error": "注册响应缺少 clientId 或 clientSecret"}
|
| 130 |
+
|
| 131 |
+
print(f"[DeviceFlow] 客户端注册成功: {client_id[:20]}...")
|
| 132 |
+
|
| 133 |
+
# Step 2: 发起设备授权
|
| 134 |
+
print(f"[DeviceFlow] Step 2: 发起设备授权...")
|
| 135 |
+
|
| 136 |
+
auth_body = {
|
| 137 |
+
"clientId": client_id,
|
| 138 |
+
"clientSecret": client_secret,
|
| 139 |
+
"startUrl": KIRO_START_URL
|
| 140 |
+
}
|
| 141 |
+
|
| 142 |
+
try:
|
| 143 |
+
auth_resp = await client.post(
|
| 144 |
+
f"{oidc_base}/device_authorization",
|
| 145 |
+
json=auth_body,
|
| 146 |
+
headers={"Content-Type": "application/json"}
|
| 147 |
+
)
|
| 148 |
+
except Exception as e:
|
| 149 |
+
return False, {"error": f"设备授权请求失败: {e}"}
|
| 150 |
+
|
| 151 |
+
if auth_resp.status_code != 200:
|
| 152 |
+
return False, {"error": f"设备授权失败: {auth_resp.text}"}
|
| 153 |
+
|
| 154 |
+
auth_data = auth_resp.json()
|
| 155 |
+
device_code = auth_data.get("deviceCode")
|
| 156 |
+
user_code = auth_data.get("userCode")
|
| 157 |
+
verification_uri = auth_data.get("verificationUriComplete") or auth_data.get("verificationUri")
|
| 158 |
+
interval = auth_data.get("interval", 5)
|
| 159 |
+
expires_in = auth_data.get("expiresIn", 600)
|
| 160 |
+
|
| 161 |
+
if not device_code or not user_code or not verification_uri:
|
| 162 |
+
return False, {"error": "设备授权响应缺少必要字��"}
|
| 163 |
+
|
| 164 |
+
print(f"[DeviceFlow] 设备码获取成功: {user_code}")
|
| 165 |
+
|
| 166 |
+
# 保存状态
|
| 167 |
+
_login_state = DeviceFlowState(
|
| 168 |
+
client_id=client_id,
|
| 169 |
+
client_secret=client_secret,
|
| 170 |
+
device_code=device_code,
|
| 171 |
+
user_code=user_code,
|
| 172 |
+
verification_uri=verification_uri,
|
| 173 |
+
interval=interval,
|
| 174 |
+
expires_at=int(time.time() + expires_in),
|
| 175 |
+
region=region,
|
| 176 |
+
started_at=time.time()
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
return True, {
|
| 180 |
+
"user_code": user_code,
|
| 181 |
+
"verification_uri": verification_uri,
|
| 182 |
+
"expires_in": expires_in,
|
| 183 |
+
"interval": interval,
|
| 184 |
+
}
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
async def poll_device_flow() -> Tuple[bool, dict]:
|
| 188 |
+
"""
|
| 189 |
+
轮询设备授权状态
|
| 190 |
+
|
| 191 |
+
Returns:
|
| 192 |
+
(success, result_or_error)
|
| 193 |
+
- success=True, result={"completed": True, "credentials": {...}} 授权完成
|
| 194 |
+
- success=True, result={"completed": False, "status": "pending"} 等待中
|
| 195 |
+
- success=False, result={"error": "..."} 错误
|
| 196 |
+
"""
|
| 197 |
+
global _login_state
|
| 198 |
+
|
| 199 |
+
if _login_state is None:
|
| 200 |
+
return False, {"error": "没有进行中的登录"}
|
| 201 |
+
|
| 202 |
+
# 检查是否过期
|
| 203 |
+
if time.time() > _login_state.expires_at:
|
| 204 |
+
_login_state = None
|
| 205 |
+
return False, {"error": "授权已过期,请重新开始"}
|
| 206 |
+
|
| 207 |
+
oidc_base = f"https://oidc.{_login_state.region}.amazonaws.com"
|
| 208 |
+
|
| 209 |
+
token_body = {
|
| 210 |
+
"clientId": _login_state.client_id,
|
| 211 |
+
"clientSecret": _login_state.client_secret,
|
| 212 |
+
"grantType": "urn:ietf:params:oauth:grant-type:device_code",
|
| 213 |
+
"deviceCode": _login_state.device_code
|
| 214 |
+
}
|
| 215 |
+
|
| 216 |
+
async with httpx.AsyncClient(timeout=30) as client:
|
| 217 |
+
try:
|
| 218 |
+
token_resp = await client.post(
|
| 219 |
+
f"{oidc_base}/token",
|
| 220 |
+
json=token_body,
|
| 221 |
+
headers={"Content-Type": "application/json"}
|
| 222 |
+
)
|
| 223 |
+
except Exception as e:
|
| 224 |
+
return False, {"error": f"Token 请求失败: {e}"}
|
| 225 |
+
|
| 226 |
+
if token_resp.status_code == 200:
|
| 227 |
+
# 授权成功
|
| 228 |
+
token_data = token_resp.json()
|
| 229 |
+
|
| 230 |
+
credentials = {
|
| 231 |
+
"accessToken": token_data.get("accessToken"),
|
| 232 |
+
"refreshToken": token_data.get("refreshToken"),
|
| 233 |
+
"expiresAt": datetime.now(timezone.utc).isoformat(),
|
| 234 |
+
"clientId": _login_state.client_id,
|
| 235 |
+
"clientSecret": _login_state.client_secret,
|
| 236 |
+
"region": _login_state.region,
|
| 237 |
+
"authMethod": "idc",
|
| 238 |
+
}
|
| 239 |
+
|
| 240 |
+
# 计算过期时间
|
| 241 |
+
if expires_in := token_data.get("expiresIn"):
|
| 242 |
+
from datetime import timedelta
|
| 243 |
+
expires_at = datetime.now(timezone.utc) + timedelta(seconds=expires_in)
|
| 244 |
+
credentials["expiresAt"] = expires_at.isoformat()
|
| 245 |
+
|
| 246 |
+
# 清除状态
|
| 247 |
+
_login_state = None
|
| 248 |
+
|
| 249 |
+
print(f"[DeviceFlow] 授权成功!")
|
| 250 |
+
return True, {"completed": True, "credentials": credentials}
|
| 251 |
+
|
| 252 |
+
# 检查错误类型
|
| 253 |
+
try:
|
| 254 |
+
error_data = token_resp.json()
|
| 255 |
+
error_code = error_data.get("error", "")
|
| 256 |
+
except:
|
| 257 |
+
error_code = ""
|
| 258 |
+
|
| 259 |
+
if error_code == "authorization_pending":
|
| 260 |
+
# 用户还未完成授权
|
| 261 |
+
return True, {"completed": False, "status": "pending"}
|
| 262 |
+
elif error_code == "slow_down":
|
| 263 |
+
# 请求太频繁
|
| 264 |
+
return True, {"completed": False, "status": "slow_down"}
|
| 265 |
+
elif error_code == "expired_token":
|
| 266 |
+
_login_state = None
|
| 267 |
+
return False, {"error": "授权已过期,请重新开始"}
|
| 268 |
+
elif error_code == "access_denied":
|
| 269 |
+
_login_state = None
|
| 270 |
+
return False, {"error": "用户拒绝授权"}
|
| 271 |
+
else:
|
| 272 |
+
return False, {"error": f"Token 请求失败: {token_resp.text}"}
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
def cancel_device_flow() -> bool:
|
| 276 |
+
"""取消设备授权流程"""
|
| 277 |
+
global _login_state
|
| 278 |
+
if _login_state is not None:
|
| 279 |
+
_login_state = None
|
| 280 |
+
return True
|
| 281 |
+
return False
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
async def save_credentials_to_file(credentials: dict, name: str = "kiro-proxy-auth") -> str:
|
| 285 |
+
"""
|
| 286 |
+
保存凭证到文件
|
| 287 |
+
|
| 288 |
+
支持的字段:
|
| 289 |
+
- accessToken, refreshToken, profileArn, expiresAt
|
| 290 |
+
- clientId, clientSecret (IDC 认证)
|
| 291 |
+
- region, authMethod, provider
|
| 292 |
+
|
| 293 |
+
Returns:
|
| 294 |
+
保存的文件路径
|
| 295 |
+
"""
|
| 296 |
+
from ..config import TOKEN_DIR
|
| 297 |
+
TOKEN_DIR.mkdir(parents=True, exist_ok=True)
|
| 298 |
+
|
| 299 |
+
# 生成文件名
|
| 300 |
+
file_path = TOKEN_DIR / f"{name}.json"
|
| 301 |
+
|
| 302 |
+
# 如果文件已存在,合并现有数据
|
| 303 |
+
existing = {}
|
| 304 |
+
if file_path.exists():
|
| 305 |
+
try:
|
| 306 |
+
with open(file_path, "r") as f:
|
| 307 |
+
existing = json.load(f)
|
| 308 |
+
except Exception:
|
| 309 |
+
pass
|
| 310 |
+
|
| 311 |
+
# 更新凭证(只更新非空值)
|
| 312 |
+
for key, value in credentials.items():
|
| 313 |
+
if value is not None:
|
| 314 |
+
existing[key] = value
|
| 315 |
+
|
| 316 |
+
with open(file_path, "w") as f:
|
| 317 |
+
json.dump(existing, f, indent=2)
|
| 318 |
+
|
| 319 |
+
print(f"[DeviceFlow] 凭证已保存到: {file_path}")
|
| 320 |
+
return str(file_path)
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
# ==================== Social Auth (Google/GitHub) ====================
|
| 324 |
+
|
| 325 |
+
def _generate_code_verifier() -> str:
|
| 326 |
+
"""生成 PKCE code_verifier"""
|
| 327 |
+
return secrets.token_urlsafe(64)[:128]
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
def _generate_code_challenge(verifier: str) -> str:
|
| 331 |
+
"""生成 PKCE code_challenge (SHA256)"""
|
| 332 |
+
digest = hashlib.sha256(verifier.encode()).digest()
|
| 333 |
+
return base64.urlsafe_b64encode(digest).rstrip(b'=').decode()
|
| 334 |
+
|
| 335 |
+
|
| 336 |
+
def _generate_oauth_state() -> str:
|
| 337 |
+
"""生成 OAuth state"""
|
| 338 |
+
return secrets.token_urlsafe(32)
|
| 339 |
+
|
| 340 |
+
|
| 341 |
+
def get_social_auth_state() -> Optional[dict]:
|
| 342 |
+
"""获取当前 Social Auth 状态"""
|
| 343 |
+
global _social_auth_state
|
| 344 |
+
if _social_auth_state is None:
|
| 345 |
+
return None
|
| 346 |
+
|
| 347 |
+
if time.time() > _social_auth_state.expires_at:
|
| 348 |
+
_social_auth_state = None
|
| 349 |
+
return None
|
| 350 |
+
|
| 351 |
+
return {
|
| 352 |
+
"provider": _social_auth_state.provider,
|
| 353 |
+
"expires_in": int(_social_auth_state.expires_at - time.time()),
|
| 354 |
+
}
|
| 355 |
+
|
| 356 |
+
|
| 357 |
+
async def start_social_auth(provider: str, redirect_uri: str = None) -> Tuple[bool, dict]:
|
| 358 |
+
"""
|
| 359 |
+
启动 Social Auth 登录 (Google/GitHub)
|
| 360 |
+
|
| 361 |
+
Args:
|
| 362 |
+
provider: "google" 或 "github"
|
| 363 |
+
redirect_uri: 回调地址,默认使用 Kiro 官方回调地址
|
| 364 |
+
|
| 365 |
+
Returns:
|
| 366 |
+
(success, result_or_error)
|
| 367 |
+
"""
|
| 368 |
+
global _social_auth_state
|
| 369 |
+
|
| 370 |
+
# 验证 provider
|
| 371 |
+
provider_normalized = provider.lower()
|
| 372 |
+
if provider_normalized == "google":
|
| 373 |
+
provider_normalized = "Google"
|
| 374 |
+
elif provider_normalized == "github":
|
| 375 |
+
provider_normalized = "Github"
|
| 376 |
+
else:
|
| 377 |
+
return False, {"error": f"不支持的登录提供商: {provider}"}
|
| 378 |
+
|
| 379 |
+
print(f"[SocialAuth] 开始 {provider_normalized} 登录流程")
|
| 380 |
+
|
| 381 |
+
# 生成 PKCE
|
| 382 |
+
code_verifier = _generate_code_verifier()
|
| 383 |
+
code_challenge = _generate_code_challenge(code_verifier)
|
| 384 |
+
oauth_state = _generate_oauth_state()
|
| 385 |
+
|
| 386 |
+
# 回调地址 - 使用 Kiro 官方的回调地址(已在 Cognito 中注册)
|
| 387 |
+
# 参考 Kiro-account-manager: kiro://kiro.kiroAgent/authenticate-success
|
| 388 |
+
if redirect_uri is None:
|
| 389 |
+
redirect_uri = "kiro://kiro.kiroAgent/authenticate-success"
|
| 390 |
+
|
| 391 |
+
# 构建登录 URL (使用 /login 端点,参考 Kiro-account-manager)
|
| 392 |
+
from urllib.parse import quote, urlencode
|
| 393 |
+
|
| 394 |
+
# 使用 urlencode 确保参数正确编码
|
| 395 |
+
params = {
|
| 396 |
+
"idp": provider_normalized,
|
| 397 |
+
"redirect_uri": redirect_uri,
|
| 398 |
+
"code_challenge": code_challenge,
|
| 399 |
+
"code_challenge_method": "S256",
|
| 400 |
+
"state": oauth_state,
|
| 401 |
+
}
|
| 402 |
+
login_url = f"{KIRO_AUTH_ENDPOINT}/login?{urlencode(params)}"
|
| 403 |
+
|
| 404 |
+
print(f"[SocialAuth] ========== Social Auth 登录 ==========")
|
| 405 |
+
print(f"[SocialAuth] Provider: {provider_normalized}")
|
| 406 |
+
print(f"[SocialAuth] Redirect URI: {redirect_uri}")
|
| 407 |
+
print(f"[SocialAuth] Code Challenge: {code_challenge[:20]}...")
|
| 408 |
+
print(f"[SocialAuth] State: {oauth_state}")
|
| 409 |
+
print(f"[SocialAuth] 登录 URL: {login_url}")
|
| 410 |
+
print(f"[SocialAuth] =========================================")
|
| 411 |
+
|
| 412 |
+
# 保存状态(10 分钟过期)
|
| 413 |
+
_social_auth_state = SocialAuthState(
|
| 414 |
+
provider=provider_normalized,
|
| 415 |
+
code_verifier=code_verifier,
|
| 416 |
+
code_challenge=code_challenge,
|
| 417 |
+
oauth_state=oauth_state,
|
| 418 |
+
expires_at=int(time.time() + 600),
|
| 419 |
+
started_at=time.time(),
|
| 420 |
+
)
|
| 421 |
+
|
| 422 |
+
return True, {
|
| 423 |
+
"login_url": login_url,
|
| 424 |
+
"state": oauth_state,
|
| 425 |
+
"provider": provider_normalized,
|
| 426 |
+
"redirect_uri": redirect_uri,
|
| 427 |
+
}
|
| 428 |
+
|
| 429 |
+
|
| 430 |
+
async def exchange_social_auth_token(code: str, state: str, redirect_uri: str = None) -> Tuple[bool, dict]:
|
| 431 |
+
"""
|
| 432 |
+
用授权码交换 Token
|
| 433 |
+
|
| 434 |
+
参考 Kiro-account-manager 实现:
|
| 435 |
+
- 端点: https://prod.us-east-1.auth.desktop.kiro.dev/oauth/token
|
| 436 |
+
- 请求体: {code, code_verifier, redirect_uri}
|
| 437 |
+
- 响应: {accessToken, refreshToken, profileArn, expiresIn}
|
| 438 |
+
|
| 439 |
+
Args:
|
| 440 |
+
code: 授权码
|
| 441 |
+
state: OAuth state
|
| 442 |
+
redirect_uri: 回调地址(需要与 start_social_auth 中使用的一致)
|
| 443 |
+
|
| 444 |
+
Returns:
|
| 445 |
+
(success, result_or_error)
|
| 446 |
+
"""
|
| 447 |
+
global _social_auth_state
|
| 448 |
+
|
| 449 |
+
if _social_auth_state is None:
|
| 450 |
+
return False, {"error": "没有进行中的社交登录"}
|
| 451 |
+
|
| 452 |
+
# 验证 state
|
| 453 |
+
if state != _social_auth_state.oauth_state:
|
| 454 |
+
_social_auth_state = None
|
| 455 |
+
return False, {"error": "OAuth state 不匹配"}
|
| 456 |
+
|
| 457 |
+
# 检查过期
|
| 458 |
+
if time.time() > _social_auth_state.expires_at:
|
| 459 |
+
_social_auth_state = None
|
| 460 |
+
return False, {"error": "登录已过期,请重新开始"}
|
| 461 |
+
|
| 462 |
+
print(f"[SocialAuth] 交换 Token...")
|
| 463 |
+
|
| 464 |
+
# 回调地址 - 需要与 start_social_auth 中使��的一致
|
| 465 |
+
# 使用 Kiro 官方的回调地址
|
| 466 |
+
if redirect_uri is None:
|
| 467 |
+
redirect_uri = "kiro://kiro.kiroAgent/authenticate-success"
|
| 468 |
+
|
| 469 |
+
# 交换 Token (参考 Kiro-account-manager 的请求格式)
|
| 470 |
+
token_body = {
|
| 471 |
+
"code": code,
|
| 472 |
+
"code_verifier": _social_auth_state.code_verifier,
|
| 473 |
+
"redirect_uri": redirect_uri,
|
| 474 |
+
}
|
| 475 |
+
|
| 476 |
+
async with httpx.AsyncClient(timeout=30) as client:
|
| 477 |
+
try:
|
| 478 |
+
token_resp = await client.post(
|
| 479 |
+
f"{KIRO_AUTH_ENDPOINT}/oauth/token",
|
| 480 |
+
json=token_body,
|
| 481 |
+
headers={"Content-Type": "application/json"}
|
| 482 |
+
)
|
| 483 |
+
except Exception as e:
|
| 484 |
+
_social_auth_state = None
|
| 485 |
+
return False, {"error": f"Token 请求失败: {e}"}
|
| 486 |
+
|
| 487 |
+
if token_resp.status_code != 200:
|
| 488 |
+
error_text = token_resp.text
|
| 489 |
+
_social_auth_state = None
|
| 490 |
+
return False, {"error": f"Token 交换失败: {error_text}"}
|
| 491 |
+
|
| 492 |
+
token_data = token_resp.json()
|
| 493 |
+
|
| 494 |
+
# 解析响应 (参考 Kiro-account-manager 的响应格式)
|
| 495 |
+
# 响应字段: accessToken, refreshToken, profileArn, expiresIn
|
| 496 |
+
provider = _social_auth_state.provider
|
| 497 |
+
|
| 498 |
+
credentials = {
|
| 499 |
+
"accessToken": token_data.get("accessToken") or token_data.get("access_token"),
|
| 500 |
+
"refreshToken": token_data.get("refreshToken") or token_data.get("refresh_token"),
|
| 501 |
+
"profileArn": token_data.get("profileArn"),
|
| 502 |
+
"expiresAt": datetime.now(timezone.utc).isoformat(),
|
| 503 |
+
"authMethod": "social",
|
| 504 |
+
"provider": provider, # 保存 provider 字段
|
| 505 |
+
}
|
| 506 |
+
|
| 507 |
+
# 计算过期时间
|
| 508 |
+
expires_in = token_data.get("expiresIn") or token_data.get("expires_in")
|
| 509 |
+
if expires_in:
|
| 510 |
+
from datetime import timedelta
|
| 511 |
+
expires_at = datetime.now(timezone.utc) + timedelta(seconds=expires_in)
|
| 512 |
+
credentials["expiresAt"] = expires_at.isoformat()
|
| 513 |
+
|
| 514 |
+
_social_auth_state = None
|
| 515 |
+
|
| 516 |
+
print(f"[SocialAuth] {provider} 登录成功!")
|
| 517 |
+
return True, {"completed": True, "credentials": credentials, "provider": provider}
|
| 518 |
+
|
| 519 |
+
|
| 520 |
+
def cancel_social_auth() -> bool:
|
| 521 |
+
"""取消 Social Auth 登录"""
|
| 522 |
+
global _social_auth_state
|
| 523 |
+
if _social_auth_state is not None:
|
| 524 |
+
_social_auth_state = None
|
| 525 |
+
return True
|
| 526 |
+
return False
|
| 527 |
+
|
| 528 |
+
|
| 529 |
+
# ==================== 回调服务器 ====================
|
| 530 |
+
|
| 531 |
+
_callback_result = None
|
| 532 |
+
_callback_event = None
|
| 533 |
+
|
| 534 |
+
async def start_callback_server() -> Tuple[bool, dict]:
|
| 535 |
+
"""启动本地回调服务器"""
|
| 536 |
+
global _callback_result, _callback_event
|
| 537 |
+
|
| 538 |
+
from aiohttp import web
|
| 539 |
+
|
| 540 |
+
_callback_result = None
|
| 541 |
+
_callback_event = asyncio.Event()
|
| 542 |
+
|
| 543 |
+
async def handle_callback(request):
|
| 544 |
+
global _callback_result
|
| 545 |
+
code = request.query.get("code")
|
| 546 |
+
state = request.query.get("state")
|
| 547 |
+
error = request.query.get("error")
|
| 548 |
+
|
| 549 |
+
if error:
|
| 550 |
+
_callback_result = {"error": error}
|
| 551 |
+
elif code and state:
|
| 552 |
+
_callback_result = {"code": code, "state": state}
|
| 553 |
+
else:
|
| 554 |
+
_callback_result = {"error": "缺少授权码"}
|
| 555 |
+
|
| 556 |
+
_callback_event.set()
|
| 557 |
+
|
| 558 |
+
# 返回成功页面
|
| 559 |
+
html = """
|
| 560 |
+
<html>
|
| 561 |
+
<head><title>登录成功</title></head>
|
| 562 |
+
<body style="font-family:sans-serif;text-align:center;padding:50px">
|
| 563 |
+
<h1>✅ 登录成功</h1>
|
| 564 |
+
<p>您可以关闭此窗口并返回 Kiro Proxy</p>
|
| 565 |
+
<script>setTimeout(()=>window.close(),2000)</script>
|
| 566 |
+
</body>
|
| 567 |
+
</html>
|
| 568 |
+
"""
|
| 569 |
+
return web.Response(text=html, content_type="text/html")
|
| 570 |
+
|
| 571 |
+
app = web.Application()
|
| 572 |
+
app.router.add_get("/kiro-social-callback", handle_callback)
|
| 573 |
+
|
| 574 |
+
runner = web.AppRunner(app)
|
| 575 |
+
await runner.setup()
|
| 576 |
+
|
| 577 |
+
try:
|
| 578 |
+
site = web.TCPSite(runner, "127.0.0.1", 19823)
|
| 579 |
+
await site.start()
|
| 580 |
+
print("[SocialAuth] 回调服务器已启动: http://127.0.0.1:19823")
|
| 581 |
+
return True, {"port": 19823}
|
| 582 |
+
except Exception as e:
|
| 583 |
+
return False, {"error": f"启动回调服务器失败: {e}"}
|
| 584 |
+
|
| 585 |
+
|
| 586 |
+
async def wait_for_callback(timeout: int = 300) -> Tuple[bool, dict]:
|
| 587 |
+
"""等待回调"""
|
| 588 |
+
global _callback_result, _callback_event
|
| 589 |
+
|
| 590 |
+
if _callback_event is None:
|
| 591 |
+
return False, {"error": "回调服务器未启动"}
|
| 592 |
+
|
| 593 |
+
try:
|
| 594 |
+
await asyncio.wait_for(_callback_event.wait(), timeout=timeout)
|
| 595 |
+
|
| 596 |
+
if _callback_result and "code" in _callback_result:
|
| 597 |
+
return True, _callback_result
|
| 598 |
+
elif _callback_result and "error" in _callback_result:
|
| 599 |
+
return False, _callback_result
|
| 600 |
+
else:
|
| 601 |
+
return False, {"error": "未收到有效回调"}
|
| 602 |
+
except asyncio.TimeoutError:
|
| 603 |
+
return False, {"error": "等待回调超时"}
|
KiroProxy/kiro_proxy/cli.py
ADDED
|
@@ -0,0 +1,375 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Kiro Proxy CLI - 轻量命令行工具"""
|
| 3 |
+
import argparse
|
| 4 |
+
import asyncio
|
| 5 |
+
import json
|
| 6 |
+
import sys
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
|
| 9 |
+
from . import __version__
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def cmd_serve(args):
|
| 13 |
+
"""启动代理服务"""
|
| 14 |
+
from .main import run
|
| 15 |
+
run(port=args.port)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def cmd_accounts_list(args):
|
| 19 |
+
"""列出所有账号"""
|
| 20 |
+
from .core import state
|
| 21 |
+
accounts = state.get_accounts_status()
|
| 22 |
+
if not accounts:
|
| 23 |
+
print("暂无账号")
|
| 24 |
+
return
|
| 25 |
+
print(f"{'ID':<10} {'名称':<20} {'状态':<10} {'请求数':<8}")
|
| 26 |
+
print("-" * 50)
|
| 27 |
+
for acc in accounts:
|
| 28 |
+
print(f"{acc['id']:<10} {acc['name']:<20} {acc['status']:<10} {acc['request_count']:<8}")
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def cmd_accounts_export(args):
|
| 32 |
+
"""导出账号配置"""
|
| 33 |
+
from .core import state
|
| 34 |
+
accounts_data = []
|
| 35 |
+
for acc in state.accounts:
|
| 36 |
+
creds = acc.get_credentials()
|
| 37 |
+
if creds:
|
| 38 |
+
accounts_data.append({
|
| 39 |
+
"name": acc.name,
|
| 40 |
+
"enabled": acc.enabled,
|
| 41 |
+
"credentials": {
|
| 42 |
+
"accessToken": creds.access_token,
|
| 43 |
+
"refreshToken": creds.refresh_token,
|
| 44 |
+
"expiresAt": creds.expires_at,
|
| 45 |
+
"region": creds.region,
|
| 46 |
+
"authMethod": creds.auth_method,
|
| 47 |
+
}
|
| 48 |
+
})
|
| 49 |
+
|
| 50 |
+
output = {"accounts": accounts_data, "version": "1.0"}
|
| 51 |
+
|
| 52 |
+
if args.output:
|
| 53 |
+
Path(args.output).write_text(json.dumps(output, indent=2, ensure_ascii=False))
|
| 54 |
+
print(f"已导出 {len(accounts_data)} 个账号到 {args.output}")
|
| 55 |
+
else:
|
| 56 |
+
print(json.dumps(output, indent=2, ensure_ascii=False))
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def cmd_accounts_import(args):
|
| 60 |
+
"""导入账号配置"""
|
| 61 |
+
import uuid
|
| 62 |
+
from .core import state, Account
|
| 63 |
+
from .auth import save_credentials_to_file
|
| 64 |
+
|
| 65 |
+
data = json.loads(Path(args.file).read_text())
|
| 66 |
+
accounts_data = data.get("accounts", [])
|
| 67 |
+
imported = 0
|
| 68 |
+
|
| 69 |
+
for acc_data in accounts_data:
|
| 70 |
+
creds = acc_data.get("credentials", {})
|
| 71 |
+
if not creds.get("accessToken"):
|
| 72 |
+
print(f"跳过 {acc_data.get('name', '未知')}: 缺少 accessToken")
|
| 73 |
+
continue
|
| 74 |
+
|
| 75 |
+
# 保存凭证到文件
|
| 76 |
+
file_path = asyncio.run(save_credentials_to_file({
|
| 77 |
+
"accessToken": creds.get("accessToken"),
|
| 78 |
+
"refreshToken": creds.get("refreshToken"),
|
| 79 |
+
"expiresAt": creds.get("expiresAt"),
|
| 80 |
+
"region": creds.get("region", "us-east-1"),
|
| 81 |
+
"authMethod": creds.get("authMethod", "social"),
|
| 82 |
+
}, f"imported-{uuid.uuid4().hex[:8]}"))
|
| 83 |
+
|
| 84 |
+
account = Account(
|
| 85 |
+
id=uuid.uuid4().hex[:8],
|
| 86 |
+
name=acc_data.get("name", "导入账号"),
|
| 87 |
+
token_path=file_path,
|
| 88 |
+
enabled=acc_data.get("enabled", True)
|
| 89 |
+
)
|
| 90 |
+
state.accounts.append(account)
|
| 91 |
+
account.load_credentials()
|
| 92 |
+
imported += 1
|
| 93 |
+
print(f"已导入: {account.name}")
|
| 94 |
+
|
| 95 |
+
state._save_accounts()
|
| 96 |
+
print(f"\n共导入 {imported} 个账号")
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def cmd_accounts_add(args):
|
| 100 |
+
"""手动添加 Token"""
|
| 101 |
+
import uuid
|
| 102 |
+
from .core import state, Account
|
| 103 |
+
from .auth import save_credentials_to_file
|
| 104 |
+
|
| 105 |
+
print("手动添加 Kiro 账号")
|
| 106 |
+
print("-" * 40)
|
| 107 |
+
|
| 108 |
+
name = input("账号名称 [我的账号]: ").strip() or "我的账号"
|
| 109 |
+
print("\n请粘贴 Access Token:")
|
| 110 |
+
access_token = input().strip()
|
| 111 |
+
|
| 112 |
+
if not access_token:
|
| 113 |
+
print("错误: Access Token 不能为空")
|
| 114 |
+
return
|
| 115 |
+
|
| 116 |
+
print("\n请粘贴 Refresh Token (可选,直接回车跳过):")
|
| 117 |
+
refresh_token = input().strip() or None
|
| 118 |
+
|
| 119 |
+
# 保存凭证
|
| 120 |
+
file_path = asyncio.run(save_credentials_to_file({
|
| 121 |
+
"accessToken": access_token,
|
| 122 |
+
"refreshToken": refresh_token,
|
| 123 |
+
"region": "us-east-1",
|
| 124 |
+
"authMethod": "social",
|
| 125 |
+
}, f"manual-{uuid.uuid4().hex[:8]}"))
|
| 126 |
+
|
| 127 |
+
account = Account(
|
| 128 |
+
id=uuid.uuid4().hex[:8],
|
| 129 |
+
name=name,
|
| 130 |
+
token_path=file_path
|
| 131 |
+
)
|
| 132 |
+
state.accounts.append(account)
|
| 133 |
+
account.load_credentials()
|
| 134 |
+
state._save_accounts()
|
| 135 |
+
|
| 136 |
+
print(f"\n✅ 账号已添加: {name} (ID: {account.id})")
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def cmd_accounts_scan(args):
|
| 140 |
+
"""扫描本地 Token"""
|
| 141 |
+
import uuid
|
| 142 |
+
from .core import state, Account
|
| 143 |
+
from .config import TOKEN_DIR
|
| 144 |
+
|
| 145 |
+
# 扫描新目录
|
| 146 |
+
found = []
|
| 147 |
+
if TOKEN_DIR.exists():
|
| 148 |
+
for f in TOKEN_DIR.glob("*.json"):
|
| 149 |
+
try:
|
| 150 |
+
data = json.loads(f.read_text())
|
| 151 |
+
if "accessToken" in data:
|
| 152 |
+
already = any(a.token_path == str(f) for a in state.accounts)
|
| 153 |
+
found.append({"path": str(f), "name": f.stem, "already": already})
|
| 154 |
+
except:
|
| 155 |
+
pass
|
| 156 |
+
|
| 157 |
+
# 兼容旧目录
|
| 158 |
+
sso_cache = Path.home() / ".aws/sso/cache"
|
| 159 |
+
if sso_cache.exists():
|
| 160 |
+
for f in sso_cache.glob("*.json"):
|
| 161 |
+
try:
|
| 162 |
+
data = json.loads(f.read_text())
|
| 163 |
+
if "accessToken" in data:
|
| 164 |
+
already = any(a.token_path == str(f) for a in state.accounts)
|
| 165 |
+
found.append({"path": str(f), "name": f.stem + " (旧目录)", "already": already})
|
| 166 |
+
except:
|
| 167 |
+
pass
|
| 168 |
+
|
| 169 |
+
if not found:
|
| 170 |
+
print("未找到 Token 文件")
|
| 171 |
+
print(f"Token 目录: {TOKEN_DIR}")
|
| 172 |
+
return
|
| 173 |
+
|
| 174 |
+
print(f"找到 {len(found)} 个 Token:\n")
|
| 175 |
+
for i, t in enumerate(found):
|
| 176 |
+
status = "[已添加]" if t["already"] else ""
|
| 177 |
+
print(f" {i+1}. {t['name']} {status}")
|
| 178 |
+
|
| 179 |
+
if args.auto:
|
| 180 |
+
# 自动添加所有未添加的
|
| 181 |
+
added = 0
|
| 182 |
+
for t in found:
|
| 183 |
+
if not t["already"]:
|
| 184 |
+
account = Account(
|
| 185 |
+
id=uuid.uuid4().hex[:8],
|
| 186 |
+
name=t["name"],
|
| 187 |
+
token_path=t["path"]
|
| 188 |
+
)
|
| 189 |
+
state.accounts.append(account)
|
| 190 |
+
account.load_credentials()
|
| 191 |
+
added += 1
|
| 192 |
+
state._save_accounts()
|
| 193 |
+
print(f"\n已添加 {added} 个账号")
|
| 194 |
+
else:
|
| 195 |
+
print("\n使用 --auto 自动添加所有未添加的账号")
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
def cmd_login_remote(args):
|
| 199 |
+
"""生成远程登录链接"""
|
| 200 |
+
import uuid
|
| 201 |
+
import time
|
| 202 |
+
|
| 203 |
+
session_id = uuid.uuid4().hex
|
| 204 |
+
host = args.host or "localhost:8080"
|
| 205 |
+
scheme = "https" if args.https else "http"
|
| 206 |
+
|
| 207 |
+
print("远程登录链接")
|
| 208 |
+
print("-" * 40)
|
| 209 |
+
print(f"\n将以下链接发送到有浏览器的机器上完成登录:\n")
|
| 210 |
+
print(f" {scheme}://{host}/remote-login/{session_id}")
|
| 211 |
+
print(f"\n链接有效期 10 分钟")
|
| 212 |
+
print("\n登录完成后,在那台机器上导出账号,然后在这里导入:")
|
| 213 |
+
print(f" python -m kiro_proxy accounts import xxx.json")
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
def cmd_login_social(args):
|
| 217 |
+
"""Social 登录 (Google/GitHub)"""
|
| 218 |
+
from .auth import start_social_auth
|
| 219 |
+
|
| 220 |
+
provider = args.provider
|
| 221 |
+
print(f"启动 {provider.title()} 登录...")
|
| 222 |
+
|
| 223 |
+
success, result = asyncio.run(start_social_auth(provider))
|
| 224 |
+
if not success:
|
| 225 |
+
print(f"错误: {result.get('error', '未知错误')}")
|
| 226 |
+
return
|
| 227 |
+
|
| 228 |
+
print(f"\n请在浏览器中打开以下链接完成授权:\n")
|
| 229 |
+
print(f" {result['login_url']}")
|
| 230 |
+
print(f"\n授权完成后,将浏览器地址栏中的完整 URL 粘贴到这里:")
|
| 231 |
+
callback_url = input().strip()
|
| 232 |
+
|
| 233 |
+
if not callback_url:
|
| 234 |
+
print("已取消")
|
| 235 |
+
return
|
| 236 |
+
|
| 237 |
+
try:
|
| 238 |
+
from urllib.parse import urlparse, parse_qs
|
| 239 |
+
parsed = urlparse(callback_url)
|
| 240 |
+
params = parse_qs(parsed.query)
|
| 241 |
+
code = params.get("code", [None])[0]
|
| 242 |
+
oauth_state = params.get("state", [None])[0]
|
| 243 |
+
|
| 244 |
+
if not code or not oauth_state:
|
| 245 |
+
print("错误: 无效的回调 URL")
|
| 246 |
+
return
|
| 247 |
+
|
| 248 |
+
from .auth import exchange_social_auth_token
|
| 249 |
+
success, result = asyncio.run(exchange_social_auth_token(code, oauth_state))
|
| 250 |
+
|
| 251 |
+
if success and result.get("completed"):
|
| 252 |
+
import uuid
|
| 253 |
+
from .core import state, Account
|
| 254 |
+
from .auth import save_credentials_to_file
|
| 255 |
+
|
| 256 |
+
credentials = result["credentials"]
|
| 257 |
+
file_path = asyncio.run(save_credentials_to_file(
|
| 258 |
+
credentials, f"cli-{provider}"
|
| 259 |
+
))
|
| 260 |
+
|
| 261 |
+
account = Account(
|
| 262 |
+
id=uuid.uuid4().hex[:8],
|
| 263 |
+
name=f"{provider.title()} 登录",
|
| 264 |
+
token_path=file_path
|
| 265 |
+
)
|
| 266 |
+
state.accounts.append(account)
|
| 267 |
+
account.load_credentials()
|
| 268 |
+
state._save_accounts()
|
| 269 |
+
|
| 270 |
+
print(f"\n✅ 登录成功! 账号已添加: {account.name}")
|
| 271 |
+
else:
|
| 272 |
+
print(f"错误: {result.get('error', '登录失败')}")
|
| 273 |
+
except Exception as e:
|
| 274 |
+
print(f"错误: {e}")
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
def cmd_status(args):
|
| 278 |
+
"""查看服务状态"""
|
| 279 |
+
from .core import state
|
| 280 |
+
stats = state.get_stats()
|
| 281 |
+
|
| 282 |
+
print("Kiro Proxy 状态")
|
| 283 |
+
print("-" * 40)
|
| 284 |
+
print(f"运行时间: {stats['uptime_seconds']} 秒")
|
| 285 |
+
print(f"总请求数: {stats['total_requests']}")
|
| 286 |
+
print(f"错误数: {stats['total_errors']}")
|
| 287 |
+
print(f"错误率: {stats['error_rate']}")
|
| 288 |
+
print(f"账号总数: {stats['accounts_total']}")
|
| 289 |
+
print(f"可用账号: {stats['accounts_available']}")
|
| 290 |
+
print(f"冷却中: {stats['accounts_cooldown']}")
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
def main():
|
| 294 |
+
parser = argparse.ArgumentParser(
|
| 295 |
+
prog="kiro-proxy",
|
| 296 |
+
description="Kiro API Proxy CLI"
|
| 297 |
+
)
|
| 298 |
+
parser.add_argument("-v", "--version", action="version", version=__version__)
|
| 299 |
+
|
| 300 |
+
subparsers = parser.add_subparsers(dest="command", help="命令")
|
| 301 |
+
|
| 302 |
+
# serve
|
| 303 |
+
serve_parser = subparsers.add_parser("serve", help="启动代理服务")
|
| 304 |
+
serve_parser.add_argument("-p", "--port", type=int, default=8080, help="端口号")
|
| 305 |
+
serve_parser.set_defaults(func=cmd_serve)
|
| 306 |
+
|
| 307 |
+
# status
|
| 308 |
+
status_parser = subparsers.add_parser("status", help="查看状态")
|
| 309 |
+
status_parser.set_defaults(func=cmd_status)
|
| 310 |
+
|
| 311 |
+
# accounts
|
| 312 |
+
accounts_parser = subparsers.add_parser("accounts", help="账号管理")
|
| 313 |
+
accounts_sub = accounts_parser.add_subparsers(dest="accounts_cmd")
|
| 314 |
+
|
| 315 |
+
# accounts list
|
| 316 |
+
list_parser = accounts_sub.add_parser("list", help="列出账号")
|
| 317 |
+
list_parser.set_defaults(func=cmd_accounts_list)
|
| 318 |
+
|
| 319 |
+
# accounts export
|
| 320 |
+
export_parser = accounts_sub.add_parser("export", help="导出账号")
|
| 321 |
+
export_parser.add_argument("-o", "--output", help="输出文件")
|
| 322 |
+
export_parser.set_defaults(func=cmd_accounts_export)
|
| 323 |
+
|
| 324 |
+
# accounts import
|
| 325 |
+
import_parser = accounts_sub.add_parser("import", help="导入账号")
|
| 326 |
+
import_parser.add_argument("file", help="JSON 文件路径")
|
| 327 |
+
import_parser.set_defaults(func=cmd_accounts_import)
|
| 328 |
+
|
| 329 |
+
# accounts add
|
| 330 |
+
add_parser = accounts_sub.add_parser("add", help="手动添加 Token")
|
| 331 |
+
add_parser.set_defaults(func=cmd_accounts_add)
|
| 332 |
+
|
| 333 |
+
# accounts scan
|
| 334 |
+
scan_parser = accounts_sub.add_parser("scan", help="扫描本地 Token")
|
| 335 |
+
scan_parser.add_argument("--auto", action="store_true", help="自动添加")
|
| 336 |
+
scan_parser.set_defaults(func=cmd_accounts_scan)
|
| 337 |
+
|
| 338 |
+
# login
|
| 339 |
+
login_parser = subparsers.add_parser("login", help="登录")
|
| 340 |
+
login_sub = login_parser.add_subparsers(dest="login_cmd")
|
| 341 |
+
|
| 342 |
+
# login remote
|
| 343 |
+
remote_parser = login_sub.add_parser("remote", help="生成远程登录链接")
|
| 344 |
+
remote_parser.add_argument("--host", help="服务器地址 (如 example.com:8080)")
|
| 345 |
+
remote_parser.add_argument("--https", action="store_true", help="使用 HTTPS")
|
| 346 |
+
remote_parser.set_defaults(func=cmd_login_remote)
|
| 347 |
+
|
| 348 |
+
# login google
|
| 349 |
+
google_parser = login_sub.add_parser("google", help="Google 登录")
|
| 350 |
+
google_parser.set_defaults(func=cmd_login_social, provider="google")
|
| 351 |
+
|
| 352 |
+
# login github
|
| 353 |
+
github_parser = login_sub.add_parser("github", help="GitHub 登录")
|
| 354 |
+
github_parser.set_defaults(func=cmd_login_social, provider="github")
|
| 355 |
+
|
| 356 |
+
args = parser.parse_args()
|
| 357 |
+
|
| 358 |
+
if not args.command:
|
| 359 |
+
parser.print_help()
|
| 360 |
+
return
|
| 361 |
+
|
| 362 |
+
if args.command == "accounts" and not args.accounts_cmd:
|
| 363 |
+
accounts_parser.print_help()
|
| 364 |
+
return
|
| 365 |
+
|
| 366 |
+
if args.command == "login" and not args.login_cmd:
|
| 367 |
+
login_parser.print_help()
|
| 368 |
+
return
|
| 369 |
+
|
| 370 |
+
if hasattr(args, "func"):
|
| 371 |
+
args.func(args)
|
| 372 |
+
|
| 373 |
+
|
| 374 |
+
if __name__ == "__main__":
|
| 375 |
+
main()
|
KiroProxy/kiro_proxy/config.py
ADDED
|
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""配置模块"""
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
|
| 4 |
+
KIRO_API_URL = "https://q.us-east-1.amazonaws.com/generateAssistantResponse"
|
| 5 |
+
MODELS_URL = "https://q.us-east-1.amazonaws.com/ListAvailableModels"
|
| 6 |
+
|
| 7 |
+
# 统一数据目录 (所有配置文件都在这里)
|
| 8 |
+
DATA_DIR = Path.home() / ".kiro-proxy"
|
| 9 |
+
|
| 10 |
+
# Token 存储目录
|
| 11 |
+
TOKEN_DIR = DATA_DIR / "tokens"
|
| 12 |
+
|
| 13 |
+
# 默认 Token 路径 (兼容旧代码)
|
| 14 |
+
TOKEN_PATH = TOKEN_DIR / "kiro-auth-token.json"
|
| 15 |
+
|
| 16 |
+
# 配额管理配置
|
| 17 |
+
QUOTA_COOLDOWN_SECONDS = 300 # 配额超限冷却时间(秒)
|
| 18 |
+
|
| 19 |
+
# 模型映射
|
| 20 |
+
MODEL_MAPPING = {
|
| 21 |
+
# Claude 3.5 -> Kiro Claude 4
|
| 22 |
+
"claude-3-5-sonnet-20241022": "claude-sonnet-4",
|
| 23 |
+
"claude-3-5-sonnet-latest": "claude-sonnet-4",
|
| 24 |
+
"claude-3-5-sonnet": "claude-sonnet-4",
|
| 25 |
+
"claude-3-5-haiku-20241022": "claude-haiku-4.5",
|
| 26 |
+
"claude-3-5-haiku-latest": "claude-haiku-4.5",
|
| 27 |
+
# Claude 3
|
| 28 |
+
"claude-3-opus-20240229": "claude-sonnet-4.5",
|
| 29 |
+
"claude-3-opus-latest": "claude-sonnet-4.5",
|
| 30 |
+
"claude-3-sonnet-20240229": "claude-sonnet-4",
|
| 31 |
+
"claude-3-haiku-20240307": "claude-haiku-4.5",
|
| 32 |
+
# Claude 4
|
| 33 |
+
"claude-4-sonnet": "claude-sonnet-4",
|
| 34 |
+
"claude-4-opus": "claude-sonnet-4.5",
|
| 35 |
+
# OpenAI GPT -> Claude
|
| 36 |
+
"gpt-4o": "claude-sonnet-4",
|
| 37 |
+
"gpt-4o-mini": "claude-haiku-4.5",
|
| 38 |
+
"gpt-4-turbo": "claude-sonnet-4",
|
| 39 |
+
"gpt-4": "claude-sonnet-4",
|
| 40 |
+
"gpt-3.5-turbo": "claude-haiku-4.5",
|
| 41 |
+
# OpenAI o1 -> Claude Opus
|
| 42 |
+
"o1": "claude-sonnet-4.5",
|
| 43 |
+
"o1-preview": "claude-sonnet-4.5",
|
| 44 |
+
"o1-mini": "claude-sonnet-4",
|
| 45 |
+
# Gemini -> Claude
|
| 46 |
+
"gemini-2.0-flash": "claude-sonnet-4",
|
| 47 |
+
"gemini-2.0-flash-thinking": "claude-sonnet-4.5",
|
| 48 |
+
"gemini-1.5-pro": "claude-sonnet-4.5",
|
| 49 |
+
"gemini-1.5-flash": "claude-sonnet-4",
|
| 50 |
+
# 别名
|
| 51 |
+
"sonnet": "claude-sonnet-4",
|
| 52 |
+
"haiku": "claude-haiku-4.5",
|
| 53 |
+
"opus": "claude-sonnet-4.5",
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
KIRO_MODELS = {"auto", "claude-sonnet-4.5", "claude-sonnet-4", "claude-haiku-4.5"}
|
| 57 |
+
|
| 58 |
+
def get_best_model_by_tier(tier: str, available_models: set = None) -> str:
|
| 59 |
+
"""根据等级获取最佳可用模型(等级对等 + 智能降级)"""
|
| 60 |
+
if available_models is None:
|
| 61 |
+
available_models = KIRO_MODELS
|
| 62 |
+
|
| 63 |
+
# 等级对等映射 + 降级路径
|
| 64 |
+
TIER_PRIORITIES = {
|
| 65 |
+
# Opus: 最强 → 次强 → 快速 → 自动
|
| 66 |
+
"opus": ["claude-sonnet-4.5", "claude-sonnet-4", "claude-haiku-4.5", "auto"],
|
| 67 |
+
|
| 68 |
+
# Sonnet: 高性能 → 最强 → 标准 → 快速 → 自动
|
| 69 |
+
"sonnet": ["claude-sonnet-4.5", "claude-sonnet-4", "claude-haiku-4.5", "auto"],
|
| 70 |
+
|
| 71 |
+
# Haiku: 快速 → 标准 → 高性能 → 自动
|
| 72 |
+
"haiku": ["claude-haiku-4.5", "claude-sonnet-4", "claude-sonnet-4.5", "auto"],
|
| 73 |
+
}
|
| 74 |
+
|
| 75 |
+
priorities = TIER_PRIORITIES.get(tier, TIER_PRIORITIES["sonnet"])
|
| 76 |
+
|
| 77 |
+
# 选择第一个可用的模型
|
| 78 |
+
for model in priorities:
|
| 79 |
+
if model in available_models:
|
| 80 |
+
return model
|
| 81 |
+
|
| 82 |
+
return "auto" # 最终回退
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def detect_model_tier(model: str) -> str:
|
| 86 |
+
"""智能检测模型等级"""
|
| 87 |
+
if not model:
|
| 88 |
+
return "sonnet" # 默认中等
|
| 89 |
+
|
| 90 |
+
model_lower = model.lower()
|
| 91 |
+
|
| 92 |
+
# 特殊模型优先检测(避免被通用关键词误判)
|
| 93 |
+
if "gemini" in model_lower:
|
| 94 |
+
if any(keyword in model_lower for keyword in ["1.5-pro", "pro"]):
|
| 95 |
+
return "opus"
|
| 96 |
+
elif any(keyword in model_lower for keyword in ["2.0", "flash"]):
|
| 97 |
+
return "sonnet" # Gemini 2.0 和 flash 系列归为 sonnet
|
| 98 |
+
|
| 99 |
+
# 等级关键词检测(优先级从高到低)
|
| 100 |
+
# Opus 等级 - 最强模型
|
| 101 |
+
if any(keyword in model_lower for keyword in ["opus", "o1", "max", "ultra", "premium"]):
|
| 102 |
+
return "opus"
|
| 103 |
+
|
| 104 |
+
# Haiku 等级 - 快速模型(需要排除 sonnet 中的 3.5)
|
| 105 |
+
if any(keyword in model_lower for keyword in ["haiku", "mini", "light", "fast", "turbo"]):
|
| 106 |
+
return "haiku"
|
| 107 |
+
# 特殊处理:gpt-3.5 系列属于 haiku
|
| 108 |
+
if "3.5" in model_lower and "sonnet" not in model_lower:
|
| 109 |
+
return "haiku"
|
| 110 |
+
|
| 111 |
+
# Sonnet 等级 - 平衡模型
|
| 112 |
+
if any(keyword in model_lower for keyword in ["sonnet", "4o", "4", "standard", "base"]):
|
| 113 |
+
return "sonnet"
|
| 114 |
+
|
| 115 |
+
return "sonnet" # 默认中等
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def map_model_name(model: str, available_models: set = None) -> str:
|
| 119 |
+
"""将外部模型名称映射到 Kiro 支持的名称(支持动态模型选择)"""
|
| 120 |
+
if not model:
|
| 121 |
+
return "auto"
|
| 122 |
+
|
| 123 |
+
# 1. 精确匹配优先
|
| 124 |
+
if model in MODEL_MAPPING:
|
| 125 |
+
return MODEL_MAPPING[model]
|
| 126 |
+
if model in KIRO_MODELS:
|
| 127 |
+
return model
|
| 128 |
+
|
| 129 |
+
# 2. 智能等级检测 + 动态选择
|
| 130 |
+
tier = detect_model_tier(model)
|
| 131 |
+
best_model = get_best_model_by_tier(tier, available_models)
|
| 132 |
+
|
| 133 |
+
return best_model
|
KiroProxy/kiro_proxy/converters/__init__.py
ADDED
|
@@ -0,0 +1,1196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""协议转换模块 - Anthropic/OpenAI/Gemini <-> Kiro
|
| 2 |
+
|
| 3 |
+
增强版:参考 proxycast 实现
|
| 4 |
+
- 工具数量限制(最多 50 个)
|
| 5 |
+
- 工具描述截断(最多 500 字符)
|
| 6 |
+
- 历史消息交替修复
|
| 7 |
+
- OpenAI tool 角色消息处理
|
| 8 |
+
- tool_choice: required 支持
|
| 9 |
+
- web_search 特殊工具支持
|
| 10 |
+
- tool_results 去重
|
| 11 |
+
"""
|
| 12 |
+
import json
|
| 13 |
+
import hashlib
|
| 14 |
+
import re
|
| 15 |
+
from typing import List, Dict, Any, Tuple, Optional
|
| 16 |
+
|
| 17 |
+
# 常量
|
| 18 |
+
MAX_TOOLS = 50
|
| 19 |
+
MAX_TOOL_DESCRIPTION_LENGTH = 500
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def generate_session_id(messages: list) -> str:
|
| 23 |
+
"""基于消息内容生成会话ID"""
|
| 24 |
+
content = json.dumps(messages[:3], sort_keys=True)
|
| 25 |
+
return hashlib.sha256(content.encode()).hexdigest()[:16]
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def extract_images_from_content(content) -> Tuple[str, List[dict]]:
|
| 29 |
+
"""从消息内容中提取文本和图片
|
| 30 |
+
|
| 31 |
+
Returns:
|
| 32 |
+
(text_content, images_list)
|
| 33 |
+
"""
|
| 34 |
+
if isinstance(content, str):
|
| 35 |
+
return content, []
|
| 36 |
+
|
| 37 |
+
if not isinstance(content, list):
|
| 38 |
+
return str(content) if content else "", []
|
| 39 |
+
|
| 40 |
+
text_parts = []
|
| 41 |
+
images = []
|
| 42 |
+
|
| 43 |
+
for block in content:
|
| 44 |
+
if isinstance(block, str):
|
| 45 |
+
text_parts.append(block)
|
| 46 |
+
elif isinstance(block, dict):
|
| 47 |
+
block_type = block.get("type", "")
|
| 48 |
+
|
| 49 |
+
if block_type == "text":
|
| 50 |
+
text_parts.append(block.get("text", ""))
|
| 51 |
+
|
| 52 |
+
elif block_type == "image":
|
| 53 |
+
# Anthropic 格式
|
| 54 |
+
source = block.get("source", {})
|
| 55 |
+
media_type = source.get("media_type", "image/jpeg")
|
| 56 |
+
data = source.get("data", "")
|
| 57 |
+
|
| 58 |
+
fmt = "jpeg"
|
| 59 |
+
if "png" in media_type:
|
| 60 |
+
fmt = "png"
|
| 61 |
+
elif "gif" in media_type:
|
| 62 |
+
fmt = "gif"
|
| 63 |
+
elif "webp" in media_type:
|
| 64 |
+
fmt = "webp"
|
| 65 |
+
|
| 66 |
+
if data:
|
| 67 |
+
images.append({
|
| 68 |
+
"format": fmt,
|
| 69 |
+
"source": {"bytes": data}
|
| 70 |
+
})
|
| 71 |
+
|
| 72 |
+
elif block_type == "image_url":
|
| 73 |
+
# OpenAI 格式
|
| 74 |
+
image_url = block.get("image_url", {})
|
| 75 |
+
url = image_url.get("url", "")
|
| 76 |
+
|
| 77 |
+
if url.startswith("data:"):
|
| 78 |
+
match = re.match(r'data:image/(\w+);base64,(.+)', url)
|
| 79 |
+
if match:
|
| 80 |
+
fmt = match.group(1)
|
| 81 |
+
data = match.group(2)
|
| 82 |
+
images.append({
|
| 83 |
+
"format": fmt,
|
| 84 |
+
"source": {"bytes": data}
|
| 85 |
+
})
|
| 86 |
+
|
| 87 |
+
return "\n".join(text_parts), images
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def truncate_description(desc: str, max_length: int = MAX_TOOL_DESCRIPTION_LENGTH) -> str:
|
| 91 |
+
"""截断工具描述"""
|
| 92 |
+
if len(desc) <= max_length:
|
| 93 |
+
return desc
|
| 94 |
+
return desc[:max_length - 3] + "..."
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
# ==================== Anthropic 转换 ====================
|
| 98 |
+
|
| 99 |
+
def convert_anthropic_tools_to_kiro(tools: List[dict]) -> List[dict]:
|
| 100 |
+
"""将 Anthropic 工具格式转换为 Kiro 格式
|
| 101 |
+
|
| 102 |
+
增强:
|
| 103 |
+
- 限制最多 50 个工具
|
| 104 |
+
- 截断过长的描述
|
| 105 |
+
- 支持 web_search 特殊工具
|
| 106 |
+
"""
|
| 107 |
+
kiro_tools = []
|
| 108 |
+
function_count = 0
|
| 109 |
+
|
| 110 |
+
for tool in tools:
|
| 111 |
+
name = tool.get("name", "")
|
| 112 |
+
|
| 113 |
+
# 特殊工具:web_search
|
| 114 |
+
if name in ("web_search", "web_search_20250305"):
|
| 115 |
+
kiro_tools.append({
|
| 116 |
+
"webSearchTool": {
|
| 117 |
+
"type": "web_search"
|
| 118 |
+
}
|
| 119 |
+
})
|
| 120 |
+
continue
|
| 121 |
+
|
| 122 |
+
# 限制工具数量
|
| 123 |
+
if function_count >= MAX_TOOLS:
|
| 124 |
+
continue
|
| 125 |
+
function_count += 1
|
| 126 |
+
|
| 127 |
+
description = tool.get("description", f"Tool: {name}")
|
| 128 |
+
description = truncate_description(description)
|
| 129 |
+
|
| 130 |
+
input_schema = tool.get("input_schema", {"type": "object", "properties": {}})
|
| 131 |
+
|
| 132 |
+
kiro_tools.append({
|
| 133 |
+
"toolSpecification": {
|
| 134 |
+
"name": name,
|
| 135 |
+
"description": description,
|
| 136 |
+
"inputSchema": {
|
| 137 |
+
"json": input_schema
|
| 138 |
+
}
|
| 139 |
+
}
|
| 140 |
+
})
|
| 141 |
+
|
| 142 |
+
return kiro_tools
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def fix_history_alternation(history: List[dict], model_id: str = "claude-sonnet-4") -> List[dict]:
|
| 146 |
+
"""修复历史记录,确保 user/assistant 严格交替,并验证 toolUses/toolResults 配对
|
| 147 |
+
|
| 148 |
+
Kiro API 规则:
|
| 149 |
+
1. 消息必须严格交替:user -> assistant -> user -> assistant
|
| 150 |
+
2. 当 assistant 有 toolUses 时,下一条 user 必须有对应的 toolResults
|
| 151 |
+
3. 当 assistant 没有 toolUses 时,下一条 user 不能有 toolResults
|
| 152 |
+
"""
|
| 153 |
+
if not history:
|
| 154 |
+
return history
|
| 155 |
+
|
| 156 |
+
# 深拷贝以避免修改原始数据
|
| 157 |
+
import copy
|
| 158 |
+
history = copy.deepcopy(history)
|
| 159 |
+
|
| 160 |
+
fixed = []
|
| 161 |
+
|
| 162 |
+
for i, item in enumerate(history):
|
| 163 |
+
is_user = "userInputMessage" in item
|
| 164 |
+
is_assistant = "assistantResponseMessage" in item
|
| 165 |
+
|
| 166 |
+
if is_user:
|
| 167 |
+
# 检查上一条是否也是 user
|
| 168 |
+
if fixed and "userInputMessage" in fixed[-1]:
|
| 169 |
+
# 检查当前消息是否有 tool_results
|
| 170 |
+
user_msg = item["userInputMessage"]
|
| 171 |
+
ctx = user_msg.get("userInputMessageContext", {})
|
| 172 |
+
has_tool_results = bool(ctx.get("toolResults"))
|
| 173 |
+
|
| 174 |
+
if has_tool_results:
|
| 175 |
+
# 合并 tool_results 到上一条 user 消息
|
| 176 |
+
new_results = ctx["toolResults"]
|
| 177 |
+
last_user = fixed[-1]["userInputMessage"]
|
| 178 |
+
|
| 179 |
+
if "userInputMessageContext" not in last_user:
|
| 180 |
+
last_user["userInputMessageContext"] = {}
|
| 181 |
+
|
| 182 |
+
last_ctx = last_user["userInputMessageContext"]
|
| 183 |
+
if "toolResults" in last_ctx and last_ctx["toolResults"]:
|
| 184 |
+
last_ctx["toolResults"].extend(new_results)
|
| 185 |
+
else:
|
| 186 |
+
last_ctx["toolResults"] = new_results
|
| 187 |
+
continue
|
| 188 |
+
else:
|
| 189 |
+
# 插入一个占位 assistant 消息(不带 toolUses)
|
| 190 |
+
fixed.append({
|
| 191 |
+
"assistantResponseMessage": {
|
| 192 |
+
"content": "I understand."
|
| 193 |
+
}
|
| 194 |
+
})
|
| 195 |
+
|
| 196 |
+
# 验证 toolResults 与前一个 assistant 的 toolUses 配对
|
| 197 |
+
if fixed and "assistantResponseMessage" in fixed[-1]:
|
| 198 |
+
last_assistant = fixed[-1]["assistantResponseMessage"]
|
| 199 |
+
has_tool_uses = bool(last_assistant.get("toolUses"))
|
| 200 |
+
|
| 201 |
+
user_msg = item["userInputMessage"]
|
| 202 |
+
ctx = user_msg.get("userInputMessageContext", {})
|
| 203 |
+
has_tool_results = bool(ctx.get("toolResults"))
|
| 204 |
+
|
| 205 |
+
if has_tool_uses and not has_tool_results:
|
| 206 |
+
# assistant 有 toolUses 但 user 没有 toolResults
|
| 207 |
+
# 这是不允许的:不要删除 toolUses(否则会破坏后续上下文/导致 tool_use 轮次丢失)
|
| 208 |
+
# 改为在本条 user 前插入一个“工具结果占位” user 消息,与 toolUses 严格配对。
|
| 209 |
+
placeholder_results = []
|
| 210 |
+
for tu in (last_assistant.get("toolUses") or []):
|
| 211 |
+
tuid = ""
|
| 212 |
+
if isinstance(tu, dict):
|
| 213 |
+
tuid = tu.get("toolUseId") or ""
|
| 214 |
+
if tuid:
|
| 215 |
+
placeholder_results.append({
|
| 216 |
+
"content": [{"text": ""}],
|
| 217 |
+
"status": "success",
|
| 218 |
+
"toolUseId": tuid,
|
| 219 |
+
})
|
| 220 |
+
fixed.append({
|
| 221 |
+
"userInputMessage": {
|
| 222 |
+
"content": "Tool results provided.",
|
| 223 |
+
"modelId": model_id,
|
| 224 |
+
"origin": "AI_EDITOR",
|
| 225 |
+
"userInputMessageContext": {
|
| 226 |
+
"toolResults": placeholder_results
|
| 227 |
+
}
|
| 228 |
+
}
|
| 229 |
+
})
|
| 230 |
+
elif not has_tool_uses and has_tool_results:
|
| 231 |
+
# assistant 没有 toolUses 但 user 有 toolResults
|
| 232 |
+
# 这是不允许的,需要清除 user 的 toolResults
|
| 233 |
+
item["userInputMessage"].pop("userInputMessageContext", None)
|
| 234 |
+
|
| 235 |
+
fixed.append(item)
|
| 236 |
+
|
| 237 |
+
elif is_assistant:
|
| 238 |
+
# 检查上一条是否也是 assistant
|
| 239 |
+
if fixed and "assistantResponseMessage" in fixed[-1]:
|
| 240 |
+
# 插入一个占位 user 消息(不带 toolResults)
|
| 241 |
+
fixed.append({
|
| 242 |
+
"userInputMessage": {
|
| 243 |
+
"content": "Continue",
|
| 244 |
+
"modelId": model_id,
|
| 245 |
+
"origin": "AI_EDITOR"
|
| 246 |
+
}
|
| 247 |
+
})
|
| 248 |
+
|
| 249 |
+
# 如果历史为空,先插入一个 user 消息
|
| 250 |
+
if not fixed:
|
| 251 |
+
fixed.append({
|
| 252 |
+
"userInputMessage": {
|
| 253 |
+
"content": "Continue",
|
| 254 |
+
"modelId": model_id,
|
| 255 |
+
"origin": "AI_EDITOR"
|
| 256 |
+
}
|
| 257 |
+
})
|
| 258 |
+
|
| 259 |
+
fixed.append(item)
|
| 260 |
+
|
| 261 |
+
# 确保以 assistant 结尾(如果最后是 user,添加占位 assistant)
|
| 262 |
+
if fixed and "userInputMessage" in fixed[-1]:
|
| 263 |
+
# 不需要清除 toolResults,因为它是与前一个 assistant 的 toolUses 配对的
|
| 264 |
+
# 占位 assistant 只是为了满足交替规则
|
| 265 |
+
fixed.append({
|
| 266 |
+
"assistantResponseMessage": {
|
| 267 |
+
"content": "I understand."
|
| 268 |
+
}
|
| 269 |
+
})
|
| 270 |
+
|
| 271 |
+
return fixed
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
def convert_anthropic_messages_to_kiro(messages: List[dict], system="") -> Tuple[str, List[dict], List[dict]]:
|
| 275 |
+
"""将 Anthropic 消息格式转换为 Kiro 格式
|
| 276 |
+
|
| 277 |
+
Returns:
|
| 278 |
+
(user_content, history, tool_results)
|
| 279 |
+
"""
|
| 280 |
+
history = []
|
| 281 |
+
user_content = ""
|
| 282 |
+
current_tool_results = []
|
| 283 |
+
|
| 284 |
+
def _strip_thinking(text: str) -> str:
|
| 285 |
+
if text is None:
|
| 286 |
+
return ""
|
| 287 |
+
if not isinstance(text, str):
|
| 288 |
+
text = str(text)
|
| 289 |
+
if not text:
|
| 290 |
+
return ""
|
| 291 |
+
cleaned = text
|
| 292 |
+
while True:
|
| 293 |
+
start = find_real_thinking_start_tag(cleaned)
|
| 294 |
+
if start == -1:
|
| 295 |
+
break
|
| 296 |
+
end = find_real_thinking_end_tag(cleaned, start + len("<thinking>"))
|
| 297 |
+
if end == -1:
|
| 298 |
+
cleaned = cleaned[:start].rstrip()
|
| 299 |
+
break
|
| 300 |
+
before = cleaned[:start].rstrip()
|
| 301 |
+
after = cleaned[end + len("</thinking>"):].lstrip()
|
| 302 |
+
if before and after:
|
| 303 |
+
cleaned = before + "\n" + after
|
| 304 |
+
else:
|
| 305 |
+
cleaned = before or after
|
| 306 |
+
return cleaned.strip()
|
| 307 |
+
|
| 308 |
+
# 处理 system
|
| 309 |
+
system_text = ""
|
| 310 |
+
if isinstance(system, list):
|
| 311 |
+
for block in system:
|
| 312 |
+
if isinstance(block, dict) and block.get("type") == "text":
|
| 313 |
+
system_text += block.get("text", "") + "\n"
|
| 314 |
+
elif isinstance(block, str):
|
| 315 |
+
system_text += block + "\n"
|
| 316 |
+
system_text = system_text.strip()
|
| 317 |
+
elif isinstance(system, str):
|
| 318 |
+
system_text = system
|
| 319 |
+
|
| 320 |
+
system_text = _strip_thinking(system_text)
|
| 321 |
+
|
| 322 |
+
for i, msg in enumerate(messages):
|
| 323 |
+
role = msg.get("role", "")
|
| 324 |
+
content = msg.get("content", "")
|
| 325 |
+
is_last = (i == len(messages) - 1)
|
| 326 |
+
|
| 327 |
+
# 处理 content 列表
|
| 328 |
+
tool_results = []
|
| 329 |
+
text_parts = []
|
| 330 |
+
|
| 331 |
+
if isinstance(content, list):
|
| 332 |
+
for block in content:
|
| 333 |
+
if isinstance(block, dict):
|
| 334 |
+
if block.get("type") == "text":
|
| 335 |
+
text_parts.append(block.get("text", ""))
|
| 336 |
+
elif block.get("type") == "tool_result":
|
| 337 |
+
tr_content = block.get("content", "")
|
| 338 |
+
if isinstance(tr_content, list):
|
| 339 |
+
tr_text_parts = []
|
| 340 |
+
for tc in tr_content:
|
| 341 |
+
if isinstance(tc, dict) and tc.get("type") == "text":
|
| 342 |
+
tr_text_parts.append(tc.get("text", ""))
|
| 343 |
+
elif isinstance(tc, str):
|
| 344 |
+
tr_text_parts.append(tc)
|
| 345 |
+
tr_content = "\n".join(tr_text_parts)
|
| 346 |
+
|
| 347 |
+
# 处理 is_error
|
| 348 |
+
status = "error" if block.get("is_error") else "success"
|
| 349 |
+
|
| 350 |
+
tool_results.append({
|
| 351 |
+
"content": [{"text": str(tr_content)}],
|
| 352 |
+
"status": status,
|
| 353 |
+
"toolUseId": block.get("tool_use_id", "")
|
| 354 |
+
})
|
| 355 |
+
elif isinstance(block, str):
|
| 356 |
+
text_parts.append(block)
|
| 357 |
+
|
| 358 |
+
content = "\n".join(text_parts) if text_parts else ""
|
| 359 |
+
|
| 360 |
+
content = _strip_thinking(content)
|
| 361 |
+
|
| 362 |
+
# 处理工具结果
|
| 363 |
+
if tool_results:
|
| 364 |
+
# 去重
|
| 365 |
+
seen_ids = set()
|
| 366 |
+
unique_results = []
|
| 367 |
+
for tr in tool_results:
|
| 368 |
+
if tr["toolUseId"] not in seen_ids:
|
| 369 |
+
seen_ids.add(tr["toolUseId"])
|
| 370 |
+
unique_results.append(tr)
|
| 371 |
+
tool_results = unique_results
|
| 372 |
+
|
| 373 |
+
if is_last:
|
| 374 |
+
current_tool_results = tool_results
|
| 375 |
+
user_content = content if content else "Tool results provided."
|
| 376 |
+
else:
|
| 377 |
+
history.append({
|
| 378 |
+
"userInputMessage": {
|
| 379 |
+
"content": content if content else "Tool results provided.",
|
| 380 |
+
"modelId": "claude-sonnet-4",
|
| 381 |
+
"origin": "AI_EDITOR",
|
| 382 |
+
"userInputMessageContext": {
|
| 383 |
+
"toolResults": tool_results
|
| 384 |
+
}
|
| 385 |
+
}
|
| 386 |
+
})
|
| 387 |
+
continue
|
| 388 |
+
|
| 389 |
+
if role == "user":
|
| 390 |
+
if system_text and not history:
|
| 391 |
+
content = f"{system_text}\n\n{content}" if content else system_text
|
| 392 |
+
|
| 393 |
+
content = _strip_thinking(content)
|
| 394 |
+
|
| 395 |
+
if is_last:
|
| 396 |
+
user_content = content if content else "Continue"
|
| 397 |
+
else:
|
| 398 |
+
history.append({
|
| 399 |
+
"userInputMessage": {
|
| 400 |
+
"content": content if content else "Continue",
|
| 401 |
+
"modelId": "claude-sonnet-4",
|
| 402 |
+
"origin": "AI_EDITOR"
|
| 403 |
+
}
|
| 404 |
+
})
|
| 405 |
+
|
| 406 |
+
elif role == "assistant":
|
| 407 |
+
tool_uses = []
|
| 408 |
+
assistant_text = ""
|
| 409 |
+
|
| 410 |
+
if isinstance(msg.get("content"), list):
|
| 411 |
+
text_parts = []
|
| 412 |
+
for block in msg["content"]:
|
| 413 |
+
if isinstance(block, dict):
|
| 414 |
+
if block.get("type") == "tool_use":
|
| 415 |
+
tool_uses.append({
|
| 416 |
+
"toolUseId": block.get("id", ""),
|
| 417 |
+
"name": block.get("name", ""),
|
| 418 |
+
"input": block.get("input", {})
|
| 419 |
+
})
|
| 420 |
+
elif block.get("type") == "text":
|
| 421 |
+
text_parts.append(block.get("text", ""))
|
| 422 |
+
assistant_text = "\n".join(text_parts)
|
| 423 |
+
else:
|
| 424 |
+
assistant_text = content if isinstance(content, str) else ""
|
| 425 |
+
|
| 426 |
+
assistant_text = _strip_thinking(assistant_text)
|
| 427 |
+
|
| 428 |
+
if not assistant_text and not tool_uses:
|
| 429 |
+
continue
|
| 430 |
+
|
| 431 |
+
# 确保 assistant 消息有内容
|
| 432 |
+
if not assistant_text:
|
| 433 |
+
assistant_text = "I understand."
|
| 434 |
+
|
| 435 |
+
assistant_msg = {
|
| 436 |
+
"assistantResponseMessage": {
|
| 437 |
+
"content": assistant_text
|
| 438 |
+
}
|
| 439 |
+
}
|
| 440 |
+
# 只有在有 toolUses 时才添加这个字段
|
| 441 |
+
if tool_uses:
|
| 442 |
+
assistant_msg["assistantResponseMessage"]["toolUses"] = tool_uses
|
| 443 |
+
|
| 444 |
+
history.append(assistant_msg)
|
| 445 |
+
|
| 446 |
+
# 修复历史交替
|
| 447 |
+
history = fix_history_alternation(history)
|
| 448 |
+
|
| 449 |
+
return user_content, history, current_tool_results
|
| 450 |
+
|
| 451 |
+
|
| 452 |
+
def convert_kiro_response_to_anthropic(result: dict, model: str, msg_id: str) -> dict:
|
| 453 |
+
"""将 Kiro 响应转换为 Anthropic 格式"""
|
| 454 |
+
content = []
|
| 455 |
+
text = "".join(result["content"])
|
| 456 |
+
if text:
|
| 457 |
+
content.append({"type": "text", "text": text})
|
| 458 |
+
|
| 459 |
+
for tool_use in result["tool_uses"]:
|
| 460 |
+
content.append(tool_use)
|
| 461 |
+
|
| 462 |
+
return {
|
| 463 |
+
"id": msg_id,
|
| 464 |
+
"type": "message",
|
| 465 |
+
"role": "assistant",
|
| 466 |
+
"content": content,
|
| 467 |
+
"model": model,
|
| 468 |
+
"stop_reason": result["stop_reason"],
|
| 469 |
+
"stop_sequence": None,
|
| 470 |
+
"usage": {"input_tokens": 100, "output_tokens": 100}
|
| 471 |
+
}
|
| 472 |
+
|
| 473 |
+
|
| 474 |
+
# ==================== OpenAI 转换 ====================
|
| 475 |
+
|
| 476 |
+
def is_tool_choice_required(tool_choice) -> bool:
|
| 477 |
+
"""检查 tool_choice 是否为 required"""
|
| 478 |
+
if isinstance(tool_choice, dict):
|
| 479 |
+
t = tool_choice.get("type", "")
|
| 480 |
+
return t in ("any", "tool", "required")
|
| 481 |
+
elif isinstance(tool_choice, str):
|
| 482 |
+
return tool_choice in ("required", "any")
|
| 483 |
+
return False
|
| 484 |
+
|
| 485 |
+
|
| 486 |
+
def convert_openai_tools_to_kiro(tools: List[dict]) -> List[dict]:
|
| 487 |
+
"""将 OpenAI 工具格式转换为 Kiro 格式"""
|
| 488 |
+
kiro_tools = []
|
| 489 |
+
function_count = 0
|
| 490 |
+
|
| 491 |
+
for tool in tools:
|
| 492 |
+
tool_type = tool.get("type", "function")
|
| 493 |
+
|
| 494 |
+
# 特殊工具
|
| 495 |
+
if tool_type == "web_search":
|
| 496 |
+
kiro_tools.append({
|
| 497 |
+
"webSearchTool": {
|
| 498 |
+
"type": "web_search"
|
| 499 |
+
}
|
| 500 |
+
})
|
| 501 |
+
continue
|
| 502 |
+
|
| 503 |
+
if tool_type != "function":
|
| 504 |
+
continue
|
| 505 |
+
|
| 506 |
+
# 限制工具数量
|
| 507 |
+
if function_count >= MAX_TOOLS:
|
| 508 |
+
continue
|
| 509 |
+
function_count += 1
|
| 510 |
+
|
| 511 |
+
func = tool.get("function", {})
|
| 512 |
+
name = func.get("name", "")
|
| 513 |
+
description = func.get("description", f"Tool: {name}")
|
| 514 |
+
description = truncate_description(description)
|
| 515 |
+
parameters = func.get("parameters", {"type": "object", "properties": {}})
|
| 516 |
+
|
| 517 |
+
kiro_tools.append({
|
| 518 |
+
"toolSpecification": {
|
| 519 |
+
"name": name,
|
| 520 |
+
"description": description,
|
| 521 |
+
"inputSchema": {
|
| 522 |
+
"json": parameters
|
| 523 |
+
}
|
| 524 |
+
}
|
| 525 |
+
})
|
| 526 |
+
|
| 527 |
+
return kiro_tools
|
| 528 |
+
|
| 529 |
+
|
| 530 |
+
def convert_openai_messages_to_kiro(
|
| 531 |
+
messages: List[dict],
|
| 532 |
+
model: str,
|
| 533 |
+
tools: List[dict] = None,
|
| 534 |
+
tool_choice = None
|
| 535 |
+
) -> Tuple[str, List[dict], List[dict], List[dict]]:
|
| 536 |
+
"""将 OpenAI 消息格式转换为 Kiro 格式
|
| 537 |
+
|
| 538 |
+
增强:
|
| 539 |
+
- 支持 tool 角色消息
|
| 540 |
+
- 支持 assistant 的 tool_calls
|
| 541 |
+
- 支持 tool_choice: required
|
| 542 |
+
- 历史交替修复
|
| 543 |
+
|
| 544 |
+
Returns:
|
| 545 |
+
(user_content, history, tool_results, kiro_tools)
|
| 546 |
+
"""
|
| 547 |
+
system_content = ""
|
| 548 |
+
history = []
|
| 549 |
+
user_content = ""
|
| 550 |
+
current_tool_results = []
|
| 551 |
+
pending_tool_results = [] # 待处理的 tool 消息
|
| 552 |
+
|
| 553 |
+
# 处理 tool_choice: required
|
| 554 |
+
tool_instruction = ""
|
| 555 |
+
if is_tool_choice_required(tool_choice) and tools:
|
| 556 |
+
tool_instruction = "\n\n[CRITICAL INSTRUCTION] You MUST use one of the provided tools to respond. Do NOT respond with plain text. Call a tool function immediately."
|
| 557 |
+
|
| 558 |
+
for i, msg in enumerate(messages):
|
| 559 |
+
role = msg.get("role", "")
|
| 560 |
+
content = msg.get("content", "")
|
| 561 |
+
is_last = (i == len(messages) - 1)
|
| 562 |
+
|
| 563 |
+
# 提取文本内容
|
| 564 |
+
if isinstance(content, list):
|
| 565 |
+
content = " ".join([c.get("text", "") for c in content if c.get("type") == "text"])
|
| 566 |
+
if not content:
|
| 567 |
+
content = ""
|
| 568 |
+
|
| 569 |
+
if role == "system":
|
| 570 |
+
system_content = content + tool_instruction
|
| 571 |
+
|
| 572 |
+
elif role == "tool":
|
| 573 |
+
# OpenAI tool 角色消息 -> Kiro toolResults
|
| 574 |
+
tool_call_id = msg.get("tool_call_id", "")
|
| 575 |
+
pending_tool_results.append({
|
| 576 |
+
"content": [{"text": str(content)}],
|
| 577 |
+
"status": "success",
|
| 578 |
+
"toolUseId": tool_call_id
|
| 579 |
+
})
|
| 580 |
+
|
| 581 |
+
elif role == "user":
|
| 582 |
+
# 如果有待处理的 tool results,先处理
|
| 583 |
+
if pending_tool_results:
|
| 584 |
+
# 去重
|
| 585 |
+
seen_ids = set()
|
| 586 |
+
unique_results = []
|
| 587 |
+
for tr in pending_tool_results:
|
| 588 |
+
if tr["toolUseId"] not in seen_ids:
|
| 589 |
+
seen_ids.add(tr["toolUseId"])
|
| 590 |
+
unique_results.append(tr)
|
| 591 |
+
|
| 592 |
+
if is_last:
|
| 593 |
+
current_tool_results = unique_results
|
| 594 |
+
else:
|
| 595 |
+
history.append({
|
| 596 |
+
"userInputMessage": {
|
| 597 |
+
"content": "Tool results provided.",
|
| 598 |
+
"modelId": model,
|
| 599 |
+
"origin": "AI_EDITOR",
|
| 600 |
+
"userInputMessageContext": {
|
| 601 |
+
"toolResults": unique_results
|
| 602 |
+
}
|
| 603 |
+
}
|
| 604 |
+
})
|
| 605 |
+
pending_tool_results = []
|
| 606 |
+
|
| 607 |
+
# 合并 system prompt
|
| 608 |
+
if system_content and not history:
|
| 609 |
+
content = f"{system_content}\n\n{content}"
|
| 610 |
+
|
| 611 |
+
if is_last:
|
| 612 |
+
user_content = content
|
| 613 |
+
else:
|
| 614 |
+
history.append({
|
| 615 |
+
"userInputMessage": {
|
| 616 |
+
"content": content,
|
| 617 |
+
"modelId": model,
|
| 618 |
+
"origin": "AI_EDITOR"
|
| 619 |
+
}
|
| 620 |
+
})
|
| 621 |
+
|
| 622 |
+
elif role == "assistant":
|
| 623 |
+
# 如果有待处理的 tool results,先创建 user 消息
|
| 624 |
+
if pending_tool_results:
|
| 625 |
+
seen_ids = set()
|
| 626 |
+
unique_results = []
|
| 627 |
+
for tr in pending_tool_results:
|
| 628 |
+
if tr["toolUseId"] not in seen_ids:
|
| 629 |
+
seen_ids.add(tr["toolUseId"])
|
| 630 |
+
unique_results.append(tr)
|
| 631 |
+
|
| 632 |
+
history.append({
|
| 633 |
+
"userInputMessage": {
|
| 634 |
+
"content": "Tool results provided.",
|
| 635 |
+
"modelId": model,
|
| 636 |
+
"origin": "AI_EDITOR",
|
| 637 |
+
"userInputMessageContext": {
|
| 638 |
+
"toolResults": unique_results
|
| 639 |
+
}
|
| 640 |
+
}
|
| 641 |
+
})
|
| 642 |
+
pending_tool_results = []
|
| 643 |
+
|
| 644 |
+
# 处理 tool_calls
|
| 645 |
+
tool_uses = []
|
| 646 |
+
tool_calls = msg.get("tool_calls", [])
|
| 647 |
+
for tc in tool_calls:
|
| 648 |
+
func = tc.get("function", {})
|
| 649 |
+
args_str = func.get("arguments", "{}")
|
| 650 |
+
try:
|
| 651 |
+
args = json.loads(args_str)
|
| 652 |
+
except:
|
| 653 |
+
args = {}
|
| 654 |
+
|
| 655 |
+
tool_uses.append({
|
| 656 |
+
"toolUseId": tc.get("id", ""),
|
| 657 |
+
"name": func.get("name", ""),
|
| 658 |
+
"input": args
|
| 659 |
+
})
|
| 660 |
+
|
| 661 |
+
assistant_text = content if content else "I understand."
|
| 662 |
+
|
| 663 |
+
assistant_msg = {
|
| 664 |
+
"assistantResponseMessage": {
|
| 665 |
+
"content": assistant_text
|
| 666 |
+
}
|
| 667 |
+
}
|
| 668 |
+
# 只有在有 toolUses 时才添加这个字段
|
| 669 |
+
if tool_uses:
|
| 670 |
+
assistant_msg["assistantResponseMessage"]["toolUses"] = tool_uses
|
| 671 |
+
|
| 672 |
+
history.append(assistant_msg)
|
| 673 |
+
|
| 674 |
+
# 处理末尾的 tool results
|
| 675 |
+
if pending_tool_results:
|
| 676 |
+
seen_ids = set()
|
| 677 |
+
unique_results = []
|
| 678 |
+
for tr in pending_tool_results:
|
| 679 |
+
if tr["toolUseId"] not in seen_ids:
|
| 680 |
+
seen_ids.add(tr["toolUseId"])
|
| 681 |
+
unique_results.append(tr)
|
| 682 |
+
current_tool_results = unique_results
|
| 683 |
+
if not user_content:
|
| 684 |
+
user_content = "Tool results provided."
|
| 685 |
+
|
| 686 |
+
# 如果没有用户消息
|
| 687 |
+
if not user_content:
|
| 688 |
+
user_content = messages[-1].get("content", "") if messages else "Continue"
|
| 689 |
+
if isinstance(user_content, list):
|
| 690 |
+
user_content = " ".join([c.get("text", "") for c in user_content if c.get("type") == "text"])
|
| 691 |
+
if not user_content:
|
| 692 |
+
user_content = "Continue"
|
| 693 |
+
|
| 694 |
+
# 历史不包含最后一条用户消息
|
| 695 |
+
if history and "userInputMessage" in history[-1]:
|
| 696 |
+
history = history[:-1]
|
| 697 |
+
|
| 698 |
+
# 修复历史交替
|
| 699 |
+
history = fix_history_alternation(history, model)
|
| 700 |
+
|
| 701 |
+
# 转换工具
|
| 702 |
+
kiro_tools = convert_openai_tools_to_kiro(tools) if tools else []
|
| 703 |
+
|
| 704 |
+
return user_content, history, current_tool_results, kiro_tools
|
| 705 |
+
|
| 706 |
+
|
| 707 |
+
def convert_kiro_response_to_openai(result: dict, model: str, msg_id: str) -> dict:
|
| 708 |
+
"""将 Kiro 响应转换为 OpenAI 格式"""
|
| 709 |
+
text = "".join(result["content"])
|
| 710 |
+
tool_calls = []
|
| 711 |
+
|
| 712 |
+
for tool_use in result.get("tool_uses", []):
|
| 713 |
+
if tool_use.get("type") == "tool_use":
|
| 714 |
+
tool_calls.append({
|
| 715 |
+
"id": tool_use.get("id", ""),
|
| 716 |
+
"type": "function",
|
| 717 |
+
"function": {
|
| 718 |
+
"name": tool_use.get("name", ""),
|
| 719 |
+
"arguments": json.dumps(tool_use.get("input", {}))
|
| 720 |
+
}
|
| 721 |
+
})
|
| 722 |
+
|
| 723 |
+
# 映射 stop_reason
|
| 724 |
+
stop_reason = result.get("stop_reason", "stop")
|
| 725 |
+
finish_reason = "tool_calls" if tool_calls else "stop"
|
| 726 |
+
if stop_reason == "max_tokens":
|
| 727 |
+
finish_reason = "length"
|
| 728 |
+
|
| 729 |
+
message = {
|
| 730 |
+
"role": "assistant",
|
| 731 |
+
"content": text if text else None
|
| 732 |
+
}
|
| 733 |
+
if tool_calls:
|
| 734 |
+
message["tool_calls"] = tool_calls
|
| 735 |
+
|
| 736 |
+
return {
|
| 737 |
+
"id": msg_id,
|
| 738 |
+
"object": "chat.completion",
|
| 739 |
+
"model": model,
|
| 740 |
+
"choices": [{
|
| 741 |
+
"index": 0,
|
| 742 |
+
"message": message,
|
| 743 |
+
"finish_reason": finish_reason
|
| 744 |
+
}],
|
| 745 |
+
"usage": {
|
| 746 |
+
"prompt_tokens": 100,
|
| 747 |
+
"completion_tokens": 100,
|
| 748 |
+
"total_tokens": 200
|
| 749 |
+
}
|
| 750 |
+
}
|
| 751 |
+
|
| 752 |
+
|
| 753 |
+
# ==================== Gemini 转换 ====================
|
| 754 |
+
|
| 755 |
+
def convert_gemini_tools_to_kiro(tools: List[dict]) -> List[dict]:
|
| 756 |
+
"""将 Gemini 工具格式转换为 Kiro 格式
|
| 757 |
+
|
| 758 |
+
Gemini 工具格式:
|
| 759 |
+
{
|
| 760 |
+
"functionDeclarations": [
|
| 761 |
+
{
|
| 762 |
+
"name": "get_weather",
|
| 763 |
+
"description": "Get weather info",
|
| 764 |
+
"parameters": {...}
|
| 765 |
+
}
|
| 766 |
+
]
|
| 767 |
+
}
|
| 768 |
+
"""
|
| 769 |
+
kiro_tools = []
|
| 770 |
+
function_count = 0
|
| 771 |
+
|
| 772 |
+
for tool in tools:
|
| 773 |
+
# Gemini 的工具定义在 functionDeclarations 中
|
| 774 |
+
declarations = tool.get("functionDeclarations", [])
|
| 775 |
+
|
| 776 |
+
for func in declarations:
|
| 777 |
+
# 限制工具数量
|
| 778 |
+
if function_count >= MAX_TOOLS:
|
| 779 |
+
break
|
| 780 |
+
function_count += 1
|
| 781 |
+
|
| 782 |
+
name = func.get("name", "")
|
| 783 |
+
description = func.get("description", f"Tool: {name}")
|
| 784 |
+
description = truncate_description(description)
|
| 785 |
+
parameters = func.get("parameters", {"type": "object", "properties": {}})
|
| 786 |
+
|
| 787 |
+
kiro_tools.append({
|
| 788 |
+
"toolSpecification": {
|
| 789 |
+
"name": name,
|
| 790 |
+
"description": description,
|
| 791 |
+
"inputSchema": {
|
| 792 |
+
"json": parameters
|
| 793 |
+
}
|
| 794 |
+
}
|
| 795 |
+
})
|
| 796 |
+
|
| 797 |
+
return kiro_tools
|
| 798 |
+
|
| 799 |
+
|
| 800 |
+
def convert_gemini_contents_to_kiro(
|
| 801 |
+
contents: List[dict],
|
| 802 |
+
system_instruction: dict,
|
| 803 |
+
model: str,
|
| 804 |
+
tools: List[dict] = None,
|
| 805 |
+
tool_config: dict = None
|
| 806 |
+
) -> Tuple[str, List[dict], List[dict], List[dict]]:
|
| 807 |
+
"""将 Gemini 消息格式转换为 Kiro 格式
|
| 808 |
+
|
| 809 |
+
增强:
|
| 810 |
+
- 支持 functionCall 和 functionResponse
|
| 811 |
+
- 支持 tool_config
|
| 812 |
+
|
| 813 |
+
Returns:
|
| 814 |
+
(user_content, history, tool_results, kiro_tools)
|
| 815 |
+
"""
|
| 816 |
+
history = []
|
| 817 |
+
user_content = ""
|
| 818 |
+
current_tool_results = []
|
| 819 |
+
pending_tool_results = []
|
| 820 |
+
|
| 821 |
+
# 处理 system instruction
|
| 822 |
+
system_text = ""
|
| 823 |
+
if system_instruction:
|
| 824 |
+
parts = system_instruction.get("parts", [])
|
| 825 |
+
system_text = " ".join(p.get("text", "") for p in parts if "text" in p)
|
| 826 |
+
|
| 827 |
+
# 处理 tool_config(类似 tool_choice)
|
| 828 |
+
tool_instruction = ""
|
| 829 |
+
if tool_config:
|
| 830 |
+
mode = tool_config.get("functionCallingConfig", {}).get("mode", "")
|
| 831 |
+
if mode in ("ANY", "REQUIRED"):
|
| 832 |
+
tool_instruction = "\n\n[CRITICAL INSTRUCTION] You MUST use one of the provided tools to respond. Do NOT respond with plain text."
|
| 833 |
+
|
| 834 |
+
for i, content in enumerate(contents):
|
| 835 |
+
role = content.get("role", "user")
|
| 836 |
+
parts = content.get("parts", [])
|
| 837 |
+
is_last = (i == len(contents) - 1)
|
| 838 |
+
|
| 839 |
+
# 提取文本和工具调用
|
| 840 |
+
text_parts = []
|
| 841 |
+
tool_calls = []
|
| 842 |
+
tool_responses = []
|
| 843 |
+
|
| 844 |
+
for part in parts:
|
| 845 |
+
if "text" in part:
|
| 846 |
+
text_parts.append(part["text"])
|
| 847 |
+
elif "functionCall" in part:
|
| 848 |
+
# Gemini 的工具调用
|
| 849 |
+
fc = part["functionCall"]
|
| 850 |
+
tool_calls.append({
|
| 851 |
+
"toolUseId": fc.get("name", "") + "_" + str(i), # Gemini 没有 ID,生成一个
|
| 852 |
+
"name": fc.get("name", ""),
|
| 853 |
+
"input": fc.get("args", {})
|
| 854 |
+
})
|
| 855 |
+
elif "functionResponse" in part:
|
| 856 |
+
# Gemini 的工具响应
|
| 857 |
+
fr = part["functionResponse"]
|
| 858 |
+
response_content = fr.get("response", {})
|
| 859 |
+
if isinstance(response_content, dict):
|
| 860 |
+
response_text = json.dumps(response_content)
|
| 861 |
+
else:
|
| 862 |
+
response_text = str(response_content)
|
| 863 |
+
|
| 864 |
+
tool_responses.append({
|
| 865 |
+
"content": [{"text": response_text}],
|
| 866 |
+
"status": "success",
|
| 867 |
+
"toolUseId": fr.get("name", "") + "_" + str(i - 1) # 匹配上一个调用
|
| 868 |
+
})
|
| 869 |
+
|
| 870 |
+
text = " ".join(text_parts)
|
| 871 |
+
|
| 872 |
+
if role == "user":
|
| 873 |
+
# 处理待处理的 tool responses
|
| 874 |
+
if pending_tool_results:
|
| 875 |
+
seen_ids = set()
|
| 876 |
+
unique_results = []
|
| 877 |
+
for tr in pending_tool_results:
|
| 878 |
+
if tr["toolUseId"] not in seen_ids:
|
| 879 |
+
seen_ids.add(tr["toolUseId"])
|
| 880 |
+
unique_results.append(tr)
|
| 881 |
+
|
| 882 |
+
history.append({
|
| 883 |
+
"userInputMessage": {
|
| 884 |
+
"content": "Tool results provided.",
|
| 885 |
+
"modelId": model,
|
| 886 |
+
"origin": "AI_EDITOR",
|
| 887 |
+
"userInputMessageContext": {
|
| 888 |
+
"toolResults": unique_results
|
| 889 |
+
}
|
| 890 |
+
}
|
| 891 |
+
})
|
| 892 |
+
pending_tool_results = []
|
| 893 |
+
|
| 894 |
+
# 处理 functionResponse(用户消息中的工具响应)
|
| 895 |
+
if tool_responses:
|
| 896 |
+
pending_tool_results.extend(tool_responses)
|
| 897 |
+
|
| 898 |
+
# 合并 system prompt
|
| 899 |
+
if system_text and not history:
|
| 900 |
+
text = f"{system_text}{tool_instruction}\n\n{text}"
|
| 901 |
+
|
| 902 |
+
if is_last:
|
| 903 |
+
user_content = text
|
| 904 |
+
if pending_tool_results:
|
| 905 |
+
current_tool_results = pending_tool_results
|
| 906 |
+
pending_tool_results = []
|
| 907 |
+
else:
|
| 908 |
+
if text:
|
| 909 |
+
history.append({
|
| 910 |
+
"userInputMessage": {
|
| 911 |
+
"content": text,
|
| 912 |
+
"modelId": model,
|
| 913 |
+
"origin": "AI_EDITOR"
|
| 914 |
+
}
|
| 915 |
+
})
|
| 916 |
+
|
| 917 |
+
elif role == "model":
|
| 918 |
+
# 处理待处理的 tool responses
|
| 919 |
+
if pending_tool_results:
|
| 920 |
+
seen_ids = set()
|
| 921 |
+
unique_results = []
|
| 922 |
+
for tr in pending_tool_results:
|
| 923 |
+
if tr["toolUseId"] not in seen_ids:
|
| 924 |
+
seen_ids.add(tr["toolUseId"])
|
| 925 |
+
unique_results.append(tr)
|
| 926 |
+
|
| 927 |
+
history.append({
|
| 928 |
+
"userInputMessage": {
|
| 929 |
+
"content": "Tool results provided.",
|
| 930 |
+
"modelId": model,
|
| 931 |
+
"origin": "AI_EDITOR",
|
| 932 |
+
"userInputMessageContext": {
|
| 933 |
+
"toolResults": unique_results
|
| 934 |
+
}
|
| 935 |
+
}
|
| 936 |
+
})
|
| 937 |
+
pending_tool_results = []
|
| 938 |
+
|
| 939 |
+
assistant_text = text if text else "I understand."
|
| 940 |
+
|
| 941 |
+
assistant_msg = {
|
| 942 |
+
"assistantResponseMessage": {
|
| 943 |
+
"content": assistant_text
|
| 944 |
+
}
|
| 945 |
+
}
|
| 946 |
+
# 只有在有 toolUses 时才添加这个字段
|
| 947 |
+
if tool_calls:
|
| 948 |
+
assistant_msg["assistantResponseMessage"]["toolUses"] = tool_calls
|
| 949 |
+
|
| 950 |
+
history.append(assistant_msg)
|
| 951 |
+
|
| 952 |
+
# 处理末尾的 tool results
|
| 953 |
+
if pending_tool_results:
|
| 954 |
+
current_tool_results = pending_tool_results
|
| 955 |
+
if not user_content:
|
| 956 |
+
user_content = "Tool results provided."
|
| 957 |
+
|
| 958 |
+
# 如果没有用户消息
|
| 959 |
+
if not user_content:
|
| 960 |
+
if contents:
|
| 961 |
+
last_parts = contents[-1].get("parts", [])
|
| 962 |
+
user_content = " ".join(p.get("text", "") for p in last_parts if "text" in p)
|
| 963 |
+
if not user_content:
|
| 964 |
+
user_content = "Continue"
|
| 965 |
+
|
| 966 |
+
# 修复历史交替
|
| 967 |
+
history = fix_history_alternation(history, model)
|
| 968 |
+
|
| 969 |
+
# 移除最后一条(当前用户消息)
|
| 970 |
+
if history and "userInputMessage" in history[-1]:
|
| 971 |
+
history = history[:-1]
|
| 972 |
+
|
| 973 |
+
# 转换工具
|
| 974 |
+
kiro_tools = convert_gemini_tools_to_kiro(tools) if tools else []
|
| 975 |
+
|
| 976 |
+
return user_content, history, current_tool_results, kiro_tools
|
| 977 |
+
|
| 978 |
+
|
| 979 |
+
def convert_kiro_response_to_gemini(result: dict, model: str) -> dict:
|
| 980 |
+
"""将 Kiro 响应转换为 Gemini 格式"""
|
| 981 |
+
text = "".join(result.get("content", []))
|
| 982 |
+
tool_uses = result.get("tool_uses", [])
|
| 983 |
+
|
| 984 |
+
parts = []
|
| 985 |
+
|
| 986 |
+
# 添加文本部分
|
| 987 |
+
if text:
|
| 988 |
+
parts.append({"text": text})
|
| 989 |
+
|
| 990 |
+
# 添加工具调用
|
| 991 |
+
for tool_use in tool_uses:
|
| 992 |
+
if tool_use.get("type") == "tool_use":
|
| 993 |
+
parts.append({
|
| 994 |
+
"functionCall": {
|
| 995 |
+
"name": tool_use.get("name", ""),
|
| 996 |
+
"args": tool_use.get("input", {})
|
| 997 |
+
}
|
| 998 |
+
})
|
| 999 |
+
|
| 1000 |
+
# 映射 stop_reason
|
| 1001 |
+
stop_reason = result.get("stop_reason", "STOP")
|
| 1002 |
+
finish_reason = "STOP"
|
| 1003 |
+
if tool_uses:
|
| 1004 |
+
finish_reason = "TOOL_CALLS"
|
| 1005 |
+
elif stop_reason == "max_tokens":
|
| 1006 |
+
finish_reason = "MAX_TOKENS"
|
| 1007 |
+
|
| 1008 |
+
return {
|
| 1009 |
+
"candidates": [{
|
| 1010 |
+
"content": {
|
| 1011 |
+
"parts": parts,
|
| 1012 |
+
"role": "model"
|
| 1013 |
+
},
|
| 1014 |
+
"finishReason": finish_reason,
|
| 1015 |
+
"index": 0
|
| 1016 |
+
}],
|
| 1017 |
+
"usageMetadata": {
|
| 1018 |
+
"promptTokenCount": 100,
|
| 1019 |
+
"candidatesTokenCount": 100,
|
| 1020 |
+
"totalTokenCount": 200
|
| 1021 |
+
}
|
| 1022 |
+
}
|
| 1023 |
+
|
| 1024 |
+
|
| 1025 |
+
# ==================== 思考功能支持 ====================
|
| 1026 |
+
|
| 1027 |
+
def generate_thinking_prefix(thinking_type: str = "enabled", budget_tokens: int = 20000) -> str:
|
| 1028 |
+
"""生成思考模式的前缀 XML 标签
|
| 1029 |
+
|
| 1030 |
+
Args:
|
| 1031 |
+
thinking_type: 思考类型,通常为 "enabled"
|
| 1032 |
+
budget_tokens: 思考的 token 预算
|
| 1033 |
+
|
| 1034 |
+
Returns:
|
| 1035 |
+
XML 格式的思考标签字符串
|
| 1036 |
+
"""
|
| 1037 |
+
if thinking_type != "enabled":
|
| 1038 |
+
return ""
|
| 1039 |
+
|
| 1040 |
+
return f"<thinking_mode>enabled</thinking_mode>\n<max_thinking_length>{budget_tokens}</max_thinking_length>"
|
| 1041 |
+
|
| 1042 |
+
|
| 1043 |
+
def has_thinking_tags(text: str) -> bool:
|
| 1044 |
+
"""检查文本是否已包含思考标签
|
| 1045 |
+
|
| 1046 |
+
Args:
|
| 1047 |
+
text: 要检查的文本
|
| 1048 |
+
|
| 1049 |
+
Returns:
|
| 1050 |
+
如果包含思考标签返回 True
|
| 1051 |
+
"""
|
| 1052 |
+
return "<thinking_mode>" in text and "</thinking_mode>" in text
|
| 1053 |
+
|
| 1054 |
+
|
| 1055 |
+
def inject_thinking_tags_to_system(system, thinking_type: str = "enabled", budget_tokens: int = 20000):
|
| 1056 |
+
"""将思考标签注入到系统消息中
|
| 1057 |
+
|
| 1058 |
+
Args:
|
| 1059 |
+
system: 原始系统消息 (可以是字符串或列表)
|
| 1060 |
+
thinking_type: 思考类型
|
| 1061 |
+
budget_tokens: 思考的 token 预算
|
| 1062 |
+
|
| 1063 |
+
Returns:
|
| 1064 |
+
注入思考标签后的系统消息 (保持原始类型)
|
| 1065 |
+
"""
|
| 1066 |
+
# 生成思考前缀
|
| 1067 |
+
thinking_prefix = generate_thinking_prefix(thinking_type, budget_tokens)
|
| 1068 |
+
|
| 1069 |
+
if not thinking_prefix:
|
| 1070 |
+
return system
|
| 1071 |
+
|
| 1072 |
+
# 处理 system 为列表的情况 (Anthropic API 支持 system 为 content blocks 列表)
|
| 1073 |
+
if isinstance(system, list):
|
| 1074 |
+
# 将列表转换为字符串
|
| 1075 |
+
system_text = ""
|
| 1076 |
+
for block in system:
|
| 1077 |
+
if isinstance(block, dict) and block.get("type") == "text":
|
| 1078 |
+
system_text += block.get("text", "") + "\n"
|
| 1079 |
+
elif isinstance(block, str):
|
| 1080 |
+
system_text += block + "\n"
|
| 1081 |
+
system_text = system_text.strip()
|
| 1082 |
+
|
| 1083 |
+
if not system_text:
|
| 1084 |
+
return thinking_prefix
|
| 1085 |
+
|
| 1086 |
+
if has_thinking_tags(system_text):
|
| 1087 |
+
return system
|
| 1088 |
+
|
| 1089 |
+
# 返回字符串形式
|
| 1090 |
+
return f"{thinking_prefix}\n\n{system_text}"
|
| 1091 |
+
|
| 1092 |
+
# 处理 system 为字符串的情况
|
| 1093 |
+
if not system or not str(system).strip():
|
| 1094 |
+
return thinking_prefix
|
| 1095 |
+
|
| 1096 |
+
# 如果已经包含思考标签,不再重复注入
|
| 1097 |
+
if has_thinking_tags(str(system)):
|
| 1098 |
+
return system
|
| 1099 |
+
|
| 1100 |
+
# 将思考标签插入到系统消息开头
|
| 1101 |
+
return f"{thinking_prefix}\n\n{system}"
|
| 1102 |
+
|
| 1103 |
+
|
| 1104 |
+
def find_real_thinking_start_tag(text: str, pos: int = 0) -> int:
|
| 1105 |
+
"""查找真正的 <thinking> 标签位置,忽略被引号包围的情况
|
| 1106 |
+
|
| 1107 |
+
Args:
|
| 1108 |
+
text: 要搜索的文本
|
| 1109 |
+
pos: 开始搜索的位置
|
| 1110 |
+
|
| 1111 |
+
Returns:
|
| 1112 |
+
找到的标签位置,如果没找到返回 -1
|
| 1113 |
+
"""
|
| 1114 |
+
while True:
|
| 1115 |
+
idx = text.find("<thinking>", pos)
|
| 1116 |
+
if idx == -1:
|
| 1117 |
+
return -1
|
| 1118 |
+
|
| 1119 |
+
# 检查是否被引号包围
|
| 1120 |
+
# 向前查找最近的引号
|
| 1121 |
+
prev_quote = max(
|
| 1122 |
+
text.rfind("`", 0, idx),
|
| 1123 |
+
text.rfind("'", 0, idx),
|
| 1124 |
+
text.rfind('"', 0, idx)
|
| 1125 |
+
)
|
| 1126 |
+
|
| 1127 |
+
# 如果有引号且引号后没有换行,说明是被包围的
|
| 1128 |
+
if prev_quote != -1:
|
| 1129 |
+
# 检查引号到标签之间是否有换行
|
| 1130 |
+
between = text[prev_quote + 1:idx]
|
| 1131 |
+
if "\n" not in between:
|
| 1132 |
+
pos = idx + len("<thinking>")
|
| 1133 |
+
continue
|
| 1134 |
+
|
| 1135 |
+
return idx
|
| 1136 |
+
|
| 1137 |
+
|
| 1138 |
+
def find_real_thinking_end_tag(text: str, pos: int = 0) -> int:
|
| 1139 |
+
"""查找真正的 </thinking> 标签位置,忽略被引号包围的情况
|
| 1140 |
+
|
| 1141 |
+
Args:
|
| 1142 |
+
text: 要搜索的文本
|
| 1143 |
+
pos: 开始搜索的位置
|
| 1144 |
+
|
| 1145 |
+
Returns:
|
| 1146 |
+
找到的标签位置,如果没找到返回 -1
|
| 1147 |
+
"""
|
| 1148 |
+
while True:
|
| 1149 |
+
idx = text.find("</thinking>", pos)
|
| 1150 |
+
if idx == -1:
|
| 1151 |
+
return -1
|
| 1152 |
+
|
| 1153 |
+
# 检查是否被引号包围
|
| 1154 |
+
# 向前查找最近的引号
|
| 1155 |
+
prev_quote = max(
|
| 1156 |
+
text.rfind("`", 0, idx),
|
| 1157 |
+
text.rfind("'", 0, idx),
|
| 1158 |
+
text.rfind('"', 0, idx)
|
| 1159 |
+
)
|
| 1160 |
+
|
| 1161 |
+
# 如果有引号且引号后没有换行,说明是被包围的
|
| 1162 |
+
if prev_quote != -1:
|
| 1163 |
+
# 检查引号到标签之间是否有换行
|
| 1164 |
+
between = text[prev_quote + 1:idx]
|
| 1165 |
+
if "\n" not in between:
|
| 1166 |
+
pos = idx + len("</thinking>")
|
| 1167 |
+
continue
|
| 1168 |
+
|
| 1169 |
+
return idx
|
| 1170 |
+
|
| 1171 |
+
|
| 1172 |
+
def extract_thinking_from_content(content: str) -> Tuple[str, str]:
|
| 1173 |
+
"""从内容中提取思考部分和正文部分
|
| 1174 |
+
|
| 1175 |
+
Args:
|
| 1176 |
+
content: 原始内容
|
| 1177 |
+
|
| 1178 |
+
Returns:
|
| 1179 |
+
(thinking_content, text_content)
|
| 1180 |
+
"""
|
| 1181 |
+
thinking_start = find_real_thinking_start_tag(content)
|
| 1182 |
+
thinking_end = find_real_thinking_end_tag(content)
|
| 1183 |
+
|
| 1184 |
+
if thinking_start == -1 or thinking_end == -1:
|
| 1185 |
+
return "", content
|
| 1186 |
+
|
| 1187 |
+
# 提取思考内容(去掉标签)
|
| 1188 |
+
thinking_content = content[thinking_start + len("<thinking>"):thinking_end].strip()
|
| 1189 |
+
|
| 1190 |
+
# 提取正文内容(去掉思考部分)
|
| 1191 |
+
text_content = content[:thinking_start].strip()
|
| 1192 |
+
after_thinking = content[thinking_end + len("</thinking>"):].strip()
|
| 1193 |
+
if after_thinking:
|
| 1194 |
+
text_content += "\n" + after_thinking
|
| 1195 |
+
|
| 1196 |
+
return thinking_content, text_content
|
KiroProxy/kiro_proxy/core/__init__.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""核心模块"""
|
| 2 |
+
from .state import state, ProxyState, RequestLog
|
| 3 |
+
from .account import Account
|
| 4 |
+
from .persistence import load_config, save_config, CONFIG_FILE
|
| 5 |
+
from .retry import RetryableRequest, is_retryable_error, RETRYABLE_STATUS_CODES
|
| 6 |
+
from .scheduler import scheduler
|
| 7 |
+
from .stats import stats_manager
|
| 8 |
+
from .browser import detect_browsers, open_url, get_browsers_info
|
| 9 |
+
from .flow_monitor import flow_monitor, FlowMonitor, LLMFlow, FlowState, TokenUsage
|
| 10 |
+
from .usage import get_usage_limits, get_account_usage, UsageInfo
|
| 11 |
+
from .history_manager import (
|
| 12 |
+
HistoryManager, HistoryConfig, TruncateStrategy,
|
| 13 |
+
get_history_config, set_history_config, update_history_config,
|
| 14 |
+
is_content_length_error
|
| 15 |
+
)
|
| 16 |
+
from .error_handler import (
|
| 17 |
+
ErrorType, KiroError, classify_error, is_account_suspended,
|
| 18 |
+
get_anthropic_error_response, format_error_log
|
| 19 |
+
)
|
| 20 |
+
from .rate_limiter import RateLimiter, RateLimitConfig, rate_limiter, get_rate_limiter
|
| 21 |
+
|
| 22 |
+
# 新增模块
|
| 23 |
+
from .quota_cache import QuotaCache, CachedQuota, get_quota_cache
|
| 24 |
+
from .account_selector import AccountSelector, SelectionStrategy, get_account_selector
|
| 25 |
+
from .quota_scheduler import QuotaScheduler, get_quota_scheduler
|
| 26 |
+
from .refresh_manager import (
|
| 27 |
+
RefreshManager, RefreshProgress, RefreshConfig,
|
| 28 |
+
get_refresh_manager, reset_refresh_manager
|
| 29 |
+
)
|
| 30 |
+
from .kiro_api import kiro_api_request, get_user_info, get_user_email
|
| 31 |
+
|
| 32 |
+
__all__ = [
|
| 33 |
+
"state", "ProxyState", "RequestLog", "Account",
|
| 34 |
+
"load_config", "save_config", "CONFIG_FILE",
|
| 35 |
+
"RetryableRequest", "is_retryable_error", "RETRYABLE_STATUS_CODES",
|
| 36 |
+
"scheduler", "stats_manager",
|
| 37 |
+
"detect_browsers", "open_url", "get_browsers_info",
|
| 38 |
+
"flow_monitor", "FlowMonitor", "LLMFlow", "FlowState", "TokenUsage",
|
| 39 |
+
"get_usage_limits", "get_account_usage", "UsageInfo",
|
| 40 |
+
"HistoryManager", "HistoryConfig", "TruncateStrategy",
|
| 41 |
+
"get_history_config", "set_history_config", "update_history_config",
|
| 42 |
+
"is_content_length_error",
|
| 43 |
+
"ErrorType", "KiroError", "classify_error", "is_account_suspended",
|
| 44 |
+
"get_anthropic_error_response", "format_error_log",
|
| 45 |
+
"RateLimiter", "RateLimitConfig", "rate_limiter", "get_rate_limiter",
|
| 46 |
+
# 新增导出
|
| 47 |
+
"QuotaCache", "CachedQuota", "get_quota_cache",
|
| 48 |
+
"AccountSelector", "SelectionStrategy", "get_account_selector",
|
| 49 |
+
"QuotaScheduler", "get_quota_scheduler",
|
| 50 |
+
# RefreshManager 导出
|
| 51 |
+
"RefreshManager", "RefreshProgress", "RefreshConfig",
|
| 52 |
+
"get_refresh_manager", "reset_refresh_manager",
|
| 53 |
+
# Kiro API 导出
|
| 54 |
+
"kiro_api_request", "get_user_info", "get_user_email",
|
| 55 |
+
]
|
KiroProxy/kiro_proxy/core/account.py
ADDED
|
@@ -0,0 +1,287 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""账号管理"""
|
| 2 |
+
import json
|
| 3 |
+
import time
|
| 4 |
+
from dataclasses import dataclass, field
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from typing import Optional
|
| 7 |
+
|
| 8 |
+
from ..credential import (
|
| 9 |
+
KiroCredentials, TokenRefresher, CredentialStatus,
|
| 10 |
+
generate_machine_id, quota_manager
|
| 11 |
+
)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@dataclass
|
| 15 |
+
class Account:
|
| 16 |
+
"""账号信息"""
|
| 17 |
+
id: str
|
| 18 |
+
name: str
|
| 19 |
+
token_path: str
|
| 20 |
+
enabled: bool = True
|
| 21 |
+
# 是否因额度耗尽被自动禁用(用于区分手动禁用,避免被自动启用)
|
| 22 |
+
auto_disabled: bool = False
|
| 23 |
+
request_count: int = 0
|
| 24 |
+
error_count: int = 0
|
| 25 |
+
last_used: Optional[float] = None
|
| 26 |
+
status: CredentialStatus = CredentialStatus.ACTIVE
|
| 27 |
+
|
| 28 |
+
_credentials: Optional[KiroCredentials] = field(default=None, repr=False)
|
| 29 |
+
_machine_id: Optional[str] = field(default=None, repr=False)
|
| 30 |
+
|
| 31 |
+
def is_available(self) -> bool:
|
| 32 |
+
"""检查账号是否可用"""
|
| 33 |
+
if not self.enabled:
|
| 34 |
+
return False
|
| 35 |
+
if self.status in (CredentialStatus.DISABLED, CredentialStatus.UNHEALTHY, CredentialStatus.SUSPENDED):
|
| 36 |
+
return False
|
| 37 |
+
if not quota_manager.is_available(self.id):
|
| 38 |
+
return False
|
| 39 |
+
|
| 40 |
+
# 检查额度是否耗尽
|
| 41 |
+
from .quota_cache import get_quota_cache
|
| 42 |
+
quota_cache = get_quota_cache()
|
| 43 |
+
quota = quota_cache.get(self.id)
|
| 44 |
+
if quota and quota.is_exhausted:
|
| 45 |
+
return False
|
| 46 |
+
|
| 47 |
+
return True
|
| 48 |
+
|
| 49 |
+
def is_active(self) -> bool:
|
| 50 |
+
"""检查账号是否活跃(最近60秒内使用过)"""
|
| 51 |
+
from .quota_scheduler import get_quota_scheduler
|
| 52 |
+
scheduler = get_quota_scheduler()
|
| 53 |
+
return scheduler.is_active(self.id)
|
| 54 |
+
|
| 55 |
+
def get_priority_order(self) -> Optional[int]:
|
| 56 |
+
"""获取优先级顺序(从1开始),非优先账号返回 None"""
|
| 57 |
+
from .account_selector import get_account_selector
|
| 58 |
+
selector = get_account_selector()
|
| 59 |
+
return selector.get_priority_order(self.id)
|
| 60 |
+
|
| 61 |
+
def is_priority(self) -> bool:
|
| 62 |
+
"""检查是否为优先账号"""
|
| 63 |
+
return self.get_priority_order() is not None
|
| 64 |
+
|
| 65 |
+
def load_credentials(self) -> Optional[KiroCredentials]:
|
| 66 |
+
"""加载凭证信息"""
|
| 67 |
+
try:
|
| 68 |
+
self._credentials = KiroCredentials.from_file(self.token_path)
|
| 69 |
+
|
| 70 |
+
if self._credentials.client_id_hash and not self._credentials.client_id:
|
| 71 |
+
self._merge_client_credentials()
|
| 72 |
+
|
| 73 |
+
return self._credentials
|
| 74 |
+
except Exception as e:
|
| 75 |
+
print(f"[Account] 加载凭证失败 {self.id}: {e}")
|
| 76 |
+
return None
|
| 77 |
+
|
| 78 |
+
def _merge_client_credentials(self):
|
| 79 |
+
"""合并 clientIdHash 对应的凭证文件"""
|
| 80 |
+
if not self._credentials or not self._credentials.client_id_hash:
|
| 81 |
+
return
|
| 82 |
+
|
| 83 |
+
cache_dir = Path(self.token_path).parent
|
| 84 |
+
hash_file = cache_dir / f"{self._credentials.client_id_hash}.json"
|
| 85 |
+
|
| 86 |
+
if hash_file.exists():
|
| 87 |
+
try:
|
| 88 |
+
with open(hash_file) as f:
|
| 89 |
+
data = json.load(f)
|
| 90 |
+
if not self._credentials.client_id:
|
| 91 |
+
self._credentials.client_id = data.get("clientId")
|
| 92 |
+
if not self._credentials.client_secret:
|
| 93 |
+
self._credentials.client_secret = data.get("clientSecret")
|
| 94 |
+
except Exception:
|
| 95 |
+
pass
|
| 96 |
+
|
| 97 |
+
def get_credentials(self) -> Optional[KiroCredentials]:
|
| 98 |
+
"""获取凭证(带缓存)"""
|
| 99 |
+
if self._credentials is None:
|
| 100 |
+
self.load_credentials()
|
| 101 |
+
return self._credentials
|
| 102 |
+
|
| 103 |
+
def get_token(self) -> str:
|
| 104 |
+
"""获取 access_token"""
|
| 105 |
+
creds = self.get_credentials()
|
| 106 |
+
if creds and creds.access_token:
|
| 107 |
+
return creds.access_token
|
| 108 |
+
|
| 109 |
+
try:
|
| 110 |
+
with open(self.token_path) as f:
|
| 111 |
+
return json.load(f).get("accessToken", "")
|
| 112 |
+
except Exception:
|
| 113 |
+
return ""
|
| 114 |
+
|
| 115 |
+
def get_machine_id(self) -> str:
|
| 116 |
+
"""获取基于此账号的 Machine ID"""
|
| 117 |
+
if self._machine_id:
|
| 118 |
+
return self._machine_id
|
| 119 |
+
|
| 120 |
+
creds = self.get_credentials()
|
| 121 |
+
if creds:
|
| 122 |
+
self._machine_id = generate_machine_id(creds.profile_arn, creds.client_id)
|
| 123 |
+
else:
|
| 124 |
+
self._machine_id = generate_machine_id()
|
| 125 |
+
|
| 126 |
+
return self._machine_id
|
| 127 |
+
|
| 128 |
+
def is_token_expired(self) -> bool:
|
| 129 |
+
"""检查 token 是否过期"""
|
| 130 |
+
creds = self.get_credentials()
|
| 131 |
+
return creds.is_expired() if creds else True
|
| 132 |
+
|
| 133 |
+
def is_token_expiring_soon(self, minutes: int = 10) -> bool:
|
| 134 |
+
"""检查 token 是否即将过期"""
|
| 135 |
+
creds = self.get_credentials()
|
| 136 |
+
return creds.is_expiring_soon(minutes) if creds else False
|
| 137 |
+
|
| 138 |
+
async def refresh_token(self) -> tuple:
|
| 139 |
+
"""刷新 token"""
|
| 140 |
+
creds = self.get_credentials()
|
| 141 |
+
if not creds:
|
| 142 |
+
return False, "无法加载凭证"
|
| 143 |
+
|
| 144 |
+
refresher = TokenRefresher(creds)
|
| 145 |
+
success, result = await refresher.refresh()
|
| 146 |
+
|
| 147 |
+
if success:
|
| 148 |
+
creds.save_to_file(self.token_path)
|
| 149 |
+
self._credentials = creds
|
| 150 |
+
self.status = CredentialStatus.ACTIVE
|
| 151 |
+
return True, "Token 刷新成功"
|
| 152 |
+
else:
|
| 153 |
+
self.status = CredentialStatus.UNHEALTHY
|
| 154 |
+
return False, result
|
| 155 |
+
|
| 156 |
+
def mark_quota_exceeded(self, reason: str = "Rate limited"):
|
| 157 |
+
"""标记配额超限(进入冷却并避免被继续选中)
|
| 158 |
+
|
| 159 |
+
429 错误自动冷却 5 分钟,无需手动配置
|
| 160 |
+
"""
|
| 161 |
+
quota_manager.mark_exceeded(self.id, reason)
|
| 162 |
+
self.status = CredentialStatus.COOLDOWN
|
| 163 |
+
self.error_count += 1
|
| 164 |
+
|
| 165 |
+
def get_status_info(self) -> dict:
|
| 166 |
+
"""获取状态信息"""
|
| 167 |
+
cooldown_remaining = quota_manager.get_cooldown_remaining(self.id)
|
| 168 |
+
creds = self.get_credentials()
|
| 169 |
+
|
| 170 |
+
# 获取额度信息
|
| 171 |
+
from .quota_cache import get_quota_cache
|
| 172 |
+
quota_cache = get_quota_cache()
|
| 173 |
+
quota = quota_cache.get(self.id)
|
| 174 |
+
|
| 175 |
+
quota_info = None
|
| 176 |
+
if quota:
|
| 177 |
+
# 计算相对时间
|
| 178 |
+
updated_ago = ""
|
| 179 |
+
if quota.updated_at > 0:
|
| 180 |
+
seconds_ago = time.time() - quota.updated_at
|
| 181 |
+
if seconds_ago < 60:
|
| 182 |
+
updated_ago = f"{int(seconds_ago)}秒前"
|
| 183 |
+
elif seconds_ago < 3600:
|
| 184 |
+
updated_ago = f"{int(seconds_ago / 60)}分钟前"
|
| 185 |
+
else:
|
| 186 |
+
updated_ago = f"{int(seconds_ago / 3600)}小时前"
|
| 187 |
+
|
| 188 |
+
# 格式化重置时间
|
| 189 |
+
reset_date_text = None
|
| 190 |
+
if quota.next_reset_date:
|
| 191 |
+
try:
|
| 192 |
+
# 处理时间戳格式
|
| 193 |
+
if isinstance(quota.next_reset_date, (int, float)):
|
| 194 |
+
from datetime import datetime
|
| 195 |
+
reset_dt = datetime.fromtimestamp(quota.next_reset_date)
|
| 196 |
+
reset_date_text = reset_dt.strftime('%Y-%m-%d')
|
| 197 |
+
else:
|
| 198 |
+
# 处理 ISO 格式
|
| 199 |
+
from datetime import datetime
|
| 200 |
+
reset_dt = datetime.fromisoformat(quota.next_reset_date.replace('Z', '+00:00'))
|
| 201 |
+
reset_date_text = reset_dt.strftime('%Y-%m-%d')
|
| 202 |
+
except:
|
| 203 |
+
reset_date_text = str(quota.next_reset_date)
|
| 204 |
+
|
| 205 |
+
# 格式化免费试用过期时间
|
| 206 |
+
trial_expiry_text = None
|
| 207 |
+
if quota.free_trial_expiry:
|
| 208 |
+
try:
|
| 209 |
+
# 处理时间戳格式
|
| 210 |
+
if isinstance(quota.free_trial_expiry, (int, float)):
|
| 211 |
+
from datetime import datetime
|
| 212 |
+
expiry_dt = datetime.fromtimestamp(quota.free_trial_expiry)
|
| 213 |
+
trial_expiry_text = expiry_dt.strftime('%Y-%m-%d')
|
| 214 |
+
else:
|
| 215 |
+
# 处理 ISO 格式
|
| 216 |
+
from datetime import datetime
|
| 217 |
+
expiry_dt = datetime.fromisoformat(quota.free_trial_expiry.replace('Z', '+00:00'))
|
| 218 |
+
trial_expiry_text = expiry_dt.strftime('%Y-%m-%d')
|
| 219 |
+
except:
|
| 220 |
+
trial_expiry_text = str(quota.free_trial_expiry)
|
| 221 |
+
|
| 222 |
+
# 计算生效奖励数
|
| 223 |
+
active_bonuses = len([e for e in (quota.bonus_expiries or []) if e])
|
| 224 |
+
|
| 225 |
+
quota_info = {
|
| 226 |
+
"balance": quota.balance,
|
| 227 |
+
"usage_limit": quota.usage_limit,
|
| 228 |
+
"current_usage": quota.current_usage,
|
| 229 |
+
"usage_percent": quota.usage_percent,
|
| 230 |
+
"is_low_balance": quota.is_low_balance,
|
| 231 |
+
"is_exhausted": quota.is_exhausted, # 额度是否耗尽
|
| 232 |
+
"is_suspended": getattr(quota, 'is_suspended', False), # 账号是否被封禁
|
| 233 |
+
"balance_status": quota.balance_status, # 额度状态: normal, low, exhausted
|
| 234 |
+
"subscription_title": quota.subscription_title,
|
| 235 |
+
"free_trial_limit": quota.free_trial_limit,
|
| 236 |
+
"free_trial_usage": quota.free_trial_usage,
|
| 237 |
+
"bonus_limit": quota.bonus_limit,
|
| 238 |
+
"bonus_usage": quota.bonus_usage,
|
| 239 |
+
"updated_at": updated_ago,
|
| 240 |
+
"updated_timestamp": quota.updated_at,
|
| 241 |
+
"error": quota.error,
|
| 242 |
+
# 新增重置时间字段
|
| 243 |
+
"next_reset_date": quota.next_reset_date,
|
| 244 |
+
"reset_date_text": reset_date_text, # 格式化后的重置日期
|
| 245 |
+
"free_trial_expiry": quota.free_trial_expiry,
|
| 246 |
+
"trial_expiry_text": trial_expiry_text, # 格式化后的试用过期日期
|
| 247 |
+
"bonus_expiries": quota.bonus_expiries or [],
|
| 248 |
+
"active_bonuses": active_bonuses, # 生效奖励数量
|
| 249 |
+
}
|
| 250 |
+
|
| 251 |
+
# 计算最后使用时间
|
| 252 |
+
last_used_ago = None
|
| 253 |
+
if self.last_used:
|
| 254 |
+
seconds_ago = time.time() - self.last_used
|
| 255 |
+
if seconds_ago < 60:
|
| 256 |
+
last_used_ago = f"{int(seconds_ago)}秒前"
|
| 257 |
+
elif seconds_ago < 3600:
|
| 258 |
+
last_used_ago = f"{int(seconds_ago / 60)}分钟前"
|
| 259 |
+
else:
|
| 260 |
+
last_used_ago = f"{int(seconds_ago / 3600)}小时前"
|
| 261 |
+
|
| 262 |
+
return {
|
| 263 |
+
"id": self.id,
|
| 264 |
+
"name": self.name,
|
| 265 |
+
"enabled": self.enabled,
|
| 266 |
+
"status": self.status.value,
|
| 267 |
+
"available": self.is_available(),
|
| 268 |
+
"request_count": self.request_count,
|
| 269 |
+
"error_count": self.error_count,
|
| 270 |
+
"error_rate": f"{(self.error_count / max(1, self.request_count) * 100):.1f}%",
|
| 271 |
+
"cooldown_remaining": cooldown_remaining,
|
| 272 |
+
"token_expired": self.is_token_expired() if creds else None,
|
| 273 |
+
"token_expiring_soon": self.is_token_expiring_soon() if creds else None,
|
| 274 |
+
"token_expires_at": creds.expires_at if creds else None, # Token 过期时间戳
|
| 275 |
+
"auth_method": creds.auth_method if creds else None,
|
| 276 |
+
"has_refresh_token": bool(creds and creds.refresh_token),
|
| 277 |
+
"idc_config_complete": bool(creds and creds.client_id and creds.client_secret) if creds and creds.auth_method == "idc" else None,
|
| 278 |
+
# 新增字段
|
| 279 |
+
"quota": quota_info,
|
| 280 |
+
"is_priority": self.is_priority(),
|
| 281 |
+
"priority_order": self.get_priority_order(),
|
| 282 |
+
"is_active": self.is_active(),
|
| 283 |
+
"last_used": self.last_used,
|
| 284 |
+
"last_used_ago": last_used_ago,
|
| 285 |
+
# Provider 字段 (Google/Github)
|
| 286 |
+
"provider": creds.provider if creds else None,
|
| 287 |
+
}
|
KiroProxy/kiro_proxy/core/account_selector.py
ADDED
|
@@ -0,0 +1,390 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""账号选择器模块
|
| 2 |
+
|
| 3 |
+
实现基于剩余额度的智能账号选择策略,支持优先账号配置。
|
| 4 |
+
"""
|
| 5 |
+
import json
|
| 6 |
+
import random
|
| 7 |
+
import time
|
| 8 |
+
from enum import Enum
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
from typing import Optional, List, Set, TYPE_CHECKING
|
| 11 |
+
from threading import Lock
|
| 12 |
+
|
| 13 |
+
if TYPE_CHECKING:
|
| 14 |
+
from .account import Account
|
| 15 |
+
from .quota_cache import QuotaCache
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class SelectionStrategy(Enum):
|
| 19 |
+
"""选择策略"""
|
| 20 |
+
LOWEST_BALANCE = "lowest_balance" # 剩余额度最少优先
|
| 21 |
+
ROUND_ROBIN = "round_robin" # 轮询
|
| 22 |
+
LEAST_REQUESTS = "least_requests" # 请求最少优先
|
| 23 |
+
RANDOM = "random" # 随机选择(分散压力)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class AccountSelector:
|
| 27 |
+
"""账号选择器
|
| 28 |
+
|
| 29 |
+
根据配置的策略选择最合适的账号,支持优先账号配置。
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
def __init__(self, quota_cache: 'QuotaCache', priority_file: Optional[str] = None):
|
| 33 |
+
"""
|
| 34 |
+
初始化账号选择器
|
| 35 |
+
|
| 36 |
+
Args:
|
| 37 |
+
quota_cache: 额度缓存实例
|
| 38 |
+
priority_file: 优先账号配置文件路径
|
| 39 |
+
"""
|
| 40 |
+
self.quota_cache = quota_cache
|
| 41 |
+
self._priority_accounts: List[str] = []
|
| 42 |
+
# 默认使用随机策略,避免单账号 RPM 过高导致封禁风险
|
| 43 |
+
self._strategy = SelectionStrategy.RANDOM
|
| 44 |
+
self._lock = Lock()
|
| 45 |
+
self._round_robin_index = 0
|
| 46 |
+
self._last_random_account_id: Optional[str] = None
|
| 47 |
+
|
| 48 |
+
# 设置优先账号配置文件路径
|
| 49 |
+
if priority_file:
|
| 50 |
+
self._priority_file = Path(priority_file)
|
| 51 |
+
else:
|
| 52 |
+
from ..config import DATA_DIR
|
| 53 |
+
self._priority_file = DATA_DIR / "priority.json"
|
| 54 |
+
|
| 55 |
+
# 加载优先账号配置
|
| 56 |
+
self._load_priority_config()
|
| 57 |
+
|
| 58 |
+
@property
|
| 59 |
+
def strategy(self) -> SelectionStrategy:
|
| 60 |
+
"""获取当前选择策略"""
|
| 61 |
+
return self._strategy
|
| 62 |
+
|
| 63 |
+
@strategy.setter
|
| 64 |
+
def strategy(self, value: SelectionStrategy):
|
| 65 |
+
"""设置选择策略"""
|
| 66 |
+
self._strategy = value
|
| 67 |
+
self._save_priority_config()
|
| 68 |
+
|
| 69 |
+
def select(self,
|
| 70 |
+
available_accounts: List['Account'],
|
| 71 |
+
session_id: Optional[str] = None) -> Optional['Account']:
|
| 72 |
+
"""选择最合适的账号
|
| 73 |
+
|
| 74 |
+
Args:
|
| 75 |
+
available_accounts: 可用账号列表
|
| 76 |
+
session_id: 会话ID(用于会话粘性,暂未实现)
|
| 77 |
+
|
| 78 |
+
Returns:
|
| 79 |
+
选中的账号,如果没有可用账号则返回 None
|
| 80 |
+
"""
|
| 81 |
+
if not available_accounts:
|
| 82 |
+
return None
|
| 83 |
+
|
| 84 |
+
with self._lock:
|
| 85 |
+
# 1. 首先检查优先账号
|
| 86 |
+
if self._priority_accounts:
|
| 87 |
+
for priority_id in self._priority_accounts:
|
| 88 |
+
for account in available_accounts:
|
| 89 |
+
if account.id == priority_id and account.is_available():
|
| 90 |
+
return account
|
| 91 |
+
|
| 92 |
+
# 2. 根据策略选择
|
| 93 |
+
if self._strategy == SelectionStrategy.LOWEST_BALANCE:
|
| 94 |
+
return self._select_lowest_balance(available_accounts)
|
| 95 |
+
elif self._strategy == SelectionStrategy.ROUND_ROBIN:
|
| 96 |
+
return self._select_round_robin(available_accounts)
|
| 97 |
+
elif self._strategy == SelectionStrategy.LEAST_REQUESTS:
|
| 98 |
+
return self._select_least_requests(available_accounts)
|
| 99 |
+
elif self._strategy == SelectionStrategy.RANDOM:
|
| 100 |
+
return self._select_random(available_accounts)
|
| 101 |
+
|
| 102 |
+
# 默认返回第一个可用账号
|
| 103 |
+
return available_accounts[0] if available_accounts else None
|
| 104 |
+
|
| 105 |
+
def _select_lowest_balance(self, accounts: List['Account']) -> Optional['Account']:
|
| 106 |
+
"""选择剩余额度最少的账号"""
|
| 107 |
+
available = [a for a in accounts if a.is_available()]
|
| 108 |
+
if not available:
|
| 109 |
+
return None
|
| 110 |
+
|
| 111 |
+
def get_balance_and_requests(account: 'Account') -> tuple:
|
| 112 |
+
"""获取账号的余额和请求数,用于排序"""
|
| 113 |
+
quota = self.quota_cache.get(account.id)
|
| 114 |
+
balance = quota.balance if quota and not quota.has_error() else float('inf')
|
| 115 |
+
return (balance, account.request_count)
|
| 116 |
+
|
| 117 |
+
# 按余额升序,余额相同时按请求数升序
|
| 118 |
+
return min(available, key=get_balance_and_requests)
|
| 119 |
+
|
| 120 |
+
def _select_round_robin(self, accounts: List['Account']) -> Optional['Account']:
|
| 121 |
+
"""轮询选择账号"""
|
| 122 |
+
available = [a for a in accounts if a.is_available()]
|
| 123 |
+
if not available:
|
| 124 |
+
return None
|
| 125 |
+
|
| 126 |
+
self._round_robin_index = self._round_robin_index % len(available)
|
| 127 |
+
account = available[self._round_robin_index]
|
| 128 |
+
self._round_robin_index += 1
|
| 129 |
+
return account
|
| 130 |
+
|
| 131 |
+
def _select_least_requests(self, accounts: List['Account']) -> Optional['Account']:
|
| 132 |
+
"""选择请求数最少的账号"""
|
| 133 |
+
available = [a for a in accounts if a.is_available()]
|
| 134 |
+
if not available:
|
| 135 |
+
return None
|
| 136 |
+
return min(available, key=lambda a: a.request_count)
|
| 137 |
+
|
| 138 |
+
def _select_random(self, accounts: List['Account']) -> Optional['Account']:
|
| 139 |
+
"""随机选择账号(分散请求压力)"""
|
| 140 |
+
available = [a for a in accounts if a.is_available()]
|
| 141 |
+
if not available:
|
| 142 |
+
return None
|
| 143 |
+
|
| 144 |
+
# 尽量避免连续两次命中同一账号(在有多个可用账号时)
|
| 145 |
+
if self._last_random_account_id and len(available) > 1:
|
| 146 |
+
candidates = [a for a in available if a.id != self._last_random_account_id]
|
| 147 |
+
if candidates:
|
| 148 |
+
selected = random.choice(candidates)
|
| 149 |
+
else:
|
| 150 |
+
selected = random.choice(available)
|
| 151 |
+
else:
|
| 152 |
+
selected = random.choice(available)
|
| 153 |
+
|
| 154 |
+
self._last_random_account_id = selected.id
|
| 155 |
+
return selected
|
| 156 |
+
|
| 157 |
+
def set_priority_accounts(self, account_ids: List[str],
|
| 158 |
+
valid_account_ids: Optional[Set[str]] = None) -> tuple:
|
| 159 |
+
"""设置优先账号列表(按顺序)
|
| 160 |
+
|
| 161 |
+
Args:
|
| 162 |
+
account_ids: 优先账号ID列表(按顺序)
|
| 163 |
+
valid_account_ids: 有效账号ID集合(用于验证)
|
| 164 |
+
|
| 165 |
+
Returns:
|
| 166 |
+
(success, message)
|
| 167 |
+
"""
|
| 168 |
+
with self._lock:
|
| 169 |
+
if not account_ids:
|
| 170 |
+
self._priority_accounts = []
|
| 171 |
+
self._strategy = SelectionStrategy.RANDOM
|
| 172 |
+
self._save_priority_config()
|
| 173 |
+
return True, "已清除优先账号"
|
| 174 |
+
|
| 175 |
+
# 去重(保持顺序)
|
| 176 |
+
unique_ids: List[str] = []
|
| 177 |
+
seen: Set[str] = set()
|
| 178 |
+
for aid in account_ids:
|
| 179 |
+
if aid in seen:
|
| 180 |
+
continue
|
| 181 |
+
seen.add(aid)
|
| 182 |
+
unique_ids.append(aid)
|
| 183 |
+
|
| 184 |
+
# 验证账号是否存在
|
| 185 |
+
if valid_account_ids:
|
| 186 |
+
for aid in unique_ids:
|
| 187 |
+
if aid not in valid_account_ids:
|
| 188 |
+
return False, f"账号不存在: {aid}"
|
| 189 |
+
|
| 190 |
+
self._priority_accounts = unique_ids
|
| 191 |
+
self._save_priority_config()
|
| 192 |
+
if len(unique_ids) == 1:
|
| 193 |
+
return True, f"已设置优先账号: {unique_ids[0]}"
|
| 194 |
+
return True, f"已设置优先账号: {', '.join(unique_ids)}"
|
| 195 |
+
|
| 196 |
+
def set_priority_account(self, account_id: Optional[str],
|
| 197 |
+
valid_account_ids: Optional[Set[str]] = None) -> tuple:
|
| 198 |
+
"""设置优先账号(单个)
|
| 199 |
+
|
| 200 |
+
Args:
|
| 201 |
+
account_id: 账号ID,None 表示清除
|
| 202 |
+
valid_account_ids: 有效账号ID集合(用于验证)
|
| 203 |
+
|
| 204 |
+
Returns:
|
| 205 |
+
(success, message)
|
| 206 |
+
"""
|
| 207 |
+
if account_id is None:
|
| 208 |
+
return self.set_priority_accounts([], valid_account_ids)
|
| 209 |
+
return self.set_priority_accounts([account_id], valid_account_ids)
|
| 210 |
+
|
| 211 |
+
def add_priority_account(self, account_id: str,
|
| 212 |
+
position: int = -1,
|
| 213 |
+
valid_account_ids: Optional[Set[str]] = None) -> tuple:
|
| 214 |
+
"""添加优先账号(可指定插入位置)
|
| 215 |
+
|
| 216 |
+
Args:
|
| 217 |
+
account_id: 账号ID
|
| 218 |
+
position: 插入位置(0-based),-1 表示追加到末尾
|
| 219 |
+
valid_account_ids: 有效账号ID集合(用于验证)
|
| 220 |
+
|
| 221 |
+
Returns:
|
| 222 |
+
(success, message)
|
| 223 |
+
"""
|
| 224 |
+
with self._lock:
|
| 225 |
+
if valid_account_ids and account_id not in valid_account_ids:
|
| 226 |
+
return False, f"账号不存在: {account_id}"
|
| 227 |
+
|
| 228 |
+
if account_id in self._priority_accounts:
|
| 229 |
+
self._priority_accounts.remove(account_id)
|
| 230 |
+
|
| 231 |
+
if position is None or position < 0 or position >= len(self._priority_accounts):
|
| 232 |
+
self._priority_accounts.append(account_id)
|
| 233 |
+
else:
|
| 234 |
+
self._priority_accounts.insert(position, account_id)
|
| 235 |
+
|
| 236 |
+
self._save_priority_config()
|
| 237 |
+
return True, f"已添加优先账号: {account_id}"
|
| 238 |
+
|
| 239 |
+
def remove_priority_account(self, account_id: str = None) -> tuple:
|
| 240 |
+
"""移除优先账号
|
| 241 |
+
|
| 242 |
+
Args:
|
| 243 |
+
account_id: 账号ID(可选,不传则清除所有)
|
| 244 |
+
|
| 245 |
+
Returns:
|
| 246 |
+
(success, message)
|
| 247 |
+
"""
|
| 248 |
+
with self._lock:
|
| 249 |
+
if not self._priority_accounts:
|
| 250 |
+
return False, "没有设置优先账号"
|
| 251 |
+
|
| 252 |
+
if account_id:
|
| 253 |
+
if account_id not in self._priority_accounts:
|
| 254 |
+
return False, f"账号 {account_id} 不是优先账号"
|
| 255 |
+
|
| 256 |
+
self._priority_accounts.remove(account_id)
|
| 257 |
+
if not self._priority_accounts:
|
| 258 |
+
self._strategy = SelectionStrategy.RANDOM
|
| 259 |
+
self._save_priority_config()
|
| 260 |
+
return True, f"已移除优先账号: {account_id}"
|
| 261 |
+
|
| 262 |
+
self._priority_accounts = []
|
| 263 |
+
self._strategy = SelectionStrategy.RANDOM
|
| 264 |
+
self._save_priority_config()
|
| 265 |
+
return True, "已清除优先账号"
|
| 266 |
+
|
| 267 |
+
def reorder_priority(self, account_ids: List[str]) -> tuple:
|
| 268 |
+
"""重新排序优先账号列表
|
| 269 |
+
|
| 270 |
+
Args:
|
| 271 |
+
account_ids: 新的优先账号顺序(必须与当前优先账号集合一致)
|
| 272 |
+
|
| 273 |
+
Returns:
|
| 274 |
+
(success, message)
|
| 275 |
+
"""
|
| 276 |
+
with self._lock:
|
| 277 |
+
if not self._priority_accounts:
|
| 278 |
+
return False, "没有设置优先账号"
|
| 279 |
+
|
| 280 |
+
if not account_ids:
|
| 281 |
+
return False, "账号列表不能为空"
|
| 282 |
+
|
| 283 |
+
if len(account_ids) != len(self._priority_accounts):
|
| 284 |
+
return False, "账号数量不匹配"
|
| 285 |
+
|
| 286 |
+
if len(set(account_ids)) != len(account_ids):
|
| 287 |
+
return False, "账号列表包含重复项"
|
| 288 |
+
|
| 289 |
+
if set(account_ids) != set(self._priority_accounts):
|
| 290 |
+
return False, "账号列表与当前优先账号不匹配"
|
| 291 |
+
|
| 292 |
+
self._priority_accounts = list(account_ids)
|
| 293 |
+
self._save_priority_config()
|
| 294 |
+
return True, "已更新优先账号顺序"
|
| 295 |
+
|
| 296 |
+
def get_priority_account(self) -> Optional[str]:
|
| 297 |
+
"""获取优先账号(单个)"""
|
| 298 |
+
with self._lock:
|
| 299 |
+
return self._priority_accounts[0] if self._priority_accounts else None
|
| 300 |
+
|
| 301 |
+
def get_priority_accounts(self) -> List[str]:
|
| 302 |
+
"""获取优先账号列表"""
|
| 303 |
+
with self._lock:
|
| 304 |
+
return list(self._priority_accounts)
|
| 305 |
+
|
| 306 |
+
def is_priority_account(self, account_id: str) -> bool:
|
| 307 |
+
"""检查账号是否为优先账号"""
|
| 308 |
+
with self._lock:
|
| 309 |
+
return account_id in self._priority_accounts
|
| 310 |
+
|
| 311 |
+
def get_priority_order(self, account_id: str) -> Optional[int]:
|
| 312 |
+
"""获取账号的优先级顺序(从1开始)"""
|
| 313 |
+
with self._lock:
|
| 314 |
+
if account_id in self._priority_accounts:
|
| 315 |
+
return self._priority_accounts.index(account_id) + 1
|
| 316 |
+
return None
|
| 317 |
+
|
| 318 |
+
def _load_priority_config(self) -> bool:
|
| 319 |
+
"""从文件加载优先账号配置"""
|
| 320 |
+
if not self._priority_file.exists():
|
| 321 |
+
return False
|
| 322 |
+
|
| 323 |
+
try:
|
| 324 |
+
with open(self._priority_file, 'r', encoding='utf-8') as f:
|
| 325 |
+
data = json.load(f)
|
| 326 |
+
|
| 327 |
+
self._priority_accounts = data.get("priority_accounts", [])
|
| 328 |
+
strategy_str = data.get("strategy", SelectionStrategy.RANDOM.value)
|
| 329 |
+
try:
|
| 330 |
+
self._strategy = SelectionStrategy(strategy_str)
|
| 331 |
+
except ValueError:
|
| 332 |
+
self._strategy = SelectionStrategy.RANDOM
|
| 333 |
+
|
| 334 |
+
# 兼容旧版本:历史默认策略为 lowest_balance,但无优先账号时更需要分散压力
|
| 335 |
+
if not self._priority_accounts and self._strategy == SelectionStrategy.LOWEST_BALANCE:
|
| 336 |
+
self._strategy = SelectionStrategy.RANDOM
|
| 337 |
+
self._save_priority_config()
|
| 338 |
+
|
| 339 |
+
print(f"[AccountSelector] 加载优先账号配置: {len(self._priority_accounts)} 个优先账号")
|
| 340 |
+
return True
|
| 341 |
+
|
| 342 |
+
except Exception as e:
|
| 343 |
+
print(f"[AccountSelector] 加载优先账号配置失败: {e}")
|
| 344 |
+
return False
|
| 345 |
+
|
| 346 |
+
def _save_priority_config(self) -> bool:
|
| 347 |
+
"""保存优先账号配置到文件"""
|
| 348 |
+
try:
|
| 349 |
+
self._priority_file.parent.mkdir(parents=True, exist_ok=True)
|
| 350 |
+
|
| 351 |
+
data = {
|
| 352 |
+
"version": "1.0",
|
| 353 |
+
"priority_accounts": self._priority_accounts,
|
| 354 |
+
"strategy": self._strategy.value
|
| 355 |
+
}
|
| 356 |
+
|
| 357 |
+
temp_file = self._priority_file.with_suffix('.tmp')
|
| 358 |
+
with open(temp_file, 'w', encoding='utf-8') as f:
|
| 359 |
+
json.dump(data, f, indent=2, ensure_ascii=False)
|
| 360 |
+
temp_file.replace(self._priority_file)
|
| 361 |
+
|
| 362 |
+
return True
|
| 363 |
+
|
| 364 |
+
except Exception as e:
|
| 365 |
+
print(f"[AccountSelector] 保存优先账号配置失败: {e}")
|
| 366 |
+
return False
|
| 367 |
+
|
| 368 |
+
def get_status(self) -> dict:
|
| 369 |
+
"""获取选择器状态"""
|
| 370 |
+
with self._lock:
|
| 371 |
+
return {
|
| 372 |
+
"strategy": self._strategy.value,
|
| 373 |
+
"priority_accounts": list(self._priority_accounts),
|
| 374 |
+
"priority_count": len(self._priority_accounts)
|
| 375 |
+
}
|
| 376 |
+
|
| 377 |
+
|
| 378 |
+
# 全局选择器实例
|
| 379 |
+
_account_selector: Optional[AccountSelector] = None
|
| 380 |
+
|
| 381 |
+
|
| 382 |
+
def get_account_selector(quota_cache: Optional['QuotaCache'] = None) -> AccountSelector:
|
| 383 |
+
"""获取全局选择器实例"""
|
| 384 |
+
global _account_selector
|
| 385 |
+
if _account_selector is None:
|
| 386 |
+
if quota_cache is None:
|
| 387 |
+
from .quota_cache import get_quota_cache
|
| 388 |
+
quota_cache = get_quota_cache()
|
| 389 |
+
_account_selector = AccountSelector(quota_cache)
|
| 390 |
+
return _account_selector
|
KiroProxy/kiro_proxy/core/browser.py
ADDED
|
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""浏览器检测和打开"""
|
| 2 |
+
import os
|
| 3 |
+
import shlex
|
| 4 |
+
import shutil
|
| 5 |
+
import subprocess
|
| 6 |
+
import platform
|
| 7 |
+
from dataclasses import dataclass
|
| 8 |
+
from typing import List, Optional
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@dataclass
|
| 12 |
+
class BrowserInfo:
|
| 13 |
+
id: str
|
| 14 |
+
name: str
|
| 15 |
+
path: str
|
| 16 |
+
supports_incognito: bool
|
| 17 |
+
incognito_arg: str = ""
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
# 浏览器配置
|
| 21 |
+
BROWSER_CONFIGS = {
|
| 22 |
+
"chrome": {
|
| 23 |
+
"names": ["google-chrome", "google-chrome-stable", "chrome", "chromium", "chromium-browser"],
|
| 24 |
+
"display": "Chrome",
|
| 25 |
+
"incognito": "--incognito",
|
| 26 |
+
},
|
| 27 |
+
"firefox": {
|
| 28 |
+
"names": ["firefox", "firefox-esr"],
|
| 29 |
+
"display": "Firefox",
|
| 30 |
+
"incognito": "--private-window",
|
| 31 |
+
},
|
| 32 |
+
"edge": {
|
| 33 |
+
"names": ["microsoft-edge", "microsoft-edge-stable", "msedge"],
|
| 34 |
+
"display": "Edge",
|
| 35 |
+
"incognito": "--inprivate",
|
| 36 |
+
},
|
| 37 |
+
"brave": {
|
| 38 |
+
"names": ["brave", "brave-browser"],
|
| 39 |
+
"display": "Brave",
|
| 40 |
+
"incognito": "--incognito",
|
| 41 |
+
},
|
| 42 |
+
"opera": {
|
| 43 |
+
"names": ["opera"],
|
| 44 |
+
"display": "Opera",
|
| 45 |
+
"incognito": "--private",
|
| 46 |
+
},
|
| 47 |
+
"vivaldi": {
|
| 48 |
+
"names": ["vivaldi", "vivaldi-stable"],
|
| 49 |
+
"display": "Vivaldi",
|
| 50 |
+
"incognito": "--incognito",
|
| 51 |
+
},
|
| 52 |
+
}
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def detect_browsers() -> List[BrowserInfo]:
|
| 56 |
+
"""检测系统安装的浏览器"""
|
| 57 |
+
browsers = []
|
| 58 |
+
system = platform.system().lower()
|
| 59 |
+
|
| 60 |
+
if system == "windows":
|
| 61 |
+
import winreg
|
| 62 |
+
|
| 63 |
+
def normalize_exe_path(raw: str) -> Optional[str]:
|
| 64 |
+
if not raw:
|
| 65 |
+
return None
|
| 66 |
+
expanded = os.path.expandvars(raw.strip())
|
| 67 |
+
try:
|
| 68 |
+
parts = shlex.split(expanded, posix=False)
|
| 69 |
+
except ValueError:
|
| 70 |
+
parts = [expanded]
|
| 71 |
+
candidate = (parts[0] if parts else expanded).strip().strip('"')
|
| 72 |
+
if os.path.exists(candidate):
|
| 73 |
+
return candidate
|
| 74 |
+
lower = expanded.lower()
|
| 75 |
+
exe_idx = lower.find(".exe")
|
| 76 |
+
if exe_idx != -1:
|
| 77 |
+
candidate = expanded[:exe_idx + 4].strip().strip('"')
|
| 78 |
+
if os.path.exists(candidate):
|
| 79 |
+
return candidate
|
| 80 |
+
return None
|
| 81 |
+
|
| 82 |
+
def get_reg_path(exe_name: str) -> Optional[str]:
|
| 83 |
+
name = f"{exe_name}.exe"
|
| 84 |
+
for root in (winreg.HKEY_LOCAL_MACHINE, winreg.HKEY_CURRENT_USER):
|
| 85 |
+
try:
|
| 86 |
+
with winreg.OpenKey(root, rf"SOFTWARE\Microsoft\Windows\CurrentVersion\App Paths\{name}") as key:
|
| 87 |
+
value, _ = winreg.QueryValueEx(key, "")
|
| 88 |
+
path = normalize_exe_path(value)
|
| 89 |
+
if path:
|
| 90 |
+
return path
|
| 91 |
+
except (FileNotFoundError, OSError, WindowsError):
|
| 92 |
+
pass
|
| 93 |
+
return None
|
| 94 |
+
|
| 95 |
+
for browser_id, config in BROWSER_CONFIGS.items():
|
| 96 |
+
path = None
|
| 97 |
+
for exe_name in config["names"]:
|
| 98 |
+
path = get_reg_path(exe_name)
|
| 99 |
+
if path:
|
| 100 |
+
break
|
| 101 |
+
if not path:
|
| 102 |
+
for exe_name in config["names"]:
|
| 103 |
+
path = shutil.which(exe_name)
|
| 104 |
+
if path:
|
| 105 |
+
break
|
| 106 |
+
if path:
|
| 107 |
+
browsers.append(BrowserInfo(
|
| 108 |
+
id=browser_id,
|
| 109 |
+
name=config["display"],
|
| 110 |
+
path=path,
|
| 111 |
+
supports_incognito=bool(config.get("incognito")),
|
| 112 |
+
incognito_arg=config.get("incognito", ""),
|
| 113 |
+
))
|
| 114 |
+
else:
|
| 115 |
+
for browser_id, config in BROWSER_CONFIGS.items():
|
| 116 |
+
for name in config["names"]:
|
| 117 |
+
path = shutil.which(name)
|
| 118 |
+
if path:
|
| 119 |
+
browsers.append(BrowserInfo(
|
| 120 |
+
id=browser_id,
|
| 121 |
+
name=config["display"],
|
| 122 |
+
path=path,
|
| 123 |
+
supports_incognito=bool(config.get("incognito")),
|
| 124 |
+
incognito_arg=config.get("incognito", ""),
|
| 125 |
+
))
|
| 126 |
+
break
|
| 127 |
+
|
| 128 |
+
# 添加默认浏览器选项
|
| 129 |
+
if browsers:
|
| 130 |
+
browsers.insert(0, BrowserInfo(
|
| 131 |
+
id="default",
|
| 132 |
+
name="默认浏览器",
|
| 133 |
+
path="xdg-open" if system == "linux" else "open",
|
| 134 |
+
supports_incognito=False,
|
| 135 |
+
incognito_arg="",
|
| 136 |
+
))
|
| 137 |
+
|
| 138 |
+
return browsers
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def open_url(url: str, browser_id: str = "default", incognito: bool = False) -> bool:
|
| 142 |
+
"""用指定浏览器打开 URL"""
|
| 143 |
+
browsers = detect_browsers()
|
| 144 |
+
browser = next((b for b in browsers if b.id == browser_id), None)
|
| 145 |
+
|
| 146 |
+
if not browser:
|
| 147 |
+
# 降级到默认
|
| 148 |
+
browser = browsers[0] if browsers else None
|
| 149 |
+
|
| 150 |
+
if not browser:
|
| 151 |
+
return False
|
| 152 |
+
|
| 153 |
+
try:
|
| 154 |
+
if browser.id == "default":
|
| 155 |
+
# 使用系统默认浏览器
|
| 156 |
+
system = platform.system().lower()
|
| 157 |
+
if system == "linux":
|
| 158 |
+
subprocess.Popen(["xdg-open", url], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
|
| 159 |
+
elif system == "darwin":
|
| 160 |
+
subprocess.Popen(["open", url], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
|
| 161 |
+
else:
|
| 162 |
+
os.startfile(url)
|
| 163 |
+
else:
|
| 164 |
+
# 使用指定浏览器
|
| 165 |
+
args = [browser.path]
|
| 166 |
+
if incognito and browser.supports_incognito and browser.incognito_arg:
|
| 167 |
+
args.append(browser.incognito_arg)
|
| 168 |
+
args.append(url)
|
| 169 |
+
subprocess.Popen(args, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
|
| 170 |
+
|
| 171 |
+
return True
|
| 172 |
+
except Exception as e:
|
| 173 |
+
print(f"[Browser] 打开失败: {e}")
|
| 174 |
+
return False
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
def get_browsers_info() -> List[dict]:
|
| 178 |
+
"""获取浏览器信息列表"""
|
| 179 |
+
return [
|
| 180 |
+
{
|
| 181 |
+
"id": b.id,
|
| 182 |
+
"name": b.name,
|
| 183 |
+
"supports_incognito": b.supports_incognito,
|
| 184 |
+
}
|
| 185 |
+
for b in detect_browsers()
|
| 186 |
+
]
|
KiroProxy/kiro_proxy/core/error_handler.py
ADDED
|
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""错误处理模块 - 统一的错误分类和处理
|
| 2 |
+
|
| 3 |
+
检测各种 Kiro API 错误类型:
|
| 4 |
+
- 账号封禁 (TEMPORARILY_SUSPENDED)
|
| 5 |
+
- 配额超限 (Rate Limit)
|
| 6 |
+
- 内容过长 (CONTENT_LENGTH_EXCEEDS_THRESHOLD)
|
| 7 |
+
- 认证失败 (Unauthorized)
|
| 8 |
+
- 服务不可用 (Service Unavailable)
|
| 9 |
+
"""
|
| 10 |
+
import re
|
| 11 |
+
from enum import Enum
|
| 12 |
+
from dataclasses import dataclass
|
| 13 |
+
from typing import Optional, Tuple
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class ErrorType(str, Enum):
|
| 17 |
+
"""错误类型"""
|
| 18 |
+
ACCOUNT_SUSPENDED = "account_suspended" # 账号被封禁
|
| 19 |
+
RATE_LIMITED = "rate_limited" # 配额超限
|
| 20 |
+
CONTENT_TOO_LONG = "content_too_long" # 内容过长
|
| 21 |
+
AUTH_FAILED = "auth_failed" # 认证失败
|
| 22 |
+
SERVICE_UNAVAILABLE = "service_unavailable" # 服务不可用
|
| 23 |
+
MODEL_UNAVAILABLE = "model_unavailable" # 模型不可用
|
| 24 |
+
UNKNOWN = "unknown" # 未知错误
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
@dataclass
|
| 28 |
+
class KiroError:
|
| 29 |
+
"""Kiro API 错误"""
|
| 30 |
+
type: ErrorType
|
| 31 |
+
status_code: int
|
| 32 |
+
message: str
|
| 33 |
+
user_message: str # 用户友好的消息
|
| 34 |
+
should_disable_account: bool = False # 是否应该禁用账号
|
| 35 |
+
should_switch_account: bool = False # 是否应该切换账号
|
| 36 |
+
should_retry: bool = False # 是否应该重试
|
| 37 |
+
cooldown_seconds: int = 0 # 冷却时间
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def classify_error(status_code: int, error_text: str) -> KiroError:
|
| 41 |
+
"""分类 Kiro API 错误
|
| 42 |
+
|
| 43 |
+
Args:
|
| 44 |
+
status_code: HTTP 状态码
|
| 45 |
+
error_text: 错误响应文本
|
| 46 |
+
|
| 47 |
+
Returns:
|
| 48 |
+
KiroError 对象
|
| 49 |
+
"""
|
| 50 |
+
error_lower = error_text.lower()
|
| 51 |
+
|
| 52 |
+
# 1. 账号封禁检测 (最严重)
|
| 53 |
+
# 检测: AccountSuspendedException, 423 状态码, temporarily_suspended, suspended
|
| 54 |
+
is_suspended = (
|
| 55 |
+
status_code == 423 or
|
| 56 |
+
"accountsuspendedexception" in error_lower or
|
| 57 |
+
"temporarily_suspended" in error_lower or
|
| 58 |
+
"suspended" in error_lower
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
if is_suspended:
|
| 62 |
+
# 提取 User ID
|
| 63 |
+
user_id_match = re.search(r'User ID \(([^)]+)\)', error_text)
|
| 64 |
+
user_id = user_id_match.group(1) if user_id_match else "unknown"
|
| 65 |
+
|
| 66 |
+
return KiroError(
|
| 67 |
+
type=ErrorType.ACCOUNT_SUSPENDED,
|
| 68 |
+
status_code=status_code,
|
| 69 |
+
message=error_text,
|
| 70 |
+
user_message=f"⚠️ 账号已被封禁 (User ID: {user_id})。请联系 AWS 支持解封: https://support.aws.amazon.com/#/contacts/kiro",
|
| 71 |
+
should_disable_account=True,
|
| 72 |
+
should_switch_account=True,
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
# 2. 402 Payment Required - 额度用尽(不触发冷却,仅切换账号)
|
| 76 |
+
if status_code == 402 or "payment required" in error_lower or "insufficient" in error_lower:
|
| 77 |
+
return KiroError(
|
| 78 |
+
type=ErrorType.RATE_LIMITED,
|
| 79 |
+
status_code=status_code,
|
| 80 |
+
message=error_text,
|
| 81 |
+
user_message="账号额度已用尽,已切换到其他账号",
|
| 82 |
+
should_switch_account=False, # 不自动切换,让上层逻辑处理
|
| 83 |
+
cooldown_seconds=0, # 不触发冷却
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
# 3. 配额超限检测 (仅 429 触发冷却)
|
| 87 |
+
if status_code == 429:
|
| 88 |
+
return KiroError(
|
| 89 |
+
type=ErrorType.RATE_LIMITED,
|
| 90 |
+
status_code=status_code,
|
| 91 |
+
message=error_text,
|
| 92 |
+
user_message="请求过于频繁,账号已进入冷却期",
|
| 93 |
+
should_switch_account=True,
|
| 94 |
+
cooldown_seconds=30, # 基础冷却时间,实际由 QuotaManager 动态管理
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
# 4. 内容过长检测
|
| 98 |
+
if "content_length_exceeds_threshold" in error_lower or (
|
| 99 |
+
"too long" in error_lower and ("input" in error_lower or "content" in error_lower)
|
| 100 |
+
):
|
| 101 |
+
return KiroError(
|
| 102 |
+
type=ErrorType.CONTENT_TOO_LONG,
|
| 103 |
+
status_code=status_code,
|
| 104 |
+
message=error_text,
|
| 105 |
+
user_message="对话历史过长,请使用 /clear 清空对话",
|
| 106 |
+
should_retry=True,
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
# 5. 认证失败检测
|
| 110 |
+
if status_code == 401 or "unauthorized" in error_lower or "invalid token" in error_lower:
|
| 111 |
+
return KiroError(
|
| 112 |
+
type=ErrorType.AUTH_FAILED,
|
| 113 |
+
status_code=status_code,
|
| 114 |
+
message=error_text,
|
| 115 |
+
user_message="Token 已过期或无效,请刷新 Token",
|
| 116 |
+
should_switch_account=True,
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
# 6. 模型不可用检测
|
| 120 |
+
if "model_temporarily_unavailable" in error_lower or "unexpectedly high load" in error_lower:
|
| 121 |
+
return KiroError(
|
| 122 |
+
type=ErrorType.MODEL_UNAVAILABLE,
|
| 123 |
+
status_code=status_code,
|
| 124 |
+
message=error_text,
|
| 125 |
+
user_message="模型暂时不可用,请稍后重试",
|
| 126 |
+
should_retry=True,
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
# 7. 服务不可用检测
|
| 130 |
+
if status_code in (502, 503, 504) or "service unavailable" in error_lower:
|
| 131 |
+
return KiroError(
|
| 132 |
+
type=ErrorType.SERVICE_UNAVAILABLE,
|
| 133 |
+
status_code=status_code,
|
| 134 |
+
message=error_text,
|
| 135 |
+
user_message="服务暂时不可用,请稍后重试",
|
| 136 |
+
should_retry=True,
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
# 8. 未知错误
|
| 140 |
+
return KiroError(
|
| 141 |
+
type=ErrorType.UNKNOWN,
|
| 142 |
+
status_code=status_code,
|
| 143 |
+
message=error_text,
|
| 144 |
+
user_message=f"API 错误 ({status_code})",
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def is_account_suspended(status_code: int, error_text: str) -> bool:
|
| 149 |
+
"""检查是否为账号封禁错误"""
|
| 150 |
+
error = classify_error(status_code, error_text)
|
| 151 |
+
return error.type == ErrorType.ACCOUNT_SUSPENDED
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def get_anthropic_error_response(error: KiroError) -> dict:
|
| 155 |
+
"""生成 Anthropic 格式的错误响应"""
|
| 156 |
+
error_type_map = {
|
| 157 |
+
ErrorType.ACCOUNT_SUSPENDED: "authentication_error",
|
| 158 |
+
ErrorType.RATE_LIMITED: "rate_limit_error",
|
| 159 |
+
ErrorType.CONTENT_TOO_LONG: "invalid_request_error",
|
| 160 |
+
ErrorType.AUTH_FAILED: "authentication_error",
|
| 161 |
+
ErrorType.SERVICE_UNAVAILABLE: "api_error",
|
| 162 |
+
ErrorType.MODEL_UNAVAILABLE: "overloaded_error",
|
| 163 |
+
ErrorType.UNKNOWN: "api_error",
|
| 164 |
+
}
|
| 165 |
+
|
| 166 |
+
return {
|
| 167 |
+
"type": "error",
|
| 168 |
+
"error": {
|
| 169 |
+
"type": error_type_map.get(error.type, "api_error"),
|
| 170 |
+
"message": error.user_message
|
| 171 |
+
}
|
| 172 |
+
}
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def format_error_log(error: KiroError, account_id: str = None) -> str:
|
| 176 |
+
"""格式化错误日志"""
|
| 177 |
+
lines = [
|
| 178 |
+
f"[{error.type.value.upper()}]",
|
| 179 |
+
f" Status: {error.status_code}",
|
| 180 |
+
f" Message: {error.user_message}",
|
| 181 |
+
]
|
| 182 |
+
if account_id:
|
| 183 |
+
lines.insert(1, f" Account: {account_id}")
|
| 184 |
+
if error.should_disable_account:
|
| 185 |
+
lines.append(" Action: 账号已被禁用")
|
| 186 |
+
elif error.should_switch_account:
|
| 187 |
+
lines.append(" Action: 切换到其他账号")
|
| 188 |
+
return "\n".join(lines)
|
KiroProxy/kiro_proxy/core/flow_monitor.py
ADDED
|
@@ -0,0 +1,572 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Flow Monitor - LLM 流量监控
|
| 2 |
+
|
| 3 |
+
记录完整的请求/响应数据,支持查询、过滤、导出。
|
| 4 |
+
"""
|
| 5 |
+
import json
|
| 6 |
+
import time
|
| 7 |
+
import uuid
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from dataclasses import dataclass, field, asdict
|
| 10 |
+
from typing import Optional, List, Dict, Any
|
| 11 |
+
from datetime import datetime, timezone
|
| 12 |
+
from collections import deque
|
| 13 |
+
from enum import Enum
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class FlowState(str, Enum):
|
| 17 |
+
"""Flow 状态"""
|
| 18 |
+
PENDING = "pending" # 等待响应
|
| 19 |
+
STREAMING = "streaming" # 流式传输中
|
| 20 |
+
COMPLETED = "completed" # 完成
|
| 21 |
+
ERROR = "error" # 错误
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@dataclass
|
| 25 |
+
class Message:
|
| 26 |
+
"""消息"""
|
| 27 |
+
role: str # user/assistant/system/tool
|
| 28 |
+
content: Any # str 或 list
|
| 29 |
+
name: Optional[str] = None # tool name
|
| 30 |
+
tool_call_id: Optional[str] = None
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
@dataclass
|
| 34 |
+
class TokenUsage:
|
| 35 |
+
"""Token 使用量"""
|
| 36 |
+
input_tokens: int = 0
|
| 37 |
+
output_tokens: int = 0
|
| 38 |
+
cache_read_tokens: int = 0
|
| 39 |
+
cache_write_tokens: int = 0
|
| 40 |
+
|
| 41 |
+
@property
|
| 42 |
+
def total_tokens(self) -> int:
|
| 43 |
+
return self.input_tokens + self.output_tokens
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
@dataclass
|
| 47 |
+
class FlowRequest:
|
| 48 |
+
"""请求数据"""
|
| 49 |
+
method: str
|
| 50 |
+
path: str
|
| 51 |
+
headers: Dict[str, str]
|
| 52 |
+
body: Dict[str, Any]
|
| 53 |
+
|
| 54 |
+
# 解析后的字段
|
| 55 |
+
model: str = ""
|
| 56 |
+
messages: List[Message] = field(default_factory=list)
|
| 57 |
+
system: str = ""
|
| 58 |
+
tools: List[Dict] = field(default_factory=list)
|
| 59 |
+
stream: bool = False
|
| 60 |
+
max_tokens: int = 0
|
| 61 |
+
temperature: float = 1.0
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
@dataclass
|
| 65 |
+
class FlowResponse:
|
| 66 |
+
"""响应数据"""
|
| 67 |
+
status_code: int
|
| 68 |
+
headers: Dict[str, str] = field(default_factory=dict)
|
| 69 |
+
body: Any = None
|
| 70 |
+
|
| 71 |
+
# 解析后的字段
|
| 72 |
+
content: str = ""
|
| 73 |
+
tool_calls: List[Dict] = field(default_factory=list)
|
| 74 |
+
stop_reason: str = ""
|
| 75 |
+
usage: TokenUsage = field(default_factory=TokenUsage)
|
| 76 |
+
|
| 77 |
+
# 流式响应
|
| 78 |
+
chunks: List[str] = field(default_factory=list)
|
| 79 |
+
chunk_count: int = 0
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
@dataclass
|
| 83 |
+
class FlowError:
|
| 84 |
+
"""错误信息"""
|
| 85 |
+
type: str # rate_limit_error, api_error, etc.
|
| 86 |
+
message: str
|
| 87 |
+
status_code: int = 0
|
| 88 |
+
raw: str = ""
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
@dataclass
|
| 92 |
+
class FlowTiming:
|
| 93 |
+
"""时间信息"""
|
| 94 |
+
created_at: float = 0
|
| 95 |
+
first_byte_at: Optional[float] = None
|
| 96 |
+
completed_at: Optional[float] = None
|
| 97 |
+
|
| 98 |
+
@property
|
| 99 |
+
def ttfb_ms(self) -> Optional[float]:
|
| 100 |
+
"""Time to first byte"""
|
| 101 |
+
if self.first_byte_at and self.created_at:
|
| 102 |
+
return (self.first_byte_at - self.created_at) * 1000
|
| 103 |
+
return None
|
| 104 |
+
|
| 105 |
+
@property
|
| 106 |
+
def duration_ms(self) -> Optional[float]:
|
| 107 |
+
"""Total duration"""
|
| 108 |
+
if self.completed_at and self.created_at:
|
| 109 |
+
return (self.completed_at - self.created_at) * 1000
|
| 110 |
+
return None
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
@dataclass
|
| 114 |
+
class LLMFlow:
|
| 115 |
+
"""完整的 LLM 请求流"""
|
| 116 |
+
id: str
|
| 117 |
+
state: FlowState
|
| 118 |
+
|
| 119 |
+
# 路由信息
|
| 120 |
+
protocol: str # anthropic, openai, gemini
|
| 121 |
+
account_id: Optional[str] = None
|
| 122 |
+
account_name: Optional[str] = None
|
| 123 |
+
|
| 124 |
+
# 请求/响应
|
| 125 |
+
request: Optional[FlowRequest] = None
|
| 126 |
+
response: Optional[FlowResponse] = None
|
| 127 |
+
error: Optional[FlowError] = None
|
| 128 |
+
|
| 129 |
+
# 时间
|
| 130 |
+
timing: FlowTiming = field(default_factory=FlowTiming)
|
| 131 |
+
|
| 132 |
+
# 元数据
|
| 133 |
+
tags: List[str] = field(default_factory=list)
|
| 134 |
+
notes: str = ""
|
| 135 |
+
bookmarked: bool = False
|
| 136 |
+
|
| 137 |
+
# 重试信息
|
| 138 |
+
retry_count: int = 0
|
| 139 |
+
parent_flow_id: Optional[str] = None
|
| 140 |
+
|
| 141 |
+
def to_dict(self) -> dict:
|
| 142 |
+
"""转换为字典"""
|
| 143 |
+
d = {
|
| 144 |
+
"id": self.id,
|
| 145 |
+
"state": self.state.value,
|
| 146 |
+
"protocol": self.protocol,
|
| 147 |
+
"account_id": self.account_id,
|
| 148 |
+
"account_name": self.account_name,
|
| 149 |
+
"timing": {
|
| 150 |
+
"created_at": self.timing.created_at,
|
| 151 |
+
"first_byte_at": self.timing.first_byte_at,
|
| 152 |
+
"completed_at": self.timing.completed_at,
|
| 153 |
+
"ttfb_ms": self.timing.ttfb_ms,
|
| 154 |
+
"duration_ms": self.timing.duration_ms,
|
| 155 |
+
},
|
| 156 |
+
"tags": self.tags,
|
| 157 |
+
"notes": self.notes,
|
| 158 |
+
"bookmarked": self.bookmarked,
|
| 159 |
+
"retry_count": self.retry_count,
|
| 160 |
+
}
|
| 161 |
+
|
| 162 |
+
if self.request:
|
| 163 |
+
d["request"] = {
|
| 164 |
+
"method": self.request.method,
|
| 165 |
+
"path": self.request.path,
|
| 166 |
+
"model": self.request.model,
|
| 167 |
+
"stream": self.request.stream,
|
| 168 |
+
"message_count": len(self.request.messages),
|
| 169 |
+
"has_tools": bool(self.request.tools),
|
| 170 |
+
"has_system": bool(self.request.system),
|
| 171 |
+
}
|
| 172 |
+
|
| 173 |
+
if self.response:
|
| 174 |
+
d["response"] = {
|
| 175 |
+
"status_code": self.response.status_code,
|
| 176 |
+
"content_length": len(self.response.content),
|
| 177 |
+
"has_tool_calls": bool(self.response.tool_calls),
|
| 178 |
+
"stop_reason": self.response.stop_reason,
|
| 179 |
+
"chunk_count": self.response.chunk_count,
|
| 180 |
+
"usage": asdict(self.response.usage),
|
| 181 |
+
}
|
| 182 |
+
|
| 183 |
+
if self.error:
|
| 184 |
+
d["error"] = asdict(self.error)
|
| 185 |
+
|
| 186 |
+
return d
|
| 187 |
+
|
| 188 |
+
def to_full_dict(self) -> dict:
|
| 189 |
+
"""转换为完整字典(包含请求/响应体)"""
|
| 190 |
+
d = self.to_dict()
|
| 191 |
+
|
| 192 |
+
if self.request:
|
| 193 |
+
d["request"]["headers"] = self.request.headers
|
| 194 |
+
d["request"]["body"] = self.request.body
|
| 195 |
+
d["request"]["messages"] = [asdict(m) if hasattr(m, '__dataclass_fields__') else m for m in self.request.messages]
|
| 196 |
+
d["request"]["system"] = self.request.system
|
| 197 |
+
d["request"]["tools"] = self.request.tools
|
| 198 |
+
|
| 199 |
+
if self.response:
|
| 200 |
+
d["response"]["headers"] = self.response.headers
|
| 201 |
+
d["response"]["body"] = self.response.body
|
| 202 |
+
d["response"]["content"] = self.response.content
|
| 203 |
+
d["response"]["tool_calls"] = self.response.tool_calls
|
| 204 |
+
d["response"]["chunks"] = self.response.chunks[-10:] # 只保留最后10个chunk
|
| 205 |
+
|
| 206 |
+
return d
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
class FlowStore:
|
| 210 |
+
"""Flow 存储"""
|
| 211 |
+
|
| 212 |
+
def __init__(self, max_flows: int = 500, persist_dir: Optional[Path] = None):
|
| 213 |
+
self.flows: deque[LLMFlow] = deque(maxlen=max_flows)
|
| 214 |
+
self.flow_map: Dict[str, LLMFlow] = {}
|
| 215 |
+
self.persist_dir = persist_dir
|
| 216 |
+
self.max_flows = max_flows
|
| 217 |
+
|
| 218 |
+
# 统计
|
| 219 |
+
self.total_flows = 0
|
| 220 |
+
self.total_tokens_in = 0
|
| 221 |
+
self.total_tokens_out = 0
|
| 222 |
+
|
| 223 |
+
def add(self, flow: LLMFlow):
|
| 224 |
+
"""添加 Flow"""
|
| 225 |
+
# 如果队列满了,移除最旧的
|
| 226 |
+
if len(self.flows) >= self.max_flows:
|
| 227 |
+
old = self.flows[0]
|
| 228 |
+
if old.id in self.flow_map:
|
| 229 |
+
del self.flow_map[old.id]
|
| 230 |
+
|
| 231 |
+
self.flows.append(flow)
|
| 232 |
+
self.flow_map[flow.id] = flow
|
| 233 |
+
self.total_flows += 1
|
| 234 |
+
|
| 235 |
+
def get(self, flow_id: str) -> Optional[LLMFlow]:
|
| 236 |
+
"""获取 Flow"""
|
| 237 |
+
return self.flow_map.get(flow_id)
|
| 238 |
+
|
| 239 |
+
def update(self, flow_id: str, **kwargs):
|
| 240 |
+
"""更新 Flow"""
|
| 241 |
+
flow = self.flow_map.get(flow_id)
|
| 242 |
+
if flow:
|
| 243 |
+
for k, v in kwargs.items():
|
| 244 |
+
if hasattr(flow, k):
|
| 245 |
+
setattr(flow, k, v)
|
| 246 |
+
|
| 247 |
+
def query(
|
| 248 |
+
self,
|
| 249 |
+
protocol: Optional[str] = None,
|
| 250 |
+
model: Optional[str] = None,
|
| 251 |
+
account_id: Optional[str] = None,
|
| 252 |
+
state: Optional[FlowState] = None,
|
| 253 |
+
has_error: Optional[bool] = None,
|
| 254 |
+
bookmarked: Optional[bool] = None,
|
| 255 |
+
min_duration_ms: Optional[float] = None,
|
| 256 |
+
max_duration_ms: Optional[float] = None,
|
| 257 |
+
start_time: Optional[float] = None,
|
| 258 |
+
end_time: Optional[float] = None,
|
| 259 |
+
search: Optional[str] = None,
|
| 260 |
+
limit: int = 100,
|
| 261 |
+
offset: int = 0,
|
| 262 |
+
) -> List[LLMFlow]:
|
| 263 |
+
"""查询 Flows"""
|
| 264 |
+
results = []
|
| 265 |
+
|
| 266 |
+
for flow in reversed(self.flows):
|
| 267 |
+
# 过滤条件
|
| 268 |
+
if protocol and flow.protocol != protocol:
|
| 269 |
+
continue
|
| 270 |
+
if model and flow.request and flow.request.model != model:
|
| 271 |
+
continue
|
| 272 |
+
if account_id and flow.account_id != account_id:
|
| 273 |
+
continue
|
| 274 |
+
if state and flow.state != state:
|
| 275 |
+
continue
|
| 276 |
+
if has_error is not None:
|
| 277 |
+
if has_error and not flow.error:
|
| 278 |
+
continue
|
| 279 |
+
if not has_error and flow.error:
|
| 280 |
+
continue
|
| 281 |
+
if bookmarked is not None and flow.bookmarked != bookmarked:
|
| 282 |
+
continue
|
| 283 |
+
if min_duration_ms and flow.timing.duration_ms and flow.timing.duration_ms < min_duration_ms:
|
| 284 |
+
continue
|
| 285 |
+
if max_duration_ms and flow.timing.duration_ms and flow.timing.duration_ms > max_duration_ms:
|
| 286 |
+
continue
|
| 287 |
+
if start_time and flow.timing.created_at < start_time:
|
| 288 |
+
continue
|
| 289 |
+
if end_time and flow.timing.created_at > end_time:
|
| 290 |
+
continue
|
| 291 |
+
if search:
|
| 292 |
+
# 简单搜索:在内容中查找
|
| 293 |
+
found = False
|
| 294 |
+
if flow.request and search.lower() in json.dumps(flow.request.body).lower():
|
| 295 |
+
found = True
|
| 296 |
+
if flow.response and search.lower() in flow.response.content.lower():
|
| 297 |
+
found = True
|
| 298 |
+
if not found:
|
| 299 |
+
continue
|
| 300 |
+
|
| 301 |
+
results.append(flow)
|
| 302 |
+
|
| 303 |
+
return results[offset:offset + limit]
|
| 304 |
+
|
| 305 |
+
def get_stats(self) -> dict:
|
| 306 |
+
"""获取统计信息"""
|
| 307 |
+
completed = [f for f in self.flows if f.state == FlowState.COMPLETED]
|
| 308 |
+
errors = [f for f in self.flows if f.state == FlowState.ERROR]
|
| 309 |
+
|
| 310 |
+
# 按模型统计
|
| 311 |
+
model_stats = {}
|
| 312 |
+
for f in self.flows:
|
| 313 |
+
if f.request:
|
| 314 |
+
model = f.request.model or "unknown"
|
| 315 |
+
if model not in model_stats:
|
| 316 |
+
model_stats[model] = {"count": 0, "errors": 0, "tokens_in": 0, "tokens_out": 0}
|
| 317 |
+
model_stats[model]["count"] += 1
|
| 318 |
+
if f.error:
|
| 319 |
+
model_stats[model]["errors"] += 1
|
| 320 |
+
if f.response and f.response.usage:
|
| 321 |
+
model_stats[model]["tokens_in"] += f.response.usage.input_tokens
|
| 322 |
+
model_stats[model]["tokens_out"] += f.response.usage.output_tokens
|
| 323 |
+
|
| 324 |
+
# 计算平均延迟
|
| 325 |
+
durations = [f.timing.duration_ms for f in completed if f.timing.duration_ms]
|
| 326 |
+
avg_duration = sum(durations) / len(durations) if durations else 0
|
| 327 |
+
|
| 328 |
+
return {
|
| 329 |
+
"total_flows": self.total_flows,
|
| 330 |
+
"active_flows": len(self.flows),
|
| 331 |
+
"completed": len(completed),
|
| 332 |
+
"errors": len(errors),
|
| 333 |
+
"error_rate": f"{len(errors) / max(1, len(self.flows)) * 100:.1f}%",
|
| 334 |
+
"avg_duration_ms": round(avg_duration, 2),
|
| 335 |
+
"total_tokens_in": self.total_tokens_in,
|
| 336 |
+
"total_tokens_out": self.total_tokens_out,
|
| 337 |
+
"by_model": model_stats,
|
| 338 |
+
}
|
| 339 |
+
|
| 340 |
+
def export_jsonl(self, flows: List[LLMFlow]) -> str:
|
| 341 |
+
"""导出为 JSONL 格式"""
|
| 342 |
+
lines = []
|
| 343 |
+
for f in flows:
|
| 344 |
+
lines.append(json.dumps(f.to_full_dict(), ensure_ascii=False))
|
| 345 |
+
return "\n".join(lines)
|
| 346 |
+
|
| 347 |
+
def export_markdown(self, flow: LLMFlow) -> str:
|
| 348 |
+
"""导出单个 Flow 为 Markdown"""
|
| 349 |
+
lines = [
|
| 350 |
+
f"# Flow {flow.id}",
|
| 351 |
+
"",
|
| 352 |
+
f"- **Protocol**: {flow.protocol}",
|
| 353 |
+
f"- **State**: {flow.state.value}",
|
| 354 |
+
f"- **Account**: {flow.account_name or flow.account_id or 'N/A'}",
|
| 355 |
+
f"- **Created**: {datetime.fromtimestamp(flow.timing.created_at).isoformat()}",
|
| 356 |
+
]
|
| 357 |
+
|
| 358 |
+
if flow.timing.duration_ms:
|
| 359 |
+
lines.append(f"- **Duration**: {flow.timing.duration_ms:.0f}ms")
|
| 360 |
+
|
| 361 |
+
if flow.request:
|
| 362 |
+
lines.extend([
|
| 363 |
+
"",
|
| 364 |
+
"## Request",
|
| 365 |
+
"",
|
| 366 |
+
f"- **Model**: {flow.request.model}",
|
| 367 |
+
f"- **Stream**: {flow.request.stream}",
|
| 368 |
+
f"- **Messages**: {len(flow.request.messages)}",
|
| 369 |
+
])
|
| 370 |
+
|
| 371 |
+
if flow.request.system:
|
| 372 |
+
lines.extend(["", "### System", "", f"```\n{flow.request.system}\n```"])
|
| 373 |
+
|
| 374 |
+
lines.extend(["", "### Messages", ""])
|
| 375 |
+
for msg in flow.request.messages:
|
| 376 |
+
content = msg.content if isinstance(msg.content, str) else json.dumps(msg.content, ensure_ascii=False)
|
| 377 |
+
lines.append(f"**{msg.role}**: {content[:500]}{'...' if len(content) > 500 else ''}")
|
| 378 |
+
lines.append("")
|
| 379 |
+
|
| 380 |
+
if flow.response:
|
| 381 |
+
lines.extend([
|
| 382 |
+
"## Response",
|
| 383 |
+
"",
|
| 384 |
+
f"- **Status**: {flow.response.status_code}",
|
| 385 |
+
f"- **Stop Reason**: {flow.response.stop_reason}",
|
| 386 |
+
])
|
| 387 |
+
|
| 388 |
+
if flow.response.usage:
|
| 389 |
+
lines.append(f"- **Tokens**: {flow.response.usage.input_tokens} in / {flow.response.usage.output_tokens} out")
|
| 390 |
+
|
| 391 |
+
if flow.response.content:
|
| 392 |
+
lines.extend(["", "### Content", "", f"```\n{flow.response.content[:2000]}\n```"])
|
| 393 |
+
|
| 394 |
+
if flow.error:
|
| 395 |
+
lines.extend([
|
| 396 |
+
"",
|
| 397 |
+
"## Error",
|
| 398 |
+
"",
|
| 399 |
+
f"- **Type**: {flow.error.type}",
|
| 400 |
+
f"- **Message**: {flow.error.message}",
|
| 401 |
+
])
|
| 402 |
+
|
| 403 |
+
return "\n".join(lines)
|
| 404 |
+
|
| 405 |
+
|
| 406 |
+
class FlowMonitor:
|
| 407 |
+
"""Flow 监控器"""
|
| 408 |
+
|
| 409 |
+
def __init__(self, max_flows: int = 500):
|
| 410 |
+
self.store = FlowStore(max_flows=max_flows)
|
| 411 |
+
|
| 412 |
+
def create_flow(
|
| 413 |
+
self,
|
| 414 |
+
protocol: str,
|
| 415 |
+
method: str,
|
| 416 |
+
path: str,
|
| 417 |
+
headers: Dict[str, str],
|
| 418 |
+
body: Dict[str, Any],
|
| 419 |
+
account_id: Optional[str] = None,
|
| 420 |
+
account_name: Optional[str] = None,
|
| 421 |
+
) -> str:
|
| 422 |
+
"""创建新的 Flow"""
|
| 423 |
+
flow_id = uuid.uuid4().hex[:12]
|
| 424 |
+
|
| 425 |
+
# 解析请求
|
| 426 |
+
request = FlowRequest(
|
| 427 |
+
method=method,
|
| 428 |
+
path=path,
|
| 429 |
+
headers={k: v for k, v in headers.items() if k.lower() not in ["authorization"]},
|
| 430 |
+
body=body,
|
| 431 |
+
model=body.get("model", ""),
|
| 432 |
+
stream=body.get("stream", False),
|
| 433 |
+
system=body.get("system", ""),
|
| 434 |
+
tools=body.get("tools", []),
|
| 435 |
+
max_tokens=body.get("max_tokens", 0),
|
| 436 |
+
temperature=body.get("temperature", 1.0),
|
| 437 |
+
)
|
| 438 |
+
|
| 439 |
+
# 解析消息
|
| 440 |
+
messages = body.get("messages", [])
|
| 441 |
+
for msg in messages:
|
| 442 |
+
request.messages.append(Message(
|
| 443 |
+
role=msg.get("role", "user"),
|
| 444 |
+
content=msg.get("content", ""),
|
| 445 |
+
name=msg.get("name"),
|
| 446 |
+
tool_call_id=msg.get("tool_call_id"),
|
| 447 |
+
))
|
| 448 |
+
|
| 449 |
+
flow = LLMFlow(
|
| 450 |
+
id=flow_id,
|
| 451 |
+
state=FlowState.PENDING,
|
| 452 |
+
protocol=protocol,
|
| 453 |
+
account_id=account_id,
|
| 454 |
+
account_name=account_name,
|
| 455 |
+
request=request,
|
| 456 |
+
timing=FlowTiming(created_at=time.time()),
|
| 457 |
+
)
|
| 458 |
+
|
| 459 |
+
self.store.add(flow)
|
| 460 |
+
return flow_id
|
| 461 |
+
|
| 462 |
+
def start_streaming(self, flow_id: str):
|
| 463 |
+
"""标记开始流式传输"""
|
| 464 |
+
flow = self.store.get(flow_id)
|
| 465 |
+
if flow:
|
| 466 |
+
flow.state = FlowState.STREAMING
|
| 467 |
+
flow.timing.first_byte_at = time.time()
|
| 468 |
+
if not flow.response:
|
| 469 |
+
flow.response = FlowResponse(status_code=200)
|
| 470 |
+
|
| 471 |
+
def add_chunk(self, flow_id: str, chunk: str):
|
| 472 |
+
"""添加流式响应块"""
|
| 473 |
+
flow = self.store.get(flow_id)
|
| 474 |
+
if flow and flow.response:
|
| 475 |
+
flow.response.chunks.append(chunk)
|
| 476 |
+
flow.response.chunk_count += 1
|
| 477 |
+
flow.response.content += chunk
|
| 478 |
+
|
| 479 |
+
def complete_flow(
|
| 480 |
+
self,
|
| 481 |
+
flow_id: str,
|
| 482 |
+
status_code: int,
|
| 483 |
+
content: str = "",
|
| 484 |
+
tool_calls: List[Dict] = None,
|
| 485 |
+
stop_reason: str = "",
|
| 486 |
+
usage: Optional[TokenUsage] = None,
|
| 487 |
+
headers: Dict[str, str] = None,
|
| 488 |
+
):
|
| 489 |
+
"""完成 Flow"""
|
| 490 |
+
flow = self.store.get(flow_id)
|
| 491 |
+
if not flow:
|
| 492 |
+
return
|
| 493 |
+
|
| 494 |
+
flow.state = FlowState.COMPLETED
|
| 495 |
+
flow.timing.completed_at = time.time()
|
| 496 |
+
|
| 497 |
+
if not flow.response:
|
| 498 |
+
flow.response = FlowResponse(status_code=status_code)
|
| 499 |
+
|
| 500 |
+
flow.response.status_code = status_code
|
| 501 |
+
flow.response.content = content or flow.response.content
|
| 502 |
+
flow.response.tool_calls = tool_calls or []
|
| 503 |
+
flow.response.stop_reason = stop_reason
|
| 504 |
+
flow.response.headers = headers or {}
|
| 505 |
+
|
| 506 |
+
if usage:
|
| 507 |
+
flow.response.usage = usage
|
| 508 |
+
self.store.total_tokens_in += usage.input_tokens
|
| 509 |
+
self.store.total_tokens_out += usage.output_tokens
|
| 510 |
+
|
| 511 |
+
def fail_flow(self, flow_id: str, error_type: str, message: str, status_code: int = 0, raw: str = ""):
|
| 512 |
+
"""标记 Flow 失败"""
|
| 513 |
+
flow = self.store.get(flow_id)
|
| 514 |
+
if not flow:
|
| 515 |
+
return
|
| 516 |
+
|
| 517 |
+
flow.state = FlowState.ERROR
|
| 518 |
+
flow.timing.completed_at = time.time()
|
| 519 |
+
flow.error = FlowError(
|
| 520 |
+
type=error_type,
|
| 521 |
+
message=message,
|
| 522 |
+
status_code=status_code,
|
| 523 |
+
raw=raw[:1000], # 限制长度
|
| 524 |
+
)
|
| 525 |
+
|
| 526 |
+
def bookmark_flow(self, flow_id: str, bookmarked: bool = True):
|
| 527 |
+
"""书签 Flow"""
|
| 528 |
+
flow = self.store.get(flow_id)
|
| 529 |
+
if flow:
|
| 530 |
+
flow.bookmarked = bookmarked
|
| 531 |
+
|
| 532 |
+
def add_note(self, flow_id: str, note: str):
|
| 533 |
+
"""添加备注"""
|
| 534 |
+
flow = self.store.get(flow_id)
|
| 535 |
+
if flow:
|
| 536 |
+
flow.notes = note
|
| 537 |
+
|
| 538 |
+
def add_tag(self, flow_id: str, tag: str):
|
| 539 |
+
"""添加标签"""
|
| 540 |
+
flow = self.store.get(flow_id)
|
| 541 |
+
if flow and tag not in flow.tags:
|
| 542 |
+
flow.tags.append(tag)
|
| 543 |
+
|
| 544 |
+
def get_flow(self, flow_id: str) -> Optional[LLMFlow]:
|
| 545 |
+
"""获取 Flow"""
|
| 546 |
+
return self.store.get(flow_id)
|
| 547 |
+
|
| 548 |
+
def query(self, **kwargs) -> List[LLMFlow]:
|
| 549 |
+
"""查询 Flows"""
|
| 550 |
+
return self.store.query(**kwargs)
|
| 551 |
+
|
| 552 |
+
def get_stats(self) -> dict:
|
| 553 |
+
"""获取统计"""
|
| 554 |
+
return self.store.get_stats()
|
| 555 |
+
|
| 556 |
+
def export(self, flow_ids: List[str] = None, format: str = "jsonl") -> str:
|
| 557 |
+
"""导出 Flows"""
|
| 558 |
+
if flow_ids:
|
| 559 |
+
flows = [self.store.get(fid) for fid in flow_ids if self.store.get(fid)]
|
| 560 |
+
else:
|
| 561 |
+
flows = list(self.store.flows)
|
| 562 |
+
|
| 563 |
+
if format == "jsonl":
|
| 564 |
+
return self.store.export_jsonl(flows)
|
| 565 |
+
elif format == "markdown" and len(flows) == 1:
|
| 566 |
+
return self.store.export_markdown(flows[0])
|
| 567 |
+
else:
|
| 568 |
+
return json.dumps([f.to_dict() for f in flows], ensure_ascii=False, indent=2)
|
| 569 |
+
|
| 570 |
+
|
| 571 |
+
# 全局实例
|
| 572 |
+
flow_monitor = FlowMonitor(max_flows=500)
|
KiroProxy/kiro_proxy/core/history_manager.py
ADDED
|
@@ -0,0 +1,829 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""历史消息管理器 - 错误触发压缩版
|
| 2 |
+
|
| 3 |
+
自动化管理对话历史长度,收到超限错误时智能压缩而非强硬截断:
|
| 4 |
+
1. 无预检测 - 不再依赖阈值,正常发送请求
|
| 5 |
+
2. 错误触发 - 收到 CONTENT_LENGTH_EXCEEDS_THRESHOLD 错误后自动压缩
|
| 6 |
+
3. 智能压缩 - 保留最近消息 + 摘要早期对话,目标 20K-50K 字符
|
| 7 |
+
4. 自动重试 - 压缩后自动重试请求
|
| 8 |
+
"""
|
| 9 |
+
import json
|
| 10 |
+
import time
|
| 11 |
+
from typing import List, Dict, Any, Tuple, Optional, Callable
|
| 12 |
+
from dataclasses import dataclass, field
|
| 13 |
+
from collections import OrderedDict
|
| 14 |
+
from enum import Enum
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@dataclass
|
| 18 |
+
class SummaryCacheEntry:
|
| 19 |
+
summary: str
|
| 20 |
+
old_history_hash: str
|
| 21 |
+
updated_at: float
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class SummaryCache:
|
| 25 |
+
"""摘要缓存"""
|
| 26 |
+
|
| 27 |
+
def __init__(self, max_entries: int = 64):
|
| 28 |
+
self._entries: "OrderedDict[str, SummaryCacheEntry]" = OrderedDict()
|
| 29 |
+
self._max_entries = max_entries
|
| 30 |
+
|
| 31 |
+
def get(self, key: str, old_history_hash: str, max_age: int = 300) -> Optional[str]:
|
| 32 |
+
entry = self._entries.get(key)
|
| 33 |
+
if not entry:
|
| 34 |
+
return None
|
| 35 |
+
if time.time() - entry.updated_at > max_age:
|
| 36 |
+
self._entries.pop(key, None)
|
| 37 |
+
return None
|
| 38 |
+
if entry.old_history_hash != old_history_hash:
|
| 39 |
+
return None
|
| 40 |
+
self._entries.move_to_end(key)
|
| 41 |
+
return entry.summary
|
| 42 |
+
|
| 43 |
+
def set(self, key: str, summary: str, old_history_hash: str):
|
| 44 |
+
self._entries[key] = SummaryCacheEntry(
|
| 45 |
+
summary=summary,
|
| 46 |
+
old_history_hash=old_history_hash,
|
| 47 |
+
updated_at=time.time()
|
| 48 |
+
)
|
| 49 |
+
self._entries.move_to_end(key)
|
| 50 |
+
if len(self._entries) > self._max_entries:
|
| 51 |
+
self._entries.popitem(last=False)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
@dataclass
|
| 55 |
+
class CompressionCacheEntry:
|
| 56 |
+
"""压缩结果缓存条目"""
|
| 57 |
+
compressed_history: List[dict]
|
| 58 |
+
original_hash: str
|
| 59 |
+
compressed_chars: int
|
| 60 |
+
updated_at: float
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class CompressionCache:
|
| 64 |
+
"""全局压缩结果缓存
|
| 65 |
+
|
| 66 |
+
解决 Claude Code CLI 反复压缩问题:
|
| 67 |
+
- 客户端每次请求都发送完整原始历史
|
| 68 |
+
- 缓存压缩结果,避免对相同内容重复压缩
|
| 69 |
+
- 基于原始历史的 hash 匹配
|
| 70 |
+
"""
|
| 71 |
+
|
| 72 |
+
def __init__(self, max_entries: int = 32, max_age: int = 600):
|
| 73 |
+
self._entries: "OrderedDict[str, CompressionCacheEntry]" = OrderedDict()
|
| 74 |
+
self._max_entries = max_entries
|
| 75 |
+
self._max_age = max_age # 缓存有效期(秒),默认 10 分钟
|
| 76 |
+
|
| 77 |
+
def get(self, original_hash: str) -> Optional[List[dict]]:
|
| 78 |
+
"""获取缓存的压缩结果"""
|
| 79 |
+
entry = self._entries.get(original_hash)
|
| 80 |
+
if not entry:
|
| 81 |
+
return None
|
| 82 |
+
if time.time() - entry.updated_at > self._max_age:
|
| 83 |
+
self._entries.pop(original_hash, None)
|
| 84 |
+
return None
|
| 85 |
+
self._entries.move_to_end(original_hash)
|
| 86 |
+
print(f"[CompressionCache] 命中缓存,跳过重复压缩 (原始 hash: {original_hash[:16]}...)")
|
| 87 |
+
return entry.compressed_history
|
| 88 |
+
|
| 89 |
+
def set(self, original_hash: str, compressed_history: List[dict], compressed_chars: int):
|
| 90 |
+
"""缓存压缩结果"""
|
| 91 |
+
self._entries[original_hash] = CompressionCacheEntry(
|
| 92 |
+
compressed_history=compressed_history,
|
| 93 |
+
original_hash=original_hash,
|
| 94 |
+
compressed_chars=compressed_chars,
|
| 95 |
+
updated_at=time.time()
|
| 96 |
+
)
|
| 97 |
+
self._entries.move_to_end(original_hash)
|
| 98 |
+
if len(self._entries) > self._max_entries:
|
| 99 |
+
self._entries.popitem(last=False)
|
| 100 |
+
print(f"[CompressionCache] 缓存压缩结果 (原始 hash: {original_hash[:16]}..., 压缩后: {compressed_chars} 字符)")
|
| 101 |
+
|
| 102 |
+
def clear(self):
|
| 103 |
+
"""清空缓存"""
|
| 104 |
+
self._entries.clear()
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
# 全局压缩缓存实例
|
| 108 |
+
_compression_cache = CompressionCache()
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
class TruncateStrategy(str, Enum):
|
| 112 |
+
"""压缩策略(保留用于兼容)"""
|
| 113 |
+
NONE = "none"
|
| 114 |
+
AUTO_TRUNCATE = "auto_truncate"
|
| 115 |
+
SMART_SUMMARY = "smart_summary"
|
| 116 |
+
ERROR_RETRY = "error_retry"
|
| 117 |
+
PRE_ESTIMATE = "pre_estimate"
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
# 自动管理的常量(不再使用阈值触发,仅在错误后压缩)
|
| 121 |
+
# AUTO_COMPRESS_THRESHOLD 已废弃,不再用于预检测
|
| 122 |
+
SAFE_CHAR_LIMIT = 35000 # 压缩后的目标字符数 (20K-50K 范围的中间值)
|
| 123 |
+
SAFE_CHAR_LIMIT_MIN = 20000 # 压缩目标下限
|
| 124 |
+
SAFE_CHAR_LIMIT_MAX = 50000 # 压缩目标上限
|
| 125 |
+
MIN_KEEP_MESSAGES = 6 # 最少保留的最近消息数
|
| 126 |
+
MAX_KEEP_MESSAGES = 20 # 最多保留的最近消息数
|
| 127 |
+
SUMMARY_MAX_LENGTH = 3000 # 摘要最大长度
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
@dataclass
|
| 131 |
+
class HistoryConfig:
|
| 132 |
+
"""历史消息配置(简化版,大部分参数自动管理)"""
|
| 133 |
+
# 启用的策略
|
| 134 |
+
strategies: List[TruncateStrategy] = field(default_factory=lambda: [TruncateStrategy.ERROR_RETRY])
|
| 135 |
+
|
| 136 |
+
# 以下参数保留用于兼容,但实际使用自动值
|
| 137 |
+
max_messages: int = 30
|
| 138 |
+
max_chars: int = 150000
|
| 139 |
+
summary_keep_recent: int = 10
|
| 140 |
+
summary_threshold: int = 100000
|
| 141 |
+
summary_max_length: int = 2000
|
| 142 |
+
retry_max_messages: int = 20
|
| 143 |
+
max_retries: int = 3
|
| 144 |
+
estimate_threshold: int = 180000
|
| 145 |
+
chars_per_token: float = 3.0
|
| 146 |
+
summary_cache_enabled: bool = True
|
| 147 |
+
summary_cache_min_delta_messages: int = 3
|
| 148 |
+
summary_cache_min_delta_chars: int = 4000
|
| 149 |
+
summary_cache_max_age_seconds: int = 300
|
| 150 |
+
add_warning_header: bool = True
|
| 151 |
+
|
| 152 |
+
def to_dict(self) -> dict:
|
| 153 |
+
return {
|
| 154 |
+
"strategies": [s.value for s in self.strategies],
|
| 155 |
+
"max_messages": self.max_messages,
|
| 156 |
+
"max_chars": self.max_chars,
|
| 157 |
+
"summary_keep_recent": self.summary_keep_recent,
|
| 158 |
+
"summary_threshold": self.summary_threshold,
|
| 159 |
+
"summary_max_length": self.summary_max_length,
|
| 160 |
+
"retry_max_messages": self.retry_max_messages,
|
| 161 |
+
"max_retries": self.max_retries,
|
| 162 |
+
"estimate_threshold": self.estimate_threshold,
|
| 163 |
+
"chars_per_token": self.chars_per_token,
|
| 164 |
+
"summary_cache_enabled": self.summary_cache_enabled,
|
| 165 |
+
"summary_cache_min_delta_messages": self.summary_cache_min_delta_messages,
|
| 166 |
+
"summary_cache_min_delta_chars": self.summary_cache_min_delta_chars,
|
| 167 |
+
"summary_cache_max_age_seconds": self.summary_cache_max_age_seconds,
|
| 168 |
+
"add_warning_header": self.add_warning_header,
|
| 169 |
+
}
|
| 170 |
+
|
| 171 |
+
@classmethod
|
| 172 |
+
def from_dict(cls, data: dict) -> "HistoryConfig":
|
| 173 |
+
strategies = [TruncateStrategy(s) for s in data.get("strategies", ["error_retry"])]
|
| 174 |
+
return cls(
|
| 175 |
+
strategies=strategies,
|
| 176 |
+
max_messages=data.get("max_messages", 30),
|
| 177 |
+
max_chars=data.get("max_chars", 150000),
|
| 178 |
+
summary_keep_recent=data.get("summary_keep_recent", 10),
|
| 179 |
+
summary_threshold=data.get("summary_threshold", 100000),
|
| 180 |
+
summary_max_length=data.get("summary_max_length", 2000),
|
| 181 |
+
retry_max_messages=data.get("retry_max_messages", 20),
|
| 182 |
+
max_retries=data.get("max_retries", 3),
|
| 183 |
+
estimate_threshold=data.get("estimate_threshold", 180000),
|
| 184 |
+
chars_per_token=data.get("chars_per_token", 3.0),
|
| 185 |
+
summary_cache_enabled=data.get("summary_cache_enabled", True),
|
| 186 |
+
summary_cache_min_delta_messages=data.get("summary_cache_min_delta_messages", 3),
|
| 187 |
+
summary_cache_min_delta_chars=data.get("summary_cache_min_delta_chars", 4000),
|
| 188 |
+
summary_cache_max_age_seconds=data.get("summary_cache_max_age_seconds", 300),
|
| 189 |
+
add_warning_header=data.get("add_warning_header", True),
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
_summary_cache = SummaryCache()
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
class HistoryManager:
|
| 197 |
+
"""历史消息管理器 - 错误触发压缩版
|
| 198 |
+
|
| 199 |
+
不再依赖阈值预检测,仅在收到上下文超限错误后触发压缩。
|
| 200 |
+
压缩目标为 20K-50K 字符范围。
|
| 201 |
+
"""
|
| 202 |
+
|
| 203 |
+
def __init__(self, config: HistoryConfig = None, cache_key: Optional[str] = None):
|
| 204 |
+
self.config = config or HistoryConfig()
|
| 205 |
+
self._truncated = False
|
| 206 |
+
self._truncate_info = ""
|
| 207 |
+
self.cache_key = cache_key
|
| 208 |
+
self._retry_count = 0
|
| 209 |
+
|
| 210 |
+
@property
|
| 211 |
+
def was_truncated(self) -> bool:
|
| 212 |
+
return self._truncated
|
| 213 |
+
|
| 214 |
+
@property
|
| 215 |
+
def truncate_info(self) -> str:
|
| 216 |
+
return self._truncate_info
|
| 217 |
+
|
| 218 |
+
def reset(self):
|
| 219 |
+
self._truncated = False
|
| 220 |
+
self._truncate_info = ""
|
| 221 |
+
|
| 222 |
+
def set_cache_key(self, key: Optional[str]):
|
| 223 |
+
self.cache_key = key
|
| 224 |
+
|
| 225 |
+
def _hash_history(self, history: List[dict]) -> str:
|
| 226 |
+
"""生成历史消息的简单哈希"""
|
| 227 |
+
return f"{len(history)}:{len(json.dumps(history, ensure_ascii=False))}"
|
| 228 |
+
|
| 229 |
+
def estimate_tokens(self, text: str) -> int:
|
| 230 |
+
return int(len(text) / self.config.chars_per_token)
|
| 231 |
+
|
| 232 |
+
def estimate_history_size(self, history: List[dict]) -> Tuple[int, int]:
|
| 233 |
+
char_count = len(json.dumps(history, ensure_ascii=False))
|
| 234 |
+
return len(history), char_count
|
| 235 |
+
|
| 236 |
+
def estimate_request_chars(self, history: List[dict], user_content: str = "") -> Tuple[int, int, int]:
|
| 237 |
+
history_chars = len(json.dumps(history, ensure_ascii=False))
|
| 238 |
+
user_chars = len(user_content or "")
|
| 239 |
+
return history_chars, user_chars, history_chars + user_chars
|
| 240 |
+
|
| 241 |
+
def _extract_text(self, content) -> str:
|
| 242 |
+
if isinstance(content, str):
|
| 243 |
+
return content
|
| 244 |
+
if isinstance(content, list):
|
| 245 |
+
texts = []
|
| 246 |
+
for item in content:
|
| 247 |
+
if isinstance(item, dict) and item.get("type") == "text":
|
| 248 |
+
texts.append(item.get("text", ""))
|
| 249 |
+
elif isinstance(item, str):
|
| 250 |
+
texts.append(item)
|
| 251 |
+
return "\n".join(texts)
|
| 252 |
+
if isinstance(content, dict):
|
| 253 |
+
return content.get("text", "") or content.get("content", "")
|
| 254 |
+
return str(content) if content else ""
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
def _format_for_summary(self, history: List[dict]) -> str:
|
| 258 |
+
"""格式化历史消息用于生成摘要"""
|
| 259 |
+
lines = []
|
| 260 |
+
for msg in history:
|
| 261 |
+
role = "unknown"
|
| 262 |
+
content = ""
|
| 263 |
+
if "userInputMessage" in msg:
|
| 264 |
+
role = "user"
|
| 265 |
+
content = msg.get("userInputMessage", {}).get("content", "")
|
| 266 |
+
elif "assistantResponseMessage" in msg:
|
| 267 |
+
role = "assistant"
|
| 268 |
+
content = msg.get("assistantResponseMessage", {}).get("content", "")
|
| 269 |
+
else:
|
| 270 |
+
role = msg.get("role", "unknown")
|
| 271 |
+
content = self._extract_text(msg.get("content", ""))
|
| 272 |
+
# 截断过长的单条消息
|
| 273 |
+
if len(content) > 800:
|
| 274 |
+
content = content[:800] + "..."
|
| 275 |
+
lines.append(f"[{role}]: {content}")
|
| 276 |
+
return "\n".join(lines)
|
| 277 |
+
|
| 278 |
+
def _calculate_keep_count(self, history: List[dict], target_chars: int) -> int:
|
| 279 |
+
"""计算应该保留多少条最近消息"""
|
| 280 |
+
if not history:
|
| 281 |
+
return 0
|
| 282 |
+
|
| 283 |
+
# 从后往前累计,找到合适的保留数量
|
| 284 |
+
total = 0
|
| 285 |
+
count = 0
|
| 286 |
+
for msg in reversed(history):
|
| 287 |
+
msg_chars = len(json.dumps(msg, ensure_ascii=False))
|
| 288 |
+
if total + msg_chars > target_chars and count >= MIN_KEEP_MESSAGES:
|
| 289 |
+
break
|
| 290 |
+
total += msg_chars
|
| 291 |
+
count += 1
|
| 292 |
+
if count >= MAX_KEEP_MESSAGES:
|
| 293 |
+
break
|
| 294 |
+
|
| 295 |
+
return max(MIN_KEEP_MESSAGES, min(count, len(history) - 1))
|
| 296 |
+
|
| 297 |
+
def _build_compressed_history(
|
| 298 |
+
self,
|
| 299 |
+
summary: str,
|
| 300 |
+
recent_history: List[dict],
|
| 301 |
+
label: str = ""
|
| 302 |
+
) -> List[dict]:
|
| 303 |
+
"""构建压缩后的历史(摘要 + 最近消息)"""
|
| 304 |
+
# 确保 recent_history 以 user 消息开头
|
| 305 |
+
if recent_history and "assistantResponseMessage" in recent_history[0]:
|
| 306 |
+
recent_history = recent_history[1:]
|
| 307 |
+
|
| 308 |
+
# 清理孤立的 toolResults
|
| 309 |
+
tool_use_ids = set()
|
| 310 |
+
for msg in recent_history:
|
| 311 |
+
if "assistantResponseMessage" in msg:
|
| 312 |
+
for tu in msg["assistantResponseMessage"].get("toolUses", []) or []:
|
| 313 |
+
if tu.get("toolUseId"):
|
| 314 |
+
tool_use_ids.add(tu["toolUseId"])
|
| 315 |
+
|
| 316 |
+
# 清理第一条 user 消息的 toolResults(因为前面没有对应的 toolUse)
|
| 317 |
+
if recent_history and "userInputMessage" in recent_history[0]:
|
| 318 |
+
recent_history[0]["userInputMessage"].pop("userInputMessageContext", None)
|
| 319 |
+
|
| 320 |
+
# 过滤其他消息中孤立的 toolResults
|
| 321 |
+
if tool_use_ids:
|
| 322 |
+
for msg in recent_history:
|
| 323 |
+
if "userInputMessage" in msg:
|
| 324 |
+
ctx = msg.get("userInputMessage", {}).get("userInputMessageContext", {})
|
| 325 |
+
results = ctx.get("toolResults")
|
| 326 |
+
if results:
|
| 327 |
+
filtered = [r for r in results if r.get("toolUseId") in tool_use_ids]
|
| 328 |
+
if filtered:
|
| 329 |
+
ctx["toolResults"] = filtered
|
| 330 |
+
else:
|
| 331 |
+
ctx.pop("toolResults", None)
|
| 332 |
+
if not ctx:
|
| 333 |
+
msg["userInputMessage"].pop("userInputMessageContext", None)
|
| 334 |
+
else:
|
| 335 |
+
for msg in recent_history:
|
| 336 |
+
if "userInputMessage" in msg:
|
| 337 |
+
msg["userInputMessage"].pop("userInputMessageContext", None)
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
# 获取 model_id
|
| 341 |
+
model_id = "claude-sonnet-4"
|
| 342 |
+
for msg in reversed(recent_history):
|
| 343 |
+
if "userInputMessage" in msg:
|
| 344 |
+
model_id = msg["userInputMessage"].get("modelId", model_id)
|
| 345 |
+
break
|
| 346 |
+
if "assistantResponseMessage" in msg:
|
| 347 |
+
model_id = msg["assistantResponseMessage"].get("modelId", model_id)
|
| 348 |
+
break
|
| 349 |
+
|
| 350 |
+
# 检测消息格式
|
| 351 |
+
is_kiro_format = any("userInputMessage" in h or "assistantResponseMessage" in h for h in recent_history)
|
| 352 |
+
|
| 353 |
+
if is_kiro_format:
|
| 354 |
+
result = [
|
| 355 |
+
{
|
| 356 |
+
"userInputMessage": {
|
| 357 |
+
"content": f"[Earlier conversation summary]\n{summary}\n\n[Continuing from recent context...]",
|
| 358 |
+
"modelId": model_id,
|
| 359 |
+
"origin": "AI_EDITOR",
|
| 360 |
+
}
|
| 361 |
+
},
|
| 362 |
+
{
|
| 363 |
+
"assistantResponseMessage": {
|
| 364 |
+
"content": "I understand the context from the summary. Let's continue."
|
| 365 |
+
}
|
| 366 |
+
}
|
| 367 |
+
]
|
| 368 |
+
else:
|
| 369 |
+
result = [
|
| 370 |
+
{"role": "user", "content": f"[Earlier conversation summary]\n{summary}\n\n[Continuing from recent context...]"},
|
| 371 |
+
{"role": "assistant", "content": "I understand the context from the summary. Let's continue."}
|
| 372 |
+
]
|
| 373 |
+
|
| 374 |
+
result.extend(recent_history)
|
| 375 |
+
|
| 376 |
+
if label:
|
| 377 |
+
print(f"[HistoryManager] {label}: {len(recent_history)} recent + summary")
|
| 378 |
+
|
| 379 |
+
return result
|
| 380 |
+
|
| 381 |
+
async def _generate_summary(self, history: List[dict], api_caller: Callable) -> Optional[str]:
|
| 382 |
+
"""生成历史消息摘要"""
|
| 383 |
+
if not history or not api_caller:
|
| 384 |
+
return None
|
| 385 |
+
|
| 386 |
+
formatted = self._format_for_summary(history)
|
| 387 |
+
if len(formatted) > 15000:
|
| 388 |
+
formatted = formatted[:15000] + "\n...(truncated)"
|
| 389 |
+
|
| 390 |
+
prompt = f"""请简洁总结以下对话的关键信息:
|
| 391 |
+
1. 用户的主要目标
|
| 392 |
+
2. 已完成的重要操作和决策
|
| 393 |
+
3. 当前工作状态和关键上下文
|
| 394 |
+
|
| 395 |
+
对话历史:
|
| 396 |
+
{formatted}
|
| 397 |
+
|
| 398 |
+
请用中文输出摘要,控制在 {SUMMARY_MAX_LENGTH} 字符以内,重点保留对后续对话有用的信息:"""
|
| 399 |
+
|
| 400 |
+
try:
|
| 401 |
+
summary = await api_caller(prompt)
|
| 402 |
+
if summary and len(summary) > SUMMARY_MAX_LENGTH:
|
| 403 |
+
summary = summary[:SUMMARY_MAX_LENGTH] + "..."
|
| 404 |
+
return summary
|
| 405 |
+
except Exception as e:
|
| 406 |
+
print(f"[HistoryManager] 生成摘要失败: {e}")
|
| 407 |
+
return None
|
| 408 |
+
|
| 409 |
+
|
| 410 |
+
async def smart_compress(
|
| 411 |
+
self,
|
| 412 |
+
history: List[dict],
|
| 413 |
+
api_caller: Callable,
|
| 414 |
+
target_chars: int = SAFE_CHAR_LIMIT,
|
| 415 |
+
retry_level: int = 0
|
| 416 |
+
) -> List[dict]:
|
| 417 |
+
"""智能压缩历史消息
|
| 418 |
+
|
| 419 |
+
核心逻辑:保留最近消息 + 摘要早期对话
|
| 420 |
+
压缩目标为 20K-50K 字符范围
|
| 421 |
+
|
| 422 |
+
Args:
|
| 423 |
+
history: 历史消息
|
| 424 |
+
api_caller: 用于生成摘要的 API 调用函数
|
| 425 |
+
target_chars: 目标字符数 (默认 35K,范围 20K-50K)
|
| 426 |
+
retry_level: 重试级别(越高保留越少)
|
| 427 |
+
"""
|
| 428 |
+
if not history:
|
| 429 |
+
return history
|
| 430 |
+
|
| 431 |
+
current_chars = len(json.dumps(history, ensure_ascii=False))
|
| 432 |
+
|
| 433 |
+
# 确保目标在 20K-50K 范围内
|
| 434 |
+
target_chars = max(SAFE_CHAR_LIMIT_MIN, min(target_chars, SAFE_CHAR_LIMIT_MAX))
|
| 435 |
+
|
| 436 |
+
# 如果已经在目标范围内,不需要压缩
|
| 437 |
+
if current_chars <= target_chars:
|
| 438 |
+
return history
|
| 439 |
+
|
| 440 |
+
# 根据重试级别调整保留数量
|
| 441 |
+
adjusted_target = int(target_chars * (0.85 ** retry_level))
|
| 442 |
+
adjusted_target = max(SAFE_CHAR_LIMIT_MIN, adjusted_target) # 确保不低于下限
|
| 443 |
+
|
| 444 |
+
keep_count = self._calculate_keep_count(history, adjusted_target)
|
| 445 |
+
|
| 446 |
+
# 确保至少保留一些消息用于摘要
|
| 447 |
+
if keep_count >= len(history):
|
| 448 |
+
keep_count = max(MIN_KEEP_MESSAGES, len(history) - 2)
|
| 449 |
+
|
| 450 |
+
old_history = history[:-keep_count] if keep_count < len(history) else []
|
| 451 |
+
recent_history = history[-keep_count:] if keep_count > 0 else history
|
| 452 |
+
|
| 453 |
+
if not old_history:
|
| 454 |
+
# 没有可摘要的历史,直接返回
|
| 455 |
+
return recent_history
|
| 456 |
+
|
| 457 |
+
# 尝试从缓存获取摘要
|
| 458 |
+
cache_key = f"{self.cache_key}:{keep_count}" if self.cache_key else None
|
| 459 |
+
old_hash = self._hash_history(old_history)
|
| 460 |
+
|
| 461 |
+
cached_summary = None
|
| 462 |
+
if cache_key and self.config.summary_cache_enabled:
|
| 463 |
+
cached_summary = _summary_cache.get(cache_key, old_hash, self.config.summary_cache_max_age_seconds)
|
| 464 |
+
|
| 465 |
+
if cached_summary:
|
| 466 |
+
result = self._build_compressed_history(cached_summary, recent_history, "压缩(缓存)")
|
| 467 |
+
result_chars = len(json.dumps(result, ensure_ascii=False))
|
| 468 |
+
self._truncated = True
|
| 469 |
+
self._truncate_info = f"智能压缩(缓存): {len(history)} -> {len(result)} 条消息, {current_chars} -> {result_chars} 字符"
|
| 470 |
+
return result
|
| 471 |
+
|
| 472 |
+
# 生成新摘要
|
| 473 |
+
summary = await self._generate_summary(old_history, api_caller)
|
| 474 |
+
|
| 475 |
+
if summary:
|
| 476 |
+
if cache_key and self.config.summary_cache_enabled:
|
| 477 |
+
_summary_cache.set(cache_key, summary, old_hash)
|
| 478 |
+
|
| 479 |
+
result = self._build_compressed_history(summary, recent_history, "智能压缩")
|
| 480 |
+
result_chars = len(json.dumps(result, ensure_ascii=False))
|
| 481 |
+
self._truncated = True
|
| 482 |
+
self._truncate_info = f"智能压缩: {len(history)} -> {len(result)} 条消息, {current_chars} -> {result_chars} 字符 (摘要 {len(summary)} 字符)"
|
| 483 |
+
return result
|
| 484 |
+
|
| 485 |
+
# 摘要失败,回退到简单截断
|
| 486 |
+
self._truncated = True
|
| 487 |
+
result_chars = len(json.dumps(recent_history, ensure_ascii=False))
|
| 488 |
+
self._truncate_info = f"摘要失败,保留最近 {len(recent_history)} 条消息, {current_chars} -> {result_chars} 字符"
|
| 489 |
+
return recent_history
|
| 490 |
+
|
| 491 |
+
|
| 492 |
+
def needs_compression(self, history: List[dict], user_content: str = "") -> bool:
|
| 493 |
+
"""检查是否需要压缩
|
| 494 |
+
|
| 495 |
+
注意:此方法现在始终返回 False,不再基于阈值预检测。
|
| 496 |
+
压缩仅在收到上下文超限错误后触发。
|
| 497 |
+
保留此方法是为了兼容旧 API。
|
| 498 |
+
"""
|
| 499 |
+
# 不再基于阈��预检测,始终返回 False
|
| 500 |
+
# 压缩将在收到 CONTENT_LENGTH_EXCEEDS_THRESHOLD 错误后触发
|
| 501 |
+
return False
|
| 502 |
+
|
| 503 |
+
async def pre_process_async(
|
| 504 |
+
self,
|
| 505 |
+
history: List[dict],
|
| 506 |
+
user_content: str = "",
|
| 507 |
+
api_caller: Callable = None
|
| 508 |
+
) -> List[dict]:
|
| 509 |
+
"""预处理历史消息
|
| 510 |
+
|
| 511 |
+
注意:不再进行发送前自动压缩。
|
| 512 |
+
压缩仅在收到上下文超限错误后触发。
|
| 513 |
+
"""
|
| 514 |
+
self.reset()
|
| 515 |
+
|
| 516 |
+
if not history:
|
| 517 |
+
return history
|
| 518 |
+
|
| 519 |
+
# 不再进行预压缩,直接返回原始历史
|
| 520 |
+
# 压缩将在收到错误后由 handle_length_error_async 处理
|
| 521 |
+
return history
|
| 522 |
+
|
| 523 |
+
def pre_process(self, history: List[dict], user_content: str = "") -> List[dict]:
|
| 524 |
+
"""预处理历史消息(同步版本)
|
| 525 |
+
|
| 526 |
+
注意:不再进行发送前自动压缩。
|
| 527 |
+
压缩仅在收到上下文超限错误后触发。
|
| 528 |
+
"""
|
| 529 |
+
self.reset()
|
| 530 |
+
|
| 531 |
+
if not history:
|
| 532 |
+
return history
|
| 533 |
+
|
| 534 |
+
# 不再进行预压缩,直接返回原始历史
|
| 535 |
+
return history
|
| 536 |
+
|
| 537 |
+
async def handle_length_error_async(
|
| 538 |
+
self,
|
| 539 |
+
history: List[dict],
|
| 540 |
+
retry_count: int = 0,
|
| 541 |
+
api_caller: Optional[Callable] = None
|
| 542 |
+
) -> Tuple[List[dict], bool]:
|
| 543 |
+
"""处理长度超限错误(智能压缩后重试)
|
| 544 |
+
|
| 545 |
+
这是唯一触发压缩的入口点。当收到上下文超限错误时调用此方法。
|
| 546 |
+
压缩目标为 20K-50K 字符范围。
|
| 547 |
+
|
| 548 |
+
防止无限循环:
|
| 549 |
+
- 追踪压缩状态,避免重复压缩相同内容
|
| 550 |
+
- 压缩前检查大小,如果已经很小则不再压缩
|
| 551 |
+
- 达到最大重试次数后返回清晰错误
|
| 552 |
+
|
| 553 |
+
Args:
|
| 554 |
+
history: 历史消息
|
| 555 |
+
retry_count: 当前重试次数
|
| 556 |
+
api_caller: API 调用函数
|
| 557 |
+
|
| 558 |
+
Returns:
|
| 559 |
+
(compressed_history, should_retry)
|
| 560 |
+
"""
|
| 561 |
+
max_retries = self.config.max_retries
|
| 562 |
+
|
| 563 |
+
if retry_count >= max_retries:
|
| 564 |
+
print(f"[HistoryManager] 已达最大重试次数 ({max_retries}),建议清空对话")
|
| 565 |
+
self._truncate_info = f"已达最大压缩次数 ({max_retries}),请清空对话或减少消息数量"
|
| 566 |
+
return history, False
|
| 567 |
+
|
| 568 |
+
if not history:
|
| 569 |
+
return history, False
|
| 570 |
+
|
| 571 |
+
self.reset()
|
| 572 |
+
|
| 573 |
+
current_chars = len(json.dumps(history, ensure_ascii=False))
|
| 574 |
+
current_hash = self._hash_history(history)
|
| 575 |
+
|
| 576 |
+
print(f"[HistoryManager] 收到上下文超限错误,当前大小: {current_chars} 字符")
|
| 577 |
+
|
| 578 |
+
# 优先检查全局压缩缓存(解决 Claude Code CLI 反复压缩问题)
|
| 579 |
+
cached_result = _compression_cache.get(current_hash)
|
| 580 |
+
if cached_result is not None:
|
| 581 |
+
cached_chars = len(json.dumps(cached_result, ensure_ascii=False))
|
| 582 |
+
self._truncated = True
|
| 583 |
+
self._truncate_info = f"使用缓存的压缩结果: {len(history)} -> {len(cached_result)} 条消息, {current_chars} -> {cached_chars} 字符"
|
| 584 |
+
print(f"[HistoryManager] {self._truncate_info}")
|
| 585 |
+
return cached_result, True
|
| 586 |
+
|
| 587 |
+
print(f"[HistoryManager] 开始压缩...")
|
| 588 |
+
|
| 589 |
+
# 防止无限循环:检查是否已经压缩过相同内容(实例级缓存)
|
| 590 |
+
instance_cache_key = f"compression:{current_hash}:{retry_count}"
|
| 591 |
+
if hasattr(self, '_instance_compression_cache') and instance_cache_key in self._instance_compression_cache:
|
| 592 |
+
print(f"[HistoryManager] 检测到重复压缩请求,跳过")
|
| 593 |
+
self._truncate_info = "内容已压缩到最小,无法继续压缩,请清空对话"
|
| 594 |
+
return history, False
|
| 595 |
+
|
| 596 |
+
# 初始化实例级压缩缓存
|
| 597 |
+
if not hasattr(self, '_instance_compression_cache'):
|
| 598 |
+
self._instance_compression_cache = {}
|
| 599 |
+
|
| 600 |
+
# 根据重试次数计算目标大小 (20K-50K 范围)
|
| 601 |
+
# 第一次重试: 目标 35K (中间值)
|
| 602 |
+
# 第二次重试: 目标 25K
|
| 603 |
+
# 第三次重试: 目标 20K (下限)
|
| 604 |
+
if retry_count == 0:
|
| 605 |
+
target_chars = SAFE_CHAR_LIMIT # 35K
|
| 606 |
+
elif retry_count == 1:
|
| 607 |
+
target_chars = 25000
|
| 608 |
+
else:
|
| 609 |
+
target_chars = SAFE_CHAR_LIMIT_MIN # 20K
|
| 610 |
+
|
| 611 |
+
# 防止无限循环:如果当前大小已经小于目标,不再压缩
|
| 612 |
+
if current_chars <= target_chars:
|
| 613 |
+
print(f"[HistoryManager] 当前大小 ({current_chars}) 已小于目标 ({target_chars}),无法继续压缩")
|
| 614 |
+
self._truncate_info = f"内容已压缩到 {current_chars} 字符,仍然超限,请清空对话"
|
| 615 |
+
return history, False
|
| 616 |
+
|
| 617 |
+
print(f"[HistoryManager] 第 {retry_count + 1} 次重试,目标压缩到 {target_chars} 字符")
|
| 618 |
+
|
| 619 |
+
if api_caller:
|
| 620 |
+
compressed = await self.smart_compress(
|
| 621 |
+
history, api_caller,
|
| 622 |
+
target_chars=target_chars,
|
| 623 |
+
retry_level=retry_count
|
| 624 |
+
)
|
| 625 |
+
compressed_chars = len(json.dumps(compressed, ensure_ascii=False))
|
| 626 |
+
|
| 627 |
+
# 防止无限循环:检查压缩是否有效
|
| 628 |
+
if compressed_chars >= current_chars * 0.95: # 压缩效果不足 5%
|
| 629 |
+
print(f"[HistoryManager] 压缩效果不足,无法继续压缩")
|
| 630 |
+
self._truncate_info = f"压缩效果不足,请清空对话或减少消息数量"
|
| 631 |
+
return history, False
|
| 632 |
+
|
| 633 |
+
# 防止无限循环:检查压缩后是否仍然过大
|
| 634 |
+
if compressed_chars > 50000 and retry_count >= max_retries - 1:
|
| 635 |
+
print(f"[HistoryManager] 压缩后仍然过大 ({compressed_chars}),建议清空对话")
|
| 636 |
+
self._truncate_info = f"压缩后仍有 {compressed_chars} 字符,请清空对话"
|
| 637 |
+
return compressed, False
|
| 638 |
+
|
| 639 |
+
if len(compressed) < len(history):
|
| 640 |
+
# 保存到全局压缩缓存(解决 Claude Code CLI 反复压缩问题)
|
| 641 |
+
_compression_cache.set(current_hash, compressed, compressed_chars)
|
| 642 |
+
|
| 643 |
+
# 记录实例级压缩缓存(防止同一请求内的重复压缩)
|
| 644 |
+
self._instance_compression_cache[instance_cache_key] = True
|
| 645 |
+
# 清理旧缓存(保留最近 10 条)
|
| 646 |
+
if len(self._instance_compression_cache) > 10:
|
| 647 |
+
oldest_key = next(iter(self._instance_compression_cache))
|
| 648 |
+
del self._instance_compression_cache[oldest_key]
|
| 649 |
+
|
| 650 |
+
self._truncated = True
|
| 651 |
+
self._truncate_info = f"错误后压缩 (第 {retry_count + 1} 次): {len(history)} -> {len(compressed)} 条消息, {current_chars} -> {compressed_chars} 字符"
|
| 652 |
+
print(f"[HistoryManager] {self._truncate_info}")
|
| 653 |
+
return compressed, True
|
| 654 |
+
else:
|
| 655 |
+
# 无 api_caller,简单截断
|
| 656 |
+
keep_count = max(MIN_KEEP_MESSAGES, int(len(history) * (0.5 ** (retry_count + 1))))
|
| 657 |
+
if keep_count < len(history):
|
| 658 |
+
truncated = history[-keep_count:]
|
| 659 |
+
self._truncated = True
|
| 660 |
+
truncated_chars = len(json.dumps(truncated, ensure_ascii=False))
|
| 661 |
+
|
| 662 |
+
# 防止无限循环:检查截断是否有效
|
| 663 |
+
if truncated_chars >= current_chars * 0.95:
|
| 664 |
+
print(f"[HistoryManager] 截断效果不足,无法继续压缩")
|
| 665 |
+
self._truncate_info = f"截断效果不足,请清空对话"
|
| 666 |
+
return history, False
|
| 667 |
+
|
| 668 |
+
self._truncate_info = f"错误后截断 (第 {retry_count + 1} 次): {len(history)} -> {len(truncated)} 条消息, {current_chars} -> {truncated_chars} 字符"
|
| 669 |
+
print(f"[HistoryManager] {self._truncate_info}")
|
| 670 |
+
return truncated, True
|
| 671 |
+
|
| 672 |
+
return history, False
|
| 673 |
+
|
| 674 |
+
|
| 675 |
+
def handle_length_error(self, history: List[dict], retry_count: int = 0) -> Tuple[List[dict], bool]:
|
| 676 |
+
"""处理长度超限错误(同步版本,简单截断)"""
|
| 677 |
+
max_retries = self.config.max_retries
|
| 678 |
+
|
| 679 |
+
if retry_count >= max_retries:
|
| 680 |
+
return history, False
|
| 681 |
+
|
| 682 |
+
if not history:
|
| 683 |
+
return history, False
|
| 684 |
+
|
| 685 |
+
self.reset()
|
| 686 |
+
|
| 687 |
+
# 根据重试次数逐步减少
|
| 688 |
+
keep_ratio = 0.5 ** (retry_count + 1)
|
| 689 |
+
keep_count = max(MIN_KEEP_MESSAGES, int(len(history) * keep_ratio))
|
| 690 |
+
|
| 691 |
+
if keep_count < len(history):
|
| 692 |
+
truncated = history[-keep_count:]
|
| 693 |
+
self._truncated = True
|
| 694 |
+
self._truncate_info = f"错误重试截断 (第 {retry_count + 1} 次): {len(history)} -> {len(truncated)} 条消息"
|
| 695 |
+
return truncated, True
|
| 696 |
+
|
| 697 |
+
return history, False
|
| 698 |
+
|
| 699 |
+
def get_warning_header(self) -> Optional[str]:
|
| 700 |
+
if not self.config.add_warning_header or not self._truncated:
|
| 701 |
+
return None
|
| 702 |
+
return self._truncate_info
|
| 703 |
+
|
| 704 |
+
# ========== 兼容旧 API ==========
|
| 705 |
+
|
| 706 |
+
def truncate_by_count(self, history: List[dict], max_count: int) -> List[dict]:
|
| 707 |
+
"""按消息数量截断(兼容)"""
|
| 708 |
+
if len(history) <= max_count:
|
| 709 |
+
return history
|
| 710 |
+
original_count = len(history)
|
| 711 |
+
truncated = history[-max_count:]
|
| 712 |
+
self._truncated = True
|
| 713 |
+
self._truncate_info = f"按数量截断: {original_count} -> {len(truncated)} 条消息"
|
| 714 |
+
return truncated
|
| 715 |
+
|
| 716 |
+
def truncate_by_chars(self, history: List[dict], max_chars: int) -> List[dict]:
|
| 717 |
+
"""按字符数截断(兼容)"""
|
| 718 |
+
total_chars = len(json.dumps(history, ensure_ascii=False))
|
| 719 |
+
if total_chars <= max_chars:
|
| 720 |
+
return history
|
| 721 |
+
|
| 722 |
+
original_count = len(history)
|
| 723 |
+
result = []
|
| 724 |
+
current_chars = 0
|
| 725 |
+
|
| 726 |
+
for msg in reversed(history):
|
| 727 |
+
msg_chars = len(json.dumps(msg, ensure_ascii=False))
|
| 728 |
+
if current_chars + msg_chars > max_chars and result:
|
| 729 |
+
break
|
| 730 |
+
result.insert(0, msg)
|
| 731 |
+
current_chars += msg_chars
|
| 732 |
+
|
| 733 |
+
if len(result) < original_count:
|
| 734 |
+
self._truncated = True
|
| 735 |
+
self._truncate_info = f"按字符数截断: {original_count} -> {len(result)} 条消息"
|
| 736 |
+
|
| 737 |
+
return result
|
| 738 |
+
|
| 739 |
+
def should_pre_truncate(self, history: List[dict], user_content: str) -> bool:
|
| 740 |
+
"""兼容旧 API"""
|
| 741 |
+
return self.needs_compression(history, user_content)
|
| 742 |
+
|
| 743 |
+
def should_summarize(self, history: List[dict]) -> bool:
|
| 744 |
+
"""兼容旧 API"""
|
| 745 |
+
return self.needs_compression(history)
|
| 746 |
+
|
| 747 |
+
def should_smart_summarize(self, history: List[dict]) -> bool:
|
| 748 |
+
"""兼容旧 API"""
|
| 749 |
+
return self.needs_compression(history)
|
| 750 |
+
|
| 751 |
+
def should_auto_truncate_summarize(self, history: List[dict]) -> bool:
|
| 752 |
+
"""兼容旧 API"""
|
| 753 |
+
return self.needs_compression(history)
|
| 754 |
+
|
| 755 |
+
def should_pre_summary_for_error_retry(self, history: List[dict], user_content: str = "") -> bool:
|
| 756 |
+
"""兼容旧 API"""
|
| 757 |
+
return self.needs_compression(history, user_content)
|
| 758 |
+
|
| 759 |
+
async def compress_with_summary(self, history: List[dict], api_caller: Callable) -> List[dict]:
|
| 760 |
+
"""兼容旧 API"""
|
| 761 |
+
return await self.smart_compress(history, api_caller)
|
| 762 |
+
|
| 763 |
+
async def compress_before_auto_truncate(self, history: List[dict], api_caller: Callable) -> List[dict]:
|
| 764 |
+
"""兼容旧 API"""
|
| 765 |
+
return await self.smart_compress(history, api_caller)
|
| 766 |
+
|
| 767 |
+
async def generate_summary(self, history: List[dict], api_caller: Callable) -> Optional[str]:
|
| 768 |
+
"""兼容旧 API"""
|
| 769 |
+
return await self._generate_summary(history, api_caller)
|
| 770 |
+
|
| 771 |
+
def summarize_history_structure(self, history: List[dict], max_items: int = 12) -> str:
|
| 772 |
+
"""生成历史结构摘要(调试用)"""
|
| 773 |
+
if not history:
|
| 774 |
+
return "len=0"
|
| 775 |
+
|
| 776 |
+
def entry_kind(msg):
|
| 777 |
+
if "userInputMessage" in msg:
|
| 778 |
+
return "U"
|
| 779 |
+
if "assistantResponseMessage" in msg:
|
| 780 |
+
return "A"
|
| 781 |
+
role = msg.get("role")
|
| 782 |
+
return "U" if role == "user" else ("A" if role == "assistant" else "?")
|
| 783 |
+
|
| 784 |
+
kinds = [entry_kind(msg) for msg in history]
|
| 785 |
+
if len(kinds) <= max_items:
|
| 786 |
+
seq = "".join(kinds)
|
| 787 |
+
else:
|
| 788 |
+
head = max_items // 2
|
| 789 |
+
tail = max_items - head
|
| 790 |
+
seq = f"{''.join(kinds[:head])}...{''.join(kinds[-tail:])}"
|
| 791 |
+
|
| 792 |
+
return f"len={len(history)} seq={seq}"
|
| 793 |
+
|
| 794 |
+
|
| 795 |
+
|
| 796 |
+
# ========== 全局配置 ==========
|
| 797 |
+
|
| 798 |
+
_history_config = HistoryConfig()
|
| 799 |
+
|
| 800 |
+
|
| 801 |
+
def get_history_config() -> HistoryConfig:
|
| 802 |
+
"""获取历史消息配置"""
|
| 803 |
+
return _history_config
|
| 804 |
+
|
| 805 |
+
|
| 806 |
+
def set_history_config(config: HistoryConfig):
|
| 807 |
+
"""设置历史消息配置"""
|
| 808 |
+
global _history_config
|
| 809 |
+
_history_config = config
|
| 810 |
+
|
| 811 |
+
|
| 812 |
+
def update_history_config(data: dict):
|
| 813 |
+
"""更新历史消息配置"""
|
| 814 |
+
global _history_config
|
| 815 |
+
_history_config = HistoryConfig.from_dict(data)
|
| 816 |
+
|
| 817 |
+
|
| 818 |
+
def is_content_length_error(status_code: int, error_text: str) -> bool:
|
| 819 |
+
"""检查是否为内容长度超限错误"""
|
| 820 |
+
if "CONTENT_LENGTH_EXCEEDS_THRESHOLD" in error_text:
|
| 821 |
+
return True
|
| 822 |
+
if "Input is too long" in error_text:
|
| 823 |
+
return True
|
| 824 |
+
lowered = error_text.lower()
|
| 825 |
+
if "too long" in lowered and ("input" in lowered or "content" in lowered or "message" in lowered):
|
| 826 |
+
return True
|
| 827 |
+
if "context length" in lowered or "token limit" in lowered:
|
| 828 |
+
return True
|
| 829 |
+
return False
|
KiroProxy/kiro_proxy/core/kiro_api.py
ADDED
|
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Kiro Web Portal API 调用模块
|
| 2 |
+
|
| 3 |
+
调用 Kiro 的 Web Portal API 获取用户信息,使用 CBOR 编码。
|
| 4 |
+
参考: chaogei/Kiro-account-manager
|
| 5 |
+
"""
|
| 6 |
+
import uuid
|
| 7 |
+
import httpx
|
| 8 |
+
from typing import Optional, Tuple, Any, Dict
|
| 9 |
+
|
| 10 |
+
try:
|
| 11 |
+
import cbor2
|
| 12 |
+
HAS_CBOR = True
|
| 13 |
+
except ImportError:
|
| 14 |
+
HAS_CBOR = False
|
| 15 |
+
print("[KiroAPI] 警告: cbor2 未安装,部分功能不可用。请运行: pip install cbor2")
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
# Kiro Web Portal API 基础 URL
|
| 19 |
+
KIRO_API_BASE = "https://app.kiro.dev/service/KiroWebPortalService/operation"
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
async def kiro_api_request(
|
| 23 |
+
operation: str,
|
| 24 |
+
body: Dict[str, Any],
|
| 25 |
+
access_token: str,
|
| 26 |
+
idp: str = "Google",
|
| 27 |
+
) -> Tuple[bool, Any]:
|
| 28 |
+
"""
|
| 29 |
+
调用 Kiro Web Portal API
|
| 30 |
+
|
| 31 |
+
Args:
|
| 32 |
+
operation: API 操作名称,如 "GetUserUsageAndLimits"
|
| 33 |
+
body: 请求体(会被 CBOR 编码)
|
| 34 |
+
access_token: Bearer token
|
| 35 |
+
idp: 身份提供商 ("Google" 或 "Github")
|
| 36 |
+
|
| 37 |
+
Returns:
|
| 38 |
+
(success, response_data or error_dict)
|
| 39 |
+
"""
|
| 40 |
+
if not HAS_CBOR:
|
| 41 |
+
return False, {"error": "cbor2 未安装"}
|
| 42 |
+
|
| 43 |
+
if not access_token:
|
| 44 |
+
return False, {"error": "缺少 access token"}
|
| 45 |
+
|
| 46 |
+
url = f"{KIRO_API_BASE}/{operation}"
|
| 47 |
+
|
| 48 |
+
# CBOR 编码请求体
|
| 49 |
+
try:
|
| 50 |
+
encoded_body = cbor2.dumps(body)
|
| 51 |
+
except Exception as e:
|
| 52 |
+
return False, {"error": f"CBOR 编码失败: {e}"}
|
| 53 |
+
|
| 54 |
+
headers = {
|
| 55 |
+
"accept": "application/cbor",
|
| 56 |
+
"content-type": "application/cbor",
|
| 57 |
+
"smithy-protocol": "rpc-v2-cbor",
|
| 58 |
+
"amz-sdk-invocation-id": str(uuid.uuid4()),
|
| 59 |
+
"amz-sdk-request": "attempt=1; max=1",
|
| 60 |
+
"x-amz-user-agent": "aws-sdk-js/1.0.0 kiro-proxy/1.0.0",
|
| 61 |
+
"authorization": f"Bearer {access_token}",
|
| 62 |
+
"cookie": f"Idp={idp}; AccessToken={access_token}",
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
try:
|
| 66 |
+
async with httpx.AsyncClient(timeout=15, verify=False) as client:
|
| 67 |
+
response = await client.post(url, content=encoded_body, headers=headers)
|
| 68 |
+
|
| 69 |
+
if response.status_code != 200:
|
| 70 |
+
return False, {"error": f"API 请求失败: {response.status_code}"}
|
| 71 |
+
|
| 72 |
+
# CBOR 解码响应
|
| 73 |
+
try:
|
| 74 |
+
data = cbor2.loads(response.content)
|
| 75 |
+
return True, data
|
| 76 |
+
except Exception as e:
|
| 77 |
+
return False, {"error": f"CBOR 解码失败: {e}"}
|
| 78 |
+
|
| 79 |
+
except httpx.TimeoutException:
|
| 80 |
+
return False, {"error": "请求超时"}
|
| 81 |
+
except Exception as e:
|
| 82 |
+
return False, {"error": f"请求失败: {str(e)}"}
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
async def get_user_info(
|
| 86 |
+
access_token: str,
|
| 87 |
+
idp: str = "Google",
|
| 88 |
+
) -> Tuple[bool, Dict[str, Any]]:
|
| 89 |
+
"""
|
| 90 |
+
获取用户信息(包括邮箱)
|
| 91 |
+
|
| 92 |
+
Args:
|
| 93 |
+
access_token: Bearer token
|
| 94 |
+
idp: 身份提供商 ("Google" 或 "Github")
|
| 95 |
+
|
| 96 |
+
Returns:
|
| 97 |
+
(success, user_info or error_dict)
|
| 98 |
+
user_info 包含: email, userId 等
|
| 99 |
+
"""
|
| 100 |
+
success, result = await kiro_api_request(
|
| 101 |
+
operation="GetUserUsageAndLimits",
|
| 102 |
+
body={"isEmailRequired": True, "origin": "KIRO_IDE"},
|
| 103 |
+
access_token=access_token,
|
| 104 |
+
idp=idp,
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
if not success:
|
| 108 |
+
return False, result
|
| 109 |
+
|
| 110 |
+
# 提取用户信息
|
| 111 |
+
user_info = result.get("userInfo", {})
|
| 112 |
+
return True, {
|
| 113 |
+
"email": user_info.get("email"),
|
| 114 |
+
"userId": user_info.get("userId"),
|
| 115 |
+
"raw": result,
|
| 116 |
+
}
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
async def get_user_email(
|
| 120 |
+
access_token: str,
|
| 121 |
+
provider: str = "Google",
|
| 122 |
+
) -> Optional[str]:
|
| 123 |
+
"""
|
| 124 |
+
获取用户邮箱地址
|
| 125 |
+
|
| 126 |
+
Args:
|
| 127 |
+
access_token: Bearer token
|
| 128 |
+
provider: 登录提供商 ("Google" 或 "Github")
|
| 129 |
+
|
| 130 |
+
Returns:
|
| 131 |
+
邮箱地址,失败返回 None
|
| 132 |
+
"""
|
| 133 |
+
# 标准化 provider 名称
|
| 134 |
+
idp = provider
|
| 135 |
+
if provider and provider.lower() == "google":
|
| 136 |
+
idp = "Google"
|
| 137 |
+
elif provider and provider.lower() == "github":
|
| 138 |
+
idp = "Github"
|
| 139 |
+
|
| 140 |
+
success, result = await get_user_info(access_token, idp)
|
| 141 |
+
|
| 142 |
+
if success:
|
| 143 |
+
return result.get("email")
|
| 144 |
+
|
| 145 |
+
print(f"[KiroAPI] 获取邮箱失败: {result.get('error', '未知错误')}")
|
| 146 |
+
return None
|
KiroProxy/kiro_proxy/core/persistence.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""配置持久化"""
|
| 2 |
+
import json
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from typing import List, Dict, Any
|
| 5 |
+
|
| 6 |
+
# 统一使用 config.py 中的 DATA_DIR
|
| 7 |
+
from ..config import DATA_DIR
|
| 8 |
+
|
| 9 |
+
# 配置文件路径
|
| 10 |
+
CONFIG_DIR = DATA_DIR
|
| 11 |
+
CONFIG_FILE = CONFIG_DIR / "config.json"
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def ensure_config_dir():
|
| 15 |
+
"""确保配置目录存在"""
|
| 16 |
+
CONFIG_DIR.mkdir(parents=True, exist_ok=True)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def save_accounts(accounts: List[Dict[str, Any]]) -> bool:
|
| 20 |
+
"""保存账号配置"""
|
| 21 |
+
try:
|
| 22 |
+
ensure_config_dir()
|
| 23 |
+
config = load_config()
|
| 24 |
+
config["accounts"] = accounts
|
| 25 |
+
with open(CONFIG_FILE, "w", encoding="utf-8") as f:
|
| 26 |
+
json.dump(config, f, indent=2, ensure_ascii=False)
|
| 27 |
+
return True
|
| 28 |
+
except Exception as e:
|
| 29 |
+
print(f"[Persistence] 保存配置失败: {e}")
|
| 30 |
+
return False
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def load_accounts() -> List[Dict[str, Any]]:
|
| 34 |
+
"""加载账号配置"""
|
| 35 |
+
config = load_config()
|
| 36 |
+
return config.get("accounts", [])
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def load_config() -> Dict[str, Any]:
|
| 40 |
+
"""加载完整配置"""
|
| 41 |
+
try:
|
| 42 |
+
if CONFIG_FILE.exists():
|
| 43 |
+
with open(CONFIG_FILE, "r", encoding="utf-8") as f:
|
| 44 |
+
return json.load(f)
|
| 45 |
+
except Exception as e:
|
| 46 |
+
print(f"[Persistence] 加载配置失败: {e}")
|
| 47 |
+
return {}
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def save_config(config: Dict[str, Any]) -> bool:
|
| 51 |
+
"""保存完整配置"""
|
| 52 |
+
try:
|
| 53 |
+
ensure_config_dir()
|
| 54 |
+
with open(CONFIG_FILE, "w", encoding="utf-8") as f:
|
| 55 |
+
json.dump(config, f, indent=2, ensure_ascii=False)
|
| 56 |
+
return True
|
| 57 |
+
except Exception as e:
|
| 58 |
+
print(f"[Persistence] 保存配置失败: {e}")
|
| 59 |
+
return False
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def export_config() -> Dict[str, Any]:
|
| 63 |
+
"""导出配置(用于备份)"""
|
| 64 |
+
return load_config()
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def import_config(config: Dict[str, Any]) -> bool:
|
| 68 |
+
"""导入配置(用于恢复)"""
|
| 69 |
+
return save_config(config)
|
KiroProxy/kiro_proxy/core/protocol_handler.py
ADDED
|
@@ -0,0 +1,318 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""自定义协议处理器
|
| 2 |
+
|
| 3 |
+
在 Windows 上注册 kiro:// 协议,用于处理 OAuth 回调。
|
| 4 |
+
"""
|
| 5 |
+
import sys
|
| 6 |
+
import os
|
| 7 |
+
import asyncio
|
| 8 |
+
import threading
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
from typing import Optional, Callable
|
| 11 |
+
from http.server import HTTPServer, BaseHTTPRequestHandler
|
| 12 |
+
from urllib.parse import urlparse, parse_qs, urlencode
|
| 13 |
+
import socket
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
# 回调服务器端口
|
| 17 |
+
CALLBACK_PORT = 19823
|
| 18 |
+
|
| 19 |
+
# 全局回调结果
|
| 20 |
+
_callback_result = None
|
| 21 |
+
_callback_event = None
|
| 22 |
+
_callback_server = None
|
| 23 |
+
_server_thread = None
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class CallbackHandler(BaseHTTPRequestHandler):
|
| 27 |
+
"""处理 OAuth 回调的 HTTP 请求处理器"""
|
| 28 |
+
|
| 29 |
+
def log_message(self, format, *args):
|
| 30 |
+
"""禁用日志输出"""
|
| 31 |
+
pass
|
| 32 |
+
|
| 33 |
+
def do_GET(self):
|
| 34 |
+
global _callback_result, _callback_event
|
| 35 |
+
|
| 36 |
+
# 解析 URL
|
| 37 |
+
parsed = urlparse(self.path)
|
| 38 |
+
params = parse_qs(parsed.query)
|
| 39 |
+
|
| 40 |
+
# 检查是否是回调路径
|
| 41 |
+
if parsed.path == '/kiro-callback' or parsed.path == '/' or 'code' in params:
|
| 42 |
+
code = params.get('code', [None])[0]
|
| 43 |
+
state = params.get('state', [None])[0]
|
| 44 |
+
error = params.get('error', [None])[0]
|
| 45 |
+
|
| 46 |
+
print(f"[ProtocolHandler] 收到回调: code={code[:20] if code else None}..., state={state}, error={error}")
|
| 47 |
+
|
| 48 |
+
if error:
|
| 49 |
+
_callback_result = {"error": error}
|
| 50 |
+
elif code and state:
|
| 51 |
+
_callback_result = {"code": code, "state": state}
|
| 52 |
+
else:
|
| 53 |
+
_callback_result = {"error": "缺少授权码"}
|
| 54 |
+
|
| 55 |
+
# 触发事件
|
| 56 |
+
if _callback_event:
|
| 57 |
+
_callback_event.set()
|
| 58 |
+
|
| 59 |
+
# 返回成功页面
|
| 60 |
+
self.send_response(200)
|
| 61 |
+
self.send_header('Content-type', 'text/html; charset=utf-8')
|
| 62 |
+
self.end_headers()
|
| 63 |
+
|
| 64 |
+
html = """
|
| 65 |
+
<!DOCTYPE html>
|
| 66 |
+
<html>
|
| 67 |
+
<head>
|
| 68 |
+
<meta charset="utf-8">
|
| 69 |
+
<title>登录成功</title>
|
| 70 |
+
<style>
|
| 71 |
+
body { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
|
| 72 |
+
display: flex; justify-content: center; align-items: center; height: 100vh;
|
| 73 |
+
margin: 0; background: #1a1a2e; color: #fff; }
|
| 74 |
+
.container { text-align: center; padding: 2rem; }
|
| 75 |
+
h1 { color: #4ade80; margin-bottom: 1rem; }
|
| 76 |
+
p { color: #9ca3af; }
|
| 77 |
+
</style>
|
| 78 |
+
</head>
|
| 79 |
+
<body>
|
| 80 |
+
<div class="container">
|
| 81 |
+
<h1>✅ 登录成功</h1>
|
| 82 |
+
<p>您可以关闭此窗口并返回 Kiro Proxy</p>
|
| 83 |
+
<script>setTimeout(function(){window.close();}, 3000);</script>
|
| 84 |
+
</div>
|
| 85 |
+
</body>
|
| 86 |
+
</html>
|
| 87 |
+
"""
|
| 88 |
+
self.wfile.write(html.encode('utf-8'))
|
| 89 |
+
else:
|
| 90 |
+
self.send_response(404)
|
| 91 |
+
self.end_headers()
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def is_port_available(port: int) -> bool:
|
| 95 |
+
"""检查端口是否可用"""
|
| 96 |
+
try:
|
| 97 |
+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
| 98 |
+
s.bind(('127.0.0.1', port))
|
| 99 |
+
return True
|
| 100 |
+
except OSError:
|
| 101 |
+
return False
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def start_callback_server() -> tuple:
|
| 105 |
+
"""启动回调服务器
|
| 106 |
+
|
| 107 |
+
Returns:
|
| 108 |
+
(success, port or error)
|
| 109 |
+
"""
|
| 110 |
+
global _callback_server, _callback_result, _callback_event, _server_thread
|
| 111 |
+
|
| 112 |
+
# 如果服务器已经在运行,直接返回成功
|
| 113 |
+
if _callback_server is not None and _server_thread is not None and _server_thread.is_alive():
|
| 114 |
+
print(f"[ProtocolHandler] 回调服务器已在运行: http://127.0.0.1:{CALLBACK_PORT}")
|
| 115 |
+
return True, CALLBACK_PORT
|
| 116 |
+
|
| 117 |
+
_callback_result = None
|
| 118 |
+
_callback_event = threading.Event()
|
| 119 |
+
|
| 120 |
+
# 检查端口
|
| 121 |
+
if not is_port_available(CALLBACK_PORT):
|
| 122 |
+
# 端口被占用,可能是之前的服务器还在运行
|
| 123 |
+
print(f"[ProtocolHandler] 端口 {CALLBACK_PORT} 已被占用,尝试复用")
|
| 124 |
+
return True, CALLBACK_PORT
|
| 125 |
+
|
| 126 |
+
try:
|
| 127 |
+
_callback_server = HTTPServer(('127.0.0.1', CALLBACK_PORT), CallbackHandler)
|
| 128 |
+
|
| 129 |
+
# 在后台线程运行服务器
|
| 130 |
+
_server_thread = threading.Thread(target=_callback_server.serve_forever, daemon=True)
|
| 131 |
+
_server_thread.start()
|
| 132 |
+
|
| 133 |
+
print(f"[ProtocolHandler] 回调服务器已启动: http://127.0.0.1:{CALLBACK_PORT}")
|
| 134 |
+
return True, CALLBACK_PORT
|
| 135 |
+
except Exception as e:
|
| 136 |
+
return False, str(e)
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def stop_callback_server():
|
| 140 |
+
"""停止回调服务器"""
|
| 141 |
+
global _callback_server, _server_thread
|
| 142 |
+
|
| 143 |
+
if _callback_server:
|
| 144 |
+
try:
|
| 145 |
+
_callback_server.shutdown()
|
| 146 |
+
except:
|
| 147 |
+
pass
|
| 148 |
+
_callback_server = None
|
| 149 |
+
_server_thread = None
|
| 150 |
+
print("[ProtocolHandler] 回调服务���已停止")
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def wait_for_callback(timeout: int = 300) -> tuple:
|
| 154 |
+
"""等待回调
|
| 155 |
+
|
| 156 |
+
Args:
|
| 157 |
+
timeout: 超时时间(秒)
|
| 158 |
+
|
| 159 |
+
Returns:
|
| 160 |
+
(success, result or error)
|
| 161 |
+
"""
|
| 162 |
+
global _callback_result, _callback_event
|
| 163 |
+
|
| 164 |
+
if _callback_event is None:
|
| 165 |
+
return False, {"error": "回调服务器未启动"}
|
| 166 |
+
|
| 167 |
+
# 等待回调
|
| 168 |
+
if _callback_event.wait(timeout=timeout):
|
| 169 |
+
if _callback_result and "code" in _callback_result:
|
| 170 |
+
return True, _callback_result
|
| 171 |
+
elif _callback_result and "error" in _callback_result:
|
| 172 |
+
return False, _callback_result
|
| 173 |
+
else:
|
| 174 |
+
return False, {"error": "未收到有效回调"}
|
| 175 |
+
else:
|
| 176 |
+
return False, {"error": "等待回调超时"}
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
def get_callback_result() -> Optional[dict]:
|
| 180 |
+
"""获取回调结果(非阻塞)"""
|
| 181 |
+
global _callback_result
|
| 182 |
+
return _callback_result
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
def clear_callback_result():
|
| 186 |
+
"""清除回调结果"""
|
| 187 |
+
global _callback_result, _callback_event
|
| 188 |
+
_callback_result = None
|
| 189 |
+
if _callback_event:
|
| 190 |
+
_callback_event.clear()
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
# Windows 协议注册
|
| 194 |
+
def register_protocol_windows() -> tuple:
|
| 195 |
+
"""在 Windows 上注册 kiro:// 协议
|
| 196 |
+
|
| 197 |
+
注册后,当浏览器重定向到 kiro:// URL 时,Windows 会调用我们的脚本,
|
| 198 |
+
脚本将参数重定向到本地 HTTP 服务器。
|
| 199 |
+
|
| 200 |
+
Returns:
|
| 201 |
+
(success, message)
|
| 202 |
+
"""
|
| 203 |
+
if sys.platform != 'win32':
|
| 204 |
+
return False, "仅支持 Windows"
|
| 205 |
+
|
| 206 |
+
try:
|
| 207 |
+
import winreg
|
| 208 |
+
|
| 209 |
+
# 获取当前 Python 解释器路径
|
| 210 |
+
python_exe = sys.executable
|
| 211 |
+
|
| 212 |
+
# 创建一个处理脚本
|
| 213 |
+
script_dir = Path.home() / ".kiro-proxy"
|
| 214 |
+
script_dir.mkdir(parents=True, exist_ok=True)
|
| 215 |
+
script_path = script_dir / "protocol_redirect.pyw"
|
| 216 |
+
|
| 217 |
+
# 写入重定向脚本 (.pyw 不显示控制台窗口)
|
| 218 |
+
script_content = f'''# -*- coding: utf-8 -*-
|
| 219 |
+
# Kiro Protocol Redirect Script
|
| 220 |
+
import sys
|
| 221 |
+
import webbrowser
|
| 222 |
+
from urllib.parse import urlparse, parse_qs, urlencode
|
| 223 |
+
|
| 224 |
+
if len(sys.argv) > 1:
|
| 225 |
+
url = sys.argv[1]
|
| 226 |
+
|
| 227 |
+
# 解析 kiro:// URL
|
| 228 |
+
# 格式: kiro://kiro.kiroAgent/authenticate-success?code=xxx&state=xxx
|
| 229 |
+
if url.startswith('kiro://'):
|
| 230 |
+
# 提取查询参数
|
| 231 |
+
query_start = url.find('?')
|
| 232 |
+
if query_start > -1:
|
| 233 |
+
query_string = url[query_start + 1:]
|
| 234 |
+
# 重定向到本地 HTTP 服务器
|
| 235 |
+
redirect_url = "http://127.0.0.1:{CALLBACK_PORT}/kiro-callback?" + query_string
|
| 236 |
+
webbrowser.open(redirect_url)
|
| 237 |
+
'''
|
| 238 |
+
script_path.write_text(script_content, encoding='utf-8')
|
| 239 |
+
|
| 240 |
+
# 获取 pythonw.exe 路径(无控制台窗口)
|
| 241 |
+
python_dir = Path(python_exe).parent
|
| 242 |
+
pythonw_exe = python_dir / "pythonw.exe"
|
| 243 |
+
if not pythonw_exe.exists():
|
| 244 |
+
pythonw_exe = python_exe # 降级使用 python.exe
|
| 245 |
+
|
| 246 |
+
# 注册协议
|
| 247 |
+
key_path = r"SOFTWARE\\Classes\\kiro"
|
| 248 |
+
|
| 249 |
+
# 创建主键
|
| 250 |
+
key = winreg.CreateKey(winreg.HKEY_CURRENT_USER, key_path)
|
| 251 |
+
winreg.SetValue(key, "", winreg.REG_SZ, "URL:Kiro Protocol")
|
| 252 |
+
winreg.SetValueEx(key, "URL Protocol", 0, winreg.REG_SZ, "")
|
| 253 |
+
winreg.CloseKey(key)
|
| 254 |
+
|
| 255 |
+
# 创建 DefaultIcon 键
|
| 256 |
+
icon_key = winreg.CreateKey(winreg.HKEY_CURRENT_USER, key_path + r"\\DefaultIcon")
|
| 257 |
+
winreg.SetValue(icon_key, "", winreg.REG_SZ, f"{python_exe},0")
|
| 258 |
+
winreg.CloseKey(icon_key)
|
| 259 |
+
|
| 260 |
+
# 创建 shell\\open\\command 键
|
| 261 |
+
cmd_key = winreg.CreateKey(winreg.HKEY_CURRENT_USER, key_path + r"\\shell\\open\\command")
|
| 262 |
+
cmd = f'"{pythonw_exe}" "{script_path}" "%1"'
|
| 263 |
+
winreg.SetValue(cmd_key, "", winreg.REG_SZ, cmd)
|
| 264 |
+
winreg.CloseKey(cmd_key)
|
| 265 |
+
|
| 266 |
+
print(f"[ProtocolHandler] 已注册 kiro:// 协议")
|
| 267 |
+
print(f"[ProtocolHandler] 脚本路径: {script_path}")
|
| 268 |
+
print(f"[ProtocolHandler] 命令: {cmd}")
|
| 269 |
+
return True, "协议注册成功"
|
| 270 |
+
|
| 271 |
+
except Exception as e:
|
| 272 |
+
import traceback
|
| 273 |
+
traceback.print_exc()
|
| 274 |
+
return False, f"注册失败: {e}"
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
def unregister_protocol_windows() -> tuple:
|
| 278 |
+
"""取消注册 kiro:// 协议"""
|
| 279 |
+
if sys.platform != 'win32':
|
| 280 |
+
return False, "仅支持 Windows"
|
| 281 |
+
|
| 282 |
+
try:
|
| 283 |
+
import winreg
|
| 284 |
+
|
| 285 |
+
def delete_key_recursive(key, subkey):
|
| 286 |
+
try:
|
| 287 |
+
open_key = winreg.OpenKey(key, subkey, 0, winreg.KEY_ALL_ACCESS)
|
| 288 |
+
info = winreg.QueryInfoKey(open_key)
|
| 289 |
+
for i in range(info[0]):
|
| 290 |
+
child = winreg.EnumKey(open_key, 0)
|
| 291 |
+
delete_key_recursive(open_key, child)
|
| 292 |
+
winreg.CloseKey(open_key)
|
| 293 |
+
winreg.DeleteKey(key, subkey)
|
| 294 |
+
except WindowsError:
|
| 295 |
+
pass
|
| 296 |
+
|
| 297 |
+
delete_key_recursive(winreg.HKEY_CURRENT_USER, r"SOFTWARE\\Classes\\kiro")
|
| 298 |
+
|
| 299 |
+
print("[ProtocolHandler] 已取消注册 kiro:// 协议")
|
| 300 |
+
return True, "协议取消注册成功"
|
| 301 |
+
|
| 302 |
+
except Exception as e:
|
| 303 |
+
return False, f"取消注册失败: {e}"
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
def is_protocol_registered() -> bool:
|
| 307 |
+
"""检查 kiro:// 协议是否已注册"""
|
| 308 |
+
if sys.platform != 'win32':
|
| 309 |
+
return False
|
| 310 |
+
|
| 311 |
+
try:
|
| 312 |
+
import winreg
|
| 313 |
+
key = winreg.OpenKey(winreg.HKEY_CURRENT_USER, r"SOFTWARE\\Classes\\kiro")
|
| 314 |
+
winreg.CloseKey(key)
|
| 315 |
+
return True
|
| 316 |
+
except WindowsError:
|
| 317 |
+
return False
|
| 318 |
+
|
KiroProxy/kiro_proxy/core/quota_cache.py
ADDED
|
@@ -0,0 +1,397 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""额度缓存管理模块
|
| 2 |
+
|
| 3 |
+
提供账号额度信息的内存缓存和文件持久化功能。
|
| 4 |
+
"""
|
| 5 |
+
import json
|
| 6 |
+
import time
|
| 7 |
+
import asyncio
|
| 8 |
+
from dataclasses import dataclass, field, asdict
|
| 9 |
+
from enum import Enum
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
from typing import Optional, Dict, Any
|
| 12 |
+
from threading import Lock
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
# 默认缓存过期时间(秒)
|
| 16 |
+
DEFAULT_CACHE_MAX_AGE = 300 # 5分钟
|
| 17 |
+
|
| 18 |
+
# 低余额阈值
|
| 19 |
+
LOW_BALANCE_THRESHOLD = 0.2
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class BalanceStatus(Enum):
|
| 23 |
+
"""额度状态枚举
|
| 24 |
+
|
| 25 |
+
用于区分账号的额度状态:
|
| 26 |
+
- NORMAL: 正常(剩余额度 > 20%)
|
| 27 |
+
- LOW: 低额度(0 < 剩余额度 <= 20%)
|
| 28 |
+
- EXHAUSTED: 无额度(剩余额度 <= 0)
|
| 29 |
+
"""
|
| 30 |
+
NORMAL = "normal" # 正常(>20%)
|
| 31 |
+
LOW = "low" # 低额度(0-20%)
|
| 32 |
+
EXHAUSTED = "exhausted" # 无额度(<=0)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
@dataclass
|
| 36 |
+
class CachedQuota:
|
| 37 |
+
"""缓存的额度信息"""
|
| 38 |
+
account_id: str
|
| 39 |
+
usage_limit: float = 0.0 # 总额度
|
| 40 |
+
current_usage: float = 0.0 # 已用额度
|
| 41 |
+
balance: float = 0.0 # 剩余额度
|
| 42 |
+
usage_percent: float = 0.0 # 使用百分比
|
| 43 |
+
balance_status: str = "normal" # 额度状态: normal, low, exhausted
|
| 44 |
+
is_low_balance: bool = False # 是否低额度(兼容旧字段)
|
| 45 |
+
is_exhausted: bool = False # 是否无额度
|
| 46 |
+
is_suspended: bool = False # 是否被封禁
|
| 47 |
+
subscription_title: str = "" # 订阅类型
|
| 48 |
+
free_trial_limit: float = 0.0 # 免费试用额度
|
| 49 |
+
free_trial_usage: float = 0.0 # 免费试用已用
|
| 50 |
+
bonus_limit: float = 0.0 # 奖励额度
|
| 51 |
+
bonus_usage: float = 0.0 # 奖励已用
|
| 52 |
+
updated_at: float = 0.0 # 更新时间戳
|
| 53 |
+
error: Optional[str] = None # 错误信息(如果获取失败)
|
| 54 |
+
|
| 55 |
+
# 重置和过期时间
|
| 56 |
+
next_reset_date: Optional[str] = None # 下次重置时间
|
| 57 |
+
free_trial_expiry: Optional[str] = None # 免费试用过期时间
|
| 58 |
+
bonus_expiries: list = None # 奖励过期时间列表
|
| 59 |
+
|
| 60 |
+
def __post_init__(self):
|
| 61 |
+
"""初始化后计算额度状态"""
|
| 62 |
+
self._update_balance_status()
|
| 63 |
+
|
| 64 |
+
def _update_balance_status(self) -> None:
|
| 65 |
+
"""更新额度状态"""
|
| 66 |
+
if self.error is not None:
|
| 67 |
+
# 有错误时不更新状态
|
| 68 |
+
return
|
| 69 |
+
|
| 70 |
+
if self.balance <= 0:
|
| 71 |
+
self.balance_status = BalanceStatus.EXHAUSTED.value
|
| 72 |
+
self.is_exhausted = True
|
| 73 |
+
self.is_low_balance = False
|
| 74 |
+
elif self.usage_limit > 0:
|
| 75 |
+
remaining_percent = (self.balance / self.usage_limit) * 100
|
| 76 |
+
if remaining_percent <= LOW_BALANCE_THRESHOLD * 100:
|
| 77 |
+
self.balance_status = BalanceStatus.LOW.value
|
| 78 |
+
self.is_low_balance = True
|
| 79 |
+
self.is_exhausted = False
|
| 80 |
+
else:
|
| 81 |
+
self.balance_status = BalanceStatus.NORMAL.value
|
| 82 |
+
self.is_low_balance = False
|
| 83 |
+
self.is_exhausted = False
|
| 84 |
+
else:
|
| 85 |
+
self.balance_status = BalanceStatus.NORMAL.value
|
| 86 |
+
self.is_low_balance = False
|
| 87 |
+
self.is_exhausted = False
|
| 88 |
+
|
| 89 |
+
@classmethod
|
| 90 |
+
def from_usage_info(cls, account_id: str, usage_info: 'UsageInfo') -> 'CachedQuota':
|
| 91 |
+
"""从 UsageInfo 创建 CachedQuota"""
|
| 92 |
+
usage_percent = (usage_info.current_usage / usage_info.usage_limit * 100) if usage_info.usage_limit > 0 else 0.0
|
| 93 |
+
quota = cls(
|
| 94 |
+
account_id=account_id,
|
| 95 |
+
usage_limit=usage_info.usage_limit,
|
| 96 |
+
current_usage=usage_info.current_usage,
|
| 97 |
+
balance=usage_info.balance,
|
| 98 |
+
usage_percent=round(usage_percent, 2),
|
| 99 |
+
is_low_balance=usage_info.is_low_balance,
|
| 100 |
+
subscription_title=usage_info.subscription_title,
|
| 101 |
+
free_trial_limit=usage_info.free_trial_limit,
|
| 102 |
+
free_trial_usage=usage_info.free_trial_usage,
|
| 103 |
+
bonus_limit=usage_info.bonus_limit,
|
| 104 |
+
bonus_usage=usage_info.bonus_usage,
|
| 105 |
+
updated_at=time.time(),
|
| 106 |
+
error=None,
|
| 107 |
+
next_reset_date=usage_info.next_reset_date,
|
| 108 |
+
free_trial_expiry=usage_info.free_trial_expiry,
|
| 109 |
+
bonus_expiries=usage_info.bonus_expiries or [],
|
| 110 |
+
)
|
| 111 |
+
# 重新计算状态以确保一致性
|
| 112 |
+
quota._update_balance_status()
|
| 113 |
+
return quota
|
| 114 |
+
|
| 115 |
+
@classmethod
|
| 116 |
+
def from_error(cls, account_id: str, error: str) -> 'CachedQuota':
|
| 117 |
+
"""创建错误状态的缓存"""
|
| 118 |
+
# 检查是否为账号封禁错误
|
| 119 |
+
is_suspended = (
|
| 120 |
+
"temporarily_suspended" in error.lower() or
|
| 121 |
+
"suspended" in error.lower() or
|
| 122 |
+
"accountsuspendedexception" in error.lower()
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
quota = cls(
|
| 126 |
+
account_id=account_id,
|
| 127 |
+
updated_at=time.time(),
|
| 128 |
+
error=error
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
# 如果是封禁错误,标记为特殊状态
|
| 132 |
+
if is_suspended:
|
| 133 |
+
quota.is_suspended = True
|
| 134 |
+
|
| 135 |
+
return quota
|
| 136 |
+
|
| 137 |
+
@classmethod
|
| 138 |
+
def from_dict(cls, data: Dict[str, Any]) -> 'CachedQuota':
|
| 139 |
+
"""从字典创建"""
|
| 140 |
+
quota = cls(
|
| 141 |
+
account_id=data.get("account_id", ""),
|
| 142 |
+
usage_limit=data.get("usage_limit", 0.0),
|
| 143 |
+
current_usage=data.get("current_usage", 0.0),
|
| 144 |
+
balance=data.get("balance", 0.0),
|
| 145 |
+
usage_percent=data.get("usage_percent", 0.0),
|
| 146 |
+
balance_status=data.get("balance_status", "normal"),
|
| 147 |
+
is_low_balance=data.get("is_low_balance", False),
|
| 148 |
+
is_exhausted=data.get("is_exhausted", False),
|
| 149 |
+
is_suspended=data.get("is_suspended", False),
|
| 150 |
+
subscription_title=data.get("subscription_title", ""),
|
| 151 |
+
free_trial_limit=data.get("free_trial_limit", 0.0),
|
| 152 |
+
free_trial_usage=data.get("free_trial_usage", 0.0),
|
| 153 |
+
bonus_limit=data.get("bonus_limit", 0.0),
|
| 154 |
+
bonus_usage=data.get("bonus_usage", 0.0),
|
| 155 |
+
updated_at=data.get("updated_at", 0.0),
|
| 156 |
+
error=data.get("error"),
|
| 157 |
+
next_reset_date=data.get("next_reset_date"),
|
| 158 |
+
free_trial_expiry=data.get("free_trial_expiry"),
|
| 159 |
+
bonus_expiries=data.get("bonus_expiries", []),
|
| 160 |
+
)
|
| 161 |
+
# 重新计算状态以确保一致性
|
| 162 |
+
quota._update_balance_status()
|
| 163 |
+
return quota
|
| 164 |
+
|
| 165 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 166 |
+
"""转换为字典"""
|
| 167 |
+
return asdict(self)
|
| 168 |
+
|
| 169 |
+
def has_error(self) -> bool:
|
| 170 |
+
"""是否有错误"""
|
| 171 |
+
return self.error is not None
|
| 172 |
+
|
| 173 |
+
def is_available(self) -> bool:
|
| 174 |
+
"""额度是否可用(未耗尽且无错误)"""
|
| 175 |
+
return not self.is_exhausted and not self.has_error()
|
| 176 |
+
|
| 177 |
+
def get_balance_status_enum(self) -> BalanceStatus:
|
| 178 |
+
"""获取额度状态枚举"""
|
| 179 |
+
try:
|
| 180 |
+
return BalanceStatus(self.balance_status)
|
| 181 |
+
except ValueError:
|
| 182 |
+
return BalanceStatus.NORMAL
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
class QuotaCache:
|
| 186 |
+
"""额度缓存管理器
|
| 187 |
+
|
| 188 |
+
提供线程安全的额度缓存操作,支持内存缓存和文件持久化。
|
| 189 |
+
"""
|
| 190 |
+
|
| 191 |
+
def __init__(self, cache_file: Optional[str] = None):
|
| 192 |
+
"""
|
| 193 |
+
初始化缓存管理器
|
| 194 |
+
|
| 195 |
+
Args:
|
| 196 |
+
cache_file: 缓存文件路径,None 则使用默认路径
|
| 197 |
+
"""
|
| 198 |
+
self._cache: Dict[str, CachedQuota] = {}
|
| 199 |
+
self._lock = Lock()
|
| 200 |
+
self._save_lock = asyncio.Lock()
|
| 201 |
+
|
| 202 |
+
# 设置缓存文件路径
|
| 203 |
+
if cache_file:
|
| 204 |
+
self._cache_file = Path(cache_file)
|
| 205 |
+
else:
|
| 206 |
+
from ..config import DATA_DIR
|
| 207 |
+
self._cache_file = DATA_DIR / "quota_cache.json"
|
| 208 |
+
|
| 209 |
+
# 启动时加载缓存
|
| 210 |
+
self.load_from_file()
|
| 211 |
+
|
| 212 |
+
def get(self, account_id: str) -> Optional[CachedQuota]:
|
| 213 |
+
"""获取账号的缓存额度
|
| 214 |
+
|
| 215 |
+
Args:
|
| 216 |
+
account_id: 账号ID
|
| 217 |
+
|
| 218 |
+
Returns:
|
| 219 |
+
缓存的额度信息,不存在则返回 None
|
| 220 |
+
"""
|
| 221 |
+
with self._lock:
|
| 222 |
+
return self._cache.get(account_id)
|
| 223 |
+
|
| 224 |
+
def set(self, account_id: str, quota: CachedQuota) -> None:
|
| 225 |
+
"""设置账号的额度缓存
|
| 226 |
+
|
| 227 |
+
Args:
|
| 228 |
+
account_id: 账号ID
|
| 229 |
+
quota: 额度信息
|
| 230 |
+
"""
|
| 231 |
+
with self._lock:
|
| 232 |
+
self._cache[account_id] = quota
|
| 233 |
+
|
| 234 |
+
def is_stale(self, account_id: str, max_age_seconds: int = DEFAULT_CACHE_MAX_AGE) -> bool:
|
| 235 |
+
"""检查缓存是否过期
|
| 236 |
+
|
| 237 |
+
Args:
|
| 238 |
+
account_id: 账号ID
|
| 239 |
+
max_age_seconds: 最大缓存时间(秒)
|
| 240 |
+
|
| 241 |
+
Returns:
|
| 242 |
+
True 表示缓存过期或不存在
|
| 243 |
+
"""
|
| 244 |
+
with self._lock:
|
| 245 |
+
quota = self._cache.get(account_id)
|
| 246 |
+
if quota is None:
|
| 247 |
+
return True
|
| 248 |
+
return (time.time() - quota.updated_at) > max_age_seconds
|
| 249 |
+
|
| 250 |
+
def get_all(self) -> Dict[str, CachedQuota]:
|
| 251 |
+
"""获取所有缓存
|
| 252 |
+
|
| 253 |
+
Returns:
|
| 254 |
+
所有账号的额度缓存副本
|
| 255 |
+
"""
|
| 256 |
+
with self._lock:
|
| 257 |
+
return dict(self._cache)
|
| 258 |
+
|
| 259 |
+
def remove(self, account_id: str) -> None:
|
| 260 |
+
"""移除账号缓存
|
| 261 |
+
|
| 262 |
+
Args:
|
| 263 |
+
account_id: 账号ID
|
| 264 |
+
"""
|
| 265 |
+
with self._lock:
|
| 266 |
+
self._cache.pop(account_id, None)
|
| 267 |
+
|
| 268 |
+
def clear(self) -> None:
|
| 269 |
+
"""清空所有缓存"""
|
| 270 |
+
with self._lock:
|
| 271 |
+
self._cache.clear()
|
| 272 |
+
|
| 273 |
+
def load_from_file(self) -> bool:
|
| 274 |
+
"""从文件加载缓存
|
| 275 |
+
|
| 276 |
+
Returns:
|
| 277 |
+
是否加载成功
|
| 278 |
+
"""
|
| 279 |
+
if not self._cache_file.exists():
|
| 280 |
+
return False
|
| 281 |
+
|
| 282 |
+
try:
|
| 283 |
+
with open(self._cache_file, 'r', encoding='utf-8') as f:
|
| 284 |
+
data = json.load(f)
|
| 285 |
+
|
| 286 |
+
# 验证版本
|
| 287 |
+
version = data.get("version", "1.0")
|
| 288 |
+
accounts_data = data.get("accounts", {})
|
| 289 |
+
|
| 290 |
+
with self._lock:
|
| 291 |
+
self._cache.clear()
|
| 292 |
+
for account_id, quota_data in accounts_data.items():
|
| 293 |
+
quota_data["account_id"] = account_id
|
| 294 |
+
self._cache[account_id] = CachedQuota.from_dict(quota_data)
|
| 295 |
+
|
| 296 |
+
print(f"[QuotaCache] 从文件加载 {len(self._cache)} 个账号的额度缓存")
|
| 297 |
+
return True
|
| 298 |
+
|
| 299 |
+
except json.JSONDecodeError as e:
|
| 300 |
+
print(f"[QuotaCache] 缓存文件格式错误: {e}")
|
| 301 |
+
return False
|
| 302 |
+
except Exception as e:
|
| 303 |
+
print(f"[QuotaCache] 加载缓存失败: {e}")
|
| 304 |
+
return False
|
| 305 |
+
|
| 306 |
+
def save_to_file(self) -> bool:
|
| 307 |
+
"""保存缓存到文件(同步版本)
|
| 308 |
+
|
| 309 |
+
Returns:
|
| 310 |
+
是否保存成功
|
| 311 |
+
"""
|
| 312 |
+
try:
|
| 313 |
+
# 确保目录存在
|
| 314 |
+
self._cache_file.parent.mkdir(parents=True, exist_ok=True)
|
| 315 |
+
|
| 316 |
+
with self._lock:
|
| 317 |
+
accounts_data = {}
|
| 318 |
+
for account_id, quota in self._cache.items():
|
| 319 |
+
quota_dict = quota.to_dict()
|
| 320 |
+
quota_dict.pop("account_id", None) # 避免重复存储
|
| 321 |
+
accounts_data[account_id] = quota_dict
|
| 322 |
+
|
| 323 |
+
data = {
|
| 324 |
+
"version": "1.0",
|
| 325 |
+
"updated_at": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
|
| 326 |
+
"accounts": accounts_data
|
| 327 |
+
}
|
| 328 |
+
|
| 329 |
+
# 写入临时文件后重命名,确保原子性
|
| 330 |
+
temp_file = self._cache_file.with_suffix('.tmp')
|
| 331 |
+
with open(temp_file, 'w', encoding='utf-8') as f:
|
| 332 |
+
json.dump(data, f, indent=2, ensure_ascii=False)
|
| 333 |
+
temp_file.replace(self._cache_file)
|
| 334 |
+
|
| 335 |
+
return True
|
| 336 |
+
|
| 337 |
+
except Exception as e:
|
| 338 |
+
print(f"[QuotaCache] 保存缓存失败: {e}")
|
| 339 |
+
return False
|
| 340 |
+
|
| 341 |
+
async def save_to_file_async(self) -> bool:
|
| 342 |
+
"""异步保存缓存到文件
|
| 343 |
+
|
| 344 |
+
Returns:
|
| 345 |
+
是否保存成功
|
| 346 |
+
"""
|
| 347 |
+
async with self._save_lock:
|
| 348 |
+
# 在线程池中执行同步保存
|
| 349 |
+
loop = asyncio.get_event_loop()
|
| 350 |
+
return await loop.run_in_executor(None, self.save_to_file)
|
| 351 |
+
|
| 352 |
+
def get_summary(self) -> Dict[str, Any]:
|
| 353 |
+
"""获取缓存汇总信息
|
| 354 |
+
|
| 355 |
+
Returns:
|
| 356 |
+
汇总统计信息
|
| 357 |
+
"""
|
| 358 |
+
with self._lock:
|
| 359 |
+
total_balance = 0.0
|
| 360 |
+
total_usage = 0.0
|
| 361 |
+
total_limit = 0.0
|
| 362 |
+
error_count = 0
|
| 363 |
+
stale_count = 0
|
| 364 |
+
|
| 365 |
+
current_time = time.time()
|
| 366 |
+
|
| 367 |
+
for quota in self._cache.values():
|
| 368 |
+
if quota.has_error():
|
| 369 |
+
error_count += 1
|
| 370 |
+
else:
|
| 371 |
+
total_balance += quota.balance
|
| 372 |
+
total_usage += quota.current_usage
|
| 373 |
+
total_limit += quota.usage_limit
|
| 374 |
+
|
| 375 |
+
if (current_time - quota.updated_at) > DEFAULT_CACHE_MAX_AGE:
|
| 376 |
+
stale_count += 1
|
| 377 |
+
|
| 378 |
+
return {
|
| 379 |
+
"total_accounts": len(self._cache),
|
| 380 |
+
"total_balance": round(total_balance, 2),
|
| 381 |
+
"total_usage": round(total_usage, 2),
|
| 382 |
+
"total_limit": round(total_limit, 2),
|
| 383 |
+
"error_count": error_count,
|
| 384 |
+
"stale_count": stale_count
|
| 385 |
+
}
|
| 386 |
+
|
| 387 |
+
|
| 388 |
+
# 全局缓存实例
|
| 389 |
+
_quota_cache: Optional[QuotaCache] = None
|
| 390 |
+
|
| 391 |
+
|
| 392 |
+
def get_quota_cache() -> QuotaCache:
|
| 393 |
+
"""获取全局缓存实例"""
|
| 394 |
+
global _quota_cache
|
| 395 |
+
if _quota_cache is None:
|
| 396 |
+
_quota_cache = QuotaCache()
|
| 397 |
+
return _quota_cache
|
KiroProxy/kiro_proxy/core/quota_scheduler.py
ADDED
|
@@ -0,0 +1,321 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""额度更新调度器模块
|
| 2 |
+
|
| 3 |
+
实现启动时并发获取所有账号额度、定时更新活跃账号额度的功能。
|
| 4 |
+
"""
|
| 5 |
+
import asyncio
|
| 6 |
+
import time
|
| 7 |
+
from typing import Optional, Set, Dict, List, TYPE_CHECKING
|
| 8 |
+
from threading import Lock
|
| 9 |
+
|
| 10 |
+
if TYPE_CHECKING:
|
| 11 |
+
from .account import Account
|
| 12 |
+
|
| 13 |
+
from .quota_cache import QuotaCache, CachedQuota, get_quota_cache
|
| 14 |
+
from .usage import get_account_usage
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
# 默认更新间隔(秒)
|
| 18 |
+
DEFAULT_UPDATE_INTERVAL = 60
|
| 19 |
+
|
| 20 |
+
# 活跃账号判定时间窗口(秒)
|
| 21 |
+
# 需要覆盖一次更新周期,避免低频请求时“永远错过”定时刷新
|
| 22 |
+
ACTIVE_WINDOW_SECONDS = 120
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class QuotaScheduler:
|
| 26 |
+
"""额度更新调度器
|
| 27 |
+
|
| 28 |
+
负责启动时并发获取所有账号额度,以及定时更新活跃账号的额度。
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
def __init__(self,
|
| 32 |
+
quota_cache: Optional[QuotaCache] = None,
|
| 33 |
+
update_interval: int = DEFAULT_UPDATE_INTERVAL):
|
| 34 |
+
"""
|
| 35 |
+
初始化调度器
|
| 36 |
+
|
| 37 |
+
Args:
|
| 38 |
+
quota_cache: 额度缓存实例
|
| 39 |
+
update_interval: 更新间隔(秒)
|
| 40 |
+
"""
|
| 41 |
+
self.quota_cache = quota_cache or get_quota_cache()
|
| 42 |
+
self.update_interval = update_interval
|
| 43 |
+
|
| 44 |
+
self._active_accounts: Dict[str, float] = {} # account_id -> last_used_timestamp
|
| 45 |
+
self._lock = Lock()
|
| 46 |
+
self._task: Optional[asyncio.Task] = None
|
| 47 |
+
self._running = False
|
| 48 |
+
self._last_full_refresh: Optional[float] = None
|
| 49 |
+
self._accounts_getter = None # 获取账号列表的回调函数
|
| 50 |
+
|
| 51 |
+
def set_accounts_getter(self, getter):
|
| 52 |
+
"""设置获取账号列表的回调函数
|
| 53 |
+
|
| 54 |
+
Args:
|
| 55 |
+
getter: 返回账号列表的可调用对象
|
| 56 |
+
"""
|
| 57 |
+
self._accounts_getter = getter
|
| 58 |
+
|
| 59 |
+
def _get_accounts(self) -> List['Account']:
|
| 60 |
+
"""获取账号列表"""
|
| 61 |
+
if self._accounts_getter:
|
| 62 |
+
return self._accounts_getter()
|
| 63 |
+
return []
|
| 64 |
+
|
| 65 |
+
async def start(self) -> None:
|
| 66 |
+
"""启动调度器"""
|
| 67 |
+
if self._running:
|
| 68 |
+
return
|
| 69 |
+
|
| 70 |
+
self._running = True
|
| 71 |
+
print("[QuotaScheduler] 启动额度更新调度器")
|
| 72 |
+
|
| 73 |
+
# 启动时刷新所有账号额度
|
| 74 |
+
await self.refresh_all()
|
| 75 |
+
|
| 76 |
+
# 启动定时更新任务
|
| 77 |
+
self._task = asyncio.create_task(self._update_loop())
|
| 78 |
+
|
| 79 |
+
async def stop(self) -> None:
|
| 80 |
+
"""停止调度器"""
|
| 81 |
+
self._running = False
|
| 82 |
+
|
| 83 |
+
if self._task:
|
| 84 |
+
self._task.cancel()
|
| 85 |
+
try:
|
| 86 |
+
await self._task
|
| 87 |
+
except asyncio.CancelledError:
|
| 88 |
+
pass
|
| 89 |
+
self._task = None
|
| 90 |
+
|
| 91 |
+
print("[QuotaScheduler] 额度更新调度器已停止")
|
| 92 |
+
|
| 93 |
+
async def refresh_all(self) -> Dict[str, bool]:
|
| 94 |
+
"""刷新所有账号额度
|
| 95 |
+
|
| 96 |
+
Returns:
|
| 97 |
+
账号ID -> 是否成功的字典
|
| 98 |
+
"""
|
| 99 |
+
accounts = self._get_accounts()
|
| 100 |
+
if not accounts:
|
| 101 |
+
print("[QuotaScheduler] 没有账号需要刷新")
|
| 102 |
+
return {}
|
| 103 |
+
|
| 104 |
+
# 刷新所有账号(包括禁用的,以便检查是否可以解禁)
|
| 105 |
+
print(f"[QuotaScheduler] 开始刷新 {len(accounts)} 个账号的额度...")
|
| 106 |
+
|
| 107 |
+
# 并发获取所有账号额度
|
| 108 |
+
tasks = [self._refresh_account_internal(acc) for acc in accounts]
|
| 109 |
+
results = await asyncio.gather(*tasks, return_exceptions=True)
|
| 110 |
+
|
| 111 |
+
# 统计结果
|
| 112 |
+
success_count = 0
|
| 113 |
+
fail_count = 0
|
| 114 |
+
result_dict = {}
|
| 115 |
+
|
| 116 |
+
for acc, result in zip(accounts, results):
|
| 117 |
+
if isinstance(result, Exception):
|
| 118 |
+
result_dict[acc.id] = False
|
| 119 |
+
fail_count += 1
|
| 120 |
+
else:
|
| 121 |
+
result_dict[acc.id] = result
|
| 122 |
+
if result:
|
| 123 |
+
success_count += 1
|
| 124 |
+
else:
|
| 125 |
+
fail_count += 1
|
| 126 |
+
|
| 127 |
+
self._last_full_refresh = time.time()
|
| 128 |
+
|
| 129 |
+
# 保存缓存
|
| 130 |
+
await self.quota_cache.save_to_file_async()
|
| 131 |
+
|
| 132 |
+
# 保存账号配置(因为可能有启用/禁用状态变化)
|
| 133 |
+
self._save_accounts_config()
|
| 134 |
+
|
| 135 |
+
print(f"[QuotaScheduler] 额度刷新完成: 成功 {success_count}, 失败 {fail_count}")
|
| 136 |
+
return result_dict
|
| 137 |
+
|
| 138 |
+
def _save_accounts_config(self):
|
| 139 |
+
"""保存账号配置"""
|
| 140 |
+
try:
|
| 141 |
+
from .state import state
|
| 142 |
+
state._save_accounts()
|
| 143 |
+
except Exception as e:
|
| 144 |
+
print(f"[QuotaScheduler] 保存账号配置失败: {e}")
|
| 145 |
+
|
| 146 |
+
async def refresh_account(self, account_id: str) -> bool:
|
| 147 |
+
"""刷新单个账号额度
|
| 148 |
+
|
| 149 |
+
Args:
|
| 150 |
+
account_id: 账号ID
|
| 151 |
+
|
| 152 |
+
Returns:
|
| 153 |
+
是否成功
|
| 154 |
+
"""
|
| 155 |
+
accounts = self._get_accounts()
|
| 156 |
+
account = next((acc for acc in accounts if acc.id == account_id), None)
|
| 157 |
+
|
| 158 |
+
if not account:
|
| 159 |
+
print(f"[QuotaScheduler] 账号不存在: {account_id}")
|
| 160 |
+
return False
|
| 161 |
+
|
| 162 |
+
success = await self._refresh_account_internal(account)
|
| 163 |
+
|
| 164 |
+
if success:
|
| 165 |
+
await self.quota_cache.save_to_file_async()
|
| 166 |
+
self._save_accounts_config()
|
| 167 |
+
|
| 168 |
+
return success
|
| 169 |
+
|
| 170 |
+
async def _refresh_account_internal(self, account: 'Account') -> bool:
|
| 171 |
+
"""内部刷新账号额度方法
|
| 172 |
+
|
| 173 |
+
Args:
|
| 174 |
+
account: 账号对象
|
| 175 |
+
|
| 176 |
+
Returns:
|
| 177 |
+
是否成功
|
| 178 |
+
"""
|
| 179 |
+
try:
|
| 180 |
+
success, result = await get_account_usage(account)
|
| 181 |
+
|
| 182 |
+
if success:
|
| 183 |
+
quota = CachedQuota.from_usage_info(account.id, result)
|
| 184 |
+
self.quota_cache.set(account.id, quota)
|
| 185 |
+
|
| 186 |
+
# 额度为 0 时自动禁用账号
|
| 187 |
+
if quota.is_exhausted:
|
| 188 |
+
if account.enabled:
|
| 189 |
+
account.enabled = False
|
| 190 |
+
# 标记为自动禁用,避免与手动禁用混淆
|
| 191 |
+
if hasattr(account, "auto_disabled"):
|
| 192 |
+
account.auto_disabled = True
|
| 193 |
+
print(f"[QuotaScheduler] 账号 {account.id} ({account.name}) 额度已用尽,自动禁用")
|
| 194 |
+
else:
|
| 195 |
+
# 有额度时自动解禁账号(仅对自动禁用的账号生效,避免覆盖手动禁用/封禁)
|
| 196 |
+
if (not account.enabled) and getattr(account, "auto_disabled", False):
|
| 197 |
+
account.enabled = True
|
| 198 |
+
account.auto_disabled = False
|
| 199 |
+
print(f"[QuotaScheduler] 账号 {account.id} ({account.name}) 有可用额度,自动启用")
|
| 200 |
+
|
| 201 |
+
return True
|
| 202 |
+
else:
|
| 203 |
+
error_msg = result.get("error", "Unknown error") if isinstance(result, dict) else str(result)
|
| 204 |
+
quota = CachedQuota.from_error(account.id, error_msg)
|
| 205 |
+
self.quota_cache.set(account.id, quota)
|
| 206 |
+
print(f"[QuotaScheduler] 获取账号 {account.id} 额度失败: {error_msg}")
|
| 207 |
+
return False
|
| 208 |
+
|
| 209 |
+
except Exception as e:
|
| 210 |
+
error_msg = str(e)
|
| 211 |
+
quota = CachedQuota.from_error(account.id, error_msg)
|
| 212 |
+
self.quota_cache.set(account.id, quota)
|
| 213 |
+
print(f"[QuotaScheduler] 获取账号 {account.id} 额度异常: {error_msg}")
|
| 214 |
+
return False
|
| 215 |
+
|
| 216 |
+
def mark_active(self, account_id: str) -> None:
|
| 217 |
+
"""标记账号为活跃
|
| 218 |
+
|
| 219 |
+
Args:
|
| 220 |
+
account_id: 账号ID
|
| 221 |
+
"""
|
| 222 |
+
with self._lock:
|
| 223 |
+
self._active_accounts[account_id] = time.time()
|
| 224 |
+
|
| 225 |
+
def is_active(self, account_id: str) -> bool:
|
| 226 |
+
"""检查账号是否活跃
|
| 227 |
+
|
| 228 |
+
Args:
|
| 229 |
+
account_id: 账号ID
|
| 230 |
+
|
| 231 |
+
Returns:
|
| 232 |
+
是否在活跃时间窗口内
|
| 233 |
+
"""
|
| 234 |
+
with self._lock:
|
| 235 |
+
last_used = self._active_accounts.get(account_id)
|
| 236 |
+
if last_used is None:
|
| 237 |
+
return False
|
| 238 |
+
return (time.time() - last_used) < ACTIVE_WINDOW_SECONDS
|
| 239 |
+
|
| 240 |
+
def get_active_accounts(self) -> Set[str]:
|
| 241 |
+
"""获取活跃账号列表
|
| 242 |
+
|
| 243 |
+
Returns:
|
| 244 |
+
活跃账号ID集合
|
| 245 |
+
"""
|
| 246 |
+
current_time = time.time()
|
| 247 |
+
with self._lock:
|
| 248 |
+
return {
|
| 249 |
+
account_id
|
| 250 |
+
for account_id, last_used in self._active_accounts.items()
|
| 251 |
+
if (current_time - last_used) < ACTIVE_WINDOW_SECONDS
|
| 252 |
+
}
|
| 253 |
+
|
| 254 |
+
def cleanup_inactive(self) -> None:
|
| 255 |
+
"""清理不活跃的账号记录"""
|
| 256 |
+
current_time = time.time()
|
| 257 |
+
with self._lock:
|
| 258 |
+
self._active_accounts = {
|
| 259 |
+
account_id: last_used
|
| 260 |
+
for account_id, last_used in self._active_accounts.items()
|
| 261 |
+
if (current_time - last_used) < ACTIVE_WINDOW_SECONDS * 2
|
| 262 |
+
}
|
| 263 |
+
|
| 264 |
+
async def _update_loop(self) -> None:
|
| 265 |
+
"""定时更新循环"""
|
| 266 |
+
while self._running:
|
| 267 |
+
try:
|
| 268 |
+
await asyncio.sleep(self.update_interval)
|
| 269 |
+
|
| 270 |
+
if not self._running:
|
| 271 |
+
break
|
| 272 |
+
|
| 273 |
+
# 获取活跃账号
|
| 274 |
+
active_ids = self.get_active_accounts()
|
| 275 |
+
|
| 276 |
+
if active_ids:
|
| 277 |
+
print(f"[QuotaScheduler] 更新 {len(active_ids)} 个活跃账号的额度...")
|
| 278 |
+
|
| 279 |
+
accounts = self._get_accounts()
|
| 280 |
+
active_accounts = [acc for acc in accounts if acc.id in active_ids]
|
| 281 |
+
|
| 282 |
+
# 并发更新
|
| 283 |
+
tasks = [self._refresh_account_internal(acc) for acc in active_accounts]
|
| 284 |
+
await asyncio.gather(*tasks, return_exceptions=True)
|
| 285 |
+
|
| 286 |
+
# 保存缓存
|
| 287 |
+
await self.quota_cache.save_to_file_async()
|
| 288 |
+
|
| 289 |
+
# 清理不活跃记录
|
| 290 |
+
self.cleanup_inactive()
|
| 291 |
+
|
| 292 |
+
except asyncio.CancelledError:
|
| 293 |
+
break
|
| 294 |
+
except Exception as e:
|
| 295 |
+
print(f"[QuotaScheduler] 更新循环异常: {e}")
|
| 296 |
+
|
| 297 |
+
def get_last_full_refresh(self) -> Optional[float]:
|
| 298 |
+
"""获取最后一次全量刷新时间"""
|
| 299 |
+
return self._last_full_refresh
|
| 300 |
+
|
| 301 |
+
def get_status(self) -> dict:
|
| 302 |
+
"""获取调度器状态"""
|
| 303 |
+
return {
|
| 304 |
+
"running": self._running,
|
| 305 |
+
"update_interval": self.update_interval,
|
| 306 |
+
"active_accounts": list(self.get_active_accounts()),
|
| 307 |
+
"active_count": len(self.get_active_accounts()),
|
| 308 |
+
"last_full_refresh": self._last_full_refresh
|
| 309 |
+
}
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
# 全局调度器实例
|
| 313 |
+
_quota_scheduler: Optional[QuotaScheduler] = None
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
def get_quota_scheduler() -> QuotaScheduler:
|
| 317 |
+
"""获取全局调度器实例"""
|
| 318 |
+
global _quota_scheduler
|
| 319 |
+
if _quota_scheduler is None:
|
| 320 |
+
_quota_scheduler = QuotaScheduler()
|
| 321 |
+
return _quota_scheduler
|
KiroProxy/kiro_proxy/core/rate_limiter.py
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""请求限速器 - 降低账号封禁风险
|
| 2 |
+
|
| 3 |
+
通过限制请求频率来降低被检测为异常活动的风险:
|
| 4 |
+
- 每账号请求间隔
|
| 5 |
+
- 全局请求限制
|
| 6 |
+
- 突发请求检测
|
| 7 |
+
|
| 8 |
+
注意:429 冷却时间已改为自动管理(固定5分钟),不再需要手动配置
|
| 9 |
+
"""
|
| 10 |
+
import time
|
| 11 |
+
from dataclasses import dataclass, field
|
| 12 |
+
from typing import Dict, Optional
|
| 13 |
+
from collections import deque
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
@dataclass
|
| 17 |
+
class RateLimitConfig:
|
| 18 |
+
"""限速配置"""
|
| 19 |
+
# 每账号最小请求间隔(秒)
|
| 20 |
+
min_request_interval: float = 0.5
|
| 21 |
+
|
| 22 |
+
# 每账号每分钟最大请求数
|
| 23 |
+
max_requests_per_minute: int = 60
|
| 24 |
+
|
| 25 |
+
# 全局每分钟最大请求数
|
| 26 |
+
global_max_requests_per_minute: int = 120
|
| 27 |
+
|
| 28 |
+
# 是否启用限速
|
| 29 |
+
enabled: bool = False
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
@dataclass
|
| 33 |
+
class AccountRateState:
|
| 34 |
+
"""账号限速状态"""
|
| 35 |
+
last_request_time: float = 0
|
| 36 |
+
request_times: deque = field(default_factory=lambda: deque(maxlen=100))
|
| 37 |
+
|
| 38 |
+
def get_requests_in_window(self, window_seconds: int = 60) -> int:
|
| 39 |
+
"""获取时间窗口内的请求数"""
|
| 40 |
+
now = time.time()
|
| 41 |
+
cutoff = now - window_seconds
|
| 42 |
+
return sum(1 for t in self.request_times if t > cutoff)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class RateLimiter:
|
| 46 |
+
"""请求限速器"""
|
| 47 |
+
|
| 48 |
+
def __init__(self, config: RateLimitConfig = None):
|
| 49 |
+
self.config = config or RateLimitConfig()
|
| 50 |
+
self._account_states: Dict[str, AccountRateState] = {}
|
| 51 |
+
self._global_requests: deque = deque(maxlen=1000)
|
| 52 |
+
|
| 53 |
+
def _get_account_state(self, account_id: str) -> AccountRateState:
|
| 54 |
+
"""获取账号状态"""
|
| 55 |
+
if account_id not in self._account_states:
|
| 56 |
+
self._account_states[account_id] = AccountRateState()
|
| 57 |
+
return self._account_states[account_id]
|
| 58 |
+
|
| 59 |
+
def can_request(self, account_id: str) -> tuple:
|
| 60 |
+
"""检查是否可以发送请求
|
| 61 |
+
|
| 62 |
+
Returns:
|
| 63 |
+
(can_request, wait_seconds, reason)
|
| 64 |
+
"""
|
| 65 |
+
if not self.config.enabled:
|
| 66 |
+
return True, 0, None
|
| 67 |
+
|
| 68 |
+
now = time.time()
|
| 69 |
+
state = self._get_account_state(account_id)
|
| 70 |
+
|
| 71 |
+
# 检查最小请求间隔
|
| 72 |
+
time_since_last = now - state.last_request_time
|
| 73 |
+
if time_since_last < self.config.min_request_interval:
|
| 74 |
+
wait = self.config.min_request_interval - time_since_last
|
| 75 |
+
return False, wait, f"请求过快,请等待 {wait:.1f} 秒"
|
| 76 |
+
|
| 77 |
+
# 检查每账号每分钟限制
|
| 78 |
+
account_rpm = state.get_requests_in_window(60)
|
| 79 |
+
if account_rpm >= self.config.max_requests_per_minute:
|
| 80 |
+
return False, 2, f"账号请求过于频繁 ({account_rpm}/分钟)"
|
| 81 |
+
|
| 82 |
+
# 检查全局每分钟限制
|
| 83 |
+
global_rpm = sum(1 for t in self._global_requests if t > now - 60)
|
| 84 |
+
if global_rpm >= self.config.global_max_requests_per_minute:
|
| 85 |
+
return False, 1, f"全局请求过于频繁 ({global_rpm}/分钟)"
|
| 86 |
+
|
| 87 |
+
return True, 0, None
|
| 88 |
+
|
| 89 |
+
def record_request(self, account_id: str):
|
| 90 |
+
"""记录请求"""
|
| 91 |
+
now = time.time()
|
| 92 |
+
state = self._get_account_state(account_id)
|
| 93 |
+
state.last_request_time = now
|
| 94 |
+
state.request_times.append(now)
|
| 95 |
+
self._global_requests.append(now)
|
| 96 |
+
|
| 97 |
+
def get_stats(self) -> dict:
|
| 98 |
+
"""获取统计信息"""
|
| 99 |
+
now = time.time()
|
| 100 |
+
return {
|
| 101 |
+
"enabled": self.config.enabled,
|
| 102 |
+
"global_rpm": sum(1 for t in self._global_requests if t > now - 60),
|
| 103 |
+
"accounts": {
|
| 104 |
+
aid: {
|
| 105 |
+
"rpm": state.get_requests_in_window(60),
|
| 106 |
+
"last_request": now - state.last_request_time if state.last_request_time else None
|
| 107 |
+
}
|
| 108 |
+
for aid, state in self._account_states.items()
|
| 109 |
+
}
|
| 110 |
+
}
|
| 111 |
+
|
| 112 |
+
def update_config(self, **kwargs):
|
| 113 |
+
"""更新配置"""
|
| 114 |
+
for key, value in kwargs.items():
|
| 115 |
+
if hasattr(self.config, key):
|
| 116 |
+
setattr(self.config, key, value)
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
# 全局实例
|
| 120 |
+
rate_limiter = RateLimiter()
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def get_rate_limiter() -> RateLimiter:
|
| 124 |
+
"""获取限速器实例"""
|
| 125 |
+
return rate_limiter
|
KiroProxy/kiro_proxy/core/refresh_manager.py
ADDED
|
@@ -0,0 +1,888 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Token 刷新管理模块
|
| 2 |
+
|
| 3 |
+
提供 Token 批量刷新的管理功能,包括:
|
| 4 |
+
- 刷新进度跟踪
|
| 5 |
+
- 并发控制
|
| 6 |
+
- 重试机制配置
|
| 7 |
+
- 全局锁防止重复刷新
|
| 8 |
+
- Token 过期检测和自动刷新
|
| 9 |
+
- 指数退避重试策略
|
| 10 |
+
"""
|
| 11 |
+
import time
|
| 12 |
+
import asyncio
|
| 13 |
+
from dataclasses import dataclass, field, asdict
|
| 14 |
+
from typing import Optional, Dict, Any, List, Tuple, Callable, TYPE_CHECKING
|
| 15 |
+
from threading import Lock
|
| 16 |
+
|
| 17 |
+
if TYPE_CHECKING:
|
| 18 |
+
from .account import Account
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
@dataclass
|
| 22 |
+
class RefreshProgress:
|
| 23 |
+
"""刷新进度信息
|
| 24 |
+
|
| 25 |
+
用于跟踪批量 Token 刷新操作的进度状态。
|
| 26 |
+
|
| 27 |
+
Attributes:
|
| 28 |
+
total: 需要刷新的账号总数
|
| 29 |
+
completed: 已完成处理的账号数(包括成功和失败)
|
| 30 |
+
success: 刷新成功的账号数
|
| 31 |
+
failed: 刷新失败的账号数
|
| 32 |
+
current_account: 当前正在处理的账号ID
|
| 33 |
+
status: 刷新状态 - running(进行中), completed(已完成), error(出错)
|
| 34 |
+
started_at: 刷新开始时间戳
|
| 35 |
+
message: 状态消息,用于显示当前操作或错误信息
|
| 36 |
+
"""
|
| 37 |
+
total: int = 0
|
| 38 |
+
completed: int = 0
|
| 39 |
+
success: int = 0
|
| 40 |
+
failed: int = 0
|
| 41 |
+
current_account: Optional[str] = None
|
| 42 |
+
status: str = "running" # running, completed, error
|
| 43 |
+
started_at: float = field(default_factory=time.time)
|
| 44 |
+
message: Optional[str] = None
|
| 45 |
+
|
| 46 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 47 |
+
"""转换为字典格式
|
| 48 |
+
|
| 49 |
+
Returns:
|
| 50 |
+
包含所有进度信息的字典
|
| 51 |
+
"""
|
| 52 |
+
return asdict(self)
|
| 53 |
+
|
| 54 |
+
@property
|
| 55 |
+
def progress_percent(self) -> float:
|
| 56 |
+
"""计算完成百分比
|
| 57 |
+
|
| 58 |
+
Returns:
|
| 59 |
+
完成百分比(0-100)
|
| 60 |
+
"""
|
| 61 |
+
if self.total == 0:
|
| 62 |
+
return 0.0
|
| 63 |
+
return round((self.completed / self.total) * 100, 2)
|
| 64 |
+
|
| 65 |
+
@property
|
| 66 |
+
def elapsed_seconds(self) -> float:
|
| 67 |
+
"""计算已用时间(秒)
|
| 68 |
+
|
| 69 |
+
Returns:
|
| 70 |
+
从开始到现在的秒数
|
| 71 |
+
"""
|
| 72 |
+
return time.time() - self.started_at
|
| 73 |
+
|
| 74 |
+
def is_running(self) -> bool:
|
| 75 |
+
"""检查是否正在运行
|
| 76 |
+
|
| 77 |
+
Returns:
|
| 78 |
+
True 表示正在运行
|
| 79 |
+
"""
|
| 80 |
+
return self.status == "running"
|
| 81 |
+
|
| 82 |
+
def is_completed(self) -> bool:
|
| 83 |
+
"""检查是否已完成
|
| 84 |
+
|
| 85 |
+
Returns:
|
| 86 |
+
True 表示已完成(成功或出错)
|
| 87 |
+
"""
|
| 88 |
+
return self.status in ("completed", "error")
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
@dataclass
|
| 92 |
+
class RefreshConfig:
|
| 93 |
+
"""刷新配置
|
| 94 |
+
|
| 95 |
+
控制 Token 刷新行为的配置参数。
|
| 96 |
+
|
| 97 |
+
Attributes:
|
| 98 |
+
max_retries: 单个账号刷新失败时的最大重试次数
|
| 99 |
+
retry_base_delay: 重试基础延迟时间(秒),实际延迟会指数增长
|
| 100 |
+
concurrency: 并发刷新的账号数量
|
| 101 |
+
token_refresh_before_expiry: Token 过期前多少秒开始刷新(默认5分钟)
|
| 102 |
+
auto_refresh_interval: 自动刷新检查间隔(秒)
|
| 103 |
+
"""
|
| 104 |
+
max_retries: int = 3
|
| 105 |
+
retry_base_delay: float = 1.0
|
| 106 |
+
concurrency: int = 3
|
| 107 |
+
token_refresh_before_expiry: int = 300 # 5分钟
|
| 108 |
+
auto_refresh_interval: int = 60 # 1分钟
|
| 109 |
+
|
| 110 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 111 |
+
"""转换为字典格式
|
| 112 |
+
|
| 113 |
+
Returns:
|
| 114 |
+
包含所有配置项的字典
|
| 115 |
+
"""
|
| 116 |
+
return asdict(self)
|
| 117 |
+
|
| 118 |
+
@classmethod
|
| 119 |
+
def from_dict(cls, data: Dict[str, Any]) -> 'RefreshConfig':
|
| 120 |
+
"""从字典创建配置实例
|
| 121 |
+
|
| 122 |
+
Args:
|
| 123 |
+
data: 配置字典
|
| 124 |
+
|
| 125 |
+
Returns:
|
| 126 |
+
RefreshConfig 实例
|
| 127 |
+
"""
|
| 128 |
+
return cls(
|
| 129 |
+
max_retries=data.get("max_retries", 3),
|
| 130 |
+
retry_base_delay=data.get("retry_base_delay", 1.0),
|
| 131 |
+
concurrency=data.get("concurrency", 3),
|
| 132 |
+
token_refresh_before_expiry=data.get("token_refresh_before_expiry", 300),
|
| 133 |
+
auto_refresh_interval=data.get("auto_refresh_interval", 60)
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
def validate(self) -> bool:
|
| 137 |
+
"""验证配置有效性
|
| 138 |
+
|
| 139 |
+
Returns:
|
| 140 |
+
True 表示配置有效
|
| 141 |
+
|
| 142 |
+
Raises:
|
| 143 |
+
ValueError: 配置值无效时抛出
|
| 144 |
+
"""
|
| 145 |
+
if self.max_retries < 0:
|
| 146 |
+
raise ValueError("max_retries 不能为负数")
|
| 147 |
+
if self.retry_base_delay <= 0:
|
| 148 |
+
raise ValueError("retry_base_delay 必须大于0")
|
| 149 |
+
if self.concurrency < 1:
|
| 150 |
+
raise ValueError("concurrency 必须至少为1")
|
| 151 |
+
if self.token_refresh_before_expiry < 0:
|
| 152 |
+
raise ValueError("token_refresh_before_expiry 不能为负数")
|
| 153 |
+
if self.auto_refresh_interval < 1:
|
| 154 |
+
raise ValueError("auto_refresh_interval 必须至少为1秒")
|
| 155 |
+
return True
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
class RefreshManager:
|
| 159 |
+
"""Token 刷新管理器
|
| 160 |
+
|
| 161 |
+
管理 Token 批量刷新操作,提供:
|
| 162 |
+
- 全局锁机制防止重复刷新
|
| 163 |
+
- 进度跟踪
|
| 164 |
+
- 配置管理
|
| 165 |
+
- 自动 Token 刷新定时器
|
| 166 |
+
|
| 167 |
+
使用示例:
|
| 168 |
+
manager = get_refresh_manager()
|
| 169 |
+
if not manager.is_refreshing():
|
| 170 |
+
# 开始刷新操作
|
| 171 |
+
pass
|
| 172 |
+
"""
|
| 173 |
+
|
| 174 |
+
def __init__(self, config: Optional[RefreshConfig] = None):
|
| 175 |
+
"""初始化刷新管理器
|
| 176 |
+
|
| 177 |
+
Args:
|
| 178 |
+
config: 刷新配置,None 则使用默认配置
|
| 179 |
+
"""
|
| 180 |
+
# 配置
|
| 181 |
+
self._config = config or RefreshConfig()
|
| 182 |
+
|
| 183 |
+
# 线程锁(用于同步访问状态)
|
| 184 |
+
self._lock = Lock()
|
| 185 |
+
|
| 186 |
+
# 异步锁(用于防止并发刷新操作)
|
| 187 |
+
self._async_lock = asyncio.Lock()
|
| 188 |
+
|
| 189 |
+
# 刷新状态
|
| 190 |
+
self._is_refreshing: bool = False
|
| 191 |
+
self._progress: Optional[RefreshProgress] = None
|
| 192 |
+
|
| 193 |
+
# 上次刷新完成时间
|
| 194 |
+
self._last_refresh_time: Optional[float] = None
|
| 195 |
+
|
| 196 |
+
# 自动刷新定时器
|
| 197 |
+
self._auto_refresh_task: Optional[asyncio.Task] = None
|
| 198 |
+
self._auto_refresh_running: bool = False
|
| 199 |
+
|
| 200 |
+
# 获取账号列表的回调函数
|
| 201 |
+
self._accounts_getter: Optional[Callable] = None
|
| 202 |
+
|
| 203 |
+
@property
|
| 204 |
+
def config(self) -> RefreshConfig:
|
| 205 |
+
"""获取当前配置
|
| 206 |
+
|
| 207 |
+
Returns:
|
| 208 |
+
当前的刷新配置
|
| 209 |
+
"""
|
| 210 |
+
with self._lock:
|
| 211 |
+
return self._config
|
| 212 |
+
|
| 213 |
+
def is_refreshing(self) -> bool:
|
| 214 |
+
"""检查是否正在刷新
|
| 215 |
+
|
| 216 |
+
Returns:
|
| 217 |
+
True 表示正在进行刷新操作
|
| 218 |
+
"""
|
| 219 |
+
with self._lock:
|
| 220 |
+
return self._is_refreshing
|
| 221 |
+
|
| 222 |
+
def get_progress(self) -> Optional[RefreshProgress]:
|
| 223 |
+
"""获取当前刷新进度
|
| 224 |
+
|
| 225 |
+
Returns:
|
| 226 |
+
当前进度信息,如果没有进行中的刷新则返回 None
|
| 227 |
+
"""
|
| 228 |
+
with self._lock:
|
| 229 |
+
return self._progress
|
| 230 |
+
|
| 231 |
+
def get_progress_dict(self) -> Optional[Dict[str, Any]]:
|
| 232 |
+
"""获取当前刷新进度(字典格式)
|
| 233 |
+
|
| 234 |
+
Returns:
|
| 235 |
+
进度信息字典,如果没有进行中的刷新则返回 None
|
| 236 |
+
"""
|
| 237 |
+
with self._lock:
|
| 238 |
+
if self._progress is None:
|
| 239 |
+
return None
|
| 240 |
+
return self._progress.to_dict()
|
| 241 |
+
|
| 242 |
+
def update_config(self, **kwargs) -> None:
|
| 243 |
+
"""更新配置参数
|
| 244 |
+
|
| 245 |
+
支持的参数:
|
| 246 |
+
max_retries: 最大重试次数
|
| 247 |
+
retry_base_delay: 重试基础延迟
|
| 248 |
+
concurrency: 并发数
|
| 249 |
+
token_refresh_before_expiry: Token 过期前刷新时间
|
| 250 |
+
auto_refresh_interval: 自动刷新检查间隔
|
| 251 |
+
|
| 252 |
+
Args:
|
| 253 |
+
**kwargs: 要更新的配置项
|
| 254 |
+
|
| 255 |
+
Raises:
|
| 256 |
+
ValueError: 配置值无效时抛出
|
| 257 |
+
"""
|
| 258 |
+
with self._lock:
|
| 259 |
+
# 创建新配置
|
| 260 |
+
new_config = RefreshConfig(
|
| 261 |
+
max_retries=kwargs.get("max_retries", self._config.max_retries),
|
| 262 |
+
retry_base_delay=kwargs.get("retry_base_delay", self._config.retry_base_delay),
|
| 263 |
+
concurrency=kwargs.get("concurrency", self._config.concurrency),
|
| 264 |
+
token_refresh_before_expiry=kwargs.get(
|
| 265 |
+
"token_refresh_before_expiry",
|
| 266 |
+
self._config.token_refresh_before_expiry
|
| 267 |
+
),
|
| 268 |
+
auto_refresh_interval=kwargs.get(
|
| 269 |
+
"auto_refresh_interval",
|
| 270 |
+
self._config.auto_refresh_interval
|
| 271 |
+
)
|
| 272 |
+
)
|
| 273 |
+
|
| 274 |
+
# 验证配置
|
| 275 |
+
new_config.validate()
|
| 276 |
+
|
| 277 |
+
# 应用新配置
|
| 278 |
+
self._config = new_config
|
| 279 |
+
|
| 280 |
+
def _start_refresh(self, total: int, message: Optional[str] = None) -> RefreshProgress:
|
| 281 |
+
"""开始刷新操作(内部方法)
|
| 282 |
+
|
| 283 |
+
Args:
|
| 284 |
+
total: 需要刷新的账号总数
|
| 285 |
+
message: 初始状态消息
|
| 286 |
+
|
| 287 |
+
Returns:
|
| 288 |
+
新创建的进度对象
|
| 289 |
+
"""
|
| 290 |
+
with self._lock:
|
| 291 |
+
self._is_refreshing = True
|
| 292 |
+
self._progress = RefreshProgress(
|
| 293 |
+
total=total,
|
| 294 |
+
completed=0,
|
| 295 |
+
success=0,
|
| 296 |
+
failed=0,
|
| 297 |
+
current_account=None,
|
| 298 |
+
status="running",
|
| 299 |
+
started_at=time.time(),
|
| 300 |
+
message=message or "开始刷新"
|
| 301 |
+
)
|
| 302 |
+
return self._progress
|
| 303 |
+
|
| 304 |
+
def _update_progress(
|
| 305 |
+
self,
|
| 306 |
+
current_account: Optional[str] = None,
|
| 307 |
+
success: bool = False,
|
| 308 |
+
failed: bool = False,
|
| 309 |
+
message: Optional[str] = None
|
| 310 |
+
) -> None:
|
| 311 |
+
"""更新刷新进度(内部方法)
|
| 312 |
+
|
| 313 |
+
Args:
|
| 314 |
+
current_account: 当前处理的账号ID
|
| 315 |
+
success: 是否成功完成一个账号
|
| 316 |
+
failed: 是否失败一个账号
|
| 317 |
+
message: 状态消息
|
| 318 |
+
"""
|
| 319 |
+
with self._lock:
|
| 320 |
+
if self._progress is None:
|
| 321 |
+
return
|
| 322 |
+
|
| 323 |
+
if current_account is not None:
|
| 324 |
+
self._progress.current_account = current_account
|
| 325 |
+
|
| 326 |
+
if success:
|
| 327 |
+
self._progress.success += 1
|
| 328 |
+
self._progress.completed += 1
|
| 329 |
+
elif failed:
|
| 330 |
+
self._progress.failed += 1
|
| 331 |
+
self._progress.completed += 1
|
| 332 |
+
|
| 333 |
+
if message is not None:
|
| 334 |
+
self._progress.message = message
|
| 335 |
+
|
| 336 |
+
def _finish_refresh(self, status: str = "completed", message: Optional[str] = None) -> None:
|
| 337 |
+
"""完成刷新操作(内部方法)
|
| 338 |
+
|
| 339 |
+
Args:
|
| 340 |
+
status: 最终状态 - completed 或 error
|
| 341 |
+
message: 最终状态消息
|
| 342 |
+
"""
|
| 343 |
+
with self._lock:
|
| 344 |
+
self._is_refreshing = False
|
| 345 |
+
self._last_refresh_time = time.time()
|
| 346 |
+
|
| 347 |
+
if self._progress is not None:
|
| 348 |
+
self._progress.status = status
|
| 349 |
+
self._progress.current_account = None
|
| 350 |
+
if message is not None:
|
| 351 |
+
self._progress.message = message
|
| 352 |
+
elif status == "completed":
|
| 353 |
+
self._progress.message = (
|
| 354 |
+
f"刷新完成: 成功 {self._progress.success}, "
|
| 355 |
+
f"失败 {self._progress.failed}"
|
| 356 |
+
)
|
| 357 |
+
|
| 358 |
+
def get_last_refresh_time(self) -> Optional[float]:
|
| 359 |
+
"""获取上次刷新完成时间
|
| 360 |
+
|
| 361 |
+
Returns:
|
| 362 |
+
上次刷新完成的时间戳,如果从未刷新则返回 None
|
| 363 |
+
"""
|
| 364 |
+
with self._lock:
|
| 365 |
+
return self._last_refresh_time
|
| 366 |
+
|
| 367 |
+
def get_status(self) -> Dict[str, Any]:
|
| 368 |
+
"""获取管理器状态
|
| 369 |
+
|
| 370 |
+
Returns:
|
| 371 |
+
包含管理器状态信息的字典
|
| 372 |
+
"""
|
| 373 |
+
with self._lock:
|
| 374 |
+
return {
|
| 375 |
+
"is_refreshing": self._is_refreshing,
|
| 376 |
+
"progress": self._progress.to_dict() if self._progress else None,
|
| 377 |
+
"last_refresh_time": self._last_refresh_time,
|
| 378 |
+
"config": self._config.to_dict()
|
| 379 |
+
}
|
| 380 |
+
|
| 381 |
+
async def acquire_refresh_lock(self) -> bool:
|
| 382 |
+
"""尝试获取刷新锁
|
| 383 |
+
|
| 384 |
+
用于在开始刷新操作前获取异步锁,防止并发刷新。
|
| 385 |
+
|
| 386 |
+
Returns:
|
| 387 |
+
True 表示成功获取锁,False 表示已有刷新在进行
|
| 388 |
+
"""
|
| 389 |
+
if self._async_lock.locked():
|
| 390 |
+
return False
|
| 391 |
+
|
| 392 |
+
await self._async_lock.acquire()
|
| 393 |
+
return True
|
| 394 |
+
|
| 395 |
+
def release_refresh_lock(self) -> None:
|
| 396 |
+
"""释放刷新锁
|
| 397 |
+
|
| 398 |
+
在刷新操作完成后调用,释放异步锁。
|
| 399 |
+
"""
|
| 400 |
+
if self._async_lock.locked():
|
| 401 |
+
self._async_lock.release()
|
| 402 |
+
|
| 403 |
+
def should_refresh_token(self, account: 'Account') -> bool:
|
| 404 |
+
"""判断是否需要刷新 Token
|
| 405 |
+
|
| 406 |
+
检查账号的 Token 是否即将过期(过期前5分钟)或已过期。
|
| 407 |
+
|
| 408 |
+
Args:
|
| 409 |
+
account: 账号对象
|
| 410 |
+
|
| 411 |
+
Returns:
|
| 412 |
+
True 表示需要刷新 Token
|
| 413 |
+
"""
|
| 414 |
+
creds = account.get_credentials()
|
| 415 |
+
if creds is None:
|
| 416 |
+
return True # 无法获取凭证,需要刷新
|
| 417 |
+
|
| 418 |
+
# 检查是否已过期或即将过期
|
| 419 |
+
minutes_before = self._config.token_refresh_before_expiry // 60
|
| 420 |
+
return creds.is_expired() or creds.is_expiring_soon(minutes=minutes_before)
|
| 421 |
+
|
| 422 |
+
async def refresh_token_if_needed(self, account: 'Account') -> Tuple[bool, str]:
|
| 423 |
+
"""如果需要则刷新 Token
|
| 424 |
+
|
| 425 |
+
检查账号 Token 状态,如果即将过期或已过期则刷新。
|
| 426 |
+
|
| 427 |
+
Args:
|
| 428 |
+
account: 账号对象
|
| 429 |
+
|
| 430 |
+
Returns:
|
| 431 |
+
(success, message) 元组
|
| 432 |
+
- success: True 表示 Token 有效(无需刷新或刷新成功)
|
| 433 |
+
- message: 状态消息
|
| 434 |
+
"""
|
| 435 |
+
if not self.should_refresh_token(account):
|
| 436 |
+
return True, "Token 有效,无需刷新"
|
| 437 |
+
|
| 438 |
+
print(f"[RefreshManager] 账号 {account.id} Token 即将过期,开始刷新...")
|
| 439 |
+
|
| 440 |
+
success, result = await account.refresh_token()
|
| 441 |
+
|
| 442 |
+
if success:
|
| 443 |
+
print(f"[RefreshManager] 账号 {account.id} Token 刷新成功")
|
| 444 |
+
return True, "Token 刷新成功"
|
| 445 |
+
else:
|
| 446 |
+
print(f"[RefreshManager] 账号 {account.id} Token 刷新失败: {result}")
|
| 447 |
+
return False, f"Token 刷新失败: {result}"
|
| 448 |
+
|
| 449 |
+
async def refresh_account_with_token(
|
| 450 |
+
self,
|
| 451 |
+
account: 'Account',
|
| 452 |
+
get_quota_func: Optional[Callable] = None
|
| 453 |
+
) -> Tuple[bool, str]:
|
| 454 |
+
"""刷新单个账号(先刷新 Token,再获取额度)
|
| 455 |
+
|
| 456 |
+
Args:
|
| 457 |
+
account: 账号对象
|
| 458 |
+
get_quota_func: 获取额度的异步函数,接受 account 参数
|
| 459 |
+
|
| 460 |
+
Returns:
|
| 461 |
+
(success, message) 元组
|
| 462 |
+
"""
|
| 463 |
+
# 1. 先刷新 Token(如果需要)
|
| 464 |
+
token_success, token_msg = await self.refresh_token_if_needed(account)
|
| 465 |
+
|
| 466 |
+
if not token_success:
|
| 467 |
+
return False, token_msg
|
| 468 |
+
|
| 469 |
+
# 2. 获取额度(如果提供了获取函数)
|
| 470 |
+
if get_quota_func:
|
| 471 |
+
try:
|
| 472 |
+
quota_success, quota_result = await get_quota_func(account)
|
| 473 |
+
if quota_success:
|
| 474 |
+
return True, "刷新成功"
|
| 475 |
+
else:
|
| 476 |
+
error_msg = quota_result.get("error", "Unknown error") if isinstance(quota_result, dict) else str(quota_result)
|
| 477 |
+
return False, f"获取额度失败: {error_msg}"
|
| 478 |
+
except Exception as e:
|
| 479 |
+
return False, f"获取额度异常: {str(e)}"
|
| 480 |
+
|
| 481 |
+
return True, token_msg
|
| 482 |
+
|
| 483 |
+
async def retry_with_backoff(
|
| 484 |
+
self,
|
| 485 |
+
func: Callable,
|
| 486 |
+
*args,
|
| 487 |
+
max_retries: Optional[int] = None,
|
| 488 |
+
**kwargs
|
| 489 |
+
) -> Tuple[bool, Any]:
|
| 490 |
+
"""带指数退避的重试
|
| 491 |
+
|
| 492 |
+
执行异步函数,失败时使用指数退避策略重试。
|
| 493 |
+
|
| 494 |
+
Args:
|
| 495 |
+
func: 要执行的异步函数
|
| 496 |
+
*args: 传递给函数的位置参数
|
| 497 |
+
max_retries: 最大重试次数,None 则使用配置值
|
| 498 |
+
**kwargs: 传递给函数的关键字参数
|
| 499 |
+
|
| 500 |
+
Returns:
|
| 501 |
+
(success, result) 元组
|
| 502 |
+
- success: True 表示执行成功
|
| 503 |
+
- result: 成功时为函数返回值,失败时为错误信息
|
| 504 |
+
"""
|
| 505 |
+
retries = max_retries if max_retries is not None else self._config.max_retries
|
| 506 |
+
base_delay = self._config.retry_base_delay
|
| 507 |
+
|
| 508 |
+
last_error = None
|
| 509 |
+
|
| 510 |
+
for attempt in range(retries + 1):
|
| 511 |
+
try:
|
| 512 |
+
result = await func(*args, **kwargs)
|
| 513 |
+
|
| 514 |
+
# 检查返回值格式
|
| 515 |
+
if isinstance(result, tuple) and len(result) == 2:
|
| 516 |
+
success, data = result
|
| 517 |
+
if success:
|
| 518 |
+
return True, data
|
| 519 |
+
else:
|
| 520 |
+
last_error = data
|
| 521 |
+
# 检查是否是 429 错误
|
| 522 |
+
if self._is_rate_limit_error(data):
|
| 523 |
+
delay = self._get_rate_limit_delay(attempt, base_delay)
|
| 524 |
+
else:
|
| 525 |
+
delay = base_delay * (2 ** attempt)
|
| 526 |
+
else:
|
| 527 |
+
# 函数返回非元组,视为成功
|
| 528 |
+
return True, result
|
| 529 |
+
|
| 530 |
+
except Exception as e:
|
| 531 |
+
last_error = str(e)
|
| 532 |
+
delay = base_delay * (2 ** attempt)
|
| 533 |
+
|
| 534 |
+
# 如果还有重试机会,等待后重试
|
| 535 |
+
if attempt < retries:
|
| 536 |
+
print(f"[RefreshManager] 第 {attempt + 1} 次尝试失败,{delay:.1f}秒后重试...")
|
| 537 |
+
await asyncio.sleep(delay)
|
| 538 |
+
|
| 539 |
+
return False, last_error
|
| 540 |
+
|
| 541 |
+
def _is_rate_limit_error(self, error: Any) -> bool:
|
| 542 |
+
"""检查是否是限流错误(429)
|
| 543 |
+
|
| 544 |
+
Args:
|
| 545 |
+
error: 错误信息
|
| 546 |
+
|
| 547 |
+
Returns:
|
| 548 |
+
True 表示是限流错误
|
| 549 |
+
"""
|
| 550 |
+
if isinstance(error, str):
|
| 551 |
+
return "429" in error or "rate limit" in error.lower() or "请求过于频繁" in error
|
| 552 |
+
return False
|
| 553 |
+
|
| 554 |
+
def _get_rate_limit_delay(self, attempt: int, base_delay: float) -> float:
|
| 555 |
+
"""获取限流错误的等待时间
|
| 556 |
+
|
| 557 |
+
429 错误使用更长的等待时间。
|
| 558 |
+
|
| 559 |
+
Args:
|
| 560 |
+
attempt: 当前尝试次数(从0开始)
|
| 561 |
+
base_delay: 基础延迟
|
| 562 |
+
|
| 563 |
+
Returns:
|
| 564 |
+
等待时间(秒)
|
| 565 |
+
"""
|
| 566 |
+
# 429 错误使用 3 倍的基础延迟
|
| 567 |
+
return base_delay * 3 * (2 ** attempt)
|
| 568 |
+
|
| 569 |
+
async def refresh_all_with_token(
|
| 570 |
+
self,
|
| 571 |
+
accounts: List['Account'],
|
| 572 |
+
get_quota_func: Optional[Callable] = None,
|
| 573 |
+
skip_disabled: bool = True,
|
| 574 |
+
skip_error: bool = True
|
| 575 |
+
) -> RefreshProgress:
|
| 576 |
+
"""刷新所有账号(先刷新 Token,再获取额度)
|
| 577 |
+
|
| 578 |
+
使用全局锁防止并发刷新,支持进度跟踪。
|
| 579 |
+
|
| 580 |
+
Args:
|
| 581 |
+
accounts: 账号列表
|
| 582 |
+
get_quota_func: 获取额度的异步函数
|
| 583 |
+
skip_disabled: 是否跳过已禁用的账号
|
| 584 |
+
skip_error: 是否跳过已处于错误状态的账号
|
| 585 |
+
|
| 586 |
+
Returns:
|
| 587 |
+
刷新进度信息
|
| 588 |
+
"""
|
| 589 |
+
# 尝试获取锁
|
| 590 |
+
if not await self.acquire_refresh_lock():
|
| 591 |
+
# 已有刷新在进行
|
| 592 |
+
progress = self.get_progress()
|
| 593 |
+
if progress:
|
| 594 |
+
return progress
|
| 595 |
+
# 返回一个错误状态的进度
|
| 596 |
+
return RefreshProgress(
|
| 597 |
+
total=0,
|
| 598 |
+
status="error",
|
| 599 |
+
message="刷新操作正在进行中"
|
| 600 |
+
)
|
| 601 |
+
|
| 602 |
+
try:
|
| 603 |
+
# 过滤账号
|
| 604 |
+
accounts_to_refresh = []
|
| 605 |
+
for acc in accounts:
|
| 606 |
+
if skip_disabled and not acc.enabled:
|
| 607 |
+
continue
|
| 608 |
+
if skip_error and acc.status.value in ("unhealthy", "suspended"):
|
| 609 |
+
continue
|
| 610 |
+
accounts_to_refresh.append(acc)
|
| 611 |
+
|
| 612 |
+
total = len(accounts_to_refresh)
|
| 613 |
+
|
| 614 |
+
# 开始刷新
|
| 615 |
+
self._start_refresh(total, f"开始刷新 {total} 个账号")
|
| 616 |
+
|
| 617 |
+
if total == 0:
|
| 618 |
+
self._finish_refresh("completed", "没有需要刷新的账号")
|
| 619 |
+
return self.get_progress()
|
| 620 |
+
|
| 621 |
+
# 使用信号量控制并发
|
| 622 |
+
semaphore = asyncio.Semaphore(self._config.concurrency)
|
| 623 |
+
|
| 624 |
+
async def refresh_one(account: 'Account'):
|
| 625 |
+
async with semaphore:
|
| 626 |
+
self._update_progress(
|
| 627 |
+
current_account=account.id,
|
| 628 |
+
message=f"正在刷新: {account.name}"
|
| 629 |
+
)
|
| 630 |
+
|
| 631 |
+
# 使用重试机制刷新
|
| 632 |
+
success, result = await self.retry_with_backoff(
|
| 633 |
+
self.refresh_account_with_token,
|
| 634 |
+
account,
|
| 635 |
+
get_quota_func
|
| 636 |
+
)
|
| 637 |
+
|
| 638 |
+
if success:
|
| 639 |
+
self._update_progress(success=True)
|
| 640 |
+
else:
|
| 641 |
+
self._update_progress(failed=True)
|
| 642 |
+
|
| 643 |
+
return success, result
|
| 644 |
+
|
| 645 |
+
# 并发执行
|
| 646 |
+
tasks = [refresh_one(acc) for acc in accounts_to_refresh]
|
| 647 |
+
await asyncio.gather(*tasks, return_exceptions=True)
|
| 648 |
+
|
| 649 |
+
# 完成
|
| 650 |
+
self._finish_refresh("completed")
|
| 651 |
+
return self.get_progress()
|
| 652 |
+
|
| 653 |
+
except Exception as e:
|
| 654 |
+
self._finish_refresh("error", f"刷新异常: {str(e)}")
|
| 655 |
+
return self.get_progress()
|
| 656 |
+
|
| 657 |
+
finally:
|
| 658 |
+
self.release_refresh_lock()
|
| 659 |
+
|
| 660 |
+
def _is_auth_error(self, error: Any) -> bool:
|
| 661 |
+
"""检查是否是认证错误(401)
|
| 662 |
+
|
| 663 |
+
Args:
|
| 664 |
+
error: 错误信息
|
| 665 |
+
|
| 666 |
+
Returns:
|
| 667 |
+
True 表示是认证错误
|
| 668 |
+
"""
|
| 669 |
+
if isinstance(error, str):
|
| 670 |
+
return "401" in error or "unauthorized" in error.lower() or "凭证已过期" in error or "需要重新登录" in error
|
| 671 |
+
return False
|
| 672 |
+
|
| 673 |
+
async def execute_with_auth_retry(
|
| 674 |
+
self,
|
| 675 |
+
account: 'Account',
|
| 676 |
+
func: Callable,
|
| 677 |
+
*args,
|
| 678 |
+
**kwargs
|
| 679 |
+
) -> Tuple[bool, Any]:
|
| 680 |
+
"""执行操作,遇到 401 错误时自动刷新 Token 并重试
|
| 681 |
+
|
| 682 |
+
Args:
|
| 683 |
+
account: 账号对象
|
| 684 |
+
func: 要执行的异步函数
|
| 685 |
+
*args: 传递给函数的位置参数
|
| 686 |
+
**kwargs: 传递给函数的关键字参数
|
| 687 |
+
|
| 688 |
+
Returns:
|
| 689 |
+
(success, result) 元组
|
| 690 |
+
"""
|
| 691 |
+
try:
|
| 692 |
+
result = await func(*args, **kwargs)
|
| 693 |
+
|
| 694 |
+
# 检查返回值
|
| 695 |
+
if isinstance(result, tuple) and len(result) == 2:
|
| 696 |
+
success, data = result
|
| 697 |
+
if success:
|
| 698 |
+
return True, data
|
| 699 |
+
|
| 700 |
+
# 检查是否是 401 错误
|
| 701 |
+
if self._is_auth_error(data):
|
| 702 |
+
print(f"[RefreshManager] 账号 {account.id} 遇到 401 错误,尝试刷新 Token...")
|
| 703 |
+
|
| 704 |
+
# 刷新 Token
|
| 705 |
+
refresh_success, refresh_msg = await account.refresh_token()
|
| 706 |
+
|
| 707 |
+
if refresh_success:
|
| 708 |
+
print(f"[RefreshManager] Token 刷新成功,重试请求...")
|
| 709 |
+
# 重试原请求
|
| 710 |
+
retry_result = await func(*args, **kwargs)
|
| 711 |
+
if isinstance(retry_result, tuple) and len(retry_result) == 2:
|
| 712 |
+
return retry_result
|
| 713 |
+
return True, retry_result
|
| 714 |
+
else:
|
| 715 |
+
return False, f"Token 刷新失败: {refresh_msg}"
|
| 716 |
+
|
| 717 |
+
return False, data
|
| 718 |
+
|
| 719 |
+
return True, result
|
| 720 |
+
|
| 721 |
+
except Exception as e:
|
| 722 |
+
error_str = str(e)
|
| 723 |
+
|
| 724 |
+
# 检查异常是否是 401 错误
|
| 725 |
+
if self._is_auth_error(error_str):
|
| 726 |
+
print(f"[RefreshManager] 账号 {account.id} 遇到 401 异常,尝试刷新 Token...")
|
| 727 |
+
|
| 728 |
+
refresh_success, refresh_msg = await account.refresh_token()
|
| 729 |
+
|
| 730 |
+
if refresh_success:
|
| 731 |
+
print(f"[RefreshManager] Token 刷新成功,重试请求...")
|
| 732 |
+
try:
|
| 733 |
+
retry_result = await func(*args, **kwargs)
|
| 734 |
+
if isinstance(retry_result, tuple) and len(retry_result) == 2:
|
| 735 |
+
return retry_result
|
| 736 |
+
return True, retry_result
|
| 737 |
+
except Exception as retry_e:
|
| 738 |
+
return False, f"重试失败: {str(retry_e)}"
|
| 739 |
+
else:
|
| 740 |
+
return False, f"Token 刷新失败: {refresh_msg}"
|
| 741 |
+
|
| 742 |
+
return False, error_str
|
| 743 |
+
|
| 744 |
+
def set_accounts_getter(self, getter: Callable) -> None:
|
| 745 |
+
"""设置获取账号列表的回调函数
|
| 746 |
+
|
| 747 |
+
Args:
|
| 748 |
+
getter: 返回账号列表的可调用对象
|
| 749 |
+
"""
|
| 750 |
+
self._accounts_getter = getter
|
| 751 |
+
|
| 752 |
+
def _get_accounts(self) -> List['Account']:
|
| 753 |
+
"""获取账号列表"""
|
| 754 |
+
if self._accounts_getter:
|
| 755 |
+
return self._accounts_getter()
|
| 756 |
+
return []
|
| 757 |
+
|
| 758 |
+
async def start_auto_refresh(self) -> None:
|
| 759 |
+
"""启动自动 Token 刷新定时器
|
| 760 |
+
|
| 761 |
+
定期检查所有账号的 Token 状态,自动刷新即将过期的 Token。
|
| 762 |
+
启动前会清除已存在的定时器,防止重复启动。
|
| 763 |
+
"""
|
| 764 |
+
# 先停止已存在的定时器
|
| 765 |
+
await self.stop_auto_refresh()
|
| 766 |
+
|
| 767 |
+
self._auto_refresh_running = True
|
| 768 |
+
self._auto_refresh_task = asyncio.create_task(self._auto_refresh_loop())
|
| 769 |
+
print(f"[RefreshManager] 自动 Token 刷新定时器已启动,检查间隔: {self._config.auto_refresh_interval}秒")
|
| 770 |
+
|
| 771 |
+
async def stop_auto_refresh(self) -> None:
|
| 772 |
+
"""停止自动 Token 刷新定时器"""
|
| 773 |
+
self._auto_refresh_running = False
|
| 774 |
+
|
| 775 |
+
if self._auto_refresh_task:
|
| 776 |
+
self._auto_refresh_task.cancel()
|
| 777 |
+
try:
|
| 778 |
+
await self._auto_refresh_task
|
| 779 |
+
except asyncio.CancelledError:
|
| 780 |
+
pass
|
| 781 |
+
self._auto_refresh_task = None
|
| 782 |
+
print("[RefreshManager] 自动 Token 刷新定时器已停止")
|
| 783 |
+
|
| 784 |
+
def is_auto_refresh_running(self) -> bool:
|
| 785 |
+
"""检查自动刷新定时器是否在运行
|
| 786 |
+
|
| 787 |
+
Returns:
|
| 788 |
+
True 表示定时器正在运行
|
| 789 |
+
"""
|
| 790 |
+
return self._auto_refresh_running and self._auto_refresh_task is not None
|
| 791 |
+
|
| 792 |
+
async def _auto_refresh_loop(self) -> None:
|
| 793 |
+
"""自动刷新循环
|
| 794 |
+
|
| 795 |
+
定期检查所有账号的 Token 状态,刷新即将过期的 Token。
|
| 796 |
+
跳过已禁用或错误状态的账号,单个失败不影响其他账号。
|
| 797 |
+
"""
|
| 798 |
+
while self._auto_refresh_running:
|
| 799 |
+
try:
|
| 800 |
+
await asyncio.sleep(self._config.auto_refresh_interval)
|
| 801 |
+
|
| 802 |
+
if not self._auto_refresh_running:
|
| 803 |
+
break
|
| 804 |
+
|
| 805 |
+
accounts = self._get_accounts()
|
| 806 |
+
if not accounts:
|
| 807 |
+
continue
|
| 808 |
+
|
| 809 |
+
# 检查需要刷新的账号
|
| 810 |
+
accounts_to_refresh = []
|
| 811 |
+
for account in accounts:
|
| 812 |
+
# 跳过已禁用的账号
|
| 813 |
+
if not account.enabled:
|
| 814 |
+
continue
|
| 815 |
+
|
| 816 |
+
# 跳过错误状态的账号
|
| 817 |
+
if hasattr(account, 'status') and account.status.value in ("unhealthy", "suspended", "disabled"):
|
| 818 |
+
continue
|
| 819 |
+
|
| 820 |
+
# 检查是否需要刷新 Token
|
| 821 |
+
if self.should_refresh_token(account):
|
| 822 |
+
accounts_to_refresh.append(account)
|
| 823 |
+
|
| 824 |
+
if accounts_to_refresh:
|
| 825 |
+
print(f"[RefreshManager] 发现 {len(accounts_to_refresh)} 个账号需要刷新 Token")
|
| 826 |
+
|
| 827 |
+
# 逐个刷新,单个失败不影响其他
|
| 828 |
+
for account in accounts_to_refresh:
|
| 829 |
+
try:
|
| 830 |
+
success, message = await self.refresh_token_if_needed(account)
|
| 831 |
+
if not success:
|
| 832 |
+
print(f"[RefreshManager] 账号 {account.id} 自动刷新失败: {message}")
|
| 833 |
+
except Exception as e:
|
| 834 |
+
print(f"[RefreshManager] 账号 {account.id} 自动刷新异常: {e}")
|
| 835 |
+
# 继续处理其他账号
|
| 836 |
+
|
| 837 |
+
except asyncio.CancelledError:
|
| 838 |
+
break
|
| 839 |
+
except Exception as e:
|
| 840 |
+
print(f"[RefreshManager] 自动刷新循环异常: {e}")
|
| 841 |
+
# 继续运行,不因异常停止
|
| 842 |
+
|
| 843 |
+
def get_auto_refresh_status(self) -> Dict[str, Any]:
|
| 844 |
+
"""获取自动刷新状态
|
| 845 |
+
|
| 846 |
+
Returns:
|
| 847 |
+
包含自动刷新状态信息的字典
|
| 848 |
+
"""
|
| 849 |
+
return {
|
| 850 |
+
"running": self.is_auto_refresh_running(),
|
| 851 |
+
"interval": self._config.auto_refresh_interval,
|
| 852 |
+
"token_refresh_before_expiry": self._config.token_refresh_before_expiry
|
| 853 |
+
}
|
| 854 |
+
|
| 855 |
+
|
| 856 |
+
# 全局刷新管理器实例
|
| 857 |
+
_refresh_manager: Optional[RefreshManager] = None
|
| 858 |
+
_manager_lock = Lock()
|
| 859 |
+
|
| 860 |
+
|
| 861 |
+
def get_refresh_manager() -> RefreshManager:
|
| 862 |
+
"""获取全局刷新管理器实例
|
| 863 |
+
|
| 864 |
+
使用单例模式,确保全局只有一个刷新管理器实例。
|
| 865 |
+
|
| 866 |
+
Returns:
|
| 867 |
+
全局 RefreshManager 实例
|
| 868 |
+
"""
|
| 869 |
+
global _refresh_manager
|
| 870 |
+
|
| 871 |
+
if _refresh_manager is None:
|
| 872 |
+
with _manager_lock:
|
| 873 |
+
# 双重检查锁定
|
| 874 |
+
if _refresh_manager is None:
|
| 875 |
+
_refresh_manager = RefreshManager()
|
| 876 |
+
|
| 877 |
+
return _refresh_manager
|
| 878 |
+
|
| 879 |
+
|
| 880 |
+
def reset_refresh_manager() -> None:
|
| 881 |
+
"""重置全局刷新管理器
|
| 882 |
+
|
| 883 |
+
主要用于测试场景,重置全局实例。
|
| 884 |
+
"""
|
| 885 |
+
global _refresh_manager
|
| 886 |
+
|
| 887 |
+
with _manager_lock:
|
| 888 |
+
_refresh_manager = None
|
KiroProxy/kiro_proxy/core/retry.py
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""请求重试机制"""
|
| 2 |
+
import asyncio
|
| 3 |
+
from typing import Callable, Any, Optional, Set
|
| 4 |
+
from functools import wraps
|
| 5 |
+
|
| 6 |
+
# 可重试的状态码
|
| 7 |
+
RETRYABLE_STATUS_CODES: Set[int] = {
|
| 8 |
+
408, # Request Timeout
|
| 9 |
+
500, # Internal Server Error
|
| 10 |
+
502, # Bad Gateway
|
| 11 |
+
503, # Service Unavailable
|
| 12 |
+
504, # Gateway Timeout
|
| 13 |
+
}
|
| 14 |
+
|
| 15 |
+
# 不可重试的状态码(直接返回错误)
|
| 16 |
+
NON_RETRYABLE_STATUS_CODES: Set[int] = {
|
| 17 |
+
400, # Bad Request
|
| 18 |
+
401, # Unauthorized
|
| 19 |
+
403, # Forbidden
|
| 20 |
+
404, # Not Found
|
| 21 |
+
422, # Unprocessable Entity
|
| 22 |
+
}
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def is_retryable_error(status_code: Optional[int], error: Optional[Exception] = None) -> bool:
|
| 26 |
+
"""判断是否为可重试的错误"""
|
| 27 |
+
# 网络错误可重试
|
| 28 |
+
if error:
|
| 29 |
+
error_name = type(error).__name__.lower()
|
| 30 |
+
if any(kw in error_name for kw in ['timeout', 'connect', 'network', 'reset']):
|
| 31 |
+
return True
|
| 32 |
+
|
| 33 |
+
# 特定状态码可重试
|
| 34 |
+
if status_code and status_code in RETRYABLE_STATUS_CODES:
|
| 35 |
+
return True
|
| 36 |
+
|
| 37 |
+
return False
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def is_non_retryable_error(status_code: Optional[int]) -> bool:
|
| 41 |
+
"""判断是否为不可重试的错误"""
|
| 42 |
+
return status_code in NON_RETRYABLE_STATUS_CODES if status_code else False
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
async def retry_async(
|
| 46 |
+
func: Callable,
|
| 47 |
+
max_retries: int = 2,
|
| 48 |
+
base_delay: float = 0.5,
|
| 49 |
+
max_delay: float = 5.0,
|
| 50 |
+
on_retry: Optional[Callable[[int, Exception], None]] = None
|
| 51 |
+
) -> Any:
|
| 52 |
+
"""
|
| 53 |
+
异步重试装饰器
|
| 54 |
+
|
| 55 |
+
Args:
|
| 56 |
+
func: 要执行的异步函数
|
| 57 |
+
max_retries: 最大重试次数
|
| 58 |
+
base_delay: 基础延迟(秒)
|
| 59 |
+
max_delay: 最大延迟(秒)
|
| 60 |
+
on_retry: 重试时的回调函数
|
| 61 |
+
"""
|
| 62 |
+
last_error = None
|
| 63 |
+
|
| 64 |
+
for attempt in range(max_retries + 1):
|
| 65 |
+
try:
|
| 66 |
+
return await func()
|
| 67 |
+
except Exception as e:
|
| 68 |
+
last_error = e
|
| 69 |
+
|
| 70 |
+
# 检查是否可重试
|
| 71 |
+
status_code = getattr(e, 'status_code', None)
|
| 72 |
+
if is_non_retryable_error(status_code):
|
| 73 |
+
raise
|
| 74 |
+
|
| 75 |
+
if attempt < max_retries and is_retryable_error(status_code, e):
|
| 76 |
+
# 指数退避
|
| 77 |
+
delay = min(base_delay * (2 ** attempt), max_delay)
|
| 78 |
+
|
| 79 |
+
if on_retry:
|
| 80 |
+
on_retry(attempt + 1, e)
|
| 81 |
+
else:
|
| 82 |
+
print(f"[Retry] 第 {attempt + 1} 次重试,延迟 {delay:.1f}s,错误: {type(e).__name__}")
|
| 83 |
+
|
| 84 |
+
await asyncio.sleep(delay)
|
| 85 |
+
else:
|
| 86 |
+
raise
|
| 87 |
+
|
| 88 |
+
raise last_error
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
class RetryableRequest:
|
| 92 |
+
"""可重试的请求上下文"""
|
| 93 |
+
|
| 94 |
+
def __init__(self, max_retries: int = 2, base_delay: float = 0.5):
|
| 95 |
+
self.max_retries = max_retries
|
| 96 |
+
self.base_delay = base_delay
|
| 97 |
+
self.attempt = 0
|
| 98 |
+
self.last_error = None
|
| 99 |
+
|
| 100 |
+
def should_retry(self, status_code: Optional[int] = None, error: Optional[Exception] = None) -> bool:
|
| 101 |
+
"""判断是否应该重试"""
|
| 102 |
+
self.attempt += 1
|
| 103 |
+
self.last_error = error
|
| 104 |
+
|
| 105 |
+
if self.attempt > self.max_retries:
|
| 106 |
+
return False
|
| 107 |
+
|
| 108 |
+
if is_non_retryable_error(status_code):
|
| 109 |
+
return False
|
| 110 |
+
|
| 111 |
+
return is_retryable_error(status_code, error)
|
| 112 |
+
|
| 113 |
+
async def wait(self):
|
| 114 |
+
"""等待重试延迟"""
|
| 115 |
+
delay = min(self.base_delay * (2 ** (self.attempt - 1)), 5.0)
|
| 116 |
+
print(f"[Retry] 第 {self.attempt} 次重试,延迟 {delay:.1f}s")
|
| 117 |
+
await asyncio.sleep(delay)
|
KiroProxy/kiro_proxy/core/scheduler.py
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""后台任务调度器"""
|
| 2 |
+
import asyncio
|
| 3 |
+
from typing import Optional
|
| 4 |
+
from datetime import datetime
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class BackgroundScheduler:
|
| 8 |
+
"""后台任务调度器
|
| 9 |
+
|
| 10 |
+
负责:
|
| 11 |
+
- Token 过期预刷新
|
| 12 |
+
- 账号健康检查
|
| 13 |
+
- 统计数据更新
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
def __init__(self):
|
| 17 |
+
self._task: Optional[asyncio.Task] = None
|
| 18 |
+
self._running = False
|
| 19 |
+
self._refresh_interval = 300 # 5 分钟检查一次
|
| 20 |
+
self._health_check_interval = 600 # 10 分钟健康检查
|
| 21 |
+
self._last_health_check = 0
|
| 22 |
+
|
| 23 |
+
async def start(self):
|
| 24 |
+
"""启动后台任务"""
|
| 25 |
+
if self._running:
|
| 26 |
+
return
|
| 27 |
+
self._running = True
|
| 28 |
+
self._task = asyncio.create_task(self._run())
|
| 29 |
+
print("[Scheduler] 后台任务已启动")
|
| 30 |
+
|
| 31 |
+
async def stop(self):
|
| 32 |
+
"""停止后台任务"""
|
| 33 |
+
self._running = False
|
| 34 |
+
if self._task:
|
| 35 |
+
self._task.cancel()
|
| 36 |
+
try:
|
| 37 |
+
await self._task
|
| 38 |
+
except asyncio.CancelledError:
|
| 39 |
+
pass
|
| 40 |
+
print("[Scheduler] 后台任务已停止")
|
| 41 |
+
|
| 42 |
+
async def _run(self):
|
| 43 |
+
"""主循环"""
|
| 44 |
+
from . import state
|
| 45 |
+
import time
|
| 46 |
+
|
| 47 |
+
while self._running:
|
| 48 |
+
try:
|
| 49 |
+
# Token 预刷新
|
| 50 |
+
await self._refresh_expiring_tokens(state)
|
| 51 |
+
|
| 52 |
+
# 健康检查
|
| 53 |
+
now = time.time()
|
| 54 |
+
if now - self._last_health_check > self._health_check_interval:
|
| 55 |
+
await self._health_check(state)
|
| 56 |
+
self._last_health_check = now
|
| 57 |
+
|
| 58 |
+
await asyncio.sleep(self._refresh_interval)
|
| 59 |
+
|
| 60 |
+
except asyncio.CancelledError:
|
| 61 |
+
break
|
| 62 |
+
except Exception as e:
|
| 63 |
+
print(f"[Scheduler] 错误: {e}")
|
| 64 |
+
await asyncio.sleep(60)
|
| 65 |
+
|
| 66 |
+
async def _refresh_expiring_tokens(self, state):
|
| 67 |
+
"""刷新即将过期的 Token"""
|
| 68 |
+
for acc in state.accounts:
|
| 69 |
+
if not acc.enabled:
|
| 70 |
+
continue
|
| 71 |
+
|
| 72 |
+
# 提前 15 分钟刷新
|
| 73 |
+
if acc.is_token_expiring_soon(15):
|
| 74 |
+
print(f"[Scheduler] Token 即将过期,预刷新: {acc.name}")
|
| 75 |
+
success, msg = await acc.refresh_token()
|
| 76 |
+
if success:
|
| 77 |
+
print(f"[Scheduler] Token 刷新成功: {acc.name}")
|
| 78 |
+
else:
|
| 79 |
+
print(f"[Scheduler] Token 刷新失败: {acc.name} - {msg}")
|
| 80 |
+
|
| 81 |
+
async def _health_check(self, state):
|
| 82 |
+
"""健康检查"""
|
| 83 |
+
import httpx
|
| 84 |
+
from ..config import MODELS_URL
|
| 85 |
+
from ..credential import CredentialStatus
|
| 86 |
+
|
| 87 |
+
for acc in state.accounts:
|
| 88 |
+
if not acc.enabled:
|
| 89 |
+
continue
|
| 90 |
+
|
| 91 |
+
try:
|
| 92 |
+
token = acc.get_token()
|
| 93 |
+
if not token:
|
| 94 |
+
acc.status = CredentialStatus.UNHEALTHY
|
| 95 |
+
continue
|
| 96 |
+
|
| 97 |
+
headers = {
|
| 98 |
+
"Authorization": f"Bearer {token}",
|
| 99 |
+
"content-type": "application/json"
|
| 100 |
+
}
|
| 101 |
+
|
| 102 |
+
async with httpx.AsyncClient(verify=False, timeout=10) as client:
|
| 103 |
+
resp = await client.get(
|
| 104 |
+
MODELS_URL,
|
| 105 |
+
headers=headers,
|
| 106 |
+
params={"origin": "AI_EDITOR"}
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
if resp.status_code == 200:
|
| 110 |
+
if acc.status == CredentialStatus.UNHEALTHY:
|
| 111 |
+
acc.status = CredentialStatus.ACTIVE
|
| 112 |
+
print(f"[HealthCheck] 账号恢复健康: {acc.name}")
|
| 113 |
+
elif resp.status_code == 401:
|
| 114 |
+
acc.status = CredentialStatus.UNHEALTHY
|
| 115 |
+
print(f"[HealthCheck] 账号认证失败: {acc.name}")
|
| 116 |
+
elif resp.status_code == 429:
|
| 117 |
+
# 配额超限,不改变状态
|
| 118 |
+
pass
|
| 119 |
+
|
| 120 |
+
except Exception as e:
|
| 121 |
+
print(f"[HealthCheck] 检查失败 {acc.name}: {e}")
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
# 全局调度器实例
|
| 125 |
+
scheduler = BackgroundScheduler()
|
KiroProxy/kiro_proxy/core/state.py
ADDED
|
@@ -0,0 +1,280 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""全局状态管理"""
|
| 2 |
+
import time
|
| 3 |
+
from collections import deque
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
from typing import Optional, List, Dict
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
|
| 8 |
+
from ..config import TOKEN_PATH
|
| 9 |
+
from ..credential import quota_manager, CredentialStatus
|
| 10 |
+
from .account import Account
|
| 11 |
+
from .persistence import load_accounts, save_accounts
|
| 12 |
+
from .quota_cache import get_quota_cache
|
| 13 |
+
from .account_selector import get_account_selector, SelectionStrategy
|
| 14 |
+
from .quota_scheduler import get_quota_scheduler
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@dataclass
|
| 18 |
+
class RequestLog:
|
| 19 |
+
"""请求日志"""
|
| 20 |
+
id: str
|
| 21 |
+
timestamp: float
|
| 22 |
+
method: str
|
| 23 |
+
path: str
|
| 24 |
+
model: str
|
| 25 |
+
account_id: Optional[str]
|
| 26 |
+
status: int
|
| 27 |
+
duration_ms: float
|
| 28 |
+
tokens_in: int = 0
|
| 29 |
+
tokens_out: int = 0
|
| 30 |
+
error: Optional[str] = None
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class ProxyState:
|
| 34 |
+
"""全局状态管理"""
|
| 35 |
+
|
| 36 |
+
def __init__(self):
|
| 37 |
+
self.accounts: List[Account] = []
|
| 38 |
+
self.request_logs: deque = deque(maxlen=1000)
|
| 39 |
+
self.total_requests: int = 0
|
| 40 |
+
self.total_errors: int = 0
|
| 41 |
+
self.session_locks: Dict[str, str] = {}
|
| 42 |
+
self.session_timestamps: Dict[str, float] = {}
|
| 43 |
+
self.start_time: float = time.time()
|
| 44 |
+
self._load_accounts()
|
| 45 |
+
|
| 46 |
+
def _load_accounts(self):
|
| 47 |
+
"""从配置文件加载账号"""
|
| 48 |
+
saved = load_accounts()
|
| 49 |
+
if saved:
|
| 50 |
+
for acc_data in saved:
|
| 51 |
+
# 验证 token 文件存在
|
| 52 |
+
if Path(acc_data.get("token_path", "")).exists():
|
| 53 |
+
self.accounts.append(Account(
|
| 54 |
+
id=acc_data["id"],
|
| 55 |
+
name=acc_data["name"],
|
| 56 |
+
token_path=acc_data["token_path"],
|
| 57 |
+
enabled=acc_data.get("enabled", True),
|
| 58 |
+
auto_disabled=acc_data.get("auto_disabled", False),
|
| 59 |
+
))
|
| 60 |
+
print(f"[State] 从配置加载 {len(self.accounts)} 个账号")
|
| 61 |
+
|
| 62 |
+
# 如果没有账号,尝试添加默认账号
|
| 63 |
+
if not self.accounts and TOKEN_PATH.exists():
|
| 64 |
+
self.accounts.append(Account(
|
| 65 |
+
id="default",
|
| 66 |
+
name="默认账号",
|
| 67 |
+
token_path=str(TOKEN_PATH)
|
| 68 |
+
))
|
| 69 |
+
self._save_accounts()
|
| 70 |
+
|
| 71 |
+
def _save_accounts(self):
|
| 72 |
+
"""保存账号到配置文件"""
|
| 73 |
+
accounts_data = [
|
| 74 |
+
{
|
| 75 |
+
"id": acc.id,
|
| 76 |
+
"name": acc.name,
|
| 77 |
+
"token_path": acc.token_path,
|
| 78 |
+
"enabled": acc.enabled,
|
| 79 |
+
"auto_disabled": getattr(acc, "auto_disabled", False),
|
| 80 |
+
}
|
| 81 |
+
for acc in self.accounts
|
| 82 |
+
]
|
| 83 |
+
save_accounts(accounts_data)
|
| 84 |
+
|
| 85 |
+
def get_available_account(self, session_id: Optional[str] = None) -> Optional[Account]:
|
| 86 |
+
"""获取可用账号(支持会话粘性和智能选择)"""
|
| 87 |
+
quota_manager.cleanup_expired()
|
| 88 |
+
|
| 89 |
+
selector = get_account_selector()
|
| 90 |
+
has_priority = bool(selector.get_priority_accounts())
|
| 91 |
+
use_session_sticky = bool(session_id) and not has_priority and selector.strategy != SelectionStrategy.RANDOM
|
| 92 |
+
|
| 93 |
+
# 会话粘性
|
| 94 |
+
if use_session_sticky and session_id in self.session_locks:
|
| 95 |
+
account_id = self.session_locks[session_id]
|
| 96 |
+
ts = self.session_timestamps.get(session_id, 0)
|
| 97 |
+
if time.time() - ts < 60:
|
| 98 |
+
for acc in self.accounts:
|
| 99 |
+
if acc.id == account_id and acc.is_available():
|
| 100 |
+
self.session_timestamps[session_id] = time.time()
|
| 101 |
+
return acc
|
| 102 |
+
|
| 103 |
+
# 使用 AccountSelector 选择账号
|
| 104 |
+
account = selector.select(self.accounts, session_id)
|
| 105 |
+
|
| 106 |
+
if account and use_session_sticky:
|
| 107 |
+
self.session_locks[session_id] = account.id
|
| 108 |
+
self.session_timestamps[session_id] = time.time()
|
| 109 |
+
|
| 110 |
+
# 标记为活跃账号,便于额度调度器定期更新
|
| 111 |
+
if account:
|
| 112 |
+
try:
|
| 113 |
+
get_quota_scheduler().mark_active(account.id)
|
| 114 |
+
except Exception:
|
| 115 |
+
pass
|
| 116 |
+
|
| 117 |
+
return account
|
| 118 |
+
|
| 119 |
+
def mark_account_used(self, account_id: str) -> None:
|
| 120 |
+
"""标记账号被使用"""
|
| 121 |
+
scheduler = get_quota_scheduler()
|
| 122 |
+
scheduler.mark_active(account_id)
|
| 123 |
+
|
| 124 |
+
for acc in self.accounts:
|
| 125 |
+
if acc.id == account_id:
|
| 126 |
+
acc.last_used = time.time()
|
| 127 |
+
break
|
| 128 |
+
|
| 129 |
+
def get_next_available_account(self, exclude_id: str) -> Optional[Account]:
|
| 130 |
+
"""获取下一个可用账号(排除指定账号)"""
|
| 131 |
+
available = [a for a in self.accounts if a.is_available() and a.id != exclude_id]
|
| 132 |
+
if not available:
|
| 133 |
+
return None
|
| 134 |
+
account = min(available, key=lambda a: a.request_count)
|
| 135 |
+
try:
|
| 136 |
+
get_quota_scheduler().mark_active(account.id)
|
| 137 |
+
except Exception:
|
| 138 |
+
pass
|
| 139 |
+
return account
|
| 140 |
+
|
| 141 |
+
def mark_rate_limited(self, account_id: str, duration_seconds: int = 60):
|
| 142 |
+
"""标记账号限流"""
|
| 143 |
+
for acc in self.accounts:
|
| 144 |
+
if acc.id == account_id:
|
| 145 |
+
acc.mark_quota_exceeded("Rate limited")
|
| 146 |
+
break
|
| 147 |
+
|
| 148 |
+
def mark_quota_exceeded(self, account_id: str, reason: str = "Quota exceeded"):
|
| 149 |
+
"""标记账号配额超限"""
|
| 150 |
+
for acc in self.accounts:
|
| 151 |
+
if acc.id == account_id:
|
| 152 |
+
acc.mark_quota_exceeded(reason)
|
| 153 |
+
break
|
| 154 |
+
|
| 155 |
+
async def refresh_account_token(self, account_id: str) -> tuple:
|
| 156 |
+
"""刷新指定账号的 token"""
|
| 157 |
+
for acc in self.accounts:
|
| 158 |
+
if acc.id == account_id:
|
| 159 |
+
return await acc.refresh_token()
|
| 160 |
+
return False, "账号不存在"
|
| 161 |
+
|
| 162 |
+
async def refresh_expiring_tokens(self) -> List[dict]:
|
| 163 |
+
"""刷新所有即将过期的 token"""
|
| 164 |
+
results = []
|
| 165 |
+
for acc in self.accounts:
|
| 166 |
+
if acc.enabled and acc.is_token_expiring_soon(10):
|
| 167 |
+
success, msg = await acc.refresh_token()
|
| 168 |
+
results.append({
|
| 169 |
+
"account_id": acc.id,
|
| 170 |
+
"success": success,
|
| 171 |
+
"message": msg
|
| 172 |
+
})
|
| 173 |
+
return results
|
| 174 |
+
|
| 175 |
+
def add_log(self, log: RequestLog):
|
| 176 |
+
"""添加请求日志"""
|
| 177 |
+
self.request_logs.append(log)
|
| 178 |
+
self.total_requests += 1
|
| 179 |
+
if log.error:
|
| 180 |
+
self.total_errors += 1
|
| 181 |
+
|
| 182 |
+
def get_stats(self) -> dict:
|
| 183 |
+
"""获取统计信息"""
|
| 184 |
+
uptime = time.time() - self.start_time
|
| 185 |
+
|
| 186 |
+
# 获取额度汇总
|
| 187 |
+
quota_cache = get_quota_cache()
|
| 188 |
+
quota_summary = quota_cache.get_summary()
|
| 189 |
+
|
| 190 |
+
# 获取选择器状态
|
| 191 |
+
selector = get_account_selector()
|
| 192 |
+
selector_status = selector.get_status()
|
| 193 |
+
|
| 194 |
+
# 获取调度器状态
|
| 195 |
+
scheduler = get_quota_scheduler()
|
| 196 |
+
scheduler_status = scheduler.get_status()
|
| 197 |
+
|
| 198 |
+
return {
|
| 199 |
+
"uptime_seconds": int(uptime),
|
| 200 |
+
"total_requests": self.total_requests,
|
| 201 |
+
"total_errors": self.total_errors,
|
| 202 |
+
"error_rate": f"{(self.total_errors / max(1, self.total_requests) * 100):.1f}%",
|
| 203 |
+
"accounts_total": len(self.accounts),
|
| 204 |
+
"accounts_available": len([a for a in self.accounts if a.is_available()]),
|
| 205 |
+
"accounts_cooldown": len([a for a in self.accounts if a.status == CredentialStatus.COOLDOWN]),
|
| 206 |
+
"recent_logs": len(self.request_logs),
|
| 207 |
+
# 新增字段
|
| 208 |
+
"quota_summary": quota_summary,
|
| 209 |
+
"selector": selector_status,
|
| 210 |
+
"scheduler": scheduler_status,
|
| 211 |
+
}
|
| 212 |
+
|
| 213 |
+
def get_accounts_status(self) -> List[dict]:
|
| 214 |
+
"""获取所有账号状态"""
|
| 215 |
+
return [acc.get_status_info() for acc in self.accounts]
|
| 216 |
+
|
| 217 |
+
def get_accounts_summary(self) -> dict:
|
| 218 |
+
"""获取账号汇总统计"""
|
| 219 |
+
quota_cache = get_quota_cache()
|
| 220 |
+
selector = get_account_selector()
|
| 221 |
+
scheduler = get_quota_scheduler()
|
| 222 |
+
|
| 223 |
+
total_balance = 0.0
|
| 224 |
+
total_usage = 0.0
|
| 225 |
+
total_limit = 0.0
|
| 226 |
+
|
| 227 |
+
available_count = 0
|
| 228 |
+
cooldown_count = 0
|
| 229 |
+
unhealthy_count = 0
|
| 230 |
+
disabled_count = 0
|
| 231 |
+
|
| 232 |
+
for acc in self.accounts:
|
| 233 |
+
if not acc.enabled:
|
| 234 |
+
disabled_count += 1
|
| 235 |
+
elif acc.status == CredentialStatus.COOLDOWN:
|
| 236 |
+
cooldown_count += 1
|
| 237 |
+
elif acc.status == CredentialStatus.UNHEALTHY:
|
| 238 |
+
unhealthy_count += 1
|
| 239 |
+
elif acc.is_available():
|
| 240 |
+
available_count += 1
|
| 241 |
+
|
| 242 |
+
quota = quota_cache.get(acc.id)
|
| 243 |
+
if quota and not quota.has_error():
|
| 244 |
+
total_balance += quota.balance
|
| 245 |
+
total_usage += quota.current_usage
|
| 246 |
+
total_limit += quota.usage_limit
|
| 247 |
+
|
| 248 |
+
last_refresh = scheduler.get_last_full_refresh()
|
| 249 |
+
last_refresh_ago = None
|
| 250 |
+
if last_refresh:
|
| 251 |
+
seconds_ago = time.time() - last_refresh
|
| 252 |
+
if seconds_ago < 60:
|
| 253 |
+
last_refresh_ago = f"{int(seconds_ago)}秒前"
|
| 254 |
+
elif seconds_ago < 3600:
|
| 255 |
+
last_refresh_ago = f"{int(seconds_ago / 60)}分钟前"
|
| 256 |
+
else:
|
| 257 |
+
last_refresh_ago = f"{int(seconds_ago / 3600)}小时前"
|
| 258 |
+
|
| 259 |
+
return {
|
| 260 |
+
"total_accounts": len(self.accounts),
|
| 261 |
+
"available_accounts": available_count,
|
| 262 |
+
"cooldown_accounts": cooldown_count,
|
| 263 |
+
"unhealthy_accounts": unhealthy_count,
|
| 264 |
+
"disabled_accounts": disabled_count,
|
| 265 |
+
"total_balance": round(total_balance, 2),
|
| 266 |
+
"total_usage": round(total_usage, 2),
|
| 267 |
+
"total_limit": round(total_limit, 2),
|
| 268 |
+
"last_refresh": last_refresh_ago,
|
| 269 |
+
"last_refresh_timestamp": last_refresh,
|
| 270 |
+
"strategy": selector.strategy.value,
|
| 271 |
+
"priority_accounts": selector.get_priority_accounts(),
|
| 272 |
+
}
|
| 273 |
+
|
| 274 |
+
def get_valid_account_ids(self) -> set:
|
| 275 |
+
"""获取所有有效账号ID集合"""
|
| 276 |
+
return {acc.id for acc in self.accounts if acc.enabled}
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
# 全局状态实例
|
| 280 |
+
state = ProxyState()
|
KiroProxy/kiro_proxy/core/stats.py
ADDED
|
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""请求统计增强"""
|
| 2 |
+
from collections import defaultdict
|
| 3 |
+
from dataclasses import dataclass, field
|
| 4 |
+
from typing import Dict, List
|
| 5 |
+
import time
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
@dataclass
|
| 9 |
+
class AccountStats:
|
| 10 |
+
"""账号统计"""
|
| 11 |
+
total_requests: int = 0
|
| 12 |
+
total_errors: int = 0
|
| 13 |
+
total_tokens_in: int = 0
|
| 14 |
+
total_tokens_out: int = 0
|
| 15 |
+
last_request_time: float = 0
|
| 16 |
+
|
| 17 |
+
def record(self, success: bool, tokens_in: int = 0, tokens_out: int = 0):
|
| 18 |
+
self.total_requests += 1
|
| 19 |
+
if not success:
|
| 20 |
+
self.total_errors += 1
|
| 21 |
+
self.total_tokens_in += tokens_in
|
| 22 |
+
self.total_tokens_out += tokens_out
|
| 23 |
+
self.last_request_time = time.time()
|
| 24 |
+
|
| 25 |
+
@property
|
| 26 |
+
def error_rate(self) -> float:
|
| 27 |
+
if self.total_requests == 0:
|
| 28 |
+
return 0
|
| 29 |
+
return self.total_errors / self.total_requests
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
@dataclass
|
| 33 |
+
class ModelStats:
|
| 34 |
+
"""模型统计"""
|
| 35 |
+
total_requests: int = 0
|
| 36 |
+
total_errors: int = 0
|
| 37 |
+
total_latency_ms: float = 0
|
| 38 |
+
|
| 39 |
+
def record(self, success: bool, latency_ms: float):
|
| 40 |
+
self.total_requests += 1
|
| 41 |
+
if not success:
|
| 42 |
+
self.total_errors += 1
|
| 43 |
+
self.total_latency_ms += latency_ms
|
| 44 |
+
|
| 45 |
+
@property
|
| 46 |
+
def avg_latency_ms(self) -> float:
|
| 47 |
+
if self.total_requests == 0:
|
| 48 |
+
return 0
|
| 49 |
+
return self.total_latency_ms / self.total_requests
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class StatsManager:
|
| 53 |
+
"""统计管理器"""
|
| 54 |
+
|
| 55 |
+
def __init__(self):
|
| 56 |
+
self.by_account: Dict[str, AccountStats] = defaultdict(AccountStats)
|
| 57 |
+
self.by_model: Dict[str, ModelStats] = defaultdict(ModelStats)
|
| 58 |
+
self.hourly_requests: Dict[int, int] = defaultdict(int) # hour -> count
|
| 59 |
+
|
| 60 |
+
def record_request(
|
| 61 |
+
self,
|
| 62 |
+
account_id: str,
|
| 63 |
+
model: str,
|
| 64 |
+
success: bool,
|
| 65 |
+
latency_ms: float,
|
| 66 |
+
tokens_in: int = 0,
|
| 67 |
+
tokens_out: int = 0
|
| 68 |
+
):
|
| 69 |
+
"""记录请求"""
|
| 70 |
+
# 按账号统计
|
| 71 |
+
self.by_account[account_id].record(success, tokens_in, tokens_out)
|
| 72 |
+
|
| 73 |
+
# 按模型统计
|
| 74 |
+
self.by_model[model].record(success, latency_ms)
|
| 75 |
+
|
| 76 |
+
# 按小时统计
|
| 77 |
+
hour = int(time.time() // 3600)
|
| 78 |
+
self.hourly_requests[hour] += 1
|
| 79 |
+
|
| 80 |
+
# 清理旧数据(保留 24 小时)
|
| 81 |
+
self._cleanup_hourly()
|
| 82 |
+
|
| 83 |
+
def _cleanup_hourly(self):
|
| 84 |
+
"""清理超过 24 小时的数据"""
|
| 85 |
+
current_hour = int(time.time() // 3600)
|
| 86 |
+
cutoff = current_hour - 24
|
| 87 |
+
self.hourly_requests = defaultdict(
|
| 88 |
+
int,
|
| 89 |
+
{h: c for h, c in self.hourly_requests.items() if h > cutoff}
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
def get_account_stats(self, account_id: str) -> dict:
|
| 93 |
+
"""获取账号统计"""
|
| 94 |
+
stats = self.by_account.get(account_id, AccountStats())
|
| 95 |
+
return {
|
| 96 |
+
"total_requests": stats.total_requests,
|
| 97 |
+
"total_errors": stats.total_errors,
|
| 98 |
+
"error_rate": f"{stats.error_rate * 100:.1f}%",
|
| 99 |
+
"total_tokens_in": stats.total_tokens_in,
|
| 100 |
+
"total_tokens_out": stats.total_tokens_out,
|
| 101 |
+
"last_request": stats.last_request_time
|
| 102 |
+
}
|
| 103 |
+
|
| 104 |
+
def get_model_stats(self, model: str) -> dict:
|
| 105 |
+
"""获取模型统计"""
|
| 106 |
+
stats = self.by_model.get(model, ModelStats())
|
| 107 |
+
return {
|
| 108 |
+
"total_requests": stats.total_requests,
|
| 109 |
+
"total_errors": stats.total_errors,
|
| 110 |
+
"avg_latency_ms": round(stats.avg_latency_ms, 2)
|
| 111 |
+
}
|
| 112 |
+
|
| 113 |
+
def get_all_stats(self) -> dict:
|
| 114 |
+
"""获取所有统计"""
|
| 115 |
+
return {
|
| 116 |
+
"by_account": {
|
| 117 |
+
acc_id: self.get_account_stats(acc_id)
|
| 118 |
+
for acc_id in self.by_account
|
| 119 |
+
},
|
| 120 |
+
"by_model": {
|
| 121 |
+
model: self.get_model_stats(model)
|
| 122 |
+
for model in self.by_model
|
| 123 |
+
},
|
| 124 |
+
"hourly_requests": dict(self.hourly_requests),
|
| 125 |
+
"requests_last_24h": sum(self.hourly_requests.values())
|
| 126 |
+
}
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
# 全局统计实例
|
| 130 |
+
stats_manager = StatsManager()
|
KiroProxy/kiro_proxy/core/thinking.py
ADDED
|
@@ -0,0 +1,456 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Thinking / Extended Thinking helpers.
|
| 2 |
+
|
| 3 |
+
This project implements "thinking" at the proxy layer by:
|
| 4 |
+
1) Making a separate Kiro request to generate internal reasoning text.
|
| 5 |
+
2) Injecting that reasoning back into the main user prompt (hidden) to improve quality.
|
| 6 |
+
3) Optionally returning the reasoning to clients in protocol-appropriate formats.
|
| 7 |
+
|
| 8 |
+
Notes:
|
| 9 |
+
- Kiro's upstream API doesn't expose a native "thinking budget" knob, so `budget_tokens`
|
| 10 |
+
is enforced only via prompt instructions (best-effort).
|
| 11 |
+
- If the client does not provide a budget, we treat it as "unlimited" (no prompt limit).
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
from __future__ import annotations
|
| 15 |
+
|
| 16 |
+
from dataclasses import dataclass
|
| 17 |
+
from typing import Any, AsyncIterator, Optional
|
| 18 |
+
|
| 19 |
+
import json
|
| 20 |
+
|
| 21 |
+
import httpx
|
| 22 |
+
|
| 23 |
+
from ..config import KIRO_API_URL
|
| 24 |
+
from ..kiro_api import build_kiro_request, parse_event_stream
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
@dataclass(frozen=True)
|
| 28 |
+
class ThinkingConfig:
|
| 29 |
+
enabled: bool
|
| 30 |
+
budget_tokens: Optional[int] = None # None == unlimited
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def _coerce_bool(value: Any) -> Optional[bool]:
|
| 34 |
+
if isinstance(value, bool):
|
| 35 |
+
return value
|
| 36 |
+
if isinstance(value, (int, float)):
|
| 37 |
+
return bool(value)
|
| 38 |
+
if isinstance(value, str):
|
| 39 |
+
v = value.strip().lower()
|
| 40 |
+
if v in {"true", "1", "yes", "y", "on", "enabled"}:
|
| 41 |
+
return True
|
| 42 |
+
if v in {"false", "0", "no", "n", "off", "disabled"}:
|
| 43 |
+
return False
|
| 44 |
+
return None
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def _coerce_int(value: Any) -> Optional[int]:
|
| 48 |
+
if value is None:
|
| 49 |
+
return None
|
| 50 |
+
if isinstance(value, bool):
|
| 51 |
+
return None
|
| 52 |
+
if isinstance(value, int):
|
| 53 |
+
return value
|
| 54 |
+
if isinstance(value, float):
|
| 55 |
+
return int(value)
|
| 56 |
+
if isinstance(value, str):
|
| 57 |
+
v = value.strip()
|
| 58 |
+
if not v:
|
| 59 |
+
return None
|
| 60 |
+
try:
|
| 61 |
+
return int(v)
|
| 62 |
+
except ValueError:
|
| 63 |
+
return None
|
| 64 |
+
return None
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def normalize_thinking_config(raw: Any) -> ThinkingConfig:
|
| 68 |
+
"""Normalize multiple "thinking" shapes into a single config.
|
| 69 |
+
|
| 70 |
+
Supported shapes (best-effort):
|
| 71 |
+
- None / missing: disabled
|
| 72 |
+
- bool: enabled/disabled
|
| 73 |
+
- str: "enabled"/"disabled"
|
| 74 |
+
- dict:
|
| 75 |
+
- {"type": "enabled", "budget_tokens": 20000} (Anthropic style)
|
| 76 |
+
- {"thinking_type": "enabled", "budget_tokens": 20000} (legacy)
|
| 77 |
+
- {"enabled": true, "budget_tokens": 20000}
|
| 78 |
+
- {"includeThoughts": true, "thinkingBudget": 20000} (Gemini-ish)
|
| 79 |
+
"""
|
| 80 |
+
if raw is None:
|
| 81 |
+
return ThinkingConfig(enabled=False, budget_tokens=None)
|
| 82 |
+
|
| 83 |
+
bool_value = _coerce_bool(raw)
|
| 84 |
+
if bool_value is not None and not isinstance(raw, dict):
|
| 85 |
+
return ThinkingConfig(enabled=bool_value, budget_tokens=None)
|
| 86 |
+
|
| 87 |
+
if isinstance(raw, dict):
|
| 88 |
+
mode = raw.get("type") or raw.get("thinking_type") or raw.get("mode")
|
| 89 |
+
enabled = None
|
| 90 |
+
if isinstance(mode, str):
|
| 91 |
+
enabled = _coerce_bool(mode)
|
| 92 |
+
if enabled is None:
|
| 93 |
+
enabled = _coerce_bool(raw.get("enabled"))
|
| 94 |
+
if enabled is None:
|
| 95 |
+
enabled = _coerce_bool(raw.get("includeThoughts") or raw.get("include_thoughts"))
|
| 96 |
+
if enabled is None:
|
| 97 |
+
enabled = False
|
| 98 |
+
|
| 99 |
+
budget_tokens = None
|
| 100 |
+
for key in (
|
| 101 |
+
"budget_tokens",
|
| 102 |
+
"budgetTokens",
|
| 103 |
+
"thinkingBudget",
|
| 104 |
+
"thinking_budget",
|
| 105 |
+
"max_thinking_length",
|
| 106 |
+
"maxThinkingLength",
|
| 107 |
+
):
|
| 108 |
+
if key in raw:
|
| 109 |
+
budget_tokens = _coerce_int(raw.get(key))
|
| 110 |
+
break
|
| 111 |
+
if budget_tokens is not None and budget_tokens <= 0:
|
| 112 |
+
budget_tokens = None
|
| 113 |
+
|
| 114 |
+
return ThinkingConfig(enabled=bool(enabled), budget_tokens=budget_tokens)
|
| 115 |
+
|
| 116 |
+
if isinstance(raw, str):
|
| 117 |
+
enabled = _coerce_bool(raw)
|
| 118 |
+
return ThinkingConfig(enabled=bool(enabled), budget_tokens=None)
|
| 119 |
+
|
| 120 |
+
return ThinkingConfig(enabled=False, budget_tokens=None)
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def map_openai_reasoning_effort_to_budget(effort: Any) -> Optional[int]:
|
| 124 |
+
"""Map OpenAI-style reasoning effort into a best-effort budget.
|
| 125 |
+
|
| 126 |
+
We keep this generous; if effort is "high", treat as unlimited.
|
| 127 |
+
"""
|
| 128 |
+
if not isinstance(effort, str):
|
| 129 |
+
return None
|
| 130 |
+
v = effort.strip().lower()
|
| 131 |
+
if v in {"high"}:
|
| 132 |
+
return None
|
| 133 |
+
if v in {"medium"}:
|
| 134 |
+
return 20000
|
| 135 |
+
if v in {"low"}:
|
| 136 |
+
return 10000
|
| 137 |
+
return None
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def extract_thinking_config_from_openai_body(body: dict) -> tuple[ThinkingConfig, bool]:
|
| 141 |
+
"""Extract thinking config from OpenAI ChatCompletions/Responses-style bodies."""
|
| 142 |
+
if not isinstance(body, dict):
|
| 143 |
+
return ThinkingConfig(False, None), False
|
| 144 |
+
|
| 145 |
+
if "thinking" in body:
|
| 146 |
+
return normalize_thinking_config(body.get("thinking")), True
|
| 147 |
+
|
| 148 |
+
# OpenAI Responses API style
|
| 149 |
+
reasoning = body.get("reasoning")
|
| 150 |
+
if "reasoning" in body:
|
| 151 |
+
if isinstance(reasoning, dict):
|
| 152 |
+
effort = reasoning.get("effort")
|
| 153 |
+
if isinstance(effort, str) and effort.strip().lower() in {"low", "medium", "high"}:
|
| 154 |
+
return ThinkingConfig(True, map_openai_reasoning_effort_to_budget(effort)), True
|
| 155 |
+
cfg = normalize_thinking_config(reasoning)
|
| 156 |
+
return cfg, True
|
| 157 |
+
|
| 158 |
+
effort = body.get("reasoning_effort")
|
| 159 |
+
if "reasoning_effort" in body and isinstance(effort, str) and effort.strip().lower() in {"low", "medium", "high"}:
|
| 160 |
+
return ThinkingConfig(True, map_openai_reasoning_effort_to_budget(effort)), True
|
| 161 |
+
|
| 162 |
+
return ThinkingConfig(False, None), False
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
def extract_thinking_config_from_gemini_body(body: dict) -> tuple[ThinkingConfig, bool]:
|
| 166 |
+
"""Extract thinking config from Gemini generateContent bodies (best-effort)."""
|
| 167 |
+
if not isinstance(body, dict):
|
| 168 |
+
return ThinkingConfig(False, None), False
|
| 169 |
+
|
| 170 |
+
if "thinking" in body:
|
| 171 |
+
return normalize_thinking_config(body.get("thinking")), True
|
| 172 |
+
|
| 173 |
+
if "thinkingConfig" in body:
|
| 174 |
+
return normalize_thinking_config(body.get("thinkingConfig")), True
|
| 175 |
+
|
| 176 |
+
gen_cfg = body.get("generationConfig")
|
| 177 |
+
if isinstance(gen_cfg, dict):
|
| 178 |
+
if "thinkingConfig" in gen_cfg:
|
| 179 |
+
raw = gen_cfg.get("thinkingConfig")
|
| 180 |
+
cfg = normalize_thinking_config(raw)
|
| 181 |
+
if cfg.enabled:
|
| 182 |
+
return cfg, True
|
| 183 |
+
# Budget without explicit includeThoughts/mode: treat as enabled (client guidance exists)
|
| 184 |
+
if isinstance(raw, dict) and any(
|
| 185 |
+
k in raw for k in ("thinkingBudget", "budgetTokens", "budget_tokens", "max_thinking_length")
|
| 186 |
+
):
|
| 187 |
+
return ThinkingConfig(True, cfg.budget_tokens), True
|
| 188 |
+
return cfg, True
|
| 189 |
+
|
| 190 |
+
return ThinkingConfig(False, None), False
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
def infer_thinking_from_anthropic_messages(messages: list[dict]) -> bool:
|
| 194 |
+
"""推断历史消息中是否包含思维链内容,用于在客户端未明确指定时自动启用思维链"""
|
| 195 |
+
for msg in messages or []:
|
| 196 |
+
content = msg.get("content")
|
| 197 |
+
if not isinstance(content, list):
|
| 198 |
+
continue
|
| 199 |
+
for block in content:
|
| 200 |
+
if isinstance(block, dict):
|
| 201 |
+
# 检查标准的 thinking 块
|
| 202 |
+
if block.get("type") == "thinking":
|
| 203 |
+
return True
|
| 204 |
+
# 检查文本块中嵌入的 <thinking> 标签(assistant 消息中可能存在)
|
| 205 |
+
if block.get("type") == "text" and msg.get("role") == "assistant":
|
| 206 |
+
text = block.get("text", "")
|
| 207 |
+
if isinstance(text, str) and "<thinking>" in text and "</thinking>" in text:
|
| 208 |
+
return True
|
| 209 |
+
return False
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
def infer_thinking_from_openai_messages(messages: list[dict]) -> bool:
|
| 213 |
+
for msg in messages or []:
|
| 214 |
+
content = msg.get("content", "")
|
| 215 |
+
if isinstance(content, str):
|
| 216 |
+
if "<thinking>" in content and "</thinking>" in content:
|
| 217 |
+
return True
|
| 218 |
+
continue
|
| 219 |
+
if isinstance(content, list):
|
| 220 |
+
for part in content:
|
| 221 |
+
if isinstance(part, dict) and part.get("type") == "text":
|
| 222 |
+
text = part.get("text", "")
|
| 223 |
+
if "<thinking>" in text and "</thinking>" in text:
|
| 224 |
+
return True
|
| 225 |
+
return False
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
def infer_thinking_from_openai_responses_input(input_data: Any) -> bool:
|
| 229 |
+
"""Infer thinking from OpenAI Responses API `input` payloads (best-effort)."""
|
| 230 |
+
if isinstance(input_data, str):
|
| 231 |
+
return "<thinking>" in input_data and "</thinking>" in input_data
|
| 232 |
+
|
| 233 |
+
if not isinstance(input_data, list):
|
| 234 |
+
return False
|
| 235 |
+
|
| 236 |
+
for item in input_data:
|
| 237 |
+
if not isinstance(item, dict):
|
| 238 |
+
continue
|
| 239 |
+
if item.get("type") != "message":
|
| 240 |
+
continue
|
| 241 |
+
|
| 242 |
+
content_list = item.get("content", []) or []
|
| 243 |
+
for c in content_list:
|
| 244 |
+
if isinstance(c, str):
|
| 245 |
+
if "<thinking>" in c and "</thinking>" in c:
|
| 246 |
+
return True
|
| 247 |
+
continue
|
| 248 |
+
if not isinstance(c, dict):
|
| 249 |
+
continue
|
| 250 |
+
c_type = c.get("type")
|
| 251 |
+
if c_type in {"input_text", "output_text", "text"}:
|
| 252 |
+
text = c.get("text", "")
|
| 253 |
+
if isinstance(text, str) and "<thinking>" in text and "</thinking>" in text:
|
| 254 |
+
return True
|
| 255 |
+
return False
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
def infer_thinking_from_gemini_contents(contents: list[dict]) -> bool:
|
| 259 |
+
for item in contents or []:
|
| 260 |
+
for part in item.get("parts", []) or []:
|
| 261 |
+
if isinstance(part, dict) and isinstance(part.get("text"), str):
|
| 262 |
+
text = part["text"]
|
| 263 |
+
if "<thinking>" in text and "</thinking>" in text:
|
| 264 |
+
return True
|
| 265 |
+
return False
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
import re
|
| 269 |
+
|
| 270 |
+
_THINKING_PATTERN = re.compile(r"<thinking>.*?</thinking>\s*", re.DOTALL)
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
def strip_thinking_from_text(text: str) -> str:
|
| 274 |
+
"""Remove <thinking> blocks from text."""
|
| 275 |
+
if not text or not isinstance(text, str):
|
| 276 |
+
return text
|
| 277 |
+
return _THINKING_PATTERN.sub("", text).strip()
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
def strip_thinking_from_history(history: list) -> list:
|
| 281 |
+
"""Return a copy of history with <thinking> blocks removed from all messages."""
|
| 282 |
+
if not history:
|
| 283 |
+
return []
|
| 284 |
+
|
| 285 |
+
cleaned = []
|
| 286 |
+
for msg in history:
|
| 287 |
+
if not isinstance(msg, dict):
|
| 288 |
+
cleaned.append(msg)
|
| 289 |
+
continue
|
| 290 |
+
|
| 291 |
+
new_msg = msg.copy()
|
| 292 |
+
content = msg.get("content")
|
| 293 |
+
|
| 294 |
+
if isinstance(content, str):
|
| 295 |
+
new_msg["content"] = strip_thinking_from_text(content)
|
| 296 |
+
elif isinstance(content, list):
|
| 297 |
+
new_content = []
|
| 298 |
+
for part in content:
|
| 299 |
+
if isinstance(part, dict) and part.get("type") == "text":
|
| 300 |
+
new_part = part.copy()
|
| 301 |
+
new_part["text"] = strip_thinking_from_text(part.get("text", ""))
|
| 302 |
+
new_content.append(new_part)
|
| 303 |
+
else:
|
| 304 |
+
new_content.append(part)
|
| 305 |
+
new_msg["content"] = new_content
|
| 306 |
+
|
| 307 |
+
cleaned.append(new_msg)
|
| 308 |
+
|
| 309 |
+
return cleaned
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
def format_thinking_block(thinking_content: str) -> str:
|
| 313 |
+
if thinking_content is None:
|
| 314 |
+
return ""
|
| 315 |
+
thinking_content = str(thinking_content).strip()
|
| 316 |
+
if not thinking_content:
|
| 317 |
+
return ""
|
| 318 |
+
return f"<thinking>\n{thinking_content}\n</thinking>"
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
def build_thinking_prompt(user_content: str, *, budget_tokens: Optional[int]) -> str:
|
| 322 |
+
"""Build a separate prompt using Tree of Thoughts approach.
|
| 323 |
+
|
| 324 |
+
Use multiple expert perspectives to analyze the problem deeply.
|
| 325 |
+
"""
|
| 326 |
+
if user_content is None:
|
| 327 |
+
user_content = ""
|
| 328 |
+
|
| 329 |
+
budget_str = ""
|
| 330 |
+
if budget_tokens:
|
| 331 |
+
budget_str = f" Budget: {budget_tokens} tokens."
|
| 332 |
+
|
| 333 |
+
return (
|
| 334 |
+
f"Think deeply and comprehensively about this problem.{budget_str}\n\n"
|
| 335 |
+
"Use the following approach:\n"
|
| 336 |
+
"1. Break down the problem into components\n"
|
| 337 |
+
"2. Consider multiple perspectives and solutions\n"
|
| 338 |
+
"3. Evaluate trade-offs and edge cases\n"
|
| 339 |
+
"4. Synthesize your analysis into a coherent response\n\n"
|
| 340 |
+
f"{user_content}"
|
| 341 |
+
)
|
| 342 |
+
|
| 343 |
+
def build_user_prompt_with_thinking(user_content: str, thinking_content: str) -> str:
|
| 344 |
+
"""Inject thinking into the main prompt.
|
| 345 |
+
|
| 346 |
+
Minimal injection to avoid context pollution.
|
| 347 |
+
"""
|
| 348 |
+
if user_content is None:
|
| 349 |
+
user_content = ""
|
| 350 |
+
|
| 351 |
+
thinking_block = format_thinking_block(thinking_content)
|
| 352 |
+
if not thinking_block:
|
| 353 |
+
return user_content
|
| 354 |
+
|
| 355 |
+
return f"{thinking_block}\n\n{user_content}"
|
| 356 |
+
|
| 357 |
+
|
| 358 |
+
async def iter_aws_event_stream_text(byte_iter: AsyncIterator[bytes]) -> AsyncIterator[str]:
|
| 359 |
+
"""Yield incremental text content from AWS event-stream chunks."""
|
| 360 |
+
buffer = b""
|
| 361 |
+
|
| 362 |
+
async for chunk in byte_iter:
|
| 363 |
+
buffer += chunk
|
| 364 |
+
|
| 365 |
+
while len(buffer) >= 12:
|
| 366 |
+
total_len = int.from_bytes(buffer[0:4], "big")
|
| 367 |
+
|
| 368 |
+
if total_len <= 0:
|
| 369 |
+
return
|
| 370 |
+
if len(buffer) < total_len:
|
| 371 |
+
break
|
| 372 |
+
|
| 373 |
+
headers_len = int.from_bytes(buffer[4:8], "big")
|
| 374 |
+
payload_start = 12 + headers_len
|
| 375 |
+
payload_end = total_len - 4
|
| 376 |
+
|
| 377 |
+
if payload_start < payload_end:
|
| 378 |
+
try:
|
| 379 |
+
payload = json.loads(buffer[payload_start:payload_end].decode("utf-8"))
|
| 380 |
+
content = None
|
| 381 |
+
if "assistantResponseEvent" in payload:
|
| 382 |
+
content = payload["assistantResponseEvent"].get("content")
|
| 383 |
+
elif "content" in payload and "toolUseId" not in payload:
|
| 384 |
+
content = payload.get("content")
|
| 385 |
+
if content:
|
| 386 |
+
yield content
|
| 387 |
+
except Exception:
|
| 388 |
+
pass
|
| 389 |
+
|
| 390 |
+
buffer = buffer[total_len:]
|
| 391 |
+
|
| 392 |
+
|
| 393 |
+
async def fetch_thinking_text(
|
| 394 |
+
*,
|
| 395 |
+
headers: dict,
|
| 396 |
+
model: str,
|
| 397 |
+
user_content: str,
|
| 398 |
+
history: list,
|
| 399 |
+
images: list | None = None,
|
| 400 |
+
tool_results: list | None = None,
|
| 401 |
+
budget_tokens: Optional[int] = None,
|
| 402 |
+
timeout_s: float = 600.0,
|
| 403 |
+
) -> str:
|
| 404 |
+
"""Non-streaming helper to get thinking content (best-effort)."""
|
| 405 |
+
thinking_prompt = build_thinking_prompt(user_content, budget_tokens=budget_tokens)
|
| 406 |
+
clean_history = strip_thinking_from_history(history)
|
| 407 |
+
thinking_request = build_kiro_request(
|
| 408 |
+
thinking_prompt,
|
| 409 |
+
model,
|
| 410 |
+
clean_history,
|
| 411 |
+
tools=None,
|
| 412 |
+
images=images,
|
| 413 |
+
tool_results=tool_results,
|
| 414 |
+
)
|
| 415 |
+
|
| 416 |
+
try:
|
| 417 |
+
async with httpx.AsyncClient(verify=False, timeout=timeout_s) as client:
|
| 418 |
+
resp = await client.post(KIRO_API_URL, json=thinking_request, headers=headers)
|
| 419 |
+
if resp.status_code != 200:
|
| 420 |
+
return ""
|
| 421 |
+
return parse_event_stream(resp.content)
|
| 422 |
+
except Exception:
|
| 423 |
+
return ""
|
| 424 |
+
|
| 425 |
+
|
| 426 |
+
async def stream_thinking_text(
|
| 427 |
+
*,
|
| 428 |
+
headers: dict,
|
| 429 |
+
model: str,
|
| 430 |
+
user_content: str,
|
| 431 |
+
history: list,
|
| 432 |
+
images: list | None = None,
|
| 433 |
+
tool_results: list | None = None,
|
| 434 |
+
budget_tokens: Optional[int] = None,
|
| 435 |
+
timeout_s: float = 600.0,
|
| 436 |
+
) -> AsyncIterator[str]:
|
| 437 |
+
"""Streaming helper to yield thinking content incrementally (best-effort)."""
|
| 438 |
+
thinking_prompt = build_thinking_prompt(user_content, budget_tokens=budget_tokens)
|
| 439 |
+
clean_history = strip_thinking_from_history(history)
|
| 440 |
+
thinking_request = build_kiro_request(
|
| 441 |
+
thinking_prompt,
|
| 442 |
+
model,
|
| 443 |
+
clean_history,
|
| 444 |
+
tools=None,
|
| 445 |
+
images=images,
|
| 446 |
+
tool_results=tool_results,
|
| 447 |
+
)
|
| 448 |
+
|
| 449 |
+
async with httpx.AsyncClient(verify=False, timeout=timeout_s) as client:
|
| 450 |
+
async with client.stream(
|
| 451 |
+
"POST", KIRO_API_URL, json=thinking_request, headers=headers
|
| 452 |
+
) as response:
|
| 453 |
+
if response.status_code != 200:
|
| 454 |
+
return
|
| 455 |
+
async for piece in iter_aws_event_stream_text(response.aiter_bytes()):
|
| 456 |
+
yield piece
|
KiroProxy/kiro_proxy/core/usage.py
ADDED
|
@@ -0,0 +1,235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Kiro 用量查询服务
|
| 2 |
+
|
| 3 |
+
通过调用 AWS Q 的 getUsageLimits API 获取用户的用量信息。
|
| 4 |
+
"""
|
| 5 |
+
import uuid
|
| 6 |
+
import httpx
|
| 7 |
+
from dataclasses import dataclass
|
| 8 |
+
from typing import Optional, Tuple
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
# API 端点
|
| 12 |
+
USAGE_LIMITS_URL = "https://q.us-east-1.amazonaws.com/getUsageLimits"
|
| 13 |
+
|
| 14 |
+
# 低余额阈值 (20%)
|
| 15 |
+
LOW_BALANCE_THRESHOLD = 0.2
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
@dataclass
|
| 19 |
+
class UsageInfo:
|
| 20 |
+
"""用量信息"""
|
| 21 |
+
subscription_title: str = ""
|
| 22 |
+
usage_limit: float = 0.0
|
| 23 |
+
current_usage: float = 0.0
|
| 24 |
+
balance: float = 0.0
|
| 25 |
+
is_low_balance: bool = False
|
| 26 |
+
|
| 27 |
+
# 详细信息
|
| 28 |
+
free_trial_limit: float = 0.0
|
| 29 |
+
free_trial_usage: float = 0.0
|
| 30 |
+
bonus_limit: float = 0.0
|
| 31 |
+
bonus_usage: float = 0.0
|
| 32 |
+
|
| 33 |
+
# 重置和过期时间
|
| 34 |
+
next_reset_date: Optional[str] = None # 下次重置时间
|
| 35 |
+
free_trial_expiry: Optional[str] = None # 免费试用过期时间
|
| 36 |
+
bonus_expiries: list = None # 奖励过期时间列表
|
| 37 |
+
|
| 38 |
+
def __post_init__(self):
|
| 39 |
+
if self.bonus_expiries is None:
|
| 40 |
+
self.bonus_expiries = []
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def build_usage_api_url(auth_method: str, profile_arn: Optional[str] = None) -> str:
|
| 44 |
+
"""构造 API 请求 URL"""
|
| 45 |
+
url = f"{USAGE_LIMITS_URL}?origin=AI_EDITOR&resourceType=AGENTIC_REQUEST"
|
| 46 |
+
|
| 47 |
+
# Social 认证需要 profileArn
|
| 48 |
+
if auth_method == "social" and profile_arn:
|
| 49 |
+
from urllib.parse import quote
|
| 50 |
+
url += f"&profileArn={quote(profile_arn)}"
|
| 51 |
+
|
| 52 |
+
return url
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def build_usage_headers(
|
| 56 |
+
access_token: str,
|
| 57 |
+
machine_id: str,
|
| 58 |
+
kiro_version: str = "1.0.0"
|
| 59 |
+
) -> dict:
|
| 60 |
+
"""构造请求头"""
|
| 61 |
+
import platform
|
| 62 |
+
os_name = platform.system().lower()
|
| 63 |
+
|
| 64 |
+
return {
|
| 65 |
+
"Authorization": f"Bearer {access_token}",
|
| 66 |
+
"User-Agent": f"aws-sdk-js/1.0.0 ua/2.1 os/{os_name} lang/python api/codewhispererruntime#1.0.0 m/N,E KiroIDE-{kiro_version}-{machine_id}",
|
| 67 |
+
"x-amz-user-agent": f"aws-sdk-js/1.0.0 KiroIDE-{kiro_version}-{machine_id}",
|
| 68 |
+
"amz-sdk-invocation-id": str(uuid.uuid4()),
|
| 69 |
+
"amz-sdk-request": "attempt=1; max=1",
|
| 70 |
+
"Connection": "close",
|
| 71 |
+
}
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def calculate_balance(response: dict) -> UsageInfo:
|
| 75 |
+
"""从 API 响应计算余额
|
| 76 |
+
|
| 77 |
+
注意:只计算 resourceType 为 CREDIT 的额度,忽略其他类型(如 AGENTIC_REQUEST)
|
| 78 |
+
"""
|
| 79 |
+
subscription_info = response.get("subscriptionInfo", {})
|
| 80 |
+
usage_breakdown_list = response.get("usageBreakdownList", [])
|
| 81 |
+
|
| 82 |
+
total_limit = 0.0
|
| 83 |
+
total_usage = 0.0
|
| 84 |
+
free_trial_limit = 0.0
|
| 85 |
+
free_trial_usage = 0.0
|
| 86 |
+
bonus_limit = 0.0
|
| 87 |
+
bonus_usage = 0.0
|
| 88 |
+
|
| 89 |
+
# 重置和过期时间
|
| 90 |
+
next_reset_date = response.get("nextDateReset") # 下次重置时间
|
| 91 |
+
free_trial_expiry = None
|
| 92 |
+
bonus_expiries = []
|
| 93 |
+
|
| 94 |
+
# 只查找 CREDIT 类型的额度
|
| 95 |
+
credit_breakdown = None
|
| 96 |
+
for breakdown in usage_breakdown_list:
|
| 97 |
+
resource_type = breakdown.get("resourceType", "")
|
| 98 |
+
display_name = breakdown.get("displayName", "")
|
| 99 |
+
if resource_type == "CREDIT" or display_name == "Credits":
|
| 100 |
+
credit_breakdown = breakdown
|
| 101 |
+
break
|
| 102 |
+
|
| 103 |
+
if credit_breakdown:
|
| 104 |
+
# 基本额度 (优先使用带精度的值)
|
| 105 |
+
total_limit = credit_breakdown.get("usageLimitWithPrecision", 0.0) or credit_breakdown.get("usageLimit", 0.0)
|
| 106 |
+
total_usage = credit_breakdown.get("currentUsageWithPrecision", 0.0) or credit_breakdown.get("currentUsage", 0.0)
|
| 107 |
+
|
| 108 |
+
# 免费试用额度 (只有状态为 ACTIVE 时才计算)
|
| 109 |
+
free_trial = credit_breakdown.get("freeTrialInfo")
|
| 110 |
+
if free_trial and free_trial.get("freeTrialStatus") == "ACTIVE":
|
| 111 |
+
ft_limit = free_trial.get("usageLimitWithPrecision", 0.0) or free_trial.get("usageLimit", 0.0)
|
| 112 |
+
ft_usage = free_trial.get("currentUsageWithPrecision", 0.0) or free_trial.get("currentUsage", 0.0)
|
| 113 |
+
total_limit += ft_limit
|
| 114 |
+
total_usage += ft_usage
|
| 115 |
+
free_trial_limit = ft_limit
|
| 116 |
+
free_trial_usage = ft_usage
|
| 117 |
+
# 获取免费试用过期时间
|
| 118 |
+
free_trial_expiry = free_trial.get("freeTrialExpiry")
|
| 119 |
+
|
| 120 |
+
# 奖励额度 (只计算状态为 ACTIVE 的奖励)
|
| 121 |
+
bonuses = credit_breakdown.get("bonuses", [])
|
| 122 |
+
for bonus in bonuses or []:
|
| 123 |
+
if bonus.get("status") == "ACTIVE":
|
| 124 |
+
b_limit = bonus.get("usageLimitWithPrecision", 0.0) or bonus.get("usageLimit", 0.0)
|
| 125 |
+
b_usage = bonus.get("currentUsageWithPrecision", 0.0) or bonus.get("currentUsage", 0.0)
|
| 126 |
+
total_limit += b_limit
|
| 127 |
+
total_usage += b_usage
|
| 128 |
+
bonus_limit += b_limit
|
| 129 |
+
bonus_usage += b_usage
|
| 130 |
+
# 获取奖励过期时间
|
| 131 |
+
expires_at = bonus.get("expiresAt")
|
| 132 |
+
if expires_at:
|
| 133 |
+
bonus_expiries.append(expires_at)
|
| 134 |
+
|
| 135 |
+
balance = total_limit - total_usage
|
| 136 |
+
is_low = (balance / total_limit) < LOW_BALANCE_THRESHOLD if total_limit > 0 else False
|
| 137 |
+
|
| 138 |
+
return UsageInfo(
|
| 139 |
+
subscription_title=subscription_info.get("subscriptionTitle", "Unknown"),
|
| 140 |
+
usage_limit=total_limit,
|
| 141 |
+
current_usage=total_usage,
|
| 142 |
+
balance=balance,
|
| 143 |
+
is_low_balance=is_low,
|
| 144 |
+
free_trial_limit=free_trial_limit,
|
| 145 |
+
free_trial_usage=free_trial_usage,
|
| 146 |
+
bonus_limit=bonus_limit,
|
| 147 |
+
bonus_usage=bonus_usage,
|
| 148 |
+
next_reset_date=next_reset_date,
|
| 149 |
+
free_trial_expiry=free_trial_expiry,
|
| 150 |
+
bonus_expiries=bonus_expiries,
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
async def get_usage_limits(
|
| 155 |
+
access_token: str,
|
| 156 |
+
auth_method: str = "social",
|
| 157 |
+
profile_arn: Optional[str] = None,
|
| 158 |
+
machine_id: str = "",
|
| 159 |
+
kiro_version: str = "1.0.0",
|
| 160 |
+
) -> Tuple[bool, UsageInfo | dict]:
|
| 161 |
+
"""
|
| 162 |
+
获取 Kiro 用量信息
|
| 163 |
+
|
| 164 |
+
Args:
|
| 165 |
+
access_token: Bearer token
|
| 166 |
+
auth_method: 认证方式 ("social" 或 "idc")
|
| 167 |
+
profile_arn: Social 认证需要的 profileArn
|
| 168 |
+
machine_id: 设备 ID
|
| 169 |
+
kiro_version: Kiro 版本号
|
| 170 |
+
|
| 171 |
+
Returns:
|
| 172 |
+
(success, UsageInfo or error_dict)
|
| 173 |
+
"""
|
| 174 |
+
if not access_token:
|
| 175 |
+
return False, {"error": "缺少 access token"}
|
| 176 |
+
|
| 177 |
+
if not machine_id:
|
| 178 |
+
return False, {"error": "缺少 machine ID"}
|
| 179 |
+
|
| 180 |
+
# 构造 URL 和请求头
|
| 181 |
+
url = build_usage_api_url(auth_method, profile_arn)
|
| 182 |
+
headers = build_usage_headers(access_token, machine_id, kiro_version)
|
| 183 |
+
|
| 184 |
+
try:
|
| 185 |
+
async with httpx.AsyncClient(timeout=10, verify=False) as client:
|
| 186 |
+
response = await client.get(url, headers=headers)
|
| 187 |
+
|
| 188 |
+
if response.status_code != 200:
|
| 189 |
+
return False, {"error": f"API 请求失败: {response.status_code} - {response.text[:200]}"}
|
| 190 |
+
|
| 191 |
+
data = response.json()
|
| 192 |
+
usage_info = calculate_balance(data)
|
| 193 |
+
return True, usage_info
|
| 194 |
+
|
| 195 |
+
except httpx.TimeoutException:
|
| 196 |
+
return False, {"error": "请求超时"}
|
| 197 |
+
except Exception as e:
|
| 198 |
+
return False, {"error": f"请求失败: {str(e)}"}
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
async def get_account_usage(account) -> Tuple[bool, UsageInfo | dict]:
|
| 202 |
+
"""
|
| 203 |
+
获取指定账号的用量信息
|
| 204 |
+
|
| 205 |
+
Args:
|
| 206 |
+
account: Account 对象
|
| 207 |
+
|
| 208 |
+
Returns:
|
| 209 |
+
(success, UsageInfo or error_dict)
|
| 210 |
+
"""
|
| 211 |
+
from ..credential import get_kiro_version
|
| 212 |
+
from .refresh_manager import get_refresh_manager
|
| 213 |
+
|
| 214 |
+
creds = account.get_credentials()
|
| 215 |
+
if not creds:
|
| 216 |
+
return False, {"error": "无法获取凭证"}
|
| 217 |
+
|
| 218 |
+
# 先刷新 Token(如即将过期/已过期),避免额度获取失败
|
| 219 |
+
refresh_manager = get_refresh_manager()
|
| 220 |
+
if refresh_manager.should_refresh_token(account):
|
| 221 |
+
token_success, token_msg = await refresh_manager.refresh_token_if_needed(account)
|
| 222 |
+
if not token_success:
|
| 223 |
+
return False, {"error": f"Token 刷新失败: {token_msg}"}
|
| 224 |
+
|
| 225 |
+
token = account.get_token()
|
| 226 |
+
if not token:
|
| 227 |
+
return False, {"error": "无法获取 token"}
|
| 228 |
+
|
| 229 |
+
return await get_usage_limits(
|
| 230 |
+
access_token=token,
|
| 231 |
+
auth_method=creds.auth_method or "social",
|
| 232 |
+
profile_arn=creds.profile_arn,
|
| 233 |
+
machine_id=account.get_machine_id(),
|
| 234 |
+
kiro_version=get_kiro_version(),
|
| 235 |
+
)
|
KiroProxy/kiro_proxy/credential/__init__.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""凭证管理模块"""
|
| 2 |
+
from .fingerprint import generate_machine_id, get_kiro_version, get_system_info
|
| 3 |
+
from .quota import QuotaManager, QuotaRecord, quota_manager
|
| 4 |
+
from .refresher import TokenRefresher
|
| 5 |
+
from .types import KiroCredentials, CredentialStatus
|
| 6 |
+
|
| 7 |
+
__all__ = [
|
| 8 |
+
"generate_machine_id",
|
| 9 |
+
"get_kiro_version",
|
| 10 |
+
"get_system_info",
|
| 11 |
+
"QuotaManager",
|
| 12 |
+
"QuotaRecord",
|
| 13 |
+
"quota_manager",
|
| 14 |
+
"TokenRefresher",
|
| 15 |
+
"KiroCredentials",
|
| 16 |
+
"CredentialStatus",
|
| 17 |
+
]
|
KiroProxy/kiro_proxy/credential/fingerprint.py
ADDED
|
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""设备指纹生成"""
|
| 2 |
+
import hashlib
|
| 3 |
+
import platform
|
| 4 |
+
import subprocess
|
| 5 |
+
import time
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from typing import Optional
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def get_raw_machine_id() -> Optional[str]:
|
| 11 |
+
"""获取系统原始 Machine ID"""
|
| 12 |
+
system = platform.system()
|
| 13 |
+
|
| 14 |
+
try:
|
| 15 |
+
if system == "Darwin":
|
| 16 |
+
result = subprocess.run(
|
| 17 |
+
["ioreg", "-rd1", "-c", "IOPlatformExpertDevice"],
|
| 18 |
+
capture_output=True, text=True, timeout=5
|
| 19 |
+
)
|
| 20 |
+
for line in result.stdout.split("\n"):
|
| 21 |
+
if "IOPlatformUUID" in line:
|
| 22 |
+
return line.split("=")[1].strip().strip('"').lower()
|
| 23 |
+
|
| 24 |
+
elif system == "Linux":
|
| 25 |
+
for path in ["/etc/machine-id", "/var/lib/dbus/machine-id"]:
|
| 26 |
+
if Path(path).exists():
|
| 27 |
+
return Path(path).read_text().strip().lower()
|
| 28 |
+
|
| 29 |
+
elif system == "Windows":
|
| 30 |
+
result = subprocess.run(
|
| 31 |
+
["wmic", "csproduct", "get", "UUID"],
|
| 32 |
+
capture_output=True, text=True, timeout=5,
|
| 33 |
+
creationflags=0x08000000
|
| 34 |
+
)
|
| 35 |
+
lines = [l.strip() for l in result.stdout.split("\n") if l.strip()]
|
| 36 |
+
if len(lines) > 1:
|
| 37 |
+
return lines[1].lower()
|
| 38 |
+
except Exception:
|
| 39 |
+
pass
|
| 40 |
+
|
| 41 |
+
return None
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def generate_machine_id(
|
| 45 |
+
profile_arn: Optional[str] = None,
|
| 46 |
+
client_id: Optional[str] = None
|
| 47 |
+
) -> str:
|
| 48 |
+
"""生成基于凭证的唯一 Machine ID
|
| 49 |
+
|
| 50 |
+
每个凭证生成独立的 Machine ID,避免多账号共用同一指纹被检测。
|
| 51 |
+
优先级:profileArn > clientId > 系统硬件 ID
|
| 52 |
+
添加时间因子:按小时变化,避免指纹完全固化。
|
| 53 |
+
"""
|
| 54 |
+
unique_key = None
|
| 55 |
+
if profile_arn:
|
| 56 |
+
unique_key = profile_arn
|
| 57 |
+
elif client_id:
|
| 58 |
+
unique_key = client_id
|
| 59 |
+
else:
|
| 60 |
+
unique_key = get_raw_machine_id() or "KIRO_DEFAULT_MACHINE"
|
| 61 |
+
|
| 62 |
+
hour_slot = int(time.time()) // 3600
|
| 63 |
+
|
| 64 |
+
hasher = hashlib.sha256()
|
| 65 |
+
hasher.update(unique_key.encode())
|
| 66 |
+
hasher.update(hour_slot.to_bytes(8, 'little'))
|
| 67 |
+
|
| 68 |
+
return hasher.hexdigest()
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def get_kiro_version() -> str:
|
| 72 |
+
"""获取 Kiro IDE 版本号
|
| 73 |
+
|
| 74 |
+
优先检测本地安装的 Kiro,否则使用默认版本 (与 kiro.rs 保持一致)
|
| 75 |
+
"""
|
| 76 |
+
if platform.system() == "Darwin":
|
| 77 |
+
kiro_paths = [
|
| 78 |
+
"/Applications/Kiro.app/Contents/Info.plist",
|
| 79 |
+
str(Path.home() / "Applications/Kiro.app/Contents/Info.plist"),
|
| 80 |
+
]
|
| 81 |
+
for plist_path in kiro_paths:
|
| 82 |
+
try:
|
| 83 |
+
result = subprocess.run(
|
| 84 |
+
["defaults", "read", plist_path, "CFBundleShortVersionString"],
|
| 85 |
+
capture_output=True, text=True, timeout=5
|
| 86 |
+
)
|
| 87 |
+
version = result.stdout.strip()
|
| 88 |
+
if version:
|
| 89 |
+
return version
|
| 90 |
+
except Exception:
|
| 91 |
+
pass
|
| 92 |
+
|
| 93 |
+
# 默认版本与 kiro.rs 保持一致
|
| 94 |
+
return "0.8.0"
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def get_system_info() -> tuple:
|
| 98 |
+
"""获取系统运行时信息 (os_name, node_version)
|
| 99 |
+
|
| 100 |
+
node_version 与 kiro.rs 保持一致
|
| 101 |
+
"""
|
| 102 |
+
system = platform.system()
|
| 103 |
+
|
| 104 |
+
if system == "Darwin":
|
| 105 |
+
try:
|
| 106 |
+
result = subprocess.run(
|
| 107 |
+
["sw_vers", "-productVersion"],
|
| 108 |
+
capture_output=True, text=True, timeout=5
|
| 109 |
+
)
|
| 110 |
+
version = result.stdout.strip() or "14.0"
|
| 111 |
+
os_name = f"macos#{version}"
|
| 112 |
+
except Exception:
|
| 113 |
+
os_name = "macos#14.0"
|
| 114 |
+
elif system == "Linux":
|
| 115 |
+
try:
|
| 116 |
+
result = subprocess.run(
|
| 117 |
+
["uname", "-r"],
|
| 118 |
+
capture_output=True, text=True, timeout=5
|
| 119 |
+
)
|
| 120 |
+
version = result.stdout.strip() or "5.15.0"
|
| 121 |
+
os_name = f"linux#{version}"
|
| 122 |
+
except Exception:
|
| 123 |
+
os_name = "linux#5.15.0"
|
| 124 |
+
elif system == "Windows":
|
| 125 |
+
os_name = "windows#10.0"
|
| 126 |
+
else:
|
| 127 |
+
os_name = "other#1.0"
|
| 128 |
+
|
| 129 |
+
# Node 版本与 kiro.rs 保持一致
|
| 130 |
+
node_version = "22.11.0"
|
| 131 |
+
return os_name, node_version
|
KiroProxy/kiro_proxy/credential/quota.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""配额管理"""
|
| 2 |
+
import time
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from typing import Dict, Optional
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
@dataclass
|
| 8 |
+
class QuotaRecord:
|
| 9 |
+
"""配额超限记录"""
|
| 10 |
+
credential_id: str
|
| 11 |
+
exceeded_at: float
|
| 12 |
+
cooldown_until: float
|
| 13 |
+
reason: str
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class QuotaManager:
|
| 17 |
+
"""配额管理器
|
| 18 |
+
|
| 19 |
+
管理凭证的配额超限状态:
|
| 20 |
+
- 仅在收到 429 错误时触发冷却
|
| 21 |
+
- 自动管理冷却时间:固定 5 分钟(300秒)
|
| 22 |
+
- 自动清理过期的冷却状态
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
# 固定冷却时间(秒)- 429 错误自动冷却 5 分钟
|
| 26 |
+
COOLDOWN_SECONDS = 300
|
| 27 |
+
|
| 28 |
+
def __init__(self):
|
| 29 |
+
self.exceeded_records: Dict[str, QuotaRecord] = {}
|
| 30 |
+
|
| 31 |
+
def is_429_error(self, status_code: Optional[int]) -> bool:
|
| 32 |
+
"""检查是否为 429 错误(仅 429 触发冷却)"""
|
| 33 |
+
return status_code == 429
|
| 34 |
+
|
| 35 |
+
def is_quota_exceeded_error(self, status_code: Optional[int], error_message: str) -> bool:
|
| 36 |
+
"""检查是否为配额超限错误(仅用于判断是否切换账号,不触发冷却)"""
|
| 37 |
+
# 仅 429 才算配额超限
|
| 38 |
+
return status_code == 429
|
| 39 |
+
|
| 40 |
+
def mark_exceeded(self, credential_id: str, reason: str) -> QuotaRecord:
|
| 41 |
+
"""标记凭证为配额超限(仅 429 时调用)
|
| 42 |
+
|
| 43 |
+
自动管理冷却时间:固定 5 分钟(300秒)
|
| 44 |
+
"""
|
| 45 |
+
now = time.time()
|
| 46 |
+
|
| 47 |
+
record = QuotaRecord(
|
| 48 |
+
credential_id=credential_id,
|
| 49 |
+
exceeded_at=now,
|
| 50 |
+
cooldown_until=now + self.COOLDOWN_SECONDS,
|
| 51 |
+
reason=reason
|
| 52 |
+
)
|
| 53 |
+
self.exceeded_records[credential_id] = record
|
| 54 |
+
|
| 55 |
+
print(f"[QuotaManager] 账号 {credential_id} 遇到 429 错误,自动冷却 {self.COOLDOWN_SECONDS} 秒(5分钟)")
|
| 56 |
+
return record
|
| 57 |
+
|
| 58 |
+
def is_available(self, credential_id: str) -> bool:
|
| 59 |
+
"""检查凭证是否可用"""
|
| 60 |
+
record = self.exceeded_records.get(credential_id)
|
| 61 |
+
if not record:
|
| 62 |
+
return True
|
| 63 |
+
|
| 64 |
+
if time.time() >= record.cooldown_until:
|
| 65 |
+
del self.exceeded_records[credential_id]
|
| 66 |
+
return True
|
| 67 |
+
|
| 68 |
+
return False
|
| 69 |
+
|
| 70 |
+
def get_cooldown_remaining(self, credential_id: str) -> Optional[int]:
|
| 71 |
+
"""获取剩余冷却时间(秒)"""
|
| 72 |
+
record = self.exceeded_records.get(credential_id)
|
| 73 |
+
if not record:
|
| 74 |
+
return None
|
| 75 |
+
|
| 76 |
+
remaining = record.cooldown_until - time.time()
|
| 77 |
+
if remaining <= 0:
|
| 78 |
+
del self.exceeded_records[credential_id]
|
| 79 |
+
return None
|
| 80 |
+
|
| 81 |
+
return int(remaining)
|
| 82 |
+
|
| 83 |
+
def cleanup_expired(self) -> int:
|
| 84 |
+
"""清理过期的冷却记录"""
|
| 85 |
+
now = time.time()
|
| 86 |
+
expired = [k for k, v in self.exceeded_records.items() if now >= v.cooldown_until]
|
| 87 |
+
for k in expired:
|
| 88 |
+
del self.exceeded_records[k]
|
| 89 |
+
return len(expired)
|
| 90 |
+
|
| 91 |
+
def restore(self, credential_id: str) -> bool:
|
| 92 |
+
"""手动恢复凭证"""
|
| 93 |
+
if credential_id in self.exceeded_records:
|
| 94 |
+
del self.exceeded_records[credential_id]
|
| 95 |
+
return True
|
| 96 |
+
return False
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
# 全局实例 - 429 自动冷却 5 分钟
|
| 100 |
+
quota_manager = QuotaManager()
|
KiroProxy/kiro_proxy/credential/refresher.py
ADDED
|
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Token 刷新器"""
|
| 2 |
+
import httpx
|
| 3 |
+
from datetime import datetime, timezone, timedelta
|
| 4 |
+
from typing import Tuple
|
| 5 |
+
|
| 6 |
+
from .types import KiroCredentials
|
| 7 |
+
from .fingerprint import generate_machine_id, get_kiro_version
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
# Kiro Auth 端点
|
| 11 |
+
KIRO_AUTH_ENDPOINT = "https://prod.us-east-1.auth.desktop.kiro.dev"
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class TokenRefresher:
|
| 15 |
+
"""Token 刷新器"""
|
| 16 |
+
|
| 17 |
+
def __init__(self, credentials: KiroCredentials):
|
| 18 |
+
self.credentials = credentials
|
| 19 |
+
|
| 20 |
+
def get_refresh_url(self) -> str:
|
| 21 |
+
"""获取刷新 URL"""
|
| 22 |
+
region = self.credentials.region or "us-east-1"
|
| 23 |
+
auth_method = (self.credentials.auth_method or "social").lower()
|
| 24 |
+
|
| 25 |
+
if auth_method == "idc":
|
| 26 |
+
# IDC (AWS Builder ID) 使用 OIDC 端点
|
| 27 |
+
return f"https://oidc.{region}.amazonaws.com/token"
|
| 28 |
+
else:
|
| 29 |
+
# Social (Google/GitHub) 使用 Kiro Auth 端点
|
| 30 |
+
return f"{KIRO_AUTH_ENDPOINT}/refreshToken"
|
| 31 |
+
|
| 32 |
+
def validate_refresh_token(self) -> Tuple[bool, str]:
|
| 33 |
+
"""验证 refresh_token 有效性"""
|
| 34 |
+
refresh_token = self.credentials.refresh_token
|
| 35 |
+
|
| 36 |
+
if not refresh_token:
|
| 37 |
+
return False, "缺少 refresh_token"
|
| 38 |
+
|
| 39 |
+
if len(refresh_token.strip()) == 0:
|
| 40 |
+
return False, "refresh_token 为空"
|
| 41 |
+
|
| 42 |
+
if len(refresh_token) < 100 or refresh_token.endswith("..."):
|
| 43 |
+
return False, f"refresh_token 已被截断(长度: {len(refresh_token)})"
|
| 44 |
+
|
| 45 |
+
return True, ""
|
| 46 |
+
|
| 47 |
+
def _get_machine_id(self) -> str:
|
| 48 |
+
"""获取 Machine ID"""
|
| 49 |
+
return generate_machine_id(
|
| 50 |
+
self.credentials.profile_arn,
|
| 51 |
+
self.credentials.client_id
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
async def refresh_social_token(self) -> Tuple[bool, str]:
|
| 55 |
+
"""
|
| 56 |
+
刷新 Social Token (Google/GitHub)
|
| 57 |
+
|
| 58 |
+
参考 Kiro-account-manager 实现:
|
| 59 |
+
- 端点: https://prod.us-east-1.auth.desktop.kiro.dev/refreshToken
|
| 60 |
+
- 请求体: {"refreshToken": refresh_token}
|
| 61 |
+
- 响应: {accessToken, refreshToken, expiresIn}
|
| 62 |
+
"""
|
| 63 |
+
refresh_url = f"{KIRO_AUTH_ENDPOINT}/refreshToken"
|
| 64 |
+
|
| 65 |
+
body = {"refreshToken": self.credentials.refresh_token}
|
| 66 |
+
headers = {
|
| 67 |
+
"Content-Type": "application/json",
|
| 68 |
+
"User-Agent": "kiro-proxy/1.0.0",
|
| 69 |
+
"Accept": "application/json",
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
try:
|
| 73 |
+
async with httpx.AsyncClient(verify=False, timeout=30) as client:
|
| 74 |
+
resp = await client.post(refresh_url, json=body, headers=headers)
|
| 75 |
+
|
| 76 |
+
if resp.status_code != 200:
|
| 77 |
+
error_text = resp.text
|
| 78 |
+
if resp.status_code == 401:
|
| 79 |
+
return False, "凭证已过期或无效,需要重新登录"
|
| 80 |
+
elif resp.status_code == 429:
|
| 81 |
+
return False, "请求过于频繁,请稍后重试"
|
| 82 |
+
else:
|
| 83 |
+
return False, f"刷新失败: {resp.status_code} - {error_text[:200]}"
|
| 84 |
+
|
| 85 |
+
data = resp.json()
|
| 86 |
+
|
| 87 |
+
new_token = data.get("accessToken")
|
| 88 |
+
if not new_token:
|
| 89 |
+
return False, "响应中没有 accessToken"
|
| 90 |
+
|
| 91 |
+
# 更新凭证
|
| 92 |
+
self.credentials.access_token = new_token
|
| 93 |
+
|
| 94 |
+
# 更新 refreshToken(如果服务器返回了新的)
|
| 95 |
+
if rt := data.get("refreshToken"):
|
| 96 |
+
self.credentials.refresh_token = rt
|
| 97 |
+
|
| 98 |
+
# 更新过期时间
|
| 99 |
+
if expires_in := data.get("expiresIn"):
|
| 100 |
+
expires_at = datetime.now(timezone.utc) + timedelta(seconds=expires_in)
|
| 101 |
+
self.credentials.expires_at = expires_at.isoformat()
|
| 102 |
+
|
| 103 |
+
self.credentials.last_refresh = datetime.now(timezone.utc).isoformat()
|
| 104 |
+
|
| 105 |
+
print(f"[TokenRefresher] Social token 刷新成功,过期时间: {expires_in}s")
|
| 106 |
+
return True, new_token
|
| 107 |
+
|
| 108 |
+
except Exception as e:
|
| 109 |
+
return False, f"刷新异常: {str(e)}"
|
| 110 |
+
|
| 111 |
+
async def refresh_idc_token(self) -> Tuple[bool, str]:
|
| 112 |
+
"""
|
| 113 |
+
刷新 IDC Token (AWS Builder ID)
|
| 114 |
+
|
| 115 |
+
使用 AWS OIDC 端点刷新
|
| 116 |
+
"""
|
| 117 |
+
region = self.credentials.region or "us-east-1"
|
| 118 |
+
refresh_url = f"https://oidc.{region}.amazonaws.com/token"
|
| 119 |
+
|
| 120 |
+
if not self.credentials.client_id or not self.credentials.client_secret:
|
| 121 |
+
return False, "IdC 认证缺少 client_id 或 client_secret"
|
| 122 |
+
|
| 123 |
+
machine_id = self._get_machine_id()
|
| 124 |
+
kiro_version = get_kiro_version()
|
| 125 |
+
|
| 126 |
+
body = {
|
| 127 |
+
"refreshToken": self.credentials.refresh_token,
|
| 128 |
+
"clientId": self.credentials.client_id,
|
| 129 |
+
"clientSecret": self.credentials.client_secret,
|
| 130 |
+
"grantType": "refresh_token"
|
| 131 |
+
}
|
| 132 |
+
headers = {
|
| 133 |
+
"Content-Type": "application/json",
|
| 134 |
+
"x-amz-user-agent": f"aws-sdk-js/3.738.0 KiroIDE-{kiro_version}-{machine_id}",
|
| 135 |
+
"User-Agent": "node",
|
| 136 |
+
}
|
| 137 |
+
|
| 138 |
+
try:
|
| 139 |
+
async with httpx.AsyncClient(verify=False, timeout=30) as client:
|
| 140 |
+
resp = await client.post(refresh_url, json=body, headers=headers)
|
| 141 |
+
|
| 142 |
+
if resp.status_code != 200:
|
| 143 |
+
error_text = resp.text
|
| 144 |
+
if resp.status_code == 401:
|
| 145 |
+
return False, "凭证已过期或无效,需要重新登录"
|
| 146 |
+
elif resp.status_code == 429:
|
| 147 |
+
return False, "请求过于频繁,请稍后重试"
|
| 148 |
+
else:
|
| 149 |
+
return False, f"刷新失败: {resp.status_code} - {error_text[:200]}"
|
| 150 |
+
|
| 151 |
+
data = resp.json()
|
| 152 |
+
|
| 153 |
+
new_token = data.get("accessToken") or data.get("access_token")
|
| 154 |
+
if not new_token:
|
| 155 |
+
return False, "响应中没有 access_token"
|
| 156 |
+
|
| 157 |
+
# 更新凭证
|
| 158 |
+
self.credentials.access_token = new_token
|
| 159 |
+
|
| 160 |
+
if rt := data.get("refreshToken") or data.get("refresh_token"):
|
| 161 |
+
self.credentials.refresh_token = rt
|
| 162 |
+
|
| 163 |
+
if arn := data.get("profileArn"):
|
| 164 |
+
self.credentials.profile_arn = arn
|
| 165 |
+
|
| 166 |
+
if expires_in := data.get("expiresIn") or data.get("expires_in"):
|
| 167 |
+
expires_at = datetime.now(timezone.utc) + timedelta(seconds=expires_in)
|
| 168 |
+
self.credentials.expires_at = expires_at.isoformat()
|
| 169 |
+
|
| 170 |
+
self.credentials.last_refresh = datetime.now(timezone.utc).isoformat()
|
| 171 |
+
|
| 172 |
+
print(f"[TokenRefresher] IDC token 刷新成功")
|
| 173 |
+
return True, new_token
|
| 174 |
+
|
| 175 |
+
except Exception as e:
|
| 176 |
+
return False, f"刷新异常: {str(e)}"
|
| 177 |
+
|
| 178 |
+
async def refresh(self) -> Tuple[bool, str]:
|
| 179 |
+
"""
|
| 180 |
+
刷新 token,根据 authMethod 分发到正确的刷新方法
|
| 181 |
+
|
| 182 |
+
Returns:
|
| 183 |
+
(success, new_token_or_error)
|
| 184 |
+
"""
|
| 185 |
+
is_valid, error = self.validate_refresh_token()
|
| 186 |
+
if not is_valid:
|
| 187 |
+
return False, error
|
| 188 |
+
|
| 189 |
+
auth_method = (self.credentials.auth_method or "social").lower()
|
| 190 |
+
|
| 191 |
+
if auth_method == "idc":
|
| 192 |
+
return await self.refresh_idc_token()
|
| 193 |
+
else:
|
| 194 |
+
# social 或其他默认使用 social 刷新
|
| 195 |
+
return await self.refresh_social_token()
|
KiroProxy/kiro_proxy/credential/types.py
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""凭证数据类型"""
|
| 2 |
+
import json
|
| 3 |
+
import time
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
from datetime import datetime, timezone, timedelta
|
| 6 |
+
from enum import Enum
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from typing import Optional
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class CredentialStatus(Enum):
|
| 12 |
+
"""凭证状态"""
|
| 13 |
+
ACTIVE = "active"
|
| 14 |
+
COOLDOWN = "cooldown"
|
| 15 |
+
UNHEALTHY = "unhealthy"
|
| 16 |
+
DISABLED = "disabled"
|
| 17 |
+
SUSPENDED = "suspended" # 账号被封禁
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
@dataclass
|
| 21 |
+
class KiroCredentials:
|
| 22 |
+
"""Kiro 凭证信息"""
|
| 23 |
+
access_token: Optional[str] = None
|
| 24 |
+
refresh_token: Optional[str] = None
|
| 25 |
+
client_id: Optional[str] = None
|
| 26 |
+
client_secret: Optional[str] = None
|
| 27 |
+
profile_arn: Optional[str] = None
|
| 28 |
+
expires_at: Optional[str] = None
|
| 29 |
+
region: str = "us-east-1"
|
| 30 |
+
auth_method: str = "social"
|
| 31 |
+
provider: Optional[str] = None # Google / Github (社交登录提供商)
|
| 32 |
+
client_id_hash: Optional[str] = None
|
| 33 |
+
last_refresh: Optional[str] = None
|
| 34 |
+
|
| 35 |
+
@classmethod
|
| 36 |
+
def from_file(cls, path: str) -> "KiroCredentials":
|
| 37 |
+
"""从文件加载凭证"""
|
| 38 |
+
with open(path) as f:
|
| 39 |
+
data = json.load(f)
|
| 40 |
+
|
| 41 |
+
return cls(
|
| 42 |
+
access_token=data.get("accessToken"),
|
| 43 |
+
refresh_token=data.get("refreshToken"),
|
| 44 |
+
client_id=data.get("clientId"),
|
| 45 |
+
client_secret=data.get("clientSecret"),
|
| 46 |
+
profile_arn=data.get("profileArn"),
|
| 47 |
+
expires_at=data.get("expiresAt") or data.get("expire"),
|
| 48 |
+
region=data.get("region", "us-east-1"),
|
| 49 |
+
auth_method=data.get("authMethod", "social"),
|
| 50 |
+
provider=data.get("provider"),
|
| 51 |
+
client_id_hash=data.get("clientIdHash"),
|
| 52 |
+
last_refresh=data.get("lastRefresh"),
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
def to_dict(self) -> dict:
|
| 56 |
+
"""转换为字典"""
|
| 57 |
+
result = {
|
| 58 |
+
"accessToken": self.access_token,
|
| 59 |
+
"refreshToken": self.refresh_token,
|
| 60 |
+
"clientId": self.client_id,
|
| 61 |
+
"clientSecret": self.client_secret,
|
| 62 |
+
"profileArn": self.profile_arn,
|
| 63 |
+
"expiresAt": self.expires_at,
|
| 64 |
+
"region": self.region,
|
| 65 |
+
"authMethod": self.auth_method,
|
| 66 |
+
"clientIdHash": self.client_id_hash,
|
| 67 |
+
"lastRefresh": self.last_refresh,
|
| 68 |
+
}
|
| 69 |
+
# 只有社交登录才添加 provider 字段
|
| 70 |
+
if self.provider:
|
| 71 |
+
result["provider"] = self.provider
|
| 72 |
+
return result
|
| 73 |
+
|
| 74 |
+
def save_to_file(self, path: str):
|
| 75 |
+
"""保存凭证到文件"""
|
| 76 |
+
existing = {}
|
| 77 |
+
if Path(path).exists():
|
| 78 |
+
try:
|
| 79 |
+
with open(path) as f:
|
| 80 |
+
existing = json.load(f)
|
| 81 |
+
except Exception:
|
| 82 |
+
pass
|
| 83 |
+
|
| 84 |
+
existing.update({k: v for k, v in self.to_dict().items() if v is not None})
|
| 85 |
+
|
| 86 |
+
with open(path, "w") as f:
|
| 87 |
+
json.dump(existing, f, indent=2)
|
| 88 |
+
|
| 89 |
+
def is_expired(self) -> bool:
|
| 90 |
+
"""检查 token 是否已过期"""
|
| 91 |
+
if not self.expires_at:
|
| 92 |
+
return True
|
| 93 |
+
|
| 94 |
+
try:
|
| 95 |
+
if "T" in self.expires_at:
|
| 96 |
+
expires = datetime.fromisoformat(self.expires_at.replace("Z", "+00:00"))
|
| 97 |
+
now = datetime.now(timezone.utc)
|
| 98 |
+
return expires <= now + timedelta(minutes=5)
|
| 99 |
+
|
| 100 |
+
expires_ts = int(self.expires_at)
|
| 101 |
+
now_ts = int(time.time())
|
| 102 |
+
return now_ts >= (expires_ts - 300)
|
| 103 |
+
except Exception:
|
| 104 |
+
return True
|
| 105 |
+
|
| 106 |
+
def is_expiring_soon(self, minutes: int = 10) -> bool:
|
| 107 |
+
"""检查 token 是否即将过期"""
|
| 108 |
+
if not self.expires_at:
|
| 109 |
+
return False
|
| 110 |
+
|
| 111 |
+
try:
|
| 112 |
+
if "T" in self.expires_at:
|
| 113 |
+
expires = datetime.fromisoformat(self.expires_at.replace("Z", "+00:00"))
|
| 114 |
+
now = datetime.now(timezone.utc)
|
| 115 |
+
return expires < now + timedelta(minutes=minutes)
|
| 116 |
+
|
| 117 |
+
expires_ts = int(self.expires_at)
|
| 118 |
+
now_ts = int(time.time())
|
| 119 |
+
return now_ts >= (expires_ts - minutes * 60)
|
| 120 |
+
except Exception:
|
| 121 |
+
return False
|
KiroProxy/kiro_proxy/docs/01-quickstart.md
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 快速开始
|
| 2 |
+
|
| 3 |
+
## 安装运行
|
| 4 |
+
|
| 5 |
+
### 方式一:下载预编译版本
|
| 6 |
+
|
| 7 |
+
从 [Releases](https://github.com/yourname/kiro-proxy/releases) 下载对应平台的安装包:
|
| 8 |
+
|
| 9 |
+
- **Windows**: `kiro-proxy-windows.zip`
|
| 10 |
+
- **macOS**: `kiro-proxy-macos.zip`
|
| 11 |
+
- **Linux**: `kiro-proxy-linux.tar.gz`
|
| 12 |
+
|
| 13 |
+
解压后双击运行即可。
|
| 14 |
+
|
| 15 |
+
### 方式二:从源码运行
|
| 16 |
+
|
| 17 |
+
```bash
|
| 18 |
+
# 克隆项目
|
| 19 |
+
git clone https://github.com/yourname/kiro-proxy.git
|
| 20 |
+
cd kiro-proxy
|
| 21 |
+
|
| 22 |
+
# 创建虚拟环境
|
| 23 |
+
python -m venv venv
|
| 24 |
+
source venv/bin/activate # Windows: venv\Scripts\activate
|
| 25 |
+
|
| 26 |
+
# 安装依赖
|
| 27 |
+
pip install -r requirements.txt
|
| 28 |
+
|
| 29 |
+
# 运行(默认端口 8080)
|
| 30 |
+
python run.py
|
| 31 |
+
|
| 32 |
+
# 指定端口
|
| 33 |
+
python run.py 8081
|
| 34 |
+
```
|
| 35 |
+
|
| 36 |
+
启动成功后,访问 http://localhost:8080 打开管理界面。
|
| 37 |
+
|
| 38 |
+
---
|
| 39 |
+
|
| 40 |
+
## 获取 Kiro 账号
|
| 41 |
+
|
| 42 |
+
Kiro Proxy 需要 Kiro 账号的 Token 才能工作。有两种方式获取:
|
| 43 |
+
|
| 44 |
+
### 方式一:在线登录(推荐)
|
| 45 |
+
|
| 46 |
+
1. 打开 Web UI,点击「账号」标签页
|
| 47 |
+
2. 点击「在线登录」按钮
|
| 48 |
+
3. 选择登录方式:
|
| 49 |
+
- **Google** - 使用 Google 账号
|
| 50 |
+
- **GitHub** - 使用 GitHub 账号
|
| 51 |
+
- **AWS** - 使用 AWS Builder ID
|
| 52 |
+
4. 在弹出的浏览器中完成授权
|
| 53 |
+
5. 授权成功后,账号自动添加到代理
|
| 54 |
+
|
| 55 |
+
### 方式二:扫描本地 Token
|
| 56 |
+
|
| 57 |
+
如果你已经在 Kiro IDE 中登录过:
|
| 58 |
+
|
| 59 |
+
1. 打开 Kiro IDE,确保已登录
|
| 60 |
+
2. 回到 Web UI,点击「扫描 Token」
|
| 61 |
+
3. 系统会扫描 `~/.aws/sso/cache/` 目录
|
| 62 |
+
4. 选择要添加的 Token 文件
|
| 63 |
+
|
| 64 |
+
---
|
| 65 |
+
|
| 66 |
+
## 配置 AI 客户端
|
| 67 |
+
|
| 68 |
+
### Claude Code (VSCode 插件)
|
| 69 |
+
|
| 70 |
+
这是最推荐的使用方式,工具调用功能已验证可用。
|
| 71 |
+
|
| 72 |
+
1. 安装 Claude Code 插件
|
| 73 |
+
2. 打开设置,添加自定义 Provider:
|
| 74 |
+
|
| 75 |
+
```
|
| 76 |
+
名称: Kiro Proxy
|
| 77 |
+
API Provider: Anthropic
|
| 78 |
+
API Key: any(随便填一个)
|
| 79 |
+
Base URL: http://localhost:8080
|
| 80 |
+
模型: claude-sonnet-4
|
| 81 |
+
```
|
| 82 |
+
|
| 83 |
+
3. 选择 Kiro Proxy 作为当前 Provider
|
| 84 |
+
|
| 85 |
+
### Codex CLI
|
| 86 |
+
|
| 87 |
+
OpenAI 官方命令行工具。
|
| 88 |
+
|
| 89 |
+
```bash
|
| 90 |
+
# 安装
|
| 91 |
+
npm install -g @openai/codex
|
| 92 |
+
|
| 93 |
+
# 配置 (~/.codex/config.toml)
|
| 94 |
+
model = "gpt-4o"
|
| 95 |
+
model_provider = "kiro"
|
| 96 |
+
|
| 97 |
+
[model_providers.kiro]
|
| 98 |
+
name = "Kiro Proxy"
|
| 99 |
+
base_url = "http://localhost:8080/v1"
|
| 100 |
+
```
|
| 101 |
+
|
| 102 |
+
### Gemini CLI
|
| 103 |
+
|
| 104 |
+
```bash
|
| 105 |
+
# 设置环境变量
|
| 106 |
+
export GEMINI_API_BASE=http://localhost:8080/v1
|
| 107 |
+
|
| 108 |
+
# 或在配置文件中设置
|
| 109 |
+
base_url = "http://localhost:8080/v1"
|
| 110 |
+
model = "gemini-pro"
|
| 111 |
+
```
|
| 112 |
+
|
| 113 |
+
### 其他兼容客户端
|
| 114 |
+
|
| 115 |
+
任何支持 OpenAI 或 Anthropic API 的客户端都可以使用:
|
| 116 |
+
|
| 117 |
+
- **Base URL**: `http://localhost:8080` 或 `http://localhost:8080/v1`
|
| 118 |
+
- **API Key**: 任意值(代理不验证)
|
| 119 |
+
- **模型**: 见下方模型对照表
|
| 120 |
+
|
| 121 |
+
---
|
| 122 |
+
|
| 123 |
+
## 模型对照表
|
| 124 |
+
|
| 125 |
+
Kiro 支持以下模型,你可以使用 Kiro 原生名称或映射名称:
|
| 126 |
+
|
| 127 |
+
| Kiro 模型 | 能力 | 可用名称(任选其一) |
|
| 128 |
+
|-----------|------|---------------------|
|
| 129 |
+
| `claude-sonnet-4` | ⭐⭐⭐ 推荐,性价比最高 | `gpt-4o`, `gpt-4`, `gpt-4-turbo`, `claude-3-5-sonnet-20241022`, `claude-3-5-sonnet-latest`, `sonnet` |
|
| 130 |
+
| `claude-sonnet-4.5` | ⭐⭐⭐⭐ 更强,适合复杂任务 | `gemini-1.5-pro`, `o1`, `o1-preview`, `claude-3-opus-20240229`, `claude-3-opus-latest`, `claude-4-opus`, `opus` |
|
| 131 |
+
| `claude-haiku-4.5` | ⚡ 快速,适合简单任务 | `gpt-4o-mini`, `gpt-3.5-turbo`, `claude-3-5-haiku-20241022`, `haiku` |
|
| 132 |
+
| `auto` | 🤖 自动选择 | `auto` |
|
| 133 |
+
|
| 134 |
+
### 各客户端推荐配置
|
| 135 |
+
|
| 136 |
+
| 客户端 | 推荐模型名 | 实际使用 |
|
| 137 |
+
|--------|-----------|---------|
|
| 138 |
+
| Claude Code | `claude-sonnet-4` 或 `claude-sonnet-4.5` | 直接使用 Kiro 模型名 |
|
| 139 |
+
| Codex CLI | `gpt-4o` | 映射到 claude-sonnet-4 |
|
| 140 |
+
| Gemini CLI | `gemini-1.5-pro` | 映射到 claude-sonnet-4.5 |
|
| 141 |
+
| 其他 OpenAI 客户端 | `gpt-4o` | 映射到 claude-sonnet-4 |
|
| 142 |
+
|
| 143 |
+
> 💡 **提示**:不确定用什么模型?直接用 `claude-sonnet-4` 或 `gpt-4o`,性价比最高。
|
KiroProxy/kiro_proxy/docs/02-features.md
ADDED
|
@@ -0,0 +1,225 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 功能特性
|
| 2 |
+
|
| 3 |
+
## 多协议支持
|
| 4 |
+
|
| 5 |
+
Kiro Proxy 支持三种主流 AI API 协议,可以适配不同的客户端:
|
| 6 |
+
|
| 7 |
+
| 协议 | 端点 | 适用客户端 |
|
| 8 |
+
|------|------|------------|
|
| 9 |
+
| OpenAI | `/v1/chat/completions` | Codex CLI, ChatGPT 客户端 |
|
| 10 |
+
| Anthropic | `/v1/messages` | Claude Code, Claude 客户端 |
|
| 11 |
+
| Gemini | `/v1/models/{model}:generateContent` | Gemini CLI |
|
| 12 |
+
|
| 13 |
+
代理会自动将请求转换为 Kiro API 格式,响应转换回对应协议格式。
|
| 14 |
+
|
| 15 |
+
---
|
| 16 |
+
|
| 17 |
+
## 工具调用支持
|
| 18 |
+
|
| 19 |
+
完整支持三种协议的工具调用功能:
|
| 20 |
+
|
| 21 |
+
### Anthropic 协议(Claude Code)
|
| 22 |
+
|
| 23 |
+
- `tools` 定义和 `tool_result` 响应完整支持
|
| 24 |
+
- `tool_choice: required` 支持(通过 prompt 注入)
|
| 25 |
+
- `web_search` 特殊工具自动识别
|
| 26 |
+
- 工具数量限制(最多 50 个)
|
| 27 |
+
- 描述截断(超过 500 字符自动截断)
|
| 28 |
+
|
| 29 |
+
### OpenAI 协议(Codex CLI)
|
| 30 |
+
|
| 31 |
+
- `tools` 定义(function 类型)
|
| 32 |
+
- `tool_calls` 响应处理
|
| 33 |
+
- `tool` 角色消息转换
|
| 34 |
+
- `tool_choice: required/any` 支持
|
| 35 |
+
- 工具数量限制和描述截断
|
| 36 |
+
|
| 37 |
+
### Gemini 协议
|
| 38 |
+
|
| 39 |
+
- `functionDeclarations` 工具定义
|
| 40 |
+
- `functionCall` 响应处理
|
| 41 |
+
- `functionResponse` 工具结果
|
| 42 |
+
- `toolConfig.functionCallingConfig.mode` 支持(ANY/REQUIRED)
|
| 43 |
+
- 工具数量限制和描述截断
|
| 44 |
+
|
| 45 |
+
### 历史消息修复
|
| 46 |
+
|
| 47 |
+
Kiro API 要求消息必须严格交替(user → assistant → user → assistant),代理会自动:
|
| 48 |
+
|
| 49 |
+
- 检测并修复连续的同角色消息
|
| 50 |
+
- 合并重复的 tool_results
|
| 51 |
+
- 插入占位消息保持交替
|
| 52 |
+
|
| 53 |
+
---
|
| 54 |
+
|
| 55 |
+
## 多账号管理
|
| 56 |
+
|
| 57 |
+
### 账号轮询
|
| 58 |
+
|
| 59 |
+
支持添加多个 Kiro 账号,代理会自动轮询使用(默认随机):
|
| 60 |
+
|
| 61 |
+
- 每次请求随机选择一个可用账号(尽量避免连续命中同一账号)
|
| 62 |
+
- 自动跳过冷却中或不健康的账号
|
| 63 |
+
- 分散请求压力,降低单账号 RPM 过高导致封禁风险
|
| 64 |
+
|
| 65 |
+
### 会话粘性(可选)
|
| 66 |
+
|
| 67 |
+
为了保持对话上下文的连贯性,在非 `random` 策略下会启用会话粘性:
|
| 68 |
+
|
| 69 |
+
- 同一会话 ID 在 60 秒内会使用同一账号
|
| 70 |
+
- 超过 60 秒或账号不可用时才切换
|
| 71 |
+
- 会话 ID 由请求内容生成;可通过 `~/.kiro-proxy/priority.json` 中的 `strategy` 调整策略
|
| 72 |
+
|
| 73 |
+
### 账号状态
|
| 74 |
+
|
| 75 |
+
每个账号有四种状态:
|
| 76 |
+
|
| 77 |
+
| 状态 | 说明 | 颜色 |
|
| 78 |
+
|------|------|------|
|
| 79 |
+
| Active | 正常可用 | 绿色 |
|
| 80 |
+
| Cooldown | 触发限流,冷却中 | 黄色 |
|
| 81 |
+
| Unhealthy | 健康检查失败 | 红色 |
|
| 82 |
+
| Disabled | 手动禁用 | 灰色 |
|
| 83 |
+
|
| 84 |
+
---
|
| 85 |
+
|
| 86 |
+
## Token 自动刷新
|
| 87 |
+
|
| 88 |
+
### 自动检测
|
| 89 |
+
|
| 90 |
+
- 后台每 5 分钟检查所有账号的 Token 状态
|
| 91 |
+
- 检测 Token 是否即将过期(15 分钟内)
|
| 92 |
+
|
| 93 |
+
### 自动刷新
|
| 94 |
+
|
| 95 |
+
- 发现即将过期的 Token 自动刷新
|
| 96 |
+
- 支持 Social 认证(Google/GitHub)的 refresh_token
|
| 97 |
+
- 刷新失败会标记账号为不健康
|
| 98 |
+
|
| 99 |
+
### 手动刷新
|
| 100 |
+
|
| 101 |
+
- 在账号卡片点击「刷新 Token」
|
| 102 |
+
- 或点击「刷新所有 Token」批量刷新
|
| 103 |
+
|
| 104 |
+
---
|
| 105 |
+
|
| 106 |
+
## 配额管理
|
| 107 |
+
|
| 108 |
+
### 429 自动处理
|
| 109 |
+
|
| 110 |
+
当 Kiro API 返回 429 (Too Many Requests) 时:
|
| 111 |
+
|
| 112 |
+
1. 自动将该账号标记为 Cooldown 状态
|
| 113 |
+
2. 设置 5 分钟冷却时间
|
| 114 |
+
3. 立即切换到其他可用账号重试
|
| 115 |
+
4. 冷却结束后自动恢复
|
| 116 |
+
|
| 117 |
+
### 手动恢复
|
| 118 |
+
|
| 119 |
+
如果需要提前恢复账号:
|
| 120 |
+
|
| 121 |
+
1. 在「监控」页面查看配额状态
|
| 122 |
+
2. 点击账号旁的「恢复」按钮
|
| 123 |
+
|
| 124 |
+
---
|
| 125 |
+
|
| 126 |
+
## 流量监控
|
| 127 |
+
|
| 128 |
+
### 请求记录
|
| 129 |
+
|
| 130 |
+
记录所有经过代理的 LLM 请求:
|
| 131 |
+
|
| 132 |
+
- 请求时间、模型、账号
|
| 133 |
+
- 输入/输出 Token 数量
|
| 134 |
+
- 响应时间、状态码
|
| 135 |
+
- 完整的请求和响应内容
|
| 136 |
+
|
| 137 |
+
### 搜索过滤
|
| 138 |
+
|
| 139 |
+
- 按协议筛选(OpenAI/Anthropic/Gemini)
|
| 140 |
+
- 按状态筛选(完成/错误/进行中)
|
| 141 |
+
- 关键词搜索
|
| 142 |
+
|
| 143 |
+
### 导出功能
|
| 144 |
+
|
| 145 |
+
- 支持导出为 JSON 格式
|
| 146 |
+
- 可选择导出全部或指定记录
|
| 147 |
+
|
| 148 |
+
---
|
| 149 |
+
|
| 150 |
+
## 登录方式
|
| 151 |
+
|
| 152 |
+
### Google 登录
|
| 153 |
+
|
| 154 |
+
使用 Google 账号通过 OAuth 授权登录。
|
| 155 |
+
|
| 156 |
+
### GitHub 登录
|
| 157 |
+
|
| 158 |
+
使用 GitHub 账号通过 OAuth 授权登录。
|
| 159 |
+
|
| 160 |
+
### AWS Builder ID
|
| 161 |
+
|
| 162 |
+
使用 AWS Builder ID 通过 Device Code Flow 登录:
|
| 163 |
+
|
| 164 |
+
1. 点击 AWS 登录按钮
|
| 165 |
+
2. 复制显示的授权码
|
| 166 |
+
3. 在浏览器中打开授权页面
|
| 167 |
+
4. 输入授权码完成登录
|
| 168 |
+
|
| 169 |
+
---
|
| 170 |
+
|
| 171 |
+
## 历史消息管理
|
| 172 |
+
|
| 173 |
+
### 对话长度限制
|
| 174 |
+
|
| 175 |
+
Kiro API 有输入长度限制,当对话历史过长时会返回 `CONTENT_LENGTH_EXCEEDS_THRESHOLD` 错误。
|
| 176 |
+
|
| 177 |
+
代理内置了多种策略自动处理这个问题:
|
| 178 |
+
|
| 179 |
+
### 可用策略
|
| 180 |
+
|
| 181 |
+
| 策略 | 说明 | 触发时机 |
|
| 182 |
+
|------|------|----------|
|
| 183 |
+
| 自动截断 | 优先保留最新上下文并摘要前文,必要时截断 | 每次请求前 |
|
| 184 |
+
| 智能摘要 | 用 AI 生成早期对话摘要 | 超过阈值时 |
|
| 185 |
+
| 错误重试 | 遇到长度错误时截断重试 | 收到错误后 |
|
| 186 |
+
| 预估检测 | 预估 token 数量,超限预先截断 | 每次请求前 |
|
| 187 |
+
|
| 188 |
+
### 配置选项
|
| 189 |
+
|
| 190 |
+
在「设置」页面可以配置:
|
| 191 |
+
|
| 192 |
+
- **最大消息数** - 自动截断时保留的消息数量(默认 30)
|
| 193 |
+
- **最大字符数** - 自动截断时的字符数限制(默认 150000)
|
| 194 |
+
- **重试保留数** - 错误重试时保留的消息数(默认 20)
|
| 195 |
+
- **最大重试次数** - 错误重试的最大次数(默认 2)
|
| 196 |
+
- **摘要保留数** - 智能摘要时保留的最近消息数(默认 10)
|
| 197 |
+
- **摘要阈值** - 触发智能摘要的字符数阈值(默认 100000)
|
| 198 |
+
- **添加警告** - 截断时是否在日志中记录
|
| 199 |
+
|
| 200 |
+
### 推荐配置
|
| 201 |
+
|
| 202 |
+
- **默认**:只启用「错误重试」,遇到问题时自动处理
|
| 203 |
+
- **保守**:启用「智能摘要 + 错误重试」,保留关键信息
|
| 204 |
+
- **激进**:启用「自动截断 + 预估检测」,预防性截断
|
| 205 |
+
|
| 206 |
+
---
|
| 207 |
+
|
| 208 |
+
## 配置持久化
|
| 209 |
+
|
| 210 |
+
### 自动保存
|
| 211 |
+
|
| 212 |
+
账号配置自动保存到 `~/.kiro-proxy/config.json`:
|
| 213 |
+
|
| 214 |
+
- 账号列表和状态
|
| 215 |
+
- 启用/禁用设置
|
| 216 |
+
- Token 文件路径
|
| 217 |
+
|
| 218 |
+
### 重启恢复
|
| 219 |
+
|
| 220 |
+
重启代理后自动加载保存的配置,无需重新添加账号。
|
| 221 |
+
|
| 222 |
+
### 导入导出
|
| 223 |
+
|
| 224 |
+
- 「导出配置」下载当前配置
|
| 225 |
+
- 「导入配置」从文件恢复
|