ppdone / app.py
lbb123's picture
Create app.py
9fda7d3 verified
Raw
History Blame Contribute Delete
4.82 kB
import asyncio
import json
import logging
import os
import random
import time
from contextlib import asynccontextmanager
import aiohttp
from fastapi import FastAPI, Request, Response
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, StreamingResponse
# ================= 配置(优先读环境变量) =================
TARGET_URLS = [u.strip() for u in os.getenv("TARGET_URLS", "").split(",") if u.strip()] or [
"https://lbb123-p01.hf.space/v1",
"https://lbb123-p02.hf.space/v1",
"https://lbb123-p03.hf.space/v1",
"https://lbb123-p04.hf.space/v1",
"https://lbb123-p05.hf.space/v1",
"https://xunjunaa-pp01.hf.space/v1",
"https://xunjunaa-pp02.hf.space/v1",
"https://xunjunaa-pp03.hf.space/v1",
"https://xunjunaa-pp04.hf.space/v1",
"https://xunjunaa-pp05.hf.space/v1",
"https://benwe2-ppp01.hf.space/v1",
"https://benwe2-ppp02.hf.space/v1",
"https://benwe2-ppp03.hf.space/v1",
"https://benwe2-ppp04.hf.space/v1",
"https://benwe2-ppp05.hf.space/v1",
]
LOCAL_KEY = os.getenv("LOCAL_API_KEY", "123456")
PROXY = os.getenv("PROXY") or None
MAX_RETRY = int(os.getenv("MAX_RETRY", "3"))
# ================= 日志 & 健康状态 =================
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(message)s")
logger = logging.getLogger("lb")
_fail_time = {url: 0 for url in TARGET_URLS}
@asynccontextmanager
async def lifespan(app: FastAPI):
app.state.session = aiohttp.ClientSession(
connector=aiohttp.TCPConnector(limit=100, limit_per_host=10)
)
logger.info(f"启动完成 | 后端:{len(TARGET_URLS)} 密钥:{LOCAL_KEY}")
yield
await app.state.session.close()
app = FastAPI(lifespan=lifespan)
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
def pick_backends():
now = time.time()
ok = [u for u in TARGET_URLS if now - _fail_time.get(u, 0) > 30]
if not ok:
ok = TARGET_URLS
random.shuffle(ok)
return ok[:MAX_RETRY]
@app.api_route("/v1", methods=["GET", "POST", "PUT", "DELETE", "PATCH"])
@app.api_route("/v1/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH"])
async def proxy(path: str = "", request: Request = None):
if request.headers.get("authorization", "").lower() != f"bearer {LOCAL_KEY}":
return JSONResponse(status_code=401, content={"error": "Invalid API key"})
skip = {"host", "content-length", "connection", "authorization", "transfer-encoding"}
headers = {k: v for k, v in request.headers.items() if k.lower() not in skip}
headers["Authorization"] = f"Bearer {LOCAL_KEY}"
body = await request.body()
params = dict(request.query_params)
is_stream = False
if body:
try:
is_stream = json.loads(body).get("stream", False)
except Exception:
pass
last_err = None
for base in pick_backends():
url = base.rstrip("/") + (f"/{path}" if path else "")
try:
if is_stream:
resp = await request.app.state.session.request(
method=request.method, url=url, headers=headers,
data=body, params=params, proxy=PROXY,
timeout=aiohttp.ClientTimeout(connect=10, sock_read=60),
)
async def gen():
try:
async for chunk in resp.content.iter_any():
yield chunk
finally:
resp.close()
return StreamingResponse(
gen(),
media_type=resp.headers.get("content-type", "text/event-stream"),
headers={"cache-control": "no-cache", "connection": "keep-alive"},
)
async with request.app.state.session.request(
method=request.method, url=url, headers=headers,
data=body, params=params, proxy=PROXY,
timeout=aiohttp.ClientTimeout(connect=10, sock_read=60),
) as resp:
return Response(
content=await resp.read(),
status_code=resp.status,
headers={"content-type": resp.headers.get("content-type", "application/json")},
)
except (aiohttp.ClientError, asyncio.TimeoutError, ConnectionRefusedError) as e:
last_err = e
_fail_time[base] = time.time()
logger.warning(f"后端故障: {base} -> {e}")
continue
logger.error(f"全部后端不可用. 最后错误: {last_err}")
return JSONResponse(status_code=503, content={"error": f"All backends failed: {last_err}"})
# 注意:下面这段本地启动代码已移除,HF 会自动通过 uvicorn 启动