mazesmazes commited on
Commit
f7d3870
·
verified ·
1 Parent(s): c13ef82

Training in progress - step 500

Browse files
Files changed (4) hide show
  1. asr_config.py +4 -2
  2. asr_modeling.py +21 -13
  3. asr_pipeline.py +3 -1
  4. projectors.py +3 -1
asr_config.py CHANGED
@@ -54,7 +54,7 @@ class ASRConfig(transformers.PretrainedConfig):
54
  lora_rank: int = 8, # SALMONN default
55
  lora_alpha: int = 32, # SALMONN default (scaling factor 4.0)
56
  lora_dropout: float = 0.0,
57
- lora_target_modules: Optional[list] = None, # Default: ["q_proj", "v_proj"]
58
  freeze_projector: bool = False, # True for Stage 2 (LoRA-only training)
59
  max_new_tokens: Optional[int] = None,
60
  min_new_tokens: Optional[int] = None,
@@ -121,7 +121,9 @@ class ASRConfig(transformers.PretrainedConfig):
121
  self.lora_rank = lora_rank
122
  self.lora_alpha = lora_alpha
123
  self.lora_dropout = lora_dropout
124
- self.lora_target_modules = lora_target_modules or ["q_proj", "v_proj"]
 
 
125
  self.freeze_projector = freeze_projector
126
 
127
  # Generation parameters (use explicit value if provided, else use default)
 
54
  lora_rank: int = 8, # SALMONN default
55
  lora_alpha: int = 32, # SALMONN default (scaling factor 4.0)
56
  lora_dropout: float = 0.0,
57
+ lora_target_modules: Optional[list] = None, # Default: all linear layers
58
  freeze_projector: bool = False, # True for Stage 2 (LoRA-only training)
59
  max_new_tokens: Optional[int] = None,
60
  min_new_tokens: Optional[int] = None,
 
121
  self.lora_rank = lora_rank
122
  self.lora_alpha = lora_alpha
123
  self.lora_dropout = lora_dropout
124
+ self.lora_target_modules = lora_target_modules or [
125
+ "q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"
126
+ ]
127
  self.freeze_projector = freeze_projector
128
 
129
  # Generation parameters (use explicit value if provided, else use default)
asr_modeling.py CHANGED
@@ -190,27 +190,30 @@ class ASRModel(PreTrainedModel, GenerationMixin):
190
  state_dict = load_file(model_file)
191
  model.load_state_dict(state_dict, strict=False)
192
 
193
- # Load LoRA adapters if they exist and use_lora is enabled
194
  if getattr(config, "use_lora", False):
195
- adapter_file = cached_file(
 
196
  pretrained_model_name_or_path,
197
- "adapter_model.safetensors",
198
  _raise_exceptions_for_missing_entries=False,
199
  **cache_kwargs,
200
  )
201
- if adapter_file is not None:
 
 
202
  from peft import PeftModel
203
 
204
- # Get the directory containing the adapter
205
- import os
206
-
207
- adapter_dir = os.path.dirname(adapter_file)
208
- # Load adapter weights into the already-wrapped PeftModel
209
  model.language_model = PeftModel.from_pretrained(
210
- model.language_model.base_model,
211
- adapter_dir,
212
  is_trainable=True,
 
213
  )
 
 
 
214
 
215
  return model
216
  finally:
@@ -259,7 +262,10 @@ class ASRModel(PreTrainedModel, GenerationMixin):
259
  self.projector = self._create_projector(config, target_dtype)
260
 
261
  # Setup LoRA if enabled (Stage 2 fine-tuning)
262
- if getattr(config, "use_lora", False):
 
 
 
263
  self._setup_lora(config)
264
 
265
  # Freeze projector if specified (for Stage 2 LoRA-only training)
@@ -821,8 +827,10 @@ class ASRModel(PreTrainedModel, GenerationMixin):
821
  self.feature_extractor.save_pretrained(save_dir)
822
 
823
  # Save LoRA adapters if present (creates adapter_model.safetensors and adapter_config.json)
 
 
824
  if hasattr(self.language_model, "peft_config"):
825
- self.language_model.save_pretrained(save_dir)
826
 
827
  # Add processor auto_map to preprocessor_config.json
828
  config_path = save_dir / "preprocessor_config.json"
 
190
  state_dict = load_file(model_file)
191
  model.load_state_dict(state_dict, strict=False)
192
 
193
+ # Load LoRA adapters if use_lora is enabled
194
  if getattr(config, "use_lora", False):
195
+ # Check for adapter_config.json (required by PEFT to load adapters)
196
+ adapter_config_file = cached_file(
197
  pretrained_model_name_or_path,
198
+ "adapter_config.json",
199
  _raise_exceptions_for_missing_entries=False,
200
  **cache_kwargs,
201
  )
202
+ if adapter_config_file is not None:
203
+ # Load saved adapter weights using the original repo_id/path
204
+ # PEFT handles Hub downloads and caching internally
205
  from peft import PeftModel
206
 
207
+ # language_model is bare (not PEFT-wrapped) since we skipped _setup_lora
 
 
 
 
208
  model.language_model = PeftModel.from_pretrained(
209
+ model.language_model,
210
+ pretrained_model_name_or_path, # Use original repo_id, not cache path
211
  is_trainable=True,
212
+ **cache_kwargs,
213
  )
214
+ else:
215
+ # No saved adapters - initialize fresh LoRA for training
216
+ model._setup_lora(config)
217
 
218
  return model
219
  finally:
 
262
  self.projector = self._create_projector(config, target_dtype)
263
 
264
  # Setup LoRA if enabled (Stage 2 fine-tuning)
265
+ # Skip if loading from pretrained - from_pretrained will handle adapter loading
266
+ if getattr(config, "use_lora", False) and not getattr(
267
+ self.__class__, "_is_loading_from_pretrained", False
268
+ ):
269
  self._setup_lora(config)
270
 
271
  # Freeze projector if specified (for Stage 2 LoRA-only training)
 
827
  self.feature_extractor.save_pretrained(save_dir)
828
 
829
  # Save LoRA adapters if present (creates adapter_model.safetensors and adapter_config.json)
830
+ # Don't save embedding layers - the <audio> token embedding is never used
831
+ # (it's replaced with projected audio embeddings before the LLM sees it)
832
  if hasattr(self.language_model, "peft_config"):
833
+ self.language_model.save_pretrained(save_dir, save_embedding_layers=False)
834
 
835
  # Add processor auto_map to preprocessor_config.json
836
  config_path = save_dir / "preprocessor_config.json"
asr_pipeline.py CHANGED
@@ -504,7 +504,9 @@ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
504
  if repeat_count >= 1:
505
  words = words[: idx + n]
506
  text = " ".join(words)
507
- print(f"[DEBUG] Truncated repetition: {original_len} -> {len(words)} words (n={n}, repeats={repeat_count})")
 
 
508
  break
509
 
510
  # 3. COMBINE ACRONYMS
 
504
  if repeat_count >= 1:
505
  words = words[: idx + n]
506
  text = " ".join(words)
507
+ print(
508
+ f"[DEBUG] Truncated repetition: {original_len} -> {len(words)} words (n={n}, repeats={repeat_count})"
509
+ )
510
  break
511
 
512
  # 3. COMBINE ACRONYMS
projectors.py CHANGED
@@ -135,7 +135,9 @@ class MOSAProjector(nn.Module):
135
 
136
  # --- 1. Router Branch ---
137
  # Mean pool encoder outputs for routing decisions
138
- x_pooled = x.reshape(batch_size, out_len, self.k, self.encoder_dim).mean(dim=2) # (B, out_len, D)
 
 
139
 
140
  # Router logits and softmax gating (dense MoE)
141
  routing_weights = F.softmax(self.router(x_pooled), dim=-1) # (B, out_len, num_experts)
 
135
 
136
  # --- 1. Router Branch ---
137
  # Mean pool encoder outputs for routing decisions
138
+ x_pooled = x.reshape(batch_size, out_len, self.k, self.encoder_dim).mean(
139
+ dim=2
140
+ ) # (B, out_len, D)
141
 
142
  # Router logits and softmax gating (dense MoE)
143
  routing_weights = F.softmax(self.router(x_pooled), dim=-1) # (B, out_len, num_experts)