mazesmazes commited on
Commit
acc869c
·
verified ·
1 Parent(s): 86c5e8d

Update custom model files, README, and requirements

Browse files
Files changed (1) hide show
  1. alignment.py +33 -184
alignment.py CHANGED
@@ -1,14 +1,8 @@
1
  """Forced alignment for word-level timestamps using Wav2Vec2."""
2
 
3
- import math
4
- from dataclasses import dataclass
5
-
6
  import numpy as np
7
  import torch
8
 
9
- # Beam search width for backtracking (from WhisperX)
10
- BEAM_WIDTH = 2
11
-
12
  # Offset compensation for Wav2Vec2-BASE systematic bias (in seconds)
13
  # Calibrated on librispeech-alignments dataset
14
  START_OFFSET = 0.06 # Subtract from start times (shift earlier)
@@ -24,25 +18,6 @@ def _get_device() -> str:
24
  return "cpu"
25
 
26
 
27
- @dataclass
28
- class Point:
29
- """A point in the alignment path."""
30
-
31
- token_index: int
32
- time_index: int
33
- score: float
34
-
35
-
36
- @dataclass
37
- class BeamState:
38
- """State in beam search backtracking."""
39
-
40
- token_index: int
41
- time_index: int
42
- score: float
43
- path: list[Point]
44
-
45
-
46
  class ForcedAligner:
47
  """Lazy-loaded forced aligner for word-level timestamps using torchaudio wav2vec2.
48
 
@@ -113,6 +88,10 @@ class ForcedAligner:
113
  ) -> list[tuple[int, float, float]]:
114
  """Backtrack through trellis to find optimal forced monotonic alignment.
115
 
 
 
 
 
116
  Guarantees:
117
  - All tokens are emitted exactly once
118
  - Strictly monotonic: each token's frames come after previous token's
@@ -137,8 +116,8 @@ class ForcedAligner:
137
  ]
138
 
139
  # Backtrack: find where each token transition occurred
140
- # path[i] = frame where token i was first emitted
141
- token_frames: list[list[int]] = [[] for _ in range(num_tokens)]
142
 
143
  t = num_frames
144
  j = num_tokens
@@ -150,172 +129,48 @@ class ForcedAligner:
150
 
151
  if move_score >= stay_score:
152
  # Token j-1 was emitted at frame t-1
153
- token_frames[j - 1].insert(0, t - 1)
 
 
154
  j -= 1
155
  # Always decrement time (monotonic)
156
  t -= 1
157
 
158
  # Handle any remaining tokens at the start (edge case)
159
  while j > 0:
160
- token_frames[j - 1].insert(0, 0)
161
  j -= 1
162
 
163
- # Convert to spans
164
  token_spans: list[tuple[int, float, float]] = []
165
- for token_idx, frames in enumerate(token_frames):
166
- if not frames:
167
  # Token never emitted - assign minimal span after previous
168
  if token_spans:
169
  prev_end = token_spans[-1][2]
170
- frames = [int(prev_end)]
171
  else:
172
- frames = [0]
173
 
174
  token_id = tokens[token_idx]
175
- start_frame = float(min(frames))
176
- end_frame = float(max(frames)) + 1.0
177
- token_spans.append((token_id, start_frame, end_frame))
178
-
179
- return token_spans
180
-
181
- @staticmethod
182
- def _backtrack_beam(
183
- trellis: torch.Tensor,
184
- emission: torch.Tensor,
185
- tokens: list[int],
186
- blank_id: int = 0,
187
- beam_width: int = BEAM_WIDTH,
188
- ) -> list[Point] | None:
189
- """Beam search backtracking for better alignment paths.
190
-
191
- Explores multiple candidate paths simultaneously, keeping the top beam_width
192
- paths at each step. This can find better alignments than greedy backtracking.
193
-
194
- Based on WhisperX implementation.
195
-
196
- Args:
197
- trellis: Trellis matrix from forward pass
198
- emission: Log-softmax emission matrix
199
- tokens: List of target token indices
200
- blank_id: Index of the blank/CTC token
201
- beam_width: Number of candidate paths to keep
202
-
203
- Returns:
204
- List of Points representing the best alignment path, or None if failed
205
- """
206
- T, J = trellis.size(0) - 1, trellis.size(1) - 1
207
-
208
- if J == 0:
209
- return None
210
-
211
- init_state = BeamState(
212
- token_index=J,
213
- time_index=T,
214
- score=trellis[T, J].item(),
215
- path=[Point(J, T, emission[T, blank_id].exp().item())],
216
- )
217
-
218
- beams = [init_state]
219
-
220
- while beams and beams[0].token_index > 0:
221
- next_beams = []
222
-
223
- for beam in beams:
224
- t, j = beam.time_index, beam.token_index
225
-
226
- if t <= 0:
227
- continue
228
-
229
- p_stay = emission[t - 1, blank_id]
230
- p_change = emission[t - 1, tokens[j - 1]] if j > 0 else float("-inf")
231
-
232
- stay_score = trellis[t - 1, j].item()
233
- change_score = trellis[t - 1, j - 1].item() if j > 0 else float("-inf")
234
-
235
- # Stay option
236
- if not math.isinf(stay_score):
237
- new_path = beam.path.copy()
238
- new_path.append(Point(j, t - 1, p_stay.exp().item()))
239
- next_beams.append(
240
- BeamState(
241
- token_index=j,
242
- time_index=t - 1,
243
- score=stay_score,
244
- path=new_path,
245
- )
246
- )
247
-
248
- # Change option
249
- if j > 0 and not math.isinf(change_score):
250
- new_path = beam.path.copy()
251
- new_path.append(Point(j - 1, t - 1, p_change.exp().item()))
252
- next_beams.append(
253
- BeamState(
254
- token_index=j - 1,
255
- time_index=t - 1,
256
- score=change_score,
257
- path=new_path,
258
- )
259
- )
260
-
261
- # Keep top beam_width paths by score
262
- beams = sorted(next_beams, key=lambda x: x.score, reverse=True)[:beam_width]
263
-
264
- if not beams:
265
- break
266
-
267
- if not beams:
268
- return None
269
-
270
- # Fill remaining time steps with blank emissions
271
- best_beam = beams[0]
272
- t = best_beam.time_index
273
- j = best_beam.token_index
274
- while t > 0:
275
- prob = emission[t - 1, blank_id].exp().item()
276
- best_beam.path.append(Point(j, t - 1, prob))
277
- t -= 1
278
-
279
- return best_beam.path[::-1]
280
-
281
- @staticmethod
282
- def _path_to_spans(
283
- path: list[Point], tokens: list[int]
284
- ) -> list[tuple[int, float, float]]:
285
- """Convert a beam search path to token spans.
286
-
287
- Args:
288
- path: List of Points from beam search
289
- tokens: List of target token indices
290
-
291
- Returns:
292
- List of (token_id, start_frame, end_frame) tuples
293
- """
294
- if not path or not tokens:
295
- return []
296
-
297
- num_tokens = len(tokens)
298
- token_frames: list[list[int]] = [[] for _ in range(num_tokens)]
299
-
300
- # Group frames by token index
301
- for point in path:
302
- if 0 < point.token_index <= num_tokens:
303
- token_frames[point.token_index - 1].append(point.time_index)
304
-
305
- # Convert to spans
306
- token_spans: list[tuple[int, float, float]] = []
307
- for token_idx, frames in enumerate(token_frames):
308
- if not frames:
309
- # Token never emitted - assign minimal span after previous
310
- if token_spans:
311
- prev_end = token_spans[-1][2]
312
- frames = [int(prev_end)]
313
- else:
314
- frames = [0]
315
 
316
- token_id = tokens[token_idx]
317
- start_frame = float(min(frames))
318
- end_frame = float(max(frames)) + 1.0
319
  token_spans.append((token_id, start_frame, end_frame))
320
 
321
  return token_spans
@@ -390,13 +245,7 @@ class ForcedAligner:
390
 
391
  # Build Viterbi trellis and backtrack for optimal path
392
  trellis = cls._get_trellis(emission, tokens, blank_id=0)
393
-
394
- # Try beam search first, fall back to greedy if it fails
395
- beam_path = cls._backtrack_beam(trellis, emission, tokens, blank_id=0)
396
- if beam_path is not None:
397
- alignment_path = cls._path_to_spans(beam_path, tokens)
398
- else:
399
- alignment_path = cls._backtrack(trellis, emission, tokens, blank_id=0)
400
 
401
  # Convert frame indices to time (model stride is 320 samples at 16kHz = 20ms)
402
  frame_duration = 320 / cls._bundle.sample_rate
 
1
  """Forced alignment for word-level timestamps using Wav2Vec2."""
2
 
 
 
 
3
  import numpy as np
4
  import torch
5
 
 
 
 
6
  # Offset compensation for Wav2Vec2-BASE systematic bias (in seconds)
7
  # Calibrated on librispeech-alignments dataset
8
  START_OFFSET = 0.06 # Subtract from start times (shift earlier)
 
18
  return "cpu"
19
 
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  class ForcedAligner:
22
  """Lazy-loaded forced aligner for word-level timestamps using torchaudio wav2vec2.
23
 
 
88
  ) -> list[tuple[int, float, float]]:
89
  """Backtrack through trellis to find optimal forced monotonic alignment.
90
 
91
+ Uses emission probability weighting for sub-frame precision. Since wav2vec2
92
+ has 20ms frame resolution, weighting by emission scores can improve accuracy
93
+ by estimating where within a frame the token boundary likely falls.
94
+
95
  Guarantees:
96
  - All tokens are emitted exactly once
97
  - Strictly monotonic: each token's frames come after previous token's
 
116
  ]
117
 
118
  # Backtrack: find where each token transition occurred
119
+ # path[i] = list of (frame, score) tuples where token i was emitted
120
+ token_frames: list[list[tuple[int, float]]] = [[] for _ in range(num_tokens)]
121
 
122
  t = num_frames
123
  j = num_tokens
 
129
 
130
  if move_score >= stay_score:
131
  # Token j-1 was emitted at frame t-1
132
+ # Store frame index and emission probability for weighting
133
+ prob = emission[t - 1, tokens[j - 1]].exp().item()
134
+ token_frames[j - 1].insert(0, (t - 1, prob))
135
  j -= 1
136
  # Always decrement time (monotonic)
137
  t -= 1
138
 
139
  # Handle any remaining tokens at the start (edge case)
140
  while j > 0:
141
+ token_frames[j - 1].insert(0, (0, 0.0))
142
  j -= 1
143
 
144
+ # Convert to spans with emission-weighted sub-frame precision
145
  token_spans: list[tuple[int, float, float]] = []
146
+ for token_idx, frames_with_scores in enumerate(token_frames):
147
+ if not frames_with_scores:
148
  # Token never emitted - assign minimal span after previous
149
  if token_spans:
150
  prev_end = token_spans[-1][2]
151
+ frames_with_scores = [(int(prev_end), 0.0)]
152
  else:
153
+ frames_with_scores = [(0, 0.0)]
154
 
155
  token_id = tokens[token_idx]
156
+ frames = [f for f, _ in frames_with_scores]
157
+ scores = [s for _, s in frames_with_scores]
158
+
159
+ # Compute emission-weighted start position for sub-frame precision
160
+ # Weight shifts the position toward frames with higher emission probability
161
+ total_score = sum(scores)
162
+ if total_score > 0 and len(frames) > 1:
163
+ # Weighted centroid gives sub-frame precision
164
+ weighted_center = sum(f * s for f, s in zip(frames, scores)) / total_score
165
+ # Estimate start/end based on weighted center and span width
166
+ span_width = max(frames) - min(frames) + 1
167
+ start_frame = weighted_center - span_width / 2
168
+ end_frame = weighted_center + span_width / 2
169
+ else:
170
+ # Fall back to simple min/max
171
+ start_frame = float(min(frames))
172
+ end_frame = float(max(frames)) + 1.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
173
 
 
 
 
174
  token_spans.append((token_id, start_frame, end_frame))
175
 
176
  return token_spans
 
245
 
246
  # Build Viterbi trellis and backtrack for optimal path
247
  trellis = cls._get_trellis(emission, tokens, blank_id=0)
248
+ alignment_path = cls._backtrack(trellis, emission, tokens, blank_id=0)
 
 
 
 
 
 
249
 
250
  # Convert frame indices to time (model stride is 320 samples at 16kHz = 20ms)
251
  frame_duration = 320 / cls._bundle.sample_rate