File size: 12,198 Bytes
3d7f6c5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
"""Training tasks for standalone WrinkleBrane evaluation.

Three tasks of increasing difficulty:

1. **Sequence Copy**: Write a random sequence, predict it shifted by one.
   Tests basic memory write/read capability.

2. **Associative Recall**: Given key-value pairs followed by a query key,
   predict the associated value.  Tests selective retrieval.

3. **Synthetic Grammar LM**: Next-token prediction on sequences generated
   by a procedural grammar with deterministic and stochastic rules.
   Tests whether the model can learn distributional patterns.

All tasks produce ``(input_ids, target_ids)`` pairs suitable for
cross-entropy training with the same model interface.
"""

from __future__ import annotations

from typing import Tuple

import torch
from torch import Tensor


# ---------------------------------------------------------------------------
# Task 1: Sequence Copy
# ---------------------------------------------------------------------------

class SequenceCopyTask:
    """Memorize-and-reproduce task for testing memory write/read.

    The model sees a random sequence, then a SEP token, then must
    reproduce the sequence from memory:

    Input:  ``[t_0, t_1, ..., t_{L-1}, SEP, t_0, t_1, ..., t_{L-2}]``
    Target: ``[IGN, IGN, ..., IGN,     t_0, t_1, ..., t_{L-1}]``

    Only the reproduction phase (after SEP) is scored.  This directly
    tests the model's ability to store tokens in the membrane and
    retrieve them in order.

    Parameters
    ----------
    vocab_size : int
        Number of tokens (including special tokens).
    seq_len : int
        Length of the random sequence to memorize.
    """

    def __init__(
        self,
        vocab_size: int = 32,
        seq_len: int = 8,
    ):
        self.vocab_size = vocab_size
        self.seq_len = seq_len
        self.sep_token = 0
        self.token_offset = 1  # data tokens start at 1
        self.ignore_index = -100

    def generate_batch(self, batch_size: int) -> Tuple[Tensor, Tensor]:
        """Generate a batch of copy sequences.

        Returns
        -------
        input_ids : Tensor ``[B, 2 * seq_len]``
        target_ids : Tensor ``[B, 2 * seq_len]``
            First ``seq_len`` positions are ``ignore_index``.
        """
        L = self.seq_len

        # Random tokens in [token_offset, vocab_size)
        tokens = torch.randint(
            self.token_offset, self.vocab_size, (batch_size, L),
        )

        # Input: [t_0, ..., t_{L-1}, SEP, t_0, ..., t_{L-2}]
        sep = torch.full((batch_size, 1), self.sep_token, dtype=torch.long)
        input_ids = torch.cat([tokens, sep, tokens[:, :-1]], dim=1)  # [B, 2L]

        # Target: [IGN, ..., IGN, t_0, ..., t_{L-1}]
        ignore = torch.full((batch_size, L), self.ignore_index, dtype=torch.long)
        target_ids = torch.cat([ignore, tokens], dim=1)  # [B, 2L]

        return input_ids, target_ids


# ---------------------------------------------------------------------------
# Task 2: Associative Recall
# ---------------------------------------------------------------------------

class AssociativeRecallTask:
    """Generate key-value association sequences.

    Format: ``[BOS, k1, v1, k2, v2, ..., SEP, k_query, PAD]``
    Target: ``[IGN, IGN, IGN, ..., IGN, IGN, v_query]``

    Only the final position's prediction is scored (the value for the
    queried key).

    Parameters
    ----------
    vocab_size : int
        Total vocabulary.
    n_pairs : int
        Number of key-value pairs per sequence.
    """

    def __init__(
        self,
        vocab_size: int = 32,
        n_pairs: int = 4,
    ):
        self.vocab_size = vocab_size
        self.n_pairs = n_pairs
        # Special tokens
        self.bos_token = 0
        self.sep_token = 1
        self.pad_token = 2
        self.token_offset = 3  # data tokens start here
        self.ignore_index = -100

    def generate_batch(self, batch_size: int) -> Tuple[Tensor, Tensor]:
        """Generate a batch of associative recall sequences.

        Returns
        -------
        input_ids : Tensor ``[B, 2*n_pairs + 3]``
        target_ids : Tensor ``[B, 2*n_pairs + 3]``
            All positions are ``ignore_index`` except the last.
        """
        n = self.n_pairs
        data_range = self.vocab_size - self.token_offset

        # Generate unique keys and values
        keys = torch.randint(
            self.token_offset, self.token_offset + data_range // 2,
            (batch_size, n),
        )
        values = torch.randint(
            self.token_offset + data_range // 2, self.vocab_size,
            (batch_size, n),
        )

        # Pick a random query index per batch
        query_idx = torch.randint(0, n, (batch_size,))
        query_keys = keys[torch.arange(batch_size), query_idx]
        query_values = values[torch.arange(batch_size), query_idx]

        # Build input: [BOS, k1, v1, k2, v2, ..., SEP, k_query, PAD]
        seq_len = 2 * n + 3
        input_ids = torch.full((batch_size, seq_len), self.pad_token, dtype=torch.long)
        input_ids[:, 0] = self.bos_token

        for i in range(n):
            input_ids[:, 1 + 2 * i] = keys[:, i]
            input_ids[:, 2 + 2 * i] = values[:, i]

        input_ids[:, 1 + 2 * n] = self.sep_token
        input_ids[:, 2 + 2 * n] = query_keys

        # Target: ignore all except last position
        target_ids = torch.full((batch_size, seq_len), self.ignore_index, dtype=torch.long)
        target_ids[:, -1] = query_values

        return input_ids, target_ids


# ---------------------------------------------------------------------------
# Task 3: Synthetic Grammar LM
# ---------------------------------------------------------------------------

class SyntheticGrammarTask:
    """Procedural grammar with learnable deterministic and stochastic rules.

    Grammar structure:
    - Vocabulary: ``vocab_size`` tokens (first 3 reserved for BOS/EOS/PAD)
    - Rules are of the form: ``if current token is X, next token is Y``
      (deterministic) or ``next is Y1 or Y2 with probabilities p, 1-p``
      (stochastic)
    - Some tokens trigger deterministic bigram patterns (always A→B)
    - Some tokens trigger probabilistic choices (C → D 70% or E 30%)
    - Some tokens are "wild" (uniform random next token)

    This creates a learnable language with enough structure to test whether
    the model captures distributional patterns.

    Parameters
    ----------
    vocab_size : int
        Total vocabulary including special tokens.
    seq_len : int
        Sequence length.
    deterministic_frac : float
        Fraction of tokens with deterministic next-token rules.
    stochastic_frac : float
        Fraction of tokens with 2-way stochastic rules.
    seed : int
        RNG seed for rule generation (grammar is fixed, sequences vary).
    """

    def __init__(
        self,
        vocab_size: int = 32,
        seq_len: int = 64,
        deterministic_frac: float = 0.4,
        stochastic_frac: float = 0.3,
        seed: int = 42,
    ):
        self.vocab_size = vocab_size
        self.seq_len = seq_len
        self.bos_token = 0
        self.eos_token = 1
        self.pad_token = 2
        self.token_offset = 3

        gen = torch.Generator().manual_seed(seed)
        data_tokens = list(range(self.token_offset, vocab_size))
        n_data = len(data_tokens)
        n_det = int(n_data * deterministic_frac)
        n_stoch = int(n_data * stochastic_frac)

        # Shuffle to assign rule types
        perm = torch.randperm(n_data, generator=gen).tolist()
        det_tokens = [data_tokens[i] for i in perm[:n_det]]
        stoch_tokens = [data_tokens[i] for i in perm[n_det:n_det + n_stoch]]

        # Build rule tables
        self.det_rules = {}  # token -> next_token
        self.stoch_rules = {}  # token -> (token_a, token_b, prob_a)

        for t in det_tokens:
            next_t = data_tokens[torch.randint(0, n_data, (1,), generator=gen).item()]
            self.det_rules[t] = next_t

        for t in stoch_tokens:
            a = data_tokens[torch.randint(0, n_data, (1,), generator=gen).item()]
            b = data_tokens[torch.randint(0, n_data, (1,), generator=gen).item()]
            while b == a:
                b = data_tokens[torch.randint(0, n_data, (1,), generator=gen).item()]
            prob_a = 0.3 + 0.4 * torch.rand(1, generator=gen).item()  # 0.3-0.7
            self.stoch_rules[t] = (a, b, prob_a)

        self.wild_tokens = [
            t for t in data_tokens
            if t not in self.det_rules and t not in self.stoch_rules
        ]

        # Pre-compute vectorised lookup tables for fast batch generation.
        # rule_type[t]: 0=det, 1=stoch, 2=wild
        self._rule_type = torch.full((vocab_size,), 2, dtype=torch.long)
        # det_target[t]: deterministic next token (only valid when rule_type==0)
        self._det_target = torch.zeros(vocab_size, dtype=torch.long)
        # stoch_a[t], stoch_b[t], stoch_p[t]: stochastic rule params
        self._stoch_a = torch.zeros(vocab_size, dtype=torch.long)
        self._stoch_b = torch.zeros(vocab_size, dtype=torch.long)
        self._stoch_p = torch.zeros(vocab_size)

        for t, nt in self.det_rules.items():
            self._rule_type[t] = 0
            self._det_target[t] = nt
        for t, (a, b, p) in self.stoch_rules.items():
            self._rule_type[t] = 1
            self._stoch_a[t] = a
            self._stoch_b[t] = b
            self._stoch_p[t] = p

    def generate_batch(self, batch_size: int) -> Tuple[Tensor, Tensor]:
        """Generate a batch of grammar sequences (vectorised).

        Returns
        -------
        input_ids : Tensor ``[B, seq_len]``
        target_ids : Tensor ``[B, seq_len]``
            Shifted by one (standard LM target).
        """
        B = batch_size
        S = self.seq_len + 1  # need one extra for shift
        n_data = self.vocab_size - self.token_offset

        tokens = torch.empty(B, S, dtype=torch.long)
        tokens[:, 0] = self.bos_token

        # Random start tokens for the whole batch
        current = torch.randint(self.token_offset, self.vocab_size, (B,))
        tokens[:, 1] = current

        # Pre-sample all random numbers we'll need
        rand_vals = torch.rand(B, S)
        wild_tokens = torch.randint(self.token_offset, self.vocab_size, (B, S))

        for t in range(2, S):
            rt = self._rule_type[current]            # [B]
            det_next = self._det_target[current]     # [B]
            sa = self._stoch_a[current]              # [B]
            sb = self._stoch_b[current]              # [B]
            sp = self._stoch_p[current]              # [B]

            # Stochastic: pick a if rand < p, else b
            stoch_next = torch.where(rand_vals[:, t] < sp, sa, sb)

            # Combine: det if rt==0, stoch if rt==1, wild if rt==2
            next_tok = torch.where(rt == 0, det_next,
                       torch.where(rt == 1, stoch_next, wild_tokens[:, t]))

            tokens[:, t] = next_tok
            current = next_tok

        input_ids = tokens[:, :-1].contiguous()   # [B, seq_len]
        target_ids = tokens[:, 1:].contiguous()   # [B, seq_len]
        return input_ids, target_ids


# ---------------------------------------------------------------------------
# Evaluation utilities
# ---------------------------------------------------------------------------

def compute_accuracy(
    logits: Tensor,
    targets: Tensor,
    ignore_index: int = -100,
) -> float:
    """Compute token-level accuracy, ignoring padded positions.

    Parameters
    ----------
    logits : Tensor ``[B, T, V]``
    targets : Tensor ``[B, T]``
    ignore_index : int
        Target values to ignore.

    Returns
    -------
    float
        Accuracy in [0, 1].
    """
    preds = logits.argmax(dim=-1)  # [B, T]
    mask = targets != ignore_index
    if mask.sum() == 0:
        return 0.0
    correct = ((preds == targets) & mask).sum()
    return float(correct) / float(mask.sum())


def compute_perplexity(loss: float) -> float:
    """Convert cross-entropy loss to perplexity."""
    return math.exp(min(loss, 100))  # clamp to avoid overflow


import math