| import math | |
| from typing import Tuple | |
| import torch | |
| import torch.nn as nn | |
| from transformers import GenerationMixin, PreTrainedModel | |
| from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions | |
| from .configuration_qmoe import QMoEConfig | |
| class RMSNorm(nn.Module): | |
| def __init__(self, d_model: int, eps: float = 1e-6): | |
| super().__init__() | |
| self.eps = eps | |
| self.scale = nn.Parameter(torch.ones(d_model, dtype=torch.float32)) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| rms = x.pow(2).mean(dim=-1, keepdim=True).add(self.eps).sqrt() | |
| return (x / rms) * self.scale | |
| class DenseNoBias(nn.Module): | |
| def __init__(self, in_features: int, out_features: int): | |
| super().__init__() | |
| self.kernel = nn.Parameter(torch.empty(in_features, out_features, dtype=torch.float32)) | |
| nn.init.normal_(self.kernel, std=0.02) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| return x @ self.kernel | |
| def causal_mask(t: int, *, device: torch.device) -> torch.Tensor: | |
| return torch.tril(torch.ones((t, t), dtype=torch.bool, device=device)) | |
| class MultiHeadAttention(nn.Module): | |
| def __init__(self, d_model: int, num_heads: int): | |
| super().__init__() | |
| if d_model % num_heads != 0: | |
| raise ValueError('d_model must be divisible by num_heads') | |
| self.d_model = d_model | |
| self.num_heads = num_heads | |
| self.head_dim = d_model // num_heads | |
| self.q_proj = DenseNoBias(d_model, d_model) | |
| self.k_proj = DenseNoBias(d_model, d_model) | |
| self.v_proj = DenseNoBias(d_model, d_model) | |
| self.out_proj = DenseNoBias(d_model, d_model) | |
| def forward(self, x: torch.Tensor, *, attn_mask: torch.Tensor) -> torch.Tensor: | |
| b, t, d = x.shape | |
| q = self.q_proj(x).view(b, t, self.num_heads, self.head_dim) | |
| k = self.k_proj(x).view(b, t, self.num_heads, self.head_dim) | |
| v = self.v_proj(x).view(b, t, self.num_heads, self.head_dim) | |
| scale = 1.0 / math.sqrt(self.head_dim) | |
| att = torch.einsum('bthd,bshd->bhts', q, k) * scale | |
| att = att.masked_fill(~attn_mask.view(1, 1, t, t), -1e30) | |
| att = torch.softmax(att, dim=-1) | |
| out = torch.einsum('bhts,bshd->bthd', att, v).contiguous() | |
| out = out.view(b, t, d) | |
| return self.out_proj(out) | |
| class Router(nn.Module): | |
| def __init__(self, d_model: int, num_experts: int, top_k: int): | |
| super().__init__() | |
| self.num_experts = num_experts | |
| self.top_k = top_k | |
| self.gate = DenseNoBias(d_model, num_experts) | |
| def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | |
| logits = self.gate(x) | |
| probs = torch.softmax(logits, dim=-1) | |
| topk_vals, topk_idx = torch.topk(probs, k=self.top_k, dim=-1) | |
| denom = topk_vals.sum(dim=-1, keepdim=True).clamp_min(1e-6) | |
| gates = topk_vals / denom | |
| return topk_idx, gates | |
| class ExpertMLPBank(nn.Module): | |
| def __init__(self, d_model: int, hidden_dim: int, num_experts: int): | |
| super().__init__() | |
| self.w1 = nn.Parameter(torch.empty(num_experts, d_model, hidden_dim, dtype=torch.float32)) | |
| self.b1 = nn.Parameter(torch.zeros(num_experts, hidden_dim, dtype=torch.float32)) | |
| self.w2 = nn.Parameter(torch.empty(num_experts, hidden_dim, d_model, dtype=torch.float32)) | |
| self.b2 = nn.Parameter(torch.zeros(num_experts, d_model, dtype=torch.float32)) | |
| nn.init.normal_(self.w1, std=0.02) | |
| nn.init.normal_(self.w2, std=0.02) | |
| def forward(self, x: torch.Tensor, expert_idx: torch.Tensor) -> torch.Tensor: | |
| w1 = self.w1.index_select(0, expert_idx) | |
| b1 = self.b1.index_select(0, expert_idx) | |
| w2 = self.w2.index_select(0, expert_idx) | |
| b2 = self.b2.index_select(0, expert_idx) | |
| h = torch.einsum('nd,ndh->nh', x, w1) + b1 | |
| h = torch.nn.functional.silu(h) | |
| y = torch.einsum('nh,nhd->nd', h, w2) + b2 | |
| return y | |
| class MoEFeedForward(nn.Module): | |
| def __init__(self, d_model: int, hidden_dim: int, num_experts: int, top_k: int): | |
| super().__init__() | |
| self.router = Router(d_model=d_model, num_experts=num_experts, top_k=top_k) | |
| self.experts = ExpertMLPBank(d_model=d_model, hidden_dim=hidden_dim, num_experts=num_experts) | |
| self.top_k = top_k | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| b, t, d = x.shape | |
| topk_idx, gates = self.router(x) | |
| x_flat = x.reshape(b * t, d) | |
| idx_flat = topk_idx.reshape(b * t, self.top_k) | |
| gates_flat = gates.reshape(b * t, self.top_k) | |
| y = torch.zeros_like(x_flat) | |
| for j in range(self.top_k): | |
| e_idx = idx_flat[:, j] | |
| y_j = self.experts(x_flat, e_idx) | |
| y = y + y_j * gates_flat[:, j : j + 1] | |
| return y.reshape(b, t, d) | |
| class Block(nn.Module): | |
| def __init__(self, d_model: int, num_heads: int, hidden_dim: int, num_experts: int, top_k: int): | |
| super().__init__() | |
| self.rmsnorm_0 = RMSNorm(d_model) | |
| self.attn = MultiHeadAttention(d_model=d_model, num_heads=num_heads) | |
| self.rmsnorm_1 = RMSNorm(d_model) | |
| self.moe = MoEFeedForward(d_model=d_model, hidden_dim=hidden_dim, num_experts=num_experts, top_k=top_k) | |
| def forward(self, x: torch.Tensor, *, attn_mask: torch.Tensor) -> torch.Tensor: | |
| h = self.rmsnorm_0(x) | |
| x = x + self.attn(h, attn_mask=attn_mask) | |
| h = self.rmsnorm_1(x) | |
| x = x + self.moe(h) | |
| return x | |
| class QMoEForCausalLM(PreTrainedModel, GenerationMixin): | |
| config_class = QMoEConfig | |
| main_input_name = 'input_ids' | |
| def __init__(self, config: QMoEConfig): | |
| super().__init__(config) | |
| self.tok_emb = nn.Embedding(config.vocab_size, config.d_model) | |
| self.pos_emb = nn.Embedding(config.max_seq_len, config.d_model) | |
| self.blocks = nn.ModuleList([Block(config.d_model, config.num_heads, config.ffn_dim, config.num_experts, config.moe_top_k) for _ in range(config.num_layers)]) | |
| self.rmsnorm_f = RMSNorm(config.d_model) | |
| self.lm_head = DenseNoBias(config.d_model, config.vocab_size) | |
| self.post_init() | |
| def get_input_embeddings(self): | |
| return self.tok_emb | |
| def set_input_embeddings(self, value): | |
| self.tok_emb = value | |
| def get_output_embeddings(self): | |
| return self.lm_head | |
| def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **kwargs): | |
| return {'input_ids': input_ids, 'attention_mask': attention_mask} | |
| def forward(self, input_ids=None, attention_mask=None, labels=None, **kwargs): | |
| if input_ids is None: | |
| raise ValueError('input_ids is required') | |
| b, t = input_ids.shape | |
| device = input_ids.device | |
| tok = self.tok_emb(input_ids) | |
| pos_idx = torch.arange(t, device=device).unsqueeze(0) | |
| pos = self.pos_emb(pos_idx) | |
| x = tok + pos | |
| attn_mask = causal_mask(t, device=device) | |
| for blk in self.blocks: | |
| x = blk(x, attn_mask=attn_mask) | |
| x = self.rmsnorm_f(x) | |
| logits = self.lm_head(x) | |
| loss = None | |
| if labels is not None: | |
| shift_logits = logits[:, :-1, :].contiguous() | |
| shift_labels = labels[:, 1:].contiguous() | |
| loss = torch.nn.functional.cross_entropy(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1), ignore_index=-100) | |
| return CausalLMOutputWithCrossAttentions(logits=logits, loss=loss) | |