Karl El Hajal commited on
Commit
4baa40f
·
1 Parent(s): d6fcab3

Added code + requirements

Browse files
Files changed (4) hide show
  1. app.py +68 -0
  2. audio_preprocessing.py +103 -0
  3. pronunciation_checker.py +87 -0
  4. requirements.txt +86 -0
app.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # SPDX-FileContributor: Karl El Hajal
3
+ # SPDX-FileContributor: Ali Dulaimi
4
+
5
+ import gradio as gr
6
+ import tempfile
7
+ import matplotlib.pyplot as plt
8
+ from src.pronunciation_checker import PronunciationChecker
9
+
10
+ def check_pronunciation(reference_audio, input_audio):
11
+ pronunciation_checker = PronunciationChecker("microsoft/wavlm-large")
12
+
13
+ # Extract features from both audio files
14
+ layer = 6
15
+ ref_wav, sr = PronunciationChecker.preprocess_wav(reference_audio)
16
+ comparison_wav, _ = PronunciationChecker.preprocess_wav(input_audio)
17
+
18
+ # Check if waveforms are not empty
19
+ if ref_wav is None or comparison_wav is None:
20
+ raise ValueError("One or both of the waveforms are empty.")
21
+
22
+ # Extract features
23
+ ref_features, ref_wav, sr = pronunciation_checker.extract_features(ref_wav, layer)
24
+ input_features, comparison_wav, _ = pronunciation_checker.extract_features(comparison_wav, layer)
25
+
26
+ # Compute DTW
27
+ dist_matrix, path = PronunciationChecker.compute_dtw(ref_features, input_features)
28
+
29
+ # Check if DTW path is valid
30
+ if path is None or dist_matrix is None:
31
+ raise ValueError("DTW computation failed.")
32
+
33
+ PronunciationChecker.plot_waveform_with_overlay(ref_wav, sr, dist_matrix, path, "ref")
34
+
35
+ # Save the visualization to a temporary image file
36
+ with tempfile.NamedTemporaryFile(delete=False, suffix='.png') as tmp:
37
+ tmp_path = tmp.name
38
+ plt.savefig(tmp_path)
39
+ plt.close()
40
+
41
+ # Return the image file path for Gradio to display
42
+ return tmp_path
43
+
44
+ pronunciation_checker = PronunciationChecker("microsoft/wavlm-large")
45
+
46
+ # Create Gradio interface
47
+ demo = gr.Interface(
48
+ fn=check_pronunciation,
49
+ inputs=[
50
+ gr.Audio(
51
+ type="filepath",
52
+ label="Reference Audio",
53
+ format="wav"
54
+ ),
55
+ gr.Audio(
56
+ type="filepath",
57
+ label="Input Audio",
58
+ format="wav"
59
+ ),
60
+ ],
61
+ outputs=gr.Image(type="filepath"),
62
+ title="Pronunciation Checker",
63
+ description="Compare pronunciation using WavLM and visualize with DTW overlays."
64
+ )
65
+
66
+
67
+ if __name__ == "__main__":
68
+ demo.launch(share=True, height=700)
audio_preprocessing.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileContributor: Karl El Hajal
2
+
3
+ import numpy as np
4
+ import webrtcvad
5
+ from pydub import AudioSegment
6
+
7
+ VAD_SR = 16000
8
+ VAD_MODE = 3 # Aggressiveness level (0-3, where 3 is the most aggressive)
9
+ VAD_FRAME_DURATION = 10 # Frame duration in milliseconds
10
+
11
+ def get_speech_segments_webrtcvad(audio_array, sample_rate, frame_duration, vad_mode):
12
+ vad = webrtcvad.Vad(vad_mode)
13
+
14
+ # Convert the frame duration to samples
15
+ frame_duration_samples = int(sample_rate * frame_duration / 1000)
16
+
17
+ # Detect speech regions using VAD
18
+ speech_segments = []
19
+ start = -1
20
+ for i in range(0, len(audio_array), frame_duration_samples):
21
+ frame = audio_array[i : i + frame_duration_samples]
22
+
23
+ if len(frame) < 160:
24
+ is_speech = False
25
+ else:
26
+ frame = frame.tobytes()
27
+ is_speech = vad.is_speech(frame, sample_rate)
28
+
29
+ if is_speech and start == -1:
30
+ start = i
31
+ elif not is_speech and start != -1:
32
+ end = i
33
+ speech_segments.append((start, end))
34
+ start = -1
35
+
36
+ return speech_segments
37
+
38
+
39
+ def get_start_end_using_vad(audio, sample_rate):
40
+ audio_array = np.array(audio.get_array_of_samples())
41
+
42
+ speech_segments = get_speech_segments_webrtcvad(audio_array, sample_rate, VAD_FRAME_DURATION, VAD_MODE)
43
+ if len(speech_segments) == 0:
44
+ speech_segments = get_speech_segments_webrtcvad(audio_array, sample_rate, VAD_FRAME_DURATION, VAD_MODE - 1)
45
+
46
+ start_sample = speech_segments[0][0]
47
+ end_sample = speech_segments[-1][1]
48
+
49
+ start_time = float(start_sample / VAD_SR)
50
+ end_time = float(end_sample / VAD_SR)
51
+
52
+ return start_time, end_time
53
+
54
+
55
+ def trim_silences(audio, target_sr):
56
+ audio_copy = audio[:]
57
+
58
+ audio_copy = audio_copy.set_frame_rate(VAD_SR)
59
+
60
+ start_time, end_time = get_start_end_using_vad(audio_copy, VAD_SR)
61
+
62
+ start_sample_orig_sr = int(start_time * target_sr)
63
+ end_sample_orig_sr = int(end_time * target_sr)
64
+
65
+ filtered_audio_array = np.array(audio.get_array_of_samples())
66
+ filtered_audio_array = filtered_audio_array[start_sample_orig_sr:end_sample_orig_sr]
67
+
68
+ filtered_audio = AudioSegment(
69
+ filtered_audio_array.tobytes(),
70
+ frame_rate=target_sr,
71
+ sample_width=audio.sample_width,
72
+ channels=audio.channels,
73
+ )
74
+
75
+ return filtered_audio
76
+
77
+
78
+ def match_target_amplitude(audio, target_dBFS):
79
+ change_in_dBFS = target_dBFS - audio.dBFS
80
+ return audio.apply_gain(change_in_dBFS)
81
+
82
+
83
+ def process_wav(wav_path, target_sr, do_trim_silences=True):
84
+ audio = AudioSegment.from_file(wav_path)
85
+
86
+ # Convert audio to mono
87
+ if audio.channels > 1:
88
+ audio = audio.set_channels(1)
89
+
90
+ # Resample audio
91
+ audio = audio.set_frame_rate(target_sr)
92
+
93
+ # Convert the audio to 16-bit PCM format
94
+ audio = audio.set_sample_width(2)
95
+
96
+ # Remove silences
97
+ if do_trim_silences:
98
+ audio = trim_silences(audio, target_sr)
99
+
100
+ # Loudness normalization to -20dB
101
+ audio = match_target_amplitude(audio, -20.0)
102
+
103
+ return audio
pronunciation_checker.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileContributor: Karl El Hajal
2
+
3
+ import torch
4
+ import torchaudio
5
+ import numpy as np
6
+ import matplotlib.pyplot as plt
7
+ from transformers import AutoFeatureExtractor, AutoModel
8
+ from scipy.spatial.distance import cdist
9
+ from dtw import accelerated_dtw
10
+
11
+ from src.audio_preprocessing import process_wav
12
+
13
+ class PronunciationChecker:
14
+ def __init__(self, model_name = "microsoft/wavlm-large"):
15
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
16
+ self.model_name = model_name
17
+
18
+ self.processor = AutoFeatureExtractor.from_pretrained(self.model_name)
19
+ self.model = AutoModel.from_pretrained(self.model_name).eval().to(self.device)
20
+
21
+ @staticmethod
22
+ def preprocess_wav(wav_path):
23
+ temp_audio_path = "temp.wav"
24
+ audio_segment = process_wav(wav_path, 16000, do_trim_silences=True)
25
+ audio_segment.export(temp_audio_path, format="wav")
26
+ wav, sr = torchaudio.load(temp_audio_path)
27
+ return wav, sr
28
+
29
+ def extract_features(self, wav, layer=None):
30
+ inputs = self.processor(wav.squeeze().to(self.device), sampling_rate=16000, return_tensors="pt", padding=True)
31
+ inputs = {key: val.to(self.device) for key, val in inputs.items()}
32
+ with torch.no_grad():
33
+ outputs = self.model(**inputs, output_hidden_states=True)
34
+
35
+ if layer is None:
36
+ features = outputs.last_hidden_state
37
+ else:
38
+ hidden_states = outputs.hidden_states
39
+ features = hidden_states[layer]
40
+
41
+ features = features.squeeze().cpu().numpy()
42
+ return features, wav.squeeze().cpu().numpy(), 16000
43
+
44
+ @staticmethod
45
+ def compute_dtw(ref_features, input_features):
46
+ # distance_metric = "euclidean"
47
+ distance_metric = "cosine"
48
+ dist_matrix = cdist(ref_features, input_features, metric=distance_metric)
49
+ _, _, acc, path = accelerated_dtw(ref_features, input_features, dist=distance_metric)
50
+ return dist_matrix, path
51
+
52
+ @staticmethod
53
+ def plot_waveform_with_overlay(wav, sr, dist_matrix, path, wav_type='ref'):
54
+ feature_stride = 320
55
+ time_ref = np.linspace(0, len(wav) / sr, len(wav))
56
+
57
+ fig, ax = plt.subplots(figsize=(15, 6))
58
+
59
+ # Plot the reference waveform
60
+ ax.plot(time_ref, wav, label="Waveform", color="blue", alpha=0.7)
61
+
62
+ # Overlay colors based on DTW distances
63
+ for (i, j) in zip(*path):
64
+ if wav_type == "ref":
65
+ index = i
66
+ else:
67
+ index = j
68
+ start_time = index * feature_stride / sr
69
+ end_time = (index + 1) * feature_stride / sr
70
+ dist = dist_matrix[i, j]
71
+ norm_dist = (dist - dist_matrix.min()) / (dist_matrix.max() - dist_matrix.min())
72
+
73
+ green_color = float(norm_dist<0.5)
74
+ red_color = float(norm_dist>=0.5)
75
+
76
+ # green_color = 1 - norm_dist
77
+ # red_color = norm_dist
78
+ color = (red_color, green_color, 0) # Green to Red
79
+
80
+ ax.axvspan(start_time, end_time, facecolor=color, alpha=0.7)
81
+
82
+ ax.set_xlabel("Time (s)")
83
+ ax.set_ylabel("Amplitude")
84
+ ax.set_title("Waveform with DTW Distance Overlay")
85
+ ax.legend()
86
+
87
+ return fig
requirements.txt ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiofiles==23.2.1
2
+ annotated-types==0.7.0
3
+ anyio==4.8.0
4
+ certifi==2024.12.14
5
+ charset-normalizer==3.4.1
6
+ click==8.1.8
7
+ contourpy==1.3.1
8
+ cycler==0.12.1
9
+ dtw==1.4.0
10
+ exceptiongroup==1.2.2
11
+ fastapi==0.115.6
12
+ ffmpy==0.5.0
13
+ filelock==3.16.1
14
+ fonttools==4.55.3
15
+ fsspec==2024.12.0
16
+ gradio==5.12.0
17
+ gradio_client==1.5.4
18
+ h11==0.14.0
19
+ httpcore==1.0.7
20
+ httpx==0.28.1
21
+ huggingface-hub==0.27.1
22
+ idna==3.10
23
+ Jinja2==3.1.5
24
+ kiwisolver==1.4.8
25
+ markdown-it-py==3.0.0
26
+ MarkupSafe==2.1.5
27
+ matplotlib==3.10.0
28
+ mdurl==0.1.2
29
+ mpmath==1.3.0
30
+ networkx==3.4.2
31
+ numpy==2.2.1
32
+ nvidia-cublas-cu12==12.4.5.8
33
+ nvidia-cuda-cupti-cu12==12.4.127
34
+ nvidia-cuda-nvrtc-cu12==12.4.127
35
+ nvidia-cuda-runtime-cu12==12.4.127
36
+ nvidia-cudnn-cu12==9.1.0.70
37
+ nvidia-cufft-cu12==11.2.1.3
38
+ nvidia-curand-cu12==10.3.5.147
39
+ nvidia-cusolver-cu12==11.6.1.9
40
+ nvidia-cusparse-cu12==12.3.1.170
41
+ nvidia-nccl-cu12==2.21.5
42
+ nvidia-nvjitlink-cu12==12.4.127
43
+ nvidia-nvtx-cu12==12.4.127
44
+ orjson==3.10.14
45
+ packaging==24.2
46
+ pandas==2.2.3
47
+ pillow==11.1.0
48
+ pip==22.0.2
49
+ pydantic==2.10.5
50
+ pydantic_core==2.27.2
51
+ pydub==0.25.1
52
+ Pygments==2.19.1
53
+ pyparsing==3.2.1
54
+ python-dateutil==2.9.0.post0
55
+ python-multipart==0.0.20
56
+ pytz==2024.2
57
+ PyYAML==6.0.2
58
+ regex==2024.11.6
59
+ requests==2.32.3
60
+ rich==13.9.4
61
+ ruff==0.9.2
62
+ safehttpx==0.1.6
63
+ safetensors==0.5.2
64
+ scipy==1.15.1
65
+ semantic-version==2.10.0
66
+ setuptools==59.6.0
67
+ shellingham==1.5.4
68
+ six==1.17.0
69
+ sniffio==1.3.1
70
+ starlette==0.41.3
71
+ sympy==1.13.1
72
+ tokenizers==0.21.0
73
+ tomlkit==0.13.2
74
+ torch==2.5.1
75
+ torchaudio==2.5.1
76
+ tqdm==4.67.1
77
+ transformers==4.48.0
78
+ triton==3.1.0
79
+ typer==0.15.1
80
+ typing_extensions==4.12.2
81
+ tzdata==2024.2
82
+ urllib3==2.3.0
83
+ uvicorn==0.34.0
84
+ webrtcvad==2.0.10
85
+ websockets==14.1
86
+ wheel==0.37.1