File size: 3,552 Bytes
dbd8120
 
 
 
9cc8b76
dbd8120
 
9cc8b76
f348ff6
dbd8120
9cc8b76
 
 
dbd8120
f348ff6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9cc8b76
dbd8120
9cc8b76
 
 
 
 
 
 
 
 
 
 
 
 
f348ff6
9cc8b76
 
 
 
f348ff6
 
 
 
 
 
 
 
 
 
 
 
 
 
9cc8b76
 
f348ff6
9cc8b76
f348ff6
 
 
 
 
 
9cc8b76
 
 
 
f348ff6
 
9cc8b76
 
 
 
f348ff6
 
 
9cc8b76
 
 
f348ff6
9cc8b76
 
 
 
f348ff6
 
 
 
 
 
 
 
 
 
 
 
 
 
9cc8b76
 
 
 
 
 
 
f348ff6
 
 
 
9cc8b76
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import asyncio
import json
import os
import time
from contextlib import asynccontextmanager

import httpx
from fastapi import FastAPI, Request, Response
from fastapi.responses import StreamingResponse, HTMLResponse, JSONResponse

LLAMA_HOST = os.getenv("LLAMA_HOST", "127.0.0.1")
LLAMA_PORT = int(os.getenv("LLAMA_PORT", "8080"))
LLAMA_URL = f"http://{LLAMA_HOST}:{LLAMA_PORT}"

# Headers that must NOT be copied verbatim. Stripping framing headers from
# both request and response avoids
# "Too little data for declared Content-Length" errors (we mutate the JSON
# body, which changes its length).
HOP_BY_HOP = {
    "content-length",
    "transfer-encoding",
    "content-encoding",
    "connection",
    "keep-alive",
    "proxy-authenticate",
    "proxy-authorization",
    "te",
    "trailers",
    "upgrade",
}


def clean_headers(headers):
    return {k: v for k, v in headers.items() if k.lower() not in HOP_BY_HOP}


async def wait_for_llama(timeout: float = 600.0):
    start = time.time()
    async with httpx.AsyncClient() as client:
        while time.time() - start < timeout:
            try:
                r = await client.get(f"{LLAMA_URL}/health", timeout=2)
                if r.status_code == 200:
                    return True
            except Exception:
                pass
            await asyncio.sleep(1)
    return False


@asynccontextmanager
async def lifespan(app: FastAPI):
    await wait_for_llama()
    yield


app = FastAPI(lifespan=lifespan)
http_client = httpx.AsyncClient(base_url=LLAMA_URL, timeout=None)


CHAT_HTML_PATH = os.path.join(os.path.dirname(__file__), "chat.html")
try:
    with open(CHAT_HTML_PATH, "r", encoding="utf-8") as _f:
        CHAT_HTML = _f.read()
except Exception:
    CHAT_HTML = "<h1>Chat UI not found</h1>"


@app.get("/health")
async def health():
    return {"status": "ok"}


@app.get("/", response_class=HTMLResponse)
async def root():
    return HTMLResponse(CHAT_HTML)


@app.get("/api-info")
async def api_info():
    return JSONResponse({"status": "ok", "llama_server": LLAMA_URL})


@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "OPTIONS", "HEAD", "PATCH"])
async def proxy(request: Request, path: str):
    url = httpx.URL(path="/" + path, query=request.url.query.encode("utf-8"))
    headers = clean_headers(dict(request.headers))
    headers.pop("host", None)

    body = await request.body()

    # Detect streaming requests and strip the (ignored) model field
    is_stream = False
    if request.method == "POST" and path.startswith("v1/"):
        try:
            payload = json.loads(body)
            payload.pop("model", None)
            is_stream = bool(payload.get("stream", False))
            body = json.dumps(payload).encode()
        except Exception:
            pass

    if is_stream:
        async def event_stream():
            async with http_client.stream(
                request.method, url, headers=headers, content=body
            ) as upstream:
                async for chunk in upstream.aiter_raw():
                    yield chunk

        return StreamingResponse(
            event_stream(),
            media_type="text/event-stream",
        )

    upstream = await http_client.request(
        method=request.method,
        url=url,
        headers=headers,
        content=body,
    )

    return Response(
        content=upstream.content,
        status_code=upstream.status_code,
        headers=clean_headers(dict(upstream.headers)),
        media_type=upstream.headers.get("content-type"),
    )