| import math |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| class Proj(nn.Module): |
| def __init__(self, |
| d_model=None, |
| id=-1): |
| super().__init__() |
|
|
| self.eye = nn.Parameter(torch.eye(d_model)) |
|
|
| self.space = [torch.tensor((1)), torch.tensor((1))] |
| self.scale_param = nn.ParameterList([nn.Parameter(self.eye) for _ in range(2)]) |
| self.scaling_mask = [False, False] |
| self.id = -1 |
|
|
| def forward(self, x, kv_w, expert_id): |
|
|
| if expert_id == self.id: |
| pass |
| else: |
| return F.linear(x, kv_w) |
|
|
| pre_kv_w = None |
|
|
| for mask, scale, space in zip(self.scaling_mask, self.scale_param, self.space): |
|
|
| if not mask: |
| break |
|
|
| scale_size = space.shape[1] |
| cropped_scale = scale[:scale_size, :scale_size] |
|
|
| cropped_scale = cropped_scale @ cropped_scale.T |
|
|
| cropped_identity_matrix = self.eye[:scale_size, :scale_size].to(x) |
|
|
| if pre_kv_w is None: |
| pre_kv_w = kv_w + kv_w @ space @ (cropped_scale - cropped_identity_matrix) @ space.T |
| else: |
| pre_kv_w = pre_kv_w + pre_kv_w @ space @ (cropped_scale - cropped_identity_matrix) @ space.T |
|
|
| if pre_kv_w is None: |
| return F.linear(x, kv_w) |
| else: |
| return F.linear(x, pre_kv_w) |
|
|
| class Proj2(nn.Module): |
| def __init__(self, |
| d_model=None, |
| id=-1): |
| super().__init__() |
|
|
| self.eye = nn.Parameter(torch.eye(d_model)) |
|
|
| self.space = [torch.tensor((1)), torch.tensor((1))] |
| self.scale_param = nn.ParameterList([nn.Parameter(self.eye) for _ in range(2)]) |
| self.scaling_mask = [False, False] |
| self.id = -1 |
|
|
| def forward(self, x, kv_w, expert_id): |
|
|
| if expert_id == self.id: |
| pass |
| else: |
| return F.linear(x, kv_w) |
|
|
| pre_kv_w = None |
|
|
| for mask, scale, space in zip(self.scaling_mask, self.scale_param, self.space): |
|
|
| if not mask: |
| break |
|
|
| scale_size = space.shape[1] |
| cropped_scale = scale[:scale_size, :scale_size] |
|
|
| cropped_scale = cropped_scale @ cropped_scale.T |
|
|
| cropped_identity_matrix = self.eye[:scale_size, :scale_size].to(x) |
|
|
| if pre_kv_w is None: |
| pre_kv_w = kv_w + kv_w @ space @ (cropped_scale - cropped_identity_matrix) @ space.T |
| else: |
| pre_kv_w = pre_kv_w + pre_kv_w @ space @ (cropped_scale - cropped_identity_matrix) @ space.T |
|
|
| if pre_kv_w is None: |
| return F.linear(x, kv_w) |
| else: |
| return F.linear(x, pre_kv_w) |