| """ |
| Tasks for the Representation Learning Dynamics experiment. |
| ============================================================ |
| Graduated dissimilarity ladder of algorithmic tasks on integers mod p. |
| All share the same vocabulary but require increasingly different circuits. |
| |
| Level 0 — Task A: Modular Addition (a + b mod p) → Fourier circuit |
| Level 1 — Task B: Modular Subtraction (a - b mod p) → Same Fourier circuit (sign flip) |
| Level 2 — Task C: Modular Multiplication (a * b mod p) → Discrete-log Fourier circuit |
| Level 3 — Task D: Max (ordered comparison) max(a, b) → Linear/ordinal circuit |
| Level 4 — Task E: Bitwise XOR (a XOR b mod p) → Bit-level circuit, no algebraic structure |
| |
| Literature grounding: |
| - Nanda et al. 2023: Addition uses 5-frequency Fourier multiplication algorithm |
| - Chughtai et al. 2023: All cyclic group ops use GCR algorithm |
| - Yang et al. 2024: Comparison uses linear parallel circuit (not circular) |
| - Feature Emergence (2311.07568): Max-margin solutions use irreducible representations |
| """ |
|
|
| import torch |
| import numpy as np |
| from torch.utils.data import Dataset, DataLoader |
| from typing import Tuple, Dict, Optional |
|
|
|
|
| |
| PAD_TOKEN = 0 |
| EQ_TOKEN = 1 |
| PLUS_TOKEN = 2 |
| MINUS_TOKEN = 3 |
| TIMES_TOKEN = 4 |
| MAX_TOKEN = 5 |
| XOR_TOKEN = 6 |
| NUM_SPECIAL = 7 |
|
|
| |
| DEFAULT_P = 97 |
|
|
| |
| OP_TOKENS = { |
| 'add': PLUS_TOKEN, |
| 'subtract': MINUS_TOKEN, |
| 'multiply': TIMES_TOKEN, |
| 'max': MAX_TOKEN, |
| 'xor': XOR_TOKEN, |
| } |
|
|
| |
| ALL_OPERATIONS = ['add', 'subtract', 'multiply', 'max', 'xor'] |
|
|
|
|
| class ModularArithmeticDataset(Dataset): |
| """ |
| Dataset for modular/integer arithmetic: op(a, b). |
| Input sequence: [a, op_token, b, eq_token, c] |
| Labels masked for input tokens (only predict c). |
| |
| Supported operations: |
| 'add': (a + b) mod p — Fourier circuit |
| 'subtract': (a - b) mod p — Fourier circuit (sign flip) |
| 'multiply': (a * b) mod p — Discrete-log Fourier circuit |
| 'max': max(a, b) — Linear/ordinal circuit |
| 'xor': (a XOR b) mod p — Bit-level circuit |
| """ |
|
|
| def __init__(self, operation: str = 'add', p: int = DEFAULT_P, |
| split: str = 'train', train_frac: float = 0.5, |
| seed: int = 42): |
| self.p = p |
| self.operation = operation |
| self.op_token = OP_TOKENS[operation] |
|
|
| |
| all_pairs = [(a, b) for a in range(p) for b in range(p)] |
| rng = np.random.RandomState(seed) |
| rng.shuffle(all_pairs) |
|
|
| n_train = int(len(all_pairs) * train_frac) |
| if split == 'train': |
| self.pairs = all_pairs[:n_train] |
| else: |
| self.pairs = all_pairs[n_train:] |
|
|
| def _compute(self, a: int, b: int) -> int: |
| if self.operation == 'add': |
| return (a + b) % self.p |
| elif self.operation == 'subtract': |
| return (a - b) % self.p |
| elif self.operation == 'multiply': |
| return (a * b) % self.p |
| elif self.operation == 'max': |
| return max(a, b) |
| elif self.operation == 'xor': |
| return (a ^ b) % self.p |
| else: |
| raise ValueError(f"Unknown operation: {self.operation}") |
|
|
| def __len__(self): |
| return len(self.pairs) |
|
|
| def __getitem__(self, idx) -> Dict[str, torch.Tensor]: |
| a, b = self.pairs[idx] |
| c = self._compute(a, b) |
|
|
| |
| a_tok = a + NUM_SPECIAL |
| b_tok = b + NUM_SPECIAL |
| c_tok = c + NUM_SPECIAL |
|
|
| input_ids = torch.tensor([a_tok, self.op_token, b_tok, EQ_TOKEN, c_tok], |
| dtype=torch.long) |
| labels = torch.tensor([-100, -100, -100, -100, c_tok], dtype=torch.long) |
|
|
| return {'input_ids': input_ids, 'labels': labels} |
|
|
| @property |
| def vocab_size(self): |
| return self.p + NUM_SPECIAL |
|
|
|
|
| def get_probe_data(dataset: ModularArithmeticDataset, |
| n_samples: Optional[int] = None) -> Tuple[torch.Tensor, np.ndarray]: |
| """ |
| Extract fixed probe data from a dataset. |
| Returns (input_ids_batch, answer_labels) for representation tracking. |
| """ |
| n = min(n_samples or len(dataset), len(dataset)) |
| items = [dataset[i] for i in range(n)] |
| input_ids = torch.stack([item['input_ids'] for item in items]) |
| answers = np.array([item['labels'][-1].item() - NUM_SPECIAL for item in items]) |
| return input_ids, answers |
|
|
|
|
| def get_all_dataloaders(p: int = DEFAULT_P, |
| batch_size: int = 512, |
| train_frac: float = 0.5, |
| seed: int = 42) -> Dict: |
| """Get train/test dataloaders for ALL tasks.""" |
| loaders = {} |
| for op in ALL_OPERATIONS: |
| for split in ['train', 'test']: |
| ds = ModularArithmeticDataset( |
| operation=op, p=p, split=split, |
| train_frac=train_frac, seed=seed |
| ) |
| loaders[f'{op}_{split}'] = DataLoader( |
| ds, batch_size=batch_size, shuffle=(split == 'train'), |
| drop_last=False |
| ) |
| return loaders |
|
|
|
|
| |
| def get_dataloaders(p=DEFAULT_P, batch_size=512, train_frac=0.5, seed=42): |
| return get_all_dataloaders(p, batch_size, train_frac, seed) |
|
|