amirali1985 commited on
Commit
641998d
·
verified ·
1 Parent(s): 241a3ad

Upload modular/code/fourier_analysis.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modular/code/fourier_analysis.py +391 -0
modular/code/fourier_analysis.py ADDED
@@ -0,0 +1,391 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Fourier analysis of SoRL abstract tokens on modular arithmetic.
3
+
4
+ Tests Nanda's hypothesis: do abstract tokens encode Fourier components of (a+b) mod p?
5
+
6
+ Two analyses:
7
+ 1. Assignment analysis — for each (a,b) pair, which abstract token does the model assign?
8
+ Does the assignment function cluster by (a+b) mod p?
9
+ 2. Embedding analysis — do abstract token embeddings organize along sin/cos curves
10
+ in Fourier frequency space?
11
+
12
+ Usage:
13
+ python -m arithmetic.modular.experiments.11_fourier_analysis.run \
14
+ --model_dir arithmetic/runs/mod_sorl_fourier/final \
15
+ --out_dir arithmetic/modular/experiments/11_fourier_analysis/results
16
+ """
17
+ import sys, json, argparse
18
+ from pathlib import Path
19
+ sys.path.insert(0, str(Path(__file__).resolve().parents[5]))
20
+
21
+ import numpy as np
22
+ import torch
23
+ import matplotlib
24
+ matplotlib.use("Agg")
25
+ import matplotlib.pyplot as plt
26
+ from matplotlib.colors import Normalize
27
+
28
+ from sorl.sorl_wrapper import SorlModelWrapper
29
+ from sorl.sorl_trainer import sorl_search
30
+ from arithmetic.modular.data.modular import (
31
+ generate_dataset, P, VOCAB_SIZE, PAD, PROMPT_LEN,
32
+ )
33
+
34
+
35
+ # ---------------------------------------------------------------------------
36
+ # Load model
37
+ # ---------------------------------------------------------------------------
38
+
39
+ def load_model(model_dir: str, device: str) -> SorlModelWrapper:
40
+ model_dir = Path(model_dir)
41
+ with open(model_dir / "sorl_config.json") as f:
42
+ cfg = json.load(f)
43
+ from transformers import Qwen3Config
44
+ config = Qwen3Config(
45
+ hidden_size=cfg["n_embd"], num_hidden_layers=cfg["n_layer"],
46
+ num_attention_heads=cfg["n_head"], num_key_value_heads=cfg["n_head"],
47
+ intermediate_size=cfg["d_mlp"], vocab_size=VOCAB_SIZE,
48
+ max_position_embeddings=32,
49
+ )
50
+ model = SorlModelWrapper.from_scratch(config, [VOCAB_SIZE, cfg["abs_vocab"]], PAD)
51
+ model.load_state_dict(torch.load(model_dir / "model_state_dict.pt", map_location="cpu"))
52
+ return model.to(device).eval()
53
+
54
+
55
+ # ---------------------------------------------------------------------------
56
+ # Extract abstract token assignments for all (a, b) pairs
57
+ # ---------------------------------------------------------------------------
58
+
59
+ @torch.no_grad()
60
+ def get_assignments(model, all_examples, K: int, device: str, batch_size: int = 256):
61
+ """
62
+ For every (a, b) pair return the abstract token IDs assigned at each abstract position.
63
+
64
+ Returns:
65
+ assignments: np.ndarray shape (N, n_abs_positions) — token IDs
66
+ sums: np.ndarray shape (N,) — (a+b) mod p
67
+ pairs: list of (a, b) tuples
68
+ """
69
+ base_v = int(model.vocab_sizes[0].item())
70
+ all_assignments = []
71
+ all_sums = []
72
+ all_pairs = []
73
+
74
+ for start in range(0, len(all_examples), batch_size):
75
+ batch = all_examples[start:start + batch_size]
76
+ ids = torch.tensor([e.tokens for e in batch], dtype=torch.long, device=device)
77
+ attn = torch.ones_like(ids)
78
+ pl = torch.full((ids.shape[0],), PROMPT_LEN, dtype=torch.long, device=device)
79
+
80
+ best_data, _, _, _, _ = sorl_search(
81
+ model, ids, attn, pl, PAD,
82
+ n=1, K=K, max_iterations=2,
83
+ memory_span_abs=512, memory_span_traj=512,
84
+ temperature=0.0,
85
+ )
86
+
87
+ # Abstract positions: tokens >= base_v
88
+ for i, ex in enumerate(batch):
89
+ seq = best_data[i].cpu().tolist()
90
+ abs_tokens = [t - base_v for t in seq if t >= base_v]
91
+ all_assignments.append(abs_tokens)
92
+ all_sums.append((ex.a + ex.b) % P)
93
+ all_pairs.append((ex.a, ex.b))
94
+
95
+ max_len = max(len(a) for a in all_assignments)
96
+ padded = np.array([a + [-1] * (max_len - len(a)) for a in all_assignments])
97
+ return padded, np.array(all_sums), all_pairs
98
+
99
+
100
+ # ---------------------------------------------------------------------------
101
+ # Analysis 1: Assignment purity — does each abstract token cluster by sum?
102
+ # ---------------------------------------------------------------------------
103
+
104
+ def assignment_purity(assignments, sums, out_dir: Path, abs_vocab: int):
105
+ n_pos = assignments.shape[1]
106
+ fig, axes = plt.subplots(1, n_pos, figsize=(5 * n_pos, 4))
107
+ if n_pos == 1:
108
+ axes = [axes]
109
+
110
+ results = {}
111
+ for pos in range(n_pos):
112
+ col = assignments[:, pos]
113
+ valid = col >= 0
114
+ col_v = col[valid]
115
+ sums_v = sums[valid]
116
+
117
+ # For each token, what distribution over sums does it cover?
118
+ token_sum_dist = {}
119
+ for t in range(abs_vocab):
120
+ mask = col_v == t
121
+ if mask.sum() == 0:
122
+ continue
123
+ token_sum_dist[t] = sums_v[mask]
124
+
125
+ # Plot: x=token id, y=sum, scatter
126
+ ax = axes[pos]
127
+ for t, s in token_sum_dist.items():
128
+ ax.scatter([t] * len(s), s, alpha=0.1, s=2, color="steelblue")
129
+ ax.set_xlabel("Abstract token ID")
130
+ ax.set_ylabel("(a+b) mod p")
131
+ ax.set_title(f"Position {pos}: token vs sum")
132
+
133
+ # Compute mean sum per token (how ordered is it?)
134
+ means = {t: s.mean() for t, s in token_sum_dist.items()}
135
+ results[pos] = {"n_used": len(token_sum_dist), "means": means}
136
+ print(f" Position {pos}: {len(token_sum_dist)} tokens used")
137
+
138
+ plt.tight_layout()
139
+ plt.savefig(out_dir / "assignment_scatter.png", dpi=120)
140
+ plt.close()
141
+ return results
142
+
143
+
144
+ # ---------------------------------------------------------------------------
145
+ # Analysis 2: Fourier structure in assignments
146
+ # ---------------------------------------------------------------------------
147
+
148
+ def fourier_of_assignments(assignments, sums, out_dir: Path, abs_vocab: int):
149
+ """
150
+ For each abstract position, treat the assignment function f(sum) as a
151
+ discrete signal over Z_p and compute its DFT. Strong peaks at specific
152
+ frequencies indicate Fourier structure.
153
+ """
154
+ n_pos = assignments.shape[1]
155
+ fig, axes = plt.subplots(1, n_pos, figsize=(5 * n_pos, 4))
156
+ if n_pos == 1:
157
+ axes = [axes]
158
+
159
+ for pos in range(n_pos):
160
+ col = assignments[:, pos]
161
+ valid = col >= 0
162
+
163
+ # Build signal: for each possible sum value s in 0..p-1,
164
+ # compute the average abstract token ID assigned
165
+ signal = np.zeros(P)
166
+ counts = np.zeros(P)
167
+ for tok, s in zip(col[valid], sums[valid]):
168
+ signal[s] += tok
169
+ counts[s] += 1
170
+ counts = np.maximum(counts, 1)
171
+ signal /= counts # mean token ID per sum value
172
+
173
+ # DFT
174
+ freqs = np.abs(np.fft.rfft(signal))
175
+ ax = axes[pos]
176
+ ax.bar(range(len(freqs)), freqs)
177
+ ax.set_xlabel("Frequency k")
178
+ ax.set_ylabel("|DFT|")
179
+ ax.set_title(f"Position {pos}: DFT of mean-token-id(sum)")
180
+ top_k = np.argsort(freqs)[::-1][:5]
181
+ print(f" Position {pos} top-5 frequencies: {top_k.tolist()} (magnitudes: {freqs[top_k].round(2).tolist()})")
182
+
183
+ plt.tight_layout()
184
+ plt.savefig(out_dir / "assignment_fourier.png", dpi=120)
185
+ plt.close()
186
+
187
+
188
+ # ---------------------------------------------------------------------------
189
+ # Analysis 3: Abstract token embeddings in Fourier space
190
+ # ---------------------------------------------------------------------------
191
+
192
+ def embedding_fourier(model, out_dir: Path):
193
+ """
194
+ Extract abstract token embedding vectors and check if they organize
195
+ along sin/cos curves for specific Fourier frequencies.
196
+
197
+ Analogous to Nanda's analysis of token embeddings via DFT.
198
+ """
199
+ base_v = int(model.vocab_sizes[0].item())
200
+ abs_v = int(model.vocab_sizes[1].item()) # number of abstract token types
201
+
202
+ # Get embedding matrix for abstract tokens (skip placeholder at base_v)
203
+ embed = model.model.model.embed_tokens.weight # (total_vocab, d_model)
204
+ abs_embeds = embed[base_v + 1: base_v + 1 + abs_v].detach().cpu().float().numpy()
205
+ # shape: (abs_v, d_model)
206
+
207
+ print(f" Abstract embedding matrix: {abs_embeds.shape}")
208
+
209
+ # SVD to find dominant directions
210
+ U, S, Vt = np.linalg.svd(abs_embeds, full_matrices=False)
211
+ print(f" Top-5 singular values: {S[:5].round(3).tolist()}")
212
+
213
+ # Plot singular values
214
+ fig, ax = plt.subplots(figsize=(6, 4))
215
+ ax.bar(range(len(S)), S)
216
+ ax.set_xlabel("Component")
217
+ ax.set_ylabel("Singular value")
218
+ ax.set_title("Abstract embedding SVD")
219
+ plt.tight_layout()
220
+ plt.savefig(out_dir / "embedding_svd.png", dpi=120)
221
+ plt.close()
222
+
223
+ # Plot top-2 components as scatter to see if they form a circle
224
+ n_actual = abs_embeds.shape[0]
225
+ if n_actual >= 3:
226
+ fig, ax = plt.subplots(figsize=(5, 5))
227
+ ax.scatter(U[:, 0], U[:, 1], c=list(range(n_actual)), cmap="hsv", s=60)
228
+ for i in range(n_actual):
229
+ ax.annotate(str(i), (U[i, 0], U[i, 1]), fontsize=7)
230
+ ax.set_title("Abstract tokens in top-2 SVD directions")
231
+ ax.set_xlabel("PC1")
232
+ ax.set_ylabel("PC2")
233
+ plt.tight_layout()
234
+ plt.savefig(out_dir / "embedding_pca.png", dpi=120)
235
+ plt.close()
236
+
237
+
238
+ # ---------------------------------------------------------------------------
239
+ # Analysis 4: Heatmap of token assignment over (a, b) grid
240
+ # ---------------------------------------------------------------------------
241
+
242
+ def assignment_heatmap(assignments, all_pairs, out_dir: Path):
243
+ n_pos = assignments.shape[1]
244
+ fig, axes = plt.subplots(1, n_pos, figsize=(5 * n_pos, 4))
245
+ if n_pos == 1:
246
+ axes = [axes]
247
+
248
+ for pos in range(n_pos):
249
+ grid = np.full((P, P), -1, dtype=float)
250
+ for (a, b), tok in zip(all_pairs, assignments[:, pos]):
251
+ if tok >= 0:
252
+ grid[a, b] = tok
253
+ ax = axes[pos]
254
+ im = ax.imshow(grid, origin="lower", cmap="tab20", aspect="auto")
255
+ ax.set_xlabel("b")
256
+ ax.set_ylabel("a")
257
+ ax.set_title(f"Position {pos}: token assignment grid")
258
+ plt.colorbar(im, ax=ax, fraction=0.046)
259
+
260
+ plt.tight_layout()
261
+ plt.savefig(out_dir / "assignment_heatmap.png", dpi=120)
262
+ plt.close()
263
+
264
+
265
+ # ---------------------------------------------------------------------------
266
+ # Analysis 5: Fourier analysis over EMBEDDING SPACE (not token ID)
267
+ # ---------------------------------------------------------------------------
268
+
269
+ def embedding_fourier_by_sum(model, assignments, sums, out_dir: Path):
270
+ """
271
+ For each abstract position, compute h(a,b) = embedding of assigned token.
272
+ Then DFT over (a+b) mod p to find dominant Fourier frequencies.
273
+
274
+ This is the correct Nanda-style analysis: not which token was assigned,
275
+ but what embedding vector the model placed at that position.
276
+ """
277
+ base_v = int(model.vocab_sizes[0].item())
278
+ embed = model.model.model.embed_tokens.weight.detach().cpu().float().numpy()
279
+ # embed[base_v + 1 + t] = embedding of abstract token t
280
+
281
+ n_pos = assignments.shape[1]
282
+ d_model = embed.shape[1]
283
+
284
+ fig, axes = plt.subplots(2, n_pos, figsize=(5 * n_pos, 8))
285
+
286
+ for pos in range(n_pos):
287
+ col = assignments[:, pos] # (N,) token IDs (0-indexed within abs vocab)
288
+ valid = col >= 0
289
+
290
+ # Build (p, d_model) matrix: mean embedding per sum value
291
+ mean_emb = np.zeros((P, d_model))
292
+ counts = np.zeros(P)
293
+ for tok, s in zip(col[valid], sums[valid]):
294
+ emb = embed[base_v + tok]
295
+ mean_emb[s] += emb
296
+ counts[s] += 1
297
+ counts = np.maximum(counts, 1).reshape(-1, 1)
298
+ mean_emb /= counts # (p, d_model)
299
+
300
+ # DFT over sum dimension for each embedding dim
301
+ freq_power = np.abs(np.fft.rfft(mean_emb, axis=0)) # (p//2+1, d_model)
302
+ total_power_per_freq = freq_power.sum(axis=1) # (p//2+1,)
303
+
304
+ top_k = np.argsort(total_power_per_freq)[::-1][:10]
305
+ print(f" Position {pos} top-10 frequencies (by total embedding power):")
306
+ print(f" freqs: {top_k.tolist()}")
307
+ print(f" powers: {total_power_per_freq[top_k].round(1).tolist()}")
308
+
309
+ # Plot: total power per frequency
310
+ ax = axes[0, pos]
311
+ ax.bar(range(len(total_power_per_freq)), total_power_per_freq)
312
+ ax.set_xlabel("Frequency k")
313
+ ax.set_ylabel("Total |DFT| across dims")
314
+ ax.set_title(f"Pos {pos}: embedding DFT power")
315
+ # Zoom in on non-DC frequencies
316
+ ax2 = axes[1, pos]
317
+ ax2.bar(range(1, len(total_power_per_freq)), total_power_per_freq[1:])
318
+ ax2.set_xlabel("Frequency k (DC removed)")
319
+ ax2.set_ylabel("Total |DFT| across dims")
320
+ ax2.set_title(f"Pos {pos}: non-DC frequencies")
321
+
322
+ # Save per-dim frequency matrix for later
323
+ np.save(out_dir / f"freq_power_pos{pos}.npy", freq_power)
324
+
325
+ plt.tight_layout()
326
+ plt.savefig(out_dir / "embedding_freq_by_sum.png", dpi=120)
327
+ plt.close()
328
+
329
+ # Also: check if dominant non-DC freq is consistent across positions
330
+ print("\n Summary: dominant non-DC frequency per position:")
331
+ for pos in range(n_pos):
332
+ freq_power = np.load(out_dir / f"freq_power_pos{pos}.npy")
333
+ total = freq_power.sum(axis=1)
334
+ top_nondc = np.argsort(total[1:])[::-1][:3] + 1
335
+ print(f" pos {pos}: top-3 non-DC = {top_nondc.tolist()}, "
336
+ f"ratio to DC = {total[top_nondc[0]]/total[0]:.3f}")
337
+
338
+
339
+ # ---------------------------------------------------------------------------
340
+ # Main
341
+ # ---------------------------------------------------------------------------
342
+
343
+ def main():
344
+ p = argparse.ArgumentParser()
345
+ p.add_argument("--model_dir", default="arithmetic/runs/mod_sorl_fourier/final")
346
+ p.add_argument("--out_dir", default="arithmetic/modular/experiments/11_fourier_analysis/results")
347
+ p.add_argument("--K", type=int, default=1)
348
+ p.add_argument("--abs_vocab", type=int, default=30)
349
+ p.add_argument("--device", default="cuda:0")
350
+ p.add_argument("--batch_size",type=int, default=256)
351
+ args = p.parse_args()
352
+
353
+ out_dir = Path(args.out_dir)
354
+ out_dir.mkdir(parents=True, exist_ok=True)
355
+
356
+ print("Loading model...")
357
+ model = load_model(args.model_dir, args.device)
358
+
359
+ print("Generating all (a,b) pairs...")
360
+ train_ex, test_ex = generate_dataset(p=P, seed=42)
361
+ all_ex = train_ex + test_ex
362
+ print(f" Total examples: {len(all_ex)}")
363
+
364
+ print("Extracting abstract token assignments...")
365
+ assignments, sums, pairs = get_assignments(
366
+ model, all_ex, K=args.K, device=args.device, batch_size=args.batch_size
367
+ )
368
+ print(f" Assignment matrix shape: {assignments.shape}")
369
+ np.save(out_dir / "assignments.npy", assignments)
370
+ np.save(out_dir / "sums.npy", sums)
371
+
372
+ print("\n--- Analysis 1: Assignment purity ---")
373
+ assignment_purity(assignments, sums, out_dir, args.abs_vocab)
374
+
375
+ print("\n--- Analysis 2: Fourier structure in assignments ---")
376
+ fourier_of_assignments(assignments, sums, out_dir, args.abs_vocab)
377
+
378
+ print("\n--- Analysis 3: Embedding Fourier analysis ---")
379
+ embedding_fourier(model, out_dir)
380
+
381
+ print("\n--- Analysis 4: Assignment heatmap over (a, b) grid ---")
382
+ assignment_heatmap(assignments, pairs, out_dir)
383
+
384
+ print("\n--- Analysis 5: Fourier analysis of abstract token EMBEDDINGS ---")
385
+ embedding_fourier_by_sum(model, assignments, sums, out_dir)
386
+
387
+ print(f"\nDone. Results in {out_dir}")
388
+
389
+
390
+ if __name__ == "__main__":
391
+ main()