from dataclasses import dataclass from typing import Optional import torch import torch.nn as nn import torch.nn.functional as F from torch import Tensor try: from flash_attn_interface import flash_attn_func, flash_attn_varlen_func # pyright: ignore[reportMissingImports] FLASH_ATTN_AVAILABLE = True except ImportError: FLASH_ATTN_AVAILABLE = False flash_attn_func = None flash_attn_varlen_func = None class Rotary(nn.Module): cos_cached: Tensor sin_cached: Tensor def __init__(self, dim: int, max_seq_len: int = 2048, base: int = 10000): super().__init__() inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) t = torch.arange(max_seq_len) freqs = torch.outer(t, inv_freq) self.register_buffer("cos_cached", freqs.cos().bfloat16(), persistent=False) self.register_buffer("sin_cached", freqs.sin().bfloat16(), persistent=False) def forward(self, x: Tensor) -> tuple[Tensor, Tensor]: seq_len = x.shape[1] return self.cos_cached[None, :seq_len, None, :], self.sin_cached[None, :seq_len, None, :] def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: assert x.ndim == 4 # [batch, seq_len, n_heads, head_dim] d: int = x.shape[3] // 2 x1 = x[..., :d] x2 = x[..., d:] y1 = x1 * cos + x2 * sin y2 = x1 * (-sin) + x2 * cos return torch.cat([y1, y2], 3).type_as(x) class CausalSelfAttention(nn.Module): def __init__(self, config: "GPTConfig") -> None: super().__init__() self.n_head: int = config.n_head self.n_embd: int = config.n_embd self.head_dim: int = self.n_embd // self.n_head assert self.n_embd % self.n_head == 0 self.c_q = nn.Linear(self.n_embd, self.n_embd, bias=False) self.c_k = nn.Linear(self.n_embd, self.n_embd, bias=False) self.c_v = nn.Linear(self.n_embd, self.n_embd, bias=False) self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False) self.c_proj.weight.data.zero_() self.rotary = Rotary(self.head_dim, max_seq_len=config.sequence_length) def forward(self, x: Tensor, cu_seqlens: Optional[Tensor] = None, max_seqlen: Optional[int] = None) -> Tensor: assert x.ndim == 3, f"x must be 3D, got shape {x.shape}" B, T, C = x.size() assert C == self.n_embd, f"hidden dim mismatch: {C} != {self.n_embd}" assert B > 0 and T > 0, f"batch and seq length must be > 0: B={B}, T={T}" q = self.c_q(x).view(B, T, self.n_head, self.head_dim) k = self.c_k(x).view(B, T, self.n_head, self.head_dim) v = self.c_v(x).view(B, T, self.n_head, self.head_dim) assert q.shape == (B, T, self.n_head, self.head_dim), f"q shape mismatch: {q.shape}" cos, sin = self.rotary(q) q = F.rms_norm(q, (q.size(-1),)) k = F.rms_norm(k, (k.size(-1),)) q = apply_rotary_emb(q, cos, sin) k = apply_rotary_emb(k, cos, sin) use_flash = FLASH_ATTN_AVAILABLE and x.is_cuda if use_flash and flash_attn_varlen_func is not None and cu_seqlens is not None: q_flat = q.reshape(-1, self.n_head, self.head_dim) k_flat = k.reshape(-1, self.n_head, self.head_dim) v_flat = v.reshape(-1, self.n_head, self.head_dim) # Use pre-computed max_seqlen from dataloader (avoids .item() graph break) seqlen: int = max_seqlen if max_seqlen is not None else int((cu_seqlens[1:] - cu_seqlens[:-1]).max().item()) y_flat = flash_attn_varlen_func( q_flat, k_flat, v_flat, cu_seqlens_q=cu_seqlens, cu_seqlens_k=cu_seqlens, max_seqlen_q=seqlen, max_seqlen_k=seqlen, causal=True, ) y = y_flat.reshape(B, T, C) elif use_flash and flash_attn_func is not None: y = flash_attn_func(q, k, v, causal=True) y = y.contiguous().view_as(x) else: y = F.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), is_causal=True) y = y.transpose(1, 2).contiguous().view_as(x) y = self.c_proj(y) return y class MLP(nn.Module): def __init__(self, config: "GPTConfig") -> None: super().__init__() self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=False) self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=False) self.c_proj.weight.data.zero_() def forward(self, x: Tensor) -> Tensor: x = self.c_fc(x) x = F.relu(x).square() x = self.c_proj(x) return x class Block(nn.Module): def __init__(self, config: "GPTConfig") -> None: super().__init__() self.attn = CausalSelfAttention(config) self.mlp = MLP(config) def forward(self, x: Tensor, cu_seqlens: Optional[Tensor] = None, max_seqlen: Optional[int] = None) -> Tensor: x = x + self.attn(F.rms_norm(x, (x.size(-1),)), cu_seqlens=cu_seqlens, max_seqlen=max_seqlen) x = x + self.mlp(F.rms_norm(x, (x.size(-1),))) return x @dataclass class GPTConfig: vocab_size: int = 32256 n_layer: int = 12 n_head: int = 12 n_embd: int = 768 sequence_length: int = 1024 class Transformer(nn.Module): wte: nn.Embedding h: nn.ModuleList def __init__(self, config: GPTConfig): super().__init__() self.wte = nn.Embedding(config.vocab_size, config.n_embd) self.h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]) class GPT(nn.Module): config: GPTConfig transformer: Transformer lm_head: nn.Linear def __init__(self, config: GPTConfig): super().__init__() self.config = config self.transformer = Transformer(config) self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) self.transformer.wte.weight = self.lm_head.weight def forward( self, idx: Tensor, targets: Optional[Tensor] = None, return_logits: bool = True, return_hidden: bool = False, cu_seqlens: Optional[Tensor] = None, max_seqlen: Optional[int] = None, ) -> tuple[Optional[Tensor], Optional[Tensor]] | tuple[Optional[Tensor], Optional[Tensor], Tensor]: assert idx.ndim == 2, f"idx must be 2D, got shape {idx.shape}" B, T = idx.shape assert B > 0 and T > 0, f"batch and seq length must be > 0: B={B}, T={T}" if targets is not None: assert targets.shape == idx.shape, f"targets shape {targets.shape} != idx shape {idx.shape}" x = self.transformer.wte(idx) assert x.shape == (B, T, self.config.n_embd), f"embedding output shape mismatch: {x.shape}" for block in self.transformer.h: x = block(x, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen) assert x.shape == (B, T, self.config.n_embd), f"block output shape mismatch: {x.shape}" hidden = F.rms_norm(x, (x.size(-1),)) assert hidden.shape == x.shape, f"rms_norm shape mismatch: {hidden.shape}" if targets is not None: logits = self.lm_head(hidden) assert logits.shape == (B, T, self.config.vocab_size), f"logits shape mismatch: {logits.shape}" loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1) assert loss.ndim == 0, f"loss must be scalar, got shape {loss.shape}" else: if return_logits: logits = self.lm_head(hidden) else: logits = self.lm_head(hidden[:, [-1], :]) loss = None if not return_logits: logits = None if return_hidden: return logits, loss, hidden return logits, loss def get_num_params(self) -> int: return sum(p.numel() for p in self.parameters()) class StackedGPT(nn.Module): def __init__(self, model1: GPT, model2: GPT) -> None: super().__init__() self.model1 = model1 self.model2 = model2 assert model1.config.vocab_size == model2.config.vocab_size assert model1.config.n_embd == model2.config.n_embd self.model2.transformer.wte = self.model1.transformer.wte self.model2.lm_head = self.model1.lm_head def forward( self, idx: Tensor, targets: Optional[Tensor] = None, return_logits: bool = True, cu_seqlens: Optional[Tensor] = None, ) -> tuple[Optional[Tensor], Optional[Tensor]]: logits1, _ = self.model1(idx, targets=targets, return_logits=True, cu_seqlens=cu_seqlens) logits2, _ = self.model2(idx, targets=targets, return_logits=True, cu_seqlens=cu_seqlens) logits = (logits1 + logits2) / 2.0 if targets is not None: loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1) else: loss = None if not return_logits: logits = None return logits, loss def generate_expert_vectors(n_experts: int, embed_dim: int, seed: int = 42) -> torch.Tensor: torch.manual_seed(seed) vectors = torch.randn(n_experts, embed_dim) vectors = F.normalize(vectors, p=2, dim=1) return vectors class MoEGPT(nn.Module): expert_vectors: nn.Parameter expert_models: list[GPT] top_k: int temperature: float config: GPTConfig def __init__(self, expert_vectors: Tensor, *models: GPT, top_k: int = 2, temperature: float = 20.0): super().__init__() self.expert_vectors = nn.Parameter(expert_vectors, requires_grad=False) self.expert_models = list(models) self.models = nn.ModuleList(models) # For state_dict compatibility self.top_k = top_k self.temperature = temperature self.config = models[0].config for model in models[1:]: assert model.config.vocab_size == self.config.vocab_size assert model.config.n_embd == self.config.n_embd def forward( self, idx: Tensor, targets: Optional[Tensor] = None, return_logits: bool = True, cu_seqlens: Optional[Tensor] = None, max_seqlen: Optional[int] = None, ) -> tuple[Optional[Tensor], Optional[Tensor]]: B, T = idx.size() vocab_size = self.config.vocab_size token_embeds = self.expert_models[0].transformer.wte(idx) token_embeds_flat = token_embeds.reshape(-1, self.config.n_embd) token_embeds_norm = F.normalize(token_embeds_flat, p=2, dim=1) cosine_similarities = torch.matmul(token_embeds_norm, self.expert_vectors.T) # Apply temperature scaling before topk/softmax scaled_similarities = cosine_similarities * self.temperature top_k_similarities, top_k_indices = torch.topk(scaled_similarities, self.top_k, dim=-1) top_k_weights = F.softmax(top_k_similarities, dim=-1) # Process experts sequentially to avoid OOM output: Optional[Tensor] = None for expert_id, expert_model in enumerate(self.expert_models): logits, _ = expert_model( idx, targets=None, return_logits=True, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen ) if output is None: output = torch.zeros_like(logits) for k in range(self.top_k): routing_mask = (top_k_indices[:, k] == expert_id).float().view(B, T, 1) expert_weight = top_k_weights[:, k].view(B, T, 1) output = output + logits * routing_mask * expert_weight del logits assert output is not None, "No experts to process" if targets is not None: loss = F.cross_entropy(output.view(-1, vocab_size), targets.view(-1), ignore_index=-1) else: loss = None if not return_logits: return None, loss return output, loss