karlhajal commited on
Commit
4578f82
·
verified ·
1 Parent(s): 413e8eb

Add demo code

Browse files
Files changed (1) hide show
  1. app.py +148 -1
app.py CHANGED
@@ -8,10 +8,12 @@ import tempfile
8
  import matplotlib.pyplot as plt
9
  import os
10
  from src.pronunciation_checker import PronunciationChecker
11
- from src.audio_preprocessing import assess_pronunciation_quality, denoise_audio
12
  from datetime import datetime
13
  from pathlib import Path
14
  import re
 
 
15
 
16
 
17
  @spaces.GPU
@@ -172,8 +174,153 @@ def display_test_results(threshold, wavlm_layer):
172
  components.append(gr.Textbox(case["text"], label="Details", interactive=False, lines=8))
173
  return components
174
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
175
  with gr.Blocks() as demo:
176
  with gr.Tabs():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
  with gr.Tab("Pronunciation Checker"):
178
  gr.Interface(
179
  fn=check_pronunciation,
 
8
  import matplotlib.pyplot as plt
9
  import os
10
  from src.pronunciation_checker import PronunciationChecker
11
+ from src.audio_preprocessing import assess_pronunciation_quality, denoise_audio, get_red_green_segments
12
  from datetime import datetime
13
  from pathlib import Path
14
  import re
15
+ import json
16
+ import csv
17
 
18
 
19
  @spaces.GPU
 
174
  components.append(gr.Textbox(case["text"], label="Details", interactive=False, lines=8))
175
  return components
176
 
177
+ def calculate_red_percentage_demo(red_segments, labels_data, scaling_factor=0.02):
178
+ red_percentages = []
179
+
180
+ def intersection_length(start1, end1, start2, end2):
181
+ overlap_start = max(start1, start2)
182
+ overlap_end = min(end1, end2)
183
+
184
+ return max(0, overlap_end - overlap_start)
185
+
186
+ for start, end, latin, arabic in labels_data:
187
+ red_intersection = 0.0
188
+ for index in red_segments:
189
+ red_start_time = index * scaling_factor
190
+ red_end_time = (index + 1) * scaling_factor
191
+
192
+ red_intersection += intersection_length(start, end, red_start_time, red_end_time)
193
+
194
+ total_grapheme_duration = end - start
195
+ red_percentage = (red_intersection / total_grapheme_duration)
196
+ red_percentages.append(min(red_percentage, 1.))
197
+
198
+ return red_percentages
199
+
200
+
201
+ @spaces.GPU
202
+ def check_pronunciation_demo(reference_audio, input_audio, labels_data, threshold=0.4, wavlm_layer=24):
203
+ wavlm_layer = int(wavlm_layer)
204
+
205
+ # ref_wav = denoise_audio(ref_wav)
206
+ input_audio = denoise_audio(input_audio)
207
+
208
+ ref_wav, sr = pronunciation_checker.preprocess_wav(reference_audio)
209
+ comparison_wav, _ = pronunciation_checker.preprocess_wav(input_audio)
210
+
211
+ if ref_wav is None or comparison_wav is None:
212
+ raise ValueError("One or both of the waveforms are empty.")
213
+
214
+ ref_features, ref_wav, sr = pronunciation_checker.extract_features(ref_wav, wavlm_layer)
215
+ input_features, comparison_wav, _ = pronunciation_checker.extract_features(comparison_wav, wavlm_layer)
216
+
217
+ dist_matrix, path = PronunciationChecker.compute_dtw(ref_features, input_features)
218
+
219
+ red_segments, _, _ = get_red_green_segments(dist_matrix, path, wav_type="ref", threshold=threshold)
220
+
221
+ red_percentages = calculate_red_percentage_demo(red_segments, labels_data)
222
+
223
+ is_red = [percentage > 0.0 for percentage in red_percentages]
224
+
225
+ return is_red
226
+
227
+ def parse_tsv(file_path):
228
+ transcriptions = []
229
+
230
+ num_to_subtract = float('inf')
231
+
232
+ if os.path.exists(file_path):
233
+ with open(file_path, "r", encoding="utf-8") as f:
234
+ reader = csv.reader(f, delimiter="\t")
235
+ for row in reader:
236
+ if len(row) == 4: # Ensure it has the expected 4 columns
237
+ start, end, latin, arabic = row
238
+ start = float(start)/1000.
239
+ end = float(end)/1000.
240
+
241
+ num_to_subtract = min(num_to_subtract, start)
242
+ start -= num_to_subtract
243
+ end -= num_to_subtract
244
+
245
+ transcriptions.append((start, end, latin, arabic))
246
+ return transcriptions
247
+
248
+ def collect_demo_data():
249
+ dialects = []
250
+ themes = set()
251
+ sentence_ids = set()
252
+
253
+ for dialect in os.listdir(DATA_DIR):
254
+ dialect_path = os.path.join(DATA_DIR, dialect)
255
+ if os.path.isdir(dialect_path):
256
+ dialects.append(dialect) # Collect dialects
257
+ for theme in os.listdir(dialect_path):
258
+ theme_path = os.path.join(dialect_path, theme)
259
+ if os.path.isdir(theme_path):
260
+ themes.add(theme) # Collect themes
261
+ for file in os.listdir(theme_path):
262
+ if file.endswith("_word.wav"):
263
+ sentence_id = file.split("_")[0]
264
+ sentence_ids.add(sentence_id) # Collect word IDs
265
+
266
+ return sorted(dialects), sorted(themes), sorted(sentence_ids)
267
+
268
+
269
+ @spaces.GPU
270
+ def run_demo(dialect, theme, sentence_id, input_audio):
271
+ reference_audio = os.path.join(DATA_DIR, dialect, theme, f"{sentence_id}_word.wav")
272
+ label_path = os.path.join(DATA_DIR, dialect, theme, f"{sentence_id}_labels.tsv")
273
+
274
+ labels_data = parse_tsv(label_path)
275
+
276
+ results = check_pronunciation_demo(reference_audio, input_audio, labels_data)
277
+
278
+ latin_output = []
279
+ arabic_output = []
280
+ for index, (is_red, (start, end, latin, arabic)) in enumerate(zip(results, labels_data)):
281
+ latin_output.append({
282
+ "index" : index,
283
+ "letter": latin,
284
+ "result": not is_red
285
+ })
286
+ arabic_output.append({
287
+ "index" : index,
288
+ "letter": arabic,
289
+ "result": not is_red
290
+ })
291
+
292
+ print(labels_data)
293
+ print(results)
294
+
295
+ result = {
296
+ "highlighted_text_id": sentence_id,
297
+ "highlighted_text_latin_payload": latin_output,
298
+ "highlighted_text_arabic_payload": arabic_output
299
+ }
300
+
301
+ return json.dumps(result, indent=2, ensure_ascii=False)
302
+
303
+ DATA_DIR = "data"
304
+ dialects, themes, sentence_ids = collect_demo_data()
305
+
306
  with gr.Blocks() as demo:
307
  with gr.Tabs():
308
+ with gr.Tab("Demo"):
309
+ dialect_dropdown = gr.Dropdown(choices=dialects, label="Select Dialect")
310
+ theme_dropdown = gr.Dropdown(choices=themes, label="Select Theme")
311
+ word_dropdown = gr.Dropdown(choices=sentence_ids, label="Select Sentence ID")
312
+
313
+ input_audio = gr.Audio(type="filepath", label="Reference Audio", format="wav", show_download_button=True)
314
+
315
+ # JSON output
316
+ output_json = gr.JSON(label="Output JSON")
317
+
318
+ # Button to trigger JSON output
319
+ submit_btn = gr.Button("Get JSON Output")
320
+
321
+ # JSON output function triggered by button
322
+ submit_btn.click(run_demo, inputs=[dialect_dropdown, theme_dropdown, word_dropdown, input_audio], outputs=[output_json])
323
+
324
  with gr.Tab("Pronunciation Checker"):
325
  gr.Interface(
326
  fn=check_pronunciation,