| import asyncio |
| import json |
| import logging |
| import os |
| import random |
| import time |
| from contextlib import asynccontextmanager |
|
|
| import aiohttp |
| from fastapi import FastAPI, Request, Response |
| from fastapi.middleware.cors import CORSMiddleware |
| from fastapi.responses import JSONResponse, StreamingResponse |
|
|
| |
| TARGET_URLS = [u.strip() for u in os.getenv("TARGET_URLS", "").split(",") if u.strip()] or [ |
| "https://lbb123-p01.hf.space/v1", |
| "https://lbb123-p02.hf.space/v1", |
| "https://lbb123-p03.hf.space/v1", |
| "https://lbb123-p04.hf.space/v1", |
| "https://lbb123-p05.hf.space/v1", |
| "https://xunjunaa-pp01.hf.space/v1", |
| "https://xunjunaa-pp02.hf.space/v1", |
| "https://xunjunaa-pp03.hf.space/v1", |
| "https://xunjunaa-pp04.hf.space/v1", |
| "https://xunjunaa-pp05.hf.space/v1", |
| "https://benwe2-ppp01.hf.space/v1", |
| "https://benwe2-ppp02.hf.space/v1", |
| "https://benwe2-ppp03.hf.space/v1", |
| "https://benwe2-ppp04.hf.space/v1", |
| "https://benwe2-ppp05.hf.space/v1", |
| ] |
|
|
| LOCAL_KEY = os.getenv("LOCAL_API_KEY", "123456") |
| PROXY = os.getenv("PROXY") or None |
| MAX_RETRY = int(os.getenv("MAX_RETRY", "3")) |
|
|
| |
| logging.basicConfig(level=logging.INFO, format="%(asctime)s %(message)s") |
| logger = logging.getLogger("lb") |
|
|
| _fail_time = {url: 0 for url in TARGET_URLS} |
|
|
| @asynccontextmanager |
| async def lifespan(app: FastAPI): |
| app.state.session = aiohttp.ClientSession( |
| connector=aiohttp.TCPConnector(limit=100, limit_per_host=10) |
| ) |
| logger.info(f"启动完成 | 后端:{len(TARGET_URLS)} 密钥:{LOCAL_KEY}") |
| yield |
| await app.state.session.close() |
|
|
| app = FastAPI(lifespan=lifespan) |
| app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"]) |
|
|
| def pick_backends(): |
| now = time.time() |
| ok = [u for u in TARGET_URLS if now - _fail_time.get(u, 0) > 30] |
| if not ok: |
| ok = TARGET_URLS |
| random.shuffle(ok) |
| return ok[:MAX_RETRY] |
|
|
| @app.api_route("/v1", methods=["GET", "POST", "PUT", "DELETE", "PATCH"]) |
| @app.api_route("/v1/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH"]) |
| async def proxy(path: str = "", request: Request = None): |
| if request.headers.get("authorization", "").lower() != f"bearer {LOCAL_KEY}": |
| return JSONResponse(status_code=401, content={"error": "Invalid API key"}) |
|
|
| skip = {"host", "content-length", "connection", "authorization", "transfer-encoding"} |
| headers = {k: v for k, v in request.headers.items() if k.lower() not in skip} |
| headers["Authorization"] = f"Bearer {LOCAL_KEY}" |
| body = await request.body() |
| params = dict(request.query_params) |
|
|
| is_stream = False |
| if body: |
| try: |
| is_stream = json.loads(body).get("stream", False) |
| except Exception: |
| pass |
|
|
| last_err = None |
| for base in pick_backends(): |
| url = base.rstrip("/") + (f"/{path}" if path else "") |
| try: |
| if is_stream: |
| resp = await request.app.state.session.request( |
| method=request.method, url=url, headers=headers, |
| data=body, params=params, proxy=PROXY, |
| timeout=aiohttp.ClientTimeout(connect=10, sock_read=60), |
| ) |
| async def gen(): |
| try: |
| async for chunk in resp.content.iter_any(): |
| yield chunk |
| finally: |
| resp.close() |
|
|
| return StreamingResponse( |
| gen(), |
| media_type=resp.headers.get("content-type", "text/event-stream"), |
| headers={"cache-control": "no-cache", "connection": "keep-alive"}, |
| ) |
|
|
| async with request.app.state.session.request( |
| method=request.method, url=url, headers=headers, |
| data=body, params=params, proxy=PROXY, |
| timeout=aiohttp.ClientTimeout(connect=10, sock_read=60), |
| ) as resp: |
| return Response( |
| content=await resp.read(), |
| status_code=resp.status, |
| headers={"content-type": resp.headers.get("content-type", "application/json")}, |
| ) |
|
|
| except (aiohttp.ClientError, asyncio.TimeoutError, ConnectionRefusedError) as e: |
| last_err = e |
| _fail_time[base] = time.time() |
| logger.warning(f"后端故障: {base} -> {e}") |
| continue |
|
|
| logger.error(f"全部后端不可用. 最后错误: {last_err}") |
| return JSONResponse(status_code=503, content={"error": f"All backends failed: {last_err}"}) |
|
|
| |