zhuoranyang's picture
Deploy app with precomputed results for p=15,23,29,31
b753304 verified
"""
Automated neuron selection strategies for all primes.
Replaces hard-coded neuron indices from the analysis notebooks.
"""
import torch
import numpy as np
from collections import Counter
def select_top_neurons_by_frequency(max_freq_ls, W_in_decode, n=20):
"""
Select top N neurons covering all frequencies (round-robin).
Used for heatmap plots (Tab 2).
Picks the highest-magnitude neuron from each frequency in turn,
cycling through frequencies until n neurons are selected. This ensures
the heatmap shows diversification across all frequencies, matching
the blog's Figure 2.
Returns list of neuron indices into the original d_mlp-sized arrays.
"""
d_mlp = W_in_decode.shape[0]
magnitudes = W_in_decode.abs().max(dim=1).values
# Group neurons by their dominant frequency, sorted by magnitude (descending)
from collections import defaultdict
freq_groups = defaultdict(list)
for i in range(d_mlp):
f = max_freq_ls[i]
if f > 0: # skip DC neurons
freq_groups[f].append((magnitudes[i].item(), i))
# Sort each group by magnitude descending
for f in freq_groups:
freq_groups[f].sort(key=lambda x: -x[0])
# Round-robin across frequencies (ascending order)
freqs_sorted = sorted(freq_groups.keys())
selected = []
pointers = {f: 0 for f in freqs_sorted}
while len(selected) < min(n, d_mlp) and freqs_sorted:
exhausted = []
for f in freqs_sorted:
if len(selected) >= n:
break
if pointers[f] < len(freq_groups[f]):
_, idx = freq_groups[f][pointers[f]]
selected.append(idx)
pointers[f] += 1
else:
exhausted.append(f)
for f in exhausted:
freqs_sorted.remove(f)
return selected
def select_lineplot_neurons(sorted_indices, n=3):
"""
Select first N neurons from the frequency-sorted set for line plots (Tab 2).
Picks neurons evenly spaced through the sorted list to show diverse frequencies.
"""
if len(sorted_indices) <= n:
return list(range(len(sorted_indices)))
step = len(sorted_indices) // n
return [i * step for i in range(n)]
def select_phase_frequency(max_freq_ls, p):
"""
Choose the frequency for phase distribution analysis (Tab 3).
Picks the frequency with the most neurons assigned to it (mode),
excluding frequency 0 (DC component).
"""
freq_counts = Counter(f for f in max_freq_ls if f > 0)
if not freq_counts:
return 1
return freq_counts.most_common(1)[0][0]
def select_lottery_neuron(model_load, fourier_basis, decode_scales_phis_fn):
"""
Find the neuron with the clearest frequency specialization (Tab 6).
Picks the neuron with the highest ratio of dominant frequency scale
to second-highest frequency scale.
"""
scales, _, _ = decode_scales_phis_fn(model_load, fourier_basis)
# scales: [n_neurons, K+1], skip DC at index 0
scales_no_dc = scales[:, 1:]
if scales_no_dc.shape[1] < 2:
return 0
sorted_scales, _ = torch.sort(scales_no_dc, dim=1, descending=True)
ratio = sorted_scales[:, 0] / (sorted_scales[:, 1] + 1e-10)
return ratio.argmax().item()