| | |
| | """ |
| | ChromaDB Auth Proxy (robust passthrough) |
| | - Bearer auth at the edge |
| | - Streams/buffers appropriately |
| | - Preserves Content-Type, avoids JSON re-serialization |
| | - Reuses a single AsyncClient (HTTP/2, pooled) |
| | - Filters hop-by-hop headers |
| | - Maps network errors to 502/504 |
| | """ |
| |
|
| | import asyncio |
| | import logging |
| | import os |
| | import time |
| | from contextlib import asynccontextmanager |
| | from typing import AsyncGenerator, Dict |
| |
|
| | import httpx |
| | from fastapi import Depends, FastAPI, HTTPException, Request |
| | from fastapi.responses import Response, StreamingResponse |
| | from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer |
| | import uvicorn |
| |
|
| | |
| | logging.basicConfig( |
| | level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" |
| | ) |
| | logger = logging.getLogger(__name__) |
| |
|
| | |
| | |
| | |
| | CHROMA_HOST = os.getenv("CHROMA_HOST", "localhost") |
| | CHROMA_PORT = int(os.getenv("CHROMA_PORT", "8001")) |
| | PROXY_PORT = int(os.getenv("PROXY_PORT", "7860")) |
| | AUTH_TOKEN = os.getenv("CHROMA_AUTH_TOKEN", "test_token_123") |
| |
|
| | |
| | TIMEOUT_CONNECT = 10.0 |
| | TIMEOUT_READ = 60.0 * 8 |
| | TIMEOUT_WRITE = 60.0 * 2 |
| | TIMEOUT_POOL = None |
| |
|
| | |
| | |
| | |
| | security = HTTPBearer() |
| |
|
| |
|
| | |
| | |
| | |
| | @asynccontextmanager |
| | async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: |
| | """Manage application lifespan - startup and shutdown""" |
| | logger.info("π Starting ChromaDB Auth Proxy lifespan") |
| | yield |
| | logger.info("π Shutting down ChromaDB Auth Proxy") |
| | await _client.aclose() |
| |
|
| |
|
| | app = FastAPI(title="ChromaDB Auth Proxy", lifespan=lifespan) |
| |
|
| |
|
| | @app.get("/") |
| | async def root(): |
| | return {"status": "ok", "service": "chromadb-auth-proxy"} |
| |
|
| |
|
| | @app.get("/health") |
| | async def health(): |
| | return {"status": "healthy", "service": "chromadb-auth-proxy"} |
| |
|
| |
|
| | async def verify_token(credentials: HTTPAuthorizationCredentials = Depends(security)): |
| | if credentials.credentials != AUTH_TOKEN: |
| | raise HTTPException( |
| | status_code=401, |
| | detail="Invalid authentication token", |
| | headers={"WWW-Authenticate": "Bearer"}, |
| | ) |
| | return credentials |
| |
|
| |
|
| | |
| | |
| | |
| | |
| | _client = httpx.AsyncClient( |
| | http2=True, |
| | timeout=httpx.Timeout( |
| | connect=TIMEOUT_CONNECT, |
| | read=TIMEOUT_READ, |
| | write=TIMEOUT_WRITE, |
| | pool=TIMEOUT_POOL, |
| | ), |
| | limits=httpx.Limits(max_keepalive_connections=20, max_connections=100), |
| | ) |
| |
|
| |
|
| | |
| | HOP_BY_HOP = { |
| | "connection", |
| | "keep-alive", |
| | "proxy-authenticate", |
| | "proxy-authorization", |
| | "te", |
| | "trailer", |
| | "transfer-encoding", |
| | "upgrade", |
| | } |
| |
|
| | |
| | PASS_HEADERS = { |
| | "content-type", |
| | "cache-control", |
| | "etag", |
| | "last-modified", |
| | "expires", |
| | "vary", |
| | "location", |
| | "content-disposition", |
| | "content-encoding", |
| | "x-chroma-trace-id", |
| | } |
| |
|
| |
|
| | def _filter_resp_headers(upstream: httpx.Response) -> Dict[str, str]: |
| | """Drop hop-by-hop and computed headers; keep useful ones.""" |
| | out: Dict[str, str] = {} |
| | for k, v in upstream.headers.items(): |
| | kl = k.lower() |
| | if kl in HOP_BY_HOP: |
| | continue |
| | if kl in PASS_HEADERS: |
| | out[k] = v |
| | return out |
| |
|
| |
|
| | @app.api_route( |
| | "/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD", "OPTIONS"] |
| | ) |
| | async def proxy_request(request: Request, path: str, _=Depends(verify_token)): |
| | start_time = time.time() |
| | target_url = f"http://{CHROMA_HOST}:{CHROMA_PORT}/{path}" |
| |
|
| | |
| | if request.method == "DELETE": |
| | logger.warning( |
| | f"β οΈ DELETE operation - may take up to {int(TIMEOUT_READ)}s for large collections" |
| | ) |
| |
|
| | |
| | params = dict(request.query_params) |
| |
|
| | |
| | fwd_headers = {} |
| | for k, v in request.headers.items(): |
| | kl = k.lower() |
| | if kl in ("host", "authorization"): |
| | continue |
| | fwd_headers[k] = v |
| |
|
| | |
| | body = None |
| | body_size = 0 |
| | if request.method in {"POST", "PUT", "PATCH"}: |
| | body = await request.body() |
| | body_size = len(body) |
| | logger.info(f" Request body size: {body_size} bytes") |
| |
|
| | try: |
| | upstream_start = time.time() |
| |
|
| | async with _client.stream( |
| | method=request.method, |
| | url=target_url, |
| | params=params, |
| | headers=fwd_headers, |
| | content=body, |
| | ) as upstream: |
| | upstream_time = time.time() - upstream_start |
| | status = upstream.status_code |
| | resp_headers = _filter_resp_headers(upstream) |
| |
|
| | logger.info( |
| | f" β
Upstream response: {status} (took {upstream_time:.2f}s)" |
| | ) |
| |
|
| | |
| | if request.method == "HEAD" or status == 204: |
| | total_time = time.time() - start_time |
| | logger.info( |
| | f" π€ Returning HEAD/204 response (total: {total_time:.2f}s)" |
| | ) |
| | return Response(status_code=status, headers=resp_headers) |
| |
|
| | ctype = upstream.headers.get("content-type", "") |
| |
|
| | |
| | if ctype.startswith("application/json"): |
| | json_start = time.time() |
| | data = await upstream.aread() |
| | json_time = time.time() - json_start |
| | total_time = time.time() - start_time |
| | logger.info( |
| | f" π€ Returning JSON response: {len(data)} bytes (json: {json_time:.2f}s, total: {total_time:.2f}s)" |
| | ) |
| | return Response( |
| | content=data, |
| | status_code=status, |
| | headers=resp_headers, |
| | media_type=ctype, |
| | ) |
| |
|
| | |
| | async def _aiter(): |
| | chunk_count = 0 |
| | total_bytes = 0 |
| | async for chunk in upstream.aiter_raw(): |
| | if chunk: |
| | chunk_count += 1 |
| | total_bytes += len(chunk) |
| | yield chunk |
| | |
| | await asyncio.sleep(0) |
| | logger.info(f" π€ Streamed {chunk_count} chunks, {total_bytes} bytes") |
| |
|
| | return StreamingResponse( |
| | _aiter(), |
| | status_code=status, |
| | headers=resp_headers, |
| | media_type=ctype or None, |
| | ) |
| |
|
| | except httpx.ConnectTimeout: |
| | total_time = time.time() - start_time |
| | logger.error(f" β Connect timeout after {total_time:.2f}s") |
| | raise HTTPException(status_code=504, detail="Chroma upstream connect timeout") |
| | except httpx.ReadTimeout: |
| | total_time = time.time() - start_time |
| | logger.error(f" β Read timeout after {total_time:.2f}s") |
| | raise HTTPException(status_code=504, detail="Chroma upstream read timeout") |
| | except httpx.ConnectError as e: |
| | total_time = time.time() - start_time |
| | logger.error(f" β Connect error after {total_time:.2f}s: {e}") |
| | raise HTTPException( |
| | status_code=502, detail=f"Chroma upstream connect error: {e}" |
| | ) |
| | except httpx.TransportError as e: |
| | total_time = time.time() - start_time |
| | logger.error(f" β Transport error after {total_time:.2f}s: {e}") |
| | raise HTTPException( |
| | status_code=502, detail=f"Chroma upstream transport error: {e}" |
| | ) |
| | except Exception as e: |
| | total_time = time.time() - start_time |
| | logger.error(f" β Unexpected error after {total_time:.2f}s: {e}") |
| | raise HTTPException(status_code=500, detail=f"Internal proxy error: {e}") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | print("π Starting ChromaDB Auth Proxy") |
| | print(f" Proxy URL: http://0.0.0.0:{PROXY_PORT}") |
| | print(f" ChromaDB URL: http://{CHROMA_HOST}:{CHROMA_PORT}") |
| | print( |
| | f" Timeouts: connect={int(TIMEOUT_CONNECT)}s, read={int(TIMEOUT_READ)}s, write={int(TIMEOUT_WRITE)}s" |
| | ) |
| | print(f" Logging: INFO level") |
| | logger.info("ChromaDB Auth Proxy starting up") |
| | uvicorn.run(app, host="0.0.0.0", port=PROXY_PORT) |
| |
|