api / app.py
bobocup's picture
Update app.py
7dbc742 verified
from fastapi import FastAPI, Request, Response
from fastapi.responses import StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
import httpx
import json
from typing import Optional
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
TARGET_URL = "http://beibeioo.top"
async def get_http_client():
return httpx.AsyncClient(
timeout=30.0,
follow_redirects=True,
http2=True
)
async def is_stream_request(request: Request) -> bool:
# 检查请求头中的 accept 字段
accept = request.headers.get('accept', '')
# 检查查询参数
stream = request.query_params.get('stream', '').lower()
return 'text/event-stream' in accept or stream == 'true'
async def proxy_handler(url: str, request: Request):
try:
headers = dict(request.headers)
headers.pop('host', None)
headers.pop('connection', None)
headers.pop('content-length', None)
headers.pop('transfer-encoding', None)
params = dict(request.query_params)
body = await request.body()
async with await get_http_client() as client:
if await is_stream_request(request):
# 流式请求
response = await client.stream(
method=request.method,
url=url,
params=params,
headers=headers,
content=body if body else None,
)
response_headers = dict(response.headers)
response_headers.pop('transfer-encoding', None)
response_headers.pop('content-encoding', None)
response_headers.pop('content-length', None)
async def stream_generator():
try:
buffer = b""
async for chunk in response.aiter_raw():
# 立即发送每个收到的块
if chunk:
yield chunk
except Exception as e:
print(f"Streaming error: {e}")
finally:
await response.aclose()
return StreamingResponse(
stream_generator(),
status_code=response.status_code,
headers=response_headers,
media_type=response.headers.get('content-type')
)
else:
# 非流式请求
response = await client.request(
method=request.method,
url=url,
params=params,
headers=headers,
content=body if body else None,
)
response_headers = dict(response.headers)
response_headers.pop('transfer-encoding', None)
response_headers.pop('content-encoding', None)
response_headers.pop('content-length', None)
return Response(
content=response.content,
status_code=response.status_code,
headers=response_headers,
media_type=response.headers.get('content-type')
)
except httpx.TimeoutException:
return Response(content="请求超时", status_code=504)
except httpx.RequestError:
return Response(content="无法连接到目标服务器", status_code=502)
except Exception as e:
return Response(content=f"服务器错误: {str(e)}", status_code=500)
# 处理 /api/v1 路径
@app.api_route("/api/v1/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "OPTIONS", "HEAD", "PATCH"])
async def api_v1_proxy(path: str, request: Request):
url = f"{TARGET_URL}/v1/{path}"
return await proxy_handler(url, request)
# 处理根路径和其他路径
@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "OPTIONS", "HEAD", "PATCH"])
async def root_proxy(path: str, request: Request):
if not path or path == "/":
url = TARGET_URL
else:
url = f"{TARGET_URL}/{path}"
return await proxy_handler(url, request)
@app.get("/healthcheck")
async def healthcheck():
return {"status": "healthy"}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)