Aleksei Žavoronkov commited on
Commit
8cf218e
·
1 Parent(s): ed26f9c

update model architecture to the latest

Browse files
Files changed (5) hide show
  1. app.py +37 -74
  2. constants.py +3 -3
  3. gop_model.py +209 -153
  4. models.py +40 -20
  5. utils.py +62 -32
app.py CHANGED
@@ -1,58 +1,60 @@
 
 
1
  import gradio as gr
2
- from constants import ALL_PHONEMES, QUALITY_MODEL_REPO_ID, DURATION_MODEL_REPO_ID
 
3
  from utils import load_model_and_processor, run_inference, validate_phonemes
4
- from multiprocessing import Process, Queue, set_start_method
5
- import logging
6
 
7
  logging.basicConfig(
8
  level=logging.INFO,
9
  format="%(asctime)s - %(levelname)s - %(message)s",
10
- handlers=[logging.StreamHandler()]
11
  )
12
  logger = logging.getLogger(__name__)
13
 
14
- logger.info("Loading models into memory globally...")
15
- phoneme_model, phoneme_processor = load_model_and_processor(QUALITY_MODEL_REPO_ID)
16
- duration_model, duration_processor = load_model_and_processor(DURATION_MODEL_REPO_ID)
17
- logger.info("Models loaded successfully.")
18
 
19
  css = """
20
  .phoneme-scores { display: flex; flex-wrap: wrap; justify-content: center; gap: 15px; }
21
  .phoneme-container { text-align: center; padding: 10px; border: 1px solid #ddd; border-radius: 8px; }
22
  .phoneme { font-size: 1.5em; font-weight: bold; margin-bottom: 5px; }
23
  .score { padding: 8px 12px; border-radius: 5px; color: white; font-weight: bold; }
24
- .good { background-color: #28a745; } /* Green */
25
- .medium { background-color: #ffc107; } /* Yellow */
26
- .bad { background-color: #dc3545; } /* Red */
27
  """
28
 
29
 
30
  def get_score_class(score, score_type):
31
  if score_type == "quality":
32
- if score == 1: return 'good'
33
- if score == 2: return 'medium'
34
- return 'bad'
35
- else:
36
- return 'good' if score == 1 else 'bad'
 
37
 
38
 
39
  def generate_html_output(result, score_type):
40
  if isinstance(result, str):
41
  return result
42
 
43
- predicted_scores, tokens, token_lengths = result
44
- html_output = f"<div class='phoneme-section'><h3 class='scores-title'>{'Quality Scores' if score_type == 'quality' else 'Duration Scores'}</h3></div><div class='phoneme-scores'>"
45
 
46
- for i in range(token_lengths[0]):
47
- token = tokens[i]
48
- score = predicted_scores[0][i]
49
- score = int(score) + 1
50
- score_class = get_score_class(score, score_type)
51
 
 
 
 
52
  html_output += f"""
53
  <div class='phoneme-container'>
54
  <div class='phoneme'>{token}</div>
55
- <div class='score {score_class}'>{score}</div>
56
  </div>
57
  """
58
 
@@ -60,24 +62,7 @@ def generate_html_output(result, score_type):
60
  return html_output
61
 
62
 
63
- def inference_wrapper(model_type, model, processor, audio_path, transcript, queue):
64
- result = run_inference(audio_path, transcript, model, processor)
65
-
66
- if isinstance(result, str):
67
- queue.put((model_type, result))
68
- return
69
-
70
- predicted_scores, tokens, token_lengths = result
71
-
72
- scores_list = predicted_scores.cpu().tolist()
73
- lengths_list = token_lengths.cpu().tolist()
74
-
75
- safe_result = (scores_list, tokens, lengths_list)
76
-
77
- queue.put((model_type, safe_result))
78
-
79
-
80
- def score_phonemes_in_parallel(phoneme_text, audio_file):
81
  if audio_file is None:
82
  return "<p style='text-align:center; color:red;'>Please upload a .wav audio file.</p>", ""
83
 
@@ -85,27 +70,9 @@ def score_phonemes_in_parallel(phoneme_text, audio_file):
85
  if phonemes_validation_error:
86
  return phonemes_validation_error, ""
87
 
88
- results_queue = Queue()
89
-
90
- quality_process = Process(
91
- target=inference_wrapper,
92
- args=("quality", phoneme_model, phoneme_processor, audio_file, phoneme_text, results_queue)
93
- )
94
- duration_process = Process(
95
- target=inference_wrapper,
96
- args=("duration", duration_model, duration_processor, audio_file, phoneme_text, results_queue)
97
- )
98
-
99
- quality_process.start()
100
- duration_process.start()
101
-
102
- quality_process.join()
103
- duration_process.join()
104
-
105
- results = {}
106
- while not results_queue.empty():
107
- key, value = results_queue.get()
108
- results[key] = value
109
 
110
  quality_result = results.get("quality")
111
  duration_result = results.get("duration")
@@ -120,8 +87,9 @@ with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
120
  gr.Markdown(
121
  """
122
  # Phoneme Pronunciation and Duration Scorer
123
- Enter the phonemes directly into the text box (space-separated).
124
- Then, upload a `.wav` file or record the audio of the pronounced word.
 
125
  The application will provide a pronunciation (quality) and duration score for each phoneme.
126
 
127
  Scores legend:
@@ -144,7 +112,7 @@ with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
144
  ["aja (L2 speaker)", "a j a", "./audio/L2/03ac-e45b-ec8a-6fa0_aja_take1.wav"],
145
  ["sõpra (L2 speaker)", "s õ pp r a", "./audio/L2/4071-0c77-e1d3-9587_sõpra_take1.wav"],
146
  ],
147
- inputs=[word_input, phoneme_text_input, audio_input]
148
  )
149
 
150
  gr.Markdown("---")
@@ -155,16 +123,11 @@ with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
155
  duration_output_html = gr.HTML()
156
 
157
  btn.click(
158
- fn=score_phonemes_in_parallel,
159
  inputs=[phoneme_text_input, audio_input],
160
- outputs=[phoneme_output_html, duration_output_html]
 
161
  )
162
 
163
  if __name__ == "__main__":
164
- try:
165
- set_start_method("fork", force=True)
166
- logger.info("Multiprocessing start method set to 'fork'.")
167
- except RuntimeError:
168
- logger.warning("Start method has already been set.")
169
-
170
  demo.queue(default_concurrency_limit=2).launch()
 
1
+ import logging
2
+
3
  import gradio as gr
4
+
5
+ from constants import ALL_PHONEMES, MODEL_REPO_ID
6
  from utils import load_model_and_processor, run_inference, validate_phonemes
 
 
7
 
8
  logging.basicConfig(
9
  level=logging.INFO,
10
  format="%(asctime)s - %(levelname)s - %(message)s",
11
+ handlers=[logging.StreamHandler()],
12
  )
13
  logger = logging.getLogger(__name__)
14
 
15
+ logger.info("Loading model into memory globally")
16
+ model, processor = load_model_and_processor(MODEL_REPO_ID)
17
+ logger.info("Model loaded successfully")
 
18
 
19
  css = """
20
  .phoneme-scores { display: flex; flex-wrap: wrap; justify-content: center; gap: 15px; }
21
  .phoneme-container { text-align: center; padding: 10px; border: 1px solid #ddd; border-radius: 8px; }
22
  .phoneme { font-size: 1.5em; font-weight: bold; margin-bottom: 5px; }
23
  .score { padding: 8px 12px; border-radius: 5px; color: white; font-weight: bold; }
24
+ .good { background-color: #28a745; }
25
+ .medium { background-color: #ffc107; }
26
+ .bad { background-color: #dc3545; }
27
  """
28
 
29
 
30
  def get_score_class(score, score_type):
31
  if score_type == "quality":
32
+ if score == 1:
33
+ return "good"
34
+ if score == 2:
35
+ return "medium"
36
+ return "bad"
37
+ return "good" if score == 1 else "bad"
38
 
39
 
40
  def generate_html_output(result, score_type):
41
  if isinstance(result, str):
42
  return result
43
 
44
+ if not result:
45
+ return "<p style='text-align:center; color:red;'>No scores were produced.</p>"
46
 
47
+ predicted_scores, tokens = result
48
+ title = "Quality Scores" if score_type == "quality" else "Duration Scores"
49
+ html_output = f"<div class='phoneme-section'><h3 class='scores-title'>{title}</h3></div><div class='phoneme-scores'>"
 
 
50
 
51
+ for token, score in zip(tokens, predicted_scores):
52
+ display_score = int(score) + 1
53
+ score_class = get_score_class(display_score, score_type)
54
  html_output += f"""
55
  <div class='phoneme-container'>
56
  <div class='phoneme'>{token}</div>
57
+ <div class='score {score_class}'>{display_score}</div>
58
  </div>
59
  """
60
 
 
62
  return html_output
63
 
64
 
65
+ def score_phonemes(phoneme_text, audio_file):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  if audio_file is None:
67
  return "<p style='text-align:center; color:red;'>Please upload a .wav audio file.</p>", ""
68
 
 
70
  if phonemes_validation_error:
71
  return phonemes_validation_error, ""
72
 
73
+ results = run_inference(audio_file, phoneme_text, model, processor)
74
+ if isinstance(results, str):
75
+ return results, ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
  quality_result = results.get("quality")
78
  duration_result = results.get("duration")
 
87
  gr.Markdown(
88
  """
89
  # Phoneme Pronunciation and Duration Scorer
90
+ Enter phonemes directly into the text box, separated by spaces.
91
+ Use `|` between words if you want to score a multi-word sequence.
92
+ Then upload a `.wav` file or record the audio of the pronounced word.
93
  The application will provide a pronunciation (quality) and duration score for each phoneme.
94
 
95
  Scores legend:
 
112
  ["aja (L2 speaker)", "a j a", "./audio/L2/03ac-e45b-ec8a-6fa0_aja_take1.wav"],
113
  ["sõpra (L2 speaker)", "s õ pp r a", "./audio/L2/4071-0c77-e1d3-9587_sõpra_take1.wav"],
114
  ],
115
+ inputs=[word_input, phoneme_text_input, audio_input],
116
  )
117
 
118
  gr.Markdown("---")
 
123
  duration_output_html = gr.HTML()
124
 
125
  btn.click(
126
+ fn=score_phonemes,
127
  inputs=[phoneme_text_input, audio_input],
128
+ outputs=[phoneme_output_html, duration_output_html],
129
+ api_name=False,
130
  )
131
 
132
  if __name__ == "__main__":
 
 
 
 
 
 
133
  demo.queue(default_concurrency_limit=2).launch()
constants.py CHANGED
@@ -1,5 +1,4 @@
1
- QUALITY_MODEL_REPO_ID = "alzavo/sayest-quality"
2
- DURATION_MODEL_REPO_ID = "alzavo/sayest-duration"
3
  SAMPLING_RATE = 16000
4
  MONO_CHANNEL = 1
5
  MAX_AUDIO_DURATION_SECONDS = 5
@@ -94,4 +93,5 @@ ALL_PHONEMES = {
94
  "u:",
95
  "w",
96
  "j:j",
97
- }
 
 
1
+ MODEL_REPO_ID = "alzavo/sayest-latest"
 
2
  SAMPLING_RATE = 16000
3
  MONO_CHANNEL = 1
4
  MAX_AUDIO_DURATION_SECONDS = 5
 
93
  "u:",
94
  "w",
95
  "j:j",
96
+ "|",
97
+ }
gop_model.py CHANGED
@@ -12,9 +12,6 @@ from models import OrdinalLogLoss
12
 
13
 
14
  class GOPWav2Vec2Config(PretrainedConfig):
15
- """
16
- Configuration for GOP-enhanced model that wraps a Wav2Vec2ForCTC backbone.
17
- """
18
  model_type = "gop-wav2vec2"
19
 
20
  def __init__(
@@ -33,6 +30,7 @@ class GOPWav2Vec2Config(PretrainedConfig):
33
  unk_id: Optional[int] = None,
34
  bos_id: Optional[int] = None,
35
  eos_id: Optional[int] = None,
 
36
  token_id_vocab: Optional[List[int]] = None,
37
  ctc_config: Optional[dict] = None,
38
  **kwargs,
@@ -57,21 +55,17 @@ class GOPWav2Vec2Config(PretrainedConfig):
57
  self.unk_id = unk_id
58
  self.bos_id = bos_id
59
  self.eos_id = eos_id
 
60
  self.token_id_vocab = token_id_vocab
61
  self.ctc_config = ctc_config
62
 
63
 
64
  class GOPPhonemeClassifier(PreTrainedModel):
65
- """
66
- GOP classifier that wraps a pretrained Wav2Vec2ForCTC backbone.
67
- Computes per-phoneme scores using GOP-derived features + a small Transformer + classifier head.
68
- """
69
-
70
  config_class = GOPWav2Vec2Config
71
-
72
  def __init__(self, config: GOPWav2Vec2Config, load_pretrained_backbone: bool = False):
73
  super().__init__(config)
74
-
75
  if config.ctc_config is not None:
76
  backbone_config = Wav2Vec2Config.from_dict(config.ctc_config)
77
  elif config.base_model_name_or_path is not None:
@@ -82,35 +76,43 @@ class GOPPhonemeClassifier(PreTrainedModel):
82
  self.ctc_model = Wav2Vec2ForCTC(backbone_config)
83
  self.config.ctc_config = backbone_config.to_dict()
84
 
85
- # Special ids
86
  self.blank_id = config.pad_id
87
  self.unk_id = config.unk_id
88
  self.bos_id = config.bos_id
89
  self.eos_id = config.eos_id
90
  self.pad_id = self.blank_id
 
91
 
92
  special_ids = {self.blank_id, self.unk_id, self.bos_id, self.eos_id, self.pad_id}
93
  self.special_ids = {i for i in special_ids if i is not None}
94
 
95
  vocab_size = int(self.ctc_model.config.vocab_size)
96
  self.token_id_vocab = (
97
- config.token_id_vocab if config.token_id_vocab is not None else [i for i in range(vocab_size) if i not in self.special_ids]
 
 
98
  )
99
 
100
  self.gop_feature_dim = 1 + len(self.token_id_vocab) + 1
101
  self.embedding_dim = int(config.gop_embedding_dim)
102
- self.token_embedding = nn.Embedding(vocab_size, self.embedding_dim, padding_idx=self.pad_id if self.pad_id is not None else 0)
 
 
 
 
103
  self.combined_feature_dim = self.gop_feature_dim + self.embedding_dim
104
-
105
- enc_layer = nn.TransformerEncoderLayer(
106
- d_model=self.combined_feature_dim,
107
- nhead=config.gop_transformer_nhead,
108
- dim_feedforward=config.gop_transformer_dim_feedforward,
109
- dropout=config.gop_transformer_dropout,
110
- activation=F.relu,
 
 
 
111
  batch_first=True,
112
  )
113
- self.gop_transformer_encoder = nn.TransformerEncoder(enc_layer, num_layers=config.gop_transformer_nlayers)
114
 
115
  head_label_config = getattr(config, "gop_head_labels", None)
116
  if head_label_config is None:
@@ -118,13 +120,24 @@ class GOPPhonemeClassifier(PreTrainedModel):
118
  raise ValueError("Config must provide gop_head_labels or num_gop_labels for the classifier.")
119
  head_label_config = {"default": int(config.num_gop_labels)}
120
  self.head_label_config = {str(k): int(v) for k, v in head_label_config.items()}
121
- self.classifiers = nn.ModuleDict({
122
- head: nn.Linear(self.combined_feature_dim, num_labels)
123
- for head, num_labels in self.head_label_config.items()
124
- })
 
 
 
 
 
 
 
 
 
 
 
 
125
 
126
  self._init_losses()
127
-
128
  self.post_init()
129
 
130
  if load_pretrained_backbone:
@@ -152,11 +165,15 @@ class GOPPhonemeClassifier(PreTrainedModel):
152
  elif weights is not None:
153
  head_weights = weights
154
  if head_weights is not None:
155
- head_weights = head_weights if isinstance(head_weights, torch.Tensor) else torch.tensor(head_weights, dtype=torch.float)
 
 
 
 
156
  loss_modules[head] = OrdinalLogLoss(
157
  num_classes=int(num_labels),
158
  alpha=alpha,
159
- reduction='mean',
160
  class_weights=head_weights,
161
  )
162
  self.loss_fns = nn.ModuleDict(loss_modules)
@@ -168,7 +185,6 @@ class GOPPhonemeClassifier(PreTrainedModel):
168
  target_ids: torch.Tensor,
169
  target_lengths: torch.Tensor,
170
  ) -> torch.Tensor:
171
- """CTC log p(target|input) per item for a batch."""
172
  target_ids_cpu = target_ids.cpu()
173
  target_lengths_cpu = target_lengths.cpu()
174
  log_probs_cpu = log_probs_TNC.cpu()
@@ -176,23 +192,22 @@ class GOPPhonemeClassifier(PreTrainedModel):
176
 
177
  targets_flat = []
178
  for i in range(target_ids_cpu.size(0)):
179
- valid_targets = target_ids_cpu[i, :target_lengths_cpu[i]]
180
  targets_flat.append(valid_targets)
181
  targets_cat = torch.cat(targets_flat) if targets_flat else torch.tensor([], dtype=torch.long)
182
 
183
  if target_lengths_cpu.sum() == 0:
184
- return torch.full((log_probs_TNC.size(1),), -float('inf'), device=log_probs_TNC.device)
185
 
186
- ctc_loss_fn = torch.nn.CTCLoss(blank=self.blank_id, reduction='none', zero_infinity=True)
187
  try:
188
  loss_per_item = ctc_loss_fn(log_probs_cpu, targets_cat, input_lengths_cpu, target_lengths_cpu)
189
  return -loss_per_item.to(log_probs_TNC.device)
190
- except Exception as e:
191
- warnings.warn(f"CTCLoss calculation failed: {e}. Returning -inf for batch.")
192
- return torch.full((log_probs_TNC.size(1),), -float('inf'), device=log_probs_TNC.device)
193
 
194
  def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor):
195
- """Compute time dimension after backbone feature extractor."""
196
  def _conv_out_length(input_length, kernel_size, stride):
197
  return torch.div(input_length - kernel_size, stride, rounding_mode="floor") + 1
198
 
@@ -200,6 +215,92 @@ class GOPPhonemeClassifier(PreTrainedModel):
200
  input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
201
  return input_lengths
202
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
203
  def forward(
204
  self,
205
  input_values: torch.Tensor,
@@ -212,10 +313,11 @@ class GOPPhonemeClassifier(PreTrainedModel):
212
  return_dict: Optional[bool] = None,
213
  labels: Optional[torch.Tensor] = None,
214
  ) -> SequenceClassifierOutput:
215
-
216
  device = input_values.device
 
217
 
218
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
 
219
 
220
  if self.training or labels is not None:
221
  if canonical_token_ids is None or token_lengths is None or token_mask is None:
@@ -224,7 +326,15 @@ class GOPPhonemeClassifier(PreTrainedModel):
224
  if token_mask is None:
225
  raise ValueError("`token_mask` must be provided to GOPPhonemeClassifier.forward.")
226
 
227
- # 1) Backbone forward to get hidden states
 
 
 
 
 
 
 
 
228
  outputs = self.ctc_model.wav2vec2(
229
  input_values,
230
  attention_mask=attention_mask,
@@ -234,150 +344,96 @@ class GOPPhonemeClassifier(PreTrainedModel):
234
  )
235
  hidden_states = outputs.last_hidden_state
236
 
237
- # 2) Frame-level logits for CTC
238
  logits_ctc = self.ctc_model.lm_head(hidden_states)
239
  log_probs_ctc = F.log_softmax(logits_ctc, dim=-1)
240
  log_probs_TNC = log_probs_ctc.permute(1, 0, 2).contiguous()
241
 
242
- # 3) Frame lengths
243
  batch_size = input_values.size(0)
244
- if attention_mask is None:
245
- input_lengths_frames = torch.full((batch_size,), log_probs_TNC.size(0), dtype=torch.long, device=device)
246
- else:
247
- input_lengths_samples = attention_mask.sum(dim=-1)
248
- input_lengths_frames = self._get_feat_extract_output_lengths(input_lengths_samples)
249
- input_lengths_frames = torch.clamp(input_lengths_frames, max=log_probs_TNC.size(0))
250
 
251
- # 4) GOP feature calculation over tokens
252
  max_token_len = canonical_token_ids.size(1) if canonical_token_ids is not None else 0
253
- batch_combined_features_list = [[] for _ in range(batch_size)]
254
- token_mask_bool = token_mask.to(device=device).bool()
 
255
 
256
  lpp_log_prob_batch = self._calculate_log_prob(
257
  log_probs_TNC, input_lengths_frames, canonical_token_ids, token_lengths
258
  )
259
 
260
  for token_idx in range(max_token_len):
261
- current_token_ids = canonical_token_ids[:, token_idx]
262
- current_token_embeddings = self.token_embedding(current_token_ids)
263
- token_out_of_bounds_mask = (token_idx >= token_lengths)
264
- mask_column = token_mask_bool[:, token_idx] if token_mask_bool.dim() == 2 else token_mask_bool
265
- skip_mask = token_out_of_bounds_mask | ~mask_column
266
-
267
- if skip_mask.all():
268
- continue
269
-
270
- active_mask = ~skip_mask
271
 
272
  all_sub_log_probs = []
273
- if self.token_id_vocab:
274
- for sub_token_id in self.token_id_vocab:
275
- sub_ids_batch = canonical_token_ids.clone()
276
- sub_ids_batch[active_mask, token_idx] = sub_token_id
277
- log_prob_sub_batch = self._calculate_log_prob(
278
- log_probs_TNC, input_lengths_frames, sub_ids_batch, token_lengths
279
- )
280
- all_sub_log_probs.append(log_prob_sub_batch)
281
-
282
- if all_sub_log_probs:
283
- sub_lpr_batch = lpp_log_prob_batch.unsqueeze(1) - torch.stack(all_sub_log_probs, dim=1)
284
- sub_lpr_batch = torch.nan_to_num(sub_lpr_batch, nan=0.0, posinf=1e10, neginf=-1e10)
285
- if skip_mask.any():
286
- sub_lpr_batch[skip_mask, :] = 0.0
287
- else:
288
- sub_lpr_batch = torch.zeros((batch_size, 0), device=device)
289
 
290
- # Deletion GOP component
291
  del_lpr_list = []
292
  for b_idx in range(batch_size):
293
  if skip_mask[b_idx]:
294
  del_lpr_list.append(torch.tensor(-1e10, device=device))
295
- continue
296
- item_tokens = canonical_token_ids[b_idx, : token_lengths[b_idx]].tolist()
297
- del_tokens_list = item_tokens[:token_idx] + item_tokens[token_idx + 1:]
298
- if not del_tokens_list:
299
- log_prob_del_item = torch.tensor(-float('inf'), device=device)
300
  else:
301
- del_ids_tensor = torch.tensor([del_tokens_list], dtype=torch.long, device=canonical_token_ids.device)
302
- del_len_tensor = torch.tensor([len(del_tokens_list)], dtype=torch.long, device=canonical_token_ids.device)
303
- log_probs_item_TNC = log_probs_TNC[:, b_idx:b_idx + 1, :]
304
- input_len_item = input_lengths_frames[b_idx:b_idx + 1]
305
- log_prob_del_item = self._calculate_log_prob(
306
- log_probs_item_TNC, input_len_item, del_ids_tensor, del_len_tensor
307
- )
308
- if log_prob_del_item.dim() > 0:
309
- log_prob_del_item = log_prob_del_item[0]
310
- lpr_del_item = lpp_log_prob_batch[b_idx] - log_prob_del_item
311
- lpr_del_item = torch.nan_to_num(lpr_del_item, nan=0.0, posinf=1e10, neginf=-1e10)
312
- del_lpr_list.append(lpr_del_item)
 
 
 
 
 
 
 
 
 
313
  del_lpr_batch = torch.stack(del_lpr_list)
314
 
315
- gop_part = torch.cat([lpp_log_prob_batch.unsqueeze(1), sub_lpr_batch, del_lpr_batch.unsqueeze(1)], dim=1)
 
 
 
 
316
  combined_features = torch.cat([gop_part, current_token_embeddings], dim=1)
317
- for b_idx in range(batch_size):
318
- if active_mask[b_idx]:
319
- batch_combined_features_list[b_idx].append(combined_features[b_idx])
320
-
321
- # 5) Pad phoneme feature sequences and mask
322
- feature_lengths_list = [len(seq_list) for seq_list in batch_combined_features_list]
323
- if feature_lengths_list:
324
- feature_lengths = torch.tensor(feature_lengths_list, dtype=torch.long, device=device)
325
- target_pad_len = int(feature_lengths.max().item())
326
- else:
327
- feature_lengths = torch.zeros((batch_size,), dtype=torch.long, device=device)
328
- target_pad_len = 0
329
-
330
- padded_sequences = []
331
- for seq_list in batch_combined_features_list:
332
- if seq_list:
333
- seq_tensor = torch.stack(seq_list, dim=0)
334
- pad_len = target_pad_len - seq_tensor.size(0)
335
- if pad_len > 0:
336
- seq_tensor = F.pad(seq_tensor, (0, 0, 0, pad_len))
337
- padded_sequences.append(seq_tensor)
338
- else:
339
- padded_sequences.append(torch.zeros((target_pad_len, self.combined_feature_dim), device=device))
340
- transformer_input = torch.stack(padded_sequences, dim=0) if padded_sequences else torch.zeros((0, 0, self.combined_feature_dim), device=device)
341
- transformer_padding_mask = torch.arange(target_pad_len, device=device)[None, :] >= feature_lengths[:, None]
342
-
343
- # 6) GOP transformer
344
- gop_transformer_output = self.gop_transformer_encoder(
345
- transformer_input,
346
- src_key_padding_mask=transformer_padding_mask
347
- )
348
 
349
- # 7) Per-phoneme classifier (multi-head)
350
- final_logits = {
351
- head: classifier(gop_transformer_output)
352
- for head, classifier in self.classifiers.items()
353
- }
354
 
355
- # 8) Loss
356
  loss = None
357
  if labels is not None:
358
- if isinstance(labels, torch.Tensor):
359
- label_map = {next(iter(final_logits.keys())): labels}
360
- elif isinstance(labels, dict):
361
- label_map = labels
362
- else:
363
- raise TypeError("labels must be a Tensor or a dict of Tensors when provided.")
364
-
365
- active_mask = ~transformer_padding_mask.view(-1)
366
  for head, head_logits in final_logits.items():
367
- head_labels = label_map.get(head)
368
- if head_labels is None:
369
- continue
370
- logits_flat = head_logits.view(-1, head_logits.size(-1))
371
- labels_flat = head_labels.view(-1)
372
- active_logits = logits_flat[active_mask]
373
- #breakpoint()
374
- active_labels = labels_flat[active_mask]
375
- if active_labels.numel() == 0:
376
- continue
377
- head_loss = self.loss_fns[head](active_logits, active_labels)
378
- loss = head_loss if loss is None else loss + head_loss
379
- if loss is None:
380
- loss = torch.tensor(0.0, device=device, requires_grad=True)
381
 
382
  if not return_dict:
383
  output = (final_logits,)
 
12
 
13
 
14
  class GOPWav2Vec2Config(PretrainedConfig):
 
 
 
15
  model_type = "gop-wav2vec2"
16
 
17
  def __init__(
 
30
  unk_id: Optional[int] = None,
31
  bos_id: Optional[int] = None,
32
  eos_id: Optional[int] = None,
33
+ word_boundary_id: Optional[int] = None,
34
  token_id_vocab: Optional[List[int]] = None,
35
  ctc_config: Optional[dict] = None,
36
  **kwargs,
 
55
  self.unk_id = unk_id
56
  self.bos_id = bos_id
57
  self.eos_id = eos_id
58
+ self.word_boundary_id = word_boundary_id
59
  self.token_id_vocab = token_id_vocab
60
  self.ctc_config = ctc_config
61
 
62
 
63
  class GOPPhonemeClassifier(PreTrainedModel):
 
 
 
 
 
64
  config_class = GOPWav2Vec2Config
65
+
66
  def __init__(self, config: GOPWav2Vec2Config, load_pretrained_backbone: bool = False):
67
  super().__init__(config)
68
+
69
  if config.ctc_config is not None:
70
  backbone_config = Wav2Vec2Config.from_dict(config.ctc_config)
71
  elif config.base_model_name_or_path is not None:
 
76
  self.ctc_model = Wav2Vec2ForCTC(backbone_config)
77
  self.config.ctc_config = backbone_config.to_dict()
78
 
 
79
  self.blank_id = config.pad_id
80
  self.unk_id = config.unk_id
81
  self.bos_id = config.bos_id
82
  self.eos_id = config.eos_id
83
  self.pad_id = self.blank_id
84
+ self.word_boundary_id = config.word_boundary_id
85
 
86
  special_ids = {self.blank_id, self.unk_id, self.bos_id, self.eos_id, self.pad_id}
87
  self.special_ids = {i for i in special_ids if i is not None}
88
 
89
  vocab_size = int(self.ctc_model.config.vocab_size)
90
  self.token_id_vocab = (
91
+ config.token_id_vocab
92
+ if config.token_id_vocab is not None
93
+ else [i for i in range(vocab_size) if i not in self.special_ids]
94
  )
95
 
96
  self.gop_feature_dim = 1 + len(self.token_id_vocab) + 1
97
  self.embedding_dim = int(config.gop_embedding_dim)
98
+ self.token_embedding = nn.Embedding(
99
+ vocab_size,
100
+ self.embedding_dim,
101
+ padding_idx=self.pad_id if self.pad_id is not None else 0,
102
+ )
103
  self.combined_feature_dim = self.gop_feature_dim + self.embedding_dim
104
+ self.gop_part_dropout = nn.Dropout(config.gop_transformer_dropout)
105
+ self.gop_part_norm = nn.LayerNorm(self.gop_feature_dim)
106
+
107
+ self.lstm_hidden_size = int(config.gop_transformer_dim_feedforward)
108
+ self.gop_rnn = nn.LSTM(
109
+ input_size=self.combined_feature_dim,
110
+ hidden_size=self.lstm_hidden_size,
111
+ num_layers=config.gop_transformer_nlayers,
112
+ dropout=config.gop_transformer_dropout if config.gop_transformer_nlayers > 1 else 0.0,
113
+ bidirectional=True,
114
  batch_first=True,
115
  )
 
116
 
117
  head_label_config = getattr(config, "gop_head_labels", None)
118
  if head_label_config is None:
 
120
  raise ValueError("Config must provide gop_head_labels or num_gop_labels for the classifier.")
121
  head_label_config = {"default": int(config.num_gop_labels)}
122
  self.head_label_config = {str(k): int(v) for k, v in head_label_config.items()}
123
+ self.head_hidden_dim = self.lstm_hidden_size * 2
124
+ self.head_mlps = nn.ModuleDict(
125
+ {
126
+ head: nn.Sequential(
127
+ nn.Linear(self.head_hidden_dim, self.head_hidden_dim),
128
+ nn.LeakyReLU(),
129
+ )
130
+ for head in self.head_label_config.keys()
131
+ }
132
+ )
133
+ self.classifiers = nn.ModuleDict(
134
+ {
135
+ head: nn.Linear(self.head_hidden_dim, num_labels)
136
+ for head, num_labels in self.head_label_config.items()
137
+ }
138
+ )
139
 
140
  self._init_losses()
 
141
  self.post_init()
142
 
143
  if load_pretrained_backbone:
 
165
  elif weights is not None:
166
  head_weights = weights
167
  if head_weights is not None:
168
+ head_weights = (
169
+ head_weights
170
+ if isinstance(head_weights, torch.Tensor)
171
+ else torch.tensor(head_weights, dtype=torch.float)
172
+ )
173
  loss_modules[head] = OrdinalLogLoss(
174
  num_classes=int(num_labels),
175
  alpha=alpha,
176
+ reduction="mean",
177
  class_weights=head_weights,
178
  )
179
  self.loss_fns = nn.ModuleDict(loss_modules)
 
185
  target_ids: torch.Tensor,
186
  target_lengths: torch.Tensor,
187
  ) -> torch.Tensor:
 
188
  target_ids_cpu = target_ids.cpu()
189
  target_lengths_cpu = target_lengths.cpu()
190
  log_probs_cpu = log_probs_TNC.cpu()
 
192
 
193
  targets_flat = []
194
  for i in range(target_ids_cpu.size(0)):
195
+ valid_targets = target_ids_cpu[i, : target_lengths_cpu[i]]
196
  targets_flat.append(valid_targets)
197
  targets_cat = torch.cat(targets_flat) if targets_flat else torch.tensor([], dtype=torch.long)
198
 
199
  if target_lengths_cpu.sum() == 0:
200
+ return torch.full((log_probs_TNC.size(1),), -float("inf"), device=log_probs_TNC.device)
201
 
202
+ ctc_loss_fn = torch.nn.CTCLoss(blank=self.blank_id, reduction="none", zero_infinity=True)
203
  try:
204
  loss_per_item = ctc_loss_fn(log_probs_cpu, targets_cat, input_lengths_cpu, target_lengths_cpu)
205
  return -loss_per_item.to(log_probs_TNC.device)
206
+ except Exception as exc:
207
+ warnings.warn(f"CTCLoss calculation failed: {exc}. Returning -inf for batch.")
208
+ return torch.full((log_probs_TNC.size(1),), -float("inf"), device=log_probs_TNC.device)
209
 
210
  def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor):
 
211
  def _conv_out_length(input_length, kernel_size, stride):
212
  return torch.div(input_length - kernel_size, stride, rounding_mode="floor") + 1
213
 
 
215
  input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
216
  return input_lengths
217
 
218
+ def _prepare_labels(
219
+ self, labels: Optional[Union[torch.Tensor, Dict[str, torch.Tensor]]]
220
+ ) -> Optional[Dict[str, torch.Tensor]]:
221
+ if labels is None:
222
+ return None
223
+ if isinstance(labels, torch.Tensor):
224
+ if len(self.head_label_config) != 1:
225
+ raise ValueError("Multi-head setup requires `labels` to be a dict keyed by head name.")
226
+ head_name = next(iter(self.head_label_config))
227
+ return {head_name: labels}
228
+ if not isinstance(labels, dict):
229
+ raise ValueError("`labels` must be a Tensor for single-head setups or a dict for multi-head setups.")
230
+ return labels
231
+
232
+ def _validate_inputs(
233
+ self,
234
+ input_values: torch.Tensor,
235
+ attention_mask: torch.Tensor,
236
+ canonical_token_ids: torch.Tensor,
237
+ token_lengths: torch.Tensor,
238
+ token_mask: torch.Tensor,
239
+ labels: Optional[Dict[str, torch.Tensor]],
240
+ ) -> torch.Tensor:
241
+ if input_values.dim() != 2:
242
+ raise ValueError(f"`input_values` must be 2D (batch, time); got shape {tuple(input_values.shape)}.")
243
+ if attention_mask is None or attention_mask.shape != input_values.shape:
244
+ raise ValueError("`attention_mask` must be provided and match the shape of `input_values`.")
245
+
246
+ if canonical_token_ids is None or token_lengths is None or token_mask is None:
247
+ raise ValueError("`canonical_token_ids`, `token_lengths`, and `token_mask` are required.")
248
+
249
+ if canonical_token_ids.dim() != 2:
250
+ raise ValueError(
251
+ f"`canonical_token_ids` must be 2D (batch, tokens); got shape {tuple(canonical_token_ids.shape)}."
252
+ )
253
+ batch_size, max_tokens = canonical_token_ids.shape
254
+ if batch_size != input_values.shape[0]:
255
+ raise ValueError("Batch size mismatch between `input_values` and `canonical_token_ids`.")
256
+
257
+ if token_mask.dim() != 2 or token_mask.shape != canonical_token_ids.shape:
258
+ raise ValueError("`token_mask` must be the same shape as `canonical_token_ids`.")
259
+
260
+ if token_lengths.dim() != 1 or token_lengths.shape[0] != batch_size:
261
+ raise ValueError("`token_lengths` must be 1D with length equal to batch size.")
262
+ if torch.any(token_lengths < 0):
263
+ raise ValueError("`token_lengths` must be non-negative.")
264
+ if torch.any(token_lengths > max_tokens):
265
+ raise ValueError("`token_lengths` cannot exceed the number of provided tokens.")
266
+
267
+ token_mask_bool = token_mask.to(device=canonical_token_ids.device, dtype=torch.bool)
268
+ arange_positions = torch.arange(max_tokens, device=canonical_token_ids.device)
269
+ padded_active = token_mask_bool & (arange_positions.unsqueeze(0) >= token_lengths.unsqueeze(1))
270
+ if torch.any(padded_active):
271
+ raise ValueError("`token_mask` marks padded positions as valid (indices >= token_lengths).")
272
+ if torch.any(token_mask_bool.sum(dim=1) > token_lengths):
273
+ raise ValueError("`token_mask` has more active positions than `token_lengths` for some batch items.")
274
+
275
+ if labels is not None:
276
+ if not isinstance(labels, dict):
277
+ raise ValueError("`labels` must be a dict keyed by head name after normalization.")
278
+ expected_heads = set(self.head_label_config.keys())
279
+ label_heads = set(labels.keys())
280
+ unknown_heads = label_heads - expected_heads
281
+ missing_heads = expected_heads - label_heads
282
+ if unknown_heads:
283
+ raise ValueError(f"Unexpected label heads provided: {sorted(unknown_heads)}.")
284
+ if missing_heads:
285
+ raise ValueError(f"Missing label heads: {sorted(missing_heads)}.")
286
+ for head, head_labels in labels.items():
287
+ if head_labels.shape != canonical_token_ids.shape:
288
+ raise ValueError(
289
+ f"Labels for head '{head}' must match `canonical_token_ids` shape "
290
+ f"{tuple(canonical_token_ids.shape)}; got {tuple(head_labels.shape)}."
291
+ )
292
+ if head_labels.dtype not in (torch.int64, torch.long):
293
+ raise ValueError(f"Labels for head '{head}' must be integer tensors; got dtype {head_labels.dtype}.")
294
+ masked_positions = token_mask_bool.logical_not()
295
+ bad_mask = masked_positions & (head_labels != -100)
296
+ if torch.any(bad_mask):
297
+ bad_count = int(bad_mask.sum().item())
298
+ raise ValueError(
299
+ f"Labels for head '{head}' must be -100 at masked positions; found {bad_count} mismatches."
300
+ )
301
+
302
+ return token_mask_bool
303
+
304
  def forward(
305
  self,
306
  input_values: torch.Tensor,
 
313
  return_dict: Optional[bool] = None,
314
  labels: Optional[torch.Tensor] = None,
315
  ) -> SequenceClassifierOutput:
 
316
  device = input_values.device
317
+ assert attention_mask is not None
318
 
319
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
320
+ labels = self._prepare_labels(labels)
321
 
322
  if self.training or labels is not None:
323
  if canonical_token_ids is None or token_lengths is None or token_mask is None:
 
326
  if token_mask is None:
327
  raise ValueError("`token_mask` must be provided to GOPPhonemeClassifier.forward.")
328
 
329
+ token_mask_bool = self._validate_inputs(
330
+ input_values=input_values,
331
+ attention_mask=attention_mask,
332
+ canonical_token_ids=canonical_token_ids,
333
+ token_lengths=token_lengths,
334
+ token_mask=token_mask,
335
+ labels=labels,
336
+ ).to(device=device)
337
+
338
  outputs = self.ctc_model.wav2vec2(
339
  input_values,
340
  attention_mask=attention_mask,
 
344
  )
345
  hidden_states = outputs.last_hidden_state
346
 
 
347
  logits_ctc = self.ctc_model.lm_head(hidden_states)
348
  log_probs_ctc = F.log_softmax(logits_ctc, dim=-1)
349
  log_probs_TNC = log_probs_ctc.permute(1, 0, 2).contiguous()
350
 
 
351
  batch_size = input_values.size(0)
352
+ input_lengths_samples = attention_mask.sum(dim=-1)
353
+ input_lengths_frames = self._get_feat_extract_output_lengths(input_lengths_samples)
354
+ input_lengths_frames = torch.clamp(input_lengths_frames, max=log_probs_TNC.size(0))
 
 
 
355
 
 
356
  max_token_len = canonical_token_ids.size(1) if canonical_token_ids is not None else 0
357
+ batch_combined_features_list = []
358
+
359
+ token_embeddings = self.token_embedding(canonical_token_ids)
360
 
361
  lpp_log_prob_batch = self._calculate_log_prob(
362
  log_probs_TNC, input_lengths_frames, canonical_token_ids, token_lengths
363
  )
364
 
365
  for token_idx in range(max_token_len):
366
+ current_token_embeddings = token_embeddings[:, token_idx, :]
367
+ active_mask = token_mask_bool[:, token_idx]
368
+ skip_mask = ~active_mask
 
 
 
 
 
 
 
369
 
370
  all_sub_log_probs = []
371
+ for sub_token_id in self.token_id_vocab:
372
+ sub_ids_batch = canonical_token_ids.clone()
373
+ sub_ids_batch[active_mask, token_idx] = sub_token_id
374
+ log_prob_sub_batch = self._calculate_log_prob(
375
+ log_probs_TNC, input_lengths_frames, sub_ids_batch, token_lengths
376
+ )
377
+ all_sub_log_probs.append(log_prob_sub_batch)
378
+
379
+ sub_log_probs_batch = torch.stack(all_sub_log_probs, dim=1)
380
+ sub_log_probs_batch = F.log_softmax(sub_log_probs_batch, dim=1)
381
+ sub_log_probs_batch = torch.nan_to_num(sub_log_probs_batch, nan=0.0, posinf=1e10, neginf=-1e10)
382
+ if skip_mask.any():
383
+ sub_log_probs_batch[skip_mask, :] = 0.0
 
 
 
384
 
 
385
  del_lpr_list = []
386
  for b_idx in range(batch_size):
387
  if skip_mask[b_idx]:
388
  del_lpr_list.append(torch.tensor(-1e10, device=device))
 
 
 
 
 
389
  else:
390
+ item_tokens = canonical_token_ids[b_idx, : token_lengths[b_idx]].tolist()
391
+ del_tokens_list = item_tokens[:token_idx] + item_tokens[token_idx + 1 :]
392
+ if not del_tokens_list:
393
+ log_prob_del_item = torch.tensor(-float("inf"), device=device)
394
+ else:
395
+ del_ids_tensor = torch.tensor(
396
+ [del_tokens_list], dtype=torch.long, device=canonical_token_ids.device
397
+ )
398
+ del_len_tensor = torch.tensor(
399
+ [len(del_tokens_list)], dtype=torch.long, device=canonical_token_ids.device
400
+ )
401
+ log_probs_item_TNC = log_probs_TNC[:, b_idx : b_idx + 1, :]
402
+ input_len_item = input_lengths_frames[b_idx : b_idx + 1]
403
+ log_prob_del_item = self._calculate_log_prob(
404
+ log_probs_item_TNC, input_len_item, del_ids_tensor, del_len_tensor
405
+ )
406
+ if log_prob_del_item.dim() > 0:
407
+ log_prob_del_item = log_prob_del_item[0]
408
+ lpr_del_item = lpp_log_prob_batch[b_idx] - log_prob_del_item
409
+ lpr_del_item = torch.nan_to_num(lpr_del_item, nan=0.0, posinf=1e10, neginf=-1e10)
410
+ del_lpr_list.append(lpr_del_item)
411
  del_lpr_batch = torch.stack(del_lpr_list)
412
 
413
+ gop_part = torch.cat(
414
+ [lpp_log_prob_batch.unsqueeze(1), sub_log_probs_batch, del_lpr_batch.unsqueeze(1)], dim=1
415
+ )
416
+ gop_part = self.gop_part_norm(gop_part)
417
+ gop_part = self.gop_part_dropout(gop_part)
418
  combined_features = torch.cat([gop_part, current_token_embeddings], dim=1)
419
+ batch_combined_features_list.append(combined_features)
420
+
421
+ transformer_input = torch.stack(batch_combined_features_list, dim=1)
422
+
423
+ gop_rnn_output, _ = self.gop_rnn(transformer_input)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
424
 
425
+ final_logits = {}
426
+ for head, classifier in self.classifiers.items():
427
+ head_features = self.head_mlps[head](gop_rnn_output)
428
+ final_logits[head] = classifier(head_features)
 
429
 
 
430
  loss = None
431
  if labels is not None:
432
+ loss = 0.0
 
 
 
 
 
 
 
433
  for head, head_logits in final_logits.items():
434
+ head_labels = labels[head]
435
+ head_loss = self.loss_fns[head](head_logits, head_labels)
436
+ loss += head_loss
 
 
 
 
 
 
 
 
 
 
 
437
 
438
  if not return_dict:
439
  output = (final_logits,)
models.py CHANGED
@@ -4,19 +4,21 @@ import torch.nn as nn
4
 
5
  class OrdinalLogLoss(nn.Module):
6
  def __init__(
7
- self,
8
- num_classes,
9
- alpha=1.0,
10
- reduction='mean',
11
- distance_matrix=None,
12
- class_weights=None,
13
- eps=1e-8
 
14
  ):
15
  super(OrdinalLogLoss, self).__init__()
16
  self.num_classes = num_classes
17
  self.alpha = alpha
18
  self.reduction = reduction
19
  self.eps = eps
 
20
 
21
  if distance_matrix is not None:
22
  assert distance_matrix.shape == (num_classes, num_classes), \
@@ -35,21 +37,39 @@ class OrdinalLogLoss(nn.Module):
35
  self.class_weights = None
36
 
37
  def forward(self, logits, target):
38
- probs = torch.softmax(logits, dim=1).clamp(max=1 - self.eps)
39
- distances = self.distance_matrix[target] ** self.alpha
40
- per_class_loss = -torch.log(1 - probs + self.eps)
41
- loss = (per_class_loss * distances).sum(dim=1) # shape (batch_size,)
42
-
43
- # Apply class weights
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  if self.class_weights is not None:
45
- sample_weights = self.class_weights[target]
46
- loss = loss * sample_weights
 
 
 
 
 
47
 
48
- # Apply reduction
49
  if self.reduction == 'mean':
50
- return loss.mean()
51
  elif self.reduction == 'sum':
52
- return loss.sum()
53
  else:
54
- return loss
55
-
 
4
 
5
  class OrdinalLogLoss(nn.Module):
6
  def __init__(
7
+ self,
8
+ num_classes,
9
+ alpha=1.0,
10
+ reduction='mean',
11
+ distance_matrix=None,
12
+ class_weights=None,
13
+ eps=1e-8,
14
+ ignore_index=-100,
15
  ):
16
  super(OrdinalLogLoss, self).__init__()
17
  self.num_classes = num_classes
18
  self.alpha = alpha
19
  self.reduction = reduction
20
  self.eps = eps
21
+ self.ignore_index = ignore_index
22
 
23
  if distance_matrix is not None:
24
  assert distance_matrix.shape == (num_classes, num_classes), \
 
37
  self.class_weights = None
38
 
39
  def forward(self, logits, target):
40
+ if logits.numel() == 0:
41
+ return logits.new_tensor(0.0)
42
+
43
+ probs = torch.softmax(logits, dim=-1).clamp(max=1 - self.eps)
44
+
45
+ if self.ignore_index is not None:
46
+ valid_mask = target != self.ignore_index
47
+ else:
48
+ valid_mask = torch.ones_like(target, dtype=torch.bool)
49
+
50
+ if not valid_mask.any():
51
+ if self.reduction == 'none':
52
+ return logits.new_zeros(target.shape, dtype=logits.dtype)
53
+ return logits.new_tensor(0.0)
54
+
55
+ active_probs = probs[valid_mask]
56
+ active_target = target[valid_mask]
57
+ distances = self.distance_matrix[active_target] ** self.alpha
58
+ per_class_loss = -torch.log(1 - active_probs + self.eps)
59
+ loss_active = (per_class_loss * distances).sum(dim=-1)
60
+
61
  if self.class_weights is not None:
62
+ sample_weights = self.class_weights[active_target]
63
+ loss_active = loss_active * sample_weights
64
+
65
+ if self.reduction == 'none':
66
+ full_loss = logits.new_zeros(target.shape, dtype=logits.dtype)
67
+ full_loss[valid_mask] = loss_active
68
+ return full_loss
69
 
 
70
  if self.reduction == 'mean':
71
+ return loss_active.mean()
72
  elif self.reduction == 'sum':
73
+ return loss_active.sum()
74
  else:
75
+ raise ValueError(f"Unsupported reduction: {self.reduction}")
 
utils.py CHANGED
@@ -1,24 +1,27 @@
1
- from typing import List
 
 
 
2
  import torch
3
  import torchaudio
4
- from transformers import Wav2Vec2Processor, QuantoConfig
 
5
  from constants import MAX_AUDIO_DURATION_SECONDS, MONO_CHANNEL, SAMPLING_RATE
6
  from gop_model import GOPPhonemeClassifier
7
- import logging
8
 
9
  logger = logging.getLogger(__name__)
10
 
11
 
12
  def load_model_and_processor(model_repo_id: str):
13
- logger.info(f"Loading model and processor from Hugging Face Hub: {model_repo_id}")
14
 
15
  quantization_config = QuantoConfig(weights="int8")
16
- logger.info("Applying INT8 dynamic quantization during model loading...")
17
 
18
  model = GOPPhonemeClassifier.from_pretrained(
19
  model_repo_id,
20
  quantization_config=quantization_config,
21
- device_map="auto"
22
  )
23
  processor = Wav2Vec2Processor.from_pretrained(model_repo_id)
24
  model.eval()
@@ -36,6 +39,46 @@ def validate_phonemes(phoneme_text, allowed_phonemes):
36
  return None
37
 
38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  def run_inference(audio_file_path: str, transcript: str, model: GOPPhonemeClassifier, processor: Wav2Vec2Processor):
40
  if not audio_file_path or not transcript:
41
  return "<p style='text-align:center; color:red;'>Please provide both an audio file and the transcript.</p>"
@@ -56,19 +99,14 @@ def run_inference(audio_file_path: str, transcript: str, model: GOPPhonemeClassi
56
 
57
  audio_input = waveform.squeeze(0)
58
  processed_audio = processor(audio_input, sampling_rate=SAMPLING_RATE, return_tensors="pt", padding=True)
59
- input_values = processed_audio.input_values.to(model.device)
60
- attention_mask = processed_audio.attention_mask.to(model.device)
61
-
62
- phonemes: List[str] = transcript.strip().split()
63
- tokenizer = processor.tokenizer
64
- unk_id = getattr(tokenizer, "unk_token_id", None)
65
- ids = tokenizer.convert_tokens_to_ids(phonemes)
66
- if isinstance(ids, int):
67
- ids = [ids]
68
- ids = [i if i is not None else unk_id for i in ids]
69
- canonical_token_ids = torch.tensor([ids], dtype=torch.long).to(model.device)
70
- token_lengths = torch.tensor([len(ids)], dtype=torch.long).to(model.device)
71
- token_mask = torch.ones_like(canonical_token_ids).to(model.device)
72
 
73
  with torch.no_grad():
74
  outputs = model(
@@ -76,19 +114,11 @@ def run_inference(audio_file_path: str, transcript: str, model: GOPPhonemeClassi
76
  attention_mask=attention_mask,
77
  canonical_token_ids=canonical_token_ids,
78
  token_lengths=token_lengths,
79
- token_mask=token_mask
80
  )
81
 
82
- logits = outputs.logits
83
-
84
- head_name = next(iter(logits))
85
- scores_tensor = logits[head_name]
86
- predicted_scores = torch.argmax(scores_tensor, dim=-1)
87
-
88
- tokens = processor.tokenizer.convert_ids_to_tokens(canonical_token_ids[0])
89
-
90
- return predicted_scores, tokens, token_lengths
91
 
92
- except Exception as e:
93
- logger.error(f"An error occurred during inference: {e}", exc_info=True)
94
- return f"<p style='text-align:center; color:red;'>An error occurred: {e}</p>"
 
1
+ from typing import Dict, List, Tuple
2
+
3
+ import logging
4
+
5
  import torch
6
  import torchaudio
7
+ from transformers import QuantoConfig, Wav2Vec2Processor
8
+
9
  from constants import MAX_AUDIO_DURATION_SECONDS, MONO_CHANNEL, SAMPLING_RATE
10
  from gop_model import GOPPhonemeClassifier
 
11
 
12
  logger = logging.getLogger(__name__)
13
 
14
 
15
  def load_model_and_processor(model_repo_id: str):
16
+ logger.info("Loading model and processor from Hugging Face Hub: %s", model_repo_id)
17
 
18
  quantization_config = QuantoConfig(weights="int8")
19
+ logger.info("Applying INT8 dynamic quantization during model loading")
20
 
21
  model = GOPPhonemeClassifier.from_pretrained(
22
  model_repo_id,
23
  quantization_config=quantization_config,
24
+ device_map="auto",
25
  )
26
  processor = Wav2Vec2Processor.from_pretrained(model_repo_id)
27
  model.eval()
 
39
  return None
40
 
41
 
42
+ def _prepare_canonical_tokens(transcript: str, processor: Wav2Vec2Processor, device: torch.device):
43
+ phonemes: List[str] = transcript.strip().split()
44
+ if not phonemes:
45
+ raise ValueError("Please enter at least one phoneme.")
46
+
47
+ token_mask_values = [token != "|" for token in phonemes]
48
+ if not any(token_mask_values):
49
+ raise ValueError("The phoneme sequence must contain at least one non-boundary token.")
50
+
51
+ tokenizer = processor.tokenizer
52
+ unk_id = getattr(tokenizer, "unk_token_id", None)
53
+ ids = tokenizer.convert_tokens_to_ids(phonemes)
54
+ if isinstance(ids, int):
55
+ ids = [ids]
56
+ ids = [token_id if token_id is not None else unk_id for token_id in ids]
57
+
58
+ canonical_token_ids = torch.tensor([ids], dtype=torch.long, device=device)
59
+ token_lengths = torch.tensor([len(ids)], dtype=torch.long, device=device)
60
+ token_mask = torch.tensor([token_mask_values], dtype=torch.bool, device=device)
61
+
62
+ display_tokens = [token for token, is_active in zip(phonemes, token_mask_values) if is_active]
63
+ return canonical_token_ids, token_lengths, token_mask, display_tokens
64
+
65
+
66
+ def _extract_head_predictions(
67
+ logits_by_head: Dict[str, torch.Tensor],
68
+ token_mask: torch.Tensor,
69
+ display_tokens: List[str],
70
+ ) -> Dict[str, Tuple[List[int], List[str]]]:
71
+ active_mask = token_mask[0].bool()
72
+ results: Dict[str, Tuple[List[int], List[str]]] = {}
73
+
74
+ for head_name, head_logits in logits_by_head.items():
75
+ predicted_scores = torch.argmax(head_logits, dim=-1)[0]
76
+ filtered_scores = predicted_scores[active_mask].detach().cpu().tolist()
77
+ results[head_name] = (filtered_scores, display_tokens)
78
+
79
+ return results
80
+
81
+
82
  def run_inference(audio_file_path: str, transcript: str, model: GOPPhonemeClassifier, processor: Wav2Vec2Processor):
83
  if not audio_file_path or not transcript:
84
  return "<p style='text-align:center; color:red;'>Please provide both an audio file and the transcript.</p>"
 
99
 
100
  audio_input = waveform.squeeze(0)
101
  processed_audio = processor(audio_input, sampling_rate=SAMPLING_RATE, return_tensors="pt", padding=True)
102
+
103
+ model_device = next(model.parameters()).device
104
+ input_values = processed_audio.input_values.to(model_device)
105
+ attention_mask = processed_audio.attention_mask.to(model_device)
106
+
107
+ canonical_token_ids, token_lengths, token_mask, display_tokens = _prepare_canonical_tokens(
108
+ transcript, processor, model_device
109
+ )
 
 
 
 
 
110
 
111
  with torch.no_grad():
112
  outputs = model(
 
114
  attention_mask=attention_mask,
115
  canonical_token_ids=canonical_token_ids,
116
  token_lengths=token_lengths,
117
+ token_mask=token_mask,
118
  )
119
 
120
+ return _extract_head_predictions(outputs.logits, token_mask, display_tokens)
 
 
 
 
 
 
 
 
121
 
122
+ except Exception as exc:
123
+ logger.error("An error occurred during inference: %s", exc, exc_info=True)
124
+ return f"<p style='text-align:center; color:red;'>An error occurred: {exc}</p>"