Spaces:
Sleeping
Sleeping
| import os | |
| import sys | |
| import time | |
| import uuid | |
| import logging | |
| import json | |
| import re | |
| from logging.handlers import RotatingFileHandler | |
| from datetime import datetime | |
| from typing import List, Optional, Dict, Any, Union | |
| import requests | |
| from fastapi import FastAPI, Request, HTTPException, status | |
| from fastapi.responses import JSONResponse, StreamingResponse | |
| from pydantic import BaseModel, ValidationError | |
| from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint | |
| from starlette.responses import Response | |
| # -------------------------- | |
| # 1. 日志配置(保持DEBUG级别,便于调试) | |
| # -------------------------- | |
| def setup_logging(): | |
| log_format = '%(asctime)s - %(name)s - %(levelname)s - request_id=%(request_id)s - %(message)s' | |
| formatter = logging.Formatter(log_format) | |
| class DefaultRequestIDFilter(logging.Filter): | |
| def filter(self, record): | |
| if not hasattr(record, 'request_id'): | |
| record.request_id = 'unknown' | |
| return True | |
| log_dir = "logs" | |
| os.makedirs(log_dir, exist_ok=True) | |
| # 文件日志(轮转) | |
| file_handler = RotatingFileHandler( | |
| f"{log_dir}/app.log", | |
| maxBytes=1024 * 1024 * 10, # 10MB | |
| backupCount=10, | |
| encoding='utf-8' # 确保中文日志不乱码 | |
| ) | |
| file_handler.setFormatter(formatter) | |
| file_handler.addFilter(DefaultRequestIDFilter()) | |
| # 控制台日志 | |
| console_handler = logging.StreamHandler(sys.stdout) | |
| console_handler.setFormatter(formatter) | |
| console_handler.addFilter(DefaultRequestIDFilter()) | |
| # 全局日志配置 | |
| logging.basicConfig( | |
| level=logging.DEBUG, | |
| handlers=[file_handler, console_handler] | |
| ) | |
| # 业务日志器 | |
| logger = logging.getLogger("api_proxy") | |
| logger.setLevel(logging.DEBUG) | |
| return logger | |
| logger = setup_logging() | |
| # -------------------------- | |
| # 2. 全局配置(环境变量优先,默认值兜底) | |
| # -------------------------- | |
| EMBEDDING_API_BASE = os.getenv("EMBEDDING_API_BASE", "https://fiewolf1000-gpt-text-api.hf.space/v1") | |
| CHAT_API_BASE = os.getenv("CHAT_API_BASE", "https://free.v36.cm/v1") | |
| EMBEDDING_API_KEY = os.getenv("EMBEDDING_API_KEY", "sk-OR0eRlmirRsSdCrA9bAbEa805d5f42448b7d0d184b268791") | |
| CHAT_API_KEY = os.getenv("CHAT_API_KEY", "sk-tLB1LCAGfBVMW1mt54F1A5026dD246E582809454Ea93E430") | |
| ALLOWED_CLIENT_API_KEYS = set(os.getenv("ALLOWED_CLIENT_API_KEYS", "sk-tLB1LCAGfBVMW1mt54F1A5026dD246E582809454Ea93E430,sk-client-456").split(',')) | |
| # 支持的模型列表(严格校验) | |
| SUPPORTED_MODELS = { | |
| "embedding": ["text-embedding-ada-002", "text-embedding-3-small", "text-embedding-3-large"], | |
| "chat": ["gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-4", "gpt-4-32k", "gpt-4o-mini"] | |
| } | |
| # FastAPI应用实例 | |
| app = FastAPI(title="API Proxy Service") | |
| # -------------------------- | |
| # 3. 中间件(请求ID生成与日志绑定) | |
| # -------------------------- | |
| class RequestIDLogAdapter(logging.LoggerAdapter): | |
| def process(self, msg, kwargs): | |
| return f"{msg}", {** kwargs, 'extra': {**self.extra, **kwargs.get('extra', {})}} | |
| class RequestIDMiddleware(BaseHTTPMiddleware): | |
| async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response: | |
| # 生成/获取请求ID(优先从Header取,无则自动生成) | |
| request_id = request.headers.get("X-Request-ID", str(uuid.uuid4())) | |
| request.state.request_id = request_id | |
| # 绑定请求ID到日志 | |
| request.state.logger = RequestIDLogAdapter(logger, {'request_id': request_id}) | |
| # 记录请求入口 | |
| request.state.logger.info( | |
| f"接收请求: {request.method} {request.url.path},客户端: {request.client.host}:{request.client.port}" | |
| ) | |
| # 日志脱敏(排除敏感头) | |
| filtered_headers = {k: v for k, v in request.headers.items() if k.lower() not in ['authorization', 'cookie']} | |
| request.state.logger.debug(f"请求头: {filtered_headers}") | |
| start_time = time.time() | |
| try: | |
| response = await call_next(request) | |
| except Exception as e: | |
| request.state.logger.error(f"处理请求异常: {str(e)}", exc_info=True) | |
| raise | |
| finally: | |
| # 记录请求耗时 | |
| process_time = time.time() - start_time | |
| request.state.logger.info( | |
| f"请求完成: {request.method} {request.url.path},状态码: {response.status_code},处理时间: {process_time:.6f}秒" | |
| ) | |
| # 响应头携带请求ID(便于追踪) | |
| response.headers["X-Request-ID"] = request_id | |
| return response | |
| app.add_middleware(RequestIDMiddleware) | |
| # -------------------------- | |
| # 4. 工具函数(客户端API Key验证) | |
| # -------------------------- | |
| def validate_client_api_key(request: Request) -> str: | |
| logger = request.state.logger | |
| logger.info("进入客户端API Key验证流程") | |
| # 1. 检查Authorization头是否存在 | |
| auth_header = request.headers.get("Authorization") | |
| if not auth_header: | |
| logger.warning("未检测到Authorization请求头") | |
| raise HTTPException( | |
| status_code=status.HTTP_401_UNAUTHORIZED, | |
| detail="未提供API密钥,请使用 Bearer <API_KEY> 格式在Authorization头中携带" | |
| ) | |
| # 2. 检查Authorization格式 | |
| if not auth_header.startswith("Bearer "): | |
| logger.warning(f"Authorization格式错误,原始值: {auth_header[:10]}***") | |
| raise HTTPException( | |
| status_code=status.HTTP_401_UNAUTHORIZED, | |
| detail="Authorization头格式错误,正确格式为: Bearer <API_KEY>" | |
| ) | |
| # 3. 提取并验证API Key | |
| client_api_key = auth_header[len("Bearer "):].strip() | |
| masked_key = f"{client_api_key[:4]}***{client_api_key[-4:]}" # 日志脱敏 | |
| if client_api_key not in ALLOWED_CLIENT_API_KEYS: | |
| logger.warning(f"API Key验证失败,密钥: {masked_key}") | |
| raise HTTPException( | |
| status_code=status.HTTP_401_UNAUTHORIZED, | |
| detail="无效或未授权的API密钥" | |
| ) | |
| logger.info(f"API Key验证通过,密钥: {masked_key}") | |
| return client_api_key | |
| # -------------------------- | |
| # 5. Pydantic模型(请求格式校验) | |
| # -------------------------- | |
| class EmbeddingRequest(BaseModel): | |
| input: str | List[str] | |
| model: str | |
| encoding_format: Optional[str] = "float" | |
| user: Optional[str] = None | |
| class MessageContent(BaseModel): | |
| type: str | |
| text: str | |
| class Message(BaseModel): | |
| role: str | |
| content: Union[str, List[MessageContent]] # 支持纯文本或多内容类型 | |
| class ChatRequest(BaseModel): | |
| model: str | |
| messages: List[Message] | |
| temperature: Optional[float] = 1.0 | |
| top_p: Optional[float] = 1.0 | |
| n: Optional[int] = 1 | |
| stream: Optional[bool] = False | |
| stop: Optional[str | List[str]] = None | |
| max_tokens: Optional[int] = None | |
| presence_penalty: Optional[float] = 0.0 | |
| frequency_penalty: Optional[float] = 0.0 | |
| logit_bias: Optional[Dict[str, float]] = None | |
| user: Optional[str] = None | |
| # -------------------------- | |
| # 6. 核心工具函数(请求转发,透传Content-Type) | |
| # -------------------------- | |
| def forward_request( | |
| url: str, | |
| api_key: str, | |
| payload: Dict[str, Any], | |
| logger: logging.Logger, | |
| stream: bool = False | |
| ) -> Any: | |
| """ | |
| 转发请求到上游API,仅替换API Key,保持原始数据和响应头不变 | |
| """ | |
| # 1. 构建请求头(仅替换API Key) | |
| headers = { | |
| "Content-Type": "application/json", | |
| "Authorization": f"Bearer {api_key}" if api_key else "" | |
| } | |
| logger.debug(f"上游请求URL: {url}") | |
| safe_headers = headers.copy() | |
| if "Authorization" in safe_headers: | |
| safe_headers["Authorization"] = safe_headers["Authorization"][:10] + "***" | |
| logger.debug(f"上游请求头: {safe_headers}") | |
| # 2. 记录请求体日志(仅做长度限制,不修改内容) | |
| payload_str = json.dumps(payload, ensure_ascii=False) | |
| if len(payload_str) > 1000: | |
| logger.debug(f"上游请求体(截断): {payload_str[:1000]}...") | |
| else: | |
| logger.debug(f"上游请求体: {payload_str}") | |
| try: | |
| if stream: | |
| def stream_generator(): | |
| logger.info("启动上游流式响应接收") | |
| request_start_time = time.time() | |
| with requests.post( | |
| url, | |
| json=payload, | |
| headers=headers, | |
| stream=True, | |
| timeout=60, # 超时保护 | |
| verify=True # 生产环境开启SSL验证 | |
| ) as r: | |
| # 记录连接耗时 | |
| conn_time = time.time() - request_start_time | |
| logger.info(f"上游连接建立,耗时: {conn_time:.3f}秒,状态码: {r.status_code}") | |
| # 上游非200状态码处理 | |
| if r.status_code != 200: | |
| error_msg = f"上游请求失败: {r.status_code},响应片段: {r.text[:500]}" | |
| logger.error(error_msg) | |
| yield f'data: {{"error": "{error_msg}", "code": {r.status_code}}}\n\n' | |
| return | |
| # 流式处理变量 | |
| chunk_count = 0 | |
| total_bytes = 0 | |
| for line in r.iter_lines(decode_unicode=False): # 改为二进制透传 | |
| if not line: | |
| continue # 跳过空行 | |
| # 转换为字符串(假设上游是UTF-8,若需动态判断更复杂) | |
| try: | |
| line_str = line.decode('utf-8') | |
| except UnicodeDecodeError: | |
| line_str = "[二进制数据无法解码]" | |
| logger.warning("流式响应包含非UTF-8字节,可能导致乱码") | |
| logger.debug(f"原始上游行内容: {line_str}") | |
| chunk_count += 1 | |
| raw_len = len(line) | |
| total_bytes += raw_len | |
| # 注意:这里要返回字节,而非字符串 | |
| yield line + b"\n" # 二进制透传,客户端自行解码 | |
| logger.info(f"上游流式响应完成,共{chunk_count}个片段,总大小: {total_bytes/1024:.2f}KB") | |
| # 返回生成器,后续在接口中获取并透传Content-Type | |
| return stream_generator() | |
| # 非流式请求处理 | |
| else: | |
| logger.info("发送上游非流式请求") | |
| request_start_time = time.time() | |
| response = requests.post( | |
| url, | |
| json=payload, | |
| headers=headers, | |
| timeout=60, | |
| verify=True | |
| ) | |
| # 记录响应基础信息 | |
| resp_time = time.time() - request_start_time | |
| logger.info( | |
| f"上游非流式响应接收,耗时: {resp_time:.3f}秒,状态码: {response.status_code}," | |
| f"响应大小: {len(response.content)/1024:.2f}KB" | |
| ) | |
| logger.debug(f"上游响应头: {dict(response.headers)}") | |
| # 响应内容日志(截断过长内容) | |
| resp_text = response.text | |
| if len(resp_text) > 1000: | |
| logger.debug(f"上游响应内容(截断): {resp_text[:1000]}...") | |
| else: | |
| logger.debug(f"上游响应内容: {resp_text}") | |
| # 上游错误状态码处理 | |
| if response.status_code != 200: | |
| error_msg = f"上游请求失败: {response.status_code},响应: {resp_text[:500]}" | |
| logger.error(error_msg) | |
| raise HTTPException(status_code=response.status_code, detail=error_msg) | |
| # 直接返回原始响应数据 | |
| return response.json() | |
| # 网络异常处理(分类提示) | |
| except requests.exceptions.RequestException as e: | |
| error_type = "" | |
| if isinstance(e, requests.exceptions.Timeout): | |
| error_type = "(请求超时)" | |
| elif isinstance(e, requests.exceptions.ConnectionError): | |
| error_type = "(连接失败)" | |
| elif isinstance(e, requests.exceptions.SSLError): | |
| error_type = "(SSL证书错误)" | |
| error_msg = f"与上游API通信异常{error_type}: {str(e)}" | |
| logger.error(error_msg, exc_info=True) | |
| raise HTTPException(status_code=500, detail=error_msg) | |
| # -------------------------- | |
| # 7. 接口实现(健康检查、嵌入、聊天) | |
| # -------------------------- | |
| async def health_check(request: Request): | |
| """健康检查接口(监控用)""" | |
| logger = request.state.logger | |
| logger.info("处理健康检查请求") | |
| # 检查上游API可用性 | |
| embedding_healthy = False | |
| try: | |
| requests.head(EMBEDDING_API_BASE, timeout=5) | |
| embedding_healthy = True | |
| except Exception as e: | |
| logger.warning(f"嵌入API健康检查失败: {str(e)}") | |
| chat_healthy = False | |
| try: | |
| requests.head(CHAT_API_BASE, timeout=5) | |
| chat_healthy = True | |
| except Exception as e: | |
| logger.warning(f"聊天API健康检查失败: {str(e)}") | |
| overall_healthy = embedding_healthy and chat_healthy | |
| status_code = status.HTTP_200_OK if overall_healthy else status.HTTP_503_SERVICE_UNAVAILABLE | |
| return JSONResponse( | |
| status_code=status_code, | |
| content={ | |
| "status": "healthy" if overall_healthy else "unhealthy", | |
| "timestamp": datetime.utcnow().isoformat(), | |
| "services": { | |
| "embedding": {"healthy": embedding_healthy, "base_url": EMBEDDING_API_BASE}, | |
| "chat": {"healthy": chat_healthy, "base_url": CHAT_API_BASE} | |
| } | |
| } | |
| ) | |
| async def create_embedding(request: Request, req: EmbeddingRequest): | |
| """嵌入接口转发""" | |
| logger = request.state.logger | |
| logger.info("处理嵌入请求") | |
| # 验证客户端API Key | |
| validate_client_api_key(request) | |
| # 验证模型是否支持 | |
| if req.model not in SUPPORTED_MODELS["embedding"]: | |
| logger.warning(f"不支持的嵌入模型: {req.model}") | |
| raise HTTPException( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| detail=f"不支持的嵌入模型: {req.model},支持的模型: {SUPPORTED_MODELS['embedding']}" | |
| ) | |
| # 构建上游请求参数(保持原始请求数据) | |
| payload = req.dict(exclude_unset=True) | |
| # 转发请求 | |
| url = f"{EMBEDDING_API_BASE}/embeddings" | |
| result = forward_request( | |
| url=url, | |
| api_key=EMBEDDING_API_KEY, | |
| payload=payload, | |
| logger=logger, | |
| stream=False | |
| ) | |
| return JSONResponse(content=result) | |
| async def create_chat_completion(request: Request, req: ChatRequest): | |
| """聊天接口转发(透传上游Content-Type)""" | |
| logger = request.state.logger | |
| logger.info("处理聊天请求") | |
| # 验证客户端API Key | |
| validate_client_api_key(request) | |
| # 验证模型是否支持 | |
| if req.model not in SUPPORTED_MODELS["chat"]: | |
| logger.warning(f"不支持的聊天模型: {req.model}") | |
| raise HTTPException( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| detail=f"不支持的聊天模型: {req.model},支持的模型: {SUPPORTED_MODELS['chat']}" | |
| ) | |
| # 构建上游请求参数(保持原始请求数据) | |
| payload = req.dict(exclude_unset=True) | |
| target_url = f"{CHAT_API_BASE}/chat/completions" | |
| # 处理流式/非流式响应 | |
| if req.stream: | |
| logger.info("启用流式响应模式(透传上游Content-Type)") | |
| # 步骤1:先获取上游的Content-Type(轻量HEAD请求,避免重复POST) | |
| upstream_content_type = "text/event-stream" # 默认值 | |
| try: | |
| with requests.head( | |
| target_url, | |
| headers={"Authorization": f"Bearer {CHAT_API_KEY}"}, | |
| timeout=10, | |
| verify=True | |
| ) as head_resp: | |
| if "Content-Type" in head_resp.headers: | |
| upstream_content_type = head_resp.headers["Content-Type"] | |
| logger.debug(f"获取上游Content-Type: {upstream_content_type}") | |
| except Exception as e: | |
| logger.warning(f"HEAD请求获取上游Content-Type失败,使用默认值: {str(e)}") | |
| # 步骤2:获取上游流式生成器 | |
| upstream_generator = forward_request( | |
| url=target_url, | |
| api_key=CHAT_API_KEY, | |
| payload=payload, | |
| logger=logger, | |
| stream=True | |
| ) | |
| # 步骤3:透传Content-Type返回流式响应 | |
| return StreamingResponse( | |
| upstream_generator, | |
| media_type=upstream_content_type, # 不修改,直接使用上游的Content-Type | |
| headers={"Cache-Control": "no-cache"} | |
| ) | |
| # 非流式响应 | |
| else: | |
| logger.info("启用非流式响应模式") | |
| response_data = forward_request( | |
| url=target_url, | |
| api_key=CHAT_API_KEY, | |
| payload=payload, | |
| logger=logger, | |
| stream=False | |
| ) | |
| logger.debug(f"非流式聊天响应大小: {len(str(response_data))}字符") | |
| logger.info("非流式聊天请求处理完成") | |
| return JSONResponse(content=response_data) | |
| # -------------------------- | |
| # 8. 全局异常处理器(统一响应格式) | |
| # -------------------------- | |
| async def http_exception_handler(request: Request, exc: HTTPException): | |
| """处理FastAPI HTTP异常""" | |
| logger = request.state.logger if hasattr(request.state, 'logger') else logging.getLogger("api_proxy") | |
| error_msg = f"HTTP异常: {exc.status_code} - {exc.detail}" | |
| logger.error(error_msg) | |
| return JSONResponse( | |
| status_code=exc.status_code, | |
| content={ | |
| "error": { | |
| "message": exc.detail, | |
| "type": "invalid_request_error", | |
| "request_id": getattr(request.state, 'request_id', 'unknown') | |
| } | |
| } | |
| ) | |
| async def general_exception_handler(request: Request, exc: Exception): | |
| """处理全局未捕获异常""" | |
| logger = request.state.logger if hasattr(request.state, 'logger') else logging.getLogger("api_proxy") | |
| error_msg = f"服务器内部异常: {str(exc)}" | |
| logger.error(error_msg, exc_info=True) | |
| return JSONResponse( | |
| status_code=500, | |
| content={ | |
| "error": { | |
| "message": "服务器内部错误,请联系管理员", | |
| "type": "server_error", | |
| "request_id": getattr(request.state, 'request_id', 'unknown') | |
| } | |
| } | |
| ) | |
| # -------------------------- | |
| # 9. 服务启动入口 | |
| # -------------------------- | |
| if __name__ == "__main__": | |
| import uvicorn | |
| port = int(os.getenv("PORT", 7860)) # 优先使用环境变量PORT,默认7860 | |
| logger.info(f"启动API代理服务,端口: {port},允许客户端API Key数量: {len(ALLOWED_CLIENT_API_KEYS)}") | |
| # 禁用uvicorn默认日志(使用自定义日志) | |
| uvicorn.run(app, host="0.0.0.0", port=port, log_config=None) |