|
|
|
|
|
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
import torch.nn.functional as F
|
|
|
import time
|
|
|
import numpy as np
|
|
|
from typing import Optional, Dict, Any, List, Tuple, Union
|
|
|
from dataclasses import dataclass, field
|
|
|
from enum import Enum
|
|
|
import logging
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
class ReasoningPath(Enum):
|
|
|
"""Available reasoning paths with different compute requirements"""
|
|
|
FAST = "fast"
|
|
|
STANDARD = "standard"
|
|
|
EXPERT = "expert"
|
|
|
DEEP = "deep"
|
|
|
ULTRA_DEEP = "ultra_deep"
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
class ComplexityFeatures:
|
|
|
"""Features used for complexity scoring"""
|
|
|
token_length: int
|
|
|
token_entropy: float
|
|
|
has_math: bool
|
|
|
has_code: bool
|
|
|
named_entities_count: int
|
|
|
syntactic_depth: float
|
|
|
conversation_depth: int
|
|
|
prior_failures: int = 0
|
|
|
user_preference_score: float = 0.5
|
|
|
use_moe: bool = False
|
|
|
domain_signals: Dict[str, float] = field(default_factory=dict)
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
class RoutingDecision:
|
|
|
"""Routing decision output"""
|
|
|
path: ReasoningPath
|
|
|
confidence: float
|
|
|
complexity_score: float
|
|
|
estimated_latency_ms: float
|
|
|
debug_info: Dict[str, Any] = field(default_factory=dict)
|
|
|
|
|
|
|
|
|
class ComplexityScorer(nn.Module):
|
|
|
"""Neural network for scoring input complexity"""
|
|
|
|
|
|
def __init__(self, feature_dim: int = 128, hidden_dim: int = 256):
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
|
self.text_encoder = nn.Sequential(
|
|
|
nn.Linear(feature_dim, hidden_dim),
|
|
|
nn.ReLU(),
|
|
|
nn.Dropout(0.1),
|
|
|
nn.Linear(hidden_dim, hidden_dim // 2)
|
|
|
)
|
|
|
|
|
|
|
|
|
self.math_encoder = nn.Linear(32, hidden_dim // 4)
|
|
|
self.code_encoder = nn.Linear(32, hidden_dim // 4)
|
|
|
|
|
|
|
|
|
self.complexity_head = nn.Sequential(
|
|
|
nn.Linear(hidden_dim, hidden_dim),
|
|
|
nn.ReLU(),
|
|
|
nn.Dropout(0.1),
|
|
|
nn.Linear(hidden_dim, 1),
|
|
|
nn.Sigmoid()
|
|
|
)
|
|
|
|
|
|
|
|
|
self.register_buffer('feature_mean', torch.zeros(feature_dim))
|
|
|
self.register_buffer('feature_std', torch.ones(feature_dim))
|
|
|
|
|
|
def extract_features(self, text: str, tokens: torch.Tensor) -> ComplexityFeatures:
|
|
|
"""Extract complexity features from input"""
|
|
|
|
|
|
token_length = len(tokens)
|
|
|
|
|
|
|
|
|
token_probs = torch.softmax(torch.randn(len(tokens)), dim=-1)
|
|
|
token_entropy = -torch.sum(token_probs * torch.log(token_probs + 1e-10)).item()
|
|
|
|
|
|
|
|
|
has_math = any(symbol in text for symbol in ['=', '∫', '∑', '∂', 'sqrt', 'log'])
|
|
|
has_code = any(keyword in text for keyword in ['def', 'class', 'function', '{', '}', '()', '[]'])
|
|
|
|
|
|
|
|
|
import re
|
|
|
capitals = re.findall(r'\b[A-Z][a-z]+\b', text)
|
|
|
named_entities_count = len(set(capitals))
|
|
|
|
|
|
|
|
|
syntactic_depth = len(text.split('.')) * np.log(1 + len(text.split(',')))
|
|
|
|
|
|
return ComplexityFeatures(
|
|
|
token_length=token_length,
|
|
|
token_entropy=token_entropy,
|
|
|
has_math=has_math,
|
|
|
has_code=has_code,
|
|
|
named_entities_count=named_entities_count,
|
|
|
syntactic_depth=syntactic_depth,
|
|
|
conversation_depth=0
|
|
|
)
|
|
|
|
|
|
def forward(self, features: ComplexityFeatures) -> torch.Tensor:
|
|
|
"""Compute complexity score from features"""
|
|
|
|
|
|
dtype = next(self.parameters()).dtype
|
|
|
device = next(self.parameters()).device
|
|
|
feature_vec = torch.tensor([
|
|
|
features.token_length / 1000.0,
|
|
|
features.token_entropy / 10.0,
|
|
|
float(features.has_math),
|
|
|
float(features.has_code),
|
|
|
features.named_entities_count / 20.0,
|
|
|
features.syntactic_depth / 100.0,
|
|
|
features.conversation_depth / 10.0,
|
|
|
features.prior_failures / 5.0,
|
|
|
features.user_preference_score
|
|
|
], dtype=dtype, device=device).unsqueeze(0)
|
|
|
|
|
|
|
|
|
if feature_vec.shape[1] < self.feature_mean.shape[0]:
|
|
|
padding = torch.zeros((1, self.feature_mean.shape[0] - feature_vec.shape[1]), dtype=dtype, device=device)
|
|
|
feature_vec = torch.cat([feature_vec, padding], dim=1)
|
|
|
|
|
|
|
|
|
feature_vec = (feature_vec - self.feature_mean.to(dtype=dtype, device=device)) / (self.feature_std.to(dtype=dtype, device=device) + 1e-8)
|
|
|
|
|
|
|
|
|
text_features = self.text_encoder(feature_vec)
|
|
|
|
|
|
|
|
|
if features.has_math:
|
|
|
math_features = self.math_encoder(torch.randn(1, 32, dtype=dtype, device=device))
|
|
|
text_features = torch.cat([text_features, math_features], dim=-1)
|
|
|
|
|
|
if features.has_code:
|
|
|
code_features = self.code_encoder(torch.randn(1, 32, dtype=dtype, device=device))
|
|
|
text_features = torch.cat([text_features, code_features], dim=-1)
|
|
|
|
|
|
|
|
|
if text_features.shape[1] < 256:
|
|
|
padding = torch.zeros((1, 256 - text_features.shape[1]), dtype=dtype, device=device)
|
|
|
text_features = torch.cat([text_features, padding], dim=1)
|
|
|
|
|
|
|
|
|
complexity_score = self.complexity_head(text_features)
|
|
|
|
|
|
return complexity_score.squeeze()
|
|
|
|
|
|
|
|
|
class RouterNetwork(nn.Module):
|
|
|
"""Neural router for path selection"""
|
|
|
|
|
|
def __init__(self, hidden_dim: int = 4096, router_hidden: int = 1024, n_paths: int = 4):
|
|
|
super().__init__()
|
|
|
|
|
|
self.n_paths = n_paths
|
|
|
|
|
|
|
|
|
self.router = nn.Sequential(
|
|
|
nn.Linear(hidden_dim + 9, router_hidden),
|
|
|
nn.ReLU(),
|
|
|
nn.Dropout(0.1),
|
|
|
nn.Linear(router_hidden, router_hidden // 2),
|
|
|
nn.ReLU(),
|
|
|
nn.Linear(router_hidden // 2, n_paths)
|
|
|
)
|
|
|
|
|
|
|
|
|
self.confidence = nn.Sequential(
|
|
|
nn.Linear(hidden_dim + n_paths, 256),
|
|
|
nn.ReLU(),
|
|
|
nn.Linear(256, 1),
|
|
|
nn.Sigmoid()
|
|
|
)
|
|
|
|
|
|
def forward(self, hidden_states: torch.Tensor, complexity_features: ComplexityFeatures) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
|
"""Route to appropriate path based on input"""
|
|
|
batch_size = hidden_states.shape[0]
|
|
|
|
|
|
|
|
|
pooled = hidden_states.mean(dim=1)
|
|
|
|
|
|
|
|
|
dtype = hidden_states.dtype
|
|
|
device = hidden_states.device
|
|
|
feature_vec = torch.tensor([
|
|
|
complexity_features.token_length / 1000.0,
|
|
|
complexity_features.token_entropy / 10.0,
|
|
|
float(complexity_features.has_math),
|
|
|
float(complexity_features.has_code),
|
|
|
complexity_features.named_entities_count / 20.0,
|
|
|
complexity_features.syntactic_depth / 100.0,
|
|
|
complexity_features.conversation_depth / 10.0,
|
|
|
complexity_features.prior_failures / 5.0,
|
|
|
complexity_features.user_preference_score
|
|
|
], dtype=dtype, device=device).unsqueeze(0).repeat(batch_size, 1)
|
|
|
|
|
|
|
|
|
router_input = torch.cat([pooled, feature_vec], dim=-1)
|
|
|
|
|
|
|
|
|
logits = self.router(router_input)
|
|
|
probs = F.softmax(logits, dim=-1)
|
|
|
|
|
|
|
|
|
conf_input = torch.cat([pooled, probs], dim=-1)
|
|
|
confidence = self.confidence(conf_input).squeeze(-1)
|
|
|
|
|
|
return probs, confidence
|
|
|
|
|
|
|
|
|
class DynamicReasoningEngine(nn.Module):
|
|
|
"""Main DRE orchestrator for adaptive inference"""
|
|
|
|
|
|
def __init__(
|
|
|
self,
|
|
|
base_model: nn.Module,
|
|
|
config: Dict[str, Any],
|
|
|
fast_model: Optional[nn.Module] = None,
|
|
|
enable_caching: bool = True
|
|
|
):
|
|
|
super().__init__()
|
|
|
|
|
|
self.base_model = base_model
|
|
|
self.fast_model = fast_model or self._create_distilled_model()
|
|
|
self.config = config
|
|
|
|
|
|
|
|
|
self.complexity_scorer = ComplexityScorer()
|
|
|
self.router = RouterNetwork(
|
|
|
hidden_dim=config.get('hidden_dim', 4096),
|
|
|
n_paths=len(ReasoningPath)
|
|
|
)
|
|
|
|
|
|
self.hidden_complexity_head = nn.Sequential(
|
|
|
nn.Linear(config.get('hidden_dim', 4096), 256),
|
|
|
nn.ReLU(),
|
|
|
nn.Dropout(0.1),
|
|
|
nn.Linear(256, 1),
|
|
|
nn.Sigmoid(),
|
|
|
)
|
|
|
|
|
|
|
|
|
self.enable_caching = enable_caching
|
|
|
self.cache = {} if enable_caching else None
|
|
|
self.cache_hits = 0
|
|
|
self.cache_misses = 0
|
|
|
|
|
|
|
|
|
|
|
|
self.complexity_thresholds = {
|
|
|
ReasoningPath.FAST: 0.2,
|
|
|
ReasoningPath.STANDARD: 0.35,
|
|
|
ReasoningPath.EXPERT: 0.5,
|
|
|
ReasoningPath.DEEP: 0.75,
|
|
|
ReasoningPath.ULTRA_DEEP: 0.9
|
|
|
}
|
|
|
|
|
|
|
|
|
self.latency_history = {path: [] for path in ReasoningPath}
|
|
|
|
|
|
|
|
|
self.activation_counts = {path: 0 for path in ReasoningPath}
|
|
|
self.total_activations = 0
|
|
|
self.complexity_scores = []
|
|
|
self.confidence_scores = []
|
|
|
self.reasoning_steps = []
|
|
|
|
|
|
def _create_distilled_model(self):
|
|
|
"""Create a smaller distilled version of the base model"""
|
|
|
|
|
|
return nn.Sequential(
|
|
|
nn.Linear(self.base_model.config.n_embd, 512),
|
|
|
nn.ReLU(),
|
|
|
nn.Linear(512, self.base_model.config.vocab_size)
|
|
|
)
|
|
|
|
|
|
def _check_cache(self, input_hash: str) -> Optional[torch.Tensor]:
|
|
|
"""Check if response is cached"""
|
|
|
if not self.enable_caching:
|
|
|
return None
|
|
|
|
|
|
if input_hash in self.cache:
|
|
|
self.cache_hits += 1
|
|
|
logger.info(f"Cache hit! Hits: {self.cache_hits}, Misses: {self.cache_misses}")
|
|
|
return self.cache[input_hash]
|
|
|
|
|
|
self.cache_misses += 1
|
|
|
return None
|
|
|
|
|
|
def _fast_inference(self, input_ids: torch.Tensor, **kwargs) -> torch.Tensor:
|
|
|
"""Fast path: cached or distilled model inference"""
|
|
|
|
|
|
input_hash = hash(input_ids.cpu().numpy().tobytes())
|
|
|
cached = self._check_cache(str(input_hash))
|
|
|
if cached is not None:
|
|
|
return cached
|
|
|
|
|
|
|
|
|
if self.fast_model is not None:
|
|
|
with torch.no_grad():
|
|
|
embeddings = self.base_model.embed_tokens(input_ids)
|
|
|
pooled = embeddings.mean(dim=1)
|
|
|
output = self.fast_model(pooled)
|
|
|
|
|
|
|
|
|
if self.enable_caching:
|
|
|
self.cache[str(input_hash)] = output
|
|
|
|
|
|
return output
|
|
|
|
|
|
return None
|
|
|
|
|
|
def _standard_inference(self, input_ids: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:
|
|
|
"""Standard path: normal forward pass"""
|
|
|
return self.base_model(input_ids, **kwargs)
|
|
|
|
|
|
def _deep_inference(
|
|
|
self,
|
|
|
input_ids: torch.Tensor,
|
|
|
max_steps: int = 10,
|
|
|
**kwargs
|
|
|
) -> Dict[str, torch.Tensor]:
|
|
|
"""Deep path: chain-of-thought reasoning"""
|
|
|
outputs = []
|
|
|
current_input = input_ids
|
|
|
|
|
|
for step in range(max_steps):
|
|
|
|
|
|
step_output = self.base_model(current_input, **kwargs)
|
|
|
outputs.append(step_output)
|
|
|
|
|
|
|
|
|
if self._is_reasoning_complete(step_output):
|
|
|
break
|
|
|
|
|
|
|
|
|
current_input = input_ids
|
|
|
|
|
|
|
|
|
final_output = self._aggregate_reasoning_steps(outputs)
|
|
|
return final_output
|
|
|
|
|
|
def _ultra_deep_inference(
|
|
|
self,
|
|
|
input_ids: torch.Tensor,
|
|
|
max_depth: int = 5,
|
|
|
**kwargs
|
|
|
) -> Dict[str, torch.Tensor]:
|
|
|
"""Ultra-deep path: recursive reasoning with self-reflection"""
|
|
|
def recursive_reason(input_ids, depth):
|
|
|
if depth == 0:
|
|
|
return self._standard_inference(input_ids, **kwargs)
|
|
|
|
|
|
|
|
|
response = self._deep_inference(input_ids, **kwargs)
|
|
|
|
|
|
|
|
|
critique = self._generate_critique(response)
|
|
|
|
|
|
|
|
|
refined = recursive_reason(input_ids, depth - 1)
|
|
|
|
|
|
return self._merge_responses(response, refined)
|
|
|
|
|
|
return recursive_reason(input_ids, max_depth)
|
|
|
|
|
|
def _is_reasoning_complete(self, output: Dict[str, torch.Tensor]) -> bool:
|
|
|
"""Check if reasoning chain is complete"""
|
|
|
|
|
|
logits = output.get('logits', None)
|
|
|
if logits is not None:
|
|
|
probs = F.softmax(logits[:, -1, :], dim=-1)
|
|
|
max_prob = probs.max().item()
|
|
|
return max_prob > 0.95
|
|
|
return False
|
|
|
|
|
|
def _aggregate_reasoning_steps(self, outputs: List[Dict]) -> Dict[str, torch.Tensor]:
|
|
|
"""Aggregate multiple reasoning steps"""
|
|
|
|
|
|
aggregated = {}
|
|
|
for key in outputs[0].keys():
|
|
|
if isinstance(outputs[0][key], torch.Tensor):
|
|
|
stacked = torch.stack([o[key] for o in outputs])
|
|
|
aggregated[key] = stacked.mean(dim=0)
|
|
|
else:
|
|
|
aggregated[key] = outputs[-1][key]
|
|
|
return aggregated
|
|
|
|
|
|
def _generate_critique(self, response: Dict[str, torch.Tensor]) -> torch.Tensor:
|
|
|
"""Generate self-critique of response"""
|
|
|
|
|
|
return torch.randn_like(response['logits'])
|
|
|
|
|
|
def _merge_responses(self, response1: Dict, response2: Dict) -> Dict[str, torch.Tensor]:
|
|
|
"""Merge two responses"""
|
|
|
merged = {}
|
|
|
for key in response1.keys():
|
|
|
if isinstance(response1[key], torch.Tensor):
|
|
|
|
|
|
merged[key] = 0.6 * response1[key] + 0.4 * response2[key]
|
|
|
else:
|
|
|
merged[key] = response1[key]
|
|
|
return merged
|
|
|
|
|
|
def route(
|
|
|
self,
|
|
|
input_ids: torch.Tensor,
|
|
|
text: str = "",
|
|
|
use_soft_routing: bool = False,
|
|
|
override_path: Optional[ReasoningPath] = None
|
|
|
) -> RoutingDecision:
|
|
|
"""Decide which reasoning path to use"""
|
|
|
|
|
|
|
|
|
features = self.complexity_scorer.extract_features(text, input_ids[0])
|
|
|
|
|
|
|
|
|
|
|
|
embeddings = self.base_model.embed_tokens(input_ids).detach()
|
|
|
pooled = embeddings.mean(dim=1)
|
|
|
complexity_hidden = self.hidden_complexity_head(pooled).squeeze(-1)
|
|
|
complexity_features = self.complexity_scorer(features).squeeze()
|
|
|
|
|
|
if isinstance(complexity_features, torch.Tensor) and complexity_features.dim() == 0:
|
|
|
complexity_features_tensor = complexity_features
|
|
|
else:
|
|
|
|
|
|
complexity_features_tensor = torch.as_tensor(complexity_features, dtype=complexity_hidden.dtype, device=complexity_hidden.device)
|
|
|
complexity_score_tensor = 0.7 * complexity_hidden + 0.3 * complexity_features_tensor
|
|
|
complexity_score = float(complexity_score_tensor.mean().detach().cpu().item())
|
|
|
|
|
|
|
|
|
probs, confidence = self.router(embeddings, features)
|
|
|
|
|
|
|
|
|
if override_path:
|
|
|
return RoutingDecision(
|
|
|
path=override_path,
|
|
|
confidence=1.0,
|
|
|
complexity_score=complexity_score,
|
|
|
estimated_latency_ms=self._estimate_latency(override_path),
|
|
|
debug_info={'override': True}
|
|
|
)
|
|
|
|
|
|
|
|
|
if use_soft_routing:
|
|
|
|
|
|
probs_np = probs.detach().to(torch.float32).cpu().numpy()
|
|
|
return RoutingDecision(
|
|
|
path=ReasoningPath.STANDARD,
|
|
|
confidence=confidence.item(),
|
|
|
complexity_score=complexity_score,
|
|
|
estimated_latency_ms=self._estimate_latency_weighted(probs),
|
|
|
debug_info={'probs': probs_np, 'soft_routing': True}
|
|
|
)
|
|
|
|
|
|
|
|
|
path_idx = probs.argmax(dim=-1).item()
|
|
|
selected_path = list(ReasoningPath)[path_idx]
|
|
|
|
|
|
|
|
|
|
|
|
if not self.training:
|
|
|
if complexity_score < self.complexity_thresholds[ReasoningPath.FAST]:
|
|
|
selected_path = ReasoningPath.FAST
|
|
|
elif complexity_score < self.complexity_thresholds[ReasoningPath.STANDARD]:
|
|
|
selected_path = ReasoningPath.STANDARD
|
|
|
elif complexity_score < self.complexity_thresholds[ReasoningPath.DEEP]:
|
|
|
selected_path = ReasoningPath.DEEP
|
|
|
elif complexity_score >= self.complexity_thresholds[ReasoningPath.ULTRA_DEEP]:
|
|
|
selected_path = ReasoningPath.ULTRA_DEEP
|
|
|
|
|
|
|
|
|
self._last_router_tensors = {
|
|
|
'probs': probs,
|
|
|
'confidence': confidence,
|
|
|
'complexity': complexity_score_tensor,
|
|
|
}
|
|
|
probs_np = probs.detach().to(torch.float32).cpu().numpy()
|
|
|
return RoutingDecision(
|
|
|
path=selected_path,
|
|
|
confidence=confidence.item(),
|
|
|
complexity_score=complexity_score,
|
|
|
estimated_latency_ms=self._estimate_latency(selected_path),
|
|
|
debug_info={
|
|
|
'probs': probs_np,
|
|
|
'features': features.__dict__
|
|
|
}
|
|
|
)
|
|
|
|
|
|
def _estimate_latency(self, path: ReasoningPath) -> float:
|
|
|
"""Estimate latency for a given path"""
|
|
|
latency_ranges = {
|
|
|
ReasoningPath.FAST: (10, 100),
|
|
|
ReasoningPath.STANDARD: (1000, 5000),
|
|
|
ReasoningPath.EXPERT: (3000, 10000),
|
|
|
ReasoningPath.DEEP: (10000, 60000),
|
|
|
ReasoningPath.ULTRA_DEEP: (60000, 300000)
|
|
|
}
|
|
|
|
|
|
if self.latency_history[path]:
|
|
|
|
|
|
return np.mean(self.latency_history[path][-10:])
|
|
|
|
|
|
|
|
|
min_lat, max_lat = latency_ranges[path]
|
|
|
return (min_lat + max_lat) / 2
|
|
|
|
|
|
def _estimate_latency_weighted(self, probs: torch.Tensor) -> float:
|
|
|
"""Estimate weighted latency for soft routing"""
|
|
|
latencies = [self._estimate_latency(path) for path in ReasoningPath]
|
|
|
weighted_latency = sum(p * l for p, l in zip(probs[0].detach().to(torch.float32).cpu().numpy(), latencies))
|
|
|
return weighted_latency
|
|
|
|
|
|
def get_current_metrics(self) -> Dict[str, Any]:
|
|
|
"""Get current DRE metrics for logging"""
|
|
|
if self.total_activations == 0:
|
|
|
return {
|
|
|
'activation_rate': 0.0,
|
|
|
'avg_complexity': 0.0,
|
|
|
'avg_confidence': 0.0,
|
|
|
'avg_reasoning_steps': 0.0,
|
|
|
'path_distribution': {path.value: 0.0 for path in ReasoningPath}
|
|
|
}
|
|
|
|
|
|
|
|
|
path_distribution = {
|
|
|
path.value: self.activation_counts[path] / self.total_activations * 100
|
|
|
for path in ReasoningPath
|
|
|
}
|
|
|
|
|
|
|
|
|
avg_complexity = float(np.mean(self.complexity_scores[-100:])) if self.complexity_scores else 0.0
|
|
|
avg_confidence = float(np.mean(self.confidence_scores[-100:])) if self.confidence_scores else 0.0
|
|
|
avg_reasoning_steps = float(np.mean(self.reasoning_steps[-50:])) if self.reasoning_steps else 0.0
|
|
|
|
|
|
|
|
|
cache_hit_rate = 0.0
|
|
|
if self.enable_caching and (self.cache_hits + self.cache_misses) > 0:
|
|
|
cache_hit_rate = self.cache_hits / (self.cache_hits + self.cache_misses) * 100
|
|
|
|
|
|
return {
|
|
|
'activation_rate': self.total_activations,
|
|
|
'avg_complexity': avg_complexity,
|
|
|
'avg_confidence': avg_confidence,
|
|
|
'avg_reasoning_steps': avg_reasoning_steps,
|
|
|
'path_distribution': path_distribution,
|
|
|
'cache_hit_rate': cache_hit_rate,
|
|
|
'total_cache_hits': self.cache_hits,
|
|
|
'total_cache_misses': self.cache_misses
|
|
|
}
|
|
|
|
|
|
def forward(
|
|
|
self,
|
|
|
input_ids: torch.Tensor,
|
|
|
text: str = "",
|
|
|
override_path: Optional[ReasoningPath] = None,
|
|
|
**kwargs
|
|
|
) -> Dict[str, Any]:
|
|
|
"""Main forward pass with dynamic routing"""
|
|
|
|
|
|
|
|
|
routing_decision = self.route(input_ids, text, override_path=override_path)
|
|
|
|
|
|
|
|
|
start_time = time.time()
|
|
|
|
|
|
|
|
|
if routing_decision.path == ReasoningPath.FAST:
|
|
|
|
|
|
train_needs_loss = self.training and (kwargs.get('labels', None) is not None)
|
|
|
if train_needs_loss:
|
|
|
output = self._standard_inference(input_ids, **kwargs)
|
|
|
else:
|
|
|
output = self._fast_inference(input_ids, **kwargs)
|
|
|
|
|
|
if not isinstance(output, dict):
|
|
|
output = {'logits': output}
|
|
|
|
|
|
elif routing_decision.path == ReasoningPath.STANDARD:
|
|
|
output = self._standard_inference(input_ids, **kwargs)
|
|
|
|
|
|
elif routing_decision.path == ReasoningPath.EXPERT:
|
|
|
|
|
|
output = self._standard_inference(input_ids, **kwargs)
|
|
|
|
|
|
elif routing_decision.path == ReasoningPath.DEEP:
|
|
|
output = self._deep_inference(input_ids, **kwargs)
|
|
|
|
|
|
elif routing_decision.path == ReasoningPath.ULTRA_DEEP:
|
|
|
output = self._ultra_deep_inference(input_ids, **kwargs)
|
|
|
|
|
|
else:
|
|
|
raise ValueError(f"Unknown reasoning path: {routing_decision.path}")
|
|
|
|
|
|
|
|
|
latency_ms = (time.time() - start_time) * 1000
|
|
|
self.latency_history[routing_decision.path].append(latency_ms)
|
|
|
|
|
|
|
|
|
self.activation_counts[routing_decision.path] += 1
|
|
|
self.total_activations += 1
|
|
|
self.complexity_scores.append(routing_decision.complexity_score)
|
|
|
self.confidence_scores.append(routing_decision.confidence)
|
|
|
|
|
|
|
|
|
dre_aux_loss = None
|
|
|
try:
|
|
|
if self.training and hasattr(self, '_last_router_tensors'):
|
|
|
probs = self._last_router_tensors['probs']
|
|
|
confidence = self._last_router_tensors['confidence']
|
|
|
|
|
|
target_uniform = torch.full_like(probs[0], 1.0 / probs.shape[-1])
|
|
|
balance_loss = (probs.mean(dim=0) - target_uniform).pow(2).mean()
|
|
|
|
|
|
|
|
|
path_costs = torch.tensor([0.1, 1.0, 1.5, 2.5, 4.0], dtype=probs.dtype, device=probs.device)
|
|
|
expected_cost = (probs * path_costs).sum(dim=-1).mean()
|
|
|
|
|
|
conf_loss = -torch.log(confidence.clamp_min(1e-6)).mean()
|
|
|
dre_aux_loss = balance_loss + 0.1 * expected_cost + 0.01 * conf_loss
|
|
|
except Exception:
|
|
|
dre_aux_loss = None
|
|
|
|
|
|
|
|
|
if routing_decision.path in [ReasoningPath.DEEP, ReasoningPath.ULTRA_DEEP]:
|
|
|
steps = routing_decision.debug_info.get('reasoning_steps', 1)
|
|
|
self.reasoning_steps.append(steps)
|
|
|
|
|
|
|
|
|
output['routing_info'] = {
|
|
|
'path': routing_decision.path.value,
|
|
|
'complexity_score': routing_decision.complexity_score,
|
|
|
'confidence': routing_decision.confidence,
|
|
|
'latency_ms': latency_ms,
|
|
|
'debug': routing_decision.debug_info,
|
|
|
'dre_metrics': self.get_current_metrics(),
|
|
|
'use_moe': (routing_decision.path == ReasoningPath.EXPERT)
|
|
|
}
|
|
|
|
|
|
if dre_aux_loss is not None:
|
|
|
output['dre_aux_loss'] = dre_aux_loss
|
|
|
|
|
|
|
|
|
try:
|
|
|
is_compiling = getattr(torch._dynamo, 'is_compiling', lambda: False)()
|
|
|
except Exception:
|
|
|
is_compiling = False
|
|
|
if not is_compiling:
|
|
|
|
|
|
logger.info("DRE: Path=%s, Complexity=%.3f, Latency=%.1fms",
|
|
|
routing_decision.path.value,
|
|
|
float(routing_decision.complexity_score),
|
|
|
float(latency_ms))
|
|
|
|
|
|
return output
|
|
|
|