| """Application factory: loads config, wires the router, mounts auth gate.""" # module docstring | |
| from __future__ import annotations # defer annotation evaluation (PEP 563) | |
| import logging # structured logging config + named logger | |
| import os # read env vars for KEY (auth) and IPS (whitelist) | |
| import json # stdlib json for _peek_model (free-model balance gate) | |
| from pathlib import Path # locate the data directory relative to this file | |
| from fastapi import FastAPI # the app factory target | |
| from fastapi.middleware.cors import CORSMiddleware # CORS for browser clients | |
| from fastapi.responses import JSONResponse # JSON error responses from the gate | |
| from fastapi.staticfiles import StaticFiles # serve the dreamy website (html/css/js) | |
| from .models.config import ModelConfig # loads models.json + resolves names | |
| from .routes import proxy, auth as auth_routes # proxy + auth (google oauth) routers | |
| from .routes import me as me_routes # per-user billing API (/api/me/*) | |
| from .routes import pricing as pricing_routes # public pricing catalog (/api/pricing) | |
| from .services.forwarder import make_client, set_shared_client # shared httpx connection pool | |
| from .auth import session as sess # session cookie verify for the gate | |
| logging.basicConfig( # configure root logging once at import time | |
| level=logging.INFO, # INFO level (warnings/errors always show) | |
| format="%(asctime)s %(levelname)-7s %(name)s: %(message)s", # timestamp + level + module + msg | |
| ) | |
| logger = logging.getLogger("router.app") # named logger for this module | |
| BASE_DIR = Path(__file__).resolve().parent.parent # repo root (one level up from src/) | |
| DATA_DIR = BASE_DIR / "data" # data folder next to src/, holds models.json | |
| PROXY_KEY = os.environ.get("KEY", "") # Bearer token clients must present | |
| WHITELIST_IPS = set( # parse newline-separated IP allowlist | |
| line.strip() for line in os.environ.get("IPS", "").splitlines() if line.strip() # ignore blank lines | |
| ) | |
| def get_client_ip(request) -> str: # get the real client IP, respecting proxies | |
| """Get the real client IP, respecting X-Forwarded-For from load balancers.""" # docstring | |
| xff = request.headers.get("x-forwarded-for", "") # the proxy chain header | |
| if xff: # we're behind a load balancer (HF Spaces, nginx, Cloudflare) | |
| return xff.split(",")[0].strip() # first entry = the original client | |
| return request.client.host if request.client else "" # direct connection; empty if unknown | |
| def extract_bearer(auth_header: str) -> str: # safely extract a Bearer token | |
| """Safely extract a Bearer token. Never raises IndexError.""" # docstring | |
| if not auth_header: # no header at all | |
| return "" # empty token -> will fail auth | |
| parts = auth_header.split(" ", 1) # split on first space only (maxsplit=1) | |
| if len(parts) == 2 and parts[0].lower() == "bearer": # valid "Bearer <token>" shape | |
| return parts[1].strip() # return the token | |
| return "" # malformed header -> empty token -> auth fails | |
| def _peek_model(body_bytes: bytes) -> str: # pull the "model" field from a request body without consuming it | |
| """Best-effort extract the `model` field from a JSON body (for the free-model gate).""" # docstring | |
| if not body_bytes: # GET/DELETE have no body | |
| return "" # no model | |
| try: # attempt to parse | |
| parsed = json.loads(body_bytes) # parse JSON | |
| except Exception: # not valid JSON | |
| return "" # can't peek | |
| if isinstance(parsed, dict): # only objects carry a model field | |
| return str(parsed.get("model") or "") # the model name (or empty) | |
| return "" # arrays/scalars -> no model | |
| def create_app() -> FastAPI: # application factory (lets tests configure cleanly) | |
| app = FastAPI(title="router", docs_url=None, redoc_url=None) # no docs endpoints exposed | |
| # CORS: allow all origins WITHOUT credentials (we use Bearer, not cookies). # comment header | |
| # allow_credentials=True + wildcard origins violates the CORS spec. # spec note | |
| app.add_middleware( # enable CORS so browser clients can call us | |
| CORSMiddleware, # Starlette's CORS impl | |
| allow_origins=["*"], # allow any origin | |
| allow_credentials=False, # Bearer tokens, not cookies (spec-safe with wildcard) | |
| allow_methods=["*"], # all verbs | |
| allow_headers=["*"], # all headers | |
| ) | |
| config = ModelConfig(DATA_DIR) # load models.json into the config | |
| client = make_client() # create the shared connection pool | |
| set_shared_client(client) # expose it to auth routes (no circular import) | |
| proxy.init_router(config, client) # wire the router with the loaded config + pool | |
| pricing_routes.init_config(config) # wire the pricing catalog with the loaded config | |
| logger.info("loaded %d models, %d total targets", # log startup summary | |
| len(config.models), config.total_targets()) | |
| # runs once the server is ready to accept requests | |
| async def startup(): # async startup tasks (needs the event loop) | |
| from .services.billing import gate as billing_gate, registry as billing_registry # local import | |
| if billing_gate.is_enabled(): # billing secrets present | |
| try: # bootstrap/load the registry gist (creates it on first-ever run) | |
| await billing_registry.get_or_create() # load users into the in-memory cache | |
| logger.info("billing enabled: %d users in registry", # log the user count | |
| len(billing_registry._CACHE)) # from the in-memory cache | |
| from .services.billing.tracker import TRACKER # background coroutines | |
| TRACKER.start() # launch the usage flusher (persists balances every 30s) | |
| except Exception as exc: # registry load failed (bad token, gist gone, etc.) | |
| logger.error("billing registry load failed: %s — running UNMETERED", exc) # degrade gracefully | |
| # runs on app termination (SIGTERM/SIGINT) | |
| async def shutdown(): # clean up resources before exit | |
| """Close the shared connection pool cleanly.""" # docstring | |
| from .services.billing.tracker import TRACKER # background coroutines | |
| await TRACKER.stop() # cancel the flusher + future coroutines | |
| await client.aclose() # release all pooled connections | |
| logger.info("connection pool closed") # confirm shutdown | |
| # dashboard stats (public, no auth) — moved off "/" for static hosting | |
| async def status(): # lightweight status summary for the dashboard | |
| return { # JSON status blob | |
| "service": "router", # service name | |
| "models": len(config.models), # how many models are loaded | |
| "targets": config.total_targets(), # total targets across all models | |
| "down": proxy.ROUTER.health.down_count(), # currently-cooled-down targets | |
| "model_list": list(config.models.keys()), # all model names | |
| "google": auth_routes.gg.is_configured(), # is Google login enabled? | |
| } | |
| # health check endpoint (no auth) | |
| async def health(): # used by HF Space / uptime monitors | |
| return {"status": "ok", "models": len(config.models), # status + counts | |
| "down": proxy.ROUTER.health.down_count()} # currently down | |
| # runs before every route | |
| async def gate(request, call_next): # auth + IP whitelist + billing gate | |
| client_ip = get_client_ip(request) # real client IP (respects XFF) | |
| if WHITELIST_IPS and client_ip not in WHITELIST_IPS: # if a whitelist is set and caller isn't on it | |
| logger.warning("forbidden IP: %s", client_ip) # log the rejection | |
| return JSONResponse({"error": "forbidden"}, status_code=403) # reject | |
| path = request.url.path # the request path | |
| # ONLY proxy endpoints under /v1/ (except the public models list) require auth. | |
| # Everything else (website pages, css/js, /auth/*, /health, /api/status) is public. | |
| needs_auth = path.startswith("/v1/") and path != "/v1/models" # API proxy calls only | |
| if not needs_auth: # public path | |
| return await call_next(request) # no auth required | |
| # --- resolve caller identity (3 tiers: admin key > user key > session) --- | |
| from .services.billing import gate as billing_gate, provision as billing_provision # lazy import | |
| from .services.billing import registry as billing_registry # user lookup | |
| user_id = None # the billing user_id (None = not a billing user) | |
| is_admin = False # admin tier (master key) bypasses billing | |
| session_payload = None # the decoded session (for lazy-provision identity) | |
| # tier 1+2: Bearer token (master admin key OR a user dr_u_ key) | |
| auth_header = request.headers.get("authorization", "") # the Authorization header | |
| resolved_uid, resolved_admin = billing_gate.resolve(auth_header, PROXY_KEY) # resolve | |
| if resolved_admin: # master key matched | |
| is_admin = True # admin: unmetered | |
| elif resolved_uid: # a valid user API key | |
| user_id = resolved_uid # billing user identified | |
| # tier 3: Google session cookie (dashboard users) | |
| if user_id is None and not is_admin: # no bearer resolved; try the session | |
| cookie = request.cookies.get(sess.COOKIE_NAME, "") # read the session cookie | |
| session_payload = sess.verify_session(cookie) # validate signature + expiry | |
| if session_payload: # valid logged-in user | |
| user_id = session_payload["uid"] # google subject = billing user_id | |
| # auth check: must be admin, a billing user, or match the legacy PROXY_KEY | |
| token = extract_bearer(auth_header) # safe bearer extraction | |
| legacy_key_ok = bool(PROXY_KEY) and token == PROXY_KEY and is_admin # admin covers this | |
| if user_id is None and not is_admin: # no identity resolved at all | |
| return JSONResponse({"error": "unauthorized"}, status_code=401) # reject | |
| # --- billing gate (users only; admin skips) --- | |
| if billing_gate.is_enabled() and user_id is not None: # metered tier | |
| # lazy-provision: if the user has a session but isn't registered yet | |
| # (e.g. a session from before billing existed), provision them now. | |
| if not billing_registry.find_user(user_id): # not in the registry cache | |
| uname = (session_payload or {}).get("username", "user") # best-effort display name | |
| try: # provision on-the-fly (grants the free-credit bundle) | |
| await billing_provision.provision(user_id, "", uname) # one-time gist write | |
| except Exception as exc: # provisioning failed (gist error, etc.) | |
| logger.warning("lazy provision failed for %s: %s", user_id, exc) # log + continue unmetered | |
| # free-tier models (free:...) bypass the balance gate entirely — | |
| # a user with zero credits can still use them. Peek the model from | |
| # the body to decide. (Starlette caches the body, so the proxy | |
| # handler re-reading it below is a no-op cache hit.) | |
| body_bytes = await request.body() # read + cache the request body | |
| model_name = _peek_model(body_bytes) # extract the model field (best-effort) | |
| is_free_model = model_name.startswith("free:") # free models never need credits | |
| # balance check: no credits = 402 (but NOT for free models) | |
| if not is_free_model and not await billing_gate.check_balance(user_id): # balance is 0 and model is paid | |
| return JSONResponse( # payment required | |
| {"error": {"message": "insufficient credits. Deposit ETH to your wallet to continue.", # billing error | |
| "type": "insufficient_quota"}}, | |
| status_code=402) # HTTP 402 Payment Required | |
| # stash identity for the router (so it can burn after the response) | |
| request.state.user_id = user_id # None for admin, uid for billing users | |
| request.state.is_admin = is_admin # admin flag | |
| return await call_next(request) # pass through to the route | |
| app.include_router(auth_routes.router) # mount /auth/* (google oauth + session) | |
| app.include_router(me_routes.router) # mount /api/me/* (billing console) | |
| app.include_router(pricing_routes.router) # mount /api/pricing (public catalog) | |
| app.include_router(proxy.router) # mount /v1/* proxy routes | |
| # Mount the website LAST so it only catches non-API paths. # static file serving | |
| # The auth + proxy routers above claim /auth/* and /v1/*; everything else # (login.html, dashboard.html, css/, js/, assets/) lands here. | |
| static_dir = BASE_DIR / "static" # the website files live here (copied from PUBLIC/ in Docker) | |
| if static_dir.is_dir(): # only mount if the dir exists (skip in test envs) | |
| app.mount("/", StaticFiles(directory=str(static_dir), html=True), name="static") # serve site at / | |
| return app # hand back the configured app | |
| app = create_app() # module-level instance for `uvicorn src.app:app` | |