| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| from torch import nn |
| from einops import rearrange |
|
|
|
|
| class Res2dModule(nn.Module): |
| def __init__(self, idim, odim, stride=(2, 2)): |
| super(Res2dModule, self).__init__() |
| self.conv1 = nn.Conv2d(idim, odim, 3, padding=1, stride=stride) |
| self.bn1 = nn.BatchNorm2d(odim) |
| self.conv2 = nn.Conv2d(odim, odim, 3, padding=1) |
| self.bn2 = nn.BatchNorm2d(odim) |
| self.relu = nn.ReLU() |
|
|
| |
| self.diff = False |
| if (idim != odim) or (stride[0] > 1): |
| self.conv3 = nn.Conv2d(idim, odim, 3, padding=1, stride=stride) |
| self.bn3 = nn.BatchNorm2d(odim) |
| self.diff = True |
|
|
| def forward(self, x): |
| out = self.bn2(self.conv2(self.relu(self.bn1(self.conv1(x))))) |
| if self.diff: |
| x = self.bn3(self.conv3(x)) |
| out = x + out |
| out = self.relu(out) |
| return out |
|
|
|
|
| class Conv2dSubsampling(nn.Module): |
| """Convolutional 2D subsampling (to 1/4 length). |
| |
| Args: |
| idim (int): Input dimension. |
| hdim (int): Hidden dimension. |
| odim (int): Output dimension. |
| strides (list): Sizes of strides. |
| n_bands (int): Number of frequency bands. |
| """ |
|
|
| def __init__(self, idim, hdim, odim, strides=[2, 2], n_bands=64): |
| """Construct an Conv2dSubsampling object.""" |
| super(Conv2dSubsampling, self).__init__() |
|
|
| self.conv = nn.Sequential( |
| Res2dModule(idim, hdim, (2, strides[0])), |
| Res2dModule(hdim, hdim, (2, strides[1])), |
| ) |
| self.linear = nn.Linear(hdim * n_bands // 2 // 2, odim) |
|
|
| def forward(self, x): |
| """Subsample x. |
| |
| Args: |
| x (torch.Tensor): Input tensor (#batch, idim, time). |
| |
| Returns: |
| torch.Tensor: Subsampled tensor (#batch, time', odim), |
| where time' = time // 4. |
| """ |
|
|
| if x.dim() == 3: |
| x = x.unsqueeze(1) |
| x = self.conv(x) |
| x = rearrange(x, "b c f t -> b t (c f)") |
| x = self.linear(x) |
| return x |
|
|