Update custom model files, README, and requirements
Browse files- alignment.py +39 -46
alignment.py
CHANGED
|
@@ -70,6 +70,11 @@ class ForcedAligner:
|
|
| 70 |
trellis = torch.full((num_frames + 1, num_tokens + 1), -float("inf"))
|
| 71 |
trellis[0, 0] = 0
|
| 72 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
for t in range(num_frames):
|
| 74 |
for j in range(num_tokens + 1):
|
| 75 |
# Stay: emit blank and stay at j tokens
|
|
@@ -85,19 +90,16 @@ class ForcedAligner:
|
|
| 85 |
@staticmethod
|
| 86 |
def _backtrack(
|
| 87 |
trellis: torch.Tensor, emission: torch.Tensor, tokens: list[int], blank_id: int = 0
|
| 88 |
-
) -> list[tuple[int, float, float]]:
|
| 89 |
"""Backtrack through trellis to find optimal forced monotonic alignment.
|
| 90 |
|
| 91 |
-
Uses emission probability weighting for sub-frame precision. Since wav2vec2
|
| 92 |
-
has 20ms frame resolution, weighting by emission scores can improve accuracy
|
| 93 |
-
by estimating where within a frame the token boundary likely falls.
|
| 94 |
-
|
| 95 |
Guarantees:
|
| 96 |
- All tokens are emitted exactly once
|
| 97 |
- Strictly monotonic: each token's frames come after previous token's
|
| 98 |
- No frame skipping or token teleporting
|
| 99 |
|
| 100 |
-
Returns list of (token_id, start_frame, end_frame) for each token.
|
|
|
|
| 101 |
"""
|
| 102 |
num_frames = emission.size(0)
|
| 103 |
num_tokens = len(tokens)
|
|
@@ -111,12 +113,12 @@ class ForcedAligner:
|
|
| 111 |
# Alignment failed - fall back to uniform distribution
|
| 112 |
frames_per_token = num_frames / num_tokens
|
| 113 |
return [
|
| 114 |
-
(tokens[i], i * frames_per_token, (i + 1) * frames_per_token)
|
| 115 |
for i in range(num_tokens)
|
| 116 |
]
|
| 117 |
|
| 118 |
# Backtrack: find where each token transition occurred
|
| 119 |
-
#
|
| 120 |
token_frames: list[list[tuple[int, float]]] = [[] for _ in range(num_tokens)]
|
| 121 |
|
| 122 |
t = num_frames
|
|
@@ -129,9 +131,9 @@ class ForcedAligner:
|
|
| 129 |
|
| 130 |
if move_score >= stay_score:
|
| 131 |
# Token j-1 was emitted at frame t-1
|
| 132 |
-
# Store frame
|
| 133 |
-
|
| 134 |
-
token_frames[j - 1].insert(0, (t - 1,
|
| 135 |
j -= 1
|
| 136 |
# Always decrement time (monotonic)
|
| 137 |
t -= 1
|
|
@@ -141,8 +143,8 @@ class ForcedAligner:
|
|
| 141 |
token_frames[j - 1].insert(0, (0, 0.0))
|
| 142 |
j -= 1
|
| 143 |
|
| 144 |
-
# Convert to spans with
|
| 145 |
-
token_spans: list[tuple[int, float, float]] = []
|
| 146 |
for token_idx, frames_with_scores in enumerate(token_frames):
|
| 147 |
if not frames_with_scores:
|
| 148 |
# Token never emitted - assign minimal span after previous
|
|
@@ -154,24 +156,13 @@ class ForcedAligner:
|
|
| 154 |
|
| 155 |
token_id = tokens[token_idx]
|
| 156 |
frames = [f for f, _ in frames_with_scores]
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
#
|
| 161 |
-
|
| 162 |
-
if total_score > 0 and len(frames) > 1:
|
| 163 |
-
# Weighted centroid gives sub-frame precision
|
| 164 |
-
weighted_center = sum(f * s for f, s in zip(frames, scores)) / total_score
|
| 165 |
-
# Estimate start/end based on weighted center and span width
|
| 166 |
-
span_width = max(frames) - min(frames) + 1
|
| 167 |
-
start_frame = weighted_center - span_width / 2
|
| 168 |
-
end_frame = weighted_center + span_width / 2
|
| 169 |
-
else:
|
| 170 |
-
# Fall back to simple min/max
|
| 171 |
-
start_frame = float(min(frames))
|
| 172 |
-
end_frame = float(max(frames)) + 1.0
|
| 173 |
|
| 174 |
-
token_spans.append((token_id, start_frame, end_frame))
|
| 175 |
|
| 176 |
return token_spans
|
| 177 |
|
|
@@ -255,22 +246,24 @@ class ForcedAligner:
|
|
| 255 |
end_offset = END_OFFSET
|
| 256 |
|
| 257 |
# Group aligned tokens into words based on pipe separator
|
|
|
|
| 258 |
words = text.split()
|
| 259 |
word_timestamps = []
|
| 260 |
-
|
| 261 |
-
|
| 262 |
word_idx = 0
|
| 263 |
separator_id = dictionary.get("|", dictionary.get(" ", 0))
|
| 264 |
|
| 265 |
-
for token_id, start_frame, end_frame in alignment_path:
|
| 266 |
if token_id == separator_id: # Word separator
|
| 267 |
if (
|
| 268 |
-
|
| 269 |
-
and
|
| 270 |
and word_idx < len(words)
|
| 271 |
):
|
| 272 |
-
|
| 273 |
-
|
|
|
|
| 274 |
word_timestamps.append(
|
| 275 |
{
|
| 276 |
"word": words[word_idx],
|
|
@@ -279,21 +272,21 @@ class ForcedAligner:
|
|
| 279 |
}
|
| 280 |
)
|
| 281 |
word_idx += 1
|
| 282 |
-
|
| 283 |
-
|
| 284 |
else:
|
| 285 |
-
if
|
| 286 |
-
|
| 287 |
-
|
| 288 |
|
| 289 |
# Don't forget the last word
|
| 290 |
if (
|
| 291 |
-
|
| 292 |
-
and
|
| 293 |
and word_idx < len(words)
|
| 294 |
):
|
| 295 |
-
start_time = max(0.0,
|
| 296 |
-
end_time = max(0.0,
|
| 297 |
word_timestamps.append(
|
| 298 |
{
|
| 299 |
"word": words[word_idx],
|
|
|
|
| 70 |
trellis = torch.full((num_frames + 1, num_tokens + 1), -float("inf"))
|
| 71 |
trellis[0, 0] = 0
|
| 72 |
|
| 73 |
+
# Force alignment to use all tokens by preventing staying in blank
|
| 74 |
+
# at the end when there are still tokens to emit
|
| 75 |
+
if num_tokens > 1:
|
| 76 |
+
trellis[-num_tokens + 1:, 0] = float("inf")
|
| 77 |
+
|
| 78 |
for t in range(num_frames):
|
| 79 |
for j in range(num_tokens + 1):
|
| 80 |
# Stay: emit blank and stay at j tokens
|
|
|
|
| 90 |
@staticmethod
|
| 91 |
def _backtrack(
|
| 92 |
trellis: torch.Tensor, emission: torch.Tensor, tokens: list[int], blank_id: int = 0
|
| 93 |
+
) -> list[tuple[int, float, float, float]]:
|
| 94 |
"""Backtrack through trellis to find optimal forced monotonic alignment.
|
| 95 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 96 |
Guarantees:
|
| 97 |
- All tokens are emitted exactly once
|
| 98 |
- Strictly monotonic: each token's frames come after previous token's
|
| 99 |
- No frame skipping or token teleporting
|
| 100 |
|
| 101 |
+
Returns list of (token_id, start_frame, end_frame, peak_frame) for each token.
|
| 102 |
+
The peak_frame is the frame with highest emission probability for that token.
|
| 103 |
"""
|
| 104 |
num_frames = emission.size(0)
|
| 105 |
num_tokens = len(tokens)
|
|
|
|
| 113 |
# Alignment failed - fall back to uniform distribution
|
| 114 |
frames_per_token = num_frames / num_tokens
|
| 115 |
return [
|
| 116 |
+
(tokens[i], i * frames_per_token, (i + 1) * frames_per_token, (i + 0.5) * frames_per_token)
|
| 117 |
for i in range(num_tokens)
|
| 118 |
]
|
| 119 |
|
| 120 |
# Backtrack: find where each token transition occurred
|
| 121 |
+
# Store (frame, emission_score) for each token
|
| 122 |
token_frames: list[list[tuple[int, float]]] = [[] for _ in range(num_tokens)]
|
| 123 |
|
| 124 |
t = num_frames
|
|
|
|
| 131 |
|
| 132 |
if move_score >= stay_score:
|
| 133 |
# Token j-1 was emitted at frame t-1
|
| 134 |
+
# Store frame and its emission probability
|
| 135 |
+
emit_prob = emission[t - 1, tokens[j - 1]].exp().item()
|
| 136 |
+
token_frames[j - 1].insert(0, (t - 1, emit_prob))
|
| 137 |
j -= 1
|
| 138 |
# Always decrement time (monotonic)
|
| 139 |
t -= 1
|
|
|
|
| 143 |
token_frames[j - 1].insert(0, (0, 0.0))
|
| 144 |
j -= 1
|
| 145 |
|
| 146 |
+
# Convert to spans with peak frame
|
| 147 |
+
token_spans: list[tuple[int, float, float, float]] = []
|
| 148 |
for token_idx, frames_with_scores in enumerate(token_frames):
|
| 149 |
if not frames_with_scores:
|
| 150 |
# Token never emitted - assign minimal span after previous
|
|
|
|
| 156 |
|
| 157 |
token_id = tokens[token_idx]
|
| 158 |
frames = [f for f, _ in frames_with_scores]
|
| 159 |
+
start_frame = float(min(frames))
|
| 160 |
+
end_frame = float(max(frames)) + 1.0
|
| 161 |
+
|
| 162 |
+
# Find peak frame (highest emission probability)
|
| 163 |
+
peak_frame, _ = max(frames_with_scores, key=lambda x: x[1])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 164 |
|
| 165 |
+
token_spans.append((token_id, start_frame, end_frame, float(peak_frame)))
|
| 166 |
|
| 167 |
return token_spans
|
| 168 |
|
|
|
|
| 246 |
end_offset = END_OFFSET
|
| 247 |
|
| 248 |
# Group aligned tokens into words based on pipe separator
|
| 249 |
+
# Use peak emission frame for more accurate word boundaries
|
| 250 |
words = text.split()
|
| 251 |
word_timestamps = []
|
| 252 |
+
first_char_peak = None
|
| 253 |
+
last_char_peak = None
|
| 254 |
word_idx = 0
|
| 255 |
separator_id = dictionary.get("|", dictionary.get(" ", 0))
|
| 256 |
|
| 257 |
+
for token_id, start_frame, end_frame, peak_frame in alignment_path:
|
| 258 |
if token_id == separator_id: # Word separator
|
| 259 |
if (
|
| 260 |
+
first_char_peak is not None
|
| 261 |
+
and last_char_peak is not None
|
| 262 |
and word_idx < len(words)
|
| 263 |
):
|
| 264 |
+
# Use peak frames for word boundaries
|
| 265 |
+
start_time = max(0.0, first_char_peak * frame_duration - start_offset)
|
| 266 |
+
end_time = max(0.0, (last_char_peak + 1) * frame_duration - end_offset)
|
| 267 |
word_timestamps.append(
|
| 268 |
{
|
| 269 |
"word": words[word_idx],
|
|
|
|
| 272 |
}
|
| 273 |
)
|
| 274 |
word_idx += 1
|
| 275 |
+
first_char_peak = None
|
| 276 |
+
last_char_peak = None
|
| 277 |
else:
|
| 278 |
+
if first_char_peak is None:
|
| 279 |
+
first_char_peak = peak_frame
|
| 280 |
+
last_char_peak = peak_frame
|
| 281 |
|
| 282 |
# Don't forget the last word
|
| 283 |
if (
|
| 284 |
+
first_char_peak is not None
|
| 285 |
+
and last_char_peak is not None
|
| 286 |
and word_idx < len(words)
|
| 287 |
):
|
| 288 |
+
start_time = max(0.0, first_char_peak * frame_duration - start_offset)
|
| 289 |
+
end_time = max(0.0, (last_char_peak + 1) * frame_duration - end_offset)
|
| 290 |
word_timestamps.append(
|
| 291 |
{
|
| 292 |
"word": words[word_idx],
|