Training in progress - step 1000
Browse files- asr_config.py +14 -0
- asr_modeling.py +52 -1
- asr_processing.py +1 -0
- chat_template.jinja +1 -1
asr_config.py
CHANGED
|
@@ -49,6 +49,13 @@ class ASRConfig(transformers.PretrainedConfig):
|
|
| 49 |
mask_feature_prob: float = 0.0, # Probability of masking frequency bins (disabled by default)
|
| 50 |
mask_feature_length: int = 10, # Max length of frequency mask
|
| 51 |
mask_feature_min_masks: int = 0, # Min number of frequency masks
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
max_new_tokens: Optional[int] = None,
|
| 53 |
min_new_tokens: Optional[int] = None,
|
| 54 |
repetition_penalty: Optional[float] = None,
|
|
@@ -109,6 +116,13 @@ class ASRConfig(transformers.PretrainedConfig):
|
|
| 109 |
self.mask_feature_prob = mask_feature_prob
|
| 110 |
self.mask_feature_length = mask_feature_length
|
| 111 |
self.mask_feature_min_masks = mask_feature_min_masks
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
|
| 113 |
# Generation parameters (use explicit value if provided, else use default)
|
| 114 |
self.num_beams = num_beams if num_beams is not None else generation_defaults["num_beams"]
|
|
|
|
| 49 |
mask_feature_prob: float = 0.0, # Probability of masking frequency bins (disabled by default)
|
| 50 |
mask_feature_length: int = 10, # Max length of frequency mask
|
| 51 |
mask_feature_min_masks: int = 0, # Min number of frequency masks
|
| 52 |
+
# LoRA configuration (for Stage 2 fine-tuning)
|
| 53 |
+
use_lora: bool = False,
|
| 54 |
+
lora_rank: int = 8, # SALMONN default
|
| 55 |
+
lora_alpha: int = 32, # SALMONN default (scaling factor 4.0)
|
| 56 |
+
lora_dropout: float = 0.0,
|
| 57 |
+
lora_target_modules: Optional[list] = None, # Default: ["q_proj", "v_proj"]
|
| 58 |
+
freeze_projector: bool = False, # True for Stage 2 (LoRA-only training)
|
| 59 |
max_new_tokens: Optional[int] = None,
|
| 60 |
min_new_tokens: Optional[int] = None,
|
| 61 |
repetition_penalty: Optional[float] = None,
|
|
|
|
| 116 |
self.mask_feature_prob = mask_feature_prob
|
| 117 |
self.mask_feature_length = mask_feature_length
|
| 118 |
self.mask_feature_min_masks = mask_feature_min_masks
|
| 119 |
+
# LoRA configuration
|
| 120 |
+
self.use_lora = use_lora
|
| 121 |
+
self.lora_rank = lora_rank
|
| 122 |
+
self.lora_alpha = lora_alpha
|
| 123 |
+
self.lora_dropout = lora_dropout
|
| 124 |
+
self.lora_target_modules = lora_target_modules or ["q_proj", "v_proj"]
|
| 125 |
+
self.freeze_projector = freeze_projector
|
| 126 |
|
| 127 |
# Generation parameters (use explicit value if provided, else use default)
|
| 128 |
self.num_beams = num_beams if num_beams is not None else generation_defaults["num_beams"]
|
asr_modeling.py
CHANGED
|
@@ -190,6 +190,28 @@ class ASRModel(PreTrainedModel, GenerationMixin):
|
|
| 190 |
state_dict = load_file(model_file)
|
| 191 |
model.load_state_dict(state_dict, strict=False)
|
| 192 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 193 |
return model
|
| 194 |
finally:
|
| 195 |
cls._is_loading_from_pretrained = False
|
|
@@ -233,9 +255,17 @@ class ASRModel(PreTrainedModel, GenerationMixin):
|
|
| 233 |
# Feature extractor for audio preprocessing
|
| 234 |
self.feature_extractor = self._create_feature_extractor(config)
|
| 235 |
|
| 236 |
-
# Audio projector (trainable)
|
| 237 |
self.projector = self._create_projector(config, target_dtype)
|
| 238 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 239 |
# For model parallelism
|
| 240 |
self._no_split_modules = getattr(self.language_model, "_no_split_modules", [])
|
| 241 |
|
|
@@ -333,6 +363,21 @@ class ASRModel(PreTrainedModel, GenerationMixin):
|
|
| 333 |
device = next(self.language_model.parameters()).device
|
| 334 |
return projector.to(device=device, dtype=dtype)
|
| 335 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 336 |
def _init_tokenizer(self, config: ASRConfig):
|
| 337 |
"""Initialize tokenizer with audio token."""
|
| 338 |
self.tokenizer = AutoTokenizer.from_pretrained(config.text_model_id, trust_remote_code=True)
|
|
@@ -600,6 +645,7 @@ class ASRModel(PreTrainedModel, GenerationMixin):
|
|
| 600 |
tokenize=True,
|
| 601 |
add_generation_prompt=True,
|
| 602 |
return_tensors="pt",
|
|
|
|
| 603 |
)
|
| 604 |
input_ids = chat_result.input_ids.to(device)
|
| 605 |
|
|
@@ -674,6 +720,7 @@ class ASRModel(PreTrainedModel, GenerationMixin):
|
|
| 674 |
tokenize=True,
|
| 675 |
add_generation_prompt=True,
|
| 676 |
return_tensors="pt",
|
|
|
|
| 677 |
)
|
| 678 |
input_ids = chat_result.input_ids.to(device)
|
| 679 |
|
|
@@ -773,6 +820,10 @@ class ASRModel(PreTrainedModel, GenerationMixin):
|
|
| 773 |
self.tokenizer.save_pretrained(save_dir)
|
| 774 |
self.feature_extractor.save_pretrained(save_dir)
|
| 775 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 776 |
# Add processor auto_map to preprocessor_config.json
|
| 777 |
config_path = save_dir / "preprocessor_config.json"
|
| 778 |
if config_path.exists():
|
|
|
|
| 190 |
state_dict = load_file(model_file)
|
| 191 |
model.load_state_dict(state_dict, strict=False)
|
| 192 |
|
| 193 |
+
# Load LoRA adapters if they exist and use_lora is enabled
|
| 194 |
+
if getattr(config, "use_lora", False):
|
| 195 |
+
adapter_file = cached_file(
|
| 196 |
+
pretrained_model_name_or_path,
|
| 197 |
+
"adapter_model.safetensors",
|
| 198 |
+
_raise_exceptions_for_missing_entries=False,
|
| 199 |
+
**cache_kwargs,
|
| 200 |
+
)
|
| 201 |
+
if adapter_file is not None:
|
| 202 |
+
from peft import PeftModel
|
| 203 |
+
|
| 204 |
+
# Get the directory containing the adapter
|
| 205 |
+
import os
|
| 206 |
+
|
| 207 |
+
adapter_dir = os.path.dirname(adapter_file)
|
| 208 |
+
# Load adapter weights into the already-wrapped PeftModel
|
| 209 |
+
model.language_model = PeftModel.from_pretrained(
|
| 210 |
+
model.language_model.base_model,
|
| 211 |
+
adapter_dir,
|
| 212 |
+
is_trainable=True,
|
| 213 |
+
)
|
| 214 |
+
|
| 215 |
return model
|
| 216 |
finally:
|
| 217 |
cls._is_loading_from_pretrained = False
|
|
|
|
| 255 |
# Feature extractor for audio preprocessing
|
| 256 |
self.feature_extractor = self._create_feature_extractor(config)
|
| 257 |
|
| 258 |
+
# Audio projector (trainable unless freeze_projector is set)
|
| 259 |
self.projector = self._create_projector(config, target_dtype)
|
| 260 |
|
| 261 |
+
# Setup LoRA if enabled (Stage 2 fine-tuning)
|
| 262 |
+
if getattr(config, "use_lora", False):
|
| 263 |
+
self._setup_lora(config)
|
| 264 |
+
|
| 265 |
+
# Freeze projector if specified (for Stage 2 LoRA-only training)
|
| 266 |
+
if getattr(config, "freeze_projector", False):
|
| 267 |
+
self.projector.requires_grad_(False)
|
| 268 |
+
|
| 269 |
# For model parallelism
|
| 270 |
self._no_split_modules = getattr(self.language_model, "_no_split_modules", [])
|
| 271 |
|
|
|
|
| 363 |
device = next(self.language_model.parameters()).device
|
| 364 |
return projector.to(device=device, dtype=dtype)
|
| 365 |
|
| 366 |
+
def _setup_lora(self, config: ASRConfig):
|
| 367 |
+
"""Apply LoRA adapters to the language model for Stage 2 fine-tuning."""
|
| 368 |
+
from peft import LoraConfig, get_peft_model
|
| 369 |
+
|
| 370 |
+
lora_config = LoraConfig(
|
| 371 |
+
r=config.lora_rank,
|
| 372 |
+
lora_alpha=config.lora_alpha,
|
| 373 |
+
target_modules=config.lora_target_modules,
|
| 374 |
+
lora_dropout=config.lora_dropout,
|
| 375 |
+
bias="none",
|
| 376 |
+
task_type="CAUSAL_LM",
|
| 377 |
+
)
|
| 378 |
+
self.language_model = get_peft_model(self.language_model, lora_config)
|
| 379 |
+
# LoRA params are trainable by default, base model stays frozen
|
| 380 |
+
|
| 381 |
def _init_tokenizer(self, config: ASRConfig):
|
| 382 |
"""Initialize tokenizer with audio token."""
|
| 383 |
self.tokenizer = AutoTokenizer.from_pretrained(config.text_model_id, trust_remote_code=True)
|
|
|
|
| 645 |
tokenize=True,
|
| 646 |
add_generation_prompt=True,
|
| 647 |
return_tensors="pt",
|
| 648 |
+
enable_thinking=False, # Disable Qwen3 thinking mode for ASR
|
| 649 |
)
|
| 650 |
input_ids = chat_result.input_ids.to(device)
|
| 651 |
|
|
|
|
| 720 |
tokenize=True,
|
| 721 |
add_generation_prompt=True,
|
| 722 |
return_tensors="pt",
|
| 723 |
+
enable_thinking=False, # Disable Qwen3 thinking mode for ASR
|
| 724 |
)
|
| 725 |
input_ids = chat_result.input_ids.to(device)
|
| 726 |
|
|
|
|
| 820 |
self.tokenizer.save_pretrained(save_dir)
|
| 821 |
self.feature_extractor.save_pretrained(save_dir)
|
| 822 |
|
| 823 |
+
# Save LoRA adapters if present (creates adapter_model.safetensors and adapter_config.json)
|
| 824 |
+
if hasattr(self.language_model, "peft_config"):
|
| 825 |
+
self.language_model.save_pretrained(save_dir)
|
| 826 |
+
|
| 827 |
# Add processor auto_map to preprocessor_config.json
|
| 828 |
config_path = save_dir / "preprocessor_config.json"
|
| 829 |
if config_path.exists():
|
asr_processing.py
CHANGED
|
@@ -99,6 +99,7 @@ class ASRProcessor(ProcessorMixin):
|
|
| 99 |
tokenize=True,
|
| 100 |
add_generation_prompt=(text is None),
|
| 101 |
return_tensors=return_tensors,
|
|
|
|
| 102 |
)
|
| 103 |
|
| 104 |
# Handle both tensor and BatchEncoding returns
|
|
|
|
| 99 |
tokenize=True,
|
| 100 |
add_generation_prompt=(text is None),
|
| 101 |
return_tensors=return_tensors,
|
| 102 |
+
enable_thinking=False, # Disable Qwen3 thinking mode for ASR
|
| 103 |
)
|
| 104 |
|
| 105 |
# Handle both tensor and BatchEncoding returns
|
chat_template.jinja
CHANGED
|
@@ -83,7 +83,7 @@
|
|
| 83 |
{%- endfor %}
|
| 84 |
{%- if add_generation_prompt %}
|
| 85 |
{{- '<|im_start|>assistant\n' }}
|
| 86 |
-
{%- if
|
| 87 |
{{- '<think>\n\n</think>\n\n' }}
|
| 88 |
{%- endif %}
|
| 89 |
{%- endif %}
|
|
|
|
| 83 |
{%- endfor %}
|
| 84 |
{%- if add_generation_prompt %}
|
| 85 |
{{- '<|im_start|>assistant\n' }}
|
| 86 |
+
{%- if true %}
|
| 87 |
{{- '<think>\n\n</think>\n\n' }}
|
| 88 |
{%- endif %}
|
| 89 |
{%- endif %}
|