import asyncio import json import logging import os import random import time from contextlib import asynccontextmanager import aiohttp from fastapi import FastAPI, Request, Response from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, StreamingResponse # ================= 配置(优先读环境变量) ================= TARGET_URLS = [u.strip() for u in os.getenv("TARGET_URLS", "").split(",") if u.strip()] or [ "https://lbb123-p01.hf.space/v1", "https://lbb123-p02.hf.space/v1", "https://lbb123-p03.hf.space/v1", "https://lbb123-p04.hf.space/v1", "https://lbb123-p05.hf.space/v1", "https://xunjunaa-pp01.hf.space/v1", "https://xunjunaa-pp02.hf.space/v1", "https://xunjunaa-pp03.hf.space/v1", "https://xunjunaa-pp04.hf.space/v1", "https://xunjunaa-pp05.hf.space/v1", "https://benwe2-ppp01.hf.space/v1", "https://benwe2-ppp02.hf.space/v1", "https://benwe2-ppp03.hf.space/v1", "https://benwe2-ppp04.hf.space/v1", "https://benwe2-ppp05.hf.space/v1", ] LOCAL_KEY = os.getenv("LOCAL_API_KEY", "123456") PROXY = os.getenv("PROXY") or None MAX_RETRY = int(os.getenv("MAX_RETRY", "3")) # ================= 日志 & 健康状态 ================= logging.basicConfig(level=logging.INFO, format="%(asctime)s %(message)s") logger = logging.getLogger("lb") _fail_time = {url: 0 for url in TARGET_URLS} @asynccontextmanager async def lifespan(app: FastAPI): app.state.session = aiohttp.ClientSession( connector=aiohttp.TCPConnector(limit=100, limit_per_host=10) ) logger.info(f"启动完成 | 后端:{len(TARGET_URLS)} 密钥:{LOCAL_KEY}") yield await app.state.session.close() app = FastAPI(lifespan=lifespan) app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"]) def pick_backends(): now = time.time() ok = [u for u in TARGET_URLS if now - _fail_time.get(u, 0) > 30] if not ok: ok = TARGET_URLS random.shuffle(ok) return ok[:MAX_RETRY] @app.api_route("/v1", methods=["GET", "POST", "PUT", "DELETE", "PATCH"]) @app.api_route("/v1/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH"]) async def proxy(path: str = "", request: Request = None): if request.headers.get("authorization", "").lower() != f"bearer {LOCAL_KEY}": return JSONResponse(status_code=401, content={"error": "Invalid API key"}) skip = {"host", "content-length", "connection", "authorization", "transfer-encoding"} headers = {k: v for k, v in request.headers.items() if k.lower() not in skip} headers["Authorization"] = f"Bearer {LOCAL_KEY}" body = await request.body() params = dict(request.query_params) is_stream = False if body: try: is_stream = json.loads(body).get("stream", False) except Exception: pass last_err = None for base in pick_backends(): url = base.rstrip("/") + (f"/{path}" if path else "") try: if is_stream: resp = await request.app.state.session.request( method=request.method, url=url, headers=headers, data=body, params=params, proxy=PROXY, timeout=aiohttp.ClientTimeout(connect=10, sock_read=60), ) async def gen(): try: async for chunk in resp.content.iter_any(): yield chunk finally: resp.close() return StreamingResponse( gen(), media_type=resp.headers.get("content-type", "text/event-stream"), headers={"cache-control": "no-cache", "connection": "keep-alive"}, ) async with request.app.state.session.request( method=request.method, url=url, headers=headers, data=body, params=params, proxy=PROXY, timeout=aiohttp.ClientTimeout(connect=10, sock_read=60), ) as resp: return Response( content=await resp.read(), status_code=resp.status, headers={"content-type": resp.headers.get("content-type", "application/json")}, ) except (aiohttp.ClientError, asyncio.TimeoutError, ConnectionRefusedError) as e: last_err = e _fail_time[base] = time.time() logger.warning(f"后端故障: {base} -> {e}") continue logger.error(f"全部后端不可用. 最后错误: {last_err}") return JSONResponse(status_code=503, content={"error": f"All backends failed: {last_err}"}) # 注意:下面这段本地启动代码已移除,HF 会自动通过 uvicorn 启动