Spaces:
Sleeping
Sleeping
| # -*- coding: utf-8 -*- | |
| # SPDX-FileContributor: Karl El Hajal | |
| # SPDX-FileContributor: Ali Dulaimi | |
| import spaces | |
| import gradio as gr | |
| import tempfile | |
| import matplotlib.pyplot as plt | |
| import os | |
| from src.pronunciation_checker import PronunciationChecker | |
| from src.audio_preprocessing import assess_pronunciation_quality, denoise_audio, get_red_green_segments | |
| from datetime import datetime | |
| from pathlib import Path | |
| import re | |
| import json | |
| import csv | |
| def check_pronunciation(reference_audio, input_audio, threshold, wavlm_layer, labels_data=None, input_number=None): | |
| wavlm_layer = int(wavlm_layer) | |
| timing_results = [] | |
| # can be moved to src/utils | |
| def log_timing(step): | |
| timing_results.append((step, datetime.now())) | |
| log_timing("Start") | |
| # ref_wav = denoise_audio(ref_wav) | |
| input_audio = denoise_audio(input_audio) | |
| log_timing("Input Audio Denoising") | |
| # Extract features from both audio files | |
| ref_wav, sr = pronunciation_checker.preprocess_wav(reference_audio) | |
| log_timing("Reference Audio Preprocessing") | |
| comparison_wav, _ = pronunciation_checker.preprocess_wav(input_audio) | |
| log_timing("Input Audio Preprocessing") | |
| # Check if waveforms are not empty | |
| if ref_wav is None or comparison_wav is None: | |
| raise ValueError("One or both of the waveforms are empty.") | |
| # Extract features | |
| ref_features, ref_wav, sr = pronunciation_checker.extract_features(ref_wav, wavlm_layer) | |
| log_timing("Reference Feature Extraction") | |
| input_features, comparison_wav, _ = pronunciation_checker.extract_features(comparison_wav, wavlm_layer) | |
| log_timing("Input Feature Extraction") | |
| # Compute DTW | |
| dist_matrix, path = PronunciationChecker.compute_dtw(ref_features, input_features) | |
| log_timing("DTW Computation") | |
| quality_score, needs_repeat = assess_pronunciation_quality(dist_matrix, path, threshold, "ref") | |
| log_timing("Quality Assessment") | |
| # Check if DTW path is valid | |
| if path is None or dist_matrix is None: | |
| raise ValueError("DTW computation failed.") | |
| PronunciationChecker.plot_waveform_with_overlay(ref_wav, sr, dist_matrix, path, "ref", threshold, labels_data, input_number) | |
| log_timing("Visualization Generation") | |
| # Save the visualization to a temporary image file | |
| with tempfile.NamedTemporaryFile(delete=False, suffix='.png') as tmp: | |
| tmp_path = tmp.name | |
| plt.savefig(tmp_path) | |
| plt.close() | |
| timing_analysis = "Runtime Analysis:\n" | |
| for i in range(1, len(timing_results)): | |
| step_duration = (timing_results[i][1] - timing_results[i-1][1]).total_seconds() | |
| timing_analysis += f"{timing_results[i][0]}: {step_duration:.2f}s\n" | |
| # Return the image file path for Gradio to display | |
| return tmp_path, quality_score, needs_repeat, timing_analysis | |
| def parse_labels_file(labels_file): | |
| labels_data = [] | |
| with open(labels_file, 'r') as f: | |
| for line in f: | |
| parts = re.split(r'[\t ]+', line) | |
| if len(parts) < 3: | |
| continue | |
| start, end = map(float, parts[:2]) | |
| grapheme = parts[2] | |
| boolean_labels = list(map(int, parts[3:])) | |
| labels_data.append((start, end, grapheme, *boolean_labels)) | |
| return labels_data | |
| def run_tests(threshold=0.4, wavlm_layer=24): | |
| # Create a directory for storing plots if it doesn't exist | |
| plots_dir = Path("temp_plots") | |
| plots_dir.mkdir(exist_ok=True) | |
| benchmark_dir = Path("benchmark") | |
| test_cases = [] | |
| # Clear any old plot files | |
| for old_plot in plots_dir.glob("*.png"): | |
| old_plot.unlink() | |
| for test_word_dir in benchmark_dir.iterdir(): | |
| if test_word_dir.is_dir(): | |
| word = test_word_dir.name | |
| ref_file = test_word_dir / "ref.wav" | |
| labels_file = test_word_dir / "labels.txt" | |
| if not ref_file.exists() or not labels_file.exists(): | |
| test_cases.append({ | |
| "heading": f"Error: Missing files for '{word}'", | |
| "plot": None, | |
| "text": f"Missing required files (ref.wav or labels.txt) for test word '{word}'." | |
| }) | |
| continue | |
| labels_data = parse_labels_file(labels_file) | |
| for input_file in test_word_dir.glob("input_*.wav"): | |
| if not input_file.exists(): | |
| test_cases.append({ | |
| "heading": f"Error: Missing input files for '{word}'", | |
| "plot": None, | |
| "text": f"No input files found for test word '{word}'." | |
| }) | |
| continue | |
| try: | |
| # Generate a unique filename for this test's plot | |
| plot_filename = f"{word}_{input_file.stem}_{datetime.now().strftime('%H%M%S')}.png" | |
| plot_path = plots_dir / plot_filename | |
| input_number = str(input_file.stem).split('_')[1] | |
| input_number = int(input_number)-1 | |
| # Run the pronunciation checker | |
| tmp_path, quality_score, needs_repeat, timing_analysis = check_pronunciation( | |
| ref_file, input_file, threshold, wavlm_layer, labels_data, input_number | |
| ) | |
| # Move the plot to our persistent directory | |
| if os.path.exists(tmp_path): | |
| os.rename(tmp_path, plot_path) | |
| test_cases.append({ | |
| "heading": f"Word: {word}, Input: {input_file.name}", | |
| "plot": str(plot_path), | |
| "text": f"Quality Score: {quality_score:.2f}\n" | |
| f"Needs Repeat: {needs_repeat}" | |
| }) | |
| except Exception as e: | |
| test_cases.append({ | |
| "heading": f"Error for '{word}' with '{input_file.name}'", | |
| "plot": None, | |
| "text": str(e) | |
| }) | |
| return test_cases | |
| def display_test_results(threshold, wavlm_layer): | |
| test_cases = run_tests(threshold, wavlm_layer) | |
| components = [] | |
| for case in test_cases: | |
| components.append(gr.Markdown(f"### {case['heading']}")) | |
| if case["plot"]: | |
| components.append(gr.Image(case["plot"], type="filepath", label="Plot")) | |
| components.append(gr.Textbox(case["text"], label="Details", interactive=False, lines=8)) | |
| return components | |
| def calculate_red_percentage_demo(red_segments, labels_data, scaling_factor=0.02): | |
| red_percentages = [] | |
| def intersection_length(start1, end1, start2, end2): | |
| overlap_start = max(start1, start2) | |
| overlap_end = min(end1, end2) | |
| return max(0, overlap_end - overlap_start) | |
| for start, end, latin, arabic in labels_data: | |
| red_intersection = 0.0 | |
| for index in red_segments: | |
| red_start_time = index * scaling_factor | |
| red_end_time = (index + 1) * scaling_factor | |
| red_intersection += intersection_length(start, end, red_start_time, red_end_time) | |
| total_grapheme_duration = end - start | |
| red_percentage = (red_intersection / total_grapheme_duration) | |
| red_percentages.append(min(red_percentage, 1.)) | |
| return red_percentages | |
| def check_pronunciation_demo(reference_audio, input_audio, labels_data, threshold=0.4, wavlm_layer=24): | |
| wavlm_layer = int(wavlm_layer) | |
| # ref_wav = denoise_audio(ref_wav) | |
| input_audio = denoise_audio(input_audio) | |
| ref_wav, sr = pronunciation_checker.preprocess_wav(reference_audio, do_trim_silences=False) | |
| comparison_wav, _ = pronunciation_checker.preprocess_wav(input_audio) | |
| if ref_wav is None or comparison_wav is None: | |
| raise ValueError("One or both of the waveforms are empty.") | |
| ref_features, ref_wav, sr = pronunciation_checker.extract_features(ref_wav, wavlm_layer) | |
| input_features, comparison_wav, _ = pronunciation_checker.extract_features(comparison_wav, wavlm_layer) | |
| dist_matrix, path = PronunciationChecker.compute_dtw(ref_features, input_features) | |
| red_segments, _, _ = get_red_green_segments(dist_matrix, path, wav_type="ref", threshold=threshold) | |
| red_percentages = calculate_red_percentage_demo(red_segments, labels_data) | |
| is_red = [percentage > 0.0 for percentage in red_percentages] | |
| return is_red | |
| def parse_tsv(file_path): | |
| transcriptions = [] | |
| num_to_subtract = None | |
| if os.path.exists(file_path): | |
| with open(file_path, "r", encoding="utf-8") as f: | |
| reader = csv.reader(f, delimiter="\t") | |
| for row in reader: | |
| if len(row) == 4: # Ensure it has the expected 4 columns | |
| start, end, latin, arabic = row | |
| start = int(start) | |
| end = int(end) | |
| if num_to_subtract is None: | |
| num_to_subtract = start | |
| start -= num_to_subtract | |
| end -= num_to_subtract | |
| start = float(start)/1000. | |
| end = float(end)/1000. | |
| transcriptions.append((start, end, latin, arabic)) | |
| return transcriptions | |
| def collect_demo_data(): | |
| dialects = [] | |
| themes = set() | |
| sentence_ids = set() | |
| for dialect in os.listdir(DATA_DIR): | |
| dialect_path = os.path.join(DATA_DIR, dialect) | |
| if os.path.isdir(dialect_path): | |
| dialects.append(dialect) # Collect dialects | |
| for theme in os.listdir(dialect_path): | |
| theme_path = os.path.join(dialect_path, theme) | |
| if os.path.isdir(theme_path): | |
| themes.add(theme) # Collect themes | |
| for file in os.listdir(theme_path): | |
| if file.endswith("_word.wav"): | |
| sentence_id = file.split("_")[0] | |
| sentence_ids.add(sentence_id) # Collect word IDs | |
| return sorted(dialects), sorted(themes), sorted(sentence_ids) | |
| def run_demo(dialect, theme, sentence_id, input_audio): | |
| reference_audio = os.path.join(DATA_DIR, dialect, theme, f"{sentence_id}_word.wav") | |
| label_path = os.path.join(DATA_DIR, dialect, theme, f"{sentence_id}_labels.tsv") | |
| labels_data = parse_tsv(label_path) | |
| results = check_pronunciation_demo(reference_audio, input_audio, labels_data) | |
| latin_output = [] | |
| arabic_output = [] | |
| for index, (is_red, (start, end, latin, arabic)) in enumerate(zip(results, labels_data)): | |
| latin_output.append({ | |
| "index" : index, | |
| "letter": latin, | |
| "result": not is_red | |
| }) | |
| arabic_output.append({ | |
| "index" : index, | |
| "letter": arabic, | |
| "result": not is_red | |
| }) | |
| print(labels_data) | |
| print(results) | |
| result = { | |
| "highlighted_text_id": sentence_id, | |
| "highlighted_text_latin_payload": latin_output, | |
| "highlighted_text_arabic_payload": arabic_output | |
| } | |
| return json.dumps(result, indent=2, ensure_ascii=False) | |
| DATA_DIR = "data" | |
| dialects, themes, sentence_ids = collect_demo_data() | |
| with gr.Blocks() as demo: | |
| with gr.Tabs(): | |
| with gr.Tab("Demo"): | |
| dialect_dropdown = gr.Dropdown(choices=dialects, label="Select Dialect") | |
| theme_dropdown = gr.Dropdown(choices=themes, label="Select Theme") | |
| word_dropdown = gr.Dropdown(choices=sentence_ids, label="Select Sentence ID") | |
| input_audio = gr.Audio(type="filepath", label="Reference Audio", format="wav", show_download_button=True) | |
| # JSON output | |
| output_json = gr.JSON(label="Output JSON") | |
| # Button to trigger JSON output | |
| submit_btn = gr.Button("Get JSON Output") | |
| # JSON output function triggered by button | |
| submit_btn.click(run_demo, inputs=[dialect_dropdown, theme_dropdown, word_dropdown, input_audio], outputs=[output_json]) | |
| with gr.Tab("Pronunciation Checker"): | |
| gr.Interface( | |
| fn=check_pronunciation, | |
| inputs=[ | |
| gr.Audio( | |
| type="filepath", | |
| label="Reference Audio", | |
| format="wav", | |
| show_download_button=True, | |
| ), | |
| gr.Audio( | |
| type="filepath", | |
| label="Input Audio", | |
| format="wav", | |
| show_download_button=True, | |
| ), | |
| gr.Slider(minimum=0.0, maximum=1.0, value=0.4, label="Decision Threshold"), | |
| gr.Slider(minimum=0, maximum=24, value=24, step=1, label="WavLM Layer"), | |
| ], | |
| outputs=[ | |
| gr.Image(type="filepath", label="Pronunciation Comparison"), | |
| gr.Number(label="Quality Score"), | |
| gr.Textbox(label="Feedback"), | |
| gr.Textbox(label="Runtime Analysis") | |
| ], | |
| live=False, | |
| title="Pronunciation Checker", | |
| description="Pronunciation Checker" | |
| ) | |
| with gr.Tab("Run Tests"): | |
| with gr.Column(): | |
| gr.Markdown("## Run Automated Tests") | |
| with gr.Row(): | |
| threshold_slider = gr.Slider(minimum=0.0, maximum=1.0, value=0.4, label="Decision Threshold") | |
| wavlm_layer_slider = gr.Slider(minimum=0, maximum=24, value=24, step=1, label="WavLM Layer") | |
| run_tests_button = gr.Button("Run Tests", variant="primary", interactive=True) | |
| # Create static containers for results | |
| with gr.Column(visible=False) as results_container: | |
| # Pre-create components for up to 10 test cases | |
| test_case_components = [] | |
| for i in range(10): # Adjust number based on your needs | |
| with gr.Column(visible=False) as case_container: | |
| heading = gr.Markdown() | |
| details = gr.Markdown() | |
| image = gr.Image(width=800) | |
| separator = gr.Markdown("---") | |
| test_case_components.append({ | |
| "container": case_container, | |
| "heading": heading, | |
| "details": details, | |
| "image": image, | |
| "separator": separator | |
| }) | |
| def format_test_results(threshold, wavlm_layer): | |
| test_cases = run_tests(threshold, wavlm_layer) | |
| updates = {results_container: gr.update(visible=True)} | |
| # Update pre-created components with new results | |
| for i, components in enumerate(test_case_components): | |
| if i < len(test_cases) and test_cases[i]['plot']: | |
| case = test_cases[i] | |
| updates[components["container"]] = gr.update(visible=True) | |
| updates[components["heading"]] = f"### {case['heading']}" | |
| updates[components["details"]] = case['text'] | |
| updates[components["image"]] = case['plot'] | |
| else: | |
| updates[components["container"]] = gr.update(visible=False) | |
| return updates | |
| # Update results when button is clicked | |
| run_tests_button.click( | |
| fn=format_test_results, | |
| inputs=[threshold_slider, wavlm_layer_slider], | |
| outputs=[results_container] + [comp for components in test_case_components | |
| for comp in [components["container"], | |
| components["heading"], | |
| components["details"], | |
| components["image"]]] | |
| ) | |
| if __name__ == "__main__": | |
| pronunciation_checker = PronunciationChecker("microsoft/wavlm-large") | |
| demo.launch() |