import math import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from torch.nn.utils import weight_norm class SincConv_fast(nn.Module): @staticmethod def to_mel(hz): return 2595 * np.log10(1 + hz / 700) @staticmethod def to_hz(mel): return 700 * (10 ** (mel / 2595) - 1) def __init__(self, out_channels, kernel_size, sample_rate=16000, in_channels=1, stride=1, padding=0, dilation=1, bias=False, groups=1, min_low_hz=0, min_band_hz=0): super(SincConv_fast,self).__init__() if in_channels != 1: msg = "SincConv only support one input channel (here, in_channels = {%i})" % (in_channels) raise ValueError(msg) self.out_channels = out_channels self.kernel_size = kernel_size if kernel_size%2==0: self.kernel_size=self.kernel_size+1 self.stride = stride self.padding = padding self.dilation = dilation if bias: raise ValueError('SincConv does not support bias.') if groups > 1: raise ValueError('SincConv does not support groups.') self.sample_rate = sample_rate self.min_low_hz = min_low_hz self.min_band_hz = min_band_hz low_hz = 0 high_hz = self.sample_rate / 2 - (self.min_low_hz + self.min_band_hz) mel = np.linspace(self.to_mel(low_hz), self.to_mel(high_hz), self.out_channels + 1) hz = self.to_hz(mel) self.low_hz_ = nn.Parameter(torch.Tensor(hz[:-1]).view(-1, 1)) self.band_hz_ = nn.Parameter(torch.Tensor(np.diff(hz)).view(-1, 1)) n_lin=torch.linspace(0, (self.kernel_size/2)-1, steps=int((self.kernel_size/2))) self.window_=0.54-0.46*torch.cos(2*math.pi*n_lin/self.kernel_size); n = (self.kernel_size - 1) / 2.0 self.n_ = 2*math.pi*torch.arange(-n, 0).view(1, -1) / self.sample_rate def forward(self, waveforms): self.n_ = self.n_.to(waveforms.device) self.window_ = self.window_.to(waveforms.device) low = self.min_low_hz + torch.abs(self.low_hz_) high = torch.clamp(low + self.min_band_hz + torch.abs(self.band_hz_),self.min_low_hz,self.sample_rate/2) band=(high-low)[:,0] f_times_t_low = torch.matmul(low, self.n_) f_times_t_high = torch.matmul(high, self.n_) band_pass_left=((torch.sin(f_times_t_high)-torch.sin(f_times_t_low))/(self.n_/2))*self.window_ band_pass_center = 2*band.view(-1,1) band_pass_right= torch.flip(band_pass_left,dims=[1]) band_pass=torch.cat([band_pass_left,band_pass_center,band_pass_right],dim=1) band_pass = band_pass / (2*band[:,None]) self.filters = (band_pass).view( self.out_channels, 1, self.kernel_size) return F.conv1d(waveforms, self.filters, stride=self.stride, padding=self.padding, dilation=self.dilation, bias=None, groups=1) class Res2Block(nn.Module): def __init__(self, nb_filts, nums=4): super(Res2Block, self).__init__() self.nb_filts = nb_filts self.conv1 = nn.Conv2d(in_channels=nb_filts[0], out_channels=nb_filts[1], kernel_size=1, padding=0, stride=1) self.bn1 = nn.BatchNorm2d(num_features=nb_filts[1]) self.relu = nn.ReLU(inplace=True) self.nums = nums self.SE = SE_Block(nb_filts[1]) convs = [] bns = [] for i in range(self.nums): convs.append(nn.Conv2d(in_channels=(nb_filts[1]// self.nums), out_channels=(nb_filts[1] //self.nums), kernel_size=3, stride=1, padding=1)) bns.append(nn.BatchNorm2d((nb_filts[1] //self.nums))) self.convs = nn.ModuleList(convs) self.bns = nn.ModuleList(bns) self.conv3 = nn.Conv2d(in_channels=nb_filts[1], out_channels=nb_filts[1], kernel_size=1, padding=0, stride=1) self.bn3 = nn.BatchNorm2d(nb_filts[1]) if nb_filts[0] != nb_filts[1]: self.downsample = True self.conv_downsample = nn.Conv2d(in_channels=nb_filts[0], out_channels=nb_filts[1], padding=(0, 1), kernel_size=(1, 3), stride=1) else: self.downsample = False self.mp = nn.MaxPool2d((1,3)) def forward(self, x): residual = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) spx = torch.split(out, self.nb_filts[1]//self.nums, 1) for i in range(self.nums): if i==0: sp = spx[i] else: sp += spx[i] sp = self.convs[i](sp) sp = self.bns[i](sp) if i==0: out = sp else: out = torch.cat((out,sp),1) out = self.conv3(out) out = self.bn3(out) out = self.SE(out) if self.downsample: residual = self.conv_downsample(residual) out += residual out = self.relu(out) out = self.mp(out) return out class SE_Block(nn.Module): "credits: https://github.com/moskomule/senet.pytorch/blob/master/senet/se_module.py#L4" def __init__(self, c, r=8): super().__init__() self.squeeze = nn.AdaptiveAvgPool2d(1) self.excitation = nn.Sequential( nn.Linear(c, c // r, bias=False), nn.ReLU(inplace=True), nn.Linear(c // r, c, bias=False), nn.Sigmoid() ) def forward(self, x): bs, c, _, _ = x.shape y = self.squeeze(x).view(bs, c) y = self.excitation(y).view(bs, c, 1, 1) return x * y.expand_as(x) class Encoder(nn.Module): def __init__(self): super().__init__() filts = [70, [1, 32], [32, 32], [32, 64], [64, 64]] self.sinc_conv = SincConv_fast(out_channels=filts[0], kernel_size=128, ) self.first_bn = nn.BatchNorm2d(num_features=1) self.selu = nn.SELU(inplace=True) self.res_encoder = nn.Sequential( nn.Sequential(Res2Block(nb_filts=filts[1])), nn.Sequential(Res2Block(nb_filts=filts[2])), nn.Sequential(Res2Block(nb_filts=filts[3])), nn.Sequential(Res2Block(nb_filts=filts[4])), nn.Sequential(Res2Block(nb_filts=filts[4])), nn.Sequential(Res2Block(nb_filts=filts[4]))) def forward(self, x): x = x.unsqueeze(1) x = self.sinc_conv(x) x = x.unsqueeze(dim=1) x = F.max_pool2d(torch.abs(x), (3, 3)) x = self.first_bn(x) x = self.selu(x) e = self.res_encoder(x) return e import torch import torch.nn as nn from torch.nn.utils import weight_norm class Chomp1d(nn.Module): def __init__(self, chomp_size): super(Chomp1d, self).__init__() self.chomp_size = chomp_size def forward(self, x): return x[:, :, :-self.chomp_size].contiguous() class TemporalBlock(nn.Module): def __init__(self, n_inputs, n_outputs, kernel_size, stride, dilation, padding, dropout=0.2): super(TemporalBlock, self).__init__() self.conv1 = weight_norm(nn.Conv1d(n_inputs, n_outputs, kernel_size, stride=stride, padding=padding, dilation=dilation)) self.chomp1 = Chomp1d(padding) self.relu1 = nn.ReLU() self.dropout1 = nn.Dropout(dropout) self.conv2 = weight_norm(nn.Conv1d(n_outputs, n_outputs, kernel_size, stride=stride, padding=padding, dilation=dilation)) self.chomp2 = Chomp1d(padding) self.relu2 = nn.ReLU() self.dropout2 = nn.Dropout(dropout) self.net = nn.Sequential(self.conv1, self.chomp1, self.relu1, self.dropout1, self.conv2, self.chomp2, self.relu2, self.dropout2) self.downsample = nn.Conv1d(n_inputs, n_outputs, 1) if n_inputs != n_outputs else None self.relu = nn.ReLU() self.init_weights() def init_weights(self): self.conv1.weight.data.normal_(0, 0.01) self.conv2.weight.data.normal_(0, 0.01) if self.downsample is not None: self.downsample.weight.data.normal_(0, 0.01) def forward(self, x): out = self.net(x) res = x if self.downsample is None else self.downsample(x) return self.relu(out + res) class TemporalConvNet(nn.Module): def __init__(self, num_inputs, num_channels, kernel_size=2, dropout=0.2): super(TemporalConvNet, self).__init__() layers = [] num_levels = len(num_channels) for i in range(num_levels): dilation_size = 2 ** i in_channels = num_inputs if i == 0 else num_channels[i-1] out_channels = num_channels[i] layers += [TemporalBlock(in_channels, out_channels, kernel_size, stride=1, dilation=dilation_size, padding=(kernel_size-1) * dilation_size, dropout=dropout)] self.network = nn.Sequential(*layers) def forward(self, x): return self.network(x) class TestModel(nn.Module): def __init__(self): super().__init__() self.encoder = Encoder() self.tempCNN1 = TemporalConvNet(64,[72,36,24,12,6]) self.tempCNN2 = TemporalConvNet(64,[72,36,24,12,6]) self.relu = nn.ReLU(0.1) self.pooling = nn.AdaptiveAvgPool2d((1, 1)) self.linear1 = nn.Linear(138,4) self.linear2 = nn.Linear(174,4) self.linear3 = nn.Linear(8,54) self.linear4 = nn.Linear(54,2) self.drop = nn.Dropout(p=0.2) def forward(self, x): x = self.encoder(x) matrix1, _ = torch.max(x, dim=2) # T matrix2, _ = torch.max(x, dim=3) # S x1 = self.tempCNN1(matrix2) x1 = torch.flatten(x1,1,2) x1 = self.linear1(x1) x1 = self.drop(x1) x1 = self.relu(x1) x2 = self.tempCNN2(matrix1) x2 = torch.flatten(x2,1,2) x2 = self.linear2(x2) x2 = self.drop(x2) x2 = self.relu(x2) last_layer =self.relu(self.linear3(torch.cat((x1,x2), dim=1))) return last_layer, self.linear4(last_layer)