KuyaToto commited on
Commit
fa83439
·
verified ·
1 Parent(s): daa79d8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -43
app.py CHANGED
@@ -10,13 +10,11 @@ import orjson
10
  import string
11
  import eng_to_ipa as ipa
12
  import numpy as np
13
- from functools import lru_cache
14
- from collections import defaultdict
15
 
16
- # Device setup
17
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18
 
19
- # WordMap
20
  WORD_MAP = {
21
  'A': {'word': 'Apple', 'phonetic': 'ˈæpəl'},
22
  'B': {'word': 'Ball', 'phonetic': 'bɔːl'},
@@ -46,32 +44,15 @@ WORD_MAP = {
46
  'Z': {'word': 'Zebra', 'phonetic': 'ˈziːbrə'}
47
  }
48
 
49
- # Model configuration
50
- MODELS = {
51
- "English": {
52
- "processor_path": "facebook/wav2vec2-base-960h",
53
- "model_path": "facebook/wav2vec2-base-960h",
54
- "epitran": lambda: epitran.Epitran("eng-Latn"),
55
- "processor": None,
56
- "model": None,
57
- "epitran_instance": None
58
- }
59
- }
60
-
61
- @lru_cache(maxsize=1)
62
- def load_model(language):
63
- if language not in MODELS:
64
- raise ValueError(f"Unsupported language: {language}")
65
- if MODELS[language]["processor"] is None:
66
- MODELS[language]["processor"] = Wav2Vec2Processor.from_pretrained(MODELS[language]["processor_path"])
67
- MODELS[language]["model"] = Wav2Vec2ForCTC.from_pretrained(MODELS[language]["model_path"]).to(device).eval()
68
- MODELS[language]["epitran_instance"] = MODELS[language]["epitran"]()
69
 
70
- @lru_cache(maxsize=1000)
71
  def clean_phonemes(ipa_text):
72
  return re.sub(r'[^\w\s]', '', ipa_text)
73
 
74
- @lru_cache(maxsize=1000)
75
  def transliterate_english(word):
76
  try:
77
  word = word.lower().translate(str.maketrans('', '', string.punctuation))
@@ -92,22 +73,16 @@ def find_closest_word(transcription, reference_word):
92
  similarity = round((1 - distances[closest_word] / max(1, max_len)) * 100, 2)
93
  return closest_word, similarity
94
 
 
95
  def analyze_phonemes(language, reference_text, audio_input):
96
  try:
97
- load_model(language)
98
- lang_models = MODELS[language]
99
- processor = lang_models["processor"]
100
- model = lang_models["model"]
101
-
102
- # Handle audio input (numpy array from browser recording)
103
- if isinstance(audio_input, tuple) or isinstance(audio_input, list):
104
  audio, sr = audio_input[0], audio_input[1]
105
- else: # fallback to file path
106
  audio, sr = librosa.load(audio_input, sr=16000, mono=True)
107
-
108
- max_amp = np.max(np.abs(audio))
109
- if max_amp > 0:
110
- audio = audio / max_amp
111
 
112
  trimmed_audio, _ = librosa.effects.trim(audio, top_db=25)
113
  if len(trimmed_audio) < 2400:
@@ -119,9 +94,9 @@ def analyze_phonemes(language, reference_text, audio_input):
119
  "metrics": {"message": "Audio too short or silent."}
120
  }).decode()
121
 
122
- trimmed_audio = trimmed_audio[:12000] # 0.75s cap
123
-
124
  input_values = processor(trimmed_audio, sampling_rate=16000, return_tensors="pt", padding=True).input_values.to(device)
 
125
  with torch.no_grad():
126
  logits = model(input_values).logits
127
  pred_ids = torch.argmax(logits, dim=-1)
@@ -202,11 +177,10 @@ def analyze_phonemes(language, reference_text, audio_input):
202
  "metrics": {"message": f"Error: {str(e)}"}
203
  }).decode()
204
 
205
-
206
  def get_default_text(language):
207
  return "A" if language == "English" else ""
208
 
209
-
210
  with gr.Blocks() as demo:
211
  gr.Markdown("# Multilingual Phoneme Alignment Analysis")
212
  gr.Markdown("Compare audio pronunciation with reference text at phoneme level.")
@@ -214,7 +188,7 @@ with gr.Blocks() as demo:
214
  with gr.Row():
215
  language = gr.Dropdown(["English"], label="Language", value="English")
216
  reference_text = gr.Textbox(label="Reference Text", value="A")
217
- audio_input = gr.Audio(label="Record Audio", type="numpy") # <-- change here
218
  submit_btn = gr.Button("Analyze")
219
  output = gr.JSON(label="Phoneme Alignment Results")
220
 
 
10
  import string
11
  import eng_to_ipa as ipa
12
  import numpy as np
 
 
13
 
14
+ # --- Device setup ---
15
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
 
17
+ # --- WordMap ---
18
  WORD_MAP = {
19
  'A': {'word': 'Apple', 'phonetic': 'ˈæpəl'},
20
  'B': {'word': 'Ball', 'phonetic': 'bɔːl'},
 
44
  'Z': {'word': 'Zebra', 'phonetic': 'ˈziːbrə'}
45
  }
46
 
47
+ # --- Load model once at startup ---
48
+ processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
49
+ model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h").to(device).eval()
50
+ epi = epitran.Epitran("eng-Latn")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
+ # --- Helper functions ---
53
  def clean_phonemes(ipa_text):
54
  return re.sub(r'[^\w\s]', '', ipa_text)
55
 
 
56
  def transliterate_english(word):
57
  try:
58
  word = word.lower().translate(str.maketrans('', '', string.punctuation))
 
73
  similarity = round((1 - distances[closest_word] / max(1, max_len)) * 100, 2)
74
  return closest_word, similarity
75
 
76
+ # --- Main analysis function ---
77
  def analyze_phonemes(language, reference_text, audio_input):
78
  try:
79
+ # Handle audio input
80
+ if isinstance(audio_input, (tuple, list)):
 
 
 
 
 
81
  audio, sr = audio_input[0], audio_input[1]
82
+ else:
83
  audio, sr = librosa.load(audio_input, sr=16000, mono=True)
84
+ audio = audio.astype(np.float32)
85
+ audio = audio / max(1e-9, np.max(np.abs(audio)))
 
 
86
 
87
  trimmed_audio, _ = librosa.effects.trim(audio, top_db=25)
88
  if len(trimmed_audio) < 2400:
 
94
  "metrics": {"message": "Audio too short or silent."}
95
  }).decode()
96
 
97
+ trimmed_audio = trimmed_audio[:48000] # up to 3 seconds
 
98
  input_values = processor(trimmed_audio, sampling_rate=16000, return_tensors="pt", padding=True).input_values.to(device)
99
+
100
  with torch.no_grad():
101
  logits = model(input_values).logits
102
  pred_ids = torch.argmax(logits, dim=-1)
 
177
  "metrics": {"message": f"Error: {str(e)}"}
178
  }).decode()
179
 
180
+ # --- Gradio UI ---
181
  def get_default_text(language):
182
  return "A" if language == "English" else ""
183
 
 
184
  with gr.Blocks() as demo:
185
  gr.Markdown("# Multilingual Phoneme Alignment Analysis")
186
  gr.Markdown("Compare audio pronunciation with reference text at phoneme level.")
 
188
  with gr.Row():
189
  language = gr.Dropdown(["English"], label="Language", value="English")
190
  reference_text = gr.Textbox(label="Reference Text", value="A")
191
+ audio_input = gr.Audio(label="Record Audio", type="numpy")
192
  submit_btn = gr.Button("Analyze")
193
  output = gr.JSON(label="Phoneme Alignment Results")
194