| """ |
| H4 Language Model — Transformer LM with H4 geometric attention. |
| |
| Architecture: |
| - Token embedding + golden-angle positional encoding (PhiPositionalEncoding) |
| - N × H4TransformerBlock (H4 attention + FFN) |
| - LM head (Linear to vocab_size) |
| |
| The frozen H4 geometry handles spatial partitioning of attention space. |
| Trainable adapters (nudge matrices, chamber bonuses, projections) learn |
| which directions to query and how to weight chambers. |
| """ |
|
|
| import math |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| import sys |
| import os |
| sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) |
|
|
| from h4_hybrid_attention import H4TransformerBlock |
| from utils.phi_positional import PhiPositionalEncoding |
| from bitlinear import BitLinear |
|
|
|
|
| class H4LanguageModel(nn.Module): |
| """ |
| Full language model with H4 polytopic attention. |
| |
| Args: |
| vocab_size: vocabulary size |
| d_model: model dimension |
| n_heads: number of H4 attention heads per layer |
| n_layers: number of transformer blocks |
| d_value: value dimension per head |
| d_ffn: FFN hidden dimension (default: 4 * d_model) |
| top_k: max candidates per query in ChamberTree lookup |
| max_seq_len: max sequence length for positional encoding cache |
| dropout: dropout rate |
| """ |
|
|
| def __init__( |
| self, |
| vocab_size: int, |
| d_model: int = 64, |
| n_heads: int = 8, |
| n_layers: int = 4, |
| d_value: int = 16, |
| d_ffn: int = None, |
| top_k: int = 32, |
| max_seq_len: int = 8192, |
| dropout: float = 0.1, |
| use_bitlinear: bool = False, |
| ): |
| super().__init__() |
| self.vocab_size = vocab_size |
| self.d_model = d_model |
| self.n_layers = n_layers |
| self.use_bitlinear = use_bitlinear |
|
|
| if d_ffn is None: |
| d_ffn = d_model * 4 |
|
|
| |
| self.token_emb = nn.Embedding(vocab_size, d_model) |
| |
| self.emb_scale = math.sqrt(d_model) |
|
|
| |
| self.pos_enc = PhiPositionalEncoding(d_model, max_cached=max_seq_len) |
|
|
| |
| self.emb_dropout = nn.Dropout(dropout) |
|
|
| |
| self.blocks = nn.ModuleList([ |
| H4TransformerBlock( |
| d_model=d_model, |
| n_heads=n_heads, |
| d_value=d_value, |
| d_ffn=d_ffn, |
| top_k=top_k, |
| dropout=dropout, |
| use_bitlinear=use_bitlinear, |
| ) |
| for _ in range(n_layers) |
| ]) |
|
|
| |
| self.ln_f = nn.LayerNorm(d_model) |
|
|
| |
| self.lm_head = nn.Linear(d_model, vocab_size, bias=False) |
| |
| self.lm_head.weight = self.token_emb.weight |
|
|
| self._init_weights() |
|
|
| def _init_weights(self): |
| """Initialize weights following GPT-2 conventions.""" |
| for module in self.modules(): |
| if isinstance(module, BitLinear): |
| |
| torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) |
| elif isinstance(module, nn.Linear): |
| torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) |
| if module.bias is not None: |
| torch.nn.init.zeros_(module.bias) |
| elif isinstance(module, nn.Embedding): |
| torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) |
|
|
| def forward( |
| self, |
| input_ids: torch.Tensor, |
| use_tree: bool = True, |
| return_diagnostics: bool = False, |
| ) -> torch.Tensor: |
| """ |
| Args: |
| input_ids: (batch, seq_len) token indices |
| use_tree: if True, use ChamberTree for O(log t) attention |
| return_diagnostics: if True, return (logits, list_of_diag_dicts) |
| |
| Returns: |
| logits: (batch, seq_len, vocab_size) |
| """ |
| B, T = input_ids.shape |
|
|
| |
| tok_emb = self.token_emb(input_ids) * self.emb_scale |
| pos_emb = self.pos_enc(T).unsqueeze(0).to(tok_emb.device) |
| x = self.emb_dropout(tok_emb + pos_emb) |
|
|
| |
| diagnostics = [] |
| for block in self.blocks: |
| if return_diagnostics: |
| x, diag = block(x, use_tree=use_tree, return_diagnostics=True) |
| diagnostics.append(diag) |
| else: |
| x = block(x, use_tree=use_tree) |
|
|
| |
| x = self.ln_f(x) |
| logits = self.lm_head(x) |
|
|
| if return_diagnostics: |
| return logits, diagnostics |
| return logits |
|
|
| def count_params(self): |
| """Count trainable and frozen parameters.""" |
| trainable = sum(p.numel() for p in self.parameters() if p.requires_grad) |
| frozen = sum(p.numel() for p in self.parameters() if not p.requires_grad) |
| buffers = sum(b.numel() for b in self.buffers()) |
| return { |
| 'trainable': trainable, |
| 'frozen': frozen, |
| 'buffers': buffers, |
| 'total': trainable + frozen, |
| } |
|
|
| @torch.no_grad() |
| def generate( |
| self, |
| input_ids: torch.Tensor, |
| max_new_tokens: int = 100, |
| temperature: float = 1.0, |
| top_k_sample: int = 0, |
| ) -> torch.Tensor: |
| """Autoregressive generation.""" |
| for _ in range(max_new_tokens): |
| |
| logits = self.forward(input_ids, use_tree=False) |
| logits = logits[:, -1, :] / temperature |
|
|
| if top_k_sample > 0: |
| v, _ = torch.topk(logits, min(top_k_sample, logits.size(-1))) |
| logits[logits < v[:, [-1]]] = float('-inf') |
|
|
| probs = F.softmax(logits, dim=-1) |
| next_id = torch.multinomial(probs, num_samples=1) |
| input_ids = torch.cat([input_ids, next_id], dim=1) |
|
|
| return input_ids |
|
|