karlhajal's picture
Set trim silences to false for demo reference
83f9419 verified
# -*- coding: utf-8 -*-
# SPDX-FileContributor: Karl El Hajal
# SPDX-FileContributor: Ali Dulaimi
import spaces
import gradio as gr
import tempfile
import matplotlib.pyplot as plt
import os
from src.pronunciation_checker import PronunciationChecker
from src.audio_preprocessing import assess_pronunciation_quality, denoise_audio, get_red_green_segments
from datetime import datetime
from pathlib import Path
import re
import json
import csv
@spaces.GPU
def check_pronunciation(reference_audio, input_audio, threshold, wavlm_layer, labels_data=None, input_number=None):
wavlm_layer = int(wavlm_layer)
timing_results = []
# can be moved to src/utils
def log_timing(step):
timing_results.append((step, datetime.now()))
log_timing("Start")
# ref_wav = denoise_audio(ref_wav)
input_audio = denoise_audio(input_audio)
log_timing("Input Audio Denoising")
# Extract features from both audio files
ref_wav, sr = pronunciation_checker.preprocess_wav(reference_audio)
log_timing("Reference Audio Preprocessing")
comparison_wav, _ = pronunciation_checker.preprocess_wav(input_audio)
log_timing("Input Audio Preprocessing")
# Check if waveforms are not empty
if ref_wav is None or comparison_wav is None:
raise ValueError("One or both of the waveforms are empty.")
# Extract features
ref_features, ref_wav, sr = pronunciation_checker.extract_features(ref_wav, wavlm_layer)
log_timing("Reference Feature Extraction")
input_features, comparison_wav, _ = pronunciation_checker.extract_features(comparison_wav, wavlm_layer)
log_timing("Input Feature Extraction")
# Compute DTW
dist_matrix, path = PronunciationChecker.compute_dtw(ref_features, input_features)
log_timing("DTW Computation")
quality_score, needs_repeat = assess_pronunciation_quality(dist_matrix, path, threshold, "ref")
log_timing("Quality Assessment")
# Check if DTW path is valid
if path is None or dist_matrix is None:
raise ValueError("DTW computation failed.")
PronunciationChecker.plot_waveform_with_overlay(ref_wav, sr, dist_matrix, path, "ref", threshold, labels_data, input_number)
log_timing("Visualization Generation")
# Save the visualization to a temporary image file
with tempfile.NamedTemporaryFile(delete=False, suffix='.png') as tmp:
tmp_path = tmp.name
plt.savefig(tmp_path)
plt.close()
timing_analysis = "Runtime Analysis:\n"
for i in range(1, len(timing_results)):
step_duration = (timing_results[i][1] - timing_results[i-1][1]).total_seconds()
timing_analysis += f"{timing_results[i][0]}: {step_duration:.2f}s\n"
# Return the image file path for Gradio to display
return tmp_path, quality_score, needs_repeat, timing_analysis
def parse_labels_file(labels_file):
labels_data = []
with open(labels_file, 'r') as f:
for line in f:
parts = re.split(r'[\t ]+', line)
if len(parts) < 3:
continue
start, end = map(float, parts[:2])
grapheme = parts[2]
boolean_labels = list(map(int, parts[3:]))
labels_data.append((start, end, grapheme, *boolean_labels))
return labels_data
@spaces.GPU
def run_tests(threshold=0.4, wavlm_layer=24):
# Create a directory for storing plots if it doesn't exist
plots_dir = Path("temp_plots")
plots_dir.mkdir(exist_ok=True)
benchmark_dir = Path("benchmark")
test_cases = []
# Clear any old plot files
for old_plot in plots_dir.glob("*.png"):
old_plot.unlink()
for test_word_dir in benchmark_dir.iterdir():
if test_word_dir.is_dir():
word = test_word_dir.name
ref_file = test_word_dir / "ref.wav"
labels_file = test_word_dir / "labels.txt"
if not ref_file.exists() or not labels_file.exists():
test_cases.append({
"heading": f"Error: Missing files for '{word}'",
"plot": None,
"text": f"Missing required files (ref.wav or labels.txt) for test word '{word}'."
})
continue
labels_data = parse_labels_file(labels_file)
for input_file in test_word_dir.glob("input_*.wav"):
if not input_file.exists():
test_cases.append({
"heading": f"Error: Missing input files for '{word}'",
"plot": None,
"text": f"No input files found for test word '{word}'."
})
continue
try:
# Generate a unique filename for this test's plot
plot_filename = f"{word}_{input_file.stem}_{datetime.now().strftime('%H%M%S')}.png"
plot_path = plots_dir / plot_filename
input_number = str(input_file.stem).split('_')[1]
input_number = int(input_number)-1
# Run the pronunciation checker
tmp_path, quality_score, needs_repeat, timing_analysis = check_pronunciation(
ref_file, input_file, threshold, wavlm_layer, labels_data, input_number
)
# Move the plot to our persistent directory
if os.path.exists(tmp_path):
os.rename(tmp_path, plot_path)
test_cases.append({
"heading": f"Word: {word}, Input: {input_file.name}",
"plot": str(plot_path),
"text": f"Quality Score: {quality_score:.2f}\n"
f"Needs Repeat: {needs_repeat}"
})
except Exception as e:
test_cases.append({
"heading": f"Error for '{word}' with '{input_file.name}'",
"plot": None,
"text": str(e)
})
return test_cases
def display_test_results(threshold, wavlm_layer):
test_cases = run_tests(threshold, wavlm_layer)
components = []
for case in test_cases:
components.append(gr.Markdown(f"### {case['heading']}"))
if case["plot"]:
components.append(gr.Image(case["plot"], type="filepath", label="Plot"))
components.append(gr.Textbox(case["text"], label="Details", interactive=False, lines=8))
return components
def calculate_red_percentage_demo(red_segments, labels_data, scaling_factor=0.02):
red_percentages = []
def intersection_length(start1, end1, start2, end2):
overlap_start = max(start1, start2)
overlap_end = min(end1, end2)
return max(0, overlap_end - overlap_start)
for start, end, latin, arabic in labels_data:
red_intersection = 0.0
for index in red_segments:
red_start_time = index * scaling_factor
red_end_time = (index + 1) * scaling_factor
red_intersection += intersection_length(start, end, red_start_time, red_end_time)
total_grapheme_duration = end - start
red_percentage = (red_intersection / total_grapheme_duration)
red_percentages.append(min(red_percentage, 1.))
return red_percentages
@spaces.GPU
def check_pronunciation_demo(reference_audio, input_audio, labels_data, threshold=0.4, wavlm_layer=24):
wavlm_layer = int(wavlm_layer)
# ref_wav = denoise_audio(ref_wav)
input_audio = denoise_audio(input_audio)
ref_wav, sr = pronunciation_checker.preprocess_wav(reference_audio, do_trim_silences=False)
comparison_wav, _ = pronunciation_checker.preprocess_wav(input_audio)
if ref_wav is None or comparison_wav is None:
raise ValueError("One or both of the waveforms are empty.")
ref_features, ref_wav, sr = pronunciation_checker.extract_features(ref_wav, wavlm_layer)
input_features, comparison_wav, _ = pronunciation_checker.extract_features(comparison_wav, wavlm_layer)
dist_matrix, path = PronunciationChecker.compute_dtw(ref_features, input_features)
red_segments, _, _ = get_red_green_segments(dist_matrix, path, wav_type="ref", threshold=threshold)
red_percentages = calculate_red_percentage_demo(red_segments, labels_data)
is_red = [percentage > 0.0 for percentage in red_percentages]
return is_red
def parse_tsv(file_path):
transcriptions = []
num_to_subtract = None
if os.path.exists(file_path):
with open(file_path, "r", encoding="utf-8") as f:
reader = csv.reader(f, delimiter="\t")
for row in reader:
if len(row) == 4: # Ensure it has the expected 4 columns
start, end, latin, arabic = row
start = int(start)
end = int(end)
if num_to_subtract is None:
num_to_subtract = start
start -= num_to_subtract
end -= num_to_subtract
start = float(start)/1000.
end = float(end)/1000.
transcriptions.append((start, end, latin, arabic))
return transcriptions
def collect_demo_data():
dialects = []
themes = set()
sentence_ids = set()
for dialect in os.listdir(DATA_DIR):
dialect_path = os.path.join(DATA_DIR, dialect)
if os.path.isdir(dialect_path):
dialects.append(dialect) # Collect dialects
for theme in os.listdir(dialect_path):
theme_path = os.path.join(dialect_path, theme)
if os.path.isdir(theme_path):
themes.add(theme) # Collect themes
for file in os.listdir(theme_path):
if file.endswith("_word.wav"):
sentence_id = file.split("_")[0]
sentence_ids.add(sentence_id) # Collect word IDs
return sorted(dialects), sorted(themes), sorted(sentence_ids)
@spaces.GPU
def run_demo(dialect, theme, sentence_id, input_audio):
reference_audio = os.path.join(DATA_DIR, dialect, theme, f"{sentence_id}_word.wav")
label_path = os.path.join(DATA_DIR, dialect, theme, f"{sentence_id}_labels.tsv")
labels_data = parse_tsv(label_path)
results = check_pronunciation_demo(reference_audio, input_audio, labels_data)
latin_output = []
arabic_output = []
for index, (is_red, (start, end, latin, arabic)) in enumerate(zip(results, labels_data)):
latin_output.append({
"index" : index,
"letter": latin,
"result": not is_red
})
arabic_output.append({
"index" : index,
"letter": arabic,
"result": not is_red
})
print(labels_data)
print(results)
result = {
"highlighted_text_id": sentence_id,
"highlighted_text_latin_payload": latin_output,
"highlighted_text_arabic_payload": arabic_output
}
return json.dumps(result, indent=2, ensure_ascii=False)
DATA_DIR = "data"
dialects, themes, sentence_ids = collect_demo_data()
with gr.Blocks() as demo:
with gr.Tabs():
with gr.Tab("Demo"):
dialect_dropdown = gr.Dropdown(choices=dialects, label="Select Dialect")
theme_dropdown = gr.Dropdown(choices=themes, label="Select Theme")
word_dropdown = gr.Dropdown(choices=sentence_ids, label="Select Sentence ID")
input_audio = gr.Audio(type="filepath", label="Reference Audio", format="wav", show_download_button=True)
# JSON output
output_json = gr.JSON(label="Output JSON")
# Button to trigger JSON output
submit_btn = gr.Button("Get JSON Output")
# JSON output function triggered by button
submit_btn.click(run_demo, inputs=[dialect_dropdown, theme_dropdown, word_dropdown, input_audio], outputs=[output_json])
with gr.Tab("Pronunciation Checker"):
gr.Interface(
fn=check_pronunciation,
inputs=[
gr.Audio(
type="filepath",
label="Reference Audio",
format="wav",
show_download_button=True,
),
gr.Audio(
type="filepath",
label="Input Audio",
format="wav",
show_download_button=True,
),
gr.Slider(minimum=0.0, maximum=1.0, value=0.4, label="Decision Threshold"),
gr.Slider(minimum=0, maximum=24, value=24, step=1, label="WavLM Layer"),
],
outputs=[
gr.Image(type="filepath", label="Pronunciation Comparison"),
gr.Number(label="Quality Score"),
gr.Textbox(label="Feedback"),
gr.Textbox(label="Runtime Analysis")
],
live=False,
title="Pronunciation Checker",
description="Pronunciation Checker"
)
with gr.Tab("Run Tests"):
with gr.Column():
gr.Markdown("## Run Automated Tests")
with gr.Row():
threshold_slider = gr.Slider(minimum=0.0, maximum=1.0, value=0.4, label="Decision Threshold")
wavlm_layer_slider = gr.Slider(minimum=0, maximum=24, value=24, step=1, label="WavLM Layer")
run_tests_button = gr.Button("Run Tests", variant="primary", interactive=True)
# Create static containers for results
with gr.Column(visible=False) as results_container:
# Pre-create components for up to 10 test cases
test_case_components = []
for i in range(10): # Adjust number based on your needs
with gr.Column(visible=False) as case_container:
heading = gr.Markdown()
details = gr.Markdown()
image = gr.Image(width=800)
separator = gr.Markdown("---")
test_case_components.append({
"container": case_container,
"heading": heading,
"details": details,
"image": image,
"separator": separator
})
def format_test_results(threshold, wavlm_layer):
test_cases = run_tests(threshold, wavlm_layer)
updates = {results_container: gr.update(visible=True)}
# Update pre-created components with new results
for i, components in enumerate(test_case_components):
if i < len(test_cases) and test_cases[i]['plot']:
case = test_cases[i]
updates[components["container"]] = gr.update(visible=True)
updates[components["heading"]] = f"### {case['heading']}"
updates[components["details"]] = case['text']
updates[components["image"]] = case['plot']
else:
updates[components["container"]] = gr.update(visible=False)
return updates
# Update results when button is clicked
run_tests_button.click(
fn=format_test_results,
inputs=[threshold_slider, wavlm_layer_slider],
outputs=[results_container] + [comp for components in test_case_components
for comp in [components["container"],
components["heading"],
components["details"],
components["image"]]]
)
if __name__ == "__main__":
pronunciation_checker = PronunciationChecker("microsoft/wavlm-large")
demo.launch()