|
|
"""
|
|
|
MoE GPT Model - HuggingFace kompatibel
|
|
|
Basiert auf nanoMoE und dem Blog Post
|
|
|
"""
|
|
|
|
|
|
import math
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
import torch.nn.functional as F
|
|
|
from typing import Optional, Tuple, Union
|
|
|
from dataclasses import dataclass
|
|
|
|
|
|
from transformers import PreTrainedModel
|
|
|
from transformers.generation import GenerationMixin
|
|
|
from transformers.modeling_outputs import CausalLMOutputWithPast
|
|
|
|
|
|
from moe_config import MoEGPTConfig
|
|
|
from moe_layers import MoELayer
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
class MoECausalLMOutput(CausalLMOutputWithPast):
|
|
|
"""
|
|
|
Erweiterte Output Klasse mit MoE-spezifischen Losses
|
|
|
"""
|
|
|
|
|
|
aux_loss: Optional[torch.FloatTensor] = None
|
|
|
router_z_loss: Optional[torch.FloatTensor] = None
|
|
|
|
|
|
|
|
|
def apply_rotary_emb(x: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor) -> torch.Tensor:
|
|
|
"""
|
|
|
Applies Rotary Position Embeddings (RoPE) to input tensor.
|
|
|
|
|
|
Args:
|
|
|
x: Input tensor of shape [B, H, T, D]
|
|
|
freqs_cos: Cosine frequencies of shape [T, D//2]
|
|
|
freqs_sin: Sine frequencies of shape [T, D//2]
|
|
|
|
|
|
Returns:
|
|
|
Tensor with RoPE applied
|
|
|
"""
|
|
|
|
|
|
|
|
|
x_complex = x.float().reshape(*x.shape[:-1], -1, 2)
|
|
|
|
|
|
|
|
|
x_rot_real = x_complex[..., 0] * freqs_cos - x_complex[..., 1] * freqs_sin
|
|
|
x_rot_imag = x_complex[..., 0] * freqs_sin + x_complex[..., 1] * freqs_cos
|
|
|
|
|
|
|
|
|
x_out = torch.stack([x_rot_real, x_rot_imag], dim=-1)
|
|
|
x_out = x_out.flatten(-2)
|
|
|
|
|
|
return x_out.type_as(x)
|
|
|
|
|
|
|
|
|
def precompute_freqs_rope(dim: int, max_seq_len: int, theta: float = 10000.0) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
|
"""
|
|
|
Precomputes RoPE frequencies.
|
|
|
|
|
|
Args:
|
|
|
dim: Head dimension
|
|
|
max_seq_len: Maximum sequence length
|
|
|
theta: RoPE theta parameter (base for frequency calculation)
|
|
|
|
|
|
Returns:
|
|
|
Tuple of (freqs_cos, freqs_sin) tensors of shape [max_seq_len, dim//2]
|
|
|
"""
|
|
|
|
|
|
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
|
|
|
|
|
|
|
|
|
t = torch.arange(max_seq_len, dtype=torch.float32)
|
|
|
|
|
|
|
|
|
freqs = torch.outer(t, freqs)
|
|
|
|
|
|
|
|
|
freqs_cos = torch.cos(freqs)
|
|
|
freqs_sin = torch.sin(freqs)
|
|
|
|
|
|
return freqs_cos, freqs_sin
|
|
|
|
|
|
|
|
|
class CausalSelfAttention(nn.Module):
|
|
|
"""
|
|
|
Multi-Head Causal Self-Attention with Rotary Position Embeddings (RoPE).
|
|
|
Uses PyTorch SDPA for optimized performance.
|
|
|
"""
|
|
|
|
|
|
def __init__(self, config: MoEGPTConfig):
|
|
|
super().__init__()
|
|
|
assert config.n_embd % config.n_head == 0
|
|
|
|
|
|
|
|
|
self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
|
|
|
|
|
|
self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
|
|
|
|
|
|
|
|
|
self.attn_dropout = nn.Dropout(config.dropout)
|
|
|
self.resid_dropout = nn.Dropout(config.dropout)
|
|
|
|
|
|
self.n_head = config.n_head
|
|
|
self.n_embd = config.n_embd
|
|
|
self.dropout = config.dropout
|
|
|
self.head_dim = config.n_embd // config.n_head
|
|
|
|
|
|
|
|
|
freqs_cos, freqs_sin = precompute_freqs_rope(
|
|
|
dim=self.head_dim,
|
|
|
max_seq_len=config.n_positions,
|
|
|
theta=config.rope_theta
|
|
|
)
|
|
|
self.register_buffer("freqs_cos", freqs_cos, persistent=False)
|
|
|
self.register_buffer("freqs_sin", freqs_sin, persistent=False)
|
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
|
B, T, C = x.size()
|
|
|
|
|
|
|
|
|
q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
|
|
|
|
|
|
|
|
|
k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
|
|
|
q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
|
|
|
v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
|
|
|
|
|
|
|
|
|
q = apply_rotary_emb(q, self.freqs_cos[:T], self.freqs_sin[:T])
|
|
|
k = apply_rotary_emb(k, self.freqs_cos[:T], self.freqs_sin[:T])
|
|
|
|
|
|
|
|
|
|
|
|
y = F.scaled_dot_product_attention(
|
|
|
q, k, v,
|
|
|
attn_mask=None,
|
|
|
dropout_p=self.dropout if self.training else 0.0,
|
|
|
is_causal=True
|
|
|
)
|
|
|
|
|
|
|
|
|
y = y.transpose(1, 2).contiguous().view(B, T, C)
|
|
|
|
|
|
|
|
|
y = self.resid_dropout(self.c_proj(y))
|
|
|
|
|
|
return y
|
|
|
|
|
|
|
|
|
class MLP(nn.Module):
|
|
|
"""
|
|
|
Standard Feed-Forward Network (für nicht-MoE Layers)
|
|
|
"""
|
|
|
|
|
|
def __init__(self, config: MoEGPTConfig):
|
|
|
super().__init__()
|
|
|
self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
|
|
|
self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
|
|
|
self.dropout = nn.Dropout(config.dropout)
|
|
|
|
|
|
if config.activation_function == "gelu":
|
|
|
self.activation = nn.GELU()
|
|
|
elif config.activation_function == "relu":
|
|
|
self.activation = nn.ReLU()
|
|
|
else:
|
|
|
raise ValueError(f"Unbekannte Aktivierung: {config.activation_function}")
|
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
|
x = self.c_fc(x)
|
|
|
x = self.activation(x)
|
|
|
x = self.c_proj(x)
|
|
|
x = self.dropout(x)
|
|
|
return x
|
|
|
|
|
|
|
|
|
class TransformerBlock(nn.Module):
|
|
|
"""
|
|
|
Standard Transformer Block (Attention + MLP)
|
|
|
"""
|
|
|
|
|
|
def __init__(self, config: MoEGPTConfig):
|
|
|
super().__init__()
|
|
|
self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
|
|
|
self.attn = CausalSelfAttention(config)
|
|
|
self.ln_2 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
|
|
|
self.mlp = MLP(config)
|
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
|
x = x + self.attn(self.ln_1(x))
|
|
|
x = x + self.mlp(self.ln_2(x))
|
|
|
return x
|
|
|
|
|
|
|
|
|
class MoETransformerBlock(nn.Module):
|
|
|
"""
|
|
|
MoE Transformer Block (Attention + MoE Layer)
|
|
|
"""
|
|
|
|
|
|
def __init__(self, config: MoEGPTConfig):
|
|
|
super().__init__()
|
|
|
self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
|
|
|
self.attn = CausalSelfAttention(config)
|
|
|
self.ln_2 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
|
|
|
|
|
|
|
|
|
self.moe = MoELayer(
|
|
|
d_model=config.n_embd,
|
|
|
n_experts=config.n_experts,
|
|
|
n_experts_active=config.n_experts_active,
|
|
|
use_noisy_gating=config.use_noisy_gating,
|
|
|
capacity_factor=config.capacity_factor,
|
|
|
bias=config.bias,
|
|
|
dropout=config.dropout,
|
|
|
activation=config.activation_function,
|
|
|
)
|
|
|
|
|
|
def forward(
|
|
|
self, x: torch.Tensor
|
|
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
|
|
|
|
x = x + self.attn(self.ln_1(x))
|
|
|
|
|
|
|
|
|
moe_out, aux_loss, router_z_loss = self.moe(self.ln_2(x))
|
|
|
x = x + moe_out
|
|
|
|
|
|
return x, aux_loss, router_z_loss
|
|
|
|
|
|
|
|
|
class MoEGPTPreTrainedModel(PreTrainedModel):
|
|
|
"""
|
|
|
Base Klasse für MoE GPT mit HuggingFace PreTrainedModel
|
|
|
"""
|
|
|
|
|
|
config_class = MoEGPTConfig
|
|
|
base_model_prefix = "transformer"
|
|
|
supports_gradient_checkpointing = True
|
|
|
|
|
|
def _init_weights(self, module):
|
|
|
"""
|
|
|
Weight Initialization nach ST-MoE (Zoph et al. 2022)
|
|
|
Truncated Normal mit reduzierter Std für MoE Stabilität
|
|
|
"""
|
|
|
if isinstance(module, nn.Linear):
|
|
|
|
|
|
fan_in = module.weight.shape[-1]
|
|
|
std = (self.config.initializer_range / fan_in) ** 0.5
|
|
|
|
|
|
torch.nn.init.trunc_normal_(
|
|
|
module.weight,
|
|
|
mean=0.0,
|
|
|
std=std,
|
|
|
a=-2 * std,
|
|
|
b=2 * std,
|
|
|
)
|
|
|
if module.bias is not None:
|
|
|
torch.nn.init.zeros_(module.bias)
|
|
|
|
|
|
elif isinstance(module, nn.Embedding):
|
|
|
torch.nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
|
|
|
|
|
|
elif isinstance(module, nn.Parameter):
|
|
|
|
|
|
fan_in = module.shape[-1] if len(module.shape) >= 2 else module.shape[0]
|
|
|
std = (self.config.initializer_range / fan_in) ** 0.5
|
|
|
|
|
|
torch.nn.init.trunc_normal_(
|
|
|
module,
|
|
|
mean=0.0,
|
|
|
std=std,
|
|
|
a=-2 * std,
|
|
|
b=2 * std,
|
|
|
)
|
|
|
|
|
|
|
|
|
class MoEGPTModel(MoEGPTPreTrainedModel):
|
|
|
"""
|
|
|
MoE GPT Model (ohne LM Head)
|
|
|
"""
|
|
|
|
|
|
def __init__(self, config: MoEGPTConfig):
|
|
|
super().__init__(config)
|
|
|
self.config = config
|
|
|
self.gradient_checkpointing = False
|
|
|
|
|
|
|
|
|
self.wte = nn.Embedding(config.vocab_size, config.n_embd)
|
|
|
self.drop = nn.Dropout(config.dropout)
|
|
|
|
|
|
|
|
|
self.h = nn.ModuleList()
|
|
|
for i in range(config.n_layer):
|
|
|
if i % config.moe_layer_frequency == 0:
|
|
|
|
|
|
self.h.append(MoETransformerBlock(config))
|
|
|
else:
|
|
|
|
|
|
self.h.append(TransformerBlock(config))
|
|
|
|
|
|
|
|
|
self.ln_f = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
|
|
|
|
|
|
|
|
|
self.post_init()
|
|
|
|
|
|
def forward(
|
|
|
self,
|
|
|
input_ids: torch.LongTensor,
|
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
|
device = input_ids.device
|
|
|
b, t = input_ids.size()
|
|
|
|
|
|
assert t <= self.config.n_positions, f"Sequenz zu lang: {t} > {self.config.n_positions}"
|
|
|
|
|
|
|
|
|
tok_emb = self.wte(input_ids)
|
|
|
x = self.drop(tok_emb)
|
|
|
|
|
|
|
|
|
total_aux_loss = 0.0
|
|
|
total_router_z_loss = 0.0
|
|
|
|
|
|
|
|
|
for block in self.h:
|
|
|
if isinstance(block, MoETransformerBlock):
|
|
|
if self.gradient_checkpointing and self.training:
|
|
|
|
|
|
def create_custom_forward(module):
|
|
|
def custom_forward(*inputs):
|
|
|
return module(*inputs)
|
|
|
return custom_forward
|
|
|
|
|
|
x, aux_loss, router_z_loss = torch.utils.checkpoint.checkpoint(
|
|
|
create_custom_forward(block),
|
|
|
x,
|
|
|
use_reentrant=False
|
|
|
)
|
|
|
else:
|
|
|
x, aux_loss, router_z_loss = block(x)
|
|
|
total_aux_loss = total_aux_loss + aux_loss
|
|
|
total_router_z_loss = total_router_z_loss + router_z_loss
|
|
|
else:
|
|
|
if self.gradient_checkpointing and self.training:
|
|
|
x = torch.utils.checkpoint.checkpoint(
|
|
|
block,
|
|
|
x,
|
|
|
use_reentrant=False
|
|
|
)
|
|
|
else:
|
|
|
x = block(x)
|
|
|
|
|
|
x = self.ln_f(x)
|
|
|
|
|
|
return x, total_aux_loss, total_router_z_loss
|
|
|
|
|
|
|
|
|
class MoEGPTForCausalLM(MoEGPTPreTrainedModel, GenerationMixin):
|
|
|
"""
|
|
|
MoE GPT mit Language Modeling Head (für Pretraining)
|
|
|
Erbt von GenerationMixin für .generate() Support
|
|
|
"""
|
|
|
|
|
|
|
|
|
_tied_weights_keys = ["lm_head.weight"]
|
|
|
|
|
|
def __init__(self, config: MoEGPTConfig):
|
|
|
super().__init__(config)
|
|
|
self.transformer = MoEGPTModel(config)
|
|
|
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
|
|
|
|
|
|
|
|
self.lm_head.weight = self.transformer.wte.weight
|
|
|
|
|
|
|
|
|
self.post_init()
|
|
|
|
|
|
def get_output_embeddings(self):
|
|
|
"""Für HuggingFace Weight Tying"""
|
|
|
return self.lm_head
|
|
|
|
|
|
def set_output_embeddings(self, new_embeddings):
|
|
|
"""Für HuggingFace Weight Tying"""
|
|
|
self.lm_head = new_embeddings
|
|
|
|
|
|
def get_input_embeddings(self):
|
|
|
"""Für HuggingFace Weight Tying"""
|
|
|
return self.transformer.wte
|
|
|
|
|
|
def set_input_embeddings(self, new_embeddings):
|
|
|
"""Für HuggingFace Weight Tying"""
|
|
|
self.transformer.wte = new_embeddings
|
|
|
|
|
|
def tie_weights(self):
|
|
|
"""
|
|
|
Tie lm_head weights to input embeddings (weight tying)
|
|
|
Called after loading checkpoint to fix missing lm_head.weight
|
|
|
"""
|
|
|
self.lm_head.weight = self.transformer.wte.weight
|
|
|
|
|
|
def forward(
|
|
|
self,
|
|
|
input_ids: torch.LongTensor,
|
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
|
labels: Optional[torch.LongTensor] = None,
|
|
|
return_dict: Optional[bool] = None,
|
|
|
**kwargs,
|
|
|
) -> Union[Tuple, MoECausalLMOutput]:
|
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
|
|
|
|
|
|
|
hidden_states, aux_loss, router_z_loss = self.transformer(
|
|
|
input_ids=input_ids,
|
|
|
attention_mask=attention_mask,
|
|
|
)
|
|
|
|
|
|
|
|
|
if labels is not None:
|
|
|
|
|
|
logits = self.lm_head(hidden_states)
|
|
|
else:
|
|
|
|
|
|
logits = self.lm_head(hidden_states[:, [-1], :])
|
|
|
|
|
|
|
|
|
loss = None
|
|
|
if labels is not None:
|
|
|
|
|
|
shift_logits = logits[..., :-1, :].contiguous()
|
|
|
shift_labels = labels[..., 1:].contiguous()
|
|
|
|
|
|
|
|
|
loss_fct = nn.CrossEntropyLoss()
|
|
|
lm_loss = loss_fct(
|
|
|
shift_logits.view(-1, shift_logits.size(-1)),
|
|
|
shift_labels.view(-1),
|
|
|
)
|
|
|
|
|
|
|
|
|
loss = lm_loss
|
|
|
if self.training:
|
|
|
loss = loss + self.config.aux_loss_alpha * aux_loss
|
|
|
loss = loss + self.config.router_z_loss_alpha * router_z_loss
|
|
|
|
|
|
if not return_dict:
|
|
|
output = (logits,)
|
|
|
return ((loss,) + output) if loss is not None else output
|
|
|
|
|
|
return MoECausalLMOutput(
|
|
|
loss=loss,
|
|
|
logits=logits,
|
|
|
aux_loss=aux_loss if self.training else None,
|
|
|
router_z_loss=router_z_loss if self.training else None,
|
|
|
)
|
|
|
|
|
|
def prepare_inputs_for_generation(self, input_ids, **kwargs):
|
|
|
"""Für HuggingFace generate() Funktion"""
|
|
|
return {"input_ids": input_ids}
|
|
|
|