0ahkd1 commited on
Commit
d4d0c2c
·
1 Parent(s): 69e2bf6

refactored red-green decision code

Browse files
app.py CHANGED
@@ -26,17 +26,16 @@ def check_pronunciation(reference_audio, input_audio, threshold, wavlm_layer, la
26
 
27
  log_timing("Start")
28
 
 
 
 
 
29
  # Extract features from both audio files
30
  ref_wav, sr = pronunciation_checker.preprocess_wav(reference_audio)
31
  log_timing("Reference Audio Preprocessing")
32
 
33
  comparison_wav, _ = pronunciation_checker.preprocess_wav(input_audio)
34
  log_timing("Input Audio Preprocessing")
35
-
36
-
37
- # ref_wav = denoise_audio(ref_wav)
38
- # comparison_wav = denoise_audio(comparison_wav)
39
- # log_timing("Audio Denoising")
40
 
41
  # Check if waveforms are not empty
42
  if ref_wav is None or comparison_wav is None:
 
26
 
27
  log_timing("Start")
28
 
29
+ # ref_wav = denoise_audio(ref_wav)
30
+ input_audio = denoise_audio(input_audio)
31
+ log_timing("Input Audio Denoising")
32
+
33
  # Extract features from both audio files
34
  ref_wav, sr = pronunciation_checker.preprocess_wav(reference_audio)
35
  log_timing("Reference Audio Preprocessing")
36
 
37
  comparison_wav, _ = pronunciation_checker.preprocess_wav(input_audio)
38
  log_timing("Input Audio Preprocessing")
 
 
 
 
 
39
 
40
  # Check if waveforms are not empty
41
  if ref_wav is None or comparison_wav is None:
src/audio_preprocessing.py CHANGED
@@ -4,7 +4,6 @@ import numpy as np
4
  import webrtcvad
5
  from pydub import AudioSegment
6
  import subprocess
7
- import tempfile
8
 
9
 
10
  VAD_SR = 16000
@@ -106,19 +105,27 @@ def process_wav(wav_path, target_sr, do_trim_silences=True):
106
  return audio
107
 
108
 
109
- def assess_pronunciation_quality(dist_matrix, path):
110
- # Extract distances along the alignment path
111
- path_distances = [dist_matrix[i, j] for i, j in zip(*path)]
112
-
113
- num_wav_frames = len(dist_matrix) # For the reference wav
114
  wav_distances = [0] * num_wav_frames
115
  for (i, j) in zip(*path):
116
- wav_distances[i] = dist_matrix[i, j] # For the reference wav
 
 
 
 
 
117
 
118
- threshold = 0.3
 
 
 
119
 
120
  # Analyze normalized distances
121
- num_red_segments = sum(1 for d in wav_distances if d >= threshold)
122
  total_segments = len(wav_distances)
123
  red_percentage = num_red_segments / total_segments if total_segments > 0 else 0.0
124
 
@@ -128,9 +135,9 @@ def assess_pronunciation_quality(dist_matrix, path):
128
 
129
  # Print debug information
130
  print(f"Raw distance stats:")
131
- print(f" Min distance: {min(path_distances):.4f}")
132
- print(f" Max distance: {max(path_distances):.4f}")
133
- print(f" Mean distance: {np.mean(path_distances):.4f}")
134
  print(f"\nNormalized distance stats:")
135
  print(f" Number of red segments (>= 0.5): {num_red_segments}")
136
  print(f" Total segments: {total_segments}")
@@ -140,7 +147,7 @@ def assess_pronunciation_quality(dist_matrix, path):
140
 
141
 
142
  def denoise_audio(input_audio_path):
143
-
144
  output_audio_path = input_audio_path.replace(".wav", "_denoised.wav")
145
  subprocess.run(["denoise", input_audio_path, output_audio_path, "--plot"], check=True)
146
 
 
4
  import webrtcvad
5
  from pydub import AudioSegment
6
  import subprocess
 
7
 
8
 
9
  VAD_SR = 16000
 
105
  return audio
106
 
107
 
108
+ def get_red_green_segments(dist_matrix, path, wav_type='ref', threshold=0.3):
109
+ if wav_type == "ref":
110
+ num_wav_frames = len(dist_matrix)
111
+ else:
112
+ num_wav_frames = len(dist_matrix[0])
113
  wav_distances = [0] * num_wav_frames
114
  for (i, j) in zip(*path):
115
+ wav_distances[i] = dist_matrix[i, j]
116
+
117
+ red_segments = [i for i, d in enumerate(wav_distances) if d >= threshold]
118
+ green_segments = [i for i, d in enumerate(wav_distances) if d < threshold]
119
+
120
+ return red_segments, green_segments, wav_distances
121
 
122
+
123
+ def assess_pronunciation_quality(dist_matrix, path):
124
+ # _ is green_segments
125
+ red_segments, _, wav_distances = get_red_green_segments(dist_matrix, path, wav_type=None)
126
 
127
  # Analyze normalized distances
128
+ num_red_segments = len(red_segments)
129
  total_segments = len(wav_distances)
130
  red_percentage = num_red_segments / total_segments if total_segments > 0 else 0.0
131
 
 
135
 
136
  # Print debug information
137
  print(f"Raw distance stats:")
138
+ print(f" Min distance: {min(wav_distances):.4f}")
139
+ print(f" Max distance: {max(wav_distances):.4f}")
140
+ print(f" Mean distance: {np.mean(wav_distances):.4f}")
141
  print(f"\nNormalized distance stats:")
142
  print(f" Number of red segments (>= 0.5): {num_red_segments}")
143
  print(f" Total segments: {total_segments}")
 
147
 
148
 
149
  def denoise_audio(input_audio_path):
150
+ assert isinstance(input_audio_path, str), "Input path must be a string"
151
  output_audio_path = input_audio_path.replace(".wav", "_denoised.wav")
152
  subprocess.run(["denoise", input_audio_path, output_audio_path, "--plot"], check=True)
153
 
src/pronunciation_checker.py CHANGED
@@ -1,14 +1,13 @@
1
  # SPDX-FileContributor: Karl El Hajal
2
 
3
  import torch
4
- import torchaudio
5
  import numpy as np
6
  import matplotlib.pyplot as plt
7
  from transformers import AutoFeatureExtractor, AutoModel
8
  from scipy.spatial.distance import cdist
9
  from dtw import accelerated_dtw
10
 
11
- from src.audio_preprocessing import process_wav
12
 
13
  class PronunciationChecker:
14
  def __init__(self, model_name = "microsoft/wavlm-large"):
@@ -76,43 +75,19 @@ class PronunciationChecker:
76
 
77
  fig, ax = plt.subplots(3, 1, figsize=(15, 10), gridspec_kw={'height_ratios': [5, 1, 1]})
78
 
79
- # Plot the reference waveform
80
  ax[0].plot(time_ref, wav, label="Waveform", color="blue", alpha=0.7)
81
-
82
- # DTW distance overlay
83
- if wav_type == "ref":
84
- num_wav_frames = len(dist_matrix)
85
- else:
86
- num_wav_frames = len(dist_matrix[0])
87
 
88
- wav_distances = [0] * num_wav_frames
89
 
90
- for (i, j) in zip(*path):
91
- index = i if wav_type == "ref" else j
92
- wav_distances[index] = dist_matrix[i, j]
93
-
94
- # cur_index = -1
95
- # for (i, j) in zip(*path):
96
- # if wav_type == "ref":
97
- # index = i
98
- # else:
99
- # index = j
100
- # if index == cur_index:
101
- # continue
102
- # wav_distances[index] = dist_matrix[i, j]
103
- # cur_index = index
104
-
105
- # Overlay colors based on DTW distances
106
- for index in range(0, num_wav_frames):
107
  start_time = index * scaling_factor
108
  end_time = (index + 1) * scaling_factor
109
- norm_dist = wav_distances[index]
110
-
111
- green_color = float(norm_dist < threshold)
112
- red_color = float(norm_dist >= threshold)
113
 
114
- color = (red_color, green_color, 0) # Green to Red
115
- ax[0].axvspan(start_time, end_time, facecolor=color, alpha=0.5)
 
 
116
 
117
  ax[0].set_xlabel("Time (s)")
118
  ax[0].set_ylabel("Amplitude")
@@ -122,8 +97,6 @@ class PronunciationChecker:
122
  ax[1].set_xlim(ax[0].get_xlim())
123
  ax[2].set_xlim(ax[0].get_xlim())
124
 
125
- print(input_number)
126
-
127
  if labels_data:
128
  for start, end, grapheme, *boolean_labels in labels_data:
129
  ax[0].axvline(start, color='gray', linestyle='--', alpha=0.7)
@@ -143,6 +116,4 @@ class PronunciationChecker:
143
  ax[2].set_title("Boolean Labels")
144
  ax[2].grid(False)
145
 
146
- return fig
147
-
148
-
 
1
  # SPDX-FileContributor: Karl El Hajal
2
 
3
  import torch
 
4
  import numpy as np
5
  import matplotlib.pyplot as plt
6
  from transformers import AutoFeatureExtractor, AutoModel
7
  from scipy.spatial.distance import cdist
8
  from dtw import accelerated_dtw
9
 
10
+ from src.audio_preprocessing import process_wav, get_red_green_segments
11
 
12
  class PronunciationChecker:
13
  def __init__(self, model_name = "microsoft/wavlm-large"):
 
75
 
76
  fig, ax = plt.subplots(3, 1, figsize=(15, 10), gridspec_kw={'height_ratios': [5, 1, 1]})
77
 
 
78
  ax[0].plot(time_ref, wav, label="Waveform", color="blue", alpha=0.7)
 
 
 
 
 
 
79
 
80
+ red_segments, green_segments, _ = get_red_green_segments(dist_matrix, path, threshold, wav_type=wav_type)
81
 
82
+ for index in green_segments:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  start_time = index * scaling_factor
84
  end_time = (index + 1) * scaling_factor
85
+ ax[0].axvspan(start_time, end_time, facecolor=(0, 1, 0), alpha=0.5)
 
 
 
86
 
87
+ for index in red_segments:
88
+ start_time = index * scaling_factor
89
+ end_time = (index + 1) * scaling_factor
90
+ ax[0].axvspan(start_time, end_time, facecolor=(1, 0, 0), alpha=0.5)
91
 
92
  ax[0].set_xlabel("Time (s)")
93
  ax[0].set_ylabel("Amplitude")
 
97
  ax[1].set_xlim(ax[0].get_xlim())
98
  ax[2].set_xlim(ax[0].get_xlim())
99
 
 
 
100
  if labels_data:
101
  for start, end, grapheme, *boolean_labels in labels_data:
102
  ax[0].axvline(start, color='gray', linestyle='--', alpha=0.7)
 
116
  ax[2].set_title("Boolean Labels")
117
  ax[2].grid(False)
118
 
119
+ return fig