mazesmazes commited on
Commit
6a72d9f
·
verified ·
1 Parent(s): 213396d

Training in progress - step 500

Browse files
Files changed (2) hide show
  1. asr_modeling.py +12 -3
  2. asr_pipeline.py +3 -27
asr_modeling.py CHANGED
@@ -212,8 +212,18 @@ class ASRModel(PreTrainedModel, GenerationMixin):
212
  **cache_kwargs,
213
  )
214
  else:
215
- # No saved adapters - initialize fresh LoRA for training
216
- model._setup_lora(config)
 
 
 
 
 
 
 
 
 
 
217
 
218
  return model
219
  finally:
@@ -382,7 +392,6 @@ class ASRModel(PreTrainedModel, GenerationMixin):
382
  task_type="CAUSAL_LM",
383
  )
384
  self.language_model = get_peft_model(self.language_model, lora_config)
385
- # LoRA params are trainable by default, base model stays frozen
386
 
387
  def _init_tokenizer(self, config: ASRConfig):
388
  """Initialize tokenizer with audio token."""
 
212
  **cache_kwargs,
213
  )
214
  else:
215
+ # No saved adapters - initialize fresh LLM LoRA for training
216
+ from peft import LoraConfig, get_peft_model
217
+
218
+ lora_config = LoraConfig(
219
+ r=config.lora_rank,
220
+ lora_alpha=config.lora_alpha,
221
+ target_modules=config.lora_target_modules,
222
+ lora_dropout=config.lora_dropout,
223
+ bias="none",
224
+ task_type="CAUSAL_LM",
225
+ )
226
+ model.language_model = get_peft_model(model.language_model, lora_config)
227
 
228
  return model
229
  finally:
 
392
  task_type="CAUSAL_LM",
393
  )
394
  self.language_model = get_peft_model(self.language_model, lora_config)
 
395
 
396
  def _init_tokenizer(self, config: ASRConfig):
397
  """Initialize tokenizer with audio token."""
asr_pipeline.py CHANGED
@@ -485,40 +485,16 @@ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
485
  if not text:
486
  return ""
487
 
488
- original_len = len(text.split())
489
- original_text = text # Keep for debug
490
-
491
  # 1. LOWERCASE
492
  text = text.lower()
493
 
494
- # 2. REMOVE REPETITIVE LOOPS
495
- # If the model repeats the same phrase, keep only one instance.
496
- words = text.split()
497
- for n in range(1, min(15, len(words) // 2 + 1)):
498
- last_sequence = words[-n:]
499
- repeat_count = 0
500
- idx = len(words) - n
501
- while idx >= n and words[idx - n : idx] == last_sequence:
502
- repeat_count += 1
503
- idx -= n
504
-
505
- if repeat_count >= 1:
506
- words = words[: idx + n]
507
- text = " ".join(words)
508
- print(
509
- f"[POSTPROCESS] Truncated repetition: {original_len} -> {len(words)} words (n={n}, repeats={repeat_count})"
510
- )
511
- print(f"[POSTPROCESS] Before: {original_text[:100]}...")
512
- print(f"[POSTPROCESS] After: {text[:100]}...")
513
- break
514
-
515
- # 3. COMBINE ACRONYMS
516
  # Merge consecutive single letters into one word (e.g., "u s a" -> "usa")
517
  text = re.sub(r"\b([a-z])((?:\s+[a-z])+)\b", lambda m: m.group(0).replace(" ", ""), text)
518
 
519
- # 4. NORMALIZE CURRENCY
520
  # Convert "eur X" to "X euros" for Whisper normalizer compatibility
521
  text = re.sub(r"\beur\s+(\d+)", r"\1 euros", text)
522
 
523
- # 5. STRIP WHITESPACE
524
  return re.sub(r"\s+", " ", text).strip()
 
485
  if not text:
486
  return ""
487
 
 
 
 
488
  # 1. LOWERCASE
489
  text = text.lower()
490
 
491
+ # 2. COMBINE ACRONYMS
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
492
  # Merge consecutive single letters into one word (e.g., "u s a" -> "usa")
493
  text = re.sub(r"\b([a-z])((?:\s+[a-z])+)\b", lambda m: m.group(0).replace(" ", ""), text)
494
 
495
+ # 3. NORMALIZE CURRENCY
496
  # Convert "eur X" to "X euros" for Whisper normalizer compatibility
497
  text = re.sub(r"\beur\s+(\d+)", r"\1 euros", text)
498
 
499
+ # 4. STRIP WHITESPACE
500
  return re.sub(r"\s+", " ", text).strip()