boringKey's picture
Upload 236 files
5fee096 verified
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
class Proj(nn.Module):
def __init__(self,
d_model=None,
id=-1):
super().__init__()
self.eye = nn.Parameter(torch.eye(d_model))
self.space = [torch.tensor((1)), torch.tensor((1))]
self.scale_param = nn.ParameterList([nn.Parameter(self.eye) for _ in range(2)])
self.scaling_mask = [False, False]
self.id = -1
def forward(self, x, kv_w, expert_id):
if expert_id == self.id:
pass
else:
return F.linear(x, kv_w)
pre_kv_w = None
for mask, scale, space in zip(self.scaling_mask, self.scale_param, self.space):
if not mask:
break
scale_size = space.shape[1]
cropped_scale = scale[:scale_size, :scale_size]
cropped_scale = cropped_scale @ cropped_scale.T # better, idk why
cropped_identity_matrix = self.eye[:scale_size, :scale_size].to(x)
if pre_kv_w is None:
pre_kv_w = kv_w + kv_w @ space @ (cropped_scale - cropped_identity_matrix) @ space.T
else:
pre_kv_w = pre_kv_w + pre_kv_w @ space @ (cropped_scale - cropped_identity_matrix) @ space.T
if pre_kv_w is None:
return F.linear(x, kv_w)
else:
return F.linear(x, pre_kv_w)
class Proj2(nn.Module):
def __init__(self,
d_model=None,
id=-1):
super().__init__()
self.eye = nn.Parameter(torch.eye(d_model))
self.space = [torch.tensor((1)), torch.tensor((1))]
self.scale_param = nn.ParameterList([nn.Parameter(self.eye) for _ in range(2)])
self.scaling_mask = [False, False]
self.id = -1
def forward(self, x, kv_w, expert_id):
if expert_id == self.id:
pass
else:
return F.linear(x, kv_w)
pre_kv_w = None
for mask, scale, space in zip(self.scaling_mask, self.scale_param, self.space):
if not mask:
break
scale_size = space.shape[1]
cropped_scale = scale[:scale_size, :scale_size]
cropped_scale = cropped_scale @ cropped_scale.T # better, idk why
cropped_identity_matrix = self.eye[:scale_size, :scale_size].to(x)
if pre_kv_w is None:
pre_kv_w = kv_w + kv_w @ space @ (cropped_scale - cropped_identity_matrix) @ space.T
else:
pre_kv_w = pre_kv_w + pre_kv_w @ space @ (cropped_scale - cropped_identity_matrix) @ space.T
if pre_kv_w is None:
return F.linear(x, kv_w)
else:
return F.linear(x, pre_kv_w)