zhuoranyang's picture
Deploy app with precomputed results for p=15,23,29,31
b753304 verified
import torch, einops, copy
import plotly.graph_objects as go
import numpy as np
import matplotlib.pyplot as plt
import plotly.express as px
import plotly.io as pio
import torch.nn as nn
from functools import partial
########## Fourier Methods ##########
def normalize_to_pi(value): return (value + np.pi) % (2 * np.pi) - np.pi
def get_fourier_basis(p, device):
"""
Generates the Fourier basis for a given dimensionality `p`.
Args:
p (int): The dimensionality of the Fourier basis.
device (str): The device to place the Fourier basis tensor on ('cpu' or 'cuda').
Returns:
torch.Tensor: A matrix where each row is a Fourier basis vector.
list: A list of names corresponding to the Fourier basis vectors.
"""
# Initialize the list to store Fourier basis vectors and names
fourier_basis = []
fourier_basis_names = []
# Add the constant term (normalized)
fourier_basis.append(torch.ones(p) / np.sqrt(p))
fourier_basis_names.append('Const')
# Generate Fourier basis for cosines and sines
for i in range(1, p // 2 + 1):
# Compute cosine and sine basis terms
cosine = torch.cos(2 * torch.pi * torch.arange(p) * i / p)
sine = torch.sin(2 * torch.pi * torch.arange(p) * i / p)
# Normalize each basis function
cosine /= cosine.norm()
sine /= sine.norm()
# Append basis vectors and their names
fourier_basis.append(cosine)
fourier_basis.append(sine)
fourier_basis_names.append(f'cos {i}')
fourier_basis_names.append(f'sin {i}')
# Special case for even p: cos(k*pi), alternating +1 and -1
if p % 2 == 0:
cosine = torch.cos(torch.pi * torch.arange(p))
cosine /= cosine.norm()
fourier_basis.append(cosine)
fourier_basis_names.append(f'cos {p // 2}')
# Stack the basis vectors into a matrix and move to the desired device
fourier_basis = torch.stack(fourier_basis, dim=0).to(device)
return fourier_basis, fourier_basis_names
def get_fourier_basis_unstd(p, device):
"""
Generates the Fourier basis for a given dimensionality `p`.
Args:
p (int): The dimensionality of the Fourier basis.
device (str): The device to place the Fourier basis tensor on ('cpu' or 'cuda').
Returns:
torch.Tensor: A matrix where each row is a Fourier basis vector.
list: A list of names corresponding to the Fourier basis vectors.
"""
# Initialize the list to store Fourier basis vectors and names
fourier_basis = []
fourier_basis_names = []
# Add the constant term (normalized)
fourier_basis.append(torch.ones(p) / np.sqrt(p))
fourier_basis_names.append('Const')
# Generate Fourier basis for cosines and sines
for i in range(1, p // 2 + 1):
# Compute cosine and sine basis terms
cosine = torch.cos(2 * torch.pi * torch.arange(p) * i / p)
sine = torch.sin(2 * torch.pi * torch.arange(p) * i / p)
# Append basis vectors and their names
fourier_basis.append(cosine)
fourier_basis.append(sine)
fourier_basis_names.append(f'cos {i}')
fourier_basis_names.append(f'sin {i}')
# Special case for even p: cos(k*pi), alternating +1 and -1
if p % 2 == 0:
cosine = torch.cos(torch.pi * torch.arange(p))
cosine /= cosine.norm()
fourier_basis.append(cosine)
fourier_basis_names.append(f'cos {p // 2}')
# Stack the basis vectors into a matrix and move to the desired device
fourier_basis = torch.stack(fourier_basis, dim=0).to(device)
return fourier_basis, fourier_basis_names
def fft1d(tensor, fourier_basis):
# Converts a tensor with dimension p into the Fourier basis
return tensor @ fourier_basis.T
def fft2d(mat, p, fourier_basis):
# Converts a pxpx... or batch x ... tensor into the 2D Fourier basis.
# Output has the same shape as the original
shape = mat.shape
mat = einops.rearrange(mat, '(x y) ... -> x y (...)', x=p, y=p)
fourier_mat = torch.einsum('xyz,fx,Fy->fFz', mat, fourier_basis, fourier_basis)
#fourier_mat = torch.einsum('xy,fX,FY->fFY', mat, fourier_basis, fourier_basis)
return fourier_mat.reshape(shape)
def to_numpy(tensor, flat=False):
if type(tensor)!=torch.Tensor:
return tensor
if flat:
return tensor.flatten().detach().cpu().numpy()
else:
return tensor.detach().cpu().numpy()
def unflatten_first(tensor, p):
if tensor.shape[0]==p*p:
return einops.rearrange(tensor, '(x y) ... -> x y ...', x=p, y=p)
else:
return tensor
def decode_weights(model_load, fourier_basis):
"""
Decodes the weights using the given model and Fourier basis, and computes the maximum frequency list.
Parameters:
model_load (dict): A dictionary containing the model's weights.
fourier_basis_unstd (torch.Tensor): The Fourier basis matrix.
Returns:
tuple: A tuple containing:
- W_in_decode (torch.Tensor): Decoded weights for W_in.
- W_out_decode (torch.Tensor): Decoded weights for W_out.
- max_freq_ls (list): List of maximum frequencies derived from W_in_decode.
"""
# Decode the weights
W_in_decode = model_load['mlp.W_in'] @ fourier_basis.T
W_out_decode = model_load['mlp.W_out'].T @ fourier_basis.T
# Find the maximum frequency list
max_ls = torch.argmax(abs(W_in_decode), dim=1)
max_freq_ls = [(id.item() + 1) // 2 for id in max_ls]
return W_in_decode, W_out_decode, max_freq_ls
def compute_neuron(neuron, max_freq_ls, W_decode):
"""
Computes the scale and phase coefficients for a given neuron.
Parameters:
neuron (int): Index of the neuron to compute coefficients for.
max_freq_ls (list): List of maximum frequencies derived from W_in_decode.
W_in_decode (torch.Tensor): Decoded weights for W_in.
Returns:
tuple: A tuple containing:
- coeff_in_scale (float): Scale coefficient.
- coeff_in_phi (float): Phase coefficient.
"""
p = W_decode.shape[1]
if max_freq_ls[neuron] != 0:
# Get the coefficients for the neuron
neuron_coeff = W_decode[neuron, [max_freq_ls[neuron] * 2 - 1, max_freq_ls[neuron] * 2]]
# Compute scale and phase
coeff_scale = np.sqrt(torch.sum(neuron_coeff.pow(2)).item()) * np.sqrt(2/p)
coeff_phi = np.arctan2(-neuron_coeff[1].item(), neuron_coeff[0].item())
else:
# Default values if max frequency is zero
coeff_phi = 0
coeff_scale = W_decode[neuron, 0].item()
return coeff_scale, coeff_phi
import torch
def decode_scales_phis(model_load: dict, fourier_basis: torch.Tensor):
"""
Decode W_in into scale & phase for **all** frequencies.
Returns:
scales: Tensor[n_neurons, K+1]
phis: Tensor[n_neurons, K+1]
"""
# 1) decode W_in
W = model_load['mlp.W_in'] @ fourier_basis.T # [n_neurons, p]
W_out = model_load['mlp.W_out'].T @ fourier_basis.T # [n_neurons, p]
# 2) set up
n_neurons, p = W.shape
K = (p - 1) // 2
scales = torch.zeros(n_neurons, K+1, device=W.device, dtype=W.dtype)
phis = torch.zeros(n_neurons, K+1, device=W.device, dtype=W.dtype)
psis = torch.zeros(n_neurons, K+1, device=W.device, dtype=W.dtype)
# 3) DC (f=0)
scales[:, 0] = W[:, 0].abs()
# phis[:,0] stays 0
# 4) all other freqs
for f in range(1, K+1):
real = W[:, 2*f - 1]
imag = W[:, 2*f]
scales[:, f] = np.sqrt(2/p) * torch.sqrt(real.pow(2) + imag.pow(2))
phis[:, f] = torch.atan2(-imag, real)
psis[:, f] = torch.atan2(-W_out[:, 2*f], W_out[:, 2*f - 1])
return scales, phis, psis
########## Neuron Tracking ##########
def sort_model(model_load, sort_order_mlp, sort_order_d):
"""
Reorders the weights of a model based on the provided sorting orders.
Parameters:
model_load (dict): The original loaded model dictionary.
sort_order_mlp (list or array): Sorting order for the MLP dimensions.
sort_order_d (list or array): Sorting order for the embedding dimensions.
Returns:
dict: A deep copy of the reordered model.
"""
# Create a deep copy of the model to avoid modifying the original
sorted_model_load = copy.deepcopy(model_load)
# Reorder MLP weights and biases
sorted_model_load['mlp.W_in'] = sorted_model_load['mlp.W_in'][sort_order_mlp]
sorted_model_load['mlp.W_in'] = sorted_model_load['mlp.W_in'][:, sort_order_d]
sorted_model_load['mlp.W_out'] = sorted_model_load['mlp.W_out'][sort_order_d]
sorted_model_load['mlp.W_out'] = sorted_model_load['mlp.W_out'][:, sort_order_mlp]
sorted_model_load['mlp.b_in'] = sorted_model_load['mlp.b_in'][sort_order_mlp]
# Reorder embedding weights
sorted_model_load['embed.W_E'] = sorted_model_load['embed.W_E'][sort_order_d]
sorted_model_load['unembed.embed_layer.W_E'] = sorted_model_load['embed.W_E']
return sorted_model_load
########## Plotting Helper ##########
def imshow(tensor, xaxis=None, yaxis=None, animation_name='Snapshot', **kwargs):
if tensor.shape[0]==p*p:
tensor = unflatten_first(tensor, p)
tensor = torch.squeeze(tensor)
px.imshow(to_numpy(tensor, flat=False),
labels={'x':xaxis, 'y':yaxis, 'animation_name':animation_name},
**kwargs).show()
# Set default colour scheme
imshow = partial(imshow, color_continuous_scale='Blues')
# Creates good defaults for showing divergent colour scales (ie with both
# positive and negative values, where 0 is white)
imshow_div = partial(imshow, color_continuous_scale='RdBu', color_continuous_midpoint=0.0)
# Presets a bunch of defaults to imshow to make it suitable for showing heatmaps
# of activations with x axis being input 1 and y axis being input 2.
inputs_heatmap = partial(imshow, xaxis='Input 1', yaxis='Input 2', color_continuous_scale='RdBu', color_continuous_midpoint=0.0, width=1000, height=800)
def line(x, y=None, hover=None, xaxis='', yaxis='', **kwargs):
if type(y)==torch.Tensor:
y = to_numpy(y, flat=True)
if type(x)==torch.Tensor:
x=to_numpy(x, flat=True)
fig = px.line(x, y=y, hover_name=hover, **kwargs)
fig.update_layout(xaxis_title=xaxis, yaxis_title=yaxis)
fig.show()
def scatter(x, y, **kwargs):
px.scatter(x=to_numpy(x, flat=True), y=to_numpy(y, flat=True), **kwargs).show()
def lines(lines_list, x=None, mode='lines', labels=None, xaxis='', yaxis='', title = '', log_y=False, hover=None, **kwargs):
# Helper function to plot multiple lines
if type(lines_list)==torch.Tensor:
lines_list = [lines_list[i] for i in range(lines_list.shape[0])]
if x is None:
x=np.arange(len(lines_list[0]))
fig = go.Figure(layout={'title':title})
fig.update_xaxes(title=xaxis)
fig.update_yaxes(title=yaxis)
for c, line in enumerate(lines_list):
if type(line)==torch.Tensor:
line = to_numpy(line)
if labels is not None:
label = labels[c]
else:
label = c
fig.add_trace(go.Scatter(x=x, y=line, mode=mode, name=label, hovertext=hover, **kwargs))
if log_y:
fig.update_layout(yaxis_type="log")
fig.show()
def line_marker(x, **kwargs):
lines([x], mode='lines+markers', **kwargs)
def animate_lines(lines_list, snapshot_index = None, snapshot='snapshot', hover=None, xaxis='x', yaxis='y', **kwargs):
if type(lines_list)==list:
lines_list = torch.stack(lines_list, axis=0)
lines_list = to_numpy(lines_list, flat=False)
if snapshot_index is None:
snapshot_index = np.arange(lines_list.shape[0])
if hover is not None:
hover = [i for j in range(len(snapshot_index)) for i in hover]
print(lines_list.shape)
rows=[]
for i in range(lines_list.shape[0]):
for j in range(lines_list.shape[1]):
rows.append([lines_list[i][j], snapshot_index[i], j])
df = pd.DataFrame(rows, columns=[yaxis, snapshot, xaxis])
px.line(df, x=xaxis, y=yaxis, animation_frame=snapshot, range_y=[lines_list.min(), lines_list.max()], hover_name=hover,**kwargs).show()
def imshow_fourier(tensor, p, fourier_basis_names, title='', animation_name='snapshot', facet_labels=[], width=1000, height=800, **kwargs):
if tensor.shape[0] == p * p:
tensor = unflatten_first(tensor, p)
tensor = torch.squeeze(tensor)
fig = px.imshow(
to_numpy(tensor),
x=fourier_basis_names,
y=fourier_basis_names,
labels={
'x': 'x Component',
'y': 'y Component',
'animation_frame': animation_name
},
title=title,
color_continuous_midpoint=0.,
color_continuous_scale='RdBu',
width=width,
height=height,
**kwargs
)
fig.update(data=[{'hovertemplate': "%{x}x * %{y}y<br>Value:%{z:.4f}"}])
if facet_labels:
for i, label in enumerate(facet_labels):
fig.layout.annotations[i]['text'] = label
fig.show()
def animate_multi_lines(lines_list, y_index=None, snapshot_index = None, snapshot='snapshot', hover=None, swap_y_animate=False, **kwargs):
# Can plot an animation of lines with multiple lines on the plot.
if type(lines_list)==list:
lines_list = torch.stack(lines_list, axis=0)
lines_list = to_numpy(lines_list, flat=False)
if swap_y_animate:
lines_list = lines_list.transpose(1, 0, 2)
if snapshot_index is None:
snapshot_index = np.arange(lines_list.shape[0])
if y_index is None:
y_index = [str(i) for i in range(lines_list.shape[1])]
if hover is not None:
hover = [i for j in range(len(snapshot_index)) for i in hover]
print(lines_list.shape)
rows=[]
for i in range(lines_list.shape[0]):
for j in range(lines_list.shape[2]):
rows.append(list(lines_list[i, :, j])+[snapshot_index[i], j])
df = pd.DataFrame(rows, columns=y_index+[snapshot, 'x'])
px.line(df, x='x', y=y_index, animation_frame=snapshot, range_y=[lines_list.min(), lines_list.max()], hover_name=hover, **kwargs).show()
def animate_scatter(lines_list, snapshot_index = None, snapshot='snapshot', hover=None, yaxis='y', xaxis='x', color=None, color_name = 'color', **kwargs):
# Can plot an animated scatter plot
# lines_list has shape snapshot x 2 x line
if type(lines_list)==list:
lines_list = torch.stack(lines_list, axis=0)
lines_list = to_numpy(lines_list, flat=False)
if snapshot_index is None:
snapshot_index = np.arange(lines_list.shape[0])
if hover is not None:
hover = [i for j in range(len(snapshot_index)) for i in hover]
if color is None:
color = np.ones(lines_list.shape[-1])
if type(color)==torch.Tensor:
color = to_numpy(color)
if len(color.shape)==1:
color = einops.repeat(color, 'x -> snapshot x', snapshot=lines_list.shape[0])
print(lines_list.shape)
rows=[]
for i in range(lines_list.shape[0]):
for j in range(lines_list.shape[2]):
rows.append([lines_list[i, 0, j].item(), lines_list[i, 1, j].item(), snapshot_index[i], color[i, j]])
print([lines_list[:, 0].min(), lines_list[:, 0].max()])
print([lines_list[:, 1].min(), lines_list[:, 1].max()])
df = pd.DataFrame(rows, columns=[xaxis, yaxis, snapshot, color_name])
px.scatter(df, x=xaxis, y=yaxis, animation_frame=snapshot, range_x=[lines_list[:, 0].min(), lines_list[:, 0].max()], range_y=[lines_list[:, 1].min(), lines_list[:, 1].max()], hover_name=hover, color=color_name, **kwargs).sh
def plot_angles_on_circle(angles, multipliers = [1, 2, 4, 6], title_prefix="Angles Multiplication"):
"""
Visualize multiple sets of angles (in radians) on unit circles.
Parameters:
- angles: list or array-like of angles in radians (should be in range [-π, π]).
- title_prefix: Prefix for titles of the subplots (default is "Angles Multiplication").
"""
# Create a figure with 4 subplots in a row
plt.figure(figsize=(20, 5))
# Loop through each multiplier to create the subplots
for i, multiplier in enumerate(multipliers):
# Multiply the angles
modified_angles = angles * multiplier
# Convert angles to x and y coordinates on a unit circle
x = np.cos(modified_angles)
y = np.sin(modified_angles)
# Plot the unit circle
theta = np.linspace(0, 2 * np.pi, 500)
circle_x = np.cos(theta)
circle_y = np.sin(theta)
plt.subplot(1, 4, i + 1)
plt.plot(circle_x, circle_y, color='lightgray', label='Unit Circle') # Unit circle
plt.scatter(x, y, color='red', label='Points') # Points corresponding to the angles
plt.axhline(0, color='black', linewidth=0.5) # Horizontal line
plt.axvline(0, color='black', linewidth=0.5) # Vertical line
# Annotate each point with its angle
for j, angle in enumerate(modified_angles):
plt.text(x[j] * 1.1, y[j] * 1.1, f'{angle:.2f}', fontsize=9, ha='center')
# Set title and formatting
plt.title(f"{title_prefix}: {multiplier}*Angles")
plt.axis('equal') # Equal scaling for x and y
plt.legend()
plt.grid(True)
# Adjust spacing between subplots
plt.tight_layout()
plt.show()