Update custom model files, README, and requirements
Browse files- .gitattributes +0 -1
- asr_config.py +13 -15
- asr_modeling.py +55 -116
- asr_pipeline.py +53 -6
.gitattributes
CHANGED
|
@@ -1,4 +1,3 @@
|
|
| 1 |
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 2 |
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 3 |
tokenizer_config.json -filter -diff -merge text
|
| 4 |
-
tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
|
|
|
| 1 |
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 2 |
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 3 |
tokenizer_config.json -filter -diff -merge text
|
|
|
asr_config.py
CHANGED
|
@@ -25,7 +25,6 @@ class ASRConfig(transformers.PretrainedConfig):
|
|
| 25 |
model_dtype: str = "bfloat16",
|
| 26 |
num_beams: Optional[int] = None,
|
| 27 |
system_prompt: str = "You are a helpful assistant.",
|
| 28 |
-
user_prompt: str = "Please transcribe this English audio into text: <audio>",
|
| 29 |
encoder_dim: Optional[int] = None,
|
| 30 |
llm_dim: Optional[int] = None,
|
| 31 |
# Encoder conv layers: list of (padding, kernel_size, stride) tuples
|
|
@@ -51,14 +50,12 @@ class ASRConfig(transformers.PretrainedConfig):
|
|
| 51 |
qformer_intermediate_size: Optional[int] = None, # FFN size (defaults to 4x hidden)
|
| 52 |
label_smoothing: float = 0.0, # Label smoothing for cross-entropy loss
|
| 53 |
inference_warmup_tokens: int = 10,
|
| 54 |
-
# SpecAugment settings
|
| 55 |
use_specaugment: bool = False,
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
mask_feature_length: int = 10, # Max length of frequency mask
|
| 61 |
-
mask_feature_min_masks: int = 0, # Min number of frequency masks
|
| 62 |
# LoRA configuration (for Stage 2 fine-tuning)
|
| 63 |
use_lora: bool = False,
|
| 64 |
lora_rank: int = 8, # SALMONN default
|
|
@@ -104,7 +101,6 @@ class ASRConfig(transformers.PretrainedConfig):
|
|
| 104 |
self.attn_implementation = attn_implementation
|
| 105 |
self.model_dtype = model_dtype
|
| 106 |
self.system_prompt = system_prompt
|
| 107 |
-
self.user_prompt = user_prompt
|
| 108 |
self.encoder_dim = encoder_dim
|
| 109 |
self.llm_dim = llm_dim
|
| 110 |
# Default conv layers for Whisper/GLM-ASR: [(pad, kernel, stride), ...]
|
|
@@ -131,12 +127,10 @@ class ASRConfig(transformers.PretrainedConfig):
|
|
| 131 |
self.inference_warmup_tokens = inference_warmup_tokens
|
| 132 |
# SpecAugment configuration
|
| 133 |
self.use_specaugment = use_specaugment
|
| 134 |
-
self.
|
| 135 |
-
self.
|
| 136 |
-
self.
|
| 137 |
-
self.
|
| 138 |
-
self.mask_feature_length = mask_feature_length
|
| 139 |
-
self.mask_feature_min_masks = mask_feature_min_masks
|
| 140 |
# LoRA configuration
|
| 141 |
self.use_lora = use_lora
|
| 142 |
self.lora_rank = lora_rank
|
|
@@ -206,6 +200,10 @@ class ASRConfig(transformers.PretrainedConfig):
|
|
| 206 |
|
| 207 |
super().__init__(**kwargs)
|
| 208 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 209 |
self.auto_map = {
|
| 210 |
"AutoConfig": "asr_config.ASRConfig",
|
| 211 |
"AutoModel": "asr_modeling.ASRModel",
|
|
|
|
| 25 |
model_dtype: str = "bfloat16",
|
| 26 |
num_beams: Optional[int] = None,
|
| 27 |
system_prompt: str = "You are a helpful assistant.",
|
|
|
|
| 28 |
encoder_dim: Optional[int] = None,
|
| 29 |
llm_dim: Optional[int] = None,
|
| 30 |
# Encoder conv layers: list of (padding, kernel_size, stride) tuples
|
|
|
|
| 50 |
qformer_intermediate_size: Optional[int] = None, # FFN size (defaults to 4x hidden)
|
| 51 |
label_smoothing: float = 0.0, # Label smoothing for cross-entropy loss
|
| 52 |
inference_warmup_tokens: int = 10,
|
| 53 |
+
# SpecAugment settings
|
| 54 |
use_specaugment: bool = False,
|
| 55 |
+
num_time_masks: int = 2,
|
| 56 |
+
time_mask_length: int = 10,
|
| 57 |
+
num_freq_masks: int = 0,
|
| 58 |
+
freq_mask_length: int = 10,
|
|
|
|
|
|
|
| 59 |
# LoRA configuration (for Stage 2 fine-tuning)
|
| 60 |
use_lora: bool = False,
|
| 61 |
lora_rank: int = 8, # SALMONN default
|
|
|
|
| 101 |
self.attn_implementation = attn_implementation
|
| 102 |
self.model_dtype = model_dtype
|
| 103 |
self.system_prompt = system_prompt
|
|
|
|
| 104 |
self.encoder_dim = encoder_dim
|
| 105 |
self.llm_dim = llm_dim
|
| 106 |
# Default conv layers for Whisper/GLM-ASR: [(pad, kernel, stride), ...]
|
|
|
|
| 127 |
self.inference_warmup_tokens = inference_warmup_tokens
|
| 128 |
# SpecAugment configuration
|
| 129 |
self.use_specaugment = use_specaugment
|
| 130 |
+
self.num_time_masks = num_time_masks
|
| 131 |
+
self.time_mask_length = time_mask_length
|
| 132 |
+
self.num_freq_masks = num_freq_masks
|
| 133 |
+
self.freq_mask_length = freq_mask_length
|
|
|
|
|
|
|
| 134 |
# LoRA configuration
|
| 135 |
self.use_lora = use_lora
|
| 136 |
self.lora_rank = lora_rank
|
|
|
|
| 200 |
|
| 201 |
super().__init__(**kwargs)
|
| 202 |
|
| 203 |
+
# Point encoder to audio_config so pipeline uses correct feature extractor
|
| 204 |
+
# The pipeline looks for config.encoder._name_or_path for feature extractor
|
| 205 |
+
self.encoder = self.audio_config
|
| 206 |
+
|
| 207 |
self.auto_map = {
|
| 208 |
"AutoConfig": "asr_config.ASRConfig",
|
| 209 |
"AutoModel": "asr_modeling.ASRModel",
|
asr_modeling.py
CHANGED
|
@@ -24,120 +24,26 @@ except ImportError:
|
|
| 24 |
from projectors import PROJECTOR_CLASSES # type: ignore[no-redef]
|
| 25 |
|
| 26 |
|
| 27 |
-
|
| 28 |
-
shape: tuple[int, int],
|
| 29 |
-
mask_prob: float,
|
| 30 |
-
mask_length: int,
|
| 31 |
-
min_masks: int = 0,
|
| 32 |
-
device: torch.device = None,
|
| 33 |
-
) -> torch.Tensor:
|
| 34 |
-
"""Compute random mask spans for SpecAugment.
|
| 35 |
-
|
| 36 |
-
Based on transformers' _compute_mask_indices for Wav2Vec2/Whisper.
|
| 37 |
-
|
| 38 |
-
Args:
|
| 39 |
-
shape: (batch_size, sequence_length)
|
| 40 |
-
mask_prob: Probability for each token to be chosen as start of mask span
|
| 41 |
-
mask_length: Maximum length of mask span
|
| 42 |
-
min_masks: Minimum number of masks per sample
|
| 43 |
-
device: Device to create tensor on
|
| 44 |
-
|
| 45 |
-
Returns:
|
| 46 |
-
Boolean mask tensor of shape (batch_size, sequence_length)
|
| 47 |
-
"""
|
| 48 |
-
batch_size, sequence_length = shape
|
| 49 |
-
|
| 50 |
-
if mask_length < 1:
|
| 51 |
-
raise ValueError(f"mask_length must be >= 1, got {mask_length}")
|
| 52 |
-
|
| 53 |
-
if mask_length > sequence_length:
|
| 54 |
-
raise ValueError(f"mask_length {mask_length} must be <= sequence_length {sequence_length}")
|
| 55 |
-
|
| 56 |
-
# Compute number of masked spans per sample
|
| 57 |
-
num_masked_spans = int(mask_prob * sequence_length / mask_length + torch.rand(1).item())
|
| 58 |
-
num_masked_spans = max(num_masked_spans, min_masks)
|
| 59 |
-
|
| 60 |
-
# Clamp to ensure we don't exceed sequence length
|
| 61 |
-
if num_masked_spans * mask_length > sequence_length:
|
| 62 |
-
num_masked_spans = sequence_length // mask_length
|
| 63 |
-
|
| 64 |
-
if num_masked_spans == 0:
|
| 65 |
-
return torch.zeros((batch_size, sequence_length), dtype=torch.bool, device=device)
|
| 66 |
-
|
| 67 |
-
# Uniformly sample span start indices
|
| 68 |
-
mask = torch.zeros((batch_size, sequence_length), dtype=torch.bool, device=device)
|
| 69 |
-
|
| 70 |
-
for i in range(batch_size):
|
| 71 |
-
# Random start indices for this sample
|
| 72 |
-
spec_aug_start_indices = torch.randint(
|
| 73 |
-
0, sequence_length - mask_length + 1, (num_masked_spans,), device=device
|
| 74 |
-
)
|
| 75 |
-
|
| 76 |
-
# Create mask spans
|
| 77 |
-
for start_idx in spec_aug_start_indices:
|
| 78 |
-
mask[i, start_idx : start_idx + mask_length] = True
|
| 79 |
-
|
| 80 |
-
return mask
|
| 81 |
|
| 82 |
|
| 83 |
def apply_specaugment(
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
mask_feature_length: int = 10,
|
| 90 |
-
mask_feature_min_masks: int = 0,
|
| 91 |
) -> torch.Tensor:
|
| 92 |
-
"""Apply SpecAugment
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
Returns:
|
| 104 |
-
Augmented mel spectrogram with same shape
|
| 105 |
-
"""
|
| 106 |
-
batch_size, n_mels, time_steps = input_features.shape
|
| 107 |
-
device = input_features.device
|
| 108 |
-
|
| 109 |
-
# Clone to avoid modifying original
|
| 110 |
-
augmented = input_features.clone()
|
| 111 |
-
|
| 112 |
-
# Time masking (along time dimension)
|
| 113 |
-
# Apply if prob > 0 OR min_masks > 0 (to support fixed mask count with prob=0)
|
| 114 |
-
if mask_time_prob > 0 or mask_time_min_masks > 0:
|
| 115 |
-
time_mask = _compute_mask_indices(
|
| 116 |
-
shape=(batch_size, time_steps),
|
| 117 |
-
mask_prob=mask_time_prob,
|
| 118 |
-
mask_length=mask_time_length,
|
| 119 |
-
min_masks=mask_time_min_masks,
|
| 120 |
-
device=device,
|
| 121 |
-
)
|
| 122 |
-
# Expand to (batch, 1, time) for broadcasting
|
| 123 |
-
time_mask = time_mask.unsqueeze(1)
|
| 124 |
-
augmented = augmented.masked_fill(time_mask, 0.0)
|
| 125 |
-
|
| 126 |
-
# Frequency masking (along mel dimension)
|
| 127 |
-
# Apply if prob > 0 OR min_masks > 0 (to support fixed mask count with prob=0)
|
| 128 |
-
if mask_feature_prob > 0 or mask_feature_min_masks > 0:
|
| 129 |
-
feature_mask = _compute_mask_indices(
|
| 130 |
-
shape=(batch_size, n_mels),
|
| 131 |
-
mask_prob=mask_feature_prob,
|
| 132 |
-
mask_length=mask_feature_length,
|
| 133 |
-
min_masks=mask_feature_min_masks,
|
| 134 |
-
device=device,
|
| 135 |
-
)
|
| 136 |
-
# Expand to (batch, n_mels, 1) for broadcasting
|
| 137 |
-
feature_mask = feature_mask.unsqueeze(2)
|
| 138 |
-
augmented = augmented.masked_fill(feature_mask, 0.0)
|
| 139 |
-
|
| 140 |
-
return augmented
|
| 141 |
|
| 142 |
|
| 143 |
class ASRModel(PreTrainedModel, GenerationMixin):
|
|
@@ -225,6 +131,10 @@ class ASRModel(PreTrainedModel, GenerationMixin):
|
|
| 225 |
)
|
| 226 |
model.language_model = get_peft_model(model.language_model, lora_config)
|
| 227 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 228 |
return model
|
| 229 |
finally:
|
| 230 |
cls._is_loading_from_pretrained = False
|
|
@@ -393,6 +303,11 @@ class ASRModel(PreTrainedModel, GenerationMixin):
|
|
| 393 |
)
|
| 394 |
self.language_model = get_peft_model(self.language_model, lora_config)
|
| 395 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 396 |
def _init_tokenizer(self, config: ASRConfig):
|
| 397 |
"""Initialize tokenizer with audio token."""
|
| 398 |
self.tokenizer = AutoTokenizer.from_pretrained(config.text_model_id, trust_remote_code=True)
|
|
@@ -551,12 +466,10 @@ class ASRModel(PreTrainedModel, GenerationMixin):
|
|
| 551 |
if self.training and getattr(self.config, "use_specaugment", False):
|
| 552 |
input_features = apply_specaugment(
|
| 553 |
input_features,
|
| 554 |
-
|
| 555 |
-
|
| 556 |
-
|
| 557 |
-
|
| 558 |
-
mask_feature_length=self.config.mask_feature_length,
|
| 559 |
-
mask_feature_min_masks=self.config.mask_feature_min_masks,
|
| 560 |
)
|
| 561 |
|
| 562 |
# Encode audio -> flattened (total_audio_tokens, hidden_dim)
|
|
@@ -841,6 +754,27 @@ class ASRModel(PreTrainedModel, GenerationMixin):
|
|
| 841 |
if hasattr(self.language_model, "peft_config"):
|
| 842 |
self.language_model.save_pretrained(save_dir, save_embedding_layers=False)
|
| 843 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 844 |
# Add processor auto_map to preprocessor_config.json
|
| 845 |
config_path = save_dir / "preprocessor_config.json"
|
| 846 |
if config_path.exists():
|
|
@@ -866,6 +800,11 @@ class ASRModel(PreTrainedModel, GenerationMixin):
|
|
| 866 |
# Copy projectors module
|
| 867 |
shutil.copy(src_dir / "projectors.py", save_dir / "projectors.py")
|
| 868 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 869 |
def create_or_update_model_card(self, output_dir: Union[str, Path]) -> None:
|
| 870 |
"""No-op for model card creation - we use MODEL_CARD.md in repo instead."""
|
| 871 |
pass
|
|
|
|
| 24 |
from projectors import PROJECTOR_CLASSES # type: ignore[no-redef]
|
| 25 |
|
| 26 |
|
| 27 |
+
from torchaudio.transforms import FrequencyMasking, TimeMasking
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
|
| 29 |
|
| 30 |
def apply_specaugment(
|
| 31 |
+
x: torch.Tensor,
|
| 32 |
+
num_time_masks: int = 2,
|
| 33 |
+
time_mask_length: int = 10,
|
| 34 |
+
num_freq_masks: int = 0,
|
| 35 |
+
freq_mask_length: int = 10,
|
|
|
|
|
|
|
| 36 |
) -> torch.Tensor:
|
| 37 |
+
"""Apply SpecAugment using torchaudio. Input shape: (batch, n_mels, time)."""
|
| 38 |
+
if num_time_masks > 0:
|
| 39 |
+
tm = TimeMasking(time_mask_param=time_mask_length, iid_masks=True)
|
| 40 |
+
for _ in range(num_time_masks):
|
| 41 |
+
x = tm(x)
|
| 42 |
+
if num_freq_masks > 0:
|
| 43 |
+
fm = FrequencyMasking(freq_mask_param=freq_mask_length, iid_masks=True)
|
| 44 |
+
for _ in range(num_freq_masks):
|
| 45 |
+
x = fm(x)
|
| 46 |
+
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
|
| 48 |
|
| 49 |
class ASRModel(PreTrainedModel, GenerationMixin):
|
|
|
|
| 131 |
)
|
| 132 |
model.language_model = get_peft_model(model.language_model, lora_config)
|
| 133 |
|
| 134 |
+
# Clear base_model_name_or_path so PEFT doesn't save a reference
|
| 135 |
+
# to the base LLM. See _setup_lora for details.
|
| 136 |
+
model.language_model.peft_config["default"].base_model_name_or_path = None
|
| 137 |
+
|
| 138 |
return model
|
| 139 |
finally:
|
| 140 |
cls._is_loading_from_pretrained = False
|
|
|
|
| 303 |
)
|
| 304 |
self.language_model = get_peft_model(self.language_model, lora_config)
|
| 305 |
|
| 306 |
+
# Clear base_model_name_or_path so PEFT doesn't save a reference to the
|
| 307 |
+
# base LLM (e.g. Qwen). This prevents pipeline() from redirecting to the
|
| 308 |
+
# wrong model. The correct path gets set during save_pretrained/push_to_hub.
|
| 309 |
+
self.language_model.peft_config["default"].base_model_name_or_path = None
|
| 310 |
+
|
| 311 |
def _init_tokenizer(self, config: ASRConfig):
|
| 312 |
"""Initialize tokenizer with audio token."""
|
| 313 |
self.tokenizer = AutoTokenizer.from_pretrained(config.text_model_id, trust_remote_code=True)
|
|
|
|
| 466 |
if self.training and getattr(self.config, "use_specaugment", False):
|
| 467 |
input_features = apply_specaugment(
|
| 468 |
input_features,
|
| 469 |
+
num_time_masks=self.config.num_time_masks,
|
| 470 |
+
time_mask_length=self.config.time_mask_length,
|
| 471 |
+
num_freq_masks=self.config.num_freq_masks,
|
| 472 |
+
freq_mask_length=self.config.freq_mask_length,
|
|
|
|
|
|
|
| 473 |
)
|
| 474 |
|
| 475 |
# Encode audio -> flattened (total_audio_tokens, hidden_dim)
|
|
|
|
| 754 |
if hasattr(self.language_model, "peft_config"):
|
| 755 |
self.language_model.save_pretrained(save_dir, save_embedding_layers=False)
|
| 756 |
|
| 757 |
+
# Fix adapter_config.json to point base_model_name_or_path to the repo itself
|
| 758 |
+
# This prevents transformers pipeline() from redirecting to the base LLM repo
|
| 759 |
+
# (like Qwen) which breaks feature extractor loading for multimodal models.
|
| 760 |
+
# See: https://huggingface.co/ibm-granite/granite-speech-3.3-2b for reference
|
| 761 |
+
adapter_config_path = save_dir / "adapter_config.json"
|
| 762 |
+
if adapter_config_path.exists():
|
| 763 |
+
with adapter_config_path.open() as f:
|
| 764 |
+
adapter_config = json.load(f)
|
| 765 |
+
|
| 766 |
+
# Use repo_id from kwargs or config - never use checkpoint directory name
|
| 767 |
+
repo_id = (
|
| 768 |
+
kwargs.get("repo_id")
|
| 769 |
+
or kwargs.get("push_to_hub_model_id")
|
| 770 |
+
or getattr(self.config, "pretrained_model_path", None)
|
| 771 |
+
)
|
| 772 |
+
if repo_id:
|
| 773 |
+
adapter_config["base_model_name_or_path"] = repo_id
|
| 774 |
+
|
| 775 |
+
with adapter_config_path.open("w") as f:
|
| 776 |
+
json.dump(adapter_config, f, indent=2)
|
| 777 |
+
|
| 778 |
# Add processor auto_map to preprocessor_config.json
|
| 779 |
config_path = save_dir / "preprocessor_config.json"
|
| 780 |
if config_path.exists():
|
|
|
|
| 800 |
# Copy projectors module
|
| 801 |
shutil.copy(src_dir / "projectors.py", save_dir / "projectors.py")
|
| 802 |
|
| 803 |
+
def push_to_hub(self, repo_id: str, **kwargs) -> str:
|
| 804 |
+
"""Push model to HuggingFace Hub, ensuring adapter_config points to repo."""
|
| 805 |
+
# Call parent's push_to_hub with repo_id in kwargs so save_pretrained can use it
|
| 806 |
+
return super().push_to_hub(repo_id, repo_id=repo_id, **kwargs)
|
| 807 |
+
|
| 808 |
def create_or_update_model_card(self, output_dir: Union[str, Path]) -> None:
|
| 809 |
"""No-op for model card creation - we use MODEL_CARD.md in repo instead."""
|
| 810 |
pass
|
asr_pipeline.py
CHANGED
|
@@ -332,6 +332,7 @@ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
|
|
| 332 |
kwargs.pop("min_speakers", None)
|
| 333 |
kwargs.pop("max_speakers", None)
|
| 334 |
kwargs.pop("hf_token", None)
|
|
|
|
| 335 |
|
| 336 |
return super()._sanitize_parameters(**kwargs)
|
| 337 |
|
|
@@ -346,6 +347,7 @@ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
|
|
| 346 |
inputs: Audio input (file path, dict with array/sampling_rate, etc.)
|
| 347 |
return_timestamps: If True, return word-level timestamps using forced alignment
|
| 348 |
return_speakers: If True, return speaker labels for each word
|
|
|
|
| 349 |
num_speakers: Exact number of speakers (if known, for diarization)
|
| 350 |
min_speakers: Minimum number of speakers (for diarization)
|
| 351 |
max_speakers: Maximum number of speakers (for diarization)
|
|
@@ -359,6 +361,7 @@ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
|
|
| 359 |
# Extract our params before super().__call__ (which will also call _sanitize_parameters)
|
| 360 |
return_timestamps = kwargs.pop("return_timestamps", False)
|
| 361 |
return_speakers = kwargs.pop("return_speakers", False)
|
|
|
|
| 362 |
diarization_params = {
|
| 363 |
"num_speakers": kwargs.pop("num_speakers", None),
|
| 364 |
"min_speakers": kwargs.pop("min_speakers", None),
|
|
@@ -369,6 +372,12 @@ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
|
|
| 369 |
if return_speakers:
|
| 370 |
return_timestamps = True
|
| 371 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 372 |
# Store audio for timestamp alignment and diarization
|
| 373 |
if return_timestamps or return_speakers:
|
| 374 |
self._current_audio = self._extract_audio(inputs)
|
|
@@ -416,6 +425,8 @@ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
|
|
| 416 |
|
| 417 |
# Clean up
|
| 418 |
self._current_audio = None
|
|
|
|
|
|
|
| 419 |
|
| 420 |
return result
|
| 421 |
|
|
@@ -523,6 +534,13 @@ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
|
|
| 523 |
text = self._post_process_prediction(text)
|
| 524 |
return {"text": text}
|
| 525 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 526 |
def _post_process_prediction(self, text: str) -> str:
|
| 527 |
"""Post-process model output to fix common issues."""
|
| 528 |
if not text:
|
|
@@ -531,22 +549,29 @@ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
|
|
| 531 |
# 1. LOWERCASE
|
| 532 |
text = text.lower()
|
| 533 |
|
| 534 |
-
# 2.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 535 |
# Merge consecutive single letters into one word (e.g., "u s a" -> "usa")
|
| 536 |
text = re.sub(r"\b([a-z])((?:\s+[a-z])+)\b", lambda m: m.group(0).replace(" ", ""), text)
|
| 537 |
|
| 538 |
-
#
|
| 539 |
# Convert "eur X" to "X euros" for Whisper normalizer compatibility
|
| 540 |
text = re.sub(r"\beur\s+(\d+)", r"\1 euros", text)
|
| 541 |
|
| 542 |
-
#
|
|
|
|
|
|
|
|
|
|
| 543 |
text = self._truncate_trailing_repeats(text)
|
| 544 |
|
| 545 |
-
#
|
| 546 |
return re.sub(r"\s+", " ", text).strip()
|
| 547 |
|
| 548 |
-
def _truncate_trailing_repeats(self, text: str, max_ngram: int =
|
| 549 |
-
"""Remove trailing repeated n-grams (1-
|
| 550 |
words = text.split()
|
| 551 |
if len(words) < 2:
|
| 552 |
return text
|
|
@@ -566,3 +591,25 @@ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
|
|
| 566 |
break # Restart from largest n-gram
|
| 567 |
|
| 568 |
return " ".join(words)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 332 |
kwargs.pop("min_speakers", None)
|
| 333 |
kwargs.pop("max_speakers", None)
|
| 334 |
kwargs.pop("hf_token", None)
|
| 335 |
+
kwargs.pop("user_prompt", None)
|
| 336 |
|
| 337 |
return super()._sanitize_parameters(**kwargs)
|
| 338 |
|
|
|
|
| 347 |
inputs: Audio input (file path, dict with array/sampling_rate, etc.)
|
| 348 |
return_timestamps: If True, return word-level timestamps using forced alignment
|
| 349 |
return_speakers: If True, return speaker labels for each word
|
| 350 |
+
user_prompt: Custom transcription prompt (default: "Transcribe: ")
|
| 351 |
num_speakers: Exact number of speakers (if known, for diarization)
|
| 352 |
min_speakers: Minimum number of speakers (for diarization)
|
| 353 |
max_speakers: Maximum number of speakers (for diarization)
|
|
|
|
| 361 |
# Extract our params before super().__call__ (which will also call _sanitize_parameters)
|
| 362 |
return_timestamps = kwargs.pop("return_timestamps", False)
|
| 363 |
return_speakers = kwargs.pop("return_speakers", False)
|
| 364 |
+
user_prompt = kwargs.pop("user_prompt", None)
|
| 365 |
diarization_params = {
|
| 366 |
"num_speakers": kwargs.pop("num_speakers", None),
|
| 367 |
"min_speakers": kwargs.pop("min_speakers", None),
|
|
|
|
| 372 |
if return_speakers:
|
| 373 |
return_timestamps = True
|
| 374 |
|
| 375 |
+
# Set custom user prompt if provided
|
| 376 |
+
original_prompt = None
|
| 377 |
+
if user_prompt:
|
| 378 |
+
original_prompt = self.model.TRANSCRIBE_PROMPT
|
| 379 |
+
self.model.TRANSCRIBE_PROMPT = user_prompt
|
| 380 |
+
|
| 381 |
# Store audio for timestamp alignment and diarization
|
| 382 |
if return_timestamps or return_speakers:
|
| 383 |
self._current_audio = self._extract_audio(inputs)
|
|
|
|
| 425 |
|
| 426 |
# Clean up
|
| 427 |
self._current_audio = None
|
| 428 |
+
if original_prompt is not None:
|
| 429 |
+
self.model.TRANSCRIBE_PROMPT = original_prompt
|
| 430 |
|
| 431 |
return result
|
| 432 |
|
|
|
|
| 534 |
text = self._post_process_prediction(text)
|
| 535 |
return {"text": text}
|
| 536 |
|
| 537 |
+
# Known hallucination patterns that should be deleted entirely
|
| 538 |
+
HALLUCINATION_PATTERNS = frozenset(
|
| 539 |
+
[
|
| 540 |
+
"and gt and gt",
|
| 541 |
+
]
|
| 542 |
+
)
|
| 543 |
+
|
| 544 |
def _post_process_prediction(self, text: str) -> str:
|
| 545 |
"""Post-process model output to fix common issues."""
|
| 546 |
if not text:
|
|
|
|
| 549 |
# 1. LOWERCASE
|
| 550 |
text = text.lower()
|
| 551 |
|
| 552 |
+
# 2. CHECK FOR KNOWN HALLUCINATIONS (delete entirely)
|
| 553 |
+
if text.strip() in self.HALLUCINATION_PATTERNS:
|
| 554 |
+
return ""
|
| 555 |
+
|
| 556 |
+
# 3. COMBINE ACRONYMS
|
| 557 |
# Merge consecutive single letters into one word (e.g., "u s a" -> "usa")
|
| 558 |
text = re.sub(r"\b([a-z])((?:\s+[a-z])+)\b", lambda m: m.group(0).replace(" ", ""), text)
|
| 559 |
|
| 560 |
+
# 4. NORMALIZE CURRENCY
|
| 561 |
# Convert "eur X" to "X euros" for Whisper normalizer compatibility
|
| 562 |
text = re.sub(r"\beur\s+(\d+)", r"\1 euros", text)
|
| 563 |
|
| 564 |
+
# 5. TRUNCATE CHARACTER REPETITIONS (e.g., "uhhhhhh" -> "uhh")
|
| 565 |
+
text = self._truncate_character_repetitions(text)
|
| 566 |
+
|
| 567 |
+
# 6. TRUNCATE TRAILING REPEATS (word-level)
|
| 568 |
text = self._truncate_trailing_repeats(text)
|
| 569 |
|
| 570 |
+
# 7. STRIP WHITESPACE
|
| 571 |
return re.sub(r"\s+", " ", text).strip()
|
| 572 |
|
| 573 |
+
def _truncate_trailing_repeats(self, text: str, max_ngram: int = 10) -> str:
|
| 574 |
+
"""Remove trailing repeated n-grams (1-10 words)."""
|
| 575 |
words = text.split()
|
| 576 |
if len(words) < 2:
|
| 577 |
return text
|
|
|
|
| 591 |
break # Restart from largest n-gram
|
| 592 |
|
| 593 |
return " ".join(words)
|
| 594 |
+
|
| 595 |
+
def _truncate_character_repetitions(self, text: str, max_repeats: int = 3) -> str:
|
| 596 |
+
"""Remove excessive character repetitions (e.g., 'uhhhhhh' -> 'uhh').
|
| 597 |
+
|
| 598 |
+
Handles hallucinations where the model outputs the same character many times,
|
| 599 |
+
like "uhhhhhhhhhhhhhhhhhhhhhhhhh" at the end of a prediction.
|
| 600 |
+
|
| 601 |
+
Args:
|
| 602 |
+
text: Input text to clean
|
| 603 |
+
max_repeats: Maximum allowed consecutive repetitions of a character
|
| 604 |
+
|
| 605 |
+
Returns:
|
| 606 |
+
Text with character repetitions truncated
|
| 607 |
+
"""
|
| 608 |
+
if not text:
|
| 609 |
+
return text
|
| 610 |
+
|
| 611 |
+
# Replace any character repeated more than max_repeats times with max_repeats
|
| 612 |
+
# Pattern: any character followed by itself N+ times
|
| 613 |
+
pattern = rf"(.)\1{{{max_repeats},}}"
|
| 614 |
+
replacement = r"\1" * max_repeats
|
| 615 |
+
return re.sub(pattern, replacement, text)
|