Spaces:
Sleeping
Sleeping
| """ | |
| Gemini API 兼容中间件服务器 | |
| 透明代理 - 将 HTTP 请求通过 WebSocket 转发给 WSClient 处理 | |
| 版本: 1.0.0 | |
| 协议: Gemini Compatible WebSocket Proxy Protocol | |
| """ | |
| import asyncio | |
| import json | |
| import time | |
| import uuid | |
| import logging | |
| from dataclasses import dataclass, field | |
| from enum import Enum | |
| from typing import Optional, Dict, List, Any, AsyncGenerator | |
| from contextlib import asynccontextmanager | |
| from datetime import datetime | |
| from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException, Request, Path | |
| from fastapi.responses import StreamingResponse, JSONResponse, HTMLResponse | |
| from pydantic import BaseModel | |
| import uvicorn | |
| # ============================================================ | |
| # 日志配置 | |
| # ============================================================ | |
| class ColoredFormatter(logging.Formatter): | |
| """彩色日志格式化器""" | |
| COLORS = { | |
| 'DEBUG': '\033[36m', # 青色 | |
| 'INFO': '\033[32m', # 绿色 | |
| 'WARNING': '\033[33m', # 黄色 | |
| 'ERROR': '\033[31m', # 红色 | |
| 'CRITICAL': '\033[35m', # 紫色 | |
| } | |
| RESET = '\033[0m' | |
| def format(self, record): | |
| color = self.COLORS.get(record.levelname, self.RESET) | |
| record.levelname = f"{color}{record.levelname}{self.RESET}" | |
| record.msg = f"{color}{record.msg}{self.RESET}" | |
| return super().format(record) | |
| def setup_logging(): | |
| """配置日志""" | |
| handler = logging.StreamHandler() | |
| handler.setFormatter(ColoredFormatter( | |
| fmt='%(asctime)s | %(levelname)-17s | %(message)s', | |
| datefmt='%Y-%m-%d %H:%M:%S' | |
| )) | |
| logger = logging.getLogger("middleware") | |
| logger.setLevel(logging.DEBUG) | |
| logger.addHandler(handler) | |
| return logger | |
| log = setup_logging() | |
| # ============================================================ | |
| # 配置 | |
| # ============================================================ | |
| class Config: | |
| HOST = "0.0.0.0" | |
| PORT = 8000 | |
| API_KEY = "sk-123456" # 修改为你的密钥 | |
| REQUEST_TIMEOUT = 120 # 请求超时(秒) | |
| HEARTBEAT_INTERVAL = 30 # 心跳间隔(秒) | |
| REGISTER_TIMEOUT = 10 # 注册超时(秒) | |
| LOG_BODY_MAX_LENGTH = 500 # 日志中 body 最大显示长度 | |
| # ============================================================ | |
| # 协议定义 | |
| # ============================================================ | |
| class MessageType(str, Enum): | |
| # 连接管理 | |
| REGISTER = "register" | |
| REGISTER_ACK = "register_ack" | |
| # 请求 | |
| REQUEST = "request" | |
| # 响应 | |
| RESPONSE = "response" | |
| CHUNK = "chunk" | |
| END = "end" | |
| ERROR = "error" | |
| # 控制 | |
| ABORT = "abort" | |
| PING = "ping" | |
| PONG = "pong" | |
| # ============================================================ | |
| # 请求上下文 | |
| # ============================================================ | |
| class RequestStatus(str, Enum): | |
| PENDING = "pending" | |
| STREAMING = "streaming" | |
| COMPLETED = "completed" | |
| ERROR = "error" | |
| ABORTED = "aborted" | |
| class RequestContext: | |
| """请求上下文,追踪每个请求的状态""" | |
| id: str | |
| created_at: float | |
| is_stream: bool | |
| model: str | |
| status: RequestStatus = RequestStatus.PENDING | |
| # 非流式请求:用 Future 等待完整响应 | |
| response_future: asyncio.Future = field(default_factory=lambda: asyncio.get_running_loop().create_future()) | |
| # 流式请求:用 Queue 传递 chunks | |
| chunk_queue: asyncio.Queue = field(default_factory=asyncio.Queue) | |
| # 统计 | |
| chunk_count: int = 0 | |
| def elapsed_ms(self) -> int: | |
| """返回请求耗时(毫秒)""" | |
| return int((time.time() - self.created_at) * 1000) | |
| # ============================================================ | |
| # 请求管理器 | |
| # ============================================================ | |
| class RequestManager: | |
| """管理所有待处理的请求""" | |
| def __init__(self): | |
| self.pending_requests: Dict[str, RequestContext] = {} | |
| self._lock = asyncio.Lock() | |
| async def create_request(self, is_stream: bool, model: str) -> RequestContext: | |
| """创建新的请求上下文""" | |
| request_id = str(uuid.uuid4()) | |
| ctx = RequestContext( | |
| id=request_id, | |
| created_at=time.time(), | |
| is_stream=is_stream, | |
| model=model | |
| ) | |
| async with self._lock: | |
| self.pending_requests[request_id] = ctx | |
| log.debug(f"[ReqMgr] 创建请求 | id={request_id[:8]}... | stream={is_stream} | model={model}") | |
| return ctx | |
| def get_request(self, request_id: str) -> Optional[RequestContext]: | |
| return self.pending_requests.get(request_id) | |
| async def wait_for_response(self, request_id: str, timeout: float) -> Dict: | |
| """等待非流式请求的完整响应""" | |
| ctx = self.pending_requests.get(request_id) | |
| if not ctx: | |
| raise ValueError(f"Request {request_id} not found") | |
| try: | |
| log.debug(f"[ReqMgr] 等待响应 | id={request_id[:8]}... | timeout={timeout}s") | |
| result = await asyncio.wait_for(ctx.response_future, timeout=timeout) | |
| log.info(f"[ReqMgr] 收到响应 | id={request_id[:8]}... | elapsed={ctx.elapsed_ms()}ms") | |
| return result | |
| except asyncio.TimeoutError: | |
| ctx.status = RequestStatus.ERROR | |
| log.error(f"[ReqMgr] 请求超时 | id={request_id[:8]}... | elapsed={ctx.elapsed_ms()}ms") | |
| raise TimeoutError(f"Request {request_id} timed out after {timeout}s") | |
| finally: | |
| await self._cleanup_request(request_id) | |
| async def wait_for_stream(self, request_id: str, timeout: float) -> AsyncGenerator[Dict, None]: | |
| """等待流式请求的数据块""" | |
| ctx = self.pending_requests.get(request_id) | |
| if not ctx: | |
| raise ValueError(f"Request {request_id} not found") | |
| ctx.status = RequestStatus.STREAMING | |
| log.debug(f"[ReqMgr] 开始流式接收 | id={request_id[:8]}...") | |
| try: | |
| while True: | |
| try: | |
| chunk = await asyncio.wait_for(ctx.chunk_queue.get(), timeout=timeout) | |
| if chunk is None: # 流结束信号 | |
| log.info(f"[ReqMgr] 流结束 | id={request_id[:8]}... | chunks={ctx.chunk_count} | elapsed={ctx.elapsed_ms()}ms") | |
| break | |
| ctx.chunk_count += 1 | |
| yield chunk | |
| except asyncio.TimeoutError: | |
| ctx.status = RequestStatus.ERROR | |
| log.error(f"[ReqMgr] 流超时 | id={request_id[:8]}... | chunks={ctx.chunk_count}") | |
| raise TimeoutError(f"Stream {request_id} timed out") | |
| finally: | |
| await self._cleanup_request(request_id) | |
| def resolve_request(self, request_id: str, response: Dict): | |
| """解决非流式请求""" | |
| ctx = self.pending_requests.get(request_id) | |
| if ctx and not ctx.response_future.done(): | |
| ctx.status = RequestStatus.COMPLETED | |
| ctx.response_future.set_result(response) | |
| log.debug(f"[ReqMgr] 请求已解决 | id={request_id[:8]}...") | |
| def push_chunk(self, request_id: str, chunk: Dict): | |
| """推送流式数据块""" | |
| ctx = self.pending_requests.get(request_id) | |
| if ctx and ctx.is_stream: | |
| ctx.chunk_queue.put_nowait(chunk) | |
| def end_stream(self, request_id: str, final_body: Optional[Dict] = None): | |
| """结束流式响应""" | |
| ctx = self.pending_requests.get(request_id) | |
| if ctx: | |
| ctx.status = RequestStatus.COMPLETED | |
| if final_body: | |
| ctx.chunk_queue.put_nowait(final_body) | |
| ctx.chunk_queue.put_nowait(None) # 结束信号 | |
| def fail_request(self, request_id: str, error: Dict): | |
| """标记请求失败""" | |
| ctx = self.pending_requests.get(request_id) | |
| if ctx: | |
| ctx.status = RequestStatus.ERROR | |
| log.error(f"[ReqMgr] 请求失败 | id={request_id[:8]}... | error={error}") | |
| if not ctx.response_future.done(): | |
| ctx.response_future.set_exception(Exception(json.dumps(error))) | |
| if ctx.is_stream: | |
| ctx.chunk_queue.put_nowait(None) | |
| def abort_request(self, request_id: str): | |
| """中止请求""" | |
| ctx = self.pending_requests.get(request_id) | |
| if ctx: | |
| ctx.status = RequestStatus.ABORTED | |
| log.warning(f"[ReqMgr] 请求中止 | id={request_id[:8]}...") | |
| if not ctx.response_future.done(): | |
| ctx.response_future.set_exception(Exception("Request aborted")) | |
| if ctx.is_stream: | |
| ctx.chunk_queue.put_nowait(None) | |
| async def fail_all_requests(self, error_message: str): | |
| """使所有待处理请求失败""" | |
| async with self._lock: | |
| count = len(self.pending_requests) | |
| if count > 0: | |
| log.warning(f"[ReqMgr] 批量失败 | count={count} | reason={error_message}") | |
| for request_id in list(self.pending_requests.keys()): | |
| self.fail_request(request_id, { | |
| "error": { | |
| "code": 503, | |
| "message": error_message, | |
| "status": "UNAVAILABLE" | |
| } | |
| }) | |
| async def _cleanup_request(self, request_id: str): | |
| """清理已完成的请求""" | |
| async with self._lock: | |
| self.pending_requests.pop(request_id, None) | |
| def pending_count(self) -> int: | |
| return len(self.pending_requests) | |
| # ============================================================ | |
| # WebSocket 管理器 | |
| # ============================================================ | |
| class WebSocketManager: | |
| """管理 WSClient 连接""" | |
| def __init__(self): | |
| self.active_connection: Optional[WebSocket] = None | |
| self.client_id: Optional[str] = None | |
| self.client_version: Optional[str] = None | |
| self.models: List[str] = [] | |
| self.max_concurrent: int = 1 | |
| self.connected_at: Optional[float] = None | |
| self._lock = asyncio.Lock() | |
| self._heartbeat_task: Optional[asyncio.Task] = None | |
| async def register(self, websocket: WebSocket, payload: Dict): | |
| """注册 WSClient 连接""" | |
| async with self._lock: | |
| # 关闭旧连接 | |
| if self.active_connection and self.active_connection != websocket: | |
| log.warning("[WSMgr] 关闭旧连接...") | |
| try: | |
| await self.active_connection.close() | |
| except: | |
| pass | |
| self.active_connection = websocket | |
| self.client_id = payload.get("clientId", "unknown") | |
| self.client_version = payload.get("clientVersion", "unknown") | |
| self.models = payload.get("models", []) | |
| self.max_concurrent = payload.get("maxConcurrent", 1) | |
| self.connected_at = time.time() | |
| # 发送注册确认 | |
| ack_message = { | |
| "id": str(uuid.uuid4()), | |
| "type": MessageType.REGISTER_ACK.value, | |
| "timestamp": int(time.time() * 1000), | |
| "payload": { | |
| "success": True, | |
| "serverId": "gemini-middleware-001", | |
| "config": { | |
| "heartbeatInterval": Config.HEARTBEAT_INTERVAL * 1000, | |
| "requestTimeout": Config.REQUEST_TIMEOUT * 1000 | |
| } | |
| } | |
| } | |
| await websocket.send_json(ack_message) | |
| log.info(f"[WSMgr] ✅ WSClient 已注册") | |
| log.info(f"[WSMgr] ├─ clientId: {self.client_id}") | |
| log.info(f"[WSMgr] ├─ version: {self.client_version}") | |
| log.info(f"[WSMgr] ├─ models: {self.models}") | |
| log.info(f"[WSMgr] └─ maxConcurrent: {self.max_concurrent}") | |
| # 启动心跳 | |
| if self._heartbeat_task: | |
| self._heartbeat_task.cancel() | |
| self._heartbeat_task = asyncio.create_task(self._heartbeat_loop()) | |
| async def unregister(self, websocket: WebSocket): | |
| """注销连接""" | |
| async with self._lock: | |
| if self.active_connection == websocket: | |
| uptime = int(time.time() - self.connected_at) if self.connected_at else 0 | |
| log.warning(f"[WSMgr] ❌ WSClient 已断开 | uptime={uptime}s") | |
| self.active_connection = None | |
| self.client_id = None | |
| self.models = [] | |
| self.connected_at = None | |
| if self._heartbeat_task: | |
| self._heartbeat_task.cancel() | |
| self._heartbeat_task = None | |
| def is_available(self) -> bool: | |
| return self.active_connection is not None | |
| async def send_message(self, message: Dict): | |
| """发送消息到 WSClient""" | |
| if self.active_connection: | |
| await self.active_connection.send_json(message) | |
| log.debug(f"[WSMgr] 发送消息 | type={message.get('type')} | id={message.get('id', '')[:8]}...") | |
| async def send_request(self, request_id: str, is_stream: bool, body: Dict): | |
| """发送请求到 WSClient""" | |
| message = { | |
| "id": request_id, | |
| "type": MessageType.REQUEST.value, | |
| "timestamp": int(time.time() * 1000), | |
| "stream": is_stream, | |
| "body": body | |
| } | |
| await self.send_message(message) | |
| body_preview = self._truncate_body(body) | |
| log.info(f"[WSMgr] 📤 发送请求 | id={request_id[:8]}... | stream={is_stream}") | |
| log.debug(f"[WSMgr] └─ body: {body_preview}") | |
| async def send_abort(self, request_id: str, reason: str = "client_disconnected"): | |
| """发送取消请求""" | |
| message = { | |
| "id": request_id, | |
| "type": MessageType.ABORT.value, | |
| "timestamp": int(time.time() * 1000), | |
| "reason": reason | |
| } | |
| await self.send_message(message) | |
| log.warning(f"[WSMgr] 🚫 发送取消 | id={request_id[:8]}... | reason={reason}") | |
| async def _heartbeat_loop(self): | |
| """心跳循环""" | |
| while True: | |
| try: | |
| await asyncio.sleep(Config.HEARTBEAT_INTERVAL) | |
| if self.active_connection: | |
| ping_id = str(uuid.uuid4()) | |
| await self.send_message({ | |
| "id": ping_id, | |
| "type": MessageType.PING.value, | |
| "timestamp": int(time.time() * 1000) | |
| }) | |
| log.debug(f"[WSMgr] 💓 发送心跳 | id={ping_id[:8]}...") | |
| except asyncio.CancelledError: | |
| break | |
| except Exception as e: | |
| log.error(f"[WSMgr] 心跳错误: {e}") | |
| def _truncate_body(self, body: Dict) -> str: | |
| """截断 body 用于日志显示""" | |
| s = json.dumps(body, ensure_ascii=False) | |
| if len(s) > Config.LOG_BODY_MAX_LENGTH: | |
| return s[:Config.LOG_BODY_MAX_LENGTH] + "..." | |
| return s | |
| def get_status(self) -> Dict: | |
| """获取连接状态""" | |
| return { | |
| "connected": self.is_available(), | |
| "clientId": self.client_id, | |
| "clientVersion": self.client_version, | |
| "models": self.models, | |
| "maxConcurrent": self.max_concurrent, | |
| "uptime": int(time.time() - self.connected_at) if self.connected_at else 0 | |
| } | |
| # ============================================================ | |
| # 全局实例 | |
| # ============================================================ | |
| request_manager = RequestManager() | |
| ws_manager = WebSocketManager() | |
| # ============================================================ | |
| # FastAPI 应用 | |
| # ============================================================ | |
| async def lifespan(app: FastAPI): | |
| log.info("=" * 60) | |
| log.info(" Gemini API 兼容中间件") | |
| log.info("=" * 60) | |
| log.info(f" HTTP 端点: http://{Config.HOST}:{Config.PORT}") | |
| log.info(f" WebSocket: ws://{Config.HOST}:{Config.PORT}/ws") | |
| log.info(f" API 文档: http://{Config.HOST}:{Config.PORT}/docs") | |
| log.info("=" * 60) | |
| yield | |
| log.info("[Server] 服务关闭") | |
| app = FastAPI( | |
| title="Gemini API Compatible Middleware", | |
| description="透明代理 - 将 Gemini API 请求通过 WebSocket 转发给 WSClient", | |
| version="1.0.0", | |
| lifespan=lifespan | |
| ) | |
| # ============================================================ | |
| # API Key 校验 - 支持多种传递方式 | |
| # ============================================================ | |
| def extract_api_key(request: Request) -> Optional[str]: | |
| """ | |
| 从多个位置提取 API Key,按优先级: | |
| 1. Header: x-goog-api-key | |
| 2. Header: Authorization: Bearer <key> | |
| 3. Query: ?key=<key> | |
| """ | |
| # 方式1: x-goog-api-key header | |
| api_key = request.headers.get("x-goog-api-key") | |
| if api_key: | |
| log.debug(f"[Auth] 从 x-goog-api-key header 获取 key") | |
| return api_key | |
| # 方式2: Authorization header (Bearer token) | |
| auth_header = request.headers.get("authorization", "") | |
| if auth_header.lower().startswith("bearer "): | |
| api_key = auth_header[7:].strip() | |
| log.debug(f"[Auth] 从 Authorization header 获取 key") | |
| return api_key | |
| # 方式3: Query parameter | |
| api_key = request.query_params.get("key") | |
| if api_key: | |
| log.debug(f"[Auth] 从 query parameter 获取 key") | |
| return api_key | |
| return None | |
| def verify_api_key(request: Request): | |
| """从请求中校验 API Key""" | |
| api_key = extract_api_key(request) | |
| # 调试日志 | |
| log.debug(f"[Auth] 提取到的 Key: {api_key[:10] + '...' if api_key else 'None'}") | |
| if api_key != Config.API_KEY: | |
| log.warning(f"[Auth] ⛔ 认证失败 | key={api_key[:10] + '...' if api_key else 'None'}") | |
| raise HTTPException( | |
| status_code=401, | |
| detail={ | |
| "error": { | |
| "code": 401, | |
| "message": "Invalid API key. Provide via 'x-goog-api-key' header, 'Authorization: Bearer <key>' header, or '?key=<key>' query parameter.", | |
| "status": "UNAUTHENTICATED" | |
| } | |
| } | |
| ) | |
| log.debug("[Auth] ✓ API Key 验证通过") | |
| # ============================================================ | |
| # HTTP 路由 - Gemini API 兼容 | |
| # ============================================================ | |
| async def root(): | |
| """根路径 - 服务信息""" | |
| return { | |
| "service": "Gemini API Compatible Middleware", | |
| "version": "1.0.0", | |
| "status": "running", | |
| "wsClientConnected": ws_manager.is_available(), | |
| "endpoints": { | |
| "models": "/v1beta/models", | |
| "generateContent": "/v1beta/models/{model}:generateContent", | |
| "streamGenerateContent": "/v1beta/models/{model}:streamGenerateContent", | |
| "health": "/health", | |
| "websocket": "/ws" | |
| } | |
| } | |
| async def health_check(): | |
| """健康检查""" | |
| status = ws_manager.get_status() | |
| return { | |
| "status": "ok" if status["connected"] else "degraded", | |
| "timestamp": datetime.now().isoformat(), | |
| "wsClient": status, | |
| "pendingRequests": request_manager.pending_count | |
| } | |
| async def list_models(request: Request): | |
| """列出可用模型""" | |
| verify_api_key(request) | |
| log.info("[API] GET /v1beta/models") | |
| if not ws_manager.is_available(): | |
| log.warning("[API] WSClient 未连接,返回空模型列表") | |
| return {"models": []} | |
| models = [ | |
| { | |
| "name": f"models/{model}", | |
| "displayName": model, | |
| "supportedGenerationMethods": ["generateContent", "streamGenerateContent"] | |
| } | |
| for model in ws_manager.models | |
| ] | |
| log.info(f"[API] 返回 {len(models)} 个模型") | |
| return {"models": models} | |
| async def get_model( | |
| request: Request, | |
| model: str = Path(..., description="模型名称") | |
| ): | |
| """获取模型信息""" | |
| verify_api_key(request) | |
| log.info(f"[API] GET /v1beta/models/{model}") | |
| # 移除可能的 "models/" 前缀 | |
| model_name = model.replace("models/", "") | |
| if not ws_manager.is_available(): | |
| raise HTTPException(status_code=503, detail={ | |
| "error": {"code": 503, "message": "Service unavailable", "status": "UNAVAILABLE"} | |
| }) | |
| if model_name not in ws_manager.models: | |
| raise HTTPException(status_code=404, detail={ | |
| "error": {"code": 404, "message": f"Model '{model}' not found", "status": "NOT_FOUND"} | |
| }) | |
| return { | |
| "name": f"models/{model_name}", | |
| "displayName": model_name, | |
| "supportedGenerationMethods": ["generateContent", "streamGenerateContent"] | |
| } | |
| async def generate_content( | |
| request: Request, | |
| model: str = Path(..., description="模型名称") | |
| ): | |
| """生成内容 - 非流式""" | |
| verify_api_key(request) | |
| model_name = model.replace("models/", "") | |
| log.info(f"[API] POST /v1beta/models/{model_name}:generateContent") | |
| # 检查 WSClient 是否可用 | |
| if not ws_manager.is_available(): | |
| log.error("[API] WSClient 未连接") | |
| raise HTTPException(status_code=503, detail={ | |
| "error": {"code": 503, "message": "WSClient not connected", "status": "UNAVAILABLE"} | |
| }) | |
| # 读取请求体(透传) | |
| body = await request.json() | |
| log.debug(f"[API] 请求体: {ws_manager._truncate_body(body)}") | |
| # 创建请求上下文 | |
| ctx = await request_manager.create_request(is_stream=False, model=model_name) | |
| try: | |
| # 发送请求到 WSClient | |
| await ws_manager.send_request(ctx.id, is_stream=False, body=body) | |
| # 等待响应 | |
| response_body = await request_manager.wait_for_response(ctx.id, Config.REQUEST_TIMEOUT) | |
| log.info(f"[API] ✅ 请求完成 | id={ctx.id[:8]}... | elapsed={ctx.elapsed_ms()}ms") | |
| return JSONResponse(content=response_body) | |
| except TimeoutError: | |
| log.error(f"[API] ⏱️ 请求超时 | id={ctx.id[:8]}...") | |
| raise HTTPException(status_code=504, detail={ | |
| "error": {"code": 504, "message": "Request timeout", "status": "DEADLINE_EXCEEDED"} | |
| }) | |
| except Exception as e: | |
| log.error(f"[API] ❌ 请求失败 | id={ctx.id[:8]}... | error={e}") | |
| # 尝试解析错误 JSON | |
| try: | |
| error_detail = json.loads(str(e)) | |
| except: | |
| error_detail = {"error": {"code": 500, "message": str(e), "status": "INTERNAL"}} | |
| raise HTTPException(status_code=500, detail=error_detail) | |
| async def stream_generate_content( | |
| request: Request, | |
| model: str = Path(..., description="模型名称") | |
| ): | |
| """生成内容 - 流式""" | |
| verify_api_key(request) | |
| model_name = model.replace("models/", "") | |
| log.info(f"[API] POST /v1beta/models/{model_name}:streamGenerateContent") | |
| # 检查 WSClient 是否可用 | |
| if not ws_manager.is_available(): | |
| log.error("[API] WSClient 未连接") | |
| raise HTTPException(status_code=503, detail={ | |
| "error": {"code": 503, "message": "WSClient not connected", "status": "UNAVAILABLE"} | |
| }) | |
| # 读取请求体(透传) | |
| body = await request.json() | |
| log.debug(f"[API] 请求体: {ws_manager._truncate_body(body)}") | |
| # 创建请求上下文 | |
| ctx = await request_manager.create_request(is_stream=True, model=model_name) | |
| # 发送请求到 WSClient | |
| await ws_manager.send_request(ctx.id, is_stream=True, body=body) | |
| # 返回流式响应 | |
| return StreamingResponse( | |
| stream_generator(ctx.id, request), | |
| media_type="text/event-stream", | |
| headers={ | |
| "Cache-Control": "no-cache", | |
| "Connection": "keep-alive", | |
| "X-Accel-Buffering": "no" | |
| } | |
| ) | |
| async def stream_generator(request_id: str, http_request: Request) -> AsyncGenerator[str, None]: | |
| """生成 SSE 流式响应""" | |
| ctx = request_manager.get_request(request_id) | |
| try: | |
| async for chunk_body in request_manager.wait_for_stream(request_id, Config.REQUEST_TIMEOUT): | |
| # 检查客户端是否断开 | |
| if await http_request.is_disconnected(): | |
| log.warning(f"[Stream] 客户端断开 | id={request_id[:8]}...") | |
| await ws_manager.send_abort(request_id, "client_disconnected") | |
| request_manager.abort_request(request_id) | |
| break | |
| # 输出 SSE 格式(透传 body) | |
| yield f"data: {json.dumps(chunk_body)}\n\n" | |
| log.info(f"[Stream] ✅ 流完成 | id={request_id[:8]}...") | |
| except TimeoutError: | |
| log.error(f"[Stream] ⏱️ 流超时 | id={request_id[:8]}...") | |
| error_body = {"error": {"code": 504, "message": "Stream timeout", "status": "DEADLINE_EXCEEDED"}} | |
| yield f"data: {json.dumps(error_body)}\n\n" | |
| except Exception as e: | |
| log.error(f"[Stream] ❌ 流错误 | id={request_id[:8]}... | error={e}") | |
| error_body = {"error": {"code": 500, "message": str(e), "status": "INTERNAL"}} | |
| yield f"data: {json.dumps(error_body)}\n\n" | |
| # ============================================================ | |
| # WebSocket 端点 | |
| # ============================================================ | |
| async def websocket_endpoint(websocket: WebSocket): | |
| """WSClient WebSocket 连接端点""" | |
| await websocket.accept() | |
| client_ip = websocket.client.host if websocket.client else "unknown" | |
| log.info(f"[WS] 🔌 新连接 | ip={client_ip}") | |
| log.info(f"[WS] └─ 等待 REGISTER 消息 (timeout={Config.REGISTER_TIMEOUT}s)...") | |
| try: | |
| while True: | |
| data = await websocket.receive_json() | |
| await handle_ws_message(websocket, data) | |
| except WebSocketDisconnect: | |
| log.warning(f"[WS] 连接断开 | ip={client_ip}") | |
| except Exception as e: | |
| log.error(f"[WS] 连接错误 | ip={client_ip} | error={e}") | |
| finally: | |
| await ws_manager.unregister(websocket) | |
| await request_manager.fail_all_requests("WSClient disconnected") | |
| async def handle_ws_message(websocket: WebSocket, data: Dict): | |
| """处理 WSClient 消息""" | |
| msg_type = data.get("type", "") | |
| msg_id = data.get("id", "") | |
| timestamp = data.get("timestamp", 0) | |
| # 简短日志 | |
| log.debug(f"[WS] 📩 收到消息 | type={msg_type} | id={msg_id[:8] if msg_id else 'N/A'}...") | |
| # 根据消息类型处理 | |
| if msg_type == MessageType.REGISTER.value or msg_type == "register": | |
| payload = data.get("payload", {}) | |
| await ws_manager.register(websocket, payload) | |
| elif msg_type == MessageType.RESPONSE.value or msg_type == "response": | |
| body = data.get("body", {}) | |
| log.info(f"[WS] 📥 收到响应 | id={msg_id[:8]}...") | |
| log.debug(f"[WS] └─ body: {ws_manager._truncate_body(body)}") | |
| request_manager.resolve_request(msg_id, body) | |
| elif msg_type == MessageType.CHUNK.value or msg_type == "chunk": | |
| body = data.get("body", {}) | |
| index = data.get("index", 0) | |
| log.debug(f"[WS] 📦 收到数据块 | id={msg_id[:8]}... | index={index}") | |
| request_manager.push_chunk(msg_id, body) | |
| elif msg_type == MessageType.END.value or msg_type == "end": | |
| body = data.get("body") | |
| log.info(f"[WS] 🏁 流结束 | id={msg_id[:8]}...") | |
| request_manager.end_stream(msg_id, body) | |
| elif msg_type == MessageType.ERROR.value or msg_type == "error": | |
| body = data.get("body", {}) | |
| log.error(f"[WS] ⚠️ 收到错误 | id={msg_id[:8]}... | body={body}") | |
| request_manager.fail_request(msg_id, body) | |
| elif msg_type == MessageType.PONG.value or msg_type == "pong": | |
| log.debug(f"[WS] 💓 收到心跳响应 | id={msg_id[:8]}...") | |
| else: | |
| log.warning(f"[WS] ❓ 未知消息类型 | type={msg_type}") | |
| # ============================================================ | |
| # 启动 | |
| # ============================================================ | |
| if __name__ == "__main__": | |
| uvicorn.run( | |
| "main:app", | |
| host=Config.HOST, | |
| port=Config.PORT, | |
| reload=True, | |
| log_level="warning" # 使用我们自己的日志 | |
| ) |