UltraThinking-LLM-Training / src /models /dynamic_reasoning.py
Vedisasi's picture
Upload folder using huggingface_hub
54c5666 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
import time
import numpy as np
from typing import Optional, Dict, Any, List, Tuple, Union
from dataclasses import dataclass, field
from enum import Enum
import logging
logger = logging.getLogger(__name__)
class ReasoningPath(Enum):
"""Available reasoning paths with different compute requirements"""
FAST = "fast" # <100ms - cached/simple responses
STANDARD = "standard" # 1-5s - normal forward pass
EXPERT = "expert" # expert MoE path (activates experts)
DEEP = "deep" # 10-60s - chain-of-thought
ULTRA_DEEP = "ultra_deep" # minutes - recursive reasoning
@dataclass
class ComplexityFeatures:
"""Features used for complexity scoring"""
token_length: int
token_entropy: float
has_math: bool
has_code: bool
named_entities_count: int
syntactic_depth: float
conversation_depth: int
prior_failures: int = 0
user_preference_score: float = 0.5
use_moe: bool = False # Whether to use MoE for this path
domain_signals: Dict[str, float] = field(default_factory=dict)
@dataclass
class RoutingDecision:
"""Routing decision output"""
path: ReasoningPath
confidence: float
complexity_score: float
estimated_latency_ms: float
debug_info: Dict[str, Any] = field(default_factory=dict)
class ComplexityScorer(nn.Module):
"""Neural network for scoring input complexity"""
def __init__(self, feature_dim: int = 128, hidden_dim: int = 256):
super().__init__()
# Feature extractors
self.text_encoder = nn.Sequential(
nn.Linear(feature_dim, hidden_dim),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(hidden_dim, hidden_dim // 2)
)
# Domain-specific encoders
self.math_encoder = nn.Linear(32, hidden_dim // 4)
self.code_encoder = nn.Linear(32, hidden_dim // 4)
# Complexity predictor
self.complexity_head = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(hidden_dim, 1),
nn.Sigmoid()
)
# Feature statistics
self.register_buffer('feature_mean', torch.zeros(feature_dim))
self.register_buffer('feature_std', torch.ones(feature_dim))
def extract_features(self, text: str, tokens: torch.Tensor) -> ComplexityFeatures:
"""Extract complexity features from input"""
# Token statistics
token_length = len(tokens)
# Calculate token entropy
token_probs = torch.softmax(torch.randn(len(tokens)), dim=-1) # Placeholder
token_entropy = -torch.sum(token_probs * torch.log(token_probs + 1e-10)).item()
# Domain detection
has_math = any(symbol in text for symbol in ['=', '∫', '∑', '∂', 'sqrt', 'log'])
has_code = any(keyword in text for keyword in ['def', 'class', 'function', '{', '}', '()', '[]'])
# Named entities (simplified)
import re
capitals = re.findall(r'\b[A-Z][a-z]+\b', text)
named_entities_count = len(set(capitals))
# Syntactic complexity (simplified - could use actual parser)
syntactic_depth = len(text.split('.')) * np.log(1 + len(text.split(',')))
return ComplexityFeatures(
token_length=token_length,
token_entropy=token_entropy,
has_math=has_math,
has_code=has_code,
named_entities_count=named_entities_count,
syntactic_depth=syntactic_depth,
conversation_depth=0 # Set by conversation manager
)
def forward(self, features: ComplexityFeatures) -> torch.Tensor:
"""Compute complexity score from features"""
# Create feature vector
dtype = next(self.parameters()).dtype
device = next(self.parameters()).device
feature_vec = torch.tensor([
features.token_length / 1000.0, # Normalize
features.token_entropy / 10.0,
float(features.has_math),
float(features.has_code),
features.named_entities_count / 20.0,
features.syntactic_depth / 100.0,
features.conversation_depth / 10.0,
features.prior_failures / 5.0,
features.user_preference_score
], dtype=dtype, device=device).unsqueeze(0)
# Pad to feature_dim
if feature_vec.shape[1] < self.feature_mean.shape[0]:
padding = torch.zeros((1, self.feature_mean.shape[0] - feature_vec.shape[1]), dtype=dtype, device=device)
feature_vec = torch.cat([feature_vec, padding], dim=1)
# Normalize features
feature_vec = (feature_vec - self.feature_mean.to(dtype=dtype, device=device)) / (self.feature_std.to(dtype=dtype, device=device) + 1e-8)
# Encode features
text_features = self.text_encoder(feature_vec)
# Add domain-specific features if present
if features.has_math:
math_features = self.math_encoder(torch.randn(1, 32, dtype=dtype, device=device)) # Placeholder
text_features = torch.cat([text_features, math_features], dim=-1)
if features.has_code:
code_features = self.code_encoder(torch.randn(1, 32, dtype=dtype, device=device)) # Placeholder
text_features = torch.cat([text_features, code_features], dim=-1)
# Pad if necessary
if text_features.shape[1] < 256:
padding = torch.zeros((1, 256 - text_features.shape[1]), dtype=dtype, device=device)
text_features = torch.cat([text_features, padding], dim=1)
# Predict complexity
complexity_score = self.complexity_head(text_features)
return complexity_score.squeeze()
class RouterNetwork(nn.Module):
"""Neural router for path selection"""
def __init__(self, hidden_dim: int = 4096, router_hidden: int = 1024, n_paths: int = 4):
super().__init__()
self.n_paths = n_paths
# Router MLP
self.router = nn.Sequential(
nn.Linear(hidden_dim + 9, router_hidden), # +9 for complexity features
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(router_hidden, router_hidden // 2),
nn.ReLU(),
nn.Linear(router_hidden // 2, n_paths)
)
# Confidence predictor
self.confidence = nn.Sequential(
nn.Linear(hidden_dim + n_paths, 256),
nn.ReLU(),
nn.Linear(256, 1),
nn.Sigmoid()
)
def forward(self, hidden_states: torch.Tensor, complexity_features: ComplexityFeatures) -> Tuple[torch.Tensor, torch.Tensor]:
"""Route to appropriate path based on input"""
batch_size = hidden_states.shape[0]
# Pool hidden states
pooled = hidden_states.mean(dim=1) # [batch, hidden_dim]
# Create feature vector
dtype = hidden_states.dtype
device = hidden_states.device
feature_vec = torch.tensor([
complexity_features.token_length / 1000.0,
complexity_features.token_entropy / 10.0,
float(complexity_features.has_math),
float(complexity_features.has_code),
complexity_features.named_entities_count / 20.0,
complexity_features.syntactic_depth / 100.0,
complexity_features.conversation_depth / 10.0,
complexity_features.prior_failures / 5.0,
complexity_features.user_preference_score
], dtype=dtype, device=device).unsqueeze(0).repeat(batch_size, 1)
# Concatenate features
router_input = torch.cat([pooled, feature_vec], dim=-1)
# Get routing probabilities
logits = self.router(router_input)
probs = F.softmax(logits, dim=-1)
# Predict confidence
conf_input = torch.cat([pooled, probs], dim=-1)
confidence = self.confidence(conf_input).squeeze(-1)
return probs, confidence
class DynamicReasoningEngine(nn.Module):
"""Main DRE orchestrator for adaptive inference"""
def __init__(
self,
base_model: nn.Module,
config: Dict[str, Any],
fast_model: Optional[nn.Module] = None,
enable_caching: bool = True
):
super().__init__()
self.base_model = base_model
self.fast_model = fast_model or self._create_distilled_model()
self.config = config
# Components
self.complexity_scorer = ComplexityScorer()
self.router = RouterNetwork(
hidden_dim=config.get('hidden_dim', 4096),
n_paths=len(ReasoningPath)
)
# Hidden-state based complexity head to avoid placeholder randomness and to vary per-input
self.hidden_complexity_head = nn.Sequential(
nn.Linear(config.get('hidden_dim', 4096), 256),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(256, 1),
nn.Sigmoid(),
)
# Caching
self.enable_caching = enable_caching
self.cache = {} if enable_caching else None
self.cache_hits = 0
self.cache_misses = 0
# Thresholds for routing (can be learned)
# Reordered so EXPERT sits between STANDARD and DEEP for better MoE engagement
self.complexity_thresholds = {
ReasoningPath.FAST: 0.2,
ReasoningPath.STANDARD: 0.35,
ReasoningPath.EXPERT: 0.5, # MoE experts - moderate complexity
ReasoningPath.DEEP: 0.75, # Chain-of-thought - high complexity
ReasoningPath.ULTRA_DEEP: 0.9 # Recursive reasoning - very high complexity
}
# Latency tracking
self.latency_history = {path: [] for path in ReasoningPath}
# DRE metrics tracking
self.activation_counts = {path: 0 for path in ReasoningPath}
self.total_activations = 0
self.complexity_scores = []
self.confidence_scores = []
self.reasoning_steps = []
def _create_distilled_model(self):
"""Create a smaller distilled version of the base model"""
# Placeholder - in practice, load a pre-distilled model
return nn.Sequential(
nn.Linear(self.base_model.config.n_embd, 512),
nn.ReLU(),
nn.Linear(512, self.base_model.config.vocab_size)
)
def _check_cache(self, input_hash: str) -> Optional[torch.Tensor]:
"""Check if response is cached"""
if not self.enable_caching:
return None
if input_hash in self.cache:
self.cache_hits += 1
logger.info(f"Cache hit! Hits: {self.cache_hits}, Misses: {self.cache_misses}")
return self.cache[input_hash]
self.cache_misses += 1
return None
def _fast_inference(self, input_ids: torch.Tensor, **kwargs) -> torch.Tensor:
"""Fast path: cached or distilled model inference"""
# Check cache first
input_hash = hash(input_ids.cpu().numpy().tobytes())
cached = self._check_cache(str(input_hash))
if cached is not None:
return cached
# Use distilled model
if self.fast_model is not None:
with torch.no_grad():
embeddings = self.base_model.embed_tokens(input_ids)
pooled = embeddings.mean(dim=1)
output = self.fast_model(pooled)
# Cache result
if self.enable_caching:
self.cache[str(input_hash)] = output
return output
return None
def _standard_inference(self, input_ids: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:
"""Standard path: normal forward pass"""
return self.base_model(input_ids, **kwargs)
def _deep_inference(
self,
input_ids: torch.Tensor,
max_steps: int = 10,
**kwargs
) -> Dict[str, torch.Tensor]:
"""Deep path: chain-of-thought reasoning"""
outputs = []
current_input = input_ids
for step in range(max_steps):
# Generate reasoning step
step_output = self.base_model(current_input, **kwargs)
outputs.append(step_output)
# Check if reasoning is complete (simplified)
if self._is_reasoning_complete(step_output):
break
# Prepare next input (would include generated tokens in practice)
current_input = input_ids # Placeholder
# Aggregate outputs
final_output = self._aggregate_reasoning_steps(outputs)
return final_output
def _ultra_deep_inference(
self,
input_ids: torch.Tensor,
max_depth: int = 5,
**kwargs
) -> Dict[str, torch.Tensor]:
"""Ultra-deep path: recursive reasoning with self-reflection"""
def recursive_reason(input_ids, depth):
if depth == 0:
return self._standard_inference(input_ids, **kwargs)
# Generate initial response
response = self._deep_inference(input_ids, **kwargs)
# Self-critique (placeholder)
critique = self._generate_critique(response)
# Refine based on critique
refined = recursive_reason(input_ids, depth - 1)
return self._merge_responses(response, refined)
return recursive_reason(input_ids, max_depth)
def _is_reasoning_complete(self, output: Dict[str, torch.Tensor]) -> bool:
"""Check if reasoning chain is complete"""
# Simplified - check for end token or confidence threshold
logits = output.get('logits', None)
if logits is not None:
probs = F.softmax(logits[:, -1, :], dim=-1)
max_prob = probs.max().item()
return max_prob > 0.95 # High confidence
return False
def _aggregate_reasoning_steps(self, outputs: List[Dict]) -> Dict[str, torch.Tensor]:
"""Aggregate multiple reasoning steps"""
# Simple averaging (can be more sophisticated)
aggregated = {}
for key in outputs[0].keys():
if isinstance(outputs[0][key], torch.Tensor):
stacked = torch.stack([o[key] for o in outputs])
aggregated[key] = stacked.mean(dim=0)
else:
aggregated[key] = outputs[-1][key] # Take last
return aggregated
def _generate_critique(self, response: Dict[str, torch.Tensor]) -> torch.Tensor:
"""Generate self-critique of response"""
# Placeholder - would use a critique model
return torch.randn_like(response['logits'])
def _merge_responses(self, response1: Dict, response2: Dict) -> Dict[str, torch.Tensor]:
"""Merge two responses"""
merged = {}
for key in response1.keys():
if isinstance(response1[key], torch.Tensor):
# Weighted average
merged[key] = 0.6 * response1[key] + 0.4 * response2[key]
else:
merged[key] = response1[key]
return merged
def route(
self,
input_ids: torch.Tensor,
text: str = "",
use_soft_routing: bool = False,
override_path: Optional[ReasoningPath] = None
) -> RoutingDecision:
"""Decide which reasoning path to use"""
# Extract features
features = self.complexity_scorer.extract_features(text, input_ids[0])
# Get complexity score - combine hidden-state signal with features for better variation
# Use base embeddings as input signal but DETACH to avoid training the base model from DRE aux loss
embeddings = self.base_model.embed_tokens(input_ids).detach()
pooled = embeddings.mean(dim=1) # [batch, hidden_dim]
complexity_hidden = self.hidden_complexity_head(pooled).squeeze(-1) # [batch]
complexity_features = self.complexity_scorer(features).squeeze()
# Blend signals; if batch, average feature score across batch for stability
if isinstance(complexity_features, torch.Tensor) and complexity_features.dim() == 0:
complexity_features_tensor = complexity_features
else:
# Coerce to tensor on the right device/dtype
complexity_features_tensor = torch.as_tensor(complexity_features, dtype=complexity_hidden.dtype, device=complexity_hidden.device)
complexity_score_tensor = 0.7 * complexity_hidden + 0.3 * complexity_features_tensor
complexity_score = float(complexity_score_tensor.mean().detach().cpu().item())
# Get router prediction (allow grads for router so it can learn via aux loss)
probs, confidence = self.router(embeddings, features)
# Override if specified
if override_path:
return RoutingDecision(
path=override_path,
confidence=1.0,
complexity_score=complexity_score,
estimated_latency_ms=self._estimate_latency(override_path),
debug_info={'override': True}
)
# Soft routing: combine outputs from multiple paths
if use_soft_routing:
# Return probabilities for weighted combination
probs_np = probs.detach().to(torch.float32).cpu().numpy()
return RoutingDecision(
path=ReasoningPath.STANDARD, # Default
confidence=confidence.item(),
complexity_score=complexity_score,
estimated_latency_ms=self._estimate_latency_weighted(probs),
debug_info={'probs': probs_np, 'soft_routing': True}
)
# Hard routing: select single path
path_idx = probs.argmax(dim=-1).item()
selected_path = list(ReasoningPath)[path_idx]
# Apply complexity threshold override only when NOT training
# During training, allow the router to learn the mapping; rely on thresholds at inference time
if not self.training:
if complexity_score < self.complexity_thresholds[ReasoningPath.FAST]:
selected_path = ReasoningPath.FAST
elif complexity_score < self.complexity_thresholds[ReasoningPath.STANDARD]:
selected_path = ReasoningPath.STANDARD
elif complexity_score < self.complexity_thresholds[ReasoningPath.DEEP]:
selected_path = ReasoningPath.DEEP
elif complexity_score >= self.complexity_thresholds[ReasoningPath.ULTRA_DEEP]:
selected_path = ReasoningPath.ULTRA_DEEP
# Stash tensors for aux loss computation during forward()
self._last_router_tensors = {
'probs': probs, # [batch, n_paths]
'confidence': confidence, # [batch]
'complexity': complexity_score_tensor, # [batch]
}
probs_np = probs.detach().to(torch.float32).cpu().numpy()
return RoutingDecision(
path=selected_path,
confidence=confidence.item(),
complexity_score=complexity_score,
estimated_latency_ms=self._estimate_latency(selected_path),
debug_info={
'probs': probs_np,
'features': features.__dict__
}
)
def _estimate_latency(self, path: ReasoningPath) -> float:
"""Estimate latency for a given path"""
latency_ranges = {
ReasoningPath.FAST: (10, 100),
ReasoningPath.STANDARD: (1000, 5000),
ReasoningPath.EXPERT: (3000, 10000), # MoE experts - slower than standard, faster than deep
ReasoningPath.DEEP: (10000, 60000),
ReasoningPath.ULTRA_DEEP: (60000, 300000)
}
if self.latency_history[path]:
# Use historical average
return np.mean(self.latency_history[path][-10:])
# Use midpoint of range
min_lat, max_lat = latency_ranges[path]
return (min_lat + max_lat) / 2
def _estimate_latency_weighted(self, probs: torch.Tensor) -> float:
"""Estimate weighted latency for soft routing"""
latencies = [self._estimate_latency(path) for path in ReasoningPath]
weighted_latency = sum(p * l for p, l in zip(probs[0].detach().to(torch.float32).cpu().numpy(), latencies))
return weighted_latency
def get_current_metrics(self) -> Dict[str, Any]:
"""Get current DRE metrics for logging"""
if self.total_activations == 0:
return {
'activation_rate': 0.0,
'avg_complexity': 0.0,
'avg_confidence': 0.0,
'avg_reasoning_steps': 0.0,
'path_distribution': {path.value: 0.0 for path in ReasoningPath}
}
# Calculate activation rates per path
path_distribution = {
path.value: self.activation_counts[path] / self.total_activations * 100
for path in ReasoningPath
}
# Calculate averages
avg_complexity = float(np.mean(self.complexity_scores[-100:])) if self.complexity_scores else 0.0
avg_confidence = float(np.mean(self.confidence_scores[-100:])) if self.confidence_scores else 0.0
avg_reasoning_steps = float(np.mean(self.reasoning_steps[-50:])) if self.reasoning_steps else 0.0
# Cache efficiency
cache_hit_rate = 0.0
if self.enable_caching and (self.cache_hits + self.cache_misses) > 0:
cache_hit_rate = self.cache_hits / (self.cache_hits + self.cache_misses) * 100
return {
'activation_rate': self.total_activations,
'avg_complexity': avg_complexity,
'avg_confidence': avg_confidence,
'avg_reasoning_steps': avg_reasoning_steps,
'path_distribution': path_distribution,
'cache_hit_rate': cache_hit_rate,
'total_cache_hits': self.cache_hits,
'total_cache_misses': self.cache_misses
}
def forward(
self,
input_ids: torch.Tensor,
text: str = "",
override_path: Optional[ReasoningPath] = None,
**kwargs
) -> Dict[str, Any]:
"""Main forward pass with dynamic routing"""
# Route to appropriate path
routing_decision = self.route(input_ids, text, override_path=override_path)
# Track timing
start_time = time.time()
# Execute selected path
if routing_decision.path == ReasoningPath.FAST:
# During training with labels, run STANDARD inference to get valid loss/hidden_states
train_needs_loss = self.training and (kwargs.get('labels', None) is not None)
if train_needs_loss:
output = self._standard_inference(input_ids, **kwargs)
else:
output = self._fast_inference(input_ids, **kwargs)
# Convert to standard format if needed
if not isinstance(output, dict):
output = {'logits': output}
elif routing_decision.path == ReasoningPath.STANDARD:
output = self._standard_inference(input_ids, **kwargs)
elif routing_decision.path == ReasoningPath.EXPERT:
# Expert path shares the same base forward; UltraThinkCore will apply MoE based on routing_info['use_moe']
output = self._standard_inference(input_ids, **kwargs)
elif routing_decision.path == ReasoningPath.DEEP:
output = self._deep_inference(input_ids, **kwargs)
elif routing_decision.path == ReasoningPath.ULTRA_DEEP:
output = self._ultra_deep_inference(input_ids, **kwargs)
else:
raise ValueError(f"Unknown reasoning path: {routing_decision.path}")
# Track latency
latency_ms = (time.time() - start_time) * 1000
self.latency_history[routing_decision.path].append(latency_ms)
# Update DRE metrics
self.activation_counts[routing_decision.path] += 1
self.total_activations += 1
self.complexity_scores.append(routing_decision.complexity_score)
self.confidence_scores.append(routing_decision.confidence)
# Compute a small auxiliary loss to train the router (balance + latency + confidence)
dre_aux_loss = None
try:
if self.training and hasattr(self, '_last_router_tensors'):
probs = self._last_router_tensors['probs'] # [batch, n_paths]
confidence = self._last_router_tensors['confidence'] # [batch]
# Encourage balanced usage across paths (Switch-Transformer style)
target_uniform = torch.full_like(probs[0], 1.0 / probs.shape[-1])
balance_loss = (probs.mean(dim=0) - target_uniform).pow(2).mean()
# Penalize expected latency (prefer cheaper paths unless LM loss demands otherwise)
# Relative costs for FAST, STANDARD, EXPERT, DEEP, ULTRA_DEEP
path_costs = torch.tensor([0.1, 1.0, 1.5, 2.5, 4.0], dtype=probs.dtype, device=probs.device)
expected_cost = (probs * path_costs).sum(dim=-1).mean()
# Encourage higher confidence
conf_loss = -torch.log(confidence.clamp_min(1e-6)).mean()
dre_aux_loss = balance_loss + 0.1 * expected_cost + 0.01 * conf_loss
except Exception:
dre_aux_loss = None
# Track reasoning steps for deep paths
if routing_decision.path in [ReasoningPath.DEEP, ReasoningPath.ULTRA_DEEP]:
steps = routing_decision.debug_info.get('reasoning_steps', 1)
self.reasoning_steps.append(steps)
# Add routing info to output
output['routing_info'] = {
'path': routing_decision.path.value,
'complexity_score': routing_decision.complexity_score,
'confidence': routing_decision.confidence,
'latency_ms': latency_ms,
'debug': routing_decision.debug_info,
'dre_metrics': self.get_current_metrics(),
'use_moe': (routing_decision.path == ReasoningPath.EXPERT)
}
# Expose aux loss to the trainer for joint optimization
if dre_aux_loss is not None:
output['dre_aux_loss'] = dre_aux_loss
# Avoid issues with torch.compile/torch._dynamo tracing Python f-strings and time
try:
is_compiling = getattr(torch._dynamo, 'is_compiling', lambda: False)()
except Exception:
is_compiling = False
if not is_compiling:
# Use logger parameter interpolation to avoid formatting issues
logger.info("DRE: Path=%s, Complexity=%.3f, Latency=%.1fms",
routing_decision.path.value,
float(routing_decision.complexity_score),
float(latency_ms))
return output