| """v10: Sparse Distributed Memory char-LM. |
| |
| Zero backprop. Zero learned parameters (except optional final codebook). |
| Training = single pass of Hamming-ball writes; inference = Hamming-ball retrieval. |
| |
| Per Bricken & Pehlevan 2021, attention approximates SDM under norm conditions. |
| This is the pure-SDM baseline — essentially the classical associative-memory answer. |
| |
| Pipeline: |
| 1. Fix random ±1 hard-address matrix A ∈ {±1}^{N×D}, random char hypervectors |
| C ∈ {±1}^{V×D}, integer counter matrix M ∈ ℤ^{N×D}. |
| 2. Context embedding: cyclic-shift-bind last k chars into ±1 query q ∈ {±1}^D. |
| 3. For each (context, next_char) training pair: find addresses i where |
| Hamming(A_i, q) ≤ r (equivalently dot(A_i, q) ≥ D - 2r). For each such i, |
| accumulate C[next_char] into M_i (update counter toward target). |
| 4. At inference: retrieve y_est = sign(Σ_{i active} M_i), then classify next |
| char by argmax of y_est · C_v^T. |
| |
| No gradient, no training loop over parameters — just a single pass through the |
| data updating integer counters. |
| """ |
| import math |
| import os |
| import time |
| import numpy as np |
| import torch |
|
|
|
|
| def char_hv(vocab_size, d, seed=0): |
| g = torch.Generator().manual_seed(seed) |
| return torch.sign(torch.randn(vocab_size, d, generator=g)).to(torch.int8) |
|
|
|
|
| def random_hard_addresses(n, d, seed=1): |
| g = torch.Generator().manual_seed(seed) |
| return torch.sign(torch.randn(n, d, generator=g)).to(torch.int8) |
|
|
|
|
| def context_embed(ctx_ids, C, permutation_matrix=None): |
| """Permutation-bind last k chars. ctx_ids shape (..., k) int64. |
| Returns (..., D) ±1 int8. |
| |
| Use circular shift by position as the permutation (cheap, per Rachkovskij 2112). |
| """ |
| V, D = C.shape |
| |
| B, k = ctx_ids.shape |
| device = ctx_ids.device |
| codes = C.to(device)[ctx_ids] |
| |
| rolled = torch.stack([ |
| torch.roll(codes[:, p, :], shifts=p, dims=-1) for p in range(k) |
| ], dim=1) |
| |
| s = rolled.to(torch.int32).sum(dim=1) |
| out = torch.sign(s).to(torch.int8) |
| |
| out[out == 0] = 1 |
| return out |
|
|
|
|
| def retrieve_topk(query, A, topk): |
| """Return boolean mask of the top-k addresses by Hamming similarity. |
| query: (B, D), A: (N, D). Returns (B, N) bool. |
| """ |
| dots = query.to(torch.int32) @ A.to(torch.int32).t() |
| _, idx = torch.topk(dots, k=topk, dim=1) |
| mask = torch.zeros_like(dots, dtype=torch.bool) |
| mask.scatter_(1, idx, True) |
| return mask |
|
|
|
|
| class SDMCharLM: |
| def __init__(self, vocab_size=128, d=512, n_addrs=2**15, context=16, topk=None, |
| device='cuda', seed=0): |
| self.V = vocab_size |
| self.D = d |
| self.N = n_addrs |
| self.k = context |
| |
| self.topk = topk if topk is not None else max(8, n_addrs // 100) |
| self.device = device |
| self.C = char_hv(vocab_size, d, seed=seed).to(device) |
| self.A = random_hard_addresses(n_addrs, d, seed=seed + 1).to(device) |
| self.M = torch.zeros(n_addrs, d, dtype=torch.int32, device=device) |
|
|
| def train(self, data: np.memmap, max_samples=200_000, batch=512, verbose=True): |
| """Single-pass write over (context, next_char) pairs drawn from data.""" |
| N_data = len(data) - self.k - 1 |
| n_written = 0 |
| t0 = time.time() |
| while n_written < max_samples: |
| b = min(batch, max_samples - n_written) |
| starts = np.random.randint(0, N_data, size=b) |
| ctx = np.stack([data[s:s + self.k].astype(np.int64) for s in starts]) |
| nxt = np.stack([data[s + self.k].astype(np.int64) for s in starts]) |
| ctx_t = torch.from_numpy(ctx).to(self.device) |
| nxt_t = torch.from_numpy(nxt).to(self.device) |
| q = context_embed(ctx_t, self.C) |
| |
| active = retrieve_topk(q, self.A, self.topk) |
| |
| target = self.C[nxt_t].to(torch.int32) |
| |
| |
| update = active.to(torch.int32).t() @ target |
| self.M.add_(update) |
| n_written += b |
| if verbose and n_written % 10000 == 0: |
| print(f"written {n_written:,} | elapsed {time.time()-t0:.1f}s | " |
| f"avg active/query {active.to(torch.int32).sum(dim=1).float().mean().item():.0f}") |
|
|
| @torch.no_grad() |
| def predict_logits(self, ctx_t): |
| """ctx_t: (B, k) int64. Returns (B, V) float logits.""" |
| q = context_embed(ctx_t, self.C) |
| active = retrieve_topk(q, self.A, self.topk) |
| |
| |
| sums = active.to(torch.int32) @ self.M |
| y_est = torch.sign(sums) |
| y_est = torch.where(y_est == 0, torch.ones_like(y_est), y_est) |
| |
| scores = y_est.to(torch.float32) @ self.C.to(torch.float32).t() / self.D |
| return scores |
|
|
| @torch.no_grad() |
| def evaluate_bpc(self, data: np.memmap, max_samples=20_000, batch=256, temperature=0.1): |
| """Compute BPC on held-out data.""" |
| import torch.nn.functional as F |
| N_data = len(data) - self.k - 1 |
| n = min(max_samples, N_data) |
| rng = np.random.RandomState(42) |
| starts = rng.randint(0, N_data, size=n) |
| total_loss = 0.0 |
| total_cnt = 0 |
| for i in range(0, n, batch): |
| chunk = starts[i:i+batch] |
| ctx = np.stack([data[s:s + self.k].astype(np.int64) for s in chunk]) |
| nxt = np.stack([data[s + self.k].astype(np.int64) for s in chunk]) |
| ctx_t = torch.from_numpy(ctx).to(self.device) |
| nxt_t = torch.from_numpy(nxt).to(self.device) |
| logits = self.predict_logits(ctx_t) / temperature |
| loss = F.cross_entropy(logits, nxt_t, reduction='sum') |
| total_loss += loss.item() |
| total_cnt += chunk.shape[0] |
| avg = total_loss / total_cnt |
| return avg, avg / math.log(2) |
|
|
|
|
| if __name__ == '__main__': |
| import argparse |
| ap = argparse.ArgumentParser() |
| ap.add_argument('--data-dir', default='/root/bitnet1/data') |
| ap.add_argument('--d', type=int, default=512) |
| ap.add_argument('--n-addrs', type=int, default=2**15) |
| ap.add_argument('--context', type=int, default=16) |
| ap.add_argument('--topk', type=int, default=None) |
| ap.add_argument('--train-samples', type=int, default=500_000) |
| ap.add_argument('--eval-samples', type=int, default=20_000) |
| ap.add_argument('--temperature', type=float, default=0.1) |
| ap.add_argument('--device', default='cuda') |
| args = ap.parse_args() |
|
|
| sdm = SDMCharLM(d=args.d, n_addrs=args.n_addrs, context=args.context, |
| topk=args.topk, device=args.device) |
| print(f"SDM: D={sdm.D} N={sdm.N} k={sdm.k} topk={sdm.topk}") |
| print(f"Memory: {sdm.M.numel() * 4 / 1e6:.1f} MB") |
|
|
| train_data = np.memmap(os.path.join(args.data_dir, 'train.bin'), dtype=np.uint8, mode='r') |
| val_data = np.memmap(os.path.join(args.data_dir, 'validation.bin'), dtype=np.uint8, mode='r') |
|
|
| t0 = time.time() |
| sdm.train(train_data, max_samples=args.train_samples) |
| print(f"Training complete in {time.time()-t0:.1f}s") |
|
|
| for temp in [1.0, 0.3, 0.1, 0.03]: |
| loss, bpc = sdm.evaluate_bpc(val_data, max_samples=args.eval_samples, temperature=temp) |
| print(f"temp={temp}: loss={loss:.4f} bpc={bpc:.4f}") |
|
|