nachi1326 commited on
Commit
9c9e7c7
·
verified ·
1 Parent(s): d761d28

Create model.py

Browse files
Files changed (1) hide show
  1. model.py +679 -0
model.py ADDED
@@ -0,0 +1,679 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import os
3
+ from pathlib import Path
4
+ from typing import Callable, Dict, List, Optional, Tuple, Union
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn as nn
9
+ import torchaudio
10
+ from torch.utils.data import Dataset
11
+ from torch import flatten
12
+ from typing import Optional
13
+ import torchaudio.functional as F
14
+ import random
15
+
16
+
17
+
18
+ def find_wav_files(path_to_dir: Union[Path, str]):
19
+ paths = list(sorted(Path(path_to_dir).glob("**/*.wav")))
20
+
21
+ if len(paths) == 0:
22
+ return None
23
+
24
+ return paths
25
+
26
+
27
+ def set_seed_all(seed: int = 0):
28
+
29
+ if not isinstance(seed, int):
30
+ seed = 0
31
+ random.seed(seed)
32
+ np.random.seed(seed)
33
+ torch.manual_seed(seed)
34
+
35
+ if torch.cuda.is_available():
36
+ torch.cuda.manual_seed(seed)
37
+ torch.cuda.manual_seed_all(seed)
38
+ torch.backends.cudnn.benchmark = False
39
+ torch.backends.cudnn.deterministic = True
40
+
41
+ os.environ["PYTHONHASHSEED"] = str(seed)
42
+ return None
43
+
44
+ SOX_SILENCE = [
45
+ ["silence", "1", "0.2", "1%", "-1", "0.2", "1%"],
46
+ ]
47
+ class AudioDataset(Dataset):
48
+
49
+ def __init__(
50
+ self,
51
+ directory_or_path_list: Union[Union[str, Path], List[Union[str, Path]]],
52
+ sample_rate: int = 16_000,
53
+ amount: Optional[int] = None,
54
+ normalize: bool = True,
55
+ trim: bool = True
56
+ ) :
57
+ super().__init__()
58
+
59
+ self.trim = trim
60
+ self.sample_rate = sample_rate
61
+ self.normalize = normalize
62
+
63
+ if isinstance(directory_or_path_list, list):
64
+ paths = directory_or_path_list
65
+ elif isinstance(directory_or_path_list, Path) or isinstance(
66
+ directory_or_path_list, str
67
+ ):
68
+ directory = Path(directory_or_path_list)
69
+
70
+ paths = find_wav_files(directory)
71
+
72
+ if amount is not None:
73
+ paths = paths[:amount]
74
+
75
+ self._paths = paths
76
+
77
+ def __getitem__(self, index: int) -> Tuple[torch.Tensor, int]:
78
+ path = self._paths[index]
79
+
80
+ waveform, sample_rate = torchaudio.load(path, normalize=self.normalize)
81
+
82
+ if sample_rate != self.sample_rate:
83
+ waveform, sample_rate = torchaudio.sox_effects.apply_effects_file(
84
+ path, [["rate", f"{self.sample_rate}"]], normalize=self.normalize
85
+ )
86
+
87
+ if self.trim:
88
+ (
89
+ waveform_trimmed,
90
+ sample_rate_trimmed,
91
+ ) = torchaudio.sox_effects.apply_effects_tensor(
92
+ waveform, sample_rate, SOX_SILENCE
93
+ )
94
+
95
+ if waveform_trimmed.size()[1] > 0:
96
+ waveform = waveform_trimmed
97
+ sample_rate = sample_rate_trimmed
98
+
99
+ audio_path = str(path)
100
+
101
+ return waveform, sample_rate, str(audio_path)
102
+
103
+ def __len__(self) -> int:
104
+ return len(self._paths)
105
+
106
+
107
+ class PadDataset(Dataset):
108
+ def __init__(self, dataset: Dataset, cut: int = 64600, label=None):
109
+ self.dataset = dataset
110
+ self.cut = cut
111
+ self.label = label
112
+
113
+ def __getitem__(self, index):
114
+ waveform, sample_rate, audio_path = self.dataset[index]
115
+ waveform = waveform.squeeze(0)
116
+ waveform_len = waveform.shape[0]
117
+ if waveform_len >= self.cut:
118
+ if self.label is None:
119
+ return waveform[: self.cut], sample_rate, str(audio_path)
120
+ else:
121
+ return waveform[: self.cut], sample_rate, str(audio_path), self.label
122
+ # need to pad
123
+ num_repeats = int(self.cut / waveform_len) + 1
124
+ padded_waveform = torch.tile(waveform, (1, num_repeats))[:, : self.cut][0]
125
+
126
+ if self.label is None:
127
+ return padded_waveform, sample_rate, str(audio_path)
128
+ else:
129
+ return padded_waveform, sample_rate, str(audio_path), self.label
130
+
131
+ def __len__(self):
132
+ return len(self.dataset)
133
+
134
+
135
+ class TransformDataset(Dataset):
136
+
137
+ def __init__(
138
+ self,
139
+ dataset: Dataset,
140
+ transformation: Callable,
141
+ needs_sample_rate: bool = False,
142
+ transform_kwargs: dict = {},
143
+ ) -> None:
144
+ super().__init__()
145
+ self._dataset = dataset
146
+
147
+ self._transform_constructor = transformation
148
+ self._needs_sample_rate = needs_sample_rate
149
+ self._transform_kwargs = transform_kwargs
150
+
151
+ self._transform = None
152
+
153
+ def __len__(self):
154
+ return len(self._dataset)
155
+
156
+ def __getitem__(self, index: int) -> Tuple[torch.Tensor, int]:
157
+ waveform, sample_rate, audio_path = self._dataset[index]
158
+
159
+ if self._transform is None:
160
+ if self._needs_sample_rate:
161
+ self._transform = self._transform_constructor(
162
+ sample_rate, **self._transform_kwargs
163
+ )
164
+ else:
165
+ self._transform = self._transform_constructor(**self._transform_kwargs)
166
+
167
+ return self._transform(waveform), sample_rate, str(audio_path)
168
+
169
+
170
+ class DoubleDeltaTransform(torch.nn.Module):
171
+
172
+ def __init__(self, win_length: int = 5, mode: str = "replicate"):
173
+ super().__init__()
174
+ self.win_length = win_length
175
+ self.mode = mode
176
+
177
+ self._delta = torchaudio.transforms.ComputeDeltas(
178
+ win_length=self.win_length, mode=self.mode
179
+ )
180
+
181
+ def forward(self, X):
182
+
183
+ delta = self._delta(X)
184
+ double_delta = self._delta(delta)
185
+
186
+ return torch.hstack((X, delta, double_delta))
187
+
188
+
189
+ def _build_preprocessing(
190
+ directory_or_audiodataset: Union[Union[str, Path], AudioDataset],
191
+ transform: torch.nn.Module,
192
+ audiokwargs: dict = {},
193
+ transformkwargs: dict = {},
194
+ ):
195
+ if isinstance(directory_or_audiodataset, AudioDataset) or isinstance(
196
+ directory_or_audiodataset, PadDataset
197
+ ):
198
+ return TransformDataset(
199
+ dataset=directory_or_audiodataset,
200
+ transformation=transform,
201
+ needs_sample_rate=True,
202
+ transform_kwargs=transformkwargs,
203
+ )
204
+ elif isinstance(directory_or_audiodataset, str) or isinstance(
205
+ directory_or_audiodataset, Path
206
+ ):
207
+ return TransformDataset(
208
+ dataset=AudioDataset(directory=directory_or_audiodataset, **audiokwargs),
209
+ transformation=transform,
210
+ needs_sample_rate=True,
211
+ transform_kwargs=transformkwargs,
212
+ )
213
+
214
+
215
+ mfcc = functools.partial(_build_preprocessing, transform=torchaudio.transforms.MFCC)
216
+
217
+ def double_delta(dataset: Dataset, delta_kwargs: dict = {}) -> TransformDataset:
218
+ return TransformDataset(
219
+ dataset=dataset,
220
+ transformation=DoubleDeltaTransform,
221
+ transform_kwargs=delta_kwargs,
222
+ )
223
+
224
+
225
+ # def load_directory_split_train_test(
226
+ # path: Union[Path, str],
227
+ # feature_fn: Callable,
228
+ # feature_kwargs: dict,
229
+ # test_size: float,
230
+ # use_double_delta: bool = True,
231
+ # pad: bool = False,
232
+ # label: Optional[int] = None,
233
+ # ):
234
+
235
+ # paths = find_wav_files(path)
236
+
237
+ # test_size = int(test_size * len(paths))
238
+
239
+ # train_paths = paths[:-test_size]
240
+ # test_paths = paths[-test_size:]
241
+
242
+ # train_dataset = AudioDataset(train_paths)
243
+ # if pad:
244
+ # train_dataset = PadDataset(train_dataset, label=label)
245
+
246
+ # test_dataset = AudioDataset(test_paths)
247
+ # if pad:
248
+ # test_dataset = PadDataset(test_dataset, label=label)
249
+
250
+ # dataset_train = feature_fn(
251
+ # directory_or_audiodataset=train_dataset,
252
+ # transformkwargs=feature_kwargs,
253
+ # )
254
+
255
+ # dataset_test = feature_fn(
256
+ # directory_or_audiodataset=test_dataset,
257
+ # transformkwargs=feature_kwargs,
258
+ # )
259
+ # if use_double_delta:
260
+ # dataset_train = double_delta(dataset_train)
261
+ # dataset_test = double_delta(dataset_test)
262
+
263
+ # return dataset_train, dataset_test
264
+
265
+ audio = ["/kaggle/input/the-lj-speech-dataset/LJSpeech-1.1/wavs/LJ001-0001.wav"]
266
+
267
+ train_dataset = AudioDataset(audio)
268
+ train_dataset = PadDataset(train_dataset)
269
+
270
+ dataset_train = mfcc(
271
+ directory_or_audiodataset=train_dataset,
272
+ transformkwargs={}
273
+ )
274
+
275
+ dataset_train = double_delta(dataset_train)
276
+
277
+ print(dataset_train[0][0].shape)
278
+
279
+ class ShallowCNN(nn.Module):
280
+ def __init__(self, in_features, out_dim, **kwargs):
281
+ super(ShallowCNN, self).__init__()
282
+ self.conv1 = nn.Conv2d(in_features, 32, kernel_size=4, stride=1, padding=1)
283
+ self.conv2 = nn.Conv2d(32, 48, kernel_size=5, stride=1, padding=1)
284
+ self.conv3 = nn.Conv2d(48, 64, kernel_size=4, stride=1, padding=1)
285
+ self.conv4 = nn.Conv2d(64, 128, kernel_size=(2, 4), stride=1, padding=1)
286
+ self.pool = nn.MaxPool2d(2, 2)
287
+ self.fc1 = nn.Linear(15104, 128)
288
+ self.fc2 = nn.Linear(128, out_dim)
289
+ self.relu = nn.ReLU()
290
+
291
+ def forward(self, x: torch.Tensor):
292
+ x = x.unsqueeze(1)
293
+ x = self.pool(self.relu(self.conv1(x)))
294
+ x = self.pool(self.relu(self.conv2(x)))
295
+ x = self.pool(self.relu(self.conv3(x)))
296
+ x = self.pool(self.relu(self.conv4(x)))
297
+ x = flatten(x, 1)
298
+ x = self.relu(self.fc1(x))
299
+ x = self.fc2(x)
300
+ return x
301
+
302
+ class SimpleLSTM(nn.Module):
303
+ def __init__(
304
+ self,
305
+ feat_dim: int,
306
+ time_dim: int,
307
+ mid_dim: int,
308
+ out_dim: int,
309
+ **kwargs,
310
+ ):
311
+ super(SimpleLSTM, self).__init__()
312
+
313
+ self.lstm = nn.LSTM(
314
+ input_size=feat_dim,
315
+ hidden_size=mid_dim,
316
+ num_layers=2,
317
+ bidirectional=True,
318
+ batch_first=True,
319
+ dropout=0.01,
320
+ )
321
+ self.conv = nn.Conv1d(in_channels=mid_dim * 2, out_channels=10, kernel_size=1)
322
+ self.relu = nn.ReLU()
323
+ self.fc = nn.Linear(in_features=time_dim * 10, out_features=out_dim)
324
+
325
+ def forward(self, x: torch.Tensor):
326
+ B = x.size(0)
327
+
328
+ x = x.permute(0, 2, 1)
329
+
330
+ lstm_out, _ = self.lstm(x)
331
+
332
+ feat = lstm_out.permute(0, 2, 1)
333
+
334
+ feat = self.conv(feat)
335
+ feat = self.relu(feat)
336
+ feat = feat.reshape(B, -1)
337
+ out = self.fc(feat)
338
+
339
+ return out
340
+
341
+ import torch
342
+ import torch.nn.functional as F
343
+ import torch.utils.checkpoint as cp
344
+ from torch import nn
345
+
346
+
347
+ def get_nonlinear(config_str, channels):
348
+ nonlinear = nn.Sequential()
349
+ for name in config_str.split('-'):
350
+ if name == 'relu':
351
+ nonlinear.add_module('relu', nn.ReLU(inplace=True))
352
+ elif name == 'prelu':
353
+ nonlinear.add_module('prelu', nn.PReLU(channels))
354
+ elif name == 'batchnorm':
355
+ nonlinear.add_module('batchnorm', nn.BatchNorm1d(channels))
356
+ elif name == 'batchnorm_':
357
+ nonlinear.add_module('batchnorm',
358
+ nn.BatchNorm1d(channels, affine=False))
359
+ else:
360
+ raise ValueError('Unexpected module ({}).'.format(name))
361
+ return nonlinear
362
+
363
+
364
+ def statistics_pooling(x, dim=-1, keepdim=False, unbiased=True, eps=1e-2):
365
+
366
+ mean = x.mean(dim=dim)
367
+ std = x.std(dim=dim, unbiased=False)
368
+ stats = torch.cat([mean, std], dim=-1)
369
+ if keepdim:
370
+ stats = stats.unsqueeze(dim=dim)
371
+ return stats
372
+
373
+
374
+ def high_order_statistics_pooling(x,
375
+ dim=-1,
376
+ keepdim=False,
377
+ unbiased=True,
378
+ eps=1e-2):
379
+ mean = x.mean(dim=dim)
380
+ std = x.std(dim=dim, unbiased=unbiased)
381
+ norm = (x - mean.unsqueeze(dim=dim)) \
382
+ / std.clamp(min=eps).unsqueeze(dim=dim)
383
+ skewness = norm.pow(3).mean(dim=dim)
384
+ kurtosis = norm.pow(4).mean(dim=dim)
385
+ stats = torch.cat([mean, std, skewness, kurtosis], dim=-1)
386
+ if keepdim:
387
+ stats = stats.unsqueeze(dim=dim)
388
+ return stats
389
+
390
+
391
+ class StatsPool(nn.Module):
392
+ def forward(self, x):
393
+ ret = statistics_pooling(x)
394
+ return ret
395
+
396
+
397
+ class HighOrderStatsPool(nn.Module):
398
+ def forward(self, x):
399
+ return high_order_statistics_pooling(x)
400
+
401
+
402
+ class TDNNLayer(nn.Module):
403
+ def __init__(self,
404
+ in_channels,
405
+ out_channels,
406
+ kernel_size,
407
+ stride=1,
408
+ padding=0,
409
+ dilation=1,
410
+ bias=False,
411
+ config_str='batchnorm-relu'):
412
+ super(TDNNLayer, self).__init__()
413
+ if padding < 0:
414
+ assert kernel_size % 2 == 1, 'Expect equal paddings, but got even kernel size ({})'.format(
415
+ kernel_size)
416
+ padding = (kernel_size - 1) // 2 * dilation
417
+ self.linear = nn.Conv1d(in_channels,
418
+ out_channels,
419
+ kernel_size,
420
+ stride=stride,
421
+ padding=padding,
422
+ dilation=dilation,
423
+ bias=bias)
424
+ self.nonlinear = get_nonlinear(config_str, out_channels)
425
+
426
+ def forward(self, x):
427
+ x = self.linear(x)
428
+ # print("linear", x)
429
+ x = self.nonlinear(x)
430
+ # print("nonlinear", x)
431
+ return x
432
+
433
+
434
+ class DenseTDNNLayer(nn.Module):
435
+ def __init__(self,
436
+ in_channels,
437
+ out_channels,
438
+ bn_channels,
439
+ kernel_size,
440
+ stride=1,
441
+ dilation=1,
442
+ bias=False,
443
+ config_str='batchnorm-relu',
444
+ memory_efficient=False):
445
+ super(DenseTDNNLayer, self).__init__()
446
+ assert kernel_size % 2 == 1, 'Expect equal paddings, but got even kernel size ({})'.format(
447
+ kernel_size)
448
+ padding = (kernel_size - 1) // 2 * dilation
449
+ self.memory_efficient = memory_efficient
450
+ self.nonlinear1 = get_nonlinear(config_str, in_channels)
451
+ self.linear1 = nn.Conv1d(in_channels, bn_channels, 1, bias=False)
452
+ self.nonlinear2 = get_nonlinear(config_str, bn_channels)
453
+ self.linear2 = nn.Conv1d(bn_channels,
454
+ out_channels,
455
+ kernel_size,
456
+ stride=stride,
457
+ padding=padding,
458
+ dilation=dilation,
459
+ bias=bias)
460
+
461
+ def bn_function(self, x):
462
+ return self.linear1(self.nonlinear1(x))
463
+
464
+ def forward(self, x):
465
+
466
+ x = self.bn_function(x)
467
+ x = self.linear2(self.nonlinear2(x))
468
+ return x
469
+
470
+
471
+ class DenseTDNNBlock(nn.ModuleList):
472
+ def __init__(self,
473
+ num_layers,
474
+ in_channels,
475
+ out_channels,
476
+ bn_channels,
477
+ kernel_size,
478
+ stride=1,
479
+ dilation=1,
480
+ bias=False,
481
+ config_str='batchnorm-relu',
482
+ memory_efficient=False):
483
+ super(DenseTDNNBlock, self).__init__()
484
+ for i in range(num_layers):
485
+ layer = DenseTDNNLayer(in_channels=in_channels + i * out_channels,
486
+ out_channels=out_channels,
487
+ bn_channels=bn_channels,
488
+ kernel_size=kernel_size,
489
+ stride=stride,
490
+ dilation=dilation,
491
+ bias=bias,
492
+ config_str=config_str,
493
+ memory_efficient=memory_efficient)
494
+ self.add_module('tdnnd%d' % (i + 1), layer)
495
+
496
+ def forward(self, x):
497
+ for layer in self:
498
+ x = torch.cat([x, layer(x)], dim=1)
499
+ return x
500
+
501
+
502
+ class StatsSelect(nn.Module):
503
+ def __init__(self, channels, branches, null=False, reduction=1):
504
+ super(StatsSelect, self).__init__()
505
+ self.gather = HighOrderStatsPool()
506
+ self.linear1 = nn.Conv1d(channels * 4, channels // reduction, 1)
507
+ self.linear2 = nn.ModuleList()
508
+ if null:
509
+ branches += 1
510
+ for _ in range(branches):
511
+ self.linear2.append(nn.Conv1d(channels // reduction, channels, 1))
512
+ self.channels = channels
513
+ self.branches = branches
514
+ self.null = null
515
+ self.reduction = reduction
516
+
517
+ def forward(self, x):
518
+ f = torch.cat([_x.unsqueeze(dim=1) for _x in x], dim=1)
519
+ x = torch.sum(f, dim=1)
520
+ x = self.linear1(self.gather(x).unsqueeze(dim=-1))
521
+ s = []
522
+ for linear in self.linear2:
523
+ s.append(linear(x).view(-1, 1, self.channels))
524
+ s = torch.cat(s, dim=1)
525
+ s = F.softmax(s, dim=1).unsqueeze(dim=-1)
526
+ if self.null:
527
+ s = s[:, :-1, :, :]
528
+ return torch.sum(f * s, dim=1)
529
+
530
+ def extra_repr(self):
531
+ return 'channels={}, branches={}, reduction={}'.format(
532
+ self.channels, self.branches, self.reduction)
533
+
534
+
535
+ class TransitLayer(nn.Module):
536
+ def __init__(self,
537
+ in_channels,
538
+ out_channels,
539
+ bias=True,
540
+ config_str='batchnorm-relu'):
541
+ super(TransitLayer, self).__init__()
542
+ self.nonlinear = get_nonlinear(config_str, in_channels)
543
+ self.linear = nn.Conv1d(in_channels, out_channels, 1, bias=bias)
544
+
545
+ def forward(self, x):
546
+ x = self.nonlinear(x)
547
+ # print("nonlinear", x)
548
+ x = self.linear(x)
549
+ # print("linear", x)
550
+ return x
551
+
552
+
553
+ class DenseLayer(nn.Module):
554
+ def __init__(self,
555
+ in_channels,
556
+ out_channels,
557
+ bias=False,
558
+ config_str='batchnorm-relu'):
559
+ super(DenseLayer, self).__init__()
560
+ self.linear = nn.Conv1d(in_channels, out_channels, 1, bias=bias)
561
+ self.nonlinear = get_nonlinear(config_str, out_channels)
562
+
563
+ def forward(self, x):
564
+ if len(x.shape) == 2:
565
+ x = self.linear(x.unsqueeze(dim=-1)).squeeze(dim=-1)
566
+ else:
567
+ x = self.linear(x)
568
+ x = self.nonlinear(x)
569
+ return x
570
+
571
+ from collections import OrderedDict
572
+
573
+ from torch import nn
574
+
575
+ class DTDNN(nn.Module):
576
+ def __init__(self,
577
+ feat_dim=30,
578
+ embedding_size=512,
579
+ num_classes=None,
580
+ growth_rate=64,
581
+ bn_size=2,
582
+ init_channels=128,
583
+ config_str='batchnorm-relu',
584
+ memory_efficient=True):
585
+ super(DTDNN, self).__init__()
586
+
587
+ self.xvector = nn.Sequential(
588
+ OrderedDict([
589
+ ('tdnn',
590
+ TDNNLayer(feat_dim,
591
+ init_channels,
592
+ 5,
593
+ dilation=1,
594
+ padding=-1,
595
+ config_str=config_str)),
596
+ ]))
597
+ channels = init_channels
598
+ for i, (num_layers, kernel_size,
599
+ dilation) in enumerate(zip((6, 12), (3, 3), (1, 3))):
600
+ block = DenseTDNNBlock(num_layers=num_layers,
601
+ in_channels=channels,
602
+ out_channels=growth_rate,
603
+ bn_channels=bn_size * growth_rate,
604
+ kernel_size=kernel_size,
605
+ dilation=dilation,
606
+ config_str=config_str,
607
+ memory_efficient=memory_efficient)
608
+ self.xvector.add_module('block%d' % (i + 1), block)
609
+ channels = channels + num_layers * growth_rate
610
+ self.xvector.add_module(
611
+ 'transit%d' % (i + 1),
612
+ TransitLayer(channels,
613
+ channels // 2,
614
+ bias=False,
615
+ config_str=config_str))
616
+ channels //= 2
617
+ self.xvector.add_module('stats', StatsPool())
618
+ self.xvector.add_module(
619
+ 'dense',
620
+ DenseLayer(channels * 2, embedding_size, config_str='batchnorm_'))
621
+ if num_classes is not None:
622
+ self.classifier = nn.Linear(embedding_size, num_classes)
623
+ self.softmax = nn.Softmax()
624
+
625
+ for m in self.modules():
626
+ if isinstance(m, (nn.Conv1d, nn.Linear)):
627
+ nn.init.kaiming_normal_(m.weight.data)
628
+ if m.bias is not None:
629
+ nn.init.zeros_(m.bias)
630
+
631
+ def forward(self, x):
632
+ x = x.unsqueeze(1).permute(0,2,1)
633
+ x = self.xvector(x)
634
+ x = self.classifier(x)
635
+ # x = self.softmax(x)
636
+ return x
637
+
638
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
639
+
640
+ cnn_model = ShallowCNN(in_features= 1,out_dim=1).to(device)
641
+ cnn_checkpoint = torch.load("/kaggle/input/deepfakemodels/best_cnn.pt", map_location=device)
642
+ cnn_model.load_state_dict(cnn_checkpoint['state_dict'])
643
+
644
+ lstm_model = SimpleLSTM(feat_dim= 40, time_dim= 972, mid_dim= 30, out_dim= 1).to(device)
645
+ lstm_checkpoint = torch.load("/kaggle/input/deepfakemodels/best_lstm.pt", map_location=device)
646
+ lstm_model.load_state_dict(lstm_checkpoint['state_dict'])
647
+
648
+ dtdnn_model = DTDNN(feat_dim= 38880,num_classes= 1).to(device)
649
+ dtdnn_checkpoint = torch.load("/kaggle/input/deepfakemodels/best_tdnn.pt", map_location=device)
650
+ dtdnn_model.load_state_dict(dtdnn_checkpoint['state_dict'])
651
+
652
+ # Set models to evaluation mode
653
+ cnn_model.eval()
654
+ lstm_model.eval()
655
+ dtdnn_model.eval()
656
+
657
+ # Prepare input data
658
+ input_data = dataset_train[0][0].unsqueeze(0)
659
+
660
+ # Forward pass through CNN model
661
+ cnn_output = cnn_model(input_data)
662
+ cnn_prob = torch.sigmoid(cnn_output)
663
+
664
+ # Forward pass through LSTM model
665
+ lstm_output = lstm_model(input_data)
666
+ lstm_prob = torch.sigmoid(lstm_output)
667
+
668
+ # Forward pass through DT-DNN model
669
+ dtdnn_input = input_data.view(input_data.size(0), -1)
670
+ dtdnn_output = dtdnn_model(dtdnn_input)
671
+ dtdnn_prob = torch.sigmoid(dtdnn_output)
672
+
673
+ # Combine predictions
674
+ combined_prob = (cnn_prob + lstm_prob + dtdnn_prob) / 3
675
+
676
+ # Classify based on combined probabilities
677
+ combined_pred = (combined_prob >= 0.5).int()
678
+
679
+ print(combined_pred.item())