"""v10: Sparse Distributed Memory char-LM. Zero backprop. Zero learned parameters (except optional final codebook). Training = single pass of Hamming-ball writes; inference = Hamming-ball retrieval. Per Bricken & Pehlevan 2021, attention approximates SDM under norm conditions. This is the pure-SDM baseline — essentially the classical associative-memory answer. Pipeline: 1. Fix random ±1 hard-address matrix A ∈ {±1}^{N×D}, random char hypervectors C ∈ {±1}^{V×D}, integer counter matrix M ∈ ℤ^{N×D}. 2. Context embedding: cyclic-shift-bind last k chars into ±1 query q ∈ {±1}^D. 3. For each (context, next_char) training pair: find addresses i where Hamming(A_i, q) ≤ r (equivalently dot(A_i, q) ≥ D - 2r). For each such i, accumulate C[next_char] into M_i (update counter toward target). 4. At inference: retrieve y_est = sign(Σ_{i active} M_i), then classify next char by argmax of y_est · C_v^T. No gradient, no training loop over parameters — just a single pass through the data updating integer counters. """ import math import os import time import numpy as np import torch def char_hv(vocab_size, d, seed=0): g = torch.Generator().manual_seed(seed) return torch.sign(torch.randn(vocab_size, d, generator=g)).to(torch.int8) def random_hard_addresses(n, d, seed=1): g = torch.Generator().manual_seed(seed) return torch.sign(torch.randn(n, d, generator=g)).to(torch.int8) def context_embed(ctx_ids, C, permutation_matrix=None): """Permutation-bind last k chars. ctx_ids shape (..., k) int64. Returns (..., D) ±1 int8. Use circular shift by position as the permutation (cheap, per Rachkovskij 2112). """ V, D = C.shape # ctx_ids: (B, k) B, k = ctx_ids.shape device = ctx_ids.device codes = C.to(device)[ctx_ids] # (B, k, D) ±1 # Circular shift by position p (so char at position p is shifted by p) rolled = torch.stack([ torch.roll(codes[:, p, :], shifts=p, dims=-1) for p in range(k) ], dim=1) # (B, k, D) # Bundle with sign-of-sum (majority vote) s = rolled.to(torch.int32).sum(dim=1) # (B, D) out = torch.sign(s).to(torch.int8) # Tie-break at zero → +1 out[out == 0] = 1 return out def retrieve_topk(query, A, topk): """Return boolean mask of the top-k addresses by Hamming similarity. query: (B, D), A: (N, D). Returns (B, N) bool. """ dots = query.to(torch.int32) @ A.to(torch.int32).t() # (B, N) _, idx = torch.topk(dots, k=topk, dim=1) # (B, topk) mask = torch.zeros_like(dots, dtype=torch.bool) mask.scatter_(1, idx, True) return mask class SDMCharLM: def __init__(self, vocab_size=128, d=512, n_addrs=2**15, context=16, topk=None, device='cuda', seed=0): self.V = vocab_size self.D = d self.N = n_addrs self.k = context # Activate top-k addresses per query (default ~1% of N) self.topk = topk if topk is not None else max(8, n_addrs // 100) self.device = device self.C = char_hv(vocab_size, d, seed=seed).to(device) # (V,D) ±1 int8 self.A = random_hard_addresses(n_addrs, d, seed=seed + 1).to(device) # (N,D) self.M = torch.zeros(n_addrs, d, dtype=torch.int32, device=device) # counters def train(self, data: np.memmap, max_samples=200_000, batch=512, verbose=True): """Single-pass write over (context, next_char) pairs drawn from data.""" N_data = len(data) - self.k - 1 n_written = 0 t0 = time.time() while n_written < max_samples: b = min(batch, max_samples - n_written) starts = np.random.randint(0, N_data, size=b) ctx = np.stack([data[s:s + self.k].astype(np.int64) for s in starts]) nxt = np.stack([data[s + self.k].astype(np.int64) for s in starts]) ctx_t = torch.from_numpy(ctx).to(self.device) nxt_t = torch.from_numpy(nxt).to(self.device) q = context_embed(ctx_t, self.C) # (B, D) ±1 # Active addresses: (B, N) active = retrieve_topk(q, self.A, self.topk) # Target codes: (B, D) from C[nxt] target = self.C[nxt_t].to(torch.int32) # (B, D) ±1 # For each active(i, j), M[j] += target[i] # Equivalent: M += active^T @ target update = active.to(torch.int32).t() @ target # (N, D) self.M.add_(update) n_written += b if verbose and n_written % 10000 == 0: print(f"written {n_written:,} | elapsed {time.time()-t0:.1f}s | " f"avg active/query {active.to(torch.int32).sum(dim=1).float().mean().item():.0f}") @torch.no_grad() def predict_logits(self, ctx_t): """ctx_t: (B, k) int64. Returns (B, V) float logits.""" q = context_embed(ctx_t, self.C) active = retrieve_topk(q, self.A, self.topk) # (B, N) # y_est = sign(Σ active_j · M_j) per sample # (B, D) = (B, N) @ (N, D) — active is bool, M is int32 sums = active.to(torch.int32) @ self.M # (B, D) y_est = torch.sign(sums) # (B, D) float y_est = torch.where(y_est == 0, torch.ones_like(y_est), y_est) # Scores against char codebook: (B, V) = (B, D) @ (D, V) / D scores = y_est.to(torch.float32) @ self.C.to(torch.float32).t() / self.D return scores @torch.no_grad() def evaluate_bpc(self, data: np.memmap, max_samples=20_000, batch=256, temperature=0.1): """Compute BPC on held-out data.""" import torch.nn.functional as F N_data = len(data) - self.k - 1 n = min(max_samples, N_data) rng = np.random.RandomState(42) starts = rng.randint(0, N_data, size=n) total_loss = 0.0 total_cnt = 0 for i in range(0, n, batch): chunk = starts[i:i+batch] ctx = np.stack([data[s:s + self.k].astype(np.int64) for s in chunk]) nxt = np.stack([data[s + self.k].astype(np.int64) for s in chunk]) ctx_t = torch.from_numpy(ctx).to(self.device) nxt_t = torch.from_numpy(nxt).to(self.device) logits = self.predict_logits(ctx_t) / temperature # sharper loss = F.cross_entropy(logits, nxt_t, reduction='sum') total_loss += loss.item() total_cnt += chunk.shape[0] avg = total_loss / total_cnt return avg, avg / math.log(2) if __name__ == '__main__': import argparse ap = argparse.ArgumentParser() ap.add_argument('--data-dir', default='/root/bitnet1/data') ap.add_argument('--d', type=int, default=512) ap.add_argument('--n-addrs', type=int, default=2**15) ap.add_argument('--context', type=int, default=16) ap.add_argument('--topk', type=int, default=None) ap.add_argument('--train-samples', type=int, default=500_000) ap.add_argument('--eval-samples', type=int, default=20_000) ap.add_argument('--temperature', type=float, default=0.1) ap.add_argument('--device', default='cuda') args = ap.parse_args() sdm = SDMCharLM(d=args.d, n_addrs=args.n_addrs, context=args.context, topk=args.topk, device=args.device) print(f"SDM: D={sdm.D} N={sdm.N} k={sdm.k} topk={sdm.topk}") print(f"Memory: {sdm.M.numel() * 4 / 1e6:.1f} MB") train_data = np.memmap(os.path.join(args.data_dir, 'train.bin'), dtype=np.uint8, mode='r') val_data = np.memmap(os.path.join(args.data_dir, 'validation.bin'), dtype=np.uint8, mode='r') t0 = time.time() sdm.train(train_data, max_samples=args.train_samples) print(f"Training complete in {time.time()-t0:.1f}s") for temp in [1.0, 0.3, 0.1, 0.03]: loss, bpc = sdm.evaluate_bpc(val_data, max_samples=args.eval_samples, temperature=temp) print(f"temp={temp}: loss={loss:.4f} bpc={bpc:.4f}")