mazesmazes commited on
Commit
71301b8
·
verified ·
1 Parent(s): ab03ae3

Update custom model files, README, and requirements

Browse files
Files changed (5) hide show
  1. README.md +49 -185
  2. asr_config.py +24 -3
  3. asr_modeling.py +250 -1
  4. asr_pipeline.py +27 -47
  5. asr_processing.py +9 -2
README.md CHANGED
@@ -1,207 +1,71 @@
1
  ---
2
- base_model: Qwen/Qwen3-1.7B
3
- library_name: peft
4
- pipeline_tag: text-generation
 
 
 
 
 
 
5
  tags:
6
- - base_model:adapter:Qwen/Qwen3-1.7B
7
- - lora
8
- - transformers
 
 
 
9
  ---
10
 
11
- # Model Card for Model ID
12
 
13
- <!-- Provide a quick summary of what the model is/does. -->
14
 
 
15
 
 
 
 
16
 
17
- ## Model Details
18
-
19
- ### Model Description
20
-
21
- <!-- Provide a longer summary of what this model is. -->
22
-
23
-
24
-
25
- - **Developed by:** [More Information Needed]
26
- - **Funded by [optional]:** [More Information Needed]
27
- - **Shared by [optional]:** [More Information Needed]
28
- - **Model type:** [More Information Needed]
29
- - **Language(s) (NLP):** [More Information Needed]
30
- - **License:** [More Information Needed]
31
- - **Finetuned from model [optional]:** [More Information Needed]
32
-
33
- ### Model Sources [optional]
34
-
35
- <!-- Provide the basic links for the model. -->
36
-
37
- - **Repository:** [More Information Needed]
38
- - **Paper [optional]:** [More Information Needed]
39
- - **Demo [optional]:** [More Information Needed]
40
-
41
- ## Uses
42
-
43
- <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
44
-
45
- ### Direct Use
46
-
47
- <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
48
-
49
- [More Information Needed]
50
-
51
- ### Downstream Use [optional]
52
-
53
- <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
54
-
55
- [More Information Needed]
56
-
57
- ### Out-of-Scope Use
58
-
59
- <!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
60
-
61
- [More Information Needed]
62
-
63
- ## Bias, Risks, and Limitations
64
-
65
- <!-- This section is meant to convey both technical and sociotechnical limitations. -->
66
-
67
- [More Information Needed]
68
-
69
- ### Recommendations
70
-
71
- <!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
72
-
73
- Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
74
-
75
- ## How to Get Started with the Model
76
-
77
- Use the code below to get started with the model.
78
-
79
- [More Information Needed]
80
 
81
  ## Training Details
82
 
83
- ### Training Data
84
-
85
- <!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
86
-
87
- [More Information Needed]
88
-
89
- ### Training Procedure
90
-
91
- <!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
92
-
93
- #### Preprocessing [optional]
94
-
95
- [More Information Needed]
96
-
97
-
98
- #### Training Hyperparameters
99
-
100
- - **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
101
-
102
- #### Speeds, Sizes, Times [optional]
103
-
104
- <!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
105
-
106
- [More Information Needed]
107
-
108
- ## Evaluation
109
-
110
- <!-- This section describes the evaluation protocols and provides the results. -->
111
-
112
- ### Testing Data, Factors & Metrics
113
-
114
- #### Testing Data
115
-
116
- <!-- This should link to a Dataset Card if possible. -->
117
-
118
- [More Information Needed]
119
-
120
- #### Factors
121
-
122
- <!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
123
-
124
- [More Information Needed]
125
-
126
- #### Metrics
127
-
128
- <!-- These are the evaluation metrics being used, ideally with a description of why. -->
129
-
130
- [More Information Needed]
131
-
132
- ### Results
133
-
134
- [More Information Needed]
135
-
136
- #### Summary
137
-
138
-
139
-
140
- ## Model Examination [optional]
141
-
142
- <!-- Relevant interpretability work for the model goes here -->
143
-
144
- [More Information Needed]
145
-
146
- ## Environmental Impact
147
-
148
- <!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
149
-
150
- Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
151
-
152
- - **Hardware Type:** [More Information Needed]
153
- - **Hours used:** [More Information Needed]
154
- - **Cloud Provider:** [More Information Needed]
155
- - **Compute Region:** [More Information Needed]
156
- - **Carbon Emitted:** [More Information Needed]
157
-
158
- ## Technical Specifications [optional]
159
-
160
- ### Model Architecture and Objective
161
-
162
- [More Information Needed]
163
-
164
- ### Compute Infrastructure
165
-
166
- [More Information Needed]
167
-
168
- #### Hardware
169
-
170
- [More Information Needed]
171
-
172
- #### Software
173
-
174
- [More Information Needed]
175
-
176
- ## Citation [optional]
177
-
178
- <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
179
-
180
- **BibTeX:**
181
-
182
- [More Information Needed]
183
-
184
- **APA:**
185
 
186
- [More Information Needed]
187
 
188
- ## Glossary [optional]
189
 
190
- <!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
191
 
192
- [More Information Needed]
193
 
194
- ## More Information [optional]
 
195
 
196
- [More Information Needed]
197
 
198
- ## Model Card Authors [optional]
 
 
199
 
200
- [More Information Needed]
201
 
202
- ## Model Card Contact
 
 
 
203
 
204
- [More Information Needed]
205
- ### Framework versions
206
 
207
- - PEFT 0.18.0
 
 
1
  ---
2
+ license: mit
3
+ language:
4
+ - en
5
+ datasets:
6
+ - speechbrain/LoquaciousSet
7
+ base_model:
8
+ - openai/whisper-large-v3-turbo
9
+ - HuggingFaceTB/SmolLM3-3B
10
+ pipeline_tag: automatic-speech-recognition
11
  tags:
12
+ - asr
13
+ - speech-recognition
14
+ - audio
15
+ - smollm
16
+ - whisper
17
+ - mlp
18
  ---
19
 
20
+ # Tiny Audio
21
 
22
+ A speech recognition model trained in 24 hours on a single GPU for ~$12. Built with the [Tiny Audio](https://github.com/alexkroman/tiny-audio) codebase—a minimal, hackable framework for training ASR models.
23
 
24
+ ## Architecture
25
 
26
+ ```
27
+ Audio (16kHz) → Whisper Encoder (frozen) → MLP Projector (trained) → SmolLM3-3B (frozen) → Text
28
+ ```
29
 
30
+ **MLP Projector:**
31
+ - Convolutional downsampling: 4x sequence compression via two stride-2 conv layers
32
+ - Linear (1280 → 2048) → GELU → Linear (2048 → 2048)
33
+ - Output normalization: RMSNorm
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
  ## Training Details
36
 
37
+ | | |
38
+ |---|---|
39
+ | **Dataset** | LoquaciousSet (25,000 hours) |
40
+ | **Hardware** | Single NVIDIA A40 40GB |
41
+ | **Training Time** | ~24 hours |
42
+ | **Cost** | ~$12 |
43
+ | **Trainable Parameters** | ~12M (projector only) |
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
+ ## Performance
46
 
47
+ **Word Error Rate (WER): 12.14%** on LoquaciousSet test set.
48
 
 
49
 
50
+ ## Usage
51
 
52
+ ```python
53
+ from transformers import pipeline
54
 
55
+ pipe = pipeline("automatic-speech-recognition", model="mazesmazes/tiny-audio", trust_remote_code=True)
56
 
57
+ result = pipe("path/to/audio.wav")
58
+ print(result["text"])
59
+ ```
60
 
61
+ ## Limitations
62
 
63
+ - English only
64
+ - Optimized for 16kHz audio; other sample rates are resampled automatically
65
+ - Performance may degrade on heavily accented speech, noisy environments, or domain-specific jargon
66
+ - Maximum audio length limited by context window
67
 
68
+ ## Learn More
 
69
 
70
+ - **[Train your own model](https://github.com/alexkroman/tiny-audio)** — The full codebase with training scripts
71
+ - **[Free 3.5-hour course](https://github.com/alexkroman/tiny-audio/blob/main/docs/course/0-course-overview.md)** — Build your own ASR system from scratch
asr_config.py CHANGED
@@ -22,12 +22,12 @@ class ASRConfig(transformers.PretrainedConfig):
22
  # Default is Whisper/GLM-ASR structure: conv1(k=3,s=1,p=1) + conv2(k=3,s=2,p=1)
23
  encoder_conv_layers: Optional[list] = None,
24
  audio_sample_rate: int = 16000,
25
- projector_init_std: float = 0.02,
26
  projector_pool_stride: int = 4,
27
  downsample_rate: int = 5, # Granite default
28
  projector_hidden_dim: Optional[int] = None,
29
- projector_type: str = "moe", # "moe", "swiglu", "residual", "shared_moe", "mlp", "qformer"
30
- projector_num_layers: int = 2, # Number of layers (for residual projector)
 
31
  projector_dropout: float = 0.0, # Dropout rate for projector layers
32
  # MoE-specific configuration
33
  num_experts: int = 4, # Number of experts in MoE projectors
@@ -41,7 +41,16 @@ class ASRConfig(transformers.PretrainedConfig):
41
  qformer_intermediate_size: Optional[int] = None, # FFN size (defaults to 4x hidden)
42
  label_smoothing: float = 0.0, # Label smoothing for cross-entropy loss
43
  inference_warmup_tokens: int = 10,
 
 
 
 
 
 
 
 
44
  max_new_tokens: Optional[int] = None,
 
45
  repetition_penalty: Optional[float] = None,
46
  length_penalty: Optional[float] = None,
47
  no_repeat_ngram_size: Optional[int] = None,
@@ -52,6 +61,7 @@ class ASRConfig(transformers.PretrainedConfig):
52
  generation_defaults = {
53
  "num_beams": 1,
54
  "max_new_tokens": 256,
 
55
  "repetition_penalty": 1.0,
56
  "length_penalty": 1.0,
57
  "no_repeat_ngram_size": 0,
@@ -91,12 +101,23 @@ class ASRConfig(transformers.PretrainedConfig):
91
  self.qformer_intermediate_size = qformer_intermediate_size
92
  self.label_smoothing = label_smoothing
93
  self.inference_warmup_tokens = inference_warmup_tokens
 
 
 
 
 
 
 
 
94
 
95
  # Generation parameters (use explicit value if provided, else use default)
96
  self.num_beams = num_beams if num_beams is not None else generation_defaults["num_beams"]
97
  self.max_new_tokens = (
98
  max_new_tokens if max_new_tokens is not None else generation_defaults["max_new_tokens"]
99
  )
 
 
 
100
  self.repetition_penalty = (
101
  repetition_penalty
102
  if repetition_penalty is not None
 
22
  # Default is Whisper/GLM-ASR structure: conv1(k=3,s=1,p=1) + conv2(k=3,s=2,p=1)
23
  encoder_conv_layers: Optional[list] = None,
24
  audio_sample_rate: int = 16000,
 
25
  projector_pool_stride: int = 4,
26
  downsample_rate: int = 5, # Granite default
27
  projector_hidden_dim: Optional[int] = None,
28
+ projector_type: str = "mlp", # "mlp", "mosa", "moe", "qformer"
29
+ projector_num_layers: int = 2, # Number of layers in MLP projector
30
+ projector_init_std: float = 0.02, # Weight initialization std
31
  projector_dropout: float = 0.0, # Dropout rate for projector layers
32
  # MoE-specific configuration
33
  num_experts: int = 4, # Number of experts in MoE projectors
 
41
  qformer_intermediate_size: Optional[int] = None, # FFN size (defaults to 4x hidden)
42
  label_smoothing: float = 0.0, # Label smoothing for cross-entropy loss
43
  inference_warmup_tokens: int = 10,
44
+ # SpecAugment settings (Whisper defaults)
45
+ use_specaugment: bool = False,
46
+ mask_time_prob: float = 0.05, # Probability of masking time steps
47
+ mask_time_length: int = 10, # Max length of time mask
48
+ mask_time_min_masks: int = 2, # Min number of time masks
49
+ mask_feature_prob: float = 0.0, # Probability of masking frequency bins (disabled by default)
50
+ mask_feature_length: int = 10, # Max length of frequency mask
51
+ mask_feature_min_masks: int = 0, # Min number of frequency masks
52
  max_new_tokens: Optional[int] = None,
53
+ min_new_tokens: Optional[int] = None,
54
  repetition_penalty: Optional[float] = None,
55
  length_penalty: Optional[float] = None,
56
  no_repeat_ngram_size: Optional[int] = None,
 
61
  generation_defaults = {
62
  "num_beams": 1,
63
  "max_new_tokens": 256,
64
+ "min_new_tokens": 1,
65
  "repetition_penalty": 1.0,
66
  "length_penalty": 1.0,
67
  "no_repeat_ngram_size": 0,
 
101
  self.qformer_intermediate_size = qformer_intermediate_size
102
  self.label_smoothing = label_smoothing
103
  self.inference_warmup_tokens = inference_warmup_tokens
104
+ # SpecAugment configuration
105
+ self.use_specaugment = use_specaugment
106
+ self.mask_time_prob = mask_time_prob
107
+ self.mask_time_length = mask_time_length
108
+ self.mask_time_min_masks = mask_time_min_masks
109
+ self.mask_feature_prob = mask_feature_prob
110
+ self.mask_feature_length = mask_feature_length
111
+ self.mask_feature_min_masks = mask_feature_min_masks
112
 
113
  # Generation parameters (use explicit value if provided, else use default)
114
  self.num_beams = num_beams if num_beams is not None else generation_defaults["num_beams"]
115
  self.max_new_tokens = (
116
  max_new_tokens if max_new_tokens is not None else generation_defaults["max_new_tokens"]
117
  )
118
+ self.min_new_tokens = (
119
+ min_new_tokens if min_new_tokens is not None else generation_defaults["min_new_tokens"]
120
+ )
121
  self.repetition_penalty = (
122
  repetition_penalty
123
  if repetition_penalty is not None
asr_modeling.py CHANGED
@@ -1,6 +1,7 @@
1
  import json
2
  from pathlib import Path
3
- from typing import Optional, Union
 
4
 
5
  import torch
6
  import torch.nn as nn
@@ -10,6 +11,7 @@ from transformers import (
10
  AutoModelForCausalLM,
11
  AutoTokenizer,
12
  PreTrainedModel,
 
13
  )
14
  from transformers.generation import GenerationMixin
15
  from transformers.modeling_outputs import CausalLMOutputWithPast
@@ -22,6 +24,122 @@ except ImportError:
22
  from projectors import PROJECTOR_CLASSES # type: ignore[no-redef]
23
 
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  class ASRModel(PreTrainedModel, GenerationMixin):
26
  """Audio-to-text model combining an audio encoder, projector, and language model."""
27
 
@@ -110,6 +228,7 @@ class ASRModel(PreTrainedModel, GenerationMixin):
110
  # Set up generation config with greedy decoding defaults
111
  self.generation_config = self.language_model.generation_config
112
  self.generation_config.max_new_tokens = config.max_new_tokens
 
113
  self.generation_config.num_beams = config.num_beams
114
  self.generation_config.do_sample = False
115
  # Clear sampling params (inherited from LLM) since we use greedy decoding
@@ -383,6 +502,18 @@ class ASRModel(PreTrainedModel, GenerationMixin):
383
  inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
384
 
385
  if input_features is not None and input_ids is not None:
 
 
 
 
 
 
 
 
 
 
 
 
386
  # Encode audio -> flattened (total_audio_tokens, hidden_dim)
387
  audio_embeds = self._encode_audio(input_features, audio_attention_mask)
388
 
@@ -515,6 +646,120 @@ class ASRModel(PreTrainedModel, GenerationMixin):
515
  return output
516
  return output.sequences
517
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
518
  def save_pretrained(self, save_directory: Union[str, Path], **kwargs):
519
  """Save model, tokenizer, and processor."""
520
  import shutil
@@ -568,6 +813,10 @@ class ASRModel(PreTrainedModel, GenerationMixin):
568
  # Copy projectors module
569
  shutil.copy(src_dir / "projectors.py", save_dir / "projectors.py")
570
 
 
 
 
 
571
 
572
  # Register with transformers Auto classes
573
  AutoConfig.register("asr_model", ASRConfig)
 
1
  import json
2
  from pathlib import Path
3
+ from threading import Thread
4
+ from typing import Iterator, Optional, Union
5
 
6
  import torch
7
  import torch.nn as nn
 
11
  AutoModelForCausalLM,
12
  AutoTokenizer,
13
  PreTrainedModel,
14
+ TextIteratorStreamer,
15
  )
16
  from transformers.generation import GenerationMixin
17
  from transformers.modeling_outputs import CausalLMOutputWithPast
 
24
  from projectors import PROJECTOR_CLASSES # type: ignore[no-redef]
25
 
26
 
27
+ def _compute_mask_indices(
28
+ shape: tuple[int, int],
29
+ mask_prob: float,
30
+ mask_length: int,
31
+ min_masks: int = 0,
32
+ device: torch.device = None,
33
+ ) -> torch.Tensor:
34
+ """Compute random mask spans for SpecAugment.
35
+
36
+ Based on transformers' _compute_mask_indices for Wav2Vec2/Whisper.
37
+
38
+ Args:
39
+ shape: (batch_size, sequence_length)
40
+ mask_prob: Probability for each token to be chosen as start of mask span
41
+ mask_length: Maximum length of mask span
42
+ min_masks: Minimum number of masks per sample
43
+ device: Device to create tensor on
44
+
45
+ Returns:
46
+ Boolean mask tensor of shape (batch_size, sequence_length)
47
+ """
48
+ batch_size, sequence_length = shape
49
+
50
+ if mask_length < 1:
51
+ raise ValueError(f"mask_length must be >= 1, got {mask_length}")
52
+
53
+ if mask_length > sequence_length:
54
+ raise ValueError(
55
+ f"mask_length {mask_length} must be <= sequence_length {sequence_length}"
56
+ )
57
+
58
+ # Compute number of masked spans per sample
59
+ num_masked_spans = int(mask_prob * sequence_length / mask_length + torch.rand(1).item())
60
+ num_masked_spans = max(num_masked_spans, min_masks)
61
+
62
+ # Clamp to ensure we don't exceed sequence length
63
+ if num_masked_spans * mask_length > sequence_length:
64
+ num_masked_spans = sequence_length // mask_length
65
+
66
+ if num_masked_spans == 0:
67
+ return torch.zeros((batch_size, sequence_length), dtype=torch.bool, device=device)
68
+
69
+ # Uniformly sample span start indices
70
+ mask = torch.zeros((batch_size, sequence_length), dtype=torch.bool, device=device)
71
+
72
+ for i in range(batch_size):
73
+ # Random start indices for this sample
74
+ spec_aug_start_indices = torch.randint(
75
+ 0, sequence_length - mask_length + 1, (num_masked_spans,), device=device
76
+ )
77
+
78
+ # Create mask spans
79
+ for start_idx in spec_aug_start_indices:
80
+ mask[i, start_idx : start_idx + mask_length] = True
81
+
82
+ return mask
83
+
84
+
85
+ def apply_specaugment(
86
+ input_features: torch.Tensor,
87
+ mask_time_prob: float = 0.05,
88
+ mask_time_length: int = 10,
89
+ mask_time_min_masks: int = 2,
90
+ mask_feature_prob: float = 0.0,
91
+ mask_feature_length: int = 10,
92
+ mask_feature_min_masks: int = 0,
93
+ ) -> torch.Tensor:
94
+ """Apply SpecAugment to mel spectrogram features.
95
+
96
+ Args:
97
+ input_features: Mel spectrogram of shape (batch, n_mels, time)
98
+ mask_time_prob: Probability of masking time steps
99
+ mask_time_length: Max length of time mask
100
+ mask_time_min_masks: Min number of time masks
101
+ mask_feature_prob: Probability of masking frequency bins
102
+ mask_feature_length: Max length of frequency mask
103
+ mask_feature_min_masks: Min number of frequency masks
104
+
105
+ Returns:
106
+ Augmented mel spectrogram with same shape
107
+ """
108
+ batch_size, n_mels, time_steps = input_features.shape
109
+ device = input_features.device
110
+
111
+ # Clone to avoid modifying original
112
+ augmented = input_features.clone()
113
+
114
+ # Time masking (along time dimension)
115
+ if mask_time_prob > 0:
116
+ time_mask = _compute_mask_indices(
117
+ shape=(batch_size, time_steps),
118
+ mask_prob=mask_time_prob,
119
+ mask_length=mask_time_length,
120
+ min_masks=mask_time_min_masks,
121
+ device=device,
122
+ )
123
+ # Expand to (batch, 1, time) for broadcasting
124
+ time_mask = time_mask.unsqueeze(1)
125
+ augmented = augmented.masked_fill(time_mask, 0.0)
126
+
127
+ # Frequency masking (along mel dimension)
128
+ if mask_feature_prob > 0:
129
+ feature_mask = _compute_mask_indices(
130
+ shape=(batch_size, n_mels),
131
+ mask_prob=mask_feature_prob,
132
+ mask_length=mask_feature_length,
133
+ min_masks=mask_feature_min_masks,
134
+ device=device,
135
+ )
136
+ # Expand to (batch, n_mels, 1) for broadcasting
137
+ feature_mask = feature_mask.unsqueeze(2)
138
+ augmented = augmented.masked_fill(feature_mask, 0.0)
139
+
140
+ return augmented
141
+
142
+
143
  class ASRModel(PreTrainedModel, GenerationMixin):
144
  """Audio-to-text model combining an audio encoder, projector, and language model."""
145
 
 
228
  # Set up generation config with greedy decoding defaults
229
  self.generation_config = self.language_model.generation_config
230
  self.generation_config.max_new_tokens = config.max_new_tokens
231
+ self.generation_config.min_new_tokens = config.min_new_tokens
232
  self.generation_config.num_beams = config.num_beams
233
  self.generation_config.do_sample = False
234
  # Clear sampling params (inherited from LLM) since we use greedy decoding
 
502
  inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
503
 
504
  if input_features is not None and input_ids is not None:
505
+ # Apply SpecAugment during training if enabled
506
+ if self.training and getattr(self.config, "use_specaugment", False):
507
+ input_features = apply_specaugment(
508
+ input_features,
509
+ mask_time_prob=self.config.mask_time_prob,
510
+ mask_time_length=self.config.mask_time_length,
511
+ mask_time_min_masks=self.config.mask_time_min_masks,
512
+ mask_feature_prob=self.config.mask_feature_prob,
513
+ mask_feature_length=self.config.mask_feature_length,
514
+ mask_feature_min_masks=self.config.mask_feature_min_masks,
515
+ )
516
+
517
  # Encode audio -> flattened (total_audio_tokens, hidden_dim)
518
  audio_embeds = self._encode_audio(input_features, audio_attention_mask)
519
 
 
646
  return output
647
  return output.sequences
648
 
649
+ def generate_streaming(
650
+ self,
651
+ input_features: torch.Tensor,
652
+ audio_attention_mask: torch.Tensor,
653
+ system_prompt: Optional[str] = None,
654
+ **generate_kwargs,
655
+ ) -> Iterator[str]:
656
+ """Generate transcription with streaming token output.
657
+
658
+ Yields partial transcript strings as tokens are generated.
659
+ Reduces time-to-first-word by streaming tokens as they're decoded.
660
+
661
+ Args:
662
+ input_features: Mel spectrogram features (batch, n_mels, mel_len)
663
+ audio_attention_mask: Mask for real vs padded mel frames (batch, mel_len)
664
+ system_prompt: Optional system prompt override
665
+ **generate_kwargs: Additional generation arguments
666
+
667
+ Yields:
668
+ Partial transcript text as each token is generated
669
+ """
670
+ device = input_features.device
671
+ batch_size = input_features.shape[0]
672
+
673
+ # Encode audio -> flattened embeddings
674
+ audio_embeds = self._encode_audio(input_features, audio_attention_mask)
675
+
676
+ # Build prompt with correct number of audio tokens
677
+ num_audio_tokens = self._get_num_audio_tokens(audio_attention_mask)
678
+ audio_placeholder = "<audio>" * num_audio_tokens
679
+
680
+ system_prompt = system_prompt or self.system_prompt
681
+
682
+ messages: list[dict[str, str]] = []
683
+ if system_prompt:
684
+ messages.append({"role": "system", "content": system_prompt})
685
+ messages.append({"role": "user", "content": self.TRANSCRIBE_PROMPT + audio_placeholder})
686
+
687
+ chat_result = self.tokenizer.apply_chat_template(
688
+ messages,
689
+ tokenize=True,
690
+ add_generation_prompt=True,
691
+ return_tensors="pt",
692
+ )
693
+ input_ids = chat_result.input_ids.to(device)
694
+
695
+ if input_ids.dim() == 1:
696
+ input_ids = input_ids.unsqueeze(0)
697
+ if input_ids.shape[0] == 1 and batch_size > 1:
698
+ input_ids = input_ids.expand(batch_size, -1)
699
+
700
+ attention_mask = torch.ones_like(input_ids)
701
+
702
+ # Get text embeddings and replace audio tokens with audio embeddings
703
+ inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
704
+ audio_token_mask = (input_ids == self.audio_token_id).unsqueeze(-1)
705
+ inputs_embeds = inputs_embeds.masked_scatter(
706
+ audio_token_mask.to(inputs_embeds.device),
707
+ audio_embeds.to(inputs_embeds.device, dtype=inputs_embeds.dtype),
708
+ )
709
+
710
+ # Setup streamer for token-by-token output
711
+ streamer = TextIteratorStreamer(
712
+ self.tokenizer,
713
+ skip_prompt=True,
714
+ skip_special_tokens=True,
715
+ )
716
+
717
+ # Prepare generation kwargs
718
+ gen_kwargs = {
719
+ "inputs_embeds": inputs_embeds,
720
+ "attention_mask": attention_mask,
721
+ "generation_config": self.generation_config,
722
+ "streamer": streamer,
723
+ **generate_kwargs,
724
+ }
725
+
726
+ # Run generation in background thread
727
+ thread = Thread(target=self.language_model.generate, kwargs=gen_kwargs)
728
+ thread.start()
729
+
730
+ # Yield tokens as they're generated, filtering out <think>...</think> blocks
731
+ # SmolLM3 always starts in thinking mode, so assume we're in a think block
732
+ in_think_block = True
733
+ buffer = ""
734
+
735
+ for text in streamer:
736
+ buffer += text
737
+
738
+ # Check for think block start (in case model outputs multiple think blocks)
739
+ while "<think>" in buffer:
740
+ in_think_block = True
741
+ # Yield any text before <think>
742
+ before_think = buffer.split("<think>")[0]
743
+ if before_think:
744
+ yield before_think
745
+ buffer = buffer.split("<think>", 1)[-1]
746
+
747
+ # Check for think block end
748
+ while in_think_block and "</think>" in buffer:
749
+ in_think_block = False
750
+ buffer = buffer.split("</think>", 1)[-1]
751
+
752
+ # Yield text if not in think block
753
+ if not in_think_block and buffer:
754
+ yield buffer
755
+ buffer = ""
756
+
757
+ # Yield any remaining buffer
758
+ if buffer and not in_think_block:
759
+ yield buffer
760
+
761
+ thread.join()
762
+
763
  def save_pretrained(self, save_directory: Union[str, Path], **kwargs):
764
  """Save model, tokenizer, and processor."""
765
  import shutil
 
813
  # Copy projectors module
814
  shutil.copy(src_dir / "projectors.py", save_dir / "projectors.py")
815
 
816
+ def create_or_update_model_card(self, output_dir: Union[str, Path]):
817
+ """No-op for model card creation - we use MODEL_CARD.md in repo instead."""
818
+ pass
819
+
820
 
821
  # Register with transformers Auto classes
822
  AutoConfig.register("asr_model", ASRConfig)
asr_pipeline.py CHANGED
@@ -476,57 +476,37 @@ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
476
  text = self.tokenizer.decode(tokens, skip_special_tokens=True).strip()
477
  # Strip <think>...</think> tags (Qwen3 doesn't respect /no_think prompt)
478
  text = re.sub(r"<think>.*?</think>\s*", "", text, flags=re.DOTALL).strip()
479
- # Collapse spaced-out acronyms (e.g., "I S D S" -> "ISDS")
480
- text = self._collapse_acronyms(text)
481
- # Truncate if a word repeats more than 3 times consecutively
482
- text = self._truncate_repetitions(text, max_repeats=3)
483
  return {"text": text}
484
 
485
- def _collapse_acronyms(self, text: str) -> str:
486
- """Collapse spaced-out acronyms into single words.
 
 
487
 
488
- Converts patterns like "I S D S" to "ISDS" when 2+ single letters
489
- are separated by spaces.
490
 
491
- Args:
492
- text: Input text with potential spaced acronyms
493
-
494
- Returns:
495
- Text with acronyms collapsed
496
- """
497
- # Match 2+ single letters (case-insensitive) separated by spaces
498
- # Pattern: single letter, then one or more (space + single letter)
499
- pattern = r"\b([A-Za-z])((?:\s[A-Za-z]){1,})\b"
500
-
501
- def collapse_match(match: re.Match) -> str:
502
- # Get the full match and remove spaces
503
- full = match.group(0)
504
- return full.replace(" ", "").upper()
505
-
506
- return re.sub(pattern, collapse_match, text)
507
-
508
- def _truncate_repetitions(self, text: str, max_repeats: int = 3) -> str:
509
- """Truncate text when a word repeats more than max_repeats times consecutively.
510
-
511
- Args:
512
- text: Input text to check for repetitions
513
- max_repeats: Maximum allowed consecutive repetitions (default 3)
514
-
515
- Returns:
516
- Truncated text if repetition detected, otherwise original text
517
- """
518
  words = text.split()
519
- if len(words) <= max_repeats:
520
- return text
521
-
522
- repeat_count = 1
523
- for i in range(1, len(words)):
524
- if words[i].lower() == words[i - 1].lower():
525
- repeat_count += 1
526
- if repeat_count > max_repeats:
527
- # Keep up to max_repeats of the repeated word
528
- return " ".join(words[:i])
529
- else:
530
- repeat_count = 1
 
 
 
 
 
531
 
532
  return text
 
476
  text = self.tokenizer.decode(tokens, skip_special_tokens=True).strip()
477
  # Strip <think>...</think> tags (Qwen3 doesn't respect /no_think prompt)
478
  text = re.sub(r"<think>.*?</think>\s*", "", text, flags=re.DOTALL).strip()
479
+ # Post-process prediction
480
+ text = self._post_process_prediction(text)
 
 
481
  return {"text": text}
482
 
483
+ def _post_process_prediction(self, text: str) -> str:
484
+ """Post-process model output to fix common issues."""
485
+ if not text:
486
+ return ""
487
 
488
+ # 1. LOWERCASE
489
+ text = text.lower()
490
 
491
+ # 2. REMOVE REPETITIVE LOOPS
492
+ # If the model repeats the same phrase more than twice, cut it off.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
493
  words = text.split()
494
+ if len(words) > 10:
495
+ # Check for repeating n-grams (1 to 4 words long)
496
+ for n in range(1, 5):
497
+ last_sequence = words[-n:]
498
+ repeat_count = 0
499
+ idx = len(words) - n
500
+ while idx >= n and words[idx - n : idx] == last_sequence:
501
+ repeat_count += 1
502
+ idx -= n
503
+
504
+ # If more than 2 exact repetitions at the end, truncate
505
+ if repeat_count > 2:
506
+ text = " ".join(words[: idx + n])
507
+ break
508
+
509
+ # 3. STRIP WHITESPACE
510
+ text = re.sub(r'\s+', ' ', text).strip()
511
 
512
  return text
asr_processing.py CHANGED
@@ -94,14 +94,21 @@ class ASRProcessor(ProcessorMixin):
94
  messages.append({"role": "assistant", "content": text})
95
 
96
  # Tokenize
97
- input_ids = self.tokenizer.apply_chat_template(
98
  messages,
99
  tokenize=True,
100
  add_generation_prompt=(text is None),
101
  return_tensors=return_tensors,
102
  )
103
 
104
- if isinstance(input_ids, torch.Tensor) and input_ids.dim() == 1:
 
 
 
 
 
 
 
105
  input_ids = input_ids.unsqueeze(0)
106
 
107
  result["input_ids"] = input_ids
 
94
  messages.append({"role": "assistant", "content": text})
95
 
96
  # Tokenize
97
+ tokenized = self.tokenizer.apply_chat_template(
98
  messages,
99
  tokenize=True,
100
  add_generation_prompt=(text is None),
101
  return_tensors=return_tensors,
102
  )
103
 
104
+ # Handle both tensor and BatchEncoding returns
105
+ if isinstance(tokenized, torch.Tensor):
106
+ input_ids = tokenized
107
+ else:
108
+ # BatchEncoding or dict-like object
109
+ input_ids = tokenized["input_ids"] if "input_ids" in tokenized else tokenized.input_ids
110
+
111
+ if input_ids.dim() == 1:
112
  input_ids = input_ids.unsqueeze(0)
113
 
114
  result["input_ids"] = input_ids