| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | |
| |
|
| |
|
| | class Downsample1d(nn.Module): |
| | def __init__(self, dim): |
| | super().__init__() |
| | self.conv = nn.Conv1d(dim, dim, 3, 2, 1) |
| |
|
| | def forward(self, x): |
| | return self.conv(x) |
| |
|
| | class Upsample1d(nn.Module): |
| | def __init__(self, dim): |
| | super().__init__() |
| | self.conv = nn.ConvTranspose1d(dim, dim, 4, 2, 1) |
| |
|
| | def forward(self, x): |
| | return self.conv(x) |
| |
|
| | class Conv1dBlock(nn.Module): |
| | ''' |
| | Conv1d --> GroupNorm --> Mish |
| | ''' |
| |
|
| | def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8): |
| | super().__init__() |
| |
|
| | self.block = nn.Sequential( |
| | nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2), |
| | |
| | nn.GroupNorm(n_groups, out_channels), |
| | |
| | nn.Mish(), |
| | ) |
| |
|
| | def forward(self, x): |
| | return self.block(x) |
| |
|
| |
|
| | def test(): |
| | cb = Conv1dBlock(256, 128, kernel_size=3) |
| | x = torch.zeros((1,256,16)) |
| | o = cb(x) |
| |
|