KuyaToto commited on
Commit
63e7642
·
verified ·
1 Parent(s): e7078f9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -158
app.py CHANGED
@@ -4,24 +4,25 @@ import librosa
4
  import torch
5
  import epitran
6
  import re
7
- import difflib
8
  import editdistance
9
- from jiwer import wer
10
  import json
11
  import string
12
  import eng_to_ipa as ipa
13
  import numpy as np
14
 
15
- # Models: Wav2Vec2 for both Arabic and English
 
 
 
16
  MODELS = {
17
  "Arabic": {
18
  "processor": Wav2Vec2Processor.from_pretrained("jonatasgrosman/wav2vec2-large-xlsr-53-arabic"),
19
- "model": Wav2Vec2ForCTC.from_pretrained("jonatasgrosman/wav2vec2-large-xlsr-53-arabic"),
20
  "epitran": epitran.Epitran("ara-Arab")
21
  },
22
  "English": {
23
- "processor": Wav2Vec2Processor.from_pretrained("jonatasgrosman/wav2vec2-large-xlsr-53-english"),
24
- "model": Wav2Vec2ForCTC.from_pretrained("jonatasgrosman/wav2vec2-large-xlsr-53-english"),
25
  "epitran": epitran.Epitran("eng-Latn")
26
  }
27
  }
@@ -29,27 +30,23 @@ MODELS = {
29
  for lang in MODELS.values():
30
  lang["model"].config.ctc_loss_reduction = "mean"
31
 
 
 
 
32
  def clean_phonemes(ipa_text):
33
  return re.sub(r'[\u064B-\u0652\u02D0]', '', ipa_text)
34
 
35
  def safe_transliterate_arabic(epi, word):
36
  try:
37
- word = word.strip()
38
- ipa = epi.transliterate(word)
39
- if not ipa.strip():
40
- raise ValueError("Empty IPA string")
41
- return clean_phonemes(ipa)
42
- except Exception as e:
43
- print(f"[Warning] Arabic transliteration failed for '{word}': {e}")
44
  return ""
45
 
46
  def transliterate_english(word):
47
  try:
48
- word = word.lower().translate(str.maketrans('', '', string.punctuation))
49
- ipa_text = ipa.convert(word)
50
- return clean_phonemes(ipa_text)
51
- except Exception as e:
52
- print(f"[Warning] English IPA conversion failed for '{word}': {e}")
53
  return ""
54
 
55
  def analyze_phonemes(language, reference_text, audio_file):
@@ -60,160 +57,61 @@ def analyze_phonemes(language, reference_text, audio_file):
60
 
61
  transliterate_fn = safe_transliterate_arabic if language == "Arabic" else transliterate_english
62
 
63
- ref_phonemes = [list(transliterate_fn(word)) for word in reference_text.split()]
64
-
65
- # Load audio
66
  audio, sr = librosa.load(audio_file, sr=16000)
 
 
67
 
68
- # Normalize volume
69
- max_amp = np.max(np.abs(audio))
70
- if max_amp > 0:
71
- audio = audio / max_amp # Normalize to [-1, 1]
72
-
73
- # Stricter silence trimming
74
  trimmed_audio, _ = librosa.effects.trim(audio, top_db=30)
75
- if len(trimmed_audio) < (sr * 0.15):
76
- return json.dumps({
77
- "language": language,
78
- "reference_text": reference_text,
79
- "transcription": "No speech detected",
80
- "word_alignment": [],
81
- "metrics": {"message": "Audio appears silent or too noisy. Try speaking louder or in a quieter environment."}
82
- }, indent=2, ensure_ascii=False)
83
-
84
- # Cap to 0.75s for single letters
85
- max_duration = 0.75
86
- if len(trimmed_audio) > int(sr * max_duration):
87
- trimmed_audio = trimmed_audio[:int(sr * max_duration)]
88
-
89
- # Noise gate
90
- noise_gate_threshold = 0.02
91
- trimmed_audio[np.abs(trimmed_audio) < noise_gate_threshold] = 0
92
-
93
- input_values = processor(trimmed_audio, sampling_rate=sr, return_tensors="pt").input_values
94
 
 
 
95
  with torch.no_grad():
96
  logits = model(input_values).logits
97
- pred_ids = torch.argmax(logits, dim=-1)
98
- transcription = processor.batch_decode(pred_ids)[0].strip()
99
 
100
- # Stricter confidence check
101
  probs = torch.softmax(logits, dim=-1)
102
- max_probs = probs.max(dim=-1).values.mean().item()
103
- if max_probs < 0.6:
104
- return json.dumps({
105
- "language": language,
106
- "reference_text": reference_text,
107
- "transcription": "No speech detected",
108
- "word_alignment": [],
109
- "metrics": {"message": "Low confidence transcription (possible noise). Try again with clearer speech."}
110
- }, indent=2, ensure_ascii=False)
111
-
112
- # Filter vowel-heavy or overly long transcriptions
113
- transcription_clean = transcription.lower().replace("the", "").strip()
114
- if len(transcription_clean) > 3 or re.match(r'^[aeiou]+$', transcription_clean):
115
- return json.dumps({
116
- "language": language,
117
- "reference_text": reference_text,
118
- "transcription": "No speech detected",
119
- "word_alignment": [],
120
- "metrics": {"message": "Detected noise or unclear speech. Try again with clear pronunciation."}
121
- }, indent=2, ensure_ascii=False)
122
-
123
- obs_phonemes = [list(transliterate_fn(word)) for word in transcription_clean.split()]
124
-
125
- results = {
126
- "language": language,
127
- "reference_text": reference_text,
128
- "transcription": transcription_clean or "No speech detected",
129
- "word_alignment": [],
130
- "metrics": {}
131
- }
132
 
133
- total_phoneme_errors = 0
134
- total_phoneme_length = 0
135
- correct_words = 0
136
- total_word_length = len(ref_phonemes)
137
-
138
- for i, (ref, obs) in enumerate(zip(ref_phonemes, obs_phonemes)):
139
- ref_str = ''.join(ref)
140
- obs_str = ''.join(obs)
141
- edits = editdistance.eval(ref, obs)
142
- acc = round((1 - edits / max(1, len(ref))) * 100, 2)
143
-
144
- matcher = difflib.SequenceMatcher(None, ref, obs)
145
- ops = matcher.get_opcodes()
146
- error_details = []
147
- for tag, i1, i2, j1, j2 in ops:
148
- ref_seg = ''.join(ref[i1:i2]) or '-'
149
- obs_seg = ''.join(obs[j1:j2]) or '-'
150
- if tag != 'equal':
151
- error_details.append({
152
- "type": tag.upper(),
153
- "reference": ref_seg,
154
- "observed": obs_seg
155
- })
156
-
157
- results["word_alignment"].append({
158
- "word_index": i,
159
- "reference_phonemes": ref_str,
160
- "observed_phonemes": obs_str,
161
- "edit_distance": edits,
162
- "accuracy": acc,
163
- "is_correct": edits == 0,
164
- "errors": error_details
165
- })
166
 
167
- total_phoneme_errors += edits
168
- total_phoneme_length += len(ref)
169
- correct_words += int(edits == 0)
170
-
171
- phoneme_acc = round((1 - total_phoneme_errors / max(1, total_phoneme_length)) * 100, 2)
172
- phoneme_er = round((total_phoneme_errors / max(1, total_phoneme_length)) * 100, 2)
173
- word_acc = round((correct_words / max(1, total_word_length)) * 100, 2)
174
- word_er = round(((total_word_length - correct_words) / max(1, total_word_length)) * 100, 2)
175
- text_wer = round(wer(reference_text, transcription_clean or "") * 100, 2)
176
-
177
- results["metrics"] = {
178
- "word_accuracy": word_acc,
179
- "word_error_rate": word_er,
180
- "phoneme_accuracy": phoneme_acc,
181
- "phoneme_error_rate": phoneme_er,
182
- "asr_word_error_rate": text_wer
183
- }
184
 
185
- return json.dumps(results, indent=2, ensure_ascii=False)
186
 
 
187
  def get_default_text(language):
188
- return {
189
- "Arabic": "فَبِأَيِّ آلَاءِ رَبِّكُمَا تُكَذِّبَانِ",
190
- "English": "A"
191
- }.get(language, "")
192
 
193
  with gr.Blocks() as demo:
194
- gr.Markdown("# Multilingual Phoneme Alignment Analysis")
195
- gr.Markdown("Compare audio pronunciation with reference text at phoneme level. Tip: Speak clearly; silence or noise may cause errors.")
196
-
197
- with gr.Row():
198
- language = gr.Dropdown(["Arabic", "English"], label="Language", value="English")
199
-
200
  reference_text = gr.Textbox(label="Reference Text", value=get_default_text("English"))
201
- audio_input = gr.Audio(label="Upload Audio File", type="filepath")
202
  submit_btn = gr.Button("Analyze")
203
- output = gr.JSON(label="Phoneme Alignment Results")
204
-
205
- language.change(
206
- fn=get_default_text,
207
- inputs=language,
208
- outputs=reference_text,
209
- api_name="/get_default_text"
210
- )
211
-
212
- submit_btn.click(
213
- fn=analyze_phonemes,
214
- inputs=[language, reference_text, audio_input],
215
- outputs=output,
216
- api_name="/analyze_phonemes"
217
- )
218
-
219
- demo.launch()
 
4
  import torch
5
  import epitran
6
  import re
 
7
  import editdistance
 
8
  import json
9
  import string
10
  import eng_to_ipa as ipa
11
  import numpy as np
12
 
13
+ # --- Device setup ---
14
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
+
16
+ # --- Load faster/smaller models for English & Arabic ---
17
  MODELS = {
18
  "Arabic": {
19
  "processor": Wav2Vec2Processor.from_pretrained("jonatasgrosman/wav2vec2-large-xlsr-53-arabic"),
20
+ "model": Wav2Vec2ForCTC.from_pretrained("jonatasgrosman/wav2vec2-large-xlsr-53-arabic").to(device),
21
  "epitran": epitran.Epitran("ara-Arab")
22
  },
23
  "English": {
24
+ "processor": Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h"),
25
+ "model": Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h").to(device),
26
  "epitran": epitran.Epitran("eng-Latn")
27
  }
28
  }
 
30
  for lang in MODELS.values():
31
  lang["model"].config.ctc_loss_reduction = "mean"
32
 
33
+ # --- Precompute IPA mapping for single letters ---
34
+ LETTER_IPA = {l: ipa.convert(l.lower()).replace(".", "") for l in string.ascii_uppercase}
35
+
36
  def clean_phonemes(ipa_text):
37
  return re.sub(r'[\u064B-\u0652\u02D0]', '', ipa_text)
38
 
39
  def safe_transliterate_arabic(epi, word):
40
  try:
41
+ ipa_text = epi.transliterate(word.strip())
42
+ return clean_phonemes(ipa_text)
43
+ except:
 
 
 
 
44
  return ""
45
 
46
  def transliterate_english(word):
47
  try:
48
+ return LETTER_IPA.get(word.upper(), "")
49
+ except:
 
 
 
50
  return ""
51
 
52
  def analyze_phonemes(language, reference_text, audio_file):
 
57
 
58
  transliterate_fn = safe_transliterate_arabic if language == "Arabic" else transliterate_english
59
 
60
+ # --- Load & normalize audio ---
 
 
61
  audio, sr = librosa.load(audio_file, sr=16000)
62
+ if len(audio) < sr * 0.1:
63
+ return {"language": language, "transcription": "No speech detected", "correct": False}
64
 
65
+ audio = audio / max(np.max(np.abs(audio)), 1e-9)
 
 
 
 
 
66
  trimmed_audio, _ = librosa.effects.trim(audio, top_db=30)
67
+ trimmed_audio = trimmed_audio[:int(sr*0.75)] # max 0.75s
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
+ # --- Wav2Vec2 inference ---
70
+ input_values = processor(trimmed_audio, sampling_rate=sr, return_tensors="pt").input_values.to(device)
71
  with torch.no_grad():
72
  logits = model(input_values).logits
73
+ pred_ids = torch.argmax(logits, dim=-1)
74
+ transcription = processor.batch_decode(pred_ids)[0].strip()
75
 
76
+ # --- Quick confidence check ---
77
  probs = torch.softmax(logits, dim=-1)
78
+ if probs.max(dim=-1).values.mean().item() < 0.6:
79
+ return {"language": language, "transcription": "Low confidence", "correct": False}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
+ # --- Single-letter optimization ---
82
+ if len(reference_text.strip()) == 1:
83
+ ref_ipa = transliterate_fn(reference_text.strip())
84
+ trans_ipa = transliterate_fn(transcription)
85
+ correct = ref_ipa == trans_ipa or reference_text.upper() == transcription.upper()
86
+ return {"language": language, "reference": reference_text, "transcription": transcription, "correct": correct}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
 
88
+ # --- Full phoneme alignment (for multi-letter words) ---
89
+ ref_phonemes = [list(transliterate_fn(word)) for word in reference_text.split()]
90
+ obs_phonemes = [list(transliterate_fn(word)) for word in transcription.split()]
91
+
92
+ results = []
93
+ for r, o in zip(ref_phonemes, obs_phonemes):
94
+ results.append({
95
+ "reference": ''.join(r),
96
+ "observed": ''.join(o),
97
+ "edit_distance": editdistance.eval(r, o)
98
+ })
 
 
 
 
 
 
99
 
100
+ return {"language": language, "reference_text": reference_text, "transcription": transcription, "word_alignment": results}
101
 
102
+ # --- Gradio UI ---
103
  def get_default_text(language):
104
+ return {"Arabic": "فَبِأَيِّ آلَاءِ رَبِّكُمَا تُكَذِّبَانِ", "English": "A"}.get(language, "")
 
 
 
105
 
106
  with gr.Blocks() as demo:
107
+ gr.Markdown("# Fast Multilingual Letter & Word Phoneme Analysis")
108
+ language = gr.Dropdown(["Arabic","English"], value="English", label="Language")
 
 
 
 
109
  reference_text = gr.Textbox(label="Reference Text", value=get_default_text("English"))
110
+ audio_input = gr.Audio(label="Record Audio", type="filepath")
111
  submit_btn = gr.Button("Analyze")
112
+ output = gr.JSON(label="Results")
113
+
114
+ language.change(fn=get_default_text, inputs=language, outputs=reference_text)
115
+ submit_btn.click(fn=analyze_phonemes, inputs=[language, reference_text, audio_input], outputs=output)
116
+
117
+ demo.launch()