File size: 5,948 Bytes
3dbff85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
"""Riprap Inference Space β€” bearer-auth proxy on port 7860.

Forwards /v1/chat/completions and /v1/embeddings to Ollama on
localhost:11434 (which exposes an OpenAI-compatible surface), and
forwards everything else to riprap-models on localhost:7861.

A single shared secret (env var RIPRAP_PROXY_TOKEN) gates all
inbound calls; clients pass it as `Authorization: Bearer <token>`.
The two UI Spaces (lablab + personal mirror) carry the same token
in their RIPRAP_LLM_API_KEY env var.

Streaming endpoints (SSE for chat completions) are forwarded with
correct chunk-by-chunk relay; non-streaming endpoints are buffered.
"""
from __future__ import annotations

import os
from typing import AsyncIterator

import httpx
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import JSONResponse, Response, StreamingResponse

OLLAMA_URL = "http://127.0.0.1:11434"
MODELS_URL = "http://127.0.0.1:7861"

PROXY_TOKEN = os.environ.get("RIPRAP_PROXY_TOKEN", "")

app = FastAPI(title="Riprap Inference Proxy")


def _check_auth(request: Request) -> None:
    if not PROXY_TOKEN:
        raise HTTPException(503, "RIPRAP_PROXY_TOKEN not set on the inference Space")
    auth = request.headers.get("authorization", "")
    if not auth.startswith("Bearer "):
        raise HTTPException(401, "missing bearer token")
    if auth.removeprefix("Bearer ").strip() != PROXY_TOKEN:
        raise HTTPException(401, "invalid bearer token")


@app.get("/")
def root():
    # HF Spaces hits / for health on idle-wakeup. Don't require auth.
    return {"service": "riprap-inference", "ok": True}


@app.get("/healthz")
async def healthz():
    out = {"proxy": "ok"}
    async with httpx.AsyncClient(timeout=5) as client:
        try:
            r = await client.get(f"{OLLAMA_URL}/api/tags")
            out["ollama"] = "ok" if r.status_code == 200 else f"http_{r.status_code}"
        except Exception as e:
            out["ollama"] = f"err: {type(e).__name__}"
        try:
            r = await client.get(f"{MODELS_URL}/healthz")
            out["riprap_models"] = "ok" if r.status_code == 200 else f"http_{r.status_code}"
        except Exception as e:
            out["riprap_models"] = f"err: {type(e).__name__}"
    return out


# ── Ollama (OpenAI-compat) routes ─────────────────────────────────────
async def _stream_passthrough(upstream: httpx.Response) -> AsyncIterator[bytes]:
    async for chunk in upstream.aiter_raw():
        yield chunk


async def _proxy_post(upstream_base: str, path: str, request: Request,
                      *, timeout: float = 300.0) -> Response:
    body = await request.body()
    headers = {
        "content-type": request.headers.get("content-type", "application/json"),
        "accept": request.headers.get("accept", "*/*"),
    }
    is_stream = b'"stream":true' in body or b'"stream": true' in body
    client = httpx.AsyncClient(timeout=timeout)
    upstream_req = client.build_request(
        "POST", f"{upstream_base}{path}", content=body, headers=headers
    )
    upstream = await client.send(upstream_req, stream=is_stream)

    if is_stream:
        return StreamingResponse(
            _stream_passthrough(upstream),
            status_code=upstream.status_code,
            media_type=upstream.headers.get("content-type", "text/event-stream"),
            background=upstream.aclose,  # close upstream when client disconnects
        )
    content = await upstream.aread()
    await upstream.aclose()
    await client.aclose()
    return Response(
        content=content,
        status_code=upstream.status_code,
        media_type=upstream.headers.get("content-type", "application/json"),
    )


@app.post("/v1/chat/completions")
async def chat_completions(request: Request) -> Response:
    _check_auth(request)
    return await _proxy_post(OLLAMA_URL, "/v1/chat/completions", request)


@app.post("/v1/completions")
async def completions(request: Request) -> Response:
    _check_auth(request)
    return await _proxy_post(OLLAMA_URL, "/v1/completions", request)


@app.post("/v1/embeddings")
async def embeddings(request: Request) -> Response:
    """OpenAI-style embeddings. Routed to riprap-models's granite-embed
    endpoint, which returns the same {data: [{embedding: [...]}]} shape."""
    _check_auth(request)
    return await _proxy_post(MODELS_URL, "/v1/granite-embed", request)


@app.get("/v1/models")
async def models(request: Request) -> Response:
    _check_auth(request)
    async with httpx.AsyncClient(timeout=10) as client:
        r = await client.get(f"{OLLAMA_URL}/v1/models")
        return Response(content=r.content, status_code=r.status_code,
                        media_type=r.headers.get("content-type", "application/json"))


# ── riprap-models (specialist ML) routes ──────────────────────────────
@app.post("/v1/prithvi-pluvial")
async def prithvi_pluvial(request: Request) -> Response:
    _check_auth(request)
    return await _proxy_post(MODELS_URL, "/v1/prithvi-pluvial", request)


@app.post("/v1/terramind")
async def terramind(request: Request) -> Response:
    _check_auth(request)
    return await _proxy_post(MODELS_URL, "/v1/terramind", request)


@app.post("/v1/ttm-forecast")
async def ttm_forecast(request: Request) -> Response:
    _check_auth(request)
    return await _proxy_post(MODELS_URL, "/v1/ttm-forecast", request)


@app.post("/v1/gliner-extract")
async def gliner_extract(request: Request) -> Response:
    _check_auth(request)
    return await _proxy_post(MODELS_URL, "/v1/gliner-extract", request)


# Catch-all for any riprap-models endpoints not explicitly listed above.
@app.api_route("/v1/{path:path}", methods=["POST"])
async def catch_all(path: str, request: Request) -> Response:
    _check_auth(request)
    return await _proxy_post(MODELS_URL, f"/v1/{path}", request)