File size: 5,948 Bytes
3dbff85 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 | """Riprap Inference Space β bearer-auth proxy on port 7860.
Forwards /v1/chat/completions and /v1/embeddings to Ollama on
localhost:11434 (which exposes an OpenAI-compatible surface), and
forwards everything else to riprap-models on localhost:7861.
A single shared secret (env var RIPRAP_PROXY_TOKEN) gates all
inbound calls; clients pass it as `Authorization: Bearer <token>`.
The two UI Spaces (lablab + personal mirror) carry the same token
in their RIPRAP_LLM_API_KEY env var.
Streaming endpoints (SSE for chat completions) are forwarded with
correct chunk-by-chunk relay; non-streaming endpoints are buffered.
"""
from __future__ import annotations
import os
from typing import AsyncIterator
import httpx
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import JSONResponse, Response, StreamingResponse
OLLAMA_URL = "http://127.0.0.1:11434"
MODELS_URL = "http://127.0.0.1:7861"
PROXY_TOKEN = os.environ.get("RIPRAP_PROXY_TOKEN", "")
app = FastAPI(title="Riprap Inference Proxy")
def _check_auth(request: Request) -> None:
if not PROXY_TOKEN:
raise HTTPException(503, "RIPRAP_PROXY_TOKEN not set on the inference Space")
auth = request.headers.get("authorization", "")
if not auth.startswith("Bearer "):
raise HTTPException(401, "missing bearer token")
if auth.removeprefix("Bearer ").strip() != PROXY_TOKEN:
raise HTTPException(401, "invalid bearer token")
@app.get("/")
def root():
# HF Spaces hits / for health on idle-wakeup. Don't require auth.
return {"service": "riprap-inference", "ok": True}
@app.get("/healthz")
async def healthz():
out = {"proxy": "ok"}
async with httpx.AsyncClient(timeout=5) as client:
try:
r = await client.get(f"{OLLAMA_URL}/api/tags")
out["ollama"] = "ok" if r.status_code == 200 else f"http_{r.status_code}"
except Exception as e:
out["ollama"] = f"err: {type(e).__name__}"
try:
r = await client.get(f"{MODELS_URL}/healthz")
out["riprap_models"] = "ok" if r.status_code == 200 else f"http_{r.status_code}"
except Exception as e:
out["riprap_models"] = f"err: {type(e).__name__}"
return out
# ββ Ollama (OpenAI-compat) routes βββββββββββββββββββββββββββββββββββββ
async def _stream_passthrough(upstream: httpx.Response) -> AsyncIterator[bytes]:
async for chunk in upstream.aiter_raw():
yield chunk
async def _proxy_post(upstream_base: str, path: str, request: Request,
*, timeout: float = 300.0) -> Response:
body = await request.body()
headers = {
"content-type": request.headers.get("content-type", "application/json"),
"accept": request.headers.get("accept", "*/*"),
}
is_stream = b'"stream":true' in body or b'"stream": true' in body
client = httpx.AsyncClient(timeout=timeout)
upstream_req = client.build_request(
"POST", f"{upstream_base}{path}", content=body, headers=headers
)
upstream = await client.send(upstream_req, stream=is_stream)
if is_stream:
return StreamingResponse(
_stream_passthrough(upstream),
status_code=upstream.status_code,
media_type=upstream.headers.get("content-type", "text/event-stream"),
background=upstream.aclose, # close upstream when client disconnects
)
content = await upstream.aread()
await upstream.aclose()
await client.aclose()
return Response(
content=content,
status_code=upstream.status_code,
media_type=upstream.headers.get("content-type", "application/json"),
)
@app.post("/v1/chat/completions")
async def chat_completions(request: Request) -> Response:
_check_auth(request)
return await _proxy_post(OLLAMA_URL, "/v1/chat/completions", request)
@app.post("/v1/completions")
async def completions(request: Request) -> Response:
_check_auth(request)
return await _proxy_post(OLLAMA_URL, "/v1/completions", request)
@app.post("/v1/embeddings")
async def embeddings(request: Request) -> Response:
"""OpenAI-style embeddings. Routed to riprap-models's granite-embed
endpoint, which returns the same {data: [{embedding: [...]}]} shape."""
_check_auth(request)
return await _proxy_post(MODELS_URL, "/v1/granite-embed", request)
@app.get("/v1/models")
async def models(request: Request) -> Response:
_check_auth(request)
async with httpx.AsyncClient(timeout=10) as client:
r = await client.get(f"{OLLAMA_URL}/v1/models")
return Response(content=r.content, status_code=r.status_code,
media_type=r.headers.get("content-type", "application/json"))
# ββ riprap-models (specialist ML) routes ββββββββββββββββββββββββββββββ
@app.post("/v1/prithvi-pluvial")
async def prithvi_pluvial(request: Request) -> Response:
_check_auth(request)
return await _proxy_post(MODELS_URL, "/v1/prithvi-pluvial", request)
@app.post("/v1/terramind")
async def terramind(request: Request) -> Response:
_check_auth(request)
return await _proxy_post(MODELS_URL, "/v1/terramind", request)
@app.post("/v1/ttm-forecast")
async def ttm_forecast(request: Request) -> Response:
_check_auth(request)
return await _proxy_post(MODELS_URL, "/v1/ttm-forecast", request)
@app.post("/v1/gliner-extract")
async def gliner_extract(request: Request) -> Response:
_check_auth(request)
return await _proxy_post(MODELS_URL, "/v1/gliner-extract", request)
# Catch-all for any riprap-models endpoints not explicitly listed above.
@app.api_route("/v1/{path:path}", methods=["POST"])
async def catch_all(path: str, request: Request) -> Response:
_check_auth(request)
return await _proxy_post(MODELS_URL, f"/v1/{path}", request)
|