gpt-api-zhen / app.py
fiewolf1000's picture
Update app.py
ca5b62a verified
import os
import sys
import time
import uuid
import logging
import json
import re # 统一导入,避免重复
from logging.handlers import RotatingFileHandler
from datetime import datetime
from typing import List, Optional, Dict, Any, Union # 统一导入类型注解
import requests
from fastapi import FastAPI, Request, HTTPException, status
from fastapi.responses import JSONResponse, StreamingResponse
from pydantic import BaseModel, ValidationError
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
from starlette.responses import Response
# --------------------------
# 1. 日志配置(保持DEBUG级别,便于调试)
# --------------------------
def setup_logging():
log_format = '%(asctime)s - %(name)s - %(levelname)s - request_id=%(request_id)s - %(message)s'
formatter = logging.Formatter(log_format)
class DefaultRequestIDFilter(logging.Filter):
def filter(self, record):
if not hasattr(record, 'request_id'):
record.request_id = 'unknown'
return True
log_dir = "logs"
os.makedirs(log_dir, exist_ok=True)
# 文件日志(轮转)
file_handler = RotatingFileHandler(
f"{log_dir}/app.log",
maxBytes=1024 * 1024 * 10, # 10MB
backupCount=10,
encoding='utf-8' # 确保中文日志不乱码
)
file_handler.setFormatter(formatter)
file_handler.addFilter(DefaultRequestIDFilter())
# 控制台日志
console_handler = logging.StreamHandler(sys.stdout)
console_handler.setFormatter(formatter)
console_handler.addFilter(DefaultRequestIDFilter())
# 全局日志配置
logging.basicConfig(
level=logging.DEBUG,
handlers=[file_handler, console_handler]
)
# 业务日志器
logger = logging.getLogger("api_proxy")
logger.setLevel(logging.DEBUG)
return logger
logger = setup_logging()
# --------------------------
# 2. 全局配置(环境变量优先,默认值兜底)
# --------------------------
EMBEDDING_API_BASE = os.getenv("EMBEDDING_API_BASE", "https://fiewolf1000-gpt-text-api.hf.space/v1")
CHAT_API_BASE = os.getenv("CHAT_API_BASE", "https://free.v36.cm/v1")
EMBEDDING_API_KEY = os.getenv("EMBEDDING_API_KEY", "sk-OR0eRlmirRsSdCrA9bAbEa805d5f42448b7d0d184b268791")
CHAT_API_KEY = os.getenv("CHAT_API_KEY", "sk-tLB1LCAGfBVMW1mt54F1A5026dD246E582809454Ea93E430")
ALLOWED_CLIENT_API_KEYS = set(os.getenv("ALLOWED_CLIENT_API_KEYS", "sk-tLB1LCAGfBVMW1mt54F1A5026dD246E582809454Ea93E430,sk-client-456").split(','))
# 支持的模型列表(严格校验)
SUPPORTED_MODELS = {
"embedding": ["text-embedding-ada-002", "text-embedding-3-small", "text-embedding-3-large"],
"chat": ["gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-4", "gpt-4-32k", "gpt-4o-mini"]
}
# FastAPI应用实例
app = FastAPI(title="API Proxy Service")
# --------------------------
# 3. 中间件(请求ID生成与日志绑定)
# --------------------------
class RequestIDLogAdapter(logging.LoggerAdapter):
def process(self, msg, kwargs):
return f"{msg}", {**kwargs, 'extra': {**self.extra, **kwargs.get('extra', {})}}
class RequestIDMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response:
# 生成/获取请求ID(优先从Header取,无则自动生成)
request_id = request.headers.get("X-Request-ID", str(uuid.uuid4()))
request.state.request_id = request_id
# 绑定请求ID到日志
request.state.logger = RequestIDLogAdapter(logger, {'request_id': request_id})
# 记录请求入口
request.state.logger.info(
f"接收请求: {request.method} {request.url.path},客户端: {request.client.host}:{request.client.port}"
)
# 日志脱敏(排除敏感头)
filtered_headers = {k: v for k, v in request.headers.items() if k.lower() not in ['authorization', 'cookie']}
request.state.logger.debug(f"请求头: {filtered_headers}")
start_time = time.time()
try:
response = await call_next(request)
except Exception as e:
request.state.logger.error(f"处理请求异常: {str(e)}", exc_info=True)
raise
finally:
# 记录请求耗时
process_time = time.time() - start_time
request.state.logger.info(
f"请求完成: {request.method} {request.url.path},状态码: {response.status_code},处理时间: {process_time:.6f}秒"
)
# 响应头携带请求ID(便于追踪)
response.headers["X-Request-ID"] = request_id
return response
app.add_middleware(RequestIDMiddleware)
# --------------------------
# 4. 工具函数(客户端API Key验证)
# --------------------------
def validate_client_api_key(request: Request) -> str:
logger = request.state.logger
logger.info("进入客户端API Key验证流程")
# 1. 检查Authorization头是否存在
auth_header = request.headers.get("Authorization")
if not auth_header:
logger.warning("未检测到Authorization请求头")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="未提供API密钥,请使用 Bearer <API_KEY> 格式在Authorization头中携带"
)
# 2. 检查Authorization格式
if not auth_header.startswith("Bearer "):
logger.warning(f"Authorization格式错误,原始值: {auth_header[:10]}***")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Authorization头格式错误,正确格式为: Bearer <API_KEY>"
)
# 3. 提取并验证API Key
client_api_key = auth_header[len("Bearer "):].strip()
masked_key = f"{client_api_key[:4]}***{client_api_key[-4:]}" # 日志脱敏
if client_api_key not in ALLOWED_CLIENT_API_KEYS:
logger.warning(f"API Key验证失败,密钥: {masked_key}")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="无效或未授权的API密钥"
)
logger.info(f"API Key验证通过,密钥: {masked_key}")
return client_api_key
# --------------------------
# 5. Pydantic模型(请求格式校验)
# --------------------------
class EmbeddingRequest(BaseModel):
input: str | List[str]
model: str
encoding_format: Optional[str] = "float"
user: Optional[str] = None
class MessageContent(BaseModel):
type: str
text: str
class Message(BaseModel):
role: str
content: Union[str, List[MessageContent]] # 支持纯文本或多内容类型
class ChatRequest(BaseModel):
model: str
messages: List[Message]
temperature: Optional[float] = 1.0
top_p: Optional[float] = 1.0
n: Optional[int] = 1
stream: Optional[bool] = False
stop: Optional[str | List[str]] = None
max_tokens: Optional[int] = None
presence_penalty: Optional[float] = 0.0
frequency_penalty: Optional[float] = 0.0
logit_bias: Optional[Dict[str, float]] = None
user: Optional[str] = None
# --------------------------
# 6. 核心工具函数(请求转发,含中文修复)
# --------------------------
def forward_request(
url: str,
api_key: str,
payload: Dict[str, Any],
logger: logging.Logger,
stream: bool = False
) -> Any:
"""
转发请求到上游API,支持流式/非流式响应:
- 修复中文流式片段截断导致的解析失败
- 统一覆盖model字段为客户端请求值
- 增强日志记录(脱敏+关键信息)
"""
# 1. 构建请求头(Authorization脱敏)
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {api_key}" if api_key else ""
}
logger.debug(f"上游请求URL: {url}")
safe_headers = headers.copy()
if "Authorization" in safe_headers:
safe_headers["Authorization"] = safe_headers["Authorization"][:10] + "***"
logger.debug(f"上游请求头: {safe_headers}")
# 2. 记录脱敏请求体(避免敏感信息泄露)
if "messages" in payload:
safe_messages = []
for msg in payload["messages"]:
safe_msg = msg.copy()
# 文本内容截断(日志可读性优化)
if isinstance(safe_msg.get("content"), str):
safe_msg["content"] = safe_msg["content"][:100] + "..." if len(safe_msg["content"]) > 100 else safe_msg["content"]
elif isinstance(safe_msg.get("content"), list):
for item in safe_msg["content"]:
if "text" in item:
item["text"] = item["text"][:100] + "..." if len(item["text"]) > 100 else item["text"]
safe_messages.append(safe_msg)
safe_payload = payload.copy()
safe_payload["messages"] = safe_messages
logger.debug(f"上游请求体(脱敏): {json.dumps(safe_payload, ensure_ascii=False)}")
else:
logger.debug(f"上游请求体(脱敏): {json.dumps(payload, ensure_ascii=False)[:500]}...")
try:
if stream:
def stream_generator():
logger.info("启动上游流式响应接收")
request_start_time = time.time()
with requests.post(
url,
json=payload,
headers=headers,
stream=True,
timeout=60, # 超时保护
verify=True # 生产环境开启SSL验证
) as r:
# 记录连接耗时
conn_time = time.time() - request_start_time
logger.info(f"上游连接建立,耗时: {conn_time:.3f}秒,状态码: {r.status_code}")
# 上游非200状态码处理
if r.status_code != 200:
error_msg = f"上游请求失败: {r.status_code},响应片段: {r.text[:500]}"
logger.error(error_msg)
yield f'data: {{"error": "{error_msg}", "code": {r.status_code}}}\n\n'
return
# 流式处理变量
chunk_count = 0
total_bytes = 0
target_model = payload.get("model")
incomplete_json = "" # 缓存截断的JSON片段(中文修复核心)
for line in r.iter_lines(decode_unicode=True):
logger.debug(f"原始上游行内容: {line}")
if not line:
continue # 跳过空行
chunk_count += 1
raw_len = len(line)
total_bytes += raw_len
processed_line = line
# 仅处理data: 开头的流式片段
if target_model and line.startswith("data: "):
data_part = line[6:]
# 处理[DONE]结束标识
if data_part.strip() == "[DONE]":
# 先补全缓存的不完整JSON(若有)
if incomplete_json:
try:
completed_json = incomplete_json + "}" # 补全JSON闭合符
data = json.loads(completed_json)
# 覆盖model字段
if "model" in data:
data["model"] = target_model
if "choices" in data:
for choice in data["choices"]:
if "model" in choice:
choice["model"] = target_model
# 发送补全后的片段
yield f'data: {json.dumps(data, ensure_ascii=False)}\n'
logger.debug(f"补全并发送缓存片段(片段#{chunk_count-1})")
incomplete_json = ""
except json.JSONDecodeError:
logger.warning(f"补全缓存JSON失败,丢弃片段#{chunk_count-1}")
incomplete_json = ""
# 发送[DONE]标识
processed_line = "data: [DONE]\n"
# 处理普通JSON片段
else:
full_data_part = incomplete_json + data_part
try:
# 尝试解析合并后的JSON
data = json.loads(full_data_part)
# 覆盖model字段
if "model" in data:
data["model"] = target_model
if "choices" in data:
for choice in data["choices"]:
if "model" in choice:
choice["model"] = target_model
# 生成处理后的片段(中文不转义)
processed_line = f'data: {json.dumps(data, ensure_ascii=False)}\n'
incomplete_json = "" # 解析成功,清空缓存
except json.JSONDecodeError as e:
# 仅处理"JSON未闭合"错误(中文截断场景)
if "unexpected end of JSON input" in str(e):
incomplete_json = full_data_part
logger.debug(f"缓存截断片段#{chunk_count}(长度: {len(incomplete_json)})")
continue # 不发送,等待下一段补全
# 其他解析错误,尝试修复编码
else:
logger.warning(f"片段#{chunk_count}解析失败,尝试修复编码: {str(e)}")
try:
# 修复UTF-8转义乱码
fixed_data = json.loads(data_part.encode('utf-8').decode('unicode_escape'))
processed_line = f'data: {json.dumps(fixed_data, ensure_ascii=False)}\n'
except:
# 最终失败,返回原始内容(避免断流)
processed_line = f"{line}\n"
else:
processed_line = f"{line}\n"
# 日志记录(每10个片段或大片段重点记录)
if chunk_count % 10 == 0 or raw_len > 1024:
logger.debug(
f"接收上游片段 #{chunk_count},大小: {raw_len}字节,"
f"内容前200字符: {processed_line[:200]}..."
)
# 转发片段给客户端
yield processed_line
# 流式结束,处理剩余缓存
if incomplete_json:
try:
completed_json = incomplete_json + "}"
data = json.loads(completed_json)
if "model" in data:
data["model"] = target_model
if "choices" in data:
for choice in data["choices"]:
if "model" in choice:
choice["model"] = target_model
yield f'data: {json.dumps(data, ensure_ascii=False)}\n'
logger.debug("流式结束,补全最后一段缓存")
except:
logger.warning("流式结束,丢弃未补全的缓存片段")
logger.info(f"上游流式响应完成,共{chunk_count}个片段,总大小: {total_bytes/1024:.2f}KB")
return stream_generator()
# 非流式请求处理
else:
logger.info("发送上游非流式请求")
request_start_time = time.time()
response = requests.post(
url,
json=payload,
headers=headers,
timeout=60,
verify=True
)
# 记录响应基础信息
resp_time = time.time() - request_start_time
logger.info(
f"上游非流式响应接收,耗时: {resp_time:.3f}秒,状态码: {response.status_code},"
f"响应大小: {len(response.content)/1024:.2f}KB"
)
logger.debug(f"上游响应头: {dict(response.headers)}")
# 响应内容日志(截断过长内容)
resp_text = response.text
if len(resp_text) > 1000:
logger.debug(f"上游响应内容(截断): {resp_text[:1000]}...")
else:
logger.debug(f"上游响应内容: {resp_text}")
# 上游错误状态码处理
if response.status_code != 200:
error_msg = f"上游请求失败: {response.status_code},响应: {resp_text[:500]}"
logger.error(error_msg)
raise HTTPException(status_code=response.status_code, detail=error_msg)
# 处理响应(覆盖model字段)
resp_data = response.json()
target_model = payload.get("model")
if target_model and "model" in resp_data:
resp_data["model"] = target_model
if target_model and "choices" in resp_data:
for choice in resp_data["choices"]:
if "model" in choice:
choice["model"] = target_model
return resp_data
# 网络异常处理(分类提示)
except requests.exceptions.RequestException as e:
error_type = ""
if isinstance(e, requests.exceptions.Timeout):
error_type = "(请求超时)"
elif isinstance(e, requests.exceptions.ConnectionError):
error_type = "(连接失败)"
elif isinstance(e, requests.exceptions.SSLError):
error_type = "(SSL证书错误)"
error_msg = f"与上游API通信异常{error_type}: {str(e)}"
logger.error(error_msg, exc_info=True)
raise HTTPException(status_code=500, detail=error_msg)
# --------------------------
# 7. 接口实现(健康检查、嵌入、聊天)
# --------------------------
@app.get("/health")
async def health_check(request: Request):
"""健康检查接口(监控用)"""
logger = request.state.logger
logger.info("处理健康检查请求")
# 检查上游API可用性
embedding_healthy = False
try:
requests.head(EMBEDDING_API_BASE, timeout=5)
embedding_healthy = True
except Exception as e:
logger.warning(f"嵌入API健康检查失败: {str(e)}")
chat_healthy = False
try:
requests.head(CHAT_API_BASE, timeout=5)
chat_healthy = True
except Exception as e:
logger.warning(f"聊天API健康检查失败: {str(e)}")
# 构建健康检查结果
result = {
"status": "healthy" if (embedding_healthy and chat_healthy) else "degraded",
"timestamp": datetime.utcnow().isoformat() + "Z",
"supported_models": SUPPORTED_MODELS,
"services": {
"embedding_api": {
"status": "healthy" if embedding_healthy else "unhealthy",
"base_url": EMBEDDING_API_BASE
},
"chat_api": {
"status": "healthy" if chat_healthy else "unhealthy",
"base_url": CHAT_API_BASE
}
}
}
logger.info(f"健康检查完成,整体状态: {result['status']}")
return result
@app.post("/v1/embeddings")
async def create_embedding(request: Request, req_body: EmbeddingRequest):
"""嵌入接口(文本转向量)"""
logger = request.state.logger
logger.info("处理嵌入请求")
# 1. 验证客户端API Key
validate_client_api_key(request)
# 2. 验证模型是否支持
if req_body.model not in SUPPORTED_MODELS["embedding"]:
error_msg = f"不支持的嵌入模型: {req_body.model},支持模型: {SUPPORTED_MODELS['embedding']}"
logger.error(error_msg)
raise HTTPException(status_code=400, detail=error_msg)
logger.info(f"嵌入模型验证通过: {req_body.model}")
# 3. 转发请求到上游
payload = req_body.model_dump(exclude_unset=True) # 排除未设置的字段
target_url = f"{EMBEDDING_API_BASE}/embeddings"
logger.debug(f"转发嵌入请求到: {target_url}, payload: {json.dumps(payload, ensure_ascii=False)[:500]}...")
response_data = forward_request(target_url, EMBEDDING_API_KEY, payload, logger)
logger.info("嵌入请求处理完成,返回响应")
return response_data
@app.post("/v1/chat/completions")
async def create_chat_completion(request: Request):
"""聊天接口(支持流式响应)"""
logger = request.state.logger
logger.info("处理聊天请求")
try:
# 1. 解析请求体(兼容原始JSON)
req_body_dict = await request.json()
logger.debug(f"原始聊天请求体: {json.dumps(req_body_dict, ensure_ascii=False)[:1000]}...")
# 2. 验证客户端API Key
validate_client_api_key(request)
# 3. 校验请求体格式
req_body = ChatRequest(**req_body_dict)
logger.info(f"聊天请求解析完成: 模型={req_body.model},流式={req_body.stream}")
logger.debug(f"最终发送的messages: {json.dumps([msg.model_dump() for msg in req_body.messages], ensure_ascii=False)}")
# 4. 验证模型是否支持
if req_body.model not in SUPPORTED_MODELS["chat"]:
error_msg = f"不支持的聊天模型: {req_body.model},支持模型: {SUPPORTED_MODELS['chat']}"
logger.error(error_msg)
raise HTTPException(status_code=400, detail=error_msg)
logger.info(f"聊天模型验证通过: {req_body.model}")
# 5. 构建转发 payload(排除未设置字段)
payload = req_body.model_dump(exclude_unset=True)
target_url = f"{CHAT_API_BASE}/chat/completions"
logger.debug(f"转发聊天请求到: {target_url}")
# 6. 处理流式/非流式响应
if req_body.stream:
logger.info("启用流式响应模式")
# 上游流式生成器
upstream_generator = forward_request(target_url, CHAT_API_KEY, payload, logger, stream=True)
# 客户端流式包装器(增加日志)
async def client_stream():
client_chunk_count = 0
try:
for chunk in upstream_generator:
client_chunk_count += 1
logger.debug(f"发送给客户端片段 #{client_chunk_count}: {chunk.strip()[:200]}...")
yield chunk
logger.info(f"流式响应完成,共发送 {client_chunk_count} 个片段")
except Exception as e:
logger.error(f"流式响应异常: {str(e)}", exc_info=True)
raise
return StreamingResponse(client_stream(), media_type="text/event-stream")
# 非流式响应
else:
logger.info("启用非流式响应模式")
response_data = forward_request(target_url, CHAT_API_KEY, payload, logger)
logger.debug(f"非流式聊天响应大小: {len(str(response_data))}字符")
logger.info("非流式聊天请求处理完成")
return response_data
# 异常捕获与处理
except json.JSONDecodeError:
logger.error("聊天请求体不是有效JSON")
raise HTTPException(status_code=400, detail="请求体必须是有效JSON格式")
except ValidationError as e:
logger.error(f"聊天请求参数错误: {str(e)}")
raise HTTPException(status_code=422, detail=f"请求参数错误: {str(e)}")
except Exception as e:
logger.error(f"聊天请求处理异常: {str(e)}", exc_info=True)
raise
# --------------------------
# 8. 全局异常处理器(统一响应格式)
# --------------------------
@app.exception_handler(HTTPException)
async def http_exception_handler(request: Request, exc: HTTPException):
"""处理FastAPI HTTP异常"""
logger = request.state.logger if hasattr(request.state, 'logger') else logging.getLogger("api_proxy")
error_msg = f"HTTP异常: {exc.status_code} - {exc.detail}"
logger.error(error_msg)
return JSONResponse(
status_code=exc.status_code,
content={
"error": {
"message": exc.detail,
"type": "invalid_request_error",
"request_id": getattr(request.state, 'request_id', 'unknown')
}
}
)
@app.exception_handler(Exception)
async def general_exception_handler(request: Request, exc: Exception):
"""处理全局未捕获异常"""
logger = request.state.logger if hasattr(request.state, 'logger') else logging.getLogger("api_proxy")
error_msg = f"服务器内部异常: {str(exc)}"
logger.error(error_msg, exc_info=True)
return JSONResponse(
status_code=500,
content={
"error": {
"message": "服务器内部错误,请联系管理员",
"type": "server_error",
"request_id": getattr(request.state, 'request_id', 'unknown')
}
}
)
# --------------------------
# 9. 服务启动入口
# --------------------------
if __name__ == "__main__":
import uvicorn
port = int(os.getenv("PORT", 7860))
logger.info(f"启动API代理服务,端口: {port},允许客户端API Key数量: {len(ALLOWED_CLIENT_API_KEYS)}")
# 禁用uvicorn默认日志(使用自定义日志)
uvicorn.run(app, host="0.0.0.0", port=port, log_config=None)