| import asyncio |
| import json |
| import os |
| import time |
| from contextlib import asynccontextmanager |
|
|
| import httpx |
| from fastapi import FastAPI, Request, Response |
| from fastapi.responses import StreamingResponse, HTMLResponse, JSONResponse |
|
|
| LLAMA_HOST = os.getenv("LLAMA_HOST", "127.0.0.1") |
| LLAMA_PORT = int(os.getenv("LLAMA_PORT", "8080")) |
| LLAMA_URL = f"http://{LLAMA_HOST}:{LLAMA_PORT}" |
|
|
| |
| |
| |
| |
| HOP_BY_HOP = { |
| "content-length", |
| "transfer-encoding", |
| "content-encoding", |
| "connection", |
| "keep-alive", |
| "proxy-authenticate", |
| "proxy-authorization", |
| "te", |
| "trailers", |
| "upgrade", |
| } |
|
|
|
|
| def clean_headers(headers): |
| return {k: v for k, v in headers.items() if k.lower() not in HOP_BY_HOP} |
|
|
|
|
| async def wait_for_llama(timeout: float = 600.0): |
| start = time.time() |
| async with httpx.AsyncClient() as client: |
| while time.time() - start < timeout: |
| try: |
| r = await client.get(f"{LLAMA_URL}/health", timeout=2) |
| if r.status_code == 200: |
| return True |
| except Exception: |
| pass |
| await asyncio.sleep(1) |
| return False |
|
|
|
|
| @asynccontextmanager |
| async def lifespan(app: FastAPI): |
| await wait_for_llama() |
| yield |
|
|
|
|
| app = FastAPI(lifespan=lifespan) |
| http_client = httpx.AsyncClient(base_url=LLAMA_URL, timeout=None) |
|
|
|
|
| CHAT_HTML_PATH = os.path.join(os.path.dirname(__file__), "chat.html") |
| try: |
| with open(CHAT_HTML_PATH, "r", encoding="utf-8") as _f: |
| CHAT_HTML = _f.read() |
| except Exception: |
| CHAT_HTML = "<h1>Chat UI not found</h1>" |
|
|
|
|
| @app.get("/health") |
| async def health(): |
| return {"status": "ok"} |
|
|
|
|
| @app.get("/", response_class=HTMLResponse) |
| async def root(): |
| return HTMLResponse(CHAT_HTML) |
|
|
|
|
| @app.get("/api-info") |
| async def api_info(): |
| return JSONResponse({"status": "ok", "llama_server": LLAMA_URL}) |
|
|
|
|
| @app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "OPTIONS", "HEAD", "PATCH"]) |
| async def proxy(request: Request, path: str): |
| url = httpx.URL(path="/" + path, query=request.url.query.encode("utf-8")) |
| headers = clean_headers(dict(request.headers)) |
| headers.pop("host", None) |
|
|
| body = await request.body() |
|
|
| |
| is_stream = False |
| if request.method == "POST" and path.startswith("v1/"): |
| try: |
| payload = json.loads(body) |
| payload.pop("model", None) |
| is_stream = bool(payload.get("stream", False)) |
| body = json.dumps(payload).encode() |
| except Exception: |
| pass |
|
|
| if is_stream: |
| async def event_stream(): |
| async with http_client.stream( |
| request.method, url, headers=headers, content=body |
| ) as upstream: |
| async for chunk in upstream.aiter_raw(): |
| yield chunk |
|
|
| return StreamingResponse( |
| event_stream(), |
| media_type="text/event-stream", |
| ) |
|
|
| upstream = await http_client.request( |
| method=request.method, |
| url=url, |
| headers=headers, |
| content=body, |
| ) |
|
|
| return Response( |
| content=upstream.content, |
| status_code=upstream.status_code, |
| headers=clean_headers(dict(upstream.headers)), |
| media_type=upstream.headers.get("content-type"), |
| ) |
|
|