|
|
""" |
|
|
Automated neuron selection strategies for all primes. |
|
|
Replaces hard-coded neuron indices from the analysis notebooks. |
|
|
""" |
|
|
import torch |
|
|
import numpy as np |
|
|
from collections import Counter |
|
|
|
|
|
|
|
|
def select_top_neurons_by_frequency(max_freq_ls, W_in_decode, n=20): |
|
|
""" |
|
|
Select top N neurons covering all frequencies (round-robin). |
|
|
Used for heatmap plots (Tab 2). |
|
|
|
|
|
Picks the highest-magnitude neuron from each frequency in turn, |
|
|
cycling through frequencies until n neurons are selected. This ensures |
|
|
the heatmap shows diversification across all frequencies, matching |
|
|
the blog's Figure 2. |
|
|
|
|
|
Returns list of neuron indices into the original d_mlp-sized arrays. |
|
|
""" |
|
|
d_mlp = W_in_decode.shape[0] |
|
|
magnitudes = W_in_decode.abs().max(dim=1).values |
|
|
|
|
|
|
|
|
from collections import defaultdict |
|
|
freq_groups = defaultdict(list) |
|
|
for i in range(d_mlp): |
|
|
f = max_freq_ls[i] |
|
|
if f > 0: |
|
|
freq_groups[f].append((magnitudes[i].item(), i)) |
|
|
|
|
|
|
|
|
for f in freq_groups: |
|
|
freq_groups[f].sort(key=lambda x: -x[0]) |
|
|
|
|
|
|
|
|
freqs_sorted = sorted(freq_groups.keys()) |
|
|
selected = [] |
|
|
pointers = {f: 0 for f in freqs_sorted} |
|
|
|
|
|
while len(selected) < min(n, d_mlp) and freqs_sorted: |
|
|
exhausted = [] |
|
|
for f in freqs_sorted: |
|
|
if len(selected) >= n: |
|
|
break |
|
|
if pointers[f] < len(freq_groups[f]): |
|
|
_, idx = freq_groups[f][pointers[f]] |
|
|
selected.append(idx) |
|
|
pointers[f] += 1 |
|
|
else: |
|
|
exhausted.append(f) |
|
|
for f in exhausted: |
|
|
freqs_sorted.remove(f) |
|
|
|
|
|
return selected |
|
|
|
|
|
|
|
|
def select_lineplot_neurons(sorted_indices, n=3): |
|
|
""" |
|
|
Select first N neurons from the frequency-sorted set for line plots (Tab 2). |
|
|
Picks neurons evenly spaced through the sorted list to show diverse frequencies. |
|
|
""" |
|
|
if len(sorted_indices) <= n: |
|
|
return list(range(len(sorted_indices))) |
|
|
step = len(sorted_indices) // n |
|
|
return [i * step for i in range(n)] |
|
|
|
|
|
|
|
|
def select_phase_frequency(max_freq_ls, p): |
|
|
""" |
|
|
Choose the frequency for phase distribution analysis (Tab 3). |
|
|
Picks the frequency with the most neurons assigned to it (mode), |
|
|
excluding frequency 0 (DC component). |
|
|
""" |
|
|
freq_counts = Counter(f for f in max_freq_ls if f > 0) |
|
|
if not freq_counts: |
|
|
return 1 |
|
|
return freq_counts.most_common(1)[0][0] |
|
|
|
|
|
|
|
|
def select_lottery_neuron(model_load, fourier_basis, decode_scales_phis_fn): |
|
|
""" |
|
|
Find the neuron with the clearest frequency specialization (Tab 6). |
|
|
Picks the neuron with the highest ratio of dominant frequency scale |
|
|
to second-highest frequency scale. |
|
|
""" |
|
|
scales, _, _ = decode_scales_phis_fn(model_load, fourier_basis) |
|
|
|
|
|
scales_no_dc = scales[:, 1:] |
|
|
|
|
|
if scales_no_dc.shape[1] < 2: |
|
|
return 0 |
|
|
|
|
|
sorted_scales, _ = torch.sort(scales_no_dc, dim=1, descending=True) |
|
|
ratio = sorted_scales[:, 0] / (sorted_scales[:, 1] + 1e-10) |
|
|
|
|
|
return ratio.argmax().item() |
|
|
|