| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | """Positonal Encoding Module.""" |
| |
|
| | import math |
| | from typing import Tuple, Union |
| |
|
| | import torch |
| | import torch.nn.functional as F |
| |
|
| |
|
| | class PositionalEncoding(torch.nn.Module): |
| | """Positional encoding. |
| | |
| | :param int d_model: embedding dim |
| | :param float dropout_rate: dropout rate |
| | :param int max_len: maximum input length |
| | |
| | PE(pos, 2i) = sin(pos/(10000^(2i/dmodel))) |
| | PE(pos, 2i+1) = cos(pos/(10000^(2i/dmodel))) |
| | """ |
| | def __init__(self, |
| | d_model: int, |
| | dropout_rate: float, |
| | max_len: int = 5000, |
| | reverse: bool = False): |
| | """Construct an PositionalEncoding object.""" |
| | super().__init__() |
| | self.d_model = d_model |
| | self.xscale = math.sqrt(self.d_model) |
| | self.dropout = torch.nn.Dropout(p=dropout_rate) |
| | self.max_len = max_len |
| |
|
| | pe = torch.zeros(self.max_len, self.d_model) |
| | position = torch.arange(0, self.max_len).unsqueeze(1) |
| | div_term = torch.exp( |
| | torch.arange(0, self.d_model, 2) * |
| | -(math.log(10000.0) / self.d_model)) |
| | pe[:, 0::2] = torch.sin(position * div_term) |
| | pe[:, 1::2] = torch.cos(position * div_term) |
| | pe = pe.unsqueeze(0) |
| | self.register_buffer('pe', pe) |
| |
|
| | def forward(self, |
| | x: torch.Tensor, |
| | offset: Union[int, torch.Tensor] = 0) \ |
| | -> Tuple[torch.Tensor, torch.Tensor]: |
| | """Add positional encoding. |
| | |
| | Args: |
| | x (torch.Tensor): Input. Its shape is (batch, time, ...) |
| | offset (int, torch.tensor): position offset |
| | |
| | Returns: |
| | torch.Tensor: Encoded tensor. Its shape is (batch, time, ...) |
| | torch.Tensor: for compatibility to RelPositionalEncoding |
| | """ |
| |
|
| | self.pe = self.pe.to(x.device) |
| | pos_emb = self.position_encoding(offset, x.size(1), False) |
| | x = x * self.xscale + pos_emb |
| | return self.dropout(x), self.dropout(pos_emb) |
| |
|
| | def position_encoding(self, offset: Union[int, torch.Tensor], size: int, |
| | apply_dropout: bool = True) -> torch.Tensor: |
| | """ For getting encoding in a streaming fashion |
| | |
| | Attention!!!!! |
| | we apply dropout only once at the whole utterance level in a none |
| | streaming way, but will call this function several times with |
| | increasing input size in a streaming scenario, so the dropout will |
| | be applied several times. |
| | |
| | Args: |
| | offset (int or torch.tensor): start offset |
| | size (int): required size of position encoding |
| | |
| | Returns: |
| | torch.Tensor: Corresponding encoding |
| | """ |
| | |
| | |
| | if isinstance(offset, int): |
| | assert offset + size < self.max_len |
| | pos_emb = self.pe[:, offset:offset + size] |
| | elif isinstance(offset, torch.Tensor) and offset.dim() == 0: |
| | assert offset + size < self.max_len |
| | pos_emb = self.pe[:, offset:offset + size] |
| | else: |
| | assert torch.max(offset) + size < self.max_len |
| | index = offset.unsqueeze(1) + \ |
| | torch.arange(0, size).to(offset.device) |
| | flag = index > 0 |
| | |
| | index = index * flag |
| | pos_emb = F.embedding(index, self.pe[0]) |
| |
|
| | if apply_dropout: |
| | pos_emb = self.dropout(pos_emb) |
| | return pos_emb |
| |
|
| | class RelPositionalEncoding(PositionalEncoding): |
| | """Relative positional encoding module. |
| | See : Appendix B in https://arxiv.org/abs/1901.02860 |
| | Args: |
| | d_model (int): Embedding dimension. |
| | dropout_rate (float): Dropout rate. |
| | max_len (int): Maximum input length. |
| | """ |
| | def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000): |
| | """Initialize class.""" |
| | super().__init__(d_model, dropout_rate, max_len, reverse=True) |
| |
|
| | def forward(self, |
| | x: torch.Tensor, |
| | offset: Union[int, torch.Tensor] = 0) \ |
| | -> Tuple[torch.Tensor, torch.Tensor]: |
| | """Compute positional encoding. |
| | Args: |
| | x (torch.Tensor): Input tensor (batch, time, `*`). |
| | Returns: |
| | torch.Tensor: Encoded tensor (batch, time, `*`). |
| | torch.Tensor: Positional embedding tensor (1, time, `*`). |
| | """ |
| | self.pe = self.pe.to(x.device) |
| | x = x * self.xscale |
| | pos_emb = self.position_encoding(offset, x.size(1), False) |
| | return self.dropout(x), self.dropout(pos_emb) |
| |
|
| |
|
| | class NoPositionalEncoding(torch.nn.Module): |
| | """ No position encoding |
| | """ |
| | def __init__(self, d_model: int, dropout_rate: float): |
| | super().__init__() |
| | self.d_model = d_model |
| | self.dropout = torch.nn.Dropout(p=dropout_rate) |
| |
|
| | def forward(self, |
| | x: torch.Tensor, |
| | offset: Union[int, torch.Tensor] = 0) \ |
| | -> Tuple[torch.Tensor, torch.Tensor]: |
| | """ Just return zero vector for interface compatibility |
| | """ |
| | pos_emb = torch.zeros(1, x.size(1), self.d_model).to(x.device) |
| | return self.dropout(x), pos_emb |
| |
|
| | def position_encoding( |
| | self, offset: Union[int, torch.Tensor], size: int) -> torch.Tensor: |
| | return torch.zeros(1, size, self.d_model) |
| |
|