WCNegentropy's picture
Upload 510 files
3d7f6c5 verified
"""Training tasks for standalone WrinkleBrane evaluation.
Three tasks of increasing difficulty:
1. **Sequence Copy**: Write a random sequence, predict it shifted by one.
Tests basic memory write/read capability.
2. **Associative Recall**: Given key-value pairs followed by a query key,
predict the associated value. Tests selective retrieval.
3. **Synthetic Grammar LM**: Next-token prediction on sequences generated
by a procedural grammar with deterministic and stochastic rules.
Tests whether the model can learn distributional patterns.
All tasks produce ``(input_ids, target_ids)`` pairs suitable for
cross-entropy training with the same model interface.
"""
from __future__ import annotations
from typing import Tuple
import torch
from torch import Tensor
# ---------------------------------------------------------------------------
# Task 1: Sequence Copy
# ---------------------------------------------------------------------------
class SequenceCopyTask:
"""Memorize-and-reproduce task for testing memory write/read.
The model sees a random sequence, then a SEP token, then must
reproduce the sequence from memory:
Input: ``[t_0, t_1, ..., t_{L-1}, SEP, t_0, t_1, ..., t_{L-2}]``
Target: ``[IGN, IGN, ..., IGN, t_0, t_1, ..., t_{L-1}]``
Only the reproduction phase (after SEP) is scored. This directly
tests the model's ability to store tokens in the membrane and
retrieve them in order.
Parameters
----------
vocab_size : int
Number of tokens (including special tokens).
seq_len : int
Length of the random sequence to memorize.
"""
def __init__(
self,
vocab_size: int = 32,
seq_len: int = 8,
):
self.vocab_size = vocab_size
self.seq_len = seq_len
self.sep_token = 0
self.token_offset = 1 # data tokens start at 1
self.ignore_index = -100
def generate_batch(self, batch_size: int) -> Tuple[Tensor, Tensor]:
"""Generate a batch of copy sequences.
Returns
-------
input_ids : Tensor ``[B, 2 * seq_len]``
target_ids : Tensor ``[B, 2 * seq_len]``
First ``seq_len`` positions are ``ignore_index``.
"""
L = self.seq_len
# Random tokens in [token_offset, vocab_size)
tokens = torch.randint(
self.token_offset, self.vocab_size, (batch_size, L),
)
# Input: [t_0, ..., t_{L-1}, SEP, t_0, ..., t_{L-2}]
sep = torch.full((batch_size, 1), self.sep_token, dtype=torch.long)
input_ids = torch.cat([tokens, sep, tokens[:, :-1]], dim=1) # [B, 2L]
# Target: [IGN, ..., IGN, t_0, ..., t_{L-1}]
ignore = torch.full((batch_size, L), self.ignore_index, dtype=torch.long)
target_ids = torch.cat([ignore, tokens], dim=1) # [B, 2L]
return input_ids, target_ids
# ---------------------------------------------------------------------------
# Task 2: Associative Recall
# ---------------------------------------------------------------------------
class AssociativeRecallTask:
"""Generate key-value association sequences.
Format: ``[BOS, k1, v1, k2, v2, ..., SEP, k_query, PAD]``
Target: ``[IGN, IGN, IGN, ..., IGN, IGN, v_query]``
Only the final position's prediction is scored (the value for the
queried key).
Parameters
----------
vocab_size : int
Total vocabulary.
n_pairs : int
Number of key-value pairs per sequence.
"""
def __init__(
self,
vocab_size: int = 32,
n_pairs: int = 4,
):
self.vocab_size = vocab_size
self.n_pairs = n_pairs
# Special tokens
self.bos_token = 0
self.sep_token = 1
self.pad_token = 2
self.token_offset = 3 # data tokens start here
self.ignore_index = -100
def generate_batch(self, batch_size: int) -> Tuple[Tensor, Tensor]:
"""Generate a batch of associative recall sequences.
Returns
-------
input_ids : Tensor ``[B, 2*n_pairs + 3]``
target_ids : Tensor ``[B, 2*n_pairs + 3]``
All positions are ``ignore_index`` except the last.
"""
n = self.n_pairs
data_range = self.vocab_size - self.token_offset
# Generate unique keys and values
keys = torch.randint(
self.token_offset, self.token_offset + data_range // 2,
(batch_size, n),
)
values = torch.randint(
self.token_offset + data_range // 2, self.vocab_size,
(batch_size, n),
)
# Pick a random query index per batch
query_idx = torch.randint(0, n, (batch_size,))
query_keys = keys[torch.arange(batch_size), query_idx]
query_values = values[torch.arange(batch_size), query_idx]
# Build input: [BOS, k1, v1, k2, v2, ..., SEP, k_query, PAD]
seq_len = 2 * n + 3
input_ids = torch.full((batch_size, seq_len), self.pad_token, dtype=torch.long)
input_ids[:, 0] = self.bos_token
for i in range(n):
input_ids[:, 1 + 2 * i] = keys[:, i]
input_ids[:, 2 + 2 * i] = values[:, i]
input_ids[:, 1 + 2 * n] = self.sep_token
input_ids[:, 2 + 2 * n] = query_keys
# Target: ignore all except last position
target_ids = torch.full((batch_size, seq_len), self.ignore_index, dtype=torch.long)
target_ids[:, -1] = query_values
return input_ids, target_ids
# ---------------------------------------------------------------------------
# Task 3: Synthetic Grammar LM
# ---------------------------------------------------------------------------
class SyntheticGrammarTask:
"""Procedural grammar with learnable deterministic and stochastic rules.
Grammar structure:
- Vocabulary: ``vocab_size`` tokens (first 3 reserved for BOS/EOS/PAD)
- Rules are of the form: ``if current token is X, next token is Y``
(deterministic) or ``next is Y1 or Y2 with probabilities p, 1-p``
(stochastic)
- Some tokens trigger deterministic bigram patterns (always A→B)
- Some tokens trigger probabilistic choices (C → D 70% or E 30%)
- Some tokens are "wild" (uniform random next token)
This creates a learnable language with enough structure to test whether
the model captures distributional patterns.
Parameters
----------
vocab_size : int
Total vocabulary including special tokens.
seq_len : int
Sequence length.
deterministic_frac : float
Fraction of tokens with deterministic next-token rules.
stochastic_frac : float
Fraction of tokens with 2-way stochastic rules.
seed : int
RNG seed for rule generation (grammar is fixed, sequences vary).
"""
def __init__(
self,
vocab_size: int = 32,
seq_len: int = 64,
deterministic_frac: float = 0.4,
stochastic_frac: float = 0.3,
seed: int = 42,
):
self.vocab_size = vocab_size
self.seq_len = seq_len
self.bos_token = 0
self.eos_token = 1
self.pad_token = 2
self.token_offset = 3
gen = torch.Generator().manual_seed(seed)
data_tokens = list(range(self.token_offset, vocab_size))
n_data = len(data_tokens)
n_det = int(n_data * deterministic_frac)
n_stoch = int(n_data * stochastic_frac)
# Shuffle to assign rule types
perm = torch.randperm(n_data, generator=gen).tolist()
det_tokens = [data_tokens[i] for i in perm[:n_det]]
stoch_tokens = [data_tokens[i] for i in perm[n_det:n_det + n_stoch]]
# Build rule tables
self.det_rules = {} # token -> next_token
self.stoch_rules = {} # token -> (token_a, token_b, prob_a)
for t in det_tokens:
next_t = data_tokens[torch.randint(0, n_data, (1,), generator=gen).item()]
self.det_rules[t] = next_t
for t in stoch_tokens:
a = data_tokens[torch.randint(0, n_data, (1,), generator=gen).item()]
b = data_tokens[torch.randint(0, n_data, (1,), generator=gen).item()]
while b == a:
b = data_tokens[torch.randint(0, n_data, (1,), generator=gen).item()]
prob_a = 0.3 + 0.4 * torch.rand(1, generator=gen).item() # 0.3-0.7
self.stoch_rules[t] = (a, b, prob_a)
self.wild_tokens = [
t for t in data_tokens
if t not in self.det_rules and t not in self.stoch_rules
]
# Pre-compute vectorised lookup tables for fast batch generation.
# rule_type[t]: 0=det, 1=stoch, 2=wild
self._rule_type = torch.full((vocab_size,), 2, dtype=torch.long)
# det_target[t]: deterministic next token (only valid when rule_type==0)
self._det_target = torch.zeros(vocab_size, dtype=torch.long)
# stoch_a[t], stoch_b[t], stoch_p[t]: stochastic rule params
self._stoch_a = torch.zeros(vocab_size, dtype=torch.long)
self._stoch_b = torch.zeros(vocab_size, dtype=torch.long)
self._stoch_p = torch.zeros(vocab_size)
for t, nt in self.det_rules.items():
self._rule_type[t] = 0
self._det_target[t] = nt
for t, (a, b, p) in self.stoch_rules.items():
self._rule_type[t] = 1
self._stoch_a[t] = a
self._stoch_b[t] = b
self._stoch_p[t] = p
def generate_batch(self, batch_size: int) -> Tuple[Tensor, Tensor]:
"""Generate a batch of grammar sequences (vectorised).
Returns
-------
input_ids : Tensor ``[B, seq_len]``
target_ids : Tensor ``[B, seq_len]``
Shifted by one (standard LM target).
"""
B = batch_size
S = self.seq_len + 1 # need one extra for shift
n_data = self.vocab_size - self.token_offset
tokens = torch.empty(B, S, dtype=torch.long)
tokens[:, 0] = self.bos_token
# Random start tokens for the whole batch
current = torch.randint(self.token_offset, self.vocab_size, (B,))
tokens[:, 1] = current
# Pre-sample all random numbers we'll need
rand_vals = torch.rand(B, S)
wild_tokens = torch.randint(self.token_offset, self.vocab_size, (B, S))
for t in range(2, S):
rt = self._rule_type[current] # [B]
det_next = self._det_target[current] # [B]
sa = self._stoch_a[current] # [B]
sb = self._stoch_b[current] # [B]
sp = self._stoch_p[current] # [B]
# Stochastic: pick a if rand < p, else b
stoch_next = torch.where(rand_vals[:, t] < sp, sa, sb)
# Combine: det if rt==0, stoch if rt==1, wild if rt==2
next_tok = torch.where(rt == 0, det_next,
torch.where(rt == 1, stoch_next, wild_tokens[:, t]))
tokens[:, t] = next_tok
current = next_tok
input_ids = tokens[:, :-1].contiguous() # [B, seq_len]
target_ids = tokens[:, 1:].contiguous() # [B, seq_len]
return input_ids, target_ids
# ---------------------------------------------------------------------------
# Evaluation utilities
# ---------------------------------------------------------------------------
def compute_accuracy(
logits: Tensor,
targets: Tensor,
ignore_index: int = -100,
) -> float:
"""Compute token-level accuracy, ignoring padded positions.
Parameters
----------
logits : Tensor ``[B, T, V]``
targets : Tensor ``[B, T]``
ignore_index : int
Target values to ignore.
Returns
-------
float
Accuracy in [0, 1].
"""
preds = logits.argmax(dim=-1) # [B, T]
mask = targets != ignore_index
if mask.sum() == 0:
return 0.0
correct = ((preds == targets) & mask).sum()
return float(correct) / float(mask.sum())
def compute_perplexity(loss: float) -> float:
"""Convert cross-entropy loss to perplexity."""
return math.exp(min(loss, 100)) # clamp to avoid overflow
import math