| import torch
|
| import torch.nn as nn
|
| import torch.nn.functional as F
|
| import math
|
| from typing import Optional, Tuple, List
|
| import warnings
|
|
|
|
|
| class RotaryPositionEmbedding(nn.Module):
|
| """RoPE implementation without traditional position embeddings"""
|
|
|
| def __init__(self, dim: int, base: int = 10000):
|
| super().__init__()
|
| self.dim = dim
|
| self.base = base
|
|
|
| inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
|
| self.register_buffer('inv_freq', inv_freq, persistent=False)
|
|
|
| def forward(self, x: torch.Tensor, seq_dim: int = -2) -> Tuple[torch.Tensor, torch.Tensor]:
|
| seq_len = x.shape[seq_dim]
|
| device = x.device
|
| dtype = x.dtype
|
|
|
| t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype)
|
| freqs = torch.outer(t, self.inv_freq)
|
|
|
|
|
| cos = torch.cos(freqs).to(dtype)
|
| sin = torch.sin(freqs).to(dtype)
|
|
|
| return cos, sin
|
|
|
|
|
| def apply_rotary_pos_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
|
| """Apply rotary position embedding to input tensor"""
|
|
|
|
|
|
|
| batch_size, num_heads, seq_len, head_dim = x.shape
|
| half_dim = head_dim // 2
|
|
|
|
|
| x_reshaped = x.view(batch_size, num_heads, seq_len, half_dim, 2)
|
| x_real = x_reshaped[..., 0]
|
| x_imag = x_reshaped[..., 1]
|
|
|
|
|
| cos = cos.unsqueeze(0).unsqueeze(0)
|
| sin = sin.unsqueeze(0).unsqueeze(0)
|
|
|
|
|
| x_real_rot = x_real * cos - x_imag * sin
|
| x_imag_rot = x_real * sin + x_imag * cos
|
|
|
|
|
| x_rotated = torch.stack([x_real_rot, x_imag_rot], dim=-1)
|
| x_rotated = x_rotated.view(batch_size, num_heads, seq_len, head_dim)
|
|
|
| return x_rotated.type_as(x)
|
|
|
|
|
| class VariableGroupedQueryAttention(nn.Module):
|
| """Variable Grouped Query Attention with layer-specific head grouping and optional RoPE/NoPE"""
|
|
|
| def __init__(self, dim: int, num_heads: int = 8, layer_idx: int = 0,
|
| num_layers: int = 12, variable_groups: bool = True,
|
| use_rope: bool = True):
|
| super().__init__()
|
| self.dim = dim
|
| self.num_heads = num_heads
|
| self.head_dim = dim // num_heads
|
| self.scale = self.head_dim ** -0.5
|
| self.variable_groups = variable_groups
|
| self.layer_idx = layer_idx
|
| self.num_layers = num_layers
|
| self.use_rope = use_rope
|
|
|
|
|
| if variable_groups:
|
|
|
|
|
|
|
|
|
|
|
| layer_ratio = layer_idx / max(1, num_layers - 1)
|
|
|
|
|
|
|
| min_kv_heads = max(1, num_heads // 6)
|
| max_kv_heads = max(2, num_heads // 3)
|
|
|
|
|
| raw_kv_heads = int(min_kv_heads + (max_kv_heads - min_kv_heads) * layer_ratio)
|
|
|
|
|
| self.num_kv_heads = raw_kv_heads
|
| if self.num_heads % self.num_kv_heads != 0:
|
|
|
| for i in range(self.num_kv_heads, 0, -1):
|
| if self.num_heads % i == 0:
|
| self.num_kv_heads = i
|
| break
|
|
|
| if self.num_heads % self.num_kv_heads != 0:
|
| for i in range(self.num_kv_heads + 1, max_kv_heads + 1):
|
| if self.num_heads % i == 0:
|
| self.num_kv_heads = i
|
| break
|
| else:
|
| self.num_kv_heads = max(2, num_heads // 2)
|
|
|
|
|
| assert self.num_heads % self.num_kv_heads == 0, \
|
| f"Layer {layer_idx}: num_heads ({num_heads}) must be divisible by num_kv_heads ({self.num_kv_heads})"
|
|
|
|
|
| self.q_proj = nn.Linear(dim, dim, bias=False)
|
|
|
|
|
| self.k_proj = nn.Linear(dim, self.num_kv_heads * self.head_dim, bias=False)
|
| self.v_proj = nn.Linear(dim, self.num_kv_heads * self.head_dim, bias=False)
|
|
|
|
|
| self.out_proj = nn.Linear(dim, dim, bias=False)
|
|
|
|
|
|
|
| if self.use_rope:
|
| self.rope = RotaryPositionEmbedding(self.head_dim)
|
| else:
|
| self.rope = None
|
|
|
| def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| batch_size, seq_len, _ = x.shape
|
|
|
|
|
| q = self.q_proj(x)
|
| k = self.k_proj(x)
|
| v = self.v_proj(x)
|
|
|
|
|
| q = q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| k = k.view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2)
|
| v = v.view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2)
|
|
|
|
|
|
|
| if self.use_rope and self.rope is not None:
|
| cos, sin = self.rope(q)
|
| q = apply_rotary_pos_emb(q, cos, sin)
|
| k = apply_rotary_pos_emb(k, cos, sin)
|
|
|
|
|
|
|
| if self.num_kv_heads != self.num_heads:
|
| k = k.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1)
|
| v = v.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1)
|
|
|
|
|
| attn_scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale
|
|
|
|
|
| if attention_mask is not None:
|
| attn_scores = attn_scores + attention_mask
|
|
|
| attn_weights = F.softmax(attn_scores, dim=-1, dtype=torch.float32).to(q.dtype)
|
|
|
|
|
| attn_output = torch.matmul(attn_weights, v)
|
|
|
|
|
| attn_output = attn_output.transpose(1, 2).contiguous().view(
|
| batch_size, seq_len, self.dim
|
| )
|
|
|
| return self.out_proj(attn_output)
|
|
|
|
|
| class Expert(nn.Module):
|
| """Single expert in the MOE layer"""
|
|
|
| def __init__(self, dim: int, hidden_dim: int, dropout: float = 0.1):
|
| super().__init__()
|
| self.net = nn.Sequential(
|
| nn.Linear(dim, hidden_dim),
|
| nn.GELU(),
|
| nn.Dropout(dropout),
|
| nn.Linear(hidden_dim, dim),
|
| nn.Dropout(dropout)
|
| )
|
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| return self.net(x)
|
|
|
|
|
| class MOELayer(nn.Module):
|
| """Mixture of Experts Layer with adaptive routing based on input complexity"""
|
|
|
| def __init__(self, dim: int, hidden_dim: int, num_experts: int = 4,
|
| capacity_factor: float = 1.0, noisy_gating: bool = True,
|
| adaptive_routing: bool = True):
|
| super().__init__()
|
| self.dim = dim
|
| self.num_experts = num_experts
|
| self.capacity_factor = capacity_factor
|
| self.noisy_gating = noisy_gating
|
| self.adaptive_routing = adaptive_routing
|
|
|
|
|
| self.experts = nn.ModuleList([
|
| Expert(dim, hidden_dim) for _ in range(num_experts)
|
| ])
|
|
|
|
|
| self.gate = nn.Linear(dim, num_experts)
|
|
|
|
|
|
|
| if adaptive_routing:
|
|
|
| self.complexity_encoder = nn.Sequential(
|
| nn.Linear(dim, dim // 4),
|
| nn.GELU(),
|
| nn.Linear(dim // 4, 1),
|
| nn.Sigmoid()
|
| )
|
|
|
|
|
| self.complexity_proj = nn.Linear(dim, 1)
|
|
|
|
|
| self.complexity_bias = nn.Parameter(torch.zeros(1))
|
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| batch_size, seq_len, dim = x.shape
|
|
|
|
|
| x_flat = x.reshape(-1, dim)
|
| num_tokens = x_flat.shape[0]
|
|
|
|
|
| gate_scores = self.gate(x_flat)
|
|
|
|
|
| if self.adaptive_routing:
|
|
|
| complexity_scores = self.complexity_encoder(x_flat)
|
|
|
|
|
|
|
|
|
| complexity_temp = self.complexity_proj(x_flat) + self.complexity_bias
|
|
|
| adaptive_temp = 0.5 + 1.5 * (1.0 - complexity_scores.squeeze(-1))
|
|
|
|
|
|
|
|
|
| scaled_scores = gate_scores / (adaptive_temp.unsqueeze(-1) + 1e-8)
|
|
|
| if self.noisy_gating and self.training:
|
|
|
| noise_scale = (1.0 / self.num_experts) * (1.0 - complexity_scores.squeeze(-1) * 0.5)
|
| noise = torch.randn_like(gate_scores) * noise_scale.unsqueeze(-1)
|
| scaled_scores = scaled_scores + noise
|
| else:
|
| scaled_scores = gate_scores
|
| if self.noisy_gating and self.training:
|
| noise = torch.randn_like(gate_scores) * (1.0 / self.num_experts)
|
| scaled_scores = scaled_scores + noise
|
|
|
|
|
| top_k = 2
|
| top_scores, top_indices = torch.topk(scaled_scores, k=top_k, dim=-1)
|
| top_gates = F.softmax(top_scores, dim=-1, dtype=torch.float32).to(x_flat.dtype)
|
|
|
|
|
| final_output = torch.zeros_like(x_flat)
|
|
|
|
|
| self.aux_loss = self._load_balancing_loss(gate_scores, top_indices)
|
|
|
|
|
| for i in range(top_k):
|
|
|
| expert_indices = top_indices[:, i]
|
| gate_values = top_gates[:, i].unsqueeze(-1)
|
|
|
| for expert_idx, expert in enumerate(self.experts):
|
| token_indices = (expert_indices == expert_idx).nonzero(as_tuple=True)[0]
|
|
|
| if token_indices.numel() > 0:
|
| selected_tokens = x_flat[token_indices]
|
| selected_gates = gate_values[token_indices]
|
|
|
| expert_output = expert(selected_tokens)
|
| final_output.index_add_(0, token_indices, expert_output * selected_gates)
|
|
|
|
|
| return final_output.reshape(batch_size, seq_len, dim)
|
|
|
| def _load_balancing_loss(self, gate_scores: torch.Tensor, top_indices: torch.Tensor) -> torch.Tensor:
|
| """Compute load balancing auxiliary loss"""
|
| if not self.training:
|
| return torch.tensor(0.0, device=gate_scores.device)
|
|
|
|
|
| top1_indices = top_indices[:, 0]
|
| expert_mask = F.one_hot(top1_indices, num_classes=self.num_experts).float()
|
| routing_fraction = expert_mask.mean(dim=0)
|
|
|
|
|
| gate_prob = F.softmax(gate_scores, dim=-1)
|
| gate_fraction = gate_prob.mean(dim=0)
|
|
|
|
|
| load_balance_loss = self.num_experts * torch.sum(routing_fraction * gate_fraction)
|
|
|
| return load_balance_loss
|
|
|
|
|
| class SlimMoETransformerBlock(nn.Module):
|
| """Transformer block with VGQA and MOE"""
|
|
|
| def __init__(self, dim: int, num_heads: int, hidden_dim: int,
|
| num_experts: int = 4, dropout: float = 0.1,
|
| layer_idx: int = 0, num_layers: int = 12,
|
| adaptive_routing: bool = True):
|
| super().__init__()
|
| self.dim = dim
|
| self.adaptive_routing = adaptive_routing
|
|
|
|
|
| self.attn_norm = nn.LayerNorm(dim)
|
|
|
|
|
|
|
| use_rope = (layer_idx % 4 != 3)
|
|
|
| self.attention = VariableGroupedQueryAttention(
|
| dim, num_heads, layer_idx=layer_idx,
|
| num_layers=num_layers, variable_groups=True,
|
| use_rope=use_rope
|
| )
|
|
|
|
|
| self.dense_ffn_norm = nn.LayerNorm(dim)
|
| self.dense_ffn = nn.Sequential(
|
| nn.Linear(dim, hidden_dim),
|
| nn.GELU(),
|
| nn.Dropout(dropout),
|
| nn.Linear(hidden_dim, dim),
|
| nn.Dropout(dropout)
|
| )
|
|
|
|
|
| self.moe_norm = nn.LayerNorm(dim)
|
| self.moe = MOELayer(dim, hidden_dim, num_experts, adaptive_routing=adaptive_routing)
|
|
|
|
|
| self.dropout = nn.Dropout(dropout)
|
|
|
| def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
|
| attn_norm_out = self.attn_norm(x)
|
| attn_out = self.attention(attn_norm_out, attention_mask)
|
| x = x + self.dropout(attn_out)
|
|
|
|
|
| dense_ffn_norm_out = self.dense_ffn_norm(x)
|
| dense_ffn_out = self.dense_ffn(dense_ffn_norm_out)
|
| x = x + dense_ffn_out
|
|
|
|
|
| moe_norm_out = self.moe_norm(x)
|
| moe_out = self.moe(moe_norm_out)
|
| x = x + self.dropout(moe_out)
|
|
|
| return x
|
|
|
|
|
| class SlimMOETransformer(nn.Module):
|
| """Complete MOE Transformer with Variable Grouped Query Attention and RoPE"""
|
|
|
| def __init__(self, vocab_size: int = 50257, dim: int = 768, num_layers: int = 12,
|
| num_heads: int = 12, hidden_dim: int = 2048, num_experts: int = 4,
|
| max_seq_len: int = 2048, dropout: float = 0.1, adaptive_routing: bool = True):
|
| super().__init__()
|
|
|
| self.vocab_size = vocab_size
|
| self.dim = dim
|
| self.num_layers = num_layers
|
| self.max_seq_len = max_seq_len
|
|
|
| self.token_embedding = nn.Embedding(vocab_size, dim)
|
| self.dropout = nn.Dropout(dropout)
|
| self.layers = nn.ModuleList([
|
| SlimMoETransformerBlock(
|
| dim=dim,
|
| num_heads=num_heads,
|
| hidden_dim=hidden_dim,
|
| num_experts=num_experts,
|
| dropout=dropout,
|
| layer_idx=i,
|
| num_layers=num_layers,
|
| adaptive_routing=adaptive_routing
|
| ) for i in range(num_layers)
|
| ])
|
| self.norm = nn.LayerNorm(dim)
|
|
|
|
|
|
|
|
|
|
|
| self.apply(self._init_weights)
|
| self._calculate_parameters()
|
|
|
| def _init_weights(self, module):
|
| """Initialize weights"""
|
| if isinstance(module, nn.Linear):
|
| torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
| if module.bias is not None:
|
| torch.nn.init.zeros_(module.bias)
|
| elif isinstance(module, nn.Embedding):
|
| torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
| elif isinstance(module, nn.LayerNorm):
|
| torch.nn.init.zeros_(module.bias)
|
| torch.nn.init.ones_(module.weight)
|
|
|
| def _calculate_parameters(self):
|
|
|
| total_params = sum(p.numel() for p in self.parameters())
|
| print(f"Total Parameters: {total_params:,}")
|
|
|
| def forward(self, input_ids: torch.Tensor,
|
| attention_mask: Optional[torch.Tensor] = None,
|
| labels: Optional[torch.Tensor] = None) -> dict:
|
|
|
| batch_size, seq_len = input_ids.shape
|
|
|
| causal_mask = torch.triu(
|
| torch.full((1, 1, seq_len, seq_len), -torch.finfo(torch.get_default_dtype()).max, device=input_ids.device),
|
| diagonal=1
|
| )
|
|
|
| if attention_mask is not None:
|
| padding_mask = (1.0 - attention_mask.unsqueeze(1).unsqueeze(2)) * -torch.finfo(
|
| torch.get_default_dtype()).max
|
| extended_attention_mask = causal_mask + padding_mask
|
| else:
|
| extended_attention_mask = causal_mask
|
|
|
| x = self.token_embedding(input_ids) * math.sqrt(self.dim)
|
| x = self.dropout(x)
|
|
|
| total_aux_loss = 0.0
|
| for layer in self.layers:
|
| x = layer(x, extended_attention_mask)
|
| if self.training:
|
| total_aux_loss += layer.moe.aux_loss
|
|
|
| x = self.norm(x)
|
|
|
|
|
| return {
|
| 'last_hidden_state': x,
|
| 'aux_loss': total_aux_loss
|
| }
|
|
|
|
|
| def create_moe_model(vocab_size: int = 50257) -> SlimMOETransformer:
|
| """
|
| Create a MOE model with approximately 300M parameters.
|
|
|
| Configuration:
|
| - dim=768, num_layers=16, num_heads=12
|
| - hidden_dim=1536, num_experts=4
|
| - This yields ~280-290M parameters, safely under 300M
|
| """
|
| return SlimMOETransformer(
|
| vocab_size=vocab_size,
|
| dim=768,
|
| num_layers=16,
|
| num_heads=12,
|
| hidden_dim=1536,
|
| num_experts=4,
|
| max_seq_len=2048,
|
| dropout=0.1
|
| ) |