File size: 11,841 Bytes
7e2e7b9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 | """
MoE Layer Komponenten
Basierend auf dem nanoMoE Blog Post und HuggingFace Best Practices
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple, Optional
class MoERouter(nn.Module):
"""
Noisy Top-k Router für MoE.
Routet Tokens zu den Top-k Experten basierend auf gelernten Wahrscheinlichkeiten.
"""
def __init__(
self,
d_model: int,
n_experts: int,
n_experts_active: int,
use_noisy_gating: bool = True,
capacity_factor: float = 1.25,
):
super().__init__()
self.d_model = d_model
self.n_experts = n_experts
self.n_experts_active = n_experts_active
self.use_noisy_gating = use_noisy_gating
self.capacity_factor = capacity_factor
# Linear projections für Router (kein Bias, siehe Shazeer et al. 2017)
self.w_gate = nn.Linear(d_model, n_experts, bias=False)
self.w_noise = nn.Linear(d_model, n_experts, bias=False) if use_noisy_gating else None
def forward(
self, x: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Args:
x: Input tensor [batch_size, seq_len, d_model]
Returns:
expert_weights: Gewichte für jeden Experten [batch_size * seq_len, n_experts, capacity]
expert_mask: Maske für verwendete Experten [batch_size * seq_len, n_experts, capacity]
expert_batches: Batches für jeden Experten [n_experts, capacity, d_model]
router_logits: Router Logits für z-loss [batch_size, seq_len, n_experts]
"""
batch_size, seq_len, d_model = x.shape
num_tokens = batch_size * seq_len
# Router läuft IMMER in FP32 für numerische Stabilität!
device_type = "cuda" if x.is_cuda else "cpu"
with torch.amp.autocast(device_type=device_type, enabled=False):
x_fp32 = x.float()
# Router Logits berechnen
router_logits = self.w_gate(x_fp32) # [B, T, n_experts]
# Noisy Top-k Gating (optional)
if self.use_noisy_gating and self.training:
noise = F.softplus(self.w_noise(x_fp32))
noise = noise * torch.randn_like(noise)
router_logits = router_logits + noise
# Top-k Experten auswählen
top_k_logits, top_k_indices = router_logits.topk(
self.n_experts_active, dim=-1
) # [B, T, K]
# Softmax über alle Experten (nicht nur Top-k)
router_probs = torch.full_like(router_logits, float("-inf"))
router_probs.scatter_(-1, top_k_indices, top_k_logits)
router_probs = F.softmax(router_probs, dim=-1) # [B, T, n_experts]
# Expert Capacity berechnen
capacity = self._compute_capacity(num_tokens)
# Multi-hot Maske der gewählten Experten
expert_mask = F.one_hot(
top_k_indices, num_classes=self.n_experts
) # [B, T, K, n_experts]
expert_mask = expert_mask.view(num_tokens, self.n_experts_active, self.n_experts)
expert_mask = expert_mask.permute(1, 0, 2) # [K, num_tokens, n_experts]
# Position jedes Tokens im Expert Batch (cumsum für Top-1 first prioritization)
expert_rank = expert_mask.reshape(
self.n_experts_active * num_tokens, self.n_experts
)
expert_rank = torch.cumsum(expert_rank, dim=0) - 1
expert_rank = expert_rank.reshape(
self.n_experts_active, num_tokens, self.n_experts
)
# Tokens über Kapazität hinaus maskieren
expert_mask = expert_mask * torch.lt(expert_rank, capacity)
# Position im Expert Batch
expert_rank = torch.sum(expert_mask * expert_rank, dim=-1) # [K, num_tokens]
# Wahrscheinlichkeiten mit Maske multiplizieren
router_probs = router_probs.view(num_tokens, self.n_experts)[
None, :
] # [1, num_tokens, n_experts]
expert_weights = expert_mask * router_probs # [K, num_tokens, n_experts]
# One-hot für Position in Expert Batch
expert_rank_one_hot = F.one_hot(
expert_rank, num_classes=capacity
) # [K, num_tokens, capacity]
# Gewichte an Expert Batch Position
expert_weights = torch.sum(
expert_weights.unsqueeze(3) * expert_rank_one_hot.unsqueeze(2), dim=0
) # [num_tokens, n_experts, capacity]
expert_mask = expert_weights.bool()
# Expert Batches erstellen
x_flat = x.view(num_tokens, d_model)
expert_batches = (
expert_mask.permute(1, 2, 0).type_as(x) @ x_flat
) # [n_experts, capacity, d_model]
return expert_weights, expert_mask, expert_batches, router_logits
def _compute_capacity(self, num_tokens: int) -> int:
"""Berechnet Expert Capacity"""
capacity = math.floor(
self.n_experts_active * self.capacity_factor * num_tokens / self.n_experts
)
capacity += capacity % 2 # Gerade Zahl für bessere Hardware-Nutzung
return max(int(capacity), 2) # Minimum 2 für kleine Batches
class ExpertMLP(nn.Module):
"""
Batch von MLP Experten.
Alle Experten haben die gleiche Architektur, aber unabhängige Gewichte.
"""
def __init__(
self,
d_model: int,
n_experts: int,
bias: bool = False,
dropout: float = 0.1,
activation: str = "gelu",
):
super().__init__()
self.d_model = d_model
self.n_experts = n_experts
self.bias = bias
# 4x hidden dimension (Standard für GPT)
hidden_dim = 4 * d_model
# Gewichte für alle Experten (batch matmul)
self.w_fc = nn.Parameter(torch.empty(n_experts, d_model, hidden_dim))
self.w_proj = nn.Parameter(torch.empty(n_experts, hidden_dim, d_model))
if bias:
self.fc_bias = nn.Parameter(torch.empty(n_experts, 1, hidden_dim))
self.proj_bias = nn.Parameter(torch.empty(n_experts, 1, d_model))
else:
self.register_parameter("fc_bias", None)
self.register_parameter("proj_bias", None)
# Aktivierungsfunktion
if activation == "gelu":
self.activation = nn.GELU()
elif activation == "relu":
self.activation = nn.ReLU()
elif activation == "swiglu":
# SwiGLU braucht extra Gewichte
self.w_gate = nn.Parameter(torch.empty(n_experts, d_model, hidden_dim))
self.activation = nn.SiLU()
else:
raise ValueError(f"Unbekannte Aktivierung: {activation}")
self.dropout = nn.Dropout(dropout)
self.activation_type = activation
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: [n_experts, capacity, d_model]
Returns:
output: [n_experts, capacity, d_model]
"""
# Erste Linear Layer mit batch matmul
h = torch.bmm(x, self.w_fc)
if self.bias:
h = h + self.fc_bias
# Aktivierung
if self.activation_type == "swiglu":
# SwiGLU: silu(x @ W_gate) * (x @ W_fc)
gate = torch.bmm(x, self.w_gate)
h = self.activation(gate) * h
else:
h = self.activation(h)
# Zweite Linear Layer
output = torch.bmm(h, self.w_proj)
if self.bias:
output = output + self.proj_bias
output = self.dropout(output)
return output
class MoELayer(nn.Module):
"""
Vollständige Mixture-of-Experts Layer.
Kombiniert Router und Experten.
"""
def __init__(
self,
d_model: int,
n_experts: int = 8,
n_experts_active: int = 2,
use_noisy_gating: bool = True,
capacity_factor: float = 1.25,
bias: bool = False,
dropout: float = 0.1,
activation: str = "gelu",
):
super().__init__()
self.router = MoERouter(
d_model=d_model,
n_experts=n_experts,
n_experts_active=n_experts_active,
use_noisy_gating=use_noisy_gating,
capacity_factor=capacity_factor,
)
self.experts = ExpertMLP(
d_model=d_model,
n_experts=n_experts,
bias=bias,
dropout=dropout,
activation=activation,
)
self.n_experts = n_experts
self.n_experts_active = n_experts_active
def forward(
self, x: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Args:
x: [batch_size, seq_len, d_model]
Returns:
output: [batch_size, seq_len, d_model]
load_balance_loss: Skalarer Load Balancing Loss
router_z_loss: Skalarer Router Z-Loss
"""
batch_size, seq_len, d_model = x.shape
num_tokens = batch_size * seq_len
# Routing
expert_weights, expert_mask, expert_batches, router_logits = self.router(x)
# Expert Forward Pass
expert_outputs = self.experts(expert_batches) # [n_experts, capacity, d_model]
# Outputs kombinieren (gewichteter Durchschnitt)
expert_weights_flat = expert_weights.view(num_tokens, -1) # [num_tokens, n_experts * capacity]
expert_outputs_flat = expert_outputs.view(-1, d_model) # [n_experts * capacity, d_model]
output = expert_weights_flat @ expert_outputs_flat # [num_tokens, d_model]
output = output.view(batch_size, seq_len, d_model)
# Auxiliary Losses berechnen
load_balance_loss = self._compute_load_balance_loss(router_logits, expert_mask)
router_z_loss = self._compute_router_z_loss(router_logits)
return output, load_balance_loss, router_z_loss
def _compute_load_balance_loss(
self, router_logits: torch.Tensor, expert_mask: torch.Tensor
) -> torch.Tensor:
"""
Load Balancing Loss (Switch Transformer, Fedus et al. 2022)
Encourages uniform distribution of tokens across experts.
"""
batch_size, seq_len, n_experts = router_logits.shape
num_tokens = batch_size * seq_len
# Probability pro Expert
router_probs = F.softmax(router_logits, dim=-1) # [B, T, n_experts]
prob_per_expert = torch.mean(router_probs, dim=(0, 1)) # [n_experts]
# Token Ratio pro Expert
with torch.no_grad():
# expert_mask ist [num_tokens, n_experts, capacity]
tokens_per_expert = torch.sum(expert_mask.float(), dim=(0, 2)) # [n_experts]
tokens_per_expert = tokens_per_expert / (num_tokens * self.n_experts_active)
# Dot product (scaled by n_experts)
loss = self.n_experts * torch.sum(prob_per_expert * tokens_per_expert)
return loss
def _compute_router_z_loss(self, router_logits: torch.Tensor) -> torch.Tensor:
"""
Router Z-Loss (ST-MoE, Zoph et al. 2022)
Penalisiert große Router Logits für numerische Stabilität.
"""
# Squared logsumexp über Experten
z_loss = torch.logsumexp(router_logits, dim=-1) ** 2.0 # [B, T]
z_loss = torch.mean(z_loss)
return z_loss
|