| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | """Positonal Encoding Module.""" |
| |
|
| | import math |
| | from typing import Tuple, Union |
| |
|
| | import torch |
| | import torch.nn.functional as F |
| | import numpy as np |
| |
|
| |
|
| | class PositionalEncoding(torch.nn.Module): |
| | """Positional encoding. |
| | |
| | :param int d_model: embedding dim |
| | :param float dropout_rate: dropout rate |
| | :param int max_len: maximum input length |
| | |
| | PE(pos, 2i) = sin(pos/(10000^(2i/dmodel))) |
| | PE(pos, 2i+1) = cos(pos/(10000^(2i/dmodel))) |
| | """ |
| |
|
| | def __init__(self, |
| | d_model: int, |
| | dropout_rate: float, |
| | max_len: int = 5000, |
| | reverse: bool = False): |
| | """Construct an PositionalEncoding object.""" |
| | super().__init__() |
| | self.d_model = d_model |
| | self.xscale = math.sqrt(self.d_model) |
| | self.dropout = torch.nn.Dropout(p=dropout_rate) |
| | self.max_len = max_len |
| |
|
| | self.pe = torch.zeros(self.max_len, self.d_model) |
| | position = torch.arange(0, self.max_len, |
| | dtype=torch.float32).unsqueeze(1) |
| | div_term = torch.exp( |
| | torch.arange(0, self.d_model, 2, dtype=torch.float32) * |
| | -(math.log(10000.0) / self.d_model)) |
| | self.pe[:, 0::2] = torch.sin(position * div_term) |
| | self.pe[:, 1::2] = torch.cos(position * div_term) |
| | self.pe = self.pe.unsqueeze(0) |
| |
|
| | def forward(self, |
| | x: torch.Tensor, |
| | offset: Union[int, torch.Tensor] = 0) \ |
| | -> Tuple[torch.Tensor, torch.Tensor]: |
| | """Add positional encoding. |
| | |
| | Args: |
| | x (torch.Tensor): Input. Its shape is (batch, time, ...) |
| | offset (int, torch.tensor): position offset |
| | |
| | Returns: |
| | torch.Tensor: Encoded tensor. Its shape is (batch, time, ...) |
| | torch.Tensor: for compatibility to RelPositionalEncoding |
| | """ |
| |
|
| | self.pe = self.pe.to(x.device) |
| | pos_emb = self.position_encoding(offset, x.size(1), False) |
| | x = x * self.xscale + pos_emb |
| | return self.dropout(x), self.dropout(pos_emb) |
| |
|
| | def position_encoding(self, |
| | offset: Union[int, torch.Tensor], |
| | size: int, |
| | apply_dropout: bool = True) -> torch.Tensor: |
| | """ For getting encoding in a streaming fashion |
| | |
| | Attention!!!!! |
| | we apply dropout only once at the whole utterance level in a none |
| | streaming way, but will call this function several times with |
| | increasing input size in a streaming scenario, so the dropout will |
| | be applied several times. |
| | |
| | Args: |
| | offset (int or torch.tensor): start offset |
| | size (int): required size of position encoding |
| | |
| | Returns: |
| | torch.Tensor: Corresponding encoding |
| | """ |
| | |
| | |
| | if isinstance(offset, int): |
| | assert offset + size <= self.max_len |
| | pos_emb = self.pe[:, offset:offset + size] |
| | elif isinstance(offset, torch.Tensor) and offset.dim() == 0: |
| | assert offset + size <= self.max_len |
| | pos_emb = self.pe[:, offset:offset + size] |
| | else: |
| | assert torch.max(offset) + size <= self.max_len |
| | index = offset.unsqueeze(1) + \ |
| | torch.arange(0, size).to(offset.device) |
| | flag = index > 0 |
| | |
| | index = index * flag |
| | pos_emb = F.embedding(index, self.pe[0]) |
| |
|
| | if apply_dropout: |
| | pos_emb = self.dropout(pos_emb) |
| | return pos_emb |
| |
|
| |
|
| | class RelPositionalEncoding(PositionalEncoding): |
| | """Relative positional encoding module. |
| | See : Appendix B in https://arxiv.org/abs/1901.02860 |
| | Args: |
| | d_model (int): Embedding dimension. |
| | dropout_rate (float): Dropout rate. |
| | max_len (int): Maximum input length. |
| | """ |
| |
|
| | def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000): |
| | """Initialize class.""" |
| | super().__init__(d_model, dropout_rate, max_len, reverse=True) |
| |
|
| | def forward(self, |
| | x: torch.Tensor, |
| | offset: Union[int, torch.Tensor] = 0) \ |
| | -> Tuple[torch.Tensor, torch.Tensor]: |
| | """Compute positional encoding. |
| | Args: |
| | x (torch.Tensor): Input tensor (batch, time, `*`). |
| | Returns: |
| | torch.Tensor: Encoded tensor (batch, time, `*`). |
| | torch.Tensor: Positional embedding tensor (1, time, `*`). |
| | """ |
| | self.pe = self.pe.to(x.device) |
| | x = x * self.xscale |
| | pos_emb = self.position_encoding(offset, x.size(1), False) |
| | return self.dropout(x), self.dropout(pos_emb) |
| |
|
| |
|
| | class WhisperPositionalEncoding(PositionalEncoding): |
| | """ Sinusoids position encoding used in openai-whisper.encoder |
| | """ |
| |
|
| | def __init__(self, d_model: int, dropout_rate: float, max_len: int = 1500): |
| | super().__init__(d_model, dropout_rate, max_len) |
| | self.xscale = 1.0 |
| | log_timescale_increment = np.log(10000) / (d_model // 2 - 1) |
| | inv_timescales = torch.exp(-log_timescale_increment * |
| | torch.arange(d_model // 2)) |
| | scaled_time = torch.arange(max_len)[:, np.newaxis] * \ |
| | inv_timescales[np.newaxis, :] |
| | pe = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1) |
| | delattr(self, "pe") |
| | self.register_buffer("pe", pe.unsqueeze(0)) |
| |
|
| |
|
| | class LearnablePositionalEncoding(PositionalEncoding): |
| | """ Learnable position encoding used in openai-whisper.decoder |
| | """ |
| |
|
| | def __init__(self, d_model: int, dropout_rate: float, max_len: int = 448): |
| | super().__init__(d_model, dropout_rate, max_len) |
| | |
| | self.pe = torch.nn.Parameter(torch.empty(1, max_len, d_model)) |
| | self.xscale = 1.0 |
| |
|
| |
|
| | class NoPositionalEncoding(torch.nn.Module): |
| | """ No position encoding |
| | """ |
| |
|
| | def __init__(self, d_model: int, dropout_rate: float): |
| | super().__init__() |
| | self.d_model = d_model |
| | self.dropout = torch.nn.Dropout(p=dropout_rate) |
| |
|
| | def forward(self, |
| | x: torch.Tensor, |
| | offset: Union[int, torch.Tensor] = 0) \ |
| | -> Tuple[torch.Tensor, torch.Tensor]: |
| | """ Just return zero vector for interface compatibility |
| | """ |
| | pos_emb = torch.zeros(1, x.size(1), self.d_model).to(x.device) |
| | return self.dropout(x), pos_emb |
| |
|
| | def position_encoding(self, offset: Union[int, torch.Tensor], |
| | size: int) -> torch.Tensor: |
| | return torch.zeros(1, size, self.d_model) |
| |
|
| |
|
| | class EspnetRelPositionalEncoding(torch.nn.Module): |
| | """Relative positional encoding module (new implementation). |
| | |
| | Details can be found in https://github.com/espnet/espnet/pull/2816. |
| | |
| | See : Appendix B in https://arxiv.org/abs/1901.02860 |
| | |
| | Args: |
| | d_model (int): Embedding dimension. |
| | dropout_rate (float): Dropout rate. |
| | max_len (int): Maximum input length. |
| | |
| | """ |
| |
|
| | def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000): |
| | """Construct an PositionalEncoding object.""" |
| | super(EspnetRelPositionalEncoding, self).__init__() |
| | self.d_model = d_model |
| | self.xscale = math.sqrt(self.d_model) |
| | self.dropout = torch.nn.Dropout(p=dropout_rate) |
| | self.pe = None |
| | self.extend_pe(torch.tensor(0.0).expand(1, max_len)) |
| |
|
| | def extend_pe(self, x: torch.Tensor): |
| | """Reset the positional encodings.""" |
| | if self.pe is not None: |
| | |
| | |
| | if self.pe.size(1) >= x.size(1) * 2 - 1: |
| | if self.pe.dtype != x.dtype or self.pe.device != x.device: |
| | self.pe = self.pe.to(dtype=x.dtype, device=x.device) |
| | return |
| | |
| | |
| | |
| | pe_positive = torch.zeros(x.size(1), self.d_model) |
| | pe_negative = torch.zeros(x.size(1), self.d_model) |
| | position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1) |
| | div_term = torch.exp( |
| | torch.arange(0, self.d_model, 2, dtype=torch.float32) |
| | * -(math.log(10000.0) / self.d_model) |
| | ) |
| | pe_positive[:, 0::2] = torch.sin(position * div_term) |
| | pe_positive[:, 1::2] = torch.cos(position * div_term) |
| | pe_negative[:, 0::2] = torch.sin(-1 * position * div_term) |
| | pe_negative[:, 1::2] = torch.cos(-1 * position * div_term) |
| |
|
| | |
| | |
| | |
| | pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0) |
| | pe_negative = pe_negative[1:].unsqueeze(0) |
| | pe = torch.cat([pe_positive, pe_negative], dim=1) |
| | self.pe = pe.to(device=x.device, dtype=x.dtype) |
| |
|
| | def forward(self, x: torch.Tensor, offset: Union[int, torch.Tensor] = 0) \ |
| | -> Tuple[torch.Tensor, torch.Tensor]: |
| | """Add positional encoding. |
| | |
| | Args: |
| | x (torch.Tensor): Input tensor (batch, time, `*`). |
| | |
| | Returns: |
| | torch.Tensor: Encoded tensor (batch, time, `*`). |
| | |
| | """ |
| | self.extend_pe(x) |
| | x = x * self.xscale |
| | pos_emb = self.position_encoding(size=x.size(1), offset=offset) |
| | return self.dropout(x), self.dropout(pos_emb) |
| |
|
| | def position_encoding(self, |
| | offset: Union[int, torch.Tensor], |
| | size: int) -> torch.Tensor: |
| | """ For getting encoding in a streaming fashion |
| | |
| | Attention!!!!! |
| | we apply dropout only once at the whole utterance level in a none |
| | streaming way, but will call this function several times with |
| | increasing input size in a streaming scenario, so the dropout will |
| | be applied several times. |
| | |
| | Args: |
| | offset (int or torch.tensor): start offset |
| | size (int): required size of position encoding |
| | |
| | Returns: |
| | torch.Tensor: Corresponding encoding |
| | """ |
| | pos_emb = self.pe[ |
| | :, |
| | self.pe.size(1) // 2 - size + 1: self.pe.size(1) // 2 + size, |
| | ] |
| | return pos_emb |
| |
|