|
|
import os |
|
|
import time |
|
|
import streamlit as st |
|
|
from transformers import pipeline |
|
|
from pydub import AudioSegment |
|
|
import tempfile |
|
|
import torch |
|
|
from datasets import load_dataset |
|
|
import jiwer |
|
|
import librosa |
|
|
import soundfile |
|
|
|
|
|
|
|
|
st.set_page_config(page_title="Audio-to-Text with Grammar Check", page_icon="🎤", layout="wide") |
|
|
|
|
|
|
|
|
MODELS = { |
|
|
"automatic-speech-recognition": { |
|
|
"whisper-tiny": "openai/whisper-tiny", |
|
|
"whisper-small": "openai/whisper-small", |
|
|
"whisper-base": "openai/whisper-base" |
|
|
}, |
|
|
"text2text-generation": { |
|
|
"flan-t5-base": "pszemraj/grammar-synthesis-small" |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
@st.cache_resource |
|
|
def load_model(model_key, task): |
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
with st.spinner(f"Loading {model_key} model..."): |
|
|
return pipeline(task, model=MODELS[task][model_key], device=device) |
|
|
|
|
|
def convert_audio_to_wav(audio_file): |
|
|
"""Convert uploaded audio to WAV format""" |
|
|
try: |
|
|
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file: |
|
|
audio = AudioSegment.from_file(audio_file) |
|
|
audio.export(tmp_file.name, format="wav") |
|
|
return tmp_file.name |
|
|
except Exception as e: |
|
|
st.error(f"Audio conversion failed: {str(e)}") |
|
|
return None |
|
|
|
|
|
def evaluate_asr_accuracy(transcription, reference): |
|
|
"""Calculate WER and CER accuracy""" |
|
|
ref_processed = reference.lower().strip() |
|
|
hyp_processed = transcription.lower().strip() |
|
|
|
|
|
if not ref_processed: |
|
|
return 0.0, 0.0 |
|
|
|
|
|
wer = jiwer.wer(ref_processed, hyp_processed) |
|
|
cer = jiwer.cer(ref_processed, hyp_processed) |
|
|
|
|
|
return 1 - wer, 1 - cer |
|
|
|
|
|
|
|
|
@st.cache_data(show_spinner=False) |
|
|
def load_cached_dataset(num_samples=1): |
|
|
st.info("Loading dataset...") |
|
|
try: |
|
|
dataset = load_dataset( |
|
|
"librispeech_asr", |
|
|
"clean", |
|
|
split="test", |
|
|
streaming=True, |
|
|
trust_remote_code=True |
|
|
).take(num_samples) |
|
|
return [sample for sample in dataset] |
|
|
except Exception as e: |
|
|
st.error(f"Dataset loading failed: {str(e)}") |
|
|
return None |
|
|
|
|
|
def main(): |
|
|
st.title("🎤 Audio Grammar Evaluation System for Language Learners") |
|
|
|
|
|
|
|
|
if "transcription" not in st.session_state: |
|
|
st.session_state.transcription = "" |
|
|
if "grammar_feedback" not in st.session_state: |
|
|
st.session_state.grammar_feedback = "" |
|
|
|
|
|
|
|
|
tab1, tab2 = st.tabs(["Audio Processor", "Model Evaluator"]) |
|
|
|
|
|
with tab1: |
|
|
st.subheader("Upload & Process Audio") |
|
|
audio_file = st.file_uploader("Upload audio file", type=["mp3", "wav", "ogg", "m4a"]) |
|
|
|
|
|
if audio_file: |
|
|
st.audio(audio_file, format="audio/wav") |
|
|
wav_path = convert_audio_to_wav(audio_file) |
|
|
|
|
|
if wav_path: |
|
|
asr_model = load_model("whisper-tiny", "automatic-speech-recognition") |
|
|
|
|
|
with st.spinner("Generating transcription..."): |
|
|
transcription = asr_model(wav_path)["text"] |
|
|
st.session_state.transcription = transcription |
|
|
st.text_area("Transcription Result", transcription, height=150) |
|
|
|
|
|
if st.session_state.transcription: |
|
|
grammar_model = load_model("flan-t5-base", "text2text-generation") |
|
|
with st.spinner("Checking grammar..."): |
|
|
grammar_feedback = grammar_model( |
|
|
f"Correct the grammar in: {transcription}" |
|
|
)[0]["generated_text"] |
|
|
st.session_state.grammar_feedback = grammar_feedback |
|
|
st.success("Grammar Corrected Text:") |
|
|
st.write(grammar_feedback) |
|
|
|
|
|
os.unlink(wav_path) |
|
|
|
|
|
with tab2: |
|
|
st.subheader("Triple Model Evaluation with Runtime") |
|
|
|
|
|
|
|
|
model_options = list(MODELS["automatic-speech-recognition"].keys()) |
|
|
model1, model2, model3 = st.columns(3) |
|
|
with model1: |
|
|
selected_model1 = st.selectbox("Select Model 1", model_options, index=0) |
|
|
with model2: |
|
|
selected_model2 = st.selectbox("Select Model 2", model_options, index=1) |
|
|
with model3: |
|
|
selected_model3 = st.selectbox("Select Model 3", model_options, index=2) |
|
|
|
|
|
if st.button("Run Triple Evaluation"): |
|
|
dataset = load_cached_dataset(num_samples=1) |
|
|
if not dataset: |
|
|
return |
|
|
|
|
|
|
|
|
model1 = load_model(selected_model1, "automatic-speech-recognition") |
|
|
model2 = load_model(selected_model2, "automatic-speech-recognition") |
|
|
model3 = load_model(selected_model3, "automatic-speech-recognition") |
|
|
|
|
|
results = [] |
|
|
total_runtime_model1 = 0.0 |
|
|
total_runtime_model2 = 0.0 |
|
|
total_runtime_model3 = 0.0 |
|
|
|
|
|
for i, sample in enumerate(dataset): |
|
|
with st.spinner(f"Processing Sample..."): |
|
|
audio_array = sample["audio"]["array"] |
|
|
reference_text = sample["text"] |
|
|
|
|
|
|
|
|
start_time = time.perf_counter() |
|
|
transcription1 = model1(audio_array)["text"] |
|
|
end_time = time.perf_counter() |
|
|
runtime1 = end_time - start_time |
|
|
total_runtime_model1 += runtime1 |
|
|
wer1, cer1 = evaluate_asr_accuracy(transcription1, reference_text) |
|
|
|
|
|
|
|
|
start_time = time.perf_counter() |
|
|
transcription2 = model2(audio_array)["text"] |
|
|
end_time = time.perf_counter() |
|
|
runtime2 = end_time - start_time |
|
|
total_runtime_model2 += runtime2 |
|
|
wer2, cer2 = evaluate_asr_accuracy(transcription2, reference_text) |
|
|
|
|
|
|
|
|
start_time = time.perf_counter() |
|
|
transcription3 = model3(audio_array)["text"] |
|
|
end_time = time.perf_counter() |
|
|
runtime3 = end_time - start_time |
|
|
total_runtime_model3 += runtime3 |
|
|
wer3, cer3 = evaluate_asr_accuracy(transcription3, reference_text) |
|
|
|
|
|
|
|
|
model1_result = { |
|
|
"Model": selected_model1, |
|
|
"Runtime": f"{runtime1:.4f}s", |
|
|
"WER": f"{wer1*100:.2f}%", |
|
|
"CER": f"{cer1*100:.2f}%" |
|
|
} |
|
|
model2_result = { |
|
|
"Model": selected_model2, |
|
|
"Runtime": f"{runtime2:.4f}s", |
|
|
"WER": f"{wer2*100:.2f}%", |
|
|
"CER": f"{cer2*100:.2f}%" |
|
|
} |
|
|
model3_result = { |
|
|
"Model": selected_model3, |
|
|
"Runtime": f"{runtime3:.4f}s", |
|
|
"WER": f"{wer3*100:.2f}%", |
|
|
"CER": f"{cer3*100:.2f}%" |
|
|
} |
|
|
results.extend([model1_result, model2_result, model3_result]) |
|
|
|
|
|
|
|
|
st.subheader("Model Evaluation Results") |
|
|
st.table(results) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |