| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| """Positonal Encoding Module.""" |
|
|
| import math |
| from typing import Tuple, Union |
|
|
| import torch |
| import torch.nn.functional as F |
|
|
|
|
| class PositionalEncoding(torch.nn.Module): |
| """Positional encoding. |
| |
| :param int d_model: embedding dim |
| :param float dropout_rate: dropout rate |
| :param int max_len: maximum input length |
| |
| PE(pos, 2i) = sin(pos/(10000^(2i/dmodel))) |
| PE(pos, 2i+1) = cos(pos/(10000^(2i/dmodel))) |
| """ |
| def __init__(self, |
| d_model: int, |
| dropout_rate: float, |
| max_len: int = 5000, |
| reverse: bool = False): |
| """Construct an PositionalEncoding object.""" |
| super().__init__() |
| self.d_model = d_model |
| self.xscale = math.sqrt(self.d_model) |
| self.dropout = torch.nn.Dropout(p=dropout_rate) |
| self.max_len = max_len |
|
|
| pe = torch.zeros(self.max_len, self.d_model) |
| position = torch.arange(0, self.max_len).unsqueeze(1) |
| div_term = torch.exp( |
| torch.arange(0, self.d_model, 2) * |
| -(math.log(10000.0) / self.d_model)) |
| pe[:, 0::2] = torch.sin(position * div_term) |
| pe[:, 1::2] = torch.cos(position * div_term) |
| pe = pe.unsqueeze(0) |
| self.register_buffer('pe', pe) |
|
|
| def forward(self, |
| x: torch.Tensor, |
| offset: Union[int, torch.Tensor] = 0) \ |
| -> Tuple[torch.Tensor, torch.Tensor]: |
| """Add positional encoding. |
| |
| Args: |
| x (torch.Tensor): Input. Its shape is (batch, time, ...) |
| offset (int, torch.tensor): position offset |
| |
| Returns: |
| torch.Tensor: Encoded tensor. Its shape is (batch, time, ...) |
| torch.Tensor: for compatibility to RelPositionalEncoding |
| """ |
|
|
| self.pe = self.pe.to(x.device) |
| pos_emb = self.position_encoding(offset, x.size(1), False) |
| x = x * self.xscale + pos_emb |
| return self.dropout(x), self.dropout(pos_emb) |
|
|
| def position_encoding(self, offset: Union[int, torch.Tensor], size: int, |
| apply_dropout: bool = True) -> torch.Tensor: |
| """ For getting encoding in a streaming fashion |
| |
| Attention!!!!! |
| we apply dropout only once at the whole utterance level in a none |
| streaming way, but will call this function several times with |
| increasing input size in a streaming scenario, so the dropout will |
| be applied several times. |
| |
| Args: |
| offset (int or torch.tensor): start offset |
| size (int): required size of position encoding |
| |
| Returns: |
| torch.Tensor: Corresponding encoding |
| """ |
| |
| |
| if isinstance(offset, int): |
| assert offset + size < self.max_len |
| pos_emb = self.pe[:, offset:offset + size] |
| elif isinstance(offset, torch.Tensor) and offset.dim() == 0: |
| assert offset + size < self.max_len |
| pos_emb = self.pe[:, offset:offset + size] |
| else: |
| assert torch.max(offset) + size < self.max_len |
| index = offset.unsqueeze(1) + \ |
| torch.arange(0, size).to(offset.device) |
| flag = index > 0 |
| |
| index = index * flag |
| pos_emb = F.embedding(index, self.pe[0]) |
|
|
| if apply_dropout: |
| pos_emb = self.dropout(pos_emb) |
| return pos_emb |
|
|
| class RelPositionalEncoding(PositionalEncoding): |
| """Relative positional encoding module. |
| See : Appendix B in https://arxiv.org/abs/1901.02860 |
| Args: |
| d_model (int): Embedding dimension. |
| dropout_rate (float): Dropout rate. |
| max_len (int): Maximum input length. |
| """ |
| def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000): |
| """Initialize class.""" |
| super().__init__(d_model, dropout_rate, max_len, reverse=True) |
|
|
| def forward(self, |
| x: torch.Tensor, |
| offset: Union[int, torch.Tensor] = 0) \ |
| -> Tuple[torch.Tensor, torch.Tensor]: |
| """Compute positional encoding. |
| Args: |
| x (torch.Tensor): Input tensor (batch, time, `*`). |
| Returns: |
| torch.Tensor: Encoded tensor (batch, time, `*`). |
| torch.Tensor: Positional embedding tensor (1, time, `*`). |
| """ |
| self.pe = self.pe.to(x.device) |
| x = x * self.xscale |
| pos_emb = self.position_encoding(offset, x.size(1), False) |
| return self.dropout(x), self.dropout(pos_emb) |
|
|
|
|
| class NoPositionalEncoding(torch.nn.Module): |
| """ No position encoding |
| """ |
| def __init__(self, d_model: int, dropout_rate: float): |
| super().__init__() |
| self.d_model = d_model |
| self.dropout = torch.nn.Dropout(p=dropout_rate) |
|
|
| def forward(self, |
| x: torch.Tensor, |
| offset: Union[int, torch.Tensor] = 0) \ |
| -> Tuple[torch.Tensor, torch.Tensor]: |
| """ Just return zero vector for interface compatibility |
| """ |
| pos_emb = torch.zeros(1, x.size(1), self.d_model).to(x.device) |
| return self.dropout(x), pos_emb |
|
|
| def position_encoding( |
| self, offset: Union[int, torch.Tensor], size: int) -> torch.Tensor: |
| return torch.zeros(1, size, self.d_model) |
|
|