tekkmaven's picture
Upload tasks.py with huggingface_hub
6da318d verified
"""
Tasks for the Representation Learning Dynamics experiment.
============================================================
Graduated dissimilarity ladder of algorithmic tasks on integers mod p.
All share the same vocabulary but require increasingly different circuits.
Level 0 — Task A: Modular Addition (a + b mod p) → Fourier circuit
Level 1 — Task B: Modular Subtraction (a - b mod p) → Same Fourier circuit (sign flip)
Level 2 — Task C: Modular Multiplication (a * b mod p) → Discrete-log Fourier circuit
Level 3 — Task D: Max (ordered comparison) max(a, b) → Linear/ordinal circuit
Level 4 — Task E: Bitwise XOR (a XOR b mod p) → Bit-level circuit, no algebraic structure
Literature grounding:
- Nanda et al. 2023: Addition uses 5-frequency Fourier multiplication algorithm
- Chughtai et al. 2023: All cyclic group ops use GCR algorithm
- Yang et al. 2024: Comparison uses linear parallel circuit (not circular)
- Feature Emergence (2311.07568): Max-margin solutions use irreducible representations
"""
import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader
from typing import Tuple, Dict, Optional
# Special tokens — one per operation type
PAD_TOKEN = 0
EQ_TOKEN = 1 # "=" token
PLUS_TOKEN = 2 # "+" token
MINUS_TOKEN = 3 # "-" token
TIMES_TOKEN = 4 # "×" token
MAX_TOKEN = 5 # "max" token
XOR_TOKEN = 6 # "⊕" token
NUM_SPECIAL = 7
# Default prime for modular arithmetic
DEFAULT_P = 97
# Operator token lookup
OP_TOKENS = {
'add': PLUS_TOKEN,
'subtract': MINUS_TOKEN,
'multiply': TIMES_TOKEN,
'max': MAX_TOKEN,
'xor': XOR_TOKEN,
}
# All operations in order of predicted dissimilarity from addition
ALL_OPERATIONS = ['add', 'subtract', 'multiply', 'max', 'xor']
class ModularArithmeticDataset(Dataset):
"""
Dataset for modular/integer arithmetic: op(a, b).
Input sequence: [a, op_token, b, eq_token, c]
Labels masked for input tokens (only predict c).
Supported operations:
'add': (a + b) mod p — Fourier circuit
'subtract': (a - b) mod p — Fourier circuit (sign flip)
'multiply': (a * b) mod p — Discrete-log Fourier circuit
'max': max(a, b) — Linear/ordinal circuit
'xor': (a XOR b) mod p — Bit-level circuit
"""
def __init__(self, operation: str = 'add', p: int = DEFAULT_P,
split: str = 'train', train_frac: float = 0.5,
seed: int = 42):
self.p = p
self.operation = operation
self.op_token = OP_TOKENS[operation]
# Generate all p*p pairs
all_pairs = [(a, b) for a in range(p) for b in range(p)]
rng = np.random.RandomState(seed)
rng.shuffle(all_pairs)
n_train = int(len(all_pairs) * train_frac)
if split == 'train':
self.pairs = all_pairs[:n_train]
else:
self.pairs = all_pairs[n_train:]
def _compute(self, a: int, b: int) -> int:
if self.operation == 'add':
return (a + b) % self.p
elif self.operation == 'subtract':
return (a - b) % self.p
elif self.operation == 'multiply':
return (a * b) % self.p
elif self.operation == 'max':
return max(a, b) # result is in [0, p-1], no mod needed
elif self.operation == 'xor':
return (a ^ b) % self.p # mod p to keep in vocab range
else:
raise ValueError(f"Unknown operation: {self.operation}")
def __len__(self):
return len(self.pairs)
def __getitem__(self, idx) -> Dict[str, torch.Tensor]:
a, b = self.pairs[idx]
c = self._compute(a, b)
# Offset numbers by NUM_SPECIAL to avoid collision with special tokens
a_tok = a + NUM_SPECIAL
b_tok = b + NUM_SPECIAL
c_tok = c + NUM_SPECIAL
input_ids = torch.tensor([a_tok, self.op_token, b_tok, EQ_TOKEN, c_tok],
dtype=torch.long)
labels = torch.tensor([-100, -100, -100, -100, c_tok], dtype=torch.long)
return {'input_ids': input_ids, 'labels': labels}
@property
def vocab_size(self):
return self.p + NUM_SPECIAL
def get_probe_data(dataset: ModularArithmeticDataset,
n_samples: Optional[int] = None) -> Tuple[torch.Tensor, np.ndarray]:
"""
Extract fixed probe data from a dataset.
Returns (input_ids_batch, answer_labels) for representation tracking.
"""
n = min(n_samples or len(dataset), len(dataset))
items = [dataset[i] for i in range(n)]
input_ids = torch.stack([item['input_ids'] for item in items])
answers = np.array([item['labels'][-1].item() - NUM_SPECIAL for item in items])
return input_ids, answers
def get_all_dataloaders(p: int = DEFAULT_P,
batch_size: int = 512,
train_frac: float = 0.5,
seed: int = 42) -> Dict:
"""Get train/test dataloaders for ALL tasks."""
loaders = {}
for op in ALL_OPERATIONS:
for split in ['train', 'test']:
ds = ModularArithmeticDataset(
operation=op, p=p, split=split,
train_frac=train_frac, seed=seed
)
loaders[f'{op}_{split}'] = DataLoader(
ds, batch_size=batch_size, shuffle=(split == 'train'),
drop_last=False
)
return loaders
# Backward compatibility
def get_dataloaders(p=DEFAULT_P, batch_size=512, train_frac=0.5, seed=42):
return get_all_dataloaders(p, batch_size, train_frac, seed)