mazesmazes commited on
Commit
b483d77
·
verified ·
1 Parent(s): 78424be

Update custom model files, README, and requirements

Browse files
Files changed (3) hide show
  1. asr_config.py +1 -1
  2. asr_modeling.py +5 -2
  3. asr_pipeline.py +28 -0
asr_config.py CHANGED
@@ -52,7 +52,7 @@ class ASRConfig(transformers.PretrainedConfig):
52
  # Set default generation parameters (greedy decoding only)
53
  generation_defaults = {
54
  "num_beams": 1,
55
- "max_new_tokens": 96,
56
  "repetition_penalty": 1.0,
57
  "length_penalty": 1.0,
58
  "no_repeat_ngram_size": 0,
 
52
  # Set default generation parameters (greedy decoding only)
53
  generation_defaults = {
54
  "num_beams": 1,
55
+ "max_new_tokens": 256,
56
  "repetition_penalty": 1.0,
57
  "length_penalty": 1.0,
58
  "no_repeat_ngram_size": 0,
asr_modeling.py CHANGED
@@ -121,7 +121,10 @@ class ASRModel(PreTrainedModel, GenerationMixin):
121
  self.generation_config.length_penalty = config.length_penalty
122
  self.generation_config.repetition_penalty = config.repetition_penalty
123
  self.generation_config.no_repeat_ngram_size = config.no_repeat_ngram_size
124
- self.generation_config.eos_token_id = self.tokenizer.convert_tokens_to_ids("<|im_end|>")
 
 
 
125
  self.generation_config.pad_token_id = self.tokenizer.pad_token_id
126
 
127
  # Feature extractor for audio preprocessing
@@ -145,7 +148,7 @@ class ASRModel(PreTrainedModel, GenerationMixin):
145
  encoder_kwargs = {
146
  "attn_implementation": config.attn_implementation,
147
  "low_cpu_mem_usage": True,
148
- "torch_dtype": dtype,
149
  }
150
 
151
  if "whisper" in config.audio_model_id.lower():
 
121
  self.generation_config.length_penalty = config.length_penalty
122
  self.generation_config.repetition_penalty = config.repetition_penalty
123
  self.generation_config.no_repeat_ngram_size = config.no_repeat_ngram_size
124
+ self.generation_config.eos_token_id = [
125
+ self.tokenizer.convert_tokens_to_ids("<|im_end|>"),
126
+ self.tokenizer.convert_tokens_to_ids("<|endoftext|>"),
127
+ ]
128
  self.generation_config.pad_token_id = self.tokenizer.pad_token_id
129
 
130
  # Feature extractor for audio preprocessing
 
148
  encoder_kwargs = {
149
  "attn_implementation": config.attn_implementation,
150
  "low_cpu_mem_usage": True,
151
+ "dtype": dtype,
152
  }
153
 
154
  if "whisper" in config.audio_model_id.lower():
asr_pipeline.py CHANGED
@@ -476,4 +476,32 @@ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
476
  text = self.tokenizer.decode(tokens, skip_special_tokens=True).strip()
477
  # Strip <think>...</think> tags (Qwen3 doesn't respect /no_think prompt)
478
  text = re.sub(r"<think>.*?</think>\s*", "", text, flags=re.DOTALL).strip()
 
 
479
  return {"text": text}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
476
  text = self.tokenizer.decode(tokens, skip_special_tokens=True).strip()
477
  # Strip <think>...</think> tags (Qwen3 doesn't respect /no_think prompt)
478
  text = re.sub(r"<think>.*?</think>\s*", "", text, flags=re.DOTALL).strip()
479
+ # Truncate if a word repeats more than 3 times consecutively
480
+ text = self._truncate_repetitions(text, max_repeats=3)
481
  return {"text": text}
482
+
483
+ def _truncate_repetitions(self, text: str, max_repeats: int = 3) -> str:
484
+ """Truncate text when a word repeats more than max_repeats times consecutively.
485
+
486
+ Args:
487
+ text: Input text to check for repetitions
488
+ max_repeats: Maximum allowed consecutive repetitions (default 3)
489
+
490
+ Returns:
491
+ Truncated text if repetition detected, otherwise original text
492
+ """
493
+ words = text.split()
494
+ if len(words) <= max_repeats:
495
+ return text
496
+
497
+ repeat_count = 1
498
+ for i in range(1, len(words)):
499
+ if words[i].lower() == words[i - 1].lower():
500
+ repeat_count += 1
501
+ if repeat_count > max_repeats:
502
+ # Keep up to max_repeats of the repeated word
503
+ return " ".join(words[:i])
504
+ else:
505
+ repeat_count = 1
506
+
507
+ return text