""" Tasks for the Representation Learning Dynamics experiment. ============================================================ Graduated dissimilarity ladder of algorithmic tasks on integers mod p. All share the same vocabulary but require increasingly different circuits. Level 0 — Task A: Modular Addition (a + b mod p) → Fourier circuit Level 1 — Task B: Modular Subtraction (a - b mod p) → Same Fourier circuit (sign flip) Level 2 — Task C: Modular Multiplication (a * b mod p) → Discrete-log Fourier circuit Level 3 — Task D: Max (ordered comparison) max(a, b) → Linear/ordinal circuit Level 4 — Task E: Bitwise XOR (a XOR b mod p) → Bit-level circuit, no algebraic structure Literature grounding: - Nanda et al. 2023: Addition uses 5-frequency Fourier multiplication algorithm - Chughtai et al. 2023: All cyclic group ops use GCR algorithm - Yang et al. 2024: Comparison uses linear parallel circuit (not circular) - Feature Emergence (2311.07568): Max-margin solutions use irreducible representations """ import torch import numpy as np from torch.utils.data import Dataset, DataLoader from typing import Tuple, Dict, Optional # Special tokens — one per operation type PAD_TOKEN = 0 EQ_TOKEN = 1 # "=" token PLUS_TOKEN = 2 # "+" token MINUS_TOKEN = 3 # "-" token TIMES_TOKEN = 4 # "×" token MAX_TOKEN = 5 # "max" token XOR_TOKEN = 6 # "⊕" token NUM_SPECIAL = 7 # Default prime for modular arithmetic DEFAULT_P = 97 # Operator token lookup OP_TOKENS = { 'add': PLUS_TOKEN, 'subtract': MINUS_TOKEN, 'multiply': TIMES_TOKEN, 'max': MAX_TOKEN, 'xor': XOR_TOKEN, } # All operations in order of predicted dissimilarity from addition ALL_OPERATIONS = ['add', 'subtract', 'multiply', 'max', 'xor'] class ModularArithmeticDataset(Dataset): """ Dataset for modular/integer arithmetic: op(a, b). Input sequence: [a, op_token, b, eq_token, c] Labels masked for input tokens (only predict c). Supported operations: 'add': (a + b) mod p — Fourier circuit 'subtract': (a - b) mod p — Fourier circuit (sign flip) 'multiply': (a * b) mod p — Discrete-log Fourier circuit 'max': max(a, b) — Linear/ordinal circuit 'xor': (a XOR b) mod p — Bit-level circuit """ def __init__(self, operation: str = 'add', p: int = DEFAULT_P, split: str = 'train', train_frac: float = 0.5, seed: int = 42): self.p = p self.operation = operation self.op_token = OP_TOKENS[operation] # Generate all p*p pairs all_pairs = [(a, b) for a in range(p) for b in range(p)] rng = np.random.RandomState(seed) rng.shuffle(all_pairs) n_train = int(len(all_pairs) * train_frac) if split == 'train': self.pairs = all_pairs[:n_train] else: self.pairs = all_pairs[n_train:] def _compute(self, a: int, b: int) -> int: if self.operation == 'add': return (a + b) % self.p elif self.operation == 'subtract': return (a - b) % self.p elif self.operation == 'multiply': return (a * b) % self.p elif self.operation == 'max': return max(a, b) # result is in [0, p-1], no mod needed elif self.operation == 'xor': return (a ^ b) % self.p # mod p to keep in vocab range else: raise ValueError(f"Unknown operation: {self.operation}") def __len__(self): return len(self.pairs) def __getitem__(self, idx) -> Dict[str, torch.Tensor]: a, b = self.pairs[idx] c = self._compute(a, b) # Offset numbers by NUM_SPECIAL to avoid collision with special tokens a_tok = a + NUM_SPECIAL b_tok = b + NUM_SPECIAL c_tok = c + NUM_SPECIAL input_ids = torch.tensor([a_tok, self.op_token, b_tok, EQ_TOKEN, c_tok], dtype=torch.long) labels = torch.tensor([-100, -100, -100, -100, c_tok], dtype=torch.long) return {'input_ids': input_ids, 'labels': labels} @property def vocab_size(self): return self.p + NUM_SPECIAL def get_probe_data(dataset: ModularArithmeticDataset, n_samples: Optional[int] = None) -> Tuple[torch.Tensor, np.ndarray]: """ Extract fixed probe data from a dataset. Returns (input_ids_batch, answer_labels) for representation tracking. """ n = min(n_samples or len(dataset), len(dataset)) items = [dataset[i] for i in range(n)] input_ids = torch.stack([item['input_ids'] for item in items]) answers = np.array([item['labels'][-1].item() - NUM_SPECIAL for item in items]) return input_ids, answers def get_all_dataloaders(p: int = DEFAULT_P, batch_size: int = 512, train_frac: float = 0.5, seed: int = 42) -> Dict: """Get train/test dataloaders for ALL tasks.""" loaders = {} for op in ALL_OPERATIONS: for split in ['train', 'test']: ds = ModularArithmeticDataset( operation=op, p=p, split=split, train_frac=train_frac, seed=seed ) loaders[f'{op}_{split}'] = DataLoader( ds, batch_size=batch_size, shuffle=(split == 'train'), drop_last=False ) return loaders # Backward compatibility def get_dataloaders(p=DEFAULT_P, batch_size=512, train_frac=0.5, seed=42): return get_all_dataloaders(p, batch_size, train_frac, seed)