Spaces:
Paused
Paused
| """Riprap vLLM Space — bearer-auth proxy on port 7860. | |
| Forwards /v1/chat/completions and /v1/completions to vLLM on | |
| localhost:8000 (which exposes the OpenAI-compatible surface | |
| natively), and forwards specialist endpoints to riprap-models on | |
| localhost:7861. | |
| Same auth shape as the Ollama-backed riprap-inference Space — both | |
| UI Spaces (lablab + msradam/riprap) carry the shared bearer in | |
| RIPRAP_LLM_API_KEY. | |
| GPU power | |
| --------- | |
| A background sampler reads `nvmlDeviceGetPowerUsage` every 100 ms | |
| into a ring buffer. Each forwarded POST records the wall-clock | |
| window of the upstream call and reports: | |
| X-GPU-Power-W mean draw (W) across the window | |
| X-GPU-Energy-J energy (J) over the window | |
| X-GPU-Duration-S forwarded-call duration in seconds | |
| The `/v1/power` GET also exposes the instantaneous reading for | |
| clients that prefer to bracket their own work with two samples | |
| (used by the LLM client path where LiteLLM hides response headers). | |
| Reading from NVML costs <1 ms per sample; the ring buffer holds 60 s | |
| of 100 ms samples (600 entries). Sampler degrades to a no-op if NVML | |
| init fails (unlikely on an L4 Space, but possible on CPU-only sims). | |
| """ | |
| from __future__ import annotations | |
| import asyncio | |
| import logging | |
| import os | |
| import time | |
| from collections import deque | |
| from typing import AsyncIterator | |
| import httpx | |
| from fastapi import FastAPI, HTTPException, Request | |
| from fastapi.responses import JSONResponse, Response, StreamingResponse | |
| log = logging.getLogger("riprap.proxy") | |
| VLLM_URL = "http://127.0.0.1:8000" | |
| MODELS_URL = "http://127.0.0.1:7861" | |
| PROXY_TOKEN = os.environ.get("RIPRAP_PROXY_TOKEN", "") | |
| app = FastAPI(title="Riprap vLLM Proxy") | |
| # --------------------------------------------------------------------------- | |
| # GPU power sampler | |
| # --------------------------------------------------------------------------- | |
| # Ring buffer of (unix_ts, power_w) samples. 600 entries × 100 ms = 60 s | |
| # of history, which covers the longest single inference call we'd see | |
| # (vLLM cold-compile is ~120 s but that surfaces as multiple shorter | |
| # reads from inside vLLM's loop, not as a single forwarded POST). | |
| _SAMPLES: deque[tuple[float, float]] = deque(maxlen=600) | |
| _SAMPLER_TASK: asyncio.Task | None = None | |
| _NVML_OK: bool = False | |
| _NVML_HANDLE = None | |
| _NVML_ERR: str | None = None | |
| def _init_nvml() -> None: | |
| """Best-effort NVML init. On failure we record the error string and | |
| leave _NVML_OK=False — the proxy still serves traffic, just without | |
| real power data.""" | |
| global _NVML_OK, _NVML_HANDLE, _NVML_ERR | |
| try: | |
| import pynvml | |
| pynvml.nvmlInit() | |
| # Single-GPU L4 Space — device 0 is the L4. If a future deploy | |
| # uses multi-GPU we'd average across handles, but that's not | |
| # the current shape. | |
| _NVML_HANDLE = pynvml.nvmlDeviceGetHandleByIndex(0) | |
| # Probe once to confirm power query works. | |
| pynvml.nvmlDeviceGetPowerUsage(_NVML_HANDLE) | |
| _NVML_OK = True | |
| log.info("NVML initialized for GPU 0") | |
| except Exception as e: # noqa: BLE001 | |
| _NVML_ERR = f"{type(e).__name__}: {e}" | |
| _NVML_OK = False | |
| log.warning("NVML init failed (%s); power data will be unavailable", | |
| _NVML_ERR) | |
| def _read_power_w() -> float | None: | |
| """Instantaneous package power in watts. None if NVML is dead.""" | |
| if not _NVML_OK: | |
| return None | |
| try: | |
| import pynvml | |
| # nvmlDeviceGetPowerUsage returns milliwatts. | |
| mw = pynvml.nvmlDeviceGetPowerUsage(_NVML_HANDLE) | |
| return mw / 1000.0 | |
| except Exception: | |
| return None | |
| async def _power_sampler() -> None: | |
| """Background loop, 100 ms cadence. Cheap (~1 ms NVML query).""" | |
| while True: | |
| p = _read_power_w() | |
| if p is not None: | |
| _SAMPLES.append((time.time(), p)) | |
| await asyncio.sleep(0.1) | |
| def _avg_power_over(t0: float, t1: float) -> float | None: | |
| """Mean of samples in the [t0, t1] window. Returns None when no | |
| samples landed in the window (callers fall back to a single | |
| instantaneous read).""" | |
| if not _SAMPLES: | |
| return None | |
| bucket = [p for ts, p in _SAMPLES if t0 <= ts <= t1] | |
| if not bucket: | |
| # Window may be too short / sampler hadn't ticked yet — return | |
| # the most recent reading we have as the next-best signal. | |
| return _SAMPLES[-1][1] if _SAMPLES else None | |
| return sum(bucket) / len(bucket) | |
| async def _startup() -> None: | |
| _init_nvml() | |
| if _NVML_OK: | |
| global _SAMPLER_TASK | |
| _SAMPLER_TASK = asyncio.create_task(_power_sampler()) | |
| async def _shutdown() -> None: | |
| if _SAMPLER_TASK is not None: | |
| _SAMPLER_TASK.cancel() | |
| if _NVML_OK: | |
| try: | |
| import pynvml | |
| pynvml.nvmlShutdown() | |
| except Exception: | |
| pass | |
| # --------------------------------------------------------------------------- | |
| # Auth + routing | |
| # --------------------------------------------------------------------------- | |
| def _check_auth(request: Request) -> None: | |
| if not PROXY_TOKEN: | |
| raise HTTPException(503, "RIPRAP_PROXY_TOKEN not set on the inference Space") | |
| auth = request.headers.get("authorization", "") | |
| if not auth.startswith("Bearer "): | |
| raise HTTPException(401, "missing bearer token") | |
| if auth.removeprefix("Bearer ").strip() != PROXY_TOKEN: | |
| raise HTTPException(401, "invalid bearer token") | |
| def root(): | |
| return {"service": "riprap-vllm", "ok": True, | |
| "nvml": _NVML_OK, | |
| "nvml_err": None if _NVML_OK else _NVML_ERR} | |
| async def vllm_log(request: Request, lines: int = 100) -> Response: | |
| """Last N lines of $HOME/vllm.log — operator diagnostic.""" | |
| _check_auth(request) | |
| import os | |
| log_path = os.path.join(os.environ.get("HOME", "/home/user"), "vllm.log") | |
| try: | |
| with open(log_path) as f: | |
| tail = f.readlines()[-lines:] | |
| return JSONResponse({"ok": True, "path": log_path, | |
| "lines": len(tail), "log": "".join(tail)}) | |
| except FileNotFoundError: | |
| return JSONResponse({"ok": False, "err": "vllm.log not found"}, status_code=404) | |
| except Exception as e: | |
| return JSONResponse({"ok": False, "err": str(e)}, status_code=500) | |
| async def healthz(): | |
| out = {"proxy": "ok", "nvml": "ok" if _NVML_OK else f"err: {_NVML_ERR}"} | |
| async with httpx.AsyncClient(timeout=5) as client: | |
| try: | |
| r = await client.get(f"{VLLM_URL}/health") | |
| out["vllm"] = "ok" if r.status_code == 200 else f"http_{r.status_code}" | |
| except Exception as e: | |
| out["vllm"] = f"err: {type(e).__name__}" | |
| try: | |
| r = await client.get(f"{MODELS_URL}/healthz") | |
| if r.status_code == 200: | |
| out["riprap_models"] = "ok" | |
| # Bubble up the loaded-model list + last-error map so | |
| # operators can diagnose without hitting /v1/diag. | |
| try: | |
| body = r.json() | |
| out["riprap_models_loaded"] = body.get("models_loaded") | |
| out["riprap_models_last_errors"] = body.get("last_errors") | |
| except Exception: | |
| pass | |
| else: | |
| out["riprap_models"] = f"http_{r.status_code}" | |
| except Exception as e: | |
| out["riprap_models"] = f"err: {type(e).__name__}" | |
| return out | |
| async def diag(request: Request) -> Response: | |
| """Forward to riprap-models /v1/diag (auth-required). | |
| Operator-only diagnostic snapshot — what's loaded, last per-stage | |
| error with traceback tail, and CUDA memory state per device.""" | |
| _check_auth(request) | |
| async with httpx.AsyncClient(timeout=10) as client: | |
| try: | |
| r = await client.get(f"{MODELS_URL}/v1/diag") | |
| except Exception as e: | |
| return JSONResponse({"ok": False, | |
| "err": f"upstream: {type(e).__name__}: {e}"}, | |
| status_code=503) | |
| return Response(content=r.content, | |
| status_code=r.status_code, | |
| media_type=r.headers.get("content-type", "application/json")) | |
| async def power(request: Request) -> Response: | |
| """Instantaneous and recent-window GPU power (W). | |
| Used by the LLM client path: LiteLLM doesn't surface response | |
| headers, so the client samples /v1/power before/after its | |
| chat.completions call to bracket the energy reading. The recent | |
| 1-second average smooths over the 100 ms sampler cadence. | |
| """ | |
| _check_auth(request) | |
| if not _NVML_OK: | |
| return JSONResponse( | |
| {"ok": False, "err": _NVML_ERR or "NVML unavailable"}, | |
| status_code=503, | |
| ) | |
| now = time.time() | |
| inst = _read_power_w() | |
| avg_1s = _avg_power_over(now - 1.0, now) if _SAMPLES else None | |
| avg_5s = _avg_power_over(now - 5.0, now) if _SAMPLES else None | |
| return JSONResponse({ | |
| "ok": True, | |
| "ts": now, | |
| "power_w": inst, | |
| "power_w_avg_1s": avg_1s, | |
| "power_w_avg_5s": avg_5s, | |
| "samples_held": len(_SAMPLES), | |
| "device": "NVIDIA L4", | |
| }) | |
| # --------------------------------------------------------------------------- | |
| # Forwarding | |
| # --------------------------------------------------------------------------- | |
| async def _stream_passthrough(upstream: httpx.Response, | |
| add_headers: dict[str, str]) -> AsyncIterator[bytes]: | |
| # Streaming responses can't carry headers added after the first byte — | |
| # we set them on the StreamingResponse before yielding. add_headers | |
| # is captured by reference in the caller via a closure. | |
| async for chunk in upstream.aiter_raw(): | |
| yield chunk | |
| async def _proxy_post(upstream_base: str, path: str, request: Request, | |
| *, timeout: float = 300.0) -> Response: | |
| body = await request.body() | |
| headers = { | |
| "content-type": request.headers.get("content-type", "application/json"), | |
| "accept": request.headers.get("accept", "*/*"), | |
| } | |
| is_stream = b'"stream":true' in body or b'"stream": true' in body | |
| client = httpx.AsyncClient(timeout=timeout) | |
| upstream_req = client.build_request( | |
| "POST", f"{upstream_base}{path}", content=body, headers=headers | |
| ) | |
| t0 = time.time() | |
| upstream = await client.send(upstream_req, stream=is_stream) | |
| if is_stream: | |
| # We can't measure end-of-stream wallclock without consuming the | |
| # body, but the client sees per-call duration on its side. For | |
| # streaming we record headers describing only the start power | |
| # snapshot — useful as a sanity signal, not a true energy. | |
| snap = _read_power_w() | |
| hdrs = { | |
| "x-gpu-power-w": f"{snap:.2f}" if snap is not None else "", | |
| "x-gpu-stream": "1", | |
| } | |
| return StreamingResponse( | |
| _stream_passthrough(upstream, hdrs), | |
| status_code=upstream.status_code, | |
| media_type=upstream.headers.get("content-type", "text/event-stream"), | |
| headers=hdrs, | |
| background=upstream.aclose, | |
| ) | |
| content = await upstream.aread() | |
| await upstream.aclose() | |
| await client.aclose() | |
| t1 = time.time() | |
| duration_s = max(t1 - t0, 0.0) | |
| extra_headers: dict[str, str] = {} | |
| if _NVML_OK: | |
| avg_w = _avg_power_over(t0, t1) | |
| if avg_w is not None: | |
| energy_j = avg_w * duration_s | |
| extra_headers = { | |
| "x-gpu-power-w": f"{avg_w:.3f}", | |
| "x-gpu-energy-j": f"{energy_j:.3f}", | |
| "x-gpu-duration-s": f"{duration_s:.3f}", | |
| "x-gpu-device": "NVIDIA L4", | |
| } | |
| media_type = upstream.headers.get("content-type", "application/json") | |
| response_headers = dict(extra_headers) | |
| return Response( | |
| content=content, | |
| status_code=upstream.status_code, | |
| media_type=media_type, | |
| headers=response_headers, | |
| ) | |
| # vLLM (OpenAI-compat) routes | |
| async def chat_completions(request: Request) -> Response: | |
| _check_auth(request) | |
| return await _proxy_post(VLLM_URL, "/v1/chat/completions", request) | |
| async def completions(request: Request) -> Response: | |
| _check_auth(request) | |
| return await _proxy_post(VLLM_URL, "/v1/completions", request) | |
| async def embeddings(request: Request) -> Response: | |
| """Routed to riprap-models's granite-embed (vLLM doesn't serve our | |
| embedding model).""" | |
| _check_auth(request) | |
| return await _proxy_post(MODELS_URL, "/v1/granite-embed", request) | |
| async def models(request: Request) -> Response: | |
| _check_auth(request) | |
| async with httpx.AsyncClient(timeout=10) as client: | |
| r = await client.get(f"{VLLM_URL}/v1/models") | |
| return Response(content=r.content, status_code=r.status_code, | |
| media_type=r.headers.get("content-type", "application/json")) | |
| # riprap-models specialist routes | |
| async def prithvi_pluvial(request: Request) -> Response: | |
| _check_auth(request) | |
| return await _proxy_post(MODELS_URL, "/v1/prithvi-pluvial", request) | |
| async def terramind(request: Request) -> Response: | |
| _check_auth(request) | |
| return await _proxy_post(MODELS_URL, "/v1/terramind", request) | |
| async def ttm_forecast(request: Request) -> Response: | |
| _check_auth(request) | |
| return await _proxy_post(MODELS_URL, "/v1/ttm-forecast", request) | |
| async def gliner_extract(request: Request) -> Response: | |
| _check_auth(request) | |
| return await _proxy_post(MODELS_URL, "/v1/gliner-extract", request) | |
| async def catch_all(path: str, request: Request) -> Response: | |
| _check_auth(request) | |
| return await _proxy_post(MODELS_URL, f"/v1/{path}", request) | |