| |
| |
|
|
| |
| |
|
|
| """Multi-Head Attention layer definition.""" |
|
|
| import math |
|
|
| import numpy |
| import torch |
| from torch import nn |
|
|
|
|
| class MultiHeadedAttention(nn.Module): |
| """Multi-Head Attention layer. |
| Args: |
| n_head (int): The number of heads. |
| n_feat (int): The number of features. |
| dropout_rate (float): Dropout rate. |
| """ |
|
|
| def __init__(self, n_head, n_feat, dropout_rate): |
| """Construct an MultiHeadedAttention object.""" |
| super(MultiHeadedAttention, self).__init__() |
| assert n_feat % n_head == 0 |
| |
| self.d_k = n_feat // n_head |
| self.h = n_head |
| self.linear_q = nn.Linear(n_feat, n_feat) |
| self.linear_k = nn.Linear(n_feat, n_feat) |
| self.linear_v = nn.Linear(n_feat, n_feat) |
| self.linear_out = nn.Linear(n_feat, n_feat) |
| self.attn = None |
| self.dropout = nn.Dropout(p=dropout_rate) |
| self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') |
| if not self.flash: |
| print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0") |
|
|
|
|
| def forward_qkv(self, query, key, value): |
| """Transform query, key and value. |
| Args: |
| query (torch.Tensor): Query tensor (#batch, time1, size). |
| key (torch.Tensor): Key tensor (#batch, time2, size). |
| value (torch.Tensor): Value tensor (#batch, time2, size). |
| Returns: |
| torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k). |
| torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k). |
| torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k). |
| """ |
| n_batch = query.size(0) |
| q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k) |
| k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k) |
| v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k) |
| q = q.transpose(1, 2) |
| k = k.transpose(1, 2) |
| v = v.transpose(1, 2) |
|
|
| return q, k, v |
|
|
| def forward_attention(self, value, scores, mask): |
| """Compute attention context vector. |
| Args: |
| value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k). |
| scores (torch.Tensor): Attention score (#batch, n_head, time1, time2). |
| mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2). |
| Returns: |
| torch.Tensor: Transformed value (#batch, time1, d_model) |
| weighted by the attention score (#batch, time1, time2). |
| """ |
| n_batch = value.size(0) |
| if mask is not None: |
| mask = mask.unsqueeze(1).eq(0) |
| min_value = float( |
| numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min |
| ) |
| scores = scores.masked_fill(mask, min_value) |
| self.attn = torch.softmax(scores, dim=-1).masked_fill( |
| mask, 0.0 |
| ) |
| else: |
| self.attn = torch.softmax(scores, dim=-1) |
|
|
| p_attn = self.dropout(self.attn) |
| x = torch.matmul(p_attn, value) |
| x = ( |
| x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k) |
| ) |
|
|
| return self.linear_out(x) |
|
|
| def forward(self, query, key, value, mask): |
| """Compute scaled dot product attention. |
| Args: |
| query (torch.Tensor): Query tensor (#batch, time1, size). |
| key (torch.Tensor): Key tensor (#batch, time2, size). |
| value (torch.Tensor): Value tensor (#batch, time2, size). |
| mask (torch.Tensor): Mask tensor (#batch, 1, time2) or |
| (#batch, time1, time2). |
| Returns: |
| torch.Tensor: Output tensor (#batch, time1, d_model). |
| """ |
| q, k, v = self.forward_qkv(query, key, value) |
|
|
| B, Nh, Nt, E = q.shape |
| q = q / math.sqrt(E) |
| mask = mask * mask[:, None, :, 0] |
| mask = mask[:, None] |
| if self.flash: |
| attn = torch.nn.functional.scaled_dot_product_attention(q, k, v, is_causal=False, attn_mask=mask) |
| else: |
| attn = self.slow_attn(q, k, v, is_causal=False, attn_mask=mask) |
| attn = attn.transpose(1, 2) |
| attn = attn.reshape(B, -1, self.h * self.d_k) |
| attn = self.linear_out(attn) |
| return attn |
|
|
| def slow_attn(self, Q, K, V, is_causal, attn_mask): |
| attn_mask = attn_mask.masked_fill(not attn_mask, -float('inf')) if attn_mask.dtype == torch.bool else attn_mask |
| attn_weight = torch.softmax((Q @ K.transpose(-2, -1) / math.sqrt(Q.size(-1))) + attn_mask, dim=-1) |
| return attn_weight @ V |
|
|
| class RelPositionMultiHeadedAttention(MultiHeadedAttention): |
| """Multi-Head Attention layer with relative position encoding. |
| Paper: https://arxiv.org/abs/1901.02860 |
| Args: |
| n_head (int): The number of heads. |
| n_feat (int): The number of features. |
| dropout_rate (float): Dropout rate. |
| """ |
|
|
| def __init__(self, n_head, n_feat, dropout_rate): |
| """Construct an RelPositionMultiHeadedAttention object.""" |
| super().__init__(n_head, n_feat, dropout_rate) |
| |
| self.linear_pos = nn.Linear(n_feat, n_feat, bias=False) |
| |
| |
| self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k)) |
| self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k)) |
| torch.nn.init.xavier_uniform_(self.pos_bias_u) |
| torch.nn.init.xavier_uniform_(self.pos_bias_v) |
|
|
| def rel_shift(self, x, zero_triu=False): |
| """Compute relative positinal encoding. |
| Args: |
| x (torch.Tensor): Input tensor (batch, time, size). |
| zero_triu (bool): If true, return the lower triangular part of the matrix. |
| Returns: |
| torch.Tensor: Output tensor. |
| """ |
| zero_pad = torch.zeros((*x.size()[:3], 1), device=x.device, dtype=x.dtype) |
| x_padded = torch.cat([zero_pad, x], dim=-1) |
|
|
| x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2)) |
| x = x_padded[:, :, 1:].view_as(x) |
|
|
| if zero_triu: |
| ones = torch.ones((x.size(2), x.size(3))) |
| x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :] |
|
|
| return x |
|
|
| def forward(self, query, key, value, pos_emb, mask): |
| """Compute 'Scaled Dot Product Attention' with rel. positional encoding. |
| Args: |
| query (torch.Tensor): Query tensor (#batch, time1, size). |
| key (torch.Tensor): Key tensor (#batch, time2, size). |
| value (torch.Tensor): Value tensor (#batch, time2, size). |
| pos_emb (torch.Tensor): Positional embedding tensor (#batch, time2, size). |
| mask (torch.Tensor): Mask tensor (#batch, 1, time2) or |
| (#batch, time1, time2). |
| Returns: |
| torch.Tensor: Output tensor (#batch, time1, d_model). |
| """ |
| q, k, v = self.forward_qkv(query, key, value) |
| q = q.transpose(1, 2) |
|
|
| n_batch_pos = pos_emb.size(0) |
| p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k) |
| p = p.transpose(1, 2) |
|
|
| |
| q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2) |
| |
| q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2) |
|
|
| |
| |
| |
| |
| matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1)) |
|
|
| |
| |
| matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1)) |
| matrix_bd = self.rel_shift(matrix_bd) |
|
|
| scores = (matrix_ac + matrix_bd) / math.sqrt( |
| self.d_k |
| ) |
|
|
| return self.forward_attention(v, scores, mask) |
|
|