| |
| import torch |
| import torch.nn as nn |
| import numpy as np |
|
|
|
|
| class SelfAttention(nn.Module): |
| def __init__(self, input_size=1024, output_size=1024, freq=10000, heads=1, pos_enc=None): |
| """ The basic (multi-head) Attention 'cell' containing the learnable parameters of Q, K and V |
| |
| :param int input_size: Feature input size of Q, K, V. |
| :param int output_size: Feature -hidden- size of Q, K, V. |
| :param int freq: The frequency of the sinusoidal positional encoding. |
| :param int heads: Number of heads for the attention module. |
| :param str | None pos_enc: The type of the positional encoding [supported: Absolute, Relative]. |
| """ |
| super(SelfAttention, self).__init__() |
|
|
| self.permitted_encodings = ["absolute", "relative"] |
| if pos_enc is not None: |
| pos_enc = pos_enc.lower() |
| assert pos_enc in self.permitted_encodings, f"Supported encodings: {*self.permitted_encodings,}" |
|
|
| self.input_size = input_size |
| self.output_size = output_size |
| self.heads = heads |
| self.pos_enc = pos_enc |
| self.freq = freq |
| self.Wk, self.Wq, self.Wv = nn.ModuleList(), nn.ModuleList(), nn.ModuleList() |
| for _ in range(self.heads): |
| self.Wk.append(nn.Linear(in_features=input_size, out_features=output_size//heads, bias=False)) |
| self.Wq.append(nn.Linear(in_features=input_size, out_features=output_size//heads, bias=False)) |
| self.Wv.append(nn.Linear(in_features=input_size, out_features=output_size//heads, bias=False)) |
| self.out = nn.Linear(in_features=output_size, out_features=input_size, bias=False) |
|
|
| self.softmax = nn.Softmax(dim=-1) |
| self.drop = nn.Dropout(p=0.5) |
|
|
| def getAbsolutePosition(self, T): |
| """Calculate the sinusoidal positional encoding based on the absolute position of each considered frame. |
| Based on 'Attention is all you need' paper (https://arxiv.org/abs/1706.03762) |
| |
| :param int T: Number of frames contained in Q, K and V |
| :return: Tensor with shape [T, T] |
| """ |
| freq = self.freq |
| d = self.input_size |
|
|
| pos = torch.tensor([k for k in range(T)], device=self.out.weight.device) |
| i = torch.tensor([k for k in range(T//2)], device=self.out.weight.device) |
|
|
| |
| pos = pos.reshape(pos.shape[0], 1) |
| pos = pos.repeat_interleave(i.shape[0], dim=1) |
| i = i.repeat(pos.shape[0], 1) |
|
|
| AP = torch.zeros(T, T, device=self.out.weight.device) |
| AP[pos, 2*i] = torch.sin(pos / freq ** ((2 * i) / d)) |
| AP[pos, 2*i+1] = torch.cos(pos / freq ** ((2 * i) / d)) |
| return AP |
|
|
| def getRelativePosition(self, T): |
| """Calculate the sinusoidal positional encoding based on the relative position of each considered frame. |
| r_pos calculations as here: https://theaisummer.com/positional-embeddings/ |
| |
| :param int T: Number of frames contained in Q, K and V |
| :return: Tensor with shape [T, T] |
| """ |
| freq = self.freq |
| d = 2 * T |
| min_rpos = -(T - 1) |
|
|
| i = torch.tensor([k for k in range(T)], device=self.out.weight.device) |
| j = torch.tensor([k for k in range(T)], device=self.out.weight.device) |
|
|
| |
| i = i.reshape(i.shape[0], 1) |
| i = i.repeat_interleave(i.shape[0], dim=1) |
| j = j.repeat(i.shape[0], 1) |
|
|
| |
| r_pos = j - i - min_rpos |
|
|
| RP = torch.zeros(T, T, device=self.out.weight.device) |
| idx = torch.tensor([k for k in range(T//2)], device=self.out.weight.device) |
| RP[:, 2*idx] = torch.sin(r_pos[:, 2*idx] / freq ** ((i[:, 2*idx] + j[:, 2*idx]) / d)) |
| RP[:, 2*idx+1] = torch.cos(r_pos[:, 2*idx+1] / freq ** ((i[:, 2*idx+1] + j[:, 2*idx+1]) / d)) |
| return RP |
|
|
| def forward(self, x): |
| """ Compute the weighted frame features, based on either the global or local (multi-head) attention mechanism. |
| |
| :param torch.tensor x: Frame features with shape [T, input_size] |
| :return: A tuple of: |
| y: Weighted features based on the attention weights, with shape [T, input_size] |
| att_weights : The attention weights (before dropout), with shape [T, T] |
| """ |
| outputs = [] |
| for head in range(self.heads): |
| K = self.Wk[head](x) |
| Q = self.Wq[head](x) |
| V = self.Wv[head](x) |
|
|
| |
| |
| energies = torch.matmul(Q, K.transpose(1, 0)) |
| if self.pos_enc is not None: |
| if self.pos_enc == "absolute": |
| AP = self.getAbsolutePosition(T=energies.shape[0]) |
| energies = energies + AP |
| elif self.pos_enc == "relative": |
| RP = self.getRelativePosition(T=energies.shape[0]) |
| energies = energies + RP |
|
|
| att_weights = self.softmax(energies) |
| _att_weights = self.drop(att_weights) |
| y = torch.matmul(_att_weights, V) |
|
|
| |
| outputs.append(y) |
| y = self.out(torch.cat(outputs, dim=1)) |
| return y, att_weights.clone() |
|
|
|
|
| if __name__ == '__main__': |
| pass |
| """Uncomment for a quick proof of concept |
| model = SelfAttention(input_size=256, output_size=256, pos_enc="absolute").cuda() |
| _input = torch.randn(500, 256).cuda() # [seq_len, hidden_size] |
| output, weights = model(_input) |
| print(f"Output shape: {output.shape}\tattention shape: {weights.shape}") |
| """ |
|
|