File size: 2,233 Bytes
1cd928a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 |
import torch
from torch import nn
from .deepunet import DeepUnet, DeepUnet0
from .constants import *
from .spec import MelSpectrogram
from .seq import BiGRU
class E2E(nn.Module):
def __init__(self, n_blocks, n_gru, kernel_size, en_de_layers=5, inter_layers=4, in_channels=1,
en_out_channels=16):
super(E2E, self).__init__()
self.unet = DeepUnet(kernel_size, n_blocks, en_de_layers, inter_layers, in_channels, en_out_channels)
self.cnn = nn.Conv2d(en_out_channels, 3, (3, 3), padding=(1, 1))
if n_gru:
self.fc = nn.Sequential(
BiGRU(3 * N_MELS, 256, n_gru),
nn.Linear(512, N_CLASS),
nn.Dropout(0.25),
nn.Sigmoid()
)
else:
self.fc = nn.Sequential(
nn.Linear(3 * N_MELS, N_CLASS),
nn.Dropout(0.25),
nn.Sigmoid()
)
def forward(self, mel):
mel = mel.transpose(-1, -2).unsqueeze(1)
x = self.cnn(self.unet(mel)).transpose(1, 2).flatten(-2)
x = self.fc(x)
return x
class E2E0(nn.Module):
# 和E2E的区别是DeepUnet换成了DeepUnet0, 也就是没有skip connection, 没有residual block, 性能会差一些,但是速度会快一些
def __init__(self, n_blocks, n_gru, kernel_size, en_de_layers=5, inter_layers=4, in_channels=1,
en_out_channels=16):
super(E2E0, self).__init__()
self.unet = DeepUnet0(kernel_size, n_blocks, en_de_layers, inter_layers, in_channels, en_out_channels)
self.cnn = nn.Conv2d(en_out_channels, 3, (3, 3), padding=(1, 1))
if n_gru:
self.fc = nn.Sequential(
BiGRU(3 * N_MELS, 256, n_gru),
nn.Linear(512, N_CLASS),
nn.Dropout(0.25),
nn.Sigmoid()
)
else:
self.fc = nn.Sequential(
nn.Linear(3 * N_MELS, N_CLASS),
nn.Dropout(0.25),
nn.Sigmoid()
)
def forward(self, mel):
mel = mel.transpose(-1, -2).unsqueeze(1)
x = self.cnn(self.unet(mel)).transpose(1, 2).flatten(-2)
x = self.fc(x)
return x
|