huazai676's picture
Upload 6 files
ba80248 verified
import torch
from torch import nn
import torchaudio as audio
from torch import Tensor
class StutterNet(nn.Module):
def __init__(self, n_mels=40,
dropout=0.0, use_batchnorm=False, scale=1):
'''Implementation of StutterNet
from Sheikh et al. StutterNet:
"Stuttering Detection Using
Time Delay Neural Network" 2021
Args:
n_mels (int, optional): number of mel filter banks
n_classes (int, optional): number of classes in output layer
use_dropout (bool, optional): whether or not to use dropout in the
last two linear layers
use_batchnorm (bool, optional): whether ot not to batchnorm in the
TDNN layers
scale (float ,optional): width scale factor
'''
super(StutterNet, self).__init__()
self.n_mels = n_mels
# self.spec = audio.transforms.MelSpectrogram(n_mels=n_mels, sample_rate=16000,
# n_fft=512, pad=1, f_max=8000, win_length=400,
# f_min=0, power=2.0, hop_length=160, norm='slaney')
# self.db = audio.transforms.AmplitudeToDB()
# self.mfcc = audio.transforms.MFCC(16000, 40)
self.tdnn_1 = nn.Conv1d(n_mels, int(512*scale), 5, dilation=1)
self.tdnn_2 = nn.Conv1d(int(512*scale), int(1536*scale), 5, dilation=2)
self.tdnn_3 = nn.Conv1d(int(1536*scale), int(512*scale), 7, dilation=3)
self.tdnn_4 = nn.Conv1d(int(512*scale), int(512*scale), 1)
self.tdnn_5 = nn.Conv1d(int(512*scale), int(1500*scale), 1)
self.fc_1 = nn.Linear(int(3000*scale), 512)
self.relu = nn.ReLU()
self.bn_1 = nn.BatchNorm1d(int(512*scale))
self.bn_2 = nn.BatchNorm1d(int(1536*scale))
self.bn_3 = nn.BatchNorm1d(int(512*scale))
self.bn_4 = nn.BatchNorm1d(int(512*scale))
self.bn_5 = nn.BatchNorm1d(int(1500*scale))
nn.init.xavier_uniform_(self.fc_1.weight)
self.dropout_1 = nn.Dropout(dropout)
self.fc_2 = nn.Linear(512, 512)
nn.init.xavier_uniform_(self.fc_1.weight)
self.dropout_2 = nn.Dropout(dropout)
self.binary_head = nn.Linear(512, 6)
self.class_head = nn.Linear(512, 6)
self.sig = nn.Sigmoid()
def forward(self, x):
'''forward method'''
batch_size = x.shape[0]
# x = self.spec(x)
# x = self.db(x)
# x = self.mfcc(x)
x = self.tdnn_1(x)
x = self.relu(x)
x = self.bn_1(x)
x = self.tdnn_2(x)
x = self.relu(x)
x = self.bn_2(x)
x = self.tdnn_3(x)
x = self.relu(x)
x = self.bn_3(x)
x = self.tdnn_4(x)
x = self.relu(x)
x = self.bn_4(x)
x = self.tdnn_5(x)
x = self.relu(x)
x = self.bn_5(x)
mean = torch.mean(x,-1)
std = torch.std(x,-1)
x = torch.cat((mean,std),1)
x = self.fc_1(x)
x = self.dropout_1(x)
x = self.fc_2(x)
x = self.dropout_2(x)
binary = self.binary_head(x)
# binary = self.sig(binary)
classes = self.class_head(x)
# classes = self.sig(classes)
# return torch.cat((classes, binary), dim=-1)
return torch.cat((binary, classes), dim=-1)
class ResBlock1d(nn.Module):
def __init__(self, input_dims, output_dims, depth=2, kernel_size=3,
use_batchnorm=False, downsample=False, dropout=0.0):
super(ResBlock1d, self).__init__()
self.depth = depth
self.use_batchnorm = use_batchnorm
scale = 1
self.up = None
if (downsample):
self.down = nn.Conv1d(int(input_dims), int(output_dims), 3, 2, padding=1)
# self.down = nn.MaxPool1d(1, stride=2)
scale=2
self.downsample = downsample
self.conv_1 = nn.Conv1d(int(input_dims),
output_dims, 3, stride=scale, padding=1)
self.convs = nn.ModuleList([nn.Conv1d(output_dims,
output_dims, kernel_size, padding='same') for _ in range(depth-1)])
self.bn_1 = nn.BatchNorm1d(output_dims)
self.bn = None
if (use_batchnorm):
self.bn = nn.ModuleList([nn.BatchNorm1d(
output_dims) for _ in range(depth-1)])
self.relu = nn.ReLU()
self.dropout = nn.Dropout(dropout)
def forward(self, x):
temp = x
if (self.downsample):
temp = self.down(x)
x = self.conv_1(x)
x = self.bn_1(x)
if (not self.use_batchnorm):
for i in range(self.depth-1):
x = self.convs[i](x)
x = self.dropout(x)
if (i != self.depth-2):
x = self.relu(x)
else:
for i in range(self.depth-1):
x = self.convs[i](x)
x = self.dropout(x)
x = self.bn[i](x)
if (i != self.depth-2):
x = self.relu(x)
x = temp + x
return x
class ResNet1D(nn.Module):
def __init__(self, n_mels=100,n_classes=12, kernel_size=3,
dropout=0.0, use_batchnorm=False, scale=1):
'''Implementation of StutterNet
from Sheikh et al. StutterNet:
"Stuttering Detection Using
Time Delay Neural Network" 2021
Args:
n_mels (int, optional): number of mel filter banks
n_classes (int, optional): number of classes in output layer
use_dropout (bool, optional): whether or not to use dropout in the
last two linear layers
use_batchnorm (bool, optional): whether ot not to batchnorm in the
TDNN layers
scale (float ,optional): width scale factor
'''
super(ResNet1D, self).__init__()
self.n_mels = n_mels
# self.spec = audio.transforms.MelSpectrogram(n_mels=n_mels, sample_rate=16000,
# n_fft=512, pad=1, f_max=8000, f_min=0,
# power=2.0, hop_length=160)
# self.mfcc = audio.transforms.MFCC(16000, 40)
# self.db = audio.transforms.AmplitudeToDB()
self.tdnn_1 = nn.Conv1d(n_mels, int(64*scale), 3, padding=1, bias=False)
self.res_1_1 = ResBlock1d(int(64*scale), int(64*scale), kernel_size=kernel_size, downsample=True, use_batchnorm=use_batchnorm, dropout=dropout)
self.res_1_2 = ResBlock1d(int(64*scale), int(64*scale), kernel_size=kernel_size, downsample=False, use_batchnorm=use_batchnorm, dropout=dropout)
self.res_1_3 = ResBlock1d(int(64*scale), int(64*scale), kernel_size=kernel_size, downsample=False, use_batchnorm=use_batchnorm, dropout=dropout)
self.res_2_1 = ResBlock1d(int(64*scale), int(128*scale), kernel_size=kernel_size, downsample=True, use_batchnorm=use_batchnorm, dropout=dropout)
self.res_2_2 = ResBlock1d(int(128*scale), int(128*scale), kernel_size=kernel_size, downsample=False, use_batchnorm=use_batchnorm, dropout=dropout)
self.res_2_3 = ResBlock1d(int(128*scale), int(128*scale), kernel_size=kernel_size, downsample=False, use_batchnorm=use_batchnorm, dropout=dropout)
self.res_3_1 = ResBlock1d(int(128*scale), int(256*scale), kernel_size=kernel_size, downsample=True, use_batchnorm=use_batchnorm, dropout=dropout)
self.res_3_2 = ResBlock1d(int(256*scale), int(256*scale), kernel_size=kernel_size, downsample=False, use_batchnorm=use_batchnorm, dropout=dropout)
self.res_3_3 = ResBlock1d(int(256*scale), int(256*scale), kernel_size=kernel_size, downsample=False, use_batchnorm=use_batchnorm, dropout=dropout)
self.res_4_1 = ResBlock1d(int(256*scale), int(512*scale), kernel_size=kernel_size, downsample=True, use_batchnorm=use_batchnorm, dropout=dropout)
self.res_4_2 = ResBlock1d(int(512*scale), int(512*scale), kernel_size=kernel_size, downsample=False, use_batchnorm=use_batchnorm, dropout=dropout)
self.res_4_3 = ResBlock1d(int(512*scale), int(512*scale), kernel_size=kernel_size, downsample=False, use_batchnorm=use_batchnorm, dropout=dropout)
# self.bn = nn.BatchNorm1d(int(512*scale))
self.relu = nn.ReLU()
self.fc = nn.Linear(int(1024*scale), n_classes)
def forward(self, x):
'''forward method'''
batch_size = x.shape[0]
# x = self.spec(x)
# x = self.mfcc(x)
# x = self.db(x)
x = self.tdnn_1(x)
x = self.res_1_1(x)
x = self.relu(x)
x = self.res_1_2(x)
x = self.relu(x)
x = self.res_1_3(x)
x = self.relu(x)
x = self.res_2_1(x)
x = self.relu(x)
x = self.res_2_2(x)
x = self.relu(x)
x = self.res_2_3(x)
x = self.relu(x)
x = self.res_3_1(x)
x = self.relu(x)
x = self.res_3_2(x)
x = self.relu(x)
x = self.res_3_3(x)
x = self.relu(x)
x = self.res_4_1(x)
x = self.relu(x)
x = self.res_4_2(x)
x = self.relu(x)
x = self.res_4_3(x)
x = self.relu(x)
# x = self.bn(x)
mean = torch.mean(x,-1)
std = torch.std(x,-1)
x = torch.cat((mean,std),1)
x = self.fc(x)
return x
from torch import Tensor
'''credit: https://github.com/roman-vygon/BCResNet'''
class SubSpectralNorm(nn.Module):
def __init__(self, C, S, eps=1e-5):
super(SubSpectralNorm, self).__init__()
self.S = S
self.eps = eps
self.bn = nn.BatchNorm2d(C*S)
def forward(self, x):
# x: input features with shape {N, C, F, T}
# S: number of sub-bands
N, C, F, T = x.size()
x = x.view(N, C * self.S, F // self.S, T)
x = self.bn(x)
return x.view(N, C, F, T)
class BroadcastedBlock(nn.Module):
def __init__(
self,
planes: int,
dilation=1,
stride=1,
temp_pad=(0, 1),
) -> None:
super(BroadcastedBlock, self).__init__()
self.freq_dw_conv = nn.Conv2d(planes, planes, kernel_size=(3, 1), padding=(1, 0), groups=planes,
dilation=dilation,
stride=stride, bias=False)
self.ssn1 = SubSpectralNorm(planes, 5)
self.temp_dw_conv = nn.Conv2d(planes, planes, kernel_size=(1, 3), padding=temp_pad, groups=planes,
dilation=dilation, stride=stride, bias=False)
self.bn = nn.BatchNorm2d(planes)
self.relu = nn.ReLU(inplace=True)
self.channel_drop = nn.Dropout2d(p=0.5)
self.swish = nn.SiLU()
self.conv1x1 = nn.Conv2d(planes, planes, kernel_size=(1, 1), bias=False)
def forward(self, x: Tensor) -> Tensor:
identity = x
# f2
##########################
out = self.freq_dw_conv(x)
out = self.ssn1(out)
##########################
auxilary = out
out = out.mean(2, keepdim=True) # frequency average pooling
# f1
############################
out = self.temp_dw_conv(out)
out = self.bn(out)
out = self.swish(out)
out = self.conv1x1(out)
out = self.channel_drop(out)
############################
out = out + identity + auxilary
out = self.relu(out)
return out
class TransitionBlock(nn.Module):
def __init__(
self,
inplanes: int,
planes: int,
dilation=1,
stride=1,
temp_pad=(0, 1),
) -> None:
super(TransitionBlock, self).__init__()
self.freq_dw_conv = nn.Conv2d(planes, planes, kernel_size=(3, 1), padding=(1, 0), groups=planes,
stride=stride,
dilation=dilation, bias=False)
self.ssn = SubSpectralNorm(planes, 5)
self.temp_dw_conv = nn.Conv2d(planes, planes, kernel_size=(1, 3), padding=temp_pad, groups=planes,
dilation=dilation, stride=stride, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.bn2 = nn.BatchNorm2d(planes)
self.relu = nn.ReLU(inplace=True)
self.channel_drop = nn.Dropout2d(p=0.5)
self.swish = nn.SiLU()
self.conv1x1_1 = nn.Conv2d(inplanes, planes, kernel_size=(1, 1), bias=False)
self.conv1x1_2 = nn.Conv2d(planes, planes, kernel_size=(1, 1), bias=False)
def forward(self, x: Tensor) -> Tensor:
# f2
#############################
out = self.conv1x1_1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.freq_dw_conv(out)
out = self.ssn(out)
#############################
auxilary = out
out = out.mean(2, keepdim=True) # frequency average pooling
# f1
#############################
out = self.temp_dw_conv(out)
out = self.bn2(out)
out = self.swish(out)
out = self.conv1x1_2(out)
out = self.channel_drop(out)
#############################
out = auxilary + out
out = self.relu(out)
return out
class BCResNet(torch.nn.Module):
def __init__(self):
super(BCResNet, self).__init__()
self.conv1 = nn.Conv2d(1, 16, 5, stride=(2, 1), padding=(2, 2))
self.block1_1 = TransitionBlock(16, 8)
self.block1_2 = BroadcastedBlock(8)
self.block2_1 = TransitionBlock(8, 12, stride=(2, 1), dilation=(1, 2), temp_pad=(0, 2))
self.block2_2 = BroadcastedBlock(12, dilation=(1, 2), temp_pad=(0, 2))
self.block3_1 = TransitionBlock(12, 16, stride=(2, 1), dilation=(1, 4), temp_pad=(0, 4))
self.block3_2 = BroadcastedBlock(16, dilation=(1, 4), temp_pad=(0, 4))
self.block3_3 = BroadcastedBlock(16, dilation=(1, 4), temp_pad=(0, 4))
self.block3_4 = BroadcastedBlock(16, dilation=(1, 4), temp_pad=(0, 4))
self.block4_1 = TransitionBlock(16, 20, dilation=(1, 8), temp_pad=(0, 8))
self.block4_2 = BroadcastedBlock(20, dilation=(1, 8), temp_pad=(0, 8))
self.block4_3 = BroadcastedBlock(20, dilation=(1, 8), temp_pad=(0, 8))
self.block4_4 = BroadcastedBlock(20, dilation=(1, 8), temp_pad=(0, 8))
self.conv2 = nn.Conv2d(20, 20, 5, groups=20, padding=(0, 2))
self.conv3 = nn.Conv2d(20, 32, 1, bias=False)
self.conv4 = nn.Conv2d(32, 12, 1, bias=False)
def forward(self, x):
out = self.conv1(x)
out = self.block1_1(out)
out = self.block1_2(out)
out = self.block2_1(out)
out = self.block2_2(out)
out = self.block3_1(out)
out = self.block3_2(out)
out = self.block3_3(out)
out = self.block3_4(out)
out = self.block4_1(out)
out = self.block4_2(out)
out = self.block4_3(out)
out = self.block4_4(out)
out = self.conv2(out)
out = self.conv3(out)
out = out.mean(-1, keepdim=True)
out = self.conv4(out)
return out.reshape((-1, 12))