mazesmazes commited on
Commit
8d11529
·
verified ·
1 Parent(s): f6cb04c

Update custom model files, README, and requirements

Browse files
Files changed (1) hide show
  1. alignment.py +51 -11
alignment.py CHANGED
@@ -3,6 +3,14 @@
3
  import numpy as np
4
  import torch
5
 
 
 
 
 
 
 
 
 
6
 
7
  def _get_device() -> str:
8
  """Get best available device for non-transformers models."""
@@ -44,6 +52,30 @@ class ForcedAligner:
44
  cls._dictionary = {c: i for i, c in enumerate(cls._labels)}
45
  return cls._model, cls._labels, cls._dictionary
46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  @staticmethod
48
  def _get_trellis(emission: torch.Tensor, tokens: list[int], blank_id: int = 0) -> torch.Tensor:
49
  """Build trellis for forced alignment using forward algorithm.
@@ -53,7 +85,7 @@ class ForcedAligner:
53
 
54
  Args:
55
  emission: Log-softmax emission matrix of shape (num_frames, num_classes)
56
- tokens: List of target token indices
57
  blank_id: Index of the blank/CTC token (default 0)
58
 
59
  Returns:
@@ -71,7 +103,13 @@ class ForcedAligner:
71
  stay = trellis[t, j] + emission[t, blank_id]
72
 
73
  # Move: emit token j and advance to j+1 tokens
74
- move = trellis[t, j - 1] + emission[t, tokens[j - 1]] if j > 0 else -float("inf")
 
 
 
 
 
 
75
 
76
  trellis[t + 1, j] = max(stay, move) # Viterbi: take best path
77
 
@@ -116,7 +154,10 @@ class ForcedAligner:
116
  while t > 0 and j > 0:
117
  # Check: did we transition from j-1 to j at frame t-1?
118
  stay_score = trellis[t - 1, j] + emission[t - 1, blank_id]
119
- move_score = trellis[t - 1, j - 1] + emission[t - 1, tokens[j - 1]]
 
 
 
120
 
121
  if move_score >= stay_score:
122
  # Token j-1 was emitted at frame t-1
@@ -148,11 +189,6 @@ class ForcedAligner:
148
 
149
  return token_spans
150
 
151
- # Offset compensation for Wav2Vec2-BASE systematic bias (in seconds)
152
- # Calibrated on librispeech-alignments dataset
153
- START_OFFSET = 0.06 # Subtract from start times (shift earlier)
154
- END_OFFSET = -0.03 # Add to end times (shift later)
155
-
156
  @classmethod
157
  def align(
158
  cls,
@@ -207,16 +243,20 @@ class ForcedAligner:
207
 
208
  emission = emissions[0].cpu()
209
 
210
- # Normalize text: uppercase, keep only valid characters
211
  transcript = text.upper()
212
 
213
  # Build tokens from transcript (including word separators)
 
214
  tokens = []
215
  for char in transcript:
216
  if char in dictionary:
217
  tokens.append(dictionary[char])
218
  elif char == " ":
219
  tokens.append(dictionary.get("|", dictionary.get(" ", 0)))
 
 
 
220
 
221
  if not tokens:
222
  return []
@@ -229,8 +269,8 @@ class ForcedAligner:
229
  frame_duration = 320 / cls._bundle.sample_rate
230
 
231
  # Apply separate offset compensation for start/end (Wav2Vec2 systematic bias)
232
- start_offset = cls.START_OFFSET
233
- end_offset = cls.END_OFFSET
234
 
235
  # Group aligned tokens into words based on pipe separator
236
  words = text.split()
 
3
  import numpy as np
4
  import torch
5
 
6
+ # Wildcard token ID for out-of-vocabulary characters
7
+ WILDCARD_TOKEN = -1
8
+
9
+ # Offset compensation for Wav2Vec2-BASE systematic bias (in seconds)
10
+ # Calibrated on librispeech-alignments dataset
11
+ START_OFFSET = 0.06 # Subtract from start times (shift earlier)
12
+ END_OFFSET = -0.03 # Add to end times (shift later)
13
+
14
 
15
  def _get_device() -> str:
16
  """Get best available device for non-transformers models."""
 
52
  cls._dictionary = {c: i for i, c in enumerate(cls._labels)}
53
  return cls._model, cls._labels, cls._dictionary
54
 
55
+ @staticmethod
56
+ def _get_emission_score(
57
+ emission: torch.Tensor, token: int, blank_id: int = 0
58
+ ) -> torch.Tensor:
59
+ """Get emission score for a token, handling wildcards.
60
+
61
+ For wildcard tokens (WILDCARD_TOKEN), returns the max score over all
62
+ non-blank tokens - allowing any character to match.
63
+
64
+ Args:
65
+ emission: Emission vector for a single frame (num_classes,)
66
+ token: Token index, or WILDCARD_TOKEN for out-of-vocabulary chars
67
+ blank_id: Index of the blank/CTC token
68
+
69
+ Returns:
70
+ Emission score (scalar tensor)
71
+ """
72
+ if token == WILDCARD_TOKEN:
73
+ # Wildcard: take max over all non-blank tokens
74
+ mask = torch.ones(emission.size(0), dtype=torch.bool)
75
+ mask[blank_id] = False
76
+ return emission[mask].max()
77
+ return emission[token]
78
+
79
  @staticmethod
80
  def _get_trellis(emission: torch.Tensor, tokens: list[int], blank_id: int = 0) -> torch.Tensor:
81
  """Build trellis for forced alignment using forward algorithm.
 
85
 
86
  Args:
87
  emission: Log-softmax emission matrix of shape (num_frames, num_classes)
88
+ tokens: List of target token indices (WILDCARD_TOKEN for OOV chars)
89
  blank_id: Index of the blank/CTC token (default 0)
90
 
91
  Returns:
 
103
  stay = trellis[t, j] + emission[t, blank_id]
104
 
105
  # Move: emit token j and advance to j+1 tokens
106
+ if j > 0:
107
+ token_score = ForcedAligner._get_emission_score(
108
+ emission[t], tokens[j - 1], blank_id
109
+ )
110
+ move = trellis[t, j - 1] + token_score
111
+ else:
112
+ move = -float("inf")
113
 
114
  trellis[t + 1, j] = max(stay, move) # Viterbi: take best path
115
 
 
154
  while t > 0 and j > 0:
155
  # Check: did we transition from j-1 to j at frame t-1?
156
  stay_score = trellis[t - 1, j] + emission[t - 1, blank_id]
157
+ token_score = ForcedAligner._get_emission_score(
158
+ emission[t - 1], tokens[j - 1], blank_id
159
+ )
160
+ move_score = trellis[t - 1, j - 1] + token_score
161
 
162
  if move_score >= stay_score:
163
  # Token j-1 was emitted at frame t-1
 
189
 
190
  return token_spans
191
 
 
 
 
 
 
192
  @classmethod
193
  def align(
194
  cls,
 
243
 
244
  emission = emissions[0].cpu()
245
 
246
+ # Normalize text: uppercase
247
  transcript = text.upper()
248
 
249
  # Build tokens from transcript (including word separators)
250
+ # Unknown characters get WILDCARD_TOKEN which matches any non-blank emission
251
  tokens = []
252
  for char in transcript:
253
  if char in dictionary:
254
  tokens.append(dictionary[char])
255
  elif char == " ":
256
  tokens.append(dictionary.get("|", dictionary.get(" ", 0)))
257
+ else:
258
+ # Out-of-vocabulary character - use wildcard
259
+ tokens.append(WILDCARD_TOKEN)
260
 
261
  if not tokens:
262
  return []
 
269
  frame_duration = 320 / cls._bundle.sample_rate
270
 
271
  # Apply separate offset compensation for start/end (Wav2Vec2 systematic bias)
272
+ start_offset = START_OFFSET
273
+ end_offset = END_OFFSET
274
 
275
  # Group aligned tokens into words based on pipe separator
276
  words = text.split()