mazesmazes commited on
Commit
1d0a54b
·
verified ·
1 Parent(s): 6fbb3b5

Update custom model files, README, and requirements

Browse files
Files changed (1) hide show
  1. asr_pipeline.py +133 -22
asr_pipeline.py CHANGED
@@ -24,7 +24,10 @@ def _get_device() -> str:
24
 
25
 
26
  class ForcedAligner:
27
- """Lazy-loaded forced aligner for word-level timestamps using torchaudio wav2vec2."""
 
 
 
28
 
29
  _bundle = None
30
  _model = None
@@ -51,6 +54,107 @@ class ForcedAligner:
51
  cls._dictionary = {c: i for i, c in enumerate(cls._labels)}
52
  return cls._model, cls._labels, cls._dictionary
53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  @classmethod
55
  def align(
56
  cls,
@@ -59,21 +163,26 @@ class ForcedAligner:
59
  sample_rate: int = 16000,
60
  _language: str = "eng",
61
  _batch_size: int = 16,
 
62
  ) -> list[dict]:
63
  """Align transcript to audio and return word-level timestamps.
64
 
 
 
65
  Args:
66
  audio: Audio waveform as numpy array
67
  text: Transcript text to align
68
  sample_rate: Audio sample rate (default 16000)
69
  _language: ISO-639-3 language code (default "eng" for English, unused)
70
  _batch_size: Batch size for alignment model (unused)
 
 
 
71
 
72
  Returns:
73
  List of dicts with 'word', 'start', 'end' keys
74
  """
75
  import torchaudio
76
- from torchaudio.functional import forced_align, merge_tokens
77
 
78
  device = _get_device()
79
  model, labels, dictionary = cls.get_instance(device)
@@ -105,7 +214,8 @@ class ForcedAligner:
105
 
106
  # Normalize text: uppercase, keep only valid characters
107
  transcript = text.upper()
108
- # Build tokens from transcript
 
109
  tokens = []
110
  for char in transcript:
111
  if char in dictionary:
@@ -116,35 +226,34 @@ class ForcedAligner:
116
  if not tokens:
117
  return []
118
 
119
- targets = torch.tensor([tokens], dtype=torch.int32)
120
-
121
- # Run forced alignment
122
- # Note: forced_align is deprecated in torchaudio 2.6+ and will be removed in 2.9 (late 2025)
123
- # No official replacement announced yet. See https://github.com/pytorch/audio/issues/3902
124
- aligned_tokens, scores = forced_align(emission.unsqueeze(0), targets, blank=0)
125
-
126
- # Use torchaudio's merge_tokens to get token spans (removes blanks and merges repeats)
127
- token_spans = merge_tokens(aligned_tokens[0], scores[0])
128
 
129
  # Convert frame indices to time (model stride is 320 samples at 16kHz = 20ms)
130
  frame_duration = 320 / cls._bundle.sample_rate
131
 
132
- # Group token spans into words based on pipe separator
 
 
 
133
  words = text.split()
134
  word_timestamps = []
135
  current_word_start = None
136
  current_word_end = None
137
  word_idx = 0
 
138
 
139
- for span in token_spans:
140
- token_char = labels[span.token]
141
- if token_char == "|": # Word separator
142
  if current_word_start is not None and word_idx < len(words):
 
 
143
  word_timestamps.append(
144
  {
145
  "word": words[word_idx],
146
- "start": current_word_start * frame_duration,
147
- "end": current_word_end * frame_duration,
148
  }
149
  )
150
  word_idx += 1
@@ -152,16 +261,18 @@ class ForcedAligner:
152
  current_word_end = None
153
  else:
154
  if current_word_start is None:
155
- current_word_start = span.start
156
- current_word_end = span.end
157
 
158
  # Don't forget the last word
159
  if current_word_start is not None and word_idx < len(words):
 
 
160
  word_timestamps.append(
161
  {
162
  "word": words[word_idx],
163
- "start": current_word_start * frame_duration,
164
- "end": current_word_end * frame_duration,
165
  }
166
  )
167
 
 
24
 
25
 
26
  class ForcedAligner:
27
+ """Lazy-loaded forced aligner for word-level timestamps using torchaudio wav2vec2.
28
+
29
+ Uses Viterbi trellis algorithm for optimal alignment path finding.
30
+ """
31
 
32
  _bundle = None
33
  _model = None
 
54
  cls._dictionary = {c: i for i, c in enumerate(cls._labels)}
55
  return cls._model, cls._labels, cls._dictionary
56
 
57
+ @staticmethod
58
+ def _get_trellis(emission: torch.Tensor, tokens: list[int], blank_id: int = 0) -> torch.Tensor:
59
+ """Build Viterbi trellis for forced alignment.
60
+
61
+ The trellis is a 2D matrix where trellis[t, j] represents the log probability
62
+ of the most likely path that has emitted j tokens at time t.
63
+
64
+ Args:
65
+ emission: Log-softmax emission matrix of shape (num_frames, num_classes)
66
+ tokens: List of target token indices
67
+ blank_id: Index of the blank/CTC token (default 0)
68
+
69
+ Returns:
70
+ Trellis matrix of shape (num_frames + 1, num_tokens + 1)
71
+ """
72
+ num_frames = emission.size(0)
73
+ num_tokens = len(tokens)
74
+
75
+ # Initialize trellis with -inf (impossible paths)
76
+ trellis = torch.full((num_frames + 1, num_tokens + 1), -float("inf"))
77
+ trellis[0, 0] = 0 # Start state has probability 1
78
+
79
+ for t in range(num_frames):
80
+ for j in range(num_tokens + 1):
81
+ # Stay in current state (emit blank)
82
+ if j < num_tokens + 1:
83
+ stay_prob = trellis[t, j] + emission[t, blank_id]
84
+ else:
85
+ stay_prob = -float("inf")
86
+
87
+ # Move to next state (emit token)
88
+ if j > 0:
89
+ move_prob = trellis[t, j - 1] + emission[t, tokens[j - 1]]
90
+ else:
91
+ move_prob = -float("inf")
92
+
93
+ trellis[t + 1, j] = max(stay_prob, move_prob)
94
+
95
+ return trellis
96
+
97
+ @staticmethod
98
+ def _backtrack(
99
+ trellis: torch.Tensor, emission: torch.Tensor, tokens: list[int], blank_id: int = 0
100
+ ) -> list[tuple[int, int, int]]:
101
+ """Backtrack through trellis to find optimal alignment path.
102
+
103
+ Args:
104
+ trellis: Trellis matrix from _get_trellis
105
+ emission: Log-softmax emission matrix
106
+ tokens: List of target token indices
107
+ blank_id: Index of the blank/CTC token
108
+
109
+ Returns:
110
+ List of (token_idx, start_frame, end_frame) tuples
111
+ """
112
+ num_frames = emission.size(0)
113
+ num_tokens = len(tokens)
114
+
115
+ # Start from the end
116
+ t = num_frames
117
+ j = num_tokens
118
+ path = []
119
+
120
+ # Backtrack to find where each token was emitted
121
+ while j > 0:
122
+ # Find the frame where token j-1 was first emitted
123
+ token_end = t
124
+
125
+ # Walk back while staying in state j (emitting blanks)
126
+ while t > 0:
127
+ stay_prob = trellis[t - 1, j] + emission[t - 1, blank_id]
128
+ if j > 0:
129
+ move_prob = trellis[t - 1, j - 1] + emission[t - 1, tokens[j - 1]]
130
+ else:
131
+ move_prob = -float("inf")
132
+
133
+ # Check if we moved into this state or stayed
134
+ if move_prob > stay_prob:
135
+ # We moved into state j at time t-1
136
+ token_start = t - 1
137
+ path.append((tokens[j - 1], token_start, token_end))
138
+ j -= 1
139
+ t -= 1
140
+ break
141
+ else:
142
+ # We stayed in state j
143
+ t -= 1
144
+
145
+ if t == 0 and j > 0:
146
+ # Handle edge case: remaining tokens at the start
147
+ path.append((tokens[j - 1], 0, token_end))
148
+ j -= 1
149
+
150
+ # Reverse to get chronological order
151
+ path.reverse()
152
+ return path
153
+
154
+ # Sub-frame offset to compensate for Wav2Vec2 convolutional look-ahead (in seconds)
155
+ # This makes timestamps feel more "natural" by shifting them earlier
156
+ OFFSET_COMPENSATION = 0.04 # 40ms
157
+
158
  @classmethod
159
  def align(
160
  cls,
 
163
  sample_rate: int = 16000,
164
  _language: str = "eng",
165
  _batch_size: int = 16,
166
+ offset_compensation: float | None = None,
167
  ) -> list[dict]:
168
  """Align transcript to audio and return word-level timestamps.
169
 
170
+ Uses Viterbi trellis algorithm for optimal forced alignment.
171
+
172
  Args:
173
  audio: Audio waveform as numpy array
174
  text: Transcript text to align
175
  sample_rate: Audio sample rate (default 16000)
176
  _language: ISO-639-3 language code (default "eng" for English, unused)
177
  _batch_size: Batch size for alignment model (unused)
178
+ offset_compensation: Time offset in seconds to subtract from timestamps
179
+ to compensate for Wav2Vec2 look-ahead (default: 0.04s / 40ms).
180
+ Set to 0 to disable.
181
 
182
  Returns:
183
  List of dicts with 'word', 'start', 'end' keys
184
  """
185
  import torchaudio
 
186
 
187
  device = _get_device()
188
  model, labels, dictionary = cls.get_instance(device)
 
214
 
215
  # Normalize text: uppercase, keep only valid characters
216
  transcript = text.upper()
217
+
218
+ # Build tokens from transcript (including word separators)
219
  tokens = []
220
  for char in transcript:
221
  if char in dictionary:
 
226
  if not tokens:
227
  return []
228
 
229
+ # Build Viterbi trellis and backtrack for optimal path
230
+ trellis = cls._get_trellis(emission, tokens, blank_id=0)
231
+ alignment_path = cls._backtrack(trellis, emission, tokens, blank_id=0)
 
 
 
 
 
 
232
 
233
  # Convert frame indices to time (model stride is 320 samples at 16kHz = 20ms)
234
  frame_duration = 320 / cls._bundle.sample_rate
235
 
236
+ # Apply offset compensation for Wav2Vec2 look-ahead
237
+ offset = offset_compensation if offset_compensation is not None else cls.OFFSET_COMPENSATION
238
+
239
+ # Group aligned tokens into words based on pipe separator
240
  words = text.split()
241
  word_timestamps = []
242
  current_word_start = None
243
  current_word_end = None
244
  word_idx = 0
245
+ separator_id = dictionary.get("|", dictionary.get(" ", 0))
246
 
247
+ for token_id, start_frame, end_frame in alignment_path:
248
+ if token_id == separator_id: # Word separator
 
249
  if current_word_start is not None and word_idx < len(words):
250
+ start_time = max(0.0, current_word_start * frame_duration - offset)
251
+ end_time = max(0.0, current_word_end * frame_duration - offset)
252
  word_timestamps.append(
253
  {
254
  "word": words[word_idx],
255
+ "start": start_time,
256
+ "end": end_time,
257
  }
258
  )
259
  word_idx += 1
 
261
  current_word_end = None
262
  else:
263
  if current_word_start is None:
264
+ current_word_start = start_frame
265
+ current_word_end = end_frame
266
 
267
  # Don't forget the last word
268
  if current_word_start is not None and word_idx < len(words):
269
+ start_time = max(0.0, current_word_start * frame_duration - offset)
270
+ end_time = max(0.0, current_word_end * frame_duration - offset)
271
  word_timestamps.append(
272
  {
273
  "word": words[word_idx],
274
+ "start": start_time,
275
+ "end": end_time,
276
  }
277
  )
278