| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | """Multi-Head Attention layer definition.""" |
| |
|
| | import math |
| | from typing import Tuple |
| |
|
| | import torch |
| | from torch import nn |
| |
|
| |
|
| | class MultiHeadedAttention(nn.Module): |
| | """Multi-Head Attention layer. |
| | |
| | Args: |
| | n_head (int): The number of heads. |
| | n_feat (int): The number of features. |
| | dropout_rate (float): Dropout rate. |
| | |
| | """ |
| |
|
| | def __init__(self, |
| | n_head: int, |
| | n_feat: int, |
| | dropout_rate: float, |
| | key_bias: bool = True): |
| | """Construct an MultiHeadedAttention object.""" |
| | super().__init__() |
| | assert n_feat % n_head == 0 |
| | |
| | self.d_k = n_feat // n_head |
| | self.h = n_head |
| | self.linear_q = nn.Linear(n_feat, n_feat) |
| | self.linear_k = nn.Linear(n_feat, n_feat, bias=key_bias) |
| | self.linear_v = nn.Linear(n_feat, n_feat) |
| | self.linear_out = nn.Linear(n_feat, n_feat) |
| | self.dropout = nn.Dropout(p=dropout_rate) |
| |
|
| | def forward_qkv( |
| | self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor |
| | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| | """Transform query, key and value. |
| | |
| | Args: |
| | query (torch.Tensor): Query tensor (#batch, time1, size). |
| | key (torch.Tensor): Key tensor (#batch, time2, size). |
| | value (torch.Tensor): Value tensor (#batch, time2, size). |
| | |
| | Returns: |
| | torch.Tensor: Transformed query tensor, size |
| | (#batch, n_head, time1, d_k). |
| | torch.Tensor: Transformed key tensor, size |
| | (#batch, n_head, time2, d_k). |
| | torch.Tensor: Transformed value tensor, size |
| | (#batch, n_head, time2, d_k). |
| | |
| | """ |
| | n_batch = query.size(0) |
| | q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k) |
| | k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k) |
| | v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k) |
| | q = q.transpose(1, 2) |
| | k = k.transpose(1, 2) |
| | v = v.transpose(1, 2) |
| |
|
| | return q, k, v |
| |
|
| | def forward_attention( |
| | self, |
| | value: torch.Tensor, |
| | scores: torch.Tensor, |
| | mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool) |
| | ) -> torch.Tensor: |
| | """Compute attention context vector. |
| | |
| | Args: |
| | value (torch.Tensor): Transformed value, size |
| | (#batch, n_head, time2, d_k). |
| | scores (torch.Tensor): Attention score, size |
| | (#batch, n_head, time1, time2). |
| | mask (torch.Tensor): Mask, size (#batch, 1, time2) or |
| | (#batch, time1, time2), (0, 0, 0) means fake mask. |
| | |
| | Returns: |
| | torch.Tensor: Transformed value (#batch, time1, d_model) |
| | weighted by the attention score (#batch, time1, time2). |
| | |
| | """ |
| | n_batch = value.size(0) |
| | |
| | |
| | |
| | |
| | if mask.size(2) > 0: |
| | mask = mask.unsqueeze(1).eq(0) |
| | |
| | mask = mask[:, :, :, :scores.size(-1)] |
| | scores = scores.masked_fill(mask, -float('inf')) |
| | attn = torch.softmax(scores, dim=-1).masked_fill( |
| | mask, 0.0) |
| | |
| | |
| | |
| | else: |
| | attn = torch.softmax(scores, dim=-1) |
| |
|
| | p_attn = self.dropout(attn) |
| | x = torch.matmul(p_attn, value) |
| | x = (x.transpose(1, 2).contiguous().view(n_batch, -1, |
| | self.h * self.d_k) |
| | ) |
| |
|
| | return self.linear_out(x) |
| |
|
| | def forward( |
| | self, |
| | query: torch.Tensor, |
| | key: torch.Tensor, |
| | value: torch.Tensor, |
| | mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), |
| | pos_emb: torch.Tensor = torch.empty(0), |
| | cache: torch.Tensor = torch.zeros((0, 0, 0, 0)) |
| | ) -> Tuple[torch.Tensor, torch.Tensor]: |
| | """Compute scaled dot product attention. |
| | |
| | Args: |
| | query (torch.Tensor): Query tensor (#batch, time1, size). |
| | key (torch.Tensor): Key tensor (#batch, time2, size). |
| | value (torch.Tensor): Value tensor (#batch, time2, size). |
| | mask (torch.Tensor): Mask tensor (#batch, 1, time2) or |
| | (#batch, time1, time2). |
| | 1.When applying cross attention between decoder and encoder, |
| | the batch padding mask for input is in (#batch, 1, T) shape. |
| | 2.When applying self attention of encoder, |
| | the mask is in (#batch, T, T) shape. |
| | 3.When applying self attention of decoder, |
| | the mask is in (#batch, L, L) shape. |
| | 4.If the different position in decoder see different block |
| | of the encoder, such as Mocha, the passed in mask could be |
| | in (#batch, L, T) shape. But there is no such case in current |
| | CosyVoice. |
| | cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2), |
| | where `cache_t == chunk_size * num_decoding_left_chunks` |
| | and `head * d_k == size` |
| | |
| | |
| | Returns: |
| | torch.Tensor: Output tensor (#batch, time1, d_model). |
| | torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2) |
| | where `cache_t == chunk_size * num_decoding_left_chunks` |
| | and `head * d_k == size` |
| | |
| | """ |
| | q, k, v = self.forward_qkv(query, key, value) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | if cache.size(0) > 0: |
| | key_cache, value_cache = torch.split(cache, |
| | cache.size(-1) // 2, |
| | dim=-1) |
| | k = torch.cat([key_cache, k], dim=2) |
| | v = torch.cat([value_cache, v], dim=2) |
| | |
| | |
| | new_cache = torch.cat((k, v), dim=-1) |
| |
|
| | scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k) |
| | return self.forward_attention(v, scores, mask), new_cache |
| |
|
| |
|
| | class RelPositionMultiHeadedAttention(MultiHeadedAttention): |
| | """Multi-Head Attention layer with relative position encoding. |
| | Paper: https://arxiv.org/abs/1901.02860 |
| | Args: |
| | n_head (int): The number of heads. |
| | n_feat (int): The number of features. |
| | dropout_rate (float): Dropout rate. |
| | """ |
| |
|
| | def __init__(self, |
| | n_head: int, |
| | n_feat: int, |
| | dropout_rate: float, |
| | key_bias: bool = True): |
| | """Construct an RelPositionMultiHeadedAttention object.""" |
| | super().__init__(n_head, n_feat, dropout_rate, key_bias) |
| | |
| | self.linear_pos = nn.Linear(n_feat, n_feat, bias=False) |
| | |
| | |
| | self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k)) |
| | self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k)) |
| | torch.nn.init.xavier_uniform_(self.pos_bias_u) |
| | torch.nn.init.xavier_uniform_(self.pos_bias_v) |
| |
|
| | def rel_shift(self, x): |
| | """Compute relative positional encoding. |
| | |
| | Args: |
| | x (torch.Tensor): Input tensor (batch, head, time1, 2*time1-1). |
| | time1 means the length of query vector. |
| | |
| | Returns: |
| | torch.Tensor: Output tensor. |
| | |
| | """ |
| | zero_pad = torch.zeros((*x.size()[:3], 1), device=x.device, dtype=x.dtype) |
| | x_padded = torch.cat([zero_pad, x], dim=-1) |
| |
|
| | x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2)) |
| | x = x_padded[:, :, 1:].view_as(x)[ |
| | :, :, :, : x.size(-1) // 2 + 1 |
| | ] |
| | return x |
| |
|
| | def forward( |
| | self, |
| | query: torch.Tensor, |
| | key: torch.Tensor, |
| | value: torch.Tensor, |
| | mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), |
| | pos_emb: torch.Tensor = torch.empty(0), |
| | cache: torch.Tensor = torch.zeros((0, 0, 0, 0)) |
| | ) -> Tuple[torch.Tensor, torch.Tensor]: |
| | """Compute 'Scaled Dot Product Attention' with rel. positional encoding. |
| | Args: |
| | query (torch.Tensor): Query tensor (#batch, time1, size). |
| | key (torch.Tensor): Key tensor (#batch, time2, size). |
| | value (torch.Tensor): Value tensor (#batch, time2, size). |
| | mask (torch.Tensor): Mask tensor (#batch, 1, time2) or |
| | (#batch, time1, time2), (0, 0, 0) means fake mask. |
| | pos_emb (torch.Tensor): Positional embedding tensor |
| | (#batch, time2, size). |
| | cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2), |
| | where `cache_t == chunk_size * num_decoding_left_chunks` |
| | and `head * d_k == size` |
| | Returns: |
| | torch.Tensor: Output tensor (#batch, time1, d_model). |
| | torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2) |
| | where `cache_t == chunk_size * num_decoding_left_chunks` |
| | and `head * d_k == size` |
| | """ |
| | q, k, v = self.forward_qkv(query, key, value) |
| | q = q.transpose(1, 2) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | if cache.size(0) > 0: |
| | key_cache, value_cache = torch.split(cache, |
| | cache.size(-1) // 2, |
| | dim=-1) |
| | k = torch.cat([key_cache, k], dim=2) |
| | v = torch.cat([value_cache, v], dim=2) |
| | |
| | |
| | new_cache = torch.cat((k, v), dim=-1) |
| |
|
| | n_batch_pos = pos_emb.size(0) |
| | p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k) |
| | p = p.transpose(1, 2) |
| |
|
| | |
| | q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2) |
| | |
| | q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2) |
| |
|
| | |
| | |
| | |
| | |
| | matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1)) |
| |
|
| | |
| | |
| | matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1)) |
| | |
| | if matrix_ac.shape != matrix_bd.shape: |
| | matrix_bd = self.rel_shift(matrix_bd) |
| |
|
| | scores = (matrix_ac + matrix_bd) / math.sqrt( |
| | self.d_k) |
| |
|
| | return self.forward_attention(v, scores, mask), new_cache |