theo-bert-base / modeling_theo_bert_base.py
toranb's picture
Initial release: TheoBERT Base — biblical-domain masked language model
a64c547
import math
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PreTrainedModel
from transformers.modeling_outputs import BaseModelOutput, MaskedLMOutput
from .configuration_theo_bert_base import TheoBertBaseConfig
from .muon import Muon
def norm(x: torch.Tensor) -> torch.Tensor:
return F.rms_norm(x, (x.size(-1),))
def apply_rotary_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
d = x.shape[-1] // 2
x1, x2 = x[..., :d], x[..., d:]
return torch.cat([x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos], dim=-1).to(x.dtype)
class RMSNorm(nn.Module):
def __init__(self, dim: int):
super().__init__()
self.dim = dim
def forward(self, x: torch.Tensor) -> torch.Tensor:
return F.rms_norm(x, (self.dim,))
class SelfAttention(nn.Module):
def __init__(self, config: TheoBertBaseConfig):
super().__init__()
self.n_head = config.n_head
self.head_dim = config.n_embd // config.n_head
D = config.n_embd
self.c_q = nn.Linear(D, D, bias=False)
self.c_k = nn.Linear(D, D, bias=False)
self.c_v = nn.Linear(D, D, bias=False)
self.c_proj = nn.Linear(D, D, bias=False)
def forward(self, x, cos_sin, ve=None, attention_mask=None):
B, T, D = x.shape
H, Dh = self.n_head, self.head_dim
q = self.c_q(x).view(B, T, H, Dh)
k = self.c_k(x).view(B, T, H, Dh)
v_proj = self.c_v(x)
if ve is not None:
v_proj = v_proj + ve
v = v_proj.view(B, T, H, Dh)
cos, sin = cos_sin
q = apply_rotary_emb(q, cos, sin)
k = apply_rotary_emb(k, cos, sin)
q = norm(q)
k = norm(k)
q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
attn_mask = None
if attention_mask is not None:
attn_mask = attention_mask[:, None, None, :].bool()
y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, is_causal=False)
return self.c_proj(y.transpose(1, 2).contiguous().view(B, T, D))
class MLP(nn.Module):
def __init__(self, config: TheoBertBaseConfig):
super().__init__()
D = config.n_embd
self.c_fc = nn.Linear(D, 4 * D, bias=False)
self.c_proj = nn.Linear(4 * D, D, bias=False)
def forward(self, x):
return self.c_proj(F.relu(self.c_fc(x)).square())
class Block(nn.Module):
def __init__(self, config: TheoBertBaseConfig, layer_idx: int):
super().__init__()
self.attn = SelfAttention(config)
self.mlp = MLP(config)
self.resid_lambda = nn.Parameter(torch.ones(1))
self.x0_lambda = nn.Parameter(torch.full((1,), 0.1))
if layer_idx % 2 == 0:
self.value_embed = nn.Embedding(config.vocab_size, config.n_embd)
self.ve_gate = nn.Linear(32, 1, bias=False)
def forward(self, x, cos_sin, x0, token_ids, attention_mask=None):
normed = norm(x)
ve = None
if hasattr(self, "value_embed"):
raw_ve = self.value_embed(token_ids)
gate = 2 * torch.sigmoid(self.ve_gate(normed[..., :32]))
ve = gate * raw_ve
x = x + self.attn(normed, cos_sin, ve=ve, attention_mask=attention_mask)
x = x + self.mlp(norm(x))
return self.resid_lambda * x + self.x0_lambda * x0
class TheoBertBasePreTrainedModel(PreTrainedModel):
config_class = TheoBertBaseConfig
base_model_prefix = "theo_bert_base"
supports_gradient_checkpointing = True
def _init_weights(self, module):
if isinstance(module, nn.Linear):
fan_in = module.weight.size(1)
fan_out = module.weight.size(0)
std = (1.0 / math.sqrt(fan_in)) * min(1.0, math.sqrt(fan_out / fan_in))
nn.init.normal_(module.weight, std=std)
elif isinstance(module, nn.Embedding):
nn.init.normal_(module.weight, std=1.0)
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, TheoBertBaseModel):
module.use_gradient_checkpointing = value
class TheoBertBaseModel(TheoBertBasePreTrainedModel):
def __init__(self, config: TheoBertBaseConfig):
super().__init__(config)
self.use_gradient_checkpointing = False
self.wte = nn.Embedding(config.vocab_size, config.n_embd)
self.blocks = nn.ModuleList([Block(config, i) for i in range(config.n_layer)])
# Retained on the base model so the exported checkpoint can be consumed by
# AutoModel and AutoModelForMaskedLM from the same repository without key drift.
self.mlm_head = nn.Sequential(
nn.Linear(config.n_embd, config.n_embd, bias=False),
nn.GELU(),
RMSNorm(config.n_embd),
nn.Linear(config.n_embd, config.vocab_size, bias=False),
)
self._refresh_rope_cache()
self.post_init()
self._post_init_architecture()
def _post_init_architecture(self):
nn.init.zeros_(self.mlm_head[-1].weight)
for block in self.blocks:
nn.init.zeros_(block.attn.c_proj.weight)
nn.init.zeros_(block.mlp.c_proj.weight)
nn.init.ones_(block.resid_lambda)
block.x0_lambda.data.fill_(0.1)
def _make_rotary(self, seq_len, head_dim, base=10000, device=None):
if device is None:
device = self.wte.weight.device
inv_freq = 1.0 / (
base
** (torch.arange(0, head_dim, 2, dtype=torch.float32, device=device) / head_dim)
)
t = torch.arange(seq_len, dtype=torch.float32, device=device)
freqs = torch.outer(t, inv_freq)
cos = freqs.cos()[None, :, None, :]
sin = freqs.sin()[None, :, None, :]
return cos, sin
def _refresh_rope_cache(self):
head_dim = self.config.n_embd // self.config.n_head
cache_len = self.config.seq_len * self.config.rope_cache_factor
cos, sin = self._make_rotary(cache_len, head_dim, base=self.config.rope_base)
self.register_buffer("cos", cos, persistent=False)
self.register_buffer("sin", sin, persistent=False)
def get_input_embeddings(self):
return self.wte
def set_input_embeddings(self, value):
self.wte = value
def mean_pool(self, hidden, mask=None):
if mask is None:
return hidden.mean(dim=1)
m = mask.unsqueeze(-1).float()
return (hidden * m).sum(1) / m.sum(1).clamp(min=1)
def setup_optimizers(self, embedding_lr=0.3, matrix_lr=0.02):
model_dim = self.config.n_embd
mlm_head_lr = 0.004 * math.sqrt(768 / model_dim)
embed_params = list(self.wte.parameters())
ve_params, ve_gate_params, resid_params, x0_params, matrix_params = [], [], [], [], []
for block in self.blocks:
matrix_params += [
block.attn.c_q.weight,
block.attn.c_k.weight,
block.attn.c_v.weight,
block.attn.c_proj.weight,
block.mlp.c_fc.weight,
block.mlp.c_proj.weight,
]
resid_params.append(block.resid_lambda)
x0_params.append(block.x0_lambda)
if hasattr(block, "value_embed"):
ve_params += list(block.value_embed.parameters())
ve_gate_params += list(block.ve_gate.parameters())
adamw_groups = [
{"params": embed_params + ve_params, "lr": embedding_lr},
{"params": list(self.mlm_head.parameters()), "lr": mlm_head_lr},
{"params": resid_params, "lr": 0.005},
{"params": x0_params, "lr": 0.5, "betas": (0.96, 0.95)},
{"params": ve_gate_params, "lr": 0.004},
]
adamw_groups = [g for g in adamw_groups if g["params"]]
adamw = torch.optim.AdamW(
adamw_groups, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.0, fused=True
)
muon = Muon(matrix_params, lr=matrix_lr, momentum=0.95)
for opt in (adamw, muon):
for group in opt.param_groups:
group["initial_lr"] = group["lr"]
return adamw, muon
def forward(
self,
input_ids: torch.LongTensor,
attention_mask: Optional[torch.Tensor] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs,
):
del kwargs
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
# Non-persistent rotary buffers are skipped by state_dict load; after
# from_pretrained's meta→to_empty path they hold uninitialized memory.
# Refresh once per instance, on the actual parameter device.
if not getattr(self, "_rope_initialized", False):
self._refresh_rope_cache()
self._rope_initialized = True
B, T = input_ids.shape
if T > self.cos.size(1):
raise ValueError(
f"Input sequence length {T} exceeds rotary cache length {self.cos.size(1)}."
)
cos_sin = self.cos[:, :T], self.sin[:, :T]
x = norm(self.wte(input_ids))
x0 = x
hidden_states = () if output_hidden_states else None
if output_hidden_states:
hidden_states = hidden_states + (x,)
if self.training and self.use_gradient_checkpointing:
from torch.utils.checkpoint import checkpoint
for block in self.blocks:
x = checkpoint(block, x, cos_sin, x0, input_ids, attention_mask, use_reentrant=False)
if output_hidden_states:
hidden_states = hidden_states + (x,)
else:
for block in self.blocks:
x = block(x, cos_sin, x0, input_ids, attention_mask=attention_mask)
if output_hidden_states:
hidden_states = hidden_states + (x,)
x = norm(x)
if not return_dict:
return (x, hidden_states)
return BaseModelOutput(last_hidden_state=x, hidden_states=hidden_states)
class TheoBertBaseForMaskedLM(TheoBertBaseModel):
_keys_to_ignore_on_load_unexpected = [r"cls\..*"]
def get_output_embeddings(self):
return self.mlm_head[-1]
def set_output_embeddings(self, new_embeddings):
self.mlm_head[-1] = new_embeddings
def forward(
self,
input_ids: torch.LongTensor,
attention_mask: Optional[torch.Tensor] = None,
labels: Optional[torch.LongTensor] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs,
):
outputs = super().forward(
input_ids=input_ids,
attention_mask=attention_mask,
output_hidden_states=output_hidden_states,
return_dict=True,
**kwargs,
)
logits = self.mlm_head(outputs.last_hidden_state).float()
loss = None
if labels is not None:
loss = F.cross_entropy(
logits.view(-1, self.config.vocab_size),
labels.view(-1),
ignore_index=-100,
)
if return_dict is False:
result = (logits, outputs.hidden_states)
return ((loss,) + result) if loss is not None else result
return MaskedLMOutput(loss=loss, logits=logits, hidden_states=outputs.hidden_states)