Upload 10 files
Browse files- Dockerfile +12 -0
- app.py +452 -0
- claude_converter.py +386 -0
- claude_parser.py +222 -0
- claude_stream.py +145 -0
- claude_types.py +20 -0
- config.py +40 -0
- replicate.py +199 -0
- requirements.txt +5 -0
- utils.py +53 -0
Dockerfile
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.11-slim
|
| 2 |
+
|
| 3 |
+
WORKDIR /app
|
| 4 |
+
|
| 5 |
+
COPY requirements.txt .
|
| 6 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 7 |
+
|
| 8 |
+
COPY . .
|
| 9 |
+
|
| 10 |
+
EXPOSE 8000
|
| 11 |
+
|
| 12 |
+
CMD ["python", "app.py"]
|
app.py
ADDED
|
@@ -0,0 +1,452 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
import traceback
|
| 4 |
+
import uuid
|
| 5 |
+
import time
|
| 6 |
+
import asyncio
|
| 7 |
+
import importlib.util
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from typing import Dict, Optional, List, Any, AsyncGenerator, Tuple
|
| 10 |
+
|
| 11 |
+
from contextlib import asynccontextmanager
|
| 12 |
+
from fastapi import FastAPI, Depends, HTTPException, Header
|
| 13 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 14 |
+
from fastapi.responses import StreamingResponse
|
| 15 |
+
from dotenv import load_dotenv
|
| 16 |
+
import httpx
|
| 17 |
+
import hashlib
|
| 18 |
+
|
| 19 |
+
from utils import get_proxies, create_proxy_mounts
|
| 20 |
+
|
| 21 |
+
# ------------------------------------------------------------------------------
|
| 22 |
+
# Bootstrap
|
| 23 |
+
# ------------------------------------------------------------------------------
|
| 24 |
+
|
| 25 |
+
BASE_DIR = Path(__file__).resolve().parent
|
| 26 |
+
|
| 27 |
+
load_dotenv(BASE_DIR / ".env")
|
| 28 |
+
|
| 29 |
+
app = FastAPI(title="v2 OpenAI-compatible Server (Amazon Q Backend)")
|
| 30 |
+
|
| 31 |
+
# CORS for simple testing in browser
|
| 32 |
+
app.add_middleware(
|
| 33 |
+
CORSMiddleware,
|
| 34 |
+
allow_origins=["*"],
|
| 35 |
+
allow_methods=["*"],
|
| 36 |
+
allow_headers=["*"],
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
# ------------------------------------------------------------------------------
|
| 40 |
+
# Dynamic import of replicate.py to avoid package __init__ needs
|
| 41 |
+
# ------------------------------------------------------------------------------
|
| 42 |
+
|
| 43 |
+
def _load_replicate_module():
|
| 44 |
+
mod_path = BASE_DIR / "replicate.py"
|
| 45 |
+
spec = importlib.util.spec_from_file_location("v2_replicate", str(mod_path))
|
| 46 |
+
module = importlib.util.module_from_spec(spec) # type: ignore[arg-type]
|
| 47 |
+
assert spec is not None and spec.loader is not None
|
| 48 |
+
spec.loader.exec_module(module) # type: ignore[attr-defined]
|
| 49 |
+
return module
|
| 50 |
+
|
| 51 |
+
_replicate = _load_replicate_module()
|
| 52 |
+
send_chat_request = _replicate.send_chat_request
|
| 53 |
+
|
| 54 |
+
# ------------------------------------------------------------------------------
|
| 55 |
+
# Dynamic import of Claude modules
|
| 56 |
+
# ------------------------------------------------------------------------------
|
| 57 |
+
|
| 58 |
+
def _load_claude_modules():
|
| 59 |
+
# claude_types
|
| 60 |
+
spec_types = importlib.util.spec_from_file_location("v2_claude_types", str(BASE_DIR / "claude_types.py"))
|
| 61 |
+
mod_types = importlib.util.module_from_spec(spec_types)
|
| 62 |
+
spec_types.loader.exec_module(mod_types)
|
| 63 |
+
|
| 64 |
+
# claude_converter
|
| 65 |
+
spec_conv = importlib.util.spec_from_file_location("v2_claude_converter", str(BASE_DIR / "claude_converter.py"))
|
| 66 |
+
mod_conv = importlib.util.module_from_spec(spec_conv)
|
| 67 |
+
|
| 68 |
+
import sys
|
| 69 |
+
sys.modules["v2.claude_types"] = mod_types
|
| 70 |
+
|
| 71 |
+
spec_conv.loader.exec_module(mod_conv)
|
| 72 |
+
|
| 73 |
+
# claude_stream
|
| 74 |
+
spec_stream = importlib.util.spec_from_file_location("v2_claude_stream", str(BASE_DIR / "claude_stream.py"))
|
| 75 |
+
mod_stream = importlib.util.module_from_spec(spec_stream)
|
| 76 |
+
spec_stream.loader.exec_module(mod_stream)
|
| 77 |
+
|
| 78 |
+
return mod_types, mod_conv, mod_stream
|
| 79 |
+
|
| 80 |
+
_claude_types, _claude_converter, _claude_stream = _load_claude_modules()
|
| 81 |
+
ClaudeRequest = _claude_types.ClaudeRequest
|
| 82 |
+
convert_claude_to_amazonq_request = _claude_converter.convert_claude_to_amazonq_request
|
| 83 |
+
ClaudeStreamHandler = _claude_stream.ClaudeStreamHandler
|
| 84 |
+
|
| 85 |
+
# ------------------------------------------------------------------------------
|
| 86 |
+
# Global HTTP Client
|
| 87 |
+
# ------------------------------------------------------------------------------
|
| 88 |
+
|
| 89 |
+
GLOBAL_CLIENT: Optional[httpx.AsyncClient] = None
|
| 90 |
+
|
| 91 |
+
async def _init_global_client():
|
| 92 |
+
global GLOBAL_CLIENT
|
| 93 |
+
mounts = create_proxy_mounts()
|
| 94 |
+
# Increased limits for high concurrency with streaming
|
| 95 |
+
# max_connections: 总连接数上限
|
| 96 |
+
# max_keepalive_connections: 保持活跃的连接数
|
| 97 |
+
# keepalive_expiry: 连接保持时间
|
| 98 |
+
limits = httpx.Limits(
|
| 99 |
+
max_keepalive_connections=60,
|
| 100 |
+
max_connections=60, # 提高到500以支持更高并发
|
| 101 |
+
keepalive_expiry=30.0 # 30秒后释放空闲连接
|
| 102 |
+
)
|
| 103 |
+
# 为流式响应设置更长的超时
|
| 104 |
+
timeout = httpx.Timeout(
|
| 105 |
+
connect=30.0, # 连接超时,TLS 握手需要足够时间
|
| 106 |
+
read=300.0, # 读取超时(流式响应需要更长时间)
|
| 107 |
+
write=30.0, # 写入超时
|
| 108 |
+
pool=10.0 # 从连接池获取连接的超时时间
|
| 109 |
+
)
|
| 110 |
+
# 只在有代理时才传递 mounts 参数
|
| 111 |
+
if mounts:
|
| 112 |
+
GLOBAL_CLIENT = httpx.AsyncClient(mounts=mounts, timeout=timeout, limits=limits)
|
| 113 |
+
else:
|
| 114 |
+
GLOBAL_CLIENT = httpx.AsyncClient(timeout=timeout, limits=limits)
|
| 115 |
+
|
| 116 |
+
async def _close_global_client():
|
| 117 |
+
global GLOBAL_CLIENT
|
| 118 |
+
if GLOBAL_CLIENT:
|
| 119 |
+
await GLOBAL_CLIENT.aclose()
|
| 120 |
+
GLOBAL_CLIENT = None
|
| 121 |
+
|
| 122 |
+
# ------------------------------------------------------------------------------
|
| 123 |
+
# Token 缓存和管理
|
| 124 |
+
# ------------------------------------------------------------------------------
|
| 125 |
+
|
| 126 |
+
# 内存缓存: {hash: {accessToken, refreshToken, clientId, clientSecret, lastRefresh}}
|
| 127 |
+
TOKEN_MAP: Dict[str, Dict[str, Any]] = {}
|
| 128 |
+
|
| 129 |
+
def _sha256(text: str) -> str:
|
| 130 |
+
"""计算 SHA256 哈希"""
|
| 131 |
+
return hashlib.sha256(text.encode()).hexdigest()
|
| 132 |
+
|
| 133 |
+
def _parse_bearer_token(bearer_token: str) -> Tuple[str, str, str]:
|
| 134 |
+
"""
|
| 135 |
+
解析 Bearer token: clientId:clientSecret:refreshToken
|
| 136 |
+
重要: refreshToken 中可能包含冒号,所以要正确处理
|
| 137 |
+
"""
|
| 138 |
+
temp_array = bearer_token.split(":")
|
| 139 |
+
client_id = temp_array[0] if len(temp_array) > 0 else ""
|
| 140 |
+
client_secret = temp_array[1] if len(temp_array) > 1 else ""
|
| 141 |
+
refresh_token = ":".join(temp_array[2:]) if len(temp_array) > 2 else ""
|
| 142 |
+
return client_id, client_secret, refresh_token
|
| 143 |
+
|
| 144 |
+
async def _handle_token_refresh(client_id: str, client_secret: str, refresh_token: str) -> Optional[str]:
|
| 145 |
+
"""刷新 access token"""
|
| 146 |
+
payload = {
|
| 147 |
+
"grantType": "refresh_token",
|
| 148 |
+
"clientId": client_id,
|
| 149 |
+
"clientSecret": client_secret,
|
| 150 |
+
"refreshToken": refresh_token,
|
| 151 |
+
}
|
| 152 |
+
|
| 153 |
+
try:
|
| 154 |
+
client = GLOBAL_CLIENT
|
| 155 |
+
if not client:
|
| 156 |
+
async with httpx.AsyncClient(timeout=60.0) as temp_client:
|
| 157 |
+
r = await temp_client.post(TOKEN_URL, headers=_oidc_headers(), json=payload)
|
| 158 |
+
r.raise_for_status()
|
| 159 |
+
data = r.json()
|
| 160 |
+
else:
|
| 161 |
+
r = await client.post(TOKEN_URL, headers=_oidc_headers(), json=payload)
|
| 162 |
+
r.raise_for_status()
|
| 163 |
+
data = r.json()
|
| 164 |
+
|
| 165 |
+
return data.get("accessToken")
|
| 166 |
+
except httpx.HTTPStatusError as e:
|
| 167 |
+
print(f"Token refresh HTTP error: {e.response.status_code} - {e.response.text}")
|
| 168 |
+
traceback.print_exc()
|
| 169 |
+
return None
|
| 170 |
+
except Exception as e:
|
| 171 |
+
print(f"Token refresh error: {e}")
|
| 172 |
+
traceback.print_exc()
|
| 173 |
+
return None
|
| 174 |
+
|
| 175 |
+
# ------------------------------------------------------------------------------
|
| 176 |
+
# 全局 Token 刷新器
|
| 177 |
+
# ------------------------------------------------------------------------------
|
| 178 |
+
|
| 179 |
+
async def _global_token_refresher():
|
| 180 |
+
"""全局刷新器: 每 45 分钟刷新所有缓存的 token"""
|
| 181 |
+
while True:
|
| 182 |
+
try:
|
| 183 |
+
await asyncio.sleep(45 * 60) # 45 minutes
|
| 184 |
+
if not TOKEN_MAP:
|
| 185 |
+
continue
|
| 186 |
+
print(f"[Token Refresher] Starting token refresh cycle...")
|
| 187 |
+
refresh_count = 0
|
| 188 |
+
for hash_key, token_data in list(TOKEN_MAP.items()):
|
| 189 |
+
try:
|
| 190 |
+
new_token = await _handle_token_refresh(
|
| 191 |
+
token_data["clientId"],
|
| 192 |
+
token_data["clientSecret"],
|
| 193 |
+
token_data["refreshToken"]
|
| 194 |
+
)
|
| 195 |
+
if new_token:
|
| 196 |
+
TOKEN_MAP[hash_key]["accessToken"] = new_token
|
| 197 |
+
TOKEN_MAP[hash_key]["lastRefresh"] = time.time()
|
| 198 |
+
refresh_count += 1
|
| 199 |
+
else:
|
| 200 |
+
print(f"[Token Refresher] Failed to refresh token for hash: {hash_key[:8]}...")
|
| 201 |
+
except Exception as e:
|
| 202 |
+
print(f"[Token Refresher] Exception refreshing token: {e}")
|
| 203 |
+
traceback.print_exc()
|
| 204 |
+
print(f"[Token Refresher] Refreshed {refresh_count}/{len(TOKEN_MAP)} tokens")
|
| 205 |
+
except Exception:
|
| 206 |
+
traceback.print_exc()
|
| 207 |
+
await asyncio.sleep(60) # 发生异常时等待 1 分钟后重试
|
| 208 |
+
|
| 209 |
+
# ------------------------------------------------------------------------------
|
| 210 |
+
# Token refresh (OIDC)
|
| 211 |
+
# ------------------------------------------------------------------------------
|
| 212 |
+
|
| 213 |
+
OIDC_BASE = "https://oidc.us-east-1.amazonaws.com"
|
| 214 |
+
TOKEN_URL = f"{OIDC_BASE}/token"
|
| 215 |
+
|
| 216 |
+
def _oidc_headers() -> Dict[str, str]:
|
| 217 |
+
return {
|
| 218 |
+
"content-type": "application/json",
|
| 219 |
+
"user-agent": "aws-sdk-rust/1.3.9 os/windows lang/rust/1.87.0",
|
| 220 |
+
"x-amz-user-agent": "aws-sdk-rust/1.3.9 ua/2.1 api/ssooidc/1.88.0 os/windows lang/rust/1.87.0 m/E app/AmazonQ-For-CLI",
|
| 221 |
+
"amz-sdk-request": "attempt=1; max=3",
|
| 222 |
+
"amz-sdk-invocation-id": str(uuid.uuid4()),
|
| 223 |
+
}
|
| 224 |
+
|
| 225 |
+
# ------------------------------------------------------------------------------
|
| 226 |
+
# 认证中间件
|
| 227 |
+
# ------------------------------------------------------------------------------
|
| 228 |
+
|
| 229 |
+
async def auth_middleware(authorization: Optional[str] = Header(default=None)) -> Dict[str, Any]:
|
| 230 |
+
"""
|
| 231 |
+
认证中间件: 解析 Bearer token 并返回账户信息
|
| 232 |
+
Bearer token 格式: clientId:clientSecret:refreshToken
|
| 233 |
+
"""
|
| 234 |
+
if not authorization or not authorization.startswith("Bearer "):
|
| 235 |
+
raise HTTPException(status_code=401, detail="Missing or invalid Authorization header")
|
| 236 |
+
|
| 237 |
+
bearer_token = authorization[7:] # 移除 "Bearer " 前缀
|
| 238 |
+
token_hash = _sha256(bearer_token)
|
| 239 |
+
|
| 240 |
+
# 检查缓存
|
| 241 |
+
if token_hash in TOKEN_MAP:
|
| 242 |
+
return {
|
| 243 |
+
"accessToken": TOKEN_MAP[token_hash]["accessToken"],
|
| 244 |
+
"clientId": TOKEN_MAP[token_hash]["clientId"],
|
| 245 |
+
"clientSecret": TOKEN_MAP[token_hash]["clientSecret"],
|
| 246 |
+
"refreshToken": TOKEN_MAP[token_hash]["refreshToken"],
|
| 247 |
+
}
|
| 248 |
+
|
| 249 |
+
# 解析 bearer token
|
| 250 |
+
client_id, client_secret, refresh_token = _parse_bearer_token(bearer_token)
|
| 251 |
+
|
| 252 |
+
if not client_id or not client_secret or not refresh_token:
|
| 253 |
+
raise HTTPException(status_code=401, detail="Invalid token format. Expected: clientId:clientSecret:refreshToken")
|
| 254 |
+
|
| 255 |
+
# 刷新 token
|
| 256 |
+
access_token = await _handle_token_refresh(client_id, client_secret, refresh_token)
|
| 257 |
+
if not access_token:
|
| 258 |
+
raise HTTPException(status_code=401, detail="Failed to refresh access token")
|
| 259 |
+
|
| 260 |
+
# 缓存
|
| 261 |
+
TOKEN_MAP[token_hash] = {
|
| 262 |
+
"accessToken": access_token,
|
| 263 |
+
"refreshToken": refresh_token,
|
| 264 |
+
"clientId": client_id,
|
| 265 |
+
"clientSecret": client_secret,
|
| 266 |
+
"lastRefresh": time.time()
|
| 267 |
+
}
|
| 268 |
+
|
| 269 |
+
return {
|
| 270 |
+
"accessToken": access_token,
|
| 271 |
+
"clientId": client_id,
|
| 272 |
+
"clientSecret": client_secret,
|
| 273 |
+
"refreshToken": refresh_token,
|
| 274 |
+
}
|
| 275 |
+
|
| 276 |
+
# ------------------------------------------------------------------------------
|
| 277 |
+
# Dependencies
|
| 278 |
+
# ------------------------------------------------------------------------------
|
| 279 |
+
|
| 280 |
+
async def require_account(authorization: Optional[str] = Header(default=None)) -> Dict[str, Any]:
|
| 281 |
+
return await auth_middleware(authorization)
|
| 282 |
+
|
| 283 |
+
# ------------------------------------------------------------------------------
|
| 284 |
+
# Claude Messages API endpoint
|
| 285 |
+
# ------------------------------------------------------------------------------
|
| 286 |
+
|
| 287 |
+
@app.post("/v1/messages")
|
| 288 |
+
async def claude_messages(req: ClaudeRequest, account: Dict[str, Any] = Depends(require_account)):
|
| 289 |
+
"""
|
| 290 |
+
Claude-compatible messages endpoint.
|
| 291 |
+
"""
|
| 292 |
+
# 1. Convert request
|
| 293 |
+
try:
|
| 294 |
+
aq_request = convert_claude_to_amazonq_request(req)
|
| 295 |
+
except Exception as e:
|
| 296 |
+
traceback.print_exc()
|
| 297 |
+
raise HTTPException(status_code=400, detail=f"Request conversion failed: {str(e)}")
|
| 298 |
+
|
| 299 |
+
# 2. Send upstream - always stream from upstream to get full event details
|
| 300 |
+
try:
|
| 301 |
+
access = account.get("accessToken")
|
| 302 |
+
if not access:
|
| 303 |
+
raise HTTPException(status_code=502, detail="Access token unavailable")
|
| 304 |
+
|
| 305 |
+
# We call with stream=True to get the event iterator
|
| 306 |
+
_, _, tracker, event_iter = await send_chat_request(
|
| 307 |
+
access_token=access,
|
| 308 |
+
messages=[],
|
| 309 |
+
model=req.model,
|
| 310 |
+
stream=True,
|
| 311 |
+
client=GLOBAL_CLIENT,
|
| 312 |
+
raw_payload=aq_request
|
| 313 |
+
)
|
| 314 |
+
|
| 315 |
+
if not event_iter:
|
| 316 |
+
raise HTTPException(status_code=502, detail="No event stream returned")
|
| 317 |
+
|
| 318 |
+
# Handler
|
| 319 |
+
# Estimate input tokens (simple count or 0)
|
| 320 |
+
# For now 0 or simple len
|
| 321 |
+
input_tokens = 0
|
| 322 |
+
handler = ClaudeStreamHandler(model=req.model, input_tokens=input_tokens)
|
| 323 |
+
|
| 324 |
+
async def event_generator():
|
| 325 |
+
try:
|
| 326 |
+
async for event_type, payload in event_iter:
|
| 327 |
+
async for sse in handler.handle_event(event_type, payload):
|
| 328 |
+
yield sse
|
| 329 |
+
async for sse in handler.finish():
|
| 330 |
+
yield sse
|
| 331 |
+
except GeneratorExit:
|
| 332 |
+
# Client disconnected
|
| 333 |
+
raise
|
| 334 |
+
except Exception:
|
| 335 |
+
raise
|
| 336 |
+
|
| 337 |
+
if req.stream:
|
| 338 |
+
return StreamingResponse(event_generator(), media_type="text/event-stream")
|
| 339 |
+
else:
|
| 340 |
+
# Accumulate for non-streaming
|
| 341 |
+
# This is a bit complex because we need to reconstruct the full response object
|
| 342 |
+
# For now, let's just support streaming as it's the main use case for Claude Code
|
| 343 |
+
# But to be nice, let's try to support non-streaming by consuming the generator
|
| 344 |
+
|
| 345 |
+
content_blocks = []
|
| 346 |
+
usage = {"input_tokens": 0, "output_tokens": 0}
|
| 347 |
+
stop_reason = None
|
| 348 |
+
|
| 349 |
+
# We need to parse the SSE strings back to objects... inefficient but works
|
| 350 |
+
# Or we could refactor handler to yield objects.
|
| 351 |
+
# For now, let's just raise error for non-streaming or implement basic text
|
| 352 |
+
# Claude Code uses streaming.
|
| 353 |
+
|
| 354 |
+
# Let's implement a basic accumulator from the SSE stream
|
| 355 |
+
final_content = []
|
| 356 |
+
|
| 357 |
+
async for sse_line in event_generator():
|
| 358 |
+
if sse_line.startswith("data: "):
|
| 359 |
+
data_str = sse_line[6:].strip()
|
| 360 |
+
if data_str == "[DONE]": continue
|
| 361 |
+
try:
|
| 362 |
+
data = json.loads(data_str)
|
| 363 |
+
dtype = data.get("type")
|
| 364 |
+
if dtype == "content_block_start":
|
| 365 |
+
idx = data.get("index", 0)
|
| 366 |
+
while len(final_content) <= idx:
|
| 367 |
+
final_content.append(None)
|
| 368 |
+
final_content[idx] = data.get("content_block")
|
| 369 |
+
elif dtype == "content_block_delta":
|
| 370 |
+
idx = data.get("index", 0)
|
| 371 |
+
delta = data.get("delta", {})
|
| 372 |
+
if final_content[idx]:
|
| 373 |
+
if delta.get("type") == "text_delta":
|
| 374 |
+
final_content[idx]["text"] += delta.get("text", "")
|
| 375 |
+
elif delta.get("type") == "input_json_delta":
|
| 376 |
+
# We need to accumulate partial json
|
| 377 |
+
# But wait, content_block for tool_use has 'input' as dict?
|
| 378 |
+
# No, in start it is empty.
|
| 379 |
+
# We need to track partial json string
|
| 380 |
+
if "partial_json" not in final_content[idx]:
|
| 381 |
+
final_content[idx]["partial_json"] = ""
|
| 382 |
+
final_content[idx]["partial_json"] += delta.get("partial_json", "")
|
| 383 |
+
elif dtype == "content_block_stop":
|
| 384 |
+
idx = data.get("index", 0)
|
| 385 |
+
# If tool use, parse json
|
| 386 |
+
if final_content[idx] and final_content[idx]["type"] == "tool_use":
|
| 387 |
+
if "partial_json" in final_content[idx]:
|
| 388 |
+
try:
|
| 389 |
+
final_content[idx]["input"] = json.loads(final_content[idx]["partial_json"])
|
| 390 |
+
except:
|
| 391 |
+
pass
|
| 392 |
+
del final_content[idx]["partial_json"]
|
| 393 |
+
elif dtype == "message_delta":
|
| 394 |
+
usage = data.get("usage", usage)
|
| 395 |
+
stop_reason = data.get("delta", {}).get("stop_reason")
|
| 396 |
+
except:
|
| 397 |
+
pass
|
| 398 |
+
|
| 399 |
+
return {
|
| 400 |
+
"id": f"msg_{uuid.uuid4()}",
|
| 401 |
+
"type": "message",
|
| 402 |
+
"role": "assistant",
|
| 403 |
+
"model": req.model,
|
| 404 |
+
"content": [c for c in final_content if c is not None],
|
| 405 |
+
"stop_reason": stop_reason,
|
| 406 |
+
"stop_sequence": None,
|
| 407 |
+
"usage": usage
|
| 408 |
+
}
|
| 409 |
+
|
| 410 |
+
except Exception as e:
|
| 411 |
+
raise
|
| 412 |
+
|
| 413 |
+
# ------------------------------------------------------------------------------
|
| 414 |
+
# Startup / Shutdown Events
|
| 415 |
+
# ------------------------------------------------------------------------------
|
| 416 |
+
|
| 417 |
+
async def _startup():
|
| 418 |
+
"""初始化全局客户端和启动后台任务"""
|
| 419 |
+
await _init_global_client()
|
| 420 |
+
asyncio.create_task(_global_token_refresher())
|
| 421 |
+
|
| 422 |
+
async def _shutdown():
|
| 423 |
+
"""清理资源"""
|
| 424 |
+
await _close_global_client()
|
| 425 |
+
|
| 426 |
+
# 更新 lifespan 上下文管理器使用实际的启动/关闭逻辑
|
| 427 |
+
@asynccontextmanager
|
| 428 |
+
async def lifespan(app_instance: FastAPI):
|
| 429 |
+
"""
|
| 430 |
+
管理应用生命周期事件
|
| 431 |
+
启动时初始化数据库和后台任务,关闭时清理资源
|
| 432 |
+
"""
|
| 433 |
+
await _startup()
|
| 434 |
+
yield
|
| 435 |
+
await _shutdown()
|
| 436 |
+
|
| 437 |
+
# 将 lifespan 设置到 app
|
| 438 |
+
app.router.lifespan_context = lifespan
|
| 439 |
+
|
| 440 |
+
# ------------------------------------------------------------------------------
|
| 441 |
+
# 直接运行支持
|
| 442 |
+
# ------------------------------------------------------------------------------
|
| 443 |
+
|
| 444 |
+
if __name__ == "__main__":
|
| 445 |
+
import uvicorn
|
| 446 |
+
port = int(os.getenv("PORT", "8000"))
|
| 447 |
+
uvicorn.run(
|
| 448 |
+
app,
|
| 449 |
+
host="0.0.0.0",
|
| 450 |
+
port=port,
|
| 451 |
+
log_level="info"
|
| 452 |
+
)
|
claude_converter.py
ADDED
|
@@ -0,0 +1,386 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import uuid
|
| 3 |
+
from datetime import datetime
|
| 4 |
+
from typing import List, Dict, Any, Optional, Union
|
| 5 |
+
|
| 6 |
+
try:
|
| 7 |
+
from .claude_types import ClaudeRequest, ClaudeMessage, ClaudeTool
|
| 8 |
+
except ImportError:
|
| 9 |
+
# Fallback for dynamic loading where relative import might fail
|
| 10 |
+
# We assume claude_types is available in sys.modules or we can import it directly if in same dir
|
| 11 |
+
import sys
|
| 12 |
+
if "v2.claude_types" in sys.modules:
|
| 13 |
+
from v2.claude_types import ClaudeRequest, ClaudeMessage, ClaudeTool
|
| 14 |
+
else:
|
| 15 |
+
# Try absolute import assuming v2 is in path or current dir
|
| 16 |
+
try:
|
| 17 |
+
from claude_types import ClaudeRequest, ClaudeMessage, ClaudeTool
|
| 18 |
+
except ImportError:
|
| 19 |
+
# Last resort: if loaded via importlib in app.py, we might need to rely on app.py injecting it
|
| 20 |
+
# But app.py loads this module.
|
| 21 |
+
pass
|
| 22 |
+
|
| 23 |
+
def get_current_timestamp() -> str:
|
| 24 |
+
"""Get current timestamp in Amazon Q format."""
|
| 25 |
+
now = datetime.now().astimezone()
|
| 26 |
+
weekday = now.strftime("%A")
|
| 27 |
+
iso_time = now.isoformat(timespec='milliseconds')
|
| 28 |
+
return f"{weekday}, {iso_time}"
|
| 29 |
+
|
| 30 |
+
def map_model_name(claude_model: str) -> str:
|
| 31 |
+
"""Map Claude model name to Amazon Q model ID."""
|
| 32 |
+
model_lower = claude_model.lower()
|
| 33 |
+
if model_lower.startswith("claude-sonnet-4.5") or model_lower.startswith("claude-sonnet-4-5"):
|
| 34 |
+
return "claude-sonnet-4.5"
|
| 35 |
+
return "claude-sonnet-4"
|
| 36 |
+
|
| 37 |
+
def extract_text_from_content(content: Union[str, List[Dict[str, Any]]]) -> str:
|
| 38 |
+
"""Extract text from Claude content."""
|
| 39 |
+
if isinstance(content, str):
|
| 40 |
+
return content
|
| 41 |
+
elif isinstance(content, list):
|
| 42 |
+
parts = []
|
| 43 |
+
for block in content:
|
| 44 |
+
if isinstance(block, dict):
|
| 45 |
+
if block.get("type") == "text":
|
| 46 |
+
parts.append(block.get("text", ""))
|
| 47 |
+
return "\n".join(parts)
|
| 48 |
+
return ""
|
| 49 |
+
|
| 50 |
+
def process_tool_result_block(block: Dict[str, Any], tool_results: List[Dict[str, Any]]) -> None:
|
| 51 |
+
"""
|
| 52 |
+
处理单个 tool_result 块,提取内容并添加到 tool_results 列表
|
| 53 |
+
|
| 54 |
+
Args:
|
| 55 |
+
block: tool_result 类型的内容块
|
| 56 |
+
tool_results: 用于存储处理结果的列表
|
| 57 |
+
"""
|
| 58 |
+
tool_use_id = block.get("tool_use_id")
|
| 59 |
+
raw_c = block.get("content", [])
|
| 60 |
+
|
| 61 |
+
aq_content = []
|
| 62 |
+
if isinstance(raw_c, str):
|
| 63 |
+
aq_content = [{"text": raw_c}]
|
| 64 |
+
elif isinstance(raw_c, list):
|
| 65 |
+
for item in raw_c:
|
| 66 |
+
if isinstance(item, dict):
|
| 67 |
+
if item.get("type") == "text":
|
| 68 |
+
aq_content.append({"text": item.get("text", "")})
|
| 69 |
+
elif "text" in item:
|
| 70 |
+
aq_content.append({"text": item["text"]})
|
| 71 |
+
elif isinstance(item, str):
|
| 72 |
+
aq_content.append({"text": item})
|
| 73 |
+
|
| 74 |
+
if not any(i.get("text", "").strip() for i in aq_content):
|
| 75 |
+
aq_content = [{"text": "Tool use was cancelled by the user"}]
|
| 76 |
+
|
| 77 |
+
# Merge if exists
|
| 78 |
+
existing = next((r for r in tool_results if r["toolUseId"] == tool_use_id), None)
|
| 79 |
+
if existing:
|
| 80 |
+
existing["content"].extend(aq_content)
|
| 81 |
+
else:
|
| 82 |
+
tool_results.append({
|
| 83 |
+
"toolUseId": tool_use_id,
|
| 84 |
+
"content": aq_content,
|
| 85 |
+
"status": block.get("status", "success")
|
| 86 |
+
})
|
| 87 |
+
|
| 88 |
+
def extract_images_from_content(content: Union[str, List[Dict[str, Any]]]) -> Optional[List[Dict[str, Any]]]:
|
| 89 |
+
"""Extract images from Claude content and convert to Amazon Q format."""
|
| 90 |
+
if not isinstance(content, list):
|
| 91 |
+
return None
|
| 92 |
+
|
| 93 |
+
images = []
|
| 94 |
+
for block in content:
|
| 95 |
+
if isinstance(block, dict) and block.get("type") == "image":
|
| 96 |
+
source = block.get("source", {})
|
| 97 |
+
if source.get("type") == "base64":
|
| 98 |
+
media_type = source.get("media_type", "image/png")
|
| 99 |
+
fmt = media_type.split("/")[-1] if "/" in media_type else "png"
|
| 100 |
+
images.append({
|
| 101 |
+
"format": fmt,
|
| 102 |
+
"source": {
|
| 103 |
+
"bytes": source.get("data", "")
|
| 104 |
+
}
|
| 105 |
+
})
|
| 106 |
+
return images if images else None
|
| 107 |
+
|
| 108 |
+
def convert_tool(tool: ClaudeTool) -> Dict[str, Any]:
|
| 109 |
+
"""Convert Claude tool to Amazon Q tool."""
|
| 110 |
+
desc = tool.description or ""
|
| 111 |
+
if len(desc) > 10240:
|
| 112 |
+
desc = desc[:10100] + "\n\n...(Full description provided in TOOL DOCUMENTATION section)"
|
| 113 |
+
|
| 114 |
+
return {
|
| 115 |
+
"toolSpecification": {
|
| 116 |
+
"name": tool.name,
|
| 117 |
+
"description": desc,
|
| 118 |
+
"inputSchema": {"json": tool.input_schema}
|
| 119 |
+
}
|
| 120 |
+
}
|
| 121 |
+
|
| 122 |
+
def merge_user_messages(messages: List[Dict[str, Any]]) -> Dict[str, Any]:
|
| 123 |
+
"""Merge consecutive user messages, keeping only the last 2 messages' images."""
|
| 124 |
+
if not messages:
|
| 125 |
+
return {}
|
| 126 |
+
|
| 127 |
+
all_contents = []
|
| 128 |
+
base_context = None
|
| 129 |
+
base_origin = None
|
| 130 |
+
base_model = None
|
| 131 |
+
all_images = []
|
| 132 |
+
|
| 133 |
+
for msg in messages:
|
| 134 |
+
content = msg.get("content", "")
|
| 135 |
+
if base_context is None:
|
| 136 |
+
base_context = msg.get("userInputMessageContext", {})
|
| 137 |
+
if base_origin is None:
|
| 138 |
+
base_origin = msg.get("origin", "CLI")
|
| 139 |
+
if base_model is None:
|
| 140 |
+
base_model = msg.get("modelId")
|
| 141 |
+
|
| 142 |
+
if content:
|
| 143 |
+
all_contents.append(content)
|
| 144 |
+
|
| 145 |
+
# Collect images from each message
|
| 146 |
+
msg_images = msg.get("images")
|
| 147 |
+
if msg_images:
|
| 148 |
+
all_images.append(msg_images)
|
| 149 |
+
|
| 150 |
+
result = {
|
| 151 |
+
"content": "\n\n".join(all_contents),
|
| 152 |
+
"userInputMessageContext": base_context or {},
|
| 153 |
+
"origin": base_origin or "CLI",
|
| 154 |
+
"modelId": base_model
|
| 155 |
+
}
|
| 156 |
+
|
| 157 |
+
# Only keep images from the last 2 messages that have images
|
| 158 |
+
if all_images:
|
| 159 |
+
kept_images = []
|
| 160 |
+
for img_list in all_images[-2:]: # Take last 2 messages' images
|
| 161 |
+
kept_images.extend(img_list)
|
| 162 |
+
if kept_images:
|
| 163 |
+
result["images"] = kept_images
|
| 164 |
+
|
| 165 |
+
return result
|
| 166 |
+
|
| 167 |
+
def process_history(messages: List[ClaudeMessage]) -> List[Dict[str, Any]]:
|
| 168 |
+
"""Process history messages to match Amazon Q format (alternating user/assistant)."""
|
| 169 |
+
history = []
|
| 170 |
+
seen_tool_use_ids = set()
|
| 171 |
+
|
| 172 |
+
raw_history = []
|
| 173 |
+
|
| 174 |
+
# First pass: convert individual messages
|
| 175 |
+
for msg in messages:
|
| 176 |
+
if msg.role == "user":
|
| 177 |
+
content = msg.content
|
| 178 |
+
text_content = ""
|
| 179 |
+
tool_results = None
|
| 180 |
+
images = extract_images_from_content(content)
|
| 181 |
+
|
| 182 |
+
if isinstance(content, list):
|
| 183 |
+
text_parts = []
|
| 184 |
+
for block in content:
|
| 185 |
+
if isinstance(block, dict):
|
| 186 |
+
btype = block.get("type")
|
| 187 |
+
if btype == "text":
|
| 188 |
+
text_parts.append(block.get("text", ""))
|
| 189 |
+
elif btype == "tool_result":
|
| 190 |
+
if tool_results is None:
|
| 191 |
+
tool_results = []
|
| 192 |
+
process_tool_result_block(block, tool_results)
|
| 193 |
+
text_content = "\n".join(text_parts)
|
| 194 |
+
else:
|
| 195 |
+
text_content = extract_text_from_content(content)
|
| 196 |
+
|
| 197 |
+
user_ctx = {
|
| 198 |
+
"envState": {
|
| 199 |
+
"operatingSystem": "macos",
|
| 200 |
+
"currentWorkingDirectory": "/"
|
| 201 |
+
}
|
| 202 |
+
}
|
| 203 |
+
if tool_results:
|
| 204 |
+
user_ctx["toolResults"] = tool_results
|
| 205 |
+
|
| 206 |
+
u_msg = {
|
| 207 |
+
"content": text_content,
|
| 208 |
+
"userInputMessageContext": user_ctx,
|
| 209 |
+
"origin": "CLI"
|
| 210 |
+
}
|
| 211 |
+
if images:
|
| 212 |
+
u_msg["images"] = images
|
| 213 |
+
|
| 214 |
+
raw_history.append({"userInputMessage": u_msg})
|
| 215 |
+
|
| 216 |
+
elif msg.role == "assistant":
|
| 217 |
+
content = msg.content
|
| 218 |
+
text_content = extract_text_from_content(content)
|
| 219 |
+
|
| 220 |
+
entry = {
|
| 221 |
+
"assistantResponseMessage": {
|
| 222 |
+
"messageId": str(uuid.uuid4()),
|
| 223 |
+
"content": text_content
|
| 224 |
+
}
|
| 225 |
+
}
|
| 226 |
+
|
| 227 |
+
if isinstance(content, list):
|
| 228 |
+
tool_uses = []
|
| 229 |
+
for block in content:
|
| 230 |
+
if isinstance(block, dict) and block.get("type") == "tool_use":
|
| 231 |
+
tid = block.get("id")
|
| 232 |
+
if tid and tid not in seen_tool_use_ids:
|
| 233 |
+
seen_tool_use_ids.add(tid)
|
| 234 |
+
tool_uses.append({
|
| 235 |
+
"toolUseId": tid,
|
| 236 |
+
"name": block.get("name"),
|
| 237 |
+
"input": block.get("input", {})
|
| 238 |
+
})
|
| 239 |
+
if tool_uses:
|
| 240 |
+
entry["assistantResponseMessage"]["toolUses"] = tool_uses
|
| 241 |
+
|
| 242 |
+
raw_history.append(entry)
|
| 243 |
+
|
| 244 |
+
# Second pass: merge consecutive user messages
|
| 245 |
+
pending_user_msgs = []
|
| 246 |
+
for item in raw_history:
|
| 247 |
+
if "userInputMessage" in item:
|
| 248 |
+
pending_user_msgs.append(item["userInputMessage"])
|
| 249 |
+
elif "assistantResponseMessage" in item:
|
| 250 |
+
if pending_user_msgs:
|
| 251 |
+
merged = merge_user_messages(pending_user_msgs)
|
| 252 |
+
history.append({"userInputMessage": merged})
|
| 253 |
+
pending_user_msgs = []
|
| 254 |
+
history.append(item)
|
| 255 |
+
|
| 256 |
+
if pending_user_msgs:
|
| 257 |
+
merged = merge_user_messages(pending_user_msgs)
|
| 258 |
+
history.append({"userInputMessage": merged})
|
| 259 |
+
|
| 260 |
+
return history
|
| 261 |
+
|
| 262 |
+
def convert_claude_to_amazonq_request(req: ClaudeRequest, conversation_id: Optional[str] = None) -> Dict[str, Any]:
|
| 263 |
+
"""Convert ClaudeRequest to Amazon Q request body."""
|
| 264 |
+
if conversation_id is None:
|
| 265 |
+
conversation_id = str(uuid.uuid4())
|
| 266 |
+
|
| 267 |
+
# 1. Tools
|
| 268 |
+
aq_tools = []
|
| 269 |
+
long_desc_tools = []
|
| 270 |
+
if req.tools:
|
| 271 |
+
for t in req.tools:
|
| 272 |
+
if t.description and len(t.description) > 10240:
|
| 273 |
+
long_desc_tools.append({"name": t.name, "full_description": t.description})
|
| 274 |
+
aq_tools.append(convert_tool(t))
|
| 275 |
+
|
| 276 |
+
# 2. Current Message (last user message)
|
| 277 |
+
last_msg = req.messages[-1] if req.messages else None
|
| 278 |
+
prompt_content = ""
|
| 279 |
+
tool_results = None
|
| 280 |
+
has_tool_result = False
|
| 281 |
+
images = None
|
| 282 |
+
|
| 283 |
+
if last_msg and last_msg.role == "user":
|
| 284 |
+
content = last_msg.content
|
| 285 |
+
images = extract_images_from_content(content)
|
| 286 |
+
|
| 287 |
+
if isinstance(content, list):
|
| 288 |
+
text_parts = []
|
| 289 |
+
for block in content:
|
| 290 |
+
if isinstance(block, dict):
|
| 291 |
+
btype = block.get("type")
|
| 292 |
+
if btype == "text":
|
| 293 |
+
text_parts.append(block.get("text", ""))
|
| 294 |
+
elif btype == "tool_result":
|
| 295 |
+
has_tool_result = True
|
| 296 |
+
if tool_results is None:
|
| 297 |
+
tool_results = []
|
| 298 |
+
process_tool_result_block(block, tool_results)
|
| 299 |
+
prompt_content = "\n".join(text_parts)
|
| 300 |
+
else:
|
| 301 |
+
prompt_content = extract_text_from_content(content)
|
| 302 |
+
|
| 303 |
+
# 3. Context
|
| 304 |
+
user_ctx = {
|
| 305 |
+
"envState": {
|
| 306 |
+
"operatingSystem": "macos",
|
| 307 |
+
"currentWorkingDirectory": "/"
|
| 308 |
+
}
|
| 309 |
+
}
|
| 310 |
+
if aq_tools:
|
| 311 |
+
user_ctx["tools"] = aq_tools
|
| 312 |
+
if tool_results:
|
| 313 |
+
user_ctx["toolResults"] = tool_results
|
| 314 |
+
|
| 315 |
+
# 4. Format Content
|
| 316 |
+
formatted_content = ""
|
| 317 |
+
if has_tool_result and not prompt_content:
|
| 318 |
+
formatted_content = ""
|
| 319 |
+
else:
|
| 320 |
+
formatted_content = (
|
| 321 |
+
"--- CONTEXT ENTRY BEGIN ---\n"
|
| 322 |
+
f"Current time: {get_current_timestamp()}\n"
|
| 323 |
+
"--- CONTEXT ENTRY END ---\n\n"
|
| 324 |
+
"--- USER MESSAGE BEGIN ---\n"
|
| 325 |
+
f"{prompt_content}\n"
|
| 326 |
+
"--- USER MESSAGE END ---"
|
| 327 |
+
)
|
| 328 |
+
|
| 329 |
+
if long_desc_tools:
|
| 330 |
+
docs = []
|
| 331 |
+
for info in long_desc_tools:
|
| 332 |
+
docs.append(f"Tool: {info['name']}\nFull Description:\n{info['full_description']}\n")
|
| 333 |
+
formatted_content = (
|
| 334 |
+
"--- TOOL DOCUMENTATION BEGIN ---\n"
|
| 335 |
+
f"{''.join(docs)}"
|
| 336 |
+
"--- TOOL DOCUMENTATION END ---\n\n"
|
| 337 |
+
f"{formatted_content}"
|
| 338 |
+
)
|
| 339 |
+
|
| 340 |
+
if req.system and formatted_content:
|
| 341 |
+
sys_text = ""
|
| 342 |
+
if isinstance(req.system, str):
|
| 343 |
+
sys_text = req.system
|
| 344 |
+
elif isinstance(req.system, list):
|
| 345 |
+
parts = []
|
| 346 |
+
for b in req.system:
|
| 347 |
+
if isinstance(b, dict) and b.get("type") == "text":
|
| 348 |
+
parts.append(b.get("text", ""))
|
| 349 |
+
sys_text = "\n".join(parts)
|
| 350 |
+
|
| 351 |
+
if sys_text:
|
| 352 |
+
formatted_content = (
|
| 353 |
+
"--- SYSTEM PROMPT BEGIN ---\n"
|
| 354 |
+
f"{sys_text}\n"
|
| 355 |
+
"--- SYSTEM PROMPT END ---\n\n"
|
| 356 |
+
f"{formatted_content}"
|
| 357 |
+
)
|
| 358 |
+
|
| 359 |
+
# 5. Model
|
| 360 |
+
model_id = map_model_name(req.model)
|
| 361 |
+
|
| 362 |
+
# 6. User Input Message
|
| 363 |
+
user_input_msg = {
|
| 364 |
+
"content": formatted_content,
|
| 365 |
+
"userInputMessageContext": user_ctx,
|
| 366 |
+
"origin": "CLI",
|
| 367 |
+
"modelId": model_id
|
| 368 |
+
}
|
| 369 |
+
if images:
|
| 370 |
+
user_input_msg["images"] = images
|
| 371 |
+
|
| 372 |
+
# 7. History
|
| 373 |
+
history_msgs = req.messages[:-1] if len(req.messages) > 1 else []
|
| 374 |
+
aq_history = process_history(history_msgs)
|
| 375 |
+
|
| 376 |
+
# 8. Final Body
|
| 377 |
+
return {
|
| 378 |
+
"conversationState": {
|
| 379 |
+
"conversationId": conversation_id,
|
| 380 |
+
"history": aq_history,
|
| 381 |
+
"currentMessage": {
|
| 382 |
+
"userInputMessage": user_input_msg
|
| 383 |
+
},
|
| 384 |
+
"chatTriggerType": "MANUAL"
|
| 385 |
+
}
|
| 386 |
+
}
|
claude_parser.py
ADDED
|
@@ -0,0 +1,222 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import struct
|
| 3 |
+
import logging
|
| 4 |
+
from typing import Optional, Dict, Any, AsyncIterator
|
| 5 |
+
|
| 6 |
+
logger = logging.getLogger(__name__)
|
| 7 |
+
|
| 8 |
+
class EventStreamParser:
|
| 9 |
+
"""AWS Event Stream binary format parser (v2 style)."""
|
| 10 |
+
|
| 11 |
+
@staticmethod
|
| 12 |
+
def parse_headers(headers_data: bytes) -> Dict[str, str]:
|
| 13 |
+
"""Parse event stream headers."""
|
| 14 |
+
headers = {}
|
| 15 |
+
offset = 0
|
| 16 |
+
|
| 17 |
+
while offset < len(headers_data):
|
| 18 |
+
if offset >= len(headers_data):
|
| 19 |
+
break
|
| 20 |
+
name_length = headers_data[offset]
|
| 21 |
+
offset += 1
|
| 22 |
+
|
| 23 |
+
if offset + name_length > len(headers_data):
|
| 24 |
+
break
|
| 25 |
+
name = headers_data[offset:offset + name_length].decode('utf-8')
|
| 26 |
+
offset += name_length
|
| 27 |
+
|
| 28 |
+
if offset >= len(headers_data):
|
| 29 |
+
break
|
| 30 |
+
value_type = headers_data[offset]
|
| 31 |
+
offset += 1
|
| 32 |
+
|
| 33 |
+
if offset + 2 > len(headers_data):
|
| 34 |
+
break
|
| 35 |
+
value_length = struct.unpack('>H', headers_data[offset:offset + 2])[0]
|
| 36 |
+
offset += 2
|
| 37 |
+
|
| 38 |
+
if offset + value_length > len(headers_data):
|
| 39 |
+
break
|
| 40 |
+
|
| 41 |
+
if value_type == 7:
|
| 42 |
+
value = headers_data[offset:offset + value_length].decode('utf-8')
|
| 43 |
+
else:
|
| 44 |
+
value = headers_data[offset:offset + value_length]
|
| 45 |
+
|
| 46 |
+
offset += value_length
|
| 47 |
+
headers[name] = value
|
| 48 |
+
|
| 49 |
+
return headers
|
| 50 |
+
|
| 51 |
+
@staticmethod
|
| 52 |
+
def parse_message(data: bytes) -> Optional[Dict[str, Any]]:
|
| 53 |
+
"""Parse single Event Stream message."""
|
| 54 |
+
try:
|
| 55 |
+
if len(data) < 16:
|
| 56 |
+
return None
|
| 57 |
+
|
| 58 |
+
total_length = struct.unpack('>I', data[0:4])[0]
|
| 59 |
+
headers_length = struct.unpack('>I', data[4:8])[0]
|
| 60 |
+
|
| 61 |
+
if len(data) < total_length:
|
| 62 |
+
logger.warning(f"Incomplete message: expected {total_length} bytes, got {len(data)}")
|
| 63 |
+
return None
|
| 64 |
+
|
| 65 |
+
headers_data = data[12:12 + headers_length]
|
| 66 |
+
headers = EventStreamParser.parse_headers(headers_data)
|
| 67 |
+
|
| 68 |
+
payload_start = 12 + headers_length
|
| 69 |
+
payload_end = total_length - 4
|
| 70 |
+
payload_data = data[payload_start:payload_end]
|
| 71 |
+
|
| 72 |
+
payload = None
|
| 73 |
+
if payload_data:
|
| 74 |
+
try:
|
| 75 |
+
payload = json.loads(payload_data.decode('utf-8'))
|
| 76 |
+
except (json.JSONDecodeError, UnicodeDecodeError):
|
| 77 |
+
payload = payload_data
|
| 78 |
+
|
| 79 |
+
return {
|
| 80 |
+
'headers': headers,
|
| 81 |
+
'payload': payload,
|
| 82 |
+
'total_length': total_length
|
| 83 |
+
}
|
| 84 |
+
|
| 85 |
+
except Exception as e:
|
| 86 |
+
logger.error(f"Failed to parse message: {e}", exc_info=True)
|
| 87 |
+
return None
|
| 88 |
+
|
| 89 |
+
@staticmethod
|
| 90 |
+
async def parse_stream(byte_stream: AsyncIterator[bytes]) -> AsyncIterator[Dict[str, Any]]:
|
| 91 |
+
"""Parse byte stream and extract events."""
|
| 92 |
+
buffer = bytearray()
|
| 93 |
+
|
| 94 |
+
async for chunk in byte_stream:
|
| 95 |
+
buffer.extend(chunk)
|
| 96 |
+
|
| 97 |
+
while len(buffer) >= 12:
|
| 98 |
+
try:
|
| 99 |
+
total_length = struct.unpack('>I', buffer[0:4])[0]
|
| 100 |
+
except struct.error:
|
| 101 |
+
break
|
| 102 |
+
|
| 103 |
+
if len(buffer) < total_length:
|
| 104 |
+
break
|
| 105 |
+
|
| 106 |
+
message_data = bytes(buffer[:total_length])
|
| 107 |
+
buffer = buffer[total_length:]
|
| 108 |
+
|
| 109 |
+
message = EventStreamParser.parse_message(message_data)
|
| 110 |
+
if message:
|
| 111 |
+
yield message
|
| 112 |
+
|
| 113 |
+
def extract_event_info(message: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
| 114 |
+
"""Extract event information from parsed message."""
|
| 115 |
+
headers = message.get('headers', {})
|
| 116 |
+
payload = message.get('payload')
|
| 117 |
+
|
| 118 |
+
event_type = headers.get(':event-type') or headers.get('event-type')
|
| 119 |
+
content_type = headers.get(':content-type') or headers.get('content-type')
|
| 120 |
+
message_type = headers.get(':message-type') or headers.get('message-type')
|
| 121 |
+
|
| 122 |
+
return {
|
| 123 |
+
'event_type': event_type,
|
| 124 |
+
'content_type': content_type,
|
| 125 |
+
'message_type': message_type,
|
| 126 |
+
'payload': payload
|
| 127 |
+
}
|
| 128 |
+
|
| 129 |
+
def _sse_format(event_type: str, data: Dict[str, Any]) -> str:
|
| 130 |
+
"""Format SSE event."""
|
| 131 |
+
json_data = json.dumps(data, ensure_ascii=False)
|
| 132 |
+
return f"event: {event_type}\ndata: {json_data}\n\n"
|
| 133 |
+
|
| 134 |
+
def build_message_start(conversation_id: str, model: str = "claude-sonnet-4.5", input_tokens: int = 0) -> str:
|
| 135 |
+
"""Build message_start SSE event."""
|
| 136 |
+
data = {
|
| 137 |
+
"type": "message_start",
|
| 138 |
+
"message": {
|
| 139 |
+
"id": conversation_id,
|
| 140 |
+
"type": "message",
|
| 141 |
+
"role": "assistant",
|
| 142 |
+
"content": [],
|
| 143 |
+
"model": model,
|
| 144 |
+
"stop_reason": None,
|
| 145 |
+
"stop_sequence": None,
|
| 146 |
+
"usage": {"input_tokens": input_tokens, "output_tokens": 0}
|
| 147 |
+
}
|
| 148 |
+
}
|
| 149 |
+
return _sse_format("message_start", data)
|
| 150 |
+
|
| 151 |
+
def build_content_block_start(index: int, block_type: str = "text") -> str:
|
| 152 |
+
"""Build content_block_start SSE event."""
|
| 153 |
+
data = {
|
| 154 |
+
"type": "content_block_start",
|
| 155 |
+
"index": index,
|
| 156 |
+
"content_block": {"type": block_type, "text": ""} if block_type == "text" else {"type": block_type}
|
| 157 |
+
}
|
| 158 |
+
return _sse_format("content_block_start", data)
|
| 159 |
+
|
| 160 |
+
def build_content_block_delta(index: int, text: str) -> str:
|
| 161 |
+
"""Build content_block_delta SSE event (text)."""
|
| 162 |
+
data = {
|
| 163 |
+
"type": "content_block_delta",
|
| 164 |
+
"index": index,
|
| 165 |
+
"delta": {"type": "text_delta", "text": text}
|
| 166 |
+
}
|
| 167 |
+
return _sse_format("content_block_delta", data)
|
| 168 |
+
|
| 169 |
+
def build_content_block_stop(index: int) -> str:
|
| 170 |
+
"""Build content_block_stop SSE event."""
|
| 171 |
+
data = {
|
| 172 |
+
"type": "content_block_stop",
|
| 173 |
+
"index": index
|
| 174 |
+
}
|
| 175 |
+
return _sse_format("content_block_stop", data)
|
| 176 |
+
|
| 177 |
+
def build_ping() -> str:
|
| 178 |
+
"""Build ping SSE event."""
|
| 179 |
+
data = {"type": "ping"}
|
| 180 |
+
return _sse_format("ping", data)
|
| 181 |
+
|
| 182 |
+
def build_message_stop(input_tokens: int, output_tokens: int, stop_reason: Optional[str] = None) -> str:
|
| 183 |
+
"""Build message_delta and message_stop SSE events."""
|
| 184 |
+
delta_data = {
|
| 185 |
+
"type": "message_delta",
|
| 186 |
+
"delta": {"stop_reason": stop_reason or "end_turn", "stop_sequence": None},
|
| 187 |
+
"usage": {"output_tokens": output_tokens}
|
| 188 |
+
}
|
| 189 |
+
delta_event = _sse_format("message_delta", delta_data)
|
| 190 |
+
|
| 191 |
+
stop_data = {
|
| 192 |
+
"type": "message_stop"
|
| 193 |
+
}
|
| 194 |
+
stop_event = _sse_format("message_stop", stop_data)
|
| 195 |
+
|
| 196 |
+
return delta_event + stop_event
|
| 197 |
+
|
| 198 |
+
def build_tool_use_start(index: int, tool_use_id: str, tool_name: str) -> str:
|
| 199 |
+
"""Build tool_use content_block_start SSE event."""
|
| 200 |
+
data = {
|
| 201 |
+
"type": "content_block_start",
|
| 202 |
+
"index": index,
|
| 203 |
+
"content_block": {
|
| 204 |
+
"type": "tool_use",
|
| 205 |
+
"id": tool_use_id,
|
| 206 |
+
"name": tool_name,
|
| 207 |
+
"input": {}
|
| 208 |
+
}
|
| 209 |
+
}
|
| 210 |
+
return _sse_format("content_block_start", data)
|
| 211 |
+
|
| 212 |
+
def build_tool_use_input_delta(index: int, input_json_delta: str) -> str:
|
| 213 |
+
"""Build tool_use input_json_delta SSE event."""
|
| 214 |
+
data = {
|
| 215 |
+
"type": "content_block_delta",
|
| 216 |
+
"index": index,
|
| 217 |
+
"delta": {
|
| 218 |
+
"type": "input_json_delta",
|
| 219 |
+
"partial_json": input_json_delta
|
| 220 |
+
}
|
| 221 |
+
}
|
| 222 |
+
return _sse_format("content_block_delta", data)
|
claude_stream.py
ADDED
|
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import logging
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from typing import AsyncGenerator, Optional, Dict, Any, List, Set
|
| 5 |
+
|
| 6 |
+
from utils import load_module
|
| 7 |
+
|
| 8 |
+
logger = logging.getLogger(__name__)
|
| 9 |
+
|
| 10 |
+
_parser = load_module("v2_claude_parser", "claude_parser.py")
|
| 11 |
+
build_message_start = _parser.build_message_start
|
| 12 |
+
build_content_block_start = _parser.build_content_block_start
|
| 13 |
+
build_content_block_delta = _parser.build_content_block_delta
|
| 14 |
+
build_content_block_stop = _parser.build_content_block_stop
|
| 15 |
+
build_ping = _parser.build_ping
|
| 16 |
+
build_message_stop = _parser.build_message_stop
|
| 17 |
+
build_tool_use_start = _parser.build_tool_use_start
|
| 18 |
+
build_tool_use_input_delta = _parser.build_tool_use_input_delta
|
| 19 |
+
|
| 20 |
+
class ClaudeStreamHandler:
|
| 21 |
+
def __init__(self, model: str, input_tokens: int = 0):
|
| 22 |
+
self.model = model
|
| 23 |
+
self.input_tokens = input_tokens
|
| 24 |
+
self.response_buffer: List[str] = []
|
| 25 |
+
self.content_block_index: int = -1
|
| 26 |
+
self.content_block_started: bool = False
|
| 27 |
+
self.content_block_start_sent: bool = False
|
| 28 |
+
self.content_block_stop_sent: bool = False
|
| 29 |
+
self.message_start_sent: bool = False
|
| 30 |
+
self.conversation_id: Optional[str] = None
|
| 31 |
+
|
| 32 |
+
# Tool use state
|
| 33 |
+
self.current_tool_use: Optional[Dict[str, Any]] = None
|
| 34 |
+
self.tool_input_buffer: List[str] = []
|
| 35 |
+
self.tool_use_id: Optional[str] = None
|
| 36 |
+
self.tool_name: Optional[str] = None
|
| 37 |
+
self._processed_tool_use_ids: Set[str] = set()
|
| 38 |
+
self.all_tool_inputs: List[str] = []
|
| 39 |
+
|
| 40 |
+
async def handle_event(self, event_type: str, payload: Dict[str, Any]) -> AsyncGenerator[str, None]:
|
| 41 |
+
"""Process a single Amazon Q event and yield Claude SSE events."""
|
| 42 |
+
|
| 43 |
+
# 1. Message Start (initial-response)
|
| 44 |
+
if event_type == "initial-response":
|
| 45 |
+
if not self.message_start_sent:
|
| 46 |
+
conv_id = payload.get('conversationId', self.conversation_id or 'unknown')
|
| 47 |
+
self.conversation_id = conv_id
|
| 48 |
+
yield build_message_start(conv_id, self.model, self.input_tokens)
|
| 49 |
+
self.message_start_sent = True
|
| 50 |
+
yield build_ping()
|
| 51 |
+
|
| 52 |
+
# 2. Content Block Delta (assistantResponseEvent)
|
| 53 |
+
elif event_type == "assistantResponseEvent":
|
| 54 |
+
content = payload.get("content", "")
|
| 55 |
+
|
| 56 |
+
# Close any open tool use block
|
| 57 |
+
if self.current_tool_use and not self.content_block_stop_sent:
|
| 58 |
+
yield build_content_block_stop(self.content_block_index)
|
| 59 |
+
self.content_block_stop_sent = True
|
| 60 |
+
self.current_tool_use = None
|
| 61 |
+
|
| 62 |
+
# Start content block if needed
|
| 63 |
+
if not self.content_block_start_sent:
|
| 64 |
+
self.content_block_index += 1
|
| 65 |
+
yield build_content_block_start(self.content_block_index, "text")
|
| 66 |
+
self.content_block_start_sent = True
|
| 67 |
+
self.content_block_started = True
|
| 68 |
+
|
| 69 |
+
# Send delta
|
| 70 |
+
if content:
|
| 71 |
+
self.response_buffer.append(content)
|
| 72 |
+
yield build_content_block_delta(self.content_block_index, content)
|
| 73 |
+
|
| 74 |
+
# 3. Tool Use (toolUseEvent)
|
| 75 |
+
elif event_type == "toolUseEvent":
|
| 76 |
+
tool_use_id = payload.get("toolUseId")
|
| 77 |
+
tool_name = payload.get("name")
|
| 78 |
+
tool_input = payload.get("input", {})
|
| 79 |
+
is_stop = payload.get("stop", False)
|
| 80 |
+
|
| 81 |
+
# Start new tool use
|
| 82 |
+
if tool_use_id and tool_name and not self.current_tool_use:
|
| 83 |
+
# Close previous text block if open
|
| 84 |
+
if self.content_block_start_sent and not self.content_block_stop_sent:
|
| 85 |
+
yield build_content_block_stop(self.content_block_index)
|
| 86 |
+
self.content_block_stop_sent = True
|
| 87 |
+
|
| 88 |
+
self._processed_tool_use_ids.add(tool_use_id)
|
| 89 |
+
self.content_block_index += 1
|
| 90 |
+
|
| 91 |
+
yield build_tool_use_start(self.content_block_index, tool_use_id, tool_name)
|
| 92 |
+
|
| 93 |
+
self.content_block_started = True
|
| 94 |
+
self.current_tool_use = {"toolUseId": tool_use_id, "name": tool_name}
|
| 95 |
+
self.tool_use_id = tool_use_id
|
| 96 |
+
self.tool_name = tool_name
|
| 97 |
+
self.tool_input_buffer = []
|
| 98 |
+
self.content_block_stop_sent = False
|
| 99 |
+
self.content_block_start_sent = True
|
| 100 |
+
|
| 101 |
+
# Accumulate input
|
| 102 |
+
if self.current_tool_use and tool_input:
|
| 103 |
+
fragment = ""
|
| 104 |
+
if isinstance(tool_input, str):
|
| 105 |
+
fragment = tool_input
|
| 106 |
+
else:
|
| 107 |
+
fragment = json.dumps(tool_input, ensure_ascii=False)
|
| 108 |
+
|
| 109 |
+
self.tool_input_buffer.append(fragment)
|
| 110 |
+
yield build_tool_use_input_delta(self.content_block_index, fragment)
|
| 111 |
+
|
| 112 |
+
# Stop tool use
|
| 113 |
+
if is_stop and self.current_tool_use:
|
| 114 |
+
full_input = "".join(self.tool_input_buffer)
|
| 115 |
+
self.all_tool_inputs.append(full_input)
|
| 116 |
+
|
| 117 |
+
yield build_content_block_stop(self.content_block_index)
|
| 118 |
+
self.content_block_stop_sent = True
|
| 119 |
+
self.content_block_started = False
|
| 120 |
+
self.current_tool_use = None
|
| 121 |
+
self.tool_use_id = None
|
| 122 |
+
self.tool_name = None
|
| 123 |
+
self.tool_input_buffer = []
|
| 124 |
+
|
| 125 |
+
# 4. Assistant Response End (assistantResponseEnd)
|
| 126 |
+
elif event_type == "assistantResponseEnd":
|
| 127 |
+
# Close any open block
|
| 128 |
+
if self.content_block_started and not self.content_block_stop_sent:
|
| 129 |
+
yield build_content_block_stop(self.content_block_index)
|
| 130 |
+
self.content_block_stop_sent = True
|
| 131 |
+
|
| 132 |
+
async def finish(self) -> AsyncGenerator[str, None]:
|
| 133 |
+
"""Send final events."""
|
| 134 |
+
# Ensure last block is closed
|
| 135 |
+
if self.content_block_started and not self.content_block_stop_sent:
|
| 136 |
+
yield build_content_block_stop(self.content_block_index)
|
| 137 |
+
self.content_block_stop_sent = True
|
| 138 |
+
|
| 139 |
+
# Calculate output tokens (approximate)
|
| 140 |
+
full_text = "".join(self.response_buffer)
|
| 141 |
+
full_tool_input = "".join(self.all_tool_inputs)
|
| 142 |
+
# Simple approximation: 4 chars per token
|
| 143 |
+
output_tokens = max(1, (len(full_text) + len(full_tool_input)) // 4)
|
| 144 |
+
|
| 145 |
+
yield build_message_stop(self.input_tokens, output_tokens, "end_turn")
|
claude_types.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Optional, Union, Dict, Any, Literal
|
| 2 |
+
from pydantic import BaseModel
|
| 3 |
+
|
| 4 |
+
class ClaudeMessage(BaseModel):
|
| 5 |
+
role: str
|
| 6 |
+
content: Union[str, List[Dict[str, Any]]]
|
| 7 |
+
|
| 8 |
+
class ClaudeTool(BaseModel):
|
| 9 |
+
name: str
|
| 10 |
+
description: Optional[str] = ""
|
| 11 |
+
input_schema: Dict[str, Any]
|
| 12 |
+
|
| 13 |
+
class ClaudeRequest(BaseModel):
|
| 14 |
+
model: str
|
| 15 |
+
messages: List[ClaudeMessage]
|
| 16 |
+
max_tokens: int = 8192
|
| 17 |
+
temperature: Optional[float] = None
|
| 18 |
+
tools: Optional[List[ClaudeTool]] = None
|
| 19 |
+
stream: bool = False
|
| 20 |
+
system: Optional[Union[str, List[Dict[str, Any]]]] = None
|
config.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Amazon Q API 配置文件
|
| 3 |
+
包含请求模板和默认配置
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
# Amazon Q API 端点
|
| 7 |
+
AMAZONQ_API_URL = "https://q.us-east-1.amazonaws.com/"
|
| 8 |
+
|
| 9 |
+
# 默认请求头模板
|
| 10 |
+
DEFAULT_HEADERS = {
|
| 11 |
+
"content-type": "application/x-amz-json-1.0",
|
| 12 |
+
"x-amz-target": "AmazonCodeWhispererStreamingService.GenerateAssistantResponse",
|
| 13 |
+
"user-agent": "aws-sdk-rust/1.3.9 ua/2.1 api/codewhispererstreaming/0.1.11582 os/windows lang/rust/1.87.0 md/appVersion-1.19.4 app/AmazonQ-For-CLI",
|
| 14 |
+
"x-amz-user-agent": "aws-sdk-rust/1.3.9 ua/2.1 api/codewhispererstreaming/0.1.11582 os/windows lang/rust/1.87.0 m/F app/AmazonQ-For-CLI",
|
| 15 |
+
"x-amzn-codewhisperer-optout": "false",
|
| 16 |
+
"amz-sdk-request": "attempt=1; max=3"
|
| 17 |
+
}
|
| 18 |
+
|
| 19 |
+
# 默认请求体模板(仅作为结构参考,实际使用时会被 raw_payload 替换)
|
| 20 |
+
DEFAULT_BODY_TEMPLATE = {
|
| 21 |
+
"conversationState": {
|
| 22 |
+
"conversationId": "", # 运行时动态生成
|
| 23 |
+
"history": [],
|
| 24 |
+
"currentMessage": {
|
| 25 |
+
"userInputMessage": {
|
| 26 |
+
"content": "",
|
| 27 |
+
"userInputMessageContext": {
|
| 28 |
+
"envState": {
|
| 29 |
+
"operatingSystem": "windows",
|
| 30 |
+
"currentWorkingDirectory": ""
|
| 31 |
+
},
|
| 32 |
+
"tools": []
|
| 33 |
+
},
|
| 34 |
+
"origin": "CLI",
|
| 35 |
+
"modelId": "claude-sonnet-4"
|
| 36 |
+
}
|
| 37 |
+
},
|
| 38 |
+
"chatTriggerType": "MANUAL"
|
| 39 |
+
}
|
| 40 |
+
}
|
replicate.py
ADDED
|
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import uuid
|
| 3 |
+
import asyncio
|
| 4 |
+
from typing import Dict, Optional, Tuple, List, AsyncGenerator, Any
|
| 5 |
+
import httpx
|
| 6 |
+
|
| 7 |
+
from utils import get_proxies, load_module, create_proxy_mounts
|
| 8 |
+
from config import AMAZONQ_API_URL, DEFAULT_HEADERS
|
| 9 |
+
|
| 10 |
+
try:
|
| 11 |
+
_parser = load_module("v2_claude_parser", "claude_parser.py")
|
| 12 |
+
EventStreamParser = _parser.EventStreamParser
|
| 13 |
+
extract_event_info = _parser.extract_event_info
|
| 14 |
+
except Exception as e:
|
| 15 |
+
print(f"Warning: Failed to load claude_parser: {e}")
|
| 16 |
+
EventStreamParser = None
|
| 17 |
+
extract_event_info = None
|
| 18 |
+
|
| 19 |
+
class StreamTracker:
|
| 20 |
+
def __init__(self):
|
| 21 |
+
self.has_content = False
|
| 22 |
+
|
| 23 |
+
async def track(self, gen: AsyncGenerator[str, None]) -> AsyncGenerator[str, None]:
|
| 24 |
+
async for item in gen:
|
| 25 |
+
if item:
|
| 26 |
+
self.has_content = True
|
| 27 |
+
yield item
|
| 28 |
+
|
| 29 |
+
def load_template() -> Tuple[str, Dict[str, str]]:
|
| 30 |
+
"""
|
| 31 |
+
加载 Amazon Q API 请求模板
|
| 32 |
+
|
| 33 |
+
Returns:
|
| 34 |
+
(url, headers): API 端点 URL 和默认请求头
|
| 35 |
+
"""
|
| 36 |
+
return AMAZONQ_API_URL, DEFAULT_HEADERS.copy()
|
| 37 |
+
|
| 38 |
+
def _merge_headers(as_log: Dict[str, str], bearer_token: str) -> Dict[str, str]:
|
| 39 |
+
headers = dict(as_log)
|
| 40 |
+
for k in list(headers.keys()):
|
| 41 |
+
kl = k.lower()
|
| 42 |
+
if kl in ("content-length","host","connection","transfer-encoding"):
|
| 43 |
+
headers.pop(k, None)
|
| 44 |
+
def set_header(name: str, value: str):
|
| 45 |
+
for key in list(headers.keys()):
|
| 46 |
+
if key.lower() == name.lower():
|
| 47 |
+
del headers[key]
|
| 48 |
+
headers[name] = value
|
| 49 |
+
set_header("Authorization", f"Bearer {bearer_token}")
|
| 50 |
+
set_header("amz-sdk-invocation-id", str(uuid.uuid4()))
|
| 51 |
+
return headers
|
| 52 |
+
|
| 53 |
+
async def send_chat_request(
|
| 54 |
+
access_token: str,
|
| 55 |
+
messages: List[Dict[str, Any]],
|
| 56 |
+
model: Optional[str] = None,
|
| 57 |
+
stream: bool = False,
|
| 58 |
+
timeout: Tuple[int,int] = (30,300),
|
| 59 |
+
client: Optional[httpx.AsyncClient] = None,
|
| 60 |
+
raw_payload: Dict[str, Any] = None
|
| 61 |
+
) -> Tuple[Optional[str], Optional[AsyncGenerator[str, None]], StreamTracker, Optional[AsyncGenerator[Any, None]]]:
|
| 62 |
+
"""
|
| 63 |
+
发送聊天请求到 Amazon Q API
|
| 64 |
+
|
| 65 |
+
Args:
|
| 66 |
+
access_token: Amazon Q access token
|
| 67 |
+
messages: 消息列表(已废弃,使用 raw_payload)
|
| 68 |
+
model: 模型名称(已废弃,使用 raw_payload)
|
| 69 |
+
stream: 是否流式响应
|
| 70 |
+
timeout: 超时配置
|
| 71 |
+
client: HTTP 客户端
|
| 72 |
+
raw_payload: Claude API 转换后的请求体(必需)
|
| 73 |
+
"""
|
| 74 |
+
if raw_payload is None:
|
| 75 |
+
raise ValueError("raw_payload is required")
|
| 76 |
+
|
| 77 |
+
url, headers_from_log = load_template()
|
| 78 |
+
headers_from_log["amz-sdk-invocation-id"] = str(uuid.uuid4())
|
| 79 |
+
|
| 80 |
+
# Use raw payload (for Claude API)
|
| 81 |
+
body_json = raw_payload
|
| 82 |
+
# Ensure conversationId is set if missing
|
| 83 |
+
if "conversationState" in body_json and "conversationId" not in body_json["conversationState"]:
|
| 84 |
+
body_json["conversationState"]["conversationId"] = str(uuid.uuid4())
|
| 85 |
+
|
| 86 |
+
payload_str = json.dumps(body_json, ensure_ascii=False)
|
| 87 |
+
headers = _merge_headers(headers_from_log, access_token)
|
| 88 |
+
|
| 89 |
+
local_client = False
|
| 90 |
+
if client is None:
|
| 91 |
+
local_client = True
|
| 92 |
+
mounts = create_proxy_mounts()
|
| 93 |
+
# 增加连接超时时间,避免 TLS 握手超时
|
| 94 |
+
timeout_config = httpx.Timeout(connect=60.0, read=timeout[1], write=timeout[0], pool=10.0)
|
| 95 |
+
# 只在有代理时才传递 mounts 参数
|
| 96 |
+
if mounts:
|
| 97 |
+
client = httpx.AsyncClient(mounts=mounts, timeout=timeout_config)
|
| 98 |
+
else:
|
| 99 |
+
client = httpx.AsyncClient(timeout=timeout_config)
|
| 100 |
+
|
| 101 |
+
# Use manual request sending to control stream lifetime
|
| 102 |
+
req = client.build_request("POST", url, headers=headers, content=payload_str)
|
| 103 |
+
|
| 104 |
+
resp = None
|
| 105 |
+
try:
|
| 106 |
+
resp = await client.send(req, stream=True)
|
| 107 |
+
|
| 108 |
+
if resp.status_code >= 400:
|
| 109 |
+
try:
|
| 110 |
+
await resp.read()
|
| 111 |
+
err = resp.text
|
| 112 |
+
except Exception:
|
| 113 |
+
err = f"HTTP {resp.status_code}"
|
| 114 |
+
await resp.aclose()
|
| 115 |
+
if local_client:
|
| 116 |
+
await client.aclose()
|
| 117 |
+
raise httpx.HTTPError(f"Upstream error {resp.status_code}: {err}")
|
| 118 |
+
|
| 119 |
+
tracker = StreamTracker()
|
| 120 |
+
|
| 121 |
+
# Track if the response has been consumed to avoid double-close
|
| 122 |
+
response_consumed = False
|
| 123 |
+
|
| 124 |
+
async def _iter_events() -> AsyncGenerator[Any, None]:
|
| 125 |
+
nonlocal response_consumed
|
| 126 |
+
try:
|
| 127 |
+
# Use EventStreamParser from claude_parser.py
|
| 128 |
+
async def byte_gen():
|
| 129 |
+
async for chunk in resp.aiter_bytes():
|
| 130 |
+
if chunk:
|
| 131 |
+
yield chunk
|
| 132 |
+
|
| 133 |
+
async for message in EventStreamParser.parse_stream(byte_gen()):
|
| 134 |
+
event_info = extract_event_info(message)
|
| 135 |
+
if event_info:
|
| 136 |
+
event_type = event_info.get('event_type')
|
| 137 |
+
payload = event_info.get('payload')
|
| 138 |
+
if event_type and payload:
|
| 139 |
+
yield (event_type, payload)
|
| 140 |
+
except Exception:
|
| 141 |
+
if not tracker.has_content:
|
| 142 |
+
raise
|
| 143 |
+
finally:
|
| 144 |
+
response_consumed = True
|
| 145 |
+
await resp.aclose()
|
| 146 |
+
if local_client:
|
| 147 |
+
await client.aclose()
|
| 148 |
+
|
| 149 |
+
if stream:
|
| 150 |
+
# Wrap generator to ensure cleanup on early termination
|
| 151 |
+
async def _safe_iter_events():
|
| 152 |
+
try:
|
| 153 |
+
# 托底方案: 300秒强制超时
|
| 154 |
+
async with asyncio.timeout(300):
|
| 155 |
+
async for item in _iter_events():
|
| 156 |
+
yield item
|
| 157 |
+
except asyncio.TimeoutError:
|
| 158 |
+
# 超时强制关闭
|
| 159 |
+
if resp and not resp.is_closed:
|
| 160 |
+
await resp.aclose()
|
| 161 |
+
if local_client and client:
|
| 162 |
+
await client.aclose()
|
| 163 |
+
raise
|
| 164 |
+
except GeneratorExit:
|
| 165 |
+
# Generator was closed without being fully consumed
|
| 166 |
+
# Ensure cleanup happens even if finally block wasn't reached
|
| 167 |
+
if resp and not resp.is_closed:
|
| 168 |
+
await resp.aclose()
|
| 169 |
+
if local_client and client:
|
| 170 |
+
await client.aclose()
|
| 171 |
+
raise
|
| 172 |
+
except Exception:
|
| 173 |
+
# Any exception should also trigger cleanup
|
| 174 |
+
if resp and not resp.is_closed:
|
| 175 |
+
await resp.aclose()
|
| 176 |
+
if local_client and client:
|
| 177 |
+
await client.aclose()
|
| 178 |
+
raise
|
| 179 |
+
return None, None, tracker, _safe_iter_events()
|
| 180 |
+
else:
|
| 181 |
+
# Non-streaming: consume all events
|
| 182 |
+
try:
|
| 183 |
+
async for _ in _iter_events():
|
| 184 |
+
pass
|
| 185 |
+
finally:
|
| 186 |
+
# Ensure response is closed even if iteration is incomplete
|
| 187 |
+
if not response_consumed and resp:
|
| 188 |
+
await resp.aclose()
|
| 189 |
+
if local_client:
|
| 190 |
+
await client.aclose()
|
| 191 |
+
return None, None, tracker, None
|
| 192 |
+
|
| 193 |
+
except Exception:
|
| 194 |
+
# Critical: close response on any exception before generators are created
|
| 195 |
+
if resp and not resp.is_closed:
|
| 196 |
+
await resp.aclose()
|
| 197 |
+
if local_client and client:
|
| 198 |
+
await client.aclose()
|
| 199 |
+
raise
|
requirements.txt
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
fastapi==0.115.5
|
| 2 |
+
uvicorn[standard]==0.32.0
|
| 3 |
+
pydantic==2.9.2
|
| 4 |
+
python-dotenv==1.0.1
|
| 5 |
+
httpx==0.28.1
|
utils.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""公共工具函数"""
|
| 2 |
+
import os
|
| 3 |
+
import importlib.util
|
| 4 |
+
import httpx
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from typing import Dict, Optional
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def get_proxies() -> Optional[Dict[str, str]]:
|
| 10 |
+
"""
|
| 11 |
+
从环境变量获取代理配置
|
| 12 |
+
读取 HTTP_PROXY 环境变量并返回代理字典
|
| 13 |
+
"""
|
| 14 |
+
proxy = os.getenv("HTTP_PROXY", "").strip()
|
| 15 |
+
if proxy:
|
| 16 |
+
return {"http": proxy, "https": proxy}
|
| 17 |
+
return None
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def load_module(module_name: str, file_name: str):
|
| 21 |
+
"""
|
| 22 |
+
动态加载指定模块
|
| 23 |
+
|
| 24 |
+
Args:
|
| 25 |
+
module_name: 模块名称
|
| 26 |
+
file_name: 文件名(相对于当前目录)
|
| 27 |
+
|
| 28 |
+
Returns:
|
| 29 |
+
加载的模块对象
|
| 30 |
+
"""
|
| 31 |
+
base_dir = Path(__file__).resolve().parent
|
| 32 |
+
spec = importlib.util.spec_from_file_location(module_name, str(base_dir / file_name))
|
| 33 |
+
module = importlib.util.module_from_spec(spec)
|
| 34 |
+
spec.loader.exec_module(module)
|
| 35 |
+
return module
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def create_proxy_mounts() -> Optional[Dict[str, httpx.AsyncHTTPTransport]]:
|
| 39 |
+
"""
|
| 40 |
+
创建代理传输层配置
|
| 41 |
+
|
| 42 |
+
Returns:
|
| 43 |
+
代理挂载配置字典,如果没有配置代理则返回 None
|
| 44 |
+
"""
|
| 45 |
+
proxies = get_proxies()
|
| 46 |
+
if proxies:
|
| 47 |
+
proxy_url = proxies.get("https") or proxies.get("http")
|
| 48 |
+
if proxy_url:
|
| 49 |
+
return {
|
| 50 |
+
"https://": httpx.AsyncHTTPTransport(proxy=proxy_url),
|
| 51 |
+
"http://": httpx.AsyncHTTPTransport(proxy=proxy_url),
|
| 52 |
+
}
|
| 53 |
+
return None
|