| """Training tasks for standalone WrinkleBrane evaluation. |
| |
| Three tasks of increasing difficulty: |
| |
| 1. **Sequence Copy**: Write a random sequence, predict it shifted by one. |
| Tests basic memory write/read capability. |
| |
| 2. **Associative Recall**: Given key-value pairs followed by a query key, |
| predict the associated value. Tests selective retrieval. |
| |
| 3. **Synthetic Grammar LM**: Next-token prediction on sequences generated |
| by a procedural grammar with deterministic and stochastic rules. |
| Tests whether the model can learn distributional patterns. |
| |
| All tasks produce ``(input_ids, target_ids)`` pairs suitable for |
| cross-entropy training with the same model interface. |
| """ |
|
|
| from __future__ import annotations |
|
|
| from typing import Tuple |
|
|
| import torch |
| from torch import Tensor |
|
|
|
|
| |
| |
| |
|
|
| class SequenceCopyTask: |
| """Memorize-and-reproduce task for testing memory write/read. |
| |
| The model sees a random sequence, then a SEP token, then must |
| reproduce the sequence from memory: |
| |
| Input: ``[t_0, t_1, ..., t_{L-1}, SEP, t_0, t_1, ..., t_{L-2}]`` |
| Target: ``[IGN, IGN, ..., IGN, t_0, t_1, ..., t_{L-1}]`` |
| |
| Only the reproduction phase (after SEP) is scored. This directly |
| tests the model's ability to store tokens in the membrane and |
| retrieve them in order. |
| |
| Parameters |
| ---------- |
| vocab_size : int |
| Number of tokens (including special tokens). |
| seq_len : int |
| Length of the random sequence to memorize. |
| """ |
|
|
| def __init__( |
| self, |
| vocab_size: int = 32, |
| seq_len: int = 8, |
| ): |
| self.vocab_size = vocab_size |
| self.seq_len = seq_len |
| self.sep_token = 0 |
| self.token_offset = 1 |
| self.ignore_index = -100 |
|
|
| def generate_batch(self, batch_size: int) -> Tuple[Tensor, Tensor]: |
| """Generate a batch of copy sequences. |
| |
| Returns |
| ------- |
| input_ids : Tensor ``[B, 2 * seq_len]`` |
| target_ids : Tensor ``[B, 2 * seq_len]`` |
| First ``seq_len`` positions are ``ignore_index``. |
| """ |
| L = self.seq_len |
|
|
| |
| tokens = torch.randint( |
| self.token_offset, self.vocab_size, (batch_size, L), |
| ) |
|
|
| |
| sep = torch.full((batch_size, 1), self.sep_token, dtype=torch.long) |
| input_ids = torch.cat([tokens, sep, tokens[:, :-1]], dim=1) |
|
|
| |
| ignore = torch.full((batch_size, L), self.ignore_index, dtype=torch.long) |
| target_ids = torch.cat([ignore, tokens], dim=1) |
|
|
| return input_ids, target_ids |
|
|
|
|
| |
| |
| |
|
|
| class AssociativeRecallTask: |
| """Generate key-value association sequences. |
| |
| Format: ``[BOS, k1, v1, k2, v2, ..., SEP, k_query, PAD]`` |
| Target: ``[IGN, IGN, IGN, ..., IGN, IGN, v_query]`` |
| |
| Only the final position's prediction is scored (the value for the |
| queried key). |
| |
| Parameters |
| ---------- |
| vocab_size : int |
| Total vocabulary. |
| n_pairs : int |
| Number of key-value pairs per sequence. |
| """ |
|
|
| def __init__( |
| self, |
| vocab_size: int = 32, |
| n_pairs: int = 4, |
| ): |
| self.vocab_size = vocab_size |
| self.n_pairs = n_pairs |
| |
| self.bos_token = 0 |
| self.sep_token = 1 |
| self.pad_token = 2 |
| self.token_offset = 3 |
| self.ignore_index = -100 |
|
|
| def generate_batch(self, batch_size: int) -> Tuple[Tensor, Tensor]: |
| """Generate a batch of associative recall sequences. |
| |
| Returns |
| ------- |
| input_ids : Tensor ``[B, 2*n_pairs + 3]`` |
| target_ids : Tensor ``[B, 2*n_pairs + 3]`` |
| All positions are ``ignore_index`` except the last. |
| """ |
| n = self.n_pairs |
| data_range = self.vocab_size - self.token_offset |
|
|
| |
| keys = torch.randint( |
| self.token_offset, self.token_offset + data_range // 2, |
| (batch_size, n), |
| ) |
| values = torch.randint( |
| self.token_offset + data_range // 2, self.vocab_size, |
| (batch_size, n), |
| ) |
|
|
| |
| query_idx = torch.randint(0, n, (batch_size,)) |
| query_keys = keys[torch.arange(batch_size), query_idx] |
| query_values = values[torch.arange(batch_size), query_idx] |
|
|
| |
| seq_len = 2 * n + 3 |
| input_ids = torch.full((batch_size, seq_len), self.pad_token, dtype=torch.long) |
| input_ids[:, 0] = self.bos_token |
|
|
| for i in range(n): |
| input_ids[:, 1 + 2 * i] = keys[:, i] |
| input_ids[:, 2 + 2 * i] = values[:, i] |
|
|
| input_ids[:, 1 + 2 * n] = self.sep_token |
| input_ids[:, 2 + 2 * n] = query_keys |
|
|
| |
| target_ids = torch.full((batch_size, seq_len), self.ignore_index, dtype=torch.long) |
| target_ids[:, -1] = query_values |
|
|
| return input_ids, target_ids |
|
|
|
|
| |
| |
| |
|
|
| class SyntheticGrammarTask: |
| """Procedural grammar with learnable deterministic and stochastic rules. |
| |
| Grammar structure: |
| - Vocabulary: ``vocab_size`` tokens (first 3 reserved for BOS/EOS/PAD) |
| - Rules are of the form: ``if current token is X, next token is Y`` |
| (deterministic) or ``next is Y1 or Y2 with probabilities p, 1-p`` |
| (stochastic) |
| - Some tokens trigger deterministic bigram patterns (always A→B) |
| - Some tokens trigger probabilistic choices (C → D 70% or E 30%) |
| - Some tokens are "wild" (uniform random next token) |
| |
| This creates a learnable language with enough structure to test whether |
| the model captures distributional patterns. |
| |
| Parameters |
| ---------- |
| vocab_size : int |
| Total vocabulary including special tokens. |
| seq_len : int |
| Sequence length. |
| deterministic_frac : float |
| Fraction of tokens with deterministic next-token rules. |
| stochastic_frac : float |
| Fraction of tokens with 2-way stochastic rules. |
| seed : int |
| RNG seed for rule generation (grammar is fixed, sequences vary). |
| """ |
|
|
| def __init__( |
| self, |
| vocab_size: int = 32, |
| seq_len: int = 64, |
| deterministic_frac: float = 0.4, |
| stochastic_frac: float = 0.3, |
| seed: int = 42, |
| ): |
| self.vocab_size = vocab_size |
| self.seq_len = seq_len |
| self.bos_token = 0 |
| self.eos_token = 1 |
| self.pad_token = 2 |
| self.token_offset = 3 |
|
|
| gen = torch.Generator().manual_seed(seed) |
| data_tokens = list(range(self.token_offset, vocab_size)) |
| n_data = len(data_tokens) |
| n_det = int(n_data * deterministic_frac) |
| n_stoch = int(n_data * stochastic_frac) |
|
|
| |
| perm = torch.randperm(n_data, generator=gen).tolist() |
| det_tokens = [data_tokens[i] for i in perm[:n_det]] |
| stoch_tokens = [data_tokens[i] for i in perm[n_det:n_det + n_stoch]] |
|
|
| |
| self.det_rules = {} |
| self.stoch_rules = {} |
|
|
| for t in det_tokens: |
| next_t = data_tokens[torch.randint(0, n_data, (1,), generator=gen).item()] |
| self.det_rules[t] = next_t |
|
|
| for t in stoch_tokens: |
| a = data_tokens[torch.randint(0, n_data, (1,), generator=gen).item()] |
| b = data_tokens[torch.randint(0, n_data, (1,), generator=gen).item()] |
| while b == a: |
| b = data_tokens[torch.randint(0, n_data, (1,), generator=gen).item()] |
| prob_a = 0.3 + 0.4 * torch.rand(1, generator=gen).item() |
| self.stoch_rules[t] = (a, b, prob_a) |
|
|
| self.wild_tokens = [ |
| t for t in data_tokens |
| if t not in self.det_rules and t not in self.stoch_rules |
| ] |
|
|
| |
| |
| self._rule_type = torch.full((vocab_size,), 2, dtype=torch.long) |
| |
| self._det_target = torch.zeros(vocab_size, dtype=torch.long) |
| |
| self._stoch_a = torch.zeros(vocab_size, dtype=torch.long) |
| self._stoch_b = torch.zeros(vocab_size, dtype=torch.long) |
| self._stoch_p = torch.zeros(vocab_size) |
|
|
| for t, nt in self.det_rules.items(): |
| self._rule_type[t] = 0 |
| self._det_target[t] = nt |
| for t, (a, b, p) in self.stoch_rules.items(): |
| self._rule_type[t] = 1 |
| self._stoch_a[t] = a |
| self._stoch_b[t] = b |
| self._stoch_p[t] = p |
|
|
| def generate_batch(self, batch_size: int) -> Tuple[Tensor, Tensor]: |
| """Generate a batch of grammar sequences (vectorised). |
| |
| Returns |
| ------- |
| input_ids : Tensor ``[B, seq_len]`` |
| target_ids : Tensor ``[B, seq_len]`` |
| Shifted by one (standard LM target). |
| """ |
| B = batch_size |
| S = self.seq_len + 1 |
| n_data = self.vocab_size - self.token_offset |
|
|
| tokens = torch.empty(B, S, dtype=torch.long) |
| tokens[:, 0] = self.bos_token |
|
|
| |
| current = torch.randint(self.token_offset, self.vocab_size, (B,)) |
| tokens[:, 1] = current |
|
|
| |
| rand_vals = torch.rand(B, S) |
| wild_tokens = torch.randint(self.token_offset, self.vocab_size, (B, S)) |
|
|
| for t in range(2, S): |
| rt = self._rule_type[current] |
| det_next = self._det_target[current] |
| sa = self._stoch_a[current] |
| sb = self._stoch_b[current] |
| sp = self._stoch_p[current] |
|
|
| |
| stoch_next = torch.where(rand_vals[:, t] < sp, sa, sb) |
|
|
| |
| next_tok = torch.where(rt == 0, det_next, |
| torch.where(rt == 1, stoch_next, wild_tokens[:, t])) |
|
|
| tokens[:, t] = next_tok |
| current = next_tok |
|
|
| input_ids = tokens[:, :-1].contiguous() |
| target_ids = tokens[:, 1:].contiguous() |
| return input_ids, target_ids |
|
|
|
|
| |
| |
| |
|
|
| def compute_accuracy( |
| logits: Tensor, |
| targets: Tensor, |
| ignore_index: int = -100, |
| ) -> float: |
| """Compute token-level accuracy, ignoring padded positions. |
| |
| Parameters |
| ---------- |
| logits : Tensor ``[B, T, V]`` |
| targets : Tensor ``[B, T]`` |
| ignore_index : int |
| Target values to ignore. |
| |
| Returns |
| ------- |
| float |
| Accuracy in [0, 1]. |
| """ |
| preds = logits.argmax(dim=-1) |
| mask = targets != ignore_index |
| if mask.sum() == 0: |
| return 0.0 |
| correct = ((preds == targets) & mask).sum() |
| return float(correct) / float(mask.sum()) |
|
|
|
|
| def compute_perplexity(loss: float) -> float: |
| """Convert cross-entropy loss to perplexity.""" |
| return math.exp(min(loss, 100)) |
|
|
|
|
| import math |
|
|