Debito commited on
Commit
d38a70f
·
verified ·
1 Parent(s): 1420d02

Upload 3 files

Browse files
Files changed (3) hide show
  1. api/__init__.py +1 -0
  2. api/api_server.py +373 -0
  3. api/load_balancer.py +475 -0
api/__init__.py CHANGED
@@ -0,0 +1 @@
 
 
1
+
api/api_server.py ADDED
@@ -0,0 +1,373 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ API Server for Mamba Swarm
3
+ FastAPI-based server for serving the distributed Mamba language model
4
+ """
5
+
6
+ from fastapi import FastAPI, HTTPException, BackgroundTasks, Depends
7
+ from fastapi.middleware.cors import CORSMiddleware
8
+ from fastapi.responses import StreamingResponse
9
+ from pydantic import BaseModel, Field
10
+ from typing import List, Optional, Dict, Any, AsyncGenerator
11
+ import asyncio
12
+ import json
13
+ import time
14
+ import logging
15
+ import torch
16
+ from contextlib import asynccontextmanager
17
+ import uvicorn
18
+
19
+ # Import your swarm components
20
+ from system.mambaSwarm import SwarmEngine
21
+ from system.inference import InferenceEngine
22
+ from routing.router import Router
23
+ from training.trainer import setup_logging
24
+
25
+ # Pydantic models for API
26
+ class GenerationRequest(BaseModel):
27
+ prompt: str = Field(..., description="Input text prompt")
28
+ max_length: int = Field(default=100, ge=1, le=2048, description="Maximum generation length")
29
+ temperature: float = Field(default=0.7, ge=0.1, le=2.0, description="Sampling temperature")
30
+ top_p: float = Field(default=0.9, ge=0.1, le=1.0, description="Top-p sampling")
31
+ top_k: int = Field(default=50, ge=1, le=100, description="Top-k sampling")
32
+ repetition_penalty: float = Field(default=1.1, ge=1.0, le=2.0, description="Repetition penalty")
33
+ stream: bool = Field(default=False, description="Enable streaming response")
34
+ domain: Optional[str] = Field(default=None, description="Specific domain for routing")
35
+
36
+ class GenerationResponse(BaseModel):
37
+ generated_text: str
38
+ prompt: str
39
+ generation_time: float
40
+ tokens_generated: int
41
+ model_info: Dict[str, Any]
42
+
43
+ class StreamingToken(BaseModel):
44
+ token: str
45
+ is_final: bool = False
46
+ metadata: Optional[Dict[str, Any]] = None
47
+
48
+ class HealthResponse(BaseModel):
49
+ status: str
50
+ swarm_status: Dict[str, Any]
51
+ system_info: Dict[str, Any]
52
+ timestamp: float
53
+
54
+ class ModelInfo(BaseModel):
55
+ total_parameters: int
56
+ active_encoders: int
57
+ total_encoders: int
58
+ memory_usage: Dict[str, float]
59
+ device_info: List[str]
60
+
61
+ # Global swarm engine instance
62
+ swarm_engine: Optional[SwarmEngine] = None
63
+ inference_engine: Optional[InferenceEngine] = None
64
+
65
+ @asynccontextmanager
66
+ async def lifespan(app: FastAPI):
67
+ """Manage application lifespan"""
68
+ global swarm_engine, inference_engine
69
+
70
+ # Startup
71
+ logging.info("Initializing Mamba Swarm API Server...")
72
+
73
+ try:
74
+ # Initialize swarm engine
75
+ swarm_engine = SwarmEngine()
76
+ await asyncio.get_event_loop().run_in_executor(None, swarm_engine.initialize)
77
+
78
+ # Initialize inference engine
79
+ inference_engine = InferenceEngine(swarm_engine)
80
+
81
+ logging.info("Mamba Swarm API Server initialized successfully")
82
+
83
+ except Exception as e:
84
+ logging.error(f"Failed to initialize swarm: {e}")
85
+ raise
86
+
87
+ yield
88
+
89
+ # Shutdown
90
+ logging.info("Shutting down Mamba Swarm API Server...")
91
+ if swarm_engine:
92
+ swarm_engine.shutdown()
93
+
94
+ # Create FastAPI app
95
+ app = FastAPI(
96
+ title="Mamba Swarm API",
97
+ description="Distributed Mamba Language Model API with 100 encoder units",
98
+ version="1.0.0",
99
+ lifespan=lifespan
100
+ )
101
+
102
+ # Add CORS middleware
103
+ app.add_middleware(
104
+ CORSMiddleware,
105
+ allow_origins=["*"],
106
+ allow_credentials=True,
107
+ allow_methods=["*"],
108
+ allow_headers=["*"],
109
+ )
110
+
111
+ # Dependency to get swarm engine
112
+ async def get_swarm_engine() -> SwarmEngine:
113
+ if swarm_engine is None:
114
+ raise HTTPException(status_code=503, detail="Swarm engine not initialized")
115
+ return swarm_engine
116
+
117
+ async def get_inference_engine() -> InferenceEngine:
118
+ if inference_engine is None:
119
+ raise HTTPException(status_code=503, detail="Inference engine not initialized")
120
+ return inference_engine
121
+
122
+ @app.get("/health", response_model=HealthResponse)
123
+ async def health_check(swarm: SwarmEngine = Depends(get_swarm_engine)):
124
+ """Health check endpoint"""
125
+ try:
126
+ swarm_status = swarm.get_status()
127
+ system_info = {
128
+ "cuda_available": torch.cuda.is_available(),
129
+ "cuda_device_count": torch.cuda.device_count() if torch.cuda.is_available() else 0,
130
+ "python_version": "3.8+",
131
+ }
132
+
133
+ return HealthResponse(
134
+ status="healthy",
135
+ swarm_status=swarm_status,
136
+ system_info=system_info,
137
+ timestamp=time.time()
138
+ )
139
+ except Exception as e:
140
+ raise HTTPException(status_code=500, detail=f"Health check failed: {str(e)}")
141
+
142
+ @app.get("/model/info", response_model=ModelInfo)
143
+ async def get_model_info(swarm: SwarmEngine = Depends(get_swarm_engine)):
144
+ """Get model information"""
145
+ try:
146
+ info = swarm.get_model_info()
147
+ memory_stats = swarm.memory_manager.get_memory_stats()
148
+
149
+ return ModelInfo(
150
+ total_parameters=info.get("total_parameters", 7000000000), # 100 * 70M
151
+ active_encoders=info.get("active_encoders", 100),
152
+ total_encoders=info.get("total_encoders", 100),
153
+ memory_usage={
154
+ "system_memory_gb": memory_stats.used_memory,
155
+ "gpu_memory_gb": memory_stats.gpu_memory,
156
+ "cache_size_gb": memory_stats.cache_size
157
+ },
158
+ device_info=info.get("devices", ["cuda:0" if torch.cuda.is_available() else "cpu"])
159
+ )
160
+ except Exception as e:
161
+ raise HTTPException(status_code=500, detail=f"Failed to get model info: {str(e)}")
162
+
163
+ @app.post("/generate", response_model=GenerationResponse)
164
+ async def generate_text(
165
+ request: GenerationRequest,
166
+ inference: InferenceEngine = Depends(get_inference_engine)
167
+ ):
168
+ """Generate text from prompt"""
169
+ try:
170
+ start_time = time.time()
171
+
172
+ # Generate text
173
+ result = await asyncio.get_event_loop().run_in_executor(
174
+ None,
175
+ inference.generate,
176
+ request.prompt,
177
+ {
178
+ "max_length": request.max_length,
179
+ "temperature": request.temperature,
180
+ "top_p": request.top_p,
181
+ "top_k": request.top_k,
182
+ "repetition_penalty": request.repetition_penalty,
183
+ "domain": request.domain
184
+ }
185
+ )
186
+
187
+ generation_time = time.time() - start_time
188
+
189
+ return GenerationResponse(
190
+ generated_text=result["generated_text"],
191
+ prompt=request.prompt,
192
+ generation_time=generation_time,
193
+ tokens_generated=result.get("tokens_generated", 0),
194
+ model_info=result.get("model_info", {})
195
+ )
196
+
197
+ except Exception as e:
198
+ raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}")
199
+
200
+ @app.post("/generate/stream")
201
+ async def generate_text_stream(
202
+ request: GenerationRequest,
203
+ inference: InferenceEngine = Depends(get_inference_engine)
204
+ ):
205
+ """Generate text with streaming response"""
206
+ if not request.stream:
207
+ raise HTTPException(status_code=400, detail="Streaming not requested")
208
+
209
+ async def generate_stream() -> AsyncGenerator[str, None]:
210
+ try:
211
+ # Create generator for streaming
212
+ generator = inference.generate_stream(
213
+ request.prompt,
214
+ {
215
+ "max_length": request.max_length,
216
+ "temperature": request.temperature,
217
+ "top_p": request.top_p,
218
+ "top_k": request.top_k,
219
+ "repetition_penalty": request.repetition_penalty,
220
+ "domain": request.domain
221
+ }
222
+ )
223
+
224
+ for token_data in generator:
225
+ streaming_token = StreamingToken(
226
+ token=token_data.get("token", ""),
227
+ is_final=token_data.get("is_final", False),
228
+ metadata=token_data.get("metadata", {})
229
+ )
230
+
231
+ yield f"data: {streaming_token.json()}\n\n"
232
+
233
+ if streaming_token.is_final:
234
+ break
235
+
236
+ except Exception as e:
237
+ error_token = StreamingToken(
238
+ token="",
239
+ is_final=True,
240
+ metadata={"error": str(e)}
241
+ )
242
+ yield f"data: {error_token.json()}\n\n"
243
+
244
+ return StreamingResponse(
245
+ generate_stream(),
246
+ media_type="text/plain",
247
+ headers={"Cache-Control": "no-cache", "Connection": "keep-alive"}
248
+ )
249
+
250
+ @app.post("/generate/batch")
251
+ async def generate_batch(
252
+ requests: List[GenerationRequest],
253
+ inference: InferenceEngine = Depends(get_inference_engine)
254
+ ):
255
+ """Generate text for multiple prompts"""
256
+ if len(requests) > 10:
257
+ raise HTTPException(status_code=400, detail="Batch size too large (max 10)")
258
+
259
+ try:
260
+ # Process requests in parallel
261
+ tasks = []
262
+ for req in requests:
263
+ task = asyncio.get_event_loop().run_in_executor(
264
+ None,
265
+ inference.generate,
266
+ req.prompt,
267
+ {
268
+ "max_length": req.max_length,
269
+ "temperature": req.temperature,
270
+ "top_p": req.top_p,
271
+ "top_k": req.top_k,
272
+ "repetition_penalty": req.repetition_penalty,
273
+ "domain": req.domain
274
+ }
275
+ )
276
+ tasks.append(task)
277
+
278
+ results = await asyncio.gather(*tasks, return_exceptions=True)
279
+
280
+ responses = []
281
+ for i, (req, result) in enumerate(zip(requests, results)):
282
+ if isinstance(result, Exception):
283
+ responses.append({
284
+ "error": str(result),
285
+ "prompt": req.prompt,
286
+ "index": i
287
+ })
288
+ else:
289
+ responses.append(GenerationResponse(
290
+ generated_text=result["generated_text"],
291
+ prompt=req.prompt,
292
+ generation_time=result.get("generation_time", 0),
293
+ tokens_generated=result.get("tokens_generated", 0),
294
+ model_info=result.get("model_info", {})
295
+ ))
296
+
297
+ return {"responses": responses}
298
+
299
+ except Exception as e:
300
+ raise HTTPException(status_code=500, detail=f"Batch generation failed: {str(e)}")
301
+
302
+ @app.get("/metrics")
303
+ async def get_metrics(swarm: SwarmEngine = Depends(get_swarm_engine)):
304
+ """Get system metrics"""
305
+ try:
306
+ metrics = {
307
+ "memory_report": swarm.memory_manager.get_memory_report(),
308
+ "swarm_metrics": swarm.get_metrics(),
309
+ "inference_stats": swarm.get_inference_stats() if hasattr(swarm, 'get_inference_stats') else {},
310
+ "timestamp": time.time()
311
+ }
312
+ return metrics
313
+ except Exception as e:
314
+ raise HTTPException(status_code=500, detail=f"Failed to get metrics: {str(e)}")
315
+
316
+ @app.post("/admin/reload")
317
+ async def reload_model(
318
+ background_tasks: BackgroundTasks,
319
+ swarm: SwarmEngine = Depends(get_swarm_engine)
320
+ ):
321
+ """Reload the model (admin endpoint)"""
322
+ try:
323
+ background_tasks.add_task(swarm.reload_model)
324
+ return {"message": "Model reload initiated"}
325
+ except Exception as e:
326
+ raise HTTPException(status_code=500, detail=f"Failed to reload model: {str(e)}")
327
+
328
+ @app.post("/admin/cleanup")
329
+ async def cleanup_memory(swarm: SwarmEngine = Depends(get_swarm_engine)):
330
+ """Force memory cleanup (admin endpoint)"""
331
+ try:
332
+ swarm.memory_manager.cleanup_memory(aggressive=True)
333
+ return {"message": "Memory cleanup completed"}
334
+ except Exception as e:
335
+ raise HTTPException(status_code=500, detail=f"Failed to cleanup memory: {str(e)}")
336
+
337
+ # Error handlers
338
+ @app.exception_handler(HTTPException)
339
+ async def http_exception_handler(request, exc):
340
+ return {
341
+ "error": exc.detail,
342
+ "status_code": exc.status_code,
343
+ "timestamp": time.time()
344
+ }
345
+
346
+ @app.exception_handler(Exception)
347
+ async def general_exception_handler(request, exc):
348
+ logging.error(f"Unhandled exception: {exc}")
349
+ return {
350
+ "error": "Internal server error",
351
+ "status_code": 500,
352
+ "timestamp": time.time()
353
+ }
354
+
355
+ def run_server(host: str = "0.0.0.0", port: int = 8000, workers: int = 1):
356
+ """Run the API server"""
357
+ setup_logging()
358
+
359
+ config = uvicorn.Config(
360
+ app=app,
361
+ host=host,
362
+ port=port,
363
+ workers=workers,
364
+ log_level="info",
365
+ access_log=True,
366
+ reload=False # Set to True for development
367
+ )
368
+
369
+ server = uvicorn.Server(config)
370
+ server.run()
371
+
372
+ if __name__ == "__main__":
373
+ run_server()
api/load_balancer.py ADDED
@@ -0,0 +1,475 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Load Balancer for Mamba Swarm API
3
+ Distributes requests across multiple API server instances
4
+ """
5
+
6
+ import asyncio
7
+ import aiohttp
8
+ import random
9
+ import time
10
+ import logging
11
+ from typing import List, Dict, Any, Optional, Tuple
12
+ from dataclasses import dataclass, field
13
+ from enum import Enum
14
+ from collections import defaultdict, deque
15
+ import json
16
+ import hashlib
17
+
18
+ class LoadBalancingStrategy(Enum):
19
+ ROUND_ROBIN = "round_robin"
20
+ LEAST_CONNECTIONS = "least_connections"
21
+ WEIGHTED_ROUND_ROBIN = "weighted_round_robin"
22
+ LEAST_RESPONSE_TIME = "least_response_time"
23
+ HASH_BASED = "hash_based"
24
+ RESOURCE_AWARE = "resource_aware"
25
+
26
+ @dataclass
27
+ class ServerInstance:
28
+ host: str
29
+ port: int
30
+ weight: float = 1.0
31
+ max_connections: int = 100
32
+ timeout: float = 30.0
33
+ current_connections: int = 0
34
+ total_requests: int = 0
35
+ failed_requests: int = 0
36
+ response_times: deque = field(default_factory=lambda: deque(maxlen=100))
37
+ last_health_check: float = 0.0
38
+ is_healthy: bool = True
39
+ health_check_failures: int = 0
40
+
41
+ @property
42
+ def url(self) -> str:
43
+ return f"http://{self.host}:{self.port}"
44
+
45
+ @property
46
+ def avg_response_time(self) -> float:
47
+ return sum(self.response_times) / len(self.response_times) if self.response_times else 0.0
48
+
49
+ @property
50
+ def success_rate(self) -> float:
51
+ total = self.total_requests
52
+ if total == 0:
53
+ return 1.0
54
+ return (total - self.failed_requests) / total
55
+
56
+ @property
57
+ def load_score(self) -> float:
58
+ """Calculate load score for resource-aware balancing"""
59
+ connection_load = self.current_connections / self.max_connections
60
+ response_time_load = min(self.avg_response_time / 1000.0, 1.0) # Normalize to seconds
61
+ failure_rate = self.failed_requests / max(self.total_requests, 1)
62
+
63
+ return (connection_load * 0.4 + response_time_load * 0.4 + failure_rate * 0.2)
64
+
65
+ class LoadBalancer:
66
+ """Advanced load balancer for Mamba Swarm API servers"""
67
+
68
+ def __init__(self,
69
+ servers: List[Tuple[str, int]],
70
+ strategy: LoadBalancingStrategy = LoadBalancingStrategy.RESOURCE_AWARE,
71
+ health_check_interval: float = 30.0,
72
+ health_check_timeout: float = 5.0,
73
+ max_retries: int = 3):
74
+
75
+ self.logger = logging.getLogger(__name__)
76
+ self.strategy = strategy
77
+ self.health_check_interval = health_check_interval
78
+ self.health_check_timeout = health_check_timeout
79
+ self.max_retries = max_retries
80
+
81
+ # Initialize server instances
82
+ self.servers = [
83
+ ServerInstance(host=host, port=port)
84
+ for host, port in servers
85
+ ]
86
+
87
+ # Strategy-specific state
88
+ self.round_robin_index = 0
89
+ self.request_counts = defaultdict(int)
90
+
91
+ # Session for HTTP requests
92
+ self.session: Optional[aiohttp.ClientSession] = None
93
+
94
+ # Health check task
95
+ self.health_check_task: Optional[asyncio.Task] = None
96
+
97
+ # Metrics
98
+ self.total_requests = 0
99
+ self.failed_requests = 0
100
+ self.start_time = time.time()
101
+
102
+ async def __aenter__(self):
103
+ """Async context manager entry"""
104
+ await self.start()
105
+ return self
106
+
107
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
108
+ """Async context manager exit"""
109
+ await self.stop()
110
+
111
+ async def start(self):
112
+ """Start the load balancer"""
113
+ # Create HTTP session
114
+ timeout = aiohttp.ClientTimeout(total=30.0, connect=10.0)
115
+ self.session = aiohttp.ClientSession(timeout=timeout)
116
+
117
+ # Start health check task
118
+ self.health_check_task = asyncio.create_task(self._health_check_loop())
119
+
120
+ # Initial health check
121
+ await self._check_all_servers_health()
122
+
123
+ self.logger.info(f"Load balancer started with {len(self.servers)} servers using {self.strategy.value} strategy")
124
+
125
+ async def stop(self):
126
+ """Stop the load balancer"""
127
+ if self.health_check_task:
128
+ self.health_check_task.cancel()
129
+ try:
130
+ await self.health_check_task
131
+ except asyncio.CancelledError:
132
+ pass
133
+
134
+ if self.session:
135
+ await self.session.close()
136
+
137
+ self.logger.info("Load balancer stopped")
138
+
139
+ def get_healthy_servers(self) -> List[ServerInstance]:
140
+ """Get list of healthy servers"""
141
+ return [server for server in self.servers if server.is_healthy]
142
+
143
+ def select_server(self, request_data: Optional[Dict[str, Any]] = None) -> Optional[ServerInstance]:
144
+ """Select server based on configured strategy"""
145
+ healthy_servers = self.get_healthy_servers()
146
+
147
+ if not healthy_servers:
148
+ self.logger.warning("No healthy servers available")
149
+ return None
150
+
151
+ if self.strategy == LoadBalancingStrategy.ROUND_ROBIN:
152
+ return self._round_robin_selection(healthy_servers)
153
+ elif self.strategy == LoadBalancingStrategy.LEAST_CONNECTIONS:
154
+ return self._least_connections_selection(healthy_servers)
155
+ elif self.strategy == LoadBalancingStrategy.WEIGHTED_ROUND_ROBIN:
156
+ return self._weighted_round_robin_selection(healthy_servers)
157
+ elif self.strategy == LoadBalancingStrategy.LEAST_RESPONSE_TIME:
158
+ return self._least_response_time_selection(healthy_servers)
159
+ elif self.strategy == LoadBalancingStrategy.HASH_BASED:
160
+ return self._hash_based_selection(healthy_servers, request_data)
161
+ elif self.strategy == LoadBalancingStrategy.RESOURCE_AWARE:
162
+ return self._resource_aware_selection(healthy_servers)
163
+ else:
164
+ return random.choice(healthy_servers)
165
+
166
+ def _round_robin_selection(self, servers: List[ServerInstance]) -> ServerInstance:
167
+ """Round-robin server selection"""
168
+ server = servers[self.round_robin_index % len(servers)]
169
+ self.round_robin_index += 1
170
+ return server
171
+
172
+ def _least_connections_selection(self, servers: List[ServerInstance]) -> ServerInstance:
173
+ """Select server with least connections"""
174
+ return min(servers, key=lambda s: s.current_connections)
175
+
176
+ def _weighted_round_robin_selection(self, servers: List[ServerInstance]) -> ServerInstance:
177
+ """Weighted round-robin selection"""
178
+ total_weight = sum(s.weight for s in servers)
179
+ random_weight = random.uniform(0, total_weight)
180
+
181
+ current_weight = 0
182
+ for server in servers:
183
+ current_weight += server.weight
184
+ if random_weight <= current_weight:
185
+ return server
186
+
187
+ return servers[-1] # Fallback
188
+
189
+ def _least_response_time_selection(self, servers: List[ServerInstance]) -> ServerInstance:
190
+ """Select server with least average response time"""
191
+ return min(servers, key=lambda s: s.avg_response_time or float('inf'))
192
+
193
+ def _hash_based_selection(self, servers: List[ServerInstance], request_data: Optional[Dict[str, Any]]) -> ServerInstance:
194
+ """Hash-based selection for session affinity"""
195
+ if not request_data or 'prompt' not in request_data:
196
+ return random.choice(servers)
197
+
198
+ # Use prompt hash for consistent routing
199
+ prompt_hash = hashlib.md5(request_data['prompt'].encode()).hexdigest()
200
+ server_index = int(prompt_hash, 16) % len(servers)
201
+ return servers[server_index]
202
+
203
+ def _resource_aware_selection(self, servers: List[ServerInstance]) -> ServerInstance:
204
+ """Select server based on resource utilization"""
205
+ # Sort by load score (lower is better)
206
+ sorted_servers = sorted(servers, key=lambda s: s.load_score)
207
+
208
+ # Use weighted random selection favoring lower load servers
209
+ weights = [1.0 / (s.load_score + 0.1) for s in sorted_servers]
210
+ total_weight = sum(weights)
211
+
212
+ random_value = random.uniform(0, total_weight)
213
+ current_weight = 0
214
+
215
+ for server, weight in zip(sorted_servers, weights):
216
+ current_weight += weight
217
+ if random_value <= current_weight:
218
+ return server
219
+
220
+ return sorted_servers[0] # Fallback to best server
221
+
222
+ async def forward_request(self,
223
+ path: str,
224
+ method: str = "POST",
225
+ data: Optional[Dict[str, Any]] = None,
226
+ headers: Optional[Dict[str, str]] = None,
227
+ **kwargs) -> Tuple[int, Dict[str, Any]]:
228
+ """Forward request to selected server with retry logic"""
229
+ self.total_requests += 1
230
+
231
+ for attempt in range(self.max_retries + 1):
232
+ server = self.select_server(data)
233
+ if not server:
234
+ self.failed_requests += 1
235
+ return 503, {"error": "No healthy servers available"}
236
+
237
+ try:
238
+ start_time = time.time()
239
+ server.current_connections += 1
240
+
241
+ url = f"{server.url}{path}"
242
+ request_kwargs = {
243
+ "timeout": aiohttp.ClientTimeout(total=server.timeout),
244
+ **kwargs
245
+ }
246
+
247
+ if headers:
248
+ request_kwargs["headers"] = headers
249
+
250
+ if data:
251
+ request_kwargs["json"] = data
252
+
253
+ async with self.session.request(method, url, **request_kwargs) as response:
254
+ response_time = time.time() - start_time
255
+ response_data = await response.json()
256
+
257
+ # Update server metrics
258
+ server.current_connections -= 1
259
+ server.total_requests += 1
260
+ server.response_times.append(response_time * 1000) # Store in ms
261
+
262
+ if response.status >= 400:
263
+ server.failed_requests += 1
264
+
265
+ if attempt < self.max_retries:
266
+ self.logger.warning(f"Request failed on {server.url} (attempt {attempt + 1}), retrying...")
267
+ continue
268
+
269
+ return response.status, response_data
270
+
271
+ except Exception as e:
272
+ server.current_connections = max(0, server.current_connections - 1)
273
+ server.failed_requests += 1
274
+
275
+ self.logger.error(f"Request failed on {server.url}: {e}")
276
+
277
+ if attempt < self.max_retries:
278
+ await asyncio.sleep(0.1 * (attempt + 1)) # Exponential backoff
279
+ continue
280
+
281
+ self.failed_requests += 1
282
+ return 502, {"error": "All servers failed after retries"}
283
+
284
+ async def _check_server_health(self, server: ServerInstance) -> bool:
285
+ """Check health of a single server"""
286
+ try:
287
+ url = f"{server.url}/health"
288
+ timeout = aiohttp.ClientTimeout(total=self.health_check_timeout)
289
+
290
+ async with self.session.get(url, timeout=timeout) as response:
291
+ if response.status == 200:
292
+ health_data = await response.json()
293
+ server.last_health_check = time.time()
294
+ server.health_check_failures = 0
295
+
296
+ # Update server metrics from health data if available
297
+ if 'system_info' in health_data:
298
+ # Could extract additional metrics here
299
+ pass
300
+
301
+ return True
302
+ else:
303
+ server.health_check_failures += 1
304
+ return False
305
+
306
+ except Exception as e:
307
+ server.health_check_failures += 1
308
+ self.logger.debug(f"Health check failed for {server.url}: {e}")
309
+ return False
310
+
311
+ async def _check_all_servers_health(self):
312
+ """Check health of all servers"""
313
+ tasks = [self._check_server_health(server) for server in self.servers]
314
+ results = await asyncio.gather(*tasks, return_exceptions=True)
315
+
316
+ for server, result in zip(self.servers, results):
317
+ if isinstance(result, Exception):
318
+ server.is_healthy = False
319
+ server.health_check_failures += 1
320
+ else:
321
+ was_healthy = server.is_healthy
322
+ server.is_healthy = result and server.health_check_failures < 3
323
+
324
+ if not was_healthy and server.is_healthy:
325
+ self.logger.info(f"Server {server.url} is back online")
326
+ elif was_healthy and not server.is_healthy:
327
+ self.logger.warning(f"Server {server.url} is unhealthy")
328
+
329
+ async def _health_check_loop(self):
330
+ """Periodic health check loop"""
331
+ while True:
332
+ try:
333
+ await asyncio.sleep(self.health_check_interval)
334
+ await self._check_all_servers_health()
335
+ except asyncio.CancelledError:
336
+ break
337
+ except Exception as e:
338
+ self.logger.error(f"Health check loop error: {e}")
339
+
340
+ def add_server(self, host: str, port: int, weight: float = 1.0):
341
+ """Add a new server to the pool"""
342
+ server = ServerInstance(host=host, port=port, weight=weight)
343
+ self.servers.append(server)
344
+ self.logger.info(f"Added server {server.url}")
345
+
346
+ def remove_server(self, host: str, port: int):
347
+ """Remove a server from the pool"""
348
+ self.servers = [s for s in self.servers if not (s.host == host and s.port == port)]
349
+ self.logger.info(f"Removed server http://{host}:{port}")
350
+
351
+ def get_stats(self) -> Dict[str, Any]:
352
+ """Get load balancer statistics"""
353
+ uptime = time.time() - self.start_time
354
+
355
+ server_stats = []
356
+ for server in self.servers:
357
+ server_stats.append({
358
+ "url": server.url,
359
+ "is_healthy": server.is_healthy,
360
+ "current_connections": server.current_connections,
361
+ "total_requests": server.total_requests,
362
+ "failed_requests": server.failed_requests,
363
+ "success_rate": server.success_rate,
364
+ "avg_response_time_ms": server.avg_response_time,
365
+ "load_score": server.load_score,
366
+ "weight": server.weight
367
+ })
368
+
369
+ return {
370
+ "strategy": self.strategy.value,
371
+ "uptime_seconds": uptime,
372
+ "total_requests": self.total_requests,
373
+ "failed_requests": self.failed_requests,
374
+ "success_rate": (self.total_requests - self.failed_requests) / max(self.total_requests, 1),
375
+ "healthy_servers": len(self.get_healthy_servers()),
376
+ "total_servers": len(self.servers),
377
+ "servers": server_stats
378
+ }
379
+
380
+ # FastAPI integration
381
+ from fastapi import FastAPI, Request, HTTPException
382
+ from fastapi.responses import JSONResponse
383
+ import uvicorn
384
+
385
+ def create_load_balancer_app(servers: List[Tuple[str, int]],
386
+ strategy: LoadBalancingStrategy = LoadBalancingStrategy.RESOURCE_AWARE) -> FastAPI:
387
+ """Create FastAPI app with load balancer"""
388
+
389
+ app = FastAPI(title="Mamba Swarm Load Balancer", version="1.0.0")
390
+ load_balancer = LoadBalancer(servers, strategy)
391
+
392
+ @app.on_event("startup")
393
+ async def startup():
394
+ await load_balancer.start()
395
+
396
+ @app.on_event("shutdown")
397
+ async def shutdown():
398
+ await load_balancer.stop()
399
+
400
+ @app.get("/lb/health")
401
+ async def lb_health():
402
+ """Load balancer health endpoint"""
403
+ return {"status": "healthy", "stats": load_balancer.get_stats()}
404
+
405
+ @app.get("/lb/stats")
406
+ async def lb_stats():
407
+ """Get load balancer statistics"""
408
+ return load_balancer.get_stats()
409
+
410
+ @app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH"])
411
+ async def proxy_request(request: Request, path: str):
412
+ """Proxy all requests to backend servers"""
413
+ try:
414
+ # Get request data
415
+ body = await request.body()
416
+ headers = dict(request.headers)
417
+
418
+ # Remove hop-by-hop headers
419
+ headers.pop("host", None)
420
+ headers.pop("connection", None)
421
+
422
+ # Parse body if it's JSON
423
+ data = None
424
+ if body:
425
+ try:
426
+ import json
427
+ data = json.loads(body.decode())
428
+ except:
429
+ pass
430
+
431
+ # Forward request
432
+ status, response_data = await load_balancer.forward_request(
433
+ f"/{path}",
434
+ request.method,
435
+ data=data,
436
+ headers=headers,
437
+ params=dict(request.query_params)
438
+ )
439
+
440
+ return JSONResponse(content=response_data, status_code=status)
441
+
442
+ except Exception as e:
443
+ return JSONResponse(
444
+ content={"error": f"Load balancer error: {str(e)}"},
445
+ status_code=500
446
+ )
447
+
448
+ return app
449
+
450
+ def run_load_balancer(servers: List[Tuple[str, int]],
451
+ host: str = "0.0.0.0",
452
+ port: int = 8080,
453
+ strategy: LoadBalancingStrategy = LoadBalancingStrategy.RESOURCE_AWARE):
454
+ """Run the load balancer"""
455
+ app = create_load_balancer_app(servers, strategy)
456
+
457
+ config = uvicorn.Config(
458
+ app=app,
459
+ host=host,
460
+ port=port,
461
+ log_level="info"
462
+ )
463
+
464
+ server = uvicorn.Server(config)
465
+ server.run()
466
+
467
+ if __name__ == "__main__":
468
+ # Example usage
469
+ servers = [
470
+ ("localhost", 8000),
471
+ ("localhost", 8001),
472
+ ("localhost", 8002),
473
+ ]
474
+
475
+ run_load_balancer(servers, strategy=LoadBalancingStrategy.RESOURCE_AWARE)