Update custom model files, README, and requirements
Browse files- .gitattributes +0 -1
- asr_config.py +4 -2
- asr_modeling.py +26 -0
- asr_pipeline.py +42 -6
.gitattributes
CHANGED
|
@@ -1,4 +1,3 @@
|
|
| 1 |
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 2 |
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 3 |
tokenizer_config.json -filter -diff -merge text
|
| 4 |
-
tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
|
|
|
| 1 |
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 2 |
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 3 |
tokenizer_config.json -filter -diff -merge text
|
|
|
asr_config.py
CHANGED
|
@@ -25,7 +25,6 @@ class ASRConfig(transformers.PretrainedConfig):
|
|
| 25 |
model_dtype: str = "bfloat16",
|
| 26 |
num_beams: Optional[int] = None,
|
| 27 |
system_prompt: str = "You are a helpful assistant.",
|
| 28 |
-
user_prompt: str = "Please transcribe this English audio into text: <audio>",
|
| 29 |
encoder_dim: Optional[int] = None,
|
| 30 |
llm_dim: Optional[int] = None,
|
| 31 |
# Encoder conv layers: list of (padding, kernel_size, stride) tuples
|
|
@@ -104,7 +103,6 @@ class ASRConfig(transformers.PretrainedConfig):
|
|
| 104 |
self.attn_implementation = attn_implementation
|
| 105 |
self.model_dtype = model_dtype
|
| 106 |
self.system_prompt = system_prompt
|
| 107 |
-
self.user_prompt = user_prompt
|
| 108 |
self.encoder_dim = encoder_dim
|
| 109 |
self.llm_dim = llm_dim
|
| 110 |
# Default conv layers for Whisper/GLM-ASR: [(pad, kernel, stride), ...]
|
|
@@ -206,6 +204,10 @@ class ASRConfig(transformers.PretrainedConfig):
|
|
| 206 |
|
| 207 |
super().__init__(**kwargs)
|
| 208 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 209 |
self.auto_map = {
|
| 210 |
"AutoConfig": "asr_config.ASRConfig",
|
| 211 |
"AutoModel": "asr_modeling.ASRModel",
|
|
|
|
| 25 |
model_dtype: str = "bfloat16",
|
| 26 |
num_beams: Optional[int] = None,
|
| 27 |
system_prompt: str = "You are a helpful assistant.",
|
|
|
|
| 28 |
encoder_dim: Optional[int] = None,
|
| 29 |
llm_dim: Optional[int] = None,
|
| 30 |
# Encoder conv layers: list of (padding, kernel_size, stride) tuples
|
|
|
|
| 103 |
self.attn_implementation = attn_implementation
|
| 104 |
self.model_dtype = model_dtype
|
| 105 |
self.system_prompt = system_prompt
|
|
|
|
| 106 |
self.encoder_dim = encoder_dim
|
| 107 |
self.llm_dim = llm_dim
|
| 108 |
# Default conv layers for Whisper/GLM-ASR: [(pad, kernel, stride), ...]
|
|
|
|
| 204 |
|
| 205 |
super().__init__(**kwargs)
|
| 206 |
|
| 207 |
+
# Point encoder to audio_config so pipeline uses correct feature extractor
|
| 208 |
+
# The pipeline looks for config.encoder._name_or_path for feature extractor
|
| 209 |
+
self.encoder = self.audio_config
|
| 210 |
+
|
| 211 |
self.auto_map = {
|
| 212 |
"AutoConfig": "asr_config.ASRConfig",
|
| 213 |
"AutoModel": "asr_modeling.ASRModel",
|
asr_modeling.py
CHANGED
|
@@ -841,6 +841,27 @@ class ASRModel(PreTrainedModel, GenerationMixin):
|
|
| 841 |
if hasattr(self.language_model, "peft_config"):
|
| 842 |
self.language_model.save_pretrained(save_dir, save_embedding_layers=False)
|
| 843 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 844 |
# Add processor auto_map to preprocessor_config.json
|
| 845 |
config_path = save_dir / "preprocessor_config.json"
|
| 846 |
if config_path.exists():
|
|
@@ -866,6 +887,11 @@ class ASRModel(PreTrainedModel, GenerationMixin):
|
|
| 866 |
# Copy projectors module
|
| 867 |
shutil.copy(src_dir / "projectors.py", save_dir / "projectors.py")
|
| 868 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 869 |
def create_or_update_model_card(self, output_dir: Union[str, Path]) -> None:
|
| 870 |
"""No-op for model card creation - we use MODEL_CARD.md in repo instead."""
|
| 871 |
pass
|
|
|
|
| 841 |
if hasattr(self.language_model, "peft_config"):
|
| 842 |
self.language_model.save_pretrained(save_dir, save_embedding_layers=False)
|
| 843 |
|
| 844 |
+
# Fix adapter_config.json to point base_model_name_or_path to the repo itself
|
| 845 |
+
# This prevents transformers pipeline() from redirecting to the base LLM repo
|
| 846 |
+
# (like Qwen) which breaks feature extractor loading for multimodal models.
|
| 847 |
+
# See: https://huggingface.co/ibm-granite/granite-speech-3.3-2b for reference
|
| 848 |
+
adapter_config_path = save_dir / "adapter_config.json"
|
| 849 |
+
if adapter_config_path.exists():
|
| 850 |
+
with adapter_config_path.open() as f:
|
| 851 |
+
adapter_config = json.load(f)
|
| 852 |
+
|
| 853 |
+
# Use repo_id if provided, otherwise use the save directory name
|
| 854 |
+
# (which becomes the repo ID when pushed to hub)
|
| 855 |
+
repo_id = kwargs.get("repo_id") or kwargs.get("push_to_hub_model_id")
|
| 856 |
+
if repo_id:
|
| 857 |
+
adapter_config["base_model_name_or_path"] = repo_id
|
| 858 |
+
else:
|
| 859 |
+
# Fallback: use save_dir name (works when save_dir matches repo structure)
|
| 860 |
+
adapter_config["base_model_name_or_path"] = save_dir.name
|
| 861 |
+
|
| 862 |
+
with adapter_config_path.open("w") as f:
|
| 863 |
+
json.dump(adapter_config, f, indent=2)
|
| 864 |
+
|
| 865 |
# Add processor auto_map to preprocessor_config.json
|
| 866 |
config_path = save_dir / "preprocessor_config.json"
|
| 867 |
if config_path.exists():
|
|
|
|
| 887 |
# Copy projectors module
|
| 888 |
shutil.copy(src_dir / "projectors.py", save_dir / "projectors.py")
|
| 889 |
|
| 890 |
+
def push_to_hub(self, repo_id: str, **kwargs) -> str:
|
| 891 |
+
"""Push model to HuggingFace Hub, ensuring adapter_config points to repo."""
|
| 892 |
+
# Call parent's push_to_hub with repo_id in kwargs so save_pretrained can use it
|
| 893 |
+
return super().push_to_hub(repo_id, repo_id=repo_id, **kwargs)
|
| 894 |
+
|
| 895 |
def create_or_update_model_card(self, output_dir: Union[str, Path]) -> None:
|
| 896 |
"""No-op for model card creation - we use MODEL_CARD.md in repo instead."""
|
| 897 |
pass
|
asr_pipeline.py
CHANGED
|
@@ -523,6 +523,13 @@ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
|
|
| 523 |
text = self._post_process_prediction(text)
|
| 524 |
return {"text": text}
|
| 525 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 526 |
def _post_process_prediction(self, text: str) -> str:
|
| 527 |
"""Post-process model output to fix common issues."""
|
| 528 |
if not text:
|
|
@@ -531,22 +538,29 @@ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
|
|
| 531 |
# 1. LOWERCASE
|
| 532 |
text = text.lower()
|
| 533 |
|
| 534 |
-
# 2.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 535 |
# Merge consecutive single letters into one word (e.g., "u s a" -> "usa")
|
| 536 |
text = re.sub(r"\b([a-z])((?:\s+[a-z])+)\b", lambda m: m.group(0).replace(" ", ""), text)
|
| 537 |
|
| 538 |
-
#
|
| 539 |
# Convert "eur X" to "X euros" for Whisper normalizer compatibility
|
| 540 |
text = re.sub(r"\beur\s+(\d+)", r"\1 euros", text)
|
| 541 |
|
| 542 |
-
#
|
|
|
|
|
|
|
|
|
|
| 543 |
text = self._truncate_trailing_repeats(text)
|
| 544 |
|
| 545 |
-
#
|
| 546 |
return re.sub(r"\s+", " ", text).strip()
|
| 547 |
|
| 548 |
-
def _truncate_trailing_repeats(self, text: str, max_ngram: int =
|
| 549 |
-
"""Remove trailing repeated n-grams (1-
|
| 550 |
words = text.split()
|
| 551 |
if len(words) < 2:
|
| 552 |
return text
|
|
@@ -566,3 +580,25 @@ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
|
|
| 566 |
break # Restart from largest n-gram
|
| 567 |
|
| 568 |
return " ".join(words)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 523 |
text = self._post_process_prediction(text)
|
| 524 |
return {"text": text}
|
| 525 |
|
| 526 |
+
# Known hallucination patterns that should be deleted entirely
|
| 527 |
+
HALLUCINATION_PATTERNS = frozenset(
|
| 528 |
+
[
|
| 529 |
+
"and gt and gt",
|
| 530 |
+
]
|
| 531 |
+
)
|
| 532 |
+
|
| 533 |
def _post_process_prediction(self, text: str) -> str:
|
| 534 |
"""Post-process model output to fix common issues."""
|
| 535 |
if not text:
|
|
|
|
| 538 |
# 1. LOWERCASE
|
| 539 |
text = text.lower()
|
| 540 |
|
| 541 |
+
# 2. CHECK FOR KNOWN HALLUCINATIONS (delete entirely)
|
| 542 |
+
if text.strip() in self.HALLUCINATION_PATTERNS:
|
| 543 |
+
return ""
|
| 544 |
+
|
| 545 |
+
# 3. COMBINE ACRONYMS
|
| 546 |
# Merge consecutive single letters into one word (e.g., "u s a" -> "usa")
|
| 547 |
text = re.sub(r"\b([a-z])((?:\s+[a-z])+)\b", lambda m: m.group(0).replace(" ", ""), text)
|
| 548 |
|
| 549 |
+
# 4. NORMALIZE CURRENCY
|
| 550 |
# Convert "eur X" to "X euros" for Whisper normalizer compatibility
|
| 551 |
text = re.sub(r"\beur\s+(\d+)", r"\1 euros", text)
|
| 552 |
|
| 553 |
+
# 5. TRUNCATE CHARACTER REPETITIONS (e.g., "uhhhhhh" -> "uhh")
|
| 554 |
+
text = self._truncate_character_repetitions(text)
|
| 555 |
+
|
| 556 |
+
# 6. TRUNCATE TRAILING REPEATS (word-level)
|
| 557 |
text = self._truncate_trailing_repeats(text)
|
| 558 |
|
| 559 |
+
# 7. STRIP WHITESPACE
|
| 560 |
return re.sub(r"\s+", " ", text).strip()
|
| 561 |
|
| 562 |
+
def _truncate_trailing_repeats(self, text: str, max_ngram: int = 10) -> str:
|
| 563 |
+
"""Remove trailing repeated n-grams (1-10 words)."""
|
| 564 |
words = text.split()
|
| 565 |
if len(words) < 2:
|
| 566 |
return text
|
|
|
|
| 580 |
break # Restart from largest n-gram
|
| 581 |
|
| 582 |
return " ".join(words)
|
| 583 |
+
|
| 584 |
+
def _truncate_character_repetitions(self, text: str, max_repeats: int = 3) -> str:
|
| 585 |
+
"""Remove excessive character repetitions (e.g., 'uhhhhhh' -> 'uhh').
|
| 586 |
+
|
| 587 |
+
Handles hallucinations where the model outputs the same character many times,
|
| 588 |
+
like "uhhhhhhhhhhhhhhhhhhhhhhhhh" at the end of a prediction.
|
| 589 |
+
|
| 590 |
+
Args:
|
| 591 |
+
text: Input text to clean
|
| 592 |
+
max_repeats: Maximum allowed consecutive repetitions of a character
|
| 593 |
+
|
| 594 |
+
Returns:
|
| 595 |
+
Text with character repetitions truncated
|
| 596 |
+
"""
|
| 597 |
+
if not text:
|
| 598 |
+
return text
|
| 599 |
+
|
| 600 |
+
# Replace any character repeated more than max_repeats times with max_repeats
|
| 601 |
+
# Pattern: any character followed by itself N+ times
|
| 602 |
+
pattern = rf"(.)\1{{{max_repeats},}}"
|
| 603 |
+
replacement = r"\1" * max_repeats
|
| 604 |
+
return re.sub(pattern, replacement, text)
|