File size: 3,465 Bytes
27cae04
4903eb0
 
9d9b911
fde7a66
4903eb0
 
9d9b911
 
27cae04
4903eb0
72c54d6
 
27cae04
9d9b911
72c54d6
9d9b911
72c54d6
9d9b911
 
 
 
 
72c54d6
9d9b911
 
 
 
 
4903eb0
 
72c54d6
 
 
 
 
 
 
 
 
 
 
 
788951d
72c54d6
 
 
077da68
9d9b911
 
 
 
72c54d6
9d9b911
788951d
9d9b911
 
72c54d6
9d9b911
72c54d6
9d9b911
 
 
 
 
 
 
 
 
4903eb0
 
 
 
 
72c54d6
4903eb0
72c54d6
4903eb0
72c54d6
 
 
 
9d9b911
 
72c54d6
9d9b911
 
72c54d6
 
 
 
 
 
4903eb0
 
fde7a66
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import os
import httpx
import websockets
import asyncio
import uvicorn
from fastapi import FastAPI, Request, WebSocket, Response
from starlette.responses import StreamingResponse
from starlette.background import BackgroundTask
from contextlib import asynccontextmanager

# --- CONFIGURATION ---
TARGET_URL = os.environ.get("TARGET_URL") # Your Private Direct URL
HF_TOKEN = os.environ.get("HF_TOKEN")     # Your HF Read Token

if not TARGET_URL:
    raise ValueError("❌ TARGET_URL is missing.")

# Global Client
http_client = None

@asynccontextmanager
async def lifespan(app: FastAPI):
    global http_client
    http_client = httpx.AsyncClient(timeout=60.0, follow_redirects=True)
    yield
    await http_client.aclose()

app = FastAPI(lifespan=lifespan)

@app.middleware("http")
async def proxy_http(request: Request, call_next):
    url = f"{TARGET_URL}{request.url.path}"
    if request.url.query:
        url += f"?{request.url.query}"

    # 1. Prepare Headers
    excluded = ['host', 'content-length', 'connection', 'upgrade', 'accept-encoding']
    proxy_headers = {k: v for k, v in request.headers.items() if k.lower() not in excluded}

    # --- THE HEADER SWAP ---
    # A. If the user sent a token, hide it in 'X-User-Auth'
    if "authorization" in proxy_headers:
        proxy_headers["X-User-Auth"] = proxy_headers["authorization"]
    
    # B. Put the HF Token in the main 'Authorization' slot to pass the Gate
    proxy_headers["Authorization"] = f"Bearer {HF_TOKEN}"
    # -----------------------

    try:
        rp_req = http_client.build_request(
            method=request.method,
            url=url,
            headers=proxy_headers, # Sending modified headers
            content=request.stream(),
            cookies=request.cookies
        )
        rp_resp = await http_client.send(rp_req, stream=True)
        
        res_excluded = ['content-encoding', 'content-length', 'transfer-encoding', 'connection']
        res_headers = {k: v for k, v in rp_resp.headers.items() if k.lower() not in res_excluded}

        return StreamingResponse(
            rp_resp.aiter_raw(),
            status_code=rp_resp.status_code,
            headers=res_headers,
            background=BackgroundTask(rp_resp.aclose)
        )
    except Exception as e:
        return Response(f"Proxy Error: {str(e)}", status_code=502)

@app.websocket_route("/{path:path}")
async def proxy_ws(websocket: WebSocket):
    await websocket.accept()
    ws_url = TARGET_URL.replace("https://", "wss://").replace("http://", "ws://")
    target_ws_url = f"{ws_url}/{websocket.path_params['path']}"
    
    # Same Swap for Websockets
    headers = dict(websocket.headers)
    if "authorization" in headers:
        headers["X-User-Auth"] = headers["authorization"]
    headers["Authorization"] = f"Bearer {HF_TOKEN}"

    try:
        async with websockets.connect(target_ws_url, additional_headers=headers) as target_ws:
            async def forward(src, dst):
                try:
                    while True:
                        data = await src.recv() if hasattr(src, 'recv') else await src.receive_text()
                        await dst.send(data) if hasattr(dst, 'send') else await dst.send_text(data)
                except: pass
            await asyncio.gather(forward(websocket, target_ws), forward(target_ws, websocket))
    except:
        await websocket.close()

if __name__ == '__main__':
    uvicorn.run(app, host='0.0.0.0', port=7860)