audioDeepFake / model.py
nachi1326's picture
Update model.py
950dac2 verified
import functools
import os
from pathlib import Path
from typing import Callable, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
import torchaudio
from torch.utils.data import Dataset
from torch import flatten
from typing import Optional
import torchaudio.functional as F
import random
def find_wav_files(path_to_dir: Union[Path, str]):
paths = list(sorted(Path(path_to_dir).glob("**/*.wav")))
if len(paths) == 0:
return None
return paths
def set_seed_all(seed: int = 0):
if not isinstance(seed, int):
seed = 0
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
os.environ["PYTHONHASHSEED"] = str(seed)
return None
SOX_SILENCE = [
["silence", "1", "0.2", "1%", "-1", "0.2", "1%"],
]
class AudioDataset(Dataset):
def __init__(
self,
directory_or_path_list: Union[Union[str, Path], List[Union[str, Path]]],
sample_rate: int = 16_000,
amount: Optional[int] = None,
normalize: bool = True,
trim: bool = True
) :
super().__init__()
self.trim = trim
self.sample_rate = sample_rate
self.normalize = normalize
if isinstance(directory_or_path_list, list):
paths = directory_or_path_list
elif isinstance(directory_or_path_list, Path) or isinstance(
directory_or_path_list, str
):
directory = Path(directory_or_path_list)
paths = find_wav_files(directory)
if amount is not None:
paths = paths[:amount]
self._paths = paths
def __getitem__(self, index: int) -> Tuple[torch.Tensor, int]:
path = self._paths[index]
waveform, sample_rate = torchaudio.load(path, normalize=self.normalize)
if sample_rate != self.sample_rate:
waveform, sample_rate = torchaudio.sox_effects.apply_effects_file(
path, [["rate", f"{self.sample_rate}"]], normalize=self.normalize
)
if self.trim:
(
waveform_trimmed,
sample_rate_trimmed,
) = torchaudio.sox_effects.apply_effects_tensor(
waveform, sample_rate, SOX_SILENCE
)
if waveform_trimmed.size()[1] > 0:
waveform = waveform_trimmed
sample_rate = sample_rate_trimmed
audio_path = str(path)
return waveform, sample_rate, str(audio_path)
def __len__(self) -> int:
return len(self._paths)
class PadDataset(Dataset):
def __init__(self, dataset: Dataset, cut: int = 64600, label=None):
self.dataset = dataset
self.cut = cut
self.label = label
def __getitem__(self, index):
waveform, sample_rate, audio_path = self.dataset[index]
waveform = waveform.squeeze(0)
waveform_len = waveform.shape[0]
if waveform_len >= self.cut:
if self.label is None:
return waveform[: self.cut], sample_rate, str(audio_path)
else:
return waveform[: self.cut], sample_rate, str(audio_path), self.label
# need to pad
num_repeats = int(self.cut / waveform_len) + 1
padded_waveform = torch.tile(waveform, (1, num_repeats))[:, : self.cut][0]
if self.label is None:
return padded_waveform, sample_rate, str(audio_path)
else:
return padded_waveform, sample_rate, str(audio_path), self.label
def __len__(self):
return len(self.dataset)
class TransformDataset(Dataset):
def __init__(
self,
dataset: Dataset,
transformation: Callable,
needs_sample_rate: bool = False,
transform_kwargs: dict = {},
) -> None:
super().__init__()
self._dataset = dataset
self._transform_constructor = transformation
self._needs_sample_rate = needs_sample_rate
self._transform_kwargs = transform_kwargs
self._transform = None
def __len__(self):
return len(self._dataset)
def __getitem__(self, index: int) -> Tuple[torch.Tensor, int]:
waveform, sample_rate, audio_path = self._dataset[index]
if self._transform is None:
if self._needs_sample_rate:
self._transform = self._transform_constructor(
sample_rate, **self._transform_kwargs
)
else:
self._transform = self._transform_constructor(**self._transform_kwargs)
return self._transform(waveform), sample_rate, str(audio_path)
class DoubleDeltaTransform(torch.nn.Module):
def __init__(self, win_length: int = 5, mode: str = "replicate"):
super().__init__()
self.win_length = win_length
self.mode = mode
self._delta = torchaudio.transforms.ComputeDeltas(
win_length=self.win_length, mode=self.mode
)
def forward(self, X):
delta = self._delta(X)
double_delta = self._delta(delta)
return torch.hstack((X, delta, double_delta))
def _build_preprocessing(
directory_or_audiodataset: Union[Union[str, Path], AudioDataset],
transform: torch.nn.Module,
audiokwargs: dict = {},
transformkwargs: dict = {},
):
if isinstance(directory_or_audiodataset, AudioDataset) or isinstance(
directory_or_audiodataset, PadDataset
):
return TransformDataset(
dataset=directory_or_audiodataset,
transformation=transform,
needs_sample_rate=True,
transform_kwargs=transformkwargs,
)
elif isinstance(directory_or_audiodataset, str) or isinstance(
directory_or_audiodataset, Path
):
return TransformDataset(
dataset=AudioDataset(directory=directory_or_audiodataset, **audiokwargs),
transformation=transform,
needs_sample_rate=True,
transform_kwargs=transformkwargs,
)
mfcc = functools.partial(_build_preprocessing, transform=torchaudio.transforms.MFCC)
def double_delta(dataset: Dataset, delta_kwargs: dict = {}) -> TransformDataset:
return TransformDataset(
dataset=dataset,
transformation=DoubleDeltaTransform,
transform_kwargs=delta_kwargs,
)
# def load_directory_split_train_test(
# path: Union[Path, str],
# feature_fn: Callable,
# feature_kwargs: dict,
# test_size: float,
# use_double_delta: bool = True,
# pad: bool = False,
# label: Optional[int] = None,
# ):
# paths = find_wav_files(path)
# test_size = int(test_size * len(paths))
# train_paths = paths[:-test_size]
# test_paths = paths[-test_size:]
# train_dataset = AudioDataset(train_paths)
# if pad:
# train_dataset = PadDataset(train_dataset, label=label)
# test_dataset = AudioDataset(test_paths)
# if pad:
# test_dataset = PadDataset(test_dataset, label=label)
# dataset_train = feature_fn(
# directory_or_audiodataset=train_dataset,
# transformkwargs=feature_kwargs,
# )
# dataset_test = feature_fn(
# directory_or_audiodataset=test_dataset,
# transformkwargs=feature_kwargs,
# )
# if use_double_delta:
# dataset_train = double_delta(dataset_train)
# dataset_test = double_delta(dataset_test)
# return dataset_train, dataset_test
class ShallowCNN(nn.Module):
def __init__(self, in_features, out_dim, **kwargs):
super(ShallowCNN, self).__init__()
self.conv1 = nn.Conv2d(in_features, 32, kernel_size=4, stride=1, padding=1)
self.conv2 = nn.Conv2d(32, 48, kernel_size=5, stride=1, padding=1)
self.conv3 = nn.Conv2d(48, 64, kernel_size=4, stride=1, padding=1)
self.conv4 = nn.Conv2d(64, 128, kernel_size=(2, 4), stride=1, padding=1)
self.pool = nn.MaxPool2d(2, 2)
self.fc1 = nn.Linear(15104, 128)
self.fc2 = nn.Linear(128, out_dim)
self.relu = nn.ReLU()
def forward(self, x: torch.Tensor):
x = x.unsqueeze(1)
x = self.pool(self.relu(self.conv1(x)))
x = self.pool(self.relu(self.conv2(x)))
x = self.pool(self.relu(self.conv3(x)))
x = self.pool(self.relu(self.conv4(x)))
x = flatten(x, 1)
x = self.relu(self.fc1(x))
x = self.fc2(x)
return x
class SimpleLSTM(nn.Module):
def __init__(
self,
feat_dim: int,
time_dim: int,
mid_dim: int,
out_dim: int,
**kwargs,
):
super(SimpleLSTM, self).__init__()
self.lstm = nn.LSTM(
input_size=feat_dim,
hidden_size=mid_dim,
num_layers=2,
bidirectional=True,
batch_first=True,
dropout=0.01,
)
self.conv = nn.Conv1d(in_channels=mid_dim * 2, out_channels=10, kernel_size=1)
self.relu = nn.ReLU()
self.fc = nn.Linear(in_features=time_dim * 10, out_features=out_dim)
def forward(self, x: torch.Tensor):
B = x.size(0)
x = x.permute(0, 2, 1)
lstm_out, _ = self.lstm(x)
feat = lstm_out.permute(0, 2, 1)
feat = self.conv(feat)
feat = self.relu(feat)
feat = feat.reshape(B, -1)
out = self.fc(feat)
return out
import torch
import torch.nn.functional as F
import torch.utils.checkpoint as cp
from torch import nn
def get_nonlinear(config_str, channels):
nonlinear = nn.Sequential()
for name in config_str.split('-'):
if name == 'relu':
nonlinear.add_module('relu', nn.ReLU(inplace=True))
elif name == 'prelu':
nonlinear.add_module('prelu', nn.PReLU(channels))
elif name == 'batchnorm':
nonlinear.add_module('batchnorm', nn.BatchNorm1d(channels))
elif name == 'batchnorm_':
nonlinear.add_module('batchnorm',
nn.BatchNorm1d(channels, affine=False))
else:
raise ValueError('Unexpected module ({}).'.format(name))
return nonlinear
def statistics_pooling(x, dim=-1, keepdim=False, unbiased=True, eps=1e-2):
mean = x.mean(dim=dim)
std = x.std(dim=dim, unbiased=False)
stats = torch.cat([mean, std], dim=-1)
if keepdim:
stats = stats.unsqueeze(dim=dim)
return stats
def high_order_statistics_pooling(x,
dim=-1,
keepdim=False,
unbiased=True,
eps=1e-2):
mean = x.mean(dim=dim)
std = x.std(dim=dim, unbiased=unbiased)
norm = (x - mean.unsqueeze(dim=dim)) \
/ std.clamp(min=eps).unsqueeze(dim=dim)
skewness = norm.pow(3).mean(dim=dim)
kurtosis = norm.pow(4).mean(dim=dim)
stats = torch.cat([mean, std, skewness, kurtosis], dim=-1)
if keepdim:
stats = stats.unsqueeze(dim=dim)
return stats
class StatsPool(nn.Module):
def forward(self, x):
ret = statistics_pooling(x)
return ret
class HighOrderStatsPool(nn.Module):
def forward(self, x):
return high_order_statistics_pooling(x)
class TDNNLayer(nn.Module):
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
bias=False,
config_str='batchnorm-relu'):
super(TDNNLayer, self).__init__()
if padding < 0:
assert kernel_size % 2 == 1, 'Expect equal paddings, but got even kernel size ({})'.format(
kernel_size)
padding = (kernel_size - 1) // 2 * dilation
self.linear = nn.Conv1d(in_channels,
out_channels,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
bias=bias)
self.nonlinear = get_nonlinear(config_str, out_channels)
def forward(self, x):
x = self.linear(x)
# print("linear", x)
x = self.nonlinear(x)
# print("nonlinear", x)
return x
class DenseTDNNLayer(nn.Module):
def __init__(self,
in_channels,
out_channels,
bn_channels,
kernel_size,
stride=1,
dilation=1,
bias=False,
config_str='batchnorm-relu',
memory_efficient=False):
super(DenseTDNNLayer, self).__init__()
assert kernel_size % 2 == 1, 'Expect equal paddings, but got even kernel size ({})'.format(
kernel_size)
padding = (kernel_size - 1) // 2 * dilation
self.memory_efficient = memory_efficient
self.nonlinear1 = get_nonlinear(config_str, in_channels)
self.linear1 = nn.Conv1d(in_channels, bn_channels, 1, bias=False)
self.nonlinear2 = get_nonlinear(config_str, bn_channels)
self.linear2 = nn.Conv1d(bn_channels,
out_channels,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
bias=bias)
def bn_function(self, x):
return self.linear1(self.nonlinear1(x))
def forward(self, x):
x = self.bn_function(x)
x = self.linear2(self.nonlinear2(x))
return x
class DenseTDNNBlock(nn.ModuleList):
def __init__(self,
num_layers,
in_channels,
out_channels,
bn_channels,
kernel_size,
stride=1,
dilation=1,
bias=False,
config_str='batchnorm-relu',
memory_efficient=False):
super(DenseTDNNBlock, self).__init__()
for i in range(num_layers):
layer = DenseTDNNLayer(in_channels=in_channels + i * out_channels,
out_channels=out_channels,
bn_channels=bn_channels,
kernel_size=kernel_size,
stride=stride,
dilation=dilation,
bias=bias,
config_str=config_str,
memory_efficient=memory_efficient)
self.add_module('tdnnd%d' % (i + 1), layer)
def forward(self, x):
for layer in self:
x = torch.cat([x, layer(x)], dim=1)
return x
class StatsSelect(nn.Module):
def __init__(self, channels, branches, null=False, reduction=1):
super(StatsSelect, self).__init__()
self.gather = HighOrderStatsPool()
self.linear1 = nn.Conv1d(channels * 4, channels // reduction, 1)
self.linear2 = nn.ModuleList()
if null:
branches += 1
for _ in range(branches):
self.linear2.append(nn.Conv1d(channels // reduction, channels, 1))
self.channels = channels
self.branches = branches
self.null = null
self.reduction = reduction
def forward(self, x):
f = torch.cat([_x.unsqueeze(dim=1) for _x in x], dim=1)
x = torch.sum(f, dim=1)
x = self.linear1(self.gather(x).unsqueeze(dim=-1))
s = []
for linear in self.linear2:
s.append(linear(x).view(-1, 1, self.channels))
s = torch.cat(s, dim=1)
s = F.softmax(s, dim=1).unsqueeze(dim=-1)
if self.null:
s = s[:, :-1, :, :]
return torch.sum(f * s, dim=1)
def extra_repr(self):
return 'channels={}, branches={}, reduction={}'.format(
self.channels, self.branches, self.reduction)
class TransitLayer(nn.Module):
def __init__(self,
in_channels,
out_channels,
bias=True,
config_str='batchnorm-relu'):
super(TransitLayer, self).__init__()
self.nonlinear = get_nonlinear(config_str, in_channels)
self.linear = nn.Conv1d(in_channels, out_channels, 1, bias=bias)
def forward(self, x):
x = self.nonlinear(x)
# print("nonlinear", x)
x = self.linear(x)
# print("linear", x)
return x
class DenseLayer(nn.Module):
def __init__(self,
in_channels,
out_channels,
bias=False,
config_str='batchnorm-relu'):
super(DenseLayer, self).__init__()
self.linear = nn.Conv1d(in_channels, out_channels, 1, bias=bias)
self.nonlinear = get_nonlinear(config_str, out_channels)
def forward(self, x):
if len(x.shape) == 2:
x = self.linear(x.unsqueeze(dim=-1)).squeeze(dim=-1)
else:
x = self.linear(x)
x = self.nonlinear(x)
return x
from collections import OrderedDict
from torch import nn
class DTDNN(nn.Module):
def __init__(self,
feat_dim=30,
embedding_size=512,
num_classes=None,
growth_rate=64,
bn_size=2,
init_channels=128,
config_str='batchnorm-relu',
memory_efficient=True):
super(DTDNN, self).__init__()
self.xvector = nn.Sequential(
OrderedDict([
('tdnn',
TDNNLayer(feat_dim,
init_channels,
5,
dilation=1,
padding=-1,
config_str=config_str)),
]))
channels = init_channels
for i, (num_layers, kernel_size,
dilation) in enumerate(zip((6, 12), (3, 3), (1, 3))):
block = DenseTDNNBlock(num_layers=num_layers,
in_channels=channels,
out_channels=growth_rate,
bn_channels=bn_size * growth_rate,
kernel_size=kernel_size,
dilation=dilation,
config_str=config_str,
memory_efficient=memory_efficient)
self.xvector.add_module('block%d' % (i + 1), block)
channels = channels + num_layers * growth_rate
self.xvector.add_module(
'transit%d' % (i + 1),
TransitLayer(channels,
channels // 2,
bias=False,
config_str=config_str))
channels //= 2
self.xvector.add_module('stats', StatsPool())
self.xvector.add_module(
'dense',
DenseLayer(channels * 2, embedding_size, config_str='batchnorm_'))
if num_classes is not None:
self.classifier = nn.Linear(embedding_size, num_classes)
self.softmax = nn.Softmax()
for m in self.modules():
if isinstance(m, (nn.Conv1d, nn.Linear)):
nn.init.kaiming_normal_(m.weight.data)
if m.bias is not None:
nn.init.zeros_(m.bias)
def forward(self, x):
x = x.unsqueeze(1).permute(0,2,1)
x = self.xvector(x)
x = self.classifier(x)
# x = self.softmax(x)
return x
def pred_audio(path):
audio = [path]
audio_ds = AudioDataset(audio)
audio_ds = PadDataset(audio_ds)
audio_ds = mfcc(
directory_or_audiodataset=audio_ds,
transformkwargs={}
)
audio_ds = double_delta(audio_ds)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
cnn_model = ShallowCNN(in_features= 1,out_dim=1).to(device)
cnn_checkpoint = torch.load("./models/best_cnn.pt", map_location=device)
cnn_model.load_state_dict(cnn_checkpoint['state_dict'])
lstm_model = SimpleLSTM(feat_dim= 40, time_dim= 972, mid_dim= 30, out_dim= 1).to(device)
lstm_checkpoint = torch.load("./models/best_lstm.pt", map_location=device)
lstm_model.load_state_dict(lstm_checkpoint['state_dict'])
dtdnn_model = DTDNN(feat_dim= 38880,num_classes= 1).to(device)
dtdnn_checkpoint = torch.load("./models/best_tdnn.pt", map_location=device)
dtdnn_model.load_state_dict(dtdnn_checkpoint['state_dict'])
# Set models to evaluation mode
cnn_model.eval()
lstm_model.eval()
dtdnn_model.eval()
# Prepare input data
input_data = audio_ds[0][0].unsqueeze(0)
# Forward pass through CNN model
cnn_output = cnn_model(input_data)
cnn_prob = torch.sigmoid(cnn_output)
# Forward pass through LSTM model
lstm_output = lstm_model(input_data)
lstm_prob = torch.sigmoid(lstm_output)
# Forward pass through DT-DNN model
dtdnn_input = input_data.view(input_data.size(0), -1)
dtdnn_output = dtdnn_model(dtdnn_input)
dtdnn_prob = torch.sigmoid(dtdnn_output)
# Combine predictions
combined_prob = (cnn_prob + lstm_prob + dtdnn_prob) / 3
# Classify based on combined probabilities
combined_pred = (combined_prob >= 0.5).int()
cnn_pred = (cnn_prob >= 0.5).int()
lstm_pred = (lstm_prob >= 0.5).int()
dtdnn_pred = (dtdnn_prob >= 0.5).int()
return [cnn_pred.item(), lstm_pred.item(), dtdnn_pred.item(), combined_pred.item()]