File size: 3,264 Bytes
803b895
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import asyncio
import json
import os
import time
from contextlib import asynccontextmanager

import httpx
from fastapi import FastAPI, Request, Response
from fastapi.responses import StreamingResponse, HTMLResponse, JSONResponse

LLAMA_HOST = os.getenv("LLAMA_HOST", "127.0.0.1")
LLAMA_PORT = int(os.getenv("LLAMA_PORT", "8080"))
LLAMA_URL = f"http://{LLAMA_HOST}:{LLAMA_PORT}"

HOP_BY_HOP = {
    "content-length",
    "transfer-encoding",
    "content-encoding",
    "connection",
    "keep-alive",
    "proxy-authenticate",
    "proxy-authorization",
    "te",
    "trailers",
    "upgrade",
}


def clean_headers(headers):
    return {k: v for k, v in headers.items() if k.lower() not in HOP_BY_HOP}


async def wait_for_llama(timeout: float = 600.0):
    start = time.time()
    async with httpx.AsyncClient() as client:
        while time.time() - start < timeout:
            try:
                r = await client.get(f"{LLAMA_URL}/health", timeout=2)
                if r.status_code == 200:
                    return True
            except Exception:
                pass
            await asyncio.sleep(1)
    return False


@asynccontextmanager
async def lifespan(app: FastAPI):
    await wait_for_llama()
    yield


app = FastAPI(lifespan=lifespan)
http_client = httpx.AsyncClient(base_url=LLAMA_URL, timeout=None)


CHAT_HTML_PATH = os.path.join(os.path.dirname(__file__), "chat.html")
try:
    with open(CHAT_HTML_PATH, "r", encoding="utf-8") as _f:
        CHAT_HTML = _f.read()
except Exception:
    CHAT_HTML = "<h1>Chat UI not found</h1>"


@app.get("/health")
async def health():
    return {"status": "ok"}


@app.get("/", response_class=HTMLResponse)
async def root():
    return HTMLResponse(CHAT_HTML)


@app.get("/api-info")
async def api_info():
    return JSONResponse({"status": "ok", "llama_server": LLAMA_URL})


@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "OPTIONS", "HEAD", "PATCH"])
async def proxy(request: Request, path: str):
    url = httpx.URL(path="/" + path, query=request.url.query.encode("utf-8"))
    headers = clean_headers(dict(request.headers))
    headers.pop("host", None)

    body = await request.body()

    is_stream = False
    if request.method == "POST" and path.startswith("v1/"):
        try:
            payload = json.loads(body)
            payload.pop("model", None)
            is_stream = bool(payload.get("stream", False))
            body = json.dumps(payload).encode()
        except Exception:
            pass

    if is_stream:
        async def event_stream():
            async with http_client.stream(
                request.method, url, headers=headers, content=body
            ) as upstream:
                async for chunk in upstream.aiter_raw():
                    yield chunk

        return StreamingResponse(
            event_stream(),
            media_type="text/event-stream",
        )

    upstream = await http_client.request(
        method=request.method,
        url=url,
        headers=headers,
        content=body,
    )

    return Response(
        content=upstream.content,
        status_code=upstream.status_code,
        headers=clean_headers(dict(upstream.headers)),
        media_type=upstream.headers.get("content-type"),
    )