mazesmazes commited on
Commit
593eaa4
·
verified ·
1 Parent(s): 7d32caa

Update custom model files, README, and requirements

Browse files
Files changed (4) hide show
  1. .gitattributes +0 -1
  2. asr_config.py +4 -2
  3. asr_modeling.py +26 -0
  4. asr_pipeline.py +42 -6
.gitattributes CHANGED
@@ -1,4 +1,3 @@
1
  *.safetensors filter=lfs diff=lfs merge=lfs -text
2
  *.bin filter=lfs diff=lfs merge=lfs -text
3
  tokenizer_config.json -filter -diff -merge text
4
- tokenizer.json filter=lfs diff=lfs merge=lfs -text
 
1
  *.safetensors filter=lfs diff=lfs merge=lfs -text
2
  *.bin filter=lfs diff=lfs merge=lfs -text
3
  tokenizer_config.json -filter -diff -merge text
 
asr_config.py CHANGED
@@ -25,7 +25,6 @@ class ASRConfig(transformers.PretrainedConfig):
25
  model_dtype: str = "bfloat16",
26
  num_beams: Optional[int] = None,
27
  system_prompt: str = "You are a helpful assistant.",
28
- user_prompt: str = "Please transcribe this English audio into text: <audio>",
29
  encoder_dim: Optional[int] = None,
30
  llm_dim: Optional[int] = None,
31
  # Encoder conv layers: list of (padding, kernel_size, stride) tuples
@@ -104,7 +103,6 @@ class ASRConfig(transformers.PretrainedConfig):
104
  self.attn_implementation = attn_implementation
105
  self.model_dtype = model_dtype
106
  self.system_prompt = system_prompt
107
- self.user_prompt = user_prompt
108
  self.encoder_dim = encoder_dim
109
  self.llm_dim = llm_dim
110
  # Default conv layers for Whisper/GLM-ASR: [(pad, kernel, stride), ...]
@@ -206,6 +204,10 @@ class ASRConfig(transformers.PretrainedConfig):
206
 
207
  super().__init__(**kwargs)
208
 
 
 
 
 
209
  self.auto_map = {
210
  "AutoConfig": "asr_config.ASRConfig",
211
  "AutoModel": "asr_modeling.ASRModel",
 
25
  model_dtype: str = "bfloat16",
26
  num_beams: Optional[int] = None,
27
  system_prompt: str = "You are a helpful assistant.",
 
28
  encoder_dim: Optional[int] = None,
29
  llm_dim: Optional[int] = None,
30
  # Encoder conv layers: list of (padding, kernel_size, stride) tuples
 
103
  self.attn_implementation = attn_implementation
104
  self.model_dtype = model_dtype
105
  self.system_prompt = system_prompt
 
106
  self.encoder_dim = encoder_dim
107
  self.llm_dim = llm_dim
108
  # Default conv layers for Whisper/GLM-ASR: [(pad, kernel, stride), ...]
 
204
 
205
  super().__init__(**kwargs)
206
 
207
+ # Point encoder to audio_config so pipeline uses correct feature extractor
208
+ # The pipeline looks for config.encoder._name_or_path for feature extractor
209
+ self.encoder = self.audio_config
210
+
211
  self.auto_map = {
212
  "AutoConfig": "asr_config.ASRConfig",
213
  "AutoModel": "asr_modeling.ASRModel",
asr_modeling.py CHANGED
@@ -841,6 +841,27 @@ class ASRModel(PreTrainedModel, GenerationMixin):
841
  if hasattr(self.language_model, "peft_config"):
842
  self.language_model.save_pretrained(save_dir, save_embedding_layers=False)
843
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
844
  # Add processor auto_map to preprocessor_config.json
845
  config_path = save_dir / "preprocessor_config.json"
846
  if config_path.exists():
@@ -866,6 +887,11 @@ class ASRModel(PreTrainedModel, GenerationMixin):
866
  # Copy projectors module
867
  shutil.copy(src_dir / "projectors.py", save_dir / "projectors.py")
868
 
 
 
 
 
 
869
  def create_or_update_model_card(self, output_dir: Union[str, Path]) -> None:
870
  """No-op for model card creation - we use MODEL_CARD.md in repo instead."""
871
  pass
 
841
  if hasattr(self.language_model, "peft_config"):
842
  self.language_model.save_pretrained(save_dir, save_embedding_layers=False)
843
 
844
+ # Fix adapter_config.json to point base_model_name_or_path to the repo itself
845
+ # This prevents transformers pipeline() from redirecting to the base LLM repo
846
+ # (like Qwen) which breaks feature extractor loading for multimodal models.
847
+ # See: https://huggingface.co/ibm-granite/granite-speech-3.3-2b for reference
848
+ adapter_config_path = save_dir / "adapter_config.json"
849
+ if adapter_config_path.exists():
850
+ with adapter_config_path.open() as f:
851
+ adapter_config = json.load(f)
852
+
853
+ # Use repo_id if provided, otherwise use the save directory name
854
+ # (which becomes the repo ID when pushed to hub)
855
+ repo_id = kwargs.get("repo_id") or kwargs.get("push_to_hub_model_id")
856
+ if repo_id:
857
+ adapter_config["base_model_name_or_path"] = repo_id
858
+ else:
859
+ # Fallback: use save_dir name (works when save_dir matches repo structure)
860
+ adapter_config["base_model_name_or_path"] = save_dir.name
861
+
862
+ with adapter_config_path.open("w") as f:
863
+ json.dump(adapter_config, f, indent=2)
864
+
865
  # Add processor auto_map to preprocessor_config.json
866
  config_path = save_dir / "preprocessor_config.json"
867
  if config_path.exists():
 
887
  # Copy projectors module
888
  shutil.copy(src_dir / "projectors.py", save_dir / "projectors.py")
889
 
890
+ def push_to_hub(self, repo_id: str, **kwargs) -> str:
891
+ """Push model to HuggingFace Hub, ensuring adapter_config points to repo."""
892
+ # Call parent's push_to_hub with repo_id in kwargs so save_pretrained can use it
893
+ return super().push_to_hub(repo_id, repo_id=repo_id, **kwargs)
894
+
895
  def create_or_update_model_card(self, output_dir: Union[str, Path]) -> None:
896
  """No-op for model card creation - we use MODEL_CARD.md in repo instead."""
897
  pass
asr_pipeline.py CHANGED
@@ -523,6 +523,13 @@ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
523
  text = self._post_process_prediction(text)
524
  return {"text": text}
525
 
 
 
 
 
 
 
 
526
  def _post_process_prediction(self, text: str) -> str:
527
  """Post-process model output to fix common issues."""
528
  if not text:
@@ -531,22 +538,29 @@ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
531
  # 1. LOWERCASE
532
  text = text.lower()
533
 
534
- # 2. COMBINE ACRONYMS
 
 
 
 
535
  # Merge consecutive single letters into one word (e.g., "u s a" -> "usa")
536
  text = re.sub(r"\b([a-z])((?:\s+[a-z])+)\b", lambda m: m.group(0).replace(" ", ""), text)
537
 
538
- # 3. NORMALIZE CURRENCY
539
  # Convert "eur X" to "X euros" for Whisper normalizer compatibility
540
  text = re.sub(r"\beur\s+(\d+)", r"\1 euros", text)
541
 
542
- # 4. TRUNCATE TRAILING REPEATS
 
 
 
543
  text = self._truncate_trailing_repeats(text)
544
 
545
- # 5. STRIP WHITESPACE
546
  return re.sub(r"\s+", " ", text).strip()
547
 
548
- def _truncate_trailing_repeats(self, text: str, max_ngram: int = 4) -> str:
549
- """Remove trailing repeated n-grams (1-4 words)."""
550
  words = text.split()
551
  if len(words) < 2:
552
  return text
@@ -566,3 +580,25 @@ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
566
  break # Restart from largest n-gram
567
 
568
  return " ".join(words)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
523
  text = self._post_process_prediction(text)
524
  return {"text": text}
525
 
526
+ # Known hallucination patterns that should be deleted entirely
527
+ HALLUCINATION_PATTERNS = frozenset(
528
+ [
529
+ "and gt and gt",
530
+ ]
531
+ )
532
+
533
  def _post_process_prediction(self, text: str) -> str:
534
  """Post-process model output to fix common issues."""
535
  if not text:
 
538
  # 1. LOWERCASE
539
  text = text.lower()
540
 
541
+ # 2. CHECK FOR KNOWN HALLUCINATIONS (delete entirely)
542
+ if text.strip() in self.HALLUCINATION_PATTERNS:
543
+ return ""
544
+
545
+ # 3. COMBINE ACRONYMS
546
  # Merge consecutive single letters into one word (e.g., "u s a" -> "usa")
547
  text = re.sub(r"\b([a-z])((?:\s+[a-z])+)\b", lambda m: m.group(0).replace(" ", ""), text)
548
 
549
+ # 4. NORMALIZE CURRENCY
550
  # Convert "eur X" to "X euros" for Whisper normalizer compatibility
551
  text = re.sub(r"\beur\s+(\d+)", r"\1 euros", text)
552
 
553
+ # 5. TRUNCATE CHARACTER REPETITIONS (e.g., "uhhhhhh" -> "uhh")
554
+ text = self._truncate_character_repetitions(text)
555
+
556
+ # 6. TRUNCATE TRAILING REPEATS (word-level)
557
  text = self._truncate_trailing_repeats(text)
558
 
559
+ # 7. STRIP WHITESPACE
560
  return re.sub(r"\s+", " ", text).strip()
561
 
562
+ def _truncate_trailing_repeats(self, text: str, max_ngram: int = 10) -> str:
563
+ """Remove trailing repeated n-grams (1-10 words)."""
564
  words = text.split()
565
  if len(words) < 2:
566
  return text
 
580
  break # Restart from largest n-gram
581
 
582
  return " ".join(words)
583
+
584
+ def _truncate_character_repetitions(self, text: str, max_repeats: int = 3) -> str:
585
+ """Remove excessive character repetitions (e.g., 'uhhhhhh' -> 'uhh').
586
+
587
+ Handles hallucinations where the model outputs the same character many times,
588
+ like "uhhhhhhhhhhhhhhhhhhhhhhhhh" at the end of a prediction.
589
+
590
+ Args:
591
+ text: Input text to clean
592
+ max_repeats: Maximum allowed consecutive repetitions of a character
593
+
594
+ Returns:
595
+ Text with character repetitions truncated
596
+ """
597
+ if not text:
598
+ return text
599
+
600
+ # Replace any character repeated more than max_repeats times with max_repeats
601
+ # Pattern: any character followed by itself N+ times
602
+ pattern = rf"(.)\1{{{max_repeats},}}"
603
+ replacement = r"\1" * max_repeats
604
+ return re.sub(pattern, replacement, text)