Model save
Browse files- .gitattributes +1 -0
- README.md +54 -0
- alignment.py +299 -0
- asr_config.py +212 -0
- asr_modeling.py +1110 -0
- asr_pipeline.py +569 -0
- asr_processing.py +133 -0
- audio_head.py +396 -0
- chat_template.jinja +94 -0
- diarization.py +759 -0
- modules/__init__.py +5 -0
- modules/mlp.py +197 -0
- preprocessor_config.json +19 -0
- projectors.py +505 -0
- tokenizer.json +3 -0
- tokenizer_config.json +19 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
README.md
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
library_name: transformers
|
| 3 |
+
tags:
|
| 4 |
+
- generated_from_trainer
|
| 5 |
+
model-index:
|
| 6 |
+
- name: test_s2s_output
|
| 7 |
+
results: []
|
| 8 |
+
---
|
| 9 |
+
|
| 10 |
+
<!-- This model card has been generated automatically according to the information the Trainer had access to. You
|
| 11 |
+
should probably proofread and complete it, then remove this comment. -->
|
| 12 |
+
|
| 13 |
+
# test_s2s_output
|
| 14 |
+
|
| 15 |
+
This model is a fine-tuned version of [](https://huggingface.co/) on an unknown dataset.
|
| 16 |
+
|
| 17 |
+
## Model description
|
| 18 |
+
|
| 19 |
+
More information needed
|
| 20 |
+
|
| 21 |
+
## Intended uses & limitations
|
| 22 |
+
|
| 23 |
+
More information needed
|
| 24 |
+
|
| 25 |
+
## Training and evaluation data
|
| 26 |
+
|
| 27 |
+
More information needed
|
| 28 |
+
|
| 29 |
+
## Training procedure
|
| 30 |
+
|
| 31 |
+
### Training hyperparameters
|
| 32 |
+
|
| 33 |
+
The following hyperparameters were used during training:
|
| 34 |
+
- learning_rate: 0.0001
|
| 35 |
+
- train_batch_size: 16
|
| 36 |
+
- eval_batch_size: 16
|
| 37 |
+
- seed: 42
|
| 38 |
+
- gradient_accumulation_steps: 2
|
| 39 |
+
- total_train_batch_size: 32
|
| 40 |
+
- optimizer: Use OptimizerNames.ADAMW_TORCH_FUSED with betas=(0.9,0.999) and epsilon=1e-08 and optimizer_args=No additional optimizer arguments
|
| 41 |
+
- lr_scheduler_type: polynomial
|
| 42 |
+
- lr_scheduler_warmup_steps: 500
|
| 43 |
+
- training_steps: 5
|
| 44 |
+
|
| 45 |
+
### Training results
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
### Framework versions
|
| 50 |
+
|
| 51 |
+
- Transformers 5.0.0
|
| 52 |
+
- Pytorch 2.8.0
|
| 53 |
+
- Datasets 3.6.0
|
| 54 |
+
- Tokenizers 0.22.2
|
alignment.py
ADDED
|
@@ -0,0 +1,299 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Forced alignment for word-level timestamps using Wav2Vec2."""
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
# Offset compensation for Wav2Vec2-BASE systematic bias (in seconds)
|
| 7 |
+
# Calibrated on librispeech-alignments dataset (n=25, MAE=48ms)
|
| 8 |
+
START_OFFSET = 0.04 # Subtract from start times (shift earlier)
|
| 9 |
+
END_OFFSET = -0.04 # Subtract from end times (shift later)
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def _get_device() -> str:
|
| 13 |
+
"""Get best available device for non-transformers models."""
|
| 14 |
+
if torch.cuda.is_available():
|
| 15 |
+
return "cuda"
|
| 16 |
+
if torch.backends.mps.is_available():
|
| 17 |
+
return "mps"
|
| 18 |
+
return "cpu"
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class ForcedAligner:
|
| 22 |
+
"""Lazy-loaded forced aligner for word-level timestamps using torchaudio wav2vec2.
|
| 23 |
+
|
| 24 |
+
Uses Viterbi trellis algorithm for optimal alignment path finding.
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
_bundle = None
|
| 28 |
+
_model = None
|
| 29 |
+
_labels = None
|
| 30 |
+
_dictionary = None
|
| 31 |
+
|
| 32 |
+
@classmethod
|
| 33 |
+
def get_instance(cls, device: str = "cuda"):
|
| 34 |
+
"""Get or create the forced alignment model (singleton).
|
| 35 |
+
|
| 36 |
+
Args:
|
| 37 |
+
device: Device to run model on ("cuda" or "cpu")
|
| 38 |
+
|
| 39 |
+
Returns:
|
| 40 |
+
Tuple of (model, labels, dictionary)
|
| 41 |
+
"""
|
| 42 |
+
if cls._model is None:
|
| 43 |
+
import torchaudio
|
| 44 |
+
|
| 45 |
+
cls._bundle = torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H
|
| 46 |
+
cls._model = cls._bundle.get_model().to(device)
|
| 47 |
+
cls._model.eval()
|
| 48 |
+
cls._labels = cls._bundle.get_labels()
|
| 49 |
+
cls._dictionary = {c: i for i, c in enumerate(cls._labels)}
|
| 50 |
+
return cls._model, cls._labels, cls._dictionary
|
| 51 |
+
|
| 52 |
+
@staticmethod
|
| 53 |
+
def _get_trellis(emission: torch.Tensor, tokens: list[int], blank_id: int = 0) -> torch.Tensor:
|
| 54 |
+
"""Build trellis for forced alignment using forward algorithm.
|
| 55 |
+
|
| 56 |
+
The trellis[t, j] represents the log probability of the best path that
|
| 57 |
+
aligns the first j tokens to the first t frames.
|
| 58 |
+
|
| 59 |
+
Args:
|
| 60 |
+
emission: Log-softmax emission matrix of shape (num_frames, num_classes)
|
| 61 |
+
tokens: List of target token indices
|
| 62 |
+
blank_id: Index of the blank/CTC token (default 0)
|
| 63 |
+
|
| 64 |
+
Returns:
|
| 65 |
+
Trellis matrix of shape (num_frames + 1, num_tokens + 1)
|
| 66 |
+
"""
|
| 67 |
+
num_frames = emission.size(0)
|
| 68 |
+
num_tokens = len(tokens)
|
| 69 |
+
|
| 70 |
+
trellis = torch.full((num_frames + 1, num_tokens + 1), -float("inf"))
|
| 71 |
+
trellis[0, 0] = 0
|
| 72 |
+
|
| 73 |
+
# Force alignment to use all tokens by preventing staying in blank
|
| 74 |
+
# at the end when there are still tokens to emit
|
| 75 |
+
if num_tokens > 1:
|
| 76 |
+
trellis[-num_tokens + 1 :, 0] = float("inf")
|
| 77 |
+
|
| 78 |
+
for t in range(num_frames):
|
| 79 |
+
for j in range(num_tokens + 1):
|
| 80 |
+
# Stay: emit blank and stay at j tokens
|
| 81 |
+
stay = trellis[t, j] + emission[t, blank_id]
|
| 82 |
+
|
| 83 |
+
# Move: emit token j and advance to j+1 tokens
|
| 84 |
+
move = trellis[t, j - 1] + emission[t, tokens[j - 1]] if j > 0 else -float("inf")
|
| 85 |
+
|
| 86 |
+
trellis[t + 1, j] = max(stay, move) # Viterbi: take best path
|
| 87 |
+
|
| 88 |
+
return trellis
|
| 89 |
+
|
| 90 |
+
@staticmethod
|
| 91 |
+
def _backtrack(
|
| 92 |
+
trellis: torch.Tensor, emission: torch.Tensor, tokens: list[int], blank_id: int = 0
|
| 93 |
+
) -> list[tuple[int, float, float, float]]:
|
| 94 |
+
"""Backtrack through trellis to find optimal forced monotonic alignment.
|
| 95 |
+
|
| 96 |
+
Guarantees:
|
| 97 |
+
- All tokens are emitted exactly once
|
| 98 |
+
- Strictly monotonic: each token's frames come after previous token's
|
| 99 |
+
- No frame skipping or token teleporting
|
| 100 |
+
|
| 101 |
+
Returns list of (token_id, start_frame, end_frame, peak_frame) for each token.
|
| 102 |
+
The peak_frame is the frame with highest emission probability for that token.
|
| 103 |
+
"""
|
| 104 |
+
num_frames = emission.size(0)
|
| 105 |
+
num_tokens = len(tokens)
|
| 106 |
+
|
| 107 |
+
if num_tokens == 0:
|
| 108 |
+
return []
|
| 109 |
+
|
| 110 |
+
# Find the best ending point (should be at num_tokens)
|
| 111 |
+
# But verify trellis reached a valid state
|
| 112 |
+
if trellis[num_frames, num_tokens] == -float("inf"):
|
| 113 |
+
# Alignment failed - fall back to uniform distribution
|
| 114 |
+
frames_per_token = num_frames / num_tokens
|
| 115 |
+
return [
|
| 116 |
+
(
|
| 117 |
+
tokens[i],
|
| 118 |
+
i * frames_per_token,
|
| 119 |
+
(i + 1) * frames_per_token,
|
| 120 |
+
(i + 0.5) * frames_per_token,
|
| 121 |
+
)
|
| 122 |
+
for i in range(num_tokens)
|
| 123 |
+
]
|
| 124 |
+
|
| 125 |
+
# Backtrack: find where each token transition occurred
|
| 126 |
+
# Store (frame, emission_score) for each token
|
| 127 |
+
token_frames: list[list[tuple[int, float]]] = [[] for _ in range(num_tokens)]
|
| 128 |
+
|
| 129 |
+
t = num_frames
|
| 130 |
+
j = num_tokens
|
| 131 |
+
|
| 132 |
+
while t > 0 and j > 0:
|
| 133 |
+
# Check: did we transition from j-1 to j at frame t-1?
|
| 134 |
+
stay_score = trellis[t - 1, j] + emission[t - 1, blank_id]
|
| 135 |
+
move_score = trellis[t - 1, j - 1] + emission[t - 1, tokens[j - 1]]
|
| 136 |
+
|
| 137 |
+
if move_score >= stay_score:
|
| 138 |
+
# Token j-1 was emitted at frame t-1
|
| 139 |
+
# Store frame and its emission probability
|
| 140 |
+
emit_prob = emission[t - 1, tokens[j - 1]].exp().item()
|
| 141 |
+
token_frames[j - 1].insert(0, (t - 1, emit_prob))
|
| 142 |
+
j -= 1
|
| 143 |
+
# Always decrement time (monotonic)
|
| 144 |
+
t -= 1
|
| 145 |
+
|
| 146 |
+
# Handle any remaining tokens at the start (edge case)
|
| 147 |
+
while j > 0:
|
| 148 |
+
token_frames[j - 1].insert(0, (0, 0.0))
|
| 149 |
+
j -= 1
|
| 150 |
+
|
| 151 |
+
# Convert to spans with peak frame
|
| 152 |
+
token_spans: list[tuple[int, float, float, float]] = []
|
| 153 |
+
for token_idx, frames_with_scores in enumerate(token_frames):
|
| 154 |
+
if not frames_with_scores:
|
| 155 |
+
# Token never emitted - assign minimal span after previous
|
| 156 |
+
if token_spans:
|
| 157 |
+
prev_end = token_spans[-1][2]
|
| 158 |
+
frames_with_scores = [(int(prev_end), 0.0)]
|
| 159 |
+
else:
|
| 160 |
+
frames_with_scores = [(0, 0.0)]
|
| 161 |
+
|
| 162 |
+
token_id = tokens[token_idx]
|
| 163 |
+
frames = [f for f, _ in frames_with_scores]
|
| 164 |
+
start_frame = float(min(frames))
|
| 165 |
+
end_frame = float(max(frames)) + 1.0
|
| 166 |
+
|
| 167 |
+
# Find peak frame (highest emission probability)
|
| 168 |
+
peak_frame, _ = max(frames_with_scores, key=lambda x: x[1])
|
| 169 |
+
|
| 170 |
+
token_spans.append((token_id, start_frame, end_frame, float(peak_frame)))
|
| 171 |
+
|
| 172 |
+
return token_spans
|
| 173 |
+
|
| 174 |
+
@classmethod
|
| 175 |
+
def align(
|
| 176 |
+
cls,
|
| 177 |
+
audio: np.ndarray,
|
| 178 |
+
text: str,
|
| 179 |
+
sample_rate: int = 16000,
|
| 180 |
+
_language: str = "eng",
|
| 181 |
+
_batch_size: int = 16,
|
| 182 |
+
) -> list[dict]:
|
| 183 |
+
"""Align transcript to audio and return word-level timestamps.
|
| 184 |
+
|
| 185 |
+
Uses Viterbi trellis algorithm for optimal forced alignment.
|
| 186 |
+
|
| 187 |
+
Args:
|
| 188 |
+
audio: Audio waveform as numpy array
|
| 189 |
+
text: Transcript text to align
|
| 190 |
+
sample_rate: Audio sample rate (default 16000)
|
| 191 |
+
_language: ISO-639-3 language code (default "eng" for English, unused)
|
| 192 |
+
_batch_size: Batch size for alignment model (unused)
|
| 193 |
+
|
| 194 |
+
Returns:
|
| 195 |
+
List of dicts with 'word', 'start', 'end' keys
|
| 196 |
+
"""
|
| 197 |
+
import torchaudio
|
| 198 |
+
|
| 199 |
+
device = _get_device()
|
| 200 |
+
model, _labels, dictionary = cls.get_instance(device)
|
| 201 |
+
assert cls._bundle is not None and dictionary is not None # Initialized by get_instance
|
| 202 |
+
|
| 203 |
+
# Convert audio to tensor (copy to ensure array is writable)
|
| 204 |
+
if isinstance(audio, np.ndarray):
|
| 205 |
+
waveform = torch.from_numpy(audio.copy()).float()
|
| 206 |
+
else:
|
| 207 |
+
waveform = audio.clone().float()
|
| 208 |
+
|
| 209 |
+
# Ensure 2D (channels, time)
|
| 210 |
+
if waveform.dim() == 1:
|
| 211 |
+
waveform = waveform.unsqueeze(0)
|
| 212 |
+
|
| 213 |
+
# Resample if needed (wav2vec2 expects 16kHz)
|
| 214 |
+
if sample_rate != cls._bundle.sample_rate:
|
| 215 |
+
waveform = torchaudio.functional.resample(
|
| 216 |
+
waveform, sample_rate, cls._bundle.sample_rate
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
waveform = waveform.to(device)
|
| 220 |
+
|
| 221 |
+
# Get emissions from model
|
| 222 |
+
with torch.inference_mode():
|
| 223 |
+
emissions, _ = model(waveform)
|
| 224 |
+
emissions = torch.log_softmax(emissions, dim=-1)
|
| 225 |
+
|
| 226 |
+
emission = emissions[0].cpu()
|
| 227 |
+
|
| 228 |
+
# Normalize text: uppercase, keep only valid characters
|
| 229 |
+
transcript = text.upper()
|
| 230 |
+
|
| 231 |
+
# Build tokens from transcript (including word separators)
|
| 232 |
+
tokens = []
|
| 233 |
+
for char in transcript:
|
| 234 |
+
if char in dictionary:
|
| 235 |
+
tokens.append(dictionary[char])
|
| 236 |
+
elif char == " ":
|
| 237 |
+
tokens.append(dictionary.get("|", dictionary.get(" ", 0)))
|
| 238 |
+
|
| 239 |
+
if not tokens:
|
| 240 |
+
return []
|
| 241 |
+
|
| 242 |
+
# Build Viterbi trellis and backtrack for optimal path
|
| 243 |
+
trellis = cls._get_trellis(emission, tokens, blank_id=0)
|
| 244 |
+
alignment_path = cls._backtrack(trellis, emission, tokens, blank_id=0)
|
| 245 |
+
|
| 246 |
+
# Convert frame indices to time (model stride is 320 samples at 16kHz = 20ms)
|
| 247 |
+
frame_duration = 320 / cls._bundle.sample_rate
|
| 248 |
+
|
| 249 |
+
# Apply separate offset compensation for start/end (Wav2Vec2 systematic bias)
|
| 250 |
+
start_offset = START_OFFSET
|
| 251 |
+
end_offset = END_OFFSET
|
| 252 |
+
|
| 253 |
+
# Group aligned tokens into words based on pipe separator
|
| 254 |
+
# Use peak emission frame for more accurate word boundaries
|
| 255 |
+
words = text.split()
|
| 256 |
+
word_timestamps = []
|
| 257 |
+
first_char_peak = None
|
| 258 |
+
last_char_peak = None
|
| 259 |
+
word_idx = 0
|
| 260 |
+
separator_id = dictionary.get("|", dictionary.get(" ", 0))
|
| 261 |
+
|
| 262 |
+
for token_id, _start_frame, _end_frame, peak_frame in alignment_path:
|
| 263 |
+
if token_id == separator_id: # Word separator
|
| 264 |
+
if (
|
| 265 |
+
first_char_peak is not None
|
| 266 |
+
and last_char_peak is not None
|
| 267 |
+
and word_idx < len(words)
|
| 268 |
+
):
|
| 269 |
+
# Use peak frames for word boundaries
|
| 270 |
+
start_time = max(0.0, first_char_peak * frame_duration - start_offset)
|
| 271 |
+
end_time = max(0.0, (last_char_peak + 1) * frame_duration - end_offset)
|
| 272 |
+
word_timestamps.append(
|
| 273 |
+
{
|
| 274 |
+
"word": words[word_idx],
|
| 275 |
+
"start": start_time,
|
| 276 |
+
"end": end_time,
|
| 277 |
+
}
|
| 278 |
+
)
|
| 279 |
+
word_idx += 1
|
| 280 |
+
first_char_peak = None
|
| 281 |
+
last_char_peak = None
|
| 282 |
+
else:
|
| 283 |
+
if first_char_peak is None:
|
| 284 |
+
first_char_peak = peak_frame
|
| 285 |
+
last_char_peak = peak_frame
|
| 286 |
+
|
| 287 |
+
# Don't forget the last word
|
| 288 |
+
if first_char_peak is not None and last_char_peak is not None and word_idx < len(words):
|
| 289 |
+
start_time = max(0.0, first_char_peak * frame_duration - start_offset)
|
| 290 |
+
end_time = max(0.0, (last_char_peak + 1) * frame_duration - end_offset)
|
| 291 |
+
word_timestamps.append(
|
| 292 |
+
{
|
| 293 |
+
"word": words[word_idx],
|
| 294 |
+
"start": start_time,
|
| 295 |
+
"end": end_time,
|
| 296 |
+
}
|
| 297 |
+
)
|
| 298 |
+
|
| 299 |
+
return word_timestamps
|
asr_config.py
ADDED
|
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional
|
| 2 |
+
|
| 3 |
+
import transformers
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class ASRConfig(transformers.PretrainedConfig):
|
| 7 |
+
"""Configuration class for the ASR model."""
|
| 8 |
+
|
| 9 |
+
model_type = "asr_model"
|
| 10 |
+
is_composition = True
|
| 11 |
+
|
| 12 |
+
# Generation defaults
|
| 13 |
+
GENERATION_DEFAULTS = {
|
| 14 |
+
"num_beams": 1,
|
| 15 |
+
"max_new_tokens": 128,
|
| 16 |
+
"min_new_tokens": 0,
|
| 17 |
+
"repetition_penalty": 1.0,
|
| 18 |
+
"length_penalty": 1.0,
|
| 19 |
+
"no_repeat_ngram_size": 0,
|
| 20 |
+
"use_cache": True,
|
| 21 |
+
"do_sample": False,
|
| 22 |
+
"temperature": None,
|
| 23 |
+
"top_p": None,
|
| 24 |
+
"top_k": None,
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
def __init__(
|
| 28 |
+
self,
|
| 29 |
+
# Model IDs
|
| 30 |
+
audio_model_id: str = "zai-org/GLM-ASR-Nano-2512",
|
| 31 |
+
text_model_id: str = "Qwen/Qwen3-0.6B",
|
| 32 |
+
# Model settings
|
| 33 |
+
attn_implementation: str = "sdpa",
|
| 34 |
+
model_dtype: str = "bfloat16",
|
| 35 |
+
system_prompt: str = "You are a helpful assistant.",
|
| 36 |
+
enable_thinking: bool = False,
|
| 37 |
+
# Encoder settings (auto-detected if None)
|
| 38 |
+
encoder_dim: Optional[int] = None,
|
| 39 |
+
llm_dim: Optional[int] = None,
|
| 40 |
+
encoder_conv_layers: Optional[list] = None,
|
| 41 |
+
audio_sample_rate: int = 16000,
|
| 42 |
+
# Projector settings
|
| 43 |
+
projector_type: str = "mlp",
|
| 44 |
+
projector_pool_stride: int = 4,
|
| 45 |
+
projector_hidden_dim: Optional[int] = None,
|
| 46 |
+
projector_num_layers: int = 2,
|
| 47 |
+
projector_init_std: float = 0.02,
|
| 48 |
+
projector_dropout: float = 0.0,
|
| 49 |
+
# MoE projector settings
|
| 50 |
+
num_experts: int = 4,
|
| 51 |
+
num_experts_per_tok: int = 2,
|
| 52 |
+
router_aux_loss_coef: float = 0.01,
|
| 53 |
+
# QFormer projector settings
|
| 54 |
+
qformer_window_size: int = 15,
|
| 55 |
+
qformer_hidden_size: Optional[int] = None,
|
| 56 |
+
qformer_num_layers: int = 2,
|
| 57 |
+
qformer_num_heads: int = 16,
|
| 58 |
+
qformer_intermediate_size: Optional[int] = None,
|
| 59 |
+
downsample_rate: int = 5,
|
| 60 |
+
# Training settings (not saved to config.json for inference)
|
| 61 |
+
use_specaugment: bool = False,
|
| 62 |
+
num_time_masks: int = 2,
|
| 63 |
+
time_mask_length: int = 10,
|
| 64 |
+
num_freq_masks: int = 0,
|
| 65 |
+
freq_mask_length: int = 10,
|
| 66 |
+
use_lora: bool = False,
|
| 67 |
+
lora_rank: int = 8,
|
| 68 |
+
lora_alpha: int = 32,
|
| 69 |
+
lora_dropout: float = 0.0,
|
| 70 |
+
lora_target_modules: Optional[list] = None,
|
| 71 |
+
freeze_projector: bool = False,
|
| 72 |
+
label_smoothing: float = 0.0,
|
| 73 |
+
# Audio Head settings (flow matching with pocket-tts)
|
| 74 |
+
use_audio_head: bool = False,
|
| 75 |
+
freeze_audio_head: bool = False, # Freeze entire audio head
|
| 76 |
+
lsd_decode_steps: int = 1, # LSD decoding integration steps
|
| 77 |
+
flow_temperature: float = 1.0, # Sampling temperature for flow generation
|
| 78 |
+
pocket_tts_weights: Optional[str] = None, # Path to pretrained pocket-tts weights
|
| 79 |
+
freeze_flow_net: bool = True, # Freeze flow_net, only train llm_proj
|
| 80 |
+
**kwargs,
|
| 81 |
+
):
|
| 82 |
+
# Merge generation defaults with kwargs (kwargs takes precedence)
|
| 83 |
+
for key, default in self.GENERATION_DEFAULTS.items():
|
| 84 |
+
if key not in kwargs:
|
| 85 |
+
kwargs[key] = default
|
| 86 |
+
|
| 87 |
+
# Core model settings
|
| 88 |
+
self.audio_model_id = audio_model_id
|
| 89 |
+
self.text_model_id = text_model_id
|
| 90 |
+
self.attn_implementation = attn_implementation
|
| 91 |
+
self.model_dtype = model_dtype
|
| 92 |
+
self.system_prompt = system_prompt
|
| 93 |
+
self.enable_thinking = enable_thinking
|
| 94 |
+
|
| 95 |
+
# Encoder settings
|
| 96 |
+
self.encoder_dim = encoder_dim
|
| 97 |
+
self.llm_dim = llm_dim
|
| 98 |
+
self.encoder_conv_layers = encoder_conv_layers or [(1, 3, 1), (1, 3, 2)]
|
| 99 |
+
self.audio_sample_rate = audio_sample_rate
|
| 100 |
+
|
| 101 |
+
# Projector settings
|
| 102 |
+
self.projector_type = projector_type
|
| 103 |
+
self.projector_pool_stride = projector_pool_stride
|
| 104 |
+
self.projector_hidden_dim = projector_hidden_dim
|
| 105 |
+
self.projector_num_layers = projector_num_layers
|
| 106 |
+
self.projector_init_std = projector_init_std
|
| 107 |
+
self.projector_dropout = projector_dropout
|
| 108 |
+
|
| 109 |
+
# MoE settings
|
| 110 |
+
self.num_experts = num_experts
|
| 111 |
+
self.num_experts_per_tok = num_experts_per_tok
|
| 112 |
+
self.router_aux_loss_coef = router_aux_loss_coef
|
| 113 |
+
|
| 114 |
+
# QFormer settings
|
| 115 |
+
self.qformer_window_size = qformer_window_size
|
| 116 |
+
self.qformer_hidden_size = qformer_hidden_size
|
| 117 |
+
self.qformer_num_layers = qformer_num_layers
|
| 118 |
+
self.qformer_num_heads = qformer_num_heads
|
| 119 |
+
self.qformer_intermediate_size = qformer_intermediate_size
|
| 120 |
+
self.downsample_rate = downsample_rate
|
| 121 |
+
|
| 122 |
+
# Training settings
|
| 123 |
+
self.use_specaugment = use_specaugment
|
| 124 |
+
self.num_time_masks = num_time_masks
|
| 125 |
+
self.time_mask_length = time_mask_length
|
| 126 |
+
self.num_freq_masks = num_freq_masks
|
| 127 |
+
self.freq_mask_length = freq_mask_length
|
| 128 |
+
self.use_lora = use_lora
|
| 129 |
+
self.lora_rank = lora_rank
|
| 130 |
+
self.lora_alpha = lora_alpha
|
| 131 |
+
self.lora_dropout = lora_dropout
|
| 132 |
+
self.lora_target_modules = lora_target_modules or [
|
| 133 |
+
"q_proj",
|
| 134 |
+
"k_proj",
|
| 135 |
+
"v_proj",
|
| 136 |
+
"o_proj",
|
| 137 |
+
"gate_proj",
|
| 138 |
+
"up_proj",
|
| 139 |
+
"down_proj",
|
| 140 |
+
]
|
| 141 |
+
self.freeze_projector = freeze_projector
|
| 142 |
+
self.label_smoothing = label_smoothing
|
| 143 |
+
|
| 144 |
+
# Audio Head settings (flow matching with pocket-tts)
|
| 145 |
+
self.use_audio_head = use_audio_head
|
| 146 |
+
self.freeze_audio_head = freeze_audio_head
|
| 147 |
+
self.lsd_decode_steps = lsd_decode_steps
|
| 148 |
+
self.flow_temperature = flow_temperature
|
| 149 |
+
self.pocket_tts_weights = pocket_tts_weights
|
| 150 |
+
self.freeze_flow_net = freeze_flow_net
|
| 151 |
+
|
| 152 |
+
# Generation parameters (from kwargs after merge with defaults)
|
| 153 |
+
self.num_beams = kwargs.pop("num_beams")
|
| 154 |
+
self.max_new_tokens = kwargs.pop("max_new_tokens")
|
| 155 |
+
self.min_new_tokens = kwargs.pop("min_new_tokens")
|
| 156 |
+
self.repetition_penalty = kwargs.pop("repetition_penalty")
|
| 157 |
+
self.length_penalty = kwargs.pop("length_penalty")
|
| 158 |
+
self.no_repeat_ngram_size = kwargs.pop("no_repeat_ngram_size")
|
| 159 |
+
self.use_cache = kwargs.pop("use_cache")
|
| 160 |
+
self.do_sample = kwargs.pop("do_sample")
|
| 161 |
+
self.temperature = kwargs.pop("temperature")
|
| 162 |
+
self.top_p = kwargs.pop("top_p")
|
| 163 |
+
self.top_k = kwargs.pop("top_k")
|
| 164 |
+
|
| 165 |
+
# Load sub-configs
|
| 166 |
+
self.audio_config = kwargs.pop("audio_config", None)
|
| 167 |
+
if self.audio_config is None:
|
| 168 |
+
self.audio_config = transformers.AutoConfig.from_pretrained(
|
| 169 |
+
audio_model_id, trust_remote_code=True
|
| 170 |
+
)
|
| 171 |
+
self.audio_config.dtype = model_dtype
|
| 172 |
+
elif isinstance(self.audio_config, dict) and self.audio_config.get("model_type"):
|
| 173 |
+
config_class = transformers.AutoConfig.for_model(
|
| 174 |
+
self.audio_config["model_type"]
|
| 175 |
+
).__class__
|
| 176 |
+
self.audio_config = config_class(**self.audio_config)
|
| 177 |
+
|
| 178 |
+
self.text_config = kwargs.pop("text_config", None)
|
| 179 |
+
if self.text_config is None:
|
| 180 |
+
self.text_config = transformers.AutoConfig.from_pretrained(
|
| 181 |
+
text_model_id, trust_remote_code=True
|
| 182 |
+
)
|
| 183 |
+
self.text_config.dtype = model_dtype
|
| 184 |
+
elif isinstance(self.text_config, dict):
|
| 185 |
+
config_class = transformers.AutoConfig.for_model(
|
| 186 |
+
self.text_config["model_type"]
|
| 187 |
+
).__class__
|
| 188 |
+
self.text_config = config_class(**self.text_config)
|
| 189 |
+
|
| 190 |
+
super().__init__(**kwargs)
|
| 191 |
+
|
| 192 |
+
# Pipeline configuration
|
| 193 |
+
self.encoder = self.audio_config
|
| 194 |
+
self.auto_map = {
|
| 195 |
+
"AutoConfig": "asr_config.ASRConfig",
|
| 196 |
+
"AutoModel": "asr_modeling.ASRModel",
|
| 197 |
+
"AutoModelForSpeechSeq2Seq": "asr_modeling.ASRModel",
|
| 198 |
+
"AutoProcessor": "asr_processing.ASRProcessor",
|
| 199 |
+
}
|
| 200 |
+
self.custom_pipelines = {
|
| 201 |
+
"automatic-speech-recognition": {
|
| 202 |
+
"impl": "asr_pipeline.ASRPipeline",
|
| 203 |
+
"pt": ["AutoModelForSpeechSeq2Seq"],
|
| 204 |
+
"tf": [],
|
| 205 |
+
"type": "audio",
|
| 206 |
+
}
|
| 207 |
+
}
|
| 208 |
+
self.architectures = ["ASRModel"]
|
| 209 |
+
self.pipeline_tag = "automatic-speech-recognition"
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
transformers.AutoConfig.register("asr_model", ASRConfig)
|
asr_modeling.py
ADDED
|
@@ -0,0 +1,1110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
from threading import Thread
|
| 4 |
+
from typing import Iterator, Optional, Union
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
from transformers import (
|
| 9 |
+
AutoConfig,
|
| 10 |
+
AutoModel,
|
| 11 |
+
AutoModelForCausalLM,
|
| 12 |
+
AutoTokenizer,
|
| 13 |
+
PreTrainedModel,
|
| 14 |
+
TextIteratorStreamer,
|
| 15 |
+
)
|
| 16 |
+
from transformers.generation import GenerationMixin
|
| 17 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
| 18 |
+
|
| 19 |
+
try:
|
| 20 |
+
from .asr_config import ASRConfig
|
| 21 |
+
from .projectors import PROJECTOR_CLASSES
|
| 22 |
+
except ImportError:
|
| 23 |
+
from asr_config import ASRConfig # type: ignore[no-redef]
|
| 24 |
+
from projectors import PROJECTOR_CLASSES # type: ignore[no-redef]
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
from torchaudio.transforms import SpecAugment
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class ASRModel(PreTrainedModel, GenerationMixin):
|
| 31 |
+
"""Audio-to-text model combining an audio encoder, projector, and language model."""
|
| 32 |
+
|
| 33 |
+
config_class = ASRConfig
|
| 34 |
+
base_model_prefix = "model"
|
| 35 |
+
main_input_name = "input_features"
|
| 36 |
+
_supports_flash_attn_2 = True
|
| 37 |
+
supports_gradient_checkpointing = True
|
| 38 |
+
_is_loading_from_pretrained: bool = False
|
| 39 |
+
_pretrained_model_path: Optional[str] = None
|
| 40 |
+
|
| 41 |
+
TRANSCRIBE_PROMPT = ""
|
| 42 |
+
|
| 43 |
+
@classmethod
|
| 44 |
+
def from_pretrained(cls, pretrained_model_name_or_path: str, *args, **kwargs) -> "ASRModel":
|
| 45 |
+
"""Load model from pretrained, handling device placement correctly."""
|
| 46 |
+
from safetensors.torch import load_file
|
| 47 |
+
from transformers.utils.hub import cached_file
|
| 48 |
+
|
| 49 |
+
config = kwargs.pop("config", None)
|
| 50 |
+
if config is None:
|
| 51 |
+
config = ASRConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
| 52 |
+
|
| 53 |
+
# Set flag to avoid device_map="auto" in sub-model loaders
|
| 54 |
+
cls._is_loading_from_pretrained = True
|
| 55 |
+
cls._pretrained_model_path = pretrained_model_name_or_path
|
| 56 |
+
|
| 57 |
+
try:
|
| 58 |
+
model = cls(config, **kwargs)
|
| 59 |
+
|
| 60 |
+
# Load projector weights from safetensors
|
| 61 |
+
subfolder = kwargs.get("subfolder")
|
| 62 |
+
revision = kwargs.get("revision")
|
| 63 |
+
cache_kwargs = {}
|
| 64 |
+
if subfolder:
|
| 65 |
+
cache_kwargs["subfolder"] = subfolder
|
| 66 |
+
if revision:
|
| 67 |
+
cache_kwargs["revision"] = revision
|
| 68 |
+
|
| 69 |
+
model_file = cached_file(
|
| 70 |
+
pretrained_model_name_or_path,
|
| 71 |
+
"model.safetensors",
|
| 72 |
+
_raise_exceptions_for_missing_entries=False,
|
| 73 |
+
**cache_kwargs,
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
if model_file is not None:
|
| 77 |
+
state_dict = load_file(model_file)
|
| 78 |
+
model.load_state_dict(state_dict, strict=False)
|
| 79 |
+
|
| 80 |
+
# Load LoRA adapters if use_lora is enabled
|
| 81 |
+
if getattr(config, "use_lora", False):
|
| 82 |
+
# Check for adapter_config.json (required by PEFT to load adapters)
|
| 83 |
+
adapter_config_file = cached_file(
|
| 84 |
+
pretrained_model_name_or_path,
|
| 85 |
+
"adapter_config.json",
|
| 86 |
+
_raise_exceptions_for_missing_entries=False,
|
| 87 |
+
**cache_kwargs,
|
| 88 |
+
)
|
| 89 |
+
if adapter_config_file is not None:
|
| 90 |
+
# Load saved adapter weights using the original repo_id/path
|
| 91 |
+
# PEFT handles Hub downloads and caching internally
|
| 92 |
+
from peft import PeftModel
|
| 93 |
+
|
| 94 |
+
model.language_model = PeftModel.from_pretrained(
|
| 95 |
+
model.language_model,
|
| 96 |
+
pretrained_model_name_or_path,
|
| 97 |
+
is_trainable=True,
|
| 98 |
+
**cache_kwargs,
|
| 99 |
+
)
|
| 100 |
+
else:
|
| 101 |
+
# No saved adapters - initialize fresh LLM LoRA for training
|
| 102 |
+
from peft import LoraConfig, get_peft_model
|
| 103 |
+
|
| 104 |
+
lora_config = LoraConfig(
|
| 105 |
+
r=config.lora_rank,
|
| 106 |
+
lora_alpha=config.lora_alpha,
|
| 107 |
+
target_modules=config.lora_target_modules,
|
| 108 |
+
lora_dropout=config.lora_dropout,
|
| 109 |
+
bias="none",
|
| 110 |
+
task_type="CAUSAL_LM",
|
| 111 |
+
)
|
| 112 |
+
model.language_model = get_peft_model(model.language_model, lora_config)
|
| 113 |
+
|
| 114 |
+
return model
|
| 115 |
+
finally:
|
| 116 |
+
cls._is_loading_from_pretrained = False
|
| 117 |
+
cls._pretrained_model_path = None
|
| 118 |
+
|
| 119 |
+
def __init__(self, config: ASRConfig, **kwargs) -> None:
|
| 120 |
+
super().__init__(config)
|
| 121 |
+
|
| 122 |
+
self.system_prompt = config.system_prompt
|
| 123 |
+
target_dtype = getattr(torch, config.model_dtype)
|
| 124 |
+
|
| 125 |
+
# Audio encoder (frozen)
|
| 126 |
+
self.audio_tower = self._load_audio_encoder(config, target_dtype)
|
| 127 |
+
|
| 128 |
+
# Language model (frozen)
|
| 129 |
+
self.language_model = self._load_language_model(config, target_dtype)
|
| 130 |
+
|
| 131 |
+
# Initialize tokenizer and special tokens
|
| 132 |
+
self._init_tokenizer(config)
|
| 133 |
+
|
| 134 |
+
# Set up generation config with greedy decoding defaults
|
| 135 |
+
self.generation_config = self.language_model.generation_config
|
| 136 |
+
self.generation_config.max_new_tokens = config.max_new_tokens
|
| 137 |
+
self.generation_config.min_new_tokens = config.min_new_tokens
|
| 138 |
+
self.generation_config.num_beams = config.num_beams
|
| 139 |
+
self.generation_config.do_sample = config.do_sample
|
| 140 |
+
# Set sampling params from config (None means use model defaults)
|
| 141 |
+
self.generation_config.temperature = config.temperature
|
| 142 |
+
self.generation_config.top_p = config.top_p
|
| 143 |
+
self.generation_config.top_k = config.top_k
|
| 144 |
+
self.generation_config.use_cache = config.use_cache
|
| 145 |
+
self.generation_config.length_penalty = config.length_penalty
|
| 146 |
+
self.generation_config.repetition_penalty = config.repetition_penalty
|
| 147 |
+
self.generation_config.no_repeat_ngram_size = config.no_repeat_ngram_size
|
| 148 |
+
# Set EOS tokens, filtering out any that don't exist in the tokenizer
|
| 149 |
+
eos_candidates = [
|
| 150 |
+
self.tokenizer.convert_tokens_to_ids("<|im_end|>"),
|
| 151 |
+
self.tokenizer.convert_tokens_to_ids("<|endoftext|>"),
|
| 152 |
+
]
|
| 153 |
+
self.generation_config.eos_token_id = [t for t in eos_candidates if t is not None]
|
| 154 |
+
self.generation_config.pad_token_id = self.tokenizer.pad_token_id
|
| 155 |
+
|
| 156 |
+
# Feature extractor for audio preprocessing
|
| 157 |
+
self.feature_extractor = self._create_feature_extractor(config)
|
| 158 |
+
|
| 159 |
+
# Audio projector (trainable unless freeze_projector is set)
|
| 160 |
+
self.projector = self._create_projector(config, target_dtype)
|
| 161 |
+
|
| 162 |
+
# Setup LoRA if enabled (Stage 2 fine-tuning)
|
| 163 |
+
# Skip if loading from pretrained - from_pretrained will handle adapter loading
|
| 164 |
+
if getattr(config, "use_lora", False) and not getattr(
|
| 165 |
+
self.__class__, "_is_loading_from_pretrained", False
|
| 166 |
+
):
|
| 167 |
+
self._setup_lora(config)
|
| 168 |
+
|
| 169 |
+
# Freeze projector if specified (for Stage 2 LoRA-only training)
|
| 170 |
+
if getattr(config, "freeze_projector", False):
|
| 171 |
+
self.projector.requires_grad_(False)
|
| 172 |
+
|
| 173 |
+
# SpecAugment for data augmentation during training
|
| 174 |
+
if getattr(config, "use_specaugment", False):
|
| 175 |
+
self.spec_augment = SpecAugment(
|
| 176 |
+
n_time_masks=config.num_time_masks,
|
| 177 |
+
time_mask_param=config.time_mask_length,
|
| 178 |
+
n_freq_masks=config.num_freq_masks,
|
| 179 |
+
freq_mask_param=config.freq_mask_length,
|
| 180 |
+
)
|
| 181 |
+
else:
|
| 182 |
+
self.spec_augment = None
|
| 183 |
+
|
| 184 |
+
# Audio head for S2S (flow matching)
|
| 185 |
+
if getattr(config, "use_audio_head", False):
|
| 186 |
+
from .audio_head import AudioHead
|
| 187 |
+
|
| 188 |
+
device = next(self.language_model.parameters()).device
|
| 189 |
+
llm_dim = self.language_model.config.hidden_size
|
| 190 |
+
|
| 191 |
+
self.audio_head = AudioHead(config, llm_dim=llm_dim).to(
|
| 192 |
+
device=device, dtype=target_dtype
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
# Load pretrained pocket-tts flow_net if configured
|
| 196 |
+
pocket_tts_weights = getattr(config, "pocket_tts_weights", None)
|
| 197 |
+
freeze_flow_net = getattr(config, "freeze_flow_net", True)
|
| 198 |
+
if pocket_tts_weights is not None or freeze_flow_net:
|
| 199 |
+
# If freeze_flow_net is True but no weights specified, download from HF
|
| 200 |
+
self.audio_head.load_pretrained_flow_net(
|
| 201 |
+
weights_path=pocket_tts_weights,
|
| 202 |
+
freeze=freeze_flow_net,
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
if getattr(config, "freeze_audio_head", False):
|
| 206 |
+
self.audio_head.requires_grad_(False)
|
| 207 |
+
else:
|
| 208 |
+
self.audio_head = None
|
| 209 |
+
|
| 210 |
+
# For model parallelism
|
| 211 |
+
self._no_split_modules = getattr(self.language_model, "_no_split_modules", [])
|
| 212 |
+
|
| 213 |
+
def _create_feature_extractor(self, config: ASRConfig):
|
| 214 |
+
"""Create the appropriate feature extractor for the audio encoder."""
|
| 215 |
+
from transformers import AutoFeatureExtractor
|
| 216 |
+
|
| 217 |
+
feature_extractor = AutoFeatureExtractor.from_pretrained(config.audio_model_id)
|
| 218 |
+
# Disable padding by default - use actual audio length
|
| 219 |
+
feature_extractor.padding = False
|
| 220 |
+
return feature_extractor
|
| 221 |
+
|
| 222 |
+
@classmethod
|
| 223 |
+
def _load_audio_encoder(cls, config: ASRConfig, dtype: torch.dtype) -> nn.Module:
|
| 224 |
+
"""Load and freeze the audio encoder."""
|
| 225 |
+
encoder_kwargs = {
|
| 226 |
+
"attn_implementation": config.attn_implementation,
|
| 227 |
+
"low_cpu_mem_usage": True,
|
| 228 |
+
"torch_dtype": dtype,
|
| 229 |
+
}
|
| 230 |
+
|
| 231 |
+
if "whisper" in config.audio_model_id.lower():
|
| 232 |
+
from transformers import WhisperModel
|
| 233 |
+
|
| 234 |
+
full_model = WhisperModel.from_pretrained(config.audio_model_id, **encoder_kwargs)
|
| 235 |
+
encoder = full_model.encoder
|
| 236 |
+
del full_model
|
| 237 |
+
elif "glm" in config.audio_model_id.lower():
|
| 238 |
+
# GLM-ASR models use audio_tower as the encoder
|
| 239 |
+
# Requires transformers >= 5.x or installed from source
|
| 240 |
+
from transformers import AutoModelForSeq2SeqLM
|
| 241 |
+
|
| 242 |
+
full_model = AutoModelForSeq2SeqLM.from_pretrained(
|
| 243 |
+
config.audio_model_id, trust_remote_code=True, **encoder_kwargs
|
| 244 |
+
)
|
| 245 |
+
# GLM stores encoder at audio_tower (GlmAsrEncoder)
|
| 246 |
+
encoder = full_model.audio_tower
|
| 247 |
+
# Clear references to free VRAM from the LLM decoder
|
| 248 |
+
full_model.language_model = None
|
| 249 |
+
full_model.multi_modal_projector = None
|
| 250 |
+
del full_model
|
| 251 |
+
else:
|
| 252 |
+
encoder = AutoModel.from_pretrained(config.audio_model_id, **encoder_kwargs)
|
| 253 |
+
|
| 254 |
+
encoder.requires_grad_(False)
|
| 255 |
+
encoder.eval()
|
| 256 |
+
return encoder
|
| 257 |
+
|
| 258 |
+
@classmethod
|
| 259 |
+
def _load_language_model(cls, config: ASRConfig, dtype: torch.dtype) -> PreTrainedModel:
|
| 260 |
+
"""Load and freeze the language model."""
|
| 261 |
+
decoder_kwargs = {
|
| 262 |
+
"attn_implementation": config.attn_implementation,
|
| 263 |
+
"trust_remote_code": True,
|
| 264 |
+
"low_cpu_mem_usage": True,
|
| 265 |
+
"dtype": dtype,
|
| 266 |
+
}
|
| 267 |
+
|
| 268 |
+
decoder = AutoModelForCausalLM.from_pretrained(config.text_model_id, **decoder_kwargs)
|
| 269 |
+
decoder.config.use_cache = getattr(config, "use_cache", True)
|
| 270 |
+
decoder.requires_grad_(False)
|
| 271 |
+
decoder.eval()
|
| 272 |
+
return decoder
|
| 273 |
+
|
| 274 |
+
def _create_projector(self, config: ASRConfig, dtype: torch.dtype) -> nn.Module:
|
| 275 |
+
"""Create the trainable audio projector."""
|
| 276 |
+
# Auto-detect dimensions if not specified
|
| 277 |
+
if config.encoder_dim is None:
|
| 278 |
+
enc_cfg = self.audio_tower.config
|
| 279 |
+
config.encoder_dim = getattr(enc_cfg, "hidden_size", None) or getattr(
|
| 280 |
+
enc_cfg, "d_model", None
|
| 281 |
+
)
|
| 282 |
+
if config.encoder_dim is None:
|
| 283 |
+
raise ValueError("Could not auto-detect encoder_dim. Please specify in config.")
|
| 284 |
+
|
| 285 |
+
if config.llm_dim is None:
|
| 286 |
+
dec_cfg = self.language_model.config
|
| 287 |
+
config.llm_dim = getattr(dec_cfg, "hidden_size", None) or getattr(
|
| 288 |
+
dec_cfg, "d_model", None
|
| 289 |
+
)
|
| 290 |
+
if config.llm_dim is None:
|
| 291 |
+
raise ValueError("Could not auto-detect llm_dim. Please specify in config.")
|
| 292 |
+
|
| 293 |
+
# Select projector type based on config
|
| 294 |
+
projector_type = getattr(config, "projector_type", "mlp")
|
| 295 |
+
projector_class = PROJECTOR_CLASSES.get(projector_type)
|
| 296 |
+
if projector_class is None:
|
| 297 |
+
raise ValueError(
|
| 298 |
+
f"Unknown projector_type: {projector_type}. "
|
| 299 |
+
f"Valid options: {list(PROJECTOR_CLASSES.keys())}"
|
| 300 |
+
)
|
| 301 |
+
projector = projector_class(config)
|
| 302 |
+
|
| 303 |
+
# Move projector to same device as language model (important when using quantization)
|
| 304 |
+
device = next(self.language_model.parameters()).device
|
| 305 |
+
return projector.to(device=device, dtype=dtype)
|
| 306 |
+
|
| 307 |
+
def _setup_lora(self, config: ASRConfig):
|
| 308 |
+
"""Apply LoRA adapters to the language model for Stage 2 fine-tuning."""
|
| 309 |
+
from peft import LoraConfig, get_peft_model
|
| 310 |
+
|
| 311 |
+
lora_config = LoraConfig(
|
| 312 |
+
r=config.lora_rank,
|
| 313 |
+
lora_alpha=config.lora_alpha,
|
| 314 |
+
target_modules=config.lora_target_modules,
|
| 315 |
+
lora_dropout=config.lora_dropout,
|
| 316 |
+
bias="none",
|
| 317 |
+
task_type="CAUSAL_LM",
|
| 318 |
+
)
|
| 319 |
+
self.language_model = get_peft_model(self.language_model, lora_config)
|
| 320 |
+
|
| 321 |
+
def _init_tokenizer(self, config: ASRConfig):
|
| 322 |
+
"""Initialize tokenizer with audio token."""
|
| 323 |
+
self.tokenizer = AutoTokenizer.from_pretrained(config.text_model_id, trust_remote_code=True)
|
| 324 |
+
|
| 325 |
+
# Set pad token
|
| 326 |
+
if (
|
| 327 |
+
self.tokenizer.pad_token is None
|
| 328 |
+
or self.tokenizer.pad_token_id == self.tokenizer.eos_token_id
|
| 329 |
+
) and "<|finetune_right_pad_id|>" in self.tokenizer.get_vocab():
|
| 330 |
+
self.tokenizer.pad_token = "<|finetune_right_pad_id|>"
|
| 331 |
+
|
| 332 |
+
# Add audio token
|
| 333 |
+
existing_special = getattr(self.tokenizer, "additional_special_tokens", None) or []
|
| 334 |
+
if "<audio>" not in existing_special:
|
| 335 |
+
self.tokenizer.add_special_tokens(
|
| 336 |
+
{"additional_special_tokens": existing_special + ["<audio>"]}
|
| 337 |
+
)
|
| 338 |
+
self.language_model.resize_token_embeddings(len(self.tokenizer), mean_resizing=False)
|
| 339 |
+
|
| 340 |
+
self.audio_token_id = self.tokenizer.convert_tokens_to_ids("<audio>")
|
| 341 |
+
self.tokenizer.padding_side = "right"
|
| 342 |
+
|
| 343 |
+
# Sync token IDs to configs
|
| 344 |
+
for cfg in [self.config.text_config, self.language_model.config, self.generation_config]:
|
| 345 |
+
if cfg is not None:
|
| 346 |
+
cfg.pad_token_id = self.tokenizer.pad_token_id
|
| 347 |
+
cfg.eos_token_id = self.tokenizer.eos_token_id
|
| 348 |
+
cfg.bos_token_id = self.tokenizer.bos_token_id
|
| 349 |
+
|
| 350 |
+
def _init_weights(self, _module):
|
| 351 |
+
"""Weight initialization (projector weights are initialized in MoEAudioProjector)."""
|
| 352 |
+
pass
|
| 353 |
+
|
| 354 |
+
def _set_gradient_checkpointing(self, enable: bool = True, gradient_checkpointing_func=None):
|
| 355 |
+
"""Enable/disable gradient checkpointing for the language model."""
|
| 356 |
+
# The LLM still stores activations during forward for backprop to projector
|
| 357 |
+
# Gradient checkpointing trades compute for memory by recomputing activations
|
| 358 |
+
if hasattr(self.language_model, "_set_gradient_checkpointing"):
|
| 359 |
+
self.language_model._set_gradient_checkpointing(enable, gradient_checkpointing_func)
|
| 360 |
+
elif hasattr(self.language_model, "gradient_checkpointing_enable") and enable:
|
| 361 |
+
self.language_model.gradient_checkpointing_enable(
|
| 362 |
+
gradient_checkpointing_kwargs={"use_reentrant": False}
|
| 363 |
+
)
|
| 364 |
+
elif hasattr(self.language_model, "gradient_checkpointing_disable") and not enable:
|
| 365 |
+
self.language_model.gradient_checkpointing_disable()
|
| 366 |
+
|
| 367 |
+
def get_input_embeddings(self) -> nn.Module:
|
| 368 |
+
return self.language_model.get_input_embeddings()
|
| 369 |
+
|
| 370 |
+
def set_input_embeddings(self, value: nn.Module) -> None:
|
| 371 |
+
self.language_model.set_input_embeddings(value)
|
| 372 |
+
|
| 373 |
+
def get_output_embeddings(self) -> nn.Module:
|
| 374 |
+
return self.language_model.get_output_embeddings()
|
| 375 |
+
|
| 376 |
+
def set_output_embeddings(self, value: nn.Module) -> None:
|
| 377 |
+
self.language_model.set_output_embeddings(value)
|
| 378 |
+
|
| 379 |
+
def get_processor(self):
|
| 380 |
+
"""Get the processor for this model."""
|
| 381 |
+
try:
|
| 382 |
+
from .asr_processing import ASRProcessor
|
| 383 |
+
except ImportError:
|
| 384 |
+
from asr_processing import ASRProcessor # type: ignore[no-redef]
|
| 385 |
+
|
| 386 |
+
return ASRProcessor(
|
| 387 |
+
feature_extractor=self.feature_extractor,
|
| 388 |
+
tokenizer=self.tokenizer,
|
| 389 |
+
projector=self.projector,
|
| 390 |
+
encoder_conv_layers=self.config.encoder_conv_layers,
|
| 391 |
+
)
|
| 392 |
+
|
| 393 |
+
def state_dict(self, *args, **kwargs) -> dict[str, torch.Tensor]:
|
| 394 |
+
"""Save trainable weights (projector + audio_head if present)."""
|
| 395 |
+
state = {f"projector.{k}": v for k, v in self.projector.state_dict().items()}
|
| 396 |
+
if self.audio_head is not None:
|
| 397 |
+
state.update({f"audio_head.{k}": v for k, v in self.audio_head.state_dict().items()})
|
| 398 |
+
return state
|
| 399 |
+
|
| 400 |
+
def _compute_encoder_output_lengths(
|
| 401 |
+
self,
|
| 402 |
+
audio_attention_mask: torch.Tensor,
|
| 403 |
+
) -> torch.Tensor:
|
| 404 |
+
"""Compute per-sample encoder output lengths using conv layer formulas.
|
| 405 |
+
|
| 406 |
+
Args:
|
| 407 |
+
audio_attention_mask: Mask indicating real vs padded mel frames (batch, mel_len)
|
| 408 |
+
|
| 409 |
+
Returns:
|
| 410 |
+
Tensor of encoder output lengths per sample (batch,)
|
| 411 |
+
"""
|
| 412 |
+
# Get mel frame lengths from attention mask
|
| 413 |
+
lengths = audio_attention_mask.sum(dim=-1)
|
| 414 |
+
|
| 415 |
+
# Apply conv layer formulas: output = (input + 2*pad - (kernel-1) - 1) // stride + 1
|
| 416 |
+
for padding, kernel_size, stride in self.config.encoder_conv_layers:
|
| 417 |
+
lengths = (lengths + 2 * padding - (kernel_size - 1) - 1) // stride + 1
|
| 418 |
+
|
| 419 |
+
return lengths
|
| 420 |
+
|
| 421 |
+
def _encode_audio(
|
| 422 |
+
self,
|
| 423 |
+
audio_features: torch.Tensor,
|
| 424 |
+
audio_attention_mask: torch.Tensor,
|
| 425 |
+
expected_token_counts: torch.Tensor | None = None,
|
| 426 |
+
) -> torch.Tensor:
|
| 427 |
+
"""Encode audio and project to LLM embedding space.
|
| 428 |
+
|
| 429 |
+
Args:
|
| 430 |
+
audio_features: Mel spectrogram features (batch, n_mels, mel_len)
|
| 431 |
+
audio_attention_mask: Mask indicating real vs padded mel frames (batch, mel_len)
|
| 432 |
+
expected_token_counts: Expected number of audio tokens per sample from input_ids.
|
| 433 |
+
If provided, output will match these counts exactly (padding/truncating as needed).
|
| 434 |
+
|
| 435 |
+
Returns:
|
| 436 |
+
Flattened audio embeddings of shape (total_audio_tokens, hidden_dim).
|
| 437 |
+
"""
|
| 438 |
+
with torch.no_grad():
|
| 439 |
+
encoder_out = self.audio_tower(input_features=audio_features)
|
| 440 |
+
hidden_states = encoder_out.last_hidden_state
|
| 441 |
+
|
| 442 |
+
# Project to LLM space
|
| 443 |
+
audio_embeds = self.projector(hidden_states)
|
| 444 |
+
|
| 445 |
+
# Use expected token counts if provided (from input_ids), otherwise compute from audio
|
| 446 |
+
if expected_token_counts is not None:
|
| 447 |
+
token_counts = expected_token_counts
|
| 448 |
+
else:
|
| 449 |
+
# Compute per-sample encoder output lengths using conv formulas
|
| 450 |
+
encoder_lengths = self._compute_encoder_output_lengths(audio_attention_mask)
|
| 451 |
+
token_counts = torch.tensor(
|
| 452 |
+
[
|
| 453 |
+
self.projector.get_output_length(int(length.item()))
|
| 454 |
+
for length in encoder_lengths
|
| 455 |
+
],
|
| 456 |
+
device=audio_embeds.device,
|
| 457 |
+
)
|
| 458 |
+
|
| 459 |
+
# Extract embeddings matching expected token counts per sample
|
| 460 |
+
batch_size = audio_embeds.shape[0]
|
| 461 |
+
hidden_dim = audio_embeds.shape[2]
|
| 462 |
+
|
| 463 |
+
result_embeds = []
|
| 464 |
+
for i in range(batch_size):
|
| 465 |
+
count = int(token_counts[i].item())
|
| 466 |
+
sample_embeds = audio_embeds[i, :count, :] # Take first 'count' embeddings
|
| 467 |
+
# Pad with zeros if we don't have enough embeddings
|
| 468 |
+
if sample_embeds.shape[0] < count:
|
| 469 |
+
padding = torch.zeros(
|
| 470 |
+
count - sample_embeds.shape[0],
|
| 471 |
+
hidden_dim,
|
| 472 |
+
device=audio_embeds.device,
|
| 473 |
+
dtype=audio_embeds.dtype,
|
| 474 |
+
)
|
| 475 |
+
sample_embeds = torch.cat([sample_embeds, padding], dim=0)
|
| 476 |
+
result_embeds.append(sample_embeds)
|
| 477 |
+
|
| 478 |
+
return torch.cat(result_embeds, dim=0)
|
| 479 |
+
|
| 480 |
+
def forward(
|
| 481 |
+
self,
|
| 482 |
+
input_ids: Optional[torch.Tensor] = None,
|
| 483 |
+
input_features: Optional[torch.Tensor] = None,
|
| 484 |
+
audio_attention_mask: Optional[torch.Tensor] = None,
|
| 485 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 486 |
+
position_ids: Optional[torch.Tensor] = None,
|
| 487 |
+
past_key_values: Optional[torch.Tensor] = None,
|
| 488 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 489 |
+
labels: Optional[torch.Tensor] = None,
|
| 490 |
+
use_cache: Optional[bool] = None,
|
| 491 |
+
cache_position: Optional[torch.Tensor] = None,
|
| 492 |
+
latent_targets: Optional[torch.Tensor] = None,
|
| 493 |
+
latent_lengths: Optional[torch.Tensor] = None,
|
| 494 |
+
**kwargs,
|
| 495 |
+
) -> CausalLMOutputWithPast:
|
| 496 |
+
"""Forward pass for training and inference."""
|
| 497 |
+
# Get text embeddings if not provided
|
| 498 |
+
if inputs_embeds is None:
|
| 499 |
+
inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
|
| 500 |
+
|
| 501 |
+
if input_features is not None and input_ids is not None:
|
| 502 |
+
# Apply SpecAugment during training if enabled
|
| 503 |
+
if self.training and self.spec_augment is not None:
|
| 504 |
+
input_features = self.spec_augment(input_features)
|
| 505 |
+
|
| 506 |
+
# Count expected audio tokens from input_ids (ground truth from collator)
|
| 507 |
+
audio_token_counts = (input_ids == self.audio_token_id).sum(dim=-1)
|
| 508 |
+
|
| 509 |
+
# Encode audio -> flattened (total_audio_tokens, hidden_dim)
|
| 510 |
+
audio_embeds = self._encode_audio(
|
| 511 |
+
input_features, audio_attention_mask, audio_token_counts
|
| 512 |
+
)
|
| 513 |
+
|
| 514 |
+
# Replace <audio> token placeholders with audio embeddings using masked_scatter
|
| 515 |
+
audio_token_mask = (input_ids == self.audio_token_id).unsqueeze(-1)
|
| 516 |
+
|
| 517 |
+
inputs_embeds = inputs_embeds.masked_scatter(
|
| 518 |
+
audio_token_mask.to(inputs_embeds.device),
|
| 519 |
+
audio_embeds.to(inputs_embeds.device, dtype=inputs_embeds.dtype),
|
| 520 |
+
)
|
| 521 |
+
|
| 522 |
+
# Request hidden states if training audio head with latent targets
|
| 523 |
+
if self.audio_head is not None and latent_targets is not None:
|
| 524 |
+
kwargs["output_hidden_states"] = True
|
| 525 |
+
|
| 526 |
+
# Remove TRL-specific keys that shouldn't go to the LLM
|
| 527 |
+
kwargs.pop("prompts", None)
|
| 528 |
+
kwargs.pop("prompt_attention_mask", None)
|
| 529 |
+
|
| 530 |
+
# Run through language model (let it compute loss if labels provided)
|
| 531 |
+
outputs = self.language_model(
|
| 532 |
+
attention_mask=attention_mask,
|
| 533 |
+
position_ids=position_ids,
|
| 534 |
+
past_key_values=past_key_values,
|
| 535 |
+
inputs_embeds=inputs_embeds,
|
| 536 |
+
labels=labels,
|
| 537 |
+
use_cache=use_cache,
|
| 538 |
+
cache_position=cache_position,
|
| 539 |
+
**kwargs,
|
| 540 |
+
)
|
| 541 |
+
|
| 542 |
+
# Add auxiliary loss from MoE projectors if available
|
| 543 |
+
if outputs.loss is not None and hasattr(self.projector, "get_aux_loss"):
|
| 544 |
+
aux_loss = self.projector.get_aux_loss()
|
| 545 |
+
if aux_loss is not None and aux_loss.numel() > 0:
|
| 546 |
+
outputs.loss = outputs.loss + aux_loss.to(outputs.loss.device)
|
| 547 |
+
|
| 548 |
+
# Compute audio head loss if training S2S with latent targets
|
| 549 |
+
if self.audio_head is not None and latent_targets is not None:
|
| 550 |
+
if outputs.hidden_states is None:
|
| 551 |
+
raise ValueError(
|
| 552 |
+
"LLM did not return hidden_states for audio head. "
|
| 553 |
+
"Ensure output_hidden_states=True is passed to the LLM."
|
| 554 |
+
)
|
| 555 |
+
hidden_states = outputs.hidden_states[-1] # Last layer hidden states
|
| 556 |
+
|
| 557 |
+
# Extract only assistant-position hidden states using assistant_mask
|
| 558 |
+
# This mask identifies text output positions (where LLM generates response)
|
| 559 |
+
assistant_mask = kwargs.get("assistant_mask")
|
| 560 |
+
if assistant_mask is not None:
|
| 561 |
+
batch_size = hidden_states.shape[0]
|
| 562 |
+
|
| 563 |
+
# Extract assistant hidden states for each sample
|
| 564 |
+
assistant_hidden_list = []
|
| 565 |
+
assistant_lengths = []
|
| 566 |
+
for i in range(batch_size):
|
| 567 |
+
mask_i = assistant_mask[i] # [seq_len]
|
| 568 |
+
hidden_i = hidden_states[i][mask_i] # [num_assistant_tokens, hidden_dim]
|
| 569 |
+
assistant_hidden_list.append(hidden_i)
|
| 570 |
+
assistant_lengths.append(hidden_i.shape[0])
|
| 571 |
+
|
| 572 |
+
# Pad sequences while preserving gradients
|
| 573 |
+
# Use pad_sequence which maintains gradient flow
|
| 574 |
+
hidden_states = torch.nn.utils.rnn.pad_sequence(
|
| 575 |
+
assistant_hidden_list, batch_first=True, padding_value=0.0
|
| 576 |
+
)
|
| 577 |
+
# Note: latent_lengths stays as original Mimi latent lengths for masking
|
| 578 |
+
# audio_head._compute_loss handles interpolation between different seq lengths
|
| 579 |
+
|
| 580 |
+
# No detach needed: LLM is frozen (requires_grad=False), so gradients
|
| 581 |
+
# naturally stop there. Hidden states keep their grad_fn for proper backprop.
|
| 582 |
+
audio_head_loss = self.audio_head(
|
| 583 |
+
hidden_states,
|
| 584 |
+
latent_targets=latent_targets,
|
| 585 |
+
latent_lengths=latent_lengths,
|
| 586 |
+
)
|
| 587 |
+
|
| 588 |
+
# Combine with LLM loss if present (e.g., joint ASR+S2S training)
|
| 589 |
+
if outputs.loss is not None:
|
| 590 |
+
total_loss = outputs.loss + audio_head_loss
|
| 591 |
+
else:
|
| 592 |
+
total_loss = audio_head_loss
|
| 593 |
+
|
| 594 |
+
# Return new output object (direct assignment doesn't work with Accelerator/DDP)
|
| 595 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
| 596 |
+
|
| 597 |
+
return CausalLMOutputWithPast(
|
| 598 |
+
loss=total_loss,
|
| 599 |
+
logits=outputs.logits,
|
| 600 |
+
past_key_values=outputs.past_key_values,
|
| 601 |
+
hidden_states=outputs.hidden_states,
|
| 602 |
+
attentions=outputs.attentions,
|
| 603 |
+
)
|
| 604 |
+
|
| 605 |
+
return outputs
|
| 606 |
+
|
| 607 |
+
def prepare_inputs_for_generation(self, *args, **kwargs):
|
| 608 |
+
"""Prepare inputs for generation, handling audio features for cached decoding."""
|
| 609 |
+
input_features = kwargs.pop("input_features", None)
|
| 610 |
+
cache_position = kwargs.get("cache_position")
|
| 611 |
+
|
| 612 |
+
model_inputs = self.language_model.prepare_inputs_for_generation(*args, **kwargs)
|
| 613 |
+
|
| 614 |
+
# Only pass audio features on the first generation step (cache_position[0] == 0)
|
| 615 |
+
if cache_position is not None and cache_position[0] == 0 and input_features is not None:
|
| 616 |
+
model_inputs["input_features"] = input_features
|
| 617 |
+
|
| 618 |
+
return model_inputs
|
| 619 |
+
|
| 620 |
+
def _get_num_audio_tokens(
|
| 621 |
+
self,
|
| 622 |
+
audio_attention_mask: torch.Tensor,
|
| 623 |
+
) -> int:
|
| 624 |
+
"""Calculate number of audio tokens based on actual audio length.
|
| 625 |
+
|
| 626 |
+
Uses attention mask to get real audio length, then computes:
|
| 627 |
+
mel_frames -> encoder_frames (via conv formulas) -> projector output tokens
|
| 628 |
+
"""
|
| 629 |
+
encoder_lengths = self._compute_encoder_output_lengths(audio_attention_mask)
|
| 630 |
+
# Use max length for batch (all samples should have same token count for generation)
|
| 631 |
+
encoder_output_len = int(encoder_lengths.max().item())
|
| 632 |
+
return int(self.projector.get_output_length(encoder_output_len))
|
| 633 |
+
|
| 634 |
+
@torch.no_grad()
|
| 635 |
+
def generate(
|
| 636 |
+
self,
|
| 637 |
+
input_ids: Optional[torch.Tensor] = None,
|
| 638 |
+
input_features: Optional[torch.Tensor] = None,
|
| 639 |
+
audio_attention_mask: Optional[torch.Tensor] = None,
|
| 640 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 641 |
+
system_prompt: Optional[str] = None,
|
| 642 |
+
**generate_kwargs,
|
| 643 |
+
) -> torch.Tensor:
|
| 644 |
+
"""Generate transcription from audio input.
|
| 645 |
+
|
| 646 |
+
Can be called in two ways:
|
| 647 |
+
1. With input_ids containing <audio> tokens (from processor)
|
| 648 |
+
2. With just audio, and we build the prompt internally
|
| 649 |
+
"""
|
| 650 |
+
if input_features is None:
|
| 651 |
+
raise ValueError("input_features required for generation")
|
| 652 |
+
if audio_attention_mask is None:
|
| 653 |
+
raise ValueError("audio_attention_mask required for generation")
|
| 654 |
+
|
| 655 |
+
device = input_features.device
|
| 656 |
+
batch_size = input_features.shape[0]
|
| 657 |
+
|
| 658 |
+
# Encode audio -> flattened embeddings
|
| 659 |
+
audio_embeds = self._encode_audio(input_features, audio_attention_mask)
|
| 660 |
+
|
| 661 |
+
# If input_ids not provided, build prompt with correct number of audio tokens
|
| 662 |
+
if input_ids is None:
|
| 663 |
+
num_audio_tokens = self._get_num_audio_tokens(audio_attention_mask)
|
| 664 |
+
audio_placeholder = "<audio>" * num_audio_tokens
|
| 665 |
+
|
| 666 |
+
system_prompt = system_prompt or self.system_prompt
|
| 667 |
+
|
| 668 |
+
messages: list[dict[str, str]] = []
|
| 669 |
+
if system_prompt:
|
| 670 |
+
messages.append({"role": "system", "content": system_prompt})
|
| 671 |
+
# Audio tokens only (instruction-free)
|
| 672 |
+
user_content = audio_placeholder
|
| 673 |
+
if self.TRANSCRIBE_PROMPT:
|
| 674 |
+
user_content += " " + self.TRANSCRIBE_PROMPT
|
| 675 |
+
messages.append({"role": "user", "content": user_content})
|
| 676 |
+
|
| 677 |
+
chat_result = self.tokenizer.apply_chat_template(
|
| 678 |
+
messages,
|
| 679 |
+
tokenize=True,
|
| 680 |
+
add_generation_prompt=True,
|
| 681 |
+
return_tensors="pt",
|
| 682 |
+
enable_thinking=getattr(self.config, "enable_thinking", False),
|
| 683 |
+
)
|
| 684 |
+
input_ids = chat_result.input_ids.to(device)
|
| 685 |
+
|
| 686 |
+
if input_ids.dim() == 1:
|
| 687 |
+
input_ids = input_ids.unsqueeze(0)
|
| 688 |
+
if input_ids.shape[0] == 1 and batch_size > 1:
|
| 689 |
+
input_ids = input_ids.expand(batch_size, -1)
|
| 690 |
+
|
| 691 |
+
attention_mask = torch.ones_like(input_ids)
|
| 692 |
+
|
| 693 |
+
# Get text embeddings and replace audio tokens with audio embeddings
|
| 694 |
+
inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
|
| 695 |
+
audio_token_mask = (input_ids == self.audio_token_id).unsqueeze(-1)
|
| 696 |
+
inputs_embeds = inputs_embeds.masked_scatter(
|
| 697 |
+
audio_token_mask.to(inputs_embeds.device),
|
| 698 |
+
audio_embeds.to(inputs_embeds.device, dtype=inputs_embeds.dtype),
|
| 699 |
+
)
|
| 700 |
+
|
| 701 |
+
# Generate using language model
|
| 702 |
+
# Pass both input_ids and inputs_embeds so repetition_penalty works correctly
|
| 703 |
+
# (it needs input_ids to track which tokens have been used)
|
| 704 |
+
output = self.language_model.generate(
|
| 705 |
+
input_ids=input_ids,
|
| 706 |
+
inputs_embeds=inputs_embeds,
|
| 707 |
+
attention_mask=attention_mask,
|
| 708 |
+
generation_config=self.generation_config,
|
| 709 |
+
**generate_kwargs,
|
| 710 |
+
)
|
| 711 |
+
|
| 712 |
+
# When using inputs_embeds with input_ids, generate returns full sequence
|
| 713 |
+
# Strip the input tokens to return only generated tokens
|
| 714 |
+
sequences = output if isinstance(output, torch.Tensor) else output.sequences
|
| 715 |
+
input_len = input_ids.shape[1]
|
| 716 |
+
return sequences[:, input_len:]
|
| 717 |
+
|
| 718 |
+
def generate_streaming(
|
| 719 |
+
self,
|
| 720 |
+
input_features: torch.Tensor,
|
| 721 |
+
audio_attention_mask: torch.Tensor,
|
| 722 |
+
system_prompt: Optional[str] = None,
|
| 723 |
+
**generate_kwargs,
|
| 724 |
+
) -> Iterator[str]:
|
| 725 |
+
"""Generate transcription with streaming token output.
|
| 726 |
+
|
| 727 |
+
Yields partial transcript strings as tokens are generated.
|
| 728 |
+
Reduces time-to-first-word by streaming tokens as they're decoded.
|
| 729 |
+
|
| 730 |
+
Args:
|
| 731 |
+
input_features: Mel spectrogram features (batch, n_mels, mel_len)
|
| 732 |
+
audio_attention_mask: Mask for real vs padded mel frames (batch, mel_len)
|
| 733 |
+
system_prompt: Optional system prompt override
|
| 734 |
+
**generate_kwargs: Additional generation arguments
|
| 735 |
+
|
| 736 |
+
Yields:
|
| 737 |
+
Partial transcript text as each token is generated
|
| 738 |
+
"""
|
| 739 |
+
device = input_features.device
|
| 740 |
+
batch_size = input_features.shape[0]
|
| 741 |
+
|
| 742 |
+
# Encode audio -> flattened embeddings
|
| 743 |
+
audio_embeds = self._encode_audio(input_features, audio_attention_mask)
|
| 744 |
+
|
| 745 |
+
# Build prompt with correct number of audio tokens
|
| 746 |
+
num_audio_tokens = self._get_num_audio_tokens(audio_attention_mask)
|
| 747 |
+
audio_placeholder = "<audio>" * num_audio_tokens
|
| 748 |
+
|
| 749 |
+
system_prompt = system_prompt or self.system_prompt
|
| 750 |
+
|
| 751 |
+
messages: list[dict[str, str]] = []
|
| 752 |
+
if system_prompt:
|
| 753 |
+
messages.append({"role": "system", "content": system_prompt})
|
| 754 |
+
# Audio tokens only (instruction-free)
|
| 755 |
+
user_content = audio_placeholder
|
| 756 |
+
if self.TRANSCRIBE_PROMPT:
|
| 757 |
+
user_content += " " + self.TRANSCRIBE_PROMPT
|
| 758 |
+
messages.append({"role": "user", "content": user_content})
|
| 759 |
+
|
| 760 |
+
chat_result = self.tokenizer.apply_chat_template(
|
| 761 |
+
messages,
|
| 762 |
+
tokenize=True,
|
| 763 |
+
add_generation_prompt=True,
|
| 764 |
+
return_tensors="pt",
|
| 765 |
+
enable_thinking=getattr(self.config, "enable_thinking", False),
|
| 766 |
+
)
|
| 767 |
+
input_ids = chat_result.input_ids.to(device)
|
| 768 |
+
|
| 769 |
+
if input_ids.dim() == 1:
|
| 770 |
+
input_ids = input_ids.unsqueeze(0)
|
| 771 |
+
if input_ids.shape[0] == 1 and batch_size > 1:
|
| 772 |
+
input_ids = input_ids.expand(batch_size, -1)
|
| 773 |
+
|
| 774 |
+
attention_mask = torch.ones_like(input_ids)
|
| 775 |
+
|
| 776 |
+
# Get text embeddings and replace audio tokens with audio embeddings
|
| 777 |
+
inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
|
| 778 |
+
audio_token_mask = (input_ids == self.audio_token_id).unsqueeze(-1)
|
| 779 |
+
inputs_embeds = inputs_embeds.masked_scatter(
|
| 780 |
+
audio_token_mask.to(inputs_embeds.device),
|
| 781 |
+
audio_embeds.to(inputs_embeds.device, dtype=inputs_embeds.dtype),
|
| 782 |
+
)
|
| 783 |
+
|
| 784 |
+
# Setup streamer for token-by-token output
|
| 785 |
+
streamer = TextIteratorStreamer(
|
| 786 |
+
self.tokenizer,
|
| 787 |
+
skip_prompt=True,
|
| 788 |
+
skip_special_tokens=True,
|
| 789 |
+
)
|
| 790 |
+
|
| 791 |
+
# Prepare generation kwargs
|
| 792 |
+
gen_kwargs = {
|
| 793 |
+
"inputs_embeds": inputs_embeds,
|
| 794 |
+
"attention_mask": attention_mask,
|
| 795 |
+
"generation_config": self.generation_config,
|
| 796 |
+
"streamer": streamer,
|
| 797 |
+
**generate_kwargs,
|
| 798 |
+
}
|
| 799 |
+
|
| 800 |
+
# Run generation in background thread
|
| 801 |
+
thread = Thread(target=self.language_model.generate, kwargs=gen_kwargs)
|
| 802 |
+
thread.start()
|
| 803 |
+
|
| 804 |
+
# Yield tokens as they're generated, filtering out <think>...</think> blocks
|
| 805 |
+
# Start assuming no think block - only filter when we see <think>
|
| 806 |
+
in_think_block = False
|
| 807 |
+
buffer = ""
|
| 808 |
+
|
| 809 |
+
for text in streamer:
|
| 810 |
+
buffer += text
|
| 811 |
+
|
| 812 |
+
# Check for think block start (in case model outputs think blocks)
|
| 813 |
+
while "<think>" in buffer:
|
| 814 |
+
in_think_block = True
|
| 815 |
+
# Yield any text before <think>
|
| 816 |
+
before_think = buffer.split("<think>")[0]
|
| 817 |
+
if before_think:
|
| 818 |
+
yield before_think
|
| 819 |
+
buffer = buffer.split("<think>", 1)[-1]
|
| 820 |
+
|
| 821 |
+
# Check for think block end
|
| 822 |
+
while in_think_block and "</think>" in buffer:
|
| 823 |
+
in_think_block = False
|
| 824 |
+
buffer = buffer.split("</think>", 1)[-1]
|
| 825 |
+
|
| 826 |
+
# Yield text if not in think block
|
| 827 |
+
if not in_think_block and buffer:
|
| 828 |
+
yield buffer
|
| 829 |
+
buffer = ""
|
| 830 |
+
|
| 831 |
+
# Yield any remaining buffer
|
| 832 |
+
if buffer and not in_think_block:
|
| 833 |
+
yield buffer
|
| 834 |
+
|
| 835 |
+
thread.join()
|
| 836 |
+
|
| 837 |
+
@torch.no_grad()
|
| 838 |
+
def generate_text_only(
|
| 839 |
+
self,
|
| 840 |
+
messages: list[dict[str, str]],
|
| 841 |
+
max_new_tokens: int = 256,
|
| 842 |
+
**generate_kwargs,
|
| 843 |
+
) -> str:
|
| 844 |
+
"""Generate text using only the LLM (no audio encoding).
|
| 845 |
+
|
| 846 |
+
Used for SIFT-style response generation from metadata prompts.
|
| 847 |
+
|
| 848 |
+
Args:
|
| 849 |
+
messages: List of chat messages [{"role": "user", "content": "..."}]
|
| 850 |
+
max_new_tokens: Maximum tokens to generate
|
| 851 |
+
**generate_kwargs: Additional generation arguments
|
| 852 |
+
|
| 853 |
+
Returns:
|
| 854 |
+
Generated text response
|
| 855 |
+
"""
|
| 856 |
+
device = next(self.language_model.parameters()).device
|
| 857 |
+
|
| 858 |
+
# Apply chat template
|
| 859 |
+
input_ids = self.tokenizer.apply_chat_template(
|
| 860 |
+
messages,
|
| 861 |
+
tokenize=True,
|
| 862 |
+
add_generation_prompt=True,
|
| 863 |
+
return_tensors="pt",
|
| 864 |
+
enable_thinking=getattr(self.config, "enable_thinking", False),
|
| 865 |
+
).to(device)
|
| 866 |
+
|
| 867 |
+
if input_ids.dim() == 1:
|
| 868 |
+
input_ids = input_ids.unsqueeze(0)
|
| 869 |
+
|
| 870 |
+
attention_mask = torch.ones_like(input_ids)
|
| 871 |
+
|
| 872 |
+
# Generate using language model directly
|
| 873 |
+
output = self.language_model.generate(
|
| 874 |
+
input_ids=input_ids,
|
| 875 |
+
attention_mask=attention_mask,
|
| 876 |
+
max_new_tokens=max_new_tokens,
|
| 877 |
+
do_sample=False,
|
| 878 |
+
pad_token_id=self.tokenizer.pad_token_id,
|
| 879 |
+
eos_token_id=self.tokenizer.eos_token_id,
|
| 880 |
+
**generate_kwargs,
|
| 881 |
+
)
|
| 882 |
+
|
| 883 |
+
# Decode only the new tokens
|
| 884 |
+
new_tokens = output[0, input_ids.shape[1] :]
|
| 885 |
+
response = self.tokenizer.decode(new_tokens, skip_special_tokens=True)
|
| 886 |
+
return response.strip()
|
| 887 |
+
|
| 888 |
+
@torch.no_grad()
|
| 889 |
+
def generate_with_audio(
|
| 890 |
+
self,
|
| 891 |
+
input_features: torch.Tensor,
|
| 892 |
+
audio_attention_mask: torch.Tensor,
|
| 893 |
+
**generate_kwargs,
|
| 894 |
+
) -> dict[str, torch.Tensor | list[str]]:
|
| 895 |
+
"""Generate text and audio for Speech-to-Speech.
|
| 896 |
+
|
| 897 |
+
Args:
|
| 898 |
+
input_features: Mel spectrogram features (batch, n_mels, mel_len)
|
| 899 |
+
audio_attention_mask: Mask for real vs padded mel frames (batch, mel_len)
|
| 900 |
+
**generate_kwargs: Additional generation arguments
|
| 901 |
+
|
| 902 |
+
Returns:
|
| 903 |
+
Dict with:
|
| 904 |
+
- text: Decoded text strings (list of str)
|
| 905 |
+
- audio: Audio waveform at 24kHz (batch, samples)
|
| 906 |
+
"""
|
| 907 |
+
if self.audio_head is None:
|
| 908 |
+
raise ValueError("Audio head not configured. Set use_audio_head=True in config.")
|
| 909 |
+
|
| 910 |
+
device = input_features.device
|
| 911 |
+
batch_size = input_features.shape[0]
|
| 912 |
+
|
| 913 |
+
# Encode audio -> flattened embeddings
|
| 914 |
+
audio_embeds = self._encode_audio(input_features, audio_attention_mask)
|
| 915 |
+
|
| 916 |
+
# Build prompt with correct number of audio tokens
|
| 917 |
+
num_audio_tokens = self._get_num_audio_tokens(audio_attention_mask)
|
| 918 |
+
audio_placeholder = "<audio>" * num_audio_tokens
|
| 919 |
+
|
| 920 |
+
messages: list[dict[str, str]] = []
|
| 921 |
+
if self.system_prompt:
|
| 922 |
+
messages.append({"role": "system", "content": self.system_prompt})
|
| 923 |
+
user_content = audio_placeholder
|
| 924 |
+
if self.TRANSCRIBE_PROMPT:
|
| 925 |
+
user_content += " " + self.TRANSCRIBE_PROMPT
|
| 926 |
+
messages.append({"role": "user", "content": user_content})
|
| 927 |
+
|
| 928 |
+
chat_result = self.tokenizer.apply_chat_template(
|
| 929 |
+
messages,
|
| 930 |
+
tokenize=True,
|
| 931 |
+
add_generation_prompt=True,
|
| 932 |
+
return_tensors="pt",
|
| 933 |
+
enable_thinking=getattr(self.config, "enable_thinking", False),
|
| 934 |
+
)
|
| 935 |
+
input_ids = chat_result.input_ids.to(device)
|
| 936 |
+
|
| 937 |
+
if input_ids.dim() == 1:
|
| 938 |
+
input_ids = input_ids.unsqueeze(0)
|
| 939 |
+
if input_ids.shape[0] == 1 and batch_size > 1:
|
| 940 |
+
input_ids = input_ids.expand(batch_size, -1)
|
| 941 |
+
|
| 942 |
+
attention_mask = torch.ones_like(input_ids)
|
| 943 |
+
|
| 944 |
+
# Get text embeddings and replace audio tokens with audio embeddings
|
| 945 |
+
inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
|
| 946 |
+
audio_token_mask = (input_ids == self.audio_token_id).unsqueeze(-1)
|
| 947 |
+
inputs_embeds = inputs_embeds.masked_scatter(
|
| 948 |
+
audio_token_mask.to(inputs_embeds.device),
|
| 949 |
+
audio_embeds.to(inputs_embeds.device, dtype=inputs_embeds.dtype),
|
| 950 |
+
)
|
| 951 |
+
|
| 952 |
+
# Generate with hidden states
|
| 953 |
+
output = self.language_model.generate(
|
| 954 |
+
input_ids=input_ids,
|
| 955 |
+
inputs_embeds=inputs_embeds,
|
| 956 |
+
attention_mask=attention_mask,
|
| 957 |
+
generation_config=self.generation_config,
|
| 958 |
+
output_hidden_states=True,
|
| 959 |
+
return_dict_in_generate=True,
|
| 960 |
+
**generate_kwargs,
|
| 961 |
+
)
|
| 962 |
+
|
| 963 |
+
# Extract generated text
|
| 964 |
+
text_ids = output.sequences[:, input_ids.shape[1] :]
|
| 965 |
+
text = self.tokenizer.batch_decode(text_ids, skip_special_tokens=True)
|
| 966 |
+
|
| 967 |
+
# Extract hidden states from generation steps and concatenate
|
| 968 |
+
# output.hidden_states is tuple of (step,) where each step is tuple of (layer,)
|
| 969 |
+
# Each layer tensor is (batch, 1, hidden_dim) for generated tokens
|
| 970 |
+
last_layer_states = []
|
| 971 |
+
for step_hidden in output.hidden_states:
|
| 972 |
+
# step_hidden is tuple of (num_layers,) tensors
|
| 973 |
+
# Get last layer: shape (batch, 1, hidden_dim)
|
| 974 |
+
last_layer_states.append(step_hidden[-1])
|
| 975 |
+
|
| 976 |
+
# Concatenate across generation steps: (batch, gen_seq_len, hidden_dim)
|
| 977 |
+
hidden_states = torch.cat(last_layer_states, dim=1)
|
| 978 |
+
|
| 979 |
+
# Generate Mimi latents from LLM hidden states via flow matching
|
| 980 |
+
latents = self.audio_head(hidden_states)
|
| 981 |
+
|
| 982 |
+
# Load Mimi decoder if not already loaded
|
| 983 |
+
if self.audio_head.mimi is None:
|
| 984 |
+
self.audio_head.load_mimi_decoder(device=device)
|
| 985 |
+
|
| 986 |
+
# Decode latents to audio waveform
|
| 987 |
+
audio = self.audio_head.decode_to_audio(latents)
|
| 988 |
+
|
| 989 |
+
return {
|
| 990 |
+
"text": text,
|
| 991 |
+
"audio": audio,
|
| 992 |
+
}
|
| 993 |
+
|
| 994 |
+
def save_pretrained(self, save_directory: Union[str, Path], **kwargs) -> None:
|
| 995 |
+
"""Save model, tokenizer, and processor."""
|
| 996 |
+
import shutil
|
| 997 |
+
from pathlib import Path as PathlibPath
|
| 998 |
+
|
| 999 |
+
save_dir = PathlibPath(save_directory)
|
| 1000 |
+
save_dir.mkdir(parents=True, exist_ok=True)
|
| 1001 |
+
|
| 1002 |
+
# Update config with actual vocab size
|
| 1003 |
+
self.config.vocab_size = self.language_model.config.vocab_size
|
| 1004 |
+
self.config.text_config.vocab_size = self.language_model.config.vocab_size
|
| 1005 |
+
|
| 1006 |
+
if hasattr(self.audio_tower.config, "num_mel_bins"):
|
| 1007 |
+
self.config.audio_config.num_mel_bins = self.audio_tower.config.num_mel_bins
|
| 1008 |
+
|
| 1009 |
+
# Save model (temporarily remove non-serializable attributes)
|
| 1010 |
+
tokenizer = self.tokenizer
|
| 1011 |
+
del self.tokenizer
|
| 1012 |
+
|
| 1013 |
+
try:
|
| 1014 |
+
super().save_pretrained(save_dir, **kwargs)
|
| 1015 |
+
finally:
|
| 1016 |
+
self.tokenizer = tokenizer
|
| 1017 |
+
|
| 1018 |
+
# Save tokenizer and feature extractor
|
| 1019 |
+
self.tokenizer.save_pretrained(save_dir)
|
| 1020 |
+
self.feature_extractor.save_pretrained(save_dir)
|
| 1021 |
+
|
| 1022 |
+
# Save LoRA adapters if present (creates adapter_model.safetensors and adapter_config.json)
|
| 1023 |
+
# Don't save embedding layers - the <audio> token embedding is never used
|
| 1024 |
+
# (it's replaced with projected audio embeddings before the LLM sees it)
|
| 1025 |
+
if hasattr(self.language_model, "peft_config"):
|
| 1026 |
+
self.language_model.save_pretrained(save_dir, save_embedding_layers=False)
|
| 1027 |
+
|
| 1028 |
+
# Clear base_model_name_or_path in adapter_config.json to prevent HF pipeline
|
| 1029 |
+
# from redirecting to the base LLM repo (like Qwen) which breaks feature
|
| 1030 |
+
# extractor loading for multimodal models. If a repo_id is provided, use that
|
| 1031 |
+
# so the model can be loaded directly from the Hub.
|
| 1032 |
+
adapter_config_path = save_dir / "adapter_config.json"
|
| 1033 |
+
if adapter_config_path.exists():
|
| 1034 |
+
with adapter_config_path.open() as f:
|
| 1035 |
+
adapter_config = json.load(f)
|
| 1036 |
+
|
| 1037 |
+
# Use repo_id if available, otherwise clear to prevent redirect.
|
| 1038 |
+
# Use empty string instead of None to avoid str(None) -> "None" bug
|
| 1039 |
+
# in some transformers/PEFT versions.
|
| 1040 |
+
repo_id = (
|
| 1041 |
+
kwargs.get("repo_id")
|
| 1042 |
+
or kwargs.get("push_to_hub_model_id")
|
| 1043 |
+
or getattr(self.config, "pretrained_model_path", None)
|
| 1044 |
+
or "" # Use empty string instead of None
|
| 1045 |
+
)
|
| 1046 |
+
adapter_config["base_model_name_or_path"] = repo_id
|
| 1047 |
+
|
| 1048 |
+
with adapter_config_path.open("w") as f:
|
| 1049 |
+
json.dump(adapter_config, f, indent=2)
|
| 1050 |
+
|
| 1051 |
+
# Add processor auto_map to preprocessor_config.json
|
| 1052 |
+
config_path = save_dir / "preprocessor_config.json"
|
| 1053 |
+
if config_path.exists():
|
| 1054 |
+
with config_path.open() as f:
|
| 1055 |
+
processor_config = json.load(f)
|
| 1056 |
+
else:
|
| 1057 |
+
processor_config = {}
|
| 1058 |
+
|
| 1059 |
+
processor_config.update(
|
| 1060 |
+
{
|
| 1061 |
+
"processor_class": "ASRProcessor",
|
| 1062 |
+
"auto_map": {"AutoProcessor": "asr_processing.ASRProcessor"},
|
| 1063 |
+
}
|
| 1064 |
+
)
|
| 1065 |
+
|
| 1066 |
+
with config_path.open("w") as f:
|
| 1067 |
+
json.dump(processor_config, f, indent=2)
|
| 1068 |
+
|
| 1069 |
+
# Copy source files for auto-loading
|
| 1070 |
+
src_dir = PathlibPath(__file__).parent
|
| 1071 |
+
for asr_file in src_dir.glob("asr_*.py"):
|
| 1072 |
+
shutil.copy(asr_file, save_dir / asr_file.name)
|
| 1073 |
+
# Copy projectors module
|
| 1074 |
+
shutil.copy(src_dir / "projectors.py", save_dir / "projectors.py")
|
| 1075 |
+
# Copy alignment module
|
| 1076 |
+
shutil.copy(src_dir / "alignment.py", save_dir / "alignment.py")
|
| 1077 |
+
# Copy diarization module
|
| 1078 |
+
shutil.copy(src_dir / "diarization.py", save_dir / "diarization.py")
|
| 1079 |
+
# Copy audio head for S2S
|
| 1080 |
+
audio_head_path = src_dir / "audio_head.py"
|
| 1081 |
+
if audio_head_path.exists():
|
| 1082 |
+
shutil.copy(audio_head_path, save_dir / "audio_head.py")
|
| 1083 |
+
# Copy modules directory (for audio head dependencies)
|
| 1084 |
+
modules_dir = src_dir / "modules"
|
| 1085 |
+
if modules_dir.exists():
|
| 1086 |
+
save_modules_dir = save_dir / "modules"
|
| 1087 |
+
save_modules_dir.mkdir(exist_ok=True)
|
| 1088 |
+
for module_file in modules_dir.glob("*.py"):
|
| 1089 |
+
shutil.copy(module_file, save_modules_dir / module_file.name)
|
| 1090 |
+
|
| 1091 |
+
def push_to_hub(self, repo_id: str, **kwargs) -> str:
|
| 1092 |
+
"""Push model to HuggingFace Hub, ensuring adapter_config points to repo.
|
| 1093 |
+
|
| 1094 |
+
IMPORTANT: Sets base_model_name_or_path in adapter_config.json to repo_id
|
| 1095 |
+
so that transformers pipeline() can load the model correctly. Without this,
|
| 1096 |
+
the pipeline tries to load from "None" which fails.
|
| 1097 |
+
"""
|
| 1098 |
+
# Store repo_id in config so save_pretrained can access it
|
| 1099 |
+
self.config.pretrained_model_path = repo_id
|
| 1100 |
+
# Call parent's push_to_hub
|
| 1101 |
+
return super().push_to_hub(repo_id, **kwargs)
|
| 1102 |
+
|
| 1103 |
+
def create_or_update_model_card(self, output_dir: Union[str, Path]) -> None:
|
| 1104 |
+
"""No-op for model card creation - we use MODEL_CARD.md in repo instead."""
|
| 1105 |
+
pass
|
| 1106 |
+
|
| 1107 |
+
|
| 1108 |
+
# Register with transformers Auto classes
|
| 1109 |
+
AutoConfig.register("asr_model", ASRConfig)
|
| 1110 |
+
AutoModel.register(ASRConfig, ASRModel)
|
asr_pipeline.py
ADDED
|
@@ -0,0 +1,569 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""ASR pipeline for audio-to-text transcription with optional timestamps and diarization."""
|
| 2 |
+
|
| 3 |
+
import re
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from typing import Any, Iterator, Union
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
import transformers
|
| 10 |
+
|
| 11 |
+
try:
|
| 12 |
+
from .alignment import ForcedAligner
|
| 13 |
+
from .asr_modeling import ASRModel
|
| 14 |
+
from .diarization import SpeakerDiarizer
|
| 15 |
+
except ImportError:
|
| 16 |
+
from alignment import ForcedAligner # type: ignore[no-redef]
|
| 17 |
+
from asr_modeling import ASRModel # type: ignore[no-redef]
|
| 18 |
+
from diarization import SpeakerDiarizer # type: ignore[no-redef]
|
| 19 |
+
|
| 20 |
+
# Re-export for backwards compatibility
|
| 21 |
+
__all__ = ["ForcedAligner", "SpeakerDiarizer", "ASRPipeline", "strip_thinking"]
|
| 22 |
+
|
| 23 |
+
# Default TTS voice for Kokoro
|
| 24 |
+
DEFAULT_TTS_VOICE = "af_heart"
|
| 25 |
+
TTS_SAMPLE_RATE = 24000
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def strip_thinking(text: str) -> str:
|
| 29 |
+
"""Remove <think>...</think> tags from model output.
|
| 30 |
+
|
| 31 |
+
Args:
|
| 32 |
+
text: Model output text that may contain thinking tags
|
| 33 |
+
|
| 34 |
+
Returns:
|
| 35 |
+
Text with thinking content removed
|
| 36 |
+
"""
|
| 37 |
+
if not text:
|
| 38 |
+
return text
|
| 39 |
+
text = re.sub(r"<think>.*?</think>\s*", "", text, flags=re.DOTALL)
|
| 40 |
+
return text.strip()
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
|
| 44 |
+
"""ASR Pipeline for audio-to-text transcription."""
|
| 45 |
+
|
| 46 |
+
model: ASRModel
|
| 47 |
+
|
| 48 |
+
def __init__(self, model: ASRModel, **kwargs):
|
| 49 |
+
"""Initialize ASR pipeline.
|
| 50 |
+
|
| 51 |
+
Args:
|
| 52 |
+
model: ASRModel instance for transcription
|
| 53 |
+
**kwargs: Additional arguments (feature_extractor, tokenizer, device)
|
| 54 |
+
"""
|
| 55 |
+
feature_extractor = kwargs.pop("feature_extractor", None)
|
| 56 |
+
tokenizer = kwargs.pop("tokenizer", model.tokenizer)
|
| 57 |
+
|
| 58 |
+
if feature_extractor is None:
|
| 59 |
+
feature_extractor = model.get_processor().feature_extractor
|
| 60 |
+
|
| 61 |
+
super().__init__(
|
| 62 |
+
model=model, feature_extractor=feature_extractor, tokenizer=tokenizer, **kwargs
|
| 63 |
+
)
|
| 64 |
+
self._current_audio = None
|
| 65 |
+
self._tts_pipeline = None
|
| 66 |
+
|
| 67 |
+
@property
|
| 68 |
+
def tts_pipeline(self):
|
| 69 |
+
"""Lazy-load Kokoro TTS pipeline on first use."""
|
| 70 |
+
if self._tts_pipeline is None:
|
| 71 |
+
try:
|
| 72 |
+
from kokoro import KPipeline
|
| 73 |
+
|
| 74 |
+
self._tts_pipeline = KPipeline(lang_code="a", repo_id="hexgrad/Kokoro-82M")
|
| 75 |
+
except ImportError as e:
|
| 76 |
+
raise ImportError(
|
| 77 |
+
"Kokoro TTS is required for audio output. "
|
| 78 |
+
"Install with: pip install kokoro>=0.9.2\n"
|
| 79 |
+
"Also requires espeak-ng: apt-get install espeak-ng"
|
| 80 |
+
) from e
|
| 81 |
+
return self._tts_pipeline
|
| 82 |
+
|
| 83 |
+
def text_to_speech(self, text: str, voice: str = DEFAULT_TTS_VOICE) -> dict[str, Any]:
|
| 84 |
+
"""Convert text to speech using Kokoro TTS.
|
| 85 |
+
|
| 86 |
+
Args:
|
| 87 |
+
text: Text to synthesize
|
| 88 |
+
voice: Kokoro voice ID (default: "af_heart")
|
| 89 |
+
|
| 90 |
+
Returns:
|
| 91 |
+
Dict with 'audio' (numpy array) and 'sample_rate' keys
|
| 92 |
+
"""
|
| 93 |
+
if not text or not text.strip():
|
| 94 |
+
return {"audio": np.array([], dtype=np.float32), "sample_rate": TTS_SAMPLE_RATE}
|
| 95 |
+
|
| 96 |
+
# Generate audio chunks and concatenate
|
| 97 |
+
audio_chunks = []
|
| 98 |
+
for _, _, audio in self.tts_pipeline(text, voice=voice):
|
| 99 |
+
audio_chunks.append(audio)
|
| 100 |
+
|
| 101 |
+
audio = np.concatenate(audio_chunks) if audio_chunks else np.array([], dtype=np.float32)
|
| 102 |
+
return {"audio": audio, "sample_rate": TTS_SAMPLE_RATE}
|
| 103 |
+
|
| 104 |
+
def transcribe_streaming(
|
| 105 |
+
self,
|
| 106 |
+
inputs: Union[str, bytes, np.ndarray, dict],
|
| 107 |
+
system_prompt: str | None = None,
|
| 108 |
+
) -> Iterator[str]:
|
| 109 |
+
"""Transcribe audio with streaming token output for low-latency applications.
|
| 110 |
+
|
| 111 |
+
Yields partial transcript strings as tokens are generated, reducing
|
| 112 |
+
time-to-first-word compared to waiting for full transcription.
|
| 113 |
+
|
| 114 |
+
Args:
|
| 115 |
+
inputs: Audio input in any supported format:
|
| 116 |
+
- str: File path to audio file
|
| 117 |
+
- bytes: Raw audio bytes
|
| 118 |
+
- np.ndarray: Audio samples as numpy array
|
| 119 |
+
- dict: {"array": np.ndarray, "sampling_rate": int}
|
| 120 |
+
system_prompt: Optional system prompt override (uses model's default if not provided)
|
| 121 |
+
|
| 122 |
+
Yields:
|
| 123 |
+
Partial transcript text as each token is generated
|
| 124 |
+
|
| 125 |
+
Example:
|
| 126 |
+
>>> for partial in pipeline.transcribe_streaming("audio.wav"):
|
| 127 |
+
... print(partial, end="", flush=True)
|
| 128 |
+
"""
|
| 129 |
+
# Extract audio array from various input formats
|
| 130 |
+
audio_data = self._extract_audio(inputs)
|
| 131 |
+
if audio_data is None:
|
| 132 |
+
return
|
| 133 |
+
|
| 134 |
+
audio_array = audio_data["array"]
|
| 135 |
+
sample_rate = audio_data.get("sampling_rate", 16000)
|
| 136 |
+
|
| 137 |
+
# Preprocess audio through feature extractor
|
| 138 |
+
model_inputs = self.feature_extractor(
|
| 139 |
+
audio_array,
|
| 140 |
+
sampling_rate=sample_rate,
|
| 141 |
+
return_tensors="pt",
|
| 142 |
+
return_attention_mask=True,
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
# Get model dtype and device, cast inputs to match
|
| 146 |
+
device = self.model.device
|
| 147 |
+
model_dtype = next(self.model.parameters()).dtype
|
| 148 |
+
input_features = model_inputs.input_features.to(device, dtype=model_dtype)
|
| 149 |
+
attention_mask = model_inputs.attention_mask.to(device)
|
| 150 |
+
|
| 151 |
+
# Stream tokens from model
|
| 152 |
+
yield from self.model.generate_streaming(
|
| 153 |
+
input_features=input_features,
|
| 154 |
+
audio_attention_mask=attention_mask,
|
| 155 |
+
system_prompt=system_prompt,
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
def transcribe_streaming_with_audio(
|
| 159 |
+
self,
|
| 160 |
+
inputs: Union[str, bytes, np.ndarray, dict],
|
| 161 |
+
voice: str = DEFAULT_TTS_VOICE,
|
| 162 |
+
system_prompt: str | None = None,
|
| 163 |
+
) -> Iterator[dict[str, Any]]:
|
| 164 |
+
"""Transcribe audio with streaming text AND audio output.
|
| 165 |
+
|
| 166 |
+
Yields partial text as tokens are generated, and audio chunks
|
| 167 |
+
as complete sentences are detected. This enables low-latency
|
| 168 |
+
voice agents that can start speaking before transcription completes.
|
| 169 |
+
|
| 170 |
+
Args:
|
| 171 |
+
inputs: Audio input (same formats as transcribe_streaming)
|
| 172 |
+
voice: Kokoro TTS voice ID
|
| 173 |
+
system_prompt: Optional system prompt override (uses model's default if not provided)
|
| 174 |
+
|
| 175 |
+
Yields:
|
| 176 |
+
Dicts with either:
|
| 177 |
+
- {"type": "text", "text": str, "interim": bool} for text updates
|
| 178 |
+
- {"type": "audio", "audio": np.ndarray, "sample_rate": int} for audio chunks
|
| 179 |
+
|
| 180 |
+
Example:
|
| 181 |
+
>>> for chunk in pipeline.transcribe_streaming_with_audio(audio):
|
| 182 |
+
... if chunk["type"] == "text":
|
| 183 |
+
... print(chunk["text"], end="", flush=True)
|
| 184 |
+
... elif chunk["type"] == "audio":
|
| 185 |
+
... play_audio(chunk["audio"], chunk["sample_rate"])
|
| 186 |
+
"""
|
| 187 |
+
import re
|
| 188 |
+
|
| 189 |
+
sentence_buffer = ""
|
| 190 |
+
full_text = ""
|
| 191 |
+
|
| 192 |
+
# Sentence-ending patterns (handles ., !, ?, and common abbreviations)
|
| 193 |
+
sentence_end_pattern = re.compile(r"[.!?](?:\s|$)")
|
| 194 |
+
|
| 195 |
+
for token_text in self.transcribe_streaming(inputs, system_prompt=system_prompt):
|
| 196 |
+
full_text += token_text
|
| 197 |
+
sentence_buffer += token_text
|
| 198 |
+
|
| 199 |
+
# Yield text update
|
| 200 |
+
yield {"type": "text", "text": full_text, "interim": True}
|
| 201 |
+
|
| 202 |
+
# Check for complete sentence
|
| 203 |
+
match = sentence_end_pattern.search(sentence_buffer)
|
| 204 |
+
if match:
|
| 205 |
+
# Extract complete sentence(s)
|
| 206 |
+
end_pos = match.end()
|
| 207 |
+
complete_text = sentence_buffer[:end_pos].strip()
|
| 208 |
+
sentence_buffer = sentence_buffer[end_pos:]
|
| 209 |
+
|
| 210 |
+
# Generate audio for the complete sentence
|
| 211 |
+
if complete_text:
|
| 212 |
+
try:
|
| 213 |
+
tts_result = self.text_to_speech(complete_text, voice=voice)
|
| 214 |
+
if tts_result["audio"] is not None and len(tts_result["audio"]) > 0:
|
| 215 |
+
yield {
|
| 216 |
+
"type": "audio",
|
| 217 |
+
"audio": tts_result["audio"],
|
| 218 |
+
"sample_rate": tts_result["sample_rate"],
|
| 219 |
+
}
|
| 220 |
+
except Exception:
|
| 221 |
+
pass # Skip audio on TTS errors
|
| 222 |
+
|
| 223 |
+
# Final text update (not interim)
|
| 224 |
+
yield {"type": "text", "text": full_text, "interim": False}
|
| 225 |
+
|
| 226 |
+
# Generate audio for any remaining text
|
| 227 |
+
remaining = sentence_buffer.strip()
|
| 228 |
+
if remaining:
|
| 229 |
+
try:
|
| 230 |
+
tts_result = self.text_to_speech(remaining, voice=voice)
|
| 231 |
+
if tts_result["audio"] is not None and len(tts_result["audio"]) > 0:
|
| 232 |
+
yield {
|
| 233 |
+
"type": "audio",
|
| 234 |
+
"audio": tts_result["audio"],
|
| 235 |
+
"sample_rate": tts_result["sample_rate"],
|
| 236 |
+
}
|
| 237 |
+
except Exception:
|
| 238 |
+
pass
|
| 239 |
+
|
| 240 |
+
def _sanitize_parameters(self, **kwargs):
|
| 241 |
+
"""Intercept our custom parameters before parent class validates them."""
|
| 242 |
+
# Remove our custom parameters so parent doesn't see them
|
| 243 |
+
kwargs.pop("return_timestamps", None)
|
| 244 |
+
kwargs.pop("return_speakers", None)
|
| 245 |
+
kwargs.pop("num_speakers", None)
|
| 246 |
+
kwargs.pop("min_speakers", None)
|
| 247 |
+
kwargs.pop("max_speakers", None)
|
| 248 |
+
kwargs.pop("hf_token", None)
|
| 249 |
+
kwargs.pop("user_prompt", None)
|
| 250 |
+
kwargs.pop("system_prompt", None)
|
| 251 |
+
kwargs.pop("diarization_backend", None)
|
| 252 |
+
# TTS parameters
|
| 253 |
+
kwargs.pop("return_audio", None)
|
| 254 |
+
kwargs.pop("tts_voice", None)
|
| 255 |
+
|
| 256 |
+
return super()._sanitize_parameters(**kwargs)
|
| 257 |
+
|
| 258 |
+
def __call__(
|
| 259 |
+
self,
|
| 260 |
+
inputs,
|
| 261 |
+
**kwargs,
|
| 262 |
+
):
|
| 263 |
+
"""Transcribe audio with optional word-level timestamps and speaker diarization.
|
| 264 |
+
|
| 265 |
+
Args:
|
| 266 |
+
inputs: Audio input (file path, dict with array/sampling_rate, etc.)
|
| 267 |
+
return_timestamps: If True, return word-level timestamps using forced alignment
|
| 268 |
+
return_speakers: If True, return speaker labels for each word
|
| 269 |
+
return_audio: If True, synthesize transcription as speech using Kokoro TTS
|
| 270 |
+
tts_voice: Kokoro voice ID for TTS output (default: "af_heart")
|
| 271 |
+
user_prompt: Custom transcription prompt (default: "Transcribe: ")
|
| 272 |
+
system_prompt: Custom system prompt override (uses model's default if not provided)
|
| 273 |
+
num_speakers: Exact number of speakers (if known, for diarization)
|
| 274 |
+
min_speakers: Minimum number of speakers (for diarization)
|
| 275 |
+
max_speakers: Maximum number of speakers (for diarization)
|
| 276 |
+
**kwargs: Additional arguments passed to the pipeline
|
| 277 |
+
|
| 278 |
+
Returns:
|
| 279 |
+
Dict with 'text' key, 'words' key if return_timestamps=True,
|
| 280 |
+
speaker labels on words if return_speakers=True,
|
| 281 |
+
and 'audio'/'sample_rate' keys if return_audio=True
|
| 282 |
+
"""
|
| 283 |
+
# Extract our params before super().__call__ (which will also call _sanitize_parameters)
|
| 284 |
+
return_timestamps = kwargs.pop("return_timestamps", False)
|
| 285 |
+
return_speakers = kwargs.pop("return_speakers", False)
|
| 286 |
+
return_audio = kwargs.pop("return_audio", False)
|
| 287 |
+
tts_voice = kwargs.pop("tts_voice", DEFAULT_TTS_VOICE)
|
| 288 |
+
user_prompt = kwargs.pop("user_prompt", None)
|
| 289 |
+
system_prompt = kwargs.pop("system_prompt", None)
|
| 290 |
+
diarization_params = {
|
| 291 |
+
"num_speakers": kwargs.pop("num_speakers", None),
|
| 292 |
+
"min_speakers": kwargs.pop("min_speakers", None),
|
| 293 |
+
"max_speakers": kwargs.pop("max_speakers", None),
|
| 294 |
+
}
|
| 295 |
+
|
| 296 |
+
if return_speakers:
|
| 297 |
+
return_timestamps = True
|
| 298 |
+
|
| 299 |
+
# Set custom user prompt if provided
|
| 300 |
+
original_prompt = None
|
| 301 |
+
if user_prompt:
|
| 302 |
+
original_prompt = self.model.TRANSCRIBE_PROMPT
|
| 303 |
+
self.model.TRANSCRIBE_PROMPT = user_prompt
|
| 304 |
+
|
| 305 |
+
# Set custom system prompt if provided
|
| 306 |
+
original_system_prompt = None
|
| 307 |
+
if system_prompt:
|
| 308 |
+
original_system_prompt = self.model.system_prompt
|
| 309 |
+
self.model.system_prompt = system_prompt
|
| 310 |
+
|
| 311 |
+
# Store audio for timestamp alignment and diarization
|
| 312 |
+
if return_timestamps or return_speakers:
|
| 313 |
+
self._current_audio = self._extract_audio(inputs)
|
| 314 |
+
|
| 315 |
+
# Run standard transcription
|
| 316 |
+
result = super().__call__(inputs, **kwargs)
|
| 317 |
+
|
| 318 |
+
# Add timestamps if requested
|
| 319 |
+
if return_timestamps and self._current_audio is not None:
|
| 320 |
+
text = result.get("text", "")
|
| 321 |
+
if text:
|
| 322 |
+
try:
|
| 323 |
+
words = ForcedAligner.align(
|
| 324 |
+
self._current_audio["array"],
|
| 325 |
+
text,
|
| 326 |
+
sample_rate=self._current_audio.get("sampling_rate", 16000),
|
| 327 |
+
)
|
| 328 |
+
result["words"] = words
|
| 329 |
+
except Exception as e:
|
| 330 |
+
result["words"] = []
|
| 331 |
+
result["timestamp_error"] = str(e)
|
| 332 |
+
else:
|
| 333 |
+
result["words"] = []
|
| 334 |
+
|
| 335 |
+
# Add speaker diarization if requested
|
| 336 |
+
if return_speakers and self._current_audio is not None:
|
| 337 |
+
try:
|
| 338 |
+
# Run diarization
|
| 339 |
+
speaker_segments = SpeakerDiarizer.diarize(
|
| 340 |
+
self._current_audio["array"],
|
| 341 |
+
sample_rate=self._current_audio.get("sampling_rate", 16000),
|
| 342 |
+
**{k: v for k, v in diarization_params.items() if v is not None},
|
| 343 |
+
)
|
| 344 |
+
result["speaker_segments"] = speaker_segments
|
| 345 |
+
|
| 346 |
+
# Assign speakers to words
|
| 347 |
+
if result.get("words"):
|
| 348 |
+
result["words"] = SpeakerDiarizer.assign_speakers_to_words(
|
| 349 |
+
result["words"],
|
| 350 |
+
speaker_segments,
|
| 351 |
+
)
|
| 352 |
+
except Exception as e:
|
| 353 |
+
result["speaker_segments"] = []
|
| 354 |
+
result["diarization_error"] = str(e)
|
| 355 |
+
|
| 356 |
+
# Synthesize transcription as speech if requested
|
| 357 |
+
if return_audio:
|
| 358 |
+
text = result.get("text", "")
|
| 359 |
+
try:
|
| 360 |
+
tts_result = self.text_to_speech(text, voice=tts_voice)
|
| 361 |
+
result["audio"] = tts_result["audio"]
|
| 362 |
+
result["sample_rate"] = tts_result["sample_rate"]
|
| 363 |
+
except Exception as e:
|
| 364 |
+
result["audio"] = np.array([], dtype=np.float32)
|
| 365 |
+
result["sample_rate"] = TTS_SAMPLE_RATE
|
| 366 |
+
result["tts_error"] = str(e)
|
| 367 |
+
|
| 368 |
+
# Clean up
|
| 369 |
+
self._current_audio = None
|
| 370 |
+
if original_prompt is not None:
|
| 371 |
+
self.model.TRANSCRIBE_PROMPT = original_prompt
|
| 372 |
+
if original_system_prompt is not None:
|
| 373 |
+
self.model.system_prompt = original_system_prompt
|
| 374 |
+
|
| 375 |
+
return result
|
| 376 |
+
|
| 377 |
+
def _extract_audio(self, inputs) -> dict | None:
|
| 378 |
+
"""Extract audio array from various input formats.
|
| 379 |
+
|
| 380 |
+
Supported input formats:
|
| 381 |
+
- str: File path to audio file
|
| 382 |
+
- bytes: Encoded audio (mp3, wav, etc.) - decoded via ffmpeg
|
| 383 |
+
- np.ndarray: Audio samples as float32 array
|
| 384 |
+
- dict with "array": Audio samples as numpy array
|
| 385 |
+
- dict with "raw": Alias for "array" (HF pipeline compat)
|
| 386 |
+
- dict with "raw_bytes": Raw PCM bytes (requires "dtype", optional "sampling_rate")
|
| 387 |
+
|
| 388 |
+
For raw PCM bytes (e.g., from pipecat), use:
|
| 389 |
+
{"raw_bytes": pcm_bytes, "dtype": "int16", "sampling_rate": 16000}
|
| 390 |
+
"""
|
| 391 |
+
from transformers.pipelines.audio_utils import ffmpeg_read
|
| 392 |
+
|
| 393 |
+
if isinstance(inputs, dict):
|
| 394 |
+
if "array" in inputs:
|
| 395 |
+
return {
|
| 396 |
+
"array": inputs["array"],
|
| 397 |
+
"sampling_rate": inputs.get("sampling_rate", 16000),
|
| 398 |
+
}
|
| 399 |
+
if "raw" in inputs:
|
| 400 |
+
return {
|
| 401 |
+
"array": inputs["raw"],
|
| 402 |
+
"sampling_rate": inputs.get("sampling_rate", 16000),
|
| 403 |
+
}
|
| 404 |
+
if "raw_bytes" in inputs:
|
| 405 |
+
# Raw PCM bytes - convert to float32 array
|
| 406 |
+
dtype = inputs.get("dtype", "int16")
|
| 407 |
+
sample_rate = inputs.get("sampling_rate", 16000)
|
| 408 |
+
audio = np.frombuffer(inputs["raw_bytes"], dtype=dtype).astype(np.float32)
|
| 409 |
+
# Normalize based on dtype
|
| 410 |
+
if dtype == "int16":
|
| 411 |
+
audio = audio / 32768.0
|
| 412 |
+
elif dtype == "int32":
|
| 413 |
+
audio = audio / 2147483648.0
|
| 414 |
+
return {"array": audio, "sampling_rate": sample_rate}
|
| 415 |
+
elif isinstance(inputs, str):
|
| 416 |
+
# File path - load audio using ffmpeg (same as HF pipeline)
|
| 417 |
+
with Path(inputs).open("rb") as f:
|
| 418 |
+
audio = ffmpeg_read(f.read(), sampling_rate=16000)
|
| 419 |
+
return {"array": audio, "sampling_rate": 16000}
|
| 420 |
+
elif isinstance(inputs, bytes):
|
| 421 |
+
audio = ffmpeg_read(inputs, sampling_rate=16000)
|
| 422 |
+
return {"array": audio, "sampling_rate": 16000}
|
| 423 |
+
elif isinstance(inputs, np.ndarray):
|
| 424 |
+
return {"array": inputs, "sampling_rate": 16000}
|
| 425 |
+
|
| 426 |
+
return None
|
| 427 |
+
|
| 428 |
+
def preprocess(self, inputs, **preprocess_params):
|
| 429 |
+
"""Preprocess audio inputs for the model.
|
| 430 |
+
|
| 431 |
+
Args:
|
| 432 |
+
inputs: Audio input (dict with array, file path, etc.)
|
| 433 |
+
**preprocess_params: Additional preprocessing parameters
|
| 434 |
+
|
| 435 |
+
Yields:
|
| 436 |
+
Model input dicts with input_features and attention_mask
|
| 437 |
+
"""
|
| 438 |
+
# Handle dict with "array" key (from datasets)
|
| 439 |
+
if isinstance(inputs, dict) and "array" in inputs:
|
| 440 |
+
inputs = {
|
| 441 |
+
"raw": inputs["array"],
|
| 442 |
+
"sampling_rate": inputs.get("sampling_rate", self.feature_extractor.sampling_rate),
|
| 443 |
+
}
|
| 444 |
+
|
| 445 |
+
for item in super().preprocess(inputs, **preprocess_params):
|
| 446 |
+
if "is_last" not in item:
|
| 447 |
+
item["is_last"] = True
|
| 448 |
+
yield item
|
| 449 |
+
|
| 450 |
+
def _forward(self, model_inputs, **generate_kwargs) -> dict[str, Any]:
|
| 451 |
+
"""Run model forward pass to generate transcription.
|
| 452 |
+
|
| 453 |
+
Args:
|
| 454 |
+
model_inputs: Dict with input_features and attention_mask
|
| 455 |
+
**generate_kwargs: Generation parameters
|
| 456 |
+
|
| 457 |
+
Returns:
|
| 458 |
+
Dict with generated token IDs
|
| 459 |
+
"""
|
| 460 |
+
# Extract audio features and is_last flag
|
| 461 |
+
is_last = model_inputs.pop("is_last", True) if isinstance(model_inputs, dict) else True
|
| 462 |
+
|
| 463 |
+
input_features = model_inputs["input_features"].to(self.model.device)
|
| 464 |
+
audio_attention_mask = model_inputs["attention_mask"].to(self.model.device)
|
| 465 |
+
|
| 466 |
+
generated_ids = self.model.generate(
|
| 467 |
+
input_features=input_features,
|
| 468 |
+
audio_attention_mask=audio_attention_mask,
|
| 469 |
+
**generate_kwargs,
|
| 470 |
+
)
|
| 471 |
+
|
| 472 |
+
return {"tokens": generated_ids, "is_last": is_last}
|
| 473 |
+
|
| 474 |
+
def postprocess(self, model_outputs, **kwargs) -> dict[str, str]:
|
| 475 |
+
"""Convert model output tokens to text.
|
| 476 |
+
|
| 477 |
+
Args:
|
| 478 |
+
model_outputs: Dict with 'tokens' key containing generated IDs
|
| 479 |
+
**kwargs: Additional postprocessing parameters
|
| 480 |
+
|
| 481 |
+
Returns:
|
| 482 |
+
Dict with 'text' key containing transcription
|
| 483 |
+
"""
|
| 484 |
+
# Handle list of outputs (from chunking)
|
| 485 |
+
if isinstance(model_outputs, list):
|
| 486 |
+
model_outputs = model_outputs[0] if model_outputs else {}
|
| 487 |
+
|
| 488 |
+
tokens = model_outputs.get("tokens")
|
| 489 |
+
if tokens is None:
|
| 490 |
+
return super().postprocess(model_outputs, **kwargs)
|
| 491 |
+
|
| 492 |
+
if torch.is_tensor(tokens):
|
| 493 |
+
tokens = tokens.cpu()
|
| 494 |
+
if tokens.dim() > 1:
|
| 495 |
+
tokens = tokens[0]
|
| 496 |
+
|
| 497 |
+
# Filter out eos tokens that the tokenizer doesn't recognize as special
|
| 498 |
+
# (generation_config.eos_token_id may differ from tokenizer.eos_token_id)
|
| 499 |
+
if hasattr(self, "model") and hasattr(self.model, "generation_config"):
|
| 500 |
+
eos_ids = self.model.generation_config.eos_token_id
|
| 501 |
+
if eos_ids is not None:
|
| 502 |
+
eos_set = set(eos_ids) if isinstance(eos_ids, list) else {eos_ids}
|
| 503 |
+
tokens = [t for t in tokens.tolist() if t not in eos_set]
|
| 504 |
+
|
| 505 |
+
text = self.tokenizer.decode(tokens, skip_special_tokens=True).strip()
|
| 506 |
+
# Strip <think>...</think> tags (Qwen3 doesn't respect /no_think prompt)
|
| 507 |
+
text = strip_thinking(text)
|
| 508 |
+
# Truncate repetitions at end of text
|
| 509 |
+
text = _truncate_repetitions(text)
|
| 510 |
+
return {"text": text}
|
| 511 |
+
|
| 512 |
+
|
| 513 |
+
def _truncate_repetitions(text: str, min_repeats: int = 3) -> str:
|
| 514 |
+
"""Truncate repeated words/phrases/characters at end of text.
|
| 515 |
+
|
| 516 |
+
Detects patterns like:
|
| 517 |
+
- Repeated words: "the the the the" -> "the"
|
| 518 |
+
- Repeated phrases: "i am sorry i am sorry i am sorry" -> "i am sorry"
|
| 519 |
+
- Repeated characters: "444444" -> "4"
|
| 520 |
+
|
| 521 |
+
Args:
|
| 522 |
+
text: Input text to process
|
| 523 |
+
min_repeats: Minimum repetitions to trigger truncation (default 3)
|
| 524 |
+
|
| 525 |
+
Returns:
|
| 526 |
+
Text with trailing repetitions removed
|
| 527 |
+
"""
|
| 528 |
+
if not text:
|
| 529 |
+
return text
|
| 530 |
+
|
| 531 |
+
# 1. Truncate repeated characters at end (e.g., "444444" -> "4")
|
| 532 |
+
char_pattern = re.compile(r"(.)\1{" + str(min_repeats - 1) + r",}$")
|
| 533 |
+
text = char_pattern.sub(r"\1", text)
|
| 534 |
+
|
| 535 |
+
# 2. Truncate repeated words at end (e.g., "the the the" -> "the")
|
| 536 |
+
word_pattern = re.compile(
|
| 537 |
+
r"\b(\w+)(?:\s+\1){" + str(min_repeats - 1) + r",}\s*$", re.IGNORECASE
|
| 538 |
+
)
|
| 539 |
+
while word_pattern.search(text):
|
| 540 |
+
text = word_pattern.sub(r"\1", text)
|
| 541 |
+
|
| 542 |
+
# 3. Truncate repeated phrases (2-20 words) at end
|
| 543 |
+
# e.g., "i am sorry i am sorry i am sorry" -> "i am sorry"
|
| 544 |
+
words = text.split()
|
| 545 |
+
if len(words) >= min_repeats * 2:
|
| 546 |
+
# Try phrase lengths from 2 to 20 words
|
| 547 |
+
for phrase_len in range(2, min(21, len(words) // min_repeats + 1)):
|
| 548 |
+
# Check if the last phrase_len words repeat
|
| 549 |
+
phrase = " ".join(words[-phrase_len:])
|
| 550 |
+
# Build pattern to match repeated phrases at end
|
| 551 |
+
phrase_escaped = re.escape(phrase)
|
| 552 |
+
phrase_pattern = re.compile(
|
| 553 |
+
r"(^|.*?\s)("
|
| 554 |
+
+ phrase_escaped
|
| 555 |
+
+ r")(?:\s+"
|
| 556 |
+
+ phrase_escaped
|
| 557 |
+
+ r"){"
|
| 558 |
+
+ str(min_repeats - 1)
|
| 559 |
+
+ r",}\s*$",
|
| 560 |
+
re.IGNORECASE,
|
| 561 |
+
)
|
| 562 |
+
match = phrase_pattern.match(text)
|
| 563 |
+
if match:
|
| 564 |
+
# Keep prefix + one instance of the phrase
|
| 565 |
+
text = (match.group(1) + match.group(2)).strip()
|
| 566 |
+
words = text.split()
|
| 567 |
+
break
|
| 568 |
+
|
| 569 |
+
return text
|
asr_processing.py
ADDED
|
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional, Union
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import transformers
|
| 5 |
+
from transformers import ProcessorMixin
|
| 6 |
+
|
| 7 |
+
try:
|
| 8 |
+
from .asr_config import ASRConfig
|
| 9 |
+
except ImportError:
|
| 10 |
+
from asr_config import ASRConfig # type: ignore[no-redef]
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class ASRProcessor(ProcessorMixin):
|
| 14 |
+
"""Processor for Whisper-based ASR models."""
|
| 15 |
+
|
| 16 |
+
attributes = ["feature_extractor", "tokenizer"]
|
| 17 |
+
feature_extractor_class = "AutoFeatureExtractor"
|
| 18 |
+
tokenizer_class = "AutoTokenizer"
|
| 19 |
+
AUDIO_TOKEN = "<audio>"
|
| 20 |
+
TRANSCRIBE_PROMPT = ""
|
| 21 |
+
# Default conv layers for Whisper/GLM-ASR: [(pad, kernel, stride), ...]
|
| 22 |
+
DEFAULT_ENCODER_CONV_LAYERS = [(1, 3, 1), (1, 3, 2)]
|
| 23 |
+
|
| 24 |
+
def __init__(
|
| 25 |
+
self,
|
| 26 |
+
feature_extractor,
|
| 27 |
+
tokenizer,
|
| 28 |
+
projector=None,
|
| 29 |
+
encoder_conv_layers: Optional[list] = None,
|
| 30 |
+
):
|
| 31 |
+
"""Initialize the ASR processor.
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
feature_extractor: Audio feature extractor (WhisperFeatureExtractor)
|
| 35 |
+
tokenizer: Text tokenizer for the language model
|
| 36 |
+
projector: Audio projector module (for computing output lengths)
|
| 37 |
+
encoder_conv_layers: Conv layer specs [(pad, kernel, stride), ...]
|
| 38 |
+
"""
|
| 39 |
+
self.feature_extractor = feature_extractor
|
| 40 |
+
self.tokenizer = tokenizer
|
| 41 |
+
self.audio_token_id = tokenizer.convert_tokens_to_ids(self.AUDIO_TOKEN)
|
| 42 |
+
self.projector = projector
|
| 43 |
+
self.encoder_conv_layers = encoder_conv_layers or self.DEFAULT_ENCODER_CONV_LAYERS
|
| 44 |
+
|
| 45 |
+
def _compute_encoder_output_length(self, mel_length: int) -> int:
|
| 46 |
+
"""Compute encoder output length using conv layer formulas."""
|
| 47 |
+
length = mel_length
|
| 48 |
+
for padding, kernel_size, stride in self.encoder_conv_layers:
|
| 49 |
+
length = (length + 2 * padding - (kernel_size - 1) - 1) // stride + 1
|
| 50 |
+
return length
|
| 51 |
+
|
| 52 |
+
def __call__(
|
| 53 |
+
self,
|
| 54 |
+
audio: Optional[Union[list, "torch.Tensor"]] = None,
|
| 55 |
+
text: Optional[str] = None,
|
| 56 |
+
system_prompt: Optional[str] = None,
|
| 57 |
+
return_tensors: str = "pt",
|
| 58 |
+
**kwargs,
|
| 59 |
+
) -> dict:
|
| 60 |
+
"""Process audio and text inputs for inference.
|
| 61 |
+
|
| 62 |
+
Args:
|
| 63 |
+
audio: Raw audio waveform(s)
|
| 64 |
+
text: Target transcription (optional, for training - but use DataCollator instead)
|
| 65 |
+
system_prompt: Optional system prompt
|
| 66 |
+
return_tensors: Return format ("pt" for PyTorch)
|
| 67 |
+
|
| 68 |
+
Returns:
|
| 69 |
+
Dict with input_features, input_ids, attention_mask
|
| 70 |
+
"""
|
| 71 |
+
result = {}
|
| 72 |
+
|
| 73 |
+
# Process audio
|
| 74 |
+
if audio is not None:
|
| 75 |
+
audio_inputs = self.feature_extractor(
|
| 76 |
+
audio,
|
| 77 |
+
sampling_rate=getattr(self.feature_extractor, "sampling_rate", 16000),
|
| 78 |
+
return_attention_mask=True,
|
| 79 |
+
return_tensors=return_tensors,
|
| 80 |
+
**kwargs,
|
| 81 |
+
)
|
| 82 |
+
result["input_features"] = audio_inputs["input_features"]
|
| 83 |
+
result["audio_attention_mask"] = audio_inputs["attention_mask"]
|
| 84 |
+
|
| 85 |
+
# Use actual audio length (from attention mask) for token count
|
| 86 |
+
real_mel_len = int(audio_inputs["attention_mask"].sum(dim=-1).max().item())
|
| 87 |
+
encoder_output_len = self._compute_encoder_output_length(real_mel_len)
|
| 88 |
+
num_audio_tokens = self.projector.get_output_length(encoder_output_len)
|
| 89 |
+
else:
|
| 90 |
+
num_audio_tokens = 0
|
| 91 |
+
|
| 92 |
+
# Build prompt with audio token placeholders (instruction-free)
|
| 93 |
+
if num_audio_tokens > 0:
|
| 94 |
+
user_content = self.AUDIO_TOKEN * num_audio_tokens
|
| 95 |
+
if self.TRANSCRIBE_PROMPT:
|
| 96 |
+
user_content += " " + self.TRANSCRIBE_PROMPT
|
| 97 |
+
else:
|
| 98 |
+
user_content = self.TRANSCRIBE_PROMPT or ""
|
| 99 |
+
|
| 100 |
+
messages = []
|
| 101 |
+
if system_prompt:
|
| 102 |
+
messages.append({"role": "system", "content": system_prompt})
|
| 103 |
+
messages.append({"role": "user", "content": user_content})
|
| 104 |
+
if text is not None:
|
| 105 |
+
messages.append({"role": "assistant", "content": text})
|
| 106 |
+
|
| 107 |
+
# Tokenize
|
| 108 |
+
tokenized = self.tokenizer.apply_chat_template(
|
| 109 |
+
messages,
|
| 110 |
+
tokenize=True,
|
| 111 |
+
add_generation_prompt=(text is None),
|
| 112 |
+
return_tensors=return_tensors,
|
| 113 |
+
enable_thinking=False, # Disable Qwen3 thinking mode for ASR
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
# Handle both tensor and BatchEncoding returns
|
| 117 |
+
if isinstance(tokenized, torch.Tensor):
|
| 118 |
+
input_ids = tokenized
|
| 119 |
+
else:
|
| 120 |
+
# BatchEncoding or dict-like object
|
| 121 |
+
input_ids = tokenized.get("input_ids", tokenized.input_ids)
|
| 122 |
+
|
| 123 |
+
if input_ids.dim() == 1:
|
| 124 |
+
input_ids = input_ids.unsqueeze(0)
|
| 125 |
+
|
| 126 |
+
result["input_ids"] = input_ids
|
| 127 |
+
result["attention_mask"] = torch.ones_like(input_ids)
|
| 128 |
+
|
| 129 |
+
return result
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
ASRProcessor.register_for_auto_class()
|
| 133 |
+
transformers.AutoProcessor.register(ASRConfig, ASRProcessor)
|
audio_head.py
ADDED
|
@@ -0,0 +1,396 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Flow matching audio head for speech-to-speech.
|
| 2 |
+
|
| 3 |
+
Generates audio from LLM hidden states via flow matching:
|
| 4 |
+
LLM hidden -> llm_proj -> flow_net (LSD decode) -> Mimi latents -> Mimi decoder -> audio
|
| 5 |
+
|
| 6 |
+
Supports two modes:
|
| 7 |
+
1. Training from scratch with 512-dim Mimi embeddings (latent_proj_in/out)
|
| 8 |
+
2. Using pretrained pocket-tts flow_net with 32-dim normalized latents
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import logging
|
| 12 |
+
from functools import partial
|
| 13 |
+
from typing import Optional
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
import torch.nn as nn
|
| 17 |
+
|
| 18 |
+
from .modules.mlp import SimpleMLPAdaLN
|
| 19 |
+
|
| 20 |
+
logger = logging.getLogger(__name__)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def lsd_decode(
|
| 24 |
+
v_t,
|
| 25 |
+
x_0: torch.Tensor,
|
| 26 |
+
num_steps: int = 1,
|
| 27 |
+
) -> torch.Tensor:
|
| 28 |
+
"""Lagrangian Self-Distillation decoding.
|
| 29 |
+
|
| 30 |
+
Iteratively refines noise into latents using the flow velocity network.
|
| 31 |
+
|
| 32 |
+
Args:
|
| 33 |
+
v_t: Velocity function v(s, t, x) -> velocity
|
| 34 |
+
x_0: Initial noise, shape [N, latent_dim]
|
| 35 |
+
num_steps: Number of integration steps
|
| 36 |
+
|
| 37 |
+
Returns:
|
| 38 |
+
Decoded latents, shape [N, latent_dim]
|
| 39 |
+
"""
|
| 40 |
+
current = x_0
|
| 41 |
+
for i in range(num_steps):
|
| 42 |
+
s = i / num_steps
|
| 43 |
+
t = (i + 1) / num_steps
|
| 44 |
+
s_tensor = torch.full_like(x_0[..., :1], s)
|
| 45 |
+
t_tensor = torch.full_like(x_0[..., :1], t)
|
| 46 |
+
flow_dir = v_t(s_tensor, t_tensor, current)
|
| 47 |
+
current = current + flow_dir / num_steps
|
| 48 |
+
return current
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class AudioHead(nn.Module):
|
| 52 |
+
"""Flow matching head: LLM hidden -> Mimi latents -> audio.
|
| 53 |
+
|
| 54 |
+
Architecture:
|
| 55 |
+
- llm_proj: Linear projection from LLM hidden dim to flow conditioning
|
| 56 |
+
- latent_proj_in/out: Project between Mimi 512-dim and flow 32-dim
|
| 57 |
+
- flow_net: SimpleMLPAdaLN that predicts flow velocity
|
| 58 |
+
- Mimi decoder for latent -> audio
|
| 59 |
+
|
| 60 |
+
Args:
|
| 61 |
+
config: ASRConfig with:
|
| 62 |
+
- llm_dim: LLM hidden dimension (default: 2048)
|
| 63 |
+
- lsd_decode_steps: Number of LSD integration steps (default: 1)
|
| 64 |
+
- flow_temperature: Sampling temperature for noise (default: 1.0)
|
| 65 |
+
"""
|
| 66 |
+
|
| 67 |
+
# Architecture dimensions
|
| 68 |
+
COND_DIM = 1024 # Conditioning dimension
|
| 69 |
+
LATENT_DIM = 32 # Flow latent dimension (matches Mimi's 32 codebooks)
|
| 70 |
+
MIMI_DIM = 512 # Mimi encoder output dimension
|
| 71 |
+
FLOW_DIM = 512 # Flow network hidden dimension
|
| 72 |
+
FLOW_DEPTH = 6 # Number of residual blocks
|
| 73 |
+
|
| 74 |
+
def __init__(self, config, llm_dim: int = None):
|
| 75 |
+
super().__init__()
|
| 76 |
+
# llm_dim can be passed directly or from config
|
| 77 |
+
self.llm_dim = llm_dim or getattr(config, "llm_dim", None) or 2048
|
| 78 |
+
self.cond_dim = self.COND_DIM
|
| 79 |
+
self.latent_dim = self.LATENT_DIM
|
| 80 |
+
self.mimi_dim = self.MIMI_DIM
|
| 81 |
+
self.lsd_steps = getattr(config, "lsd_decode_steps", 1)
|
| 82 |
+
self.temp = getattr(config, "flow_temperature", 1.0)
|
| 83 |
+
|
| 84 |
+
# LLM -> conditioning projection
|
| 85 |
+
self.llm_proj = nn.Linear(self.llm_dim, self.cond_dim, bias=False)
|
| 86 |
+
|
| 87 |
+
# Mimi embedding projections
|
| 88 |
+
# Projects 512-dim Mimi embeddings to 32-dim flow latents and back
|
| 89 |
+
self.latent_proj_in = nn.Linear(self.mimi_dim, self.latent_dim, bias=False)
|
| 90 |
+
self.latent_proj_out = nn.Linear(self.latent_dim, self.mimi_dim, bias=False)
|
| 91 |
+
|
| 92 |
+
# Flow network
|
| 93 |
+
self.flow_net = SimpleMLPAdaLN(
|
| 94 |
+
in_channels=self.latent_dim,
|
| 95 |
+
model_channels=self.FLOW_DIM,
|
| 96 |
+
out_channels=self.latent_dim,
|
| 97 |
+
cond_channels=self.cond_dim,
|
| 98 |
+
num_res_blocks=self.FLOW_DEPTH,
|
| 99 |
+
num_time_conds=2,
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
# Normalization buffers for pretrained pocket-tts flow_net
|
| 103 |
+
# When using pretrained weights, the flow operates in normalized 32-dim space
|
| 104 |
+
self.register_buffer("emb_mean", torch.zeros(self.latent_dim))
|
| 105 |
+
self.register_buffer("emb_std", torch.ones(self.latent_dim))
|
| 106 |
+
self._use_pretrained_normalization = False
|
| 107 |
+
|
| 108 |
+
# Mimi decoder components (loaded separately via load_mimi_decoder)
|
| 109 |
+
self.mimi = None
|
| 110 |
+
|
| 111 |
+
def load_mimi_decoder(self, device: torch.device = None, dtype: torch.dtype = None):
|
| 112 |
+
"""Load Mimi model for decoding latents to audio."""
|
| 113 |
+
from transformers import MimiModel
|
| 114 |
+
|
| 115 |
+
self.mimi = MimiModel.from_pretrained("kyutai/mimi")
|
| 116 |
+
self.mimi.requires_grad_(False)
|
| 117 |
+
self.mimi.eval()
|
| 118 |
+
|
| 119 |
+
if device is not None:
|
| 120 |
+
self.mimi = self.mimi.to(device)
|
| 121 |
+
if dtype is not None:
|
| 122 |
+
self.mimi = self.mimi.to(dtype)
|
| 123 |
+
|
| 124 |
+
logger.info("Loaded Mimi decoder from kyutai/mimi")
|
| 125 |
+
|
| 126 |
+
def load_pretrained_flow_net(
|
| 127 |
+
self,
|
| 128 |
+
weights_path: Optional[str] = None,
|
| 129 |
+
freeze: bool = True,
|
| 130 |
+
):
|
| 131 |
+
"""Load pretrained pocket-tts flow_net weights.
|
| 132 |
+
|
| 133 |
+
This enables using the pretrained flow matching network from pocket-tts,
|
| 134 |
+
which operates in normalized 32-dim latent space.
|
| 135 |
+
|
| 136 |
+
Args:
|
| 137 |
+
weights_path: Path to safetensors file. If None, downloads from HuggingFace.
|
| 138 |
+
freeze: Whether to freeze flow_net weights (default: True, only train llm_proj)
|
| 139 |
+
"""
|
| 140 |
+
import safetensors.torch
|
| 141 |
+
|
| 142 |
+
if weights_path is None:
|
| 143 |
+
from huggingface_hub import hf_hub_download
|
| 144 |
+
|
| 145 |
+
weights_path = hf_hub_download(
|
| 146 |
+
repo_id="kyutai/pocket-tts", filename="tts_b6369a24.safetensors"
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
state = safetensors.torch.load_file(weights_path)
|
| 150 |
+
|
| 151 |
+
# Extract flow_net weights
|
| 152 |
+
flow_state = {}
|
| 153 |
+
for k, v in state.items():
|
| 154 |
+
if k.startswith("flow_lm.flow_net."):
|
| 155 |
+
new_key = k.replace("flow_lm.flow_net.", "")
|
| 156 |
+
flow_state[new_key] = v
|
| 157 |
+
|
| 158 |
+
self.flow_net.load_state_dict(flow_state)
|
| 159 |
+
logger.info(f"Loaded pretrained flow_net from {weights_path}")
|
| 160 |
+
|
| 161 |
+
# Load normalization buffers
|
| 162 |
+
if "flow_lm.emb_mean" in state:
|
| 163 |
+
self.emb_mean.copy_(state["flow_lm.emb_mean"])
|
| 164 |
+
if "flow_lm.emb_std" in state:
|
| 165 |
+
self.emb_std.copy_(state["flow_lm.emb_std"])
|
| 166 |
+
# Enable normalization for generate
|
| 167 |
+
self._use_pretrained_normalization = True
|
| 168 |
+
logger.info("Loaded emb_mean and emb_std for normalization")
|
| 169 |
+
|
| 170 |
+
if freeze:
|
| 171 |
+
self.flow_net.requires_grad_(False)
|
| 172 |
+
logger.info("Froze flow_net weights (only llm_proj will train)")
|
| 173 |
+
|
| 174 |
+
def forward(
|
| 175 |
+
self,
|
| 176 |
+
hidden_states: torch.Tensor,
|
| 177 |
+
latent_targets: Optional[torch.Tensor] = None,
|
| 178 |
+
latent_lengths: Optional[torch.Tensor] = None,
|
| 179 |
+
) -> torch.Tensor:
|
| 180 |
+
"""Forward pass for training or inference.
|
| 181 |
+
|
| 182 |
+
Args:
|
| 183 |
+
hidden_states: LLM hidden states, shape [batch, seq_len, llm_dim]
|
| 184 |
+
latent_targets: Target Mimi latents for training, shape [batch, seq_len, 512]
|
| 185 |
+
latent_lengths: Actual lengths per sample, shape [batch]
|
| 186 |
+
|
| 187 |
+
Returns:
|
| 188 |
+
Training: scalar flow matching loss
|
| 189 |
+
Inference: generated Mimi latents, shape [batch, seq_len, 512]
|
| 190 |
+
"""
|
| 191 |
+
# Project LLM hidden states to conditioning
|
| 192 |
+
cond = self.llm_proj(hidden_states)
|
| 193 |
+
|
| 194 |
+
if latent_targets is not None:
|
| 195 |
+
return self._compute_loss(cond, latent_targets, latent_lengths)
|
| 196 |
+
return self._generate(cond)
|
| 197 |
+
|
| 198 |
+
def _compute_loss(
|
| 199 |
+
self,
|
| 200 |
+
cond: torch.Tensor,
|
| 201 |
+
targets: torch.Tensor,
|
| 202 |
+
lengths: Optional[torch.Tensor],
|
| 203 |
+
) -> torch.Tensor:
|
| 204 |
+
"""Compute flow matching loss with reconstruction term.
|
| 205 |
+
|
| 206 |
+
The loss has two components:
|
| 207 |
+
1. Flow matching loss: MSE between predicted and target velocities in 32-dim space
|
| 208 |
+
2. Reconstruction loss: MSE between reconstructed and original 512-dim embeddings
|
| 209 |
+
(this ensures latent_proj_out is trained)
|
| 210 |
+
|
| 211 |
+
Args:
|
| 212 |
+
cond: Conditioning from LLM, shape [batch, cond_seq_len, cond_dim]
|
| 213 |
+
targets: Mimi embeddings, shape [batch, target_seq_len, 512]
|
| 214 |
+
lengths: Optional lengths for masking
|
| 215 |
+
"""
|
| 216 |
+
# Debug: check inputs for NaN/Inf
|
| 217 |
+
if torch.isnan(cond).any() or torch.isinf(cond).any():
|
| 218 |
+
logger.warning(
|
| 219 |
+
f"NaN/Inf in cond! shape={cond.shape}, nan={torch.isnan(cond).sum()}, inf={torch.isinf(cond).sum()}"
|
| 220 |
+
)
|
| 221 |
+
if torch.isnan(targets).any() or torch.isinf(targets).any():
|
| 222 |
+
logger.warning(f"NaN/Inf in targets! shape={targets.shape}")
|
| 223 |
+
|
| 224 |
+
batch, cond_seq_len, _ = cond.shape
|
| 225 |
+
target_seq_len = targets.shape[1]
|
| 226 |
+
device = cond.device
|
| 227 |
+
dtype = cond.dtype
|
| 228 |
+
|
| 229 |
+
# Handle empty sequences
|
| 230 |
+
if cond_seq_len == 0 or target_seq_len == 0:
|
| 231 |
+
return torch.tensor(0.0, device=device, dtype=dtype, requires_grad=True)
|
| 232 |
+
|
| 233 |
+
# Project 512-dim Mimi embeddings to 32-dim flow latents
|
| 234 |
+
targets_proj = self.latent_proj_in(targets)
|
| 235 |
+
|
| 236 |
+
# Compute reconstruction loss to train latent_proj_out
|
| 237 |
+
# This ensures the projection learns a good inverse mapping
|
| 238 |
+
targets_reconstructed = self.latent_proj_out(targets_proj)
|
| 239 |
+
|
| 240 |
+
# Interpolate targets to match conditioning sequence length
|
| 241 |
+
targets_for_interp = targets
|
| 242 |
+
if target_seq_len != cond_seq_len:
|
| 243 |
+
targets_proj = targets_proj.transpose(1, 2)
|
| 244 |
+
targets_proj = torch.nn.functional.interpolate(
|
| 245 |
+
targets_proj, size=cond_seq_len, mode="linear", align_corners=False
|
| 246 |
+
)
|
| 247 |
+
targets_proj = targets_proj.transpose(1, 2).contiguous()
|
| 248 |
+
|
| 249 |
+
# Also interpolate original targets for reconstruction loss
|
| 250 |
+
targets_for_interp = targets.transpose(1, 2)
|
| 251 |
+
targets_for_interp = torch.nn.functional.interpolate(
|
| 252 |
+
targets_for_interp, size=cond_seq_len, mode="linear", align_corners=False
|
| 253 |
+
)
|
| 254 |
+
targets_for_interp = targets_for_interp.transpose(1, 2).contiguous()
|
| 255 |
+
|
| 256 |
+
# Interpolate reconstructed targets to match
|
| 257 |
+
targets_reconstructed = targets_reconstructed.transpose(1, 2)
|
| 258 |
+
targets_reconstructed = torch.nn.functional.interpolate(
|
| 259 |
+
targets_reconstructed, size=cond_seq_len, mode="linear", align_corners=False
|
| 260 |
+
)
|
| 261 |
+
targets_reconstructed = targets_reconstructed.transpose(1, 2).contiguous()
|
| 262 |
+
|
| 263 |
+
if lengths is not None:
|
| 264 |
+
scale = cond_seq_len / target_seq_len
|
| 265 |
+
lengths = (lengths.float() * scale).long()
|
| 266 |
+
|
| 267 |
+
seq_len = cond_seq_len
|
| 268 |
+
x_1 = targets_proj
|
| 269 |
+
|
| 270 |
+
# Random timesteps for each sample/position (match input dtype)
|
| 271 |
+
t = torch.rand(batch, seq_len, 1, device=device, dtype=dtype)
|
| 272 |
+
|
| 273 |
+
# Sample noise
|
| 274 |
+
x_0 = torch.randn_like(x_1)
|
| 275 |
+
|
| 276 |
+
# Linear interpolation: x_t = (1-t) * x_0 + t * x_1
|
| 277 |
+
x_t = (1 - t) * x_0 + t * x_1
|
| 278 |
+
|
| 279 |
+
# Target velocity: dx/dt = x_1 - x_0
|
| 280 |
+
v_target = x_1 - x_0
|
| 281 |
+
|
| 282 |
+
# Flatten for flow_net: [batch * seq_len, dim]
|
| 283 |
+
cond_flat = cond.view(-1, self.cond_dim)
|
| 284 |
+
t_flat = t.view(-1, 1)
|
| 285 |
+
x_t_flat = x_t.view(-1, self.latent_dim)
|
| 286 |
+
|
| 287 |
+
# Predict velocity
|
| 288 |
+
v_pred = self.flow_net(cond_flat, t_flat, t_flat, x_t_flat)
|
| 289 |
+
v_pred = v_pred.view(batch, seq_len, self.latent_dim)
|
| 290 |
+
|
| 291 |
+
# Compute masked losses
|
| 292 |
+
if lengths is not None:
|
| 293 |
+
positions = torch.arange(seq_len, device=device).unsqueeze(0)
|
| 294 |
+
mask = positions < lengths.unsqueeze(1)
|
| 295 |
+
|
| 296 |
+
# Check if mask is all False (no valid positions)
|
| 297 |
+
if not mask.any():
|
| 298 |
+
return torch.tensor(0.0, device=device, dtype=dtype, requires_grad=True)
|
| 299 |
+
|
| 300 |
+
flow_mask = mask.unsqueeze(-1).expand_as(v_pred)
|
| 301 |
+
recon_mask = mask.unsqueeze(-1).expand_as(targets_reconstructed)
|
| 302 |
+
|
| 303 |
+
flow_loss = ((v_pred - v_target) ** 2)[flow_mask].mean()
|
| 304 |
+
recon_loss = ((targets_reconstructed - targets_for_interp) ** 2)[recon_mask].mean()
|
| 305 |
+
else:
|
| 306 |
+
flow_loss = ((v_pred - v_target) ** 2).mean()
|
| 307 |
+
recon_loss = ((targets_reconstructed - targets_for_interp) ** 2).mean()
|
| 308 |
+
|
| 309 |
+
# Combined loss (reconstruction loss weighted at 0.1 to not dominate)
|
| 310 |
+
return flow_loss + 0.1 * recon_loss
|
| 311 |
+
|
| 312 |
+
def _generate(self, cond: torch.Tensor) -> torch.Tensor:
|
| 313 |
+
"""Generate Mimi embeddings via LSD decoding.
|
| 314 |
+
|
| 315 |
+
Args:
|
| 316 |
+
cond: Conditioning from LLM, shape [batch, seq_len, cond_dim]
|
| 317 |
+
|
| 318 |
+
Returns:
|
| 319 |
+
Generated Mimi embeddings, shape [batch, seq_len, 512]
|
| 320 |
+
"""
|
| 321 |
+
batch, seq_len, _ = cond.shape
|
| 322 |
+
device = cond.device
|
| 323 |
+
dtype = cond.dtype
|
| 324 |
+
|
| 325 |
+
# Handle empty sequences
|
| 326 |
+
if seq_len == 0:
|
| 327 |
+
return torch.empty(batch, 0, self.mimi_dim, device=device, dtype=dtype)
|
| 328 |
+
|
| 329 |
+
# Clamp temperature to non-negative to avoid complex numbers from sqrt
|
| 330 |
+
temp = max(0.0, self.temp)
|
| 331 |
+
|
| 332 |
+
latents = []
|
| 333 |
+
for t in range(seq_len):
|
| 334 |
+
cond_t = cond[:, t]
|
| 335 |
+
|
| 336 |
+
# Sample initial noise in 32-dim flow space
|
| 337 |
+
noise = torch.randn(batch, self.latent_dim, device=device, dtype=dtype)
|
| 338 |
+
noise = noise * (temp**0.5)
|
| 339 |
+
|
| 340 |
+
def velocity_fn(cond_fixed, s, t, x):
|
| 341 |
+
return self.flow_net(cond_fixed, s, t, x)
|
| 342 |
+
|
| 343 |
+
conditioned_flow = partial(velocity_fn, cond_t)
|
| 344 |
+
latent = lsd_decode(conditioned_flow, noise, self.lsd_steps)
|
| 345 |
+
latents.append(latent)
|
| 346 |
+
|
| 347 |
+
latents = torch.stack(latents, dim=1)
|
| 348 |
+
|
| 349 |
+
# Denormalize if using pretrained pocket-tts normalization
|
| 350 |
+
if self._use_pretrained_normalization:
|
| 351 |
+
latents = latents * self.emb_std + self.emb_mean
|
| 352 |
+
|
| 353 |
+
# Project back to 512-dim Mimi embedding space
|
| 354 |
+
return self.latent_proj_out(latents)
|
| 355 |
+
|
| 356 |
+
def decode_to_audio(self, latents: torch.Tensor) -> torch.Tensor:
|
| 357 |
+
"""Decode Mimi latents to audio waveform.
|
| 358 |
+
|
| 359 |
+
Note: HuggingFace MimiModel.decode() expects discrete codes, not continuous
|
| 360 |
+
embeddings. We bypass the quantizer and call upsample → decoder_transformer
|
| 361 |
+
→ decoder directly to decode from continuous latents.
|
| 362 |
+
|
| 363 |
+
Args:
|
| 364 |
+
latents: Mimi latents, shape [batch, seq_len, 512]
|
| 365 |
+
|
| 366 |
+
Returns:
|
| 367 |
+
Audio waveform, shape [batch, samples]
|
| 368 |
+
"""
|
| 369 |
+
if self.mimi is None:
|
| 370 |
+
raise RuntimeError("Mimi decoder not loaded. Call load_mimi_decoder() first.")
|
| 371 |
+
|
| 372 |
+
# [batch, seq, 512] → [batch, 512, seq]
|
| 373 |
+
latents = latents.transpose(1, 2)
|
| 374 |
+
|
| 375 |
+
with torch.no_grad():
|
| 376 |
+
# Upsample latents (2x temporal upsampling)
|
| 377 |
+
emb = self.mimi.upsample(latents)
|
| 378 |
+
|
| 379 |
+
# Decoder transformer expects [batch, seq, dim]
|
| 380 |
+
emb = emb.transpose(1, 2)
|
| 381 |
+
decoder_out = self.mimi.decoder_transformer(emb)
|
| 382 |
+
emb = getattr(decoder_out, "last_hidden_state", decoder_out[0])
|
| 383 |
+
|
| 384 |
+
# Final decoder expects [batch, dim, seq]
|
| 385 |
+
emb = emb.transpose(1, 2)
|
| 386 |
+
audio = self.mimi.decoder(emb)
|
| 387 |
+
|
| 388 |
+
return audio.squeeze(1)
|
| 389 |
+
|
| 390 |
+
def get_output_length(self, input_length: int) -> int:
|
| 391 |
+
"""Estimate output audio frames from input hidden state length.
|
| 392 |
+
|
| 393 |
+
For Mimi at 12.5 Hz frame rate with 24kHz audio:
|
| 394 |
+
Each latent frame = 24000 / 12.5 = 1920 audio samples
|
| 395 |
+
"""
|
| 396 |
+
return input_length * 1920
|
chat_template.jinja
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{# ───── defaults ───── #}
|
| 2 |
+
{%- if enable_thinking is not defined -%}
|
| 3 |
+
{%- set enable_thinking = true -%}
|
| 4 |
+
{%- endif -%}
|
| 5 |
+
|
| 6 |
+
{# ───── reasoning mode ───── #}
|
| 7 |
+
{%- if enable_thinking -%}
|
| 8 |
+
{%- set reasoning_mode = "/think" -%}
|
| 9 |
+
{%- else -%}
|
| 10 |
+
{%- set reasoning_mode = "/no_think" -%}
|
| 11 |
+
{%- endif -%}
|
| 12 |
+
|
| 13 |
+
{# ───── header (system message) ───── #}
|
| 14 |
+
{{- "<|im_start|>system\n" -}}
|
| 15 |
+
|
| 16 |
+
{%- if messages[0].role == "system" -%}
|
| 17 |
+
{%- set system_message = messages[0].content -%}
|
| 18 |
+
{%- if "/no_think" in system_message -%}
|
| 19 |
+
{%- set reasoning_mode = "/no_think" -%}
|
| 20 |
+
{%- elif "/think" in system_message -%}
|
| 21 |
+
{%- set reasoning_mode = "/think" -%}
|
| 22 |
+
{%- endif -%}
|
| 23 |
+
{%- set custom_instructions = system_message.replace("/no_think", "").replace("/think", "").rstrip() -%}
|
| 24 |
+
{%- endif -%}
|
| 25 |
+
|
| 26 |
+
{%- if "/system_override" in system_message -%}
|
| 27 |
+
{{- custom_instructions.replace("/system_override", "").rstrip() -}}
|
| 28 |
+
{{- "<|im_end|>\n" -}}
|
| 29 |
+
{%- else -%}
|
| 30 |
+
{{- "## Metadata\n\n" -}}
|
| 31 |
+
{{- "Knowledge Cutoff Date: June 2025\n" -}}
|
| 32 |
+
{%- set today = strftime_now("%d %B %Y") -%}
|
| 33 |
+
{{- "Today Date: " ~ today ~ "\n" -}}
|
| 34 |
+
{{- "Reasoning Mode: " + reasoning_mode + "\n\n" -}}
|
| 35 |
+
|
| 36 |
+
{{- "## Custom Instructions\n\n" -}}
|
| 37 |
+
{%- if custom_instructions -%}
|
| 38 |
+
{{- custom_instructions + "\n\n" -}}
|
| 39 |
+
{%- elif reasoning_mode == "/think" -%}
|
| 40 |
+
{{- "You are a helpful AI assistant named SmolLM, trained by Hugging Face. Your role as an assistant involves thoroughly exploring questions through a systematic thinking process before providing the final precise and accurate solutions. This requires engaging in a comprehensive cycle of analysis, summarizing, exploration, reassessment, reflection, backtracking, and iteration to develop well-considered thinking process. Please structure your response into two main sections: Thought and Solution using the specified format: <think> Thought section </think> Solution section. In the Thought section, detail your reasoning process in steps. Each step should include detailed considerations such as analysing questions, summarizing relevant findings, brainstorming new ideas, verifying the accuracy of the current steps, refining any errors, and revisiting previous steps. In the Solution section, based on various attempts, explorations, and reflections from the Thought section, systematically present the final solution that you deem correct. The Solution section should be logical, accurate, and concise and detail necessary steps needed to reach the conclusion.\n\n" -}}
|
| 41 |
+
{%- else -%}
|
| 42 |
+
{{- "You are a helpful AI assistant named SmolLM, trained by Hugging Face.\n\n" -}}
|
| 43 |
+
{%- endif -%}
|
| 44 |
+
|
| 45 |
+
{%- if xml_tools or python_tools or tools -%}
|
| 46 |
+
{{- "### Tools\n\n" -}}
|
| 47 |
+
{%- if xml_tools or tools -%}
|
| 48 |
+
{%- if tools -%}
|
| 49 |
+
{%- set xml_tools = tools -%}
|
| 50 |
+
{%- endif -%}
|
| 51 |
+
{%- set ns = namespace(xml_tool_string="You may call one or more functions to assist with the user query.\nYou are provided with function signatures within <tools></tools> XML tags:\n\n<tools>\n") -%}
|
| 52 |
+
{%- for tool in xml_tools[:] -%} {# The slicing makes sure that xml_tools is a list #}
|
| 53 |
+
{%- set ns.xml_tool_string = ns.xml_tool_string ~ (tool | string) ~ "\n" -%}
|
| 54 |
+
{%- endfor -%}
|
| 55 |
+
{%- set xml_tool_string = ns.xml_tool_string + "</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call>" -%}
|
| 56 |
+
{{- xml_tool_string -}}
|
| 57 |
+
{%- endif -%}
|
| 58 |
+
{%- if python_tools -%}
|
| 59 |
+
{%- set ns = namespace(python_tool_string="When you send a message containing Python code between '<code>' and '</code>' tags, it will be executed in a stateful Jupyter notebook environment, and you will then be given the output to continued reasoning in an agentic loop.\n\nYou can use the following tools in your python code like regular functions:\n<tools>\n") -%}
|
| 60 |
+
{%- for tool in python_tools[:] -%} {# The slicing makes sure that python_tools is a list #}
|
| 61 |
+
{%- set ns.python_tool_string = ns.python_tool_string ~ (tool | string) ~ "\n" -%}
|
| 62 |
+
{%- endfor -%}
|
| 63 |
+
{%- set python_tool_string = ns.python_tool_string + "</tools>\n\nThe state persists between code executions: so variables that you define in one step are still available thereafter." -%}
|
| 64 |
+
{{- python_tool_string -}}
|
| 65 |
+
{%- endif -%}
|
| 66 |
+
{{- "\n\n" -}}
|
| 67 |
+
{{- "<|im_end|>\n" -}}
|
| 68 |
+
{%- endif -%}
|
| 69 |
+
{%- endif -%}
|
| 70 |
+
{# ───── main loop ───── #}
|
| 71 |
+
{%- for message in messages -%}
|
| 72 |
+
{%- set content = message.content if message.content is string else "" -%}
|
| 73 |
+
{%- if message.role == "user" -%}
|
| 74 |
+
{{ "<|im_start|>" + message.role + "\n" + content + "<|im_end|>\n" }}
|
| 75 |
+
{%- elif message.role == "assistant" -%}
|
| 76 |
+
{% generation %}
|
| 77 |
+
{%- if reasoning_mode == "/think" -%}
|
| 78 |
+
{{ "<|im_start|>assistant\n" + content.lstrip("\n") + "<|im_end|>\n" }}
|
| 79 |
+
{%- else -%}
|
| 80 |
+
{{ "<|im_start|>assistant\n" + "<think>\n\n</think>\n" + content.lstrip("\n") + "<|im_end|>\n" }}
|
| 81 |
+
{%- endif -%}
|
| 82 |
+
{% endgeneration %}
|
| 83 |
+
{%- elif message.role == "tool" -%}
|
| 84 |
+
{{ "<|im_start|>" + "user\n" + content + "<|im_end|>\n" }}
|
| 85 |
+
{%- endif -%}
|
| 86 |
+
{%- endfor -%}
|
| 87 |
+
{# ───── generation prompt ───── #}
|
| 88 |
+
{%- if add_generation_prompt -%}
|
| 89 |
+
{%- if reasoning_mode == "/think" -%}
|
| 90 |
+
{{ "<|im_start|>assistant\n" }}
|
| 91 |
+
{%- else -%}
|
| 92 |
+
{{ "<|im_start|>assistant\n" + "<think>\n\n</think>\n" }}
|
| 93 |
+
{%- endif -%}
|
| 94 |
+
{%- endif -%}
|
diarization.py
ADDED
|
@@ -0,0 +1,759 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Speaker diarization using TEN-VAD + ECAPA-TDNN + spectral clustering.
|
| 2 |
+
|
| 3 |
+
Spectral clustering implementation adapted from FunASR/3D-Speaker:
|
| 4 |
+
https://github.com/alibaba-damo-academy/FunASR
|
| 5 |
+
MIT License (https://opensource.org/licenses/MIT)
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import warnings
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
import scipy
|
| 12 |
+
import sklearn.metrics.pairwise
|
| 13 |
+
import torch
|
| 14 |
+
from sklearn.cluster._kmeans import k_means
|
| 15 |
+
from sklearn.preprocessing import normalize
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def _get_device() -> torch.device:
|
| 19 |
+
"""Get best available device for inference."""
|
| 20 |
+
if torch.cuda.is_available():
|
| 21 |
+
return torch.device("cuda")
|
| 22 |
+
if torch.backends.mps.is_available():
|
| 23 |
+
return torch.device("mps")
|
| 24 |
+
return torch.device("cpu")
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class SpectralCluster:
|
| 28 |
+
"""Spectral clustering using unnormalized Laplacian of affinity matrix.
|
| 29 |
+
|
| 30 |
+
Adapted from FunASR/3D-Speaker and SpeechBrain implementations.
|
| 31 |
+
Uses eigenvalue gap to automatically determine number of speakers.
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
def __init__(self, min_num_spks: int = 1, max_num_spks: int = 15, pval: float = 0.06):
|
| 35 |
+
self.min_num_spks = min_num_spks
|
| 36 |
+
self.max_num_spks = max_num_spks
|
| 37 |
+
self.pval = pval
|
| 38 |
+
|
| 39 |
+
def __call__(self, embeddings: np.ndarray, oracle_num: int | None = None) -> np.ndarray:
|
| 40 |
+
"""Run spectral clustering on embeddings.
|
| 41 |
+
|
| 42 |
+
Args:
|
| 43 |
+
embeddings: Speaker embeddings of shape [N, D]
|
| 44 |
+
oracle_num: Optional known number of speakers
|
| 45 |
+
|
| 46 |
+
Returns:
|
| 47 |
+
Cluster labels of shape [N]
|
| 48 |
+
"""
|
| 49 |
+
# Similarity matrix computation
|
| 50 |
+
sim_mat = self.get_sim_mat(embeddings)
|
| 51 |
+
|
| 52 |
+
# Refining similarity matrix with pval
|
| 53 |
+
prunned_sim_mat = self.p_pruning(sim_mat)
|
| 54 |
+
|
| 55 |
+
# Symmetrization
|
| 56 |
+
sym_prund_sim_mat = 0.5 * (prunned_sim_mat + prunned_sim_mat.T)
|
| 57 |
+
|
| 58 |
+
# Laplacian calculation
|
| 59 |
+
laplacian = self.get_laplacian(sym_prund_sim_mat)
|
| 60 |
+
|
| 61 |
+
# Get Spectral Embeddings
|
| 62 |
+
emb, num_of_spk = self.get_spec_embs(laplacian, oracle_num)
|
| 63 |
+
|
| 64 |
+
# Perform clustering
|
| 65 |
+
return self.cluster_embs(emb, num_of_spk)
|
| 66 |
+
|
| 67 |
+
def get_sim_mat(self, embeddings: np.ndarray) -> np.ndarray:
|
| 68 |
+
"""Compute cosine similarity matrix."""
|
| 69 |
+
return sklearn.metrics.pairwise.cosine_similarity(embeddings, embeddings)
|
| 70 |
+
|
| 71 |
+
def p_pruning(self, affinity: np.ndarray) -> np.ndarray:
|
| 72 |
+
"""Prune low similarity values in affinity matrix (keep top pval fraction)."""
|
| 73 |
+
n = affinity.shape[0]
|
| 74 |
+
pval = max(self.pval, 6.0 / n)
|
| 75 |
+
k_keep = max(1, int(pval * n))
|
| 76 |
+
|
| 77 |
+
# Vectorized: find top-k indices per row and zero out the rest
|
| 78 |
+
top_k_idx = np.argpartition(affinity, -k_keep, axis=1)[:, -k_keep:]
|
| 79 |
+
mask = np.zeros_like(affinity, dtype=bool)
|
| 80 |
+
np.put_along_axis(mask, top_k_idx, True, axis=1)
|
| 81 |
+
affinity[~mask] = 0
|
| 82 |
+
return affinity
|
| 83 |
+
|
| 84 |
+
def get_laplacian(self, sim_mat: np.ndarray) -> np.ndarray:
|
| 85 |
+
"""Compute unnormalized Laplacian matrix."""
|
| 86 |
+
from scipy.sparse.csgraph import laplacian
|
| 87 |
+
|
| 88 |
+
np.fill_diagonal(sim_mat, 0)
|
| 89 |
+
return laplacian(sim_mat, normed=False)
|
| 90 |
+
|
| 91 |
+
def get_spec_embs(
|
| 92 |
+
self, laplacian: np.ndarray, k_oracle: int | None = None
|
| 93 |
+
) -> tuple[np.ndarray, int]:
|
| 94 |
+
"""Extract spectral embeddings from Laplacian.
|
| 95 |
+
|
| 96 |
+
Uses the eigengap heuristic to estimate the number of clusters:
|
| 97 |
+
The number of clusters k is chosen where the gap between consecutive
|
| 98 |
+
eigenvalues is largest, indicating a transition from "cluster" eigenvalues
|
| 99 |
+
(near 0) to "noise" eigenvalues.
|
| 100 |
+
"""
|
| 101 |
+
lambdas, eig_vecs = scipy.linalg.eigh(laplacian)
|
| 102 |
+
|
| 103 |
+
num_of_spk = k_oracle if k_oracle is not None else self._estimate_num_speakers(lambdas)
|
| 104 |
+
|
| 105 |
+
emb = eig_vecs[:, :num_of_spk]
|
| 106 |
+
return emb, num_of_spk
|
| 107 |
+
|
| 108 |
+
def _estimate_num_speakers(self, lambdas: np.ndarray) -> int:
|
| 109 |
+
"""Estimate number of speakers using refined eigengap heuristic.
|
| 110 |
+
|
| 111 |
+
For spectral clustering, we look for the largest gap in eigenvalues.
|
| 112 |
+
The eigenvalues corresponding to clusters are close to 0, and there
|
| 113 |
+
should be a significant jump to the remaining eigenvalues.
|
| 114 |
+
"""
|
| 115 |
+
# Consider eigenvalues from index 1 to max_num_spks (skip first, it's always ~0)
|
| 116 |
+
# We need gaps between positions, so look at indices 1 to max_num_spks+1
|
| 117 |
+
max_idx = min(self.max_num_spks + 1, len(lambdas))
|
| 118 |
+
relevant_lambdas = lambdas[1:max_idx] # Skip first eigenvalue
|
| 119 |
+
|
| 120 |
+
if len(relevant_lambdas) < 2:
|
| 121 |
+
return self.min_num_spks
|
| 122 |
+
|
| 123 |
+
# Compute absolute gaps (not ratios - ratios are unstable near 0)
|
| 124 |
+
gaps = np.diff(relevant_lambdas)
|
| 125 |
+
|
| 126 |
+
# Find the largest gap - the index gives us (k-1) since we skipped first
|
| 127 |
+
# Add 1 to convert from gap index to number of speakers
|
| 128 |
+
# Add 1 again because we skipped the first eigenvalue
|
| 129 |
+
max_gap_idx = int(np.argmax(gaps))
|
| 130 |
+
num_of_spk = max_gap_idx + 2 # +1 for gap->count, +1 for skipped eigenvalue
|
| 131 |
+
|
| 132 |
+
# Clamp between min and max
|
| 133 |
+
return max(self.min_num_spks, min(num_of_spk, self.max_num_spks))
|
| 134 |
+
|
| 135 |
+
def cluster_embs(self, emb: np.ndarray, k: int) -> np.ndarray:
|
| 136 |
+
"""Cluster spectral embeddings using k-means."""
|
| 137 |
+
_, labels, _ = k_means(emb, k, n_init=10)
|
| 138 |
+
return labels
|
| 139 |
+
|
| 140 |
+
def get_eigen_gaps(self, eig_vals: np.ndarray) -> np.ndarray:
|
| 141 |
+
"""Compute gaps between consecutive eigenvalues."""
|
| 142 |
+
return np.diff(eig_vals)
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
class SpeakerClusterer:
|
| 146 |
+
"""Speaker clustering backend using spectral clustering with speaker merging.
|
| 147 |
+
|
| 148 |
+
Features:
|
| 149 |
+
- Spectral clustering with eigenvalue gap for auto speaker count detection
|
| 150 |
+
- P-pruning for affinity matrix refinement
|
| 151 |
+
- Post-clustering speaker merging by cosine similarity
|
| 152 |
+
"""
|
| 153 |
+
|
| 154 |
+
def __init__(
|
| 155 |
+
self,
|
| 156 |
+
min_num_spks: int = 2,
|
| 157 |
+
max_num_spks: int = 10,
|
| 158 |
+
merge_thr: float = 0.90, # Moderate merging
|
| 159 |
+
):
|
| 160 |
+
self.min_num_spks = min_num_spks
|
| 161 |
+
self.max_num_spks = max_num_spks
|
| 162 |
+
self.merge_thr = merge_thr
|
| 163 |
+
self._spectral_cluster: SpectralCluster | None = None
|
| 164 |
+
|
| 165 |
+
def _get_spectral_cluster(self) -> SpectralCluster:
|
| 166 |
+
"""Lazy-load spectral clusterer."""
|
| 167 |
+
if self._spectral_cluster is None:
|
| 168 |
+
self._spectral_cluster = SpectralCluster(
|
| 169 |
+
min_num_spks=self.min_num_spks,
|
| 170 |
+
max_num_spks=self.max_num_spks,
|
| 171 |
+
)
|
| 172 |
+
return self._spectral_cluster
|
| 173 |
+
|
| 174 |
+
def __call__(self, embeddings: np.ndarray, num_speakers: int | None = None) -> np.ndarray:
|
| 175 |
+
"""Cluster speaker embeddings and return labels.
|
| 176 |
+
|
| 177 |
+
Args:
|
| 178 |
+
embeddings: Speaker embeddings of shape [N, D]
|
| 179 |
+
num_speakers: Optional oracle number of speakers
|
| 180 |
+
|
| 181 |
+
Returns:
|
| 182 |
+
Cluster labels of shape [N]
|
| 183 |
+
"""
|
| 184 |
+
import warnings
|
| 185 |
+
|
| 186 |
+
if len(embeddings.shape) != 2:
|
| 187 |
+
raise ValueError(f"Expected 2D array, got shape {embeddings.shape}")
|
| 188 |
+
|
| 189 |
+
# Handle edge cases
|
| 190 |
+
if embeddings.shape[0] == 0:
|
| 191 |
+
return np.array([], dtype=int)
|
| 192 |
+
if embeddings.shape[0] == 1:
|
| 193 |
+
return np.array([0], dtype=int)
|
| 194 |
+
if embeddings.shape[0] < 6:
|
| 195 |
+
return np.zeros(embeddings.shape[0], dtype=int)
|
| 196 |
+
|
| 197 |
+
# Normalize embeddings and replace NaN/inf
|
| 198 |
+
embeddings = np.nan_to_num(embeddings, nan=0.0, posinf=0.0, neginf=0.0)
|
| 199 |
+
embeddings = normalize(embeddings)
|
| 200 |
+
|
| 201 |
+
# Run spectral clustering (suppress numerical warnings)
|
| 202 |
+
spectral = self._get_spectral_cluster()
|
| 203 |
+
|
| 204 |
+
# Update min/max for oracle case
|
| 205 |
+
if num_speakers is not None:
|
| 206 |
+
spectral.min_num_spks = num_speakers
|
| 207 |
+
spectral.max_num_spks = num_speakers
|
| 208 |
+
|
| 209 |
+
with warnings.catch_warnings():
|
| 210 |
+
warnings.filterwarnings("ignore", category=RuntimeWarning)
|
| 211 |
+
labels = spectral(embeddings, oracle_num=num_speakers)
|
| 212 |
+
|
| 213 |
+
# Reset min/max
|
| 214 |
+
if num_speakers is not None:
|
| 215 |
+
spectral.min_num_spks = self.min_num_spks
|
| 216 |
+
spectral.max_num_spks = self.max_num_spks
|
| 217 |
+
|
| 218 |
+
# Merge similar speakers if no oracle
|
| 219 |
+
if num_speakers is None:
|
| 220 |
+
labels = self._merge_by_cos(labels, embeddings, self.merge_thr)
|
| 221 |
+
|
| 222 |
+
# Re-index labels sequentially
|
| 223 |
+
_, labels = np.unique(labels, return_inverse=True)
|
| 224 |
+
|
| 225 |
+
return labels
|
| 226 |
+
|
| 227 |
+
def _merge_by_cos(self, labels: np.ndarray, embs: np.ndarray, cos_thr: float) -> np.ndarray:
|
| 228 |
+
"""Merge similar speakers by cosine similarity of centroids."""
|
| 229 |
+
from scipy.cluster.hierarchy import fcluster, linkage
|
| 230 |
+
from scipy.spatial.distance import pdist
|
| 231 |
+
|
| 232 |
+
unique_labels = np.unique(labels)
|
| 233 |
+
if len(unique_labels) <= 1:
|
| 234 |
+
return labels
|
| 235 |
+
|
| 236 |
+
# Compute normalized speaker centroids
|
| 237 |
+
centroids = np.array([embs[labels == lbl].mean(0) for lbl in unique_labels])
|
| 238 |
+
centroids = normalize(centroids)
|
| 239 |
+
|
| 240 |
+
# Hierarchical clustering with cosine distance
|
| 241 |
+
distances = pdist(centroids, metric="cosine")
|
| 242 |
+
linkage_matrix = linkage(distances, method="average")
|
| 243 |
+
merged_labels = fcluster(linkage_matrix, t=1.0 - cos_thr, criterion="distance") - 1
|
| 244 |
+
|
| 245 |
+
# Map original labels to merged labels
|
| 246 |
+
label_map = dict(zip(unique_labels, merged_labels))
|
| 247 |
+
return np.array([label_map[lbl] for lbl in labels])
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
class LocalSpeakerDiarizer:
|
| 251 |
+
"""Local speaker diarization using TEN-VAD + ECAPA-TDNN + spectral clustering.
|
| 252 |
+
|
| 253 |
+
Pipeline:
|
| 254 |
+
1. TEN-VAD detects speech segments
|
| 255 |
+
2. Sliding window (1.0s, 75% overlap) for uniform embedding extraction
|
| 256 |
+
3. ECAPA-TDNN extracts speaker embeddings per window
|
| 257 |
+
4. Spectral clustering with eigenvalue gap for auto speaker detection
|
| 258 |
+
5. Frame-level consensus voting for segment reconstruction
|
| 259 |
+
6. Post-processing merges short segments to reduce flicker
|
| 260 |
+
|
| 261 |
+
Tunable Parameters (class attributes):
|
| 262 |
+
- WINDOW_SIZE: Embedding extraction window size in seconds
|
| 263 |
+
- STEP_SIZE: Sliding window step size (overlap = WINDOW_SIZE - STEP_SIZE)
|
| 264 |
+
- VAD_THRESHOLD: Speech detection threshold (lower = more sensitive)
|
| 265 |
+
- VAD_MIN_DURATION: Minimum speech segment duration
|
| 266 |
+
- VAD_MAX_GAP: Maximum gap to bridge between segments
|
| 267 |
+
- VAD_PAD_ONSET/OFFSET: Padding added to speech segments
|
| 268 |
+
- VOTING_RATE: Frame resolution for consensus voting
|
| 269 |
+
- MIN_SEGMENT_DURATION: Minimum final segment duration
|
| 270 |
+
- SAME_SPEAKER_GAP: Maximum gap to merge same-speaker segments
|
| 271 |
+
- TAIL_COVERAGE_RATIO: Minimum tail coverage to add extra window
|
| 272 |
+
"""
|
| 273 |
+
|
| 274 |
+
_ten_vad_model = None
|
| 275 |
+
_ecapa_model = None
|
| 276 |
+
_device = None
|
| 277 |
+
|
| 278 |
+
# ==================== TUNABLE PARAMETERS ====================
|
| 279 |
+
|
| 280 |
+
# Sliding window for embedding extraction
|
| 281 |
+
WINDOW_SIZE = 0.75 # seconds - shorter window for finer resolution
|
| 282 |
+
STEP_SIZE = 0.15 # seconds (80% overlap for more votes)
|
| 283 |
+
TAIL_COVERAGE_RATIO = 0.1 # Add extra window if tail > this ratio of window
|
| 284 |
+
|
| 285 |
+
# VAD hysteresis parameters
|
| 286 |
+
VAD_THRESHOLD = 0.25 # Balanced threshold
|
| 287 |
+
VAD_MIN_DURATION = 0.05 # Minimum speech segment duration (seconds)
|
| 288 |
+
VAD_MAX_GAP = 0.50 # Bridge gaps shorter than this (seconds)
|
| 289 |
+
VAD_PAD_ONSET = 0.05 # Padding at segment start (seconds)
|
| 290 |
+
VAD_PAD_OFFSET = 0.05 # Padding at segment end (seconds)
|
| 291 |
+
|
| 292 |
+
# Frame-level voting
|
| 293 |
+
VOTING_RATE = 0.01 # 10ms resolution for consensus voting
|
| 294 |
+
|
| 295 |
+
# Post-processing
|
| 296 |
+
MIN_SEGMENT_DURATION = 0.15 # Minimum final segment duration (seconds)
|
| 297 |
+
SHORT_SEGMENT_GAP = 0.1 # Gap threshold for merging short segments
|
| 298 |
+
SAME_SPEAKER_GAP = 0.5 # Gap threshold for merging same-speaker segments
|
| 299 |
+
|
| 300 |
+
# ===========================================================
|
| 301 |
+
|
| 302 |
+
@classmethod
|
| 303 |
+
def _get_ten_vad_model(cls):
|
| 304 |
+
"""Lazy-load TEN-VAD model (singleton)."""
|
| 305 |
+
if cls._ten_vad_model is None:
|
| 306 |
+
from ten_vad import TenVad
|
| 307 |
+
|
| 308 |
+
cls._ten_vad_model = TenVad(hop_size=256, threshold=cls.VAD_THRESHOLD)
|
| 309 |
+
return cls._ten_vad_model
|
| 310 |
+
|
| 311 |
+
@classmethod
|
| 312 |
+
def _get_device(cls) -> torch.device:
|
| 313 |
+
"""Get the best available device."""
|
| 314 |
+
if cls._device is None:
|
| 315 |
+
cls._device = _get_device()
|
| 316 |
+
return cls._device
|
| 317 |
+
|
| 318 |
+
@classmethod
|
| 319 |
+
def _get_ecapa_model(cls):
|
| 320 |
+
"""Lazy-load ECAPA-TDNN speaker embedding model (singleton)."""
|
| 321 |
+
if cls._ecapa_model is None:
|
| 322 |
+
# Suppress torchaudio deprecation warning from SpeechBrain
|
| 323 |
+
with warnings.catch_warnings():
|
| 324 |
+
warnings.filterwarnings("ignore", message="torchaudio._backend")
|
| 325 |
+
from speechbrain.inference.speaker import EncoderClassifier
|
| 326 |
+
|
| 327 |
+
device = cls._get_device()
|
| 328 |
+
cls._ecapa_model = EncoderClassifier.from_hparams(
|
| 329 |
+
source="speechbrain/spkrec-ecapa-voxceleb",
|
| 330 |
+
run_opts={"device": str(device)},
|
| 331 |
+
)
|
| 332 |
+
|
| 333 |
+
return cls._ecapa_model
|
| 334 |
+
|
| 335 |
+
@classmethod
|
| 336 |
+
def diarize(
|
| 337 |
+
cls,
|
| 338 |
+
audio: np.ndarray | str,
|
| 339 |
+
sample_rate: int = 16000,
|
| 340 |
+
num_speakers: int | None = None,
|
| 341 |
+
min_speakers: int = 2,
|
| 342 |
+
max_speakers: int = 10,
|
| 343 |
+
**_kwargs,
|
| 344 |
+
) -> list[dict]:
|
| 345 |
+
"""Run speaker diarization on audio.
|
| 346 |
+
|
| 347 |
+
Args:
|
| 348 |
+
audio: Audio waveform as numpy array or path to audio file
|
| 349 |
+
sample_rate: Audio sample rate (default 16000)
|
| 350 |
+
num_speakers: Exact number of speakers (if known)
|
| 351 |
+
min_speakers: Minimum number of speakers
|
| 352 |
+
max_speakers: Maximum number of speakers
|
| 353 |
+
|
| 354 |
+
Returns:
|
| 355 |
+
List of dicts with 'speaker', 'start', 'end' keys
|
| 356 |
+
"""
|
| 357 |
+
# Handle file path input
|
| 358 |
+
if isinstance(audio, str):
|
| 359 |
+
import librosa
|
| 360 |
+
|
| 361 |
+
audio, sample_rate = librosa.load(audio, sr=16000)
|
| 362 |
+
|
| 363 |
+
# Ensure correct sample rate
|
| 364 |
+
if sample_rate != 16000:
|
| 365 |
+
import librosa
|
| 366 |
+
|
| 367 |
+
audio = librosa.resample(audio, orig_sr=sample_rate, target_sr=16000)
|
| 368 |
+
sample_rate = 16000
|
| 369 |
+
|
| 370 |
+
audio = audio.astype(np.float32)
|
| 371 |
+
total_duration = len(audio) / sample_rate
|
| 372 |
+
|
| 373 |
+
# Step 1: VAD (returns segments and raw frame-level decisions)
|
| 374 |
+
segments, vad_frames = cls._get_speech_segments(audio, sample_rate)
|
| 375 |
+
if not segments:
|
| 376 |
+
return []
|
| 377 |
+
|
| 378 |
+
# Step 2: Extract embeddings
|
| 379 |
+
embeddings, window_segments = cls._extract_embeddings(audio, segments, sample_rate)
|
| 380 |
+
if len(embeddings) == 0:
|
| 381 |
+
return []
|
| 382 |
+
|
| 383 |
+
# Step 3: Cluster
|
| 384 |
+
clusterer = SpeakerClusterer(min_num_spks=min_speakers, max_num_spks=max_speakers)
|
| 385 |
+
labels = clusterer(embeddings, num_speakers)
|
| 386 |
+
|
| 387 |
+
# Step 4: Post-process with consensus voting (VAD-aware)
|
| 388 |
+
return cls._postprocess_segments(window_segments, labels, total_duration, vad_frames)
|
| 389 |
+
|
| 390 |
+
@classmethod
|
| 391 |
+
def _get_speech_segments(
|
| 392 |
+
cls, audio_array: np.ndarray, sample_rate: int = 16000
|
| 393 |
+
) -> tuple[list[dict], list[bool]]:
|
| 394 |
+
"""Get speech segments using TEN-VAD.
|
| 395 |
+
|
| 396 |
+
Returns:
|
| 397 |
+
Tuple of (segments list, vad_frames list of per-frame speech decisions)
|
| 398 |
+
"""
|
| 399 |
+
vad_model = cls._get_ten_vad_model()
|
| 400 |
+
|
| 401 |
+
# Convert to int16 as required by TEN-VAD
|
| 402 |
+
# Clip to prevent integer overflow
|
| 403 |
+
if audio_array.dtype != np.int16:
|
| 404 |
+
audio_int16 = (np.clip(audio_array, -1.0, 1.0) * 32767).astype(np.int16)
|
| 405 |
+
else:
|
| 406 |
+
audio_int16 = audio_array
|
| 407 |
+
|
| 408 |
+
# Process frame by frame
|
| 409 |
+
hop_size = 256
|
| 410 |
+
frame_duration = hop_size / sample_rate
|
| 411 |
+
speech_frames: list[bool] = []
|
| 412 |
+
|
| 413 |
+
for i in range(0, len(audio_int16) - hop_size, hop_size):
|
| 414 |
+
frame = audio_int16[i : i + hop_size]
|
| 415 |
+
_, is_speech = vad_model.process(frame)
|
| 416 |
+
speech_frames.append(is_speech)
|
| 417 |
+
|
| 418 |
+
# Convert frame-level decisions to segments
|
| 419 |
+
segments = []
|
| 420 |
+
in_speech = False
|
| 421 |
+
start_idx = 0
|
| 422 |
+
|
| 423 |
+
for i, is_speech in enumerate(speech_frames):
|
| 424 |
+
if is_speech and not in_speech:
|
| 425 |
+
start_idx = i
|
| 426 |
+
in_speech = True
|
| 427 |
+
elif not is_speech and in_speech:
|
| 428 |
+
start_time = start_idx * frame_duration
|
| 429 |
+
end_time = i * frame_duration
|
| 430 |
+
segments.append(
|
| 431 |
+
{
|
| 432 |
+
"start": start_time,
|
| 433 |
+
"end": end_time,
|
| 434 |
+
"start_sample": int(start_time * sample_rate),
|
| 435 |
+
"end_sample": int(end_time * sample_rate),
|
| 436 |
+
}
|
| 437 |
+
)
|
| 438 |
+
in_speech = False
|
| 439 |
+
|
| 440 |
+
# Handle trailing speech
|
| 441 |
+
if in_speech:
|
| 442 |
+
start_time = start_idx * frame_duration
|
| 443 |
+
end_time = len(speech_frames) * frame_duration
|
| 444 |
+
segments.append(
|
| 445 |
+
{
|
| 446 |
+
"start": start_time,
|
| 447 |
+
"end": end_time,
|
| 448 |
+
"start_sample": int(start_time * sample_rate),
|
| 449 |
+
"end_sample": int(end_time * sample_rate),
|
| 450 |
+
}
|
| 451 |
+
)
|
| 452 |
+
|
| 453 |
+
return cls._apply_vad_hysteresis(segments, sample_rate), speech_frames
|
| 454 |
+
|
| 455 |
+
@classmethod
|
| 456 |
+
def _apply_vad_hysteresis(cls, segments: list[dict], sample_rate: int = 16000) -> list[dict]:
|
| 457 |
+
"""Apply hysteresis-like post-processing to VAD segments."""
|
| 458 |
+
if not segments:
|
| 459 |
+
return segments
|
| 460 |
+
|
| 461 |
+
segments = sorted(segments, key=lambda x: x["start"])
|
| 462 |
+
|
| 463 |
+
# Fill short gaps
|
| 464 |
+
merged = [segments[0].copy()]
|
| 465 |
+
for seg in segments[1:]:
|
| 466 |
+
gap = seg["start"] - merged[-1]["end"]
|
| 467 |
+
if gap <= cls.VAD_MAX_GAP:
|
| 468 |
+
merged[-1]["end"] = seg["end"]
|
| 469 |
+
merged[-1]["end_sample"] = seg["end_sample"]
|
| 470 |
+
else:
|
| 471 |
+
merged.append(seg.copy())
|
| 472 |
+
|
| 473 |
+
# Remove short segments
|
| 474 |
+
filtered = [seg for seg in merged if (seg["end"] - seg["start"]) >= cls.VAD_MIN_DURATION]
|
| 475 |
+
|
| 476 |
+
# Dilate segments (add padding)
|
| 477 |
+
for seg in filtered:
|
| 478 |
+
seg["start"] = max(0.0, seg["start"] - cls.VAD_PAD_ONSET)
|
| 479 |
+
seg["end"] = seg["end"] + cls.VAD_PAD_OFFSET
|
| 480 |
+
seg["start_sample"] = int(seg["start"] * sample_rate)
|
| 481 |
+
seg["end_sample"] = int(seg["end"] * sample_rate)
|
| 482 |
+
|
| 483 |
+
return filtered
|
| 484 |
+
|
| 485 |
+
@classmethod
|
| 486 |
+
def _extract_embeddings(
|
| 487 |
+
cls, audio_array: np.ndarray, segments: list[dict], sample_rate: int
|
| 488 |
+
) -> tuple[np.ndarray, list[dict]]:
|
| 489 |
+
"""Extract speaker embeddings using sliding windows."""
|
| 490 |
+
speaker_model = cls._get_ecapa_model()
|
| 491 |
+
|
| 492 |
+
window_samples = int(cls.WINDOW_SIZE * sample_rate)
|
| 493 |
+
step_samples = int(cls.STEP_SIZE * sample_rate)
|
| 494 |
+
|
| 495 |
+
embeddings = []
|
| 496 |
+
window_segments = []
|
| 497 |
+
|
| 498 |
+
with torch.no_grad():
|
| 499 |
+
for seg in segments:
|
| 500 |
+
seg_start = seg["start_sample"]
|
| 501 |
+
seg_end = seg["end_sample"]
|
| 502 |
+
seg_len = seg_end - seg_start
|
| 503 |
+
|
| 504 |
+
# Generate window positions
|
| 505 |
+
if seg_len <= window_samples:
|
| 506 |
+
starts = [seg_start]
|
| 507 |
+
ends = [seg_end]
|
| 508 |
+
else:
|
| 509 |
+
starts = list(range(seg_start, seg_end - window_samples + 1, step_samples))
|
| 510 |
+
ends = [s + window_samples for s in starts]
|
| 511 |
+
|
| 512 |
+
# Cover tail if > TAIL_COVERAGE_RATIO of window remains
|
| 513 |
+
if ends and ends[-1] < seg_end:
|
| 514 |
+
remainder = seg_end - ends[-1]
|
| 515 |
+
if remainder > (window_samples * cls.TAIL_COVERAGE_RATIO):
|
| 516 |
+
starts.append(seg_end - window_samples)
|
| 517 |
+
ends.append(seg_end)
|
| 518 |
+
|
| 519 |
+
for c_start, c_end in zip(starts, ends):
|
| 520 |
+
chunk = audio_array[c_start:c_end]
|
| 521 |
+
|
| 522 |
+
# Pad short chunks with reflection
|
| 523 |
+
if len(chunk) < window_samples:
|
| 524 |
+
pad_width = window_samples - len(chunk)
|
| 525 |
+
chunk = np.pad(chunk, (0, pad_width), mode="reflect")
|
| 526 |
+
|
| 527 |
+
# Extract embedding using SpeechBrain's encode_batch
|
| 528 |
+
chunk_tensor = torch.from_numpy(chunk).float().unsqueeze(0)
|
| 529 |
+
embedding = (
|
| 530 |
+
speaker_model.encode_batch(chunk_tensor).squeeze(0).squeeze(0).cpu().numpy()
|
| 531 |
+
)
|
| 532 |
+
|
| 533 |
+
# Validate embedding
|
| 534 |
+
if np.isfinite(embedding).all() and np.linalg.norm(embedding) > 1e-8:
|
| 535 |
+
embeddings.append(embedding)
|
| 536 |
+
window_segments.append(
|
| 537 |
+
{
|
| 538 |
+
"start": c_start / sample_rate,
|
| 539 |
+
"end": c_end / sample_rate,
|
| 540 |
+
}
|
| 541 |
+
)
|
| 542 |
+
|
| 543 |
+
# Normalize all embeddings at once
|
| 544 |
+
if embeddings:
|
| 545 |
+
return normalize(np.array(embeddings)), window_segments
|
| 546 |
+
return np.array([]), []
|
| 547 |
+
|
| 548 |
+
@classmethod
|
| 549 |
+
def _resample_vad(cls, vad_frames: list[bool], num_frames: int) -> np.ndarray:
|
| 550 |
+
"""Resample VAD frame decisions to match voting grid resolution.
|
| 551 |
+
|
| 552 |
+
VAD operates at 256 samples / 16000 Hz = 16ms per frame.
|
| 553 |
+
Voting operates at VOTING_RATE (default 10ms) per frame.
|
| 554 |
+
This maps VAD decisions to the finer voting grid.
|
| 555 |
+
"""
|
| 556 |
+
if not vad_frames:
|
| 557 |
+
return np.zeros(num_frames, dtype=bool)
|
| 558 |
+
|
| 559 |
+
vad_rate = 256 / 16000 # 16ms per VAD frame
|
| 560 |
+
vad_arr = np.array(vad_frames)
|
| 561 |
+
|
| 562 |
+
# Vectorized: compute VAD frame indices for each voting frame
|
| 563 |
+
voting_times = np.arange(num_frames) * cls.VOTING_RATE
|
| 564 |
+
vad_indices = np.clip((voting_times / vad_rate).astype(int), 0, len(vad_arr) - 1)
|
| 565 |
+
return vad_arr[vad_indices]
|
| 566 |
+
|
| 567 |
+
@classmethod
|
| 568 |
+
def _postprocess_segments(
|
| 569 |
+
cls,
|
| 570 |
+
window_segments: list[dict],
|
| 571 |
+
labels: np.ndarray,
|
| 572 |
+
total_duration: float,
|
| 573 |
+
vad_frames: list[bool],
|
| 574 |
+
) -> list[dict]:
|
| 575 |
+
"""Post-process using frame-level consensus voting with VAD-aware silence."""
|
| 576 |
+
if not window_segments or len(labels) == 0:
|
| 577 |
+
return []
|
| 578 |
+
|
| 579 |
+
# Correct labels to be contiguous
|
| 580 |
+
unique_labels = np.unique(labels)
|
| 581 |
+
label_map = {old: new for new, old in enumerate(unique_labels)}
|
| 582 |
+
clean_labels = np.array([label_map[lbl] for lbl in labels])
|
| 583 |
+
num_speakers = len(unique_labels)
|
| 584 |
+
|
| 585 |
+
if num_speakers == 0:
|
| 586 |
+
return []
|
| 587 |
+
|
| 588 |
+
# Create voting grid
|
| 589 |
+
num_frames = int(np.ceil(total_duration / cls.VOTING_RATE)) + 1
|
| 590 |
+
votes = np.zeros((num_frames, num_speakers), dtype=np.float32)
|
| 591 |
+
|
| 592 |
+
# Accumulate votes
|
| 593 |
+
for win, label in zip(window_segments, clean_labels):
|
| 594 |
+
start_frame = int(win["start"] / cls.VOTING_RATE)
|
| 595 |
+
end_frame = int(win["end"] / cls.VOTING_RATE)
|
| 596 |
+
end_frame = min(end_frame, num_frames)
|
| 597 |
+
if start_frame < end_frame:
|
| 598 |
+
votes[start_frame:end_frame, label] += 1.0
|
| 599 |
+
|
| 600 |
+
# Determine winner per frame
|
| 601 |
+
frame_speakers = np.argmax(votes, axis=1)
|
| 602 |
+
max_votes = np.max(votes, axis=1)
|
| 603 |
+
|
| 604 |
+
# Resample VAD to voting grid resolution for silence-aware voting
|
| 605 |
+
vad_resampled = cls._resample_vad(vad_frames, num_frames)
|
| 606 |
+
|
| 607 |
+
# Convert frames to segments
|
| 608 |
+
final_segments = []
|
| 609 |
+
current_speaker = -1
|
| 610 |
+
seg_start = 0.0
|
| 611 |
+
|
| 612 |
+
for f in range(num_frames):
|
| 613 |
+
speaker = int(frame_speakers[f])
|
| 614 |
+
score = max_votes[f]
|
| 615 |
+
|
| 616 |
+
# Force silence if VAD says no speech OR no votes
|
| 617 |
+
if score == 0 or not vad_resampled[f]:
|
| 618 |
+
speaker = -1
|
| 619 |
+
|
| 620 |
+
if speaker != current_speaker:
|
| 621 |
+
if current_speaker != -1:
|
| 622 |
+
final_segments.append(
|
| 623 |
+
{
|
| 624 |
+
"speaker": f"SPEAKER_{current_speaker}",
|
| 625 |
+
"start": seg_start,
|
| 626 |
+
"end": f * cls.VOTING_RATE,
|
| 627 |
+
}
|
| 628 |
+
)
|
| 629 |
+
current_speaker = speaker
|
| 630 |
+
seg_start = f * cls.VOTING_RATE
|
| 631 |
+
|
| 632 |
+
# Close last segment
|
| 633 |
+
if current_speaker != -1:
|
| 634 |
+
final_segments.append(
|
| 635 |
+
{
|
| 636 |
+
"speaker": f"SPEAKER_{current_speaker}",
|
| 637 |
+
"start": seg_start,
|
| 638 |
+
"end": num_frames * cls.VOTING_RATE,
|
| 639 |
+
}
|
| 640 |
+
)
|
| 641 |
+
|
| 642 |
+
return cls._merge_short_segments(final_segments)
|
| 643 |
+
|
| 644 |
+
@classmethod
|
| 645 |
+
def _merge_short_segments(cls, segments: list[dict]) -> list[dict]:
|
| 646 |
+
"""Merge short segments to reduce flicker."""
|
| 647 |
+
if not segments:
|
| 648 |
+
return []
|
| 649 |
+
|
| 650 |
+
clean: list[dict] = []
|
| 651 |
+
for seg in segments:
|
| 652 |
+
dur = seg["end"] - seg["start"]
|
| 653 |
+
if dur < cls.MIN_SEGMENT_DURATION:
|
| 654 |
+
if (
|
| 655 |
+
clean
|
| 656 |
+
and clean[-1]["speaker"] == seg["speaker"]
|
| 657 |
+
and seg["start"] - clean[-1]["end"] < cls.SHORT_SEGMENT_GAP
|
| 658 |
+
):
|
| 659 |
+
clean[-1]["end"] = seg["end"]
|
| 660 |
+
continue
|
| 661 |
+
|
| 662 |
+
if (
|
| 663 |
+
clean
|
| 664 |
+
and clean[-1]["speaker"] == seg["speaker"]
|
| 665 |
+
and seg["start"] - clean[-1]["end"] < cls.SAME_SPEAKER_GAP
|
| 666 |
+
):
|
| 667 |
+
clean[-1]["end"] = seg["end"]
|
| 668 |
+
else:
|
| 669 |
+
clean.append(seg)
|
| 670 |
+
|
| 671 |
+
return clean
|
| 672 |
+
|
| 673 |
+
@classmethod
|
| 674 |
+
def assign_speakers_to_words(
|
| 675 |
+
cls,
|
| 676 |
+
words: list[dict],
|
| 677 |
+
speaker_segments: list[dict],
|
| 678 |
+
) -> list[dict]:
|
| 679 |
+
"""Assign speaker labels to words based on timestamp overlap.
|
| 680 |
+
|
| 681 |
+
Args:
|
| 682 |
+
words: List of word dicts with 'word', 'start', 'end' keys
|
| 683 |
+
speaker_segments: List of speaker dicts with 'speaker', 'start', 'end' keys
|
| 684 |
+
|
| 685 |
+
Returns:
|
| 686 |
+
Words list with 'speaker' key added to each word
|
| 687 |
+
"""
|
| 688 |
+
for word in words:
|
| 689 |
+
word_mid = (word["start"] + word["end"]) / 2
|
| 690 |
+
|
| 691 |
+
# Find the speaker segment that contains this word's midpoint
|
| 692 |
+
best_speaker = None
|
| 693 |
+
for seg in speaker_segments:
|
| 694 |
+
if seg["start"] <= word_mid <= seg["end"]:
|
| 695 |
+
best_speaker = seg["speaker"]
|
| 696 |
+
break
|
| 697 |
+
|
| 698 |
+
# If no exact match, find closest segment
|
| 699 |
+
if best_speaker is None and speaker_segments:
|
| 700 |
+
min_dist = float("inf")
|
| 701 |
+
for seg in speaker_segments:
|
| 702 |
+
seg_mid = (seg["start"] + seg["end"]) / 2
|
| 703 |
+
dist = abs(word_mid - seg_mid)
|
| 704 |
+
if dist < min_dist:
|
| 705 |
+
min_dist = dist
|
| 706 |
+
best_speaker = seg["speaker"]
|
| 707 |
+
|
| 708 |
+
word["speaker"] = best_speaker
|
| 709 |
+
|
| 710 |
+
return words
|
| 711 |
+
|
| 712 |
+
|
| 713 |
+
class SpeakerDiarizer:
|
| 714 |
+
"""Speaker diarization using TEN-VAD + ECAPA-TDNN + spectral clustering.
|
| 715 |
+
|
| 716 |
+
Example:
|
| 717 |
+
>>> segments = SpeakerDiarizer.diarize(audio_array)
|
| 718 |
+
>>> for seg in segments:
|
| 719 |
+
... print(f"{seg['speaker']}: {seg['start']:.2f} - {seg['end']:.2f}")
|
| 720 |
+
"""
|
| 721 |
+
|
| 722 |
+
@classmethod
|
| 723 |
+
def diarize(
|
| 724 |
+
cls,
|
| 725 |
+
audio: np.ndarray | str,
|
| 726 |
+
sample_rate: int = 16000,
|
| 727 |
+
num_speakers: int | None = None,
|
| 728 |
+
min_speakers: int | None = None,
|
| 729 |
+
max_speakers: int | None = None,
|
| 730 |
+
**_kwargs,
|
| 731 |
+
) -> list[dict]:
|
| 732 |
+
"""Run speaker diarization on audio.
|
| 733 |
+
|
| 734 |
+
Args:
|
| 735 |
+
audio: Audio waveform as numpy array or path to audio file
|
| 736 |
+
sample_rate: Audio sample rate (default 16000)
|
| 737 |
+
num_speakers: Exact number of speakers (if known)
|
| 738 |
+
min_speakers: Minimum number of speakers
|
| 739 |
+
max_speakers: Maximum number of speakers
|
| 740 |
+
|
| 741 |
+
Returns:
|
| 742 |
+
List of dicts with 'speaker', 'start', 'end' keys
|
| 743 |
+
"""
|
| 744 |
+
return LocalSpeakerDiarizer.diarize(
|
| 745 |
+
audio,
|
| 746 |
+
sample_rate=sample_rate,
|
| 747 |
+
num_speakers=num_speakers,
|
| 748 |
+
min_speakers=min_speakers or 2,
|
| 749 |
+
max_speakers=max_speakers or 10,
|
| 750 |
+
)
|
| 751 |
+
|
| 752 |
+
@classmethod
|
| 753 |
+
def assign_speakers_to_words(
|
| 754 |
+
cls,
|
| 755 |
+
words: list[dict],
|
| 756 |
+
speaker_segments: list[dict],
|
| 757 |
+
) -> list[dict]:
|
| 758 |
+
"""Assign speaker labels to words based on timestamp overlap."""
|
| 759 |
+
return LocalSpeakerDiarizer.assign_speakers_to_words(words, speaker_segments)
|
modules/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Modules for flow matching audio synthesis."""
|
| 2 |
+
|
| 3 |
+
from .mlp import SimpleMLPAdaLN
|
| 4 |
+
|
| 5 |
+
__all__ = ["SimpleMLPAdaLN"]
|
modules/mlp.py
ADDED
|
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Flow matching MLP with adaptive layer normalization.
|
| 2 |
+
|
| 3 |
+
Adapted from pocket-tts, originally from:
|
| 4 |
+
https://github.com/LTH14/mar/blob/fe470ac24afbee924668d8c5c83e9fec60af3a73/models/diffloss.py
|
| 5 |
+
|
| 6 |
+
Reference: https://arxiv.org/abs/2406.11838
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import math
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
|
| 16 |
+
"""Apply adaptive normalization modulation."""
|
| 17 |
+
return x * (1 + scale) + shift
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class RMSNorm(nn.Module):
|
| 21 |
+
"""Root Mean Square Layer Normalization."""
|
| 22 |
+
|
| 23 |
+
def __init__(self, dim: int, eps: float = 1e-5):
|
| 24 |
+
super().__init__()
|
| 25 |
+
self.eps = eps
|
| 26 |
+
self.alpha = nn.Parameter(torch.ones(dim))
|
| 27 |
+
|
| 28 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 29 |
+
x_dtype = x.dtype
|
| 30 |
+
var = self.eps + x.var(dim=-1, keepdim=True)
|
| 31 |
+
return (x * (self.alpha.to(var) * torch.rsqrt(var))).to(x_dtype)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class LayerNorm(nn.Module):
|
| 35 |
+
"""LayerNorm that supports JVP (for flow matching gradients)."""
|
| 36 |
+
|
| 37 |
+
def __init__(self, channels: int, eps: float = 1e-6, elementwise_affine: bool = True):
|
| 38 |
+
super().__init__()
|
| 39 |
+
self.eps = eps
|
| 40 |
+
if elementwise_affine:
|
| 41 |
+
self.weight = nn.Parameter(torch.ones(channels))
|
| 42 |
+
self.bias = nn.Parameter(torch.zeros(channels))
|
| 43 |
+
|
| 44 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 45 |
+
mean = x.mean(dim=-1, keepdim=True)
|
| 46 |
+
var = x.var(dim=-1, unbiased=False, keepdim=True)
|
| 47 |
+
x = (x - mean) / torch.sqrt(var + self.eps)
|
| 48 |
+
if hasattr(self, "weight"):
|
| 49 |
+
x = x * self.weight + self.bias
|
| 50 |
+
return x
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class TimestepEmbedder(nn.Module):
|
| 54 |
+
"""Embeds scalar timesteps into vector representations."""
|
| 55 |
+
|
| 56 |
+
def __init__(
|
| 57 |
+
self,
|
| 58 |
+
hidden_size: int,
|
| 59 |
+
frequency_embedding_size: int = 256,
|
| 60 |
+
max_period: int = 10000,
|
| 61 |
+
):
|
| 62 |
+
super().__init__()
|
| 63 |
+
self.mlp = nn.Sequential(
|
| 64 |
+
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
|
| 65 |
+
nn.SiLU(),
|
| 66 |
+
nn.Linear(hidden_size, hidden_size, bias=True),
|
| 67 |
+
RMSNorm(hidden_size),
|
| 68 |
+
)
|
| 69 |
+
self.frequency_embedding_size = frequency_embedding_size
|
| 70 |
+
half = frequency_embedding_size // 2
|
| 71 |
+
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half) / half)
|
| 72 |
+
self.register_buffer("freqs", freqs)
|
| 73 |
+
|
| 74 |
+
def forward(self, t: torch.Tensor) -> torch.Tensor:
|
| 75 |
+
args = t * self.freqs.to(t.dtype)
|
| 76 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
| 77 |
+
return self.mlp(embedding)
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
class ResBlock(nn.Module):
|
| 81 |
+
"""Residual block with adaptive layer normalization."""
|
| 82 |
+
|
| 83 |
+
def __init__(self, channels: int):
|
| 84 |
+
super().__init__()
|
| 85 |
+
self.channels = channels
|
| 86 |
+
self.in_ln = LayerNorm(channels, eps=1e-6)
|
| 87 |
+
self.mlp = nn.Sequential(
|
| 88 |
+
nn.Linear(channels, channels, bias=True),
|
| 89 |
+
nn.SiLU(),
|
| 90 |
+
nn.Linear(channels, channels, bias=True),
|
| 91 |
+
)
|
| 92 |
+
self.adaLN_modulation = nn.Sequential(
|
| 93 |
+
nn.SiLU(),
|
| 94 |
+
nn.Linear(channels, 3 * channels, bias=True),
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
| 98 |
+
shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(y).chunk(3, dim=-1)
|
| 99 |
+
h = modulate(self.in_ln(x), shift_mlp, scale_mlp)
|
| 100 |
+
h = self.mlp(h)
|
| 101 |
+
return x + gate_mlp * h
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
class FinalLayer(nn.Module):
|
| 105 |
+
"""Final layer with adaptive normalization (DiT-style)."""
|
| 106 |
+
|
| 107 |
+
def __init__(self, model_channels: int, out_channels: int):
|
| 108 |
+
super().__init__()
|
| 109 |
+
self.norm_final = LayerNorm(model_channels, elementwise_affine=False, eps=1e-6)
|
| 110 |
+
self.linear = nn.Linear(model_channels, out_channels, bias=True)
|
| 111 |
+
self.adaLN_modulation = nn.Sequential(
|
| 112 |
+
nn.SiLU(),
|
| 113 |
+
nn.Linear(model_channels, 2 * model_channels, bias=True),
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
|
| 117 |
+
shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
|
| 118 |
+
x = modulate(self.norm_final(x), shift, scale)
|
| 119 |
+
return self.linear(x)
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
class SimpleMLPAdaLN(nn.Module):
|
| 123 |
+
"""MLP for flow matching with adaptive layer normalization.
|
| 124 |
+
|
| 125 |
+
Takes conditioning from an AR transformer and predicts flow velocity.
|
| 126 |
+
|
| 127 |
+
Args:
|
| 128 |
+
in_channels: Input/output latent dimension (e.g., 256 for Mimi)
|
| 129 |
+
model_channels: Hidden dimension of the MLP
|
| 130 |
+
out_channels: Output dimension (same as in_channels for flow matching)
|
| 131 |
+
cond_channels: Conditioning dimension from LLM
|
| 132 |
+
num_res_blocks: Number of residual blocks
|
| 133 |
+
num_time_conds: Number of time conditions (2 for start/end time in LSD)
|
| 134 |
+
"""
|
| 135 |
+
|
| 136 |
+
def __init__(
|
| 137 |
+
self,
|
| 138 |
+
in_channels: int,
|
| 139 |
+
model_channels: int,
|
| 140 |
+
out_channels: int,
|
| 141 |
+
cond_channels: int,
|
| 142 |
+
num_res_blocks: int,
|
| 143 |
+
num_time_conds: int = 2,
|
| 144 |
+
):
|
| 145 |
+
super().__init__()
|
| 146 |
+
|
| 147 |
+
self.in_channels = in_channels
|
| 148 |
+
self.model_channels = model_channels
|
| 149 |
+
self.out_channels = out_channels
|
| 150 |
+
self.num_res_blocks = num_res_blocks
|
| 151 |
+
self.num_time_conds = num_time_conds
|
| 152 |
+
|
| 153 |
+
assert num_time_conds == 2, "LSD requires exactly 2 time conditions (start, end)"
|
| 154 |
+
|
| 155 |
+
self.time_embed = nn.ModuleList(
|
| 156 |
+
[TimestepEmbedder(model_channels) for _ in range(num_time_conds)]
|
| 157 |
+
)
|
| 158 |
+
self.cond_embed = nn.Linear(cond_channels, model_channels)
|
| 159 |
+
self.input_proj = nn.Linear(in_channels, model_channels)
|
| 160 |
+
|
| 161 |
+
self.res_blocks = nn.ModuleList([ResBlock(model_channels) for _ in range(num_res_blocks)])
|
| 162 |
+
self.final_layer = FinalLayer(model_channels, out_channels)
|
| 163 |
+
|
| 164 |
+
def forward(
|
| 165 |
+
self,
|
| 166 |
+
c: torch.Tensor,
|
| 167 |
+
s: torch.Tensor,
|
| 168 |
+
t: torch.Tensor,
|
| 169 |
+
x: torch.Tensor,
|
| 170 |
+
) -> torch.Tensor:
|
| 171 |
+
"""Predict flow velocity.
|
| 172 |
+
|
| 173 |
+
Args:
|
| 174 |
+
c: Conditioning from LLM, shape [N, cond_channels]
|
| 175 |
+
s: Start time, shape [N, 1]
|
| 176 |
+
t: Target time, shape [N, 1]
|
| 177 |
+
x: Noisy latent, shape [N, in_channels]
|
| 178 |
+
|
| 179 |
+
Returns:
|
| 180 |
+
Predicted velocity, shape [N, out_channels]
|
| 181 |
+
"""
|
| 182 |
+
x = self.input_proj(x)
|
| 183 |
+
|
| 184 |
+
# Combine time embeddings (average of start and end time embeddings)
|
| 185 |
+
ts = [s, t]
|
| 186 |
+
t_combined = sum(self.time_embed[i](ts[i]) for i in range(self.num_time_conds))
|
| 187 |
+
t_combined = t_combined / self.num_time_conds
|
| 188 |
+
|
| 189 |
+
# Add conditioning
|
| 190 |
+
c = self.cond_embed(c)
|
| 191 |
+
y = t_combined + c
|
| 192 |
+
|
| 193 |
+
# Residual blocks
|
| 194 |
+
for block in self.res_blocks:
|
| 195 |
+
x = block(x, y)
|
| 196 |
+
|
| 197 |
+
return self.final_layer(x, y)
|
preprocessor_config.json
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"chunk_length": 30,
|
| 3 |
+
"dither": 0.0,
|
| 4 |
+
"feature_extractor_type": "WhisperFeatureExtractor",
|
| 5 |
+
"feature_size": 128,
|
| 6 |
+
"hop_length": 160,
|
| 7 |
+
"n_fft": 400,
|
| 8 |
+
"n_samples": 480000,
|
| 9 |
+
"nb_max_frames": 3000,
|
| 10 |
+
"padding": false,
|
| 11 |
+
"padding_side": "right",
|
| 12 |
+
"padding_value": 0.0,
|
| 13 |
+
"return_attention_mask": false,
|
| 14 |
+
"sampling_rate": 16000,
|
| 15 |
+
"processor_class": "ASRProcessor",
|
| 16 |
+
"auto_map": {
|
| 17 |
+
"AutoProcessor": "asr_processing.ASRProcessor"
|
| 18 |
+
}
|
| 19 |
+
}
|
projectors.py
ADDED
|
@@ -0,0 +1,505 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Audio projector modules for bridging encoder and decoder embeddings.
|
| 2 |
+
|
| 3 |
+
This module contains all projector architectures:
|
| 4 |
+
- MLPAudioProjector: Simple 2-layer MLP with frame stacking downsampling
|
| 5 |
+
- MOSAProjector: MOSA-style dense mixture of experts
|
| 6 |
+
- SharedMoEAudioProjector: Shared expert + sparse routed experts
|
| 7 |
+
- QFormerAudioProjector: BLIP-2 QFormer with learnable queries (Granite-style)
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import math
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
import torch.nn as nn
|
| 14 |
+
import torch.nn.functional as F # noqa: N812
|
| 15 |
+
from transformers import AutoModel, Blip2QFormerConfig
|
| 16 |
+
from transformers.models.llama.modeling_llama import LlamaRMSNorm
|
| 17 |
+
|
| 18 |
+
# =============================================================================
|
| 19 |
+
# MLP Projector
|
| 20 |
+
# =============================================================================
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class MLPAudioProjector(nn.Module):
|
| 24 |
+
"""2-layer MLP projector with frame-stacking downsampling (matches GLM-ASR)."""
|
| 25 |
+
|
| 26 |
+
def __init__(self, config):
|
| 27 |
+
"""Initialize MLP projector.
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
config: ASRConfig with encoder_dim, llm_dim, projector_pool_stride
|
| 31 |
+
"""
|
| 32 |
+
super().__init__()
|
| 33 |
+
|
| 34 |
+
encoder_dim = getattr(config, "encoder_dim", 768)
|
| 35 |
+
llm_dim = getattr(config, "llm_dim", 2048)
|
| 36 |
+
self.k = getattr(config, "projector_pool_stride", 4)
|
| 37 |
+
|
| 38 |
+
# Frame stacking: concat k adjacent frames then project
|
| 39 |
+
in_dim = encoder_dim * self.k
|
| 40 |
+
# Hidden dim defaults to llm_dim, can be overridden via config
|
| 41 |
+
hidden_dim = getattr(config, "projector_hidden_dim", None) or llm_dim
|
| 42 |
+
self.linear_1 = nn.Linear(in_dim, hidden_dim, bias=False)
|
| 43 |
+
self.norm = LlamaRMSNorm(hidden_dim, eps=1e-6)
|
| 44 |
+
self.act = nn.GELU()
|
| 45 |
+
self.linear_2 = nn.Linear(hidden_dim, llm_dim, bias=False)
|
| 46 |
+
|
| 47 |
+
def get_output_length(self, input_length: int) -> int:
|
| 48 |
+
"""Calculate output sequence length given input length (matches GLM-ASR)."""
|
| 49 |
+
# GLM-ASR formula: (L - merge_factor) // merge_factor + 1
|
| 50 |
+
return (input_length - self.k) // self.k + 1
|
| 51 |
+
|
| 52 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 53 |
+
"""Project audio features to LLM embedding space.
|
| 54 |
+
|
| 55 |
+
Args:
|
| 56 |
+
x: Audio encoder output of shape [batch, seq_len, encoder_dim]
|
| 57 |
+
|
| 58 |
+
Returns:
|
| 59 |
+
Projected features of shape [batch, (seq_len - k) // k + 1, llm_dim]
|
| 60 |
+
"""
|
| 61 |
+
batch, seq, dim = x.shape
|
| 62 |
+
# Truncate to match GLM-ASR: use (seq - k) // k + 1 frames
|
| 63 |
+
# This drops trailing frames that don't fill a complete k-frame window
|
| 64 |
+
out_len = (seq - self.k) // self.k + 1
|
| 65 |
+
x = x[:, : out_len * self.k, :] # Truncate to exact multiple
|
| 66 |
+
x = x.reshape(batch, out_len, dim * self.k)
|
| 67 |
+
|
| 68 |
+
x = self.linear_1(x)
|
| 69 |
+
x = self.norm(x)
|
| 70 |
+
x = self.act(x)
|
| 71 |
+
return self.linear_2(x)
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
# =============================================================================
|
| 75 |
+
# MoE Projector (MOSA-style)
|
| 76 |
+
# =============================================================================
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
class SimpleAdapter(nn.Module):
|
| 80 |
+
"""Simple 2-layer GELU adapter (from MOSA paper)."""
|
| 81 |
+
|
| 82 |
+
def __init__(self, input_dim: int, hidden_dim: int, output_dim: int):
|
| 83 |
+
super().__init__()
|
| 84 |
+
self.fc1 = nn.Linear(input_dim, hidden_dim)
|
| 85 |
+
self.act = nn.GELU()
|
| 86 |
+
self.fc2 = nn.Linear(hidden_dim, output_dim)
|
| 87 |
+
|
| 88 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 89 |
+
return self.fc2(self.act(self.fc1(x)))
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
class SwiGLU(nn.Module):
|
| 93 |
+
"""SwiGLU activation with gated linear units (used in LLaMA, Mistral, etc.)."""
|
| 94 |
+
|
| 95 |
+
def __init__(self, dim: int, hidden_dim: int, bias: bool = False):
|
| 96 |
+
super().__init__()
|
| 97 |
+
self.w1 = nn.Linear(dim, hidden_dim, bias=bias) # Gate
|
| 98 |
+
self.w2 = nn.Linear(dim, hidden_dim, bias=bias) # Value
|
| 99 |
+
self.w3 = nn.Linear(hidden_dim, dim, bias=bias) # Output
|
| 100 |
+
|
| 101 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 102 |
+
return self.w3(F.silu(self.w1(x)) * self.w2(x))
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
class AsymmetricSwiGLU(nn.Module):
|
| 106 |
+
"""SwiGLU that handles different input and output dimensions."""
|
| 107 |
+
|
| 108 |
+
def __init__(
|
| 109 |
+
self, in_features: int, hidden_features: int, out_features: int, bias: bool = False
|
| 110 |
+
):
|
| 111 |
+
super().__init__()
|
| 112 |
+
self.w1 = nn.Linear(in_features, hidden_features, bias=bias) # Gate
|
| 113 |
+
self.w2 = nn.Linear(in_features, hidden_features, bias=bias) # Value
|
| 114 |
+
self.w3 = nn.Linear(hidden_features, out_features, bias=bias) # Output
|
| 115 |
+
|
| 116 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 117 |
+
return self.w3(F.silu(self.w1(x)) * self.w2(x))
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
class MOSAProjector(nn.Module):
|
| 121 |
+
"""MOSA-Base projector: simple 2-layer ReLU router with 4 simple adapters.
|
| 122 |
+
|
| 123 |
+
Based on "MOSA: Mixtures of Simple Adapters" (arXiv:2508.18998).
|
| 124 |
+
Uses softmax gating over all experts (dense MoE) with only cross-entropy loss.
|
| 125 |
+
Uses Conv1d for downsampling (2 layers, stride 2 each = 4x total).
|
| 126 |
+
"""
|
| 127 |
+
|
| 128 |
+
def __init__(self, config):
|
| 129 |
+
"""Initialize MOSA projector.
|
| 130 |
+
|
| 131 |
+
Args:
|
| 132 |
+
config: ASRConfig with encoder_dim, llm_dim, num_experts
|
| 133 |
+
"""
|
| 134 |
+
super().__init__()
|
| 135 |
+
self.encoder_dim = getattr(config, "encoder_dim", None) or 1280
|
| 136 |
+
self.llm_dim = getattr(config, "llm_dim", None) or 2048
|
| 137 |
+
self.num_experts = getattr(config, "num_experts", None) or 4 # MOSA-Base uses 4
|
| 138 |
+
adapter_hidden = getattr(config, "adapter_hidden_dim", None) or 4096
|
| 139 |
+
router_hidden = getattr(config, "router_hidden_dim", None) or 512
|
| 140 |
+
|
| 141 |
+
# --- 1. Conv1d Downsampler (4x reduction) ---
|
| 142 |
+
# 2 layers of stride-2 convolution
|
| 143 |
+
self.downsampler = nn.Sequential(
|
| 144 |
+
nn.Conv1d(self.encoder_dim, self.encoder_dim, kernel_size=3, stride=2, padding=1),
|
| 145 |
+
nn.GELU(),
|
| 146 |
+
nn.Conv1d(self.encoder_dim, self.llm_dim, kernel_size=3, stride=2, padding=1),
|
| 147 |
+
nn.GELU(),
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
# --- 2. Simple Router (MOSA-Base: 2 layers with ReLU) ---
|
| 151 |
+
# Takes downsampled features (llm_dim) -> 512 -> num_experts
|
| 152 |
+
self.router = nn.Sequential(
|
| 153 |
+
nn.Linear(self.llm_dim, router_hidden),
|
| 154 |
+
nn.ReLU(),
|
| 155 |
+
nn.Linear(router_hidden, self.num_experts),
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
# --- 3. Experts (Simple 2-layer GELU adapters) ---
|
| 159 |
+
# Each expert: llm_dim -> hidden -> llm_dim (much smaller than frame-stacking)
|
| 160 |
+
self.experts = nn.ModuleList(
|
| 161 |
+
[
|
| 162 |
+
SimpleAdapter(self.llm_dim, adapter_hidden, self.llm_dim)
|
| 163 |
+
for _ in range(self.num_experts)
|
| 164 |
+
]
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 168 |
+
"""Project audio features using mixture of experts.
|
| 169 |
+
|
| 170 |
+
Args:
|
| 171 |
+
x: Audio encoder output of shape [batch, seq_len, encoder_dim]
|
| 172 |
+
|
| 173 |
+
Returns:
|
| 174 |
+
Projected features of shape [batch, out_len, llm_dim]
|
| 175 |
+
"""
|
| 176 |
+
# --- 1. Conv1d Downsampling ---
|
| 177 |
+
# Permute for Conv1d: [B, S, D] -> [B, D, S]
|
| 178 |
+
x = x.transpose(1, 2)
|
| 179 |
+
x = self.downsampler(x)
|
| 180 |
+
# Permute back: [B, D, S] -> [B, S, D]
|
| 181 |
+
x = x.transpose(1, 2)
|
| 182 |
+
|
| 183 |
+
# --- 2. Routing ---
|
| 184 |
+
routing_weights = F.softmax(self.router(x), dim=-1) # (B, out_len, num_experts)
|
| 185 |
+
|
| 186 |
+
# --- 3. Expert Mixture (Dense Execution) ---
|
| 187 |
+
expert_outputs = torch.stack([expert(x) for expert in self.experts]) # (E, B, out_len, D)
|
| 188 |
+
return torch.einsum("ebsd, bse -> bsd", expert_outputs, routing_weights)
|
| 189 |
+
|
| 190 |
+
def get_output_length(self, input_length: int) -> int:
|
| 191 |
+
"""Calculate output sequence length after Conv1d downsampling (4x reduction)."""
|
| 192 |
+
# Conv1d with stride 2, kernel 3, padding 1: out = (in + 2*1 - 3) // 2 + 1 = (in - 1) // 2 + 1
|
| 193 |
+
# Applied twice for 4x total reduction
|
| 194 |
+
after_conv1 = (input_length + 2 * 1 - 3) // 2 + 1
|
| 195 |
+
return (after_conv1 + 2 * 1 - 3) // 2 + 1
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
# =============================================================================
|
| 199 |
+
# MoE Projector (Pure PyTorch with Shared Expert)
|
| 200 |
+
# =============================================================================
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
class MoEAudioProjector(nn.Module):
|
| 204 |
+
"""MoE projector with shared expert (DeepSeek-style), pure PyTorch implementation.
|
| 205 |
+
|
| 206 |
+
Uses 4 sparse experts with top-2 routing plus a shared expert that processes all tokens.
|
| 207 |
+
No external dependencies (megablocks removed).
|
| 208 |
+
|
| 209 |
+
Architecture matches main branch: norm → experts(in_dim → hidden → out_dim)
|
| 210 |
+
"""
|
| 211 |
+
|
| 212 |
+
def __init__(self, config):
|
| 213 |
+
"""Initialize MoE projector.
|
| 214 |
+
|
| 215 |
+
Args:
|
| 216 |
+
config: ASRConfig with encoder_dim, llm_dim, num_experts, num_experts_per_tok
|
| 217 |
+
"""
|
| 218 |
+
super().__init__()
|
| 219 |
+
|
| 220 |
+
self.k = getattr(config, "projector_pool_stride", 4)
|
| 221 |
+
self.aux_coef = getattr(config, "router_aux_loss_coef", 0.01)
|
| 222 |
+
|
| 223 |
+
# Stability coefficients
|
| 224 |
+
self.router_z_loss_coef = getattr(
|
| 225 |
+
config, "router_z_loss_coef", 1e-4
|
| 226 |
+
) # Prevents logit explosion
|
| 227 |
+
self.router_jitter_noise = getattr(
|
| 228 |
+
config, "router_jitter_noise", 0.01
|
| 229 |
+
) # Prevents expert collapse
|
| 230 |
+
|
| 231 |
+
in_dim = config.encoder_dim * self.k
|
| 232 |
+
out_dim = config.llm_dim
|
| 233 |
+
|
| 234 |
+
# Expert hidden dim (default = output dim)
|
| 235 |
+
hidden_dim = getattr(config, "projector_hidden_dim", None) or out_dim
|
| 236 |
+
|
| 237 |
+
# Number of experts and top-k selection
|
| 238 |
+
self.num_experts = getattr(config, "num_experts", 4)
|
| 239 |
+
self.top_k = getattr(config, "num_experts_per_tok", 2)
|
| 240 |
+
|
| 241 |
+
# A. Normalize stacked input (like main branch SharedMoEBlock)
|
| 242 |
+
self.norm = LlamaRMSNorm(in_dim, eps=1e-6)
|
| 243 |
+
|
| 244 |
+
# B. Router (operates on stacked input)
|
| 245 |
+
self.router = nn.Linear(in_dim, self.num_experts, bias=False)
|
| 246 |
+
|
| 247 |
+
# C. Experts: simple 2-layer MLP (same as MLPAudioProjector)
|
| 248 |
+
self.experts = nn.ModuleList(
|
| 249 |
+
[SimpleAdapter(in_dim, hidden_dim, out_dim) for _ in range(self.num_experts)]
|
| 250 |
+
)
|
| 251 |
+
|
| 252 |
+
# D. Shared Expert (same architecture)
|
| 253 |
+
self.shared_expert = SimpleAdapter(in_dim, hidden_dim, out_dim)
|
| 254 |
+
|
| 255 |
+
# E. Initialize weights for stable training
|
| 256 |
+
self._init_weights()
|
| 257 |
+
|
| 258 |
+
self.last_aux_loss = torch.tensor(0.0)
|
| 259 |
+
|
| 260 |
+
def _init_weights(self):
|
| 261 |
+
"""Initialize weights for stable training start."""
|
| 262 |
+
with torch.no_grad():
|
| 263 |
+
# Router: small weights -> uniform probability
|
| 264 |
+
nn.init.normal_(self.router.weight, mean=0.0, std=0.02)
|
| 265 |
+
|
| 266 |
+
# Experts: xavier for fc1, small for fc2 (output)
|
| 267 |
+
for expert in [self.shared_expert, *self.experts]:
|
| 268 |
+
nn.init.xavier_uniform_(expert.fc1.weight)
|
| 269 |
+
nn.init.normal_(expert.fc2.weight, mean=0.0, std=0.01) # Small init
|
| 270 |
+
|
| 271 |
+
def get_output_length(self, input_length: int) -> int:
|
| 272 |
+
"""Calculate output sequence length given input length (matches MLP projector)."""
|
| 273 |
+
return (input_length - self.k) // self.k + 1
|
| 274 |
+
|
| 275 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 276 |
+
"""Project audio features using shared + sparse MoE.
|
| 277 |
+
|
| 278 |
+
Args:
|
| 279 |
+
x: Audio encoder output of shape [batch, seq_len, encoder_dim]
|
| 280 |
+
|
| 281 |
+
Returns:
|
| 282 |
+
Projected features of shape [batch, out_len, llm_dim]
|
| 283 |
+
"""
|
| 284 |
+
# 1. Frame Stacking
|
| 285 |
+
batch, seq, dim = x.shape
|
| 286 |
+
out_len = (seq - self.k) // self.k + 1
|
| 287 |
+
x = x[:, : out_len * self.k, :]
|
| 288 |
+
x = x.reshape(batch, out_len, dim * self.k)
|
| 289 |
+
|
| 290 |
+
# 2. Normalize stacked input (like main branch SharedMoEBlock)
|
| 291 |
+
x = self.norm(x)
|
| 292 |
+
flat_x = x.view(-1, x.size(-1)) # [tokens, in_dim]
|
| 293 |
+
|
| 294 |
+
# 3. Shared Expert (compute first, creates output tensor)
|
| 295 |
+
output = self.shared_expert(flat_x)
|
| 296 |
+
|
| 297 |
+
# 4. Sparse Experts (in-place add to shared output)
|
| 298 |
+
self.last_aux_loss = self._forward_sparse(flat_x, output)
|
| 299 |
+
|
| 300 |
+
return output.view(batch, out_len, -1)
|
| 301 |
+
|
| 302 |
+
def _forward_sparse(self, x: torch.Tensor, output: torch.Tensor) -> torch.Tensor:
|
| 303 |
+
"""Stability-hardened sparse expert dispatch (in-place add to output).
|
| 304 |
+
|
| 305 |
+
Args:
|
| 306 |
+
x: Flattened input of shape [tokens, dim]
|
| 307 |
+
output: Output tensor to add sparse expert results into (in-place)
|
| 308 |
+
|
| 309 |
+
Returns:
|
| 310 |
+
Auxiliary loss tensor
|
| 311 |
+
"""
|
| 312 |
+
# A. Router Logic with Jitter
|
| 313 |
+
logits = self.router(x)
|
| 314 |
+
|
| 315 |
+
if self.training and self.router_jitter_noise > 0:
|
| 316 |
+
# Jitter: multiply by uniform noise (1-eps, 1+eps) to shake decision boundary
|
| 317 |
+
# Prevents router from getting stuck on one expert early in training
|
| 318 |
+
noise = torch.empty_like(logits).uniform_(
|
| 319 |
+
1.0 - self.router_jitter_noise, 1.0 + self.router_jitter_noise
|
| 320 |
+
)
|
| 321 |
+
logits = logits * noise
|
| 322 |
+
|
| 323 |
+
# Force float32 for softmax (bf16/fp16 exponentials can overflow)
|
| 324 |
+
probs = torch.softmax(logits, dim=-1, dtype=torch.float32).type_as(x)
|
| 325 |
+
|
| 326 |
+
# B. Top-K Selection
|
| 327 |
+
top_k_weights, top_k_indices = torch.topk(probs, self.top_k, dim=-1)
|
| 328 |
+
|
| 329 |
+
# Normalize weights so they sum to 1.0
|
| 330 |
+
top_k_weights = top_k_weights / (top_k_weights.sum(dim=-1, keepdim=True) + 1e-6)
|
| 331 |
+
|
| 332 |
+
# C. Aux Loss + Z-Loss
|
| 333 |
+
aux_loss = torch.tensor(0.0, device=x.device)
|
| 334 |
+
|
| 335 |
+
if self.training:
|
| 336 |
+
# Load balancing loss (batch-size invariant)
|
| 337 |
+
prob_per_expert = probs.mean(0) # [num_experts]
|
| 338 |
+
target = 1.0 / self.num_experts
|
| 339 |
+
balance_loss = (
|
| 340 |
+
self.aux_coef * ((prob_per_expert - target) ** 2).mean() * self.num_experts
|
| 341 |
+
)
|
| 342 |
+
|
| 343 |
+
# Z-loss: penalty on large logits to prevent softmax saturation
|
| 344 |
+
z_loss = self.router_z_loss_coef * torch.logsumexp(logits, dim=-1).pow(2).mean()
|
| 345 |
+
|
| 346 |
+
aux_loss = balance_loss + z_loss
|
| 347 |
+
|
| 348 |
+
# D. Dispatch Loop (in-place add to output)
|
| 349 |
+
for i, expert in enumerate(self.experts):
|
| 350 |
+
# Create boolean mask for tokens that selected Expert 'i'
|
| 351 |
+
mask = top_k_indices == i
|
| 352 |
+
|
| 353 |
+
if mask.any():
|
| 354 |
+
# token_idx = which tokens, k_idx = 1st or 2nd choice
|
| 355 |
+
token_idx, k_idx = torch.where(mask)
|
| 356 |
+
|
| 357 |
+
# Gather inputs and compute
|
| 358 |
+
expert_input = x[token_idx]
|
| 359 |
+
expert_output = expert(expert_input)
|
| 360 |
+
|
| 361 |
+
# Apply routing weight
|
| 362 |
+
weight = top_k_weights[token_idx, k_idx].unsqueeze(-1)
|
| 363 |
+
weighted_output = (expert_output * weight).type_as(output)
|
| 364 |
+
|
| 365 |
+
# Scatter back in-place (index_add_ is atomic and deterministic)
|
| 366 |
+
output.index_add_(0, token_idx, weighted_output)
|
| 367 |
+
|
| 368 |
+
return aux_loss
|
| 369 |
+
|
| 370 |
+
def get_aux_loss(self) -> torch.Tensor:
|
| 371 |
+
"""Return auxiliary load balancing loss."""
|
| 372 |
+
return self.last_aux_loss
|
| 373 |
+
|
| 374 |
+
|
| 375 |
+
# =============================================================================
|
| 376 |
+
# QFormer Projector (Granite-style)
|
| 377 |
+
# =============================================================================
|
| 378 |
+
|
| 379 |
+
|
| 380 |
+
class QFormerAudioProjector(nn.Module):
|
| 381 |
+
"""
|
| 382 |
+
BLIP-2 QFormer projector with learnable queries.
|
| 383 |
+
|
| 384 |
+
Based on GraniteSpeechEncoderProjector - uses a QFormer model with learnable
|
| 385 |
+
query embeddings to compress and project audio encoder outputs. The audio
|
| 386 |
+
sequence is processed in windows and downsampled via cross-attention.
|
| 387 |
+
"""
|
| 388 |
+
|
| 389 |
+
def __init__(self, config):
|
| 390 |
+
"""Initialize QFormer projector.
|
| 391 |
+
|
| 392 |
+
Args:
|
| 393 |
+
config: ASRConfig with encoder_dim, llm_dim, qformer_* settings
|
| 394 |
+
"""
|
| 395 |
+
super().__init__()
|
| 396 |
+
|
| 397 |
+
encoder_dim = config.encoder_dim
|
| 398 |
+
llm_dim = config.llm_dim
|
| 399 |
+
|
| 400 |
+
# Window and downsampling parameters (Granite defaults: window=15, downsample=5)
|
| 401 |
+
self.window_size = getattr(config, "qformer_window_size", 15)
|
| 402 |
+
self.downsample_rate = getattr(config, "downsample_rate", 5)
|
| 403 |
+
self.num_queries = self.window_size // self.downsample_rate
|
| 404 |
+
|
| 405 |
+
# QFormer hidden size (matches encoder for cross-attention)
|
| 406 |
+
qformer_hidden = getattr(config, "qformer_hidden_size", None) or encoder_dim
|
| 407 |
+
qformer_num_layers = getattr(config, "qformer_num_layers", 2)
|
| 408 |
+
qformer_num_heads = getattr(config, "qformer_num_heads", 16)
|
| 409 |
+
qformer_intermediate = getattr(config, "qformer_intermediate_size", None) or (
|
| 410 |
+
qformer_hidden * 4
|
| 411 |
+
)
|
| 412 |
+
|
| 413 |
+
# Learnable query embeddings (Granite uses std=1.0)
|
| 414 |
+
self.query = nn.Parameter(torch.zeros(1, self.num_queries, qformer_hidden))
|
| 415 |
+
self.query.data.normal_(mean=0.0, std=1.0)
|
| 416 |
+
|
| 417 |
+
# Optional projection if encoder dim != qformer hidden
|
| 418 |
+
if encoder_dim != qformer_hidden:
|
| 419 |
+
self.encoder_proj = nn.Linear(encoder_dim, qformer_hidden, bias=False)
|
| 420 |
+
else:
|
| 421 |
+
self.encoder_proj = None
|
| 422 |
+
|
| 423 |
+
# Configure QFormer to match Granite's exact config
|
| 424 |
+
qformer_config = Blip2QFormerConfig(
|
| 425 |
+
hidden_size=qformer_hidden,
|
| 426 |
+
num_hidden_layers=qformer_num_layers,
|
| 427 |
+
num_attention_heads=qformer_num_heads,
|
| 428 |
+
intermediate_size=qformer_intermediate,
|
| 429 |
+
encoder_hidden_size=qformer_hidden,
|
| 430 |
+
cross_attention_frequency=1,
|
| 431 |
+
# Granite-specific settings
|
| 432 |
+
hidden_act="gelu",
|
| 433 |
+
attention_probs_dropout_prob=0.1,
|
| 434 |
+
hidden_dropout_prob=0.1,
|
| 435 |
+
layer_norm_eps=1e-12,
|
| 436 |
+
initializer_range=0.02,
|
| 437 |
+
)
|
| 438 |
+
self.qformer = AutoModel.from_config(qformer_config)
|
| 439 |
+
|
| 440 |
+
# Final projection to LLM dimension (Granite uses bias=True)
|
| 441 |
+
self.linear = nn.Linear(qformer_hidden, llm_dim)
|
| 442 |
+
|
| 443 |
+
def get_output_length(self, input_length: int) -> int:
|
| 444 |
+
"""Calculate output sequence length given input length."""
|
| 445 |
+
# QFormer uses window-based processing with num_queries per window
|
| 446 |
+
nblocks = math.ceil(input_length / self.window_size)
|
| 447 |
+
return nblocks * self.num_queries
|
| 448 |
+
|
| 449 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 450 |
+
"""
|
| 451 |
+
Args:
|
| 452 |
+
hidden_states: [batch_size, seq_len, encoder_dim]
|
| 453 |
+
|
| 454 |
+
Returns:
|
| 455 |
+
projected: [batch_size, num_output_tokens, llm_dim]
|
| 456 |
+
"""
|
| 457 |
+
batch_size, seq_len, dim = hidden_states.size()
|
| 458 |
+
|
| 459 |
+
# Ensure float dtype for QFormer
|
| 460 |
+
target_dtype = self.query.dtype
|
| 461 |
+
if hidden_states.dtype != target_dtype:
|
| 462 |
+
hidden_states = hidden_states.to(target_dtype)
|
| 463 |
+
|
| 464 |
+
# Optional encoder projection
|
| 465 |
+
if self.encoder_proj is not None:
|
| 466 |
+
hidden_states = self.encoder_proj(hidden_states)
|
| 467 |
+
|
| 468 |
+
# Compute number of windows and pad to fit
|
| 469 |
+
nblocks = math.ceil(seq_len / self.window_size)
|
| 470 |
+
pad = nblocks * self.window_size - seq_len
|
| 471 |
+
if pad > 0:
|
| 472 |
+
hidden_states = F.pad(hidden_states, (0, 0, 0, pad), "constant", 0)
|
| 473 |
+
|
| 474 |
+
# Reshape to process each window: [batch*nblocks, window_size, dim]
|
| 475 |
+
effective_batch = batch_size * nblocks
|
| 476 |
+
hidden_states = hidden_states.view(effective_batch, self.window_size, -1)
|
| 477 |
+
|
| 478 |
+
# Expand queries to match batch size
|
| 479 |
+
query_embeds = self.query.expand(effective_batch, -1, -1)
|
| 480 |
+
|
| 481 |
+
# QFormer cross-attention
|
| 482 |
+
query_output = self.qformer(
|
| 483 |
+
query_embeds=query_embeds,
|
| 484 |
+
encoder_hidden_states=hidden_states,
|
| 485 |
+
return_dict=True,
|
| 486 |
+
)
|
| 487 |
+
|
| 488 |
+
# Reshape back: [batch, nblocks * num_queries, hidden]
|
| 489 |
+
output_tokens = nblocks * self.num_queries
|
| 490 |
+
query_proj = query_output.last_hidden_state.view(batch_size, output_tokens, -1)
|
| 491 |
+
|
| 492 |
+
# Project to LLM dimension
|
| 493 |
+
return self.linear(query_proj)
|
| 494 |
+
|
| 495 |
+
|
| 496 |
+
# =============================================================================
|
| 497 |
+
# Projector Registry
|
| 498 |
+
# =============================================================================
|
| 499 |
+
|
| 500 |
+
PROJECTOR_CLASSES = {
|
| 501 |
+
"mlp": MLPAudioProjector,
|
| 502 |
+
"mosa": MOSAProjector,
|
| 503 |
+
"moe": MoEAudioProjector,
|
| 504 |
+
"qformer": QFormerAudioProjector,
|
| 505 |
+
}
|
tokenizer.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d4aeaf198f783cbf58d8cd59812baac429ffe49147bf9648f6618de20b8d4a4c
|
| 3 |
+
size 17209003
|
tokenizer_config.json
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"backend": "tokenizers",
|
| 3 |
+
"bos_token": null,
|
| 4 |
+
"clean_up_tokenization_spaces": true,
|
| 5 |
+
"eos_token": "<|im_end|>",
|
| 6 |
+
"extra_special_tokens": [
|
| 7 |
+
"<audio>"
|
| 8 |
+
],
|
| 9 |
+
"fast": false,
|
| 10 |
+
"is_local": false,
|
| 11 |
+
"model_input_names": [
|
| 12 |
+
"input_ids",
|
| 13 |
+
"attention_mask"
|
| 14 |
+
],
|
| 15 |
+
"model_max_length": 131072,
|
| 16 |
+
"model_specific_special_tokens": {},
|
| 17 |
+
"pad_token": "<|finetune_right_pad_id|>",
|
| 18 |
+
"tokenizer_class": "TokenizersBackend"
|
| 19 |
+
}
|