ranjit-task-logs-analysis / egs /librispeech /ASR /transducer_stateless_multi_datasets /subsampling.py
| # Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) | |
| # | |
| # See ../../../../LICENSE for clarification regarding multiple authors | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import torch | |
| import torch.nn as nn | |
| class Conv2dSubsampling(nn.Module): | |
| """Convolutional 2D subsampling (to 1/4 length). | |
| Convert an input of shape (N, T, idim) to an output | |
| with shape (N, T', odim), where | |
| T' = ((T-1)//2 - 1)//2, which approximates T' == T//4 | |
| It is based on | |
| https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa | |
| """ | |
| def __init__(self, idim: int, odim: int) -> None: | |
| """ | |
| Args: | |
| idim: | |
| Input dim. The input shape is (N, T, idim). | |
| Caution: It requires: T >=7, idim >=7 | |
| odim: | |
| Output dim. The output shape is (N, ((T-1)//2 - 1)//2, odim) | |
| """ | |
| assert idim >= 7 | |
| super().__init__() | |
| self.conv = nn.Sequential( | |
| nn.Conv2d(in_channels=1, out_channels=odim, kernel_size=3, stride=2), | |
| nn.ReLU(), | |
| nn.Conv2d(in_channels=odim, out_channels=odim, kernel_size=3, stride=2), | |
| nn.ReLU(), | |
| ) | |
| self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| """Subsample x. | |
| Args: | |
| x: | |
| Its shape is (N, T, idim). | |
| Returns: | |
| Return a tensor of shape (N, ((T-1)//2 - 1)//2, odim) | |
| """ | |
| # On entry, x is (N, T, idim) | |
| x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W) | |
| x = self.conv(x) | |
| # Now x is of shape (N, odim, ((T-1)//2 - 1)//2, ((idim-1)//2 - 1)//2) | |
| b, c, t, f = x.size() | |
| x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) | |
| # Now x is of shape (N, ((T-1)//2 - 1))//2, odim) | |
| return x | |
| class VggSubsampling(nn.Module): | |
| """Trying to follow the setup described in the following paper: | |
| https://arxiv.org/pdf/1910.09799.pdf | |
| This paper is not 100% explicit so I am guessing to some extent, | |
| and trying to compare with other VGG implementations. | |
| Convert an input of shape (N, T, idim) to an output | |
| with shape (N, T', odim), where | |
| T' = ((T-1)//2 - 1)//2, which approximates T' = T//4 | |
| """ | |
| def __init__(self, idim: int, odim: int) -> None: | |
| """Construct a VggSubsampling object. | |
| This uses 2 VGG blocks with 2 Conv2d layers each, | |
| subsampling its input by a factor of 4 in the time dimensions. | |
| Args: | |
| idim: | |
| Input dim. The input shape is (N, T, idim). | |
| Caution: It requires: T >=7, idim >=7 | |
| odim: | |
| Output dim. The output shape is (N, ((T-1)//2 - 1)//2, odim) | |
| """ | |
| super().__init__() | |
| cur_channels = 1 | |
| layers = [] | |
| block_dims = [32, 64] | |
| # The decision to use padding=1 for the 1st convolution, then padding=0 | |
| # for the 2nd and for the max-pooling, and ceil_mode=True, was driven by | |
| # a back-compatibility concern so that the number of frames at the | |
| # output would be equal to: | |
| # (((T-1)//2)-1)//2. | |
| # We can consider changing this by using padding=1 on the | |
| # 2nd convolution, so the num-frames at the output would be T//4. | |
| for block_dim in block_dims: | |
| layers.append( | |
| torch.nn.Conv2d( | |
| in_channels=cur_channels, | |
| out_channels=block_dim, | |
| kernel_size=3, | |
| padding=1, | |
| stride=1, | |
| ) | |
| ) | |
| layers.append(torch.nn.ReLU()) | |
| layers.append( | |
| torch.nn.Conv2d( | |
| in_channels=block_dim, | |
| out_channels=block_dim, | |
| kernel_size=3, | |
| padding=0, | |
| stride=1, | |
| ) | |
| ) | |
| layers.append( | |
| torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=0, ceil_mode=True) | |
| ) | |
| cur_channels = block_dim | |
| self.layers = nn.Sequential(*layers) | |
| self.out = nn.Linear(block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| """Subsample x. | |
| Args: | |
| x: | |
| Its shape is (N, T, idim). | |
| Returns: | |
| Return a tensor of shape (N, ((T-1)//2 - 1)//2, odim) | |
| """ | |
| x = x.unsqueeze(1) | |
| x = self.layers(x) | |
| b, c, t, f = x.size() | |
| x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) | |
| return x | |