Spaces:
Running
Running
init
Browse files- app.py +87 -0
- nnet/ResNet34.py +213 -0
- nnet/__init__.py +0 -0
- nnet/cnns.py +186 -0
- nnet/norm.py +59 -0
- nnet/pooling.py +100 -0
- nnet/speaker_encoder.py +47 -0
- nnet/spex_plus.py +247 -0
- requirement.txt +5 -0
- utils/__init__.py +0 -0
- utils/audio.py +124 -0
- utils/dataset copy.py +284 -0
- utils/dataset.py +402 -0
- utils/load_obj.py +18 -0
- utils/logger.py +22 -0
- utils/sisdr.py +23 -0
- utils/timer.py +17 -0
app.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import torch as th
|
| 3 |
+
import numpy as np
|
| 4 |
+
from nnet.spex_plus import SpEx_Plus
|
| 5 |
+
from utils.logger import get_logger
|
| 6 |
+
from utils.audio import WaveReader, write_wav
|
| 7 |
+
|
| 8 |
+
logger = get_logger(__name__)
|
| 9 |
+
|
| 10 |
+
class NnetComputer(object):
|
| 11 |
+
def __init__(self, cpt_dir, gpuid, nnet_conf):
|
| 12 |
+
self.device = th.device("cuda:{}".format(gpuid)) if gpuid >= 0 else th.device("cpu")
|
| 13 |
+
nnet = self._load_nnet(cpt_dir, nnet_conf)
|
| 14 |
+
self.nnet = nnet.to(self.device) if gpuid >= 0 else nnet
|
| 15 |
+
# set eval model
|
| 16 |
+
self.nnet.eval()
|
| 17 |
+
|
| 18 |
+
def _load_nnet(self, cpt_dir, nnet_conf):
|
| 19 |
+
nnet = SpEx_Plus(**nnet_conf)
|
| 20 |
+
cpt_fname = os.path.join(cpt_dir, "59.pt.tar")
|
| 21 |
+
cpt = th.load(cpt_fname, map_location="cpu")
|
| 22 |
+
nnet.load_state_dict(cpt["model_state_dict"])
|
| 23 |
+
logger.info("Load checkpoint from {}, epoch {:d}".format(
|
| 24 |
+
cpt_fname, cpt["epoch"]))
|
| 25 |
+
return nnet
|
| 26 |
+
|
| 27 |
+
def compute(self, samps, aux_samps, aux_samps_len):
|
| 28 |
+
with th.no_grad():
|
| 29 |
+
raw = th.tensor(samps, dtype=th.float32, device=self.device)
|
| 30 |
+
aux = th.tensor(aux_samps, dtype=th.float32, device=self.device)
|
| 31 |
+
aux_len = th.tensor(aux_samps_len, dtype=th.float32, device=self.device)
|
| 32 |
+
aux = aux.unsqueeze(0)
|
| 33 |
+
sps, sps2, sps3, spk_pred = self.nnet(raw, aux, aux_len)
|
| 34 |
+
sp_samps = np.squeeze(sps.detach().cpu().numpy())
|
| 35 |
+
return sp_samps
|
| 36 |
+
|
| 37 |
+
def compute_output(input_audio, use_gpu, checkpoint, output_dir):
|
| 38 |
+
# Prepare mix_input and aux_input based on the input_audio
|
| 39 |
+
mix_input = {} # Modify this to include your mix_input
|
| 40 |
+
aux_input = {} # Modify this to include your aux_input
|
| 41 |
+
|
| 42 |
+
# Set GPU index based on the user's choice
|
| 43 |
+
gpu_index = -1 if not use_gpu else 0
|
| 44 |
+
|
| 45 |
+
# Run the computation
|
| 46 |
+
nnet_conf = {
|
| 47 |
+
"L1": int(0.0025 * 16000),
|
| 48 |
+
"L2": int(0.01 * 16000),
|
| 49 |
+
"L3": int(0.02 * 16000),
|
| 50 |
+
"N": 256,
|
| 51 |
+
"B": 8,
|
| 52 |
+
"O": 256,
|
| 53 |
+
"P": 512,
|
| 54 |
+
"Q": 3,
|
| 55 |
+
"num_spks": 395,
|
| 56 |
+
"spk_embed_dim": 256,
|
| 57 |
+
"causal": False
|
| 58 |
+
}
|
| 59 |
+
computer = NnetComputer(checkpoint, gpu_index, nnet_conf)
|
| 60 |
+
for key, mix_samps in mix_input:
|
| 61 |
+
aux_samps = aux_input[key]
|
| 62 |
+
logger.info("Compute on utterance {}...".format(key))
|
| 63 |
+
samps = computer.compute(mix_samps, aux_samps, len(aux_samps))
|
| 64 |
+
norm = np.linalg.norm(mix_samps, np.inf)
|
| 65 |
+
samps = samps[:mix_samps.size]
|
| 66 |
+
# Normalize the output
|
| 67 |
+
samps = samps * norm / np.max(np.abs(samps))
|
| 68 |
+
# Write output to the specified directory
|
| 69 |
+
write_wav(os.path.join(output_dir, "{}.wav".format(key)), samps, sample_rate=args.sample_rate)
|
| 70 |
+
logger.info("Compute over {:d} utterances".format(len(mix_input)))
|
| 71 |
+
|
| 72 |
+
# Define the Gradio interface
|
| 73 |
+
inputs = [
|
| 74 |
+
gr.Audio(name="input_audio", label="Input Audio"),
|
| 75 |
+
gr.Checkbox(name="use_gpu", label="Use GPU"),
|
| 76 |
+
gr.TextInput(name="checkpoint", label="Checkpoint Directory"),
|
| 77 |
+
gr.TextInput(name="output_dir", label="Output Directory")
|
| 78 |
+
]
|
| 79 |
+
output = gr.Interface(
|
| 80 |
+
fn=compute_output,
|
| 81 |
+
inputs=inputs,
|
| 82 |
+
outputs=None,
|
| 83 |
+
title="Audio Processing with Neural Network",
|
| 84 |
+
description="Process audio input using a neural network model.",
|
| 85 |
+
theme="compact"
|
| 86 |
+
)
|
| 87 |
+
output.launch()
|
nnet/ResNet34.py
ADDED
|
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#! /usr/bin/python
|
| 2 |
+
# -*- encoding: utf-8 -*-
|
| 3 |
+
'''
|
| 4 |
+
Fast ResNet
|
| 5 |
+
https://arxiv.org/pdf/2003.11982.pdf
|
| 6 |
+
'''
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
from torch.nn import Parameter
|
| 12 |
+
try:
|
| 13 |
+
from .pooling import *
|
| 14 |
+
except:
|
| 15 |
+
from pooling import *
|
| 16 |
+
|
| 17 |
+
class SEBasicBlock(nn.Module):
|
| 18 |
+
expansion = 1
|
| 19 |
+
|
| 20 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None, reduction=8):
|
| 21 |
+
super(SEBasicBlock, self).__init__()
|
| 22 |
+
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
|
| 23 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
| 24 |
+
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, bias=False)
|
| 25 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
| 26 |
+
self.relu = nn.ReLU(inplace=True)
|
| 27 |
+
self.se = SELayer(planes, reduction)
|
| 28 |
+
self.downsample = downsample
|
| 29 |
+
self.stride = stride
|
| 30 |
+
|
| 31 |
+
def forward(self, x):
|
| 32 |
+
residual = x
|
| 33 |
+
|
| 34 |
+
out = self.conv1(x)
|
| 35 |
+
out = self.relu(out)
|
| 36 |
+
out = self.bn1(out)
|
| 37 |
+
|
| 38 |
+
out = self.conv2(out)
|
| 39 |
+
out = self.bn2(out)
|
| 40 |
+
out = self.se(out)
|
| 41 |
+
|
| 42 |
+
if self.downsample is not None:
|
| 43 |
+
residual = self.downsample(x)
|
| 44 |
+
|
| 45 |
+
out += residual
|
| 46 |
+
out = self.relu(out)
|
| 47 |
+
return out
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class SEBottleneck(nn.Module):
|
| 51 |
+
expansion = 4
|
| 52 |
+
|
| 53 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None, reduction=8):
|
| 54 |
+
super(SEBottleneck, self).__init__()
|
| 55 |
+
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
|
| 56 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
| 57 |
+
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
|
| 58 |
+
padding=1, bias=False)
|
| 59 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
| 60 |
+
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
|
| 61 |
+
self.bn3 = nn.BatchNorm2d(planes * 4)
|
| 62 |
+
self.relu = nn.ReLU(inplace=True)
|
| 63 |
+
self.se = SELayer(planes * 4, reduction)
|
| 64 |
+
self.downsample = downsample
|
| 65 |
+
self.stride = stride
|
| 66 |
+
|
| 67 |
+
def forward(self, x):
|
| 68 |
+
residual = x
|
| 69 |
+
|
| 70 |
+
out = self.conv1(x)
|
| 71 |
+
out = self.bn1(out)
|
| 72 |
+
out = self.relu(out)
|
| 73 |
+
|
| 74 |
+
out = self.conv2(out)
|
| 75 |
+
out = self.bn2(out)
|
| 76 |
+
out = self.relu(out)
|
| 77 |
+
|
| 78 |
+
out = self.conv3(out)
|
| 79 |
+
out = self.bn3(out)
|
| 80 |
+
out = self.se(out)
|
| 81 |
+
|
| 82 |
+
if self.downsample is not None:
|
| 83 |
+
residual = self.downsample(x)
|
| 84 |
+
|
| 85 |
+
out += residual
|
| 86 |
+
out = self.relu(out)
|
| 87 |
+
|
| 88 |
+
return out
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
class SELayer(nn.Module):
|
| 92 |
+
def __init__(self, channel, reduction=8):
|
| 93 |
+
super(SELayer, self).__init__()
|
| 94 |
+
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
| 95 |
+
self.fc = nn.Sequential(
|
| 96 |
+
nn.Linear(channel, channel // reduction),
|
| 97 |
+
nn.ReLU(inplace=True),
|
| 98 |
+
nn.Linear(channel // reduction, channel),
|
| 99 |
+
nn.Sigmoid()
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
def forward(self, x):
|
| 103 |
+
b, c, _, _ = x.size()
|
| 104 |
+
y = self.avg_pool(x).view(b, c)
|
| 105 |
+
y = self.fc(y).view(b, c, 1, 1)
|
| 106 |
+
return x * y
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
class ResNetSE(nn.Module):
|
| 110 |
+
def __init__(self, block, layers, num_filters, embedding_dim, n_mels=80, pooling_type="TSP", **kwargs):
|
| 111 |
+
super(ResNetSE, self).__init__()
|
| 112 |
+
|
| 113 |
+
self.inplanes = num_filters[0]
|
| 114 |
+
self.conv1 = nn.Conv2d(1, num_filters[0] , kernel_size=3, stride=(1, 1), padding=1,
|
| 115 |
+
bias=False)
|
| 116 |
+
self.bn1 = nn.BatchNorm2d(num_filters[0])
|
| 117 |
+
self.relu = nn.ReLU(inplace=True)
|
| 118 |
+
|
| 119 |
+
self.layer1 = self._make_layer(block, num_filters[0], layers[0])
|
| 120 |
+
self.layer2 = self._make_layer(block, num_filters[1], layers[1], stride=(2, 2))
|
| 121 |
+
self.layer3 = self._make_layer(block, num_filters[2], layers[2], stride=(2, 2))
|
| 122 |
+
self.layer4 = self._make_layer(block, num_filters[3], layers[3], stride=(2, 2))
|
| 123 |
+
|
| 124 |
+
out_dim = num_filters[3] * block.expansion * (n_mels//8)
|
| 125 |
+
|
| 126 |
+
if pooling_type == "Temporal_Average_Pooling" or pooling_type == "TAP":
|
| 127 |
+
self.pooling = Temporal_Average_Pooling()
|
| 128 |
+
self.bn2 = nn.BatchNorm1d(out_dim)
|
| 129 |
+
self.fc = nn.Linear(out_dim, embedding_dim)
|
| 130 |
+
self.bn3 = nn.BatchNorm1d(embedding_dim)
|
| 131 |
+
|
| 132 |
+
elif pooling_type == "Temporal_Statistics_Pooling" or pooling_type == "TSP":
|
| 133 |
+
self.pooling = Temporal_Statistics_Pooling()
|
| 134 |
+
self.bn2 = nn.BatchNorm1d(out_dim * 2)
|
| 135 |
+
self.fc = nn.Linear(out_dim * 2, embedding_dim)
|
| 136 |
+
self.bn3 = nn.BatchNorm1d(embedding_dim)
|
| 137 |
+
|
| 138 |
+
elif pooling_type == "Self_Attentive_Pooling" or pooling_type == "SAP":
|
| 139 |
+
self.pooling = Self_Attentive_Pooling(out_dim)
|
| 140 |
+
self.bn2 = nn.BatchNorm1d(out_dim)
|
| 141 |
+
self.fc = nn.Linear(out_dim, embedding_dim)
|
| 142 |
+
self.bn3 = nn.BatchNorm1d(embedding_dim)
|
| 143 |
+
|
| 144 |
+
elif pooling_type == "Attentive_Statistics_Pooling" or pooling_type == "ASP":
|
| 145 |
+
self.pooling = Attentive_Statistics_Pooling(out_dim)
|
| 146 |
+
self.bn2 = nn.BatchNorm1d(out_dim * 2)
|
| 147 |
+
self.fc = nn.Linear(out_dim * 2, embedding_dim)
|
| 148 |
+
self.bn3 = nn.BatchNorm1d(embedding_dim)
|
| 149 |
+
|
| 150 |
+
else:
|
| 151 |
+
raise ValueError('{} pooling type is not defined'.format(pooling_type))
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
for m in self.modules():
|
| 155 |
+
if isinstance(m, nn.Conv2d):
|
| 156 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
| 157 |
+
elif isinstance(m, nn.BatchNorm2d):
|
| 158 |
+
nn.init.constant_(m.weight, 1)
|
| 159 |
+
nn.init.constant_(m.bias, 0)
|
| 160 |
+
|
| 161 |
+
def _make_layer(self, block, planes, blocks, stride=1):
|
| 162 |
+
downsample = None
|
| 163 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
| 164 |
+
downsample = nn.Sequential(
|
| 165 |
+
nn.Conv2d(self.inplanes, planes * block.expansion,
|
| 166 |
+
kernel_size=1, stride=stride, bias=False),
|
| 167 |
+
nn.BatchNorm2d(planes * block.expansion),
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
layers = []
|
| 171 |
+
layers.append(block(self.inplanes, planes, stride, downsample))
|
| 172 |
+
self.inplanes = planes * block.expansion
|
| 173 |
+
for i in range(1, blocks):
|
| 174 |
+
layers.append(block(self.inplanes, planes))
|
| 175 |
+
|
| 176 |
+
return nn.Sequential(*layers)
|
| 177 |
+
|
| 178 |
+
def forward(self, x):
|
| 179 |
+
x = x.unsqueeze(1)
|
| 180 |
+
x = self.conv1(x)
|
| 181 |
+
x = self.bn1(x)
|
| 182 |
+
x = self.relu(x)
|
| 183 |
+
|
| 184 |
+
x = self.layer1(x)
|
| 185 |
+
x = self.layer2(x)
|
| 186 |
+
x = self.layer3(x)
|
| 187 |
+
x = self.layer4(x)
|
| 188 |
+
|
| 189 |
+
x = x.reshape(x.shape[0], -1, x.shape[-1])
|
| 190 |
+
|
| 191 |
+
x = self.pooling(x)
|
| 192 |
+
x = self.bn2(x)
|
| 193 |
+
x = torch.flatten(x, 1)
|
| 194 |
+
x = self.fc(x)
|
| 195 |
+
x = self.bn3(x)
|
| 196 |
+
return x
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
def Speaker_Encoder(embedding_dim=256, **kwargs):
|
| 200 |
+
# Number of filters
|
| 201 |
+
num_filters = [32, 64, 128, 256]
|
| 202 |
+
model = ResNetSE(SEBasicBlock, [3, 4, 6, 3], num_filters, embedding_dim, **kwargs)
|
| 203 |
+
return model
|
| 204 |
+
|
| 205 |
+
if __name__ == '__main__':
|
| 206 |
+
model = Speaker_Encoder()
|
| 207 |
+
total = sum([param.nelement() for param in model.parameters()])
|
| 208 |
+
print(total/1e6)
|
| 209 |
+
data = torch.randn(10, 80, 100)
|
| 210 |
+
out = model(data)
|
| 211 |
+
print(data.shape)
|
| 212 |
+
print(out.shape)
|
| 213 |
+
|
nnet/__init__.py
ADDED
|
File without changes
|
nnet/cnns.py
ADDED
|
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
|
| 3 |
+
import torch as th
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
|
| 6 |
+
from .norm import ChannelwiseLayerNorm, GlobalLayerNorm
|
| 7 |
+
|
| 8 |
+
class Conv1D(nn.Conv1d):
|
| 9 |
+
"""
|
| 10 |
+
1D Conv based on nn.Conv1d for 2D or 3D tensor
|
| 11 |
+
Input: 2D or 3D tensor with [N, L_in] or [N, C_in, L_in]
|
| 12 |
+
Output: Default 3D tensor with [N, C_out, L_out]
|
| 13 |
+
If C_out=1 and squeeze is true, return 2D tensor
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
def __init__(self, *args, **kwargs):
|
| 17 |
+
super(Conv1D, self).__init__(*args, **kwargs)
|
| 18 |
+
|
| 19 |
+
def forward(self, x, squeeze=False):
|
| 20 |
+
if x.dim() not in [2, 3]:
|
| 21 |
+
raise RuntimeError("{} require a 2/3D tensor input".format(
|
| 22 |
+
self.__name__))
|
| 23 |
+
x = super().forward(x if x.dim() == 3 else th.unsqueeze(x, 1))
|
| 24 |
+
if squeeze:
|
| 25 |
+
x = th.squeeze(x)
|
| 26 |
+
return x
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class ConvTrans1D(nn.ConvTranspose1d):
|
| 30 |
+
"""
|
| 31 |
+
1D Transposed Conv based on nn.ConvTranspose1d for 2D or 3D tensor
|
| 32 |
+
Input: 2D or 3D tensor with [N, L_in] or [N, C_in, L_in]
|
| 33 |
+
Output: 2D tensor with [N, L_out]
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
def __init__(self, *args, **kwargs):
|
| 37 |
+
super(ConvTrans1D, self).__init__(*args, **kwargs)
|
| 38 |
+
|
| 39 |
+
def forward(self, x):
|
| 40 |
+
if x.dim() not in [2, 3]:
|
| 41 |
+
raise RuntimeError("{} require a 2/3D tensor input".format(
|
| 42 |
+
self.__name__))
|
| 43 |
+
x = super().forward(x if x.dim() == 3 else th.unsqueeze(x, 1))
|
| 44 |
+
|
| 45 |
+
# squeeze the channel dimension 1 after reconstructing the signal
|
| 46 |
+
return th.squeeze(x, 1)
|
| 47 |
+
|
| 48 |
+
class TCNBlock(nn.Module):
|
| 49 |
+
"""
|
| 50 |
+
Temporal convolutional network block,
|
| 51 |
+
1x1Conv - PReLU - Norm - DConv - PReLU - Norm - SConv
|
| 52 |
+
Input: 3D tensor with [N, C_in, L_in]
|
| 53 |
+
Output: 3D tensor with [N, C_out, L_out]
|
| 54 |
+
"""
|
| 55 |
+
|
| 56 |
+
def __init__(self,
|
| 57 |
+
in_channels=256,
|
| 58 |
+
conv_channels=512,
|
| 59 |
+
kernel_size=3,
|
| 60 |
+
dilation=1,
|
| 61 |
+
causal=False):
|
| 62 |
+
super(TCNBlock, self).__init__()
|
| 63 |
+
self.conv1x1 = Conv1D(in_channels, conv_channels, 1)
|
| 64 |
+
self.prelu1 = nn.PReLU()
|
| 65 |
+
self.norm1 = GlobalLayerNorm(conv_channels, elementwise_affine=True) if not causal else (
|
| 66 |
+
ChannelwiseLayerNorm(conv_channels, elementwise_affine=True))
|
| 67 |
+
dconv_pad = (dilation * (kernel_size - 1)) // 2 if not causal else (
|
| 68 |
+
dilation * (kernel_size - 1))
|
| 69 |
+
self.dconv = nn.Conv1d(
|
| 70 |
+
conv_channels,
|
| 71 |
+
conv_channels,
|
| 72 |
+
kernel_size,
|
| 73 |
+
groups=conv_channels,
|
| 74 |
+
padding=dconv_pad,
|
| 75 |
+
dilation=dilation,
|
| 76 |
+
bias=True)
|
| 77 |
+
self.prelu2 = nn.PReLU()
|
| 78 |
+
self.norm2 = GlobalLayerNorm(conv_channels, elementwise_affine=True) if not causal else (
|
| 79 |
+
ChannelwiseLayerNorm(conv_channels, elementwise_affine=True))
|
| 80 |
+
self.sconv = nn.Conv1d(conv_channels, in_channels, 1, bias=True)
|
| 81 |
+
self.causal = causal
|
| 82 |
+
self.dconv_pad = dconv_pad
|
| 83 |
+
|
| 84 |
+
def forward(self, x):
|
| 85 |
+
y = self.conv1x1(x)
|
| 86 |
+
y = self.norm1(self.prelu1(y))
|
| 87 |
+
y = self.dconv(y)
|
| 88 |
+
if self.causal:
|
| 89 |
+
y = y[:, :, :-self.dconv_pad]
|
| 90 |
+
y = self.norm2(self.prelu2(y))
|
| 91 |
+
y = self.sconv(y)
|
| 92 |
+
y += x
|
| 93 |
+
return y
|
| 94 |
+
|
| 95 |
+
class TCNBlock_Spk(nn.Module):
|
| 96 |
+
"""
|
| 97 |
+
Temporal convolutional network block,
|
| 98 |
+
1x1Conv - PReLU - Norm - DConv - PReLU - Norm - SConv
|
| 99 |
+
The first tcn block takes additional speaker embedding as inputs
|
| 100 |
+
Input: 3D tensor with [N, C_in, L_in]
|
| 101 |
+
Input Speaker Embedding: 2D tensor with [N, D]
|
| 102 |
+
Output: 3D tensor with [N, C_out, L_out]
|
| 103 |
+
"""
|
| 104 |
+
|
| 105 |
+
def __init__(self,
|
| 106 |
+
in_channels=256,
|
| 107 |
+
spk_embed_dim=100,
|
| 108 |
+
conv_channels=512,
|
| 109 |
+
kernel_size=3,
|
| 110 |
+
dilation=1,
|
| 111 |
+
causal=False):
|
| 112 |
+
super(TCNBlock_Spk, self).__init__()
|
| 113 |
+
self.conv1x1 = Conv1D(in_channels+spk_embed_dim, conv_channels, 1)
|
| 114 |
+
self.prelu1 = nn.PReLU()
|
| 115 |
+
self.norm1 = GlobalLayerNorm(conv_channels, elementwise_affine=True) if not causal else (
|
| 116 |
+
ChannelwiseLayerNorm(conv_channels, elementwise_affine=True))
|
| 117 |
+
dconv_pad = (dilation * (kernel_size - 1)) // 2 if not causal else (
|
| 118 |
+
dilation * (kernel_size - 1))
|
| 119 |
+
self.dconv = nn.Conv1d(
|
| 120 |
+
conv_channels,
|
| 121 |
+
conv_channels,
|
| 122 |
+
kernel_size,
|
| 123 |
+
groups=conv_channels,
|
| 124 |
+
padding=dconv_pad,
|
| 125 |
+
dilation=dilation,
|
| 126 |
+
bias=True)
|
| 127 |
+
self.prelu2 = nn.PReLU()
|
| 128 |
+
self.norm2 = GlobalLayerNorm(conv_channels, elementwise_affine=True) if not causal else (
|
| 129 |
+
ChannelwiseLayerNorm(conv_channels, elementwise_affine=True))
|
| 130 |
+
self.sconv = nn.Conv1d(conv_channels, in_channels, 1, bias=True)
|
| 131 |
+
self.causal = causal
|
| 132 |
+
self.dconv_pad = dconv_pad
|
| 133 |
+
self.dilation = dilation
|
| 134 |
+
|
| 135 |
+
def forward(self, x, aux):
|
| 136 |
+
# Repeatedly concated speaker embedding aux to each frame of the representation x
|
| 137 |
+
T = x.shape[-1]
|
| 138 |
+
aux = th.unsqueeze(aux, -1)
|
| 139 |
+
aux = aux.repeat(1,1,T)
|
| 140 |
+
y = th.cat([x, aux], 1)
|
| 141 |
+
y = self.conv1x1(y)
|
| 142 |
+
y = self.norm1(self.prelu1(y))
|
| 143 |
+
y = self.dconv(y)
|
| 144 |
+
if self.causal:
|
| 145 |
+
y = y[:, :, :-self.dconv_pad]
|
| 146 |
+
y = self.norm2(self.prelu2(y))
|
| 147 |
+
y = self.sconv(y)
|
| 148 |
+
y += x
|
| 149 |
+
return y
|
| 150 |
+
|
| 151 |
+
class ResBlock(nn.Module):
|
| 152 |
+
"""
|
| 153 |
+
Resnet block for speaker encoder to obtain speaker embedding
|
| 154 |
+
ref to
|
| 155 |
+
https://github.com/fatchord/WaveRNN/blob/master/models/fatchord_version.py
|
| 156 |
+
and
|
| 157 |
+
https://github.com/Jungjee/RawNet/blob/master/PyTorch/model_RawNet.py
|
| 158 |
+
"""
|
| 159 |
+
def __init__(self, in_dims, out_dims):
|
| 160 |
+
super(ResBlock, self).__init__()
|
| 161 |
+
self.conv1 = nn.Conv1d(in_dims, out_dims, kernel_size=1, bias=False)
|
| 162 |
+
self.conv2 = nn.Conv1d(out_dims, out_dims, kernel_size=1, bias=False)
|
| 163 |
+
self.batch_norm1 = nn.BatchNorm1d(out_dims)
|
| 164 |
+
self.batch_norm2 = nn.BatchNorm1d(out_dims)
|
| 165 |
+
self.prelu1 = nn.PReLU()
|
| 166 |
+
self.prelu2 = nn.PReLU()
|
| 167 |
+
self.maxpool = nn.MaxPool1d(3)
|
| 168 |
+
if in_dims != out_dims:
|
| 169 |
+
self.downsample = True
|
| 170 |
+
self.conv_downsample = nn.Conv1d(in_dims, out_dims, kernel_size=1, bias=False)
|
| 171 |
+
else:
|
| 172 |
+
self.downsample = False
|
| 173 |
+
|
| 174 |
+
def forward(self, x):
|
| 175 |
+
y = self.conv1(x)
|
| 176 |
+
y = self.batch_norm1(y)
|
| 177 |
+
y = self.prelu1(y)
|
| 178 |
+
y = self.conv2(y)
|
| 179 |
+
y = self.batch_norm2(y)
|
| 180 |
+
if self.downsample:
|
| 181 |
+
y += self.conv_downsample(x)
|
| 182 |
+
else:
|
| 183 |
+
y += x
|
| 184 |
+
y = self.prelu2(y)
|
| 185 |
+
return self.maxpool(y)
|
| 186 |
+
|
nnet/norm.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
|
| 3 |
+
import torch as th
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
|
| 6 |
+
class ChannelwiseLayerNorm(nn.LayerNorm):
|
| 7 |
+
"""
|
| 8 |
+
Channel-wise layer normalization based on nn.LayerNorm
|
| 9 |
+
Input: 3D tensor with [batch_size(N), channel_size(C), frame_num(T)]
|
| 10 |
+
Output: 3D tensor with same shape
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
def __init__(self, *args, **kwargs):
|
| 14 |
+
super(ChannelwiseLayerNorm, self).__init__(*args, **kwargs)
|
| 15 |
+
|
| 16 |
+
def forward(self, x):
|
| 17 |
+
if x.dim() != 3:
|
| 18 |
+
raise RuntimeError("{} requires a 3D tensor input".format(
|
| 19 |
+
self.__name__))
|
| 20 |
+
x = th.transpose(x, 1, 2)
|
| 21 |
+
x = super().forward(x)
|
| 22 |
+
x = th.transpose(x, 1, 2)
|
| 23 |
+
return x
|
| 24 |
+
|
| 25 |
+
class GlobalLayerNorm(nn.Module):
|
| 26 |
+
"""
|
| 27 |
+
Global layer normalization
|
| 28 |
+
Input: 3D tensor with [batch_size(N), channel_size(C), frame_num(T)]
|
| 29 |
+
Output: 3D tensor with same shape
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
def __init__(self, dim, eps=1e-05, elementwise_affine=True):
|
| 33 |
+
super(GlobalLayerNorm, self).__init__()
|
| 34 |
+
self.eps = eps
|
| 35 |
+
self.normalized_dim = dim
|
| 36 |
+
self.elementwise_affine = elementwise_affine
|
| 37 |
+
if elementwise_affine:
|
| 38 |
+
self.beta = nn.Parameter(th.zeros(dim, 1))
|
| 39 |
+
self.gamma = nn.Parameter(th.ones(dim, 1))
|
| 40 |
+
else:
|
| 41 |
+
self.register_parameter("weight", None)
|
| 42 |
+
self.register_parameter("bias", None)
|
| 43 |
+
|
| 44 |
+
def forward(self, x):
|
| 45 |
+
if x.dim() != 3:
|
| 46 |
+
raise RuntimeError("{} requires a 3D tensor input".format(
|
| 47 |
+
self.__name__))
|
| 48 |
+
# calculate the mean, variance over the channel and time dimensions
|
| 49 |
+
mean = th.mean(x, (1, 2), keepdim=True)
|
| 50 |
+
var = th.mean((x - mean)**2, (1, 2), keepdim=True)
|
| 51 |
+
if self.elementwise_affine:
|
| 52 |
+
x = self.gamma * (x - mean) / th.sqrt(var + self.eps) + self.beta
|
| 53 |
+
else:
|
| 54 |
+
x = (x - mean) / th.sqrt(var + self.eps)
|
| 55 |
+
return x
|
| 56 |
+
|
| 57 |
+
def extra_repr(self):
|
| 58 |
+
return "{normalized_dim}, eps={eps}, " \
|
| 59 |
+
"elementwise_affine={elementwise_affine}".format(**self.__dict__)
|
nnet/pooling.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from speechbrain.lobes.models.ECAPA_TDNN import AttentiveStatisticsPooling
|
| 5 |
+
|
| 6 |
+
class Temporal_Average_Pooling(nn.Module):
|
| 7 |
+
def __init__(self, **kwargs):
|
| 8 |
+
"""TAP
|
| 9 |
+
Paper: Multi-Task Learning with High-Order Statistics for X-vector based Text-Independent Speaker Verification
|
| 10 |
+
Link: https://arxiv.org/pdf/1903.12058.pdf
|
| 11 |
+
"""
|
| 12 |
+
super(Temporal_Average_Pooling, self).__init__()
|
| 13 |
+
|
| 14 |
+
def forward(self, x):
|
| 15 |
+
"""Computes Temporal Average Pooling Module
|
| 16 |
+
Args:
|
| 17 |
+
x (torch.Tensor): Input tensor (#batch, channels, frames).
|
| 18 |
+
Returns:
|
| 19 |
+
torch.Tensor: Output tensor (#batch, channels)
|
| 20 |
+
"""
|
| 21 |
+
x = torch.mean(x, axis=2)
|
| 22 |
+
return x
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class Temporal_Statistics_Pooling(nn.Module):
|
| 26 |
+
def __init__(self, **kwargs):
|
| 27 |
+
"""TSP
|
| 28 |
+
Paper: X-vectors: Robust DNN Embeddings for Speaker Recognition
|
| 29 |
+
Link: http://www.danielpovey.com/files/2018_icassp_xvectors.pdf
|
| 30 |
+
"""
|
| 31 |
+
super(Temporal_Statistics_Pooling, self).__init__()
|
| 32 |
+
|
| 33 |
+
def forward(self, x):
|
| 34 |
+
"""Computes Temporal Statistics Pooling Module
|
| 35 |
+
Args:
|
| 36 |
+
x (torch.Tensor): Input tensor (#batch, channels, frames).
|
| 37 |
+
Returns:
|
| 38 |
+
torch.Tensor: Output tensor (#batch, channels*2)
|
| 39 |
+
"""
|
| 40 |
+
mean = torch.mean(x, axis=2)
|
| 41 |
+
var = torch.var(x, axis=2)
|
| 42 |
+
x = torch.cat((mean, var), axis=1)
|
| 43 |
+
return x
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
''' Self attentive weighted mean pooling.
|
| 47 |
+
'''
|
| 48 |
+
class Self_Attentive_Pooling(nn.Module):
|
| 49 |
+
def __init__(self, dim, **kwargs):
|
| 50 |
+
# Use Conv1d with stride == 1 rather than Linear, then we don't need to transpose inputs.
|
| 51 |
+
# attention dim = 128
|
| 52 |
+
super(Self_Attentive_Pooling, self).__init__()
|
| 53 |
+
self.linear1 = nn.Conv1d(dim, dim, kernel_size=1) # equals W and b in the paper
|
| 54 |
+
self.linear2 = nn.Conv1d(dim, dim, kernel_size=1) # equals V and k in the paper
|
| 55 |
+
|
| 56 |
+
def forward(self, x):
|
| 57 |
+
# DON'T use ReLU here! In experiments, I find ReLU hard to converge.
|
| 58 |
+
alpha = torch.tanh(self.linear1(x))
|
| 59 |
+
alpha = torch.softmax(self.linear2(alpha), dim=2)
|
| 60 |
+
mean = torch.sum(alpha * x, dim=2)
|
| 61 |
+
return mean
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
''' Attentive weighted mean and standard deviation pooling.
|
| 65 |
+
'''
|
| 66 |
+
class Attentive_Statistics_Pooling(nn.Module):
|
| 67 |
+
def __init__(self, dim, **kwargs):
|
| 68 |
+
# Use AttentiveStatisticsPooling and BatchNorm1d from speechbrain
|
| 69 |
+
super(Attentive_Statistics_Pooling, self).__init__()
|
| 70 |
+
self.pooling = AttentiveStatisticsPooling(dim)
|
| 71 |
+
|
| 72 |
+
def forward(self, x):
|
| 73 |
+
x = self.pooling(x)
|
| 74 |
+
return x
|
| 75 |
+
|
| 76 |
+
# class Attentive_Statistics_Pooling(nn.Module):
|
| 77 |
+
# def __init__(self, dim, **kwargs):
|
| 78 |
+
# # Use Conv1d with stride == 1 rather than Linear, then we don't need to transpose inputs.
|
| 79 |
+
# # attention dim = 128
|
| 80 |
+
# super(Attentive_Statistics_Pooling, self).__init__()
|
| 81 |
+
# self.linear1 = nn.Conv1d(dim, dim, kernel_size=1) # equals W and b in the paper
|
| 82 |
+
# self.linear2 = nn.Conv1d(dim, dim, kernel_size=1) # equals V and k in the paper
|
| 83 |
+
#
|
| 84 |
+
# def forward(self, x):
|
| 85 |
+
# # DON'T use ReLU here! In experiments, I find ReLU hard to converge.
|
| 86 |
+
# alpha = torch.tanh(self.linear1(x))
|
| 87 |
+
# alpha = torch.softmax(self.linear2(alpha), dim=2)
|
| 88 |
+
# mean = torch.sum(alpha * x, dim=2)
|
| 89 |
+
# residuals = torch.sum(alpha * x ** 2, dim=2) - mean ** 2
|
| 90 |
+
# std = torch.sqrt(residuals.clamp(min=1e-9))
|
| 91 |
+
# return torch.cat([mean, std], dim=1)
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
if __name__ == "__main__":
|
| 96 |
+
data = torch.randn(10, 128, 100)
|
| 97 |
+
pooling = Self_Attentive_Pooling(128)
|
| 98 |
+
out = pooling(data)
|
| 99 |
+
print(data.shape)
|
| 100 |
+
print(out.shape)
|
nnet/speaker_encoder.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torchaudio
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
from torch.nn import functional as F
|
| 5 |
+
from .ResNet34 import Speaker_Encoder
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class Speaker_Model(torch.nn.Module):
|
| 9 |
+
#class Speaker_Model(LightningModule):
|
| 10 |
+
def __init__(self, pooling_type, spk_embed_dim, sample_rate, n_mels):
|
| 11 |
+
super().__init__()
|
| 12 |
+
# self.save_hyperparameters()
|
| 13 |
+
|
| 14 |
+
self.pooling_type = pooling_type
|
| 15 |
+
self.spk_embed_dim = spk_embed_dim
|
| 16 |
+
self.sample_rate = sample_rate
|
| 17 |
+
self.n_mels = n_mels
|
| 18 |
+
sr = self.sample_rate
|
| 19 |
+
|
| 20 |
+
self.mel_trans = torch.nn.Sequential(
|
| 21 |
+
PreEmphasis(),
|
| 22 |
+
torchaudio.transforms.MelSpectrogram(sample_rate=sr, n_fft=512,
|
| 23 |
+
win_length=sr * 25 // 1000, hop_length=sr * 10 // 1000,
|
| 24 |
+
window_fn=torch.hamming_window, n_mels=self.n_mels)
|
| 25 |
+
)
|
| 26 |
+
self.instancenorm = nn.InstanceNorm1d(self.n_mels)
|
| 27 |
+
|
| 28 |
+
self.hparams = {'embedding_dim': self.spk_embed_dim, 'pooling_type': self.pooling_type , 'n_mels': self.n_mels}
|
| 29 |
+
|
| 30 |
+
self.speaker_encoder = Speaker_Encoder(**dict(self.hparams))
|
| 31 |
+
|
| 32 |
+
class PreEmphasis(torch.nn.Module):
|
| 33 |
+
def __init__(self, coef: float = 0.97):
|
| 34 |
+
super().__init__()
|
| 35 |
+
self.coef = coef
|
| 36 |
+
# make kernel
|
| 37 |
+
# In pytorch, the convolution operation uses cross-correlation. So, filter is flipped.
|
| 38 |
+
self.register_buffer(
|
| 39 |
+
'flipped_filter', torch.FloatTensor([-self.coef, 1.]).unsqueeze(0).unsqueeze(0)
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
def forward(self, inputs: torch.tensor) -> torch.tensor:
|
| 43 |
+
assert len(inputs.size()) == 2, 'The number of dimensions of inputs tensor must be 2!'
|
| 44 |
+
# reflect padding to match lengths of in/out
|
| 45 |
+
inputs = inputs.unsqueeze(1)
|
| 46 |
+
inputs = F.pad(inputs, (1, 0), 'reflect')
|
| 47 |
+
return F.conv1d(inputs, self.flipped_filter).squeeze(1)
|
nnet/spex_plus.py
ADDED
|
@@ -0,0 +1,247 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
|
| 3 |
+
import torch as th
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
|
| 7 |
+
from .norm import ChannelwiseLayerNorm, GlobalLayerNorm
|
| 8 |
+
from .cnns import Conv1D, ConvTrans1D, TCNBlock, TCNBlock_Spk, ResBlock
|
| 9 |
+
|
| 10 |
+
import torchaudio
|
| 11 |
+
from .ResNet34 import Speaker_Encoder
|
| 12 |
+
# from .sunine.trainer.utils import PreEmphasis
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
# 考虑两种可能,频域就不大可能有所谓的多时间尺度,所以肯定speaker是直接频谱,那speech呢?
|
| 17 |
+
# 注意下维度 是 B N T 还是 B T N
|
| 18 |
+
|
| 19 |
+
class SpEx_Plus(nn.Module):
|
| 20 |
+
def __init__(self,
|
| 21 |
+
L1=20,
|
| 22 |
+
L2=80,
|
| 23 |
+
L3=160,
|
| 24 |
+
N=256,
|
| 25 |
+
B=8,
|
| 26 |
+
O=256,
|
| 27 |
+
P=512,
|
| 28 |
+
Q=3,
|
| 29 |
+
num_spks=101,
|
| 30 |
+
spk_embed_dim=256,
|
| 31 |
+
sample_rate = 16000,
|
| 32 |
+
n_mels = 80,
|
| 33 |
+
causal=False,
|
| 34 |
+
):
|
| 35 |
+
super(SpEx_Plus, self).__init__()
|
| 36 |
+
# n x S => n x N x T, S = 4s*8000 = 32000
|
| 37 |
+
self.sample_rate = sample_rate
|
| 38 |
+
self.n_mels = n_mels
|
| 39 |
+
self.L1 = L1
|
| 40 |
+
self.L2 = L2
|
| 41 |
+
self.L3 = L3
|
| 42 |
+
self.encoder_1d_short = Conv1D(1, N, L1, stride=L1 // 2, padding=0)
|
| 43 |
+
self.encoder_1d_middle = Conv1D(1, N, L2, stride=L1 // 2, padding=0)
|
| 44 |
+
self.encoder_1d_long = Conv1D(1, N, L3, stride=L1 // 2, padding=0)
|
| 45 |
+
# before repeat blocks, always cLN
|
| 46 |
+
self.ln = ChannelwiseLayerNorm(3*N)
|
| 47 |
+
# n x N x T => n x O x T
|
| 48 |
+
self.proj = Conv1D(3*N, O, 1)
|
| 49 |
+
self.conv_block_1 = TCNBlock_Spk(spk_embed_dim=spk_embed_dim, in_channels=O, conv_channels=P, kernel_size=Q, causal=causal, dilation=1)
|
| 50 |
+
self.conv_block_1_other = self._build_stacks(num_blocks=B, in_channels=O, conv_channels=P, kernel_size=Q, causal=causal)
|
| 51 |
+
self.conv_block_2 = TCNBlock_Spk(spk_embed_dim=spk_embed_dim, in_channels=O, conv_channels=P, kernel_size=Q, causal=causal, dilation=1)
|
| 52 |
+
self.conv_block_2_other = self._build_stacks(num_blocks=B, in_channels=O, conv_channels=P, kernel_size=Q, causal=causal)
|
| 53 |
+
self.conv_block_3 = TCNBlock_Spk(spk_embed_dim=spk_embed_dim, in_channels=O, conv_channels=P, kernel_size=Q, causal=causal, dilation=1)
|
| 54 |
+
self.conv_block_3_other = self._build_stacks(num_blocks=B, in_channels=O, conv_channels=P, kernel_size=Q, causal=causal)
|
| 55 |
+
self.conv_block_4 = TCNBlock_Spk(spk_embed_dim=spk_embed_dim, in_channels=O, conv_channels=P, kernel_size=Q, causal=causal, dilation=1)
|
| 56 |
+
self.conv_block_4_other = self._build_stacks(num_blocks=B, in_channels=O, conv_channels=P, kernel_size=Q, causal=causal)
|
| 57 |
+
# n x O x T => n x N x T
|
| 58 |
+
self.mask1 = Conv1D(O, N, 1)
|
| 59 |
+
self.mask2 = Conv1D(O, N, 1)
|
| 60 |
+
self.mask3 = Conv1D(O, N, 1)
|
| 61 |
+
# using ConvTrans1D: n x N x T => n x 1 x To
|
| 62 |
+
# To = (T - 1) * L // 2 + L
|
| 63 |
+
#############################################################
|
| 64 |
+
self.decoder_1d_short = ConvTrans1D(N, 1, kernel_size=L1, stride=L1 // 2, bias=True)
|
| 65 |
+
self.decoder_1d_middle = ConvTrans1D(N, 1, kernel_size=L2, stride=L1 // 2, bias=True)
|
| 66 |
+
self.decoder_1d_long = ConvTrans1D(N, 1, kernel_size=L3, stride=L1 // 2, bias=True)
|
| 67 |
+
self.num_spks = num_spks
|
| 68 |
+
# self.spk_encoder = nn.Sequential(
|
| 69 |
+
# ChannelwiseLayerNorm(3*N),
|
| 70 |
+
# Conv1D(3*N, O, 1),
|
| 71 |
+
# ResBlock(O, O),
|
| 72 |
+
# ResBlock(O, P),
|
| 73 |
+
# ResBlock(P, P),
|
| 74 |
+
# Conv1D(P, spk_embed_dim, 1),
|
| 75 |
+
# )
|
| 76 |
+
|
| 77 |
+
# self.pred_linear = nn.Linear(spk_embed_dim, num_spks)
|
| 78 |
+
|
| 79 |
+
# 改为pretrain
|
| 80 |
+
# 考虑两种可能,频域就不大可能有所谓的多时间尺度,所以肯定speaker是直接频谱,那speech呢?
|
| 81 |
+
# /work105/youzhenghai/model/resnet_asp_aam_adamw_welr
|
| 82 |
+
# import ..sunine/trainer/speaker encoder
|
| 83 |
+
# **kwargs 无需关心 找到 self.hparams就行 按照 main_infer改就行
|
| 84 |
+
#############################################################
|
| 85 |
+
|
| 86 |
+
# # 1. Acoustic Feature
|
| 87 |
+
# self.mel_trans = th.nn.Sequential(
|
| 88 |
+
# PreEmphasis(),
|
| 89 |
+
# torchaudio.transforms.MelSpectrogram(sample_rate=self.sample_rate, n_fft=512,
|
| 90 |
+
# win_length=400, hop_length=160, window_fn=th.hamming_window, n_mels=self.n_mels)
|
| 91 |
+
# )
|
| 92 |
+
|
| 93 |
+
# self.instancenorm = nn.InstanceNorm1d(self.n_mels)
|
| 94 |
+
|
| 95 |
+
# # 在调用的地方设置超参数 记得后面写为参数传入
|
| 96 |
+
# self.hparams = {'embedding_dim': spk_embed_dim, 'pooling_type': 'ASP' , 'n_mels': self.n_mels}
|
| 97 |
+
# # 使用 **self.hparams 调用函数
|
| 98 |
+
# self.speaker_encoder = Speaker_Encoder(**self.hparams)
|
| 99 |
+
self.speaker_embedding_extracter = Speaker_Model(pooling_type='ASP', spk_embed_dim=spk_embed_dim, sample_rate=self.sample_rate, n_mels=self.n_mels)
|
| 100 |
+
self.pred_linear = nn.Linear(spk_embed_dim, num_spks)
|
| 101 |
+
|
| 102 |
+
#############################################################
|
| 103 |
+
|
| 104 |
+
# # 3. Loss / Classifier
|
| 105 |
+
# if not self.hparams.evaluate:
|
| 106 |
+
# LossFunction = importlib.import_module('trainer.loss.'+self.hparams.loss_type).__getattribute__('LossFunction')
|
| 107 |
+
# self.loss = LossFunction(**dict(self.hparams))
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def _build_stacks(self, num_blocks, **block_kwargs):
|
| 111 |
+
"""
|
| 112 |
+
Stack B numbers of TCN block, the first TCN block takes the speaker embedding
|
| 113 |
+
"""
|
| 114 |
+
blocks = [
|
| 115 |
+
TCNBlock(**block_kwargs, dilation=(2**b))
|
| 116 |
+
for b in range(1,num_blocks)
|
| 117 |
+
]
|
| 118 |
+
return nn.Sequential(*blocks)
|
| 119 |
+
# 注意下维度 是 B N T 还是 B T N
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def forward(self, x, aux, aux_len):
|
| 124 |
+
if x.dim() >= 3:
|
| 125 |
+
raise RuntimeError(
|
| 126 |
+
"{} accept 1/2D tensor as input, but got {:d}".format(
|
| 127 |
+
self.__name__, x.dim()))
|
| 128 |
+
# when inference, only one utt
|
| 129 |
+
if x.dim() == 1:
|
| 130 |
+
x = th.unsqueeze(x, 0)
|
| 131 |
+
|
| 132 |
+
# n x 1 x S => n x N x T
|
| 133 |
+
w1 = F.relu(self.encoder_1d_short(x))
|
| 134 |
+
T = w1.shape[-1]
|
| 135 |
+
xlen1 = x.shape[-1]
|
| 136 |
+
xlen2 = (T - 1) * (self.L1 // 2) + self.L2
|
| 137 |
+
xlen3 = (T - 1) * (self.L1 // 2) + self.L3
|
| 138 |
+
w2 = F.relu(self.encoder_1d_middle(F.pad(x, (0, xlen2 - xlen1), "constant", 0)))
|
| 139 |
+
w3 = F.relu(self.encoder_1d_long(F.pad(x, (0, xlen3 - xlen1), "constant", 0)))
|
| 140 |
+
|
| 141 |
+
# n x 3N x T
|
| 142 |
+
y = self.ln(th.cat([w1, w2, w3], 1))
|
| 143 |
+
# n x O x T
|
| 144 |
+
y = self.proj(y)
|
| 145 |
+
|
| 146 |
+
# speaker encoder (share params from speech encoder)
|
| 147 |
+
# aux_w1 = F.relu(self.encoder_1d_short(aux))
|
| 148 |
+
# aux_T_shape = aux_w1.shape[-1]
|
| 149 |
+
# aux_len1 = aux.shape[-1]
|
| 150 |
+
# aux_len2 = (aux_T_shape - 1) * (self.L1 // 2) + self.L2
|
| 151 |
+
# aux_len3 = (aux_T_shape - 1) * (self.L1 // 2) + self.L3
|
| 152 |
+
# aux_w2 = F.relu(self.encoder_1d_middle(F.pad(aux, (0, aux_len2 - aux_len1), "constant", 0)))
|
| 153 |
+
# aux_w3 = F.relu(self.encoder_1d_long(F.pad(aux, (0, aux_len3 - aux_len1), "constant", 0)))
|
| 154 |
+
|
| 155 |
+
# spk_encoder + mean pooling
|
| 156 |
+
# aux = self.spk_encoder(th.cat([aux_w1, aux_w2, aux_w3], 1))
|
| 157 |
+
# aux_T = (aux_len - self.L1) // (self.L1 // 2) + 1
|
| 158 |
+
# aux_T = ((aux_T // 3) // 3) // 3
|
| 159 |
+
# aux = th.sum(aux, -1)/aux_T.view(-1,1).float()
|
| 160 |
+
|
| 161 |
+
# spk_encoder + TAP pooling
|
| 162 |
+
aux = self.speaker_embedding_extracter(aux)
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
#aux = torch.mean(aux, axis=0)
|
| 167 |
+
|
| 168 |
+
# aux = aux.cpu().detach().numpy()
|
| 169 |
+
|
| 170 |
+
# 不需要 reshape N * D 是正确的维度
|
| 171 |
+
#aux = aux.reshape(-1, self.hparams.nPerSpeaker, self.spk_embed_dim)
|
| 172 |
+
# loss, acc = self.loss(x, label)
|
| 173 |
+
# return loss.mean(), acc
|
| 174 |
+
# 考虑 loss 是否也要
|
| 175 |
+
|
| 176 |
+
y = self.conv_block_1(y, aux)
|
| 177 |
+
y = self.conv_block_1_other(y)
|
| 178 |
+
y = self.conv_block_2(y, aux)
|
| 179 |
+
y = self.conv_block_2_other(y)
|
| 180 |
+
y = self.conv_block_3(y, aux)
|
| 181 |
+
y = self.conv_block_3_other(y)
|
| 182 |
+
y = self.conv_block_4(y, aux)
|
| 183 |
+
y = self.conv_block_4_other(y)
|
| 184 |
+
|
| 185 |
+
# n x N x T
|
| 186 |
+
m1 = F.relu(self.mask1(y))
|
| 187 |
+
m2 = F.relu(self.mask2(y))
|
| 188 |
+
m3 = F.relu(self.mask3(y))
|
| 189 |
+
S1 = w1 * m1
|
| 190 |
+
S2 = w2 * m2
|
| 191 |
+
S3 = w3 * m3
|
| 192 |
+
|
| 193 |
+
return self.decoder_1d_short(S1), self.decoder_1d_middle(S2)[:, :xlen1], self.decoder_1d_long(S3)[:, :xlen1], self.pred_linear(aux)
|
| 194 |
+
|
| 195 |
+
class PreEmphasis(th.nn.Module):
|
| 196 |
+
def __init__(self, coef: float = 0.97):
|
| 197 |
+
super().__init__()
|
| 198 |
+
self.coef = coef
|
| 199 |
+
# make kernel
|
| 200 |
+
# In pyth, the convolution operation uses cross-correlation. So, filter is flipped.
|
| 201 |
+
self.register_buffer(
|
| 202 |
+
'flipped_filter', th.FloatTensor([-self.coef, 1.]).unsqueeze(0).unsqueeze(0)
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
def forward(self, inputs: th.tensor) -> th.tensor:
|
| 206 |
+
assert len(inputs.size()) == 2, 'The number of dimensions of inputs tensor must be 2!'
|
| 207 |
+
# reflect padding to match lengths of in/out
|
| 208 |
+
inputs = inputs.unsqueeze(1)
|
| 209 |
+
inputs = F.pad(inputs, (1, 0), 'reflect')
|
| 210 |
+
return F.conv1d(inputs, self.flipped_filter).squeeze(1)
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
class Speaker_Model(nn.Module):
|
| 214 |
+
#class Speaker_Model(LightningModule):
|
| 215 |
+
def __init__(self, pooling_type, spk_embed_dim, sample_rate, n_mels):
|
| 216 |
+
super().__init__()
|
| 217 |
+
# self.save_hyperparameters()
|
| 218 |
+
|
| 219 |
+
self.pooling_type = pooling_type
|
| 220 |
+
self.spk_embed_dim = spk_embed_dim
|
| 221 |
+
self.sample_rate = sample_rate
|
| 222 |
+
self.n_mels = n_mels
|
| 223 |
+
sr = self.sample_rate
|
| 224 |
+
|
| 225 |
+
self.mel_trans = th.nn.Sequential(
|
| 226 |
+
PreEmphasis(),
|
| 227 |
+
torchaudio.transforms.MelSpectrogram(sample_rate=sr, n_fft=512,
|
| 228 |
+
win_length=sr * 25 // 1000, hop_length=sr * 10 // 1000,
|
| 229 |
+
window_fn=th.hamming_window, n_mels=self.n_mels)
|
| 230 |
+
)
|
| 231 |
+
self.instancenorm = nn.InstanceNorm1d(self.n_mels)
|
| 232 |
+
|
| 233 |
+
self.hparams = {'embedding_dim': self.spk_embed_dim, 'pooling_type': self.pooling_type , 'n_mels': self.n_mels}
|
| 234 |
+
|
| 235 |
+
self.speaker_encoder = Speaker_Encoder(**dict(self.hparams))
|
| 236 |
+
|
| 237 |
+
def extract_speaker_embedding(self, data):
|
| 238 |
+
x = data.reshape(-1, data.size()[-1])
|
| 239 |
+
x = self.mel_trans(x) + 1e-6
|
| 240 |
+
x = x.log()
|
| 241 |
+
x = self.instancenorm(x)
|
| 242 |
+
x = self.speaker_encoder(x)
|
| 243 |
+
return x
|
| 244 |
+
|
| 245 |
+
def forward(self, x):
|
| 246 |
+
x = self.extract_speaker_embedding(x)
|
| 247 |
+
return x
|
requirement.txt
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch=1.8.0
|
| 2 |
+
torchaudio=0.8.0
|
| 3 |
+
speechbrain=0.5.10
|
| 4 |
+
soundfile
|
| 5 |
+
gradio
|
utils/__init__.py
ADDED
|
File without changes
|
utils/audio.py
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import numpy as np
|
| 5 |
+
import soundfile as sf
|
| 6 |
+
|
| 7 |
+
def write_wav(fname, samps, sample_rate=16000, normalize=True):
|
| 8 |
+
"""
|
| 9 |
+
Write wav files in float32, support single/multi-channel
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
# wham and whamr mixture and clean data are float 32, can not use scipy.io.wavfile to read and write int16
|
| 13 |
+
# change to soundfile to read and write, although reference speech is int16, soundfile still can read and outputs as float
|
| 14 |
+
fdir = os.path.dirname(fname)
|
| 15 |
+
if fdir and not os.path.exists(fdir):
|
| 16 |
+
os.makedirs(fdir)
|
| 17 |
+
sf.write(fname, samps, sample_rate, subtype='FLOAT')
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def read_wav(fname, normalize=True, return_rate=False):
|
| 21 |
+
"""
|
| 22 |
+
Read wave files (support multi-channel)
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
# wham and whamr mixture and clean data are float 32, can not use scipy.io.wavfile to read and write int16
|
| 26 |
+
# change to soundfile to read and write, although reference speech is int16, soundfile still can read and outputs as float
|
| 27 |
+
samps, samp_rate = sf.read(fname)
|
| 28 |
+
if return_rate:
|
| 29 |
+
return samp_rate, samps
|
| 30 |
+
return samps
|
| 31 |
+
|
| 32 |
+
def parse_scripts(scp_path, value_processor=lambda x: x, num_tokens=2):
|
| 33 |
+
"""
|
| 34 |
+
Parse kaldi's script(.scp) file
|
| 35 |
+
If num_tokens >= 2, function will check token number
|
| 36 |
+
"""
|
| 37 |
+
scp_dict = dict()
|
| 38 |
+
line = 0
|
| 39 |
+
with open(scp_path, "r") as f:
|
| 40 |
+
for raw_line in f:
|
| 41 |
+
scp_tokens = raw_line.strip().split()
|
| 42 |
+
line += 1
|
| 43 |
+
if num_tokens >= 2 and len(scp_tokens) != num_tokens or len(
|
| 44 |
+
scp_tokens) < 2:
|
| 45 |
+
raise RuntimeError(
|
| 46 |
+
"For {}, format error in line[{:d}]: {}".format(
|
| 47 |
+
scp_path, line, raw_line))
|
| 48 |
+
if num_tokens == 2:
|
| 49 |
+
key, value = scp_tokens
|
| 50 |
+
else:
|
| 51 |
+
key, value = scp_tokens[0], scp_tokens[1:]
|
| 52 |
+
if key in scp_dict:
|
| 53 |
+
raise ValueError("Duplicated key \'{0}\' exists in {1}".format(
|
| 54 |
+
key, scp_path))
|
| 55 |
+
scp_dict[key] = value_processor(value)
|
| 56 |
+
return scp_dict
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
class Reader(object):
|
| 60 |
+
"""
|
| 61 |
+
Basic Reader Class
|
| 62 |
+
"""
|
| 63 |
+
|
| 64 |
+
def __init__(self, scp_path, value_processor=lambda x: x):
|
| 65 |
+
self.index_dict = parse_scripts(
|
| 66 |
+
scp_path, value_processor=value_processor, num_tokens=2)
|
| 67 |
+
self.index_keys = list(self.index_dict.keys())
|
| 68 |
+
|
| 69 |
+
def _load(self, key):
|
| 70 |
+
# return path
|
| 71 |
+
return self.index_dict[key]
|
| 72 |
+
|
| 73 |
+
# number of utterance
|
| 74 |
+
def __len__(self):
|
| 75 |
+
return len(self.index_dict)
|
| 76 |
+
|
| 77 |
+
# avoid key error
|
| 78 |
+
def __contains__(self, key):
|
| 79 |
+
return key in self.index_dict
|
| 80 |
+
|
| 81 |
+
# sequential index
|
| 82 |
+
def __iter__(self):
|
| 83 |
+
for key in self.index_keys:
|
| 84 |
+
yield key, self._load(key)
|
| 85 |
+
|
| 86 |
+
# random index, support str/int as index
|
| 87 |
+
def __getitem__(self, index):
|
| 88 |
+
if type(index) not in [int, str]:
|
| 89 |
+
raise IndexError("Unsupported index type: {}".format(type(index)))
|
| 90 |
+
if type(index) == int:
|
| 91 |
+
# from int index to key
|
| 92 |
+
num_utts = len(self.index_keys)
|
| 93 |
+
if index >= num_utts or index < 0:
|
| 94 |
+
raise KeyError(
|
| 95 |
+
"Interger index out of range, {:d} vs {:d}".format(
|
| 96 |
+
index, num_utts))
|
| 97 |
+
index = self.index_keys[index]
|
| 98 |
+
if index not in self.index_dict:
|
| 99 |
+
raise KeyError("Missing utterance {}!".format(index))
|
| 100 |
+
return self._load(index)
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
class WaveReader(Reader):
|
| 104 |
+
"""
|
| 105 |
+
Sequential/Random Reader for single channel wave
|
| 106 |
+
Format of wav.scp follows Kaldi's definition:
|
| 107 |
+
key1 /path/to/wav
|
| 108 |
+
...
|
| 109 |
+
"""
|
| 110 |
+
|
| 111 |
+
def __init__(self, wav_scp, sample_rate=None, normalize=True):
|
| 112 |
+
super(WaveReader, self).__init__(wav_scp)
|
| 113 |
+
self.samp_rate = sample_rate
|
| 114 |
+
self.normalize = normalize
|
| 115 |
+
|
| 116 |
+
def _load(self, key):
|
| 117 |
+
# return C x N or N
|
| 118 |
+
samp_rate, samps = read_wav(
|
| 119 |
+
self.index_dict[key], normalize=self.normalize, return_rate=True)
|
| 120 |
+
# if given samp_rate, check it
|
| 121 |
+
if self.samp_rate is not None and samp_rate != self.samp_rate:
|
| 122 |
+
raise RuntimeError("SampleRate mismatch: {:d} vs {:d}".format(
|
| 123 |
+
samp_rate, self.samp_rate))
|
| 124 |
+
return samps
|
utils/dataset copy.py
ADDED
|
@@ -0,0 +1,284 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
|
| 3 |
+
import random
|
| 4 |
+
import torch as th
|
| 5 |
+
import numpy as np
|
| 6 |
+
|
| 7 |
+
from torch.utils.data.dataloader import default_collate
|
| 8 |
+
import torch.utils.data as dat
|
| 9 |
+
from torch.nn.utils.rnn import pad_sequence
|
| 10 |
+
|
| 11 |
+
from .audio import WaveReader
|
| 12 |
+
|
| 13 |
+
import soundfile as sf
|
| 14 |
+
|
| 15 |
+
# random_seed = 1453
|
| 16 |
+
# random.seed(random_seed)
|
| 17 |
+
|
| 18 |
+
def make_dataloader(train=True,
|
| 19 |
+
utt_scp_file=None,
|
| 20 |
+
spk_list=None,
|
| 21 |
+
sample_rate=16000,
|
| 22 |
+
num_workers=4,
|
| 23 |
+
chunk_size=32000,
|
| 24 |
+
batch_size=16):
|
| 25 |
+
dataset = Dataset(utt_scp_file=utt_scp_file,
|
| 26 |
+
spk_list=spk_list,
|
| 27 |
+
chunk_size=chunk_size,
|
| 28 |
+
sample_rate=sample_rate)
|
| 29 |
+
return DataLoader(dataset,
|
| 30 |
+
train=train,
|
| 31 |
+
chunk_size=chunk_size,
|
| 32 |
+
batch_size=batch_size,
|
| 33 |
+
num_workers=num_workers)
|
| 34 |
+
|
| 35 |
+
class Dataset(object):
|
| 36 |
+
"""
|
| 37 |
+
Per Utterance Loader
|
| 38 |
+
"""
|
| 39 |
+
def __init__(self, utt_scp_file="", spk_list=None,chunk_size=32000, sample_rate=8000):
|
| 40 |
+
self.sample_rate = sample_rate
|
| 41 |
+
self.spk_list = self._load_spk(spk_list)
|
| 42 |
+
|
| 43 |
+
self.seg_least= int(chunk_size // 2 )
|
| 44 |
+
|
| 45 |
+
# self.mix = WaveReader(mix_scp, sample_rate=sample_rate)
|
| 46 |
+
# self.ref = WaveReader(ref_scp, sample_rate=sample_rate)
|
| 47 |
+
# self.aux = WaveReader(aux_scp, sample_rate=sample_rate)
|
| 48 |
+
|
| 49 |
+
with open(utt_scp_file, 'r') as f:
|
| 50 |
+
lines = f.readlines()
|
| 51 |
+
self.data = []
|
| 52 |
+
self.total_lines = len(self.data)
|
| 53 |
+
for line in lines:
|
| 54 |
+
parts = line.strip().split()
|
| 55 |
+
sentence_id = parts[0]
|
| 56 |
+
sentence_path = parts[1]
|
| 57 |
+
data_len = parts[2]
|
| 58 |
+
spk_id = (sentence_id.split('-')[0])[1:5]
|
| 59 |
+
self.data.append((sentence_id, spk_id, sentence_path, data_len))
|
| 60 |
+
|
| 61 |
+
if not self.data:
|
| 62 |
+
raise ValueError("No valid lines found in the input file.")
|
| 63 |
+
self.total_lines = len(self.data)
|
| 64 |
+
|
| 65 |
+
def _load_spk(self, spk_list_path):
|
| 66 |
+
if spk_list_path is None:
|
| 67 |
+
return []
|
| 68 |
+
lines = open(spk_list_path).readlines()
|
| 69 |
+
new_lines = []
|
| 70 |
+
for line in lines:
|
| 71 |
+
new_lines.append(line.strip())
|
| 72 |
+
|
| 73 |
+
return new_lines
|
| 74 |
+
|
| 75 |
+
def __len__(self):
|
| 76 |
+
return len(self.data)
|
| 77 |
+
|
| 78 |
+
def _get_segment_start_stop(self, seg_len, length):
|
| 79 |
+
if seg_len is not None:
|
| 80 |
+
start = random.randint(0, length - seg_len)
|
| 81 |
+
stop = start + seg_len
|
| 82 |
+
else:
|
| 83 |
+
start = 0
|
| 84 |
+
stop = None
|
| 85 |
+
return start, stop
|
| 86 |
+
|
| 87 |
+
def _mix(self, sources_list):
|
| 88 |
+
|
| 89 |
+
# if self.seg_len:
|
| 90 |
+
# mix_length = self.seg_len
|
| 91 |
+
|
| 92 |
+
# else:
|
| 93 |
+
# mix_length = self.common_length
|
| 94 |
+
mix_length = self.common_length
|
| 95 |
+
mixture = np.zeros(mix_length)
|
| 96 |
+
for i, _ in enumerate(sources_list):
|
| 97 |
+
mixture += sources_list[i]
|
| 98 |
+
|
| 99 |
+
return mixture
|
| 100 |
+
|
| 101 |
+
def __getitem__(self, idx):
|
| 102 |
+
source_id, source_spk, source_path, all_source_length= self.data[idx]
|
| 103 |
+
all_source_length = int(all_source_length)
|
| 104 |
+
spk_idx = self.spk_list.index(source_spk)
|
| 105 |
+
|
| 106 |
+
other_counter = 0
|
| 107 |
+
while True:
|
| 108 |
+
random_idx = np.random.randint(0, self.total_lines)
|
| 109 |
+
if self.data[random_idx][1] != source_spk:
|
| 110 |
+
other_id, other_spk, other_path, other_length = self.data[random_idx]
|
| 111 |
+
other_length = int(other_length)
|
| 112 |
+
|
| 113 |
+
if other_length > self.seg_least:
|
| 114 |
+
break
|
| 115 |
+
|
| 116 |
+
other_counter += 1
|
| 117 |
+
|
| 118 |
+
if other_counter >= self.total_lines:
|
| 119 |
+
raise ValueError("All Data too shorter to mix")
|
| 120 |
+
|
| 121 |
+
enroll_counter = 0
|
| 122 |
+
|
| 123 |
+
while True:
|
| 124 |
+
random_idx = np.random.randint(0, self.total_lines)
|
| 125 |
+
if self.data[random_idx][1] == source_spk:
|
| 126 |
+
enroll_id, enroll_spk, enroll_path, all_enroll_length= self.data[random_idx]
|
| 127 |
+
all_enroll_length = int(all_enroll_length)
|
| 128 |
+
if all_enroll_length > self.seg_least:
|
| 129 |
+
break
|
| 130 |
+
|
| 131 |
+
enroll_counter += 1
|
| 132 |
+
if enroll_counter >= self.total_lines:
|
| 133 |
+
raise ValueError("All Data too shorter to enroll")
|
| 134 |
+
# lengths = [all_source_length, other_length]
|
| 135 |
+
|
| 136 |
+
if all_source_length >= other_length:
|
| 137 |
+
self.common_length = other_length
|
| 138 |
+
start, stop = self._get_segment_start_stop(other_length, all_source_length)
|
| 139 |
+
source_tmp,_ = sf.read(source_path, dtype="float32", start=start, stop=stop)
|
| 140 |
+
other_tmp,_ = sf.read(other_path, dtype="float32")
|
| 141 |
+
elif all_source_length <= other_length:
|
| 142 |
+
self.common_length = all_source_length
|
| 143 |
+
start, stop = self._get_segment_start_stop(all_source_length, other_length)
|
| 144 |
+
source_tmp,_ = sf.read(source_path, dtype="float32")
|
| 145 |
+
other_tmp,_ = sf.read(other_path, dtype="float32", start=start, stop=stop)
|
| 146 |
+
|
| 147 |
+
source = source_tmp[:, np.random.randint(0, source_tmp.shape[1])]
|
| 148 |
+
|
| 149 |
+
other = other_tmp[:, np.random.randint(0, other_tmp.shape[1])]
|
| 150 |
+
|
| 151 |
+
mixture = self._mix([source, other])
|
| 152 |
+
mixture = mixture.astype(np.float32)
|
| 153 |
+
|
| 154 |
+
enroll_tmp, _ = sf.read(enroll_path, dtype="float32")
|
| 155 |
+
enroll = enroll_tmp[:, np.random.randint(0, enroll_tmp.shape[1])]
|
| 156 |
+
|
| 157 |
+
return {
|
| 158 |
+
"mix": mixture,
|
| 159 |
+
"ref": source,
|
| 160 |
+
"aux": enroll,
|
| 161 |
+
"aux_len": len(enroll),
|
| 162 |
+
"spk_idx": spk_idx
|
| 163 |
+
}
|
| 164 |
+
|
| 165 |
+
class ChunkSplitter(object):
|
| 166 |
+
"""
|
| 167 |
+
Split utterance into small chunks
|
| 168 |
+
"""
|
| 169 |
+
def __init__(self, chunk_size, train=True, least=16000):
|
| 170 |
+
self.chunk_size = chunk_size
|
| 171 |
+
self.least = least
|
| 172 |
+
self.train = train
|
| 173 |
+
|
| 174 |
+
def _make_chunk(self, eg, s):
|
| 175 |
+
"""
|
| 176 |
+
Make a chunk instance, which contains:
|
| 177 |
+
"mix": ndarray,
|
| 178 |
+
"ref": [ndarray...]
|
| 179 |
+
"""
|
| 180 |
+
chunk = dict()
|
| 181 |
+
chunk["mix"] = eg["mix"][s:s + self.chunk_size]
|
| 182 |
+
chunk["ref"] = eg["ref"][s:s + self.chunk_size]
|
| 183 |
+
chunk["aux"] = eg["aux"]
|
| 184 |
+
chunk["aux_len"] = eg["aux_len"]
|
| 185 |
+
chunk["valid_len"] = int(self.chunk_size)
|
| 186 |
+
chunk["spk_idx"] = eg["spk_idx"]
|
| 187 |
+
return chunk
|
| 188 |
+
|
| 189 |
+
def split(self, eg):
|
| 190 |
+
N = eg["mix"].size
|
| 191 |
+
# too short, throw away
|
| 192 |
+
if N < self.least:
|
| 193 |
+
return []
|
| 194 |
+
chunks = []
|
| 195 |
+
# padding zeros
|
| 196 |
+
if N < self.chunk_size:
|
| 197 |
+
P = self.chunk_size - N
|
| 198 |
+
chunk = dict()
|
| 199 |
+
chunk["mix"] = np.pad(eg["mix"], (0, P), "constant")
|
| 200 |
+
chunk["ref"] = np.pad(eg["ref"], (0, P), "constant")
|
| 201 |
+
chunk["aux"] = eg["aux"]
|
| 202 |
+
chunk["aux_len"] = eg["aux_len"]
|
| 203 |
+
chunk["valid_len"] = int(N)
|
| 204 |
+
chunk["spk_idx"] = eg["spk_idx"]
|
| 205 |
+
chunks.append(chunk)
|
| 206 |
+
else:
|
| 207 |
+
# random select start point for training
|
| 208 |
+
s = random.randint(0, N % self.least) if self.train else 0
|
| 209 |
+
while True:
|
| 210 |
+
if s + self.chunk_size > N:
|
| 211 |
+
break
|
| 212 |
+
chunk = self._make_chunk(eg, s)
|
| 213 |
+
chunks.append(chunk)
|
| 214 |
+
s += self.least
|
| 215 |
+
return chunks
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
class DataLoader(object):
|
| 219 |
+
"""
|
| 220 |
+
Online dataloader for chunk-level
|
| 221 |
+
"""
|
| 222 |
+
def __init__(self,
|
| 223 |
+
dataset,
|
| 224 |
+
num_workers=4,
|
| 225 |
+
chunk_size=32000,
|
| 226 |
+
batch_size=16,
|
| 227 |
+
train=True):
|
| 228 |
+
self.batch_size = batch_size
|
| 229 |
+
self.train = train
|
| 230 |
+
self.splitter = ChunkSplitter(chunk_size,
|
| 231 |
+
train=train,
|
| 232 |
+
least=chunk_size // 2)
|
| 233 |
+
# just return batch of egs, support multiple workers
|
| 234 |
+
self.eg_loader = dat.DataLoader(dataset,
|
| 235 |
+
batch_size=batch_size // 2,
|
| 236 |
+
num_workers=num_workers,
|
| 237 |
+
shuffle=train,
|
| 238 |
+
collate_fn=self._collate)
|
| 239 |
+
|
| 240 |
+
def _collate(self, batch):
|
| 241 |
+
"""
|
| 242 |
+
Online split utterances
|
| 243 |
+
"""
|
| 244 |
+
chunk = []
|
| 245 |
+
for eg in batch:
|
| 246 |
+
chunk += self.splitter.split(eg)
|
| 247 |
+
return chunk
|
| 248 |
+
|
| 249 |
+
def _pad_aux(self, chunk_list):
|
| 250 |
+
lens_list = []
|
| 251 |
+
for chunk_item in chunk_list:
|
| 252 |
+
lens_list.append(chunk_item['aux_len'])
|
| 253 |
+
max_len = np.max(lens_list)
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
for idx in range(len(chunk_list)):
|
| 257 |
+
P = max_len - len(chunk_list[idx]["aux"])
|
| 258 |
+
chunk_list[idx]["aux"] = np.pad(chunk_list[idx]["aux"], (0, P), "constant")
|
| 259 |
+
|
| 260 |
+
return chunk_list
|
| 261 |
+
|
| 262 |
+
def _merge(self, chunk_list):
|
| 263 |
+
"""
|
| 264 |
+
Merge chunk list into mini-batch
|
| 265 |
+
"""
|
| 266 |
+
N = len(chunk_list)
|
| 267 |
+
if self.train:
|
| 268 |
+
random.shuffle(chunk_list)
|
| 269 |
+
blist = []
|
| 270 |
+
for s in range(0, N - self.batch_size + 1, self.batch_size):
|
| 271 |
+
# padding aux info
|
| 272 |
+
#self._pad_aux(chunk_list[s:s + self.batch_size])
|
| 273 |
+
batch = default_collate(self._pad_aux(chunk_list[s:s + self.batch_size]))
|
| 274 |
+
blist.append(batch)
|
| 275 |
+
rn = N % self.batch_size
|
| 276 |
+
return blist, chunk_list[-rn:] if rn else []
|
| 277 |
+
|
| 278 |
+
def __iter__(self):
|
| 279 |
+
chunk_list = []
|
| 280 |
+
for chunks in self.eg_loader:
|
| 281 |
+
chunk_list += chunks
|
| 282 |
+
batch, chunk_list = self._merge(chunk_list)
|
| 283 |
+
for obj in batch:
|
| 284 |
+
yield obj
|
utils/dataset.py
ADDED
|
@@ -0,0 +1,402 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
|
| 3 |
+
import random
|
| 4 |
+
import torch as th
|
| 5 |
+
import numpy as np
|
| 6 |
+
|
| 7 |
+
from torch.utils.data.dataloader import default_collate
|
| 8 |
+
import torch.utils.data as dat
|
| 9 |
+
from torch.nn.utils.rnn import pad_sequence
|
| 10 |
+
|
| 11 |
+
from .audio import WaveReader
|
| 12 |
+
|
| 13 |
+
import soundfile as sf
|
| 14 |
+
|
| 15 |
+
# random_seed = 1453
|
| 16 |
+
# random.seed(random_seed)
|
| 17 |
+
|
| 18 |
+
# "aux_len": all_enroll_length,
|
| 19 |
+
|
| 20 |
+
EPS = 1e-10
|
| 21 |
+
def make_dataloader(train=True,
|
| 22 |
+
mix_scp_file=None,
|
| 23 |
+
enroll_scp_file=None,
|
| 24 |
+
noise_scp_file=None,
|
| 25 |
+
spk_list=None,
|
| 26 |
+
sample_rate=16000,
|
| 27 |
+
num_workers=4,
|
| 28 |
+
chunk_size=32000,
|
| 29 |
+
batch_size=16):
|
| 30 |
+
dataset = Dataset(mix_scp_file=mix_scp_file,
|
| 31 |
+
enroll_scp_file=enroll_scp_file,
|
| 32 |
+
noise_scp_file=noise_scp_file,
|
| 33 |
+
spk_list=spk_list,
|
| 34 |
+
chunk_size=chunk_size,
|
| 35 |
+
sample_rate=sample_rate)
|
| 36 |
+
return DataLoader(dataset,
|
| 37 |
+
train=train,
|
| 38 |
+
chunk_size=chunk_size,
|
| 39 |
+
batch_size=batch_size,
|
| 40 |
+
num_workers=num_workers)
|
| 41 |
+
|
| 42 |
+
class Dataset(object):
|
| 43 |
+
"""
|
| 44 |
+
Per Utterance Loader
|
| 45 |
+
"""
|
| 46 |
+
def __init__(self, mix_scp_file="", enroll_scp_file="", noise_scp_file="", spk_list=None,chunk_size=32000, sample_rate=8000):
|
| 47 |
+
self.sample_rate = sample_rate
|
| 48 |
+
self.spk_list = self._load_spk(spk_list)
|
| 49 |
+
|
| 50 |
+
self.seg_least= int(chunk_size // 2 )
|
| 51 |
+
|
| 52 |
+
with open(mix_scp_file, 'r') as f:
|
| 53 |
+
lines = f.readlines()
|
| 54 |
+
self.data = []
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
for line in lines:
|
| 58 |
+
parts = line.strip().split()
|
| 59 |
+
sentence_id = parts[0]
|
| 60 |
+
sentence_path = parts[1]
|
| 61 |
+
data_len = parts[2]
|
| 62 |
+
spk_id = (sentence_id.split('-')[0])[1:5]
|
| 63 |
+
self.data.append((sentence_id, spk_id, sentence_path, data_len))
|
| 64 |
+
|
| 65 |
+
with open(enroll_scp_file, 'r') as f:
|
| 66 |
+
enroll_lines = f.readlines()
|
| 67 |
+
self.enroll_data = []
|
| 68 |
+
|
| 69 |
+
for line in enroll_lines:
|
| 70 |
+
parts = line.strip().split()
|
| 71 |
+
sentence_id = parts[0]
|
| 72 |
+
sentence_path = parts[1]
|
| 73 |
+
data_len = parts[2]
|
| 74 |
+
spk_id = (sentence_id.split('-')[0])[1:5]
|
| 75 |
+
self.enroll_data.append((sentence_id, spk_id, sentence_path, data_len))
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
with open(noise_scp_file, 'r') as f:
|
| 79 |
+
noise_lines = f.readlines()
|
| 80 |
+
self.noise_data = []
|
| 81 |
+
|
| 82 |
+
for line in noise_lines:
|
| 83 |
+
parts = line.strip().split()
|
| 84 |
+
sentence_id = parts[0]
|
| 85 |
+
sentence_path = parts[1]
|
| 86 |
+
data_len = parts[2]
|
| 87 |
+
# spk_id = (sentence_id.split('-')[0])[1:5]
|
| 88 |
+
self.noise_data.append((sentence_id, sentence_path, data_len))
|
| 89 |
+
|
| 90 |
+
self.total_lines = len(self.data)
|
| 91 |
+
self.total_enroll = self._enroll_data_len()
|
| 92 |
+
self.total_noise = self._noise_data_len()
|
| 93 |
+
|
| 94 |
+
if not self.data:
|
| 95 |
+
raise ValueError("No valid lines found in the input file.")
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def _load_spk(self, spk_list_path):
|
| 99 |
+
if spk_list_path is None:
|
| 100 |
+
return []
|
| 101 |
+
lines = open(spk_list_path).readlines()
|
| 102 |
+
new_lines = []
|
| 103 |
+
for line in lines:
|
| 104 |
+
new_lines.append(line.strip())
|
| 105 |
+
|
| 106 |
+
return new_lines
|
| 107 |
+
|
| 108 |
+
def __len__(self):
|
| 109 |
+
return len(self.data)
|
| 110 |
+
|
| 111 |
+
def _enroll_data_len(self):
|
| 112 |
+
return len(self.enroll_data)
|
| 113 |
+
|
| 114 |
+
def _noise_data_len(self):
|
| 115 |
+
return len(self.noise_data)
|
| 116 |
+
|
| 117 |
+
def _get_segment_start_stop(self, seg_len, length):
|
| 118 |
+
if seg_len is not None:
|
| 119 |
+
start = random.randint(0, length - seg_len)
|
| 120 |
+
stop = start + seg_len
|
| 121 |
+
else:
|
| 122 |
+
start = 0
|
| 123 |
+
stop = None
|
| 124 |
+
return start, stop
|
| 125 |
+
|
| 126 |
+
def _mix(self, sources_list):
|
| 127 |
+
|
| 128 |
+
# if self.seg_len:
|
| 129 |
+
# mix_length = self.seg_len
|
| 130 |
+
|
| 131 |
+
# else:
|
| 132 |
+
# mix_length = self.common_length
|
| 133 |
+
mix_length = self.common_length
|
| 134 |
+
mixture = np.zeros(mix_length)
|
| 135 |
+
for i, _ in enumerate(sources_list):
|
| 136 |
+
mixture += sources_list[i]
|
| 137 |
+
|
| 138 |
+
return mixture
|
| 139 |
+
|
| 140 |
+
def __getitem__(self, idx):
|
| 141 |
+
source_id, source_spk, source_path, all_source_length= self.data[idx]
|
| 142 |
+
all_source_length = int(all_source_length)
|
| 143 |
+
spk_idx = self.spk_list.index(source_spk)
|
| 144 |
+
|
| 145 |
+
other_counter = 0
|
| 146 |
+
while True:
|
| 147 |
+
random_idx = np.random.randint(0, self.total_lines)
|
| 148 |
+
if self.data[random_idx][1] != source_spk:
|
| 149 |
+
other_id, other_spk, other_path, other_length = self.data[random_idx]
|
| 150 |
+
other_length = int(other_length)
|
| 151 |
+
|
| 152 |
+
if other_length > self.seg_least:
|
| 153 |
+
break
|
| 154 |
+
|
| 155 |
+
other_counter += 1
|
| 156 |
+
|
| 157 |
+
if other_counter >= self.total_lines:
|
| 158 |
+
raise ValueError("All Data too shorter to mix")
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
if all_source_length >= other_length:
|
| 162 |
+
self.common_length = other_length
|
| 163 |
+
start, stop = self._get_segment_start_stop(self.common_length, all_source_length)
|
| 164 |
+
source_tmp,_ = sf.read(source_path, dtype="float32", start=start, stop=stop)
|
| 165 |
+
other_tmp,_ = sf.read(other_path, dtype="float32")
|
| 166 |
+
elif all_source_length <= other_length:
|
| 167 |
+
self.common_length = all_source_length
|
| 168 |
+
start, stop = self._get_segment_start_stop(self.common_length, other_length)
|
| 169 |
+
source_tmp,_ = sf.read(source_path, dtype="float32")
|
| 170 |
+
other_tmp,_ = sf.read(other_path, dtype="float32", start=start, stop=stop)
|
| 171 |
+
|
| 172 |
+
noise_counter = 0
|
| 173 |
+
while True:
|
| 174 |
+
random_idx = np.random.randint(0, self.total_noise)
|
| 175 |
+
|
| 176 |
+
noise_id, noise_path, all_noise_length= self.noise_data[random_idx]
|
| 177 |
+
all_noise_length = int(all_noise_length)
|
| 178 |
+
|
| 179 |
+
if all_noise_length >= self.common_length:
|
| 180 |
+
break
|
| 181 |
+
noise_counter += 1
|
| 182 |
+
if noise_counter >= self.total_noise:
|
| 183 |
+
raise ValueError("All Data can't as noise")
|
| 184 |
+
|
| 185 |
+
enroll_counter = 0
|
| 186 |
+
while True:
|
| 187 |
+
random_idx = np.random.randint(0, self.total_enroll)
|
| 188 |
+
if self.enroll_data[random_idx][1] == source_spk:
|
| 189 |
+
enroll_id, enroll_spk, enroll_path, all_enroll_length= self.enroll_data[random_idx]
|
| 190 |
+
all_enroll_length = int(all_enroll_length)
|
| 191 |
+
break
|
| 192 |
+
|
| 193 |
+
enroll_counter += 1
|
| 194 |
+
if enroll_counter >= self.total_enroll:
|
| 195 |
+
raise ValueError("All Data can't as enroll")
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
source = source_tmp[:, np.random.randint(0, source_tmp.shape[1])]
|
| 201 |
+
other = other_tmp[:, np.random.randint(0, other_tmp.shape[1])]
|
| 202 |
+
|
| 203 |
+
noise_start, noise_stop = self._get_segment_start_stop(self.common_length, all_noise_length)
|
| 204 |
+
noise,_ = sf.read(noise_path, dtype="float32", start=noise_start, stop=noise_stop) # single channel?
|
| 205 |
+
# noise = noise_tmp[:, np.random.randint(0, noise_tmp.shape[1])]
|
| 206 |
+
# other_noise = self._mix([other,noise])
|
| 207 |
+
desired_snr = np.random.uniform(-4, 4) # 设置目标 SNR
|
| 208 |
+
current_snr = 10 * np.log10(np.mean(source ** 2) / (np.mean(noise ** 2) + EPS) + EPS)
|
| 209 |
+
scale_factor = 10 ** ((current_snr - desired_snr ) / 20)
|
| 210 |
+
scaled_noise = noise * scale_factor
|
| 211 |
+
|
| 212 |
+
snr = 10 * np.log10(np.mean(source ** 2) / (np.mean(scaled_noise ** 2) + EPS) + EPS)
|
| 213 |
+
mixture = self._mix([source,other,scaled_noise])
|
| 214 |
+
|
| 215 |
+
mixture = mixture.astype(np.float32)
|
| 216 |
+
|
| 217 |
+
enroll_tmp, _ = sf.read(enroll_path, dtype="float32")
|
| 218 |
+
enroll = enroll_tmp[:, np.random.randint(0, enroll_tmp.shape[1])]
|
| 219 |
+
|
| 220 |
+
return {
|
| 221 |
+
"mix": mixture,
|
| 222 |
+
"ref": source,
|
| 223 |
+
"aux": enroll,
|
| 224 |
+
"aux_len": all_enroll_length,
|
| 225 |
+
"spk_idx": spk_idx
|
| 226 |
+
}
|
| 227 |
+
|
| 228 |
+
class ChunkSplitter(object):
|
| 229 |
+
"""
|
| 230 |
+
Split utterance into small chunks
|
| 231 |
+
"""
|
| 232 |
+
def __init__(self, chunk_size, train=True, least=16000):
|
| 233 |
+
self.chunk_size = chunk_size
|
| 234 |
+
self.least = least
|
| 235 |
+
self.train = train
|
| 236 |
+
|
| 237 |
+
def _make_chunk(self, eg, s):
|
| 238 |
+
"""
|
| 239 |
+
Make a chunk instance, which contains:
|
| 240 |
+
"mix": ndarray,
|
| 241 |
+
"ref": [ndarray...]
|
| 242 |
+
"""
|
| 243 |
+
chunk = dict()
|
| 244 |
+
chunk["mix"] = eg["mix"][s:s + self.chunk_size]
|
| 245 |
+
chunk["ref"] = eg["ref"][s:s + self.chunk_size]
|
| 246 |
+
chunk["aux"] = eg["aux"]
|
| 247 |
+
chunk["aux_len"] = eg["aux_len"]
|
| 248 |
+
chunk["valid_len"] = int(self.chunk_size)
|
| 249 |
+
chunk["spk_idx"] = eg["spk_idx"]
|
| 250 |
+
return chunk
|
| 251 |
+
|
| 252 |
+
def split(self, eg):
|
| 253 |
+
N = eg["mix"].size
|
| 254 |
+
# too short, throw away
|
| 255 |
+
if N < self.least:
|
| 256 |
+
return []
|
| 257 |
+
chunks = []
|
| 258 |
+
# padding zeros
|
| 259 |
+
if N < self.chunk_size:
|
| 260 |
+
P = self.chunk_size - N
|
| 261 |
+
chunk = dict()
|
| 262 |
+
chunk["mix"] = np.pad(eg["mix"], (0, P), "constant")
|
| 263 |
+
chunk["ref"] = np.pad(eg["ref"], (0, P), "constant")
|
| 264 |
+
chunk["aux"] = eg["aux"]
|
| 265 |
+
chunk["aux_len"] = eg["aux_len"]
|
| 266 |
+
chunk["valid_len"] = int(N)
|
| 267 |
+
chunk["spk_idx"] = eg["spk_idx"]
|
| 268 |
+
chunks.append(chunk)
|
| 269 |
+
# else:
|
| 270 |
+
# # random select start point for training
|
| 271 |
+
# s = random.randint(0, N % self.least) if self.train else 0
|
| 272 |
+
# while True:
|
| 273 |
+
# if s + self.chunk_size > N:
|
| 274 |
+
# break
|
| 275 |
+
# chunk = self._make_chunk(eg, s)
|
| 276 |
+
# chunks.append(chunk)
|
| 277 |
+
# s += self.least
|
| 278 |
+
# return chunks
|
| 279 |
+
|
| 280 |
+
else:
|
| 281 |
+
if self.train:
|
| 282 |
+
# random select A start point for training
|
| 283 |
+
s = random.randint(0, N - self.chunk_size)
|
| 284 |
+
chunk = self._make_chunk(eg, s)
|
| 285 |
+
chunks.append(chunk)
|
| 286 |
+
else:
|
| 287 |
+
s = 0
|
| 288 |
+
while True:
|
| 289 |
+
if s + self.chunk_size > N:
|
| 290 |
+
break
|
| 291 |
+
chunk = self._make_chunk(eg, s)
|
| 292 |
+
chunks.append(chunk)
|
| 293 |
+
s += self.least
|
| 294 |
+
return chunks
|
| 295 |
+
|
| 296 |
+
class DataLoader(object):
|
| 297 |
+
"""
|
| 298 |
+
Online dataloader for chunk-level
|
| 299 |
+
"""
|
| 300 |
+
def __init__(self,
|
| 301 |
+
dataset,
|
| 302 |
+
num_workers=4,
|
| 303 |
+
chunk_size=32000,
|
| 304 |
+
batch_size=16,
|
| 305 |
+
train=True):
|
| 306 |
+
self.batch_size = batch_size
|
| 307 |
+
self.train = train
|
| 308 |
+
self.splitter = ChunkSplitter(chunk_size,
|
| 309 |
+
train=train,
|
| 310 |
+
least=chunk_size // 2)
|
| 311 |
+
# just return batch of egs, support multiple workers
|
| 312 |
+
self.eg_loader = dat.DataLoader(dataset,
|
| 313 |
+
batch_size=batch_size // 2,
|
| 314 |
+
num_workers=num_workers,
|
| 315 |
+
shuffle=train,
|
| 316 |
+
collate_fn=self._collate)
|
| 317 |
+
|
| 318 |
+
def _collate(self, batch):
|
| 319 |
+
"""
|
| 320 |
+
Online split utterances
|
| 321 |
+
"""
|
| 322 |
+
chunk = []
|
| 323 |
+
for eg in batch:
|
| 324 |
+
chunk += self.splitter.split(eg)
|
| 325 |
+
return chunk
|
| 326 |
+
|
| 327 |
+
def _pad_aux(self, chunk_list):
|
| 328 |
+
lens_list = []
|
| 329 |
+
for chunk_item in chunk_list:
|
| 330 |
+
lens_list.append(chunk_item['aux_len'])
|
| 331 |
+
max_len = np.max(lens_list)
|
| 332 |
+
# pad 0
|
| 333 |
+
for idx in range(len(chunk_list)):
|
| 334 |
+
P = max_len - len(chunk_list[idx]["aux"])
|
| 335 |
+
chunk_list[idx]["aux"] = np.pad(chunk_list[idx]["aux"], (0, P), "constant")
|
| 336 |
+
# # pad circle
|
| 337 |
+
# for idx in range(len(chunk_list)):
|
| 338 |
+
# P = max_len - len(chunk_list[idx]["aux"])
|
| 339 |
+
# original_aux_len = len(chunk_list[idx]["aux"])
|
| 340 |
+
# # 使用循环来填充原句子的内容
|
| 341 |
+
# for i in range(P):
|
| 342 |
+
# chunk_list[idx]["aux"].append(chunk_list[idx]["aux"][i % original_aux_len])
|
| 343 |
+
|
| 344 |
+
|
| 345 |
+
return chunk_list
|
| 346 |
+
|
| 347 |
+
def _merge(self, chunk_list):
|
| 348 |
+
"""
|
| 349 |
+
Merge chunk list into mini-batch
|
| 350 |
+
"""
|
| 351 |
+
N = len(chunk_list)
|
| 352 |
+
if self.train:
|
| 353 |
+
random.shuffle(chunk_list)
|
| 354 |
+
blist = []
|
| 355 |
+
for s in range(0, N - self.batch_size + 1, self.batch_size):
|
| 356 |
+
# padding aux info
|
| 357 |
+
#self._pad_aux(chunk_list[s:s + self.batch_size])
|
| 358 |
+
batch = default_collate(self._pad_aux(chunk_list[s:s + self.batch_size]))
|
| 359 |
+
blist.append(batch)
|
| 360 |
+
rn = N % self.batch_size
|
| 361 |
+
return blist, chunk_list[-rn:] if rn else []
|
| 362 |
+
|
| 363 |
+
def __iter__(self):
|
| 364 |
+
chunk_list = []
|
| 365 |
+
for chunks in self.eg_loader:
|
| 366 |
+
chunk_list += chunks
|
| 367 |
+
batch, chunk_list = self._merge(chunk_list)
|
| 368 |
+
for obj in batch:
|
| 369 |
+
yield obj
|
| 370 |
+
|
| 371 |
+
|
| 372 |
+
|
| 373 |
+
# def snr_xy(x, y):
|
| 374 |
+
# return 10 * np.log10(np.mean(x ** 2) / (np.mean(y ** 2) + EPS) + EPS)
|
| 375 |
+
|
| 376 |
+
# def main(args):
|
| 377 |
+
# wham_noise_dir = args.wham_dir
|
| 378 |
+
# # Get train dir
|
| 379 |
+
# subdir = os.path.join(wham_noise_dir, 'tr')
|
| 380 |
+
# # List files in that dir
|
| 381 |
+
# sound_paths = glob.glob(os.path.join(subdir, '**/*.wav'),
|
| 382 |
+
# recursive=True)
|
| 383 |
+
# # Avoid running this script if it already have been run
|
| 384 |
+
# if len(sound_paths) == 60000:
|
| 385 |
+
# print("It appears that augmented files have already been generated.\n"
|
| 386 |
+
# "Skipping data augmentation.")
|
| 387 |
+
# return
|
| 388 |
+
# elif len(sound_paths) != 20000:
|
| 389 |
+
# print("It appears that augmented files have not been generated properly\n"
|
| 390 |
+
# "Resuming augmentation.")
|
| 391 |
+
# originals = [x for x in sound_paths if 'sp' not in x]
|
| 392 |
+
# to_be_removed_08 = [x.replace('sp08','') for x in sound_paths if 'sp08' in x]
|
| 393 |
+
# to_be_removed_12 = [x.replace('sp12','') for x in sound_paths if 'sp12' in x ]
|
| 394 |
+
# sound_paths_08 = list(set(originals) - set(to_be_removed_08))
|
| 395 |
+
# sound_paths_12 = list(set(originals) - set(to_be_removed_12))
|
| 396 |
+
# augment_noise(sound_paths_08, 0.8)
|
| 397 |
+
# augment_noise(sound_paths_12, 1.2)
|
| 398 |
+
# else:
|
| 399 |
+
# print(f'Augmenting {subdir} files')
|
| 400 |
+
# # Transform audio speed
|
| 401 |
+
# augment_noise(sound_paths, 0.8)
|
| 402 |
+
# augment_noise(sound_paths, 1.2)
|
utils/load_obj.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
|
| 3 |
+
import torch as th
|
| 4 |
+
|
| 5 |
+
def load_obj(obj, device):
|
| 6 |
+
"""
|
| 7 |
+
Offload tensor object in obj to cuda device
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
def cuda(obj):
|
| 11 |
+
return obj.to(device) if isinstance(obj, th.Tensor) else obj
|
| 12 |
+
|
| 13 |
+
if isinstance(obj, dict):
|
| 14 |
+
return {key: load_obj(obj[key], device) for key in obj}
|
| 15 |
+
elif isinstance(obj, list):
|
| 16 |
+
return [load_obj(val, device) for val in obj]
|
| 17 |
+
else:
|
| 18 |
+
return cuda(obj)
|
utils/logger.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
|
| 3 |
+
import logging
|
| 4 |
+
|
| 5 |
+
def get_logger(
|
| 6 |
+
name,
|
| 7 |
+
format_str="%(asctime)s [%(pathname)s:%(lineno)s - %(levelname)s ] %(message)s",
|
| 8 |
+
date_format="%Y-%m-%d %H:%M:%S",
|
| 9 |
+
file=False):
|
| 10 |
+
"""
|
| 11 |
+
Get python logger instance
|
| 12 |
+
"""
|
| 13 |
+
logger = logging.getLogger(name)
|
| 14 |
+
logger.setLevel(logging.INFO)
|
| 15 |
+
# file or console
|
| 16 |
+
handler = logging.StreamHandler() if not file else logging.FileHandler(
|
| 17 |
+
name)
|
| 18 |
+
handler.setLevel(logging.INFO)
|
| 19 |
+
formatter = logging.Formatter(fmt=format_str, datefmt=date_format)
|
| 20 |
+
handler.setFormatter(formatter)
|
| 21 |
+
logger.addHandler(handler)
|
| 22 |
+
return logger
|
utils/sisdr.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
|
| 5 |
+
def sisdr(x, s, remove_dc=True):
|
| 6 |
+
"""
|
| 7 |
+
Compute SI-SDR
|
| 8 |
+
x: extracted signal
|
| 9 |
+
s: reference signal(ground truth)
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
def vec_l2norm(x):
|
| 13 |
+
return np.linalg.norm(x, 2)
|
| 14 |
+
|
| 15 |
+
if remove_dc:
|
| 16 |
+
x_zm = x - np.mean(x)
|
| 17 |
+
s_zm = s - np.mean(s)
|
| 18 |
+
t = np.inner(x_zm, s_zm) * s_zm / vec_l2norm(s_zm)**2
|
| 19 |
+
n = x_zm - t
|
| 20 |
+
else:
|
| 21 |
+
t = np.inner(x, s) * s / vec_l2norm(s)**2
|
| 22 |
+
n = x - t
|
| 23 |
+
return 20 * np.log10(vec_l2norm(t) / vec_l2norm(n))
|
utils/timer.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
|
| 3 |
+
import time
|
| 4 |
+
|
| 5 |
+
class Timer(object):
|
| 6 |
+
"""
|
| 7 |
+
A timer to record the elapsed time
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
def __init__(self):
|
| 11 |
+
self.reset()
|
| 12 |
+
|
| 13 |
+
def reset(self):
|
| 14 |
+
self.start = time.time()
|
| 15 |
+
|
| 16 |
+
def elapsed(self):
|
| 17 |
+
return (time.time() - self.start) / 60
|