| | import torch |
| | from typeguard import check_argument_types |
| | from typing import Optional |
| | from typing import Tuple |
| |
|
| | from espnet.nets.pytorch_backend.nets_utils import make_pad_mask |
| |
|
| |
|
| | class LabelAggregate(torch.nn.Module): |
| | def __init__( |
| | self, |
| | win_length: int = 512, |
| | hop_length: int = 128, |
| | center: bool = True, |
| | ): |
| | assert check_argument_types() |
| | super().__init__() |
| |
|
| | self.win_length = win_length |
| | self.hop_length = hop_length |
| | self.center = center |
| |
|
| | def extra_repr(self): |
| | return ( |
| | f"win_length={self.win_length}, " |
| | f"hop_length={self.hop_length}, " |
| | f"center={self.center}, " |
| | ) |
| |
|
| | def forward( |
| | self, input: torch.Tensor, ilens: torch.Tensor = None |
| | ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: |
| | """LabelAggregate forward function. |
| | |
| | Args: |
| | input: (Batch, Nsamples, Label_dim) |
| | ilens: (Batch) |
| | Returns: |
| | output: (Batch, Frames, Label_dim) |
| | |
| | """ |
| | bs = input.size(0) |
| | max_length = input.size(1) |
| | label_dim = input.size(2) |
| |
|
| | |
| | |
| | |
| |
|
| | |
| | if self.center: |
| | pad = self.win_length // 2 |
| | max_length = max_length + 2 * pad |
| | input = torch.nn.functional.pad(input, (0, 0, pad, pad), "constant", 0) |
| | nframe = (max_length - self.win_length) // self.hop_length + 1 |
| |
|
| | |
| | output = input.as_strided( |
| | (bs, nframe, self.win_length, label_dim), |
| | (max_length * label_dim, self.hop_length * label_dim, label_dim, 1), |
| | ) |
| |
|
| | |
| | output = torch.gt(output.sum(dim=2, keepdim=False), self.win_length // 2) |
| | output = output.float() |
| |
|
| | |
| | if ilens is not None: |
| | if self.center: |
| | pad = self.win_length // 2 |
| | ilens = ilens + 2 * pad |
| |
|
| | olens = (ilens - self.win_length) // self.hop_length + 1 |
| | output.masked_fill_(make_pad_mask(olens, output, 1), 0.0) |
| | else: |
| | olens = None |
| |
|
| | return output, olens |
| |
|