QMoE-400 / modeling_qmoe.py
Sidharthan's picture
Upload folder using huggingface_hub
0ed2b3d verified
import math
from typing import Tuple
import torch
import torch.nn as nn
from transformers import GenerationMixin, PreTrainedModel
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
from .configuration_qmoe import QMoEConfig
class RMSNorm(nn.Module):
def __init__(self, d_model: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.scale = nn.Parameter(torch.ones(d_model, dtype=torch.float32))
def forward(self, x: torch.Tensor) -> torch.Tensor:
rms = x.pow(2).mean(dim=-1, keepdim=True).add(self.eps).sqrt()
return (x / rms) * self.scale
class DenseNoBias(nn.Module):
def __init__(self, in_features: int, out_features: int):
super().__init__()
self.kernel = nn.Parameter(torch.empty(in_features, out_features, dtype=torch.float32))
nn.init.normal_(self.kernel, std=0.02)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x @ self.kernel
def causal_mask(t: int, *, device: torch.device) -> torch.Tensor:
return torch.tril(torch.ones((t, t), dtype=torch.bool, device=device))
class MultiHeadAttention(nn.Module):
def __init__(self, d_model: int, num_heads: int):
super().__init__()
if d_model % num_heads != 0:
raise ValueError('d_model must be divisible by num_heads')
self.d_model = d_model
self.num_heads = num_heads
self.head_dim = d_model // num_heads
self.q_proj = DenseNoBias(d_model, d_model)
self.k_proj = DenseNoBias(d_model, d_model)
self.v_proj = DenseNoBias(d_model, d_model)
self.out_proj = DenseNoBias(d_model, d_model)
def forward(self, x: torch.Tensor, *, attn_mask: torch.Tensor) -> torch.Tensor:
b, t, d = x.shape
q = self.q_proj(x).view(b, t, self.num_heads, self.head_dim)
k = self.k_proj(x).view(b, t, self.num_heads, self.head_dim)
v = self.v_proj(x).view(b, t, self.num_heads, self.head_dim)
scale = 1.0 / math.sqrt(self.head_dim)
att = torch.einsum('bthd,bshd->bhts', q, k) * scale
att = att.masked_fill(~attn_mask.view(1, 1, t, t), -1e30)
att = torch.softmax(att, dim=-1)
out = torch.einsum('bhts,bshd->bthd', att, v).contiguous()
out = out.view(b, t, d)
return self.out_proj(out)
class Router(nn.Module):
def __init__(self, d_model: int, num_experts: int, top_k: int):
super().__init__()
self.num_experts = num_experts
self.top_k = top_k
self.gate = DenseNoBias(d_model, num_experts)
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
logits = self.gate(x)
probs = torch.softmax(logits, dim=-1)
topk_vals, topk_idx = torch.topk(probs, k=self.top_k, dim=-1)
denom = topk_vals.sum(dim=-1, keepdim=True).clamp_min(1e-6)
gates = topk_vals / denom
return topk_idx, gates
class ExpertMLPBank(nn.Module):
def __init__(self, d_model: int, hidden_dim: int, num_experts: int):
super().__init__()
self.w1 = nn.Parameter(torch.empty(num_experts, d_model, hidden_dim, dtype=torch.float32))
self.b1 = nn.Parameter(torch.zeros(num_experts, hidden_dim, dtype=torch.float32))
self.w2 = nn.Parameter(torch.empty(num_experts, hidden_dim, d_model, dtype=torch.float32))
self.b2 = nn.Parameter(torch.zeros(num_experts, d_model, dtype=torch.float32))
nn.init.normal_(self.w1, std=0.02)
nn.init.normal_(self.w2, std=0.02)
def forward(self, x: torch.Tensor, expert_idx: torch.Tensor) -> torch.Tensor:
w1 = self.w1.index_select(0, expert_idx)
b1 = self.b1.index_select(0, expert_idx)
w2 = self.w2.index_select(0, expert_idx)
b2 = self.b2.index_select(0, expert_idx)
h = torch.einsum('nd,ndh->nh', x, w1) + b1
h = torch.nn.functional.silu(h)
y = torch.einsum('nh,nhd->nd', h, w2) + b2
return y
class MoEFeedForward(nn.Module):
def __init__(self, d_model: int, hidden_dim: int, num_experts: int, top_k: int):
super().__init__()
self.router = Router(d_model=d_model, num_experts=num_experts, top_k=top_k)
self.experts = ExpertMLPBank(d_model=d_model, hidden_dim=hidden_dim, num_experts=num_experts)
self.top_k = top_k
def forward(self, x: torch.Tensor) -> torch.Tensor:
b, t, d = x.shape
topk_idx, gates = self.router(x)
x_flat = x.reshape(b * t, d)
idx_flat = topk_idx.reshape(b * t, self.top_k)
gates_flat = gates.reshape(b * t, self.top_k)
y = torch.zeros_like(x_flat)
for j in range(self.top_k):
e_idx = idx_flat[:, j]
y_j = self.experts(x_flat, e_idx)
y = y + y_j * gates_flat[:, j : j + 1]
return y.reshape(b, t, d)
class Block(nn.Module):
def __init__(self, d_model: int, num_heads: int, hidden_dim: int, num_experts: int, top_k: int):
super().__init__()
self.rmsnorm_0 = RMSNorm(d_model)
self.attn = MultiHeadAttention(d_model=d_model, num_heads=num_heads)
self.rmsnorm_1 = RMSNorm(d_model)
self.moe = MoEFeedForward(d_model=d_model, hidden_dim=hidden_dim, num_experts=num_experts, top_k=top_k)
def forward(self, x: torch.Tensor, *, attn_mask: torch.Tensor) -> torch.Tensor:
h = self.rmsnorm_0(x)
x = x + self.attn(h, attn_mask=attn_mask)
h = self.rmsnorm_1(x)
x = x + self.moe(h)
return x
class QMoEForCausalLM(PreTrainedModel, GenerationMixin):
config_class = QMoEConfig
main_input_name = 'input_ids'
def __init__(self, config: QMoEConfig):
super().__init__(config)
self.tok_emb = nn.Embedding(config.vocab_size, config.d_model)
self.pos_emb = nn.Embedding(config.max_seq_len, config.d_model)
self.blocks = nn.ModuleList([Block(config.d_model, config.num_heads, config.ffn_dim, config.num_experts, config.moe_top_k) for _ in range(config.num_layers)])
self.rmsnorm_f = RMSNorm(config.d_model)
self.lm_head = DenseNoBias(config.d_model, config.vocab_size)
self.post_init()
def get_input_embeddings(self):
return self.tok_emb
def set_input_embeddings(self, value):
self.tok_emb = value
def get_output_embeddings(self):
return self.lm_head
def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **kwargs):
return {'input_ids': input_ids, 'attention_mask': attention_mask}
def forward(self, input_ids=None, attention_mask=None, labels=None, **kwargs):
if input_ids is None:
raise ValueError('input_ids is required')
b, t = input_ids.shape
device = input_ids.device
tok = self.tok_emb(input_ids)
pos_idx = torch.arange(t, device=device).unsqueeze(0)
pos = self.pos_emb(pos_idx)
x = tok + pos
attn_mask = causal_mask(t, device=device)
for blk in self.blocks:
x = blk(x, attn_mask=attn_mask)
x = self.rmsnorm_f(x)
logits = self.lm_head(x)
loss = None
if labels is not None:
shift_logits = logits[:, :-1, :].contiguous()
shift_labels = labels[:, 1:].contiguous()
loss = torch.nn.functional.cross_entropy(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1), ignore_index=-100)
return CausalLMOutputWithCrossAttentions(logits=logits, loss=loss)