import gradio as gr from gradio_client import Client, handle_file import jiwer import os import time import warnings import pyarabic.araby as araby import difflib # Import difflib # Suppress specific UserWarnings from jiwer related to empty strings warnings.filterwarnings("ignore", message="Reference is empty.*", category=UserWarning) warnings.filterwarnings("ignore", message="Hypothesis is empty.*", category=UserWarning) # --- Constants --- DIACRITIZATION_API_URL = "Bisher/CATT.diacratization" TRANSCRIPTION_API_URL = "gh-kaka22/diacritic_level_arabic_transcription" SYLLABLE_TRANSCRIPTION_API_URL = "Bisher/arabic_syllable_transcription" # Define Arabic diacritics # Use a try-except block in case pyarabic is not installed or fails to import try: ARABIC_DIACRITICS = { araby.FATHA, araby.FATHATAN, araby.DAMMA, araby.DAMMATAN, araby.KASRA, araby.KASRATAN, araby.SUKUN, araby.SHADDA, } except (ImportError, NameError): print("Warning: pyarabic not found or failed to import. Using fallback diacritics set.") ARABIC_DIACRITICS = {'\u064B', '\u064C', '\u064D', '\u064E', '\u064F', '\u0650', '\u0651', '\u0652'} # --- API Clients --- # Use caching or global clients to avoid re-initializing on every call diacritization_client = None transcription_client = None syllable_transcription_client = None def get_diacritization_client(): global diacritization_client if diacritization_client is None: try: diacritization_client = Client(DIACRITIZATION_API_URL, download_files=True) except Exception as e: print(f"Error initializing diacritization client: {e}") return None return diacritization_client def get_transcription_client(): global transcription_client if transcription_client is None: try: transcription_client = Client(TRANSCRIPTION_API_URL, download_files=True) except Exception as e: print(f"Error initializing transcription client: {e}") return None return transcription_client def get_syllable_transcription_client(): global syllable_transcription_client if syllable_transcription_client is None: try: syllable_transcription_client = Client(SYLLABLE_TRANSCRIPTION_API_URL, download_files=True) except Exception as e: print(f"Error initializing syllable transcription client: {e}") return None return syllable_transcription_client # --- Helper Functions --- def diacritize_text_api(text_to_diacritize): """Calls the diacritization API.""" if not text_to_diacritize or not text_to_diacritize.strip(): return "Please enter some text to diacritize.", "" # Return two values as expected by the click handler client = get_diacritization_client() if not client: return "Error: Could not connect to the diacritization service.", "" try: result = client.predict( model_type="Encoder-Only", input_text=text_to_diacritize, api_name="/predict" ) # Ensure result is a string, handle potential None or unexpected types result_str = str(result) if result is not None else "Error: Empty response from diacritization service." # Return the result for both the output textbox and the state return result_str, result_str except Exception as e: print(f"Error during diacritization API call: {e}") return f"Error during diacritization: {e}", "" def transcribe_audio_api(audio_filepath): """Calls the standard transcription API.""" if not audio_filepath: return "Error: Please provide an audio recording or file." if not os.path.exists(audio_filepath): return f"Error: Audio file not found at {audio_filepath}" client = get_transcription_client() if not client: return "Error: Could not connect to the transcription service." try: # Add a small delay if needed, sometimes helps with API race conditions # time.sleep(0.5) result = client.predict( audio=handle_file(audio_filepath), api_name="/predict" ) # Process result, expecting a dictionary or string if isinstance(result, dict) and 'text' in result: transcript = result['text'] elif isinstance(result, str): transcript = result else: transcript = f"Error: Unexpected response format from transcription service: {type(result)}" return transcript if transcript is not None else "Error: Empty transcript received." except Exception as e: print(f"Error during transcription API call: {e}") return f"Error during transcription: {e}" def transcribe_syllable_audio_api(audio_filepath): """Calls the syllable transcription API.""" if not audio_filepath: # This case might not be strictly needed if called after the first check, but good practice return "Error: Audio file path missing for syllable transcription." if not os.path.exists(audio_filepath): return f"Error: Audio file not found at {audio_filepath} for syllable transcription." client = get_syllable_transcription_client() if not client: return "Error: Could not connect to the syllable transcription service." try: # Add a small delay if needed # time.sleep(0.5) result = client.predict( audio=handle_file(audio_filepath), api_name="/predict" ) # Process result, expecting a dictionary or string if isinstance(result, dict) and 'text' in result: transcript = result['text'] elif isinstance(result, str): transcript = result else: transcript = f"Error: Unexpected response format from syllable transcription service: {type(result)}" return transcript if transcript is not None else "Error: Empty syllable transcript received." except Exception as e: print(f"Error during syllable transcription API call: {e}") return f"Error during syllable transcription: {e}" def get_diacritics_sequence(text): """Extracts diacritics from a string.""" if not isinstance(text, str): return "" diacritics_only = [c for c in text if c in ARABIC_DIACRITICS] return ' '.join(diacritics_only) def calculate_metrics(reference, hypothesis): """Calculates WER, DER, CER.""" ref = reference or "" hyp = hypothesis or "" # Handle cases where one or both are empty or just whitespace if not ref.strip() and not hyp.strip(): return 0.0, 0.0, 0.0 # Both empty, 0 error if not ref.strip(): return 1.0, 1.0, 1.0 # Reference empty, hypothesis not: Max error if not hyp.strip(): # Hypothesis empty, reference not: Max error (though jiwer might handle this) # Let jiwer calculate based on its rules for empty hypothesis pass try: # WER wer = jiwer.wer(ref, hyp) # DER ref_d = get_diacritics_sequence(ref) hyp_d = get_diacritics_sequence(hyp) # Handle empty diacritic sequences for DER calculation if not ref_d.strip() and not hyp_d.strip(): der = 0.0 elif not ref_d.strip(): der = 1.0 else: der = jiwer.wer(ref_d, hyp_d) # jiwer handles empty hyp_d if ref_d is not empty # CER cer = jiwer.cer(ref, hyp) return round(wer, 4), round(der, 4), round(cer, 4) except Exception as e: print(f"Error calculating metrics: {e}") return None, None, None # Indicate error in calculation def highlight_errors(reference, hypothesis): """Highlights differences between reference and hypothesis using HTML mark tag.""" ref = reference or "" hyp = hypothesis or "" ref_words = ref.split() hyp_words = hyp.split() if not ref_words and not hyp_words: return "", "" # No errors if both are empty matcher = difflib.SequenceMatcher(None, ref_words, hyp_words, autojunk=False) highlighted_hyp_words = [] error_words_ref = [] # Words in reference that were deleted or replaced error_words_hyp = [] # Words in hypothesis that were inserted or replaced for tag, i1, i2, j1, j2 in matcher.get_opcodes(): if tag == 'equal': highlighted_hyp_words.extend(hyp_words[j1:j2]) elif tag == 'replace': # Mark incorrect words in hypothesis red for word in hyp_words[j1:j2]: highlighted_hyp_words.append(f"{word}") error_words_ref.extend(ref_words[i1:i2]) error_words_hyp.extend(hyp_words[j1:j2]) elif tag == 'delete': # Indicate missing words (maybe with a placeholder?) - for now, just note them # We don't add anything to highlighted_hyp_words here as they are missing error_words_ref.extend(ref_words[i1:i2]) # Optionally add a placeholder in the output to show where deletion happened # highlighted_hyp_words.append("[missing]") elif tag == 'insert': # Mark inserted words in hypothesis green for word in hyp_words[j1:j2]: highlighted_hyp_words.append(f"{word}") error_words_hyp.extend(hyp_words[j1:j2]) html_output = ' '.join(highlighted_hyp_words) # Combine unique error words for the list error_list = sorted(list(set(error_words_ref + error_words_hyp))) return html_output, ', '.join(error_list) # --- Gradio Interface --- with gr.Blocks(theme=gr.themes.Soft()) as app: gr.Markdown( """ # Arabic Diacritization and Reading Assessment Tool 1. Enter undiacritized Arabic text and click **Diacritize Text**. 2. Read the generated **Diacritized Text** aloud and record or upload audio. 3. Click **Transcribe and Compare** to see the transcript, syllable transcript, WER/DER/CER, and mispronounced words highlighted. """ ) # Using gr.State to hold the diacritized reference text between steps reference_text_state = gr.State("") with gr.Row(): with gr.Column(scale=1): text_input = gr.Textbox(label="Undiacritized Arabic Text", lines=3, text_align="right") diacritize_btn = gr.Button("Diacritize Text") diacritized_output = gr.Textbox( label="Diacritized Text (Reference)", lines=3, interactive=False, # User shouldn't edit this directly text_align="right" ) with gr.Column(scale=1): audio_input = gr.Audio(label="Record or Upload Audio", type="filepath", sources=["microphone", "upload"]) transcribe_btn = gr.Button("Transcribe and Compare") transcript_output = gr.Textbox( label="Transcript (Hypothesis)", lines=3, interactive=False, text_align="right" ) # Ensure this Textbox is defined correctly transcript_syllables_output = gr.Textbox( label="Transcript Syllables (Hypothesis)", # Corrected label slightly for clarity lines=3, interactive=False, text_align="right" ) with gr.Row(): wer_out = gr.Number(label="WER", interactive=False, precision=4) der_out = gr.Number(label="DER", interactive=False, precision=4) cer_out = gr.Number(label="CER", interactive=False, precision=4) # Use Markdown for potentially richer HTML display if needed, but HTML component is fine error_html = gr.HTML(label="Highlighted Errors in Hypothesis") error_list = gr.Textbox(label="Words Involved in Errors", interactive=False) # Changed label # --- Event Handlers --- # When Diacritize button is clicked diacritize_btn.click( fn=diacritize_text_api, inputs=[text_input], # Output to the display box AND the hidden state outputs=[diacritized_output, reference_text_state] ) # Define the main processing function that returns all 7 values def process_audio_and_compare(audio_filepath, reference_text): """Processes audio, gets both transcripts, calculates metrics, and highlights errors.""" # Default values in case of errors transcript = "Error: Processing failed." syllable_transcript = "Error: Processing failed." wer, der, cer = None, None, None html_output = "" error_words = "" # Validate inputs if not audio_filepath: transcript = "Error: No audio provided." syllable_transcript = "Error: No audio provided." # Return 7 values even on input error return transcript, syllable_transcript, None, None, None, "", "" if not reference_text: transcript = "Error: No reference text found. Please diacritize first." syllable_transcript = "Error: No reference text found." # Return 7 values return transcript, syllable_transcript, None, None, None, "", "" # --- Call Transcription APIs --- transcript = transcribe_audio_api(audio_filepath) # Call syllable transcription regardless of the first one's success for now, # but handle its potential error message. syllable_transcript = transcribe_syllable_audio_api(audio_filepath) # --- Calculate Metrics and Highlight Errors (only if first transcript is not an error) --- if not transcript.startswith("Error"): wer, der, cer = calculate_metrics(reference_text, transcript) # Use the standard transcript for highlighting, adjust if needed html_output, error_words = highlight_errors(reference_text, transcript) else: # If the main transcript failed, indicate no metrics/highlighting possible wer, der, cer = None, None, None html_output = "Highlighting not available due to transcription error." error_words = "N/A" # --- Return all 7 values --- return transcript, syllable_transcript, wer, der, cer, html_output, error_words # When Transcribe button is clicked transcribe_btn.click( fn=process_audio_and_compare, # Get audio path and the reference text from the state inputs=[audio_input, reference_text_state], # Update all 7 output components outputs=[ transcript_output, transcript_syllables_output, # This should now update correctly wer_out, der_out, cer_out, error_html, error_list ] ) # Launch the app if __name__ == "__main__": app.launch(debug=True, share=True) # Set share=True if you need a public link