mazesmazes commited on
Commit
fdd8e4a
·
verified ·
1 Parent(s): a670b3e

Update custom model files, README, and requirements

Browse files
Files changed (3) hide show
  1. .gitattributes +0 -1
  2. asr_config.py +1 -1
  3. asr_pipeline.py +26 -1
.gitattributes CHANGED
@@ -1,4 +1,3 @@
1
  *.safetensors filter=lfs diff=lfs merge=lfs -text
2
  *.bin filter=lfs diff=lfs merge=lfs -text
3
  tokenizer_config.json -filter -diff -merge text
4
- tokenizer.json filter=lfs diff=lfs merge=lfs -text
 
1
  *.safetensors filter=lfs diff=lfs merge=lfs -text
2
  *.bin filter=lfs diff=lfs merge=lfs -text
3
  tokenizer_config.json -filter -diff -merge text
 
asr_config.py CHANGED
@@ -67,7 +67,7 @@ class ASRConfig(transformers.PretrainedConfig):
67
  # Set default generation parameters (greedy decoding only)
68
  generation_defaults = {
69
  "num_beams": 1,
70
- "max_new_tokens": 256,
71
  "min_new_tokens": 0,
72
  "repetition_penalty": 1.0,
73
  "length_penalty": 1.0,
 
67
  # Set default generation parameters (greedy decoding only)
68
  generation_defaults = {
69
  "num_beams": 1,
70
+ "max_new_tokens": 128,
71
  "min_new_tokens": 0,
72
  "repetition_penalty": 1.0,
73
  "length_penalty": 1.0,
asr_pipeline.py CHANGED
@@ -496,5 +496,30 @@ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
496
  # Convert "eur X" to "X euros" for Whisper normalizer compatibility
497
  text = re.sub(r"\beur\s+(\d+)", r"\1 euros", text)
498
 
499
- # 4. STRIP WHITESPACE
 
 
 
500
  return re.sub(r"\s+", " ", text).strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
496
  # Convert "eur X" to "X euros" for Whisper normalizer compatibility
497
  text = re.sub(r"\beur\s+(\d+)", r"\1 euros", text)
498
 
499
+ # 4. TRUNCATE TRAILING REPEATS
500
+ text = self._truncate_trailing_repeats(text)
501
+
502
+ # 5. STRIP WHITESPACE
503
  return re.sub(r"\s+", " ", text).strip()
504
+
505
+ def _truncate_trailing_repeats(self, text: str, max_ngram: int = 4) -> str:
506
+ """Remove trailing repeated n-grams (1-4 words)."""
507
+ words = text.split()
508
+ if len(words) < 2:
509
+ return text
510
+
511
+ # Keep truncating until no more trailing repeats found
512
+ changed = True
513
+ while changed:
514
+ changed = False
515
+ # Check for repeating n-grams from largest to smallest
516
+ for n in range(min(max_ngram, len(words) // 2), 0, -1):
517
+ if len(words) < n * 2:
518
+ continue
519
+ # Check if last n words repeat the previous n words
520
+ if words[-n:] == words[-2 * n : -n]:
521
+ words = words[:-n] # Remove the trailing repeat
522
+ changed = True
523
+ break # Restart from largest n-gram
524
+
525
+ return " ".join(words)