File size: 5,639 Bytes
ac2814e
 
 
6da318d
 
 
 
 
 
 
 
 
 
 
 
 
 
ac2814e
 
 
 
 
 
 
 
6da318d
ac2814e
6da318d
 
 
 
 
 
 
ac2814e
 
 
 
6da318d
 
 
 
 
 
 
 
 
 
 
 
ac2814e
 
 
6da318d
 
ac2814e
6da318d
 
 
 
 
 
 
ac2814e
 
 
 
 
 
 
6da318d
ac2814e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6da318d
 
 
 
 
 
ac2814e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6da318d
 
 
 
 
ac2814e
6da318d
ac2814e
 
 
 
 
 
 
 
 
 
6da318d
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
"""
Tasks for the Representation Learning Dynamics experiment.
============================================================
Graduated dissimilarity ladder of algorithmic tasks on integers mod p.
All share the same vocabulary but require increasingly different circuits.

Level 0 β€” Task A: Modular Addition      (a + b mod p)   β†’ Fourier circuit
Level 1 β€” Task B: Modular Subtraction   (a - b mod p)   β†’ Same Fourier circuit (sign flip)
Level 2 β€” Task C: Modular Multiplication (a * b mod p)   β†’ Discrete-log Fourier circuit
Level 3 β€” Task D: Max (ordered comparison) max(a, b)     β†’ Linear/ordinal circuit
Level 4 β€” Task E: Bitwise XOR           (a XOR b mod p)  β†’ Bit-level circuit, no algebraic structure

Literature grounding:
- Nanda et al. 2023: Addition uses 5-frequency Fourier multiplication algorithm
- Chughtai et al. 2023: All cyclic group ops use GCR algorithm
- Yang et al. 2024: Comparison uses linear parallel circuit (not circular)
- Feature Emergence (2311.07568): Max-margin solutions use irreducible representations
"""

import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader
from typing import Tuple, Dict, Optional


# Special tokens β€” one per operation type
PAD_TOKEN = 0
EQ_TOKEN = 1     # "=" token
PLUS_TOKEN = 2   # "+" token
MINUS_TOKEN = 3  # "-" token
TIMES_TOKEN = 4  # "Γ—" token
MAX_TOKEN = 5    # "max" token
XOR_TOKEN = 6    # "βŠ•" token
NUM_SPECIAL = 7

# Default prime for modular arithmetic
DEFAULT_P = 97

# Operator token lookup
OP_TOKENS = {
    'add': PLUS_TOKEN,
    'subtract': MINUS_TOKEN,
    'multiply': TIMES_TOKEN,
    'max': MAX_TOKEN,
    'xor': XOR_TOKEN,
}

# All operations in order of predicted dissimilarity from addition
ALL_OPERATIONS = ['add', 'subtract', 'multiply', 'max', 'xor']


class ModularArithmeticDataset(Dataset):
    """
    Dataset for modular/integer arithmetic: op(a, b).
    Input sequence: [a, op_token, b, eq_token, c]
    Labels masked for input tokens (only predict c).

    Supported operations:
      'add':      (a + b) mod p     β€” Fourier circuit
      'subtract': (a - b) mod p     β€” Fourier circuit (sign flip)
      'multiply': (a * b) mod p     β€” Discrete-log Fourier circuit
      'max':      max(a, b)         β€” Linear/ordinal circuit
      'xor':      (a XOR b) mod p   β€” Bit-level circuit
    """

    def __init__(self, operation: str = 'add', p: int = DEFAULT_P,
                 split: str = 'train', train_frac: float = 0.5,
                 seed: int = 42):
        self.p = p
        self.operation = operation
        self.op_token = OP_TOKENS[operation]

        # Generate all p*p pairs
        all_pairs = [(a, b) for a in range(p) for b in range(p)]
        rng = np.random.RandomState(seed)
        rng.shuffle(all_pairs)

        n_train = int(len(all_pairs) * train_frac)
        if split == 'train':
            self.pairs = all_pairs[:n_train]
        else:
            self.pairs = all_pairs[n_train:]

    def _compute(self, a: int, b: int) -> int:
        if self.operation == 'add':
            return (a + b) % self.p
        elif self.operation == 'subtract':
            return (a - b) % self.p
        elif self.operation == 'multiply':
            return (a * b) % self.p
        elif self.operation == 'max':
            return max(a, b)  # result is in [0, p-1], no mod needed
        elif self.operation == 'xor':
            return (a ^ b) % self.p  # mod p to keep in vocab range
        else:
            raise ValueError(f"Unknown operation: {self.operation}")

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx) -> Dict[str, torch.Tensor]:
        a, b = self.pairs[idx]
        c = self._compute(a, b)

        # Offset numbers by NUM_SPECIAL to avoid collision with special tokens
        a_tok = a + NUM_SPECIAL
        b_tok = b + NUM_SPECIAL
        c_tok = c + NUM_SPECIAL

        input_ids = torch.tensor([a_tok, self.op_token, b_tok, EQ_TOKEN, c_tok],
                                 dtype=torch.long)
        labels = torch.tensor([-100, -100, -100, -100, c_tok], dtype=torch.long)

        return {'input_ids': input_ids, 'labels': labels}

    @property
    def vocab_size(self):
        return self.p + NUM_SPECIAL


def get_probe_data(dataset: ModularArithmeticDataset,
                   n_samples: Optional[int] = None) -> Tuple[torch.Tensor, np.ndarray]:
    """
    Extract fixed probe data from a dataset.
    Returns (input_ids_batch, answer_labels) for representation tracking.
    """
    n = min(n_samples or len(dataset), len(dataset))
    items = [dataset[i] for i in range(n)]
    input_ids = torch.stack([item['input_ids'] for item in items])
    answers = np.array([item['labels'][-1].item() - NUM_SPECIAL for item in items])
    return input_ids, answers


def get_all_dataloaders(p: int = DEFAULT_P,
                        batch_size: int = 512,
                        train_frac: float = 0.5,
                        seed: int = 42) -> Dict:
    """Get train/test dataloaders for ALL tasks."""
    loaders = {}
    for op in ALL_OPERATIONS:
        for split in ['train', 'test']:
            ds = ModularArithmeticDataset(
                operation=op, p=p, split=split,
                train_frac=train_frac, seed=seed
            )
            loaders[f'{op}_{split}'] = DataLoader(
                ds, batch_size=batch_size, shuffle=(split == 'train'),
                drop_last=False
            )
    return loaders


# Backward compatibility
def get_dataloaders(p=DEFAULT_P, batch_size=512, train_frac=0.5, seed=42):
    return get_all_dataloaders(p, batch_size, train_frac, seed)