devme commited on
Commit
bc43157
·
verified ·
1 Parent(s): e51ffd6

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -472
app.py DELETED
@@ -1,472 +0,0 @@
1
- import json
2
- import os
3
- import traceback
4
- import uuid
5
- import time
6
- import asyncio
7
- import importlib.util
8
- from pathlib import Path
9
- from typing import Dict, Optional, List, Any, AsyncGenerator, Tuple
10
-
11
- from contextlib import asynccontextmanager
12
- from fastapi import FastAPI, Depends, HTTPException, Header
13
- from fastapi.middleware.cors import CORSMiddleware
14
- from fastapi.responses import StreamingResponse, RedirectResponse
15
- from dotenv import load_dotenv
16
- import httpx
17
- import hashlib
18
-
19
- from utils import get_proxies, create_proxy_mounts
20
-
21
- # ------------------------------------------------------------------------------
22
- # Bootstrap
23
- # ------------------------------------------------------------------------------
24
-
25
- BASE_DIR = Path(__file__).resolve().parent
26
-
27
- load_dotenv(BASE_DIR / ".env")
28
-
29
- app = FastAPI(title="v2 OpenAI-compatible Server (Amazon Q Backend)")
30
-
31
- # CORS for simple testing in browser
32
- app.add_middleware(
33
- CORSMiddleware,
34
- allow_origins=["*"],
35
- allow_methods=["*"],
36
- allow_headers=["*"],
37
- )
38
-
39
- # ------------------------------------------------------------------------------
40
- # Dynamic import of replicate.py to avoid package __init__ needs
41
- # ------------------------------------------------------------------------------
42
-
43
- def _load_replicate_module():
44
- mod_path = BASE_DIR / "replicate.py"
45
- spec = importlib.util.spec_from_file_location("v2_replicate", str(mod_path))
46
- module = importlib.util.module_from_spec(spec) # type: ignore[arg-type]
47
- assert spec is not None and spec.loader is not None
48
- spec.loader.exec_module(module) # type: ignore[attr-defined]
49
- return module
50
-
51
- _replicate = _load_replicate_module()
52
- send_chat_request = _replicate.send_chat_request
53
-
54
- # ------------------------------------------------------------------------------
55
- # Dynamic import of Claude modules
56
- # ------------------------------------------------------------------------------
57
-
58
- def _load_claude_modules():
59
- # claude_types
60
- spec_types = importlib.util.spec_from_file_location("v2_claude_types", str(BASE_DIR / "claude_types.py"))
61
- mod_types = importlib.util.module_from_spec(spec_types)
62
- spec_types.loader.exec_module(mod_types)
63
-
64
- # claude_converter
65
- spec_conv = importlib.util.spec_from_file_location("v2_claude_converter", str(BASE_DIR / "claude_converter.py"))
66
- mod_conv = importlib.util.module_from_spec(spec_conv)
67
-
68
- import sys
69
- sys.modules["v2.claude_types"] = mod_types
70
-
71
- spec_conv.loader.exec_module(mod_conv)
72
-
73
- # claude_stream
74
- spec_stream = importlib.util.spec_from_file_location("v2_claude_stream", str(BASE_DIR / "claude_stream.py"))
75
- mod_stream = importlib.util.module_from_spec(spec_stream)
76
- spec_stream.loader.exec_module(mod_stream)
77
-
78
- return mod_types, mod_conv, mod_stream
79
-
80
- _claude_types, _claude_converter, _claude_stream = _load_claude_modules()
81
- ClaudeRequest = _claude_types.ClaudeRequest
82
- convert_claude_to_amazonq_request = _claude_converter.convert_claude_to_amazonq_request
83
- ClaudeStreamHandler = _claude_stream.ClaudeStreamHandler
84
-
85
- # ------------------------------------------------------------------------------
86
- # Global HTTP Client
87
- # ------------------------------------------------------------------------------
88
-
89
- GLOBAL_CLIENT: Optional[httpx.AsyncClient] = None
90
-
91
- async def _init_global_client():
92
- global GLOBAL_CLIENT
93
- mounts = create_proxy_mounts()
94
- # Increased limits for high concurrency with streaming
95
- # max_connections: 总连接数上限
96
- # max_keepalive_connections: 保持活跃的连接数
97
- # keepalive_expiry: 连接保持时间
98
- limits = httpx.Limits(
99
- max_keepalive_connections=60,
100
- max_connections=60, # 提高到500以支持更高并发
101
- keepalive_expiry=30.0 # 30秒后释放空闲连接
102
- )
103
- # 为流式响应设置更长的超时
104
- timeout = httpx.Timeout(
105
- connect=30.0, # 连接超时,TLS 握手需要足够时间
106
- read=300.0, # 读取超时(流式响应需要更长时间)
107
- write=30.0, # 写入超时
108
- pool=10.0 # 从连接池获取连接的超时时间
109
- )
110
- # 只在有代理时才传递 mounts 参数
111
- if mounts:
112
- GLOBAL_CLIENT = httpx.AsyncClient(mounts=mounts, timeout=timeout, limits=limits)
113
- else:
114
- GLOBAL_CLIENT = httpx.AsyncClient(timeout=timeout, limits=limits)
115
-
116
- async def _close_global_client():
117
- global GLOBAL_CLIENT
118
- if GLOBAL_CLIENT:
119
- await GLOBAL_CLIENT.aclose()
120
- GLOBAL_CLIENT = None
121
-
122
- # ------------------------------------------------------------------------------
123
- # Token 缓存和管理
124
- # ------------------------------------------------------------------------------
125
-
126
- # 内存缓存: {hash: {accessToken, refreshToken, clientId, clientSecret, lastRefresh}}
127
- TOKEN_MAP: Dict[str, Dict[str, Any]] = {}
128
-
129
- def _sha256(text: str) -> str:
130
- """计算 SHA256 哈希"""
131
- return hashlib.sha256(text.encode()).hexdigest()
132
-
133
- def _parse_bearer_token(bearer_token: str) -> Tuple[str, str, str]:
134
- """
135
- ��析 Bearer token: clientId:clientSecret:refreshToken
136
- 重要: refreshToken 中可能包含冒号,所以要正确处理
137
- """
138
- temp_array = bearer_token.split(":")
139
- client_id = temp_array[0] if len(temp_array) > 0 else ""
140
- client_secret = temp_array[1] if len(temp_array) > 1 else ""
141
- refresh_token = ":".join(temp_array[2:]) if len(temp_array) > 2 else ""
142
- return client_id, client_secret, refresh_token
143
-
144
- async def _handle_token_refresh(client_id: str, client_secret: str, refresh_token: str) -> Optional[str]:
145
- """刷新 access token"""
146
- payload = {
147
- "grantType": "refresh_token",
148
- "clientId": client_id,
149
- "clientSecret": client_secret,
150
- "refreshToken": refresh_token,
151
- }
152
-
153
- try:
154
- client = GLOBAL_CLIENT
155
- if not client:
156
- async with httpx.AsyncClient(timeout=60.0) as temp_client:
157
- r = await temp_client.post(TOKEN_URL, headers=_oidc_headers(), json=payload)
158
- r.raise_for_status()
159
- data = r.json()
160
- else:
161
- r = await client.post(TOKEN_URL, headers=_oidc_headers(), json=payload)
162
- r.raise_for_status()
163
- data = r.json()
164
-
165
- return data.get("accessToken")
166
- except httpx.HTTPStatusError as e:
167
- print(f"Token refresh HTTP error: {e.response.status_code} - {e.response.text}")
168
- traceback.print_exc()
169
- return None
170
- except Exception as e:
171
- print(f"Token refresh error: {e}")
172
- traceback.print_exc()
173
- return None
174
-
175
- # ------------------------------------------------------------------------------
176
- # 全局 Token 刷新器
177
- # ------------------------------------------------------------------------------
178
-
179
- async def _global_token_refresher():
180
- """全局刷新器: 每 45 分钟刷新所有缓存的 token"""
181
- while True:
182
- try:
183
- await asyncio.sleep(45 * 60) # 45 minutes
184
- if not TOKEN_MAP:
185
- continue
186
- print(f"[Token Refresher] Starting token refresh cycle...")
187
- refresh_count = 0
188
- for hash_key, token_data in list(TOKEN_MAP.items()):
189
- try:
190
- new_token = await _handle_token_refresh(
191
- token_data["clientId"],
192
- token_data["clientSecret"],
193
- token_data["refreshToken"]
194
- )
195
- if new_token:
196
- TOKEN_MAP[hash_key]["accessToken"] = new_token
197
- TOKEN_MAP[hash_key]["lastRefresh"] = time.time()
198
- refresh_count += 1
199
- else:
200
- print(f"[Token Refresher] Failed to refresh token for hash: {hash_key[:8]}...")
201
- except Exception as e:
202
- print(f"[Token Refresher] Exception refreshing token: {e}")
203
- traceback.print_exc()
204
- print(f"[Token Refresher] Refreshed {refresh_count}/{len(TOKEN_MAP)} tokens")
205
- except Exception:
206
- traceback.print_exc()
207
- await asyncio.sleep(60) # 发生异常时等待 1 分钟后重试
208
-
209
- # ------------------------------------------------------------------------------
210
- # Token refresh (OIDC)
211
- # ------------------------------------------------------------------------------
212
-
213
- OIDC_BASE = "https://oidc.us-east-1.amazonaws.com"
214
- TOKEN_URL = f"{OIDC_BASE}/token"
215
-
216
- def _oidc_headers() -> Dict[str, str]:
217
- return {
218
- "content-type": "application/json",
219
- "user-agent": "aws-sdk-rust/1.3.9 os/windows lang/rust/1.87.0",
220
- "x-amz-user-agent": "aws-sdk-rust/1.3.9 ua/2.1 api/ssooidc/1.88.0 os/windows lang/rust/1.87.0 m/E app/AmazonQ-For-CLI",
221
- "amz-sdk-request": "attempt=1; max=3",
222
- "amz-sdk-invocation-id": str(uuid.uuid4()),
223
- }
224
-
225
- # ------------------------------------------------------------------------------
226
- # 认证中间件
227
- # ------------------------------------------------------------------------------
228
-
229
- async def auth_middleware(
230
- authorization: Optional[str] = Header(default=None),
231
- x_api_key: Optional[str] = Header(default=None, alias="x-api-key")
232
- ) -> Dict[str, Any]:
233
- """
234
- 认证中间件: 支持 OpenAI Bearer token 和 Claude x-api-key
235
- Token 格式: clientId:clientSecret:refreshToken
236
- """
237
- # 优先使用 x-api-key (Claude 格式)
238
- token = x_api_key if x_api_key else None
239
-
240
- # 如果没有 x-api-key,尝试从 Authorization header 获取 (OpenAI 格式)
241
- if not token and authorization and authorization.startswith("Bearer "):
242
- token = authorization[7:]
243
-
244
- if not token:
245
- raise HTTPException(status_code=401, detail="Missing authentication. Provide Authorization header or x-api-key")
246
-
247
- token_hash = _sha256(token)
248
-
249
- # 检查缓存
250
- if token_hash in TOKEN_MAP:
251
- return {
252
- "accessToken": TOKEN_MAP[token_hash]["accessToken"],
253
- "clientId": TOKEN_MAP[token_hash]["clientId"],
254
- "clientSecret": TOKEN_MAP[token_hash]["clientSecret"],
255
- "refreshToken": TOKEN_MAP[token_hash]["refreshToken"],
256
- }
257
-
258
- # 解析 token
259
- client_id, client_secret, refresh_token = _parse_bearer_token(token)
260
-
261
- if not client_id or not client_secret or not refresh_token:
262
- raise HTTPException(status_code=401, detail="Invalid token format. Expected: clientId:clientSecret:refreshToken")
263
-
264
- # 刷新 token
265
- access_token = await _handle_token_refresh(client_id, client_secret, refresh_token)
266
- if not access_token:
267
- raise HTTPException(status_code=401, detail="Failed to refresh access token")
268
-
269
- # 缓存
270
- TOKEN_MAP[token_hash] = {
271
- "accessToken": access_token,
272
- "refreshToken": refresh_token,
273
- "clientId": client_id,
274
- "clientSecret": client_secret,
275
- "lastRefresh": time.time()
276
- }
277
-
278
- return {
279
- "accessToken": access_token,
280
- "clientId": client_id,
281
- "clientSecret": client_secret,
282
- "refreshToken": refresh_token,
283
- }
284
-
285
- # ------------------------------------------------------------------------------
286
- # Dependencies
287
- # ------------------------------------------------------------------------------
288
-
289
- async def require_account(
290
- authorization: Optional[str] = Header(default=None),
291
- x_api_key: Optional[str] = Header(default=None, alias="x-api-key")
292
- ) -> Dict[str, Any]:
293
- return await auth_middleware(authorization, x_api_key)
294
-
295
- # ------------------------------------------------------------------------------
296
- # Root endpoint
297
- # ------------------------------------------------------------------------------
298
-
299
- @app.get("/")
300
- async def root():
301
- return RedirectResponse(url="https://www.bilibili.com/video/BV1SMH5zfEwe/?spm_id_from=333.337.search-card.all.click&vd_source=1f3b8eb28230105c578a443fa6481550")
302
-
303
- # ------------------------------------------------------------------------------
304
- # Claude Messages API endpoint
305
- # ------------------------------------------------------------------------------
306
-
307
- @app.post("/v1/messages")
308
- async def claude_messages(req: ClaudeRequest, account: Dict[str, Any] = Depends(require_account)):
309
- """
310
- Claude-compatible messages endpoint.
311
- """
312
- # 1. Convert request
313
- try:
314
- aq_request = convert_claude_to_amazonq_request(req)
315
- except Exception as e:
316
- traceback.print_exc()
317
- raise HTTPException(status_code=400, detail=f"Request conversion failed: {str(e)}")
318
-
319
- # 2. Send upstream - always stream from upstream to get full event details
320
- try:
321
- access = account.get("accessToken")
322
- if not access:
323
- raise HTTPException(status_code=502, detail="Access token unavailable")
324
-
325
- # We call with stream=True to get the event iterator
326
- _, _, tracker, event_iter = await send_chat_request(
327
- access_token=access,
328
- messages=[],
329
- model=req.model,
330
- stream=True,
331
- client=GLOBAL_CLIENT,
332
- raw_payload=aq_request
333
- )
334
-
335
- if not event_iter:
336
- raise HTTPException(status_code=502, detail="No event stream returned")
337
-
338
- # Handler
339
- # Estimate input tokens (simple count or 0)
340
- # For now 0 or simple len
341
- input_tokens = 0
342
- handler = ClaudeStreamHandler(model=req.model, input_tokens=input_tokens)
343
-
344
- async def event_generator():
345
- try:
346
- async for event_type, payload in event_iter:
347
- async for sse in handler.handle_event(event_type, payload):
348
- yield sse
349
- async for sse in handler.finish():
350
- yield sse
351
- except GeneratorExit:
352
- # Client disconnected
353
- raise
354
- except Exception:
355
- raise
356
-
357
- if req.stream:
358
- return StreamingResponse(event_generator(), media_type="text/event-stream")
359
- else:
360
- # Accumulate for non-streaming
361
- # This is a bit complex because we need to reconstruct the full response object
362
- # For now, let's just support streaming as it's the main use case for Claude Code
363
- # But to be nice, let's try to support non-streaming by consuming the generator
364
-
365
- content_blocks = []
366
- usage = {"input_tokens": 0, "output_tokens": 0}
367
- stop_reason = None
368
-
369
- # We need to parse the SSE strings back to objects... inefficient but works
370
- # Or we could refactor handler to yield objects.
371
- # For now, let's just raise error for non-streaming or implement basic text
372
- # Claude Code uses streaming.
373
-
374
- # Let's implement a basic accumulator from the SSE stream
375
- final_content = []
376
-
377
- async for sse_line in event_generator():
378
- if sse_line.startswith("data: "):
379
- data_str = sse_line[6:].strip()
380
- if data_str == "[DONE]": continue
381
- try:
382
- data = json.loads(data_str)
383
- dtype = data.get("type")
384
- if dtype == "content_block_start":
385
- idx = data.get("index", 0)
386
- while len(final_content) <= idx:
387
- final_content.append(None)
388
- final_content[idx] = data.get("content_block")
389
- elif dtype == "content_block_delta":
390
- idx = data.get("index", 0)
391
- delta = data.get("delta", {})
392
- if final_content[idx]:
393
- if delta.get("type") == "text_delta":
394
- final_content[idx]["text"] += delta.get("text", "")
395
- elif delta.get("type") == "input_json_delta":
396
- # We need to accumulate partial json
397
- # But wait, content_block for tool_use has 'input' as dict?
398
- # No, in start it is empty.
399
- # We need to track partial json string
400
- if "partial_json" not in final_content[idx]:
401
- final_content[idx]["partial_json"] = ""
402
- final_content[idx]["partial_json"] += delta.get("partial_json", "")
403
- elif dtype == "content_block_stop":
404
- idx = data.get("index", 0)
405
- # If tool use, parse json
406
- if final_content[idx] and final_content[idx]["type"] == "tool_use":
407
- if "partial_json" in final_content[idx]:
408
- try:
409
- final_content[idx]["input"] = json.loads(final_content[idx]["partial_json"])
410
- except:
411
- pass
412
- del final_content[idx]["partial_json"]
413
- elif dtype == "message_delta":
414
- usage = data.get("usage", usage)
415
- stop_reason = data.get("delta", {}).get("stop_reason")
416
- except:
417
- pass
418
-
419
- return {
420
- "id": f"msg_{uuid.uuid4()}",
421
- "type": "message",
422
- "role": "assistant",
423
- "model": req.model,
424
- "content": [c for c in final_content if c is not None],
425
- "stop_reason": stop_reason,
426
- "stop_sequence": None,
427
- "usage": usage
428
- }
429
-
430
- except Exception as e:
431
- raise
432
-
433
- # ------------------------------------------------------------------------------
434
- # Startup / Shutdown Events
435
- # ------------------------------------------------------------------------------
436
-
437
- async def _startup():
438
- """初始化全局客户端和启动后台任务"""
439
- await _init_global_client()
440
- asyncio.create_task(_global_token_refresher())
441
-
442
- async def _shutdown():
443
- """清理资源"""
444
- await _close_global_client()
445
-
446
- # 更新 lifespan 上下文管理器使用实际的启动/关闭逻辑
447
- @asynccontextmanager
448
- async def lifespan(app_instance: FastAPI):
449
- """
450
- 管理应用生命周期事件
451
- 启动时初始化数据库和后台任务,关闭时清理资源
452
- """
453
- await _startup()
454
- yield
455
- await _shutdown()
456
-
457
- # 将 lifespan 设置到 app
458
- app.router.lifespan_context = lifespan
459
-
460
- # ------------------------------------------------------------------------------
461
- # 直接运行支持
462
- # ------------------------------------------------------------------------------
463
-
464
- if __name__ == "__main__":
465
- import uvicorn
466
- port = int(os.getenv("PORT", "8000"))
467
- uvicorn.run(
468
- app,
469
- host="0.0.0.0",
470
- port=port,
471
- log_level="info"
472
- )