File size: 11,841 Bytes
7e2e7b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
"""

MoE Layer Komponenten

Basierend auf dem nanoMoE Blog Post und HuggingFace Best Practices

"""

import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple, Optional


class MoERouter(nn.Module):
    """

    Noisy Top-k Router für MoE.

    Routet Tokens zu den Top-k Experten basierend auf gelernten Wahrscheinlichkeiten.

    """

    def __init__(

        self,

        d_model: int,

        n_experts: int,

        n_experts_active: int,

        use_noisy_gating: bool = True,

        capacity_factor: float = 1.25,

    ):
        super().__init__()

        self.d_model = d_model
        self.n_experts = n_experts
        self.n_experts_active = n_experts_active
        self.use_noisy_gating = use_noisy_gating
        self.capacity_factor = capacity_factor

        # Linear projections für Router (kein Bias, siehe Shazeer et al. 2017)
        self.w_gate = nn.Linear(d_model, n_experts, bias=False)
        self.w_noise = nn.Linear(d_model, n_experts, bias=False) if use_noisy_gating else None

    def forward(

        self, x: torch.Tensor

    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        """

        Args:

            x: Input tensor [batch_size, seq_len, d_model]



        Returns:

            expert_weights: Gewichte für jeden Experten [batch_size * seq_len, n_experts, capacity]

            expert_mask: Maske für verwendete Experten [batch_size * seq_len, n_experts, capacity]

            expert_batches: Batches für jeden Experten [n_experts, capacity, d_model]

            router_logits: Router Logits für z-loss [batch_size, seq_len, n_experts]

        """
        batch_size, seq_len, d_model = x.shape
        num_tokens = batch_size * seq_len

        # Router läuft IMMER in FP32 für numerische Stabilität!
        device_type = "cuda" if x.is_cuda else "cpu"
        with torch.amp.autocast(device_type=device_type, enabled=False):
            x_fp32 = x.float()

            # Router Logits berechnen
            router_logits = self.w_gate(x_fp32)  # [B, T, n_experts]

            # Noisy Top-k Gating (optional)
            if self.use_noisy_gating and self.training:
                noise = F.softplus(self.w_noise(x_fp32))
                noise = noise * torch.randn_like(noise)
                router_logits = router_logits + noise

            # Top-k Experten auswählen
            top_k_logits, top_k_indices = router_logits.topk(
                self.n_experts_active, dim=-1
            )  # [B, T, K]

            # Softmax über alle Experten (nicht nur Top-k)
            router_probs = torch.full_like(router_logits, float("-inf"))
            router_probs.scatter_(-1, top_k_indices, top_k_logits)
            router_probs = F.softmax(router_probs, dim=-1)  # [B, T, n_experts]

            # Expert Capacity berechnen
            capacity = self._compute_capacity(num_tokens)

            # Multi-hot Maske der gewählten Experten
            expert_mask = F.one_hot(
                top_k_indices, num_classes=self.n_experts
            )  # [B, T, K, n_experts]
            expert_mask = expert_mask.view(num_tokens, self.n_experts_active, self.n_experts)
            expert_mask = expert_mask.permute(1, 0, 2)  # [K, num_tokens, n_experts]

            # Position jedes Tokens im Expert Batch (cumsum für Top-1 first prioritization)
            expert_rank = expert_mask.reshape(
                self.n_experts_active * num_tokens, self.n_experts
            )
            expert_rank = torch.cumsum(expert_rank, dim=0) - 1
            expert_rank = expert_rank.reshape(
                self.n_experts_active, num_tokens, self.n_experts
            )

            # Tokens über Kapazität hinaus maskieren
            expert_mask = expert_mask * torch.lt(expert_rank, capacity)

            # Position im Expert Batch
            expert_rank = torch.sum(expert_mask * expert_rank, dim=-1)  # [K, num_tokens]

            # Wahrscheinlichkeiten mit Maske multiplizieren
            router_probs = router_probs.view(num_tokens, self.n_experts)[
                None, :
            ]  # [1, num_tokens, n_experts]
            expert_weights = expert_mask * router_probs  # [K, num_tokens, n_experts]

            # One-hot für Position in Expert Batch
            expert_rank_one_hot = F.one_hot(
                expert_rank, num_classes=capacity
            )  # [K, num_tokens, capacity]

            # Gewichte an Expert Batch Position
            expert_weights = torch.sum(
                expert_weights.unsqueeze(3) * expert_rank_one_hot.unsqueeze(2), dim=0
            )  # [num_tokens, n_experts, capacity]
            expert_mask = expert_weights.bool()

            # Expert Batches erstellen
            x_flat = x.view(num_tokens, d_model)
            expert_batches = (
                expert_mask.permute(1, 2, 0).type_as(x) @ x_flat
            )  # [n_experts, capacity, d_model]

        return expert_weights, expert_mask, expert_batches, router_logits

    def _compute_capacity(self, num_tokens: int) -> int:
        """Berechnet Expert Capacity"""
        capacity = math.floor(
            self.n_experts_active * self.capacity_factor * num_tokens / self.n_experts
        )
        capacity += capacity % 2  # Gerade Zahl für bessere Hardware-Nutzung
        return max(int(capacity), 2)  # Minimum 2 für kleine Batches


class ExpertMLP(nn.Module):
    """

    Batch von MLP Experten.

    Alle Experten haben die gleiche Architektur, aber unabhängige Gewichte.

    """

    def __init__(

        self,

        d_model: int,

        n_experts: int,

        bias: bool = False,

        dropout: float = 0.1,

        activation: str = "gelu",

    ):
        super().__init__()

        self.d_model = d_model
        self.n_experts = n_experts
        self.bias = bias

        # 4x hidden dimension (Standard für GPT)
        hidden_dim = 4 * d_model

        # Gewichte für alle Experten (batch matmul)
        self.w_fc = nn.Parameter(torch.empty(n_experts, d_model, hidden_dim))
        self.w_proj = nn.Parameter(torch.empty(n_experts, hidden_dim, d_model))

        if bias:
            self.fc_bias = nn.Parameter(torch.empty(n_experts, 1, hidden_dim))
            self.proj_bias = nn.Parameter(torch.empty(n_experts, 1, d_model))
        else:
            self.register_parameter("fc_bias", None)
            self.register_parameter("proj_bias", None)

        # Aktivierungsfunktion
        if activation == "gelu":
            self.activation = nn.GELU()
        elif activation == "relu":
            self.activation = nn.ReLU()
        elif activation == "swiglu":
            # SwiGLU braucht extra Gewichte
            self.w_gate = nn.Parameter(torch.empty(n_experts, d_model, hidden_dim))
            self.activation = nn.SiLU()
        else:
            raise ValueError(f"Unbekannte Aktivierung: {activation}")

        self.dropout = nn.Dropout(dropout)
        self.activation_type = activation

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """

        Args:

            x: [n_experts, capacity, d_model]



        Returns:

            output: [n_experts, capacity, d_model]

        """
        # Erste Linear Layer mit batch matmul
        h = torch.bmm(x, self.w_fc)
        if self.bias:
            h = h + self.fc_bias

        # Aktivierung
        if self.activation_type == "swiglu":
            # SwiGLU: silu(x @ W_gate) * (x @ W_fc)
            gate = torch.bmm(x, self.w_gate)
            h = self.activation(gate) * h
        else:
            h = self.activation(h)

        # Zweite Linear Layer
        output = torch.bmm(h, self.w_proj)
        if self.bias:
            output = output + self.proj_bias

        output = self.dropout(output)

        return output


class MoELayer(nn.Module):
    """

    Vollständige Mixture-of-Experts Layer.

    Kombiniert Router und Experten.

    """

    def __init__(

        self,

        d_model: int,

        n_experts: int = 8,

        n_experts_active: int = 2,

        use_noisy_gating: bool = True,

        capacity_factor: float = 1.25,

        bias: bool = False,

        dropout: float = 0.1,

        activation: str = "gelu",

    ):
        super().__init__()

        self.router = MoERouter(
            d_model=d_model,
            n_experts=n_experts,
            n_experts_active=n_experts_active,
            use_noisy_gating=use_noisy_gating,
            capacity_factor=capacity_factor,
        )

        self.experts = ExpertMLP(
            d_model=d_model,
            n_experts=n_experts,
            bias=bias,
            dropout=dropout,
            activation=activation,
        )

        self.n_experts = n_experts
        self.n_experts_active = n_experts_active

    def forward(

        self, x: torch.Tensor

    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """

        Args:

            x: [batch_size, seq_len, d_model]



        Returns:

            output: [batch_size, seq_len, d_model]

            load_balance_loss: Skalarer Load Balancing Loss

            router_z_loss: Skalarer Router Z-Loss

        """
        batch_size, seq_len, d_model = x.shape
        num_tokens = batch_size * seq_len

        # Routing
        expert_weights, expert_mask, expert_batches, router_logits = self.router(x)

        # Expert Forward Pass
        expert_outputs = self.experts(expert_batches)  # [n_experts, capacity, d_model]

        # Outputs kombinieren (gewichteter Durchschnitt)
        expert_weights_flat = expert_weights.view(num_tokens, -1)  # [num_tokens, n_experts * capacity]
        expert_outputs_flat = expert_outputs.view(-1, d_model)  # [n_experts * capacity, d_model]
        output = expert_weights_flat @ expert_outputs_flat  # [num_tokens, d_model]
        output = output.view(batch_size, seq_len, d_model)

        # Auxiliary Losses berechnen
        load_balance_loss = self._compute_load_balance_loss(router_logits, expert_mask)
        router_z_loss = self._compute_router_z_loss(router_logits)

        return output, load_balance_loss, router_z_loss

    def _compute_load_balance_loss(

        self, router_logits: torch.Tensor, expert_mask: torch.Tensor

    ) -> torch.Tensor:
        """

        Load Balancing Loss (Switch Transformer, Fedus et al. 2022)

        Encourages uniform distribution of tokens across experts.

        """
        batch_size, seq_len, n_experts = router_logits.shape
        num_tokens = batch_size * seq_len

        # Probability pro Expert
        router_probs = F.softmax(router_logits, dim=-1)  # [B, T, n_experts]
        prob_per_expert = torch.mean(router_probs, dim=(0, 1))  # [n_experts]

        # Token Ratio pro Expert
        with torch.no_grad():
            # expert_mask ist [num_tokens, n_experts, capacity]
            tokens_per_expert = torch.sum(expert_mask.float(), dim=(0, 2))  # [n_experts]
            tokens_per_expert = tokens_per_expert / (num_tokens * self.n_experts_active)

        # Dot product (scaled by n_experts)
        loss = self.n_experts * torch.sum(prob_per_expert * tokens_per_expert)

        return loss

    def _compute_router_z_loss(self, router_logits: torch.Tensor) -> torch.Tensor:
        """

        Router Z-Loss (ST-MoE, Zoph et al. 2022)

        Penalisiert große Router Logits für numerische Stabilität.

        """
        # Squared logsumexp über Experten
        z_loss = torch.logsumexp(router_logits, dim=-1) ** 2.0  # [B, T]
        z_loss = torch.mean(z_loss)

        return z_loss