Spaces:
Sleeping
Sleeping
Add demo code
Browse files
app.py
CHANGED
|
@@ -8,10 +8,12 @@ import tempfile
|
|
| 8 |
import matplotlib.pyplot as plt
|
| 9 |
import os
|
| 10 |
from src.pronunciation_checker import PronunciationChecker
|
| 11 |
-
from src.audio_preprocessing import assess_pronunciation_quality, denoise_audio
|
| 12 |
from datetime import datetime
|
| 13 |
from pathlib import Path
|
| 14 |
import re
|
|
|
|
|
|
|
| 15 |
|
| 16 |
|
| 17 |
@spaces.GPU
|
|
@@ -172,8 +174,153 @@ def display_test_results(threshold, wavlm_layer):
|
|
| 172 |
components.append(gr.Textbox(case["text"], label="Details", interactive=False, lines=8))
|
| 173 |
return components
|
| 174 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 175 |
with gr.Blocks() as demo:
|
| 176 |
with gr.Tabs():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 177 |
with gr.Tab("Pronunciation Checker"):
|
| 178 |
gr.Interface(
|
| 179 |
fn=check_pronunciation,
|
|
|
|
| 8 |
import matplotlib.pyplot as plt
|
| 9 |
import os
|
| 10 |
from src.pronunciation_checker import PronunciationChecker
|
| 11 |
+
from src.audio_preprocessing import assess_pronunciation_quality, denoise_audio, get_red_green_segments
|
| 12 |
from datetime import datetime
|
| 13 |
from pathlib import Path
|
| 14 |
import re
|
| 15 |
+
import json
|
| 16 |
+
import csv
|
| 17 |
|
| 18 |
|
| 19 |
@spaces.GPU
|
|
|
|
| 174 |
components.append(gr.Textbox(case["text"], label="Details", interactive=False, lines=8))
|
| 175 |
return components
|
| 176 |
|
| 177 |
+
def calculate_red_percentage_demo(red_segments, labels_data, scaling_factor=0.02):
|
| 178 |
+
red_percentages = []
|
| 179 |
+
|
| 180 |
+
def intersection_length(start1, end1, start2, end2):
|
| 181 |
+
overlap_start = max(start1, start2)
|
| 182 |
+
overlap_end = min(end1, end2)
|
| 183 |
+
|
| 184 |
+
return max(0, overlap_end - overlap_start)
|
| 185 |
+
|
| 186 |
+
for start, end, latin, arabic in labels_data:
|
| 187 |
+
red_intersection = 0.0
|
| 188 |
+
for index in red_segments:
|
| 189 |
+
red_start_time = index * scaling_factor
|
| 190 |
+
red_end_time = (index + 1) * scaling_factor
|
| 191 |
+
|
| 192 |
+
red_intersection += intersection_length(start, end, red_start_time, red_end_time)
|
| 193 |
+
|
| 194 |
+
total_grapheme_duration = end - start
|
| 195 |
+
red_percentage = (red_intersection / total_grapheme_duration)
|
| 196 |
+
red_percentages.append(min(red_percentage, 1.))
|
| 197 |
+
|
| 198 |
+
return red_percentages
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
@spaces.GPU
|
| 202 |
+
def check_pronunciation_demo(reference_audio, input_audio, labels_data, threshold=0.4, wavlm_layer=24):
|
| 203 |
+
wavlm_layer = int(wavlm_layer)
|
| 204 |
+
|
| 205 |
+
# ref_wav = denoise_audio(ref_wav)
|
| 206 |
+
input_audio = denoise_audio(input_audio)
|
| 207 |
+
|
| 208 |
+
ref_wav, sr = pronunciation_checker.preprocess_wav(reference_audio)
|
| 209 |
+
comparison_wav, _ = pronunciation_checker.preprocess_wav(input_audio)
|
| 210 |
+
|
| 211 |
+
if ref_wav is None or comparison_wav is None:
|
| 212 |
+
raise ValueError("One or both of the waveforms are empty.")
|
| 213 |
+
|
| 214 |
+
ref_features, ref_wav, sr = pronunciation_checker.extract_features(ref_wav, wavlm_layer)
|
| 215 |
+
input_features, comparison_wav, _ = pronunciation_checker.extract_features(comparison_wav, wavlm_layer)
|
| 216 |
+
|
| 217 |
+
dist_matrix, path = PronunciationChecker.compute_dtw(ref_features, input_features)
|
| 218 |
+
|
| 219 |
+
red_segments, _, _ = get_red_green_segments(dist_matrix, path, wav_type="ref", threshold=threshold)
|
| 220 |
+
|
| 221 |
+
red_percentages = calculate_red_percentage_demo(red_segments, labels_data)
|
| 222 |
+
|
| 223 |
+
is_red = [percentage > 0.0 for percentage in red_percentages]
|
| 224 |
+
|
| 225 |
+
return is_red
|
| 226 |
+
|
| 227 |
+
def parse_tsv(file_path):
|
| 228 |
+
transcriptions = []
|
| 229 |
+
|
| 230 |
+
num_to_subtract = float('inf')
|
| 231 |
+
|
| 232 |
+
if os.path.exists(file_path):
|
| 233 |
+
with open(file_path, "r", encoding="utf-8") as f:
|
| 234 |
+
reader = csv.reader(f, delimiter="\t")
|
| 235 |
+
for row in reader:
|
| 236 |
+
if len(row) == 4: # Ensure it has the expected 4 columns
|
| 237 |
+
start, end, latin, arabic = row
|
| 238 |
+
start = float(start)/1000.
|
| 239 |
+
end = float(end)/1000.
|
| 240 |
+
|
| 241 |
+
num_to_subtract = min(num_to_subtract, start)
|
| 242 |
+
start -= num_to_subtract
|
| 243 |
+
end -= num_to_subtract
|
| 244 |
+
|
| 245 |
+
transcriptions.append((start, end, latin, arabic))
|
| 246 |
+
return transcriptions
|
| 247 |
+
|
| 248 |
+
def collect_demo_data():
|
| 249 |
+
dialects = []
|
| 250 |
+
themes = set()
|
| 251 |
+
sentence_ids = set()
|
| 252 |
+
|
| 253 |
+
for dialect in os.listdir(DATA_DIR):
|
| 254 |
+
dialect_path = os.path.join(DATA_DIR, dialect)
|
| 255 |
+
if os.path.isdir(dialect_path):
|
| 256 |
+
dialects.append(dialect) # Collect dialects
|
| 257 |
+
for theme in os.listdir(dialect_path):
|
| 258 |
+
theme_path = os.path.join(dialect_path, theme)
|
| 259 |
+
if os.path.isdir(theme_path):
|
| 260 |
+
themes.add(theme) # Collect themes
|
| 261 |
+
for file in os.listdir(theme_path):
|
| 262 |
+
if file.endswith("_word.wav"):
|
| 263 |
+
sentence_id = file.split("_")[0]
|
| 264 |
+
sentence_ids.add(sentence_id) # Collect word IDs
|
| 265 |
+
|
| 266 |
+
return sorted(dialects), sorted(themes), sorted(sentence_ids)
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
@spaces.GPU
|
| 270 |
+
def run_demo(dialect, theme, sentence_id, input_audio):
|
| 271 |
+
reference_audio = os.path.join(DATA_DIR, dialect, theme, f"{sentence_id}_word.wav")
|
| 272 |
+
label_path = os.path.join(DATA_DIR, dialect, theme, f"{sentence_id}_labels.tsv")
|
| 273 |
+
|
| 274 |
+
labels_data = parse_tsv(label_path)
|
| 275 |
+
|
| 276 |
+
results = check_pronunciation_demo(reference_audio, input_audio, labels_data)
|
| 277 |
+
|
| 278 |
+
latin_output = []
|
| 279 |
+
arabic_output = []
|
| 280 |
+
for index, (is_red, (start, end, latin, arabic)) in enumerate(zip(results, labels_data)):
|
| 281 |
+
latin_output.append({
|
| 282 |
+
"index" : index,
|
| 283 |
+
"letter": latin,
|
| 284 |
+
"result": not is_red
|
| 285 |
+
})
|
| 286 |
+
arabic_output.append({
|
| 287 |
+
"index" : index,
|
| 288 |
+
"letter": arabic,
|
| 289 |
+
"result": not is_red
|
| 290 |
+
})
|
| 291 |
+
|
| 292 |
+
print(labels_data)
|
| 293 |
+
print(results)
|
| 294 |
+
|
| 295 |
+
result = {
|
| 296 |
+
"highlighted_text_id": sentence_id,
|
| 297 |
+
"highlighted_text_latin_payload": latin_output,
|
| 298 |
+
"highlighted_text_arabic_payload": arabic_output
|
| 299 |
+
}
|
| 300 |
+
|
| 301 |
+
return json.dumps(result, indent=2, ensure_ascii=False)
|
| 302 |
+
|
| 303 |
+
DATA_DIR = "data"
|
| 304 |
+
dialects, themes, sentence_ids = collect_demo_data()
|
| 305 |
+
|
| 306 |
with gr.Blocks() as demo:
|
| 307 |
with gr.Tabs():
|
| 308 |
+
with gr.Tab("Demo"):
|
| 309 |
+
dialect_dropdown = gr.Dropdown(choices=dialects, label="Select Dialect")
|
| 310 |
+
theme_dropdown = gr.Dropdown(choices=themes, label="Select Theme")
|
| 311 |
+
word_dropdown = gr.Dropdown(choices=sentence_ids, label="Select Sentence ID")
|
| 312 |
+
|
| 313 |
+
input_audio = gr.Audio(type="filepath", label="Reference Audio", format="wav", show_download_button=True)
|
| 314 |
+
|
| 315 |
+
# JSON output
|
| 316 |
+
output_json = gr.JSON(label="Output JSON")
|
| 317 |
+
|
| 318 |
+
# Button to trigger JSON output
|
| 319 |
+
submit_btn = gr.Button("Get JSON Output")
|
| 320 |
+
|
| 321 |
+
# JSON output function triggered by button
|
| 322 |
+
submit_btn.click(run_demo, inputs=[dialect_dropdown, theme_dropdown, word_dropdown, input_audio], outputs=[output_json])
|
| 323 |
+
|
| 324 |
with gr.Tab("Pronunciation Checker"):
|
| 325 |
gr.Interface(
|
| 326 |
fn=check_pronunciation,
|