sshenai commited on
Commit
cd09bfc
·
verified ·
1 Parent(s): 819e284

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +200 -0
  2. requirements.txt +8 -0
app.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import streamlit as st
4
+ from transformers import pipeline
5
+ from pydub import AudioSegment
6
+ import tempfile
7
+ import torch
8
+ from datasets import load_dataset
9
+ import jiwer
10
+ import librosa
11
+ import soundfile
12
+
13
+ # Page configuration
14
+ st.set_page_config(page_title="Audio-to-Text with Grammar Check", page_icon="🎤", layout="wide")
15
+
16
+ # Model configurations (three ASR models)
17
+ MODELS = {
18
+ "automatic-speech-recognition": {
19
+ "whisper-tiny": "openai/whisper-tiny",
20
+ "whisper-small": "openai/whisper-small",
21
+ "whisper-base": "openai/whisper-base"
22
+ },
23
+ "text2text-generation": {
24
+ "flan-t5-base": "pszemraj/grammar-synthesis-small"
25
+ }
26
+ }
27
+
28
+ # Cached model loading
29
+ @st.cache_resource
30
+ def load_model(model_key, task):
31
+ device = "cuda" if torch.cuda.is_available() else "cpu"
32
+ with st.spinner(f"Loading {model_key} model..."):
33
+ return pipeline(task, model=MODELS[task][model_key], device=device)
34
+
35
+ def convert_audio_to_wav(audio_file):
36
+ """Convert uploaded audio to WAV format"""
37
+ try:
38
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file:
39
+ audio = AudioSegment.from_file(audio_file)
40
+ audio.export(tmp_file.name, format="wav")
41
+ return tmp_file.name
42
+ except Exception as e:
43
+ st.error(f"Audio conversion failed: {str(e)}")
44
+ return None
45
+
46
+ def evaluate_asr_accuracy(transcription, reference):
47
+ """Calculate WER and CER accuracy"""
48
+ ref_processed = reference.lower().strip()
49
+ hyp_processed = transcription.lower().strip()
50
+
51
+ if not ref_processed:
52
+ return 0.0, 0.0
53
+
54
+ wer = jiwer.wer(ref_processed, hyp_processed)
55
+ cer = jiwer.cer(ref_processed, hyp_processed)
56
+
57
+ return 1 - wer, 1 - cer
58
+
59
+ # Cached dataset loading with audio decoding
60
+ @st.cache_data(show_spinner=False)
61
+ def load_cached_dataset(num_samples=1):
62
+ st.info("Loading dataset...")
63
+ try:
64
+ dataset = load_dataset(
65
+ "librispeech_asr",
66
+ "clean",
67
+ split="test",
68
+ streaming=True,
69
+ trust_remote_code=True
70
+ ).take(num_samples)
71
+ return [sample for sample in dataset]
72
+ except Exception as e:
73
+ st.error(f"Dataset loading failed: {str(e)}")
74
+ return None
75
+
76
+ def main():
77
+ st.title("🎤 Audio Grammar Evaluation System for Language Learners")
78
+
79
+ # Session state for persisting results
80
+ if "transcription" not in st.session_state:
81
+ st.session_state.transcription = ""
82
+ if "grammar_feedback" not in st.session_state:
83
+ st.session_state.grammar_feedback = ""
84
+
85
+ # Audio processing tab
86
+ tab1, tab2 = st.tabs(["Audio Processor", "Model Evaluator"])
87
+
88
+ with tab1:
89
+ st.subheader("Upload & Process Audio")
90
+ audio_file = st.file_uploader("Upload audio file", type=["mp3", "wav", "ogg", "m4a"])
91
+
92
+ if audio_file:
93
+ st.audio(audio_file, format="audio/wav")
94
+ wav_path = convert_audio_to_wav(audio_file)
95
+
96
+ if wav_path:
97
+ asr_model = load_model("whisper-tiny", "automatic-speech-recognition")
98
+
99
+ with st.spinner("Generating transcription..."):
100
+ transcription = asr_model(wav_path)["text"]
101
+ st.session_state.transcription = transcription
102
+ st.text_area("Transcription Result", transcription, height=150)
103
+
104
+ if st.session_state.transcription:
105
+ grammar_model = load_model("flan-t5-base", "text2text-generation")
106
+ with st.spinner("Checking grammar..."):
107
+ grammar_feedback = grammar_model(
108
+ f"Correct the grammar in: {transcription}"
109
+ )[0]["generated_text"]
110
+ st.session_state.grammar_feedback = grammar_feedback
111
+ st.success("Grammar Corrected Text:")
112
+ st.write(grammar_feedback)
113
+
114
+ os.unlink(wav_path)
115
+
116
+ with tab2:
117
+ st.subheader("Triple Model Evaluation with Runtime")
118
+
119
+ # Model selection
120
+ model_options = list(MODELS["automatic-speech-recognition"].keys())
121
+ model1, model2, model3 = st.columns(3)
122
+ with model1:
123
+ selected_model1 = st.selectbox("Select Model 1", model_options, index=0)
124
+ with model2:
125
+ selected_model2 = st.selectbox("Select Model 2", model_options, index=1)
126
+ with model3:
127
+ selected_model3 = st.selectbox("Select Model 3", model_options, index=2)
128
+
129
+ if st.button("Run Triple Evaluation"):
130
+ dataset = load_cached_dataset(num_samples=1)
131
+ if not dataset:
132
+ return
133
+
134
+ # Load three models
135
+ model1 = load_model(selected_model1, "automatic-speech-recognition")
136
+ model2 = load_model(selected_model2, "automatic-speech-recognition")
137
+ model3 = load_model(selected_model3, "automatic-speech-recognition")
138
+
139
+ results = []
140
+ total_runtime_model1 = 0.0
141
+ total_runtime_model2 = 0.0
142
+ total_runtime_model3 = 0.0
143
+
144
+ for i, sample in enumerate(dataset):
145
+ with st.spinner(f"Processing Sample..."):
146
+ audio_array = sample["audio"]["array"]
147
+ reference_text = sample["text"]
148
+
149
+ # Evaluate Model 1
150
+ start_time = time.perf_counter()
151
+ transcription1 = model1(audio_array)["text"]
152
+ end_time = time.perf_counter()
153
+ runtime1 = end_time - start_time
154
+ total_runtime_model1 += runtime1
155
+ wer1, cer1 = evaluate_asr_accuracy(transcription1, reference_text)
156
+
157
+ # Evaluate Model 2
158
+ start_time = time.perf_counter()
159
+ transcription2 = model2(audio_array)["text"]
160
+ end_time = time.perf_counter()
161
+ runtime2 = end_time - start_time
162
+ total_runtime_model2 += runtime2
163
+ wer2, cer2 = evaluate_asr_accuracy(transcription2, reference_text)
164
+
165
+ # Evaluate Model 3
166
+ start_time = time.perf_counter()
167
+ transcription3 = model3(audio_array)["text"]
168
+ end_time = time.perf_counter()
169
+ runtime3 = end_time - start_time
170
+ total_runtime_model3 += runtime3
171
+ wer3, cer3 = evaluate_asr_accuracy(transcription3, reference_text)
172
+
173
+ # Organize results
174
+ model1_result = {
175
+ "Model": selected_model1,
176
+ "Runtime": f"{runtime1:.4f}s",
177
+ "WER": f"{wer1*100:.2f}%",
178
+ "CER": f"{cer1*100:.2f}%"
179
+ }
180
+ model2_result = {
181
+ "Model": selected_model2,
182
+ "Runtime": f"{runtime2:.4f}s",
183
+ "WER": f"{wer2*100:.2f}%",
184
+ "CER": f"{cer2*100:.2f}%"
185
+ }
186
+ model3_result = {
187
+ "Model": selected_model3,
188
+ "Runtime": f"{runtime3:.4f}s",
189
+ "WER": f"{wer3*100:.2f}%",
190
+ "CER": f"{cer3*100:.2f}%"
191
+ }
192
+ results.extend([model1_result, model2_result, model3_result])
193
+
194
+ # Display results
195
+ st.subheader("Model Evaluation Results")
196
+ st.table(results)
197
+
198
+
199
+ if __name__ == "__main__":
200
+ main()
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ transformers>=4.30.0
2
+ torch>=2.0.0
3
+ pydub>=0.25.1
4
+ streamlit>=1.25.0
5
+ jiwer>=2.0.0
6
+ datasets>=2.0.0
7
+ librosa>=0.10.0
8
+ soundfile>=0.12.1