KuyaToto commited on
Commit
c173e30
·
verified ·
1 Parent(s): 883f9e7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +65 -89
app.py CHANGED
@@ -7,120 +7,114 @@ import re
7
  import difflib
8
  import editdistance
9
  from jiwer import wer
10
- import json
11
  import string
12
  import eng_to_ipa as ipa
13
  import numpy as np
 
14
 
15
- # Models: Wav2Vec2 for both Arabic and English
 
 
 
16
  MODELS = {
17
  "Arabic": {
18
- "processor": Wav2Vec2Processor.from_pretrained("jonatasgrosman/wav2vec2-large-xlsr-53-arabic"),
19
- "model": Wav2Vec2ForCTC.from_pretrained("jonatasgrosman/wav2vec2-large-xlsr-53-arabic"),
20
- "epitran": epitran.Epitran("ara-Arab")
 
 
 
21
  },
22
  "English": {
23
- "processor": Wav2Vec2Processor.from_pretrained("jonatasgrosman/wav2vec2-large-xlsr-53-english"),
24
- "model": Wav2Vec2ForCTC.from_pretrained("jonatasgrosman/wav2vec2-large-xlsr-53-english"),
25
- "epitran": epitran.Epitran("eng-Latn")
 
 
 
26
  }
27
  }
28
 
29
- for lang in MODELS.values():
30
- lang["model"].config.ctc_loss_reduction = "mean"
 
 
 
 
31
 
 
32
  def clean_phonemes(ipa_text):
33
- return re.sub(r'[\u064B-\u0652\u02D0]', '', ipa_text)
34
 
35
- def safe_transliterate_arabic(epi, word):
 
36
  try:
37
  word = word.strip()
38
- ipa = epi.transliterate(word)
39
- if not ipa.strip():
40
- raise ValueError("Empty IPA string")
41
- return clean_phonemes(ipa)
42
- except Exception as e:
43
- print(f"[Warning] Arabic transliteration failed for '{word}': {e}")
44
  return ""
45
 
 
46
  def transliterate_english(word):
47
  try:
48
  word = word.lower().translate(str.maketrans('', '', string.punctuation))
49
- ipa_text = ipa.convert(word)
50
- return clean_phonemes(ipa_text)
51
- except Exception as e:
52
- print(f"[Warning] English IPA conversion failed for '{word}': {e}")
53
  return ""
54
 
55
  def analyze_phonemes(language, reference_text, audio_file):
 
56
  lang_models = MODELS[language]
57
  processor = lang_models["processor"]
58
  model = lang_models["model"]
59
- epi = lang_models["epitran"]
60
 
61
  transliterate_fn = safe_transliterate_arabic if language == "Arabic" else transliterate_english
62
 
63
- ref_phonemes = [list(transliterate_fn(word)) for word in reference_text.split()]
64
-
65
- # Load audio
66
- audio, sr = librosa.load(audio_file, sr=16000)
67
 
68
- # Normalize volume
 
69
  max_amp = np.max(np.abs(audio))
70
  if max_amp > 0:
71
- audio = audio / max_amp # Normalize to [-1, 1]
72
 
73
- # Stricter silence trimming
74
- trimmed_audio, _ = librosa.effects.trim(audio, top_db=30)
75
- if len(trimmed_audio) < (sr * 0.15):
76
- return json.dumps({
77
  "language": language,
78
  "reference_text": reference_text,
79
  "transcription": "No speech detected",
80
  "word_alignment": [],
81
- "metrics": {"message": "Audio appears silent or too noisy. Try speaking louder or in a quieter environment."}
82
- }, indent=2, ensure_ascii=False)
83
 
84
- # Cap to 0.75s for single letters
85
- max_duration = 0.75
86
- if len(trimmed_audio) > int(sr * max_duration):
87
- trimmed_audio = trimmed_audio[:int(sr * max_duration)]
88
 
89
- # Noise gate
90
- noise_gate_threshold = 0.02
91
- trimmed_audio[np.abs(trimmed_audio) < noise_gate_threshold] = 0
92
-
93
- input_values = processor(trimmed_audio, sampling_rate=sr, return_tensors="pt").input_values
94
 
95
  with torch.no_grad():
96
  logits = model(input_values).logits
97
  pred_ids = torch.argmax(logits, dim=-1)
98
- transcription = processor.batch_decode(pred_ids)[0].strip()
99
 
100
- # Stricter confidence check
101
  probs = torch.softmax(logits, dim=-1)
102
  max_probs = probs.max(dim=-1).values.mean().item()
103
- if max_probs < 0.6:
104
- return json.dumps({
105
- "language": language,
106
- "reference_text": reference_text,
107
- "transcription": "No speech detected",
108
- "word_alignment": [],
109
- "metrics": {"message": "Low confidence transcription (possible noise). Try again with clearer speech."}
110
- }, indent=2, ensure_ascii=False)
111
-
112
- # Filter vowel-heavy or overly long transcriptions
113
- transcription_clean = transcription.lower().replace("the", "").strip()
114
- if len(transcription_clean) > 3 or re.match(r'^[aeiou]+$', transcription_clean):
115
- return json.dumps({
116
  "language": language,
117
  "reference_text": reference_text,
118
  "transcription": "No speech detected",
119
  "word_alignment": [],
120
- "metrics": {"message": "Detected noise or unclear speech. Try again with clear pronunciation."}
121
- }, indent=2, ensure_ascii=False)
122
 
123
- obs_phonemes = [list(transliterate_fn(word)) for word in transcription_clean.split()]
124
 
125
  results = {
126
  "language": language,
@@ -142,17 +136,10 @@ def analyze_phonemes(language, reference_text, audio_file):
142
  acc = round((1 - edits / max(1, len(ref))) * 100, 2)
143
 
144
  matcher = difflib.SequenceMatcher(None, ref, obs)
145
- ops = matcher.get_opcodes()
146
- error_details = []
147
- for tag, i1, i2, j1, j2 in ops:
148
- ref_seg = ''.join(ref[i1:i2]) or '-'
149
- obs_seg = ''.join(obs[j1:j2]) or '-'
150
- if tag != 'equal':
151
- error_details.append({
152
- "type": tag.upper(),
153
- "reference": ref_seg,
154
- "observed": obs_seg
155
- })
156
 
157
  results["word_alignment"].append({
158
  "word_index": i,
@@ -182,7 +169,7 @@ def analyze_phonemes(language, reference_text, audio_file):
182
  "asr_word_error_rate": text_wer
183
  }
184
 
185
- return json.dumps(results, indent=2, ensure_ascii=False)
186
 
187
  def get_default_text(language):
188
  return {
@@ -192,28 +179,17 @@ def get_default_text(language):
192
 
193
  with gr.Blocks() as demo:
194
  gr.Markdown("# Multilingual Phoneme Alignment Analysis")
195
- gr.Markdown("Compare audio pronunciation with reference text at phoneme level. Tip: Speak clearly; silence or noise may cause errors.")
196
 
197
  with gr.Row():
198
  language = gr.Dropdown(["Arabic", "English"], label="Language", value="English")
199
 
200
- reference_text = gr.Textbox(label="Reference Text", value=get_default_text("English"))
201
  audio_input = gr.Audio(label="Upload Audio File", type="filepath")
202
  submit_btn = gr.Button("Analyze")
203
  output = gr.JSON(label="Phoneme Alignment Results")
204
 
205
- language.change(
206
- fn=get_default_text,
207
- inputs=language,
208
- outputs=reference_text,
209
- api_name="/get_default_text"
210
- )
211
-
212
- submit_btn.click(
213
- fn=analyze_phonemes,
214
- inputs=[language, reference_text, audio_input],
215
- outputs=output,
216
- api_name="/analyze_phonemes"
217
- )
218
 
219
  demo.launch()
 
7
  import difflib
8
  import editdistance
9
  from jiwer import wer
10
+ import orjson
11
  import string
12
  import eng_to_ipa as ipa
13
  import numpy as np
14
+ from functools import lru_cache
15
 
16
+ # Check for GPU
17
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18
+
19
+ # Lazy-load models
20
  MODELS = {
21
  "Arabic": {
22
+ "processor_path": "jonatasgrosman/wav2vec2-large-xlsr-53-arabic",
23
+ "model_path": "jonatasgrosman/wav2vec2-large-xlsr-53-arabic",
24
+ "epitran": lambda: epitran.Epitran("ara-Arab"),
25
+ "processor": None,
26
+ "model": None,
27
+ "epitran_instance": None
28
  },
29
  "English": {
30
+ "processor_path": "jonatasgrosman/wav2vec2-large-xlsr-53-english",
31
+ "model_path": "jonatasgrosman/wav2vec2-large-xlsr-53-english",
32
+ "epitran": lambda: epitran.Epitran("eng-Latn"),
33
+ "processor": None,
34
+ "model": None,
35
+ "epitran_instance": None
36
  }
37
  }
38
 
39
+ def load_model(language):
40
+ if MODELS[language]["processor"] is None:
41
+ MODELS[language]["processor"] = Wav2Vec2Processor.from_pretrained(MODELS[language]["processor_path"])
42
+ MODELS[language]["model"] = Wav2Vec2ForCTC.from_pretrained(MODELS[language]["model_path"]).to(device)
43
+ MODELS[language]["model"].config.ctc_loss_reduction = "mean"
44
+ MODELS[language]["epitran_instance"] = MODELS[language]["epitran"]()
45
 
46
+ @lru_cache(maxsize=1000)
47
  def clean_phonemes(ipa_text):
48
+ return re.sub(r'[^\w\s]', '', ipa_text)
49
 
50
+ @lru_cache(maxsize=1000)
51
+ def safe_transliterate_arabic(word):
52
  try:
53
  word = word.strip()
54
+ ipa = MODELS["Arabic"]["epitran_instance"].transliterate(word)
55
+ return clean_phonemes(ipa) if ipa.strip() else ""
56
+ except Exception:
 
 
 
57
  return ""
58
 
59
+ @lru_cache(maxsize=1000)
60
  def transliterate_english(word):
61
  try:
62
  word = word.lower().translate(str.maketrans('', '', string.punctuation))
63
+ return clean_phonemes(ipa.convert(word))
64
+ except Exception:
 
 
65
  return ""
66
 
67
  def analyze_phonemes(language, reference_text, audio_file):
68
+ load_model(language)
69
  lang_models = MODELS[language]
70
  processor = lang_models["processor"]
71
  model = lang_models["model"]
 
72
 
73
  transliterate_fn = safe_transliterate_arabic if language == "Arabic" else transliterate_english
74
 
75
+ ref_phonemes = [list(transliterate_fn(word)) for word in reference_text.split() if transliterate_fn(word)]
 
 
 
76
 
77
+ # Load and preprocess audio
78
+ audio, _ = librosa.load(audio_file, sr=16000)
79
  max_amp = np.max(np.abs(audio))
80
  if max_amp > 0:
81
+ audio = audio / max_amp
82
 
83
+ trimmed_audio, _ = librosa.effects.trim(audio, top_db=25)
84
+ if len(trimmed_audio) < 2400: # 0.15s at 16kHz
85
+ return orjson.dumps({
 
86
  "language": language,
87
  "reference_text": reference_text,
88
  "transcription": "No speech detected",
89
  "word_alignment": [],
90
+ "metrics": {"message": "Audio too short or silent."}
91
+ }).decode()
92
 
93
+ # Cap audio length to 0.75s
94
+ if len(trimmed_audio) > 12000:
95
+ trimmed_audio = trimmed_audio[:12000]
 
96
 
97
+ input_values = processor(trimmed_audio, sampling_rate=16000, return_tensors="pt").input_values.to(device)
 
 
 
 
98
 
99
  with torch.no_grad():
100
  logits = model(input_values).logits
101
  pred_ids = torch.argmax(logits, dim=-1)
102
+ transcription = processor.batch_decode(pred_ids)[0].strip().lower()
103
 
104
+ # Combined validation
105
  probs = torch.softmax(logits, dim=-1)
106
  max_probs = probs.max(dim=-1).values.mean().item()
107
+ transcription_clean = transcription.replace("the", "").strip()
108
+ if max_probs < 0.6 or len(transcription_clean) > 3 or re.match(r'^[aeiou]+$', transcription_clean):
109
+ return orjson.dumps({
 
 
 
 
 
 
 
 
 
 
110
  "language": language,
111
  "reference_text": reference_text,
112
  "transcription": "No speech detected",
113
  "word_alignment": [],
114
+ "metrics": {"message": "Unclear or noisy speech."}
115
+ }).decode()
116
 
117
+ obs_phonemes = [list(transliterate_fn(word)) for word in transcription_clean.split() if transliterate_fn(word)]
118
 
119
  results = {
120
  "language": language,
 
136
  acc = round((1 - edits / max(1, len(ref))) * 100, 2)
137
 
138
  matcher = difflib.SequenceMatcher(None, ref, obs)
139
+ error_details = [
140
+ {"type": tag.upper(), "reference": ''.join(ref[i1:i2]) or '-', "observed": ''.join(obs[j1:j2]) or '-'}
141
+ for tag, i1, i2, j1, j2 in matcher.get_opcodes() if tag != 'equal'
142
+ ]
 
 
 
 
 
 
 
143
 
144
  results["word_alignment"].append({
145
  "word_index": i,
 
169
  "asr_word_error_rate": text_wer
170
  }
171
 
172
+ return orjson.dumps(results).decode()
173
 
174
  def get_default_text(language):
175
  return {
 
179
 
180
  with gr.Blocks() as demo:
181
  gr.Markdown("# Multilingual Phoneme Alignment Analysis")
182
+ gr.Markdown("Compare audio pronunciation with reference text at phoneme level.")
183
 
184
  with gr.Row():
185
  language = gr.Dropdown(["Arabic", "English"], label="Language", value="English")
186
 
187
+ reference_text = gr.Textbox(label="Reference Text", value="A")
188
  audio_input = gr.Audio(label="Upload Audio File", type="filepath")
189
  submit_btn = gr.Button("Analyze")
190
  output = gr.JSON(label="Phoneme Alignment Results")
191
 
192
+ language.change(fn=get_default_text, inputs=language, outputs=reference_text)
193
+ submit_btn.click(fn=analyze_phonemes, inputs=[language, reference_text, audio_input], outputs=output)
 
 
 
 
 
 
 
 
 
 
 
194
 
195
  demo.launch()