Spaces:
Sleeping
Sleeping
| """ | |
| Simple Python Load Balancer | |
| A lightweight load balancer for vLLM and Reranker services. | |
| Uses FastAPI to forward requests to multiple backend services. | |
| """ | |
| import os | |
| import sys | |
| import time | |
| import asyncio | |
| import httpx | |
| from pathlib import Path | |
| from typing import List, Dict, Optional, Any | |
| from threading import Lock | |
| import argparse | |
| try: | |
| from fastapi import FastAPI, Request, HTTPException | |
| from fastapi.responses import StreamingResponse, Response | |
| from fastapi.middleware.cors import CORSMiddleware | |
| import uvicorn | |
| HAS_FASTAPI = True | |
| except ImportError: | |
| HAS_FASTAPI = False | |
| print("Warning: FastAPI not installed. Install with: pip install fastapi uvicorn httpx") | |
| class SimpleLoadBalancer: | |
| """Simple load balancer with round-robin and least-connection strategies""" | |
| def __init__( | |
| self, | |
| backends: List[str], | |
| strategy: str = "round_robin", | |
| health_check_interval: float = 10.0, | |
| ): | |
| """ | |
| Initialize load balancer | |
| Args: | |
| backends: List of backend URLs (e.g., ["http://localhost:8000", "http://localhost:8001"]) | |
| strategy: Load balancing strategy ("round_robin" or "least_conn") | |
| health_check_interval: Health check interval in seconds | |
| """ | |
| self.backends = backends | |
| self.strategy = strategy | |
| self.health_check_interval = health_check_interval | |
| # Round-robin state | |
| self.current_index = 0 | |
| self.index_lock = Lock() | |
| # Least-connection state | |
| self.connection_counts: Dict[str, int] = {backend: 0 for backend in backends} | |
| self.conn_lock = Lock() | |
| # Health check state | |
| self.healthy_backends: Dict[str, bool] = {backend: True for backend in backends} | |
| self.health_lock = Lock() | |
| # HTTP client for forwarding requests | |
| self.client = httpx.AsyncClient(timeout=300.0) | |
| print(f"Load balancer initialized with {len(backends)} backends") | |
| print(f"Strategy: {strategy}") | |
| for i, backend in enumerate(backends): | |
| print(f" [{i+1}] {backend}") | |
| def get_backend(self) -> Optional[str]: | |
| """Get next backend based on strategy""" | |
| with self.health_lock: | |
| available_backends = [b for b in self.backends if self.healthy_backends.get(b, True)] | |
| if not available_backends: | |
| # If no healthy backends, try all backends | |
| available_backends = self.backends | |
| if not available_backends: | |
| return None | |
| if self.strategy == "round_robin": | |
| with self.index_lock: | |
| backend = available_backends[self.current_index % len(available_backends)] | |
| self.current_index = (self.current_index + 1) % len(available_backends) | |
| return backend | |
| elif self.strategy == "least_conn": | |
| with self.conn_lock: | |
| # Find backend with least connections | |
| backend = min(available_backends, key=lambda b: self.connection_counts.get(b, 0)) | |
| self.connection_counts[backend] = self.connection_counts.get(backend, 0) + 1 | |
| return backend | |
| else: | |
| # Default to round-robin | |
| with self.index_lock: | |
| backend = available_backends[self.current_index % len(available_backends)] | |
| self.current_index = (self.current_index + 1) % len(available_backends) | |
| return backend | |
| def release_backend(self, backend: str): | |
| """Release a backend (for least-conn strategy)""" | |
| if self.strategy == "least_conn": | |
| with self.conn_lock: | |
| self.connection_counts[backend] = max(0, self.connection_counts.get(backend, 0) - 1) | |
| async def health_check(self, backend: str) -> bool: | |
| """Check if a backend is healthy""" | |
| try: | |
| # For vLLM backends (URLs ending with /v1), use /models endpoint | |
| # For other backends, try /health first, then root | |
| if backend.endswith("/v1"): | |
| # vLLM endpoint: try /models (which becomes /v1/models) | |
| endpoints = ["/models", "/"] | |
| else: | |
| # Other services: try /health, then root | |
| endpoints = ["/health", "/"] | |
| for endpoint in endpoints: | |
| try: | |
| response = await self.client.get(f"{backend}{endpoint}", timeout=5.0) | |
| if response.status_code < 500: | |
| return True | |
| except: | |
| continue | |
| return False | |
| except Exception: | |
| return False | |
| async def check_all_backends(self): | |
| """Check health of all backends""" | |
| while True: | |
| for backend in self.backends: | |
| is_healthy = await self.health_check(backend) | |
| with self.health_lock: | |
| self.healthy_backends[backend] = is_healthy | |
| if not is_healthy: | |
| print(f"Warning: Backend {backend} is unhealthy") | |
| await asyncio.sleep(self.health_check_interval) | |
| async def forward_request( | |
| self, | |
| method: str, | |
| path: str, | |
| request: Request, | |
| backend: Optional[str] = None | |
| ) -> Response: | |
| """Forward a request to a backend""" | |
| if backend is None: | |
| backend = self.get_backend() | |
| if backend is None: | |
| raise HTTPException(status_code=503, detail="No healthy backends available") | |
| try: | |
| # Get request body | |
| body = await request.body() | |
| # Get query parameters | |
| query_params = dict(request.query_params) | |
| # Get headers (exclude host and connection) | |
| headers = dict(request.headers) | |
| headers.pop("host", None) | |
| headers.pop("connection", None) | |
| headers.pop("content-length", None) | |
| # Forward request | |
| url = f"{backend}{path}" | |
| if query_params: | |
| url += "?" + "&".join(f"{k}={v}" for k, v in query_params.items()) | |
| response = await self.client.request( | |
| method=method, | |
| url=url, | |
| content=body, | |
| headers=headers, | |
| ) | |
| # Create response | |
| return Response( | |
| content=response.content, | |
| status_code=response.status_code, | |
| headers=dict(response.headers), | |
| ) | |
| except Exception as e: | |
| # Mark backend as unhealthy | |
| with self.health_lock: | |
| self.healthy_backends[backend] = False | |
| self.release_backend(backend) | |
| raise HTTPException(status_code=502, detail=f"Backend error: {str(e)}") | |
| finally: | |
| self.release_backend(backend) | |
| async def forward_streaming_request( | |
| self, | |
| method: str, | |
| path: str, | |
| request: Request, | |
| backend: Optional[str] = None | |
| ): | |
| """Forward a streaming request to a backend""" | |
| if backend is None: | |
| backend = self.get_backend() | |
| if backend is None: | |
| raise HTTPException(status_code=503, detail="No healthy backends available") | |
| try: | |
| # Get request body | |
| body = await request.body() | |
| # Get query parameters | |
| query_params = dict(request.query_params) | |
| # Get headers | |
| headers = dict(request.headers) | |
| headers.pop("host", None) | |
| headers.pop("connection", None) | |
| headers.pop("content-length", None) | |
| # Forward request | |
| url = f"{backend}{path}" | |
| if query_params: | |
| url += "?" + "&".join(f"{k}={v}" for k, v in query_params.items()) | |
| async with httpx.AsyncClient(timeout=300.0) as client: | |
| async with client.stream( | |
| method=method, | |
| url=url, | |
| content=body, | |
| headers=headers, | |
| ) as response: | |
| async def generate(): | |
| async for chunk in response.aiter_bytes(): | |
| yield chunk | |
| return StreamingResponse( | |
| generate(), | |
| status_code=response.status_code, | |
| headers=dict(response.headers), | |
| ) | |
| except Exception as e: | |
| # Mark backend as unhealthy | |
| with self.health_lock: | |
| self.healthy_backends[backend] = False | |
| self.release_backend(backend) | |
| raise HTTPException(status_code=502, detail=f"Backend error: {str(e)}") | |
| finally: | |
| self.release_backend(backend) | |
| def create_load_balancer_app( | |
| backends: List[str], | |
| strategy: str = "round_robin", | |
| health_check_interval: float = 10.0, | |
| ) -> FastAPI: | |
| """Create FastAPI app with load balancer""" | |
| if not HAS_FASTAPI: | |
| raise RuntimeError("FastAPI not installed. Install with: pip install fastapi uvicorn httpx") | |
| app = FastAPI(title="Simple Load Balancer", version="1.0.0") | |
| # Add CORS middleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Create load balancer | |
| lb = SimpleLoadBalancer(backends, strategy, health_check_interval) | |
| # Start health check task | |
| async def start_health_check(): | |
| asyncio.create_task(lb.check_all_backends()) | |
| # Health check endpoint | |
| async def health(): | |
| healthy_count = sum(1 for h in lb.healthy_backends.values() if h) | |
| return { | |
| "status": "healthy" if healthy_count > 0 else "unhealthy", | |
| "healthy_backends": healthy_count, | |
| "total_backends": len(lb.backends), | |
| "backends": [ | |
| { | |
| "url": backend, | |
| "healthy": lb.healthy_backends.get(backend, False), | |
| "connections": lb.connection_counts.get(backend, 0) if lb.strategy == "least_conn" else None, | |
| } | |
| for backend in lb.backends | |
| ] | |
| } | |
| # Forward all other requests | |
| async def forward(request: Request, path: str): | |
| method = request.method | |
| is_streaming = "stream" in request.query_params or "stream=true" in str(request.url) | |
| if is_streaming: | |
| return await lb.forward_streaming_request(method, f"/{path}", request) | |
| else: | |
| return await lb.forward_request(method, f"/{path}", request) | |
| return app | |
| def main(): | |
| """Main entry point for load balancer""" | |
| parser = argparse.ArgumentParser(description="Simple Python Load Balancer") | |
| parser.add_argument( | |
| "--backends", | |
| type=str, | |
| nargs="+", | |
| required=True, | |
| help="Backend URLs (e.g., http://localhost:8000 http://localhost:8001)" | |
| ) | |
| parser.add_argument( | |
| "--host", | |
| type=str, | |
| default="0.0.0.0", | |
| help="Host to bind to (default: 0.0.0.0)" | |
| ) | |
| parser.add_argument( | |
| "--port", | |
| type=int, | |
| default=8000, | |
| help="Port to bind to (default: 8000)" | |
| ) | |
| parser.add_argument( | |
| "--strategy", | |
| type=str, | |
| default="round_robin", | |
| choices=["round_robin", "least_conn"], | |
| help="Load balancing strategy (default: round_robin)" | |
| ) | |
| parser.add_argument( | |
| "--health-check-interval", | |
| type=float, | |
| default=10.0, | |
| help="Health check interval in seconds (default: 10.0)" | |
| ) | |
| args = parser.parse_args() | |
| if not HAS_FASTAPI: | |
| print("Error: FastAPI not installed. Install with: pip install fastapi uvicorn httpx") | |
| sys.exit(1) | |
| # Create app | |
| app = create_load_balancer_app( | |
| backends=args.backends, | |
| strategy=args.strategy, | |
| health_check_interval=args.health_check_interval, | |
| ) | |
| # Run server | |
| print(f"Starting load balancer on {args.host}:{args.port}") | |
| print(f"Strategy: {args.strategy}") | |
| print(f"Backends: {', '.join(args.backends)}") | |
| uvicorn.run( | |
| app, | |
| host=args.host, | |
| port=args.port, | |
| log_level="info" | |
| ) | |
| if __name__ == "__main__": | |
| main() | |
| # python -m shared.utils.load_balancer --backends http://localhost:8000 http://localhost:8001 http://localhost:8002 http://localhost:8003 --port 8004 --strategy round_robin |