jdsan's picture
deploy: sync
63b9729 verified
"""Thin public proxy in front of the private plot-digitizer backend.
Auth model (Option A):
- Caller sends their own `Authorization: Bearer <HF_TOKEN>` — any HF read token.
- Gateway validates it via https://huggingface.co/api/whoami-v2 and
rate-limits per HF username.
- Gateway forwards to the private backend with:
Authorization: Bearer <GATEWAY_HF_TOKEN> # unlocks the private Space
X-Forwarded-User: <caller's HF username> # attribution + backend rate-limit key
- The caller's HF token is NEVER forwarded to the backend.
- Trust boundary = HF's private-Space gate; only the gateway's
GATEWAY_HF_TOKEN can reach the backend.
Gateway secrets (HF Space Secrets):
- GATEWAY_HF_TOKEN : fine-grained HF token with Read on the backend Space.
- BACKEND_URL : e.g. https://jdsan-plot-digitizer.hf.space
This file is public. No secrets live in code.
"""
import os
import time
from typing import Dict, Tuple
import httpx
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import Response
from slowapi import Limiter, _rate_limit_exceeded_handler
from slowapi.errors import RateLimitExceeded
BACKEND_URL = os.environ.get("BACKEND_URL", "").rstrip("/")
GATEWAY_HF_TOKEN = os.environ.get("GATEWAY_HF_TOKEN") or os.environ.get("HF_TOKEN", "")
if not BACKEND_URL:
raise RuntimeError("BACKEND_URL env var not set on gateway")
if not GATEWAY_HF_TOKEN:
raise RuntimeError("GATEWAY_HF_TOKEN env var not set on gateway")
WHOAMI_URL = "https://huggingface.co/api/whoami-v2"
WHOAMI_TTL_SECONDS = 300 # 5 min cache to avoid whoami on every request
_whoami_cache: Dict[str, Tuple[float, str]] = {}
_http = httpx.AsyncClient(timeout=90)
async def _resolve_caller(bearer: str) -> str:
"""Return HF username for a caller-supplied bearer token. Raises 401 on failure."""
now = time.time()
cached = _whoami_cache.get(bearer)
if cached and cached[0] > now:
return cached[1]
try:
r = await _http.get(WHOAMI_URL, headers={"Authorization": f"Bearer {bearer}"}, timeout=10)
except httpx.HTTPError:
raise HTTPException(status_code=502, detail="whoami upstream error")
if r.status_code != 200:
raise HTTPException(status_code=401, detail="AUTH_FAILED — invalid HF token")
data = r.json()
name = data.get("name")
if not isinstance(name, str) or not name:
raise HTTPException(status_code=401, detail="AUTH_FAILED — token has no user")
_whoami_cache[bearer] = (now + WHOAMI_TTL_SECONDS, name)
return name
def _extract_bearer(request: Request) -> str:
auth = request.headers.get("Authorization", "")
if not auth.lower().startswith("bearer "):
raise HTTPException(status_code=401, detail="missing Authorization: Bearer <HF token>")
token = auth.split(None, 1)[1].strip()
if not token:
raise HTTPException(status_code=401, detail="empty bearer token")
return token
def _rate_limit_key(request: Request) -> str:
# State set by the digitize handler after whoami succeeds; fall back to IP
# for pre-auth endpoints.
username = getattr(request.state, "hf_username", None)
if username:
return f"hf:{username}"
return request.client.host if request.client else "anon"
limiter = Limiter(key_func=_rate_limit_key, default_limits=["100/day", "10/minute"])
app = FastAPI(title="plot-digitizer-gateway", version="2.0.0", docs_url=None, redoc_url=None)
app.state.limiter = limiter
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
@app.get("/health")
async def health() -> dict:
return {"status": "ok", "gateway": True}
@app.post("/v1/digitize")
@limiter.limit("100/day;10/minute")
async def digitize(request: Request) -> Response:
bearer = _extract_bearer(request)
username = await _resolve_caller(bearer)
request.state.hf_username = username
body = await request.body()
if len(body) > 16 * 1024 * 1024:
raise HTTPException(status_code=413, detail="payload too large")
r = await _http.post(
f"{BACKEND_URL}/v1/digitize",
content=body,
headers={
"Authorization": f"Bearer {GATEWAY_HF_TOKEN}",
"Content-Type": request.headers.get("Content-Type", "application/json"),
"X-Forwarded-User": username,
},
)
return Response(
content=r.content,
status_code=r.status_code,
media_type=r.headers.get("content-type"),
)