SlimMoE-250M-base / slim_moe_transformer.py
SlimFactory's picture
Upload folder using huggingface_hub
783312e verified
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from typing import Optional, Tuple, List
import warnings
class RotaryPositionEmbedding(nn.Module):
"""RoPE implementation without traditional position embeddings"""
def __init__(self, dim: int, base: int = 10000):
super().__init__()
self.dim = dim
self.base = base
# Only compute frequencies for half the dimensions (complex form)
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer('inv_freq', inv_freq, persistent=False)
def forward(self, x: torch.Tensor, seq_dim: int = -2) -> Tuple[torch.Tensor, torch.Tensor]:
seq_len = x.shape[seq_dim]
device = x.device
dtype = x.dtype
t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype)
freqs = torch.outer(t, self.inv_freq)
# Create cosine and sine components
cos = torch.cos(freqs).to(dtype)
sin = torch.sin(freqs).to(dtype)
return cos, sin
def apply_rotary_pos_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
"""Apply rotary position embedding to input tensor"""
# x shape: [batch_size, num_heads, seq_len, head_dim]
# cos, sin shape: [seq_len, head_dim//2]
batch_size, num_heads, seq_len, head_dim = x.shape
half_dim = head_dim // 2
# Reshape x to separate real and imaginary parts
x_reshaped = x.view(batch_size, num_heads, seq_len, half_dim, 2)
x_real = x_reshaped[..., 0]
x_imag = x_reshaped[..., 1]
# Expand cos and sin to match dimensions
cos = cos.unsqueeze(0).unsqueeze(0) # [1, 1, seq_len, half_dim]
sin = sin.unsqueeze(0).unsqueeze(0) # [1, 1, seq_len, half_dim]
# Apply rotation
x_real_rot = x_real * cos - x_imag * sin
x_imag_rot = x_real * sin + x_imag * cos
# Combine back
x_rotated = torch.stack([x_real_rot, x_imag_rot], dim=-1)
x_rotated = x_rotated.view(batch_size, num_heads, seq_len, head_dim)
return x_rotated.type_as(x)
class VariableGroupedQueryAttention(nn.Module):
"""Variable Grouped Query Attention with layer-specific head grouping and optional RoPE/NoPE"""
def __init__(self, dim: int, num_heads: int = 8, layer_idx: int = 0,
num_layers: int = 12, variable_groups: bool = True,
use_rope: bool = True):
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim ** -0.5
self.variable_groups = variable_groups
self.layer_idx = layer_idx
self.num_layers = num_layers
self.use_rope = use_rope
# Variable group calculation - different KV heads for each layer
if variable_groups:
# Create progressive pattern: more KV heads in deeper layers
# Early layers: fewer KV heads (more compression)
# Later layers: more KV heads (more detail)
# Normalized layer position (0 to 1)
layer_ratio = layer_idx / max(1, num_layers - 1)
# Calculate KV heads with progressive scaling
# Start with fewer KV heads (e.g., 2-3) and increase toward end
min_kv_heads = max(1, num_heads // 6) # Minimum 1/6 of heads
max_kv_heads = max(2, num_heads // 3) # Maximum 1/3 of heads
# Progressive scaling: early layers use fewer, later use more
raw_kv_heads = int(min_kv_heads + (max_kv_heads - min_kv_heads) * layer_ratio)
# Ensure it's a valid divisor
self.num_kv_heads = raw_kv_heads
if self.num_heads % self.num_kv_heads != 0:
# Find the nearest valid num_kv_heads
for i in range(self.num_kv_heads, 0, -1):
if self.num_heads % i == 0:
self.num_kv_heads = i
break
# If that didn't work, try going up
if self.num_heads % self.num_kv_heads != 0:
for i in range(self.num_kv_heads + 1, max_kv_heads + 1):
if self.num_heads % i == 0:
self.num_kv_heads = i
break
else:
self.num_kv_heads = max(2, num_heads // 2)
# Final validation
assert self.num_heads % self.num_kv_heads == 0, \
f"Layer {layer_idx}: num_heads ({num_heads}) must be divisible by num_kv_heads ({self.num_kv_heads})"
# Query projections
self.q_proj = nn.Linear(dim, dim, bias=False)
# Key-Value projections with grouped attention
self.k_proj = nn.Linear(dim, self.num_kv_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(dim, self.num_kv_heads * self.head_dim, bias=False)
# Output projection
self.out_proj = nn.Linear(dim, dim, bias=False)
# RoPE - only create if using positional embeddings
# NoPE layers (every 4th layer) skip positional embeddings entirely
if self.use_rope:
self.rope = RotaryPositionEmbedding(self.head_dim)
else:
self.rope = None
def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
batch_size, seq_len, _ = x.shape
# Project queries, keys, values
q = self.q_proj(x) # [batch, seq_len, dim]
k = self.k_proj(x) # [batch, seq_len, num_kv_heads * head_dim]
v = self.v_proj(x) # [batch, seq_len, num_kv_heads * head_dim]
# Reshape for multi-head attention
q = q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
k = k.view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2)
v = v.view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2)
# Apply RoPE to queries and keys (NoPE layers skip this)
# NoPE layers rely on causal attention mask for positional information
if self.use_rope and self.rope is not None:
cos, sin = self.rope(q)
q = apply_rotary_pos_emb(q, cos, sin)
k = apply_rotary_pos_emb(k, cos, sin)
# else: NoPE - no positional embeddings applied, causal mask provides ordering
# Expand KV heads for grouped query attention
if self.num_kv_heads != self.num_heads:
k = k.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1)
v = v.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1)
# Compute attention scores
attn_scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale
# Apply attention mask if provided
if attention_mask is not None:
attn_scores = attn_scores + attention_mask
attn_weights = F.softmax(attn_scores, dim=-1, dtype=torch.float32).to(q.dtype)
# Apply attention to values
attn_output = torch.matmul(attn_weights, v)
# Reshape and project back
attn_output = attn_output.transpose(1, 2).contiguous().view(
batch_size, seq_len, self.dim
)
return self.out_proj(attn_output)
class Expert(nn.Module):
"""Single expert in the MOE layer"""
def __init__(self, dim: int, hidden_dim: int, dropout: float = 0.1):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.net(x)
class MOELayer(nn.Module):
"""Mixture of Experts Layer with adaptive routing based on input complexity"""
def __init__(self, dim: int, hidden_dim: int, num_experts: int = 4,
capacity_factor: float = 1.0, noisy_gating: bool = True,
adaptive_routing: bool = True):
super().__init__()
self.dim = dim
self.num_experts = num_experts
self.capacity_factor = capacity_factor
self.noisy_gating = noisy_gating
self.adaptive_routing = adaptive_routing
# Create experts
self.experts = nn.ModuleList([
Expert(dim, hidden_dim) for _ in range(num_experts)
])
# Standard gate network
self.gate = nn.Linear(dim, num_experts)
# NOVEL: Adaptive complexity-based routing
# Learns to route tokens based on their complexity/importance
if adaptive_routing:
# Complexity encoder: estimates how "complex" each token representation is
self.complexity_encoder = nn.Sequential(
nn.Linear(dim, dim // 4),
nn.GELU(),
nn.Linear(dim // 4, 1),
nn.Sigmoid() # Output: 0 (simple) to 1 (complex)
)
# Adaptive temperature: dynamically adjusts expert selection based on complexity
self.complexity_proj = nn.Linear(dim, 1)
# Learnable bias for complexity-aware routing
self.complexity_bias = nn.Parameter(torch.zeros(1))
def forward(self, x: torch.Tensor) -> torch.Tensor:
batch_size, seq_len, dim = x.shape
# Flatten for expert routing
x_flat = x.reshape(-1, dim)
num_tokens = x_flat.shape[0]
# Compute standard gate scores
gate_scores = self.gate(x_flat)
# NOVEL: Adaptive routing based on token complexity
if self.adaptive_routing:
# Estimate complexity of each token (0 = simple, 1 = complex)
complexity_scores = self.complexity_encoder(x_flat) # [num_tokens, 1]
# Compute adaptive temperature based on complexity
# Complex tokens get lower temperature (sharper distribution)
# Simple tokens get higher temperature (softer distribution)
complexity_temp = self.complexity_proj(x_flat) + self.complexity_bias
# Temperature in range [0.5, 2.0] - inverse relationship with complexity
adaptive_temp = 0.5 + 1.5 * (1.0 - complexity_scores.squeeze(-1))
# Apply adaptive temperature scaling to gate scores
# Lower temp = sharper = focus on fewer experts
# Higher temp = softer = distribute more evenly
scaled_scores = gate_scores / (adaptive_temp.unsqueeze(-1) + 1e-8)
if self.noisy_gating and self.training:
# Reduced noise for complex tokens (they should be more confident)
noise_scale = (1.0 / self.num_experts) * (1.0 - complexity_scores.squeeze(-1) * 0.5)
noise = torch.randn_like(gate_scores) * noise_scale.unsqueeze(-1)
scaled_scores = scaled_scores + noise
else:
scaled_scores = gate_scores
if self.noisy_gating and self.training:
noise = torch.randn_like(gate_scores) * (1.0 / self.num_experts)
scaled_scores = scaled_scores + noise
# Get top-2 experts using adaptive scores
top_k = 2
top_scores, top_indices = torch.topk(scaled_scores, k=top_k, dim=-1)
top_gates = F.softmax(top_scores, dim=-1, dtype=torch.float32).to(x_flat.dtype)
# Create placeholder for final output
final_output = torch.zeros_like(x_flat)
# Compute auxiliary loss for load balancing (use original gate_scores, not scaled)
self.aux_loss = self._load_balancing_loss(gate_scores, top_indices)
# Route tokens to experts
for i in range(top_k):
# Process tokens for the i-th choice expert
expert_indices = top_indices[:, i]
gate_values = top_gates[:, i].unsqueeze(-1)
for expert_idx, expert in enumerate(self.experts):
token_indices = (expert_indices == expert_idx).nonzero(as_tuple=True)[0]
if token_indices.numel() > 0:
selected_tokens = x_flat[token_indices]
selected_gates = gate_values[token_indices]
expert_output = expert(selected_tokens)
final_output.index_add_(0, token_indices, expert_output * selected_gates)
# Reshape back to original dimensions
return final_output.reshape(batch_size, seq_len, dim)
def _load_balancing_loss(self, gate_scores: torch.Tensor, top_indices: torch.Tensor) -> torch.Tensor:
"""Compute load balancing auxiliary loss"""
if not self.training:
return torch.tensor(0.0, device=gate_scores.device)
# Compute fraction of tokens routed to each expert (based on top-1 choice)
top1_indices = top_indices[:, 0]
expert_mask = F.one_hot(top1_indices, num_classes=self.num_experts).float()
routing_fraction = expert_mask.mean(dim=0)
# Compute fraction of gate probability for each expert
gate_prob = F.softmax(gate_scores, dim=-1)
gate_fraction = gate_prob.mean(dim=0)
# Load balancing loss
load_balance_loss = self.num_experts * torch.sum(routing_fraction * gate_fraction)
return load_balance_loss
class SlimMoETransformerBlock(nn.Module):
"""Transformer block with VGQA and MOE"""
def __init__(self, dim: int, num_heads: int, hidden_dim: int,
num_experts: int = 4, dropout: float = 0.1,
layer_idx: int = 0, num_layers: int = 12,
adaptive_routing: bool = True):
super().__init__()
self.dim = dim
self.adaptive_routing = adaptive_routing
# Attention components with layer-specific KV heads
self.attn_norm = nn.LayerNorm(dim)
# NoPE every 4th layer (layers 3, 7, 11, ...), RoPE for all others
# Pattern: layer_idx % 4 == 3 means it's the 4th layer (0-indexed: 3rd, 7th, etc.)
use_rope = (layer_idx % 4 != 3)
self.attention = VariableGroupedQueryAttention(
dim, num_heads, layer_idx=layer_idx,
num_layers=num_layers, variable_groups=True,
use_rope=use_rope
)
# Dense transformer feed-forward (before MoE)
self.dense_ffn_norm = nn.LayerNorm(dim)
self.dense_ffn = nn.Sequential(
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
# MOE components
self.moe_norm = nn.LayerNorm(dim)
self.moe = MOELayer(dim, hidden_dim, num_experts, adaptive_routing=adaptive_routing)
# Dropout
self.dropout = nn.Dropout(dropout)
def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
# Attention branch with residual
attn_norm_out = self.attn_norm(x)
attn_out = self.attention(attn_norm_out, attention_mask)
x = x + self.dropout(attn_out)
# Dense transformer feed-forward branch with residual
dense_ffn_norm_out = self.dense_ffn_norm(x)
dense_ffn_out = self.dense_ffn(dense_ffn_norm_out)
x = x + dense_ffn_out
# MOE branch with residual
moe_norm_out = self.moe_norm(x)
moe_out = self.moe(moe_norm_out)
x = x + self.dropout(moe_out)
return x
class SlimMOETransformer(nn.Module):
"""Complete MOE Transformer with Variable Grouped Query Attention and RoPE"""
def __init__(self, vocab_size: int = 50257, dim: int = 768, num_layers: int = 12,
num_heads: int = 12, hidden_dim: int = 2048, num_experts: int = 4,
max_seq_len: int = 2048, dropout: float = 0.1, adaptive_routing: bool = True):
super().__init__()
self.vocab_size = vocab_size
self.dim = dim
self.num_layers = num_layers
self.max_seq_len = max_seq_len
self.token_embedding = nn.Embedding(vocab_size, dim)
self.dropout = nn.Dropout(dropout)
self.layers = nn.ModuleList([
SlimMoETransformerBlock(
dim=dim,
num_heads=num_heads,
hidden_dim=hidden_dim,
num_experts=num_experts,
dropout=dropout,
layer_idx=i,
num_layers=num_layers,
adaptive_routing=adaptive_routing
) for i in range(num_layers)
])
self.norm = nn.LayerNorm(dim)
# --- FIX: Remove the lm_head from the base transformer model ---
# self.lm_head = nn.Linear(dim, vocab_size, bias=False)
# The CausalLM wrapper will handle the final projection.
self.apply(self._init_weights)
self._calculate_parameters() # This will now show a smaller number
def _init_weights(self, module):
"""Initialize weights"""
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
elif isinstance(module, nn.LayerNorm):
torch.nn.init.zeros_(module.bias)
torch.nn.init.ones_(module.weight)
def _calculate_parameters(self):
# ... (this method is unchanged) ...
total_params = sum(p.numel() for p in self.parameters())
print(f"Total Parameters: {total_params:,}")
def forward(self, input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None) -> dict: # Note: labels are ignored here now
batch_size, seq_len = input_ids.shape
causal_mask = torch.triu(
torch.full((1, 1, seq_len, seq_len), -torch.finfo(torch.get_default_dtype()).max, device=input_ids.device),
diagonal=1
)
if attention_mask is not None:
padding_mask = (1.0 - attention_mask.unsqueeze(1).unsqueeze(2)) * -torch.finfo(
torch.get_default_dtype()).max
extended_attention_mask = causal_mask + padding_mask
else:
extended_attention_mask = causal_mask
x = self.token_embedding(input_ids) * math.sqrt(self.dim)
x = self.dropout(x)
total_aux_loss = 0.0
for layer in self.layers:
x = layer(x, extended_attention_mask)
if self.training:
total_aux_loss += layer.moe.aux_loss
x = self.norm(x)
# --- FIX: Return hidden states and aux loss, not logits ---
return {
'last_hidden_state': x,
'aux_loss': total_aux_loss
}
def create_moe_model(vocab_size: int = 50257) -> SlimMOETransformer:
"""
Create a MOE model with approximately 300M parameters.
Configuration:
- dim=768, num_layers=16, num_heads=12
- hidden_dim=1536, num_experts=4
- This yields ~280-290M parameters, safely under 300M
"""
return SlimMOETransformer(
vocab_size=vocab_size,
dim=768,
num_layers=16,
num_heads=12,
hidden_dim=1536,
num_experts=4,
max_seq_len=2048,
dropout=0.1
)