Bisher's picture
Update app.py
33b395c verified
raw
history blame
11.7 kB
import gradio as gr
from gradio_client import Client, handle_file
import jiwer
import os
import time
import warnings
# Suppress specific UserWarnings from jiwer related to empty strings
warnings.filterwarnings("ignore", message="Reference is empty.*", category=UserWarning)
warnings.filterwarnings("ignore", message="Hypothesis is empty.*", category=UserWarning)
# --- Constants ---
DIACRITIZATION_API_URL = "Bisher/CATT.diacratization"
TRANSCRIPTION_API_URL = "gh-kaka22/diacritic_level_arabic_transcription"
# --- Gradio API Clients ---
# It's good practice to initialize clients outside the functions
# if the app runs continuously, but be mindful of potential state issues
# or connection timeouts in long-running deployments. For simplicity here,
# we might re-initialize, though a single initialization is often preferred.
def get_diacritization_client():
"""Initializes and returns the client for the text diacritization API."""
try:
# Added timeout for robustness
return Client(DIACRITIZATION_API_URL, download_files=True)
except Exception as e:
print(f"Error initializing diacritization client: {e}")
return None
def get_transcription_client():
"""Initializes and returns the client for the audio transcription API."""
try:
# Added timeout for robustness
return Client(TRANSCRIPTION_API_URL, download_files=True)
except Exception as e:
print(f"Error initializing transcription client: {e}")
return None
# --- Helper Functions ---
def diacritize_text_api(text_to_diacritize):
"""
Calls the Hugging Face space to diacritize the input text.
Args:
text_to_diacritize (str): The undiacritized Arabic text.
Returns:
tuple: (str, str) The diacritized text (or error message) returned twice,
once for the output component and once for the state.
"""
if not text_to_diacritize or not text_to_diacritize.strip():
error_msg = "Please enter some text to diacritize."
# Return the error message twice
return error_msg, error_msg
client = get_diacritization_client()
if not client:
error_msg = "Error: Could not connect to the diacritization service."
# Return the error message twice
return error_msg, error_msg
try:
print(f"Sending text to diacritization API: {text_to_diacritize}")
result = client.predict(
model_type="Encoder-Only", # Or 'Encoder-Decoder' if preferred
input_text=text_to_diacritize,
api_name="/predict"
)
print(f"Received diacritized text: {result}")
# Ensure result is a string before returning
result_str = str(result) if result is not None else "Error: Received empty response from diacritization service."
# Return the result twice
return result_str, result_str
except Exception as e:
print(f"Error during text diacritization API call: {e}")
error_msg = f"Error during diacritization: {e}"
# Return the error message twice
return error_msg, error_msg
def transcribe_audio_api(audio_filepath):
"""
Calls the Hugging Face space to transcribe and diacritize the input audio.
Args:
audio_filepath (str): The path to the audio file.
Returns:
str: The diacritized transcript, or an error message.
"""
if not audio_filepath:
return "Error: Please provide an audio recording or file."
# Check if file exists and is accessible
if not os.path.exists(audio_filepath):
return f"Error: Audio file not found at {audio_filepath}"
client = get_transcription_client()
if not client:
return "Error: Could not connect to the transcription service."
try:
print(f"Sending audio file to transcription API: {audio_filepath}")
# Use handle_file to manage the audio file for the API call
result = client.predict(
audio=handle_file(audio_filepath),
api_name="/predict"
)
print(f"Received transcript: {result}")
# The API might return more structure, adapt if needed. Assuming it returns the text directly.
# Example: if result is {'text': '...'}, use result['text']
if isinstance(result, dict) and 'text' in result:
transcript = result['text']
elif isinstance(result, str):
transcript = result
else:
print(f"Unexpected transcription result format: {result}")
return "Error: Unexpected format received from transcription service."
# Ensure transcript is a string
return str(transcript) if transcript is not None else "Error: Received empty response from transcription service."
except Exception as e:
print(f"Error during audio transcription API call: {e}")
# Provide more specific error feedback if possible
return f"Error during transcription: {e}"
def calculate_metrics(reference, hypothesis):
"""
Calculates Word Error Rate (WER) and Diacritic Error Rate (DER).
Args:
reference (str): The original diacritized text.
hypothesis (str): The diacritized transcript from the audio.
Returns:
tuple: (wer, der) scores, or (None, None) if inputs are invalid or calculation fails.
"""
# Ensure inputs are strings before proceeding
if not isinstance(reference, str):
print(f"Error: Reference input is not a string (type: {type(reference)}). Value: {reference}")
reference = "" # Default to empty string to avoid downstream errors
if not isinstance(hypothesis, str):
print(f"Error: Hypothesis input is not a string (type: {type(hypothesis)}). Value: {hypothesis}")
hypothesis = "" # Default to empty string
# Handle empty strings to avoid jiwer warnings/errors if not suppressed
ref_strip = reference.strip()
hyp_strip = hypothesis.strip()
if not ref_strip and not hyp_strip:
return 0.0, 0.0 # Both empty, 0% error
if not ref_strip:
print("Warning: Reference text is empty.")
# WER/DER are typically 1.0 (or inf) if reference is empty and hypothesis is not.
return 1.0, 1.0
# Note: If hypothesis is empty but reference is not, jiwer calculates WER=1.0, which is correct.
try:
# 1. Calculate Word Error Rate (WER)
wer = jiwer.wer(reference, hypothesis)
# 2. Calculate Diacritic Error Rate (DER)
# - Treat each character (including diacritics) as a token.
# - Join characters with spaces to make jiwer treat them as "words".
ref_chars = ' '.join(list(reference))
hyp_chars = ' '.join(list(hypothesis))
# Need to handle potential empty strings after join for jiwer
if not ref_chars.strip() and not hyp_chars.strip():
der = 0.0
elif not ref_chars.strip():
der = 1.0
else:
der = jiwer.wer(ref_chars, hyp_chars)
return round(wer, 4), round(der, 4)
except Exception as e:
print(f"Error calculating metrics: {e}")
return None, None
def process_audio_and_compare(audio_input, original_diacritized_text):
"""
Main function triggered after audio input.
Transcribes audio, calculates metrics, and returns results.
Returns:
tuple: (transcript, wer, der)
transcript (str): The transcribed text or an error message.
wer (float | None): Word Error Rate or None if error.
der (float | None): Diacritic Error Rate or None if error.
"""
print("Processing audio and comparing...")
# Check if original_diacritized_text is valid
if not original_diacritized_text or not isinstance(original_diacritized_text, str) or original_diacritized_text.startswith("Error:"):
error_msg = "Error: Valid reference diacritized text not available. Please diacritize text first."
print(error_msg)
# Return default/error values for all outputs
return error_msg, None, None
# --- 1. Transcribe Audio ---
# Gradio provides the audio data (e.g., filepath for upload/mic)
transcript = transcribe_audio_api(audio_input)
if not isinstance(transcript, str) or transcript.startswith("Error:"):
# If transcription failed, return the error and None for metrics
error_msg = transcript if isinstance(transcript, str) else "Error: Transcription failed with non-string output."
print(error_msg)
return error_msg, None, None
# --- 2. Calculate Metrics ---
wer, der = calculate_metrics(original_diacritized_text, transcript)
if wer is None or der is None:
print("Metrics calculation failed.")
# Return transcript but indicate metric failure
return transcript, None, None
print(f"Comparison complete. WER: {wer}, DER: {der}")
return transcript, wer, der
# --- Gradio Interface ---
with gr.Blocks(theme=gr.themes.Soft()) as app:
gr.Markdown(
"""
# Arabic Diacritization and Reading Assessment Tool
1. Enter undiacritized Arabic text and click **Diacritize Text**.
2. Read the generated **Diacritized Text** aloud and record it using the microphone or upload an audio file.
3. Click **Transcribe and Compare** to get the transcript and see the WER/DER scores compared to the original diacritized text.
"""
)
# Store the original diacritized text for comparison later
original_diacritized_state = gr.State("") # Initialize state
with gr.Row():
with gr.Column(scale=1):
text_input = gr.Textbox(
label="1. Enter Undiacritized Arabic Text",
placeholder="ู…ุซุงู„: ุงู„ุณู„ุงู… ุนู„ูŠูƒู…",
lines=3,
text_align="right", # Align text right for Arabic
)
diacritize_button = gr.Button("Diacritize Text")
diacritized_text_output = gr.Textbox(
label="2. Diacritized Text (Reference)",
lines=3,
interactive=False, # User shouldn't edit this directly
text_align="right",
)
with gr.Column(scale=1):
audio_input = gr.Audio(
sources=["microphone", "upload"],
type="filepath", # Get the path to the saved audio file
label="3. Record or Upload Audio of Reading Diacritized Text",
)
transcribe_button = gr.Button("Transcribe and Compare")
transcript_output = gr.Textbox(
label="4. Diacritized Transcript (Hypothesis)",
lines=3,
interactive=False,
text_align="right",
)
with gr.Row():
# Set precision for number outputs
wer_output = gr.Number(label="Word Error Rate (WER)", interactive=False, precision=4)
der_output = gr.Number(label="Diacritic Error Rate (DER)", interactive=False, precision=4)
# --- Connect Components ---
# Action for Diacritize Button
diacritize_button.click(
fn=diacritize_text_api,
inputs=[text_input],
# Expects two outputs now from the modified function
outputs=[diacritized_text_output, original_diacritized_state]
)
# Action for Transcribe Button
transcribe_button.click(
fn=process_audio_and_compare,
inputs=[audio_input, original_diacritized_state], # Pass audio and stored text
outputs=[transcript_output, wer_output, der_output] # Update transcript and metrics
)
app.launch(debug=True, share=True)