| """eGPT: LLaMA-style decoder-only model — self-contained HuggingFace implementation. |
| |
| Architecture: RMSNorm, RoPE, Grouped-Query Attention, SwiGLU FFN, no bias. |
| Weight keys match eGPT/model.py exactly for checkpoint compatibility. |
| """ |
|
|
| import math |
| from typing import Optional, Tuple |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from transformers import PreTrainedModel |
| from transformers.generation import GenerationMixin |
| from transformers.modeling_outputs import CausalLMOutputWithPast |
|
|
| try: |
| from .configuration_egpt import eGPTConfig |
| except ImportError: |
| from configuration_egpt import eGPTConfig |
|
|
|
|
| class RMSNorm(nn.Module): |
| def __init__(self, dim: int, eps: float = 1e-5): |
| super().__init__() |
| self.eps = eps |
| self.weight = nn.Parameter(torch.ones(dim)) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight |
|
|
|
|
| def _precompute_freqs_cis(head_dim: int, max_seq_len: int, theta: float) -> torch.Tensor: |
| freqs = 1.0 / (theta ** (torch.arange(0, head_dim, 2).float() / head_dim)) |
| t = torch.arange(max_seq_len, dtype=torch.float32) |
| freqs = torch.outer(t, freqs) |
| return torch.polar(torch.ones_like(freqs), freqs) |
|
|
|
|
| def _apply_rotary_emb(xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor): |
| def rotate(x): |
| xc = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) |
| f = freqs_cis.view(1, x.shape[1], 1, freqs_cis.shape[-1]) |
| return torch.view_as_real(xc * f).flatten(-2).type_as(x) |
| return rotate(xq), rotate(xk) |
|
|
|
|
| class Attention(nn.Module): |
| def __init__(self, cfg: eGPTConfig): |
| super().__init__() |
| self.n_heads = cfg.n_heads |
| self.n_kv_heads = cfg.n_kv_heads |
| self.head_dim = cfg.head_dim or (cfg.dim // cfg.n_heads) |
| self.n_rep = cfg.n_heads // cfg.n_kv_heads |
|
|
| self.wq = nn.Linear(cfg.dim, cfg.n_heads * self.head_dim, bias=False) |
| self.wk = nn.Linear(cfg.dim, cfg.n_kv_heads * self.head_dim, bias=False) |
| self.wv = nn.Linear(cfg.dim, cfg.n_kv_heads * self.head_dim, bias=False) |
| self.wo = nn.Linear(cfg.n_heads * self.head_dim, cfg.dim, bias=False) |
|
|
| def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: |
| B, T, _ = x.shape |
| xq = self.wq(x).view(B, T, self.n_heads, self.head_dim) |
| xk = self.wk(x).view(B, T, self.n_kv_heads, self.head_dim) |
| xv = self.wv(x).view(B, T, self.n_kv_heads, self.head_dim) |
| xq, xk = _apply_rotary_emb(xq, xk, freqs_cis[:T]) |
| if self.n_rep > 1: |
| xk = xk.repeat_interleave(self.n_rep, dim=2) |
| xv = xv.repeat_interleave(self.n_rep, dim=2) |
| out = F.scaled_dot_product_attention( |
| xq.transpose(1, 2), xk.transpose(1, 2), xv.transpose(1, 2), is_causal=True |
| ) |
| return self.wo(out.transpose(1, 2).contiguous().view(B, T, -1)) |
|
|
|
|
| class FeedForward(nn.Module): |
| def __init__(self, cfg: eGPTConfig): |
| super().__init__() |
| hidden = int(cfg.dim * cfg.ffn_multiplier * 2 / 3) |
| hidden = cfg.multiple_of * math.ceil(hidden / cfg.multiple_of) |
| self.w1 = nn.Linear(cfg.dim, hidden, bias=False) |
| self.w2 = nn.Linear(hidden, cfg.dim, bias=False) |
| self.w3 = nn.Linear(cfg.dim, hidden, bias=False) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return self.w2(F.silu(self.w1(x)) * self.w3(x)) |
|
|
|
|
| class TransformerBlock(nn.Module): |
| def __init__(self, cfg: eGPTConfig): |
| super().__init__() |
| self.attn_norm = RMSNorm(cfg.dim, cfg.norm_eps) |
| self.attn = Attention(cfg) |
| self.ffn_norm = RMSNorm(cfg.dim, cfg.norm_eps) |
| self.ffn = FeedForward(cfg) |
|
|
| def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: |
| x = x + self.attn(self.attn_norm(x), freqs_cis) |
| x = x + self.ffn(self.ffn_norm(x)) |
| return x |
|
|
|
|
| class eGPTForCausalLM(PreTrainedModel, GenerationMixin): |
| config_class = eGPTConfig |
|
|
| def __init__(self, config: eGPTConfig): |
| super().__init__(config) |
| head_dim = config.head_dim or (config.dim // config.n_heads) |
|
|
| self.embed = nn.Embedding(config.vocab_size, config.dim) |
| self.layers = nn.ModuleList([TransformerBlock(config) for _ in range(config.n_layers)]) |
| self.norm = RMSNorm(config.dim, config.norm_eps) |
| self.head = nn.Linear(config.dim, config.vocab_size, bias=False) |
|
|
| if config.weight_tying: |
| self.head.weight = self.embed.weight |
|
|
| freqs_cis = _precompute_freqs_cis(head_dim, config.max_seq_len * 2, config.rope_theta) |
| self.register_buffer("freqs_cis", freqs_cis, persistent=False) |
|
|
| self.post_init() |
|
|
| def get_input_embeddings(self): |
| return self.embed |
|
|
| def set_input_embeddings(self, value): |
| self.embed = value |
|
|
| def get_output_embeddings(self): |
| return self.head |
|
|
| def forward( |
| self, |
| input_ids: torch.Tensor, |
| attention_mask: Optional[torch.Tensor] = None, |
| labels: Optional[torch.Tensor] = None, |
| **kwargs, |
| ) -> CausalLMOutputWithPast: |
| x = self.embed(input_ids) |
| for layer in self.layers: |
| x = layer(x, self.freqs_cis) |
| x = self.norm(x) |
| logits = self.head(x) |
|
|
| loss = None |
| if labels is not None: |
| loss = F.cross_entropy( |
| logits.view(-1, logits.size(-1)).float(), |
| labels.reshape(-1), |
| ignore_index=-100, |
| ) |
|
|
| return CausalLMOutputWithPast(loss=loss, logits=logits) |
|
|