Spaces:
Sleeping
Sleeping
| # bdh.py | |
| import dataclasses | |
| import math | |
| from typing import Optional, Tuple, List | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| class BDHConfig: | |
| n_layer: int = 32 | |
| n_embd: int = 4096 | |
| dropout: float = 0.1 | |
| n_head: int = 32 | |
| mlp_internal_dim_multiplier: int = 1 | |
| vocab_size: int = 256 | |
| use_alibi: bool = True | |
| use_l1_norm: bool = True | |
| relu_threshold: float = 0.0 | |
| rotary_embedding: str = "rope" | |
| rope_theta: float = 65536.0 | |
| use_plasticity: bool = True | |
| plasticity_lr: float = 0.01 | |
| consolidation_rate: float = 0.01 | |
| forget_rate: float = 0.1 | |
| use_rho_cache: bool = True | |
| def latent_per_head(self) -> int: | |
| return self.mlp_internal_dim_multiplier * self.n_embd // self.n_head | |
| def latent_total(self) -> int: | |
| return self.latent_per_head() * self.n_head | |
| class TernaryLinear3D(nn.Module): | |
| def __init__(self, n_head: int, in_features: int, out_features: int): | |
| super().__init__() | |
| self.n_head = n_head | |
| self.in_features = in_features | |
| self.out_features = out_features | |
| self.register_buffer('weight_ternary', torch.zeros(n_head, out_features, in_features, dtype=torch.int8)) | |
| self.weight_fp32 = nn.Parameter(torch.zeros(n_head, out_features, in_features)) | |
| self.register_buffer('weight_scale', torch.ones(n_head, 1, 1)) | |
| self._init_weights() | |
| def _init_weights(self): | |
| with torch.no_grad(): | |
| rand_vals = torch.randint(-1, 2, self.weight_fp32.shape, dtype=torch.float32) | |
| self.weight_fp32.data = rand_vals | |
| self.weight_ternary.data = rand_vals.to(torch.int8) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| if x.dim() == 4 and x.size(1) == 1: | |
| x = x.expand(-1, self.n_head, -1, -1) | |
| weight = self.weight_ternary.float() | |
| return torch.einsum('bhtd,hnd->bhtn', x, weight) | |
| def update_ternary_weights(self): | |
| with torch.no_grad(): | |
| gamma = self.weight_fp32.abs().mean(dim=(1, 2), keepdim=True).clamp(min=1e-5) | |
| self.weight_scale.data = gamma | |
| w_scaled = self.weight_fp32 / gamma | |
| w_ternary = torch.round(w_scaled).clamp(-1, 1).to(torch.int8) | |
| self.weight_ternary.data = w_ternary | |
| self.weight_fp32.data = w_ternary.float() * gamma | |
| class TernaryLinear2D(nn.Module): | |
| def __init__(self, in_features: int, out_features: int): | |
| super().__init__() | |
| self.in_features = in_features | |
| self.out_features = out_features | |
| self.register_buffer('weight_ternary', torch.zeros(out_features, in_features, dtype=torch.int8)) | |
| self.weight_fp32 = nn.Parameter(torch.zeros(out_features, in_features)) | |
| self.register_buffer('weight_scale', torch.ones(1)) | |
| self._init_weights() | |
| def _init_weights(self): | |
| with torch.no_grad(): | |
| rand_vals = torch.randint(-1, 2, self.weight_fp32.shape, dtype=torch.float32) | |
| self.weight_fp32.data = rand_vals | |
| self.weight_ternary.data = rand_vals.to(torch.int8) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| orig_shape = x.shape | |
| if x.dim() == 4: | |
| B, _, T, D = x.shape | |
| x = x.view(B * T, D) | |
| elif x.dim() == 3: | |
| B, T, D = x.shape | |
| x = x.view(B * T, D) | |
| weight = self.weight_ternary.float() | |
| out = F.linear(x, weight) | |
| if len(orig_shape) == 4: | |
| B, _, T, _ = orig_shape | |
| out = out.view(B, 1, T, -1) | |
| elif len(orig_shape) == 3: | |
| B, T, _ = orig_shape | |
| out = out.view(B, T, -1) | |
| return out | |
| def update_ternary_weights(self): | |
| with torch.no_grad(): | |
| gamma = self.weight_fp32.abs().mean().clamp(min=1e-5) | |
| self.weight_scale.data = gamma | |
| w_scaled = self.weight_fp32 / gamma | |
| w_ternary = torch.round(w_scaled).clamp(-1, 1).to(torch.int8) | |
| self.weight_ternary.data = w_ternary | |
| self.weight_fp32.data = w_ternary.float() * gamma | |
| def get_freqs(n: int, theta: float, dtype: torch.dtype, rotary_type: str = "rope") -> torch.Tensor: | |
| if rotary_type == "alibi": | |
| return torch.zeros(n, dtype=dtype) | |
| def quantize(t, q=2): | |
| return (t / q).floor() * q | |
| indices = torch.arange(0, n, 1, dtype=dtype) | |
| if rotary_type == "rope": | |
| indices = quantize(indices) | |
| return 1.0 / (theta ** (indices / n)) / (2 * math.pi) | |
| def row_normalize(scores: torch.Tensor, eps: float = 1e-6) -> torch.Tensor: | |
| denom = scores.abs().sum(dim=-1, keepdim=True) + eps | |
| return scores / denom | |
| class Attention(nn.Module): | |
| def __init__(self, config: BDHConfig): | |
| super().__init__() | |
| self.config = config | |
| nh = config.n_head | |
| N = config.latent_per_head() | |
| self.use_alibi = config.use_alibi | |
| self.use_l1_norm = config.use_l1_norm | |
| self.rotary_type = config.rotary_embedding | |
| freqs = get_freqs(N, config.rope_theta, torch.float32, self.rotary_type) | |
| self.register_buffer('freqs', freqs.view(1, 1, 1, N)) | |
| if self.use_alibi: | |
| slopes = torch.tensor([2 ** (-8 * i / nh) for i in range(1, nh + 1)], dtype=torch.float32) | |
| self.register_buffer('alibi_slopes', slopes.view(1, nh, 1, 1)) | |
| def _rope(self, phases: torch.Tensor, v: torch.Tensor) -> torch.Tensor: | |
| v_rot = torch.stack((-v[..., 1::2], v[..., ::2]), dim=-1).view(*v.size()) | |
| phases_cos, phases_sin = torch.cos(phases), torch.sin(phases) | |
| return (v * phases_cos).to(v.dtype) + (v_rot * phases_sin).to(v.dtype) | |
| def _rotate(self, v: torch.Tensor, start: int = 0) -> torch.Tensor: | |
| if self.rotary_type == "alibi": | |
| return v | |
| _, _, T, _ = v.size() | |
| device = v.device | |
| positions = torch.arange(start, start + T, device=device, dtype=self.freqs.dtype).view(1, 1, -1, 1) | |
| raw = positions * self.freqs | |
| phases = (raw - raw.floor()) * (2 * math.pi) | |
| return self._rope(phases, v) | |
| def forward(self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor, start_pos: int = 0) -> torch.Tensor: | |
| assert K is Q | |
| B, nh, T, N = Q.size() | |
| QR = self._rotate(Q, start_pos) | |
| KR = QR | |
| scores = (QR @ KR.mT).tril(diagonal=-1) | |
| if self.use_alibi: | |
| pos_row = torch.arange(start_pos, start_pos + T, device=scores.device) | |
| pos_col = torch.arange(start_pos, start_pos + T, device=scores.device) | |
| alibi = (pos_col.view(1, 1, 1, -1) - pos_row.view(1, 1, -1, 1)).tril(-1) | |
| scores = scores + alibi * self.alibi_slopes | |
| if self.use_l1_norm: | |
| scores = row_normalize(scores) | |
| return scores @ V | |
| class BDHState: | |
| def __init__(self, n_layer: int, n_head: int, latent_dim: int, n_embd: int): | |
| self.n_layer = n_layer | |
| self.n_head = n_head | |
| self.latent_dim = latent_dim | |
| self.n_embd = n_embd | |
| self.layers: List[dict] = [{'rho': None, 'hidden': None} for _ in range(n_layer)] | |
| self.total_position = 0 | |
| def get_rho(self, layer_idx: int, batch_size: int, device: torch.device) -> torch.Tensor: | |
| rho = self.layers[layer_idx]['rho'] | |
| if rho is None: | |
| rho = torch.zeros(batch_size, self.n_head, self.latent_dim, self.n_embd, device=device) | |
| self.layers[layer_idx]['rho'] = rho | |
| return rho | |
| def update_rho(self, layer_idx: int, x_latent: torch.Tensor, v: torch.Tensor, decay: float = 1.0): | |
| rho = self.layers[layer_idx]['rho'] | |
| rho = rho * decay + torch.einsum('bhn,bhd->bhnd', x_latent, v) | |
| self.layers[layer_idx]['rho'] = rho | |
| def set_hidden(self, layer_idx: int, hidden: torch.Tensor): | |
| self.layers[layer_idx]['hidden'] = hidden | |
| def get_hidden(self, layer_idx: int) -> Optional[torch.Tensor]: | |
| return self.layers[layer_idx]['hidden'] | |
| def advance_position(self): | |
| self.total_position += 1 | |
| class BDH(nn.Module): | |
| def __init__(self, config: BDHConfig): | |
| super().__init__() | |
| self.config = config | |
| nh = config.n_head | |
| D = config.n_embd | |
| N = config.latent_per_head() | |
| self.encoder = TernaryLinear3D(nh, D, N) | |
| self.encoder_v = TernaryLinear3D(nh, D, N) | |
| self.decoder = TernaryLinear2D(nh * N, D) | |
| self.attn = Attention(config) | |
| self.ln = nn.LayerNorm(D, elementwise_affine=False, bias=False) | |
| self.embed = nn.Embedding(config.vocab_size, D) | |
| self.drop = nn.Dropout(config.dropout) | |
| self.lm_head = TernaryLinear2D(D, config.vocab_size) | |
| self.relu_threshold = config.relu_threshold | |
| self.plasticity = None | |
| if config.use_plasticity: | |
| from plasticity import UnifiedPlasticity | |
| self.plasticity = UnifiedPlasticity( | |
| modules=[self.encoder, self.encoder_v, self.decoder, self.lm_head], | |
| lr=config.plasticity_lr, | |
| consolidation_rate=config.consolidation_rate, | |
| forget_rate=config.forget_rate | |
| ) | |
| def forward(self, idx: torch.Tensor, targets: Optional[torch.Tensor] = None, | |
| state: Optional[BDHState] = None) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: | |
| C = self.config | |
| B, T = idx.size() | |
| D = C.n_embd | |
| nh = C.n_head | |
| N = C.latent_per_head() | |
| x = self.embed(idx).unsqueeze(1) | |
| x = self.ln(x) | |
| start_pos = state.total_position if state is not None else 0 | |
| for layer_idx in range(C.n_layer): | |
| x_latent = self.encoder(x) | |
| if self.relu_threshold != 0: | |
| x_latent = x_latent - self.relu_threshold | |
| x_sparse = F.relu(x_latent) | |
| if state is not None and C.use_rho_cache: | |
| yKV = self._recurrent_attention(x_sparse, x, state, layer_idx, start_pos) | |
| else: | |
| yKV = self.attn(Q=x_sparse, K=x_sparse, V=x, start_pos=start_pos) | |
| yKV = self.ln(yKV) | |
| y_latent = self.encoder_v(yKV) | |
| if self.relu_threshold != 0: | |
| y_latent = y_latent - self.relu_threshold | |
| y_sparse = F.relu(y_latent) | |
| xy_sparse = x_sparse * y_sparse | |
| xy_sparse = self.drop(xy_sparse) | |
| xy_flat = xy_sparse.transpose(1, 2).reshape(B, 1, T, N * nh) | |
| yMLP = self.decoder(xy_flat) | |
| y = self.ln(yMLP) | |
| x = self.ln(x + y) | |
| if state is not None: | |
| state.set_hidden(layer_idx, x.clone()) | |
| if state is not None: | |
| state.advance_position() | |
| logits = self.lm_head(x.squeeze(1)) | |
| loss = None | |
| if targets is not None: | |
| loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1)) | |
| return logits, loss | |
| def forward_with_states(self, idx: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]: | |
| C = self.config | |
| B, T = idx.size() | |
| D = C.n_embd | |
| nh = C.n_head | |
| N = C.latent_per_head() | |
| state = BDHState(C.n_layer, nh, N, D) | |
| logits, _ = self.forward(idx, state=state) | |
| hidden_states = [] | |
| for layer_idx in range(C.n_layer): | |
| h = state.get_hidden(layer_idx) | |
| if h is not None: | |
| hidden_states.append(h.squeeze(1)) | |
| return logits, hidden_states | |
| def _recurrent_attention(self, Q: torch.Tensor, V: torch.Tensor, state: BDHState, | |
| layer_idx: int, start_pos: int) -> torch.Tensor: | |
| B, nh, T, N = Q.size() | |
| D = V.size(-1) | |
| device = Q.device | |
| QR = self.attn._rotate(Q, start_pos) | |
| outputs = [] | |
| for t in range(T): | |
| q_t = QR[:, :, t:t+1, :] | |
| v_t = V[:, :, t:t+1, :].repeat(1, nh, 1, 1) | |
| rho = state.get_rho(layer_idx, B, device) | |
| attn_t = (rho * q_t.transpose(-1, -2)).sum(dim=2, keepdim=True) | |
| outputs.append(attn_t) | |
| state.update_rho(layer_idx, q_t.squeeze(2), v_t.squeeze(2)) | |
| return torch.cat(outputs, dim=2) | |
| def update_ternary_weights(self): | |
| for module in self.modules(): | |
| if isinstance(module, (TernaryLinear2D, TernaryLinear3D)): | |
| module.update_ternary_weights() | |
| if self.plasticity is not None: | |
| self.plasticity._update_ternary() | |
| def save(self, path: str): | |
| torch.save({ | |
| 'config': self.config, | |
| 'state_dict': self.state_dict() | |
| }, path) | |
| def load(cls, path: str, device: str = 'cpu') -> 'BDH': | |
| checkpoint = torch.load(path, map_location=device, weights_only=False) | |
| config = checkpoint['config'] | |
| model = cls(config).to(device) | |
| model.load_state_dict(checkpoint['state_dict']) | |
| return model |