bitnet-1bitllm / vm_backup /code /model_v10_sdm.py
hidude562's picture
1bitllm code (checkpoints to follow)
4754707 verified
"""v10: Sparse Distributed Memory char-LM.
Zero backprop. Zero learned parameters (except optional final codebook).
Training = single pass of Hamming-ball writes; inference = Hamming-ball retrieval.
Per Bricken & Pehlevan 2021, attention approximates SDM under norm conditions.
This is the pure-SDM baseline — essentially the classical associative-memory answer.
Pipeline:
1. Fix random ±1 hard-address matrix A ∈ {±1}^{N×D}, random char hypervectors
C ∈ {±1}^{V×D}, integer counter matrix M ∈ ℤ^{N×D}.
2. Context embedding: cyclic-shift-bind last k chars into ±1 query q ∈ {±1}^D.
3. For each (context, next_char) training pair: find addresses i where
Hamming(A_i, q) ≤ r (equivalently dot(A_i, q) ≥ D - 2r). For each such i,
accumulate C[next_char] into M_i (update counter toward target).
4. At inference: retrieve y_est = sign(Σ_{i active} M_i), then classify next
char by argmax of y_est · C_v^T.
No gradient, no training loop over parameters — just a single pass through the
data updating integer counters.
"""
import math
import os
import time
import numpy as np
import torch
def char_hv(vocab_size, d, seed=0):
g = torch.Generator().manual_seed(seed)
return torch.sign(torch.randn(vocab_size, d, generator=g)).to(torch.int8)
def random_hard_addresses(n, d, seed=1):
g = torch.Generator().manual_seed(seed)
return torch.sign(torch.randn(n, d, generator=g)).to(torch.int8)
def context_embed(ctx_ids, C, permutation_matrix=None):
"""Permutation-bind last k chars. ctx_ids shape (..., k) int64.
Returns (..., D) ±1 int8.
Use circular shift by position as the permutation (cheap, per Rachkovskij 2112).
"""
V, D = C.shape
# ctx_ids: (B, k)
B, k = ctx_ids.shape
device = ctx_ids.device
codes = C.to(device)[ctx_ids] # (B, k, D) ±1
# Circular shift by position p (so char at position p is shifted by p)
rolled = torch.stack([
torch.roll(codes[:, p, :], shifts=p, dims=-1) for p in range(k)
], dim=1) # (B, k, D)
# Bundle with sign-of-sum (majority vote)
s = rolled.to(torch.int32).sum(dim=1) # (B, D)
out = torch.sign(s).to(torch.int8)
# Tie-break at zero → +1
out[out == 0] = 1
return out
def retrieve_topk(query, A, topk):
"""Return boolean mask of the top-k addresses by Hamming similarity.
query: (B, D), A: (N, D). Returns (B, N) bool.
"""
dots = query.to(torch.int32) @ A.to(torch.int32).t() # (B, N)
_, idx = torch.topk(dots, k=topk, dim=1) # (B, topk)
mask = torch.zeros_like(dots, dtype=torch.bool)
mask.scatter_(1, idx, True)
return mask
class SDMCharLM:
def __init__(self, vocab_size=128, d=512, n_addrs=2**15, context=16, topk=None,
device='cuda', seed=0):
self.V = vocab_size
self.D = d
self.N = n_addrs
self.k = context
# Activate top-k addresses per query (default ~1% of N)
self.topk = topk if topk is not None else max(8, n_addrs // 100)
self.device = device
self.C = char_hv(vocab_size, d, seed=seed).to(device) # (V,D) ±1 int8
self.A = random_hard_addresses(n_addrs, d, seed=seed + 1).to(device) # (N,D)
self.M = torch.zeros(n_addrs, d, dtype=torch.int32, device=device) # counters
def train(self, data: np.memmap, max_samples=200_000, batch=512, verbose=True):
"""Single-pass write over (context, next_char) pairs drawn from data."""
N_data = len(data) - self.k - 1
n_written = 0
t0 = time.time()
while n_written < max_samples:
b = min(batch, max_samples - n_written)
starts = np.random.randint(0, N_data, size=b)
ctx = np.stack([data[s:s + self.k].astype(np.int64) for s in starts])
nxt = np.stack([data[s + self.k].astype(np.int64) for s in starts])
ctx_t = torch.from_numpy(ctx).to(self.device)
nxt_t = torch.from_numpy(nxt).to(self.device)
q = context_embed(ctx_t, self.C) # (B, D) ±1
# Active addresses: (B, N)
active = retrieve_topk(q, self.A, self.topk)
# Target codes: (B, D) from C[nxt]
target = self.C[nxt_t].to(torch.int32) # (B, D) ±1
# For each active(i, j), M[j] += target[i]
# Equivalent: M += active^T @ target
update = active.to(torch.int32).t() @ target # (N, D)
self.M.add_(update)
n_written += b
if verbose and n_written % 10000 == 0:
print(f"written {n_written:,} | elapsed {time.time()-t0:.1f}s | "
f"avg active/query {active.to(torch.int32).sum(dim=1).float().mean().item():.0f}")
@torch.no_grad()
def predict_logits(self, ctx_t):
"""ctx_t: (B, k) int64. Returns (B, V) float logits."""
q = context_embed(ctx_t, self.C)
active = retrieve_topk(q, self.A, self.topk) # (B, N)
# y_est = sign(Σ active_j · M_j) per sample
# (B, D) = (B, N) @ (N, D) — active is bool, M is int32
sums = active.to(torch.int32) @ self.M # (B, D)
y_est = torch.sign(sums) # (B, D) float
y_est = torch.where(y_est == 0, torch.ones_like(y_est), y_est)
# Scores against char codebook: (B, V) = (B, D) @ (D, V) / D
scores = y_est.to(torch.float32) @ self.C.to(torch.float32).t() / self.D
return scores
@torch.no_grad()
def evaluate_bpc(self, data: np.memmap, max_samples=20_000, batch=256, temperature=0.1):
"""Compute BPC on held-out data."""
import torch.nn.functional as F
N_data = len(data) - self.k - 1
n = min(max_samples, N_data)
rng = np.random.RandomState(42)
starts = rng.randint(0, N_data, size=n)
total_loss = 0.0
total_cnt = 0
for i in range(0, n, batch):
chunk = starts[i:i+batch]
ctx = np.stack([data[s:s + self.k].astype(np.int64) for s in chunk])
nxt = np.stack([data[s + self.k].astype(np.int64) for s in chunk])
ctx_t = torch.from_numpy(ctx).to(self.device)
nxt_t = torch.from_numpy(nxt).to(self.device)
logits = self.predict_logits(ctx_t) / temperature # sharper
loss = F.cross_entropy(logits, nxt_t, reduction='sum')
total_loss += loss.item()
total_cnt += chunk.shape[0]
avg = total_loss / total_cnt
return avg, avg / math.log(2)
if __name__ == '__main__':
import argparse
ap = argparse.ArgumentParser()
ap.add_argument('--data-dir', default='/root/bitnet1/data')
ap.add_argument('--d', type=int, default=512)
ap.add_argument('--n-addrs', type=int, default=2**15)
ap.add_argument('--context', type=int, default=16)
ap.add_argument('--topk', type=int, default=None)
ap.add_argument('--train-samples', type=int, default=500_000)
ap.add_argument('--eval-samples', type=int, default=20_000)
ap.add_argument('--temperature', type=float, default=0.1)
ap.add_argument('--device', default='cuda')
args = ap.parse_args()
sdm = SDMCharLM(d=args.d, n_addrs=args.n_addrs, context=args.context,
topk=args.topk, device=args.device)
print(f"SDM: D={sdm.D} N={sdm.N} k={sdm.k} topk={sdm.topk}")
print(f"Memory: {sdm.M.numel() * 4 / 1e6:.1f} MB")
train_data = np.memmap(os.path.join(args.data_dir, 'train.bin'), dtype=np.uint8, mode='r')
val_data = np.memmap(os.path.join(args.data_dir, 'validation.bin'), dtype=np.uint8, mode='r')
t0 = time.time()
sdm.train(train_data, max_samples=args.train_samples)
print(f"Training complete in {time.time()-t0:.1f}s")
for temp in [1.0, 0.3, 0.1, 0.03]:
loss, bpc = sdm.evaluate_bpc(val_data, max_samples=args.eval_samples, temperature=temp)
print(f"temp={temp}: loss={loss:.4f} bpc={bpc:.4f}")