Spaces:
Running on Zero
Running on Zero
| """Synthetic data for FOG ablation — algorithmic tasks. | |
| Easy: CopyTask, ReverseTask, SelectiveRetrieval | |
| Hard: DistractorRetrieval, NoisyRetrieval, MultiQueryRetrieval, ChainedRetrieval | |
| """ | |
| from __future__ import annotations | |
| import random | |
| import torch | |
| from torch.utils.data import Dataset | |
| def _build_item(ids: list[int], sep_pos: int, seq_len: int) -> dict[str, torch.Tensor]: | |
| """Shared helper: pad/truncate, build input/target/mask.""" | |
| real_len = len(ids) | |
| ids = ids[:seq_len] | |
| ids += [0] * (seq_len - len(ids)) | |
| x = torch.tensor(ids[:-1], dtype=torch.long) | |
| y = torch.tensor(ids[1:], dtype=torch.long) | |
| m = torch.zeros_like(y) | |
| # Only mask real tokens after SEP, not padding | |
| end = min(real_len - 1, len(m)) # -1 because targets are shifted | |
| if sep_pos < end: | |
| m[sep_pos:end] = 1 | |
| return {"input_ids": x, "targets": y, "loss_mask": m} | |
| def prebatch_dataset(dataset: Dataset, seq_len: int) -> dict[str, torch.Tensor]: | |
| """Pre-stack entire dataset into contiguous tensors for fast batching.""" | |
| n = len(dataset) | |
| all_x = torch.zeros(n, seq_len - 1, dtype=torch.long) | |
| all_y = torch.zeros(n, seq_len - 1, dtype=torch.long) | |
| all_m = torch.zeros(n, seq_len - 1, dtype=torch.long) | |
| for i in range(n): | |
| item = dataset[i] | |
| L = item["input_ids"].size(0) | |
| all_x[i, :L] = item["input_ids"] | |
| all_y[i, :L] = item["targets"] | |
| all_m[i, :L] = item["loss_mask"] | |
| return {"input_ids": all_x, "targets": all_y, "loss_mask": all_m} | |
| class TensorBatchIterator: | |
| """Fast batch iterator over pre-stacked tensors. No DataLoader overhead.""" | |
| def __init__(self, data: dict[str, torch.Tensor], batch_size: int, shuffle: bool = False): | |
| self.data = data | |
| self.batch_size = batch_size | |
| self.shuffle = shuffle | |
| self.n = data["input_ids"].size(0) | |
| def __iter__(self): | |
| if self.shuffle: | |
| perm = torch.randperm(self.n) | |
| else: | |
| perm = torch.arange(self.n) | |
| for start in range(0, self.n, self.batch_size): | |
| idx = perm[start : start + self.batch_size] | |
| yield {k: v[idx] for k, v in self.data.items()} | |
| def __len__(self) -> int: | |
| return (self.n + self.batch_size - 1) // self.batch_size | |
| # ── Easy tasks ────────────────────────────────────────────────── | |
| class CopyTask(Dataset): | |
| """Copy: [a, b, c, SEP] -> [a, b, c]. Tests memory.""" | |
| def __init__(self, vocab_size: int, seq_len: int, n_samples: int, seed: int = 42): | |
| super().__init__() | |
| sep = vocab_size - 1 | |
| rng = random.Random(seed) | |
| cv = vocab_size - 1 | |
| half = seq_len // 2 - 1 | |
| self.items = [] | |
| for _ in range(n_samples): | |
| c = [rng.randint(0, cv - 1) for _ in range(half)] | |
| ids = c + [sep] + c | |
| self.items.append(_build_item(ids, len(c), seq_len)) | |
| def __len__(self) -> int: | |
| return len(self.items) | |
| def __getitem__(self, idx: int) -> dict[str, torch.Tensor]: | |
| return self.items[idx] | |
| class ReverseTask(Dataset): | |
| """Reverse: [a, b, c, SEP] -> [c, b, a]. Tests compare + memory.""" | |
| def __init__(self, vocab_size: int, seq_len: int, n_samples: int, seed: int = 42): | |
| super().__init__() | |
| sep = vocab_size - 1 | |
| rng = random.Random(seed) | |
| cv = vocab_size - 1 | |
| half = seq_len // 2 - 1 | |
| self.items = [] | |
| for _ in range(n_samples): | |
| c = [rng.randint(0, cv - 1) for _ in range(half)] | |
| ids = c + [sep] + list(reversed(c)) | |
| self.items.append(_build_item(ids, len(c), seq_len)) | |
| def __len__(self) -> int: | |
| return len(self.items) | |
| def __getitem__(self, idx: int) -> dict[str, torch.Tensor]: | |
| return self.items[idx] | |
| class SelectiveRetrieval(Dataset): | |
| """[k1, v1, k2, v2, SEP, query_key] -> answer_value. Tests compare + select + memory.""" | |
| def __init__(self, vocab_size: int, seq_len: int, n_samples: int, n_pairs: int = 4, seed: int = 42): | |
| super().__init__() | |
| sep = vocab_size - 1 | |
| rng = random.Random(seed) | |
| cv = vocab_size - 2 | |
| self.items = [] | |
| for _ in range(n_samples): | |
| keys = rng.sample(range(cv), min(n_pairs, cv)) | |
| values = [rng.randint(0, cv - 1) for _ in keys] | |
| qi = rng.randint(0, len(keys) - 1) | |
| ids = [] | |
| for k, v in zip(keys, values): | |
| ids.extend([k, v]) | |
| sp = len(ids) | |
| ids += [sep, keys[qi], values[qi]] | |
| self.items.append(_build_item(ids, sp, seq_len)) | |
| def __len__(self) -> int: | |
| return len(self.items) | |
| def __getitem__(self, idx: int) -> dict[str, torch.Tensor]: | |
| return self.items[idx] | |
| # ── Hard tasks ────────────────────────────────────────────────── | |
| class DistractorRetrieval(Dataset): | |
| """Keys differ by +/-1 from query. Forces precise compare.""" | |
| def __init__(self, vocab_size: int, seq_len: int, n_samples: int, n_pairs: int = 4, seed: int = 42): | |
| super().__init__() | |
| sep = vocab_size - 1 | |
| rng = random.Random(seed) | |
| cv = vocab_size - 2 | |
| self.items = [] | |
| for _ in range(n_samples): | |
| qk = rng.randint(n_pairs, cv - n_pairs - 1) | |
| offsets = [i for i in range(-n_pairs, n_pairs + 1) if i != 0] | |
| dk = [qk + o for o in offsets if 0 <= qk + o < cv] | |
| keys = [qk] + dk[:n_pairs - 1] | |
| rng.shuffle(keys) | |
| values = [rng.randint(0, cv - 1) for _ in keys] | |
| qi = keys.index(qk) | |
| ids = [] | |
| for k, v in zip(keys, values): | |
| ids.extend([k, v]) | |
| sp = len(ids) | |
| ids += [sep, qk, values[qi]] | |
| self.items.append(_build_item(ids, sp, seq_len)) | |
| def __len__(self) -> int: | |
| return len(self.items) | |
| def __getitem__(self, idx: int) -> dict[str, torch.Tensor]: | |
| return self.items[idx] | |
| class NoisyRetrieval(Dataset): | |
| """Noise tokens between KV pairs. Forces select to filter, memory to retrieve.""" | |
| def __init__(self, vocab_size: int, seq_len: int, n_samples: int, | |
| n_pairs: int = 3, noise_len: int = 2, seed: int = 42): | |
| super().__init__() | |
| sep = vocab_size - 1 | |
| rng = random.Random(seed) | |
| cv = vocab_size - 2 | |
| self.items = [] | |
| for _ in range(n_samples): | |
| keys = rng.sample(range(cv), min(n_pairs, cv)) | |
| values = [rng.randint(0, cv - 1) for _ in keys] | |
| qi = rng.randint(0, len(keys) - 1) | |
| ids = [] | |
| for i, (k, v) in enumerate(zip(keys, values)): | |
| ids.extend([k, v]) | |
| if i < len(keys) - 1: | |
| ids.extend([rng.randint(0, cv - 1) for _ in range(noise_len)]) | |
| sp = len(ids) | |
| ids += [sep, keys[qi], values[qi]] | |
| self.items.append(_build_item(ids, sp, seq_len)) | |
| def __len__(self) -> int: | |
| return len(self.items) | |
| def __getitem__(self, idx: int) -> dict[str, torch.Tensor]: | |
| return self.items[idx] | |
| class MultiQueryRetrieval(Dataset): | |
| """2 sequential queries — retrieve 2 values. Tests compose.""" | |
| def __init__(self, vocab_size: int, seq_len: int, n_samples: int, | |
| n_pairs: int = 4, seed: int = 42): | |
| super().__init__() | |
| sep = vocab_size - 1 | |
| rng = random.Random(seed) | |
| cv = vocab_size - 2 | |
| self.items = [] | |
| for _ in range(n_samples): | |
| keys = rng.sample(range(cv), min(n_pairs, cv)) | |
| values = [rng.randint(0, cv - 1) for _ in keys] | |
| qis = rng.sample(range(len(keys)), min(2, len(keys))) | |
| ids = [] | |
| for k, v in zip(keys, values): | |
| ids.extend([k, v]) | |
| sp = len(ids) | |
| ids.append(sep) | |
| for qi in qis: | |
| ids += [keys[qi], values[qi]] | |
| self.items.append(_build_item(ids, sp, seq_len)) | |
| def __len__(self) -> int: | |
| return len(self.items) | |
| def __getitem__(self, idx: int) -> dict[str, torch.Tensor]: | |
| return self.items[idx] | |
| class ChainedRetrieval(Dataset): | |
| """Two-hop lookup: find value for query key, use that value as key for second lookup. | |
| [k1,v1, k2,v2, ..., kN,vN, SEP, query_key, final_answer] | |
| The model must: | |
| 1. Compare query_key against all keys → find matching value (Φ_compare + Φ_memory) | |
| 2. Use that value as a new key → find its value (Φ_compose + Φ_memory) | |
| 3. Output the final value | |
| This is compositional: uniform models with shared compare/memory | |
| struggle when capacity is tight, while motif-aware models with | |
| dedicated compare (narrow) and memory (wide) subspaces can separate | |
| the two lookups. | |
| """ | |
| def __init__(self, vocab_size: int, seq_len: int, n_samples: int, | |
| n_pairs: int = 6, seed: int = 42): | |
| super().__init__() | |
| sep = vocab_size - 1 | |
| rng = random.Random(seed) | |
| cv = vocab_size - 2 | |
| self.items = [] | |
| attempts = 0 | |
| while len(self.items) < n_samples and attempts < n_samples * 20: | |
| attempts += 1 | |
| if cv < n_pairs: | |
| break | |
| keys = rng.sample(range(cv), n_pairs) | |
| values = [rng.randint(0, cv - 1) for _ in keys] | |
| # Find a valid chain: query_key → value_1, value_1 must be a key → value_2 | |
| # value_1 must appear as a key somewhere (different pair) | |
| chain_found = False | |
| for qi in range(n_pairs): | |
| v1 = values[qi] | |
| for hop2 in range(n_pairs): | |
| if hop2 != qi and keys[hop2] == v1: | |
| # Chain: query keys[qi] → values[qi]=v1, then v1=keys[hop2] → values[hop2] | |
| ids = [] | |
| for k, v in zip(keys, values): | |
| ids.extend([k, v]) | |
| sp = len(ids) | |
| answer = values[hop2] | |
| ids += [sep, keys[qi], answer] | |
| self.items.append(_build_item(ids, sp, seq_len)) | |
| chain_found = True | |
| break | |
| if chain_found: | |
| break | |
| def __len__(self) -> int: | |
| return len(self.items) | |
| def __getitem__(self, idx: int) -> dict[str, torch.Tensor]: | |
| return self.items[idx] | |
| # ── Stress tasks for 400-800K models ──────────────────────────── | |
| class ConditionalRetrieval(Dataset): | |
| """Conditional branching: [k1,v1,k2,v2,..., SEP, query_key, threshold] -> value if key>threshold else 0. | |
| Forces the model to compare query_key against keys AND compare the matched key | |
| against a threshold — two distinct comparison operations chained with a branch. | |
| Tests: Phi_compare (double), Phi_memory, conditional logic. | |
| """ | |
| def __init__(self, vocab_size: int, seq_len: int, n_samples: int, | |
| n_pairs: int = 6, seed: int = 42): | |
| super().__init__() | |
| sep = vocab_size - 1 | |
| rng = random.Random(seed) | |
| cv = vocab_size - 2 | |
| self.items = [] | |
| for _ in range(n_samples): | |
| keys = rng.sample(range(cv), min(n_pairs, cv)) | |
| values = [rng.randint(1, cv - 1) for _ in keys] | |
| qi = rng.randint(0, len(keys) - 1) | |
| threshold = rng.randint(0, cv - 1) | |
| answer = values[qi] if keys[qi] > threshold else 0 | |
| ids = [] | |
| for k, v in zip(keys, values): | |
| ids.extend([k, v]) | |
| sp = len(ids) | |
| ids += [sep, keys[qi], threshold, answer] | |
| self.items.append(_build_item(ids, sp, seq_len)) | |
| def __len__(self) -> int: | |
| return len(self.items) | |
| def __getitem__(self, idx: int) -> dict[str, torch.Tensor]: | |
| return self.items[idx] | |
| class SetIntersection(Dataset): | |
| """Two sets A and B, output sorted intersection. | |
| [a1,a2,...,aN, SEP1, b1,b2,...,bM, SEP2, sorted intersection tokens...] | |
| Forces the model to: hold set A in memory, scan B comparing each element, | |
| collect matches, output them sorted. Multi-output, requires compare+memory+ordering. | |
| """ | |
| def __init__(self, vocab_size: int, seq_len: int, n_samples: int, | |
| set_size: int = 8, overlap: int = 3, seed: int = 42): | |
| super().__init__() | |
| sep1 = vocab_size - 1 | |
| sep2 = vocab_size - 2 | |
| rng = random.Random(seed) | |
| cv = vocab_size - 3 | |
| self.items = [] | |
| for _ in range(n_samples): | |
| pool = rng.sample(range(cv), min(set_size * 2, cv)) | |
| shared = pool[:overlap] | |
| a_only = pool[overlap:set_size] | |
| b_only = pool[set_size:set_size + (set_size - overlap)] | |
| set_a = shared + a_only | |
| set_b = shared + b_only | |
| rng.shuffle(set_a) | |
| rng.shuffle(set_b) | |
| intersection = sorted(shared) | |
| ids = set_a + [sep1] + set_b + [sep2] + intersection | |
| sp = len(set_a) + 1 + len(set_b) # mask starts after SEP2 | |
| self.items.append(_build_item(ids, sp, seq_len)) | |
| def __len__(self) -> int: | |
| return len(self.items) | |
| def __getitem__(self, idx: int) -> dict[str, torch.Tensor]: | |
| return self.items[idx] | |
| class ComposeArithmetic(Dataset): | |
| """Retrieve two values by keys, output (v1 + v2) mod M. | |
| [k1,v1,...,kN,vN, SEP, query_key1, query_key2, answer] | |
| Forces: two parallel retrievals + arithmetic composition. | |
| Tests: Phi_memory (dual retrieval), Phi_compose (modular addition). | |
| """ | |
| def __init__(self, vocab_size: int, seq_len: int, n_samples: int, | |
| n_pairs: int = 6, seed: int = 42): | |
| super().__init__() | |
| sep = vocab_size - 1 | |
| rng = random.Random(seed) | |
| cv = vocab_size - 2 | |
| modulus = cv # values stay in valid token range | |
| self.items = [] | |
| for _ in range(n_samples): | |
| keys = rng.sample(range(cv), min(n_pairs, cv)) | |
| values = [rng.randint(0, cv - 1) for _ in keys] | |
| qi1, qi2 = rng.sample(range(len(keys)), 2) | |
| answer = (values[qi1] + values[qi2]) % modulus | |
| ids = [] | |
| for k, v in zip(keys, values): | |
| ids.extend([k, v]) | |
| sp = len(ids) | |
| ids += [sep, keys[qi1], keys[qi2], answer] | |
| self.items.append(_build_item(ids, sp, seq_len)) | |
| def __len__(self) -> int: | |
| return len(self.items) | |
| def __getitem__(self, idx: int) -> dict[str, torch.Tensor]: | |
| return self.items[idx] | |
| class MultiHopChained(Dataset): | |
| """3-hop chained retrieval with guaranteed chains (no retry loop). | |
| Constructs chains deterministically: picks 4 keys, wires k0→k1→k2→k3, | |
| fills remaining pairs randomly. Output = value of final hop. | |
| Forces: 3 sequential compose operations, each requiring compare+memory. | |
| """ | |
| def __init__(self, vocab_size: int, seq_len: int, n_samples: int, | |
| n_pairs: int = 10, seed: int = 42): | |
| super().__init__() | |
| sep = vocab_size - 1 | |
| rng = random.Random(seed) | |
| cv = vocab_size - 2 | |
| self.items = [] | |
| for _ in range(n_samples): | |
| all_keys = rng.sample(range(cv), min(n_pairs, cv)) | |
| # Wire a guaranteed 3-hop chain: k0→k1, k1→k2, k2→final_answer | |
| values = [rng.randint(0, cv - 1) for _ in all_keys] | |
| values[0] = all_keys[1] # hop1: k0 → k1 | |
| values[1] = all_keys[2] # hop2: k1 → k2 | |
| values[2] = rng.randint(0, cv - 1) # hop3: k2 → answer | |
| answer = values[2] | |
| # Shuffle pair order so chain isn't positionally obvious | |
| pairs = list(zip(all_keys, values)) | |
| rng.shuffle(pairs) | |
| ids = [] | |
| for k, v in pairs: | |
| ids.extend([k, v]) | |
| sp = len(ids) | |
| ids += [sep, all_keys[0], answer] | |
| self.items.append(_build_item(ids, sp, seq_len)) | |
| def __len__(self) -> int: | |
| return len(self.items) | |
| def __getitem__(self, idx: int) -> dict[str, torch.Tensor]: | |
| return self.items[idx] | |