Fill-Mask
Transformers
Safetensors
English
theo_bert_base
masked-language-modeling
bible
theology
christianity
trust-remote-code
custom_code
Eval Results (legacy)
Instructions to use toranb/theo-bert-base with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use toranb/theo-bert-base with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("fill-mask", model="toranb/theo-bert-base", trust_remote_code=True)# Load model directly from transformers import AutoModelForMaskedLM model = AutoModelForMaskedLM.from_pretrained("toranb/theo-bert-base", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
| import math | |
| from typing import Optional | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from transformers import PreTrainedModel | |
| from transformers.modeling_outputs import BaseModelOutput, MaskedLMOutput | |
| from .configuration_theo_bert_base import TheoBertBaseConfig | |
| from .muon import Muon | |
| def norm(x: torch.Tensor) -> torch.Tensor: | |
| return F.rms_norm(x, (x.size(-1),)) | |
| def apply_rotary_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: | |
| d = x.shape[-1] // 2 | |
| x1, x2 = x[..., :d], x[..., d:] | |
| return torch.cat([x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos], dim=-1).to(x.dtype) | |
| class RMSNorm(nn.Module): | |
| def __init__(self, dim: int): | |
| super().__init__() | |
| self.dim = dim | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| return F.rms_norm(x, (self.dim,)) | |
| class SelfAttention(nn.Module): | |
| def __init__(self, config: TheoBertBaseConfig): | |
| super().__init__() | |
| self.n_head = config.n_head | |
| self.head_dim = config.n_embd // config.n_head | |
| D = config.n_embd | |
| self.c_q = nn.Linear(D, D, bias=False) | |
| self.c_k = nn.Linear(D, D, bias=False) | |
| self.c_v = nn.Linear(D, D, bias=False) | |
| self.c_proj = nn.Linear(D, D, bias=False) | |
| def forward(self, x, cos_sin, ve=None, attention_mask=None): | |
| B, T, D = x.shape | |
| H, Dh = self.n_head, self.head_dim | |
| q = self.c_q(x).view(B, T, H, Dh) | |
| k = self.c_k(x).view(B, T, H, Dh) | |
| v_proj = self.c_v(x) | |
| if ve is not None: | |
| v_proj = v_proj + ve | |
| v = v_proj.view(B, T, H, Dh) | |
| cos, sin = cos_sin | |
| q = apply_rotary_emb(q, cos, sin) | |
| k = apply_rotary_emb(k, cos, sin) | |
| q = norm(q) | |
| k = norm(k) | |
| q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) | |
| attn_mask = None | |
| if attention_mask is not None: | |
| attn_mask = attention_mask[:, None, None, :].bool() | |
| y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, is_causal=False) | |
| return self.c_proj(y.transpose(1, 2).contiguous().view(B, T, D)) | |
| class MLP(nn.Module): | |
| def __init__(self, config: TheoBertBaseConfig): | |
| super().__init__() | |
| D = config.n_embd | |
| self.c_fc = nn.Linear(D, 4 * D, bias=False) | |
| self.c_proj = nn.Linear(4 * D, D, bias=False) | |
| def forward(self, x): | |
| return self.c_proj(F.relu(self.c_fc(x)).square()) | |
| class Block(nn.Module): | |
| def __init__(self, config: TheoBertBaseConfig, layer_idx: int): | |
| super().__init__() | |
| self.attn = SelfAttention(config) | |
| self.mlp = MLP(config) | |
| self.resid_lambda = nn.Parameter(torch.ones(1)) | |
| self.x0_lambda = nn.Parameter(torch.full((1,), 0.1)) | |
| if layer_idx % 2 == 0: | |
| self.value_embed = nn.Embedding(config.vocab_size, config.n_embd) | |
| self.ve_gate = nn.Linear(32, 1, bias=False) | |
| def forward(self, x, cos_sin, x0, token_ids, attention_mask=None): | |
| normed = norm(x) | |
| ve = None | |
| if hasattr(self, "value_embed"): | |
| raw_ve = self.value_embed(token_ids) | |
| gate = 2 * torch.sigmoid(self.ve_gate(normed[..., :32])) | |
| ve = gate * raw_ve | |
| x = x + self.attn(normed, cos_sin, ve=ve, attention_mask=attention_mask) | |
| x = x + self.mlp(norm(x)) | |
| return self.resid_lambda * x + self.x0_lambda * x0 | |
| class TheoBertBasePreTrainedModel(PreTrainedModel): | |
| config_class = TheoBertBaseConfig | |
| base_model_prefix = "theo_bert_base" | |
| supports_gradient_checkpointing = True | |
| def _init_weights(self, module): | |
| if isinstance(module, nn.Linear): | |
| fan_in = module.weight.size(1) | |
| fan_out = module.weight.size(0) | |
| std = (1.0 / math.sqrt(fan_in)) * min(1.0, math.sqrt(fan_out / fan_in)) | |
| nn.init.normal_(module.weight, std=std) | |
| elif isinstance(module, nn.Embedding): | |
| nn.init.normal_(module.weight, std=1.0) | |
| def _set_gradient_checkpointing(self, module, value=False): | |
| if isinstance(module, TheoBertBaseModel): | |
| module.use_gradient_checkpointing = value | |
| class TheoBertBaseModel(TheoBertBasePreTrainedModel): | |
| def __init__(self, config: TheoBertBaseConfig): | |
| super().__init__(config) | |
| self.use_gradient_checkpointing = False | |
| self.wte = nn.Embedding(config.vocab_size, config.n_embd) | |
| self.blocks = nn.ModuleList([Block(config, i) for i in range(config.n_layer)]) | |
| # Retained on the base model so the exported checkpoint can be consumed by | |
| # AutoModel and AutoModelForMaskedLM from the same repository without key drift. | |
| self.mlm_head = nn.Sequential( | |
| nn.Linear(config.n_embd, config.n_embd, bias=False), | |
| nn.GELU(), | |
| RMSNorm(config.n_embd), | |
| nn.Linear(config.n_embd, config.vocab_size, bias=False), | |
| ) | |
| self._refresh_rope_cache() | |
| self.post_init() | |
| self._post_init_architecture() | |
| def _post_init_architecture(self): | |
| nn.init.zeros_(self.mlm_head[-1].weight) | |
| for block in self.blocks: | |
| nn.init.zeros_(block.attn.c_proj.weight) | |
| nn.init.zeros_(block.mlp.c_proj.weight) | |
| nn.init.ones_(block.resid_lambda) | |
| block.x0_lambda.data.fill_(0.1) | |
| def _make_rotary(self, seq_len, head_dim, base=10000, device=None): | |
| if device is None: | |
| device = self.wte.weight.device | |
| inv_freq = 1.0 / ( | |
| base | |
| ** (torch.arange(0, head_dim, 2, dtype=torch.float32, device=device) / head_dim) | |
| ) | |
| t = torch.arange(seq_len, dtype=torch.float32, device=device) | |
| freqs = torch.outer(t, inv_freq) | |
| cos = freqs.cos()[None, :, None, :] | |
| sin = freqs.sin()[None, :, None, :] | |
| return cos, sin | |
| def _refresh_rope_cache(self): | |
| head_dim = self.config.n_embd // self.config.n_head | |
| cache_len = self.config.seq_len * self.config.rope_cache_factor | |
| cos, sin = self._make_rotary(cache_len, head_dim, base=self.config.rope_base) | |
| self.register_buffer("cos", cos, persistent=False) | |
| self.register_buffer("sin", sin, persistent=False) | |
| def get_input_embeddings(self): | |
| return self.wte | |
| def set_input_embeddings(self, value): | |
| self.wte = value | |
| def mean_pool(self, hidden, mask=None): | |
| if mask is None: | |
| return hidden.mean(dim=1) | |
| m = mask.unsqueeze(-1).float() | |
| return (hidden * m).sum(1) / m.sum(1).clamp(min=1) | |
| def setup_optimizers(self, embedding_lr=0.3, matrix_lr=0.02): | |
| model_dim = self.config.n_embd | |
| mlm_head_lr = 0.004 * math.sqrt(768 / model_dim) | |
| embed_params = list(self.wte.parameters()) | |
| ve_params, ve_gate_params, resid_params, x0_params, matrix_params = [], [], [], [], [] | |
| for block in self.blocks: | |
| matrix_params += [ | |
| block.attn.c_q.weight, | |
| block.attn.c_k.weight, | |
| block.attn.c_v.weight, | |
| block.attn.c_proj.weight, | |
| block.mlp.c_fc.weight, | |
| block.mlp.c_proj.weight, | |
| ] | |
| resid_params.append(block.resid_lambda) | |
| x0_params.append(block.x0_lambda) | |
| if hasattr(block, "value_embed"): | |
| ve_params += list(block.value_embed.parameters()) | |
| ve_gate_params += list(block.ve_gate.parameters()) | |
| adamw_groups = [ | |
| {"params": embed_params + ve_params, "lr": embedding_lr}, | |
| {"params": list(self.mlm_head.parameters()), "lr": mlm_head_lr}, | |
| {"params": resid_params, "lr": 0.005}, | |
| {"params": x0_params, "lr": 0.5, "betas": (0.96, 0.95)}, | |
| {"params": ve_gate_params, "lr": 0.004}, | |
| ] | |
| adamw_groups = [g for g in adamw_groups if g["params"]] | |
| adamw = torch.optim.AdamW( | |
| adamw_groups, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.0, fused=True | |
| ) | |
| muon = Muon(matrix_params, lr=matrix_lr, momentum=0.95) | |
| for opt in (adamw, muon): | |
| for group in opt.param_groups: | |
| group["initial_lr"] = group["lr"] | |
| return adamw, muon | |
| def forward( | |
| self, | |
| input_ids: torch.LongTensor, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| output_hidden_states: Optional[bool] = None, | |
| return_dict: Optional[bool] = None, | |
| **kwargs, | |
| ): | |
| del kwargs | |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
| output_hidden_states = ( | |
| output_hidden_states | |
| if output_hidden_states is not None | |
| else self.config.output_hidden_states | |
| ) | |
| # Non-persistent rotary buffers are skipped by state_dict load; after | |
| # from_pretrained's meta→to_empty path they hold uninitialized memory. | |
| # Refresh once per instance, on the actual parameter device. | |
| if not getattr(self, "_rope_initialized", False): | |
| self._refresh_rope_cache() | |
| self._rope_initialized = True | |
| B, T = input_ids.shape | |
| if T > self.cos.size(1): | |
| raise ValueError( | |
| f"Input sequence length {T} exceeds rotary cache length {self.cos.size(1)}." | |
| ) | |
| cos_sin = self.cos[:, :T], self.sin[:, :T] | |
| x = norm(self.wte(input_ids)) | |
| x0 = x | |
| hidden_states = () if output_hidden_states else None | |
| if output_hidden_states: | |
| hidden_states = hidden_states + (x,) | |
| if self.training and self.use_gradient_checkpointing: | |
| from torch.utils.checkpoint import checkpoint | |
| for block in self.blocks: | |
| x = checkpoint(block, x, cos_sin, x0, input_ids, attention_mask, use_reentrant=False) | |
| if output_hidden_states: | |
| hidden_states = hidden_states + (x,) | |
| else: | |
| for block in self.blocks: | |
| x = block(x, cos_sin, x0, input_ids, attention_mask=attention_mask) | |
| if output_hidden_states: | |
| hidden_states = hidden_states + (x,) | |
| x = norm(x) | |
| if not return_dict: | |
| return (x, hidden_states) | |
| return BaseModelOutput(last_hidden_state=x, hidden_states=hidden_states) | |
| class TheoBertBaseForMaskedLM(TheoBertBaseModel): | |
| _keys_to_ignore_on_load_unexpected = [r"cls\..*"] | |
| def get_output_embeddings(self): | |
| return self.mlm_head[-1] | |
| def set_output_embeddings(self, new_embeddings): | |
| self.mlm_head[-1] = new_embeddings | |
| def forward( | |
| self, | |
| input_ids: torch.LongTensor, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| labels: Optional[torch.LongTensor] = None, | |
| output_hidden_states: Optional[bool] = None, | |
| return_dict: Optional[bool] = None, | |
| **kwargs, | |
| ): | |
| outputs = super().forward( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| output_hidden_states=output_hidden_states, | |
| return_dict=True, | |
| **kwargs, | |
| ) | |
| logits = self.mlm_head(outputs.last_hidden_state).float() | |
| loss = None | |
| if labels is not None: | |
| loss = F.cross_entropy( | |
| logits.view(-1, self.config.vocab_size), | |
| labels.view(-1), | |
| ignore_index=-100, | |
| ) | |
| if return_dict is False: | |
| result = (logits, outputs.hidden_states) | |
| return ((loss,) + result) if loss is not None else result | |
| return MaskedLMOutput(loss=loss, logits=logits, hidden_states=outputs.hidden_states) | |