import math from typing import Optional import torch import torch.nn as nn import torch.nn.functional as F from transformers import PreTrainedModel from transformers.modeling_outputs import BaseModelOutput, MaskedLMOutput from .configuration_theo_bert_base import TheoBertBaseConfig from .muon import Muon def norm(x: torch.Tensor) -> torch.Tensor: return F.rms_norm(x, (x.size(-1),)) def apply_rotary_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: d = x.shape[-1] // 2 x1, x2 = x[..., :d], x[..., d:] return torch.cat([x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos], dim=-1).to(x.dtype) class RMSNorm(nn.Module): def __init__(self, dim: int): super().__init__() self.dim = dim def forward(self, x: torch.Tensor) -> torch.Tensor: return F.rms_norm(x, (self.dim,)) class SelfAttention(nn.Module): def __init__(self, config: TheoBertBaseConfig): super().__init__() self.n_head = config.n_head self.head_dim = config.n_embd // config.n_head D = config.n_embd self.c_q = nn.Linear(D, D, bias=False) self.c_k = nn.Linear(D, D, bias=False) self.c_v = nn.Linear(D, D, bias=False) self.c_proj = nn.Linear(D, D, bias=False) def forward(self, x, cos_sin, ve=None, attention_mask=None): B, T, D = x.shape H, Dh = self.n_head, self.head_dim q = self.c_q(x).view(B, T, H, Dh) k = self.c_k(x).view(B, T, H, Dh) v_proj = self.c_v(x) if ve is not None: v_proj = v_proj + ve v = v_proj.view(B, T, H, Dh) cos, sin = cos_sin q = apply_rotary_emb(q, cos, sin) k = apply_rotary_emb(k, cos, sin) q = norm(q) k = norm(k) q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) attn_mask = None if attention_mask is not None: attn_mask = attention_mask[:, None, None, :].bool() y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, is_causal=False) return self.c_proj(y.transpose(1, 2).contiguous().view(B, T, D)) class MLP(nn.Module): def __init__(self, config: TheoBertBaseConfig): super().__init__() D = config.n_embd self.c_fc = nn.Linear(D, 4 * D, bias=False) self.c_proj = nn.Linear(4 * D, D, bias=False) def forward(self, x): return self.c_proj(F.relu(self.c_fc(x)).square()) class Block(nn.Module): def __init__(self, config: TheoBertBaseConfig, layer_idx: int): super().__init__() self.attn = SelfAttention(config) self.mlp = MLP(config) self.resid_lambda = nn.Parameter(torch.ones(1)) self.x0_lambda = nn.Parameter(torch.full((1,), 0.1)) if layer_idx % 2 == 0: self.value_embed = nn.Embedding(config.vocab_size, config.n_embd) self.ve_gate = nn.Linear(32, 1, bias=False) def forward(self, x, cos_sin, x0, token_ids, attention_mask=None): normed = norm(x) ve = None if hasattr(self, "value_embed"): raw_ve = self.value_embed(token_ids) gate = 2 * torch.sigmoid(self.ve_gate(normed[..., :32])) ve = gate * raw_ve x = x + self.attn(normed, cos_sin, ve=ve, attention_mask=attention_mask) x = x + self.mlp(norm(x)) return self.resid_lambda * x + self.x0_lambda * x0 class TheoBertBasePreTrainedModel(PreTrainedModel): config_class = TheoBertBaseConfig base_model_prefix = "theo_bert_base" supports_gradient_checkpointing = True def _init_weights(self, module): if isinstance(module, nn.Linear): fan_in = module.weight.size(1) fan_out = module.weight.size(0) std = (1.0 / math.sqrt(fan_in)) * min(1.0, math.sqrt(fan_out / fan_in)) nn.init.normal_(module.weight, std=std) elif isinstance(module, nn.Embedding): nn.init.normal_(module.weight, std=1.0) def _set_gradient_checkpointing(self, module, value=False): if isinstance(module, TheoBertBaseModel): module.use_gradient_checkpointing = value class TheoBertBaseModel(TheoBertBasePreTrainedModel): def __init__(self, config: TheoBertBaseConfig): super().__init__(config) self.use_gradient_checkpointing = False self.wte = nn.Embedding(config.vocab_size, config.n_embd) self.blocks = nn.ModuleList([Block(config, i) for i in range(config.n_layer)]) # Retained on the base model so the exported checkpoint can be consumed by # AutoModel and AutoModelForMaskedLM from the same repository without key drift. self.mlm_head = nn.Sequential( nn.Linear(config.n_embd, config.n_embd, bias=False), nn.GELU(), RMSNorm(config.n_embd), nn.Linear(config.n_embd, config.vocab_size, bias=False), ) self._refresh_rope_cache() self.post_init() self._post_init_architecture() def _post_init_architecture(self): nn.init.zeros_(self.mlm_head[-1].weight) for block in self.blocks: nn.init.zeros_(block.attn.c_proj.weight) nn.init.zeros_(block.mlp.c_proj.weight) nn.init.ones_(block.resid_lambda) block.x0_lambda.data.fill_(0.1) def _make_rotary(self, seq_len, head_dim, base=10000, device=None): if device is None: device = self.wte.weight.device inv_freq = 1.0 / ( base ** (torch.arange(0, head_dim, 2, dtype=torch.float32, device=device) / head_dim) ) t = torch.arange(seq_len, dtype=torch.float32, device=device) freqs = torch.outer(t, inv_freq) cos = freqs.cos()[None, :, None, :] sin = freqs.sin()[None, :, None, :] return cos, sin def _refresh_rope_cache(self): head_dim = self.config.n_embd // self.config.n_head cache_len = self.config.seq_len * self.config.rope_cache_factor cos, sin = self._make_rotary(cache_len, head_dim, base=self.config.rope_base) self.register_buffer("cos", cos, persistent=False) self.register_buffer("sin", sin, persistent=False) def get_input_embeddings(self): return self.wte def set_input_embeddings(self, value): self.wte = value def mean_pool(self, hidden, mask=None): if mask is None: return hidden.mean(dim=1) m = mask.unsqueeze(-1).float() return (hidden * m).sum(1) / m.sum(1).clamp(min=1) def setup_optimizers(self, embedding_lr=0.3, matrix_lr=0.02): model_dim = self.config.n_embd mlm_head_lr = 0.004 * math.sqrt(768 / model_dim) embed_params = list(self.wte.parameters()) ve_params, ve_gate_params, resid_params, x0_params, matrix_params = [], [], [], [], [] for block in self.blocks: matrix_params += [ block.attn.c_q.weight, block.attn.c_k.weight, block.attn.c_v.weight, block.attn.c_proj.weight, block.mlp.c_fc.weight, block.mlp.c_proj.weight, ] resid_params.append(block.resid_lambda) x0_params.append(block.x0_lambda) if hasattr(block, "value_embed"): ve_params += list(block.value_embed.parameters()) ve_gate_params += list(block.ve_gate.parameters()) adamw_groups = [ {"params": embed_params + ve_params, "lr": embedding_lr}, {"params": list(self.mlm_head.parameters()), "lr": mlm_head_lr}, {"params": resid_params, "lr": 0.005}, {"params": x0_params, "lr": 0.5, "betas": (0.96, 0.95)}, {"params": ve_gate_params, "lr": 0.004}, ] adamw_groups = [g for g in adamw_groups if g["params"]] adamw = torch.optim.AdamW( adamw_groups, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.0, fused=True ) muon = Muon(matrix_params, lr=matrix_lr, momentum=0.95) for opt in (adamw, muon): for group in opt.param_groups: group["initial_lr"] = group["lr"] return adamw, muon def forward( self, input_ids: torch.LongTensor, attention_mask: Optional[torch.Tensor] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, **kwargs, ): del kwargs return_dict = return_dict if return_dict is not None else self.config.use_return_dict output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) # Non-persistent rotary buffers are skipped by state_dict load; after # from_pretrained's meta→to_empty path they hold uninitialized memory. # Refresh once per instance, on the actual parameter device. if not getattr(self, "_rope_initialized", False): self._refresh_rope_cache() self._rope_initialized = True B, T = input_ids.shape if T > self.cos.size(1): raise ValueError( f"Input sequence length {T} exceeds rotary cache length {self.cos.size(1)}." ) cos_sin = self.cos[:, :T], self.sin[:, :T] x = norm(self.wte(input_ids)) x0 = x hidden_states = () if output_hidden_states else None if output_hidden_states: hidden_states = hidden_states + (x,) if self.training and self.use_gradient_checkpointing: from torch.utils.checkpoint import checkpoint for block in self.blocks: x = checkpoint(block, x, cos_sin, x0, input_ids, attention_mask, use_reentrant=False) if output_hidden_states: hidden_states = hidden_states + (x,) else: for block in self.blocks: x = block(x, cos_sin, x0, input_ids, attention_mask=attention_mask) if output_hidden_states: hidden_states = hidden_states + (x,) x = norm(x) if not return_dict: return (x, hidden_states) return BaseModelOutput(last_hidden_state=x, hidden_states=hidden_states) class TheoBertBaseForMaskedLM(TheoBertBaseModel): _keys_to_ignore_on_load_unexpected = [r"cls\..*"] def get_output_embeddings(self): return self.mlm_head[-1] def set_output_embeddings(self, new_embeddings): self.mlm_head[-1] = new_embeddings def forward( self, input_ids: torch.LongTensor, attention_mask: Optional[torch.Tensor] = None, labels: Optional[torch.LongTensor] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, **kwargs, ): outputs = super().forward( input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=output_hidden_states, return_dict=True, **kwargs, ) logits = self.mlm_head(outputs.last_hidden_state).float() loss = None if labels is not None: loss = F.cross_entropy( logits.view(-1, self.config.vocab_size), labels.view(-1), ignore_index=-100, ) if return_dict is False: result = (logits, outputs.hidden_states) return ((loss,) + result) if loss is not None else result return MaskedLMOutput(loss=loss, logits=logits, hidden_states=outputs.hidden_states)