Spaces:
Sleeping
Sleeping
| # -*- coding: utf-8 -*- | |
| # SPDX-FileContributor: Karl El Hajal | |
| # SPDX-FileContributor: Ali Dulaimi | |
| import spaces | |
| import gradio as gr | |
| import tempfile | |
| import matplotlib.pyplot as plt | |
| import os | |
| from src.pronunciation_checker import PronunciationChecker | |
| from src.audio_preprocessing import assess_pronunciation_quality, denoise_audio | |
| from datetime import datetime | |
| from pathlib import Path | |
| import re | |
| def check_pronunciation(reference_audio, input_audio, threshold, wavlm_layer, labels_data=None, input_number=None): | |
| wavlm_layer = int(wavlm_layer) | |
| timing_results = [] | |
| # can be moved to src/utils | |
| def log_timing(step): | |
| timing_results.append((step, datetime.now())) | |
| log_timing("Start") | |
| # ref_wav = denoise_audio(ref_wav) | |
| input_audio = denoise_audio(input_audio) | |
| log_timing("Input Audio Denoising") | |
| # Extract features from both audio files | |
| ref_wav, sr = pronunciation_checker.preprocess_wav(reference_audio) | |
| log_timing("Reference Audio Preprocessing") | |
| comparison_wav, _ = pronunciation_checker.preprocess_wav(input_audio) | |
| log_timing("Input Audio Preprocessing") | |
| # Check if waveforms are not empty | |
| if ref_wav is None or comparison_wav is None: | |
| raise ValueError("One or both of the waveforms are empty.") | |
| # Extract features | |
| ref_features, ref_wav, sr = pronunciation_checker.extract_features(ref_wav, wavlm_layer) | |
| log_timing("Reference Feature Extraction") | |
| input_features, comparison_wav, _ = pronunciation_checker.extract_features(comparison_wav, wavlm_layer) | |
| log_timing("Input Feature Extraction") | |
| # Compute DTW | |
| dist_matrix, path = PronunciationChecker.compute_dtw(ref_features, input_features) | |
| log_timing("DTW Computation") | |
| quality_score, needs_repeat = assess_pronunciation_quality(dist_matrix, path) | |
| log_timing("Quality Assessment") | |
| # Check if DTW path is valid | |
| if path is None or dist_matrix is None: | |
| raise ValueError("DTW computation failed.") | |
| PronunciationChecker.plot_waveform_with_overlay(ref_wav, sr, dist_matrix, path, "ref", threshold, labels_data, input_number) | |
| log_timing("Visualization Generation") | |
| # Save the visualization to a temporary image file | |
| with tempfile.NamedTemporaryFile(delete=False, suffix='.png') as tmp: | |
| tmp_path = tmp.name | |
| plt.savefig(tmp_path) | |
| plt.close() | |
| timing_analysis = "Runtime Analysis:\n" | |
| for i in range(1, len(timing_results)): | |
| step_duration = (timing_results[i][1] - timing_results[i-1][1]).total_seconds() | |
| timing_analysis += f"{timing_results[i][0]}: {step_duration:.2f}s\n" | |
| # Return the image file path for Gradio to display | |
| return tmp_path, quality_score, needs_repeat, timing_analysis | |
| def parse_labels_file(labels_file): | |
| labels_data = [] | |
| with open(labels_file, 'r') as f: | |
| for line in f: | |
| parts = re.split(r'[\t ]+', line) | |
| if len(parts) < 3: | |
| continue | |
| start, end = map(float, parts[:2]) | |
| grapheme = parts[2] | |
| boolean_labels = list(map(int, parts[3:])) | |
| labels_data.append((start, end, grapheme, *boolean_labels)) | |
| return labels_data | |
| def run_tests(threshold=0.4, wavlm_layer=24): | |
| # Create a directory for storing plots if it doesn't exist | |
| plots_dir = Path("temp_plots") | |
| plots_dir.mkdir(exist_ok=True) | |
| benchmark_dir = Path("benchmark") | |
| test_cases = [] | |
| # Clear any old plot files | |
| for old_plot in plots_dir.glob("*.png"): | |
| old_plot.unlink() | |
| for test_word_dir in benchmark_dir.iterdir(): | |
| if test_word_dir.is_dir(): | |
| word = test_word_dir.name | |
| ref_file = test_word_dir / "ref.wav" | |
| labels_file = test_word_dir / "labels.txt" | |
| if not ref_file.exists() or not labels_file.exists(): | |
| test_cases.append({ | |
| "heading": f"Error: Missing files for '{word}'", | |
| "plot": None, | |
| "text": f"Missing required files (ref.wav or labels.txt) for test word '{word}'." | |
| }) | |
| continue | |
| labels_data = parse_labels_file(labels_file) | |
| for input_file in test_word_dir.glob("input_*.wav"): | |
| if not input_file.exists(): | |
| test_cases.append({ | |
| "heading": f"Error: Missing input files for '{word}'", | |
| "plot": None, | |
| "text": f"No input files found for test word '{word}'." | |
| }) | |
| continue | |
| try: | |
| # Generate a unique filename for this test's plot | |
| plot_filename = f"{word}_{input_file.stem}_{datetime.now().strftime('%H%M%S')}.png" | |
| plot_path = plots_dir / plot_filename | |
| input_number = str(input_file.stem).split('_')[1] | |
| input_number = int(input_number)-1 | |
| # Run the pronunciation checker | |
| tmp_path, quality_score, needs_repeat, timing_analysis = check_pronunciation( | |
| ref_file, input_file, threshold, wavlm_layer, labels_data, input_number | |
| ) | |
| # Move the plot to our persistent directory | |
| if os.path.exists(tmp_path): | |
| os.rename(tmp_path, plot_path) | |
| test_cases.append({ | |
| "heading": f"Word: {word}, Input: {input_file.name}", | |
| "plot": str(plot_path), | |
| "text": f"Quality Score: {quality_score:.2f}\n" | |
| f"Needs Repeat: {needs_repeat}" | |
| }) | |
| except Exception as e: | |
| test_cases.append({ | |
| "heading": f"Error for '{word}' with '{input_file.name}'", | |
| "plot": None, | |
| "text": str(e) | |
| }) | |
| return test_cases | |
| def display_test_results(threshold, wavlm_layer): | |
| test_cases = run_tests(threshold, wavlm_layer) | |
| components = [] | |
| for case in test_cases: | |
| components.append(gr.Markdown(f"### {case['heading']}")) | |
| if case["plot"]: | |
| components.append(gr.Image(case["plot"], type="filepath", label="Plot")) | |
| components.append(gr.Textbox(case["text"], label="Details", interactive=False, lines=8)) | |
| return components | |
| with gr.Blocks() as demo: | |
| with gr.Tabs(): | |
| with gr.Tab("Pronunciation Checker"): | |
| gr.Interface( | |
| fn=check_pronunciation, | |
| inputs=[ | |
| gr.Audio( | |
| type="filepath", | |
| label="Reference Audio", | |
| format="wav", | |
| show_download_button=True, | |
| ), | |
| gr.Audio( | |
| type="filepath", | |
| label="Input Audio", | |
| format="wav", | |
| show_download_button=True, | |
| ), | |
| gr.Slider(minimum=0.0, maximum=1.0, value=0.4, label="Decision Threshold"), | |
| gr.Slider(minimum=0, maximum=24, value=24, step=1, label="WavLM Layer"), | |
| ], | |
| outputs=[ | |
| gr.Image(type="filepath", label="Pronunciation Comparison"), | |
| gr.Number(label="Quality Score"), | |
| gr.Textbox(label="Feedback"), | |
| gr.Textbox(label="Runtime Analysis") | |
| ], | |
| live=False, | |
| title="Pronunciation Checker", | |
| description="Pronunciation Checker" | |
| ) | |
| with gr.Tab("Run Tests"): | |
| with gr.Column(): | |
| gr.Markdown("## Run Automated Tests") | |
| with gr.Row(): | |
| threshold_slider = gr.Slider(minimum=0.0, maximum=1.0, value=0.4, label="Decision Threshold") | |
| wavlm_layer_slider = gr.Slider(minimum=0, maximum=24, value=24, step=1, label="WavLM Layer") | |
| run_tests_button = gr.Button("Run Tests", variant="primary", interactive=True) | |
| # Create static containers for results | |
| with gr.Column(visible=False) as results_container: | |
| # Pre-create components for up to 10 test cases | |
| test_case_components = [] | |
| for i in range(10): # Adjust number based on your needs | |
| with gr.Column(visible=False) as case_container: | |
| heading = gr.Markdown() | |
| details = gr.Markdown() | |
| image = gr.Image(width=800) | |
| separator = gr.Markdown("---") | |
| test_case_components.append({ | |
| "container": case_container, | |
| "heading": heading, | |
| "details": details, | |
| "image": image, | |
| "separator": separator | |
| }) | |
| def format_test_results(threshold, wavlm_layer): | |
| test_cases = run_tests(threshold, wavlm_layer) | |
| updates = {results_container: gr.update(visible=True)} | |
| # Update pre-created components with new results | |
| for i, components in enumerate(test_case_components): | |
| if i < len(test_cases) and test_cases[i]['plot']: | |
| case = test_cases[i] | |
| updates[components["container"]] = gr.update(visible=True) | |
| updates[components["heading"]] = f"### {case['heading']}" | |
| updates[components["details"]] = case['text'] | |
| updates[components["image"]] = case['plot'] | |
| else: | |
| updates[components["container"]] = gr.update(visible=False) | |
| return updates | |
| # Update results when button is clicked | |
| run_tests_button.click( | |
| fn=format_test_results, | |
| inputs=[threshold_slider, wavlm_layer_slider], | |
| outputs=[results_container] + [comp for components in test_case_components | |
| for comp in [components["container"], | |
| components["heading"], | |
| components["details"], | |
| components["image"]]] | |
| ) | |
| if __name__ == "__main__": | |
| pronunciation_checker = PronunciationChecker("microsoft/wavlm-large") | |
| demo.launch() |