mazesmazes commited on
Commit
3e23818
·
verified ·
1 Parent(s): c7d61ca

Training in progress - step 22000

Browse files
Files changed (3) hide show
  1. .gitattributes +1 -0
  2. asr_config.py +2 -4
  3. asr_pipeline.py +6 -42
.gitattributes CHANGED
@@ -1,3 +1,4 @@
1
  *.safetensors filter=lfs diff=lfs merge=lfs -text
2
  *.bin filter=lfs diff=lfs merge=lfs -text
3
  tokenizer_config.json -filter -diff -merge text
 
 
1
  *.safetensors filter=lfs diff=lfs merge=lfs -text
2
  *.bin filter=lfs diff=lfs merge=lfs -text
3
  tokenizer_config.json -filter -diff -merge text
4
+ tokenizer.json filter=lfs diff=lfs merge=lfs -text
asr_config.py CHANGED
@@ -25,6 +25,7 @@ class ASRConfig(transformers.PretrainedConfig):
25
  model_dtype: str = "bfloat16",
26
  num_beams: Optional[int] = None,
27
  system_prompt: str = "You are a helpful assistant.",
 
28
  encoder_dim: Optional[int] = None,
29
  llm_dim: Optional[int] = None,
30
  # Encoder conv layers: list of (padding, kernel_size, stride) tuples
@@ -103,6 +104,7 @@ class ASRConfig(transformers.PretrainedConfig):
103
  self.attn_implementation = attn_implementation
104
  self.model_dtype = model_dtype
105
  self.system_prompt = system_prompt
 
106
  self.encoder_dim = encoder_dim
107
  self.llm_dim = llm_dim
108
  # Default conv layers for Whisper/GLM-ASR: [(pad, kernel, stride), ...]
@@ -204,10 +206,6 @@ class ASRConfig(transformers.PretrainedConfig):
204
 
205
  super().__init__(**kwargs)
206
 
207
- # Point encoder to audio_config so pipeline uses correct feature extractor
208
- # The pipeline looks for config.encoder._name_or_path for feature extractor
209
- self.encoder = self.audio_config
210
-
211
  self.auto_map = {
212
  "AutoConfig": "asr_config.ASRConfig",
213
  "AutoModel": "asr_modeling.ASRModel",
 
25
  model_dtype: str = "bfloat16",
26
  num_beams: Optional[int] = None,
27
  system_prompt: str = "You are a helpful assistant.",
28
+ user_prompt: str = "Please transcribe this English audio into text: <audio>",
29
  encoder_dim: Optional[int] = None,
30
  llm_dim: Optional[int] = None,
31
  # Encoder conv layers: list of (padding, kernel_size, stride) tuples
 
104
  self.attn_implementation = attn_implementation
105
  self.model_dtype = model_dtype
106
  self.system_prompt = system_prompt
107
+ self.user_prompt = user_prompt
108
  self.encoder_dim = encoder_dim
109
  self.llm_dim = llm_dim
110
  # Default conv layers for Whisper/GLM-ASR: [(pad, kernel, stride), ...]
 
206
 
207
  super().__init__(**kwargs)
208
 
 
 
 
 
209
  self.auto_map = {
210
  "AutoConfig": "asr_config.ASRConfig",
211
  "AutoModel": "asr_modeling.ASRModel",
asr_pipeline.py CHANGED
@@ -523,13 +523,6 @@ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
523
  text = self._post_process_prediction(text)
524
  return {"text": text}
525
 
526
- # Known hallucination patterns that should be deleted entirely
527
- HALLUCINATION_PATTERNS = frozenset(
528
- [
529
- "and gt and gt",
530
- ]
531
- )
532
-
533
  def _post_process_prediction(self, text: str) -> str:
534
  """Post-process model output to fix common issues."""
535
  if not text:
@@ -538,29 +531,22 @@ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
538
  # 1. LOWERCASE
539
  text = text.lower()
540
 
541
- # 2. CHECK FOR KNOWN HALLUCINATIONS (delete entirely)
542
- if text.strip() in self.HALLUCINATION_PATTERNS:
543
- return ""
544
-
545
- # 3. COMBINE ACRONYMS
546
  # Merge consecutive single letters into one word (e.g., "u s a" -> "usa")
547
  text = re.sub(r"\b([a-z])((?:\s+[a-z])+)\b", lambda m: m.group(0).replace(" ", ""), text)
548
 
549
- # 4. NORMALIZE CURRENCY
550
  # Convert "eur X" to "X euros" for Whisper normalizer compatibility
551
  text = re.sub(r"\beur\s+(\d+)", r"\1 euros", text)
552
 
553
- # 5. TRUNCATE CHARACTER REPETITIONS (e.g., "uhhhhhh" -> "uhh")
554
- text = self._truncate_character_repetitions(text)
555
-
556
- # 6. TRUNCATE TRAILING REPEATS (word-level)
557
  text = self._truncate_trailing_repeats(text)
558
 
559
- # 7. STRIP WHITESPACE
560
  return re.sub(r"\s+", " ", text).strip()
561
 
562
- def _truncate_trailing_repeats(self, text: str, max_ngram: int = 10) -> str:
563
- """Remove trailing repeated n-grams (1-10 words)."""
564
  words = text.split()
565
  if len(words) < 2:
566
  return text
@@ -580,25 +566,3 @@ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
580
  break # Restart from largest n-gram
581
 
582
  return " ".join(words)
583
-
584
- def _truncate_character_repetitions(self, text: str, max_repeats: int = 3) -> str:
585
- """Remove excessive character repetitions (e.g., 'uhhhhhh' -> 'uhh').
586
-
587
- Handles hallucinations where the model outputs the same character many times,
588
- like "uhhhhhhhhhhhhhhhhhhhhhhhhh" at the end of a prediction.
589
-
590
- Args:
591
- text: Input text to clean
592
- max_repeats: Maximum allowed consecutive repetitions of a character
593
-
594
- Returns:
595
- Text with character repetitions truncated
596
- """
597
- if not text:
598
- return text
599
-
600
- # Replace any character repeated more than max_repeats times with max_repeats
601
- # Pattern: any character followed by itself N+ times
602
- pattern = rf"(.)\1{{{max_repeats},}}"
603
- replacement = r"\1" * max_repeats
604
- return re.sub(pattern, replacement, text)
 
523
  text = self._post_process_prediction(text)
524
  return {"text": text}
525
 
 
 
 
 
 
 
 
526
  def _post_process_prediction(self, text: str) -> str:
527
  """Post-process model output to fix common issues."""
528
  if not text:
 
531
  # 1. LOWERCASE
532
  text = text.lower()
533
 
534
+ # 2. COMBINE ACRONYMS
 
 
 
 
535
  # Merge consecutive single letters into one word (e.g., "u s a" -> "usa")
536
  text = re.sub(r"\b([a-z])((?:\s+[a-z])+)\b", lambda m: m.group(0).replace(" ", ""), text)
537
 
538
+ # 3. NORMALIZE CURRENCY
539
  # Convert "eur X" to "X euros" for Whisper normalizer compatibility
540
  text = re.sub(r"\beur\s+(\d+)", r"\1 euros", text)
541
 
542
+ # 4. TRUNCATE TRAILING REPEATS
 
 
 
543
  text = self._truncate_trailing_repeats(text)
544
 
545
+ # 5. STRIP WHITESPACE
546
  return re.sub(r"\s+", " ", text).strip()
547
 
548
+ def _truncate_trailing_repeats(self, text: str, max_ngram: int = 4) -> str:
549
+ """Remove trailing repeated n-grams (1-4 words)."""
550
  words = text.split()
551
  if len(words) < 2:
552
  return text
 
566
  break # Restart from largest n-gram
567
 
568
  return " ".join(words)