arnomatic commited on
Commit
7e2e7b9
·
verified ·
1 Parent(s): db1a28e

Upload 3 files

Browse files
Files changed (3) hide show
  1. moe_config.py +119 -0
  2. moe_layers.py +323 -0
  3. moe_model.py +460 -0
moe_config.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ HuggingFace-compatible MoE Configuration
3
+ Basierend auf dem nanoMoE Blog Post
4
+ """
5
+
6
+ from transformers import PretrainedConfig
7
+
8
+
9
+ class MoEGPTConfig(PretrainedConfig):
10
+ """
11
+ Konfiguration für MoE-basiertes GPT Modell.
12
+
13
+ Args:
14
+ vocab_size (int): Größe des Vokabulars
15
+ n_positions (int): Maximale Sequenzlänge
16
+ n_embd (int): Dimensionalität der Embeddings (d im Blog)
17
+ n_layer (int): Anzahl der Transformer Blocks
18
+ n_head (int): Anzahl der Attention Heads
19
+ n_experts (int): Anzahl der Experten pro MoE Layer
20
+ n_experts_active (int): Anzahl aktiver Experten (top-k)
21
+ moe_layer_frequency (int): Jede n-te Layer wird zu MoE (P im Blog)
22
+ capacity_factor (float): Expert Capacity Factor für Training
23
+ eval_capacity_factor (float): Expert Capacity Factor für Evaluation
24
+ use_noisy_gating (bool): Ob Noisy Top-k Gating verwendet werden soll
25
+ aux_loss_alpha (float): Skalierung für Load Balancing Loss
26
+ router_z_loss_alpha (float): Skalierung für Router Z-Loss
27
+ bias (bool): Ob Bias in Linear Layers verwendet werden soll
28
+ dropout (float): Dropout Probability
29
+ activation_function (str): Aktivierungsfunktion (gelu, relu, swiglu)
30
+ initializer_range (float): Standard Deviation für Weight Initialization
31
+ layer_norm_epsilon (float): Epsilon für Layer Normalization
32
+ """
33
+
34
+ model_type = "moe_gpt"
35
+
36
+ def __init__(
37
+ self,
38
+ vocab_size=128256, # Llama 3.2 tokenizer (inkl. special tokens)
39
+ n_positions=2048, # Default 2048 für RoPE
40
+ n_embd=768,
41
+ n_layer=12,
42
+ n_head=12,
43
+ n_experts=8,
44
+ n_experts_active=2,
45
+ moe_layer_frequency=2,
46
+ capacity_factor=1.25,
47
+ eval_capacity_factor=2.0,
48
+ use_noisy_gating=True,
49
+ aux_loss_alpha=0.01,
50
+ router_z_loss_alpha=0.001,
51
+ bias=False,
52
+ dropout=0.1,
53
+ activation_function="gelu",
54
+ initializer_range=0.1,
55
+ layer_norm_epsilon=1e-5,
56
+ use_cache=True,
57
+ rope_theta=10000.0, # RoPE base theta
58
+ **kwargs,
59
+ ):
60
+ super().__init__(**kwargs)
61
+
62
+ self.vocab_size = vocab_size
63
+ self.n_positions = n_positions
64
+ self.n_embd = n_embd
65
+ self.n_layer = n_layer
66
+ self.n_head = n_head
67
+ self.n_experts = n_experts
68
+ self.n_experts_active = n_experts_active
69
+ self.moe_layer_frequency = moe_layer_frequency
70
+ self.capacity_factor = capacity_factor
71
+ self.eval_capacity_factor = eval_capacity_factor
72
+ self.use_noisy_gating = use_noisy_gating
73
+ self.aux_loss_alpha = aux_loss_alpha
74
+ self.router_z_loss_alpha = router_z_loss_alpha
75
+ self.bias = bias
76
+ self.dropout = dropout
77
+ self.activation_function = activation_function
78
+ self.initializer_range = initializer_range
79
+ self.layer_norm_epsilon = layer_norm_epsilon
80
+ self.use_cache = use_cache
81
+ self.rope_theta = rope_theta
82
+
83
+ # HuggingFace Standard Attribute (für .generate())
84
+ self.num_hidden_layers = n_layer
85
+ self.hidden_size = n_embd
86
+ self.num_attention_heads = n_head
87
+ self.max_position_embeddings = n_positions
88
+
89
+ # Validierung
90
+ assert n_embd % n_head == 0, "n_embd muss durch n_head teilbar sein"
91
+ assert n_experts_active <= n_experts, "n_experts_active darf nicht größer als n_experts sein"
92
+ assert moe_layer_frequency >= 1, "moe_layer_frequency muss mindestens 1 sein"
93
+
94
+ @property
95
+ def head_dim(self):
96
+ """Dimension pro Attention Head"""
97
+ return self.n_embd // self.n_head
98
+
99
+ @property
100
+ def total_experts(self):
101
+ """Gesamtanzahl der Experten im Modell"""
102
+ num_moe_layers = sum(1 for i in range(self.n_layer) if i % self.moe_layer_frequency == 0)
103
+ return num_moe_layers * self.n_experts
104
+
105
+ @property
106
+ def active_parameters_ratio(self):
107
+ """Ratio der aktiven Parameter (ungefähr)"""
108
+ num_moe_layers = sum(1 for i in range(self.n_layer) if i % self.moe_layer_frequency == 0)
109
+ num_dense_layers = self.n_layer - num_moe_layers
110
+
111
+ # Vereinfachte Schätzung (ignoriert Attention)
112
+ dense_params = num_dense_layers * (8 * self.n_embd**2) # FFN params
113
+ moe_total_params = num_moe_layers * self.n_experts * (8 * self.n_embd**2)
114
+ moe_active_params = num_moe_layers * self.n_experts_active * (8 * self.n_embd**2)
115
+
116
+ total = dense_params + moe_total_params
117
+ active = dense_params + moe_active_params
118
+
119
+ return active / total if total > 0 else 1.0
moe_layers.py ADDED
@@ -0,0 +1,323 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MoE Layer Komponenten
3
+ Basierend auf dem nanoMoE Blog Post und HuggingFace Best Practices
4
+ """
5
+
6
+ import math
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from typing import Tuple, Optional
11
+
12
+
13
+ class MoERouter(nn.Module):
14
+ """
15
+ Noisy Top-k Router für MoE.
16
+ Routet Tokens zu den Top-k Experten basierend auf gelernten Wahrscheinlichkeiten.
17
+ """
18
+
19
+ def __init__(
20
+ self,
21
+ d_model: int,
22
+ n_experts: int,
23
+ n_experts_active: int,
24
+ use_noisy_gating: bool = True,
25
+ capacity_factor: float = 1.25,
26
+ ):
27
+ super().__init__()
28
+
29
+ self.d_model = d_model
30
+ self.n_experts = n_experts
31
+ self.n_experts_active = n_experts_active
32
+ self.use_noisy_gating = use_noisy_gating
33
+ self.capacity_factor = capacity_factor
34
+
35
+ # Linear projections für Router (kein Bias, siehe Shazeer et al. 2017)
36
+ self.w_gate = nn.Linear(d_model, n_experts, bias=False)
37
+ self.w_noise = nn.Linear(d_model, n_experts, bias=False) if use_noisy_gating else None
38
+
39
+ def forward(
40
+ self, x: torch.Tensor
41
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
42
+ """
43
+ Args:
44
+ x: Input tensor [batch_size, seq_len, d_model]
45
+
46
+ Returns:
47
+ expert_weights: Gewichte für jeden Experten [batch_size * seq_len, n_experts, capacity]
48
+ expert_mask: Maske für verwendete Experten [batch_size * seq_len, n_experts, capacity]
49
+ expert_batches: Batches für jeden Experten [n_experts, capacity, d_model]
50
+ router_logits: Router Logits für z-loss [batch_size, seq_len, n_experts]
51
+ """
52
+ batch_size, seq_len, d_model = x.shape
53
+ num_tokens = batch_size * seq_len
54
+
55
+ # Router läuft IMMER in FP32 für numerische Stabilität!
56
+ device_type = "cuda" if x.is_cuda else "cpu"
57
+ with torch.amp.autocast(device_type=device_type, enabled=False):
58
+ x_fp32 = x.float()
59
+
60
+ # Router Logits berechnen
61
+ router_logits = self.w_gate(x_fp32) # [B, T, n_experts]
62
+
63
+ # Noisy Top-k Gating (optional)
64
+ if self.use_noisy_gating and self.training:
65
+ noise = F.softplus(self.w_noise(x_fp32))
66
+ noise = noise * torch.randn_like(noise)
67
+ router_logits = router_logits + noise
68
+
69
+ # Top-k Experten auswählen
70
+ top_k_logits, top_k_indices = router_logits.topk(
71
+ self.n_experts_active, dim=-1
72
+ ) # [B, T, K]
73
+
74
+ # Softmax über alle Experten (nicht nur Top-k)
75
+ router_probs = torch.full_like(router_logits, float("-inf"))
76
+ router_probs.scatter_(-1, top_k_indices, top_k_logits)
77
+ router_probs = F.softmax(router_probs, dim=-1) # [B, T, n_experts]
78
+
79
+ # Expert Capacity berechnen
80
+ capacity = self._compute_capacity(num_tokens)
81
+
82
+ # Multi-hot Maske der gewählten Experten
83
+ expert_mask = F.one_hot(
84
+ top_k_indices, num_classes=self.n_experts
85
+ ) # [B, T, K, n_experts]
86
+ expert_mask = expert_mask.view(num_tokens, self.n_experts_active, self.n_experts)
87
+ expert_mask = expert_mask.permute(1, 0, 2) # [K, num_tokens, n_experts]
88
+
89
+ # Position jedes Tokens im Expert Batch (cumsum für Top-1 first prioritization)
90
+ expert_rank = expert_mask.reshape(
91
+ self.n_experts_active * num_tokens, self.n_experts
92
+ )
93
+ expert_rank = torch.cumsum(expert_rank, dim=0) - 1
94
+ expert_rank = expert_rank.reshape(
95
+ self.n_experts_active, num_tokens, self.n_experts
96
+ )
97
+
98
+ # Tokens über Kapazität hinaus maskieren
99
+ expert_mask = expert_mask * torch.lt(expert_rank, capacity)
100
+
101
+ # Position im Expert Batch
102
+ expert_rank = torch.sum(expert_mask * expert_rank, dim=-1) # [K, num_tokens]
103
+
104
+ # Wahrscheinlichkeiten mit Maske multiplizieren
105
+ router_probs = router_probs.view(num_tokens, self.n_experts)[
106
+ None, :
107
+ ] # [1, num_tokens, n_experts]
108
+ expert_weights = expert_mask * router_probs # [K, num_tokens, n_experts]
109
+
110
+ # One-hot für Position in Expert Batch
111
+ expert_rank_one_hot = F.one_hot(
112
+ expert_rank, num_classes=capacity
113
+ ) # [K, num_tokens, capacity]
114
+
115
+ # Gewichte an Expert Batch Position
116
+ expert_weights = torch.sum(
117
+ expert_weights.unsqueeze(3) * expert_rank_one_hot.unsqueeze(2), dim=0
118
+ ) # [num_tokens, n_experts, capacity]
119
+ expert_mask = expert_weights.bool()
120
+
121
+ # Expert Batches erstellen
122
+ x_flat = x.view(num_tokens, d_model)
123
+ expert_batches = (
124
+ expert_mask.permute(1, 2, 0).type_as(x) @ x_flat
125
+ ) # [n_experts, capacity, d_model]
126
+
127
+ return expert_weights, expert_mask, expert_batches, router_logits
128
+
129
+ def _compute_capacity(self, num_tokens: int) -> int:
130
+ """Berechnet Expert Capacity"""
131
+ capacity = math.floor(
132
+ self.n_experts_active * self.capacity_factor * num_tokens / self.n_experts
133
+ )
134
+ capacity += capacity % 2 # Gerade Zahl für bessere Hardware-Nutzung
135
+ return max(int(capacity), 2) # Minimum 2 für kleine Batches
136
+
137
+
138
+ class ExpertMLP(nn.Module):
139
+ """
140
+ Batch von MLP Experten.
141
+ Alle Experten haben die gleiche Architektur, aber unabhängige Gewichte.
142
+ """
143
+
144
+ def __init__(
145
+ self,
146
+ d_model: int,
147
+ n_experts: int,
148
+ bias: bool = False,
149
+ dropout: float = 0.1,
150
+ activation: str = "gelu",
151
+ ):
152
+ super().__init__()
153
+
154
+ self.d_model = d_model
155
+ self.n_experts = n_experts
156
+ self.bias = bias
157
+
158
+ # 4x hidden dimension (Standard für GPT)
159
+ hidden_dim = 4 * d_model
160
+
161
+ # Gewichte für alle Experten (batch matmul)
162
+ self.w_fc = nn.Parameter(torch.empty(n_experts, d_model, hidden_dim))
163
+ self.w_proj = nn.Parameter(torch.empty(n_experts, hidden_dim, d_model))
164
+
165
+ if bias:
166
+ self.fc_bias = nn.Parameter(torch.empty(n_experts, 1, hidden_dim))
167
+ self.proj_bias = nn.Parameter(torch.empty(n_experts, 1, d_model))
168
+ else:
169
+ self.register_parameter("fc_bias", None)
170
+ self.register_parameter("proj_bias", None)
171
+
172
+ # Aktivierungsfunktion
173
+ if activation == "gelu":
174
+ self.activation = nn.GELU()
175
+ elif activation == "relu":
176
+ self.activation = nn.ReLU()
177
+ elif activation == "swiglu":
178
+ # SwiGLU braucht extra Gewichte
179
+ self.w_gate = nn.Parameter(torch.empty(n_experts, d_model, hidden_dim))
180
+ self.activation = nn.SiLU()
181
+ else:
182
+ raise ValueError(f"Unbekannte Aktivierung: {activation}")
183
+
184
+ self.dropout = nn.Dropout(dropout)
185
+ self.activation_type = activation
186
+
187
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
188
+ """
189
+ Args:
190
+ x: [n_experts, capacity, d_model]
191
+
192
+ Returns:
193
+ output: [n_experts, capacity, d_model]
194
+ """
195
+ # Erste Linear Layer mit batch matmul
196
+ h = torch.bmm(x, self.w_fc)
197
+ if self.bias:
198
+ h = h + self.fc_bias
199
+
200
+ # Aktivierung
201
+ if self.activation_type == "swiglu":
202
+ # SwiGLU: silu(x @ W_gate) * (x @ W_fc)
203
+ gate = torch.bmm(x, self.w_gate)
204
+ h = self.activation(gate) * h
205
+ else:
206
+ h = self.activation(h)
207
+
208
+ # Zweite Linear Layer
209
+ output = torch.bmm(h, self.w_proj)
210
+ if self.bias:
211
+ output = output + self.proj_bias
212
+
213
+ output = self.dropout(output)
214
+
215
+ return output
216
+
217
+
218
+ class MoELayer(nn.Module):
219
+ """
220
+ Vollständige Mixture-of-Experts Layer.
221
+ Kombiniert Router und Experten.
222
+ """
223
+
224
+ def __init__(
225
+ self,
226
+ d_model: int,
227
+ n_experts: int = 8,
228
+ n_experts_active: int = 2,
229
+ use_noisy_gating: bool = True,
230
+ capacity_factor: float = 1.25,
231
+ bias: bool = False,
232
+ dropout: float = 0.1,
233
+ activation: str = "gelu",
234
+ ):
235
+ super().__init__()
236
+
237
+ self.router = MoERouter(
238
+ d_model=d_model,
239
+ n_experts=n_experts,
240
+ n_experts_active=n_experts_active,
241
+ use_noisy_gating=use_noisy_gating,
242
+ capacity_factor=capacity_factor,
243
+ )
244
+
245
+ self.experts = ExpertMLP(
246
+ d_model=d_model,
247
+ n_experts=n_experts,
248
+ bias=bias,
249
+ dropout=dropout,
250
+ activation=activation,
251
+ )
252
+
253
+ self.n_experts = n_experts
254
+ self.n_experts_active = n_experts_active
255
+
256
+ def forward(
257
+ self, x: torch.Tensor
258
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
259
+ """
260
+ Args:
261
+ x: [batch_size, seq_len, d_model]
262
+
263
+ Returns:
264
+ output: [batch_size, seq_len, d_model]
265
+ load_balance_loss: Skalarer Load Balancing Loss
266
+ router_z_loss: Skalarer Router Z-Loss
267
+ """
268
+ batch_size, seq_len, d_model = x.shape
269
+ num_tokens = batch_size * seq_len
270
+
271
+ # Routing
272
+ expert_weights, expert_mask, expert_batches, router_logits = self.router(x)
273
+
274
+ # Expert Forward Pass
275
+ expert_outputs = self.experts(expert_batches) # [n_experts, capacity, d_model]
276
+
277
+ # Outputs kombinieren (gewichteter Durchschnitt)
278
+ expert_weights_flat = expert_weights.view(num_tokens, -1) # [num_tokens, n_experts * capacity]
279
+ expert_outputs_flat = expert_outputs.view(-1, d_model) # [n_experts * capacity, d_model]
280
+ output = expert_weights_flat @ expert_outputs_flat # [num_tokens, d_model]
281
+ output = output.view(batch_size, seq_len, d_model)
282
+
283
+ # Auxiliary Losses berechnen
284
+ load_balance_loss = self._compute_load_balance_loss(router_logits, expert_mask)
285
+ router_z_loss = self._compute_router_z_loss(router_logits)
286
+
287
+ return output, load_balance_loss, router_z_loss
288
+
289
+ def _compute_load_balance_loss(
290
+ self, router_logits: torch.Tensor, expert_mask: torch.Tensor
291
+ ) -> torch.Tensor:
292
+ """
293
+ Load Balancing Loss (Switch Transformer, Fedus et al. 2022)
294
+ Encourages uniform distribution of tokens across experts.
295
+ """
296
+ batch_size, seq_len, n_experts = router_logits.shape
297
+ num_tokens = batch_size * seq_len
298
+
299
+ # Probability pro Expert
300
+ router_probs = F.softmax(router_logits, dim=-1) # [B, T, n_experts]
301
+ prob_per_expert = torch.mean(router_probs, dim=(0, 1)) # [n_experts]
302
+
303
+ # Token Ratio pro Expert
304
+ with torch.no_grad():
305
+ # expert_mask ist [num_tokens, n_experts, capacity]
306
+ tokens_per_expert = torch.sum(expert_mask.float(), dim=(0, 2)) # [n_experts]
307
+ tokens_per_expert = tokens_per_expert / (num_tokens * self.n_experts_active)
308
+
309
+ # Dot product (scaled by n_experts)
310
+ loss = self.n_experts * torch.sum(prob_per_expert * tokens_per_expert)
311
+
312
+ return loss
313
+
314
+ def _compute_router_z_loss(self, router_logits: torch.Tensor) -> torch.Tensor:
315
+ """
316
+ Router Z-Loss (ST-MoE, Zoph et al. 2022)
317
+ Penalisiert große Router Logits für numerische Stabilität.
318
+ """
319
+ # Squared logsumexp über Experten
320
+ z_loss = torch.logsumexp(router_logits, dim=-1) ** 2.0 # [B, T]
321
+ z_loss = torch.mean(z_loss)
322
+
323
+ return z_loss
moe_model.py ADDED
@@ -0,0 +1,460 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MoE GPT Model - HuggingFace kompatibel
3
+ Basiert auf nanoMoE und dem Blog Post
4
+ """
5
+
6
+ import math
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from typing import Optional, Tuple, Union
11
+ from dataclasses import dataclass
12
+
13
+ from transformers import PreTrainedModel
14
+ from transformers.generation import GenerationMixin
15
+ from transformers.modeling_outputs import CausalLMOutputWithPast
16
+
17
+ from moe_config import MoEGPTConfig
18
+ from moe_layers import MoELayer
19
+
20
+
21
+ @dataclass
22
+ class MoECausalLMOutput(CausalLMOutputWithPast):
23
+ """
24
+ Erweiterte Output Klasse mit MoE-spezifischen Losses
25
+ """
26
+
27
+ aux_loss: Optional[torch.FloatTensor] = None
28
+ router_z_loss: Optional[torch.FloatTensor] = None
29
+
30
+
31
+ def apply_rotary_emb(x: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor) -> torch.Tensor:
32
+ """
33
+ Applies Rotary Position Embeddings (RoPE) to input tensor.
34
+
35
+ Args:
36
+ x: Input tensor of shape [B, H, T, D]
37
+ freqs_cos: Cosine frequencies of shape [T, D//2]
38
+ freqs_sin: Sine frequencies of shape [T, D//2]
39
+
40
+ Returns:
41
+ Tensor with RoPE applied
42
+ """
43
+ # Reshape x to separate real and imaginary parts for rotation
44
+ # x: [B, H, T, D] -> [B, H, T, D//2, 2]
45
+ x_complex = x.float().reshape(*x.shape[:-1], -1, 2)
46
+
47
+ # Apply rotation: (a + bi) * (cos + i*sin) = (a*cos - b*sin) + i(a*sin + b*cos)
48
+ x_rot_real = x_complex[..., 0] * freqs_cos - x_complex[..., 1] * freqs_sin
49
+ x_rot_imag = x_complex[..., 0] * freqs_sin + x_complex[..., 1] * freqs_cos
50
+
51
+ # Stack back together and flatten
52
+ x_out = torch.stack([x_rot_real, x_rot_imag], dim=-1)
53
+ x_out = x_out.flatten(-2)
54
+
55
+ return x_out.type_as(x)
56
+
57
+
58
+ def precompute_freqs_rope(dim: int, max_seq_len: int, theta: float = 10000.0) -> Tuple[torch.Tensor, torch.Tensor]:
59
+ """
60
+ Precomputes RoPE frequencies.
61
+
62
+ Args:
63
+ dim: Head dimension
64
+ max_seq_len: Maximum sequence length
65
+ theta: RoPE theta parameter (base for frequency calculation)
66
+
67
+ Returns:
68
+ Tuple of (freqs_cos, freqs_sin) tensors of shape [max_seq_len, dim//2]
69
+ """
70
+ # Compute frequencies for each dimension pair
71
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
72
+
73
+ # Create position indices
74
+ t = torch.arange(max_seq_len, dtype=torch.float32)
75
+
76
+ # Compute outer product: [max_seq_len, dim//2]
77
+ freqs = torch.outer(t, freqs)
78
+
79
+ # Compute cos and sin
80
+ freqs_cos = torch.cos(freqs)
81
+ freqs_sin = torch.sin(freqs)
82
+
83
+ return freqs_cos, freqs_sin
84
+
85
+
86
+ class CausalSelfAttention(nn.Module):
87
+ """
88
+ Multi-Head Causal Self-Attention with Rotary Position Embeddings (RoPE).
89
+ Uses PyTorch SDPA for optimized performance.
90
+ """
91
+
92
+ def __init__(self, config: MoEGPTConfig):
93
+ super().__init__()
94
+ assert config.n_embd % config.n_head == 0
95
+
96
+ # Key, Query, Value für alle Heads gleichzeitig
97
+ self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
98
+ # Output Projektion
99
+ self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
100
+
101
+ # Regularization
102
+ self.attn_dropout = nn.Dropout(config.dropout)
103
+ self.resid_dropout = nn.Dropout(config.dropout)
104
+
105
+ self.n_head = config.n_head
106
+ self.n_embd = config.n_embd
107
+ self.dropout = config.dropout
108
+ self.head_dim = config.n_embd // config.n_head
109
+
110
+ # Precompute RoPE frequencies
111
+ freqs_cos, freqs_sin = precompute_freqs_rope(
112
+ dim=self.head_dim,
113
+ max_seq_len=config.n_positions,
114
+ theta=config.rope_theta
115
+ )
116
+ self.register_buffer("freqs_cos", freqs_cos, persistent=False)
117
+ self.register_buffer("freqs_sin", freqs_sin, persistent=False)
118
+
119
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
120
+ B, T, C = x.size() # batch, sequence length, embedding dim
121
+
122
+ # Q, K, V berechnen
123
+ q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
124
+
125
+ # Reshape für Multi-Head
126
+ k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2) # [B, H, T, d]
127
+ q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
128
+ v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
129
+
130
+ # Apply RoPE to Q and K
131
+ q = apply_rotary_emb(q, self.freqs_cos[:T], self.freqs_sin[:T])
132
+ k = apply_rotary_emb(k, self.freqs_cos[:T], self.freqs_sin[:T])
133
+
134
+ # Use PyTorch SDPA (Scaled Dot Product Attention) - optimized!
135
+ # SDPA handles causal masking, dropout, and is memory efficient
136
+ y = F.scaled_dot_product_attention(
137
+ q, k, v,
138
+ attn_mask=None, # Causal mask handled by is_causal
139
+ dropout_p=self.dropout if self.training else 0.0,
140
+ is_causal=True # Efficient causal masking
141
+ ) # [B, H, T, d]
142
+
143
+ # Reshape back
144
+ y = y.transpose(1, 2).contiguous().view(B, T, C)
145
+
146
+ # Output Projektion
147
+ y = self.resid_dropout(self.c_proj(y))
148
+
149
+ return y
150
+
151
+
152
+ class MLP(nn.Module):
153
+ """
154
+ Standard Feed-Forward Network (für nicht-MoE Layers)
155
+ """
156
+
157
+ def __init__(self, config: MoEGPTConfig):
158
+ super().__init__()
159
+ self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
160
+ self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
161
+ self.dropout = nn.Dropout(config.dropout)
162
+
163
+ if config.activation_function == "gelu":
164
+ self.activation = nn.GELU()
165
+ elif config.activation_function == "relu":
166
+ self.activation = nn.ReLU()
167
+ else:
168
+ raise ValueError(f"Unbekannte Aktivierung: {config.activation_function}")
169
+
170
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
171
+ x = self.c_fc(x)
172
+ x = self.activation(x)
173
+ x = self.c_proj(x)
174
+ x = self.dropout(x)
175
+ return x
176
+
177
+
178
+ class TransformerBlock(nn.Module):
179
+ """
180
+ Standard Transformer Block (Attention + MLP)
181
+ """
182
+
183
+ def __init__(self, config: MoEGPTConfig):
184
+ super().__init__()
185
+ self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
186
+ self.attn = CausalSelfAttention(config)
187
+ self.ln_2 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
188
+ self.mlp = MLP(config)
189
+
190
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
191
+ x = x + self.attn(self.ln_1(x))
192
+ x = x + self.mlp(self.ln_2(x))
193
+ return x
194
+
195
+
196
+ class MoETransformerBlock(nn.Module):
197
+ """
198
+ MoE Transformer Block (Attention + MoE Layer)
199
+ """
200
+
201
+ def __init__(self, config: MoEGPTConfig):
202
+ super().__init__()
203
+ self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
204
+ self.attn = CausalSelfAttention(config)
205
+ self.ln_2 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
206
+
207
+ # Capacity Factor abhängig von Training/Eval
208
+ self.moe = MoELayer(
209
+ d_model=config.n_embd,
210
+ n_experts=config.n_experts,
211
+ n_experts_active=config.n_experts_active,
212
+ use_noisy_gating=config.use_noisy_gating,
213
+ capacity_factor=config.capacity_factor,
214
+ bias=config.bias,
215
+ dropout=config.dropout,
216
+ activation=config.activation_function,
217
+ )
218
+
219
+ def forward(
220
+ self, x: torch.Tensor
221
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
222
+ # Attention
223
+ x = x + self.attn(self.ln_1(x))
224
+
225
+ # MoE Layer
226
+ moe_out, aux_loss, router_z_loss = self.moe(self.ln_2(x))
227
+ x = x + moe_out
228
+
229
+ return x, aux_loss, router_z_loss
230
+
231
+
232
+ class MoEGPTPreTrainedModel(PreTrainedModel):
233
+ """
234
+ Base Klasse für MoE GPT mit HuggingFace PreTrainedModel
235
+ """
236
+
237
+ config_class = MoEGPTConfig
238
+ base_model_prefix = "transformer"
239
+ supports_gradient_checkpointing = True
240
+
241
+ def _init_weights(self, module):
242
+ """
243
+ Weight Initialization nach ST-MoE (Zoph et al. 2022)
244
+ Truncated Normal mit reduzierter Std für MoE Stabilität
245
+ """
246
+ if isinstance(module, nn.Linear):
247
+ # Fan-in Initialization
248
+ fan_in = module.weight.shape[-1]
249
+ std = (self.config.initializer_range / fan_in) ** 0.5
250
+
251
+ torch.nn.init.trunc_normal_(
252
+ module.weight,
253
+ mean=0.0,
254
+ std=std,
255
+ a=-2 * std,
256
+ b=2 * std,
257
+ )
258
+ if module.bias is not None:
259
+ torch.nn.init.zeros_(module.bias)
260
+
261
+ elif isinstance(module, nn.Embedding):
262
+ torch.nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
263
+
264
+ elif isinstance(module, nn.Parameter):
265
+ # Für Expert Parameter
266
+ fan_in = module.shape[-1] if len(module.shape) >= 2 else module.shape[0]
267
+ std = (self.config.initializer_range / fan_in) ** 0.5
268
+
269
+ torch.nn.init.trunc_normal_(
270
+ module,
271
+ mean=0.0,
272
+ std=std,
273
+ a=-2 * std,
274
+ b=2 * std,
275
+ )
276
+
277
+
278
+ class MoEGPTModel(MoEGPTPreTrainedModel):
279
+ """
280
+ MoE GPT Model (ohne LM Head)
281
+ """
282
+
283
+ def __init__(self, config: MoEGPTConfig):
284
+ super().__init__(config)
285
+ self.config = config
286
+ self.gradient_checkpointing = False # Für HF Gradient Checkpointing Support
287
+
288
+ # Token Embeddings only (RoPE handles positions)
289
+ self.wte = nn.Embedding(config.vocab_size, config.n_embd)
290
+ self.drop = nn.Dropout(config.dropout)
291
+
292
+ # Transformer Blocks (gemischt: Standard + MoE)
293
+ self.h = nn.ModuleList()
294
+ for i in range(config.n_layer):
295
+ if i % config.moe_layer_frequency == 0:
296
+ # MoE Block
297
+ self.h.append(MoETransformerBlock(config))
298
+ else:
299
+ # Standard Block
300
+ self.h.append(TransformerBlock(config))
301
+
302
+ # Final Layer Norm
303
+ self.ln_f = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
304
+
305
+ # Initialize weights
306
+ self.post_init()
307
+
308
+ def forward(
309
+ self,
310
+ input_ids: torch.LongTensor,
311
+ attention_mask: Optional[torch.Tensor] = None,
312
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
313
+ device = input_ids.device
314
+ b, t = input_ids.size()
315
+
316
+ assert t <= self.config.n_positions, f"Sequenz zu lang: {t} > {self.config.n_positions}"
317
+
318
+ # Token Embeddings only (RoPE in attention layers)
319
+ tok_emb = self.wte(input_ids) # [B, T, n_embd]
320
+ x = self.drop(tok_emb)
321
+
322
+ # Sammle Auxiliary Losses
323
+ total_aux_loss = 0.0
324
+ total_router_z_loss = 0.0
325
+
326
+ # Durch alle Blocks
327
+ for block in self.h:
328
+ if isinstance(block, MoETransformerBlock):
329
+ if self.gradient_checkpointing and self.training:
330
+ # Gradient Checkpointing für MoE Blocks
331
+ def create_custom_forward(module):
332
+ def custom_forward(*inputs):
333
+ return module(*inputs)
334
+ return custom_forward
335
+
336
+ x, aux_loss, router_z_loss = torch.utils.checkpoint.checkpoint(
337
+ create_custom_forward(block),
338
+ x,
339
+ use_reentrant=False
340
+ )
341
+ else:
342
+ x, aux_loss, router_z_loss = block(x)
343
+ total_aux_loss = total_aux_loss + aux_loss
344
+ total_router_z_loss = total_router_z_loss + router_z_loss
345
+ else:
346
+ if self.gradient_checkpointing and self.training:
347
+ x = torch.utils.checkpoint.checkpoint(
348
+ block,
349
+ x,
350
+ use_reentrant=False
351
+ )
352
+ else:
353
+ x = block(x)
354
+
355
+ x = self.ln_f(x)
356
+
357
+ return x, total_aux_loss, total_router_z_loss
358
+
359
+
360
+ class MoEGPTForCausalLM(MoEGPTPreTrainedModel, GenerationMixin):
361
+ """
362
+ MoE GPT mit Language Modeling Head (für Pretraining)
363
+ Erbt von GenerationMixin für .generate() Support
364
+ """
365
+
366
+ # Teile HuggingFace mit, welche Weights geteilt sind
367
+ _tied_weights_keys = ["lm_head.weight"]
368
+
369
+ def __init__(self, config: MoEGPTConfig):
370
+ super().__init__(config)
371
+ self.transformer = MoEGPTModel(config)
372
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
373
+
374
+ # Weight Tying (LM Head teilt Gewichte mit Token Embedding)
375
+ self.lm_head.weight = self.transformer.wte.weight
376
+
377
+ # Initialize weights
378
+ self.post_init()
379
+
380
+ def get_output_embeddings(self):
381
+ """Für HuggingFace Weight Tying"""
382
+ return self.lm_head
383
+
384
+ def set_output_embeddings(self, new_embeddings):
385
+ """Für HuggingFace Weight Tying"""
386
+ self.lm_head = new_embeddings
387
+
388
+ def get_input_embeddings(self):
389
+ """Für HuggingFace Weight Tying"""
390
+ return self.transformer.wte
391
+
392
+ def set_input_embeddings(self, new_embeddings):
393
+ """Für HuggingFace Weight Tying"""
394
+ self.transformer.wte = new_embeddings
395
+
396
+ def tie_weights(self):
397
+ """
398
+ Tie lm_head weights to input embeddings (weight tying)
399
+ Called after loading checkpoint to fix missing lm_head.weight
400
+ """
401
+ self.lm_head.weight = self.transformer.wte.weight
402
+
403
+ def forward(
404
+ self,
405
+ input_ids: torch.LongTensor,
406
+ attention_mask: Optional[torch.Tensor] = None,
407
+ labels: Optional[torch.LongTensor] = None,
408
+ return_dict: Optional[bool] = None,
409
+ **kwargs, # Accept additional kwargs like use_cache for HuggingFace compatibility
410
+ ) -> Union[Tuple, MoECausalLMOutput]:
411
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
412
+
413
+ # Forward durch Transformer
414
+ hidden_states, aux_loss, router_z_loss = self.transformer(
415
+ input_ids=input_ids,
416
+ attention_mask=attention_mask,
417
+ )
418
+
419
+ # LM Head
420
+ if labels is not None:
421
+ # Training: nur letzte Position für jede Sequenz
422
+ logits = self.lm_head(hidden_states)
423
+ else:
424
+ # Inference: nur letzte Position
425
+ logits = self.lm_head(hidden_states[:, [-1], :])
426
+
427
+ # Loss berechnen
428
+ loss = None
429
+ if labels is not None:
430
+ # Shift für next token prediction
431
+ shift_logits = logits[..., :-1, :].contiguous()
432
+ shift_labels = labels[..., 1:].contiguous()
433
+
434
+ # Cross Entropy Loss
435
+ loss_fct = nn.CrossEntropyLoss()
436
+ lm_loss = loss_fct(
437
+ shift_logits.view(-1, shift_logits.size(-1)),
438
+ shift_labels.view(-1),
439
+ )
440
+
441
+ # Auxiliary Losses hinzufügen
442
+ loss = lm_loss
443
+ if self.training:
444
+ loss = loss + self.config.aux_loss_alpha * aux_loss
445
+ loss = loss + self.config.router_z_loss_alpha * router_z_loss
446
+
447
+ if not return_dict:
448
+ output = (logits,)
449
+ return ((loss,) + output) if loss is not None else output
450
+
451
+ return MoECausalLMOutput(
452
+ loss=loss,
453
+ logits=logits,
454
+ aux_loss=aux_loss if self.training else None,
455
+ router_z_loss=router_z_loss if self.training else None,
456
+ )
457
+
458
+ def prepare_inputs_for_generation(self, input_ids, **kwargs):
459
+ """Für HuggingFace generate() Funktion"""
460
+ return {"input_ids": input_ids}