Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from gradio_client import Client, handle_file | |
| import jiwer | |
| import os | |
| import time | |
| import warnings | |
| # Suppress specific UserWarnings from jiwer related to empty strings | |
| warnings.filterwarnings("ignore", message="Reference is empty.*", category=UserWarning) | |
| warnings.filterwarnings("ignore", message="Hypothesis is empty.*", category=UserWarning) | |
| # --- Constants --- | |
| DIACRITIZATION_API_URL = "Bisher/CATT.diacratization" | |
| TRANSCRIPTION_API_URL = "gh-kaka22/diacritic_level_arabic_transcription" | |
| # --- Gradio API Clients --- | |
| # It's good practice to initialize clients outside the functions | |
| # if the app runs continuously, but be mindful of potential state issues | |
| # or connection timeouts in long-running deployments. For simplicity here, | |
| # we might re-initialize, though a single initialization is often preferred. | |
| def get_diacritization_client(): | |
| """Initializes and returns the client for the text diacritization API.""" | |
| try: | |
| # Added timeout for robustness | |
| return Client(DIACRITIZATION_API_URL, download_files=True) | |
| except Exception as e: | |
| print(f"Error initializing diacritization client: {e}") | |
| return None | |
| def get_transcription_client(): | |
| """Initializes and returns the client for the audio transcription API.""" | |
| try: | |
| # Added timeout for robustness | |
| return Client(TRANSCRIPTION_API_URL, download_files=True) | |
| except Exception as e: | |
| print(f"Error initializing transcription client: {e}") | |
| return None | |
| # --- Helper Functions --- | |
| def diacritize_text_api(text_to_diacritize): | |
| """ | |
| Calls the Hugging Face space to diacritize the input text. | |
| Args: | |
| text_to_diacritize (str): The undiacritized Arabic text. | |
| Returns: | |
| tuple: (str, str) The diacritized text (or error message) returned twice, | |
| once for the output component and once for the state. | |
| """ | |
| if not text_to_diacritize or not text_to_diacritize.strip(): | |
| error_msg = "Please enter some text to diacritize." | |
| # Return the error message twice | |
| return error_msg, error_msg | |
| client = get_diacritization_client() | |
| if not client: | |
| error_msg = "Error: Could not connect to the diacritization service." | |
| # Return the error message twice | |
| return error_msg, error_msg | |
| try: | |
| print(f"Sending text to diacritization API: {text_to_diacritize}") | |
| result = client.predict( | |
| model_type="Encoder-Only", # Or 'Encoder-Decoder' if preferred | |
| input_text=text_to_diacritize, | |
| api_name="/predict" | |
| ) | |
| print(f"Received diacritized text: {result}") | |
| # Ensure result is a string before returning | |
| result_str = str(result) if result is not None else "Error: Received empty response from diacritization service." | |
| # Return the result twice | |
| return result_str, result_str | |
| except Exception as e: | |
| print(f"Error during text diacritization API call: {e}") | |
| error_msg = f"Error during diacritization: {e}" | |
| # Return the error message twice | |
| return error_msg, error_msg | |
| def transcribe_audio_api(audio_filepath): | |
| """ | |
| Calls the Hugging Face space to transcribe and diacritize the input audio. | |
| Args: | |
| audio_filepath (str): The path to the audio file. | |
| Returns: | |
| str: The diacritized transcript, or an error message. | |
| """ | |
| if not audio_filepath: | |
| return "Error: Please provide an audio recording or file." | |
| # Check if file exists and is accessible | |
| if not os.path.exists(audio_filepath): | |
| return f"Error: Audio file not found at {audio_filepath}" | |
| client = get_transcription_client() | |
| if not client: | |
| return "Error: Could not connect to the transcription service." | |
| try: | |
| print(f"Sending audio file to transcription API: {audio_filepath}") | |
| # Use handle_file to manage the audio file for the API call | |
| result = client.predict( | |
| audio=handle_file(audio_filepath), | |
| api_name="/predict" | |
| ) | |
| print(f"Received transcript: {result}") | |
| # The API might return more structure, adapt if needed. Assuming it returns the text directly. | |
| # Example: if result is {'text': '...'}, use result['text'] | |
| if isinstance(result, dict) and 'text' in result: | |
| transcript = result['text'] | |
| elif isinstance(result, str): | |
| transcript = result | |
| else: | |
| print(f"Unexpected transcription result format: {result}") | |
| return "Error: Unexpected format received from transcription service." | |
| # Ensure transcript is a string | |
| return str(transcript) if transcript is not None else "Error: Received empty response from transcription service." | |
| except Exception as e: | |
| print(f"Error during audio transcription API call: {e}") | |
| # Provide more specific error feedback if possible | |
| return f"Error during transcription: {e}" | |
| def calculate_metrics(reference, hypothesis): | |
| """ | |
| Calculates Word Error Rate (WER) and Diacritic Error Rate (DER). | |
| Args: | |
| reference (str): The original diacritized text. | |
| hypothesis (str): The diacritized transcript from the audio. | |
| Returns: | |
| tuple: (wer, der) scores, or (None, None) if inputs are invalid or calculation fails. | |
| """ | |
| # Ensure inputs are strings before proceeding | |
| if not isinstance(reference, str): | |
| print(f"Error: Reference input is not a string (type: {type(reference)}). Value: {reference}") | |
| reference = "" # Default to empty string to avoid downstream errors | |
| if not isinstance(hypothesis, str): | |
| print(f"Error: Hypothesis input is not a string (type: {type(hypothesis)}). Value: {hypothesis}") | |
| hypothesis = "" # Default to empty string | |
| # Handle empty strings to avoid jiwer warnings/errors if not suppressed | |
| ref_strip = reference.strip() | |
| hyp_strip = hypothesis.strip() | |
| if not ref_strip and not hyp_strip: | |
| return 0.0, 0.0 # Both empty, 0% error | |
| if not ref_strip: | |
| print("Warning: Reference text is empty.") | |
| # WER/DER are typically 1.0 (or inf) if reference is empty and hypothesis is not. | |
| return 1.0, 1.0 | |
| # Note: If hypothesis is empty but reference is not, jiwer calculates WER=1.0, which is correct. | |
| try: | |
| # 1. Calculate Word Error Rate (WER) | |
| wer = jiwer.wer(reference, hypothesis) | |
| # 2. Calculate Diacritic Error Rate (DER) | |
| # - Treat each character (including diacritics) as a token. | |
| # - Join characters with spaces to make jiwer treat them as "words". | |
| ref_chars = ' '.join(list(reference)) | |
| hyp_chars = ' '.join(list(hypothesis)) | |
| # Need to handle potential empty strings after join for jiwer | |
| if not ref_chars.strip() and not hyp_chars.strip(): | |
| der = 0.0 | |
| elif not ref_chars.strip(): | |
| der = 1.0 | |
| else: | |
| der = jiwer.wer(ref_chars, hyp_chars) | |
| return round(wer, 4), round(der, 4) | |
| except Exception as e: | |
| print(f"Error calculating metrics: {e}") | |
| return None, None | |
| def process_audio_and_compare(audio_input, original_diacritized_text): | |
| """ | |
| Main function triggered after audio input. | |
| Transcribes audio, calculates metrics, and returns results. | |
| Returns: | |
| tuple: (transcript, wer, der) | |
| transcript (str): The transcribed text or an error message. | |
| wer (float | None): Word Error Rate or None if error. | |
| der (float | None): Diacritic Error Rate or None if error. | |
| """ | |
| print("Processing audio and comparing...") | |
| # Check if original_diacritized_text is valid | |
| if not original_diacritized_text or not isinstance(original_diacritized_text, str) or original_diacritized_text.startswith("Error:"): | |
| error_msg = "Error: Valid reference diacritized text not available. Please diacritize text first." | |
| print(error_msg) | |
| # Return default/error values for all outputs | |
| return error_msg, None, None | |
| # --- 1. Transcribe Audio --- | |
| # Gradio provides the audio data (e.g., filepath for upload/mic) | |
| transcript = transcribe_audio_api(audio_input) | |
| if not isinstance(transcript, str) or transcript.startswith("Error:"): | |
| # If transcription failed, return the error and None for metrics | |
| error_msg = transcript if isinstance(transcript, str) else "Error: Transcription failed with non-string output." | |
| print(error_msg) | |
| return error_msg, None, None | |
| # --- 2. Calculate Metrics --- | |
| wer, der = calculate_metrics(original_diacritized_text, transcript) | |
| if wer is None or der is None: | |
| print("Metrics calculation failed.") | |
| # Return transcript but indicate metric failure | |
| return transcript, None, None | |
| print(f"Comparison complete. WER: {wer}, DER: {der}") | |
| return transcript, wer, der | |
| # --- Gradio Interface --- | |
| with gr.Blocks(theme=gr.themes.Soft()) as app: | |
| gr.Markdown( | |
| """ | |
| # Arabic Diacritization and Reading Assessment Tool | |
| 1. Enter undiacritized Arabic text and click **Diacritize Text**. | |
| 2. Read the generated **Diacritized Text** aloud and record it using the microphone or upload an audio file. | |
| 3. Click **Transcribe and Compare** to get the transcript and see the WER/DER scores compared to the original diacritized text. | |
| """ | |
| ) | |
| # Store the original diacritized text for comparison later | |
| original_diacritized_state = gr.State("") # Initialize state | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| text_input = gr.Textbox( | |
| label="1. Enter Undiacritized Arabic Text", | |
| placeholder="ู ุซุงู: ุงูุณูุงู ุนูููู ", | |
| lines=3, | |
| text_align="right", # Align text right for Arabic | |
| ) | |
| diacritize_button = gr.Button("Diacritize Text") | |
| diacritized_text_output = gr.Textbox( | |
| label="2. Diacritized Text (Reference)", | |
| lines=3, | |
| interactive=False, # User shouldn't edit this directly | |
| text_align="right", | |
| ) | |
| with gr.Column(scale=1): | |
| audio_input = gr.Audio( | |
| sources=["microphone", "upload"], | |
| type="filepath", # Get the path to the saved audio file | |
| label="3. Record or Upload Audio of Reading Diacritized Text", | |
| ) | |
| transcribe_button = gr.Button("Transcribe and Compare") | |
| transcript_output = gr.Textbox( | |
| label="4. Diacritized Transcript (Hypothesis)", | |
| lines=3, | |
| interactive=False, | |
| text_align="right", | |
| ) | |
| with gr.Row(): | |
| # Set precision for number outputs | |
| wer_output = gr.Number(label="Word Error Rate (WER)", interactive=False, precision=4) | |
| der_output = gr.Number(label="Diacritic Error Rate (DER)", interactive=False, precision=4) | |
| # --- Connect Components --- | |
| # Action for Diacritize Button | |
| diacritize_button.click( | |
| fn=diacritize_text_api, | |
| inputs=[text_input], | |
| # Expects two outputs now from the modified function | |
| outputs=[diacritized_text_output, original_diacritized_state] | |
| ) | |
| # Action for Transcribe Button | |
| transcribe_button.click( | |
| fn=process_audio_and_compare, | |
| inputs=[audio_input, original_diacritized_state], # Pass audio and stored text | |
| outputs=[transcript_output, wer_output, der_output] # Update transcript and metrics | |
| ) | |
| app.launch(debug=True, share=True) |