likhonsheikh's picture
Upload folder using huggingface_hub
b9b1e87 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple, List, Dict, Any
import math
from ..configs.config import ModelConfig, InterleavedThinkingConfig
class EfficientAttention(nn.Module):
"""Memory-efficient attention mechanism with flash attention support."""
def __init__(self, config: ModelConfig):
super().__init__()
self.config = config
self.n_heads = config.heads
self.head_dim = config.dim // config.heads
self.scale = self.head_dim ** -0.5
self.q_proj = nn.Linear(config.dim, config.dim, bias=False)
self.k_proj = nn.Linear(config.dim, config.dim, bias=False)
self.v_proj = nn.Linear(config.dim, config.dim, bias=False)
self.o_proj = nn.Linear(config.dim, config.dim, bias=False)
self.dropout = nn.Dropout(config.dropout)
# RoPE for positional encoding
self.rope_cache = self._build_rope_cache(config.max_seq_len)
def _build_rope_cache(self, max_seq_len: int) -> torch.Tensor:
"""Build RoPE (Rotary Position Embedding) cache."""
inv_freq = 1.0 / (10000 ** (torch.arange(0, self.head_dim, 2).float() / self.head_dim))
t = torch.arange(max_seq_len).float()
freqs = torch.einsum('i , j -> i j', t, inv_freq)
return torch.cat((freqs.sin(), freqs.cos()), dim=-1)
def _apply_rope(self, x: torch.Tensor, start_pos: int = 0) -> torch.Tensor:
"""Apply RoPE to input tensor."""
rope = self.rope_cache[start_pos:start_pos + x.size(1)].to(x.device)
xshaped = x.float().reshape(*x.shape[:-1], -1, 2)
rope_shaped = rope.reshape(1, xshaped.size(1), 1, xshaped.size(3), 2)
x_out = torch.stack([
xshaped[..., 0] * rope_shaped[..., 0] - xshaped[..., 1] * rope_shaped[..., 1],
xshaped[..., 1] * rope_shaped[..., 0] + xshaped[..., 0] * rope_shaped[..., 1],
], -1)
return x_out.flatten(3).type_as(x)
def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
B, T, C = x.size()
# Project to q, k, v
q = self.q_proj(x).view(B, T, self.n_heads, self.head_dim)
k = self.k_proj(x).view(B, T, self.n_heads, self.head_dim)
v = self.v_proj(x).view(B, T, self.n_heads, self.head_dim)
# Apply RoPE
q = self._apply_rope(q.transpose(1, 2)).transpose(1, 2)
k = self._apply_rope(k.transpose(1, 2)).transpose(1, 2)
# Transpose for attention
q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
# Flash attention or regular attention
if self.config.flash_attention and hasattr(torch.nn.functional, 'scaled_dot_product_attention'):
# Use PyTorch's built-in flash attention
attn_output = F.scaled_dot_product_attention(
q, k, v, attn_mask=mask, dropout_p=self.dropout.p if self.training else 0.0
)
else:
# Regular attention
attn_weights = (q @ k.transpose(-2, -1)) * self.scale
if mask is not None:
attn_weights = attn_weights.masked_fill(mask == 0, float('-inf'))
attn_weights = F.softmax(attn_weights, dim=-1)
attn_weights = self.dropout(attn_weights)
attn_output = attn_weights @ v
# Reshape and project
attn_output = attn_output.transpose(1, 2).contiguous().view(B, T, C)
return self.o_proj(attn_output)
class FeedForward(nn.Module):
"""Feed-forward network with SwiGLU activation."""
def __init__(self, config: ModelConfig):
super().__init__()
hidden_dim = 4 * config.dim # Standard expansion factor
self.gate_proj = nn.Linear(config.dim, hidden_dim, bias=False)
self.up_proj = nn.Linear(config.dim, hidden_dim, bias=False)
self.down_proj = nn.Linear(hidden_dim, config.dim, bias=False)
self.dropout = nn.Dropout(config.dropout)
def forward(self, x: torch.Tensor) -> torch.Tensor:
gate = self.gate_proj(x)
up = self.up_proj(x)
# SwiGLU activation
hidden = F.silu(gate) * up
return self.down_proj(self.dropout(hidden))
class TransformerBlock(nn.Module):
"""Single transformer block with efficient attention and feed-forward."""
def __init__(self, config: ModelConfig):
super().__init__()
self.attention = EfficientAttention(config)
self.feed_forward = FeedForward(config)
self.attention_norm = nn.RMSNorm(config.dim)
self.ffn_norm = nn.RMSNorm(config.dim)
def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
# Pre-norm attention
attn_out = self.attention(self.attention_norm(x), mask)
x = x + attn_out
# Pre-norm feed-forward
ff_out = self.feed_forward(self.ffn_norm(x))
x = x + ff_out
return x
class CompactTransformer(nn.Module):
"""Compact transformer model with efficient architecture."""
def __init__(self, config: ModelConfig):
super().__init__()
self.config = config
# Token and position embeddings
self.embed_tokens = nn.Embedding(config.vocab_size, config.dim)
self.embed_positions = nn.Embedding(config.max_seq_len, config.dim)
# Transformer layers
self.layers = nn.ModuleList([
TransformerBlock(config) for _ in range(config.layers)
])
# Output head
self.norm = nn.RMSNorm(config.dim)
self.lm_head = nn.Linear(config.dim, config.vocab_size, bias=False)
# Tie weights
self.embed_tokens.weight = self.lm_head.weight
# Initialize weights
self.apply(self._init_weights)
def _init_weights(self, module):
"""Initialize model weights."""
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
def forward(
self,
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
) -> Dict[str, torch.Tensor]:
B, T = input_ids.size()
# Create position IDs if not provided
if position_ids is None:
position_ids = torch.arange(T, dtype=torch.long, device=input_ids.device).unsqueeze(0)
# Embeddings
token_emb = self.embed_tokens(input_ids)
pos_emb = self.embed_positions(position_ids)
x = token_emb + pos_emb
# Create attention mask
if attention_mask is not None:
# Convert to causal mask
causal_mask = torch.triu(torch.ones(T, T, device=input_ids.device), diagonal=1).bool()
causal_mask = causal_mask.unsqueeze(0).unsqueeze(0) # (1, 1, T, T)
# Combine with attention_mask
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) # (B, 1, 1, T)
attention_mask = attention_mask & ~causal_mask
else:
# Pure causal mask
causal_mask = torch.triu(torch.ones(T, T, device=input_ids.device), diagonal=1).bool()
attention_mask = ~causal_mask.unsqueeze(0).unsqueeze(0)
# Apply transformer layers
for layer in self.layers:
x = layer(x, attention_mask)
# Final normalization
x = self.norm(x)
# Language modeling head
logits = self.lm_head(x)
return {"logits": logits, "hidden_states": x}
def get_num_params(self) -> int:
"""Get total number of parameters."""
return sum(p.numel() for p in self.parameters())
class ReasoningPath(nn.Module):
"""Enhanced reasoning path for interleaved thinking with uncertainty estimation."""
def __init__(self, config: ModelConfig, thinking_config: InterleavedThinkingConfig, path_id: int = 0):
super().__init__()
self.config = config
self.thinking_config = thinking_config
self.path_id = path_id
# Reasoning-specific layers (smaller than main model)
self.reasoning_layers = nn.ModuleList([
TransformerBlock(config) for _ in range(min(2, config.layers // 2))
])
# Enhanced confidence scoring with uncertainty estimation
self.confidence_head = nn.Sequential(
nn.Linear(config.dim, config.dim // 2),
nn.ReLU(),
nn.Linear(config.dim // 2, 2) # mean and variance for uncertainty
)
self.output_projection = nn.Linear(config.dim, config.vocab_size)
# Path specialization (if enabled)
if thinking_config.path_specialization:
self.specialization_adapter = nn.Linear(config.dim, config.dim)
else:
self.specialization_adapter = None
def forward(self, hidden_states: torch.Tensor, mask: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]:
x = hidden_states
# Apply path specialization if enabled
if self.specialization_adapter is not None:
# Add path-specific bias based on path_id
path_bias = torch.sin(torch.tensor(float(self.path_id) * 0.1)) * 0.1
x = x + path_bias
# Apply specialization adapter
if self.specialization_adapter is not None:
x = self.specialization_adapter(x)
# Apply reasoning layers
for layer in self.reasoning_layers:
x = layer(x, mask)
# Enhanced confidence scoring with uncertainty
if self.thinking_config.uncertainty_estimation:
confidence_params = self.confidence_head(x.mean(dim=1))
confidence_mean = torch.sigmoid(confidence_params[:, 0:1])
confidence_var = F.softplus(confidence_params[:, 1:2]) + 1e-6 # ensure positive variance
# Sample from distribution for robustness
if self.training:
eps = torch.randn_like(confidence_var)
confidence = confidence_mean + torch.sqrt(confidence_var) * eps
confidence = torch.clamp(confidence, 0.0, 1.0)
else:
confidence = confidence_mean
else:
# Fallback to simple confidence scoring
confidence = torch.sigmoid(self.confidence_head(x.mean(dim=1)))
# Project to vocabulary
reasoning_logits = self.output_projection(x)
return {
"reasoning_logits": reasoning_logits,
"confidence": confidence,
"confidence_var": confidence_var if self.thinking_config.uncertainty_estimation else None,
"reasoning_states": x
}
class EarlyStopController(nn.Module):
"""Enhanced controller for early stopping with task-specific thresholds."""
def __init__(self, config: ModelConfig, thinking_config: InterleavedThinkingConfig):
super().__init__()
self.thinking_config = thinking_config
# Task complexity classifier
self.complexity_classifier = nn.Sequential(
nn.Linear(config.dim, config.dim // 2),
nn.ReLU(),
nn.Linear(config.dim // 2, 3), # simple, medium, complex
)
# Early stop predictor
self.stop_predictor = nn.Linear(config.dim, 1)
# Task-specific threshold predictors (if enabled)
if thinking_config.task_specific_thresholds:
self.task_threshold_predictor = nn.Sequential(
nn.Linear(config.dim, config.dim // 4),
nn.ReLU(),
nn.Linear(config.dim // 4, 1),
nn.Sigmoid()
)
else:
self.task_threshold_predictor = None
def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
# Classify task complexity
complexity_logits = self.complexity_classifier(hidden_states.mean(dim=1))
complexity_probs = F.softmax(complexity_logits, dim=-1)
# Predict whether to stop early
stop_logits = self.stop_predictor(hidden_states.mean(dim=1))
stop_prob = torch.sigmoid(stop_logits)
# Task-specific threshold (if enabled)
task_threshold = None
if self.task_threshold_predictor is not None:
task_threshold = self.task_threshold_predictor(hidden_states.mean(dim=1))
return complexity_probs, stop_prob, task_threshold
class HierarchicalReasoningPath(nn.Module):
"""Hierarchical reasoning path with different abstraction levels."""
def __init__(self, config: ModelConfig, thinking_config: InterleavedThinkingConfig, level: int):
super().__init__()
self.level = level
self.config = config
self.thinking_config = thinking_config
# Different architectures for different hierarchy levels
if level == 0: # Low-level, detailed reasoning
self.layers = nn.ModuleList([TransformerBlock(config) for _ in range(2)])
self.abstraction_projection = nn.Linear(config.dim, config.dim // 2)
elif level == 1: # Mid-level, pattern recognition
self.layers = nn.ModuleList([TransformerBlock(config) for _ in range(1)])
self.abstraction_projection = nn.Linear(config.dim, config.dim // 4)
else: # High-level, conceptual reasoning
self.layers = nn.ModuleList([TransformerBlock(config) for _ in range(1)])
self.abstraction_projection = nn.Linear(config.dim, config.dim // 8)
self.confidence_head = nn.Linear(config.dim, 1)
self.output_projection = nn.Linear(config.dim, config.vocab_size)
def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]:
# Apply level-specific reasoning
for layer in self.layers:
x = layer(x, mask)
# Apply abstraction based on hierarchy level
abstracted = self.abstraction_projection(x)
confidence = torch.sigmoid(self.confidence_head(x.mean(dim=1)))
reasoning_logits = self.output_projection(x)
return {
"reasoning_logits": reasoning_logits,
"confidence": confidence,
"abstracted_states": abstracted,
"level": self.level
}
class InterleavedThinking(nn.Module):
"""Enhanced interleaved thinking mechanism with hierarchical paths and attention fusion."""
def __init__(self, model_config: ModelConfig, thinking_config: InterleavedThinkingConfig):
super().__init__()
self.model_config = model_config
self.thinking_config = thinking_config
# Hierarchical reasoning paths
if thinking_config.hierarchical_paths:
self.reasoning_paths = nn.ModuleList([
HierarchicalReasoningPath(model_config, thinking_config, level % thinking_config.num_hierarchy_levels)
for level in range(thinking_config.max_reasoning_paths)
])
else:
# Fallback to regular reasoning paths
self.reasoning_paths = nn.ModuleList([
ReasoningPath(model_config, thinking_config, path_id=i)
for i in range(thinking_config.max_reasoning_paths)
])
# Early stop controller
self.early_stop_controller = EarlyStopController(model_config, thinking_config)
# Attention-based path fusion (if enabled)
if thinking_config.attention_fusion:
self.fusion_attention = nn.MultiheadAttention(
embed_dim=model_config.vocab_size,
num_heads=8,
dropout=0.1
)
self.fusion_norm = nn.LayerNorm(model_config.vocab_size)
else:
# Fallback to linear combination
self.path_combiner = nn.Linear(
model_config.vocab_size * thinking_config.max_reasoning_paths,
model_config.vocab_size
)
# Adaptive memory compression with reconstruction
if thinking_config.adaptive_compression:
compression_dim = model_config.dim // 4
self.memory_compressor = nn.Linear(model_config.dim, compression_dim)
self.memory_reconstructor = nn.Linear(compression_dim, model_config.dim)
self.compression_gate = nn.Sequential(
nn.Linear(model_config.dim, 1),
nn.Sigmoid()
)
elif thinking_config.memory_compression:
self.memory_compressor = nn.Linear(model_config.dim, model_config.dim // 4)
self.memory_reconstructor = None
self.compression_gate = None
else:
self.memory_compressor = None
self.memory_reconstructor = None
self.compression_gate = None
def forward(
self,
base_hidden_states: torch.Tensor,
mask: Optional[torch.Tensor] = None,
current_depth: int = 0
) -> Dict[str, Any]:
batch_size = base_hidden_states.size(0)
# Early stopping check with task-specific thresholds
complexity_probs, stop_prob, task_threshold = self.early_stop_controller(base_hidden_states)
# Use task-specific threshold if available, otherwise use config default
effective_threshold = task_threshold if task_threshold is not None else self.thinking_config.early_stop_threshold
should_stop = stop_prob > effective_threshold
if should_stop.item() and current_depth > 1:
return {
"should_stop": True,
"reasoning_results": None,
"final_logits": None,
"confidence_scores": None,
"complexity": complexity_probs,
"visualization_data": {"early_stop_depth": current_depth} if self.thinking_config.visualization_enabled else None
}
# Run parallel reasoning paths
path_results = []
confidence_scores = []
confidence_vars = []
for path in self.reasoning_paths:
result = path(base_hidden_states, mask)
path_results.append(result["reasoning_logits"])
confidence_scores.append(result["confidence"])
if "confidence_var" in result and result["confidence_var"] is not None:
confidence_vars.append(result["confidence_var"])
# Stack results
path_logits = torch.stack(path_results, dim=1) # (B, num_paths, T, vocab_size)
confidence_scores = torch.stack(confidence_scores, dim=1) # (B, num_paths, 1)
# Path combination: attention fusion or confidence-weighted averaging
if self.thinking_config.attention_fusion:
# Attention-based fusion
# Flatten batch and sequence dimensions for attention
B, P, T, V = path_logits.size()
flat_logits = path_logits.view(B * T, P, V) # (B*T, P, V)
# Create attention mask and query
attn_output, _ = self.fusion_attention(
flat_logits.mean(dim=1, keepdim=True), # query: mean across paths
flat_logits, # key
flat_logits # value
)
combined_logits = self.fusion_norm(attn_output.squeeze(1)).view(B, T, V)
else:
# Confidence-weighted averaging
confidence_weights = F.softmax(confidence_scores.squeeze(-1), dim=-1)
confidence_weights = confidence_weights.unsqueeze(-1).unsqueeze(-1) # (B, num_paths, 1, 1)
# Weighted combination of logits
weighted_logits = (path_logits * confidence_weights).sum(dim=1)
# Final projection
combined_logits = self.path_combiner(
weighted_logits.view(batch_size, -1, self.model_config.vocab_size * self.thinking_config.max_reasoning_paths)
)
# Adaptive memory compression with reconstruction
compressed_states = None
reconstruction_loss = None
if self.memory_compressor is not None:
# Get the reasoning states from the highest confidence path
best_path_idx = confidence_scores.mean(dim=-1).argmax(dim=-1)
best_reasoning_states = torch.stack([
path_results[i]["reasoning_states"][b] for b, i in enumerate(best_path_idx)
], dim=0)
if self.thinking_config.adaptive_compression and self.memory_reconstructor is not None:
# Adaptive compression with gating and reconstruction
compression_gate = self.compression_gate(best_reasoning_states.mean(dim=1))
compressed = self.memory_compressor(best_reasoning_states)
reconstructed = self.memory_reconstructor(compressed)
# Reconstruction loss for training
reconstruction_loss = F.mse_loss(reconstructed, best_reasoning_states)
# Adaptive compression: use compressed if gate > 0.5, otherwise use original
compressed_states = torch.where(
compression_gate.unsqueeze(-1).unsqueeze(-1) > 0.5,
compressed,
best_reasoning_states
)
else:
# Simple compression
compressed_states = self.memory_compressor(best_reasoning_states)
# Visualization data
visualization_data = None
if self.thinking_config.visualization_enabled:
visualization_data = {
"confidence_scores": confidence_scores.cpu().numpy(),
"confidence_vars": [v.cpu().numpy() for v in confidence_vars] if confidence_vars else None,
"complexity_probs": complexity_probs.cpu().numpy(),
"task_threshold": task_threshold.cpu().numpy() if task_threshold is not None else None,
"path_logits_shape": path_logits.shape,
"hierarchical_levels": [getattr(path, 'level', 0) for path in self.reasoning_paths] if self.thinking_config.hierarchical_paths else None,
"reconstruction_loss": reconstruction_loss.item() if reconstruction_loss is not None else None
}
return {
"should_stop": False,
"reasoning_results": {
"path_logits": path_logits,
"confidence_scores": confidence_scores,
"complexity": complexity_probs,
"compressed_states": compressed_states,
"confidence_vars": confidence_vars if confidence_vars else None,
"reconstruction_loss": reconstruction_loss
},
"final_logits": combined_logits,
"confidence_scores": confidence_scores.mean(dim=1),
"visualization_data": visualization_data
}
class CompactAIModel(nn.Module):
"""Complete compact AI model with interleaved thinking."""
def __init__(self, model_config: ModelConfig, thinking_config: InterleavedThinkingConfig):
super().__init__()
self.model_config = model_config
self.thinking_config = thinking_config
# Base transformer model
self.base_model = CompactTransformer(model_config)
# Interleaved thinking mechanism
self.thinking = InterleavedThinking(model_config, thinking_config)
# Dynamic depth controller
self.depth_controller = nn.Linear(model_config.dim, thinking_config.reasoning_depth)
def forward(
self,
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
use_thinking: bool = True,
max_reasoning_depth: Optional[int] = None,
) -> Dict[str, Any]:
# Base model forward pass
base_outputs = self.base_model(input_ids, attention_mask)
base_logits = base_outputs["logits"]
base_hidden = base_outputs["hidden_states"]
if not use_thinking:
return {
"logits": base_logits,
"thinking_results": None,
"final_tokens": 0,
"visualization_data": None
}
# Determine reasoning depth
if max_reasoning_depth is None:
# Dynamic depth based on input complexity
if self.thinking_config.dynamic_depth:
depth_logits = self.depth_controller(base_hidden.mean(dim=1))
depth_probs = F.softmax(depth_logits, dim=-1)
max_reasoning_depth = depth_probs.argmax(dim=-1).item() + 1
else:
max_reasoning_depth = self.thinking_config.reasoning_depth
# Interleaved thinking with iterative reasoning
current_hidden = base_hidden
thinking_results = []
total_reasoning_tokens = 0
visualization_history = [] if self.thinking_config.visualization_enabled else None
for depth in range(max_reasoning_depth):
thinking_output = self.thinking(current_hidden, attention_mask, depth)
if thinking_output["should_stop"]:
break
thinking_results.append(thinking_output["reasoning_results"])
# Collect visualization data
if self.thinking_config.visualization_enabled and thinking_output["visualization_data"]:
thinking_output["visualization_data"]["depth"] = depth
visualization_history.append(thinking_output["visualization_data"])
# Update hidden states for next iteration if we have compressed states
if thinking_output["reasoning_results"]["compressed_states"] is not None:
current_hidden = thinking_output["reasoning_results"]["compressed_states"]
else:
# Use the combined logits to generate next hidden states
# This is a simplified version - in practice, you'd want more sophisticated state updates
current_hidden = current_hidden + thinking_output["final_logits"].detach() * 0.1
total_reasoning_tokens += input_ids.size(1)
# Check token budget
if total_reasoning_tokens >= self.thinking_config.token_budget:
break
# Final output combination
if thinking_results:
# Use the last thinking result's combined logits
final_logits = thinking_output["final_logits"]
else:
final_logits = base_logits
return {
"logits": final_logits,
"thinking_results": thinking_results,
"final_tokens": total_reasoning_tokens,
"visualization_data": visualization_history
}
def create_compact_model(model_size: str = "small") -> CompactAIModel:
"""Create a compact AI model with the specified size."""
if model_size == "tiny":
model_config = ModelConfig(
model_size="tiny",
dim=256,
layers=8,
heads=8,
)
elif model_size == "small":
model_config = ModelConfig(
model_size="small",
dim=512,
layers=12,
heads=8,
)
elif model_size == "medium":
model_config = ModelConfig(
model_size="medium",
dim=768,
layers=16,
heads=12,
)
else:
raise ValueError(f"Unknown model size: {model_size}")
thinking_config = InterleavedThinkingConfig()
return CompactAIModel(model_config, thinking_config)