| """ |
| model.py — MechanismBase |
| ======================== |
| |
| The transformer decoder implementing P / G → Q. |
| |
| Two configurations: |
| SmallConfig (~10M params) — appropriate for ~200K tokens. |
| Generalizes. Recommended for current corpus. |
| |
| FullConfig (~235M params) — appropriate for ~2M+ tokens. |
| Use after expanding the training corpus. |
| |
| Architecture maps to PL terminology: |
| wte — token embedding: seeds patterns P with initial loaded history |
| wpe — position encoding: adds positional loaded history |
| PropagationBlock — one complete P / G → Q step: |
| attention = gradient family G applied to P |
| residual = loaded history H_P accumulating |
| pre-norm = coherence check before each propagation |
| MLP = reconfiguration toward coherent state |
| ln_f — final coherence check |
| lm_head — output: weight-tied to wte (same carrier in and out) |
| |
| Parameter counts (approximate): |
| SmallConfig: 10.5M params |
| FullConfig: 235.0M params |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from dataclasses import dataclass |
|
|
|
|
| |
| |
| |
|
|
| @dataclass |
| class SmallConfig: |
| """ |
| ~10M params. Appropriate for 100K–500K tokens. |
| This is the working configuration for the current corpus (~200K tokens). |
| Trains in ~30 minutes on RTX 4060 Ti. |
| Will generalize, not just memorize. |
| """ |
| vocab_size: int = 16384 |
| n_embd: int = 256 |
| n_layer: int = 8 |
| n_head: int = 8 |
| block_size: int = 256 |
| dropout: float = 0.1 |
| name: str = "SmallBase" |
|
|
|
|
| @dataclass |
| class MediumConfig: |
| """ |
| ~50M params. Appropriate for 500K–2M tokens. |
| Use after expanding generate_data.py to produce more derivation traces. |
| Trains in ~2-3 hours on RTX 4060 Ti. |
| """ |
| vocab_size: int = 16384 |
| n_embd: int = 512 |
| n_layer: int = 12 |
| n_head: int = 8 |
| block_size: int = 256 |
| dropout: float = 0.1 |
| name: str = "MediumBase" |
|
|
|
|
| @dataclass |
| class FullConfig: |
| """ |
| ~235M params. The full AGI Base V1. |
| Appropriate for 2M+ tokens. |
| Requires expanding generate_data.py significantly (see comments there). |
| Trains in ~6 hours on RTX 4060 Ti when data is sufficient. |
| """ |
| vocab_size: int = 16384 |
| n_embd: int = 1024 |
| n_layer: int = 16 |
| n_head: int = 16 |
| block_size: int = 256 |
| dropout: float = 0.1 |
| name: str = "FullBase" |
|
|
|
|
| |
| MechanismConfig = SmallConfig |
|
|
|
|
| |
| |
| |
|
|
| class PropagationBlock(nn.Module): |
| """ |
| One complete P / G → Q propagation step. |
| |
| Attention : gradient family G applied to pattern P |
| Residual : loaded history H_P accumulating |
| LayerNorm : coherence threshold check (pre-norm: check BEFORE propagating) |
| MLP : reconfiguration toward coherent state |
| """ |
|
|
| def __init__(self, config): |
| super().__init__() |
| self.ln1 = nn.LayerNorm(config.n_embd) |
| self.attn = nn.MultiheadAttention( |
| config.n_embd, |
| config.n_head, |
| dropout=config.dropout, |
| batch_first=True, |
| ) |
| self.ln2 = nn.LayerNorm(config.n_embd) |
| self.mlp = nn.Sequential( |
| nn.Linear(config.n_embd, 4 * config.n_embd), |
| nn.GELU(), |
| nn.Linear(4 * config.n_embd, config.n_embd), |
| nn.Dropout(config.dropout), |
| ) |
| self.drop = nn.Dropout(config.dropout) |
|
|
| def forward(self, x, attn_mask=None): |
| |
| normed = self.ln1(x) |
| attn_out, _ = self.attn( |
| normed, normed, normed, |
| attn_mask=attn_mask, |
| need_weights=False, |
| ) |
| |
| x = x + self.drop(attn_out) |
| x = x + self.mlp(self.ln2(x)) |
| return x |
|
|
|
|
| |
| |
| |
|
|
| class MechanismBase(nn.Module): |
| """ |
| The mechanism instantiated in the weight carrier. |
| |
| wte : token embedding — seeds patterns |
| wpe : position encoding — adds positional loaded history |
| h : propagation blocks |
| ln_f : final coherence check |
| lm_head : output (weight-tied to wte) |
| """ |
|
|
| def __init__(self, config): |
| super().__init__() |
| self.config = config |
|
|
| self.wte = nn.Embedding(config.vocab_size, config.n_embd) |
| self.wpe = nn.Embedding(config.block_size, config.n_embd) |
| self.drop = nn.Dropout(config.dropout) |
| self.h = nn.ModuleList( |
| [PropagationBlock(config) for _ in range(config.n_layer)] |
| ) |
| self.ln_f = nn.LayerNorm(config.n_embd) |
| self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) |
|
|
| |
| self.lm_head.weight = self.wte.weight |
|
|
| self.apply(self._init_weights) |
|
|
| def _init_weights(self, module): |
| if isinstance(module, nn.Linear): |
| nn.init.normal_(module.weight, mean=0.0, std=0.02) |
| if module.bias is not None: |
| nn.init.zeros_(module.bias) |
| elif isinstance(module, nn.Embedding): |
| nn.init.normal_(module.weight, mean=0.0, std=0.02) |
|
|
| def forward(self, idx, targets=None): |
| B, T = idx.shape |
| assert T <= self.config.block_size, \ |
| f"Sequence length {T} exceeds block_size {self.config.block_size}" |
|
|
| positions = torch.arange(T, device=idx.device) |
| x = self.drop(self.wte(idx) + self.wpe(positions)) |
|
|
| |
| causal_mask = nn.Transformer.generate_square_subsequent_mask( |
| T, device=idx.device |
| ) |
|
|
| for block in self.h: |
| x = block(x, attn_mask=causal_mask) |
|
|
| x = self.ln_f(x) |
| logits = self.lm_head(x) |
|
|
| loss = None |
| if targets is not None: |
| loss = F.cross_entropy( |
| logits.view(-1, logits.size(-1)), |
| targets.view(-1), |
| ) |
|
|
| return logits, loss |
|
|
| @torch.no_grad() |
| def generate( |
| self, |
| idx, |
| max_new_tokens: int = 200, |
| temperature: float = 0.8, |
| top_k: int = 50, |
| top_p: float = 0.9, |
| ): |
| """ |
| Autoregressive generation with temperature + top-k + top-p sampling. |
| """ |
| self.eval() |
| for _ in range(max_new_tokens): |
| x = idx[:, -self.config.block_size:] |
| logits, _ = self(x, None) |
| next_logits = logits[0, -1, :] / temperature |
|
|
| |
| if top_k > 0: |
| k = min(top_k, next_logits.size(-1)) |
| topk_vals, _ = torch.topk(next_logits, k) |
| next_logits[next_logits < topk_vals[-1]] = float("-inf") |
|
|
| |
| if top_p < 1.0: |
| sorted_logits, sorted_idx = torch.sort(next_logits, descending=True) |
| cumprobs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) |
| remove = (cumprobs - F.softmax(sorted_logits, dim=-1)) > top_p |
| sorted_logits[remove] = float("-inf") |
| next_logits = torch.zeros_like(next_logits).scatter_( |
| 0, sorted_idx, sorted_logits |
| ) |
|
|
| probs = F.softmax(next_logits, dim=-1) |
| next_id = torch.multinomial(probs, num_samples=1) |
| idx = torch.cat([idx, next_id.unsqueeze(0)], dim=1) |
|
|
| return idx |
|
|
| def count_parameters(self) -> int: |
| return sum(p.numel() for p in self.parameters()) |
|
|
| def parameter_summary(self) -> str: |
| total = self.count_parameters() |
| embed = self.wte.weight.numel() |
| lines = [ |
| f" Configuration: {self.config.name}", |
| f" Total params: {total:,}", |
| f" Embed params: {embed:,} ({embed/total:.1%} of total)", |
| f" n_embd={self.config.n_embd}, " |
| f"n_layer={self.config.n_layer}, " |
| f"n_head={self.config.n_head}", |
| ] |
| return "\n".join(lines) |
|
|
|
|
| if __name__ == "__main__": |
| for ConfigClass in [SmallConfig, MediumConfig, FullConfig]: |
| config = ConfigClass() |
| model = MechanismBase(config) |
| print(model.parameter_summary()) |
| print() |
|
|