import os import sys import time import uuid import logging import json import re # 统一导入,避免重复 from logging.handlers import RotatingFileHandler from datetime import datetime from typing import List, Optional, Dict, Any, Union # 统一导入类型注解 import requests from fastapi import FastAPI, Request, HTTPException, status from fastapi.responses import JSONResponse, StreamingResponse from pydantic import BaseModel, ValidationError from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint from starlette.responses import Response # -------------------------- # 1. 日志配置(保持DEBUG级别,便于调试) # -------------------------- def setup_logging(): log_format = '%(asctime)s - %(name)s - %(levelname)s - request_id=%(request_id)s - %(message)s' formatter = logging.Formatter(log_format) class DefaultRequestIDFilter(logging.Filter): def filter(self, record): if not hasattr(record, 'request_id'): record.request_id = 'unknown' return True log_dir = "logs" os.makedirs(log_dir, exist_ok=True) # 文件日志(轮转) file_handler = RotatingFileHandler( f"{log_dir}/app.log", maxBytes=1024 * 1024 * 10, # 10MB backupCount=10, encoding='utf-8' # 确保中文日志不乱码 ) file_handler.setFormatter(formatter) file_handler.addFilter(DefaultRequestIDFilter()) # 控制台日志 console_handler = logging.StreamHandler(sys.stdout) console_handler.setFormatter(formatter) console_handler.addFilter(DefaultRequestIDFilter()) # 全局日志配置 logging.basicConfig( level=logging.DEBUG, handlers=[file_handler, console_handler] ) # 业务日志器 logger = logging.getLogger("api_proxy") logger.setLevel(logging.DEBUG) return logger logger = setup_logging() # -------------------------- # 2. 全局配置(环境变量优先,默认值兜底) # -------------------------- EMBEDDING_API_BASE = os.getenv("EMBEDDING_API_BASE", "https://fiewolf1000-gpt-text-api.hf.space/v1") CHAT_API_BASE = os.getenv("CHAT_API_BASE", "https://free.v36.cm/v1") EMBEDDING_API_KEY = os.getenv("EMBEDDING_API_KEY", "sk-OR0eRlmirRsSdCrA9bAbEa805d5f42448b7d0d184b268791") CHAT_API_KEY = os.getenv("CHAT_API_KEY", "sk-tLB1LCAGfBVMW1mt54F1A5026dD246E582809454Ea93E430") ALLOWED_CLIENT_API_KEYS = set(os.getenv("ALLOWED_CLIENT_API_KEYS", "sk-tLB1LCAGfBVMW1mt54F1A5026dD246E582809454Ea93E430,sk-client-456").split(',')) # 支持的模型列表(严格校验) SUPPORTED_MODELS = { "embedding": ["text-embedding-ada-002", "text-embedding-3-small", "text-embedding-3-large"], "chat": ["gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-4", "gpt-4-32k", "gpt-4o-mini"] } # FastAPI应用实例 app = FastAPI(title="API Proxy Service") # -------------------------- # 3. 中间件(请求ID生成与日志绑定) # -------------------------- class RequestIDLogAdapter(logging.LoggerAdapter): def process(self, msg, kwargs): return f"{msg}", {**kwargs, 'extra': {**self.extra, **kwargs.get('extra', {})}} class RequestIDMiddleware(BaseHTTPMiddleware): async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response: # 生成/获取请求ID(优先从Header取,无则自动生成) request_id = request.headers.get("X-Request-ID", str(uuid.uuid4())) request.state.request_id = request_id # 绑定请求ID到日志 request.state.logger = RequestIDLogAdapter(logger, {'request_id': request_id}) # 记录请求入口 request.state.logger.info( f"接收请求: {request.method} {request.url.path},客户端: {request.client.host}:{request.client.port}" ) # 日志脱敏(排除敏感头) filtered_headers = {k: v for k, v in request.headers.items() if k.lower() not in ['authorization', 'cookie']} request.state.logger.debug(f"请求头: {filtered_headers}") start_time = time.time() try: response = await call_next(request) except Exception as e: request.state.logger.error(f"处理请求异常: {str(e)}", exc_info=True) raise finally: # 记录请求耗时 process_time = time.time() - start_time request.state.logger.info( f"请求完成: {request.method} {request.url.path},状态码: {response.status_code},处理时间: {process_time:.6f}秒" ) # 响应头携带请求ID(便于追踪) response.headers["X-Request-ID"] = request_id return response app.add_middleware(RequestIDMiddleware) # -------------------------- # 4. 工具函数(客户端API Key验证) # -------------------------- def validate_client_api_key(request: Request) -> str: logger = request.state.logger logger.info("进入客户端API Key验证流程") # 1. 检查Authorization头是否存在 auth_header = request.headers.get("Authorization") if not auth_header: logger.warning("未检测到Authorization请求头") raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="未提供API密钥,请使用 Bearer 格式在Authorization头中携带" ) # 2. 检查Authorization格式 if not auth_header.startswith("Bearer "): logger.warning(f"Authorization格式错误,原始值: {auth_header[:10]}***") raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Authorization头格式错误,正确格式为: Bearer " ) # 3. 提取并验证API Key client_api_key = auth_header[len("Bearer "):].strip() masked_key = f"{client_api_key[:4]}***{client_api_key[-4:]}" # 日志脱敏 if client_api_key not in ALLOWED_CLIENT_API_KEYS: logger.warning(f"API Key验证失败,密钥: {masked_key}") raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="无效或未授权的API密钥" ) logger.info(f"API Key验证通过,密钥: {masked_key}") return client_api_key # -------------------------- # 5. Pydantic模型(请求格式校验) # -------------------------- class EmbeddingRequest(BaseModel): input: str | List[str] model: str encoding_format: Optional[str] = "float" user: Optional[str] = None class MessageContent(BaseModel): type: str text: str class Message(BaseModel): role: str content: Union[str, List[MessageContent]] # 支持纯文本或多内容类型 class ChatRequest(BaseModel): model: str messages: List[Message] temperature: Optional[float] = 1.0 top_p: Optional[float] = 1.0 n: Optional[int] = 1 stream: Optional[bool] = False stop: Optional[str | List[str]] = None max_tokens: Optional[int] = None presence_penalty: Optional[float] = 0.0 frequency_penalty: Optional[float] = 0.0 logit_bias: Optional[Dict[str, float]] = None user: Optional[str] = None # -------------------------- # 6. 核心工具函数(请求转发,含中文修复) # -------------------------- def forward_request( url: str, api_key: str, payload: Dict[str, Any], logger: logging.Logger, stream: bool = False ) -> Any: """ 转发请求到上游API,支持流式/非流式响应: - 修复中文流式片段截断导致的解析失败 - 统一覆盖model字段为客户端请求值 - 增强日志记录(脱敏+关键信息) """ # 1. 构建请求头(Authorization脱敏) headers = { "Content-Type": "application/json", "Authorization": f"Bearer {api_key}" if api_key else "" } logger.debug(f"上游请求URL: {url}") safe_headers = headers.copy() if "Authorization" in safe_headers: safe_headers["Authorization"] = safe_headers["Authorization"][:10] + "***" logger.debug(f"上游请求头: {safe_headers}") # 2. 记录脱敏请求体(避免敏感信息泄露) if "messages" in payload: safe_messages = [] for msg in payload["messages"]: safe_msg = msg.copy() # 文本内容截断(日志可读性优化) if isinstance(safe_msg.get("content"), str): safe_msg["content"] = safe_msg["content"][:100] + "..." if len(safe_msg["content"]) > 100 else safe_msg["content"] elif isinstance(safe_msg.get("content"), list): for item in safe_msg["content"]: if "text" in item: item["text"] = item["text"][:100] + "..." if len(item["text"]) > 100 else item["text"] safe_messages.append(safe_msg) safe_payload = payload.copy() safe_payload["messages"] = safe_messages logger.debug(f"上游请求体(脱敏): {json.dumps(safe_payload, ensure_ascii=False)}") else: logger.debug(f"上游请求体(脱敏): {json.dumps(payload, ensure_ascii=False)[:500]}...") try: if stream: def stream_generator(): logger.info("启动上游流式响应接收") request_start_time = time.time() with requests.post( url, json=payload, headers=headers, stream=True, timeout=60, # 超时保护 verify=True # 生产环境开启SSL验证 ) as r: # 记录连接耗时 conn_time = time.time() - request_start_time logger.info(f"上游连接建立,耗时: {conn_time:.3f}秒,状态码: {r.status_code}") # 上游非200状态码处理 if r.status_code != 200: error_msg = f"上游请求失败: {r.status_code},响应片段: {r.text[:500]}" logger.error(error_msg) yield f'data: {{"error": "{error_msg}", "code": {r.status_code}}}\n\n' return # 流式处理变量 chunk_count = 0 total_bytes = 0 target_model = payload.get("model") incomplete_json = "" # 缓存截断的JSON片段(中文修复核心) for line in r.iter_lines(decode_unicode=True): logger.debug(f"原始上游行内容: {line}") if not line: continue # 跳过空行 chunk_count += 1 raw_len = len(line) total_bytes += raw_len processed_line = line # 仅处理data: 开头的流式片段 if target_model and line.startswith("data: "): data_part = line[6:] # 处理[DONE]结束标识 if data_part.strip() == "[DONE]": # 先补全缓存的不完整JSON(若有) if incomplete_json: try: completed_json = incomplete_json + "}" # 补全JSON闭合符 data = json.loads(completed_json) # 覆盖model字段 if "model" in data: data["model"] = target_model if "choices" in data: for choice in data["choices"]: if "model" in choice: choice["model"] = target_model # 发送补全后的片段 yield f'data: {json.dumps(data, ensure_ascii=False)}\n' logger.debug(f"补全并发送缓存片段(片段#{chunk_count-1})") incomplete_json = "" except json.JSONDecodeError: logger.warning(f"补全缓存JSON失败,丢弃片段#{chunk_count-1}") incomplete_json = "" # 发送[DONE]标识 processed_line = "data: [DONE]\n" # 处理普通JSON片段 else: full_data_part = incomplete_json + data_part try: # 尝试解析合并后的JSON data = json.loads(full_data_part) # 覆盖model字段 if "model" in data: data["model"] = target_model if "choices" in data: for choice in data["choices"]: if "model" in choice: choice["model"] = target_model # 生成处理后的片段(中文不转义) processed_line = f'data: {json.dumps(data, ensure_ascii=False)}\n' incomplete_json = "" # 解析成功,清空缓存 except json.JSONDecodeError as e: # 仅处理"JSON未闭合"错误(中文截断场景) if "unexpected end of JSON input" in str(e): incomplete_json = full_data_part logger.debug(f"缓存截断片段#{chunk_count}(长度: {len(incomplete_json)})") continue # 不发送,等待下一段补全 # 其他解析错误,尝试修复编码 else: logger.warning(f"片段#{chunk_count}解析失败,尝试修复编码: {str(e)}") try: # 修复UTF-8转义乱码 fixed_data = json.loads(data_part.encode('utf-8').decode('unicode_escape')) processed_line = f'data: {json.dumps(fixed_data, ensure_ascii=False)}\n' except: # 最终失败,返回原始内容(避免断流) processed_line = f"{line}\n" else: processed_line = f"{line}\n" # 日志记录(每10个片段或大片段重点记录) if chunk_count % 10 == 0 or raw_len > 1024: logger.debug( f"接收上游片段 #{chunk_count},大小: {raw_len}字节," f"内容前200字符: {processed_line[:200]}..." ) # 转发片段给客户端 yield processed_line # 流式结束,处理剩余缓存 if incomplete_json: try: completed_json = incomplete_json + "}" data = json.loads(completed_json) if "model" in data: data["model"] = target_model if "choices" in data: for choice in data["choices"]: if "model" in choice: choice["model"] = target_model yield f'data: {json.dumps(data, ensure_ascii=False)}\n' logger.debug("流式结束,补全最后一段缓存") except: logger.warning("流式结束,丢弃未补全的缓存片段") logger.info(f"上游流式响应完成,共{chunk_count}个片段,总大小: {total_bytes/1024:.2f}KB") return stream_generator() # 非流式请求处理 else: logger.info("发送上游非流式请求") request_start_time = time.time() response = requests.post( url, json=payload, headers=headers, timeout=60, verify=True ) # 记录响应基础信息 resp_time = time.time() - request_start_time logger.info( f"上游非流式响应接收,耗时: {resp_time:.3f}秒,状态码: {response.status_code}," f"响应大小: {len(response.content)/1024:.2f}KB" ) logger.debug(f"上游响应头: {dict(response.headers)}") # 响应内容日志(截断过长内容) resp_text = response.text if len(resp_text) > 1000: logger.debug(f"上游响应内容(截断): {resp_text[:1000]}...") else: logger.debug(f"上游响应内容: {resp_text}") # 上游错误状态码处理 if response.status_code != 200: error_msg = f"上游请求失败: {response.status_code},响应: {resp_text[:500]}" logger.error(error_msg) raise HTTPException(status_code=response.status_code, detail=error_msg) # 处理响应(覆盖model字段) resp_data = response.json() target_model = payload.get("model") if target_model and "model" in resp_data: resp_data["model"] = target_model if target_model and "choices" in resp_data: for choice in resp_data["choices"]: if "model" in choice: choice["model"] = target_model return resp_data # 网络异常处理(分类提示) except requests.exceptions.RequestException as e: error_type = "" if isinstance(e, requests.exceptions.Timeout): error_type = "(请求超时)" elif isinstance(e, requests.exceptions.ConnectionError): error_type = "(连接失败)" elif isinstance(e, requests.exceptions.SSLError): error_type = "(SSL证书错误)" error_msg = f"与上游API通信异常{error_type}: {str(e)}" logger.error(error_msg, exc_info=True) raise HTTPException(status_code=500, detail=error_msg) # -------------------------- # 7. 接口实现(健康检查、嵌入、聊天) # -------------------------- @app.get("/health") async def health_check(request: Request): """健康检查接口(监控用)""" logger = request.state.logger logger.info("处理健康检查请求") # 检查上游API可用性 embedding_healthy = False try: requests.head(EMBEDDING_API_BASE, timeout=5) embedding_healthy = True except Exception as e: logger.warning(f"嵌入API健康检查失败: {str(e)}") chat_healthy = False try: requests.head(CHAT_API_BASE, timeout=5) chat_healthy = True except Exception as e: logger.warning(f"聊天API健康检查失败: {str(e)}") # 构建健康检查结果 result = { "status": "healthy" if (embedding_healthy and chat_healthy) else "degraded", "timestamp": datetime.utcnow().isoformat() + "Z", "supported_models": SUPPORTED_MODELS, "services": { "embedding_api": { "status": "healthy" if embedding_healthy else "unhealthy", "base_url": EMBEDDING_API_BASE }, "chat_api": { "status": "healthy" if chat_healthy else "unhealthy", "base_url": CHAT_API_BASE } } } logger.info(f"健康检查完成,整体状态: {result['status']}") return result @app.post("/v1/embeddings") async def create_embedding(request: Request, req_body: EmbeddingRequest): """嵌入接口(文本转向量)""" logger = request.state.logger logger.info("处理嵌入请求") # 1. 验证客户端API Key validate_client_api_key(request) # 2. 验证模型是否支持 if req_body.model not in SUPPORTED_MODELS["embedding"]: error_msg = f"不支持的嵌入模型: {req_body.model},支持模型: {SUPPORTED_MODELS['embedding']}" logger.error(error_msg) raise HTTPException(status_code=400, detail=error_msg) logger.info(f"嵌入模型验证通过: {req_body.model}") # 3. 转发请求到上游 payload = req_body.model_dump(exclude_unset=True) # 排除未设置的字段 target_url = f"{EMBEDDING_API_BASE}/embeddings" logger.debug(f"转发嵌入请求到: {target_url}, payload: {json.dumps(payload, ensure_ascii=False)[:500]}...") response_data = forward_request(target_url, EMBEDDING_API_KEY, payload, logger) logger.info("嵌入请求处理完成,返回响应") return response_data @app.post("/v1/chat/completions") async def create_chat_completion(request: Request): """聊天接口(支持流式响应)""" logger = request.state.logger logger.info("处理聊天请求") try: # 1. 解析请求体(兼容原始JSON) req_body_dict = await request.json() logger.debug(f"原始聊天请求体: {json.dumps(req_body_dict, ensure_ascii=False)[:1000]}...") # 2. 验证客户端API Key validate_client_api_key(request) # 3. 校验请求体格式 req_body = ChatRequest(**req_body_dict) logger.info(f"聊天请求解析完成: 模型={req_body.model},流式={req_body.stream}") logger.debug(f"最终发送的messages: {json.dumps([msg.model_dump() for msg in req_body.messages], ensure_ascii=False)}") # 4. 验证模型是否支持 if req_body.model not in SUPPORTED_MODELS["chat"]: error_msg = f"不支持的聊天模型: {req_body.model},支持模型: {SUPPORTED_MODELS['chat']}" logger.error(error_msg) raise HTTPException(status_code=400, detail=error_msg) logger.info(f"聊天模型验证通过: {req_body.model}") # 5. 构建转发 payload(排除未设置字段) payload = req_body.model_dump(exclude_unset=True) target_url = f"{CHAT_API_BASE}/chat/completions" logger.debug(f"转发聊天请求到: {target_url}") # 6. 处理流式/非流式响应 if req_body.stream: logger.info("启用流式响应模式") # 上游流式生成器 upstream_generator = forward_request(target_url, CHAT_API_KEY, payload, logger, stream=True) # 客户端流式包装器(增加日志) async def client_stream(): client_chunk_count = 0 try: for chunk in upstream_generator: client_chunk_count += 1 logger.debug(f"发送给客户端片段 #{client_chunk_count}: {chunk.strip()[:200]}...") yield chunk logger.info(f"流式响应完成,共发送 {client_chunk_count} 个片段") except Exception as e: logger.error(f"流式响应异常: {str(e)}", exc_info=True) raise return StreamingResponse(client_stream(), media_type="text/event-stream") # 非流式响应 else: logger.info("启用非流式响应模式") response_data = forward_request(target_url, CHAT_API_KEY, payload, logger) logger.debug(f"非流式聊天响应大小: {len(str(response_data))}字符") logger.info("非流式聊天请求处理完成") return response_data # 异常捕获与处理 except json.JSONDecodeError: logger.error("聊天请求体不是有效JSON") raise HTTPException(status_code=400, detail="请求体必须是有效JSON格式") except ValidationError as e: logger.error(f"聊天请求参数错误: {str(e)}") raise HTTPException(status_code=422, detail=f"请求参数错误: {str(e)}") except Exception as e: logger.error(f"聊天请求处理异常: {str(e)}", exc_info=True) raise # -------------------------- # 8. 全局异常处理器(统一响应格式) # -------------------------- @app.exception_handler(HTTPException) async def http_exception_handler(request: Request, exc: HTTPException): """处理FastAPI HTTP异常""" logger = request.state.logger if hasattr(request.state, 'logger') else logging.getLogger("api_proxy") error_msg = f"HTTP异常: {exc.status_code} - {exc.detail}" logger.error(error_msg) return JSONResponse( status_code=exc.status_code, content={ "error": { "message": exc.detail, "type": "invalid_request_error", "request_id": getattr(request.state, 'request_id', 'unknown') } } ) @app.exception_handler(Exception) async def general_exception_handler(request: Request, exc: Exception): """处理全局未捕获异常""" logger = request.state.logger if hasattr(request.state, 'logger') else logging.getLogger("api_proxy") error_msg = f"服务器内部异常: {str(exc)}" logger.error(error_msg, exc_info=True) return JSONResponse( status_code=500, content={ "error": { "message": "服务器内部错误,请联系管理员", "type": "server_error", "request_id": getattr(request.state, 'request_id', 'unknown') } } ) # -------------------------- # 9. 服务启动入口 # -------------------------- if __name__ == "__main__": import uvicorn port = int(os.getenv("PORT", 7860)) logger.info(f"启动API代理服务,端口: {port},允许客户端API Key数量: {len(ALLOWED_CLIENT_API_KEYS)}") # 禁用uvicorn默认日志(使用自定义日志) uvicorn.run(app, host="0.0.0.0", port=port, log_config=None)