"""Training tasks for standalone WrinkleBrane evaluation. Three tasks of increasing difficulty: 1. **Sequence Copy**: Write a random sequence, predict it shifted by one. Tests basic memory write/read capability. 2. **Associative Recall**: Given key-value pairs followed by a query key, predict the associated value. Tests selective retrieval. 3. **Synthetic Grammar LM**: Next-token prediction on sequences generated by a procedural grammar with deterministic and stochastic rules. Tests whether the model can learn distributional patterns. All tasks produce ``(input_ids, target_ids)`` pairs suitable for cross-entropy training with the same model interface. """ from __future__ import annotations from typing import Tuple import torch from torch import Tensor # --------------------------------------------------------------------------- # Task 1: Sequence Copy # --------------------------------------------------------------------------- class SequenceCopyTask: """Memorize-and-reproduce task for testing memory write/read. The model sees a random sequence, then a SEP token, then must reproduce the sequence from memory: Input: ``[t_0, t_1, ..., t_{L-1}, SEP, t_0, t_1, ..., t_{L-2}]`` Target: ``[IGN, IGN, ..., IGN, t_0, t_1, ..., t_{L-1}]`` Only the reproduction phase (after SEP) is scored. This directly tests the model's ability to store tokens in the membrane and retrieve them in order. Parameters ---------- vocab_size : int Number of tokens (including special tokens). seq_len : int Length of the random sequence to memorize. """ def __init__( self, vocab_size: int = 32, seq_len: int = 8, ): self.vocab_size = vocab_size self.seq_len = seq_len self.sep_token = 0 self.token_offset = 1 # data tokens start at 1 self.ignore_index = -100 def generate_batch(self, batch_size: int) -> Tuple[Tensor, Tensor]: """Generate a batch of copy sequences. Returns ------- input_ids : Tensor ``[B, 2 * seq_len]`` target_ids : Tensor ``[B, 2 * seq_len]`` First ``seq_len`` positions are ``ignore_index``. """ L = self.seq_len # Random tokens in [token_offset, vocab_size) tokens = torch.randint( self.token_offset, self.vocab_size, (batch_size, L), ) # Input: [t_0, ..., t_{L-1}, SEP, t_0, ..., t_{L-2}] sep = torch.full((batch_size, 1), self.sep_token, dtype=torch.long) input_ids = torch.cat([tokens, sep, tokens[:, :-1]], dim=1) # [B, 2L] # Target: [IGN, ..., IGN, t_0, ..., t_{L-1}] ignore = torch.full((batch_size, L), self.ignore_index, dtype=torch.long) target_ids = torch.cat([ignore, tokens], dim=1) # [B, 2L] return input_ids, target_ids # --------------------------------------------------------------------------- # Task 2: Associative Recall # --------------------------------------------------------------------------- class AssociativeRecallTask: """Generate key-value association sequences. Format: ``[BOS, k1, v1, k2, v2, ..., SEP, k_query, PAD]`` Target: ``[IGN, IGN, IGN, ..., IGN, IGN, v_query]`` Only the final position's prediction is scored (the value for the queried key). Parameters ---------- vocab_size : int Total vocabulary. n_pairs : int Number of key-value pairs per sequence. """ def __init__( self, vocab_size: int = 32, n_pairs: int = 4, ): self.vocab_size = vocab_size self.n_pairs = n_pairs # Special tokens self.bos_token = 0 self.sep_token = 1 self.pad_token = 2 self.token_offset = 3 # data tokens start here self.ignore_index = -100 def generate_batch(self, batch_size: int) -> Tuple[Tensor, Tensor]: """Generate a batch of associative recall sequences. Returns ------- input_ids : Tensor ``[B, 2*n_pairs + 3]`` target_ids : Tensor ``[B, 2*n_pairs + 3]`` All positions are ``ignore_index`` except the last. """ n = self.n_pairs data_range = self.vocab_size - self.token_offset # Generate unique keys and values keys = torch.randint( self.token_offset, self.token_offset + data_range // 2, (batch_size, n), ) values = torch.randint( self.token_offset + data_range // 2, self.vocab_size, (batch_size, n), ) # Pick a random query index per batch query_idx = torch.randint(0, n, (batch_size,)) query_keys = keys[torch.arange(batch_size), query_idx] query_values = values[torch.arange(batch_size), query_idx] # Build input: [BOS, k1, v1, k2, v2, ..., SEP, k_query, PAD] seq_len = 2 * n + 3 input_ids = torch.full((batch_size, seq_len), self.pad_token, dtype=torch.long) input_ids[:, 0] = self.bos_token for i in range(n): input_ids[:, 1 + 2 * i] = keys[:, i] input_ids[:, 2 + 2 * i] = values[:, i] input_ids[:, 1 + 2 * n] = self.sep_token input_ids[:, 2 + 2 * n] = query_keys # Target: ignore all except last position target_ids = torch.full((batch_size, seq_len), self.ignore_index, dtype=torch.long) target_ids[:, -1] = query_values return input_ids, target_ids # --------------------------------------------------------------------------- # Task 3: Synthetic Grammar LM # --------------------------------------------------------------------------- class SyntheticGrammarTask: """Procedural grammar with learnable deterministic and stochastic rules. Grammar structure: - Vocabulary: ``vocab_size`` tokens (first 3 reserved for BOS/EOS/PAD) - Rules are of the form: ``if current token is X, next token is Y`` (deterministic) or ``next is Y1 or Y2 with probabilities p, 1-p`` (stochastic) - Some tokens trigger deterministic bigram patterns (always A→B) - Some tokens trigger probabilistic choices (C → D 70% or E 30%) - Some tokens are "wild" (uniform random next token) This creates a learnable language with enough structure to test whether the model captures distributional patterns. Parameters ---------- vocab_size : int Total vocabulary including special tokens. seq_len : int Sequence length. deterministic_frac : float Fraction of tokens with deterministic next-token rules. stochastic_frac : float Fraction of tokens with 2-way stochastic rules. seed : int RNG seed for rule generation (grammar is fixed, sequences vary). """ def __init__( self, vocab_size: int = 32, seq_len: int = 64, deterministic_frac: float = 0.4, stochastic_frac: float = 0.3, seed: int = 42, ): self.vocab_size = vocab_size self.seq_len = seq_len self.bos_token = 0 self.eos_token = 1 self.pad_token = 2 self.token_offset = 3 gen = torch.Generator().manual_seed(seed) data_tokens = list(range(self.token_offset, vocab_size)) n_data = len(data_tokens) n_det = int(n_data * deterministic_frac) n_stoch = int(n_data * stochastic_frac) # Shuffle to assign rule types perm = torch.randperm(n_data, generator=gen).tolist() det_tokens = [data_tokens[i] for i in perm[:n_det]] stoch_tokens = [data_tokens[i] for i in perm[n_det:n_det + n_stoch]] # Build rule tables self.det_rules = {} # token -> next_token self.stoch_rules = {} # token -> (token_a, token_b, prob_a) for t in det_tokens: next_t = data_tokens[torch.randint(0, n_data, (1,), generator=gen).item()] self.det_rules[t] = next_t for t in stoch_tokens: a = data_tokens[torch.randint(0, n_data, (1,), generator=gen).item()] b = data_tokens[torch.randint(0, n_data, (1,), generator=gen).item()] while b == a: b = data_tokens[torch.randint(0, n_data, (1,), generator=gen).item()] prob_a = 0.3 + 0.4 * torch.rand(1, generator=gen).item() # 0.3-0.7 self.stoch_rules[t] = (a, b, prob_a) self.wild_tokens = [ t for t in data_tokens if t not in self.det_rules and t not in self.stoch_rules ] # Pre-compute vectorised lookup tables for fast batch generation. # rule_type[t]: 0=det, 1=stoch, 2=wild self._rule_type = torch.full((vocab_size,), 2, dtype=torch.long) # det_target[t]: deterministic next token (only valid when rule_type==0) self._det_target = torch.zeros(vocab_size, dtype=torch.long) # stoch_a[t], stoch_b[t], stoch_p[t]: stochastic rule params self._stoch_a = torch.zeros(vocab_size, dtype=torch.long) self._stoch_b = torch.zeros(vocab_size, dtype=torch.long) self._stoch_p = torch.zeros(vocab_size) for t, nt in self.det_rules.items(): self._rule_type[t] = 0 self._det_target[t] = nt for t, (a, b, p) in self.stoch_rules.items(): self._rule_type[t] = 1 self._stoch_a[t] = a self._stoch_b[t] = b self._stoch_p[t] = p def generate_batch(self, batch_size: int) -> Tuple[Tensor, Tensor]: """Generate a batch of grammar sequences (vectorised). Returns ------- input_ids : Tensor ``[B, seq_len]`` target_ids : Tensor ``[B, seq_len]`` Shifted by one (standard LM target). """ B = batch_size S = self.seq_len + 1 # need one extra for shift n_data = self.vocab_size - self.token_offset tokens = torch.empty(B, S, dtype=torch.long) tokens[:, 0] = self.bos_token # Random start tokens for the whole batch current = torch.randint(self.token_offset, self.vocab_size, (B,)) tokens[:, 1] = current # Pre-sample all random numbers we'll need rand_vals = torch.rand(B, S) wild_tokens = torch.randint(self.token_offset, self.vocab_size, (B, S)) for t in range(2, S): rt = self._rule_type[current] # [B] det_next = self._det_target[current] # [B] sa = self._stoch_a[current] # [B] sb = self._stoch_b[current] # [B] sp = self._stoch_p[current] # [B] # Stochastic: pick a if rand < p, else b stoch_next = torch.where(rand_vals[:, t] < sp, sa, sb) # Combine: det if rt==0, stoch if rt==1, wild if rt==2 next_tok = torch.where(rt == 0, det_next, torch.where(rt == 1, stoch_next, wild_tokens[:, t])) tokens[:, t] = next_tok current = next_tok input_ids = tokens[:, :-1].contiguous() # [B, seq_len] target_ids = tokens[:, 1:].contiguous() # [B, seq_len] return input_ids, target_ids # --------------------------------------------------------------------------- # Evaluation utilities # --------------------------------------------------------------------------- def compute_accuracy( logits: Tensor, targets: Tensor, ignore_index: int = -100, ) -> float: """Compute token-level accuracy, ignoring padded positions. Parameters ---------- logits : Tensor ``[B, T, V]`` targets : Tensor ``[B, T]`` ignore_index : int Target values to ignore. Returns ------- float Accuracy in [0, 1]. """ preds = logits.argmax(dim=-1) # [B, T] mask = targets != ignore_index if mask.sum() == 0: return 0.0 correct = ((preds == targets) & mask).sum() return float(correct) / float(mask.sum()) def compute_perplexity(loss: float) -> float: """Convert cross-entropy loss to perplexity.""" return math.exp(min(loss, 100)) # clamp to avoid overflow import math