diff --git a/.dockerignore b/.dockerignore deleted file mode 100644 index 311d8d5c9225f59b701f82c422de896b12ee33a9..0000000000000000000000000000000000000000 --- a/.dockerignore +++ /dev/null @@ -1,54 +0,0 @@ -# Rust build artifacts -target/ -Cargo.lock - -# Node.js dependencies and build artifacts -admin-ui/node_modules/ -admin-ui/dist/ -admin-ui/pnpm-lock.yaml -admin-ui/tsconfig.tsbuildinfo -admin-ui/.vite/ - -# Version control -.git/ -.gitignore - -# IDE and editor files -.idea/ -.vscode/ -*.swp -*.swo -*~ - -# Claude/AI -.claude/ -CLAUDE.md -AGENTS.md - -# CI/CD -.github/ - -# Documentation and examples -*.md -README.md -config.example.json -credentials.example.*.json - -# Development and test files -src/test.rs -src/debug.rs -test.json -tools/ - -# OS-specific files -.DS_Store -Thumbs.db - -# Local configuration (keep templates only) -config.json -credentials.json - -# Docker files -Dockerfile -.dockerignore -docker-compose*.yml \ No newline at end of file diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml deleted file mode 100644 index 95c7538542d3a04ed83d00c16fabae4e0e0ce733..0000000000000000000000000000000000000000 --- a/.github/workflows/build.yaml +++ /dev/null @@ -1,85 +0,0 @@ -name: Build Artifacts - -on: - push: - tags: - - 'v*' - workflow_dispatch: - inputs: - version: - description: 'Version label for artifacts (e.g., 2025.12.1)' - required: true - default: '2026.1.1' - -permissions: - contents: read - -jobs: - build: - strategy: - fail-fast: false - matrix: - include: - - platform: macos-latest - target: aarch64-apple-darwin - name: macOS-arm64 - - platform: macos-latest - target: x86_64-apple-darwin - name: macOS-x64 - - platform: windows-latest - target: x86_64-pc-windows-msvc - name: Windows-x64 - - platform: ubuntu-22.04 - target: x86_64-unknown-linux-gnu - name: Linux-x64 - - platform: ubuntu-22.04-arm - target: aarch64-unknown-linux-gnu - name: Linux-arm64 - - runs-on: ${{ matrix.platform }} - - steps: - - name: Checkout - uses: actions/checkout@v4 - - - name: Setup Node.js - uses: actions/setup-node@v4 - with: - node-version: '20' - - - name: Setup pnpm - uses: pnpm/action-setup@v4 - with: - version: 9 - - - name: Install admin-ui dependencies - working-directory: admin-ui - run: pnpm install - - - name: Build admin-ui - working-directory: admin-ui - run: pnpm build - - - name: Setup Rust - uses: dtolnay/rust-toolchain@stable - with: - targets: ${{ matrix.target }} - - - name: Setup Rust cache - uses: Swatinem/rust-cache@v2 - with: - shared-key: "rust-cache-${{ matrix.target }}" - cache-on-failure: true - - - name: Build app - run: cargo build --release --target ${{ matrix.target }} - - - name: Upload build artifacts - uses: actions/upload-artifact@v4 - with: - name: kiro-rs-${{ github.event.inputs.version || github.ref_name }}-${{ matrix.name }} - if-no-files-found: error - compression-level: 6 - path: | - target/${{ matrix.target }}/release/kiro-rs - target/${{ matrix.target }}/release/kiro-rs.exe \ No newline at end of file diff --git a/.github/workflows/docker-build.yaml b/.github/workflows/docker-build.yaml deleted file mode 100644 index 6d917a1d5be70dd919056e6a9f07e3b47b8d3e46..0000000000000000000000000000000000000000 --- a/.github/workflows/docker-build.yaml +++ /dev/null @@ -1,87 +0,0 @@ -name: Build and Push Docker Images - -on: - push: - tags: - - 'v*' - workflow_dispatch: - inputs: - version: - description: 'Version tag for Docker images (e.g., 2025.12.1)' - required: true - default: '2026.1.1' - -permissions: - contents: read - packages: write - -jobs: - build: - runs-on: ${{ matrix.runner }} - - strategy: - fail-fast: false - matrix: - include: - - platform: linux/amd64 - runner: ubuntu-latest - arch: amd64 - - platform: linux/arm64 - runner: ubuntu-22.04-arm - arch: arm64 - - steps: - - name: Checkout - uses: actions/checkout@v4 - - - name: Set up Docker Buildx - uses: docker/setup-buildx-action@v3 - - - name: Log in to GitHub Container Registry - uses: docker/login-action@v3 - with: - registry: ghcr.io - username: ${{ github.repository_owner }} - password: ${{ github.token }} - - - name: Build and push - uses: docker/build-push-action@v6 - with: - context: . - platforms: ${{ matrix.platform }} - cache-from: type=gha - cache-to: type=gha,mode=max - push: true - provenance: false - tags: ghcr.io/${{ github.repository_owner }}/kiro-rs:${{ github.event.inputs.version || github.ref_name }}-${{ matrix.arch }} - labels: | - org.opencontainers.image.source=https://github.com/${{ github.repository }} - org.opencontainers.image.description=Kiro.rs Docker Image - - manifest: - needs: build - runs-on: ubuntu-latest - steps: - - name: Log in to GitHub Container Registry - uses: docker/login-action@v3 - with: - registry: ghcr.io - username: ${{ github.repository_owner }} - password: ${{ github.token }} - - - name: Create and push multi-arch manifest - run: | - VERSION="${{ github.event.inputs.version || github.ref_name }}" - IMAGE="ghcr.io/${{ github.repository_owner }}/kiro-rs" - - # Create manifest for version tag - docker manifest create ${IMAGE}:${VERSION} \ - ${IMAGE}:${VERSION}-amd64 \ - ${IMAGE}:${VERSION}-arm64 - docker manifest push ${IMAGE}:${VERSION} - - # Create manifest for latest tag - docker manifest create ${IMAGE}:latest \ - ${IMAGE}:${VERSION}-amd64 \ - ${IMAGE}:${VERSION}-arm64 - docker manifest push ${IMAGE}:latest diff --git a/.github/workflows/docker-build.yml b/.github/workflows/docker-build.yml deleted file mode 100644 index 6df0b7e6fa3802c2c564e1766cf709ddd0a9b241..0000000000000000000000000000000000000000 --- a/.github/workflows/docker-build.yml +++ /dev/null @@ -1,55 +0,0 @@ -name: Build and Push Docker Image - -on: - push: - branches: - - master - paths-ignore: - - '**.md' - - '.gitignore' - workflow_dispatch: - -env: - REGISTRY: ghcr.io - IMAGE_NAME: ${{ github.repository }} - -jobs: - build-and-push: - runs-on: ubuntu-latest - permissions: - contents: read - packages: write - - steps: - - name: Checkout repository - uses: actions/checkout@v4 - - - name: Set up Docker Buildx - uses: docker/setup-buildx-action@v3 - - - name: Log in to GitHub Container Registry - uses: docker/login-action@v3 - with: - registry: ${{ env.REGISTRY }} - username: ${{ github.actor }} - password: ${{ secrets.GITHUB_TOKEN }} - - - name: Extract metadata - id: meta - uses: docker/metadata-action@v5 - with: - images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }} - tags: | - type=raw,value=latest,enable={{is_default_branch}} - type=sha,prefix={{branch}}- - - - name: Build and push Docker image - uses: docker/build-push-action@v5 - with: - context: . - push: true - tags: ${{ steps.meta.outputs.tags }} - labels: ${{ steps.meta.outputs.labels }} - cache-from: type=gha - cache-to: type=gha,mode=max - platforms: linux/amd64 diff --git a/.gitignore b/.gitignore deleted file mode 100644 index 07b2905153645ab142c4b9374e23a449c9f38cb3..0000000000000000000000000000000000000000 --- a/.gitignore +++ /dev/null @@ -1,12 +0,0 @@ -/target -/CLAUDE.md -/AGENTS.md -/config.json -/credentials.json -/.idea -/test.json -/Cargo.lock -/admin-ui/node_modules/ -/admin-ui/dist/ -/admin-ui/pnpm-lock.yaml -/admin-ui/tsconfig.tsbuildinfo diff --git a/Cargo.toml b/Cargo.toml deleted file mode 100644 index 053eca8f56ca48c69d7cea12fc2df29d6d2341fb..0000000000000000000000000000000000000000 --- a/Cargo.toml +++ /dev/null @@ -1,34 +0,0 @@ -[package] -name = "kiro-rs" -version = "2026.1.5" -edition = "2024" - -[profile.release] -lto = true -strip = true - -[dependencies] -axum = "0.8" -tokio = { version = "1.0", features = ["full"] } -reqwest = { version = "0.12", features = ["stream", "json", "socks"] } -serde = { version = "1.0", features = ["derive"] } -serde_json = "1.0" -tracing = "0.1" -tracing-subscriber = { version = "0.3", features = ["env-filter"] } -anyhow = "1.0" -http = "1.0" -futures = "0.3" -chrono = { version = "0.4", features = ["serde"] } -uuid = { version = "1.10", features = ["v1", "v4", "fast-rng"] } -fastrand = "2" -sha2 = "0.10" -hex = "0.4" -crc = "3" # CRC32C 计算 -bytes = "1" # 高效的字节缓冲区 -tower-http = { version = "0.6", features = ["cors"] } -clap = { version = "4.5", features = ["derive"] } -urlencoding = "2" -parking_lot = "0.12" # 高性能同步原语 -subtle = "2.6" # 常量时间比较(防止时序攻击) -rust-embed = "8" # 嵌入静态文件 -mime_guess = "2" # MIME 类型推断 \ No newline at end of file diff --git a/Dockerfile b/Dockerfile deleted file mode 100644 index 8ee54c3f80be743a188ef3b8a8a2847b2b7c00df..0000000000000000000000000000000000000000 --- a/Dockerfile +++ /dev/null @@ -1,42 +0,0 @@ -FROM node:22-alpine AS frontend-builder - -WORKDIR /app/admin-ui -COPY admin-ui/package.json ./ -RUN npm install -g pnpm && pnpm install -COPY admin-ui ./ -RUN pnpm build - -FROM rust:1.85-alpine AS builder - -RUN apk add --no-cache musl-dev openssl-dev openssl-libs-static - -WORKDIR /app -COPY Cargo.toml Cargo.lock* ./ -COPY src ./src -COPY --from=frontend-builder /app/admin-ui/dist /app/admin-ui/dist - -RUN cargo build --release - -FROM alpine:3.21 - -RUN apk add --no-cache ca-certificates - -# 创建非 root 用户 (HuggingFace Spaces 要求) -RUN adduser -D -u 1000 appuser - -WORKDIR /app - -COPY --from=builder /app/target/release/kiro-rs /app/kiro-rs -COPY entrypoint.sh /app/entrypoint.sh - -# 创建配置目录并设置权限 -RUN mkdir -p /app/config && \ - chown -R appuser:appuser /app && \ - chmod +x /app/entrypoint.sh - -USER appuser - -# HuggingFace Spaces 只支持端口 7860 -EXPOSE 7860 - -ENTRYPOINT ["/app/entrypoint.sh"] diff --git a/admin-ui/index.html b/admin-ui/index.html deleted file mode 100644 index fa27c3ef561d10044930666513ad7ffe26671cd2..0000000000000000000000000000000000000000 --- a/admin-ui/index.html +++ /dev/null @@ -1,13 +0,0 @@ - - - - - - - Kiro Admin - - -
- - - diff --git a/admin-ui/package.json b/admin-ui/package.json deleted file mode 100644 index 8e14a46f35e7e0b8e30bb0471ff96a631596b79c..0000000000000000000000000000000000000000 --- a/admin-ui/package.json +++ /dev/null @@ -1,37 +0,0 @@ -{ - "name": "kiro-admin-ui", - "version": "1.0.0", - "type": "module", - "scripts": { - "dev": "vite", - "build": "tsc -b && vite build", - "preview": "vite preview" - }, - "dependencies": { - "react": "^18.3.1", - "react-dom": "^18.3.1", - "@tanstack/react-query": "^5.60.0", - "axios": "^1.7.0", - "clsx": "^2.1.1", - "tailwind-merge": "^2.5.0", - "class-variance-authority": "^0.7.0", - "@radix-ui/react-slot": "^1.1.0", - "@radix-ui/react-switch": "^1.1.0", - "@radix-ui/react-dialog": "^1.1.0", - "@radix-ui/react-dropdown-menu": "^2.1.0", - "@radix-ui/react-toast": "^1.2.0", - "@radix-ui/react-tooltip": "^1.1.0", - "lucide-react": "^0.460.0", - "sonner": "^1.7.0" - }, - "devDependencies": { - "@types/react": "^18.3.12", - "@types/react-dom": "^18.3.1", - "@vitejs/plugin-react-swc": "^3.7.0", - "autoprefixer": "^10.4.20", - "postcss": "^8.4.47", - "tailwindcss": "^3.4.14", - "typescript": "^5.6.3", - "vite": "^5.4.0" - } -} diff --git a/admin-ui/postcss.config.js b/admin-ui/postcss.config.js deleted file mode 100644 index 2e7af2b7f1a6f391da1631d93968a9d487ba977d..0000000000000000000000000000000000000000 --- a/admin-ui/postcss.config.js +++ /dev/null @@ -1,6 +0,0 @@ -export default { - plugins: { - tailwindcss: {}, - autoprefixer: {}, - }, -} diff --git a/admin-ui/public/vite.svg b/admin-ui/public/vite.svg deleted file mode 100644 index 6a4109910f7762f0470eacc43118d1d613951d13..0000000000000000000000000000000000000000 --- a/admin-ui/public/vite.svg +++ /dev/null @@ -1 +0,0 @@ - diff --git a/admin-ui/src/App.tsx b/admin-ui/src/App.tsx deleted file mode 100644 index f2fab45510cb4b5b73fb4492bca1c6d1871cbd50..0000000000000000000000000000000000000000 --- a/admin-ui/src/App.tsx +++ /dev/null @@ -1,37 +0,0 @@ -import { useState, useEffect } from 'react' -import { storage } from '@/lib/storage' -import { LoginPage } from '@/components/login-page' -import { Dashboard } from '@/components/dashboard' -import { Toaster } from '@/components/ui/sonner' - -function App() { - const [isLoggedIn, setIsLoggedIn] = useState(false) - - useEffect(() => { - // 检查是否已经有保存的 API Key - if (storage.getApiKey()) { - setIsLoggedIn(true) - } - }, []) - - const handleLogin = () => { - setIsLoggedIn(true) - } - - const handleLogout = () => { - setIsLoggedIn(false) - } - - return ( - <> - {isLoggedIn ? ( - - ) : ( - - )} - - - ) -} - -export default App diff --git a/admin-ui/src/api/credentials.ts b/admin-ui/src/api/credentials.ts deleted file mode 100644 index 14591029bcbdcbd2ea4cc45f92ca84268eb85533..0000000000000000000000000000000000000000 --- a/admin-ui/src/api/credentials.ts +++ /dev/null @@ -1,86 +0,0 @@ -import axios from 'axios' -import { storage } from '@/lib/storage' -import type { - CredentialsStatusResponse, - BalanceResponse, - SuccessResponse, - SetDisabledRequest, - SetPriorityRequest, - AddCredentialRequest, - AddCredentialResponse, -} from '@/types/api' - -// 创建 axios 实例 -const api = axios.create({ - baseURL: '/api/admin', - headers: { - 'Content-Type': 'application/json', - }, -}) - -// 请求拦截器添加 API Key -api.interceptors.request.use((config) => { - const apiKey = storage.getApiKey() - if (apiKey) { - config.headers['x-api-key'] = apiKey - } - return config -}) - -// 获取所有凭据状态 -export async function getCredentials(): Promise { - const { data } = await api.get('/credentials') - return data -} - -// 设置凭据禁用状态 -export async function setCredentialDisabled( - id: number, - disabled: boolean -): Promise { - const { data } = await api.post( - `/credentials/${id}/disabled`, - { disabled } as SetDisabledRequest - ) - return data -} - -// 设置凭据优先级 -export async function setCredentialPriority( - id: number, - priority: number -): Promise { - const { data } = await api.post( - `/credentials/${id}/priority`, - { priority } as SetPriorityRequest - ) - return data -} - -// 重置失败计数 -export async function resetCredentialFailure( - id: number -): Promise { - const { data } = await api.post(`/credentials/${id}/reset`) - return data -} - -// 获取凭据余额 -export async function getCredentialBalance(id: number): Promise { - const { data } = await api.get(`/credentials/${id}/balance`) - return data -} - -// 添加新凭据 -export async function addCredential( - req: AddCredentialRequest -): Promise { - const { data } = await api.post('/credentials', req) - return data -} - -// 删除凭据 -export async function deleteCredential(id: number): Promise { - const { data } = await api.delete(`/credentials/${id}`) - return data -} diff --git a/admin-ui/src/components/add-credential-dialog.tsx b/admin-ui/src/components/add-credential-dialog.tsx deleted file mode 100644 index 1ea498f39f5e63cbb426cc278be8a189aa0d5285..0000000000000000000000000000000000000000 --- a/admin-ui/src/components/add-credential-dialog.tsx +++ /dev/null @@ -1,186 +0,0 @@ -import { useState } from 'react' -import { toast } from 'sonner' -import { - Dialog, - DialogContent, - DialogHeader, - DialogTitle, - DialogFooter, -} from '@/components/ui/dialog' -import { Button } from '@/components/ui/button' -import { Input } from '@/components/ui/input' -import { useAddCredential } from '@/hooks/use-credentials' -import { extractErrorMessage } from '@/lib/utils' - -interface AddCredentialDialogProps { - open: boolean - onOpenChange: (open: boolean) => void -} - -type AuthMethod = 'social' | 'idc' | 'builder-id' - -export function AddCredentialDialog({ open, onOpenChange }: AddCredentialDialogProps) { - const [refreshToken, setRefreshToken] = useState('') - const [authMethod, setAuthMethod] = useState('social') - const [clientId, setClientId] = useState('') - const [clientSecret, setClientSecret] = useState('') - const [priority, setPriority] = useState('0') - - const { mutate, isPending } = useAddCredential() - - const resetForm = () => { - setRefreshToken('') - setAuthMethod('social') - setClientId('') - setClientSecret('') - setPriority('0') - } - - const handleSubmit = (e: React.FormEvent) => { - e.preventDefault() - - // 验证必填字段 - if (!refreshToken.trim()) { - toast.error('请输入 Refresh Token') - return - } - - // IdC/Builder-ID 需要额外字段 - if ((authMethod === 'idc' || authMethod === 'builder-id') && - (!clientId.trim() || !clientSecret.trim())) { - toast.error('IdC/Builder-ID 认证需要填写 Client ID 和 Client Secret') - return - } - - mutate( - { - refreshToken: refreshToken.trim(), - authMethod, - clientId: clientId.trim() || undefined, - clientSecret: clientSecret.trim() || undefined, - priority: parseInt(priority) || 0, - }, - { - onSuccess: (data) => { - toast.success(data.message) - onOpenChange(false) - resetForm() - }, - onError: (error: unknown) => { - toast.error(`添加失败: ${extractErrorMessage(error)}`) - }, - } - ) - } - - return ( - - - - 添加凭据 - - -
-
- {/* Refresh Token */} -
- - setRefreshToken(e.target.value)} - disabled={isPending} - /> -
- - {/* 认证方式 */} -
- - -
- - {/* IdC/Builder-ID 额外字段 */} - {(authMethod === 'idc' || authMethod === 'builder-id') && ( - <> -
- - setClientId(e.target.value)} - disabled={isPending} - /> -
-
- - setClientSecret(e.target.value)} - disabled={isPending} - /> -
- - )} - - {/* 优先级 */} -
- - setPriority(e.target.value)} - disabled={isPending} - /> -

- 数字越小优先级越高,默认为 0 -

-
-
- - - - - -
-
-
- ) -} diff --git a/admin-ui/src/components/balance-dialog.tsx b/admin-ui/src/components/balance-dialog.tsx deleted file mode 100644 index a7f9ec905dd9a2fde735629cabf44d52e2520f6b..0000000000000000000000000000000000000000 --- a/admin-ui/src/components/balance-dialog.tsx +++ /dev/null @@ -1,104 +0,0 @@ -import { - Dialog, - DialogContent, - DialogHeader, - DialogTitle, -} from '@/components/ui/dialog' -import { Progress } from '@/components/ui/progress' -import { useCredentialBalance } from '@/hooks/use-credentials' -import { parseError } from '@/lib/utils' - -interface BalanceDialogProps { - credentialId: number | null - open: boolean - onOpenChange: (open: boolean) => void -} - -export function BalanceDialog({ credentialId, open, onOpenChange }: BalanceDialogProps) { - const { data: balance, isLoading, error } = useCredentialBalance(credentialId) - - const formatDate = (timestamp: number | null) => { - if (!timestamp) return '未知' - return new Date(timestamp * 1000).toLocaleString('zh-CN') - } - - const formatNumber = (num: number) => { - return num.toLocaleString('zh-CN', { minimumFractionDigits: 2, maximumFractionDigits: 2 }) - } - - return ( - - - - - 凭据 #{credentialId} 余额信息 - - - - {isLoading && ( -
-
-
- )} - - {error && (() => { - const parsed = parseError(error) - return ( -
-
- - - - {parsed.title} -
- {parsed.detail && ( -
- {parsed.detail} -
- )} -
- ) - })()} - - {balance && ( -
- {/* 订阅类型 */} -
- - {balance.subscriptionTitle || '未知订阅类型'} - -
- - {/* 使用进度 */} -
-
- 已使用: ${formatNumber(balance.currentUsage)} - 限额: ${formatNumber(balance.usageLimit)} -
- -
- {balance.usagePercentage.toFixed(1)}% 已使用 -
-
- - {/* 详细信息 */} -
-
- 剩余额度: - - ${formatNumber(balance.remaining)} - -
-
- 下次重置: - - {formatDate(balance.nextResetAt)} - -
-
-
- )} -
-
- ) -} diff --git a/admin-ui/src/components/credential-card.tsx b/admin-ui/src/components/credential-card.tsx deleted file mode 100644 index a3e4ea04e8cd655da8b05b7b811ad10dff6346f4..0000000000000000000000000000000000000000 --- a/admin-ui/src/components/credential-card.tsx +++ /dev/null @@ -1,298 +0,0 @@ -import { useState } from 'react' -import { toast } from 'sonner' -import { RefreshCw, ChevronUp, ChevronDown, Wallet, Trash2 } from 'lucide-react' -import { Card, CardContent, CardHeader, CardTitle } from '@/components/ui/card' -import { Button } from '@/components/ui/button' -import { Badge } from '@/components/ui/badge' -import { Switch } from '@/components/ui/switch' -import { Input } from '@/components/ui/input' -import { - Dialog, - DialogContent, - DialogDescription, - DialogFooter, - DialogHeader, - DialogTitle, -} from '@/components/ui/dialog' -import type { CredentialStatusItem } from '@/types/api' -import { - useSetDisabled, - useSetPriority, - useResetFailure, - useDeleteCredential, -} from '@/hooks/use-credentials' - -interface CredentialCardProps { - credential: CredentialStatusItem - onViewBalance: (id: number) => void -} - -export function CredentialCard({ credential, onViewBalance }: CredentialCardProps) { - const [editingPriority, setEditingPriority] = useState(false) - const [priorityValue, setPriorityValue] = useState(String(credential.priority)) - const [showDeleteDialog, setShowDeleteDialog] = useState(false) - - const setDisabled = useSetDisabled() - const setPriority = useSetPriority() - const resetFailure = useResetFailure() - const deleteCredential = useDeleteCredential() - - const handleToggleDisabled = () => { - setDisabled.mutate( - { id: credential.id, disabled: !credential.disabled }, - { - onSuccess: (res) => { - toast.success(res.message) - }, - onError: (err) => { - toast.error('操作失败: ' + (err as Error).message) - }, - } - ) - } - - const handlePriorityChange = () => { - const newPriority = parseInt(priorityValue, 10) - if (isNaN(newPriority) || newPriority < 0) { - toast.error('优先级必须是非负整数') - return - } - setPriority.mutate( - { id: credential.id, priority: newPriority }, - { - onSuccess: (res) => { - toast.success(res.message) - setEditingPriority(false) - }, - onError: (err) => { - toast.error('操作失败: ' + (err as Error).message) - }, - } - ) - } - - const handleReset = () => { - resetFailure.mutate(credential.id, { - onSuccess: (res) => { - toast.success(res.message) - }, - onError: (err) => { - toast.error('操作失败: ' + (err as Error).message) - }, - }) - } - - const handleDelete = () => { - deleteCredential.mutate(credential.id, { - onSuccess: (res) => { - toast.success(res.message) - setShowDeleteDialog(false) - }, - onError: (err) => { - toast.error('删除失败: ' + (err as Error).message) - }, - }) - } - - const formatExpiry = (expiresAt: string | null) => { - if (!expiresAt) return '未知' - const date = new Date(expiresAt) - const now = new Date() - const diff = date.getTime() - now.getTime() - if (diff < 0) return '已过期' - const minutes = Math.floor(diff / 60000) - if (minutes < 60) return `${minutes} 分钟` - const hours = Math.floor(minutes / 60) - if (hours < 24) return `${hours} 小时` - return `${Math.floor(hours / 24)} 天` - } - - return ( - <> - - -
- - 凭据 #{credential.id} - {credential.isCurrent && ( - 当前 - )} - {credential.disabled && ( - 已禁用 - )} - -
- 启用 - -
-
-
- - {/* 信息网格 */} -
-
- 优先级: - {editingPriority ? ( -
- setPriorityValue(e.target.value)} - className="w-16 h-7 text-sm" - min="0" - /> - - -
- ) : ( - setEditingPriority(true)} - > - {credential.priority} - (点击编辑) - - )} -
-
- 失败次数: - 0 ? 'text-red-500 font-medium' : ''}> - {credential.failureCount} - -
-
- 认证方式: - {credential.authMethod || '未知'} -
-
- Token 有效期: - {formatExpiry(credential.expiresAt)} -
- {credential.hasProfileArn && ( -
- 有 Profile ARN -
- )} -
- - {/* 操作按钮 */} -
- - - - - -
-
-
- - {/* 删除确认对话框 */} - - - - 确认删除凭据 - - 您确定要删除凭据 #{credential.id} 吗?此操作无法撤销。 - - - - - - - - - - ) -} diff --git a/admin-ui/src/components/dashboard.tsx b/admin-ui/src/components/dashboard.tsx deleted file mode 100644 index 0898b5e799faa82a7db1c0433ac6c4b20a1ecd66..0000000000000000000000000000000000000000 --- a/admin-ui/src/components/dashboard.tsx +++ /dev/null @@ -1,186 +0,0 @@ -import { useState } from 'react' -import { RefreshCw, LogOut, Moon, Sun, Server, Plus } from 'lucide-react' -import { useQueryClient } from '@tanstack/react-query' -import { toast } from 'sonner' -import { storage } from '@/lib/storage' -import { Card, CardContent, CardHeader, CardTitle } from '@/components/ui/card' -import { Button } from '@/components/ui/button' -import { Badge } from '@/components/ui/badge' -import { CredentialCard } from '@/components/credential-card' -import { BalanceDialog } from '@/components/balance-dialog' -import { AddCredentialDialog } from '@/components/add-credential-dialog' -import { useCredentials } from '@/hooks/use-credentials' - -interface DashboardProps { - onLogout: () => void -} - -export function Dashboard({ onLogout }: DashboardProps) { - const [selectedCredentialId, setSelectedCredentialId] = useState(null) - const [balanceDialogOpen, setBalanceDialogOpen] = useState(false) - const [addDialogOpen, setAddDialogOpen] = useState(false) - const [darkMode, setDarkMode] = useState(() => { - if (typeof window !== 'undefined') { - return document.documentElement.classList.contains('dark') - } - return false - }) - - const queryClient = useQueryClient() - const { data, isLoading, error, refetch } = useCredentials() - - const toggleDarkMode = () => { - setDarkMode(!darkMode) - document.documentElement.classList.toggle('dark') - } - - const handleViewBalance = (id: number) => { - setSelectedCredentialId(id) - setBalanceDialogOpen(true) - } - - const handleRefresh = () => { - refetch() - toast.success('已刷新凭据列表') - } - - const handleLogout = () => { - storage.removeApiKey() - queryClient.clear() - onLogout() - } - - if (isLoading) { - return ( -
-
-
-

加载中...

-
-
- ) - } - - if (error) { - return ( -
- - -
加载失败
-

{(error as Error).message}

-
- - -
-
-
-
- ) - } - - return ( -
- {/* 顶部导航 */} -
-
-
- - Kiro Admin -
-
- - - -
-
-
- - {/* 主内容 */} -
- {/* 统计卡片 */} -
- - - - 凭据总数 - - - -
{data?.total || 0}
-
-
- - - - 可用凭据 - - - -
{data?.available || 0}
-
-
- - - - 当前活跃 - - - -
- #{data?.currentId || '-'} - 活跃 -
-
-
-
- - {/* 凭据列表 */} -
-
-

凭据管理

- -
- {data?.credentials.length === 0 ? ( - - - 暂无凭据 - - - ) : ( -
- {data?.credentials.map((credential) => ( - - ))} -
- )} -
-
- - {/* 余额对话框 */} - - - {/* 添加凭据对话框 */} - -
- ) -} diff --git a/admin-ui/src/components/login-page.tsx b/admin-ui/src/components/login-page.tsx deleted file mode 100644 index c89f65c39835ac981aba1258e28e77e152695581..0000000000000000000000000000000000000000 --- a/admin-ui/src/components/login-page.tsx +++ /dev/null @@ -1,62 +0,0 @@ -import { useState, useEffect } from 'react' -import { KeyRound } from 'lucide-react' -import { storage } from '@/lib/storage' -import { Card, CardContent, CardDescription, CardHeader, CardTitle } from '@/components/ui/card' -import { Input } from '@/components/ui/input' -import { Button } from '@/components/ui/button' - -interface LoginPageProps { - onLogin: (apiKey: string) => void -} - -export function LoginPage({ onLogin }: LoginPageProps) { - const [apiKey, setApiKey] = useState('') - - useEffect(() => { - // 从 storage 读取保存的 API Key - const savedKey = storage.getApiKey() - if (savedKey) { - setApiKey(savedKey) - } - }, []) - - const handleSubmit = (e: React.FormEvent) => { - e.preventDefault() - if (apiKey.trim()) { - storage.setApiKey(apiKey.trim()) - onLogin(apiKey.trim()) - } - } - - return ( -
- - -
- -
- Kiro Admin - - 请输入 Admin API Key 以访问管理面板 - -
- -
-
- setApiKey(e.target.value)} - className="text-center" - /> -
- -
-
-
-
- ) -} diff --git a/admin-ui/src/components/ui/badge.tsx b/admin-ui/src/components/ui/badge.tsx deleted file mode 100644 index baa444c11155c9ea3562fb900b3c67d62e7b147f..0000000000000000000000000000000000000000 --- a/admin-ui/src/components/ui/badge.tsx +++ /dev/null @@ -1,39 +0,0 @@ -import * as React from 'react' -import { cva, type VariantProps } from 'class-variance-authority' -import { cn } from '@/lib/utils' - -const badgeVariants = cva( - 'inline-flex items-center rounded-full border px-2.5 py-0.5 text-xs font-semibold transition-colors focus:outline-none focus:ring-2 focus:ring-ring focus:ring-offset-2', - { - variants: { - variant: { - default: - 'border-transparent bg-primary text-primary-foreground hover:bg-primary/80', - secondary: - 'border-transparent bg-secondary text-secondary-foreground hover:bg-secondary/80', - destructive: - 'border-transparent bg-destructive text-destructive-foreground hover:bg-destructive/80', - outline: 'text-foreground', - success: - 'border-transparent bg-green-500 text-white hover:bg-green-500/80', - warning: - 'border-transparent bg-yellow-500 text-white hover:bg-yellow-500/80', - }, - }, - defaultVariants: { - variant: 'default', - }, - } -) - -export interface BadgeProps - extends React.HTMLAttributes, - VariantProps {} - -function Badge({ className, variant, ...props }: BadgeProps) { - return ( -
- ) -} - -export { Badge, badgeVariants } diff --git a/admin-ui/src/components/ui/button.tsx b/admin-ui/src/components/ui/button.tsx deleted file mode 100644 index 640e0f6ec37b5d09449cd24cb22eb107819574ad..0000000000000000000000000000000000000000 --- a/admin-ui/src/components/ui/button.tsx +++ /dev/null @@ -1,55 +0,0 @@ -import * as React from 'react' -import { Slot } from '@radix-ui/react-slot' -import { cva, type VariantProps } from 'class-variance-authority' -import { cn } from '@/lib/utils' - -const buttonVariants = cva( - 'inline-flex items-center justify-center gap-2 whitespace-nowrap rounded-md text-sm font-medium ring-offset-background transition-colors focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-offset-2 disabled:pointer-events-none disabled:opacity-50 [&_svg]:pointer-events-none [&_svg]:size-4 [&_svg]:shrink-0', - { - variants: { - variant: { - default: 'bg-primary text-primary-foreground hover:bg-primary/90', - destructive: - 'bg-destructive text-destructive-foreground hover:bg-destructive/90', - outline: - 'border border-input bg-background hover:bg-accent hover:text-accent-foreground', - secondary: - 'bg-secondary text-secondary-foreground hover:bg-secondary/80', - ghost: 'hover:bg-accent hover:text-accent-foreground', - link: 'text-primary underline-offset-4 hover:underline', - }, - size: { - default: 'h-10 px-4 py-2', - sm: 'h-9 rounded-md px-3', - lg: 'h-11 rounded-md px-8', - icon: 'h-10 w-10', - }, - }, - defaultVariants: { - variant: 'default', - size: 'default', - }, - } -) - -export interface ButtonProps - extends React.ButtonHTMLAttributes, - VariantProps { - asChild?: boolean -} - -const Button = React.forwardRef( - ({ className, variant, size, asChild = false, ...props }, ref) => { - const Comp = asChild ? Slot : 'button' - return ( - - ) - } -) -Button.displayName = 'Button' - -export { Button, buttonVariants } diff --git a/admin-ui/src/components/ui/card.tsx b/admin-ui/src/components/ui/card.tsx deleted file mode 100644 index 18e786a24b5c6130963b1724ebdd8d44f7b0567f..0000000000000000000000000000000000000000 --- a/admin-ui/src/components/ui/card.tsx +++ /dev/null @@ -1,78 +0,0 @@ -import * as React from 'react' -import { cn } from '@/lib/utils' - -const Card = React.forwardRef< - HTMLDivElement, - React.HTMLAttributes ->(({ className, ...props }, ref) => ( -
-)) -Card.displayName = 'Card' - -const CardHeader = React.forwardRef< - HTMLDivElement, - React.HTMLAttributes ->(({ className, ...props }, ref) => ( -
-)) -CardHeader.displayName = 'CardHeader' - -const CardTitle = React.forwardRef< - HTMLParagraphElement, - React.HTMLAttributes ->(({ className, ...props }, ref) => ( -

-)) -CardTitle.displayName = 'CardTitle' - -const CardDescription = React.forwardRef< - HTMLParagraphElement, - React.HTMLAttributes ->(({ className, ...props }, ref) => ( -

-)) -CardDescription.displayName = 'CardDescription' - -const CardContent = React.forwardRef< - HTMLDivElement, - React.HTMLAttributes ->(({ className, ...props }, ref) => ( -

-)) -CardContent.displayName = 'CardContent' - -const CardFooter = React.forwardRef< - HTMLDivElement, - React.HTMLAttributes ->(({ className, ...props }, ref) => ( -
-)) -CardFooter.displayName = 'CardFooter' - -export { Card, CardHeader, CardFooter, CardTitle, CardDescription, CardContent } diff --git a/admin-ui/src/components/ui/dialog.tsx b/admin-ui/src/components/ui/dialog.tsx deleted file mode 100644 index ba67aa862159add7275c9b9100ba24b02fcda186..0000000000000000000000000000000000000000 --- a/admin-ui/src/components/ui/dialog.tsx +++ /dev/null @@ -1,119 +0,0 @@ -import * as React from 'react' -import * as DialogPrimitive from '@radix-ui/react-dialog' -import { X } from 'lucide-react' -import { cn } from '@/lib/utils' - -const Dialog = DialogPrimitive.Root - -const DialogTrigger = DialogPrimitive.Trigger - -const DialogPortal = DialogPrimitive.Portal - -const DialogClose = DialogPrimitive.Close - -const DialogOverlay = React.forwardRef< - React.ElementRef, - React.ComponentPropsWithoutRef ->(({ className, ...props }, ref) => ( - -)) -DialogOverlay.displayName = DialogPrimitive.Overlay.displayName - -const DialogContent = React.forwardRef< - React.ElementRef, - React.ComponentPropsWithoutRef ->(({ className, children, ...props }, ref) => ( - - - - {children} - - - 关闭 - - - -)) -DialogContent.displayName = DialogPrimitive.Content.displayName - -const DialogHeader = ({ - className, - ...props -}: React.HTMLAttributes) => ( -
-) -DialogHeader.displayName = 'DialogHeader' - -const DialogFooter = ({ - className, - ...props -}: React.HTMLAttributes) => ( -
-) -DialogFooter.displayName = 'DialogFooter' - -const DialogTitle = React.forwardRef< - React.ElementRef, - React.ComponentPropsWithoutRef ->(({ className, ...props }, ref) => ( - -)) -DialogTitle.displayName = DialogPrimitive.Title.displayName - -const DialogDescription = React.forwardRef< - React.ElementRef, - React.ComponentPropsWithoutRef ->(({ className, ...props }, ref) => ( - -)) -DialogDescription.displayName = DialogPrimitive.Description.displayName - -export { - Dialog, - DialogPortal, - DialogOverlay, - DialogClose, - DialogTrigger, - DialogContent, - DialogHeader, - DialogFooter, - DialogTitle, - DialogDescription, -} diff --git a/admin-ui/src/components/ui/input.tsx b/admin-ui/src/components/ui/input.tsx deleted file mode 100644 index d32ad81fb1d1ac66263698220ff19528cdc03d60..0000000000000000000000000000000000000000 --- a/admin-ui/src/components/ui/input.tsx +++ /dev/null @@ -1,24 +0,0 @@ -import * as React from 'react' -import { cn } from '@/lib/utils' - -export interface InputProps - extends React.InputHTMLAttributes {} - -const Input = React.forwardRef( - ({ className, type, ...props }, ref) => { - return ( - - ) - } -) -Input.displayName = 'Input' - -export { Input } diff --git a/admin-ui/src/components/ui/progress.tsx b/admin-ui/src/components/ui/progress.tsx deleted file mode 100644 index b400823c409b248dd7433546dd1b8fd6542c3fac..0000000000000000000000000000000000000000 --- a/admin-ui/src/components/ui/progress.tsx +++ /dev/null @@ -1,35 +0,0 @@ -import * as React from 'react' -import { cn } from '@/lib/utils' - -interface ProgressProps extends React.HTMLAttributes { - value?: number - max?: number -} - -const Progress = React.forwardRef( - ({ className, value = 0, max = 100, ...props }, ref) => { - const percentage = Math.min(Math.max((value / max) * 100, 0), 100) - - return ( -
-
80 ? 'bg-red-500' : percentage > 60 ? 'bg-yellow-500' : 'bg-green-500' - )} - style={{ width: `${percentage}%` }} - /> -
- ) - } -) -Progress.displayName = 'Progress' - -export { Progress } diff --git a/admin-ui/src/components/ui/sonner.tsx b/admin-ui/src/components/ui/sonner.tsx deleted file mode 100644 index f5d591c7902618bbbb6de455ea2113187b55515f..0000000000000000000000000000000000000000 --- a/admin-ui/src/components/ui/sonner.tsx +++ /dev/null @@ -1,25 +0,0 @@ -import { Toaster as Sonner } from 'sonner' - -type ToasterProps = React.ComponentProps - -const Toaster = ({ ...props }: ToasterProps) => { - return ( - - ) -} - -export { Toaster } diff --git a/admin-ui/src/components/ui/switch.tsx b/admin-ui/src/components/ui/switch.tsx deleted file mode 100644 index f2d6513519d95e611b2b0fe01627abfeeeb6a795..0000000000000000000000000000000000000000 --- a/admin-ui/src/components/ui/switch.tsx +++ /dev/null @@ -1,26 +0,0 @@ -import * as React from 'react' -import * as SwitchPrimitives from '@radix-ui/react-switch' -import { cn } from '@/lib/utils' - -const Switch = React.forwardRef< - React.ElementRef, - React.ComponentPropsWithoutRef ->(({ className, ...props }, ref) => ( - - - -)) -Switch.displayName = SwitchPrimitives.Root.displayName - -export { Switch } diff --git a/admin-ui/src/hooks/use-credentials.ts b/admin-ui/src/hooks/use-credentials.ts deleted file mode 100644 index 8c8e9afc4c71ec7306fb8ee4bbedbb89644bbc24..0000000000000000000000000000000000000000 --- a/admin-ui/src/hooks/use-credentials.ts +++ /dev/null @@ -1,87 +0,0 @@ -import { useQuery, useMutation, useQueryClient } from '@tanstack/react-query' -import { - getCredentials, - setCredentialDisabled, - setCredentialPriority, - resetCredentialFailure, - getCredentialBalance, - addCredential, - deleteCredential, -} from '@/api/credentials' -import type { AddCredentialRequest } from '@/types/api' - -// 查询凭据列表 -export function useCredentials() { - return useQuery({ - queryKey: ['credentials'], - queryFn: getCredentials, - refetchInterval: 30000, // 每 30 秒刷新一次 - }) -} - -// 查询凭据余额 -export function useCredentialBalance(id: number | null) { - return useQuery({ - queryKey: ['credential-balance', id], - queryFn: () => getCredentialBalance(id!), - enabled: id !== null, - retry: false, // 余额查询失败时不重试(避免重复请求被封禁的账号) - }) -} - -// 设置禁用状态 -export function useSetDisabled() { - const queryClient = useQueryClient() - return useMutation({ - mutationFn: ({ id, disabled }: { id: number; disabled: boolean }) => - setCredentialDisabled(id, disabled), - onSuccess: () => { - queryClient.invalidateQueries({ queryKey: ['credentials'] }) - }, - }) -} - -// 设置优先级 -export function useSetPriority() { - const queryClient = useQueryClient() - return useMutation({ - mutationFn: ({ id, priority }: { id: number; priority: number }) => - setCredentialPriority(id, priority), - onSuccess: () => { - queryClient.invalidateQueries({ queryKey: ['credentials'] }) - }, - }) -} - -// 重置失败计数 -export function useResetFailure() { - const queryClient = useQueryClient() - return useMutation({ - mutationFn: (id: number) => resetCredentialFailure(id), - onSuccess: () => { - queryClient.invalidateQueries({ queryKey: ['credentials'] }) - }, - }) -} - -// 添加新凭据 -export function useAddCredential() { - const queryClient = useQueryClient() - return useMutation({ - mutationFn: (req: AddCredentialRequest) => addCredential(req), - onSuccess: () => { - queryClient.invalidateQueries({ queryKey: ['credentials'] }) - }, - }) -} - -// 删除凭据 -export function useDeleteCredential() { - const queryClient = useQueryClient() - return useMutation({ - mutationFn: (id: number) => deleteCredential(id), - onSuccess: () => { - queryClient.invalidateQueries({ queryKey: ['credentials'] }) - }, - }) -} diff --git a/admin-ui/src/index.css b/admin-ui/src/index.css deleted file mode 100644 index 01678720f31557bfe2b34127911f6836f925eaac..0000000000000000000000000000000000000000 --- a/admin-ui/src/index.css +++ /dev/null @@ -1,60 +0,0 @@ -@tailwind base; -@tailwind components; -@tailwind utilities; - -@layer base { - :root { - --background: 0 0% 100%; - --foreground: 222.2 84% 4.9%; - --card: 0 0% 100%; - --card-foreground: 222.2 84% 4.9%; - --popover: 0 0% 100%; - --popover-foreground: 222.2 84% 4.9%; - --primary: 222.2 47.4% 11.2%; - --primary-foreground: 210 40% 98%; - --secondary: 210 40% 96.1%; - --secondary-foreground: 222.2 47.4% 11.2%; - --muted: 210 40% 96.1%; - --muted-foreground: 215.4 16.3% 46.9%; - --accent: 210 40% 96.1%; - --accent-foreground: 222.2 47.4% 11.2%; - --destructive: 0 84.2% 60.2%; - --destructive-foreground: 210 40% 98%; - --border: 214.3 31.8% 91.4%; - --input: 214.3 31.8% 91.4%; - --ring: 222.2 84% 4.9%; - --radius: 0.5rem; - } - - .dark { - --background: 222.2 84% 4.9%; - --foreground: 210 40% 98%; - --card: 222.2 84% 4.9%; - --card-foreground: 210 40% 98%; - --popover: 222.2 84% 4.9%; - --popover-foreground: 210 40% 98%; - --primary: 210 40% 98%; - --primary-foreground: 222.2 47.4% 11.2%; - --secondary: 217.2 32.6% 17.5%; - --secondary-foreground: 210 40% 98%; - --muted: 217.2 32.6% 17.5%; - --muted-foreground: 215 20.2% 65.1%; - --accent: 217.2 32.6% 17.5%; - --accent-foreground: 210 40% 98%; - --destructive: 0 62.8% 30.6%; - --destructive-foreground: 210 40% 98%; - --border: 217.2 32.6% 17.5%; - --input: 217.2 32.6% 17.5%; - --ring: 212.7 26.8% 83.9%; - } -} - -@layer base { - * { - @apply border-border; - } - body { - @apply bg-background text-foreground; - font-feature-settings: "rlig" 1, "calt" 1; - } -} diff --git a/admin-ui/src/lib/storage.ts b/admin-ui/src/lib/storage.ts deleted file mode 100644 index b61b741c46b126977e79a1c39e03b79833759986..0000000000000000000000000000000000000000 --- a/admin-ui/src/lib/storage.ts +++ /dev/null @@ -1,7 +0,0 @@ -const API_KEY_STORAGE_KEY = 'adminApiKey' - -export const storage = { - getApiKey: () => localStorage.getItem(API_KEY_STORAGE_KEY), - setApiKey: (key: string) => localStorage.setItem(API_KEY_STORAGE_KEY, key), - removeApiKey: () => localStorage.removeItem(API_KEY_STORAGE_KEY), -} diff --git a/admin-ui/src/lib/utils.ts b/admin-ui/src/lib/utils.ts deleted file mode 100644 index 0be83dd12f2f2a59a26043f78985096b821fd8c9..0000000000000000000000000000000000000000 --- a/admin-ui/src/lib/utils.ts +++ /dev/null @@ -1,106 +0,0 @@ -import { clsx, type ClassValue } from 'clsx' -import { twMerge } from 'tailwind-merge' - -export function cn(...inputs: ClassValue[]) { - return twMerge(clsx(inputs)) -} - -/** - * 解析后端错误响应,提取用户友好的错误信息 - */ -export interface ParsedError { - /** 简短的错误标题 */ - title: string - /** 详细的错误描述 */ - detail?: string - /** 错误类型 */ - type?: string -} - -/** - * 从错误对象中提取错误消息 - * 支持 Axios 错误和普通 Error 对象 - */ -export function extractErrorMessage(error: unknown): string { - const parsed = parseError(error) - return parsed.title -} - -/** - * 解析错误,返回结构化的错误信息 - */ -export function parseError(error: unknown): ParsedError { - if (!error || typeof error !== 'object') { - return { title: '未知错误' } - } - - const axiosError = error as Record - const response = axiosError.response as Record | undefined - const data = response?.data as Record | undefined - const errorObj = data?.error as Record | undefined - - // 尝试从后端错误响应中提取信息 - if (errorObj && typeof errorObj.message === 'string') { - const message = errorObj.message - const type = typeof errorObj.type === 'string' ? errorObj.type : undefined - - // 解析嵌套的错误信息(如:上游服务错误: 权限不足: 403 {...}) - const parsed = parseNestedErrorMessage(message) - - return { - title: parsed.title, - detail: parsed.detail, - type, - } - } - - // 回退到 Error.message - if ('message' in axiosError && typeof axiosError.message === 'string') { - return { title: axiosError.message } - } - - return { title: '未知错误' } -} - -/** - * 解析嵌套的错误消息 - * 例如:"上游服务错误: 权限不足,无法获取使用额度: 403 Forbidden {...}" - */ -function parseNestedErrorMessage(message: string): { title: string; detail?: string } { - // 尝试提取 HTTP 状态码(如 403、502 等) - const statusMatch = message.match(/(\d{3})\s+\w+/) - const statusCode = statusMatch ? statusMatch[1] : null - - // 尝试提取 JSON 中的 message 字段 - const jsonMatch = message.match(/\{[^{}]*"message"\s*:\s*"([^"]+)"[^{}]*\}/) - if (jsonMatch) { - const innerMessage = jsonMatch[1] - // 提取主要错误原因(去掉前缀) - const parts = message.split(':').map(s => s.trim()) - const mainReason = parts.length > 1 ? parts[1].split(':')[0] : parts[0] - - // 在 title 中包含状态码 - const title = statusCode - ? `${mainReason || '服务错误'} (${statusCode})` - : (mainReason || '服务错误') - - return { - title, - detail: innerMessage, - } - } - - // 尝试按冒号分割,提取主要信息 - const colonParts = message.split(':') - if (colonParts.length >= 2) { - const mainPart = colonParts[1].trim().split(':')[0].trim() - const title = statusCode ? `${mainPart} (${statusCode})` : mainPart - - return { - title, - detail: colonParts.slice(2).join(':').trim() || undefined, - } - } - - return { title: message } -} diff --git a/admin-ui/src/main.tsx b/admin-ui/src/main.tsx deleted file mode 100644 index 17ef62af4b05e9859405893d25919d3d6cc2ecf3..0000000000000000000000000000000000000000 --- a/admin-ui/src/main.tsx +++ /dev/null @@ -1,22 +0,0 @@ -import React from 'react' -import ReactDOM from 'react-dom/client' -import { QueryClient, QueryClientProvider } from '@tanstack/react-query' -import App from './App' -import './index.css' - -const queryClient = new QueryClient({ - defaultOptions: { - queries: { - staleTime: 5000, - refetchOnWindowFocus: false, - }, - }, -}) - -ReactDOM.createRoot(document.getElementById('root')!).render( - - - - - , -) diff --git a/admin-ui/src/types/api.ts b/admin-ui/src/types/api.ts deleted file mode 100644 index 05a77346544586fe569a78e7ba8d4aba90516bbd..0000000000000000000000000000000000000000 --- a/admin-ui/src/types/api.ts +++ /dev/null @@ -1,69 +0,0 @@ -// 凭据状态响应 -export interface CredentialsStatusResponse { - total: number - available: number - currentId: number - credentials: CredentialStatusItem[] -} - -// 单个凭据状态 -export interface CredentialStatusItem { - id: number - priority: number - disabled: boolean - failureCount: number - isCurrent: boolean - expiresAt: string | null - authMethod: string | null - hasProfileArn: boolean -} - -// 余额响应 -export interface BalanceResponse { - id: number - subscriptionTitle: string | null - currentUsage: number - usageLimit: number - remaining: number - usagePercentage: number - nextResetAt: number | null -} - -// 成功响应 -export interface SuccessResponse { - success: boolean - message: string -} - -// 错误响应 -export interface AdminErrorResponse { - error: { - type: string - message: string - } -} - -// 请求类型 -export interface SetDisabledRequest { - disabled: boolean -} - -export interface SetPriorityRequest { - priority: number -} - -// 添加凭据请求 -export interface AddCredentialRequest { - refreshToken: string - authMethod?: 'social' | 'idc' | 'builder-id' - clientId?: string - clientSecret?: string - priority?: number -} - -// 添加凭据响应 -export interface AddCredentialResponse { - success: boolean - message: string - credentialId: number -} diff --git a/admin-ui/tailwind.config.js b/admin-ui/tailwind.config.js deleted file mode 100644 index 2e36b3bd0a1fd448a2c746716c76baf3bdb8edbc..0000000000000000000000000000000000000000 --- a/admin-ui/tailwind.config.js +++ /dev/null @@ -1,53 +0,0 @@ -/** @type {import('tailwindcss').Config} */ -export default { - darkMode: 'class', - content: [ - './index.html', - './src/**/*.{js,ts,jsx,tsx}', - ], - theme: { - extend: { - colors: { - border: 'hsl(var(--border))', - input: 'hsl(var(--input))', - ring: 'hsl(var(--ring))', - background: 'hsl(var(--background))', - foreground: 'hsl(var(--foreground))', - primary: { - DEFAULT: 'hsl(var(--primary))', - foreground: 'hsl(var(--primary-foreground))', - }, - secondary: { - DEFAULT: 'hsl(var(--secondary))', - foreground: 'hsl(var(--secondary-foreground))', - }, - destructive: { - DEFAULT: 'hsl(var(--destructive))', - foreground: 'hsl(var(--destructive-foreground))', - }, - muted: { - DEFAULT: 'hsl(var(--muted))', - foreground: 'hsl(var(--muted-foreground))', - }, - accent: { - DEFAULT: 'hsl(var(--accent))', - foreground: 'hsl(var(--accent-foreground))', - }, - popover: { - DEFAULT: 'hsl(var(--popover))', - foreground: 'hsl(var(--popover-foreground))', - }, - card: { - DEFAULT: 'hsl(var(--card))', - foreground: 'hsl(var(--card-foreground))', - }, - }, - borderRadius: { - lg: 'var(--radius)', - md: 'calc(var(--radius) - 2px)', - sm: 'calc(var(--radius) - 4px)', - }, - }, - }, - plugins: [], -} diff --git a/admin-ui/tsconfig.json b/admin-ui/tsconfig.json deleted file mode 100644 index 5e1feb437651831f728f7188d3b8ce2c0e3eb8f7..0000000000000000000000000000000000000000 --- a/admin-ui/tsconfig.json +++ /dev/null @@ -1,24 +0,0 @@ -{ - "compilerOptions": { - "target": "ES2020", - "useDefineForClassFields": true, - "lib": ["ES2020", "DOM", "DOM.Iterable"], - "module": "ESNext", - "skipLibCheck": true, - "moduleResolution": "bundler", - "allowImportingTsExtensions": true, - "isolatedModules": true, - "moduleDetection": "force", - "noEmit": true, - "jsx": "react-jsx", - "strict": true, - "noUnusedLocals": true, - "noUnusedParameters": true, - "noFallthroughCasesInSwitch": true, - "baseUrl": ".", - "paths": { - "@/*": ["./src/*"] - } - }, - "include": ["src"] -} diff --git a/admin-ui/vite.config.ts b/admin-ui/vite.config.ts deleted file mode 100644 index 5c8e864871a88c8ded8cb0af1198be8eb5fae61c..0000000000000000000000000000000000000000 --- a/admin-ui/vite.config.ts +++ /dev/null @@ -1,25 +0,0 @@ -import { defineConfig } from 'vite' -import react from '@vitejs/plugin-react-swc' -import path from 'path' - -export default defineConfig({ - plugins: [react()], - base: '/admin/', - resolve: { - alias: { - '@': path.resolve(__dirname, './src'), - }, - }, - server: { - proxy: { - '/api': { - target: 'http://localhost:8080', - changeOrigin: true, - }, - }, - }, - build: { - outDir: 'dist', - emptyOutDir: true, - }, -}) diff --git a/config.example.json b/config.example.json deleted file mode 100644 index ab3544a9294f893db5a11ea2a090d1fce0bc0750..0000000000000000000000000000000000000000 --- a/config.example.json +++ /dev/null @@ -1,7 +0,0 @@ -{ - "host": "127.0.0.1", - "port": 8990, - "apiKey": "sk-kiro-rs-qazWSXedcRFV123456", - "region": "us-east-1", - "adminApiKey": "sk-admin-your-secret-key" -} \ No newline at end of file diff --git a/credentials.example.idc.json b/credentials.example.idc.json deleted file mode 100644 index 09f7cf135b37778be45fdd4d8d46407ac1babdf1..0000000000000000000000000000000000000000 --- a/credentials.example.idc.json +++ /dev/null @@ -1,9 +0,0 @@ -{ - "refreshToken": "xxxxxxxxxxxxxxxxxxxx", - "expiresAt": "2025-12-31T02:32:45.144Z", - "authMethod": "idc", - "clientId": "xxxxxxxxx", - "clientSecret": "xxxxxxxxx", - "region": "us-east-2", - "machineId": "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" -} diff --git a/credentials.example.multiple.json b/credentials.example.multiple.json deleted file mode 100644 index 0a57a002a4b48726cbfe8ef00f4a966af6868e60..0000000000000000000000000000000000000000 --- a/credentials.example.multiple.json +++ /dev/null @@ -1,19 +0,0 @@ -[ - { - "refreshToken": "xxxxxxxxxxxxxxxxxxxx", - "expiresAt": "2025-12-31T02:32:45.144Z", - "authMethod": "social", - "machineId": "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", - "priority": 0 - }, - { - "refreshToken": "yyyyyyyyyyyyyyyyyyyy", - "expiresAt": "2025-12-31T02:32:45.144Z", - "authMethod": "idc", - "clientId": "xxxxxxxxx", - "clientSecret": "xxxxxxxxx", - "region": "us-east-2", - "machineId": "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb", - "priority": 1 - } -] diff --git a/credentials.example.social.json b/credentials.example.social.json deleted file mode 100644 index 259897f2c7f5bb734a45bbfc754dbaea97a1834c..0000000000000000000000000000000000000000 --- a/credentials.example.social.json +++ /dev/null @@ -1,6 +0,0 @@ -{ - "refreshToken": "xxxxxxxxxxxxxxxxxxxx", - "expiresAt": "2025-12-31T02:32:45.144Z", - "authMethod": "social", - "machineId": "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" -} diff --git a/entrypoint.sh b/entrypoint.sh deleted file mode 100644 index cc806f0ffecb596a804a36616c9c0978aeec7207..0000000000000000000000000000000000000000 --- a/entrypoint.sh +++ /dev/null @@ -1,61 +0,0 @@ -#!/bin/sh - -# 从环境变量生成 config.json -if [ -n "${ADMIN_API_KEY}" ]; then - # 如果设置了 ADMIN_API_KEY,包含在配置中 - cat > /app/config/config.json << EOF -{ - "host": "0.0.0.0", - "port": 7860, - "apiKey": "${API_KEY:-sk-kiro-rs-default}", - "region": "${REGION:-us-east-1}", - "adminApiKey": "${ADMIN_API_KEY}" -} -EOF -else - # 否则不包含 adminApiKey - cat > /app/config/config.json << EOF -{ - "host": "0.0.0.0", - "port": 7860, - "apiKey": "${API_KEY:-sk-kiro-rs-default}", - "region": "${REGION:-us-east-1}" -} -EOF -fi - -# 从环境变量生成 credentials.json -# 支持两种模式: -# 1. 多凭据模式:通过 CREDENTIALS_JSON 环境变量传入完整的 JSON 数组 -# 2. 单凭据模式:通过单独的环境变量(向后兼容) - -if [ -n "${CREDENTIALS_JSON}" ]; then - # 多凭据模式:直接使用 CREDENTIALS_JSON - echo "${CREDENTIALS_JSON}" > /app/config/credentials.json - echo "Using multi-credential mode from CREDENTIALS_JSON" -else - # 单凭据模式 - if [ "${AUTH_METHOD}" = "idc" ]; then - cat > /app/config/credentials.json << EOF -{ - "refreshToken": "${REFRESH_TOKEN}", - "expiresAt": "${EXPIRES_AT:-2020-01-01T00:00:00.000Z}", - "authMethod": "idc", - "clientId": "${CLIENT_ID}", - "clientSecret": "${CLIENT_SECRET}" -} -EOF - else - cat > /app/config/credentials.json << EOF -{ - "refreshToken": "${REFRESH_TOKEN}", - "expiresAt": "${EXPIRES_AT:-2020-01-01T00:00:00.000Z}", - "authMethod": "${AUTH_METHOD:-social}" -} -EOF - fi - echo "Using single-credential mode" -fi - -echo "Starting kiro-rs..." -exec /app/kiro-rs -c /app/config/config.json --credentials /app/config/credentials.json diff --git a/src/admin/error.rs b/src/admin/error.rs deleted file mode 100644 index e1f921954ebc94a168d2ff8b00a23d23de369b64..0000000000000000000000000000000000000000 --- a/src/admin/error.rs +++ /dev/null @@ -1,64 +0,0 @@ -//! Admin API 错误类型定义 - -use std::fmt; - -use axum::http::StatusCode; - -use super::types::AdminErrorResponse; - -/// Admin 服务错误类型 -#[derive(Debug)] -pub enum AdminServiceError { - /// 凭据不存在 - NotFound { id: u64 }, - - /// 上游服务调用失败(网络、API 错误等) - UpstreamError(String), - - /// 内部状态错误 - InternalError(String), - - /// 凭据无效(验证失败) - InvalidCredential(String), -} - -impl fmt::Display for AdminServiceError { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - AdminServiceError::NotFound { id } => { - write!(f, "凭据不存在: {}", id) - } - AdminServiceError::UpstreamError(msg) => write!(f, "上游服务错误: {}", msg), - AdminServiceError::InternalError(msg) => write!(f, "内部错误: {}", msg), - AdminServiceError::InvalidCredential(msg) => write!(f, "凭据无效: {}", msg), - } - } -} - -impl std::error::Error for AdminServiceError {} - -impl AdminServiceError { - /// 获取对应的 HTTP 状态码 - pub fn status_code(&self) -> StatusCode { - match self { - AdminServiceError::NotFound { .. } => StatusCode::NOT_FOUND, - AdminServiceError::UpstreamError(_) => StatusCode::BAD_GATEWAY, - AdminServiceError::InternalError(_) => StatusCode::INTERNAL_SERVER_ERROR, - AdminServiceError::InvalidCredential(_) => StatusCode::BAD_REQUEST, - } - } - - /// 转换为 API 错误响应 - pub fn into_response(self) -> AdminErrorResponse { - match &self { - AdminServiceError::NotFound { .. } => AdminErrorResponse::not_found(self.to_string()), - AdminServiceError::UpstreamError(_) => AdminErrorResponse::api_error(self.to_string()), - AdminServiceError::InternalError(_) => { - AdminErrorResponse::internal_error(self.to_string()) - } - AdminServiceError::InvalidCredential(_) => { - AdminErrorResponse::invalid_request(self.to_string()) - } - } - } -} diff --git a/src/admin/handlers.rs b/src/admin/handlers.rs deleted file mode 100644 index 39190b115042016aec39b799067a1e88bf504920..0000000000000000000000000000000000000000 --- a/src/admin/handlers.rs +++ /dev/null @@ -1,104 +0,0 @@ -//! Admin API HTTP 处理器 - -use axum::{ - Json, - extract::{Path, State}, - response::IntoResponse, -}; - -use super::{ - middleware::AdminState, - types::{AddCredentialRequest, SetDisabledRequest, SetPriorityRequest, SuccessResponse}, -}; - -/// GET /api/admin/credentials -/// 获取所有凭据状态 -pub async fn get_all_credentials(State(state): State) -> impl IntoResponse { - let response = state.service.get_all_credentials(); - Json(response) -} - -/// POST /api/admin/credentials/:id/disabled -/// 设置凭据禁用状态 -pub async fn set_credential_disabled( - State(state): State, - Path(id): Path, - Json(payload): Json, -) -> impl IntoResponse { - match state.service.set_disabled(id, payload.disabled) { - Ok(_) => { - let action = if payload.disabled { "禁用" } else { "启用" }; - Json(SuccessResponse::new(format!("凭据 #{} 已{}", id, action))).into_response() - } - Err(e) => (e.status_code(), Json(e.into_response())).into_response(), - } -} - -/// POST /api/admin/credentials/:id/priority -/// 设置凭据优先级 -pub async fn set_credential_priority( - State(state): State, - Path(id): Path, - Json(payload): Json, -) -> impl IntoResponse { - match state.service.set_priority(id, payload.priority) { - Ok(_) => Json(SuccessResponse::new(format!( - "凭据 #{} 优先级已设置为 {}", - id, payload.priority - ))) - .into_response(), - Err(e) => (e.status_code(), Json(e.into_response())).into_response(), - } -} - -/// POST /api/admin/credentials/:id/reset -/// 重置失败计数并重新启用 -pub async fn reset_failure_count( - State(state): State, - Path(id): Path, -) -> impl IntoResponse { - match state.service.reset_and_enable(id) { - Ok(_) => Json(SuccessResponse::new(format!( - "凭据 #{} 失败计数已重置并重新启用", - id - ))) - .into_response(), - Err(e) => (e.status_code(), Json(e.into_response())).into_response(), - } -} - -/// GET /api/admin/credentials/:id/balance -/// 获取指定凭据的余额 -pub async fn get_credential_balance( - State(state): State, - Path(id): Path, -) -> impl IntoResponse { - match state.service.get_balance(id).await { - Ok(response) => Json(response).into_response(), - Err(e) => (e.status_code(), Json(e.into_response())).into_response(), - } -} - -/// POST /api/admin/credentials -/// 添加新凭据 -pub async fn add_credential( - State(state): State, - Json(payload): Json, -) -> impl IntoResponse { - match state.service.add_credential(payload).await { - Ok(response) => Json(response).into_response(), - Err(e) => (e.status_code(), Json(e.into_response())).into_response(), - } -} - -/// DELETE /api/admin/credentials/:id -/// 删除凭据 -pub async fn delete_credential( - State(state): State, - Path(id): Path, -) -> impl IntoResponse { - match state.service.delete_credential(id) { - Ok(_) => Json(SuccessResponse::new(format!("凭据 #{} 已删除", id))).into_response(), - Err(e) => (e.status_code(), Json(e.into_response())).into_response(), - } -} diff --git a/src/admin/middleware.rs b/src/admin/middleware.rs deleted file mode 100644 index af11af917580e5d748c00f265d860a14e74b02a2..0000000000000000000000000000000000000000 --- a/src/admin/middleware.rs +++ /dev/null @@ -1,50 +0,0 @@ -//! Admin API 中间件 - -use std::sync::Arc; - -use axum::{ - body::Body, - extract::State, - http::{Request, StatusCode}, - middleware::Next, - response::{IntoResponse, Json, Response}, -}; - -use super::service::AdminService; -use super::types::AdminErrorResponse; -use crate::common::auth; - -/// Admin API 共享状态 -#[derive(Clone)] -pub struct AdminState { - /// Admin API 密钥 - pub admin_api_key: String, - /// Admin 服务 - pub service: Arc, -} - -impl AdminState { - pub fn new(admin_api_key: impl Into, service: AdminService) -> Self { - Self { - admin_api_key: admin_api_key.into(), - service: Arc::new(service), - } - } -} - -/// Admin API 认证中间件 -pub async fn admin_auth_middleware( - State(state): State, - request: Request, - next: Next, -) -> Response { - let api_key = auth::extract_api_key(&request); - - match api_key { - Some(key) if auth::constant_time_eq(&key, &state.admin_api_key) => next.run(request).await, - _ => { - let error = AdminErrorResponse::authentication_error(); - (StatusCode::UNAUTHORIZED, Json(error)).into_response() - } - } -} diff --git a/src/admin/mod.rs b/src/admin/mod.rs deleted file mode 100644 index 21321f85b4b82acf1eabf1e59a52951ac097f0a3..0000000000000000000000000000000000000000 --- a/src/admin/mod.rs +++ /dev/null @@ -1,28 +0,0 @@ -//! Admin API 模块 -//! -//! 提供凭据管理和监控功能的 HTTP API -//! -//! # 功能 -//! - 查询所有凭据状态 -//! - 启用/禁用凭据 -//! - 修改凭据优先级 -//! - 重置失败计数 -//! - 查询凭据余额 -//! -//! # 使用 -//! ```ignore -//! let admin_service = AdminService::new(token_manager.clone()); -//! let admin_state = AdminState::new(admin_api_key, admin_service); -//! let admin_router = create_admin_router(admin_state); -//! ``` - -mod error; -mod handlers; -mod middleware; -mod router; -mod service; -pub mod types; - -pub use middleware::AdminState; -pub use router::create_admin_router; -pub use service::AdminService; diff --git a/src/admin/router.rs b/src/admin/router.rs deleted file mode 100644 index c833dc5260efad64e2bcefa36bd34da4e201d231..0000000000000000000000000000000000000000 --- a/src/admin/router.rs +++ /dev/null @@ -1,47 +0,0 @@ -//! Admin API 路由配置 - -use axum::{ - Router, middleware, - routing::{delete, get, post}, -}; - -use super::{ - handlers::{ - add_credential, delete_credential, get_all_credentials, get_credential_balance, - reset_failure_count, set_credential_disabled, set_credential_priority, - }, - middleware::{AdminState, admin_auth_middleware}, -}; - -/// 创建 Admin API 路由 -/// -/// # 端点 -/// - `GET /credentials` - 获取所有凭据状态 -/// - `POST /credentials` - 添加新凭据 -/// - `DELETE /credentials/:id` - 删除凭据 -/// - `POST /credentials/:id/disabled` - 设置凭据禁用状态 -/// - `POST /credentials/:id/priority` - 设置凭据优先级 -/// - `POST /credentials/:id/reset` - 重置失败计数 -/// - `GET /credentials/:id/balance` - 获取凭据余额 -/// -/// # 认证 -/// 需要 Admin API Key 认证,支持: -/// - `x-api-key` header -/// - `Authorization: Bearer ` header -pub fn create_admin_router(state: AdminState) -> Router { - Router::new() - .route( - "/credentials", - get(get_all_credentials).post(add_credential), - ) - .route("/credentials/{id}", delete(delete_credential)) - .route("/credentials/{id}/disabled", post(set_credential_disabled)) - .route("/credentials/{id}/priority", post(set_credential_priority)) - .route("/credentials/{id}/reset", post(reset_failure_count)) - .route("/credentials/{id}/balance", get(get_credential_balance)) - .layer(middleware::from_fn_with_state( - state.clone(), - admin_auth_middleware, - )) - .with_state(state) -} diff --git a/src/admin/service.rs b/src/admin/service.rs deleted file mode 100644 index f6affa5a0b72b3fd66241321289bc076cfc30c43..0000000000000000000000000000000000000000 --- a/src/admin/service.rs +++ /dev/null @@ -1,234 +0,0 @@ -//! Admin API 业务逻辑服务 - -use std::sync::Arc; - -use crate::kiro::model::credentials::KiroCredentials; -use crate::kiro::token_manager::MultiTokenManager; - -use super::error::AdminServiceError; -use super::types::{ - AddCredentialRequest, AddCredentialResponse, BalanceResponse, CredentialStatusItem, - CredentialsStatusResponse, -}; - -/// Admin 服务 -/// -/// 封装所有 Admin API 的业务逻辑 -pub struct AdminService { - token_manager: Arc, -} - -impl AdminService { - pub fn new(token_manager: Arc) -> Self { - Self { token_manager } - } - - /// 获取所有凭据状态 - pub fn get_all_credentials(&self) -> CredentialsStatusResponse { - let snapshot = self.token_manager.snapshot(); - - let mut credentials: Vec = snapshot - .entries - .into_iter() - .map(|entry| CredentialStatusItem { - id: entry.id, - priority: entry.priority, - disabled: entry.disabled, - failure_count: entry.failure_count, - is_current: entry.id == snapshot.current_id, - expires_at: entry.expires_at, - auth_method: entry.auth_method, - has_profile_arn: entry.has_profile_arn, - }) - .collect(); - - // 按优先级排序(数字越小优先级越高) - credentials.sort_by_key(|c| c.priority); - - CredentialsStatusResponse { - total: snapshot.total, - available: snapshot.available, - current_id: snapshot.current_id, - credentials, - } - } - - /// 设置凭据禁用状态 - pub fn set_disabled(&self, id: u64, disabled: bool) -> Result<(), AdminServiceError> { - // 先获取当前凭据 ID,用于判断是否需要切换 - let snapshot = self.token_manager.snapshot(); - let current_id = snapshot.current_id; - - self.token_manager - .set_disabled(id, disabled) - .map_err(|e| self.classify_error(e, id))?; - - // 只有禁用的是当前凭据时才尝试切换到下一个 - if disabled && id == current_id { - let _ = self.token_manager.switch_to_next(); - } - Ok(()) - } - - /// 设置凭据优先级 - pub fn set_priority(&self, id: u64, priority: u32) -> Result<(), AdminServiceError> { - self.token_manager - .set_priority(id, priority) - .map_err(|e| self.classify_error(e, id)) - } - - /// 重置失败计数并重新启用 - pub fn reset_and_enable(&self, id: u64) -> Result<(), AdminServiceError> { - self.token_manager - .reset_and_enable(id) - .map_err(|e| self.classify_error(e, id)) - } - - /// 获取凭据余额 - pub async fn get_balance(&self, id: u64) -> Result { - let usage = self - .token_manager - .get_usage_limits_for(id) - .await - .map_err(|e| self.classify_balance_error(e, id))?; - - let current_usage = usage.current_usage(); - let usage_limit = usage.usage_limit(); - let remaining = (usage_limit - current_usage).max(0.0); - let usage_percentage = if usage_limit > 0.0 { - (current_usage / usage_limit * 100.0).min(100.0) - } else { - 0.0 - }; - - Ok(BalanceResponse { - id, - subscription_title: usage.subscription_title().map(|s| s.to_string()), - current_usage, - usage_limit, - remaining, - usage_percentage, - next_reset_at: usage.next_date_reset, - }) - } - - /// 添加新凭据 - pub async fn add_credential( - &self, - req: AddCredentialRequest, - ) -> Result { - // 构建凭据对象 - let new_cred = KiroCredentials { - id: None, - access_token: None, - refresh_token: Some(req.refresh_token), - profile_arn: None, - expires_at: None, - auth_method: Some(req.auth_method), - client_id: req.client_id, - client_secret: req.client_secret, - priority: req.priority, - region: req.region, - machine_id: req.machine_id, - }; - - // 调用 token_manager 添加凭据 - let credential_id = self - .token_manager - .add_credential(new_cred) - .await - .map_err(|e| self.classify_add_error(e))?; - - Ok(AddCredentialResponse { - success: true, - message: format!("凭据添加成功,ID: {}", credential_id), - credential_id, - }) - } - - /// 删除凭据 - pub fn delete_credential(&self, id: u64) -> Result<(), AdminServiceError> { - self.token_manager - .delete_credential(id) - .map_err(|e| self.classify_delete_error(e, id)) - } - - /// 分类简单操作错误(set_disabled, set_priority, reset_and_enable) - fn classify_error(&self, e: anyhow::Error, id: u64) -> AdminServiceError { - let msg = e.to_string(); - if msg.contains("不存在") { - AdminServiceError::NotFound { id } - } else { - AdminServiceError::InternalError(msg) - } - } - - /// 分类余额查询错误(可能涉及上游 API 调用) - fn classify_balance_error(&self, e: anyhow::Error, id: u64) -> AdminServiceError { - let msg = e.to_string(); - - // 1. 凭据不存在 - if msg.contains("不存在") { - return AdminServiceError::NotFound { id }; - } - - // 2. 上游服务错误特征:HTTP 响应错误或网络错误 - let is_upstream_error = - // HTTP 响应错误(来自 refresh_*_token 的错误消息) - msg.contains("凭证已过期或无效") || - msg.contains("权限不足") || - msg.contains("已被限流") || - msg.contains("服务器错误") || - msg.contains("Token 刷新失败") || - msg.contains("暂时不可用") || - // 网络错误(reqwest 错误) - msg.contains("error trying to connect") || - msg.contains("connection") || - msg.contains("timeout") || - msg.contains("timed out"); - - if is_upstream_error { - AdminServiceError::UpstreamError(msg) - } else { - // 3. 默认归类为内部错误(本地验证失败、配置错误等) - // 包括:缺少 refreshToken、refreshToken 已被截断、无法生成 machineId 等 - AdminServiceError::InternalError(msg) - } - } - - /// 分类添加凭据错误 - fn classify_add_error(&self, e: anyhow::Error) -> AdminServiceError { - let msg = e.to_string(); - - // 凭据验证失败(refreshToken 无效、格式错误等) - let is_invalid_credential = msg.contains("缺少 refreshToken") - || msg.contains("refreshToken 为空") - || msg.contains("refreshToken 已被截断") - || msg.contains("凭证已过期或无效") - || msg.contains("权限不足") - || msg.contains("已被限流"); - - if is_invalid_credential { - AdminServiceError::InvalidCredential(msg) - } else if msg.contains("error trying to connect") - || msg.contains("connection") - || msg.contains("timeout") - { - AdminServiceError::UpstreamError(msg) - } else { - AdminServiceError::InternalError(msg) - } - } - - /// 分类删除凭据错误 - fn classify_delete_error(&self, e: anyhow::Error, id: u64) -> AdminServiceError { - let msg = e.to_string(); - if msg.contains("不存在") { - AdminServiceError::NotFound { id } - } else if msg.contains("只能删除已禁用的凭据") { - AdminServiceError::InvalidCredential(msg) - } else { - AdminServiceError::InternalError(msg) - } - } -} diff --git a/src/admin/types.rs b/src/admin/types.rs deleted file mode 100644 index 52cd593ea9723de0bcdd0eb1f05f6977390ec91d..0000000000000000000000000000000000000000 --- a/src/admin/types.rs +++ /dev/null @@ -1,187 +0,0 @@ -//! Admin API 类型定义 - -use serde::{Deserialize, Serialize}; - -// ============ 凭据状态 ============ - -/// 所有凭据状态响应 -#[derive(Debug, Serialize)] -#[serde(rename_all = "camelCase")] -pub struct CredentialsStatusResponse { - /// 凭据总数 - pub total: usize, - /// 可用凭据数量(未禁用) - pub available: usize, - /// 当前活跃凭据 ID - pub current_id: u64, - /// 各凭据状态列表 - pub credentials: Vec, -} - -/// 单个凭据的状态信息 -#[derive(Debug, Serialize)] -#[serde(rename_all = "camelCase")] -pub struct CredentialStatusItem { - /// 凭据唯一 ID - pub id: u64, - /// 优先级(数字越小优先级越高) - pub priority: u32, - /// 是否被禁用 - pub disabled: bool, - /// 连续失败次数 - pub failure_count: u32, - /// 是否为当前活跃凭据 - pub is_current: bool, - /// Token 过期时间(RFC3339 格式) - pub expires_at: Option, - /// 认证方式 - pub auth_method: Option, - /// 是否有 Profile ARN - pub has_profile_arn: bool, -} - -// ============ 操作请求 ============ - -/// 启用/禁用凭据请求 -#[derive(Debug, Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct SetDisabledRequest { - /// 是否禁用 - pub disabled: bool, -} - -/// 修改优先级请求 -#[derive(Debug, Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct SetPriorityRequest { - /// 新优先级值 - pub priority: u32, -} - -/// 添加凭据请求 -#[derive(Debug, Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct AddCredentialRequest { - /// 刷新令牌(必填) - pub refresh_token: String, - - /// 认证方式(可选,默认 social) - #[serde(default = "default_auth_method")] - pub auth_method: String, - - /// OIDC Client ID(IdC 认证需要) - pub client_id: Option, - - /// OIDC Client Secret(IdC 认证需要) - pub client_secret: Option, - - /// 优先级(可选,默认 0) - #[serde(default)] - pub priority: u32, - - /// 凭据级 Region 配置(用于 OIDC token 刷新) - /// 未配置时回退到 config.json 的全局 region - pub region: Option, - - /// 凭据级 Machine ID(可选,64 位字符串) - /// 未配置时回退到 config.json 的 machineId - pub machine_id: Option, -} - -fn default_auth_method() -> String { - "social".to_string() -} - -/// 添加凭据成功响应 -#[derive(Debug, Serialize)] -#[serde(rename_all = "camelCase")] -pub struct AddCredentialResponse { - pub success: bool, - pub message: String, - /// 新添加的凭据 ID - pub credential_id: u64, -} - -// ============ 余额查询 ============ - -/// 余额查询响应 -#[derive(Debug, Serialize)] -#[serde(rename_all = "camelCase")] -pub struct BalanceResponse { - /// 凭据 ID - pub id: u64, - /// 订阅类型 - pub subscription_title: Option, - /// 当前使用量 - pub current_usage: f64, - /// 使用限额 - pub usage_limit: f64, - /// 剩余额度 - pub remaining: f64, - /// 使用百分比 - pub usage_percentage: f64, - /// 下次重置时间(Unix 时间戳) - pub next_reset_at: Option, -} - -// ============ 通用响应 ============ - -/// 操作成功响应 -#[derive(Debug, Serialize)] -pub struct SuccessResponse { - pub success: bool, - pub message: String, -} - -impl SuccessResponse { - pub fn new(message: impl Into) -> Self { - Self { - success: true, - message: message.into(), - } - } -} - -/// 错误响应 -#[derive(Debug, Serialize)] -pub struct AdminErrorResponse { - pub error: AdminError, -} - -#[derive(Debug, Serialize)] -pub struct AdminError { - #[serde(rename = "type")] - pub error_type: String, - pub message: String, -} - -impl AdminErrorResponse { - pub fn new(error_type: impl Into, message: impl Into) -> Self { - Self { - error: AdminError { - error_type: error_type.into(), - message: message.into(), - }, - } - } - - pub fn invalid_request(message: impl Into) -> Self { - Self::new("invalid_request", message) - } - - pub fn authentication_error() -> Self { - Self::new("authentication_error", "Invalid or missing admin API key") - } - - pub fn not_found(message: impl Into) -> Self { - Self::new("not_found", message) - } - - pub fn api_error(message: impl Into) -> Self { - Self::new("api_error", message) - } - - pub fn internal_error(message: impl Into) -> Self { - Self::new("internal_error", message) - } -} diff --git a/src/admin_ui/mod.rs b/src/admin_ui/mod.rs deleted file mode 100644 index 9537d2827434bf97f9e2eee1d1869399a28249e1..0000000000000000000000000000000000000000 --- a/src/admin_ui/mod.rs +++ /dev/null @@ -1,7 +0,0 @@ -//! Admin UI 静态文件服务模块 -//! -//! 使用 rust-embed 嵌入前端构建产物 - -mod router; - -pub use router::create_admin_ui_router; diff --git a/src/admin_ui/router.rs b/src/admin_ui/router.rs deleted file mode 100644 index 36d8bc7a0e9d08d2b38e77f6f30e98f4ee38cb96..0000000000000000000000000000000000000000 --- a/src/admin_ui/router.rs +++ /dev/null @@ -1,109 +0,0 @@ -//! Admin UI 路由配置 - -use axum::{ - Router, - body::Body, - http::{Response, StatusCode, Uri, header}, - response::IntoResponse, - routing::get, -}; -use rust_embed::Embed; - -/// 嵌入前端构建产物 -#[derive(Embed)] -#[folder = "admin-ui/dist"] -struct Asset; - -/// 创建 Admin UI 路由 -pub fn create_admin_ui_router() -> Router { - Router::new() - .route("/", get(index_handler)) - .route("/{*file}", get(static_handler)) -} - -/// 处理首页请求 -async fn index_handler() -> impl IntoResponse { - serve_index() -} - -/// 处理静态文件请求 -async fn static_handler(uri: Uri) -> impl IntoResponse { - let path = uri.path().trim_start_matches('/'); - - // 安全检查:拒绝包含 .. 的路径 - if path.contains("..") { - return Response::builder() - .status(StatusCode::BAD_REQUEST) - .body(Body::from("Invalid path")) - .expect("Failed to build response"); - } - - // 尝试获取请求的文件 - if let Some(content) = Asset::get(path) { - let mime = mime_guess::from_path(path) - .first_or_octet_stream() - .to_string(); - - // 根据文件类型设置不同的缓存策略 - let cache_control = get_cache_control(path); - - return Response::builder() - .status(StatusCode::OK) - .header(header::CONTENT_TYPE, mime) - .header(header::CACHE_CONTROL, cache_control) - .body(Body::from(content.data.into_owned())) - .expect("Failed to build response"); - } - - // SPA fallback: 如果文件不存在且不是资源文件,返回 index.html - if !is_asset_path(path) { - return serve_index(); - } - - // 404 - Response::builder() - .status(StatusCode::NOT_FOUND) - .body(Body::from("Not found")) - .expect("Failed to build response") -} - -/// 提供 index.html -fn serve_index() -> Response { - match Asset::get("index.html") { - Some(content) => Response::builder() - .status(StatusCode::OK) - .header(header::CONTENT_TYPE, "text/html; charset=utf-8") - .header(header::CACHE_CONTROL, "no-cache") - .body(Body::from(content.data.into_owned())) - .expect("Failed to build response"), - None => Response::builder() - .status(StatusCode::NOT_FOUND) - .body(Body::from( - "Admin UI not built. Run 'pnpm build' in admin-ui directory.", - )) - .expect("Failed to build response"), - } -} - -/// 根据文件类型返回合适的缓存策略 -fn get_cache_control(path: &str) -> &'static str { - if path.ends_with(".html") { - // HTML 文件不缓存,确保用户获取最新版本 - "no-cache" - } else if path.starts_with("assets/") { - // assets/ 目录下的文件带有内容哈希,可以长期缓存 - "public, max-age=31536000, immutable" - } else { - // 其他文件(如 favicon)使用较短的缓存 - "public, max-age=3600" - } -} - -/// 判断是否为资源文件路径(有扩展名的文件) -fn is_asset_path(path: &str) -> bool { - // 检查最后一个路径段是否包含扩展名 - path.rsplit('/') - .next() - .map(|filename| filename.contains('.')) - .unwrap_or(false) -} diff --git a/src/anthropic/converter.rs b/src/anthropic/converter.rs deleted file mode 100644 index 874ff283b39f2656ed679af0d653f4ac1b018b67..0000000000000000000000000000000000000000 --- a/src/anthropic/converter.rs +++ /dev/null @@ -1,1118 +0,0 @@ -//! Anthropic → Kiro 协议转换器 -//! -//! 负责将 Anthropic API 请求格式转换为 Kiro API 请求格式 - -use uuid::Uuid; - -use crate::kiro::model::requests::conversation::{ - AssistantMessage, ConversationState, CurrentMessage, HistoryAssistantMessage, - HistoryUserMessage, KiroImage, Message, UserInputMessage, UserInputMessageContext, UserMessage, -}; -use crate::kiro::model::requests::tool::{ - InputSchema, Tool, ToolResult, ToolSpecification, ToolUseEntry, -}; - -use super::types::{ContentBlock, MessagesRequest, Thinking}; - -/// 模型映射:将 Anthropic 模型名映射到 Kiro 模型 ID -/// -/// 按照用户要求: -/// - 所有 sonnet → claude-sonnet-4.5 -/// - 所有 opus → claude-opus-4.5 -/// - 所有 haiku → claude-haiku-4.5 -pub fn map_model(model: &str) -> Option { - let model_lower = model.to_lowercase(); - - if model_lower.contains("sonnet") { - Some("claude-sonnet-4.5".to_string()) - } else if model_lower.contains("opus") { - Some("claude-opus-4.5".to_string()) - } else if model_lower.contains("haiku") { - Some("claude-haiku-4.5".to_string()) - } else { - None - } -} - -/// 转换结果 -#[derive(Debug)] -pub struct ConversionResult { - /// 转换后的 Kiro 请求 - pub conversation_state: ConversationState, -} - -/// 转换错误 -#[derive(Debug)] -pub enum ConversionError { - UnsupportedModel(String), - EmptyMessages, -} - -impl std::fmt::Display for ConversionError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - ConversionError::UnsupportedModel(model) => write!(f, "模型不支持: {}", model), - ConversionError::EmptyMessages => write!(f, "消息列表为空"), - } - } -} - -impl std::error::Error for ConversionError {} - -/// 从 metadata.user_id 中提取 session UUID -/// -/// user_id 格式: user_xxx_account__session_0b4445e1-f5be-49e1-87ce-62bbc28ad705 -/// 提取 session_ 后面的 UUID 作为 conversationId -fn extract_session_id(user_id: &str) -> Option { - // 查找 "session_" 后面的内容 - if let Some(pos) = user_id.find("session_") { - let session_part = &user_id[pos + 8..]; // "session_" 长度为 8 - // session_part 应该是 UUID 格式: xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx - // 验证是否是有效的 UUID 格式(36 字符,包含 4 个连字符) - if session_part.len() >= 36 { - let uuid_str = &session_part[..36]; - // 简单验证 UUID 格式 - if uuid_str.chars().filter(|c| *c == '-').count() == 4 { - return Some(uuid_str.to_string()); - } - } - } - None -} - -/// 收集历史消息中使用的所有工具名称 -fn collect_history_tool_names(history: &[Message]) -> Vec { - let mut tool_names = Vec::new(); - - for msg in history { - if let Message::Assistant(assistant_msg) = msg { - if let Some(ref tool_uses) = assistant_msg.assistant_response_message.tool_uses { - for tool_use in tool_uses { - if !tool_names.contains(&tool_use.name) { - tool_names.push(tool_use.name.clone()); - } - } - } - } - } - - tool_names -} - -/// 为历史中使用但不在 tools 列表中的工具创建占位符定义 -/// Kiro API 要求:历史消息中引用的工具必须在 currentMessage.tools 中有定义 -fn create_placeholder_tool(name: &str) -> Tool { - Tool { - tool_specification: ToolSpecification { - name: name.to_string(), - description: "Tool used in conversation history".to_string(), - input_schema: InputSchema::from_json(serde_json::json!({ - "$schema": "http://json-schema.org/draft-07/schema#", - "type": "object", - "properties": {}, - "required": [], - "additionalProperties": true - })), - }, - } -} - -/// 将 Anthropic 请求转换为 Kiro 请求 -pub fn convert_request(req: &MessagesRequest) -> Result { - // 1. 映射模型 - let model_id = map_model(&req.model) - .ok_or_else(|| ConversionError::UnsupportedModel(req.model.clone()))?; - - // 2. 检查消息列表 - if req.messages.is_empty() { - return Err(ConversionError::EmptyMessages); - } - - // 3. 生成会话 ID 和代理 ID - // 优先从 metadata.user_id 中提取 session UUID 作为 conversationId - let conversation_id = req - .metadata - .as_ref() - .and_then(|m| m.user_id.as_ref()) - .and_then(|user_id| extract_session_id(user_id)) - .unwrap_or_else(|| Uuid::new_v4().to_string()); - let agent_continuation_id = Uuid::new_v4().to_string(); - - // 4. 确定触发类型 - let chat_trigger_type = determine_chat_trigger_type(req); - - // 5. 处理最后一条消息作为 current_message - let last_message = req.messages.last().unwrap(); - let (text_content, images, tool_results) = process_message_content(&last_message.content)?; - - // 6. 转换工具定义 - let mut tools = convert_tools(&req.tools); - - // 7. 构建历史消息(需要先构建,以便收集历史中使用的工具) - let history = build_history(req, &model_id)?; - - // 8. 验证并过滤 tool_use/tool_result 配对 - // 移除孤立的 tool_result(没有对应的 tool_use) - let validated_tool_results = validate_tool_pairing(&history, &tool_results); - - // 9. 收集历史中使用的工具名称,为缺失的工具生成占位符定义 - // Kiro API 要求:历史消息中引用的工具必须在 tools 列表中有定义 - // 注意:Kiro 匹配工具名称时忽略大小写,所以这里也需要忽略大小写比较 - let history_tool_names = collect_history_tool_names(&history); - let existing_tool_names: std::collections::HashSet<_> = tools - .iter() - .map(|t| t.tool_specification.name.to_lowercase()) - .collect(); - - for tool_name in history_tool_names { - if !existing_tool_names.contains(&tool_name.to_lowercase()) { - tools.push(create_placeholder_tool(&tool_name)); - } - } - - // 10. 构建 UserInputMessageContext - let mut context = UserInputMessageContext::new(); - if !tools.is_empty() { - context = context.with_tools(tools); - } - if !validated_tool_results.is_empty() { - context = context.with_tool_results(validated_tool_results); - } - - // 11. 构建当前消息 - // 保留文本内容,即使有工具结果也不丢弃用户文本 - let content = text_content; - - let mut user_input = UserInputMessage::new(content, &model_id) - .with_context(context) - .with_origin("AI_EDITOR"); - - if !images.is_empty() { - user_input = user_input.with_images(images); - } - - let current_message = CurrentMessage::new(user_input); - - // 12. 构建 ConversationState - let conversation_state = ConversationState::new(conversation_id) - .with_agent_continuation_id(agent_continuation_id) - .with_agent_task_type("vibe") - .with_chat_trigger_type(chat_trigger_type) - .with_current_message(current_message) - .with_history(history); - - Ok(ConversionResult { conversation_state }) -} - -/// 确定聊天触发类型 -/// "AUTO" 模式可能会导致 400 Bad Request 错误 -fn determine_chat_trigger_type(_req: &MessagesRequest) -> String { - "MANUAL".to_string() -} - -/// 处理消息内容,提取文本、图片和工具结果 -fn process_message_content( - content: &serde_json::Value, -) -> Result<(String, Vec, Vec), ConversionError> { - let mut text_parts = Vec::new(); - let mut images = Vec::new(); - let mut tool_results = Vec::new(); - - match content { - serde_json::Value::String(s) => { - text_parts.push(s.clone()); - } - serde_json::Value::Array(arr) => { - for item in arr { - if let Ok(block) = serde_json::from_value::(item.clone()) { - match block.block_type.as_str() { - "text" => { - if let Some(text) = block.text { - text_parts.push(text); - } - } - "image" => { - if let Some(source) = block.source { - if let Some(format) = get_image_format(&source.media_type) { - images.push(KiroImage::from_base64(format, source.data)); - } - } - } - "tool_result" => { - if let Some(tool_use_id) = block.tool_use_id { - let result_content = extract_tool_result_content(&block.content); - let is_error = block.is_error.unwrap_or(false); - - let mut result = if is_error { - ToolResult::error(&tool_use_id, result_content) - } else { - ToolResult::success(&tool_use_id, result_content) - }; - result.status = - Some(if is_error { "error" } else { "success" }.to_string()); - - tool_results.push(result); - } - } - "tool_use" => { - // tool_use 在 assistant 消息中处理,这里忽略 - } - _ => {} - } - } - } - } - _ => {} - } - - Ok((text_parts.join("\n"), images, tool_results)) -} - -/// 从 media_type 获取图片格式 -fn get_image_format(media_type: &str) -> Option { - match media_type { - "image/jpeg" => Some("jpeg".to_string()), - "image/png" => Some("png".to_string()), - "image/gif" => Some("gif".to_string()), - "image/webp" => Some("webp".to_string()), - _ => None, - } -} - -/// 提取工具结果内容 -fn extract_tool_result_content(content: &Option) -> String { - match content { - Some(serde_json::Value::String(s)) => s.clone(), - Some(serde_json::Value::Array(arr)) => { - let mut parts = Vec::new(); - for item in arr { - if let Some(text) = item.get("text").and_then(|v| v.as_str()) { - parts.push(text.to_string()); - } - } - parts.join("\n") - } - Some(v) => v.to_string(), - None => String::new(), - } -} - -/// 验证并过滤 tool_use/tool_result 配对 -/// -/// 收集所有 tool_use_id,验证 tool_result 是否匹配 -/// 静默跳过孤立的 tool_use 和 tool_result,输出警告日志 -/// -/// # Arguments -/// * `history` - 历史消息引用 -/// * `tool_results` - 当前消息中的 tool_result 列表 -/// -/// # Returns -/// 经过验证和过滤后的 tool_result 列表 -fn validate_tool_pairing(history: &[Message], tool_results: &[ToolResult]) -> Vec { - use std::collections::HashSet; - - // 1. 收集所有历史中的 tool_use_id - let mut all_tool_use_ids: HashSet = HashSet::new(); - // 2. 收集历史中已经有 tool_result 的 tool_use_id - let mut history_tool_result_ids: HashSet = HashSet::new(); - - for msg in history { - match msg { - Message::Assistant(assistant_msg) => { - if let Some(ref tool_uses) = assistant_msg.assistant_response_message.tool_uses { - for tool_use in tool_uses { - all_tool_use_ids.insert(tool_use.tool_use_id.clone()); - } - } - } - Message::User(user_msg) => { - // 收集历史 user 消息中的 tool_results - for result in &user_msg.user_input_message.user_input_message_context.tool_results - { - history_tool_result_ids.insert(result.tool_use_id.clone()); - } - } - } - } - - // 3. 计算真正未配对的 tool_use_ids(排除历史中已配对的) - let mut unpaired_tool_use_ids: HashSet = all_tool_use_ids - .difference(&history_tool_result_ids) - .cloned() - .collect(); - - // 4. 过滤并验证当前消息的 tool_results - let mut filtered_results = Vec::new(); - - for result in tool_results { - if unpaired_tool_use_ids.contains(&result.tool_use_id) { - // 配对成功 - filtered_results.push(result.clone()); - unpaired_tool_use_ids.remove(&result.tool_use_id); - } else if all_tool_use_ids.contains(&result.tool_use_id) { - // tool_use 存在但已经在历史中配对过了,这是重复的 tool_result - tracing::warn!( - "跳过重复的 tool_result:该 tool_use 已在历史中配对,tool_use_id={}", - result.tool_use_id - ); - } else { - // 孤立 tool_result - 找不到对应的 tool_use - tracing::warn!( - "跳过孤立的 tool_result:找不到对应的 tool_use,tool_use_id={}", - result.tool_use_id - ); - } - } - - // 5. 检测真正孤立的 tool_use(有 tool_use 但在历史和当前消息中都没有 tool_result) - for orphaned_id in &unpaired_tool_use_ids { - tracing::warn!( - "检测到孤立的 tool_use:找不到对应的 tool_result,tool_use_id={}", - orphaned_id - ); - } - - filtered_results -} - -/// 转换工具定义 -fn convert_tools(tools: &Option>) -> Vec { - let Some(tools) = tools else { - return Vec::new(); - }; - - tools - .iter() - .map(|t| { - let description = t.description.clone(); - // 限制描述长度为 10000 字符(安全截断 UTF-8,单次遍历) - let description = match description.char_indices().nth(10000) { - Some((idx, _)) => description[..idx].to_string(), - None => description, - }; - - Tool { - tool_specification: ToolSpecification { - name: t.name.clone(), - description, - input_schema: InputSchema::from_json(serde_json::json!(t.input_schema)), - }, - } - }) - .collect() -} - -/// 生成thinking标签前缀 -fn generate_thinking_prefix(thinking: &Option) -> Option { - if let Some(t) = thinking { - if t.thinking_type == "enabled" { - return Some(format!( - "enabled{}", - t.budget_tokens - )); - } - } - None -} - -/// 检查内容是否已包含thinking标签 -fn has_thinking_tags(content: &str) -> bool { - content.contains("") || content.contains("") -} - -/// 构建历史消息 -fn build_history(req: &MessagesRequest, model_id: &str) -> Result, ConversionError> { - let mut history = Vec::new(); - - // 生成thinking前缀(如果需要) - let thinking_prefix = generate_thinking_prefix(&req.thinking); - - // 1. 处理系统消息 - if let Some(ref system) = req.system { - let system_content: String = system - .iter() - .map(|s| s.text.clone()) - .collect::>() - .join("\n"); - - if !system_content.is_empty() { - // 注入thinking标签到系统消息最前面(如果需要且不存在) - let final_content = if let Some(ref prefix) = thinking_prefix { - if !has_thinking_tags(&system_content) { - format!("{}\n{}", prefix, system_content) - } else { - system_content - } - } else { - system_content - }; - - // 系统消息作为 user + assistant 配对 - let user_msg = HistoryUserMessage::new(final_content, model_id); - history.push(Message::User(user_msg)); - - let assistant_msg = HistoryAssistantMessage::new("I will follow these instructions."); - history.push(Message::Assistant(assistant_msg)); - } - } else if let Some(ref prefix) = thinking_prefix { - // 没有系统消息但有thinking配置,插入新的系统消息 - let user_msg = HistoryUserMessage::new(prefix.clone(), model_id); - history.push(Message::User(user_msg)); - - let assistant_msg = HistoryAssistantMessage::new("I will follow these instructions."); - history.push(Message::Assistant(assistant_msg)); - } - - // 2. 处理常规消息历史 - // 最后一条消息作为 currentMessage,不加入历史 - let history_end_index = req.messages.len().saturating_sub(1); - - // 如果最后一条是 assistant,则包含在历史中 - let last_is_assistant = req - .messages - .last() - .map(|m| m.role == "assistant") - .unwrap_or(false); - - let history_end_index = if last_is_assistant { - req.messages.len() - } else { - history_end_index - }; - - // 收集并配对消息 - let mut user_buffer: Vec<&super::types::Message> = Vec::new(); - - for i in 0..history_end_index { - let msg = &req.messages[i]; - - if msg.role == "user" { - user_buffer.push(msg); - } else if msg.role == "assistant" { - // 遇到 assistant,处理累积的 user 消息 - if !user_buffer.is_empty() { - let merged_user = merge_user_messages(&user_buffer, model_id)?; - history.push(Message::User(merged_user)); - user_buffer.clear(); - - // 添加 assistant 消息 - let assistant = convert_assistant_message(msg)?; - history.push(Message::Assistant(assistant)); - } - } - } - - // 处理结尾的孤立 user 消息 - if !user_buffer.is_empty() { - let merged_user = merge_user_messages(&user_buffer, model_id)?; - history.push(Message::User(merged_user)); - - // 自动配对一个 "OK" 的 assistant 响应 - let auto_assistant = HistoryAssistantMessage::new("OK"); - history.push(Message::Assistant(auto_assistant)); - } - - Ok(history) -} - -/// 合并多个 user 消息 -fn merge_user_messages( - messages: &[&super::types::Message], - model_id: &str, -) -> Result { - let mut content_parts = Vec::new(); - let mut all_images = Vec::new(); - let mut all_tool_results = Vec::new(); - - for msg in messages { - let (text, images, tool_results) = process_message_content(&msg.content)?; - if !text.is_empty() { - content_parts.push(text); - } - all_images.extend(images); - all_tool_results.extend(tool_results); - } - - let content = content_parts.join("\n"); - // 保留文本内容,即使有工具结果也不丢弃用户文本 - let mut user_msg = UserMessage::new(&content, model_id); - - if !all_images.is_empty() { - user_msg = user_msg.with_images(all_images); - } - - if !all_tool_results.is_empty() { - let mut ctx = UserInputMessageContext::new(); - ctx = ctx.with_tool_results(all_tool_results); - user_msg = user_msg.with_context(ctx); - } - - Ok(HistoryUserMessage { - user_input_message: user_msg, - }) -} - -/// 转换 assistant 消息 -fn convert_assistant_message( - msg: &super::types::Message, -) -> Result { - let mut thinking_content = String::new(); - let mut text_content = String::new(); - let mut tool_uses = Vec::new(); - - match &msg.content { - serde_json::Value::String(s) => { - text_content = s.clone(); - } - serde_json::Value::Array(arr) => { - for item in arr { - if let Ok(block) = serde_json::from_value::(item.clone()) { - match block.block_type.as_str() { - "thinking" => { - if let Some(thinking) = block.thinking { - thinking_content.push_str(&thinking); - } - } - "text" => { - if let Some(text) = block.text { - text_content.push_str(&text); - } - } - "tool_use" => { - if let (Some(id), Some(name)) = (block.id, block.name) { - let input = block.input.unwrap_or(serde_json::json!({})); - tool_uses.push(ToolUseEntry::new(id, name).with_input(input)); - } - } - _ => {} - } - } - } - } - _ => {} - } - - // 组合 thinking 和 text 内容 - // 格式: 思考内容\n\ntext内容 - // 注意: Kiro API 要求 content 字段不能为空,当只有 tool_use 时需要占位符 - let final_content = if !thinking_content.is_empty() { - if !text_content.is_empty() { - format!( - "{}\n\n{}", - thinking_content, text_content - ) - } else { - format!("{}", thinking_content) - } - } else if text_content.is_empty() && !tool_uses.is_empty() { - "There is a tool use.".to_string() - } else { - text_content - }; - - let mut assistant = AssistantMessage::new(final_content); - if !tool_uses.is_empty() { - assistant = assistant.with_tool_uses(tool_uses); - } - - Ok(HistoryAssistantMessage { - assistant_response_message: assistant, - }) -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_map_model_sonnet() { - assert!( - map_model("claude-sonnet-4-20250514") - .unwrap() - .contains("sonnet") - ); - assert!( - map_model("claude-3-5-sonnet-20241022") - .unwrap() - .contains("sonnet") - ); - } - - #[test] - fn test_map_model_opus() { - assert!( - map_model("claude-opus-4-20250514") - .unwrap() - .contains("opus") - ); - } - - #[test] - fn test_map_model_haiku() { - assert!( - map_model("claude-haiku-4-20250514") - .unwrap() - .contains("haiku") - ); - } - - #[test] - fn test_map_model_unsupported() { - assert!(map_model("gpt-4").is_none()); - } - - #[test] - fn test_determine_chat_trigger_type() { - // 无工具时返回 MANUAL - let req = MessagesRequest { - model: "claude-sonnet-4".to_string(), - max_tokens: 1024, - messages: vec![], - stream: false, - system: None, - tools: None, - tool_choice: None, - thinking: None, - metadata: None, - }; - assert_eq!(determine_chat_trigger_type(&req), "MANUAL"); - } - - #[test] - fn test_collect_history_tool_names() { - use crate::kiro::model::requests::tool::ToolUseEntry; - - // 创建包含工具使用的历史消息 - let mut assistant_msg = AssistantMessage::new("I'll read the file."); - assistant_msg = assistant_msg.with_tool_uses(vec![ - ToolUseEntry::new("tool-1", "read") - .with_input(serde_json::json!({"path": "/test.txt"})), - ToolUseEntry::new("tool-2", "write") - .with_input(serde_json::json!({"path": "/out.txt"})), - ]); - - let history = vec![ - Message::User(HistoryUserMessage::new( - "Read the file", - "claude-sonnet-4.5", - )), - Message::Assistant(HistoryAssistantMessage { - assistant_response_message: assistant_msg, - }), - ]; - - let tool_names = collect_history_tool_names(&history); - assert_eq!(tool_names.len(), 2); - assert!(tool_names.contains(&"read".to_string())); - assert!(tool_names.contains(&"write".to_string())); - } - - #[test] - fn test_create_placeholder_tool() { - let tool = create_placeholder_tool("my_custom_tool"); - - assert_eq!(tool.tool_specification.name, "my_custom_tool"); - assert!(!tool.tool_specification.description.is_empty()); - - // 验证 JSON 序列化正确 - let json = serde_json::to_string(&tool).unwrap(); - assert!(json.contains("\"name\":\"my_custom_tool\"")); - } - - #[test] - fn test_history_tools_added_to_tools_list() { - use super::super::types::Message as AnthropicMessage; - - // 创建一个请求,历史中有工具使用,但 tools 列表为空 - let req = MessagesRequest { - model: "claude-sonnet-4".to_string(), - max_tokens: 1024, - messages: vec![ - AnthropicMessage { - role: "user".to_string(), - content: serde_json::json!("Read the file"), - }, - AnthropicMessage { - role: "assistant".to_string(), - content: serde_json::json!([ - {"type": "text", "text": "I'll read the file."}, - {"type": "tool_use", "id": "tool-1", "name": "read", "input": {"path": "/test.txt"}} - ]), - }, - AnthropicMessage { - role: "user".to_string(), - content: serde_json::json!([ - {"type": "tool_result", "tool_use_id": "tool-1", "content": "file content"} - ]), - }, - ], - stream: false, - system: None, - tools: None, // 没有提供工具定义 - tool_choice: None, - thinking: None, - metadata: None, - }; - - let result = convert_request(&req).unwrap(); - - // 验证 tools 列表中包含了历史中使用的工具的占位符定义 - let tools = &result - .conversation_state - .current_message - .user_input_message - .user_input_message_context - .tools; - - assert!(!tools.is_empty(), "tools 列表不应为空"); - assert!( - tools.iter().any(|t| t.tool_specification.name == "read"), - "tools 列表应包含 'read' 工具的占位符定义" - ); - } - - #[test] - fn test_extract_session_id_valid() { - // 测试有效的 user_id 格式 - let user_id = "user_0dede55c6dcc4a11a30bbb5e7f22e6fdf86cdeba3820019cc27612af4e1243cd_account__session_8bb5523b-ec7c-4540-a9ca-beb6d79f1552"; - let session_id = extract_session_id(user_id); - assert_eq!( - session_id, - Some("8bb5523b-ec7c-4540-a9ca-beb6d79f1552".to_string()) - ); - } - - #[test] - fn test_extract_session_id_no_session() { - // 测试没有 session 的 user_id - let user_id = "user_0dede55c6dcc4a11a30bbb5e7f22e6fdf86cdeba3820019cc27612af4e1243cd"; - let session_id = extract_session_id(user_id); - assert_eq!(session_id, None); - } - - #[test] - fn test_extract_session_id_invalid_uuid() { - // 测试无效的 UUID 格式 - let user_id = "user_xxx_session_invalid-uuid"; - let session_id = extract_session_id(user_id); - assert_eq!(session_id, None); - } - - #[test] - fn test_convert_request_with_session_metadata() { - use super::super::types::{Message as AnthropicMessage, Metadata}; - - // 测试带有 metadata 的请求,应该使用 session UUID 作为 conversationId - let req = MessagesRequest { - model: "claude-sonnet-4".to_string(), - max_tokens: 1024, - messages: vec![AnthropicMessage { - role: "user".to_string(), - content: serde_json::json!("Hello"), - }], - stream: false, - system: None, - tools: None, - tool_choice: None, - thinking: None, - metadata: Some(Metadata { - user_id: Some( - "user_0dede55c6dcc4a11a30bbb5e7f22e6fdf86cdeba3820019cc27612af4e1243cd_account__session_a0662283-7fd3-4399-a7eb-52b9a717ae88".to_string(), - ), - }), - }; - - let result = convert_request(&req).unwrap(); - assert_eq!( - result.conversation_state.conversation_id, - "a0662283-7fd3-4399-a7eb-52b9a717ae88" - ); - } - - #[test] - fn test_convert_request_without_metadata() { - use super::super::types::Message as AnthropicMessage; - - // 测试没有 metadata 的请求,应该生成新的 UUID - let req = MessagesRequest { - model: "claude-sonnet-4".to_string(), - max_tokens: 1024, - messages: vec![AnthropicMessage { - role: "user".to_string(), - content: serde_json::json!("Hello"), - }], - stream: false, - system: None, - tools: None, - tool_choice: None, - thinking: None, - metadata: None, - }; - - let result = convert_request(&req).unwrap(); - // 验证生成的是有效的 UUID 格式 - assert_eq!(result.conversation_state.conversation_id.len(), 36); - assert_eq!( - result - .conversation_state - .conversation_id - .chars() - .filter(|c| *c == '-') - .count(), - 4 - ); - } - - #[test] - fn test_validate_tool_pairing_orphaned_result() { - // 测试孤立的 tool_result 被过滤 - // 历史中没有 tool_use,但 tool_results 中有 tool_result - let history = vec![ - Message::User(HistoryUserMessage::new("Hello", "claude-sonnet-4.5")), - Message::Assistant(HistoryAssistantMessage::new("Hi there!")), - ]; - - let tool_results = vec![ToolResult::success("orphan-123", "some result")]; - - let filtered = validate_tool_pairing(&history, &tool_results); - - // 孤立的 tool_result 应该被过滤掉 - assert!(filtered.is_empty(), "孤立的 tool_result 应该被过滤"); - } - - #[test] - fn test_validate_tool_pairing_orphaned_use() { - use crate::kiro::model::requests::tool::ToolUseEntry; - - // 测试孤立的 tool_use(有 tool_use 但没有对应的 tool_result) - let mut assistant_msg = AssistantMessage::new("I'll read the file."); - assistant_msg = assistant_msg.with_tool_uses(vec![ToolUseEntry::new("tool-orphan", "read") - .with_input(serde_json::json!({"path": "/test.txt"}))]); - - let history = vec![ - Message::User(HistoryUserMessage::new( - "Read the file", - "claude-sonnet-4.5", - )), - Message::Assistant(HistoryAssistantMessage { - assistant_response_message: assistant_msg, - }), - ]; - - // 没有 tool_result - let tool_results: Vec = vec![]; - - let filtered = validate_tool_pairing(&history, &tool_results); - - // 结果应该为空(因为没有 tool_result) - // 同时应该输出警告日志(孤立的 tool_use) - assert!(filtered.is_empty()); - } - - #[test] - fn test_validate_tool_pairing_valid() { - use crate::kiro::model::requests::tool::ToolUseEntry; - - // 测试正常配对的情况 - let mut assistant_msg = AssistantMessage::new("I'll read the file."); - assistant_msg = assistant_msg.with_tool_uses(vec![ToolUseEntry::new("tool-1", "read") - .with_input(serde_json::json!({"path": "/test.txt"}))]); - - let history = vec![ - Message::User(HistoryUserMessage::new( - "Read the file", - "claude-sonnet-4.5", - )), - Message::Assistant(HistoryAssistantMessage { - assistant_response_message: assistant_msg, - }), - ]; - - let tool_results = vec![ToolResult::success("tool-1", "file content")]; - - let filtered = validate_tool_pairing(&history, &tool_results); - - // 配对成功,应该保留 - assert_eq!(filtered.len(), 1); - assert_eq!(filtered[0].tool_use_id, "tool-1"); - } - - #[test] - fn test_validate_tool_pairing_mixed() { - use crate::kiro::model::requests::tool::ToolUseEntry; - - // 测试混合情况:部分配对成功,部分孤立 - let mut assistant_msg = AssistantMessage::new("I'll use two tools."); - assistant_msg = assistant_msg.with_tool_uses(vec![ - ToolUseEntry::new("tool-1", "read").with_input(serde_json::json!({})), - ToolUseEntry::new("tool-2", "write").with_input(serde_json::json!({})), - ]); - - let history = vec![ - Message::User(HistoryUserMessage::new("Do something", "claude-sonnet-4.5")), - Message::Assistant(HistoryAssistantMessage { - assistant_response_message: assistant_msg, - }), - ]; - - // tool_results: tool-1 配对,tool-3 孤立 - let tool_results = vec![ - ToolResult::success("tool-1", "result 1"), - ToolResult::success("tool-3", "orphan result"), // 孤立 - ]; - - let filtered = validate_tool_pairing(&history, &tool_results); - - // 只有 tool-1 应该保留 - assert_eq!(filtered.len(), 1); - assert_eq!(filtered[0].tool_use_id, "tool-1"); - // tool-2 是孤立的 tool_use(无 result),tool-3 是孤立的 tool_result - } - - #[test] - fn test_validate_tool_pairing_history_already_paired() { - use crate::kiro::model::requests::tool::ToolUseEntry; - - // 测试历史中已配对的 tool_use 不应该被报告为孤立 - // 场景:多轮对话中,之前的 tool_use 已经在历史中有对应的 tool_result - let mut assistant_msg1 = AssistantMessage::new("I'll read the file."); - assistant_msg1 = assistant_msg1.with_tool_uses(vec![ToolUseEntry::new("tool-1", "read") - .with_input(serde_json::json!({"path": "/test.txt"}))]); - - // 构建历史中的 user 消息,包含 tool_result - let mut user_msg_with_result = UserMessage::new("", "claude-sonnet-4.5"); - let mut ctx = UserInputMessageContext::new(); - ctx = ctx.with_tool_results(vec![ToolResult::success("tool-1", "file content")]); - user_msg_with_result = user_msg_with_result.with_context(ctx); - - let history = vec![ - // 第一轮:用户请求 - Message::User(HistoryUserMessage::new( - "Read the file", - "claude-sonnet-4.5", - )), - // 第一轮:assistant 使用工具 - Message::Assistant(HistoryAssistantMessage { - assistant_response_message: assistant_msg1, - }), - // 第二轮:用户返回工具结果(历史中已配对) - Message::User(HistoryUserMessage { - user_input_message: user_msg_with_result, - }), - // 第二轮:assistant 响应 - Message::Assistant(HistoryAssistantMessage::new("The file contains...")), - ]; - - // 当前消息没有 tool_results(用户只是继续对话) - let tool_results: Vec = vec![]; - - let filtered = validate_tool_pairing(&history, &tool_results); - - // 结果应该为空,且不应该有孤立 tool_use 的警告 - // 因为 tool-1 已经在历史中配对了 - assert!(filtered.is_empty()); - } - - #[test] - fn test_validate_tool_pairing_duplicate_result() { - use crate::kiro::model::requests::tool::ToolUseEntry; - - // 测试重复的 tool_result(历史中已配对,当前消息又发送了相同的 tool_result) - let mut assistant_msg = AssistantMessage::new("I'll read the file."); - assistant_msg = assistant_msg.with_tool_uses(vec![ToolUseEntry::new("tool-1", "read") - .with_input(serde_json::json!({"path": "/test.txt"}))]); - - // 历史中已有 tool_result - let mut user_msg_with_result = UserMessage::new("", "claude-sonnet-4.5"); - let mut ctx = UserInputMessageContext::new(); - ctx = ctx.with_tool_results(vec![ToolResult::success("tool-1", "file content")]); - user_msg_with_result = user_msg_with_result.with_context(ctx); - - let history = vec![ - Message::User(HistoryUserMessage::new( - "Read the file", - "claude-sonnet-4.5", - )), - Message::Assistant(HistoryAssistantMessage { - assistant_response_message: assistant_msg, - }), - Message::User(HistoryUserMessage { - user_input_message: user_msg_with_result, - }), - Message::Assistant(HistoryAssistantMessage::new("Done")), - ]; - - // 当前消息又发送了相同的 tool_result(重复) - let tool_results = vec![ToolResult::success("tool-1", "file content again")]; - - let filtered = validate_tool_pairing(&history, &tool_results); - - // 重复的 tool_result 应该被过滤掉 - assert!(filtered.is_empty(), "重复的 tool_result 应该被过滤"); - } - - #[test] - fn test_convert_assistant_message_tool_use_only() { - use super::super::types::Message as AnthropicMessage; - - // 测试仅包含 tool_use 的 assistant 消息(无 text 块) - // Kiro API 要求 content 字段不能为空 - let msg = AnthropicMessage { - role: "assistant".to_string(), - content: serde_json::json!([ - {"type": "tool_use", "id": "toolu_01ABC", "name": "read_file", "input": {"path": "/test.txt"}} - ]), - }; - - let result = convert_assistant_message(&msg).expect("应该成功转换"); - - // 验证 content 不为空(使用占位符) - assert!( - !result.assistant_response_message.content.is_empty(), - "content 不应为空" - ); - assert_eq!( - result.assistant_response_message.content, "There is a tool use.", - "仅 tool_use 时应使用 'There is a tool use.' 占位符" - ); - - // 验证 tool_uses 被正确保留 - let tool_uses = result - .assistant_response_message - .tool_uses - .expect("应该有 tool_uses"); - assert_eq!(tool_uses.len(), 1); - assert_eq!(tool_uses[0].tool_use_id, "toolu_01ABC"); - assert_eq!(tool_uses[0].name, "read_file"); - } - - #[test] - fn test_convert_assistant_message_with_text_and_tool_use() { - use super::super::types::Message as AnthropicMessage; - - // 测试同时包含 text 和 tool_use 的 assistant 消息 - let msg = AnthropicMessage { - role: "assistant".to_string(), - content: serde_json::json!([ - {"type": "text", "text": "Let me read that file for you."}, - {"type": "tool_use", "id": "toolu_02XYZ", "name": "read_file", "input": {"path": "/data.json"}} - ]), - }; - - let result = convert_assistant_message(&msg).expect("应该成功转换"); - - // 验证 content 使用原始文本(不是占位符) - assert_eq!( - result.assistant_response_message.content, - "Let me read that file for you." - ); - - // 验证 tool_uses 被正确保留 - let tool_uses = result - .assistant_response_message - .tool_uses - .expect("应该有 tool_uses"); - assert_eq!(tool_uses.len(), 1); - assert_eq!(tool_uses[0].tool_use_id, "toolu_02XYZ"); - } -} diff --git a/src/anthropic/handlers.rs b/src/anthropic/handlers.rs deleted file mode 100644 index 49861226acd6513eddee0ed334c5671427f678b7..0000000000000000000000000000000000000000 --- a/src/anthropic/handlers.rs +++ /dev/null @@ -1,523 +0,0 @@ -//! Anthropic API Handler 函数 - -use std::convert::Infallible; - -use crate::kiro::model::events::Event; -use crate::kiro::model::requests::kiro::KiroRequest; -use crate::kiro::parser::decoder::EventStreamDecoder; -use crate::token; -use axum::{ - Json as JsonExtractor, - body::Body, - extract::State, - http::{StatusCode, header}, - response::{IntoResponse, Json, Response}, -}; -use bytes::Bytes; -use futures::{Stream, StreamExt, stream}; -use serde_json::json; -use std::time::Duration; -use tokio::time::interval; -use uuid::Uuid; - -use super::converter::{ConversionError, convert_request}; -use super::middleware::AppState; -use super::stream::{SseEvent, StreamContext}; -use super::types::{ - CountTokensRequest, CountTokensResponse, ErrorResponse, MessagesRequest, Model, ModelsResponse, -}; -use super::websearch; - -/// GET /v1/models -/// -/// 返回可用的模型列表 -pub async fn get_models() -> impl IntoResponse { - tracing::info!("Received GET /v1/models request"); - - let models = vec![ - Model { - id: "claude-sonnet-4-5-20250929".to_string(), - object: "model".to_string(), - created: 1727568000, - owned_by: "anthropic".to_string(), - display_name: "Claude Sonnet 4.5".to_string(), - model_type: "chat".to_string(), - max_tokens: 32000, - }, - Model { - id: "claude-opus-4-5-20251101".to_string(), - object: "model".to_string(), - created: 1730419200, - owned_by: "anthropic".to_string(), - display_name: "Claude Opus 4.5".to_string(), - model_type: "chat".to_string(), - max_tokens: 32000, - }, - Model { - id: "claude-haiku-4-5-20251001".to_string(), - object: "model".to_string(), - created: 1727740800, - owned_by: "anthropic".to_string(), - display_name: "Claude Haiku 4.5".to_string(), - model_type: "chat".to_string(), - max_tokens: 32000, - }, - ]; - - Json(ModelsResponse { - object: "list".to_string(), - data: models, - }) -} - -/// POST /v1/messages -/// -/// 创建消息(对话) -pub async fn post_messages( - State(state): State, - JsonExtractor(payload): JsonExtractor, -) -> Response { - tracing::info!( - model = %payload.model, - max_tokens = %payload.max_tokens, - stream = %payload.stream, - message_count = %payload.messages.len(), - "Received POST /v1/messages request" - ); - // 检查 KiroProvider 是否可用 - let provider = match &state.kiro_provider { - Some(p) => p.clone(), - None => { - tracing::error!("KiroProvider 未配置"); - return ( - StatusCode::SERVICE_UNAVAILABLE, - Json(ErrorResponse::new( - "service_unavailable", - "Kiro API provider not configured", - )), - ) - .into_response(); - } - }; - - // 检查是否为 WebSearch 请求 - if websearch::has_web_search_tool(&payload) { - tracing::info!("检测到 WebSearch 工具,路由到 WebSearch 处理"); - - // 估算输入 tokens - let input_tokens = token::count_all_tokens( - payload.model.clone(), - payload.system.clone(), - payload.messages.clone(), - payload.tools.clone(), - ) as i32; - - return websearch::handle_websearch_request(provider, &payload, input_tokens).await; - } - - // 转换请求 - let conversion_result = match convert_request(&payload) { - Ok(result) => result, - Err(e) => { - let (error_type, message) = match &e { - ConversionError::UnsupportedModel(model) => { - ("invalid_request_error", format!("模型不支持: {}", model)) - } - ConversionError::EmptyMessages => { - ("invalid_request_error", "消息列表为空".to_string()) - } - }; - tracing::warn!("请求转换失败: {}", e); - return ( - StatusCode::BAD_REQUEST, - Json(ErrorResponse::new(error_type, message)), - ) - .into_response(); - } - }; - - // 构建 Kiro 请求 - let kiro_request = KiroRequest { - conversation_state: conversion_result.conversation_state, - profile_arn: state.profile_arn.clone(), - }; - - let request_body = match serde_json::to_string(&kiro_request) { - Ok(body) => body, - Err(e) => { - tracing::error!("序列化请求失败: {}", e); - return ( - StatusCode::INTERNAL_SERVER_ERROR, - Json(ErrorResponse::new( - "internal_error", - format!("序列化请求失败: {}", e), - )), - ) - .into_response(); - } - }; - - tracing::debug!("Kiro request body: {}", request_body); - - // 估算输入 tokens - let input_tokens = token::count_all_tokens( - payload.model.clone(), - payload.system, - payload.messages, - payload.tools, - ) as i32; - - // 检查是否启用了thinking - let thinking_enabled = payload - .thinking - .as_ref() - .map(|t| t.thinking_type == "enabled") - .unwrap_or(false); - - if payload.stream { - // 流式响应 - handle_stream_request( - provider, - &request_body, - &payload.model, - input_tokens, - thinking_enabled, - ) - .await - } else { - // 非流式响应 - handle_non_stream_request(provider, &request_body, &payload.model, input_tokens).await - } -} - -/// 处理流式请求 -async fn handle_stream_request( - provider: std::sync::Arc, - request_body: &str, - model: &str, - input_tokens: i32, - thinking_enabled: bool, -) -> Response { - // 调用 Kiro API(支持多凭据故障转移) - let response = match provider.call_api_stream(request_body).await { - Ok(resp) => resp, - Err(e) => { - tracing::error!("Kiro API 调用失败: {}", e); - return ( - StatusCode::BAD_GATEWAY, - Json(ErrorResponse::new( - "api_error", - format!("上游 API 调用失败: {}", e), - )), - ) - .into_response(); - } - }; - - // 创建流处理上下文 - let mut ctx = StreamContext::new_with_thinking(model, input_tokens, thinking_enabled); - - // 生成初始事件 - let initial_events = ctx.generate_initial_events(); - - // 创建 SSE 流 - let stream = create_sse_stream(response, ctx, initial_events); - - // 返回 SSE 响应 - Response::builder() - .status(StatusCode::OK) - .header(header::CONTENT_TYPE, "text/event-stream") - .header(header::CACHE_CONTROL, "no-cache") - .header(header::CONNECTION, "keep-alive") - .body(Body::from_stream(stream)) - .unwrap() -} - -/// Ping 事件间隔(25秒) -const PING_INTERVAL_SECS: u64 = 25; - -/// 创建 ping 事件的 SSE 字符串 -fn create_ping_sse() -> Bytes { - Bytes::from("event: ping\ndata: {\"type\": \"ping\"}\n\n") -} - -/// 创建 SSE 事件流 -fn create_sse_stream( - response: reqwest::Response, - ctx: StreamContext, - initial_events: Vec, -) -> impl Stream> { - // 先发送初始事件 - let initial_stream = stream::iter( - initial_events - .into_iter() - .map(|e| Ok(Bytes::from(e.to_sse_string()))), - ); - - // 然后处理 Kiro 响应流,同时每25秒发送 ping 保活 - let body_stream = response.bytes_stream(); - - let processing_stream = stream::unfold( - (body_stream, ctx, EventStreamDecoder::new(), false, interval(Duration::from_secs(PING_INTERVAL_SECS))), - |(mut body_stream, mut ctx, mut decoder, finished, mut ping_interval)| async move { - if finished { - return None; - } - - // 使用 select! 同时等待数据和 ping 定时器 - tokio::select! { - // 处理数据流 - chunk_result = body_stream.next() => { - match chunk_result { - Some(Ok(chunk)) => { - // 解码事件 - if let Err(e) = decoder.feed(&chunk) { - tracing::warn!("缓冲区溢出: {}", e); - } - - let mut events = Vec::new(); - for result in decoder.decode_iter() { - match result { - Ok(frame) => { - if let Ok(event) = Event::from_frame(frame) { - let sse_events = ctx.process_kiro_event(&event); - events.extend(sse_events); - } - } - Err(e) => { - tracing::warn!("解码事件失败: {}", e); - } - } - } - - // 转换为 SSE 字节流 - let bytes: Vec> = events - .into_iter() - .map(|e| Ok(Bytes::from(e.to_sse_string()))) - .collect(); - - Some((stream::iter(bytes), (body_stream, ctx, decoder, false, ping_interval))) - } - Some(Err(e)) => { - tracing::error!("读取响应流失败: {}", e); - // 发送最终事件并结束 - let final_events = ctx.generate_final_events(); - let bytes: Vec> = final_events - .into_iter() - .map(|e| Ok(Bytes::from(e.to_sse_string()))) - .collect(); - Some((stream::iter(bytes), (body_stream, ctx, decoder, true, ping_interval))) - } - None => { - // 流结束,发送最终事件 - let final_events = ctx.generate_final_events(); - let bytes: Vec> = final_events - .into_iter() - .map(|e| Ok(Bytes::from(e.to_sse_string()))) - .collect(); - Some((stream::iter(bytes), (body_stream, ctx, decoder, true, ping_interval))) - } - } - } - // 发送 ping 保活 - _ = ping_interval.tick() => { - tracing::trace!("发送 ping 保活事件"); - let bytes: Vec> = vec![Ok(create_ping_sse())]; - Some((stream::iter(bytes), (body_stream, ctx, decoder, false, ping_interval))) - } - } - }, - ) - .flatten(); - - initial_stream.chain(processing_stream) -} - -/// 上下文窗口大小(200k tokens) -const CONTEXT_WINDOW_SIZE: i32 = 200_000; - -/// 处理非流式请求 -async fn handle_non_stream_request( - provider: std::sync::Arc, - request_body: &str, - model: &str, - input_tokens: i32, -) -> Response { - // 调用 Kiro API(支持多凭据故障转移) - let response = match provider.call_api(request_body).await { - Ok(resp) => resp, - Err(e) => { - tracing::error!("Kiro API 调用失败: {}", e); - return ( - StatusCode::BAD_GATEWAY, - Json(ErrorResponse::new( - "api_error", - format!("上游 API 调用失败: {}", e), - )), - ) - .into_response(); - } - }; - - // 读取响应体 - let body_bytes = match response.bytes().await { - Ok(bytes) => bytes, - Err(e) => { - tracing::error!("读取响应体失败: {}", e); - return ( - StatusCode::BAD_GATEWAY, - Json(ErrorResponse::new( - "api_error", - format!("读取响应失败: {}", e), - )), - ) - .into_response(); - } - }; - - // 解析事件流 - let mut decoder = EventStreamDecoder::new(); - if let Err(e) = decoder.feed(&body_bytes) { - tracing::warn!("缓冲区溢出: {}", e); - } - - let mut text_content = String::new(); - let mut tool_uses: Vec = Vec::new(); - let mut has_tool_use = false; - let mut stop_reason = "end_turn".to_string(); - // 从 contextUsageEvent 计算的实际输入 tokens - let mut context_input_tokens: Option = None; - - // 收集工具调用的增量 JSON - let mut tool_json_buffers: std::collections::HashMap = - std::collections::HashMap::new(); - - for result in decoder.decode_iter() { - match result { - Ok(frame) => { - if let Ok(event) = Event::from_frame(frame) { - match event { - Event::AssistantResponse(resp) => { - text_content.push_str(&resp.content); - } - Event::ToolUse(tool_use) => { - has_tool_use = true; - - // 累积工具的 JSON 输入 - let buffer = tool_json_buffers - .entry(tool_use.tool_use_id.clone()) - .or_insert_with(String::new); - buffer.push_str(&tool_use.input); - - // 如果是完整的工具调用,添加到列表 - if tool_use.stop { - let input: serde_json::Value = serde_json::from_str(buffer) - .unwrap_or_else(|e| { - tracing::warn!( - "工具输入 JSON 解析失败: {}, tool_use_id: {}, 原始内容: {}", - e, tool_use.tool_use_id, buffer - ); - serde_json::json!({}) - }); - - tool_uses.push(json!({ - "type": "tool_use", - "id": tool_use.tool_use_id, - "name": tool_use.name, - "input": input - })); - } - } - Event::ContextUsage(context_usage) => { - // 从上下文使用百分比计算实际的 input_tokens - // 公式: percentage * 200000 / 100 = percentage * 2000 - let actual_input_tokens = (context_usage.context_usage_percentage - * (CONTEXT_WINDOW_SIZE as f64) - / 100.0) - as i32; - context_input_tokens = Some(actual_input_tokens); - tracing::debug!( - "收到 contextUsageEvent: {}%, 计算 input_tokens: {}", - context_usage.context_usage_percentage, - actual_input_tokens - ); - } - Event::Exception { exception_type, .. } => { - if exception_type == "ContentLengthExceededException" { - stop_reason = "max_tokens".to_string(); - } - } - _ => {} - } - } - } - Err(e) => { - tracing::warn!("解码事件失败: {}", e); - } - } - } - - // 确定 stop_reason - if has_tool_use && stop_reason == "end_turn" { - stop_reason = "tool_use".to_string(); - } - - // 构建响应内容 - let mut content: Vec = Vec::new(); - - if !text_content.is_empty() { - content.push(json!({ - "type": "text", - "text": text_content - })); - } - - content.extend(tool_uses); - - // 估算输出 tokens - let output_tokens = token::estimate_output_tokens(&content); - - // 使用从 contextUsageEvent 计算的 input_tokens,如果没有则使用估算值 - let final_input_tokens = context_input_tokens.unwrap_or(input_tokens); - - // 构建 Anthropic 响应 - let response_body = json!({ - "id": format!("msg_{}", Uuid::new_v4().to_string().replace('-', "")), - "type": "message", - "role": "assistant", - "content": content, - "model": model, - "stop_reason": stop_reason, - "stop_sequence": null, - "usage": { - "input_tokens": final_input_tokens, - "output_tokens": output_tokens - } - }); - - (StatusCode::OK, Json(response_body)).into_response() -} - -/// POST /v1/messages/count_tokens -/// -/// 计算消息的 token 数量 -pub async fn count_tokens( - JsonExtractor(payload): JsonExtractor, -) -> impl IntoResponse { - tracing::info!( - model = %payload.model, - message_count = %payload.messages.len(), - "Received POST /v1/messages/count_tokens request" - ); - - let total_tokens = token::count_all_tokens( - payload.model, - payload.system, - payload.messages, - payload.tools, - ) as i32; - - Json(CountTokensResponse { - input_tokens: total_tokens.max(1) as i32, - }) -} diff --git a/src/anthropic/middleware.rs b/src/anthropic/middleware.rs deleted file mode 100644 index 731c3e0b4d116f6118567ff84d47ca2c2201fbd2..0000000000000000000000000000000000000000 --- a/src/anthropic/middleware.rs +++ /dev/null @@ -1,84 +0,0 @@ -//! Anthropic API 中间件 - -use std::sync::Arc; - -use axum::{ - body::Body, - extract::State, - http::{Request, StatusCode}, - middleware::Next, - response::{IntoResponse, Json, Response}, -}; - -use crate::common::auth; -use crate::kiro::provider::KiroProvider; - -use super::types::ErrorResponse; - -/// 应用共享状态 -#[derive(Clone)] -pub struct AppState { - /// API 密钥 - pub api_key: String, - /// Kiro Provider(可选,用于实际 API 调用) - /// 内部使用 MultiTokenManager,已支持线程安全的多凭据管理 - pub kiro_provider: Option>, - /// Profile ARN(可选,用于请求) - pub profile_arn: Option, -} - -impl AppState { - /// 创建新的应用状态 - pub fn new(api_key: impl Into) -> Self { - Self { - api_key: api_key.into(), - kiro_provider: None, - profile_arn: None, - } - } - - /// 设置 KiroProvider - pub fn with_kiro_provider(mut self, provider: KiroProvider) -> Self { - self.kiro_provider = Some(Arc::new(provider)); - self - } - - /// 设置 Profile ARN - pub fn with_profile_arn(mut self, arn: impl Into) -> Self { - self.profile_arn = Some(arn.into()); - self - } -} - -/// API Key 认证中间件 -pub async fn auth_middleware( - State(state): State, - request: Request, - next: Next, -) -> Response { - match auth::extract_api_key(&request) { - Some(key) if auth::constant_time_eq(&key, &state.api_key) => next.run(request).await, - _ => { - let error = ErrorResponse::authentication_error(); - (StatusCode::UNAUTHORIZED, Json(error)).into_response() - } - } -} - -/// CORS 中间件层 -/// -/// **安全说明**:当前配置允许所有来源(Any),这是为了支持公开 API 服务。 -/// 如果需要更严格的安全控制,请根据实际需求配置具体的允许来源、方法和头信息。 -/// -/// # 配置说明 -/// - `allow_origin(Any)`: 允许任何来源的请求 -/// - `allow_methods(Any)`: 允许任何 HTTP 方法 -/// - `allow_headers(Any)`: 允许任何请求头 -pub fn cors_layer() -> tower_http::cors::CorsLayer { - use tower_http::cors::{Any, CorsLayer}; - - CorsLayer::new() - .allow_origin(Any) - .allow_methods(Any) - .allow_headers(Any) -} diff --git a/src/anthropic/mod.rs b/src/anthropic/mod.rs deleted file mode 100644 index a5f3842a7a63e4ca7f90df5bdf1cd3e3f9d07e1f..0000000000000000000000000000000000000000 --- a/src/anthropic/mod.rs +++ /dev/null @@ -1,27 +0,0 @@ -//! Anthropic API 兼容服务模块 -//! -//! 提供与 Anthropic Claude API 兼容的 HTTP 服务端点。 -//! -//! # 支持的端点 -//! - `GET /v1/models` - 获取可用模型列表 -//! - `POST /v1/messages` - 创建消息(对话) -//! - `POST /v1/messages/count_tokens` - 计算 token 数量 -//! -//! # 使用示例 -//! ```rust,ignore -//! use kiro_rs::anthropic; -//! -//! let app = anthropic::create_router("your-api-key"); -//! let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await?; -//! axum::serve(listener, app).await?; -//! ``` - -mod converter; -mod handlers; -mod middleware; -mod router; -mod stream; -pub mod types; -mod websearch; - -pub use router::create_router_with_provider; diff --git a/src/anthropic/router.rs b/src/anthropic/router.rs deleted file mode 100644 index c5aafc75a38c417511a3f79092e21c71e9ce70a7..0000000000000000000000000000000000000000 --- a/src/anthropic/router.rs +++ /dev/null @@ -1,65 +0,0 @@ -//! Anthropic API 路由配置 - -use axum::{ - Router, - extract::DefaultBodyLimit, - middleware, - routing::{get, post}, -}; - -use crate::kiro::provider::KiroProvider; - -use super::{ - handlers::{count_tokens, get_models, post_messages}, - middleware::{AppState, auth_middleware, cors_layer}, -}; - -/// 请求体最大大小限制 (50MB) -const MAX_BODY_SIZE: usize = 50 * 1024 * 1024; - -/// 创建 Anthropic API 路由 -/// -/// # 端点 -/// - `GET /v1/models` - 获取可用模型列表 -/// - `POST /v1/messages` - 创建消息(对话) -/// - `POST /v1/messages/count_tokens` - 计算 token 数量 -/// -/// # 认证 -/// 所有 `/v1` 路径需要 API Key 认证,支持: -/// - `x-api-key` header -/// - `Authorization: Bearer ` header -/// -/// # 参数 -/// - `api_key`: API 密钥,用于验证客户端请求 -/// - `kiro_provider`: 可选的 KiroProvider,用于调用上游 API - -/// 创建带有 KiroProvider 的 Anthropic API 路由 -pub fn create_router_with_provider( - api_key: impl Into, - kiro_provider: Option, - profile_arn: Option, -) -> Router { - let mut state = AppState::new(api_key); - if let Some(provider) = kiro_provider { - state = state.with_kiro_provider(provider); - } - if let Some(arn) = profile_arn { - state = state.with_profile_arn(arn); - } - - // 需要认证的 /v1 路由 - let v1_routes = Router::new() - .route("/models", get(get_models)) - .route("/messages", post(post_messages)) - .route("/messages/count_tokens", post(count_tokens)) - .layer(middleware::from_fn_with_state( - state.clone(), - auth_middleware, - )); - - Router::new() - .nest("/v1", v1_routes) - .layer(cors_layer()) - .layer(DefaultBodyLimit::max(MAX_BODY_SIZE)) - .with_state(state) -} diff --git a/src/anthropic/stream.rs b/src/anthropic/stream.rs deleted file mode 100644 index 1d52388040d5d89ad49078b27901adfdd61bf375..0000000000000000000000000000000000000000 --- a/src/anthropic/stream.rs +++ /dev/null @@ -1,1423 +0,0 @@ -//! 流式响应处理模块 -//! -//! 实现 Kiro → Anthropic 流式响应转换和 SSE 状态管理 - -use std::collections::HashMap; - -use serde_json::json; -use uuid::Uuid; - -use crate::kiro::model::events::Event; - -/// 找到小于等于目标位置的最近有效UTF-8字符边界 -/// -/// UTF-8字符可能占用1-4个字节,直接按字节位置切片可能会切在多字节字符中间导致panic。 -/// 这个函数从目标位置向前搜索,找到最近的有效字符边界。 -fn find_char_boundary(s: &str, target: usize) -> usize { - if target >= s.len() { - return s.len(); - } - if target == 0 { - return 0; - } - // 从目标位置向前搜索有效的字符边界 - let mut pos = target; - while pos > 0 && !s.is_char_boundary(pos) { - pos -= 1; - } - pos -} - -/// 需要跳过的包裹字符 -/// -/// 当 thinking 标签被这些字符包裹时,认为是在引用标签而非真正的标签: -/// - 反引号 (`):行内代码 -/// - 双引号 ("):字符串 -/// - 单引号 ('):字符串 -const QUOTE_CHARS: &[u8] = &[ - b'`', b'"', b'\'', b'\\', b'#', b'!', b'@', b'$', b'%', b'^', b'&', b'*', b'(', b')', b'-', - b'_', b'=', b'+', b'[', b']', b'{', b'}', b';', b':', b'<', b'>', b',', b'.', b'?', b'/', -]; - -/// 检查指定位置的字符是否是引用字符 -fn is_quote_char(buffer: &str, pos: usize) -> bool { - buffer - .as_bytes() - .get(pos) - .map(|c| QUOTE_CHARS.contains(c)) - .unwrap_or(false) -} - -/// 查找真正的 thinking 结束标签(不被引用字符包裹,且后面有双换行符) -/// -/// 当模型在思考过程中提到 `` 时,通常会用反引号、引号等包裹, -/// 或者在同一行有其他内容(如"关于 标签")。 -/// 这个函数会跳过这些情况,只返回真正的结束标签位置。 -/// -/// 跳过的情况: -/// - 被引用字符包裹(反引号、引号等) -/// - 后面没有双换行符(真正的结束标签后面会有 `\n\n`) -/// - 标签在缓冲区末尾(流式处理时需要等待更多内容) -/// -/// # 参数 -/// - `buffer`: 要搜索的字符串 -/// -/// # 返回值 -/// - `Some(pos)`: 真正的结束标签的起始位置 -/// - `None`: 没有找到真正的结束标签 -fn find_real_thinking_end_tag(buffer: &str) -> Option { - const TAG: &str = ""; - let mut search_start = 0; - - while let Some(pos) = buffer[search_start..].find(TAG) { - let absolute_pos = search_start + pos; - - // 检查前面是否有引用字符 - let has_quote_before = absolute_pos > 0 && is_quote_char(buffer, absolute_pos - 1); - - // 检查后面是否有引用字符 - let after_pos = absolute_pos + TAG.len(); - let has_quote_after = is_quote_char(buffer, after_pos); - - // 如果被引用字符包裹,跳过 - if has_quote_before || has_quote_after { - search_start = absolute_pos + 1; - continue; - } - - // 检查后面的内容 - let after_content = &buffer[after_pos..]; - - // 如果标签后面内容不足以判断是否有双换行符,等待更多内容 - if after_content.len() < 2 { - return None; - } - - // 真正的 thinking 结束标签后面会有双换行符 `\n\n` - if after_content.starts_with("\n\n") { - return Some(absolute_pos); - } - - // 不是双换行符,跳过继续搜索 - search_start = absolute_pos + 1; - } - - None -} - -/// 查找缓冲区末尾的 thinking 结束标签(允许末尾只有空白字符) -/// -/// 用于“边界事件”场景:例如 thinking 结束后立刻进入 tool_use,或流结束, -/// 此时 `` 后面可能没有 `\n\n`,但结束标签依然应被识别并过滤。 -/// -/// 约束:只有当 `` 之后全部都是空白字符时才认为是结束标签, -/// 以避免在 thinking 内容中提到 ``(非结束标签)时误判。 -fn find_real_thinking_end_tag_at_buffer_end(buffer: &str) -> Option { - const TAG: &str = ""; - let mut search_start = 0; - - while let Some(pos) = buffer[search_start..].find(TAG) { - let absolute_pos = search_start + pos; - - // 检查前面是否有引用字符 - let has_quote_before = absolute_pos > 0 && is_quote_char(buffer, absolute_pos - 1); - - // 检查后面是否有引用字符 - let after_pos = absolute_pos + TAG.len(); - let has_quote_after = is_quote_char(buffer, after_pos); - - if has_quote_before || has_quote_after { - search_start = absolute_pos + 1; - continue; - } - - // 只有当标签后面全部是空白字符时才认定为结束标签 - if buffer[after_pos..].trim().is_empty() { - return Some(absolute_pos); - } - - search_start = absolute_pos + 1; - } - - None -} - -/// 查找真正的 thinking 开始标签(不被引用字符包裹) -/// -/// 与 `find_real_thinking_end_tag` 类似,跳过被引用字符包裹的开始标签。 -fn find_real_thinking_start_tag(buffer: &str) -> Option { - const TAG: &str = ""; - let mut search_start = 0; - - while let Some(pos) = buffer[search_start..].find(TAG) { - let absolute_pos = search_start + pos; - - // 检查前面是否有引用字符 - let has_quote_before = absolute_pos > 0 && is_quote_char(buffer, absolute_pos - 1); - - // 检查后面是否有引用字符 - let after_pos = absolute_pos + TAG.len(); - let has_quote_after = is_quote_char(buffer, after_pos); - - // 如果不被引用字符包裹,则是真正的开始标签 - if !has_quote_before && !has_quote_after { - return Some(absolute_pos); - } - - // 继续搜索下一个匹配 - search_start = absolute_pos + 1; - } - - None -} - -/// SSE 事件 -#[derive(Debug, Clone)] -pub struct SseEvent { - pub event: String, - pub data: serde_json::Value, -} - -impl SseEvent { - pub fn new(event: impl Into, data: serde_json::Value) -> Self { - Self { - event: event.into(), - data, - } - } - - /// 格式化为 SSE 字符串 - pub fn to_sse_string(&self) -> String { - format!( - "event: {}\ndata: {}\n\n", - self.event, - serde_json::to_string(&self.data).unwrap_or_default() - ) - } -} - -/// 内容块状态 -#[derive(Debug, Clone)] -struct BlockState { - block_type: String, - started: bool, - stopped: bool, -} - -impl BlockState { - fn new(block_type: impl Into) -> Self { - Self { - block_type: block_type.into(), - started: false, - stopped: false, - } - } -} - -/// SSE 状态管理器 -/// -/// 确保 SSE 事件序列符合 Claude API 规范: -/// 1. message_start 只能出现一次 -/// 2. content_block 必须先 start 再 delta 再 stop -/// 3. message_delta 只能出现一次,且在所有 content_block_stop 之后 -/// 4. message_stop 在最后 -#[derive(Debug)] -pub struct SseStateManager { - /// message_start 是否已发送 - message_started: bool, - /// message_delta 是否已发送 - message_delta_sent: bool, - /// 活跃的内容块状态 - active_blocks: HashMap, - /// 消息是否已结束 - message_ended: bool, - /// 下一个块索引 - next_block_index: i32, - /// 当前 stop_reason - stop_reason: Option, - /// 是否有工具调用 - has_tool_use: bool, -} - -impl Default for SseStateManager { - fn default() -> Self { - Self::new() - } -} - -impl SseStateManager { - pub fn new() -> Self { - Self { - message_started: false, - message_delta_sent: false, - active_blocks: HashMap::new(), - message_ended: false, - next_block_index: 0, - stop_reason: None, - has_tool_use: false, - } - } - - /// 判断指定块是否处于可接收 delta 的打开状态 - fn is_block_open_of_type(&self, index: i32, expected_type: &str) -> bool { - self.active_blocks - .get(&index) - .is_some_and(|b| b.started && !b.stopped && b.block_type == expected_type) - } - - /// 获取下一个块索引 - pub fn next_block_index(&mut self) -> i32 { - let index = self.next_block_index; - self.next_block_index += 1; - index - } - - /// 记录工具调用 - pub fn set_has_tool_use(&mut self, has: bool) { - self.has_tool_use = has; - } - - /// 设置 stop_reason - pub fn set_stop_reason(&mut self, reason: impl Into) { - self.stop_reason = Some(reason.into()); - } - - /// 获取最终的 stop_reason - pub fn get_stop_reason(&self) -> String { - if let Some(ref reason) = self.stop_reason { - reason.clone() - } else if self.has_tool_use { - "tool_use".to_string() - } else { - "end_turn".to_string() - } - } - - /// 处理 message_start 事件 - pub fn handle_message_start(&mut self, event: serde_json::Value) -> Option { - if self.message_started { - tracing::debug!("跳过重复的 message_start 事件"); - return None; - } - self.message_started = true; - Some(SseEvent::new("message_start", event)) - } - - /// 处理 content_block_start 事件 - pub fn handle_content_block_start( - &mut self, - index: i32, - block_type: &str, - data: serde_json::Value, - ) -> Vec { - let mut events = Vec::new(); - - // 如果是 tool_use 块,先关闭之前的文本块 - if block_type == "tool_use" { - self.has_tool_use = true; - for (block_index, block) in self.active_blocks.iter_mut() { - if block.block_type == "text" && block.started && !block.stopped { - // 自动发送 content_block_stop 关闭文本块 - events.push(SseEvent::new( - "content_block_stop", - json!({ - "type": "content_block_stop", - "index": block_index - }), - )); - block.stopped = true; - } - } - } - - // 检查块是否已存在 - if let Some(block) = self.active_blocks.get_mut(&index) { - if block.started { - tracing::debug!("块 {} 已启动,跳过重复的 content_block_start", index); - return events; - } - block.started = true; - } else { - let mut block = BlockState::new(block_type); - block.started = true; - self.active_blocks.insert(index, block); - } - - events.push(SseEvent::new("content_block_start", data)); - events - } - - /// 处理 content_block_delta 事件 - pub fn handle_content_block_delta( - &mut self, - index: i32, - data: serde_json::Value, - ) -> Option { - // 确保块已启动 - if let Some(block) = self.active_blocks.get(&index) { - if !block.started || block.stopped { - tracing::warn!( - "块 {} 状态异常: started={}, stopped={}", - index, - block.started, - block.stopped - ); - return None; - } - } else { - // 块不存在,可能需要先创建 - tracing::warn!("收到未知块 {} 的 delta 事件", index); - return None; - } - - Some(SseEvent::new("content_block_delta", data)) - } - - /// 处理 content_block_stop 事件 - pub fn handle_content_block_stop(&mut self, index: i32) -> Option { - if let Some(block) = self.active_blocks.get_mut(&index) { - if block.stopped { - tracing::debug!("块 {} 已停止,跳过重复的 content_block_stop", index); - return None; - } - block.stopped = true; - return Some(SseEvent::new( - "content_block_stop", - json!({ - "type": "content_block_stop", - "index": index - }), - )); - } - None - } - - /// 生成最终事件序列 - pub fn generate_final_events( - &mut self, - input_tokens: i32, - output_tokens: i32, - ) -> Vec { - let mut events = Vec::new(); - - // 关闭所有未关闭的块 - for (index, block) in self.active_blocks.iter_mut() { - if block.started && !block.stopped { - events.push(SseEvent::new( - "content_block_stop", - json!({ - "type": "content_block_stop", - "index": index - }), - )); - block.stopped = true; - } - } - - // 发送 message_delta - if !self.message_delta_sent { - self.message_delta_sent = true; - events.push(SseEvent::new( - "message_delta", - json!({ - "type": "message_delta", - "delta": { - "stop_reason": self.get_stop_reason(), - "stop_sequence": null - }, - "usage": { - "input_tokens": input_tokens, - "output_tokens": output_tokens - } - }), - )); - } - - // 发送 message_stop - if !self.message_ended { - self.message_ended = true; - events.push(SseEvent::new( - "message_stop", - json!({ "type": "message_stop" }), - )); - } - - events - } -} - -/// 上下文窗口大小(200k tokens) -const CONTEXT_WINDOW_SIZE: i32 = 200_000; - -/// 流处理上下文 -pub struct StreamContext { - /// SSE 状态管理器 - pub state_manager: SseStateManager, - /// 请求的模型名称 - pub model: String, - /// 消息 ID - pub message_id: String, - /// 输入 tokens(估算值) - pub input_tokens: i32, - /// 从 contextUsageEvent 计算的实际输入 tokens - pub context_input_tokens: Option, - /// 输出 tokens 累计 - pub output_tokens: i32, - /// 工具块索引映射 (tool_id -> block_index) - pub tool_block_indices: HashMap, - /// thinking 是否启用 - pub thinking_enabled: bool, - /// thinking 内容缓冲区 - pub thinking_buffer: String, - /// 是否在 thinking 块内 - pub in_thinking_block: bool, - /// thinking 块是否已提取完成 - pub thinking_extracted: bool, - /// thinking 块索引 - pub thinking_block_index: Option, - /// 文本块索引(thinking 启用时动态分配) - pub text_block_index: Option, -} - -impl StreamContext { - /// 创建启用thinking的StreamContext - pub fn new_with_thinking( - model: impl Into, - input_tokens: i32, - thinking_enabled: bool, - ) -> Self { - Self { - state_manager: SseStateManager::new(), - model: model.into(), - message_id: format!("msg_{}", Uuid::new_v4().to_string().replace('-', "")), - input_tokens, - context_input_tokens: None, - output_tokens: 0, - tool_block_indices: HashMap::new(), - thinking_enabled, - thinking_buffer: String::new(), - in_thinking_block: false, - thinking_extracted: false, - thinking_block_index: None, - text_block_index: None, - } - } - - /// 生成 message_start 事件 - pub fn create_message_start_event(&self) -> serde_json::Value { - json!({ - "type": "message_start", - "message": { - "id": self.message_id, - "type": "message", - "role": "assistant", - "content": [], - "model": self.model, - "stop_reason": null, - "stop_sequence": null, - "usage": { - "input_tokens": self.input_tokens, - "output_tokens": 1 - } - } - }) - } - - /// 生成初始事件序列 (message_start + 文本块 start) - /// - /// 当 thinking 启用时,不在初始化时创建文本块,而是等到实际收到内容时再创建。 - /// 这样可以确保 thinking 块(索引 0)在文本块(索引 1)之前。 - pub fn generate_initial_events(&mut self) -> Vec { - let mut events = Vec::new(); - - // message_start - let msg_start = self.create_message_start_event(); - if let Some(event) = self.state_manager.handle_message_start(msg_start) { - events.push(event); - } - - // 如果启用了 thinking,不在这里创建文本块 - // thinking 块和文本块会在 process_content_with_thinking 中按正确顺序创建 - if self.thinking_enabled { - return events; - } - - // 创建初始文本块(仅在未启用 thinking 时) - let text_block_index = self.state_manager.next_block_index(); - self.text_block_index = Some(text_block_index); - let text_block_events = self.state_manager.handle_content_block_start( - text_block_index, - "text", - json!({ - "type": "content_block_start", - "index": text_block_index, - "content_block": { - "type": "text", - "text": "" - } - }), - ); - events.extend(text_block_events); - - events - } - - /// 处理 Kiro 事件并转换为 Anthropic SSE 事件 - pub fn process_kiro_event(&mut self, event: &Event) -> Vec { - match event { - Event::AssistantResponse(resp) => self.process_assistant_response(&resp.content), - Event::ToolUse(tool_use) => self.process_tool_use(tool_use), - Event::ContextUsage(context_usage) => { - // 从上下文使用百分比计算实际的 input_tokens - // 公式: percentage * 200000 / 100 = percentage * 2000 - let actual_input_tokens = (context_usage.context_usage_percentage - * (CONTEXT_WINDOW_SIZE as f64) - / 100.0) as i32; - self.context_input_tokens = Some(actual_input_tokens); - tracing::debug!( - "收到 contextUsageEvent: {}%, 计算 input_tokens: {}", - context_usage.context_usage_percentage, - actual_input_tokens - ); - Vec::new() - } - Event::Error { - error_code, - error_message, - } => { - tracing::error!("收到错误事件: {} - {}", error_code, error_message); - Vec::new() - } - Event::Exception { - exception_type, - message, - } => { - // 处理 ContentLengthExceededException - if exception_type == "ContentLengthExceededException" { - self.state_manager.set_stop_reason("max_tokens"); - } - tracing::warn!("收到异常事件: {} - {}", exception_type, message); - Vec::new() - } - _ => Vec::new(), - } - } - - /// 处理助手响应事件 - fn process_assistant_response(&mut self, content: &str) -> Vec { - if content.is_empty() { - return Vec::new(); - } - - // 估算 tokens - self.output_tokens += estimate_tokens(content); - - // 如果启用了thinking,需要处理thinking块 - if self.thinking_enabled { - return self.process_content_with_thinking(content); - } - - // 非 thinking 模式同样复用统一的 text_delta 发送逻辑, - // 以便在 tool_use 自动关闭文本块后能够自愈重建新的文本块,避免“吞字”。 - self.create_text_delta_events(content) - } - - /// 处理包含thinking块的内容 - fn process_content_with_thinking(&mut self, content: &str) -> Vec { - let mut events = Vec::new(); - - // 将内容添加到缓冲区进行处理 - self.thinking_buffer.push_str(content); - - loop { - if !self.in_thinking_block && !self.thinking_extracted { - // 查找 开始标签(跳过被反引号包裹的) - if let Some(start_pos) = find_real_thinking_start_tag(&self.thinking_buffer) { - // 发送 之前的内容作为 text_delta - let before_thinking = self.thinking_buffer[..start_pos].to_string(); - if !before_thinking.is_empty() { - events.extend(self.create_text_delta_events(&before_thinking)); - } - - // 进入 thinking 块 - self.in_thinking_block = true; - self.thinking_buffer = - self.thinking_buffer[start_pos + "".len()..].to_string(); - - // 创建 thinking 块的 content_block_start 事件 - let thinking_index = self.state_manager.next_block_index(); - self.thinking_block_index = Some(thinking_index); - let start_events = self.state_manager.handle_content_block_start( - thinking_index, - "thinking", - json!({ - "type": "content_block_start", - "index": thinking_index, - "content_block": { - "type": "thinking", - "thinking": "" - } - }), - ); - events.extend(start_events); - } else { - // 没有找到 ,检查是否可能是部分标签 - // 保留可能是部分标签的内容 - let target_len = self - .thinking_buffer - .len() - .saturating_sub("".len()); - let safe_len = find_char_boundary(&self.thinking_buffer, target_len); - if safe_len > 0 { - let safe_content = self.thinking_buffer[..safe_len].to_string(); - if !safe_content.is_empty() { - events.extend(self.create_text_delta_events(&safe_content)); - } - self.thinking_buffer = self.thinking_buffer[safe_len..].to_string(); - } - break; - } - } else if self.in_thinking_block { - // 在 thinking 块内,查找 结束标签(跳过被反引号包裹的) - if let Some(end_pos) = find_real_thinking_end_tag(&self.thinking_buffer) { - // 提取 thinking 内容 - let thinking_content = self.thinking_buffer[..end_pos].to_string(); - if !thinking_content.is_empty() { - if let Some(thinking_index) = self.thinking_block_index { - events.push( - self.create_thinking_delta_event(thinking_index, &thinking_content), - ); - } - } - - // 结束 thinking 块 - self.in_thinking_block = false; - self.thinking_extracted = true; - - // 发送空的 thinking_delta 事件,然后发送 content_block_stop 事件 - if let Some(thinking_index) = self.thinking_block_index { - // 先发送空的 thinking_delta - events.push(self.create_thinking_delta_event(thinking_index, "")); - // 再发送 content_block_stop - if let Some(stop_event) = - self.state_manager.handle_content_block_stop(thinking_index) - { - events.push(stop_event); - } - } - - self.thinking_buffer = - self.thinking_buffer[end_pos + "".len()..].to_string(); - } else { - // 没有找到结束标签,发送当前缓冲区内容作为 thinking_delta - // 保留可能是部分标签的内容 - let target_len = self - .thinking_buffer - .len() - .saturating_sub("".len()); - let safe_len = find_char_boundary(&self.thinking_buffer, target_len); - if safe_len > 0 { - let safe_content = self.thinking_buffer[..safe_len].to_string(); - if !safe_content.is_empty() { - if let Some(thinking_index) = self.thinking_block_index { - events.push( - self.create_thinking_delta_event(thinking_index, &safe_content), - ); - } - } - self.thinking_buffer = self.thinking_buffer[safe_len..].to_string(); - } - break; - } - } else { - // thinking 已提取完成,剩余内容作为 text_delta - if !self.thinking_buffer.is_empty() { - let remaining = self.thinking_buffer.clone(); - self.thinking_buffer.clear(); - events.extend(self.create_text_delta_events(&remaining)); - } - break; - } - } - - events - } - - /// 创建 text_delta 事件 - /// - /// 如果文本块尚未创建,会先创建文本块。 - /// 当发生 tool_use 时,状态机会自动关闭当前文本块;后续文本会自动创建新的文本块继续输出。 - /// - /// 返回值包含可能的 content_block_start 事件和 content_block_delta 事件。 - fn create_text_delta_events(&mut self, text: &str) -> Vec { - let mut events = Vec::new(); - - // 如果当前 text_block_index 指向的块已经被关闭(例如 tool_use 开始时自动 stop), - // 则丢弃该索引并创建新的文本块继续输出,避免 delta 被状态机拒绝导致“吞字”。 - if let Some(idx) = self.text_block_index { - if !self.state_manager.is_block_open_of_type(idx, "text") { - self.text_block_index = None; - } - } - - // 获取或创建文本块索引 - let text_index = if let Some(idx) = self.text_block_index { - idx - } else { - // 文本块尚未创建,需要先创建 - let idx = self.state_manager.next_block_index(); - self.text_block_index = Some(idx); - - // 发送 content_block_start 事件 - let start_events = self.state_manager.handle_content_block_start( - idx, - "text", - json!({ - "type": "content_block_start", - "index": idx, - "content_block": { - "type": "text", - "text": "" - } - }), - ); - events.extend(start_events); - idx - }; - - // 发送 content_block_delta 事件 - if let Some(delta_event) = self.state_manager.handle_content_block_delta( - text_index, - json!({ - "type": "content_block_delta", - "index": text_index, - "delta": { - "type": "text_delta", - "text": text - } - }), - ) { - events.push(delta_event); - } - - events - } - - /// 创建 thinking_delta 事件 - fn create_thinking_delta_event(&self, index: i32, thinking: &str) -> SseEvent { - SseEvent::new( - "content_block_delta", - json!({ - "type": "content_block_delta", - "index": index, - "delta": { - "type": "thinking_delta", - "thinking": thinking - } - }), - ) - } - - /// 处理工具使用事件 - fn process_tool_use( - &mut self, - tool_use: &crate::kiro::model::events::ToolUseEvent, - ) -> Vec { - let mut events = Vec::new(); - - self.state_manager.set_has_tool_use(true); - - // tool_use 必须发生在 thinking 结束之后。 - // 但当 `` 后面没有 `\n\n`(例如紧跟 tool_use 或流结束)时, - // thinking 结束标签会滞留在 thinking_buffer,导致后续 flush 时把 `` 当作内容输出。 - // 这里在开始 tool_use block 前做一次“边界场景”的结束标签识别与过滤。 - if self.thinking_enabled && self.in_thinking_block { - if let Some(end_pos) = find_real_thinking_end_tag_at_buffer_end(&self.thinking_buffer) { - let thinking_content = self.thinking_buffer[..end_pos].to_string(); - if !thinking_content.is_empty() { - if let Some(thinking_index) = self.thinking_block_index { - events.push( - self.create_thinking_delta_event(thinking_index, &thinking_content), - ); - } - } - - // 结束 thinking 块 - self.in_thinking_block = false; - self.thinking_extracted = true; - - if let Some(thinking_index) = self.thinking_block_index { - // 先发送空的 thinking_delta - events.push(self.create_thinking_delta_event(thinking_index, "")); - // 再发送 content_block_stop - if let Some(stop_event) = - self.state_manager.handle_content_block_stop(thinking_index) - { - events.push(stop_event); - } - } - - // 把结束标签后的内容当作普通文本(通常为空或空白) - let after_pos = end_pos + "".len(); - let remaining = self.thinking_buffer[after_pos..].to_string(); - self.thinking_buffer.clear(); - if !remaining.is_empty() { - events.extend(self.create_text_delta_events(&remaining)); - } - } - } - - // thinking 模式下,process_content_with_thinking 可能会为了探测 `` 而暂存一小段尾部文本。 - // 如果此时直接开始 tool_use,状态机会自动关闭 text block,导致这段“待输出文本”看起来被 tool_use 吞掉。 - // 约束:只在尚未进入 thinking block、且 thinking 尚未被提取时,将缓冲区当作普通文本 flush。 - if self.thinking_enabled - && !self.in_thinking_block - && !self.thinking_extracted - && !self.thinking_buffer.is_empty() - { - let buffered = std::mem::take(&mut self.thinking_buffer); - events.extend(self.create_text_delta_events(&buffered)); - } - - // 获取或分配块索引 - let block_index = if let Some(&idx) = self.tool_block_indices.get(&tool_use.tool_use_id) { - idx - } else { - let idx = self.state_manager.next_block_index(); - self.tool_block_indices - .insert(tool_use.tool_use_id.clone(), idx); - idx - }; - - // 发送 content_block_start - let start_events = self.state_manager.handle_content_block_start( - block_index, - "tool_use", - json!({ - "type": "content_block_start", - "index": block_index, - "content_block": { - "type": "tool_use", - "id": tool_use.tool_use_id, - "name": tool_use.name, - "input": {} - } - }), - ); - events.extend(start_events); - - // 发送参数增量 (ToolUseEvent.input 是 String 类型) - if !tool_use.input.is_empty() { - self.output_tokens += (tool_use.input.len() as i32 + 3) / 4; // 估算 token - - if let Some(delta_event) = self.state_manager.handle_content_block_delta( - block_index, - json!({ - "type": "content_block_delta", - "index": block_index, - "delta": { - "type": "input_json_delta", - "partial_json": tool_use.input - } - }), - ) { - events.push(delta_event); - } - } - - // 如果是完整的工具调用(stop=true),发送 content_block_stop - if tool_use.stop { - if let Some(stop_event) = self.state_manager.handle_content_block_stop(block_index) { - events.push(stop_event); - } - } - - events - } - - /// 生成最终事件序列 - pub fn generate_final_events(&mut self) -> Vec { - let mut events = Vec::new(); - - // Flush thinking_buffer 中的剩余内容 - if self.thinking_enabled && !self.thinking_buffer.is_empty() { - if self.in_thinking_block { - // 末尾可能残留 ``(例如紧跟 tool_use 或流结束),需要在 flush 时过滤掉结束标签。 - if let Some(end_pos) = - find_real_thinking_end_tag_at_buffer_end(&self.thinking_buffer) - { - let thinking_content = self.thinking_buffer[..end_pos].to_string(); - if !thinking_content.is_empty() { - if let Some(thinking_index) = self.thinking_block_index { - events.push( - self.create_thinking_delta_event(thinking_index, &thinking_content), - ); - } - } - - // 关闭 thinking 块:先发送空的 thinking_delta,再发送 content_block_stop - if let Some(thinking_index) = self.thinking_block_index { - events.push(self.create_thinking_delta_event(thinking_index, "")); - if let Some(stop_event) = - self.state_manager.handle_content_block_stop(thinking_index) - { - events.push(stop_event); - } - } - - // 把结束标签后的内容当作普通文本(通常为空或空白) - let after_pos = end_pos + "".len(); - let remaining = self.thinking_buffer[after_pos..].to_string(); - self.thinking_buffer.clear(); - self.in_thinking_block = false; - self.thinking_extracted = true; - if !remaining.is_empty() { - events.extend(self.create_text_delta_events(&remaining)); - } - } else { - // 如果还在 thinking 块内,发送剩余内容作为 thinking_delta - if let Some(thinking_index) = self.thinking_block_index { - events.push( - self.create_thinking_delta_event(thinking_index, &self.thinking_buffer), - ); - } - // 关闭 thinking 块:先发送空的 thinking_delta,再发送 content_block_stop - if let Some(thinking_index) = self.thinking_block_index { - // 先发送空的 thinking_delta - events.push(self.create_thinking_delta_event(thinking_index, "")); - // 再发送 content_block_stop - if let Some(stop_event) = - self.state_manager.handle_content_block_stop(thinking_index) - { - events.push(stop_event); - } - } - } - } else { - // 否则发送剩余内容作为 text_delta - let buffer_content = self.thinking_buffer.clone(); - events.extend(self.create_text_delta_events(&buffer_content)); - } - self.thinking_buffer.clear(); - } - - // 使用从 contextUsageEvent 计算的 input_tokens,如果没有则使用估算值 - let final_input_tokens = self.context_input_tokens.unwrap_or(self.input_tokens); - - // 生成最终事件 - events.extend( - self.state_manager - .generate_final_events(final_input_tokens, self.output_tokens), - ); - events - } -} - -/// 简单的 token 估算 -fn estimate_tokens(text: &str) -> i32 { - let chars: Vec = text.chars().collect(); - let mut chinese_count = 0; - let mut other_count = 0; - - for c in &chars { - if *c >= '\u{4E00}' && *c <= '\u{9FFF}' { - chinese_count += 1; - } else { - other_count += 1; - } - } - - // 中文约 1.5 字符/token,英文约 4 字符/token - let chinese_tokens = (chinese_count * 2 + 2) / 3; - let other_tokens = (other_count + 3) / 4; - - (chinese_tokens + other_tokens).max(1) -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_sse_event_format() { - let event = SseEvent::new("message_start", json!({"type": "message_start"})); - let sse_str = event.to_sse_string(); - - assert!(sse_str.starts_with("event: message_start\n")); - assert!(sse_str.contains("data: ")); - assert!(sse_str.ends_with("\n\n")); - } - - #[test] - fn test_sse_state_manager_message_start() { - let mut manager = SseStateManager::new(); - - // 第一次应该成功 - let event = manager.handle_message_start(json!({"type": "message_start"})); - assert!(event.is_some()); - - // 第二次应该被跳过 - let event = manager.handle_message_start(json!({"type": "message_start"})); - assert!(event.is_none()); - } - - #[test] - fn test_sse_state_manager_block_lifecycle() { - let mut manager = SseStateManager::new(); - - // 创建块 - let events = manager.handle_content_block_start(0, "text", json!({})); - assert_eq!(events.len(), 1); - - // delta - let event = manager.handle_content_block_delta(0, json!({})); - assert!(event.is_some()); - - // stop - let event = manager.handle_content_block_stop(0); - assert!(event.is_some()); - - // 重复 stop 应该被跳过 - let event = manager.handle_content_block_stop(0); - assert!(event.is_none()); - } - - #[test] - fn test_text_delta_after_tool_use_restarts_text_block() { - let mut ctx = StreamContext::new_with_thinking("test-model", 1, false); - - let initial_events = ctx.generate_initial_events(); - assert!( - initial_events - .iter() - .any(|e| e.event == "content_block_start" - && e.data["content_block"]["type"] == "text") - ); - - let initial_text_index = ctx - .text_block_index - .expect("initial text block index should exist"); - - // tool_use 开始会自动关闭现有 text block - let tool_events = ctx.process_tool_use(&crate::kiro::model::events::ToolUseEvent { - name: "test_tool".to_string(), - tool_use_id: "tool_1".to_string(), - input: "{}".to_string(), - stop: false, - }); - assert!( - tool_events.iter().any(|e| { - e.event == "content_block_stop" - && e.data["index"].as_i64() == Some(initial_text_index as i64) - }), - "tool_use should stop the previous text block" - ); - - // 之后再来文本增量,应自动创建新的 text block 而不是往已 stop 的块里写 delta - let text_events = ctx.process_assistant_response("hello"); - let new_text_start_index = text_events.iter().find_map(|e| { - if e.event == "content_block_start" && e.data["content_block"]["type"] == "text" { - e.data["index"].as_i64() - } else { - None - } - }); - assert!( - new_text_start_index.is_some(), - "should start a new text block" - ); - assert_ne!( - new_text_start_index.unwrap(), - initial_text_index as i64, - "new text block index should differ from the stopped one" - ); - assert!( - text_events.iter().any(|e| { - e.event == "content_block_delta" - && e.data["delta"]["type"] == "text_delta" - && e.data["delta"]["text"] == "hello" - }), - "should emit text_delta after restarting text block" - ); - } - - #[test] - fn test_tool_use_flushes_pending_thinking_buffer_text_before_tool_block() { - // thinking 模式下,短文本可能被暂存在 thinking_buffer 以等待 `` 的跨 chunk 匹配。 - // 当紧接着出现 tool_use 时,应先 flush 这段文本,再开始 tool_use block。 - let mut ctx = StreamContext::new_with_thinking("test-model", 1, true); - let _initial_events = ctx.generate_initial_events(); - - // 两段短文本(各 2 个中文字符),总长度仍可能不足以满足 safe_len>0 的输出条件, - // 因而会留在 thinking_buffer 中等待后续 chunk。 - let ev1 = ctx.process_assistant_response("有修"); - assert!( - ev1.iter().all(|e| e.event != "content_block_delta"), - "short prefix should be buffered under thinking mode" - ); - let ev2 = ctx.process_assistant_response("改:"); - assert!( - ev2.iter().all(|e| e.event != "content_block_delta"), - "short prefix should still be buffered under thinking mode" - ); - - let events = ctx.process_tool_use(&crate::kiro::model::events::ToolUseEvent { - name: "Write".to_string(), - tool_use_id: "tool_1".to_string(), - input: "{}".to_string(), - stop: false, - }); - - let text_start_index = events.iter().find_map(|e| { - if e.event == "content_block_start" && e.data["content_block"]["type"] == "text" { - e.data["index"].as_i64() - } else { - None - } - }); - let pos_text_delta = events.iter().position(|e| { - e.event == "content_block_delta" && e.data["delta"]["type"] == "text_delta" - }); - let pos_text_stop = text_start_index.and_then(|idx| { - events.iter().position(|e| { - e.event == "content_block_stop" && e.data["index"].as_i64() == Some(idx) - }) - }); - let pos_tool_start = events.iter().position(|e| { - e.event == "content_block_start" && e.data["content_block"]["type"] == "tool_use" - }); - - assert!( - text_start_index.is_some(), - "should start a text block to flush buffered text" - ); - assert!( - pos_text_delta.is_some(), - "should flush buffered text as text_delta" - ); - assert!( - pos_text_stop.is_some(), - "should stop text block before tool_use block starts" - ); - assert!(pos_tool_start.is_some(), "should start tool_use block"); - - let pos_text_delta = pos_text_delta.unwrap(); - let pos_text_stop = pos_text_stop.unwrap(); - let pos_tool_start = pos_tool_start.unwrap(); - - assert!( - pos_text_delta < pos_text_stop && pos_text_stop < pos_tool_start, - "ordering should be: text_delta -> text_stop -> tool_use_start" - ); - - assert!( - events.iter().any(|e| { - e.event == "content_block_delta" - && e.data["delta"]["type"] == "text_delta" - && e.data["delta"]["text"] == "有修改:" - }), - "flushed text should equal the buffered prefix" - ); - } - - #[test] - fn test_estimate_tokens() { - assert!(estimate_tokens("Hello") > 0); - assert!(estimate_tokens("你好") > 0); - assert!(estimate_tokens("Hello 你好") > 0); - } - - #[test] - fn test_find_real_thinking_start_tag_basic() { - // 基本情况:正常的开始标签 - assert_eq!(find_real_thinking_start_tag(""), Some(0)); - assert_eq!(find_real_thinking_start_tag("prefix"), Some(6)); - } - - #[test] - fn test_find_real_thinking_start_tag_with_backticks() { - // 被反引号包裹的应该被跳过 - assert_eq!(find_real_thinking_start_tag("``"), None); - assert_eq!(find_real_thinking_start_tag("use `` tag"), None); - - // 先有被包裹的,后有真正的开始标签 - assert_eq!( - find_real_thinking_start_tag("about `` tagcontent"), - Some(22) - ); - } - - #[test] - fn test_find_real_thinking_start_tag_with_quotes() { - // 被双引号包裹的应该被跳过 - assert_eq!(find_real_thinking_start_tag("\"\""), None); - assert_eq!(find_real_thinking_start_tag("the \"\" tag"), None); - - // 被单引号包裹的应该被跳过 - assert_eq!(find_real_thinking_start_tag("''"), None); - - // 混合情况 - assert_eq!( - find_real_thinking_start_tag("about \"\" and '' then"), - Some(40) - ); - } - - #[test] - fn test_find_real_thinking_end_tag_basic() { - // 基本情况:正常的结束标签后面有双换行符 - assert_eq!(find_real_thinking_end_tag("\n\n"), Some(0)); - assert_eq!( - find_real_thinking_end_tag("content\n\n"), - Some(7) - ); - assert_eq!( - find_real_thinking_end_tag("some text\n\nmore text"), - Some(9) - ); - - // 没有双换行符的情况 - assert_eq!(find_real_thinking_end_tag(""), None); - assert_eq!(find_real_thinking_end_tag("\n"), None); - assert_eq!(find_real_thinking_end_tag(" more"), None); - } - - #[test] - fn test_find_real_thinking_end_tag_with_backticks() { - // 被反引号包裹的应该被跳过 - assert_eq!(find_real_thinking_end_tag("``\n\n"), None); - assert_eq!( - find_real_thinking_end_tag("mention `` in code\n\n"), - None - ); - - // 只有前面有反引号 - assert_eq!(find_real_thinking_end_tag("`\n\n"), None); - - // 只有后面有反引号 - assert_eq!(find_real_thinking_end_tag("`\n\n"), None); - } - - #[test] - fn test_find_real_thinking_end_tag_with_quotes() { - // 被双引号包裹的应该被跳过 - assert_eq!(find_real_thinking_end_tag("\"\"\n\n"), None); - assert_eq!( - find_real_thinking_end_tag("the string \"\" is a tag\n\n"), - None - ); - - // 被单引号包裹的应该被跳过 - assert_eq!(find_real_thinking_end_tag("''\n\n"), None); - assert_eq!( - find_real_thinking_end_tag("use '' as marker\n\n"), - None - ); - - // 混合情况:双引号包裹后有真正的标签 - assert_eq!( - find_real_thinking_end_tag("about \"\" tag\n\n"), - Some(23) - ); - - // 混合情况:单引号包裹后有真正的标签 - assert_eq!( - find_real_thinking_end_tag("about '' tag\n\n"), - Some(23) - ); - } - - #[test] - fn test_find_real_thinking_end_tag_mixed() { - // 先有被包裹的,后有真正的结束标签 - assert_eq!( - find_real_thinking_end_tag("discussing `` tag\n\n"), - Some(28) - ); - - // 多个被包裹的,最后一个是真正的 - assert_eq!( - find_real_thinking_end_tag("`` and `` done\n\n"), - Some(36) - ); - - // 多种引用字符混合 - assert_eq!( - find_real_thinking_end_tag( - "`` and \"\" and '' done\n\n" - ), - Some(54) - ); - } - - #[test] - fn test_tool_use_immediately_after_thinking_filters_end_tag_and_closes_thinking_block() { - let mut ctx = StreamContext::new_with_thinking("test-model", 1, true); - let _initial_events = ctx.generate_initial_events(); - - let mut all_events = Vec::new(); - - // thinking 内容以 `` 结尾,但后面没有 `\n\n`(模拟紧跟 tool_use 的场景) - all_events.extend(ctx.process_assistant_response("abc")); - - let tool_events = ctx.process_tool_use(&crate::kiro::model::events::ToolUseEvent { - name: "Write".to_string(), - tool_use_id: "tool_1".to_string(), - input: "{}".to_string(), - stop: false, - }); - all_events.extend(tool_events); - - all_events.extend(ctx.generate_final_events()); - - // 不应把 `` 当作 thinking 内容输出 - assert!( - all_events.iter().all(|e| { - !(e.event == "content_block_delta" - && e.data["delta"]["type"] == "thinking_delta" - && e.data["delta"]["thinking"] == "") - }), - "`` should be filtered from output" - ); - - // thinking block 必须在 tool_use block 之前关闭 - let thinking_index = ctx - .thinking_block_index - .expect("thinking block index should exist"); - let pos_thinking_stop = all_events.iter().position(|e| { - e.event == "content_block_stop" - && e.data["index"].as_i64() == Some(thinking_index as i64) - }); - let pos_tool_start = all_events.iter().position(|e| { - e.event == "content_block_start" && e.data["content_block"]["type"] == "tool_use" - }); - assert!( - pos_thinking_stop.is_some(), - "thinking block should be stopped" - ); - assert!(pos_tool_start.is_some(), "tool_use block should be started"); - assert!( - pos_thinking_stop.unwrap() < pos_tool_start.unwrap(), - "thinking block should stop before tool_use block starts" - ); - } - - #[test] - fn test_final_flush_filters_standalone_thinking_end_tag() { - let mut ctx = StreamContext::new_with_thinking("test-model", 1, true); - let _initial_events = ctx.generate_initial_events(); - - let mut all_events = Vec::new(); - all_events.extend(ctx.process_assistant_response("abc")); - all_events.extend(ctx.generate_final_events()); - - assert!( - all_events.iter().all(|e| { - !(e.event == "content_block_delta" - && e.data["delta"]["type"] == "thinking_delta" - && e.data["delta"]["thinking"] == "") - }), - "`` should be filtered during final flush" - ); - } -} diff --git a/src/anthropic/types.rs b/src/anthropic/types.rs deleted file mode 100644 index 37f52ab6e85fbea3e54377bc3afd9d5953cb1a96..0000000000000000000000000000000000000000 --- a/src/anthropic/types.rs +++ /dev/null @@ -1,270 +0,0 @@ -//! Anthropic API 类型定义 - -use serde::{Deserialize, Serialize}; -use std::collections::HashMap; - -// === 错误响应 === - -/// API 错误响应 -#[derive(Debug, Serialize)] -pub struct ErrorResponse { - pub error: ErrorDetail, -} - -/// 错误详情 -#[derive(Debug, Serialize)] -pub struct ErrorDetail { - #[serde(rename = "type")] - pub error_type: String, - pub message: String, -} - -impl ErrorResponse { - /// 创建新的错误响应 - pub fn new(error_type: impl Into, message: impl Into) -> Self { - Self { - error: ErrorDetail { - error_type: error_type.into(), - message: message.into(), - }, - } - } - - /// 创建认证错误响应 - pub fn authentication_error() -> Self { - Self::new("authentication_error", "Invalid API key") - } -} - -// === Models 端点类型 === - -/// 模型信息 -#[derive(Debug, Serialize)] -pub struct Model { - pub id: String, - pub object: String, - pub created: i64, - pub owned_by: String, - pub display_name: String, - #[serde(rename = "type")] - pub model_type: String, - pub max_tokens: i32, -} - -/// 模型列表响应 -#[derive(Debug, Serialize)] -pub struct ModelsResponse { - pub object: String, - pub data: Vec, -} - -// === Messages 端点类型 === - -/// 最大思考预算 tokens -const MAX_BUDGET_TOKENS: i32 = 24576; - -/// Thinking 配置 -#[derive(Debug, Deserialize, Clone)] -pub struct Thinking { - #[serde(rename = "type")] - pub thinking_type: String, - #[serde( - default = "default_budget_tokens", - deserialize_with = "deserialize_budget_tokens" - )] - pub budget_tokens: i32, -} - -fn default_budget_tokens() -> i32 { - 20000 -} -fn deserialize_budget_tokens<'de, D>(deserializer: D) -> Result -where - D: serde::Deserializer<'de>, -{ - let value = i32::deserialize(deserializer)?; - Ok(value.min(MAX_BUDGET_TOKENS)) -} - -/// Claude Code 请求中的 metadata -#[derive(Debug, Clone, Deserialize)] -pub struct Metadata { - /// 用户 ID,格式如: user_xxx_account__session_0b4445e1-f5be-49e1-87ce-62bbc28ad705 - pub user_id: Option, -} - -/// Messages 请求体 -#[derive(Debug, Deserialize)] -pub struct MessagesRequest { - pub model: String, - pub max_tokens: i32, - pub messages: Vec, - #[serde(default)] - pub stream: bool, - #[serde(default, deserialize_with = "deserialize_system")] - pub system: Option>, - pub tools: Option>, - pub tool_choice: Option, - pub thinking: Option, - /// Claude Code 请求中的 metadata,包含 session 信息 - pub metadata: Option, -} - -/// 反序列化 system 字段,支持字符串或数组格式 -fn deserialize_system<'de, D>(deserializer: D) -> Result>, D::Error> -where - D: serde::Deserializer<'de>, -{ - use serde::de::Error; - - // 创建一个 visitor 来处理 string 或 array - struct SystemVisitor; - - impl<'de> serde::de::Visitor<'de> for SystemVisitor { - type Value = Option>; - - fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { - formatter.write_str("a string or an array of system messages") - } - - fn visit_str(self, value: &str) -> Result - where - E: serde::de::Error, - { - Ok(Some(vec![SystemMessage { - text: value.to_string(), - }])) - } - - fn visit_seq(self, mut seq: A) -> Result - where - A: serde::de::SeqAccess<'de>, - { - let mut messages = Vec::new(); - while let Some(msg) = seq.next_element()? { - messages.push(msg); - } - Ok(if messages.is_empty() { None } else { Some(messages) }) - } - - fn visit_none(self) -> Result - where - E: serde::de::Error, - { - Ok(None) - } - - fn visit_some(self, deserializer: D) -> Result - where - D: serde::Deserializer<'de>, - { - serde::de::Deserialize::deserialize(deserializer) - } - } - - deserializer.deserialize_any(SystemVisitor) -} - -/// 消息 -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct Message { - pub role: String, - /// 可以是 string 或 ContentBlock 数组 - pub content: serde_json::Value, -} - -/// 系统消息 -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct SystemMessage { - pub text: String, -} - -/// 工具定义 -/// -/// 支持两种格式: -/// 1. 普通工具:{ name, description, input_schema } -/// 2. WebSearch 工具:{ type: "web_search_20250305", name: "web_search", max_uses: 8 } -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct Tool { - /// 工具类型,如 "web_search_20250305"(可选,仅 WebSearch 工具) - #[serde(rename = "type", skip_serializing_if = "Option::is_none")] - pub tool_type: Option, - /// 工具名称 - #[serde(default)] - pub name: String, - /// 工具描述(普通工具必需,WebSearch 工具可选) - #[serde(default)] - pub description: String, - /// 输入参数 schema(普通工具必需,WebSearch 工具无此字段) - #[serde(default)] - pub input_schema: HashMap, - /// 最大使用次数(仅 WebSearch 工具) - #[serde(skip_serializing_if = "Option::is_none")] - pub max_uses: Option, -} - -impl Tool { - /// 检查是否为 WebSearch 工具 - pub fn is_web_search(&self) -> bool { - self.tool_type - .as_ref() - .is_some_and(|t| t.starts_with("web_search")) - } -} - -/// 内容块 -#[derive(Debug, Deserialize, Serialize)] -pub struct ContentBlock { - #[serde(rename = "type")] - pub block_type: String, - #[serde(skip_serializing_if = "Option::is_none")] - pub text: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub thinking: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub tool_use_id: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub content: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub name: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub input: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub id: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub is_error: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub source: Option, -} - -/// 图片数据源 -#[derive(Debug, Deserialize, Serialize)] -pub struct ImageSource { - #[serde(rename = "type")] - pub source_type: String, - pub media_type: String, - pub data: String, -} - -// === Count Tokens 端点类型 === - -/// Token 计数请求 -#[derive(Debug, Serialize, Deserialize)] -pub struct CountTokensRequest { - pub model: String, - pub messages: Vec, - #[serde( - default, - skip_serializing_if = "Option::is_none", - deserialize_with = "deserialize_system" - )] - pub system: Option>, - #[serde(skip_serializing_if = "Option::is_none")] - pub tools: Option>, -} - -/// Token 计数响应 -#[derive(Debug, Serialize, Deserialize)] -pub struct CountTokensResponse { - pub input_tokens: i32, -} diff --git a/src/anthropic/websearch.rs b/src/anthropic/websearch.rs deleted file mode 100644 index 6977492f50363921bf82ba6d22e0114ac7e95595..0000000000000000000000000000000000000000 --- a/src/anthropic/websearch.rs +++ /dev/null @@ -1,726 +0,0 @@ -//! WebSearch 工具处理模块 -//! -//! 实现 Anthropic WebSearch 请求到 Kiro MCP 的转换和响应生成 - -use std::convert::Infallible; - -use axum::{ - body::Body, - http::{StatusCode, header}, - response::{IntoResponse, Json, Response}, -}; -use bytes::Bytes; -use futures::{Stream, stream}; -use serde::{Deserialize, Serialize}; -use serde_json::json; -use uuid::Uuid; - -use super::stream::SseEvent; -use super::types::{ErrorResponse, MessagesRequest}; - -/// MCP 请求 -#[derive(Debug, Serialize)] -pub struct McpRequest { - pub id: String, - pub jsonrpc: String, - pub method: String, - pub params: McpParams, -} - -/// MCP 请求参数 -#[derive(Debug, Serialize)] -pub struct McpParams { - pub name: String, - pub arguments: McpArguments, -} - -/// MCP 参数 -#[derive(Debug, Serialize)] -pub struct McpArguments { - pub query: String, -} - -/// MCP 响应 -#[derive(Debug, Deserialize)] -pub struct McpResponse { - pub error: Option, - pub id: String, - pub jsonrpc: String, - pub result: Option, -} - -/// MCP 错误 -#[derive(Debug, Deserialize)] -pub struct McpError { - pub code: Option, - pub message: Option, -} - -/// MCP 结果 -#[derive(Debug, Deserialize)] -pub struct McpResult { - pub content: Vec, - #[serde(rename = "isError")] - pub is_error: bool, -} - -/// MCP 内容 -#[derive(Debug, Deserialize)] -pub struct McpContent { - #[serde(rename = "type")] - pub content_type: String, - pub text: String, -} - -/// WebSearch 搜索结果 -#[derive(Debug, Deserialize)] -pub struct WebSearchResults { - pub results: Vec, - #[serde(rename = "totalResults")] - pub total_results: Option, - pub query: Option, - pub error: Option, -} - -/// 单个搜索结果 -#[derive(Debug, Deserialize, Clone)] -pub struct WebSearchResult { - pub title: String, - pub url: String, - pub snippet: Option, - #[serde(rename = "publishedDate")] - pub published_date: Option, - pub id: Option, - pub domain: Option, - #[serde(rename = "maxVerbatimWordLimit")] - pub max_verbatim_word_limit: Option, - #[serde(rename = "publicDomain")] - pub public_domain: Option, -} - -/// 检查请求是否为纯 WebSearch 请求 -/// -/// 条件:tools 有且只有一个,且 name 为 web_search -pub fn has_web_search_tool(req: &MessagesRequest) -> bool { - req.tools.as_ref().is_some_and(|tools| { - tools.len() == 1 && tools.first().is_some_and(|t| t.name == "web_search") - }) -} - -/// 从消息中提取搜索查询 -/// -/// 读取 messages 的第一条消息的第一个内容块 -/// 并去除 "Perform a web search for the query: " 前缀 -pub fn extract_search_query(req: &MessagesRequest) -> Option { - // 获取第一条消息 - let first_msg = req.messages.first()?; - - // 提取文本内容 - let text = match &first_msg.content { - serde_json::Value::String(s) => s.clone(), - serde_json::Value::Array(arr) => { - // 获取第一个内容块 - let first_block = arr.first()?; - if first_block.get("type")?.as_str()? == "text" { - first_block.get("text")?.as_str()?.to_string() - } else { - return None; - } - } - _ => return None, - }; - - // 去除前缀 "Perform a web search for the query: " - const PREFIX: &str = "Perform a web search for the query: "; - let query = if text.starts_with(PREFIX) { - text[PREFIX.len()..].to_string() - } else { - text - }; - - if query.is_empty() { - None - } else { - Some(query) - } -} - -/// 生成22位大小写字母和数字的随机字符串 -fn generate_random_id_22() -> String { - const CHARSET: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"; - (0..22) - .map(|_| { - let idx = fastrand::usize(..CHARSET.len()); - CHARSET[idx] as char - }) - .collect() -} - -/// 生成8位小写字母和数字的随机字符串 -fn generate_random_id_8() -> String { - const CHARSET: &[u8] = b"abcdefghijklmnopqrstuvwxyz0123456789"; - (0..8) - .map(|_| { - let idx = fastrand::usize(..CHARSET.len()); - CHARSET[idx] as char - }) - .collect() -} - -/// 创建 MCP 请求 -/// -/// ID 格式: web_search_tooluse_{22位随机}_{毫秒时间戳}_{8位随机} -pub fn create_mcp_request(query: &str) -> (String, McpRequest) { - let random_22 = generate_random_id_22(); - let timestamp = chrono::Utc::now().timestamp_millis(); - let random_8 = generate_random_id_8(); - - let request_id = format!("web_search_tooluse_{}_{}_{}",random_22, timestamp, random_8); - - // tool_use_id 使用相同格式 - let tool_use_id = format!( - "srvtoolu_{}", - Uuid::new_v4().to_string().replace('-', "")[..32].to_string() - ); - - let request = McpRequest { - id: request_id, - jsonrpc: "2.0".to_string(), - method: "tools/call".to_string(), - params: McpParams { - name: "web_search".to_string(), - arguments: McpArguments { - query: query.to_string(), - }, - }, - }; - - (tool_use_id, request) -} - -/// 解析 MCP 响应中的搜索结果 -pub fn parse_search_results(mcp_response: &McpResponse) -> Option { - let result = mcp_response.result.as_ref()?; - let content = result.content.first()?; - - if content.content_type != "text" { - return None; - } - - serde_json::from_str(&content.text).ok() -} - -/// 生成 WebSearch SSE 响应流 -pub fn create_websearch_sse_stream( - model: String, - query: String, - tool_use_id: String, - search_results: Option, - input_tokens: i32, -) -> impl Stream> { - let events = generate_websearch_events(&model, &query, &tool_use_id, search_results, input_tokens); - - stream::iter( - events - .into_iter() - .map(|e| Ok(Bytes::from(e.to_sse_string()))), - ) -} - -/// 生成 WebSearch SSE 事件序列 -fn generate_websearch_events( - model: &str, - query: &str, - tool_use_id: &str, - search_results: Option, - input_tokens: i32, -) -> Vec { - let mut events = Vec::new(); - let message_id = format!("msg_{}", Uuid::new_v4().to_string().replace('-', "")[..24].to_string()); - - // 1. message_start - events.push(SseEvent::new( - "message_start", - json!({ - "type": "message_start", - "message": { - "id": message_id, - "type": "message", - "role": "assistant", - "model": model, - "content": [], - "stop_reason": null, - "stop_sequence": null, - "usage": { - "input_tokens": input_tokens, - "output_tokens": 0, - "cache_creation_input_tokens": 0, - "cache_read_input_tokens": 0 - } - } - }), - )); - - // 2. content_block_start (server_tool_use) - events.push(SseEvent::new( - "content_block_start", - json!({ - "type": "content_block_start", - "index": 0, - "content_block": { - "id": tool_use_id, - "type": "server_tool_use", - "name": "web_search", - "input": {} - } - }), - )); - - // 3. content_block_delta (input_json_delta) - let input_json = json!({"query": query}); - events.push(SseEvent::new( - "content_block_delta", - json!({ - "type": "content_block_delta", - "index": 0, - "delta": { - "type": "input_json_delta", - "partial_json": serde_json::to_string(&input_json).unwrap_or_default() - } - }), - )); - - // 4. content_block_stop (server_tool_use) - events.push(SseEvent::new( - "content_block_stop", - json!({ - "type": "content_block_stop", - "index": 0 - }), - )); - - // 5. content_block_start (web_search_tool_result) - let search_content = if let Some(ref results) = search_results { - results - .results - .iter() - .map(|r| { - json!({ - "type": "web_search_result", - "title": r.title, - "url": r.url, - "encrypted_content": r.snippet.clone().unwrap_or_default(), - "page_age": null - }) - }) - .collect::>() - } else { - vec![] - }; - - events.push(SseEvent::new( - "content_block_start", - json!({ - "type": "content_block_start", - "index": 1, - "content_block": { - "type": "web_search_tool_result", - "tool_use_id": tool_use_id, - "content": search_content - } - }), - )); - - // 6. content_block_stop (web_search_tool_result) - events.push(SseEvent::new( - "content_block_stop", - json!({ - "type": "content_block_stop", - "index": 1 - }), - )); - - // 7. content_block_start (text) - events.push(SseEvent::new( - "content_block_start", - json!({ - "type": "content_block_start", - "index": 2, - "content_block": { - "type": "text", - "text": "" - } - }), - )); - - // 8. content_block_delta (text_delta) - 生成搜索结果摘要 - let summary = generate_search_summary(query, &search_results); - - // 分块发送文本 - let chunk_size = 100; - for chunk in summary.chars().collect::>().chunks(chunk_size) { - let text: String = chunk.iter().collect(); - events.push(SseEvent::new( - "content_block_delta", - json!({ - "type": "content_block_delta", - "index": 2, - "delta": { - "type": "text_delta", - "text": text - } - }), - )); - } - - // 9. content_block_stop (text) - events.push(SseEvent::new( - "content_block_stop", - json!({ - "type": "content_block_stop", - "index": 2 - }), - )); - - // 10. message_delta - let output_tokens = (summary.len() as i32 + 3) / 4; // 简单估算 - events.push(SseEvent::new( - "message_delta", - json!({ - "type": "message_delta", - "delta": { - "stop_reason": "end_turn", - "stop_sequence": null - }, - "usage": { - "output_tokens": output_tokens - } - }), - )); - - // 11. message_stop - events.push(SseEvent::new( - "message_stop", - json!({ - "type": "message_stop" - }), - )); - - events -} - -/// 生成搜索结果摘要 -fn generate_search_summary(query: &str, results: &Option) -> String { - let mut summary = format!("Here are the search results for \"{}\":\n\n", query); - - if let Some(results) = results { - for (i, result) in results.results.iter().enumerate() { - summary.push_str(&format!("{}. **{}**\n", i + 1, result.title)); - if let Some(ref snippet) = result.snippet { - // 截断过长的摘要 - let truncated = if snippet.len() > 200 { - format!("{}...", &snippet[..200]) - } else { - snippet.clone() - }; - summary.push_str(&format!(" {}\n", truncated)); - } - summary.push_str(&format!(" Source: {}\n\n", result.url)); - } - } else { - summary.push_str("No results found.\n"); - } - - summary.push_str("\nPlease note that these are web search results and may not be fully accurate or up-to-date."); - - summary -} - -/// 处理 WebSearch 请求 -pub async fn handle_websearch_request( - provider: std::sync::Arc, - payload: &MessagesRequest, - input_tokens: i32, -) -> Response { - // 1. 提取搜索查询 - let query = match extract_search_query(payload) { - Some(q) => q, - None => { - return ( - StatusCode::BAD_REQUEST, - Json(ErrorResponse::new( - "invalid_request_error", - "无法从消息中提取搜索查询", - )), - ) - .into_response(); - } - }; - - tracing::info!(query = %query, "处理 WebSearch 请求"); - - // 2. 创建 MCP 请求 - let (tool_use_id, mcp_request) = create_mcp_request(&query); - - // 3. 调用 Kiro MCP API - let search_results = match call_mcp_api(&provider, &mcp_request).await { - Ok(response) => parse_search_results(&response), - Err(e) => { - tracing::warn!("MCP API 调用失败: {}", e); - None - } - }; - - // 4. 生成 SSE 响应 - let model = payload.model.clone(); - let stream = create_websearch_sse_stream( - model, - query, - tool_use_id, - search_results, - input_tokens, - ); - - Response::builder() - .status(StatusCode::OK) - .header(header::CONTENT_TYPE, "text/event-stream") - .header(header::CACHE_CONTROL, "no-cache") - .header(header::CONNECTION, "keep-alive") - .body(Body::from_stream(stream)) - .unwrap() -} - -/// 调用 Kiro MCP API -async fn call_mcp_api( - provider: &crate::kiro::provider::KiroProvider, - request: &McpRequest, -) -> anyhow::Result { - let request_body = serde_json::to_string(request)?; - - tracing::debug!("MCP request: {}", request_body); - - let response = provider.call_mcp(&request_body).await?; - - let body = response.text().await?; - tracing::debug!("MCP response: {}", body); - - let mcp_response: McpResponse = serde_json::from_str(&body)?; - - if let Some(ref error) = mcp_response.error { - anyhow::bail!( - "MCP error: {} - {}", - error.code.unwrap_or(-1), - error.message.as_deref().unwrap_or("Unknown error") - ); - } - - Ok(mcp_response) -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_has_web_search_tool_only_one() { - use crate::anthropic::types::{Message, Tool}; - - let req = MessagesRequest { - model: "claude-sonnet-4".to_string(), - max_tokens: 1024, - messages: vec![Message { - role: "user".to_string(), - content: serde_json::json!("test"), - }], - stream: true, - system: None, - tools: Some(vec![Tool { - tool_type: Some("web_search_20250305".to_string()), - name: "web_search".to_string(), - description: String::new(), - input_schema: Default::default(), - max_uses: Some(8), - }]), - tool_choice: None, - thinking: None, - metadata: None, - }; - - assert!(has_web_search_tool(&req)); - } - - #[test] - fn test_has_web_search_tool_multiple_tools() { - use crate::anthropic::types::{Message, Tool}; - - let req = MessagesRequest { - model: "claude-sonnet-4".to_string(), - max_tokens: 1024, - messages: vec![Message { - role: "user".to_string(), - content: serde_json::json!("test"), - }], - stream: true, - system: None, - tools: Some(vec![ - Tool { - tool_type: Some("web_search_20250305".to_string()), - name: "web_search".to_string(), - description: String::new(), - input_schema: Default::default(), - max_uses: Some(8), - }, - Tool { - tool_type: None, - name: "other_tool".to_string(), - description: "Other tool".to_string(), - input_schema: Default::default(), - max_uses: None, - }, - ]), - tool_choice: None, - thinking: None, - metadata: None, - }; - - // 多个工具时不应该被识别为纯 websearch 请求 - assert!(!has_web_search_tool(&req)); - } - - #[test] - fn test_extract_search_query_with_prefix() { - use crate::anthropic::types::Message; - - let req = MessagesRequest { - model: "claude-sonnet-4".to_string(), - max_tokens: 1024, - messages: vec![Message { - role: "user".to_string(), - content: serde_json::json!([{ - "type": "text", - "text": "Perform a web search for the query: rust latest version 2026" - }]), - }], - stream: true, - system: None, - tools: None, - tool_choice: None, - thinking: None, - metadata: None, - }; - - let query = extract_search_query(&req); - // 前缀应该被去除 - assert_eq!(query, Some("rust latest version 2026".to_string())); - } - - #[test] - fn test_extract_search_query_plain_text() { - use crate::anthropic::types::Message; - - let req = MessagesRequest { - model: "claude-sonnet-4".to_string(), - max_tokens: 1024, - messages: vec![Message { - role: "user".to_string(), - content: serde_json::json!("What is the weather today?"), - }], - stream: true, - system: None, - tools: None, - tool_choice: None, - thinking: None, - metadata: None, - }; - - let query = extract_search_query(&req); - assert_eq!(query, Some("What is the weather today?".to_string())); - } - - #[test] - fn test_create_mcp_request() { - let (tool_use_id, request) = create_mcp_request("test query"); - - assert!(tool_use_id.starts_with("srvtoolu_")); - assert_eq!(request.jsonrpc, "2.0"); - assert_eq!(request.method, "tools/call"); - assert_eq!(request.params.name, "web_search"); - assert_eq!(request.params.arguments.query, "test query"); - - // 验证 ID 格式: web_search_tooluse_{22位}_{时间戳}_{8位} - assert!(request.id.starts_with("web_search_tooluse_")); - } - - #[test] - fn test_mcp_request_id_format() { - let (_, request) = create_mcp_request("test"); - - // 格式: web_search_tooluse_{22位}_{毫秒时间戳}_{8位} - let id = &request.id; - assert!(id.starts_with("web_search_tooluse_")); - - let suffix = &id["web_search_tooluse_".len()..]; - let parts: Vec<&str> = suffix.split('_').collect(); - assert_eq!(parts.len(), 3, "应该有3个部分: 22位随机_时间戳_8位随机"); - - // 第一部分: 22位大小写字母和数字 - assert_eq!(parts[0].len(), 22); - assert!(parts[0].chars().all(|c| c.is_ascii_alphanumeric())); - - // 第二部分: 毫秒时间戳 - assert!(parts[1].parse::().is_ok()); - - // 第三部分: 8位小写字母和数字 - assert_eq!(parts[2].len(), 8); - assert!(parts[2] - .chars() - .all(|c| c.is_ascii_lowercase() || c.is_ascii_digit())); - } - - #[test] - fn test_parse_search_results() { - let response = McpResponse { - error: None, - id: "test_id".to_string(), - jsonrpc: "2.0".to_string(), - result: Some(McpResult { - content: vec![McpContent { - content_type: "text".to_string(), - text: r#"{"results":[{"title":"Test","url":"https://example.com","snippet":"Test snippet"}],"totalResults":1}"#.to_string(), - }], - is_error: false, - }), - }; - - let results = parse_search_results(&response); - assert!(results.is_some()); - let results = results.unwrap(); - assert_eq!(results.results.len(), 1); - assert_eq!(results.results[0].title, "Test"); - } - - #[test] - fn test_generate_search_summary() { - let results = WebSearchResults { - results: vec![WebSearchResult { - title: "Test Result".to_string(), - url: "https://example.com".to_string(), - snippet: Some("This is a test snippet".to_string()), - published_date: None, - id: None, - domain: None, - max_verbatim_word_limit: None, - public_domain: None, - }], - total_results: Some(1), - query: Some("test".to_string()), - error: None, - }; - - let summary = generate_search_summary("test", &Some(results)); - - assert!(summary.contains("Test Result")); - assert!(summary.contains("https://example.com")); - assert!(summary.contains("This is a test snippet")); - } -} diff --git a/src/common/auth.rs b/src/common/auth.rs deleted file mode 100644 index 74f4b173494973449a6a1a95c0d353a4b57c8dcb..0000000000000000000000000000000000000000 --- a/src/common/auth.rs +++ /dev/null @@ -1,41 +0,0 @@ -//! 公共认证工具函数 - -use axum::{ - body::Body, - http::{Request, header}, -}; -use subtle::ConstantTimeEq; - -/// 从请求中提取 API Key -/// -/// 支持两种认证方式: -/// - `x-api-key` header -/// - `Authorization: Bearer ` header -pub fn extract_api_key(request: &Request) -> Option { - // 优先检查 x-api-key - if let Some(key) = request - .headers() - .get("x-api-key") - .and_then(|v| v.to_str().ok()) - { - return Some(key.to_string()); - } - - // 其次检查 Authorization: Bearer - request - .headers() - .get(header::AUTHORIZATION) - .and_then(|v| v.to_str().ok()) - .and_then(|v| v.strip_prefix("Bearer ")) - .map(|s| s.to_string()) -} - -/// 常量时间字符串比较,防止时序攻击 -/// -/// 无论字符串内容如何,比较所需的时间都是恒定的, -/// 这可以防止攻击者通过测量响应时间来猜测 API Key。 -/// -/// 使用经过安全审计的 `subtle` crate 实现 -pub fn constant_time_eq(a: &str, b: &str) -> bool { - a.as_bytes().ct_eq(b.as_bytes()).into() -} diff --git a/src/common/mod.rs b/src/common/mod.rs deleted file mode 100644 index 56c77f4fc2b04972caf271e3e1390ada9ce47dbd..0000000000000000000000000000000000000000 --- a/src/common/mod.rs +++ /dev/null @@ -1,3 +0,0 @@ -//! 公共工具模块 - -pub mod auth; diff --git a/src/debug.rs b/src/debug.rs deleted file mode 100644 index 206ea9c518db082c3aaee7a441ab97637f0cf689..0000000000000000000000000000000000000000 --- a/src/debug.rs +++ /dev/null @@ -1,210 +0,0 @@ -//! 调试工具模块 -//! -//! 提供 hex 打印和 CRC 调试等功能 - -use crate::kiro::model::events::Event; -use std::io::Write; - -/// 打印 hex 数据 (类似 xxd 格式) -pub fn print_hex(data: &[u8]) { - for (i, chunk) in data.chunks(16).enumerate() { - // 打印偏移 - print!("{:08x}: ", i * 16); - - // 打印 hex - for (j, byte) in chunk.iter().enumerate() { - if j == 8 { - print!(" "); - } - print!("{:02x} ", byte); - } - - // 补齐空格 - let padding = 16 - chunk.len(); - for j in 0..padding { - if chunk.len() + j == 8 { - print!(" "); - } - print!(" "); - } - - // 打印 ASCII - print!(" |"); - for byte in chunk { - if *byte >= 0x20 && *byte < 0x7f { - print!("{}", *byte as char); - } else { - print!("."); - } - } - println!("|"); - } - std::io::stdout().flush().ok(); -} - -/// 调试 CRC 计算 - 分析 AWS Event Stream 帧的 CRC -pub fn debug_crc(data: &[u8]) { - if data.len() < 12 { - println!("[CRC 调试] 数据不足 12 字节"); - return; - } - - use crc::{Crc, CRC_32_BZIP2, CRC_32_ISO_HDLC, CRC_32_ISCSI, CRC_32_JAMCRC}; - - let total_length = u32::from_be_bytes([data[0], data[1], data[2], data[3]]); - let header_length = u32::from_be_bytes([data[4], data[5], data[6], data[7]]); - let prelude_crc = u32::from_be_bytes([data[8], data[9], data[10], data[11]]); - - println!("\n[CRC 调试]"); - println!(" total_length: {} (0x{:08x})", total_length, total_length); - println!( - " header_length: {} (0x{:08x})", - header_length, header_length - ); - println!(" prelude_crc (from data): 0x{:08x}", prelude_crc); - - // 测试各种 CRC32 变种 - let crc32c: Crc = Crc::::new(&CRC_32_ISCSI); - let crc32_iso: Crc = Crc::::new(&CRC_32_ISO_HDLC); - let crc32_bzip2: Crc = Crc::::new(&CRC_32_BZIP2); - let crc32_jamcrc: Crc = Crc::::new(&CRC_32_JAMCRC); - - let prelude = &data[..8]; - - println!(" CRC32C (ISCSI): 0x{:08x}", crc32c.checksum(prelude)); - println!( - " CRC32 ISO-HDLC: 0x{:08x} {}", - crc32_iso.checksum(prelude), - if crc32_iso.checksum(prelude) == prelude_crc { - "<-- MATCH" - } else { - "" - } - ); - println!(" CRC32 BZIP2: 0x{:08x}", crc32_bzip2.checksum(prelude)); - println!( - " CRC32 JAMCRC: 0x{:08x}", - crc32_jamcrc.checksum(prelude) - ); - - // 打印前 8 字节 - print!(" 前 8 字节: "); - for byte in prelude { - print!("{:02x} ", byte); - } - println!(); -} - -/// 打印帧摘要信息 -pub fn print_frame_summary(data: &[u8]) { - if data.len() < 12 { - println!("[帧摘要] 数据不足"); - return; - } - - let total_length = u32::from_be_bytes([data[0], data[1], data[2], data[3]]) as usize; - let header_length = u32::from_be_bytes([data[4], data[5], data[6], data[7]]) as usize; - - println!("\n[帧摘要]"); - println!(" 总长度: {} 字节", total_length); - println!(" 头部长度: {} 字节", header_length); - println!(" Payload 长度: {} 字节", total_length.saturating_sub(12 + header_length + 4)); - println!(" 数据可用: {} 字节", data.len()); - - if data.len() >= total_length { - println!(" 状态: 完整帧"); - } else { - println!( - " 状态: 不完整 (缺少 {} 字节)", - total_length - data.len() - ); - } -} - -/// 详细打印事件 (调试格式,包含事件类型和完整数据) -pub fn print_event_verbose(event: &Event) { - match event { - Event::AssistantResponse(e) => { - println!("\n[事件] AssistantResponse"); - println!(" content: {:?}", e.content()); - } - Event::ToolUse(e) => { - println!("\n[事件] ToolUse"); - println!(" name: {:?}", e.name()); - println!(" tool_use_id: {:?}", e.tool_use_id()); - println!(" input: {:?}", e.input()); - println!(" stop: {}", e.is_complete()); - } - Event::Metering(e) => { - println!("\n[事件] Metering"); - println!(" unit: {:?}", e.unit); - println!(" unit_plural: {:?}", e.unit_plural); - println!(" usage: {}", e.usage); - } - Event::ContextUsage(e) => { - println!("\n[事件] ContextUsage"); - println!(" context_usage_percentage: {}", e.context_usage_percentage); - } - Event::Unknown { event_type, payload } => { - println!("\n[事件] Unknown"); - println!(" event_type: {:?}", event_type); - println!(" payload ({} bytes):", payload.len()); - print_hex(payload); - } - Event::Error { - error_code, - error_message, - } => { - println!("\n[事件] Error"); - println!(" error_code: {:?}", error_code); - println!(" error_message: {:?}", error_message); - } - Event::Exception { - exception_type, - message, - } => { - println!("\n[事件] Exception"); - println!(" exception_type: {:?}", exception_type); - println!(" message: {:?}", message); - } - } -} - -/// 简洁打印事件 (用于正常输出) -pub fn print_event(event: &Event) { - match event { - Event::AssistantResponse(e) => { - // 实时打印助手响应,不换行 - print!("{}", e.content()); - std::io::stdout().flush().ok(); - } - Event::ToolUse(e) => { - println!("\n[工具调用] {} (id: {})", e.name(), e.tool_use_id()); - println!(" 输入: {}", e.input()); - if e.is_complete() { - println!(" [调用结束]"); - } - } - Event::Metering(e) => { - println!("\n[计费] {}", e); - } - Event::ContextUsage(e) => { - println!("\n[上下文使用率] {}", e); - } - Event::Unknown { event_type, .. } => { - println!("\n[未知事件] {}", event_type); - } - Event::Error { - error_code, - error_message, - } => { - println!("\n[错误] {}: {}", error_code, error_message); - } - Event::Exception { - exception_type, - message, - } => { - println!("\n[异常] {}: {}", exception_type, message); - } - } -} diff --git a/src/http_client.rs b/src/http_client.rs deleted file mode 100644 index f68ff55af26a4c10b9833960e01a256cf9f764c8..0000000000000000000000000000000000000000 --- a/src/http_client.rs +++ /dev/null @@ -1,95 +0,0 @@ -//! HTTP Client 构建模块 -//! -//! 提供统一的 HTTP Client 构建功能,支持代理配置 - -use reqwest::{Client, Proxy}; -use std::time::Duration; - -/// 代理配置 -#[derive(Debug, Clone, Default)] -pub struct ProxyConfig { - /// 代理地址,支持 http/https/socks5 - pub url: String, - /// 代理认证用户名 - pub username: Option, - /// 代理认证密码 - pub password: Option, -} - -impl ProxyConfig { - /// 从 url 创建代理配置 - pub fn new(url: impl Into) -> Self { - Self { - url: url.into(), - username: None, - password: None, - } - } - - /// 设置认证信息 - pub fn with_auth(mut self, username: impl Into, password: impl Into) -> Self { - self.username = Some(username.into()); - self.password = Some(password.into()); - self - } -} - -/// 构建 HTTP Client -/// -/// # Arguments -/// * `proxy` - 可选的代理配置 -/// * `timeout_secs` - 超时时间(秒) -/// -/// # Returns -/// 配置好的 reqwest::Client -pub fn build_client(proxy: Option<&ProxyConfig>, timeout_secs: u64) -> anyhow::Result { - let mut builder = Client::builder().timeout(Duration::from_secs(timeout_secs)); - - if let Some(proxy_config) = proxy { - let mut proxy = Proxy::all(&proxy_config.url)?; - - // 设置代理认证 - if let (Some(username), Some(password)) = (&proxy_config.username, &proxy_config.password) { - proxy = proxy.basic_auth(username, password); - } - - builder = builder.proxy(proxy); - tracing::debug!("HTTP Client 使用代理: {}", proxy_config.url); - } - - Ok(builder.build()?) -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_proxy_config_new() { - let config = ProxyConfig::new("http://127.0.0.1:7890"); - assert_eq!(config.url, "http://127.0.0.1:7890"); - assert!(config.username.is_none()); - assert!(config.password.is_none()); - } - - #[test] - fn test_proxy_config_with_auth() { - let config = ProxyConfig::new("socks5://127.0.0.1:1080").with_auth("user", "pass"); - assert_eq!(config.url, "socks5://127.0.0.1:1080"); - assert_eq!(config.username, Some("user".to_string())); - assert_eq!(config.password, Some("pass".to_string())); - } - - #[test] - fn test_build_client_without_proxy() { - let client = build_client(None, 30); - assert!(client.is_ok()); - } - - #[test] - fn test_build_client_with_proxy() { - let config = ProxyConfig::new("http://127.0.0.1:7890"); - let client = build_client(Some(&config), 30); - assert!(client.is_ok()); - } -} diff --git a/src/kiro/machine_id.rs b/src/kiro/machine_id.rs deleted file mode 100644 index b48599dba5b074150040baca0839c589413d51bf..0000000000000000000000000000000000000000 --- a/src/kiro/machine_id.rs +++ /dev/null @@ -1,167 +0,0 @@ -//! 设备指纹生成器 -//! - -use sha2::{Digest, Sha256}; - -use crate::kiro::model::credentials::KiroCredentials; -use crate::model::config::Config; - -/// 标准化 machineId 格式 -/// -/// 支持以下格式: -/// - 64 字符十六进制字符串(直接返回) -/// - UUID 格式(如 "2582956e-cc88-4669-b546-07adbffcb894",移除连字符后补齐到 64 字符) -fn normalize_machine_id(machine_id: &str) -> Option { - let trimmed = machine_id.trim(); - - // 如果已经是 64 字符,直接返回 - if trimmed.len() == 64 && trimmed.chars().all(|c| c.is_ascii_hexdigit()) { - return Some(trimmed.to_string()); - } - - // 尝试解析 UUID 格式(移除连字符) - let without_dashes: String = trimmed.chars().filter(|c| *c != '-').collect(); - - // UUID 去掉连字符后是 32 字符 - if without_dashes.len() == 32 && without_dashes.chars().all(|c| c.is_ascii_hexdigit()) { - // 补齐到 64 字符(重复一次) - return Some(format!("{}{}", without_dashes, without_dashes)); - } - - // 无法识别的格式 - None -} - -/// 根据凭证信息生成唯一的 Machine ID -/// -/// 优先使用凭据级 machineId,其次使用 config.machineId,然后使用 refreshToken 生成 -pub fn generate_from_credentials(credentials: &KiroCredentials, config: &Config) -> Option { - // 如果配置了凭据级 machineId,优先使用 - if let Some(ref machine_id) = credentials.machine_id { - if let Some(normalized) = normalize_machine_id(machine_id) { - return Some(normalized); - } - } - - // 如果配置了全局 machineId,作为默认值 - if let Some(ref machine_id) = config.machine_id { - if let Some(normalized) = normalize_machine_id(machine_id) { - return Some(normalized); - } - } - - // 使用 refreshToken 生成 - if let Some(ref refresh_token) = credentials.refresh_token { - if !refresh_token.is_empty() { - return Some(sha256_hex(&format!("KotlinNativeAPI/{}", refresh_token))); - } - } - - // 没有有效的凭证 - None -} - -/// SHA256 哈希实现(返回十六进制字符串) -fn sha256_hex(input: &str) -> String { - let mut hasher = Sha256::new(); - hasher.update(input.as_bytes()); - let result = hasher.finalize(); - hex::encode(result) -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_sha256_hex() { - let result = sha256_hex("test"); - assert_eq!(result.len(), 64); - assert_eq!( - result, - "9f86d081884c7d659a2feaa0c55ad015a3bf4f1b2b0b822cd15d6c15b0f00a08" - ); - } - - #[test] - fn test_generate_with_custom_machine_id() { - let credentials = KiroCredentials::default(); - let mut config = Config::default(); - config.machine_id = Some("a".repeat(64)); - - let result = generate_from_credentials(&credentials, &config); - assert_eq!(result, Some("a".repeat(64))); - } - - #[test] - fn test_generate_with_credential_machine_id_overrides_config() { - let mut credentials = KiroCredentials::default(); - credentials.machine_id = Some("b".repeat(64)); - - let mut config = Config::default(); - config.machine_id = Some("a".repeat(64)); - - let result = generate_from_credentials(&credentials, &config); - assert_eq!(result, Some("b".repeat(64))); - } - - #[test] - fn test_generate_with_refresh_token() { - let mut credentials = KiroCredentials::default(); - credentials.refresh_token = Some("test_refresh_token".to_string()); - let config = Config::default(); - - let result = generate_from_credentials(&credentials, &config); - assert!(result.is_some()); - assert_eq!(result.as_ref().unwrap().len(), 64); - } - - #[test] - fn test_generate_without_credentials() { - let credentials = KiroCredentials::default(); - let config = Config::default(); - - let result = generate_from_credentials(&credentials, &config); - assert!(result.is_none()); - } - - #[test] - fn test_normalize_uuid_format() { - // UUID 格式应该被转换为 64 字符 - let uuid = "2582956e-cc88-4669-b546-07adbffcb894"; - let result = normalize_machine_id(uuid); - assert!(result.is_some()); - let normalized = result.unwrap(); - assert_eq!(normalized.len(), 64); - // UUID 去掉连字符后重复一次 - assert_eq!(normalized, "2582956ecc884669b54607adbffcb8942582956ecc884669b54607adbffcb894"); - } - - #[test] - fn test_normalize_64_char_hex() { - // 64 字符十六进制应该直接返回 - let hex64 = "a".repeat(64); - let result = normalize_machine_id(&hex64); - assert_eq!(result, Some(hex64)); - } - - #[test] - fn test_normalize_invalid_format() { - // 无效格式应该返回 None - assert!(normalize_machine_id("invalid").is_none()); - assert!(normalize_machine_id("too-short").is_none()); - assert!(normalize_machine_id(&"g".repeat(64)).is_none()); // 非十六进制 - } - - #[test] - fn test_generate_with_uuid_machine_id() { - let mut credentials = KiroCredentials::default(); - credentials.machine_id = Some("2582956e-cc88-4669-b546-07adbffcb894".to_string()); - - let config = Config::default(); - - let result = generate_from_credentials(&credentials, &config); - assert!(result.is_some()); - assert_eq!(result.as_ref().unwrap().len(), 64); - } -} diff --git a/src/kiro/mod.rs b/src/kiro/mod.rs deleted file mode 100644 index 7a314bfeccefe951fcb27e14729508c7736574ea..0000000000000000000000000000000000000000 --- a/src/kiro/mod.rs +++ /dev/null @@ -1,7 +0,0 @@ -//! Kiro API 客户端模块 - -pub mod machine_id; -pub mod model; -pub mod parser; -pub mod provider; -pub mod token_manager; diff --git a/src/kiro/model/common/mod.rs b/src/kiro/model/common/mod.rs deleted file mode 100644 index 8a414093361b72ab35ac0835432e35d24292f644..0000000000000000000000000000000000000000 --- a/src/kiro/model/common/mod.rs +++ /dev/null @@ -1,4 +0,0 @@ -//! 共享类型模块 -//! -//! 此模块已简化,移除了未使用的类型定义。 -//! 如果将来需要扩展,可以在此添加新的共享类型。 diff --git a/src/kiro/model/credentials.rs b/src/kiro/model/credentials.rs deleted file mode 100644 index 6c4c9dad6a1d5746531d0642d4460673a76f32f8..0000000000000000000000000000000000000000 --- a/src/kiro/model/credentials.rs +++ /dev/null @@ -1,462 +0,0 @@ -//! Kiro OAuth 凭证数据模型 -//! -//! 支持从 Kiro IDE 的凭证文件加载,使用 Social 认证方式 -//! 支持单凭据和多凭据配置格式 - -use serde::{Deserialize, Serialize}; -use std::fs; -use std::path::Path; - -/// Kiro OAuth 凭证 -#[derive(Debug, Clone, Serialize, Deserialize, Default)] -#[serde(rename_all = "camelCase")] -pub struct KiroCredentials { - /// 凭据唯一标识符(自增 ID) - #[serde(skip_serializing_if = "Option::is_none")] - pub id: Option, - - /// 访问令牌 - #[serde(skip_serializing_if = "Option::is_none")] - pub access_token: Option, - - /// 刷新令牌 - #[serde(skip_serializing_if = "Option::is_none")] - pub refresh_token: Option, - - /// Profile ARN - #[serde(skip_serializing_if = "Option::is_none")] - pub profile_arn: Option, - - /// 过期时间 (RFC3339 格式) - #[serde(skip_serializing_if = "Option::is_none")] - pub expires_at: Option, - - /// 认证方式 (social / idc / builder-id) - #[serde(skip_serializing_if = "Option::is_none")] - pub auth_method: Option, - - /// OIDC Client ID (IdC 认证需要) - #[serde(skip_serializing_if = "Option::is_none")] - pub client_id: Option, - - /// OIDC Client Secret (IdC 认证需要) - #[serde(skip_serializing_if = "Option::is_none")] - pub client_secret: Option, - - /// 凭据优先级(数字越小优先级越高,默认为 0) - #[serde(default)] - #[serde(skip_serializing_if = "is_zero")] - pub priority: u32, - - /// 凭据级 Region 配置(用于 OIDC token 刷新) - /// 未配置时回退到 config.json 的全局 region - #[serde(skip_serializing_if = "Option::is_none")] - pub region: Option, - - /// 凭据级 Machine ID 配置(可选) - /// 未配置时回退到 config.json 的 machineId;都未配置时由 refreshToken 派生 - #[serde(skip_serializing_if = "Option::is_none")] - pub machine_id: Option, -} - -/// 判断是否为零(用于跳过序列化) -fn is_zero(value: &u32) -> bool { - *value == 0 -} - -/// 凭据配置(支持单对象或数组格式) -/// -/// 自动识别配置文件格式: -/// - 单对象格式(旧格式,向后兼容) -/// - 数组格式(新格式,支持多凭据) -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(untagged)] -pub enum CredentialsConfig { - /// 单个凭据(旧格式) - Single(KiroCredentials), - /// 多凭据数组(新格式) - Multiple(Vec), -} - -impl CredentialsConfig { - /// 从文件加载凭据配置 - /// - /// - 如果文件不存在,返回空数组 - /// - 如果文件内容为空,返回空数组 - /// - 支持单对象或数组格式 - pub fn load>(path: P) -> anyhow::Result { - let path = path.as_ref(); - - // 文件不存在时返回空数组 - if !path.exists() { - return Ok(CredentialsConfig::Multiple(vec![])); - } - - let content = fs::read_to_string(path)?; - - // 文件为空时返回空数组 - if content.trim().is_empty() { - return Ok(CredentialsConfig::Multiple(vec![])); - } - - let config = serde_json::from_str(&content)?; - Ok(config) - } - - /// 转换为按优先级排序的凭据列表 - pub fn into_sorted_credentials(self) -> Vec { - match self { - CredentialsConfig::Single(cred) => vec![cred], - CredentialsConfig::Multiple(mut creds) => { - // 按优先级排序(数字越小优先级越高) - creds.sort_by_key(|c| c.priority); - creds - } - } - } - - /// 获取凭据数量 - pub fn len(&self) -> usize { - match self { - CredentialsConfig::Single(_) => 1, - CredentialsConfig::Multiple(creds) => creds.len(), - } - } - - /// 判断是否为空 - pub fn is_empty(&self) -> bool { - match self { - CredentialsConfig::Single(_) => false, - CredentialsConfig::Multiple(creds) => creds.is_empty(), - } - } - - /// 判断是否为多凭据格式(数组格式) - pub fn is_multiple(&self) -> bool { - matches!(self, CredentialsConfig::Multiple(_)) - } -} - -impl KiroCredentials { - /// 获取默认凭证文件路径 - pub fn default_credentials_path() -> &'static str { - "credentials.json" - } - - /// 从 JSON 字符串解析凭证 - pub fn from_json(json_string: &str) -> Result { - serde_json::from_str(json_string) - } - - /// 从文件加载凭证 - pub fn load>(path: P) -> anyhow::Result { - let content = fs::read_to_string(path.as_ref())?; - if content.is_empty() { - anyhow::bail!("凭证文件为空: {:?}", path.as_ref()); - } - let credentials = Self::from_json(&content)?; - Ok(credentials) - } - - /// 序列化为格式化的 JSON 字符串 - pub fn to_pretty_json(&self) -> Result { - serde_json::to_string_pretty(self) - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_from_json() { - let json = r#"{ - "accessToken": "test_token", - "refreshToken": "test_refresh", - "profileArn": "arn:aws:test", - "expiresAt": "2024-01-01T00:00:00Z", - "authMethod": "social" - }"#; - - let creds = KiroCredentials::from_json(json).unwrap(); - assert_eq!(creds.access_token, Some("test_token".to_string())); - assert_eq!(creds.refresh_token, Some("test_refresh".to_string())); - assert_eq!(creds.profile_arn, Some("arn:aws:test".to_string())); - assert_eq!(creds.expires_at, Some("2024-01-01T00:00:00Z".to_string())); - assert_eq!(creds.auth_method, Some("social".to_string())); - } - - #[test] - fn test_from_json_with_unknown_keys() { - let json = r#"{ - "accessToken": "test_token", - "unknownField": "should be ignored" - }"#; - - let creds = KiroCredentials::from_json(json).unwrap(); - assert_eq!(creds.access_token, Some("test_token".to_string())); - } - - #[test] - fn test_to_json() { - let creds = KiroCredentials { - id: None, - access_token: Some("token".to_string()), - refresh_token: None, - profile_arn: None, - expires_at: None, - auth_method: Some("social".to_string()), - client_id: None, - client_secret: None, - priority: 0, - region: None, - machine_id: None, - }; - - let json = creds.to_pretty_json().unwrap(); - assert!(json.contains("accessToken")); - assert!(json.contains("authMethod")); - assert!(!json.contains("refreshToken")); - // priority 为 0 时不序列化 - assert!(!json.contains("priority")); - } - - #[test] - fn test_default_credentials_path() { - assert_eq!( - KiroCredentials::default_credentials_path(), - "credentials.json" - ); - } - - #[test] - fn test_priority_default() { - let json = r#"{"refreshToken": "test"}"#; - let creds = KiroCredentials::from_json(json).unwrap(); - assert_eq!(creds.priority, 0); - } - - #[test] - fn test_priority_explicit() { - let json = r#"{"refreshToken": "test", "priority": 5}"#; - let creds = KiroCredentials::from_json(json).unwrap(); - assert_eq!(creds.priority, 5); - } - - #[test] - fn test_credentials_config_single() { - let json = r#"{"refreshToken": "test", "expiresAt": "2025-12-31T00:00:00Z"}"#; - let config: CredentialsConfig = serde_json::from_str(json).unwrap(); - assert!(matches!(config, CredentialsConfig::Single(_))); - assert_eq!(config.len(), 1); - } - - #[test] - fn test_credentials_config_multiple() { - let json = r#"[ - {"refreshToken": "test1", "priority": 1}, - {"refreshToken": "test2", "priority": 0} - ]"#; - let config: CredentialsConfig = serde_json::from_str(json).unwrap(); - assert!(matches!(config, CredentialsConfig::Multiple(_))); - assert_eq!(config.len(), 2); - } - - #[test] - fn test_credentials_config_priority_sorting() { - let json = r#"[ - {"refreshToken": "t1", "priority": 2}, - {"refreshToken": "t2", "priority": 0}, - {"refreshToken": "t3", "priority": 1} - ]"#; - let config: CredentialsConfig = serde_json::from_str(json).unwrap(); - let list = config.into_sorted_credentials(); - - // 验证按优先级排序 - assert_eq!(list[0].refresh_token, Some("t2".to_string())); // priority 0 - assert_eq!(list[1].refresh_token, Some("t3".to_string())); // priority 1 - assert_eq!(list[2].refresh_token, Some("t1".to_string())); // priority 2 - } - - // ============ Region 字段测试 ============ - - #[test] - fn test_region_field_parsing() { - // 测试解析包含 region 字段的 JSON - let json = r#"{ - "refreshToken": "test_refresh", - "region": "us-east-1" - }"#; - - let creds = KiroCredentials::from_json(json).unwrap(); - assert_eq!(creds.refresh_token, Some("test_refresh".to_string())); - assert_eq!(creds.region, Some("us-east-1".to_string())); - } - - #[test] - fn test_region_field_missing_backward_compat() { - // 测试向后兼容:不包含 region 字段的旧格式 JSON - let json = r#"{ - "refreshToken": "test_refresh", - "authMethod": "social" - }"#; - - let creds = KiroCredentials::from_json(json).unwrap(); - assert_eq!(creds.refresh_token, Some("test_refresh".to_string())); - assert_eq!(creds.region, None); - } - - #[test] - fn test_region_field_serialization() { - // 测试序列化时正确输出 region 字段 - let creds = KiroCredentials { - id: None, - access_token: None, - refresh_token: Some("test".to_string()), - profile_arn: None, - expires_at: None, - auth_method: None, - client_id: None, - client_secret: None, - priority: 0, - region: Some("eu-west-1".to_string()), - machine_id: None, - }; - - let json = creds.to_pretty_json().unwrap(); - assert!(json.contains("region")); - assert!(json.contains("eu-west-1")); - } - - #[test] - fn test_region_field_none_not_serialized() { - // 测试 region 为 None 时不序列化 - let creds = KiroCredentials { - id: None, - access_token: None, - refresh_token: Some("test".to_string()), - profile_arn: None, - expires_at: None, - auth_method: None, - client_id: None, - client_secret: None, - priority: 0, - region: None, - machine_id: None, - }; - - let json = creds.to_pretty_json().unwrap(); - assert!(!json.contains("region")); - } - - // ============ MachineId 字段测试 ============ - - #[test] - fn test_machine_id_field_parsing() { - let machine_id = "a".repeat(64); - let json = format!( - r#"{{ - "refreshToken": "test_refresh", - "machineId": "{machine_id}" - }}"# - ); - - let creds = KiroCredentials::from_json(&json).unwrap(); - assert_eq!(creds.refresh_token, Some("test_refresh".to_string())); - assert_eq!(creds.machine_id, Some(machine_id)); - } - - #[test] - fn test_machine_id_field_serialization() { - let mut creds = KiroCredentials::default(); - creds.refresh_token = Some("test".to_string()); - creds.machine_id = Some("b".repeat(64)); - - let json = creds.to_pretty_json().unwrap(); - assert!(json.contains("machineId")); - } - - #[test] - fn test_machine_id_field_none_not_serialized() { - let mut creds = KiroCredentials::default(); - creds.refresh_token = Some("test".to_string()); - creds.machine_id = None; - - let json = creds.to_pretty_json().unwrap(); - assert!(!json.contains("machineId")); - } - - #[test] - fn test_multiple_credentials_with_different_regions() { - // 测试多凭据场景下不同凭据使用各自的 region - let json = r#"[ - {"refreshToken": "t1", "region": "us-east-1"}, - {"refreshToken": "t2", "region": "eu-west-1"}, - {"refreshToken": "t3"} - ]"#; - - let config: CredentialsConfig = serde_json::from_str(json).unwrap(); - let list = config.into_sorted_credentials(); - - assert_eq!(list[0].region, Some("us-east-1".to_string())); - assert_eq!(list[1].region, Some("eu-west-1".to_string())); - assert_eq!(list[2].region, None); - } - - #[test] - fn test_region_field_with_all_fields() { - // 测试包含所有字段的完整 JSON - let json = r#"{ - "id": 1, - "accessToken": "access", - "refreshToken": "refresh", - "profileArn": "arn:aws:test", - "expiresAt": "2025-12-31T00:00:00Z", - "authMethod": "idc", - "clientId": "client123", - "clientSecret": "secret456", - "priority": 5, - "region": "ap-northeast-1" - }"#; - - let creds = KiroCredentials::from_json(json).unwrap(); - assert_eq!(creds.id, Some(1)); - assert_eq!(creds.access_token, Some("access".to_string())); - assert_eq!(creds.refresh_token, Some("refresh".to_string())); - assert_eq!(creds.profile_arn, Some("arn:aws:test".to_string())); - assert_eq!(creds.expires_at, Some("2025-12-31T00:00:00Z".to_string())); - assert_eq!(creds.auth_method, Some("idc".to_string())); - assert_eq!(creds.client_id, Some("client123".to_string())); - assert_eq!(creds.client_secret, Some("secret456".to_string())); - assert_eq!(creds.priority, 5); - assert_eq!(creds.region, Some("ap-northeast-1".to_string())); - } - - #[test] - fn test_region_roundtrip() { - // 测试序列化和反序列化的往返一致性 - let original = KiroCredentials { - id: Some(42), - access_token: Some("token".to_string()), - refresh_token: Some("refresh".to_string()), - profile_arn: None, - expires_at: None, - auth_method: Some("social".to_string()), - client_id: None, - client_secret: None, - priority: 3, - region: Some("us-west-2".to_string()), - machine_id: Some("c".repeat(64)), - }; - - let json = original.to_pretty_json().unwrap(); - let parsed = KiroCredentials::from_json(&json).unwrap(); - - assert_eq!(parsed.id, original.id); - assert_eq!(parsed.access_token, original.access_token); - assert_eq!(parsed.refresh_token, original.refresh_token); - assert_eq!(parsed.priority, original.priority); - assert_eq!(parsed.region, original.region); - assert_eq!(parsed.machine_id, original.machine_id); - } -} diff --git a/src/kiro/model/events/assistant.rs b/src/kiro/model/events/assistant.rs deleted file mode 100644 index 68f1114811324fccac837819f5c2ab260d6f2458..0000000000000000000000000000000000000000 --- a/src/kiro/model/events/assistant.rs +++ /dev/null @@ -1,115 +0,0 @@ -//! 助手响应事件 -//! -//! 处理 assistantResponseEvent 类型的事件 - -use serde::{Deserialize, Serialize}; - -use crate::kiro::parser::error::ParseResult; -use crate::kiro::parser::frame::Frame; - -use super::base::EventPayload; - -/// 助手响应事件 -/// -/// 包含 AI 助手的流式响应内容 -/// -/// # 设计说明 -/// -/// 此结构体只保留实际使用的 `content` 字段,其他 API 返回的字段 -/// 通过 `#[serde(flatten)]` 捕获到 `extra` 中,确保反序列化不会失败。 -/// -/// # 示例 -/// -/// ```rust -/// use kiro_rs::kiro::model::events::AssistantResponseEvent; -/// -/// let json = r#"{"content":"Hello, world!"}"#; -/// let event: AssistantResponseEvent = serde_json::from_str(json).unwrap(); -/// assert_eq!(event.content, "Hello, world!"); -/// ``` -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct AssistantResponseEvent { - /// 响应内容片段 - #[serde(default)] - pub content: String, - - /// 捕获其他未使用的字段,确保反序列化兼容性 - #[serde(flatten)] - #[serde(skip_serializing)] - #[allow(dead_code)] - extra: serde_json::Value, -} - -impl EventPayload for AssistantResponseEvent { - fn from_frame(frame: &Frame) -> ParseResult { - frame.payload_as_json() - } -} - -impl Default for AssistantResponseEvent { - fn default() -> Self { - Self { - content: String::new(), - extra: serde_json::Value::Null, - } - } -} - -impl std::fmt::Display for AssistantResponseEvent { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", self.content) - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_deserialize_simple() { - let json = r#"{"content":"Hello, world!"}"#; - let event: AssistantResponseEvent = serde_json::from_str(json).unwrap(); - assert_eq!(event.content, "Hello, world!"); - } - - #[test] - fn test_deserialize_with_extra_fields() { - // 确保包含额外字段时反序列化不会失败 - let json = r#"{ - "content": "Done", - "conversationId": "conv-123", - "messageId": "msg-456", - "messageStatus": "COMPLETED", - "followupPrompt": { - "content": "Would you like me to explain further?", - "userIntent": "EXPLAIN_CODE_SELECTION" - } - }"#; - let event: AssistantResponseEvent = serde_json::from_str(json).unwrap(); - assert_eq!(event.content, "Done"); - } - - #[test] - fn test_serialize_minimal() { - let event = AssistantResponseEvent::default(); - let event = AssistantResponseEvent { - content: "Test".to_string(), - ..event - }; - - let json = serde_json::to_string(&event).unwrap(); - assert!(json.contains("\"content\":\"Test\"")); - // extra 字段不应该被序列化 - assert!(!json.contains("extra")); - } - - #[test] - fn test_display() { - let event = AssistantResponseEvent { - content: "test".to_string(), - ..Default::default() - }; - assert_eq!(format!("{}", event), "test"); - } -} diff --git a/src/kiro/model/events/base.rs b/src/kiro/model/events/base.rs deleted file mode 100644 index 28d2e74f0c21188dfe037f760229665c965f39a3..0000000000000000000000000000000000000000 --- a/src/kiro/model/events/base.rs +++ /dev/null @@ -1,186 +0,0 @@ -//! 事件基础定义 -//! -//! 定义事件类型枚举、trait 和统一事件结构 - -use crate::kiro::parser::error::{ParseError, ParseResult}; -use crate::kiro::parser::frame::Frame; - -/// 事件类型枚举 -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -pub enum EventType { - /// 助手响应事件 - AssistantResponse, - /// 工具使用事件 - ToolUse, - /// 计费事件 - Metering, - /// 上下文使用率事件 - ContextUsage, - /// 未知事件类型 - Unknown, -} - -impl EventType { - /// 从事件类型字符串解析 - pub fn from_str(s: &str) -> Self { - match s { - "assistantResponseEvent" => Self::AssistantResponse, - "toolUseEvent" => Self::ToolUse, - "meteringEvent" => Self::Metering, - "contextUsageEvent" => Self::ContextUsage, - _ => Self::Unknown, - } - } - - /// 转换为事件类型字符串 - pub fn as_str(&self) -> &'static str { - match self { - Self::AssistantResponse => "assistantResponseEvent", - Self::ToolUse => "toolUseEvent", - Self::Metering => "meteringEvent", - Self::ContextUsage => "contextUsageEvent", - Self::Unknown => "unknown", - } - } -} - -impl std::fmt::Display for EventType { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", self.as_str()) - } -} - -/// 事件 payload trait -/// -/// 所有具体事件类型都需要实现此 trait -pub trait EventPayload: Sized { - /// 从帧解析事件负载 - fn from_frame(frame: &Frame) -> ParseResult; -} - -/// 统一事件枚举 -/// -/// 封装所有可能的事件类型 -#[derive(Debug, Clone)] -pub enum Event { - /// 助手响应 - AssistantResponse(super::AssistantResponseEvent), - /// 工具使用 - ToolUse(super::ToolUseEvent), - /// 计费 - Metering(()), - /// 上下文使用率 - ContextUsage(super::ContextUsageEvent), - /// 未知事件 (保留原始帧数据) - Unknown {}, - /// 服务端错误 - Error { - /// 错误代码 - error_code: String, - /// 错误消息 - error_message: String, - }, - /// 服务端异常 - Exception { - /// 异常类型 - exception_type: String, - /// 异常消息 - message: String, - }, -} - -impl Event { - /// 从帧解析事件 - pub fn from_frame(frame: Frame) -> ParseResult { - let message_type = frame.message_type().unwrap_or("event"); - - match message_type { - "event" => Self::parse_event(frame), - "error" => Self::parse_error(frame), - "exception" => Self::parse_exception(frame), - other => Err(ParseError::InvalidMessageType(other.to_string())), - } - } - - /// 解析事件类型消息 - fn parse_event(frame: Frame) -> ParseResult { - let event_type_str = frame.event_type().unwrap_or("unknown"); - let event_type = EventType::from_str(event_type_str); - - match event_type { - EventType::AssistantResponse => { - let payload = super::AssistantResponseEvent::from_frame(&frame)?; - Ok(Self::AssistantResponse(payload)) - } - EventType::ToolUse => { - let payload = super::ToolUseEvent::from_frame(&frame)?; - Ok(Self::ToolUse(payload)) - } - EventType::Metering => Ok(Self::Metering(())), - EventType::ContextUsage => { - let payload = super::ContextUsageEvent::from_frame(&frame)?; - Ok(Self::ContextUsage(payload)) - } - EventType::Unknown => Ok(Self::Unknown {}), - } - } - - /// 解析错误类型消息 - fn parse_error(frame: Frame) -> ParseResult { - let error_code = frame - .headers - .error_code() - .unwrap_or("UnknownError") - .to_string(); - let error_message = frame.payload_as_str(); - - Ok(Self::Error { - error_code, - error_message, - }) - } - - /// 解析异常类型消息 - fn parse_exception(frame: Frame) -> ParseResult { - let exception_type = frame - .headers - .exception_type() - .unwrap_or("UnknownException") - .to_string(); - let message = frame.payload_as_str(); - - Ok(Self::Exception { - exception_type, - message, - }) - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_event_type_from_str() { - assert_eq!( - EventType::from_str("assistantResponseEvent"), - EventType::AssistantResponse - ); - assert_eq!(EventType::from_str("toolUseEvent"), EventType::ToolUse); - assert_eq!(EventType::from_str("meteringEvent"), EventType::Metering); - assert_eq!( - EventType::from_str("contextUsageEvent"), - EventType::ContextUsage - ); - assert_eq!(EventType::from_str("unknown_type"), EventType::Unknown); - } - - #[test] - fn test_event_type_as_str() { - assert_eq!( - EventType::AssistantResponse.as_str(), - "assistantResponseEvent" - ); - assert_eq!(EventType::ToolUse.as_str(), "toolUseEvent"); - } -} diff --git a/src/kiro/model/events/context_usage.rs b/src/kiro/model/events/context_usage.rs deleted file mode 100644 index 34034921f0f12dd181fc099c6b2866a3b644989f..0000000000000000000000000000000000000000 --- a/src/kiro/model/events/context_usage.rs +++ /dev/null @@ -1,40 +0,0 @@ -//! 上下文使用率事件 -//! -//! 处理 contextUsageEvent 类型的事件 - -use serde::Deserialize; - -use crate::kiro::parser::error::ParseResult; -use crate::kiro::parser::frame::Frame; - -use super::base::EventPayload; - -/// 上下文使用率事件 -/// -/// 包含当前上下文窗口的使用百分比 -#[derive(Debug, Clone, Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct ContextUsageEvent { - /// 上下文使用百分比 (0-100) - #[serde(default)] - pub context_usage_percentage: f64, -} - -impl EventPayload for ContextUsageEvent { - fn from_frame(frame: &Frame) -> ParseResult { - frame.payload_as_json() - } -} - -impl ContextUsageEvent { - /// 获取格式化的百分比字符串 - pub fn formatted_percentage(&self) -> String { - format!("{:.2}%", self.context_usage_percentage) - } -} - -impl std::fmt::Display for ContextUsageEvent { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", self.formatted_percentage()) - } -} diff --git a/src/kiro/model/events/mod.rs b/src/kiro/model/events/mod.rs deleted file mode 100644 index c9880f910d455726bfe7a1d2e853ad7578c79f14..0000000000000000000000000000000000000000 --- a/src/kiro/model/events/mod.rs +++ /dev/null @@ -1,13 +0,0 @@ -//! 事件模型 -//! -//! 定义 generateAssistantResponse 流式响应的事件类型 - -mod assistant; -mod base; -mod context_usage; -mod tool_use; - -pub use assistant::AssistantResponseEvent; -pub use base::Event; -pub use context_usage::ContextUsageEvent; -pub use tool_use::ToolUseEvent; diff --git a/src/kiro/model/events/tool_use.rs b/src/kiro/model/events/tool_use.rs deleted file mode 100644 index 9bfab9dea786114f5b2a9c9535e80190e7f9be60..0000000000000000000000000000000000000000 --- a/src/kiro/model/events/tool_use.rs +++ /dev/null @@ -1,52 +0,0 @@ -//! 工具使用事件 -//! -//! 处理 toolUseEvent 类型的事件 - -use serde::Deserialize; - -use crate::kiro::parser::error::ParseResult; -use crate::kiro::parser::frame::Frame; - -use super::base::EventPayload; - -/// 工具使用事件 -/// -/// 包含工具调用的流式数据 -#[derive(Debug, Clone, Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct ToolUseEvent { - /// 工具名称 - pub name: String, - /// 工具调用 ID - pub tool_use_id: String, - /// 工具输入数据 (JSON 字符串,可能是流式的部分数据) - #[serde(default)] - pub input: String, - /// 是否是最后一个块 - #[serde(default)] - pub stop: bool, -} - -impl EventPayload for ToolUseEvent { - fn from_frame(frame: &Frame) -> ParseResult { - frame.payload_as_json() - } -} - -impl std::fmt::Display for ToolUseEvent { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - if self.stop { - write!( - f, - "ToolUse[{}] (id={}, complete): {}", - self.name, self.tool_use_id, self.input - ) - } else { - write!( - f, - "ToolUse[{}] (id={}, partial): {}", - self.name, self.tool_use_id, self.input - ) - } - } -} diff --git a/src/kiro/model/mod.rs b/src/kiro/model/mod.rs deleted file mode 100644 index b0342d7c102f322c9c6c74831045f7012ae9d40e..0000000000000000000000000000000000000000 --- a/src/kiro/model/mod.rs +++ /dev/null @@ -1,16 +0,0 @@ -//! Kiro 数据模型 -//! -//! 包含 Kiro API 的所有数据类型定义: -//! - `common`: 共享类型(枚举和辅助结构体) -//! - `events`: 响应事件类型 -//! - `requests`: 请求类型 -//! - `credentials`: OAuth 凭证 -//! - `token_refresh`: Token 刷新 -//! - `usage_limits`: 使用额度查询 - -pub mod common; -pub mod credentials; -pub mod events; -pub mod requests; -pub mod token_refresh; -pub mod usage_limits; diff --git a/src/kiro/model/requests/conversation.rs b/src/kiro/model/requests/conversation.rs deleted file mode 100644 index 393b06ee257d198e6658c4073e1a205dc8f81574..0000000000000000000000000000000000000000 --- a/src/kiro/model/requests/conversation.rs +++ /dev/null @@ -1,408 +0,0 @@ -//! 对话类型定义 -//! -//! 定义 Kiro API 中对话相关的类型,包括消息、历史记录等 - -use serde::{Deserialize, Serialize}; - -use super::tool::{Tool, ToolResult, ToolUseEntry}; - -/// 对话状态 -/// -/// Kiro API 请求中的核心结构,包含当前消息和历史记录 -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct ConversationState { - /// 代理延续 ID - #[serde(skip_serializing_if = "Option::is_none")] - pub agent_continuation_id: Option, - /// 代理任务类型(通常为 "vibe") - #[serde(skip_serializing_if = "Option::is_none")] - pub agent_task_type: Option, - /// 聊天触发类型("MANUAL" 或 "AUTO") - #[serde(skip_serializing_if = "Option::is_none")] - pub chat_trigger_type: Option, - /// 当前消息 - pub current_message: CurrentMessage, - /// 会话 ID - pub conversation_id: String, - /// 历史消息列表 - #[serde(default, skip_serializing_if = "Vec::is_empty")] - pub history: Vec, -} - -impl ConversationState { - /// 创建新的对话状态 - pub fn new(conversation_id: impl Into) -> Self { - Self { - agent_continuation_id: None, - agent_task_type: None, - chat_trigger_type: None, - current_message: CurrentMessage::default(), - conversation_id: conversation_id.into(), - history: Vec::new(), - } - } - - /// 设置代理延续 ID - pub fn with_agent_continuation_id(mut self, id: impl Into) -> Self { - self.agent_continuation_id = Some(id.into()); - self - } - - /// 设置代理任务类型 - pub fn with_agent_task_type(mut self, task_type: impl Into) -> Self { - self.agent_task_type = Some(task_type.into()); - self - } - - /// 设置聊天触发类型 - pub fn with_chat_trigger_type(mut self, trigger_type: impl Into) -> Self { - self.chat_trigger_type = Some(trigger_type.into()); - self - } - - /// 设置当前消息 - pub fn with_current_message(mut self, message: CurrentMessage) -> Self { - self.current_message = message; - self - } - - /// 添加历史消息 - pub fn with_history(mut self, history: Vec) -> Self { - self.history = history; - self - } -} - -/// 当前消息容器 -#[derive(Debug, Clone, Default, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct CurrentMessage { - /// 用户输入消息 - pub user_input_message: UserInputMessage, -} - -impl CurrentMessage { - /// 创建新的当前消息 - pub fn new(user_input_message: UserInputMessage) -> Self { - Self { user_input_message } - } -} - -/// 用户输入消息 -#[derive(Debug, Clone, Default, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct UserInputMessage { - /// 用户输入消息上下文 - pub user_input_message_context: UserInputMessageContext, - /// 消息内容 - pub content: String, - /// 模型 ID - pub model_id: String, - /// 图片列表 - #[serde(default, skip_serializing_if = "Vec::is_empty")] - pub images: Vec, - /// 消息来源(通常为 "AI_EDITOR") - #[serde(skip_serializing_if = "Option::is_none")] - pub origin: Option, -} - -impl UserInputMessage { - /// 创建新的用户输入消息 - pub fn new(content: impl Into, model_id: impl Into) -> Self { - Self { - user_input_message_context: UserInputMessageContext::default(), - content: content.into(), - model_id: model_id.into(), - images: Vec::new(), - origin: Some("AI_EDITOR".to_string()), - } - } - - /// 设置消息上下文 - pub fn with_context(mut self, context: UserInputMessageContext) -> Self { - self.user_input_message_context = context; - self - } - - /// 添加图片 - pub fn with_images(mut self, images: Vec) -> Self { - self.images = images; - self - } - - /// 设置来源 - pub fn with_origin(mut self, origin: impl Into) -> Self { - self.origin = Some(origin.into()); - self - } -} - -/// 用户输入消息上下文 -/// -/// 包含工具定义和工具执行结果 -#[derive(Debug, Clone, Default, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct UserInputMessageContext { - /// 工具执行结果列表 - #[serde(default, skip_serializing_if = "Vec::is_empty")] - pub tool_results: Vec, - /// 可用工具列表 - #[serde(default, skip_serializing_if = "Vec::is_empty")] - pub tools: Vec, -} - -impl UserInputMessageContext { - /// 创建新的消息上下文 - pub fn new() -> Self { - Self::default() - } - - /// 设置工具列表 - pub fn with_tools(mut self, tools: Vec) -> Self { - self.tools = tools; - self - } - - /// 设置工具结果 - pub fn with_tool_results(mut self, results: Vec) -> Self { - self.tool_results = results; - self - } -} - -/// Kiro 图片 -/// -/// API 中使用的图片格式 -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct KiroImage { - /// 图片格式("jpeg", "png", "gif", "webp") - pub format: String, - /// 图片数据源 - pub source: KiroImageSource, -} - -impl KiroImage { - /// 从 base64 数据创建图片 - pub fn from_base64(format: impl Into, data: impl Into) -> Self { - Self { - format: format.into(), - source: KiroImageSource { bytes: data.into() }, - } - } -} - -/// Kiro 图片数据源 -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct KiroImageSource { - /// base64 编码的图片数据 - pub bytes: String, -} - -/// 历史消息 -/// -/// 可以是用户消息或助手消息 -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(untagged)] -pub enum Message { - /// 用户消息 - User(HistoryUserMessage), - /// 助手消息 - Assistant(HistoryAssistantMessage), -} - -#[allow(dead_code)] -impl Message { - /// 创建用户消息 - pub fn user(content: impl Into, model_id: impl Into) -> Self { - Self::User(HistoryUserMessage::new(content, model_id)) - } - - /// 创建助手消息 - pub fn assistant(content: impl Into) -> Self { - Self::Assistant(HistoryAssistantMessage::new(content)) - } - - /// 判断是否为用户消息 - pub fn is_user(&self) -> bool { - matches!(self, Self::User(_)) - } - - /// 判断是否为助手消息 - pub fn is_assistant(&self) -> bool { - matches!(self, Self::Assistant(_)) - } -} - -/// 历史用户消息 -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct HistoryUserMessage { - /// 用户输入消息 - pub user_input_message: UserMessage, -} - -impl HistoryUserMessage { - /// 创建新的历史用户消息 - pub fn new(content: impl Into, model_id: impl Into) -> Self { - Self { - user_input_message: UserMessage::new(content, model_id), - } - } -} - -/// 用户消息(历史记录中使用) -#[derive(Debug, Clone, Default, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct UserMessage { - /// 消息内容 - pub content: String, - /// 模型 ID - pub model_id: String, - /// 消息来源 - #[serde(skip_serializing_if = "Option::is_none")] - pub origin: Option, - /// 图片列表 - #[serde(default, skip_serializing_if = "Vec::is_empty")] - pub images: Vec, - /// 用户输入消息上下文 - #[serde(default, skip_serializing_if = "is_default_context")] - pub user_input_message_context: UserInputMessageContext, -} - -fn is_default_context(ctx: &UserInputMessageContext) -> bool { - ctx.tools.is_empty() && ctx.tool_results.is_empty() -} - -impl UserMessage { - /// 创建新的用户消息 - pub fn new(content: impl Into, model_id: impl Into) -> Self { - Self { - content: content.into(), - model_id: model_id.into(), - origin: Some("AI_EDITOR".to_string()), - images: Vec::new(), - user_input_message_context: UserInputMessageContext::default(), - } - } - - /// 设置图片 - pub fn with_images(mut self, images: Vec) -> Self { - self.images = images; - self - } - - /// 设置上下文 - pub fn with_context(mut self, context: UserInputMessageContext) -> Self { - self.user_input_message_context = context; - self - } -} - -/// 历史助手消息 -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct HistoryAssistantMessage { - /// 助手响应消息 - pub assistant_response_message: AssistantMessage, -} - -impl HistoryAssistantMessage { - /// 创建新的历史助手消息 - pub fn new(content: impl Into) -> Self { - Self { - assistant_response_message: AssistantMessage::new(content), - } - } -} - -/// 助手消息(历史记录中使用) -#[derive(Debug, Clone, Default, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct AssistantMessage { - /// 响应内容 - pub content: String, - /// 工具使用列表 - #[serde(default, skip_serializing_if = "Option::is_none")] - pub tool_uses: Option>, -} - -impl AssistantMessage { - /// 创建新的助手消息 - pub fn new(content: impl Into) -> Self { - Self { - content: content.into(), - tool_uses: None, - } - } - - /// 设置工具使用 - pub fn with_tool_uses(mut self, tool_uses: Vec) -> Self { - self.tool_uses = Some(tool_uses); - self - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_conversation_state_new() { - let state = ConversationState::new("conv-123") - .with_agent_task_type("vibe") - .with_chat_trigger_type("MANUAL"); - - assert_eq!(state.conversation_id, "conv-123"); - assert_eq!(state.agent_task_type, Some("vibe".to_string())); - assert_eq!(state.chat_trigger_type, Some("MANUAL".to_string())); - } - - #[test] - fn test_user_input_message() { - let msg = UserInputMessage::new("Hello", "claude-3-5-sonnet").with_origin("AI_EDITOR"); - - assert_eq!(msg.content, "Hello"); - assert_eq!(msg.model_id, "claude-3-5-sonnet"); - assert_eq!(msg.origin, Some("AI_EDITOR".to_string())); - } - - #[test] - fn test_message_enum() { - let user_msg = Message::user("Hello", "model-id"); - assert!(user_msg.is_user()); - assert!(!user_msg.is_assistant()); - - let assistant_msg = Message::assistant("Hi there!"); - assert!(assistant_msg.is_assistant()); - assert!(!assistant_msg.is_user()); - } - - #[test] - fn test_history_serialize() { - let history = vec![ - Message::user("Hello", "claude-3-5-sonnet"), - Message::assistant("Hi! How can I help you?"), - ]; - - let json = serde_json::to_string(&history).unwrap(); - assert!(json.contains("userInputMessage")); - assert!(json.contains("assistantResponseMessage")); - } - - #[test] - fn test_conversation_state_serialize() { - let state = ConversationState::new("conv-123") - .with_agent_task_type("vibe") - .with_current_message(CurrentMessage::new(UserInputMessage::new( - "Hello", - "claude-3-5-sonnet", - ))); - - let json = serde_json::to_string(&state).unwrap(); - assert!(json.contains("\"conversationId\":\"conv-123\"")); - assert!(json.contains("\"agentTaskType\":\"vibe\"")); - assert!(json.contains("\"content\":\"Hello\"")); - } -} diff --git a/src/kiro/model/requests/kiro.rs b/src/kiro/model/requests/kiro.rs deleted file mode 100644 index 4af70663cb6128d7cfcd4535b7daebb5987db869..0000000000000000000000000000000000000000 --- a/src/kiro/model/requests/kiro.rs +++ /dev/null @@ -1,68 +0,0 @@ -//! Kiro 请求类型定义 -//! -//! 定义 Kiro API 的主请求结构 - -use serde::{Deserialize, Serialize}; - -use super::conversation::ConversationState; - -/// Kiro API 请求 -/// -/// 用于构建发送给 Kiro API 的请求 -/// -/// # 示例 -/// -/// ```rust -/// use kiro_rs::kiro::model::requests::{ -/// KiroRequest, ConversationState, CurrentMessage, UserInputMessage, Tool -/// }; -/// -/// // 创建简单请求 -/// let state = ConversationState::new("conv-123") -/// .with_agent_task_type("vibe") -/// .with_current_message(CurrentMessage::new( -/// UserInputMessage::new("Hello", "claude-3-5-sonnet") -/// )); -/// -/// let request = KiroRequest::new(state); -/// let json = request.to_json().unwrap(); -/// ``` -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct KiroRequest { - /// 对话状态 - pub conversation_state: ConversationState, - /// Profile ARN(可选) - #[serde(skip_serializing_if = "Option::is_none")] - pub profile_arn: Option, -} -#[cfg(test)] -mod tests { - use super::*; - #[test] - fn test_kiro_request_deserialize() { - let json = r#"{ - "conversationState": { - "conversationId": "conv-456", - "currentMessage": { - "userInputMessage": { - "content": "Test message", - "modelId": "claude-3-5-sonnet", - "userInputMessageContext": {} - } - } - } - }"#; - - let request: KiroRequest = serde_json::from_str(json).unwrap(); - assert_eq!(request.conversation_state.conversation_id, "conv-456"); - assert_eq!( - request - .conversation_state - .current_message - .user_input_message - .content, - "Test message" - ); - } -} diff --git a/src/kiro/model/requests/mod.rs b/src/kiro/model/requests/mod.rs deleted file mode 100644 index ec82c6db88bba3b131cc2d88c6fda52ce8101a4b..0000000000000000000000000000000000000000 --- a/src/kiro/model/requests/mod.rs +++ /dev/null @@ -1,7 +0,0 @@ -//! 请求类型模块 -//! -//! 包含 Kiro API 请求相关的类型定义 - -pub mod conversation; -pub mod kiro; -pub mod tool; diff --git a/src/kiro/model/requests/tool.rs b/src/kiro/model/requests/tool.rs deleted file mode 100644 index c251d614581f3a633fa325b4e5884701a009c4d6..0000000000000000000000000000000000000000 --- a/src/kiro/model/requests/tool.rs +++ /dev/null @@ -1,192 +0,0 @@ -//! 工具类型定义 -//! -//! 定义 Kiro API 中工具相关的类型 - -use serde::{Deserialize, Serialize}; - -/// 工具定义 -/// -/// 用于在请求中定义可用的工具 -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct Tool { - /// 工具规范 - pub tool_specification: ToolSpecification, -} - -/// 工具规范 -/// -/// 定义工具的名称、描述和输入模式 -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct ToolSpecification { - /// 工具名称 - pub name: String, - /// 工具描述 - pub description: String, - /// 输入模式(JSON Schema) - pub input_schema: InputSchema, -} - -/// 输入模式 -/// -/// 包装 JSON Schema 定义 -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct InputSchema { - /// JSON Schema 定义 - pub json: serde_json::Value, -} - -impl Default for InputSchema { - fn default() -> Self { - Self { - json: serde_json::json!({ - "type": "object", - "properties": {} - }), - } - } -} - -impl InputSchema { - /// 从 JSON 值创建 - pub fn from_json(json: serde_json::Value) -> Self { - Self { json } - } -} - -/// 工具执行结果 -/// -/// 用于返回工具执行的结果 -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct ToolResult { - /// 工具使用 ID(与请求中的 tool_use_id 对应) - pub tool_use_id: String, - /// 结果内容(数组格式) - pub content: Vec>, - /// 执行状态("success" 或 "error") - #[serde(skip_serializing_if = "Option::is_none")] - pub status: Option, - /// 是否为错误 - #[serde(default, skip_serializing_if = "is_false")] - pub is_error: bool, -} - -fn is_false(b: &bool) -> bool { - !*b -} - -impl ToolResult { - /// 创建成功的工具结果 - pub fn success(tool_use_id: impl Into, content: impl Into) -> Self { - let mut map = serde_json::Map::new(); - map.insert( - "text".to_string(), - serde_json::Value::String(content.into()), - ); - - Self { - tool_use_id: tool_use_id.into(), - content: vec![map], - status: Some("success".to_string()), - is_error: false, - } - } - - /// 创建错误的工具结果 - pub fn error(tool_use_id: impl Into, error_message: impl Into) -> Self { - let mut map = serde_json::Map::new(); - map.insert( - "text".to_string(), - serde_json::Value::String(error_message.into()), - ); - - Self { - tool_use_id: tool_use_id.into(), - content: vec![map], - status: Some("error".to_string()), - is_error: true, - } - } -} - -/// 工具使用条目 -/// -/// 用于历史消息中记录工具调用 -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct ToolUseEntry { - /// 工具使用 ID - pub tool_use_id: String, - /// 工具名称 - pub name: String, - /// 工具输入参数 - pub input: serde_json::Value, -} - -impl ToolUseEntry { - /// 创建新的工具使用条目 - pub fn new(tool_use_id: impl Into, name: impl Into) -> Self { - Self { - tool_use_id: tool_use_id.into(), - name: name.into(), - input: serde_json::json!({}), - } - } - - /// 设置输入参数 - pub fn with_input(mut self, input: serde_json::Value) -> Self { - self.input = input; - self - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_tool_result_success() { - let result = ToolResult::success("tool-123", "Operation completed"); - - assert!(!result.is_error); - assert_eq!(result.status, Some("success".to_string())); - } - - #[test] - fn test_tool_result_error() { - let result = ToolResult::error("tool-456", "File not found"); - - assert!(result.is_error); - assert_eq!(result.status, Some("error".to_string())); - } - - #[test] - fn test_tool_result_serialize() { - let result = ToolResult::success("tool-789", "Done"); - let json = serde_json::to_string(&result).unwrap(); - - assert!(json.contains("\"toolUseId\":\"tool-789\"")); - assert!(json.contains("\"status\":\"success\"")); - // is_error = false 应该被跳过 - assert!(!json.contains("isError")); - } - - #[test] - fn test_tool_use_entry() { - let entry = ToolUseEntry::new("use-123", "read_file") - .with_input(serde_json::json!({"path": "/test.txt"})); - - let json = serde_json::to_string(&entry).unwrap(); - assert!(json.contains("\"toolUseId\":\"use-123\"")); - assert!(json.contains("\"name\":\"read_file\"")); - assert!(json.contains("\"path\":\"/test.txt\"")); - } - - #[test] - fn test_input_schema_default() { - let schema = InputSchema::default(); - assert_eq!(schema.json["type"], "object"); - } -} diff --git a/src/kiro/model/token_refresh.rs b/src/kiro/model/token_refresh.rs deleted file mode 100644 index 4a3846b1e99d0d4701ebdbe9555e2c41a0b70b11..0000000000000000000000000000000000000000 --- a/src/kiro/model/token_refresh.rs +++ /dev/null @@ -1,44 +0,0 @@ -use serde::{Deserialize, Serialize}; - -/// 刷新 Token 的请求体 (Social 认证) -#[derive(Debug, Serialize)] -#[serde(rename_all = "camelCase")] -pub struct RefreshRequest { - pub refresh_token: String, -} - -/// 刷新 Token 的响应体 (Social 认证) -#[derive(Debug, Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct RefreshResponse { - pub access_token: String, - #[serde(default)] - pub refresh_token: Option, - #[serde(default)] - pub profile_arn: Option, - #[serde(default)] - pub expires_in: Option, -} - -/// IdC Token 刷新请求体 (AWS SSO OIDC) -#[derive(Debug, Serialize)] -#[serde(rename_all = "camelCase")] -pub struct IdcRefreshRequest { - pub client_id: String, - pub client_secret: String, - pub refresh_token: String, - pub grant_type: String, -} - -/// IdC Token 刷新响应体 (AWS SSO OIDC) -#[derive(Debug, Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct IdcRefreshResponse { - pub access_token: String, - #[serde(default)] - pub refresh_token: Option, - // #[serde(default)] - // pub token_type: Option, - #[serde(default)] - pub expires_in: Option, -} diff --git a/src/kiro/model/usage_limits.rs b/src/kiro/model/usage_limits.rs deleted file mode 100644 index 719547d8edfd62282de5140df320746ef18f8e36..0000000000000000000000000000000000000000 --- a/src/kiro/model/usage_limits.rs +++ /dev/null @@ -1,200 +0,0 @@ -//! 使用额度查询数据模型 -//! -//! 包含 getUsageLimits API 的响应类型定义 - -use serde::Deserialize; - -/// 使用额度查询响应 -#[derive(Debug, Clone, Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct UsageLimitsResponse { - /// 下次重置日期 (Unix 时间戳) - #[serde(default)] - pub next_date_reset: Option, - - /// 订阅信息 - #[serde(default)] - pub subscription_info: Option, - - /// 使用量明细列表 - #[serde(default)] - pub usage_breakdown_list: Vec, -} - -/// 订阅信息 -#[derive(Debug, Clone, Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct SubscriptionInfo { - /// 订阅标题 (KIRO PRO+ / KIRO FREE 等) - #[serde(default)] - pub subscription_title: Option, -} - -/// 使用量明细 -#[derive(Debug, Clone, Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct UsageBreakdown { - /// 当前使用量 - #[serde(default)] - pub current_usage: i64, - - /// 当前使用量(精确值) - #[serde(default)] - pub current_usage_with_precision: f64, - - /// 奖励额度列表 - #[serde(default)] - pub bonuses: Vec, - - /// 免费试用信息 - #[serde(default)] - pub free_trial_info: Option, - - /// 下次重置日期 (Unix 时间戳) - #[serde(default)] - pub next_date_reset: Option, - - /// 使用限额 - #[serde(default)] - pub usage_limit: i64, - - /// 使用限额(精确值) - #[serde(default)] - pub usage_limit_with_precision: f64, -} - -/// 奖励额度 -#[derive(Debug, Clone, Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct Bonus { - /// 当前使用量 - #[serde(default)] - pub current_usage: f64, - - /// 使用限额 - #[serde(default)] - pub usage_limit: f64, - - /// 状态 (ACTIVE / EXPIRED) - #[serde(default)] - pub status: Option, -} - -impl Bonus { - /// 检查 bonus 是否处于激活状态 - pub fn is_active(&self) -> bool { - self.status - .as_deref() - .map(|s| s == "ACTIVE") - .unwrap_or(false) - } -} - -/// 免费试用信息 -#[derive(Debug, Clone, Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct FreeTrialInfo { - /// 当前使用量 - #[serde(default)] - pub current_usage: i64, - - /// 当前使用量(精确值) - #[serde(default)] - pub current_usage_with_precision: f64, - - /// 免费试用过期时间 (Unix 时间戳) - #[serde(default)] - pub free_trial_expiry: Option, - - /// 免费试用状态 (ACTIVE / EXPIRED) - #[serde(default)] - pub free_trial_status: Option, - - /// 使用限额 - #[serde(default)] - pub usage_limit: i64, - - /// 使用限额(精确值) - #[serde(default)] - pub usage_limit_with_precision: f64, -} - -// ============ 便捷方法实现 ============ - -impl FreeTrialInfo { - /// 检查免费试用是否处于激活状态 - pub fn is_active(&self) -> bool { - self.free_trial_status - .as_deref() - .map(|s| s == "ACTIVE") - .unwrap_or(false) - } -} - -impl UsageLimitsResponse { - /// 获取订阅标题 - pub fn subscription_title(&self) -> Option<&str> { - self.subscription_info - .as_ref() - .and_then(|info| info.subscription_title.as_deref()) - } - - /// 获取第一个使用量明细 - fn primary_breakdown(&self) -> Option<&UsageBreakdown> { - self.usage_breakdown_list.first() - } - - /// 获取总使用限额(精确值) - /// - /// 累加基础额度、激活的免费试用额度和激活的奖励额度 - pub fn usage_limit(&self) -> f64 { - let Some(breakdown) = self.primary_breakdown() else { - return 0.0; - }; - - let mut total = breakdown.usage_limit_with_precision; - - // 累加激活的 free trial 额度 - if let Some(trial) = &breakdown.free_trial_info { - if trial.is_active() { - total += trial.usage_limit_with_precision; - } - } - - // 累加激活的 bonus 额度 - for bonus in &breakdown.bonuses { - if bonus.is_active() { - total += bonus.usage_limit; - } - } - - total - } - - /// 获取总当前使用量(精确值) - /// - /// 累加基础使用量、激活的免费试用使用量和激活的奖励使用量 - pub fn current_usage(&self) -> f64 { - let Some(breakdown) = self.primary_breakdown() else { - return 0.0; - }; - - let mut total = breakdown.current_usage_with_precision; - - // 累加激活的 free trial 使用量 - if let Some(trial) = &breakdown.free_trial_info { - if trial.is_active() { - total += trial.current_usage_with_precision; - } - } - - // 累加激活的 bonus 使用量 - for bonus in &breakdown.bonuses { - if bonus.is_active() { - total += bonus.current_usage; - } - } - - total - } -} diff --git a/src/kiro/parser/crc.rs b/src/kiro/parser/crc.rs deleted file mode 100644 index 03f47f4ea538c13acb3acfe341fe72700a45ea56..0000000000000000000000000000000000000000 --- a/src/kiro/parser/crc.rs +++ /dev/null @@ -1,37 +0,0 @@ -//! CRC32 校验实现 -//! -//! AWS Event Stream 使用 CRC32 (ISO-HDLC/以太网/ZIP 标准) - -use crc::{CRC_32_ISO_HDLC, Crc}; - -/// CRC32 计算器实例 (ISO-HDLC 标准,多项式 0xEDB88320) -const CRC32: Crc = Crc::::new(&CRC_32_ISO_HDLC); - -/// 计算 CRC32 校验和 (ISO-HDLC 标准) -/// -/// # Arguments -/// * `data` - 要计算校验和的数据 -/// -/// # Returns -/// CRC32 校验和值 -pub fn crc32(data: &[u8]) -> u32 { - CRC32.checksum(data) -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_crc32_empty() { - // 空数据的 CRC32 应该是 0 - assert_eq!(crc32(&[]), 0); - } - - #[test] - fn test_crc32_known_value() { - // "123456789" 的 CRC32 (ISO-HDLC) 值是 0xCBF43926 - let data = b"123456789"; - assert_eq!(crc32(data), 0xCBF43926); - } -} diff --git a/src/kiro/parser/decoder.rs b/src/kiro/parser/decoder.rs deleted file mode 100644 index 40d2cbf5f67fddf20f310e9e3c902cb0f5d208a0..0000000000000000000000000000000000000000 --- a/src/kiro/parser/decoder.rs +++ /dev/null @@ -1,465 +0,0 @@ -//! AWS Event Stream 流式解码器 -//! -//! 使用状态机处理流式数据,支持断点续传和容错处理 -//! -//! ## 状态机设计 -//! -//! 参考 kiro-kt 项目的状态机设计,采用四态模型: -//! -//! ```text -//! ┌─────────────────┐ -//! │ Ready │ (初始态,就绪接收数据) -//! └────────┬────────┘ -//! │ feed() 提供数据 -//! ↓ -//! ┌─────────────────┐ -//! │ Parsing │ decode() 尝试解析 -//! └────────┬────────┘ -//! │ -//! ┌────┴────────────┐ -//! ↓ ↓ -//! [成功] [失败] -//! │ │ -//! ↓ ├─> error_count++ -//! ┌─────────┐ │ -//! │ Ready │ ├─> error_count < max_errors? -//! └─────────┘ │ YES → Recovering → Ready -//! │ NO ↓ -//! ┌────────────┐ -//! │ Stopped │ (终止态) -//! └────────────┘ -//! ``` - -use super::error::{ParseError, ParseResult}; -use super::frame::{Frame, PRELUDE_SIZE, parse_frame}; -use bytes::{Buf, BytesMut}; - -/// 默认最大缓冲区大小 (16 MB) -pub const DEFAULT_MAX_BUFFER_SIZE: usize = 16 * 1024 * 1024; - -/// 默认最大连续错误数 -pub const DEFAULT_MAX_ERRORS: usize = 5; - -/// 默认初始缓冲区容量 -pub const DEFAULT_BUFFER_CAPACITY: usize = 8192; - -/// 解码器状态 -/// -/// 采用四态模型,参考 kiro-kt 的设计: -/// - Ready: 就绪状态,可以接收数据 -/// - Parsing: 正在解析帧 -/// - Recovering: 恢复中(尝试跳过损坏数据) -/// - Stopped: 已停止(错误过多,终止态) -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum DecoderState { - /// 就绪,可以接收数据 - Ready, - /// 正在解析帧 - Parsing, - /// 恢复中(跳过损坏数据) - Recovering, - /// 已停止(错误过多) - Stopped, -} - -/// 流式事件解码器 -/// -/// 用于从字节流中解析 AWS Event Stream 消息帧 -/// -/// # Example -/// -/// ```rust,ignore -/// use kiro_rs::kiro::parser::EventStreamDecoder; -/// -/// let mut decoder = EventStreamDecoder::new(); -/// -/// // 提供流数据 -/// decoder.feed(chunk)?; -/// -/// // 解码所有可用帧 -/// for result in decoder.decode_iter() { -/// match result { -/// Ok(frame) => println!("Got frame: {:?}", frame.event_type()), -/// Err(e) => eprintln!("Parse error: {}", e), -/// } -/// } -/// ``` -pub struct EventStreamDecoder { - /// 内部缓冲区 - buffer: BytesMut, - /// 当前状态 - state: DecoderState, - /// 已处理的帧数量 - frames_decoded: usize, - /// 连续错误计数 - error_count: usize, - /// 最大连续错误数 - max_errors: usize, - /// 最大缓冲区大小 - max_buffer_size: usize, - /// 跳过的字节数(用于调试) - bytes_skipped: usize, -} - -impl Default for EventStreamDecoder { - fn default() -> Self { - Self::new() - } -} - -impl EventStreamDecoder { - /// 创建新的解码器 - pub fn new() -> Self { - Self::with_capacity(DEFAULT_BUFFER_CAPACITY) - } - - /// 创建具有指定缓冲区大小的解码器 - pub fn with_capacity(capacity: usize) -> Self { - Self { - buffer: BytesMut::with_capacity(capacity), - state: DecoderState::Ready, - frames_decoded: 0, - error_count: 0, - max_errors: DEFAULT_MAX_ERRORS, - max_buffer_size: DEFAULT_MAX_BUFFER_SIZE, - bytes_skipped: 0, - } - } - - /// 创建具有自定义配置的解码器 - pub fn with_config(capacity: usize, max_errors: usize, max_buffer_size: usize) -> Self { - Self { - buffer: BytesMut::with_capacity(capacity), - state: DecoderState::Ready, - frames_decoded: 0, - error_count: 0, - max_errors, - max_buffer_size, - bytes_skipped: 0, - } - } - - /// 向解码器提供数据 - /// - /// # Returns - /// - `Ok(())` - 数据已添加到缓冲区 - /// - `Err(BufferOverflow)` - 缓冲区已满 - pub fn feed(&mut self, data: &[u8]) -> ParseResult<()> { - // 检查缓冲区大小限制 - let new_size = self.buffer.len() + data.len(); - if new_size > self.max_buffer_size { - return Err(ParseError::BufferOverflow { - size: new_size, - max: self.max_buffer_size, - }); - } - - self.buffer.extend_from_slice(data); - - // 从 Recovering 状态恢复到 Ready - if self.state == DecoderState::Recovering { - self.state = DecoderState::Ready; - } - - Ok(()) - } - - /// 尝试解码下一个帧 - /// - /// # Returns - /// - `Ok(Some(frame))` - 成功解码一个帧 - /// - `Ok(None)` - 数据不足,需要更多数据 - /// - `Err(e)` - 解码错误 - pub fn decode(&mut self) -> ParseResult> { - // 如果已停止,直接返回错误 - if self.state == DecoderState::Stopped { - return Err(ParseError::TooManyErrors { - count: self.error_count, - last_error: "解码器已停止".to_string(), - }); - } - - // 缓冲区为空,保持 Ready 状态 - if self.buffer.is_empty() { - self.state = DecoderState::Ready; - return Ok(None); - } - - // 转移到 Parsing 状态 - self.state = DecoderState::Parsing; - - match parse_frame(&self.buffer) { - Ok(Some((frame, consumed))) => { - // 成功解析 - self.buffer.advance(consumed); - self.state = DecoderState::Ready; - self.frames_decoded += 1; - self.error_count = 0; // 重置连续错误计数 - Ok(Some(frame)) - } - Ok(None) => { - // 数据不足,回到 Ready 状态等待更多数据 - self.state = DecoderState::Ready; - Ok(None) - } - Err(e) => { - self.error_count += 1; - let error_msg = e.to_string(); - - // 检查是否超过最大错误数 - if self.error_count >= self.max_errors { - self.state = DecoderState::Stopped; - tracing::error!( - "解码器停止: 连续 {} 次错误,最后错误: {}", - self.error_count, - error_msg - ); - return Err(ParseError::TooManyErrors { - count: self.error_count, - last_error: error_msg, - }); - } - - // 根据错误类型采用不同的恢复策略 - self.try_recover(&e); - self.state = DecoderState::Recovering; - Err(e) - } - } - } - - /// 创建解码迭代器 - pub fn decode_iter(&mut self) -> DecodeIter<'_> { - DecodeIter { decoder: self } - } - - /// 尝试容错恢复 - /// - /// 根据错误类型采用不同的恢复策略(参考 kiro-kt 的设计): - /// - Prelude 阶段错误(CRC 失败、长度异常):跳过 1 字节,尝试找下一帧边界 - /// - Data 阶段错误(Message CRC 失败、Header 解析失败):跳过整个损坏帧 - fn try_recover(&mut self, error: &ParseError) { - if self.buffer.is_empty() { - return; - } - - match error { - // Prelude 阶段错误:可能是帧边界错位,逐字节扫描找下一个有效边界 - ParseError::PreludeCrcMismatch { .. } - | ParseError::MessageTooSmall { .. } - | ParseError::MessageTooLarge { .. } => { - let skipped_byte = self.buffer[0]; - self.buffer.advance(1); - self.bytes_skipped += 1; - tracing::warn!( - "Prelude 错误恢复: 跳过字节 0x{:02x} (累计跳过 {} 字节)", - skipped_byte, - self.bytes_skipped - ); - } - - // Data 阶段错误:帧边界正确但数据损坏,跳过整个帧 - ParseError::MessageCrcMismatch { .. } | ParseError::HeaderParseFailed(_) => { - // 尝试读取 total_length 来跳过整帧 - if self.buffer.len() >= PRELUDE_SIZE { - let total_length = u32::from_be_bytes([ - self.buffer[0], - self.buffer[1], - self.buffer[2], - self.buffer[3], - ]) as usize; - - // 确保 total_length 合理且缓冲区有足够数据 - if total_length >= 16 && total_length <= self.buffer.len() { - tracing::warn!("Data 错误恢复: 跳过损坏帧 ({} 字节)", total_length); - self.buffer.advance(total_length); - self.bytes_skipped += total_length; - return; - } - } - - // 无法确定帧长度,回退到逐字节跳过 - let skipped_byte = self.buffer[0]; - self.buffer.advance(1); - self.bytes_skipped += 1; - tracing::warn!( - "Data 错误恢复 (回退): 跳过字节 0x{:02x} (累计跳过 {} 字节)", - skipped_byte, - self.bytes_skipped - ); - } - - // 其他错误:逐字节跳过 - _ => { - let skipped_byte = self.buffer[0]; - self.buffer.advance(1); - self.bytes_skipped += 1; - tracing::warn!( - "通用错误恢复: 跳过字节 0x{:02x} (累计跳过 {} 字节)", - skipped_byte, - self.bytes_skipped - ); - } - } - } - - // ==================== 生命周期管理方法 ==================== - - /// 重置解码器到初始状态 - /// - /// 清空缓冲区和所有计数器,恢复到 Ready 状态 - pub fn reset(&mut self) { - self.buffer.clear(); - self.state = DecoderState::Ready; - self.frames_decoded = 0; - self.error_count = 0; - self.bytes_skipped = 0; - } - - /// 获取当前状态 - pub fn state(&self) -> DecoderState { - self.state - } - - /// 检查是否处于 Ready 状态 - pub fn is_ready(&self) -> bool { - self.state == DecoderState::Ready - } - - /// 检查是否处于 Stopped 状态 - pub fn is_stopped(&self) -> bool { - self.state == DecoderState::Stopped - } - - /// 检查是否处于 Recovering 状态 - pub fn is_recovering(&self) -> bool { - self.state == DecoderState::Recovering - } - - /// 获取已解码的帧数量 - pub fn frames_decoded(&self) -> usize { - self.frames_decoded - } - - /// 获取当前连续错误计数 - pub fn error_count(&self) -> usize { - self.error_count - } - - /// 获取跳过的字节数 - pub fn bytes_skipped(&self) -> usize { - self.bytes_skipped - } - - /// 获取缓冲区中待处理的字节数 - pub fn buffer_len(&self) -> usize { - self.buffer.len() - } - - /// 尝试从 Stopped 状态恢复 - /// - /// 重置错误计数并转移到 Ready 状态 - /// 注意:缓冲区内容保留,可能仍包含损坏数据 - pub fn try_resume(&mut self) { - if self.state == DecoderState::Stopped { - self.error_count = 0; - self.state = DecoderState::Ready; - tracing::info!("解码器从 Stopped 状态恢复"); - } - } -} - -/// 解码迭代器 -pub struct DecodeIter<'a> { - decoder: &'a mut EventStreamDecoder, -} - -impl<'a> Iterator for DecodeIter<'a> { - type Item = ParseResult; - - fn next(&mut self) -> Option { - // 如果处于 Stopped 或 Recovering 状态,停止迭代 - match self.decoder.state { - DecoderState::Stopped => return None, - DecoderState::Recovering => return None, - _ => {} - } - - match self.decoder.decode() { - Ok(Some(frame)) => Some(Ok(frame)), - Ok(None) => None, - Err(e) => Some(Err(e)), - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_decoder_new() { - let decoder = EventStreamDecoder::new(); - assert_eq!(decoder.state(), DecoderState::Ready); - assert_eq!(decoder.frames_decoded(), 0); - assert_eq!(decoder.error_count(), 0); - } - - #[test] - fn test_decoder_feed() { - let mut decoder = EventStreamDecoder::new(); - assert!(decoder.feed(&[1, 2, 3, 4]).is_ok()); - assert_eq!(decoder.buffer_len(), 4); - } - - #[test] - fn test_decoder_buffer_overflow() { - let mut decoder = EventStreamDecoder::with_config(1024, 5, 100); - let result = decoder.feed(&[0u8; 101]); - assert!(matches!(result, Err(ParseError::BufferOverflow { .. }))); - } - - #[test] - fn test_decoder_insufficient_data() { - let mut decoder = EventStreamDecoder::new(); - decoder.feed(&[0u8; 10]).unwrap(); - - let result = decoder.decode(); - assert!(matches!(result, Ok(None))); - assert_eq!(decoder.state(), DecoderState::Ready); - } - - #[test] - fn test_decoder_reset() { - let mut decoder = EventStreamDecoder::new(); - decoder.feed(&[1, 2, 3, 4]).unwrap(); - - decoder.reset(); - assert_eq!(decoder.state(), DecoderState::Ready); - assert_eq!(decoder.buffer_len(), 0); - assert_eq!(decoder.frames_decoded(), 0); - } - - #[test] - fn test_decoder_state_transitions() { - let decoder = EventStreamDecoder::new(); - - // 初始状态 - assert!(decoder.is_ready()); - assert!(!decoder.is_stopped()); - assert!(!decoder.is_recovering()); - } - - #[test] - fn test_decoder_try_resume() { - let mut decoder = EventStreamDecoder::new(); - - // 手动设置为 Stopped 状态进行测试 - decoder.state = DecoderState::Stopped; - decoder.error_count = 5; - - decoder.try_resume(); - assert!(decoder.is_ready()); - assert_eq!(decoder.error_count(), 0); - } -} diff --git a/src/kiro/parser/error.rs b/src/kiro/parser/error.rs deleted file mode 100644 index 5918a6dfffda1dff31b113201fad9d2bfc28dbe6..0000000000000000000000000000000000000000 --- a/src/kiro/parser/error.rs +++ /dev/null @@ -1,94 +0,0 @@ -//! AWS Event Stream 解析错误定义 - -use std::fmt; - -/// 解析错误类型 -#[derive(Debug)] -pub enum ParseError { - /// 数据不足,需要更多字节 - Incomplete { needed: usize, available: usize }, - /// Prelude CRC 校验失败 - PreludeCrcMismatch { expected: u32, actual: u32 }, - /// Message CRC 校验失败 - MessageCrcMismatch { expected: u32, actual: u32 }, - /// 无效的头部值类型 - InvalidHeaderType(u8), - /// 头部解析错误 - HeaderParseFailed(String), - /// 消息长度超限 - MessageTooLarge { length: u32, max: u32 }, - /// 消息长度过小 - MessageTooSmall { length: u32, min: u32 }, - /// 无效的消息类型 - InvalidMessageType(String), - /// Payload 反序列化失败 - PayloadDeserialize(serde_json::Error), - /// IO 错误 - Io(std::io::Error), - /// 连续错误过多,解码器已停止 - TooManyErrors { count: usize, last_error: String }, - /// 缓冲区溢出 - BufferOverflow { size: usize, max: usize }, -} - -impl std::error::Error for ParseError {} - -impl fmt::Display for ParseError { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - Self::Incomplete { needed, available } => { - write!(f, "数据不足: 需要 {} 字节, 当前 {} 字节", needed, available) - } - Self::PreludeCrcMismatch { expected, actual } => { - write!( - f, - "Prelude CRC 校验失败: 期望 0x{:08x}, 实际 0x{:08x}", - expected, actual - ) - } - Self::MessageCrcMismatch { expected, actual } => { - write!( - f, - "Message CRC 校验失败: 期望 0x{:08x}, 实际 0x{:08x}", - expected, actual - ) - } - Self::InvalidHeaderType(t) => write!(f, "无效的头部值类型: {}", t), - Self::HeaderParseFailed(msg) => write!(f, "头部解析失败: {}", msg), - Self::MessageTooLarge { length, max } => { - write!(f, "消息长度超限: {} 字节 (最大 {})", length, max) - } - Self::MessageTooSmall { length, min } => { - write!(f, "消息长度过小: {} 字节 (最小 {})", length, min) - } - Self::InvalidMessageType(t) => write!(f, "无效的消息类型: {}", t), - Self::PayloadDeserialize(e) => write!(f, "Payload 反序列化失败: {}", e), - Self::Io(e) => write!(f, "IO 错误: {}", e), - Self::TooManyErrors { count, last_error } => { - write!( - f, - "连续错误过多 ({} 次),解码器已停止: {}", - count, last_error - ) - } - Self::BufferOverflow { size, max } => { - write!(f, "缓冲区溢出: {} 字节 (最大 {})", size, max) - } - } - } -} - -impl From for ParseError { - fn from(e: std::io::Error) -> Self { - Self::Io(e) - } -} - -impl From for ParseError { - fn from(e: serde_json::Error) -> Self { - Self::PayloadDeserialize(e) - } -} - -/// 解析结果类型 -pub type ParseResult = Result; diff --git a/src/kiro/parser/frame.rs b/src/kiro/parser/frame.rs deleted file mode 100644 index 0bf8db804fcf7ff86264d8949c73ad45254a93f7..0000000000000000000000000000000000000000 --- a/src/kiro/parser/frame.rs +++ /dev/null @@ -1,178 +0,0 @@ -//! AWS Event Stream 消息帧解析 -//! -//! ## 消息格式 -//! -//! ```text -//! ┌──────────────┬──────────────┬──────────────┬──────────┬──────────┬───────────┐ -//! │ Total Length │ Header Length│ Prelude CRC │ Headers │ Payload │ Msg CRC │ -//! │ (4 bytes) │ (4 bytes) │ (4 bytes) │ (变长) │ (变长) │ (4 bytes) │ -//! └──────────────┴──────────────┴──────────────┴──────────┴──────────┴───────────┘ -//! ``` -//! -//! - Total Length: 整个消息的总长度(包括自身 4 字节) -//! - Header Length: 头部数据的长度 -//! - Prelude CRC: 前 8 字节(Total Length + Header Length)的 CRC32 校验 -//! - Headers: 头部数据 -//! - Payload: 载荷数据(通常是 JSON) -//! - Message CRC: 整个消息(不含 Message CRC 自身)的 CRC32 校验 - -use super::crc::crc32; -use super::error::{ParseError, ParseResult}; -use super::header::{Headers, parse_headers}; - -/// Prelude 固定大小 (12 字节) -pub const PRELUDE_SIZE: usize = 12; - -/// 最小消息大小 (Prelude + Message CRC) -pub const MIN_MESSAGE_SIZE: usize = PRELUDE_SIZE + 4; - -/// 最大消息大小限制 (16 MB) -pub const MAX_MESSAGE_SIZE: u32 = 16 * 1024 * 1024; - -/// 解析后的消息帧 -#[derive(Debug, Clone)] -pub struct Frame { - /// 消息头部 - pub headers: Headers, - /// 消息负载 - pub payload: Vec, -} - -impl Frame { - /// 获取消息类型 - pub fn message_type(&self) -> Option<&str> { - self.headers.message_type() - } - - /// 获取事件类型 - pub fn event_type(&self) -> Option<&str> { - self.headers.event_type() - } - - /// 将 payload 解析为 JSON - pub fn payload_as_json(&self) -> ParseResult { - serde_json::from_slice(&self.payload).map_err(ParseError::PayloadDeserialize) - } - - /// 将 payload 解析为字符串 - pub fn payload_as_str(&self) -> String { - String::from_utf8_lossy(&self.payload).to_string() - } -} - -/// 尝试从缓冲区解析一个完整的帧 -/// -/// 这是一个无状态的纯函数,每次调用独立解析。 -/// 缓冲区管理由上层 `EventStreamDecoder` 负责。 -/// -/// # Arguments -/// * `buffer` - 输入缓冲区 -/// -/// # Returns -/// - `Ok(Some((frame, consumed)))` - 成功解析,返回帧和消费的字节数 -/// - `Ok(None)` - 数据不足,需要更多数据 -/// - `Err(e)` - 解析错误 -pub fn parse_frame(buffer: &[u8]) -> ParseResult> { - // 检查是否有足够的数据读取 prelude - if buffer.len() < PRELUDE_SIZE { - return Ok(None); - } - - // 读取 prelude - let total_length = u32::from_be_bytes([buffer[0], buffer[1], buffer[2], buffer[3]]); - let header_length = u32::from_be_bytes([buffer[4], buffer[5], buffer[6], buffer[7]]); - let prelude_crc = u32::from_be_bytes([buffer[8], buffer[9], buffer[10], buffer[11]]); - - // 验证消息长度范围 - if total_length < MIN_MESSAGE_SIZE as u32 { - return Err(ParseError::MessageTooSmall { - length: total_length, - min: MIN_MESSAGE_SIZE as u32, - }); - } - - if total_length > MAX_MESSAGE_SIZE { - return Err(ParseError::MessageTooLarge { - length: total_length, - max: MAX_MESSAGE_SIZE, - }); - } - - let total_length = total_length as usize; - let header_length = header_length as usize; - - // 检查是否有完整的消息 - if buffer.len() < total_length { - return Ok(None); - } - - // 验证 Prelude CRC - let actual_prelude_crc = crc32(&buffer[..8]); - if actual_prelude_crc != prelude_crc { - return Err(ParseError::PreludeCrcMismatch { - expected: prelude_crc, - actual: actual_prelude_crc, - }); - } - - // 读取 Message CRC - let message_crc = u32::from_be_bytes([ - buffer[total_length - 4], - buffer[total_length - 3], - buffer[total_length - 2], - buffer[total_length - 1], - ]); - - // 验证 Message CRC (对整个消息不含最后4字节) - let actual_message_crc = crc32(&buffer[..total_length - 4]); - if actual_message_crc != message_crc { - return Err(ParseError::MessageCrcMismatch { - expected: message_crc, - actual: actual_message_crc, - }); - } - - // 解析头部 - let headers_start = PRELUDE_SIZE; - let headers_end = headers_start + header_length; - - // 验证头部边界 - if headers_end > total_length - 4 { - return Err(ParseError::HeaderParseFailed( - "头部长度超出消息边界".to_string(), - )); - } - - let headers = parse_headers(&buffer[headers_start..headers_end], header_length)?; - - // 提取 payload (去除最后4字节的 message_crc) - let payload_start = headers_end; - let payload_end = total_length - 4; - let payload = buffer[payload_start..payload_end].to_vec(); - - Ok(Some((Frame { headers, payload }, total_length))) -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_frame_insufficient_data() { - let buffer = [0u8; 10]; // 小于 PRELUDE_SIZE - assert!(matches!(parse_frame(&buffer), Ok(None))); - } - - #[test] - fn test_frame_message_too_small() { - // 构造一个 total_length = 10 的 prelude (小于最小值) - let mut buffer = vec![0u8; 16]; - buffer[0..4].copy_from_slice(&10u32.to_be_bytes()); // total_length - buffer[4..8].copy_from_slice(&0u32.to_be_bytes()); // header_length - let prelude_crc = crc32(&buffer[0..8]); - buffer[8..12].copy_from_slice(&prelude_crc.to_be_bytes()); - - let result = parse_frame(&buffer); - assert!(matches!(result, Err(ParseError::MessageTooSmall { .. }))); - } -} diff --git a/src/kiro/parser/header.rs b/src/kiro/parser/header.rs deleted file mode 100644 index 449832cd580e1bb2756ab19191a981504e8547e5..0000000000000000000000000000000000000000 --- a/src/kiro/parser/header.rs +++ /dev/null @@ -1,317 +0,0 @@ -//! AWS Event Stream 头部解析 -//! -//! 实现 AWS Event Stream 协议的头部解析功能 - -use super::error::{ParseError, ParseResult}; -use std::collections::HashMap; - -/// 头部值类型标识 -/// -/// AWS Event Stream 协议定义的 10 种值类型 -#[repr(u8)] -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum HeaderValueType { - BoolTrue = 0, - BoolFalse = 1, - Byte = 2, - Short = 3, - Integer = 4, - Long = 5, - ByteArray = 6, - String = 7, - Timestamp = 8, - Uuid = 9, -} - -impl TryFrom for HeaderValueType { - type Error = ParseError; - - fn try_from(value: u8) -> ParseResult { - match value { - 0 => Ok(Self::BoolTrue), - 1 => Ok(Self::BoolFalse), - 2 => Ok(Self::Byte), - 3 => Ok(Self::Short), - 4 => Ok(Self::Integer), - 5 => Ok(Self::Long), - 6 => Ok(Self::ByteArray), - 7 => Ok(Self::String), - 8 => Ok(Self::Timestamp), - 9 => Ok(Self::Uuid), - _ => Err(ParseError::InvalidHeaderType(value)), - } - } -} - -/// 头部值 -/// -/// 支持 AWS Event Stream 协议定义的所有值类型 -#[derive(Debug, Clone, PartialEq)] -pub enum HeaderValue { - Bool(bool), - Byte(i8), - Short(i16), - Integer(i32), - Long(i64), - ByteArray(Vec), - String(String), - Timestamp(i64), - Uuid([u8; 16]), -} - -impl HeaderValue { - /// 尝试获取字符串值 - pub fn as_str(&self) -> Option<&str> { - match self { - Self::String(s) => Some(s), - _ => None, - } - } -} - -/// 消息头部集合 -#[derive(Debug, Clone, Default)] -pub struct Headers { - inner: HashMap, -} - -impl Headers { - /// 创建空的头部集合 - pub fn new() -> Self { - Self { - inner: HashMap::new(), - } - } - - /// 插入头部 - pub fn insert(&mut self, name: String, value: HeaderValue) { - self.inner.insert(name, value); - } - - /// 获取头部值 - pub fn get(&self, name: &str) -> Option<&HeaderValue> { - self.inner.get(name) - } - - /// 获取字符串类型的头部值 - pub fn get_string(&self, name: &str) -> Option<&str> { - self.get(name).and_then(|v| v.as_str()) - } - - /// 获取消息类型 (:message-type) - pub fn message_type(&self) -> Option<&str> { - self.get_string(":message-type") - } - - /// 获取事件类型 (:event-type) - pub fn event_type(&self) -> Option<&str> { - self.get_string(":event-type") - } - - /// 获取异常类型 (:exception-type) - pub fn exception_type(&self) -> Option<&str> { - self.get_string(":exception-type") - } - - /// 获取错误代码 (:error-code) - pub fn error_code(&self) -> Option<&str> { - self.get_string(":error-code") - } -} - -/// 从字节流解析头部 -/// -/// # Arguments -/// * `data` - 头部数据切片 -/// * `header_length` - 头部总长度 -/// -/// # Returns -/// 解析后的 Headers 结构 -pub fn parse_headers(data: &[u8], header_length: usize) -> ParseResult { - // 验证数据长度是否足够 - if data.len() < header_length { - return Err(ParseError::Incomplete { - needed: header_length, - available: data.len(), - }); - } - - let mut headers = Headers::new(); - let mut offset = 0; - - while offset < header_length { - // 读取头部名称长度 (1 byte) - if offset >= data.len() { - break; - } - let name_len = data[offset] as usize; - offset += 1; - - // 验证名称长度 - if name_len == 0 { - return Err(ParseError::HeaderParseFailed( - "头部名称长度不能为 0".to_string(), - )); - } - - // 读取头部名称 - if offset + name_len > data.len() { - return Err(ParseError::Incomplete { - needed: name_len, - available: data.len() - offset, - }); - } - let name = String::from_utf8_lossy(&data[offset..offset + name_len]).to_string(); - offset += name_len; - - // 读取值类型 (1 byte) - if offset >= data.len() { - return Err(ParseError::Incomplete { - needed: 1, - available: 0, - }); - } - let value_type = HeaderValueType::try_from(data[offset])?; - offset += 1; - - // 根据类型解析值 - let value = parse_header_value(&data[offset..], value_type, &mut offset)?; - headers.insert(name, value); - } - - Ok(headers) -} - -/// 解析头部值 -fn parse_header_value( - data: &[u8], - value_type: HeaderValueType, - global_offset: &mut usize, -) -> ParseResult { - let mut local_offset = 0; - - let result = match value_type { - HeaderValueType::BoolTrue => Ok(HeaderValue::Bool(true)), - HeaderValueType::BoolFalse => Ok(HeaderValue::Bool(false)), - HeaderValueType::Byte => { - ensure_bytes(data, 1)?; - let v = data[0] as i8; - local_offset = 1; - Ok(HeaderValue::Byte(v)) - } - HeaderValueType::Short => { - ensure_bytes(data, 2)?; - let v = i16::from_be_bytes([data[0], data[1]]); - local_offset = 2; - Ok(HeaderValue::Short(v)) - } - HeaderValueType::Integer => { - ensure_bytes(data, 4)?; - let v = i32::from_be_bytes([data[0], data[1], data[2], data[3]]); - local_offset = 4; - Ok(HeaderValue::Integer(v)) - } - HeaderValueType::Long => { - ensure_bytes(data, 8)?; - let v = i64::from_be_bytes([ - data[0], data[1], data[2], data[3], data[4], data[5], data[6], data[7], - ]); - local_offset = 8; - Ok(HeaderValue::Long(v)) - } - HeaderValueType::Timestamp => { - ensure_bytes(data, 8)?; - let v = i64::from_be_bytes([ - data[0], data[1], data[2], data[3], data[4], data[5], data[6], data[7], - ]); - local_offset = 8; - Ok(HeaderValue::Timestamp(v)) - } - HeaderValueType::ByteArray => { - ensure_bytes(data, 2)?; - let len = u16::from_be_bytes([data[0], data[1]]) as usize; - ensure_bytes(data, 2 + len)?; - let v = data[2..2 + len].to_vec(); - local_offset = 2 + len; - Ok(HeaderValue::ByteArray(v)) - } - HeaderValueType::String => { - ensure_bytes(data, 2)?; - let len = u16::from_be_bytes([data[0], data[1]]) as usize; - ensure_bytes(data, 2 + len)?; - let v = String::from_utf8_lossy(&data[2..2 + len]).to_string(); - local_offset = 2 + len; - Ok(HeaderValue::String(v)) - } - HeaderValueType::Uuid => { - ensure_bytes(data, 16)?; - let mut uuid = [0u8; 16]; - uuid.copy_from_slice(&data[..16]); - local_offset = 16; - Ok(HeaderValue::Uuid(uuid)) - } - }; - - *global_offset += local_offset; - result -} - -/// 确保有足够的字节 -fn ensure_bytes(data: &[u8], needed: usize) -> ParseResult<()> { - if data.len() < needed { - Err(ParseError::Incomplete { - needed, - available: data.len(), - }) - } else { - Ok(()) - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_header_value_type_conversion() { - assert_eq!( - HeaderValueType::try_from(0).unwrap(), - HeaderValueType::BoolTrue - ); - assert_eq!( - HeaderValueType::try_from(7).unwrap(), - HeaderValueType::String - ); - assert!(HeaderValueType::try_from(10).is_err()); - } - - #[test] - fn test_header_value_as_str() { - let value = HeaderValue::String("test".to_string()); - assert_eq!(value.as_str(), Some("test")); - - let value = HeaderValue::Bool(true); - assert_eq!(value.as_str(), None); - } - - #[test] - fn test_headers_get_string() { - let mut headers = Headers::new(); - headers.insert( - ":message-type".to_string(), - HeaderValue::String("event".to_string()), - ); - assert_eq!(headers.message_type(), Some("event")); - } - - #[test] - fn test_parse_headers_string() { - // 构造一个简单的头部: name_len(1) + name + type(7=string) + value_len(2) + value - // 头部名: "x" (长度 1) - // 值类型: 7 (String) - // 值: "ab" (长度 2) - let data = [1u8, b'x', 7, 0, 2, b'a', b'b']; - let headers = parse_headers(&data, data.len()).unwrap(); - assert_eq!(headers.get_string("x"), Some("ab")); - } -} diff --git a/src/kiro/parser/mod.rs b/src/kiro/parser/mod.rs deleted file mode 100644 index d7ab70ee94b35c10efbb8cb08b52026d5fc66134..0000000000000000000000000000000000000000 --- a/src/kiro/parser/mod.rs +++ /dev/null @@ -1,10 +0,0 @@ -//! AWS Event Stream 解析器 -//! -//! 提供对 AWS Event Stream 协议的解析支持, -//! 用于处理 generateAssistantResponse 端点的流式响应 - -pub mod crc; -pub mod decoder; -pub mod error; -pub mod frame; -pub mod header; diff --git a/src/kiro/provider.rs b/src/kiro/provider.rs deleted file mode 100644 index 13f78ffbe1aadee0a028be13950abdc970e6819a..0000000000000000000000000000000000000000 --- a/src/kiro/provider.rs +++ /dev/null @@ -1,639 +0,0 @@ -//! Kiro API Provider -//! -//! 核心组件,负责与 Kiro API 通信 -//! 支持流式和非流式请求 -//! 支持多凭据故障转移和重试 - -use reqwest::Client; -use reqwest::header::{AUTHORIZATION, CONNECTION, CONTENT_TYPE, HOST, HeaderMap, HeaderValue}; -use std::sync::Arc; -use std::time::Duration; -use tokio::time::sleep; -use uuid::Uuid; - -use crate::http_client::{ProxyConfig, build_client}; -use crate::kiro::machine_id; -use crate::kiro::token_manager::{CallContext, MultiTokenManager}; - -#[cfg(test)] -use crate::kiro::model::credentials::KiroCredentials; - -/// 每个凭据的最大重试次数 -const MAX_RETRIES_PER_CREDENTIAL: usize = 3; - -/// 总重试次数硬上限(避免无限重试) -const MAX_TOTAL_RETRIES: usize = 9; - -/// Kiro API Provider -/// -/// 核心组件,负责与 Kiro API 通信 -/// 支持多凭据故障转移和重试机制 -pub struct KiroProvider { - token_manager: Arc, - client: Client, -} - -impl KiroProvider { - /// 创建新的 KiroProvider 实例 - pub fn new(token_manager: Arc) -> Self { - Self::with_proxy(token_manager, None) - } - - /// 创建带代理配置的 KiroProvider 实例 - pub fn with_proxy(token_manager: Arc, proxy: Option) -> Self { - let client = build_client(proxy.as_ref(), 720) // 12 分钟超时 - .expect("创建 HTTP 客户端失败"); - - Self { - token_manager, - client, - } - } - - /// 获取 token_manager 的引用 - pub fn token_manager(&self) -> &MultiTokenManager { - &self.token_manager - } - - /// 获取 API 基础 URL - pub fn base_url(&self) -> String { - format!( - "https://q.{}.amazonaws.com/generateAssistantResponse", - self.token_manager.config().region - ) - } - - /// 获取 MCP API URL - pub fn mcp_url(&self) -> String { - format!( - "https://q.{}.amazonaws.com/mcp", - self.token_manager.config().region - ) - } - - /// 获取 API 基础域名 - pub fn base_domain(&self) -> String { - format!("q.{}.amazonaws.com", self.token_manager.config().region) - } - - /// 构建请求头 - /// - /// # Arguments - /// * `ctx` - API 调用上下文,包含凭据和 token - fn build_headers(&self, ctx: &CallContext) -> anyhow::Result { - let config = self.token_manager.config(); - - let machine_id = machine_id::generate_from_credentials(&ctx.credentials, config) - .ok_or_else(|| anyhow::anyhow!("无法生成 machine_id,请检查凭证配置"))?; - - let kiro_version = &config.kiro_version; - let os_name = &config.system_version; - let node_version = &config.node_version; - - let x_amz_user_agent = format!("aws-sdk-js/1.0.27 KiroIDE-{}-{}", kiro_version, machine_id); - - let user_agent = format!( - "aws-sdk-js/1.0.27 ua/2.1 os/{} lang/js md/nodejs#{} api/codewhispererstreaming#1.0.27 m/E KiroIDE-{}-{}", - os_name, node_version, kiro_version, machine_id - ); - - let mut headers = HeaderMap::new(); - - headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json")); - headers.insert( - "x-amzn-codewhisperer-optout", - HeaderValue::from_static("true"), - ); - headers.insert("x-amzn-kiro-agent-mode", HeaderValue::from_static("vibe")); - headers.insert( - "x-amz-user-agent", - HeaderValue::from_str(&x_amz_user_agent).unwrap(), - ); - headers.insert( - reqwest::header::USER_AGENT, - HeaderValue::from_str(&user_agent).unwrap(), - ); - headers.insert(HOST, HeaderValue::from_str(&self.base_domain()).unwrap()); - headers.insert( - "amz-sdk-invocation-id", - HeaderValue::from_str(&Uuid::new_v4().to_string()).unwrap(), - ); - headers.insert( - "amz-sdk-request", - HeaderValue::from_static("attempt=1; max=3"), - ); - headers.insert( - AUTHORIZATION, - HeaderValue::from_str(&format!("Bearer {}", ctx.token)).unwrap(), - ); - headers.insert(CONNECTION, HeaderValue::from_static("close")); - - Ok(headers) - } - - /// 构建 MCP 请求头 - fn build_mcp_headers(&self, ctx: &CallContext) -> anyhow::Result { - let config = self.token_manager.config(); - - let machine_id = machine_id::generate_from_credentials(&ctx.credentials, config) - .ok_or_else(|| anyhow::anyhow!("无法生成 machine_id,请检查凭证配置"))?; - - let kiro_version = &config.kiro_version; - let os_name = &config.system_version; - let node_version = &config.node_version; - - let x_amz_user_agent = format!("aws-sdk-js/1.0.27 KiroIDE-{}-{}", kiro_version, machine_id); - - let user_agent = format!( - "aws-sdk-js/1.0.27 ua/2.1 os/{} lang/js md/nodejs#{} api/codewhispererstreaming#1.0.27 m/E KiroIDE-{}-{}", - os_name, node_version, kiro_version, machine_id - ); - - let mut headers = HeaderMap::new(); - - // 按照严格顺序添加请求头 - headers.insert( - "content-type", - HeaderValue::from_static("application/json"), - ); - headers.insert( - "x-amz-user-agent", - HeaderValue::from_str(&x_amz_user_agent).unwrap(), - ); - headers.insert( - "user-agent", - HeaderValue::from_str(&user_agent).unwrap(), - ); - headers.insert( - "host", - HeaderValue::from_str(&self.base_domain()).unwrap(), - ); - headers.insert( - "amz-sdk-invocation-id", - HeaderValue::from_str(&Uuid::new_v4().to_string()).unwrap(), - ); - headers.insert( - "amz-sdk-request", - HeaderValue::from_static("attempt=1; max=3"), - ); - headers.insert( - "Authorization", - HeaderValue::from_str(&format!("Bearer {}", ctx.token)).unwrap(), - ); - headers.insert("Connection", HeaderValue::from_static("close")); - - Ok(headers) - } - - /// 发送非流式 API 请求 - /// - /// 支持多凭据故障转移: - /// - 400 Bad Request: 直接返回错误,不计入凭据失败 - /// - 401/403: 视为凭据/权限问题,计入失败次数并允许故障转移 - /// - 402 MONTHLY_REQUEST_COUNT: 视为额度用尽,禁用凭据并切换 - /// - 429/5xx/网络等瞬态错误: 重试但不禁用或切换凭据(避免误把所有凭据锁死) - /// - /// # Arguments - /// * `request_body` - JSON 格式的请求体字符串 - /// - /// # Returns - /// 返回原始的 HTTP Response,不做解析 - pub async fn call_api(&self, request_body: &str) -> anyhow::Result { - self.call_api_with_retry(request_body, false).await - } - - /// 发送流式 API 请求 - /// - /// 支持多凭据故障转移: - /// - 400 Bad Request: 直接返回错误,不计入凭据失败 - /// - 401/403: 视为凭据/权限问题,计入失败次数并允许故障转移 - /// - 402 MONTHLY_REQUEST_COUNT: 视为额度用尽,禁用凭据并切换 - /// - 429/5xx/网络等瞬态错误: 重试但不禁用或切换凭据(避免误把所有凭据锁死) - /// - /// # Arguments - /// * `request_body` - JSON 格式的请求体字符串 - /// - /// # Returns - /// 返回原始的 HTTP Response,调用方负责处理流式数据 - pub async fn call_api_stream(&self, request_body: &str) -> anyhow::Result { - self.call_api_with_retry(request_body, true).await - } - - /// 发送 MCP API 请求 - /// - /// 用于 WebSearch 等工具调用 - /// - /// # Arguments - /// * `request_body` - JSON 格式的 MCP 请求体字符串 - /// - /// # Returns - /// 返回原始的 HTTP Response - pub async fn call_mcp(&self, request_body: &str) -> anyhow::Result { - self.call_mcp_with_retry(request_body).await - } - - /// 内部方法:带重试逻辑的 MCP API 调用 - async fn call_mcp_with_retry(&self, request_body: &str) -> anyhow::Result { - let total_credentials = self.token_manager.total_count(); - let max_retries = (total_credentials * MAX_RETRIES_PER_CREDENTIAL).min(MAX_TOTAL_RETRIES); - let mut last_error: Option = None; - - for attempt in 0..max_retries { - // 获取调用上下文 - let ctx = match self.token_manager.acquire_context().await { - Ok(c) => c, - Err(e) => { - last_error = Some(e); - continue; - } - }; - - let url = self.mcp_url(); - let headers = match self.build_mcp_headers(&ctx) { - Ok(h) => h, - Err(e) => { - last_error = Some(e); - continue; - } - }; - - // 发送请求 - let response = match self - .client - .post(&url) - .headers(headers) - .body(request_body.to_string()) - .send() - .await - { - Ok(resp) => resp, - Err(e) => { - tracing::warn!( - "MCP 请求发送失败(尝试 {}/{}): {}", - attempt + 1, - max_retries, - e - ); - last_error = Some(e.into()); - if attempt + 1 < max_retries { - sleep(Self::retry_delay(attempt)).await; - } - continue; - } - }; - - let status = response.status(); - - // 成功响应 - if status.is_success() { - self.token_manager.report_success(ctx.id); - return Ok(response); - } - - // 失败响应 - let body = response.text().await.unwrap_or_default(); - - // 402 额度用尽 - if status.as_u16() == 402 && Self::is_monthly_request_limit(&body) { - let has_available = self.token_manager.report_quota_exhausted(ctx.id); - if !has_available { - anyhow::bail!("MCP 请求失败(所有凭据已用尽): {} {}", status, body); - } - last_error = Some(anyhow::anyhow!("MCP 请求失败: {} {}", status, body)); - continue; - } - - // 400 Bad Request - if status.as_u16() == 400 { - anyhow::bail!("MCP 请求失败: {} {}", status, body); - } - - // 401/403 凭据问题 - if matches!(status.as_u16(), 401 | 403) { - let has_available = self.token_manager.report_failure(ctx.id); - if !has_available { - anyhow::bail!("MCP 请求失败(所有凭据已用尽): {} {}", status, body); - } - last_error = Some(anyhow::anyhow!("MCP 请求失败: {} {}", status, body)); - continue; - } - - // 瞬态错误 - if matches!(status.as_u16(), 408 | 429) || status.is_server_error() { - tracing::warn!( - "MCP 请求失败(上游瞬态错误,尝试 {}/{}): {} {}", - attempt + 1, - max_retries, - status, - body - ); - last_error = Some(anyhow::anyhow!("MCP 请求失败: {} {}", status, body)); - if attempt + 1 < max_retries { - sleep(Self::retry_delay(attempt)).await; - } - continue; - } - - // 其他 4xx - if status.is_client_error() { - anyhow::bail!("MCP 请求失败: {} {}", status, body); - } - - // 兜底 - last_error = Some(anyhow::anyhow!("MCP 请求失败: {} {}", status, body)); - if attempt + 1 < max_retries { - sleep(Self::retry_delay(attempt)).await; - } - } - - Err(last_error.unwrap_or_else(|| { - anyhow::anyhow!("MCP 请求失败:已达到最大重试次数({}次)", max_retries) - })) - } - - /// 内部方法:带重试逻辑的 API 调用 - /// - /// 重试策略: - /// - 每个凭据最多重试 MAX_RETRIES_PER_CREDENTIAL 次 - /// - 总重试次数 = min(凭据数量 × 每凭据重试次数, MAX_TOTAL_RETRIES) - /// - 硬上限 9 次,避免无限重试 - async fn call_api_with_retry( - &self, - request_body: &str, - is_stream: bool, - ) -> anyhow::Result { - let total_credentials = self.token_manager.total_count(); - let max_retries = (total_credentials * MAX_RETRIES_PER_CREDENTIAL).min(MAX_TOTAL_RETRIES); - let mut last_error: Option = None; - let api_type = if is_stream { "流式" } else { "非流式" }; - - for attempt in 0..max_retries { - // 获取调用上下文(绑定 index、credentials、token) - let ctx = match self.token_manager.acquire_context().await { - Ok(c) => c, - Err(e) => { - last_error = Some(e); - continue; - } - }; - - let url = self.base_url(); - let headers = match self.build_headers(&ctx) { - Ok(h) => h, - Err(e) => { - last_error = Some(e); - continue; - } - }; - - // 发送请求 - let response = match self - .client - .post(&url) - .headers(headers) - .body(request_body.to_string()) - .send() - .await - { - Ok(resp) => resp, - Err(e) => { - tracing::warn!( - "API 请求发送失败(尝试 {}/{}): {}", - attempt + 1, - max_retries, - e - ); - // 网络错误通常是上游/链路瞬态问题,不应导致"禁用凭据"或"切换凭据" - // (否则一段时间网络抖动会把所有凭据都误禁用,需要重启才能恢复) - last_error = Some(e.into()); - if attempt + 1 < max_retries { - sleep(Self::retry_delay(attempt)).await; - } - continue; - } - }; - - let status = response.status(); - - // 成功响应 - if status.is_success() { - self.token_manager.report_success(ctx.id); - return Ok(response); - } - - // 失败响应:读取 body 用于日志/错误信息 - let body = response.text().await.unwrap_or_default(); - - // 402 Payment Required 且额度用尽:禁用凭据并故障转移 - if status.as_u16() == 402 && Self::is_monthly_request_limit(&body) { - tracing::warn!( - "API 请求失败(额度已用尽,禁用凭据并切换,尝试 {}/{}): {} {}", - attempt + 1, - max_retries, - status, - body - ); - - let has_available = self.token_manager.report_quota_exhausted(ctx.id); - if !has_available { - anyhow::bail!( - "{} API 请求失败(所有凭据已用尽): {} {}", - api_type, - status, - body - ); - } - - last_error = Some(anyhow::anyhow!("{} API 请求失败: {} {}", api_type, status, body)); - continue; - } - - // 400 Bad Request - 请求问题,重试/切换凭据无意义 - if status.as_u16() == 400 { - anyhow::bail!("{} API 请求失败: {} {}", api_type, status, body); - } - - // 401/403 - 更可能是凭据/权限问题:计入失败并允许故障转移 - if matches!(status.as_u16(), 401 | 403) { - tracing::warn!( - "API 请求失败(可能为凭据错误,尝试 {}/{}): {} {}", - attempt + 1, - max_retries, - status, - body - ); - - let has_available = self.token_manager.report_failure(ctx.id); - if !has_available { - anyhow::bail!( - "{} API 请求失败(所有凭据已用尽): {} {}", - api_type, - status, - body - ); - } - - last_error = Some(anyhow::anyhow!("{} API 请求失败: {} {}", api_type, status, body)); - continue; - } - - // 429/408/5xx - 瞬态上游错误:重试但不禁用或切换凭据 - // (避免 429 high traffic / 502 high load 等瞬态错误把所有凭据锁死) - if matches!(status.as_u16(), 408 | 429) || status.is_server_error() { - tracing::warn!( - "API 请求失败(上游瞬态错误,尝试 {}/{}): {} {}", - attempt + 1, - max_retries, - status, - body - ); - last_error = Some(anyhow::anyhow!("{} API 请求失败: {} {}", api_type, status, body)); - if attempt + 1 < max_retries { - sleep(Self::retry_delay(attempt)).await; - } - continue; - } - - // 其他 4xx - 通常为请求/配置问题:直接返回,不计入凭据失败 - if status.is_client_error() { - anyhow::bail!("{} API 请求失败: {} {}", api_type, status, body); - } - - // 兜底:当作可重试的瞬态错误处理(不切换凭据) - tracing::warn!( - "API 请求失败(未知错误,尝试 {}/{}): {} {}", - attempt + 1, - max_retries, - status, - body - ); - last_error = Some(anyhow::anyhow!("{} API 请求失败: {} {}", api_type, status, body)); - if attempt + 1 < max_retries { - sleep(Self::retry_delay(attempt)).await; - } - } - - // 所有重试都失败 - Err(last_error.unwrap_or_else(|| { - anyhow::anyhow!( - "{} API 请求失败:已达到最大重试次数({}次)", - api_type, - max_retries - ) - })) - } - - fn retry_delay(attempt: usize) -> Duration { - // 指数退避 + 少量抖动,避免上游抖动时放大故障 - const BASE_MS: u64 = 200; - const MAX_MS: u64 = 2_000; - let exp = BASE_MS.saturating_mul(2u64.saturating_pow(attempt.min(6) as u32)); - let backoff = exp.min(MAX_MS); - let jitter_max = (backoff / 4).max(1); - let jitter = fastrand::u64(0..=jitter_max); - Duration::from_millis(backoff.saturating_add(jitter)) - } - - fn is_monthly_request_limit(body: &str) -> bool { - if body.contains("MONTHLY_REQUEST_COUNT") { - return true; - } - - let Ok(value) = serde_json::from_str::(body) else { - return false; - }; - - if value - .get("reason") - .and_then(|v| v.as_str()) - .is_some_and(|v| v == "MONTHLY_REQUEST_COUNT") - { - return true; - } - - value - .pointer("/error/reason") - .and_then(|v| v.as_str()) - .is_some_and(|v| v == "MONTHLY_REQUEST_COUNT") - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::kiro::token_manager::CallContext; - use crate::model::config::Config; - - fn create_test_provider(config: Config, credentials: KiroCredentials) -> KiroProvider { - let tm = MultiTokenManager::new(config, vec![credentials], None, None, false).unwrap(); - KiroProvider::new(Arc::new(tm)) - } - - #[test] - fn test_base_url() { - let config = Config::default(); - let credentials = KiroCredentials::default(); - let provider = create_test_provider(config, credentials); - assert!(provider.base_url().contains("amazonaws.com")); - assert!(provider.base_url().contains("generateAssistantResponse")); - } - - #[test] - fn test_base_domain() { - let mut config = Config::default(); - config.region = "us-east-1".to_string(); - let credentials = KiroCredentials::default(); - let provider = create_test_provider(config, credentials); - assert_eq!(provider.base_domain(), "q.us-east-1.amazonaws.com"); - } - - #[test] - fn test_build_headers() { - let mut config = Config::default(); - config.region = "us-east-1".to_string(); - config.kiro_version = "0.8.0".to_string(); - - let mut credentials = KiroCredentials::default(); - credentials.profile_arn = Some("arn:aws:sso::123456789:profile/test".to_string()); - credentials.refresh_token = Some("a".repeat(150)); - - let provider = create_test_provider(config, credentials.clone()); - let ctx = CallContext { - id: 1, - credentials, - token: "test_token".to_string(), - }; - let headers = provider.build_headers(&ctx).unwrap(); - - assert_eq!(headers.get(CONTENT_TYPE).unwrap(), "application/json"); - assert_eq!(headers.get("x-amzn-codewhisperer-optout").unwrap(), "true"); - assert_eq!(headers.get("x-amzn-kiro-agent-mode").unwrap(), "vibe"); - assert!( - headers - .get(AUTHORIZATION) - .unwrap() - .to_str() - .unwrap() - .starts_with("Bearer ") - ); - assert_eq!(headers.get(CONNECTION).unwrap(), "close"); - } - - #[test] - fn test_is_monthly_request_limit_detects_reason() { - let body = r#"{"message":"You have reached the limit.","reason":"MONTHLY_REQUEST_COUNT"}"#; - assert!(KiroProvider::is_monthly_request_limit(body)); - } - - #[test] - fn test_is_monthly_request_limit_nested_reason() { - let body = r#"{"error":{"reason":"MONTHLY_REQUEST_COUNT"}}"#; - assert!(KiroProvider::is_monthly_request_limit(body)); - } - - #[test] - fn test_is_monthly_request_limit_false() { - let body = r#"{"message":"nope","reason":"DAILY_REQUEST_COUNT"}"#; - assert!(!KiroProvider::is_monthly_request_limit(body)); - } -} diff --git a/src/kiro/token_manager.rs b/src/kiro/token_manager.rs deleted file mode 100644 index 2d793daef96bc8018ba688600caebcce49331bba..0000000000000000000000000000000000000000 --- a/src/kiro/token_manager.rs +++ /dev/null @@ -1,1647 +0,0 @@ -//! Token 管理模块 -//! -//! 负责 Token 过期检测和刷新,支持 Social 和 IdC 认证方式 -//! 支持单凭据 (TokenManager) 和多凭据 (MultiTokenManager) 管理 - -use anyhow::bail; -use chrono::{DateTime, Duration, Utc}; -use parking_lot::Mutex; -use serde::Serialize; -use tokio::sync::Mutex as TokioMutex; - -use std::path::PathBuf; - -use crate::http_client::{ProxyConfig, build_client}; -use crate::kiro::machine_id; -use crate::kiro::model::credentials::KiroCredentials; -use crate::kiro::model::token_refresh::{ - IdcRefreshRequest, IdcRefreshResponse, RefreshRequest, RefreshResponse, -}; -use crate::kiro::model::usage_limits::UsageLimitsResponse; -use crate::model::config::Config; - -/// Token 管理器 -/// -/// 负责管理凭据和 Token 的自动刷新 -pub struct TokenManager { - config: Config, - credentials: KiroCredentials, - proxy: Option, -} - -impl TokenManager { - /// 创建新的 TokenManager 实例 - pub fn new(config: Config, credentials: KiroCredentials, proxy: Option) -> Self { - Self { - config, - credentials, - proxy, - } - } - - /// 获取凭据的引用 - pub fn credentials(&self) -> &KiroCredentials { - &self.credentials - } - - /// 获取配置的引用 - pub fn config(&self) -> &Config { - &self.config - } - - /// 确保获取有效的访问 Token - /// - /// 如果 Token 过期或即将过期,会自动刷新 - pub async fn ensure_valid_token(&mut self) -> anyhow::Result { - if is_token_expired(&self.credentials) || is_token_expiring_soon(&self.credentials) { - self.credentials = - refresh_token(&self.credentials, &self.config, self.proxy.as_ref()).await?; - - // 刷新后再次检查 token 时间有效性 - if is_token_expired(&self.credentials) { - anyhow::bail!("刷新后的 Token 仍然无效或已过期"); - } - } - - self.credentials - .access_token - .clone() - .ok_or_else(|| anyhow::anyhow!("没有可用的 accessToken")) - } - - /// 获取使用额度信息 - /// - /// 调用 getUsageLimits API 查询当前账户的使用额度 - pub async fn get_usage_limits(&mut self) -> anyhow::Result { - let token = self.ensure_valid_token().await?; - get_usage_limits(&self.credentials, &self.config, &token, self.proxy.as_ref()).await - } -} - -/// 检查 Token 是否在指定时间内过期 -pub(crate) fn is_token_expiring_within( - credentials: &KiroCredentials, - minutes: i64, -) -> Option { - credentials - .expires_at - .as_ref() - .and_then(|expires_at| DateTime::parse_from_rfc3339(expires_at).ok()) - .map(|expires| expires <= Utc::now() + Duration::minutes(minutes)) -} - -/// 检查 Token 是否已过期(提前 5 分钟判断) -pub(crate) fn is_token_expired(credentials: &KiroCredentials) -> bool { - is_token_expiring_within(credentials, 5).unwrap_or(true) -} - -/// 检查 Token 是否即将过期(10分钟内) -pub(crate) fn is_token_expiring_soon(credentials: &KiroCredentials) -> bool { - is_token_expiring_within(credentials, 10).unwrap_or(false) -} - -/// 验证 refreshToken 的基本有效性 -pub(crate) fn validate_refresh_token(credentials: &KiroCredentials) -> anyhow::Result<()> { - let refresh_token = credentials - .refresh_token - .as_ref() - .ok_or_else(|| anyhow::anyhow!("缺少 refreshToken"))?; - - if refresh_token.is_empty() { - bail!("refreshToken 为空"); - } - - if refresh_token.len() < 100 || refresh_token.ends_with("...") || refresh_token.contains("...") - { - bail!( - "refreshToken 已被截断(长度: {} 字符)。\n\ - 这通常是 Kiro IDE 为了防止凭证被第三方工具使用而故意截断的。", - refresh_token.len() - ); - } - - Ok(()) -} - -/// 刷新 Token -pub(crate) async fn refresh_token( - credentials: &KiroCredentials, - config: &Config, - proxy: Option<&ProxyConfig>, -) -> anyhow::Result { - validate_refresh_token(credentials)?; - - // 根据 auth_method 选择刷新方式 - // 如果未指定 auth_method,根据是否有 clientId/clientSecret 自动判断 - let auth_method = credentials.auth_method.as_deref().unwrap_or_else(|| { - if credentials.client_id.is_some() && credentials.client_secret.is_some() { - "idc" - } else { - "social" - } - }); - - match auth_method.to_lowercase().as_str() { - "idc" | "builder-id" => refresh_idc_token(credentials, config, proxy).await, - _ => refresh_social_token(credentials, config, proxy).await, - } -} - -/// 刷新 Social Token -async fn refresh_social_token( - credentials: &KiroCredentials, - config: &Config, - proxy: Option<&ProxyConfig>, -) -> anyhow::Result { - tracing::info!("正在刷新 Social Token..."); - - let refresh_token = credentials.refresh_token.as_ref().unwrap(); - // 优先使用凭据级 region,未配置时回退到 config.region - let region = credentials.region.as_ref().unwrap_or(&config.region); - - let refresh_url = format!("https://prod.{}.auth.desktop.kiro.dev/refreshToken", region); - let refresh_domain = format!("prod.{}.auth.desktop.kiro.dev", region); - let machine_id = machine_id::generate_from_credentials(credentials, config) - .ok_or_else(|| anyhow::anyhow!("无法生成 machineId"))?; - let kiro_version = &config.kiro_version; - - let client = build_client(proxy, 60)?; - let body = RefreshRequest { - refresh_token: refresh_token.to_string(), - }; - - let response = client - .post(&refresh_url) - .header("Accept", "application/json, text/plain, */*") - .header("Content-Type", "application/json") - .header( - "User-Agent", - format!("KiroIDE-{}-{}", kiro_version, machine_id), - ) - .header("Accept-Encoding", "gzip, compress, deflate, br") - .header("host", &refresh_domain) - .header("Connection", "close") - .json(&body) - .send() - .await?; - - let status = response.status(); - if !status.is_success() { - let body_text = response.text().await.unwrap_or_default(); - let error_msg = match status.as_u16() { - 401 => "OAuth 凭证已过期或无效,需要重新认证", - 403 => "权限不足,无法刷新 Token", - 429 => "请求过于频繁,已被限流", - 500..=599 => "服务器错误,AWS OAuth 服务暂时不可用", - _ => "Token 刷新失败", - }; - bail!("{}: {} {}", error_msg, status, body_text); - } - - let data: RefreshResponse = response.json().await?; - - let mut new_credentials = credentials.clone(); - new_credentials.access_token = Some(data.access_token); - - if let Some(new_refresh_token) = data.refresh_token { - new_credentials.refresh_token = Some(new_refresh_token); - } - - if let Some(profile_arn) = data.profile_arn { - new_credentials.profile_arn = Some(profile_arn); - } - - if let Some(expires_in) = data.expires_in { - let expires_at = Utc::now() + Duration::seconds(expires_in); - new_credentials.expires_at = Some(expires_at.to_rfc3339()); - } - - Ok(new_credentials) -} - -/// IdC Token 刷新所需的 x-amz-user-agent header -const IDC_AMZ_USER_AGENT: &str = "aws-sdk-js/3.738.0 ua/2.1 os/other lang/js md/browser#unknown_unknown api/sso-oidc#3.738.0 m/E KiroIDE"; - -/// 刷新 IdC Token (AWS SSO OIDC) -async fn refresh_idc_token( - credentials: &KiroCredentials, - config: &Config, - proxy: Option<&ProxyConfig>, -) -> anyhow::Result { - tracing::info!("正在刷新 IdC Token..."); - - let refresh_token = credentials.refresh_token.as_ref().unwrap(); - let client_id = credentials - .client_id - .as_ref() - .ok_or_else(|| anyhow::anyhow!("IdC 刷新需要 clientId"))?; - let client_secret = credentials - .client_secret - .as_ref() - .ok_or_else(|| anyhow::anyhow!("IdC 刷新需要 clientSecret"))?; - - // 优先使用凭据级 region,未配置时回退到 config.region - let region = credentials.region.as_ref().unwrap_or(&config.region); - let refresh_url = format!("https://oidc.{}.amazonaws.com/token", region); - - let client = build_client(proxy, 60)?; - let body = IdcRefreshRequest { - client_id: client_id.to_string(), - client_secret: client_secret.to_string(), - refresh_token: refresh_token.to_string(), - grant_type: "refresh_token".to_string(), - }; - - let response = client - .post(&refresh_url) - .header("Content-Type", "application/json") - .header("Host", format!("oidc.{}.amazonaws.com", region)) - .header("Connection", "keep-alive") - .header("x-amz-user-agent", IDC_AMZ_USER_AGENT) - .header("Accept", "*/*") - .header("Accept-Language", "*") - .header("sec-fetch-mode", "cors") - .header("User-Agent", "node") - .header("Accept-Encoding", "br, gzip, deflate") - .json(&body) - .send() - .await?; - - let status = response.status(); - if !status.is_success() { - let body_text = response.text().await.unwrap_or_default(); - let error_msg = match status.as_u16() { - 401 => "IdC 凭证已过期或无效,需要重新认证", - 403 => "权限不足,无法刷新 Token", - 429 => "请求过于频繁,已被限流", - 500..=599 => "服务器错误,AWS OIDC 服务暂时不可用", - _ => "IdC Token 刷新失败", - }; - bail!("{}: {} {}", error_msg, status, body_text); - } - - let data: IdcRefreshResponse = response.json().await?; - - let mut new_credentials = credentials.clone(); - new_credentials.access_token = Some(data.access_token); - - if let Some(new_refresh_token) = data.refresh_token { - new_credentials.refresh_token = Some(new_refresh_token); - } - - if let Some(expires_in) = data.expires_in { - let expires_at = Utc::now() + Duration::seconds(expires_in); - new_credentials.expires_at = Some(expires_at.to_rfc3339()); - } - - Ok(new_credentials) -} - -/// getUsageLimits API 所需的 x-amz-user-agent header 前缀 -const USAGE_LIMITS_AMZ_USER_AGENT_PREFIX: &str = "aws-sdk-js/1.0.0"; - -/// 获取使用额度信息 -pub(crate) async fn get_usage_limits( - credentials: &KiroCredentials, - config: &Config, - token: &str, - proxy: Option<&ProxyConfig>, -) -> anyhow::Result { - tracing::debug!("正在获取使用额度信息..."); - - let region = &config.region; - let host = format!("q.{}.amazonaws.com", region); - let machine_id = machine_id::generate_from_credentials(credentials, config) - .ok_or_else(|| anyhow::anyhow!("无法生成 machineId"))?; - let kiro_version = &config.kiro_version; - - // 构建 URL - let mut url = format!( - "https://{}/getUsageLimits?origin=AI_EDITOR&resourceType=AGENTIC_REQUEST", - host - ); - - // profileArn 是可选的 - if let Some(profile_arn) = &credentials.profile_arn { - url.push_str(&format!("&profileArn={}", urlencoding::encode(profile_arn))); - } - - // 构建 User-Agent headers - let user_agent = format!( - "aws-sdk-js/1.0.0 ua/2.1 os/darwin#24.6.0 lang/js md/nodejs#22.21.1 \ - api/codewhispererruntime#1.0.0 m/N,E KiroIDE-{}-{}", - kiro_version, machine_id - ); - let amz_user_agent = format!( - "{} KiroIDE-{}-{}", - USAGE_LIMITS_AMZ_USER_AGENT_PREFIX, kiro_version, machine_id - ); - - let client = build_client(proxy, 60)?; - - let response = client - .get(&url) - .header("x-amz-user-agent", &amz_user_agent) - .header("User-Agent", &user_agent) - .header("host", &host) - .header("amz-sdk-invocation-id", uuid::Uuid::new_v4().to_string()) - .header("amz-sdk-request", "attempt=1; max=1") - .header("Authorization", format!("Bearer {}", token)) - .header("Connection", "close") - .send() - .await?; - - let status = response.status(); - if !status.is_success() { - let body_text = response.text().await.unwrap_or_default(); - let error_msg = match status.as_u16() { - 401 => "认证失败,Token 无效或已过期", - 403 => "权限不足,无法获取使用额度", - 429 => "请求过于频繁,已被限流", - 500..=599 => "服务器错误,AWS 服务暂时不可用", - _ => "获取使用额度失败", - }; - bail!("{}: {} {}", error_msg, status, body_text); - } - - let data: UsageLimitsResponse = response.json().await?; - Ok(data) -} - -// ============================================================================ -// 多凭据 Token 管理器 -// ============================================================================ - -/// 单个凭据条目的状态 -struct CredentialEntry { - /// 凭据唯一 ID - id: u64, - /// 凭据信息 - credentials: KiroCredentials, - /// API 调用连续失败次数 - failure_count: u32, - /// 是否已禁用 - disabled: bool, - /// 禁用原因(用于区分手动禁用 vs 自动禁用,便于自愈) - disabled_reason: Option, -} - -/// 禁用原因 -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -enum DisabledReason { - /// Admin API 手动禁用 - Manual, - /// 连续失败达到阈值后自动禁用 - TooManyFailures, - /// 额度已用尽(如 MONTHLY_REQUEST_COUNT) - QuotaExceeded, -} - -// ============================================================================ -// Admin API 公开结构 -// ============================================================================ - -/// 凭据条目快照(用于 Admin API 读取) -#[derive(Debug, Clone, Serialize)] -#[serde(rename_all = "camelCase")] -pub struct CredentialEntrySnapshot { - /// 凭据唯一 ID - pub id: u64, - /// 优先级 - pub priority: u32, - /// 是否被禁用 - pub disabled: bool, - /// 连续失败次数 - pub failure_count: u32, - /// 认证方式 - pub auth_method: Option, - /// 是否有 Profile ARN - pub has_profile_arn: bool, - /// Token 过期时间 - pub expires_at: Option, -} - -/// 凭据管理器状态快照 -#[derive(Debug, Clone, Serialize)] -#[serde(rename_all = "camelCase")] -pub struct ManagerSnapshot { - /// 凭据条目列表 - pub entries: Vec, - /// 当前活跃凭据 ID - pub current_id: u64, - /// 总凭据数量 - pub total: usize, - /// 可用凭据数量 - pub available: usize, -} - -/// 多凭据 Token 管理器 -/// -/// 支持多个凭据的管理,实现固定优先级 + 故障转移策略 -/// 故障统计基于 API 调用结果,而非 Token 刷新结果 -pub struct MultiTokenManager { - config: Config, - proxy: Option, - /// 凭据条目列表 - entries: Mutex>, - /// 当前活动凭据 ID - current_id: Mutex, - /// Token 刷新锁,确保同一时间只有一个刷新操作 - refresh_lock: TokioMutex<()>, - /// 凭据文件路径(用于回写) - credentials_path: Option, - /// 是否为多凭据格式(数组格式才回写) - is_multiple_format: bool, -} - -/// 每个凭据最大 API 调用失败次数 -const MAX_FAILURES_PER_CREDENTIAL: u32 = 3; - -/// API 调用上下文 -/// -/// 绑定特定凭据的调用上下文,确保 token、credentials 和 id 的一致性 -/// 用于解决并发调用时 current_id 竞态问题 -#[derive(Clone)] -pub struct CallContext { - /// 凭据 ID(用于 report_success/report_failure) - pub id: u64, - /// 凭据信息(用于构建请求头) - pub credentials: KiroCredentials, - /// 访问 Token - pub token: String, -} - -impl MultiTokenManager { - /// 创建多凭据 Token 管理器 - /// - /// # Arguments - /// * `config` - 应用配置 - /// * `credentials` - 凭据列表 - /// * `proxy` - 可选的代理配置 - /// * `credentials_path` - 凭据文件路径(用于回写) - /// * `is_multiple_format` - 是否为多凭据格式(数组格式才回写) - pub fn new( - config: Config, - credentials: Vec, - proxy: Option, - credentials_path: Option, - is_multiple_format: bool, - ) -> anyhow::Result { - // 计算当前最大 ID,为没有 ID 的凭据分配新 ID - let max_existing_id = credentials.iter().filter_map(|c| c.id).max().unwrap_or(0); - let mut next_id = max_existing_id + 1; - let mut has_new_ids = false; - let mut has_new_machine_ids = false; - let config_ref = &config; - - let entries: Vec = credentials - .into_iter() - .map(|mut cred| { - let id = cred.id.unwrap_or_else(|| { - let id = next_id; - next_id += 1; - cred.id = Some(id); - has_new_ids = true; - id - }); - if cred.machine_id.is_none() { - if let Some(machine_id) = - machine_id::generate_from_credentials(&cred, config_ref) - { - cred.machine_id = Some(machine_id); - has_new_machine_ids = true; - } - } - CredentialEntry { - id, - credentials: cred, - failure_count: 0, - disabled: false, - disabled_reason: None, - } - }) - .collect(); - - // 检测重复 ID - let mut seen_ids = std::collections::HashSet::new(); - let mut duplicate_ids = Vec::new(); - for entry in &entries { - if !seen_ids.insert(entry.id) { - duplicate_ids.push(entry.id); - } - } - if !duplicate_ids.is_empty() { - anyhow::bail!("检测到重复的凭据 ID: {:?}", duplicate_ids); - } - - // 选择初始凭据:优先级最高(priority 最小)的凭据,无凭据时为 0 - let initial_id = entries - .iter() - .min_by_key(|e| e.credentials.priority) - .map(|e| e.id) - .unwrap_or(0); - - let manager = Self { - config, - proxy, - entries: Mutex::new(entries), - current_id: Mutex::new(initial_id), - refresh_lock: TokioMutex::new(()), - credentials_path, - is_multiple_format, - }; - - // 如果有新分配的 ID 或新生成的 machineId,立即持久化到配置文件 - if has_new_ids || has_new_machine_ids { - if let Err(e) = manager.persist_credentials() { - tracing::warn!("补全凭据 ID/machineId 后持久化失败: {}", e); - } else { - tracing::info!("已补全凭据 ID/machineId 并写回配置文件"); - } - } - - Ok(manager) - } - - /// 获取配置的引用 - pub fn config(&self) -> &Config { - &self.config - } - - /// 获取当前活动凭据的克隆 - pub fn credentials(&self) -> KiroCredentials { - let entries = self.entries.lock(); - let current_id = *self.current_id.lock(); - entries - .iter() - .find(|e| e.id == current_id) - .map(|e| e.credentials.clone()) - .unwrap_or_default() - } - - /// 获取凭据总数 - pub fn total_count(&self) -> usize { - self.entries.lock().len() - } - - /// 获取可用凭据数量 - pub fn available_count(&self) -> usize { - self.entries.lock().iter().filter(|e| !e.disabled).count() - } - - /// 获取 API 调用上下文 - /// - /// 返回绑定了 id、credentials 和 token 的调用上下文 - /// 确保整个 API 调用过程中使用一致的凭据信息 - /// - /// 如果 Token 过期或即将过期,会自动刷新 - /// Token 刷新失败时会尝试下一个可用凭据(不计入失败次数) - pub async fn acquire_context(&self) -> anyhow::Result { - let total = self.total_count(); - let mut tried_count = 0; - - loop { - if tried_count >= total { - anyhow::bail!( - "所有凭据均无法获取有效 Token(可用: {}/{})", - self.available_count(), - total - ); - } - - let (id, credentials) = { - let mut entries = self.entries.lock(); - let current_id = *self.current_id.lock(); - - // 找到当前凭据 - if let Some(entry) = entries.iter().find(|e| e.id == current_id && !e.disabled) { - (entry.id, entry.credentials.clone()) - } else { - // 当前凭据不可用,选择优先级最高的可用凭据 - let mut best = entries - .iter() - .filter(|e| !e.disabled) - .min_by_key(|e| e.credentials.priority); - - // 没有可用凭据:如果是“自动禁用导致全灭”,做一次类似重启的自愈 - if best.is_none() - && entries.iter().any(|e| { - e.disabled && e.disabled_reason == Some(DisabledReason::TooManyFailures) - }) - { - tracing::warn!( - "所有凭据均已被自动禁用,执行自愈:重置失败计数并重新启用(等价于重启)" - ); - for e in entries.iter_mut() { - if e.disabled_reason == Some(DisabledReason::TooManyFailures) { - e.disabled = false; - e.disabled_reason = None; - e.failure_count = 0; - } - } - best = entries - .iter() - .filter(|e| !e.disabled) - .min_by_key(|e| e.credentials.priority); - } - - if let Some(entry) = best { - // 先提取数据 - let new_id = entry.id; - let new_creds = entry.credentials.clone(); - drop(entries); - // 更新 current_id - let mut current_id = self.current_id.lock(); - *current_id = new_id; - (new_id, new_creds) - } else { - // 注意:必须在 bail! 之前计算 available_count, - // 因为 available_count() 会尝试获取 entries 锁, - // 而此时我们已经持有该锁,会导致死锁 - let available = entries.iter().filter(|e| !e.disabled).count(); - anyhow::bail!("所有凭据均已禁用({}/{})", available, total); - } - } - }; - - // 尝试获取/刷新 Token - match self.try_ensure_token(id, &credentials).await { - Ok(ctx) => { - return Ok(ctx); - } - Err(e) => { - tracing::warn!("凭据 #{} Token 刷新失败,尝试下一个凭据: {}", id, e); - - // Token 刷新失败,切换到下一个优先级的凭据(不计入失败次数) - self.switch_to_next_by_priority(); - tried_count += 1; - } - } - } - } - - /// 切换到下一个优先级最高的可用凭据(内部方法) - fn switch_to_next_by_priority(&self) { - let entries = self.entries.lock(); - let mut current_id = self.current_id.lock(); - - // 选择优先级最高的未禁用凭据(排除当前凭据) - if let Some(entry) = entries - .iter() - .filter(|e| !e.disabled && e.id != *current_id) - .min_by_key(|e| e.credentials.priority) - { - *current_id = entry.id; - tracing::info!( - "已切换到凭据 #{}(优先级 {})", - entry.id, - entry.credentials.priority - ); - } - } - - /// 选择优先级最高的未禁用凭据作为当前凭据(内部方法) - /// - /// 与 `switch_to_next_by_priority` 不同,此方法不排除当前凭据, - /// 纯粹按优先级选择,用于优先级变更后立即生效 - fn select_highest_priority(&self) { - let entries = self.entries.lock(); - let mut current_id = self.current_id.lock(); - - // 选择优先级最高的未禁用凭据(不排除当前凭据) - if let Some(best) = entries - .iter() - .filter(|e| !e.disabled) - .min_by_key(|e| e.credentials.priority) - { - if best.id != *current_id { - tracing::info!( - "优先级变更后切换凭据: #{} -> #{}(优先级 {})", - *current_id, - best.id, - best.credentials.priority - ); - *current_id = best.id; - } - } - } - - /// 尝试使用指定凭据获取有效 Token - /// - /// 使用双重检查锁定模式,确保同一时间只有一个刷新操作 - /// - /// # Arguments - /// * `id` - 凭据 ID,用于更新正确的条目 - /// * `credentials` - 凭据信息 - async fn try_ensure_token( - &self, - id: u64, - credentials: &KiroCredentials, - ) -> anyhow::Result { - // 第一次检查(无锁):快速判断是否需要刷新 - let needs_refresh = is_token_expired(credentials) || is_token_expiring_soon(credentials); - - let creds = if needs_refresh { - // 获取刷新锁,确保同一时间只有一个刷新操作 - let _guard = self.refresh_lock.lock().await; - - // 第二次检查:获取锁后重新读取凭据,因为其他请求可能已经完成刷新 - let current_creds = { - let entries = self.entries.lock(); - entries - .iter() - .find(|e| e.id == id) - .map(|e| e.credentials.clone()) - .ok_or_else(|| anyhow::anyhow!("凭据 #{} 不存在", id))? - }; - - if is_token_expired(¤t_creds) || is_token_expiring_soon(¤t_creds) { - // 确实需要刷新 - let new_creds = - refresh_token(¤t_creds, &self.config, self.proxy.as_ref()).await?; - - if is_token_expired(&new_creds) { - anyhow::bail!("刷新后的 Token 仍然无效或已过期"); - } - - // 更新凭据 - { - let mut entries = self.entries.lock(); - if let Some(entry) = entries.iter_mut().find(|e| e.id == id) { - entry.credentials = new_creds.clone(); - } - } - - // 回写凭据到文件(仅多凭据格式),失败只记录警告 - if let Err(e) = self.persist_credentials() { - tracing::warn!("Token 刷新后持久化失败(不影响本次请求): {}", e); - } - - new_creds - } else { - // 其他请求已经完成刷新,直接使用新凭据 - tracing::debug!("Token 已被其他请求刷新,跳过刷新"); - current_creds - } - } else { - credentials.clone() - }; - - let token = creds - .access_token - .clone() - .ok_or_else(|| anyhow::anyhow!("没有可用的 accessToken"))?; - - Ok(CallContext { - id, - credentials: creds, - token, - }) - } - - /// 将凭据列表回写到源文件 - /// - /// 仅在以下条件满足时回写: - /// - 源文件是多凭据格式(数组) - /// - credentials_path 已设置 - /// - /// # Returns - /// - `Ok(true)` - 成功写入文件 - /// - `Ok(false)` - 跳过写入(非多凭据格式或无路径配置) - /// - `Err(_)` - 写入失败 - fn persist_credentials(&self) -> anyhow::Result { - use anyhow::Context; - - // 仅多凭据格式才回写 - if !self.is_multiple_format { - return Ok(false); - } - - let path = match &self.credentials_path { - Some(p) => p, - None => return Ok(false), - }; - - // 收集所有凭据 - let credentials: Vec = { - let entries = self.entries.lock(); - entries.iter().map(|e| e.credentials.clone()).collect() - }; - - // 序列化为 pretty JSON - let json = serde_json::to_string_pretty(&credentials).context("序列化凭据失败")?; - - // 写入文件(在 Tokio runtime 内使用 block_in_place 避免阻塞 worker) - if tokio::runtime::Handle::try_current().is_ok() { - tokio::task::block_in_place(|| std::fs::write(path, &json)) - .with_context(|| format!("回写凭据文件失败: {:?}", path))?; - } else { - std::fs::write(path, &json).with_context(|| format!("回写凭据文件失败: {:?}", path))?; - } - - tracing::debug!("已回写凭据到文件: {:?}", path); - Ok(true) - } - - /// 报告指定凭据 API 调用成功 - /// - /// 重置该凭据的失败计数 - /// - /// # Arguments - /// * `id` - 凭据 ID(来自 CallContext) - pub fn report_success(&self, id: u64) { - let mut entries = self.entries.lock(); - if let Some(entry) = entries.iter_mut().find(|e| e.id == id) { - entry.failure_count = 0; - tracing::debug!("凭据 #{} API 调用成功", id); - } - } - - /// 报告指定凭据 API 调用失败 - /// - /// 增加失败计数,达到阈值时禁用凭据并切换到优先级最高的可用凭据 - /// 返回是否还有可用凭据可以重试 - /// - /// # Arguments - /// * `id` - 凭据 ID(来自 CallContext) - pub fn report_failure(&self, id: u64) -> bool { - let mut entries = self.entries.lock(); - let mut current_id = self.current_id.lock(); - - let entry = match entries.iter_mut().find(|e| e.id == id) { - Some(e) => e, - None => return entries.iter().any(|e| !e.disabled), - }; - - entry.failure_count += 1; - let failure_count = entry.failure_count; - - tracing::warn!( - "凭据 #{} API 调用失败({}/{})", - id, - failure_count, - MAX_FAILURES_PER_CREDENTIAL - ); - - if failure_count >= MAX_FAILURES_PER_CREDENTIAL { - entry.disabled = true; - entry.disabled_reason = Some(DisabledReason::TooManyFailures); - tracing::error!("凭据 #{} 已连续失败 {} 次,已被禁用", id, failure_count); - - // 切换到优先级最高的可用凭据 - if let Some(next) = entries - .iter() - .filter(|e| !e.disabled) - .min_by_key(|e| e.credentials.priority) - { - *current_id = next.id; - tracing::info!( - "已切换到凭据 #{}(优先级 {})", - next.id, - next.credentials.priority - ); - } else { - tracing::error!("所有凭据均已禁用!"); - return false; - } - } - - // 检查是否还有可用凭据 - entries.iter().any(|e| !e.disabled) - } - - /// 报告指定凭据额度已用尽 - /// - /// 用于处理 402 Payment Required 且 reason 为 `MONTHLY_REQUEST_COUNT` 的场景: - /// - 立即禁用该凭据(不等待连续失败阈值) - /// - 切换到下一个可用凭据继续重试 - /// - 返回是否还有可用凭据 - pub fn report_quota_exhausted(&self, id: u64) -> bool { - let mut entries = self.entries.lock(); - let mut current_id = self.current_id.lock(); - - let entry = match entries.iter_mut().find(|e| e.id == id) { - Some(e) => e, - None => return entries.iter().any(|e| !e.disabled), - }; - - if entry.disabled { - return entries.iter().any(|e| !e.disabled); - } - - entry.disabled = true; - entry.disabled_reason = Some(DisabledReason::QuotaExceeded); - // 设为阈值,便于在管理面板中直观看到该凭据已不可用 - entry.failure_count = MAX_FAILURES_PER_CREDENTIAL; - - tracing::error!( - "凭据 #{} 额度已用尽(MONTHLY_REQUEST_COUNT),已被禁用", - id - ); - - // 切换到优先级最高的可用凭据 - if let Some(next) = entries - .iter() - .filter(|e| !e.disabled) - .min_by_key(|e| e.credentials.priority) - { - *current_id = next.id; - tracing::info!( - "已切换到凭据 #{}(优先级 {})", - next.id, - next.credentials.priority - ); - return true; - } - - tracing::error!("所有凭据均已禁用!"); - false - } - - /// 切换到优先级最高的可用凭据 - /// - /// 返回是否成功切换 - pub fn switch_to_next(&self) -> bool { - let entries = self.entries.lock(); - let mut current_id = self.current_id.lock(); - - // 选择优先级最高的未禁用凭据(排除当前凭据) - if let Some(next) = entries - .iter() - .filter(|e| !e.disabled && e.id != *current_id) - .min_by_key(|e| e.credentials.priority) - { - *current_id = next.id; - tracing::info!( - "已切换到凭据 #{}(优先级 {})", - next.id, - next.credentials.priority - ); - true - } else { - // 没有其他可用凭据,检查当前凭据是否可用 - entries.iter().any(|e| e.id == *current_id && !e.disabled) - } - } - - /// 获取使用额度信息 - pub async fn get_usage_limits(&self) -> anyhow::Result { - let ctx = self.acquire_context().await?; - get_usage_limits( - &ctx.credentials, - &self.config, - &ctx.token, - self.proxy.as_ref(), - ) - .await - } - - // ======================================================================== - // Admin API 方法 - // ======================================================================== - - /// 获取管理器状态快照(用于 Admin API) - pub fn snapshot(&self) -> ManagerSnapshot { - let entries = self.entries.lock(); - let current_id = *self.current_id.lock(); - let available = entries.iter().filter(|e| !e.disabled).count(); - - ManagerSnapshot { - entries: entries - .iter() - .map(|e| CredentialEntrySnapshot { - id: e.id, - priority: e.credentials.priority, - disabled: e.disabled, - failure_count: e.failure_count, - auth_method: e.credentials.auth_method.clone(), - has_profile_arn: e.credentials.profile_arn.is_some(), - expires_at: e.credentials.expires_at.clone(), - }) - .collect(), - current_id, - total: entries.len(), - available, - } - } - - /// 设置凭据禁用状态(Admin API) - pub fn set_disabled(&self, id: u64, disabled: bool) -> anyhow::Result<()> { - { - let mut entries = self.entries.lock(); - let entry = entries - .iter_mut() - .find(|e| e.id == id) - .ok_or_else(|| anyhow::anyhow!("凭据不存在: {}", id))?; - entry.disabled = disabled; - if !disabled { - // 启用时重置失败计数 - entry.failure_count = 0; - entry.disabled_reason = None; - } else { - entry.disabled_reason = Some(DisabledReason::Manual); - } - } - // 持久化更改 - self.persist_credentials()?; - Ok(()) - } - - /// 设置凭据优先级(Admin API) - /// - /// 修改优先级后会立即按新优先级重新选择当前凭据。 - /// 即使持久化失败,内存中的优先级和当前凭据选择也会生效。 - pub fn set_priority(&self, id: u64, priority: u32) -> anyhow::Result<()> { - { - let mut entries = self.entries.lock(); - let entry = entries - .iter_mut() - .find(|e| e.id == id) - .ok_or_else(|| anyhow::anyhow!("凭据不存在: {}", id))?; - entry.credentials.priority = priority; - } - // 立即按新优先级重新选择当前凭据(无论持久化是否成功) - self.select_highest_priority(); - // 持久化更改 - self.persist_credentials()?; - Ok(()) - } - - /// 重置凭据失败计数并重新启用(Admin API) - pub fn reset_and_enable(&self, id: u64) -> anyhow::Result<()> { - { - let mut entries = self.entries.lock(); - let entry = entries - .iter_mut() - .find(|e| e.id == id) - .ok_or_else(|| anyhow::anyhow!("凭据不存在: {}", id))?; - entry.failure_count = 0; - entry.disabled = false; - entry.disabled_reason = None; - } - // 持久化更改 - self.persist_credentials()?; - Ok(()) - } - - /// 获取指定凭据的使用额度(Admin API) - pub async fn get_usage_limits_for(&self, id: u64) -> anyhow::Result { - let credentials = { - let entries = self.entries.lock(); - entries - .iter() - .find(|e| e.id == id) - .map(|e| e.credentials.clone()) - .ok_or_else(|| anyhow::anyhow!("凭据不存在: {}", id))? - }; - - // 检查是否需要刷新 token - let needs_refresh = is_token_expired(&credentials) || is_token_expiring_soon(&credentials); - - let token = if needs_refresh { - let _guard = self.refresh_lock.lock().await; - let current_creds = { - let entries = self.entries.lock(); - entries - .iter() - .find(|e| e.id == id) - .map(|e| e.credentials.clone()) - .ok_or_else(|| anyhow::anyhow!("凭据不存在: {}", id))? - }; - - if is_token_expired(¤t_creds) || is_token_expiring_soon(¤t_creds) { - let new_creds = - refresh_token(¤t_creds, &self.config, self.proxy.as_ref()).await?; - { - let mut entries = self.entries.lock(); - if let Some(entry) = entries.iter_mut().find(|e| e.id == id) { - entry.credentials = new_creds.clone(); - } - } - // 持久化失败只记录警告,不影响本次请求 - if let Err(e) = self.persist_credentials() { - tracing::warn!("Token 刷新后持久化失败(不影响本次请求): {}", e); - } - new_creds - .access_token - .ok_or_else(|| anyhow::anyhow!("刷新后无 access_token"))? - } else { - current_creds - .access_token - .ok_or_else(|| anyhow::anyhow!("凭据无 access_token"))? - } - } else { - credentials - .access_token - .ok_or_else(|| anyhow::anyhow!("凭据无 access_token"))? - }; - - let credentials = { - let entries = self.entries.lock(); - entries - .iter() - .find(|e| e.id == id) - .map(|e| e.credentials.clone()) - .ok_or_else(|| anyhow::anyhow!("凭据不存在: {}", id))? - }; - - get_usage_limits(&credentials, &self.config, &token, self.proxy.as_ref()).await - } - - /// 添加新凭据(Admin API) - /// - /// # 流程 - /// 1. 验证凭据基本字段(refresh_token 不为空) - /// 2. 尝试刷新 Token 验证凭据有效性 - /// 3. 分配新 ID(当前最大 ID + 1) - /// 4. 添加到 entries 列表 - /// 5. 持久化到配置文件 - /// - /// # 返回 - /// - `Ok(u64)` - 新凭据 ID - /// - `Err(_)` - 验证失败或添加失败 - pub async fn add_credential(&self, new_cred: KiroCredentials) -> anyhow::Result { - // 1. 基本验证 - validate_refresh_token(&new_cred)?; - - // 2. 尝试刷新 Token 验证凭据有效性 - let mut validated_cred = - refresh_token(&new_cred, &self.config, self.proxy.as_ref()).await?; - - // 3. 分配新 ID - let new_id = { - let entries = self.entries.lock(); - entries.iter().map(|e| e.id).max().unwrap_or(0) + 1 - }; - - // 4. 设置 ID 并保留用户输入的元数据 - validated_cred.id = Some(new_id); - validated_cred.priority = new_cred.priority; - validated_cred.auth_method = new_cred.auth_method; - validated_cred.client_id = new_cred.client_id; - validated_cred.client_secret = new_cred.client_secret; - - { - let mut entries = self.entries.lock(); - entries.push(CredentialEntry { - id: new_id, - credentials: validated_cred, - failure_count: 0, - disabled: false, - disabled_reason: None, - }); - } - - // 5. 持久化 - self.persist_credentials()?; - - tracing::info!("成功添加凭据 #{}", new_id); - Ok(new_id) - } - - /// 删除凭据(Admin API) - /// - /// # 前置条件 - /// - 凭据必须已禁用(disabled = true) - /// - /// # 行为 - /// 1. 验证凭据存在 - /// 2. 验证凭据已禁用 - /// 3. 从 entries 移除 - /// 4. 如果删除的是当前凭据,切换到优先级最高的可用凭据 - /// 5. 如果删除后没有凭据,将 current_id 重置为 0 - /// 6. 持久化到文件 - /// - /// # 返回 - /// - `Ok(())` - 删除成功 - /// - `Err(_)` - 凭据不存在、未禁用或持久化失败 - pub fn delete_credential(&self, id: u64) -> anyhow::Result<()> { - let was_current = { - let mut entries = self.entries.lock(); - - // 查找凭据 - let entry = entries - .iter() - .find(|e| e.id == id) - .ok_or_else(|| anyhow::anyhow!("凭据不存在: {}", id))?; - - // 检查是否已禁用 - if !entry.disabled { - anyhow::bail!("只能删除已禁用的凭据(请先禁用凭据 #{})", id); - } - - // 记录是否是当前凭据 - let current_id = *self.current_id.lock(); - let was_current = current_id == id; - - // 删除凭据 - entries.retain(|e| e.id != id); - - was_current - }; - - // 如果删除的是当前凭据,切换到优先级最高的可用凭据 - if was_current { - self.select_highest_priority(); - } - - // 如果删除后没有任何凭据,将 current_id 重置为 0(与初始化行为保持一致) - { - let entries = self.entries.lock(); - if entries.is_empty() { - let mut current_id = self.current_id.lock(); - *current_id = 0; - tracing::info!("所有凭据已删除,current_id 已重置为 0"); - } - } - - // 持久化更改 - self.persist_credentials()?; - - tracing::info!("已删除凭据 #{}", id); - Ok(()) - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_token_manager_new() { - let config = Config::default(); - let credentials = KiroCredentials::default(); - let tm = TokenManager::new(config, credentials, None); - assert!(tm.credentials().access_token.is_none()); - } - - #[test] - fn test_is_token_expired_with_expired_token() { - let mut credentials = KiroCredentials::default(); - credentials.expires_at = Some("2020-01-01T00:00:00Z".to_string()); - assert!(is_token_expired(&credentials)); - } - - #[test] - fn test_is_token_expired_with_valid_token() { - let mut credentials = KiroCredentials::default(); - let future = Utc::now() + Duration::hours(1); - credentials.expires_at = Some(future.to_rfc3339()); - assert!(!is_token_expired(&credentials)); - } - - #[test] - fn test_is_token_expired_within_5_minutes() { - let mut credentials = KiroCredentials::default(); - let expires = Utc::now() + Duration::minutes(3); - credentials.expires_at = Some(expires.to_rfc3339()); - assert!(is_token_expired(&credentials)); - } - - #[test] - fn test_is_token_expired_no_expires_at() { - let credentials = KiroCredentials::default(); - assert!(is_token_expired(&credentials)); - } - - #[test] - fn test_is_token_expiring_soon_within_10_minutes() { - let mut credentials = KiroCredentials::default(); - let expires = Utc::now() + Duration::minutes(8); - credentials.expires_at = Some(expires.to_rfc3339()); - assert!(is_token_expiring_soon(&credentials)); - } - - #[test] - fn test_is_token_expiring_soon_beyond_10_minutes() { - let mut credentials = KiroCredentials::default(); - let expires = Utc::now() + Duration::minutes(15); - credentials.expires_at = Some(expires.to_rfc3339()); - assert!(!is_token_expiring_soon(&credentials)); - } - - #[test] - fn test_validate_refresh_token_missing() { - let credentials = KiroCredentials::default(); - let result = validate_refresh_token(&credentials); - assert!(result.is_err()); - } - - #[test] - fn test_validate_refresh_token_valid() { - let mut credentials = KiroCredentials::default(); - credentials.refresh_token = Some("a".repeat(150)); - let result = validate_refresh_token(&credentials); - assert!(result.is_ok()); - } - - // MultiTokenManager 测试 - - #[test] - fn test_multi_token_manager_new() { - let config = Config::default(); - let mut cred1 = KiroCredentials::default(); - cred1.priority = 0; - let mut cred2 = KiroCredentials::default(); - cred2.priority = 1; - - let manager = - MultiTokenManager::new(config, vec![cred1, cred2], None, None, false).unwrap(); - assert_eq!(manager.total_count(), 2); - assert_eq!(manager.available_count(), 2); - } - - #[test] - fn test_multi_token_manager_empty_credentials() { - let config = Config::default(); - let result = MultiTokenManager::new(config, vec![], None, None, false); - // 支持 0 个凭据启动(可通过管理面板添加) - assert!(result.is_ok()); - let manager = result.unwrap(); - assert_eq!(manager.total_count(), 0); - assert_eq!(manager.available_count(), 0); - } - - #[test] - fn test_multi_token_manager_duplicate_ids() { - let config = Config::default(); - let mut cred1 = KiroCredentials::default(); - cred1.id = Some(1); - let mut cred2 = KiroCredentials::default(); - cred2.id = Some(1); // 重复 ID - - let result = MultiTokenManager::new(config, vec![cred1, cred2], None, None, false); - assert!(result.is_err()); - let err_msg = result.err().unwrap().to_string(); - assert!( - err_msg.contains("重复的凭据 ID"), - "错误消息应包含 '重复的凭据 ID',实际: {}", - err_msg - ); - } - - #[test] - fn test_multi_token_manager_report_failure() { - let config = Config::default(); - let cred1 = KiroCredentials::default(); - let cred2 = KiroCredentials::default(); - - let manager = - MultiTokenManager::new(config, vec![cred1, cred2], None, None, false).unwrap(); - - // 凭据会自动分配 ID(从 1 开始) - // 前两次失败不会禁用(使用 ID 1) - assert!(manager.report_failure(1)); - assert!(manager.report_failure(1)); - assert_eq!(manager.available_count(), 2); - - // 第三次失败会禁用第一个凭据 - assert!(manager.report_failure(1)); - assert_eq!(manager.available_count(), 1); - - // 继续失败第二个凭据(使用 ID 2) - assert!(manager.report_failure(2)); - assert!(manager.report_failure(2)); - assert!(!manager.report_failure(2)); // 所有凭据都禁用了 - assert_eq!(manager.available_count(), 0); - } - - #[test] - fn test_multi_token_manager_report_success() { - let config = Config::default(); - let cred = KiroCredentials::default(); - - let manager = MultiTokenManager::new(config, vec![cred], None, None, false).unwrap(); - - // 失败两次(使用 ID 1) - manager.report_failure(1); - manager.report_failure(1); - - // 成功后重置计数(使用 ID 1) - manager.report_success(1); - - // 再失败两次不会禁用 - manager.report_failure(1); - manager.report_failure(1); - assert_eq!(manager.available_count(), 1); - } - - #[test] - fn test_multi_token_manager_switch_to_next() { - let config = Config::default(); - let mut cred1 = KiroCredentials::default(); - cred1.refresh_token = Some("token1".to_string()); - let mut cred2 = KiroCredentials::default(); - cred2.refresh_token = Some("token2".to_string()); - - let manager = - MultiTokenManager::new(config, vec![cred1, cred2], None, None, false).unwrap(); - - // 初始是第一个凭据 - assert_eq!( - manager.credentials().refresh_token, - Some("token1".to_string()) - ); - - // 切换到下一个 - assert!(manager.switch_to_next()); - assert_eq!( - manager.credentials().refresh_token, - Some("token2".to_string()) - ); - } - - #[tokio::test] - async fn test_multi_token_manager_acquire_context_auto_recovers_all_disabled() { - let config = Config::default(); - let mut cred1 = KiroCredentials::default(); - cred1.access_token = Some("t1".to_string()); - cred1.expires_at = Some((Utc::now() + Duration::hours(1)).to_rfc3339()); - let mut cred2 = KiroCredentials::default(); - cred2.access_token = Some("t2".to_string()); - cred2.expires_at = Some((Utc::now() + Duration::hours(1)).to_rfc3339()); - - let manager = - MultiTokenManager::new(config, vec![cred1, cred2], None, None, false).unwrap(); - - // 凭据会自动分配 ID(从 1 开始) - for _ in 0..MAX_FAILURES_PER_CREDENTIAL { - manager.report_failure(1); - } - for _ in 0..MAX_FAILURES_PER_CREDENTIAL { - manager.report_failure(2); - } - - assert_eq!(manager.available_count(), 0); - - // 应触发自愈:重置失败计数并重新启用,避免必须重启进程 - let ctx = manager.acquire_context().await.unwrap(); - assert!(ctx.token == "t1" || ctx.token == "t2"); - assert_eq!(manager.available_count(), 2); - } - - #[test] - fn test_multi_token_manager_report_quota_exhausted() { - let config = Config::default(); - let cred1 = KiroCredentials::default(); - let cred2 = KiroCredentials::default(); - - let manager = - MultiTokenManager::new(config, vec![cred1, cred2], None, None, false).unwrap(); - - // 凭据会自动分配 ID(从 1 开始) - assert_eq!(manager.available_count(), 2); - assert!(manager.report_quota_exhausted(1)); - assert_eq!(manager.available_count(), 1); - - // 再禁用第二个后,无可用凭据 - assert!(!manager.report_quota_exhausted(2)); - assert_eq!(manager.available_count(), 0); - } - - #[tokio::test] - async fn test_multi_token_manager_quota_disabled_is_not_auto_recovered() { - let config = Config::default(); - let cred1 = KiroCredentials::default(); - let cred2 = KiroCredentials::default(); - - let manager = - MultiTokenManager::new(config, vec![cred1, cred2], None, None, false).unwrap(); - - manager.report_quota_exhausted(1); - manager.report_quota_exhausted(2); - assert_eq!(manager.available_count(), 0); - - let err = manager.acquire_context().await.err().unwrap().to_string(); - assert!( - err.contains("所有凭据均已禁用"), - "错误应提示所有凭据禁用,实际: {}", - err - ); - assert_eq!(manager.available_count(), 0); - } - - // ============ 凭据级 Region 优先级测试 ============ - - /// 辅助函数:获取 OIDC 刷新使用的 region(用于测试) - fn get_oidc_region_for_credential<'a>( - credentials: &'a KiroCredentials, - config: &'a Config, - ) -> &'a str { - credentials.region.as_ref().unwrap_or(&config.region) - } - - #[test] - fn test_credential_region_priority_uses_credential_region() { - // 凭据配置了 region 时,应使用凭据的 region - let mut config = Config::default(); - config.region = "us-west-2".to_string(); - - let mut credentials = KiroCredentials::default(); - credentials.region = Some("eu-west-1".to_string()); - - let region = get_oidc_region_for_credential(&credentials, &config); - assert_eq!(region, "eu-west-1"); - } - - #[test] - fn test_credential_region_priority_fallback_to_config() { - // 凭据未配置 region 时,应回退到 config.region - let mut config = Config::default(); - config.region = "us-west-2".to_string(); - - let credentials = KiroCredentials::default(); - assert!(credentials.region.is_none()); - - let region = get_oidc_region_for_credential(&credentials, &config); - assert_eq!(region, "us-west-2"); - } - - #[test] - fn test_multiple_credentials_use_respective_regions() { - // 多凭据场景下,不同凭据使用各自的 region - let mut config = Config::default(); - config.region = "ap-northeast-1".to_string(); - - let mut cred1 = KiroCredentials::default(); - cred1.region = Some("us-east-1".to_string()); - - let mut cred2 = KiroCredentials::default(); - cred2.region = Some("eu-west-1".to_string()); - - let cred3 = KiroCredentials::default(); // 无 region,使用 config - - assert_eq!(get_oidc_region_for_credential(&cred1, &config), "us-east-1"); - assert_eq!(get_oidc_region_for_credential(&cred2, &config), "eu-west-1"); - assert_eq!( - get_oidc_region_for_credential(&cred3, &config), - "ap-northeast-1" - ); - } - - #[test] - fn test_idc_oidc_endpoint_uses_credential_region() { - // 验证 IdC OIDC endpoint URL 使用凭据 region - let mut config = Config::default(); - config.region = "us-west-2".to_string(); - - let mut credentials = KiroCredentials::default(); - credentials.region = Some("eu-central-1".to_string()); - - let region = get_oidc_region_for_credential(&credentials, &config); - let refresh_url = format!("https://oidc.{}.amazonaws.com/token", region); - - assert_eq!(refresh_url, "https://oidc.eu-central-1.amazonaws.com/token"); - } - - #[test] - fn test_social_refresh_endpoint_uses_credential_region() { - // 验证 Social refresh endpoint URL 使用凭据 region - let mut config = Config::default(); - config.region = "us-west-2".to_string(); - - let mut credentials = KiroCredentials::default(); - credentials.region = Some("ap-southeast-1".to_string()); - - let region = get_oidc_region_for_credential(&credentials, &config); - let refresh_url = format!("https://prod.{}.auth.desktop.kiro.dev/refreshToken", region); - - assert_eq!( - refresh_url, - "https://prod.ap-southeast-1.auth.desktop.kiro.dev/refreshToken" - ); - } - - #[test] - fn test_api_call_still_uses_config_region() { - // 验证 API 调用(如 getUsageLimits)仍使用 config.region - // 这确保只有 OIDC 刷新使用凭据 region,API 调用行为不变 - let mut config = Config::default(); - config.region = "us-west-2".to_string(); - - let mut credentials = KiroCredentials::default(); - credentials.region = Some("eu-west-1".to_string()); - - // API 调用应使用 config.region,而非 credentials.region - let api_region = &config.region; - let api_host = format!("q.{}.amazonaws.com", api_region); - - assert_eq!(api_host, "q.us-west-2.amazonaws.com"); - // 确认凭据 region 不影响 API 调用 - assert_ne!(api_region, credentials.region.as_ref().unwrap()); - } - - #[test] - fn test_credential_region_empty_string_treated_as_set() { - // 空字符串 region 被视为已设置(虽然不推荐,但行为应一致) - let mut config = Config::default(); - config.region = "us-west-2".to_string(); - - let mut credentials = KiroCredentials::default(); - credentials.region = Some("".to_string()); - - let region = get_oidc_region_for_credential(&credentials, &config); - // 空字符串被视为已设置,不会回退到 config - assert_eq!(region, ""); - } -} diff --git a/src/main.rs b/src/main.rs deleted file mode 100644 index c01ddd6d96ffe7d3b26baf00e87747ceca87e33a..0000000000000000000000000000000000000000 --- a/src/main.rs +++ /dev/null @@ -1,161 +0,0 @@ -mod admin; -mod admin_ui; -mod anthropic; -mod common; -mod http_client; -mod kiro; -mod model; -pub mod token; - -use std::sync::Arc; - -use clap::Parser; -use kiro::model::credentials::{CredentialsConfig, KiroCredentials}; -use kiro::provider::KiroProvider; -use kiro::token_manager::MultiTokenManager; -use model::arg::Args; -use model::config::Config; - -#[tokio::main] -async fn main() { - // 解析命令行参数 - let args = Args::parse(); - - // 初始化日志 - tracing_subscriber::fmt() - .with_env_filter( - tracing_subscriber::EnvFilter::from_default_env() - .add_directive(tracing::Level::INFO.into()), - ) - .init(); - - // 加载配置 - let config_path = args - .config - .unwrap_or_else(|| Config::default_config_path().to_string()); - let config = Config::load(&config_path).unwrap_or_else(|e| { - tracing::error!("加载配置失败: {}", e); - std::process::exit(1); - }); - - // 加载凭证(支持单对象或数组格式) - let credentials_path = args - .credentials - .unwrap_or_else(|| KiroCredentials::default_credentials_path().to_string()); - let credentials_config = CredentialsConfig::load(&credentials_path).unwrap_or_else(|e| { - tracing::error!("加载凭证失败: {}", e); - std::process::exit(1); - }); - - // 判断是否为多凭据格式(用于刷新后回写) - let is_multiple_format = credentials_config.is_multiple(); - - // 转换为按优先级排序的凭据列表 - let credentials_list = credentials_config.into_sorted_credentials(); - tracing::info!("已加载 {} 个凭据配置", credentials_list.len()); - - // 获取第一个凭据用于日志显示 - let first_credentials = credentials_list.first().cloned().unwrap_or_default(); - tracing::debug!("主凭证: {:?}", first_credentials); - - // 获取 API Key - let api_key = config.api_key.clone().unwrap_or_else(|| { - tracing::error!("配置文件中未设置 apiKey"); - std::process::exit(1); - }); - - // 构建代理配置 - let proxy_config = config.proxy_url.as_ref().map(|url| { - let mut proxy = http_client::ProxyConfig::new(url); - if let (Some(username), Some(password)) = (&config.proxy_username, &config.proxy_password) { - proxy = proxy.with_auth(username, password); - } - proxy - }); - - if proxy_config.is_some() { - tracing::info!("已配置 HTTP 代理: {}", config.proxy_url.as_ref().unwrap()); - } - - // 创建 MultiTokenManager 和 KiroProvider - let token_manager = MultiTokenManager::new( - config.clone(), - credentials_list, - proxy_config.clone(), - Some(credentials_path.into()), - is_multiple_format, - ) - .unwrap_or_else(|e| { - tracing::error!("创建 Token 管理器失败: {}", e); - std::process::exit(1); - }); - let token_manager = Arc::new(token_manager); - let kiro_provider = KiroProvider::with_proxy(token_manager.clone(), proxy_config.clone()); - - // 初始化 count_tokens 配置 - token::init_config(token::CountTokensConfig { - api_url: config.count_tokens_api_url.clone(), - api_key: config.count_tokens_api_key.clone(), - auth_type: config.count_tokens_auth_type.clone(), - proxy: proxy_config, - }); - - // 构建 Anthropic API 路由(从第一个凭据获取 profile_arn) - let anthropic_app = anthropic::create_router_with_provider( - &api_key, - Some(kiro_provider), - first_credentials.profile_arn.clone(), - ); - - // 构建 Admin API 路由(如果配置了非空的 admin_api_key) - // 安全检查:空字符串被视为未配置,防止空 key 绕过认证 - let admin_key_valid = config - .admin_api_key - .as_ref() - .map(|k| !k.trim().is_empty()) - .unwrap_or(false); - - let app = if let Some(admin_key) = &config.admin_api_key { - if admin_key.trim().is_empty() { - tracing::warn!("admin_api_key 配置为空,Admin API 未启用"); - anthropic_app - } else { - let admin_service = admin::AdminService::new(token_manager.clone()); - let admin_state = admin::AdminState::new(admin_key, admin_service); - let admin_app = admin::create_admin_router(admin_state); - - // 创建 Admin UI 路由 - let admin_ui_app = admin_ui::create_admin_ui_router(); - - tracing::info!("Admin API 已启用"); - tracing::info!("Admin UI 已启用: /admin"); - anthropic_app - .nest("/api/admin", admin_app) - .nest("/admin", admin_ui_app) - } - } else { - anthropic_app - }; - - // 启动服务器 - let addr = format!("{}:{}", config.host, config.port); - tracing::info!("启动 Anthropic API 端点: {}", addr); - tracing::info!("API Key: {}***", &api_key[..(api_key.len() / 2)]); - tracing::info!("可用 API:"); - tracing::info!(" GET /v1/models"); - tracing::info!(" POST /v1/messages"); - tracing::info!(" POST /v1/messages/count_tokens"); - if admin_key_valid { - tracing::info!("Admin API:"); - tracing::info!(" GET /api/admin/credentials"); - tracing::info!(" POST /api/admin/credentials/:index/disabled"); - tracing::info!(" POST /api/admin/credentials/:index/priority"); - tracing::info!(" POST /api/admin/credentials/:index/reset"); - tracing::info!(" GET /api/admin/credentials/:index/balance"); - tracing::info!("Admin UI:"); - tracing::info!(" GET /admin"); - } - - let listener = tokio::net::TcpListener::bind(&addr).await.unwrap(); - axum::serve(listener, app).await.unwrap(); -} diff --git a/src/model/arg.rs b/src/model/arg.rs deleted file mode 100644 index f572d8327ba13f8f92ccebe44bcce85b757cc9d6..0000000000000000000000000000000000000000 --- a/src/model/arg.rs +++ /dev/null @@ -1,14 +0,0 @@ -use clap::Parser; - -/// Anthropic <-> Kiro API 客户端 -#[derive(Parser, Debug)] -#[command(version, about, long_about = None)] -pub struct Args { - /// 配置文件路径 - #[arg(short, long)] - pub config: Option, - - /// 凭证文件路径 - #[arg(long)] - pub credentials: Option, -} diff --git a/src/model/config.rs b/src/model/config.rs deleted file mode 100644 index 15d32a24a9072a357cc923ea7febc3b6a4e77cad..0000000000000000000000000000000000000000 --- a/src/model/config.rs +++ /dev/null @@ -1,132 +0,0 @@ -use serde::{Deserialize, Serialize}; -use std::fs; -use std::path::Path; - -/// KNA 应用配置 -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct Config { - #[serde(default = "default_host")] - pub host: String, - - #[serde(default = "default_port")] - pub port: u16, - - #[serde(default = "default_region")] - pub region: String, - - #[serde(default = "default_kiro_version")] - pub kiro_version: String, - - #[serde(default)] - pub machine_id: Option, - - #[serde(default)] - pub api_key: Option, - - #[serde(default = "default_system_version")] - pub system_version: String, - - #[serde(default = "default_node_version")] - pub node_version: String, - - /// 外部 count_tokens API 地址(可选) - #[serde(default)] - pub count_tokens_api_url: Option, - - /// count_tokens API 密钥(可选) - #[serde(default)] - pub count_tokens_api_key: Option, - - /// count_tokens API 认证类型(可选,"x-api-key" 或 "bearer",默认 "x-api-key") - #[serde(default = "default_count_tokens_auth_type")] - pub count_tokens_auth_type: String, - - /// HTTP 代理地址(可选) - /// 支持格式: http://host:port, https://host:port, socks5://host:port - #[serde(default)] - pub proxy_url: Option, - - /// 代理认证用户名(可选) - #[serde(default)] - pub proxy_username: Option, - - /// 代理认证密码(可选) - #[serde(default)] - pub proxy_password: Option, - - /// Admin API 密钥(可选,启用 Admin API 功能) - #[serde(default)] - pub admin_api_key: Option, -} - -fn default_host() -> String { - "127.0.0.1".to_string() -} - -fn default_port() -> u16 { - 8080 -} - -fn default_region() -> String { - "us-east-1".to_string() -} - -fn default_kiro_version() -> String { - "0.8.0".to_string() -} - -fn default_system_version() -> String { - const SYSTEM_VERSIONS: &[&str] = &["darwin#24.6.0", "win32#10.0.22631"]; - SYSTEM_VERSIONS[fastrand::usize(..SYSTEM_VERSIONS.len())].to_string() -} - -fn default_node_version() -> String { - "22.21.1".to_string() -} - -fn default_count_tokens_auth_type() -> String { - "x-api-key".to_string() -} - -impl Default for Config { - fn default() -> Self { - Self { - host: default_host(), - port: default_port(), - region: default_region(), - kiro_version: default_kiro_version(), - machine_id: None, - api_key: None, - system_version: default_system_version(), - node_version: default_node_version(), - count_tokens_api_url: None, - count_tokens_api_key: None, - count_tokens_auth_type: default_count_tokens_auth_type(), - proxy_url: None, - proxy_username: None, - proxy_password: None, - admin_api_key: None, - } - } -} - -impl Config { - /// 获取默认配置文件路径 - pub fn default_config_path() -> &'static str { - "config.json" - } - - /// 从文件加载配置 - pub fn load>(path: P) -> anyhow::Result { - let path = path.as_ref(); - if !path.exists() { - // 配置文件不存在,返回默认配置 - return Ok(Self::default()); - } - - let content = fs::read_to_string(path)?; - let config: Config = serde_json::from_str(&content)?; - Ok(config) - } -} diff --git a/src/model/mod.rs b/src/model/mod.rs deleted file mode 100644 index 31871f75da8fed1d5bbe536edd8c542abd8bf51f..0000000000000000000000000000000000000000 --- a/src/model/mod.rs +++ /dev/null @@ -1,4 +0,0 @@ -//! 应用配置模型 - -pub mod arg; -pub mod config; diff --git a/src/test.rs b/src/test.rs deleted file mode 100644 index 48299fa3f6f0cb8caeecc76d264132cece075ce1..0000000000000000000000000000000000000000 --- a/src/test.rs +++ /dev/null @@ -1,107 +0,0 @@ -use futures::StreamExt; - -use crate::debug::{print_event, print_event_verbose, debug_crc, print_hex}; -use crate::kiro::model::credentials::KiroCredentials; -use crate::kiro::model::events::Event; -use crate::kiro::model::requests::KiroRequest; -use crate::kiro::parser::EventStreamDecoder; -use crate::kiro::provider::KiroProvider; -use crate::kiro::token_manager::TokenManager; -use crate::model::config::Config; - - -/// 调用流式 API 并实时打印返回 -pub(crate) async fn call_stream_api() -> anyhow::Result<()> { - // 读取 test.json 作为请求体 - let request_body = std::fs::read_to_string("test.json")?; - println!("已加载请求体,长度: {} 字节", request_body.len()); - - // 解析请求体为 KiroRequest 对象 - let request: KiroRequest = serde_json::from_str(&request_body)?; - println!("已解析请求对象:"); - println!(" 会话 ID: {}", request.conversation_id()); - println!(" 模型 ID: {}", request.model_id()); - println!(" 消息内容长度: {} 字符", request.current_content().len()); - if let Some(ref task_type) = request.conversation_state.agent_task_type { - println!(" 任务类型: {}", task_type); - } - if let Some(ref trigger_type) = request.conversation_state.chat_trigger_type { - println!(" 触发类型: {}", trigger_type); - } - println!(" 历史消息数: {}", request.conversation_state.history.len()); - println!(" 工具数量: {}", request.conversation_state.current_message.user_input_message.user_input_message_context.tools.len()); - - // 加载凭证 - let credentials = KiroCredentials::load_default()?; - println!("已加载凭证"); - - // 加载配置 - let config = Config::load_default()?; - println!("API 区域: {}", config.region); - - // 创建 TokenManager 和 KiroProvider - let token_manager = TokenManager::new(config, credentials); - let mut provider = KiroProvider::new(token_manager); - - println!("\n开始调用流式 API...\n"); - println!("{}", "=".repeat(60)); - - // 调用流式 API - let response = provider.call_api_stream(&request_body).await?; - - // 获取字节流 - let mut stream = response.bytes_stream(); - let mut decoder = EventStreamDecoder::new(); - - // 处理流式数据 - let mut total_bytes = 0usize; - while let Some(chunk_result) = stream.next().await { - match chunk_result { - Ok(chunk) => { - // 调试模式:打印原始 hex 数据 - // println!("\n[收到数据块] {} 字节, 偏移 {}", chunk.len(), total_bytes); - // print_hex(&chunk); - // debug_crc(&chunk); - - total_bytes += chunk.len(); - - // 将数据喂给解码器 - if let Err(e) = decoder.feed(&chunk) { - eprintln!("[缓冲区错误] {}", e); - continue; - } - - // 解码所有可用的帧 - for result in decoder.decode_iter() { - match result { - Ok(frame) => { - // 解析事件 - match Event::from_frame(frame) { - Ok(event) => { - // 简洁输出 - // print_event(&event); - // 详细输出 (调试用) - print_event_verbose(&event); - } - Err(e) => eprintln!("[解析错误] {}", e), - } - } - Err(e) => { - eprintln!("[帧解析错误] {}", e); - } - } - } - } - Err(e) => { - eprintln!("[网络错误] {}", e); - break; - } - } - } - - println!("\n{}", "=".repeat(60)); - println!("流式响应结束"); - println!("共接收 {} 字节,解码 {} 帧", total_bytes, decoder.frames_decoded()); - - Ok(()) -} \ No newline at end of file diff --git a/src/token.rs b/src/token.rs deleted file mode 100644 index 7b0395f2619d25ac5b53e90d26089069c0174f5f..0000000000000000000000000000000000000000 --- a/src/token.rs +++ /dev/null @@ -1,242 +0,0 @@ -//! Token 计算模块 -//! -//! 提供文本 token 数量计算功能。 -//! -//! # 计算规则 -//! - 非西文字符:每个计 4.5 个字符单位 -//! - 西文字符:每个计 1 个字符单位 -//! - 4 个字符单位 = 1 token(四舍五入) - -use crate::anthropic::types::{ - CountTokensRequest, CountTokensResponse, Message, SystemMessage, Tool, -}; -use crate::http_client::{ProxyConfig, build_client}; -use std::sync::OnceLock; - -/// Count Tokens API 配置 -#[derive(Clone, Default)] -pub struct CountTokensConfig { - /// 外部 count_tokens API 地址 - pub api_url: Option, - /// count_tokens API 密钥 - pub api_key: Option, - /// count_tokens API 认证类型("x-api-key" 或 "bearer") - pub auth_type: String, - /// 代理配置 - pub proxy: Option, -} - -/// 全局配置存储 -static COUNT_TOKENS_CONFIG: OnceLock = OnceLock::new(); - -/// 初始化 count_tokens 配置 -/// -/// 应在应用启动时调用一次 -pub fn init_config(config: CountTokensConfig) { - let _ = COUNT_TOKENS_CONFIG.set(config); -} - -/// 获取配置 -fn get_config() -> Option<&'static CountTokensConfig> { - COUNT_TOKENS_CONFIG.get() -} - -/// 判断字符是否为非西文字符 -/// -/// 西文字符包括: -/// - ASCII 字符 (U+0000..U+007F) -/// - 拉丁字母扩展 (U+0080..U+024F) -/// - 拉丁字母扩展附加 (U+1E00..U+1EFF) -/// -/// 返回 true 表示该字符是非西文字符(如中文、日文、韩文、阿拉伯文等) -fn is_non_western_char(c: char) -> bool { - !matches!(c, - // 基本 ASCII - '\u{0000}'..='\u{007F}' | - // 拉丁字母扩展-A (Latin Extended-A) - '\u{0080}'..='\u{00FF}' | - // 拉丁字母扩展-B (Latin Extended-B) - '\u{0100}'..='\u{024F}' | - // 拉丁字母扩展附加 (Latin Extended Additional) - '\u{1E00}'..='\u{1EFF}' | - // 拉丁字母扩展-C/D/E - '\u{2C60}'..='\u{2C7F}' | - '\u{A720}'..='\u{A7FF}' | - '\u{AB30}'..='\u{AB6F}' - ) -} - -/// 计算文本的 token 数量 -/// -/// # 计算规则 -/// - 非西文字符:每个计 4.5 个字符单位 -/// - 西文字符:每个计 1 个字符单位 -/// - 4 个字符单位 = 1 token(四舍五入) -/// ``` -pub fn count_tokens(text: &str) -> u64 { - // println!("text: {}", text); - - let char_units: f64 = text - .chars() - .map(|c| if is_non_western_char(c) { 4.0 } else { 1.0 }) - .sum(); - - let tokens = char_units / 4.0; - - let acc_token = if tokens < 100.0 { - tokens * 1.5 - } else if tokens < 200.0 { - tokens * 1.3 - } else if tokens < 300.0 { - tokens * 1.25 - } else if tokens < 800.0 { - tokens * 1.2 - } else { - tokens * 1.0 - } as u64; - - // println!("tokens: {}, acc_tokens: {}", tokens, acc_token); - acc_token -} - -/// 估算请求的输入 tokens -/// -/// 优先调用远程 API,失败时回退到本地计算 -pub(crate) fn count_all_tokens( - model: String, - system: Option>, - messages: Vec, - tools: Option>, -) -> u64 { - // 检查是否配置了远程 API - if let Some(config) = get_config() { - if let Some(api_url) = &config.api_url { - // 尝试调用远程 API - let result = tokio::task::block_in_place(|| { - tokio::runtime::Handle::current().block_on(call_remote_count_tokens( - api_url, config, model, &system, &messages, &tools, - )) - }); - - match result { - Ok(tokens) => { - tracing::debug!("远程 count_tokens API 返回: {}", tokens); - return tokens; - } - Err(e) => { - tracing::warn!("远程 count_tokens API 调用失败,回退到本地计算: {}", e); - } - } - } - } - - // 本地计算 - count_all_tokens_local(system, messages, tools) -} - -/// 调用远程 count_tokens API -async fn call_remote_count_tokens( - api_url: &str, - config: &CountTokensConfig, - model: String, - system: &Option>, - messages: &Vec, - tools: &Option>, -) -> Result> { - let client = build_client(config.proxy.as_ref(), 300)?; - - // 构建请求体 - let request = CountTokensRequest { - model: model, // 模型名称用于 token 计算 - messages: messages.clone(), - system: system.clone(), - tools: tools.clone(), - }; - - // 构建请求 - let mut req_builder = client.post(api_url); - - // 设置认证头 - if let Some(api_key) = &config.api_key { - if config.auth_type == "bearer" { - req_builder = req_builder.header("Authorization", format!("Bearer {}", api_key)); - } else { - req_builder = req_builder.header("x-api-key", api_key); - } - } - - // 发送请求 - let response = req_builder - .header("Content-Type", "application/json") - .json(&request) - .send() - .await?; - - if !response.status().is_success() { - return Err(format!("API 返回错误状态: {}", response.status()).into()); - } - - let result: CountTokensResponse = response.json().await?; - Ok(result.input_tokens as u64) -} - -/// 本地计算请求的输入 tokens -fn count_all_tokens_local( - system: Option>, - messages: Vec, - tools: Option>, -) -> u64 { - let mut total = 0; - - // 系统消息 - if let Some(ref system) = system { - for msg in system { - total += count_tokens(&msg.text); - } - } - - // 用户消息 - for msg in &messages { - if let serde_json::Value::String(s) = &msg.content { - total += count_tokens(s); - } else if let serde_json::Value::Array(arr) = &msg.content { - for item in arr { - if let Some(text) = item.get("text").and_then(|v| v.as_str()) { - total += count_tokens(text); - } - } - } - } - - // 工具定义 - if let Some(ref tools) = tools { - for tool in tools { - total += count_tokens(&tool.name); - total += count_tokens(&tool.description); - let input_schema_json = serde_json::to_string(&tool.input_schema).unwrap_or_default(); - total += count_tokens(&input_schema_json); - } - } - - total.max(1) -} - -/// 估算输出 tokens -pub(crate) fn estimate_output_tokens(content: &[serde_json::Value]) -> i32 { - let mut total = 0; - - for block in content { - if let Some(text) = block.get("text").and_then(|v| v.as_str()) { - total += count_tokens(text) as i32; - } - if block.get("type").and_then(|v| v.as_str()) == Some("tool_use") { - // 工具调用开销 - if let Some(input) = block.get("input") { - let input_str = serde_json::to_string(input).unwrap_or_default(); - total += count_tokens(&input_str) as i32; - } - } - } - - total.max(1) -} diff --git a/tools/event-viewer.html b/tools/event-viewer.html deleted file mode 100644 index 98cea731efa43878bfb899923a7d2ba928bc526f..0000000000000000000000000000000000000000 --- a/tools/event-viewer.html +++ /dev/null @@ -1,896 +0,0 @@ - - - - - - AWS Event Stream Viewer - - - -
-

AWS Event Stream Viewer

- -
- -
- - - - -
- -
- - - - - - -
- - - -