|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
import numpy as np |
|
|
import torch, einops |
|
|
|
|
|
|
|
|
class HookPoint(nn.Module): |
|
|
def __init__(self): |
|
|
super().__init__() |
|
|
|
|
|
self.fwd_hooks = [] |
|
|
self.bwd_hooks = [] |
|
|
|
|
|
def give_name(self, name): |
|
|
|
|
|
self.name = name |
|
|
|
|
|
def add_hook(self, hook, dir='fwd'): |
|
|
|
|
|
|
|
|
|
|
|
def full_hook(module, module_input, module_output): |
|
|
|
|
|
return hook(module_output, name=self.name) |
|
|
|
|
|
if dir == 'fwd': |
|
|
|
|
|
handle = self.register_forward_hook(full_hook) |
|
|
self.fwd_hooks.append(handle) |
|
|
elif dir == 'bwd': |
|
|
|
|
|
handle = self.register_backward_hook(full_hook) |
|
|
self.bwd_hooks.append(handle) |
|
|
else: |
|
|
raise ValueError(f"Invalid direction {dir}") |
|
|
|
|
|
def remove_hooks(self, dir='fwd'): |
|
|
|
|
|
if (dir == 'fwd') or (dir == 'both'): |
|
|
for hook in self.fwd_hooks: |
|
|
hook.remove() |
|
|
self.fwd_hooks = [] |
|
|
if (dir == 'bwd') or (dir == 'both'): |
|
|
for hook in self.bwd_hooks: |
|
|
hook.remove() |
|
|
self.bwd_hooks = [] |
|
|
if dir not in ['fwd', 'bwd', 'both']: |
|
|
raise ValueError(f"Invalid direction {dir}") |
|
|
|
|
|
def forward(self, x): |
|
|
|
|
|
return x |
|
|
|
|
|
|
|
|
class Embed(nn.Module): |
|
|
def __init__(self, d_vocab, d_model, embed_type='one_hot'): |
|
|
super().__init__() |
|
|
self.d_vocab = d_vocab |
|
|
self.embed_type = embed_type |
|
|
|
|
|
if embed_type == 'learned': |
|
|
|
|
|
self.W_E = nn.Parameter(torch.randn(d_model, d_vocab)/np.sqrt(d_model)) |
|
|
elif embed_type == 'one_hot': |
|
|
|
|
|
self.W_E = None |
|
|
else: |
|
|
raise ValueError(f"Invalid embed_type: {embed_type}. Must be 'one_hot' or 'learned'") |
|
|
|
|
|
def forward(self, x): |
|
|
|
|
|
|
|
|
|
|
|
if isinstance(x, list): |
|
|
device = self.W_E.device if self.W_E is not None else 'cpu' |
|
|
x = torch.tensor(x, device=device) |
|
|
|
|
|
assert x.ndim == 2 and x.shape[1] == 2, f"Expected input shape (batch_size, 2), got {x.shape}" |
|
|
|
|
|
if self.embed_type == 'one_hot': |
|
|
|
|
|
embed = F.one_hot(x, num_classes=self.d_vocab).float().sum(dim=1).unsqueeze(1) |
|
|
elif self.embed_type == 'learned': |
|
|
|
|
|
embed = torch.einsum('dbp -> bpd', self.W_E[:, x]).sum(dim=1).unsqueeze(1) |
|
|
|
|
|
return embed |
|
|
|
|
|
class LayerNorm(nn.Module): |
|
|
def __init__(self, d_model, epsilon=1e-4, model=[None]): |
|
|
super().__init__() |
|
|
self.model = model |
|
|
|
|
|
self.w_ln = nn.Parameter(torch.ones(d_model)) |
|
|
self.b_ln = nn.Parameter(torch.zeros(d_model)) |
|
|
self.epsilon = epsilon |
|
|
|
|
|
def forward(self, x): |
|
|
if self.model[0].use_ln: |
|
|
|
|
|
x = x - x.mean(axis=-1)[..., None] |
|
|
x = x / (x.std(axis=-1)[..., None] + self.epsilon) |
|
|
|
|
|
x = x * self.w_ln |
|
|
x = x + self.b_ln |
|
|
return x |
|
|
else: |
|
|
return x |
|
|
|
|
|
|
|
|
class MLP(nn.Module): |
|
|
def __init__(self, d_model, d_mlp, d_vocab, act_type, model, init_type='random', init_scale=0.1): |
|
|
super().__init__() |
|
|
self.model = model |
|
|
self.init_type = init_type |
|
|
self.init_scale = init_scale |
|
|
|
|
|
|
|
|
if init_type == 'random': |
|
|
|
|
|
self.W_in = nn.Parameter(self.init_scale * torch.randn(d_mlp, d_model)/np.sqrt(d_model)) |
|
|
self.W_out = nn.Parameter(self.init_scale * torch.randn(d_vocab, d_mlp)/np.sqrt(d_model)) |
|
|
elif init_type == 'single-freq': |
|
|
|
|
|
freq_num = (d_vocab-1)//2 |
|
|
init_freq = decide_frequencies(d_mlp, d_model, freq_num) |
|
|
fourier_basis, _ = get_fourier_basis(d_vocab) |
|
|
|
|
|
self.W_in = nn.Parameter(self.init_scale * np.sqrt(d_vocab/2) * sparse_initialization(d_mlp, d_model, init_freq) @ fourier_basis) |
|
|
self.W_out = nn.Parameter(self.init_scale * np.sqrt(d_vocab/2) * fourier_basis.T @ sparse_initialization(d_mlp, d_model, init_freq).T) |
|
|
else: |
|
|
raise ValueError(f"Invalid init_type: ini{init_type}. Must be 'random' or 'single-freq'") |
|
|
|
|
|
|
|
|
self.act_type = act_type |
|
|
self.hook_pre = HookPoint() |
|
|
self.hook_post = HookPoint() |
|
|
|
|
|
|
|
|
if isinstance(act_type, str): |
|
|
assert act_type in ['ReLU', 'GeLU', 'Quad', 'Id'], f"Invalid activation type: {act_type}" |
|
|
elif not callable(act_type): |
|
|
raise ValueError("act_type must be either a string ('ReLU', 'GeLU', 'Quad', 'Id') or a callable function") |
|
|
|
|
|
fourier_basis, _ = get_fourier_basis(d_vocab) |
|
|
self.register_buffer('basis', fourier_basis.clone().detach()) |
|
|
|
|
|
def forward(self, x): |
|
|
|
|
|
x = self.hook_pre(torch.einsum('md,bpd->bpm', self.W_in, x)) |
|
|
|
|
|
|
|
|
if callable(self.act_type): |
|
|
|
|
|
x = self.act_type(x) |
|
|
elif self.act_type == 'ReLU': |
|
|
x = F.relu(x) |
|
|
elif self.act_type == 'GeLU': |
|
|
x = F.gelu(x) |
|
|
elif self.act_type == "Quad": |
|
|
x = torch.square(x) |
|
|
elif self.act_type == "Id": |
|
|
x = x |
|
|
|
|
|
x = self.hook_post(x) |
|
|
|
|
|
x = torch.einsum('dm,bpm->bpd', self.W_out, x) |
|
|
return x |
|
|
|
|
|
class EmbedMLP(nn.Module): |
|
|
def __init__(self, d_vocab, d_model, d_mlp, act_type, use_cache=False, use_ln=True, init_type='random', init_scale=0.1, embed_type='one_hot'): |
|
|
super().__init__() |
|
|
self.cache = {} |
|
|
self.use_cache = use_cache |
|
|
self.init_type = init_type |
|
|
|
|
|
|
|
|
self.embed = Embed(d_vocab, d_model, embed_type=embed_type) |
|
|
self.mlp = MLP(d_model, d_mlp, d_vocab, act_type, model=[self], init_type=init_type, init_scale=init_scale) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.use_ln = use_ln |
|
|
|
|
|
|
|
|
for name, module in self.named_modules(): |
|
|
if type(module) == HookPoint: |
|
|
module.give_name(name) |
|
|
|
|
|
def forward(self, x): |
|
|
|
|
|
x = self.embed(x) |
|
|
|
|
|
x = self.mlp(x) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return x.squeeze(1) |
|
|
|
|
|
def set_use_cache(self, use_cache): |
|
|
self.use_cache = use_cache |
|
|
|
|
|
def hook_points(self): |
|
|
|
|
|
return [module for name, module in self.named_modules() if 'hook' in name] |
|
|
|
|
|
def remove_all_hooks(self): |
|
|
|
|
|
for hp in self.hook_points(): |
|
|
hp.remove_hooks('fwd') |
|
|
hp.remove_hooks('bwd') |
|
|
|
|
|
def cache_all(self, cache, incl_bwd=False): |
|
|
|
|
|
def save_hook(tensor, name): |
|
|
cache[name] = tensor.detach() |
|
|
def save_hook_back(tensor, name): |
|
|
cache[name + '_grad'] = tensor[0].detach() |
|
|
for hp in self.hook_points(): |
|
|
hp.add_hook(save_hook, 'fwd') |
|
|
if incl_bwd: |
|
|
hp.add_hook(save_hook_back, 'bwd') |
|
|
|
|
|
|
|
|
|
|
|
def get_fourier_basis(p): |
|
|
|
|
|
fourier_basis = [] |
|
|
fourier_basis_names = [] |
|
|
|
|
|
|
|
|
fourier_basis.append(torch.ones(p) / np.sqrt(p)) |
|
|
fourier_basis_names.append('Const') |
|
|
|
|
|
|
|
|
for i in range(1, p // 2 + 1): |
|
|
|
|
|
cosine = torch.cos(2 * torch.pi * torch.arange(p) * i / p) |
|
|
sine = torch.sin(2 * torch.pi * torch.arange(p) * i / p) |
|
|
|
|
|
cosine /= cosine.norm() |
|
|
sine /= sine.norm() |
|
|
|
|
|
fourier_basis.append(cosine) |
|
|
fourier_basis.append(sine) |
|
|
fourier_basis_names.append(f'cos {i}') |
|
|
fourier_basis_names.append(f'sin {i}') |
|
|
|
|
|
|
|
|
if p % 2 == 0: |
|
|
cosine = torch.cos(torch.pi * torch.arange(p)) |
|
|
cosine /= cosine.norm() |
|
|
fourier_basis.append(cosine) |
|
|
fourier_basis_names.append(f'cos {p // 2}') |
|
|
|
|
|
|
|
|
fourier_basis = torch.stack(fourier_basis, dim=0) |
|
|
|
|
|
return fourier_basis, fourier_basis_names |
|
|
|
|
|
def decide_frequencies(d_mlp, d_model, freq_num): |
|
|
""" |
|
|
Decide frequency assignments for each neuron. |
|
|
|
|
|
For a weight matrix of shape (d_mlp, d_model), valid frequencies are integers |
|
|
in the range [1, (d_model-1)//2]. This function samples 'freq_num' unique frequencies |
|
|
uniformly from this range and assigns them to the neurons as equally as possible. |
|
|
|
|
|
Args: |
|
|
d_mlp (int): Number of neurons (rows). |
|
|
d_model (int): Number of columns in the weight matrix. |
|
|
freq_num (int): Number of unique frequencies to sample. |
|
|
|
|
|
Returns: |
|
|
np.ndarray: A 1D array of length d_mlp containing the frequency assigned to each neuron. |
|
|
""" |
|
|
|
|
|
max_freq = (d_model - 1) // 2 |
|
|
if freq_num > max_freq: |
|
|
raise ValueError(f"freq_num ({freq_num}) cannot exceed the number of available frequencies ({max_freq}).") |
|
|
|
|
|
|
|
|
freq_choices = np.random.choice(np.arange(1, max_freq + 1), size=freq_num, replace=False) |
|
|
|
|
|
|
|
|
|
|
|
repeats = (d_mlp + freq_num - 1) // freq_num |
|
|
freq_assignments = np.tile(freq_choices, repeats)[:d_mlp] |
|
|
|
|
|
|
|
|
np.random.shuffle(freq_assignments) |
|
|
|
|
|
return freq_assignments |
|
|
|
|
|
def sparse_initialization(d_mlp, d_model, freq_assignments): |
|
|
""" |
|
|
Generate a sparse weight matrix using the provided frequency assignments. |
|
|
|
|
|
For each neuron (row) assigned frequency f, this function assigns Gaussian random values |
|
|
to columns (2*f - 1) and (2*f) of that row. All other entries remain zero. |
|
|
|
|
|
Args: |
|
|
d_mlp (int): Number of neurons (rows) in the weight matrix. |
|
|
d_model (int): Number of columns in the weight matrix. |
|
|
freq_assignments (np.ndarray): 1D array of length d_mlp containing the frequency for each neuron. |
|
|
|
|
|
Returns: |
|
|
torch.Tensor: A weight matrix of shape (d_mlp, d_model) with the sparse initialization. |
|
|
""" |
|
|
|
|
|
weight = torch.zeros(d_mlp, d_model) |
|
|
|
|
|
|
|
|
for i, f in enumerate(freq_assignments): |
|
|
col1 = 2 * f - 1 |
|
|
col2 = 2 * f |
|
|
|
|
|
if col2 < d_model: |
|
|
vec = torch.randn(2, device=weight.device, dtype=weight.dtype) |
|
|
|
|
|
vec = vec / torch.norm(vec, p=2) |
|
|
|
|
|
|
|
|
weight[i, col1] = vec[0] |
|
|
weight[i, col2] = vec[1] |
|
|
else: |
|
|
|
|
|
raise IndexError(f"Computed column index {col2} is out of bounds for d_model={d_model}.") |
|
|
|
|
|
return weight |