Assembled S2S model (base + AudioHead)
Browse files- .gitattributes +1 -0
- README.md +199 -0
- alignment.py +295 -0
- asr_config.py +166 -0
- asr_modeling.py +817 -0
- asr_pipeline.py +370 -0
- asr_processing.py +133 -0
- audio_head.py +357 -0
- chat_template.jinja +94 -0
- config.json +409 -0
- diarization.py +706 -0
- full_duplex.py +475 -0
- model.safetensors +3 -0
- preprocessor_config.json +19 -0
- projectors.py +64 -0
- tokenizer.json +3 -0
- tokenizer_config.json +19 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
README.md
ADDED
|
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
library_name: transformers
|
| 3 |
+
tags: []
|
| 4 |
+
---
|
| 5 |
+
|
| 6 |
+
# Model Card for Model ID
|
| 7 |
+
|
| 8 |
+
<!-- Provide a quick summary of what the model is/does. -->
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
## Model Details
|
| 13 |
+
|
| 14 |
+
### Model Description
|
| 15 |
+
|
| 16 |
+
<!-- Provide a longer summary of what this model is. -->
|
| 17 |
+
|
| 18 |
+
This is the model card of a 🤗 transformers model that has been pushed on the Hub. This model card has been automatically generated.
|
| 19 |
+
|
| 20 |
+
- **Developed by:** [More Information Needed]
|
| 21 |
+
- **Funded by [optional]:** [More Information Needed]
|
| 22 |
+
- **Shared by [optional]:** [More Information Needed]
|
| 23 |
+
- **Model type:** [More Information Needed]
|
| 24 |
+
- **Language(s) (NLP):** [More Information Needed]
|
| 25 |
+
- **License:** [More Information Needed]
|
| 26 |
+
- **Finetuned from model [optional]:** [More Information Needed]
|
| 27 |
+
|
| 28 |
+
### Model Sources [optional]
|
| 29 |
+
|
| 30 |
+
<!-- Provide the basic links for the model. -->
|
| 31 |
+
|
| 32 |
+
- **Repository:** [More Information Needed]
|
| 33 |
+
- **Paper [optional]:** [More Information Needed]
|
| 34 |
+
- **Demo [optional]:** [More Information Needed]
|
| 35 |
+
|
| 36 |
+
## Uses
|
| 37 |
+
|
| 38 |
+
<!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
|
| 39 |
+
|
| 40 |
+
### Direct Use
|
| 41 |
+
|
| 42 |
+
<!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
|
| 43 |
+
|
| 44 |
+
[More Information Needed]
|
| 45 |
+
|
| 46 |
+
### Downstream Use [optional]
|
| 47 |
+
|
| 48 |
+
<!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
|
| 49 |
+
|
| 50 |
+
[More Information Needed]
|
| 51 |
+
|
| 52 |
+
### Out-of-Scope Use
|
| 53 |
+
|
| 54 |
+
<!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
|
| 55 |
+
|
| 56 |
+
[More Information Needed]
|
| 57 |
+
|
| 58 |
+
## Bias, Risks, and Limitations
|
| 59 |
+
|
| 60 |
+
<!-- This section is meant to convey both technical and sociotechnical limitations. -->
|
| 61 |
+
|
| 62 |
+
[More Information Needed]
|
| 63 |
+
|
| 64 |
+
### Recommendations
|
| 65 |
+
|
| 66 |
+
<!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
|
| 67 |
+
|
| 68 |
+
Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
|
| 69 |
+
|
| 70 |
+
## How to Get Started with the Model
|
| 71 |
+
|
| 72 |
+
Use the code below to get started with the model.
|
| 73 |
+
|
| 74 |
+
[More Information Needed]
|
| 75 |
+
|
| 76 |
+
## Training Details
|
| 77 |
+
|
| 78 |
+
### Training Data
|
| 79 |
+
|
| 80 |
+
<!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
|
| 81 |
+
|
| 82 |
+
[More Information Needed]
|
| 83 |
+
|
| 84 |
+
### Training Procedure
|
| 85 |
+
|
| 86 |
+
<!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
|
| 87 |
+
|
| 88 |
+
#### Preprocessing [optional]
|
| 89 |
+
|
| 90 |
+
[More Information Needed]
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
#### Training Hyperparameters
|
| 94 |
+
|
| 95 |
+
- **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
|
| 96 |
+
|
| 97 |
+
#### Speeds, Sizes, Times [optional]
|
| 98 |
+
|
| 99 |
+
<!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
|
| 100 |
+
|
| 101 |
+
[More Information Needed]
|
| 102 |
+
|
| 103 |
+
## Evaluation
|
| 104 |
+
|
| 105 |
+
<!-- This section describes the evaluation protocols and provides the results. -->
|
| 106 |
+
|
| 107 |
+
### Testing Data, Factors & Metrics
|
| 108 |
+
|
| 109 |
+
#### Testing Data
|
| 110 |
+
|
| 111 |
+
<!-- This should link to a Dataset Card if possible. -->
|
| 112 |
+
|
| 113 |
+
[More Information Needed]
|
| 114 |
+
|
| 115 |
+
#### Factors
|
| 116 |
+
|
| 117 |
+
<!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
|
| 118 |
+
|
| 119 |
+
[More Information Needed]
|
| 120 |
+
|
| 121 |
+
#### Metrics
|
| 122 |
+
|
| 123 |
+
<!-- These are the evaluation metrics being used, ideally with a description of why. -->
|
| 124 |
+
|
| 125 |
+
[More Information Needed]
|
| 126 |
+
|
| 127 |
+
### Results
|
| 128 |
+
|
| 129 |
+
[More Information Needed]
|
| 130 |
+
|
| 131 |
+
#### Summary
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
## Model Examination [optional]
|
| 136 |
+
|
| 137 |
+
<!-- Relevant interpretability work for the model goes here -->
|
| 138 |
+
|
| 139 |
+
[More Information Needed]
|
| 140 |
+
|
| 141 |
+
## Environmental Impact
|
| 142 |
+
|
| 143 |
+
<!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
|
| 144 |
+
|
| 145 |
+
Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
|
| 146 |
+
|
| 147 |
+
- **Hardware Type:** [More Information Needed]
|
| 148 |
+
- **Hours used:** [More Information Needed]
|
| 149 |
+
- **Cloud Provider:** [More Information Needed]
|
| 150 |
+
- **Compute Region:** [More Information Needed]
|
| 151 |
+
- **Carbon Emitted:** [More Information Needed]
|
| 152 |
+
|
| 153 |
+
## Technical Specifications [optional]
|
| 154 |
+
|
| 155 |
+
### Model Architecture and Objective
|
| 156 |
+
|
| 157 |
+
[More Information Needed]
|
| 158 |
+
|
| 159 |
+
### Compute Infrastructure
|
| 160 |
+
|
| 161 |
+
[More Information Needed]
|
| 162 |
+
|
| 163 |
+
#### Hardware
|
| 164 |
+
|
| 165 |
+
[More Information Needed]
|
| 166 |
+
|
| 167 |
+
#### Software
|
| 168 |
+
|
| 169 |
+
[More Information Needed]
|
| 170 |
+
|
| 171 |
+
## Citation [optional]
|
| 172 |
+
|
| 173 |
+
<!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
|
| 174 |
+
|
| 175 |
+
**BibTeX:**
|
| 176 |
+
|
| 177 |
+
[More Information Needed]
|
| 178 |
+
|
| 179 |
+
**APA:**
|
| 180 |
+
|
| 181 |
+
[More Information Needed]
|
| 182 |
+
|
| 183 |
+
## Glossary [optional]
|
| 184 |
+
|
| 185 |
+
<!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
|
| 186 |
+
|
| 187 |
+
[More Information Needed]
|
| 188 |
+
|
| 189 |
+
## More Information [optional]
|
| 190 |
+
|
| 191 |
+
[More Information Needed]
|
| 192 |
+
|
| 193 |
+
## Model Card Authors [optional]
|
| 194 |
+
|
| 195 |
+
[More Information Needed]
|
| 196 |
+
|
| 197 |
+
## Model Card Contact
|
| 198 |
+
|
| 199 |
+
[More Information Needed]
|
alignment.py
ADDED
|
@@ -0,0 +1,295 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Forced alignment for word-level timestamps using Wav2Vec2."""
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
# Offset compensation for Wav2Vec2-BASE systematic bias (in seconds)
|
| 7 |
+
# Calibrated on librispeech-alignments dataset (n=25, MAE=48ms)
|
| 8 |
+
START_OFFSET = 0.04 # Subtract from start times (shift earlier)
|
| 9 |
+
END_OFFSET = -0.04 # Subtract from end times (shift later)
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def _get_device() -> str:
|
| 13 |
+
"""Get best available device for non-transformers models."""
|
| 14 |
+
if torch.cuda.is_available():
|
| 15 |
+
return "cuda"
|
| 16 |
+
if torch.backends.mps.is_available():
|
| 17 |
+
return "mps"
|
| 18 |
+
return "cpu"
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class ForcedAligner:
|
| 22 |
+
"""Lazy-loaded forced aligner for word-level timestamps using torchaudio wav2vec2.
|
| 23 |
+
|
| 24 |
+
Uses Viterbi trellis algorithm for optimal alignment path finding.
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
_bundle = None
|
| 28 |
+
_model = None
|
| 29 |
+
_labels = None
|
| 30 |
+
_dictionary = None
|
| 31 |
+
|
| 32 |
+
@classmethod
|
| 33 |
+
def get_instance(cls, device: str = "cuda"):
|
| 34 |
+
"""Get or create the forced alignment model (singleton).
|
| 35 |
+
|
| 36 |
+
Args:
|
| 37 |
+
device: Device to run model on ("cuda" or "cpu")
|
| 38 |
+
|
| 39 |
+
Returns:
|
| 40 |
+
Tuple of (model, labels, dictionary)
|
| 41 |
+
"""
|
| 42 |
+
if cls._model is None:
|
| 43 |
+
import torchaudio
|
| 44 |
+
|
| 45 |
+
cls._bundle = torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H
|
| 46 |
+
cls._model = cls._bundle.get_model().to(device)
|
| 47 |
+
cls._model.eval()
|
| 48 |
+
cls._labels = cls._bundle.get_labels()
|
| 49 |
+
cls._dictionary = {c: i for i, c in enumerate(cls._labels)}
|
| 50 |
+
return cls._model, cls._labels, cls._dictionary
|
| 51 |
+
|
| 52 |
+
@staticmethod
|
| 53 |
+
def _get_trellis(emission: torch.Tensor, tokens: list[int], blank_id: int = 0) -> torch.Tensor:
|
| 54 |
+
"""Build trellis for forced alignment using forward algorithm.
|
| 55 |
+
|
| 56 |
+
The trellis[t, j] represents the log probability of the best path that
|
| 57 |
+
aligns the first j tokens to the first t frames.
|
| 58 |
+
|
| 59 |
+
Args:
|
| 60 |
+
emission: Log-softmax emission matrix of shape (num_frames, num_classes)
|
| 61 |
+
tokens: List of target token indices
|
| 62 |
+
blank_id: Index of the blank/CTC token (default 0)
|
| 63 |
+
|
| 64 |
+
Returns:
|
| 65 |
+
Trellis matrix of shape (num_frames + 1, num_tokens + 1)
|
| 66 |
+
"""
|
| 67 |
+
num_frames = emission.size(0)
|
| 68 |
+
num_tokens = len(tokens)
|
| 69 |
+
|
| 70 |
+
trellis = torch.full((num_frames + 1, num_tokens + 1), -float("inf"))
|
| 71 |
+
trellis[0, 0] = 0
|
| 72 |
+
|
| 73 |
+
# Force alignment to use all tokens by preventing staying in blank
|
| 74 |
+
# at the end when there are still tokens to emit
|
| 75 |
+
if num_tokens > 1:
|
| 76 |
+
trellis[-num_tokens + 1 :, 0] = float("inf")
|
| 77 |
+
|
| 78 |
+
for t in range(num_frames):
|
| 79 |
+
for j in range(num_tokens + 1):
|
| 80 |
+
# Stay: emit blank and stay at j tokens
|
| 81 |
+
stay = trellis[t, j] + emission[t, blank_id]
|
| 82 |
+
|
| 83 |
+
# Move: emit token j and advance to j+1 tokens
|
| 84 |
+
move = trellis[t, j - 1] + emission[t, tokens[j - 1]] if j > 0 else -float("inf")
|
| 85 |
+
|
| 86 |
+
trellis[t + 1, j] = max(stay, move) # Viterbi: take best path
|
| 87 |
+
|
| 88 |
+
return trellis
|
| 89 |
+
|
| 90 |
+
@staticmethod
|
| 91 |
+
def _backtrack(
|
| 92 |
+
trellis: torch.Tensor, emission: torch.Tensor, tokens: list[int], blank_id: int = 0
|
| 93 |
+
) -> list[tuple[int, float, float, float]]:
|
| 94 |
+
"""Backtrack through trellis to find optimal forced monotonic alignment.
|
| 95 |
+
|
| 96 |
+
Guarantees:
|
| 97 |
+
- All tokens are emitted exactly once
|
| 98 |
+
- Strictly monotonic: each token's frames come after previous token's
|
| 99 |
+
- No frame skipping or token teleporting
|
| 100 |
+
|
| 101 |
+
Returns list of (token_id, start_frame, end_frame, peak_frame) for each token.
|
| 102 |
+
The peak_frame is the frame with highest emission probability for that token.
|
| 103 |
+
"""
|
| 104 |
+
num_frames = emission.size(0)
|
| 105 |
+
num_tokens = len(tokens)
|
| 106 |
+
|
| 107 |
+
if num_tokens == 0:
|
| 108 |
+
return []
|
| 109 |
+
|
| 110 |
+
# Find the best ending point (should be at num_tokens)
|
| 111 |
+
# But verify trellis reached a valid state
|
| 112 |
+
if trellis[num_frames, num_tokens] == -float("inf"):
|
| 113 |
+
# Alignment failed - fall back to uniform distribution
|
| 114 |
+
frames_per_token = num_frames / num_tokens
|
| 115 |
+
return [
|
| 116 |
+
(
|
| 117 |
+
tokens[i],
|
| 118 |
+
i * frames_per_token,
|
| 119 |
+
(i + 1) * frames_per_token,
|
| 120 |
+
(i + 0.5) * frames_per_token,
|
| 121 |
+
)
|
| 122 |
+
for i in range(num_tokens)
|
| 123 |
+
]
|
| 124 |
+
|
| 125 |
+
# Backtrack: find where each token transition occurred
|
| 126 |
+
# Store (frame, emission_score) for each token
|
| 127 |
+
token_frames: list[list[tuple[int, float]]] = [[] for _ in range(num_tokens)]
|
| 128 |
+
|
| 129 |
+
t = num_frames
|
| 130 |
+
j = num_tokens
|
| 131 |
+
|
| 132 |
+
while t > 0 and j > 0:
|
| 133 |
+
# Check: did we transition from j-1 to j at frame t-1?
|
| 134 |
+
stay_score = trellis[t - 1, j] + emission[t - 1, blank_id]
|
| 135 |
+
move_score = trellis[t - 1, j - 1] + emission[t - 1, tokens[j - 1]]
|
| 136 |
+
|
| 137 |
+
if move_score >= stay_score:
|
| 138 |
+
# Token j-1 was emitted at frame t-1
|
| 139 |
+
# Store frame and its emission probability
|
| 140 |
+
emit_prob = emission[t - 1, tokens[j - 1]].exp().item()
|
| 141 |
+
token_frames[j - 1].insert(0, (t - 1, emit_prob))
|
| 142 |
+
j -= 1
|
| 143 |
+
# Always decrement time (monotonic)
|
| 144 |
+
t -= 1
|
| 145 |
+
|
| 146 |
+
# Handle any remaining tokens at the start (edge case)
|
| 147 |
+
while j > 0:
|
| 148 |
+
token_frames[j - 1].insert(0, (0, 0.0))
|
| 149 |
+
j -= 1
|
| 150 |
+
|
| 151 |
+
# Convert to spans with peak frame
|
| 152 |
+
token_spans: list[tuple[int, float, float, float]] = []
|
| 153 |
+
for token_idx, frames_with_scores in enumerate(token_frames):
|
| 154 |
+
if not frames_with_scores:
|
| 155 |
+
# Token never emitted - assign minimal span after previous
|
| 156 |
+
if token_spans:
|
| 157 |
+
prev_end = token_spans[-1][2]
|
| 158 |
+
frames_with_scores = [(int(prev_end), 0.0)]
|
| 159 |
+
else:
|
| 160 |
+
frames_with_scores = [(0, 0.0)]
|
| 161 |
+
|
| 162 |
+
token_id = tokens[token_idx]
|
| 163 |
+
frames = [f for f, _ in frames_with_scores]
|
| 164 |
+
start_frame = float(min(frames))
|
| 165 |
+
end_frame = float(max(frames)) + 1.0
|
| 166 |
+
|
| 167 |
+
# Find peak frame (highest emission probability)
|
| 168 |
+
peak_frame, _ = max(frames_with_scores, key=lambda x: x[1])
|
| 169 |
+
|
| 170 |
+
token_spans.append((token_id, start_frame, end_frame, float(peak_frame)))
|
| 171 |
+
|
| 172 |
+
return token_spans
|
| 173 |
+
|
| 174 |
+
@classmethod
|
| 175 |
+
def align(
|
| 176 |
+
cls,
|
| 177 |
+
audio: np.ndarray,
|
| 178 |
+
text: str,
|
| 179 |
+
sample_rate: int = 16000,
|
| 180 |
+
) -> list[dict]:
|
| 181 |
+
"""Align transcript to audio and return word-level timestamps.
|
| 182 |
+
|
| 183 |
+
Uses Viterbi trellis algorithm for optimal forced alignment.
|
| 184 |
+
|
| 185 |
+
Args:
|
| 186 |
+
audio: Audio waveform as numpy array
|
| 187 |
+
text: Transcript text to align
|
| 188 |
+
sample_rate: Audio sample rate (default 16000)
|
| 189 |
+
|
| 190 |
+
Returns:
|
| 191 |
+
List of dicts with 'word', 'start', 'end' keys
|
| 192 |
+
"""
|
| 193 |
+
import torchaudio
|
| 194 |
+
|
| 195 |
+
device = _get_device()
|
| 196 |
+
model, _labels, dictionary = cls.get_instance(device)
|
| 197 |
+
assert cls._bundle is not None and dictionary is not None # Initialized by get_instance
|
| 198 |
+
|
| 199 |
+
# Convert audio to tensor (copy to ensure array is writable)
|
| 200 |
+
if isinstance(audio, np.ndarray):
|
| 201 |
+
waveform = torch.from_numpy(audio.copy()).float()
|
| 202 |
+
else:
|
| 203 |
+
waveform = audio.clone().float()
|
| 204 |
+
|
| 205 |
+
# Ensure 2D (channels, time)
|
| 206 |
+
if waveform.dim() == 1:
|
| 207 |
+
waveform = waveform.unsqueeze(0)
|
| 208 |
+
|
| 209 |
+
# Resample if needed (wav2vec2 expects 16kHz)
|
| 210 |
+
if sample_rate != cls._bundle.sample_rate:
|
| 211 |
+
waveform = torchaudio.functional.resample(
|
| 212 |
+
waveform, sample_rate, cls._bundle.sample_rate
|
| 213 |
+
)
|
| 214 |
+
|
| 215 |
+
waveform = waveform.to(device)
|
| 216 |
+
|
| 217 |
+
# Get emissions from model
|
| 218 |
+
with torch.inference_mode():
|
| 219 |
+
emissions, _ = model(waveform)
|
| 220 |
+
emissions = torch.log_softmax(emissions, dim=-1)
|
| 221 |
+
|
| 222 |
+
emission = emissions[0].cpu()
|
| 223 |
+
|
| 224 |
+
# Normalize text: uppercase, keep only valid characters
|
| 225 |
+
transcript = text.upper()
|
| 226 |
+
|
| 227 |
+
# Build tokens from transcript (including word separators)
|
| 228 |
+
tokens = []
|
| 229 |
+
for char in transcript:
|
| 230 |
+
if char in dictionary:
|
| 231 |
+
tokens.append(dictionary[char])
|
| 232 |
+
elif char == " ":
|
| 233 |
+
tokens.append(dictionary.get("|", dictionary.get(" ", 0)))
|
| 234 |
+
|
| 235 |
+
if not tokens:
|
| 236 |
+
return []
|
| 237 |
+
|
| 238 |
+
# Build Viterbi trellis and backtrack for optimal path
|
| 239 |
+
trellis = cls._get_trellis(emission, tokens, blank_id=0)
|
| 240 |
+
alignment_path = cls._backtrack(trellis, emission, tokens, blank_id=0)
|
| 241 |
+
|
| 242 |
+
# Convert frame indices to time (model stride is 320 samples at 16kHz = 20ms)
|
| 243 |
+
frame_duration = 320 / cls._bundle.sample_rate
|
| 244 |
+
|
| 245 |
+
# Apply separate offset compensation for start/end (Wav2Vec2 systematic bias)
|
| 246 |
+
start_offset = START_OFFSET
|
| 247 |
+
end_offset = END_OFFSET
|
| 248 |
+
|
| 249 |
+
# Group aligned tokens into words based on pipe separator
|
| 250 |
+
# Use peak emission frame for more accurate word boundaries
|
| 251 |
+
words = text.split()
|
| 252 |
+
word_timestamps = []
|
| 253 |
+
first_char_peak = None
|
| 254 |
+
last_char_peak = None
|
| 255 |
+
word_idx = 0
|
| 256 |
+
separator_id = dictionary.get("|", dictionary.get(" ", 0))
|
| 257 |
+
|
| 258 |
+
for token_id, _start_frame, _end_frame, peak_frame in alignment_path:
|
| 259 |
+
if token_id == separator_id: # Word separator
|
| 260 |
+
if (
|
| 261 |
+
first_char_peak is not None
|
| 262 |
+
and last_char_peak is not None
|
| 263 |
+
and word_idx < len(words)
|
| 264 |
+
):
|
| 265 |
+
# Use peak frames for word boundaries
|
| 266 |
+
start_time = max(0.0, first_char_peak * frame_duration - start_offset)
|
| 267 |
+
end_time = max(0.0, (last_char_peak + 1) * frame_duration - end_offset)
|
| 268 |
+
word_timestamps.append(
|
| 269 |
+
{
|
| 270 |
+
"word": words[word_idx],
|
| 271 |
+
"start": start_time,
|
| 272 |
+
"end": end_time,
|
| 273 |
+
}
|
| 274 |
+
)
|
| 275 |
+
word_idx += 1
|
| 276 |
+
first_char_peak = None
|
| 277 |
+
last_char_peak = None
|
| 278 |
+
else:
|
| 279 |
+
if first_char_peak is None:
|
| 280 |
+
first_char_peak = peak_frame
|
| 281 |
+
last_char_peak = peak_frame
|
| 282 |
+
|
| 283 |
+
# Don't forget the last word
|
| 284 |
+
if first_char_peak is not None and last_char_peak is not None and word_idx < len(words):
|
| 285 |
+
start_time = max(0.0, first_char_peak * frame_duration - start_offset)
|
| 286 |
+
end_time = max(0.0, (last_char_peak + 1) * frame_duration - end_offset)
|
| 287 |
+
word_timestamps.append(
|
| 288 |
+
{
|
| 289 |
+
"word": words[word_idx],
|
| 290 |
+
"start": start_time,
|
| 291 |
+
"end": end_time,
|
| 292 |
+
}
|
| 293 |
+
)
|
| 294 |
+
|
| 295 |
+
return word_timestamps
|
asr_config.py
ADDED
|
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional
|
| 2 |
+
|
| 3 |
+
import transformers
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class ASRConfig(transformers.PretrainedConfig):
|
| 7 |
+
"""Configuration class for the ASR model."""
|
| 8 |
+
|
| 9 |
+
model_type = "asr_model"
|
| 10 |
+
is_composition = True
|
| 11 |
+
|
| 12 |
+
# Generation defaults
|
| 13 |
+
GENERATION_DEFAULTS = {
|
| 14 |
+
"num_beams": 1,
|
| 15 |
+
"max_new_tokens": 128,
|
| 16 |
+
"min_new_tokens": 0,
|
| 17 |
+
"repetition_penalty": 1.0,
|
| 18 |
+
"length_penalty": 1.0,
|
| 19 |
+
"no_repeat_ngram_size": 0,
|
| 20 |
+
"use_cache": True,
|
| 21 |
+
"do_sample": False,
|
| 22 |
+
"temperature": None,
|
| 23 |
+
"top_p": None,
|
| 24 |
+
"top_k": None,
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
def __init__(
|
| 28 |
+
self,
|
| 29 |
+
# Model IDs
|
| 30 |
+
audio_model_id: str = "zai-org/GLM-ASR-Nano-2512",
|
| 31 |
+
text_model_id: str = "Qwen/Qwen3-0.6B",
|
| 32 |
+
# Model settings
|
| 33 |
+
attn_implementation: str = "sdpa",
|
| 34 |
+
model_dtype: str = "bfloat16",
|
| 35 |
+
system_prompt: str = "You are a helpful assistant.",
|
| 36 |
+
enable_thinking: bool = False,
|
| 37 |
+
# Encoder settings (auto-detected if None)
|
| 38 |
+
encoder_dim: Optional[int] = None,
|
| 39 |
+
llm_dim: Optional[int] = None,
|
| 40 |
+
encoder_conv_layers: Optional[list] = None,
|
| 41 |
+
audio_sample_rate: int = 16000,
|
| 42 |
+
# Projector settings
|
| 43 |
+
projector_type: str = "mlp",
|
| 44 |
+
projector_pool_stride: int = 4,
|
| 45 |
+
projector_hidden_dim: Optional[int] = None,
|
| 46 |
+
# Training settings (not saved to config.json for inference)
|
| 47 |
+
use_specaugment: bool = False,
|
| 48 |
+
num_time_masks: int = 2,
|
| 49 |
+
time_mask_length: int = 10,
|
| 50 |
+
num_freq_masks: int = 0,
|
| 51 |
+
freq_mask_length: int = 10,
|
| 52 |
+
freeze_projector: bool = False,
|
| 53 |
+
label_smoothing: float = 0.0,
|
| 54 |
+
# Audio Head settings (trainable AR decoder + NeuCodec)
|
| 55 |
+
use_audio_head: bool = False,
|
| 56 |
+
freeze_audio_head: bool = False,
|
| 57 |
+
max_audio_tokens: int = 500,
|
| 58 |
+
decoder_dim: int = 512,
|
| 59 |
+
decoder_layers: int = 6,
|
| 60 |
+
decoder_heads: int = 8,
|
| 61 |
+
neucodec_model_id: str = "neuphonic/neucodec",
|
| 62 |
+
**kwargs,
|
| 63 |
+
):
|
| 64 |
+
# Merge generation defaults with kwargs (kwargs takes precedence)
|
| 65 |
+
for key, default in self.GENERATION_DEFAULTS.items():
|
| 66 |
+
if key not in kwargs:
|
| 67 |
+
kwargs[key] = default
|
| 68 |
+
|
| 69 |
+
# Core model settings
|
| 70 |
+
self.audio_model_id = audio_model_id
|
| 71 |
+
self.text_model_id = text_model_id
|
| 72 |
+
self.attn_implementation = attn_implementation
|
| 73 |
+
self.model_dtype = model_dtype
|
| 74 |
+
self.system_prompt = system_prompt
|
| 75 |
+
self.enable_thinking = enable_thinking
|
| 76 |
+
|
| 77 |
+
# Encoder settings
|
| 78 |
+
self.encoder_dim = encoder_dim
|
| 79 |
+
self.llm_dim = llm_dim
|
| 80 |
+
self.encoder_conv_layers = encoder_conv_layers or [(1, 3, 1), (1, 3, 2)]
|
| 81 |
+
self.audio_sample_rate = audio_sample_rate
|
| 82 |
+
|
| 83 |
+
# Projector settings
|
| 84 |
+
self.projector_type = projector_type
|
| 85 |
+
self.projector_pool_stride = projector_pool_stride
|
| 86 |
+
self.projector_hidden_dim = projector_hidden_dim
|
| 87 |
+
|
| 88 |
+
# Training settings
|
| 89 |
+
self.use_specaugment = use_specaugment
|
| 90 |
+
self.num_time_masks = num_time_masks
|
| 91 |
+
self.time_mask_length = time_mask_length
|
| 92 |
+
self.num_freq_masks = num_freq_masks
|
| 93 |
+
self.freq_mask_length = freq_mask_length
|
| 94 |
+
self.freeze_projector = freeze_projector
|
| 95 |
+
self.label_smoothing = label_smoothing
|
| 96 |
+
|
| 97 |
+
# Audio Head settings (trainable AR decoder + NeuCodec)
|
| 98 |
+
self.use_audio_head = use_audio_head
|
| 99 |
+
self.freeze_audio_head = freeze_audio_head
|
| 100 |
+
self.max_audio_tokens = max_audio_tokens
|
| 101 |
+
self.decoder_dim = decoder_dim
|
| 102 |
+
self.decoder_layers = decoder_layers
|
| 103 |
+
self.decoder_heads = decoder_heads
|
| 104 |
+
self.neucodec_model_id = neucodec_model_id
|
| 105 |
+
|
| 106 |
+
# Generation parameters (from kwargs after merge with defaults)
|
| 107 |
+
self.num_beams = kwargs.pop("num_beams")
|
| 108 |
+
self.max_new_tokens = kwargs.pop("max_new_tokens")
|
| 109 |
+
self.min_new_tokens = kwargs.pop("min_new_tokens")
|
| 110 |
+
self.repetition_penalty = kwargs.pop("repetition_penalty")
|
| 111 |
+
self.length_penalty = kwargs.pop("length_penalty")
|
| 112 |
+
self.no_repeat_ngram_size = kwargs.pop("no_repeat_ngram_size")
|
| 113 |
+
self.use_cache = kwargs.pop("use_cache")
|
| 114 |
+
self.do_sample = kwargs.pop("do_sample")
|
| 115 |
+
self.temperature = kwargs.pop("temperature")
|
| 116 |
+
self.top_p = kwargs.pop("top_p")
|
| 117 |
+
self.top_k = kwargs.pop("top_k")
|
| 118 |
+
|
| 119 |
+
# Load sub-configs
|
| 120 |
+
self.audio_config = kwargs.pop("audio_config", None)
|
| 121 |
+
if self.audio_config is None:
|
| 122 |
+
self.audio_config = transformers.AutoConfig.from_pretrained(
|
| 123 |
+
audio_model_id, trust_remote_code=True
|
| 124 |
+
)
|
| 125 |
+
self.audio_config.dtype = model_dtype
|
| 126 |
+
elif isinstance(self.audio_config, dict) and self.audio_config.get("model_type"):
|
| 127 |
+
config_class = transformers.AutoConfig.for_model(
|
| 128 |
+
self.audio_config["model_type"]
|
| 129 |
+
).__class__
|
| 130 |
+
self.audio_config = config_class(**self.audio_config)
|
| 131 |
+
|
| 132 |
+
self.text_config = kwargs.pop("text_config", None)
|
| 133 |
+
if self.text_config is None:
|
| 134 |
+
self.text_config = transformers.AutoConfig.from_pretrained(
|
| 135 |
+
text_model_id, trust_remote_code=True
|
| 136 |
+
)
|
| 137 |
+
self.text_config.dtype = model_dtype
|
| 138 |
+
elif isinstance(self.text_config, dict):
|
| 139 |
+
config_class = transformers.AutoConfig.for_model(
|
| 140 |
+
self.text_config["model_type"]
|
| 141 |
+
).__class__
|
| 142 |
+
self.text_config = config_class(**self.text_config)
|
| 143 |
+
|
| 144 |
+
super().__init__(**kwargs)
|
| 145 |
+
|
| 146 |
+
# Pipeline configuration
|
| 147 |
+
self.encoder = self.audio_config
|
| 148 |
+
self.auto_map = {
|
| 149 |
+
"AutoConfig": "asr_config.ASRConfig",
|
| 150 |
+
"AutoModel": "asr_modeling.ASRModel",
|
| 151 |
+
"AutoModelForSpeechSeq2Seq": "asr_modeling.ASRModel",
|
| 152 |
+
"AutoProcessor": "asr_processing.ASRProcessor",
|
| 153 |
+
}
|
| 154 |
+
self.custom_pipelines = {
|
| 155 |
+
"automatic-speech-recognition": {
|
| 156 |
+
"impl": "asr_pipeline.ASRPipeline",
|
| 157 |
+
"pt": ["AutoModelForSpeechSeq2Seq"],
|
| 158 |
+
"tf": [],
|
| 159 |
+
"type": "audio",
|
| 160 |
+
}
|
| 161 |
+
}
|
| 162 |
+
self.architectures = ["ASRModel"]
|
| 163 |
+
self.pipeline_tag = "automatic-speech-recognition"
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
transformers.AutoConfig.register("asr_model", ASRConfig)
|
asr_modeling.py
ADDED
|
@@ -0,0 +1,817 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
from typing import Optional, Union
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
from transformers import (
|
| 8 |
+
AutoConfig,
|
| 9 |
+
AutoModel,
|
| 10 |
+
AutoModelForCausalLM,
|
| 11 |
+
AutoTokenizer,
|
| 12 |
+
PreTrainedModel,
|
| 13 |
+
)
|
| 14 |
+
from transformers.generation import GenerationMixin
|
| 15 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
| 16 |
+
|
| 17 |
+
try:
|
| 18 |
+
from .asr_config import ASRConfig
|
| 19 |
+
from .projectors import PROJECTOR_CLASSES
|
| 20 |
+
except ImportError:
|
| 21 |
+
from asr_config import ASRConfig # type: ignore[no-redef]
|
| 22 |
+
from projectors import PROJECTOR_CLASSES # type: ignore[no-redef]
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
from torchaudio.transforms import SpecAugment
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class ASRModel(PreTrainedModel, GenerationMixin):
|
| 29 |
+
"""Audio-to-text model combining an audio encoder, projector, and language model."""
|
| 30 |
+
|
| 31 |
+
config_class = ASRConfig
|
| 32 |
+
base_model_prefix = "model"
|
| 33 |
+
main_input_name = "input_features"
|
| 34 |
+
_supports_flash_attn_2 = True
|
| 35 |
+
supports_gradient_checkpointing = True
|
| 36 |
+
_is_loading_from_pretrained: bool = False
|
| 37 |
+
_pretrained_model_path: Optional[str] = None
|
| 38 |
+
|
| 39 |
+
TRANSCRIBE_PROMPT = ""
|
| 40 |
+
|
| 41 |
+
@classmethod
|
| 42 |
+
def from_pretrained(cls, pretrained_model_name_or_path: str, *args, **kwargs) -> "ASRModel":
|
| 43 |
+
"""Load model from pretrained, handling device placement correctly."""
|
| 44 |
+
from safetensors.torch import load_file
|
| 45 |
+
from transformers.utils.hub import cached_file
|
| 46 |
+
|
| 47 |
+
config = kwargs.pop("config", None)
|
| 48 |
+
if config is None:
|
| 49 |
+
config = ASRConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
| 50 |
+
|
| 51 |
+
# Set flag to avoid device_map="auto" in sub-model loaders
|
| 52 |
+
cls._is_loading_from_pretrained = True
|
| 53 |
+
cls._pretrained_model_path = pretrained_model_name_or_path
|
| 54 |
+
|
| 55 |
+
try:
|
| 56 |
+
model = cls(config, **kwargs)
|
| 57 |
+
|
| 58 |
+
# Load projector weights from safetensors
|
| 59 |
+
subfolder = kwargs.get("subfolder")
|
| 60 |
+
revision = kwargs.get("revision")
|
| 61 |
+
cache_kwargs = {}
|
| 62 |
+
if subfolder:
|
| 63 |
+
cache_kwargs["subfolder"] = subfolder
|
| 64 |
+
if revision:
|
| 65 |
+
cache_kwargs["revision"] = revision
|
| 66 |
+
|
| 67 |
+
model_file = cached_file(
|
| 68 |
+
pretrained_model_name_or_path,
|
| 69 |
+
"model.safetensors",
|
| 70 |
+
_raise_exceptions_for_missing_entries=False,
|
| 71 |
+
**cache_kwargs,
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
if model_file is not None:
|
| 75 |
+
state_dict = load_file(model_file)
|
| 76 |
+
model.load_state_dict(state_dict, strict=False)
|
| 77 |
+
|
| 78 |
+
return model
|
| 79 |
+
finally:
|
| 80 |
+
cls._is_loading_from_pretrained = False
|
| 81 |
+
cls._pretrained_model_path = None
|
| 82 |
+
|
| 83 |
+
def __init__(self, config: ASRConfig, **kwargs) -> None:
|
| 84 |
+
super().__init__(config)
|
| 85 |
+
|
| 86 |
+
self.system_prompt = config.system_prompt
|
| 87 |
+
target_dtype = getattr(torch, config.model_dtype)
|
| 88 |
+
|
| 89 |
+
# Audio encoder (frozen)
|
| 90 |
+
self.audio_tower = self._load_audio_encoder(config, target_dtype)
|
| 91 |
+
|
| 92 |
+
# Language model (frozen)
|
| 93 |
+
self.language_model = self._load_language_model(config, target_dtype)
|
| 94 |
+
|
| 95 |
+
# Initialize tokenizer and special tokens
|
| 96 |
+
self._init_tokenizer(config)
|
| 97 |
+
|
| 98 |
+
# Set up generation config with greedy decoding defaults
|
| 99 |
+
self.generation_config = self.language_model.generation_config
|
| 100 |
+
self.generation_config.max_new_tokens = config.max_new_tokens
|
| 101 |
+
self.generation_config.min_new_tokens = config.min_new_tokens
|
| 102 |
+
self.generation_config.num_beams = config.num_beams
|
| 103 |
+
self.generation_config.do_sample = config.do_sample
|
| 104 |
+
# Set sampling params from config (None means use model defaults)
|
| 105 |
+
self.generation_config.temperature = config.temperature
|
| 106 |
+
self.generation_config.top_p = config.top_p
|
| 107 |
+
self.generation_config.top_k = config.top_k
|
| 108 |
+
self.generation_config.use_cache = config.use_cache
|
| 109 |
+
self.generation_config.length_penalty = config.length_penalty
|
| 110 |
+
self.generation_config.repetition_penalty = config.repetition_penalty
|
| 111 |
+
self.generation_config.no_repeat_ngram_size = config.no_repeat_ngram_size
|
| 112 |
+
# Set EOS tokens, filtering out any that don't exist in the tokenizer
|
| 113 |
+
eos_candidates = [
|
| 114 |
+
self.tokenizer.convert_tokens_to_ids("<|im_end|>"),
|
| 115 |
+
self.tokenizer.convert_tokens_to_ids("<|endoftext|>"),
|
| 116 |
+
]
|
| 117 |
+
self.generation_config.eos_token_id = [t for t in eos_candidates if t is not None]
|
| 118 |
+
self.generation_config.pad_token_id = self.tokenizer.pad_token_id
|
| 119 |
+
|
| 120 |
+
# Feature extractor for audio preprocessing
|
| 121 |
+
self.feature_extractor = self._create_feature_extractor(config)
|
| 122 |
+
|
| 123 |
+
# Audio projector (trainable unless freeze_projector is set)
|
| 124 |
+
self.projector = self._create_projector(config, target_dtype)
|
| 125 |
+
|
| 126 |
+
# Learned padding embedding for audio tokens (used when projector output is short)
|
| 127 |
+
# Using a learned embedding instead of zeros keeps values in the embedding distribution
|
| 128 |
+
self.audio_pad_embedding = nn.Parameter(torch.randn(1, config.llm_dim) * 0.02)
|
| 129 |
+
|
| 130 |
+
# Freeze projector if specified
|
| 131 |
+
if getattr(config, "freeze_projector", False):
|
| 132 |
+
self.projector.requires_grad_(False)
|
| 133 |
+
|
| 134 |
+
# SpecAugment for data augmentation during training
|
| 135 |
+
if getattr(config, "use_specaugment", False):
|
| 136 |
+
self.spec_augment = SpecAugment(
|
| 137 |
+
n_time_masks=config.num_time_masks,
|
| 138 |
+
time_mask_param=config.time_mask_length,
|
| 139 |
+
n_freq_masks=config.num_freq_masks,
|
| 140 |
+
freq_mask_param=config.freq_mask_length,
|
| 141 |
+
)
|
| 142 |
+
else:
|
| 143 |
+
self.spec_augment = None
|
| 144 |
+
|
| 145 |
+
# Audio head for S2S (trainable AR decoder + NeuCodec)
|
| 146 |
+
if getattr(config, "use_audio_head", False):
|
| 147 |
+
from .audio_head import AudioHead, AudioHeadConfig
|
| 148 |
+
|
| 149 |
+
device = next(self.language_model.parameters()).device
|
| 150 |
+
|
| 151 |
+
audio_head_config = AudioHeadConfig(
|
| 152 |
+
decoder_dim=config.decoder_dim,
|
| 153 |
+
decoder_layers=config.decoder_layers,
|
| 154 |
+
decoder_heads=config.decoder_heads,
|
| 155 |
+
text_vocab_size=len(self.tokenizer),
|
| 156 |
+
max_audio_tokens=config.max_audio_tokens,
|
| 157 |
+
neucodec_model_id=getattr(config, "neucodec_model_id", "neuphonic/neucodec"),
|
| 158 |
+
temperature=getattr(config, "audio_head_temperature", 1.0),
|
| 159 |
+
top_k=getattr(config, "audio_head_top_k", 50),
|
| 160 |
+
)
|
| 161 |
+
self.audio_head = AudioHead(audio_head_config).to(
|
| 162 |
+
device=device, dtype=target_dtype
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
if getattr(config, "freeze_audio_head", False):
|
| 166 |
+
self.audio_head.requires_grad_(False)
|
| 167 |
+
else:
|
| 168 |
+
self.audio_head = None
|
| 169 |
+
|
| 170 |
+
# Silero VAD for interruption detection (Freeze-Omni style)
|
| 171 |
+
# Loaded lazily on first use to avoid startup cost
|
| 172 |
+
self._vad_model = None
|
| 173 |
+
self._vad_utils = None
|
| 174 |
+
|
| 175 |
+
# For model parallelism
|
| 176 |
+
self._no_split_modules = getattr(self.language_model, "_no_split_modules", [])
|
| 177 |
+
|
| 178 |
+
def _tie_weights(self):
|
| 179 |
+
"""No-op: AudioHead manages its own embeddings."""
|
| 180 |
+
pass
|
| 181 |
+
|
| 182 |
+
def _create_feature_extractor(self, config: ASRConfig):
|
| 183 |
+
"""Create the appropriate feature extractor for the audio encoder."""
|
| 184 |
+
from transformers import AutoFeatureExtractor
|
| 185 |
+
|
| 186 |
+
feature_extractor = AutoFeatureExtractor.from_pretrained(config.audio_model_id)
|
| 187 |
+
# Disable padding by default - use actual audio length
|
| 188 |
+
feature_extractor.padding = False
|
| 189 |
+
return feature_extractor
|
| 190 |
+
|
| 191 |
+
@classmethod
|
| 192 |
+
def _load_audio_encoder(cls, config: ASRConfig, dtype: torch.dtype) -> nn.Module:
|
| 193 |
+
"""Load and freeze the audio encoder."""
|
| 194 |
+
encoder_kwargs = {
|
| 195 |
+
"attn_implementation": config.attn_implementation,
|
| 196 |
+
"low_cpu_mem_usage": True,
|
| 197 |
+
"torch_dtype": dtype,
|
| 198 |
+
}
|
| 199 |
+
|
| 200 |
+
if "whisper" in config.audio_model_id.lower():
|
| 201 |
+
from transformers import WhisperModel
|
| 202 |
+
|
| 203 |
+
full_model = WhisperModel.from_pretrained(config.audio_model_id, **encoder_kwargs)
|
| 204 |
+
encoder = full_model.encoder
|
| 205 |
+
del full_model
|
| 206 |
+
elif "glm" in config.audio_model_id.lower():
|
| 207 |
+
# GLM-ASR models use audio_tower as the encoder
|
| 208 |
+
# Requires transformers >= 5.x or installed from source
|
| 209 |
+
from transformers import AutoModelForSeq2SeqLM
|
| 210 |
+
|
| 211 |
+
full_model = AutoModelForSeq2SeqLM.from_pretrained(
|
| 212 |
+
config.audio_model_id, trust_remote_code=True, **encoder_kwargs
|
| 213 |
+
)
|
| 214 |
+
# GLM stores encoder at audio_tower (GlmAsrEncoder)
|
| 215 |
+
encoder = full_model.audio_tower
|
| 216 |
+
# Clear references to free VRAM from the LLM decoder
|
| 217 |
+
full_model.language_model = None
|
| 218 |
+
full_model.multi_modal_projector = None
|
| 219 |
+
del full_model
|
| 220 |
+
else:
|
| 221 |
+
encoder = AutoModel.from_pretrained(config.audio_model_id, **encoder_kwargs)
|
| 222 |
+
|
| 223 |
+
encoder.requires_grad_(False)
|
| 224 |
+
encoder.eval()
|
| 225 |
+
return encoder
|
| 226 |
+
|
| 227 |
+
@classmethod
|
| 228 |
+
def _load_language_model(cls, config: ASRConfig, dtype: torch.dtype) -> PreTrainedModel:
|
| 229 |
+
"""Load and freeze the language model."""
|
| 230 |
+
decoder_kwargs = {
|
| 231 |
+
"attn_implementation": config.attn_implementation,
|
| 232 |
+
"trust_remote_code": True,
|
| 233 |
+
"low_cpu_mem_usage": True,
|
| 234 |
+
"dtype": dtype,
|
| 235 |
+
}
|
| 236 |
+
|
| 237 |
+
decoder = AutoModelForCausalLM.from_pretrained(config.text_model_id, **decoder_kwargs)
|
| 238 |
+
decoder.config.use_cache = getattr(config, "use_cache", True)
|
| 239 |
+
decoder.requires_grad_(False)
|
| 240 |
+
decoder.eval()
|
| 241 |
+
return decoder
|
| 242 |
+
|
| 243 |
+
def _create_projector(self, config: ASRConfig, dtype: torch.dtype) -> nn.Module:
|
| 244 |
+
"""Create the trainable audio projector."""
|
| 245 |
+
# Auto-detect dimensions if not specified
|
| 246 |
+
if config.encoder_dim is None:
|
| 247 |
+
enc_cfg = self.audio_tower.config
|
| 248 |
+
config.encoder_dim = getattr(enc_cfg, "hidden_size", None) or getattr(
|
| 249 |
+
enc_cfg, "d_model", None
|
| 250 |
+
)
|
| 251 |
+
if config.encoder_dim is None:
|
| 252 |
+
raise ValueError("Could not auto-detect encoder_dim. Please specify in config.")
|
| 253 |
+
|
| 254 |
+
if config.llm_dim is None:
|
| 255 |
+
dec_cfg = self.language_model.config
|
| 256 |
+
config.llm_dim = getattr(dec_cfg, "hidden_size", None) or getattr(
|
| 257 |
+
dec_cfg, "d_model", None
|
| 258 |
+
)
|
| 259 |
+
if config.llm_dim is None:
|
| 260 |
+
raise ValueError("Could not auto-detect llm_dim. Please specify in config.")
|
| 261 |
+
|
| 262 |
+
# Select projector type based on config
|
| 263 |
+
projector_type = getattr(config, "projector_type", "mlp")
|
| 264 |
+
projector_class = PROJECTOR_CLASSES.get(projector_type)
|
| 265 |
+
if projector_class is None:
|
| 266 |
+
raise ValueError(
|
| 267 |
+
f"Unknown projector_type: {projector_type}. "
|
| 268 |
+
f"Valid options: {list(PROJECTOR_CLASSES.keys())}"
|
| 269 |
+
)
|
| 270 |
+
projector = projector_class(config)
|
| 271 |
+
|
| 272 |
+
# Move projector to same device as language model (important when using quantization)
|
| 273 |
+
device = next(self.language_model.parameters()).device
|
| 274 |
+
return projector.to(device=device, dtype=dtype)
|
| 275 |
+
|
| 276 |
+
def _init_tokenizer(self, config: ASRConfig):
|
| 277 |
+
"""Initialize tokenizer with audio token."""
|
| 278 |
+
self.tokenizer = AutoTokenizer.from_pretrained(config.text_model_id, trust_remote_code=True)
|
| 279 |
+
|
| 280 |
+
# Set pad token
|
| 281 |
+
if (
|
| 282 |
+
self.tokenizer.pad_token is None
|
| 283 |
+
or self.tokenizer.pad_token_id == self.tokenizer.eos_token_id
|
| 284 |
+
) and "<|finetune_right_pad_id|>" in self.tokenizer.get_vocab():
|
| 285 |
+
self.tokenizer.pad_token = "<|finetune_right_pad_id|>"
|
| 286 |
+
|
| 287 |
+
# Add audio token
|
| 288 |
+
existing_special = getattr(self.tokenizer, "additional_special_tokens", None) or []
|
| 289 |
+
if "<audio>" not in existing_special:
|
| 290 |
+
self.tokenizer.add_special_tokens(
|
| 291 |
+
{"additional_special_tokens": existing_special + ["<audio>"]}
|
| 292 |
+
)
|
| 293 |
+
self.language_model.resize_token_embeddings(
|
| 294 |
+
len(self.tokenizer), mean_resizing=False, pad_to_multiple_of=64
|
| 295 |
+
)
|
| 296 |
+
|
| 297 |
+
self.audio_token_id = self.tokenizer.convert_tokens_to_ids("<audio>")
|
| 298 |
+
self.tokenizer.padding_side = "right"
|
| 299 |
+
|
| 300 |
+
# Sync token IDs to configs
|
| 301 |
+
for cfg in [self.config.text_config, self.language_model.config, self.generation_config]:
|
| 302 |
+
if cfg is not None:
|
| 303 |
+
cfg.pad_token_id = self.tokenizer.pad_token_id
|
| 304 |
+
cfg.eos_token_id = self.tokenizer.eos_token_id
|
| 305 |
+
cfg.bos_token_id = self.tokenizer.bos_token_id
|
| 306 |
+
|
| 307 |
+
def _set_gradient_checkpointing(self, enable: bool = True, gradient_checkpointing_func=None):
|
| 308 |
+
"""Enable/disable gradient checkpointing for the language model."""
|
| 309 |
+
# The LLM still stores activations during forward for backprop to projector
|
| 310 |
+
# Gradient checkpointing trades compute for memory by recomputing activations
|
| 311 |
+
if hasattr(self.language_model, "_set_gradient_checkpointing"):
|
| 312 |
+
self.language_model._set_gradient_checkpointing(enable, gradient_checkpointing_func)
|
| 313 |
+
elif hasattr(self.language_model, "gradient_checkpointing_enable") and enable:
|
| 314 |
+
self.language_model.gradient_checkpointing_enable(
|
| 315 |
+
gradient_checkpointing_kwargs={"use_reentrant": False}
|
| 316 |
+
)
|
| 317 |
+
elif hasattr(self.language_model, "gradient_checkpointing_disable") and not enable:
|
| 318 |
+
self.language_model.gradient_checkpointing_disable()
|
| 319 |
+
|
| 320 |
+
def get_input_embeddings(self) -> nn.Module:
|
| 321 |
+
return self.language_model.get_input_embeddings()
|
| 322 |
+
|
| 323 |
+
def set_input_embeddings(self, value: nn.Module) -> None:
|
| 324 |
+
self.language_model.set_input_embeddings(value)
|
| 325 |
+
|
| 326 |
+
def get_output_embeddings(self) -> nn.Module:
|
| 327 |
+
return self.language_model.get_output_embeddings()
|
| 328 |
+
|
| 329 |
+
def set_output_embeddings(self, value: nn.Module) -> None:
|
| 330 |
+
self.language_model.set_output_embeddings(value)
|
| 331 |
+
|
| 332 |
+
def get_processor(self):
|
| 333 |
+
"""Get the processor for this model."""
|
| 334 |
+
try:
|
| 335 |
+
from .asr_processing import ASRProcessor
|
| 336 |
+
except ImportError:
|
| 337 |
+
from asr_processing import ASRProcessor # type: ignore[no-redef]
|
| 338 |
+
|
| 339 |
+
return ASRProcessor(
|
| 340 |
+
feature_extractor=self.feature_extractor,
|
| 341 |
+
tokenizer=self.tokenizer,
|
| 342 |
+
projector=self.projector,
|
| 343 |
+
encoder_conv_layers=self.config.encoder_conv_layers,
|
| 344 |
+
)
|
| 345 |
+
|
| 346 |
+
# =========================================================================
|
| 347 |
+
# Silero VAD for Interruption Detection (Freeze-Omni style)
|
| 348 |
+
# =========================================================================
|
| 349 |
+
|
| 350 |
+
def load_vad(self, force_reload: bool = False) -> None:
|
| 351 |
+
"""Load Silero VAD model for interruption detection.
|
| 352 |
+
|
| 353 |
+
Silero VAD is a lightweight (~2MB) voice activity detector that runs
|
| 354 |
+
in real-time. Used as the first layer of interruption detection.
|
| 355 |
+
|
| 356 |
+
Args:
|
| 357 |
+
force_reload: Force reload even if already loaded
|
| 358 |
+
"""
|
| 359 |
+
if self._vad_model is not None and not force_reload:
|
| 360 |
+
return
|
| 361 |
+
|
| 362 |
+
model, utils = torch.hub.load(
|
| 363 |
+
repo_or_dir="snakers4/silero-vad",
|
| 364 |
+
model="silero_vad",
|
| 365 |
+
force_reload=force_reload,
|
| 366 |
+
trust_repo=True,
|
| 367 |
+
)
|
| 368 |
+
|
| 369 |
+
self._vad_model = model
|
| 370 |
+
self._vad_utils = utils
|
| 371 |
+
|
| 372 |
+
# Freeze VAD model
|
| 373 |
+
self._vad_model.eval()
|
| 374 |
+
for param in self._vad_model.parameters():
|
| 375 |
+
param.requires_grad = False
|
| 376 |
+
|
| 377 |
+
def detect_speech(
|
| 378 |
+
self,
|
| 379 |
+
audio_chunk: torch.Tensor,
|
| 380 |
+
sample_rate: int = 16000,
|
| 381 |
+
threshold: float = 0.5,
|
| 382 |
+
) -> tuple[bool, float]:
|
| 383 |
+
"""Detect speech in an audio chunk using Silero VAD.
|
| 384 |
+
|
| 385 |
+
Args:
|
| 386 |
+
audio_chunk: Audio waveform [samples] or [1, samples] at sample_rate
|
| 387 |
+
sample_rate: Audio sample rate (default 16kHz)
|
| 388 |
+
threshold: Speech probability threshold (default 0.5)
|
| 389 |
+
|
| 390 |
+
Returns:
|
| 391 |
+
Tuple of (is_speech, probability)
|
| 392 |
+
"""
|
| 393 |
+
if self._vad_model is None:
|
| 394 |
+
self.load_vad()
|
| 395 |
+
|
| 396 |
+
# Ensure 1D tensor
|
| 397 |
+
if audio_chunk.dim() > 1:
|
| 398 |
+
audio_chunk = audio_chunk.squeeze()
|
| 399 |
+
|
| 400 |
+
# VAD expects specific sample rates (8000 or 16000)
|
| 401 |
+
if sample_rate not in (8000, 16000):
|
| 402 |
+
import torchaudio.functional as audio_functional
|
| 403 |
+
|
| 404 |
+
audio_chunk = audio_functional.resample(audio_chunk, sample_rate, 16000)
|
| 405 |
+
sample_rate = 16000
|
| 406 |
+
|
| 407 |
+
# Run VAD
|
| 408 |
+
with torch.no_grad():
|
| 409 |
+
speech_prob = self._vad_model(audio_chunk, sample_rate).item()
|
| 410 |
+
|
| 411 |
+
return speech_prob > threshold, speech_prob
|
| 412 |
+
|
| 413 |
+
def reset_vad_state(self) -> None:
|
| 414 |
+
"""Reset VAD internal state between utterances."""
|
| 415 |
+
if self._vad_model is not None:
|
| 416 |
+
self._vad_model.reset_states()
|
| 417 |
+
|
| 418 |
+
def state_dict(self, *args, **kwargs) -> dict[str, torch.Tensor]:
|
| 419 |
+
"""Save trainable weights (projector + audio_head if present)."""
|
| 420 |
+
state = {f"projector.{k}": v for k, v in self.projector.state_dict().items()}
|
| 421 |
+
if self.audio_head is not None:
|
| 422 |
+
state.update({f"audio_head.{k}": v for k, v in self.audio_head.state_dict().items()})
|
| 423 |
+
return state
|
| 424 |
+
|
| 425 |
+
def _compute_encoder_output_lengths(
|
| 426 |
+
self,
|
| 427 |
+
audio_attention_mask: torch.Tensor,
|
| 428 |
+
) -> torch.Tensor:
|
| 429 |
+
"""Compute per-sample encoder output lengths using conv layer formulas.
|
| 430 |
+
|
| 431 |
+
Args:
|
| 432 |
+
audio_attention_mask: Mask indicating real vs padded mel frames (batch, mel_len)
|
| 433 |
+
|
| 434 |
+
Returns:
|
| 435 |
+
Tensor of encoder output lengths per sample (batch,)
|
| 436 |
+
"""
|
| 437 |
+
# Get mel frame lengths from attention mask
|
| 438 |
+
lengths = audio_attention_mask.sum(dim=-1)
|
| 439 |
+
|
| 440 |
+
# Apply conv layer formulas: output = (input + 2*pad - (kernel-1) - 1) // stride + 1
|
| 441 |
+
for padding, kernel_size, stride in self.config.encoder_conv_layers:
|
| 442 |
+
lengths = (lengths + 2 * padding - (kernel_size - 1) - 1) // stride + 1
|
| 443 |
+
|
| 444 |
+
return lengths
|
| 445 |
+
|
| 446 |
+
def _encode_audio(
|
| 447 |
+
self,
|
| 448 |
+
audio_features: torch.Tensor,
|
| 449 |
+
audio_attention_mask: torch.Tensor,
|
| 450 |
+
expected_token_counts: torch.Tensor | None = None,
|
| 451 |
+
) -> torch.Tensor:
|
| 452 |
+
"""Encode audio and project to LLM embedding space.
|
| 453 |
+
|
| 454 |
+
Args:
|
| 455 |
+
audio_features: Mel spectrogram features (batch, n_mels, mel_len)
|
| 456 |
+
audio_attention_mask: Mask indicating real vs padded mel frames (batch, mel_len)
|
| 457 |
+
expected_token_counts: Expected number of audio tokens per sample from input_ids.
|
| 458 |
+
If provided, output will match these counts exactly (padding/truncating as needed).
|
| 459 |
+
|
| 460 |
+
Returns:
|
| 461 |
+
Flattened audio embeddings of shape (total_audio_tokens, hidden_dim).
|
| 462 |
+
"""
|
| 463 |
+
with torch.no_grad():
|
| 464 |
+
encoder_out = self.audio_tower(input_features=audio_features)
|
| 465 |
+
hidden_states = encoder_out.last_hidden_state
|
| 466 |
+
|
| 467 |
+
# Project to LLM space
|
| 468 |
+
audio_embeds = self.projector(hidden_states)
|
| 469 |
+
|
| 470 |
+
# Use expected token counts if provided (from input_ids), otherwise compute from audio
|
| 471 |
+
if expected_token_counts is not None:
|
| 472 |
+
token_counts = expected_token_counts
|
| 473 |
+
else:
|
| 474 |
+
# Compute per-sample encoder output lengths using conv formulas
|
| 475 |
+
encoder_lengths = self._compute_encoder_output_lengths(audio_attention_mask)
|
| 476 |
+
token_counts = torch.tensor(
|
| 477 |
+
[
|
| 478 |
+
self.projector.get_output_length(int(length.item()))
|
| 479 |
+
for length in encoder_lengths
|
| 480 |
+
],
|
| 481 |
+
device=audio_embeds.device,
|
| 482 |
+
)
|
| 483 |
+
|
| 484 |
+
# Extract embeddings matching expected token counts per sample
|
| 485 |
+
batch_size = audio_embeds.shape[0]
|
| 486 |
+
|
| 487 |
+
result_embeds = []
|
| 488 |
+
for i in range(batch_size):
|
| 489 |
+
count = int(token_counts[i].item())
|
| 490 |
+
sample_embeds = audio_embeds[i, :count, :] # Take first 'count' embeddings
|
| 491 |
+
# Pad with learned embedding if we don't have enough embeddings
|
| 492 |
+
if sample_embeds.shape[0] < count:
|
| 493 |
+
pad_count = count - sample_embeds.shape[0]
|
| 494 |
+
padding = self.audio_pad_embedding.expand(pad_count, -1).to(
|
| 495 |
+
device=audio_embeds.device, dtype=audio_embeds.dtype
|
| 496 |
+
)
|
| 497 |
+
sample_embeds = torch.cat([sample_embeds, padding], dim=0)
|
| 498 |
+
result_embeds.append(sample_embeds)
|
| 499 |
+
|
| 500 |
+
return torch.cat(result_embeds, dim=0)
|
| 501 |
+
|
| 502 |
+
def forward(
|
| 503 |
+
self,
|
| 504 |
+
input_ids: Optional[torch.Tensor] = None,
|
| 505 |
+
input_features: Optional[torch.Tensor] = None,
|
| 506 |
+
audio_attention_mask: Optional[torch.Tensor] = None,
|
| 507 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 508 |
+
position_ids: Optional[torch.Tensor] = None,
|
| 509 |
+
past_key_values: Optional[torch.Tensor] = None,
|
| 510 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 511 |
+
labels: Optional[torch.Tensor] = None,
|
| 512 |
+
use_cache: Optional[bool] = None,
|
| 513 |
+
cache_position: Optional[torch.Tensor] = None,
|
| 514 |
+
**kwargs,
|
| 515 |
+
) -> CausalLMOutputWithPast:
|
| 516 |
+
"""Forward pass for training and inference."""
|
| 517 |
+
# Get text embeddings if not provided
|
| 518 |
+
if inputs_embeds is None:
|
| 519 |
+
inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
|
| 520 |
+
|
| 521 |
+
if input_features is not None and input_ids is not None:
|
| 522 |
+
# Apply SpecAugment during training if enabled
|
| 523 |
+
if self.training and self.spec_augment is not None:
|
| 524 |
+
input_features = self.spec_augment(input_features)
|
| 525 |
+
|
| 526 |
+
# Count expected audio tokens from input_ids (ground truth from collator)
|
| 527 |
+
audio_token_counts = (input_ids == self.audio_token_id).sum(dim=-1)
|
| 528 |
+
|
| 529 |
+
# Encode audio -> flattened (total_audio_tokens, hidden_dim)
|
| 530 |
+
audio_embeds = self._encode_audio(
|
| 531 |
+
input_features, audio_attention_mask, audio_token_counts
|
| 532 |
+
)
|
| 533 |
+
|
| 534 |
+
# Replace <audio> token placeholders with audio embeddings using masked_scatter
|
| 535 |
+
audio_token_mask = (input_ids == self.audio_token_id).unsqueeze(-1)
|
| 536 |
+
|
| 537 |
+
inputs_embeds = inputs_embeds.masked_scatter(
|
| 538 |
+
audio_token_mask.to(inputs_embeds.device),
|
| 539 |
+
audio_embeds.to(inputs_embeds.device, dtype=inputs_embeds.dtype),
|
| 540 |
+
)
|
| 541 |
+
|
| 542 |
+
# Remove TRL-specific keys that shouldn't go to the LLM
|
| 543 |
+
kwargs.pop("prompts", None)
|
| 544 |
+
kwargs.pop("prompt_attention_mask", None)
|
| 545 |
+
|
| 546 |
+
# Run through language model (let it compute loss if labels provided)
|
| 547 |
+
outputs = self.language_model(
|
| 548 |
+
attention_mask=attention_mask,
|
| 549 |
+
position_ids=position_ids,
|
| 550 |
+
past_key_values=past_key_values,
|
| 551 |
+
inputs_embeds=inputs_embeds,
|
| 552 |
+
labels=labels,
|
| 553 |
+
use_cache=use_cache,
|
| 554 |
+
cache_position=cache_position,
|
| 555 |
+
**kwargs,
|
| 556 |
+
)
|
| 557 |
+
|
| 558 |
+
return outputs
|
| 559 |
+
|
| 560 |
+
def prepare_inputs_for_generation(self, *args, **kwargs):
|
| 561 |
+
"""Prepare inputs for generation, handling audio features for cached decoding."""
|
| 562 |
+
input_features = kwargs.pop("input_features", None)
|
| 563 |
+
cache_position = kwargs.get("cache_position")
|
| 564 |
+
|
| 565 |
+
model_inputs = self.language_model.prepare_inputs_for_generation(*args, **kwargs)
|
| 566 |
+
|
| 567 |
+
# Only pass audio features on the first generation step (cache_position[0] == 0)
|
| 568 |
+
if cache_position is not None and cache_position[0] == 0 and input_features is not None:
|
| 569 |
+
model_inputs["input_features"] = input_features
|
| 570 |
+
|
| 571 |
+
return model_inputs
|
| 572 |
+
|
| 573 |
+
def _get_num_audio_tokens(
|
| 574 |
+
self,
|
| 575 |
+
audio_attention_mask: torch.Tensor,
|
| 576 |
+
) -> int:
|
| 577 |
+
"""Calculate number of audio tokens based on actual audio length.
|
| 578 |
+
|
| 579 |
+
Uses attention mask to get real audio length, then computes:
|
| 580 |
+
mel_frames -> encoder_frames (via conv formulas) -> projector output tokens
|
| 581 |
+
"""
|
| 582 |
+
encoder_lengths = self._compute_encoder_output_lengths(audio_attention_mask)
|
| 583 |
+
# Use max length for batch (all samples should have same token count for generation)
|
| 584 |
+
encoder_output_len = int(encoder_lengths.max().item())
|
| 585 |
+
return int(self.projector.get_output_length(encoder_output_len))
|
| 586 |
+
|
| 587 |
+
def _build_audio_prompt(
|
| 588 |
+
self,
|
| 589 |
+
audio_attention_mask: torch.Tensor,
|
| 590 |
+
batch_size: int,
|
| 591 |
+
device: torch.device,
|
| 592 |
+
system_prompt: Optional[str] = None,
|
| 593 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 594 |
+
"""Build input_ids and attention_mask for audio-conditioned generation.
|
| 595 |
+
|
| 596 |
+
Args:
|
| 597 |
+
audio_attention_mask: Mask for real vs padded mel frames
|
| 598 |
+
batch_size: Batch size for expanding single prompts
|
| 599 |
+
device: Device to place tensors on
|
| 600 |
+
system_prompt: Optional system prompt override
|
| 601 |
+
|
| 602 |
+
Returns:
|
| 603 |
+
Tuple of (input_ids, attention_mask) tensors
|
| 604 |
+
"""
|
| 605 |
+
num_audio_tokens = self._get_num_audio_tokens(audio_attention_mask)
|
| 606 |
+
audio_placeholder = "<audio>" * num_audio_tokens
|
| 607 |
+
|
| 608 |
+
system_prompt = system_prompt or self.system_prompt
|
| 609 |
+
|
| 610 |
+
messages: list[dict[str, str]] = []
|
| 611 |
+
if system_prompt:
|
| 612 |
+
messages.append({"role": "system", "content": system_prompt})
|
| 613 |
+
user_content = audio_placeholder
|
| 614 |
+
if self.TRANSCRIBE_PROMPT:
|
| 615 |
+
user_content += " " + self.TRANSCRIBE_PROMPT
|
| 616 |
+
messages.append({"role": "user", "content": user_content})
|
| 617 |
+
|
| 618 |
+
chat_result = self.tokenizer.apply_chat_template(
|
| 619 |
+
messages,
|
| 620 |
+
tokenize=True,
|
| 621 |
+
add_generation_prompt=True,
|
| 622 |
+
return_tensors="pt",
|
| 623 |
+
enable_thinking=getattr(self.config, "enable_thinking", False),
|
| 624 |
+
)
|
| 625 |
+
input_ids = chat_result.input_ids.to(device)
|
| 626 |
+
|
| 627 |
+
if input_ids.dim() == 1:
|
| 628 |
+
input_ids = input_ids.unsqueeze(0)
|
| 629 |
+
if input_ids.shape[0] == 1 and batch_size > 1:
|
| 630 |
+
input_ids = input_ids.expand(batch_size, -1)
|
| 631 |
+
|
| 632 |
+
return input_ids, torch.ones_like(input_ids)
|
| 633 |
+
|
| 634 |
+
def _inject_audio_embeddings(
|
| 635 |
+
self,
|
| 636 |
+
input_ids: torch.Tensor,
|
| 637 |
+
audio_embeds: torch.Tensor,
|
| 638 |
+
) -> torch.Tensor:
|
| 639 |
+
"""Replace audio token placeholders with actual audio embeddings.
|
| 640 |
+
|
| 641 |
+
Args:
|
| 642 |
+
input_ids: Token IDs containing <audio> placeholder tokens
|
| 643 |
+
audio_embeds: Encoded audio embeddings to inject
|
| 644 |
+
|
| 645 |
+
Returns:
|
| 646 |
+
Input embeddings with audio tokens replaced by audio embeddings
|
| 647 |
+
"""
|
| 648 |
+
inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
|
| 649 |
+
audio_token_mask = (input_ids == self.audio_token_id).unsqueeze(-1)
|
| 650 |
+
return inputs_embeds.masked_scatter(
|
| 651 |
+
audio_token_mask.to(inputs_embeds.device),
|
| 652 |
+
audio_embeds.to(inputs_embeds.device, dtype=inputs_embeds.dtype),
|
| 653 |
+
)
|
| 654 |
+
|
| 655 |
+
@torch.no_grad()
|
| 656 |
+
def generate(
|
| 657 |
+
self,
|
| 658 |
+
input_ids: Optional[torch.Tensor] = None,
|
| 659 |
+
input_features: Optional[torch.Tensor] = None,
|
| 660 |
+
audio_attention_mask: Optional[torch.Tensor] = None,
|
| 661 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 662 |
+
system_prompt: Optional[str] = None,
|
| 663 |
+
**generate_kwargs,
|
| 664 |
+
) -> torch.Tensor:
|
| 665 |
+
"""Generate transcription from audio input.
|
| 666 |
+
|
| 667 |
+
Can be called in two ways:
|
| 668 |
+
1. With input_ids containing <audio> tokens (from processor)
|
| 669 |
+
2. With just audio, and we build the prompt internally
|
| 670 |
+
"""
|
| 671 |
+
if input_features is None:
|
| 672 |
+
raise ValueError("input_features required for generation")
|
| 673 |
+
if audio_attention_mask is None:
|
| 674 |
+
raise ValueError("audio_attention_mask required for generation")
|
| 675 |
+
|
| 676 |
+
device = input_features.device
|
| 677 |
+
batch_size = input_features.shape[0]
|
| 678 |
+
|
| 679 |
+
# Encode audio -> flattened embeddings
|
| 680 |
+
audio_embeds = self._encode_audio(input_features, audio_attention_mask)
|
| 681 |
+
|
| 682 |
+
# If input_ids not provided, build prompt with correct number of audio tokens
|
| 683 |
+
if input_ids is None:
|
| 684 |
+
input_ids, attention_mask = self._build_audio_prompt(
|
| 685 |
+
audio_attention_mask, batch_size, device, system_prompt
|
| 686 |
+
)
|
| 687 |
+
|
| 688 |
+
# Replace audio token placeholders with audio embeddings
|
| 689 |
+
inputs_embeds = self._inject_audio_embeddings(input_ids, audio_embeds)
|
| 690 |
+
|
| 691 |
+
# Generate using language model
|
| 692 |
+
# Pass both input_ids and inputs_embeds so repetition_penalty works correctly
|
| 693 |
+
# (it needs input_ids to track which tokens have been used)
|
| 694 |
+
output = self.language_model.generate(
|
| 695 |
+
input_ids=input_ids,
|
| 696 |
+
inputs_embeds=inputs_embeds,
|
| 697 |
+
attention_mask=attention_mask,
|
| 698 |
+
generation_config=self.generation_config,
|
| 699 |
+
**generate_kwargs,
|
| 700 |
+
)
|
| 701 |
+
|
| 702 |
+
# When using inputs_embeds with input_ids, generate returns full sequence
|
| 703 |
+
# Strip the input tokens to return only generated tokens
|
| 704 |
+
sequences = output if isinstance(output, torch.Tensor) else output.sequences
|
| 705 |
+
input_len = input_ids.shape[1]
|
| 706 |
+
return sequences[:, input_len:]
|
| 707 |
+
|
| 708 |
+
def _process_audio(
|
| 709 |
+
self,
|
| 710 |
+
audio,
|
| 711 |
+
sampling_rate: int = 16000,
|
| 712 |
+
) -> dict[str, torch.Tensor]:
|
| 713 |
+
"""Process raw audio waveform to model inputs."""
|
| 714 |
+
# Convert to numpy if tensor
|
| 715 |
+
if isinstance(audio, torch.Tensor):
|
| 716 |
+
audio = audio.cpu().numpy()
|
| 717 |
+
|
| 718 |
+
# Get mel features from feature extractor
|
| 719 |
+
inputs = self.feature_extractor(
|
| 720 |
+
audio,
|
| 721 |
+
sampling_rate=sampling_rate,
|
| 722 |
+
return_attention_mask=True,
|
| 723 |
+
return_tensors="pt",
|
| 724 |
+
)
|
| 725 |
+
|
| 726 |
+
device = next(self.language_model.parameters()).device
|
| 727 |
+
return {
|
| 728 |
+
"input_features": inputs["input_features"].to(device),
|
| 729 |
+
"attention_mask": inputs["attention_mask"].to(device),
|
| 730 |
+
}
|
| 731 |
+
|
| 732 |
+
def save_pretrained(self, save_directory: Union[str, Path], **kwargs) -> None:
|
| 733 |
+
"""Save model, tokenizer, and processor."""
|
| 734 |
+
import shutil
|
| 735 |
+
from pathlib import Path as PathlibPath
|
| 736 |
+
|
| 737 |
+
save_dir = PathlibPath(save_directory)
|
| 738 |
+
save_dir.mkdir(parents=True, exist_ok=True)
|
| 739 |
+
|
| 740 |
+
# Update config with actual vocab size
|
| 741 |
+
self.config.vocab_size = self.language_model.config.vocab_size
|
| 742 |
+
self.config.text_config.vocab_size = self.language_model.config.vocab_size
|
| 743 |
+
|
| 744 |
+
if hasattr(self.audio_tower.config, "num_mel_bins"):
|
| 745 |
+
self.config.audio_config.num_mel_bins = self.audio_tower.config.num_mel_bins
|
| 746 |
+
|
| 747 |
+
# Save config
|
| 748 |
+
self.config.save_pretrained(save_dir)
|
| 749 |
+
|
| 750 |
+
# Save state dict directly to avoid HuggingFace's tied weights handling
|
| 751 |
+
# which conflicts with our shared AudioHead embedding
|
| 752 |
+
state_dict = self.state_dict()
|
| 753 |
+
safe_serialization = kwargs.get("safe_serialization", True)
|
| 754 |
+
|
| 755 |
+
if safe_serialization:
|
| 756 |
+
from safetensors.torch import save_file
|
| 757 |
+
|
| 758 |
+
save_file(state_dict, save_dir / "model.safetensors")
|
| 759 |
+
else:
|
| 760 |
+
import torch
|
| 761 |
+
|
| 762 |
+
torch.save(state_dict, save_dir / "pytorch_model.bin")
|
| 763 |
+
|
| 764 |
+
# Save tokenizer and feature extractor
|
| 765 |
+
self.tokenizer.save_pretrained(save_dir)
|
| 766 |
+
self.feature_extractor.save_pretrained(save_dir)
|
| 767 |
+
|
| 768 |
+
# Add processor auto_map to preprocessor_config.json
|
| 769 |
+
config_path = save_dir / "preprocessor_config.json"
|
| 770 |
+
if config_path.exists():
|
| 771 |
+
with config_path.open() as f:
|
| 772 |
+
processor_config = json.load(f)
|
| 773 |
+
else:
|
| 774 |
+
processor_config = {}
|
| 775 |
+
|
| 776 |
+
processor_config.update(
|
| 777 |
+
{
|
| 778 |
+
"processor_class": "ASRProcessor",
|
| 779 |
+
"auto_map": {"AutoProcessor": "asr_processing.ASRProcessor"},
|
| 780 |
+
}
|
| 781 |
+
)
|
| 782 |
+
|
| 783 |
+
with config_path.open("w") as f:
|
| 784 |
+
json.dump(processor_config, f, indent=2)
|
| 785 |
+
|
| 786 |
+
# Copy source files for auto-loading
|
| 787 |
+
src_dir = PathlibPath(__file__).parent
|
| 788 |
+
for asr_file in src_dir.glob("asr_*.py"):
|
| 789 |
+
shutil.copy(asr_file, save_dir / asr_file.name)
|
| 790 |
+
# Copy projectors module
|
| 791 |
+
shutil.copy(src_dir / "projectors.py", save_dir / "projectors.py")
|
| 792 |
+
# Copy alignment module
|
| 793 |
+
shutil.copy(src_dir / "alignment.py", save_dir / "alignment.py")
|
| 794 |
+
# Copy diarization module
|
| 795 |
+
shutil.copy(src_dir / "diarization.py", save_dir / "diarization.py")
|
| 796 |
+
# Copy audio head for S2S
|
| 797 |
+
audio_head_path = src_dir / "audio_head.py"
|
| 798 |
+
if audio_head_path.exists():
|
| 799 |
+
shutil.copy(audio_head_path, save_dir / "audio_head.py")
|
| 800 |
+
# Copy full duplex session for S2S
|
| 801 |
+
full_duplex_path = src_dir / "full_duplex.py"
|
| 802 |
+
if full_duplex_path.exists():
|
| 803 |
+
shutil.copy(full_duplex_path, save_dir / "full_duplex.py")
|
| 804 |
+
|
| 805 |
+
def push_to_hub(self, repo_id: str, **kwargs) -> str:
|
| 806 |
+
"""Push model to HuggingFace Hub."""
|
| 807 |
+
self.config.pretrained_model_path = repo_id
|
| 808 |
+
return super().push_to_hub(repo_id, **kwargs)
|
| 809 |
+
|
| 810 |
+
def create_or_update_model_card(self, output_dir: Union[str, Path]) -> None:
|
| 811 |
+
"""No-op for model card creation - we use MODEL_CARD.md in repo instead."""
|
| 812 |
+
pass
|
| 813 |
+
|
| 814 |
+
|
| 815 |
+
# Register with transformers Auto classes
|
| 816 |
+
AutoConfig.register("asr_model", ASRConfig)
|
| 817 |
+
AutoModel.register(ASRConfig, ASRModel)
|
asr_pipeline.py
ADDED
|
@@ -0,0 +1,370 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""ASR pipeline for audio-to-text transcription with optional timestamps and diarization."""
|
| 2 |
+
|
| 3 |
+
import re
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from typing import Any
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
import transformers
|
| 10 |
+
|
| 11 |
+
try:
|
| 12 |
+
from .alignment import ForcedAligner
|
| 13 |
+
from .asr_modeling import ASRModel
|
| 14 |
+
from .diarization import LocalSpeakerDiarizer
|
| 15 |
+
except ImportError:
|
| 16 |
+
from alignment import ForcedAligner # type: ignore[no-redef]
|
| 17 |
+
from asr_modeling import ASRModel # type: ignore[no-redef]
|
| 18 |
+
from diarization import LocalSpeakerDiarizer # type: ignore[no-redef]
|
| 19 |
+
|
| 20 |
+
# Re-export for backwards compatibility
|
| 21 |
+
__all__ = ["ForcedAligner", "LocalSpeakerDiarizer", "ASRPipeline", "strip_thinking"]
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def strip_thinking(text: str) -> str:
|
| 25 |
+
"""Remove <think>...</think> tags from model output.
|
| 26 |
+
|
| 27 |
+
Args:
|
| 28 |
+
text: Model output text that may contain thinking tags
|
| 29 |
+
|
| 30 |
+
Returns:
|
| 31 |
+
Text with thinking content removed
|
| 32 |
+
"""
|
| 33 |
+
if not text:
|
| 34 |
+
return text
|
| 35 |
+
text = re.sub(r"<think>.*?</think>\s*", "", text, flags=re.DOTALL)
|
| 36 |
+
return text.strip()
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
|
| 40 |
+
"""ASR Pipeline for audio-to-text transcription."""
|
| 41 |
+
|
| 42 |
+
model: ASRModel
|
| 43 |
+
|
| 44 |
+
def __init__(self, model: ASRModel, **kwargs):
|
| 45 |
+
"""Initialize ASR pipeline.
|
| 46 |
+
|
| 47 |
+
Args:
|
| 48 |
+
model: ASRModel instance for transcription
|
| 49 |
+
**kwargs: Additional arguments (feature_extractor, tokenizer, device)
|
| 50 |
+
"""
|
| 51 |
+
feature_extractor = kwargs.pop("feature_extractor", None)
|
| 52 |
+
tokenizer = kwargs.pop("tokenizer", model.tokenizer)
|
| 53 |
+
|
| 54 |
+
if feature_extractor is None:
|
| 55 |
+
feature_extractor = model.get_processor().feature_extractor
|
| 56 |
+
|
| 57 |
+
super().__init__(
|
| 58 |
+
model=model, feature_extractor=feature_extractor, tokenizer=tokenizer, **kwargs
|
| 59 |
+
)
|
| 60 |
+
self._current_audio = None
|
| 61 |
+
|
| 62 |
+
def _sanitize_parameters(self, **kwargs):
|
| 63 |
+
"""Intercept our custom parameters before parent class validates them."""
|
| 64 |
+
# Remove our custom parameters so parent doesn't see them
|
| 65 |
+
kwargs.pop("return_timestamps", None)
|
| 66 |
+
kwargs.pop("return_speakers", None)
|
| 67 |
+
kwargs.pop("num_speakers", None)
|
| 68 |
+
kwargs.pop("min_speakers", None)
|
| 69 |
+
kwargs.pop("max_speakers", None)
|
| 70 |
+
kwargs.pop("hf_token", None)
|
| 71 |
+
kwargs.pop("user_prompt", None)
|
| 72 |
+
kwargs.pop("system_prompt", None)
|
| 73 |
+
kwargs.pop("diarization_backend", None)
|
| 74 |
+
return super()._sanitize_parameters(**kwargs)
|
| 75 |
+
|
| 76 |
+
def __call__(
|
| 77 |
+
self,
|
| 78 |
+
inputs,
|
| 79 |
+
**kwargs,
|
| 80 |
+
):
|
| 81 |
+
"""Transcribe audio with optional word-level timestamps and speaker diarization.
|
| 82 |
+
|
| 83 |
+
Args:
|
| 84 |
+
inputs: Audio input (file path, dict with array/sampling_rate, etc.)
|
| 85 |
+
return_timestamps: If True, return word-level timestamps using forced alignment
|
| 86 |
+
return_speakers: If True, return speaker labels for each word
|
| 87 |
+
user_prompt: Custom transcription prompt (default: "Transcribe: ")
|
| 88 |
+
system_prompt: Custom system prompt override (uses model's default if not provided)
|
| 89 |
+
num_speakers: Exact number of speakers (if known, for diarization)
|
| 90 |
+
min_speakers: Minimum number of speakers (for diarization)
|
| 91 |
+
max_speakers: Maximum number of speakers (for diarization)
|
| 92 |
+
**kwargs: Additional arguments passed to the pipeline
|
| 93 |
+
|
| 94 |
+
Returns:
|
| 95 |
+
Dict with 'text' key, 'words' key if return_timestamps=True,
|
| 96 |
+
speaker labels on words if return_speakers=True
|
| 97 |
+
"""
|
| 98 |
+
# Extract our params before super().__call__ (which will also call _sanitize_parameters)
|
| 99 |
+
return_timestamps = kwargs.pop("return_timestamps", False)
|
| 100 |
+
return_speakers = kwargs.pop("return_speakers", False)
|
| 101 |
+
user_prompt = kwargs.pop("user_prompt", None)
|
| 102 |
+
system_prompt = kwargs.pop("system_prompt", None)
|
| 103 |
+
diarization_params = {
|
| 104 |
+
"num_speakers": kwargs.pop("num_speakers", None),
|
| 105 |
+
"min_speakers": kwargs.pop("min_speakers", None),
|
| 106 |
+
"max_speakers": kwargs.pop("max_speakers", None),
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
if return_speakers:
|
| 110 |
+
return_timestamps = True
|
| 111 |
+
|
| 112 |
+
# Set custom user prompt if provided
|
| 113 |
+
original_prompt = None
|
| 114 |
+
if user_prompt:
|
| 115 |
+
original_prompt = self.model.TRANSCRIBE_PROMPT
|
| 116 |
+
self.model.TRANSCRIBE_PROMPT = user_prompt
|
| 117 |
+
|
| 118 |
+
# Set custom system prompt if provided
|
| 119 |
+
original_system_prompt = None
|
| 120 |
+
if system_prompt:
|
| 121 |
+
original_system_prompt = self.model.system_prompt
|
| 122 |
+
self.model.system_prompt = system_prompt
|
| 123 |
+
|
| 124 |
+
# Store audio for timestamp alignment and diarization
|
| 125 |
+
if return_timestamps or return_speakers:
|
| 126 |
+
self._current_audio = self._extract_audio(inputs)
|
| 127 |
+
|
| 128 |
+
# Run standard transcription
|
| 129 |
+
result = super().__call__(inputs, **kwargs)
|
| 130 |
+
|
| 131 |
+
# Add timestamps if requested
|
| 132 |
+
if return_timestamps and self._current_audio is not None:
|
| 133 |
+
text = result.get("text", "")
|
| 134 |
+
if text:
|
| 135 |
+
try:
|
| 136 |
+
words = ForcedAligner.align(
|
| 137 |
+
self._current_audio["array"],
|
| 138 |
+
text,
|
| 139 |
+
sample_rate=self._current_audio.get("sampling_rate", 16000),
|
| 140 |
+
)
|
| 141 |
+
result["words"] = words
|
| 142 |
+
except Exception as e:
|
| 143 |
+
result["words"] = []
|
| 144 |
+
result["timestamp_error"] = str(e)
|
| 145 |
+
else:
|
| 146 |
+
result["words"] = []
|
| 147 |
+
|
| 148 |
+
# Add speaker diarization if requested
|
| 149 |
+
if return_speakers and self._current_audio is not None:
|
| 150 |
+
try:
|
| 151 |
+
# Run diarization
|
| 152 |
+
speaker_segments = LocalSpeakerDiarizer.diarize(
|
| 153 |
+
self._current_audio["array"],
|
| 154 |
+
sample_rate=self._current_audio.get("sampling_rate", 16000),
|
| 155 |
+
**{k: v for k, v in diarization_params.items() if v is not None},
|
| 156 |
+
)
|
| 157 |
+
result["speaker_segments"] = speaker_segments
|
| 158 |
+
|
| 159 |
+
# Assign speakers to words
|
| 160 |
+
if result.get("words"):
|
| 161 |
+
result["words"] = LocalSpeakerDiarizer.assign_speakers_to_words(
|
| 162 |
+
result["words"],
|
| 163 |
+
speaker_segments,
|
| 164 |
+
)
|
| 165 |
+
except Exception as e:
|
| 166 |
+
result["speaker_segments"] = []
|
| 167 |
+
result["diarization_error"] = str(e)
|
| 168 |
+
|
| 169 |
+
# Clean up
|
| 170 |
+
self._current_audio = None
|
| 171 |
+
if original_prompt is not None:
|
| 172 |
+
self.model.TRANSCRIBE_PROMPT = original_prompt
|
| 173 |
+
if original_system_prompt is not None:
|
| 174 |
+
self.model.system_prompt = original_system_prompt
|
| 175 |
+
|
| 176 |
+
return result
|
| 177 |
+
|
| 178 |
+
def _extract_audio(self, inputs) -> dict | None:
|
| 179 |
+
"""Extract audio array from various input formats.
|
| 180 |
+
|
| 181 |
+
Supported input formats:
|
| 182 |
+
- str: File path to audio file
|
| 183 |
+
- bytes: Encoded audio (mp3, wav, etc.) - decoded via ffmpeg
|
| 184 |
+
- np.ndarray: Audio samples as float32 array
|
| 185 |
+
- dict with "array": Audio samples as numpy array
|
| 186 |
+
- dict with "raw": Alias for "array" (HF pipeline compat)
|
| 187 |
+
- dict with "raw_bytes": Raw PCM bytes (requires "dtype", optional "sampling_rate")
|
| 188 |
+
|
| 189 |
+
For raw PCM bytes, use:
|
| 190 |
+
{"raw_bytes": pcm_bytes, "dtype": "int16", "sampling_rate": 16000}
|
| 191 |
+
"""
|
| 192 |
+
from transformers.pipelines.audio_utils import ffmpeg_read
|
| 193 |
+
|
| 194 |
+
if isinstance(inputs, dict):
|
| 195 |
+
if "array" in inputs:
|
| 196 |
+
return {
|
| 197 |
+
"array": inputs["array"],
|
| 198 |
+
"sampling_rate": inputs.get("sampling_rate", 16000),
|
| 199 |
+
}
|
| 200 |
+
if "raw" in inputs:
|
| 201 |
+
return {
|
| 202 |
+
"array": inputs["raw"],
|
| 203 |
+
"sampling_rate": inputs.get("sampling_rate", 16000),
|
| 204 |
+
}
|
| 205 |
+
if "raw_bytes" in inputs:
|
| 206 |
+
# Raw PCM bytes - convert to float32 array
|
| 207 |
+
dtype = inputs.get("dtype", "int16")
|
| 208 |
+
sample_rate = inputs.get("sampling_rate", 16000)
|
| 209 |
+
audio = np.frombuffer(inputs["raw_bytes"], dtype=dtype).astype(np.float32)
|
| 210 |
+
# Normalize based on dtype
|
| 211 |
+
if dtype == "int16":
|
| 212 |
+
audio = audio / 32768.0
|
| 213 |
+
elif dtype == "int32":
|
| 214 |
+
audio = audio / 2147483648.0
|
| 215 |
+
return {"array": audio, "sampling_rate": sample_rate}
|
| 216 |
+
elif isinstance(inputs, str):
|
| 217 |
+
# File path - load audio using ffmpeg (same as HF pipeline)
|
| 218 |
+
with Path(inputs).open("rb") as f:
|
| 219 |
+
audio = ffmpeg_read(f.read(), sampling_rate=16000)
|
| 220 |
+
return {"array": audio, "sampling_rate": 16000}
|
| 221 |
+
elif isinstance(inputs, bytes):
|
| 222 |
+
audio = ffmpeg_read(inputs, sampling_rate=16000)
|
| 223 |
+
return {"array": audio, "sampling_rate": 16000}
|
| 224 |
+
elif isinstance(inputs, np.ndarray):
|
| 225 |
+
return {"array": inputs, "sampling_rate": 16000}
|
| 226 |
+
|
| 227 |
+
return None
|
| 228 |
+
|
| 229 |
+
def preprocess(self, inputs, **preprocess_params):
|
| 230 |
+
"""Preprocess audio inputs for the model.
|
| 231 |
+
|
| 232 |
+
Args:
|
| 233 |
+
inputs: Audio input (dict with array, file path, etc.)
|
| 234 |
+
**preprocess_params: Additional preprocessing parameters
|
| 235 |
+
|
| 236 |
+
Yields:
|
| 237 |
+
Model input dicts with input_features and attention_mask
|
| 238 |
+
"""
|
| 239 |
+
# Handle dict with "array" key (from datasets)
|
| 240 |
+
if isinstance(inputs, dict) and "array" in inputs:
|
| 241 |
+
inputs = {
|
| 242 |
+
"raw": inputs["array"],
|
| 243 |
+
"sampling_rate": inputs.get("sampling_rate", self.feature_extractor.sampling_rate),
|
| 244 |
+
}
|
| 245 |
+
|
| 246 |
+
for item in super().preprocess(inputs, **preprocess_params):
|
| 247 |
+
if "is_last" not in item:
|
| 248 |
+
item["is_last"] = True
|
| 249 |
+
yield item
|
| 250 |
+
|
| 251 |
+
def _forward(self, model_inputs, **generate_kwargs) -> dict[str, Any]:
|
| 252 |
+
"""Run model forward pass to generate transcription.
|
| 253 |
+
|
| 254 |
+
Args:
|
| 255 |
+
model_inputs: Dict with input_features and attention_mask
|
| 256 |
+
**generate_kwargs: Generation parameters
|
| 257 |
+
|
| 258 |
+
Returns:
|
| 259 |
+
Dict with generated token IDs
|
| 260 |
+
"""
|
| 261 |
+
# Extract audio features and is_last flag
|
| 262 |
+
is_last = model_inputs.pop("is_last", True) if isinstance(model_inputs, dict) else True
|
| 263 |
+
|
| 264 |
+
input_features = model_inputs["input_features"].to(self.model.device)
|
| 265 |
+
audio_attention_mask = model_inputs["attention_mask"].to(self.model.device)
|
| 266 |
+
|
| 267 |
+
generated_ids = self.model.generate(
|
| 268 |
+
input_features=input_features,
|
| 269 |
+
audio_attention_mask=audio_attention_mask,
|
| 270 |
+
**generate_kwargs,
|
| 271 |
+
)
|
| 272 |
+
|
| 273 |
+
return {"tokens": generated_ids, "is_last": is_last}
|
| 274 |
+
|
| 275 |
+
def postprocess(self, model_outputs, **kwargs) -> dict[str, str]:
|
| 276 |
+
"""Convert model output tokens to text.
|
| 277 |
+
|
| 278 |
+
Args:
|
| 279 |
+
model_outputs: Dict with 'tokens' key containing generated IDs
|
| 280 |
+
**kwargs: Additional postprocessing parameters
|
| 281 |
+
|
| 282 |
+
Returns:
|
| 283 |
+
Dict with 'text' key containing transcription
|
| 284 |
+
"""
|
| 285 |
+
# Handle list of outputs (from chunking)
|
| 286 |
+
if isinstance(model_outputs, list):
|
| 287 |
+
model_outputs = model_outputs[0] if model_outputs else {}
|
| 288 |
+
|
| 289 |
+
tokens = model_outputs.get("tokens")
|
| 290 |
+
if tokens is None:
|
| 291 |
+
return super().postprocess(model_outputs, **kwargs)
|
| 292 |
+
|
| 293 |
+
if torch.is_tensor(tokens):
|
| 294 |
+
tokens = tokens.cpu()
|
| 295 |
+
if tokens.dim() > 1:
|
| 296 |
+
tokens = tokens[0]
|
| 297 |
+
|
| 298 |
+
# Filter out eos tokens that the tokenizer doesn't recognize as special
|
| 299 |
+
# (generation_config.eos_token_id may differ from tokenizer.eos_token_id)
|
| 300 |
+
if hasattr(self, "model") and hasattr(self.model, "generation_config"):
|
| 301 |
+
eos_ids = self.model.generation_config.eos_token_id
|
| 302 |
+
if eos_ids is not None:
|
| 303 |
+
eos_set = set(eos_ids) if isinstance(eos_ids, list) else {eos_ids}
|
| 304 |
+
tokens = [t for t in tokens.tolist() if t not in eos_set]
|
| 305 |
+
|
| 306 |
+
text = self.tokenizer.decode(tokens, skip_special_tokens=True).strip()
|
| 307 |
+
# Strip <think>...</think> tags (Qwen3 doesn't respect /no_think prompt)
|
| 308 |
+
text = strip_thinking(text)
|
| 309 |
+
# Truncate repetitions at end of text
|
| 310 |
+
text = _truncate_repetitions(text)
|
| 311 |
+
return {"text": text}
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
def _truncate_repetitions(text: str, min_repeats: int = 3) -> str:
|
| 315 |
+
"""Truncate repeated words/phrases/characters at end of text.
|
| 316 |
+
|
| 317 |
+
Detects patterns like:
|
| 318 |
+
- Repeated words: "the the the the" -> "the"
|
| 319 |
+
- Repeated phrases: "i am sorry i am sorry i am sorry" -> "i am sorry"
|
| 320 |
+
- Repeated characters: "444444" -> "4"
|
| 321 |
+
|
| 322 |
+
Args:
|
| 323 |
+
text: Input text to process
|
| 324 |
+
min_repeats: Minimum repetitions to trigger truncation (default 3)
|
| 325 |
+
|
| 326 |
+
Returns:
|
| 327 |
+
Text with trailing repetitions removed
|
| 328 |
+
"""
|
| 329 |
+
if not text:
|
| 330 |
+
return text
|
| 331 |
+
|
| 332 |
+
# 1. Truncate repeated characters at end (e.g., "444444" -> "4")
|
| 333 |
+
char_pattern = re.compile(r"(.)\1{" + str(min_repeats - 1) + r",}$")
|
| 334 |
+
text = char_pattern.sub(r"\1", text)
|
| 335 |
+
|
| 336 |
+
# 2. Truncate repeated words at end (e.g., "the the the" -> "the")
|
| 337 |
+
word_pattern = re.compile(
|
| 338 |
+
r"\b(\w+)(?:\s+\1){" + str(min_repeats - 1) + r",}\s*$", re.IGNORECASE
|
| 339 |
+
)
|
| 340 |
+
while word_pattern.search(text):
|
| 341 |
+
text = word_pattern.sub(r"\1", text)
|
| 342 |
+
|
| 343 |
+
# 3. Truncate repeated phrases (2-20 words) at end
|
| 344 |
+
# e.g., "i am sorry i am sorry i am sorry" -> "i am sorry"
|
| 345 |
+
words = text.split()
|
| 346 |
+
if len(words) >= min_repeats * 2:
|
| 347 |
+
# Try phrase lengths from 2 to 20 words
|
| 348 |
+
for phrase_len in range(2, min(21, len(words) // min_repeats + 1)):
|
| 349 |
+
# Check if the last phrase_len words repeat
|
| 350 |
+
phrase = " ".join(words[-phrase_len:])
|
| 351 |
+
# Build pattern to match repeated phrases at end
|
| 352 |
+
phrase_escaped = re.escape(phrase)
|
| 353 |
+
phrase_pattern = re.compile(
|
| 354 |
+
r"(^|.*?\s)("
|
| 355 |
+
+ phrase_escaped
|
| 356 |
+
+ r")(?:\s+"
|
| 357 |
+
+ phrase_escaped
|
| 358 |
+
+ r"){"
|
| 359 |
+
+ str(min_repeats - 1)
|
| 360 |
+
+ r",}\s*$",
|
| 361 |
+
re.IGNORECASE,
|
| 362 |
+
)
|
| 363 |
+
match = phrase_pattern.match(text)
|
| 364 |
+
if match:
|
| 365 |
+
# Keep prefix + one instance of the phrase
|
| 366 |
+
text = (match.group(1) + match.group(2)).strip()
|
| 367 |
+
words = text.split()
|
| 368 |
+
break
|
| 369 |
+
|
| 370 |
+
return text
|
asr_processing.py
ADDED
|
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional, Union
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import transformers
|
| 5 |
+
from transformers import ProcessorMixin
|
| 6 |
+
|
| 7 |
+
try:
|
| 8 |
+
from .asr_config import ASRConfig
|
| 9 |
+
except ImportError:
|
| 10 |
+
from asr_config import ASRConfig # type: ignore[no-redef]
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class ASRProcessor(ProcessorMixin):
|
| 14 |
+
"""Processor for Whisper-based ASR models."""
|
| 15 |
+
|
| 16 |
+
attributes = ["feature_extractor", "tokenizer"]
|
| 17 |
+
feature_extractor_class = "AutoFeatureExtractor"
|
| 18 |
+
tokenizer_class = "AutoTokenizer"
|
| 19 |
+
AUDIO_TOKEN = "<audio>"
|
| 20 |
+
TRANSCRIBE_PROMPT = ""
|
| 21 |
+
# Default conv layers for Whisper/GLM-ASR: [(pad, kernel, stride), ...]
|
| 22 |
+
DEFAULT_ENCODER_CONV_LAYERS = [(1, 3, 1), (1, 3, 2)]
|
| 23 |
+
|
| 24 |
+
def __init__(
|
| 25 |
+
self,
|
| 26 |
+
feature_extractor,
|
| 27 |
+
tokenizer,
|
| 28 |
+
projector=None,
|
| 29 |
+
encoder_conv_layers: Optional[list] = None,
|
| 30 |
+
):
|
| 31 |
+
"""Initialize the ASR processor.
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
feature_extractor: Audio feature extractor (WhisperFeatureExtractor)
|
| 35 |
+
tokenizer: Text tokenizer for the language model
|
| 36 |
+
projector: Audio projector module (for computing output lengths)
|
| 37 |
+
encoder_conv_layers: Conv layer specs [(pad, kernel, stride), ...]
|
| 38 |
+
"""
|
| 39 |
+
self.feature_extractor = feature_extractor
|
| 40 |
+
self.tokenizer = tokenizer
|
| 41 |
+
self.audio_token_id = tokenizer.convert_tokens_to_ids(self.AUDIO_TOKEN)
|
| 42 |
+
self.projector = projector
|
| 43 |
+
self.encoder_conv_layers = encoder_conv_layers or self.DEFAULT_ENCODER_CONV_LAYERS
|
| 44 |
+
|
| 45 |
+
def _compute_encoder_output_length(self, mel_length: int) -> int:
|
| 46 |
+
"""Compute encoder output length using conv layer formulas."""
|
| 47 |
+
length = mel_length
|
| 48 |
+
for padding, kernel_size, stride in self.encoder_conv_layers:
|
| 49 |
+
length = (length + 2 * padding - (kernel_size - 1) - 1) // stride + 1
|
| 50 |
+
return length
|
| 51 |
+
|
| 52 |
+
def __call__(
|
| 53 |
+
self,
|
| 54 |
+
audio: Optional[Union[list, "torch.Tensor"]] = None,
|
| 55 |
+
text: Optional[str] = None,
|
| 56 |
+
system_prompt: Optional[str] = None,
|
| 57 |
+
return_tensors: str = "pt",
|
| 58 |
+
**kwargs,
|
| 59 |
+
) -> dict:
|
| 60 |
+
"""Process audio and text inputs for inference.
|
| 61 |
+
|
| 62 |
+
Args:
|
| 63 |
+
audio: Raw audio waveform(s)
|
| 64 |
+
text: Target transcription (optional, for training - but use DataCollator instead)
|
| 65 |
+
system_prompt: Optional system prompt
|
| 66 |
+
return_tensors: Return format ("pt" for PyTorch)
|
| 67 |
+
|
| 68 |
+
Returns:
|
| 69 |
+
Dict with input_features, input_ids, attention_mask
|
| 70 |
+
"""
|
| 71 |
+
result = {}
|
| 72 |
+
|
| 73 |
+
# Process audio
|
| 74 |
+
if audio is not None:
|
| 75 |
+
audio_inputs = self.feature_extractor(
|
| 76 |
+
audio,
|
| 77 |
+
sampling_rate=getattr(self.feature_extractor, "sampling_rate", 16000),
|
| 78 |
+
return_attention_mask=True,
|
| 79 |
+
return_tensors=return_tensors,
|
| 80 |
+
**kwargs,
|
| 81 |
+
)
|
| 82 |
+
result["input_features"] = audio_inputs["input_features"]
|
| 83 |
+
result["audio_attention_mask"] = audio_inputs["attention_mask"]
|
| 84 |
+
|
| 85 |
+
# Use actual audio length (from attention mask) for token count
|
| 86 |
+
real_mel_len = int(audio_inputs["attention_mask"].sum(dim=-1).max().item())
|
| 87 |
+
encoder_output_len = self._compute_encoder_output_length(real_mel_len)
|
| 88 |
+
num_audio_tokens = self.projector.get_output_length(encoder_output_len)
|
| 89 |
+
else:
|
| 90 |
+
num_audio_tokens = 0
|
| 91 |
+
|
| 92 |
+
# Build prompt with audio token placeholders (instruction-free)
|
| 93 |
+
if num_audio_tokens > 0:
|
| 94 |
+
user_content = self.AUDIO_TOKEN * num_audio_tokens
|
| 95 |
+
if self.TRANSCRIBE_PROMPT:
|
| 96 |
+
user_content += " " + self.TRANSCRIBE_PROMPT
|
| 97 |
+
else:
|
| 98 |
+
user_content = self.TRANSCRIBE_PROMPT or ""
|
| 99 |
+
|
| 100 |
+
messages = []
|
| 101 |
+
if system_prompt:
|
| 102 |
+
messages.append({"role": "system", "content": system_prompt})
|
| 103 |
+
messages.append({"role": "user", "content": user_content})
|
| 104 |
+
if text is not None:
|
| 105 |
+
messages.append({"role": "assistant", "content": text})
|
| 106 |
+
|
| 107 |
+
# Tokenize
|
| 108 |
+
tokenized = self.tokenizer.apply_chat_template(
|
| 109 |
+
messages,
|
| 110 |
+
tokenize=True,
|
| 111 |
+
add_generation_prompt=(text is None),
|
| 112 |
+
return_tensors=return_tensors,
|
| 113 |
+
enable_thinking=False, # Disable Qwen3 thinking mode for ASR
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
# Handle both tensor and BatchEncoding returns
|
| 117 |
+
if isinstance(tokenized, torch.Tensor):
|
| 118 |
+
input_ids = tokenized
|
| 119 |
+
else:
|
| 120 |
+
# BatchEncoding or dict-like object
|
| 121 |
+
input_ids = tokenized.get("input_ids", tokenized.input_ids)
|
| 122 |
+
|
| 123 |
+
if input_ids.dim() == 1:
|
| 124 |
+
input_ids = input_ids.unsqueeze(0)
|
| 125 |
+
|
| 126 |
+
result["input_ids"] = input_ids
|
| 127 |
+
result["attention_mask"] = torch.ones_like(input_ids)
|
| 128 |
+
|
| 129 |
+
return result
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
ASRProcessor.register_for_auto_class()
|
| 133 |
+
transformers.AutoProcessor.register(ASRConfig, ASRProcessor)
|
audio_head.py
ADDED
|
@@ -0,0 +1,357 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Audio head for speech-to-speech using a trainable AR decoder + NeuCodec.
|
| 2 |
+
|
| 3 |
+
Generates audio from text tokens via a trainable LlamaModel decoder:
|
| 4 |
+
Text tokens -> Embedding -> LlamaModel -> head -> NeuCodec FSQ codes -> audio
|
| 5 |
+
|
| 6 |
+
NeuCodec uses a single FSQ codebook (levels=[4]*8, vocab=65536) at 50 tokens/sec,
|
| 7 |
+
outputting 24kHz audio. No multi-codebook handling needed.
|
| 8 |
+
|
| 9 |
+
Training: S2SDataCollator prepares codec_input_ids/codec_labels (both 2D: [batch, seq_len]).
|
| 10 |
+
AudioHead predicts FSQ codes via a single head with teacher forcing.
|
| 11 |
+
|
| 12 |
+
Inference: Autoregressive generation with KV cache, feeding back predicted codes.
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
import logging
|
| 16 |
+
from dataclasses import dataclass
|
| 17 |
+
from typing import Iterator, Optional
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
import torch.nn as nn
|
| 21 |
+
from torch.nn import functional as F # noqa: N812
|
| 22 |
+
from transformers import PretrainedConfig, PreTrainedModel
|
| 23 |
+
from transformers.modeling_outputs import ModelOutput
|
| 24 |
+
|
| 25 |
+
logger = logging.getLogger(__name__)
|
| 26 |
+
|
| 27 |
+
# NeuCodec FSQ constants (levels=[4]*8, 1 quantizer -> 4^8 = 65536 codes)
|
| 28 |
+
NEUCODEC_VOCAB_SIZE = 65536
|
| 29 |
+
NEUCODEC_SAMPLE_RATE = 24000
|
| 30 |
+
|
| 31 |
+
# Special tokens (above vocab range)
|
| 32 |
+
BOS_TOKEN = NEUCODEC_VOCAB_SIZE
|
| 33 |
+
EOS_TOKEN = NEUCODEC_VOCAB_SIZE + 1
|
| 34 |
+
PAD_TOKEN = NEUCODEC_VOCAB_SIZE + 2
|
| 35 |
+
TOTAL_VOCAB = NEUCODEC_VOCAB_SIZE + 3 # 65539
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class AudioHeadConfig(PretrainedConfig):
|
| 39 |
+
"""Configuration class for the AudioHead model."""
|
| 40 |
+
|
| 41 |
+
model_type = "audio_head"
|
| 42 |
+
|
| 43 |
+
def __init__(
|
| 44 |
+
self,
|
| 45 |
+
decoder_dim: int = 512,
|
| 46 |
+
decoder_layers: int = 6,
|
| 47 |
+
decoder_heads: int = 8,
|
| 48 |
+
text_vocab_size: int = 32000,
|
| 49 |
+
max_audio_tokens: int = 500,
|
| 50 |
+
neucodec_model_id: str = "neuphonic/neucodec",
|
| 51 |
+
temperature: float = 1.0,
|
| 52 |
+
top_k: int = 50,
|
| 53 |
+
**kwargs,
|
| 54 |
+
):
|
| 55 |
+
self.decoder_dim = decoder_dim
|
| 56 |
+
self.decoder_layers = decoder_layers
|
| 57 |
+
self.decoder_heads = decoder_heads
|
| 58 |
+
self.text_vocab_size = text_vocab_size
|
| 59 |
+
self.max_audio_tokens = max_audio_tokens
|
| 60 |
+
self.neucodec_model_id = neucodec_model_id
|
| 61 |
+
self.temperature = temperature
|
| 62 |
+
self.top_k = top_k
|
| 63 |
+
super().__init__(**kwargs)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
@dataclass
|
| 67 |
+
class AudioHeadOutput(ModelOutput):
|
| 68 |
+
"""Output of AudioHead forward pass.
|
| 69 |
+
|
| 70 |
+
Attributes:
|
| 71 |
+
loss: Cross-entropy loss when codec_labels are provided.
|
| 72 |
+
codes: Generated codec codes when in inference mode [batch, gen_len].
|
| 73 |
+
"""
|
| 74 |
+
|
| 75 |
+
loss: Optional[torch.Tensor] = None
|
| 76 |
+
codes: Optional[torch.Tensor] = None
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
class AudioHead(PreTrainedModel):
|
| 80 |
+
"""Trainable AR decoder that predicts NeuCodec FSQ codes.
|
| 81 |
+
|
| 82 |
+
NeuCodec uses a single FSQ codebook (4^8 = 65536 codes) at 50 tokens/sec.
|
| 83 |
+
No multi-codebook handling needed — just a flat sequence of codes.
|
| 84 |
+
"""
|
| 85 |
+
|
| 86 |
+
config_class = AudioHeadConfig
|
| 87 |
+
|
| 88 |
+
def __init__(self, config: AudioHeadConfig):
|
| 89 |
+
super().__init__(config)
|
| 90 |
+
self.text_vocab_size = config.text_vocab_size
|
| 91 |
+
self.decoder_dim = config.decoder_dim
|
| 92 |
+
self.max_tokens = config.max_audio_tokens
|
| 93 |
+
self.vocab_size = NEUCODEC_VOCAB_SIZE
|
| 94 |
+
|
| 95 |
+
# Embed text tokens to decoder dim
|
| 96 |
+
self.text_embedding = nn.Embedding(config.text_vocab_size, config.decoder_dim)
|
| 97 |
+
|
| 98 |
+
# Codec token embedding (FSQ codes + special tokens)
|
| 99 |
+
self.token_embedding = nn.Embedding(TOTAL_VOCAB, config.decoder_dim)
|
| 100 |
+
|
| 101 |
+
# Small LlamaModel as decoder backbone (from config, NOT pretrained)
|
| 102 |
+
from transformers import LlamaConfig, LlamaModel
|
| 103 |
+
|
| 104 |
+
llama_config = LlamaConfig(
|
| 105 |
+
hidden_size=config.decoder_dim,
|
| 106 |
+
intermediate_size=config.decoder_dim * 4,
|
| 107 |
+
num_hidden_layers=config.decoder_layers,
|
| 108 |
+
num_attention_heads=config.decoder_heads,
|
| 109 |
+
vocab_size=TOTAL_VOCAB,
|
| 110 |
+
max_position_embeddings=4096,
|
| 111 |
+
)
|
| 112 |
+
self.decoder = LlamaModel(llama_config)
|
| 113 |
+
# We handle embeddings ourselves, remove the unused one to save memory
|
| 114 |
+
self.decoder.embed_tokens = None
|
| 115 |
+
|
| 116 |
+
# Sampling parameters for inference
|
| 117 |
+
self.temperature = config.temperature
|
| 118 |
+
self.top_k = config.top_k
|
| 119 |
+
|
| 120 |
+
# NeuCodec model (loaded lazily, frozen, inference only)
|
| 121 |
+
self.neucodec_model = None
|
| 122 |
+
|
| 123 |
+
# Initialize weights
|
| 124 |
+
self.post_init()
|
| 125 |
+
|
| 126 |
+
def forward(
|
| 127 |
+
self,
|
| 128 |
+
text_token_ids: torch.Tensor,
|
| 129 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 130 |
+
codec_labels: Optional[torch.Tensor] = None,
|
| 131 |
+
codec_input_ids: Optional[torch.Tensor] = None,
|
| 132 |
+
codec_attention_mask: Optional[torch.Tensor] = None,
|
| 133 |
+
**kwargs,
|
| 134 |
+
) -> AudioHeadOutput:
|
| 135 |
+
"""Forward pass for training or inference.
|
| 136 |
+
|
| 137 |
+
Args:
|
| 138 |
+
text_token_ids: Text token IDs [batch, seq_len]
|
| 139 |
+
attention_mask: Text attention mask [batch, seq_len] (1=real, 0=padding)
|
| 140 |
+
codec_labels: Target codes [batch, audio_len] (-100 for ignore)
|
| 141 |
+
codec_input_ids: Teacher-forced input [batch, audio_len]
|
| 142 |
+
codec_attention_mask: Codec attention mask [batch, audio_len]
|
| 143 |
+
|
| 144 |
+
Returns:
|
| 145 |
+
AudioHeadOutput with loss (training) or codes (inference).
|
| 146 |
+
"""
|
| 147 |
+
# Embed text tokens (clamp to valid range)
|
| 148 |
+
if (text_token_ids >= self.text_vocab_size).any() or (text_token_ids < 0).any():
|
| 149 |
+
logger.warning(
|
| 150 |
+
"text_token_ids out of range [0, %d): min=%d max=%d. Clamping.",
|
| 151 |
+
self.text_vocab_size, text_token_ids.min().item(), text_token_ids.max().item(),
|
| 152 |
+
)
|
| 153 |
+
text_token_ids = text_token_ids.clamp(0, self.text_vocab_size - 1)
|
| 154 |
+
prefix = self.text_embedding(text_token_ids) # [batch, text_len, decoder_dim]
|
| 155 |
+
batch_size, text_len, _ = prefix.shape
|
| 156 |
+
|
| 157 |
+
if codec_labels is not None:
|
| 158 |
+
# Teacher forcing: codec_input_ids is [batch, audio_len]
|
| 159 |
+
cb_input = codec_input_ids
|
| 160 |
+
if (cb_input >= TOTAL_VOCAB).any() or (cb_input < 0).any():
|
| 161 |
+
logger.warning(
|
| 162 |
+
"codec_input_ids out of range [0, %d): min=%d max=%d. Clamping.",
|
| 163 |
+
TOTAL_VOCAB, cb_input.min().item(), cb_input.max().item(),
|
| 164 |
+
)
|
| 165 |
+
cb_input = cb_input.clamp(0, TOTAL_VOCAB - 1)
|
| 166 |
+
token_emb = self.token_embedding(cb_input) # [batch, audio_len, dim]
|
| 167 |
+
|
| 168 |
+
audio_len = token_emb.shape[1]
|
| 169 |
+
|
| 170 |
+
# Concatenate prefix + codec tokens
|
| 171 |
+
hidden = torch.cat([prefix, token_emb], dim=1) # [batch, text+audio, dim]
|
| 172 |
+
|
| 173 |
+
# Build combined attention mask
|
| 174 |
+
if attention_mask is not None:
|
| 175 |
+
prefix_mask = attention_mask
|
| 176 |
+
else:
|
| 177 |
+
prefix_mask = torch.ones(
|
| 178 |
+
batch_size, text_len, device=hidden.device, dtype=torch.long
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
if codec_attention_mask is not None:
|
| 182 |
+
audio_mask = codec_attention_mask
|
| 183 |
+
else:
|
| 184 |
+
audio_mask = torch.ones(
|
| 185 |
+
batch_size, audio_len, device=hidden.device, dtype=torch.long
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
combined_mask = torch.cat([prefix_mask, audio_mask], dim=1)
|
| 189 |
+
|
| 190 |
+
# Build causal mask for codec positions while prefix attends bidirectionally
|
| 191 |
+
total_len = text_len + audio_len
|
| 192 |
+
causal_mask = torch.triu(
|
| 193 |
+
torch.full((total_len, total_len), float("-inf"), device=hidden.device),
|
| 194 |
+
diagonal=1,
|
| 195 |
+
)
|
| 196 |
+
causal_mask[:text_len, :text_len] = 0.0
|
| 197 |
+
causal_mask = causal_mask.unsqueeze(0).unsqueeze(0).expand(batch_size, 1, -1, -1)
|
| 198 |
+
|
| 199 |
+
padding_mask = (1 - combined_mask).bool()
|
| 200 |
+
padding_mask_expanded = padding_mask.unsqueeze(1).unsqueeze(2).expand_as(causal_mask)
|
| 201 |
+
causal_mask = causal_mask.masked_fill(padding_mask_expanded, float("-inf"))
|
| 202 |
+
|
| 203 |
+
position_ids = (
|
| 204 |
+
torch.arange(total_len, device=hidden.device).unsqueeze(0).expand(batch_size, -1)
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
# Run through LlamaModel
|
| 208 |
+
outputs = self.decoder(
|
| 209 |
+
inputs_embeds=hidden,
|
| 210 |
+
attention_mask=causal_mask,
|
| 211 |
+
position_ids=position_ids,
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
# Extract audio positions only
|
| 215 |
+
audio_hidden = outputs.last_hidden_state[:, text_len:] # [batch, audio_len, dim]
|
| 216 |
+
|
| 217 |
+
# Predict codes and compute loss
|
| 218 |
+
labels = codec_labels.clone() # [batch, audio_len]
|
| 219 |
+
valid_mask = labels != -100
|
| 220 |
+
labels[valid_mask] = labels[valid_mask].clamp(0, TOTAL_VOCAB - 1)
|
| 221 |
+
|
| 222 |
+
logits = F.linear(audio_hidden, self.token_embedding.weight) # [batch, audio_len, total_vocab]
|
| 223 |
+
loss = F.cross_entropy(
|
| 224 |
+
logits.reshape(-1, TOTAL_VOCAB),
|
| 225 |
+
labels.reshape(-1),
|
| 226 |
+
ignore_index=-100,
|
| 227 |
+
)
|
| 228 |
+
return AudioHeadOutput(loss=loss)
|
| 229 |
+
|
| 230 |
+
# Inference: autoregressive generation
|
| 231 |
+
codes = self._generate(prefix, attention_mask)
|
| 232 |
+
return AudioHeadOutput(codes=codes)
|
| 233 |
+
|
| 234 |
+
def _generate(
|
| 235 |
+
self, prefix: torch.Tensor, prefix_mask: Optional[torch.Tensor]
|
| 236 |
+
) -> torch.Tensor:
|
| 237 |
+
"""AR generation: predict codes one timestep at a time with KV cache."""
|
| 238 |
+
batch_size, text_len, _ = prefix.shape
|
| 239 |
+
device = prefix.device
|
| 240 |
+
|
| 241 |
+
all_codes = []
|
| 242 |
+
|
| 243 |
+
# Build initial input: prefix + BOS embedding
|
| 244 |
+
bos_token = torch.full((batch_size, 1), BOS_TOKEN, dtype=torch.long, device=device)
|
| 245 |
+
bos_emb = self.token_embedding(bos_token) # [batch, 1, dim]
|
| 246 |
+
hidden = torch.cat([prefix, bos_emb], dim=1) # [batch, text_len+1, dim]
|
| 247 |
+
|
| 248 |
+
# Position IDs for initial forward
|
| 249 |
+
position_ids = torch.arange(text_len + 1, device=device).unsqueeze(0).expand(batch_size, -1)
|
| 250 |
+
|
| 251 |
+
# Initial forward pass (no KV cache yet)
|
| 252 |
+
outputs = self.decoder(
|
| 253 |
+
inputs_embeds=hidden,
|
| 254 |
+
position_ids=position_ids,
|
| 255 |
+
use_cache=True,
|
| 256 |
+
)
|
| 257 |
+
past_key_values = outputs.past_key_values
|
| 258 |
+
last_hidden = outputs.last_hidden_state[:, -1:] # [batch, 1, dim]
|
| 259 |
+
|
| 260 |
+
for step in range(self.max_tokens):
|
| 261 |
+
# Predict code token
|
| 262 |
+
logits = F.linear(last_hidden.squeeze(1), self.token_embedding.weight) # [batch, vocab]
|
| 263 |
+
|
| 264 |
+
# Apply temperature and top-k sampling
|
| 265 |
+
if self.temperature > 0 and self.top_k > 0:
|
| 266 |
+
logits = logits / self.temperature
|
| 267 |
+
# Zero out logits below top-k threshold
|
| 268 |
+
top_k_vals, _ = logits.topk(self.top_k, dim=-1)
|
| 269 |
+
logits[logits < top_k_vals[:, -1:]] = float("-inf")
|
| 270 |
+
probs = F.softmax(logits, dim=-1)
|
| 271 |
+
token = torch.multinomial(probs, num_samples=1).squeeze(-1) # [batch]
|
| 272 |
+
else:
|
| 273 |
+
token = logits.argmax(dim=-1) # [batch]
|
| 274 |
+
|
| 275 |
+
# Check for EOS
|
| 276 |
+
if (token == EOS_TOKEN).all():
|
| 277 |
+
break
|
| 278 |
+
|
| 279 |
+
all_codes.append(token)
|
| 280 |
+
|
| 281 |
+
# Feed back prediction for next step
|
| 282 |
+
next_emb = self.token_embedding(token.unsqueeze(1)) # [batch, 1, dim]
|
| 283 |
+
|
| 284 |
+
next_pos = torch.full(
|
| 285 |
+
(batch_size, 1), text_len + 1 + step + 1, dtype=torch.long, device=device
|
| 286 |
+
)
|
| 287 |
+
|
| 288 |
+
# Forward with KV cache
|
| 289 |
+
outputs = self.decoder(
|
| 290 |
+
inputs_embeds=next_emb,
|
| 291 |
+
position_ids=next_pos,
|
| 292 |
+
past_key_values=past_key_values,
|
| 293 |
+
use_cache=True,
|
| 294 |
+
)
|
| 295 |
+
past_key_values = outputs.past_key_values
|
| 296 |
+
last_hidden = outputs.last_hidden_state # [batch, 1, dim]
|
| 297 |
+
|
| 298 |
+
if all_codes:
|
| 299 |
+
# [batch, gen_len]
|
| 300 |
+
codes = torch.stack(all_codes, dim=1)
|
| 301 |
+
else:
|
| 302 |
+
codes = torch.empty(batch_size, 0, dtype=torch.long, device=device)
|
| 303 |
+
|
| 304 |
+
return codes
|
| 305 |
+
|
| 306 |
+
def _load_neucodec(self):
|
| 307 |
+
"""Load frozen NeuCodec model for audio decoding."""
|
| 308 |
+
from neucodec import NeuCodec
|
| 309 |
+
|
| 310 |
+
self.neucodec_model = NeuCodec.from_pretrained(self.config.neucodec_model_id)
|
| 311 |
+
self.neucodec_model.eval()
|
| 312 |
+
self.neucodec_model.requires_grad_(False)
|
| 313 |
+
logger.info("Loaded frozen NeuCodec model for audio decoding")
|
| 314 |
+
|
| 315 |
+
def decode_to_audio(self, codes: torch.Tensor) -> list[torch.Tensor]:
|
| 316 |
+
"""Decode NeuCodec FSQ tokens to audio waveforms.
|
| 317 |
+
|
| 318 |
+
Args:
|
| 319 |
+
codes: Codec tokens [batch, seq_len]
|
| 320 |
+
|
| 321 |
+
Returns:
|
| 322 |
+
List of audio waveform tensors (one per batch item)
|
| 323 |
+
"""
|
| 324 |
+
if self.neucodec_model is None:
|
| 325 |
+
self._load_neucodec()
|
| 326 |
+
assert self.neucodec_model is not None
|
| 327 |
+
|
| 328 |
+
# NeuCodec decode_code expects [batch, 1, seq_len]
|
| 329 |
+
codes_3d = codes.unsqueeze(1).to(self.neucodec_model.device)
|
| 330 |
+
|
| 331 |
+
with torch.no_grad():
|
| 332 |
+
audio_values = self.neucodec_model.decode_code(codes_3d) # [batch, 1, samples]
|
| 333 |
+
|
| 334 |
+
return [audio_values[i, 0] for i in range(audio_values.shape[0])]
|
| 335 |
+
|
| 336 |
+
def generate_streaming(
|
| 337 |
+
self,
|
| 338 |
+
text_token_ids: torch.Tensor,
|
| 339 |
+
chunk_samples: int = 24000,
|
| 340 |
+
) -> Iterator[torch.Tensor]:
|
| 341 |
+
"""Generate audio and yield waveform chunks for streaming playback.
|
| 342 |
+
|
| 343 |
+
Args:
|
| 344 |
+
text_token_ids: Text token IDs [batch, seq_len]
|
| 345 |
+
chunk_samples: Audio samples per chunk (default 1s at 24kHz)
|
| 346 |
+
|
| 347 |
+
Yields:
|
| 348 |
+
Audio waveform chunks [samples]
|
| 349 |
+
"""
|
| 350 |
+
output = self(text_token_ids)
|
| 351 |
+
codes = output.codes
|
| 352 |
+
audios = self.decode_to_audio(codes)
|
| 353 |
+
|
| 354 |
+
for audio in audios:
|
| 355 |
+
for start in range(0, audio.shape[-1], chunk_samples):
|
| 356 |
+
end = min(start + chunk_samples, audio.shape[-1])
|
| 357 |
+
yield audio[..., start:end]
|
chat_template.jinja
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{# ───── defaults ───── #}
|
| 2 |
+
{%- if enable_thinking is not defined -%}
|
| 3 |
+
{%- set enable_thinking = true -%}
|
| 4 |
+
{%- endif -%}
|
| 5 |
+
|
| 6 |
+
{# ───── reasoning mode ───── #}
|
| 7 |
+
{%- if enable_thinking -%}
|
| 8 |
+
{%- set reasoning_mode = "/think" -%}
|
| 9 |
+
{%- else -%}
|
| 10 |
+
{%- set reasoning_mode = "/no_think" -%}
|
| 11 |
+
{%- endif -%}
|
| 12 |
+
|
| 13 |
+
{# ───── header (system message) ───── #}
|
| 14 |
+
{{- "<|im_start|>system\n" -}}
|
| 15 |
+
|
| 16 |
+
{%- if messages[0].role == "system" -%}
|
| 17 |
+
{%- set system_message = messages[0].content -%}
|
| 18 |
+
{%- if "/no_think" in system_message -%}
|
| 19 |
+
{%- set reasoning_mode = "/no_think" -%}
|
| 20 |
+
{%- elif "/think" in system_message -%}
|
| 21 |
+
{%- set reasoning_mode = "/think" -%}
|
| 22 |
+
{%- endif -%}
|
| 23 |
+
{%- set custom_instructions = system_message.replace("/no_think", "").replace("/think", "").rstrip() -%}
|
| 24 |
+
{%- endif -%}
|
| 25 |
+
|
| 26 |
+
{%- if "/system_override" in system_message -%}
|
| 27 |
+
{{- custom_instructions.replace("/system_override", "").rstrip() -}}
|
| 28 |
+
{{- "<|im_end|>\n" -}}
|
| 29 |
+
{%- else -%}
|
| 30 |
+
{{- "## Metadata\n\n" -}}
|
| 31 |
+
{{- "Knowledge Cutoff Date: June 2025\n" -}}
|
| 32 |
+
{%- set today = strftime_now("%d %B %Y") -%}
|
| 33 |
+
{{- "Today Date: " ~ today ~ "\n" -}}
|
| 34 |
+
{{- "Reasoning Mode: " + reasoning_mode + "\n\n" -}}
|
| 35 |
+
|
| 36 |
+
{{- "## Custom Instructions\n\n" -}}
|
| 37 |
+
{%- if custom_instructions -%}
|
| 38 |
+
{{- custom_instructions + "\n\n" -}}
|
| 39 |
+
{%- elif reasoning_mode == "/think" -%}
|
| 40 |
+
{{- "You are a helpful AI assistant named SmolLM, trained by Hugging Face. Your role as an assistant involves thoroughly exploring questions through a systematic thinking process before providing the final precise and accurate solutions. This requires engaging in a comprehensive cycle of analysis, summarizing, exploration, reassessment, reflection, backtracking, and iteration to develop well-considered thinking process. Please structure your response into two main sections: Thought and Solution using the specified format: <think> Thought section </think> Solution section. In the Thought section, detail your reasoning process in steps. Each step should include detailed considerations such as analysing questions, summarizing relevant findings, brainstorming new ideas, verifying the accuracy of the current steps, refining any errors, and revisiting previous steps. In the Solution section, based on various attempts, explorations, and reflections from the Thought section, systematically present the final solution that you deem correct. The Solution section should be logical, accurate, and concise and detail necessary steps needed to reach the conclusion.\n\n" -}}
|
| 41 |
+
{%- else -%}
|
| 42 |
+
{{- "You are a helpful AI assistant named SmolLM, trained by Hugging Face.\n\n" -}}
|
| 43 |
+
{%- endif -%}
|
| 44 |
+
|
| 45 |
+
{%- if xml_tools or python_tools or tools -%}
|
| 46 |
+
{{- "### Tools\n\n" -}}
|
| 47 |
+
{%- if xml_tools or tools -%}
|
| 48 |
+
{%- if tools -%}
|
| 49 |
+
{%- set xml_tools = tools -%}
|
| 50 |
+
{%- endif -%}
|
| 51 |
+
{%- set ns = namespace(xml_tool_string="You may call one or more functions to assist with the user query.\nYou are provided with function signatures within <tools></tools> XML tags:\n\n<tools>\n") -%}
|
| 52 |
+
{%- for tool in xml_tools[:] -%} {# The slicing makes sure that xml_tools is a list #}
|
| 53 |
+
{%- set ns.xml_tool_string = ns.xml_tool_string ~ (tool | string) ~ "\n" -%}
|
| 54 |
+
{%- endfor -%}
|
| 55 |
+
{%- set xml_tool_string = ns.xml_tool_string + "</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call>" -%}
|
| 56 |
+
{{- xml_tool_string -}}
|
| 57 |
+
{%- endif -%}
|
| 58 |
+
{%- if python_tools -%}
|
| 59 |
+
{%- set ns = namespace(python_tool_string="When you send a message containing Python code between '<code>' and '</code>' tags, it will be executed in a stateful Jupyter notebook environment, and you will then be given the output to continued reasoning in an agentic loop.\n\nYou can use the following tools in your python code like regular functions:\n<tools>\n") -%}
|
| 60 |
+
{%- for tool in python_tools[:] -%} {# The slicing makes sure that python_tools is a list #}
|
| 61 |
+
{%- set ns.python_tool_string = ns.python_tool_string ~ (tool | string) ~ "\n" -%}
|
| 62 |
+
{%- endfor -%}
|
| 63 |
+
{%- set python_tool_string = ns.python_tool_string + "</tools>\n\nThe state persists between code executions: so variables that you define in one step are still available thereafter." -%}
|
| 64 |
+
{{- python_tool_string -}}
|
| 65 |
+
{%- endif -%}
|
| 66 |
+
{{- "\n\n" -}}
|
| 67 |
+
{{- "<|im_end|>\n" -}}
|
| 68 |
+
{%- endif -%}
|
| 69 |
+
{%- endif -%}
|
| 70 |
+
{# ───── main loop ───── #}
|
| 71 |
+
{%- for message in messages -%}
|
| 72 |
+
{%- set content = message.content if message.content is string else "" -%}
|
| 73 |
+
{%- if message.role == "user" -%}
|
| 74 |
+
{{ "<|im_start|>" + message.role + "\n" + content + "<|im_end|>\n" }}
|
| 75 |
+
{%- elif message.role == "assistant" -%}
|
| 76 |
+
{% generation %}
|
| 77 |
+
{%- if reasoning_mode == "/think" -%}
|
| 78 |
+
{{ "<|im_start|>assistant\n" + content.lstrip("\n") + "<|im_end|>\n" }}
|
| 79 |
+
{%- else -%}
|
| 80 |
+
{{ "<|im_start|>assistant\n" + "<think>\n\n</think>\n" + content.lstrip("\n") + "<|im_end|>\n" }}
|
| 81 |
+
{%- endif -%}
|
| 82 |
+
{% endgeneration %}
|
| 83 |
+
{%- elif message.role == "tool" -%}
|
| 84 |
+
{{ "<|im_start|>" + "user\n" + content + "<|im_end|>\n" }}
|
| 85 |
+
{%- endif -%}
|
| 86 |
+
{%- endfor -%}
|
| 87 |
+
{# ───── generation prompt ───── #}
|
| 88 |
+
{%- if add_generation_prompt -%}
|
| 89 |
+
{%- if reasoning_mode == "/think" -%}
|
| 90 |
+
{{ "<|im_start|>assistant\n" }}
|
| 91 |
+
{%- else -%}
|
| 92 |
+
{{ "<|im_start|>assistant\n" + "<think>\n\n</think>\n" }}
|
| 93 |
+
{%- endif -%}
|
| 94 |
+
{%- endif -%}
|
config.json
ADDED
|
@@ -0,0 +1,409 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"architectures": [
|
| 3 |
+
"ASRModel"
|
| 4 |
+
],
|
| 5 |
+
"attn_implementation": null,
|
| 6 |
+
"audio_config": {
|
| 7 |
+
"_name_or_path": "zai-org/GLM-ASR-Nano-2512",
|
| 8 |
+
"architectures": [
|
| 9 |
+
"GlmAsrForConditionalGeneration"
|
| 10 |
+
],
|
| 11 |
+
"audio_config": {
|
| 12 |
+
"_name_or_path": "",
|
| 13 |
+
"architectures": null,
|
| 14 |
+
"attention_dropout": 0.0,
|
| 15 |
+
"chunk_size_feed_forward": 0,
|
| 16 |
+
"dtype": null,
|
| 17 |
+
"head_dim": 64,
|
| 18 |
+
"hidden_act": "gelu",
|
| 19 |
+
"hidden_size": 1280,
|
| 20 |
+
"id2label": {
|
| 21 |
+
"0": "LABEL_0",
|
| 22 |
+
"1": "LABEL_1"
|
| 23 |
+
},
|
| 24 |
+
"initializer_range": 0.02,
|
| 25 |
+
"intermediate_size": 5120,
|
| 26 |
+
"is_encoder_decoder": false,
|
| 27 |
+
"label2id": {
|
| 28 |
+
"LABEL_0": 0,
|
| 29 |
+
"LABEL_1": 1
|
| 30 |
+
},
|
| 31 |
+
"max_position_embeddings": 1500,
|
| 32 |
+
"model_type": "glmasr_encoder",
|
| 33 |
+
"num_attention_heads": 20,
|
| 34 |
+
"num_hidden_layers": 32,
|
| 35 |
+
"num_key_value_heads": 20,
|
| 36 |
+
"num_mel_bins": 128,
|
| 37 |
+
"output_attentions": false,
|
| 38 |
+
"output_hidden_states": false,
|
| 39 |
+
"partial_rotary_factor": 0.5,
|
| 40 |
+
"problem_type": null,
|
| 41 |
+
"return_dict": true,
|
| 42 |
+
"rope_parameters": {
|
| 43 |
+
"partial_rotary_factor": 0.5,
|
| 44 |
+
"rope_theta": 10000.0,
|
| 45 |
+
"rope_type": "default"
|
| 46 |
+
}
|
| 47 |
+
},
|
| 48 |
+
"audio_token_id": 59260,
|
| 49 |
+
"dtype": "bfloat16",
|
| 50 |
+
"hidden_size": 2048,
|
| 51 |
+
"model_type": "glmasr",
|
| 52 |
+
"num_mel_bins": 128,
|
| 53 |
+
"projector_hidden_act": "gelu",
|
| 54 |
+
"text_config": {
|
| 55 |
+
"_name_or_path": "",
|
| 56 |
+
"architectures": null,
|
| 57 |
+
"attention_bias": false,
|
| 58 |
+
"attention_dropout": 0.0,
|
| 59 |
+
"bos_token_id": 1,
|
| 60 |
+
"chunk_size_feed_forward": 0,
|
| 61 |
+
"dtype": null,
|
| 62 |
+
"eos_token_id": [
|
| 63 |
+
59246,
|
| 64 |
+
59253,
|
| 65 |
+
59255
|
| 66 |
+
],
|
| 67 |
+
"head_dim": 128,
|
| 68 |
+
"hidden_act": "silu",
|
| 69 |
+
"hidden_size": 2048,
|
| 70 |
+
"id2label": {
|
| 71 |
+
"0": "LABEL_0",
|
| 72 |
+
"1": "LABEL_1"
|
| 73 |
+
},
|
| 74 |
+
"initializer_range": 0.02,
|
| 75 |
+
"intermediate_size": 6144,
|
| 76 |
+
"is_encoder_decoder": false,
|
| 77 |
+
"label2id": {
|
| 78 |
+
"LABEL_0": 0,
|
| 79 |
+
"LABEL_1": 1
|
| 80 |
+
},
|
| 81 |
+
"max_position_embeddings": 8192,
|
| 82 |
+
"mlp_bias": false,
|
| 83 |
+
"model_type": "llama",
|
| 84 |
+
"num_attention_heads": 16,
|
| 85 |
+
"num_hidden_layers": 28,
|
| 86 |
+
"num_key_value_heads": 4,
|
| 87 |
+
"output_attentions": false,
|
| 88 |
+
"output_hidden_states": false,
|
| 89 |
+
"pad_token_id": null,
|
| 90 |
+
"pretraining_tp": 1,
|
| 91 |
+
"problem_type": null,
|
| 92 |
+
"return_dict": true,
|
| 93 |
+
"rms_norm_eps": 1e-05,
|
| 94 |
+
"rope_parameters": {
|
| 95 |
+
"rope_theta": 10000.0,
|
| 96 |
+
"rope_type": "default"
|
| 97 |
+
},
|
| 98 |
+
"tie_word_embeddings": false,
|
| 99 |
+
"use_cache": true,
|
| 100 |
+
"vocab_size": 59264
|
| 101 |
+
},
|
| 102 |
+
"vocab_size": 59264
|
| 103 |
+
},
|
| 104 |
+
"audio_model_id": "zai-org/GLM-ASR-Nano-2512",
|
| 105 |
+
"audio_sample_rate": 16000,
|
| 106 |
+
"auto_map": {
|
| 107 |
+
"AutoConfig": "asr_config.ASRConfig",
|
| 108 |
+
"AutoModel": "asr_modeling.ASRModel",
|
| 109 |
+
"AutoModelForSpeechSeq2Seq": "asr_modeling.ASRModel",
|
| 110 |
+
"AutoProcessor": "asr_processing.ASRProcessor"
|
| 111 |
+
},
|
| 112 |
+
"custom_pipelines": {
|
| 113 |
+
"automatic-speech-recognition": {
|
| 114 |
+
"impl": "asr_pipeline.ASRPipeline",
|
| 115 |
+
"pt": [
|
| 116 |
+
"AutoModelForSpeechSeq2Seq"
|
| 117 |
+
],
|
| 118 |
+
"tf": [],
|
| 119 |
+
"type": "audio"
|
| 120 |
+
}
|
| 121 |
+
},
|
| 122 |
+
"decoder_dim": 256,
|
| 123 |
+
"decoder_heads": 4,
|
| 124 |
+
"decoder_layers": 4,
|
| 125 |
+
"do_sample": false,
|
| 126 |
+
"downsample_rate": 5,
|
| 127 |
+
"dtype": "bfloat16",
|
| 128 |
+
"enable_thinking": false,
|
| 129 |
+
"encoder": {
|
| 130 |
+
"_name_or_path": "zai-org/GLM-ASR-Nano-2512",
|
| 131 |
+
"architectures": [
|
| 132 |
+
"GlmAsrForConditionalGeneration"
|
| 133 |
+
],
|
| 134 |
+
"audio_config": {
|
| 135 |
+
"_name_or_path": "",
|
| 136 |
+
"architectures": null,
|
| 137 |
+
"attention_dropout": 0.0,
|
| 138 |
+
"chunk_size_feed_forward": 0,
|
| 139 |
+
"dtype": null,
|
| 140 |
+
"head_dim": 64,
|
| 141 |
+
"hidden_act": "gelu",
|
| 142 |
+
"hidden_size": 1280,
|
| 143 |
+
"id2label": {
|
| 144 |
+
"0": "LABEL_0",
|
| 145 |
+
"1": "LABEL_1"
|
| 146 |
+
},
|
| 147 |
+
"initializer_range": 0.02,
|
| 148 |
+
"intermediate_size": 5120,
|
| 149 |
+
"is_encoder_decoder": false,
|
| 150 |
+
"label2id": {
|
| 151 |
+
"LABEL_0": 0,
|
| 152 |
+
"LABEL_1": 1
|
| 153 |
+
},
|
| 154 |
+
"max_position_embeddings": 1500,
|
| 155 |
+
"model_type": "glmasr_encoder",
|
| 156 |
+
"num_attention_heads": 20,
|
| 157 |
+
"num_hidden_layers": 32,
|
| 158 |
+
"num_key_value_heads": 20,
|
| 159 |
+
"num_mel_bins": 128,
|
| 160 |
+
"output_attentions": false,
|
| 161 |
+
"output_hidden_states": false,
|
| 162 |
+
"partial_rotary_factor": 0.5,
|
| 163 |
+
"problem_type": null,
|
| 164 |
+
"return_dict": true,
|
| 165 |
+
"rope_parameters": {
|
| 166 |
+
"partial_rotary_factor": 0.5,
|
| 167 |
+
"rope_theta": 10000.0,
|
| 168 |
+
"rope_type": "default"
|
| 169 |
+
}
|
| 170 |
+
},
|
| 171 |
+
"audio_token_id": 59260,
|
| 172 |
+
"dtype": "bfloat16",
|
| 173 |
+
"hidden_size": 2048,
|
| 174 |
+
"model_type": "glmasr",
|
| 175 |
+
"num_mel_bins": 128,
|
| 176 |
+
"projector_hidden_act": "gelu",
|
| 177 |
+
"text_config": {
|
| 178 |
+
"_name_or_path": "",
|
| 179 |
+
"architectures": null,
|
| 180 |
+
"attention_bias": false,
|
| 181 |
+
"attention_dropout": 0.0,
|
| 182 |
+
"bos_token_id": 1,
|
| 183 |
+
"chunk_size_feed_forward": 0,
|
| 184 |
+
"dtype": null,
|
| 185 |
+
"eos_token_id": [
|
| 186 |
+
59246,
|
| 187 |
+
59253,
|
| 188 |
+
59255
|
| 189 |
+
],
|
| 190 |
+
"head_dim": 128,
|
| 191 |
+
"hidden_act": "silu",
|
| 192 |
+
"hidden_size": 2048,
|
| 193 |
+
"id2label": {
|
| 194 |
+
"0": "LABEL_0",
|
| 195 |
+
"1": "LABEL_1"
|
| 196 |
+
},
|
| 197 |
+
"initializer_range": 0.02,
|
| 198 |
+
"intermediate_size": 6144,
|
| 199 |
+
"is_encoder_decoder": false,
|
| 200 |
+
"label2id": {
|
| 201 |
+
"LABEL_0": 0,
|
| 202 |
+
"LABEL_1": 1
|
| 203 |
+
},
|
| 204 |
+
"max_position_embeddings": 8192,
|
| 205 |
+
"mlp_bias": false,
|
| 206 |
+
"model_type": "llama",
|
| 207 |
+
"num_attention_heads": 16,
|
| 208 |
+
"num_hidden_layers": 28,
|
| 209 |
+
"num_key_value_heads": 4,
|
| 210 |
+
"output_attentions": false,
|
| 211 |
+
"output_hidden_states": false,
|
| 212 |
+
"pad_token_id": null,
|
| 213 |
+
"pretraining_tp": 1,
|
| 214 |
+
"problem_type": null,
|
| 215 |
+
"return_dict": true,
|
| 216 |
+
"rms_norm_eps": 1e-05,
|
| 217 |
+
"rope_parameters": {
|
| 218 |
+
"rope_theta": 10000.0,
|
| 219 |
+
"rope_type": "default"
|
| 220 |
+
},
|
| 221 |
+
"tie_word_embeddings": false,
|
| 222 |
+
"use_cache": true,
|
| 223 |
+
"vocab_size": 59264
|
| 224 |
+
},
|
| 225 |
+
"vocab_size": 59264
|
| 226 |
+
},
|
| 227 |
+
"encoder_conv_layers": [
|
| 228 |
+
[
|
| 229 |
+
1,
|
| 230 |
+
3,
|
| 231 |
+
1
|
| 232 |
+
],
|
| 233 |
+
[
|
| 234 |
+
1,
|
| 235 |
+
3,
|
| 236 |
+
2
|
| 237 |
+
]
|
| 238 |
+
],
|
| 239 |
+
"encoder_dim": 1280,
|
| 240 |
+
"freeze_audio_head": false,
|
| 241 |
+
"freeze_projector": false,
|
| 242 |
+
"freq_mask_length": 27,
|
| 243 |
+
"label_smoothing": 0.0,
|
| 244 |
+
"length_penalty": 1.0,
|
| 245 |
+
"llm_dim": 2048,
|
| 246 |
+
"lora_alpha": 32,
|
| 247 |
+
"lora_dropout": 0.0,
|
| 248 |
+
"lora_rank": 8,
|
| 249 |
+
"lora_target_modules": [
|
| 250 |
+
"q_proj",
|
| 251 |
+
"k_proj",
|
| 252 |
+
"v_proj",
|
| 253 |
+
"o_proj",
|
| 254 |
+
"gate_proj",
|
| 255 |
+
"up_proj",
|
| 256 |
+
"down_proj"
|
| 257 |
+
],
|
| 258 |
+
"max_audio_tokens": 500,
|
| 259 |
+
"max_new_tokens": 128,
|
| 260 |
+
"min_new_tokens": 0,
|
| 261 |
+
"model_dtype": "bfloat16",
|
| 262 |
+
"model_type": "asr_model",
|
| 263 |
+
"neucodec_model_id": "neuphonic/neucodec",
|
| 264 |
+
"no_repeat_ngram_size": 0,
|
| 265 |
+
"num_beams": 1,
|
| 266 |
+
"num_experts": 4,
|
| 267 |
+
"num_experts_per_tok": 2,
|
| 268 |
+
"num_freq_masks": 2,
|
| 269 |
+
"num_time_masks": 2,
|
| 270 |
+
"pipeline_tag": "automatic-speech-recognition",
|
| 271 |
+
"pretrained_model_path": "mazesmazes/tiny-audio-s2s-full",
|
| 272 |
+
"projector_dropout": 0.0,
|
| 273 |
+
"projector_hidden_dim": 1024,
|
| 274 |
+
"projector_init_std": 0.02,
|
| 275 |
+
"projector_num_layers": 2,
|
| 276 |
+
"projector_pool_stride": 4,
|
| 277 |
+
"projector_type": "mlp",
|
| 278 |
+
"qformer_hidden_size": null,
|
| 279 |
+
"qformer_intermediate_size": null,
|
| 280 |
+
"qformer_num_heads": 16,
|
| 281 |
+
"qformer_num_layers": 2,
|
| 282 |
+
"qformer_window_size": 15,
|
| 283 |
+
"repetition_penalty": 1.1,
|
| 284 |
+
"router_aux_loss_coef": 0.01,
|
| 285 |
+
"system_prompt": "",
|
| 286 |
+
"temperature": 1.0,
|
| 287 |
+
"text_config": {
|
| 288 |
+
"_name_or_path": "HuggingFaceTB/SmolLM3-3B",
|
| 289 |
+
"architectures": [
|
| 290 |
+
"SmolLM3ForCausalLM"
|
| 291 |
+
],
|
| 292 |
+
"attention_bias": false,
|
| 293 |
+
"attention_dropout": 0.0,
|
| 294 |
+
"bos_token_id": null,
|
| 295 |
+
"dtype": "bfloat16",
|
| 296 |
+
"eos_token_id": 128012,
|
| 297 |
+
"hidden_act": "silu",
|
| 298 |
+
"hidden_size": 2048,
|
| 299 |
+
"initializer_range": 0.02,
|
| 300 |
+
"intermediate_size": 11008,
|
| 301 |
+
"layer_types": [
|
| 302 |
+
"full_attention",
|
| 303 |
+
"full_attention",
|
| 304 |
+
"full_attention",
|
| 305 |
+
"full_attention",
|
| 306 |
+
"full_attention",
|
| 307 |
+
"full_attention",
|
| 308 |
+
"full_attention",
|
| 309 |
+
"full_attention",
|
| 310 |
+
"full_attention",
|
| 311 |
+
"full_attention",
|
| 312 |
+
"full_attention",
|
| 313 |
+
"full_attention",
|
| 314 |
+
"full_attention",
|
| 315 |
+
"full_attention",
|
| 316 |
+
"full_attention",
|
| 317 |
+
"full_attention",
|
| 318 |
+
"full_attention",
|
| 319 |
+
"full_attention",
|
| 320 |
+
"full_attention",
|
| 321 |
+
"full_attention",
|
| 322 |
+
"full_attention",
|
| 323 |
+
"full_attention",
|
| 324 |
+
"full_attention",
|
| 325 |
+
"full_attention",
|
| 326 |
+
"full_attention",
|
| 327 |
+
"full_attention",
|
| 328 |
+
"full_attention",
|
| 329 |
+
"full_attention",
|
| 330 |
+
"full_attention",
|
| 331 |
+
"full_attention",
|
| 332 |
+
"full_attention",
|
| 333 |
+
"full_attention",
|
| 334 |
+
"full_attention",
|
| 335 |
+
"full_attention",
|
| 336 |
+
"full_attention",
|
| 337 |
+
"full_attention"
|
| 338 |
+
],
|
| 339 |
+
"max_position_embeddings": 65536,
|
| 340 |
+
"max_window_layers": 28,
|
| 341 |
+
"mlp_bias": false,
|
| 342 |
+
"model_type": "smollm3",
|
| 343 |
+
"no_rope_layer_interval": 4,
|
| 344 |
+
"no_rope_layers": [
|
| 345 |
+
1,
|
| 346 |
+
1,
|
| 347 |
+
1,
|
| 348 |
+
0,
|
| 349 |
+
1,
|
| 350 |
+
1,
|
| 351 |
+
1,
|
| 352 |
+
0,
|
| 353 |
+
1,
|
| 354 |
+
1,
|
| 355 |
+
1,
|
| 356 |
+
0,
|
| 357 |
+
1,
|
| 358 |
+
1,
|
| 359 |
+
1,
|
| 360 |
+
0,
|
| 361 |
+
1,
|
| 362 |
+
1,
|
| 363 |
+
1,
|
| 364 |
+
0,
|
| 365 |
+
1,
|
| 366 |
+
1,
|
| 367 |
+
1,
|
| 368 |
+
0,
|
| 369 |
+
1,
|
| 370 |
+
1,
|
| 371 |
+
1,
|
| 372 |
+
0,
|
| 373 |
+
1,
|
| 374 |
+
1,
|
| 375 |
+
1,
|
| 376 |
+
0,
|
| 377 |
+
1,
|
| 378 |
+
1,
|
| 379 |
+
1,
|
| 380 |
+
0
|
| 381 |
+
],
|
| 382 |
+
"num_attention_heads": 16,
|
| 383 |
+
"num_hidden_layers": 36,
|
| 384 |
+
"num_key_value_heads": 4,
|
| 385 |
+
"pad_token_id": 128004,
|
| 386 |
+
"pretraining_tp": 2,
|
| 387 |
+
"rms_norm_eps": 1e-06,
|
| 388 |
+
"rope_parameters": {
|
| 389 |
+
"rope_theta": 5000000.0,
|
| 390 |
+
"rope_type": "default"
|
| 391 |
+
},
|
| 392 |
+
"sliding_window": null,
|
| 393 |
+
"tie_word_embeddings": true,
|
| 394 |
+
"use_cache": false,
|
| 395 |
+
"use_sliding_window": false,
|
| 396 |
+
"vocab_size": 128320
|
| 397 |
+
},
|
| 398 |
+
"text_model_id": "HuggingFaceTB/SmolLM3-3B",
|
| 399 |
+
"text_vocab_size": 128257,
|
| 400 |
+
"time_mask_length": 100,
|
| 401 |
+
"top_k": 50,
|
| 402 |
+
"top_p": null,
|
| 403 |
+
"transformers_version": "5.0.0",
|
| 404 |
+
"use_audio_head": true,
|
| 405 |
+
"use_cache": false,
|
| 406 |
+
"use_lora": false,
|
| 407 |
+
"use_specaugment": true,
|
| 408 |
+
"vocab_size": 128320
|
| 409 |
+
}
|
diarization.py
ADDED
|
@@ -0,0 +1,706 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Speaker diarization using TEN-VAD + ECAPA-TDNN + spectral clustering.
|
| 2 |
+
|
| 3 |
+
Spectral clustering implementation adapted from FunASR/3D-Speaker:
|
| 4 |
+
https://github.com/alibaba-damo-academy/FunASR
|
| 5 |
+
MIT License (https://opensource.org/licenses/MIT)
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import warnings
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
import scipy
|
| 12 |
+
import sklearn.metrics.pairwise
|
| 13 |
+
import torch
|
| 14 |
+
from sklearn.cluster._kmeans import k_means
|
| 15 |
+
from sklearn.preprocessing import normalize
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def _get_device() -> torch.device:
|
| 19 |
+
"""Get best available device for inference."""
|
| 20 |
+
if torch.cuda.is_available():
|
| 21 |
+
return torch.device("cuda")
|
| 22 |
+
if torch.backends.mps.is_available():
|
| 23 |
+
return torch.device("mps")
|
| 24 |
+
return torch.device("cpu")
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class SpectralCluster:
|
| 28 |
+
"""Spectral clustering using unnormalized Laplacian of affinity matrix.
|
| 29 |
+
|
| 30 |
+
Adapted from FunASR/3D-Speaker and SpeechBrain implementations.
|
| 31 |
+
Uses eigenvalue gap to automatically determine number of speakers.
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
def __init__(self, min_num_spks: int = 1, max_num_spks: int = 15, pval: float = 0.06):
|
| 35 |
+
self.min_num_spks = min_num_spks
|
| 36 |
+
self.max_num_spks = max_num_spks
|
| 37 |
+
self.pval = pval
|
| 38 |
+
|
| 39 |
+
def __call__(self, embeddings: np.ndarray, oracle_num: int | None = None) -> np.ndarray:
|
| 40 |
+
"""Run spectral clustering on embeddings.
|
| 41 |
+
|
| 42 |
+
Args:
|
| 43 |
+
embeddings: Speaker embeddings of shape [N, D]
|
| 44 |
+
oracle_num: Optional known number of speakers
|
| 45 |
+
|
| 46 |
+
Returns:
|
| 47 |
+
Cluster labels of shape [N]
|
| 48 |
+
"""
|
| 49 |
+
# Similarity matrix computation
|
| 50 |
+
sim_mat = self.get_sim_mat(embeddings)
|
| 51 |
+
|
| 52 |
+
# Refining similarity matrix with pval
|
| 53 |
+
prunned_sim_mat = self.p_pruning(sim_mat)
|
| 54 |
+
|
| 55 |
+
# Symmetrization
|
| 56 |
+
sym_prund_sim_mat = 0.5 * (prunned_sim_mat + prunned_sim_mat.T)
|
| 57 |
+
|
| 58 |
+
# Laplacian calculation
|
| 59 |
+
laplacian = self.get_laplacian(sym_prund_sim_mat)
|
| 60 |
+
|
| 61 |
+
# Get Spectral Embeddings
|
| 62 |
+
emb, num_of_spk = self.get_spec_embs(laplacian, oracle_num)
|
| 63 |
+
|
| 64 |
+
# Perform clustering
|
| 65 |
+
return self.cluster_embs(emb, num_of_spk)
|
| 66 |
+
|
| 67 |
+
def get_sim_mat(self, embeddings: np.ndarray) -> np.ndarray:
|
| 68 |
+
"""Compute cosine similarity matrix."""
|
| 69 |
+
return sklearn.metrics.pairwise.cosine_similarity(embeddings, embeddings)
|
| 70 |
+
|
| 71 |
+
def p_pruning(self, affinity: np.ndarray) -> np.ndarray:
|
| 72 |
+
"""Prune low similarity values in affinity matrix (keep top pval fraction)."""
|
| 73 |
+
n = affinity.shape[0]
|
| 74 |
+
pval = max(self.pval, 6.0 / n)
|
| 75 |
+
k_keep = max(1, int(pval * n))
|
| 76 |
+
|
| 77 |
+
# Vectorized: find top-k indices per row and zero out the rest
|
| 78 |
+
top_k_idx = np.argpartition(affinity, -k_keep, axis=1)[:, -k_keep:]
|
| 79 |
+
mask = np.zeros_like(affinity, dtype=bool)
|
| 80 |
+
np.put_along_axis(mask, top_k_idx, True, axis=1)
|
| 81 |
+
affinity[~mask] = 0
|
| 82 |
+
return affinity
|
| 83 |
+
|
| 84 |
+
def get_laplacian(self, sim_mat: np.ndarray) -> np.ndarray:
|
| 85 |
+
"""Compute unnormalized Laplacian matrix."""
|
| 86 |
+
from scipy.sparse.csgraph import laplacian
|
| 87 |
+
|
| 88 |
+
np.fill_diagonal(sim_mat, 0)
|
| 89 |
+
return laplacian(sim_mat, normed=False)
|
| 90 |
+
|
| 91 |
+
def get_spec_embs(
|
| 92 |
+
self, laplacian: np.ndarray, k_oracle: int | None = None
|
| 93 |
+
) -> tuple[np.ndarray, int]:
|
| 94 |
+
"""Extract spectral embeddings from Laplacian.
|
| 95 |
+
|
| 96 |
+
Uses the eigengap heuristic to estimate the number of clusters:
|
| 97 |
+
The number of clusters k is chosen where the gap between consecutive
|
| 98 |
+
eigenvalues is largest, indicating a transition from "cluster" eigenvalues
|
| 99 |
+
(near 0) to "noise" eigenvalues.
|
| 100 |
+
"""
|
| 101 |
+
lambdas, eig_vecs = scipy.linalg.eigh(laplacian)
|
| 102 |
+
|
| 103 |
+
num_of_spk = k_oracle if k_oracle is not None else self._estimate_num_speakers(lambdas)
|
| 104 |
+
|
| 105 |
+
emb = eig_vecs[:, :num_of_spk]
|
| 106 |
+
return emb, num_of_spk
|
| 107 |
+
|
| 108 |
+
def _estimate_num_speakers(self, lambdas: np.ndarray) -> int:
|
| 109 |
+
"""Estimate number of speakers using refined eigengap heuristic.
|
| 110 |
+
|
| 111 |
+
For spectral clustering, we look for the largest gap in eigenvalues.
|
| 112 |
+
The eigenvalues corresponding to clusters are close to 0, and there
|
| 113 |
+
should be a significant jump to the remaining eigenvalues.
|
| 114 |
+
"""
|
| 115 |
+
# Consider eigenvalues from index 1 to max_num_spks (skip first, it's always ~0)
|
| 116 |
+
# We need gaps between positions, so look at indices 1 to max_num_spks+1
|
| 117 |
+
max_idx = min(self.max_num_spks + 1, len(lambdas))
|
| 118 |
+
relevant_lambdas = lambdas[1:max_idx] # Skip first eigenvalue
|
| 119 |
+
|
| 120 |
+
if len(relevant_lambdas) < 2:
|
| 121 |
+
return self.min_num_spks
|
| 122 |
+
|
| 123 |
+
# Compute absolute gaps (not ratios - ratios are unstable near 0)
|
| 124 |
+
gaps = np.diff(relevant_lambdas)
|
| 125 |
+
|
| 126 |
+
# Find the largest gap - the index gives us (k-1) since we skipped first
|
| 127 |
+
# Add 1 to convert from gap index to number of speakers
|
| 128 |
+
# Add 1 again because we skipped the first eigenvalue
|
| 129 |
+
max_gap_idx = int(np.argmax(gaps))
|
| 130 |
+
num_of_spk = max_gap_idx + 2 # +1 for gap->count, +1 for skipped eigenvalue
|
| 131 |
+
|
| 132 |
+
# Clamp between min and max
|
| 133 |
+
return max(self.min_num_spks, min(num_of_spk, self.max_num_spks))
|
| 134 |
+
|
| 135 |
+
def cluster_embs(self, emb: np.ndarray, k: int) -> np.ndarray:
|
| 136 |
+
"""Cluster spectral embeddings using k-means."""
|
| 137 |
+
_, labels, _ = k_means(emb, k, n_init=10)
|
| 138 |
+
return labels
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
class SpeakerClusterer:
|
| 142 |
+
"""Speaker clustering backend using spectral clustering with speaker merging.
|
| 143 |
+
|
| 144 |
+
Features:
|
| 145 |
+
- Spectral clustering with eigenvalue gap for auto speaker count detection
|
| 146 |
+
- P-pruning for affinity matrix refinement
|
| 147 |
+
- Post-clustering speaker merging by cosine similarity
|
| 148 |
+
"""
|
| 149 |
+
|
| 150 |
+
def __init__(
|
| 151 |
+
self,
|
| 152 |
+
min_num_spks: int = 2,
|
| 153 |
+
max_num_spks: int = 10,
|
| 154 |
+
merge_thr: float = 0.90, # Moderate merging
|
| 155 |
+
):
|
| 156 |
+
self.min_num_spks = min_num_spks
|
| 157 |
+
self.max_num_spks = max_num_spks
|
| 158 |
+
self.merge_thr = merge_thr
|
| 159 |
+
self._spectral_cluster: SpectralCluster | None = None
|
| 160 |
+
|
| 161 |
+
def _get_spectral_cluster(self) -> SpectralCluster:
|
| 162 |
+
"""Lazy-load spectral clusterer."""
|
| 163 |
+
if self._spectral_cluster is None:
|
| 164 |
+
self._spectral_cluster = SpectralCluster(
|
| 165 |
+
min_num_spks=self.min_num_spks,
|
| 166 |
+
max_num_spks=self.max_num_spks,
|
| 167 |
+
)
|
| 168 |
+
return self._spectral_cluster
|
| 169 |
+
|
| 170 |
+
def __call__(self, embeddings: np.ndarray, num_speakers: int | None = None) -> np.ndarray:
|
| 171 |
+
"""Cluster speaker embeddings and return labels.
|
| 172 |
+
|
| 173 |
+
Args:
|
| 174 |
+
embeddings: Speaker embeddings of shape [N, D]
|
| 175 |
+
num_speakers: Optional oracle number of speakers
|
| 176 |
+
|
| 177 |
+
Returns:
|
| 178 |
+
Cluster labels of shape [N]
|
| 179 |
+
"""
|
| 180 |
+
import warnings
|
| 181 |
+
|
| 182 |
+
if len(embeddings.shape) != 2:
|
| 183 |
+
raise ValueError(f"Expected 2D array, got shape {embeddings.shape}")
|
| 184 |
+
|
| 185 |
+
# Handle edge cases
|
| 186 |
+
if embeddings.shape[0] == 0:
|
| 187 |
+
return np.array([], dtype=int)
|
| 188 |
+
if embeddings.shape[0] == 1:
|
| 189 |
+
return np.array([0], dtype=int)
|
| 190 |
+
if embeddings.shape[0] < 6:
|
| 191 |
+
return np.zeros(embeddings.shape[0], dtype=int)
|
| 192 |
+
|
| 193 |
+
# Normalize embeddings and replace NaN/inf
|
| 194 |
+
embeddings = np.nan_to_num(embeddings, nan=0.0, posinf=0.0, neginf=0.0)
|
| 195 |
+
embeddings = normalize(embeddings)
|
| 196 |
+
|
| 197 |
+
# Run spectral clustering (suppress numerical warnings)
|
| 198 |
+
spectral = self._get_spectral_cluster()
|
| 199 |
+
|
| 200 |
+
# Update min/max for oracle case
|
| 201 |
+
if num_speakers is not None:
|
| 202 |
+
spectral.min_num_spks = num_speakers
|
| 203 |
+
spectral.max_num_spks = num_speakers
|
| 204 |
+
|
| 205 |
+
with warnings.catch_warnings():
|
| 206 |
+
warnings.filterwarnings("ignore", category=RuntimeWarning)
|
| 207 |
+
labels = spectral(embeddings, oracle_num=num_speakers)
|
| 208 |
+
|
| 209 |
+
# Reset min/max
|
| 210 |
+
if num_speakers is not None:
|
| 211 |
+
spectral.min_num_spks = self.min_num_spks
|
| 212 |
+
spectral.max_num_spks = self.max_num_spks
|
| 213 |
+
|
| 214 |
+
# Merge similar speakers if no oracle
|
| 215 |
+
if num_speakers is None:
|
| 216 |
+
labels = self._merge_by_cos(labels, embeddings, self.merge_thr)
|
| 217 |
+
|
| 218 |
+
# Re-index labels sequentially
|
| 219 |
+
_, labels = np.unique(labels, return_inverse=True)
|
| 220 |
+
|
| 221 |
+
return labels
|
| 222 |
+
|
| 223 |
+
def _merge_by_cos(self, labels: np.ndarray, embs: np.ndarray, cos_thr: float) -> np.ndarray:
|
| 224 |
+
"""Merge similar speakers by cosine similarity of centroids."""
|
| 225 |
+
from scipy.cluster.hierarchy import fcluster, linkage
|
| 226 |
+
from scipy.spatial.distance import pdist
|
| 227 |
+
|
| 228 |
+
unique_labels = np.unique(labels)
|
| 229 |
+
if len(unique_labels) <= 1:
|
| 230 |
+
return labels
|
| 231 |
+
|
| 232 |
+
# Compute normalized speaker centroids
|
| 233 |
+
centroids = np.array([embs[labels == lbl].mean(0) for lbl in unique_labels])
|
| 234 |
+
centroids = normalize(centroids)
|
| 235 |
+
|
| 236 |
+
# Hierarchical clustering with cosine distance
|
| 237 |
+
distances = pdist(centroids, metric="cosine")
|
| 238 |
+
linkage_matrix = linkage(distances, method="average")
|
| 239 |
+
merged_labels = fcluster(linkage_matrix, t=1.0 - cos_thr, criterion="distance") - 1
|
| 240 |
+
|
| 241 |
+
# Map original labels to merged labels
|
| 242 |
+
label_map = dict(zip(unique_labels, merged_labels))
|
| 243 |
+
return np.array([label_map[lbl] for lbl in labels])
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
class LocalSpeakerDiarizer:
|
| 247 |
+
"""Local speaker diarization using TEN-VAD + ECAPA-TDNN + spectral clustering.
|
| 248 |
+
|
| 249 |
+
Pipeline:
|
| 250 |
+
1. TEN-VAD detects speech segments
|
| 251 |
+
2. Sliding window (1.0s, 75% overlap) for uniform embedding extraction
|
| 252 |
+
3. ECAPA-TDNN extracts speaker embeddings per window
|
| 253 |
+
4. Spectral clustering with eigenvalue gap for auto speaker detection
|
| 254 |
+
5. Frame-level consensus voting for segment reconstruction
|
| 255 |
+
6. Post-processing merges short segments to reduce flicker
|
| 256 |
+
|
| 257 |
+
Tunable Parameters (class attributes):
|
| 258 |
+
- WINDOW_SIZE: Embedding extraction window size in seconds
|
| 259 |
+
- STEP_SIZE: Sliding window step size (overlap = WINDOW_SIZE - STEP_SIZE)
|
| 260 |
+
- VAD_THRESHOLD: Speech detection threshold (lower = more sensitive)
|
| 261 |
+
- VAD_MIN_DURATION: Minimum speech segment duration
|
| 262 |
+
- VAD_MAX_GAP: Maximum gap to bridge between segments
|
| 263 |
+
- VAD_PAD_ONSET/OFFSET: Padding added to speech segments
|
| 264 |
+
- VOTING_RATE: Frame resolution for consensus voting
|
| 265 |
+
- MIN_SEGMENT_DURATION: Minimum final segment duration
|
| 266 |
+
- SAME_SPEAKER_GAP: Maximum gap to merge same-speaker segments
|
| 267 |
+
- TAIL_COVERAGE_RATIO: Minimum tail coverage to add extra window
|
| 268 |
+
"""
|
| 269 |
+
|
| 270 |
+
_ten_vad_model = None
|
| 271 |
+
_ecapa_model = None
|
| 272 |
+
_device = None
|
| 273 |
+
|
| 274 |
+
# ==================== TUNABLE PARAMETERS ====================
|
| 275 |
+
|
| 276 |
+
# Sliding window for embedding extraction
|
| 277 |
+
WINDOW_SIZE = 0.75 # seconds - shorter window for finer resolution
|
| 278 |
+
STEP_SIZE = 0.15 # seconds (80% overlap for more votes)
|
| 279 |
+
TAIL_COVERAGE_RATIO = 0.1 # Add extra window if tail > this ratio of window
|
| 280 |
+
|
| 281 |
+
# VAD hysteresis parameters
|
| 282 |
+
VAD_THRESHOLD = 0.25 # Balanced threshold
|
| 283 |
+
VAD_MIN_DURATION = 0.05 # Minimum speech segment duration (seconds)
|
| 284 |
+
VAD_MAX_GAP = 0.50 # Bridge gaps shorter than this (seconds)
|
| 285 |
+
VAD_PAD_ONSET = 0.05 # Padding at segment start (seconds)
|
| 286 |
+
VAD_PAD_OFFSET = 0.05 # Padding at segment end (seconds)
|
| 287 |
+
|
| 288 |
+
# Frame-level voting
|
| 289 |
+
VOTING_RATE = 0.01 # 10ms resolution for consensus voting
|
| 290 |
+
|
| 291 |
+
# Post-processing
|
| 292 |
+
MIN_SEGMENT_DURATION = 0.15 # Minimum final segment duration (seconds)
|
| 293 |
+
SHORT_SEGMENT_GAP = 0.1 # Gap threshold for merging short segments
|
| 294 |
+
SAME_SPEAKER_GAP = 0.5 # Gap threshold for merging same-speaker segments
|
| 295 |
+
|
| 296 |
+
# ===========================================================
|
| 297 |
+
|
| 298 |
+
@classmethod
|
| 299 |
+
def _get_ten_vad_model(cls):
|
| 300 |
+
"""Lazy-load TEN-VAD model (singleton)."""
|
| 301 |
+
if cls._ten_vad_model is None:
|
| 302 |
+
from ten_vad import TenVad
|
| 303 |
+
|
| 304 |
+
cls._ten_vad_model = TenVad(hop_size=256, threshold=cls.VAD_THRESHOLD)
|
| 305 |
+
return cls._ten_vad_model
|
| 306 |
+
|
| 307 |
+
@classmethod
|
| 308 |
+
def _get_device(cls) -> torch.device:
|
| 309 |
+
"""Get the best available device."""
|
| 310 |
+
if cls._device is None:
|
| 311 |
+
cls._device = _get_device()
|
| 312 |
+
return cls._device
|
| 313 |
+
|
| 314 |
+
@classmethod
|
| 315 |
+
def _get_ecapa_model(cls):
|
| 316 |
+
"""Lazy-load ECAPA-TDNN speaker embedding model (singleton)."""
|
| 317 |
+
if cls._ecapa_model is None:
|
| 318 |
+
# Suppress torchaudio deprecation warning from SpeechBrain
|
| 319 |
+
with warnings.catch_warnings():
|
| 320 |
+
warnings.filterwarnings("ignore", message="torchaudio._backend")
|
| 321 |
+
from speechbrain.inference.speaker import EncoderClassifier
|
| 322 |
+
|
| 323 |
+
device = cls._get_device()
|
| 324 |
+
cls._ecapa_model = EncoderClassifier.from_hparams(
|
| 325 |
+
source="speechbrain/spkrec-ecapa-voxceleb",
|
| 326 |
+
run_opts={"device": str(device)},
|
| 327 |
+
)
|
| 328 |
+
|
| 329 |
+
return cls._ecapa_model
|
| 330 |
+
|
| 331 |
+
@classmethod
|
| 332 |
+
def diarize(
|
| 333 |
+
cls,
|
| 334 |
+
audio: np.ndarray | str,
|
| 335 |
+
sample_rate: int = 16000,
|
| 336 |
+
num_speakers: int | None = None,
|
| 337 |
+
min_speakers: int = 2,
|
| 338 |
+
max_speakers: int = 10,
|
| 339 |
+
**_kwargs,
|
| 340 |
+
) -> list[dict]:
|
| 341 |
+
"""Run speaker diarization on audio.
|
| 342 |
+
|
| 343 |
+
Args:
|
| 344 |
+
audio: Audio waveform as numpy array or path to audio file
|
| 345 |
+
sample_rate: Audio sample rate (default 16000)
|
| 346 |
+
num_speakers: Exact number of speakers (if known)
|
| 347 |
+
min_speakers: Minimum number of speakers
|
| 348 |
+
max_speakers: Maximum number of speakers
|
| 349 |
+
|
| 350 |
+
Returns:
|
| 351 |
+
List of dicts with 'speaker', 'start', 'end' keys
|
| 352 |
+
"""
|
| 353 |
+
# Handle file path input
|
| 354 |
+
if isinstance(audio, str):
|
| 355 |
+
import librosa
|
| 356 |
+
|
| 357 |
+
audio, sample_rate = librosa.load(audio, sr=16000)
|
| 358 |
+
|
| 359 |
+
# Ensure correct sample rate
|
| 360 |
+
if sample_rate != 16000:
|
| 361 |
+
import librosa
|
| 362 |
+
|
| 363 |
+
audio = librosa.resample(audio, orig_sr=sample_rate, target_sr=16000)
|
| 364 |
+
sample_rate = 16000
|
| 365 |
+
|
| 366 |
+
audio = audio.astype(np.float32)
|
| 367 |
+
total_duration = len(audio) / sample_rate
|
| 368 |
+
|
| 369 |
+
# Step 1: VAD (returns segments and raw frame-level decisions)
|
| 370 |
+
segments, vad_frames = cls._get_speech_segments(audio, sample_rate)
|
| 371 |
+
if not segments:
|
| 372 |
+
return []
|
| 373 |
+
|
| 374 |
+
# Step 2: Extract embeddings
|
| 375 |
+
embeddings, window_segments = cls._extract_embeddings(audio, segments, sample_rate)
|
| 376 |
+
if len(embeddings) == 0:
|
| 377 |
+
return []
|
| 378 |
+
|
| 379 |
+
# Step 3: Cluster
|
| 380 |
+
clusterer = SpeakerClusterer(min_num_spks=min_speakers, max_num_spks=max_speakers)
|
| 381 |
+
labels = clusterer(embeddings, num_speakers)
|
| 382 |
+
|
| 383 |
+
# Step 4: Post-process with consensus voting (VAD-aware)
|
| 384 |
+
return cls._postprocess_segments(window_segments, labels, total_duration, vad_frames)
|
| 385 |
+
|
| 386 |
+
@classmethod
|
| 387 |
+
def _get_speech_segments(
|
| 388 |
+
cls, audio_array: np.ndarray, sample_rate: int = 16000
|
| 389 |
+
) -> tuple[list[dict], list[bool]]:
|
| 390 |
+
"""Get speech segments using TEN-VAD.
|
| 391 |
+
|
| 392 |
+
Returns:
|
| 393 |
+
Tuple of (segments list, vad_frames list of per-frame speech decisions)
|
| 394 |
+
"""
|
| 395 |
+
vad_model = cls._get_ten_vad_model()
|
| 396 |
+
|
| 397 |
+
# Convert to int16 as required by TEN-VAD
|
| 398 |
+
# Clip to prevent integer overflow
|
| 399 |
+
if audio_array.dtype != np.int16:
|
| 400 |
+
audio_int16 = (np.clip(audio_array, -1.0, 1.0) * 32767).astype(np.int16)
|
| 401 |
+
else:
|
| 402 |
+
audio_int16 = audio_array
|
| 403 |
+
|
| 404 |
+
# Process frame by frame
|
| 405 |
+
hop_size = 256
|
| 406 |
+
frame_duration = hop_size / sample_rate
|
| 407 |
+
speech_frames: list[bool] = []
|
| 408 |
+
|
| 409 |
+
for i in range(0, len(audio_int16) - hop_size, hop_size):
|
| 410 |
+
frame = audio_int16[i : i + hop_size]
|
| 411 |
+
_, is_speech = vad_model.process(frame)
|
| 412 |
+
speech_frames.append(is_speech)
|
| 413 |
+
|
| 414 |
+
# Convert frame-level decisions to segments
|
| 415 |
+
segments = []
|
| 416 |
+
in_speech = False
|
| 417 |
+
start_idx = 0
|
| 418 |
+
|
| 419 |
+
for i, is_speech in enumerate(speech_frames):
|
| 420 |
+
if is_speech and not in_speech:
|
| 421 |
+
start_idx = i
|
| 422 |
+
in_speech = True
|
| 423 |
+
elif not is_speech and in_speech:
|
| 424 |
+
start_time = start_idx * frame_duration
|
| 425 |
+
end_time = i * frame_duration
|
| 426 |
+
segments.append(
|
| 427 |
+
{
|
| 428 |
+
"start": start_time,
|
| 429 |
+
"end": end_time,
|
| 430 |
+
"start_sample": int(start_time * sample_rate),
|
| 431 |
+
"end_sample": int(end_time * sample_rate),
|
| 432 |
+
}
|
| 433 |
+
)
|
| 434 |
+
in_speech = False
|
| 435 |
+
|
| 436 |
+
# Handle trailing speech
|
| 437 |
+
if in_speech:
|
| 438 |
+
start_time = start_idx * frame_duration
|
| 439 |
+
end_time = len(speech_frames) * frame_duration
|
| 440 |
+
segments.append(
|
| 441 |
+
{
|
| 442 |
+
"start": start_time,
|
| 443 |
+
"end": end_time,
|
| 444 |
+
"start_sample": int(start_time * sample_rate),
|
| 445 |
+
"end_sample": int(end_time * sample_rate),
|
| 446 |
+
}
|
| 447 |
+
)
|
| 448 |
+
|
| 449 |
+
return cls._apply_vad_hysteresis(segments, sample_rate), speech_frames
|
| 450 |
+
|
| 451 |
+
@classmethod
|
| 452 |
+
def _apply_vad_hysteresis(cls, segments: list[dict], sample_rate: int = 16000) -> list[dict]:
|
| 453 |
+
"""Apply hysteresis-like post-processing to VAD segments."""
|
| 454 |
+
if not segments:
|
| 455 |
+
return segments
|
| 456 |
+
|
| 457 |
+
segments = sorted(segments, key=lambda x: x["start"])
|
| 458 |
+
|
| 459 |
+
# Fill short gaps
|
| 460 |
+
merged = [segments[0].copy()]
|
| 461 |
+
for seg in segments[1:]:
|
| 462 |
+
gap = seg["start"] - merged[-1]["end"]
|
| 463 |
+
if gap <= cls.VAD_MAX_GAP:
|
| 464 |
+
merged[-1]["end"] = seg["end"]
|
| 465 |
+
merged[-1]["end_sample"] = seg["end_sample"]
|
| 466 |
+
else:
|
| 467 |
+
merged.append(seg.copy())
|
| 468 |
+
|
| 469 |
+
# Remove short segments
|
| 470 |
+
filtered = [seg for seg in merged if (seg["end"] - seg["start"]) >= cls.VAD_MIN_DURATION]
|
| 471 |
+
|
| 472 |
+
# Dilate segments (add padding)
|
| 473 |
+
for seg in filtered:
|
| 474 |
+
seg["start"] = max(0.0, seg["start"] - cls.VAD_PAD_ONSET)
|
| 475 |
+
seg["end"] = seg["end"] + cls.VAD_PAD_OFFSET
|
| 476 |
+
seg["start_sample"] = int(seg["start"] * sample_rate)
|
| 477 |
+
seg["end_sample"] = int(seg["end"] * sample_rate)
|
| 478 |
+
|
| 479 |
+
return filtered
|
| 480 |
+
|
| 481 |
+
@classmethod
|
| 482 |
+
def _extract_embeddings(
|
| 483 |
+
cls, audio_array: np.ndarray, segments: list[dict], sample_rate: int
|
| 484 |
+
) -> tuple[np.ndarray, list[dict]]:
|
| 485 |
+
"""Extract speaker embeddings using sliding windows."""
|
| 486 |
+
speaker_model = cls._get_ecapa_model()
|
| 487 |
+
|
| 488 |
+
window_samples = int(cls.WINDOW_SIZE * sample_rate)
|
| 489 |
+
step_samples = int(cls.STEP_SIZE * sample_rate)
|
| 490 |
+
|
| 491 |
+
embeddings = []
|
| 492 |
+
window_segments = []
|
| 493 |
+
|
| 494 |
+
with torch.no_grad():
|
| 495 |
+
for seg in segments:
|
| 496 |
+
seg_start = seg["start_sample"]
|
| 497 |
+
seg_end = seg["end_sample"]
|
| 498 |
+
seg_len = seg_end - seg_start
|
| 499 |
+
|
| 500 |
+
# Generate window positions
|
| 501 |
+
if seg_len <= window_samples:
|
| 502 |
+
starts = [seg_start]
|
| 503 |
+
ends = [seg_end]
|
| 504 |
+
else:
|
| 505 |
+
starts = list(range(seg_start, seg_end - window_samples + 1, step_samples))
|
| 506 |
+
ends = [s + window_samples for s in starts]
|
| 507 |
+
|
| 508 |
+
# Cover tail if > TAIL_COVERAGE_RATIO of window remains
|
| 509 |
+
if ends and ends[-1] < seg_end:
|
| 510 |
+
remainder = seg_end - ends[-1]
|
| 511 |
+
if remainder > (window_samples * cls.TAIL_COVERAGE_RATIO):
|
| 512 |
+
starts.append(seg_end - window_samples)
|
| 513 |
+
ends.append(seg_end)
|
| 514 |
+
|
| 515 |
+
for c_start, c_end in zip(starts, ends):
|
| 516 |
+
chunk = audio_array[c_start:c_end]
|
| 517 |
+
|
| 518 |
+
# Pad short chunks with reflection
|
| 519 |
+
if len(chunk) < window_samples:
|
| 520 |
+
pad_width = window_samples - len(chunk)
|
| 521 |
+
chunk = np.pad(chunk, (0, pad_width), mode="reflect")
|
| 522 |
+
|
| 523 |
+
# Extract embedding using SpeechBrain's encode_batch
|
| 524 |
+
chunk_tensor = torch.from_numpy(chunk).float().unsqueeze(0)
|
| 525 |
+
embedding = (
|
| 526 |
+
speaker_model.encode_batch(chunk_tensor).squeeze(0).squeeze(0).cpu().numpy()
|
| 527 |
+
)
|
| 528 |
+
|
| 529 |
+
# Validate embedding
|
| 530 |
+
if np.isfinite(embedding).all() and np.linalg.norm(embedding) > 1e-8:
|
| 531 |
+
embeddings.append(embedding)
|
| 532 |
+
window_segments.append(
|
| 533 |
+
{
|
| 534 |
+
"start": c_start / sample_rate,
|
| 535 |
+
"end": c_end / sample_rate,
|
| 536 |
+
}
|
| 537 |
+
)
|
| 538 |
+
|
| 539 |
+
# Normalize all embeddings at once
|
| 540 |
+
if embeddings:
|
| 541 |
+
return normalize(np.array(embeddings)), window_segments
|
| 542 |
+
return np.array([]), []
|
| 543 |
+
|
| 544 |
+
@classmethod
|
| 545 |
+
def _resample_vad(cls, vad_frames: list[bool], num_frames: int) -> np.ndarray:
|
| 546 |
+
"""Resample VAD frame decisions to match voting grid resolution.
|
| 547 |
+
|
| 548 |
+
VAD operates at 256 samples / 16000 Hz = 16ms per frame.
|
| 549 |
+
Voting operates at VOTING_RATE (default 10ms) per frame.
|
| 550 |
+
This maps VAD decisions to the finer voting grid.
|
| 551 |
+
"""
|
| 552 |
+
if not vad_frames:
|
| 553 |
+
return np.zeros(num_frames, dtype=bool)
|
| 554 |
+
|
| 555 |
+
vad_rate = 256 / 16000 # 16ms per VAD frame
|
| 556 |
+
vad_arr = np.array(vad_frames)
|
| 557 |
+
|
| 558 |
+
# Vectorized: compute VAD frame indices for each voting frame
|
| 559 |
+
voting_times = np.arange(num_frames) * cls.VOTING_RATE
|
| 560 |
+
vad_indices = np.clip((voting_times / vad_rate).astype(int), 0, len(vad_arr) - 1)
|
| 561 |
+
return vad_arr[vad_indices]
|
| 562 |
+
|
| 563 |
+
@classmethod
|
| 564 |
+
def _postprocess_segments(
|
| 565 |
+
cls,
|
| 566 |
+
window_segments: list[dict],
|
| 567 |
+
labels: np.ndarray,
|
| 568 |
+
total_duration: float,
|
| 569 |
+
vad_frames: list[bool],
|
| 570 |
+
) -> list[dict]:
|
| 571 |
+
"""Post-process using frame-level consensus voting with VAD-aware silence."""
|
| 572 |
+
if not window_segments or len(labels) == 0:
|
| 573 |
+
return []
|
| 574 |
+
|
| 575 |
+
# Correct labels to be contiguous
|
| 576 |
+
unique_labels = np.unique(labels)
|
| 577 |
+
label_map = {old: new for new, old in enumerate(unique_labels)}
|
| 578 |
+
clean_labels = np.array([label_map[lbl] for lbl in labels])
|
| 579 |
+
num_speakers = len(unique_labels)
|
| 580 |
+
|
| 581 |
+
if num_speakers == 0:
|
| 582 |
+
return []
|
| 583 |
+
|
| 584 |
+
# Create voting grid
|
| 585 |
+
num_frames = int(np.ceil(total_duration / cls.VOTING_RATE)) + 1
|
| 586 |
+
votes = np.zeros((num_frames, num_speakers), dtype=np.float32)
|
| 587 |
+
|
| 588 |
+
# Accumulate votes
|
| 589 |
+
for win, label in zip(window_segments, clean_labels):
|
| 590 |
+
start_frame = int(win["start"] / cls.VOTING_RATE)
|
| 591 |
+
end_frame = int(win["end"] / cls.VOTING_RATE)
|
| 592 |
+
end_frame = min(end_frame, num_frames)
|
| 593 |
+
if start_frame < end_frame:
|
| 594 |
+
votes[start_frame:end_frame, label] += 1.0
|
| 595 |
+
|
| 596 |
+
# Determine winner per frame
|
| 597 |
+
frame_speakers = np.argmax(votes, axis=1)
|
| 598 |
+
max_votes = np.max(votes, axis=1)
|
| 599 |
+
|
| 600 |
+
# Resample VAD to voting grid resolution for silence-aware voting
|
| 601 |
+
vad_resampled = cls._resample_vad(vad_frames, num_frames)
|
| 602 |
+
|
| 603 |
+
# Convert frames to segments
|
| 604 |
+
final_segments = []
|
| 605 |
+
current_speaker = -1
|
| 606 |
+
seg_start = 0.0
|
| 607 |
+
|
| 608 |
+
for f in range(num_frames):
|
| 609 |
+
speaker = int(frame_speakers[f])
|
| 610 |
+
score = max_votes[f]
|
| 611 |
+
|
| 612 |
+
# Force silence if VAD says no speech OR no votes
|
| 613 |
+
if score == 0 or not vad_resampled[f]:
|
| 614 |
+
speaker = -1
|
| 615 |
+
|
| 616 |
+
if speaker != current_speaker:
|
| 617 |
+
if current_speaker != -1:
|
| 618 |
+
final_segments.append(
|
| 619 |
+
{
|
| 620 |
+
"speaker": f"SPEAKER_{current_speaker}",
|
| 621 |
+
"start": seg_start,
|
| 622 |
+
"end": f * cls.VOTING_RATE,
|
| 623 |
+
}
|
| 624 |
+
)
|
| 625 |
+
current_speaker = speaker
|
| 626 |
+
seg_start = f * cls.VOTING_RATE
|
| 627 |
+
|
| 628 |
+
# Close last segment
|
| 629 |
+
if current_speaker != -1:
|
| 630 |
+
final_segments.append(
|
| 631 |
+
{
|
| 632 |
+
"speaker": f"SPEAKER_{current_speaker}",
|
| 633 |
+
"start": seg_start,
|
| 634 |
+
"end": num_frames * cls.VOTING_RATE,
|
| 635 |
+
}
|
| 636 |
+
)
|
| 637 |
+
|
| 638 |
+
return cls._merge_short_segments(final_segments)
|
| 639 |
+
|
| 640 |
+
@classmethod
|
| 641 |
+
def _merge_short_segments(cls, segments: list[dict]) -> list[dict]:
|
| 642 |
+
"""Merge short segments to reduce flicker."""
|
| 643 |
+
if not segments:
|
| 644 |
+
return []
|
| 645 |
+
|
| 646 |
+
clean: list[dict] = []
|
| 647 |
+
for seg in segments:
|
| 648 |
+
dur = seg["end"] - seg["start"]
|
| 649 |
+
if dur < cls.MIN_SEGMENT_DURATION:
|
| 650 |
+
if (
|
| 651 |
+
clean
|
| 652 |
+
and clean[-1]["speaker"] == seg["speaker"]
|
| 653 |
+
and seg["start"] - clean[-1]["end"] < cls.SHORT_SEGMENT_GAP
|
| 654 |
+
):
|
| 655 |
+
clean[-1]["end"] = seg["end"]
|
| 656 |
+
continue
|
| 657 |
+
|
| 658 |
+
if (
|
| 659 |
+
clean
|
| 660 |
+
and clean[-1]["speaker"] == seg["speaker"]
|
| 661 |
+
and seg["start"] - clean[-1]["end"] < cls.SAME_SPEAKER_GAP
|
| 662 |
+
):
|
| 663 |
+
clean[-1]["end"] = seg["end"]
|
| 664 |
+
else:
|
| 665 |
+
clean.append(seg)
|
| 666 |
+
|
| 667 |
+
return clean
|
| 668 |
+
|
| 669 |
+
@classmethod
|
| 670 |
+
def assign_speakers_to_words(
|
| 671 |
+
cls,
|
| 672 |
+
words: list[dict],
|
| 673 |
+
speaker_segments: list[dict],
|
| 674 |
+
) -> list[dict]:
|
| 675 |
+
"""Assign speaker labels to words based on timestamp overlap.
|
| 676 |
+
|
| 677 |
+
Args:
|
| 678 |
+
words: List of word dicts with 'word', 'start', 'end' keys
|
| 679 |
+
speaker_segments: List of speaker dicts with 'speaker', 'start', 'end' keys
|
| 680 |
+
|
| 681 |
+
Returns:
|
| 682 |
+
Words list with 'speaker' key added to each word
|
| 683 |
+
"""
|
| 684 |
+
for word in words:
|
| 685 |
+
word_mid = (word["start"] + word["end"]) / 2
|
| 686 |
+
|
| 687 |
+
# Find the speaker segment that contains this word's midpoint
|
| 688 |
+
best_speaker = None
|
| 689 |
+
for seg in speaker_segments:
|
| 690 |
+
if seg["start"] <= word_mid <= seg["end"]:
|
| 691 |
+
best_speaker = seg["speaker"]
|
| 692 |
+
break
|
| 693 |
+
|
| 694 |
+
# If no exact match, find closest segment
|
| 695 |
+
if best_speaker is None and speaker_segments:
|
| 696 |
+
min_dist = float("inf")
|
| 697 |
+
for seg in speaker_segments:
|
| 698 |
+
seg_mid = (seg["start"] + seg["end"]) / 2
|
| 699 |
+
dist = abs(word_mid - seg_mid)
|
| 700 |
+
if dist < min_dist:
|
| 701 |
+
min_dist = dist
|
| 702 |
+
best_speaker = seg["speaker"]
|
| 703 |
+
|
| 704 |
+
word["speaker"] = best_speaker
|
| 705 |
+
|
| 706 |
+
return words
|
full_duplex.py
ADDED
|
@@ -0,0 +1,475 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Full-duplex audio session for speech-to-speech.
|
| 2 |
+
|
| 3 |
+
Implements Freeze-Omni style full-duplex conversation where the model
|
| 4 |
+
can listen and speak simultaneously, with support for user interruption.
|
| 5 |
+
|
| 6 |
+
Architecture:
|
| 7 |
+
- Dual queue system: PCMQueue (input) + AudioQueue (output)
|
| 8 |
+
- Multi-threaded: Listen thread + Generate thread run concurrently
|
| 9 |
+
- State machine: listen -> speak -> (interrupt) -> listen
|
| 10 |
+
- VAD-based turn detection using model's built-in Silero VAD
|
| 11 |
+
|
| 12 |
+
Usage (sync):
|
| 13 |
+
session = FullDuplexSession(model)
|
| 14 |
+
session.start()
|
| 15 |
+
|
| 16 |
+
while has_audio:
|
| 17 |
+
session.push_audio(audio_chunk)
|
| 18 |
+
output = session.pop_audio()
|
| 19 |
+
if output is not None:
|
| 20 |
+
speaker.play(output)
|
| 21 |
+
|
| 22 |
+
session.stop()
|
| 23 |
+
|
| 24 |
+
Usage (async/web):
|
| 25 |
+
session = FullDuplexSession(
|
| 26 |
+
model,
|
| 27 |
+
on_state_change=lambda s: send_status(s),
|
| 28 |
+
on_text=lambda t: send_text(t),
|
| 29 |
+
on_audio=lambda a: send_audio(a),
|
| 30 |
+
)
|
| 31 |
+
session.start()
|
| 32 |
+
|
| 33 |
+
# In your receive loop:
|
| 34 |
+
session.push_audio(audio_chunk)
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
import logging
|
| 38 |
+
import queue
|
| 39 |
+
import threading
|
| 40 |
+
import time
|
| 41 |
+
from dataclasses import dataclass, field
|
| 42 |
+
from enum import Enum
|
| 43 |
+
from typing import TYPE_CHECKING, Callable, Optional
|
| 44 |
+
|
| 45 |
+
import numpy as np
|
| 46 |
+
import torch
|
| 47 |
+
|
| 48 |
+
if TYPE_CHECKING:
|
| 49 |
+
from .asr_modeling import ASRModel
|
| 50 |
+
|
| 51 |
+
logger = logging.getLogger(__name__)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class ConversationState(Enum):
|
| 55 |
+
"""State machine for full-duplex conversation."""
|
| 56 |
+
|
| 57 |
+
IDLE = "idle"
|
| 58 |
+
LISTENING = "listening"
|
| 59 |
+
PROCESSING = "processing"
|
| 60 |
+
SPEAKING = "speaking"
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
@dataclass
|
| 64 |
+
class FullDuplexConfig:
|
| 65 |
+
"""Configuration for full-duplex session."""
|
| 66 |
+
|
| 67 |
+
# Audio settings
|
| 68 |
+
sample_rate: int = 16000
|
| 69 |
+
chunk_size: int = 512 # Samples per chunk (32ms at 16kHz)
|
| 70 |
+
output_sample_rate: int = 44100 # DAC output rate
|
| 71 |
+
|
| 72 |
+
# VAD settings
|
| 73 |
+
vad_threshold: float = 0.5
|
| 74 |
+
silence_duration_ms: float = 700 # Silence to end turn
|
| 75 |
+
min_speech_duration_ms: float = 100 # Minimum speech to trigger
|
| 76 |
+
|
| 77 |
+
# Generation settings
|
| 78 |
+
audio_chunk_size: int = 4 # Tokens per audio chunk
|
| 79 |
+
|
| 80 |
+
# Timing
|
| 81 |
+
poll_interval: float = 0.01
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
class PCMQueue:
|
| 85 |
+
"""Thread-safe queue for streaming PCM audio input."""
|
| 86 |
+
|
| 87 |
+
def __init__(self):
|
| 88 |
+
self.buffer = np.array([], dtype=np.float32)
|
| 89 |
+
self.lock = threading.Lock()
|
| 90 |
+
|
| 91 |
+
def put(self, audio: np.ndarray) -> None:
|
| 92 |
+
with self.lock:
|
| 93 |
+
self.buffer = np.concatenate([self.buffer, audio.astype(np.float32)])
|
| 94 |
+
|
| 95 |
+
def get(self, length: int) -> Optional[np.ndarray]:
|
| 96 |
+
with self.lock:
|
| 97 |
+
if len(self.buffer) < length:
|
| 98 |
+
return None
|
| 99 |
+
result = self.buffer[:length]
|
| 100 |
+
self.buffer = self.buffer[length:]
|
| 101 |
+
return result
|
| 102 |
+
|
| 103 |
+
def clear(self) -> None:
|
| 104 |
+
with self.lock:
|
| 105 |
+
self.buffer = np.array([], dtype=np.float32)
|
| 106 |
+
|
| 107 |
+
def __len__(self) -> int:
|
| 108 |
+
with self.lock:
|
| 109 |
+
return len(self.buffer)
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
class AudioQueue:
|
| 113 |
+
"""Thread-safe queue for output audio chunks."""
|
| 114 |
+
|
| 115 |
+
def __init__(self):
|
| 116 |
+
self._queue: queue.Queue = queue.Queue()
|
| 117 |
+
|
| 118 |
+
def put(self, audio: torch.Tensor) -> None:
|
| 119 |
+
self._queue.put(audio)
|
| 120 |
+
|
| 121 |
+
def get(self) -> Optional[torch.Tensor]:
|
| 122 |
+
try:
|
| 123 |
+
return self._queue.get_nowait()
|
| 124 |
+
except queue.Empty:
|
| 125 |
+
return None
|
| 126 |
+
|
| 127 |
+
def clear(self) -> None:
|
| 128 |
+
while not self._queue.empty():
|
| 129 |
+
try:
|
| 130 |
+
self._queue.get_nowait()
|
| 131 |
+
except queue.Empty:
|
| 132 |
+
break
|
| 133 |
+
|
| 134 |
+
def is_empty(self) -> bool:
|
| 135 |
+
return self._queue.empty()
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
@dataclass
|
| 139 |
+
class _SessionState:
|
| 140 |
+
"""Internal state for full-duplex session."""
|
| 141 |
+
|
| 142 |
+
state: ConversationState = ConversationState.IDLE
|
| 143 |
+
speech_buffer: list = field(default_factory=list)
|
| 144 |
+
speech_start_time: float = 0.0
|
| 145 |
+
last_speech_time: float = 0.0
|
| 146 |
+
silence_frames: int = 0
|
| 147 |
+
stop_generate: bool = False
|
| 148 |
+
is_generating: bool = False
|
| 149 |
+
generated_text: str = ""
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
class FullDuplexSession:
|
| 153 |
+
"""Full-duplex speech-to-speech session (Freeze-Omni style).
|
| 154 |
+
|
| 155 |
+
Manages simultaneous listening and speaking with VAD-based turn detection.
|
| 156 |
+
Designed to be easy to integrate with both sync and async (web) applications.
|
| 157 |
+
|
| 158 |
+
Args:
|
| 159 |
+
model: ASRModel with audio_head configured
|
| 160 |
+
config: FullDuplexConfig for session parameters
|
| 161 |
+
on_state_change: Callback when state changes (state: ConversationState)
|
| 162 |
+
on_text: Callback when text is generated (text: str, interim: bool)
|
| 163 |
+
on_audio: Callback when audio chunk is ready (audio: torch.Tensor)
|
| 164 |
+
If provided, audio is sent here instead of output_queue
|
| 165 |
+
on_interrupted: Callback when generation is interrupted
|
| 166 |
+
"""
|
| 167 |
+
|
| 168 |
+
def __init__(
|
| 169 |
+
self,
|
| 170 |
+
model: "ASRModel",
|
| 171 |
+
config: Optional[FullDuplexConfig] = None,
|
| 172 |
+
on_state_change: Optional[Callable[[ConversationState], None]] = None,
|
| 173 |
+
on_text: Optional[Callable[[str, bool], None]] = None,
|
| 174 |
+
on_audio: Optional[Callable[[torch.Tensor], None]] = None,
|
| 175 |
+
on_interrupted: Optional[Callable[[], None]] = None,
|
| 176 |
+
):
|
| 177 |
+
self.model = model
|
| 178 |
+
self.config = config or FullDuplexConfig()
|
| 179 |
+
|
| 180 |
+
# Callbacks
|
| 181 |
+
self.on_state_change = on_state_change
|
| 182 |
+
self.on_text = on_text
|
| 183 |
+
self.on_audio = on_audio
|
| 184 |
+
self.on_interrupted = on_interrupted
|
| 185 |
+
|
| 186 |
+
# Queues
|
| 187 |
+
self.input_queue = PCMQueue()
|
| 188 |
+
self.output_queue = AudioQueue()
|
| 189 |
+
|
| 190 |
+
# State
|
| 191 |
+
self._state = _SessionState()
|
| 192 |
+
self._running = False
|
| 193 |
+
self._state_lock = threading.Lock()
|
| 194 |
+
|
| 195 |
+
# Threads
|
| 196 |
+
self._listen_thread: Optional[threading.Thread] = None
|
| 197 |
+
self._generate_thread: Optional[threading.Thread] = None
|
| 198 |
+
|
| 199 |
+
# Precompute timing thresholds
|
| 200 |
+
ms_per_chunk = self.config.chunk_size / self.config.sample_rate * 1000
|
| 201 |
+
self._silence_threshold = int(self.config.silence_duration_ms / ms_per_chunk)
|
| 202 |
+
self._min_speech_chunks = int(self.config.min_speech_duration_ms / ms_per_chunk)
|
| 203 |
+
|
| 204 |
+
# Ensure VAD is loaded
|
| 205 |
+
self.model.load_vad()
|
| 206 |
+
|
| 207 |
+
@property
|
| 208 |
+
def state(self) -> ConversationState:
|
| 209 |
+
with self._state_lock:
|
| 210 |
+
return self._state.state
|
| 211 |
+
|
| 212 |
+
def _set_state(self, value: ConversationState) -> None:
|
| 213 |
+
with self._state_lock:
|
| 214 |
+
old_state = self._state.state
|
| 215 |
+
self._state.state = value
|
| 216 |
+
if old_state != value:
|
| 217 |
+
logger.debug(f"State: {old_state.value} -> {value.value}")
|
| 218 |
+
if self.on_state_change:
|
| 219 |
+
try:
|
| 220 |
+
self.on_state_change(value)
|
| 221 |
+
except Exception as e:
|
| 222 |
+
logger.error(f"on_state_change callback error: {e}")
|
| 223 |
+
|
| 224 |
+
@property
|
| 225 |
+
def is_generating(self) -> bool:
|
| 226 |
+
with self._state_lock:
|
| 227 |
+
return self._state.is_generating
|
| 228 |
+
|
| 229 |
+
@property
|
| 230 |
+
def generated_text(self) -> str:
|
| 231 |
+
with self._state_lock:
|
| 232 |
+
return self._state.generated_text
|
| 233 |
+
|
| 234 |
+
def start(self) -> None:
|
| 235 |
+
"""Start the full-duplex session."""
|
| 236 |
+
if self._running:
|
| 237 |
+
return
|
| 238 |
+
|
| 239 |
+
self._running = True
|
| 240 |
+
self._set_state(ConversationState.LISTENING)
|
| 241 |
+
|
| 242 |
+
self._listen_thread = threading.Thread(target=self._listen_loop, daemon=True)
|
| 243 |
+
self._listen_thread.start()
|
| 244 |
+
|
| 245 |
+
logger.info("Full-duplex session started")
|
| 246 |
+
|
| 247 |
+
def stop(self) -> None:
|
| 248 |
+
"""Stop the full-duplex session."""
|
| 249 |
+
self._running = False
|
| 250 |
+
|
| 251 |
+
with self._state_lock:
|
| 252 |
+
self._state.stop_generate = True
|
| 253 |
+
|
| 254 |
+
if self._listen_thread:
|
| 255 |
+
self._listen_thread.join(timeout=2.0)
|
| 256 |
+
if self._generate_thread:
|
| 257 |
+
self._generate_thread.join(timeout=2.0)
|
| 258 |
+
|
| 259 |
+
self.input_queue.clear()
|
| 260 |
+
self.output_queue.clear()
|
| 261 |
+
self._set_state(ConversationState.IDLE)
|
| 262 |
+
|
| 263 |
+
logger.info("Full-duplex session stopped")
|
| 264 |
+
|
| 265 |
+
def push_audio(self, audio: np.ndarray) -> None:
|
| 266 |
+
"""Push audio samples to the input queue.
|
| 267 |
+
|
| 268 |
+
Args:
|
| 269 |
+
audio: Audio samples as numpy array (float32 normalized or int16)
|
| 270 |
+
"""
|
| 271 |
+
if audio.dtype == np.int16:
|
| 272 |
+
audio = audio.astype(np.float32) / 32768.0
|
| 273 |
+
self.input_queue.put(audio)
|
| 274 |
+
|
| 275 |
+
def pop_audio(self) -> Optional[torch.Tensor]:
|
| 276 |
+
"""Pop generated audio from the output queue.
|
| 277 |
+
|
| 278 |
+
Only used if on_audio callback is not set.
|
| 279 |
+
|
| 280 |
+
Returns:
|
| 281 |
+
Audio tensor [samples] or None
|
| 282 |
+
"""
|
| 283 |
+
return self.output_queue.get()
|
| 284 |
+
|
| 285 |
+
def interrupt(self) -> None:
|
| 286 |
+
"""Interrupt current generation and return to listening."""
|
| 287 |
+
with self._state_lock:
|
| 288 |
+
self._state.stop_generate = True
|
| 289 |
+
|
| 290 |
+
# Wait for generation to stop
|
| 291 |
+
timeout = 2.0
|
| 292 |
+
start = time.time()
|
| 293 |
+
while self._state.is_generating and (time.time() - start) < timeout:
|
| 294 |
+
time.sleep(self.config.poll_interval)
|
| 295 |
+
|
| 296 |
+
# Clear output queue
|
| 297 |
+
self.output_queue.clear()
|
| 298 |
+
|
| 299 |
+
# Reset state
|
| 300 |
+
with self._state_lock:
|
| 301 |
+
self._state.stop_generate = False
|
| 302 |
+
self._state.generated_text = ""
|
| 303 |
+
self._state.speech_buffer.clear()
|
| 304 |
+
self._state.silence_frames = 0
|
| 305 |
+
|
| 306 |
+
self._set_state(ConversationState.LISTENING)
|
| 307 |
+
self.model.reset_vad_state()
|
| 308 |
+
|
| 309 |
+
if self.on_interrupted:
|
| 310 |
+
try:
|
| 311 |
+
self.on_interrupted()
|
| 312 |
+
except Exception as e:
|
| 313 |
+
logger.error(f"on_interrupted callback error: {e}")
|
| 314 |
+
|
| 315 |
+
logger.debug("Generation interrupted")
|
| 316 |
+
|
| 317 |
+
def _emit_audio(self, audio: torch.Tensor) -> None:
|
| 318 |
+
"""Send audio to callback or queue."""
|
| 319 |
+
if self.on_audio:
|
| 320 |
+
try:
|
| 321 |
+
self.on_audio(audio)
|
| 322 |
+
except Exception as e:
|
| 323 |
+
logger.error(f"on_audio callback error: {e}")
|
| 324 |
+
else:
|
| 325 |
+
self.output_queue.put(audio)
|
| 326 |
+
|
| 327 |
+
def _emit_text(self, text: str, interim: bool = False) -> None:
|
| 328 |
+
"""Send text to callback."""
|
| 329 |
+
if self.on_text:
|
| 330 |
+
try:
|
| 331 |
+
self.on_text(text, interim)
|
| 332 |
+
except Exception as e:
|
| 333 |
+
logger.error(f"on_text callback error: {e}")
|
| 334 |
+
|
| 335 |
+
def _listen_loop(self) -> None:
|
| 336 |
+
"""Main listening loop - processes audio and detects speech."""
|
| 337 |
+
is_speaking = False
|
| 338 |
+
|
| 339 |
+
while self._running:
|
| 340 |
+
audio = self.input_queue.get(self.config.chunk_size)
|
| 341 |
+
if audio is None:
|
| 342 |
+
time.sleep(self.config.poll_interval)
|
| 343 |
+
continue
|
| 344 |
+
|
| 345 |
+
# Run VAD
|
| 346 |
+
audio_tensor = torch.from_numpy(audio)
|
| 347 |
+
is_speech, prob = self.model.detect_speech(
|
| 348 |
+
audio_tensor,
|
| 349 |
+
self.config.sample_rate,
|
| 350 |
+
self.config.vad_threshold,
|
| 351 |
+
)
|
| 352 |
+
|
| 353 |
+
current_time = time.time()
|
| 354 |
+
|
| 355 |
+
# Check for interruption during generation
|
| 356 |
+
if self._state.is_generating and is_speech:
|
| 357 |
+
logger.debug(f"Interruption detected (prob={prob:.2f})")
|
| 358 |
+
self.interrupt()
|
| 359 |
+
# Start new utterance with this chunk
|
| 360 |
+
is_speaking = True
|
| 361 |
+
with self._state_lock:
|
| 362 |
+
self._state.speech_buffer = [audio]
|
| 363 |
+
self._state.speech_start_time = current_time
|
| 364 |
+
self._state.last_speech_time = current_time
|
| 365 |
+
self._state.silence_frames = 0
|
| 366 |
+
continue
|
| 367 |
+
|
| 368 |
+
# Normal VAD state machine
|
| 369 |
+
if is_speech:
|
| 370 |
+
if not is_speaking:
|
| 371 |
+
is_speaking = True
|
| 372 |
+
with self._state_lock:
|
| 373 |
+
self._state.speech_buffer = []
|
| 374 |
+
self._state.speech_start_time = current_time
|
| 375 |
+
with self._state_lock:
|
| 376 |
+
self._state.speech_buffer.append(audio)
|
| 377 |
+
self._state.last_speech_time = current_time
|
| 378 |
+
self._state.silence_frames = 0
|
| 379 |
+
|
| 380 |
+
elif is_speaking:
|
| 381 |
+
with self._state_lock:
|
| 382 |
+
self._state.speech_buffer.append(audio)
|
| 383 |
+
self._state.silence_frames += 1
|
| 384 |
+
|
| 385 |
+
if self._state.silence_frames >= self._silence_threshold:
|
| 386 |
+
is_speaking = False
|
| 387 |
+
|
| 388 |
+
# Check minimum speech duration
|
| 389 |
+
if len(self._state.speech_buffer) >= self._min_speech_chunks:
|
| 390 |
+
speech_audio = np.concatenate(self._state.speech_buffer)
|
| 391 |
+
self._state.speech_buffer = []
|
| 392 |
+
self._state.silence_frames = 0
|
| 393 |
+
|
| 394 |
+
# Start generation
|
| 395 |
+
self._generate_thread = threading.Thread(
|
| 396 |
+
target=self._generate_loop,
|
| 397 |
+
args=(speech_audio,),
|
| 398 |
+
daemon=True,
|
| 399 |
+
)
|
| 400 |
+
self._generate_thread.start()
|
| 401 |
+
else:
|
| 402 |
+
self._state.speech_buffer = []
|
| 403 |
+
self._state.silence_frames = 0
|
| 404 |
+
|
| 405 |
+
def _generate_loop(self, speech_audio: np.ndarray) -> None:
|
| 406 |
+
"""Generation loop - produces text and audio response."""
|
| 407 |
+
with self._state_lock:
|
| 408 |
+
self._state.is_generating = True
|
| 409 |
+
self._state.generated_text = ""
|
| 410 |
+
self._state.stop_generate = False
|
| 411 |
+
|
| 412 |
+
try:
|
| 413 |
+
self._set_state(ConversationState.PROCESSING)
|
| 414 |
+
|
| 415 |
+
# Process input audio
|
| 416 |
+
device = next(self.model.language_model.parameters()).device
|
| 417 |
+
inputs = self.model._process_audio(speech_audio, self.config.sample_rate)
|
| 418 |
+
input_features = inputs["input_features"]
|
| 419 |
+
audio_attention_mask = inputs["attention_mask"]
|
| 420 |
+
|
| 421 |
+
# Encode
|
| 422 |
+
audio_embeds = self.model._encode_audio(input_features, audio_attention_mask)
|
| 423 |
+
input_ids, attention_mask = self.model._build_audio_prompt(
|
| 424 |
+
audio_attention_mask, 1, device
|
| 425 |
+
)
|
| 426 |
+
inputs_embeds = self.model._inject_audio_embeddings(input_ids, audio_embeds)
|
| 427 |
+
|
| 428 |
+
# Check for interruption
|
| 429 |
+
if self._state.stop_generate:
|
| 430 |
+
return
|
| 431 |
+
|
| 432 |
+
# Generate text
|
| 433 |
+
with torch.no_grad():
|
| 434 |
+
output = self.model.language_model.generate(
|
| 435 |
+
input_ids=input_ids,
|
| 436 |
+
inputs_embeds=inputs_embeds,
|
| 437 |
+
attention_mask=attention_mask,
|
| 438 |
+
generation_config=self.model.generation_config,
|
| 439 |
+
)
|
| 440 |
+
|
| 441 |
+
if self._state.stop_generate:
|
| 442 |
+
return
|
| 443 |
+
|
| 444 |
+
# Extract text
|
| 445 |
+
text_ids = output[:, input_ids.shape[1] :]
|
| 446 |
+
text = self.model.tokenizer.decode(text_ids[0], skip_special_tokens=True)
|
| 447 |
+
|
| 448 |
+
with self._state_lock:
|
| 449 |
+
self._state.generated_text = text
|
| 450 |
+
|
| 451 |
+
self._emit_text(text, interim=False)
|
| 452 |
+
|
| 453 |
+
if self._state.stop_generate:
|
| 454 |
+
return
|
| 455 |
+
|
| 456 |
+
# Generate audio
|
| 457 |
+
if self.model.audio_head is not None:
|
| 458 |
+
self._set_state(ConversationState.SPEAKING)
|
| 459 |
+
|
| 460 |
+
for audio_chunk in self.model.audio_head.generate_streaming(
|
| 461 |
+
text_token_ids=text_ids,
|
| 462 |
+
):
|
| 463 |
+
if self._state.stop_generate:
|
| 464 |
+
return
|
| 465 |
+
self._emit_audio(audio_chunk)
|
| 466 |
+
|
| 467 |
+
self._set_state(ConversationState.LISTENING)
|
| 468 |
+
|
| 469 |
+
except Exception as e:
|
| 470 |
+
logger.error(f"Generation error: {e}")
|
| 471 |
+
self._set_state(ConversationState.LISTENING)
|
| 472 |
+
|
| 473 |
+
finally:
|
| 474 |
+
with self._state_lock:
|
| 475 |
+
self._state.is_generating = False
|
model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4df14201e66792d6b8ccefd124852fdc5d47f013a6276f9973540d63956097ad
|
| 3 |
+
size 122303840
|
preprocessor_config.json
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"chunk_length": 30,
|
| 3 |
+
"dither": 0.0,
|
| 4 |
+
"feature_extractor_type": "WhisperFeatureExtractor",
|
| 5 |
+
"feature_size": 128,
|
| 6 |
+
"hop_length": 160,
|
| 7 |
+
"n_fft": 400,
|
| 8 |
+
"n_samples": 480000,
|
| 9 |
+
"nb_max_frames": 3000,
|
| 10 |
+
"padding": false,
|
| 11 |
+
"padding_side": "right",
|
| 12 |
+
"padding_value": 0.0,
|
| 13 |
+
"return_attention_mask": false,
|
| 14 |
+
"sampling_rate": 16000,
|
| 15 |
+
"processor_class": "ASRProcessor",
|
| 16 |
+
"auto_map": {
|
| 17 |
+
"AutoProcessor": "asr_processing.ASRProcessor"
|
| 18 |
+
}
|
| 19 |
+
}
|
projectors.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Audio projector module for bridging encoder and decoder embeddings.
|
| 2 |
+
|
| 3 |
+
MLPAudioProjector: Simple 2-layer MLP with frame stacking downsampling.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
from transformers.models.llama.modeling_llama import LlamaRMSNorm
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class MLPAudioProjector(nn.Module):
|
| 12 |
+
"""2-layer MLP projector with frame-stacking downsampling (matches GLM-ASR)."""
|
| 13 |
+
|
| 14 |
+
def __init__(self, config):
|
| 15 |
+
"""Initialize MLP projector.
|
| 16 |
+
|
| 17 |
+
Args:
|
| 18 |
+
config: ASRConfig with encoder_dim, llm_dim, projector_pool_stride
|
| 19 |
+
"""
|
| 20 |
+
super().__init__()
|
| 21 |
+
|
| 22 |
+
encoder_dim = getattr(config, "encoder_dim", 768)
|
| 23 |
+
llm_dim = getattr(config, "llm_dim", 2048)
|
| 24 |
+
self.k = getattr(config, "projector_pool_stride", 4)
|
| 25 |
+
|
| 26 |
+
# Frame stacking: concat k adjacent frames then project
|
| 27 |
+
in_dim = encoder_dim * self.k
|
| 28 |
+
# Hidden dim defaults to llm_dim, can be overridden via config
|
| 29 |
+
hidden_dim = getattr(config, "projector_hidden_dim", None) or llm_dim
|
| 30 |
+
self.linear_1 = nn.Linear(in_dim, hidden_dim, bias=False)
|
| 31 |
+
self.norm = LlamaRMSNorm(hidden_dim, eps=1e-6)
|
| 32 |
+
self.act = nn.GELU()
|
| 33 |
+
self.linear_2 = nn.Linear(hidden_dim, llm_dim, bias=False)
|
| 34 |
+
|
| 35 |
+
def get_output_length(self, input_length: int) -> int:
|
| 36 |
+
"""Calculate output sequence length given input length (matches GLM-ASR)."""
|
| 37 |
+
# GLM-ASR formula: (L - merge_factor) // merge_factor + 1
|
| 38 |
+
return (input_length - self.k) // self.k + 1
|
| 39 |
+
|
| 40 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 41 |
+
"""Project audio features to LLM embedding space.
|
| 42 |
+
|
| 43 |
+
Args:
|
| 44 |
+
x: Audio encoder output of shape [batch, seq_len, encoder_dim]
|
| 45 |
+
|
| 46 |
+
Returns:
|
| 47 |
+
Projected features of shape [batch, (seq_len - k) // k + 1, llm_dim]
|
| 48 |
+
"""
|
| 49 |
+
batch, seq, dim = x.shape
|
| 50 |
+
# Truncate to match GLM-ASR: use (seq - k) // k + 1 frames
|
| 51 |
+
# This drops trailing frames that don't fill a complete k-frame window
|
| 52 |
+
out_len = (seq - self.k) // self.k + 1
|
| 53 |
+
x = x[:, : out_len * self.k, :] # Truncate to exact multiple
|
| 54 |
+
x = x.reshape(batch, out_len, dim * self.k)
|
| 55 |
+
|
| 56 |
+
x = self.linear_1(x)
|
| 57 |
+
x = self.norm(x)
|
| 58 |
+
x = self.act(x)
|
| 59 |
+
return self.linear_2(x)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
PROJECTOR_CLASSES = {
|
| 63 |
+
"mlp": MLPAudioProjector,
|
| 64 |
+
}
|
tokenizer.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d4aeaf198f783cbf58d8cd59812baac429ffe49147bf9648f6618de20b8d4a4c
|
| 3 |
+
size 17209003
|
tokenizer_config.json
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"backend": "tokenizers",
|
| 3 |
+
"bos_token": null,
|
| 4 |
+
"clean_up_tokenization_spaces": true,
|
| 5 |
+
"eos_token": "<|im_end|>",
|
| 6 |
+
"extra_special_tokens": [
|
| 7 |
+
"<audio>"
|
| 8 |
+
],
|
| 9 |
+
"fast": false,
|
| 10 |
+
"is_local": false,
|
| 11 |
+
"model_input_names": [
|
| 12 |
+
"input_ids",
|
| 13 |
+
"attention_mask"
|
| 14 |
+
],
|
| 15 |
+
"model_max_length": 131072,
|
| 16 |
+
"model_specific_special_tokens": {},
|
| 17 |
+
"pad_token": "<|finetune_right_pad_id|>",
|
| 18 |
+
"tokenizer_class": "TokenizersBackend"
|
| 19 |
+
}
|