Bisher's picture
Update app.py
0570125 verified
raw
history blame
15.1 kB
import gradio as gr
from gradio_client import Client, handle_file
import jiwer
import os
import time
import warnings
import pyarabic.araby as araby
import difflib # Import difflib
# Suppress specific UserWarnings from jiwer related to empty strings
warnings.filterwarnings("ignore", message="Reference is empty.*", category=UserWarning)
warnings.filterwarnings("ignore", message="Hypothesis is empty.*", category=UserWarning)
# --- Constants ---
DIACRITIZATION_API_URL = "Bisher/CATT.diacratization"
TRANSCRIPTION_API_URL = "gh-kaka22/diacritic_level_arabic_transcription"
SYLLABLE_TRANSCRIPTION_API_URL = "Bisher/arabic_syllable_transcription"
# Define Arabic diacritics
# Use a try-except block in case pyarabic is not installed or fails to import
try:
ARABIC_DIACRITICS = {
araby.FATHA, araby.FATHATAN, araby.DAMMA, araby.DAMMATAN,
araby.KASRA, araby.KASRATAN, araby.SUKUN, araby.SHADDA,
}
except (ImportError, NameError):
print("Warning: pyarabic not found or failed to import. Using fallback diacritics set.")
ARABIC_DIACRITICS = {'\u064B', '\u064C', '\u064D', '\u064E', '\u064F', '\u0650', '\u0651', '\u0652'}
# --- API Clients ---
# Use caching or global clients to avoid re-initializing on every call
diacritization_client = None
transcription_client = None
syllable_transcription_client = None
def get_diacritization_client():
global diacritization_client
if diacritization_client is None:
try:
diacritization_client = Client(DIACRITIZATION_API_URL, download_files=True)
except Exception as e:
print(f"Error initializing diacritization client: {e}")
return None
return diacritization_client
def get_transcription_client():
global transcription_client
if transcription_client is None:
try:
transcription_client = Client(TRANSCRIPTION_API_URL, download_files=True)
except Exception as e:
print(f"Error initializing transcription client: {e}")
return None
return transcription_client
def get_syllable_transcription_client():
global syllable_transcription_client
if syllable_transcription_client is None:
try:
syllable_transcription_client = Client(SYLLABLE_TRANSCRIPTION_API_URL, download_files=True)
except Exception as e:
print(f"Error initializing syllable transcription client: {e}")
return None
return syllable_transcription_client
# --- Helper Functions ---
def diacritize_text_api(text_to_diacritize):
"""Calls the diacritization API."""
if not text_to_diacritize or not text_to_diacritize.strip():
return "Please enter some text to diacritize.", "" # Return two values as expected by the click handler
client = get_diacritization_client()
if not client:
return "Error: Could not connect to the diacritization service.", ""
try:
result = client.predict(
model_type="Encoder-Only",
input_text=text_to_diacritize,
api_name="/predict"
)
# Ensure result is a string, handle potential None or unexpected types
result_str = str(result) if result is not None else "Error: Empty response from diacritization service."
# Return the result for both the output textbox and the state
return result_str, result_str
except Exception as e:
print(f"Error during diacritization API call: {e}")
return f"Error during diacritization: {e}", ""
def transcribe_audio_api(audio_filepath):
"""Calls the standard transcription API."""
if not audio_filepath:
return "Error: Please provide an audio recording or file."
if not os.path.exists(audio_filepath):
return f"Error: Audio file not found at {audio_filepath}"
client = get_transcription_client()
if not client:
return "Error: Could not connect to the transcription service."
try:
# Add a small delay if needed, sometimes helps with API race conditions
# time.sleep(0.5)
result = client.predict(
audio=handle_file(audio_filepath),
api_name="/predict"
)
# Process result, expecting a dictionary or string
if isinstance(result, dict) and 'text' in result:
transcript = result['text']
elif isinstance(result, str):
transcript = result
else:
transcript = f"Error: Unexpected response format from transcription service: {type(result)}"
return transcript if transcript is not None else "Error: Empty transcript received."
except Exception as e:
print(f"Error during transcription API call: {e}")
return f"Error during transcription: {e}"
def transcribe_syllable_audio_api(audio_filepath):
"""Calls the syllable transcription API."""
if not audio_filepath:
# This case might not be strictly needed if called after the first check, but good practice
return "Error: Audio file path missing for syllable transcription."
if not os.path.exists(audio_filepath):
return f"Error: Audio file not found at {audio_filepath} for syllable transcription."
client = get_syllable_transcription_client()
if not client:
return "Error: Could not connect to the syllable transcription service."
try:
# Add a small delay if needed
# time.sleep(0.5)
result = client.predict(
audio=handle_file(audio_filepath),
api_name="/predict"
)
# Process result, expecting a dictionary or string
if isinstance(result, dict) and 'text' in result:
transcript = result['text']
elif isinstance(result, str):
transcript = result
else:
transcript = f"Error: Unexpected response format from syllable transcription service: {type(result)}"
return transcript if transcript is not None else "Error: Empty syllable transcript received."
except Exception as e:
print(f"Error during syllable transcription API call: {e}")
return f"Error during syllable transcription: {e}"
def get_diacritics_sequence(text):
"""Extracts diacritics from a string."""
if not isinstance(text, str):
return ""
diacritics_only = [c for c in text if c in ARABIC_DIACRITICS]
return ' '.join(diacritics_only)
def calculate_metrics(reference, hypothesis):
"""Calculates WER, DER, CER."""
ref = reference or ""
hyp = hypothesis or ""
# Handle cases where one or both are empty or just whitespace
if not ref.strip() and not hyp.strip():
return 0.0, 0.0, 0.0 # Both empty, 0 error
if not ref.strip():
return 1.0, 1.0, 1.0 # Reference empty, hypothesis not: Max error
if not hyp.strip():
# Hypothesis empty, reference not: Max error (though jiwer might handle this)
# Let jiwer calculate based on its rules for empty hypothesis
pass
try:
# WER
wer = jiwer.wer(ref, hyp)
# DER
ref_d = get_diacritics_sequence(ref)
hyp_d = get_diacritics_sequence(hyp)
# Handle empty diacritic sequences for DER calculation
if not ref_d.strip() and not hyp_d.strip():
der = 0.0
elif not ref_d.strip():
der = 1.0
else:
der = jiwer.wer(ref_d, hyp_d) # jiwer handles empty hyp_d if ref_d is not empty
# CER
cer = jiwer.cer(ref, hyp)
return round(wer, 4), round(der, 4), round(cer, 4)
except Exception as e:
print(f"Error calculating metrics: {e}")
return None, None, None # Indicate error in calculation
def highlight_errors(reference, hypothesis):
"""Highlights differences between reference and hypothesis using HTML mark tag."""
ref = reference or ""
hyp = hypothesis or ""
ref_words = ref.split()
hyp_words = hyp.split()
if not ref_words and not hyp_words:
return "", "" # No errors if both are empty
matcher = difflib.SequenceMatcher(None, ref_words, hyp_words, autojunk=False)
highlighted_hyp_words = []
error_words_ref = [] # Words in reference that were deleted or replaced
error_words_hyp = [] # Words in hypothesis that were inserted or replaced
for tag, i1, i2, j1, j2 in matcher.get_opcodes():
if tag == 'equal':
highlighted_hyp_words.extend(hyp_words[j1:j2])
elif tag == 'replace':
# Mark incorrect words in hypothesis red
for word in hyp_words[j1:j2]:
highlighted_hyp_words.append(f"<mark style='background-color: #ffcccb;'>{word}</mark>")
error_words_ref.extend(ref_words[i1:i2])
error_words_hyp.extend(hyp_words[j1:j2])
elif tag == 'delete':
# Indicate missing words (maybe with a placeholder?) - for now, just note them
# We don't add anything to highlighted_hyp_words here as they are missing
error_words_ref.extend(ref_words[i1:i2])
# Optionally add a placeholder in the output to show where deletion happened
# highlighted_hyp_words.append("<mark style='background-color: #lightgrey;'>[missing]</mark>")
elif tag == 'insert':
# Mark inserted words in hypothesis green
for word in hyp_words[j1:j2]:
highlighted_hyp_words.append(f"<mark style='background-color: #ccffcc;'>{word}</mark>")
error_words_hyp.extend(hyp_words[j1:j2])
html_output = ' '.join(highlighted_hyp_words)
# Combine unique error words for the list
error_list = sorted(list(set(error_words_ref + error_words_hyp)))
return html_output, ', '.join(error_list)
# --- Gradio Interface ---
with gr.Blocks(theme=gr.themes.Soft()) as app:
gr.Markdown(
"""
# Arabic Diacritization and Reading Assessment Tool
1. Enter undiacritized Arabic text and click **Diacritize Text**.
2. Read the generated **Diacritized Text** aloud and record or upload audio.
3. Click **Transcribe and Compare** to see the transcript, syllable transcript, WER/DER/CER, and mispronounced words highlighted.
"""
)
# Using gr.State to hold the diacritized reference text between steps
reference_text_state = gr.State("")
with gr.Row():
with gr.Column(scale=1):
text_input = gr.Textbox(label="Undiacritized Arabic Text", lines=3, text_align="right")
diacritize_btn = gr.Button("Diacritize Text")
diacritized_output = gr.Textbox(
label="Diacritized Text (Reference)",
lines=3,
interactive=False, # User shouldn't edit this directly
text_align="right"
)
with gr.Column(scale=1):
audio_input = gr.Audio(label="Record or Upload Audio", type="filepath", sources=["microphone", "upload"])
transcribe_btn = gr.Button("Transcribe and Compare")
transcript_output = gr.Textbox(
label="Transcript (Hypothesis)",
lines=3,
interactive=False,
text_align="right"
)
# Ensure this Textbox is defined correctly
transcript_syllables_output = gr.Textbox(
label="Transcript Syllables (Hypothesis)", # Corrected label slightly for clarity
lines=3,
interactive=False,
text_align="right"
)
with gr.Row():
wer_out = gr.Number(label="WER", interactive=False, precision=4)
der_out = gr.Number(label="DER", interactive=False, precision=4)
cer_out = gr.Number(label="CER", interactive=False, precision=4)
# Use Markdown for potentially richer HTML display if needed, but HTML component is fine
error_html = gr.HTML(label="Highlighted Errors in Hypothesis")
error_list = gr.Textbox(label="Words Involved in Errors", interactive=False) # Changed label
# --- Event Handlers ---
# When Diacritize button is clicked
diacritize_btn.click(
fn=diacritize_text_api,
inputs=[text_input],
# Output to the display box AND the hidden state
outputs=[diacritized_output, reference_text_state]
)
# Define the main processing function that returns all 7 values
def process_audio_and_compare(audio_filepath, reference_text):
"""Processes audio, gets both transcripts, calculates metrics, and highlights errors."""
# Default values in case of errors
transcript = "Error: Processing failed."
syllable_transcript = "Error: Processing failed."
wer, der, cer = None, None, None
html_output = ""
error_words = ""
# Validate inputs
if not audio_filepath:
transcript = "Error: No audio provided."
syllable_transcript = "Error: No audio provided."
# Return 7 values even on input error
return transcript, syllable_transcript, None, None, None, "", ""
if not reference_text:
transcript = "Error: No reference text found. Please diacritize first."
syllable_transcript = "Error: No reference text found."
# Return 7 values
return transcript, syllable_transcript, None, None, None, "", ""
# --- Call Transcription APIs ---
transcript = transcribe_audio_api(audio_filepath)
# Call syllable transcription regardless of the first one's success for now,
# but handle its potential error message.
syllable_transcript = transcribe_syllable_audio_api(audio_filepath)
# --- Calculate Metrics and Highlight Errors (only if first transcript is not an error) ---
if not transcript.startswith("Error"):
wer, der, cer = calculate_metrics(reference_text, transcript)
# Use the standard transcript for highlighting, adjust if needed
html_output, error_words = highlight_errors(reference_text, transcript)
else:
# If the main transcript failed, indicate no metrics/highlighting possible
wer, der, cer = None, None, None
html_output = "Highlighting not available due to transcription error."
error_words = "N/A"
# --- Return all 7 values ---
return transcript, syllable_transcript, wer, der, cer, html_output, error_words
# When Transcribe button is clicked
transcribe_btn.click(
fn=process_audio_and_compare,
# Get audio path and the reference text from the state
inputs=[audio_input, reference_text_state],
# Update all 7 output components
outputs=[
transcript_output,
transcript_syllables_output, # This should now update correctly
wer_out,
der_out,
cer_out,
error_html,
error_list
]
)
# Launch the app
if __name__ == "__main__":
app.launch(debug=True, share=True) # Set share=True if you need a public link