Spaces:
Build error
Build error
Update model.py
Browse files
model.py
CHANGED
|
@@ -13,8 +13,6 @@ from typing import Optional
|
|
| 13 |
import torchaudio.functional as F
|
| 14 |
import random
|
| 15 |
|
| 16 |
-
|
| 17 |
-
|
| 18 |
def find_wav_files(path_to_dir: Union[Path, str]):
|
| 19 |
paths = list(sorted(Path(path_to_dir).glob("**/*.wav")))
|
| 20 |
|
|
@@ -627,7 +625,7 @@ def pred_audio(path):
|
|
| 627 |
audio = [path]
|
| 628 |
|
| 629 |
audio_ds = AudioDataset(audio)
|
| 630 |
-
audio_ds = PadDataset(
|
| 631 |
|
| 632 |
audio_ds = mfcc(
|
| 633 |
directory_or_audiodataset=audio_ds,
|
|
@@ -635,20 +633,19 @@ def pred_audio(path):
|
|
| 635 |
)
|
| 636 |
|
| 637 |
audio_ds = double_delta(audio_ds)
|
| 638 |
-
|
| 639 |
|
| 640 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 641 |
|
| 642 |
cnn_model = ShallowCNN(in_features= 1,out_dim=1).to(device)
|
| 643 |
-
cnn_checkpoint = torch.load("models/best_cnn.pt", map_location=device)
|
| 644 |
cnn_model.load_state_dict(cnn_checkpoint['state_dict'])
|
| 645 |
|
| 646 |
lstm_model = SimpleLSTM(feat_dim= 40, time_dim= 972, mid_dim= 30, out_dim= 1).to(device)
|
| 647 |
-
lstm_checkpoint = torch.load("models/best_lstm.pt", map_location=device)
|
| 648 |
lstm_model.load_state_dict(lstm_checkpoint['state_dict'])
|
| 649 |
|
| 650 |
dtdnn_model = DTDNN(feat_dim= 38880,num_classes= 1).to(device)
|
| 651 |
-
dtdnn_checkpoint = torch.load("models/best_tdnn.pt", map_location=device)
|
| 652 |
dtdnn_model.load_state_dict(dtdnn_checkpoint['state_dict'])
|
| 653 |
|
| 654 |
# Set models to evaluation mode
|
|
|
|
| 13 |
import torchaudio.functional as F
|
| 14 |
import random
|
| 15 |
|
|
|
|
|
|
|
| 16 |
def find_wav_files(path_to_dir: Union[Path, str]):
|
| 17 |
paths = list(sorted(Path(path_to_dir).glob("**/*.wav")))
|
| 18 |
|
|
|
|
| 625 |
audio = [path]
|
| 626 |
|
| 627 |
audio_ds = AudioDataset(audio)
|
| 628 |
+
audio_ds = PadDataset(audio_ds)
|
| 629 |
|
| 630 |
audio_ds = mfcc(
|
| 631 |
directory_or_audiodataset=audio_ds,
|
|
|
|
| 633 |
)
|
| 634 |
|
| 635 |
audio_ds = double_delta(audio_ds)
|
|
|
|
| 636 |
|
| 637 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 638 |
|
| 639 |
cnn_model = ShallowCNN(in_features= 1,out_dim=1).to(device)
|
| 640 |
+
cnn_checkpoint = torch.load("./models/best_cnn.pt", map_location=device)
|
| 641 |
cnn_model.load_state_dict(cnn_checkpoint['state_dict'])
|
| 642 |
|
| 643 |
lstm_model = SimpleLSTM(feat_dim= 40, time_dim= 972, mid_dim= 30, out_dim= 1).to(device)
|
| 644 |
+
lstm_checkpoint = torch.load("./models/best_lstm.pt", map_location=device)
|
| 645 |
lstm_model.load_state_dict(lstm_checkpoint['state_dict'])
|
| 646 |
|
| 647 |
dtdnn_model = DTDNN(feat_dim= 38880,num_classes= 1).to(device)
|
| 648 |
+
dtdnn_checkpoint = torch.load("./models/best_tdnn.pt", map_location=device)
|
| 649 |
dtdnn_model.load_state_dict(dtdnn_checkpoint['state_dict'])
|
| 650 |
|
| 651 |
# Set models to evaluation mode
|