# src/infrastructure/inference_router.py from abc import ABC, abstractmethod from dataclasses import dataclass from enum import Enum from typing import Optional import logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Configuration Constants COMPLEXITY_THRESHOLD = 0.3 MIN_BATTERY_PERCENT = 20 MAX_PROMPT_LENGTH = 10000 class NetworkStatus(Enum): CONNECTED = "CONNECTED" DISCONNECTED = "DISCONNECTED" HIGH_LATENCY = "HIGH_LATENCY" class RoutingDecision(Enum): LOCAL = "LOCAL" CLOUD = "CLOUD" DEGRADED_LOCAL = "DEGRADED_LOCAL" @dataclass class InferenceResult: response: Optional[str] routing_decision: RoutingDecision is_degraded: bool = False error_message: Optional[str] = None class ComplexityClassifier(ABC): @abstractmethod def predict(self, prompt: str) -> float: """Returns complexity score between 0.0 and 1.0.""" pass class NetworkMonitor(ABC): @abstractmethod def get_status(self) -> NetworkStatus: pass class DeviceMonitor(ABC): @abstractmethod def get_battery_percent(self) -> int: pass class LanguageModel(ABC): @abstractmethod def generate(self, prompt: str) -> str: pass class UserNotifier(ABC): @abstractmethod def warn_user(self, message: str) -> None: pass class InferenceRouter: """Routes inference requests based on complexity and infrastructure context.""" def __init__( self, local_classifier: ComplexityClassifier, network_monitor: NetworkMonitor, device_monitor: DeviceMonitor, local_slm: LanguageModel, cloud_llm: LanguageModel, user_notifier: UserNotifier, complexity_threshold: float = COMPLEXITY_THRESHOLD, min_battery_percent: int = MIN_BATTERY_PERCENT ): self._local_classifier = local_classifier self._network_monitor = network_monitor self._device_monitor = device_monitor self._local_slm = local_slm self._cloud_llm = cloud_llm self._user_notifier = user_notifier self._complexity_threshold = complexity_threshold self._min_battery_percent = min_battery_percent def route(self, user_prompt: str) -> InferenceResult: """Route inference request to appropriate model.""" # Input validation if not user_prompt or not isinstance(user_prompt, str): logger.error("Invalid user prompt provided") return InferenceResult( response=None, routing_decision=RoutingDecision.LOCAL, error_message="Invalid prompt: must be a non-empty string" ) if len(user_prompt) > MAX_PROMPT_LENGTH: logger.error( f"Prompt exceeds maximum length of {MAX_PROMPT_LENGTH}") return InferenceResult( response=None, routing_decision=RoutingDecision.LOCAL, error_message=f"Prompt too long: max {MAX_PROMPT_LENGTH} characters" ) # Step 1: Analyze Complexity locally try: complexity_score = self._local_classifier.predict(user_prompt) except Exception as e: logger.exception("Complexity classification failed") # Default to local model on classification failure return self._execute_local(user_prompt, is_degraded=True) # Step 2: Check Infrastructure Context network_status = self._network_monitor.get_status() battery_percent = self._device_monitor.get_battery_percent() is_offline = network_status == NetworkStatus.DISCONNECTED battery_low = battery_percent < self._min_battery_percent # ROUTING DECISION MATRIX # Scenario A: Simple Task (e.g., "Set an alarm") if complexity_score < self._complexity_threshold: logger.info(f"Simple task (score={complexity_score:.2f}) -> LOCAL") return self._execute_local(user_prompt) # Scenario B: Hard Task, but No Internet or Low Battery if is_offline or battery_low: reason = "offline" if is_offline else "low battery" logger.warning(f"Complex task but {reason} -> DEGRADED_LOCAL") self._user_notifier.warn_user("Limited capability mode") return self._execute_local(user_prompt, is_degraded=True) # Scenario C: Hard Task + Good Internet + Battery OK logger.info(f"Complex task (score={complexity_score:.2f}) -> CLOUD") return self._execute_cloud(user_prompt) def _execute_local(self, prompt: str, is_degraded: bool = False) -> InferenceResult: """Execute inference using local small language model.""" try: response = self._local_slm.generate(prompt) return InferenceResult( response=response, routing_decision=( RoutingDecision.DEGRADED_LOCAL if is_degraded else RoutingDecision.LOCAL ), is_degraded=is_degraded ) except Exception as e: logger.exception("Local model inference failed") return InferenceResult( response=None, routing_decision=RoutingDecision.LOCAL, is_degraded=is_degraded, error_message=f"Local inference failed: {str(e)}" ) def _execute_cloud(self, prompt: str) -> InferenceResult: """Execute inference using cloud LLM with fallback to local.""" try: response = self._cloud_llm.generate(prompt) return InferenceResult( response=response, routing_decision=RoutingDecision.CLOUD ) except Exception as e: logger.exception("Cloud inference failed, falling back to local") self._user_notifier.warn_user( "Cloud service unavailable, using limited capability mode" ) return self._execute_local(prompt, is_degraded=True)