GradioDemo / shared /utils /load_balancer.py
eigentom
Initial Update
90c099b
"""
Simple Python Load Balancer
A lightweight load balancer for vLLM and Reranker services.
Uses FastAPI to forward requests to multiple backend services.
"""
import os
import sys
import time
import asyncio
import httpx
from pathlib import Path
from typing import List, Dict, Optional, Any
from threading import Lock
import argparse
try:
from fastapi import FastAPI, Request, HTTPException
from fastapi.responses import StreamingResponse, Response
from fastapi.middleware.cors import CORSMiddleware
import uvicorn
HAS_FASTAPI = True
except ImportError:
HAS_FASTAPI = False
print("Warning: FastAPI not installed. Install with: pip install fastapi uvicorn httpx")
class SimpleLoadBalancer:
"""Simple load balancer with round-robin and least-connection strategies"""
def __init__(
self,
backends: List[str],
strategy: str = "round_robin",
health_check_interval: float = 10.0,
):
"""
Initialize load balancer
Args:
backends: List of backend URLs (e.g., ["http://localhost:8000", "http://localhost:8001"])
strategy: Load balancing strategy ("round_robin" or "least_conn")
health_check_interval: Health check interval in seconds
"""
self.backends = backends
self.strategy = strategy
self.health_check_interval = health_check_interval
# Round-robin state
self.current_index = 0
self.index_lock = Lock()
# Least-connection state
self.connection_counts: Dict[str, int] = {backend: 0 for backend in backends}
self.conn_lock = Lock()
# Health check state
self.healthy_backends: Dict[str, bool] = {backend: True for backend in backends}
self.health_lock = Lock()
# HTTP client for forwarding requests
self.client = httpx.AsyncClient(timeout=300.0)
print(f"Load balancer initialized with {len(backends)} backends")
print(f"Strategy: {strategy}")
for i, backend in enumerate(backends):
print(f" [{i+1}] {backend}")
def get_backend(self) -> Optional[str]:
"""Get next backend based on strategy"""
with self.health_lock:
available_backends = [b for b in self.backends if self.healthy_backends.get(b, True)]
if not available_backends:
# If no healthy backends, try all backends
available_backends = self.backends
if not available_backends:
return None
if self.strategy == "round_robin":
with self.index_lock:
backend = available_backends[self.current_index % len(available_backends)]
self.current_index = (self.current_index + 1) % len(available_backends)
return backend
elif self.strategy == "least_conn":
with self.conn_lock:
# Find backend with least connections
backend = min(available_backends, key=lambda b: self.connection_counts.get(b, 0))
self.connection_counts[backend] = self.connection_counts.get(backend, 0) + 1
return backend
else:
# Default to round-robin
with self.index_lock:
backend = available_backends[self.current_index % len(available_backends)]
self.current_index = (self.current_index + 1) % len(available_backends)
return backend
def release_backend(self, backend: str):
"""Release a backend (for least-conn strategy)"""
if self.strategy == "least_conn":
with self.conn_lock:
self.connection_counts[backend] = max(0, self.connection_counts.get(backend, 0) - 1)
async def health_check(self, backend: str) -> bool:
"""Check if a backend is healthy"""
try:
# For vLLM backends (URLs ending with /v1), use /models endpoint
# For other backends, try /health first, then root
if backend.endswith("/v1"):
# vLLM endpoint: try /models (which becomes /v1/models)
endpoints = ["/models", "/"]
else:
# Other services: try /health, then root
endpoints = ["/health", "/"]
for endpoint in endpoints:
try:
response = await self.client.get(f"{backend}{endpoint}", timeout=5.0)
if response.status_code < 500:
return True
except:
continue
return False
except Exception:
return False
async def check_all_backends(self):
"""Check health of all backends"""
while True:
for backend in self.backends:
is_healthy = await self.health_check(backend)
with self.health_lock:
self.healthy_backends[backend] = is_healthy
if not is_healthy:
print(f"Warning: Backend {backend} is unhealthy")
await asyncio.sleep(self.health_check_interval)
async def forward_request(
self,
method: str,
path: str,
request: Request,
backend: Optional[str] = None
) -> Response:
"""Forward a request to a backend"""
if backend is None:
backend = self.get_backend()
if backend is None:
raise HTTPException(status_code=503, detail="No healthy backends available")
try:
# Get request body
body = await request.body()
# Get query parameters
query_params = dict(request.query_params)
# Get headers (exclude host and connection)
headers = dict(request.headers)
headers.pop("host", None)
headers.pop("connection", None)
headers.pop("content-length", None)
# Forward request
url = f"{backend}{path}"
if query_params:
url += "?" + "&".join(f"{k}={v}" for k, v in query_params.items())
response = await self.client.request(
method=method,
url=url,
content=body,
headers=headers,
)
# Create response
return Response(
content=response.content,
status_code=response.status_code,
headers=dict(response.headers),
)
except Exception as e:
# Mark backend as unhealthy
with self.health_lock:
self.healthy_backends[backend] = False
self.release_backend(backend)
raise HTTPException(status_code=502, detail=f"Backend error: {str(e)}")
finally:
self.release_backend(backend)
async def forward_streaming_request(
self,
method: str,
path: str,
request: Request,
backend: Optional[str] = None
):
"""Forward a streaming request to a backend"""
if backend is None:
backend = self.get_backend()
if backend is None:
raise HTTPException(status_code=503, detail="No healthy backends available")
try:
# Get request body
body = await request.body()
# Get query parameters
query_params = dict(request.query_params)
# Get headers
headers = dict(request.headers)
headers.pop("host", None)
headers.pop("connection", None)
headers.pop("content-length", None)
# Forward request
url = f"{backend}{path}"
if query_params:
url += "?" + "&".join(f"{k}={v}" for k, v in query_params.items())
async with httpx.AsyncClient(timeout=300.0) as client:
async with client.stream(
method=method,
url=url,
content=body,
headers=headers,
) as response:
async def generate():
async for chunk in response.aiter_bytes():
yield chunk
return StreamingResponse(
generate(),
status_code=response.status_code,
headers=dict(response.headers),
)
except Exception as e:
# Mark backend as unhealthy
with self.health_lock:
self.healthy_backends[backend] = False
self.release_backend(backend)
raise HTTPException(status_code=502, detail=f"Backend error: {str(e)}")
finally:
self.release_backend(backend)
def create_load_balancer_app(
backends: List[str],
strategy: str = "round_robin",
health_check_interval: float = 10.0,
) -> FastAPI:
"""Create FastAPI app with load balancer"""
if not HAS_FASTAPI:
raise RuntimeError("FastAPI not installed. Install with: pip install fastapi uvicorn httpx")
app = FastAPI(title="Simple Load Balancer", version="1.0.0")
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Create load balancer
lb = SimpleLoadBalancer(backends, strategy, health_check_interval)
# Start health check task
@app.on_event("startup")
async def start_health_check():
asyncio.create_task(lb.check_all_backends())
# Health check endpoint
@app.get("/health")
async def health():
healthy_count = sum(1 for h in lb.healthy_backends.values() if h)
return {
"status": "healthy" if healthy_count > 0 else "unhealthy",
"healthy_backends": healthy_count,
"total_backends": len(lb.backends),
"backends": [
{
"url": backend,
"healthy": lb.healthy_backends.get(backend, False),
"connections": lb.connection_counts.get(backend, 0) if lb.strategy == "least_conn" else None,
}
for backend in lb.backends
]
}
# Forward all other requests
@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"])
async def forward(request: Request, path: str):
method = request.method
is_streaming = "stream" in request.query_params or "stream=true" in str(request.url)
if is_streaming:
return await lb.forward_streaming_request(method, f"/{path}", request)
else:
return await lb.forward_request(method, f"/{path}", request)
return app
def main():
"""Main entry point for load balancer"""
parser = argparse.ArgumentParser(description="Simple Python Load Balancer")
parser.add_argument(
"--backends",
type=str,
nargs="+",
required=True,
help="Backend URLs (e.g., http://localhost:8000 http://localhost:8001)"
)
parser.add_argument(
"--host",
type=str,
default="0.0.0.0",
help="Host to bind to (default: 0.0.0.0)"
)
parser.add_argument(
"--port",
type=int,
default=8000,
help="Port to bind to (default: 8000)"
)
parser.add_argument(
"--strategy",
type=str,
default="round_robin",
choices=["round_robin", "least_conn"],
help="Load balancing strategy (default: round_robin)"
)
parser.add_argument(
"--health-check-interval",
type=float,
default=10.0,
help="Health check interval in seconds (default: 10.0)"
)
args = parser.parse_args()
if not HAS_FASTAPI:
print("Error: FastAPI not installed. Install with: pip install fastapi uvicorn httpx")
sys.exit(1)
# Create app
app = create_load_balancer_app(
backends=args.backends,
strategy=args.strategy,
health_check_interval=args.health_check_interval,
)
# Run server
print(f"Starting load balancer on {args.host}:{args.port}")
print(f"Strategy: {args.strategy}")
print(f"Backends: {', '.join(args.backends)}")
uvicorn.run(
app,
host=args.host,
port=args.port,
log_level="info"
)
if __name__ == "__main__":
main()
# python -m shared.utils.load_balancer --backends http://localhost:8000 http://localhost:8001 http://localhost:8002 http://localhost:8003 --port 8004 --strategy round_robin