Update custom model files, README, and requirements
Browse files- .gitattributes +0 -1
- asr_pipeline.py +42 -6
- handler.py +19 -0
.gitattributes
CHANGED
|
@@ -1,4 +1,3 @@
|
|
| 1 |
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 2 |
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 3 |
tokenizer_config.json -filter -diff -merge text
|
| 4 |
-
tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
|
|
|
| 1 |
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 2 |
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 3 |
tokenizer_config.json -filter -diff -merge text
|
|
|
asr_pipeline.py
CHANGED
|
@@ -523,6 +523,13 @@ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
|
|
| 523 |
text = self._post_process_prediction(text)
|
| 524 |
return {"text": text}
|
| 525 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 526 |
def _post_process_prediction(self, text: str) -> str:
|
| 527 |
"""Post-process model output to fix common issues."""
|
| 528 |
if not text:
|
|
@@ -531,22 +538,29 @@ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
|
|
| 531 |
# 1. LOWERCASE
|
| 532 |
text = text.lower()
|
| 533 |
|
| 534 |
-
# 2.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 535 |
# Merge consecutive single letters into one word (e.g., "u s a" -> "usa")
|
| 536 |
text = re.sub(r"\b([a-z])((?:\s+[a-z])+)\b", lambda m: m.group(0).replace(" ", ""), text)
|
| 537 |
|
| 538 |
-
#
|
| 539 |
# Convert "eur X" to "X euros" for Whisper normalizer compatibility
|
| 540 |
text = re.sub(r"\beur\s+(\d+)", r"\1 euros", text)
|
| 541 |
|
| 542 |
-
#
|
|
|
|
|
|
|
|
|
|
| 543 |
text = self._truncate_trailing_repeats(text)
|
| 544 |
|
| 545 |
-
#
|
| 546 |
return re.sub(r"\s+", " ", text).strip()
|
| 547 |
|
| 548 |
-
def _truncate_trailing_repeats(self, text: str, max_ngram: int =
|
| 549 |
-
"""Remove trailing repeated n-grams (1-
|
| 550 |
words = text.split()
|
| 551 |
if len(words) < 2:
|
| 552 |
return text
|
|
@@ -566,3 +580,25 @@ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
|
|
| 566 |
break # Restart from largest n-gram
|
| 567 |
|
| 568 |
return " ".join(words)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 523 |
text = self._post_process_prediction(text)
|
| 524 |
return {"text": text}
|
| 525 |
|
| 526 |
+
# Known hallucination patterns that should be deleted entirely
|
| 527 |
+
HALLUCINATION_PATTERNS = frozenset(
|
| 528 |
+
[
|
| 529 |
+
"and gt and gt",
|
| 530 |
+
]
|
| 531 |
+
)
|
| 532 |
+
|
| 533 |
def _post_process_prediction(self, text: str) -> str:
|
| 534 |
"""Post-process model output to fix common issues."""
|
| 535 |
if not text:
|
|
|
|
| 538 |
# 1. LOWERCASE
|
| 539 |
text = text.lower()
|
| 540 |
|
| 541 |
+
# 2. CHECK FOR KNOWN HALLUCINATIONS (delete entirely)
|
| 542 |
+
if text.strip() in self.HALLUCINATION_PATTERNS:
|
| 543 |
+
return ""
|
| 544 |
+
|
| 545 |
+
# 3. COMBINE ACRONYMS
|
| 546 |
# Merge consecutive single letters into one word (e.g., "u s a" -> "usa")
|
| 547 |
text = re.sub(r"\b([a-z])((?:\s+[a-z])+)\b", lambda m: m.group(0).replace(" ", ""), text)
|
| 548 |
|
| 549 |
+
# 4. NORMALIZE CURRENCY
|
| 550 |
# Convert "eur X" to "X euros" for Whisper normalizer compatibility
|
| 551 |
text = re.sub(r"\beur\s+(\d+)", r"\1 euros", text)
|
| 552 |
|
| 553 |
+
# 5. TRUNCATE CHARACTER REPETITIONS (e.g., "uhhhhhh" -> "uhh")
|
| 554 |
+
text = self._truncate_character_repetitions(text)
|
| 555 |
+
|
| 556 |
+
# 6. TRUNCATE TRAILING REPEATS (word-level)
|
| 557 |
text = self._truncate_trailing_repeats(text)
|
| 558 |
|
| 559 |
+
# 7. STRIP WHITESPACE
|
| 560 |
return re.sub(r"\s+", " ", text).strip()
|
| 561 |
|
| 562 |
+
def _truncate_trailing_repeats(self, text: str, max_ngram: int = 10) -> str:
|
| 563 |
+
"""Remove trailing repeated n-grams (1-10 words)."""
|
| 564 |
words = text.split()
|
| 565 |
if len(words) < 2:
|
| 566 |
return text
|
|
|
|
| 580 |
break # Restart from largest n-gram
|
| 581 |
|
| 582 |
return " ".join(words)
|
| 583 |
+
|
| 584 |
+
def _truncate_character_repetitions(self, text: str, max_repeats: int = 3) -> str:
|
| 585 |
+
"""Remove excessive character repetitions (e.g., 'uhhhhhh' -> 'uhh').
|
| 586 |
+
|
| 587 |
+
Handles hallucinations where the model outputs the same character many times,
|
| 588 |
+
like "uhhhhhhhhhhhhhhhhhhhhhhhhh" at the end of a prediction.
|
| 589 |
+
|
| 590 |
+
Args:
|
| 591 |
+
text: Input text to clean
|
| 592 |
+
max_repeats: Maximum allowed consecutive repetitions of a character
|
| 593 |
+
|
| 594 |
+
Returns:
|
| 595 |
+
Text with character repetitions truncated
|
| 596 |
+
"""
|
| 597 |
+
if not text:
|
| 598 |
+
return text
|
| 599 |
+
|
| 600 |
+
# Replace any character repeated more than max_repeats times with max_repeats
|
| 601 |
+
# Pattern: any character followed by itself N+ times
|
| 602 |
+
pattern = rf"(.)\1{{{max_repeats},}}"
|
| 603 |
+
replacement = r"\1" * max_repeats
|
| 604 |
+
return re.sub(pattern, replacement, text)
|
handler.py
CHANGED
|
@@ -15,7 +15,18 @@ except ImportError:
|
|
| 15 |
|
| 16 |
|
| 17 |
class EndpointHandler:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
def __init__(self, path: str = ""):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
import os
|
| 20 |
|
| 21 |
import nltk
|
|
@@ -104,6 +115,14 @@ class EndpointHandler:
|
|
| 104 |
print(f"Warmup skipped due to: {e}")
|
| 105 |
|
| 106 |
def __call__(self, data: Dict[str, Any]) -> Union[Dict[str, Any], List[Dict[str, Any]]]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 107 |
inputs = data.get("inputs")
|
| 108 |
if inputs is None:
|
| 109 |
raise ValueError("Missing 'inputs' in request data")
|
|
|
|
| 15 |
|
| 16 |
|
| 17 |
class EndpointHandler:
|
| 18 |
+
"""HuggingFace Inference Endpoints handler for ASR model.
|
| 19 |
+
|
| 20 |
+
Handles model loading, warmup, and inference requests for deployment
|
| 21 |
+
on HuggingFace Inference Endpoints or similar services.
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
def __init__(self, path: str = ""):
|
| 25 |
+
"""Initialize the endpoint handler.
|
| 26 |
+
|
| 27 |
+
Args:
|
| 28 |
+
path: Path to model directory or HuggingFace model ID
|
| 29 |
+
"""
|
| 30 |
import os
|
| 31 |
|
| 32 |
import nltk
|
|
|
|
| 115 |
print(f"Warmup skipped due to: {e}")
|
| 116 |
|
| 117 |
def __call__(self, data: Dict[str, Any]) -> Union[Dict[str, Any], List[Dict[str, Any]]]:
|
| 118 |
+
"""Process an inference request.
|
| 119 |
+
|
| 120 |
+
Args:
|
| 121 |
+
data: Request data containing 'inputs' (audio path/bytes) and optional 'parameters'
|
| 122 |
+
|
| 123 |
+
Returns:
|
| 124 |
+
Transcription result with 'text' key
|
| 125 |
+
"""
|
| 126 |
inputs = data.get("inputs")
|
| 127 |
if inputs is None:
|
| 128 |
raise ValueError("Missing 'inputs' in request data")
|