# Copyright 2025 Pathway Technology, Inc. import dataclasses import math import torch import torch.nn.functional as F from torch import nn @dataclasses.dataclass class BDHConfig: n_layer: int = 6 n_embd: int = 256 dropout: float = 0.1 n_head: int = 4 mlp_internal_dim_multiplier: int = 128 vocab_size: int = 256 def get_freqs(n, theta, dtype): def quantize(t, q=2): return (t / q).floor() * q return ( 1.0 / (theta ** (quantize(torch.arange(0, n, 1, dtype=dtype)) / n)) / (2 * math.pi) ) class Attention(torch.nn.Module): def __init__(self, config): super().__init__() self.config = config nh = config.n_head D = config.n_embd N = config.mlp_internal_dim_multiplier * D // nh self.freqs = torch.nn.Buffer( get_freqs(N, theta=2**16, dtype=torch.float32).view(1, 1, 1, N) ) @staticmethod def phases_cos_sin(phases): phases = (phases % 1) * (2 * math.pi) phases_cos = torch.cos(phases) phases_sin = torch.sin(phases) return phases_cos, phases_sin @staticmethod def rope(phases, v): v_rot = torch.stack((-v[..., 1::2], v[..., ::2]), dim=-1).view(*v.size()) phases_cos, phases_sin = Attention.phases_cos_sin(phases) return (v * phases_cos).to(v.dtype) + (v_rot * phases_sin).to(v.dtype) def forward(self, Q, K, V): assert self.freqs.dtype == torch.float32 assert K is Q _, _, T, _ = Q.size() r_phases = ( torch.arange( 0, T, device=self.freqs.device, dtype=self.freqs.dtype, ).view(1, 1, -1, 1) ) * self.freqs QR = self.rope(r_phases, Q) KR = QR # Current attention scores = (QR @ KR.mT).tril(diagonal=-1) return scores @ V class BDH(nn.Module): def __init__(self, config: BDHConfig): super().__init__() assert config.vocab_size is not None self.config = config nh = config.n_head D = config.n_embd N = config.mlp_internal_dim_multiplier * D // nh self.decoder = nn.Parameter(torch.zeros((nh * N, D)).normal_(std=0.02)) self.encoder = nn.Parameter(torch.zeros((nh, D, N)).normal_(std=0.02)) self.attn = Attention(config) self.ln = nn.LayerNorm(D, elementwise_affine=False, bias=False) self.embed = nn.Embedding(config.vocab_size, D) self.drop = nn.Dropout(config.dropout) self.encoder_v = nn.Parameter(torch.zeros((nh, D, N)).normal_(std=0.02)) self.lm_head = nn.Parameter( torch.zeros((D, config.vocab_size)).normal_(std=0.02) ) self.apply(self._init_weights) def _init_weights(self, module): if isinstance(module, nn.Linear): nn.init.normal_(module.weight, mean=0.0, std=0.02) if module.bias is not None: nn.init.zeros_(module.bias) elif isinstance(module, nn.Embedding): nn.init.normal_(module.weight, mean=0.0, std=0.02) def forward(self, idx, targets=None): C = self.config B, T = idx.size() D = C.n_embd nh = C.n_head N = D * C.mlp_internal_dim_multiplier // nh x = self.embed(idx).unsqueeze(1) # actually helps with training x = self.ln(x) # B, 1, T, D for level in range(C.n_layer): x_latent = x @ self.encoder x_sparse = F.relu(x_latent) # B, nh, T, N yKV = self.attn( Q=x_sparse, K=x_sparse, V=x, ) yKV = self.ln(yKV) y_latent = yKV @ self.encoder_v y_sparse = F.relu(y_latent) xy_sparse = x_sparse * y_sparse # B, nh, T, N xy_sparse = self.drop(xy_sparse) yMLP = ( xy_sparse.transpose(1, 2).reshape(B, 1, T, N * nh) @ self.decoder ) # B, 1, T, D y = self.ln(yMLP) x = self.ln(x + y) logits = x.view(B, T, D) @ self.lm_head loss = None if targets is not None: loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1)) return logits, loss @torch.no_grad() def generate( self, idx: torch.Tensor, max_new_tokens: int, temperature: float = 1.0, top_k: int | None = None, ) -> torch.Tensor: for _ in range(max_new_tokens): idx_cond = idx logits, _ = self(idx_cond) logits = logits[:, -1, :] / temperature if top_k is not None: values, _ = torch.topk(logits, min(top_k, logits.size(-1))) logits[logits < values[:, [-1]]] = float("-inf") probs = F.softmax(logits, dim=-1) idx_next = torch.multinomial(probs, num_samples=1) idx = torch.cat((idx, idx_next), dim=1) return idx