Update custom model files, README, and requirements
Browse files- README.md +49 -185
- asr_config.py +24 -3
- asr_modeling.py +250 -1
- asr_pipeline.py +27 -47
- asr_processing.py +9 -2
README.md
CHANGED
|
@@ -1,207 +1,71 @@
|
|
| 1 |
---
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
tags:
|
| 6 |
-
-
|
| 7 |
-
-
|
| 8 |
-
-
|
|
|
|
|
|
|
|
|
|
| 9 |
---
|
| 10 |
|
| 11 |
-
#
|
| 12 |
|
| 13 |
-
|
| 14 |
|
|
|
|
| 15 |
|
|
|
|
|
|
|
|
|
|
| 16 |
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
<!-- Provide a longer summary of what this model is. -->
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
- **Developed by:** [More Information Needed]
|
| 26 |
-
- **Funded by [optional]:** [More Information Needed]
|
| 27 |
-
- **Shared by [optional]:** [More Information Needed]
|
| 28 |
-
- **Model type:** [More Information Needed]
|
| 29 |
-
- **Language(s) (NLP):** [More Information Needed]
|
| 30 |
-
- **License:** [More Information Needed]
|
| 31 |
-
- **Finetuned from model [optional]:** [More Information Needed]
|
| 32 |
-
|
| 33 |
-
### Model Sources [optional]
|
| 34 |
-
|
| 35 |
-
<!-- Provide the basic links for the model. -->
|
| 36 |
-
|
| 37 |
-
- **Repository:** [More Information Needed]
|
| 38 |
-
- **Paper [optional]:** [More Information Needed]
|
| 39 |
-
- **Demo [optional]:** [More Information Needed]
|
| 40 |
-
|
| 41 |
-
## Uses
|
| 42 |
-
|
| 43 |
-
<!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
|
| 44 |
-
|
| 45 |
-
### Direct Use
|
| 46 |
-
|
| 47 |
-
<!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
|
| 48 |
-
|
| 49 |
-
[More Information Needed]
|
| 50 |
-
|
| 51 |
-
### Downstream Use [optional]
|
| 52 |
-
|
| 53 |
-
<!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
|
| 54 |
-
|
| 55 |
-
[More Information Needed]
|
| 56 |
-
|
| 57 |
-
### Out-of-Scope Use
|
| 58 |
-
|
| 59 |
-
<!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
|
| 60 |
-
|
| 61 |
-
[More Information Needed]
|
| 62 |
-
|
| 63 |
-
## Bias, Risks, and Limitations
|
| 64 |
-
|
| 65 |
-
<!-- This section is meant to convey both technical and sociotechnical limitations. -->
|
| 66 |
-
|
| 67 |
-
[More Information Needed]
|
| 68 |
-
|
| 69 |
-
### Recommendations
|
| 70 |
-
|
| 71 |
-
<!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
|
| 72 |
-
|
| 73 |
-
Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
|
| 74 |
-
|
| 75 |
-
## How to Get Started with the Model
|
| 76 |
-
|
| 77 |
-
Use the code below to get started with the model.
|
| 78 |
-
|
| 79 |
-
[More Information Needed]
|
| 80 |
|
| 81 |
## Training Details
|
| 82 |
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
<!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
|
| 92 |
-
|
| 93 |
-
#### Preprocessing [optional]
|
| 94 |
-
|
| 95 |
-
[More Information Needed]
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
#### Training Hyperparameters
|
| 99 |
-
|
| 100 |
-
- **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
|
| 101 |
-
|
| 102 |
-
#### Speeds, Sizes, Times [optional]
|
| 103 |
-
|
| 104 |
-
<!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
|
| 105 |
-
|
| 106 |
-
[More Information Needed]
|
| 107 |
-
|
| 108 |
-
## Evaluation
|
| 109 |
-
|
| 110 |
-
<!-- This section describes the evaluation protocols and provides the results. -->
|
| 111 |
-
|
| 112 |
-
### Testing Data, Factors & Metrics
|
| 113 |
-
|
| 114 |
-
#### Testing Data
|
| 115 |
-
|
| 116 |
-
<!-- This should link to a Dataset Card if possible. -->
|
| 117 |
-
|
| 118 |
-
[More Information Needed]
|
| 119 |
-
|
| 120 |
-
#### Factors
|
| 121 |
-
|
| 122 |
-
<!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
|
| 123 |
-
|
| 124 |
-
[More Information Needed]
|
| 125 |
-
|
| 126 |
-
#### Metrics
|
| 127 |
-
|
| 128 |
-
<!-- These are the evaluation metrics being used, ideally with a description of why. -->
|
| 129 |
-
|
| 130 |
-
[More Information Needed]
|
| 131 |
-
|
| 132 |
-
### Results
|
| 133 |
-
|
| 134 |
-
[More Information Needed]
|
| 135 |
-
|
| 136 |
-
#### Summary
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
## Model Examination [optional]
|
| 141 |
-
|
| 142 |
-
<!-- Relevant interpretability work for the model goes here -->
|
| 143 |
-
|
| 144 |
-
[More Information Needed]
|
| 145 |
-
|
| 146 |
-
## Environmental Impact
|
| 147 |
-
|
| 148 |
-
<!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
|
| 149 |
-
|
| 150 |
-
Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
|
| 151 |
-
|
| 152 |
-
- **Hardware Type:** [More Information Needed]
|
| 153 |
-
- **Hours used:** [More Information Needed]
|
| 154 |
-
- **Cloud Provider:** [More Information Needed]
|
| 155 |
-
- **Compute Region:** [More Information Needed]
|
| 156 |
-
- **Carbon Emitted:** [More Information Needed]
|
| 157 |
-
|
| 158 |
-
## Technical Specifications [optional]
|
| 159 |
-
|
| 160 |
-
### Model Architecture and Objective
|
| 161 |
-
|
| 162 |
-
[More Information Needed]
|
| 163 |
-
|
| 164 |
-
### Compute Infrastructure
|
| 165 |
-
|
| 166 |
-
[More Information Needed]
|
| 167 |
-
|
| 168 |
-
#### Hardware
|
| 169 |
-
|
| 170 |
-
[More Information Needed]
|
| 171 |
-
|
| 172 |
-
#### Software
|
| 173 |
-
|
| 174 |
-
[More Information Needed]
|
| 175 |
-
|
| 176 |
-
## Citation [optional]
|
| 177 |
-
|
| 178 |
-
<!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
|
| 179 |
-
|
| 180 |
-
**BibTeX:**
|
| 181 |
-
|
| 182 |
-
[More Information Needed]
|
| 183 |
-
|
| 184 |
-
**APA:**
|
| 185 |
|
| 186 |
-
|
| 187 |
|
| 188 |
-
|
| 189 |
|
| 190 |
-
<!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
|
| 191 |
|
| 192 |
-
|
| 193 |
|
| 194 |
-
|
|
|
|
| 195 |
|
| 196 |
-
|
| 197 |
|
| 198 |
-
|
|
|
|
|
|
|
| 199 |
|
| 200 |
-
|
| 201 |
|
| 202 |
-
|
|
|
|
|
|
|
|
|
|
| 203 |
|
| 204 |
-
|
| 205 |
-
### Framework versions
|
| 206 |
|
| 207 |
-
-
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
license: mit
|
| 3 |
+
language:
|
| 4 |
+
- en
|
| 5 |
+
datasets:
|
| 6 |
+
- speechbrain/LoquaciousSet
|
| 7 |
+
base_model:
|
| 8 |
+
- openai/whisper-large-v3-turbo
|
| 9 |
+
- HuggingFaceTB/SmolLM3-3B
|
| 10 |
+
pipeline_tag: automatic-speech-recognition
|
| 11 |
tags:
|
| 12 |
+
- asr
|
| 13 |
+
- speech-recognition
|
| 14 |
+
- audio
|
| 15 |
+
- smollm
|
| 16 |
+
- whisper
|
| 17 |
+
- mlp
|
| 18 |
---
|
| 19 |
|
| 20 |
+
# Tiny Audio
|
| 21 |
|
| 22 |
+
A speech recognition model trained in 24 hours on a single GPU for ~$12. Built with the [Tiny Audio](https://github.com/alexkroman/tiny-audio) codebase—a minimal, hackable framework for training ASR models.
|
| 23 |
|
| 24 |
+
## Architecture
|
| 25 |
|
| 26 |
+
```
|
| 27 |
+
Audio (16kHz) → Whisper Encoder (frozen) → MLP Projector (trained) → SmolLM3-3B (frozen) → Text
|
| 28 |
+
```
|
| 29 |
|
| 30 |
+
**MLP Projector:**
|
| 31 |
+
- Convolutional downsampling: 4x sequence compression via two stride-2 conv layers
|
| 32 |
+
- Linear (1280 → 2048) → GELU → Linear (2048 → 2048)
|
| 33 |
+
- Output normalization: RMSNorm
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
|
| 35 |
## Training Details
|
| 36 |
|
| 37 |
+
| | |
|
| 38 |
+
|---|---|
|
| 39 |
+
| **Dataset** | LoquaciousSet (25,000 hours) |
|
| 40 |
+
| **Hardware** | Single NVIDIA A40 40GB |
|
| 41 |
+
| **Training Time** | ~24 hours |
|
| 42 |
+
| **Cost** | ~$12 |
|
| 43 |
+
| **Trainable Parameters** | ~12M (projector only) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
|
| 45 |
+
## Performance
|
| 46 |
|
| 47 |
+
**Word Error Rate (WER): 12.14%** on LoquaciousSet test set.
|
| 48 |
|
|
|
|
| 49 |
|
| 50 |
+
## Usage
|
| 51 |
|
| 52 |
+
```python
|
| 53 |
+
from transformers import pipeline
|
| 54 |
|
| 55 |
+
pipe = pipeline("automatic-speech-recognition", model="mazesmazes/tiny-audio", trust_remote_code=True)
|
| 56 |
|
| 57 |
+
result = pipe("path/to/audio.wav")
|
| 58 |
+
print(result["text"])
|
| 59 |
+
```
|
| 60 |
|
| 61 |
+
## Limitations
|
| 62 |
|
| 63 |
+
- English only
|
| 64 |
+
- Optimized for 16kHz audio; other sample rates are resampled automatically
|
| 65 |
+
- Performance may degrade on heavily accented speech, noisy environments, or domain-specific jargon
|
| 66 |
+
- Maximum audio length limited by context window
|
| 67 |
|
| 68 |
+
## Learn More
|
|
|
|
| 69 |
|
| 70 |
+
- **[Train your own model](https://github.com/alexkroman/tiny-audio)** — The full codebase with training scripts
|
| 71 |
+
- **[Free 3.5-hour course](https://github.com/alexkroman/tiny-audio/blob/main/docs/course/0-course-overview.md)** — Build your own ASR system from scratch
|
asr_config.py
CHANGED
|
@@ -22,12 +22,12 @@ class ASRConfig(transformers.PretrainedConfig):
|
|
| 22 |
# Default is Whisper/GLM-ASR structure: conv1(k=3,s=1,p=1) + conv2(k=3,s=2,p=1)
|
| 23 |
encoder_conv_layers: Optional[list] = None,
|
| 24 |
audio_sample_rate: int = 16000,
|
| 25 |
-
projector_init_std: float = 0.02,
|
| 26 |
projector_pool_stride: int = 4,
|
| 27 |
downsample_rate: int = 5, # Granite default
|
| 28 |
projector_hidden_dim: Optional[int] = None,
|
| 29 |
-
projector_type: str = "
|
| 30 |
-
projector_num_layers: int = 2, # Number of layers
|
|
|
|
| 31 |
projector_dropout: float = 0.0, # Dropout rate for projector layers
|
| 32 |
# MoE-specific configuration
|
| 33 |
num_experts: int = 4, # Number of experts in MoE projectors
|
|
@@ -41,7 +41,16 @@ class ASRConfig(transformers.PretrainedConfig):
|
|
| 41 |
qformer_intermediate_size: Optional[int] = None, # FFN size (defaults to 4x hidden)
|
| 42 |
label_smoothing: float = 0.0, # Label smoothing for cross-entropy loss
|
| 43 |
inference_warmup_tokens: int = 10,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
max_new_tokens: Optional[int] = None,
|
|
|
|
| 45 |
repetition_penalty: Optional[float] = None,
|
| 46 |
length_penalty: Optional[float] = None,
|
| 47 |
no_repeat_ngram_size: Optional[int] = None,
|
|
@@ -52,6 +61,7 @@ class ASRConfig(transformers.PretrainedConfig):
|
|
| 52 |
generation_defaults = {
|
| 53 |
"num_beams": 1,
|
| 54 |
"max_new_tokens": 256,
|
|
|
|
| 55 |
"repetition_penalty": 1.0,
|
| 56 |
"length_penalty": 1.0,
|
| 57 |
"no_repeat_ngram_size": 0,
|
|
@@ -91,12 +101,23 @@ class ASRConfig(transformers.PretrainedConfig):
|
|
| 91 |
self.qformer_intermediate_size = qformer_intermediate_size
|
| 92 |
self.label_smoothing = label_smoothing
|
| 93 |
self.inference_warmup_tokens = inference_warmup_tokens
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
|
| 95 |
# Generation parameters (use explicit value if provided, else use default)
|
| 96 |
self.num_beams = num_beams if num_beams is not None else generation_defaults["num_beams"]
|
| 97 |
self.max_new_tokens = (
|
| 98 |
max_new_tokens if max_new_tokens is not None else generation_defaults["max_new_tokens"]
|
| 99 |
)
|
|
|
|
|
|
|
|
|
|
| 100 |
self.repetition_penalty = (
|
| 101 |
repetition_penalty
|
| 102 |
if repetition_penalty is not None
|
|
|
|
| 22 |
# Default is Whisper/GLM-ASR structure: conv1(k=3,s=1,p=1) + conv2(k=3,s=2,p=1)
|
| 23 |
encoder_conv_layers: Optional[list] = None,
|
| 24 |
audio_sample_rate: int = 16000,
|
|
|
|
| 25 |
projector_pool_stride: int = 4,
|
| 26 |
downsample_rate: int = 5, # Granite default
|
| 27 |
projector_hidden_dim: Optional[int] = None,
|
| 28 |
+
projector_type: str = "mlp", # "mlp", "mosa", "moe", "qformer"
|
| 29 |
+
projector_num_layers: int = 2, # Number of layers in MLP projector
|
| 30 |
+
projector_init_std: float = 0.02, # Weight initialization std
|
| 31 |
projector_dropout: float = 0.0, # Dropout rate for projector layers
|
| 32 |
# MoE-specific configuration
|
| 33 |
num_experts: int = 4, # Number of experts in MoE projectors
|
|
|
|
| 41 |
qformer_intermediate_size: Optional[int] = None, # FFN size (defaults to 4x hidden)
|
| 42 |
label_smoothing: float = 0.0, # Label smoothing for cross-entropy loss
|
| 43 |
inference_warmup_tokens: int = 10,
|
| 44 |
+
# SpecAugment settings (Whisper defaults)
|
| 45 |
+
use_specaugment: bool = False,
|
| 46 |
+
mask_time_prob: float = 0.05, # Probability of masking time steps
|
| 47 |
+
mask_time_length: int = 10, # Max length of time mask
|
| 48 |
+
mask_time_min_masks: int = 2, # Min number of time masks
|
| 49 |
+
mask_feature_prob: float = 0.0, # Probability of masking frequency bins (disabled by default)
|
| 50 |
+
mask_feature_length: int = 10, # Max length of frequency mask
|
| 51 |
+
mask_feature_min_masks: int = 0, # Min number of frequency masks
|
| 52 |
max_new_tokens: Optional[int] = None,
|
| 53 |
+
min_new_tokens: Optional[int] = None,
|
| 54 |
repetition_penalty: Optional[float] = None,
|
| 55 |
length_penalty: Optional[float] = None,
|
| 56 |
no_repeat_ngram_size: Optional[int] = None,
|
|
|
|
| 61 |
generation_defaults = {
|
| 62 |
"num_beams": 1,
|
| 63 |
"max_new_tokens": 256,
|
| 64 |
+
"min_new_tokens": 1,
|
| 65 |
"repetition_penalty": 1.0,
|
| 66 |
"length_penalty": 1.0,
|
| 67 |
"no_repeat_ngram_size": 0,
|
|
|
|
| 101 |
self.qformer_intermediate_size = qformer_intermediate_size
|
| 102 |
self.label_smoothing = label_smoothing
|
| 103 |
self.inference_warmup_tokens = inference_warmup_tokens
|
| 104 |
+
# SpecAugment configuration
|
| 105 |
+
self.use_specaugment = use_specaugment
|
| 106 |
+
self.mask_time_prob = mask_time_prob
|
| 107 |
+
self.mask_time_length = mask_time_length
|
| 108 |
+
self.mask_time_min_masks = mask_time_min_masks
|
| 109 |
+
self.mask_feature_prob = mask_feature_prob
|
| 110 |
+
self.mask_feature_length = mask_feature_length
|
| 111 |
+
self.mask_feature_min_masks = mask_feature_min_masks
|
| 112 |
|
| 113 |
# Generation parameters (use explicit value if provided, else use default)
|
| 114 |
self.num_beams = num_beams if num_beams is not None else generation_defaults["num_beams"]
|
| 115 |
self.max_new_tokens = (
|
| 116 |
max_new_tokens if max_new_tokens is not None else generation_defaults["max_new_tokens"]
|
| 117 |
)
|
| 118 |
+
self.min_new_tokens = (
|
| 119 |
+
min_new_tokens if min_new_tokens is not None else generation_defaults["min_new_tokens"]
|
| 120 |
+
)
|
| 121 |
self.repetition_penalty = (
|
| 122 |
repetition_penalty
|
| 123 |
if repetition_penalty is not None
|
asr_modeling.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
| 1 |
import json
|
| 2 |
from pathlib import Path
|
| 3 |
-
from
|
|
|
|
| 4 |
|
| 5 |
import torch
|
| 6 |
import torch.nn as nn
|
|
@@ -10,6 +11,7 @@ from transformers import (
|
|
| 10 |
AutoModelForCausalLM,
|
| 11 |
AutoTokenizer,
|
| 12 |
PreTrainedModel,
|
|
|
|
| 13 |
)
|
| 14 |
from transformers.generation import GenerationMixin
|
| 15 |
from transformers.modeling_outputs import CausalLMOutputWithPast
|
|
@@ -22,6 +24,122 @@ except ImportError:
|
|
| 22 |
from projectors import PROJECTOR_CLASSES # type: ignore[no-redef]
|
| 23 |
|
| 24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
class ASRModel(PreTrainedModel, GenerationMixin):
|
| 26 |
"""Audio-to-text model combining an audio encoder, projector, and language model."""
|
| 27 |
|
|
@@ -110,6 +228,7 @@ class ASRModel(PreTrainedModel, GenerationMixin):
|
|
| 110 |
# Set up generation config with greedy decoding defaults
|
| 111 |
self.generation_config = self.language_model.generation_config
|
| 112 |
self.generation_config.max_new_tokens = config.max_new_tokens
|
|
|
|
| 113 |
self.generation_config.num_beams = config.num_beams
|
| 114 |
self.generation_config.do_sample = False
|
| 115 |
# Clear sampling params (inherited from LLM) since we use greedy decoding
|
|
@@ -383,6 +502,18 @@ class ASRModel(PreTrainedModel, GenerationMixin):
|
|
| 383 |
inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
|
| 384 |
|
| 385 |
if input_features is not None and input_ids is not None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 386 |
# Encode audio -> flattened (total_audio_tokens, hidden_dim)
|
| 387 |
audio_embeds = self._encode_audio(input_features, audio_attention_mask)
|
| 388 |
|
|
@@ -515,6 +646,120 @@ class ASRModel(PreTrainedModel, GenerationMixin):
|
|
| 515 |
return output
|
| 516 |
return output.sequences
|
| 517 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 518 |
def save_pretrained(self, save_directory: Union[str, Path], **kwargs):
|
| 519 |
"""Save model, tokenizer, and processor."""
|
| 520 |
import shutil
|
|
@@ -568,6 +813,10 @@ class ASRModel(PreTrainedModel, GenerationMixin):
|
|
| 568 |
# Copy projectors module
|
| 569 |
shutil.copy(src_dir / "projectors.py", save_dir / "projectors.py")
|
| 570 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 571 |
|
| 572 |
# Register with transformers Auto classes
|
| 573 |
AutoConfig.register("asr_model", ASRConfig)
|
|
|
|
| 1 |
import json
|
| 2 |
from pathlib import Path
|
| 3 |
+
from threading import Thread
|
| 4 |
+
from typing import Iterator, Optional, Union
|
| 5 |
|
| 6 |
import torch
|
| 7 |
import torch.nn as nn
|
|
|
|
| 11 |
AutoModelForCausalLM,
|
| 12 |
AutoTokenizer,
|
| 13 |
PreTrainedModel,
|
| 14 |
+
TextIteratorStreamer,
|
| 15 |
)
|
| 16 |
from transformers.generation import GenerationMixin
|
| 17 |
from transformers.modeling_outputs import CausalLMOutputWithPast
|
|
|
|
| 24 |
from projectors import PROJECTOR_CLASSES # type: ignore[no-redef]
|
| 25 |
|
| 26 |
|
| 27 |
+
def _compute_mask_indices(
|
| 28 |
+
shape: tuple[int, int],
|
| 29 |
+
mask_prob: float,
|
| 30 |
+
mask_length: int,
|
| 31 |
+
min_masks: int = 0,
|
| 32 |
+
device: torch.device = None,
|
| 33 |
+
) -> torch.Tensor:
|
| 34 |
+
"""Compute random mask spans for SpecAugment.
|
| 35 |
+
|
| 36 |
+
Based on transformers' _compute_mask_indices for Wav2Vec2/Whisper.
|
| 37 |
+
|
| 38 |
+
Args:
|
| 39 |
+
shape: (batch_size, sequence_length)
|
| 40 |
+
mask_prob: Probability for each token to be chosen as start of mask span
|
| 41 |
+
mask_length: Maximum length of mask span
|
| 42 |
+
min_masks: Minimum number of masks per sample
|
| 43 |
+
device: Device to create tensor on
|
| 44 |
+
|
| 45 |
+
Returns:
|
| 46 |
+
Boolean mask tensor of shape (batch_size, sequence_length)
|
| 47 |
+
"""
|
| 48 |
+
batch_size, sequence_length = shape
|
| 49 |
+
|
| 50 |
+
if mask_length < 1:
|
| 51 |
+
raise ValueError(f"mask_length must be >= 1, got {mask_length}")
|
| 52 |
+
|
| 53 |
+
if mask_length > sequence_length:
|
| 54 |
+
raise ValueError(
|
| 55 |
+
f"mask_length {mask_length} must be <= sequence_length {sequence_length}"
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
# Compute number of masked spans per sample
|
| 59 |
+
num_masked_spans = int(mask_prob * sequence_length / mask_length + torch.rand(1).item())
|
| 60 |
+
num_masked_spans = max(num_masked_spans, min_masks)
|
| 61 |
+
|
| 62 |
+
# Clamp to ensure we don't exceed sequence length
|
| 63 |
+
if num_masked_spans * mask_length > sequence_length:
|
| 64 |
+
num_masked_spans = sequence_length // mask_length
|
| 65 |
+
|
| 66 |
+
if num_masked_spans == 0:
|
| 67 |
+
return torch.zeros((batch_size, sequence_length), dtype=torch.bool, device=device)
|
| 68 |
+
|
| 69 |
+
# Uniformly sample span start indices
|
| 70 |
+
mask = torch.zeros((batch_size, sequence_length), dtype=torch.bool, device=device)
|
| 71 |
+
|
| 72 |
+
for i in range(batch_size):
|
| 73 |
+
# Random start indices for this sample
|
| 74 |
+
spec_aug_start_indices = torch.randint(
|
| 75 |
+
0, sequence_length - mask_length + 1, (num_masked_spans,), device=device
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
# Create mask spans
|
| 79 |
+
for start_idx in spec_aug_start_indices:
|
| 80 |
+
mask[i, start_idx : start_idx + mask_length] = True
|
| 81 |
+
|
| 82 |
+
return mask
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def apply_specaugment(
|
| 86 |
+
input_features: torch.Tensor,
|
| 87 |
+
mask_time_prob: float = 0.05,
|
| 88 |
+
mask_time_length: int = 10,
|
| 89 |
+
mask_time_min_masks: int = 2,
|
| 90 |
+
mask_feature_prob: float = 0.0,
|
| 91 |
+
mask_feature_length: int = 10,
|
| 92 |
+
mask_feature_min_masks: int = 0,
|
| 93 |
+
) -> torch.Tensor:
|
| 94 |
+
"""Apply SpecAugment to mel spectrogram features.
|
| 95 |
+
|
| 96 |
+
Args:
|
| 97 |
+
input_features: Mel spectrogram of shape (batch, n_mels, time)
|
| 98 |
+
mask_time_prob: Probability of masking time steps
|
| 99 |
+
mask_time_length: Max length of time mask
|
| 100 |
+
mask_time_min_masks: Min number of time masks
|
| 101 |
+
mask_feature_prob: Probability of masking frequency bins
|
| 102 |
+
mask_feature_length: Max length of frequency mask
|
| 103 |
+
mask_feature_min_masks: Min number of frequency masks
|
| 104 |
+
|
| 105 |
+
Returns:
|
| 106 |
+
Augmented mel spectrogram with same shape
|
| 107 |
+
"""
|
| 108 |
+
batch_size, n_mels, time_steps = input_features.shape
|
| 109 |
+
device = input_features.device
|
| 110 |
+
|
| 111 |
+
# Clone to avoid modifying original
|
| 112 |
+
augmented = input_features.clone()
|
| 113 |
+
|
| 114 |
+
# Time masking (along time dimension)
|
| 115 |
+
if mask_time_prob > 0:
|
| 116 |
+
time_mask = _compute_mask_indices(
|
| 117 |
+
shape=(batch_size, time_steps),
|
| 118 |
+
mask_prob=mask_time_prob,
|
| 119 |
+
mask_length=mask_time_length,
|
| 120 |
+
min_masks=mask_time_min_masks,
|
| 121 |
+
device=device,
|
| 122 |
+
)
|
| 123 |
+
# Expand to (batch, 1, time) for broadcasting
|
| 124 |
+
time_mask = time_mask.unsqueeze(1)
|
| 125 |
+
augmented = augmented.masked_fill(time_mask, 0.0)
|
| 126 |
+
|
| 127 |
+
# Frequency masking (along mel dimension)
|
| 128 |
+
if mask_feature_prob > 0:
|
| 129 |
+
feature_mask = _compute_mask_indices(
|
| 130 |
+
shape=(batch_size, n_mels),
|
| 131 |
+
mask_prob=mask_feature_prob,
|
| 132 |
+
mask_length=mask_feature_length,
|
| 133 |
+
min_masks=mask_feature_min_masks,
|
| 134 |
+
device=device,
|
| 135 |
+
)
|
| 136 |
+
# Expand to (batch, n_mels, 1) for broadcasting
|
| 137 |
+
feature_mask = feature_mask.unsqueeze(2)
|
| 138 |
+
augmented = augmented.masked_fill(feature_mask, 0.0)
|
| 139 |
+
|
| 140 |
+
return augmented
|
| 141 |
+
|
| 142 |
+
|
| 143 |
class ASRModel(PreTrainedModel, GenerationMixin):
|
| 144 |
"""Audio-to-text model combining an audio encoder, projector, and language model."""
|
| 145 |
|
|
|
|
| 228 |
# Set up generation config with greedy decoding defaults
|
| 229 |
self.generation_config = self.language_model.generation_config
|
| 230 |
self.generation_config.max_new_tokens = config.max_new_tokens
|
| 231 |
+
self.generation_config.min_new_tokens = config.min_new_tokens
|
| 232 |
self.generation_config.num_beams = config.num_beams
|
| 233 |
self.generation_config.do_sample = False
|
| 234 |
# Clear sampling params (inherited from LLM) since we use greedy decoding
|
|
|
|
| 502 |
inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
|
| 503 |
|
| 504 |
if input_features is not None and input_ids is not None:
|
| 505 |
+
# Apply SpecAugment during training if enabled
|
| 506 |
+
if self.training and getattr(self.config, "use_specaugment", False):
|
| 507 |
+
input_features = apply_specaugment(
|
| 508 |
+
input_features,
|
| 509 |
+
mask_time_prob=self.config.mask_time_prob,
|
| 510 |
+
mask_time_length=self.config.mask_time_length,
|
| 511 |
+
mask_time_min_masks=self.config.mask_time_min_masks,
|
| 512 |
+
mask_feature_prob=self.config.mask_feature_prob,
|
| 513 |
+
mask_feature_length=self.config.mask_feature_length,
|
| 514 |
+
mask_feature_min_masks=self.config.mask_feature_min_masks,
|
| 515 |
+
)
|
| 516 |
+
|
| 517 |
# Encode audio -> flattened (total_audio_tokens, hidden_dim)
|
| 518 |
audio_embeds = self._encode_audio(input_features, audio_attention_mask)
|
| 519 |
|
|
|
|
| 646 |
return output
|
| 647 |
return output.sequences
|
| 648 |
|
| 649 |
+
def generate_streaming(
|
| 650 |
+
self,
|
| 651 |
+
input_features: torch.Tensor,
|
| 652 |
+
audio_attention_mask: torch.Tensor,
|
| 653 |
+
system_prompt: Optional[str] = None,
|
| 654 |
+
**generate_kwargs,
|
| 655 |
+
) -> Iterator[str]:
|
| 656 |
+
"""Generate transcription with streaming token output.
|
| 657 |
+
|
| 658 |
+
Yields partial transcript strings as tokens are generated.
|
| 659 |
+
Reduces time-to-first-word by streaming tokens as they're decoded.
|
| 660 |
+
|
| 661 |
+
Args:
|
| 662 |
+
input_features: Mel spectrogram features (batch, n_mels, mel_len)
|
| 663 |
+
audio_attention_mask: Mask for real vs padded mel frames (batch, mel_len)
|
| 664 |
+
system_prompt: Optional system prompt override
|
| 665 |
+
**generate_kwargs: Additional generation arguments
|
| 666 |
+
|
| 667 |
+
Yields:
|
| 668 |
+
Partial transcript text as each token is generated
|
| 669 |
+
"""
|
| 670 |
+
device = input_features.device
|
| 671 |
+
batch_size = input_features.shape[0]
|
| 672 |
+
|
| 673 |
+
# Encode audio -> flattened embeddings
|
| 674 |
+
audio_embeds = self._encode_audio(input_features, audio_attention_mask)
|
| 675 |
+
|
| 676 |
+
# Build prompt with correct number of audio tokens
|
| 677 |
+
num_audio_tokens = self._get_num_audio_tokens(audio_attention_mask)
|
| 678 |
+
audio_placeholder = "<audio>" * num_audio_tokens
|
| 679 |
+
|
| 680 |
+
system_prompt = system_prompt or self.system_prompt
|
| 681 |
+
|
| 682 |
+
messages: list[dict[str, str]] = []
|
| 683 |
+
if system_prompt:
|
| 684 |
+
messages.append({"role": "system", "content": system_prompt})
|
| 685 |
+
messages.append({"role": "user", "content": self.TRANSCRIBE_PROMPT + audio_placeholder})
|
| 686 |
+
|
| 687 |
+
chat_result = self.tokenizer.apply_chat_template(
|
| 688 |
+
messages,
|
| 689 |
+
tokenize=True,
|
| 690 |
+
add_generation_prompt=True,
|
| 691 |
+
return_tensors="pt",
|
| 692 |
+
)
|
| 693 |
+
input_ids = chat_result.input_ids.to(device)
|
| 694 |
+
|
| 695 |
+
if input_ids.dim() == 1:
|
| 696 |
+
input_ids = input_ids.unsqueeze(0)
|
| 697 |
+
if input_ids.shape[0] == 1 and batch_size > 1:
|
| 698 |
+
input_ids = input_ids.expand(batch_size, -1)
|
| 699 |
+
|
| 700 |
+
attention_mask = torch.ones_like(input_ids)
|
| 701 |
+
|
| 702 |
+
# Get text embeddings and replace audio tokens with audio embeddings
|
| 703 |
+
inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
|
| 704 |
+
audio_token_mask = (input_ids == self.audio_token_id).unsqueeze(-1)
|
| 705 |
+
inputs_embeds = inputs_embeds.masked_scatter(
|
| 706 |
+
audio_token_mask.to(inputs_embeds.device),
|
| 707 |
+
audio_embeds.to(inputs_embeds.device, dtype=inputs_embeds.dtype),
|
| 708 |
+
)
|
| 709 |
+
|
| 710 |
+
# Setup streamer for token-by-token output
|
| 711 |
+
streamer = TextIteratorStreamer(
|
| 712 |
+
self.tokenizer,
|
| 713 |
+
skip_prompt=True,
|
| 714 |
+
skip_special_tokens=True,
|
| 715 |
+
)
|
| 716 |
+
|
| 717 |
+
# Prepare generation kwargs
|
| 718 |
+
gen_kwargs = {
|
| 719 |
+
"inputs_embeds": inputs_embeds,
|
| 720 |
+
"attention_mask": attention_mask,
|
| 721 |
+
"generation_config": self.generation_config,
|
| 722 |
+
"streamer": streamer,
|
| 723 |
+
**generate_kwargs,
|
| 724 |
+
}
|
| 725 |
+
|
| 726 |
+
# Run generation in background thread
|
| 727 |
+
thread = Thread(target=self.language_model.generate, kwargs=gen_kwargs)
|
| 728 |
+
thread.start()
|
| 729 |
+
|
| 730 |
+
# Yield tokens as they're generated, filtering out <think>...</think> blocks
|
| 731 |
+
# SmolLM3 always starts in thinking mode, so assume we're in a think block
|
| 732 |
+
in_think_block = True
|
| 733 |
+
buffer = ""
|
| 734 |
+
|
| 735 |
+
for text in streamer:
|
| 736 |
+
buffer += text
|
| 737 |
+
|
| 738 |
+
# Check for think block start (in case model outputs multiple think blocks)
|
| 739 |
+
while "<think>" in buffer:
|
| 740 |
+
in_think_block = True
|
| 741 |
+
# Yield any text before <think>
|
| 742 |
+
before_think = buffer.split("<think>")[0]
|
| 743 |
+
if before_think:
|
| 744 |
+
yield before_think
|
| 745 |
+
buffer = buffer.split("<think>", 1)[-1]
|
| 746 |
+
|
| 747 |
+
# Check for think block end
|
| 748 |
+
while in_think_block and "</think>" in buffer:
|
| 749 |
+
in_think_block = False
|
| 750 |
+
buffer = buffer.split("</think>", 1)[-1]
|
| 751 |
+
|
| 752 |
+
# Yield text if not in think block
|
| 753 |
+
if not in_think_block and buffer:
|
| 754 |
+
yield buffer
|
| 755 |
+
buffer = ""
|
| 756 |
+
|
| 757 |
+
# Yield any remaining buffer
|
| 758 |
+
if buffer and not in_think_block:
|
| 759 |
+
yield buffer
|
| 760 |
+
|
| 761 |
+
thread.join()
|
| 762 |
+
|
| 763 |
def save_pretrained(self, save_directory: Union[str, Path], **kwargs):
|
| 764 |
"""Save model, tokenizer, and processor."""
|
| 765 |
import shutil
|
|
|
|
| 813 |
# Copy projectors module
|
| 814 |
shutil.copy(src_dir / "projectors.py", save_dir / "projectors.py")
|
| 815 |
|
| 816 |
+
def create_or_update_model_card(self, output_dir: Union[str, Path]):
|
| 817 |
+
"""No-op for model card creation - we use MODEL_CARD.md in repo instead."""
|
| 818 |
+
pass
|
| 819 |
+
|
| 820 |
|
| 821 |
# Register with transformers Auto classes
|
| 822 |
AutoConfig.register("asr_model", ASRConfig)
|
asr_pipeline.py
CHANGED
|
@@ -476,57 +476,37 @@ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
|
|
| 476 |
text = self.tokenizer.decode(tokens, skip_special_tokens=True).strip()
|
| 477 |
# Strip <think>...</think> tags (Qwen3 doesn't respect /no_think prompt)
|
| 478 |
text = re.sub(r"<think>.*?</think>\s*", "", text, flags=re.DOTALL).strip()
|
| 479 |
-
#
|
| 480 |
-
text = self.
|
| 481 |
-
# Truncate if a word repeats more than 3 times consecutively
|
| 482 |
-
text = self._truncate_repetitions(text, max_repeats=3)
|
| 483 |
return {"text": text}
|
| 484 |
|
| 485 |
-
def
|
| 486 |
-
"""
|
|
|
|
|
|
|
| 487 |
|
| 488 |
-
|
| 489 |
-
|
| 490 |
|
| 491 |
-
|
| 492 |
-
|
| 493 |
-
|
| 494 |
-
Returns:
|
| 495 |
-
Text with acronyms collapsed
|
| 496 |
-
"""
|
| 497 |
-
# Match 2+ single letters (case-insensitive) separated by spaces
|
| 498 |
-
# Pattern: single letter, then one or more (space + single letter)
|
| 499 |
-
pattern = r"\b([A-Za-z])((?:\s[A-Za-z]){1,})\b"
|
| 500 |
-
|
| 501 |
-
def collapse_match(match: re.Match) -> str:
|
| 502 |
-
# Get the full match and remove spaces
|
| 503 |
-
full = match.group(0)
|
| 504 |
-
return full.replace(" ", "").upper()
|
| 505 |
-
|
| 506 |
-
return re.sub(pattern, collapse_match, text)
|
| 507 |
-
|
| 508 |
-
def _truncate_repetitions(self, text: str, max_repeats: int = 3) -> str:
|
| 509 |
-
"""Truncate text when a word repeats more than max_repeats times consecutively.
|
| 510 |
-
|
| 511 |
-
Args:
|
| 512 |
-
text: Input text to check for repetitions
|
| 513 |
-
max_repeats: Maximum allowed consecutive repetitions (default 3)
|
| 514 |
-
|
| 515 |
-
Returns:
|
| 516 |
-
Truncated text if repetition detected, otherwise original text
|
| 517 |
-
"""
|
| 518 |
words = text.split()
|
| 519 |
-
if len(words)
|
| 520 |
-
|
| 521 |
-
|
| 522 |
-
|
| 523 |
-
|
| 524 |
-
|
| 525 |
-
|
| 526 |
-
|
| 527 |
-
|
| 528 |
-
|
| 529 |
-
|
| 530 |
-
repeat_count
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 531 |
|
| 532 |
return text
|
|
|
|
| 476 |
text = self.tokenizer.decode(tokens, skip_special_tokens=True).strip()
|
| 477 |
# Strip <think>...</think> tags (Qwen3 doesn't respect /no_think prompt)
|
| 478 |
text = re.sub(r"<think>.*?</think>\s*", "", text, flags=re.DOTALL).strip()
|
| 479 |
+
# Post-process prediction
|
| 480 |
+
text = self._post_process_prediction(text)
|
|
|
|
|
|
|
| 481 |
return {"text": text}
|
| 482 |
|
| 483 |
+
def _post_process_prediction(self, text: str) -> str:
|
| 484 |
+
"""Post-process model output to fix common issues."""
|
| 485 |
+
if not text:
|
| 486 |
+
return ""
|
| 487 |
|
| 488 |
+
# 1. LOWERCASE
|
| 489 |
+
text = text.lower()
|
| 490 |
|
| 491 |
+
# 2. REMOVE REPETITIVE LOOPS
|
| 492 |
+
# If the model repeats the same phrase more than twice, cut it off.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 493 |
words = text.split()
|
| 494 |
+
if len(words) > 10:
|
| 495 |
+
# Check for repeating n-grams (1 to 4 words long)
|
| 496 |
+
for n in range(1, 5):
|
| 497 |
+
last_sequence = words[-n:]
|
| 498 |
+
repeat_count = 0
|
| 499 |
+
idx = len(words) - n
|
| 500 |
+
while idx >= n and words[idx - n : idx] == last_sequence:
|
| 501 |
+
repeat_count += 1
|
| 502 |
+
idx -= n
|
| 503 |
+
|
| 504 |
+
# If more than 2 exact repetitions at the end, truncate
|
| 505 |
+
if repeat_count > 2:
|
| 506 |
+
text = " ".join(words[: idx + n])
|
| 507 |
+
break
|
| 508 |
+
|
| 509 |
+
# 3. STRIP WHITESPACE
|
| 510 |
+
text = re.sub(r'\s+', ' ', text).strip()
|
| 511 |
|
| 512 |
return text
|
asr_processing.py
CHANGED
|
@@ -94,14 +94,21 @@ class ASRProcessor(ProcessorMixin):
|
|
| 94 |
messages.append({"role": "assistant", "content": text})
|
| 95 |
|
| 96 |
# Tokenize
|
| 97 |
-
|
| 98 |
messages,
|
| 99 |
tokenize=True,
|
| 100 |
add_generation_prompt=(text is None),
|
| 101 |
return_tensors=return_tensors,
|
| 102 |
)
|
| 103 |
|
| 104 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 105 |
input_ids = input_ids.unsqueeze(0)
|
| 106 |
|
| 107 |
result["input_ids"] = input_ids
|
|
|
|
| 94 |
messages.append({"role": "assistant", "content": text})
|
| 95 |
|
| 96 |
# Tokenize
|
| 97 |
+
tokenized = self.tokenizer.apply_chat_template(
|
| 98 |
messages,
|
| 99 |
tokenize=True,
|
| 100 |
add_generation_prompt=(text is None),
|
| 101 |
return_tensors=return_tensors,
|
| 102 |
)
|
| 103 |
|
| 104 |
+
# Handle both tensor and BatchEncoding returns
|
| 105 |
+
if isinstance(tokenized, torch.Tensor):
|
| 106 |
+
input_ids = tokenized
|
| 107 |
+
else:
|
| 108 |
+
# BatchEncoding or dict-like object
|
| 109 |
+
input_ids = tokenized["input_ids"] if "input_ids" in tokenized else tokenized.input_ids
|
| 110 |
+
|
| 111 |
+
if input_ids.dim() == 1:
|
| 112 |
input_ids = input_ids.unsqueeze(0)
|
| 113 |
|
| 114 |
result["input_ids"] = input_ids
|