0ahkd1
refactored red-green decision code
d4d0c2c
raw
history blame
11.1 kB
# -*- coding: utf-8 -*-
# SPDX-FileContributor: Karl El Hajal
# SPDX-FileContributor: Ali Dulaimi
import spaces
import gradio as gr
import tempfile
import matplotlib.pyplot as plt
import os
from src.pronunciation_checker import PronunciationChecker
from src.audio_preprocessing import assess_pronunciation_quality, denoise_audio
from datetime import datetime
from pathlib import Path
import re
@spaces.GPU
def check_pronunciation(reference_audio, input_audio, threshold, wavlm_layer, labels_data=None, input_number=None):
wavlm_layer = int(wavlm_layer)
timing_results = []
# can be moved to src/utils
def log_timing(step):
timing_results.append((step, datetime.now()))
log_timing("Start")
# ref_wav = denoise_audio(ref_wav)
input_audio = denoise_audio(input_audio)
log_timing("Input Audio Denoising")
# Extract features from both audio files
ref_wav, sr = pronunciation_checker.preprocess_wav(reference_audio)
log_timing("Reference Audio Preprocessing")
comparison_wav, _ = pronunciation_checker.preprocess_wav(input_audio)
log_timing("Input Audio Preprocessing")
# Check if waveforms are not empty
if ref_wav is None or comparison_wav is None:
raise ValueError("One or both of the waveforms are empty.")
# Extract features
ref_features, ref_wav, sr = pronunciation_checker.extract_features(ref_wav, wavlm_layer)
log_timing("Reference Feature Extraction")
input_features, comparison_wav, _ = pronunciation_checker.extract_features(comparison_wav, wavlm_layer)
log_timing("Input Feature Extraction")
# Compute DTW
dist_matrix, path = PronunciationChecker.compute_dtw(ref_features, input_features)
log_timing("DTW Computation")
quality_score, needs_repeat = assess_pronunciation_quality(dist_matrix, path)
log_timing("Quality Assessment")
# Check if DTW path is valid
if path is None or dist_matrix is None:
raise ValueError("DTW computation failed.")
PronunciationChecker.plot_waveform_with_overlay(ref_wav, sr, dist_matrix, path, "ref", threshold, labels_data, input_number)
log_timing("Visualization Generation")
# Save the visualization to a temporary image file
with tempfile.NamedTemporaryFile(delete=False, suffix='.png') as tmp:
tmp_path = tmp.name
plt.savefig(tmp_path)
plt.close()
timing_analysis = "Runtime Analysis:\n"
for i in range(1, len(timing_results)):
step_duration = (timing_results[i][1] - timing_results[i-1][1]).total_seconds()
timing_analysis += f"{timing_results[i][0]}: {step_duration:.2f}s\n"
# Return the image file path for Gradio to display
return tmp_path, quality_score, needs_repeat, timing_analysis
def parse_labels_file(labels_file):
labels_data = []
with open(labels_file, 'r') as f:
for line in f:
parts = re.split(r'[\t ]+', line)
if len(parts) < 3:
continue
start, end = map(float, parts[:2])
grapheme = parts[2]
boolean_labels = list(map(int, parts[3:]))
labels_data.append((start, end, grapheme, *boolean_labels))
return labels_data
@spaces.GPU
def run_tests(threshold=0.4, wavlm_layer=24):
# Create a directory for storing plots if it doesn't exist
plots_dir = Path("temp_plots")
plots_dir.mkdir(exist_ok=True)
benchmark_dir = Path("benchmark")
test_cases = []
# Clear any old plot files
for old_plot in plots_dir.glob("*.png"):
old_plot.unlink()
for test_word_dir in benchmark_dir.iterdir():
if test_word_dir.is_dir():
word = test_word_dir.name
ref_file = test_word_dir / "ref.wav"
labels_file = test_word_dir / "labels.txt"
if not ref_file.exists() or not labels_file.exists():
test_cases.append({
"heading": f"Error: Missing files for '{word}'",
"plot": None,
"text": f"Missing required files (ref.wav or labels.txt) for test word '{word}'."
})
continue
labels_data = parse_labels_file(labels_file)
for input_file in test_word_dir.glob("input_*.wav"):
if not input_file.exists():
test_cases.append({
"heading": f"Error: Missing input files for '{word}'",
"plot": None,
"text": f"No input files found for test word '{word}'."
})
continue
try:
# Generate a unique filename for this test's plot
plot_filename = f"{word}_{input_file.stem}_{datetime.now().strftime('%H%M%S')}.png"
plot_path = plots_dir / plot_filename
input_number = str(input_file.stem).split('_')[1]
input_number = int(input_number)-1
# Run the pronunciation checker
tmp_path, quality_score, needs_repeat, timing_analysis = check_pronunciation(
ref_file, input_file, threshold, wavlm_layer, labels_data, input_number
)
# Move the plot to our persistent directory
if os.path.exists(tmp_path):
os.rename(tmp_path, plot_path)
test_cases.append({
"heading": f"Word: {word}, Input: {input_file.name}",
"plot": str(plot_path),
"text": f"Quality Score: {quality_score:.2f}\n"
f"Needs Repeat: {needs_repeat}"
})
except Exception as e:
test_cases.append({
"heading": f"Error for '{word}' with '{input_file.name}'",
"plot": None,
"text": str(e)
})
return test_cases
def display_test_results(threshold, wavlm_layer):
test_cases = run_tests(threshold, wavlm_layer)
components = []
for case in test_cases:
components.append(gr.Markdown(f"### {case['heading']}"))
if case["plot"]:
components.append(gr.Image(case["plot"], type="filepath", label="Plot"))
components.append(gr.Textbox(case["text"], label="Details", interactive=False, lines=8))
return components
with gr.Blocks() as demo:
with gr.Tabs():
with gr.Tab("Pronunciation Checker"):
gr.Interface(
fn=check_pronunciation,
inputs=[
gr.Audio(
type="filepath",
label="Reference Audio",
format="wav",
show_download_button=True,
),
gr.Audio(
type="filepath",
label="Input Audio",
format="wav",
show_download_button=True,
),
gr.Slider(minimum=0.0, maximum=1.0, value=0.4, label="Decision Threshold"),
gr.Slider(minimum=0, maximum=24, value=24, step=1, label="WavLM Layer"),
],
outputs=[
gr.Image(type="filepath", label="Pronunciation Comparison"),
gr.Number(label="Quality Score"),
gr.Textbox(label="Feedback"),
gr.Textbox(label="Runtime Analysis")
],
live=False,
title="Pronunciation Checker",
description="Pronunciation Checker"
)
with gr.Tab("Run Tests"):
with gr.Column():
gr.Markdown("## Run Automated Tests")
with gr.Row():
threshold_slider = gr.Slider(minimum=0.0, maximum=1.0, value=0.4, label="Decision Threshold")
wavlm_layer_slider = gr.Slider(minimum=0, maximum=24, value=24, step=1, label="WavLM Layer")
run_tests_button = gr.Button("Run Tests", variant="primary", interactive=True)
# Create static containers for results
with gr.Column(visible=False) as results_container:
# Pre-create components for up to 10 test cases
test_case_components = []
for i in range(10): # Adjust number based on your needs
with gr.Column(visible=False) as case_container:
heading = gr.Markdown()
details = gr.Markdown()
image = gr.Image(width=800)
separator = gr.Markdown("---")
test_case_components.append({
"container": case_container,
"heading": heading,
"details": details,
"image": image,
"separator": separator
})
def format_test_results(threshold, wavlm_layer):
test_cases = run_tests(threshold, wavlm_layer)
updates = {results_container: gr.update(visible=True)}
# Update pre-created components with new results
for i, components in enumerate(test_case_components):
if i < len(test_cases) and test_cases[i]['plot']:
case = test_cases[i]
updates[components["container"]] = gr.update(visible=True)
updates[components["heading"]] = f"### {case['heading']}"
updates[components["details"]] = case['text']
updates[components["image"]] = case['plot']
else:
updates[components["container"]] = gr.update(visible=False)
return updates
# Update results when button is clicked
run_tests_button.click(
fn=format_test_results,
inputs=[threshold_slider, wavlm_layer_slider],
outputs=[results_container] + [comp for components in test_case_components
for comp in [components["container"],
components["heading"],
components["details"],
components["image"]]]
)
if __name__ == "__main__":
pronunciation_checker = PronunciationChecker("microsoft/wavlm-large")
demo.launch()