devme commited on
Commit
7ea1434
·
verified ·
1 Parent(s): 10c515f

Upload 10 files

Browse files
Files changed (10) hide show
  1. Dockerfile +12 -0
  2. app.py +452 -0
  3. claude_converter.py +386 -0
  4. claude_parser.py +222 -0
  5. claude_stream.py +145 -0
  6. claude_types.py +20 -0
  7. config.py +40 -0
  8. replicate.py +199 -0
  9. requirements.txt +5 -0
  10. utils.py +53 -0
Dockerfile ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11-slim
2
+
3
+ WORKDIR /app
4
+
5
+ COPY requirements.txt .
6
+ RUN pip install --no-cache-dir -r requirements.txt
7
+
8
+ COPY . .
9
+
10
+ EXPOSE 8000
11
+
12
+ CMD ["python", "app.py"]
app.py ADDED
@@ -0,0 +1,452 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import traceback
4
+ import uuid
5
+ import time
6
+ import asyncio
7
+ import importlib.util
8
+ from pathlib import Path
9
+ from typing import Dict, Optional, List, Any, AsyncGenerator, Tuple
10
+
11
+ from contextlib import asynccontextmanager
12
+ from fastapi import FastAPI, Depends, HTTPException, Header
13
+ from fastapi.middleware.cors import CORSMiddleware
14
+ from fastapi.responses import StreamingResponse
15
+ from dotenv import load_dotenv
16
+ import httpx
17
+ import hashlib
18
+
19
+ from utils import get_proxies, create_proxy_mounts
20
+
21
+ # ------------------------------------------------------------------------------
22
+ # Bootstrap
23
+ # ------------------------------------------------------------------------------
24
+
25
+ BASE_DIR = Path(__file__).resolve().parent
26
+
27
+ load_dotenv(BASE_DIR / ".env")
28
+
29
+ app = FastAPI(title="v2 OpenAI-compatible Server (Amazon Q Backend)")
30
+
31
+ # CORS for simple testing in browser
32
+ app.add_middleware(
33
+ CORSMiddleware,
34
+ allow_origins=["*"],
35
+ allow_methods=["*"],
36
+ allow_headers=["*"],
37
+ )
38
+
39
+ # ------------------------------------------------------------------------------
40
+ # Dynamic import of replicate.py to avoid package __init__ needs
41
+ # ------------------------------------------------------------------------------
42
+
43
+ def _load_replicate_module():
44
+ mod_path = BASE_DIR / "replicate.py"
45
+ spec = importlib.util.spec_from_file_location("v2_replicate", str(mod_path))
46
+ module = importlib.util.module_from_spec(spec) # type: ignore[arg-type]
47
+ assert spec is not None and spec.loader is not None
48
+ spec.loader.exec_module(module) # type: ignore[attr-defined]
49
+ return module
50
+
51
+ _replicate = _load_replicate_module()
52
+ send_chat_request = _replicate.send_chat_request
53
+
54
+ # ------------------------------------------------------------------------------
55
+ # Dynamic import of Claude modules
56
+ # ------------------------------------------------------------------------------
57
+
58
+ def _load_claude_modules():
59
+ # claude_types
60
+ spec_types = importlib.util.spec_from_file_location("v2_claude_types", str(BASE_DIR / "claude_types.py"))
61
+ mod_types = importlib.util.module_from_spec(spec_types)
62
+ spec_types.loader.exec_module(mod_types)
63
+
64
+ # claude_converter
65
+ spec_conv = importlib.util.spec_from_file_location("v2_claude_converter", str(BASE_DIR / "claude_converter.py"))
66
+ mod_conv = importlib.util.module_from_spec(spec_conv)
67
+
68
+ import sys
69
+ sys.modules["v2.claude_types"] = mod_types
70
+
71
+ spec_conv.loader.exec_module(mod_conv)
72
+
73
+ # claude_stream
74
+ spec_stream = importlib.util.spec_from_file_location("v2_claude_stream", str(BASE_DIR / "claude_stream.py"))
75
+ mod_stream = importlib.util.module_from_spec(spec_stream)
76
+ spec_stream.loader.exec_module(mod_stream)
77
+
78
+ return mod_types, mod_conv, mod_stream
79
+
80
+ _claude_types, _claude_converter, _claude_stream = _load_claude_modules()
81
+ ClaudeRequest = _claude_types.ClaudeRequest
82
+ convert_claude_to_amazonq_request = _claude_converter.convert_claude_to_amazonq_request
83
+ ClaudeStreamHandler = _claude_stream.ClaudeStreamHandler
84
+
85
+ # ------------------------------------------------------------------------------
86
+ # Global HTTP Client
87
+ # ------------------------------------------------------------------------------
88
+
89
+ GLOBAL_CLIENT: Optional[httpx.AsyncClient] = None
90
+
91
+ async def _init_global_client():
92
+ global GLOBAL_CLIENT
93
+ mounts = create_proxy_mounts()
94
+ # Increased limits for high concurrency with streaming
95
+ # max_connections: 总连接数上限
96
+ # max_keepalive_connections: 保持活跃的连接数
97
+ # keepalive_expiry: 连接保持时间
98
+ limits = httpx.Limits(
99
+ max_keepalive_connections=60,
100
+ max_connections=60, # 提高到500以支持更高并发
101
+ keepalive_expiry=30.0 # 30秒后释放空闲连接
102
+ )
103
+ # 为流式响应设置更长的超时
104
+ timeout = httpx.Timeout(
105
+ connect=30.0, # 连接超时,TLS 握手需要足够时间
106
+ read=300.0, # 读取超时(流式响应需要更长时间)
107
+ write=30.0, # 写入超时
108
+ pool=10.0 # 从连接池获取连接的超时时间
109
+ )
110
+ # 只在有代理时才传递 mounts 参数
111
+ if mounts:
112
+ GLOBAL_CLIENT = httpx.AsyncClient(mounts=mounts, timeout=timeout, limits=limits)
113
+ else:
114
+ GLOBAL_CLIENT = httpx.AsyncClient(timeout=timeout, limits=limits)
115
+
116
+ async def _close_global_client():
117
+ global GLOBAL_CLIENT
118
+ if GLOBAL_CLIENT:
119
+ await GLOBAL_CLIENT.aclose()
120
+ GLOBAL_CLIENT = None
121
+
122
+ # ------------------------------------------------------------------------------
123
+ # Token 缓存和管理
124
+ # ------------------------------------------------------------------------------
125
+
126
+ # 内存缓存: {hash: {accessToken, refreshToken, clientId, clientSecret, lastRefresh}}
127
+ TOKEN_MAP: Dict[str, Dict[str, Any]] = {}
128
+
129
+ def _sha256(text: str) -> str:
130
+ """计算 SHA256 哈希"""
131
+ return hashlib.sha256(text.encode()).hexdigest()
132
+
133
+ def _parse_bearer_token(bearer_token: str) -> Tuple[str, str, str]:
134
+ """
135
+ 解析 Bearer token: clientId:clientSecret:refreshToken
136
+ 重要: refreshToken 中可能包含冒号,所以要正确处理
137
+ """
138
+ temp_array = bearer_token.split(":")
139
+ client_id = temp_array[0] if len(temp_array) > 0 else ""
140
+ client_secret = temp_array[1] if len(temp_array) > 1 else ""
141
+ refresh_token = ":".join(temp_array[2:]) if len(temp_array) > 2 else ""
142
+ return client_id, client_secret, refresh_token
143
+
144
+ async def _handle_token_refresh(client_id: str, client_secret: str, refresh_token: str) -> Optional[str]:
145
+ """刷新 access token"""
146
+ payload = {
147
+ "grantType": "refresh_token",
148
+ "clientId": client_id,
149
+ "clientSecret": client_secret,
150
+ "refreshToken": refresh_token,
151
+ }
152
+
153
+ try:
154
+ client = GLOBAL_CLIENT
155
+ if not client:
156
+ async with httpx.AsyncClient(timeout=60.0) as temp_client:
157
+ r = await temp_client.post(TOKEN_URL, headers=_oidc_headers(), json=payload)
158
+ r.raise_for_status()
159
+ data = r.json()
160
+ else:
161
+ r = await client.post(TOKEN_URL, headers=_oidc_headers(), json=payload)
162
+ r.raise_for_status()
163
+ data = r.json()
164
+
165
+ return data.get("accessToken")
166
+ except httpx.HTTPStatusError as e:
167
+ print(f"Token refresh HTTP error: {e.response.status_code} - {e.response.text}")
168
+ traceback.print_exc()
169
+ return None
170
+ except Exception as e:
171
+ print(f"Token refresh error: {e}")
172
+ traceback.print_exc()
173
+ return None
174
+
175
+ # ------------------------------------------------------------------------------
176
+ # 全局 Token 刷新器
177
+ # ------------------------------------------------------------------------------
178
+
179
+ async def _global_token_refresher():
180
+ """全局刷新器: 每 45 分钟刷新所有缓存的 token"""
181
+ while True:
182
+ try:
183
+ await asyncio.sleep(45 * 60) # 45 minutes
184
+ if not TOKEN_MAP:
185
+ continue
186
+ print(f"[Token Refresher] Starting token refresh cycle...")
187
+ refresh_count = 0
188
+ for hash_key, token_data in list(TOKEN_MAP.items()):
189
+ try:
190
+ new_token = await _handle_token_refresh(
191
+ token_data["clientId"],
192
+ token_data["clientSecret"],
193
+ token_data["refreshToken"]
194
+ )
195
+ if new_token:
196
+ TOKEN_MAP[hash_key]["accessToken"] = new_token
197
+ TOKEN_MAP[hash_key]["lastRefresh"] = time.time()
198
+ refresh_count += 1
199
+ else:
200
+ print(f"[Token Refresher] Failed to refresh token for hash: {hash_key[:8]}...")
201
+ except Exception as e:
202
+ print(f"[Token Refresher] Exception refreshing token: {e}")
203
+ traceback.print_exc()
204
+ print(f"[Token Refresher] Refreshed {refresh_count}/{len(TOKEN_MAP)} tokens")
205
+ except Exception:
206
+ traceback.print_exc()
207
+ await asyncio.sleep(60) # 发生异常时等待 1 分钟后重试
208
+
209
+ # ------------------------------------------------------------------------------
210
+ # Token refresh (OIDC)
211
+ # ------------------------------------------------------------------------------
212
+
213
+ OIDC_BASE = "https://oidc.us-east-1.amazonaws.com"
214
+ TOKEN_URL = f"{OIDC_BASE}/token"
215
+
216
+ def _oidc_headers() -> Dict[str, str]:
217
+ return {
218
+ "content-type": "application/json",
219
+ "user-agent": "aws-sdk-rust/1.3.9 os/windows lang/rust/1.87.0",
220
+ "x-amz-user-agent": "aws-sdk-rust/1.3.9 ua/2.1 api/ssooidc/1.88.0 os/windows lang/rust/1.87.0 m/E app/AmazonQ-For-CLI",
221
+ "amz-sdk-request": "attempt=1; max=3",
222
+ "amz-sdk-invocation-id": str(uuid.uuid4()),
223
+ }
224
+
225
+ # ------------------------------------------------------------------------------
226
+ # 认证中间件
227
+ # ------------------------------------------------------------------------------
228
+
229
+ async def auth_middleware(authorization: Optional[str] = Header(default=None)) -> Dict[str, Any]:
230
+ """
231
+ 认证中间件: 解析 Bearer token 并返回账户信息
232
+ Bearer token 格式: clientId:clientSecret:refreshToken
233
+ """
234
+ if not authorization or not authorization.startswith("Bearer "):
235
+ raise HTTPException(status_code=401, detail="Missing or invalid Authorization header")
236
+
237
+ bearer_token = authorization[7:] # 移除 "Bearer " 前缀
238
+ token_hash = _sha256(bearer_token)
239
+
240
+ # 检查缓存
241
+ if token_hash in TOKEN_MAP:
242
+ return {
243
+ "accessToken": TOKEN_MAP[token_hash]["accessToken"],
244
+ "clientId": TOKEN_MAP[token_hash]["clientId"],
245
+ "clientSecret": TOKEN_MAP[token_hash]["clientSecret"],
246
+ "refreshToken": TOKEN_MAP[token_hash]["refreshToken"],
247
+ }
248
+
249
+ # 解析 bearer token
250
+ client_id, client_secret, refresh_token = _parse_bearer_token(bearer_token)
251
+
252
+ if not client_id or not client_secret or not refresh_token:
253
+ raise HTTPException(status_code=401, detail="Invalid token format. Expected: clientId:clientSecret:refreshToken")
254
+
255
+ # 刷新 token
256
+ access_token = await _handle_token_refresh(client_id, client_secret, refresh_token)
257
+ if not access_token:
258
+ raise HTTPException(status_code=401, detail="Failed to refresh access token")
259
+
260
+ # 缓存
261
+ TOKEN_MAP[token_hash] = {
262
+ "accessToken": access_token,
263
+ "refreshToken": refresh_token,
264
+ "clientId": client_id,
265
+ "clientSecret": client_secret,
266
+ "lastRefresh": time.time()
267
+ }
268
+
269
+ return {
270
+ "accessToken": access_token,
271
+ "clientId": client_id,
272
+ "clientSecret": client_secret,
273
+ "refreshToken": refresh_token,
274
+ }
275
+
276
+ # ------------------------------------------------------------------------------
277
+ # Dependencies
278
+ # ------------------------------------------------------------------------------
279
+
280
+ async def require_account(authorization: Optional[str] = Header(default=None)) -> Dict[str, Any]:
281
+ return await auth_middleware(authorization)
282
+
283
+ # ------------------------------------------------------------------------------
284
+ # Claude Messages API endpoint
285
+ # ------------------------------------------------------------------------------
286
+
287
+ @app.post("/v1/messages")
288
+ async def claude_messages(req: ClaudeRequest, account: Dict[str, Any] = Depends(require_account)):
289
+ """
290
+ Claude-compatible messages endpoint.
291
+ """
292
+ # 1. Convert request
293
+ try:
294
+ aq_request = convert_claude_to_amazonq_request(req)
295
+ except Exception as e:
296
+ traceback.print_exc()
297
+ raise HTTPException(status_code=400, detail=f"Request conversion failed: {str(e)}")
298
+
299
+ # 2. Send upstream - always stream from upstream to get full event details
300
+ try:
301
+ access = account.get("accessToken")
302
+ if not access:
303
+ raise HTTPException(status_code=502, detail="Access token unavailable")
304
+
305
+ # We call with stream=True to get the event iterator
306
+ _, _, tracker, event_iter = await send_chat_request(
307
+ access_token=access,
308
+ messages=[],
309
+ model=req.model,
310
+ stream=True,
311
+ client=GLOBAL_CLIENT,
312
+ raw_payload=aq_request
313
+ )
314
+
315
+ if not event_iter:
316
+ raise HTTPException(status_code=502, detail="No event stream returned")
317
+
318
+ # Handler
319
+ # Estimate input tokens (simple count or 0)
320
+ # For now 0 or simple len
321
+ input_tokens = 0
322
+ handler = ClaudeStreamHandler(model=req.model, input_tokens=input_tokens)
323
+
324
+ async def event_generator():
325
+ try:
326
+ async for event_type, payload in event_iter:
327
+ async for sse in handler.handle_event(event_type, payload):
328
+ yield sse
329
+ async for sse in handler.finish():
330
+ yield sse
331
+ except GeneratorExit:
332
+ # Client disconnected
333
+ raise
334
+ except Exception:
335
+ raise
336
+
337
+ if req.stream:
338
+ return StreamingResponse(event_generator(), media_type="text/event-stream")
339
+ else:
340
+ # Accumulate for non-streaming
341
+ # This is a bit complex because we need to reconstruct the full response object
342
+ # For now, let's just support streaming as it's the main use case for Claude Code
343
+ # But to be nice, let's try to support non-streaming by consuming the generator
344
+
345
+ content_blocks = []
346
+ usage = {"input_tokens": 0, "output_tokens": 0}
347
+ stop_reason = None
348
+
349
+ # We need to parse the SSE strings back to objects... inefficient but works
350
+ # Or we could refactor handler to yield objects.
351
+ # For now, let's just raise error for non-streaming or implement basic text
352
+ # Claude Code uses streaming.
353
+
354
+ # Let's implement a basic accumulator from the SSE stream
355
+ final_content = []
356
+
357
+ async for sse_line in event_generator():
358
+ if sse_line.startswith("data: "):
359
+ data_str = sse_line[6:].strip()
360
+ if data_str == "[DONE]": continue
361
+ try:
362
+ data = json.loads(data_str)
363
+ dtype = data.get("type")
364
+ if dtype == "content_block_start":
365
+ idx = data.get("index", 0)
366
+ while len(final_content) <= idx:
367
+ final_content.append(None)
368
+ final_content[idx] = data.get("content_block")
369
+ elif dtype == "content_block_delta":
370
+ idx = data.get("index", 0)
371
+ delta = data.get("delta", {})
372
+ if final_content[idx]:
373
+ if delta.get("type") == "text_delta":
374
+ final_content[idx]["text"] += delta.get("text", "")
375
+ elif delta.get("type") == "input_json_delta":
376
+ # We need to accumulate partial json
377
+ # But wait, content_block for tool_use has 'input' as dict?
378
+ # No, in start it is empty.
379
+ # We need to track partial json string
380
+ if "partial_json" not in final_content[idx]:
381
+ final_content[idx]["partial_json"] = ""
382
+ final_content[idx]["partial_json"] += delta.get("partial_json", "")
383
+ elif dtype == "content_block_stop":
384
+ idx = data.get("index", 0)
385
+ # If tool use, parse json
386
+ if final_content[idx] and final_content[idx]["type"] == "tool_use":
387
+ if "partial_json" in final_content[idx]:
388
+ try:
389
+ final_content[idx]["input"] = json.loads(final_content[idx]["partial_json"])
390
+ except:
391
+ pass
392
+ del final_content[idx]["partial_json"]
393
+ elif dtype == "message_delta":
394
+ usage = data.get("usage", usage)
395
+ stop_reason = data.get("delta", {}).get("stop_reason")
396
+ except:
397
+ pass
398
+
399
+ return {
400
+ "id": f"msg_{uuid.uuid4()}",
401
+ "type": "message",
402
+ "role": "assistant",
403
+ "model": req.model,
404
+ "content": [c for c in final_content if c is not None],
405
+ "stop_reason": stop_reason,
406
+ "stop_sequence": None,
407
+ "usage": usage
408
+ }
409
+
410
+ except Exception as e:
411
+ raise
412
+
413
+ # ------------------------------------------------------------------------------
414
+ # Startup / Shutdown Events
415
+ # ------------------------------------------------------------------------------
416
+
417
+ async def _startup():
418
+ """初始化全局客户端和启动后台任务"""
419
+ await _init_global_client()
420
+ asyncio.create_task(_global_token_refresher())
421
+
422
+ async def _shutdown():
423
+ """清理资源"""
424
+ await _close_global_client()
425
+
426
+ # 更新 lifespan 上下文管理器使用实际的启动/关闭逻辑
427
+ @asynccontextmanager
428
+ async def lifespan(app_instance: FastAPI):
429
+ """
430
+ 管理应用生命周期事件
431
+ 启动时初始化数据库和后台任务,关闭时清理资源
432
+ """
433
+ await _startup()
434
+ yield
435
+ await _shutdown()
436
+
437
+ # 将 lifespan 设置到 app
438
+ app.router.lifespan_context = lifespan
439
+
440
+ # ------------------------------------------------------------------------------
441
+ # 直接运行支持
442
+ # ------------------------------------------------------------------------------
443
+
444
+ if __name__ == "__main__":
445
+ import uvicorn
446
+ port = int(os.getenv("PORT", "8000"))
447
+ uvicorn.run(
448
+ app,
449
+ host="0.0.0.0",
450
+ port=port,
451
+ log_level="info"
452
+ )
claude_converter.py ADDED
@@ -0,0 +1,386 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import uuid
3
+ from datetime import datetime
4
+ from typing import List, Dict, Any, Optional, Union
5
+
6
+ try:
7
+ from .claude_types import ClaudeRequest, ClaudeMessage, ClaudeTool
8
+ except ImportError:
9
+ # Fallback for dynamic loading where relative import might fail
10
+ # We assume claude_types is available in sys.modules or we can import it directly if in same dir
11
+ import sys
12
+ if "v2.claude_types" in sys.modules:
13
+ from v2.claude_types import ClaudeRequest, ClaudeMessage, ClaudeTool
14
+ else:
15
+ # Try absolute import assuming v2 is in path or current dir
16
+ try:
17
+ from claude_types import ClaudeRequest, ClaudeMessage, ClaudeTool
18
+ except ImportError:
19
+ # Last resort: if loaded via importlib in app.py, we might need to rely on app.py injecting it
20
+ # But app.py loads this module.
21
+ pass
22
+
23
+ def get_current_timestamp() -> str:
24
+ """Get current timestamp in Amazon Q format."""
25
+ now = datetime.now().astimezone()
26
+ weekday = now.strftime("%A")
27
+ iso_time = now.isoformat(timespec='milliseconds')
28
+ return f"{weekday}, {iso_time}"
29
+
30
+ def map_model_name(claude_model: str) -> str:
31
+ """Map Claude model name to Amazon Q model ID."""
32
+ model_lower = claude_model.lower()
33
+ if model_lower.startswith("claude-sonnet-4.5") or model_lower.startswith("claude-sonnet-4-5"):
34
+ return "claude-sonnet-4.5"
35
+ return "claude-sonnet-4"
36
+
37
+ def extract_text_from_content(content: Union[str, List[Dict[str, Any]]]) -> str:
38
+ """Extract text from Claude content."""
39
+ if isinstance(content, str):
40
+ return content
41
+ elif isinstance(content, list):
42
+ parts = []
43
+ for block in content:
44
+ if isinstance(block, dict):
45
+ if block.get("type") == "text":
46
+ parts.append(block.get("text", ""))
47
+ return "\n".join(parts)
48
+ return ""
49
+
50
+ def process_tool_result_block(block: Dict[str, Any], tool_results: List[Dict[str, Any]]) -> None:
51
+ """
52
+ 处理单个 tool_result 块,提取内容并添加到 tool_results 列表
53
+
54
+ Args:
55
+ block: tool_result 类型的内容块
56
+ tool_results: 用于存储处理结果的列表
57
+ """
58
+ tool_use_id = block.get("tool_use_id")
59
+ raw_c = block.get("content", [])
60
+
61
+ aq_content = []
62
+ if isinstance(raw_c, str):
63
+ aq_content = [{"text": raw_c}]
64
+ elif isinstance(raw_c, list):
65
+ for item in raw_c:
66
+ if isinstance(item, dict):
67
+ if item.get("type") == "text":
68
+ aq_content.append({"text": item.get("text", "")})
69
+ elif "text" in item:
70
+ aq_content.append({"text": item["text"]})
71
+ elif isinstance(item, str):
72
+ aq_content.append({"text": item})
73
+
74
+ if not any(i.get("text", "").strip() for i in aq_content):
75
+ aq_content = [{"text": "Tool use was cancelled by the user"}]
76
+
77
+ # Merge if exists
78
+ existing = next((r for r in tool_results if r["toolUseId"] == tool_use_id), None)
79
+ if existing:
80
+ existing["content"].extend(aq_content)
81
+ else:
82
+ tool_results.append({
83
+ "toolUseId": tool_use_id,
84
+ "content": aq_content,
85
+ "status": block.get("status", "success")
86
+ })
87
+
88
+ def extract_images_from_content(content: Union[str, List[Dict[str, Any]]]) -> Optional[List[Dict[str, Any]]]:
89
+ """Extract images from Claude content and convert to Amazon Q format."""
90
+ if not isinstance(content, list):
91
+ return None
92
+
93
+ images = []
94
+ for block in content:
95
+ if isinstance(block, dict) and block.get("type") == "image":
96
+ source = block.get("source", {})
97
+ if source.get("type") == "base64":
98
+ media_type = source.get("media_type", "image/png")
99
+ fmt = media_type.split("/")[-1] if "/" in media_type else "png"
100
+ images.append({
101
+ "format": fmt,
102
+ "source": {
103
+ "bytes": source.get("data", "")
104
+ }
105
+ })
106
+ return images if images else None
107
+
108
+ def convert_tool(tool: ClaudeTool) -> Dict[str, Any]:
109
+ """Convert Claude tool to Amazon Q tool."""
110
+ desc = tool.description or ""
111
+ if len(desc) > 10240:
112
+ desc = desc[:10100] + "\n\n...(Full description provided in TOOL DOCUMENTATION section)"
113
+
114
+ return {
115
+ "toolSpecification": {
116
+ "name": tool.name,
117
+ "description": desc,
118
+ "inputSchema": {"json": tool.input_schema}
119
+ }
120
+ }
121
+
122
+ def merge_user_messages(messages: List[Dict[str, Any]]) -> Dict[str, Any]:
123
+ """Merge consecutive user messages, keeping only the last 2 messages' images."""
124
+ if not messages:
125
+ return {}
126
+
127
+ all_contents = []
128
+ base_context = None
129
+ base_origin = None
130
+ base_model = None
131
+ all_images = []
132
+
133
+ for msg in messages:
134
+ content = msg.get("content", "")
135
+ if base_context is None:
136
+ base_context = msg.get("userInputMessageContext", {})
137
+ if base_origin is None:
138
+ base_origin = msg.get("origin", "CLI")
139
+ if base_model is None:
140
+ base_model = msg.get("modelId")
141
+
142
+ if content:
143
+ all_contents.append(content)
144
+
145
+ # Collect images from each message
146
+ msg_images = msg.get("images")
147
+ if msg_images:
148
+ all_images.append(msg_images)
149
+
150
+ result = {
151
+ "content": "\n\n".join(all_contents),
152
+ "userInputMessageContext": base_context or {},
153
+ "origin": base_origin or "CLI",
154
+ "modelId": base_model
155
+ }
156
+
157
+ # Only keep images from the last 2 messages that have images
158
+ if all_images:
159
+ kept_images = []
160
+ for img_list in all_images[-2:]: # Take last 2 messages' images
161
+ kept_images.extend(img_list)
162
+ if kept_images:
163
+ result["images"] = kept_images
164
+
165
+ return result
166
+
167
+ def process_history(messages: List[ClaudeMessage]) -> List[Dict[str, Any]]:
168
+ """Process history messages to match Amazon Q format (alternating user/assistant)."""
169
+ history = []
170
+ seen_tool_use_ids = set()
171
+
172
+ raw_history = []
173
+
174
+ # First pass: convert individual messages
175
+ for msg in messages:
176
+ if msg.role == "user":
177
+ content = msg.content
178
+ text_content = ""
179
+ tool_results = None
180
+ images = extract_images_from_content(content)
181
+
182
+ if isinstance(content, list):
183
+ text_parts = []
184
+ for block in content:
185
+ if isinstance(block, dict):
186
+ btype = block.get("type")
187
+ if btype == "text":
188
+ text_parts.append(block.get("text", ""))
189
+ elif btype == "tool_result":
190
+ if tool_results is None:
191
+ tool_results = []
192
+ process_tool_result_block(block, tool_results)
193
+ text_content = "\n".join(text_parts)
194
+ else:
195
+ text_content = extract_text_from_content(content)
196
+
197
+ user_ctx = {
198
+ "envState": {
199
+ "operatingSystem": "macos",
200
+ "currentWorkingDirectory": "/"
201
+ }
202
+ }
203
+ if tool_results:
204
+ user_ctx["toolResults"] = tool_results
205
+
206
+ u_msg = {
207
+ "content": text_content,
208
+ "userInputMessageContext": user_ctx,
209
+ "origin": "CLI"
210
+ }
211
+ if images:
212
+ u_msg["images"] = images
213
+
214
+ raw_history.append({"userInputMessage": u_msg})
215
+
216
+ elif msg.role == "assistant":
217
+ content = msg.content
218
+ text_content = extract_text_from_content(content)
219
+
220
+ entry = {
221
+ "assistantResponseMessage": {
222
+ "messageId": str(uuid.uuid4()),
223
+ "content": text_content
224
+ }
225
+ }
226
+
227
+ if isinstance(content, list):
228
+ tool_uses = []
229
+ for block in content:
230
+ if isinstance(block, dict) and block.get("type") == "tool_use":
231
+ tid = block.get("id")
232
+ if tid and tid not in seen_tool_use_ids:
233
+ seen_tool_use_ids.add(tid)
234
+ tool_uses.append({
235
+ "toolUseId": tid,
236
+ "name": block.get("name"),
237
+ "input": block.get("input", {})
238
+ })
239
+ if tool_uses:
240
+ entry["assistantResponseMessage"]["toolUses"] = tool_uses
241
+
242
+ raw_history.append(entry)
243
+
244
+ # Second pass: merge consecutive user messages
245
+ pending_user_msgs = []
246
+ for item in raw_history:
247
+ if "userInputMessage" in item:
248
+ pending_user_msgs.append(item["userInputMessage"])
249
+ elif "assistantResponseMessage" in item:
250
+ if pending_user_msgs:
251
+ merged = merge_user_messages(pending_user_msgs)
252
+ history.append({"userInputMessage": merged})
253
+ pending_user_msgs = []
254
+ history.append(item)
255
+
256
+ if pending_user_msgs:
257
+ merged = merge_user_messages(pending_user_msgs)
258
+ history.append({"userInputMessage": merged})
259
+
260
+ return history
261
+
262
+ def convert_claude_to_amazonq_request(req: ClaudeRequest, conversation_id: Optional[str] = None) -> Dict[str, Any]:
263
+ """Convert ClaudeRequest to Amazon Q request body."""
264
+ if conversation_id is None:
265
+ conversation_id = str(uuid.uuid4())
266
+
267
+ # 1. Tools
268
+ aq_tools = []
269
+ long_desc_tools = []
270
+ if req.tools:
271
+ for t in req.tools:
272
+ if t.description and len(t.description) > 10240:
273
+ long_desc_tools.append({"name": t.name, "full_description": t.description})
274
+ aq_tools.append(convert_tool(t))
275
+
276
+ # 2. Current Message (last user message)
277
+ last_msg = req.messages[-1] if req.messages else None
278
+ prompt_content = ""
279
+ tool_results = None
280
+ has_tool_result = False
281
+ images = None
282
+
283
+ if last_msg and last_msg.role == "user":
284
+ content = last_msg.content
285
+ images = extract_images_from_content(content)
286
+
287
+ if isinstance(content, list):
288
+ text_parts = []
289
+ for block in content:
290
+ if isinstance(block, dict):
291
+ btype = block.get("type")
292
+ if btype == "text":
293
+ text_parts.append(block.get("text", ""))
294
+ elif btype == "tool_result":
295
+ has_tool_result = True
296
+ if tool_results is None:
297
+ tool_results = []
298
+ process_tool_result_block(block, tool_results)
299
+ prompt_content = "\n".join(text_parts)
300
+ else:
301
+ prompt_content = extract_text_from_content(content)
302
+
303
+ # 3. Context
304
+ user_ctx = {
305
+ "envState": {
306
+ "operatingSystem": "macos",
307
+ "currentWorkingDirectory": "/"
308
+ }
309
+ }
310
+ if aq_tools:
311
+ user_ctx["tools"] = aq_tools
312
+ if tool_results:
313
+ user_ctx["toolResults"] = tool_results
314
+
315
+ # 4. Format Content
316
+ formatted_content = ""
317
+ if has_tool_result and not prompt_content:
318
+ formatted_content = ""
319
+ else:
320
+ formatted_content = (
321
+ "--- CONTEXT ENTRY BEGIN ---\n"
322
+ f"Current time: {get_current_timestamp()}\n"
323
+ "--- CONTEXT ENTRY END ---\n\n"
324
+ "--- USER MESSAGE BEGIN ---\n"
325
+ f"{prompt_content}\n"
326
+ "--- USER MESSAGE END ---"
327
+ )
328
+
329
+ if long_desc_tools:
330
+ docs = []
331
+ for info in long_desc_tools:
332
+ docs.append(f"Tool: {info['name']}\nFull Description:\n{info['full_description']}\n")
333
+ formatted_content = (
334
+ "--- TOOL DOCUMENTATION BEGIN ---\n"
335
+ f"{''.join(docs)}"
336
+ "--- TOOL DOCUMENTATION END ---\n\n"
337
+ f"{formatted_content}"
338
+ )
339
+
340
+ if req.system and formatted_content:
341
+ sys_text = ""
342
+ if isinstance(req.system, str):
343
+ sys_text = req.system
344
+ elif isinstance(req.system, list):
345
+ parts = []
346
+ for b in req.system:
347
+ if isinstance(b, dict) and b.get("type") == "text":
348
+ parts.append(b.get("text", ""))
349
+ sys_text = "\n".join(parts)
350
+
351
+ if sys_text:
352
+ formatted_content = (
353
+ "--- SYSTEM PROMPT BEGIN ---\n"
354
+ f"{sys_text}\n"
355
+ "--- SYSTEM PROMPT END ---\n\n"
356
+ f"{formatted_content}"
357
+ )
358
+
359
+ # 5. Model
360
+ model_id = map_model_name(req.model)
361
+
362
+ # 6. User Input Message
363
+ user_input_msg = {
364
+ "content": formatted_content,
365
+ "userInputMessageContext": user_ctx,
366
+ "origin": "CLI",
367
+ "modelId": model_id
368
+ }
369
+ if images:
370
+ user_input_msg["images"] = images
371
+
372
+ # 7. History
373
+ history_msgs = req.messages[:-1] if len(req.messages) > 1 else []
374
+ aq_history = process_history(history_msgs)
375
+
376
+ # 8. Final Body
377
+ return {
378
+ "conversationState": {
379
+ "conversationId": conversation_id,
380
+ "history": aq_history,
381
+ "currentMessage": {
382
+ "userInputMessage": user_input_msg
383
+ },
384
+ "chatTriggerType": "MANUAL"
385
+ }
386
+ }
claude_parser.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import struct
3
+ import logging
4
+ from typing import Optional, Dict, Any, AsyncIterator
5
+
6
+ logger = logging.getLogger(__name__)
7
+
8
+ class EventStreamParser:
9
+ """AWS Event Stream binary format parser (v2 style)."""
10
+
11
+ @staticmethod
12
+ def parse_headers(headers_data: bytes) -> Dict[str, str]:
13
+ """Parse event stream headers."""
14
+ headers = {}
15
+ offset = 0
16
+
17
+ while offset < len(headers_data):
18
+ if offset >= len(headers_data):
19
+ break
20
+ name_length = headers_data[offset]
21
+ offset += 1
22
+
23
+ if offset + name_length > len(headers_data):
24
+ break
25
+ name = headers_data[offset:offset + name_length].decode('utf-8')
26
+ offset += name_length
27
+
28
+ if offset >= len(headers_data):
29
+ break
30
+ value_type = headers_data[offset]
31
+ offset += 1
32
+
33
+ if offset + 2 > len(headers_data):
34
+ break
35
+ value_length = struct.unpack('>H', headers_data[offset:offset + 2])[0]
36
+ offset += 2
37
+
38
+ if offset + value_length > len(headers_data):
39
+ break
40
+
41
+ if value_type == 7:
42
+ value = headers_data[offset:offset + value_length].decode('utf-8')
43
+ else:
44
+ value = headers_data[offset:offset + value_length]
45
+
46
+ offset += value_length
47
+ headers[name] = value
48
+
49
+ return headers
50
+
51
+ @staticmethod
52
+ def parse_message(data: bytes) -> Optional[Dict[str, Any]]:
53
+ """Parse single Event Stream message."""
54
+ try:
55
+ if len(data) < 16:
56
+ return None
57
+
58
+ total_length = struct.unpack('>I', data[0:4])[0]
59
+ headers_length = struct.unpack('>I', data[4:8])[0]
60
+
61
+ if len(data) < total_length:
62
+ logger.warning(f"Incomplete message: expected {total_length} bytes, got {len(data)}")
63
+ return None
64
+
65
+ headers_data = data[12:12 + headers_length]
66
+ headers = EventStreamParser.parse_headers(headers_data)
67
+
68
+ payload_start = 12 + headers_length
69
+ payload_end = total_length - 4
70
+ payload_data = data[payload_start:payload_end]
71
+
72
+ payload = None
73
+ if payload_data:
74
+ try:
75
+ payload = json.loads(payload_data.decode('utf-8'))
76
+ except (json.JSONDecodeError, UnicodeDecodeError):
77
+ payload = payload_data
78
+
79
+ return {
80
+ 'headers': headers,
81
+ 'payload': payload,
82
+ 'total_length': total_length
83
+ }
84
+
85
+ except Exception as e:
86
+ logger.error(f"Failed to parse message: {e}", exc_info=True)
87
+ return None
88
+
89
+ @staticmethod
90
+ async def parse_stream(byte_stream: AsyncIterator[bytes]) -> AsyncIterator[Dict[str, Any]]:
91
+ """Parse byte stream and extract events."""
92
+ buffer = bytearray()
93
+
94
+ async for chunk in byte_stream:
95
+ buffer.extend(chunk)
96
+
97
+ while len(buffer) >= 12:
98
+ try:
99
+ total_length = struct.unpack('>I', buffer[0:4])[0]
100
+ except struct.error:
101
+ break
102
+
103
+ if len(buffer) < total_length:
104
+ break
105
+
106
+ message_data = bytes(buffer[:total_length])
107
+ buffer = buffer[total_length:]
108
+
109
+ message = EventStreamParser.parse_message(message_data)
110
+ if message:
111
+ yield message
112
+
113
+ def extract_event_info(message: Dict[str, Any]) -> Optional[Dict[str, Any]]:
114
+ """Extract event information from parsed message."""
115
+ headers = message.get('headers', {})
116
+ payload = message.get('payload')
117
+
118
+ event_type = headers.get(':event-type') or headers.get('event-type')
119
+ content_type = headers.get(':content-type') or headers.get('content-type')
120
+ message_type = headers.get(':message-type') or headers.get('message-type')
121
+
122
+ return {
123
+ 'event_type': event_type,
124
+ 'content_type': content_type,
125
+ 'message_type': message_type,
126
+ 'payload': payload
127
+ }
128
+
129
+ def _sse_format(event_type: str, data: Dict[str, Any]) -> str:
130
+ """Format SSE event."""
131
+ json_data = json.dumps(data, ensure_ascii=False)
132
+ return f"event: {event_type}\ndata: {json_data}\n\n"
133
+
134
+ def build_message_start(conversation_id: str, model: str = "claude-sonnet-4.5", input_tokens: int = 0) -> str:
135
+ """Build message_start SSE event."""
136
+ data = {
137
+ "type": "message_start",
138
+ "message": {
139
+ "id": conversation_id,
140
+ "type": "message",
141
+ "role": "assistant",
142
+ "content": [],
143
+ "model": model,
144
+ "stop_reason": None,
145
+ "stop_sequence": None,
146
+ "usage": {"input_tokens": input_tokens, "output_tokens": 0}
147
+ }
148
+ }
149
+ return _sse_format("message_start", data)
150
+
151
+ def build_content_block_start(index: int, block_type: str = "text") -> str:
152
+ """Build content_block_start SSE event."""
153
+ data = {
154
+ "type": "content_block_start",
155
+ "index": index,
156
+ "content_block": {"type": block_type, "text": ""} if block_type == "text" else {"type": block_type}
157
+ }
158
+ return _sse_format("content_block_start", data)
159
+
160
+ def build_content_block_delta(index: int, text: str) -> str:
161
+ """Build content_block_delta SSE event (text)."""
162
+ data = {
163
+ "type": "content_block_delta",
164
+ "index": index,
165
+ "delta": {"type": "text_delta", "text": text}
166
+ }
167
+ return _sse_format("content_block_delta", data)
168
+
169
+ def build_content_block_stop(index: int) -> str:
170
+ """Build content_block_stop SSE event."""
171
+ data = {
172
+ "type": "content_block_stop",
173
+ "index": index
174
+ }
175
+ return _sse_format("content_block_stop", data)
176
+
177
+ def build_ping() -> str:
178
+ """Build ping SSE event."""
179
+ data = {"type": "ping"}
180
+ return _sse_format("ping", data)
181
+
182
+ def build_message_stop(input_tokens: int, output_tokens: int, stop_reason: Optional[str] = None) -> str:
183
+ """Build message_delta and message_stop SSE events."""
184
+ delta_data = {
185
+ "type": "message_delta",
186
+ "delta": {"stop_reason": stop_reason or "end_turn", "stop_sequence": None},
187
+ "usage": {"output_tokens": output_tokens}
188
+ }
189
+ delta_event = _sse_format("message_delta", delta_data)
190
+
191
+ stop_data = {
192
+ "type": "message_stop"
193
+ }
194
+ stop_event = _sse_format("message_stop", stop_data)
195
+
196
+ return delta_event + stop_event
197
+
198
+ def build_tool_use_start(index: int, tool_use_id: str, tool_name: str) -> str:
199
+ """Build tool_use content_block_start SSE event."""
200
+ data = {
201
+ "type": "content_block_start",
202
+ "index": index,
203
+ "content_block": {
204
+ "type": "tool_use",
205
+ "id": tool_use_id,
206
+ "name": tool_name,
207
+ "input": {}
208
+ }
209
+ }
210
+ return _sse_format("content_block_start", data)
211
+
212
+ def build_tool_use_input_delta(index: int, input_json_delta: str) -> str:
213
+ """Build tool_use input_json_delta SSE event."""
214
+ data = {
215
+ "type": "content_block_delta",
216
+ "index": index,
217
+ "delta": {
218
+ "type": "input_json_delta",
219
+ "partial_json": input_json_delta
220
+ }
221
+ }
222
+ return _sse_format("content_block_delta", data)
claude_stream.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ from pathlib import Path
4
+ from typing import AsyncGenerator, Optional, Dict, Any, List, Set
5
+
6
+ from utils import load_module
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+ _parser = load_module("v2_claude_parser", "claude_parser.py")
11
+ build_message_start = _parser.build_message_start
12
+ build_content_block_start = _parser.build_content_block_start
13
+ build_content_block_delta = _parser.build_content_block_delta
14
+ build_content_block_stop = _parser.build_content_block_stop
15
+ build_ping = _parser.build_ping
16
+ build_message_stop = _parser.build_message_stop
17
+ build_tool_use_start = _parser.build_tool_use_start
18
+ build_tool_use_input_delta = _parser.build_tool_use_input_delta
19
+
20
+ class ClaudeStreamHandler:
21
+ def __init__(self, model: str, input_tokens: int = 0):
22
+ self.model = model
23
+ self.input_tokens = input_tokens
24
+ self.response_buffer: List[str] = []
25
+ self.content_block_index: int = -1
26
+ self.content_block_started: bool = False
27
+ self.content_block_start_sent: bool = False
28
+ self.content_block_stop_sent: bool = False
29
+ self.message_start_sent: bool = False
30
+ self.conversation_id: Optional[str] = None
31
+
32
+ # Tool use state
33
+ self.current_tool_use: Optional[Dict[str, Any]] = None
34
+ self.tool_input_buffer: List[str] = []
35
+ self.tool_use_id: Optional[str] = None
36
+ self.tool_name: Optional[str] = None
37
+ self._processed_tool_use_ids: Set[str] = set()
38
+ self.all_tool_inputs: List[str] = []
39
+
40
+ async def handle_event(self, event_type: str, payload: Dict[str, Any]) -> AsyncGenerator[str, None]:
41
+ """Process a single Amazon Q event and yield Claude SSE events."""
42
+
43
+ # 1. Message Start (initial-response)
44
+ if event_type == "initial-response":
45
+ if not self.message_start_sent:
46
+ conv_id = payload.get('conversationId', self.conversation_id or 'unknown')
47
+ self.conversation_id = conv_id
48
+ yield build_message_start(conv_id, self.model, self.input_tokens)
49
+ self.message_start_sent = True
50
+ yield build_ping()
51
+
52
+ # 2. Content Block Delta (assistantResponseEvent)
53
+ elif event_type == "assistantResponseEvent":
54
+ content = payload.get("content", "")
55
+
56
+ # Close any open tool use block
57
+ if self.current_tool_use and not self.content_block_stop_sent:
58
+ yield build_content_block_stop(self.content_block_index)
59
+ self.content_block_stop_sent = True
60
+ self.current_tool_use = None
61
+
62
+ # Start content block if needed
63
+ if not self.content_block_start_sent:
64
+ self.content_block_index += 1
65
+ yield build_content_block_start(self.content_block_index, "text")
66
+ self.content_block_start_sent = True
67
+ self.content_block_started = True
68
+
69
+ # Send delta
70
+ if content:
71
+ self.response_buffer.append(content)
72
+ yield build_content_block_delta(self.content_block_index, content)
73
+
74
+ # 3. Tool Use (toolUseEvent)
75
+ elif event_type == "toolUseEvent":
76
+ tool_use_id = payload.get("toolUseId")
77
+ tool_name = payload.get("name")
78
+ tool_input = payload.get("input", {})
79
+ is_stop = payload.get("stop", False)
80
+
81
+ # Start new tool use
82
+ if tool_use_id and tool_name and not self.current_tool_use:
83
+ # Close previous text block if open
84
+ if self.content_block_start_sent and not self.content_block_stop_sent:
85
+ yield build_content_block_stop(self.content_block_index)
86
+ self.content_block_stop_sent = True
87
+
88
+ self._processed_tool_use_ids.add(tool_use_id)
89
+ self.content_block_index += 1
90
+
91
+ yield build_tool_use_start(self.content_block_index, tool_use_id, tool_name)
92
+
93
+ self.content_block_started = True
94
+ self.current_tool_use = {"toolUseId": tool_use_id, "name": tool_name}
95
+ self.tool_use_id = tool_use_id
96
+ self.tool_name = tool_name
97
+ self.tool_input_buffer = []
98
+ self.content_block_stop_sent = False
99
+ self.content_block_start_sent = True
100
+
101
+ # Accumulate input
102
+ if self.current_tool_use and tool_input:
103
+ fragment = ""
104
+ if isinstance(tool_input, str):
105
+ fragment = tool_input
106
+ else:
107
+ fragment = json.dumps(tool_input, ensure_ascii=False)
108
+
109
+ self.tool_input_buffer.append(fragment)
110
+ yield build_tool_use_input_delta(self.content_block_index, fragment)
111
+
112
+ # Stop tool use
113
+ if is_stop and self.current_tool_use:
114
+ full_input = "".join(self.tool_input_buffer)
115
+ self.all_tool_inputs.append(full_input)
116
+
117
+ yield build_content_block_stop(self.content_block_index)
118
+ self.content_block_stop_sent = True
119
+ self.content_block_started = False
120
+ self.current_tool_use = None
121
+ self.tool_use_id = None
122
+ self.tool_name = None
123
+ self.tool_input_buffer = []
124
+
125
+ # 4. Assistant Response End (assistantResponseEnd)
126
+ elif event_type == "assistantResponseEnd":
127
+ # Close any open block
128
+ if self.content_block_started and not self.content_block_stop_sent:
129
+ yield build_content_block_stop(self.content_block_index)
130
+ self.content_block_stop_sent = True
131
+
132
+ async def finish(self) -> AsyncGenerator[str, None]:
133
+ """Send final events."""
134
+ # Ensure last block is closed
135
+ if self.content_block_started and not self.content_block_stop_sent:
136
+ yield build_content_block_stop(self.content_block_index)
137
+ self.content_block_stop_sent = True
138
+
139
+ # Calculate output tokens (approximate)
140
+ full_text = "".join(self.response_buffer)
141
+ full_tool_input = "".join(self.all_tool_inputs)
142
+ # Simple approximation: 4 chars per token
143
+ output_tokens = max(1, (len(full_text) + len(full_tool_input)) // 4)
144
+
145
+ yield build_message_stop(self.input_tokens, output_tokens, "end_turn")
claude_types.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Union, Dict, Any, Literal
2
+ from pydantic import BaseModel
3
+
4
+ class ClaudeMessage(BaseModel):
5
+ role: str
6
+ content: Union[str, List[Dict[str, Any]]]
7
+
8
+ class ClaudeTool(BaseModel):
9
+ name: str
10
+ description: Optional[str] = ""
11
+ input_schema: Dict[str, Any]
12
+
13
+ class ClaudeRequest(BaseModel):
14
+ model: str
15
+ messages: List[ClaudeMessage]
16
+ max_tokens: int = 8192
17
+ temperature: Optional[float] = None
18
+ tools: Optional[List[ClaudeTool]] = None
19
+ stream: bool = False
20
+ system: Optional[Union[str, List[Dict[str, Any]]]] = None
config.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Amazon Q API 配置文件
3
+ 包含请求模板和默认配置
4
+ """
5
+
6
+ # Amazon Q API 端点
7
+ AMAZONQ_API_URL = "https://q.us-east-1.amazonaws.com/"
8
+
9
+ # 默认请求头模板
10
+ DEFAULT_HEADERS = {
11
+ "content-type": "application/x-amz-json-1.0",
12
+ "x-amz-target": "AmazonCodeWhispererStreamingService.GenerateAssistantResponse",
13
+ "user-agent": "aws-sdk-rust/1.3.9 ua/2.1 api/codewhispererstreaming/0.1.11582 os/windows lang/rust/1.87.0 md/appVersion-1.19.4 app/AmazonQ-For-CLI",
14
+ "x-amz-user-agent": "aws-sdk-rust/1.3.9 ua/2.1 api/codewhispererstreaming/0.1.11582 os/windows lang/rust/1.87.0 m/F app/AmazonQ-For-CLI",
15
+ "x-amzn-codewhisperer-optout": "false",
16
+ "amz-sdk-request": "attempt=1; max=3"
17
+ }
18
+
19
+ # 默认请求体模板(仅作为结构参考,实际使用时会被 raw_payload 替换)
20
+ DEFAULT_BODY_TEMPLATE = {
21
+ "conversationState": {
22
+ "conversationId": "", # 运行时动态生成
23
+ "history": [],
24
+ "currentMessage": {
25
+ "userInputMessage": {
26
+ "content": "",
27
+ "userInputMessageContext": {
28
+ "envState": {
29
+ "operatingSystem": "windows",
30
+ "currentWorkingDirectory": ""
31
+ },
32
+ "tools": []
33
+ },
34
+ "origin": "CLI",
35
+ "modelId": "claude-sonnet-4"
36
+ }
37
+ },
38
+ "chatTriggerType": "MANUAL"
39
+ }
40
+ }
replicate.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import uuid
3
+ import asyncio
4
+ from typing import Dict, Optional, Tuple, List, AsyncGenerator, Any
5
+ import httpx
6
+
7
+ from utils import get_proxies, load_module, create_proxy_mounts
8
+ from config import AMAZONQ_API_URL, DEFAULT_HEADERS
9
+
10
+ try:
11
+ _parser = load_module("v2_claude_parser", "claude_parser.py")
12
+ EventStreamParser = _parser.EventStreamParser
13
+ extract_event_info = _parser.extract_event_info
14
+ except Exception as e:
15
+ print(f"Warning: Failed to load claude_parser: {e}")
16
+ EventStreamParser = None
17
+ extract_event_info = None
18
+
19
+ class StreamTracker:
20
+ def __init__(self):
21
+ self.has_content = False
22
+
23
+ async def track(self, gen: AsyncGenerator[str, None]) -> AsyncGenerator[str, None]:
24
+ async for item in gen:
25
+ if item:
26
+ self.has_content = True
27
+ yield item
28
+
29
+ def load_template() -> Tuple[str, Dict[str, str]]:
30
+ """
31
+ 加载 Amazon Q API 请求模板
32
+
33
+ Returns:
34
+ (url, headers): API 端点 URL 和默认请求头
35
+ """
36
+ return AMAZONQ_API_URL, DEFAULT_HEADERS.copy()
37
+
38
+ def _merge_headers(as_log: Dict[str, str], bearer_token: str) -> Dict[str, str]:
39
+ headers = dict(as_log)
40
+ for k in list(headers.keys()):
41
+ kl = k.lower()
42
+ if kl in ("content-length","host","connection","transfer-encoding"):
43
+ headers.pop(k, None)
44
+ def set_header(name: str, value: str):
45
+ for key in list(headers.keys()):
46
+ if key.lower() == name.lower():
47
+ del headers[key]
48
+ headers[name] = value
49
+ set_header("Authorization", f"Bearer {bearer_token}")
50
+ set_header("amz-sdk-invocation-id", str(uuid.uuid4()))
51
+ return headers
52
+
53
+ async def send_chat_request(
54
+ access_token: str,
55
+ messages: List[Dict[str, Any]],
56
+ model: Optional[str] = None,
57
+ stream: bool = False,
58
+ timeout: Tuple[int,int] = (30,300),
59
+ client: Optional[httpx.AsyncClient] = None,
60
+ raw_payload: Dict[str, Any] = None
61
+ ) -> Tuple[Optional[str], Optional[AsyncGenerator[str, None]], StreamTracker, Optional[AsyncGenerator[Any, None]]]:
62
+ """
63
+ 发送聊天请求到 Amazon Q API
64
+
65
+ Args:
66
+ access_token: Amazon Q access token
67
+ messages: 消息列表(已废弃,使用 raw_payload)
68
+ model: 模型名称(已废弃,使用 raw_payload)
69
+ stream: 是否流式响应
70
+ timeout: 超时配置
71
+ client: HTTP 客户端
72
+ raw_payload: Claude API 转换后的请求体(必需)
73
+ """
74
+ if raw_payload is None:
75
+ raise ValueError("raw_payload is required")
76
+
77
+ url, headers_from_log = load_template()
78
+ headers_from_log["amz-sdk-invocation-id"] = str(uuid.uuid4())
79
+
80
+ # Use raw payload (for Claude API)
81
+ body_json = raw_payload
82
+ # Ensure conversationId is set if missing
83
+ if "conversationState" in body_json and "conversationId" not in body_json["conversationState"]:
84
+ body_json["conversationState"]["conversationId"] = str(uuid.uuid4())
85
+
86
+ payload_str = json.dumps(body_json, ensure_ascii=False)
87
+ headers = _merge_headers(headers_from_log, access_token)
88
+
89
+ local_client = False
90
+ if client is None:
91
+ local_client = True
92
+ mounts = create_proxy_mounts()
93
+ # 增加连接超时时间,避免 TLS 握手超时
94
+ timeout_config = httpx.Timeout(connect=60.0, read=timeout[1], write=timeout[0], pool=10.0)
95
+ # 只在有代理时才传递 mounts 参数
96
+ if mounts:
97
+ client = httpx.AsyncClient(mounts=mounts, timeout=timeout_config)
98
+ else:
99
+ client = httpx.AsyncClient(timeout=timeout_config)
100
+
101
+ # Use manual request sending to control stream lifetime
102
+ req = client.build_request("POST", url, headers=headers, content=payload_str)
103
+
104
+ resp = None
105
+ try:
106
+ resp = await client.send(req, stream=True)
107
+
108
+ if resp.status_code >= 400:
109
+ try:
110
+ await resp.read()
111
+ err = resp.text
112
+ except Exception:
113
+ err = f"HTTP {resp.status_code}"
114
+ await resp.aclose()
115
+ if local_client:
116
+ await client.aclose()
117
+ raise httpx.HTTPError(f"Upstream error {resp.status_code}: {err}")
118
+
119
+ tracker = StreamTracker()
120
+
121
+ # Track if the response has been consumed to avoid double-close
122
+ response_consumed = False
123
+
124
+ async def _iter_events() -> AsyncGenerator[Any, None]:
125
+ nonlocal response_consumed
126
+ try:
127
+ # Use EventStreamParser from claude_parser.py
128
+ async def byte_gen():
129
+ async for chunk in resp.aiter_bytes():
130
+ if chunk:
131
+ yield chunk
132
+
133
+ async for message in EventStreamParser.parse_stream(byte_gen()):
134
+ event_info = extract_event_info(message)
135
+ if event_info:
136
+ event_type = event_info.get('event_type')
137
+ payload = event_info.get('payload')
138
+ if event_type and payload:
139
+ yield (event_type, payload)
140
+ except Exception:
141
+ if not tracker.has_content:
142
+ raise
143
+ finally:
144
+ response_consumed = True
145
+ await resp.aclose()
146
+ if local_client:
147
+ await client.aclose()
148
+
149
+ if stream:
150
+ # Wrap generator to ensure cleanup on early termination
151
+ async def _safe_iter_events():
152
+ try:
153
+ # 托底方案: 300秒强制超时
154
+ async with asyncio.timeout(300):
155
+ async for item in _iter_events():
156
+ yield item
157
+ except asyncio.TimeoutError:
158
+ # 超时强制关闭
159
+ if resp and not resp.is_closed:
160
+ await resp.aclose()
161
+ if local_client and client:
162
+ await client.aclose()
163
+ raise
164
+ except GeneratorExit:
165
+ # Generator was closed without being fully consumed
166
+ # Ensure cleanup happens even if finally block wasn't reached
167
+ if resp and not resp.is_closed:
168
+ await resp.aclose()
169
+ if local_client and client:
170
+ await client.aclose()
171
+ raise
172
+ except Exception:
173
+ # Any exception should also trigger cleanup
174
+ if resp and not resp.is_closed:
175
+ await resp.aclose()
176
+ if local_client and client:
177
+ await client.aclose()
178
+ raise
179
+ return None, None, tracker, _safe_iter_events()
180
+ else:
181
+ # Non-streaming: consume all events
182
+ try:
183
+ async for _ in _iter_events():
184
+ pass
185
+ finally:
186
+ # Ensure response is closed even if iteration is incomplete
187
+ if not response_consumed and resp:
188
+ await resp.aclose()
189
+ if local_client:
190
+ await client.aclose()
191
+ return None, None, tracker, None
192
+
193
+ except Exception:
194
+ # Critical: close response on any exception before generators are created
195
+ if resp and not resp.is_closed:
196
+ await resp.aclose()
197
+ if local_client and client:
198
+ await client.aclose()
199
+ raise
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ fastapi==0.115.5
2
+ uvicorn[standard]==0.32.0
3
+ pydantic==2.9.2
4
+ python-dotenv==1.0.1
5
+ httpx==0.28.1
utils.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """公共工具函数"""
2
+ import os
3
+ import importlib.util
4
+ import httpx
5
+ from pathlib import Path
6
+ from typing import Dict, Optional
7
+
8
+
9
+ def get_proxies() -> Optional[Dict[str, str]]:
10
+ """
11
+ 从环境变量获取代理配置
12
+ 读取 HTTP_PROXY 环境变量并返回代理字典
13
+ """
14
+ proxy = os.getenv("HTTP_PROXY", "").strip()
15
+ if proxy:
16
+ return {"http": proxy, "https": proxy}
17
+ return None
18
+
19
+
20
+ def load_module(module_name: str, file_name: str):
21
+ """
22
+ 动态加载指定模块
23
+
24
+ Args:
25
+ module_name: 模块名称
26
+ file_name: 文件名(相对于当前目录)
27
+
28
+ Returns:
29
+ 加载的模块对象
30
+ """
31
+ base_dir = Path(__file__).resolve().parent
32
+ spec = importlib.util.spec_from_file_location(module_name, str(base_dir / file_name))
33
+ module = importlib.util.module_from_spec(spec)
34
+ spec.loader.exec_module(module)
35
+ return module
36
+
37
+
38
+ def create_proxy_mounts() -> Optional[Dict[str, httpx.AsyncHTTPTransport]]:
39
+ """
40
+ 创建代理传输层配置
41
+
42
+ Returns:
43
+ 代理挂载配置字典,如果没有配置代理则返回 None
44
+ """
45
+ proxies = get_proxies()
46
+ if proxies:
47
+ proxy_url = proxies.get("https") or proxies.get("http")
48
+ if proxy_url:
49
+ return {
50
+ "https://": httpx.AsyncHTTPTransport(proxy=proxy_url),
51
+ "http://": httpx.AsyncHTTPTransport(proxy=proxy_url),
52
+ }
53
+ return None