Bisher commited on
Commit
31c4539
·
verified ·
1 Parent(s): 1590f7f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +228 -80
app.py CHANGED
@@ -5,7 +5,7 @@ import os
5
  import time
6
  import warnings
7
  import pyarabic.araby as araby
8
-
9
 
10
  # Suppress specific UserWarnings from jiwer related to empty strings
11
  warnings.filterwarnings("ignore", message="Reference is empty.*", category=UserWarning)
@@ -15,41 +15,59 @@ warnings.filterwarnings("ignore", message="Hypothesis is empty.*", category=User
15
  DIACRITIZATION_API_URL = "Bisher/CATT.diacratization"
16
  TRANSCRIPTION_API_URL = "gh-kaka22/diacritic_level_arabic_transcription"
17
  SYLLABLE_TRANSCRIPTION_API_URL = "Bisher/arabic_syllable_transcription"
 
18
  # Define Arabic diacritics
19
- if araby:
 
20
  ARABIC_DIACRITICS = {
21
  araby.FATHA, araby.FATHATAN, araby.DAMMA, araby.DAMMATAN,
22
  araby.KASRA, araby.KASRATAN, araby.SUKUN, araby.SHADDA,
23
  }
24
- else:
 
25
  ARABIC_DIACRITICS = {'\u064B', '\u064C', '\u064D', '\u064E', '\u064F', '\u0650', '\u0651', '\u0652'}
26
 
27
  # --- API Clients ---
 
 
 
 
 
28
  def get_diacritization_client():
29
- try:
30
- return Client(DIACRITIZATION_API_URL, download_files=True)
31
- except Exception as e:
32
- print(f"Error initializing diacritization client: {e}")
33
- return None
 
 
 
34
 
35
  def get_transcription_client():
36
- try:
37
- return Client(TRANSCRIPTION_API_URL, download_files=True)
38
- except Exception as e:
39
- print(f"Error initializing transcription client: {e}")
40
- return None
 
 
 
41
 
42
  def get_syllable_transcription_client():
43
- try:
44
- return Client(SYLLABLE_TRANSCRIPTION_API_URL, download_files=True)
45
- except Exception as e:
46
- print(f"Error initializing transcription client: {e}")
47
- return None
 
 
 
48
 
49
  # --- Helper Functions ---
50
  def diacritize_text_api(text_to_diacritize):
 
51
  if not text_to_diacritize or not text_to_diacritize.strip():
52
- return "Please enter some text to diacritize.", ""
53
  client = get_diacritization_client()
54
  if not client:
55
  return "Error: Could not connect to the diacritization service.", ""
@@ -59,145 +77,275 @@ def diacritize_text_api(text_to_diacritize):
59
  input_text=text_to_diacritize,
60
  api_name="/predict"
61
  )
 
62
  result_str = str(result) if result is not None else "Error: Empty response from diacritization service."
 
63
  return result_str, result_str
64
  except Exception as e:
 
65
  return f"Error during diacritization: {e}", ""
66
 
67
  def transcribe_audio_api(audio_filepath):
 
68
  if not audio_filepath:
69
  return "Error: Please provide an audio recording or file."
70
  if not os.path.exists(audio_filepath):
71
  return f"Error: Audio file not found at {audio_filepath}"
 
72
  client = get_transcription_client()
73
  if not client:
74
  return "Error: Could not connect to the transcription service."
75
  try:
 
 
76
  result = client.predict(
77
  audio=handle_file(audio_filepath),
78
  api_name="/predict"
79
  )
 
80
  if isinstance(result, dict) and 'text' in result:
81
  transcript = result['text']
 
 
82
  else:
83
- transcript = str(result)
84
- return transcript
85
  except Exception as e:
 
86
  return f"Error during transcription: {e}"
87
 
88
  def transcribe_syllable_audio_api(audio_filepath):
 
89
  if not audio_filepath:
90
- return "Error: Please provide an audio recording or file."
 
91
  if not os.path.exists(audio_filepath):
92
- return f"Error: Audio file not found at {audio_filepath}"
 
93
  client = get_syllable_transcription_client()
94
  if not client:
95
- return "Error: Could not connect to the transcription service."
96
  try:
 
 
97
  result = client.predict(
98
  audio=handle_file(audio_filepath),
99
  api_name="/predict"
100
  )
 
101
  if isinstance(result, dict) and 'text' in result:
102
  transcript = result['text']
 
 
103
  else:
104
- transcript = str(result)
105
- return transcript
106
  except Exception as e:
107
- return f"Error during transcription: {e}"
 
108
 
109
  def get_diacritics_sequence(text):
 
110
  if not isinstance(text, str):
111
  return ""
112
  diacritics_only = [c for c in text if c in ARABIC_DIACRITICS]
113
  return ' '.join(diacritics_only)
114
 
115
  def calculate_metrics(reference, hypothesis):
 
116
  ref = reference or ""
117
  hyp = hypothesis or ""
118
- # WER
119
- wer = jiwer.wer(ref, hyp) if ref.strip() else (1.0 if hyp.strip() else 0.0)
120
- # DER
121
- ref_d = get_diacritics_sequence(ref)
122
- hyp_d = get_diacritics_sequence(hyp)
123
- der = jiwer.wer(ref_d, hyp_d) if ref_d.strip() else (1.0 if hyp_d.strip() else 0.0)
124
- # CER
125
- cer = jiwer.cer(ref, hyp) if ref.strip() else (1.0 if hyp.strip() else 0.0)
126
- return round(wer, 4), round(der, 4), round(cer, 4)
127
-
128
- import difflib
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
 
130
  def highlight_errors(reference, hypothesis):
131
- ref_words = reference.split()
132
- hyp_words = hypothesis.split()
133
- matcher = difflib.SequenceMatcher(a=ref_words, b=hyp_words)
134
- highlighted = []
135
- errors = []
136
- # Iterate over matched blocks and insert highlights for mismatches
137
- i = j = 0
138
- for tag, a0, a1, b0, b1 in matcher.get_opcodes():
 
 
 
 
 
 
 
139
  if tag == 'equal':
140
- for w in ref_words[a0:a1]:
141
- highlighted.append(w)
142
- else:
143
- # highlight reference words as errors
144
- for w in ref_words[a0:a1]:
145
- highlighted.append(f"<mark>{w}</mark>")
146
- errors.append(w)
147
- i = a1
148
- j = b1
149
- html = ' '.join(highlighted)
150
- return html, ', '.join(errors)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
 
152
  # --- Gradio Interface ---
153
  with gr.Blocks(theme=gr.themes.Soft()) as app:
154
  gr.Markdown(
155
  """
156
  # Arabic Diacritization and Reading Assessment Tool
157
- 1. Enter undiacritized Arabic text and click **Diacritize Text**.
158
- 2. Read the generated **Diacritized Text** aloud and record or upload audio.
159
- 3. Click **Transcribe and Compare** to see the transcript, WER/DER/CER, and mispronounced words highlighted.
160
  """
161
  )
162
 
163
- original_state = gr.State("")
 
 
164
  with gr.Row():
165
  with gr.Column(scale=1):
166
  text_input = gr.Textbox(label="Undiacritized Arabic Text", lines=3, text_align="right")
167
  diacritize_btn = gr.Button("Diacritize Text")
168
- diacritized_output = gr.Textbox(label="Diacritized Text (Reference)", lines=3, interactive=False, text_align="right")
 
 
 
 
 
169
 
170
  with gr.Column(scale=1):
171
- audio_input = gr.Audio(label="Record or Upload Audio", type="filepath")
172
  transcribe_btn = gr.Button("Transcribe and Compare")
173
- transcript_output = gr.Textbox(label="Transcript (Hypothesis)", lines=3, interactive=False, text_align="right")
174
- transcript_syllables_output = gr.Textbox(label="Transcript syllables (Hypothesis)", lines=3, interactive=False, text_align="right")
 
 
 
 
 
 
 
 
 
 
 
175
  with gr.Row():
176
  wer_out = gr.Number(label="WER", interactive=False, precision=4)
177
  der_out = gr.Number(label="DER", interactive=False, precision=4)
178
  cer_out = gr.Number(label="CER", interactive=False, precision=4)
179
- error_html = gr.HTML(label="Highlighted Errors")
180
- error_list = gr.Textbox(label="Mispronounced Words", interactive=False)
 
181
 
 
 
 
182
  diacritize_btn.click(
183
  fn=diacritize_text_api,
184
  inputs=[text_input],
185
- outputs=[diacritized_output, original_state]
 
186
  )
187
 
188
- def process(audio, ref_text):
189
- transcript = transcribe_audio_api(audio)
190
- syllable_transcript = transcribe_syllable_audio_api(audio)
191
- if transcript.startswith("Error"):
192
- return transcript, None, None, None, "", ""
193
- wer, der, cer = calculate_metrics(ref_text, transcript)
194
- html, errs = highlight_errors(ref_text, transcript)
195
- return transcript,syllable_transcript, wer, der, cer, html, errs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
196
 
 
197
  transcribe_btn.click(
198
- fn=process,
199
- inputs=[audio_input, original_state],
200
- outputs=[transcript_output, transcript_syllables_output, wer_out, der_out, cer_out, error_html, error_list]
 
 
 
 
 
 
 
 
 
 
201
  )
202
 
203
- app.launch(debug=True, share=True)
 
5
  import time
6
  import warnings
7
  import pyarabic.araby as araby
8
+ import difflib # Import difflib
9
 
10
  # Suppress specific UserWarnings from jiwer related to empty strings
11
  warnings.filterwarnings("ignore", message="Reference is empty.*", category=UserWarning)
 
15
  DIACRITIZATION_API_URL = "Bisher/CATT.diacratization"
16
  TRANSCRIPTION_API_URL = "gh-kaka22/diacritic_level_arabic_transcription"
17
  SYLLABLE_TRANSCRIPTION_API_URL = "Bisher/arabic_syllable_transcription"
18
+
19
  # Define Arabic diacritics
20
+ # Use a try-except block in case pyarabic is not installed or fails to import
21
+ try:
22
  ARABIC_DIACRITICS = {
23
  araby.FATHA, araby.FATHATAN, araby.DAMMA, araby.DAMMATAN,
24
  araby.KASRA, araby.KASRATAN, araby.SUKUN, araby.SHADDA,
25
  }
26
+ except (ImportError, NameError):
27
+ print("Warning: pyarabic not found or failed to import. Using fallback diacritics set.")
28
  ARABIC_DIACRITICS = {'\u064B', '\u064C', '\u064D', '\u064E', '\u064F', '\u0650', '\u0651', '\u0652'}
29
 
30
  # --- API Clients ---
31
+ # Use caching or global clients to avoid re-initializing on every call
32
+ diacritization_client = None
33
+ transcription_client = None
34
+ syllable_transcription_client = None
35
+
36
  def get_diacritization_client():
37
+ global diacritization_client
38
+ if diacritization_client is None:
39
+ try:
40
+ diacritization_client = Client(DIACRITIZATION_API_URL, download_files=True)
41
+ except Exception as e:
42
+ print(f"Error initializing diacritization client: {e}")
43
+ return None
44
+ return diacritization_client
45
 
46
  def get_transcription_client():
47
+ global transcription_client
48
+ if transcription_client is None:
49
+ try:
50
+ transcription_client = Client(TRANSCRIPTION_API_URL, download_files=True)
51
+ except Exception as e:
52
+ print(f"Error initializing transcription client: {e}")
53
+ return None
54
+ return transcription_client
55
 
56
  def get_syllable_transcription_client():
57
+ global syllable_transcription_client
58
+ if syllable_transcription_client is None:
59
+ try:
60
+ syllable_transcription_client = Client(SYLLABLE_TRANSCRIPTION_API_URL, download_files=True)
61
+ except Exception as e:
62
+ print(f"Error initializing syllable transcription client: {e}")
63
+ return None
64
+ return syllable_transcription_client
65
 
66
  # --- Helper Functions ---
67
  def diacritize_text_api(text_to_diacritize):
68
+ """Calls the diacritization API."""
69
  if not text_to_diacritize or not text_to_diacritize.strip():
70
+ return "Please enter some text to diacritize.", "" # Return two values as expected by the click handler
71
  client = get_diacritization_client()
72
  if not client:
73
  return "Error: Could not connect to the diacritization service.", ""
 
77
  input_text=text_to_diacritize,
78
  api_name="/predict"
79
  )
80
+ # Ensure result is a string, handle potential None or unexpected types
81
  result_str = str(result) if result is not None else "Error: Empty response from diacritization service."
82
+ # Return the result for both the output textbox and the state
83
  return result_str, result_str
84
  except Exception as e:
85
+ print(f"Error during diacritization API call: {e}")
86
  return f"Error during diacritization: {e}", ""
87
 
88
  def transcribe_audio_api(audio_filepath):
89
+ """Calls the standard transcription API."""
90
  if not audio_filepath:
91
  return "Error: Please provide an audio recording or file."
92
  if not os.path.exists(audio_filepath):
93
  return f"Error: Audio file not found at {audio_filepath}"
94
+
95
  client = get_transcription_client()
96
  if not client:
97
  return "Error: Could not connect to the transcription service."
98
  try:
99
+ # Add a small delay if needed, sometimes helps with API race conditions
100
+ # time.sleep(0.5)
101
  result = client.predict(
102
  audio=handle_file(audio_filepath),
103
  api_name="/predict"
104
  )
105
+ # Process result, expecting a dictionary or string
106
  if isinstance(result, dict) and 'text' in result:
107
  transcript = result['text']
108
+ elif isinstance(result, str):
109
+ transcript = result
110
  else:
111
+ transcript = f"Error: Unexpected response format from transcription service: {type(result)}"
112
+ return transcript if transcript is not None else "Error: Empty transcript received."
113
  except Exception as e:
114
+ print(f"Error during transcription API call: {e}")
115
  return f"Error during transcription: {e}"
116
 
117
  def transcribe_syllable_audio_api(audio_filepath):
118
+ """Calls the syllable transcription API."""
119
  if not audio_filepath:
120
+ # This case might not be strictly needed if called after the first check, but good practice
121
+ return "Error: Audio file path missing for syllable transcription."
122
  if not os.path.exists(audio_filepath):
123
+ return f"Error: Audio file not found at {audio_filepath} for syllable transcription."
124
+
125
  client = get_syllable_transcription_client()
126
  if not client:
127
+ return "Error: Could not connect to the syllable transcription service."
128
  try:
129
+ # Add a small delay if needed
130
+ # time.sleep(0.5)
131
  result = client.predict(
132
  audio=handle_file(audio_filepath),
133
  api_name="/predict"
134
  )
135
+ # Process result, expecting a dictionary or string
136
  if isinstance(result, dict) and 'text' in result:
137
  transcript = result['text']
138
+ elif isinstance(result, str):
139
+ transcript = result
140
  else:
141
+ transcript = f"Error: Unexpected response format from syllable transcription service: {type(result)}"
142
+ return transcript if transcript is not None else "Error: Empty syllable transcript received."
143
  except Exception as e:
144
+ print(f"Error during syllable transcription API call: {e}")
145
+ return f"Error during syllable transcription: {e}"
146
 
147
  def get_diacritics_sequence(text):
148
+ """Extracts diacritics from a string."""
149
  if not isinstance(text, str):
150
  return ""
151
  diacritics_only = [c for c in text if c in ARABIC_DIACRITICS]
152
  return ' '.join(diacritics_only)
153
 
154
  def calculate_metrics(reference, hypothesis):
155
+ """Calculates WER, DER, CER."""
156
  ref = reference or ""
157
  hyp = hypothesis or ""
158
+
159
+ # Handle cases where one or both are empty or just whitespace
160
+ if not ref.strip() and not hyp.strip():
161
+ return 0.0, 0.0, 0.0 # Both empty, 0 error
162
+ if not ref.strip():
163
+ return 1.0, 1.0, 1.0 # Reference empty, hypothesis not: Max error
164
+ if not hyp.strip():
165
+ # Hypothesis empty, reference not: Max error (though jiwer might handle this)
166
+ # Let jiwer calculate based on its rules for empty hypothesis
167
+ pass
168
+
169
+ try:
170
+ # WER
171
+ wer = jiwer.wer(ref, hyp)
172
+ # DER
173
+ ref_d = get_diacritics_sequence(ref)
174
+ hyp_d = get_diacritics_sequence(hyp)
175
+ # Handle empty diacritic sequences for DER calculation
176
+ if not ref_d.strip() and not hyp_d.strip():
177
+ der = 0.0
178
+ elif not ref_d.strip():
179
+ der = 1.0
180
+ else:
181
+ der = jiwer.wer(ref_d, hyp_d) # jiwer handles empty hyp_d if ref_d is not empty
182
+ # CER
183
+ cer = jiwer.cer(ref, hyp)
184
+ return round(wer, 4), round(der, 4), round(cer, 4)
185
+ except Exception as e:
186
+ print(f"Error calculating metrics: {e}")
187
+ return None, None, None # Indicate error in calculation
188
+
189
 
190
  def highlight_errors(reference, hypothesis):
191
+ """Highlights differences between reference and hypothesis using HTML mark tag."""
192
+ ref = reference or ""
193
+ hyp = hypothesis or ""
194
+ ref_words = ref.split()
195
+ hyp_words = hyp.split()
196
+
197
+ if not ref_words and not hyp_words:
198
+ return "", "" # No errors if both are empty
199
+
200
+ matcher = difflib.SequenceMatcher(None, ref_words, hyp_words, autojunk=False)
201
+ highlighted_hyp_words = []
202
+ error_words_ref = [] # Words in reference that were deleted or replaced
203
+ error_words_hyp = [] # Words in hypothesis that were inserted or replaced
204
+
205
+ for tag, i1, i2, j1, j2 in matcher.get_opcodes():
206
  if tag == 'equal':
207
+ highlighted_hyp_words.extend(hyp_words[j1:j2])
208
+ elif tag == 'replace':
209
+ # Mark incorrect words in hypothesis red
210
+ for word in hyp_words[j1:j2]:
211
+ highlighted_hyp_words.append(f"<mark style='background-color: #ffcccb;'>{word}</mark>")
212
+ error_words_ref.extend(ref_words[i1:i2])
213
+ error_words_hyp.extend(hyp_words[j1:j2])
214
+ elif tag == 'delete':
215
+ # Indicate missing words (maybe with a placeholder?) - for now, just note them
216
+ # We don't add anything to highlighted_hyp_words here as they are missing
217
+ error_words_ref.extend(ref_words[i1:i2])
218
+ # Optionally add a placeholder in the output to show where deletion happened
219
+ # highlighted_hyp_words.append("<mark style='background-color: #lightgrey;'>[missing]</mark>")
220
+ elif tag == 'insert':
221
+ # Mark inserted words in hypothesis green
222
+ for word in hyp_words[j1:j2]:
223
+ highlighted_hyp_words.append(f"<mark style='background-color: #ccffcc;'>{word}</mark>")
224
+ error_words_hyp.extend(hyp_words[j1:j2])
225
+
226
+ html_output = ' '.join(highlighted_hyp_words)
227
+ # Combine unique error words for the list
228
+ error_list = sorted(list(set(error_words_ref + error_words_hyp)))
229
+
230
+ return html_output, ', '.join(error_list)
231
+
232
 
233
  # --- Gradio Interface ---
234
  with gr.Blocks(theme=gr.themes.Soft()) as app:
235
  gr.Markdown(
236
  """
237
  # Arabic Diacritization and Reading Assessment Tool
238
+ 1. Enter undiacritized Arabic text and click **Diacritize Text**.
239
+ 2. Read the generated **Diacritized Text** aloud and record or upload audio.
240
+ 3. Click **Transcribe and Compare** to see the transcript, syllable transcript, WER/DER/CER, and mispronounced words highlighted.
241
  """
242
  )
243
 
244
+ # Using gr.State to hold the diacritized reference text between steps
245
+ reference_text_state = gr.State("")
246
+
247
  with gr.Row():
248
  with gr.Column(scale=1):
249
  text_input = gr.Textbox(label="Undiacritized Arabic Text", lines=3, text_align="right")
250
  diacritize_btn = gr.Button("Diacritize Text")
251
+ diacritized_output = gr.Textbox(
252
+ label="Diacritized Text (Reference)",
253
+ lines=3,
254
+ interactive=False, # User shouldn't edit this directly
255
+ text_align="right"
256
+ )
257
 
258
  with gr.Column(scale=1):
259
+ audio_input = gr.Audio(label="Record or Upload Audio", type="filepath", sources=["microphone", "upload"])
260
  transcribe_btn = gr.Button("Transcribe and Compare")
261
+ transcript_output = gr.Textbox(
262
+ label="Transcript (Hypothesis)",
263
+ lines=3,
264
+ interactive=False,
265
+ text_align="right"
266
+ )
267
+ # Ensure this Textbox is defined correctly
268
+ transcript_syllables_output = gr.Textbox(
269
+ label="Transcript Syllables (Hypothesis)", # Corrected label slightly for clarity
270
+ lines=3,
271
+ interactive=False,
272
+ text_align="right"
273
+ )
274
  with gr.Row():
275
  wer_out = gr.Number(label="WER", interactive=False, precision=4)
276
  der_out = gr.Number(label="DER", interactive=False, precision=4)
277
  cer_out = gr.Number(label="CER", interactive=False, precision=4)
278
+ # Use Markdown for potentially richer HTML display if needed, but HTML component is fine
279
+ error_html = gr.HTML(label="Highlighted Errors in Hypothesis")
280
+ error_list = gr.Textbox(label="Words Involved in Errors", interactive=False) # Changed label
281
 
282
+ # --- Event Handlers ---
283
+
284
+ # When Diacritize button is clicked
285
  diacritize_btn.click(
286
  fn=diacritize_text_api,
287
  inputs=[text_input],
288
+ # Output to the display box AND the hidden state
289
+ outputs=[diacritized_output, reference_text_state]
290
  )
291
 
292
+ # Define the main processing function that returns all 7 values
293
+ def process_audio_and_compare(audio_filepath, reference_text):
294
+ """Processes audio, gets both transcripts, calculates metrics, and highlights errors."""
295
+ # Default values in case of errors
296
+ transcript = "Error: Processing failed."
297
+ syllable_transcript = "Error: Processing failed."
298
+ wer, der, cer = None, None, None
299
+ html_output = ""
300
+ error_words = ""
301
+
302
+ # Validate inputs
303
+ if not audio_filepath:
304
+ transcript = "Error: No audio provided."
305
+ syllable_transcript = "Error: No audio provided."
306
+ # Return 7 values even on input error
307
+ return transcript, syllable_transcript, None, None, None, "", ""
308
+ if not reference_text:
309
+ transcript = "Error: No reference text found. Please diacritize first."
310
+ syllable_transcript = "Error: No reference text found."
311
+ # Return 7 values
312
+ return transcript, syllable_transcript, None, None, None, "", ""
313
+
314
+ # --- Call Transcription APIs ---
315
+ transcript = transcribe_audio_api(audio_filepath)
316
+ # Call syllable transcription regardless of the first one's success for now,
317
+ # but handle its potential error message.
318
+ syllable_transcript = transcribe_syllable_audio_api(audio_filepath)
319
+
320
+ # --- Calculate Metrics and Highlight Errors (only if first transcript is not an error) ---
321
+ if not transcript.startswith("Error"):
322
+ wer, der, cer = calculate_metrics(reference_text, transcript)
323
+ # Use the standard transcript for highlighting, adjust if needed
324
+ html_output, error_words = highlight_errors(reference_text, transcript)
325
+ else:
326
+ # If the main transcript failed, indicate no metrics/highlighting possible
327
+ wer, der, cer = None, None, None
328
+ html_output = "Highlighting not available due to transcription error."
329
+ error_words = "N/A"
330
+
331
+ # --- Return all 7 values ---
332
+ return transcript, syllable_transcript, wer, der, cer, html_output, error_words
333
 
334
+ # When Transcribe button is clicked
335
  transcribe_btn.click(
336
+ fn=process_audio_and_compare,
337
+ # Get audio path and the reference text from the state
338
+ inputs=[audio_input, reference_text_state],
339
+ # Update all 7 output components
340
+ outputs=[
341
+ transcript_output,
342
+ transcript_syllables_output, # This should now update correctly
343
+ wer_out,
344
+ der_out,
345
+ cer_out,
346
+ error_html,
347
+ error_list
348
+ ]
349
  )
350
 
351
+ app.launch(debug=True, share=True)