mazesmazes commited on
Commit
e196076
·
verified ·
1 Parent(s): 4e95951

Training in progress - step 1000

Browse files
Files changed (4) hide show
  1. asr_config.py +14 -0
  2. asr_modeling.py +52 -1
  3. asr_processing.py +1 -0
  4. chat_template.jinja +1 -1
asr_config.py CHANGED
@@ -49,6 +49,13 @@ class ASRConfig(transformers.PretrainedConfig):
49
  mask_feature_prob: float = 0.0, # Probability of masking frequency bins (disabled by default)
50
  mask_feature_length: int = 10, # Max length of frequency mask
51
  mask_feature_min_masks: int = 0, # Min number of frequency masks
 
 
 
 
 
 
 
52
  max_new_tokens: Optional[int] = None,
53
  min_new_tokens: Optional[int] = None,
54
  repetition_penalty: Optional[float] = None,
@@ -109,6 +116,13 @@ class ASRConfig(transformers.PretrainedConfig):
109
  self.mask_feature_prob = mask_feature_prob
110
  self.mask_feature_length = mask_feature_length
111
  self.mask_feature_min_masks = mask_feature_min_masks
 
 
 
 
 
 
 
112
 
113
  # Generation parameters (use explicit value if provided, else use default)
114
  self.num_beams = num_beams if num_beams is not None else generation_defaults["num_beams"]
 
49
  mask_feature_prob: float = 0.0, # Probability of masking frequency bins (disabled by default)
50
  mask_feature_length: int = 10, # Max length of frequency mask
51
  mask_feature_min_masks: int = 0, # Min number of frequency masks
52
+ # LoRA configuration (for Stage 2 fine-tuning)
53
+ use_lora: bool = False,
54
+ lora_rank: int = 8, # SALMONN default
55
+ lora_alpha: int = 32, # SALMONN default (scaling factor 4.0)
56
+ lora_dropout: float = 0.0,
57
+ lora_target_modules: Optional[list] = None, # Default: ["q_proj", "v_proj"]
58
+ freeze_projector: bool = False, # True for Stage 2 (LoRA-only training)
59
  max_new_tokens: Optional[int] = None,
60
  min_new_tokens: Optional[int] = None,
61
  repetition_penalty: Optional[float] = None,
 
116
  self.mask_feature_prob = mask_feature_prob
117
  self.mask_feature_length = mask_feature_length
118
  self.mask_feature_min_masks = mask_feature_min_masks
119
+ # LoRA configuration
120
+ self.use_lora = use_lora
121
+ self.lora_rank = lora_rank
122
+ self.lora_alpha = lora_alpha
123
+ self.lora_dropout = lora_dropout
124
+ self.lora_target_modules = lora_target_modules or ["q_proj", "v_proj"]
125
+ self.freeze_projector = freeze_projector
126
 
127
  # Generation parameters (use explicit value if provided, else use default)
128
  self.num_beams = num_beams if num_beams is not None else generation_defaults["num_beams"]
asr_modeling.py CHANGED
@@ -190,6 +190,28 @@ class ASRModel(PreTrainedModel, GenerationMixin):
190
  state_dict = load_file(model_file)
191
  model.load_state_dict(state_dict, strict=False)
192
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
193
  return model
194
  finally:
195
  cls._is_loading_from_pretrained = False
@@ -233,9 +255,17 @@ class ASRModel(PreTrainedModel, GenerationMixin):
233
  # Feature extractor for audio preprocessing
234
  self.feature_extractor = self._create_feature_extractor(config)
235
 
236
- # Audio projector (trainable)
237
  self.projector = self._create_projector(config, target_dtype)
238
 
 
 
 
 
 
 
 
 
239
  # For model parallelism
240
  self._no_split_modules = getattr(self.language_model, "_no_split_modules", [])
241
 
@@ -333,6 +363,21 @@ class ASRModel(PreTrainedModel, GenerationMixin):
333
  device = next(self.language_model.parameters()).device
334
  return projector.to(device=device, dtype=dtype)
335
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
336
  def _init_tokenizer(self, config: ASRConfig):
337
  """Initialize tokenizer with audio token."""
338
  self.tokenizer = AutoTokenizer.from_pretrained(config.text_model_id, trust_remote_code=True)
@@ -600,6 +645,7 @@ class ASRModel(PreTrainedModel, GenerationMixin):
600
  tokenize=True,
601
  add_generation_prompt=True,
602
  return_tensors="pt",
 
603
  )
604
  input_ids = chat_result.input_ids.to(device)
605
 
@@ -674,6 +720,7 @@ class ASRModel(PreTrainedModel, GenerationMixin):
674
  tokenize=True,
675
  add_generation_prompt=True,
676
  return_tensors="pt",
 
677
  )
678
  input_ids = chat_result.input_ids.to(device)
679
 
@@ -773,6 +820,10 @@ class ASRModel(PreTrainedModel, GenerationMixin):
773
  self.tokenizer.save_pretrained(save_dir)
774
  self.feature_extractor.save_pretrained(save_dir)
775
 
 
 
 
 
776
  # Add processor auto_map to preprocessor_config.json
777
  config_path = save_dir / "preprocessor_config.json"
778
  if config_path.exists():
 
190
  state_dict = load_file(model_file)
191
  model.load_state_dict(state_dict, strict=False)
192
 
193
+ # Load LoRA adapters if they exist and use_lora is enabled
194
+ if getattr(config, "use_lora", False):
195
+ adapter_file = cached_file(
196
+ pretrained_model_name_or_path,
197
+ "adapter_model.safetensors",
198
+ _raise_exceptions_for_missing_entries=False,
199
+ **cache_kwargs,
200
+ )
201
+ if adapter_file is not None:
202
+ from peft import PeftModel
203
+
204
+ # Get the directory containing the adapter
205
+ import os
206
+
207
+ adapter_dir = os.path.dirname(adapter_file)
208
+ # Load adapter weights into the already-wrapped PeftModel
209
+ model.language_model = PeftModel.from_pretrained(
210
+ model.language_model.base_model,
211
+ adapter_dir,
212
+ is_trainable=True,
213
+ )
214
+
215
  return model
216
  finally:
217
  cls._is_loading_from_pretrained = False
 
255
  # Feature extractor for audio preprocessing
256
  self.feature_extractor = self._create_feature_extractor(config)
257
 
258
+ # Audio projector (trainable unless freeze_projector is set)
259
  self.projector = self._create_projector(config, target_dtype)
260
 
261
+ # Setup LoRA if enabled (Stage 2 fine-tuning)
262
+ if getattr(config, "use_lora", False):
263
+ self._setup_lora(config)
264
+
265
+ # Freeze projector if specified (for Stage 2 LoRA-only training)
266
+ if getattr(config, "freeze_projector", False):
267
+ self.projector.requires_grad_(False)
268
+
269
  # For model parallelism
270
  self._no_split_modules = getattr(self.language_model, "_no_split_modules", [])
271
 
 
363
  device = next(self.language_model.parameters()).device
364
  return projector.to(device=device, dtype=dtype)
365
 
366
+ def _setup_lora(self, config: ASRConfig):
367
+ """Apply LoRA adapters to the language model for Stage 2 fine-tuning."""
368
+ from peft import LoraConfig, get_peft_model
369
+
370
+ lora_config = LoraConfig(
371
+ r=config.lora_rank,
372
+ lora_alpha=config.lora_alpha,
373
+ target_modules=config.lora_target_modules,
374
+ lora_dropout=config.lora_dropout,
375
+ bias="none",
376
+ task_type="CAUSAL_LM",
377
+ )
378
+ self.language_model = get_peft_model(self.language_model, lora_config)
379
+ # LoRA params are trainable by default, base model stays frozen
380
+
381
  def _init_tokenizer(self, config: ASRConfig):
382
  """Initialize tokenizer with audio token."""
383
  self.tokenizer = AutoTokenizer.from_pretrained(config.text_model_id, trust_remote_code=True)
 
645
  tokenize=True,
646
  add_generation_prompt=True,
647
  return_tensors="pt",
648
+ enable_thinking=False, # Disable Qwen3 thinking mode for ASR
649
  )
650
  input_ids = chat_result.input_ids.to(device)
651
 
 
720
  tokenize=True,
721
  add_generation_prompt=True,
722
  return_tensors="pt",
723
+ enable_thinking=False, # Disable Qwen3 thinking mode for ASR
724
  )
725
  input_ids = chat_result.input_ids.to(device)
726
 
 
820
  self.tokenizer.save_pretrained(save_dir)
821
  self.feature_extractor.save_pretrained(save_dir)
822
 
823
+ # Save LoRA adapters if present (creates adapter_model.safetensors and adapter_config.json)
824
+ if hasattr(self.language_model, "peft_config"):
825
+ self.language_model.save_pretrained(save_dir)
826
+
827
  # Add processor auto_map to preprocessor_config.json
828
  config_path = save_dir / "preprocessor_config.json"
829
  if config_path.exists():
asr_processing.py CHANGED
@@ -99,6 +99,7 @@ class ASRProcessor(ProcessorMixin):
99
  tokenize=True,
100
  add_generation_prompt=(text is None),
101
  return_tensors=return_tensors,
 
102
  )
103
 
104
  # Handle both tensor and BatchEncoding returns
 
99
  tokenize=True,
100
  add_generation_prompt=(text is None),
101
  return_tensors=return_tensors,
102
+ enable_thinking=False, # Disable Qwen3 thinking mode for ASR
103
  )
104
 
105
  # Handle both tensor and BatchEncoding returns
chat_template.jinja CHANGED
@@ -83,7 +83,7 @@
83
  {%- endfor %}
84
  {%- if add_generation_prompt %}
85
  {{- '<|im_start|>assistant\n' }}
86
- {%- if enable_thinking is defined and enable_thinking is false %}
87
  {{- '<think>\n\n</think>\n\n' }}
88
  {%- endif %}
89
  {%- endif %}
 
83
  {%- endfor %}
84
  {%- if add_generation_prompt %}
85
  {{- '<|im_start|>assistant\n' }}
86
+ {%- if true %}
87
  {{- '<think>\n\n</think>\n\n' }}
88
  {%- endif %}
89
  {%- endif %}