KuyaToto commited on
Commit
e7078f9
ยท
verified ยท
1 Parent(s): 9594d9d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -48
app.py CHANGED
@@ -1,5 +1,5 @@
1
  import gradio as gr
2
- from transformers import WhisperProcessor, WhisperForConditionalGeneration, Wav2Vec2ForCTC, Wav2Vec2Processor
3
  import librosa
4
  import torch
5
  import epitran
@@ -10,27 +10,24 @@ from jiwer import wer
10
  import json
11
  import string
12
  import eng_to_ipa as ipa
13
- import numpy as np # For normalization
14
 
15
- # Models: Use Whisper for English (better silence/noise handling), Wav2Vec2 for Arabic
16
  MODELS = {
17
  "Arabic": {
18
  "processor": Wav2Vec2Processor.from_pretrained("jonatasgrosman/wav2vec2-large-xlsr-53-arabic"),
19
  "model": Wav2Vec2ForCTC.from_pretrained("jonatasgrosman/wav2vec2-large-xlsr-53-arabic"),
20
- "epitran": epitran.Epitran("ara-Arab"),
21
- "is_whisper": False
22
  },
23
  "English": {
24
- "processor": WhisperProcessor.from_pretrained("openai/whisper-tiny.en"),
25
- "model": WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en"),
26
- "epitran": epitran.Epitran("eng-Latn"),
27
- "is_whisper": True
28
  }
29
  }
30
 
31
  for lang in MODELS.values():
32
- if not lang["is_whisper"]:
33
- lang["model"].config.ctc_loss_reduction = "mean"
34
 
35
  def clean_phonemes(ipa_text):
36
  return re.sub(r'[\u064B-\u0652\u02D0]', '', ipa_text)
@@ -60,7 +57,6 @@ def analyze_phonemes(language, reference_text, audio_file):
60
  processor = lang_models["processor"]
61
  model = lang_models["model"]
62
  epi = lang_models["epitran"]
63
- is_whisper = lang_models["is_whisper"]
64
 
65
  transliterate_fn = safe_transliterate_arabic if language == "Arabic" else transliterate_english
66
 
@@ -74,9 +70,9 @@ def analyze_phonemes(language, reference_text, audio_file):
74
  if max_amp > 0:
75
  audio = audio / max_amp # Normalize to [-1, 1]
76
 
77
- # Trim silence (increase top_db to 30 for stricter noise removal)
78
  trimmed_audio, _ = librosa.effects.trim(audio, top_db=30)
79
- if len(trimmed_audio) < (sr * 0.1): # Too short = silence
80
  return json.dumps({
81
  "language": language,
82
  "reference_text": reference_text,
@@ -85,44 +81,51 @@ def analyze_phonemes(language, reference_text, audio_file):
85
  "metrics": {"message": "Audio appears silent or too noisy. Try speaking louder or in a quieter environment."}
86
  }, indent=2, ensure_ascii=False)
87
 
88
- # Cap to 1.5s
89
- max_duration = 1.5
90
  if len(trimmed_audio) > int(sr * max_duration):
91
  trimmed_audio = trimmed_audio[:int(sr * max_duration)]
92
 
93
- if is_whisper:
94
- # Whisper processing
95
- input_features = processor(trimmed_audio, sampling_rate=sr, return_tensors="pt").input_features
96
- with torch.no_grad():
97
- predicted_ids = model.generate(input_features)
98
- transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0].strip()
99
- else:
100
- # Wav2Vec2 processing (for Arabic)
101
- input_values = processor(trimmed_audio, sampling_rate=sr, return_tensors="pt").input_values
102
- with torch.no_grad():
103
- logits = model(input_values).logits
104
- pred_ids = torch.argmax(logits, dim=-1)
105
- transcription = processor.batch_decode(pred_ids)[0].strip()
106
-
107
- # Confidence check (for Wav2Vec2; Whisper has internal VAD)
108
- if not is_whisper:
109
- probs = torch.softmax(logits, dim=-1)
110
- max_probs = probs.max(dim=-1).values.mean().item()
111
- if max_probs < 0.4: # Lower threshold for stricter filtering
112
- return json.dumps({
113
- "language": language,
114
- "reference_text": reference_text,
115
- "transcription": transcription,
116
- "word_alignment": [],
117
- "metrics": {"message": "Low confidence transcription (possible noise). Try again with clearer speech."}
118
- }, indent=2, ensure_ascii=False)
119
-
120
- obs_phonemes = [list(transliterate_fn(word)) for word in transcription.split()]
 
 
 
 
 
 
 
121
 
122
  results = {
123
  "language": language,
124
  "reference_text": reference_text,
125
- "transcription": transcription,
126
  "word_alignment": [],
127
  "metrics": {}
128
  }
@@ -169,7 +172,7 @@ def analyze_phonemes(language, reference_text, audio_file):
169
  phoneme_er = round((total_phoneme_errors / max(1, total_phoneme_length)) * 100, 2)
170
  word_acc = round((correct_words / max(1, total_word_length)) * 100, 2)
171
  word_er = round(((total_word_length - correct_words) / max(1, total_word_length)) * 100, 2)
172
- text_wer = round(wer(reference_text, transcription) * 100, 2)
173
 
174
  results["metrics"] = {
175
  "word_accuracy": word_acc,
@@ -184,7 +187,7 @@ def analyze_phonemes(language, reference_text, audio_file):
184
  def get_default_text(language):
185
  return {
186
  "Arabic": "ููŽุจูุฃูŽูŠู‘ู ุขู„ูŽุงุกู ุฑูŽุจู‘ููƒูู…ูŽุง ุชููƒูŽุฐู‘ูุจูŽุงู†ู",
187
- "English": "The quick brown fox jumps over the lazy dog"
188
  }.get(language, "")
189
 
190
  with gr.Blocks() as demo:
@@ -192,7 +195,7 @@ with gr.Blocks() as demo:
192
  gr.Markdown("Compare audio pronunciation with reference text at phoneme level. Tip: Speak clearly; silence or noise may cause errors.")
193
 
194
  with gr.Row():
195
- language = gr.Dropdown(["Arabic", "English"], label="Language", value="English") # Default to English
196
 
197
  reference_text = gr.Textbox(label="Reference Text", value=get_default_text("English"))
198
  audio_input = gr.Audio(label="Upload Audio File", type="filepath")
 
1
  import gradio as gr
2
+ from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
3
  import librosa
4
  import torch
5
  import epitran
 
10
  import json
11
  import string
12
  import eng_to_ipa as ipa
13
+ import numpy as np
14
 
15
+ # Models: Wav2Vec2 for both Arabic and English
16
  MODELS = {
17
  "Arabic": {
18
  "processor": Wav2Vec2Processor.from_pretrained("jonatasgrosman/wav2vec2-large-xlsr-53-arabic"),
19
  "model": Wav2Vec2ForCTC.from_pretrained("jonatasgrosman/wav2vec2-large-xlsr-53-arabic"),
20
+ "epitran": epitran.Epitran("ara-Arab")
 
21
  },
22
  "English": {
23
+ "processor": Wav2Vec2Processor.from_pretrained("jonatasgrosman/wav2vec2-large-xlsr-53-english"),
24
+ "model": Wav2Vec2ForCTC.from_pretrained("jonatasgrosman/wav2vec2-large-xlsr-53-english"),
25
+ "epitran": epitran.Epitran("eng-Latn")
 
26
  }
27
  }
28
 
29
  for lang in MODELS.values():
30
+ lang["model"].config.ctc_loss_reduction = "mean"
 
31
 
32
  def clean_phonemes(ipa_text):
33
  return re.sub(r'[\u064B-\u0652\u02D0]', '', ipa_text)
 
57
  processor = lang_models["processor"]
58
  model = lang_models["model"]
59
  epi = lang_models["epitran"]
 
60
 
61
  transliterate_fn = safe_transliterate_arabic if language == "Arabic" else transliterate_english
62
 
 
70
  if max_amp > 0:
71
  audio = audio / max_amp # Normalize to [-1, 1]
72
 
73
+ # Stricter silence trimming
74
  trimmed_audio, _ = librosa.effects.trim(audio, top_db=30)
75
+ if len(trimmed_audio) < (sr * 0.15):
76
  return json.dumps({
77
  "language": language,
78
  "reference_text": reference_text,
 
81
  "metrics": {"message": "Audio appears silent or too noisy. Try speaking louder or in a quieter environment."}
82
  }, indent=2, ensure_ascii=False)
83
 
84
+ # Cap to 0.75s for single letters
85
+ max_duration = 0.75
86
  if len(trimmed_audio) > int(sr * max_duration):
87
  trimmed_audio = trimmed_audio[:int(sr * max_duration)]
88
 
89
+ # Noise gate
90
+ noise_gate_threshold = 0.02
91
+ trimmed_audio[np.abs(trimmed_audio) < noise_gate_threshold] = 0
92
+
93
+ input_values = processor(trimmed_audio, sampling_rate=sr, return_tensors="pt").input_values
94
+
95
+ with torch.no_grad():
96
+ logits = model(input_values).logits
97
+ pred_ids = torch.argmax(logits, dim=-1)
98
+ transcription = processor.batch_decode(pred_ids)[0].strip()
99
+
100
+ # Stricter confidence check
101
+ probs = torch.softmax(logits, dim=-1)
102
+ max_probs = probs.max(dim=-1).values.mean().item()
103
+ if max_probs < 0.6:
104
+ return json.dumps({
105
+ "language": language,
106
+ "reference_text": reference_text,
107
+ "transcription": "No speech detected",
108
+ "word_alignment": [],
109
+ "metrics": {"message": "Low confidence transcription (possible noise). Try again with clearer speech."}
110
+ }, indent=2, ensure_ascii=False)
111
+
112
+ # Filter vowel-heavy or overly long transcriptions
113
+ transcription_clean = transcription.lower().replace("the", "").strip()
114
+ if len(transcription_clean) > 3 or re.match(r'^[aeiou]+$', transcription_clean):
115
+ return json.dumps({
116
+ "language": language,
117
+ "reference_text": reference_text,
118
+ "transcription": "No speech detected",
119
+ "word_alignment": [],
120
+ "metrics": {"message": "Detected noise or unclear speech. Try again with clear pronunciation."}
121
+ }, indent=2, ensure_ascii=False)
122
+
123
+ obs_phonemes = [list(transliterate_fn(word)) for word in transcription_clean.split()]
124
 
125
  results = {
126
  "language": language,
127
  "reference_text": reference_text,
128
+ "transcription": transcription_clean or "No speech detected",
129
  "word_alignment": [],
130
  "metrics": {}
131
  }
 
172
  phoneme_er = round((total_phoneme_errors / max(1, total_phoneme_length)) * 100, 2)
173
  word_acc = round((correct_words / max(1, total_word_length)) * 100, 2)
174
  word_er = round(((total_word_length - correct_words) / max(1, total_word_length)) * 100, 2)
175
+ text_wer = round(wer(reference_text, transcription_clean or "") * 100, 2)
176
 
177
  results["metrics"] = {
178
  "word_accuracy": word_acc,
 
187
  def get_default_text(language):
188
  return {
189
  "Arabic": "ููŽุจูุฃูŽูŠู‘ู ุขู„ูŽุงุกู ุฑูŽุจู‘ููƒูู…ูŽุง ุชููƒูŽุฐู‘ูุจูŽุงู†ู",
190
+ "English": "A"
191
  }.get(language, "")
192
 
193
  with gr.Blocks() as demo:
 
195
  gr.Markdown("Compare audio pronunciation with reference text at phoneme level. Tip: Speak clearly; silence or noise may cause errors.")
196
 
197
  with gr.Row():
198
+ language = gr.Dropdown(["Arabic", "English"], label="Language", value="English")
199
 
200
  reference_text = gr.Textbox(label="Reference Text", value=get_default_text("English"))
201
  audio_input = gr.Audio(label="Upload Audio File", type="filepath")