|
|
import os |
|
|
import httpx |
|
|
import websockets |
|
|
import asyncio |
|
|
import uvicorn |
|
|
from fastapi import FastAPI, Request, WebSocket, Response |
|
|
from starlette.responses import StreamingResponse |
|
|
from starlette.background import BackgroundTask |
|
|
from contextlib import asynccontextmanager |
|
|
|
|
|
|
|
|
TARGET_URL = os.environ.get("TARGET_URL") |
|
|
HF_TOKEN = os.environ.get("HF_TOKEN") |
|
|
|
|
|
if not TARGET_URL: |
|
|
raise ValueError("❌ TARGET_URL is missing.") |
|
|
|
|
|
|
|
|
http_client = None |
|
|
|
|
|
@asynccontextmanager |
|
|
async def lifespan(app: FastAPI): |
|
|
global http_client |
|
|
http_client = httpx.AsyncClient(timeout=60.0, follow_redirects=True) |
|
|
yield |
|
|
await http_client.aclose() |
|
|
|
|
|
app = FastAPI(lifespan=lifespan) |
|
|
|
|
|
@app.middleware("http") |
|
|
async def proxy_http(request: Request, call_next): |
|
|
url = f"{TARGET_URL}{request.url.path}" |
|
|
if request.url.query: |
|
|
url += f"?{request.url.query}" |
|
|
|
|
|
|
|
|
excluded = ['host', 'content-length', 'connection', 'upgrade', 'accept-encoding'] |
|
|
proxy_headers = {k: v for k, v in request.headers.items() if k.lower() not in excluded} |
|
|
|
|
|
|
|
|
|
|
|
if "authorization" in proxy_headers: |
|
|
proxy_headers["X-User-Auth"] = proxy_headers["authorization"] |
|
|
|
|
|
|
|
|
proxy_headers["Authorization"] = f"Bearer {HF_TOKEN}" |
|
|
|
|
|
|
|
|
try: |
|
|
rp_req = http_client.build_request( |
|
|
method=request.method, |
|
|
url=url, |
|
|
headers=proxy_headers, |
|
|
content=request.stream(), |
|
|
cookies=request.cookies |
|
|
) |
|
|
rp_resp = await http_client.send(rp_req, stream=True) |
|
|
|
|
|
res_excluded = ['content-encoding', 'content-length', 'transfer-encoding', 'connection'] |
|
|
res_headers = {k: v for k, v in rp_resp.headers.items() if k.lower() not in res_excluded} |
|
|
|
|
|
return StreamingResponse( |
|
|
rp_resp.aiter_raw(), |
|
|
status_code=rp_resp.status_code, |
|
|
headers=res_headers, |
|
|
background=BackgroundTask(rp_resp.aclose) |
|
|
) |
|
|
except Exception as e: |
|
|
return Response(f"Proxy Error: {str(e)}", status_code=502) |
|
|
|
|
|
@app.websocket_route("/{path:path}") |
|
|
async def proxy_ws(websocket: WebSocket): |
|
|
await websocket.accept() |
|
|
ws_url = TARGET_URL.replace("https://", "wss://").replace("http://", "ws://") |
|
|
target_ws_url = f"{ws_url}/{websocket.path_params['path']}" |
|
|
|
|
|
|
|
|
headers = dict(websocket.headers) |
|
|
if "authorization" in headers: |
|
|
headers["X-User-Auth"] = headers["authorization"] |
|
|
headers["Authorization"] = f"Bearer {HF_TOKEN}" |
|
|
|
|
|
try: |
|
|
async with websockets.connect(target_ws_url, additional_headers=headers) as target_ws: |
|
|
async def forward(src, dst): |
|
|
try: |
|
|
while True: |
|
|
data = await src.recv() if hasattr(src, 'recv') else await src.receive_text() |
|
|
await dst.send(data) if hasattr(dst, 'send') else await dst.send_text(data) |
|
|
except: pass |
|
|
await asyncio.gather(forward(websocket, target_ws), forward(target_ws, websocket)) |
|
|
except: |
|
|
await websocket.close() |
|
|
|
|
|
if __name__ == '__main__': |
|
|
uvicorn.run(app, host='0.0.0.0', port=7860) |
|
|
|