mazesmazes commited on
Commit
10d57f2
·
verified ·
1 Parent(s): f3ed069

Update custom model files, README, and requirements

Browse files
Files changed (1) hide show
  1. alignment.py +164 -60
alignment.py CHANGED
@@ -1,9 +1,31 @@
1
  """Forced alignment for word-level timestamps using Wav2Vec2."""
2
 
 
 
 
3
  import numpy as np
4
  import torch
5
 
6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  def _get_device() -> str:
8
  """Get best available device for non-transformers models."""
9
  if torch.cuda.is_available():
@@ -16,7 +38,7 @@ def _get_device() -> str:
16
  class ForcedAligner:
17
  """Lazy-loaded forced aligner for word-level timestamps using torchaudio wav2vec2.
18
 
19
- Uses Viterbi trellis algorithm for optimal alignment path finding.
20
  """
21
 
22
  _bundle = None
@@ -78,75 +100,158 @@ class ForcedAligner:
78
  return trellis
79
 
80
  @staticmethod
81
- def _backtrack(
82
- trellis: torch.Tensor, emission: torch.Tensor, tokens: list[int], blank_id: int = 0
83
- ) -> list[tuple[int, float, float]]:
84
- """Backtrack through trellis to find optimal forced monotonic alignment.
 
 
 
 
 
 
 
85
 
86
- Guarantees:
87
- - All tokens are emitted exactly once
88
- - Strictly monotonic: each token's frames come after previous token's
89
- - No frame skipping or token teleporting
 
 
90
 
91
- Returns list of (token_id, start_frame, end_frame) for each token.
 
92
  """
93
- num_frames = emission.size(0)
94
- num_tokens = len(tokens)
95
 
96
  if num_tokens == 0:
97
- return []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
 
99
- # Find the best ending point (should be at num_tokens)
100
- # But verify trellis reached a valid state
101
- if trellis[num_frames, num_tokens] == -float("inf"):
102
- # Alignment failed - fall back to uniform distribution
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
  frames_per_token = num_frames / num_tokens
104
  return [
105
  (tokens[i], i * frames_per_token, (i + 1) * frames_per_token)
106
  for i in range(num_tokens)
107
  ]
108
 
109
- # Backtrack: find where each token transition occurred
110
- # path[i] = frame where token i was first emitted
111
  token_frames: list[list[int]] = [[] for _ in range(num_tokens)]
112
-
113
- t = num_frames
114
- j = num_tokens
115
-
116
- while t > 0 and j > 0:
117
- # Check: did we transition from j-1 to j at frame t-1?
118
- stay_score = trellis[t - 1, j] + emission[t - 1, blank_id]
119
- move_score = trellis[t - 1, j - 1] + emission[t - 1, tokens[j - 1]]
120
-
121
- if move_score >= stay_score:
122
- # Token j-1 was emitted at frame t-1
123
- token_frames[j - 1].insert(0, t - 1)
124
- j -= 1
125
- # Always decrement time (monotonic)
126
- t -= 1
127
-
128
- # Handle any remaining tokens at the start (edge case)
129
- while j > 0:
130
- token_frames[j - 1].insert(0, 0)
131
- j -= 1
132
 
133
  # Convert to spans
134
  token_spans: list[tuple[int, float, float]] = []
135
- for token_idx, frames in enumerate(token_frames):
 
136
  if not frames:
137
- # Token never emitted - assign minimal span after previous
138
  if token_spans:
139
  prev_end = token_spans[-1][2]
140
- frames = [int(prev_end)]
141
  else:
142
- frames = [0]
143
-
144
- token_id = tokens[token_idx]
145
- frame_probs = emission[frames, token_id]
146
- peak_idx = int(torch.argmax(frame_probs).item())
147
- peak_frame = frames[peak_idx]
148
-
149
- token_spans.append((token_id, float(peak_frame), float(peak_frame) + 1.0))
150
 
151
  return token_spans
152
 
@@ -158,22 +263,20 @@ class ForcedAligner:
158
  @classmethod
159
  def align(
160
  cls,
161
- audio: np.ndarray,
162
  text: str,
163
  sample_rate: int = 16000,
164
- _language: str = "eng",
165
- _batch_size: int = 16,
166
  ) -> list[dict]:
167
  """Align transcript to audio and return word-level timestamps.
168
 
169
- Uses Viterbi trellis algorithm for optimal forced alignment.
170
 
171
  Args:
172
- audio: Audio waveform as numpy array
173
  text: Transcript text to align
174
  sample_rate: Audio sample rate (default 16000)
175
- _language: ISO-639-3 language code (default "eng" for English, unused)
176
- _batch_size: Batch size for alignment model (unused)
177
 
178
  Returns:
179
  List of dicts with 'word', 'start', 'end' keys
@@ -181,7 +284,7 @@ class ForcedAligner:
181
  import torchaudio
182
 
183
  device = _get_device()
184
- model, _labels, dictionary = cls.get_instance(device)
185
  assert cls._bundle is not None and dictionary is not None # Initialized by get_instance
186
 
187
  # Convert audio to tensor (copy to ensure array is writable)
@@ -223,9 +326,10 @@ class ForcedAligner:
223
  if not tokens:
224
  return []
225
 
226
- # Build Viterbi trellis and backtrack for optimal path
227
  trellis = cls._get_trellis(emission, tokens, blank_id=0)
228
- alignment_path = cls._backtrack(trellis, emission, tokens, blank_id=0)
 
229
 
230
  # Convert frame indices to time (model stride is 320 samples at 16kHz = 20ms)
231
  frame_duration = 320 / cls._bundle.sample_rate
 
1
  """Forced alignment for word-level timestamps using Wav2Vec2."""
2
 
3
+ import math
4
+ from dataclasses import dataclass, field
5
+
6
  import numpy as np
7
  import torch
8
 
9
 
10
+ @dataclass
11
+ class Point:
12
+ """A point in the alignment path."""
13
+
14
+ token_index: int
15
+ time_index: int
16
+ score: float
17
+
18
+
19
+ @dataclass
20
+ class BeamState:
21
+ """State in beam search backtracking."""
22
+
23
+ token_index: int
24
+ time_index: int
25
+ score: float
26
+ path: list[Point] = field(default_factory=list)
27
+
28
+
29
  def _get_device() -> str:
30
  """Get best available device for non-transformers models."""
31
  if torch.cuda.is_available():
 
38
  class ForcedAligner:
39
  """Lazy-loaded forced aligner for word-level timestamps using torchaudio wav2vec2.
40
 
41
+ Uses CTC trellis with beam search backtracking for optimal alignment path finding.
42
  """
43
 
44
  _bundle = None
 
100
  return trellis
101
 
102
  @staticmethod
103
+ def _backtrack_beam(
104
+ trellis: torch.Tensor,
105
+ emission: torch.Tensor,
106
+ tokens: list[int],
107
+ blank_id: int = 0,
108
+ beam_width: int = 5,
109
+ ) -> list[Point] | None:
110
+ """Beam search backtracking through trellis.
111
+
112
+ Maintains multiple hypotheses during decoding, pruning to top candidates
113
+ by cumulative score at each step.
114
 
115
+ Args:
116
+ trellis: Trellis matrix of shape (num_frames + 1, num_tokens + 1)
117
+ emission: Log-softmax emission matrix of shape (num_frames, num_classes)
118
+ tokens: List of target token indices
119
+ blank_id: Index of the blank/CTC token (default 0)
120
+ beam_width: Number of top paths to keep during beam search (default 5)
121
 
122
+ Returns:
123
+ List of Point objects representing the best alignment path, or None if failed.
124
  """
125
+ num_frames = trellis.size(0) - 1
126
+ num_tokens = trellis.size(1) - 1
127
 
128
  if num_tokens == 0:
129
+ return None
130
+
131
+ # Check if alignment is possible
132
+ if math.isinf(trellis[num_frames, num_tokens].item()):
133
+ return None
134
+
135
+ # Initialize beam with final state
136
+ init_state = BeamState(
137
+ token_index=num_tokens,
138
+ time_index=num_frames,
139
+ score=trellis[num_frames, num_tokens].item(),
140
+ path=[Point(num_tokens, num_frames, emission[num_frames - 1, blank_id].exp().item())],
141
+ )
142
+ beams = [init_state]
143
+
144
+ # Beam search backtracking
145
+ while beams and beams[0].token_index > 0:
146
+ next_beams = []
147
+
148
+ for beam in beams:
149
+ t, j = beam.time_index, beam.token_index
150
+
151
+ if t <= 0:
152
+ continue
153
+
154
+ stay_score = trellis[t - 1, j].item()
155
+ change_score = trellis[t - 1, j - 1].item() if j > 0 else float("-inf")
156
+
157
+ # Stay transition (emit blank)
158
+ if not math.isinf(stay_score):
159
+ prob = emission[t - 1, blank_id].exp().item()
160
+ new_path = beam.path.copy()
161
+ new_path.append(Point(j, t - 1, prob))
162
+ next_beams.append(
163
+ BeamState(
164
+ token_index=j,
165
+ time_index=t - 1,
166
+ score=stay_score,
167
+ path=new_path,
168
+ )
169
+ )
170
+
171
+ # Change transition (emit token)
172
+ if j > 0 and not math.isinf(change_score):
173
+ prob = emission[t - 1, tokens[j - 1]].exp().item()
174
+ new_path = beam.path.copy()
175
+ new_path.append(Point(j - 1, t - 1, prob))
176
+ next_beams.append(
177
+ BeamState(
178
+ token_index=j - 1,
179
+ time_index=t - 1,
180
+ score=change_score,
181
+ path=new_path,
182
+ )
183
+ )
184
+
185
+ # Prune to top beam_width candidates
186
+ beams = sorted(next_beams, key=lambda x: x.score, reverse=True)[:beam_width]
187
+
188
+ if not beams:
189
+ break
190
+
191
+ if not beams:
192
+ return None
193
 
194
+ # Complete path to beginning
195
+ best_beam = beams[0]
196
+ t = best_beam.time_index
197
+ j = best_beam.token_index
198
+
199
+ while t > 0:
200
+ prob = emission[t - 1, blank_id].exp().item()
201
+ best_beam.path.append(Point(j, t - 1, prob))
202
+ t -= 1
203
+
204
+ return best_beam.path[::-1]
205
+
206
+ @staticmethod
207
+ def _path_to_spans(
208
+ path: list[Point] | None, tokens: list[int], num_frames: int
209
+ ) -> list[tuple[int, float, float]]:
210
+ """Convert beam search path to token spans.
211
+
212
+ Args:
213
+ path: List of Point objects from beam search, or None
214
+ tokens: List of target token indices
215
+ num_frames: Total number of frames
216
+
217
+ Returns:
218
+ List of (token_id, start_frame, end_frame) for each token.
219
+ """
220
+ num_tokens = len(tokens)
221
+
222
+ if path is None or num_tokens == 0:
223
+ # Fall back to uniform distribution
224
+ if num_tokens == 0:
225
+ return []
226
  frames_per_token = num_frames / num_tokens
227
  return [
228
  (tokens[i], i * frames_per_token, (i + 1) * frames_per_token)
229
  for i in range(num_tokens)
230
  ]
231
 
232
+ # Group frames by token index
 
233
  token_frames: list[list[int]] = [[] for _ in range(num_tokens)]
234
+ for point in path:
235
+ # Token index in path is 1-indexed (0 = before first token)
236
+ if 0 < point.token_index <= num_tokens:
237
+ token_frames[point.token_index - 1].append(point.time_index)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
238
 
239
  # Convert to spans
240
  token_spans: list[tuple[int, float, float]] = []
241
+ for token_idx in range(num_tokens):
242
+ frames = token_frames[token_idx]
243
  if not frames:
244
+ # Token never emitted - assign span after previous
245
  if token_spans:
246
  prev_end = token_spans[-1][2]
247
+ start_frame = prev_end
248
  else:
249
+ start_frame = 0.0
250
+ token_spans.append((tokens[token_idx], start_frame, start_frame + 1.0))
251
+ else:
252
+ start_frame = float(min(frames))
253
+ end_frame = float(max(frames)) + 1.0
254
+ token_spans.append((tokens[token_idx], start_frame, end_frame))
 
 
255
 
256
  return token_spans
257
 
 
263
  @classmethod
264
  def align(
265
  cls,
266
+ audio: np.ndarray | torch.Tensor,
267
  text: str,
268
  sample_rate: int = 16000,
269
+ beam_width: int = 5,
 
270
  ) -> list[dict]:
271
  """Align transcript to audio and return word-level timestamps.
272
 
273
+ Uses CTC trellis with beam search backtracking for optimal forced alignment.
274
 
275
  Args:
276
+ audio: Audio waveform as numpy array or torch tensor
277
  text: Transcript text to align
278
  sample_rate: Audio sample rate (default 16000)
279
+ beam_width: Number of paths to keep during beam search (default 5)
 
280
 
281
  Returns:
282
  List of dicts with 'word', 'start', 'end' keys
 
284
  import torchaudio
285
 
286
  device = _get_device()
287
+ model, _, dictionary = cls.get_instance(device)
288
  assert cls._bundle is not None and dictionary is not None # Initialized by get_instance
289
 
290
  # Convert audio to tensor (copy to ensure array is writable)
 
326
  if not tokens:
327
  return []
328
 
329
+ # Build CTC trellis and use beam search backtracking
330
  trellis = cls._get_trellis(emission, tokens, blank_id=0)
331
+ path = cls._backtrack_beam(trellis, emission, tokens, blank_id=0, beam_width=beam_width)
332
+ alignment_path = cls._path_to_spans(path, tokens, emission.size(0))
333
 
334
  # Convert frame indices to time (model stride is 320 samples at 16kHz = 20ms)
335
  frame_duration = 320 / cls._bundle.sample_rate