| import torch |
| from typeguard import check_argument_types |
| from typing import Sequence |
| from typing import Union |
|
|
|
|
| def mask_along_axis( |
| spec: torch.Tensor, |
| spec_lengths: torch.Tensor, |
| mask_width_range: Sequence[int] = (0, 30), |
| dim: int = 1, |
| num_mask: int = 2, |
| replace_with_zero: bool = True, |
| ): |
| """Apply mask along the specified direction. |
| |
| Args: |
| spec: (Batch, Length, Freq) |
| spec_lengths: (Length): Not using lenghts in this implementation |
| mask_width_range: Select the width randomly between this range |
| """ |
|
|
| org_size = spec.size() |
| if spec.dim() == 4: |
| |
| spec = spec.view(-1, spec.size(2), spec.size(3)) |
|
|
| B = spec.shape[0] |
| |
| D = spec.shape[dim] |
| |
| mask_length = torch.randint( |
| mask_width_range[0], |
| mask_width_range[1], |
| (B, num_mask), |
| device=spec.device, |
| ).unsqueeze(2) |
|
|
| |
| mask_pos = torch.randint( |
| 0, max(1, D - mask_length.max()), (B, num_mask), device=spec.device |
| ).unsqueeze(2) |
|
|
| |
| aran = torch.arange(D, device=spec.device)[None, None, :] |
| |
| mask = (mask_pos <= aran) * (aran < (mask_pos + mask_length)) |
| |
| mask = mask.any(dim=1) |
| if dim == 1: |
| |
| mask = mask.unsqueeze(2) |
| elif dim == 2: |
| |
| mask = mask.unsqueeze(1) |
|
|
| if replace_with_zero: |
| value = 0.0 |
| else: |
| value = spec.mean() |
|
|
| if spec.requires_grad: |
| spec = spec.masked_fill(mask, value) |
| else: |
| spec = spec.masked_fill_(mask, value) |
| spec = spec.view(*org_size) |
| return spec, spec_lengths |
|
|
|
|
| class MaskAlongAxis(torch.nn.Module): |
| def __init__( |
| self, |
| mask_width_range: Union[int, Sequence[int]] = (0, 30), |
| num_mask: int = 2, |
| dim: Union[int, str] = "time", |
| replace_with_zero: bool = True, |
| ): |
| assert check_argument_types() |
| if isinstance(mask_width_range, int): |
| mask_width_range = (0, mask_width_range) |
| if len(mask_width_range) != 2: |
| raise TypeError( |
| f"mask_width_range must be a tuple of int and int values: " |
| f"{mask_width_range}", |
| ) |
|
|
| assert mask_width_range[1] > mask_width_range[0] |
| if isinstance(dim, str): |
| if dim == "time": |
| dim = 1 |
| elif dim == "freq": |
| dim = 2 |
| else: |
| raise ValueError("dim must be int, 'time' or 'freq'") |
| if dim == 1: |
| self.mask_axis = "time" |
| elif dim == 2: |
| self.mask_axis = "freq" |
| else: |
| self.mask_axis = "unknown" |
|
|
| super().__init__() |
| self.mask_width_range = mask_width_range |
| self.num_mask = num_mask |
| self.dim = dim |
| self.replace_with_zero = replace_with_zero |
|
|
| def extra_repr(self): |
| return ( |
| f"mask_width_range={self.mask_width_range}, " |
| f"num_mask={self.num_mask}, axis={self.mask_axis}" |
| ) |
|
|
| def forward(self, spec: torch.Tensor, spec_lengths: torch.Tensor = None): |
| """Forward function. |
| |
| Args: |
| spec: (Batch, Length, Freq) |
| """ |
|
|
| return mask_along_axis( |
| spec, |
| spec_lengths, |
| mask_width_range=self.mask_width_range, |
| dim=self.dim, |
| num_mask=self.num_mask, |
| replace_with_zero=self.replace_with_zero, |
| ) |
|
|