File size: 7,892 Bytes
4754707
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
"""v10: Sparse Distributed Memory char-LM.

Zero backprop. Zero learned parameters (except optional final codebook).
Training = single pass of Hamming-ball writes; inference = Hamming-ball retrieval.

Per Bricken & Pehlevan 2021, attention approximates SDM under norm conditions.
This is the pure-SDM baseline — essentially the classical associative-memory answer.

Pipeline:
  1. Fix random ±1 hard-address matrix A ∈ {±1}^{N×D}, random char hypervectors
     C ∈ {±1}^{V×D}, integer counter matrix M ∈ ℤ^{N×D}.
  2. Context embedding: cyclic-shift-bind last k chars into ±1 query q ∈ {±1}^D.
  3. For each (context, next_char) training pair: find addresses i where
     Hamming(A_i, q) ≤ r (equivalently dot(A_i, q) ≥ D - 2r). For each such i,
     accumulate C[next_char] into M_i (update counter toward target).
  4. At inference: retrieve y_est = sign(Σ_{i active} M_i), then classify next
     char by argmax of y_est · C_v^T.

No gradient, no training loop over parameters — just a single pass through the
data updating integer counters.
"""
import math
import os
import time
import numpy as np
import torch


def char_hv(vocab_size, d, seed=0):
    g = torch.Generator().manual_seed(seed)
    return torch.sign(torch.randn(vocab_size, d, generator=g)).to(torch.int8)


def random_hard_addresses(n, d, seed=1):
    g = torch.Generator().manual_seed(seed)
    return torch.sign(torch.randn(n, d, generator=g)).to(torch.int8)


def context_embed(ctx_ids, C, permutation_matrix=None):
    """Permutation-bind last k chars. ctx_ids shape (..., k) int64.
    Returns (..., D) ±1 int8.

    Use circular shift by position as the permutation (cheap, per Rachkovskij 2112).
    """
    V, D = C.shape
    # ctx_ids: (B, k)
    B, k = ctx_ids.shape
    device = ctx_ids.device
    codes = C.to(device)[ctx_ids]  # (B, k, D) ±1
    # Circular shift by position p (so char at position p is shifted by p)
    rolled = torch.stack([
        torch.roll(codes[:, p, :], shifts=p, dims=-1) for p in range(k)
    ], dim=1)  # (B, k, D)
    # Bundle with sign-of-sum (majority vote)
    s = rolled.to(torch.int32).sum(dim=1)  # (B, D)
    out = torch.sign(s).to(torch.int8)
    # Tie-break at zero → +1
    out[out == 0] = 1
    return out


def retrieve_topk(query, A, topk):
    """Return boolean mask of the top-k addresses by Hamming similarity.
    query: (B, D), A: (N, D). Returns (B, N) bool.
    """
    dots = query.to(torch.int32) @ A.to(torch.int32).t()  # (B, N)
    _, idx = torch.topk(dots, k=topk, dim=1)  # (B, topk)
    mask = torch.zeros_like(dots, dtype=torch.bool)
    mask.scatter_(1, idx, True)
    return mask


class SDMCharLM:
    def __init__(self, vocab_size=128, d=512, n_addrs=2**15, context=16, topk=None,
                 device='cuda', seed=0):
        self.V = vocab_size
        self.D = d
        self.N = n_addrs
        self.k = context
        # Activate top-k addresses per query (default ~1% of N)
        self.topk = topk if topk is not None else max(8, n_addrs // 100)
        self.device = device
        self.C = char_hv(vocab_size, d, seed=seed).to(device)  # (V,D) ±1 int8
        self.A = random_hard_addresses(n_addrs, d, seed=seed + 1).to(device)  # (N,D)
        self.M = torch.zeros(n_addrs, d, dtype=torch.int32, device=device)  # counters

    def train(self, data: np.memmap, max_samples=200_000, batch=512, verbose=True):
        """Single-pass write over (context, next_char) pairs drawn from data."""
        N_data = len(data) - self.k - 1
        n_written = 0
        t0 = time.time()
        while n_written < max_samples:
            b = min(batch, max_samples - n_written)
            starts = np.random.randint(0, N_data, size=b)
            ctx = np.stack([data[s:s + self.k].astype(np.int64) for s in starts])
            nxt = np.stack([data[s + self.k].astype(np.int64) for s in starts])
            ctx_t = torch.from_numpy(ctx).to(self.device)
            nxt_t = torch.from_numpy(nxt).to(self.device)
            q = context_embed(ctx_t, self.C)  # (B, D) ±1
            # Active addresses: (B, N)
            active = retrieve_topk(q, self.A, self.topk)
            # Target codes: (B, D) from C[nxt]
            target = self.C[nxt_t].to(torch.int32)  # (B, D) ±1
            # For each active(i, j), M[j] += target[i]
            # Equivalent: M += active^T @ target
            update = active.to(torch.int32).t() @ target  # (N, D)
            self.M.add_(update)
            n_written += b
            if verbose and n_written % 10000 == 0:
                print(f"written {n_written:,} | elapsed {time.time()-t0:.1f}s | "
                      f"avg active/query {active.to(torch.int32).sum(dim=1).float().mean().item():.0f}")

    @torch.no_grad()
    def predict_logits(self, ctx_t):
        """ctx_t: (B, k) int64. Returns (B, V) float logits."""
        q = context_embed(ctx_t, self.C)
        active = retrieve_topk(q, self.A, self.topk)  # (B, N)
        # y_est = sign(Σ active_j · M_j) per sample
        # (B, D) = (B, N) @ (N, D) — active is bool, M is int32
        sums = active.to(torch.int32) @ self.M  # (B, D)
        y_est = torch.sign(sums)  # (B, D) float
        y_est = torch.where(y_est == 0, torch.ones_like(y_est), y_est)
        # Scores against char codebook: (B, V) = (B, D) @ (D, V) / D
        scores = y_est.to(torch.float32) @ self.C.to(torch.float32).t() / self.D
        return scores

    @torch.no_grad()
    def evaluate_bpc(self, data: np.memmap, max_samples=20_000, batch=256, temperature=0.1):
        """Compute BPC on held-out data."""
        import torch.nn.functional as F
        N_data = len(data) - self.k - 1
        n = min(max_samples, N_data)
        rng = np.random.RandomState(42)
        starts = rng.randint(0, N_data, size=n)
        total_loss = 0.0
        total_cnt = 0
        for i in range(0, n, batch):
            chunk = starts[i:i+batch]
            ctx = np.stack([data[s:s + self.k].astype(np.int64) for s in chunk])
            nxt = np.stack([data[s + self.k].astype(np.int64) for s in chunk])
            ctx_t = torch.from_numpy(ctx).to(self.device)
            nxt_t = torch.from_numpy(nxt).to(self.device)
            logits = self.predict_logits(ctx_t) / temperature  # sharper
            loss = F.cross_entropy(logits, nxt_t, reduction='sum')
            total_loss += loss.item()
            total_cnt += chunk.shape[0]
        avg = total_loss / total_cnt
        return avg, avg / math.log(2)


if __name__ == '__main__':
    import argparse
    ap = argparse.ArgumentParser()
    ap.add_argument('--data-dir', default='/root/bitnet1/data')
    ap.add_argument('--d', type=int, default=512)
    ap.add_argument('--n-addrs', type=int, default=2**15)
    ap.add_argument('--context', type=int, default=16)
    ap.add_argument('--topk', type=int, default=None)
    ap.add_argument('--train-samples', type=int, default=500_000)
    ap.add_argument('--eval-samples', type=int, default=20_000)
    ap.add_argument('--temperature', type=float, default=0.1)
    ap.add_argument('--device', default='cuda')
    args = ap.parse_args()

    sdm = SDMCharLM(d=args.d, n_addrs=args.n_addrs, context=args.context,
                    topk=args.topk, device=args.device)
    print(f"SDM: D={sdm.D} N={sdm.N} k={sdm.k} topk={sdm.topk}")
    print(f"Memory: {sdm.M.numel() * 4 / 1e6:.1f} MB")

    train_data = np.memmap(os.path.join(args.data_dir, 'train.bin'), dtype=np.uint8, mode='r')
    val_data = np.memmap(os.path.join(args.data_dir, 'validation.bin'), dtype=np.uint8, mode='r')

    t0 = time.time()
    sdm.train(train_data, max_samples=args.train_samples)
    print(f"Training complete in {time.time()-t0:.1f}s")

    for temp in [1.0, 0.3, 0.1, 0.03]:
        loss, bpc = sdm.evaluate_bpc(val_data, max_samples=args.eval_samples, temperature=temp)
        print(f"temp={temp}: loss={loss:.4f} bpc={bpc:.4f}")