rex1-base / model.py
DavidSeyserHF's picture
Update rex1-base: mixed-2 checkpoint step 710000
3abc4f7 verified
"""REX: a recursive decoder-only Transformer language model."""
from __future__ import annotations
import json
from dataclasses import asdict, dataclass
from pathlib import Path
from typing import Any
import torch
import torch.nn as nn
import torch.nn.functional as F
@dataclass
class RexConfig:
vocab_size: int = 50_257
max_seq_len: int = 2048
d_model: int = 1536
n_heads: int = 16
n_layers: int = 8
recurrence_steps: int = 2
ffn_dim: int = 3968
dropout: float = 0.0
norm_eps: float = 1e-5
tie_embeddings: bool = True
use_step_embeddings: bool = True
initializer_range: float = 0.02
@classmethod
def from_dict(cls, data: dict[str, Any]) -> "RexConfig":
fields = {name for name in cls.__dataclass_fields__}
return cls(**{k: v for k, v in data.items() if k in fields})
def to_dict(self) -> dict[str, Any]:
return asdict(self)
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-5):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def forward(self, x: torch.Tensor) -> torch.Tensor:
dtype = x.dtype
x = x.float()
x = x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
return (self.weight * x).to(dtype)
class RotaryEmbedding(nn.Module):
def __init__(self, dim: int, max_seq_len: int, base: float = 10_000.0):
super().__init__()
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
positions = torch.arange(max_seq_len, dtype=torch.float)
freqs = torch.outer(positions, inv_freq)
self.register_buffer("cos", freqs.cos(), persistent=False)
self.register_buffer("sin", freqs.sin(), persistent=False)
def forward(self, seq_len: int) -> tuple[torch.Tensor, torch.Tensor]:
return self.cos[:seq_len], self.sin[:seq_len]
def _rotate_half(x: torch.Tensor) -> torch.Tensor:
x1 = x[..., ::2]
x2 = x[..., 1::2]
return torch.stack((-x2, x1), dim=-1).flatten(-2)
def apply_rotary(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
cos = torch.repeat_interleave(cos, 2, dim=-1)[None, None, :, :]
sin = torch.repeat_interleave(sin, 2, dim=-1)[None, None, :, :]
return (x * cos) + (_rotate_half(x) * sin)
def _safe_torch_load(path: str | Path, map_location: str | torch.device | None) -> Any:
try:
return torch.load(path, map_location=map_location, weights_only=True)
except TypeError:
return torch.load(path, map_location=map_location)
class CausalSelfAttention(nn.Module):
def __init__(self, cfg: RexConfig):
super().__init__()
if cfg.d_model % cfg.n_heads != 0:
raise ValueError("d_model must be divisible by n_heads")
self.n_heads = cfg.n_heads
self.head_dim = cfg.d_model // cfg.n_heads
if self.head_dim % 2 != 0:
raise ValueError("attention head_dim must be even for rotary embeddings")
self.qkv = nn.Linear(cfg.d_model, 3 * cfg.d_model, bias=False)
self.out = nn.Linear(cfg.d_model, cfg.d_model, bias=False)
self.dropout = cfg.dropout
self.rotary = RotaryEmbedding(self.head_dim, cfg.max_seq_len)
def forward(self, x: torch.Tensor) -> torch.Tensor:
bsz, seq_len, width = x.shape
qkv = self.qkv(x)
q, k, v = qkv.chunk(3, dim=-1)
q = q.view(bsz, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
k = k.view(bsz, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
v = v.view(bsz, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
cos, sin = self.rotary(seq_len)
q = apply_rotary(q, cos.to(q.device), sin.to(q.device))
k = apply_rotary(k, cos.to(k.device), sin.to(k.device))
y = F.scaled_dot_product_attention(
q,
k,
v,
dropout_p=self.dropout if self.training else 0.0,
is_causal=True,
)
y = y.transpose(1, 2).contiguous().view(bsz, seq_len, width)
return self.out(y)
class SwiGLU(nn.Module):
def __init__(self, cfg: RexConfig):
super().__init__()
self.w1 = nn.Linear(cfg.d_model, cfg.ffn_dim, bias=False)
self.w2 = nn.Linear(cfg.ffn_dim, cfg.d_model, bias=False)
self.w3 = nn.Linear(cfg.d_model, cfg.ffn_dim, bias=False)
self.dropout = nn.Dropout(cfg.dropout)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))
class RexBlock(nn.Module):
def __init__(self, cfg: RexConfig):
super().__init__()
self.attn_norm = RMSNorm(cfg.d_model, cfg.norm_eps)
self.attn = CausalSelfAttention(cfg)
self.ffn_norm = RMSNorm(cfg.d_model, cfg.norm_eps)
self.ffn = SwiGLU(cfg)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x + self.attn(self.attn_norm(x))
x = x + self.ffn(self.ffn_norm(x))
return x
class RexForCausalLM(nn.Module):
"""Decoder-only LM with a stack of blocks reused across recursive passes."""
def __init__(self, cfg: RexConfig):
super().__init__()
if cfg.recurrence_steps < 1:
raise ValueError("recurrence_steps must be >= 1")
self.cfg = cfg
self.token_embedding = nn.Embedding(cfg.vocab_size, cfg.d_model)
self.drop = nn.Dropout(cfg.dropout)
self.blocks = nn.ModuleList([RexBlock(cfg) for _ in range(cfg.n_layers)])
self.final_norm = RMSNorm(cfg.d_model, cfg.norm_eps)
self.lm_head = nn.Linear(cfg.d_model, cfg.vocab_size, bias=False)
if cfg.tie_embeddings:
self.lm_head.weight = self.token_embedding.weight
if cfg.use_step_embeddings:
self.step_embedding = nn.Parameter(torch.zeros(cfg.recurrence_steps, cfg.d_model))
else:
self.register_parameter("step_embedding", None)
self.apply(self._init_weights)
def _init_weights(self, module: nn.Module) -> None:
if isinstance(module, nn.Linear):
nn.init.normal_(module.weight, mean=0.0, std=self.cfg.initializer_range)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
nn.init.normal_(module.weight, mean=0.0, std=self.cfg.initializer_range)
def encode(self, input_ids: torch.Tensor, normalize: bool = True) -> torch.Tensor:
"""Return contextual token representations for downstream tasks."""
if input_ids.ndim != 2:
raise ValueError("input_ids must have shape [batch, seq]")
if input_ids.size(1) > self.cfg.max_seq_len:
raise ValueError(f"sequence length exceeds max_seq_len={self.cfg.max_seq_len}")
x = self.drop(self.token_embedding(input_ids))
for step in range(self.cfg.recurrence_steps):
if self.step_embedding is not None:
x = x + self.step_embedding[step].view(1, 1, -1)
for block in self.blocks:
x = block(x)
if normalize:
x = self.final_norm(x)
return x
def forward(
self,
input_ids: torch.Tensor,
labels: torch.Tensor | None = None,
) -> dict[str, torch.Tensor | None]:
hidden_states = self.encode(input_ids, normalize=True)
logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
loss = F.cross_entropy(
logits[:, :-1].contiguous().view(-1, logits.size(-1)),
labels[:, 1:].contiguous().view(-1),
ignore_index=-100,
)
return {"logits": logits, "loss": loss}
@torch.no_grad()
def generate(
self,
input_ids: torch.Tensor,
max_new_tokens: int,
temperature: float = 1.0,
top_k: int | None = None,
no_repeat_ngram_size: int = 0,
) -> torch.Tensor:
self.eval()
if no_repeat_ngram_size < 0:
raise ValueError("no_repeat_ngram_size must be >= 0")
for _ in range(max_new_tokens):
context = input_ids[:, -self.cfg.max_seq_len :]
logits = self(context)["logits"][:, -1, :]
logits = self._apply_no_repeat_ngram(logits, input_ids, no_repeat_ngram_size)
if temperature < 0:
raise ValueError("temperature must be >= 0")
if temperature == 0:
next_token = torch.argmax(logits, dim=-1, keepdim=True)
input_ids = torch.cat([input_ids, next_token], dim=1)
continue
logits = logits / temperature
if top_k is not None:
values, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits = logits.masked_fill(logits < values[:, [-1]], float("-inf"))
probs = F.softmax(logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
input_ids = torch.cat([input_ids, next_token], dim=1)
return input_ids
@staticmethod
def _apply_no_repeat_ngram(
logits: torch.Tensor,
input_ids: torch.Tensor,
no_repeat_ngram_size: int,
) -> torch.Tensor:
if no_repeat_ngram_size <= 0:
return logits
logits = logits.clone()
for batch_idx in range(input_ids.size(0)):
banned_tokens = RexForCausalLM._get_banned_ngram_tokens(
input_ids[batch_idx].tolist(),
no_repeat_ngram_size,
)
if banned_tokens:
logits[batch_idx, banned_tokens] = float("-inf")
return logits
@staticmethod
def _get_banned_ngram_tokens(tokens: list[int], ngram_size: int) -> list[int]:
if ngram_size == 1:
return list(set(tokens))
if len(tokens) < ngram_size - 1:
return []
prefix_to_next: dict[tuple[int, ...], set[int]] = {}
for i in range(len(tokens) - ngram_size + 1):
ngram = tokens[i : i + ngram_size]
prefix = tuple(ngram[:-1])
prefix_to_next.setdefault(prefix, set()).add(ngram[-1])
current_prefix = tuple(tokens[-(ngram_size - 1) :])
return list(prefix_to_next.get(current_prefix, set()))
def parameter_count(self, trainable_only: bool = False) -> int:
params = self.parameters()
if trainable_only:
params = (p for p in params if p.requires_grad)
return sum(p.numel() for p in params)
def save_pretrained(self, save_directory: str | Path) -> None:
"""Save model weights and config in a lightweight HF-style folder."""
save_path = Path(save_directory)
save_path.mkdir(parents=True, exist_ok=True)
with open(save_path / "config.json", "w", encoding="utf-8") as f:
json.dump(self.cfg.to_dict(), f, indent=2)
f.write("\n")
torch.save(self.state_dict(), save_path / "pytorch_model.bin")
@classmethod
def from_pretrained(
cls,
load_directory: str | Path,
map_location: str | torch.device | None = "cpu",
strict: bool = True,
) -> "RexForCausalLM":
"""Load a model saved by ``save_pretrained``."""
load_path = Path(load_directory)
with open(load_path / "config.json", "r", encoding="utf-8") as f:
cfg = RexConfig.from_dict(json.load(f))
model = cls(cfg)
state_dict = _safe_torch_load(load_path / "pytorch_model.bin", map_location)
model.load_state_dict(state_dict, strict=strict)
return model
@classmethod
def from_checkpoint(
cls,
checkpoint_path: str | Path,
map_location: str | torch.device | None = "cpu",
strict: bool = True,
) -> "RexForCausalLM":
"""Load from a training checkpoint produced by ``train.py``."""
checkpoint = _safe_torch_load(checkpoint_path, map_location)
cfg_data = checkpoint.get("model_config")
if cfg_data is None:
cfg_data = checkpoint.get("config", {}).get("model")
if cfg_data is None:
raise ValueError("checkpoint does not contain model_config or config.model")
model = cls(RexConfig.from_dict(cfg_data))
state_dict = checkpoint.get("model", checkpoint)
model.load_state_dict(state_dict, strict=strict)
return model
def build_model(config: dict[str, Any] | RexConfig | None = None) -> RexForCausalLM:
if config is None:
cfg = RexConfig()
elif isinstance(config, RexConfig):
cfg = config
else:
cfg = RexConfig.from_dict(config)
return RexForCausalLM(cfg)