File size: 3,086 Bytes
9df7259
 
 
 
 
 
 
 
1cd335f
9df7259
 
 
 
 
1cd335f
 
 
 
9df7259
1cd335f
 
 
 
9df7259
 
 
 
 
 
 
 
 
 
 
 
 
 
1cd335f
9df7259
 
 
1cd335f
9df7259
1cd335f
 
 
 
 
 
9df7259
1cd335f
 
 
 
 
9df7259
1cd335f
9df7259
1cd335f
 
 
9df7259
 
 
1cd335f
 
9df7259
 
1cd335f
 
9df7259
 
 
1cd335f
9df7259
 
 
1cd335f
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import asyncio
import json
import os
import time
from contextlib import asynccontextmanager

import httpx
from fastapi import FastAPI, Request, Response
from fastapi.responses import StreamingResponse, HTMLResponse, JSONResponse

LLAMA_HOST = os.getenv("LLAMA_HOST", "127.0.0.1")
LLAMA_PORT = int(os.getenv("LLAMA_PORT", "8080"))
LLAMA_URL = f"http://{LLAMA_HOST}:{LLAMA_PORT}"

HOP_BY_HOP = {
    "content-length","transfer-encoding","content-encoding","connection",
    "keep-alive","proxy-authenticate","proxy-authorization","te","trailers","upgrade",
}

def clean_headers(headers):
    return {k: v for k, v in headers.items() if k.lower() not in HOP_BY_HOP}

async def wait_for_llama(timeout: float = 600.0):
    start = time.time()
    async with httpx.AsyncClient() as client:
        while time.time() - start < timeout:
            try:
                r = await client.get(f"{LLAMA_URL}/health", timeout=2)
                if r.status_code == 200:
                    return True
            except Exception:
                pass
            await asyncio.sleep(1)
    return False

@asynccontextmanager
async def lifespan(app: FastAPI):
    await wait_for_llama()
    yield

app = FastAPI(lifespan=lifespan)
http_client = httpx.AsyncClient(base_url=LLAMA_URL, timeout=None)

CHAT_HTML_PATH = os.path.join(os.path.dirname(__file__), "chat.html")
try:
    with open(CHAT_HTML_PATH, "r", encoding="utf-8") as _f:
        CHAT_HTML = _f.read()
except Exception:
    CHAT_HTML = "<h1>Chat UI not found</h1>"

@app.get("/health")
async def health():
    return {"status": "ok"}

@app.get("/", response_class=HTMLResponse)
async def root():
    return HTMLResponse(CHAT_HTML)

@app.get("/api-info")
async def api_info():
    return JSONResponse({"status": "ok", "llama_server": LLAMA_URL})

@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "OPTIONS", "HEAD", "PATCH"])
async def proxy(request: Request, path: str):
    url = httpx.URL(path="/" + path, query=request.url.query.encode("utf-8"))
    headers = clean_headers(dict(request.headers))
    headers.pop("host", None)
    body = await request.body()
    is_stream = False
    if request.method == "POST" and path.startswith("v1/"):
        try:
            payload = json.loads(body)
            payload.pop("model", None)
            is_stream = bool(payload.get("stream", False))
            body = json.dumps(payload).encode()
        except Exception:
            pass
    if is_stream:
        async def event_stream():
            async with http_client.stream(request.method, url, headers=headers, content=body) as upstream:
                async for chunk in upstream.aiter_raw():
                    yield chunk
        return StreamingResponse(event_stream(), media_type="text/event-stream")
    upstream = await http_client.request(method=request.method, url=url, headers=headers, content=body)
    return Response(content=upstream.content, status_code=upstream.status_code,
                    headers=clean_headers(dict(upstream.headers)), media_type=upstream.headers.get("content-type"))