import torch import torch.nn as nn import torch.nn.functional as F from typing import Optional, Tuple, List, Dict, Any import math from ..configs.config import ModelConfig, InterleavedThinkingConfig class EfficientAttention(nn.Module): """Memory-efficient attention mechanism with flash attention support.""" def __init__(self, config: ModelConfig): super().__init__() self.config = config self.n_heads = config.heads self.head_dim = config.dim // config.heads self.scale = self.head_dim ** -0.5 self.q_proj = nn.Linear(config.dim, config.dim, bias=False) self.k_proj = nn.Linear(config.dim, config.dim, bias=False) self.v_proj = nn.Linear(config.dim, config.dim, bias=False) self.o_proj = nn.Linear(config.dim, config.dim, bias=False) self.dropout = nn.Dropout(config.dropout) # RoPE for positional encoding self.rope_cache = self._build_rope_cache(config.max_seq_len) def _build_rope_cache(self, max_seq_len: int) -> torch.Tensor: """Build RoPE (Rotary Position Embedding) cache.""" inv_freq = 1.0 / (10000 ** (torch.arange(0, self.head_dim, 2).float() / self.head_dim)) t = torch.arange(max_seq_len).float() freqs = torch.einsum('i , j -> i j', t, inv_freq) return torch.cat((freqs.sin(), freqs.cos()), dim=-1) def _apply_rope(self, x: torch.Tensor, start_pos: int = 0) -> torch.Tensor: """Apply RoPE to input tensor.""" rope = self.rope_cache[start_pos:start_pos + x.size(1)].to(x.device) xshaped = x.float().reshape(*x.shape[:-1], -1, 2) rope_shaped = rope.reshape(1, xshaped.size(1), 1, xshaped.size(3), 2) x_out = torch.stack([ xshaped[..., 0] * rope_shaped[..., 0] - xshaped[..., 1] * rope_shaped[..., 1], xshaped[..., 1] * rope_shaped[..., 0] + xshaped[..., 0] * rope_shaped[..., 1], ], -1) return x_out.flatten(3).type_as(x) def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor: B, T, C = x.size() # Project to q, k, v q = self.q_proj(x).view(B, T, self.n_heads, self.head_dim) k = self.k_proj(x).view(B, T, self.n_heads, self.head_dim) v = self.v_proj(x).view(B, T, self.n_heads, self.head_dim) # Apply RoPE q = self._apply_rope(q.transpose(1, 2)).transpose(1, 2) k = self._apply_rope(k.transpose(1, 2)).transpose(1, 2) # Transpose for attention q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) # Flash attention or regular attention if self.config.flash_attention and hasattr(torch.nn.functional, 'scaled_dot_product_attention'): # Use PyTorch's built-in flash attention attn_output = F.scaled_dot_product_attention( q, k, v, attn_mask=mask, dropout_p=self.dropout.p if self.training else 0.0 ) else: # Regular attention attn_weights = (q @ k.transpose(-2, -1)) * self.scale if mask is not None: attn_weights = attn_weights.masked_fill(mask == 0, float('-inf')) attn_weights = F.softmax(attn_weights, dim=-1) attn_weights = self.dropout(attn_weights) attn_output = attn_weights @ v # Reshape and project attn_output = attn_output.transpose(1, 2).contiguous().view(B, T, C) return self.o_proj(attn_output) class FeedForward(nn.Module): """Feed-forward network with SwiGLU activation.""" def __init__(self, config: ModelConfig): super().__init__() hidden_dim = 4 * config.dim # Standard expansion factor self.gate_proj = nn.Linear(config.dim, hidden_dim, bias=False) self.up_proj = nn.Linear(config.dim, hidden_dim, bias=False) self.down_proj = nn.Linear(hidden_dim, config.dim, bias=False) self.dropout = nn.Dropout(config.dropout) def forward(self, x: torch.Tensor) -> torch.Tensor: gate = self.gate_proj(x) up = self.up_proj(x) # SwiGLU activation hidden = F.silu(gate) * up return self.down_proj(self.dropout(hidden)) class TransformerBlock(nn.Module): """Single transformer block with efficient attention and feed-forward.""" def __init__(self, config: ModelConfig): super().__init__() self.attention = EfficientAttention(config) self.feed_forward = FeedForward(config) self.attention_norm = nn.RMSNorm(config.dim) self.ffn_norm = nn.RMSNorm(config.dim) def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor: # Pre-norm attention attn_out = self.attention(self.attention_norm(x), mask) x = x + attn_out # Pre-norm feed-forward ff_out = self.feed_forward(self.ffn_norm(x)) x = x + ff_out return x class CompactTransformer(nn.Module): """Compact transformer model with efficient architecture.""" def __init__(self, config: ModelConfig): super().__init__() self.config = config # Token and position embeddings self.embed_tokens = nn.Embedding(config.vocab_size, config.dim) self.embed_positions = nn.Embedding(config.max_seq_len, config.dim) # Transformer layers self.layers = nn.ModuleList([ TransformerBlock(config) for _ in range(config.layers) ]) # Output head self.norm = nn.RMSNorm(config.dim) self.lm_head = nn.Linear(config.dim, config.vocab_size, bias=False) # Tie weights self.embed_tokens.weight = self.lm_head.weight # Initialize weights self.apply(self._init_weights) def _init_weights(self, module): """Initialize model weights.""" if isinstance(module, nn.Linear): torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) if module.bias is not None: torch.nn.init.zeros_(module.bias) elif isinstance(module, nn.Embedding): torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) def forward( self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, ) -> Dict[str, torch.Tensor]: B, T = input_ids.size() # Create position IDs if not provided if position_ids is None: position_ids = torch.arange(T, dtype=torch.long, device=input_ids.device).unsqueeze(0) # Embeddings token_emb = self.embed_tokens(input_ids) pos_emb = self.embed_positions(position_ids) x = token_emb + pos_emb # Create attention mask if attention_mask is not None: # Convert to causal mask causal_mask = torch.triu(torch.ones(T, T, device=input_ids.device), diagonal=1).bool() causal_mask = causal_mask.unsqueeze(0).unsqueeze(0) # (1, 1, T, T) # Combine with attention_mask attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) # (B, 1, 1, T) attention_mask = attention_mask & ~causal_mask else: # Pure causal mask causal_mask = torch.triu(torch.ones(T, T, device=input_ids.device), diagonal=1).bool() attention_mask = ~causal_mask.unsqueeze(0).unsqueeze(0) # Apply transformer layers for layer in self.layers: x = layer(x, attention_mask) # Final normalization x = self.norm(x) # Language modeling head logits = self.lm_head(x) return {"logits": logits, "hidden_states": x} def get_num_params(self) -> int: """Get total number of parameters.""" return sum(p.numel() for p in self.parameters()) class ReasoningPath(nn.Module): """Enhanced reasoning path for interleaved thinking with uncertainty estimation.""" def __init__(self, config: ModelConfig, thinking_config: InterleavedThinkingConfig, path_id: int = 0): super().__init__() self.config = config self.thinking_config = thinking_config self.path_id = path_id # Reasoning-specific layers (smaller than main model) self.reasoning_layers = nn.ModuleList([ TransformerBlock(config) for _ in range(min(2, config.layers // 2)) ]) # Enhanced confidence scoring with uncertainty estimation self.confidence_head = nn.Sequential( nn.Linear(config.dim, config.dim // 2), nn.ReLU(), nn.Linear(config.dim // 2, 2) # mean and variance for uncertainty ) self.output_projection = nn.Linear(config.dim, config.vocab_size) # Path specialization (if enabled) if thinking_config.path_specialization: self.specialization_adapter = nn.Linear(config.dim, config.dim) else: self.specialization_adapter = None def forward(self, hidden_states: torch.Tensor, mask: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]: x = hidden_states # Apply path specialization if enabled if self.specialization_adapter is not None: # Add path-specific bias based on path_id path_bias = torch.sin(torch.tensor(float(self.path_id) * 0.1)) * 0.1 x = x + path_bias # Apply specialization adapter if self.specialization_adapter is not None: x = self.specialization_adapter(x) # Apply reasoning layers for layer in self.reasoning_layers: x = layer(x, mask) # Enhanced confidence scoring with uncertainty if self.thinking_config.uncertainty_estimation: confidence_params = self.confidence_head(x.mean(dim=1)) confidence_mean = torch.sigmoid(confidence_params[:, 0:1]) confidence_var = F.softplus(confidence_params[:, 1:2]) + 1e-6 # ensure positive variance # Sample from distribution for robustness if self.training: eps = torch.randn_like(confidence_var) confidence = confidence_mean + torch.sqrt(confidence_var) * eps confidence = torch.clamp(confidence, 0.0, 1.0) else: confidence = confidence_mean else: # Fallback to simple confidence scoring confidence = torch.sigmoid(self.confidence_head(x.mean(dim=1))) # Project to vocabulary reasoning_logits = self.output_projection(x) return { "reasoning_logits": reasoning_logits, "confidence": confidence, "confidence_var": confidence_var if self.thinking_config.uncertainty_estimation else None, "reasoning_states": x } class EarlyStopController(nn.Module): """Enhanced controller for early stopping with task-specific thresholds.""" def __init__(self, config: ModelConfig, thinking_config: InterleavedThinkingConfig): super().__init__() self.thinking_config = thinking_config # Task complexity classifier self.complexity_classifier = nn.Sequential( nn.Linear(config.dim, config.dim // 2), nn.ReLU(), nn.Linear(config.dim // 2, 3), # simple, medium, complex ) # Early stop predictor self.stop_predictor = nn.Linear(config.dim, 1) # Task-specific threshold predictors (if enabled) if thinking_config.task_specific_thresholds: self.task_threshold_predictor = nn.Sequential( nn.Linear(config.dim, config.dim // 4), nn.ReLU(), nn.Linear(config.dim // 4, 1), nn.Sigmoid() ) else: self.task_threshold_predictor = None def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: # Classify task complexity complexity_logits = self.complexity_classifier(hidden_states.mean(dim=1)) complexity_probs = F.softmax(complexity_logits, dim=-1) # Predict whether to stop early stop_logits = self.stop_predictor(hidden_states.mean(dim=1)) stop_prob = torch.sigmoid(stop_logits) # Task-specific threshold (if enabled) task_threshold = None if self.task_threshold_predictor is not None: task_threshold = self.task_threshold_predictor(hidden_states.mean(dim=1)) return complexity_probs, stop_prob, task_threshold class HierarchicalReasoningPath(nn.Module): """Hierarchical reasoning path with different abstraction levels.""" def __init__(self, config: ModelConfig, thinking_config: InterleavedThinkingConfig, level: int): super().__init__() self.level = level self.config = config self.thinking_config = thinking_config # Different architectures for different hierarchy levels if level == 0: # Low-level, detailed reasoning self.layers = nn.ModuleList([TransformerBlock(config) for _ in range(2)]) self.abstraction_projection = nn.Linear(config.dim, config.dim // 2) elif level == 1: # Mid-level, pattern recognition self.layers = nn.ModuleList([TransformerBlock(config) for _ in range(1)]) self.abstraction_projection = nn.Linear(config.dim, config.dim // 4) else: # High-level, conceptual reasoning self.layers = nn.ModuleList([TransformerBlock(config) for _ in range(1)]) self.abstraction_projection = nn.Linear(config.dim, config.dim // 8) self.confidence_head = nn.Linear(config.dim, 1) self.output_projection = nn.Linear(config.dim, config.vocab_size) def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]: # Apply level-specific reasoning for layer in self.layers: x = layer(x, mask) # Apply abstraction based on hierarchy level abstracted = self.abstraction_projection(x) confidence = torch.sigmoid(self.confidence_head(x.mean(dim=1))) reasoning_logits = self.output_projection(x) return { "reasoning_logits": reasoning_logits, "confidence": confidence, "abstracted_states": abstracted, "level": self.level } class InterleavedThinking(nn.Module): """Enhanced interleaved thinking mechanism with hierarchical paths and attention fusion.""" def __init__(self, model_config: ModelConfig, thinking_config: InterleavedThinkingConfig): super().__init__() self.model_config = model_config self.thinking_config = thinking_config # Hierarchical reasoning paths if thinking_config.hierarchical_paths: self.reasoning_paths = nn.ModuleList([ HierarchicalReasoningPath(model_config, thinking_config, level % thinking_config.num_hierarchy_levels) for level in range(thinking_config.max_reasoning_paths) ]) else: # Fallback to regular reasoning paths self.reasoning_paths = nn.ModuleList([ ReasoningPath(model_config, thinking_config, path_id=i) for i in range(thinking_config.max_reasoning_paths) ]) # Early stop controller self.early_stop_controller = EarlyStopController(model_config, thinking_config) # Attention-based path fusion (if enabled) if thinking_config.attention_fusion: self.fusion_attention = nn.MultiheadAttention( embed_dim=model_config.vocab_size, num_heads=8, dropout=0.1 ) self.fusion_norm = nn.LayerNorm(model_config.vocab_size) else: # Fallback to linear combination self.path_combiner = nn.Linear( model_config.vocab_size * thinking_config.max_reasoning_paths, model_config.vocab_size ) # Adaptive memory compression with reconstruction if thinking_config.adaptive_compression: compression_dim = model_config.dim // 4 self.memory_compressor = nn.Linear(model_config.dim, compression_dim) self.memory_reconstructor = nn.Linear(compression_dim, model_config.dim) self.compression_gate = nn.Sequential( nn.Linear(model_config.dim, 1), nn.Sigmoid() ) elif thinking_config.memory_compression: self.memory_compressor = nn.Linear(model_config.dim, model_config.dim // 4) self.memory_reconstructor = None self.compression_gate = None else: self.memory_compressor = None self.memory_reconstructor = None self.compression_gate = None def forward( self, base_hidden_states: torch.Tensor, mask: Optional[torch.Tensor] = None, current_depth: int = 0 ) -> Dict[str, Any]: batch_size = base_hidden_states.size(0) # Early stopping check with task-specific thresholds complexity_probs, stop_prob, task_threshold = self.early_stop_controller(base_hidden_states) # Use task-specific threshold if available, otherwise use config default effective_threshold = task_threshold if task_threshold is not None else self.thinking_config.early_stop_threshold should_stop = stop_prob > effective_threshold if should_stop.item() and current_depth > 1: return { "should_stop": True, "reasoning_results": None, "final_logits": None, "confidence_scores": None, "complexity": complexity_probs, "visualization_data": {"early_stop_depth": current_depth} if self.thinking_config.visualization_enabled else None } # Run parallel reasoning paths path_results = [] confidence_scores = [] confidence_vars = [] for path in self.reasoning_paths: result = path(base_hidden_states, mask) path_results.append(result["reasoning_logits"]) confidence_scores.append(result["confidence"]) if "confidence_var" in result and result["confidence_var"] is not None: confidence_vars.append(result["confidence_var"]) # Stack results path_logits = torch.stack(path_results, dim=1) # (B, num_paths, T, vocab_size) confidence_scores = torch.stack(confidence_scores, dim=1) # (B, num_paths, 1) # Path combination: attention fusion or confidence-weighted averaging if self.thinking_config.attention_fusion: # Attention-based fusion # Flatten batch and sequence dimensions for attention B, P, T, V = path_logits.size() flat_logits = path_logits.view(B * T, P, V) # (B*T, P, V) # Create attention mask and query attn_output, _ = self.fusion_attention( flat_logits.mean(dim=1, keepdim=True), # query: mean across paths flat_logits, # key flat_logits # value ) combined_logits = self.fusion_norm(attn_output.squeeze(1)).view(B, T, V) else: # Confidence-weighted averaging confidence_weights = F.softmax(confidence_scores.squeeze(-1), dim=-1) confidence_weights = confidence_weights.unsqueeze(-1).unsqueeze(-1) # (B, num_paths, 1, 1) # Weighted combination of logits weighted_logits = (path_logits * confidence_weights).sum(dim=1) # Final projection combined_logits = self.path_combiner( weighted_logits.view(batch_size, -1, self.model_config.vocab_size * self.thinking_config.max_reasoning_paths) ) # Adaptive memory compression with reconstruction compressed_states = None reconstruction_loss = None if self.memory_compressor is not None: # Get the reasoning states from the highest confidence path best_path_idx = confidence_scores.mean(dim=-1).argmax(dim=-1) best_reasoning_states = torch.stack([ path_results[i]["reasoning_states"][b] for b, i in enumerate(best_path_idx) ], dim=0) if self.thinking_config.adaptive_compression and self.memory_reconstructor is not None: # Adaptive compression with gating and reconstruction compression_gate = self.compression_gate(best_reasoning_states.mean(dim=1)) compressed = self.memory_compressor(best_reasoning_states) reconstructed = self.memory_reconstructor(compressed) # Reconstruction loss for training reconstruction_loss = F.mse_loss(reconstructed, best_reasoning_states) # Adaptive compression: use compressed if gate > 0.5, otherwise use original compressed_states = torch.where( compression_gate.unsqueeze(-1).unsqueeze(-1) > 0.5, compressed, best_reasoning_states ) else: # Simple compression compressed_states = self.memory_compressor(best_reasoning_states) # Visualization data visualization_data = None if self.thinking_config.visualization_enabled: visualization_data = { "confidence_scores": confidence_scores.cpu().numpy(), "confidence_vars": [v.cpu().numpy() for v in confidence_vars] if confidence_vars else None, "complexity_probs": complexity_probs.cpu().numpy(), "task_threshold": task_threshold.cpu().numpy() if task_threshold is not None else None, "path_logits_shape": path_logits.shape, "hierarchical_levels": [getattr(path, 'level', 0) for path in self.reasoning_paths] if self.thinking_config.hierarchical_paths else None, "reconstruction_loss": reconstruction_loss.item() if reconstruction_loss is not None else None } return { "should_stop": False, "reasoning_results": { "path_logits": path_logits, "confidence_scores": confidence_scores, "complexity": complexity_probs, "compressed_states": compressed_states, "confidence_vars": confidence_vars if confidence_vars else None, "reconstruction_loss": reconstruction_loss }, "final_logits": combined_logits, "confidence_scores": confidence_scores.mean(dim=1), "visualization_data": visualization_data } class CompactAIModel(nn.Module): """Complete compact AI model with interleaved thinking.""" def __init__(self, model_config: ModelConfig, thinking_config: InterleavedThinkingConfig): super().__init__() self.model_config = model_config self.thinking_config = thinking_config # Base transformer model self.base_model = CompactTransformer(model_config) # Interleaved thinking mechanism self.thinking = InterleavedThinking(model_config, thinking_config) # Dynamic depth controller self.depth_controller = nn.Linear(model_config.dim, thinking_config.reasoning_depth) def forward( self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, use_thinking: bool = True, max_reasoning_depth: Optional[int] = None, ) -> Dict[str, Any]: # Base model forward pass base_outputs = self.base_model(input_ids, attention_mask) base_logits = base_outputs["logits"] base_hidden = base_outputs["hidden_states"] if not use_thinking: return { "logits": base_logits, "thinking_results": None, "final_tokens": 0, "visualization_data": None } # Determine reasoning depth if max_reasoning_depth is None: # Dynamic depth based on input complexity if self.thinking_config.dynamic_depth: depth_logits = self.depth_controller(base_hidden.mean(dim=1)) depth_probs = F.softmax(depth_logits, dim=-1) max_reasoning_depth = depth_probs.argmax(dim=-1).item() + 1 else: max_reasoning_depth = self.thinking_config.reasoning_depth # Interleaved thinking with iterative reasoning current_hidden = base_hidden thinking_results = [] total_reasoning_tokens = 0 visualization_history = [] if self.thinking_config.visualization_enabled else None for depth in range(max_reasoning_depth): thinking_output = self.thinking(current_hidden, attention_mask, depth) if thinking_output["should_stop"]: break thinking_results.append(thinking_output["reasoning_results"]) # Collect visualization data if self.thinking_config.visualization_enabled and thinking_output["visualization_data"]: thinking_output["visualization_data"]["depth"] = depth visualization_history.append(thinking_output["visualization_data"]) # Update hidden states for next iteration if we have compressed states if thinking_output["reasoning_results"]["compressed_states"] is not None: current_hidden = thinking_output["reasoning_results"]["compressed_states"] else: # Use the combined logits to generate next hidden states # This is a simplified version - in practice, you'd want more sophisticated state updates current_hidden = current_hidden + thinking_output["final_logits"].detach() * 0.1 total_reasoning_tokens += input_ids.size(1) # Check token budget if total_reasoning_tokens >= self.thinking_config.token_budget: break # Final output combination if thinking_results: # Use the last thinking result's combined logits final_logits = thinking_output["final_logits"] else: final_logits = base_logits return { "logits": final_logits, "thinking_results": thinking_results, "final_tokens": total_reasoning_tokens, "visualization_data": visualization_history } def create_compact_model(model_size: str = "small") -> CompactAIModel: """Create a compact AI model with the specified size.""" if model_size == "tiny": model_config = ModelConfig( model_size="tiny", dim=256, layers=8, heads=8, ) elif model_size == "small": model_config = ModelConfig( model_size="small", dim=512, layers=12, heads=8, ) elif model_size == "medium": model_config = ModelConfig( model_size="medium", dim=768, layers=16, heads=12, ) else: raise ValueError(f"Unknown model size: {model_size}") thinking_config = InterleavedThinkingConfig() return CompactAIModel(model_config, thinking_config)