File size: 29,083 Bytes
ebb8447
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d2a1544
ebb8447
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
"""
Gemini API 兼容中间件服务器
透明代理 - 将 HTTP 请求通过 WebSocket 转发给 WSClient 处理

版本: 1.0.0
协议: Gemini Compatible WebSocket Proxy Protocol
"""

import asyncio
import json
import time
import uuid
import logging
from dataclasses import dataclass, field
from enum import Enum
from typing import Optional, Dict, List, Any, AsyncGenerator
from contextlib import asynccontextmanager
from datetime import datetime

from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException, Request, Path
from fastapi.responses import StreamingResponse, JSONResponse, HTMLResponse
from pydantic import BaseModel
import uvicorn


# ============================================================
# 日志配置
# ============================================================

class ColoredFormatter(logging.Formatter):
    """彩色日志格式化器"""
    
    COLORS = {
        'DEBUG': '\033[36m',     # 青色
        'INFO': '\033[32m',      # 绿色
        'WARNING': '\033[33m',   # 黄色
        'ERROR': '\033[31m',     # 红色
        'CRITICAL': '\033[35m',  # 紫色
    }
    RESET = '\033[0m'
    
    def format(self, record):
        color = self.COLORS.get(record.levelname, self.RESET)
        record.levelname = f"{color}{record.levelname}{self.RESET}"
        record.msg = f"{color}{record.msg}{self.RESET}"
        return super().format(record)


def setup_logging():
    """配置日志"""
    handler = logging.StreamHandler()
    handler.setFormatter(ColoredFormatter(
        fmt='%(asctime)s | %(levelname)-17s | %(message)s',
        datefmt='%Y-%m-%d %H:%M:%S'
    ))
    
    logger = logging.getLogger("middleware")
    logger.setLevel(logging.DEBUG)
    logger.addHandler(handler)
    return logger


log = setup_logging()


# ============================================================
# 配置
# ============================================================

class Config:
    HOST = "0.0.0.0"
    PORT = 8000
    API_KEY = "sk-123456"              # 修改为你的密钥
    REQUEST_TIMEOUT = 120              # 请求超时(秒)
    HEARTBEAT_INTERVAL = 30            # 心跳间隔(秒)
    REGISTER_TIMEOUT = 10              # 注册超时(秒)
    LOG_BODY_MAX_LENGTH = 500          # 日志中 body 最大显示长度


# ============================================================
# 协议定义
# ============================================================

class MessageType(str, Enum):
    # 连接管理
    REGISTER = "register"
    REGISTER_ACK = "register_ack"
    # 请求
    REQUEST = "request"
    # 响应
    RESPONSE = "response"
    CHUNK = "chunk"
    END = "end"
    ERROR = "error"
    # 控制
    ABORT = "abort"
    PING = "ping"
    PONG = "pong"


# ============================================================
# 请求上下文
# ============================================================

class RequestStatus(str, Enum):
    PENDING = "pending"
    STREAMING = "streaming"
    COMPLETED = "completed"
    ERROR = "error"
    ABORTED = "aborted"


@dataclass
class RequestContext:
    """请求上下文,追踪每个请求的状态"""
    id: str
    created_at: float
    is_stream: bool
    model: str
    status: RequestStatus = RequestStatus.PENDING
    # 非流式请求:用 Future 等待完整响应
    response_future: asyncio.Future = field(default_factory=lambda: asyncio.get_running_loop().create_future())
    # 流式请求:用 Queue 传递 chunks
    chunk_queue: asyncio.Queue = field(default_factory=asyncio.Queue)
    # 统计
    chunk_count: int = 0
    
    def elapsed_ms(self) -> int:
        """返回请求耗时(毫秒)"""
        return int((time.time() - self.created_at) * 1000)


# ============================================================
# 请求管理器
# ============================================================

class RequestManager:
    """管理所有待处理的请求"""
    
    def __init__(self):
        self.pending_requests: Dict[str, RequestContext] = {}
        self._lock = asyncio.Lock()
    
    async def create_request(self, is_stream: bool, model: str) -> RequestContext:
        """创建新的请求上下文"""
        request_id = str(uuid.uuid4())
        ctx = RequestContext(
            id=request_id,
            created_at=time.time(),
            is_stream=is_stream,
            model=model
        )
        async with self._lock:
            self.pending_requests[request_id] = ctx
        
        log.debug(f"[ReqMgr] 创建请求 | id={request_id[:8]}... | stream={is_stream} | model={model}")
        return ctx
    
    def get_request(self, request_id: str) -> Optional[RequestContext]:
        return self.pending_requests.get(request_id)
    
    async def wait_for_response(self, request_id: str, timeout: float) -> Dict:
        """等待非流式请求的完整响应"""
        ctx = self.pending_requests.get(request_id)
        if not ctx:
            raise ValueError(f"Request {request_id} not found")
        
        try:
            log.debug(f"[ReqMgr] 等待响应 | id={request_id[:8]}... | timeout={timeout}s")
            result = await asyncio.wait_for(ctx.response_future, timeout=timeout)
            log.info(f"[ReqMgr] 收到响应 | id={request_id[:8]}... | elapsed={ctx.elapsed_ms()}ms")
            return result
        except asyncio.TimeoutError:
            ctx.status = RequestStatus.ERROR
            log.error(f"[ReqMgr] 请求超时 | id={request_id[:8]}... | elapsed={ctx.elapsed_ms()}ms")
            raise TimeoutError(f"Request {request_id} timed out after {timeout}s")
        finally:
            await self._cleanup_request(request_id)
    
    async def wait_for_stream(self, request_id: str, timeout: float) -> AsyncGenerator[Dict, None]:
        """等待流式请求的数据块"""
        ctx = self.pending_requests.get(request_id)
        if not ctx:
            raise ValueError(f"Request {request_id} not found")
        
        ctx.status = RequestStatus.STREAMING
        log.debug(f"[ReqMgr] 开始流式接收 | id={request_id[:8]}...")
        
        try:
            while True:
                try:
                    chunk = await asyncio.wait_for(ctx.chunk_queue.get(), timeout=timeout)
                    if chunk is None:  # 流结束信号
                        log.info(f"[ReqMgr] 流结束 | id={request_id[:8]}... | chunks={ctx.chunk_count} | elapsed={ctx.elapsed_ms()}ms")
                        break
                    ctx.chunk_count += 1
                    yield chunk
                except asyncio.TimeoutError:
                    ctx.status = RequestStatus.ERROR
                    log.error(f"[ReqMgr] 流超时 | id={request_id[:8]}... | chunks={ctx.chunk_count}")
                    raise TimeoutError(f"Stream {request_id} timed out")
        finally:
            await self._cleanup_request(request_id)
    
    def resolve_request(self, request_id: str, response: Dict):
        """解决非流式请求"""
        ctx = self.pending_requests.get(request_id)
        if ctx and not ctx.response_future.done():
            ctx.status = RequestStatus.COMPLETED
            ctx.response_future.set_result(response)
            log.debug(f"[ReqMgr] 请求已解决 | id={request_id[:8]}...")
    
    def push_chunk(self, request_id: str, chunk: Dict):
        """推送流式数据块"""
        ctx = self.pending_requests.get(request_id)
        if ctx and ctx.is_stream:
            ctx.chunk_queue.put_nowait(chunk)
    
    def end_stream(self, request_id: str, final_body: Optional[Dict] = None):
        """结束流式响应"""
        ctx = self.pending_requests.get(request_id)
        if ctx:
            ctx.status = RequestStatus.COMPLETED
            if final_body:
                ctx.chunk_queue.put_nowait(final_body)
            ctx.chunk_queue.put_nowait(None)  # 结束信号
    
    def fail_request(self, request_id: str, error: Dict):
        """标记请求失败"""
        ctx = self.pending_requests.get(request_id)
        if ctx:
            ctx.status = RequestStatus.ERROR
            log.error(f"[ReqMgr] 请求失败 | id={request_id[:8]}... | error={error}")
            if not ctx.response_future.done():
                ctx.response_future.set_exception(Exception(json.dumps(error)))
            if ctx.is_stream:
                ctx.chunk_queue.put_nowait(None)
    
    def abort_request(self, request_id: str):
        """中止请求"""
        ctx = self.pending_requests.get(request_id)
        if ctx:
            ctx.status = RequestStatus.ABORTED
            log.warning(f"[ReqMgr] 请求中止 | id={request_id[:8]}...")
            if not ctx.response_future.done():
                ctx.response_future.set_exception(Exception("Request aborted"))
            if ctx.is_stream:
                ctx.chunk_queue.put_nowait(None)
    
    async def fail_all_requests(self, error_message: str):
        """使所有待处理请求失败"""
        async with self._lock:
            count = len(self.pending_requests)
            if count > 0:
                log.warning(f"[ReqMgr] 批量失败 | count={count} | reason={error_message}")
            for request_id in list(self.pending_requests.keys()):
                self.fail_request(request_id, {
                    "error": {
                        "code": 503,
                        "message": error_message,
                        "status": "UNAVAILABLE"
                    }
                })
    
    async def _cleanup_request(self, request_id: str):
        """清理已完成的请求"""
        async with self._lock:
            self.pending_requests.pop(request_id, None)
    
    @property
    def pending_count(self) -> int:
        return len(self.pending_requests)


# ============================================================
# WebSocket 管理器
# ============================================================

class WebSocketManager:
    """管理 WSClient 连接"""
    
    def __init__(self):
        self.active_connection: Optional[WebSocket] = None
        self.client_id: Optional[str] = None
        self.client_version: Optional[str] = None
        self.models: List[str] = []
        self.max_concurrent: int = 1
        self.connected_at: Optional[float] = None
        self._lock = asyncio.Lock()
        self._heartbeat_task: Optional[asyncio.Task] = None
    
    async def register(self, websocket: WebSocket, payload: Dict):
        """注册 WSClient 连接"""
        async with self._lock:
            # 关闭旧连接
            if self.active_connection and self.active_connection != websocket:
                log.warning("[WSMgr] 关闭旧连接...")
                try:
                    await self.active_connection.close()
                except:
                    pass
            
            self.active_connection = websocket
            self.client_id = payload.get("clientId", "unknown")
            self.client_version = payload.get("clientVersion", "unknown")
            self.models = payload.get("models", [])
            self.max_concurrent = payload.get("maxConcurrent", 1)
            self.connected_at = time.time()
            
            # 发送注册确认
            ack_message = {
                "id": str(uuid.uuid4()),
                "type": MessageType.REGISTER_ACK.value,
                "timestamp": int(time.time() * 1000),
                "payload": {
                    "success": True,
                    "serverId": "gemini-middleware-001",
                    "config": {
                        "heartbeatInterval": Config.HEARTBEAT_INTERVAL * 1000,
                        "requestTimeout": Config.REQUEST_TIMEOUT * 1000
                    }
                }
            }
            await websocket.send_json(ack_message)
            
            log.info(f"[WSMgr] ✅ WSClient 已注册")
            log.info(f"[WSMgr]    ├─ clientId: {self.client_id}")
            log.info(f"[WSMgr]    ├─ version: {self.client_version}")
            log.info(f"[WSMgr]    ├─ models: {self.models}")
            log.info(f"[WSMgr]    └─ maxConcurrent: {self.max_concurrent}")
            
            # 启动心跳
            if self._heartbeat_task:
                self._heartbeat_task.cancel()
            self._heartbeat_task = asyncio.create_task(self._heartbeat_loop())
    
    async def unregister(self, websocket: WebSocket):
        """注销连接"""
        async with self._lock:
            if self.active_connection == websocket:
                uptime = int(time.time() - self.connected_at) if self.connected_at else 0
                log.warning(f"[WSMgr] ❌ WSClient 已断开 | uptime={uptime}s")
                
                self.active_connection = None
                self.client_id = None
                self.models = []
                self.connected_at = None
                
                if self._heartbeat_task:
                    self._heartbeat_task.cancel()
                    self._heartbeat_task = None
    
    def is_available(self) -> bool:
        return self.active_connection is not None
    
    async def send_message(self, message: Dict):
        """发送消息到 WSClient"""
        if self.active_connection:
            await self.active_connection.send_json(message)
            log.debug(f"[WSMgr] 发送消息 | type={message.get('type')} | id={message.get('id', '')[:8]}...")
    
    async def send_request(self, request_id: str, is_stream: bool, body: Dict):
        """发送请求到 WSClient"""
        message = {
            "id": request_id,
            "type": MessageType.REQUEST.value,
            "timestamp": int(time.time() * 1000),
            "stream": is_stream,
            "body": body
        }
        await self.send_message(message)
        
        body_preview = self._truncate_body(body)
        log.info(f"[WSMgr] 📤 发送请求 | id={request_id[:8]}... | stream={is_stream}")
        log.debug(f"[WSMgr]    └─ body: {body_preview}")
    
    async def send_abort(self, request_id: str, reason: str = "client_disconnected"):
        """发送取消请求"""
        message = {
            "id": request_id,
            "type": MessageType.ABORT.value,
            "timestamp": int(time.time() * 1000),
            "reason": reason
        }
        await self.send_message(message)
        log.warning(f"[WSMgr] 🚫 发送取消 | id={request_id[:8]}... | reason={reason}")
    
    async def _heartbeat_loop(self):
        """心跳循环"""
        while True:
            try:
                await asyncio.sleep(Config.HEARTBEAT_INTERVAL)
                if self.active_connection:
                    ping_id = str(uuid.uuid4())
                    await self.send_message({
                        "id": ping_id,
                        "type": MessageType.PING.value,
                        "timestamp": int(time.time() * 1000)
                    })
                    log.debug(f"[WSMgr] 💓 发送心跳 | id={ping_id[:8]}...")
            except asyncio.CancelledError:
                break
            except Exception as e:
                log.error(f"[WSMgr] 心跳错误: {e}")
    
    def _truncate_body(self, body: Dict) -> str:
        """截断 body 用于日志显示"""
        s = json.dumps(body, ensure_ascii=False)
        if len(s) > Config.LOG_BODY_MAX_LENGTH:
            return s[:Config.LOG_BODY_MAX_LENGTH] + "..."
        return s
    
    def get_status(self) -> Dict:
        """获取连接状态"""
        return {
            "connected": self.is_available(),
            "clientId": self.client_id,
            "clientVersion": self.client_version,
            "models": self.models,
            "maxConcurrent": self.max_concurrent,
            "uptime": int(time.time() - self.connected_at) if self.connected_at else 0
        }


# ============================================================
# 全局实例
# ============================================================

request_manager = RequestManager()
ws_manager = WebSocketManager()


# ============================================================
# FastAPI 应用
# ============================================================

@asynccontextmanager
async def lifespan(app: FastAPI):
    log.info("=" * 60)
    log.info("  Gemini API 兼容中间件")
    log.info("=" * 60)
    log.info(f"  HTTP 端点: http://{Config.HOST}:{Config.PORT}")
    log.info(f"  WebSocket: ws://{Config.HOST}:{Config.PORT}/ws")
    log.info(f"  API 文档:  http://{Config.HOST}:{Config.PORT}/docs")
    log.info("=" * 60)
    yield
    log.info("[Server] 服务关闭")


app = FastAPI(
    title="Gemini API Compatible Middleware",
    description="透明代理 - 将 Gemini API 请求通过 WebSocket 转发给 WSClient",
    version="1.0.0",
    lifespan=lifespan
)


# ============================================================
# API Key 校验 - 支持多种传递方式
# ============================================================

def extract_api_key(request: Request) -> Optional[str]:
    """
    从多个位置提取 API Key,按优先级:
    1. Header: x-goog-api-key
    2. Header: Authorization: Bearer <key>
    3. Query: ?key=<key>
    """
    # 方式1: x-goog-api-key header
    api_key = request.headers.get("x-goog-api-key")
    if api_key:
        log.debug(f"[Auth] 从 x-goog-api-key header 获取 key")
        return api_key
    
    # 方式2: Authorization header (Bearer token)
    auth_header = request.headers.get("authorization", "")
    if auth_header.lower().startswith("bearer "):
        api_key = auth_header[7:].strip()
        log.debug(f"[Auth] 从 Authorization header 获取 key")
        return api_key
    
    # 方式3: Query parameter
    api_key = request.query_params.get("key")
    if api_key:
        log.debug(f"[Auth] 从 query parameter 获取 key")
        return api_key
    
    return None


def verify_api_key(request: Request):
    """从请求中校验 API Key"""
    api_key = extract_api_key(request)
    
    # 调试日志
    log.debug(f"[Auth] 提取到的 Key: {api_key[:10] + '...' if api_key else 'None'}")
    
    if api_key != Config.API_KEY:
        log.warning(f"[Auth] ⛔ 认证失败 | key={api_key[:10] + '...' if api_key else 'None'}")
        raise HTTPException(
            status_code=401,
            detail={
                "error": {
                    "code": 401,
                    "message": "Invalid API key. Provide via 'x-goog-api-key' header, 'Authorization: Bearer <key>' header, or '?key=<key>' query parameter.",
                    "status": "UNAUTHENTICATED"
                }
            }
        )
    log.debug("[Auth] ✓ API Key 验证通过")


# ============================================================
# HTTP 路由 - Gemini API 兼容
# ============================================================

@app.get("/")
async def root():
    """根路径 - 服务信息"""
    return {
        "service": "Gemini API Compatible Middleware",
        "version": "1.0.0",
        "status": "running",
        "wsClientConnected": ws_manager.is_available(),
        "endpoints": {
            "models": "/v1beta/models",
            "generateContent": "/v1beta/models/{model}:generateContent",
            "streamGenerateContent": "/v1beta/models/{model}:streamGenerateContent",
            "health": "/health",
            "websocket": "/ws"
        }
    }


@app.get("/health")
async def health_check():
    """健康检查"""
    status = ws_manager.get_status()
    return {
        "status": "ok" if status["connected"] else "degraded",
        "timestamp": datetime.now().isoformat(),
        "wsClient": status,
        "pendingRequests": request_manager.pending_count
    }


@app.get("/v1beta/models")
async def list_models(request: Request):
    """列出可用模型"""
    verify_api_key(request)
    
    log.info("[API] GET /v1beta/models")
    
    if not ws_manager.is_available():
        log.warning("[API] WSClient 未连接,返回空模型列表")
        return {"models": []}
    
    models = [
        {
            "name": f"models/{model}",
            "displayName": model,
            "supportedGenerationMethods": ["generateContent", "streamGenerateContent"]
        }
        for model in ws_manager.models
    ]
    
    log.info(f"[API] 返回 {len(models)} 个模型")
    return {"models": models}


@app.get("/v1beta/models/{model}")
async def get_model(
    request: Request,
    model: str = Path(..., description="模型名称")
):
    """获取模型信息"""
    verify_api_key(request)
    
    log.info(f"[API] GET /v1beta/models/{model}")
    
    # 移除可能的 "models/" 前缀
    model_name = model.replace("models/", "")
    
    if not ws_manager.is_available():
        raise HTTPException(status_code=503, detail={
            "error": {"code": 503, "message": "Service unavailable", "status": "UNAVAILABLE"}
        })
    
    if model_name not in ws_manager.models:
        raise HTTPException(status_code=404, detail={
            "error": {"code": 404, "message": f"Model '{model}' not found", "status": "NOT_FOUND"}
        })
    
    return {
        "name": f"models/{model_name}",
        "displayName": model_name,
        "supportedGenerationMethods": ["generateContent", "streamGenerateContent"]
    }


@app.post("/v1beta/models/{model}:generateContent")
async def generate_content(
    request: Request,
    model: str = Path(..., description="模型名称")
):
    """生成内容 - 非流式"""
    verify_api_key(request)
    
    model_name = model.replace("models/", "")
    log.info(f"[API] POST /v1beta/models/{model_name}:generateContent")
    
    # 检查 WSClient 是否可用
    if not ws_manager.is_available():
        log.error("[API] WSClient 未连接")
        raise HTTPException(status_code=503, detail={
            "error": {"code": 503, "message": "WSClient not connected", "status": "UNAVAILABLE"}
        })
    
    # 读取请求体(透传)
    body = await request.json()
    log.debug(f"[API] 请求体: {ws_manager._truncate_body(body)}")
    
    # 创建请求上下文
    ctx = await request_manager.create_request(is_stream=False, model=model_name)
    
    try:
        # 发送请求到 WSClient
        await ws_manager.send_request(ctx.id, is_stream=False, body=body)
        
        # 等待响应
        response_body = await request_manager.wait_for_response(ctx.id, Config.REQUEST_TIMEOUT)
        
        log.info(f"[API] ✅ 请求完成 | id={ctx.id[:8]}... | elapsed={ctx.elapsed_ms()}ms")
        return JSONResponse(content=response_body)
        
    except TimeoutError:
        log.error(f"[API] ⏱️ 请求超时 | id={ctx.id[:8]}...")
        raise HTTPException(status_code=504, detail={
            "error": {"code": 504, "message": "Request timeout", "status": "DEADLINE_EXCEEDED"}
        })
    except Exception as e:
        log.error(f"[API] ❌ 请求失败 | id={ctx.id[:8]}... | error={e}")
        # 尝试解析错误 JSON
        try:
            error_detail = json.loads(str(e))
        except:
            error_detail = {"error": {"code": 500, "message": str(e), "status": "INTERNAL"}}
        raise HTTPException(status_code=500, detail=error_detail)


@app.post("/v1beta/models/{model}:streamGenerateContent")
async def stream_generate_content(
    request: Request,
    model: str = Path(..., description="模型名称")
):
    """生成内容 - 流式"""
    verify_api_key(request)
    
    model_name = model.replace("models/", "")
    log.info(f"[API] POST /v1beta/models/{model_name}:streamGenerateContent")
    
    # 检查 WSClient 是否可用
    if not ws_manager.is_available():
        log.error("[API] WSClient 未连接")
        raise HTTPException(status_code=503, detail={
            "error": {"code": 503, "message": "WSClient not connected", "status": "UNAVAILABLE"}
        })
    
    # 读取请求体(透传)
    body = await request.json()
    log.debug(f"[API] 请求体: {ws_manager._truncate_body(body)}")
    
    # 创建请求上下文
    ctx = await request_manager.create_request(is_stream=True, model=model_name)
    
    # 发送请求到 WSClient
    await ws_manager.send_request(ctx.id, is_stream=True, body=body)
    
    # 返回流式响应
    return StreamingResponse(
        stream_generator(ctx.id, request),
        media_type="text/event-stream",
        headers={
            "Cache-Control": "no-cache",
            "Connection": "keep-alive",
            "X-Accel-Buffering": "no"
        }
    )


async def stream_generator(request_id: str, http_request: Request) -> AsyncGenerator[str, None]:
    """生成 SSE 流式响应"""
    ctx = request_manager.get_request(request_id)
    
    try:
        async for chunk_body in request_manager.wait_for_stream(request_id, Config.REQUEST_TIMEOUT):
            # 检查客户端是否断开
            if await http_request.is_disconnected():
                log.warning(f"[Stream] 客户端断开 | id={request_id[:8]}...")
                await ws_manager.send_abort(request_id, "client_disconnected")
                request_manager.abort_request(request_id)
                break
            
            # 输出 SSE 格式(透传 body)
            yield f"data: {json.dumps(chunk_body)}\n\n"
        
        log.info(f"[Stream] ✅ 流完成 | id={request_id[:8]}...")
        
    except TimeoutError:
        log.error(f"[Stream] ⏱️ 流超时 | id={request_id[:8]}...")
        error_body = {"error": {"code": 504, "message": "Stream timeout", "status": "DEADLINE_EXCEEDED"}}
        yield f"data: {json.dumps(error_body)}\n\n"
    except Exception as e:
        log.error(f"[Stream] ❌ 流错误 | id={request_id[:8]}... | error={e}")
        error_body = {"error": {"code": 500, "message": str(e), "status": "INTERNAL"}}
        yield f"data: {json.dumps(error_body)}\n\n"


# ============================================================
# WebSocket 端点
# ============================================================

@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
    """WSClient WebSocket 连接端点"""
    await websocket.accept()
    client_ip = websocket.client.host if websocket.client else "unknown"
    log.info(f"[WS] 🔌 新连接 | ip={client_ip}")
    log.info(f"[WS]    └─ 等待 REGISTER 消息 (timeout={Config.REGISTER_TIMEOUT}s)...")
    
    try:
        while True:
            data = await websocket.receive_json()
            await handle_ws_message(websocket, data)
    
    except WebSocketDisconnect:
        log.warning(f"[WS] 连接断开 | ip={client_ip}")
    except Exception as e:
        log.error(f"[WS] 连接错误 | ip={client_ip} | error={e}")
    finally:
        await ws_manager.unregister(websocket)
        await request_manager.fail_all_requests("WSClient disconnected")


async def handle_ws_message(websocket: WebSocket, data: Dict):
    """处理 WSClient 消息"""
    msg_type = data.get("type", "")
    msg_id = data.get("id", "")
    timestamp = data.get("timestamp", 0)
    
    # 简短日志
    log.debug(f"[WS] 📩 收到消息 | type={msg_type} | id={msg_id[:8] if msg_id else 'N/A'}...")
    
    # 根据消息类型处理
    if msg_type == MessageType.REGISTER.value or msg_type == "register":
        payload = data.get("payload", {})
        await ws_manager.register(websocket, payload)
    
    elif msg_type == MessageType.RESPONSE.value or msg_type == "response":
        body = data.get("body", {})
        log.info(f"[WS] 📥 收到响应 | id={msg_id[:8]}...")
        log.debug(f"[WS]    └─ body: {ws_manager._truncate_body(body)}")
        request_manager.resolve_request(msg_id, body)
    
    elif msg_type == MessageType.CHUNK.value or msg_type == "chunk":
        body = data.get("body", {})
        index = data.get("index", 0)
        log.debug(f"[WS] 📦 收到数据块 | id={msg_id[:8]}... | index={index}")
        request_manager.push_chunk(msg_id, body)
    
    elif msg_type == MessageType.END.value or msg_type == "end":
        body = data.get("body")
        log.info(f"[WS] 🏁 流结束 | id={msg_id[:8]}...")
        request_manager.end_stream(msg_id, body)
    
    elif msg_type == MessageType.ERROR.value or msg_type == "error":
        body = data.get("body", {})
        log.error(f"[WS] ⚠️ 收到错误 | id={msg_id[:8]}... | body={body}")
        request_manager.fail_request(msg_id, body)
    
    elif msg_type == MessageType.PONG.value or msg_type == "pong":
        log.debug(f"[WS] 💓 收到心跳响应 | id={msg_id[:8]}...")
    
    else:
        log.warning(f"[WS] ❓ 未知消息类型 | type={msg_type}")


# ============================================================
# 启动
# ============================================================

if __name__ == "__main__":
    uvicorn.run(
        "main:app",
        host=Config.HOST,
        port=Config.PORT,
        reload=True,
        log_level="warning"  # 使用我们自己的日志
    )