File size: 6,108 Bytes
b2e0e38 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 |
# src/infrastructure/inference_router.py
from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import Enum
from typing import Optional
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Configuration Constants
COMPLEXITY_THRESHOLD = 0.3
MIN_BATTERY_PERCENT = 20
MAX_PROMPT_LENGTH = 10000
class NetworkStatus(Enum):
CONNECTED = "CONNECTED"
DISCONNECTED = "DISCONNECTED"
HIGH_LATENCY = "HIGH_LATENCY"
class RoutingDecision(Enum):
LOCAL = "LOCAL"
CLOUD = "CLOUD"
DEGRADED_LOCAL = "DEGRADED_LOCAL"
@dataclass
class InferenceResult:
response: Optional[str]
routing_decision: RoutingDecision
is_degraded: bool = False
error_message: Optional[str] = None
class ComplexityClassifier(ABC):
@abstractmethod
def predict(self, prompt: str) -> float:
"""Returns complexity score between 0.0 and 1.0."""
pass
class NetworkMonitor(ABC):
@abstractmethod
def get_status(self) -> NetworkStatus:
pass
class DeviceMonitor(ABC):
@abstractmethod
def get_battery_percent(self) -> int:
pass
class LanguageModel(ABC):
@abstractmethod
def generate(self, prompt: str) -> str:
pass
class UserNotifier(ABC):
@abstractmethod
def warn_user(self, message: str) -> None:
pass
class InferenceRouter:
"""Routes inference requests based on complexity and infrastructure context."""
def __init__(
self,
local_classifier: ComplexityClassifier,
network_monitor: NetworkMonitor,
device_monitor: DeviceMonitor,
local_slm: LanguageModel,
cloud_llm: LanguageModel,
user_notifier: UserNotifier,
complexity_threshold: float = COMPLEXITY_THRESHOLD,
min_battery_percent: int = MIN_BATTERY_PERCENT
):
self._local_classifier = local_classifier
self._network_monitor = network_monitor
self._device_monitor = device_monitor
self._local_slm = local_slm
self._cloud_llm = cloud_llm
self._user_notifier = user_notifier
self._complexity_threshold = complexity_threshold
self._min_battery_percent = min_battery_percent
def route(self, user_prompt: str) -> InferenceResult:
"""Route inference request to appropriate model."""
# Input validation
if not user_prompt or not isinstance(user_prompt, str):
logger.error("Invalid user prompt provided")
return InferenceResult(
response=None,
routing_decision=RoutingDecision.LOCAL,
error_message="Invalid prompt: must be a non-empty string"
)
if len(user_prompt) > MAX_PROMPT_LENGTH:
logger.error(
f"Prompt exceeds maximum length of {MAX_PROMPT_LENGTH}")
return InferenceResult(
response=None,
routing_decision=RoutingDecision.LOCAL,
error_message=f"Prompt too long: max {MAX_PROMPT_LENGTH} characters"
)
# Step 1: Analyze Complexity locally
try:
complexity_score = self._local_classifier.predict(user_prompt)
except Exception as e:
logger.exception("Complexity classification failed")
# Default to local model on classification failure
return self._execute_local(user_prompt, is_degraded=True)
# Step 2: Check Infrastructure Context
network_status = self._network_monitor.get_status()
battery_percent = self._device_monitor.get_battery_percent()
is_offline = network_status == NetworkStatus.DISCONNECTED
battery_low = battery_percent < self._min_battery_percent
# ROUTING DECISION MATRIX
# Scenario A: Simple Task (e.g., "Set an alarm")
if complexity_score < self._complexity_threshold:
logger.info(f"Simple task (score={complexity_score:.2f}) -> LOCAL")
return self._execute_local(user_prompt)
# Scenario B: Hard Task, but No Internet or Low Battery
if is_offline or battery_low:
reason = "offline" if is_offline else "low battery"
logger.warning(f"Complex task but {reason} -> DEGRADED_LOCAL")
self._user_notifier.warn_user("Limited capability mode")
return self._execute_local(user_prompt, is_degraded=True)
# Scenario C: Hard Task + Good Internet + Battery OK
logger.info(f"Complex task (score={complexity_score:.2f}) -> CLOUD")
return self._execute_cloud(user_prompt)
def _execute_local(self, prompt: str, is_degraded: bool = False) -> InferenceResult:
"""Execute inference using local small language model."""
try:
response = self._local_slm.generate(prompt)
return InferenceResult(
response=response,
routing_decision=(
RoutingDecision.DEGRADED_LOCAL if is_degraded
else RoutingDecision.LOCAL
),
is_degraded=is_degraded
)
except Exception as e:
logger.exception("Local model inference failed")
return InferenceResult(
response=None,
routing_decision=RoutingDecision.LOCAL,
is_degraded=is_degraded,
error_message=f"Local inference failed: {str(e)}"
)
def _execute_cloud(self, prompt: str) -> InferenceResult:
"""Execute inference using cloud LLM with fallback to local."""
try:
response = self._cloud_llm.generate(prompt)
return InferenceResult(
response=response,
routing_decision=RoutingDecision.CLOUD
)
except Exception as e:
logger.exception("Cloud inference failed, falling back to local")
self._user_notifier.warn_user(
"Cloud service unavailable, using limited capability mode"
)
return self._execute_local(prompt, is_degraded=True)
|