mazesmazes commited on
Commit
e709a58
·
verified ·
1 Parent(s): acc869c

Update custom model files, README, and requirements

Browse files
Files changed (1) hide show
  1. alignment.py +39 -46
alignment.py CHANGED
@@ -70,6 +70,11 @@ class ForcedAligner:
70
  trellis = torch.full((num_frames + 1, num_tokens + 1), -float("inf"))
71
  trellis[0, 0] = 0
72
 
 
 
 
 
 
73
  for t in range(num_frames):
74
  for j in range(num_tokens + 1):
75
  # Stay: emit blank and stay at j tokens
@@ -85,19 +90,16 @@ class ForcedAligner:
85
  @staticmethod
86
  def _backtrack(
87
  trellis: torch.Tensor, emission: torch.Tensor, tokens: list[int], blank_id: int = 0
88
- ) -> list[tuple[int, float, float]]:
89
  """Backtrack through trellis to find optimal forced monotonic alignment.
90
 
91
- Uses emission probability weighting for sub-frame precision. Since wav2vec2
92
- has 20ms frame resolution, weighting by emission scores can improve accuracy
93
- by estimating where within a frame the token boundary likely falls.
94
-
95
  Guarantees:
96
  - All tokens are emitted exactly once
97
  - Strictly monotonic: each token's frames come after previous token's
98
  - No frame skipping or token teleporting
99
 
100
- Returns list of (token_id, start_frame, end_frame) for each token.
 
101
  """
102
  num_frames = emission.size(0)
103
  num_tokens = len(tokens)
@@ -111,12 +113,12 @@ class ForcedAligner:
111
  # Alignment failed - fall back to uniform distribution
112
  frames_per_token = num_frames / num_tokens
113
  return [
114
- (tokens[i], i * frames_per_token, (i + 1) * frames_per_token)
115
  for i in range(num_tokens)
116
  ]
117
 
118
  # Backtrack: find where each token transition occurred
119
- # path[i] = list of (frame, score) tuples where token i was emitted
120
  token_frames: list[list[tuple[int, float]]] = [[] for _ in range(num_tokens)]
121
 
122
  t = num_frames
@@ -129,9 +131,9 @@ class ForcedAligner:
129
 
130
  if move_score >= stay_score:
131
  # Token j-1 was emitted at frame t-1
132
- # Store frame index and emission probability for weighting
133
- prob = emission[t - 1, tokens[j - 1]].exp().item()
134
- token_frames[j - 1].insert(0, (t - 1, prob))
135
  j -= 1
136
  # Always decrement time (monotonic)
137
  t -= 1
@@ -141,8 +143,8 @@ class ForcedAligner:
141
  token_frames[j - 1].insert(0, (0, 0.0))
142
  j -= 1
143
 
144
- # Convert to spans with emission-weighted sub-frame precision
145
- token_spans: list[tuple[int, float, float]] = []
146
  for token_idx, frames_with_scores in enumerate(token_frames):
147
  if not frames_with_scores:
148
  # Token never emitted - assign minimal span after previous
@@ -154,24 +156,13 @@ class ForcedAligner:
154
 
155
  token_id = tokens[token_idx]
156
  frames = [f for f, _ in frames_with_scores]
157
- scores = [s for _, s in frames_with_scores]
158
-
159
- # Compute emission-weighted start position for sub-frame precision
160
- # Weight shifts the position toward frames with higher emission probability
161
- total_score = sum(scores)
162
- if total_score > 0 and len(frames) > 1:
163
- # Weighted centroid gives sub-frame precision
164
- weighted_center = sum(f * s for f, s in zip(frames, scores)) / total_score
165
- # Estimate start/end based on weighted center and span width
166
- span_width = max(frames) - min(frames) + 1
167
- start_frame = weighted_center - span_width / 2
168
- end_frame = weighted_center + span_width / 2
169
- else:
170
- # Fall back to simple min/max
171
- start_frame = float(min(frames))
172
- end_frame = float(max(frames)) + 1.0
173
 
174
- token_spans.append((token_id, start_frame, end_frame))
175
 
176
  return token_spans
177
 
@@ -255,22 +246,24 @@ class ForcedAligner:
255
  end_offset = END_OFFSET
256
 
257
  # Group aligned tokens into words based on pipe separator
 
258
  words = text.split()
259
  word_timestamps = []
260
- current_word_start = None
261
- current_word_end = None
262
  word_idx = 0
263
  separator_id = dictionary.get("|", dictionary.get(" ", 0))
264
 
265
- for token_id, start_frame, end_frame in alignment_path:
266
  if token_id == separator_id: # Word separator
267
  if (
268
- current_word_start is not None
269
- and current_word_end is not None
270
  and word_idx < len(words)
271
  ):
272
- start_time = max(0.0, current_word_start * frame_duration - start_offset)
273
- end_time = max(0.0, current_word_end * frame_duration - end_offset)
 
274
  word_timestamps.append(
275
  {
276
  "word": words[word_idx],
@@ -279,21 +272,21 @@ class ForcedAligner:
279
  }
280
  )
281
  word_idx += 1
282
- current_word_start = None
283
- current_word_end = None
284
  else:
285
- if current_word_start is None:
286
- current_word_start = start_frame
287
- current_word_end = end_frame
288
 
289
  # Don't forget the last word
290
  if (
291
- current_word_start is not None
292
- and current_word_end is not None
293
  and word_idx < len(words)
294
  ):
295
- start_time = max(0.0, current_word_start * frame_duration - start_offset)
296
- end_time = max(0.0, current_word_end * frame_duration - end_offset)
297
  word_timestamps.append(
298
  {
299
  "word": words[word_idx],
 
70
  trellis = torch.full((num_frames + 1, num_tokens + 1), -float("inf"))
71
  trellis[0, 0] = 0
72
 
73
+ # Force alignment to use all tokens by preventing staying in blank
74
+ # at the end when there are still tokens to emit
75
+ if num_tokens > 1:
76
+ trellis[-num_tokens + 1:, 0] = float("inf")
77
+
78
  for t in range(num_frames):
79
  for j in range(num_tokens + 1):
80
  # Stay: emit blank and stay at j tokens
 
90
  @staticmethod
91
  def _backtrack(
92
  trellis: torch.Tensor, emission: torch.Tensor, tokens: list[int], blank_id: int = 0
93
+ ) -> list[tuple[int, float, float, float]]:
94
  """Backtrack through trellis to find optimal forced monotonic alignment.
95
 
 
 
 
 
96
  Guarantees:
97
  - All tokens are emitted exactly once
98
  - Strictly monotonic: each token's frames come after previous token's
99
  - No frame skipping or token teleporting
100
 
101
+ Returns list of (token_id, start_frame, end_frame, peak_frame) for each token.
102
+ The peak_frame is the frame with highest emission probability for that token.
103
  """
104
  num_frames = emission.size(0)
105
  num_tokens = len(tokens)
 
113
  # Alignment failed - fall back to uniform distribution
114
  frames_per_token = num_frames / num_tokens
115
  return [
116
+ (tokens[i], i * frames_per_token, (i + 1) * frames_per_token, (i + 0.5) * frames_per_token)
117
  for i in range(num_tokens)
118
  ]
119
 
120
  # Backtrack: find where each token transition occurred
121
+ # Store (frame, emission_score) for each token
122
  token_frames: list[list[tuple[int, float]]] = [[] for _ in range(num_tokens)]
123
 
124
  t = num_frames
 
131
 
132
  if move_score >= stay_score:
133
  # Token j-1 was emitted at frame t-1
134
+ # Store frame and its emission probability
135
+ emit_prob = emission[t - 1, tokens[j - 1]].exp().item()
136
+ token_frames[j - 1].insert(0, (t - 1, emit_prob))
137
  j -= 1
138
  # Always decrement time (monotonic)
139
  t -= 1
 
143
  token_frames[j - 1].insert(0, (0, 0.0))
144
  j -= 1
145
 
146
+ # Convert to spans with peak frame
147
+ token_spans: list[tuple[int, float, float, float]] = []
148
  for token_idx, frames_with_scores in enumerate(token_frames):
149
  if not frames_with_scores:
150
  # Token never emitted - assign minimal span after previous
 
156
 
157
  token_id = tokens[token_idx]
158
  frames = [f for f, _ in frames_with_scores]
159
+ start_frame = float(min(frames))
160
+ end_frame = float(max(frames)) + 1.0
161
+
162
+ # Find peak frame (highest emission probability)
163
+ peak_frame, _ = max(frames_with_scores, key=lambda x: x[1])
 
 
 
 
 
 
 
 
 
 
 
164
 
165
+ token_spans.append((token_id, start_frame, end_frame, float(peak_frame)))
166
 
167
  return token_spans
168
 
 
246
  end_offset = END_OFFSET
247
 
248
  # Group aligned tokens into words based on pipe separator
249
+ # Use peak emission frame for more accurate word boundaries
250
  words = text.split()
251
  word_timestamps = []
252
+ first_char_peak = None
253
+ last_char_peak = None
254
  word_idx = 0
255
  separator_id = dictionary.get("|", dictionary.get(" ", 0))
256
 
257
+ for token_id, start_frame, end_frame, peak_frame in alignment_path:
258
  if token_id == separator_id: # Word separator
259
  if (
260
+ first_char_peak is not None
261
+ and last_char_peak is not None
262
  and word_idx < len(words)
263
  ):
264
+ # Use peak frames for word boundaries
265
+ start_time = max(0.0, first_char_peak * frame_duration - start_offset)
266
+ end_time = max(0.0, (last_char_peak + 1) * frame_duration - end_offset)
267
  word_timestamps.append(
268
  {
269
  "word": words[word_idx],
 
272
  }
273
  )
274
  word_idx += 1
275
+ first_char_peak = None
276
+ last_char_peak = None
277
  else:
278
+ if first_char_peak is None:
279
+ first_char_peak = peak_frame
280
+ last_char_peak = peak_frame
281
 
282
  # Don't forget the last word
283
  if (
284
+ first_char_peak is not None
285
+ and last_char_peak is not None
286
  and word_idx < len(words)
287
  ):
288
+ start_time = max(0.0, first_char_peak * frame_duration - start_offset)
289
+ end_time = max(0.0, (last_char_peak + 1) * frame_duration - end_offset)
290
  word_timestamps.append(
291
  {
292
  "word": words[word_idx],