Bisher commited on
Commit
18d0c35
·
verified ·
1 Parent(s): 1ae10e6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +78 -234
app.py CHANGED
@@ -15,26 +15,17 @@ warnings.filterwarnings("ignore", message="Hypothesis is empty.*", category=User
15
  DIACRITIZATION_API_URL = "Bisher/CATT.diacratization"
16
  TRANSCRIPTION_API_URL = "gh-kaka22/diacritic_level_arabic_transcription"
17
 
18
- # Define the set of Arabic diacritic characters using pyarabic constants if available
19
  if araby:
20
  ARABIC_DIACRITICS = {
21
- araby.FATHA, # U+064E
22
- araby.FATHATAN, # U+064B
23
- araby.DAMMA, # U+064F
24
- araby.DAMMATAN, # U+064C
25
- araby.KASRA, # U+0650
26
- araby.KASRATAN, # U+064D
27
- araby.SUKUN, # U+0652
28
- araby.SHADDA, # U+0651
29
  }
30
  else:
31
- # Fallback if pyarabic failed to import
32
  ARABIC_DIACRITICS = {'\u064B', '\u064C', '\u064D', '\u064E', '\u064F', '\u0650', '\u0651', '\u0652'}
33
 
34
-
35
- # --- Gradio API Clients ---
36
  def get_diacritization_client():
37
- """Initializes and returns the client for the text diacritization API."""
38
  try:
39
  return Client(DIACRITIZATION_API_URL, download_files=True)
40
  except Exception as e:
@@ -42,7 +33,6 @@ def get_diacritization_client():
42
  return None
43
 
44
  def get_transcription_client():
45
- """Initializes and returns the client for the audio transcription API."""
46
  try:
47
  return Client(TRANSCRIPTION_API_URL, download_files=True)
48
  except Exception as e:
@@ -50,280 +40,134 @@ def get_transcription_client():
50
  return None
51
 
52
  # --- Helper Functions ---
53
-
54
  def diacritize_text_api(text_to_diacritize):
55
- """
56
- Calls the Hugging Face space to diacritize the input text.
57
-
58
- Args:
59
- text_to_diacritize (str): The undiacritized Arabic text.
60
-
61
- Returns:
62
- tuple: (str, str) The diacritized text (or error message) returned twice,
63
- once for the output component and once for the state.
64
- """
65
  if not text_to_diacritize or not text_to_diacritize.strip():
66
- error_msg = "Please enter some text to diacritize."
67
- return error_msg, error_msg
68
-
69
  client = get_diacritization_client()
70
  if not client:
71
- error_msg = "Error: Could not connect to the diacritization service."
72
- return error_msg, error_msg
73
-
74
  try:
75
- print(f"Sending text to diacritization API: {text_to_diacritize}")
76
  result = client.predict(
77
  model_type="Encoder-Only",
78
  input_text=text_to_diacritize,
79
  api_name="/predict"
80
  )
81
- print(f"Received diacritized text: {result}")
82
- result_str = str(result) if result is not None else "Error: Received empty response from diacritization service."
83
  return result_str, result_str
84
  except Exception as e:
85
- print(f"Error during text diacritization API call: {e}")
86
- error_msg = f"Error during diacritization: {e}"
87
- return error_msg, error_msg
88
 
89
  def transcribe_audio_api(audio_filepath):
90
- """
91
- Calls the Hugging Face space to transcribe and diacritize the input audio.
92
-
93
- Args:
94
- audio_filepath (str): The path to the audio file.
95
-
96
- Returns:
97
- str: The diacritized transcript, or an error message.
98
- """
99
  if not audio_filepath:
100
  return "Error: Please provide an audio recording or file."
101
-
102
  if not os.path.exists(audio_filepath):
103
- return f"Error: Audio file not found at {audio_filepath}"
104
-
105
  client = get_transcription_client()
106
  if not client:
107
  return "Error: Could not connect to the transcription service."
108
-
109
  try:
110
- print(f"Sending audio file to transcription API: {audio_filepath}")
111
  result = client.predict(
112
  audio=handle_file(audio_filepath),
113
  api_name="/predict"
114
  )
115
- print(f"Received transcript: {result}")
116
  if isinstance(result, dict) and 'text' in result:
117
- transcript = result['text']
118
- elif isinstance(result, str):
119
- transcript = result
120
  else:
121
- print(f"Unexpected transcription result format: {result}")
122
- return "Error: Unexpected format received from transcription service."
123
-
124
- return str(transcript) if transcript is not None else "Error: Received empty response from transcription service."
125
-
126
  except Exception as e:
127
- print(f"Error during audio transcription API call: {e}")
128
  return f"Error during transcription: {e}"
129
 
130
  def get_diacritics_sequence(text):
131
- """
132
- Extracts only the Arabic diacritic characters from a string.
133
-
134
- Args:
135
- text (str): The input string potentially containing diacritics.
136
-
137
- Returns:
138
- str: A space-separated string of diacritics found in the text.
139
- Returns an empty string if no diacritics are found or input is not a string.
140
- """
141
  if not isinstance(text, str):
142
- return "" # Return empty string for non-string input
143
-
144
- if not araby and not ARABIC_DIACRITICS:
145
- print("Warning: pyarabic not loaded, cannot reliably extract diacritics.")
146
- return ""
147
-
148
- diacritics_only = [char for char in text if char in ARABIC_DIACRITICS]
149
  return ' '.join(diacritics_only)
150
 
151
-
152
  def calculate_metrics(reference, hypothesis):
153
- """
154
- Calculates Word Error Rate (WER), Diacritic Error Rate (DER),
155
- and Character Error Rate (CER).
156
- DER is calculated based *only* on the sequence of diacritic marks.
157
-
158
- Args:
159
- reference (str): The original diacritized text.
160
- hypothesis (str): The diacritized transcript from the audio.
161
-
162
- Returns:
163
- tuple: (wer, der, cer) scores, or (None, None, None) if inputs are invalid or calculation fails.
164
- """
165
- if not isinstance(reference, str):
166
- print(f"Error: Reference input is not a string (type: {type(reference)}). Value: {reference}")
167
- reference = ""
168
- if not isinstance(hypothesis, str):
169
- print(f"Error: Hypothesis input is not a string (type: {type(hypothesis)}). Value: {hypothesis}")
170
- hypothesis = ""
171
-
172
- ref_strip = reference.strip()
173
- hyp_strip = hypothesis.strip()
174
-
175
- wer = None
176
- der = None
177
- cer = None
178
-
179
- try:
180
- # Handle cases where both are empty first
181
- if not ref_strip and not hyp_strip:
182
- return 0.0, 0.0, 0.0
183
-
184
- # 1. Calculate Word Error Rate (WER)
185
- if not ref_strip:
186
- wer = 1.0 # Reference empty, hypothesis not
187
- else:
188
- wer = jiwer.wer(reference, hypothesis)
189
-
190
- # 2. Calculate Diacritic Error Rate (DER) based *only* on diacritics
191
- ref_diacritics = get_diacritics_sequence(reference)
192
- hyp_diacritics = get_diacritics_sequence(hypothesis)
193
- ref_diacritics_strip = ref_diacritics.strip()
194
- hyp_diacritics_strip = hyp_diacritics.strip()
195
-
196
- if not ref_diacritics_strip and not hyp_diacritics_strip:
197
- der = 0.0
198
- elif not ref_diacritics_strip:
199
- der = 1.0
200
- print("Warning: No diacritics found in reference text for DER calculation.")
201
- else:
202
- der = jiwer.wer(ref_diacritics, hyp_diacritics)
203
-
204
- # 3. Calculate Character Error Rate (CER)
205
- if not ref_strip:
206
- # If reference is empty, CER is 1.0 (all hypothesis chars are insertions)
207
- # unless hypothesis is also empty (handled above)
208
- cer = 1.0
209
  else:
210
- # jiwer.cer handles empty hypothesis correctly if reference is not empty
211
- cer = jiwer.cer(reference, hypothesis)
212
-
213
- # Round the results
214
- wer_rounded = round(wer, 4) if wer is not None else None
215
- der_rounded = round(der, 4) if der is not None else None
216
- cer_rounded = round(cer, 4) if cer is not None else None
217
-
218
- return wer_rounded, der_rounded, cer_rounded
219
-
220
- except Exception as e:
221
- print(f"Error calculating metrics: {e}")
222
- return None, None, None
223
-
224
-
225
- def process_audio_and_compare(audio_input, original_diacritized_text):
226
- """
227
- Main function triggered after audio input.
228
- Transcribes audio, calculates metrics (WER, DER, CER), and returns results.
229
-
230
- Returns:
231
- tuple: (transcript, wer, der, cer)
232
- transcript (str): The transcribed text or an error message.
233
- wer (float | None): Word Error Rate or None if error.
234
- der (float | None): Diacritic Error Rate or None if error.
235
- cer (float | None): Character Error Rate or None if error.
236
- """
237
- print("Processing audio and comparing...")
238
- if not original_diacritized_text or not isinstance(original_diacritized_text, str) or original_diacritized_text.startswith("Error:"):
239
- error_msg = "Error: Valid reference diacritized text not available. Please diacritize text first."
240
- print(error_msg)
241
- # Return default/error values for all outputs
242
- return error_msg, None, None, None
243
-
244
- transcript = transcribe_audio_api(audio_input)
245
-
246
- if not isinstance(transcript, str) or transcript.startswith("Error:"):
247
- error_msg = transcript if isinstance(transcript, str) else "Error: Transcription failed with non-string output."
248
- print(error_msg)
249
- # Return transcript error and None for metrics
250
- return error_msg, None, None, None
251
-
252
- # Calculate all three metrics
253
- wer, der, cer = calculate_metrics(original_diacritized_text, transcript)
254
-
255
- if wer is None or der is None or cer is None:
256
- print("Metrics calculation failed.")
257
- # Return transcript but None for metrics
258
- return transcript, None, None, None
259
-
260
- print(f"Comparison complete. WER: {wer}, DER: {der}, CER: {cer}")
261
- # Return transcript and all three metrics
262
- return transcript, wer, der, cer
263
-
264
 
265
  # --- Gradio Interface ---
266
  with gr.Blocks(theme=gr.themes.Soft()) as app:
267
  gr.Markdown(
268
  """
269
  # Arabic Diacritization and Reading Assessment Tool
270
- 1. Enter undiacritized Arabic text and click **Diacritize Text**.
271
- 2. Read the generated **Diacritized Text** aloud and record it using the microphone or upload an audio file.
272
- 3. Click **Transcribe and Compare** to get the transcript and see the WER/DER/CER scores compared to the original diacritized text.
273
  """
274
  )
275
 
276
- original_diacritized_state = gr.State("")
277
-
278
  with gr.Row():
279
  with gr.Column(scale=1):
280
- text_input = gr.Textbox(
281
- label="1. Enter Undiacritized Arabic Text",
282
- placeholder="مثال: السلام عليكم",
283
- lines=3,
284
- text_align="right",
285
- )
286
- diacritize_button = gr.Button("Diacritize Text")
287
- diacritized_text_output = gr.Textbox(
288
- label="2. Diacritized Text (Reference)",
289
- lines=3,
290
- interactive=False,
291
- text_align="right",
292
- )
293
 
294
  with gr.Column(scale=1):
295
- audio_input = gr.Audio(
296
- sources=["microphone", "upload"],
297
- type="filepath",
298
- label="3. Record or Upload Audio of Reading Diacritized Text",
299
- )
300
- transcribe_button = gr.Button("Transcribe and Compare")
301
- transcript_output = gr.Textbox(
302
- label="4. Diacritized Transcript (Hypothesis)",
303
- lines=3,
304
- interactive=False,
305
- text_align="right",
306
- )
307
  with gr.Row():
308
- # Add CER output component
309
- wer_output = gr.Number(label="WER", interactive=False, precision=4) # Shortened label
310
- der_output = gr.Number(label="DER", interactive=False, precision=4) # Shortened label
311
- cer_output = gr.Number(label="CER", interactive=False, precision=4) # Added CER
 
312
 
313
-
314
- # --- Connect Components ---
315
- diacritize_button.click(
316
  fn=diacritize_text_api,
317
  inputs=[text_input],
318
- outputs=[diacritized_text_output, original_diacritized_state]
319
  )
320
 
321
- transcribe_button.click(
322
- fn=process_audio_and_compare,
323
- inputs=[audio_input, original_diacritized_state],
324
- # Update outputs to include the new CER component
325
- outputs=[transcript_output, wer_output, der_output, cer_output]
 
 
 
 
 
 
 
326
  )
327
 
328
-
329
- app.launch(debug=True, share=True)
 
15
  DIACRITIZATION_API_URL = "Bisher/CATT.diacratization"
16
  TRANSCRIPTION_API_URL = "gh-kaka22/diacritic_level_arabic_transcription"
17
 
18
+ # Define Arabic diacritics
19
  if araby:
20
  ARABIC_DIACRITICS = {
21
+ araby.FATHA, araby.FATHATAN, araby.DAMMA, araby.DAMMATAN,
22
+ araby.KASRA, araby.KASRATAN, araby.SUKUN, araby.SHADDA,
 
 
 
 
 
 
23
  }
24
  else:
 
25
  ARABIC_DIACRITICS = {'\u064B', '\u064C', '\u064D', '\u064E', '\u064F', '\u0650', '\u0651', '\u0652'}
26
 
27
+ # --- API Clients ---
 
28
  def get_diacritization_client():
 
29
  try:
30
  return Client(DIACRITIZATION_API_URL, download_files=True)
31
  except Exception as e:
 
33
  return None
34
 
35
  def get_transcription_client():
 
36
  try:
37
  return Client(TRANSCRIPTION_API_URL, download_files=True)
38
  except Exception as e:
 
40
  return None
41
 
42
  # --- Helper Functions ---
 
43
  def diacritize_text_api(text_to_diacritize):
 
 
 
 
 
 
 
 
 
 
44
  if not text_to_diacritize or not text_to_diacritize.strip():
45
+ return "Please enter some text to diacritize.", ""
 
 
46
  client = get_diacritization_client()
47
  if not client:
48
+ return "Error: Could not connect to the diacritization service.", ""
 
 
49
  try:
 
50
  result = client.predict(
51
  model_type="Encoder-Only",
52
  input_text=text_to_diacritize,
53
  api_name="/predict"
54
  )
55
+ result_str = str(result) if result is not None else "Error: Empty response from diacritization service."
 
56
  return result_str, result_str
57
  except Exception as e:
58
+ return f"Error during diacritization: {e}", ""
 
 
59
 
60
  def transcribe_audio_api(audio_filepath):
 
 
 
 
 
 
 
 
 
61
  if not audio_filepath:
62
  return "Error: Please provide an audio recording or file."
 
63
  if not os.path.exists(audio_filepath):
64
+ return f"Error: Audio file not found at {audio_filepath}"
 
65
  client = get_transcription_client()
66
  if not client:
67
  return "Error: Could not connect to the transcription service."
 
68
  try:
 
69
  result = client.predict(
70
  audio=handle_file(audio_filepath),
71
  api_name="/predict"
72
  )
 
73
  if isinstance(result, dict) and 'text' in result:
74
+ transcript = result['text']
 
 
75
  else:
76
+ transcript = str(result)
77
+ return transcript
 
 
 
78
  except Exception as e:
 
79
  return f"Error during transcription: {e}"
80
 
81
  def get_diacritics_sequence(text):
 
 
 
 
 
 
 
 
 
 
82
  if not isinstance(text, str):
83
+ return ""
84
+ diacritics_only = [c for c in text if c in ARABIC_DIACRITICS]
 
 
 
 
 
85
  return ' '.join(diacritics_only)
86
 
 
87
  def calculate_metrics(reference, hypothesis):
88
+ ref = reference or ""
89
+ hyp = hypothesis or ""
90
+ # WER
91
+ wer = jiwer.wer(ref, hyp) if ref.strip() else (1.0 if hyp.strip() else 0.0)
92
+ # DER
93
+ ref_d = get_diacritics_sequence(ref)
94
+ hyp_d = get_diacritics_sequence(hyp)
95
+ der = jiwer.wer(ref_d, hyp_d) if ref_d.strip() else (1.0 if hyp_d.strip() else 0.0)
96
+ # CER
97
+ cer = jiwer.cer(ref, hyp) if ref.strip() else (1.0 if hyp.strip() else 0.0)
98
+ return round(wer, 4), round(der, 4), round(cer, 4)
99
+
100
+ import difflib
101
+
102
+ def highlight_errors(reference, hypothesis):
103
+ ref_words = reference.split()
104
+ hyp_words = hypothesis.split()
105
+ matcher = difflib.SequenceMatcher(a=ref_words, b=hyp_words)
106
+ highlighted = []
107
+ errors = []
108
+ # Iterate over matched blocks and insert highlights for mismatches
109
+ i = j = 0
110
+ for tag, a0, a1, b0, b1 in matcher.get_opcodes():
111
+ if tag == 'equal':
112
+ for w in ref_words[a0:a1]:
113
+ highlighted.append(w)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
  else:
115
+ # highlight reference words as errors
116
+ for w in ref_words[a0:a1]:
117
+ highlighted.append(f"<mark>{w}</mark>")
118
+ errors.append(w)
119
+ i = a1
120
+ j = b1
121
+ html = ' '.join(highlighted)
122
+ return html, ', '.join(errors)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
 
124
  # --- Gradio Interface ---
125
  with gr.Blocks(theme=gr.themes.Soft()) as app:
126
  gr.Markdown(
127
  """
128
  # Arabic Diacritization and Reading Assessment Tool
129
+ 1. Enter undiacritized Arabic text and click **Diacritize Text**.
130
+ 2. Read the generated **Diacritized Text** aloud and record or upload audio.
131
+ 3. Click **Transcribe and Compare** to see the transcript, WER/DER/CER, and mispronounced words highlighted.
132
  """
133
  )
134
 
135
+ original_state = gr.State("")
 
136
  with gr.Row():
137
  with gr.Column(scale=1):
138
+ text_input = gr.Textbox(label="Undiacritized Arabic Text", lines=3, text_align="right")
139
+ diacritize_btn = gr.Button("Diacritize Text")
140
+ diacritized_output = gr.Textbox(label="Diacritized Text (Reference)", lines=3, interactive=False, text_align="right")
 
 
 
 
 
 
 
 
 
 
141
 
142
  with gr.Column(scale=1):
143
+ audio_input = gr.Audio(label="Record or Upload Audio", type="filepath")
144
+ transcribe_btn = gr.Button("Transcribe and Compare")
145
+ transcript_output = gr.Textbox(label="Transcript (Hypothesis)", lines=3, interactive=False, text_align="right")
 
 
 
 
 
 
 
 
 
146
  with gr.Row():
147
+ wer_out = gr.Number(label="WER", interactive=False, precision=4)
148
+ der_out = gr.Number(label="DER", interactive=False, precision=4)
149
+ cer_out = gr.Number(label="CER", interactive=False, precision=4)
150
+ error_html = gr.HTML(label="Highlighted Errors")
151
+ error_list = gr.Textbox(label="Mispronounced Words", interactive=False)
152
 
153
+ diacritize_btn.click(
 
 
154
  fn=diacritize_text_api,
155
  inputs=[text_input],
156
+ outputs=[diacritized_output, original_state]
157
  )
158
 
159
+ def process(audio, ref_text):
160
+ transcript = transcribe_audio_api(audio)
161
+ if transcript.startswith("Error"):
162
+ return transcript, None, None, None, "", ""
163
+ wer, der, cer = calculate_metrics(ref_text, transcript)
164
+ html, errs = highlight_errors(ref_text, transcript)
165
+ return transcript, wer, der, cer, html, errs
166
+
167
+ transcribe_btn.click(
168
+ fn=process,
169
+ inputs=[audio_input, original_state],
170
+ outputs=[transcript_output, wer_out, der_out, cer_out, error_html, error_list]
171
  )
172
 
173
+ app.launch(debug=True, share=True)