aur / bdh.py
Andrewstivan's picture
Update bdh.py
b75ff45 verified
# bdh.py
import dataclasses
import math
from typing import Optional, Tuple, List
import torch
import torch.nn as nn
import torch.nn.functional as F
@dataclasses.dataclass
class BDHConfig:
n_layer: int = 32
n_embd: int = 4096
dropout: float = 0.1
n_head: int = 32
mlp_internal_dim_multiplier: int = 1
vocab_size: int = 256
use_alibi: bool = True
use_l1_norm: bool = True
relu_threshold: float = 0.0
rotary_embedding: str = "rope"
rope_theta: float = 65536.0
use_plasticity: bool = True
plasticity_lr: float = 0.01
consolidation_rate: float = 0.01
forget_rate: float = 0.1
use_rho_cache: bool = True
def latent_per_head(self) -> int:
return self.mlp_internal_dim_multiplier * self.n_embd // self.n_head
def latent_total(self) -> int:
return self.latent_per_head() * self.n_head
class TernaryLinear3D(nn.Module):
def __init__(self, n_head: int, in_features: int, out_features: int):
super().__init__()
self.n_head = n_head
self.in_features = in_features
self.out_features = out_features
self.register_buffer('weight_ternary', torch.zeros(n_head, out_features, in_features, dtype=torch.int8))
self.weight_fp32 = nn.Parameter(torch.zeros(n_head, out_features, in_features))
self.register_buffer('weight_scale', torch.ones(n_head, 1, 1))
self._init_weights()
def _init_weights(self):
with torch.no_grad():
rand_vals = torch.randint(-1, 2, self.weight_fp32.shape, dtype=torch.float32)
self.weight_fp32.data = rand_vals
self.weight_ternary.data = rand_vals.to(torch.int8)
def forward(self, x: torch.Tensor) -> torch.Tensor:
if x.dim() == 4 and x.size(1) == 1:
x = x.expand(-1, self.n_head, -1, -1)
weight = self.weight_ternary.float()
return torch.einsum('bhtd,hnd->bhtn', x, weight)
def update_ternary_weights(self):
with torch.no_grad():
gamma = self.weight_fp32.abs().mean(dim=(1, 2), keepdim=True).clamp(min=1e-5)
self.weight_scale.data = gamma
w_scaled = self.weight_fp32 / gamma
w_ternary = torch.round(w_scaled).clamp(-1, 1).to(torch.int8)
self.weight_ternary.data = w_ternary
self.weight_fp32.data = w_ternary.float() * gamma
class TernaryLinear2D(nn.Module):
def __init__(self, in_features: int, out_features: int):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.register_buffer('weight_ternary', torch.zeros(out_features, in_features, dtype=torch.int8))
self.weight_fp32 = nn.Parameter(torch.zeros(out_features, in_features))
self.register_buffer('weight_scale', torch.ones(1))
self._init_weights()
def _init_weights(self):
with torch.no_grad():
rand_vals = torch.randint(-1, 2, self.weight_fp32.shape, dtype=torch.float32)
self.weight_fp32.data = rand_vals
self.weight_ternary.data = rand_vals.to(torch.int8)
def forward(self, x: torch.Tensor) -> torch.Tensor:
orig_shape = x.shape
if x.dim() == 4:
B, _, T, D = x.shape
x = x.view(B * T, D)
elif x.dim() == 3:
B, T, D = x.shape
x = x.view(B * T, D)
weight = self.weight_ternary.float()
out = F.linear(x, weight)
if len(orig_shape) == 4:
B, _, T, _ = orig_shape
out = out.view(B, 1, T, -1)
elif len(orig_shape) == 3:
B, T, _ = orig_shape
out = out.view(B, T, -1)
return out
def update_ternary_weights(self):
with torch.no_grad():
gamma = self.weight_fp32.abs().mean().clamp(min=1e-5)
self.weight_scale.data = gamma
w_scaled = self.weight_fp32 / gamma
w_ternary = torch.round(w_scaled).clamp(-1, 1).to(torch.int8)
self.weight_ternary.data = w_ternary
self.weight_fp32.data = w_ternary.float() * gamma
def get_freqs(n: int, theta: float, dtype: torch.dtype, rotary_type: str = "rope") -> torch.Tensor:
if rotary_type == "alibi":
return torch.zeros(n, dtype=dtype)
def quantize(t, q=2):
return (t / q).floor() * q
indices = torch.arange(0, n, 1, dtype=dtype)
if rotary_type == "rope":
indices = quantize(indices)
return 1.0 / (theta ** (indices / n)) / (2 * math.pi)
def row_normalize(scores: torch.Tensor, eps: float = 1e-6) -> torch.Tensor:
denom = scores.abs().sum(dim=-1, keepdim=True) + eps
return scores / denom
class Attention(nn.Module):
def __init__(self, config: BDHConfig):
super().__init__()
self.config = config
nh = config.n_head
N = config.latent_per_head()
self.use_alibi = config.use_alibi
self.use_l1_norm = config.use_l1_norm
self.rotary_type = config.rotary_embedding
freqs = get_freqs(N, config.rope_theta, torch.float32, self.rotary_type)
self.register_buffer('freqs', freqs.view(1, 1, 1, N))
if self.use_alibi:
slopes = torch.tensor([2 ** (-8 * i / nh) for i in range(1, nh + 1)], dtype=torch.float32)
self.register_buffer('alibi_slopes', slopes.view(1, nh, 1, 1))
def _rope(self, phases: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
v_rot = torch.stack((-v[..., 1::2], v[..., ::2]), dim=-1).view(*v.size())
phases_cos, phases_sin = torch.cos(phases), torch.sin(phases)
return (v * phases_cos).to(v.dtype) + (v_rot * phases_sin).to(v.dtype)
def _rotate(self, v: torch.Tensor, start: int = 0) -> torch.Tensor:
if self.rotary_type == "alibi":
return v
_, _, T, _ = v.size()
device = v.device
positions = torch.arange(start, start + T, device=device, dtype=self.freqs.dtype).view(1, 1, -1, 1)
raw = positions * self.freqs
phases = (raw - raw.floor()) * (2 * math.pi)
return self._rope(phases, v)
def forward(self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor, start_pos: int = 0) -> torch.Tensor:
assert K is Q
B, nh, T, N = Q.size()
QR = self._rotate(Q, start_pos)
KR = QR
scores = (QR @ KR.mT).tril(diagonal=-1)
if self.use_alibi:
pos_row = torch.arange(start_pos, start_pos + T, device=scores.device)
pos_col = torch.arange(start_pos, start_pos + T, device=scores.device)
alibi = (pos_col.view(1, 1, 1, -1) - pos_row.view(1, 1, -1, 1)).tril(-1)
scores = scores + alibi * self.alibi_slopes
if self.use_l1_norm:
scores = row_normalize(scores)
return scores @ V
class BDHState:
def __init__(self, n_layer: int, n_head: int, latent_dim: int, n_embd: int):
self.n_layer = n_layer
self.n_head = n_head
self.latent_dim = latent_dim
self.n_embd = n_embd
self.layers: List[dict] = [{'rho': None, 'hidden': None} for _ in range(n_layer)]
self.total_position = 0
def get_rho(self, layer_idx: int, batch_size: int, device: torch.device) -> torch.Tensor:
rho = self.layers[layer_idx]['rho']
if rho is None:
rho = torch.zeros(batch_size, self.n_head, self.latent_dim, self.n_embd, device=device)
self.layers[layer_idx]['rho'] = rho
return rho
def update_rho(self, layer_idx: int, x_latent: torch.Tensor, v: torch.Tensor, decay: float = 1.0):
rho = self.layers[layer_idx]['rho']
rho = rho * decay + torch.einsum('bhn,bhd->bhnd', x_latent, v)
self.layers[layer_idx]['rho'] = rho
def set_hidden(self, layer_idx: int, hidden: torch.Tensor):
self.layers[layer_idx]['hidden'] = hidden
def get_hidden(self, layer_idx: int) -> Optional[torch.Tensor]:
return self.layers[layer_idx]['hidden']
def advance_position(self):
self.total_position += 1
class BDH(nn.Module):
def __init__(self, config: BDHConfig):
super().__init__()
self.config = config
nh = config.n_head
D = config.n_embd
N = config.latent_per_head()
self.encoder = TernaryLinear3D(nh, D, N)
self.encoder_v = TernaryLinear3D(nh, D, N)
self.decoder = TernaryLinear2D(nh * N, D)
self.attn = Attention(config)
self.ln = nn.LayerNorm(D, elementwise_affine=False, bias=False)
self.embed = nn.Embedding(config.vocab_size, D)
self.drop = nn.Dropout(config.dropout)
self.lm_head = TernaryLinear2D(D, config.vocab_size)
self.relu_threshold = config.relu_threshold
self.plasticity = None
if config.use_plasticity:
from plasticity import UnifiedPlasticity
self.plasticity = UnifiedPlasticity(
modules=[self.encoder, self.encoder_v, self.decoder, self.lm_head],
lr=config.plasticity_lr,
consolidation_rate=config.consolidation_rate,
forget_rate=config.forget_rate
)
def forward(self, idx: torch.Tensor, targets: Optional[torch.Tensor] = None,
state: Optional[BDHState] = None) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
C = self.config
B, T = idx.size()
D = C.n_embd
nh = C.n_head
N = C.latent_per_head()
x = self.embed(idx).unsqueeze(1)
x = self.ln(x)
start_pos = state.total_position if state is not None else 0
for layer_idx in range(C.n_layer):
x_latent = self.encoder(x)
if self.relu_threshold != 0:
x_latent = x_latent - self.relu_threshold
x_sparse = F.relu(x_latent)
if state is not None and C.use_rho_cache:
yKV = self._recurrent_attention(x_sparse, x, state, layer_idx, start_pos)
else:
yKV = self.attn(Q=x_sparse, K=x_sparse, V=x, start_pos=start_pos)
yKV = self.ln(yKV)
y_latent = self.encoder_v(yKV)
if self.relu_threshold != 0:
y_latent = y_latent - self.relu_threshold
y_sparse = F.relu(y_latent)
xy_sparse = x_sparse * y_sparse
xy_sparse = self.drop(xy_sparse)
xy_flat = xy_sparse.transpose(1, 2).reshape(B, 1, T, N * nh)
yMLP = self.decoder(xy_flat)
y = self.ln(yMLP)
x = self.ln(x + y)
if state is not None:
state.set_hidden(layer_idx, x.clone())
if state is not None:
state.advance_position()
logits = self.lm_head(x.squeeze(1))
loss = None
if targets is not None:
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
return logits, loss
def forward_with_states(self, idx: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]:
C = self.config
B, T = idx.size()
D = C.n_embd
nh = C.n_head
N = C.latent_per_head()
state = BDHState(C.n_layer, nh, N, D)
logits, _ = self.forward(idx, state=state)
hidden_states = []
for layer_idx in range(C.n_layer):
h = state.get_hidden(layer_idx)
if h is not None:
hidden_states.append(h.squeeze(1))
return logits, hidden_states
def _recurrent_attention(self, Q: torch.Tensor, V: torch.Tensor, state: BDHState,
layer_idx: int, start_pos: int) -> torch.Tensor:
B, nh, T, N = Q.size()
D = V.size(-1)
device = Q.device
QR = self.attn._rotate(Q, start_pos)
outputs = []
for t in range(T):
q_t = QR[:, :, t:t+1, :]
v_t = V[:, :, t:t+1, :].repeat(1, nh, 1, 1)
rho = state.get_rho(layer_idx, B, device)
attn_t = (rho * q_t.transpose(-1, -2)).sum(dim=2, keepdim=True)
outputs.append(attn_t)
state.update_rho(layer_idx, q_t.squeeze(2), v_t.squeeze(2))
return torch.cat(outputs, dim=2)
def update_ternary_weights(self):
for module in self.modules():
if isinstance(module, (TernaryLinear2D, TernaryLinear3D)):
module.update_ternary_weights()
if self.plasticity is not None:
self.plasticity._update_ternary()
def save(self, path: str):
torch.save({
'config': self.config,
'state_dict': self.state_dict()
}, path)
@classmethod
def load(cls, path: str, device: str = 'cpu') -> 'BDH':
checkpoint = torch.load(path, map_location=device, weights_only=False)
config = checkpoint['config']
model = cls(config).to(device)
model.load_state_dict(checkpoint['state_dict'])
return model