mazesmazes commited on
Commit
36087fa
Β·
verified Β·
1 Parent(s): 5a5380c

Training in progress - step 10

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ tokenizer.json filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ tags: []
4
+ ---
5
+
6
+ # Model Card for Model ID
7
+
8
+ <!-- Provide a quick summary of what the model is/does. -->
9
+
10
+
11
+
12
+ ## Model Details
13
+
14
+ ### Model Description
15
+
16
+ <!-- Provide a longer summary of what this model is. -->
17
+
18
+ This is the model card of a πŸ€— transformers model that has been pushed on the Hub. This model card has been automatically generated.
19
+
20
+ - **Developed by:** [More Information Needed]
21
+ - **Funded by [optional]:** [More Information Needed]
22
+ - **Shared by [optional]:** [More Information Needed]
23
+ - **Model type:** [More Information Needed]
24
+ - **Language(s) (NLP):** [More Information Needed]
25
+ - **License:** [More Information Needed]
26
+ - **Finetuned from model [optional]:** [More Information Needed]
27
+
28
+ ### Model Sources [optional]
29
+
30
+ <!-- Provide the basic links for the model. -->
31
+
32
+ - **Repository:** [More Information Needed]
33
+ - **Paper [optional]:** [More Information Needed]
34
+ - **Demo [optional]:** [More Information Needed]
35
+
36
+ ## Uses
37
+
38
+ <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
39
+
40
+ ### Direct Use
41
+
42
+ <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
43
+
44
+ [More Information Needed]
45
+
46
+ ### Downstream Use [optional]
47
+
48
+ <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
49
+
50
+ [More Information Needed]
51
+
52
+ ### Out-of-Scope Use
53
+
54
+ <!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
55
+
56
+ [More Information Needed]
57
+
58
+ ## Bias, Risks, and Limitations
59
+
60
+ <!-- This section is meant to convey both technical and sociotechnical limitations. -->
61
+
62
+ [More Information Needed]
63
+
64
+ ### Recommendations
65
+
66
+ <!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
67
+
68
+ Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
69
+
70
+ ## How to Get Started with the Model
71
+
72
+ Use the code below to get started with the model.
73
+
74
+ [More Information Needed]
75
+
76
+ ## Training Details
77
+
78
+ ### Training Data
79
+
80
+ <!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
81
+
82
+ [More Information Needed]
83
+
84
+ ### Training Procedure
85
+
86
+ <!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
87
+
88
+ #### Preprocessing [optional]
89
+
90
+ [More Information Needed]
91
+
92
+
93
+ #### Training Hyperparameters
94
+
95
+ - **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
96
+
97
+ #### Speeds, Sizes, Times [optional]
98
+
99
+ <!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
100
+
101
+ [More Information Needed]
102
+
103
+ ## Evaluation
104
+
105
+ <!-- This section describes the evaluation protocols and provides the results. -->
106
+
107
+ ### Testing Data, Factors & Metrics
108
+
109
+ #### Testing Data
110
+
111
+ <!-- This should link to a Dataset Card if possible. -->
112
+
113
+ [More Information Needed]
114
+
115
+ #### Factors
116
+
117
+ <!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
118
+
119
+ [More Information Needed]
120
+
121
+ #### Metrics
122
+
123
+ <!-- These are the evaluation metrics being used, ideally with a description of why. -->
124
+
125
+ [More Information Needed]
126
+
127
+ ### Results
128
+
129
+ [More Information Needed]
130
+
131
+ #### Summary
132
+
133
+
134
+
135
+ ## Model Examination [optional]
136
+
137
+ <!-- Relevant interpretability work for the model goes here -->
138
+
139
+ [More Information Needed]
140
+
141
+ ## Environmental Impact
142
+
143
+ <!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
144
+
145
+ Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
146
+
147
+ - **Hardware Type:** [More Information Needed]
148
+ - **Hours used:** [More Information Needed]
149
+ - **Cloud Provider:** [More Information Needed]
150
+ - **Compute Region:** [More Information Needed]
151
+ - **Carbon Emitted:** [More Information Needed]
152
+
153
+ ## Technical Specifications [optional]
154
+
155
+ ### Model Architecture and Objective
156
+
157
+ [More Information Needed]
158
+
159
+ ### Compute Infrastructure
160
+
161
+ [More Information Needed]
162
+
163
+ #### Hardware
164
+
165
+ [More Information Needed]
166
+
167
+ #### Software
168
+
169
+ [More Information Needed]
170
+
171
+ ## Citation [optional]
172
+
173
+ <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
174
+
175
+ **BibTeX:**
176
+
177
+ [More Information Needed]
178
+
179
+ **APA:**
180
+
181
+ [More Information Needed]
182
+
183
+ ## Glossary [optional]
184
+
185
+ <!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
186
+
187
+ [More Information Needed]
188
+
189
+ ## More Information [optional]
190
+
191
+ [More Information Needed]
192
+
193
+ ## Model Card Authors [optional]
194
+
195
+ [More Information Needed]
196
+
197
+ ## Model Card Contact
198
+
199
+ [More Information Needed]
asr_config.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import transformers
4
+
5
+
6
+ class ASRConfig(transformers.PretrainedConfig):
7
+ model_type = "asr_model"
8
+ is_composition = True
9
+
10
+ def __init__(
11
+ self,
12
+ audio_model_id: str = "openai/whisper-large-v3-turbo",
13
+ text_model_id: str = "HuggingFaceTB/SmolLM3-3B",
14
+ attn_implementation: str = "sdpa",
15
+ model_dtype: str = "bfloat16",
16
+ audio_downsample_rate: int = 5, # Deprecated: use projector_pool_stride instead
17
+ num_beams: Optional[int] = None,
18
+ system_prompt: str = "/no_think /system_override",
19
+ user_prompt: str = "Transcribe: <audio>",
20
+ encoder_dim: Optional[int] = None,
21
+ llm_dim: Optional[int] = None,
22
+ # Audio processing constants
23
+ audio_sample_rate: int = 16000,
24
+ # Projector initialization constants
25
+ projector_init_std: float = 0.02,
26
+ projector_pool_stride: int = 2, # AvgPool1d stride (2 = 4x total with Whisper, 1 = no pooling)
27
+ projector_hidden_dim: Optional[
28
+ int
29
+ ] = None, # SwiGLU hidden dimension (defaults to encoder_dim * 4)
30
+ projector_dropout: float = 0.1, # Dropout rate for projector layers
31
+ # Inference parameters
32
+ inference_diversity_penalty: float = 0.0,
33
+ inference_warmup_tokens: int = 10,
34
+ # Generation parameters
35
+ max_new_tokens: Optional[int] = None,
36
+ min_new_tokens: Optional[int] = None,
37
+ do_sample: Optional[bool] = None,
38
+ temperature: Optional[float] = None,
39
+ top_k: Optional[int] = None,
40
+ top_p: Optional[float] = None,
41
+ repetition_penalty: Optional[float] = None,
42
+ length_penalty: Optional[float] = None,
43
+ no_repeat_ngram_size: Optional[int] = None,
44
+ early_stopping: Optional[bool] = None,
45
+ use_cache: Optional[bool] = None,
46
+ **kwargs,
47
+ ):
48
+ # Set default generation parameters
49
+ generation_defaults = {
50
+ "num_beams": 1,
51
+ "max_new_tokens": 128,
52
+ "min_new_tokens": 1,
53
+ "do_sample": False,
54
+ "repetition_penalty": 1.05,
55
+ "no_repeat_ngram_size": 0,
56
+ "use_cache": True,
57
+ }
58
+
59
+ # Apply defaults (config.json values take precedence)
60
+ kwargs = {**generation_defaults, **kwargs}
61
+
62
+ self.audio_model_id = audio_model_id
63
+ self.text_model_id = text_model_id
64
+ self.attn_implementation = attn_implementation
65
+ self.model_dtype = model_dtype
66
+ self.audio_downsample_rate = audio_downsample_rate
67
+ self.system_prompt = system_prompt
68
+ self.user_prompt = user_prompt
69
+ self.encoder_dim = encoder_dim
70
+ self.llm_dim = llm_dim
71
+ self.audio_sample_rate = audio_sample_rate
72
+ self.projector_init_std = projector_init_std
73
+ self.projector_pool_stride = projector_pool_stride
74
+ self.projector_hidden_dim = projector_hidden_dim
75
+ self.projector_dropout = projector_dropout
76
+ self.inference_diversity_penalty = inference_diversity_penalty
77
+ self.inference_warmup_tokens = inference_warmup_tokens
78
+ if "audio_config" not in kwargs:
79
+ self.audio_config = transformers.AutoConfig.from_pretrained(audio_model_id)
80
+ else:
81
+ self.audio_config = kwargs.pop("audio_config")
82
+
83
+ if "text_config" not in kwargs:
84
+ self.text_config = transformers.AutoConfig.from_pretrained(
85
+ text_model_id, trust_remote_code=True
86
+ )
87
+ else:
88
+ self.text_config = kwargs.pop("text_config")
89
+
90
+ # Ensure configs are PretrainedConfig objects (in case loaded from dict)
91
+ if isinstance(self.text_config, dict):
92
+ # Reconstruct config from dict using the model_type stored in the dict
93
+ model_type = self.text_config.get("model_type")
94
+ if model_type:
95
+ config_class = transformers.AutoConfig.for_model(model_type).__class__
96
+ self.text_config = config_class(**self.text_config)
97
+ else:
98
+ # Fallback: try to load from model_id
99
+ self.text_config = transformers.AutoConfig.from_pretrained(
100
+ text_model_id, trust_remote_code=True
101
+ )
102
+
103
+ if isinstance(self.audio_config, dict):
104
+ model_type = self.audio_config.get("model_type")
105
+ if model_type:
106
+ config_class = transformers.AutoConfig.for_model(model_type).__class__
107
+ self.audio_config = config_class(**self.audio_config)
108
+
109
+ super().__init__(**kwargs)
110
+
111
+ self.auto_map = {
112
+ "AutoConfig": "asr_config.ASRConfig",
113
+ "AutoModel": "asr_modeling.ASRModel",
114
+ "AutoModelForSpeechSeq2Seq": "asr_modeling.ASRModel",
115
+ "AutoProcessor": "asr_processing.ASRProcessor",
116
+ }
117
+ self.custom_pipelines = {
118
+ "automatic-speech-recognition": {
119
+ "impl": "asr_pipeline.ASRPipeline",
120
+ "pt": ["AutoModelForSpeechSeq2Seq"],
121
+ "tf": [],
122
+ "type": "audio",
123
+ }
124
+ }
125
+ self.architectures = ["ASRModel"]
126
+ self.pipeline_tag = "automatic-speech-recognition"
127
+
128
+
129
+ # Register the config with transformers
130
+ # This is needed for AutoConfig.from_pretrained to work
131
+ transformers.AutoConfig.register("asr_model", ASRConfig)
asr_modeling.py ADDED
@@ -0,0 +1,874 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from typing import Optional, Union
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F # noqa: N812
7
+ from transformers import (
8
+ AutoConfig,
9
+ AutoModel,
10
+ AutoModelForCausalLM,
11
+ AutoTokenizer,
12
+ PreTrainedModel,
13
+ Wav2Vec2FeatureExtractor,
14
+ )
15
+ from transformers.generation.utils import (
16
+ GenerateBeamDecoderOnlyOutput,
17
+ GenerateBeamEncoderDecoderOutput,
18
+ GenerateDecoderOnlyOutput,
19
+ GenerateEncoderDecoderOutput,
20
+ )
21
+
22
+ try:
23
+ from .asr_config import ASRConfig
24
+ except ImportError:
25
+ from asr_config import ASRConfig # type: ignore[no-redef]
26
+
27
+
28
+ class SwiGLU(nn.Module):
29
+ """
30
+ SwiGLU activation MLP - based on LlamaMLP but with flexible output dimension.
31
+
32
+ This implements the same gated activation pattern as transformers.models.llama.modeling_llama.LlamaMLP,
33
+ but allows for different input/output dimensions (needed for cross-modal projection).
34
+
35
+ Structure: w1 (gate), w2 (up), w3 (down) with w3(silu(w1) * w2)
36
+ """
37
+
38
+ def __init__(self, in_features, hidden_features, out_features, bias=False, dropout=0.0):
39
+ super().__init__()
40
+ self.w1 = nn.Linear(in_features, hidden_features, bias=bias)
41
+ self.w2 = nn.Linear(in_features, hidden_features, bias=bias)
42
+ self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
43
+ self.act = nn.SiLU()
44
+ self.dropout = nn.Dropout(dropout)
45
+
46
+ def forward(self, x):
47
+ x_gate = self.act(self.w1(x))
48
+ x_val = self.w2(x)
49
+ x = x_gate * x_val
50
+ x = self.dropout(x) # Apply dropout after the gating operation
51
+ return self.w3(x)
52
+
53
+
54
+ class AudioProjector(nn.Module):
55
+ """
56
+ AudioProjector using a SwiGLU MLP with dropout.
57
+ """
58
+
59
+ def __init__(self, config):
60
+ super().__init__()
61
+ self.k = getattr(config, "projector_pool_stride", 2) # Downsampling rate
62
+ in_dim = config.encoder_dim * self.k # Input is k frames concatenated
63
+ out_dim = config.llm_dim
64
+ hidden_dim = config.projector_hidden_dim
65
+ if hidden_dim is None:
66
+ hidden_dim = config.encoder_dim * 4 # Default: 4x encoder dim for SwiGLU
67
+
68
+ # Get dropout rate from config
69
+ dropout_rate = getattr(config, "projector_dropout", 0.1)
70
+
71
+ # SwiGLU MLP (now takes concatenated frames as input) with dropout
72
+ self.proj = SwiGLU(in_dim, hidden_dim, out_dim, dropout=dropout_rate)
73
+
74
+ # Optional output dropout layer for additional regularization
75
+ self.output_dropout = nn.Dropout(dropout_rate)
76
+
77
+ # Initialize weights following LLaMA-style initialization for SwiGLU
78
+ # Uses smaller std to account for the multiplicative gating
79
+ with torch.no_grad():
80
+ # Standard deviation from config or default (0.02 is common for transformers)
81
+ std = getattr(config, "projector_init_std", 0.02)
82
+
83
+ # Initialize gate and up projections
84
+ nn.init.normal_(self.proj.w1.weight, mean=0.0, std=std)
85
+ nn.init.normal_(self.proj.w2.weight, mean=0.0, std=std)
86
+
87
+ # Initialize down projection with scaling to preserve variance after SwiGLU
88
+ # The 1/sqrt(2) factor accounts for the multiplicative interaction
89
+ nn.init.normal_(self.proj.w3.weight, mean=0.0, std=std / (2**0.5))
90
+
91
+ def forward(self, x):
92
+ # x: [batch, seq_len, dim]
93
+ batch_size, seq_len, dim = x.size()
94
+
95
+ # Ensure input dtype matches the projector weights
96
+ # This is crucial for MPS devices where encoder may output bfloat16
97
+ # but projector weights might be in float32 when loaded from checkpoint
98
+ target_dtype = self.proj.w1.weight.dtype
99
+ if x.dtype != target_dtype:
100
+ x = x.to(target_dtype)
101
+
102
+ # Pad the sequence to be divisible by k instead of truncating
103
+ remainder = seq_len % self.k
104
+ if remainder:
105
+ pad_len = self.k - remainder
106
+ x = F.pad(x, (0, 0, 0, pad_len))
107
+
108
+ # Reshape for temporal compression - concatenate k consecutive frames
109
+ x = x.contiguous().view(batch_size, -1, dim * self.k)
110
+
111
+ # Apply SwiGLU block
112
+ x = self.proj(x)
113
+
114
+ # Apply output dropout for additional regularization
115
+ return self.output_dropout(x)
116
+
117
+
118
+ class ASRModel(PreTrainedModel):
119
+ config_class = ASRConfig
120
+ base_model_prefix = "model"
121
+ main_input_name = "input_values"
122
+ _supports_flash_attn_2 = True
123
+ supports_gradient_checkpointing = True
124
+ _is_loading_from_pretrained: bool = False
125
+ _pretrained_model_path: Optional[str] = None
126
+
127
+ # Task to prompt mapping for generation
128
+ TASK_PROMPTS = {
129
+ "transcribe": "Transcribe: <audio>",
130
+ "continue": "Continue: <audio>",
131
+ "describe": "Describe: <audio>",
132
+ "emotion": "Emotion: <audio>",
133
+ }
134
+
135
+ @staticmethod
136
+ def _create_feature_extractor(audio_model_id: str):
137
+ """Factory method to create the appropriate feature extractor."""
138
+ is_whisper = "whisper" in audio_model_id.lower()
139
+ if is_whisper:
140
+ from transformers import WhisperConfig, WhisperFeatureExtractor
141
+
142
+ encoder_config = WhisperConfig.from_pretrained(audio_model_id)
143
+ num_mel_bins = encoder_config.num_mel_bins
144
+ return WhisperFeatureExtractor.from_pretrained(
145
+ audio_model_id,
146
+ feature_size=num_mel_bins,
147
+ )
148
+ return Wav2Vec2FeatureExtractor.from_pretrained(audio_model_id)
149
+
150
+ @classmethod
151
+ def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
152
+ from transformers import AutoFeatureExtractor
153
+
154
+ config = kwargs.pop("config", None)
155
+ if config is None:
156
+ config = ASRConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
157
+
158
+ # Load feature extractor from saved model directory
159
+ kwargs["feature_extractor"] = AutoFeatureExtractor.from_pretrained(
160
+ pretrained_model_name_or_path, **kwargs
161
+ )
162
+
163
+ cls._is_loading_from_pretrained = True
164
+ cls._pretrained_model_path = pretrained_model_name_or_path
165
+
166
+ try:
167
+ # Let parent class handle loading config and model.safetensors
168
+ model = super().from_pretrained(
169
+ pretrained_model_name_or_path, *args, config=config, **kwargs
170
+ )
171
+
172
+ # Convert projector to target dtype after loading weights
173
+ target_dtype = getattr(torch, config.model_dtype)
174
+ model.projector = model.projector.to(dtype=target_dtype)
175
+
176
+ return model
177
+ finally:
178
+ cls._is_loading_from_pretrained = False
179
+ del cls._pretrained_model_path
180
+
181
+ def __init__(self, config: ASRConfig, **kwargs):
182
+ super().__init__(config)
183
+
184
+ feature_extractor = kwargs.pop("feature_extractor", None)
185
+
186
+ self.system_prompt = config.system_prompt
187
+
188
+ self.encoder = self._create_encoder(config)
189
+
190
+ is_whisper = "whisper" in config.audio_model_id.lower() or (
191
+ hasattr(self.encoder.config, "model_type")
192
+ and "whisper" in self.encoder.config.model_type.lower()
193
+ )
194
+
195
+ if is_whisper:
196
+ self.main_input_name = "input_features"
197
+ else:
198
+ self.main_input_name = "input_values"
199
+
200
+ if feature_extractor is not None:
201
+ self.feature_extractor = feature_extractor
202
+ else:
203
+ self.feature_extractor = self._create_feature_extractor(config.audio_model_id)
204
+
205
+ self.decoder = self._create_decoder(config)
206
+ self.generation_config = self.decoder.generation_config
207
+
208
+ self._init_tokenizer()
209
+
210
+ from types import SimpleNamespace
211
+
212
+ # Auto-detect encoder_dim and llm_dim if not specified
213
+ encoder_dim = config.encoder_dim
214
+ if encoder_dim is None:
215
+ if hasattr(self.encoder.config, "hidden_size"):
216
+ encoder_dim = self.encoder.config.hidden_size
217
+ elif hasattr(self.encoder.config, "d_model"):
218
+ encoder_dim = self.encoder.config.d_model
219
+ else:
220
+ raise ValueError("Could not auto-detect encoder_dim. Please specify in config.")
221
+
222
+ llm_dim = config.llm_dim
223
+ if llm_dim is None:
224
+ if hasattr(self.decoder.config, "hidden_size"):
225
+ llm_dim = self.decoder.config.hidden_size
226
+ elif hasattr(self.decoder.config, "d_model"):
227
+ llm_dim = self.decoder.config.d_model
228
+ else:
229
+ raise ValueError("Could not auto-detect llm_dim. Please specify in config.")
230
+
231
+ projector_config = SimpleNamespace(
232
+ encoder_dim=encoder_dim,
233
+ llm_dim=llm_dim,
234
+ projector_pool_stride=getattr(config, "projector_pool_stride", 2),
235
+ projector_hidden_dim=getattr(config, "projector_hidden_dim", None),
236
+ projector_init_std=getattr(config, "projector_init_std", 0.02),
237
+ projector_dropout=getattr(config, "projector_dropout", 0.1),
238
+ )
239
+ self.projector = AudioProjector(projector_config)
240
+
241
+ # Convert projector to the same dtype as encoder/decoder
242
+ target_dtype = getattr(torch, config.model_dtype)
243
+ self.projector = self.projector.to(dtype=target_dtype)
244
+
245
+ self._no_split_modules = self.decoder._no_split_modules
246
+
247
+ @classmethod
248
+ def _create_encoder(cls, config: ASRConfig):
249
+ """Create and configure the audio encoder.
250
+
251
+ Args:
252
+ config: Model configuration
253
+
254
+ Returns:
255
+ Configured encoder model
256
+ """
257
+ target_dtype = getattr(torch, config.model_dtype)
258
+
259
+ encoder_kwargs = {
260
+ "attn_implementation": config.attn_implementation,
261
+ "dtype": target_dtype,
262
+ "low_cpu_mem_usage": True,
263
+ }
264
+ if not cls._is_loading_from_pretrained:
265
+ encoder_kwargs["device_map"] = "auto"
266
+
267
+ if "whisper" in config.audio_model_id.lower():
268
+ from transformers import WhisperModel
269
+
270
+ full_model = WhisperModel.from_pretrained(config.audio_model_id, **encoder_kwargs)
271
+ encoder = full_model.encoder
272
+ del full_model
273
+ else:
274
+ encoder = AutoModel.from_pretrained(config.audio_model_id, **encoder_kwargs)
275
+
276
+ is_whisper = "whisper" in config.audio_model_id.lower() or (
277
+ hasattr(encoder.config, "model_type") and "whisper" in encoder.config.model_type.lower()
278
+ )
279
+
280
+ # Wrap encoder forward to handle Whisper's input_features vs input_values
281
+ original_forward = encoder.forward
282
+ input_key = "input_features" if is_whisper else "input_values"
283
+
284
+ def safe_encoder_forward(self_encoder, input_values=None, **kwargs):
285
+ # Catch and discard invalid kwargs like input_ids
286
+ kwargs.pop("input_ids", None)
287
+ return original_forward(**{input_key: input_values}, **kwargs)
288
+
289
+ import types
290
+
291
+ encoder.forward = types.MethodType(safe_encoder_forward, encoder)
292
+
293
+ # Freeze all encoder parameters
294
+ encoder.requires_grad_(False)
295
+
296
+ return encoder
297
+
298
+ @classmethod
299
+ def _create_decoder(cls, config: ASRConfig):
300
+ """Create and configure the language model decoder.
301
+
302
+ Args:
303
+ config: Model configuration
304
+
305
+ Returns:
306
+ Configured decoder model
307
+ """
308
+ target_dtype = getattr(torch, config.model_dtype)
309
+
310
+ # When loading from pretrained, avoid device_map="auto" to prevent meta tensor issues
311
+ decoder_kwargs = {
312
+ "attn_implementation": config.attn_implementation,
313
+ "dtype": target_dtype,
314
+ "trust_remote_code": True,
315
+ }
316
+ # Don't use device_map="auto" as it can cause meta tensor issues with Trainer
317
+ # The Trainer will handle device placement
318
+
319
+ decoder = AutoModelForCausalLM.from_pretrained(config.text_model_id, **decoder_kwargs)
320
+
321
+ # use_cache is now safe because we pre-expand audio tokens for consistent sequence length
322
+ # Cache can be enabled/disabled via config.use_cache
323
+ decoder.config.use_cache = config.use_cache
324
+
325
+ # Freeze all decoder parameters (only projector is trainable)
326
+ decoder.requires_grad_(False)
327
+
328
+ return decoder
329
+
330
+ def _init_weights(self, module):
331
+ """Initialize weights for trainable modules.
332
+
333
+ Note: This is a no-op since:
334
+ - AudioProjector self-initializes in its __init__
335
+ - Encoder/decoder are loaded from pretrained weights
336
+ """
337
+ pass
338
+
339
+ def can_generate(self) -> bool:
340
+ """Return True to indicate this model supports generation.
341
+
342
+ Required for Transformers 4.50+ where PreTrainedModel no longer
343
+ inherits from GenerationMixin.
344
+ """
345
+ return True
346
+
347
+ @property
348
+ def _tied_weights_keys(self):
349
+ """Return list of weight keys that should be tied.
350
+
351
+ In this model, input and output embeddings of the decoder may be tied.
352
+ """
353
+ if hasattr(self.decoder, "_tied_weights_keys"):
354
+ return [f"decoder.{k}" for k in self.decoder._tied_weights_keys]
355
+ return []
356
+
357
+ def _init_tokenizer(self):
358
+ model_path = (
359
+ self.__class__._pretrained_model_path
360
+ if self._is_loading_from_pretrained
361
+ else self.config.text_model_id
362
+ )
363
+
364
+ self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
365
+
366
+ # Set pad_token if not already set to avoid warnings during generation
367
+ # If pad_token is same as eos_token, we need a different token for padding
368
+ if (
369
+ self.tokenizer.pad_token is None
370
+ or self.tokenizer.pad_token_id == self.tokenizer.eos_token_id
371
+ ) and "<|finetune_right_pad_id|>" in self.tokenizer.get_vocab():
372
+ # For SmolLM3, use the dedicated finetune_right_pad_id token
373
+ self.tokenizer.pad_token = "<|finetune_right_pad_id|>"
374
+
375
+ existing_special = self.tokenizer.additional_special_tokens or []
376
+
377
+ # Add single audio token if not present
378
+ if "<audio>" not in existing_special:
379
+ special_tokens = {"additional_special_tokens": existing_special + ["<audio>"]}
380
+ num_added_tokens = self.tokenizer.add_special_tokens(special_tokens)
381
+ if num_added_tokens > 0:
382
+ # Use mean_resizing=False since this is a structural token, not semantic
383
+ self.decoder.resize_token_embeddings(len(self.tokenizer), mean_resizing=False)
384
+
385
+ current_embed_size = self.decoder.get_input_embeddings().weight.shape[0]
386
+ expected_size = len(self.tokenizer)
387
+ if current_embed_size != expected_size:
388
+ self.decoder.resize_token_embeddings(expected_size, mean_resizing=False)
389
+
390
+ self.audio_token_id = self.tokenizer.convert_tokens_to_ids("<audio>")
391
+
392
+ self.tokenizer.padding_side = "right"
393
+
394
+ for cfg in [self.config.text_config, self.decoder.config, self.generation_config]:
395
+ if isinstance(cfg, dict):
396
+ cfg["pad_token_id"] = self.tokenizer.pad_token_id
397
+ cfg["eos_token_id"] = self.tokenizer.eos_token_id
398
+ cfg["bos_token_id"] = self.tokenizer.bos_token_id
399
+ else:
400
+ cfg.pad_token_id = self.tokenizer.pad_token_id
401
+ cfg.eos_token_id = self.tokenizer.eos_token_id
402
+ cfg.bos_token_id = self.tokenizer.bos_token_id
403
+
404
+ def get_processor(self):
405
+ try:
406
+ from .asr_processing import ASRProcessor
407
+ except ImportError:
408
+ from asr_processing import ASRProcessor # type: ignore[no-redef]
409
+
410
+ return ASRProcessor(feature_extractor=self.feature_extractor, tokenizer=self.tokenizer)
411
+
412
+ def state_dict(self, *args, **kwargs):
413
+ """Return only trainable parameters (projector weights).
414
+
415
+ Called by HuggingFace Trainer to save model.safetensors in checkpoints.
416
+ """
417
+ return self._get_trainable_state_dict()
418
+
419
+ def _get_trainable_state_dict(self):
420
+ """Get all trainable parameters as a single state dict.
421
+
422
+ This is used by Trainer for checkpointing during training.
423
+ """
424
+ state = {}
425
+
426
+ # Only projector params are trainable now (encoder and decoder are frozen)
427
+ projector_state = self.projector.state_dict()
428
+ for name, tensor in projector_state.items():
429
+ state[f"projector.{name}"] = tensor
430
+
431
+ return state
432
+
433
+ def get_input_embeddings(self):
434
+ """Delegate to decoder for proper HF Trainer integration."""
435
+ return self.decoder.get_input_embeddings()
436
+
437
+ def set_input_embeddings(self, value):
438
+ """Delegate to decoder for proper HF Trainer integration."""
439
+ self.decoder.set_input_embeddings(value)
440
+
441
+ def get_output_embeddings(self):
442
+ """Delegate to decoder for proper HF Trainer integration."""
443
+ return self.decoder.get_output_embeddings()
444
+
445
+ def set_output_embeddings(self, value):
446
+ """Delegate to decoder for proper HF Trainer integration."""
447
+ self.decoder.set_output_embeddings(value)
448
+
449
+ def _encode_audio(
450
+ self,
451
+ input_values: torch.Tensor,
452
+ audio_attention_mask: Optional[torch.Tensor] = None,
453
+ ) -> torch.Tensor:
454
+ # Ensure input is on encoder's device and has the right dtype
455
+ encoder_device = next(self.encoder.parameters()).device
456
+ encoder_dtype = next(self.encoder.parameters()).dtype
457
+ # Clone to prevent user tensor reuse contamination
458
+ input_values = input_values.clone().to(device=encoder_device, dtype=encoder_dtype)
459
+
460
+ # Only pass explicit valid arguments to encoder
461
+ # Never use **kwargs to prevent torch.compile from injecting decoder args like input_ids
462
+ # Always use no_grad since encoder is frozen
463
+ with torch.no_grad():
464
+ audio_features = self.encoder(
465
+ input_values=input_values,
466
+ attention_mask=audio_attention_mask,
467
+ ).last_hidden_state
468
+
469
+ # Project audio features and ensure dtype matches decoder
470
+ audio_embeds = self.projector(audio_features)
471
+
472
+ # Convert to decoder's dtype if needed (e.g., bfloat16)
473
+ decoder_dtype = next(self.decoder.parameters()).dtype
474
+ if audio_embeds.dtype != decoder_dtype:
475
+ audio_embeds = audio_embeds.to(dtype=decoder_dtype)
476
+
477
+ return audio_embeds
478
+
479
+ def _get_audio_expansion_details(self, input_ids: torch.Tensor, num_audio_tokens: int) -> dict:
480
+ """Calculate the positions and masks needed to expand audio tokens.
481
+
482
+ This helper consolidates the common cumsum logic used by both
483
+ _expand_audio_tokens and _expand_for_audio_tokens.
484
+
485
+ Args:
486
+ input_ids: Token IDs with single <audio> token per sample
487
+ num_audio_tokens: Number of tokens each audio token expands to
488
+
489
+ Returns:
490
+ Dictionary containing:
491
+ - new_seq_len: The total sequence length after expansion
492
+ - new_start_positions: [batch, old_seq_len] tensor mapping old indices to new
493
+ - audio_mask: [batch, old_seq_len] boolean mask for audio token positions
494
+ """
495
+ batch_size, seq_len = input_ids.shape
496
+ device = input_ids.device
497
+
498
+ # Find audio token positions
499
+ audio_mask = input_ids == self.audio_token_id
500
+
501
+ # Validate: each sample must have exactly one audio token
502
+ audio_counts = audio_mask.sum(dim=1)
503
+ if not (audio_counts == 1).all():
504
+ missing = (audio_counts == 0).any()
505
+ multiple = (audio_counts > 1).any()
506
+ if missing:
507
+ raise ValueError("Some samples are missing audio token")
508
+ if multiple:
509
+ raise ValueError("Some samples have multiple audio tokens")
510
+
511
+ # Create placeholder tensor: 1 for normal tokens, num_audio_tokens for audio token
512
+ token_counts = torch.where(audio_mask, num_audio_tokens, 1)
513
+
514
+ # Cumsum - 1 gives us the ENDING position of each token's expansion
515
+ cumsum_counts = torch.cumsum(token_counts, dim=1)
516
+
517
+ # The starting position of token i is cumsum[i-1]
518
+ new_start_positions = torch.cat(
519
+ [
520
+ torch.zeros(batch_size, 1, dtype=torch.long, device=device),
521
+ cumsum_counts[:, :-1],
522
+ ],
523
+ dim=1,
524
+ )
525
+
526
+ # Calculate new sequence length
527
+ new_seq_len = seq_len - 1 + num_audio_tokens
528
+
529
+ return {
530
+ "new_seq_len": new_seq_len,
531
+ "new_start_positions": new_start_positions,
532
+ "audio_mask": audio_mask,
533
+ }
534
+
535
+ def _expand_tensor_for_audio(
536
+ self,
537
+ input_ids: torch.Tensor,
538
+ tensor_to_expand: Optional[torch.Tensor],
539
+ num_audio_tokens: int,
540
+ fill_value: Optional[Union[int, float]] = None,
541
+ audio_fill_value: Optional[Union[int, float]] = None,
542
+ ) -> torch.Tensor:
543
+ """Generic method to expand any tensor to match audio token expansion.
544
+
545
+ Args:
546
+ input_ids: Token IDs with single <audio> token per sample
547
+ tensor_to_expand: Tensor to expand (input_ids, attention_mask, labels) or None
548
+ num_audio_tokens: Number of tokens each audio token expands to
549
+ fill_value: Default fill value for new tensor
550
+ audio_fill_value: Value to use for audio token positions (if different from fill_value)
551
+
552
+ Returns:
553
+ Expanded tensor matching the expanded sequence length
554
+ """
555
+ batch_size, seq_len = input_ids.shape
556
+ device = input_ids.device
557
+
558
+ details = self._get_audio_expansion_details(input_ids, num_audio_tokens)
559
+ new_seq_len = details["new_seq_len"]
560
+ new_start_positions = details["new_start_positions"]
561
+ audio_mask = details["audio_mask"]
562
+
563
+ # Determine the tensor we're actually expanding
564
+ if tensor_to_expand is None:
565
+ # Expanding input_ids themselves
566
+ tensor_to_expand = input_ids
567
+ fill_value = fill_value or self.tokenizer.pad_token_id
568
+ audio_fill_value = audio_fill_value or self.audio_token_id
569
+ else:
570
+ # Expanding other tensors (attention_mask, labels)
571
+ if fill_value is None:
572
+ raise ValueError("fill_value must be provided when expanding non-input_ids tensors")
573
+ if audio_fill_value is None:
574
+ audio_fill_value = fill_value
575
+
576
+ # Create output tensor
577
+ expanded = torch.full(
578
+ (batch_size, new_seq_len),
579
+ fill_value,
580
+ dtype=tensor_to_expand.dtype,
581
+ device=device,
582
+ )
583
+
584
+ # Scatter non-audio positions to their new positions
585
+ batch_indices = torch.arange(batch_size, device=device).unsqueeze(1).expand(-1, seq_len)
586
+ non_audio_mask = ~audio_mask
587
+ expanded[batch_indices[non_audio_mask], new_start_positions[non_audio_mask]] = (
588
+ tensor_to_expand[non_audio_mask]
589
+ )
590
+
591
+ # Fill audio positions if different from default fill
592
+ if audio_fill_value != fill_value:
593
+ audio_positions = audio_mask.int().argmax(dim=1)
594
+ audio_new_start = new_start_positions[
595
+ torch.arange(batch_size, device=device), audio_positions
596
+ ]
597
+ audio_token_indices = torch.arange(num_audio_tokens, device=device).unsqueeze(0)
598
+ audio_positions_expanded = audio_new_start.unsqueeze(1) + audio_token_indices
599
+ batch_idx_expanded = (
600
+ torch.arange(batch_size, device=device).unsqueeze(1).expand(-1, num_audio_tokens)
601
+ )
602
+ expanded[batch_idx_expanded, audio_positions_expanded] = audio_fill_value
603
+
604
+ return expanded
605
+
606
+ def _expand_audio_tokens(self, input_ids: torch.Tensor, num_audio_tokens: int) -> torch.Tensor:
607
+ """Convenience method for expanding input_ids."""
608
+ return self._expand_tensor_for_audio(input_ids, None, num_audio_tokens)
609
+
610
+ def _expand_for_audio_tokens(
611
+ self,
612
+ input_ids: torch.Tensor,
613
+ tensor_to_expand: torch.Tensor,
614
+ num_audio_tokens: int,
615
+ fill_value: Union[int, float],
616
+ ) -> torch.Tensor:
617
+ """Convenience method for expanding attention_mask or labels."""
618
+ return self._expand_tensor_for_audio(
619
+ input_ids, tensor_to_expand, num_audio_tokens, fill_value
620
+ )
621
+
622
+ def _prepare_audio_inputs_embeds(
623
+ self, expanded_input_ids: torch.Tensor, audio_embeds: torch.Tensor
624
+ ) -> torch.Tensor:
625
+ """Prepare inputs_embeds by replacing audio token embeddings with actual audio embeddings.
626
+
627
+ Args:
628
+ expanded_input_ids: Input IDs with expanded audio tokens
629
+ audio_embeds: Audio embeddings to inject
630
+
631
+ Returns:
632
+ inputs_embeds with audio embeddings injected
633
+ """
634
+ # Get text embeddings for expanded input_ids
635
+ inputs_embeds = self.decoder.get_input_embeddings()(expanded_input_ids)
636
+
637
+ # Simple masked scatter: replace audio token embeddings with actual audio embeddings
638
+ special_audio_mask = (expanded_input_ids == self.audio_token_id).unsqueeze(-1)
639
+ special_audio_mask = special_audio_mask.expand_as(inputs_embeds)
640
+ audio_embeds_flat = audio_embeds.reshape(-1, audio_embeds.shape[-1])
641
+ return inputs_embeds.masked_scatter(special_audio_mask, audio_embeds_flat)
642
+
643
+ def forward(
644
+ self,
645
+ input_ids: Optional[torch.Tensor] = None,
646
+ input_values: Optional[torch.Tensor] = None,
647
+ input_features: Optional[torch.Tensor] = None, # For Whisper
648
+ labels: Optional[torch.Tensor] = None,
649
+ attention_mask: Optional[torch.Tensor] = None,
650
+ num_items_in_batch: Optional[
651
+ int
652
+ ] = None, # HF Trainer provides this for gradient accumulation
653
+ **kwargs,
654
+ ):
655
+ audio_inputs = input_values if input_values is not None else input_features
656
+ if audio_inputs is not None:
657
+ # During inference, the pipeline may call forward with only audio inputs
658
+ # In that case, we should raise an error directing to use generate() instead
659
+ if input_ids is None:
660
+ raise ValueError(
661
+ "forward() requires both audio inputs and input_ids (for training). "
662
+ "For inference, use the generate() method instead, or use the pipeline "
663
+ "which will automatically call generate()."
664
+ )
665
+
666
+ # Extract audio-specific kwargs, don't pass input_ids to encoder
667
+ audio_attention_mask = kwargs.pop("audio_attention_mask", None)
668
+
669
+ # Remove any decoder-specific kwargs that shouldn't go to the encoder
670
+ kwargs.pop("past_key_values", None)
671
+ use_cache = kwargs.pop("use_cache", None)
672
+
673
+ # Encode audio to get embeddings
674
+ audio_embeds = self._encode_audio(
675
+ input_values=audio_inputs, # Will be mapped to input_features for Whisper by safe_encoder_forward
676
+ audio_attention_mask=audio_attention_mask,
677
+ )
678
+
679
+ # Validate audio token ID before using it
680
+ if self.audio_token_id is None:
681
+ raise ValueError(f"Audio token not properly initialized: {self.audio_token_id}")
682
+
683
+ vocab_size = self.decoder.get_input_embeddings().weight.shape[0]
684
+ if self.audio_token_id >= vocab_size:
685
+ raise ValueError(
686
+ f"Audio token ID out of range. ID: {self.audio_token_id}, Vocab size: {vocab_size}"
687
+ )
688
+
689
+ # Check that audio token exists
690
+ if not (input_ids == self.audio_token_id).any():
691
+ raise ValueError("Audio token <audio> must be present in input")
692
+
693
+ # Expand audio tokens to match audio embedding length
694
+ num_audio_tokens = audio_embeds.shape[1]
695
+ expanded_input_ids = self._expand_audio_tokens(input_ids, num_audio_tokens)
696
+
697
+ # Prepare inputs_embeds with audio embeddings injected
698
+ inputs_embeds = self._prepare_audio_inputs_embeds(expanded_input_ids, audio_embeds)
699
+
700
+ # Expand attention mask to match new sequence length (vectorized)
701
+ if attention_mask is not None:
702
+ full_attention_mask = self._expand_for_audio_tokens(
703
+ input_ids, attention_mask, num_audio_tokens, fill_value=1
704
+ )
705
+ else:
706
+ full_attention_mask = None
707
+
708
+ # Expand labels to match new sequence length (vectorized, mark audio tokens as -100)
709
+ if labels is not None:
710
+ labels = self._expand_for_audio_tokens(
711
+ input_ids, labels, num_audio_tokens, fill_value=-100
712
+ )
713
+ else:
714
+ inputs_embeds = self.decoder.get_input_embeddings()(input_ids)
715
+ full_attention_mask = attention_mask
716
+ use_cache = kwargs.pop("use_cache", None)
717
+
718
+ # Standard forward pass with built-in loss computation
719
+ return self.decoder(
720
+ inputs_embeds=inputs_embeds,
721
+ attention_mask=full_attention_mask,
722
+ labels=labels,
723
+ use_cache=use_cache if use_cache is not None else False,
724
+ **kwargs,
725
+ )
726
+
727
+ @torch.no_grad()
728
+ def generate(
729
+ self,
730
+ input_values: Optional[torch.Tensor] = None,
731
+ input_features: Optional[torch.Tensor] = None, # For Whisper
732
+ system_prompt: Optional[str] = None,
733
+ user_prompt: Optional[str] = None,
734
+ task: Optional[str] = None,
735
+ **generate_kwargs,
736
+ ) -> Union[
737
+ torch.Tensor,
738
+ GenerateDecoderOnlyOutput,
739
+ GenerateEncoderDecoderOutput,
740
+ GenerateBeamDecoderOnlyOutput,
741
+ GenerateBeamEncoderDecoderOutput,
742
+ ]:
743
+ audio_inputs = input_values if input_values is not None else input_features
744
+ if audio_inputs is None:
745
+ raise ValueError("input_values or input_features must be provided for generation")
746
+
747
+ audio_embeds = self._encode_audio(audio_inputs)
748
+ batch_size = audio_embeds.shape[0]
749
+ device = audio_embeds.device
750
+
751
+ if system_prompt is None:
752
+ system_prompt = self.system_prompt
753
+
754
+ if user_prompt is None:
755
+ user_prompt = (
756
+ self.TASK_PROMPTS.get(task, self.config.user_prompt or "Transcribe: <audio>")
757
+ or "Transcribe: <audio>"
758
+ )
759
+
760
+ messages = []
761
+ if system_prompt:
762
+ messages.append({"role": "system", "content": system_prompt})
763
+ messages.append(
764
+ {
765
+ "role": "user",
766
+ "content": user_prompt,
767
+ }
768
+ )
769
+
770
+ prompt_ids = self.tokenizer.apply_chat_template(
771
+ messages,
772
+ tokenize=True,
773
+ add_generation_prompt=True,
774
+ return_tensors="pt",
775
+ enable_thinking=False,
776
+ ).to(device)
777
+
778
+ if len(prompt_ids.shape) == 1:
779
+ prompt_ids = prompt_ids.unsqueeze(0)
780
+
781
+ if prompt_ids.shape[0] == 1 and batch_size > 1:
782
+ prompt_ids = prompt_ids.expand(batch_size, -1)
783
+
784
+ if not (prompt_ids == self.audio_token_id).any():
785
+ raise ValueError("Audio token <audio> not found in prompt")
786
+
787
+ # Expand audio tokens to match audio embedding length
788
+ num_audio_tokens = audio_embeds.shape[1]
789
+ expanded_prompt_ids = self._expand_audio_tokens(prompt_ids, num_audio_tokens)
790
+
791
+ # Prepare inputs_embeds with audio embeddings injected
792
+ inputs_embeds = self._prepare_audio_inputs_embeds(expanded_prompt_ids, audio_embeds)
793
+
794
+ # Create attention mask for expanded sequence
795
+ total_seq_len = inputs_embeds.shape[1]
796
+ attention_mask = torch.ones(batch_size, total_seq_len, dtype=torch.long, device=device)
797
+
798
+ # Apply generation defaults from config
799
+ config_params = [
800
+ "max_new_tokens",
801
+ "min_new_tokens",
802
+ "num_beams",
803
+ "do_sample",
804
+ "temperature",
805
+ "top_k",
806
+ "top_p",
807
+ "repetition_penalty",
808
+ "length_penalty",
809
+ "no_repeat_ngram_size",
810
+ "early_stopping",
811
+ ]
812
+ for param in config_params:
813
+ if hasattr(self.config, param) and getattr(self.config, param) is not None:
814
+ generate_kwargs.setdefault(param, getattr(self.config, param))
815
+
816
+ # Add special token defaults
817
+ generate_kwargs.setdefault("use_cache", True)
818
+ generate_kwargs.setdefault(
819
+ "eos_token_id", self.tokenizer.convert_tokens_to_ids("<|im_end|>")
820
+ )
821
+ generate_kwargs.setdefault("pad_token_id", self.tokenizer.pad_token_id)
822
+
823
+ # Track the prompt length to extract only newly generated tokens
824
+ prompt_length = expanded_prompt_ids.shape[1]
825
+
826
+ # Generate the full sequence
827
+ generated_ids = self.decoder.generate(
828
+ input_ids=expanded_prompt_ids,
829
+ inputs_embeds=inputs_embeds,
830
+ attention_mask=attention_mask,
831
+ **generate_kwargs,
832
+ )
833
+
834
+ # Return only the newly generated tokens (exclude the prompt)
835
+ return generated_ids[:, prompt_length:]
836
+
837
+ def save_pretrained(self, save_directory: Union[str, Path], **kwargs):
838
+ import shutil
839
+ from pathlib import Path as PathlibPath
840
+
841
+ save_dir = PathlibPath(save_directory)
842
+ save_dir.mkdir(parents=True, exist_ok=True)
843
+
844
+ actual_vocab_size = self.decoder.config.vocab_size
845
+ self.config.vocab_size = actual_vocab_size
846
+ self.config.text_config.vocab_size = actual_vocab_size
847
+
848
+ if hasattr(self.encoder.config, "num_mel_bins"):
849
+ self.config.audio_config.num_mel_bins = self.encoder.config.num_mel_bins
850
+
851
+ # Use parent class to save config and model.safetensors
852
+ super().save_pretrained(save_dir, **kwargs)
853
+
854
+ self.tokenizer.save_pretrained(save_dir)
855
+
856
+ # For Whisper models, ensure feature_size matches num_mel_bins from encoder config
857
+ if hasattr(self.encoder.config, "num_mel_bins"):
858
+ # For Whisper models, explicitly set the correct feature_size before saving
859
+ num_mel_bins = self.encoder.config.num_mel_bins
860
+ self.feature_extractor.feature_size = num_mel_bins
861
+ self.feature_extractor.num_mel_bins = num_mel_bins # Explicitly set num_mel_bins
862
+ if hasattr(self.feature_extractor, "n_mels"):
863
+ self.feature_extractor.n_mels = num_mel_bins
864
+ self.feature_extractor.nb_max_frames = 3000 # Whisper's max frames
865
+
866
+ self.get_processor().save_pretrained(save_dir)
867
+
868
+ src_dir = PathlibPath(__file__).parent
869
+ for asr_file in src_dir.glob("asr_*.py"):
870
+ shutil.copy(asr_file, save_dir / asr_file.name)
871
+
872
+
873
+ AutoConfig.register("asr_model", ASRConfig)
874
+ AutoModel.register(ASRConfig, ASRModel)
asr_pipeline.py ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict
2
+
3
+ import torch
4
+ import transformers
5
+ from truecase import get_true_case
6
+
7
+ try:
8
+ from .asr_modeling import ASRModel
9
+ except ImportError:
10
+ from asr_modeling import ASRModel # type: ignore[no-redef]
11
+
12
+
13
+ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
14
+ model: ASRModel
15
+
16
+ def __init__(self, model: ASRModel, **kwargs):
17
+ feature_extractor = kwargs.pop("feature_extractor", model.feature_extractor)
18
+ tokenizer = kwargs.pop("tokenizer", model.tokenizer)
19
+
20
+ super().__init__(
21
+ model=model, feature_extractor=feature_extractor, tokenizer=tokenizer, **kwargs
22
+ )
23
+
24
+ # Initialize text normalizer (same as train.py)
25
+ if hasattr(tokenizer, "normalize"):
26
+ self.text_normalizer = tokenizer
27
+ else:
28
+ # Fallback to whisper-tiny tokenizer for its normalize() method only
29
+ from transformers import WhisperTokenizer
30
+ self.text_normalizer = WhisperTokenizer.from_pretrained("openai/whisper-tiny")
31
+
32
+ def __call__(self, inputs, **kwargs):
33
+ generate_kwargs = {}
34
+ for key in [
35
+ "max_new_tokens",
36
+ "num_beams",
37
+ "do_sample",
38
+ "length_penalty",
39
+ "repetition_penalty",
40
+ "no_repeat_ngram_size",
41
+ "early_stopping",
42
+ "num_beam_groups",
43
+ "diversity_penalty",
44
+ "top_k",
45
+ "temperature",
46
+ "top_p",
47
+ "user_prompt",
48
+ "task",
49
+ "text_input",
50
+ ]:
51
+ if key in kwargs:
52
+ generate_kwargs[key] = kwargs.pop(key)
53
+
54
+ # Handle text-only mode
55
+ task = generate_kwargs.get("task")
56
+ if task == "text" or generate_kwargs.get("text_input"):
57
+ return self._process_text_only(generate_kwargs)
58
+
59
+ if isinstance(inputs, list):
60
+ results = []
61
+ for single_input in inputs:
62
+ result = self.__call__(single_input, **kwargs, **generate_kwargs)
63
+ results.append(result)
64
+ return results
65
+
66
+ model_inputs = self.preprocess(inputs, **kwargs)
67
+
68
+ from collections.abc import Iterator
69
+
70
+ if isinstance(model_inputs, Iterator):
71
+ # Convert iterator to list to process chunks
72
+ chunks = list(model_inputs)
73
+
74
+ all_outputs = []
75
+ for _chunk_num, chunk in enumerate(chunks, start=1):
76
+ chunk_output = self._forward(chunk, **generate_kwargs)
77
+ # Move tensors to CPU before adding to outputs
78
+ for key, value in chunk_output.items():
79
+ if torch.is_tensor(value):
80
+ chunk_output[key] = value.cpu()
81
+ all_outputs.append(chunk_output)
82
+
83
+ # Merge chunks and decode ourselves to ensure skip_special_tokens=True
84
+ all_tokens: list[int] = []
85
+ for output in all_outputs:
86
+ tokens = output.get("tokens")
87
+ if tokens is None:
88
+ tokens = output.get("generated_ids")
89
+ if tokens is not None:
90
+ if torch.is_tensor(tokens):
91
+ tokens = tokens.cpu()
92
+ if len(tokens.shape) > 1:
93
+ tokens = tokens[0]
94
+ all_tokens.extend(tokens.tolist() if torch.is_tensor(tokens) else tokens)
95
+
96
+ # Decode the merged tokens with skip_special_tokens
97
+ text = self.tokenizer.decode(all_tokens, skip_special_tokens=True)
98
+ text = text.strip()
99
+
100
+ # Apply Whisper normalization (matches training)
101
+ text = self.text_normalizer.normalize(text)
102
+
103
+ # Apply truecasing for proper capitalization
104
+ text = get_true_case(text)
105
+
106
+ return {"text": text}
107
+
108
+ model_outputs = self._forward(model_inputs, **generate_kwargs)
109
+ return self.postprocess(model_outputs)
110
+
111
+ def preprocess(self, inputs, **preprocess_params):
112
+ if isinstance(inputs, list):
113
+ raise ValueError("Lists should not reach preprocess - bug in __call__")
114
+
115
+ # Set default chunking to 30 seconds with 5 second overlap
116
+ preprocess_params.setdefault("chunk_length_s", 30)
117
+ preprocess_params.setdefault("stride_length_s", (5, 5))
118
+
119
+ # Handle different formats from datasets
120
+ if isinstance(inputs, dict):
121
+ if "bytes" in inputs:
122
+ # Decode bytes to audio array using torchcodec
123
+ import tempfile
124
+
125
+ from torchcodec.decoders import AudioDecoder
126
+
127
+ wav_bytes = inputs["bytes"]
128
+ # Write to temp file for torchcodec to read
129
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
130
+ f.write(wav_bytes)
131
+ temp_path = f.name
132
+ try:
133
+ decoder = AudioDecoder(temp_path)
134
+ # Get all audio samples
135
+ audio_result = decoder.get_all_samples()
136
+ audio_tensor = audio_result.data
137
+ sample_rate = audio_result.sample_rate
138
+ inputs = {"raw": audio_tensor.squeeze().numpy(), "sampling_rate": sample_rate}
139
+ finally:
140
+ from pathlib import Path
141
+
142
+ Path(temp_path).unlink()
143
+ elif "array" in inputs:
144
+ # Convert "array" key to "raw" key
145
+ inputs = {"raw": inputs["array"], "sampling_rate": inputs["sampling_rate"]}
146
+ # If it already has "raw" and "sampling_rate", it's good to go
147
+ elif hasattr(inputs, "array") and hasattr(inputs, "sampling_rate"):
148
+ # Audio object with attributes (not dict)
149
+ inputs = {"raw": inputs.array, "sampling_rate": inputs.sampling_rate}
150
+ elif hasattr(inputs, "__array__") and not isinstance(inputs, (dict, bytes, str)):
151
+ inputs = {"raw": inputs, "sampling_rate": self.model.config.audio_sample_rate}
152
+ elif torch.is_tensor(inputs):
153
+ inputs = {
154
+ "raw": inputs.cpu().numpy(),
155
+ "sampling_rate": self.model.config.audio_sample_rate,
156
+ }
157
+
158
+ return super().preprocess(inputs, **preprocess_params)
159
+
160
+ def _forward(self, model_inputs, **generate_kwargs):
161
+ # Extract task and set sampling parameters
162
+ task = generate_kwargs.pop("task", None)
163
+
164
+ # Task-specific sampling parameters
165
+ task_params: Dict[str, Dict[str, Any]] = {
166
+ "transcribe": {"do_sample": False},
167
+ "emotion": {"do_sample": True, "temperature": 0.7},
168
+ "describe": {"do_sample": True, "temperature": 0.7},
169
+ "continue": {"do_sample": True, "temperature": 1.0},
170
+ }
171
+
172
+ if task in task_params:
173
+ for key, value in task_params[task].items():
174
+ generate_kwargs.setdefault(key, value)
175
+
176
+ # Extract audio inputs from various formats
177
+ is_last = True
178
+ audio_inputs = None
179
+ is_whisper = False # Track if this is Whisper input
180
+
181
+ # Normalize model_inputs to dict format
182
+ if isinstance(model_inputs, torch.Tensor):
183
+ audio_inputs = model_inputs
184
+ elif isinstance(model_inputs, (list, tuple)) and model_inputs:
185
+ model_inputs = (
186
+ model_inputs[0]
187
+ if isinstance(model_inputs[0], dict)
188
+ else {"input_values": model_inputs[0]}
189
+ )
190
+
191
+ if isinstance(model_inputs, dict):
192
+ # Pop metadata fields
193
+ is_last = model_inputs.pop("is_last", True)
194
+ model_inputs.pop("stride", None)
195
+ # Get audio input (Whisper uses input_features, others use input_values)
196
+ if "input_features" in model_inputs:
197
+ audio_inputs = model_inputs["input_features"]
198
+ is_whisper = True
199
+ else:
200
+ audio_inputs = model_inputs.get("input_values")
201
+
202
+ if audio_inputs is None:
203
+ raise ValueError(
204
+ f"Could not extract input_values or input_features from {type(model_inputs)}"
205
+ )
206
+
207
+ if isinstance(audio_inputs, torch.Tensor):
208
+ audio_inputs = audio_inputs.to(self.model.device)
209
+ else:
210
+ raise ValueError(f"audio inputs must be a tensor, got {type(audio_inputs)}")
211
+
212
+ im_end_id = self.model.tokenizer.convert_tokens_to_ids("<|im_end|>")
213
+ generate_kwargs.setdefault("eos_token_id", im_end_id)
214
+ generate_kwargs.setdefault("max_new_tokens", self.model.config.max_new_tokens)
215
+
216
+ # Pass the appropriate input type to generate
217
+ if is_whisper:
218
+ # Whisper model - use input_features
219
+ generated_ids = self.model.generate(
220
+ input_features=audio_inputs,
221
+ system_prompt=self.model.config.system_prompt,
222
+ task=task,
223
+ **generate_kwargs,
224
+ )
225
+ else:
226
+ # Wav2Vec2/HuBERT model - use input_values
227
+ generated_ids = self.model.generate(
228
+ input_values=audio_inputs,
229
+ system_prompt=self.model.config.system_prompt,
230
+ task=task,
231
+ **generate_kwargs,
232
+ )
233
+
234
+ return {"tokens": generated_ids, "is_last": is_last}
235
+
236
+ def _process_text_only(self, generate_kwargs):
237
+ """Process text-only input without audio encoding."""
238
+ text_input = generate_kwargs.pop("text_input", None)
239
+ if text_input is None:
240
+ raise ValueError("text_input is required for text task")
241
+
242
+ # Remove task from generate_kwargs to avoid duplicate argument
243
+ generate_kwargs.pop("task", None)
244
+
245
+ # Generate text using the model
246
+ generated_ids = self.model.generate(task="text", text_input=text_input, **generate_kwargs)
247
+
248
+ # Decode the generated text
249
+ generated_text = self.tokenizer.decode(generated_ids[0], skip_special_tokens=True)
250
+
251
+ return {"text": generated_text}
252
+
253
+ def postprocess(
254
+ self, model_outputs: Dict[str, Any], return_timestamps=None, return_language=None
255
+ ):
256
+ # Handle chunked outputs from iterator
257
+ if isinstance(model_outputs, list):
258
+ # Move all tensors to CPU before calling parent postprocess
259
+ for output_dict in model_outputs:
260
+ for key, value in output_dict.items():
261
+ if torch.is_tensor(value):
262
+ output_dict[key] = value.cpu()
263
+ return super().postprocess(model_outputs)
264
+
265
+ if "is_last" in model_outputs:
266
+ model_outputs.pop("is_last")
267
+
268
+ tokens = model_outputs.get("tokens")
269
+ if tokens is None:
270
+ tokens = model_outputs.get("generated_ids")
271
+
272
+ if tokens is None:
273
+ raise ValueError(
274
+ f"Expected 'tokens' or 'generated_ids' in model_outputs, got: {model_outputs.keys()}"
275
+ )
276
+
277
+ # Move to CPU if on MPS or other device
278
+ if torch.is_tensor(tokens) and tokens.device.type != "cpu":
279
+ tokens = tokens.cpu()
280
+
281
+ if len(tokens.shape) > 1:
282
+ tokens = tokens[0]
283
+
284
+ text = self.tokenizer.decode(tokens, skip_special_tokens=True)
285
+ text = text.strip()
286
+
287
+ # Apply Whisper normalization (matches training)
288
+ text = self.text_normalizer.normalize(text)
289
+
290
+ # Apply truecasing for proper capitalization
291
+ text = get_true_case(text)
292
+
293
+ return {"text": text}
asr_processing.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import transformers
2
+ from transformers import AutoTokenizer, ProcessorMixin
3
+
4
+ # Handle both package and standalone imports
5
+ try:
6
+ from .asr_config import ASRConfig
7
+ except ImportError:
8
+ from asr_config import ASRConfig # type: ignore[no-redef]
9
+
10
+
11
+ class ASRProcessor(ProcessorMixin):
12
+ """Generic processor that can handle both Wav2Vec2 and Whisper feature extractors."""
13
+
14
+ feature_extractor_class = "AutoFeatureExtractor"
15
+ tokenizer_class = "AutoTokenizer"
16
+
17
+ def __init__(self, feature_extractor, tokenizer):
18
+ self.feature_extractor = feature_extractor
19
+ self.tokenizer = tokenizer
20
+
21
+ @classmethod
22
+ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
23
+ from transformers import AutoFeatureExtractor
24
+
25
+ # Load feature extractor and tokenizer from saved model directory
26
+ feature_extractor = AutoFeatureExtractor.from_pretrained(
27
+ pretrained_model_name_or_path, **kwargs
28
+ )
29
+
30
+ tokenizer = AutoTokenizer.from_pretrained(
31
+ pretrained_model_name_or_path, trust_remote_code=True, **kwargs
32
+ )
33
+
34
+ return cls(feature_extractor=feature_extractor, tokenizer=tokenizer)
35
+
36
+ def save_pretrained(self, save_directory, **kwargs):
37
+ """Override save_pretrained to avoid attribute errors from base class."""
38
+ import json
39
+ from pathlib import Path
40
+
41
+ save_path = Path(save_directory)
42
+ save_path.mkdir(parents=True, exist_ok=True)
43
+
44
+ # Save the feature extractor (this creates preprocessor_config.json with all feature extractor settings)
45
+ if self.feature_extractor is not None:
46
+ self.feature_extractor.save_pretrained(save_directory)
47
+
48
+ # Save the tokenizer
49
+ if self.tokenizer is not None:
50
+ self.tokenizer.save_pretrained(save_directory)
51
+
52
+ # Load the existing preprocessor_config.json and add processor-specific metadata
53
+ config_path = save_path / "preprocessor_config.json"
54
+ if config_path.exists():
55
+ with config_path.open() as f:
56
+ processor_config = json.load(f)
57
+ else:
58
+ processor_config = {}
59
+
60
+ # Add/update processor metadata while preserving feature extractor settings
61
+ feature_extractor_type = self.feature_extractor.__class__.__name__
62
+ processor_config.update(
63
+ {
64
+ "processor_class": self.__class__.__name__,
65
+ "feature_extractor_class": self.feature_extractor_class,
66
+ "tokenizer_class": self.tokenizer_class,
67
+ "feature_extractor_type": feature_extractor_type, # Dynamic based on actual type
68
+ "auto_map": {"AutoProcessor": "asr_processing.ASRProcessor"},
69
+ }
70
+ )
71
+
72
+ # Save the merged config
73
+ with config_path.open("w") as f:
74
+ json.dump(processor_config, f, indent=2)
75
+
76
+
77
+ ASRProcessor.register_for_auto_class()
78
+ transformers.AutoProcessor.register(ASRConfig, ASRProcessor)
chat_template.jinja ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {# ───── defaults ───── #}
2
+ {%- if enable_thinking is not defined -%}
3
+ {%- set enable_thinking = true -%}
4
+ {%- endif -%}
5
+
6
+ {# ───── reasoning mode ───── #}
7
+ {%- if enable_thinking -%}
8
+ {%- set reasoning_mode = "/think" -%}
9
+ {%- else -%}
10
+ {%- set reasoning_mode = "/no_think" -%}
11
+ {%- endif -%}
12
+
13
+ {# ───── header (system message) ───── #}
14
+ {{- "<|im_start|>system\n" -}}
15
+
16
+ {%- if messages[0].role == "system" -%}
17
+ {%- set system_message = messages[0].content -%}
18
+ {%- if "/no_think" in system_message -%}
19
+ {%- set reasoning_mode = "/no_think" -%}
20
+ {%- elif "/think" in system_message -%}
21
+ {%- set reasoning_mode = "/think" -%}
22
+ {%- endif -%}
23
+ {%- set custom_instructions = system_message.replace("/no_think", "").replace("/think", "").rstrip() -%}
24
+ {%- endif -%}
25
+
26
+ {%- if "/system_override" in system_message -%}
27
+ {{- custom_instructions.replace("/system_override", "").rstrip() -}}
28
+ {{- "<|im_end|>\n" -}}
29
+ {%- else -%}
30
+ {{- "## Metadata\n\n" -}}
31
+ {{- "Knowledge Cutoff Date: June 2025\n" -}}
32
+ {%- set today = strftime_now("%d %B %Y") -%}
33
+ {{- "Today Date: " ~ today ~ "\n" -}}
34
+ {{- "Reasoning Mode: " + reasoning_mode + "\n\n" -}}
35
+
36
+ {{- "## Custom Instructions\n\n" -}}
37
+ {%- if custom_instructions -%}
38
+ {{- custom_instructions + "\n\n" -}}
39
+ {%- elif reasoning_mode == "/think" -%}
40
+ {{- "You are a helpful AI assistant named SmolLM, trained by Hugging Face. Your role as an assistant involves thoroughly exploring questions through a systematic thinking process before providing the final precise and accurate solutions. This requires engaging in a comprehensive cycle of analysis, summarizing, exploration, reassessment, reflection, backtracking, and iteration to develop well-considered thinking process. Please structure your response into two main sections: Thought and Solution using the specified format: <think> Thought section </think> Solution section. In the Thought section, detail your reasoning process in steps. Each step should include detailed considerations such as analysing questions, summarizing relevant findings, brainstorming new ideas, verifying the accuracy of the current steps, refining any errors, and revisiting previous steps. In the Solution section, based on various attempts, explorations, and reflections from the Thought section, systematically present the final solution that you deem correct. The Solution section should be logical, accurate, and concise and detail necessary steps needed to reach the conclusion.\n\n" -}}
41
+ {%- else -%}
42
+ {{- "You are a helpful AI assistant named SmolLM, trained by Hugging Face.\n\n" -}}
43
+ {%- endif -%}
44
+
45
+ {%- if xml_tools or python_tools or tools -%}
46
+ {{- "### Tools\n\n" -}}
47
+ {%- if xml_tools or tools -%}
48
+ {%- if tools -%}
49
+ {%- set xml_tools = tools -%}
50
+ {%- endif -%}
51
+ {%- set ns = namespace(xml_tool_string="You may call one or more functions to assist with the user query.\nYou are provided with function signatures within <tools></tools> XML tags:\n\n<tools>\n") -%}
52
+ {%- for tool in xml_tools[:] -%} {# The slicing makes sure that xml_tools is a list #}
53
+ {%- set ns.xml_tool_string = ns.xml_tool_string ~ (tool | string) ~ "\n" -%}
54
+ {%- endfor -%}
55
+ {%- set xml_tool_string = ns.xml_tool_string + "</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call>" -%}
56
+ {{- xml_tool_string -}}
57
+ {%- endif -%}
58
+ {%- if python_tools -%}
59
+ {%- set ns = namespace(python_tool_string="When you send a message containing Python code between '<code>' and '</code>' tags, it will be executed in a stateful Jupyter notebook environment, and you will then be given the output to continued reasoning in an agentic loop.\n\nYou can use the following tools in your python code like regular functions:\n<tools>\n") -%}
60
+ {%- for tool in python_tools[:] -%} {# The slicing makes sure that python_tools is a list #}
61
+ {%- set ns.python_tool_string = ns.python_tool_string ~ (tool | string) ~ "\n" -%}
62
+ {%- endfor -%}
63
+ {%- set python_tool_string = ns.python_tool_string + "</tools>\n\nThe state persists between code executions: so variables that you define in one step are still available thereafter." -%}
64
+ {{- python_tool_string -}}
65
+ {%- endif -%}
66
+ {{- "\n\n" -}}
67
+ {{- "<|im_end|>\n" -}}
68
+ {%- endif -%}
69
+ {%- endif -%}
70
+ {# ───── main loop ───── #}
71
+ {%- for message in messages -%}
72
+ {%- set content = message.content if message.content is string else "" -%}
73
+ {%- if message.role == "user" -%}
74
+ {{ "<|im_start|>" + message.role + "\n" + content + "<|im_end|>\n" }}
75
+ {%- elif message.role == "assistant" -%}
76
+ {% generation %}
77
+ {%- if reasoning_mode == "/think" -%}
78
+ {{ "<|im_start|>assistant\n" + content.lstrip("\n") + "<|im_end|>\n" }}
79
+ {%- else -%}
80
+ {{ "<|im_start|>assistant\n" + "<think>\n\n</think>\n" + content.lstrip("\n") + "<|im_end|>\n" }}
81
+ {%- endif -%}
82
+ {% endgeneration %}
83
+ {%- elif message.role == "tool" -%}
84
+ {{ "<|im_start|>" + "user\n" + content + "<|im_end|>\n" }}
85
+ {%- endif -%}
86
+ {%- endfor -%}
87
+ {# ───── generation prompt ───── #}
88
+ {%- if add_generation_prompt -%}
89
+ {%- if reasoning_mode == "/think" -%}
90
+ {{ "<|im_start|>assistant\n" }}
91
+ {%- else -%}
92
+ {{ "<|im_start|>assistant\n" + "<think>\n\n</think>\n" }}
93
+ {%- endif -%}
94
+ {%- endif -%}
preprocessor_config.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "chunk_length": 30,
3
+ "dither": 0.0,
4
+ "feature_extractor_type": "WhisperFeatureExtractor",
5
+ "feature_size": 128,
6
+ "hop_length": 160,
7
+ "n_fft": 400,
8
+ "n_samples": 480000,
9
+ "nb_max_frames": 3000,
10
+ "num_mel_bins": 128,
11
+ "padding_side": "right",
12
+ "padding_value": 0.0,
13
+ "processor_class": "ASRProcessor",
14
+ "return_attention_mask": false,
15
+ "sampling_rate": 16000,
16
+ "feature_extractor_class": "AutoFeatureExtractor",
17
+ "tokenizer_class": "AutoTokenizer",
18
+ "auto_map": {
19
+ "AutoProcessor": "asr_processing.ASRProcessor"
20
+ }
21
+ }
special_tokens_map.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ {
4
+ "content": "<audio>",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false
9
+ }
10
+ ],
11
+ "eos_token": {
12
+ "content": "<|im_end|>",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false
17
+ },
18
+ "pad_token": "<|finetune_right_pad_id|>"
19
+ }
tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d4aeaf198f783cbf58d8cd59812baac429ffe49147bf9648f6618de20b8d4a4c
3
+ size 17209003
tokenizer_config.json ADDED
Binary file (50.6 kB). View file