mazesmazes commited on
Commit
6c50acb
·
verified ·
1 Parent(s): 4bf23cb

Update custom model files, README, and requirements

Browse files
Files changed (4) hide show
  1. .gitattributes +0 -1
  2. asr_config.py +13 -15
  3. asr_modeling.py +55 -116
  4. asr_pipeline.py +53 -6
.gitattributes CHANGED
@@ -1,4 +1,3 @@
1
  *.safetensors filter=lfs diff=lfs merge=lfs -text
2
  *.bin filter=lfs diff=lfs merge=lfs -text
3
  tokenizer_config.json -filter -diff -merge text
4
- tokenizer.json filter=lfs diff=lfs merge=lfs -text
 
1
  *.safetensors filter=lfs diff=lfs merge=lfs -text
2
  *.bin filter=lfs diff=lfs merge=lfs -text
3
  tokenizer_config.json -filter -diff -merge text
 
asr_config.py CHANGED
@@ -25,7 +25,6 @@ class ASRConfig(transformers.PretrainedConfig):
25
  model_dtype: str = "bfloat16",
26
  num_beams: Optional[int] = None,
27
  system_prompt: str = "You are a helpful assistant.",
28
- user_prompt: str = "Please transcribe this English audio into text: <audio>",
29
  encoder_dim: Optional[int] = None,
30
  llm_dim: Optional[int] = None,
31
  # Encoder conv layers: list of (padding, kernel_size, stride) tuples
@@ -51,14 +50,12 @@ class ASRConfig(transformers.PretrainedConfig):
51
  qformer_intermediate_size: Optional[int] = None, # FFN size (defaults to 4x hidden)
52
  label_smoothing: float = 0.0, # Label smoothing for cross-entropy loss
53
  inference_warmup_tokens: int = 10,
54
- # SpecAugment settings (Whisper defaults)
55
  use_specaugment: bool = False,
56
- mask_time_prob: float = 0.05, # Probability of masking time steps
57
- mask_time_length: int = 10, # Max length of time mask
58
- mask_time_min_masks: int = 2, # Min number of time masks
59
- mask_feature_prob: float = 0.0, # Probability of masking frequency bins (disabled by default)
60
- mask_feature_length: int = 10, # Max length of frequency mask
61
- mask_feature_min_masks: int = 0, # Min number of frequency masks
62
  # LoRA configuration (for Stage 2 fine-tuning)
63
  use_lora: bool = False,
64
  lora_rank: int = 8, # SALMONN default
@@ -104,7 +101,6 @@ class ASRConfig(transformers.PretrainedConfig):
104
  self.attn_implementation = attn_implementation
105
  self.model_dtype = model_dtype
106
  self.system_prompt = system_prompt
107
- self.user_prompt = user_prompt
108
  self.encoder_dim = encoder_dim
109
  self.llm_dim = llm_dim
110
  # Default conv layers for Whisper/GLM-ASR: [(pad, kernel, stride), ...]
@@ -131,12 +127,10 @@ class ASRConfig(transformers.PretrainedConfig):
131
  self.inference_warmup_tokens = inference_warmup_tokens
132
  # SpecAugment configuration
133
  self.use_specaugment = use_specaugment
134
- self.mask_time_prob = mask_time_prob
135
- self.mask_time_length = mask_time_length
136
- self.mask_time_min_masks = mask_time_min_masks
137
- self.mask_feature_prob = mask_feature_prob
138
- self.mask_feature_length = mask_feature_length
139
- self.mask_feature_min_masks = mask_feature_min_masks
140
  # LoRA configuration
141
  self.use_lora = use_lora
142
  self.lora_rank = lora_rank
@@ -206,6 +200,10 @@ class ASRConfig(transformers.PretrainedConfig):
206
 
207
  super().__init__(**kwargs)
208
 
 
 
 
 
209
  self.auto_map = {
210
  "AutoConfig": "asr_config.ASRConfig",
211
  "AutoModel": "asr_modeling.ASRModel",
 
25
  model_dtype: str = "bfloat16",
26
  num_beams: Optional[int] = None,
27
  system_prompt: str = "You are a helpful assistant.",
 
28
  encoder_dim: Optional[int] = None,
29
  llm_dim: Optional[int] = None,
30
  # Encoder conv layers: list of (padding, kernel_size, stride) tuples
 
50
  qformer_intermediate_size: Optional[int] = None, # FFN size (defaults to 4x hidden)
51
  label_smoothing: float = 0.0, # Label smoothing for cross-entropy loss
52
  inference_warmup_tokens: int = 10,
53
+ # SpecAugment settings
54
  use_specaugment: bool = False,
55
+ num_time_masks: int = 2,
56
+ time_mask_length: int = 10,
57
+ num_freq_masks: int = 0,
58
+ freq_mask_length: int = 10,
 
 
59
  # LoRA configuration (for Stage 2 fine-tuning)
60
  use_lora: bool = False,
61
  lora_rank: int = 8, # SALMONN default
 
101
  self.attn_implementation = attn_implementation
102
  self.model_dtype = model_dtype
103
  self.system_prompt = system_prompt
 
104
  self.encoder_dim = encoder_dim
105
  self.llm_dim = llm_dim
106
  # Default conv layers for Whisper/GLM-ASR: [(pad, kernel, stride), ...]
 
127
  self.inference_warmup_tokens = inference_warmup_tokens
128
  # SpecAugment configuration
129
  self.use_specaugment = use_specaugment
130
+ self.num_time_masks = num_time_masks
131
+ self.time_mask_length = time_mask_length
132
+ self.num_freq_masks = num_freq_masks
133
+ self.freq_mask_length = freq_mask_length
 
 
134
  # LoRA configuration
135
  self.use_lora = use_lora
136
  self.lora_rank = lora_rank
 
200
 
201
  super().__init__(**kwargs)
202
 
203
+ # Point encoder to audio_config so pipeline uses correct feature extractor
204
+ # The pipeline looks for config.encoder._name_or_path for feature extractor
205
+ self.encoder = self.audio_config
206
+
207
  self.auto_map = {
208
  "AutoConfig": "asr_config.ASRConfig",
209
  "AutoModel": "asr_modeling.ASRModel",
asr_modeling.py CHANGED
@@ -24,120 +24,26 @@ except ImportError:
24
  from projectors import PROJECTOR_CLASSES # type: ignore[no-redef]
25
 
26
 
27
- def _compute_mask_indices(
28
- shape: tuple[int, int],
29
- mask_prob: float,
30
- mask_length: int,
31
- min_masks: int = 0,
32
- device: torch.device = None,
33
- ) -> torch.Tensor:
34
- """Compute random mask spans for SpecAugment.
35
-
36
- Based on transformers' _compute_mask_indices for Wav2Vec2/Whisper.
37
-
38
- Args:
39
- shape: (batch_size, sequence_length)
40
- mask_prob: Probability for each token to be chosen as start of mask span
41
- mask_length: Maximum length of mask span
42
- min_masks: Minimum number of masks per sample
43
- device: Device to create tensor on
44
-
45
- Returns:
46
- Boolean mask tensor of shape (batch_size, sequence_length)
47
- """
48
- batch_size, sequence_length = shape
49
-
50
- if mask_length < 1:
51
- raise ValueError(f"mask_length must be >= 1, got {mask_length}")
52
-
53
- if mask_length > sequence_length:
54
- raise ValueError(f"mask_length {mask_length} must be <= sequence_length {sequence_length}")
55
-
56
- # Compute number of masked spans per sample
57
- num_masked_spans = int(mask_prob * sequence_length / mask_length + torch.rand(1).item())
58
- num_masked_spans = max(num_masked_spans, min_masks)
59
-
60
- # Clamp to ensure we don't exceed sequence length
61
- if num_masked_spans * mask_length > sequence_length:
62
- num_masked_spans = sequence_length // mask_length
63
-
64
- if num_masked_spans == 0:
65
- return torch.zeros((batch_size, sequence_length), dtype=torch.bool, device=device)
66
-
67
- # Uniformly sample span start indices
68
- mask = torch.zeros((batch_size, sequence_length), dtype=torch.bool, device=device)
69
-
70
- for i in range(batch_size):
71
- # Random start indices for this sample
72
- spec_aug_start_indices = torch.randint(
73
- 0, sequence_length - mask_length + 1, (num_masked_spans,), device=device
74
- )
75
-
76
- # Create mask spans
77
- for start_idx in spec_aug_start_indices:
78
- mask[i, start_idx : start_idx + mask_length] = True
79
-
80
- return mask
81
 
82
 
83
  def apply_specaugment(
84
- input_features: torch.Tensor,
85
- mask_time_prob: float = 0.05,
86
- mask_time_length: int = 10,
87
- mask_time_min_masks: int = 2,
88
- mask_feature_prob: float = 0.0,
89
- mask_feature_length: int = 10,
90
- mask_feature_min_masks: int = 0,
91
  ) -> torch.Tensor:
92
- """Apply SpecAugment to mel spectrogram features.
93
-
94
- Args:
95
- input_features: Mel spectrogram of shape (batch, n_mels, time)
96
- mask_time_prob: Probability of masking time steps
97
- mask_time_length: Max length of time mask
98
- mask_time_min_masks: Min number of time masks
99
- mask_feature_prob: Probability of masking frequency bins
100
- mask_feature_length: Max length of frequency mask
101
- mask_feature_min_masks: Min number of frequency masks
102
-
103
- Returns:
104
- Augmented mel spectrogram with same shape
105
- """
106
- batch_size, n_mels, time_steps = input_features.shape
107
- device = input_features.device
108
-
109
- # Clone to avoid modifying original
110
- augmented = input_features.clone()
111
-
112
- # Time masking (along time dimension)
113
- # Apply if prob > 0 OR min_masks > 0 (to support fixed mask count with prob=0)
114
- if mask_time_prob > 0 or mask_time_min_masks > 0:
115
- time_mask = _compute_mask_indices(
116
- shape=(batch_size, time_steps),
117
- mask_prob=mask_time_prob,
118
- mask_length=mask_time_length,
119
- min_masks=mask_time_min_masks,
120
- device=device,
121
- )
122
- # Expand to (batch, 1, time) for broadcasting
123
- time_mask = time_mask.unsqueeze(1)
124
- augmented = augmented.masked_fill(time_mask, 0.0)
125
-
126
- # Frequency masking (along mel dimension)
127
- # Apply if prob > 0 OR min_masks > 0 (to support fixed mask count with prob=0)
128
- if mask_feature_prob > 0 or mask_feature_min_masks > 0:
129
- feature_mask = _compute_mask_indices(
130
- shape=(batch_size, n_mels),
131
- mask_prob=mask_feature_prob,
132
- mask_length=mask_feature_length,
133
- min_masks=mask_feature_min_masks,
134
- device=device,
135
- )
136
- # Expand to (batch, n_mels, 1) for broadcasting
137
- feature_mask = feature_mask.unsqueeze(2)
138
- augmented = augmented.masked_fill(feature_mask, 0.0)
139
-
140
- return augmented
141
 
142
 
143
  class ASRModel(PreTrainedModel, GenerationMixin):
@@ -225,6 +131,10 @@ class ASRModel(PreTrainedModel, GenerationMixin):
225
  )
226
  model.language_model = get_peft_model(model.language_model, lora_config)
227
 
 
 
 
 
228
  return model
229
  finally:
230
  cls._is_loading_from_pretrained = False
@@ -393,6 +303,11 @@ class ASRModel(PreTrainedModel, GenerationMixin):
393
  )
394
  self.language_model = get_peft_model(self.language_model, lora_config)
395
 
 
 
 
 
 
396
  def _init_tokenizer(self, config: ASRConfig):
397
  """Initialize tokenizer with audio token."""
398
  self.tokenizer = AutoTokenizer.from_pretrained(config.text_model_id, trust_remote_code=True)
@@ -551,12 +466,10 @@ class ASRModel(PreTrainedModel, GenerationMixin):
551
  if self.training and getattr(self.config, "use_specaugment", False):
552
  input_features = apply_specaugment(
553
  input_features,
554
- mask_time_prob=self.config.mask_time_prob,
555
- mask_time_length=self.config.mask_time_length,
556
- mask_time_min_masks=self.config.mask_time_min_masks,
557
- mask_feature_prob=self.config.mask_feature_prob,
558
- mask_feature_length=self.config.mask_feature_length,
559
- mask_feature_min_masks=self.config.mask_feature_min_masks,
560
  )
561
 
562
  # Encode audio -> flattened (total_audio_tokens, hidden_dim)
@@ -841,6 +754,27 @@ class ASRModel(PreTrainedModel, GenerationMixin):
841
  if hasattr(self.language_model, "peft_config"):
842
  self.language_model.save_pretrained(save_dir, save_embedding_layers=False)
843
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
844
  # Add processor auto_map to preprocessor_config.json
845
  config_path = save_dir / "preprocessor_config.json"
846
  if config_path.exists():
@@ -866,6 +800,11 @@ class ASRModel(PreTrainedModel, GenerationMixin):
866
  # Copy projectors module
867
  shutil.copy(src_dir / "projectors.py", save_dir / "projectors.py")
868
 
 
 
 
 
 
869
  def create_or_update_model_card(self, output_dir: Union[str, Path]) -> None:
870
  """No-op for model card creation - we use MODEL_CARD.md in repo instead."""
871
  pass
 
24
  from projectors import PROJECTOR_CLASSES # type: ignore[no-redef]
25
 
26
 
27
+ from torchaudio.transforms import FrequencyMasking, TimeMasking
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
 
30
  def apply_specaugment(
31
+ x: torch.Tensor,
32
+ num_time_masks: int = 2,
33
+ time_mask_length: int = 10,
34
+ num_freq_masks: int = 0,
35
+ freq_mask_length: int = 10,
 
 
36
  ) -> torch.Tensor:
37
+ """Apply SpecAugment using torchaudio. Input shape: (batch, n_mels, time)."""
38
+ if num_time_masks > 0:
39
+ tm = TimeMasking(time_mask_param=time_mask_length, iid_masks=True)
40
+ for _ in range(num_time_masks):
41
+ x = tm(x)
42
+ if num_freq_masks > 0:
43
+ fm = FrequencyMasking(freq_mask_param=freq_mask_length, iid_masks=True)
44
+ for _ in range(num_freq_masks):
45
+ x = fm(x)
46
+ return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
 
49
  class ASRModel(PreTrainedModel, GenerationMixin):
 
131
  )
132
  model.language_model = get_peft_model(model.language_model, lora_config)
133
 
134
+ # Clear base_model_name_or_path so PEFT doesn't save a reference
135
+ # to the base LLM. See _setup_lora for details.
136
+ model.language_model.peft_config["default"].base_model_name_or_path = None
137
+
138
  return model
139
  finally:
140
  cls._is_loading_from_pretrained = False
 
303
  )
304
  self.language_model = get_peft_model(self.language_model, lora_config)
305
 
306
+ # Clear base_model_name_or_path so PEFT doesn't save a reference to the
307
+ # base LLM (e.g. Qwen). This prevents pipeline() from redirecting to the
308
+ # wrong model. The correct path gets set during save_pretrained/push_to_hub.
309
+ self.language_model.peft_config["default"].base_model_name_or_path = None
310
+
311
  def _init_tokenizer(self, config: ASRConfig):
312
  """Initialize tokenizer with audio token."""
313
  self.tokenizer = AutoTokenizer.from_pretrained(config.text_model_id, trust_remote_code=True)
 
466
  if self.training and getattr(self.config, "use_specaugment", False):
467
  input_features = apply_specaugment(
468
  input_features,
469
+ num_time_masks=self.config.num_time_masks,
470
+ time_mask_length=self.config.time_mask_length,
471
+ num_freq_masks=self.config.num_freq_masks,
472
+ freq_mask_length=self.config.freq_mask_length,
 
 
473
  )
474
 
475
  # Encode audio -> flattened (total_audio_tokens, hidden_dim)
 
754
  if hasattr(self.language_model, "peft_config"):
755
  self.language_model.save_pretrained(save_dir, save_embedding_layers=False)
756
 
757
+ # Fix adapter_config.json to point base_model_name_or_path to the repo itself
758
+ # This prevents transformers pipeline() from redirecting to the base LLM repo
759
+ # (like Qwen) which breaks feature extractor loading for multimodal models.
760
+ # See: https://huggingface.co/ibm-granite/granite-speech-3.3-2b for reference
761
+ adapter_config_path = save_dir / "adapter_config.json"
762
+ if adapter_config_path.exists():
763
+ with adapter_config_path.open() as f:
764
+ adapter_config = json.load(f)
765
+
766
+ # Use repo_id from kwargs or config - never use checkpoint directory name
767
+ repo_id = (
768
+ kwargs.get("repo_id")
769
+ or kwargs.get("push_to_hub_model_id")
770
+ or getattr(self.config, "pretrained_model_path", None)
771
+ )
772
+ if repo_id:
773
+ adapter_config["base_model_name_or_path"] = repo_id
774
+
775
+ with adapter_config_path.open("w") as f:
776
+ json.dump(adapter_config, f, indent=2)
777
+
778
  # Add processor auto_map to preprocessor_config.json
779
  config_path = save_dir / "preprocessor_config.json"
780
  if config_path.exists():
 
800
  # Copy projectors module
801
  shutil.copy(src_dir / "projectors.py", save_dir / "projectors.py")
802
 
803
+ def push_to_hub(self, repo_id: str, **kwargs) -> str:
804
+ """Push model to HuggingFace Hub, ensuring adapter_config points to repo."""
805
+ # Call parent's push_to_hub with repo_id in kwargs so save_pretrained can use it
806
+ return super().push_to_hub(repo_id, repo_id=repo_id, **kwargs)
807
+
808
  def create_or_update_model_card(self, output_dir: Union[str, Path]) -> None:
809
  """No-op for model card creation - we use MODEL_CARD.md in repo instead."""
810
  pass
asr_pipeline.py CHANGED
@@ -332,6 +332,7 @@ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
332
  kwargs.pop("min_speakers", None)
333
  kwargs.pop("max_speakers", None)
334
  kwargs.pop("hf_token", None)
 
335
 
336
  return super()._sanitize_parameters(**kwargs)
337
 
@@ -346,6 +347,7 @@ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
346
  inputs: Audio input (file path, dict with array/sampling_rate, etc.)
347
  return_timestamps: If True, return word-level timestamps using forced alignment
348
  return_speakers: If True, return speaker labels for each word
 
349
  num_speakers: Exact number of speakers (if known, for diarization)
350
  min_speakers: Minimum number of speakers (for diarization)
351
  max_speakers: Maximum number of speakers (for diarization)
@@ -359,6 +361,7 @@ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
359
  # Extract our params before super().__call__ (which will also call _sanitize_parameters)
360
  return_timestamps = kwargs.pop("return_timestamps", False)
361
  return_speakers = kwargs.pop("return_speakers", False)
 
362
  diarization_params = {
363
  "num_speakers": kwargs.pop("num_speakers", None),
364
  "min_speakers": kwargs.pop("min_speakers", None),
@@ -369,6 +372,12 @@ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
369
  if return_speakers:
370
  return_timestamps = True
371
 
 
 
 
 
 
 
372
  # Store audio for timestamp alignment and diarization
373
  if return_timestamps or return_speakers:
374
  self._current_audio = self._extract_audio(inputs)
@@ -416,6 +425,8 @@ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
416
 
417
  # Clean up
418
  self._current_audio = None
 
 
419
 
420
  return result
421
 
@@ -523,6 +534,13 @@ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
523
  text = self._post_process_prediction(text)
524
  return {"text": text}
525
 
 
 
 
 
 
 
 
526
  def _post_process_prediction(self, text: str) -> str:
527
  """Post-process model output to fix common issues."""
528
  if not text:
@@ -531,22 +549,29 @@ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
531
  # 1. LOWERCASE
532
  text = text.lower()
533
 
534
- # 2. COMBINE ACRONYMS
 
 
 
 
535
  # Merge consecutive single letters into one word (e.g., "u s a" -> "usa")
536
  text = re.sub(r"\b([a-z])((?:\s+[a-z])+)\b", lambda m: m.group(0).replace(" ", ""), text)
537
 
538
- # 3. NORMALIZE CURRENCY
539
  # Convert "eur X" to "X euros" for Whisper normalizer compatibility
540
  text = re.sub(r"\beur\s+(\d+)", r"\1 euros", text)
541
 
542
- # 4. TRUNCATE TRAILING REPEATS
 
 
 
543
  text = self._truncate_trailing_repeats(text)
544
 
545
- # 5. STRIP WHITESPACE
546
  return re.sub(r"\s+", " ", text).strip()
547
 
548
- def _truncate_trailing_repeats(self, text: str, max_ngram: int = 4) -> str:
549
- """Remove trailing repeated n-grams (1-4 words)."""
550
  words = text.split()
551
  if len(words) < 2:
552
  return text
@@ -566,3 +591,25 @@ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
566
  break # Restart from largest n-gram
567
 
568
  return " ".join(words)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
332
  kwargs.pop("min_speakers", None)
333
  kwargs.pop("max_speakers", None)
334
  kwargs.pop("hf_token", None)
335
+ kwargs.pop("user_prompt", None)
336
 
337
  return super()._sanitize_parameters(**kwargs)
338
 
 
347
  inputs: Audio input (file path, dict with array/sampling_rate, etc.)
348
  return_timestamps: If True, return word-level timestamps using forced alignment
349
  return_speakers: If True, return speaker labels for each word
350
+ user_prompt: Custom transcription prompt (default: "Transcribe: ")
351
  num_speakers: Exact number of speakers (if known, for diarization)
352
  min_speakers: Minimum number of speakers (for diarization)
353
  max_speakers: Maximum number of speakers (for diarization)
 
361
  # Extract our params before super().__call__ (which will also call _sanitize_parameters)
362
  return_timestamps = kwargs.pop("return_timestamps", False)
363
  return_speakers = kwargs.pop("return_speakers", False)
364
+ user_prompt = kwargs.pop("user_prompt", None)
365
  diarization_params = {
366
  "num_speakers": kwargs.pop("num_speakers", None),
367
  "min_speakers": kwargs.pop("min_speakers", None),
 
372
  if return_speakers:
373
  return_timestamps = True
374
 
375
+ # Set custom user prompt if provided
376
+ original_prompt = None
377
+ if user_prompt:
378
+ original_prompt = self.model.TRANSCRIBE_PROMPT
379
+ self.model.TRANSCRIBE_PROMPT = user_prompt
380
+
381
  # Store audio for timestamp alignment and diarization
382
  if return_timestamps or return_speakers:
383
  self._current_audio = self._extract_audio(inputs)
 
425
 
426
  # Clean up
427
  self._current_audio = None
428
+ if original_prompt is not None:
429
+ self.model.TRANSCRIBE_PROMPT = original_prompt
430
 
431
  return result
432
 
 
534
  text = self._post_process_prediction(text)
535
  return {"text": text}
536
 
537
+ # Known hallucination patterns that should be deleted entirely
538
+ HALLUCINATION_PATTERNS = frozenset(
539
+ [
540
+ "and gt and gt",
541
+ ]
542
+ )
543
+
544
  def _post_process_prediction(self, text: str) -> str:
545
  """Post-process model output to fix common issues."""
546
  if not text:
 
549
  # 1. LOWERCASE
550
  text = text.lower()
551
 
552
+ # 2. CHECK FOR KNOWN HALLUCINATIONS (delete entirely)
553
+ if text.strip() in self.HALLUCINATION_PATTERNS:
554
+ return ""
555
+
556
+ # 3. COMBINE ACRONYMS
557
  # Merge consecutive single letters into one word (e.g., "u s a" -> "usa")
558
  text = re.sub(r"\b([a-z])((?:\s+[a-z])+)\b", lambda m: m.group(0).replace(" ", ""), text)
559
 
560
+ # 4. NORMALIZE CURRENCY
561
  # Convert "eur X" to "X euros" for Whisper normalizer compatibility
562
  text = re.sub(r"\beur\s+(\d+)", r"\1 euros", text)
563
 
564
+ # 5. TRUNCATE CHARACTER REPETITIONS (e.g., "uhhhhhh" -> "uhh")
565
+ text = self._truncate_character_repetitions(text)
566
+
567
+ # 6. TRUNCATE TRAILING REPEATS (word-level)
568
  text = self._truncate_trailing_repeats(text)
569
 
570
+ # 7. STRIP WHITESPACE
571
  return re.sub(r"\s+", " ", text).strip()
572
 
573
+ def _truncate_trailing_repeats(self, text: str, max_ngram: int = 10) -> str:
574
+ """Remove trailing repeated n-grams (1-10 words)."""
575
  words = text.split()
576
  if len(words) < 2:
577
  return text
 
591
  break # Restart from largest n-gram
592
 
593
  return " ".join(words)
594
+
595
+ def _truncate_character_repetitions(self, text: str, max_repeats: int = 3) -> str:
596
+ """Remove excessive character repetitions (e.g., 'uhhhhhh' -> 'uhh').
597
+
598
+ Handles hallucinations where the model outputs the same character many times,
599
+ like "uhhhhhhhhhhhhhhhhhhhhhhhhh" at the end of a prediction.
600
+
601
+ Args:
602
+ text: Input text to clean
603
+ max_repeats: Maximum allowed consecutive repetitions of a character
604
+
605
+ Returns:
606
+ Text with character repetitions truncated
607
+ """
608
+ if not text:
609
+ return text
610
+
611
+ # Replace any character repeated more than max_repeats times with max_repeats
612
+ # Pattern: any character followed by itself N+ times
613
+ pattern = rf"(.)\1{{{max_repeats},}}"
614
+ replacement = r"\1" * max_repeats
615
+ return re.sub(pattern, replacement, text)