NanoBotAIAgent's picture
Upgrade proxy to match standard (chat UI, hop-by-hop, api-info)
1cd335f verified
import asyncio
import json
import os
import time
from contextlib import asynccontextmanager
import httpx
from fastapi import FastAPI, Request, Response
from fastapi.responses import StreamingResponse, HTMLResponse, JSONResponse
LLAMA_HOST = os.getenv("LLAMA_HOST", "127.0.0.1")
LLAMA_PORT = int(os.getenv("LLAMA_PORT", "8080"))
LLAMA_URL = f"http://{LLAMA_HOST}:{LLAMA_PORT}"
HOP_BY_HOP = {
"content-length","transfer-encoding","content-encoding","connection",
"keep-alive","proxy-authenticate","proxy-authorization","te","trailers","upgrade",
}
def clean_headers(headers):
return {k: v for k, v in headers.items() if k.lower() not in HOP_BY_HOP}
async def wait_for_llama(timeout: float = 600.0):
start = time.time()
async with httpx.AsyncClient() as client:
while time.time() - start < timeout:
try:
r = await client.get(f"{LLAMA_URL}/health", timeout=2)
if r.status_code == 200:
return True
except Exception:
pass
await asyncio.sleep(1)
return False
@asynccontextmanager
async def lifespan(app: FastAPI):
await wait_for_llama()
yield
app = FastAPI(lifespan=lifespan)
http_client = httpx.AsyncClient(base_url=LLAMA_URL, timeout=None)
CHAT_HTML_PATH = os.path.join(os.path.dirname(__file__), "chat.html")
try:
with open(CHAT_HTML_PATH, "r", encoding="utf-8") as _f:
CHAT_HTML = _f.read()
except Exception:
CHAT_HTML = "<h1>Chat UI not found</h1>"
@app.get("/health")
async def health():
return {"status": "ok"}
@app.get("/", response_class=HTMLResponse)
async def root():
return HTMLResponse(CHAT_HTML)
@app.get("/api-info")
async def api_info():
return JSONResponse({"status": "ok", "llama_server": LLAMA_URL})
@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "OPTIONS", "HEAD", "PATCH"])
async def proxy(request: Request, path: str):
url = httpx.URL(path="/" + path, query=request.url.query.encode("utf-8"))
headers = clean_headers(dict(request.headers))
headers.pop("host", None)
body = await request.body()
is_stream = False
if request.method == "POST" and path.startswith("v1/"):
try:
payload = json.loads(body)
payload.pop("model", None)
is_stream = bool(payload.get("stream", False))
body = json.dumps(payload).encode()
except Exception:
pass
if is_stream:
async def event_stream():
async with http_client.stream(request.method, url, headers=headers, content=body) as upstream:
async for chunk in upstream.aiter_raw():
yield chunk
return StreamingResponse(event_stream(), media_type="text/event-stream")
upstream = await http_client.request(method=request.method, url=url, headers=headers, content=body)
return Response(content=upstream.content, status_code=upstream.status_code,
headers=clean_headers(dict(upstream.headers)), media_type=upstream.headers.get("content-type"))