eGPT-100M-bytes-untrained / modeling_egpt.py
macabdul9's picture
Upload folder using huggingface_hub
4e74b1b verified
"""eGPT: LLaMA-style decoder-only model — self-contained HuggingFace implementation.
Architecture: RMSNorm, RoPE, Grouped-Query Attention, SwiGLU FFN, no bias.
Weight keys match eGPT/model.py exactly for checkpoint compatibility.
"""
import math
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PreTrainedModel
from transformers.generation import GenerationMixin
from transformers.modeling_outputs import CausalLMOutputWithPast
try:
from .configuration_egpt import eGPTConfig
except ImportError:
from configuration_egpt import eGPTConfig
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-5):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight
def _precompute_freqs_cis(head_dim: int, max_seq_len: int, theta: float) -> torch.Tensor:
freqs = 1.0 / (theta ** (torch.arange(0, head_dim, 2).float() / head_dim))
t = torch.arange(max_seq_len, dtype=torch.float32)
freqs = torch.outer(t, freqs)
return torch.polar(torch.ones_like(freqs), freqs)
def _apply_rotary_emb(xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor):
def rotate(x):
xc = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
f = freqs_cis.view(1, x.shape[1], 1, freqs_cis.shape[-1])
return torch.view_as_real(xc * f).flatten(-2).type_as(x)
return rotate(xq), rotate(xk)
class Attention(nn.Module):
def __init__(self, cfg: eGPTConfig):
super().__init__()
self.n_heads = cfg.n_heads
self.n_kv_heads = cfg.n_kv_heads
self.head_dim = cfg.head_dim or (cfg.dim // cfg.n_heads)
self.n_rep = cfg.n_heads // cfg.n_kv_heads
self.wq = nn.Linear(cfg.dim, cfg.n_heads * self.head_dim, bias=False)
self.wk = nn.Linear(cfg.dim, cfg.n_kv_heads * self.head_dim, bias=False)
self.wv = nn.Linear(cfg.dim, cfg.n_kv_heads * self.head_dim, bias=False)
self.wo = nn.Linear(cfg.n_heads * self.head_dim, cfg.dim, bias=False)
def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
B, T, _ = x.shape
xq = self.wq(x).view(B, T, self.n_heads, self.head_dim)
xk = self.wk(x).view(B, T, self.n_kv_heads, self.head_dim)
xv = self.wv(x).view(B, T, self.n_kv_heads, self.head_dim)
xq, xk = _apply_rotary_emb(xq, xk, freqs_cis[:T])
if self.n_rep > 1:
xk = xk.repeat_interleave(self.n_rep, dim=2)
xv = xv.repeat_interleave(self.n_rep, dim=2)
out = F.scaled_dot_product_attention(
xq.transpose(1, 2), xk.transpose(1, 2), xv.transpose(1, 2), is_causal=True
)
return self.wo(out.transpose(1, 2).contiguous().view(B, T, -1))
class FeedForward(nn.Module):
def __init__(self, cfg: eGPTConfig):
super().__init__()
hidden = int(cfg.dim * cfg.ffn_multiplier * 2 / 3)
hidden = cfg.multiple_of * math.ceil(hidden / cfg.multiple_of)
self.w1 = nn.Linear(cfg.dim, hidden, bias=False)
self.w2 = nn.Linear(hidden, cfg.dim, bias=False)
self.w3 = nn.Linear(cfg.dim, hidden, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.w2(F.silu(self.w1(x)) * self.w3(x))
class TransformerBlock(nn.Module):
def __init__(self, cfg: eGPTConfig):
super().__init__()
self.attn_norm = RMSNorm(cfg.dim, cfg.norm_eps)
self.attn = Attention(cfg)
self.ffn_norm = RMSNorm(cfg.dim, cfg.norm_eps)
self.ffn = FeedForward(cfg)
def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
x = x + self.attn(self.attn_norm(x), freqs_cis)
x = x + self.ffn(self.ffn_norm(x))
return x
class eGPTForCausalLM(PreTrainedModel, GenerationMixin):
config_class = eGPTConfig
def __init__(self, config: eGPTConfig):
super().__init__(config)
head_dim = config.head_dim or (config.dim // config.n_heads)
self.embed = nn.Embedding(config.vocab_size, config.dim)
self.layers = nn.ModuleList([TransformerBlock(config) for _ in range(config.n_layers)])
self.norm = RMSNorm(config.dim, config.norm_eps)
self.head = nn.Linear(config.dim, config.vocab_size, bias=False)
if config.weight_tying:
self.head.weight = self.embed.weight
freqs_cis = _precompute_freqs_cis(head_dim, config.max_seq_len * 2, config.rope_theta)
self.register_buffer("freqs_cis", freqs_cis, persistent=False)
self.post_init()
def get_input_embeddings(self):
return self.embed
def set_input_embeddings(self, value):
self.embed = value
def get_output_embeddings(self):
return self.head
def forward(
self,
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
**kwargs,
) -> CausalLMOutputWithPast:
x = self.embed(input_ids)
for layer in self.layers:
x = layer(x, self.freqs_cis)
x = self.norm(x)
logits = self.head(x)
loss = None
if labels is not None:
loss = F.cross_entropy(
logits.view(-1, logits.size(-1)).float(),
labels.reshape(-1),
ignore_index=-100,
)
return CausalLMOutputWithPast(loss=loss, logits=logits)