import functools import os from pathlib import Path from typing import Callable, Dict, List, Optional, Tuple, Union import numpy as np import torch import torch.nn as nn import torchaudio from torch.utils.data import Dataset from torch import flatten from typing import Optional import torchaudio.functional as F import random def find_wav_files(path_to_dir: Union[Path, str]): paths = list(sorted(Path(path_to_dir).glob("**/*.wav"))) if len(paths) == 0: return None return paths def set_seed_all(seed: int = 0): if not isinstance(seed, int): seed = 0 random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True os.environ["PYTHONHASHSEED"] = str(seed) return None SOX_SILENCE = [ ["silence", "1", "0.2", "1%", "-1", "0.2", "1%"], ] class AudioDataset(Dataset): def __init__( self, directory_or_path_list: Union[Union[str, Path], List[Union[str, Path]]], sample_rate: int = 16_000, amount: Optional[int] = None, normalize: bool = True, trim: bool = True ) : super().__init__() self.trim = trim self.sample_rate = sample_rate self.normalize = normalize if isinstance(directory_or_path_list, list): paths = directory_or_path_list elif isinstance(directory_or_path_list, Path) or isinstance( directory_or_path_list, str ): directory = Path(directory_or_path_list) paths = find_wav_files(directory) if amount is not None: paths = paths[:amount] self._paths = paths def __getitem__(self, index: int) -> Tuple[torch.Tensor, int]: path = self._paths[index] waveform, sample_rate = torchaudio.load(path, normalize=self.normalize) if sample_rate != self.sample_rate: waveform, sample_rate = torchaudio.sox_effects.apply_effects_file( path, [["rate", f"{self.sample_rate}"]], normalize=self.normalize ) if self.trim: ( waveform_trimmed, sample_rate_trimmed, ) = torchaudio.sox_effects.apply_effects_tensor( waveform, sample_rate, SOX_SILENCE ) if waveform_trimmed.size()[1] > 0: waveform = waveform_trimmed sample_rate = sample_rate_trimmed audio_path = str(path) return waveform, sample_rate, str(audio_path) def __len__(self) -> int: return len(self._paths) class PadDataset(Dataset): def __init__(self, dataset: Dataset, cut: int = 64600, label=None): self.dataset = dataset self.cut = cut self.label = label def __getitem__(self, index): waveform, sample_rate, audio_path = self.dataset[index] waveform = waveform.squeeze(0) waveform_len = waveform.shape[0] if waveform_len >= self.cut: if self.label is None: return waveform[: self.cut], sample_rate, str(audio_path) else: return waveform[: self.cut], sample_rate, str(audio_path), self.label # need to pad num_repeats = int(self.cut / waveform_len) + 1 padded_waveform = torch.tile(waveform, (1, num_repeats))[:, : self.cut][0] if self.label is None: return padded_waveform, sample_rate, str(audio_path) else: return padded_waveform, sample_rate, str(audio_path), self.label def __len__(self): return len(self.dataset) class TransformDataset(Dataset): def __init__( self, dataset: Dataset, transformation: Callable, needs_sample_rate: bool = False, transform_kwargs: dict = {}, ) -> None: super().__init__() self._dataset = dataset self._transform_constructor = transformation self._needs_sample_rate = needs_sample_rate self._transform_kwargs = transform_kwargs self._transform = None def __len__(self): return len(self._dataset) def __getitem__(self, index: int) -> Tuple[torch.Tensor, int]: waveform, sample_rate, audio_path = self._dataset[index] if self._transform is None: if self._needs_sample_rate: self._transform = self._transform_constructor( sample_rate, **self._transform_kwargs ) else: self._transform = self._transform_constructor(**self._transform_kwargs) return self._transform(waveform), sample_rate, str(audio_path) class DoubleDeltaTransform(torch.nn.Module): def __init__(self, win_length: int = 5, mode: str = "replicate"): super().__init__() self.win_length = win_length self.mode = mode self._delta = torchaudio.transforms.ComputeDeltas( win_length=self.win_length, mode=self.mode ) def forward(self, X): delta = self._delta(X) double_delta = self._delta(delta) return torch.hstack((X, delta, double_delta)) def _build_preprocessing( directory_or_audiodataset: Union[Union[str, Path], AudioDataset], transform: torch.nn.Module, audiokwargs: dict = {}, transformkwargs: dict = {}, ): if isinstance(directory_or_audiodataset, AudioDataset) or isinstance( directory_or_audiodataset, PadDataset ): return TransformDataset( dataset=directory_or_audiodataset, transformation=transform, needs_sample_rate=True, transform_kwargs=transformkwargs, ) elif isinstance(directory_or_audiodataset, str) or isinstance( directory_or_audiodataset, Path ): return TransformDataset( dataset=AudioDataset(directory=directory_or_audiodataset, **audiokwargs), transformation=transform, needs_sample_rate=True, transform_kwargs=transformkwargs, ) mfcc = functools.partial(_build_preprocessing, transform=torchaudio.transforms.MFCC) def double_delta(dataset: Dataset, delta_kwargs: dict = {}) -> TransformDataset: return TransformDataset( dataset=dataset, transformation=DoubleDeltaTransform, transform_kwargs=delta_kwargs, ) # def load_directory_split_train_test( # path: Union[Path, str], # feature_fn: Callable, # feature_kwargs: dict, # test_size: float, # use_double_delta: bool = True, # pad: bool = False, # label: Optional[int] = None, # ): # paths = find_wav_files(path) # test_size = int(test_size * len(paths)) # train_paths = paths[:-test_size] # test_paths = paths[-test_size:] # train_dataset = AudioDataset(train_paths) # if pad: # train_dataset = PadDataset(train_dataset, label=label) # test_dataset = AudioDataset(test_paths) # if pad: # test_dataset = PadDataset(test_dataset, label=label) # dataset_train = feature_fn( # directory_or_audiodataset=train_dataset, # transformkwargs=feature_kwargs, # ) # dataset_test = feature_fn( # directory_or_audiodataset=test_dataset, # transformkwargs=feature_kwargs, # ) # if use_double_delta: # dataset_train = double_delta(dataset_train) # dataset_test = double_delta(dataset_test) # return dataset_train, dataset_test class ShallowCNN(nn.Module): def __init__(self, in_features, out_dim, **kwargs): super(ShallowCNN, self).__init__() self.conv1 = nn.Conv2d(in_features, 32, kernel_size=4, stride=1, padding=1) self.conv2 = nn.Conv2d(32, 48, kernel_size=5, stride=1, padding=1) self.conv3 = nn.Conv2d(48, 64, kernel_size=4, stride=1, padding=1) self.conv4 = nn.Conv2d(64, 128, kernel_size=(2, 4), stride=1, padding=1) self.pool = nn.MaxPool2d(2, 2) self.fc1 = nn.Linear(15104, 128) self.fc2 = nn.Linear(128, out_dim) self.relu = nn.ReLU() def forward(self, x: torch.Tensor): x = x.unsqueeze(1) x = self.pool(self.relu(self.conv1(x))) x = self.pool(self.relu(self.conv2(x))) x = self.pool(self.relu(self.conv3(x))) x = self.pool(self.relu(self.conv4(x))) x = flatten(x, 1) x = self.relu(self.fc1(x)) x = self.fc2(x) return x class SimpleLSTM(nn.Module): def __init__( self, feat_dim: int, time_dim: int, mid_dim: int, out_dim: int, **kwargs, ): super(SimpleLSTM, self).__init__() self.lstm = nn.LSTM( input_size=feat_dim, hidden_size=mid_dim, num_layers=2, bidirectional=True, batch_first=True, dropout=0.01, ) self.conv = nn.Conv1d(in_channels=mid_dim * 2, out_channels=10, kernel_size=1) self.relu = nn.ReLU() self.fc = nn.Linear(in_features=time_dim * 10, out_features=out_dim) def forward(self, x: torch.Tensor): B = x.size(0) x = x.permute(0, 2, 1) lstm_out, _ = self.lstm(x) feat = lstm_out.permute(0, 2, 1) feat = self.conv(feat) feat = self.relu(feat) feat = feat.reshape(B, -1) out = self.fc(feat) return out import torch import torch.nn.functional as F import torch.utils.checkpoint as cp from torch import nn def get_nonlinear(config_str, channels): nonlinear = nn.Sequential() for name in config_str.split('-'): if name == 'relu': nonlinear.add_module('relu', nn.ReLU(inplace=True)) elif name == 'prelu': nonlinear.add_module('prelu', nn.PReLU(channels)) elif name == 'batchnorm': nonlinear.add_module('batchnorm', nn.BatchNorm1d(channels)) elif name == 'batchnorm_': nonlinear.add_module('batchnorm', nn.BatchNorm1d(channels, affine=False)) else: raise ValueError('Unexpected module ({}).'.format(name)) return nonlinear def statistics_pooling(x, dim=-1, keepdim=False, unbiased=True, eps=1e-2): mean = x.mean(dim=dim) std = x.std(dim=dim, unbiased=False) stats = torch.cat([mean, std], dim=-1) if keepdim: stats = stats.unsqueeze(dim=dim) return stats def high_order_statistics_pooling(x, dim=-1, keepdim=False, unbiased=True, eps=1e-2): mean = x.mean(dim=dim) std = x.std(dim=dim, unbiased=unbiased) norm = (x - mean.unsqueeze(dim=dim)) \ / std.clamp(min=eps).unsqueeze(dim=dim) skewness = norm.pow(3).mean(dim=dim) kurtosis = norm.pow(4).mean(dim=dim) stats = torch.cat([mean, std, skewness, kurtosis], dim=-1) if keepdim: stats = stats.unsqueeze(dim=dim) return stats class StatsPool(nn.Module): def forward(self, x): ret = statistics_pooling(x) return ret class HighOrderStatsPool(nn.Module): def forward(self, x): return high_order_statistics_pooling(x) class TDNNLayer(nn.Module): def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, bias=False, config_str='batchnorm-relu'): super(TDNNLayer, self).__init__() if padding < 0: assert kernel_size % 2 == 1, 'Expect equal paddings, but got even kernel size ({})'.format( kernel_size) padding = (kernel_size - 1) // 2 * dilation self.linear = nn.Conv1d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, bias=bias) self.nonlinear = get_nonlinear(config_str, out_channels) def forward(self, x): x = self.linear(x) # print("linear", x) x = self.nonlinear(x) # print("nonlinear", x) return x class DenseTDNNLayer(nn.Module): def __init__(self, in_channels, out_channels, bn_channels, kernel_size, stride=1, dilation=1, bias=False, config_str='batchnorm-relu', memory_efficient=False): super(DenseTDNNLayer, self).__init__() assert kernel_size % 2 == 1, 'Expect equal paddings, but got even kernel size ({})'.format( kernel_size) padding = (kernel_size - 1) // 2 * dilation self.memory_efficient = memory_efficient self.nonlinear1 = get_nonlinear(config_str, in_channels) self.linear1 = nn.Conv1d(in_channels, bn_channels, 1, bias=False) self.nonlinear2 = get_nonlinear(config_str, bn_channels) self.linear2 = nn.Conv1d(bn_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, bias=bias) def bn_function(self, x): return self.linear1(self.nonlinear1(x)) def forward(self, x): x = self.bn_function(x) x = self.linear2(self.nonlinear2(x)) return x class DenseTDNNBlock(nn.ModuleList): def __init__(self, num_layers, in_channels, out_channels, bn_channels, kernel_size, stride=1, dilation=1, bias=False, config_str='batchnorm-relu', memory_efficient=False): super(DenseTDNNBlock, self).__init__() for i in range(num_layers): layer = DenseTDNNLayer(in_channels=in_channels + i * out_channels, out_channels=out_channels, bn_channels=bn_channels, kernel_size=kernel_size, stride=stride, dilation=dilation, bias=bias, config_str=config_str, memory_efficient=memory_efficient) self.add_module('tdnnd%d' % (i + 1), layer) def forward(self, x): for layer in self: x = torch.cat([x, layer(x)], dim=1) return x class StatsSelect(nn.Module): def __init__(self, channels, branches, null=False, reduction=1): super(StatsSelect, self).__init__() self.gather = HighOrderStatsPool() self.linear1 = nn.Conv1d(channels * 4, channels // reduction, 1) self.linear2 = nn.ModuleList() if null: branches += 1 for _ in range(branches): self.linear2.append(nn.Conv1d(channels // reduction, channels, 1)) self.channels = channels self.branches = branches self.null = null self.reduction = reduction def forward(self, x): f = torch.cat([_x.unsqueeze(dim=1) for _x in x], dim=1) x = torch.sum(f, dim=1) x = self.linear1(self.gather(x).unsqueeze(dim=-1)) s = [] for linear in self.linear2: s.append(linear(x).view(-1, 1, self.channels)) s = torch.cat(s, dim=1) s = F.softmax(s, dim=1).unsqueeze(dim=-1) if self.null: s = s[:, :-1, :, :] return torch.sum(f * s, dim=1) def extra_repr(self): return 'channels={}, branches={}, reduction={}'.format( self.channels, self.branches, self.reduction) class TransitLayer(nn.Module): def __init__(self, in_channels, out_channels, bias=True, config_str='batchnorm-relu'): super(TransitLayer, self).__init__() self.nonlinear = get_nonlinear(config_str, in_channels) self.linear = nn.Conv1d(in_channels, out_channels, 1, bias=bias) def forward(self, x): x = self.nonlinear(x) # print("nonlinear", x) x = self.linear(x) # print("linear", x) return x class DenseLayer(nn.Module): def __init__(self, in_channels, out_channels, bias=False, config_str='batchnorm-relu'): super(DenseLayer, self).__init__() self.linear = nn.Conv1d(in_channels, out_channels, 1, bias=bias) self.nonlinear = get_nonlinear(config_str, out_channels) def forward(self, x): if len(x.shape) == 2: x = self.linear(x.unsqueeze(dim=-1)).squeeze(dim=-1) else: x = self.linear(x) x = self.nonlinear(x) return x from collections import OrderedDict from torch import nn class DTDNN(nn.Module): def __init__(self, feat_dim=30, embedding_size=512, num_classes=None, growth_rate=64, bn_size=2, init_channels=128, config_str='batchnorm-relu', memory_efficient=True): super(DTDNN, self).__init__() self.xvector = nn.Sequential( OrderedDict([ ('tdnn', TDNNLayer(feat_dim, init_channels, 5, dilation=1, padding=-1, config_str=config_str)), ])) channels = init_channels for i, (num_layers, kernel_size, dilation) in enumerate(zip((6, 12), (3, 3), (1, 3))): block = DenseTDNNBlock(num_layers=num_layers, in_channels=channels, out_channels=growth_rate, bn_channels=bn_size * growth_rate, kernel_size=kernel_size, dilation=dilation, config_str=config_str, memory_efficient=memory_efficient) self.xvector.add_module('block%d' % (i + 1), block) channels = channels + num_layers * growth_rate self.xvector.add_module( 'transit%d' % (i + 1), TransitLayer(channels, channels // 2, bias=False, config_str=config_str)) channels //= 2 self.xvector.add_module('stats', StatsPool()) self.xvector.add_module( 'dense', DenseLayer(channels * 2, embedding_size, config_str='batchnorm_')) if num_classes is not None: self.classifier = nn.Linear(embedding_size, num_classes) self.softmax = nn.Softmax() for m in self.modules(): if isinstance(m, (nn.Conv1d, nn.Linear)): nn.init.kaiming_normal_(m.weight.data) if m.bias is not None: nn.init.zeros_(m.bias) def forward(self, x): x = x.unsqueeze(1).permute(0,2,1) x = self.xvector(x) x = self.classifier(x) # x = self.softmax(x) return x def pred_audio(path): audio = [path] audio_ds = AudioDataset(audio) audio_ds = PadDataset(audio_ds) audio_ds = mfcc( directory_or_audiodataset=audio_ds, transformkwargs={} ) audio_ds = double_delta(audio_ds) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") cnn_model = ShallowCNN(in_features= 1,out_dim=1).to(device) cnn_checkpoint = torch.load("./models/best_cnn.pt", map_location=device) cnn_model.load_state_dict(cnn_checkpoint['state_dict']) lstm_model = SimpleLSTM(feat_dim= 40, time_dim= 972, mid_dim= 30, out_dim= 1).to(device) lstm_checkpoint = torch.load("./models/best_lstm.pt", map_location=device) lstm_model.load_state_dict(lstm_checkpoint['state_dict']) dtdnn_model = DTDNN(feat_dim= 38880,num_classes= 1).to(device) dtdnn_checkpoint = torch.load("./models/best_tdnn.pt", map_location=device) dtdnn_model.load_state_dict(dtdnn_checkpoint['state_dict']) # Set models to evaluation mode cnn_model.eval() lstm_model.eval() dtdnn_model.eval() # Prepare input data input_data = audio_ds[0][0].unsqueeze(0) # Forward pass through CNN model cnn_output = cnn_model(input_data) cnn_prob = torch.sigmoid(cnn_output) # Forward pass through LSTM model lstm_output = lstm_model(input_data) lstm_prob = torch.sigmoid(lstm_output) # Forward pass through DT-DNN model dtdnn_input = input_data.view(input_data.size(0), -1) dtdnn_output = dtdnn_model(dtdnn_input) dtdnn_prob = torch.sigmoid(dtdnn_output) # Combine predictions combined_prob = (cnn_prob + lstm_prob + dtdnn_prob) / 3 # Classify based on combined probabilities combined_pred = (combined_prob >= 0.5).int() cnn_pred = (cnn_prob >= 0.5).int() lstm_pred = (lstm_prob >= 0.5).int() dtdnn_pred = (dtdnn_prob >= 0.5).int() return [cnn_pred.item(), lstm_pred.item(), dtdnn_pred.item(), combined_pred.item()]