KuyaToto commited on
Commit
883f9e7
·
verified ·
1 Parent(s): ab0c8a8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +159 -56
app.py CHANGED
@@ -1,27 +1,27 @@
1
  import gradio as gr
2
- from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
3
  import librosa
4
  import torch
5
  import epitran
6
  import re
 
7
  import editdistance
 
 
8
  import string
9
  import eng_to_ipa as ipa
10
  import numpy as np
11
 
12
- # --- Device setup ---
13
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
-
15
- # --- Load faster Wav2Vec2 models for English & Arabic ---
16
  MODELS = {
17
  "Arabic": {
18
- "processor": Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-xls-r-300m"),
19
- "model": Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-xls-r-300m").to(device),
20
  "epitran": epitran.Epitran("ara-Arab")
21
  },
22
  "English": {
23
- "processor": Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-100k"),
24
- "model": Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-100k").to(device),
25
  "epitran": epitran.Epitran("eng-Latn")
26
  }
27
  }
@@ -29,23 +29,27 @@ MODELS = {
29
  for lang in MODELS.values():
30
  lang["model"].config.ctc_loss_reduction = "mean"
31
 
32
- # --- Precompute IPA mapping for single letters ---
33
- LETTER_IPA = {l: ipa.convert(l.lower()).replace(".", "") for l in string.ascii_uppercase}
34
-
35
  def clean_phonemes(ipa_text):
36
  return re.sub(r'[\u064B-\u0652\u02D0]', '', ipa_text)
37
 
38
  def safe_transliterate_arabic(epi, word):
39
  try:
40
- ipa_text = epi.transliterate(word.strip())
41
- return clean_phonemes(ipa_text)
42
- except:
 
 
 
 
43
  return ""
44
 
45
  def transliterate_english(word):
46
  try:
47
- return LETTER_IPA.get(word.upper(), "")
48
- except:
 
 
 
49
  return ""
50
 
51
  def analyze_phonemes(language, reference_text, audio_file):
@@ -56,61 +60,160 @@ def analyze_phonemes(language, reference_text, audio_file):
56
 
57
  transliterate_fn = safe_transliterate_arabic if language == "Arabic" else transliterate_english
58
 
59
- # --- Load & normalize audio ---
 
 
60
  audio, sr = librosa.load(audio_file, sr=16000)
61
- if len(audio) < sr * 0.1:
62
- return {"language": language, "transcription": "No speech detected", "correct": False}
63
 
64
- audio = audio / max(np.abs(audio), 1e-9)
 
 
 
 
 
65
  trimmed_audio, _ = librosa.effects.trim(audio, top_db=30)
66
- trimmed_audio = trimmed_audio[:int(sr*0.75)] # max 0.75s
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
- # --- Wav2Vec2 inference ---
69
- input_values = processor(trimmed_audio, sampling_rate=sr, return_tensors="pt").input_values.to(device)
70
  with torch.no_grad():
71
  logits = model(input_values).logits
72
- pred_ids = torch.argmax(logits, dim=-1)
73
- transcription = processor.batch_decode(pred_ids)[0].strip()
74
 
75
- # --- Quick confidence check ---
76
  probs = torch.softmax(logits, dim=-1)
77
- if probs.max(dim=-1).values.mean().item() < 0.6:
78
- return {"language": language, "transcription": "Low confidence", "correct": False}
79
-
80
- # --- Single-letter optimization ---
81
- if len(reference_text.strip()) == 1:
82
- ref_ipa = transliterate_fn(reference_text.strip())
83
- trans_ipa = transliterate_fn(transcription)
84
- correct = ref_ipa == trans_ipa or reference_text.upper() == transcription.upper()
85
- return {"language": language, "reference": reference_text, "transcription": transcription, "correct": correct}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
 
87
- # --- Full phoneme alignment (for multi-letter words) ---
88
- ref_phonemes = [list(transliterate_fn(word)) for word in reference_text.split()]
89
- obs_phonemes = [list(transliterate_fn(word)) for word in transcription.split()]
90
-
91
- results = []
92
- for r, o in zip(ref_phonemes, obs_phonemes):
93
- results.append({
94
- "reference": ''.join(r),
95
- "observed": ''.join(o),
96
- "edit_distance": editdistance.eval(r, o)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  })
98
 
99
- return {"language": language, "reference_text": reference_text, "transcription": transcription, "word_alignment": results}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
 
101
- # --- Gradio UI ---
102
  def get_default_text(language):
103
- return {"Arabic": "فَبِأَيِّ آلَاءِ رَبِّكُمَا تُكَذِّبَانِ", "English": "A"}.get(language, "")
 
 
 
104
 
105
  with gr.Blocks() as demo:
106
- gr.Markdown("# Fast Multilingual Letter & Word Phoneme Analysis (Wav2Vec2)")
107
- language = gr.Dropdown(["Arabic", "English"], value="English", label="Language")
 
 
 
 
108
  reference_text = gr.Textbox(label="Reference Text", value=get_default_text("English"))
109
- audio_input = gr.Audio(label="Record Audio", type="filepath")
110
  submit_btn = gr.Button("Analyze")
111
- output = gr.JSON(label="Results")
112
-
113
- language.change(fn=get_default_text, inputs=language, outputs=reference_text)
114
- submit_btn.click(fn=analyze_phonemes, inputs=[language, reference_text, audio_input], outputs=output)
 
 
 
 
 
 
 
 
 
 
 
115
 
116
  demo.launch()
 
1
  import gradio as gr
2
+ from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
3
  import librosa
4
  import torch
5
  import epitran
6
  import re
7
+ import difflib
8
  import editdistance
9
+ from jiwer import wer
10
+ import json
11
  import string
12
  import eng_to_ipa as ipa
13
  import numpy as np
14
 
15
+ # Models: Wav2Vec2 for both Arabic and English
 
 
 
16
  MODELS = {
17
  "Arabic": {
18
+ "processor": Wav2Vec2Processor.from_pretrained("jonatasgrosman/wav2vec2-large-xlsr-53-arabic"),
19
+ "model": Wav2Vec2ForCTC.from_pretrained("jonatasgrosman/wav2vec2-large-xlsr-53-arabic"),
20
  "epitran": epitran.Epitran("ara-Arab")
21
  },
22
  "English": {
23
+ "processor": Wav2Vec2Processor.from_pretrained("jonatasgrosman/wav2vec2-large-xlsr-53-english"),
24
+ "model": Wav2Vec2ForCTC.from_pretrained("jonatasgrosman/wav2vec2-large-xlsr-53-english"),
25
  "epitran": epitran.Epitran("eng-Latn")
26
  }
27
  }
 
29
  for lang in MODELS.values():
30
  lang["model"].config.ctc_loss_reduction = "mean"
31
 
 
 
 
32
  def clean_phonemes(ipa_text):
33
  return re.sub(r'[\u064B-\u0652\u02D0]', '', ipa_text)
34
 
35
  def safe_transliterate_arabic(epi, word):
36
  try:
37
+ word = word.strip()
38
+ ipa = epi.transliterate(word)
39
+ if not ipa.strip():
40
+ raise ValueError("Empty IPA string")
41
+ return clean_phonemes(ipa)
42
+ except Exception as e:
43
+ print(f"[Warning] Arabic transliteration failed for '{word}': {e}")
44
  return ""
45
 
46
  def transliterate_english(word):
47
  try:
48
+ word = word.lower().translate(str.maketrans('', '', string.punctuation))
49
+ ipa_text = ipa.convert(word)
50
+ return clean_phonemes(ipa_text)
51
+ except Exception as e:
52
+ print(f"[Warning] English IPA conversion failed for '{word}': {e}")
53
  return ""
54
 
55
  def analyze_phonemes(language, reference_text, audio_file):
 
60
 
61
  transliterate_fn = safe_transliterate_arabic if language == "Arabic" else transliterate_english
62
 
63
+ ref_phonemes = [list(transliterate_fn(word)) for word in reference_text.split()]
64
+
65
+ # Load audio
66
  audio, sr = librosa.load(audio_file, sr=16000)
 
 
67
 
68
+ # Normalize volume
69
+ max_amp = np.max(np.abs(audio))
70
+ if max_amp > 0:
71
+ audio = audio / max_amp # Normalize to [-1, 1]
72
+
73
+ # Stricter silence trimming
74
  trimmed_audio, _ = librosa.effects.trim(audio, top_db=30)
75
+ if len(trimmed_audio) < (sr * 0.15):
76
+ return json.dumps({
77
+ "language": language,
78
+ "reference_text": reference_text,
79
+ "transcription": "No speech detected",
80
+ "word_alignment": [],
81
+ "metrics": {"message": "Audio appears silent or too noisy. Try speaking louder or in a quieter environment."}
82
+ }, indent=2, ensure_ascii=False)
83
+
84
+ # Cap to 0.75s for single letters
85
+ max_duration = 0.75
86
+ if len(trimmed_audio) > int(sr * max_duration):
87
+ trimmed_audio = trimmed_audio[:int(sr * max_duration)]
88
+
89
+ # Noise gate
90
+ noise_gate_threshold = 0.02
91
+ trimmed_audio[np.abs(trimmed_audio) < noise_gate_threshold] = 0
92
+
93
+ input_values = processor(trimmed_audio, sampling_rate=sr, return_tensors="pt").input_values
94
 
 
 
95
  with torch.no_grad():
96
  logits = model(input_values).logits
97
+ pred_ids = torch.argmax(logits, dim=-1)
98
+ transcription = processor.batch_decode(pred_ids)[0].strip()
99
 
100
+ # Stricter confidence check
101
  probs = torch.softmax(logits, dim=-1)
102
+ max_probs = probs.max(dim=-1).values.mean().item()
103
+ if max_probs < 0.6:
104
+ return json.dumps({
105
+ "language": language,
106
+ "reference_text": reference_text,
107
+ "transcription": "No speech detected",
108
+ "word_alignment": [],
109
+ "metrics": {"message": "Low confidence transcription (possible noise). Try again with clearer speech."}
110
+ }, indent=2, ensure_ascii=False)
111
+
112
+ # Filter vowel-heavy or overly long transcriptions
113
+ transcription_clean = transcription.lower().replace("the", "").strip()
114
+ if len(transcription_clean) > 3 or re.match(r'^[aeiou]+$', transcription_clean):
115
+ return json.dumps({
116
+ "language": language,
117
+ "reference_text": reference_text,
118
+ "transcription": "No speech detected",
119
+ "word_alignment": [],
120
+ "metrics": {"message": "Detected noise or unclear speech. Try again with clear pronunciation."}
121
+ }, indent=2, ensure_ascii=False)
122
+
123
+ obs_phonemes = [list(transliterate_fn(word)) for word in transcription_clean.split()]
124
+
125
+ results = {
126
+ "language": language,
127
+ "reference_text": reference_text,
128
+ "transcription": transcription_clean or "No speech detected",
129
+ "word_alignment": [],
130
+ "metrics": {}
131
+ }
132
 
133
+ total_phoneme_errors = 0
134
+ total_phoneme_length = 0
135
+ correct_words = 0
136
+ total_word_length = len(ref_phonemes)
137
+
138
+ for i, (ref, obs) in enumerate(zip(ref_phonemes, obs_phonemes)):
139
+ ref_str = ''.join(ref)
140
+ obs_str = ''.join(obs)
141
+ edits = editdistance.eval(ref, obs)
142
+ acc = round((1 - edits / max(1, len(ref))) * 100, 2)
143
+
144
+ matcher = difflib.SequenceMatcher(None, ref, obs)
145
+ ops = matcher.get_opcodes()
146
+ error_details = []
147
+ for tag, i1, i2, j1, j2 in ops:
148
+ ref_seg = ''.join(ref[i1:i2]) or '-'
149
+ obs_seg = ''.join(obs[j1:j2]) or '-'
150
+ if tag != 'equal':
151
+ error_details.append({
152
+ "type": tag.upper(),
153
+ "reference": ref_seg,
154
+ "observed": obs_seg
155
+ })
156
+
157
+ results["word_alignment"].append({
158
+ "word_index": i,
159
+ "reference_phonemes": ref_str,
160
+ "observed_phonemes": obs_str,
161
+ "edit_distance": edits,
162
+ "accuracy": acc,
163
+ "is_correct": edits == 0,
164
+ "errors": error_details
165
  })
166
 
167
+ total_phoneme_errors += edits
168
+ total_phoneme_length += len(ref)
169
+ correct_words += int(edits == 0)
170
+
171
+ phoneme_acc = round((1 - total_phoneme_errors / max(1, total_phoneme_length)) * 100, 2)
172
+ phoneme_er = round((total_phoneme_errors / max(1, total_phoneme_length)) * 100, 2)
173
+ word_acc = round((correct_words / max(1, total_word_length)) * 100, 2)
174
+ word_er = round(((total_word_length - correct_words) / max(1, total_word_length)) * 100, 2)
175
+ text_wer = round(wer(reference_text, transcription_clean or "") * 100, 2)
176
+
177
+ results["metrics"] = {
178
+ "word_accuracy": word_acc,
179
+ "word_error_rate": word_er,
180
+ "phoneme_accuracy": phoneme_acc,
181
+ "phoneme_error_rate": phoneme_er,
182
+ "asr_word_error_rate": text_wer
183
+ }
184
+
185
+ return json.dumps(results, indent=2, ensure_ascii=False)
186
 
 
187
  def get_default_text(language):
188
+ return {
189
+ "Arabic": "فَبِأَيِّ آلَاءِ رَبِّكُمَا تُكَذِّبَانِ",
190
+ "English": "A"
191
+ }.get(language, "")
192
 
193
  with gr.Blocks() as demo:
194
+ gr.Markdown("# Multilingual Phoneme Alignment Analysis")
195
+ gr.Markdown("Compare audio pronunciation with reference text at phoneme level. Tip: Speak clearly; silence or noise may cause errors.")
196
+
197
+ with gr.Row():
198
+ language = gr.Dropdown(["Arabic", "English"], label="Language", value="English")
199
+
200
  reference_text = gr.Textbox(label="Reference Text", value=get_default_text("English"))
201
+ audio_input = gr.Audio(label="Upload Audio File", type="filepath")
202
  submit_btn = gr.Button("Analyze")
203
+ output = gr.JSON(label="Phoneme Alignment Results")
204
+
205
+ language.change(
206
+ fn=get_default_text,
207
+ inputs=language,
208
+ outputs=reference_text,
209
+ api_name="/get_default_text"
210
+ )
211
+
212
+ submit_btn.click(
213
+ fn=analyze_phonemes,
214
+ inputs=[language, reference_text, audio_input],
215
+ outputs=output,
216
+ api_name="/analyze_phonemes"
217
+ )
218
 
219
  demo.launch()