File size: 15,131 Bytes
647d7b9
 
 
 
 
 
17ddd97
31c4539
17ddd97
647d7b9
 
 
 
 
 
 
1590f7f
31c4539
18d0c35
31c4539
 
17ddd97
18d0c35
 
17ddd97
31c4539
 
17ddd97
647d7b9
18d0c35
31c4539
 
 
 
 
647d7b9
31c4539
 
 
 
 
 
 
 
647d7b9
 
31c4539
 
 
 
 
 
 
 
647d7b9
1590f7f
31c4539
 
 
 
 
 
 
 
1590f7f
647d7b9
 
31c4539
adaae60
31c4539
647d7b9
 
18d0c35
647d7b9
 
17ddd97
647d7b9
 
 
31c4539
18d0c35
31c4539
adaae60
647d7b9
31c4539
18d0c35
647d7b9
 
31c4539
647d7b9
adaae60
647d7b9
18d0c35
31c4539
647d7b9
 
 
 
31c4539
 
647d7b9
 
 
 
31c4539
647d7b9
18d0c35
31c4539
 
647d7b9
31c4539
 
647d7b9
31c4539
647d7b9
 
1590f7f
31c4539
1590f7f
31c4539
 
1590f7f
31c4539
 
1590f7f
 
31c4539
1590f7f
31c4539
 
1590f7f
 
 
 
31c4539
1590f7f
 
31c4539
 
1590f7f
31c4539
 
1590f7f
31c4539
 
1590f7f
17ddd97
31c4539
17ddd97
18d0c35
 
17ddd97
 
647d7b9
31c4539
18d0c35
 
31c4539
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18d0c35
 
31c4539
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18d0c35
31c4539
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
647d7b9
 
 
 
 
 
31c4539
 
 
647d7b9
 
 
31c4539
 
 
647d7b9
 
18d0c35
 
31c4539
 
 
 
 
 
647d7b9
 
31c4539
18d0c35
31c4539
 
 
 
 
 
 
 
 
 
 
 
 
647d7b9
18d0c35
 
 
31c4539
 
 
647d7b9
31c4539
 
 
18d0c35
647d7b9
 
31c4539
 
647d7b9
 
31c4539
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18d0c35
31c4539
18d0c35
31c4539
 
 
 
 
 
 
 
 
 
 
 
 
647d7b9
 
0570125
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
import gradio as gr
from gradio_client import Client, handle_file
import jiwer
import os
import time
import warnings
import pyarabic.araby as araby
import difflib # Import difflib

# Suppress specific UserWarnings from jiwer related to empty strings
warnings.filterwarnings("ignore", message="Reference is empty.*", category=UserWarning)
warnings.filterwarnings("ignore", message="Hypothesis is empty.*", category=UserWarning)

# --- Constants ---
DIACRITIZATION_API_URL = "Bisher/CATT.diacratization"
TRANSCRIPTION_API_URL = "gh-kaka22/diacritic_level_arabic_transcription"
SYLLABLE_TRANSCRIPTION_API_URL = "Bisher/arabic_syllable_transcription"

# Define Arabic diacritics
# Use a try-except block in case pyarabic is not installed or fails to import
try:
    ARABIC_DIACRITICS = {
        araby.FATHA, araby.FATHATAN, araby.DAMMA, araby.DAMMATAN,
        araby.KASRA, araby.KASRATAN, araby.SUKUN, araby.SHADDA,
    }
except (ImportError, NameError):
    print("Warning: pyarabic not found or failed to import. Using fallback diacritics set.")
    ARABIC_DIACRITICS = {'\u064B', '\u064C', '\u064D', '\u064E', '\u064F', '\u0650', '\u0651', '\u0652'}

# --- API Clients ---
# Use caching or global clients to avoid re-initializing on every call
diacritization_client = None
transcription_client = None
syllable_transcription_client = None

def get_diacritization_client():
    global diacritization_client
    if diacritization_client is None:
        try:
            diacritization_client = Client(DIACRITIZATION_API_URL, download_files=True)
        except Exception as e:
            print(f"Error initializing diacritization client: {e}")
            return None
    return diacritization_client

def get_transcription_client():
    global transcription_client
    if transcription_client is None:
        try:
            transcription_client = Client(TRANSCRIPTION_API_URL, download_files=True)
        except Exception as e:
            print(f"Error initializing transcription client: {e}")
            return None
    return transcription_client

def get_syllable_transcription_client():
    global syllable_transcription_client
    if syllable_transcription_client is None:
        try:
            syllable_transcription_client = Client(SYLLABLE_TRANSCRIPTION_API_URL, download_files=True)
        except Exception as e:
            print(f"Error initializing syllable transcription client: {e}")
            return None
    return syllable_transcription_client

# --- Helper Functions ---
def diacritize_text_api(text_to_diacritize):
    """Calls the diacritization API."""
    if not text_to_diacritize or not text_to_diacritize.strip():
        return "Please enter some text to diacritize.", "" # Return two values as expected by the click handler
    client = get_diacritization_client()
    if not client:
        return "Error: Could not connect to the diacritization service.", ""
    try:
        result = client.predict(
            model_type="Encoder-Only",
            input_text=text_to_diacritize,
            api_name="/predict"
        )
        # Ensure result is a string, handle potential None or unexpected types
        result_str = str(result) if result is not None else "Error: Empty response from diacritization service."
        # Return the result for both the output textbox and the state
        return result_str, result_str
    except Exception as e:
        print(f"Error during diacritization API call: {e}")
        return f"Error during diacritization: {e}", ""

def transcribe_audio_api(audio_filepath):
    """Calls the standard transcription API."""
    if not audio_filepath:
        return "Error: Please provide an audio recording or file."
    if not os.path.exists(audio_filepath):
        return f"Error: Audio file not found at {audio_filepath}"

    client = get_transcription_client()
    if not client:
        return "Error: Could not connect to the transcription service."
    try:
        # Add a small delay if needed, sometimes helps with API race conditions
        # time.sleep(0.5)
        result = client.predict(
            audio=handle_file(audio_filepath),
            api_name="/predict"
        )
        # Process result, expecting a dictionary or string
        if isinstance(result, dict) and 'text' in result:
            transcript = result['text']
        elif isinstance(result, str):
             transcript = result
        else:
             transcript = f"Error: Unexpected response format from transcription service: {type(result)}"
        return transcript if transcript is not None else "Error: Empty transcript received."
    except Exception as e:
        print(f"Error during transcription API call: {e}")
        return f"Error during transcription: {e}"

def transcribe_syllable_audio_api(audio_filepath):
    """Calls the syllable transcription API."""
    if not audio_filepath:
        # This case might not be strictly needed if called after the first check, but good practice
        return "Error: Audio file path missing for syllable transcription."
    if not os.path.exists(audio_filepath):
         return f"Error: Audio file not found at {audio_filepath} for syllable transcription."

    client = get_syllable_transcription_client()
    if not client:
        return "Error: Could not connect to the syllable transcription service."
    try:
        # Add a small delay if needed
        # time.sleep(0.5)
        result = client.predict(
            audio=handle_file(audio_filepath),
            api_name="/predict"
        )
         # Process result, expecting a dictionary or string
        if isinstance(result, dict) and 'text' in result:
            transcript = result['text']
        elif isinstance(result, str):
             transcript = result
        else:
            transcript = f"Error: Unexpected response format from syllable transcription service: {type(result)}"
        return transcript if transcript is not None else "Error: Empty syllable transcript received."
    except Exception as e:
        print(f"Error during syllable transcription API call: {e}")
        return f"Error during syllable transcription: {e}"

def get_diacritics_sequence(text):
    """Extracts diacritics from a string."""
    if not isinstance(text, str):
        return ""
    diacritics_only = [c for c in text if c in ARABIC_DIACRITICS]
    return ' '.join(diacritics_only)

def calculate_metrics(reference, hypothesis):
    """Calculates WER, DER, CER."""
    ref = reference or ""
    hyp = hypothesis or ""

    # Handle cases where one or both are empty or just whitespace
    if not ref.strip() and not hyp.strip():
        return 0.0, 0.0, 0.0 # Both empty, 0 error
    if not ref.strip():
        return 1.0, 1.0, 1.0 # Reference empty, hypothesis not: Max error
    if not hyp.strip():
        # Hypothesis empty, reference not: Max error (though jiwer might handle this)
         # Let jiwer calculate based on its rules for empty hypothesis
         pass

    try:
        # WER
        wer = jiwer.wer(ref, hyp)
        # DER
        ref_d = get_diacritics_sequence(ref)
        hyp_d = get_diacritics_sequence(hyp)
        # Handle empty diacritic sequences for DER calculation
        if not ref_d.strip() and not hyp_d.strip():
            der = 0.0
        elif not ref_d.strip():
            der = 1.0
        else:
            der = jiwer.wer(ref_d, hyp_d) # jiwer handles empty hyp_d if ref_d is not empty
        # CER
        cer = jiwer.cer(ref, hyp)
        return round(wer, 4), round(der, 4), round(cer, 4)
    except Exception as e:
        print(f"Error calculating metrics: {e}")
        return None, None, None # Indicate error in calculation


def highlight_errors(reference, hypothesis):
    """Highlights differences between reference and hypothesis using HTML mark tag."""
    ref = reference or ""
    hyp = hypothesis or ""
    ref_words = ref.split()
    hyp_words = hyp.split()

    if not ref_words and not hyp_words:
        return "", "" # No errors if both are empty

    matcher = difflib.SequenceMatcher(None, ref_words, hyp_words, autojunk=False)
    highlighted_hyp_words = []
    error_words_ref = [] # Words in reference that were deleted or replaced
    error_words_hyp = [] # Words in hypothesis that were inserted or replaced

    for tag, i1, i2, j1, j2 in matcher.get_opcodes():
        if tag == 'equal':
            highlighted_hyp_words.extend(hyp_words[j1:j2])
        elif tag == 'replace':
            # Mark incorrect words in hypothesis red
            for word in hyp_words[j1:j2]:
                 highlighted_hyp_words.append(f"<mark style='background-color: #ffcccb;'>{word}</mark>")
            error_words_ref.extend(ref_words[i1:i2])
            error_words_hyp.extend(hyp_words[j1:j2])
        elif tag == 'delete':
            # Indicate missing words (maybe with a placeholder?) - for now, just note them
            # We don't add anything to highlighted_hyp_words here as they are missing
             error_words_ref.extend(ref_words[i1:i2])
             # Optionally add a placeholder in the output to show where deletion happened
             # highlighted_hyp_words.append("<mark style='background-color: #lightgrey;'>[missing]</mark>")
        elif tag == 'insert':
             # Mark inserted words in hypothesis green
            for word in hyp_words[j1:j2]:
                 highlighted_hyp_words.append(f"<mark style='background-color: #ccffcc;'>{word}</mark>")
            error_words_hyp.extend(hyp_words[j1:j2])

    html_output = ' '.join(highlighted_hyp_words)
    # Combine unique error words for the list
    error_list = sorted(list(set(error_words_ref + error_words_hyp)))

    return html_output, ', '.join(error_list)


# --- Gradio Interface ---
with gr.Blocks(theme=gr.themes.Soft()) as app:
    gr.Markdown(
        """
        # Arabic Diacritization and Reading Assessment Tool
        1.  Enter undiacritized Arabic text and click **Diacritize Text**.
        2.  Read the generated **Diacritized Text** aloud and record or upload audio.
        3.  Click **Transcribe and Compare** to see the transcript, syllable transcript, WER/DER/CER, and mispronounced words highlighted.
        """
    )

    # Using gr.State to hold the diacritized reference text between steps
    reference_text_state = gr.State("")

    with gr.Row():
        with gr.Column(scale=1):
            text_input = gr.Textbox(label="Undiacritized Arabic Text", lines=3, text_align="right")
            diacritize_btn = gr.Button("Diacritize Text")
            diacritized_output = gr.Textbox(
                label="Diacritized Text (Reference)",
                lines=3,
                interactive=False, # User shouldn't edit this directly
                text_align="right"
            )

        with gr.Column(scale=1):
            audio_input = gr.Audio(label="Record or Upload Audio", type="filepath", sources=["microphone", "upload"])
            transcribe_btn = gr.Button("Transcribe and Compare")
            transcript_output = gr.Textbox(
                label="Transcript (Hypothesis)",
                lines=3,
                interactive=False,
                text_align="right"
            )
            # Ensure this Textbox is defined correctly
            transcript_syllables_output = gr.Textbox(
                label="Transcript Syllables (Hypothesis)", # Corrected label slightly for clarity
                lines=3,
                interactive=False,
                text_align="right"
            )
            with gr.Row():
                wer_out = gr.Number(label="WER", interactive=False, precision=4)
                der_out = gr.Number(label="DER", interactive=False, precision=4)
                cer_out = gr.Number(label="CER", interactive=False, precision=4)
            # Use Markdown for potentially richer HTML display if needed, but HTML component is fine
            error_html = gr.HTML(label="Highlighted Errors in Hypothesis")
            error_list = gr.Textbox(label="Words Involved in Errors", interactive=False) # Changed label

    # --- Event Handlers ---

    # When Diacritize button is clicked
    diacritize_btn.click(
        fn=diacritize_text_api,
        inputs=[text_input],
        # Output to the display box AND the hidden state
        outputs=[diacritized_output, reference_text_state]
    )

    # Define the main processing function that returns all 7 values
    def process_audio_and_compare(audio_filepath, reference_text):
        """Processes audio, gets both transcripts, calculates metrics, and highlights errors."""
        # Default values in case of errors
        transcript = "Error: Processing failed."
        syllable_transcript = "Error: Processing failed."
        wer, der, cer = None, None, None
        html_output = ""
        error_words = ""

        # Validate inputs
        if not audio_filepath:
            transcript = "Error: No audio provided."
            syllable_transcript = "Error: No audio provided."
             # Return 7 values even on input error
            return transcript, syllable_transcript, None, None, None, "", ""
        if not reference_text:
             transcript = "Error: No reference text found. Please diacritize first."
             syllable_transcript = "Error: No reference text found."
             # Return 7 values
             return transcript, syllable_transcript, None, None, None, "", ""

        # --- Call Transcription APIs ---
        transcript = transcribe_audio_api(audio_filepath)
        # Call syllable transcription regardless of the first one's success for now,
        # but handle its potential error message.
        syllable_transcript = transcribe_syllable_audio_api(audio_filepath)

        # --- Calculate Metrics and Highlight Errors (only if first transcript is not an error) ---
        if not transcript.startswith("Error"):
            wer, der, cer = calculate_metrics(reference_text, transcript)
            # Use the standard transcript for highlighting, adjust if needed
            html_output, error_words = highlight_errors(reference_text, transcript)
        else:
            # If the main transcript failed, indicate no metrics/highlighting possible
            wer, der, cer = None, None, None
            html_output = "Highlighting not available due to transcription error."
            error_words = "N/A"

        # --- Return all 7 values ---
        return transcript, syllable_transcript, wer, der, cer, html_output, error_words

    # When Transcribe button is clicked
    transcribe_btn.click(
        fn=process_audio_and_compare,
        # Get audio path and the reference text from the state
        inputs=[audio_input, reference_text_state],
        # Update all 7 output components
        outputs=[
            transcript_output,
            transcript_syllables_output, # This should now update correctly
            wer_out,
            der_out,
            cer_out,
            error_html,
            error_list
        ]
    )

# Launch the app
if __name__ == "__main__":
    app.launch(debug=True, share=True) # Set share=True if you need a public link