nachi1326 commited on
Commit
9c711c1
·
verified ·
1 Parent(s): f621eda

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +4 -7
model.py CHANGED
@@ -13,8 +13,6 @@ from typing import Optional
13
  import torchaudio.functional as F
14
  import random
15
 
16
-
17
-
18
  def find_wav_files(path_to_dir: Union[Path, str]):
19
  paths = list(sorted(Path(path_to_dir).glob("**/*.wav")))
20
 
@@ -627,7 +625,7 @@ def pred_audio(path):
627
  audio = [path]
628
 
629
  audio_ds = AudioDataset(audio)
630
- audio_ds = PadDataset(train_dataset)
631
 
632
  audio_ds = mfcc(
633
  directory_or_audiodataset=audio_ds,
@@ -635,20 +633,19 @@ def pred_audio(path):
635
  )
636
 
637
  audio_ds = double_delta(audio_ds)
638
-
639
 
640
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
641
 
642
  cnn_model = ShallowCNN(in_features= 1,out_dim=1).to(device)
643
- cnn_checkpoint = torch.load("models/best_cnn.pt", map_location=device)
644
  cnn_model.load_state_dict(cnn_checkpoint['state_dict'])
645
 
646
  lstm_model = SimpleLSTM(feat_dim= 40, time_dim= 972, mid_dim= 30, out_dim= 1).to(device)
647
- lstm_checkpoint = torch.load("models/best_lstm.pt", map_location=device)
648
  lstm_model.load_state_dict(lstm_checkpoint['state_dict'])
649
 
650
  dtdnn_model = DTDNN(feat_dim= 38880,num_classes= 1).to(device)
651
- dtdnn_checkpoint = torch.load("models/best_tdnn.pt", map_location=device)
652
  dtdnn_model.load_state_dict(dtdnn_checkpoint['state_dict'])
653
 
654
  # Set models to evaluation mode
 
13
  import torchaudio.functional as F
14
  import random
15
 
 
 
16
  def find_wav_files(path_to_dir: Union[Path, str]):
17
  paths = list(sorted(Path(path_to_dir).glob("**/*.wav")))
18
 
 
625
  audio = [path]
626
 
627
  audio_ds = AudioDataset(audio)
628
+ audio_ds = PadDataset(audio_ds)
629
 
630
  audio_ds = mfcc(
631
  directory_or_audiodataset=audio_ds,
 
633
  )
634
 
635
  audio_ds = double_delta(audio_ds)
 
636
 
637
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
638
 
639
  cnn_model = ShallowCNN(in_features= 1,out_dim=1).to(device)
640
+ cnn_checkpoint = torch.load("./models/best_cnn.pt", map_location=device)
641
  cnn_model.load_state_dict(cnn_checkpoint['state_dict'])
642
 
643
  lstm_model = SimpleLSTM(feat_dim= 40, time_dim= 972, mid_dim= 30, out_dim= 1).to(device)
644
+ lstm_checkpoint = torch.load("./models/best_lstm.pt", map_location=device)
645
  lstm_model.load_state_dict(lstm_checkpoint['state_dict'])
646
 
647
  dtdnn_model = DTDNN(feat_dim= 38880,num_classes= 1).to(device)
648
+ dtdnn_checkpoint = torch.load("./models/best_tdnn.pt", map_location=device)
649
  dtdnn_model.load_state_dict(dtdnn_checkpoint['state_dict'])
650
 
651
  # Set models to evaluation mode