karlhajal commited on
Commit
863c9ae
·
verified ·
1 Parent(s): 2394567

Update src/pronunciation_checker.py

Browse files
Files changed (1) hide show
  1. src/pronunciation_checker.py +6 -2
src/pronunciation_checker.py CHANGED
@@ -35,15 +35,19 @@ class PronunciationChecker:
35
  return wav
36
 
37
 
38
- def preprocess_wav(self, wav_path):
39
  temp_audio_path = "temp.wav"
40
  audio_segment = process_wav(wav_path, 16000, do_trim_silences=False)
41
  audio_segment.export(temp_audio_path, format="wav")
42
 
43
- wav = self.trim_silences_with_silero_vad(temp_audio_path)
 
 
 
44
 
45
  return wav, self.sr
46
 
 
47
  def extract_features(self, wav, layer=None):
48
  inputs = self.processor(wav.squeeze().to(self.device), sampling_rate=16000, return_tensors="pt", padding=True)
49
  inputs = {key: val.to(self.device) for key, val in inputs.items()}
 
35
  return wav
36
 
37
 
38
+ def preprocess_wav(self, wav_path, do_trim_silences=True):
39
  temp_audio_path = "temp.wav"
40
  audio_segment = process_wav(wav_path, 16000, do_trim_silences=False)
41
  audio_segment.export(temp_audio_path, format="wav")
42
 
43
+ if do_trim_silences:
44
+ wav = self.trim_silences_with_silero_vad(temp_audio_path)
45
+ else:
46
+ wav = self.read_audio(temp_audio_path)
47
 
48
  return wav, self.sr
49
 
50
+
51
  def extract_features(self, wav, layer=None):
52
  inputs = self.processor(wav.squeeze().to(self.device), sampling_rate=16000, return_tensors="pt", padding=True)
53
  inputs = {key: val.to(self.device) for key, val in inputs.items()}