"""REX: a recursive decoder-only Transformer language model.""" from __future__ import annotations import json from dataclasses import asdict, dataclass from pathlib import Path from typing import Any import torch import torch.nn as nn import torch.nn.functional as F @dataclass class RexConfig: vocab_size: int = 50_257 max_seq_len: int = 2048 d_model: int = 1536 n_heads: int = 16 n_layers: int = 8 recurrence_steps: int = 2 ffn_dim: int = 3968 dropout: float = 0.0 norm_eps: float = 1e-5 tie_embeddings: bool = True use_step_embeddings: bool = True initializer_range: float = 0.02 @classmethod def from_dict(cls, data: dict[str, Any]) -> "RexConfig": fields = {name for name in cls.__dataclass_fields__} return cls(**{k: v for k, v in data.items() if k in fields}) def to_dict(self) -> dict[str, Any]: return asdict(self) class RMSNorm(nn.Module): def __init__(self, dim: int, eps: float = 1e-5): super().__init__() self.eps = eps self.weight = nn.Parameter(torch.ones(dim)) def forward(self, x: torch.Tensor) -> torch.Tensor: dtype = x.dtype x = x.float() x = x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps) return (self.weight * x).to(dtype) class RotaryEmbedding(nn.Module): def __init__(self, dim: int, max_seq_len: int, base: float = 10_000.0): super().__init__() inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) positions = torch.arange(max_seq_len, dtype=torch.float) freqs = torch.outer(positions, inv_freq) self.register_buffer("cos", freqs.cos(), persistent=False) self.register_buffer("sin", freqs.sin(), persistent=False) def forward(self, seq_len: int) -> tuple[torch.Tensor, torch.Tensor]: return self.cos[:seq_len], self.sin[:seq_len] def _rotate_half(x: torch.Tensor) -> torch.Tensor: x1 = x[..., ::2] x2 = x[..., 1::2] return torch.stack((-x2, x1), dim=-1).flatten(-2) def apply_rotary(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: cos = torch.repeat_interleave(cos, 2, dim=-1)[None, None, :, :] sin = torch.repeat_interleave(sin, 2, dim=-1)[None, None, :, :] return (x * cos) + (_rotate_half(x) * sin) def _safe_torch_load(path: str | Path, map_location: str | torch.device | None) -> Any: try: return torch.load(path, map_location=map_location, weights_only=True) except TypeError: return torch.load(path, map_location=map_location) class CausalSelfAttention(nn.Module): def __init__(self, cfg: RexConfig): super().__init__() if cfg.d_model % cfg.n_heads != 0: raise ValueError("d_model must be divisible by n_heads") self.n_heads = cfg.n_heads self.head_dim = cfg.d_model // cfg.n_heads if self.head_dim % 2 != 0: raise ValueError("attention head_dim must be even for rotary embeddings") self.qkv = nn.Linear(cfg.d_model, 3 * cfg.d_model, bias=False) self.out = nn.Linear(cfg.d_model, cfg.d_model, bias=False) self.dropout = cfg.dropout self.rotary = RotaryEmbedding(self.head_dim, cfg.max_seq_len) def forward(self, x: torch.Tensor) -> torch.Tensor: bsz, seq_len, width = x.shape qkv = self.qkv(x) q, k, v = qkv.chunk(3, dim=-1) q = q.view(bsz, seq_len, self.n_heads, self.head_dim).transpose(1, 2) k = k.view(bsz, seq_len, self.n_heads, self.head_dim).transpose(1, 2) v = v.view(bsz, seq_len, self.n_heads, self.head_dim).transpose(1, 2) cos, sin = self.rotary(seq_len) q = apply_rotary(q, cos.to(q.device), sin.to(q.device)) k = apply_rotary(k, cos.to(k.device), sin.to(k.device)) y = F.scaled_dot_product_attention( q, k, v, dropout_p=self.dropout if self.training else 0.0, is_causal=True, ) y = y.transpose(1, 2).contiguous().view(bsz, seq_len, width) return self.out(y) class SwiGLU(nn.Module): def __init__(self, cfg: RexConfig): super().__init__() self.w1 = nn.Linear(cfg.d_model, cfg.ffn_dim, bias=False) self.w2 = nn.Linear(cfg.ffn_dim, cfg.d_model, bias=False) self.w3 = nn.Linear(cfg.d_model, cfg.ffn_dim, bias=False) self.dropout = nn.Dropout(cfg.dropout) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x))) class RexBlock(nn.Module): def __init__(self, cfg: RexConfig): super().__init__() self.attn_norm = RMSNorm(cfg.d_model, cfg.norm_eps) self.attn = CausalSelfAttention(cfg) self.ffn_norm = RMSNorm(cfg.d_model, cfg.norm_eps) self.ffn = SwiGLU(cfg) def forward(self, x: torch.Tensor) -> torch.Tensor: x = x + self.attn(self.attn_norm(x)) x = x + self.ffn(self.ffn_norm(x)) return x class RexForCausalLM(nn.Module): """Decoder-only LM with a stack of blocks reused across recursive passes.""" def __init__(self, cfg: RexConfig): super().__init__() if cfg.recurrence_steps < 1: raise ValueError("recurrence_steps must be >= 1") self.cfg = cfg self.token_embedding = nn.Embedding(cfg.vocab_size, cfg.d_model) self.drop = nn.Dropout(cfg.dropout) self.blocks = nn.ModuleList([RexBlock(cfg) for _ in range(cfg.n_layers)]) self.final_norm = RMSNorm(cfg.d_model, cfg.norm_eps) self.lm_head = nn.Linear(cfg.d_model, cfg.vocab_size, bias=False) if cfg.tie_embeddings: self.lm_head.weight = self.token_embedding.weight if cfg.use_step_embeddings: self.step_embedding = nn.Parameter(torch.zeros(cfg.recurrence_steps, cfg.d_model)) else: self.register_parameter("step_embedding", None) self.apply(self._init_weights) def _init_weights(self, module: nn.Module) -> None: if isinstance(module, nn.Linear): nn.init.normal_(module.weight, mean=0.0, std=self.cfg.initializer_range) if module.bias is not None: nn.init.zeros_(module.bias) elif isinstance(module, nn.Embedding): nn.init.normal_(module.weight, mean=0.0, std=self.cfg.initializer_range) def encode(self, input_ids: torch.Tensor, normalize: bool = True) -> torch.Tensor: """Return contextual token representations for downstream tasks.""" if input_ids.ndim != 2: raise ValueError("input_ids must have shape [batch, seq]") if input_ids.size(1) > self.cfg.max_seq_len: raise ValueError(f"sequence length exceeds max_seq_len={self.cfg.max_seq_len}") x = self.drop(self.token_embedding(input_ids)) for step in range(self.cfg.recurrence_steps): if self.step_embedding is not None: x = x + self.step_embedding[step].view(1, 1, -1) for block in self.blocks: x = block(x) if normalize: x = self.final_norm(x) return x def forward( self, input_ids: torch.Tensor, labels: torch.Tensor | None = None, ) -> dict[str, torch.Tensor | None]: hidden_states = self.encode(input_ids, normalize=True) logits = self.lm_head(hidden_states) loss = None if labels is not None: loss = F.cross_entropy( logits[:, :-1].contiguous().view(-1, logits.size(-1)), labels[:, 1:].contiguous().view(-1), ignore_index=-100, ) return {"logits": logits, "loss": loss} @torch.no_grad() def generate( self, input_ids: torch.Tensor, max_new_tokens: int, temperature: float = 1.0, top_k: int | None = None, no_repeat_ngram_size: int = 0, ) -> torch.Tensor: self.eval() if no_repeat_ngram_size < 0: raise ValueError("no_repeat_ngram_size must be >= 0") for _ in range(max_new_tokens): context = input_ids[:, -self.cfg.max_seq_len :] logits = self(context)["logits"][:, -1, :] logits = self._apply_no_repeat_ngram(logits, input_ids, no_repeat_ngram_size) if temperature < 0: raise ValueError("temperature must be >= 0") if temperature == 0: next_token = torch.argmax(logits, dim=-1, keepdim=True) input_ids = torch.cat([input_ids, next_token], dim=1) continue logits = logits / temperature if top_k is not None: values, _ = torch.topk(logits, min(top_k, logits.size(-1))) logits = logits.masked_fill(logits < values[:, [-1]], float("-inf")) probs = F.softmax(logits, dim=-1) next_token = torch.multinomial(probs, num_samples=1) input_ids = torch.cat([input_ids, next_token], dim=1) return input_ids @staticmethod def _apply_no_repeat_ngram( logits: torch.Tensor, input_ids: torch.Tensor, no_repeat_ngram_size: int, ) -> torch.Tensor: if no_repeat_ngram_size <= 0: return logits logits = logits.clone() for batch_idx in range(input_ids.size(0)): banned_tokens = RexForCausalLM._get_banned_ngram_tokens( input_ids[batch_idx].tolist(), no_repeat_ngram_size, ) if banned_tokens: logits[batch_idx, banned_tokens] = float("-inf") return logits @staticmethod def _get_banned_ngram_tokens(tokens: list[int], ngram_size: int) -> list[int]: if ngram_size == 1: return list(set(tokens)) if len(tokens) < ngram_size - 1: return [] prefix_to_next: dict[tuple[int, ...], set[int]] = {} for i in range(len(tokens) - ngram_size + 1): ngram = tokens[i : i + ngram_size] prefix = tuple(ngram[:-1]) prefix_to_next.setdefault(prefix, set()).add(ngram[-1]) current_prefix = tuple(tokens[-(ngram_size - 1) :]) return list(prefix_to_next.get(current_prefix, set())) def parameter_count(self, trainable_only: bool = False) -> int: params = self.parameters() if trainable_only: params = (p for p in params if p.requires_grad) return sum(p.numel() for p in params) def save_pretrained(self, save_directory: str | Path) -> None: """Save model weights and config in a lightweight HF-style folder.""" save_path = Path(save_directory) save_path.mkdir(parents=True, exist_ok=True) with open(save_path / "config.json", "w", encoding="utf-8") as f: json.dump(self.cfg.to_dict(), f, indent=2) f.write("\n") torch.save(self.state_dict(), save_path / "pytorch_model.bin") @classmethod def from_pretrained( cls, load_directory: str | Path, map_location: str | torch.device | None = "cpu", strict: bool = True, ) -> "RexForCausalLM": """Load a model saved by ``save_pretrained``.""" load_path = Path(load_directory) with open(load_path / "config.json", "r", encoding="utf-8") as f: cfg = RexConfig.from_dict(json.load(f)) model = cls(cfg) state_dict = _safe_torch_load(load_path / "pytorch_model.bin", map_location) model.load_state_dict(state_dict, strict=strict) return model @classmethod def from_checkpoint( cls, checkpoint_path: str | Path, map_location: str | torch.device | None = "cpu", strict: bool = True, ) -> "RexForCausalLM": """Load from a training checkpoint produced by ``train.py``.""" checkpoint = _safe_torch_load(checkpoint_path, map_location) cfg_data = checkpoint.get("model_config") if cfg_data is None: cfg_data = checkpoint.get("config", {}).get("model") if cfg_data is None: raise ValueError("checkpoint does not contain model_config or config.model") model = cls(RexConfig.from_dict(cfg_data)) state_dict = checkpoint.get("model", checkpoint) model.load_state_dict(state_dict, strict=strict) return model def build_model(config: dict[str, Any] | RexConfig | None = None) -> RexForCausalLM: if config is None: cfg = RexConfig() elif isinstance(config, RexConfig): cfg = config else: cfg = RexConfig.from_dict(config) return RexForCausalLM(cfg)