|
|
|
|
|
|
|
|
from abc import ABC, abstractmethod |
|
|
from dataclasses import dataclass |
|
|
from enum import Enum |
|
|
from typing import Optional |
|
|
import logging |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
|
|
|
COMPLEXITY_THRESHOLD = 0.3 |
|
|
MIN_BATTERY_PERCENT = 20 |
|
|
MAX_PROMPT_LENGTH = 10000 |
|
|
|
|
|
|
|
|
class NetworkStatus(Enum): |
|
|
CONNECTED = "CONNECTED" |
|
|
DISCONNECTED = "DISCONNECTED" |
|
|
HIGH_LATENCY = "HIGH_LATENCY" |
|
|
|
|
|
|
|
|
class RoutingDecision(Enum): |
|
|
LOCAL = "LOCAL" |
|
|
CLOUD = "CLOUD" |
|
|
DEGRADED_LOCAL = "DEGRADED_LOCAL" |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class InferenceResult: |
|
|
response: Optional[str] |
|
|
routing_decision: RoutingDecision |
|
|
is_degraded: bool = False |
|
|
error_message: Optional[str] = None |
|
|
|
|
|
|
|
|
class ComplexityClassifier(ABC): |
|
|
@abstractmethod |
|
|
def predict(self, prompt: str) -> float: |
|
|
"""Returns complexity score between 0.0 and 1.0.""" |
|
|
pass |
|
|
|
|
|
|
|
|
class NetworkMonitor(ABC): |
|
|
@abstractmethod |
|
|
def get_status(self) -> NetworkStatus: |
|
|
pass |
|
|
|
|
|
|
|
|
class DeviceMonitor(ABC): |
|
|
@abstractmethod |
|
|
def get_battery_percent(self) -> int: |
|
|
pass |
|
|
|
|
|
|
|
|
class LanguageModel(ABC): |
|
|
@abstractmethod |
|
|
def generate(self, prompt: str) -> str: |
|
|
pass |
|
|
|
|
|
|
|
|
class UserNotifier(ABC): |
|
|
@abstractmethod |
|
|
def warn_user(self, message: str) -> None: |
|
|
pass |
|
|
|
|
|
|
|
|
class InferenceRouter: |
|
|
"""Routes inference requests based on complexity and infrastructure context.""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
local_classifier: ComplexityClassifier, |
|
|
network_monitor: NetworkMonitor, |
|
|
device_monitor: DeviceMonitor, |
|
|
local_slm: LanguageModel, |
|
|
cloud_llm: LanguageModel, |
|
|
user_notifier: UserNotifier, |
|
|
complexity_threshold: float = COMPLEXITY_THRESHOLD, |
|
|
min_battery_percent: int = MIN_BATTERY_PERCENT |
|
|
): |
|
|
self._local_classifier = local_classifier |
|
|
self._network_monitor = network_monitor |
|
|
self._device_monitor = device_monitor |
|
|
self._local_slm = local_slm |
|
|
self._cloud_llm = cloud_llm |
|
|
self._user_notifier = user_notifier |
|
|
self._complexity_threshold = complexity_threshold |
|
|
self._min_battery_percent = min_battery_percent |
|
|
|
|
|
def route(self, user_prompt: str) -> InferenceResult: |
|
|
"""Route inference request to appropriate model.""" |
|
|
|
|
|
if not user_prompt or not isinstance(user_prompt, str): |
|
|
logger.error("Invalid user prompt provided") |
|
|
return InferenceResult( |
|
|
response=None, |
|
|
routing_decision=RoutingDecision.LOCAL, |
|
|
error_message="Invalid prompt: must be a non-empty string" |
|
|
) |
|
|
|
|
|
if len(user_prompt) > MAX_PROMPT_LENGTH: |
|
|
logger.error( |
|
|
f"Prompt exceeds maximum length of {MAX_PROMPT_LENGTH}") |
|
|
return InferenceResult( |
|
|
response=None, |
|
|
routing_decision=RoutingDecision.LOCAL, |
|
|
error_message=f"Prompt too long: max {MAX_PROMPT_LENGTH} characters" |
|
|
) |
|
|
|
|
|
|
|
|
try: |
|
|
complexity_score = self._local_classifier.predict(user_prompt) |
|
|
except Exception as e: |
|
|
logger.exception("Complexity classification failed") |
|
|
|
|
|
return self._execute_local(user_prompt, is_degraded=True) |
|
|
|
|
|
|
|
|
network_status = self._network_monitor.get_status() |
|
|
battery_percent = self._device_monitor.get_battery_percent() |
|
|
|
|
|
is_offline = network_status == NetworkStatus.DISCONNECTED |
|
|
battery_low = battery_percent < self._min_battery_percent |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if complexity_score < self._complexity_threshold: |
|
|
logger.info(f"Simple task (score={complexity_score:.2f}) -> LOCAL") |
|
|
return self._execute_local(user_prompt) |
|
|
|
|
|
|
|
|
if is_offline or battery_low: |
|
|
reason = "offline" if is_offline else "low battery" |
|
|
logger.warning(f"Complex task but {reason} -> DEGRADED_LOCAL") |
|
|
self._user_notifier.warn_user("Limited capability mode") |
|
|
return self._execute_local(user_prompt, is_degraded=True) |
|
|
|
|
|
|
|
|
logger.info(f"Complex task (score={complexity_score:.2f}) -> CLOUD") |
|
|
return self._execute_cloud(user_prompt) |
|
|
|
|
|
def _execute_local(self, prompt: str, is_degraded: bool = False) -> InferenceResult: |
|
|
"""Execute inference using local small language model.""" |
|
|
try: |
|
|
response = self._local_slm.generate(prompt) |
|
|
return InferenceResult( |
|
|
response=response, |
|
|
routing_decision=( |
|
|
RoutingDecision.DEGRADED_LOCAL if is_degraded |
|
|
else RoutingDecision.LOCAL |
|
|
), |
|
|
is_degraded=is_degraded |
|
|
) |
|
|
except Exception as e: |
|
|
logger.exception("Local model inference failed") |
|
|
return InferenceResult( |
|
|
response=None, |
|
|
routing_decision=RoutingDecision.LOCAL, |
|
|
is_degraded=is_degraded, |
|
|
error_message=f"Local inference failed: {str(e)}" |
|
|
) |
|
|
|
|
|
def _execute_cloud(self, prompt: str) -> InferenceResult: |
|
|
"""Execute inference using cloud LLM with fallback to local.""" |
|
|
try: |
|
|
response = self._cloud_llm.generate(prompt) |
|
|
return InferenceResult( |
|
|
response=response, |
|
|
routing_decision=RoutingDecision.CLOUD |
|
|
) |
|
|
except Exception as e: |
|
|
logger.exception("Cloud inference failed, falling back to local") |
|
|
self._user_notifier.warn_user( |
|
|
"Cloud service unavailable, using limited capability mode" |
|
|
) |
|
|
return self._execute_local(prompt, is_degraded=True) |
|
|
|