Bisher commited on
Commit
647d7b9
·
verified ·
1 Parent(s): b11a63d

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +251 -0
app.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from gradio_client import Client, handle_file
3
+ import jiwer
4
+ import os
5
+ import time
6
+ import warnings
7
+
8
+ # Suppress specific UserWarnings from jiwer related to empty strings
9
+ warnings.filterwarnings("ignore", message="Reference is empty.*", category=UserWarning)
10
+ warnings.filterwarnings("ignore", message="Hypothesis is empty.*", category=UserWarning)
11
+
12
+ # --- Constants ---
13
+ DIACRITIZATION_API_URL = "Bisher/CATT.diacratization"
14
+ TRANSCRIPTION_API_URL = "gh-kaka22/diacritic_level_arabic_transcription"
15
+
16
+ # --- Gradio API Clients ---
17
+ # It's good practice to initialize clients outside the functions
18
+ # if the app runs continuously, but be mindful of potential state issues
19
+ # or connection timeouts in long-running deployments. For simplicity here,
20
+ # we might re-initialize, though a single initialization is often preferred.
21
+
22
+ def get_diacritization_client():
23
+ """Initializes and returns the client for the text diacritization API."""
24
+ try:
25
+ return Client(DIACRITIZATION_API_URL, download_files=True) # download_files might be needed depending on space setup
26
+ except Exception as e:
27
+ print(f"Error initializing diacritization client: {e}")
28
+ return None
29
+
30
+ def get_transcription_client():
31
+ """Initializes and returns the client for the audio transcription API."""
32
+ try:
33
+ return Client(TRANSCRIPTION_API_URL, download_files=True) # download_files might be needed
34
+ except Exception as e:
35
+ print(f"Error initializing transcription client: {e}")
36
+ return None
37
+
38
+ # --- Helper Functions ---
39
+
40
+ def diacritize_text_api(text_to_diacritize):
41
+ """
42
+ Calls the Hugging Face space to diacritize the input text.
43
+
44
+ Args:
45
+ text_to_diacritize (str): The undiacritized Arabic text.
46
+
47
+ Returns:
48
+ str: The diacritized text, or an error message.
49
+ """
50
+ if not text_to_diacritize:
51
+ return "Please enter some text to diacritize."
52
+
53
+ client = get_diacritization_client()
54
+ if not client:
55
+ return "Error: Could not connect to the diacritization service."
56
+
57
+ try:
58
+ print(f"Sending text to diacritization API: {text_to_diacritize}")
59
+ result = client.predict(
60
+ model_type="Encoder-Only", # Or 'Encoder-Decoder' if preferred
61
+ input_text=text_to_diacritize,
62
+ api_name="/predict"
63
+ )
64
+ print(f"Received diacritized text: {result}")
65
+ return result
66
+ except Exception as e:
67
+ print(f"Error during text diacritization API call: {e}")
68
+ # Provide more specific error feedback if possible
69
+ return f"Error during diacritization: {e}"
70
+
71
+ def transcribe_audio_api(audio_filepath):
72
+ """
73
+ Calls the Hugging Face space to transcribe and diacritize the input audio.
74
+
75
+ Args:
76
+ audio_filepath (str): The path to the audio file.
77
+
78
+ Returns:
79
+ str: The diacritized transcript, or an error message.
80
+ """
81
+ if not audio_filepath:
82
+ return "Please provide an audio recording or file."
83
+
84
+ # Check if file exists and is accessible
85
+ if not os.path.exists(audio_filepath):
86
+ return f"Error: Audio file not found at {audio_filepath}"
87
+
88
+ client = get_transcription_client()
89
+ if not client:
90
+ return "Error: Could not connect to the transcription service."
91
+
92
+ try:
93
+ print(f"Sending audio file to transcription API: {audio_filepath}")
94
+ # Use handle_file to manage the audio file for the API call
95
+ result = client.predict(
96
+ audio=handle_file(audio_filepath),
97
+ api_name="/predict"
98
+ )
99
+ print(f"Received transcript: {result}")
100
+ # The API might return more structure, adapt if needed. Assuming it returns the text directly.
101
+ # Example: if result is {'text': '...'}, use result['text']
102
+ if isinstance(result, dict) and 'text' in result:
103
+ transcript = result['text']
104
+ elif isinstance(result, str):
105
+ transcript = result
106
+ else:
107
+ print(f"Unexpected transcription result format: {result}")
108
+ return "Error: Unexpected format received from transcription service."
109
+
110
+ return transcript
111
+
112
+ except Exception as e:
113
+ print(f"Error during audio transcription API call: {e}")
114
+ # Provide more specific error feedback if possible
115
+ return f"Error during transcription: {e}"
116
+
117
+ def calculate_metrics(reference, hypothesis):
118
+ """
119
+ Calculates Word Error Rate (WER) and Diacritic Error Rate (DER).
120
+
121
+ Args:
122
+ reference (str): The original diacritized text.
123
+ hypothesis (str): The diacritized transcript from the audio.
124
+
125
+ Returns:
126
+ tuple: (wer, der) scores, or (None, None) if inputs are invalid.
127
+ """
128
+ if not isinstance(reference, str) or not isinstance(hypothesis, str):
129
+ print("Error: Invalid input types for metric calculation.")
130
+ return None, None
131
+
132
+ # Handle empty strings to avoid jiwer warnings/errors if not suppressed
133
+ if not reference.strip() and not hypothesis.strip():
134
+ return 0.0, 0.0 # Both empty, 0% error
135
+ if not reference.strip():
136
+ print("Warning: Reference text is empty.")
137
+ # WER/DER are typically 1.0 (or inf) if reference is empty and hypothesis is not.
138
+ # Jiwer might handle this, but let's return 1.0 for clarity.
139
+ return 1.0, 1.0
140
+ if not hypothesis.strip():
141
+ print("Warning: Hypothesis text is empty.")
142
+ # If hypothesis is empty but reference is not, WER/DER is 1.0
143
+ return 1.0, 1.0
144
+
145
+ try:
146
+ # 1. Calculate Word Error Rate (WER)
147
+ wer = jiwer.wer(reference, hypothesis)
148
+
149
+ # 2. Calculate Diacritic Error Rate (DER)
150
+ # - Treat each character (including diacritics) as a token.
151
+ # - Join characters with spaces to make jiwer treat them as "words".
152
+ ref_chars = ' '.join(list(reference))
153
+ hyp_chars = ' '.join(list(hypothesis))
154
+ der = jiwer.wer(ref_chars, hyp_chars)
155
+
156
+ return round(wer, 4), round(der, 4)
157
+
158
+ except Exception as e:
159
+ print(f"Error calculating metrics: {e}")
160
+ return None, None
161
+
162
+
163
+ def process_audio_and_compare(audio_input, original_diacritized_text):
164
+ """
165
+ Main function triggered after audio input.
166
+ Transcribes audio, calculates metrics, and returns results.
167
+ """
168
+ print("Processing audio and comparing...")
169
+ if not original_diacritized_text:
170
+ return "Error: No original diacritized text found. Please diacritize text first.", None, None
171
+
172
+ # --- 1. Transcribe Audio ---
173
+ # Gradio provides the audio data (e.g., filepath for upload/mic)
174
+ transcript = transcribe_audio_api(audio_input)
175
+
176
+ if transcript.startswith("Error:"):
177
+ # If transcription failed, return the error and None for metrics
178
+ return transcript, None, None
179
+
180
+ # --- 2. Calculate Metrics ---
181
+ wer, der = calculate_metrics(original_diacritized_text, transcript)
182
+
183
+ print(f"Comparison complete. WER: {wer}, DER: {der}")
184
+ return transcript, wer, der
185
+
186
+
187
+ # --- Gradio Interface ---
188
+ with gr.Blocks(theme=gr.themes.Soft()) as app:
189
+ gr.Markdown(
190
+ """
191
+ # Arabic Diacritization and Reading Assessment Tool
192
+ 1. Enter undiacritized Arabic text and click **Diacritize Text**.
193
+ 2. Read the generated **Diacritized Text** aloud and record it using the microphone or upload an audio file.
194
+ 3. Click **Transcribe and Compare** to get the transcript and see the WER/DER scores compared to the original diacritized text.
195
+ """
196
+ )
197
+
198
+ # Store the original diacritized text for comparison later
199
+ original_diacritized_state = gr.State("")
200
+
201
+ with gr.Row():
202
+ with gr.Column(scale=1):
203
+ text_input = gr.Textbox(
204
+ label="1. Enter Undiacritized Arabic Text",
205
+ placeholder="مثال: السلام عليكم",
206
+ lines=3,
207
+ text_align="right", # Align text right for Arabic
208
+ )
209
+ diacritize_button = gr.Button("Diacritize Text")
210
+ diacritized_text_output = gr.Textbox(
211
+ label="2. Diacritized Text (Reference)",
212
+ lines=3,
213
+ interactive=False, # User shouldn't edit this directly
214
+ text_align="right",
215
+ )
216
+
217
+ with gr.Column(scale=1):
218
+ audio_input = gr.Audio(
219
+ sources=["microphone", "upload"],
220
+ type="filepath", # Get the path to the saved audio file
221
+ label="3. Record or Upload Audio of Reading Diacritized Text",
222
+ )
223
+ transcribe_button = gr.Button("Transcribe and Compare")
224
+ transcript_output = gr.Textbox(
225
+ label="4. Diacritized Transcript (Hypothesis)",
226
+ lines=3,
227
+ interactive=False,
228
+ text_align="right",
229
+ )
230
+ with gr.Row():
231
+ wer_output = gr.Number(label="Word Error Rate (WER)", interactive=False)
232
+ der_output = gr.Number(label="Diacritic Error Rate (DER)", interactive=False)
233
+
234
+
235
+ # --- Connect Components ---
236
+
237
+ # Action for Diacritize Button
238
+ diacritize_button.click(
239
+ fn=diacritize_text_api,
240
+ inputs=[text_input],
241
+ outputs=[diacritized_text_output, original_diacritized_state] # Update output and state
242
+ )
243
+
244
+ # Action for Transcribe Button
245
+ transcribe_button.click(
246
+ fn=process_audio_and_compare,
247
+ inputs=[audio_input, original_diacritized_state], # Pass audio and stored text
248
+ outputs=[transcript_output, wer_output, der_output] # Update transcript and metrics
249
+ )
250
+
251
+ app.launch(debug=True)