labs / src /app.py
3v324v23's picture
deploy: unified router + dreamy website (2026-06-16T09:46:52Z)
c1a683f
Raw
History Blame Contribute Delete
13.4 kB
"""Application factory: loads config, wires the router, mounts auth gate.""" # module docstring
from __future__ import annotations # defer annotation evaluation (PEP 563)
import logging # structured logging config + named logger
import os # read env vars for KEY (auth) and IPS (whitelist)
import json # stdlib json for _peek_model (free-model balance gate)
from pathlib import Path # locate the data directory relative to this file
from fastapi import FastAPI # the app factory target
from fastapi.middleware.cors import CORSMiddleware # CORS for browser clients
from fastapi.responses import JSONResponse # JSON error responses from the gate
from fastapi.staticfiles import StaticFiles # serve the dreamy website (html/css/js)
from .models.config import ModelConfig # loads models.json + resolves names
from .routes import proxy, auth as auth_routes # proxy + auth (google oauth) routers
from .routes import me as me_routes # per-user billing API (/api/me/*)
from .routes import pricing as pricing_routes # public pricing catalog (/api/pricing)
from .services.forwarder import make_client, set_shared_client # shared httpx connection pool
from .auth import session as sess # session cookie verify for the gate
logging.basicConfig( # configure root logging once at import time
level=logging.INFO, # INFO level (warnings/errors always show)
format="%(asctime)s %(levelname)-7s %(name)s: %(message)s", # timestamp + level + module + msg
)
logger = logging.getLogger("router.app") # named logger for this module
BASE_DIR = Path(__file__).resolve().parent.parent # repo root (one level up from src/)
DATA_DIR = BASE_DIR / "data" # data folder next to src/, holds models.json
PROXY_KEY = os.environ.get("KEY", "") # Bearer token clients must present
WHITELIST_IPS = set( # parse newline-separated IP allowlist
line.strip() for line in os.environ.get("IPS", "").splitlines() if line.strip() # ignore blank lines
)
def get_client_ip(request) -> str: # get the real client IP, respecting proxies
"""Get the real client IP, respecting X-Forwarded-For from load balancers.""" # docstring
xff = request.headers.get("x-forwarded-for", "") # the proxy chain header
if xff: # we're behind a load balancer (HF Spaces, nginx, Cloudflare)
return xff.split(",")[0].strip() # first entry = the original client
return request.client.host if request.client else "" # direct connection; empty if unknown
def extract_bearer(auth_header: str) -> str: # safely extract a Bearer token
"""Safely extract a Bearer token. Never raises IndexError.""" # docstring
if not auth_header: # no header at all
return "" # empty token -> will fail auth
parts = auth_header.split(" ", 1) # split on first space only (maxsplit=1)
if len(parts) == 2 and parts[0].lower() == "bearer": # valid "Bearer <token>" shape
return parts[1].strip() # return the token
return "" # malformed header -> empty token -> auth fails
def _peek_model(body_bytes: bytes) -> str: # pull the "model" field from a request body without consuming it
"""Best-effort extract the `model` field from a JSON body (for the free-model gate).""" # docstring
if not body_bytes: # GET/DELETE have no body
return "" # no model
try: # attempt to parse
parsed = json.loads(body_bytes) # parse JSON
except Exception: # not valid JSON
return "" # can't peek
if isinstance(parsed, dict): # only objects carry a model field
return str(parsed.get("model") or "") # the model name (or empty)
return "" # arrays/scalars -> no model
def create_app() -> FastAPI: # application factory (lets tests configure cleanly)
app = FastAPI(title="router", docs_url=None, redoc_url=None) # no docs endpoints exposed
# CORS: allow all origins WITHOUT credentials (we use Bearer, not cookies). # comment header
# allow_credentials=True + wildcard origins violates the CORS spec. # spec note
app.add_middleware( # enable CORS so browser clients can call us
CORSMiddleware, # Starlette's CORS impl
allow_origins=["*"], # allow any origin
allow_credentials=False, # Bearer tokens, not cookies (spec-safe with wildcard)
allow_methods=["*"], # all verbs
allow_headers=["*"], # all headers
)
config = ModelConfig(DATA_DIR) # load models.json into the config
client = make_client() # create the shared connection pool
set_shared_client(client) # expose it to auth routes (no circular import)
proxy.init_router(config, client) # wire the router with the loaded config + pool
pricing_routes.init_config(config) # wire the pricing catalog with the loaded config
logger.info("loaded %d models, %d total targets", # log startup summary
len(config.models), config.total_targets())
@app.on_event("startup") # runs once the server is ready to accept requests
async def startup(): # async startup tasks (needs the event loop)
from .services.billing import gate as billing_gate, registry as billing_registry # local import
if billing_gate.is_enabled(): # billing secrets present
try: # bootstrap/load the registry gist (creates it on first-ever run)
await billing_registry.get_or_create() # load users into the in-memory cache
logger.info("billing enabled: %d users in registry", # log the user count
len(billing_registry._CACHE)) # from the in-memory cache
from .services.billing.tracker import TRACKER # background coroutines
TRACKER.start() # launch the usage flusher (persists balances every 30s)
except Exception as exc: # registry load failed (bad token, gist gone, etc.)
logger.error("billing registry load failed: %s — running UNMETERED", exc) # degrade gracefully
@app.on_event("shutdown") # runs on app termination (SIGTERM/SIGINT)
async def shutdown(): # clean up resources before exit
"""Close the shared connection pool cleanly.""" # docstring
from .services.billing.tracker import TRACKER # background coroutines
await TRACKER.stop() # cancel the flusher + future coroutines
await client.aclose() # release all pooled connections
logger.info("connection pool closed") # confirm shutdown
@app.get("/api/status") # dashboard stats (public, no auth) — moved off "/" for static hosting
async def status(): # lightweight status summary for the dashboard
return { # JSON status blob
"service": "router", # service name
"models": len(config.models), # how many models are loaded
"targets": config.total_targets(), # total targets across all models
"down": proxy.ROUTER.health.down_count(), # currently-cooled-down targets
"model_list": list(config.models.keys()), # all model names
"google": auth_routes.gg.is_configured(), # is Google login enabled?
}
@app.get("/health") # health check endpoint (no auth)
async def health(): # used by HF Space / uptime monitors
return {"status": "ok", "models": len(config.models), # status + counts
"down": proxy.ROUTER.health.down_count()} # currently down
@app.middleware("http") # runs before every route
async def gate(request, call_next): # auth + IP whitelist + billing gate
client_ip = get_client_ip(request) # real client IP (respects XFF)
if WHITELIST_IPS and client_ip not in WHITELIST_IPS: # if a whitelist is set and caller isn't on it
logger.warning("forbidden IP: %s", client_ip) # log the rejection
return JSONResponse({"error": "forbidden"}, status_code=403) # reject
path = request.url.path # the request path
# ONLY proxy endpoints under /v1/ (except the public models list) require auth.
# Everything else (website pages, css/js, /auth/*, /health, /api/status) is public.
needs_auth = path.startswith("/v1/") and path != "/v1/models" # API proxy calls only
if not needs_auth: # public path
return await call_next(request) # no auth required
# --- resolve caller identity (3 tiers: admin key > user key > session) ---
from .services.billing import gate as billing_gate, provision as billing_provision # lazy import
from .services.billing import registry as billing_registry # user lookup
user_id = None # the billing user_id (None = not a billing user)
is_admin = False # admin tier (master key) bypasses billing
session_payload = None # the decoded session (for lazy-provision identity)
# tier 1+2: Bearer token (master admin key OR a user dr_u_ key)
auth_header = request.headers.get("authorization", "") # the Authorization header
resolved_uid, resolved_admin = billing_gate.resolve(auth_header, PROXY_KEY) # resolve
if resolved_admin: # master key matched
is_admin = True # admin: unmetered
elif resolved_uid: # a valid user API key
user_id = resolved_uid # billing user identified
# tier 3: Google session cookie (dashboard users)
if user_id is None and not is_admin: # no bearer resolved; try the session
cookie = request.cookies.get(sess.COOKIE_NAME, "") # read the session cookie
session_payload = sess.verify_session(cookie) # validate signature + expiry
if session_payload: # valid logged-in user
user_id = session_payload["uid"] # google subject = billing user_id
# auth check: must be admin, a billing user, or match the legacy PROXY_KEY
token = extract_bearer(auth_header) # safe bearer extraction
legacy_key_ok = bool(PROXY_KEY) and token == PROXY_KEY and is_admin # admin covers this
if user_id is None and not is_admin: # no identity resolved at all
return JSONResponse({"error": "unauthorized"}, status_code=401) # reject
# --- billing gate (users only; admin skips) ---
if billing_gate.is_enabled() and user_id is not None: # metered tier
# lazy-provision: if the user has a session but isn't registered yet
# (e.g. a session from before billing existed), provision them now.
if not billing_registry.find_user(user_id): # not in the registry cache
uname = (session_payload or {}).get("username", "user") # best-effort display name
try: # provision on-the-fly (grants the free-credit bundle)
await billing_provision.provision(user_id, "", uname) # one-time gist write
except Exception as exc: # provisioning failed (gist error, etc.)
logger.warning("lazy provision failed for %s: %s", user_id, exc) # log + continue unmetered
# free-tier models (free:...) bypass the balance gate entirely —
# a user with zero credits can still use them. Peek the model from
# the body to decide. (Starlette caches the body, so the proxy
# handler re-reading it below is a no-op cache hit.)
body_bytes = await request.body() # read + cache the request body
model_name = _peek_model(body_bytes) # extract the model field (best-effort)
is_free_model = model_name.startswith("free:") # free models never need credits
# balance check: no credits = 402 (but NOT for free models)
if not is_free_model and not await billing_gate.check_balance(user_id): # balance is 0 and model is paid
return JSONResponse( # payment required
{"error": {"message": "insufficient credits. Deposit ETH to your wallet to continue.", # billing error
"type": "insufficient_quota"}},
status_code=402) # HTTP 402 Payment Required
# stash identity for the router (so it can burn after the response)
request.state.user_id = user_id # None for admin, uid for billing users
request.state.is_admin = is_admin # admin flag
return await call_next(request) # pass through to the route
app.include_router(auth_routes.router) # mount /auth/* (google oauth + session)
app.include_router(me_routes.router) # mount /api/me/* (billing console)
app.include_router(pricing_routes.router) # mount /api/pricing (public catalog)
app.include_router(proxy.router) # mount /v1/* proxy routes
# Mount the website LAST so it only catches non-API paths. # static file serving
# The auth + proxy routers above claim /auth/* and /v1/*; everything else # (login.html, dashboard.html, css/, js/, assets/) lands here.
static_dir = BASE_DIR / "static" # the website files live here (copied from PUBLIC/ in Docker)
if static_dir.is_dir(): # only mount if the dir exists (skip in test envs)
app.mount("/", StaticFiles(directory=str(static_dir), html=True), name="static") # serve site at /
return app # hand back the configured app
app = create_app() # module-level instance for `uvicorn src.app:app`