Fix postprocess: filter eos tokens, preserve casing
Browse files- asr_pipeline.py +19 -14
asr_pipeline.py
CHANGED
|
@@ -534,6 +534,14 @@ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
|
|
| 534 |
if tokens.dim() > 1:
|
| 535 |
tokens = tokens[0]
|
| 536 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 537 |
text = self.tokenizer.decode(tokens, skip_special_tokens=True).strip()
|
| 538 |
# Strip <think>...</think> tags (Qwen3 doesn't respect /no_think prompt)
|
| 539 |
text = re.sub(r"<think>.*?</think>\s*", "", text, flags=re.DOTALL).strip()
|
|
@@ -565,14 +573,11 @@ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
|
|
| 565 |
if not text:
|
| 566 |
return ""
|
| 567 |
|
| 568 |
-
# 1.
|
| 569 |
-
|
| 570 |
-
|
| 571 |
-
# 2. CHECK FOR KNOWN HALLUCINATIONS (delete entirely)
|
| 572 |
-
if text.strip() in self.HALLUCINATION_PATTERNS:
|
| 573 |
return ""
|
| 574 |
|
| 575 |
-
#
|
| 576 |
for pattern in self.HALLUCINATION_REGEXES:
|
| 577 |
if pattern.search(text):
|
| 578 |
# If hallucination is the entire output, return empty
|
|
@@ -581,21 +586,21 @@ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
|
|
| 581 |
# Otherwise remove the hallucinated portion
|
| 582 |
text = pattern.sub("", text)
|
| 583 |
|
| 584 |
-
#
|
| 585 |
-
# Merge consecutive single letters into one word (e.g., "
|
| 586 |
-
text = re.sub(r"\b([a-
|
| 587 |
|
| 588 |
-
#
|
| 589 |
# Convert "eur X" to "X euros" for Whisper normalizer compatibility
|
| 590 |
-
text = re.sub(r"\beur\s+(\d+)", r"\1 euros", text)
|
| 591 |
|
| 592 |
-
#
|
| 593 |
text = self._truncate_character_repetitions(text)
|
| 594 |
|
| 595 |
-
#
|
| 596 |
text = self._truncate_trailing_repeats(text)
|
| 597 |
|
| 598 |
-
#
|
| 599 |
return re.sub(r"\s+", " ", text).strip()
|
| 600 |
|
| 601 |
def _truncate_trailing_repeats(self, text: str, max_ngram: int = 10) -> str:
|
|
|
|
| 534 |
if tokens.dim() > 1:
|
| 535 |
tokens = tokens[0]
|
| 536 |
|
| 537 |
+
# Filter out eos tokens that the tokenizer doesn't recognize as special
|
| 538 |
+
# (generation_config.eos_token_id may differ from tokenizer.eos_token_id)
|
| 539 |
+
if hasattr(self, "model") and hasattr(self.model, "generation_config"):
|
| 540 |
+
eos_ids = self.model.generation_config.eos_token_id
|
| 541 |
+
if eos_ids is not None:
|
| 542 |
+
eos_set = set(eos_ids) if isinstance(eos_ids, list) else {eos_ids}
|
| 543 |
+
tokens = [t for t in tokens.tolist() if t not in eos_set]
|
| 544 |
+
|
| 545 |
text = self.tokenizer.decode(tokens, skip_special_tokens=True).strip()
|
| 546 |
# Strip <think>...</think> tags (Qwen3 doesn't respect /no_think prompt)
|
| 547 |
text = re.sub(r"<think>.*?</think>\s*", "", text, flags=re.DOTALL).strip()
|
|
|
|
| 573 |
if not text:
|
| 574 |
return ""
|
| 575 |
|
| 576 |
+
# 1. CHECK FOR KNOWN HALLUCINATIONS (delete entirely, case-insensitive)
|
| 577 |
+
if text.strip().lower() in self.HALLUCINATION_PATTERNS:
|
|
|
|
|
|
|
|
|
|
| 578 |
return ""
|
| 579 |
|
| 580 |
+
# 2. CHECK FOR REGEX-BASED HALLUCINATIONS
|
| 581 |
for pattern in self.HALLUCINATION_REGEXES:
|
| 582 |
if pattern.search(text):
|
| 583 |
# If hallucination is the entire output, return empty
|
|
|
|
| 586 |
# Otherwise remove the hallucinated portion
|
| 587 |
text = pattern.sub("", text)
|
| 588 |
|
| 589 |
+
# 3. COMBINE ACRONYMS
|
| 590 |
+
# Merge consecutive single letters into one word (e.g., "U S A" -> "USA")
|
| 591 |
+
text = re.sub(r"\b([a-zA-Z])((?:\s+[a-zA-Z])+)\b", lambda m: m.group(0).replace(" ", ""), text, flags=re.IGNORECASE)
|
| 592 |
|
| 593 |
+
# 4. NORMALIZE CURRENCY
|
| 594 |
# Convert "eur X" to "X euros" for Whisper normalizer compatibility
|
| 595 |
+
text = re.sub(r"\beur\s+(\d+)", r"\1 euros", text, flags=re.IGNORECASE)
|
| 596 |
|
| 597 |
+
# 5. TRUNCATE CHARACTER REPETITIONS (e.g., "uhhhhhh" -> "uhh")
|
| 598 |
text = self._truncate_character_repetitions(text)
|
| 599 |
|
| 600 |
+
# 6. TRUNCATE TRAILING REPEATS (word-level)
|
| 601 |
text = self._truncate_trailing_repeats(text)
|
| 602 |
|
| 603 |
+
# 7. STRIP WHITESPACE
|
| 604 |
return re.sub(r"\s+", " ", text).strip()
|
| 605 |
|
| 606 |
def _truncate_trailing_repeats(self, text: str, max_ngram: int = 10) -> str:
|