StarrySkyWorld commited on
Commit
ebb8447
·
verified ·
1 Parent(s): 83a6b2d

Create main.py

Browse files
Files changed (1) hide show
  1. main.py +799 -0
main.py ADDED
@@ -0,0 +1,799 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gemini API 兼容中间件服务器
3
+ 透明代理 - 将 HTTP 请求通过 WebSocket 转发给 WSClient 处理
4
+
5
+ 版本: 1.0.0
6
+ 协议: Gemini Compatible WebSocket Proxy Protocol
7
+ """
8
+
9
+ import asyncio
10
+ import json
11
+ import time
12
+ import uuid
13
+ import logging
14
+ from dataclasses import dataclass, field
15
+ from enum import Enum
16
+ from typing import Optional, Dict, List, Any, AsyncGenerator
17
+ from contextlib import asynccontextmanager
18
+ from datetime import datetime
19
+
20
+ from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException, Request, Path
21
+ from fastapi.responses import StreamingResponse, JSONResponse
22
+ from pydantic import BaseModel
23
+ import uvicorn
24
+
25
+
26
+ # ============================================================
27
+ # 日志配置
28
+ # ============================================================
29
+
30
+ class ColoredFormatter(logging.Formatter):
31
+ """彩色日志格式化器"""
32
+
33
+ COLORS = {
34
+ 'DEBUG': '\033[36m', # 青色
35
+ 'INFO': '\033[32m', # 绿色
36
+ 'WARNING': '\033[33m', # 黄色
37
+ 'ERROR': '\033[31m', # 红色
38
+ 'CRITICAL': '\033[35m', # 紫色
39
+ }
40
+ RESET = '\033[0m'
41
+
42
+ def format(self, record):
43
+ color = self.COLORS.get(record.levelname, self.RESET)
44
+ record.levelname = f"{color}{record.levelname}{self.RESET}"
45
+ record.msg = f"{color}{record.msg}{self.RESET}"
46
+ return super().format(record)
47
+
48
+
49
+ def setup_logging():
50
+ """配置日志"""
51
+ handler = logging.StreamHandler()
52
+ handler.setFormatter(ColoredFormatter(
53
+ fmt='%(asctime)s | %(levelname)-17s | %(message)s',
54
+ datefmt='%Y-%m-%d %H:%M:%S'
55
+ ))
56
+
57
+ logger = logging.getLogger("middleware")
58
+ logger.setLevel(logging.DEBUG)
59
+ logger.addHandler(handler)
60
+ return logger
61
+
62
+
63
+ log = setup_logging()
64
+
65
+
66
+ # ============================================================
67
+ # 配置
68
+ # ============================================================
69
+
70
+ class Config:
71
+ HOST = "0.0.0.0"
72
+ PORT = 8000
73
+ API_KEY = "sk-123456" # 修改为你的密钥
74
+ REQUEST_TIMEOUT = 120 # 请求超时(秒)
75
+ HEARTBEAT_INTERVAL = 30 # 心跳间隔(秒)
76
+ REGISTER_TIMEOUT = 10 # 注册超时(秒)
77
+ LOG_BODY_MAX_LENGTH = 500 # 日志中 body 最大显示长度
78
+
79
+
80
+ # ============================================================
81
+ # 协议定义
82
+ # ============================================================
83
+
84
+ class MessageType(str, Enum):
85
+ # 连接管理
86
+ REGISTER = "register"
87
+ REGISTER_ACK = "register_ack"
88
+ # 请求
89
+ REQUEST = "request"
90
+ # 响应
91
+ RESPONSE = "response"
92
+ CHUNK = "chunk"
93
+ END = "end"
94
+ ERROR = "error"
95
+ # 控制
96
+ ABORT = "abort"
97
+ PING = "ping"
98
+ PONG = "pong"
99
+
100
+
101
+ # ============================================================
102
+ # 请求上下文
103
+ # ============================================================
104
+
105
+ class RequestStatus(str, Enum):
106
+ PENDING = "pending"
107
+ STREAMING = "streaming"
108
+ COMPLETED = "completed"
109
+ ERROR = "error"
110
+ ABORTED = "aborted"
111
+
112
+
113
+ @dataclass
114
+ class RequestContext:
115
+ """请求上下文,追踪每个请求的状态"""
116
+ id: str
117
+ created_at: float
118
+ is_stream: bool
119
+ model: str
120
+ status: RequestStatus = RequestStatus.PENDING
121
+ # 非流式请求:用 Future 等待完整响应
122
+ response_future: asyncio.Future = field(default_factory=lambda: asyncio.get_running_loop().create_future())
123
+ # 流式请求:用 Queue 传递 chunks
124
+ chunk_queue: asyncio.Queue = field(default_factory=asyncio.Queue)
125
+ # 统计
126
+ chunk_count: int = 0
127
+
128
+ def elapsed_ms(self) -> int:
129
+ """返回请求耗时(毫秒)"""
130
+ return int((time.time() - self.created_at) * 1000)
131
+
132
+
133
+ # ============================================================
134
+ # 请求管理器
135
+ # ============================================================
136
+
137
+ class RequestManager:
138
+ """管理所有待处理的请求"""
139
+
140
+ def __init__(self):
141
+ self.pending_requests: Dict[str, RequestContext] = {}
142
+ self._lock = asyncio.Lock()
143
+
144
+ async def create_request(self, is_stream: bool, model: str) -> RequestContext:
145
+ """创建新的请求上下文"""
146
+ request_id = str(uuid.uuid4())
147
+ ctx = RequestContext(
148
+ id=request_id,
149
+ created_at=time.time(),
150
+ is_stream=is_stream,
151
+ model=model
152
+ )
153
+ async with self._lock:
154
+ self.pending_requests[request_id] = ctx
155
+
156
+ log.debug(f"[ReqMgr] 创建请求 | id={request_id[:8]}... | stream={is_stream} | model={model}")
157
+ return ctx
158
+
159
+ def get_request(self, request_id: str) -> Optional[RequestContext]:
160
+ return self.pending_requests.get(request_id)
161
+
162
+ async def wait_for_response(self, request_id: str, timeout: float) -> Dict:
163
+ """等待非流式请求的完整响应"""
164
+ ctx = self.pending_requests.get(request_id)
165
+ if not ctx:
166
+ raise ValueError(f"Request {request_id} not found")
167
+
168
+ try:
169
+ log.debug(f"[ReqMgr] 等待响应 | id={request_id[:8]}... | timeout={timeout}s")
170
+ result = await asyncio.wait_for(ctx.response_future, timeout=timeout)
171
+ log.info(f"[ReqMgr] 收到响应 | id={request_id[:8]}... | elapsed={ctx.elapsed_ms()}ms")
172
+ return result
173
+ except asyncio.TimeoutError:
174
+ ctx.status = RequestStatus.ERROR
175
+ log.error(f"[ReqMgr] 请求超时 | id={request_id[:8]}... | elapsed={ctx.elapsed_ms()}ms")
176
+ raise TimeoutError(f"Request {request_id} timed out after {timeout}s")
177
+ finally:
178
+ await self._cleanup_request(request_id)
179
+
180
+ async def wait_for_stream(self, request_id: str, timeout: float) -> AsyncGenerator[Dict, None]:
181
+ """等待流式请求的数据块"""
182
+ ctx = self.pending_requests.get(request_id)
183
+ if not ctx:
184
+ raise ValueError(f"Request {request_id} not found")
185
+
186
+ ctx.status = RequestStatus.STREAMING
187
+ log.debug(f"[ReqMgr] 开始流式接收 | id={request_id[:8]}...")
188
+
189
+ try:
190
+ while True:
191
+ try:
192
+ chunk = await asyncio.wait_for(ctx.chunk_queue.get(), timeout=timeout)
193
+ if chunk is None: # 流结束信号
194
+ log.info(f"[ReqMgr] 流结束 | id={request_id[:8]}... | chunks={ctx.chunk_count} | elapsed={ctx.elapsed_ms()}ms")
195
+ break
196
+ ctx.chunk_count += 1
197
+ yield chunk
198
+ except asyncio.TimeoutError:
199
+ ctx.status = RequestStatus.ERROR
200
+ log.error(f"[ReqMgr] 流超时 | id={request_id[:8]}... | chunks={ctx.chunk_count}")
201
+ raise TimeoutError(f"Stream {request_id} timed out")
202
+ finally:
203
+ await self._cleanup_request(request_id)
204
+
205
+ def resolve_request(self, request_id: str, response: Dict):
206
+ """解决非流式请求"""
207
+ ctx = self.pending_requests.get(request_id)
208
+ if ctx and not ctx.response_future.done():
209
+ ctx.status = RequestStatus.COMPLETED
210
+ ctx.response_future.set_result(response)
211
+ log.debug(f"[ReqMgr] 请求已解决 | id={request_id[:8]}...")
212
+
213
+ def push_chunk(self, request_id: str, chunk: Dict):
214
+ """推送流式数据块"""
215
+ ctx = self.pending_requests.get(request_id)
216
+ if ctx and ctx.is_stream:
217
+ ctx.chunk_queue.put_nowait(chunk)
218
+
219
+ def end_stream(self, request_id: str, final_body: Optional[Dict] = None):
220
+ """结束流式响应"""
221
+ ctx = self.pending_requests.get(request_id)
222
+ if ctx:
223
+ ctx.status = RequestStatus.COMPLETED
224
+ if final_body:
225
+ ctx.chunk_queue.put_nowait(final_body)
226
+ ctx.chunk_queue.put_nowait(None) # 结束信号
227
+
228
+ def fail_request(self, request_id: str, error: Dict):
229
+ """标记请求失败"""
230
+ ctx = self.pending_requests.get(request_id)
231
+ if ctx:
232
+ ctx.status = RequestStatus.ERROR
233
+ log.error(f"[ReqMgr] 请求失败 | id={request_id[:8]}... | error={error}")
234
+ if not ctx.response_future.done():
235
+ ctx.response_future.set_exception(Exception(json.dumps(error)))
236
+ if ctx.is_stream:
237
+ ctx.chunk_queue.put_nowait(None)
238
+
239
+ def abort_request(self, request_id: str):
240
+ """中止请求"""
241
+ ctx = self.pending_requests.get(request_id)
242
+ if ctx:
243
+ ctx.status = RequestStatus.ABORTED
244
+ log.warning(f"[ReqMgr] 请求中止 | id={request_id[:8]}...")
245
+ if not ctx.response_future.done():
246
+ ctx.response_future.set_exception(Exception("Request aborted"))
247
+ if ctx.is_stream:
248
+ ctx.chunk_queue.put_nowait(None)
249
+
250
+ async def fail_all_requests(self, error_message: str):
251
+ """使所有待处理请求失败"""
252
+ async with self._lock:
253
+ count = len(self.pending_requests)
254
+ if count > 0:
255
+ log.warning(f"[ReqMgr] 批量失败 | count={count} | reason={error_message}")
256
+ for request_id in list(self.pending_requests.keys()):
257
+ self.fail_request(request_id, {
258
+ "error": {
259
+ "code": 503,
260
+ "message": error_message,
261
+ "status": "UNAVAILABLE"
262
+ }
263
+ })
264
+
265
+ async def _cleanup_request(self, request_id: str):
266
+ """清理已完成的请求"""
267
+ async with self._lock:
268
+ self.pending_requests.pop(request_id, None)
269
+
270
+ @property
271
+ def pending_count(self) -> int:
272
+ return len(self.pending_requests)
273
+
274
+
275
+ # ============================================================
276
+ # WebSocket 管理器
277
+ # ============================================================
278
+
279
+ class WebSocketManager:
280
+ """管理 WSClient 连接"""
281
+
282
+ def __init__(self):
283
+ self.active_connection: Optional[WebSocket] = None
284
+ self.client_id: Optional[str] = None
285
+ self.client_version: Optional[str] = None
286
+ self.models: List[str] = []
287
+ self.max_concurrent: int = 1
288
+ self.connected_at: Optional[float] = None
289
+ self._lock = asyncio.Lock()
290
+ self._heartbeat_task: Optional[asyncio.Task] = None
291
+
292
+ async def register(self, websocket: WebSocket, payload: Dict):
293
+ """注册 WSClient 连接"""
294
+ async with self._lock:
295
+ # 关闭旧连接
296
+ if self.active_connection and self.active_connection != websocket:
297
+ log.warning("[WSMgr] 关闭旧连接...")
298
+ try:
299
+ await self.active_connection.close()
300
+ except:
301
+ pass
302
+
303
+ self.active_connection = websocket
304
+ self.client_id = payload.get("clientId", "unknown")
305
+ self.client_version = payload.get("clientVersion", "unknown")
306
+ self.models = payload.get("models", [])
307
+ self.max_concurrent = payload.get("maxConcurrent", 1)
308
+ self.connected_at = time.time()
309
+
310
+ # 发送注册确认
311
+ ack_message = {
312
+ "id": str(uuid.uuid4()),
313
+ "type": MessageType.REGISTER_ACK.value,
314
+ "timestamp": int(time.time() * 1000),
315
+ "payload": {
316
+ "success": True,
317
+ "serverId": "gemini-middleware-001",
318
+ "config": {
319
+ "heartbeatInterval": Config.HEARTBEAT_INTERVAL * 1000,
320
+ "requestTimeout": Config.REQUEST_TIMEOUT * 1000
321
+ }
322
+ }
323
+ }
324
+ await websocket.send_json(ack_message)
325
+
326
+ log.info(f"[WSMgr] ✅ WSClient 已注册")
327
+ log.info(f"[WSMgr] ├─ clientId: {self.client_id}")
328
+ log.info(f"[WSMgr] ├─ version: {self.client_version}")
329
+ log.info(f"[WSMgr] ├─ models: {self.models}")
330
+ log.info(f"[WSMgr] └─ maxConcurrent: {self.max_concurrent}")
331
+
332
+ # 启动心跳
333
+ if self._heartbeat_task:
334
+ self._heartbeat_task.cancel()
335
+ self._heartbeat_task = asyncio.create_task(self._heartbeat_loop())
336
+
337
+ async def unregister(self, websocket: WebSocket):
338
+ """注销连接"""
339
+ async with self._lock:
340
+ if self.active_connection == websocket:
341
+ uptime = int(time.time() - self.connected_at) if self.connected_at else 0
342
+ log.warning(f"[WSMgr] ❌ WSClient 已断开 | uptime={uptime}s")
343
+
344
+ self.active_connection = None
345
+ self.client_id = None
346
+ self.models = []
347
+ self.connected_at = None
348
+
349
+ if self._heartbeat_task:
350
+ self._heartbeat_task.cancel()
351
+ self._heartbeat_task = None
352
+
353
+ def is_available(self) -> bool:
354
+ return self.active_connection is not None
355
+
356
+ async def send_message(self, message: Dict):
357
+ """发送消息到 WSClient"""
358
+ if self.active_connection:
359
+ await self.active_connection.send_json(message)
360
+ log.debug(f"[WSMgr] 发送消息 | type={message.get('type')} | id={message.get('id', '')[:8]}...")
361
+
362
+ async def send_request(self, request_id: str, is_stream: bool, body: Dict):
363
+ """发送请求到 WSClient"""
364
+ message = {
365
+ "id": request_id,
366
+ "type": MessageType.REQUEST.value,
367
+ "timestamp": int(time.time() * 1000),
368
+ "stream": is_stream,
369
+ "body": body
370
+ }
371
+ await self.send_message(message)
372
+
373
+ body_preview = self._truncate_body(body)
374
+ log.info(f"[WSMgr] 📤 发送请求 | id={request_id[:8]}... | stream={is_stream}")
375
+ log.debug(f"[WSMgr] └─ body: {body_preview}")
376
+
377
+ async def send_abort(self, request_id: str, reason: str = "client_disconnected"):
378
+ """发送取消请求"""
379
+ message = {
380
+ "id": request_id,
381
+ "type": MessageType.ABORT.value,
382
+ "timestamp": int(time.time() * 1000),
383
+ "reason": reason
384
+ }
385
+ await self.send_message(message)
386
+ log.warning(f"[WSMgr] 🚫 发送取消 | id={request_id[:8]}... | reason={reason}")
387
+
388
+ async def _heartbeat_loop(self):
389
+ """心跳循环"""
390
+ while True:
391
+ try:
392
+ await asyncio.sleep(Config.HEARTBEAT_INTERVAL)
393
+ if self.active_connection:
394
+ ping_id = str(uuid.uuid4())
395
+ await self.send_message({
396
+ "id": ping_id,
397
+ "type": MessageType.PING.value,
398
+ "timestamp": int(time.time() * 1000)
399
+ })
400
+ log.debug(f"[WSMgr] 💓 发送心跳 | id={ping_id[:8]}...")
401
+ except asyncio.CancelledError:
402
+ break
403
+ except Exception as e:
404
+ log.error(f"[WSMgr] 心跳错误: {e}")
405
+
406
+ def _truncate_body(self, body: Dict) -> str:
407
+ """截断 body 用于日志显示"""
408
+ s = json.dumps(body, ensure_ascii=False)
409
+ if len(s) > Config.LOG_BODY_MAX_LENGTH:
410
+ return s[:Config.LOG_BODY_MAX_LENGTH] + "..."
411
+ return s
412
+
413
+ def get_status(self) -> Dict:
414
+ """获取连接状态"""
415
+ return {
416
+ "connected": self.is_available(),
417
+ "clientId": self.client_id,
418
+ "clientVersion": self.client_version,
419
+ "models": self.models,
420
+ "maxConcurrent": self.max_concurrent,
421
+ "uptime": int(time.time() - self.connected_at) if self.connected_at else 0
422
+ }
423
+
424
+
425
+ # ============================================================
426
+ # 全局实例
427
+ # ============================================================
428
+
429
+ request_manager = RequestManager()
430
+ ws_manager = WebSocketManager()
431
+
432
+
433
+ # ============================================================
434
+ # FastAPI 应用
435
+ # ============================================================
436
+
437
+ @asynccontextmanager
438
+ async def lifespan(app: FastAPI):
439
+ log.info("=" * 60)
440
+ log.info(" Gemini API 兼容中间件")
441
+ log.info("=" * 60)
442
+ log.info(f" HTTP 端点: http://{Config.HOST}:{Config.PORT}")
443
+ log.info(f" WebSocket: ws://{Config.HOST}:{Config.PORT}/ws")
444
+ log.info(f" API 文档: http://{Config.HOST}:{Config.PORT}/docs")
445
+ log.info("=" * 60)
446
+ yield
447
+ log.info("[Server] 服务关闭")
448
+
449
+
450
+ app = FastAPI(
451
+ title="Gemini API Compatible Middleware",
452
+ description="透明代理 - 将 Gemini API 请求通过 WebSocket 转发给 WSClient",
453
+ version="1.0.0",
454
+ lifespan=lifespan
455
+ )
456
+
457
+
458
+ # ============================================================
459
+ # API Key 校验 - 支持多种传递方式
460
+ # ============================================================
461
+
462
+ def extract_api_key(request: Request) -> Optional[str]:
463
+ """
464
+ 从多个位置提取 API Key,按优先级:
465
+ 1. Header: x-goog-api-key
466
+ 2. Header: Authorization: Bearer <key>
467
+ 3. Query: ?key=<key>
468
+ """
469
+ # 方式1: x-goog-api-key header
470
+ api_key = request.headers.get("x-goog-api-key")
471
+ if api_key:
472
+ log.debug(f"[Auth] 从 x-goog-api-key header 获取 key")
473
+ return api_key
474
+
475
+ # 方式2: Authorization header (Bearer token)
476
+ auth_header = request.headers.get("authorization", "")
477
+ if auth_header.lower().startswith("bearer "):
478
+ api_key = auth_header[7:].strip()
479
+ log.debug(f"[Auth] 从 Authorization header 获取 key")
480
+ return api_key
481
+
482
+ # 方式3: Query parameter
483
+ api_key = request.query_params.get("key")
484
+ if api_key:
485
+ log.debug(f"[Auth] 从 query parameter 获取 key")
486
+ return api_key
487
+
488
+ return None
489
+
490
+
491
+ def verify_api_key(request: Request):
492
+ """从请求中校验 API Key"""
493
+ api_key = extract_api_key(request)
494
+
495
+ # 调试日志
496
+ log.debug(f"[Auth] 提取到的 Key: {api_key[:10] + '...' if api_key else 'None'}")
497
+
498
+ if api_key != Config.API_KEY:
499
+ log.warning(f"[Auth] ⛔ 认证失败 | key={api_key[:10] + '...' if api_key else 'None'}")
500
+ raise HTTPException(
501
+ status_code=401,
502
+ detail={
503
+ "error": {
504
+ "code": 401,
505
+ "message": "Invalid API key. Provide via 'x-goog-api-key' header, 'Authorization: Bearer <key>' header, or '?key=<key>' query parameter.",
506
+ "status": "UNAUTHENTICATED"
507
+ }
508
+ }
509
+ )
510
+ log.debug("[Auth] ✓ API Key 验证通过")
511
+
512
+
513
+ # ============================================================
514
+ # HTTP 路由 - Gemini API 兼容
515
+ # ============================================================
516
+
517
+ @app.get("/")
518
+ async def root():
519
+ """根路径 - 服务信息"""
520
+ return {
521
+ "service": "Gemini API Compatible Middleware",
522
+ "version": "1.0.0",
523
+ "status": "running",
524
+ "wsClientConnected": ws_manager.is_available(),
525
+ "endpoints": {
526
+ "models": "/v1beta/models",
527
+ "generateContent": "/v1beta/models/{model}:generateContent",
528
+ "streamGenerateContent": "/v1beta/models/{model}:streamGenerateContent",
529
+ "health": "/health",
530
+ "websocket": "/ws"
531
+ }
532
+ }
533
+
534
+
535
+ @app.get("/health")
536
+ async def health_check():
537
+ """健康检查"""
538
+ status = ws_manager.get_status()
539
+ return {
540
+ "status": "ok" if status["connected"] else "degraded",
541
+ "timestamp": datetime.now().isoformat(),
542
+ "wsClient": status,
543
+ "pendingRequests": request_manager.pending_count
544
+ }
545
+
546
+
547
+ @app.get("/v1beta/models")
548
+ async def list_models(request: Request):
549
+ """列出可用模型"""
550
+ verify_api_key(request)
551
+
552
+ log.info("[API] GET /v1beta/models")
553
+
554
+ if not ws_manager.is_available():
555
+ log.warning("[API] WSClient 未连接,返回空模型列表")
556
+ return {"models": []}
557
+
558
+ models = [
559
+ {
560
+ "name": f"models/{model}",
561
+ "displayName": model,
562
+ "supportedGenerationMethods": ["generateContent", "streamGenerateContent"]
563
+ }
564
+ for model in ws_manager.models
565
+ ]
566
+
567
+ log.info(f"[API] 返回 {len(models)} 个模型")
568
+ return {"models": models}
569
+
570
+
571
+ @app.get("/v1beta/models/{model}")
572
+ async def get_model(
573
+ request: Request,
574
+ model: str = Path(..., description="模型名称")
575
+ ):
576
+ """获取模型信息"""
577
+ verify_api_key(request)
578
+
579
+ log.info(f"[API] GET /v1beta/models/{model}")
580
+
581
+ # 移除可能的 "models/" 前缀
582
+ model_name = model.replace("models/", "")
583
+
584
+ if not ws_manager.is_available():
585
+ raise HTTPException(status_code=503, detail={
586
+ "error": {"code": 503, "message": "Service unavailable", "status": "UNAVAILABLE"}
587
+ })
588
+
589
+ if model_name not in ws_manager.models:
590
+ raise HTTPException(status_code=404, detail={
591
+ "error": {"code": 404, "message": f"Model '{model}' not found", "status": "NOT_FOUND"}
592
+ })
593
+
594
+ return {
595
+ "name": f"models/{model_name}",
596
+ "displayName": model_name,
597
+ "supportedGenerationMethods": ["generateContent", "streamGenerateContent"]
598
+ }
599
+
600
+
601
+ @app.post("/v1beta/models/{model}:generateContent")
602
+ async def generate_content(
603
+ request: Request,
604
+ model: str = Path(..., description="模型名称")
605
+ ):
606
+ """生成内容 - 非流式"""
607
+ verify_api_key(request)
608
+
609
+ model_name = model.replace("models/", "")
610
+ log.info(f"[API] POST /v1beta/models/{model_name}:generateContent")
611
+
612
+ # 检查 WSClient 是否可用
613
+ if not ws_manager.is_available():
614
+ log.error("[API] WSClient 未连接")
615
+ raise HTTPException(status_code=503, detail={
616
+ "error": {"code": 503, "message": "WSClient not connected", "status": "UNAVAILABLE"}
617
+ })
618
+
619
+ # 读取请求体(透传)
620
+ body = await request.json()
621
+ log.debug(f"[API] 请求体: {ws_manager._truncate_body(body)}")
622
+
623
+ # 创建请求上下文
624
+ ctx = await request_manager.create_request(is_stream=False, model=model_name)
625
+
626
+ try:
627
+ # 发送请求到 WSClient
628
+ await ws_manager.send_request(ctx.id, is_stream=False, body=body)
629
+
630
+ # 等待响应
631
+ response_body = await request_manager.wait_for_response(ctx.id, Config.REQUEST_TIMEOUT)
632
+
633
+ log.info(f"[API] ✅ 请求完成 | id={ctx.id[:8]}... | elapsed={ctx.elapsed_ms()}ms")
634
+ return JSONResponse(content=response_body)
635
+
636
+ except TimeoutError:
637
+ log.error(f"[API] ⏱️ 请求超时 | id={ctx.id[:8]}...")
638
+ raise HTTPException(status_code=504, detail={
639
+ "error": {"code": 504, "message": "Request timeout", "status": "DEADLINE_EXCEEDED"}
640
+ })
641
+ except Exception as e:
642
+ log.error(f"[API] ❌ 请求失败 | id={ctx.id[:8]}... | error={e}")
643
+ # 尝试解析错误 JSON
644
+ try:
645
+ error_detail = json.loads(str(e))
646
+ except:
647
+ error_detail = {"error": {"code": 500, "message": str(e), "status": "INTERNAL"}}
648
+ raise HTTPException(status_code=500, detail=error_detail)
649
+
650
+
651
+ @app.post("/v1beta/models/{model}:streamGenerateContent")
652
+ async def stream_generate_content(
653
+ request: Request,
654
+ model: str = Path(..., description="模型名称")
655
+ ):
656
+ """生成内容 - 流式"""
657
+ verify_api_key(request)
658
+
659
+ model_name = model.replace("models/", "")
660
+ log.info(f"[API] POST /v1beta/models/{model_name}:streamGenerateContent")
661
+
662
+ # 检查 WSClient 是否可用
663
+ if not ws_manager.is_available():
664
+ log.error("[API] WSClient 未连接")
665
+ raise HTTPException(status_code=503, detail={
666
+ "error": {"code": 503, "message": "WSClient not connected", "status": "UNAVAILABLE"}
667
+ })
668
+
669
+ # 读取请求体(透传)
670
+ body = await request.json()
671
+ log.debug(f"[API] 请求体: {ws_manager._truncate_body(body)}")
672
+
673
+ # 创建请求上下文
674
+ ctx = await request_manager.create_request(is_stream=True, model=model_name)
675
+
676
+ # 发送请求到 WSClient
677
+ await ws_manager.send_request(ctx.id, is_stream=True, body=body)
678
+
679
+ # 返回流式响应
680
+ return StreamingResponse(
681
+ stream_generator(ctx.id, request),
682
+ media_type="text/event-stream",
683
+ headers={
684
+ "Cache-Control": "no-cache",
685
+ "Connection": "keep-alive",
686
+ "X-Accel-Buffering": "no"
687
+ }
688
+ )
689
+
690
+
691
+ async def stream_generator(request_id: str, http_request: Request) -> AsyncGenerator[str, None]:
692
+ """生成 SSE 流式响应"""
693
+ ctx = request_manager.get_request(request_id)
694
+
695
+ try:
696
+ async for chunk_body in request_manager.wait_for_stream(request_id, Config.REQUEST_TIMEOUT):
697
+ # 检查客户端是否断开
698
+ if await http_request.is_disconnected():
699
+ log.warning(f"[Stream] 客户端断开 | id={request_id[:8]}...")
700
+ await ws_manager.send_abort(request_id, "client_disconnected")
701
+ request_manager.abort_request(request_id)
702
+ break
703
+
704
+ # 输出 SSE 格式(透传 body)
705
+ yield f"data: {json.dumps(chunk_body)}\n\n"
706
+
707
+ log.info(f"[Stream] ✅ 流完成 | id={request_id[:8]}...")
708
+
709
+ except TimeoutError:
710
+ log.error(f"[Stream] ⏱️ 流超时 | id={request_id[:8]}...")
711
+ error_body = {"error": {"code": 504, "message": "Stream timeout", "status": "DEADLINE_EXCEEDED"}}
712
+ yield f"data: {json.dumps(error_body)}\n\n"
713
+ except Exception as e:
714
+ log.error(f"[Stream] ❌ 流错误 | id={request_id[:8]}... | error={e}")
715
+ error_body = {"error": {"code": 500, "message": str(e), "status": "INTERNAL"}}
716
+ yield f"data: {json.dumps(error_body)}\n\n"
717
+
718
+
719
+ # ============================================================
720
+ # WebSocket 端点
721
+ # ============================================================
722
+
723
+ @app.websocket("/ws")
724
+ async def websocket_endpoint(websocket: WebSocket):
725
+ """WSClient WebSocket 连接端点"""
726
+ await websocket.accept()
727
+ client_ip = websocket.client.host if websocket.client else "unknown"
728
+ log.info(f"[WS] 🔌 新连接 | ip={client_ip}")
729
+ log.info(f"[WS] └─ 等待 REGISTER 消息 (timeout={Config.REGISTER_TIMEOUT}s)...")
730
+
731
+ try:
732
+ while True:
733
+ data = await websocket.receive_json()
734
+ await handle_ws_message(websocket, data)
735
+
736
+ except WebSocketDisconnect:
737
+ log.warning(f"[WS] 连接断开 | ip={client_ip}")
738
+ except Exception as e:
739
+ log.error(f"[WS] 连接错误 | ip={client_ip} | error={e}")
740
+ finally:
741
+ await ws_manager.unregister(websocket)
742
+ await request_manager.fail_all_requests("WSClient disconnected")
743
+
744
+
745
+ async def handle_ws_message(websocket: WebSocket, data: Dict):
746
+ """处理 WSClient 消息"""
747
+ msg_type = data.get("type", "")
748
+ msg_id = data.get("id", "")
749
+ timestamp = data.get("timestamp", 0)
750
+
751
+ # 简短日志
752
+ log.debug(f"[WS] 📩 收到消息 | type={msg_type} | id={msg_id[:8] if msg_id else 'N/A'}...")
753
+
754
+ # 根据消息类型处理
755
+ if msg_type == MessageType.REGISTER.value or msg_type == "register":
756
+ payload = data.get("payload", {})
757
+ await ws_manager.register(websocket, payload)
758
+
759
+ elif msg_type == MessageType.RESPONSE.value or msg_type == "response":
760
+ body = data.get("body", {})
761
+ log.info(f"[WS] 📥 收到响应 | id={msg_id[:8]}...")
762
+ log.debug(f"[WS] └─ body: {ws_manager._truncate_body(body)}")
763
+ request_manager.resolve_request(msg_id, body)
764
+
765
+ elif msg_type == MessageType.CHUNK.value or msg_type == "chunk":
766
+ body = data.get("body", {})
767
+ index = data.get("index", 0)
768
+ log.debug(f"[WS] 📦 收到数据块 | id={msg_id[:8]}... | index={index}")
769
+ request_manager.push_chunk(msg_id, body)
770
+
771
+ elif msg_type == MessageType.END.value or msg_type == "end":
772
+ body = data.get("body")
773
+ log.info(f"[WS] 🏁 流结束 | id={msg_id[:8]}...")
774
+ request_manager.end_stream(msg_id, body)
775
+
776
+ elif msg_type == MessageType.ERROR.value or msg_type == "error":
777
+ body = data.get("body", {})
778
+ log.error(f"[WS] ⚠️ 收到错误 | id={msg_id[:8]}... | body={body}")
779
+ request_manager.fail_request(msg_id, body)
780
+
781
+ elif msg_type == MessageType.PONG.value or msg_type == "pong":
782
+ log.debug(f"[WS] 💓 收到心跳响应 | id={msg_id[:8]}...")
783
+
784
+ else:
785
+ log.warning(f"[WS] ❓ 未知消息类型 | type={msg_type}")
786
+
787
+
788
+ # ============================================================
789
+ # 启动
790
+ # ============================================================
791
+
792
+ if __name__ == "__main__":
793
+ uvicorn.run(
794
+ "main:app",
795
+ host=Config.HOST,
796
+ port=Config.PORT,
797
+ reload=True,
798
+ log_level="warning" # 使用我们自己的日志
799
+ )