""" Multi-Server LM Studio Client Pool for Felix Framework. This module enables true parallel processing by allowing agents to use different LM Studio servers and models simultaneously, removing the bottleneck of a single server. Key Features: - Multiple LM Studio server support - Agent-type to server/model mapping - Load balancing and health checks - Failover and fault tolerance - Performance monitoring per server Usage: pool = LMStudioClientPool("config/server_config.json") response = await pool.complete_for_agent_type("research", system_prompt, user_prompt) """ import json import time import asyncio import logging from typing import Dict, List, Optional, Any, Tuple from dataclasses import dataclass from enum import Enum from pathlib import Path from .lm_studio_client import LMStudioClient, LLMResponse, RequestPriority, LMStudioConnectionError logger = logging.getLogger(__name__) class LoadBalanceStrategy(Enum): """Load balancing strategies for server selection.""" ROUND_ROBIN = "round_robin" LEAST_BUSY = "least_busy" FASTEST_RESPONSE = "fastest_response" AGENT_TYPE_MAPPING = "agent_type_mapping" @dataclass class ServerConfig: """Configuration for a single LM Studio server.""" name: str url: str model: str timeout: float = 120.0 max_concurrent: int = 4 weight: float = 1.0 # For weighted load balancing enabled: bool = True @dataclass class ServerStats: """Runtime statistics for a server.""" total_requests: int = 0 total_tokens: int = 0 total_response_time: float = 0.0 current_load: int = 0 # Current active requests last_health_check: float = 0.0 health_status: bool = True average_response_time: float = 0.0 def update_stats(self, tokens: int, response_time: float): """Update server statistics.""" self.total_requests += 1 self.total_tokens += tokens self.total_response_time += response_time self.average_response_time = self.total_response_time / self.total_requests class LMStudioClientPool: """ Pool of LM Studio clients for multi-server parallel processing. Manages multiple LM Studio servers, assigns requests to appropriate servers based on agent types, and provides load balancing and failover capabilities. """ def __init__(self, config_path: Optional[str] = None, debug_mode: bool = False): """ Initialize the client pool. Args: config_path: Path to server configuration JSON file debug_mode: Enable verbose debug output """ self.debug_mode = debug_mode self.config_path = config_path # Server management self.servers: Dict[str, ServerConfig] = {} self.clients: Dict[str, LMStudioClient] = {} self.stats: Dict[str, ServerStats] = {} # Agent type mapping self.agent_mappings: Dict[str, str] = {} # Load balancing self.load_balance_strategy = LoadBalanceStrategy.AGENT_TYPE_MAPPING self._round_robin_index = 0 # Health monitoring self._health_check_interval = 30.0 # seconds self._last_global_health_check = 0.0 # Load configuration if config_path and Path(config_path).exists(): self.load_config(config_path) else: self._create_default_config() def _create_default_config(self): """Create default single-server configuration.""" default_server = ServerConfig( name="default", url="http://localhost:1234/v1", model="local-model", timeout=120.0, max_concurrent=4 ) self.servers["default"] = default_server self.clients["default"] = LMStudioClient( base_url=default_server.url, timeout=default_server.timeout, max_concurrent_requests=default_server.max_concurrent, debug_mode=self.debug_mode ) self.stats["default"] = ServerStats() # Default mapping: all agent types use default server self.agent_mappings = { "research": "default", "analysis": "default", "synthesis": "default", "critic": "default" } if self.debug_mode: print("šŸ”§ Using default single-server configuration") def load_config(self, config_path: str): """ Load server configuration from JSON file. Args: config_path: Path to configuration file """ try: with open(config_path, 'r') as f: config_data = json.load(f) # Load servers for server_data in config_data.get("servers", []): server_config = ServerConfig( name=server_data["name"], url=server_data["url"], model=server_data["model"], timeout=server_data.get("timeout", 120.0), max_concurrent=server_data.get("max_concurrent", 4), weight=server_data.get("weight", 1.0), enabled=server_data.get("enabled", True) ) self.servers[server_config.name] = server_config if server_config.enabled: self.clients[server_config.name] = LMStudioClient( base_url=server_config.url, timeout=server_config.timeout, max_concurrent_requests=server_config.max_concurrent, debug_mode=self.debug_mode ) self.stats[server_config.name] = ServerStats() # Load agent mappings self.agent_mappings = config_data.get("agent_mapping", {}) # Load load balancing strategy strategy_name = config_data.get("load_balance_strategy", "agent_type_mapping") try: self.load_balance_strategy = LoadBalanceStrategy(strategy_name) except ValueError: self.load_balance_strategy = LoadBalanceStrategy.AGENT_TYPE_MAPPING if self.debug_mode: print(f"šŸ”§ Loaded multi-server config: {len(self.servers)} servers") for name, server in self.servers.items(): status = "enabled" if server.enabled else "disabled" print(f" - {name}: {server.url} ({server.model}) [{status}]") except Exception as e: logger.error(f"Failed to load config from {config_path}: {e}") self._create_default_config() def get_server_for_agent_type(self, agent_type: str) -> Optional[str]: """ Get the appropriate server for an agent type. Args: agent_type: Type of agent (research, analysis, synthesis, critic) Returns: Server name or None if no suitable server """ if self.load_balance_strategy == LoadBalanceStrategy.AGENT_TYPE_MAPPING: return self.agent_mappings.get(agent_type, self._get_fallback_server()) elif self.load_balance_strategy == LoadBalanceStrategy.ROUND_ROBIN: return self._get_round_robin_server() elif self.load_balance_strategy == LoadBalanceStrategy.LEAST_BUSY: return self._get_least_busy_server() elif self.load_balance_strategy == LoadBalanceStrategy.FASTEST_RESPONSE: return self._get_fastest_server() return self._get_fallback_server() def _get_fallback_server(self) -> Optional[str]: """Get first available server as fallback.""" for name, server in self.servers.items(): if server.enabled and name in self.clients: return name return None def _get_round_robin_server(self) -> Optional[str]: """Get next server using round-robin selection.""" available_servers = [name for name, server in self.servers.items() if server.enabled and name in self.clients] if not available_servers: return None server_name = available_servers[self._round_robin_index % len(available_servers)] self._round_robin_index = (self._round_robin_index + 1) % len(available_servers) return server_name def _get_least_busy_server(self) -> Optional[str]: """Get server with lowest current load.""" available_servers = [(name, self.stats[name].current_load) for name, server in self.servers.items() if server.enabled and name in self.clients] if not available_servers: return None # Sort by current load (ascending) available_servers.sort(key=lambda x: x[1]) return available_servers[0][0] def _get_fastest_server(self) -> Optional[str]: """Get server with fastest average response time.""" available_servers = [(name, self.stats[name].average_response_time) for name, server in self.servers.items() if server.enabled and name in self.clients and self.stats[name].total_requests > 0] if not available_servers: return self._get_fallback_server() # Sort by average response time (ascending) available_servers.sort(key=lambda x: x[1]) return available_servers[0][0] async def complete_for_agent_type(self, agent_type: str, agent_id: str, system_prompt: str, user_prompt: str, temperature: float = 0.7, max_tokens: Optional[int] = None, priority: RequestPriority = RequestPriority.NORMAL) -> LLMResponse: """ Complete request using appropriate server for agent type. Args: agent_type: Type of agent making request agent_id: ID of the requesting agent system_prompt: System prompt user_prompt: User prompt temperature: Sampling temperature max_tokens: Maximum tokens priority: Request priority Returns: LLM response Raises: LMStudioConnectionError: If no servers available """ server_name = self.get_server_for_agent_type(agent_type) if not server_name or server_name not in self.clients: raise LMStudioConnectionError(f"No available server for agent type: {agent_type}") # Check health if needed await self._check_server_health(server_name) client = self.clients[server_name] server_config = self.servers[server_name] stats = self.stats[server_name] if self.debug_mode: print(f"🌐 {agent_id} ({agent_type}) → {server_name} ({server_config.url}, {server_config.model})") # Track load stats.current_load += 1 try: start_time = time.perf_counter() # Make request with specified model response = await client.complete_async( agent_id=agent_id, system_prompt=system_prompt, user_prompt=user_prompt, temperature=temperature, max_tokens=max_tokens, model=server_config.model, priority=priority ) end_time = time.perf_counter() response_time = end_time - start_time # Update statistics stats.update_stats(response.tokens_used, response_time) return response finally: # Decrease load counter stats.current_load = max(0, stats.current_load - 1) async def complete(self, agent_id: str, system_prompt: str, user_prompt: str, temperature: float = 0.7, max_tokens: Optional[int] = None, priority: RequestPriority = RequestPriority.NORMAL) -> LLMResponse: """ Complete request using default load balancing (fallback method). Args: agent_id: ID of the requesting agent system_prompt: System prompt user_prompt: User prompt temperature: Sampling temperature max_tokens: Maximum tokens priority: Request priority Returns: LLM response """ # Extract agent type from agent_id if possible agent_type = "general" if "_" in agent_id: agent_type = agent_id.split("_")[0] return await self.complete_for_agent_type( agent_type, agent_id, system_prompt, user_prompt, temperature, max_tokens, priority ) async def _check_server_health(self, server_name: str): """Check health of specific server.""" current_time = time.time() stats = self.stats[server_name] # Only check if enough time has passed if current_time - stats.last_health_check < self._health_check_interval: return client = self.clients[server_name] try: health_ok = client.test_connection() stats.health_status = health_ok stats.last_health_check = current_time if not health_ok and self.debug_mode: print(f"āš ļø Server {server_name} health check failed") except Exception as e: stats.health_status = False stats.last_health_check = current_time logger.warning(f"Health check failed for {server_name}: {e}") async def health_check_all_servers(self) -> Dict[str, bool]: """ Check health of all servers. Returns: Dictionary mapping server names to health status """ health_results = {} for server_name in self.clients: await self._check_server_health(server_name) health_results[server_name] = self.stats[server_name].health_status return health_results def get_pool_stats(self) -> Dict[str, Any]: """ Get comprehensive statistics for the entire pool. Returns: Dictionary with pool statistics """ pool_stats = { "total_servers": len(self.servers), "active_servers": len(self.clients), "load_balance_strategy": self.load_balance_strategy.value, "servers": {} } for name, stats in self.stats.items(): server_config = self.servers[name] pool_stats["servers"][name] = { "config": { "url": server_config.url, "model": server_config.model, "enabled": server_config.enabled }, "stats": { "total_requests": stats.total_requests, "total_tokens": stats.total_tokens, "average_response_time": stats.average_response_time, "current_load": stats.current_load, "health_status": stats.health_status } } return pool_stats def get_agent_mapping_info(self) -> Dict[str, str]: """Get current agent type to server mappings.""" return self.agent_mappings.copy() async def close_all(self): """Close all client connections.""" for client in self.clients.values(): if hasattr(client, 'close_async'): await client.close_async() def display_pool_status(self): """Display current pool status for debugging.""" if not self.debug_mode: return print(f"\n╭─ LM STUDIO POOL STATUS ─╮") print(f"│ Strategy: {self.load_balance_strategy.value}") print(f"│ Servers: {len(self.clients)} active / {len(self.servers)} total") for name, stats in self.stats.items(): server = self.servers[name] status = "🟢" if stats.health_status else "šŸ”“" load = f"{stats.current_load}/{server.max_concurrent}" avg_time = f"{stats.average_response_time:.2f}s" if stats.total_requests > 0 else "N/A" print(f"│ {status} {name}: {load} load, {stats.total_requests} reqs, {avg_time} avg") print(f"│ Agent Mapping:") for agent_type, server_name in self.agent_mappings.items(): print(f"│ {agent_type} → {server_name}") print(f"ā•°{'─'*35}╯")