| | """ |
| | Learnable linear attention feature map classes and functions |
| | """ |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| |
|
| |
|
| | def init_feature_map(name: str, mlp: nn.Module, **kwargs: dict): |
| | """ |
| | Initialize feature map final activation for linear attention |
| | """ |
| | return FeatureMap(activation_name=name, mlp=mlp, **kwargs) |
| |
|
| |
|
| | def init_feature_map_act(name: str, fullspace: bool = True, **kwargs): |
| | """ |
| | Initialize feature map final activation for linear attention |
| | """ |
| | if name == 'softmax_dim' and fullspace: |
| | return SoftmaxDim(**kwargs) |
| | elif name == 'softmax_dim' and not fullspace: |
| | return SoftmaxDimHalfspace(**kwargs) |
| | elif name == 'exp_dim' and fullspace: |
| | return Exp(**kwargs) |
| | elif name == 'exp_dim' and not fullspace: |
| | return ExpHalfspace(**kwargs) |
| | elif name == 'pos_elu': |
| | return PosELU(**kwargs) |
| | elif name == 'relu': |
| | return ReLU(**kwargs) |
| | |
| | else: |
| | raise NotImplementedError |
| |
|
| |
|
| | def init_learned_kernel(name: str, **kwargs: any): |
| | """ |
| | Initialize feature map MLP for linear attention |
| | """ |
| | if name == 'untied_head_einsum': |
| | return FeatureMapMLP(**kwargs) |
| | elif name == 'untied_head_adapter': |
| | return FeatureMapAdapter(**kwargs) |
| | else: |
| | raise NotImplementedError |
| |
|
| |
|
| | class FeatureMap(nn.Module): |
| | """ |
| | Final 'activation' of feature map. Can probably be combined with |
| | `FeatureMapMLP` below |
| | |
| | Full feature map is like f(xW + b) |
| | -> This is the `f` part |
| | """ |
| | def __init__(self, |
| | activation_name: str, |
| | head_dim_idx: int = -1, |
| | eps: float = 1e-12, |
| | mlp: nn.Module = None, |
| | fullspace: bool = True,): |
| | super().__init__() |
| | self.head_dim_idx = head_dim_idx |
| | self.eps = eps |
| | self.mlp = mlp if mlp is not None else nn.Identity() |
| | self.activation = init_feature_map_act(activation_name, fullspace, eps=eps) |
| | |
| | def forward(self, x: torch.Tensor, *mlp_args: any, **mlp_kwargs: any): |
| | """ |
| | Assume x.shape is (batch_size, n_heads, seq_len, head_dim) |
| | """ |
| | return self.activation(self.mlp(x, *mlp_args, **mlp_kwargs), x) |
| |
|
| | def q_map(self, *args: any, **kwargs: any): |
| | """ |
| | Use for inference in case q and k feature maps differ |
| | """ |
| | return self.forward(*args, **kwargs) |
| |
|
| | def k_map(self, *args: any, **kwargs: any): |
| | """ |
| | Use for inference in case q and k feature maps differ |
| | """ |
| | return self.forward(*args, **kwargs) |
| |
|
| |
|
| | |
| | |
| | |
| | class FeatureMapAct(nn.Module): |
| | """ |
| | Base class for feature map activations |
| | """ |
| | def __init__(self, eps: float = 1e-12): |
| | super().__init__() |
| | self.eps = eps |
| |
|
| | def forward(self, x: torch.Tensor, *args: any, **kwargs: any): |
| | """ |
| | x.shape is (batch_size, n_heads, seq_len, head_dim) |
| | """ |
| | return x |
| |
|
| |
|
| | class PosELU(FeatureMapAct): |
| | """ |
| | 1 + ELU activation as in https://arxiv.org/abs/2006.16236 |
| | """ |
| | def forward(self, x: torch.Tensor, *args: any, **kwargs: any): |
| | return (1 + F.elu(x)).clamp(min=self.eps) |
| |
|
| |
|
| | class ReLU(FeatureMapAct): |
| | """ |
| | ReLU activation as in https://arxiv.org/abs/2103.13076 |
| | """ |
| | def forward(self, x: torch.Tensor, *args: any, **kwargs: any): |
| | return F.relu(x).clamp(min=self.eps) |
| |
|
| |
|
| | class SoftmaxDim(FeatureMapAct): |
| | """ |
| | Softmax activation as in https://arxiv.org/abs/2402.04347 |
| | """ |
| | def forward(self, x: torch.Tensor, *args: any, **kwargs: any): |
| | return torch.cat([ |
| | torch.softmax(x, dim=-1), torch.softmax(-x, dim=-1) |
| | ], dim=-1).clamp(min=self.eps) |
| |
|
| | |
| | class SoftmaxDimHalfspace(FeatureMapAct): |
| | """ |
| | Softmax activation as in https://arxiv.org/abs/2402.04347 |
| | """ |
| | def forward(self, x: torch.Tensor, *args: any, **kwargs: any): |
| | return torch.softmax(x, dim=-1).clamp(min=self.eps) |
| |
|
| |
|
| | class Exp(FeatureMapAct): |
| | """ |
| | Exp activation as in https://arxiv.org/abs/2402.04347 |
| | """ |
| | def forward(self, x: torch.Tensor, *args: any, **kwargs: any): |
| | x_max = torch.amax(x, dim=-1, keepdim=True) |
| | x_min = torch.amin(x, dim=-1, keepdim=True) |
| | return torch.cat([ |
| | torch.exp(x - x_max), torch.exp(-x + x_min) |
| | ], dim=-1).clamp(min=self.eps) |
| |
|
| |
|
| | class ExpHalfspace(FeatureMapAct): |
| | """ |
| | Exp activation as in https://arxiv.org/abs/2402.04347 |
| | """ |
| | def forward(self, x: torch.Tensor, *args: any, **kwargs: any): |
| | x_max = torch.amax(x, dim=-1, keepdim=True) |
| | return torch.exp(x - x_max).clamp(min=self.eps) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class FeatureMapMLP(nn.Module): |
| | """ |
| | Learnable MLP in feature map. |
| | |
| | Full feature map is like f(xW + b) |
| | -> This is the `W` and (optional) `b` part |
| | """ |
| | def __init__(self, |
| | num_heads: int, |
| | head_dim: int, |
| | feature_dim: int, |
| | dtype: torch.dtype, |
| | device: torch.device, |
| | skip_connection: bool = False, |
| | bias: bool = False, |
| | zero_init: bool = False, |
| | normal_init: bool = False,): |
| | super().__init__() |
| | self.num_heads = num_heads |
| | self.head_dim = head_dim |
| | self.feature_dim = feature_dim |
| | self.dtype = dtype |
| | self.device = device |
| | self.skip_connection = skip_connection |
| | self.bias = bias |
| | self.zero_init = zero_init |
| | self.normal_init = normal_init |
| | self.init_weights_() |
| |
|
| | if self.zero_init: |
| | self.zero_init_with_skip_() if self.skip_connection else self.zero_init_() |
| |
|
| | if self.normal_init: |
| | with torch.no_grad(): |
| | nn.init.normal_(self.layer) |
| | |
| | if self.skip_connection: |
| | assertion_fail = f'If self.skip_connection we need self.head_dim == self.feature_dim but self.head_dim is {self.head_dim} != self.feature_dim is {self.feature_dim}' |
| | assert self.head_dim == self.feature_dim, assertion_fail |
| |
|
| | def init_weights_(self): |
| | """ |
| | Initialize (W)eights and (b)iases |
| | """ |
| | self.layer = nn.Parameter(torch.zeros( |
| | (self.num_heads, self.head_dim, self.feature_dim), |
| | dtype=self.dtype, device=self.device, |
| | )) |
| | nn.init.kaiming_uniform_(self.layer) |
| |
|
| | if self.bias: |
| | self.bias = nn.Parameter(torch.zeros( |
| | (1, self.num_heads, 1, 1), |
| | dtype=self.dtype, device=self.device, |
| | )) |
| | nn.init.kaiming_uniform_(self.bias) |
| | else: |
| | self.bias = 0. |
| |
|
| | def zero_init_with_skip_(self): |
| | """ |
| | Initialize weights to zero matrix if skip connection |
| | """ |
| | with torch.no_grad(): |
| | nn.init.zeros_(self.layer) |
| |
|
| | def zero_init_(self): |
| | """ |
| | Initialize weights to identity matrix if no skip connection |
| | """ |
| | with torch.no_grad(): |
| | for i in range(self.layer.shape[0]): |
| | try: |
| | nn.init.eye_(self.layer[i]) |
| | except RuntimeError: |
| | with torch.no_grad(): |
| | dtype = self.layer[i].dtype |
| | weight = torch.eye(*self.layer[i].shape, |
| | requires_grad=self.layer[i].requires_grad, |
| | device=self.layer[i].device) |
| | self.layer[i] = weight.to(dtype=dtype) |
| |
|
| | def forward(self, x: torch.Tensor): |
| | """ |
| | Assume x.shape is (batch_size, num_heads, seq_len, head_dim) |
| | """ |
| | _x = torch.einsum('hdf,bhld->bhlf', self.layer, x) + self.bias |
| | return x + _x if self.skip_connection else _x |
| |
|
| |
|
| | class FeatureMapAdapter(FeatureMapMLP): |
| | """ |
| | Learnable Feature map with bottleneck adapter |
| | as in https://arxiv.org/abs/1902.00751 |
| | |
| | We don't use but could be fun to try |
| | """ |
| | def __init__(self, hidden_dim: int, *args, **kwargs): |
| | kwargs['skip_connection'] = True |
| | kwargs['bias'] = True |
| | kwargs['zero_init'] = True |
| | self.hidden_dim = hidden_dim |
| | super().__init__(*args, **kwargs) |
| | |
| | def init_weights_(self): |
| | """ |
| | Initialize (W)eights and (b)iases |
| | """ |
| | kwargs = {'dtype': self.dtype, 'device': self.device} |
| | self.layer0 = nn.Parameter( |
| | torch.zeros((self.num_heads, self.head_dim, self.hidden_dim), **kwargs) |
| | ) |
| | self.layer1 = nn.Parameter( |
| | torch.zeros((self.num_heads, self.hidden_dim, self.feature_dim), **kwargs) |
| | ) |
| | nn.init.kaiming_uniform_(self.layer0) |
| | nn.init.kaiming_uniform_(self.layer1) |
| |
|
| | self.bias0 = nn.Parameter(torch.zeros((1, self.num_heads, 1, self.hidden_dim), **kwargs)) |
| | self.bias1 = nn.Parameter(torch.zeros((1, self.num_heads, 1, self.feature_dim), **kwargs)) |
| | nn.init.kaiming_uniform_(self.bias0) |
| | nn.init.kaiming_uniform_(self.bias1) |
| |
|
| | def zero_init_with_skip_(self): |
| | with torch.no_grad(): |
| | nn.init.zeros_(self.layer0) |
| | nn.init.zeros_(self.layer1) |
| | nn.init.zeros_(self.bias0) |
| | nn.init.zeros_(self.bias1) |
| |
|
| | def zero_init_(self): |
| | assert NotImplementedError |
| |
|
| | def forward(self, x: torch.Tensor): |
| | """ |
| | Assume x.shape is (batch_size, num_heads, seq_len, head_dim) |
| | -> Down-project, apply nonlinearity, up-project; add skip connection |
| | """ |
| | _x = torch.einsum('hde,bhld->bhle', self.layer0, x) + self.bias0 |
| | _x = F.relu(_x) |
| | _x = torch.einsum('hef,bhle->bhlf', self.layer1, _x) + self.bias1 |
| | return x + _x if self.skip_connection else _x |
| |
|