Res2TCNGuard / _net.py
korallll's picture
Add model code (_net.py, evaluate.py, res2tcnguard.py); fix README usage; precise params
f2beec2 verified
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import weight_norm
class SincConv_fast(nn.Module):
@staticmethod
def to_mel(hz):
return 2595 * np.log10(1 + hz / 700)
@staticmethod
def to_hz(mel):
return 700 * (10 ** (mel / 2595) - 1)
def __init__(self, out_channels, kernel_size, sample_rate=16000, in_channels=1,
stride=1, padding=0, dilation=1, bias=False, groups=1, min_low_hz=0, min_band_hz=0):
super(SincConv_fast,self).__init__()
if in_channels != 1:
msg = "SincConv only support one input channel (here, in_channels = {%i})" % (in_channels)
raise ValueError(msg)
self.out_channels = out_channels
self.kernel_size = kernel_size
if kernel_size%2==0:
self.kernel_size=self.kernel_size+1
self.stride = stride
self.padding = padding
self.dilation = dilation
if bias:
raise ValueError('SincConv does not support bias.')
if groups > 1:
raise ValueError('SincConv does not support groups.')
self.sample_rate = sample_rate
self.min_low_hz = min_low_hz
self.min_band_hz = min_band_hz
low_hz = 0
high_hz = self.sample_rate / 2 - (self.min_low_hz + self.min_band_hz)
mel = np.linspace(self.to_mel(low_hz),
self.to_mel(high_hz),
self.out_channels + 1)
hz = self.to_hz(mel)
self.low_hz_ = nn.Parameter(torch.Tensor(hz[:-1]).view(-1, 1))
self.band_hz_ = nn.Parameter(torch.Tensor(np.diff(hz)).view(-1, 1))
n_lin=torch.linspace(0, (self.kernel_size/2)-1, steps=int((self.kernel_size/2)))
self.window_=0.54-0.46*torch.cos(2*math.pi*n_lin/self.kernel_size);
n = (self.kernel_size - 1) / 2.0
self.n_ = 2*math.pi*torch.arange(-n, 0).view(1, -1) / self.sample_rate
def forward(self, waveforms):
self.n_ = self.n_.to(waveforms.device)
self.window_ = self.window_.to(waveforms.device)
low = self.min_low_hz + torch.abs(self.low_hz_)
high = torch.clamp(low + self.min_band_hz + torch.abs(self.band_hz_),self.min_low_hz,self.sample_rate/2)
band=(high-low)[:,0]
f_times_t_low = torch.matmul(low, self.n_)
f_times_t_high = torch.matmul(high, self.n_)
band_pass_left=((torch.sin(f_times_t_high)-torch.sin(f_times_t_low))/(self.n_/2))*self.window_
band_pass_center = 2*band.view(-1,1)
band_pass_right= torch.flip(band_pass_left,dims=[1])
band_pass=torch.cat([band_pass_left,band_pass_center,band_pass_right],dim=1)
band_pass = band_pass / (2*band[:,None])
self.filters = (band_pass).view(
self.out_channels, 1, self.kernel_size)
return F.conv1d(waveforms, self.filters, stride=self.stride,
padding=self.padding, dilation=self.dilation,
bias=None, groups=1)
class Res2Block(nn.Module):
def __init__(self, nb_filts, nums=4):
super(Res2Block, self).__init__()
self.nb_filts = nb_filts
self.conv1 = nn.Conv2d(in_channels=nb_filts[0],
out_channels=nb_filts[1],
kernel_size=1,
padding=0,
stride=1)
self.bn1 = nn.BatchNorm2d(num_features=nb_filts[1])
self.relu = nn.ReLU(inplace=True)
self.nums = nums
self.SE = SE_Block(nb_filts[1])
convs = []
bns = []
for i in range(self.nums):
convs.append(nn.Conv2d(in_channels=(nb_filts[1]// self.nums),
out_channels=(nb_filts[1] //self.nums),
kernel_size=3,
stride=1,
padding=1))
bns.append(nn.BatchNorm2d((nb_filts[1] //self.nums)))
self.convs = nn.ModuleList(convs)
self.bns = nn.ModuleList(bns)
self.conv3 = nn.Conv2d(in_channels=nb_filts[1],
out_channels=nb_filts[1],
kernel_size=1,
padding=0,
stride=1)
self.bn3 = nn.BatchNorm2d(nb_filts[1])
if nb_filts[0] != nb_filts[1]:
self.downsample = True
self.conv_downsample = nn.Conv2d(in_channels=nb_filts[0],
out_channels=nb_filts[1],
padding=(0, 1),
kernel_size=(1, 3),
stride=1)
else:
self.downsample = False
self.mp = nn.MaxPool2d((1,3))
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
spx = torch.split(out, self.nb_filts[1]//self.nums, 1)
for i in range(self.nums):
if i==0:
sp = spx[i]
else:
sp += spx[i]
sp = self.convs[i](sp)
sp = self.bns[i](sp)
if i==0:
out = sp
else:
out = torch.cat((out,sp),1)
out = self.conv3(out)
out = self.bn3(out)
out = self.SE(out)
if self.downsample:
residual = self.conv_downsample(residual)
out += residual
out = self.relu(out)
out = self.mp(out)
return out
class SE_Block(nn.Module):
"credits: https://github.com/moskomule/senet.pytorch/blob/master/senet/se_module.py#L4"
def __init__(self, c, r=8):
super().__init__()
self.squeeze = nn.AdaptiveAvgPool2d(1)
self.excitation = nn.Sequential(
nn.Linear(c, c // r, bias=False),
nn.ReLU(inplace=True),
nn.Linear(c // r, c, bias=False),
nn.Sigmoid()
)
def forward(self, x):
bs, c, _, _ = x.shape
y = self.squeeze(x).view(bs, c)
y = self.excitation(y).view(bs, c, 1, 1)
return x * y.expand_as(x)
class Encoder(nn.Module):
def __init__(self):
super().__init__()
filts = [70, [1, 32], [32, 32], [32, 64], [64, 64]]
self.sinc_conv = SincConv_fast(out_channels=filts[0],
kernel_size=128,
)
self.first_bn = nn.BatchNorm2d(num_features=1)
self.selu = nn.SELU(inplace=True)
self.res_encoder = nn.Sequential(
nn.Sequential(Res2Block(nb_filts=filts[1])),
nn.Sequential(Res2Block(nb_filts=filts[2])),
nn.Sequential(Res2Block(nb_filts=filts[3])),
nn.Sequential(Res2Block(nb_filts=filts[4])),
nn.Sequential(Res2Block(nb_filts=filts[4])),
nn.Sequential(Res2Block(nb_filts=filts[4])))
def forward(self, x):
x = x.unsqueeze(1)
x = self.sinc_conv(x)
x = x.unsqueeze(dim=1)
x = F.max_pool2d(torch.abs(x), (3, 3))
x = self.first_bn(x)
x = self.selu(x)
e = self.res_encoder(x)
return e
import torch
import torch.nn as nn
from torch.nn.utils import weight_norm
class Chomp1d(nn.Module):
def __init__(self, chomp_size):
super(Chomp1d, self).__init__()
self.chomp_size = chomp_size
def forward(self, x):
return x[:, :, :-self.chomp_size].contiguous()
class TemporalBlock(nn.Module):
def __init__(self, n_inputs, n_outputs, kernel_size, stride, dilation, padding, dropout=0.2):
super(TemporalBlock, self).__init__()
self.conv1 = weight_norm(nn.Conv1d(n_inputs, n_outputs, kernel_size,
stride=stride, padding=padding, dilation=dilation))
self.chomp1 = Chomp1d(padding)
self.relu1 = nn.ReLU()
self.dropout1 = nn.Dropout(dropout)
self.conv2 = weight_norm(nn.Conv1d(n_outputs, n_outputs, kernel_size,
stride=stride, padding=padding, dilation=dilation))
self.chomp2 = Chomp1d(padding)
self.relu2 = nn.ReLU()
self.dropout2 = nn.Dropout(dropout)
self.net = nn.Sequential(self.conv1, self.chomp1, self.relu1, self.dropout1,
self.conv2, self.chomp2, self.relu2, self.dropout2)
self.downsample = nn.Conv1d(n_inputs, n_outputs, 1) if n_inputs != n_outputs else None
self.relu = nn.ReLU()
self.init_weights()
def init_weights(self):
self.conv1.weight.data.normal_(0, 0.01)
self.conv2.weight.data.normal_(0, 0.01)
if self.downsample is not None:
self.downsample.weight.data.normal_(0, 0.01)
def forward(self, x):
out = self.net(x)
res = x if self.downsample is None else self.downsample(x)
return self.relu(out + res)
class TemporalConvNet(nn.Module):
def __init__(self, num_inputs, num_channels, kernel_size=2, dropout=0.2):
super(TemporalConvNet, self).__init__()
layers = []
num_levels = len(num_channels)
for i in range(num_levels):
dilation_size = 2 ** i
in_channels = num_inputs if i == 0 else num_channels[i-1]
out_channels = num_channels[i]
layers += [TemporalBlock(in_channels, out_channels, kernel_size, stride=1, dilation=dilation_size,
padding=(kernel_size-1) * dilation_size, dropout=dropout)]
self.network = nn.Sequential(*layers)
def forward(self, x):
return self.network(x)
class TestModel(nn.Module):
def __init__(self):
super().__init__()
self.encoder = Encoder()
self.tempCNN1 = TemporalConvNet(64,[72,36,24,12,6])
self.tempCNN2 = TemporalConvNet(64,[72,36,24,12,6])
self.relu = nn.ReLU(0.1)
self.pooling = nn.AdaptiveAvgPool2d((1, 1))
self.linear1 = nn.Linear(138,4)
self.linear2 = nn.Linear(174,4)
self.linear3 = nn.Linear(8,54)
self.linear4 = nn.Linear(54,2)
self.drop = nn.Dropout(p=0.2)
def forward(self, x):
x = self.encoder(x)
matrix1, _ = torch.max(x, dim=2) # T
matrix2, _ = torch.max(x, dim=3) # S
x1 = self.tempCNN1(matrix2)
x1 = torch.flatten(x1,1,2)
x1 = self.linear1(x1)
x1 = self.drop(x1)
x1 = self.relu(x1)
x2 = self.tempCNN2(matrix1)
x2 = torch.flatten(x2,1,2)
x2 = self.linear2(x2)
x2 = self.drop(x2)
x2 = self.relu(x2)
last_layer =self.relu(self.linear3(torch.cat((x1,x2), dim=1)))
return last_layer, self.linear4(last_layer)