mazesmazes commited on
Commit
10ed507
·
verified ·
1 Parent(s): 8ade3a7

Training in progress - step 25000

Browse files
Files changed (4) hide show
  1. .gitattributes +1 -0
  2. asr_config.py +2 -4
  3. asr_modeling.py +0 -35
  4. asr_pipeline.py +6 -42
.gitattributes CHANGED
@@ -1,3 +1,4 @@
1
  *.safetensors filter=lfs diff=lfs merge=lfs -text
2
  *.bin filter=lfs diff=lfs merge=lfs -text
3
  tokenizer_config.json -filter -diff -merge text
 
 
1
  *.safetensors filter=lfs diff=lfs merge=lfs -text
2
  *.bin filter=lfs diff=lfs merge=lfs -text
3
  tokenizer_config.json -filter -diff -merge text
4
+ tokenizer.json filter=lfs diff=lfs merge=lfs -text
asr_config.py CHANGED
@@ -25,6 +25,7 @@ class ASRConfig(transformers.PretrainedConfig):
25
  model_dtype: str = "bfloat16",
26
  num_beams: Optional[int] = None,
27
  system_prompt: str = "You are a helpful assistant.",
 
28
  encoder_dim: Optional[int] = None,
29
  llm_dim: Optional[int] = None,
30
  # Encoder conv layers: list of (padding, kernel_size, stride) tuples
@@ -103,6 +104,7 @@ class ASRConfig(transformers.PretrainedConfig):
103
  self.attn_implementation = attn_implementation
104
  self.model_dtype = model_dtype
105
  self.system_prompt = system_prompt
 
106
  self.encoder_dim = encoder_dim
107
  self.llm_dim = llm_dim
108
  # Default conv layers for Whisper/GLM-ASR: [(pad, kernel, stride), ...]
@@ -204,10 +206,6 @@ class ASRConfig(transformers.PretrainedConfig):
204
 
205
  super().__init__(**kwargs)
206
 
207
- # Point encoder to audio_config so pipeline uses correct feature extractor
208
- # The pipeline looks for config.encoder._name_or_path for feature extractor
209
- self.encoder = self.audio_config
210
-
211
  self.auto_map = {
212
  "AutoConfig": "asr_config.ASRConfig",
213
  "AutoModel": "asr_modeling.ASRModel",
 
25
  model_dtype: str = "bfloat16",
26
  num_beams: Optional[int] = None,
27
  system_prompt: str = "You are a helpful assistant.",
28
+ user_prompt: str = "Please transcribe this English audio into text: <audio>",
29
  encoder_dim: Optional[int] = None,
30
  llm_dim: Optional[int] = None,
31
  # Encoder conv layers: list of (padding, kernel_size, stride) tuples
 
104
  self.attn_implementation = attn_implementation
105
  self.model_dtype = model_dtype
106
  self.system_prompt = system_prompt
107
+ self.user_prompt = user_prompt
108
  self.encoder_dim = encoder_dim
109
  self.llm_dim = llm_dim
110
  # Default conv layers for Whisper/GLM-ASR: [(pad, kernel, stride), ...]
 
206
 
207
  super().__init__(**kwargs)
208
 
 
 
 
 
209
  self.auto_map = {
210
  "AutoConfig": "asr_config.ASRConfig",
211
  "AutoModel": "asr_modeling.ASRModel",
asr_modeling.py CHANGED
@@ -225,10 +225,6 @@ class ASRModel(PreTrainedModel, GenerationMixin):
225
  )
226
  model.language_model = get_peft_model(model.language_model, lora_config)
227
 
228
- # Clear base_model_name_or_path so PEFT doesn't save a reference
229
- # to the base LLM. See _setup_lora for details.
230
- model.language_model.peft_config["default"].base_model_name_or_path = None
231
-
232
  return model
233
  finally:
234
  cls._is_loading_from_pretrained = False
@@ -397,11 +393,6 @@ class ASRModel(PreTrainedModel, GenerationMixin):
397
  )
398
  self.language_model = get_peft_model(self.language_model, lora_config)
399
 
400
- # Clear base_model_name_or_path so PEFT doesn't save a reference to the
401
- # base LLM (e.g. Qwen). This prevents pipeline() from redirecting to the
402
- # wrong model. The correct path gets set during save_pretrained/push_to_hub.
403
- self.language_model.peft_config["default"].base_model_name_or_path = None
404
-
405
  def _init_tokenizer(self, config: ASRConfig):
406
  """Initialize tokenizer with audio token."""
407
  self.tokenizer = AutoTokenizer.from_pretrained(config.text_model_id, trust_remote_code=True)
@@ -850,27 +841,6 @@ class ASRModel(PreTrainedModel, GenerationMixin):
850
  if hasattr(self.language_model, "peft_config"):
851
  self.language_model.save_pretrained(save_dir, save_embedding_layers=False)
852
 
853
- # Fix adapter_config.json to point base_model_name_or_path to the repo itself
854
- # This prevents transformers pipeline() from redirecting to the base LLM repo
855
- # (like Qwen) which breaks feature extractor loading for multimodal models.
856
- # See: https://huggingface.co/ibm-granite/granite-speech-3.3-2b for reference
857
- adapter_config_path = save_dir / "adapter_config.json"
858
- if adapter_config_path.exists():
859
- with adapter_config_path.open() as f:
860
- adapter_config = json.load(f)
861
-
862
- # Use repo_id if provided, otherwise use the save directory name
863
- # (which becomes the repo ID when pushed to hub)
864
- repo_id = kwargs.get("repo_id") or kwargs.get("push_to_hub_model_id")
865
- if repo_id:
866
- adapter_config["base_model_name_or_path"] = repo_id
867
- else:
868
- # Fallback: use save_dir name (works when save_dir matches repo structure)
869
- adapter_config["base_model_name_or_path"] = save_dir.name
870
-
871
- with adapter_config_path.open("w") as f:
872
- json.dump(adapter_config, f, indent=2)
873
-
874
  # Add processor auto_map to preprocessor_config.json
875
  config_path = save_dir / "preprocessor_config.json"
876
  if config_path.exists():
@@ -896,11 +866,6 @@ class ASRModel(PreTrainedModel, GenerationMixin):
896
  # Copy projectors module
897
  shutil.copy(src_dir / "projectors.py", save_dir / "projectors.py")
898
 
899
- def push_to_hub(self, repo_id: str, **kwargs) -> str:
900
- """Push model to HuggingFace Hub, ensuring adapter_config points to repo."""
901
- # Call parent's push_to_hub with repo_id in kwargs so save_pretrained can use it
902
- return super().push_to_hub(repo_id, repo_id=repo_id, **kwargs)
903
-
904
  def create_or_update_model_card(self, output_dir: Union[str, Path]) -> None:
905
  """No-op for model card creation - we use MODEL_CARD.md in repo instead."""
906
  pass
 
225
  )
226
  model.language_model = get_peft_model(model.language_model, lora_config)
227
 
 
 
 
 
228
  return model
229
  finally:
230
  cls._is_loading_from_pretrained = False
 
393
  )
394
  self.language_model = get_peft_model(self.language_model, lora_config)
395
 
 
 
 
 
 
396
  def _init_tokenizer(self, config: ASRConfig):
397
  """Initialize tokenizer with audio token."""
398
  self.tokenizer = AutoTokenizer.from_pretrained(config.text_model_id, trust_remote_code=True)
 
841
  if hasattr(self.language_model, "peft_config"):
842
  self.language_model.save_pretrained(save_dir, save_embedding_layers=False)
843
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
844
  # Add processor auto_map to preprocessor_config.json
845
  config_path = save_dir / "preprocessor_config.json"
846
  if config_path.exists():
 
866
  # Copy projectors module
867
  shutil.copy(src_dir / "projectors.py", save_dir / "projectors.py")
868
 
 
 
 
 
 
869
  def create_or_update_model_card(self, output_dir: Union[str, Path]) -> None:
870
  """No-op for model card creation - we use MODEL_CARD.md in repo instead."""
871
  pass
asr_pipeline.py CHANGED
@@ -523,13 +523,6 @@ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
523
  text = self._post_process_prediction(text)
524
  return {"text": text}
525
 
526
- # Known hallucination patterns that should be deleted entirely
527
- HALLUCINATION_PATTERNS = frozenset(
528
- [
529
- "and gt and gt",
530
- ]
531
- )
532
-
533
  def _post_process_prediction(self, text: str) -> str:
534
  """Post-process model output to fix common issues."""
535
  if not text:
@@ -538,29 +531,22 @@ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
538
  # 1. LOWERCASE
539
  text = text.lower()
540
 
541
- # 2. CHECK FOR KNOWN HALLUCINATIONS (delete entirely)
542
- if text.strip() in self.HALLUCINATION_PATTERNS:
543
- return ""
544
-
545
- # 3. COMBINE ACRONYMS
546
  # Merge consecutive single letters into one word (e.g., "u s a" -> "usa")
547
  text = re.sub(r"\b([a-z])((?:\s+[a-z])+)\b", lambda m: m.group(0).replace(" ", ""), text)
548
 
549
- # 4. NORMALIZE CURRENCY
550
  # Convert "eur X" to "X euros" for Whisper normalizer compatibility
551
  text = re.sub(r"\beur\s+(\d+)", r"\1 euros", text)
552
 
553
- # 5. TRUNCATE CHARACTER REPETITIONS (e.g., "uhhhhhh" -> "uhh")
554
- text = self._truncate_character_repetitions(text)
555
-
556
- # 6. TRUNCATE TRAILING REPEATS (word-level)
557
  text = self._truncate_trailing_repeats(text)
558
 
559
- # 7. STRIP WHITESPACE
560
  return re.sub(r"\s+", " ", text).strip()
561
 
562
- def _truncate_trailing_repeats(self, text: str, max_ngram: int = 10) -> str:
563
- """Remove trailing repeated n-grams (1-10 words)."""
564
  words = text.split()
565
  if len(words) < 2:
566
  return text
@@ -580,25 +566,3 @@ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
580
  break # Restart from largest n-gram
581
 
582
  return " ".join(words)
583
-
584
- def _truncate_character_repetitions(self, text: str, max_repeats: int = 3) -> str:
585
- """Remove excessive character repetitions (e.g., 'uhhhhhh' -> 'uhh').
586
-
587
- Handles hallucinations where the model outputs the same character many times,
588
- like "uhhhhhhhhhhhhhhhhhhhhhhhhh" at the end of a prediction.
589
-
590
- Args:
591
- text: Input text to clean
592
- max_repeats: Maximum allowed consecutive repetitions of a character
593
-
594
- Returns:
595
- Text with character repetitions truncated
596
- """
597
- if not text:
598
- return text
599
-
600
- # Replace any character repeated more than max_repeats times with max_repeats
601
- # Pattern: any character followed by itself N+ times
602
- pattern = rf"(.)\1{{{max_repeats},}}"
603
- replacement = r"\1" * max_repeats
604
- return re.sub(pattern, replacement, text)
 
523
  text = self._post_process_prediction(text)
524
  return {"text": text}
525
 
 
 
 
 
 
 
 
526
  def _post_process_prediction(self, text: str) -> str:
527
  """Post-process model output to fix common issues."""
528
  if not text:
 
531
  # 1. LOWERCASE
532
  text = text.lower()
533
 
534
+ # 2. COMBINE ACRONYMS
 
 
 
 
535
  # Merge consecutive single letters into one word (e.g., "u s a" -> "usa")
536
  text = re.sub(r"\b([a-z])((?:\s+[a-z])+)\b", lambda m: m.group(0).replace(" ", ""), text)
537
 
538
+ # 3. NORMALIZE CURRENCY
539
  # Convert "eur X" to "X euros" for Whisper normalizer compatibility
540
  text = re.sub(r"\beur\s+(\d+)", r"\1 euros", text)
541
 
542
+ # 4. TRUNCATE TRAILING REPEATS
 
 
 
543
  text = self._truncate_trailing_repeats(text)
544
 
545
+ # 5. STRIP WHITESPACE
546
  return re.sub(r"\s+", " ", text).strip()
547
 
548
+ def _truncate_trailing_repeats(self, text: str, max_ngram: int = 4) -> str:
549
+ """Remove trailing repeated n-grams (1-4 words)."""
550
  words = text.split()
551
  if len(words) < 2:
552
  return text
 
566
  break # Restart from largest n-gram
567
 
568
  return " ".join(words)