mazesmazes commited on
Commit
f6987eb
·
verified ·
1 Parent(s): f914ea1

Update custom model files, README, and requirements

Browse files
Files changed (3) hide show
  1. .gitattributes +0 -1
  2. asr_pipeline.py +42 -6
  3. handler.py +19 -0
.gitattributes CHANGED
@@ -1,4 +1,3 @@
1
  *.safetensors filter=lfs diff=lfs merge=lfs -text
2
  *.bin filter=lfs diff=lfs merge=lfs -text
3
  tokenizer_config.json -filter -diff -merge text
4
- tokenizer.json filter=lfs diff=lfs merge=lfs -text
 
1
  *.safetensors filter=lfs diff=lfs merge=lfs -text
2
  *.bin filter=lfs diff=lfs merge=lfs -text
3
  tokenizer_config.json -filter -diff -merge text
 
asr_pipeline.py CHANGED
@@ -523,6 +523,13 @@ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
523
  text = self._post_process_prediction(text)
524
  return {"text": text}
525
 
 
 
 
 
 
 
 
526
  def _post_process_prediction(self, text: str) -> str:
527
  """Post-process model output to fix common issues."""
528
  if not text:
@@ -531,22 +538,29 @@ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
531
  # 1. LOWERCASE
532
  text = text.lower()
533
 
534
- # 2. COMBINE ACRONYMS
 
 
 
 
535
  # Merge consecutive single letters into one word (e.g., "u s a" -> "usa")
536
  text = re.sub(r"\b([a-z])((?:\s+[a-z])+)\b", lambda m: m.group(0).replace(" ", ""), text)
537
 
538
- # 3. NORMALIZE CURRENCY
539
  # Convert "eur X" to "X euros" for Whisper normalizer compatibility
540
  text = re.sub(r"\beur\s+(\d+)", r"\1 euros", text)
541
 
542
- # 4. TRUNCATE TRAILING REPEATS
 
 
 
543
  text = self._truncate_trailing_repeats(text)
544
 
545
- # 5. STRIP WHITESPACE
546
  return re.sub(r"\s+", " ", text).strip()
547
 
548
- def _truncate_trailing_repeats(self, text: str, max_ngram: int = 4) -> str:
549
- """Remove trailing repeated n-grams (1-4 words)."""
550
  words = text.split()
551
  if len(words) < 2:
552
  return text
@@ -566,3 +580,25 @@ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
566
  break # Restart from largest n-gram
567
 
568
  return " ".join(words)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
523
  text = self._post_process_prediction(text)
524
  return {"text": text}
525
 
526
+ # Known hallucination patterns that should be deleted entirely
527
+ HALLUCINATION_PATTERNS = frozenset(
528
+ [
529
+ "and gt and gt",
530
+ ]
531
+ )
532
+
533
  def _post_process_prediction(self, text: str) -> str:
534
  """Post-process model output to fix common issues."""
535
  if not text:
 
538
  # 1. LOWERCASE
539
  text = text.lower()
540
 
541
+ # 2. CHECK FOR KNOWN HALLUCINATIONS (delete entirely)
542
+ if text.strip() in self.HALLUCINATION_PATTERNS:
543
+ return ""
544
+
545
+ # 3. COMBINE ACRONYMS
546
  # Merge consecutive single letters into one word (e.g., "u s a" -> "usa")
547
  text = re.sub(r"\b([a-z])((?:\s+[a-z])+)\b", lambda m: m.group(0).replace(" ", ""), text)
548
 
549
+ # 4. NORMALIZE CURRENCY
550
  # Convert "eur X" to "X euros" for Whisper normalizer compatibility
551
  text = re.sub(r"\beur\s+(\d+)", r"\1 euros", text)
552
 
553
+ # 5. TRUNCATE CHARACTER REPETITIONS (e.g., "uhhhhhh" -> "uhh")
554
+ text = self._truncate_character_repetitions(text)
555
+
556
+ # 6. TRUNCATE TRAILING REPEATS (word-level)
557
  text = self._truncate_trailing_repeats(text)
558
 
559
+ # 7. STRIP WHITESPACE
560
  return re.sub(r"\s+", " ", text).strip()
561
 
562
+ def _truncate_trailing_repeats(self, text: str, max_ngram: int = 10) -> str:
563
+ """Remove trailing repeated n-grams (1-10 words)."""
564
  words = text.split()
565
  if len(words) < 2:
566
  return text
 
580
  break # Restart from largest n-gram
581
 
582
  return " ".join(words)
583
+
584
+ def _truncate_character_repetitions(self, text: str, max_repeats: int = 3) -> str:
585
+ """Remove excessive character repetitions (e.g., 'uhhhhhh' -> 'uhh').
586
+
587
+ Handles hallucinations where the model outputs the same character many times,
588
+ like "uhhhhhhhhhhhhhhhhhhhhhhhhh" at the end of a prediction.
589
+
590
+ Args:
591
+ text: Input text to clean
592
+ max_repeats: Maximum allowed consecutive repetitions of a character
593
+
594
+ Returns:
595
+ Text with character repetitions truncated
596
+ """
597
+ if not text:
598
+ return text
599
+
600
+ # Replace any character repeated more than max_repeats times with max_repeats
601
+ # Pattern: any character followed by itself N+ times
602
+ pattern = rf"(.)\1{{{max_repeats},}}"
603
+ replacement = r"\1" * max_repeats
604
+ return re.sub(pattern, replacement, text)
handler.py CHANGED
@@ -15,7 +15,18 @@ except ImportError:
15
 
16
 
17
  class EndpointHandler:
 
 
 
 
 
 
18
  def __init__(self, path: str = ""):
 
 
 
 
 
19
  import os
20
 
21
  import nltk
@@ -104,6 +115,14 @@ class EndpointHandler:
104
  print(f"Warmup skipped due to: {e}")
105
 
106
  def __call__(self, data: Dict[str, Any]) -> Union[Dict[str, Any], List[Dict[str, Any]]]:
 
 
 
 
 
 
 
 
107
  inputs = data.get("inputs")
108
  if inputs is None:
109
  raise ValueError("Missing 'inputs' in request data")
 
15
 
16
 
17
  class EndpointHandler:
18
+ """HuggingFace Inference Endpoints handler for ASR model.
19
+
20
+ Handles model loading, warmup, and inference requests for deployment
21
+ on HuggingFace Inference Endpoints or similar services.
22
+ """
23
+
24
  def __init__(self, path: str = ""):
25
+ """Initialize the endpoint handler.
26
+
27
+ Args:
28
+ path: Path to model directory or HuggingFace model ID
29
+ """
30
  import os
31
 
32
  import nltk
 
115
  print(f"Warmup skipped due to: {e}")
116
 
117
  def __call__(self, data: Dict[str, Any]) -> Union[Dict[str, Any], List[Dict[str, Any]]]:
118
+ """Process an inference request.
119
+
120
+ Args:
121
+ data: Request data containing 'inputs' (audio path/bytes) and optional 'parameters'
122
+
123
+ Returns:
124
+ Transcription result with 'text' key
125
+ """
126
  inputs = data.get("inputs")
127
  if inputs is None:
128
  raise ValueError("Missing 'inputs' in request data")