Training in progress - step 10
Browse files- .gitattributes +1 -0
- README.md +199 -0
- asr_config.py +131 -0
- asr_modeling.py +874 -0
- asr_pipeline.py +293 -0
- asr_processing.py +78 -0
- chat_template.jinja +94 -0
- preprocessor_config.json +21 -0
- special_tokens_map.json +19 -0
- tokenizer.json +3 -0
- tokenizer_config.json +0 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
README.md
ADDED
|
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
library_name: transformers
|
| 3 |
+
tags: []
|
| 4 |
+
---
|
| 5 |
+
|
| 6 |
+
# Model Card for Model ID
|
| 7 |
+
|
| 8 |
+
<!-- Provide a quick summary of what the model is/does. -->
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
## Model Details
|
| 13 |
+
|
| 14 |
+
### Model Description
|
| 15 |
+
|
| 16 |
+
<!-- Provide a longer summary of what this model is. -->
|
| 17 |
+
|
| 18 |
+
This is the model card of a π€ transformers model that has been pushed on the Hub. This model card has been automatically generated.
|
| 19 |
+
|
| 20 |
+
- **Developed by:** [More Information Needed]
|
| 21 |
+
- **Funded by [optional]:** [More Information Needed]
|
| 22 |
+
- **Shared by [optional]:** [More Information Needed]
|
| 23 |
+
- **Model type:** [More Information Needed]
|
| 24 |
+
- **Language(s) (NLP):** [More Information Needed]
|
| 25 |
+
- **License:** [More Information Needed]
|
| 26 |
+
- **Finetuned from model [optional]:** [More Information Needed]
|
| 27 |
+
|
| 28 |
+
### Model Sources [optional]
|
| 29 |
+
|
| 30 |
+
<!-- Provide the basic links for the model. -->
|
| 31 |
+
|
| 32 |
+
- **Repository:** [More Information Needed]
|
| 33 |
+
- **Paper [optional]:** [More Information Needed]
|
| 34 |
+
- **Demo [optional]:** [More Information Needed]
|
| 35 |
+
|
| 36 |
+
## Uses
|
| 37 |
+
|
| 38 |
+
<!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
|
| 39 |
+
|
| 40 |
+
### Direct Use
|
| 41 |
+
|
| 42 |
+
<!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
|
| 43 |
+
|
| 44 |
+
[More Information Needed]
|
| 45 |
+
|
| 46 |
+
### Downstream Use [optional]
|
| 47 |
+
|
| 48 |
+
<!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
|
| 49 |
+
|
| 50 |
+
[More Information Needed]
|
| 51 |
+
|
| 52 |
+
### Out-of-Scope Use
|
| 53 |
+
|
| 54 |
+
<!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
|
| 55 |
+
|
| 56 |
+
[More Information Needed]
|
| 57 |
+
|
| 58 |
+
## Bias, Risks, and Limitations
|
| 59 |
+
|
| 60 |
+
<!-- This section is meant to convey both technical and sociotechnical limitations. -->
|
| 61 |
+
|
| 62 |
+
[More Information Needed]
|
| 63 |
+
|
| 64 |
+
### Recommendations
|
| 65 |
+
|
| 66 |
+
<!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
|
| 67 |
+
|
| 68 |
+
Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
|
| 69 |
+
|
| 70 |
+
## How to Get Started with the Model
|
| 71 |
+
|
| 72 |
+
Use the code below to get started with the model.
|
| 73 |
+
|
| 74 |
+
[More Information Needed]
|
| 75 |
+
|
| 76 |
+
## Training Details
|
| 77 |
+
|
| 78 |
+
### Training Data
|
| 79 |
+
|
| 80 |
+
<!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
|
| 81 |
+
|
| 82 |
+
[More Information Needed]
|
| 83 |
+
|
| 84 |
+
### Training Procedure
|
| 85 |
+
|
| 86 |
+
<!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
|
| 87 |
+
|
| 88 |
+
#### Preprocessing [optional]
|
| 89 |
+
|
| 90 |
+
[More Information Needed]
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
#### Training Hyperparameters
|
| 94 |
+
|
| 95 |
+
- **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
|
| 96 |
+
|
| 97 |
+
#### Speeds, Sizes, Times [optional]
|
| 98 |
+
|
| 99 |
+
<!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
|
| 100 |
+
|
| 101 |
+
[More Information Needed]
|
| 102 |
+
|
| 103 |
+
## Evaluation
|
| 104 |
+
|
| 105 |
+
<!-- This section describes the evaluation protocols and provides the results. -->
|
| 106 |
+
|
| 107 |
+
### Testing Data, Factors & Metrics
|
| 108 |
+
|
| 109 |
+
#### Testing Data
|
| 110 |
+
|
| 111 |
+
<!-- This should link to a Dataset Card if possible. -->
|
| 112 |
+
|
| 113 |
+
[More Information Needed]
|
| 114 |
+
|
| 115 |
+
#### Factors
|
| 116 |
+
|
| 117 |
+
<!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
|
| 118 |
+
|
| 119 |
+
[More Information Needed]
|
| 120 |
+
|
| 121 |
+
#### Metrics
|
| 122 |
+
|
| 123 |
+
<!-- These are the evaluation metrics being used, ideally with a description of why. -->
|
| 124 |
+
|
| 125 |
+
[More Information Needed]
|
| 126 |
+
|
| 127 |
+
### Results
|
| 128 |
+
|
| 129 |
+
[More Information Needed]
|
| 130 |
+
|
| 131 |
+
#### Summary
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
## Model Examination [optional]
|
| 136 |
+
|
| 137 |
+
<!-- Relevant interpretability work for the model goes here -->
|
| 138 |
+
|
| 139 |
+
[More Information Needed]
|
| 140 |
+
|
| 141 |
+
## Environmental Impact
|
| 142 |
+
|
| 143 |
+
<!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
|
| 144 |
+
|
| 145 |
+
Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
|
| 146 |
+
|
| 147 |
+
- **Hardware Type:** [More Information Needed]
|
| 148 |
+
- **Hours used:** [More Information Needed]
|
| 149 |
+
- **Cloud Provider:** [More Information Needed]
|
| 150 |
+
- **Compute Region:** [More Information Needed]
|
| 151 |
+
- **Carbon Emitted:** [More Information Needed]
|
| 152 |
+
|
| 153 |
+
## Technical Specifications [optional]
|
| 154 |
+
|
| 155 |
+
### Model Architecture and Objective
|
| 156 |
+
|
| 157 |
+
[More Information Needed]
|
| 158 |
+
|
| 159 |
+
### Compute Infrastructure
|
| 160 |
+
|
| 161 |
+
[More Information Needed]
|
| 162 |
+
|
| 163 |
+
#### Hardware
|
| 164 |
+
|
| 165 |
+
[More Information Needed]
|
| 166 |
+
|
| 167 |
+
#### Software
|
| 168 |
+
|
| 169 |
+
[More Information Needed]
|
| 170 |
+
|
| 171 |
+
## Citation [optional]
|
| 172 |
+
|
| 173 |
+
<!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
|
| 174 |
+
|
| 175 |
+
**BibTeX:**
|
| 176 |
+
|
| 177 |
+
[More Information Needed]
|
| 178 |
+
|
| 179 |
+
**APA:**
|
| 180 |
+
|
| 181 |
+
[More Information Needed]
|
| 182 |
+
|
| 183 |
+
## Glossary [optional]
|
| 184 |
+
|
| 185 |
+
<!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
|
| 186 |
+
|
| 187 |
+
[More Information Needed]
|
| 188 |
+
|
| 189 |
+
## More Information [optional]
|
| 190 |
+
|
| 191 |
+
[More Information Needed]
|
| 192 |
+
|
| 193 |
+
## Model Card Authors [optional]
|
| 194 |
+
|
| 195 |
+
[More Information Needed]
|
| 196 |
+
|
| 197 |
+
## Model Card Contact
|
| 198 |
+
|
| 199 |
+
[More Information Needed]
|
asr_config.py
ADDED
|
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional
|
| 2 |
+
|
| 3 |
+
import transformers
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class ASRConfig(transformers.PretrainedConfig):
|
| 7 |
+
model_type = "asr_model"
|
| 8 |
+
is_composition = True
|
| 9 |
+
|
| 10 |
+
def __init__(
|
| 11 |
+
self,
|
| 12 |
+
audio_model_id: str = "openai/whisper-large-v3-turbo",
|
| 13 |
+
text_model_id: str = "HuggingFaceTB/SmolLM3-3B",
|
| 14 |
+
attn_implementation: str = "sdpa",
|
| 15 |
+
model_dtype: str = "bfloat16",
|
| 16 |
+
audio_downsample_rate: int = 5, # Deprecated: use projector_pool_stride instead
|
| 17 |
+
num_beams: Optional[int] = None,
|
| 18 |
+
system_prompt: str = "/no_think /system_override",
|
| 19 |
+
user_prompt: str = "Transcribe: <audio>",
|
| 20 |
+
encoder_dim: Optional[int] = None,
|
| 21 |
+
llm_dim: Optional[int] = None,
|
| 22 |
+
# Audio processing constants
|
| 23 |
+
audio_sample_rate: int = 16000,
|
| 24 |
+
# Projector initialization constants
|
| 25 |
+
projector_init_std: float = 0.02,
|
| 26 |
+
projector_pool_stride: int = 2, # AvgPool1d stride (2 = 4x total with Whisper, 1 = no pooling)
|
| 27 |
+
projector_hidden_dim: Optional[
|
| 28 |
+
int
|
| 29 |
+
] = None, # SwiGLU hidden dimension (defaults to encoder_dim * 4)
|
| 30 |
+
projector_dropout: float = 0.1, # Dropout rate for projector layers
|
| 31 |
+
# Inference parameters
|
| 32 |
+
inference_diversity_penalty: float = 0.0,
|
| 33 |
+
inference_warmup_tokens: int = 10,
|
| 34 |
+
# Generation parameters
|
| 35 |
+
max_new_tokens: Optional[int] = None,
|
| 36 |
+
min_new_tokens: Optional[int] = None,
|
| 37 |
+
do_sample: Optional[bool] = None,
|
| 38 |
+
temperature: Optional[float] = None,
|
| 39 |
+
top_k: Optional[int] = None,
|
| 40 |
+
top_p: Optional[float] = None,
|
| 41 |
+
repetition_penalty: Optional[float] = None,
|
| 42 |
+
length_penalty: Optional[float] = None,
|
| 43 |
+
no_repeat_ngram_size: Optional[int] = None,
|
| 44 |
+
early_stopping: Optional[bool] = None,
|
| 45 |
+
use_cache: Optional[bool] = None,
|
| 46 |
+
**kwargs,
|
| 47 |
+
):
|
| 48 |
+
# Set default generation parameters
|
| 49 |
+
generation_defaults = {
|
| 50 |
+
"num_beams": 1,
|
| 51 |
+
"max_new_tokens": 128,
|
| 52 |
+
"min_new_tokens": 1,
|
| 53 |
+
"do_sample": False,
|
| 54 |
+
"repetition_penalty": 1.05,
|
| 55 |
+
"no_repeat_ngram_size": 0,
|
| 56 |
+
"use_cache": True,
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
# Apply defaults (config.json values take precedence)
|
| 60 |
+
kwargs = {**generation_defaults, **kwargs}
|
| 61 |
+
|
| 62 |
+
self.audio_model_id = audio_model_id
|
| 63 |
+
self.text_model_id = text_model_id
|
| 64 |
+
self.attn_implementation = attn_implementation
|
| 65 |
+
self.model_dtype = model_dtype
|
| 66 |
+
self.audio_downsample_rate = audio_downsample_rate
|
| 67 |
+
self.system_prompt = system_prompt
|
| 68 |
+
self.user_prompt = user_prompt
|
| 69 |
+
self.encoder_dim = encoder_dim
|
| 70 |
+
self.llm_dim = llm_dim
|
| 71 |
+
self.audio_sample_rate = audio_sample_rate
|
| 72 |
+
self.projector_init_std = projector_init_std
|
| 73 |
+
self.projector_pool_stride = projector_pool_stride
|
| 74 |
+
self.projector_hidden_dim = projector_hidden_dim
|
| 75 |
+
self.projector_dropout = projector_dropout
|
| 76 |
+
self.inference_diversity_penalty = inference_diversity_penalty
|
| 77 |
+
self.inference_warmup_tokens = inference_warmup_tokens
|
| 78 |
+
if "audio_config" not in kwargs:
|
| 79 |
+
self.audio_config = transformers.AutoConfig.from_pretrained(audio_model_id)
|
| 80 |
+
else:
|
| 81 |
+
self.audio_config = kwargs.pop("audio_config")
|
| 82 |
+
|
| 83 |
+
if "text_config" not in kwargs:
|
| 84 |
+
self.text_config = transformers.AutoConfig.from_pretrained(
|
| 85 |
+
text_model_id, trust_remote_code=True
|
| 86 |
+
)
|
| 87 |
+
else:
|
| 88 |
+
self.text_config = kwargs.pop("text_config")
|
| 89 |
+
|
| 90 |
+
# Ensure configs are PretrainedConfig objects (in case loaded from dict)
|
| 91 |
+
if isinstance(self.text_config, dict):
|
| 92 |
+
# Reconstruct config from dict using the model_type stored in the dict
|
| 93 |
+
model_type = self.text_config.get("model_type")
|
| 94 |
+
if model_type:
|
| 95 |
+
config_class = transformers.AutoConfig.for_model(model_type).__class__
|
| 96 |
+
self.text_config = config_class(**self.text_config)
|
| 97 |
+
else:
|
| 98 |
+
# Fallback: try to load from model_id
|
| 99 |
+
self.text_config = transformers.AutoConfig.from_pretrained(
|
| 100 |
+
text_model_id, trust_remote_code=True
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
if isinstance(self.audio_config, dict):
|
| 104 |
+
model_type = self.audio_config.get("model_type")
|
| 105 |
+
if model_type:
|
| 106 |
+
config_class = transformers.AutoConfig.for_model(model_type).__class__
|
| 107 |
+
self.audio_config = config_class(**self.audio_config)
|
| 108 |
+
|
| 109 |
+
super().__init__(**kwargs)
|
| 110 |
+
|
| 111 |
+
self.auto_map = {
|
| 112 |
+
"AutoConfig": "asr_config.ASRConfig",
|
| 113 |
+
"AutoModel": "asr_modeling.ASRModel",
|
| 114 |
+
"AutoModelForSpeechSeq2Seq": "asr_modeling.ASRModel",
|
| 115 |
+
"AutoProcessor": "asr_processing.ASRProcessor",
|
| 116 |
+
}
|
| 117 |
+
self.custom_pipelines = {
|
| 118 |
+
"automatic-speech-recognition": {
|
| 119 |
+
"impl": "asr_pipeline.ASRPipeline",
|
| 120 |
+
"pt": ["AutoModelForSpeechSeq2Seq"],
|
| 121 |
+
"tf": [],
|
| 122 |
+
"type": "audio",
|
| 123 |
+
}
|
| 124 |
+
}
|
| 125 |
+
self.architectures = ["ASRModel"]
|
| 126 |
+
self.pipeline_tag = "automatic-speech-recognition"
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
# Register the config with transformers
|
| 130 |
+
# This is needed for AutoConfig.from_pretrained to work
|
| 131 |
+
transformers.AutoConfig.register("asr_model", ASRConfig)
|
asr_modeling.py
ADDED
|
@@ -0,0 +1,874 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
from typing import Optional, Union
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import torch.nn.functional as F # noqa: N812
|
| 7 |
+
from transformers import (
|
| 8 |
+
AutoConfig,
|
| 9 |
+
AutoModel,
|
| 10 |
+
AutoModelForCausalLM,
|
| 11 |
+
AutoTokenizer,
|
| 12 |
+
PreTrainedModel,
|
| 13 |
+
Wav2Vec2FeatureExtractor,
|
| 14 |
+
)
|
| 15 |
+
from transformers.generation.utils import (
|
| 16 |
+
GenerateBeamDecoderOnlyOutput,
|
| 17 |
+
GenerateBeamEncoderDecoderOutput,
|
| 18 |
+
GenerateDecoderOnlyOutput,
|
| 19 |
+
GenerateEncoderDecoderOutput,
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
try:
|
| 23 |
+
from .asr_config import ASRConfig
|
| 24 |
+
except ImportError:
|
| 25 |
+
from asr_config import ASRConfig # type: ignore[no-redef]
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class SwiGLU(nn.Module):
|
| 29 |
+
"""
|
| 30 |
+
SwiGLU activation MLP - based on LlamaMLP but with flexible output dimension.
|
| 31 |
+
|
| 32 |
+
This implements the same gated activation pattern as transformers.models.llama.modeling_llama.LlamaMLP,
|
| 33 |
+
but allows for different input/output dimensions (needed for cross-modal projection).
|
| 34 |
+
|
| 35 |
+
Structure: w1 (gate), w2 (up), w3 (down) with w3(silu(w1) * w2)
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
def __init__(self, in_features, hidden_features, out_features, bias=False, dropout=0.0):
|
| 39 |
+
super().__init__()
|
| 40 |
+
self.w1 = nn.Linear(in_features, hidden_features, bias=bias)
|
| 41 |
+
self.w2 = nn.Linear(in_features, hidden_features, bias=bias)
|
| 42 |
+
self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
|
| 43 |
+
self.act = nn.SiLU()
|
| 44 |
+
self.dropout = nn.Dropout(dropout)
|
| 45 |
+
|
| 46 |
+
def forward(self, x):
|
| 47 |
+
x_gate = self.act(self.w1(x))
|
| 48 |
+
x_val = self.w2(x)
|
| 49 |
+
x = x_gate * x_val
|
| 50 |
+
x = self.dropout(x) # Apply dropout after the gating operation
|
| 51 |
+
return self.w3(x)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class AudioProjector(nn.Module):
|
| 55 |
+
"""
|
| 56 |
+
AudioProjector using a SwiGLU MLP with dropout.
|
| 57 |
+
"""
|
| 58 |
+
|
| 59 |
+
def __init__(self, config):
|
| 60 |
+
super().__init__()
|
| 61 |
+
self.k = getattr(config, "projector_pool_stride", 2) # Downsampling rate
|
| 62 |
+
in_dim = config.encoder_dim * self.k # Input is k frames concatenated
|
| 63 |
+
out_dim = config.llm_dim
|
| 64 |
+
hidden_dim = config.projector_hidden_dim
|
| 65 |
+
if hidden_dim is None:
|
| 66 |
+
hidden_dim = config.encoder_dim * 4 # Default: 4x encoder dim for SwiGLU
|
| 67 |
+
|
| 68 |
+
# Get dropout rate from config
|
| 69 |
+
dropout_rate = getattr(config, "projector_dropout", 0.1)
|
| 70 |
+
|
| 71 |
+
# SwiGLU MLP (now takes concatenated frames as input) with dropout
|
| 72 |
+
self.proj = SwiGLU(in_dim, hidden_dim, out_dim, dropout=dropout_rate)
|
| 73 |
+
|
| 74 |
+
# Optional output dropout layer for additional regularization
|
| 75 |
+
self.output_dropout = nn.Dropout(dropout_rate)
|
| 76 |
+
|
| 77 |
+
# Initialize weights following LLaMA-style initialization for SwiGLU
|
| 78 |
+
# Uses smaller std to account for the multiplicative gating
|
| 79 |
+
with torch.no_grad():
|
| 80 |
+
# Standard deviation from config or default (0.02 is common for transformers)
|
| 81 |
+
std = getattr(config, "projector_init_std", 0.02)
|
| 82 |
+
|
| 83 |
+
# Initialize gate and up projections
|
| 84 |
+
nn.init.normal_(self.proj.w1.weight, mean=0.0, std=std)
|
| 85 |
+
nn.init.normal_(self.proj.w2.weight, mean=0.0, std=std)
|
| 86 |
+
|
| 87 |
+
# Initialize down projection with scaling to preserve variance after SwiGLU
|
| 88 |
+
# The 1/sqrt(2) factor accounts for the multiplicative interaction
|
| 89 |
+
nn.init.normal_(self.proj.w3.weight, mean=0.0, std=std / (2**0.5))
|
| 90 |
+
|
| 91 |
+
def forward(self, x):
|
| 92 |
+
# x: [batch, seq_len, dim]
|
| 93 |
+
batch_size, seq_len, dim = x.size()
|
| 94 |
+
|
| 95 |
+
# Ensure input dtype matches the projector weights
|
| 96 |
+
# This is crucial for MPS devices where encoder may output bfloat16
|
| 97 |
+
# but projector weights might be in float32 when loaded from checkpoint
|
| 98 |
+
target_dtype = self.proj.w1.weight.dtype
|
| 99 |
+
if x.dtype != target_dtype:
|
| 100 |
+
x = x.to(target_dtype)
|
| 101 |
+
|
| 102 |
+
# Pad the sequence to be divisible by k instead of truncating
|
| 103 |
+
remainder = seq_len % self.k
|
| 104 |
+
if remainder:
|
| 105 |
+
pad_len = self.k - remainder
|
| 106 |
+
x = F.pad(x, (0, 0, 0, pad_len))
|
| 107 |
+
|
| 108 |
+
# Reshape for temporal compression - concatenate k consecutive frames
|
| 109 |
+
x = x.contiguous().view(batch_size, -1, dim * self.k)
|
| 110 |
+
|
| 111 |
+
# Apply SwiGLU block
|
| 112 |
+
x = self.proj(x)
|
| 113 |
+
|
| 114 |
+
# Apply output dropout for additional regularization
|
| 115 |
+
return self.output_dropout(x)
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
class ASRModel(PreTrainedModel):
|
| 119 |
+
config_class = ASRConfig
|
| 120 |
+
base_model_prefix = "model"
|
| 121 |
+
main_input_name = "input_values"
|
| 122 |
+
_supports_flash_attn_2 = True
|
| 123 |
+
supports_gradient_checkpointing = True
|
| 124 |
+
_is_loading_from_pretrained: bool = False
|
| 125 |
+
_pretrained_model_path: Optional[str] = None
|
| 126 |
+
|
| 127 |
+
# Task to prompt mapping for generation
|
| 128 |
+
TASK_PROMPTS = {
|
| 129 |
+
"transcribe": "Transcribe: <audio>",
|
| 130 |
+
"continue": "Continue: <audio>",
|
| 131 |
+
"describe": "Describe: <audio>",
|
| 132 |
+
"emotion": "Emotion: <audio>",
|
| 133 |
+
}
|
| 134 |
+
|
| 135 |
+
@staticmethod
|
| 136 |
+
def _create_feature_extractor(audio_model_id: str):
|
| 137 |
+
"""Factory method to create the appropriate feature extractor."""
|
| 138 |
+
is_whisper = "whisper" in audio_model_id.lower()
|
| 139 |
+
if is_whisper:
|
| 140 |
+
from transformers import WhisperConfig, WhisperFeatureExtractor
|
| 141 |
+
|
| 142 |
+
encoder_config = WhisperConfig.from_pretrained(audio_model_id)
|
| 143 |
+
num_mel_bins = encoder_config.num_mel_bins
|
| 144 |
+
return WhisperFeatureExtractor.from_pretrained(
|
| 145 |
+
audio_model_id,
|
| 146 |
+
feature_size=num_mel_bins,
|
| 147 |
+
)
|
| 148 |
+
return Wav2Vec2FeatureExtractor.from_pretrained(audio_model_id)
|
| 149 |
+
|
| 150 |
+
@classmethod
|
| 151 |
+
def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
|
| 152 |
+
from transformers import AutoFeatureExtractor
|
| 153 |
+
|
| 154 |
+
config = kwargs.pop("config", None)
|
| 155 |
+
if config is None:
|
| 156 |
+
config = ASRConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
| 157 |
+
|
| 158 |
+
# Load feature extractor from saved model directory
|
| 159 |
+
kwargs["feature_extractor"] = AutoFeatureExtractor.from_pretrained(
|
| 160 |
+
pretrained_model_name_or_path, **kwargs
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
cls._is_loading_from_pretrained = True
|
| 164 |
+
cls._pretrained_model_path = pretrained_model_name_or_path
|
| 165 |
+
|
| 166 |
+
try:
|
| 167 |
+
# Let parent class handle loading config and model.safetensors
|
| 168 |
+
model = super().from_pretrained(
|
| 169 |
+
pretrained_model_name_or_path, *args, config=config, **kwargs
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
# Convert projector to target dtype after loading weights
|
| 173 |
+
target_dtype = getattr(torch, config.model_dtype)
|
| 174 |
+
model.projector = model.projector.to(dtype=target_dtype)
|
| 175 |
+
|
| 176 |
+
return model
|
| 177 |
+
finally:
|
| 178 |
+
cls._is_loading_from_pretrained = False
|
| 179 |
+
del cls._pretrained_model_path
|
| 180 |
+
|
| 181 |
+
def __init__(self, config: ASRConfig, **kwargs):
|
| 182 |
+
super().__init__(config)
|
| 183 |
+
|
| 184 |
+
feature_extractor = kwargs.pop("feature_extractor", None)
|
| 185 |
+
|
| 186 |
+
self.system_prompt = config.system_prompt
|
| 187 |
+
|
| 188 |
+
self.encoder = self._create_encoder(config)
|
| 189 |
+
|
| 190 |
+
is_whisper = "whisper" in config.audio_model_id.lower() or (
|
| 191 |
+
hasattr(self.encoder.config, "model_type")
|
| 192 |
+
and "whisper" in self.encoder.config.model_type.lower()
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
if is_whisper:
|
| 196 |
+
self.main_input_name = "input_features"
|
| 197 |
+
else:
|
| 198 |
+
self.main_input_name = "input_values"
|
| 199 |
+
|
| 200 |
+
if feature_extractor is not None:
|
| 201 |
+
self.feature_extractor = feature_extractor
|
| 202 |
+
else:
|
| 203 |
+
self.feature_extractor = self._create_feature_extractor(config.audio_model_id)
|
| 204 |
+
|
| 205 |
+
self.decoder = self._create_decoder(config)
|
| 206 |
+
self.generation_config = self.decoder.generation_config
|
| 207 |
+
|
| 208 |
+
self._init_tokenizer()
|
| 209 |
+
|
| 210 |
+
from types import SimpleNamespace
|
| 211 |
+
|
| 212 |
+
# Auto-detect encoder_dim and llm_dim if not specified
|
| 213 |
+
encoder_dim = config.encoder_dim
|
| 214 |
+
if encoder_dim is None:
|
| 215 |
+
if hasattr(self.encoder.config, "hidden_size"):
|
| 216 |
+
encoder_dim = self.encoder.config.hidden_size
|
| 217 |
+
elif hasattr(self.encoder.config, "d_model"):
|
| 218 |
+
encoder_dim = self.encoder.config.d_model
|
| 219 |
+
else:
|
| 220 |
+
raise ValueError("Could not auto-detect encoder_dim. Please specify in config.")
|
| 221 |
+
|
| 222 |
+
llm_dim = config.llm_dim
|
| 223 |
+
if llm_dim is None:
|
| 224 |
+
if hasattr(self.decoder.config, "hidden_size"):
|
| 225 |
+
llm_dim = self.decoder.config.hidden_size
|
| 226 |
+
elif hasattr(self.decoder.config, "d_model"):
|
| 227 |
+
llm_dim = self.decoder.config.d_model
|
| 228 |
+
else:
|
| 229 |
+
raise ValueError("Could not auto-detect llm_dim. Please specify in config.")
|
| 230 |
+
|
| 231 |
+
projector_config = SimpleNamespace(
|
| 232 |
+
encoder_dim=encoder_dim,
|
| 233 |
+
llm_dim=llm_dim,
|
| 234 |
+
projector_pool_stride=getattr(config, "projector_pool_stride", 2),
|
| 235 |
+
projector_hidden_dim=getattr(config, "projector_hidden_dim", None),
|
| 236 |
+
projector_init_std=getattr(config, "projector_init_std", 0.02),
|
| 237 |
+
projector_dropout=getattr(config, "projector_dropout", 0.1),
|
| 238 |
+
)
|
| 239 |
+
self.projector = AudioProjector(projector_config)
|
| 240 |
+
|
| 241 |
+
# Convert projector to the same dtype as encoder/decoder
|
| 242 |
+
target_dtype = getattr(torch, config.model_dtype)
|
| 243 |
+
self.projector = self.projector.to(dtype=target_dtype)
|
| 244 |
+
|
| 245 |
+
self._no_split_modules = self.decoder._no_split_modules
|
| 246 |
+
|
| 247 |
+
@classmethod
|
| 248 |
+
def _create_encoder(cls, config: ASRConfig):
|
| 249 |
+
"""Create and configure the audio encoder.
|
| 250 |
+
|
| 251 |
+
Args:
|
| 252 |
+
config: Model configuration
|
| 253 |
+
|
| 254 |
+
Returns:
|
| 255 |
+
Configured encoder model
|
| 256 |
+
"""
|
| 257 |
+
target_dtype = getattr(torch, config.model_dtype)
|
| 258 |
+
|
| 259 |
+
encoder_kwargs = {
|
| 260 |
+
"attn_implementation": config.attn_implementation,
|
| 261 |
+
"dtype": target_dtype,
|
| 262 |
+
"low_cpu_mem_usage": True,
|
| 263 |
+
}
|
| 264 |
+
if not cls._is_loading_from_pretrained:
|
| 265 |
+
encoder_kwargs["device_map"] = "auto"
|
| 266 |
+
|
| 267 |
+
if "whisper" in config.audio_model_id.lower():
|
| 268 |
+
from transformers import WhisperModel
|
| 269 |
+
|
| 270 |
+
full_model = WhisperModel.from_pretrained(config.audio_model_id, **encoder_kwargs)
|
| 271 |
+
encoder = full_model.encoder
|
| 272 |
+
del full_model
|
| 273 |
+
else:
|
| 274 |
+
encoder = AutoModel.from_pretrained(config.audio_model_id, **encoder_kwargs)
|
| 275 |
+
|
| 276 |
+
is_whisper = "whisper" in config.audio_model_id.lower() or (
|
| 277 |
+
hasattr(encoder.config, "model_type") and "whisper" in encoder.config.model_type.lower()
|
| 278 |
+
)
|
| 279 |
+
|
| 280 |
+
# Wrap encoder forward to handle Whisper's input_features vs input_values
|
| 281 |
+
original_forward = encoder.forward
|
| 282 |
+
input_key = "input_features" if is_whisper else "input_values"
|
| 283 |
+
|
| 284 |
+
def safe_encoder_forward(self_encoder, input_values=None, **kwargs):
|
| 285 |
+
# Catch and discard invalid kwargs like input_ids
|
| 286 |
+
kwargs.pop("input_ids", None)
|
| 287 |
+
return original_forward(**{input_key: input_values}, **kwargs)
|
| 288 |
+
|
| 289 |
+
import types
|
| 290 |
+
|
| 291 |
+
encoder.forward = types.MethodType(safe_encoder_forward, encoder)
|
| 292 |
+
|
| 293 |
+
# Freeze all encoder parameters
|
| 294 |
+
encoder.requires_grad_(False)
|
| 295 |
+
|
| 296 |
+
return encoder
|
| 297 |
+
|
| 298 |
+
@classmethod
|
| 299 |
+
def _create_decoder(cls, config: ASRConfig):
|
| 300 |
+
"""Create and configure the language model decoder.
|
| 301 |
+
|
| 302 |
+
Args:
|
| 303 |
+
config: Model configuration
|
| 304 |
+
|
| 305 |
+
Returns:
|
| 306 |
+
Configured decoder model
|
| 307 |
+
"""
|
| 308 |
+
target_dtype = getattr(torch, config.model_dtype)
|
| 309 |
+
|
| 310 |
+
# When loading from pretrained, avoid device_map="auto" to prevent meta tensor issues
|
| 311 |
+
decoder_kwargs = {
|
| 312 |
+
"attn_implementation": config.attn_implementation,
|
| 313 |
+
"dtype": target_dtype,
|
| 314 |
+
"trust_remote_code": True,
|
| 315 |
+
}
|
| 316 |
+
# Don't use device_map="auto" as it can cause meta tensor issues with Trainer
|
| 317 |
+
# The Trainer will handle device placement
|
| 318 |
+
|
| 319 |
+
decoder = AutoModelForCausalLM.from_pretrained(config.text_model_id, **decoder_kwargs)
|
| 320 |
+
|
| 321 |
+
# use_cache is now safe because we pre-expand audio tokens for consistent sequence length
|
| 322 |
+
# Cache can be enabled/disabled via config.use_cache
|
| 323 |
+
decoder.config.use_cache = config.use_cache
|
| 324 |
+
|
| 325 |
+
# Freeze all decoder parameters (only projector is trainable)
|
| 326 |
+
decoder.requires_grad_(False)
|
| 327 |
+
|
| 328 |
+
return decoder
|
| 329 |
+
|
| 330 |
+
def _init_weights(self, module):
|
| 331 |
+
"""Initialize weights for trainable modules.
|
| 332 |
+
|
| 333 |
+
Note: This is a no-op since:
|
| 334 |
+
- AudioProjector self-initializes in its __init__
|
| 335 |
+
- Encoder/decoder are loaded from pretrained weights
|
| 336 |
+
"""
|
| 337 |
+
pass
|
| 338 |
+
|
| 339 |
+
def can_generate(self) -> bool:
|
| 340 |
+
"""Return True to indicate this model supports generation.
|
| 341 |
+
|
| 342 |
+
Required for Transformers 4.50+ where PreTrainedModel no longer
|
| 343 |
+
inherits from GenerationMixin.
|
| 344 |
+
"""
|
| 345 |
+
return True
|
| 346 |
+
|
| 347 |
+
@property
|
| 348 |
+
def _tied_weights_keys(self):
|
| 349 |
+
"""Return list of weight keys that should be tied.
|
| 350 |
+
|
| 351 |
+
In this model, input and output embeddings of the decoder may be tied.
|
| 352 |
+
"""
|
| 353 |
+
if hasattr(self.decoder, "_tied_weights_keys"):
|
| 354 |
+
return [f"decoder.{k}" for k in self.decoder._tied_weights_keys]
|
| 355 |
+
return []
|
| 356 |
+
|
| 357 |
+
def _init_tokenizer(self):
|
| 358 |
+
model_path = (
|
| 359 |
+
self.__class__._pretrained_model_path
|
| 360 |
+
if self._is_loading_from_pretrained
|
| 361 |
+
else self.config.text_model_id
|
| 362 |
+
)
|
| 363 |
+
|
| 364 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
| 365 |
+
|
| 366 |
+
# Set pad_token if not already set to avoid warnings during generation
|
| 367 |
+
# If pad_token is same as eos_token, we need a different token for padding
|
| 368 |
+
if (
|
| 369 |
+
self.tokenizer.pad_token is None
|
| 370 |
+
or self.tokenizer.pad_token_id == self.tokenizer.eos_token_id
|
| 371 |
+
) and "<|finetune_right_pad_id|>" in self.tokenizer.get_vocab():
|
| 372 |
+
# For SmolLM3, use the dedicated finetune_right_pad_id token
|
| 373 |
+
self.tokenizer.pad_token = "<|finetune_right_pad_id|>"
|
| 374 |
+
|
| 375 |
+
existing_special = self.tokenizer.additional_special_tokens or []
|
| 376 |
+
|
| 377 |
+
# Add single audio token if not present
|
| 378 |
+
if "<audio>" not in existing_special:
|
| 379 |
+
special_tokens = {"additional_special_tokens": existing_special + ["<audio>"]}
|
| 380 |
+
num_added_tokens = self.tokenizer.add_special_tokens(special_tokens)
|
| 381 |
+
if num_added_tokens > 0:
|
| 382 |
+
# Use mean_resizing=False since this is a structural token, not semantic
|
| 383 |
+
self.decoder.resize_token_embeddings(len(self.tokenizer), mean_resizing=False)
|
| 384 |
+
|
| 385 |
+
current_embed_size = self.decoder.get_input_embeddings().weight.shape[0]
|
| 386 |
+
expected_size = len(self.tokenizer)
|
| 387 |
+
if current_embed_size != expected_size:
|
| 388 |
+
self.decoder.resize_token_embeddings(expected_size, mean_resizing=False)
|
| 389 |
+
|
| 390 |
+
self.audio_token_id = self.tokenizer.convert_tokens_to_ids("<audio>")
|
| 391 |
+
|
| 392 |
+
self.tokenizer.padding_side = "right"
|
| 393 |
+
|
| 394 |
+
for cfg in [self.config.text_config, self.decoder.config, self.generation_config]:
|
| 395 |
+
if isinstance(cfg, dict):
|
| 396 |
+
cfg["pad_token_id"] = self.tokenizer.pad_token_id
|
| 397 |
+
cfg["eos_token_id"] = self.tokenizer.eos_token_id
|
| 398 |
+
cfg["bos_token_id"] = self.tokenizer.bos_token_id
|
| 399 |
+
else:
|
| 400 |
+
cfg.pad_token_id = self.tokenizer.pad_token_id
|
| 401 |
+
cfg.eos_token_id = self.tokenizer.eos_token_id
|
| 402 |
+
cfg.bos_token_id = self.tokenizer.bos_token_id
|
| 403 |
+
|
| 404 |
+
def get_processor(self):
|
| 405 |
+
try:
|
| 406 |
+
from .asr_processing import ASRProcessor
|
| 407 |
+
except ImportError:
|
| 408 |
+
from asr_processing import ASRProcessor # type: ignore[no-redef]
|
| 409 |
+
|
| 410 |
+
return ASRProcessor(feature_extractor=self.feature_extractor, tokenizer=self.tokenizer)
|
| 411 |
+
|
| 412 |
+
def state_dict(self, *args, **kwargs):
|
| 413 |
+
"""Return only trainable parameters (projector weights).
|
| 414 |
+
|
| 415 |
+
Called by HuggingFace Trainer to save model.safetensors in checkpoints.
|
| 416 |
+
"""
|
| 417 |
+
return self._get_trainable_state_dict()
|
| 418 |
+
|
| 419 |
+
def _get_trainable_state_dict(self):
|
| 420 |
+
"""Get all trainable parameters as a single state dict.
|
| 421 |
+
|
| 422 |
+
This is used by Trainer for checkpointing during training.
|
| 423 |
+
"""
|
| 424 |
+
state = {}
|
| 425 |
+
|
| 426 |
+
# Only projector params are trainable now (encoder and decoder are frozen)
|
| 427 |
+
projector_state = self.projector.state_dict()
|
| 428 |
+
for name, tensor in projector_state.items():
|
| 429 |
+
state[f"projector.{name}"] = tensor
|
| 430 |
+
|
| 431 |
+
return state
|
| 432 |
+
|
| 433 |
+
def get_input_embeddings(self):
|
| 434 |
+
"""Delegate to decoder for proper HF Trainer integration."""
|
| 435 |
+
return self.decoder.get_input_embeddings()
|
| 436 |
+
|
| 437 |
+
def set_input_embeddings(self, value):
|
| 438 |
+
"""Delegate to decoder for proper HF Trainer integration."""
|
| 439 |
+
self.decoder.set_input_embeddings(value)
|
| 440 |
+
|
| 441 |
+
def get_output_embeddings(self):
|
| 442 |
+
"""Delegate to decoder for proper HF Trainer integration."""
|
| 443 |
+
return self.decoder.get_output_embeddings()
|
| 444 |
+
|
| 445 |
+
def set_output_embeddings(self, value):
|
| 446 |
+
"""Delegate to decoder for proper HF Trainer integration."""
|
| 447 |
+
self.decoder.set_output_embeddings(value)
|
| 448 |
+
|
| 449 |
+
def _encode_audio(
|
| 450 |
+
self,
|
| 451 |
+
input_values: torch.Tensor,
|
| 452 |
+
audio_attention_mask: Optional[torch.Tensor] = None,
|
| 453 |
+
) -> torch.Tensor:
|
| 454 |
+
# Ensure input is on encoder's device and has the right dtype
|
| 455 |
+
encoder_device = next(self.encoder.parameters()).device
|
| 456 |
+
encoder_dtype = next(self.encoder.parameters()).dtype
|
| 457 |
+
# Clone to prevent user tensor reuse contamination
|
| 458 |
+
input_values = input_values.clone().to(device=encoder_device, dtype=encoder_dtype)
|
| 459 |
+
|
| 460 |
+
# Only pass explicit valid arguments to encoder
|
| 461 |
+
# Never use **kwargs to prevent torch.compile from injecting decoder args like input_ids
|
| 462 |
+
# Always use no_grad since encoder is frozen
|
| 463 |
+
with torch.no_grad():
|
| 464 |
+
audio_features = self.encoder(
|
| 465 |
+
input_values=input_values,
|
| 466 |
+
attention_mask=audio_attention_mask,
|
| 467 |
+
).last_hidden_state
|
| 468 |
+
|
| 469 |
+
# Project audio features and ensure dtype matches decoder
|
| 470 |
+
audio_embeds = self.projector(audio_features)
|
| 471 |
+
|
| 472 |
+
# Convert to decoder's dtype if needed (e.g., bfloat16)
|
| 473 |
+
decoder_dtype = next(self.decoder.parameters()).dtype
|
| 474 |
+
if audio_embeds.dtype != decoder_dtype:
|
| 475 |
+
audio_embeds = audio_embeds.to(dtype=decoder_dtype)
|
| 476 |
+
|
| 477 |
+
return audio_embeds
|
| 478 |
+
|
| 479 |
+
def _get_audio_expansion_details(self, input_ids: torch.Tensor, num_audio_tokens: int) -> dict:
|
| 480 |
+
"""Calculate the positions and masks needed to expand audio tokens.
|
| 481 |
+
|
| 482 |
+
This helper consolidates the common cumsum logic used by both
|
| 483 |
+
_expand_audio_tokens and _expand_for_audio_tokens.
|
| 484 |
+
|
| 485 |
+
Args:
|
| 486 |
+
input_ids: Token IDs with single <audio> token per sample
|
| 487 |
+
num_audio_tokens: Number of tokens each audio token expands to
|
| 488 |
+
|
| 489 |
+
Returns:
|
| 490 |
+
Dictionary containing:
|
| 491 |
+
- new_seq_len: The total sequence length after expansion
|
| 492 |
+
- new_start_positions: [batch, old_seq_len] tensor mapping old indices to new
|
| 493 |
+
- audio_mask: [batch, old_seq_len] boolean mask for audio token positions
|
| 494 |
+
"""
|
| 495 |
+
batch_size, seq_len = input_ids.shape
|
| 496 |
+
device = input_ids.device
|
| 497 |
+
|
| 498 |
+
# Find audio token positions
|
| 499 |
+
audio_mask = input_ids == self.audio_token_id
|
| 500 |
+
|
| 501 |
+
# Validate: each sample must have exactly one audio token
|
| 502 |
+
audio_counts = audio_mask.sum(dim=1)
|
| 503 |
+
if not (audio_counts == 1).all():
|
| 504 |
+
missing = (audio_counts == 0).any()
|
| 505 |
+
multiple = (audio_counts > 1).any()
|
| 506 |
+
if missing:
|
| 507 |
+
raise ValueError("Some samples are missing audio token")
|
| 508 |
+
if multiple:
|
| 509 |
+
raise ValueError("Some samples have multiple audio tokens")
|
| 510 |
+
|
| 511 |
+
# Create placeholder tensor: 1 for normal tokens, num_audio_tokens for audio token
|
| 512 |
+
token_counts = torch.where(audio_mask, num_audio_tokens, 1)
|
| 513 |
+
|
| 514 |
+
# Cumsum - 1 gives us the ENDING position of each token's expansion
|
| 515 |
+
cumsum_counts = torch.cumsum(token_counts, dim=1)
|
| 516 |
+
|
| 517 |
+
# The starting position of token i is cumsum[i-1]
|
| 518 |
+
new_start_positions = torch.cat(
|
| 519 |
+
[
|
| 520 |
+
torch.zeros(batch_size, 1, dtype=torch.long, device=device),
|
| 521 |
+
cumsum_counts[:, :-1],
|
| 522 |
+
],
|
| 523 |
+
dim=1,
|
| 524 |
+
)
|
| 525 |
+
|
| 526 |
+
# Calculate new sequence length
|
| 527 |
+
new_seq_len = seq_len - 1 + num_audio_tokens
|
| 528 |
+
|
| 529 |
+
return {
|
| 530 |
+
"new_seq_len": new_seq_len,
|
| 531 |
+
"new_start_positions": new_start_positions,
|
| 532 |
+
"audio_mask": audio_mask,
|
| 533 |
+
}
|
| 534 |
+
|
| 535 |
+
def _expand_tensor_for_audio(
|
| 536 |
+
self,
|
| 537 |
+
input_ids: torch.Tensor,
|
| 538 |
+
tensor_to_expand: Optional[torch.Tensor],
|
| 539 |
+
num_audio_tokens: int,
|
| 540 |
+
fill_value: Optional[Union[int, float]] = None,
|
| 541 |
+
audio_fill_value: Optional[Union[int, float]] = None,
|
| 542 |
+
) -> torch.Tensor:
|
| 543 |
+
"""Generic method to expand any tensor to match audio token expansion.
|
| 544 |
+
|
| 545 |
+
Args:
|
| 546 |
+
input_ids: Token IDs with single <audio> token per sample
|
| 547 |
+
tensor_to_expand: Tensor to expand (input_ids, attention_mask, labels) or None
|
| 548 |
+
num_audio_tokens: Number of tokens each audio token expands to
|
| 549 |
+
fill_value: Default fill value for new tensor
|
| 550 |
+
audio_fill_value: Value to use for audio token positions (if different from fill_value)
|
| 551 |
+
|
| 552 |
+
Returns:
|
| 553 |
+
Expanded tensor matching the expanded sequence length
|
| 554 |
+
"""
|
| 555 |
+
batch_size, seq_len = input_ids.shape
|
| 556 |
+
device = input_ids.device
|
| 557 |
+
|
| 558 |
+
details = self._get_audio_expansion_details(input_ids, num_audio_tokens)
|
| 559 |
+
new_seq_len = details["new_seq_len"]
|
| 560 |
+
new_start_positions = details["new_start_positions"]
|
| 561 |
+
audio_mask = details["audio_mask"]
|
| 562 |
+
|
| 563 |
+
# Determine the tensor we're actually expanding
|
| 564 |
+
if tensor_to_expand is None:
|
| 565 |
+
# Expanding input_ids themselves
|
| 566 |
+
tensor_to_expand = input_ids
|
| 567 |
+
fill_value = fill_value or self.tokenizer.pad_token_id
|
| 568 |
+
audio_fill_value = audio_fill_value or self.audio_token_id
|
| 569 |
+
else:
|
| 570 |
+
# Expanding other tensors (attention_mask, labels)
|
| 571 |
+
if fill_value is None:
|
| 572 |
+
raise ValueError("fill_value must be provided when expanding non-input_ids tensors")
|
| 573 |
+
if audio_fill_value is None:
|
| 574 |
+
audio_fill_value = fill_value
|
| 575 |
+
|
| 576 |
+
# Create output tensor
|
| 577 |
+
expanded = torch.full(
|
| 578 |
+
(batch_size, new_seq_len),
|
| 579 |
+
fill_value,
|
| 580 |
+
dtype=tensor_to_expand.dtype,
|
| 581 |
+
device=device,
|
| 582 |
+
)
|
| 583 |
+
|
| 584 |
+
# Scatter non-audio positions to their new positions
|
| 585 |
+
batch_indices = torch.arange(batch_size, device=device).unsqueeze(1).expand(-1, seq_len)
|
| 586 |
+
non_audio_mask = ~audio_mask
|
| 587 |
+
expanded[batch_indices[non_audio_mask], new_start_positions[non_audio_mask]] = (
|
| 588 |
+
tensor_to_expand[non_audio_mask]
|
| 589 |
+
)
|
| 590 |
+
|
| 591 |
+
# Fill audio positions if different from default fill
|
| 592 |
+
if audio_fill_value != fill_value:
|
| 593 |
+
audio_positions = audio_mask.int().argmax(dim=1)
|
| 594 |
+
audio_new_start = new_start_positions[
|
| 595 |
+
torch.arange(batch_size, device=device), audio_positions
|
| 596 |
+
]
|
| 597 |
+
audio_token_indices = torch.arange(num_audio_tokens, device=device).unsqueeze(0)
|
| 598 |
+
audio_positions_expanded = audio_new_start.unsqueeze(1) + audio_token_indices
|
| 599 |
+
batch_idx_expanded = (
|
| 600 |
+
torch.arange(batch_size, device=device).unsqueeze(1).expand(-1, num_audio_tokens)
|
| 601 |
+
)
|
| 602 |
+
expanded[batch_idx_expanded, audio_positions_expanded] = audio_fill_value
|
| 603 |
+
|
| 604 |
+
return expanded
|
| 605 |
+
|
| 606 |
+
def _expand_audio_tokens(self, input_ids: torch.Tensor, num_audio_tokens: int) -> torch.Tensor:
|
| 607 |
+
"""Convenience method for expanding input_ids."""
|
| 608 |
+
return self._expand_tensor_for_audio(input_ids, None, num_audio_tokens)
|
| 609 |
+
|
| 610 |
+
def _expand_for_audio_tokens(
|
| 611 |
+
self,
|
| 612 |
+
input_ids: torch.Tensor,
|
| 613 |
+
tensor_to_expand: torch.Tensor,
|
| 614 |
+
num_audio_tokens: int,
|
| 615 |
+
fill_value: Union[int, float],
|
| 616 |
+
) -> torch.Tensor:
|
| 617 |
+
"""Convenience method for expanding attention_mask or labels."""
|
| 618 |
+
return self._expand_tensor_for_audio(
|
| 619 |
+
input_ids, tensor_to_expand, num_audio_tokens, fill_value
|
| 620 |
+
)
|
| 621 |
+
|
| 622 |
+
def _prepare_audio_inputs_embeds(
|
| 623 |
+
self, expanded_input_ids: torch.Tensor, audio_embeds: torch.Tensor
|
| 624 |
+
) -> torch.Tensor:
|
| 625 |
+
"""Prepare inputs_embeds by replacing audio token embeddings with actual audio embeddings.
|
| 626 |
+
|
| 627 |
+
Args:
|
| 628 |
+
expanded_input_ids: Input IDs with expanded audio tokens
|
| 629 |
+
audio_embeds: Audio embeddings to inject
|
| 630 |
+
|
| 631 |
+
Returns:
|
| 632 |
+
inputs_embeds with audio embeddings injected
|
| 633 |
+
"""
|
| 634 |
+
# Get text embeddings for expanded input_ids
|
| 635 |
+
inputs_embeds = self.decoder.get_input_embeddings()(expanded_input_ids)
|
| 636 |
+
|
| 637 |
+
# Simple masked scatter: replace audio token embeddings with actual audio embeddings
|
| 638 |
+
special_audio_mask = (expanded_input_ids == self.audio_token_id).unsqueeze(-1)
|
| 639 |
+
special_audio_mask = special_audio_mask.expand_as(inputs_embeds)
|
| 640 |
+
audio_embeds_flat = audio_embeds.reshape(-1, audio_embeds.shape[-1])
|
| 641 |
+
return inputs_embeds.masked_scatter(special_audio_mask, audio_embeds_flat)
|
| 642 |
+
|
| 643 |
+
def forward(
|
| 644 |
+
self,
|
| 645 |
+
input_ids: Optional[torch.Tensor] = None,
|
| 646 |
+
input_values: Optional[torch.Tensor] = None,
|
| 647 |
+
input_features: Optional[torch.Tensor] = None, # For Whisper
|
| 648 |
+
labels: Optional[torch.Tensor] = None,
|
| 649 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 650 |
+
num_items_in_batch: Optional[
|
| 651 |
+
int
|
| 652 |
+
] = None, # HF Trainer provides this for gradient accumulation
|
| 653 |
+
**kwargs,
|
| 654 |
+
):
|
| 655 |
+
audio_inputs = input_values if input_values is not None else input_features
|
| 656 |
+
if audio_inputs is not None:
|
| 657 |
+
# During inference, the pipeline may call forward with only audio inputs
|
| 658 |
+
# In that case, we should raise an error directing to use generate() instead
|
| 659 |
+
if input_ids is None:
|
| 660 |
+
raise ValueError(
|
| 661 |
+
"forward() requires both audio inputs and input_ids (for training). "
|
| 662 |
+
"For inference, use the generate() method instead, or use the pipeline "
|
| 663 |
+
"which will automatically call generate()."
|
| 664 |
+
)
|
| 665 |
+
|
| 666 |
+
# Extract audio-specific kwargs, don't pass input_ids to encoder
|
| 667 |
+
audio_attention_mask = kwargs.pop("audio_attention_mask", None)
|
| 668 |
+
|
| 669 |
+
# Remove any decoder-specific kwargs that shouldn't go to the encoder
|
| 670 |
+
kwargs.pop("past_key_values", None)
|
| 671 |
+
use_cache = kwargs.pop("use_cache", None)
|
| 672 |
+
|
| 673 |
+
# Encode audio to get embeddings
|
| 674 |
+
audio_embeds = self._encode_audio(
|
| 675 |
+
input_values=audio_inputs, # Will be mapped to input_features for Whisper by safe_encoder_forward
|
| 676 |
+
audio_attention_mask=audio_attention_mask,
|
| 677 |
+
)
|
| 678 |
+
|
| 679 |
+
# Validate audio token ID before using it
|
| 680 |
+
if self.audio_token_id is None:
|
| 681 |
+
raise ValueError(f"Audio token not properly initialized: {self.audio_token_id}")
|
| 682 |
+
|
| 683 |
+
vocab_size = self.decoder.get_input_embeddings().weight.shape[0]
|
| 684 |
+
if self.audio_token_id >= vocab_size:
|
| 685 |
+
raise ValueError(
|
| 686 |
+
f"Audio token ID out of range. ID: {self.audio_token_id}, Vocab size: {vocab_size}"
|
| 687 |
+
)
|
| 688 |
+
|
| 689 |
+
# Check that audio token exists
|
| 690 |
+
if not (input_ids == self.audio_token_id).any():
|
| 691 |
+
raise ValueError("Audio token <audio> must be present in input")
|
| 692 |
+
|
| 693 |
+
# Expand audio tokens to match audio embedding length
|
| 694 |
+
num_audio_tokens = audio_embeds.shape[1]
|
| 695 |
+
expanded_input_ids = self._expand_audio_tokens(input_ids, num_audio_tokens)
|
| 696 |
+
|
| 697 |
+
# Prepare inputs_embeds with audio embeddings injected
|
| 698 |
+
inputs_embeds = self._prepare_audio_inputs_embeds(expanded_input_ids, audio_embeds)
|
| 699 |
+
|
| 700 |
+
# Expand attention mask to match new sequence length (vectorized)
|
| 701 |
+
if attention_mask is not None:
|
| 702 |
+
full_attention_mask = self._expand_for_audio_tokens(
|
| 703 |
+
input_ids, attention_mask, num_audio_tokens, fill_value=1
|
| 704 |
+
)
|
| 705 |
+
else:
|
| 706 |
+
full_attention_mask = None
|
| 707 |
+
|
| 708 |
+
# Expand labels to match new sequence length (vectorized, mark audio tokens as -100)
|
| 709 |
+
if labels is not None:
|
| 710 |
+
labels = self._expand_for_audio_tokens(
|
| 711 |
+
input_ids, labels, num_audio_tokens, fill_value=-100
|
| 712 |
+
)
|
| 713 |
+
else:
|
| 714 |
+
inputs_embeds = self.decoder.get_input_embeddings()(input_ids)
|
| 715 |
+
full_attention_mask = attention_mask
|
| 716 |
+
use_cache = kwargs.pop("use_cache", None)
|
| 717 |
+
|
| 718 |
+
# Standard forward pass with built-in loss computation
|
| 719 |
+
return self.decoder(
|
| 720 |
+
inputs_embeds=inputs_embeds,
|
| 721 |
+
attention_mask=full_attention_mask,
|
| 722 |
+
labels=labels,
|
| 723 |
+
use_cache=use_cache if use_cache is not None else False,
|
| 724 |
+
**kwargs,
|
| 725 |
+
)
|
| 726 |
+
|
| 727 |
+
@torch.no_grad()
|
| 728 |
+
def generate(
|
| 729 |
+
self,
|
| 730 |
+
input_values: Optional[torch.Tensor] = None,
|
| 731 |
+
input_features: Optional[torch.Tensor] = None, # For Whisper
|
| 732 |
+
system_prompt: Optional[str] = None,
|
| 733 |
+
user_prompt: Optional[str] = None,
|
| 734 |
+
task: Optional[str] = None,
|
| 735 |
+
**generate_kwargs,
|
| 736 |
+
) -> Union[
|
| 737 |
+
torch.Tensor,
|
| 738 |
+
GenerateDecoderOnlyOutput,
|
| 739 |
+
GenerateEncoderDecoderOutput,
|
| 740 |
+
GenerateBeamDecoderOnlyOutput,
|
| 741 |
+
GenerateBeamEncoderDecoderOutput,
|
| 742 |
+
]:
|
| 743 |
+
audio_inputs = input_values if input_values is not None else input_features
|
| 744 |
+
if audio_inputs is None:
|
| 745 |
+
raise ValueError("input_values or input_features must be provided for generation")
|
| 746 |
+
|
| 747 |
+
audio_embeds = self._encode_audio(audio_inputs)
|
| 748 |
+
batch_size = audio_embeds.shape[0]
|
| 749 |
+
device = audio_embeds.device
|
| 750 |
+
|
| 751 |
+
if system_prompt is None:
|
| 752 |
+
system_prompt = self.system_prompt
|
| 753 |
+
|
| 754 |
+
if user_prompt is None:
|
| 755 |
+
user_prompt = (
|
| 756 |
+
self.TASK_PROMPTS.get(task, self.config.user_prompt or "Transcribe: <audio>")
|
| 757 |
+
or "Transcribe: <audio>"
|
| 758 |
+
)
|
| 759 |
+
|
| 760 |
+
messages = []
|
| 761 |
+
if system_prompt:
|
| 762 |
+
messages.append({"role": "system", "content": system_prompt})
|
| 763 |
+
messages.append(
|
| 764 |
+
{
|
| 765 |
+
"role": "user",
|
| 766 |
+
"content": user_prompt,
|
| 767 |
+
}
|
| 768 |
+
)
|
| 769 |
+
|
| 770 |
+
prompt_ids = self.tokenizer.apply_chat_template(
|
| 771 |
+
messages,
|
| 772 |
+
tokenize=True,
|
| 773 |
+
add_generation_prompt=True,
|
| 774 |
+
return_tensors="pt",
|
| 775 |
+
enable_thinking=False,
|
| 776 |
+
).to(device)
|
| 777 |
+
|
| 778 |
+
if len(prompt_ids.shape) == 1:
|
| 779 |
+
prompt_ids = prompt_ids.unsqueeze(0)
|
| 780 |
+
|
| 781 |
+
if prompt_ids.shape[0] == 1 and batch_size > 1:
|
| 782 |
+
prompt_ids = prompt_ids.expand(batch_size, -1)
|
| 783 |
+
|
| 784 |
+
if not (prompt_ids == self.audio_token_id).any():
|
| 785 |
+
raise ValueError("Audio token <audio> not found in prompt")
|
| 786 |
+
|
| 787 |
+
# Expand audio tokens to match audio embedding length
|
| 788 |
+
num_audio_tokens = audio_embeds.shape[1]
|
| 789 |
+
expanded_prompt_ids = self._expand_audio_tokens(prompt_ids, num_audio_tokens)
|
| 790 |
+
|
| 791 |
+
# Prepare inputs_embeds with audio embeddings injected
|
| 792 |
+
inputs_embeds = self._prepare_audio_inputs_embeds(expanded_prompt_ids, audio_embeds)
|
| 793 |
+
|
| 794 |
+
# Create attention mask for expanded sequence
|
| 795 |
+
total_seq_len = inputs_embeds.shape[1]
|
| 796 |
+
attention_mask = torch.ones(batch_size, total_seq_len, dtype=torch.long, device=device)
|
| 797 |
+
|
| 798 |
+
# Apply generation defaults from config
|
| 799 |
+
config_params = [
|
| 800 |
+
"max_new_tokens",
|
| 801 |
+
"min_new_tokens",
|
| 802 |
+
"num_beams",
|
| 803 |
+
"do_sample",
|
| 804 |
+
"temperature",
|
| 805 |
+
"top_k",
|
| 806 |
+
"top_p",
|
| 807 |
+
"repetition_penalty",
|
| 808 |
+
"length_penalty",
|
| 809 |
+
"no_repeat_ngram_size",
|
| 810 |
+
"early_stopping",
|
| 811 |
+
]
|
| 812 |
+
for param in config_params:
|
| 813 |
+
if hasattr(self.config, param) and getattr(self.config, param) is not None:
|
| 814 |
+
generate_kwargs.setdefault(param, getattr(self.config, param))
|
| 815 |
+
|
| 816 |
+
# Add special token defaults
|
| 817 |
+
generate_kwargs.setdefault("use_cache", True)
|
| 818 |
+
generate_kwargs.setdefault(
|
| 819 |
+
"eos_token_id", self.tokenizer.convert_tokens_to_ids("<|im_end|>")
|
| 820 |
+
)
|
| 821 |
+
generate_kwargs.setdefault("pad_token_id", self.tokenizer.pad_token_id)
|
| 822 |
+
|
| 823 |
+
# Track the prompt length to extract only newly generated tokens
|
| 824 |
+
prompt_length = expanded_prompt_ids.shape[1]
|
| 825 |
+
|
| 826 |
+
# Generate the full sequence
|
| 827 |
+
generated_ids = self.decoder.generate(
|
| 828 |
+
input_ids=expanded_prompt_ids,
|
| 829 |
+
inputs_embeds=inputs_embeds,
|
| 830 |
+
attention_mask=attention_mask,
|
| 831 |
+
**generate_kwargs,
|
| 832 |
+
)
|
| 833 |
+
|
| 834 |
+
# Return only the newly generated tokens (exclude the prompt)
|
| 835 |
+
return generated_ids[:, prompt_length:]
|
| 836 |
+
|
| 837 |
+
def save_pretrained(self, save_directory: Union[str, Path], **kwargs):
|
| 838 |
+
import shutil
|
| 839 |
+
from pathlib import Path as PathlibPath
|
| 840 |
+
|
| 841 |
+
save_dir = PathlibPath(save_directory)
|
| 842 |
+
save_dir.mkdir(parents=True, exist_ok=True)
|
| 843 |
+
|
| 844 |
+
actual_vocab_size = self.decoder.config.vocab_size
|
| 845 |
+
self.config.vocab_size = actual_vocab_size
|
| 846 |
+
self.config.text_config.vocab_size = actual_vocab_size
|
| 847 |
+
|
| 848 |
+
if hasattr(self.encoder.config, "num_mel_bins"):
|
| 849 |
+
self.config.audio_config.num_mel_bins = self.encoder.config.num_mel_bins
|
| 850 |
+
|
| 851 |
+
# Use parent class to save config and model.safetensors
|
| 852 |
+
super().save_pretrained(save_dir, **kwargs)
|
| 853 |
+
|
| 854 |
+
self.tokenizer.save_pretrained(save_dir)
|
| 855 |
+
|
| 856 |
+
# For Whisper models, ensure feature_size matches num_mel_bins from encoder config
|
| 857 |
+
if hasattr(self.encoder.config, "num_mel_bins"):
|
| 858 |
+
# For Whisper models, explicitly set the correct feature_size before saving
|
| 859 |
+
num_mel_bins = self.encoder.config.num_mel_bins
|
| 860 |
+
self.feature_extractor.feature_size = num_mel_bins
|
| 861 |
+
self.feature_extractor.num_mel_bins = num_mel_bins # Explicitly set num_mel_bins
|
| 862 |
+
if hasattr(self.feature_extractor, "n_mels"):
|
| 863 |
+
self.feature_extractor.n_mels = num_mel_bins
|
| 864 |
+
self.feature_extractor.nb_max_frames = 3000 # Whisper's max frames
|
| 865 |
+
|
| 866 |
+
self.get_processor().save_pretrained(save_dir)
|
| 867 |
+
|
| 868 |
+
src_dir = PathlibPath(__file__).parent
|
| 869 |
+
for asr_file in src_dir.glob("asr_*.py"):
|
| 870 |
+
shutil.copy(asr_file, save_dir / asr_file.name)
|
| 871 |
+
|
| 872 |
+
|
| 873 |
+
AutoConfig.register("asr_model", ASRConfig)
|
| 874 |
+
AutoModel.register(ASRConfig, ASRModel)
|
asr_pipeline.py
ADDED
|
@@ -0,0 +1,293 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, Dict
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import transformers
|
| 5 |
+
from truecase import get_true_case
|
| 6 |
+
|
| 7 |
+
try:
|
| 8 |
+
from .asr_modeling import ASRModel
|
| 9 |
+
except ImportError:
|
| 10 |
+
from asr_modeling import ASRModel # type: ignore[no-redef]
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
|
| 14 |
+
model: ASRModel
|
| 15 |
+
|
| 16 |
+
def __init__(self, model: ASRModel, **kwargs):
|
| 17 |
+
feature_extractor = kwargs.pop("feature_extractor", model.feature_extractor)
|
| 18 |
+
tokenizer = kwargs.pop("tokenizer", model.tokenizer)
|
| 19 |
+
|
| 20 |
+
super().__init__(
|
| 21 |
+
model=model, feature_extractor=feature_extractor, tokenizer=tokenizer, **kwargs
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
# Initialize text normalizer (same as train.py)
|
| 25 |
+
if hasattr(tokenizer, "normalize"):
|
| 26 |
+
self.text_normalizer = tokenizer
|
| 27 |
+
else:
|
| 28 |
+
# Fallback to whisper-tiny tokenizer for its normalize() method only
|
| 29 |
+
from transformers import WhisperTokenizer
|
| 30 |
+
self.text_normalizer = WhisperTokenizer.from_pretrained("openai/whisper-tiny")
|
| 31 |
+
|
| 32 |
+
def __call__(self, inputs, **kwargs):
|
| 33 |
+
generate_kwargs = {}
|
| 34 |
+
for key in [
|
| 35 |
+
"max_new_tokens",
|
| 36 |
+
"num_beams",
|
| 37 |
+
"do_sample",
|
| 38 |
+
"length_penalty",
|
| 39 |
+
"repetition_penalty",
|
| 40 |
+
"no_repeat_ngram_size",
|
| 41 |
+
"early_stopping",
|
| 42 |
+
"num_beam_groups",
|
| 43 |
+
"diversity_penalty",
|
| 44 |
+
"top_k",
|
| 45 |
+
"temperature",
|
| 46 |
+
"top_p",
|
| 47 |
+
"user_prompt",
|
| 48 |
+
"task",
|
| 49 |
+
"text_input",
|
| 50 |
+
]:
|
| 51 |
+
if key in kwargs:
|
| 52 |
+
generate_kwargs[key] = kwargs.pop(key)
|
| 53 |
+
|
| 54 |
+
# Handle text-only mode
|
| 55 |
+
task = generate_kwargs.get("task")
|
| 56 |
+
if task == "text" or generate_kwargs.get("text_input"):
|
| 57 |
+
return self._process_text_only(generate_kwargs)
|
| 58 |
+
|
| 59 |
+
if isinstance(inputs, list):
|
| 60 |
+
results = []
|
| 61 |
+
for single_input in inputs:
|
| 62 |
+
result = self.__call__(single_input, **kwargs, **generate_kwargs)
|
| 63 |
+
results.append(result)
|
| 64 |
+
return results
|
| 65 |
+
|
| 66 |
+
model_inputs = self.preprocess(inputs, **kwargs)
|
| 67 |
+
|
| 68 |
+
from collections.abc import Iterator
|
| 69 |
+
|
| 70 |
+
if isinstance(model_inputs, Iterator):
|
| 71 |
+
# Convert iterator to list to process chunks
|
| 72 |
+
chunks = list(model_inputs)
|
| 73 |
+
|
| 74 |
+
all_outputs = []
|
| 75 |
+
for _chunk_num, chunk in enumerate(chunks, start=1):
|
| 76 |
+
chunk_output = self._forward(chunk, **generate_kwargs)
|
| 77 |
+
# Move tensors to CPU before adding to outputs
|
| 78 |
+
for key, value in chunk_output.items():
|
| 79 |
+
if torch.is_tensor(value):
|
| 80 |
+
chunk_output[key] = value.cpu()
|
| 81 |
+
all_outputs.append(chunk_output)
|
| 82 |
+
|
| 83 |
+
# Merge chunks and decode ourselves to ensure skip_special_tokens=True
|
| 84 |
+
all_tokens: list[int] = []
|
| 85 |
+
for output in all_outputs:
|
| 86 |
+
tokens = output.get("tokens")
|
| 87 |
+
if tokens is None:
|
| 88 |
+
tokens = output.get("generated_ids")
|
| 89 |
+
if tokens is not None:
|
| 90 |
+
if torch.is_tensor(tokens):
|
| 91 |
+
tokens = tokens.cpu()
|
| 92 |
+
if len(tokens.shape) > 1:
|
| 93 |
+
tokens = tokens[0]
|
| 94 |
+
all_tokens.extend(tokens.tolist() if torch.is_tensor(tokens) else tokens)
|
| 95 |
+
|
| 96 |
+
# Decode the merged tokens with skip_special_tokens
|
| 97 |
+
text = self.tokenizer.decode(all_tokens, skip_special_tokens=True)
|
| 98 |
+
text = text.strip()
|
| 99 |
+
|
| 100 |
+
# Apply Whisper normalization (matches training)
|
| 101 |
+
text = self.text_normalizer.normalize(text)
|
| 102 |
+
|
| 103 |
+
# Apply truecasing for proper capitalization
|
| 104 |
+
text = get_true_case(text)
|
| 105 |
+
|
| 106 |
+
return {"text": text}
|
| 107 |
+
|
| 108 |
+
model_outputs = self._forward(model_inputs, **generate_kwargs)
|
| 109 |
+
return self.postprocess(model_outputs)
|
| 110 |
+
|
| 111 |
+
def preprocess(self, inputs, **preprocess_params):
|
| 112 |
+
if isinstance(inputs, list):
|
| 113 |
+
raise ValueError("Lists should not reach preprocess - bug in __call__")
|
| 114 |
+
|
| 115 |
+
# Set default chunking to 30 seconds with 5 second overlap
|
| 116 |
+
preprocess_params.setdefault("chunk_length_s", 30)
|
| 117 |
+
preprocess_params.setdefault("stride_length_s", (5, 5))
|
| 118 |
+
|
| 119 |
+
# Handle different formats from datasets
|
| 120 |
+
if isinstance(inputs, dict):
|
| 121 |
+
if "bytes" in inputs:
|
| 122 |
+
# Decode bytes to audio array using torchcodec
|
| 123 |
+
import tempfile
|
| 124 |
+
|
| 125 |
+
from torchcodec.decoders import AudioDecoder
|
| 126 |
+
|
| 127 |
+
wav_bytes = inputs["bytes"]
|
| 128 |
+
# Write to temp file for torchcodec to read
|
| 129 |
+
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
|
| 130 |
+
f.write(wav_bytes)
|
| 131 |
+
temp_path = f.name
|
| 132 |
+
try:
|
| 133 |
+
decoder = AudioDecoder(temp_path)
|
| 134 |
+
# Get all audio samples
|
| 135 |
+
audio_result = decoder.get_all_samples()
|
| 136 |
+
audio_tensor = audio_result.data
|
| 137 |
+
sample_rate = audio_result.sample_rate
|
| 138 |
+
inputs = {"raw": audio_tensor.squeeze().numpy(), "sampling_rate": sample_rate}
|
| 139 |
+
finally:
|
| 140 |
+
from pathlib import Path
|
| 141 |
+
|
| 142 |
+
Path(temp_path).unlink()
|
| 143 |
+
elif "array" in inputs:
|
| 144 |
+
# Convert "array" key to "raw" key
|
| 145 |
+
inputs = {"raw": inputs["array"], "sampling_rate": inputs["sampling_rate"]}
|
| 146 |
+
# If it already has "raw" and "sampling_rate", it's good to go
|
| 147 |
+
elif hasattr(inputs, "array") and hasattr(inputs, "sampling_rate"):
|
| 148 |
+
# Audio object with attributes (not dict)
|
| 149 |
+
inputs = {"raw": inputs.array, "sampling_rate": inputs.sampling_rate}
|
| 150 |
+
elif hasattr(inputs, "__array__") and not isinstance(inputs, (dict, bytes, str)):
|
| 151 |
+
inputs = {"raw": inputs, "sampling_rate": self.model.config.audio_sample_rate}
|
| 152 |
+
elif torch.is_tensor(inputs):
|
| 153 |
+
inputs = {
|
| 154 |
+
"raw": inputs.cpu().numpy(),
|
| 155 |
+
"sampling_rate": self.model.config.audio_sample_rate,
|
| 156 |
+
}
|
| 157 |
+
|
| 158 |
+
return super().preprocess(inputs, **preprocess_params)
|
| 159 |
+
|
| 160 |
+
def _forward(self, model_inputs, **generate_kwargs):
|
| 161 |
+
# Extract task and set sampling parameters
|
| 162 |
+
task = generate_kwargs.pop("task", None)
|
| 163 |
+
|
| 164 |
+
# Task-specific sampling parameters
|
| 165 |
+
task_params: Dict[str, Dict[str, Any]] = {
|
| 166 |
+
"transcribe": {"do_sample": False},
|
| 167 |
+
"emotion": {"do_sample": True, "temperature": 0.7},
|
| 168 |
+
"describe": {"do_sample": True, "temperature": 0.7},
|
| 169 |
+
"continue": {"do_sample": True, "temperature": 1.0},
|
| 170 |
+
}
|
| 171 |
+
|
| 172 |
+
if task in task_params:
|
| 173 |
+
for key, value in task_params[task].items():
|
| 174 |
+
generate_kwargs.setdefault(key, value)
|
| 175 |
+
|
| 176 |
+
# Extract audio inputs from various formats
|
| 177 |
+
is_last = True
|
| 178 |
+
audio_inputs = None
|
| 179 |
+
is_whisper = False # Track if this is Whisper input
|
| 180 |
+
|
| 181 |
+
# Normalize model_inputs to dict format
|
| 182 |
+
if isinstance(model_inputs, torch.Tensor):
|
| 183 |
+
audio_inputs = model_inputs
|
| 184 |
+
elif isinstance(model_inputs, (list, tuple)) and model_inputs:
|
| 185 |
+
model_inputs = (
|
| 186 |
+
model_inputs[0]
|
| 187 |
+
if isinstance(model_inputs[0], dict)
|
| 188 |
+
else {"input_values": model_inputs[0]}
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
if isinstance(model_inputs, dict):
|
| 192 |
+
# Pop metadata fields
|
| 193 |
+
is_last = model_inputs.pop("is_last", True)
|
| 194 |
+
model_inputs.pop("stride", None)
|
| 195 |
+
# Get audio input (Whisper uses input_features, others use input_values)
|
| 196 |
+
if "input_features" in model_inputs:
|
| 197 |
+
audio_inputs = model_inputs["input_features"]
|
| 198 |
+
is_whisper = True
|
| 199 |
+
else:
|
| 200 |
+
audio_inputs = model_inputs.get("input_values")
|
| 201 |
+
|
| 202 |
+
if audio_inputs is None:
|
| 203 |
+
raise ValueError(
|
| 204 |
+
f"Could not extract input_values or input_features from {type(model_inputs)}"
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
if isinstance(audio_inputs, torch.Tensor):
|
| 208 |
+
audio_inputs = audio_inputs.to(self.model.device)
|
| 209 |
+
else:
|
| 210 |
+
raise ValueError(f"audio inputs must be a tensor, got {type(audio_inputs)}")
|
| 211 |
+
|
| 212 |
+
im_end_id = self.model.tokenizer.convert_tokens_to_ids("<|im_end|>")
|
| 213 |
+
generate_kwargs.setdefault("eos_token_id", im_end_id)
|
| 214 |
+
generate_kwargs.setdefault("max_new_tokens", self.model.config.max_new_tokens)
|
| 215 |
+
|
| 216 |
+
# Pass the appropriate input type to generate
|
| 217 |
+
if is_whisper:
|
| 218 |
+
# Whisper model - use input_features
|
| 219 |
+
generated_ids = self.model.generate(
|
| 220 |
+
input_features=audio_inputs,
|
| 221 |
+
system_prompt=self.model.config.system_prompt,
|
| 222 |
+
task=task,
|
| 223 |
+
**generate_kwargs,
|
| 224 |
+
)
|
| 225 |
+
else:
|
| 226 |
+
# Wav2Vec2/HuBERT model - use input_values
|
| 227 |
+
generated_ids = self.model.generate(
|
| 228 |
+
input_values=audio_inputs,
|
| 229 |
+
system_prompt=self.model.config.system_prompt,
|
| 230 |
+
task=task,
|
| 231 |
+
**generate_kwargs,
|
| 232 |
+
)
|
| 233 |
+
|
| 234 |
+
return {"tokens": generated_ids, "is_last": is_last}
|
| 235 |
+
|
| 236 |
+
def _process_text_only(self, generate_kwargs):
|
| 237 |
+
"""Process text-only input without audio encoding."""
|
| 238 |
+
text_input = generate_kwargs.pop("text_input", None)
|
| 239 |
+
if text_input is None:
|
| 240 |
+
raise ValueError("text_input is required for text task")
|
| 241 |
+
|
| 242 |
+
# Remove task from generate_kwargs to avoid duplicate argument
|
| 243 |
+
generate_kwargs.pop("task", None)
|
| 244 |
+
|
| 245 |
+
# Generate text using the model
|
| 246 |
+
generated_ids = self.model.generate(task="text", text_input=text_input, **generate_kwargs)
|
| 247 |
+
|
| 248 |
+
# Decode the generated text
|
| 249 |
+
generated_text = self.tokenizer.decode(generated_ids[0], skip_special_tokens=True)
|
| 250 |
+
|
| 251 |
+
return {"text": generated_text}
|
| 252 |
+
|
| 253 |
+
def postprocess(
|
| 254 |
+
self, model_outputs: Dict[str, Any], return_timestamps=None, return_language=None
|
| 255 |
+
):
|
| 256 |
+
# Handle chunked outputs from iterator
|
| 257 |
+
if isinstance(model_outputs, list):
|
| 258 |
+
# Move all tensors to CPU before calling parent postprocess
|
| 259 |
+
for output_dict in model_outputs:
|
| 260 |
+
for key, value in output_dict.items():
|
| 261 |
+
if torch.is_tensor(value):
|
| 262 |
+
output_dict[key] = value.cpu()
|
| 263 |
+
return super().postprocess(model_outputs)
|
| 264 |
+
|
| 265 |
+
if "is_last" in model_outputs:
|
| 266 |
+
model_outputs.pop("is_last")
|
| 267 |
+
|
| 268 |
+
tokens = model_outputs.get("tokens")
|
| 269 |
+
if tokens is None:
|
| 270 |
+
tokens = model_outputs.get("generated_ids")
|
| 271 |
+
|
| 272 |
+
if tokens is None:
|
| 273 |
+
raise ValueError(
|
| 274 |
+
f"Expected 'tokens' or 'generated_ids' in model_outputs, got: {model_outputs.keys()}"
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
# Move to CPU if on MPS or other device
|
| 278 |
+
if torch.is_tensor(tokens) and tokens.device.type != "cpu":
|
| 279 |
+
tokens = tokens.cpu()
|
| 280 |
+
|
| 281 |
+
if len(tokens.shape) > 1:
|
| 282 |
+
tokens = tokens[0]
|
| 283 |
+
|
| 284 |
+
text = self.tokenizer.decode(tokens, skip_special_tokens=True)
|
| 285 |
+
text = text.strip()
|
| 286 |
+
|
| 287 |
+
# Apply Whisper normalization (matches training)
|
| 288 |
+
text = self.text_normalizer.normalize(text)
|
| 289 |
+
|
| 290 |
+
# Apply truecasing for proper capitalization
|
| 291 |
+
text = get_true_case(text)
|
| 292 |
+
|
| 293 |
+
return {"text": text}
|
asr_processing.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import transformers
|
| 2 |
+
from transformers import AutoTokenizer, ProcessorMixin
|
| 3 |
+
|
| 4 |
+
# Handle both package and standalone imports
|
| 5 |
+
try:
|
| 6 |
+
from .asr_config import ASRConfig
|
| 7 |
+
except ImportError:
|
| 8 |
+
from asr_config import ASRConfig # type: ignore[no-redef]
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class ASRProcessor(ProcessorMixin):
|
| 12 |
+
"""Generic processor that can handle both Wav2Vec2 and Whisper feature extractors."""
|
| 13 |
+
|
| 14 |
+
feature_extractor_class = "AutoFeatureExtractor"
|
| 15 |
+
tokenizer_class = "AutoTokenizer"
|
| 16 |
+
|
| 17 |
+
def __init__(self, feature_extractor, tokenizer):
|
| 18 |
+
self.feature_extractor = feature_extractor
|
| 19 |
+
self.tokenizer = tokenizer
|
| 20 |
+
|
| 21 |
+
@classmethod
|
| 22 |
+
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
|
| 23 |
+
from transformers import AutoFeatureExtractor
|
| 24 |
+
|
| 25 |
+
# Load feature extractor and tokenizer from saved model directory
|
| 26 |
+
feature_extractor = AutoFeatureExtractor.from_pretrained(
|
| 27 |
+
pretrained_model_name_or_path, **kwargs
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 31 |
+
pretrained_model_name_or_path, trust_remote_code=True, **kwargs
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
return cls(feature_extractor=feature_extractor, tokenizer=tokenizer)
|
| 35 |
+
|
| 36 |
+
def save_pretrained(self, save_directory, **kwargs):
|
| 37 |
+
"""Override save_pretrained to avoid attribute errors from base class."""
|
| 38 |
+
import json
|
| 39 |
+
from pathlib import Path
|
| 40 |
+
|
| 41 |
+
save_path = Path(save_directory)
|
| 42 |
+
save_path.mkdir(parents=True, exist_ok=True)
|
| 43 |
+
|
| 44 |
+
# Save the feature extractor (this creates preprocessor_config.json with all feature extractor settings)
|
| 45 |
+
if self.feature_extractor is not None:
|
| 46 |
+
self.feature_extractor.save_pretrained(save_directory)
|
| 47 |
+
|
| 48 |
+
# Save the tokenizer
|
| 49 |
+
if self.tokenizer is not None:
|
| 50 |
+
self.tokenizer.save_pretrained(save_directory)
|
| 51 |
+
|
| 52 |
+
# Load the existing preprocessor_config.json and add processor-specific metadata
|
| 53 |
+
config_path = save_path / "preprocessor_config.json"
|
| 54 |
+
if config_path.exists():
|
| 55 |
+
with config_path.open() as f:
|
| 56 |
+
processor_config = json.load(f)
|
| 57 |
+
else:
|
| 58 |
+
processor_config = {}
|
| 59 |
+
|
| 60 |
+
# Add/update processor metadata while preserving feature extractor settings
|
| 61 |
+
feature_extractor_type = self.feature_extractor.__class__.__name__
|
| 62 |
+
processor_config.update(
|
| 63 |
+
{
|
| 64 |
+
"processor_class": self.__class__.__name__,
|
| 65 |
+
"feature_extractor_class": self.feature_extractor_class,
|
| 66 |
+
"tokenizer_class": self.tokenizer_class,
|
| 67 |
+
"feature_extractor_type": feature_extractor_type, # Dynamic based on actual type
|
| 68 |
+
"auto_map": {"AutoProcessor": "asr_processing.ASRProcessor"},
|
| 69 |
+
}
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
# Save the merged config
|
| 73 |
+
with config_path.open("w") as f:
|
| 74 |
+
json.dump(processor_config, f, indent=2)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
ASRProcessor.register_for_auto_class()
|
| 78 |
+
transformers.AutoProcessor.register(ASRConfig, ASRProcessor)
|
chat_template.jinja
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{# βββββ defaults βββββ #}
|
| 2 |
+
{%- if enable_thinking is not defined -%}
|
| 3 |
+
{%- set enable_thinking = true -%}
|
| 4 |
+
{%- endif -%}
|
| 5 |
+
|
| 6 |
+
{# βββββ reasoning mode βββββ #}
|
| 7 |
+
{%- if enable_thinking -%}
|
| 8 |
+
{%- set reasoning_mode = "/think" -%}
|
| 9 |
+
{%- else -%}
|
| 10 |
+
{%- set reasoning_mode = "/no_think" -%}
|
| 11 |
+
{%- endif -%}
|
| 12 |
+
|
| 13 |
+
{# βββββ header (system message) βββββ #}
|
| 14 |
+
{{- "<|im_start|>system\n" -}}
|
| 15 |
+
|
| 16 |
+
{%- if messages[0].role == "system" -%}
|
| 17 |
+
{%- set system_message = messages[0].content -%}
|
| 18 |
+
{%- if "/no_think" in system_message -%}
|
| 19 |
+
{%- set reasoning_mode = "/no_think" -%}
|
| 20 |
+
{%- elif "/think" in system_message -%}
|
| 21 |
+
{%- set reasoning_mode = "/think" -%}
|
| 22 |
+
{%- endif -%}
|
| 23 |
+
{%- set custom_instructions = system_message.replace("/no_think", "").replace("/think", "").rstrip() -%}
|
| 24 |
+
{%- endif -%}
|
| 25 |
+
|
| 26 |
+
{%- if "/system_override" in system_message -%}
|
| 27 |
+
{{- custom_instructions.replace("/system_override", "").rstrip() -}}
|
| 28 |
+
{{- "<|im_end|>\n" -}}
|
| 29 |
+
{%- else -%}
|
| 30 |
+
{{- "## Metadata\n\n" -}}
|
| 31 |
+
{{- "Knowledge Cutoff Date: June 2025\n" -}}
|
| 32 |
+
{%- set today = strftime_now("%d %B %Y") -%}
|
| 33 |
+
{{- "Today Date: " ~ today ~ "\n" -}}
|
| 34 |
+
{{- "Reasoning Mode: " + reasoning_mode + "\n\n" -}}
|
| 35 |
+
|
| 36 |
+
{{- "## Custom Instructions\n\n" -}}
|
| 37 |
+
{%- if custom_instructions -%}
|
| 38 |
+
{{- custom_instructions + "\n\n" -}}
|
| 39 |
+
{%- elif reasoning_mode == "/think" -%}
|
| 40 |
+
{{- "You are a helpful AI assistant named SmolLM, trained by Hugging Face. Your role as an assistant involves thoroughly exploring questions through a systematic thinking process before providing the final precise and accurate solutions. This requires engaging in a comprehensive cycle of analysis, summarizing, exploration, reassessment, reflection, backtracking, and iteration to develop well-considered thinking process. Please structure your response into two main sections: Thought and Solution using the specified format: <think> Thought section </think> Solution section. In the Thought section, detail your reasoning process in steps. Each step should include detailed considerations such as analysing questions, summarizing relevant findings, brainstorming new ideas, verifying the accuracy of the current steps, refining any errors, and revisiting previous steps. In the Solution section, based on various attempts, explorations, and reflections from the Thought section, systematically present the final solution that you deem correct. The Solution section should be logical, accurate, and concise and detail necessary steps needed to reach the conclusion.\n\n" -}}
|
| 41 |
+
{%- else -%}
|
| 42 |
+
{{- "You are a helpful AI assistant named SmolLM, trained by Hugging Face.\n\n" -}}
|
| 43 |
+
{%- endif -%}
|
| 44 |
+
|
| 45 |
+
{%- if xml_tools or python_tools or tools -%}
|
| 46 |
+
{{- "### Tools\n\n" -}}
|
| 47 |
+
{%- if xml_tools or tools -%}
|
| 48 |
+
{%- if tools -%}
|
| 49 |
+
{%- set xml_tools = tools -%}
|
| 50 |
+
{%- endif -%}
|
| 51 |
+
{%- set ns = namespace(xml_tool_string="You may call one or more functions to assist with the user query.\nYou are provided with function signatures within <tools></tools> XML tags:\n\n<tools>\n") -%}
|
| 52 |
+
{%- for tool in xml_tools[:] -%} {# The slicing makes sure that xml_tools is a list #}
|
| 53 |
+
{%- set ns.xml_tool_string = ns.xml_tool_string ~ (tool | string) ~ "\n" -%}
|
| 54 |
+
{%- endfor -%}
|
| 55 |
+
{%- set xml_tool_string = ns.xml_tool_string + "</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call>" -%}
|
| 56 |
+
{{- xml_tool_string -}}
|
| 57 |
+
{%- endif -%}
|
| 58 |
+
{%- if python_tools -%}
|
| 59 |
+
{%- set ns = namespace(python_tool_string="When you send a message containing Python code between '<code>' and '</code>' tags, it will be executed in a stateful Jupyter notebook environment, and you will then be given the output to continued reasoning in an agentic loop.\n\nYou can use the following tools in your python code like regular functions:\n<tools>\n") -%}
|
| 60 |
+
{%- for tool in python_tools[:] -%} {# The slicing makes sure that python_tools is a list #}
|
| 61 |
+
{%- set ns.python_tool_string = ns.python_tool_string ~ (tool | string) ~ "\n" -%}
|
| 62 |
+
{%- endfor -%}
|
| 63 |
+
{%- set python_tool_string = ns.python_tool_string + "</tools>\n\nThe state persists between code executions: so variables that you define in one step are still available thereafter." -%}
|
| 64 |
+
{{- python_tool_string -}}
|
| 65 |
+
{%- endif -%}
|
| 66 |
+
{{- "\n\n" -}}
|
| 67 |
+
{{- "<|im_end|>\n" -}}
|
| 68 |
+
{%- endif -%}
|
| 69 |
+
{%- endif -%}
|
| 70 |
+
{# βββββ main loop βββββ #}
|
| 71 |
+
{%- for message in messages -%}
|
| 72 |
+
{%- set content = message.content if message.content is string else "" -%}
|
| 73 |
+
{%- if message.role == "user" -%}
|
| 74 |
+
{{ "<|im_start|>" + message.role + "\n" + content + "<|im_end|>\n" }}
|
| 75 |
+
{%- elif message.role == "assistant" -%}
|
| 76 |
+
{% generation %}
|
| 77 |
+
{%- if reasoning_mode == "/think" -%}
|
| 78 |
+
{{ "<|im_start|>assistant\n" + content.lstrip("\n") + "<|im_end|>\n" }}
|
| 79 |
+
{%- else -%}
|
| 80 |
+
{{ "<|im_start|>assistant\n" + "<think>\n\n</think>\n" + content.lstrip("\n") + "<|im_end|>\n" }}
|
| 81 |
+
{%- endif -%}
|
| 82 |
+
{% endgeneration %}
|
| 83 |
+
{%- elif message.role == "tool" -%}
|
| 84 |
+
{{ "<|im_start|>" + "user\n" + content + "<|im_end|>\n" }}
|
| 85 |
+
{%- endif -%}
|
| 86 |
+
{%- endfor -%}
|
| 87 |
+
{# βββββ generation prompt βββββ #}
|
| 88 |
+
{%- if add_generation_prompt -%}
|
| 89 |
+
{%- if reasoning_mode == "/think" -%}
|
| 90 |
+
{{ "<|im_start|>assistant\n" }}
|
| 91 |
+
{%- else -%}
|
| 92 |
+
{{ "<|im_start|>assistant\n" + "<think>\n\n</think>\n" }}
|
| 93 |
+
{%- endif -%}
|
| 94 |
+
{%- endif -%}
|
preprocessor_config.json
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"chunk_length": 30,
|
| 3 |
+
"dither": 0.0,
|
| 4 |
+
"feature_extractor_type": "WhisperFeatureExtractor",
|
| 5 |
+
"feature_size": 128,
|
| 6 |
+
"hop_length": 160,
|
| 7 |
+
"n_fft": 400,
|
| 8 |
+
"n_samples": 480000,
|
| 9 |
+
"nb_max_frames": 3000,
|
| 10 |
+
"num_mel_bins": 128,
|
| 11 |
+
"padding_side": "right",
|
| 12 |
+
"padding_value": 0.0,
|
| 13 |
+
"processor_class": "ASRProcessor",
|
| 14 |
+
"return_attention_mask": false,
|
| 15 |
+
"sampling_rate": 16000,
|
| 16 |
+
"feature_extractor_class": "AutoFeatureExtractor",
|
| 17 |
+
"tokenizer_class": "AutoTokenizer",
|
| 18 |
+
"auto_map": {
|
| 19 |
+
"AutoProcessor": "asr_processing.ASRProcessor"
|
| 20 |
+
}
|
| 21 |
+
}
|
special_tokens_map.json
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"additional_special_tokens": [
|
| 3 |
+
{
|
| 4 |
+
"content": "<audio>",
|
| 5 |
+
"lstrip": false,
|
| 6 |
+
"normalized": false,
|
| 7 |
+
"rstrip": false,
|
| 8 |
+
"single_word": false
|
| 9 |
+
}
|
| 10 |
+
],
|
| 11 |
+
"eos_token": {
|
| 12 |
+
"content": "<|im_end|>",
|
| 13 |
+
"lstrip": false,
|
| 14 |
+
"normalized": false,
|
| 15 |
+
"rstrip": false,
|
| 16 |
+
"single_word": false
|
| 17 |
+
},
|
| 18 |
+
"pad_token": "<|finetune_right_pad_id|>"
|
| 19 |
+
}
|
tokenizer.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d4aeaf198f783cbf58d8cd59812baac429ffe49147bf9648f6618de20b8d4a4c
|
| 3 |
+
size 17209003
|
tokenizer_config.json
ADDED
|
Binary file (50.6 kB). View file
|
|
|