| import torch |
| from torch import nn |
| from torch.nn import functional as F |
|
|
| from .models import register_backbone |
| from .blocks import (get_sinusoid_encoding, TransformerBlock, MaskedConv1D, |
| ConvBlock, LayerNorm) |
|
|
|
|
| @register_backbone("convTransformer") |
| class ConvTransformerBackbone(nn.Module): |
| """ |
| A backbone that combines convolutions with transformers |
| """ |
| def __init__( |
| self, |
| n_in, |
| n_embd, |
| n_head, |
| n_embd_ks, |
| max_len, |
| arch = (2, 2, 5), |
| mha_win_size = [-1]*6, |
| scale_factor = 2, |
| with_ln = False, |
| attn_pdrop = 0.0, |
| proj_pdrop = 0.0, |
| path_pdrop = 0.0, |
| use_abs_pe = False, |
| use_rel_pe = False, |
| ): |
| super().__init__() |
| assert len(arch) == 3 |
| assert len(mha_win_size) == (1 + arch[2]) |
| self.arch = arch |
| self.mha_win_size = mha_win_size |
| self.max_len = max_len |
| self.relu = nn.ReLU(inplace=True) |
| self.scale_factor = scale_factor |
| self.use_abs_pe = use_abs_pe |
| self.use_rel_pe = use_rel_pe |
|
|
| |
| if self.use_abs_pe: |
| pos_embd = get_sinusoid_encoding(self.max_len, n_embd) / (n_embd**0.5) |
| self.register_buffer("pos_embd", pos_embd, persistent=False) |
|
|
| |
| self.embd = nn.ModuleList() |
| self.embd_norm = nn.ModuleList() |
| for idx in range(arch[0]): |
| if idx == 0: |
| in_channels = n_in |
| else: |
| in_channels = n_embd |
| self.embd.append(MaskedConv1D( |
| in_channels, n_embd, n_embd_ks, |
| stride=1, padding=n_embd_ks//2, bias=(not with_ln) |
| ) |
| ) |
| if with_ln: |
| self.embd_norm.append( |
| LayerNorm(n_embd) |
| ) |
| else: |
| self.embd_norm.append(nn.Identity()) |
|
|
| |
| self.stem = nn.ModuleList() |
| for idx in range(arch[1]): |
| self.stem.append(TransformerBlock( |
| n_embd, n_head, |
| n_ds_strides=(1, 1), |
| attn_pdrop=attn_pdrop, |
| proj_pdrop=proj_pdrop, |
| path_pdrop=path_pdrop, |
| mha_win_size=self.mha_win_size[0], |
| use_rel_pe=self.use_rel_pe |
| ) |
| ) |
|
|
| |
| self.branch = nn.ModuleList() |
| for idx in range(arch[2]): |
| self.branch.append(TransformerBlock( |
| n_embd, n_head, |
| n_ds_strides=(self.scale_factor, self.scale_factor), |
| attn_pdrop=attn_pdrop, |
| proj_pdrop=proj_pdrop, |
| path_pdrop=path_pdrop, |
| mha_win_size=self.mha_win_size[1+idx], |
| use_rel_pe=self.use_rel_pe |
| ) |
| ) |
|
|
| |
| self.apply(self.__init_weights__) |
|
|
| def __init_weights__(self, module): |
| |
| if isinstance(module, (nn.Linear, nn.Conv1d)): |
| if module.bias is not None: |
| torch.nn.init.constant_(module.bias, 0.) |
|
|
| def forward(self, x, mask): |
| |
| |
| B, C, T = x.size() |
|
|
| |
| for idx in range(len(self.embd)): |
| x, mask = self.embd[idx](x, mask) |
| x = self.relu(self.embd_norm[idx](x)) |
|
|
| |
| if self.use_abs_pe and self.training: |
| assert T <= self.max_len, "Reached max length." |
| pe = self.pos_embd |
| |
| x = x + pe[:, :, :T] * mask.to(x.dtype) |
|
|
| |
| if self.use_abs_pe and (not self.training): |
| if T >= self.max_len: |
| pe = F.interpolate( |
| self.pos_embd, T, mode='linear', align_corners=False) |
| else: |
| pe = self.pos_embd |
| |
| x = x + pe[:, :, :T] * mask.to(x.dtype) |
|
|
| |
| for idx in range(len(self.stem)): |
| x, mask = self.stem[idx](x, mask) |
|
|
| |
| out_feats = tuple() |
| out_masks = tuple() |
| |
| out_feats += (x, ) |
| out_masks += (mask, ) |
|
|
| |
| for idx in range(len(self.branch)): |
| x, mask = self.branch[idx](x, mask) |
| out_feats += (x, ) |
| out_masks += (mask, ) |
|
|
| return out_feats, out_masks |
|
|
| @register_backbone("conv") |
| class ConvBackbone(nn.Module): |
| """ |
| A backbone that with only conv |
| """ |
| def __init__( |
| self, |
| n_in, |
| n_embd, |
| n_embd_ks, |
| arch = (2, 2, 5), |
| scale_factor = 2, |
| with_ln=False, |
| ): |
| super().__init__() |
| assert len(arch) == 3 |
| self.arch = arch |
| self.relu = nn.ReLU(inplace=True) |
| self.scale_factor = scale_factor |
|
|
| |
| self.embd = nn.ModuleList() |
| self.embd_norm = nn.ModuleList() |
| for idx in range(arch[0]): |
| if idx == 0: |
| in_channels = n_in |
| else: |
| in_channels = n_embd |
| self.embd.append(MaskedConv1D( |
| in_channels, n_embd, n_embd_ks, |
| stride=1, padding=n_embd_ks//2, bias=(not with_ln) |
| ) |
| ) |
| if with_ln: |
| self.embd_norm.append( |
| LayerNorm(n_embd) |
| ) |
| else: |
| self.embd_norm.append(nn.Identity()) |
|
|
| |
| self.stem = nn.ModuleList() |
| for idx in range(arch[1]): |
| self.stem.append(ConvBlock(n_embd, 3, 1)) |
|
|
| |
| self.branch = nn.ModuleList() |
| for idx in range(arch[2]): |
| self.branch.append(ConvBlock(n_embd, 3, self.scale_factor)) |
|
|
| |
| self.apply(self.__init_weights__) |
|
|
| def __init_weights__(self, module): |
| |
| if isinstance(module, (nn.Linear, nn.Conv1d)): |
| if module.bias is not None: |
| torch.nn.init.constant_(module.bias, 0.) |
|
|
| def forward(self, x, mask): |
| |
| |
| B, C, T = x.size() |
|
|
| |
| for idx in range(len(self.embd)): |
| x, mask = self.embd[idx](x, mask) |
| x = self.relu(self.embd_norm[idx](x)) |
|
|
| |
| for idx in range(len(self.stem)): |
| x, mask = self.stem[idx](x, mask) |
|
|
| |
| out_feats = tuple() |
| out_masks = tuple() |
| |
| out_feats += (x, ) |
| out_masks += (mask, ) |
|
|
| |
| for idx in range(len(self.branch)): |
| x, mask = self.branch[idx](x, mask) |
| out_feats += (x, ) |
| out_masks += (mask, ) |
|
|
| return out_feats, out_masks |
|
|