abpt / src /fog /data.py
kharki's picture
feat: FOG stress ablation with new tasks (conditional, intersection, compose_add, multihop)
a4b762c verified
"""Synthetic data for FOG ablation — algorithmic tasks.
Easy: CopyTask, ReverseTask, SelectiveRetrieval
Hard: DistractorRetrieval, NoisyRetrieval, MultiQueryRetrieval, ChainedRetrieval
"""
from __future__ import annotations
import random
import torch
from torch.utils.data import Dataset
def _build_item(ids: list[int], sep_pos: int, seq_len: int) -> dict[str, torch.Tensor]:
"""Shared helper: pad/truncate, build input/target/mask."""
real_len = len(ids)
ids = ids[:seq_len]
ids += [0] * (seq_len - len(ids))
x = torch.tensor(ids[:-1], dtype=torch.long)
y = torch.tensor(ids[1:], dtype=torch.long)
m = torch.zeros_like(y)
# Only mask real tokens after SEP, not padding
end = min(real_len - 1, len(m)) # -1 because targets are shifted
if sep_pos < end:
m[sep_pos:end] = 1
return {"input_ids": x, "targets": y, "loss_mask": m}
def prebatch_dataset(dataset: Dataset, seq_len: int) -> dict[str, torch.Tensor]:
"""Pre-stack entire dataset into contiguous tensors for fast batching."""
n = len(dataset)
all_x = torch.zeros(n, seq_len - 1, dtype=torch.long)
all_y = torch.zeros(n, seq_len - 1, dtype=torch.long)
all_m = torch.zeros(n, seq_len - 1, dtype=torch.long)
for i in range(n):
item = dataset[i]
L = item["input_ids"].size(0)
all_x[i, :L] = item["input_ids"]
all_y[i, :L] = item["targets"]
all_m[i, :L] = item["loss_mask"]
return {"input_ids": all_x, "targets": all_y, "loss_mask": all_m}
class TensorBatchIterator:
"""Fast batch iterator over pre-stacked tensors. No DataLoader overhead."""
def __init__(self, data: dict[str, torch.Tensor], batch_size: int, shuffle: bool = False):
self.data = data
self.batch_size = batch_size
self.shuffle = shuffle
self.n = data["input_ids"].size(0)
def __iter__(self):
if self.shuffle:
perm = torch.randperm(self.n)
else:
perm = torch.arange(self.n)
for start in range(0, self.n, self.batch_size):
idx = perm[start : start + self.batch_size]
yield {k: v[idx] for k, v in self.data.items()}
def __len__(self) -> int:
return (self.n + self.batch_size - 1) // self.batch_size
# ── Easy tasks ──────────────────────────────────────────────────
class CopyTask(Dataset):
"""Copy: [a, b, c, SEP] -> [a, b, c]. Tests memory."""
def __init__(self, vocab_size: int, seq_len: int, n_samples: int, seed: int = 42):
super().__init__()
sep = vocab_size - 1
rng = random.Random(seed)
cv = vocab_size - 1
half = seq_len // 2 - 1
self.items = []
for _ in range(n_samples):
c = [rng.randint(0, cv - 1) for _ in range(half)]
ids = c + [sep] + c
self.items.append(_build_item(ids, len(c), seq_len))
def __len__(self) -> int:
return len(self.items)
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
return self.items[idx]
class ReverseTask(Dataset):
"""Reverse: [a, b, c, SEP] -> [c, b, a]. Tests compare + memory."""
def __init__(self, vocab_size: int, seq_len: int, n_samples: int, seed: int = 42):
super().__init__()
sep = vocab_size - 1
rng = random.Random(seed)
cv = vocab_size - 1
half = seq_len // 2 - 1
self.items = []
for _ in range(n_samples):
c = [rng.randint(0, cv - 1) for _ in range(half)]
ids = c + [sep] + list(reversed(c))
self.items.append(_build_item(ids, len(c), seq_len))
def __len__(self) -> int:
return len(self.items)
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
return self.items[idx]
class SelectiveRetrieval(Dataset):
"""[k1, v1, k2, v2, SEP, query_key] -> answer_value. Tests compare + select + memory."""
def __init__(self, vocab_size: int, seq_len: int, n_samples: int, n_pairs: int = 4, seed: int = 42):
super().__init__()
sep = vocab_size - 1
rng = random.Random(seed)
cv = vocab_size - 2
self.items = []
for _ in range(n_samples):
keys = rng.sample(range(cv), min(n_pairs, cv))
values = [rng.randint(0, cv - 1) for _ in keys]
qi = rng.randint(0, len(keys) - 1)
ids = []
for k, v in zip(keys, values):
ids.extend([k, v])
sp = len(ids)
ids += [sep, keys[qi], values[qi]]
self.items.append(_build_item(ids, sp, seq_len))
def __len__(self) -> int:
return len(self.items)
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
return self.items[idx]
# ── Hard tasks ──────────────────────────────────────────────────
class DistractorRetrieval(Dataset):
"""Keys differ by +/-1 from query. Forces precise compare."""
def __init__(self, vocab_size: int, seq_len: int, n_samples: int, n_pairs: int = 4, seed: int = 42):
super().__init__()
sep = vocab_size - 1
rng = random.Random(seed)
cv = vocab_size - 2
self.items = []
for _ in range(n_samples):
qk = rng.randint(n_pairs, cv - n_pairs - 1)
offsets = [i for i in range(-n_pairs, n_pairs + 1) if i != 0]
dk = [qk + o for o in offsets if 0 <= qk + o < cv]
keys = [qk] + dk[:n_pairs - 1]
rng.shuffle(keys)
values = [rng.randint(0, cv - 1) for _ in keys]
qi = keys.index(qk)
ids = []
for k, v in zip(keys, values):
ids.extend([k, v])
sp = len(ids)
ids += [sep, qk, values[qi]]
self.items.append(_build_item(ids, sp, seq_len))
def __len__(self) -> int:
return len(self.items)
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
return self.items[idx]
class NoisyRetrieval(Dataset):
"""Noise tokens between KV pairs. Forces select to filter, memory to retrieve."""
def __init__(self, vocab_size: int, seq_len: int, n_samples: int,
n_pairs: int = 3, noise_len: int = 2, seed: int = 42):
super().__init__()
sep = vocab_size - 1
rng = random.Random(seed)
cv = vocab_size - 2
self.items = []
for _ in range(n_samples):
keys = rng.sample(range(cv), min(n_pairs, cv))
values = [rng.randint(0, cv - 1) for _ in keys]
qi = rng.randint(0, len(keys) - 1)
ids = []
for i, (k, v) in enumerate(zip(keys, values)):
ids.extend([k, v])
if i < len(keys) - 1:
ids.extend([rng.randint(0, cv - 1) for _ in range(noise_len)])
sp = len(ids)
ids += [sep, keys[qi], values[qi]]
self.items.append(_build_item(ids, sp, seq_len))
def __len__(self) -> int:
return len(self.items)
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
return self.items[idx]
class MultiQueryRetrieval(Dataset):
"""2 sequential queries — retrieve 2 values. Tests compose."""
def __init__(self, vocab_size: int, seq_len: int, n_samples: int,
n_pairs: int = 4, seed: int = 42):
super().__init__()
sep = vocab_size - 1
rng = random.Random(seed)
cv = vocab_size - 2
self.items = []
for _ in range(n_samples):
keys = rng.sample(range(cv), min(n_pairs, cv))
values = [rng.randint(0, cv - 1) for _ in keys]
qis = rng.sample(range(len(keys)), min(2, len(keys)))
ids = []
for k, v in zip(keys, values):
ids.extend([k, v])
sp = len(ids)
ids.append(sep)
for qi in qis:
ids += [keys[qi], values[qi]]
self.items.append(_build_item(ids, sp, seq_len))
def __len__(self) -> int:
return len(self.items)
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
return self.items[idx]
class ChainedRetrieval(Dataset):
"""Two-hop lookup: find value for query key, use that value as key for second lookup.
[k1,v1, k2,v2, ..., kN,vN, SEP, query_key, final_answer]
The model must:
1. Compare query_key against all keys → find matching value (Φ_compare + Φ_memory)
2. Use that value as a new key → find its value (Φ_compose + Φ_memory)
3. Output the final value
This is compositional: uniform models with shared compare/memory
struggle when capacity is tight, while motif-aware models with
dedicated compare (narrow) and memory (wide) subspaces can separate
the two lookups.
"""
def __init__(self, vocab_size: int, seq_len: int, n_samples: int,
n_pairs: int = 6, seed: int = 42):
super().__init__()
sep = vocab_size - 1
rng = random.Random(seed)
cv = vocab_size - 2
self.items = []
attempts = 0
while len(self.items) < n_samples and attempts < n_samples * 20:
attempts += 1
if cv < n_pairs:
break
keys = rng.sample(range(cv), n_pairs)
values = [rng.randint(0, cv - 1) for _ in keys]
# Find a valid chain: query_key → value_1, value_1 must be a key → value_2
# value_1 must appear as a key somewhere (different pair)
chain_found = False
for qi in range(n_pairs):
v1 = values[qi]
for hop2 in range(n_pairs):
if hop2 != qi and keys[hop2] == v1:
# Chain: query keys[qi] → values[qi]=v1, then v1=keys[hop2] → values[hop2]
ids = []
for k, v in zip(keys, values):
ids.extend([k, v])
sp = len(ids)
answer = values[hop2]
ids += [sep, keys[qi], answer]
self.items.append(_build_item(ids, sp, seq_len))
chain_found = True
break
if chain_found:
break
def __len__(self) -> int:
return len(self.items)
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
return self.items[idx]
# ── Stress tasks for 400-800K models ────────────────────────────
class ConditionalRetrieval(Dataset):
"""Conditional branching: [k1,v1,k2,v2,..., SEP, query_key, threshold] -> value if key>threshold else 0.
Forces the model to compare query_key against keys AND compare the matched key
against a threshold — two distinct comparison operations chained with a branch.
Tests: Phi_compare (double), Phi_memory, conditional logic.
"""
def __init__(self, vocab_size: int, seq_len: int, n_samples: int,
n_pairs: int = 6, seed: int = 42):
super().__init__()
sep = vocab_size - 1
rng = random.Random(seed)
cv = vocab_size - 2
self.items = []
for _ in range(n_samples):
keys = rng.sample(range(cv), min(n_pairs, cv))
values = [rng.randint(1, cv - 1) for _ in keys]
qi = rng.randint(0, len(keys) - 1)
threshold = rng.randint(0, cv - 1)
answer = values[qi] if keys[qi] > threshold else 0
ids = []
for k, v in zip(keys, values):
ids.extend([k, v])
sp = len(ids)
ids += [sep, keys[qi], threshold, answer]
self.items.append(_build_item(ids, sp, seq_len))
def __len__(self) -> int:
return len(self.items)
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
return self.items[idx]
class SetIntersection(Dataset):
"""Two sets A and B, output sorted intersection.
[a1,a2,...,aN, SEP1, b1,b2,...,bM, SEP2, sorted intersection tokens...]
Forces the model to: hold set A in memory, scan B comparing each element,
collect matches, output them sorted. Multi-output, requires compare+memory+ordering.
"""
def __init__(self, vocab_size: int, seq_len: int, n_samples: int,
set_size: int = 8, overlap: int = 3, seed: int = 42):
super().__init__()
sep1 = vocab_size - 1
sep2 = vocab_size - 2
rng = random.Random(seed)
cv = vocab_size - 3
self.items = []
for _ in range(n_samples):
pool = rng.sample(range(cv), min(set_size * 2, cv))
shared = pool[:overlap]
a_only = pool[overlap:set_size]
b_only = pool[set_size:set_size + (set_size - overlap)]
set_a = shared + a_only
set_b = shared + b_only
rng.shuffle(set_a)
rng.shuffle(set_b)
intersection = sorted(shared)
ids = set_a + [sep1] + set_b + [sep2] + intersection
sp = len(set_a) + 1 + len(set_b) # mask starts after SEP2
self.items.append(_build_item(ids, sp, seq_len))
def __len__(self) -> int:
return len(self.items)
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
return self.items[idx]
class ComposeArithmetic(Dataset):
"""Retrieve two values by keys, output (v1 + v2) mod M.
[k1,v1,...,kN,vN, SEP, query_key1, query_key2, answer]
Forces: two parallel retrievals + arithmetic composition.
Tests: Phi_memory (dual retrieval), Phi_compose (modular addition).
"""
def __init__(self, vocab_size: int, seq_len: int, n_samples: int,
n_pairs: int = 6, seed: int = 42):
super().__init__()
sep = vocab_size - 1
rng = random.Random(seed)
cv = vocab_size - 2
modulus = cv # values stay in valid token range
self.items = []
for _ in range(n_samples):
keys = rng.sample(range(cv), min(n_pairs, cv))
values = [rng.randint(0, cv - 1) for _ in keys]
qi1, qi2 = rng.sample(range(len(keys)), 2)
answer = (values[qi1] + values[qi2]) % modulus
ids = []
for k, v in zip(keys, values):
ids.extend([k, v])
sp = len(ids)
ids += [sep, keys[qi1], keys[qi2], answer]
self.items.append(_build_item(ids, sp, seq_len))
def __len__(self) -> int:
return len(self.items)
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
return self.items[idx]
class MultiHopChained(Dataset):
"""3-hop chained retrieval with guaranteed chains (no retry loop).
Constructs chains deterministically: picks 4 keys, wires k0→k1→k2→k3,
fills remaining pairs randomly. Output = value of final hop.
Forces: 3 sequential compose operations, each requiring compare+memory.
"""
def __init__(self, vocab_size: int, seq_len: int, n_samples: int,
n_pairs: int = 10, seed: int = 42):
super().__init__()
sep = vocab_size - 1
rng = random.Random(seed)
cv = vocab_size - 2
self.items = []
for _ in range(n_samples):
all_keys = rng.sample(range(cv), min(n_pairs, cv))
# Wire a guaranteed 3-hop chain: k0→k1, k1→k2, k2→final_answer
values = [rng.randint(0, cv - 1) for _ in all_keys]
values[0] = all_keys[1] # hop1: k0 → k1
values[1] = all_keys[2] # hop2: k1 → k2
values[2] = rng.randint(0, cv - 1) # hop3: k2 → answer
answer = values[2]
# Shuffle pair order so chain isn't positionally obvious
pairs = list(zip(all_keys, values))
rng.shuffle(pairs)
ids = []
for k, v in pairs:
ids.extend([k, v])
sp = len(ids)
ids += [sep, all_keys[0], answer]
self.items.append(_build_item(ids, sp, seq_len))
def __len__(self) -> int:
return len(self.items)
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
return self.items[idx]