h4-polytopic-attention / python /h4_language_model.py
grapheneaffiliates's picture
Upload python/h4_language_model.py with huggingface_hub
f742bbc verified
"""
H4 Language Model — Transformer LM with H4 geometric attention.
Architecture:
- Token embedding + golden-angle positional encoding (PhiPositionalEncoding)
- N × H4TransformerBlock (H4 attention + FFN)
- LM head (Linear to vocab_size)
The frozen H4 geometry handles spatial partitioning of attention space.
Trainable adapters (nudge matrices, chamber bonuses, projections) learn
which directions to query and how to weight chambers.
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import sys
import os
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from h4_hybrid_attention import H4TransformerBlock
from utils.phi_positional import PhiPositionalEncoding
from bitlinear import BitLinear
class H4LanguageModel(nn.Module):
"""
Full language model with H4 polytopic attention.
Args:
vocab_size: vocabulary size
d_model: model dimension
n_heads: number of H4 attention heads per layer
n_layers: number of transformer blocks
d_value: value dimension per head
d_ffn: FFN hidden dimension (default: 4 * d_model)
top_k: max candidates per query in ChamberTree lookup
max_seq_len: max sequence length for positional encoding cache
dropout: dropout rate
"""
def __init__(
self,
vocab_size: int,
d_model: int = 64,
n_heads: int = 8,
n_layers: int = 4,
d_value: int = 16,
d_ffn: int = None,
top_k: int = 32,
max_seq_len: int = 8192,
dropout: float = 0.1,
use_bitlinear: bool = False,
):
super().__init__()
self.vocab_size = vocab_size
self.d_model = d_model
self.n_layers = n_layers
self.use_bitlinear = use_bitlinear
if d_ffn is None:
d_ffn = d_model * 4
# Token embedding (always float — lookup table, not a matmul)
self.token_emb = nn.Embedding(vocab_size, d_model)
# Scale embedding by sqrt(d_model) as in original transformer
self.emb_scale = math.sqrt(d_model)
# Golden-angle positional encoding
self.pos_enc = PhiPositionalEncoding(d_model, max_cached=max_seq_len)
# Embedding dropout
self.emb_dropout = nn.Dropout(dropout)
# Transformer blocks with H4 attention
self.blocks = nn.ModuleList([
H4TransformerBlock(
d_model=d_model,
n_heads=n_heads,
d_value=d_value,
d_ffn=d_ffn,
top_k=top_k,
dropout=dropout,
use_bitlinear=use_bitlinear,
)
for _ in range(n_layers)
])
# Final layer norm
self.ln_f = nn.LayerNorm(d_model)
# LM head (tied with token embedding weights — stays float)
self.lm_head = nn.Linear(d_model, vocab_size, bias=False)
# Weight tying
self.lm_head.weight = self.token_emb.weight
self._init_weights()
def _init_weights(self):
"""Initialize weights following GPT-2 conventions."""
for module in self.modules():
if isinstance(module, BitLinear):
# BitLinear already has kaiming init; apply GPT-2 scale
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
elif isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
def forward(
self,
input_ids: torch.Tensor,
use_tree: bool = True,
return_diagnostics: bool = False,
) -> torch.Tensor:
"""
Args:
input_ids: (batch, seq_len) token indices
use_tree: if True, use ChamberTree for O(log t) attention
return_diagnostics: if True, return (logits, list_of_diag_dicts)
Returns:
logits: (batch, seq_len, vocab_size)
"""
B, T = input_ids.shape
# Token + positional embedding
tok_emb = self.token_emb(input_ids) * self.emb_scale # (B, T, D)
pos_emb = self.pos_enc(T).unsqueeze(0).to(tok_emb.device) # (1, T, D)
x = self.emb_dropout(tok_emb + pos_emb)
# Transformer blocks
diagnostics = []
for block in self.blocks:
if return_diagnostics:
x, diag = block(x, use_tree=use_tree, return_diagnostics=True)
diagnostics.append(diag)
else:
x = block(x, use_tree=use_tree)
# Final norm + LM head
x = self.ln_f(x)
logits = self.lm_head(x) # (B, T, vocab_size)
if return_diagnostics:
return logits, diagnostics
return logits
def count_params(self):
"""Count trainable and frozen parameters."""
trainable = sum(p.numel() for p in self.parameters() if p.requires_grad)
frozen = sum(p.numel() for p in self.parameters() if not p.requires_grad)
buffers = sum(b.numel() for b in self.buffers())
return {
'trainable': trainable,
'frozen': frozen,
'buffers': buffers,
'total': trainable + frozen,
}
@torch.no_grad()
def generate(
self,
input_ids: torch.Tensor,
max_new_tokens: int = 100,
temperature: float = 1.0,
top_k_sample: int = 0,
) -> torch.Tensor:
"""Autoregressive generation."""
for _ in range(max_new_tokens):
# Crop to max sequence length if needed
logits = self.forward(input_ids, use_tree=False)
logits = logits[:, -1, :] / temperature
if top_k_sample > 0:
v, _ = torch.topk(logits, min(top_k_sample, logits.size(-1)))
logits[logits < v[:, [-1]]] = float('-inf')
probs = F.softmax(logits, dim=-1)
next_id = torch.multinomial(probs, num_samples=1)
input_ids = torch.cat([input_ids, next_id], dim=1)
return input_ids