mazesmazes commited on
Commit
1b9681a
·
verified ·
1 Parent(s): 1d0a54b

Update custom model files, README, and requirements

Browse files
Files changed (1) hide show
  1. asr_pipeline.py +52 -53
asr_pipeline.py CHANGED
@@ -56,10 +56,10 @@ class ForcedAligner:
56
 
57
  @staticmethod
58
  def _get_trellis(emission: torch.Tensor, tokens: list[int], blank_id: int = 0) -> torch.Tensor:
59
- """Build Viterbi trellis for forced alignment.
60
 
61
- The trellis is a 2D matrix where trellis[t, j] represents the log probability
62
- of the most likely path that has emitted j tokens at time t.
63
 
64
  Args:
65
  emission: Log-softmax emission matrix of shape (num_frames, num_classes)
@@ -72,25 +72,21 @@ class ForcedAligner:
72
  num_frames = emission.size(0)
73
  num_tokens = len(tokens)
74
 
75
- # Initialize trellis with -inf (impossible paths)
76
  trellis = torch.full((num_frames + 1, num_tokens + 1), -float("inf"))
77
- trellis[0, 0] = 0 # Start state has probability 1
78
 
79
  for t in range(num_frames):
80
  for j in range(num_tokens + 1):
81
- # Stay in current state (emit blank)
82
- if j < num_tokens + 1:
83
- stay_prob = trellis[t, j] + emission[t, blank_id]
84
- else:
85
- stay_prob = -float("inf")
86
 
87
- # Move to next state (emit token)
88
  if j > 0:
89
- move_prob = trellis[t, j - 1] + emission[t, tokens[j - 1]]
90
  else:
91
- move_prob = -float("inf")
92
 
93
- trellis[t + 1, j] = max(stay_prob, move_prob)
94
 
95
  return trellis
96
 
@@ -100,60 +96,63 @@ class ForcedAligner:
100
  ) -> list[tuple[int, int, int]]:
101
  """Backtrack through trellis to find optimal alignment path.
102
 
103
- Args:
104
- trellis: Trellis matrix from _get_trellis
105
- emission: Log-softmax emission matrix
106
- tokens: List of target token indices
107
- blank_id: Index of the blank/CTC token
108
-
109
- Returns:
110
- List of (token_idx, start_frame, end_frame) tuples
111
  """
112
  num_frames = emission.size(0)
113
  num_tokens = len(tokens)
114
 
115
- # Start from the end
116
  t = num_frames
117
  j = num_tokens
118
- path = []
119
 
120
- # Backtrack to find where each token was emitted
121
- while j > 0:
122
- # Find the frame where token j-1 was first emitted
123
- token_end = t
124
 
125
- # Walk back while staying in state j (emitting blanks)
126
- while t > 0:
127
- stay_prob = trellis[t - 1, j] + emission[t - 1, blank_id]
128
- if j > 0:
129
- move_prob = trellis[t - 1, j - 1] + emission[t - 1, tokens[j - 1]]
130
- else:
131
- move_prob = -float("inf")
132
-
133
- # Check if we moved into this state or stayed
134
- if move_prob > stay_prob:
135
- # We moved into state j at time t-1
136
- token_start = t - 1
137
- path.append((tokens[j - 1], token_start, token_end))
138
- j -= 1
139
- t -= 1
140
- break
141
- else:
142
- # We stayed in state j
143
- t -= 1
144
 
145
- if t == 0 and j > 0:
146
- # Handle edge case: remaining tokens at the start
147
- path.append((tokens[j - 1], 0, token_end))
 
 
 
 
148
  j -= 1
149
 
150
- # Reverse to get chronological order
 
151
  path.reverse()
152
- return path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
 
154
  # Sub-frame offset to compensate for Wav2Vec2 convolutional look-ahead (in seconds)
155
  # This makes timestamps feel more "natural" by shifting them earlier
156
- OFFSET_COMPENSATION = 0.04 # 40ms
157
 
158
  @classmethod
159
  def align(
 
56
 
57
  @staticmethod
58
  def _get_trellis(emission: torch.Tensor, tokens: list[int], blank_id: int = 0) -> torch.Tensor:
59
+ """Build trellis for forced alignment using forward algorithm.
60
 
61
+ The trellis[t, j] represents the log probability of the best path that
62
+ aligns the first j tokens to the first t frames.
63
 
64
  Args:
65
  emission: Log-softmax emission matrix of shape (num_frames, num_classes)
 
72
  num_frames = emission.size(0)
73
  num_tokens = len(tokens)
74
 
 
75
  trellis = torch.full((num_frames + 1, num_tokens + 1), -float("inf"))
76
+ trellis[0, 0] = 0
77
 
78
  for t in range(num_frames):
79
  for j in range(num_tokens + 1):
80
+ # Stay: emit blank and stay at j tokens
81
+ stay = trellis[t, j] + emission[t, blank_id]
 
 
 
82
 
83
+ # Move: emit token j and advance to j+1 tokens
84
  if j > 0:
85
+ move = trellis[t, j - 1] + emission[t, tokens[j - 1]]
86
  else:
87
+ move = torch.tensor(-float("inf"))
88
 
89
+ trellis[t + 1, j] = torch.logaddexp(torch.tensor(stay), move).item()
90
 
91
  return trellis
92
 
 
96
  ) -> list[tuple[int, int, int]]:
97
  """Backtrack through trellis to find optimal alignment path.
98
 
99
+ Returns list of (token_id, start_frame, end_frame) for each token.
 
 
 
 
 
 
 
100
  """
101
  num_frames = emission.size(0)
102
  num_tokens = len(tokens)
103
 
104
+ # Trace back from final state
105
  t = num_frames
106
  j = num_tokens
107
+ path = [] # Will store (frame, token_index) pairs
108
 
109
+ while t > 0 and j >= 0:
110
+ # At position (t, j), we need to determine if we got here by:
111
+ # 1. Staying at j (emitting blank at frame t-1)
112
+ # 2. Moving from j-1 to j (emitting token j-1 at frame t-1)
113
 
114
+ if j == 0:
115
+ # Can only stay (no previous token state to come from)
116
+ t -= 1
117
+ continue
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
 
119
+ # Compare which transition was more likely
120
+ stay_score = trellis[t - 1, j] + emission[t - 1, blank_id]
121
+ move_score = trellis[t - 1, j - 1] + emission[t - 1, tokens[j - 1]]
122
+
123
+ if move_score > stay_score:
124
+ # Token j-1 was emitted at frame t-1
125
+ path.append((t - 1, j - 1))
126
  j -= 1
127
 
128
+ t -= 1
129
+
130
  path.reverse()
131
+
132
+ # Convert path to token spans with start/end frames
133
+ if not path:
134
+ return []
135
+
136
+ token_spans = []
137
+ i = 0
138
+ while i < len(path):
139
+ frame, token_idx = path[i]
140
+ start_frame = frame
141
+
142
+ # Find end frame (where this token stops being emitted)
143
+ end_frame = frame + 1
144
+ while i + 1 < len(path) and path[i + 1][1] == token_idx:
145
+ i += 1
146
+ end_frame = path[i][0] + 1
147
+
148
+ token_spans.append((tokens[token_idx], start_frame, end_frame))
149
+ i += 1
150
+
151
+ return token_spans
152
 
153
  # Sub-frame offset to compensate for Wav2Vec2 convolutional look-ahead (in seconds)
154
  # This makes timestamps feel more "natural" by shifting them earlier
155
+ OFFSET_COMPENSATION = 0.02 # 40ms
156
 
157
  @classmethod
158
  def align(