tflux2011's picture
Upload 7 files
b2e0e38 verified
# src/infrastructure/inference_router.py
from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import Enum
from typing import Optional
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Configuration Constants
COMPLEXITY_THRESHOLD = 0.3
MIN_BATTERY_PERCENT = 20
MAX_PROMPT_LENGTH = 10000
class NetworkStatus(Enum):
CONNECTED = "CONNECTED"
DISCONNECTED = "DISCONNECTED"
HIGH_LATENCY = "HIGH_LATENCY"
class RoutingDecision(Enum):
LOCAL = "LOCAL"
CLOUD = "CLOUD"
DEGRADED_LOCAL = "DEGRADED_LOCAL"
@dataclass
class InferenceResult:
response: Optional[str]
routing_decision: RoutingDecision
is_degraded: bool = False
error_message: Optional[str] = None
class ComplexityClassifier(ABC):
@abstractmethod
def predict(self, prompt: str) -> float:
"""Returns complexity score between 0.0 and 1.0."""
pass
class NetworkMonitor(ABC):
@abstractmethod
def get_status(self) -> NetworkStatus:
pass
class DeviceMonitor(ABC):
@abstractmethod
def get_battery_percent(self) -> int:
pass
class LanguageModel(ABC):
@abstractmethod
def generate(self, prompt: str) -> str:
pass
class UserNotifier(ABC):
@abstractmethod
def warn_user(self, message: str) -> None:
pass
class InferenceRouter:
"""Routes inference requests based on complexity and infrastructure context."""
def __init__(
self,
local_classifier: ComplexityClassifier,
network_monitor: NetworkMonitor,
device_monitor: DeviceMonitor,
local_slm: LanguageModel,
cloud_llm: LanguageModel,
user_notifier: UserNotifier,
complexity_threshold: float = COMPLEXITY_THRESHOLD,
min_battery_percent: int = MIN_BATTERY_PERCENT
):
self._local_classifier = local_classifier
self._network_monitor = network_monitor
self._device_monitor = device_monitor
self._local_slm = local_slm
self._cloud_llm = cloud_llm
self._user_notifier = user_notifier
self._complexity_threshold = complexity_threshold
self._min_battery_percent = min_battery_percent
def route(self, user_prompt: str) -> InferenceResult:
"""Route inference request to appropriate model."""
# Input validation
if not user_prompt or not isinstance(user_prompt, str):
logger.error("Invalid user prompt provided")
return InferenceResult(
response=None,
routing_decision=RoutingDecision.LOCAL,
error_message="Invalid prompt: must be a non-empty string"
)
if len(user_prompt) > MAX_PROMPT_LENGTH:
logger.error(
f"Prompt exceeds maximum length of {MAX_PROMPT_LENGTH}")
return InferenceResult(
response=None,
routing_decision=RoutingDecision.LOCAL,
error_message=f"Prompt too long: max {MAX_PROMPT_LENGTH} characters"
)
# Step 1: Analyze Complexity locally
try:
complexity_score = self._local_classifier.predict(user_prompt)
except Exception as e:
logger.exception("Complexity classification failed")
# Default to local model on classification failure
return self._execute_local(user_prompt, is_degraded=True)
# Step 2: Check Infrastructure Context
network_status = self._network_monitor.get_status()
battery_percent = self._device_monitor.get_battery_percent()
is_offline = network_status == NetworkStatus.DISCONNECTED
battery_low = battery_percent < self._min_battery_percent
# ROUTING DECISION MATRIX
# Scenario A: Simple Task (e.g., "Set an alarm")
if complexity_score < self._complexity_threshold:
logger.info(f"Simple task (score={complexity_score:.2f}) -> LOCAL")
return self._execute_local(user_prompt)
# Scenario B: Hard Task, but No Internet or Low Battery
if is_offline or battery_low:
reason = "offline" if is_offline else "low battery"
logger.warning(f"Complex task but {reason} -> DEGRADED_LOCAL")
self._user_notifier.warn_user("Limited capability mode")
return self._execute_local(user_prompt, is_degraded=True)
# Scenario C: Hard Task + Good Internet + Battery OK
logger.info(f"Complex task (score={complexity_score:.2f}) -> CLOUD")
return self._execute_cloud(user_prompt)
def _execute_local(self, prompt: str, is_degraded: bool = False) -> InferenceResult:
"""Execute inference using local small language model."""
try:
response = self._local_slm.generate(prompt)
return InferenceResult(
response=response,
routing_decision=(
RoutingDecision.DEGRADED_LOCAL if is_degraded
else RoutingDecision.LOCAL
),
is_degraded=is_degraded
)
except Exception as e:
logger.exception("Local model inference failed")
return InferenceResult(
response=None,
routing_decision=RoutingDecision.LOCAL,
is_degraded=is_degraded,
error_message=f"Local inference failed: {str(e)}"
)
def _execute_cloud(self, prompt: str) -> InferenceResult:
"""Execute inference using cloud LLM with fallback to local."""
try:
response = self._cloud_llm.generate(prompt)
return InferenceResult(
response=response,
routing_decision=RoutingDecision.CLOUD
)
except Exception as e:
logger.exception("Cloud inference failed, falling back to local")
self._user_notifier.warn_user(
"Cloud service unavailable, using limited capability mode"
)
return self._execute_local(prompt, is_degraded=True)