File size: 5,639 Bytes
ac2814e 6da318d ac2814e 6da318d ac2814e 6da318d ac2814e 6da318d ac2814e 6da318d ac2814e 6da318d ac2814e 6da318d ac2814e 6da318d ac2814e 6da318d ac2814e 6da318d ac2814e 6da318d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 | """
Tasks for the Representation Learning Dynamics experiment.
============================================================
Graduated dissimilarity ladder of algorithmic tasks on integers mod p.
All share the same vocabulary but require increasingly different circuits.
Level 0 β Task A: Modular Addition (a + b mod p) β Fourier circuit
Level 1 β Task B: Modular Subtraction (a - b mod p) β Same Fourier circuit (sign flip)
Level 2 β Task C: Modular Multiplication (a * b mod p) β Discrete-log Fourier circuit
Level 3 β Task D: Max (ordered comparison) max(a, b) β Linear/ordinal circuit
Level 4 β Task E: Bitwise XOR (a XOR b mod p) β Bit-level circuit, no algebraic structure
Literature grounding:
- Nanda et al. 2023: Addition uses 5-frequency Fourier multiplication algorithm
- Chughtai et al. 2023: All cyclic group ops use GCR algorithm
- Yang et al. 2024: Comparison uses linear parallel circuit (not circular)
- Feature Emergence (2311.07568): Max-margin solutions use irreducible representations
"""
import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader
from typing import Tuple, Dict, Optional
# Special tokens β one per operation type
PAD_TOKEN = 0
EQ_TOKEN = 1 # "=" token
PLUS_TOKEN = 2 # "+" token
MINUS_TOKEN = 3 # "-" token
TIMES_TOKEN = 4 # "Γ" token
MAX_TOKEN = 5 # "max" token
XOR_TOKEN = 6 # "β" token
NUM_SPECIAL = 7
# Default prime for modular arithmetic
DEFAULT_P = 97
# Operator token lookup
OP_TOKENS = {
'add': PLUS_TOKEN,
'subtract': MINUS_TOKEN,
'multiply': TIMES_TOKEN,
'max': MAX_TOKEN,
'xor': XOR_TOKEN,
}
# All operations in order of predicted dissimilarity from addition
ALL_OPERATIONS = ['add', 'subtract', 'multiply', 'max', 'xor']
class ModularArithmeticDataset(Dataset):
"""
Dataset for modular/integer arithmetic: op(a, b).
Input sequence: [a, op_token, b, eq_token, c]
Labels masked for input tokens (only predict c).
Supported operations:
'add': (a + b) mod p β Fourier circuit
'subtract': (a - b) mod p β Fourier circuit (sign flip)
'multiply': (a * b) mod p β Discrete-log Fourier circuit
'max': max(a, b) β Linear/ordinal circuit
'xor': (a XOR b) mod p β Bit-level circuit
"""
def __init__(self, operation: str = 'add', p: int = DEFAULT_P,
split: str = 'train', train_frac: float = 0.5,
seed: int = 42):
self.p = p
self.operation = operation
self.op_token = OP_TOKENS[operation]
# Generate all p*p pairs
all_pairs = [(a, b) for a in range(p) for b in range(p)]
rng = np.random.RandomState(seed)
rng.shuffle(all_pairs)
n_train = int(len(all_pairs) * train_frac)
if split == 'train':
self.pairs = all_pairs[:n_train]
else:
self.pairs = all_pairs[n_train:]
def _compute(self, a: int, b: int) -> int:
if self.operation == 'add':
return (a + b) % self.p
elif self.operation == 'subtract':
return (a - b) % self.p
elif self.operation == 'multiply':
return (a * b) % self.p
elif self.operation == 'max':
return max(a, b) # result is in [0, p-1], no mod needed
elif self.operation == 'xor':
return (a ^ b) % self.p # mod p to keep in vocab range
else:
raise ValueError(f"Unknown operation: {self.operation}")
def __len__(self):
return len(self.pairs)
def __getitem__(self, idx) -> Dict[str, torch.Tensor]:
a, b = self.pairs[idx]
c = self._compute(a, b)
# Offset numbers by NUM_SPECIAL to avoid collision with special tokens
a_tok = a + NUM_SPECIAL
b_tok = b + NUM_SPECIAL
c_tok = c + NUM_SPECIAL
input_ids = torch.tensor([a_tok, self.op_token, b_tok, EQ_TOKEN, c_tok],
dtype=torch.long)
labels = torch.tensor([-100, -100, -100, -100, c_tok], dtype=torch.long)
return {'input_ids': input_ids, 'labels': labels}
@property
def vocab_size(self):
return self.p + NUM_SPECIAL
def get_probe_data(dataset: ModularArithmeticDataset,
n_samples: Optional[int] = None) -> Tuple[torch.Tensor, np.ndarray]:
"""
Extract fixed probe data from a dataset.
Returns (input_ids_batch, answer_labels) for representation tracking.
"""
n = min(n_samples or len(dataset), len(dataset))
items = [dataset[i] for i in range(n)]
input_ids = torch.stack([item['input_ids'] for item in items])
answers = np.array([item['labels'][-1].item() - NUM_SPECIAL for item in items])
return input_ids, answers
def get_all_dataloaders(p: int = DEFAULT_P,
batch_size: int = 512,
train_frac: float = 0.5,
seed: int = 42) -> Dict:
"""Get train/test dataloaders for ALL tasks."""
loaders = {}
for op in ALL_OPERATIONS:
for split in ['train', 'test']:
ds = ModularArithmeticDataset(
operation=op, p=p, split=split,
train_frac=train_frac, seed=seed
)
loaders[f'{op}_{split}'] = DataLoader(
ds, batch_size=batch_size, shuffle=(split == 'train'),
drop_last=False
)
return loaders
# Backward compatibility
def get_dataloaders(p=DEFAULT_P, batch_size=512, train_frac=0.5, seed=42):
return get_all_dataloaders(p, batch_size, train_frac, seed)
|