Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, Request, HTTPException, Response | |
| from fastapi.responses import StreamingResponse | |
| import httpx | |
| import logging | |
| import os, json | |
| import re # 用于URL路径处理 | |
| from key_selector import KeySelector # 自动选择key | |
| from app.routers import key_management # Import the new router | |
| def get_target_url(url: str) -> str: | |
| """将url参数变了转换为合法的目标url;from http/ or https/ to http:// or https://""" | |
| url = re.sub(r"^http/", "http://", url) | |
| url = re.sub(r"^https/", "https://", url) | |
| return url | |
| app = FastAPI() | |
| # Include the new key management router | |
| app.include_router(key_management.router, prefix="/api/keys", tags=["Key Management"]) | |
| # 配置日志 | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger("uvicorn.error") | |
| # 从环境变量获取配置 | |
| # X_Goog_Api_Key = os.getenv("X_Goog_Api_Key", "") | |
| async def read_root(): | |
| return {"message": "FastAPI Proxy is running"} | |
| async def proxy(request: Request, path: str): | |
| # 添加流式请求判断逻辑 | |
| is_streaming = ":streamGenerateContent" in path.lower() | |
| target_url = get_target_url(path) | |
| method = request.method | |
| headers = {k: v for k, v in request.headers.items() | |
| # if k.lower() not in ["host", "connection", "Postman-Token", "content-length"]} | |
| if k.lower() not in ["host", "content-length"]} | |
| key_selector = KeySelector() | |
| headers["X-Goog-Api-Key"] = key_selector.get_api_key_info()['key_value'] # 从数据库获取API密钥 | |
| try: | |
| # 关键修复:禁用KeepAlive防止连接冲突 | |
| transport = httpx.AsyncHTTPTransport(retries=3, http1=True) | |
| async with httpx.AsyncClient( | |
| transport=transport, | |
| timeout=httpx.Timeout(300.0, connect=30.0) | |
| ) as client: | |
| # 处理请求体 | |
| req_content = await request.body() | |
| # 发送请求到上游服务 | |
| response = await client.request( | |
| method=method, | |
| url=target_url, | |
| headers=headers, | |
| content=req_content, | |
| follow_redirects=True # 自动处理重定向 | |
| ) | |
| if is_streaming: | |
| # 流式响应处理 | |
| async def stream_generator(): | |
| try: | |
| async for chunk in response.aiter_bytes(): | |
| yield chunk | |
| except Exception as e: | |
| logger.error(f"Stream interrupted: {str(e)}") | |
| yield json.dumps({"error": "流中断"}).encode() | |
| # 移除冲突头部 | |
| headers = dict(response.headers) | |
| headers.pop("Content-Length", None) | |
| return StreamingResponse( | |
| content=stream_generator(), | |
| status_code=response.status_code, | |
| headers=headers, | |
| media_type="application/x-ndjson" # Gemini流式格式 | |
| ) | |
| else: | |
| # 非流式响应处理 | |
| # 解析 JSON 字符串 | |
| try: | |
| data = json.loads(response.text) | |
| # 格式化输出 JSON 数据 | |
| formatted_json = json.dumps(data, ensure_ascii=False, indent=4) | |
| # return formatted_json | |
| return Response( | |
| content=formatted_json, | |
| media_type="application/json" | |
| ) | |
| except json.JSONDecodeError as e: | |
| print(f"Error decoding JSON: {e}") | |
| except httpx.ConnectError as e: | |
| logger.error(f"Connection failed to {target_url}: {e}") | |
| raise HTTPException(502, f"无法连接到上游服务: {target_url}") # Modified error message | |
| except httpx.ReadTimeout as e: | |
| logger.error(f"Timeout: {e}") | |
| raise HTTPException(504, "上游服务响应超时") | |
| except httpx.HTTPError as e: # 捕获所有HTTP异常 | |
| try: | |
| # 安全地获取异常信息 | |
| error_type = type(e).__name__ | |
| # 尝试获取状态码(如果存在) | |
| status_code = getattr(e, 'response', None) and e.response.status_code | |
| # 安全地获取错误详情 | |
| error_detail = "" | |
| try: | |
| # 尝试获取文本响应(限制长度) | |
| if hasattr(e, 'response') and e.response: | |
| error_detail = e.response.text[:500] # 只取前500个字符 | |
| except Exception as ex: | |
| error_detail = f"无法获取错误详情: {type(ex).__name__}" | |
| # 安全地记录日志 | |
| logger.error( | |
| "HTTP代理错误 | " | |
| f"类型: {error_type} | " | |
| f"状态码: {status_code or 'N/A'} | " | |
| f"目标URL: {target_url} | " | |
| f"详情: {error_detail[:200]}" # 日志中只记录前200字符 | |
| ) | |
| # 打印到控制台以便调试 | |
| print(f"目标URL: {target_url}") | |
| print(f"状态码: {status_code or 'N/A'}") | |
| print(f"错误详情: {error_detail[:500]}") | |
| except Exception as ex: | |
| # 如果记录日志本身出错,使用最安全的方式记录 | |
| logger.error(f"记录HTTP错误时发生异常: {type(ex).__name__}") | |
| print(f"严重错误: 记录HTTP错误时发生异常: {ex}") | |
| # 返回用户友好的错误响应 | |
| raise HTTPException( | |
| status_code=502, | |
| detail=f"网关服务错误: {error_type} (上游状态: {status_code or '未知'})" | |
| ) | |
| except Exception as e: | |
| logger.exception("Unexpected proxy error") | |
| raise HTTPException(500, f"内部服务器错误: {str(e)}") | |
| if __name__ == "__main__": | |
| # In a real application, you would typically run uvicorn here: | |
| # import uvicorn | |
| # uvicorn.run(app, host="0.0.0.0", port=8000) | |
| pass # Placeholder to fix indentation error | |