mazesmazes commited on
Commit
d26738c
·
verified ·
1 Parent(s): 10d57f2

Update custom model files, README, and requirements

Browse files
Files changed (1) hide show
  1. alignment.py +58 -164
alignment.py CHANGED
@@ -1,31 +1,9 @@
1
  """Forced alignment for word-level timestamps using Wav2Vec2."""
2
 
3
- import math
4
- from dataclasses import dataclass, field
5
-
6
  import numpy as np
7
  import torch
8
 
9
 
10
- @dataclass
11
- class Point:
12
- """A point in the alignment path."""
13
-
14
- token_index: int
15
- time_index: int
16
- score: float
17
-
18
-
19
- @dataclass
20
- class BeamState:
21
- """State in beam search backtracking."""
22
-
23
- token_index: int
24
- time_index: int
25
- score: float
26
- path: list[Point] = field(default_factory=list)
27
-
28
-
29
  def _get_device() -> str:
30
  """Get best available device for non-transformers models."""
31
  if torch.cuda.is_available():
@@ -38,7 +16,7 @@ def _get_device() -> str:
38
  class ForcedAligner:
39
  """Lazy-loaded forced aligner for word-level timestamps using torchaudio wav2vec2.
40
 
41
- Uses CTC trellis with beam search backtracking for optimal alignment path finding.
42
  """
43
 
44
  _bundle = None
@@ -100,158 +78,73 @@ class ForcedAligner:
100
  return trellis
101
 
102
  @staticmethod
103
- def _backtrack_beam(
104
- trellis: torch.Tensor,
105
- emission: torch.Tensor,
106
- tokens: list[int],
107
- blank_id: int = 0,
108
- beam_width: int = 5,
109
- ) -> list[Point] | None:
110
- """Beam search backtracking through trellis.
111
-
112
- Maintains multiple hypotheses during decoding, pruning to top candidates
113
- by cumulative score at each step.
114
-
115
- Args:
116
- trellis: Trellis matrix of shape (num_frames + 1, num_tokens + 1)
117
- emission: Log-softmax emission matrix of shape (num_frames, num_classes)
118
- tokens: List of target token indices
119
- blank_id: Index of the blank/CTC token (default 0)
120
- beam_width: Number of top paths to keep during beam search (default 5)
121
-
122
- Returns:
123
- List of Point objects representing the best alignment path, or None if failed.
124
- """
125
- num_frames = trellis.size(0) - 1
126
- num_tokens = trellis.size(1) - 1
127
-
128
- if num_tokens == 0:
129
- return None
130
-
131
- # Check if alignment is possible
132
- if math.isinf(trellis[num_frames, num_tokens].item()):
133
- return None
134
-
135
- # Initialize beam with final state
136
- init_state = BeamState(
137
- token_index=num_tokens,
138
- time_index=num_frames,
139
- score=trellis[num_frames, num_tokens].item(),
140
- path=[Point(num_tokens, num_frames, emission[num_frames - 1, blank_id].exp().item())],
141
- )
142
- beams = [init_state]
143
-
144
- # Beam search backtracking
145
- while beams and beams[0].token_index > 0:
146
- next_beams = []
147
-
148
- for beam in beams:
149
- t, j = beam.time_index, beam.token_index
150
-
151
- if t <= 0:
152
- continue
153
-
154
- stay_score = trellis[t - 1, j].item()
155
- change_score = trellis[t - 1, j - 1].item() if j > 0 else float("-inf")
156
-
157
- # Stay transition (emit blank)
158
- if not math.isinf(stay_score):
159
- prob = emission[t - 1, blank_id].exp().item()
160
- new_path = beam.path.copy()
161
- new_path.append(Point(j, t - 1, prob))
162
- next_beams.append(
163
- BeamState(
164
- token_index=j,
165
- time_index=t - 1,
166
- score=stay_score,
167
- path=new_path,
168
- )
169
- )
170
-
171
- # Change transition (emit token)
172
- if j > 0 and not math.isinf(change_score):
173
- prob = emission[t - 1, tokens[j - 1]].exp().item()
174
- new_path = beam.path.copy()
175
- new_path.append(Point(j - 1, t - 1, prob))
176
- next_beams.append(
177
- BeamState(
178
- token_index=j - 1,
179
- time_index=t - 1,
180
- score=change_score,
181
- path=new_path,
182
- )
183
- )
184
-
185
- # Prune to top beam_width candidates
186
- beams = sorted(next_beams, key=lambda x: x.score, reverse=True)[:beam_width]
187
-
188
- if not beams:
189
- break
190
-
191
- if not beams:
192
- return None
193
-
194
- # Complete path to beginning
195
- best_beam = beams[0]
196
- t = best_beam.time_index
197
- j = best_beam.token_index
198
-
199
- while t > 0:
200
- prob = emission[t - 1, blank_id].exp().item()
201
- best_beam.path.append(Point(j, t - 1, prob))
202
- t -= 1
203
-
204
- return best_beam.path[::-1]
205
-
206
- @staticmethod
207
- def _path_to_spans(
208
- path: list[Point] | None, tokens: list[int], num_frames: int
209
  ) -> list[tuple[int, float, float]]:
210
- """Convert beam search path to token spans.
211
 
212
- Args:
213
- path: List of Point objects from beam search, or None
214
- tokens: List of target token indices
215
- num_frames: Total number of frames
216
 
217
- Returns:
218
- List of (token_id, start_frame, end_frame) for each token.
219
  """
 
220
  num_tokens = len(tokens)
221
 
222
- if path is None or num_tokens == 0:
223
- # Fall back to uniform distribution
224
- if num_tokens == 0:
225
- return []
 
 
 
226
  frames_per_token = num_frames / num_tokens
227
  return [
228
  (tokens[i], i * frames_per_token, (i + 1) * frames_per_token)
229
  for i in range(num_tokens)
230
  ]
231
 
232
- # Group frames by token index
 
233
  token_frames: list[list[int]] = [[] for _ in range(num_tokens)]
234
- for point in path:
235
- # Token index in path is 1-indexed (0 = before first token)
236
- if 0 < point.token_index <= num_tokens:
237
- token_frames[point.token_index - 1].append(point.time_index)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
238
 
239
  # Convert to spans
240
  token_spans: list[tuple[int, float, float]] = []
241
- for token_idx in range(num_tokens):
242
- frames = token_frames[token_idx]
243
  if not frames:
244
- # Token never emitted - assign span after previous
245
  if token_spans:
246
  prev_end = token_spans[-1][2]
247
- start_frame = prev_end
248
  else:
249
- start_frame = 0.0
250
- token_spans.append((tokens[token_idx], start_frame, start_frame + 1.0))
251
- else:
252
- start_frame = float(min(frames))
253
- end_frame = float(max(frames)) + 1.0
254
- token_spans.append((tokens[token_idx], start_frame, end_frame))
255
 
256
  return token_spans
257
 
@@ -263,20 +156,22 @@ class ForcedAligner:
263
  @classmethod
264
  def align(
265
  cls,
266
- audio: np.ndarray | torch.Tensor,
267
  text: str,
268
  sample_rate: int = 16000,
269
- beam_width: int = 5,
 
270
  ) -> list[dict]:
271
  """Align transcript to audio and return word-level timestamps.
272
 
273
- Uses CTC trellis with beam search backtracking for optimal forced alignment.
274
 
275
  Args:
276
- audio: Audio waveform as numpy array or torch tensor
277
  text: Transcript text to align
278
  sample_rate: Audio sample rate (default 16000)
279
- beam_width: Number of paths to keep during beam search (default 5)
 
280
 
281
  Returns:
282
  List of dicts with 'word', 'start', 'end' keys
@@ -284,7 +179,7 @@ class ForcedAligner:
284
  import torchaudio
285
 
286
  device = _get_device()
287
- model, _, dictionary = cls.get_instance(device)
288
  assert cls._bundle is not None and dictionary is not None # Initialized by get_instance
289
 
290
  # Convert audio to tensor (copy to ensure array is writable)
@@ -326,10 +221,9 @@ class ForcedAligner:
326
  if not tokens:
327
  return []
328
 
329
- # Build CTC trellis and use beam search backtracking
330
  trellis = cls._get_trellis(emission, tokens, blank_id=0)
331
- path = cls._backtrack_beam(trellis, emission, tokens, blank_id=0, beam_width=beam_width)
332
- alignment_path = cls._path_to_spans(path, tokens, emission.size(0))
333
 
334
  # Convert frame indices to time (model stride is 320 samples at 16kHz = 20ms)
335
  frame_duration = 320 / cls._bundle.sample_rate
 
1
  """Forced alignment for word-level timestamps using Wav2Vec2."""
2
 
 
 
 
3
  import numpy as np
4
  import torch
5
 
6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  def _get_device() -> str:
8
  """Get best available device for non-transformers models."""
9
  if torch.cuda.is_available():
 
16
  class ForcedAligner:
17
  """Lazy-loaded forced aligner for word-level timestamps using torchaudio wav2vec2.
18
 
19
+ Uses Viterbi trellis algorithm for optimal alignment path finding.
20
  """
21
 
22
  _bundle = None
 
78
  return trellis
79
 
80
  @staticmethod
81
+ def _backtrack(
82
+ trellis: torch.Tensor, emission: torch.Tensor, tokens: list[int], blank_id: int = 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  ) -> list[tuple[int, float, float]]:
84
+ """Backtrack through trellis to find optimal forced monotonic alignment.
85
 
86
+ Guarantees:
87
+ - All tokens are emitted exactly once
88
+ - Strictly monotonic: each token's frames come after previous token's
89
+ - No frame skipping or token teleporting
90
 
91
+ Returns list of (token_id, start_frame, end_frame) for each token.
 
92
  """
93
+ num_frames = emission.size(0)
94
  num_tokens = len(tokens)
95
 
96
+ if num_tokens == 0:
97
+ return []
98
+
99
+ # Find the best ending point (should be at num_tokens)
100
+ # But verify trellis reached a valid state
101
+ if trellis[num_frames, num_tokens] == -float("inf"):
102
+ # Alignment failed - fall back to uniform distribution
103
  frames_per_token = num_frames / num_tokens
104
  return [
105
  (tokens[i], i * frames_per_token, (i + 1) * frames_per_token)
106
  for i in range(num_tokens)
107
  ]
108
 
109
+ # Backtrack: find where each token transition occurred
110
+ # path[i] = frame where token i was first emitted
111
  token_frames: list[list[int]] = [[] for _ in range(num_tokens)]
112
+
113
+ t = num_frames
114
+ j = num_tokens
115
+
116
+ while t > 0 and j > 0:
117
+ # Check: did we transition from j-1 to j at frame t-1?
118
+ stay_score = trellis[t - 1, j] + emission[t - 1, blank_id]
119
+ move_score = trellis[t - 1, j - 1] + emission[t - 1, tokens[j - 1]]
120
+
121
+ if move_score >= stay_score:
122
+ # Token j-1 was emitted at frame t-1
123
+ token_frames[j - 1].insert(0, t - 1)
124
+ j -= 1
125
+ # Always decrement time (monotonic)
126
+ t -= 1
127
+
128
+ # Handle any remaining tokens at the start (edge case)
129
+ while j > 0:
130
+ token_frames[j - 1].insert(0, 0)
131
+ j -= 1
132
 
133
  # Convert to spans
134
  token_spans: list[tuple[int, float, float]] = []
135
+ for token_idx, frames in enumerate(token_frames):
 
136
  if not frames:
137
+ # Token never emitted - assign minimal span after previous
138
  if token_spans:
139
  prev_end = token_spans[-1][2]
140
+ frames = [int(prev_end)]
141
  else:
142
+ frames = [0]
143
+
144
+ token_id = tokens[token_idx]
145
+ start_frame = float(min(frames))
146
+ end_frame = float(max(frames)) + 1.0
147
+ token_spans.append((token_id, start_frame, end_frame))
148
 
149
  return token_spans
150
 
 
156
  @classmethod
157
  def align(
158
  cls,
159
+ audio: np.ndarray,
160
  text: str,
161
  sample_rate: int = 16000,
162
+ _language: str = "eng",
163
+ _batch_size: int = 16,
164
  ) -> list[dict]:
165
  """Align transcript to audio and return word-level timestamps.
166
 
167
+ Uses Viterbi trellis algorithm for optimal forced alignment.
168
 
169
  Args:
170
+ audio: Audio waveform as numpy array
171
  text: Transcript text to align
172
  sample_rate: Audio sample rate (default 16000)
173
+ _language: ISO-639-3 language code (default "eng" for English, unused)
174
+ _batch_size: Batch size for alignment model (unused)
175
 
176
  Returns:
177
  List of dicts with 'word', 'start', 'end' keys
 
179
  import torchaudio
180
 
181
  device = _get_device()
182
+ model, _labels, dictionary = cls.get_instance(device)
183
  assert cls._bundle is not None and dictionary is not None # Initialized by get_instance
184
 
185
  # Convert audio to tensor (copy to ensure array is writable)
 
221
  if not tokens:
222
  return []
223
 
224
+ # Build Viterbi trellis and backtrack for optimal path
225
  trellis = cls._get_trellis(emission, tokens, blank_id=0)
226
+ alignment_path = cls._backtrack(trellis, emission, tokens, blank_id=0)
 
227
 
228
  # Convert frame indices to time (model stride is 320 samples at 16kHz = 20ms)
229
  frame_duration = 320 / cls._bundle.sample_rate