File size: 3,552 Bytes
dbd8120 9cc8b76 dbd8120 9cc8b76 f348ff6 dbd8120 9cc8b76 dbd8120 f348ff6 9cc8b76 dbd8120 9cc8b76 f348ff6 9cc8b76 f348ff6 9cc8b76 f348ff6 9cc8b76 f348ff6 9cc8b76 f348ff6 9cc8b76 f348ff6 9cc8b76 f348ff6 9cc8b76 f348ff6 9cc8b76 f348ff6 9cc8b76 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 | import asyncio
import json
import os
import time
from contextlib import asynccontextmanager
import httpx
from fastapi import FastAPI, Request, Response
from fastapi.responses import StreamingResponse, HTMLResponse, JSONResponse
LLAMA_HOST = os.getenv("LLAMA_HOST", "127.0.0.1")
LLAMA_PORT = int(os.getenv("LLAMA_PORT", "8080"))
LLAMA_URL = f"http://{LLAMA_HOST}:{LLAMA_PORT}"
# Headers that must NOT be copied verbatim. Stripping framing headers from
# both request and response avoids
# "Too little data for declared Content-Length" errors (we mutate the JSON
# body, which changes its length).
HOP_BY_HOP = {
"content-length",
"transfer-encoding",
"content-encoding",
"connection",
"keep-alive",
"proxy-authenticate",
"proxy-authorization",
"te",
"trailers",
"upgrade",
}
def clean_headers(headers):
return {k: v for k, v in headers.items() if k.lower() not in HOP_BY_HOP}
async def wait_for_llama(timeout: float = 600.0):
start = time.time()
async with httpx.AsyncClient() as client:
while time.time() - start < timeout:
try:
r = await client.get(f"{LLAMA_URL}/health", timeout=2)
if r.status_code == 200:
return True
except Exception:
pass
await asyncio.sleep(1)
return False
@asynccontextmanager
async def lifespan(app: FastAPI):
await wait_for_llama()
yield
app = FastAPI(lifespan=lifespan)
http_client = httpx.AsyncClient(base_url=LLAMA_URL, timeout=None)
CHAT_HTML_PATH = os.path.join(os.path.dirname(__file__), "chat.html")
try:
with open(CHAT_HTML_PATH, "r", encoding="utf-8") as _f:
CHAT_HTML = _f.read()
except Exception:
CHAT_HTML = "<h1>Chat UI not found</h1>"
@app.get("/health")
async def health():
return {"status": "ok"}
@app.get("/", response_class=HTMLResponse)
async def root():
return HTMLResponse(CHAT_HTML)
@app.get("/api-info")
async def api_info():
return JSONResponse({"status": "ok", "llama_server": LLAMA_URL})
@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "OPTIONS", "HEAD", "PATCH"])
async def proxy(request: Request, path: str):
url = httpx.URL(path="/" + path, query=request.url.query.encode("utf-8"))
headers = clean_headers(dict(request.headers))
headers.pop("host", None)
body = await request.body()
# Detect streaming requests and strip the (ignored) model field
is_stream = False
if request.method == "POST" and path.startswith("v1/"):
try:
payload = json.loads(body)
payload.pop("model", None)
is_stream = bool(payload.get("stream", False))
body = json.dumps(payload).encode()
except Exception:
pass
if is_stream:
async def event_stream():
async with http_client.stream(
request.method, url, headers=headers, content=body
) as upstream:
async for chunk in upstream.aiter_raw():
yield chunk
return StreamingResponse(
event_stream(),
media_type="text/event-stream",
)
upstream = await http_client.request(
method=request.method,
url=url,
headers=headers,
content=body,
)
return Response(
content=upstream.content,
status_code=upstream.status_code,
headers=clean_headers(dict(upstream.headers)),
media_type=upstream.headers.get("content-type"),
)
|