| import asyncio |
| import json |
| import os |
| import time |
| from contextlib import asynccontextmanager |
|
|
| import httpx |
| from fastapi import FastAPI, Request, Response |
| from fastapi.responses import StreamingResponse, HTMLResponse, JSONResponse |
|
|
| LLAMA_HOST = os.getenv("LLAMA_HOST", "127.0.0.1") |
| LLAMA_PORT = int(os.getenv("LLAMA_PORT", "8080")) |
| LLAMA_URL = f"http://{LLAMA_HOST}:{LLAMA_PORT}" |
|
|
| HOP_BY_HOP = { |
| "content-length","transfer-encoding","content-encoding","connection", |
| "keep-alive","proxy-authenticate","proxy-authorization","te","trailers","upgrade", |
| } |
|
|
| def clean_headers(headers): |
| return {k: v for k, v in headers.items() if k.lower() not in HOP_BY_HOP} |
|
|
| async def wait_for_llama(timeout: float = 600.0): |
| start = time.time() |
| async with httpx.AsyncClient() as client: |
| while time.time() - start < timeout: |
| try: |
| r = await client.get(f"{LLAMA_URL}/health", timeout=2) |
| if r.status_code == 200: |
| return True |
| except Exception: |
| pass |
| await asyncio.sleep(1) |
| return False |
|
|
| @asynccontextmanager |
| async def lifespan(app: FastAPI): |
| await wait_for_llama() |
| yield |
|
|
| app = FastAPI(lifespan=lifespan) |
| http_client = httpx.AsyncClient(base_url=LLAMA_URL, timeout=None) |
|
|
| CHAT_HTML_PATH = os.path.join(os.path.dirname(__file__), "chat.html") |
| try: |
| with open(CHAT_HTML_PATH, "r", encoding="utf-8") as _f: |
| CHAT_HTML = _f.read() |
| except Exception: |
| CHAT_HTML = "<h1>Chat UI not found</h1>" |
|
|
| @app.get("/health") |
| async def health(): |
| return {"status": "ok"} |
|
|
| @app.get("/", response_class=HTMLResponse) |
| async def root(): |
| return HTMLResponse(CHAT_HTML) |
|
|
| @app.get("/api-info") |
| async def api_info(): |
| return JSONResponse({"status": "ok", "llama_server": LLAMA_URL}) |
|
|
| @app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "OPTIONS", "HEAD", "PATCH"]) |
| async def proxy(request: Request, path: str): |
| url = httpx.URL(path="/" + path, query=request.url.query.encode("utf-8")) |
| headers = clean_headers(dict(request.headers)) |
| headers.pop("host", None) |
| body = await request.body() |
| is_stream = False |
| if request.method == "POST" and path.startswith("v1/"): |
| try: |
| payload = json.loads(body) |
| payload.pop("model", None) |
| is_stream = bool(payload.get("stream", False)) |
| body = json.dumps(payload).encode() |
| except Exception: |
| pass |
| if is_stream: |
| async def event_stream(): |
| async with http_client.stream(request.method, url, headers=headers, content=body) as upstream: |
| async for chunk in upstream.aiter_raw(): |
| yield chunk |
| return StreamingResponse(event_stream(), media_type="text/event-stream") |
| upstream = await http_client.request(method=request.method, url=url, headers=headers, content=body) |
| return Response(content=upstream.content, status_code=upstream.status_code, |
| headers=clean_headers(dict(upstream.headers)), media_type=upstream.headers.get("content-type")) |
|
|