Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from gradio_client import Client, handle_file | |
| import jiwer | |
| import os | |
| import time | |
| import warnings | |
| import pyarabic.araby as araby | |
| import difflib # Import difflib | |
| # Suppress specific UserWarnings from jiwer related to empty strings | |
| warnings.filterwarnings("ignore", message="Reference is empty.*", category=UserWarning) | |
| warnings.filterwarnings("ignore", message="Hypothesis is empty.*", category=UserWarning) | |
| # --- Constants --- | |
| DIACRITIZATION_API_URL = "Bisher/CATT.diacratization" | |
| TRANSCRIPTION_API_URL = "gh-kaka22/diacritic_level_arabic_transcription" | |
| SYLLABLE_TRANSCRIPTION_API_URL = "Bisher/arabic_syllable_transcription" | |
| # Define Arabic diacritics | |
| # Use a try-except block in case pyarabic is not installed or fails to import | |
| try: | |
| ARABIC_DIACRITICS = { | |
| araby.FATHA, araby.FATHATAN, araby.DAMMA, araby.DAMMATAN, | |
| araby.KASRA, araby.KASRATAN, araby.SUKUN, araby.SHADDA, | |
| } | |
| except (ImportError, NameError): | |
| print("Warning: pyarabic not found or failed to import. Using fallback diacritics set.") | |
| ARABIC_DIACRITICS = {'\u064B', '\u064C', '\u064D', '\u064E', '\u064F', '\u0650', '\u0651', '\u0652'} | |
| # --- API Clients --- | |
| # Use caching or global clients to avoid re-initializing on every call | |
| diacritization_client = None | |
| transcription_client = None | |
| syllable_transcription_client = None | |
| def get_diacritization_client(): | |
| global diacritization_client | |
| if diacritization_client is None: | |
| try: | |
| diacritization_client = Client(DIACRITIZATION_API_URL, download_files=True) | |
| except Exception as e: | |
| print(f"Error initializing diacritization client: {e}") | |
| return None | |
| return diacritization_client | |
| def get_transcription_client(): | |
| global transcription_client | |
| if transcription_client is None: | |
| try: | |
| transcription_client = Client(TRANSCRIPTION_API_URL, download_files=True) | |
| except Exception as e: | |
| print(f"Error initializing transcription client: {e}") | |
| return None | |
| return transcription_client | |
| def get_syllable_transcription_client(): | |
| global syllable_transcription_client | |
| if syllable_transcription_client is None: | |
| try: | |
| syllable_transcription_client = Client(SYLLABLE_TRANSCRIPTION_API_URL, download_files=True) | |
| except Exception as e: | |
| print(f"Error initializing syllable transcription client: {e}") | |
| return None | |
| return syllable_transcription_client | |
| # --- Helper Functions --- | |
| def diacritize_text_api(text_to_diacritize): | |
| """Calls the diacritization API.""" | |
| if not text_to_diacritize or not text_to_diacritize.strip(): | |
| return "Please enter some text to diacritize.", "" # Return two values as expected by the click handler | |
| client = get_diacritization_client() | |
| if not client: | |
| return "Error: Could not connect to the diacritization service.", "" | |
| try: | |
| result = client.predict( | |
| model_type="Encoder-Only", | |
| input_text=text_to_diacritize, | |
| api_name="/predict" | |
| ) | |
| # Ensure result is a string, handle potential None or unexpected types | |
| result_str = str(result) if result is not None else "Error: Empty response from diacritization service." | |
| # Return the result for both the output textbox and the state | |
| return result_str, result_str | |
| except Exception as e: | |
| print(f"Error during diacritization API call: {e}") | |
| return f"Error during diacritization: {e}", "" | |
| def transcribe_audio_api(audio_filepath): | |
| """Calls the standard transcription API.""" | |
| if not audio_filepath: | |
| return "Error: Please provide an audio recording or file." | |
| if not os.path.exists(audio_filepath): | |
| return f"Error: Audio file not found at {audio_filepath}" | |
| client = get_transcription_client() | |
| if not client: | |
| return "Error: Could not connect to the transcription service." | |
| try: | |
| # Add a small delay if needed, sometimes helps with API race conditions | |
| # time.sleep(0.5) | |
| result = client.predict( | |
| audio=handle_file(audio_filepath), | |
| api_name="/predict" | |
| ) | |
| # Process result, expecting a dictionary or string | |
| if isinstance(result, dict) and 'text' in result: | |
| transcript = result['text'] | |
| elif isinstance(result, str): | |
| transcript = result | |
| else: | |
| transcript = f"Error: Unexpected response format from transcription service: {type(result)}" | |
| return transcript if transcript is not None else "Error: Empty transcript received." | |
| except Exception as e: | |
| print(f"Error during transcription API call: {e}") | |
| return f"Error during transcription: {e}" | |
| def transcribe_syllable_audio_api(audio_filepath): | |
| """Calls the syllable transcription API.""" | |
| if not audio_filepath: | |
| # This case might not be strictly needed if called after the first check, but good practice | |
| return "Error: Audio file path missing for syllable transcription." | |
| if not os.path.exists(audio_filepath): | |
| return f"Error: Audio file not found at {audio_filepath} for syllable transcription." | |
| client = get_syllable_transcription_client() | |
| if not client: | |
| return "Error: Could not connect to the syllable transcription service." | |
| try: | |
| # Add a small delay if needed | |
| # time.sleep(0.5) | |
| result = client.predict( | |
| audio=handle_file(audio_filepath), | |
| api_name="/predict" | |
| ) | |
| # Process result, expecting a dictionary or string | |
| if isinstance(result, dict) and 'text' in result: | |
| transcript = result['text'] | |
| elif isinstance(result, str): | |
| transcript = result | |
| else: | |
| transcript = f"Error: Unexpected response format from syllable transcription service: {type(result)}" | |
| return transcript if transcript is not None else "Error: Empty syllable transcript received." | |
| except Exception as e: | |
| print(f"Error during syllable transcription API call: {e}") | |
| return f"Error during syllable transcription: {e}" | |
| def get_diacritics_sequence(text): | |
| """Extracts diacritics from a string.""" | |
| if not isinstance(text, str): | |
| return "" | |
| diacritics_only = [c for c in text if c in ARABIC_DIACRITICS] | |
| return ' '.join(diacritics_only) | |
| def calculate_metrics(reference, hypothesis): | |
| """Calculates WER, DER, CER.""" | |
| ref = reference or "" | |
| hyp = hypothesis or "" | |
| # Handle cases where one or both are empty or just whitespace | |
| if not ref.strip() and not hyp.strip(): | |
| return 0.0, 0.0, 0.0 # Both empty, 0 error | |
| if not ref.strip(): | |
| return 1.0, 1.0, 1.0 # Reference empty, hypothesis not: Max error | |
| if not hyp.strip(): | |
| # Hypothesis empty, reference not: Max error (though jiwer might handle this) | |
| # Let jiwer calculate based on its rules for empty hypothesis | |
| pass | |
| try: | |
| # WER | |
| wer = jiwer.wer(ref, hyp) | |
| # DER | |
| ref_d = get_diacritics_sequence(ref) | |
| hyp_d = get_diacritics_sequence(hyp) | |
| # Handle empty diacritic sequences for DER calculation | |
| if not ref_d.strip() and not hyp_d.strip(): | |
| der = 0.0 | |
| elif not ref_d.strip(): | |
| der = 1.0 | |
| else: | |
| der = jiwer.wer(ref_d, hyp_d) # jiwer handles empty hyp_d if ref_d is not empty | |
| # CER | |
| cer = jiwer.cer(ref, hyp) | |
| return round(wer, 4), round(der, 4), round(cer, 4) | |
| except Exception as e: | |
| print(f"Error calculating metrics: {e}") | |
| return None, None, None # Indicate error in calculation | |
| def highlight_errors(reference, hypothesis): | |
| """Highlights differences between reference and hypothesis using HTML mark tag.""" | |
| ref = reference or "" | |
| hyp = hypothesis or "" | |
| ref_words = ref.split() | |
| hyp_words = hyp.split() | |
| if not ref_words and not hyp_words: | |
| return "", "" # No errors if both are empty | |
| matcher = difflib.SequenceMatcher(None, ref_words, hyp_words, autojunk=False) | |
| highlighted_hyp_words = [] | |
| error_words_ref = [] # Words in reference that were deleted or replaced | |
| error_words_hyp = [] # Words in hypothesis that were inserted or replaced | |
| for tag, i1, i2, j1, j2 in matcher.get_opcodes(): | |
| if tag == 'equal': | |
| highlighted_hyp_words.extend(hyp_words[j1:j2]) | |
| elif tag == 'replace': | |
| # Mark incorrect words in hypothesis red | |
| for word in hyp_words[j1:j2]: | |
| highlighted_hyp_words.append(f"<mark style='background-color: #ffcccb;'>{word}</mark>") | |
| error_words_ref.extend(ref_words[i1:i2]) | |
| error_words_hyp.extend(hyp_words[j1:j2]) | |
| elif tag == 'delete': | |
| # Indicate missing words (maybe with a placeholder?) - for now, just note them | |
| # We don't add anything to highlighted_hyp_words here as they are missing | |
| error_words_ref.extend(ref_words[i1:i2]) | |
| # Optionally add a placeholder in the output to show where deletion happened | |
| # highlighted_hyp_words.append("<mark style='background-color: #lightgrey;'>[missing]</mark>") | |
| elif tag == 'insert': | |
| # Mark inserted words in hypothesis green | |
| for word in hyp_words[j1:j2]: | |
| highlighted_hyp_words.append(f"<mark style='background-color: #ccffcc;'>{word}</mark>") | |
| error_words_hyp.extend(hyp_words[j1:j2]) | |
| html_output = ' '.join(highlighted_hyp_words) | |
| # Combine unique error words for the list | |
| error_list = sorted(list(set(error_words_ref + error_words_hyp))) | |
| return html_output, ', '.join(error_list) | |
| # --- Gradio Interface --- | |
| with gr.Blocks(theme=gr.themes.Soft()) as app: | |
| gr.Markdown( | |
| """ | |
| # Arabic Diacritization and Reading Assessment Tool | |
| 1. Enter undiacritized Arabic text and click **Diacritize Text**. | |
| 2. Read the generated **Diacritized Text** aloud and record or upload audio. | |
| 3. Click **Transcribe and Compare** to see the transcript, syllable transcript, WER/DER/CER, and mispronounced words highlighted. | |
| """ | |
| ) | |
| # Using gr.State to hold the diacritized reference text between steps | |
| reference_text_state = gr.State("") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| text_input = gr.Textbox(label="Undiacritized Arabic Text", lines=3, text_align="right") | |
| diacritize_btn = gr.Button("Diacritize Text") | |
| diacritized_output = gr.Textbox( | |
| label="Diacritized Text (Reference)", | |
| lines=3, | |
| interactive=False, # User shouldn't edit this directly | |
| text_align="right" | |
| ) | |
| with gr.Column(scale=1): | |
| audio_input = gr.Audio(label="Record or Upload Audio", type="filepath", sources=["microphone", "upload"]) | |
| transcribe_btn = gr.Button("Transcribe and Compare") | |
| transcript_output = gr.Textbox( | |
| label="Transcript (Hypothesis)", | |
| lines=3, | |
| interactive=False, | |
| text_align="right" | |
| ) | |
| # Ensure this Textbox is defined correctly | |
| transcript_syllables_output = gr.Textbox( | |
| label="Transcript Syllables (Hypothesis)", # Corrected label slightly for clarity | |
| lines=3, | |
| interactive=False, | |
| text_align="right" | |
| ) | |
| with gr.Row(): | |
| wer_out = gr.Number(label="WER", interactive=False, precision=4) | |
| der_out = gr.Number(label="DER", interactive=False, precision=4) | |
| cer_out = gr.Number(label="CER", interactive=False, precision=4) | |
| # Use Markdown for potentially richer HTML display if needed, but HTML component is fine | |
| error_html = gr.HTML(label="Highlighted Errors in Hypothesis") | |
| error_list = gr.Textbox(label="Words Involved in Errors", interactive=False) # Changed label | |
| # --- Event Handlers --- | |
| # When Diacritize button is clicked | |
| diacritize_btn.click( | |
| fn=diacritize_text_api, | |
| inputs=[text_input], | |
| # Output to the display box AND the hidden state | |
| outputs=[diacritized_output, reference_text_state] | |
| ) | |
| # Define the main processing function that returns all 7 values | |
| def process_audio_and_compare(audio_filepath, reference_text): | |
| """Processes audio, gets both transcripts, calculates metrics, and highlights errors.""" | |
| # Default values in case of errors | |
| transcript = "Error: Processing failed." | |
| syllable_transcript = "Error: Processing failed." | |
| wer, der, cer = None, None, None | |
| html_output = "" | |
| error_words = "" | |
| # Validate inputs | |
| if not audio_filepath: | |
| transcript = "Error: No audio provided." | |
| syllable_transcript = "Error: No audio provided." | |
| # Return 7 values even on input error | |
| return transcript, syllable_transcript, None, None, None, "", "" | |
| if not reference_text: | |
| transcript = "Error: No reference text found. Please diacritize first." | |
| syllable_transcript = "Error: No reference text found." | |
| # Return 7 values | |
| return transcript, syllable_transcript, None, None, None, "", "" | |
| # --- Call Transcription APIs --- | |
| transcript = transcribe_audio_api(audio_filepath) | |
| # Call syllable transcription regardless of the first one's success for now, | |
| # but handle its potential error message. | |
| syllable_transcript = transcribe_syllable_audio_api(audio_filepath) | |
| # --- Calculate Metrics and Highlight Errors (only if first transcript is not an error) --- | |
| if not transcript.startswith("Error"): | |
| wer, der, cer = calculate_metrics(reference_text, transcript) | |
| # Use the standard transcript for highlighting, adjust if needed | |
| html_output, error_words = highlight_errors(reference_text, transcript) | |
| else: | |
| # If the main transcript failed, indicate no metrics/highlighting possible | |
| wer, der, cer = None, None, None | |
| html_output = "Highlighting not available due to transcription error." | |
| error_words = "N/A" | |
| # --- Return all 7 values --- | |
| return transcript, syllable_transcript, wer, der, cer, html_output, error_words | |
| # When Transcribe button is clicked | |
| transcribe_btn.click( | |
| fn=process_audio_and_compare, | |
| # Get audio path and the reference text from the state | |
| inputs=[audio_input, reference_text_state], | |
| # Update all 7 output components | |
| outputs=[ | |
| transcript_output, | |
| transcript_syllables_output, # This should now update correctly | |
| wer_out, | |
| der_out, | |
| cer_out, | |
| error_html, | |
| error_list | |
| ] | |
| ) | |
| # Launch the app | |
| if __name__ == "__main__": | |
| app.launch(debug=True, share=True) # Set share=True if you need a public link | |