Upload 3 files
Browse files- moe_config.py +119 -0
- moe_layers.py +323 -0
- moe_model.py +460 -0
moe_config.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
HuggingFace-compatible MoE Configuration
|
| 3 |
+
Basierend auf dem nanoMoE Blog Post
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from transformers import PretrainedConfig
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class MoEGPTConfig(PretrainedConfig):
|
| 10 |
+
"""
|
| 11 |
+
Konfiguration für MoE-basiertes GPT Modell.
|
| 12 |
+
|
| 13 |
+
Args:
|
| 14 |
+
vocab_size (int): Größe des Vokabulars
|
| 15 |
+
n_positions (int): Maximale Sequenzlänge
|
| 16 |
+
n_embd (int): Dimensionalität der Embeddings (d im Blog)
|
| 17 |
+
n_layer (int): Anzahl der Transformer Blocks
|
| 18 |
+
n_head (int): Anzahl der Attention Heads
|
| 19 |
+
n_experts (int): Anzahl der Experten pro MoE Layer
|
| 20 |
+
n_experts_active (int): Anzahl aktiver Experten (top-k)
|
| 21 |
+
moe_layer_frequency (int): Jede n-te Layer wird zu MoE (P im Blog)
|
| 22 |
+
capacity_factor (float): Expert Capacity Factor für Training
|
| 23 |
+
eval_capacity_factor (float): Expert Capacity Factor für Evaluation
|
| 24 |
+
use_noisy_gating (bool): Ob Noisy Top-k Gating verwendet werden soll
|
| 25 |
+
aux_loss_alpha (float): Skalierung für Load Balancing Loss
|
| 26 |
+
router_z_loss_alpha (float): Skalierung für Router Z-Loss
|
| 27 |
+
bias (bool): Ob Bias in Linear Layers verwendet werden soll
|
| 28 |
+
dropout (float): Dropout Probability
|
| 29 |
+
activation_function (str): Aktivierungsfunktion (gelu, relu, swiglu)
|
| 30 |
+
initializer_range (float): Standard Deviation für Weight Initialization
|
| 31 |
+
layer_norm_epsilon (float): Epsilon für Layer Normalization
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
model_type = "moe_gpt"
|
| 35 |
+
|
| 36 |
+
def __init__(
|
| 37 |
+
self,
|
| 38 |
+
vocab_size=128256, # Llama 3.2 tokenizer (inkl. special tokens)
|
| 39 |
+
n_positions=2048, # Default 2048 für RoPE
|
| 40 |
+
n_embd=768,
|
| 41 |
+
n_layer=12,
|
| 42 |
+
n_head=12,
|
| 43 |
+
n_experts=8,
|
| 44 |
+
n_experts_active=2,
|
| 45 |
+
moe_layer_frequency=2,
|
| 46 |
+
capacity_factor=1.25,
|
| 47 |
+
eval_capacity_factor=2.0,
|
| 48 |
+
use_noisy_gating=True,
|
| 49 |
+
aux_loss_alpha=0.01,
|
| 50 |
+
router_z_loss_alpha=0.001,
|
| 51 |
+
bias=False,
|
| 52 |
+
dropout=0.1,
|
| 53 |
+
activation_function="gelu",
|
| 54 |
+
initializer_range=0.1,
|
| 55 |
+
layer_norm_epsilon=1e-5,
|
| 56 |
+
use_cache=True,
|
| 57 |
+
rope_theta=10000.0, # RoPE base theta
|
| 58 |
+
**kwargs,
|
| 59 |
+
):
|
| 60 |
+
super().__init__(**kwargs)
|
| 61 |
+
|
| 62 |
+
self.vocab_size = vocab_size
|
| 63 |
+
self.n_positions = n_positions
|
| 64 |
+
self.n_embd = n_embd
|
| 65 |
+
self.n_layer = n_layer
|
| 66 |
+
self.n_head = n_head
|
| 67 |
+
self.n_experts = n_experts
|
| 68 |
+
self.n_experts_active = n_experts_active
|
| 69 |
+
self.moe_layer_frequency = moe_layer_frequency
|
| 70 |
+
self.capacity_factor = capacity_factor
|
| 71 |
+
self.eval_capacity_factor = eval_capacity_factor
|
| 72 |
+
self.use_noisy_gating = use_noisy_gating
|
| 73 |
+
self.aux_loss_alpha = aux_loss_alpha
|
| 74 |
+
self.router_z_loss_alpha = router_z_loss_alpha
|
| 75 |
+
self.bias = bias
|
| 76 |
+
self.dropout = dropout
|
| 77 |
+
self.activation_function = activation_function
|
| 78 |
+
self.initializer_range = initializer_range
|
| 79 |
+
self.layer_norm_epsilon = layer_norm_epsilon
|
| 80 |
+
self.use_cache = use_cache
|
| 81 |
+
self.rope_theta = rope_theta
|
| 82 |
+
|
| 83 |
+
# HuggingFace Standard Attribute (für .generate())
|
| 84 |
+
self.num_hidden_layers = n_layer
|
| 85 |
+
self.hidden_size = n_embd
|
| 86 |
+
self.num_attention_heads = n_head
|
| 87 |
+
self.max_position_embeddings = n_positions
|
| 88 |
+
|
| 89 |
+
# Validierung
|
| 90 |
+
assert n_embd % n_head == 0, "n_embd muss durch n_head teilbar sein"
|
| 91 |
+
assert n_experts_active <= n_experts, "n_experts_active darf nicht größer als n_experts sein"
|
| 92 |
+
assert moe_layer_frequency >= 1, "moe_layer_frequency muss mindestens 1 sein"
|
| 93 |
+
|
| 94 |
+
@property
|
| 95 |
+
def head_dim(self):
|
| 96 |
+
"""Dimension pro Attention Head"""
|
| 97 |
+
return self.n_embd // self.n_head
|
| 98 |
+
|
| 99 |
+
@property
|
| 100 |
+
def total_experts(self):
|
| 101 |
+
"""Gesamtanzahl der Experten im Modell"""
|
| 102 |
+
num_moe_layers = sum(1 for i in range(self.n_layer) if i % self.moe_layer_frequency == 0)
|
| 103 |
+
return num_moe_layers * self.n_experts
|
| 104 |
+
|
| 105 |
+
@property
|
| 106 |
+
def active_parameters_ratio(self):
|
| 107 |
+
"""Ratio der aktiven Parameter (ungefähr)"""
|
| 108 |
+
num_moe_layers = sum(1 for i in range(self.n_layer) if i % self.moe_layer_frequency == 0)
|
| 109 |
+
num_dense_layers = self.n_layer - num_moe_layers
|
| 110 |
+
|
| 111 |
+
# Vereinfachte Schätzung (ignoriert Attention)
|
| 112 |
+
dense_params = num_dense_layers * (8 * self.n_embd**2) # FFN params
|
| 113 |
+
moe_total_params = num_moe_layers * self.n_experts * (8 * self.n_embd**2)
|
| 114 |
+
moe_active_params = num_moe_layers * self.n_experts_active * (8 * self.n_embd**2)
|
| 115 |
+
|
| 116 |
+
total = dense_params + moe_total_params
|
| 117 |
+
active = dense_params + moe_active_params
|
| 118 |
+
|
| 119 |
+
return active / total if total > 0 else 1.0
|
moe_layers.py
ADDED
|
@@ -0,0 +1,323 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
MoE Layer Komponenten
|
| 3 |
+
Basierend auf dem nanoMoE Blog Post und HuggingFace Best Practices
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import math
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
from typing import Tuple, Optional
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class MoERouter(nn.Module):
|
| 14 |
+
"""
|
| 15 |
+
Noisy Top-k Router für MoE.
|
| 16 |
+
Routet Tokens zu den Top-k Experten basierend auf gelernten Wahrscheinlichkeiten.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
def __init__(
|
| 20 |
+
self,
|
| 21 |
+
d_model: int,
|
| 22 |
+
n_experts: int,
|
| 23 |
+
n_experts_active: int,
|
| 24 |
+
use_noisy_gating: bool = True,
|
| 25 |
+
capacity_factor: float = 1.25,
|
| 26 |
+
):
|
| 27 |
+
super().__init__()
|
| 28 |
+
|
| 29 |
+
self.d_model = d_model
|
| 30 |
+
self.n_experts = n_experts
|
| 31 |
+
self.n_experts_active = n_experts_active
|
| 32 |
+
self.use_noisy_gating = use_noisy_gating
|
| 33 |
+
self.capacity_factor = capacity_factor
|
| 34 |
+
|
| 35 |
+
# Linear projections für Router (kein Bias, siehe Shazeer et al. 2017)
|
| 36 |
+
self.w_gate = nn.Linear(d_model, n_experts, bias=False)
|
| 37 |
+
self.w_noise = nn.Linear(d_model, n_experts, bias=False) if use_noisy_gating else None
|
| 38 |
+
|
| 39 |
+
def forward(
|
| 40 |
+
self, x: torch.Tensor
|
| 41 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 42 |
+
"""
|
| 43 |
+
Args:
|
| 44 |
+
x: Input tensor [batch_size, seq_len, d_model]
|
| 45 |
+
|
| 46 |
+
Returns:
|
| 47 |
+
expert_weights: Gewichte für jeden Experten [batch_size * seq_len, n_experts, capacity]
|
| 48 |
+
expert_mask: Maske für verwendete Experten [batch_size * seq_len, n_experts, capacity]
|
| 49 |
+
expert_batches: Batches für jeden Experten [n_experts, capacity, d_model]
|
| 50 |
+
router_logits: Router Logits für z-loss [batch_size, seq_len, n_experts]
|
| 51 |
+
"""
|
| 52 |
+
batch_size, seq_len, d_model = x.shape
|
| 53 |
+
num_tokens = batch_size * seq_len
|
| 54 |
+
|
| 55 |
+
# Router läuft IMMER in FP32 für numerische Stabilität!
|
| 56 |
+
device_type = "cuda" if x.is_cuda else "cpu"
|
| 57 |
+
with torch.amp.autocast(device_type=device_type, enabled=False):
|
| 58 |
+
x_fp32 = x.float()
|
| 59 |
+
|
| 60 |
+
# Router Logits berechnen
|
| 61 |
+
router_logits = self.w_gate(x_fp32) # [B, T, n_experts]
|
| 62 |
+
|
| 63 |
+
# Noisy Top-k Gating (optional)
|
| 64 |
+
if self.use_noisy_gating and self.training:
|
| 65 |
+
noise = F.softplus(self.w_noise(x_fp32))
|
| 66 |
+
noise = noise * torch.randn_like(noise)
|
| 67 |
+
router_logits = router_logits + noise
|
| 68 |
+
|
| 69 |
+
# Top-k Experten auswählen
|
| 70 |
+
top_k_logits, top_k_indices = router_logits.topk(
|
| 71 |
+
self.n_experts_active, dim=-1
|
| 72 |
+
) # [B, T, K]
|
| 73 |
+
|
| 74 |
+
# Softmax über alle Experten (nicht nur Top-k)
|
| 75 |
+
router_probs = torch.full_like(router_logits, float("-inf"))
|
| 76 |
+
router_probs.scatter_(-1, top_k_indices, top_k_logits)
|
| 77 |
+
router_probs = F.softmax(router_probs, dim=-1) # [B, T, n_experts]
|
| 78 |
+
|
| 79 |
+
# Expert Capacity berechnen
|
| 80 |
+
capacity = self._compute_capacity(num_tokens)
|
| 81 |
+
|
| 82 |
+
# Multi-hot Maske der gewählten Experten
|
| 83 |
+
expert_mask = F.one_hot(
|
| 84 |
+
top_k_indices, num_classes=self.n_experts
|
| 85 |
+
) # [B, T, K, n_experts]
|
| 86 |
+
expert_mask = expert_mask.view(num_tokens, self.n_experts_active, self.n_experts)
|
| 87 |
+
expert_mask = expert_mask.permute(1, 0, 2) # [K, num_tokens, n_experts]
|
| 88 |
+
|
| 89 |
+
# Position jedes Tokens im Expert Batch (cumsum für Top-1 first prioritization)
|
| 90 |
+
expert_rank = expert_mask.reshape(
|
| 91 |
+
self.n_experts_active * num_tokens, self.n_experts
|
| 92 |
+
)
|
| 93 |
+
expert_rank = torch.cumsum(expert_rank, dim=0) - 1
|
| 94 |
+
expert_rank = expert_rank.reshape(
|
| 95 |
+
self.n_experts_active, num_tokens, self.n_experts
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
# Tokens über Kapazität hinaus maskieren
|
| 99 |
+
expert_mask = expert_mask * torch.lt(expert_rank, capacity)
|
| 100 |
+
|
| 101 |
+
# Position im Expert Batch
|
| 102 |
+
expert_rank = torch.sum(expert_mask * expert_rank, dim=-1) # [K, num_tokens]
|
| 103 |
+
|
| 104 |
+
# Wahrscheinlichkeiten mit Maske multiplizieren
|
| 105 |
+
router_probs = router_probs.view(num_tokens, self.n_experts)[
|
| 106 |
+
None, :
|
| 107 |
+
] # [1, num_tokens, n_experts]
|
| 108 |
+
expert_weights = expert_mask * router_probs # [K, num_tokens, n_experts]
|
| 109 |
+
|
| 110 |
+
# One-hot für Position in Expert Batch
|
| 111 |
+
expert_rank_one_hot = F.one_hot(
|
| 112 |
+
expert_rank, num_classes=capacity
|
| 113 |
+
) # [K, num_tokens, capacity]
|
| 114 |
+
|
| 115 |
+
# Gewichte an Expert Batch Position
|
| 116 |
+
expert_weights = torch.sum(
|
| 117 |
+
expert_weights.unsqueeze(3) * expert_rank_one_hot.unsqueeze(2), dim=0
|
| 118 |
+
) # [num_tokens, n_experts, capacity]
|
| 119 |
+
expert_mask = expert_weights.bool()
|
| 120 |
+
|
| 121 |
+
# Expert Batches erstellen
|
| 122 |
+
x_flat = x.view(num_tokens, d_model)
|
| 123 |
+
expert_batches = (
|
| 124 |
+
expert_mask.permute(1, 2, 0).type_as(x) @ x_flat
|
| 125 |
+
) # [n_experts, capacity, d_model]
|
| 126 |
+
|
| 127 |
+
return expert_weights, expert_mask, expert_batches, router_logits
|
| 128 |
+
|
| 129 |
+
def _compute_capacity(self, num_tokens: int) -> int:
|
| 130 |
+
"""Berechnet Expert Capacity"""
|
| 131 |
+
capacity = math.floor(
|
| 132 |
+
self.n_experts_active * self.capacity_factor * num_tokens / self.n_experts
|
| 133 |
+
)
|
| 134 |
+
capacity += capacity % 2 # Gerade Zahl für bessere Hardware-Nutzung
|
| 135 |
+
return max(int(capacity), 2) # Minimum 2 für kleine Batches
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
class ExpertMLP(nn.Module):
|
| 139 |
+
"""
|
| 140 |
+
Batch von MLP Experten.
|
| 141 |
+
Alle Experten haben die gleiche Architektur, aber unabhängige Gewichte.
|
| 142 |
+
"""
|
| 143 |
+
|
| 144 |
+
def __init__(
|
| 145 |
+
self,
|
| 146 |
+
d_model: int,
|
| 147 |
+
n_experts: int,
|
| 148 |
+
bias: bool = False,
|
| 149 |
+
dropout: float = 0.1,
|
| 150 |
+
activation: str = "gelu",
|
| 151 |
+
):
|
| 152 |
+
super().__init__()
|
| 153 |
+
|
| 154 |
+
self.d_model = d_model
|
| 155 |
+
self.n_experts = n_experts
|
| 156 |
+
self.bias = bias
|
| 157 |
+
|
| 158 |
+
# 4x hidden dimension (Standard für GPT)
|
| 159 |
+
hidden_dim = 4 * d_model
|
| 160 |
+
|
| 161 |
+
# Gewichte für alle Experten (batch matmul)
|
| 162 |
+
self.w_fc = nn.Parameter(torch.empty(n_experts, d_model, hidden_dim))
|
| 163 |
+
self.w_proj = nn.Parameter(torch.empty(n_experts, hidden_dim, d_model))
|
| 164 |
+
|
| 165 |
+
if bias:
|
| 166 |
+
self.fc_bias = nn.Parameter(torch.empty(n_experts, 1, hidden_dim))
|
| 167 |
+
self.proj_bias = nn.Parameter(torch.empty(n_experts, 1, d_model))
|
| 168 |
+
else:
|
| 169 |
+
self.register_parameter("fc_bias", None)
|
| 170 |
+
self.register_parameter("proj_bias", None)
|
| 171 |
+
|
| 172 |
+
# Aktivierungsfunktion
|
| 173 |
+
if activation == "gelu":
|
| 174 |
+
self.activation = nn.GELU()
|
| 175 |
+
elif activation == "relu":
|
| 176 |
+
self.activation = nn.ReLU()
|
| 177 |
+
elif activation == "swiglu":
|
| 178 |
+
# SwiGLU braucht extra Gewichte
|
| 179 |
+
self.w_gate = nn.Parameter(torch.empty(n_experts, d_model, hidden_dim))
|
| 180 |
+
self.activation = nn.SiLU()
|
| 181 |
+
else:
|
| 182 |
+
raise ValueError(f"Unbekannte Aktivierung: {activation}")
|
| 183 |
+
|
| 184 |
+
self.dropout = nn.Dropout(dropout)
|
| 185 |
+
self.activation_type = activation
|
| 186 |
+
|
| 187 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 188 |
+
"""
|
| 189 |
+
Args:
|
| 190 |
+
x: [n_experts, capacity, d_model]
|
| 191 |
+
|
| 192 |
+
Returns:
|
| 193 |
+
output: [n_experts, capacity, d_model]
|
| 194 |
+
"""
|
| 195 |
+
# Erste Linear Layer mit batch matmul
|
| 196 |
+
h = torch.bmm(x, self.w_fc)
|
| 197 |
+
if self.bias:
|
| 198 |
+
h = h + self.fc_bias
|
| 199 |
+
|
| 200 |
+
# Aktivierung
|
| 201 |
+
if self.activation_type == "swiglu":
|
| 202 |
+
# SwiGLU: silu(x @ W_gate) * (x @ W_fc)
|
| 203 |
+
gate = torch.bmm(x, self.w_gate)
|
| 204 |
+
h = self.activation(gate) * h
|
| 205 |
+
else:
|
| 206 |
+
h = self.activation(h)
|
| 207 |
+
|
| 208 |
+
# Zweite Linear Layer
|
| 209 |
+
output = torch.bmm(h, self.w_proj)
|
| 210 |
+
if self.bias:
|
| 211 |
+
output = output + self.proj_bias
|
| 212 |
+
|
| 213 |
+
output = self.dropout(output)
|
| 214 |
+
|
| 215 |
+
return output
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
class MoELayer(nn.Module):
|
| 219 |
+
"""
|
| 220 |
+
Vollständige Mixture-of-Experts Layer.
|
| 221 |
+
Kombiniert Router und Experten.
|
| 222 |
+
"""
|
| 223 |
+
|
| 224 |
+
def __init__(
|
| 225 |
+
self,
|
| 226 |
+
d_model: int,
|
| 227 |
+
n_experts: int = 8,
|
| 228 |
+
n_experts_active: int = 2,
|
| 229 |
+
use_noisy_gating: bool = True,
|
| 230 |
+
capacity_factor: float = 1.25,
|
| 231 |
+
bias: bool = False,
|
| 232 |
+
dropout: float = 0.1,
|
| 233 |
+
activation: str = "gelu",
|
| 234 |
+
):
|
| 235 |
+
super().__init__()
|
| 236 |
+
|
| 237 |
+
self.router = MoERouter(
|
| 238 |
+
d_model=d_model,
|
| 239 |
+
n_experts=n_experts,
|
| 240 |
+
n_experts_active=n_experts_active,
|
| 241 |
+
use_noisy_gating=use_noisy_gating,
|
| 242 |
+
capacity_factor=capacity_factor,
|
| 243 |
+
)
|
| 244 |
+
|
| 245 |
+
self.experts = ExpertMLP(
|
| 246 |
+
d_model=d_model,
|
| 247 |
+
n_experts=n_experts,
|
| 248 |
+
bias=bias,
|
| 249 |
+
dropout=dropout,
|
| 250 |
+
activation=activation,
|
| 251 |
+
)
|
| 252 |
+
|
| 253 |
+
self.n_experts = n_experts
|
| 254 |
+
self.n_experts_active = n_experts_active
|
| 255 |
+
|
| 256 |
+
def forward(
|
| 257 |
+
self, x: torch.Tensor
|
| 258 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 259 |
+
"""
|
| 260 |
+
Args:
|
| 261 |
+
x: [batch_size, seq_len, d_model]
|
| 262 |
+
|
| 263 |
+
Returns:
|
| 264 |
+
output: [batch_size, seq_len, d_model]
|
| 265 |
+
load_balance_loss: Skalarer Load Balancing Loss
|
| 266 |
+
router_z_loss: Skalarer Router Z-Loss
|
| 267 |
+
"""
|
| 268 |
+
batch_size, seq_len, d_model = x.shape
|
| 269 |
+
num_tokens = batch_size * seq_len
|
| 270 |
+
|
| 271 |
+
# Routing
|
| 272 |
+
expert_weights, expert_mask, expert_batches, router_logits = self.router(x)
|
| 273 |
+
|
| 274 |
+
# Expert Forward Pass
|
| 275 |
+
expert_outputs = self.experts(expert_batches) # [n_experts, capacity, d_model]
|
| 276 |
+
|
| 277 |
+
# Outputs kombinieren (gewichteter Durchschnitt)
|
| 278 |
+
expert_weights_flat = expert_weights.view(num_tokens, -1) # [num_tokens, n_experts * capacity]
|
| 279 |
+
expert_outputs_flat = expert_outputs.view(-1, d_model) # [n_experts * capacity, d_model]
|
| 280 |
+
output = expert_weights_flat @ expert_outputs_flat # [num_tokens, d_model]
|
| 281 |
+
output = output.view(batch_size, seq_len, d_model)
|
| 282 |
+
|
| 283 |
+
# Auxiliary Losses berechnen
|
| 284 |
+
load_balance_loss = self._compute_load_balance_loss(router_logits, expert_mask)
|
| 285 |
+
router_z_loss = self._compute_router_z_loss(router_logits)
|
| 286 |
+
|
| 287 |
+
return output, load_balance_loss, router_z_loss
|
| 288 |
+
|
| 289 |
+
def _compute_load_balance_loss(
|
| 290 |
+
self, router_logits: torch.Tensor, expert_mask: torch.Tensor
|
| 291 |
+
) -> torch.Tensor:
|
| 292 |
+
"""
|
| 293 |
+
Load Balancing Loss (Switch Transformer, Fedus et al. 2022)
|
| 294 |
+
Encourages uniform distribution of tokens across experts.
|
| 295 |
+
"""
|
| 296 |
+
batch_size, seq_len, n_experts = router_logits.shape
|
| 297 |
+
num_tokens = batch_size * seq_len
|
| 298 |
+
|
| 299 |
+
# Probability pro Expert
|
| 300 |
+
router_probs = F.softmax(router_logits, dim=-1) # [B, T, n_experts]
|
| 301 |
+
prob_per_expert = torch.mean(router_probs, dim=(0, 1)) # [n_experts]
|
| 302 |
+
|
| 303 |
+
# Token Ratio pro Expert
|
| 304 |
+
with torch.no_grad():
|
| 305 |
+
# expert_mask ist [num_tokens, n_experts, capacity]
|
| 306 |
+
tokens_per_expert = torch.sum(expert_mask.float(), dim=(0, 2)) # [n_experts]
|
| 307 |
+
tokens_per_expert = tokens_per_expert / (num_tokens * self.n_experts_active)
|
| 308 |
+
|
| 309 |
+
# Dot product (scaled by n_experts)
|
| 310 |
+
loss = self.n_experts * torch.sum(prob_per_expert * tokens_per_expert)
|
| 311 |
+
|
| 312 |
+
return loss
|
| 313 |
+
|
| 314 |
+
def _compute_router_z_loss(self, router_logits: torch.Tensor) -> torch.Tensor:
|
| 315 |
+
"""
|
| 316 |
+
Router Z-Loss (ST-MoE, Zoph et al. 2022)
|
| 317 |
+
Penalisiert große Router Logits für numerische Stabilität.
|
| 318 |
+
"""
|
| 319 |
+
# Squared logsumexp über Experten
|
| 320 |
+
z_loss = torch.logsumexp(router_logits, dim=-1) ** 2.0 # [B, T]
|
| 321 |
+
z_loss = torch.mean(z_loss)
|
| 322 |
+
|
| 323 |
+
return z_loss
|
moe_model.py
ADDED
|
@@ -0,0 +1,460 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
MoE GPT Model - HuggingFace kompatibel
|
| 3 |
+
Basiert auf nanoMoE und dem Blog Post
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import math
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
from typing import Optional, Tuple, Union
|
| 11 |
+
from dataclasses import dataclass
|
| 12 |
+
|
| 13 |
+
from transformers import PreTrainedModel
|
| 14 |
+
from transformers.generation import GenerationMixin
|
| 15 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
| 16 |
+
|
| 17 |
+
from moe_config import MoEGPTConfig
|
| 18 |
+
from moe_layers import MoELayer
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
@dataclass
|
| 22 |
+
class MoECausalLMOutput(CausalLMOutputWithPast):
|
| 23 |
+
"""
|
| 24 |
+
Erweiterte Output Klasse mit MoE-spezifischen Losses
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
aux_loss: Optional[torch.FloatTensor] = None
|
| 28 |
+
router_z_loss: Optional[torch.FloatTensor] = None
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def apply_rotary_emb(x: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor) -> torch.Tensor:
|
| 32 |
+
"""
|
| 33 |
+
Applies Rotary Position Embeddings (RoPE) to input tensor.
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
x: Input tensor of shape [B, H, T, D]
|
| 37 |
+
freqs_cos: Cosine frequencies of shape [T, D//2]
|
| 38 |
+
freqs_sin: Sine frequencies of shape [T, D//2]
|
| 39 |
+
|
| 40 |
+
Returns:
|
| 41 |
+
Tensor with RoPE applied
|
| 42 |
+
"""
|
| 43 |
+
# Reshape x to separate real and imaginary parts for rotation
|
| 44 |
+
# x: [B, H, T, D] -> [B, H, T, D//2, 2]
|
| 45 |
+
x_complex = x.float().reshape(*x.shape[:-1], -1, 2)
|
| 46 |
+
|
| 47 |
+
# Apply rotation: (a + bi) * (cos + i*sin) = (a*cos - b*sin) + i(a*sin + b*cos)
|
| 48 |
+
x_rot_real = x_complex[..., 0] * freqs_cos - x_complex[..., 1] * freqs_sin
|
| 49 |
+
x_rot_imag = x_complex[..., 0] * freqs_sin + x_complex[..., 1] * freqs_cos
|
| 50 |
+
|
| 51 |
+
# Stack back together and flatten
|
| 52 |
+
x_out = torch.stack([x_rot_real, x_rot_imag], dim=-1)
|
| 53 |
+
x_out = x_out.flatten(-2)
|
| 54 |
+
|
| 55 |
+
return x_out.type_as(x)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def precompute_freqs_rope(dim: int, max_seq_len: int, theta: float = 10000.0) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 59 |
+
"""
|
| 60 |
+
Precomputes RoPE frequencies.
|
| 61 |
+
|
| 62 |
+
Args:
|
| 63 |
+
dim: Head dimension
|
| 64 |
+
max_seq_len: Maximum sequence length
|
| 65 |
+
theta: RoPE theta parameter (base for frequency calculation)
|
| 66 |
+
|
| 67 |
+
Returns:
|
| 68 |
+
Tuple of (freqs_cos, freqs_sin) tensors of shape [max_seq_len, dim//2]
|
| 69 |
+
"""
|
| 70 |
+
# Compute frequencies for each dimension pair
|
| 71 |
+
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
|
| 72 |
+
|
| 73 |
+
# Create position indices
|
| 74 |
+
t = torch.arange(max_seq_len, dtype=torch.float32)
|
| 75 |
+
|
| 76 |
+
# Compute outer product: [max_seq_len, dim//2]
|
| 77 |
+
freqs = torch.outer(t, freqs)
|
| 78 |
+
|
| 79 |
+
# Compute cos and sin
|
| 80 |
+
freqs_cos = torch.cos(freqs)
|
| 81 |
+
freqs_sin = torch.sin(freqs)
|
| 82 |
+
|
| 83 |
+
return freqs_cos, freqs_sin
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
class CausalSelfAttention(nn.Module):
|
| 87 |
+
"""
|
| 88 |
+
Multi-Head Causal Self-Attention with Rotary Position Embeddings (RoPE).
|
| 89 |
+
Uses PyTorch SDPA for optimized performance.
|
| 90 |
+
"""
|
| 91 |
+
|
| 92 |
+
def __init__(self, config: MoEGPTConfig):
|
| 93 |
+
super().__init__()
|
| 94 |
+
assert config.n_embd % config.n_head == 0
|
| 95 |
+
|
| 96 |
+
# Key, Query, Value für alle Heads gleichzeitig
|
| 97 |
+
self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
|
| 98 |
+
# Output Projektion
|
| 99 |
+
self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
|
| 100 |
+
|
| 101 |
+
# Regularization
|
| 102 |
+
self.attn_dropout = nn.Dropout(config.dropout)
|
| 103 |
+
self.resid_dropout = nn.Dropout(config.dropout)
|
| 104 |
+
|
| 105 |
+
self.n_head = config.n_head
|
| 106 |
+
self.n_embd = config.n_embd
|
| 107 |
+
self.dropout = config.dropout
|
| 108 |
+
self.head_dim = config.n_embd // config.n_head
|
| 109 |
+
|
| 110 |
+
# Precompute RoPE frequencies
|
| 111 |
+
freqs_cos, freqs_sin = precompute_freqs_rope(
|
| 112 |
+
dim=self.head_dim,
|
| 113 |
+
max_seq_len=config.n_positions,
|
| 114 |
+
theta=config.rope_theta
|
| 115 |
+
)
|
| 116 |
+
self.register_buffer("freqs_cos", freqs_cos, persistent=False)
|
| 117 |
+
self.register_buffer("freqs_sin", freqs_sin, persistent=False)
|
| 118 |
+
|
| 119 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 120 |
+
B, T, C = x.size() # batch, sequence length, embedding dim
|
| 121 |
+
|
| 122 |
+
# Q, K, V berechnen
|
| 123 |
+
q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
|
| 124 |
+
|
| 125 |
+
# Reshape für Multi-Head
|
| 126 |
+
k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2) # [B, H, T, d]
|
| 127 |
+
q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
|
| 128 |
+
v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
|
| 129 |
+
|
| 130 |
+
# Apply RoPE to Q and K
|
| 131 |
+
q = apply_rotary_emb(q, self.freqs_cos[:T], self.freqs_sin[:T])
|
| 132 |
+
k = apply_rotary_emb(k, self.freqs_cos[:T], self.freqs_sin[:T])
|
| 133 |
+
|
| 134 |
+
# Use PyTorch SDPA (Scaled Dot Product Attention) - optimized!
|
| 135 |
+
# SDPA handles causal masking, dropout, and is memory efficient
|
| 136 |
+
y = F.scaled_dot_product_attention(
|
| 137 |
+
q, k, v,
|
| 138 |
+
attn_mask=None, # Causal mask handled by is_causal
|
| 139 |
+
dropout_p=self.dropout if self.training else 0.0,
|
| 140 |
+
is_causal=True # Efficient causal masking
|
| 141 |
+
) # [B, H, T, d]
|
| 142 |
+
|
| 143 |
+
# Reshape back
|
| 144 |
+
y = y.transpose(1, 2).contiguous().view(B, T, C)
|
| 145 |
+
|
| 146 |
+
# Output Projektion
|
| 147 |
+
y = self.resid_dropout(self.c_proj(y))
|
| 148 |
+
|
| 149 |
+
return y
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
class MLP(nn.Module):
|
| 153 |
+
"""
|
| 154 |
+
Standard Feed-Forward Network (für nicht-MoE Layers)
|
| 155 |
+
"""
|
| 156 |
+
|
| 157 |
+
def __init__(self, config: MoEGPTConfig):
|
| 158 |
+
super().__init__()
|
| 159 |
+
self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
|
| 160 |
+
self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
|
| 161 |
+
self.dropout = nn.Dropout(config.dropout)
|
| 162 |
+
|
| 163 |
+
if config.activation_function == "gelu":
|
| 164 |
+
self.activation = nn.GELU()
|
| 165 |
+
elif config.activation_function == "relu":
|
| 166 |
+
self.activation = nn.ReLU()
|
| 167 |
+
else:
|
| 168 |
+
raise ValueError(f"Unbekannte Aktivierung: {config.activation_function}")
|
| 169 |
+
|
| 170 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 171 |
+
x = self.c_fc(x)
|
| 172 |
+
x = self.activation(x)
|
| 173 |
+
x = self.c_proj(x)
|
| 174 |
+
x = self.dropout(x)
|
| 175 |
+
return x
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
class TransformerBlock(nn.Module):
|
| 179 |
+
"""
|
| 180 |
+
Standard Transformer Block (Attention + MLP)
|
| 181 |
+
"""
|
| 182 |
+
|
| 183 |
+
def __init__(self, config: MoEGPTConfig):
|
| 184 |
+
super().__init__()
|
| 185 |
+
self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
|
| 186 |
+
self.attn = CausalSelfAttention(config)
|
| 187 |
+
self.ln_2 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
|
| 188 |
+
self.mlp = MLP(config)
|
| 189 |
+
|
| 190 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 191 |
+
x = x + self.attn(self.ln_1(x))
|
| 192 |
+
x = x + self.mlp(self.ln_2(x))
|
| 193 |
+
return x
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
class MoETransformerBlock(nn.Module):
|
| 197 |
+
"""
|
| 198 |
+
MoE Transformer Block (Attention + MoE Layer)
|
| 199 |
+
"""
|
| 200 |
+
|
| 201 |
+
def __init__(self, config: MoEGPTConfig):
|
| 202 |
+
super().__init__()
|
| 203 |
+
self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
|
| 204 |
+
self.attn = CausalSelfAttention(config)
|
| 205 |
+
self.ln_2 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
|
| 206 |
+
|
| 207 |
+
# Capacity Factor abhängig von Training/Eval
|
| 208 |
+
self.moe = MoELayer(
|
| 209 |
+
d_model=config.n_embd,
|
| 210 |
+
n_experts=config.n_experts,
|
| 211 |
+
n_experts_active=config.n_experts_active,
|
| 212 |
+
use_noisy_gating=config.use_noisy_gating,
|
| 213 |
+
capacity_factor=config.capacity_factor,
|
| 214 |
+
bias=config.bias,
|
| 215 |
+
dropout=config.dropout,
|
| 216 |
+
activation=config.activation_function,
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
def forward(
|
| 220 |
+
self, x: torch.Tensor
|
| 221 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 222 |
+
# Attention
|
| 223 |
+
x = x + self.attn(self.ln_1(x))
|
| 224 |
+
|
| 225 |
+
# MoE Layer
|
| 226 |
+
moe_out, aux_loss, router_z_loss = self.moe(self.ln_2(x))
|
| 227 |
+
x = x + moe_out
|
| 228 |
+
|
| 229 |
+
return x, aux_loss, router_z_loss
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
class MoEGPTPreTrainedModel(PreTrainedModel):
|
| 233 |
+
"""
|
| 234 |
+
Base Klasse für MoE GPT mit HuggingFace PreTrainedModel
|
| 235 |
+
"""
|
| 236 |
+
|
| 237 |
+
config_class = MoEGPTConfig
|
| 238 |
+
base_model_prefix = "transformer"
|
| 239 |
+
supports_gradient_checkpointing = True
|
| 240 |
+
|
| 241 |
+
def _init_weights(self, module):
|
| 242 |
+
"""
|
| 243 |
+
Weight Initialization nach ST-MoE (Zoph et al. 2022)
|
| 244 |
+
Truncated Normal mit reduzierter Std für MoE Stabilität
|
| 245 |
+
"""
|
| 246 |
+
if isinstance(module, nn.Linear):
|
| 247 |
+
# Fan-in Initialization
|
| 248 |
+
fan_in = module.weight.shape[-1]
|
| 249 |
+
std = (self.config.initializer_range / fan_in) ** 0.5
|
| 250 |
+
|
| 251 |
+
torch.nn.init.trunc_normal_(
|
| 252 |
+
module.weight,
|
| 253 |
+
mean=0.0,
|
| 254 |
+
std=std,
|
| 255 |
+
a=-2 * std,
|
| 256 |
+
b=2 * std,
|
| 257 |
+
)
|
| 258 |
+
if module.bias is not None:
|
| 259 |
+
torch.nn.init.zeros_(module.bias)
|
| 260 |
+
|
| 261 |
+
elif isinstance(module, nn.Embedding):
|
| 262 |
+
torch.nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
|
| 263 |
+
|
| 264 |
+
elif isinstance(module, nn.Parameter):
|
| 265 |
+
# Für Expert Parameter
|
| 266 |
+
fan_in = module.shape[-1] if len(module.shape) >= 2 else module.shape[0]
|
| 267 |
+
std = (self.config.initializer_range / fan_in) ** 0.5
|
| 268 |
+
|
| 269 |
+
torch.nn.init.trunc_normal_(
|
| 270 |
+
module,
|
| 271 |
+
mean=0.0,
|
| 272 |
+
std=std,
|
| 273 |
+
a=-2 * std,
|
| 274 |
+
b=2 * std,
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
class MoEGPTModel(MoEGPTPreTrainedModel):
|
| 279 |
+
"""
|
| 280 |
+
MoE GPT Model (ohne LM Head)
|
| 281 |
+
"""
|
| 282 |
+
|
| 283 |
+
def __init__(self, config: MoEGPTConfig):
|
| 284 |
+
super().__init__(config)
|
| 285 |
+
self.config = config
|
| 286 |
+
self.gradient_checkpointing = False # Für HF Gradient Checkpointing Support
|
| 287 |
+
|
| 288 |
+
# Token Embeddings only (RoPE handles positions)
|
| 289 |
+
self.wte = nn.Embedding(config.vocab_size, config.n_embd)
|
| 290 |
+
self.drop = nn.Dropout(config.dropout)
|
| 291 |
+
|
| 292 |
+
# Transformer Blocks (gemischt: Standard + MoE)
|
| 293 |
+
self.h = nn.ModuleList()
|
| 294 |
+
for i in range(config.n_layer):
|
| 295 |
+
if i % config.moe_layer_frequency == 0:
|
| 296 |
+
# MoE Block
|
| 297 |
+
self.h.append(MoETransformerBlock(config))
|
| 298 |
+
else:
|
| 299 |
+
# Standard Block
|
| 300 |
+
self.h.append(TransformerBlock(config))
|
| 301 |
+
|
| 302 |
+
# Final Layer Norm
|
| 303 |
+
self.ln_f = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
|
| 304 |
+
|
| 305 |
+
# Initialize weights
|
| 306 |
+
self.post_init()
|
| 307 |
+
|
| 308 |
+
def forward(
|
| 309 |
+
self,
|
| 310 |
+
input_ids: torch.LongTensor,
|
| 311 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 312 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 313 |
+
device = input_ids.device
|
| 314 |
+
b, t = input_ids.size()
|
| 315 |
+
|
| 316 |
+
assert t <= self.config.n_positions, f"Sequenz zu lang: {t} > {self.config.n_positions}"
|
| 317 |
+
|
| 318 |
+
# Token Embeddings only (RoPE in attention layers)
|
| 319 |
+
tok_emb = self.wte(input_ids) # [B, T, n_embd]
|
| 320 |
+
x = self.drop(tok_emb)
|
| 321 |
+
|
| 322 |
+
# Sammle Auxiliary Losses
|
| 323 |
+
total_aux_loss = 0.0
|
| 324 |
+
total_router_z_loss = 0.0
|
| 325 |
+
|
| 326 |
+
# Durch alle Blocks
|
| 327 |
+
for block in self.h:
|
| 328 |
+
if isinstance(block, MoETransformerBlock):
|
| 329 |
+
if self.gradient_checkpointing and self.training:
|
| 330 |
+
# Gradient Checkpointing für MoE Blocks
|
| 331 |
+
def create_custom_forward(module):
|
| 332 |
+
def custom_forward(*inputs):
|
| 333 |
+
return module(*inputs)
|
| 334 |
+
return custom_forward
|
| 335 |
+
|
| 336 |
+
x, aux_loss, router_z_loss = torch.utils.checkpoint.checkpoint(
|
| 337 |
+
create_custom_forward(block),
|
| 338 |
+
x,
|
| 339 |
+
use_reentrant=False
|
| 340 |
+
)
|
| 341 |
+
else:
|
| 342 |
+
x, aux_loss, router_z_loss = block(x)
|
| 343 |
+
total_aux_loss = total_aux_loss + aux_loss
|
| 344 |
+
total_router_z_loss = total_router_z_loss + router_z_loss
|
| 345 |
+
else:
|
| 346 |
+
if self.gradient_checkpointing and self.training:
|
| 347 |
+
x = torch.utils.checkpoint.checkpoint(
|
| 348 |
+
block,
|
| 349 |
+
x,
|
| 350 |
+
use_reentrant=False
|
| 351 |
+
)
|
| 352 |
+
else:
|
| 353 |
+
x = block(x)
|
| 354 |
+
|
| 355 |
+
x = self.ln_f(x)
|
| 356 |
+
|
| 357 |
+
return x, total_aux_loss, total_router_z_loss
|
| 358 |
+
|
| 359 |
+
|
| 360 |
+
class MoEGPTForCausalLM(MoEGPTPreTrainedModel, GenerationMixin):
|
| 361 |
+
"""
|
| 362 |
+
MoE GPT mit Language Modeling Head (für Pretraining)
|
| 363 |
+
Erbt von GenerationMixin für .generate() Support
|
| 364 |
+
"""
|
| 365 |
+
|
| 366 |
+
# Teile HuggingFace mit, welche Weights geteilt sind
|
| 367 |
+
_tied_weights_keys = ["lm_head.weight"]
|
| 368 |
+
|
| 369 |
+
def __init__(self, config: MoEGPTConfig):
|
| 370 |
+
super().__init__(config)
|
| 371 |
+
self.transformer = MoEGPTModel(config)
|
| 372 |
+
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
| 373 |
+
|
| 374 |
+
# Weight Tying (LM Head teilt Gewichte mit Token Embedding)
|
| 375 |
+
self.lm_head.weight = self.transformer.wte.weight
|
| 376 |
+
|
| 377 |
+
# Initialize weights
|
| 378 |
+
self.post_init()
|
| 379 |
+
|
| 380 |
+
def get_output_embeddings(self):
|
| 381 |
+
"""Für HuggingFace Weight Tying"""
|
| 382 |
+
return self.lm_head
|
| 383 |
+
|
| 384 |
+
def set_output_embeddings(self, new_embeddings):
|
| 385 |
+
"""Für HuggingFace Weight Tying"""
|
| 386 |
+
self.lm_head = new_embeddings
|
| 387 |
+
|
| 388 |
+
def get_input_embeddings(self):
|
| 389 |
+
"""Für HuggingFace Weight Tying"""
|
| 390 |
+
return self.transformer.wte
|
| 391 |
+
|
| 392 |
+
def set_input_embeddings(self, new_embeddings):
|
| 393 |
+
"""Für HuggingFace Weight Tying"""
|
| 394 |
+
self.transformer.wte = new_embeddings
|
| 395 |
+
|
| 396 |
+
def tie_weights(self):
|
| 397 |
+
"""
|
| 398 |
+
Tie lm_head weights to input embeddings (weight tying)
|
| 399 |
+
Called after loading checkpoint to fix missing lm_head.weight
|
| 400 |
+
"""
|
| 401 |
+
self.lm_head.weight = self.transformer.wte.weight
|
| 402 |
+
|
| 403 |
+
def forward(
|
| 404 |
+
self,
|
| 405 |
+
input_ids: torch.LongTensor,
|
| 406 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 407 |
+
labels: Optional[torch.LongTensor] = None,
|
| 408 |
+
return_dict: Optional[bool] = None,
|
| 409 |
+
**kwargs, # Accept additional kwargs like use_cache for HuggingFace compatibility
|
| 410 |
+
) -> Union[Tuple, MoECausalLMOutput]:
|
| 411 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 412 |
+
|
| 413 |
+
# Forward durch Transformer
|
| 414 |
+
hidden_states, aux_loss, router_z_loss = self.transformer(
|
| 415 |
+
input_ids=input_ids,
|
| 416 |
+
attention_mask=attention_mask,
|
| 417 |
+
)
|
| 418 |
+
|
| 419 |
+
# LM Head
|
| 420 |
+
if labels is not None:
|
| 421 |
+
# Training: nur letzte Position für jede Sequenz
|
| 422 |
+
logits = self.lm_head(hidden_states)
|
| 423 |
+
else:
|
| 424 |
+
# Inference: nur letzte Position
|
| 425 |
+
logits = self.lm_head(hidden_states[:, [-1], :])
|
| 426 |
+
|
| 427 |
+
# Loss berechnen
|
| 428 |
+
loss = None
|
| 429 |
+
if labels is not None:
|
| 430 |
+
# Shift für next token prediction
|
| 431 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
| 432 |
+
shift_labels = labels[..., 1:].contiguous()
|
| 433 |
+
|
| 434 |
+
# Cross Entropy Loss
|
| 435 |
+
loss_fct = nn.CrossEntropyLoss()
|
| 436 |
+
lm_loss = loss_fct(
|
| 437 |
+
shift_logits.view(-1, shift_logits.size(-1)),
|
| 438 |
+
shift_labels.view(-1),
|
| 439 |
+
)
|
| 440 |
+
|
| 441 |
+
# Auxiliary Losses hinzufügen
|
| 442 |
+
loss = lm_loss
|
| 443 |
+
if self.training:
|
| 444 |
+
loss = loss + self.config.aux_loss_alpha * aux_loss
|
| 445 |
+
loss = loss + self.config.router_z_loss_alpha * router_z_loss
|
| 446 |
+
|
| 447 |
+
if not return_dict:
|
| 448 |
+
output = (logits,)
|
| 449 |
+
return ((loss,) + output) if loss is not None else output
|
| 450 |
+
|
| 451 |
+
return MoECausalLMOutput(
|
| 452 |
+
loss=loss,
|
| 453 |
+
logits=logits,
|
| 454 |
+
aux_loss=aux_loss if self.training else None,
|
| 455 |
+
router_z_loss=router_z_loss if self.training else None,
|
| 456 |
+
)
|
| 457 |
+
|
| 458 |
+
def prepare_inputs_for_generation(self, input_ids, **kwargs):
|
| 459 |
+
"""Für HuggingFace generate() Funktion"""
|
| 460 |
+
return {"input_ids": input_ids}
|