"""Riprap vLLM Space — bearer-auth proxy on port 7860. Forwards /v1/chat/completions and /v1/completions to vLLM on localhost:8000 (which exposes the OpenAI-compatible surface natively), and forwards specialist endpoints to riprap-models on localhost:7861. Same auth shape as the Ollama-backed riprap-inference Space — both UI Spaces (lablab + msradam/riprap) carry the shared bearer in RIPRAP_LLM_API_KEY. GPU power --------- A background sampler reads `nvmlDeviceGetPowerUsage` every 100 ms into a ring buffer. Each forwarded POST records the wall-clock window of the upstream call and reports: X-GPU-Power-W mean draw (W) across the window X-GPU-Energy-J energy (J) over the window X-GPU-Duration-S forwarded-call duration in seconds The `/v1/power` GET also exposes the instantaneous reading for clients that prefer to bracket their own work with two samples (used by the LLM client path where LiteLLM hides response headers). Reading from NVML costs <1 ms per sample; the ring buffer holds 60 s of 100 ms samples (600 entries). Sampler degrades to a no-op if NVML init fails (unlikely on an L4 Space, but possible on CPU-only sims). """ from __future__ import annotations import asyncio import logging import os import time from collections import deque from typing import AsyncIterator import httpx from fastapi import FastAPI, HTTPException, Request from fastapi.responses import JSONResponse, Response, StreamingResponse log = logging.getLogger("riprap.proxy") VLLM_URL = "http://127.0.0.1:8000" MODELS_URL = "http://127.0.0.1:7861" PROXY_TOKEN = os.environ.get("RIPRAP_PROXY_TOKEN", "") app = FastAPI(title="Riprap vLLM Proxy") # --------------------------------------------------------------------------- # GPU power sampler # --------------------------------------------------------------------------- # Ring buffer of (unix_ts, power_w) samples. 600 entries × 100 ms = 60 s # of history, which covers the longest single inference call we'd see # (vLLM cold-compile is ~120 s but that surfaces as multiple shorter # reads from inside vLLM's loop, not as a single forwarded POST). _SAMPLES: deque[tuple[float, float]] = deque(maxlen=600) _SAMPLER_TASK: asyncio.Task | None = None _NVML_OK: bool = False _NVML_HANDLE = None _NVML_ERR: str | None = None def _init_nvml() -> None: """Best-effort NVML init. On failure we record the error string and leave _NVML_OK=False — the proxy still serves traffic, just without real power data.""" global _NVML_OK, _NVML_HANDLE, _NVML_ERR try: import pynvml pynvml.nvmlInit() # Single-GPU L4 Space — device 0 is the L4. If a future deploy # uses multi-GPU we'd average across handles, but that's not # the current shape. _NVML_HANDLE = pynvml.nvmlDeviceGetHandleByIndex(0) # Probe once to confirm power query works. pynvml.nvmlDeviceGetPowerUsage(_NVML_HANDLE) _NVML_OK = True log.info("NVML initialized for GPU 0") except Exception as e: # noqa: BLE001 _NVML_ERR = f"{type(e).__name__}: {e}" _NVML_OK = False log.warning("NVML init failed (%s); power data will be unavailable", _NVML_ERR) def _read_power_w() -> float | None: """Instantaneous package power in watts. None if NVML is dead.""" if not _NVML_OK: return None try: import pynvml # nvmlDeviceGetPowerUsage returns milliwatts. mw = pynvml.nvmlDeviceGetPowerUsage(_NVML_HANDLE) return mw / 1000.0 except Exception: return None async def _power_sampler() -> None: """Background loop, 100 ms cadence. Cheap (~1 ms NVML query).""" while True: p = _read_power_w() if p is not None: _SAMPLES.append((time.time(), p)) await asyncio.sleep(0.1) def _avg_power_over(t0: float, t1: float) -> float | None: """Mean of samples in the [t0, t1] window. Returns None when no samples landed in the window (callers fall back to a single instantaneous read).""" if not _SAMPLES: return None bucket = [p for ts, p in _SAMPLES if t0 <= ts <= t1] if not bucket: # Window may be too short / sampler hadn't ticked yet — return # the most recent reading we have as the next-best signal. return _SAMPLES[-1][1] if _SAMPLES else None return sum(bucket) / len(bucket) @app.on_event("startup") async def _startup() -> None: _init_nvml() if _NVML_OK: global _SAMPLER_TASK _SAMPLER_TASK = asyncio.create_task(_power_sampler()) @app.on_event("shutdown") async def _shutdown() -> None: if _SAMPLER_TASK is not None: _SAMPLER_TASK.cancel() if _NVML_OK: try: import pynvml pynvml.nvmlShutdown() except Exception: pass # --------------------------------------------------------------------------- # Auth + routing # --------------------------------------------------------------------------- def _check_auth(request: Request) -> None: if not PROXY_TOKEN: raise HTTPException(503, "RIPRAP_PROXY_TOKEN not set on the inference Space") auth = request.headers.get("authorization", "") if not auth.startswith("Bearer "): raise HTTPException(401, "missing bearer token") if auth.removeprefix("Bearer ").strip() != PROXY_TOKEN: raise HTTPException(401, "invalid bearer token") @app.get("/") def root(): return {"service": "riprap-vllm", "ok": True, "nvml": _NVML_OK, "nvml_err": None if _NVML_OK else _NVML_ERR} @app.get("/vllm-log", include_in_schema=False) async def vllm_log(request: Request, lines: int = 100) -> Response: """Last N lines of $HOME/vllm.log — operator diagnostic.""" _check_auth(request) import os log_path = os.path.join(os.environ.get("HOME", "/home/user"), "vllm.log") try: with open(log_path) as f: tail = f.readlines()[-lines:] return JSONResponse({"ok": True, "path": log_path, "lines": len(tail), "log": "".join(tail)}) except FileNotFoundError: return JSONResponse({"ok": False, "err": "vllm.log not found"}, status_code=404) except Exception as e: return JSONResponse({"ok": False, "err": str(e)}, status_code=500) @app.get("/healthz") async def healthz(): out = {"proxy": "ok", "nvml": "ok" if _NVML_OK else f"err: {_NVML_ERR}"} async with httpx.AsyncClient(timeout=5) as client: try: r = await client.get(f"{VLLM_URL}/health") out["vllm"] = "ok" if r.status_code == 200 else f"http_{r.status_code}" except Exception as e: out["vllm"] = f"err: {type(e).__name__}" try: r = await client.get(f"{MODELS_URL}/healthz") if r.status_code == 200: out["riprap_models"] = "ok" # Bubble up the loaded-model list + last-error map so # operators can diagnose without hitting /v1/diag. try: body = r.json() out["riprap_models_loaded"] = body.get("models_loaded") out["riprap_models_last_errors"] = body.get("last_errors") except Exception: pass else: out["riprap_models"] = f"http_{r.status_code}" except Exception as e: out["riprap_models"] = f"err: {type(e).__name__}" return out @app.get("/v1/diag", include_in_schema=False) async def diag(request: Request) -> Response: """Forward to riprap-models /v1/diag (auth-required). Operator-only diagnostic snapshot — what's loaded, last per-stage error with traceback tail, and CUDA memory state per device.""" _check_auth(request) async with httpx.AsyncClient(timeout=10) as client: try: r = await client.get(f"{MODELS_URL}/v1/diag") except Exception as e: return JSONResponse({"ok": False, "err": f"upstream: {type(e).__name__}: {e}"}, status_code=503) return Response(content=r.content, status_code=r.status_code, media_type=r.headers.get("content-type", "application/json")) @app.get("/v1/power") async def power(request: Request) -> Response: """Instantaneous and recent-window GPU power (W). Used by the LLM client path: LiteLLM doesn't surface response headers, so the client samples /v1/power before/after its chat.completions call to bracket the energy reading. The recent 1-second average smooths over the 100 ms sampler cadence. """ _check_auth(request) if not _NVML_OK: return JSONResponse( {"ok": False, "err": _NVML_ERR or "NVML unavailable"}, status_code=503, ) now = time.time() inst = _read_power_w() avg_1s = _avg_power_over(now - 1.0, now) if _SAMPLES else None avg_5s = _avg_power_over(now - 5.0, now) if _SAMPLES else None return JSONResponse({ "ok": True, "ts": now, "power_w": inst, "power_w_avg_1s": avg_1s, "power_w_avg_5s": avg_5s, "samples_held": len(_SAMPLES), "device": "NVIDIA L4", }) # --------------------------------------------------------------------------- # Forwarding # --------------------------------------------------------------------------- async def _stream_passthrough(upstream: httpx.Response, add_headers: dict[str, str]) -> AsyncIterator[bytes]: # Streaming responses can't carry headers added after the first byte — # we set them on the StreamingResponse before yielding. add_headers # is captured by reference in the caller via a closure. async for chunk in upstream.aiter_raw(): yield chunk async def _proxy_post(upstream_base: str, path: str, request: Request, *, timeout: float = 300.0) -> Response: body = await request.body() headers = { "content-type": request.headers.get("content-type", "application/json"), "accept": request.headers.get("accept", "*/*"), } is_stream = b'"stream":true' in body or b'"stream": true' in body client = httpx.AsyncClient(timeout=timeout) upstream_req = client.build_request( "POST", f"{upstream_base}{path}", content=body, headers=headers ) t0 = time.time() upstream = await client.send(upstream_req, stream=is_stream) if is_stream: # We can't measure end-of-stream wallclock without consuming the # body, but the client sees per-call duration on its side. For # streaming we record headers describing only the start power # snapshot — useful as a sanity signal, not a true energy. snap = _read_power_w() hdrs = { "x-gpu-power-w": f"{snap:.2f}" if snap is not None else "", "x-gpu-stream": "1", } return StreamingResponse( _stream_passthrough(upstream, hdrs), status_code=upstream.status_code, media_type=upstream.headers.get("content-type", "text/event-stream"), headers=hdrs, background=upstream.aclose, ) content = await upstream.aread() await upstream.aclose() await client.aclose() t1 = time.time() duration_s = max(t1 - t0, 0.0) extra_headers: dict[str, str] = {} if _NVML_OK: avg_w = _avg_power_over(t0, t1) if avg_w is not None: energy_j = avg_w * duration_s extra_headers = { "x-gpu-power-w": f"{avg_w:.3f}", "x-gpu-energy-j": f"{energy_j:.3f}", "x-gpu-duration-s": f"{duration_s:.3f}", "x-gpu-device": "NVIDIA L4", } media_type = upstream.headers.get("content-type", "application/json") response_headers = dict(extra_headers) return Response( content=content, status_code=upstream.status_code, media_type=media_type, headers=response_headers, ) # vLLM (OpenAI-compat) routes @app.post("/v1/chat/completions") async def chat_completions(request: Request) -> Response: _check_auth(request) return await _proxy_post(VLLM_URL, "/v1/chat/completions", request) @app.post("/v1/completions") async def completions(request: Request) -> Response: _check_auth(request) return await _proxy_post(VLLM_URL, "/v1/completions", request) @app.post("/v1/embeddings") async def embeddings(request: Request) -> Response: """Routed to riprap-models's granite-embed (vLLM doesn't serve our embedding model).""" _check_auth(request) return await _proxy_post(MODELS_URL, "/v1/granite-embed", request) @app.get("/v1/models") async def models(request: Request) -> Response: _check_auth(request) async with httpx.AsyncClient(timeout=10) as client: r = await client.get(f"{VLLM_URL}/v1/models") return Response(content=r.content, status_code=r.status_code, media_type=r.headers.get("content-type", "application/json")) # riprap-models specialist routes @app.post("/v1/prithvi-pluvial") async def prithvi_pluvial(request: Request) -> Response: _check_auth(request) return await _proxy_post(MODELS_URL, "/v1/prithvi-pluvial", request) @app.post("/v1/terramind") async def terramind(request: Request) -> Response: _check_auth(request) return await _proxy_post(MODELS_URL, "/v1/terramind", request) @app.post("/v1/ttm-forecast") async def ttm_forecast(request: Request) -> Response: _check_auth(request) return await _proxy_post(MODELS_URL, "/v1/ttm-forecast", request) @app.post("/v1/gliner-extract") async def gliner_extract(request: Request) -> Response: _check_auth(request) return await _proxy_post(MODELS_URL, "/v1/gliner-extract", request) @app.api_route("/v1/{path:path}", methods=["POST"]) async def catch_all(path: str, request: Request) -> Response: _check_auth(request) return await _proxy_post(MODELS_URL, f"/v1/{path}", request)