| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import math |
| | from typing import List, Tuple |
| |
|
| | import numpy as np |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from utils import Fp32GroupNorm, Fp32LayerNorm, TransposeLast |
| |
|
| |
|
| | class ConvFeatureExtractionModel(nn.Module): |
| | def __init__( |
| | self, |
| | conv_layers: List[Tuple[int, int, int]], |
| | dropout: float = 0.0, |
| | mode: str = "default", |
| | conv_bias: bool = False, |
| | ): |
| | super().__init__() |
| |
|
| | assert mode in {"default", "layer_norm"} |
| |
|
| | def block( |
| | n_in, |
| | n_out, |
| | k, |
| | stride, |
| | is_layer_norm=False, |
| | is_group_norm=False, |
| | conv_bias=False, |
| | ): |
| | def make_conv(): |
| | conv = nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias) |
| | nn.init.kaiming_normal_(conv.weight) |
| | return conv |
| |
|
| | assert ( |
| | is_layer_norm and is_group_norm |
| | ) == False, "layer norm and group norm are exclusive" |
| |
|
| | if is_layer_norm: |
| | return nn.Sequential( |
| | make_conv(), |
| | nn.Dropout(p=dropout), |
| | nn.Sequential( |
| | TransposeLast(), |
| | Fp32LayerNorm(dim, elementwise_affine=True), |
| | TransposeLast(), |
| | ), |
| | nn.GELU(), |
| | ) |
| | elif is_group_norm: |
| | return nn.Sequential( |
| | make_conv(), |
| | nn.Dropout(p=dropout), |
| | Fp32GroupNorm(dim, dim, affine=True), |
| | nn.GELU(), |
| | ) |
| | else: |
| | return nn.Sequential(make_conv(), nn.Dropout(p=dropout), nn.GELU()) |
| |
|
| | in_d = 1 |
| | self.conv_layers = nn.ModuleList() |
| | for i, cl in enumerate(conv_layers): |
| | assert len(cl) == 3, "invalid conv definition: " + str(cl) |
| | (dim, k, stride) = cl |
| |
|
| | self.conv_layers.append( |
| | block( |
| | in_d, |
| | dim, |
| | k, |
| | stride, |
| | is_layer_norm=mode == "layer_norm", |
| | is_group_norm=mode == "default" and i == 0, |
| | conv_bias=conv_bias, |
| | ) |
| | ) |
| | in_d = dim |
| |
|
| | def forward(self, x): |
| | |
| | x = x.unsqueeze(1) |
| |
|
| | for conv in self.conv_layers: |
| | x = conv(x) |
| |
|
| | return x |
| |
|