| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import math |
| from typing import List, Tuple |
|
|
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from utils import Fp32GroupNorm, Fp32LayerNorm, TransposeLast |
|
|
|
|
| class ConvFeatureExtractionModel(nn.Module): |
| def __init__( |
| self, |
| conv_layers: List[Tuple[int, int, int]], |
| dropout: float = 0.0, |
| mode: str = "default", |
| conv_bias: bool = False, |
| ): |
| super().__init__() |
|
|
| assert mode in {"default", "layer_norm"} |
|
|
| def block( |
| n_in, |
| n_out, |
| k, |
| stride, |
| is_layer_norm=False, |
| is_group_norm=False, |
| conv_bias=False, |
| ): |
| def make_conv(): |
| conv = nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias) |
| nn.init.kaiming_normal_(conv.weight) |
| return conv |
|
|
| assert ( |
| is_layer_norm and is_group_norm |
| ) == False, "layer norm and group norm are exclusive" |
|
|
| if is_layer_norm: |
| return nn.Sequential( |
| make_conv(), |
| nn.Dropout(p=dropout), |
| nn.Sequential( |
| TransposeLast(), |
| Fp32LayerNorm(dim, elementwise_affine=True), |
| TransposeLast(), |
| ), |
| nn.GELU(), |
| ) |
| elif is_group_norm: |
| return nn.Sequential( |
| make_conv(), |
| nn.Dropout(p=dropout), |
| Fp32GroupNorm(dim, dim, affine=True), |
| nn.GELU(), |
| ) |
| else: |
| return nn.Sequential(make_conv(), nn.Dropout(p=dropout), nn.GELU()) |
|
|
| in_d = 1 |
| self.conv_layers = nn.ModuleList() |
| for i, cl in enumerate(conv_layers): |
| assert len(cl) == 3, "invalid conv definition: " + str(cl) |
| (dim, k, stride) = cl |
|
|
| self.conv_layers.append( |
| block( |
| in_d, |
| dim, |
| k, |
| stride, |
| is_layer_norm=mode == "layer_norm", |
| is_group_norm=mode == "default" and i == 0, |
| conv_bias=conv_bias, |
| ) |
| ) |
| in_d = dim |
|
|
| def forward(self, x): |
| |
| x = x.unsqueeze(1) |
|
|
| for conv in self.conv_layers: |
| x = conv(x) |
|
|
| return x |
|
|