# -*- coding: utf-8 -*- # SPDX-FileContributor: Karl El Hajal # SPDX-FileContributor: Ali Dulaimi import spaces import gradio as gr import tempfile import matplotlib.pyplot as plt import os from src.pronunciation_checker import PronunciationChecker from src.audio_preprocessing import assess_pronunciation_quality, denoise_audio, get_red_green_segments from datetime import datetime from pathlib import Path import re import json import csv @spaces.GPU def check_pronunciation(reference_audio, input_audio, threshold, wavlm_layer, labels_data=None, input_number=None): wavlm_layer = int(wavlm_layer) timing_results = [] # can be moved to src/utils def log_timing(step): timing_results.append((step, datetime.now())) log_timing("Start") # ref_wav = denoise_audio(ref_wav) input_audio = denoise_audio(input_audio) log_timing("Input Audio Denoising") # Extract features from both audio files ref_wav, sr = pronunciation_checker.preprocess_wav(reference_audio) log_timing("Reference Audio Preprocessing") comparison_wav, _ = pronunciation_checker.preprocess_wav(input_audio) log_timing("Input Audio Preprocessing") # Check if waveforms are not empty if ref_wav is None or comparison_wav is None: raise ValueError("One or both of the waveforms are empty.") # Extract features ref_features, ref_wav, sr = pronunciation_checker.extract_features(ref_wav, wavlm_layer) log_timing("Reference Feature Extraction") input_features, comparison_wav, _ = pronunciation_checker.extract_features(comparison_wav, wavlm_layer) log_timing("Input Feature Extraction") # Compute DTW dist_matrix, path = PronunciationChecker.compute_dtw(ref_features, input_features) log_timing("DTW Computation") quality_score, needs_repeat = assess_pronunciation_quality(dist_matrix, path, threshold, "ref") log_timing("Quality Assessment") # Check if DTW path is valid if path is None or dist_matrix is None: raise ValueError("DTW computation failed.") PronunciationChecker.plot_waveform_with_overlay(ref_wav, sr, dist_matrix, path, "ref", threshold, labels_data, input_number) log_timing("Visualization Generation") # Save the visualization to a temporary image file with tempfile.NamedTemporaryFile(delete=False, suffix='.png') as tmp: tmp_path = tmp.name plt.savefig(tmp_path) plt.close() timing_analysis = "Runtime Analysis:\n" for i in range(1, len(timing_results)): step_duration = (timing_results[i][1] - timing_results[i-1][1]).total_seconds() timing_analysis += f"{timing_results[i][0]}: {step_duration:.2f}s\n" # Return the image file path for Gradio to display return tmp_path, quality_score, needs_repeat, timing_analysis def parse_labels_file(labels_file): labels_data = [] with open(labels_file, 'r') as f: for line in f: parts = re.split(r'[\t ]+', line) if len(parts) < 3: continue start, end = map(float, parts[:2]) grapheme = parts[2] boolean_labels = list(map(int, parts[3:])) labels_data.append((start, end, grapheme, *boolean_labels)) return labels_data @spaces.GPU def run_tests(threshold=0.4, wavlm_layer=24): # Create a directory for storing plots if it doesn't exist plots_dir = Path("temp_plots") plots_dir.mkdir(exist_ok=True) benchmark_dir = Path("benchmark") test_cases = [] # Clear any old plot files for old_plot in plots_dir.glob("*.png"): old_plot.unlink() for test_word_dir in benchmark_dir.iterdir(): if test_word_dir.is_dir(): word = test_word_dir.name ref_file = test_word_dir / "ref.wav" labels_file = test_word_dir / "labels.txt" if not ref_file.exists() or not labels_file.exists(): test_cases.append({ "heading": f"Error: Missing files for '{word}'", "plot": None, "text": f"Missing required files (ref.wav or labels.txt) for test word '{word}'." }) continue labels_data = parse_labels_file(labels_file) for input_file in test_word_dir.glob("input_*.wav"): if not input_file.exists(): test_cases.append({ "heading": f"Error: Missing input files for '{word}'", "plot": None, "text": f"No input files found for test word '{word}'." }) continue try: # Generate a unique filename for this test's plot plot_filename = f"{word}_{input_file.stem}_{datetime.now().strftime('%H%M%S')}.png" plot_path = plots_dir / plot_filename input_number = str(input_file.stem).split('_')[1] input_number = int(input_number)-1 # Run the pronunciation checker tmp_path, quality_score, needs_repeat, timing_analysis = check_pronunciation( ref_file, input_file, threshold, wavlm_layer, labels_data, input_number ) # Move the plot to our persistent directory if os.path.exists(tmp_path): os.rename(tmp_path, plot_path) test_cases.append({ "heading": f"Word: {word}, Input: {input_file.name}", "plot": str(plot_path), "text": f"Quality Score: {quality_score:.2f}\n" f"Needs Repeat: {needs_repeat}" }) except Exception as e: test_cases.append({ "heading": f"Error for '{word}' with '{input_file.name}'", "plot": None, "text": str(e) }) return test_cases def display_test_results(threshold, wavlm_layer): test_cases = run_tests(threshold, wavlm_layer) components = [] for case in test_cases: components.append(gr.Markdown(f"### {case['heading']}")) if case["plot"]: components.append(gr.Image(case["plot"], type="filepath", label="Plot")) components.append(gr.Textbox(case["text"], label="Details", interactive=False, lines=8)) return components def calculate_red_percentage_demo(red_segments, labels_data, scaling_factor=0.02): red_percentages = [] def intersection_length(start1, end1, start2, end2): overlap_start = max(start1, start2) overlap_end = min(end1, end2) return max(0, overlap_end - overlap_start) for start, end, latin, arabic in labels_data: red_intersection = 0.0 for index in red_segments: red_start_time = index * scaling_factor red_end_time = (index + 1) * scaling_factor red_intersection += intersection_length(start, end, red_start_time, red_end_time) total_grapheme_duration = end - start red_percentage = (red_intersection / total_grapheme_duration) red_percentages.append(min(red_percentage, 1.)) return red_percentages @spaces.GPU def check_pronunciation_demo(reference_audio, input_audio, labels_data, threshold=0.4, wavlm_layer=24): wavlm_layer = int(wavlm_layer) # ref_wav = denoise_audio(ref_wav) input_audio = denoise_audio(input_audio) ref_wav, sr = pronunciation_checker.preprocess_wav(reference_audio, do_trim_silences=False) comparison_wav, _ = pronunciation_checker.preprocess_wav(input_audio) if ref_wav is None or comparison_wav is None: raise ValueError("One or both of the waveforms are empty.") ref_features, ref_wav, sr = pronunciation_checker.extract_features(ref_wav, wavlm_layer) input_features, comparison_wav, _ = pronunciation_checker.extract_features(comparison_wav, wavlm_layer) dist_matrix, path = PronunciationChecker.compute_dtw(ref_features, input_features) red_segments, _, _ = get_red_green_segments(dist_matrix, path, wav_type="ref", threshold=threshold) red_percentages = calculate_red_percentage_demo(red_segments, labels_data) is_red = [percentage > 0.0 for percentage in red_percentages] return is_red def parse_tsv(file_path): transcriptions = [] num_to_subtract = None if os.path.exists(file_path): with open(file_path, "r", encoding="utf-8") as f: reader = csv.reader(f, delimiter="\t") for row in reader: if len(row) == 4: # Ensure it has the expected 4 columns start, end, latin, arabic = row start = int(start) end = int(end) if num_to_subtract is None: num_to_subtract = start start -= num_to_subtract end -= num_to_subtract start = float(start)/1000. end = float(end)/1000. transcriptions.append((start, end, latin, arabic)) return transcriptions def collect_demo_data(): dialects = [] themes = set() sentence_ids = set() for dialect in os.listdir(DATA_DIR): dialect_path = os.path.join(DATA_DIR, dialect) if os.path.isdir(dialect_path): dialects.append(dialect) # Collect dialects for theme in os.listdir(dialect_path): theme_path = os.path.join(dialect_path, theme) if os.path.isdir(theme_path): themes.add(theme) # Collect themes for file in os.listdir(theme_path): if file.endswith("_word.wav"): sentence_id = file.split("_")[0] sentence_ids.add(sentence_id) # Collect word IDs return sorted(dialects), sorted(themes), sorted(sentence_ids) @spaces.GPU def run_demo(dialect, theme, sentence_id, input_audio): reference_audio = os.path.join(DATA_DIR, dialect, theme, f"{sentence_id}_word.wav") label_path = os.path.join(DATA_DIR, dialect, theme, f"{sentence_id}_labels.tsv") labels_data = parse_tsv(label_path) results = check_pronunciation_demo(reference_audio, input_audio, labels_data) latin_output = [] arabic_output = [] for index, (is_red, (start, end, latin, arabic)) in enumerate(zip(results, labels_data)): latin_output.append({ "index" : index, "letter": latin, "result": not is_red }) arabic_output.append({ "index" : index, "letter": arabic, "result": not is_red }) print(labels_data) print(results) result = { "highlighted_text_id": sentence_id, "highlighted_text_latin_payload": latin_output, "highlighted_text_arabic_payload": arabic_output } return json.dumps(result, indent=2, ensure_ascii=False) DATA_DIR = "data" dialects, themes, sentence_ids = collect_demo_data() with gr.Blocks() as demo: with gr.Tabs(): with gr.Tab("Demo"): dialect_dropdown = gr.Dropdown(choices=dialects, label="Select Dialect") theme_dropdown = gr.Dropdown(choices=themes, label="Select Theme") word_dropdown = gr.Dropdown(choices=sentence_ids, label="Select Sentence ID") input_audio = gr.Audio(type="filepath", label="Reference Audio", format="wav", show_download_button=True) # JSON output output_json = gr.JSON(label="Output JSON") # Button to trigger JSON output submit_btn = gr.Button("Get JSON Output") # JSON output function triggered by button submit_btn.click(run_demo, inputs=[dialect_dropdown, theme_dropdown, word_dropdown, input_audio], outputs=[output_json]) with gr.Tab("Pronunciation Checker"): gr.Interface( fn=check_pronunciation, inputs=[ gr.Audio( type="filepath", label="Reference Audio", format="wav", show_download_button=True, ), gr.Audio( type="filepath", label="Input Audio", format="wav", show_download_button=True, ), gr.Slider(minimum=0.0, maximum=1.0, value=0.4, label="Decision Threshold"), gr.Slider(minimum=0, maximum=24, value=24, step=1, label="WavLM Layer"), ], outputs=[ gr.Image(type="filepath", label="Pronunciation Comparison"), gr.Number(label="Quality Score"), gr.Textbox(label="Feedback"), gr.Textbox(label="Runtime Analysis") ], live=False, title="Pronunciation Checker", description="Pronunciation Checker" ) with gr.Tab("Run Tests"): with gr.Column(): gr.Markdown("## Run Automated Tests") with gr.Row(): threshold_slider = gr.Slider(minimum=0.0, maximum=1.0, value=0.4, label="Decision Threshold") wavlm_layer_slider = gr.Slider(minimum=0, maximum=24, value=24, step=1, label="WavLM Layer") run_tests_button = gr.Button("Run Tests", variant="primary", interactive=True) # Create static containers for results with gr.Column(visible=False) as results_container: # Pre-create components for up to 10 test cases test_case_components = [] for i in range(10): # Adjust number based on your needs with gr.Column(visible=False) as case_container: heading = gr.Markdown() details = gr.Markdown() image = gr.Image(width=800) separator = gr.Markdown("---") test_case_components.append({ "container": case_container, "heading": heading, "details": details, "image": image, "separator": separator }) def format_test_results(threshold, wavlm_layer): test_cases = run_tests(threshold, wavlm_layer) updates = {results_container: gr.update(visible=True)} # Update pre-created components with new results for i, components in enumerate(test_case_components): if i < len(test_cases) and test_cases[i]['plot']: case = test_cases[i] updates[components["container"]] = gr.update(visible=True) updates[components["heading"]] = f"### {case['heading']}" updates[components["details"]] = case['text'] updates[components["image"]] = case['plot'] else: updates[components["container"]] = gr.update(visible=False) return updates # Update results when button is clicked run_tests_button.click( fn=format_test_results, inputs=[threshold_slider, wavlm_layer_slider], outputs=[results_container] + [comp for components in test_case_components for comp in [components["container"], components["heading"], components["details"], components["image"]]] ) if __name__ == "__main__": pronunciation_checker = PronunciationChecker("microsoft/wavlm-large") demo.launch()