File size: 3,264 Bytes
803b895 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 | import asyncio
import json
import os
import time
from contextlib import asynccontextmanager
import httpx
from fastapi import FastAPI, Request, Response
from fastapi.responses import StreamingResponse, HTMLResponse, JSONResponse
LLAMA_HOST = os.getenv("LLAMA_HOST", "127.0.0.1")
LLAMA_PORT = int(os.getenv("LLAMA_PORT", "8080"))
LLAMA_URL = f"http://{LLAMA_HOST}:{LLAMA_PORT}"
HOP_BY_HOP = {
"content-length",
"transfer-encoding",
"content-encoding",
"connection",
"keep-alive",
"proxy-authenticate",
"proxy-authorization",
"te",
"trailers",
"upgrade",
}
def clean_headers(headers):
return {k: v for k, v in headers.items() if k.lower() not in HOP_BY_HOP}
async def wait_for_llama(timeout: float = 600.0):
start = time.time()
async with httpx.AsyncClient() as client:
while time.time() - start < timeout:
try:
r = await client.get(f"{LLAMA_URL}/health", timeout=2)
if r.status_code == 200:
return True
except Exception:
pass
await asyncio.sleep(1)
return False
@asynccontextmanager
async def lifespan(app: FastAPI):
await wait_for_llama()
yield
app = FastAPI(lifespan=lifespan)
http_client = httpx.AsyncClient(base_url=LLAMA_URL, timeout=None)
CHAT_HTML_PATH = os.path.join(os.path.dirname(__file__), "chat.html")
try:
with open(CHAT_HTML_PATH, "r", encoding="utf-8") as _f:
CHAT_HTML = _f.read()
except Exception:
CHAT_HTML = "<h1>Chat UI not found</h1>"
@app.get("/health")
async def health():
return {"status": "ok"}
@app.get("/", response_class=HTMLResponse)
async def root():
return HTMLResponse(CHAT_HTML)
@app.get("/api-info")
async def api_info():
return JSONResponse({"status": "ok", "llama_server": LLAMA_URL})
@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "OPTIONS", "HEAD", "PATCH"])
async def proxy(request: Request, path: str):
url = httpx.URL(path="/" + path, query=request.url.query.encode("utf-8"))
headers = clean_headers(dict(request.headers))
headers.pop("host", None)
body = await request.body()
is_stream = False
if request.method == "POST" and path.startswith("v1/"):
try:
payload = json.loads(body)
payload.pop("model", None)
is_stream = bool(payload.get("stream", False))
body = json.dumps(payload).encode()
except Exception:
pass
if is_stream:
async def event_stream():
async with http_client.stream(
request.method, url, headers=headers, content=body
) as upstream:
async for chunk in upstream.aiter_raw():
yield chunk
return StreamingResponse(
event_stream(),
media_type="text/event-stream",
)
upstream = await http_client.request(
method=request.method,
url=url,
headers=headers,
content=body,
)
return Response(
content=upstream.content,
status_code=upstream.status_code,
headers=clean_headers(dict(upstream.headers)),
media_type=upstream.headers.get("content-type"),
)
|