mazesmazes commited on
Commit
86c5e8d
·
verified ·
1 Parent(s): 8d11529

Update custom model files, README, and requirements

Browse files
Files changed (1) hide show
  1. alignment.py +177 -44
alignment.py CHANGED
@@ -1,10 +1,13 @@
1
  """Forced alignment for word-level timestamps using Wav2Vec2."""
2
 
 
 
 
3
  import numpy as np
4
  import torch
5
 
6
- # Wildcard token ID for out-of-vocabulary characters
7
- WILDCARD_TOKEN = -1
8
 
9
  # Offset compensation for Wav2Vec2-BASE systematic bias (in seconds)
10
  # Calibrated on librispeech-alignments dataset
@@ -21,6 +24,25 @@ def _get_device() -> str:
21
  return "cpu"
22
 
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  class ForcedAligner:
25
  """Lazy-loaded forced aligner for word-level timestamps using torchaudio wav2vec2.
26
 
@@ -52,30 +74,6 @@ class ForcedAligner:
52
  cls._dictionary = {c: i for i, c in enumerate(cls._labels)}
53
  return cls._model, cls._labels, cls._dictionary
54
 
55
- @staticmethod
56
- def _get_emission_score(
57
- emission: torch.Tensor, token: int, blank_id: int = 0
58
- ) -> torch.Tensor:
59
- """Get emission score for a token, handling wildcards.
60
-
61
- For wildcard tokens (WILDCARD_TOKEN), returns the max score over all
62
- non-blank tokens - allowing any character to match.
63
-
64
- Args:
65
- emission: Emission vector for a single frame (num_classes,)
66
- token: Token index, or WILDCARD_TOKEN for out-of-vocabulary chars
67
- blank_id: Index of the blank/CTC token
68
-
69
- Returns:
70
- Emission score (scalar tensor)
71
- """
72
- if token == WILDCARD_TOKEN:
73
- # Wildcard: take max over all non-blank tokens
74
- mask = torch.ones(emission.size(0), dtype=torch.bool)
75
- mask[blank_id] = False
76
- return emission[mask].max()
77
- return emission[token]
78
-
79
  @staticmethod
80
  def _get_trellis(emission: torch.Tensor, tokens: list[int], blank_id: int = 0) -> torch.Tensor:
81
  """Build trellis for forced alignment using forward algorithm.
@@ -85,7 +83,7 @@ class ForcedAligner:
85
 
86
  Args:
87
  emission: Log-softmax emission matrix of shape (num_frames, num_classes)
88
- tokens: List of target token indices (WILDCARD_TOKEN for OOV chars)
89
  blank_id: Index of the blank/CTC token (default 0)
90
 
91
  Returns:
@@ -103,13 +101,7 @@ class ForcedAligner:
103
  stay = trellis[t, j] + emission[t, blank_id]
104
 
105
  # Move: emit token j and advance to j+1 tokens
106
- if j > 0:
107
- token_score = ForcedAligner._get_emission_score(
108
- emission[t], tokens[j - 1], blank_id
109
- )
110
- move = trellis[t, j - 1] + token_score
111
- else:
112
- move = -float("inf")
113
 
114
  trellis[t + 1, j] = max(stay, move) # Viterbi: take best path
115
 
@@ -154,10 +146,7 @@ class ForcedAligner:
154
  while t > 0 and j > 0:
155
  # Check: did we transition from j-1 to j at frame t-1?
156
  stay_score = trellis[t - 1, j] + emission[t - 1, blank_id]
157
- token_score = ForcedAligner._get_emission_score(
158
- emission[t - 1], tokens[j - 1], blank_id
159
- )
160
- move_score = trellis[t - 1, j - 1] + token_score
161
 
162
  if move_score >= stay_score:
163
  # Token j-1 was emitted at frame t-1
@@ -189,6 +178,148 @@ class ForcedAligner:
189
 
190
  return token_spans
191
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
192
  @classmethod
193
  def align(
194
  cls,
@@ -243,27 +374,29 @@ class ForcedAligner:
243
 
244
  emission = emissions[0].cpu()
245
 
246
- # Normalize text: uppercase
247
  transcript = text.upper()
248
 
249
  # Build tokens from transcript (including word separators)
250
- # Unknown characters get WILDCARD_TOKEN which matches any non-blank emission
251
  tokens = []
252
  for char in transcript:
253
  if char in dictionary:
254
  tokens.append(dictionary[char])
255
  elif char == " ":
256
  tokens.append(dictionary.get("|", dictionary.get(" ", 0)))
257
- else:
258
- # Out-of-vocabulary character - use wildcard
259
- tokens.append(WILDCARD_TOKEN)
260
 
261
  if not tokens:
262
  return []
263
 
264
  # Build Viterbi trellis and backtrack for optimal path
265
  trellis = cls._get_trellis(emission, tokens, blank_id=0)
266
- alignment_path = cls._backtrack(trellis, emission, tokens, blank_id=0)
 
 
 
 
 
 
267
 
268
  # Convert frame indices to time (model stride is 320 samples at 16kHz = 20ms)
269
  frame_duration = 320 / cls._bundle.sample_rate
 
1
  """Forced alignment for word-level timestamps using Wav2Vec2."""
2
 
3
+ import math
4
+ from dataclasses import dataclass
5
+
6
  import numpy as np
7
  import torch
8
 
9
+ # Beam search width for backtracking (from WhisperX)
10
+ BEAM_WIDTH = 2
11
 
12
  # Offset compensation for Wav2Vec2-BASE systematic bias (in seconds)
13
  # Calibrated on librispeech-alignments dataset
 
24
  return "cpu"
25
 
26
 
27
+ @dataclass
28
+ class Point:
29
+ """A point in the alignment path."""
30
+
31
+ token_index: int
32
+ time_index: int
33
+ score: float
34
+
35
+
36
+ @dataclass
37
+ class BeamState:
38
+ """State in beam search backtracking."""
39
+
40
+ token_index: int
41
+ time_index: int
42
+ score: float
43
+ path: list[Point]
44
+
45
+
46
  class ForcedAligner:
47
  """Lazy-loaded forced aligner for word-level timestamps using torchaudio wav2vec2.
48
 
 
74
  cls._dictionary = {c: i for i, c in enumerate(cls._labels)}
75
  return cls._model, cls._labels, cls._dictionary
76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  @staticmethod
78
  def _get_trellis(emission: torch.Tensor, tokens: list[int], blank_id: int = 0) -> torch.Tensor:
79
  """Build trellis for forced alignment using forward algorithm.
 
83
 
84
  Args:
85
  emission: Log-softmax emission matrix of shape (num_frames, num_classes)
86
+ tokens: List of target token indices
87
  blank_id: Index of the blank/CTC token (default 0)
88
 
89
  Returns:
 
101
  stay = trellis[t, j] + emission[t, blank_id]
102
 
103
  # Move: emit token j and advance to j+1 tokens
104
+ move = trellis[t, j - 1] + emission[t, tokens[j - 1]] if j > 0 else -float("inf")
 
 
 
 
 
 
105
 
106
  trellis[t + 1, j] = max(stay, move) # Viterbi: take best path
107
 
 
146
  while t > 0 and j > 0:
147
  # Check: did we transition from j-1 to j at frame t-1?
148
  stay_score = trellis[t - 1, j] + emission[t - 1, blank_id]
149
+ move_score = trellis[t - 1, j - 1] + emission[t - 1, tokens[j - 1]]
 
 
 
150
 
151
  if move_score >= stay_score:
152
  # Token j-1 was emitted at frame t-1
 
178
 
179
  return token_spans
180
 
181
+ @staticmethod
182
+ def _backtrack_beam(
183
+ trellis: torch.Tensor,
184
+ emission: torch.Tensor,
185
+ tokens: list[int],
186
+ blank_id: int = 0,
187
+ beam_width: int = BEAM_WIDTH,
188
+ ) -> list[Point] | None:
189
+ """Beam search backtracking for better alignment paths.
190
+
191
+ Explores multiple candidate paths simultaneously, keeping the top beam_width
192
+ paths at each step. This can find better alignments than greedy backtracking.
193
+
194
+ Based on WhisperX implementation.
195
+
196
+ Args:
197
+ trellis: Trellis matrix from forward pass
198
+ emission: Log-softmax emission matrix
199
+ tokens: List of target token indices
200
+ blank_id: Index of the blank/CTC token
201
+ beam_width: Number of candidate paths to keep
202
+
203
+ Returns:
204
+ List of Points representing the best alignment path, or None if failed
205
+ """
206
+ T, J = trellis.size(0) - 1, trellis.size(1) - 1
207
+
208
+ if J == 0:
209
+ return None
210
+
211
+ init_state = BeamState(
212
+ token_index=J,
213
+ time_index=T,
214
+ score=trellis[T, J].item(),
215
+ path=[Point(J, T, emission[T, blank_id].exp().item())],
216
+ )
217
+
218
+ beams = [init_state]
219
+
220
+ while beams and beams[0].token_index > 0:
221
+ next_beams = []
222
+
223
+ for beam in beams:
224
+ t, j = beam.time_index, beam.token_index
225
+
226
+ if t <= 0:
227
+ continue
228
+
229
+ p_stay = emission[t - 1, blank_id]
230
+ p_change = emission[t - 1, tokens[j - 1]] if j > 0 else float("-inf")
231
+
232
+ stay_score = trellis[t - 1, j].item()
233
+ change_score = trellis[t - 1, j - 1].item() if j > 0 else float("-inf")
234
+
235
+ # Stay option
236
+ if not math.isinf(stay_score):
237
+ new_path = beam.path.copy()
238
+ new_path.append(Point(j, t - 1, p_stay.exp().item()))
239
+ next_beams.append(
240
+ BeamState(
241
+ token_index=j,
242
+ time_index=t - 1,
243
+ score=stay_score,
244
+ path=new_path,
245
+ )
246
+ )
247
+
248
+ # Change option
249
+ if j > 0 and not math.isinf(change_score):
250
+ new_path = beam.path.copy()
251
+ new_path.append(Point(j - 1, t - 1, p_change.exp().item()))
252
+ next_beams.append(
253
+ BeamState(
254
+ token_index=j - 1,
255
+ time_index=t - 1,
256
+ score=change_score,
257
+ path=new_path,
258
+ )
259
+ )
260
+
261
+ # Keep top beam_width paths by score
262
+ beams = sorted(next_beams, key=lambda x: x.score, reverse=True)[:beam_width]
263
+
264
+ if not beams:
265
+ break
266
+
267
+ if not beams:
268
+ return None
269
+
270
+ # Fill remaining time steps with blank emissions
271
+ best_beam = beams[0]
272
+ t = best_beam.time_index
273
+ j = best_beam.token_index
274
+ while t > 0:
275
+ prob = emission[t - 1, blank_id].exp().item()
276
+ best_beam.path.append(Point(j, t - 1, prob))
277
+ t -= 1
278
+
279
+ return best_beam.path[::-1]
280
+
281
+ @staticmethod
282
+ def _path_to_spans(
283
+ path: list[Point], tokens: list[int]
284
+ ) -> list[tuple[int, float, float]]:
285
+ """Convert a beam search path to token spans.
286
+
287
+ Args:
288
+ path: List of Points from beam search
289
+ tokens: List of target token indices
290
+
291
+ Returns:
292
+ List of (token_id, start_frame, end_frame) tuples
293
+ """
294
+ if not path or not tokens:
295
+ return []
296
+
297
+ num_tokens = len(tokens)
298
+ token_frames: list[list[int]] = [[] for _ in range(num_tokens)]
299
+
300
+ # Group frames by token index
301
+ for point in path:
302
+ if 0 < point.token_index <= num_tokens:
303
+ token_frames[point.token_index - 1].append(point.time_index)
304
+
305
+ # Convert to spans
306
+ token_spans: list[tuple[int, float, float]] = []
307
+ for token_idx, frames in enumerate(token_frames):
308
+ if not frames:
309
+ # Token never emitted - assign minimal span after previous
310
+ if token_spans:
311
+ prev_end = token_spans[-1][2]
312
+ frames = [int(prev_end)]
313
+ else:
314
+ frames = [0]
315
+
316
+ token_id = tokens[token_idx]
317
+ start_frame = float(min(frames))
318
+ end_frame = float(max(frames)) + 1.0
319
+ token_spans.append((token_id, start_frame, end_frame))
320
+
321
+ return token_spans
322
+
323
  @classmethod
324
  def align(
325
  cls,
 
374
 
375
  emission = emissions[0].cpu()
376
 
377
+ # Normalize text: uppercase, keep only valid characters
378
  transcript = text.upper()
379
 
380
  # Build tokens from transcript (including word separators)
 
381
  tokens = []
382
  for char in transcript:
383
  if char in dictionary:
384
  tokens.append(dictionary[char])
385
  elif char == " ":
386
  tokens.append(dictionary.get("|", dictionary.get(" ", 0)))
 
 
 
387
 
388
  if not tokens:
389
  return []
390
 
391
  # Build Viterbi trellis and backtrack for optimal path
392
  trellis = cls._get_trellis(emission, tokens, blank_id=0)
393
+
394
+ # Try beam search first, fall back to greedy if it fails
395
+ beam_path = cls._backtrack_beam(trellis, emission, tokens, blank_id=0)
396
+ if beam_path is not None:
397
+ alignment_path = cls._path_to_spans(beam_path, tokens)
398
+ else:
399
+ alignment_path = cls._backtrack(trellis, emission, tokens, blank_id=0)
400
 
401
  # Convert frame indices to time (model stride is 320 samples at 16kHz = 20ms)
402
  frame_duration = 320 / cls._bundle.sample_rate