KuyaToto commited on
Commit
f26b2f5
·
verified ·
1 Parent(s): 7e41a9d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -111
app.py CHANGED
@@ -1,17 +1,16 @@
1
  import gradio as gr
2
  import torch
3
  import torchaudio
4
- from transformers import WhisperProcessor, WhisperForConditionalGeneration
5
  import epitran
6
  import re
7
  import editdistance
8
- from jiwer import wer
9
  import orjson
10
- import eng_to_ipa as ipa
11
 
12
- # --- Device setup ---
13
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
- print("Using device:", device)
15
 
16
  # --- WordMap ---
17
  WORD_MAP = {
@@ -43,142 +42,83 @@ WORD_MAP = {
43
  'Z': {'word': 'Zebra', 'phonetic': 'ˈziːbrə'}
44
  }
45
 
46
- # --- Load Whisper tiny ---
47
- processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
48
- model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en").to(device).eval()
49
- epi = epitran.Epitran("eng-Latn")
50
 
51
- # --- Precompute IPA cache ---
52
  IPA_CACHE = {v['word'].lower(): re.sub(r'[^\w\s]', '', v['phonetic']) for v in WORD_MAP.values()}
53
 
54
  # --- Helpers ---
55
- def clean_phonemes(ipa_text):
56
- return re.sub(r'[^\w\s]', '', ipa_text)
57
-
58
- def transliterate_english(word):
59
  word_lower = word.lower()
60
  if word_lower in IPA_CACHE:
61
  return IPA_CACHE[word_lower]
62
  try:
63
- return clean_phonemes(ipa.convert(word)) or ""
64
  except Exception:
65
  return ""
66
 
67
- def find_closest_word(transcription, reference_word):
68
- if not transcription:
69
- return reference_word, 0.0
70
- transcription = transcription.lower().strip()
71
- distances = {entry['word'].lower(): editdistance.eval(transcription, entry['word'].lower()) for entry in WORD_MAP.values()}
72
- closest_word = min(distances, key=distances.get)
73
- max_len = max(len(transcription), len(closest_word))
74
- similarity = round((1 - distances[closest_word] / max(1, max_len)) * 100, 2)
75
- return closest_word, similarity
76
-
77
- def fast_transcribe(audio_path):
78
  waveform, sr = torchaudio.load(audio_path)
79
  if sr != 16000:
80
  waveform = torchaudio.functional.resample(waveform, sr, 16000)
81
-
82
- input_features = processor(waveform.squeeze(), sampling_rate=16000, return_tensors="pt").input_features.to(device)
83
 
84
  with torch.no_grad():
85
- predicted_ids = model.generate(input_features)
 
 
86
 
87
- transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
88
- return transcription.strip().lower()
89
-
90
- # --- Main function ---
91
- def analyze_phonemes(language, reference_text, audio_input, detailed=True):
92
  try:
93
- transcription = fast_transcribe(audio_input)
 
 
 
 
 
94
 
95
  if not detailed:
96
- return orjson.dumps({
97
- "language": language,
98
- "reference_text": reference_text,
99
- "transcription": transcription
100
- }).decode()
101
 
102
- # Detailed phoneme alignment
103
- closest_word, similarity = find_closest_word(transcription, reference_text.lower())
104
- transcription_clean = closest_word
105
 
106
- obs_phonemes = [list(transliterate_english(word)) for word in transcription_clean.split()]
107
- ref_words = reference_text.lower().split()
108
- ref_phonemes = [list(transliterate_english(word)) for word in ref_words]
109
 
110
- results = {
111
  "language": language,
112
- "reference_text": reference_text,
113
- "transcription": transcription_clean,
114
- "word_alignment": [],
115
- "metrics": {"similarity": similarity}
 
 
 
 
 
 
 
 
116
  }
117
-
118
- total_phoneme_errors = 0
119
- total_phoneme_length = 0
120
- correct_words = 0
121
- total_word_length = len(ref_phonemes)
122
-
123
- for i, (ref, obs) in enumerate(zip(ref_phonemes, obs_phonemes)):
124
- ref_str = ''.join(ref)
125
- obs_str = ''.join(obs)
126
- edits = editdistance.eval(ref, obs)
127
- acc = round((1 - edits / max(1, len(ref))) * 100, 2)
128
- results["word_alignment"].append({
129
- "word_index": i,
130
- "reference_phonemes": ref_str,
131
- "observed_phonemes": obs_str,
132
- "edit_distance": edits,
133
- "accuracy": acc,
134
- "is_correct": edits == 0
135
- })
136
- total_phoneme_errors += edits
137
- total_phoneme_length += len(ref)
138
- correct_words += int(edits == 0)
139
-
140
- phoneme_acc = round((1 - total_phoneme_errors / max(1, total_phoneme_length)) * 100, 2)
141
- phoneme_er = round((total_phoneme_errors / max(1, total_phoneme_length)) * 100, 2)
142
- word_acc = round((correct_words / max(1, total_word_length)) * 100, 2)
143
- word_er = round(((total_word_length - correct_words) / max(1, total_word_length)) * 100, 2)
144
- text_wer = round(wer(reference_text, transcription_clean) * 100, 2)
145
-
146
- results["metrics"].update({
147
- "word_accuracy": word_acc,
148
- "word_error_rate": word_er,
149
- "phoneme_accuracy": phoneme_acc,
150
- "phoneme_error_rate": phoneme_er,
151
- "asr_word_error_rate": text_wer
152
- })
153
-
154
- return orjson.dumps(results).decode()
155
-
156
  except Exception as e:
157
- return orjson.dumps({
158
- "language": language,
159
- "reference_text": reference_text,
160
- "transcription": "Error processing audio",
161
- "word_alignment": [],
162
- "metrics": {"message": f"Error: {str(e)}"}
163
- }).decode()
164
 
165
  # --- Gradio UI ---
166
- def get_default_text(language):
167
- return "A" if language == "English" else ""
168
-
169
  with gr.Blocks() as demo:
170
- gr.Markdown("# Multilingual Phoneme Alignment (Fast Whisper Backend)")
171
- gr.Markdown("Compare audio pronunciation with reference text at phoneme level. Toggle fast vs detailed mode.")
172
 
173
  with gr.Row():
174
- language = gr.Dropdown(["English"], label="Language", value="English")
175
- reference_text = gr.Textbox(label="Reference Text", value="A")
176
- audio_input = gr.Audio(label="Record Audio", type="filepath") # ⚡ filepath, not numpy
177
- detailed = gr.Checkbox(label="Detailed Mode (phoneme analysis)", value=True)
178
- submit_btn = gr.Button("Analyze")
179
- output = gr.JSON(label="Results")
180
-
181
- language.change(fn=get_default_text, inputs=language, outputs=reference_text)
182
- submit_btn.click(fn=analyze_phonemes, inputs=[language, reference_text, audio_input, detailed], outputs=output)
183
 
184
  demo.launch()
 
1
  import gradio as gr
2
  import torch
3
  import torchaudio
4
+ from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
5
  import epitran
6
  import re
7
  import editdistance
 
8
  import orjson
9
+ from jiwer import wer
10
 
11
+ # --- Device ---
12
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
+ print("Using:", device)
14
 
15
  # --- WordMap ---
16
  WORD_MAP = {
 
42
  'Z': {'word': 'Zebra', 'phonetic': 'ˈziːbrə'}
43
  }
44
 
45
+ # --- Load wav2vec2 (smaller + faster than Whisper) ---
46
+ processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
47
+ model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h").to(device).eval()
 
48
 
49
+ epi = epitran.Epitran("eng-Latn")
50
  IPA_CACHE = {v['word'].lower(): re.sub(r'[^\w\s]', '', v['phonetic']) for v in WORD_MAP.values()}
51
 
52
  # --- Helpers ---
53
+ def transliterate(word):
 
 
 
54
  word_lower = word.lower()
55
  if word_lower in IPA_CACHE:
56
  return IPA_CACHE[word_lower]
57
  try:
58
+ return epi.transliterate(word_lower)
59
  except Exception:
60
  return ""
61
 
62
+ def transcribe(audio_path):
 
 
 
 
 
 
 
 
 
 
63
  waveform, sr = torchaudio.load(audio_path)
64
  if sr != 16000:
65
  waveform = torchaudio.functional.resample(waveform, sr, 16000)
66
+ inputs = processor(waveform.squeeze(), sampling_rate=16000, return_tensors="pt", padding=True).to(device)
 
67
 
68
  with torch.no_grad():
69
+ logits = model(**inputs).logits
70
+ pred_ids = torch.argmax(logits, dim=-1)
71
+ return processor.decode(pred_ids[0]).lower()
72
 
73
+ def analyze(language, reference_text, audio_input, detailed=True):
 
 
 
 
74
  try:
75
+ transcription = transcribe(audio_input)
76
+
77
+ # match closest word from WORD_MAP
78
+ distances = {entry['word'].lower(): editdistance.eval(transcription, entry['word'].lower()) for entry in WORD_MAP.values()}
79
+ closest_word = min(distances, key=distances.get)
80
+ similarity = round((1 - distances[closest_word] / max(1, len(closest_word))) * 100, 2)
81
 
82
  if not detailed:
83
+ return {"language": language, "reference": reference_text, "transcription": closest_word}
 
 
 
 
84
 
85
+ # phoneme-level alignment
86
+ ref_ph = list(transliterate(reference_text))
87
+ obs_ph = list(transliterate(closest_word))
88
 
89
+ edits = editdistance.eval(ref_ph, obs_ph)
90
+ phon_acc = round((1 - edits / max(1, len(ref_ph))) * 100, 2)
 
91
 
92
+ return {
93
  "language": language,
94
+ "reference": reference_text,
95
+ "transcription": closest_word,
96
+ "metrics": {
97
+ "similarity": similarity,
98
+ "phoneme_accuracy": phon_acc,
99
+ "asr_word_error_rate": round(wer(reference_text, closest_word) * 100, 2)
100
+ },
101
+ "alignment": {
102
+ "reference_phonemes": "".join(ref_ph),
103
+ "observed_phonemes": "".join(obs_ph),
104
+ "edit_distance": edits
105
+ }
106
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
  except Exception as e:
108
+ return {"error": str(e)}
 
 
 
 
 
 
109
 
110
  # --- Gradio UI ---
 
 
 
111
  with gr.Blocks() as demo:
112
+ gr.Markdown("## Fast wav2vec2-based Phoneme Checker")
 
113
 
114
  with gr.Row():
115
+ lang = gr.Dropdown(["English"], value="English", label="Language")
116
+ ref = gr.Textbox(value="A", label="Reference Word")
117
+ audio = gr.Audio(label="Record Audio", type="filepath")
118
+ detailed = gr.Checkbox(value=True, label="Detailed Mode")
119
+ out = gr.JSON(label="Results")
120
+
121
+ demo_btn = gr.Button("Analyze")
122
+ demo_btn.click(analyze, inputs=[lang, ref, audio, detailed], outputs=out)
 
123
 
124
  demo.launch()