|
|
"""
|
|
|
Load Balancer for Mamba Swarm API
|
|
|
Distributes requests across multiple API server instances
|
|
|
"""
|
|
|
|
|
|
import asyncio
|
|
|
import aiohttp
|
|
|
import random
|
|
|
import time
|
|
|
import logging
|
|
|
from typing import List, Dict, Any, Optional, Tuple
|
|
|
from dataclasses import dataclass, field
|
|
|
from enum import Enum
|
|
|
from collections import defaultdict, deque
|
|
|
import json
|
|
|
import hashlib
|
|
|
|
|
|
class LoadBalancingStrategy(Enum):
|
|
|
ROUND_ROBIN = "round_robin"
|
|
|
LEAST_CONNECTIONS = "least_connections"
|
|
|
WEIGHTED_ROUND_ROBIN = "weighted_round_robin"
|
|
|
LEAST_RESPONSE_TIME = "least_response_time"
|
|
|
HASH_BASED = "hash_based"
|
|
|
RESOURCE_AWARE = "resource_aware"
|
|
|
|
|
|
@dataclass
|
|
|
class ServerInstance:
|
|
|
host: str
|
|
|
port: int
|
|
|
weight: float = 1.0
|
|
|
max_connections: int = 100
|
|
|
timeout: float = 30.0
|
|
|
current_connections: int = 0
|
|
|
total_requests: int = 0
|
|
|
failed_requests: int = 0
|
|
|
response_times: deque = field(default_factory=lambda: deque(maxlen=100))
|
|
|
last_health_check: float = 0.0
|
|
|
is_healthy: bool = True
|
|
|
health_check_failures: int = 0
|
|
|
|
|
|
@property
|
|
|
def url(self) -> str:
|
|
|
return f"http://{self.host}:{self.port}"
|
|
|
|
|
|
@property
|
|
|
def avg_response_time(self) -> float:
|
|
|
return sum(self.response_times) / len(self.response_times) if self.response_times else 0.0
|
|
|
|
|
|
@property
|
|
|
def success_rate(self) -> float:
|
|
|
total = self.total_requests
|
|
|
if total == 0:
|
|
|
return 1.0
|
|
|
return (total - self.failed_requests) / total
|
|
|
|
|
|
@property
|
|
|
def load_score(self) -> float:
|
|
|
"""Calculate load score for resource-aware balancing"""
|
|
|
connection_load = self.current_connections / self.max_connections
|
|
|
response_time_load = min(self.avg_response_time / 1000.0, 1.0)
|
|
|
failure_rate = self.failed_requests / max(self.total_requests, 1)
|
|
|
|
|
|
return (connection_load * 0.4 + response_time_load * 0.4 + failure_rate * 0.2)
|
|
|
|
|
|
class LoadBalancer:
|
|
|
"""Advanced load balancer for Mamba Swarm API servers"""
|
|
|
|
|
|
def __init__(self,
|
|
|
servers: List[Tuple[str, int]],
|
|
|
strategy: LoadBalancingStrategy = LoadBalancingStrategy.RESOURCE_AWARE,
|
|
|
health_check_interval: float = 30.0,
|
|
|
health_check_timeout: float = 5.0,
|
|
|
max_retries: int = 3):
|
|
|
|
|
|
self.logger = logging.getLogger(__name__)
|
|
|
self.strategy = strategy
|
|
|
self.health_check_interval = health_check_interval
|
|
|
self.health_check_timeout = health_check_timeout
|
|
|
self.max_retries = max_retries
|
|
|
|
|
|
|
|
|
self.servers = [
|
|
|
ServerInstance(host=host, port=port)
|
|
|
for host, port in servers
|
|
|
]
|
|
|
|
|
|
|
|
|
self.round_robin_index = 0
|
|
|
self.request_counts = defaultdict(int)
|
|
|
|
|
|
|
|
|
self.session: Optional[aiohttp.ClientSession] = None
|
|
|
|
|
|
|
|
|
self.health_check_task: Optional[asyncio.Task] = None
|
|
|
|
|
|
|
|
|
self.total_requests = 0
|
|
|
self.failed_requests = 0
|
|
|
self.start_time = time.time()
|
|
|
|
|
|
async def __aenter__(self):
|
|
|
"""Async context manager entry"""
|
|
|
await self.start()
|
|
|
return self
|
|
|
|
|
|
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
|
|
"""Async context manager exit"""
|
|
|
await self.stop()
|
|
|
|
|
|
async def start(self):
|
|
|
"""Start the load balancer"""
|
|
|
|
|
|
timeout = aiohttp.ClientTimeout(total=30.0, connect=10.0)
|
|
|
self.session = aiohttp.ClientSession(timeout=timeout)
|
|
|
|
|
|
|
|
|
self.health_check_task = asyncio.create_task(self._health_check_loop())
|
|
|
|
|
|
|
|
|
await self._check_all_servers_health()
|
|
|
|
|
|
self.logger.info(f"Load balancer started with {len(self.servers)} servers using {self.strategy.value} strategy")
|
|
|
|
|
|
async def stop(self):
|
|
|
"""Stop the load balancer"""
|
|
|
if self.health_check_task:
|
|
|
self.health_check_task.cancel()
|
|
|
try:
|
|
|
await self.health_check_task
|
|
|
except asyncio.CancelledError:
|
|
|
pass
|
|
|
|
|
|
if self.session:
|
|
|
await self.session.close()
|
|
|
|
|
|
self.logger.info("Load balancer stopped")
|
|
|
|
|
|
def get_healthy_servers(self) -> List[ServerInstance]:
|
|
|
"""Get list of healthy servers"""
|
|
|
return [server for server in self.servers if server.is_healthy]
|
|
|
|
|
|
def select_server(self, request_data: Optional[Dict[str, Any]] = None) -> Optional[ServerInstance]:
|
|
|
"""Select server based on configured strategy"""
|
|
|
healthy_servers = self.get_healthy_servers()
|
|
|
|
|
|
if not healthy_servers:
|
|
|
self.logger.warning("No healthy servers available")
|
|
|
return None
|
|
|
|
|
|
if self.strategy == LoadBalancingStrategy.ROUND_ROBIN:
|
|
|
return self._round_robin_selection(healthy_servers)
|
|
|
elif self.strategy == LoadBalancingStrategy.LEAST_CONNECTIONS:
|
|
|
return self._least_connections_selection(healthy_servers)
|
|
|
elif self.strategy == LoadBalancingStrategy.WEIGHTED_ROUND_ROBIN:
|
|
|
return self._weighted_round_robin_selection(healthy_servers)
|
|
|
elif self.strategy == LoadBalancingStrategy.LEAST_RESPONSE_TIME:
|
|
|
return self._least_response_time_selection(healthy_servers)
|
|
|
elif self.strategy == LoadBalancingStrategy.HASH_BASED:
|
|
|
return self._hash_based_selection(healthy_servers, request_data)
|
|
|
elif self.strategy == LoadBalancingStrategy.RESOURCE_AWARE:
|
|
|
return self._resource_aware_selection(healthy_servers)
|
|
|
else:
|
|
|
return random.choice(healthy_servers)
|
|
|
|
|
|
def _round_robin_selection(self, servers: List[ServerInstance]) -> ServerInstance:
|
|
|
"""Round-robin server selection"""
|
|
|
server = servers[self.round_robin_index % len(servers)]
|
|
|
self.round_robin_index += 1
|
|
|
return server
|
|
|
|
|
|
def _least_connections_selection(self, servers: List[ServerInstance]) -> ServerInstance:
|
|
|
"""Select server with least connections"""
|
|
|
return min(servers, key=lambda s: s.current_connections)
|
|
|
|
|
|
def _weighted_round_robin_selection(self, servers: List[ServerInstance]) -> ServerInstance:
|
|
|
"""Weighted round-robin selection"""
|
|
|
total_weight = sum(s.weight for s in servers)
|
|
|
random_weight = random.uniform(0, total_weight)
|
|
|
|
|
|
current_weight = 0
|
|
|
for server in servers:
|
|
|
current_weight += server.weight
|
|
|
if random_weight <= current_weight:
|
|
|
return server
|
|
|
|
|
|
return servers[-1]
|
|
|
|
|
|
def _least_response_time_selection(self, servers: List[ServerInstance]) -> ServerInstance:
|
|
|
"""Select server with least average response time"""
|
|
|
return min(servers, key=lambda s: s.avg_response_time or float('inf'))
|
|
|
|
|
|
def _hash_based_selection(self, servers: List[ServerInstance], request_data: Optional[Dict[str, Any]]) -> ServerInstance:
|
|
|
"""Hash-based selection for session affinity"""
|
|
|
if not request_data or 'prompt' not in request_data:
|
|
|
return random.choice(servers)
|
|
|
|
|
|
|
|
|
prompt_hash = hashlib.md5(request_data['prompt'].encode()).hexdigest()
|
|
|
server_index = int(prompt_hash, 16) % len(servers)
|
|
|
return servers[server_index]
|
|
|
|
|
|
def _resource_aware_selection(self, servers: List[ServerInstance]) -> ServerInstance:
|
|
|
"""Select server based on resource utilization"""
|
|
|
|
|
|
sorted_servers = sorted(servers, key=lambda s: s.load_score)
|
|
|
|
|
|
|
|
|
weights = [1.0 / (s.load_score + 0.1) for s in sorted_servers]
|
|
|
total_weight = sum(weights)
|
|
|
|
|
|
random_value = random.uniform(0, total_weight)
|
|
|
current_weight = 0
|
|
|
|
|
|
for server, weight in zip(sorted_servers, weights):
|
|
|
current_weight += weight
|
|
|
if random_value <= current_weight:
|
|
|
return server
|
|
|
|
|
|
return sorted_servers[0]
|
|
|
|
|
|
async def forward_request(self,
|
|
|
path: str,
|
|
|
method: str = "POST",
|
|
|
data: Optional[Dict[str, Any]] = None,
|
|
|
headers: Optional[Dict[str, str]] = None,
|
|
|
**kwargs) -> Tuple[int, Dict[str, Any]]:
|
|
|
"""Forward request to selected server with retry logic"""
|
|
|
self.total_requests += 1
|
|
|
|
|
|
for attempt in range(self.max_retries + 1):
|
|
|
server = self.select_server(data)
|
|
|
if not server:
|
|
|
self.failed_requests += 1
|
|
|
return 503, {"error": "No healthy servers available"}
|
|
|
|
|
|
try:
|
|
|
start_time = time.time()
|
|
|
server.current_connections += 1
|
|
|
|
|
|
url = f"{server.url}{path}"
|
|
|
request_kwargs = {
|
|
|
"timeout": aiohttp.ClientTimeout(total=server.timeout),
|
|
|
**kwargs
|
|
|
}
|
|
|
|
|
|
if headers:
|
|
|
request_kwargs["headers"] = headers
|
|
|
|
|
|
if data:
|
|
|
request_kwargs["json"] = data
|
|
|
|
|
|
async with self.session.request(method, url, **request_kwargs) as response:
|
|
|
response_time = time.time() - start_time
|
|
|
response_data = await response.json()
|
|
|
|
|
|
|
|
|
server.current_connections -= 1
|
|
|
server.total_requests += 1
|
|
|
server.response_times.append(response_time * 1000)
|
|
|
|
|
|
if response.status >= 400:
|
|
|
server.failed_requests += 1
|
|
|
|
|
|
if attempt < self.max_retries:
|
|
|
self.logger.warning(f"Request failed on {server.url} (attempt {attempt + 1}), retrying...")
|
|
|
continue
|
|
|
|
|
|
return response.status, response_data
|
|
|
|
|
|
except Exception as e:
|
|
|
server.current_connections = max(0, server.current_connections - 1)
|
|
|
server.failed_requests += 1
|
|
|
|
|
|
self.logger.error(f"Request failed on {server.url}: {e}")
|
|
|
|
|
|
if attempt < self.max_retries:
|
|
|
await asyncio.sleep(0.1 * (attempt + 1))
|
|
|
continue
|
|
|
|
|
|
self.failed_requests += 1
|
|
|
return 502, {"error": "All servers failed after retries"}
|
|
|
|
|
|
async def _check_server_health(self, server: ServerInstance) -> bool:
|
|
|
"""Check health of a single server"""
|
|
|
try:
|
|
|
url = f"{server.url}/health"
|
|
|
timeout = aiohttp.ClientTimeout(total=self.health_check_timeout)
|
|
|
|
|
|
async with self.session.get(url, timeout=timeout) as response:
|
|
|
if response.status == 200:
|
|
|
health_data = await response.json()
|
|
|
server.last_health_check = time.time()
|
|
|
server.health_check_failures = 0
|
|
|
|
|
|
|
|
|
if 'system_info' in health_data:
|
|
|
|
|
|
pass
|
|
|
|
|
|
return True
|
|
|
else:
|
|
|
server.health_check_failures += 1
|
|
|
return False
|
|
|
|
|
|
except Exception as e:
|
|
|
server.health_check_failures += 1
|
|
|
self.logger.debug(f"Health check failed for {server.url}: {e}")
|
|
|
return False
|
|
|
|
|
|
async def _check_all_servers_health(self):
|
|
|
"""Check health of all servers"""
|
|
|
tasks = [self._check_server_health(server) for server in self.servers]
|
|
|
results = await asyncio.gather(*tasks, return_exceptions=True)
|
|
|
|
|
|
for server, result in zip(self.servers, results):
|
|
|
if isinstance(result, Exception):
|
|
|
server.is_healthy = False
|
|
|
server.health_check_failures += 1
|
|
|
else:
|
|
|
was_healthy = server.is_healthy
|
|
|
server.is_healthy = result and server.health_check_failures < 3
|
|
|
|
|
|
if not was_healthy and server.is_healthy:
|
|
|
self.logger.info(f"Server {server.url} is back online")
|
|
|
elif was_healthy and not server.is_healthy:
|
|
|
self.logger.warning(f"Server {server.url} is unhealthy")
|
|
|
|
|
|
async def _health_check_loop(self):
|
|
|
"""Periodic health check loop"""
|
|
|
while True:
|
|
|
try:
|
|
|
await asyncio.sleep(self.health_check_interval)
|
|
|
await self._check_all_servers_health()
|
|
|
except asyncio.CancelledError:
|
|
|
break
|
|
|
except Exception as e:
|
|
|
self.logger.error(f"Health check loop error: {e}")
|
|
|
|
|
|
def add_server(self, host: str, port: int, weight: float = 1.0):
|
|
|
"""Add a new server to the pool"""
|
|
|
server = ServerInstance(host=host, port=port, weight=weight)
|
|
|
self.servers.append(server)
|
|
|
self.logger.info(f"Added server {server.url}")
|
|
|
|
|
|
def remove_server(self, host: str, port: int):
|
|
|
"""Remove a server from the pool"""
|
|
|
self.servers = [s for s in self.servers if not (s.host == host and s.port == port)]
|
|
|
self.logger.info(f"Removed server http://{host}:{port}")
|
|
|
|
|
|
def get_stats(self) -> Dict[str, Any]:
|
|
|
"""Get load balancer statistics"""
|
|
|
uptime = time.time() - self.start_time
|
|
|
|
|
|
server_stats = []
|
|
|
for server in self.servers:
|
|
|
server_stats.append({
|
|
|
"url": server.url,
|
|
|
"is_healthy": server.is_healthy,
|
|
|
"current_connections": server.current_connections,
|
|
|
"total_requests": server.total_requests,
|
|
|
"failed_requests": server.failed_requests,
|
|
|
"success_rate": server.success_rate,
|
|
|
"avg_response_time_ms": server.avg_response_time,
|
|
|
"load_score": server.load_score,
|
|
|
"weight": server.weight
|
|
|
})
|
|
|
|
|
|
return {
|
|
|
"strategy": self.strategy.value,
|
|
|
"uptime_seconds": uptime,
|
|
|
"total_requests": self.total_requests,
|
|
|
"failed_requests": self.failed_requests,
|
|
|
"success_rate": (self.total_requests - self.failed_requests) / max(self.total_requests, 1),
|
|
|
"healthy_servers": len(self.get_healthy_servers()),
|
|
|
"total_servers": len(self.servers),
|
|
|
"servers": server_stats
|
|
|
}
|
|
|
|
|
|
|
|
|
from fastapi import FastAPI, Request, HTTPException
|
|
|
from fastapi.responses import JSONResponse
|
|
|
import uvicorn
|
|
|
|
|
|
def create_load_balancer_app(servers: List[Tuple[str, int]],
|
|
|
strategy: LoadBalancingStrategy = LoadBalancingStrategy.RESOURCE_AWARE) -> FastAPI:
|
|
|
"""Create FastAPI app with load balancer"""
|
|
|
|
|
|
app = FastAPI(title="Mamba Swarm Load Balancer", version="1.0.0")
|
|
|
load_balancer = LoadBalancer(servers, strategy)
|
|
|
|
|
|
@app.on_event("startup")
|
|
|
async def startup():
|
|
|
await load_balancer.start()
|
|
|
|
|
|
@app.on_event("shutdown")
|
|
|
async def shutdown():
|
|
|
await load_balancer.stop()
|
|
|
|
|
|
@app.get("/lb/health")
|
|
|
async def lb_health():
|
|
|
"""Load balancer health endpoint"""
|
|
|
return {"status": "healthy", "stats": load_balancer.get_stats()}
|
|
|
|
|
|
@app.get("/lb/stats")
|
|
|
async def lb_stats():
|
|
|
"""Get load balancer statistics"""
|
|
|
return load_balancer.get_stats()
|
|
|
|
|
|
@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH"])
|
|
|
async def proxy_request(request: Request, path: str):
|
|
|
"""Proxy all requests to backend servers"""
|
|
|
try:
|
|
|
|
|
|
body = await request.body()
|
|
|
headers = dict(request.headers)
|
|
|
|
|
|
|
|
|
headers.pop("host", None)
|
|
|
headers.pop("connection", None)
|
|
|
|
|
|
|
|
|
data = None
|
|
|
if body:
|
|
|
try:
|
|
|
import json
|
|
|
data = json.loads(body.decode())
|
|
|
except:
|
|
|
pass
|
|
|
|
|
|
|
|
|
status, response_data = await load_balancer.forward_request(
|
|
|
f"/{path}",
|
|
|
request.method,
|
|
|
data=data,
|
|
|
headers=headers,
|
|
|
params=dict(request.query_params)
|
|
|
)
|
|
|
|
|
|
return JSONResponse(content=response_data, status_code=status)
|
|
|
|
|
|
except Exception as e:
|
|
|
return JSONResponse(
|
|
|
content={"error": f"Load balancer error: {str(e)}"},
|
|
|
status_code=500
|
|
|
)
|
|
|
|
|
|
return app
|
|
|
|
|
|
def run_load_balancer(servers: List[Tuple[str, int]],
|
|
|
host: str = "0.0.0.0",
|
|
|
port: int = 8080,
|
|
|
strategy: LoadBalancingStrategy = LoadBalancingStrategy.RESOURCE_AWARE):
|
|
|
"""Run the load balancer"""
|
|
|
app = create_load_balancer_app(servers, strategy)
|
|
|
|
|
|
config = uvicorn.Config(
|
|
|
app=app,
|
|
|
host=host,
|
|
|
port=port,
|
|
|
log_level="info"
|
|
|
)
|
|
|
|
|
|
server = uvicorn.Server(config)
|
|
|
server.run()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
|
servers = [
|
|
|
("localhost", 8000),
|
|
|
("localhost", 8001),
|
|
|
("localhost", 8002),
|
|
|
]
|
|
|
|
|
|
run_load_balancer(servers, strategy=LoadBalancingStrategy.RESOURCE_AWARE) |