# bdh.py import dataclasses import math from typing import Optional, Tuple, List import torch import torch.nn as nn import torch.nn.functional as F @dataclasses.dataclass class BDHConfig: n_layer: int = 32 n_embd: int = 4096 dropout: float = 0.1 n_head: int = 32 mlp_internal_dim_multiplier: int = 1 vocab_size: int = 256 use_alibi: bool = True use_l1_norm: bool = True relu_threshold: float = 0.0 rotary_embedding: str = "rope" rope_theta: float = 65536.0 use_plasticity: bool = True plasticity_lr: float = 0.01 consolidation_rate: float = 0.01 forget_rate: float = 0.1 use_rho_cache: bool = True def latent_per_head(self) -> int: return self.mlp_internal_dim_multiplier * self.n_embd // self.n_head def latent_total(self) -> int: return self.latent_per_head() * self.n_head class TernaryLinear3D(nn.Module): def __init__(self, n_head: int, in_features: int, out_features: int): super().__init__() self.n_head = n_head self.in_features = in_features self.out_features = out_features self.register_buffer('weight_ternary', torch.zeros(n_head, out_features, in_features, dtype=torch.int8)) self.weight_fp32 = nn.Parameter(torch.zeros(n_head, out_features, in_features)) self.register_buffer('weight_scale', torch.ones(n_head, 1, 1)) self._init_weights() def _init_weights(self): with torch.no_grad(): rand_vals = torch.randint(-1, 2, self.weight_fp32.shape, dtype=torch.float32) self.weight_fp32.data = rand_vals self.weight_ternary.data = rand_vals.to(torch.int8) def forward(self, x: torch.Tensor) -> torch.Tensor: if x.dim() == 4 and x.size(1) == 1: x = x.expand(-1, self.n_head, -1, -1) weight = self.weight_ternary.float() return torch.einsum('bhtd,hnd->bhtn', x, weight) def update_ternary_weights(self): with torch.no_grad(): gamma = self.weight_fp32.abs().mean(dim=(1, 2), keepdim=True).clamp(min=1e-5) self.weight_scale.data = gamma w_scaled = self.weight_fp32 / gamma w_ternary = torch.round(w_scaled).clamp(-1, 1).to(torch.int8) self.weight_ternary.data = w_ternary self.weight_fp32.data = w_ternary.float() * gamma class TernaryLinear2D(nn.Module): def __init__(self, in_features: int, out_features: int): super().__init__() self.in_features = in_features self.out_features = out_features self.register_buffer('weight_ternary', torch.zeros(out_features, in_features, dtype=torch.int8)) self.weight_fp32 = nn.Parameter(torch.zeros(out_features, in_features)) self.register_buffer('weight_scale', torch.ones(1)) self._init_weights() def _init_weights(self): with torch.no_grad(): rand_vals = torch.randint(-1, 2, self.weight_fp32.shape, dtype=torch.float32) self.weight_fp32.data = rand_vals self.weight_ternary.data = rand_vals.to(torch.int8) def forward(self, x: torch.Tensor) -> torch.Tensor: orig_shape = x.shape if x.dim() == 4: B, _, T, D = x.shape x = x.view(B * T, D) elif x.dim() == 3: B, T, D = x.shape x = x.view(B * T, D) weight = self.weight_ternary.float() out = F.linear(x, weight) if len(orig_shape) == 4: B, _, T, _ = orig_shape out = out.view(B, 1, T, -1) elif len(orig_shape) == 3: B, T, _ = orig_shape out = out.view(B, T, -1) return out def update_ternary_weights(self): with torch.no_grad(): gamma = self.weight_fp32.abs().mean().clamp(min=1e-5) self.weight_scale.data = gamma w_scaled = self.weight_fp32 / gamma w_ternary = torch.round(w_scaled).clamp(-1, 1).to(torch.int8) self.weight_ternary.data = w_ternary self.weight_fp32.data = w_ternary.float() * gamma def get_freqs(n: int, theta: float, dtype: torch.dtype, rotary_type: str = "rope") -> torch.Tensor: if rotary_type == "alibi": return torch.zeros(n, dtype=dtype) def quantize(t, q=2): return (t / q).floor() * q indices = torch.arange(0, n, 1, dtype=dtype) if rotary_type == "rope": indices = quantize(indices) return 1.0 / (theta ** (indices / n)) / (2 * math.pi) def row_normalize(scores: torch.Tensor, eps: float = 1e-6) -> torch.Tensor: denom = scores.abs().sum(dim=-1, keepdim=True) + eps return scores / denom class Attention(nn.Module): def __init__(self, config: BDHConfig): super().__init__() self.config = config nh = config.n_head N = config.latent_per_head() self.use_alibi = config.use_alibi self.use_l1_norm = config.use_l1_norm self.rotary_type = config.rotary_embedding freqs = get_freqs(N, config.rope_theta, torch.float32, self.rotary_type) self.register_buffer('freqs', freqs.view(1, 1, 1, N)) if self.use_alibi: slopes = torch.tensor([2 ** (-8 * i / nh) for i in range(1, nh + 1)], dtype=torch.float32) self.register_buffer('alibi_slopes', slopes.view(1, nh, 1, 1)) def _rope(self, phases: torch.Tensor, v: torch.Tensor) -> torch.Tensor: v_rot = torch.stack((-v[..., 1::2], v[..., ::2]), dim=-1).view(*v.size()) phases_cos, phases_sin = torch.cos(phases), torch.sin(phases) return (v * phases_cos).to(v.dtype) + (v_rot * phases_sin).to(v.dtype) def _rotate(self, v: torch.Tensor, start: int = 0) -> torch.Tensor: if self.rotary_type == "alibi": return v _, _, T, _ = v.size() device = v.device positions = torch.arange(start, start + T, device=device, dtype=self.freqs.dtype).view(1, 1, -1, 1) raw = positions * self.freqs phases = (raw - raw.floor()) * (2 * math.pi) return self._rope(phases, v) def forward(self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor, start_pos: int = 0) -> torch.Tensor: assert K is Q B, nh, T, N = Q.size() QR = self._rotate(Q, start_pos) KR = QR scores = (QR @ KR.mT).tril(diagonal=-1) if self.use_alibi: pos_row = torch.arange(start_pos, start_pos + T, device=scores.device) pos_col = torch.arange(start_pos, start_pos + T, device=scores.device) alibi = (pos_col.view(1, 1, 1, -1) - pos_row.view(1, 1, -1, 1)).tril(-1) scores = scores + alibi * self.alibi_slopes if self.use_l1_norm: scores = row_normalize(scores) return scores @ V class BDHState: def __init__(self, n_layer: int, n_head: int, latent_dim: int, n_embd: int): self.n_layer = n_layer self.n_head = n_head self.latent_dim = latent_dim self.n_embd = n_embd self.layers: List[dict] = [{'rho': None, 'hidden': None} for _ in range(n_layer)] self.total_position = 0 def get_rho(self, layer_idx: int, batch_size: int, device: torch.device) -> torch.Tensor: rho = self.layers[layer_idx]['rho'] if rho is None: rho = torch.zeros(batch_size, self.n_head, self.latent_dim, self.n_embd, device=device) self.layers[layer_idx]['rho'] = rho return rho def update_rho(self, layer_idx: int, x_latent: torch.Tensor, v: torch.Tensor, decay: float = 1.0): rho = self.layers[layer_idx]['rho'] rho = rho * decay + torch.einsum('bhn,bhd->bhnd', x_latent, v) self.layers[layer_idx]['rho'] = rho def set_hidden(self, layer_idx: int, hidden: torch.Tensor): self.layers[layer_idx]['hidden'] = hidden def get_hidden(self, layer_idx: int) -> Optional[torch.Tensor]: return self.layers[layer_idx]['hidden'] def advance_position(self): self.total_position += 1 class BDH(nn.Module): def __init__(self, config: BDHConfig): super().__init__() self.config = config nh = config.n_head D = config.n_embd N = config.latent_per_head() self.encoder = TernaryLinear3D(nh, D, N) self.encoder_v = TernaryLinear3D(nh, D, N) self.decoder = TernaryLinear2D(nh * N, D) self.attn = Attention(config) self.ln = nn.LayerNorm(D, elementwise_affine=False, bias=False) self.embed = nn.Embedding(config.vocab_size, D) self.drop = nn.Dropout(config.dropout) self.lm_head = TernaryLinear2D(D, config.vocab_size) self.relu_threshold = config.relu_threshold self.plasticity = None if config.use_plasticity: from plasticity import UnifiedPlasticity self.plasticity = UnifiedPlasticity( modules=[self.encoder, self.encoder_v, self.decoder, self.lm_head], lr=config.plasticity_lr, consolidation_rate=config.consolidation_rate, forget_rate=config.forget_rate ) def forward(self, idx: torch.Tensor, targets: Optional[torch.Tensor] = None, state: Optional[BDHState] = None) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: C = self.config B, T = idx.size() D = C.n_embd nh = C.n_head N = C.latent_per_head() x = self.embed(idx).unsqueeze(1) x = self.ln(x) start_pos = state.total_position if state is not None else 0 for layer_idx in range(C.n_layer): x_latent = self.encoder(x) if self.relu_threshold != 0: x_latent = x_latent - self.relu_threshold x_sparse = F.relu(x_latent) if state is not None and C.use_rho_cache: yKV = self._recurrent_attention(x_sparse, x, state, layer_idx, start_pos) else: yKV = self.attn(Q=x_sparse, K=x_sparse, V=x, start_pos=start_pos) yKV = self.ln(yKV) y_latent = self.encoder_v(yKV) if self.relu_threshold != 0: y_latent = y_latent - self.relu_threshold y_sparse = F.relu(y_latent) xy_sparse = x_sparse * y_sparse xy_sparse = self.drop(xy_sparse) xy_flat = xy_sparse.transpose(1, 2).reshape(B, 1, T, N * nh) yMLP = self.decoder(xy_flat) y = self.ln(yMLP) x = self.ln(x + y) if state is not None: state.set_hidden(layer_idx, x.clone()) if state is not None: state.advance_position() logits = self.lm_head(x.squeeze(1)) loss = None if targets is not None: loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1)) return logits, loss def forward_with_states(self, idx: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]: C = self.config B, T = idx.size() D = C.n_embd nh = C.n_head N = C.latent_per_head() state = BDHState(C.n_layer, nh, N, D) logits, _ = self.forward(idx, state=state) hidden_states = [] for layer_idx in range(C.n_layer): h = state.get_hidden(layer_idx) if h is not None: hidden_states.append(h.squeeze(1)) return logits, hidden_states def _recurrent_attention(self, Q: torch.Tensor, V: torch.Tensor, state: BDHState, layer_idx: int, start_pos: int) -> torch.Tensor: B, nh, T, N = Q.size() D = V.size(-1) device = Q.device QR = self.attn._rotate(Q, start_pos) outputs = [] for t in range(T): q_t = QR[:, :, t:t+1, :] v_t = V[:, :, t:t+1, :].repeat(1, nh, 1, 1) rho = state.get_rho(layer_idx, B, device) attn_t = (rho * q_t.transpose(-1, -2)).sum(dim=2, keepdim=True) outputs.append(attn_t) state.update_rho(layer_idx, q_t.squeeze(2), v_t.squeeze(2)) return torch.cat(outputs, dim=2) def update_ternary_weights(self): for module in self.modules(): if isinstance(module, (TernaryLinear2D, TernaryLinear3D)): module.update_ternary_weights() if self.plasticity is not None: self.plasticity._update_ternary() def save(self, path: str): torch.save({ 'config': self.config, 'state_dict': self.state_dict() }, path) @classmethod def load(cls, path: str, device: str = 'cpu') -> 'BDH': checkpoint = torch.load(path, map_location=device, weights_only=False) config = checkpoint['config'] model = cls(config).to(device) model.load_state_dict(checkpoint['state_dict']) return model