mazesmazes commited on
Commit
4e1e668
·
verified ·
1 Parent(s): 5506373

Fix postprocess: filter eos tokens, preserve casing

Browse files
Files changed (1) hide show
  1. asr_pipeline.py +19 -14
asr_pipeline.py CHANGED
@@ -534,6 +534,14 @@ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
534
  if tokens.dim() > 1:
535
  tokens = tokens[0]
536
 
 
 
 
 
 
 
 
 
537
  text = self.tokenizer.decode(tokens, skip_special_tokens=True).strip()
538
  # Strip <think>...</think> tags (Qwen3 doesn't respect /no_think prompt)
539
  text = re.sub(r"<think>.*?</think>\s*", "", text, flags=re.DOTALL).strip()
@@ -565,14 +573,11 @@ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
565
  if not text:
566
  return ""
567
 
568
- # 1. LOWERCASE
569
- text = text.lower()
570
-
571
- # 2. CHECK FOR KNOWN HALLUCINATIONS (delete entirely)
572
- if text.strip() in self.HALLUCINATION_PATTERNS:
573
  return ""
574
 
575
- # 3. CHECK FOR REGEX-BASED HALLUCINATIONS
576
  for pattern in self.HALLUCINATION_REGEXES:
577
  if pattern.search(text):
578
  # If hallucination is the entire output, return empty
@@ -581,21 +586,21 @@ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
581
  # Otherwise remove the hallucinated portion
582
  text = pattern.sub("", text)
583
 
584
- # 4. COMBINE ACRONYMS
585
- # Merge consecutive single letters into one word (e.g., "u s a" -> "usa")
586
- text = re.sub(r"\b([a-z])((?:\s+[a-z])+)\b", lambda m: m.group(0).replace(" ", ""), text)
587
 
588
- # 5. NORMALIZE CURRENCY
589
  # Convert "eur X" to "X euros" for Whisper normalizer compatibility
590
- text = re.sub(r"\beur\s+(\d+)", r"\1 euros", text)
591
 
592
- # 6. TRUNCATE CHARACTER REPETITIONS (e.g., "uhhhhhh" -> "uhh")
593
  text = self._truncate_character_repetitions(text)
594
 
595
- # 7. TRUNCATE TRAILING REPEATS (word-level)
596
  text = self._truncate_trailing_repeats(text)
597
 
598
- # 8. STRIP WHITESPACE
599
  return re.sub(r"\s+", " ", text).strip()
600
 
601
  def _truncate_trailing_repeats(self, text: str, max_ngram: int = 10) -> str:
 
534
  if tokens.dim() > 1:
535
  tokens = tokens[0]
536
 
537
+ # Filter out eos tokens that the tokenizer doesn't recognize as special
538
+ # (generation_config.eos_token_id may differ from tokenizer.eos_token_id)
539
+ if hasattr(self, "model") and hasattr(self.model, "generation_config"):
540
+ eos_ids = self.model.generation_config.eos_token_id
541
+ if eos_ids is not None:
542
+ eos_set = set(eos_ids) if isinstance(eos_ids, list) else {eos_ids}
543
+ tokens = [t for t in tokens.tolist() if t not in eos_set]
544
+
545
  text = self.tokenizer.decode(tokens, skip_special_tokens=True).strip()
546
  # Strip <think>...</think> tags (Qwen3 doesn't respect /no_think prompt)
547
  text = re.sub(r"<think>.*?</think>\s*", "", text, flags=re.DOTALL).strip()
 
573
  if not text:
574
  return ""
575
 
576
+ # 1. CHECK FOR KNOWN HALLUCINATIONS (delete entirely, case-insensitive)
577
+ if text.strip().lower() in self.HALLUCINATION_PATTERNS:
 
 
 
578
  return ""
579
 
580
+ # 2. CHECK FOR REGEX-BASED HALLUCINATIONS
581
  for pattern in self.HALLUCINATION_REGEXES:
582
  if pattern.search(text):
583
  # If hallucination is the entire output, return empty
 
586
  # Otherwise remove the hallucinated portion
587
  text = pattern.sub("", text)
588
 
589
+ # 3. COMBINE ACRONYMS
590
+ # Merge consecutive single letters into one word (e.g., "U S A" -> "USA")
591
+ text = re.sub(r"\b([a-zA-Z])((?:\s+[a-zA-Z])+)\b", lambda m: m.group(0).replace(" ", ""), text, flags=re.IGNORECASE)
592
 
593
+ # 4. NORMALIZE CURRENCY
594
  # Convert "eur X" to "X euros" for Whisper normalizer compatibility
595
+ text = re.sub(r"\beur\s+(\d+)", r"\1 euros", text, flags=re.IGNORECASE)
596
 
597
+ # 5. TRUNCATE CHARACTER REPETITIONS (e.g., "uhhhhhh" -> "uhh")
598
  text = self._truncate_character_repetitions(text)
599
 
600
+ # 6. TRUNCATE TRAILING REPEATS (word-level)
601
  text = self._truncate_trailing_repeats(text)
602
 
603
+ # 7. STRIP WHITESPACE
604
  return re.sub(r"\s+", " ", text).strip()
605
 
606
  def _truncate_trailing_repeats(self, text: str, max_ngram: int = 10) -> str: