File size: 6,108 Bytes
b2e0e38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
# src/infrastructure/inference_router.py

from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import Enum
from typing import Optional
import logging

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


# Configuration Constants
COMPLEXITY_THRESHOLD = 0.3
MIN_BATTERY_PERCENT = 20
MAX_PROMPT_LENGTH = 10000


class NetworkStatus(Enum):
    CONNECTED = "CONNECTED"
    DISCONNECTED = "DISCONNECTED"
    HIGH_LATENCY = "HIGH_LATENCY"


class RoutingDecision(Enum):
    LOCAL = "LOCAL"
    CLOUD = "CLOUD"
    DEGRADED_LOCAL = "DEGRADED_LOCAL"


@dataclass
class InferenceResult:
    response: Optional[str]
    routing_decision: RoutingDecision
    is_degraded: bool = False
    error_message: Optional[str] = None


class ComplexityClassifier(ABC):
    @abstractmethod
    def predict(self, prompt: str) -> float:
        """Returns complexity score between 0.0 and 1.0."""
        pass


class NetworkMonitor(ABC):
    @abstractmethod
    def get_status(self) -> NetworkStatus:
        pass


class DeviceMonitor(ABC):
    @abstractmethod
    def get_battery_percent(self) -> int:
        pass


class LanguageModel(ABC):
    @abstractmethod
    def generate(self, prompt: str) -> str:
        pass


class UserNotifier(ABC):
    @abstractmethod
    def warn_user(self, message: str) -> None:
        pass


class InferenceRouter:
    """Routes inference requests based on complexity and infrastructure context."""

    def __init__(
        self,
        local_classifier: ComplexityClassifier,
        network_monitor: NetworkMonitor,
        device_monitor: DeviceMonitor,
        local_slm: LanguageModel,
        cloud_llm: LanguageModel,
        user_notifier: UserNotifier,
        complexity_threshold: float = COMPLEXITY_THRESHOLD,
        min_battery_percent: int = MIN_BATTERY_PERCENT
    ):
        self._local_classifier = local_classifier
        self._network_monitor = network_monitor
        self._device_monitor = device_monitor
        self._local_slm = local_slm
        self._cloud_llm = cloud_llm
        self._user_notifier = user_notifier
        self._complexity_threshold = complexity_threshold
        self._min_battery_percent = min_battery_percent

    def route(self, user_prompt: str) -> InferenceResult:
        """Route inference request to appropriate model."""
        # Input validation
        if not user_prompt or not isinstance(user_prompt, str):
            logger.error("Invalid user prompt provided")
            return InferenceResult(
                response=None,
                routing_decision=RoutingDecision.LOCAL,
                error_message="Invalid prompt: must be a non-empty string"
            )

        if len(user_prompt) > MAX_PROMPT_LENGTH:
            logger.error(
                f"Prompt exceeds maximum length of {MAX_PROMPT_LENGTH}")
            return InferenceResult(
                response=None,
                routing_decision=RoutingDecision.LOCAL,
                error_message=f"Prompt too long: max {MAX_PROMPT_LENGTH} characters"
            )

        # Step 1: Analyze Complexity locally
        try:
            complexity_score = self._local_classifier.predict(user_prompt)
        except Exception as e:
            logger.exception("Complexity classification failed")
            # Default to local model on classification failure
            return self._execute_local(user_prompt, is_degraded=True)

        # Step 2: Check Infrastructure Context
        network_status = self._network_monitor.get_status()
        battery_percent = self._device_monitor.get_battery_percent()

        is_offline = network_status == NetworkStatus.DISCONNECTED
        battery_low = battery_percent < self._min_battery_percent

        # ROUTING DECISION MATRIX

        # Scenario A: Simple Task (e.g., "Set an alarm")
        if complexity_score < self._complexity_threshold:
            logger.info(f"Simple task (score={complexity_score:.2f}) -> LOCAL")
            return self._execute_local(user_prompt)

        # Scenario B: Hard Task, but No Internet or Low Battery
        if is_offline or battery_low:
            reason = "offline" if is_offline else "low battery"
            logger.warning(f"Complex task but {reason} -> DEGRADED_LOCAL")
            self._user_notifier.warn_user("Limited capability mode")
            return self._execute_local(user_prompt, is_degraded=True)

        # Scenario C: Hard Task + Good Internet + Battery OK
        logger.info(f"Complex task (score={complexity_score:.2f}) -> CLOUD")
        return self._execute_cloud(user_prompt)

    def _execute_local(self, prompt: str, is_degraded: bool = False) -> InferenceResult:
        """Execute inference using local small language model."""
        try:
            response = self._local_slm.generate(prompt)
            return InferenceResult(
                response=response,
                routing_decision=(
                    RoutingDecision.DEGRADED_LOCAL if is_degraded
                    else RoutingDecision.LOCAL
                ),
                is_degraded=is_degraded
            )
        except Exception as e:
            logger.exception("Local model inference failed")
            return InferenceResult(
                response=None,
                routing_decision=RoutingDecision.LOCAL,
                is_degraded=is_degraded,
                error_message=f"Local inference failed: {str(e)}"
            )

    def _execute_cloud(self, prompt: str) -> InferenceResult:
        """Execute inference using cloud LLM with fallback to local."""
        try:
            response = self._cloud_llm.generate(prompt)
            return InferenceResult(
                response=response,
                routing_decision=RoutingDecision.CLOUD
            )
        except Exception as e:
            logger.exception("Cloud inference failed, falling back to local")
            self._user_notifier.warn_user(
                "Cloud service unavailable, using limited capability mode"
            )
            return self._execute_local(prompt, is_degraded=True)