bandenamaj commited on
Commit
e2b838f
·
verified ·
1 Parent(s): 22cc708

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +29 -33
app.py CHANGED
@@ -1,52 +1,48 @@
1
  from fastapi import FastAPI, Request, Response
2
  from fastapi.responses import StreamingResponse
3
- import requests
4
  import os
5
 
6
  app = FastAPI()
7
  HF_TOKEN = os.environ.get("HF_TOKEN")
8
  BASE_URL = "https://huggingface.co"
9
 
10
- def get_proxy_response(path: str, query: str = "", method: str = "GET", headers: dict = {}, body: bytes = None):
 
 
 
 
 
 
 
 
11
  url = f"{BASE_URL}/{path}"
12
- if query:
13
- url += f"?{query}"
 
 
 
 
 
 
 
14
 
15
- # Filter headers to avoid conflicts
16
- filtered_headers = {k: v for k, v in headers.items() if k.lower() not in ["host", "content-length", "connection"]}
17
- if HF_TOKEN and "authorization" not in [k.lower() for k in filtered_headers.keys()]:
18
- filtered_headers["Authorization"] = f"Bearer {HF_TOKEN}"
19
-
20
  try:
21
- resp = requests.request(
22
- method=method,
23
- url=url,
24
- headers=filtered_headers,
25
- data=body,
26
- stream=True,
27
- allow_redirects=True
28
- )
29
 
30
- # Exclude headers that FastAPI/Uvicorn might handle or conflict with
31
  excluded = ["content-encoding", "transfer-encoding", "connection", "keep-alive"]
32
- resp_headers = {k: v for k, v in resp.headers.items() if k.lower() not in excluded}
33
 
34
  return StreamingResponse(
35
- resp.iter_content(chunk_size=128*1024),
36
- status_code=resp.status_code,
37
  headers=resp_headers,
38
- media_type=resp.headers.get("content-type")
39
  )
40
  except Exception as e:
 
41
  return Response(content=str(e), status_code=500)
42
-
43
- @app.get("/resolve/{repo_id:path}/{filename:path}")
44
- async def old_resolve_compat(repo_id: str, filename: str):
45
- """Keep compatibility with the previous /resolve structure."""
46
- return get_proxy_response(f"{repo_id}/resolve/main/{filename}")
47
-
48
- @app.api_route("/{path:path}", methods=["GET", "HEAD", "OPTIONS", "POST"])
49
- async def catch_all(request: Request, path: str):
50
- """Universal proxy to Hugging Face."""
51
- body = await request.body() if request.method == "POST" else None
52
- return get_proxy_response(path, request.url.query, request.method, dict(request.headers), body)
 
1
  from fastapi import FastAPI, Request, Response
2
  from fastapi.responses import StreamingResponse
3
+ import httpx
4
  import os
5
 
6
  app = FastAPI()
7
  HF_TOKEN = os.environ.get("HF_TOKEN")
8
  BASE_URL = "https://huggingface.co"
9
 
10
+ # Create a shared async client
11
+ client = httpx.AsyncClient(base_url=BASE_URL, follow_redirects=True)
12
+
13
+ @app.get("/")
14
+ def home():
15
+ return {"status": "running", "message": "Universal HF Proxy Active", "token_detected": bool(HF_TOKEN)}
16
+
17
+ @app.api_route("/{path:path}", methods=["GET", "HEAD", "OPTIONS", "POST"])
18
+ async def proxy(request: Request, path: str):
19
  url = f"{BASE_URL}/{path}"
20
+ if request.url.query:
21
+ url += f"?{request.url.query}"
22
+
23
+ print(f"Proxying: {request.method} {url}")
24
+
25
+ # Filter headers
26
+ headers = {k: v for k, v in request.headers.items() if k.lower() not in ["host", "content-length", "connection"]}
27
+ if HF_TOKEN and "authorization" not in [k.lower() for k in headers.keys()]:
28
+ headers["Authorization"] = f"Bearer {HF_TOKEN}"
29
 
 
 
 
 
 
30
  try:
31
+ # Build request
32
+ body = await request.body()
33
+ rp_req = client.build_request(request.method, url, headers=headers, content=body)
34
+ rp_resp = await client.send(rp_req, stream=True)
 
 
 
 
35
 
36
+ # Filter response headers
37
  excluded = ["content-encoding", "transfer-encoding", "connection", "keep-alive"]
38
+ resp_headers = {k: v for k, v in rp_resp.headers.items() if k.lower() not in excluded}
39
 
40
  return StreamingResponse(
41
+ rp_resp.aiter_raw(),
42
+ status_code=rp_resp.status_code,
43
  headers=resp_headers,
44
+ media_type=rp_resp.headers.get("content-type")
45
  )
46
  except Exception as e:
47
+ print(f"Proxy Error: {e}")
48
  return Response(content=str(e), status_code=500)