Update custom model files, README, and requirements
Browse files- .gitattributes +0 -1
- README.md +42 -106
- asr_config.py +6 -6
- asr_modeling.py +14 -30
- asr_pipeline.py +35 -11
- handler.py +8 -60
.gitattributes
CHANGED
|
@@ -1,4 +1,3 @@
|
|
| 1 |
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 2 |
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 3 |
tokenizer_config.json -filter -diff -merge text
|
| 4 |
-
tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
|
|
|
| 1 |
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 2 |
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 3 |
tokenizer_config.json -filter -diff -merge text
|
|
|
README.md
CHANGED
|
@@ -1,123 +1,59 @@
|
|
| 1 |
---
|
| 2 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
tags:
|
| 4 |
-
-
|
| 5 |
-
|
| 6 |
-
-
|
| 7 |
-
|
|
|
|
| 8 |
---
|
| 9 |
|
| 10 |
-
|
| 11 |
-
should probably proofread and complete it, then remove this comment. -->
|
| 12 |
|
| 13 |
-
|
| 14 |
|
| 15 |
-
|
| 16 |
-
It achieves the following results on the evaluation set:
|
| 17 |
-
- Loss: 0.2566
|
| 18 |
|
| 19 |
-
|
|
|
|
|
|
|
| 20 |
|
| 21 |
-
|
| 22 |
|
| 23 |
-
##
|
| 24 |
|
| 25 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
|
| 27 |
-
##
|
| 28 |
|
| 29 |
-
|
|
|
|
| 30 |
|
| 31 |
-
|
|
|
|
|
|
|
|
|
|
| 32 |
|
| 33 |
-
|
| 34 |
|
| 35 |
-
|
| 36 |
-
-
|
| 37 |
-
-
|
| 38 |
-
- eval_batch_size: 16
|
| 39 |
-
- seed: 936
|
| 40 |
-
- optimizer: Use OptimizerNames.ADAMW_TORCH_FUSED with betas=(0.9,0.95) and epsilon=1e-08 and optimizer_args=No additional optimizer arguments
|
| 41 |
-
- lr_scheduler_type: cosine
|
| 42 |
-
- lr_scheduler_warmup_steps: 500
|
| 43 |
-
- num_epochs: 1
|
| 44 |
|
| 45 |
-
|
| 46 |
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
| 0.2888 | 0.0149 | 1000 | 0.2819 |
|
| 50 |
-
| 0.3565 | 0.0298 | 2000 | 0.2919 |
|
| 51 |
-
| 0.3189 | 0.0447 | 3000 | 0.2879 |
|
| 52 |
-
| 0.3274 | 0.0596 | 4000 | 0.2929 |
|
| 53 |
-
| 0.3231 | 0.0745 | 5000 | 0.2870 |
|
| 54 |
-
| 0.3270 | 0.0894 | 6000 | 0.2853 |
|
| 55 |
-
| 0.3486 | 0.1043 | 7000 | 0.2860 |
|
| 56 |
-
| 0.3066 | 0.1192 | 8000 | 0.2865 |
|
| 57 |
-
| 0.3487 | 0.1341 | 9000 | 0.2866 |
|
| 58 |
-
| 0.3307 | 0.1490 | 10000 | 0.2871 |
|
| 59 |
-
| 0.3419 | 0.1639 | 11000 | 0.2852 |
|
| 60 |
-
| 0.3601 | 0.1788 | 12000 | 0.2848 |
|
| 61 |
-
| 0.3156 | 0.1936 | 13000 | 0.2860 |
|
| 62 |
-
| 0.3098 | 0.2085 | 14000 | 0.2830 |
|
| 63 |
-
| 0.3133 | 0.2234 | 15000 | 0.2851 |
|
| 64 |
-
| 0.3269 | 0.2383 | 16000 | 0.2826 |
|
| 65 |
-
| 0.3257 | 0.2532 | 17000 | 0.2822 |
|
| 66 |
-
| 0.3281 | 0.2681 | 18000 | 0.2822 |
|
| 67 |
-
| 0.3941 | 0.2830 | 19000 | 0.2813 |
|
| 68 |
-
| 0.3875 | 0.2979 | 20000 | 0.2854 |
|
| 69 |
-
| 0.3214 | 0.3128 | 21000 | 0.2795 |
|
| 70 |
-
| 0.2914 | 0.3277 | 22000 | 0.2792 |
|
| 71 |
-
| 0.2951 | 0.3426 | 23000 | 0.2805 |
|
| 72 |
-
| 0.3343 | 0.3575 | 24000 | 0.2779 |
|
| 73 |
-
| 0.3252 | 0.3724 | 25000 | 0.2771 |
|
| 74 |
-
| 0.3027 | 0.3873 | 26000 | 0.2768 |
|
| 75 |
-
| 0.3287 | 0.4022 | 27000 | 0.2759 |
|
| 76 |
-
| 0.3208 | 0.4171 | 28000 | 0.2749 |
|
| 77 |
-
| 0.3402 | 0.4320 | 29000 | 0.2730 |
|
| 78 |
-
| 0.2928 | 0.4469 | 30000 | 0.2726 |
|
| 79 |
-
| 0.3085 | 0.4618 | 31000 | 0.2737 |
|
| 80 |
-
| 0.3073 | 0.4767 | 32000 | 0.2705 |
|
| 81 |
-
| 0.3471 | 0.4916 | 33000 | 0.2708 |
|
| 82 |
-
| 0.2945 | 0.5065 | 34000 | 0.2690 |
|
| 83 |
-
| 0.3294 | 0.5214 | 35000 | 0.2696 |
|
| 84 |
-
| 0.3095 | 0.5363 | 36000 | 0.2679 |
|
| 85 |
-
| 0.3152 | 0.5512 | 37000 | 0.2659 |
|
| 86 |
-
| 0.3035 | 0.5660 | 38000 | 0.2674 |
|
| 87 |
-
| 0.3342 | 0.5809 | 39000 | 0.2656 |
|
| 88 |
-
| 0.3242 | 0.5958 | 40000 | 0.2653 |
|
| 89 |
-
| 0.2789 | 0.6107 | 41000 | 0.2643 |
|
| 90 |
-
| 0.3082 | 0.6256 | 42000 | 0.2643 |
|
| 91 |
-
| 0.3174 | 0.6405 | 43000 | 0.2633 |
|
| 92 |
-
| 0.2730 | 0.6554 | 44000 | 0.2628 |
|
| 93 |
-
| 0.2934 | 0.6703 | 45000 | 0.2609 |
|
| 94 |
-
| 0.2944 | 0.6852 | 46000 | 0.2606 |
|
| 95 |
-
| 0.3111 | 0.7001 | 47000 | 0.2614 |
|
| 96 |
-
| 0.3431 | 0.7150 | 48000 | 0.2605 |
|
| 97 |
-
| 0.3226 | 0.7299 | 49000 | 0.2601 |
|
| 98 |
-
| 0.2735 | 0.7448 | 50000 | 0.2591 |
|
| 99 |
-
| 0.3208 | 0.7597 | 51000 | 0.2590 |
|
| 100 |
-
| 0.3208 | 0.7746 | 52000 | 0.2584 |
|
| 101 |
-
| 0.3021 | 0.7895 | 53000 | 0.2578 |
|
| 102 |
-
| 0.2730 | 0.8044 | 54000 | 0.2583 |
|
| 103 |
-
| 0.2938 | 0.8193 | 55000 | 0.2581 |
|
| 104 |
-
| 0.2894 | 0.8342 | 56000 | 0.2574 |
|
| 105 |
-
| 0.2781 | 0.8491 | 57000 | 0.2572 |
|
| 106 |
-
| 0.3003 | 0.8640 | 58000 | 0.2568 |
|
| 107 |
-
| 0.2719 | 0.8789 | 59000 | 0.2568 |
|
| 108 |
-
| 0.2878 | 0.8938 | 60000 | 0.2567 |
|
| 109 |
-
| 0.3058 | 0.9087 | 61000 | 0.2568 |
|
| 110 |
-
| 0.3036 | 0.9236 | 62000 | 0.2568 |
|
| 111 |
-
| 0.3050 | 0.9384 | 63000 | 0.2568 |
|
| 112 |
-
| 0.3244 | 0.9533 | 64000 | 0.2567 |
|
| 113 |
-
| 0.3187 | 0.9682 | 65000 | 0.2566 |
|
| 114 |
-
| 0.3016 | 0.9831 | 66000 | 0.2566 |
|
| 115 |
-
| 0.2697 | 0.9980 | 67000 | 0.2566 |
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
### Framework versions
|
| 119 |
-
|
| 120 |
-
- Transformers 5.0.0.dev0
|
| 121 |
-
- Pytorch 2.8.0+cu128
|
| 122 |
-
- Datasets 3.6.0
|
| 123 |
-
- Tokenizers 0.22.1
|
|
|
|
| 1 |
---
|
| 2 |
+
license: mit
|
| 3 |
+
language:
|
| 4 |
+
- en
|
| 5 |
+
datasets:
|
| 6 |
+
- speechbrain/LoquaciousSet
|
| 7 |
+
base_model:
|
| 8 |
+
- zai-org/GLM-ASR-Nano-2512
|
| 9 |
+
- Qwen/Qwen3-0.6B
|
| 10 |
+
pipeline_tag: automatic-speech-recognition
|
| 11 |
tags:
|
| 12 |
+
- asr
|
| 13 |
+
- speech-recognition
|
| 14 |
+
- audio
|
| 15 |
+
- qwen
|
| 16 |
+
- glm-asr
|
| 17 |
---
|
| 18 |
|
| 19 |
+
# Tiny Audio
|
|
|
|
| 20 |
|
| 21 |
+
A speech recognition model trained in 24 hours on a single GPU for ~$12. Built with [Tiny Audio](https://github.com/alexkroman/tiny-audio)—a minimal, hackable ASR framework.
|
| 22 |
|
| 23 |
+
## Architecture
|
|
|
|
|
|
|
| 24 |
|
| 25 |
+
```
|
| 26 |
+
Audio (16kHz) → GLM-ASR Encoder (frozen) → MLP Projector (trained) → Qwen3 (frozen) → Text
|
| 27 |
+
```
|
| 28 |
|
| 29 |
+
Only the projector is trained (~12M params). The encoder and decoder remain frozen.
|
| 30 |
|
| 31 |
+
## Training
|
| 32 |
|
| 33 |
+
| | |
|
| 34 |
+
|---|---|
|
| 35 |
+
| **Dataset** | LoquaciousSet (25,000 hours) |
|
| 36 |
+
| **Hardware** | Single NVIDIA A40 |
|
| 37 |
+
| **Time** | ~24 hours |
|
| 38 |
+
| **Cost** | ~$12 |
|
| 39 |
|
| 40 |
+
## Usage
|
| 41 |
|
| 42 |
+
```python
|
| 43 |
+
from transformers import pipeline
|
| 44 |
|
| 45 |
+
pipe = pipeline("automatic-speech-recognition", model="mazesmazes/tiny-audio", trust_remote_code=True)
|
| 46 |
+
result = pipe("audio.wav")
|
| 47 |
+
print(result["text"])
|
| 48 |
+
```
|
| 49 |
|
| 50 |
+
## Limitations
|
| 51 |
|
| 52 |
+
- English only
|
| 53 |
+
- 16kHz audio (other sample rates resampled automatically)
|
| 54 |
+
- May degrade on accented speech, noisy audio, or domain-specific terms
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
|
| 56 |
+
## Links
|
| 57 |
|
| 58 |
+
- [Train your own](https://github.com/alexkroman/tiny-audio)
|
| 59 |
+
- [Free 3.5-hour course](https://github.com/alexkroman/tiny-audio/blob/main/docs/course/0-course-overview.md)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
asr_config.py
CHANGED
|
@@ -7,8 +7,8 @@ class ASRConfig(transformers.PretrainedConfig):
|
|
| 7 |
"""Configuration class for the ASR model.
|
| 8 |
|
| 9 |
This config combines settings for:
|
| 10 |
-
- Audio encoder (Whisper)
|
| 11 |
-
- Text decoder (
|
| 12 |
- Projector (MLP, MOSA, MoE, QFormer)
|
| 13 |
- Generation parameters
|
| 14 |
- Training options (SpecAugment, LoRA)
|
|
@@ -19,8 +19,8 @@ class ASRConfig(transformers.PretrainedConfig):
|
|
| 19 |
|
| 20 |
def __init__(
|
| 21 |
self,
|
| 22 |
-
audio_model_id: str = "
|
| 23 |
-
text_model_id: str = "
|
| 24 |
attn_implementation: str = "flash_attention_2",
|
| 25 |
model_dtype: str = "bfloat16",
|
| 26 |
num_beams: Optional[int] = None,
|
|
@@ -74,8 +74,8 @@ class ASRConfig(transformers.PretrainedConfig):
|
|
| 74 |
"""Initialize ASR model configuration.
|
| 75 |
|
| 76 |
Args:
|
| 77 |
-
audio_model_id: HuggingFace model ID for audio encoder (Whisper)
|
| 78 |
-
text_model_id: HuggingFace model ID for text decoder (
|
| 79 |
attn_implementation: Attention implementation ("flash_attention_2", "sdpa", "eager")
|
| 80 |
model_dtype: Model dtype ("bfloat16", "float16", "float32")
|
| 81 |
projector_type: Projector architecture ("mlp", "mosa", "moe", "qformer")
|
|
|
|
| 7 |
"""Configuration class for the ASR model.
|
| 8 |
|
| 9 |
This config combines settings for:
|
| 10 |
+
- Audio encoder (GLM-ASR/Whisper)
|
| 11 |
+
- Text decoder (Qwen)
|
| 12 |
- Projector (MLP, MOSA, MoE, QFormer)
|
| 13 |
- Generation parameters
|
| 14 |
- Training options (SpecAugment, LoRA)
|
|
|
|
| 19 |
|
| 20 |
def __init__(
|
| 21 |
self,
|
| 22 |
+
audio_model_id: str = "zai-org/GLM-ASR-Nano-2512",
|
| 23 |
+
text_model_id: str = "Qwen/Qwen3-0.6B",
|
| 24 |
attn_implementation: str = "flash_attention_2",
|
| 25 |
model_dtype: str = "bfloat16",
|
| 26 |
num_beams: Optional[int] = None,
|
|
|
|
| 74 |
"""Initialize ASR model configuration.
|
| 75 |
|
| 76 |
Args:
|
| 77 |
+
audio_model_id: HuggingFace model ID for audio encoder (GLM-ASR/Whisper)
|
| 78 |
+
text_model_id: HuggingFace model ID for text decoder (Qwen)
|
| 79 |
attn_implementation: Attention implementation ("flash_attention_2", "sdpa", "eager")
|
| 80 |
model_dtype: Model dtype ("bfloat16", "float16", "float32")
|
| 81 |
projector_type: Projector architecture ("mlp", "mosa", "moe", "qformer")
|
asr_modeling.py
CHANGED
|
@@ -24,26 +24,7 @@ except ImportError:
|
|
| 24 |
from projectors import PROJECTOR_CLASSES # type: ignore[no-redef]
|
| 25 |
|
| 26 |
|
| 27 |
-
from torchaudio.transforms import
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
def apply_specaugment(
|
| 31 |
-
x: torch.Tensor,
|
| 32 |
-
num_time_masks: int = 2,
|
| 33 |
-
time_mask_length: int = 10,
|
| 34 |
-
num_freq_masks: int = 0,
|
| 35 |
-
freq_mask_length: int = 10,
|
| 36 |
-
) -> torch.Tensor:
|
| 37 |
-
"""Apply SpecAugment using torchaudio. Input shape: (batch, n_mels, time)."""
|
| 38 |
-
if num_time_masks > 0:
|
| 39 |
-
tm = TimeMasking(time_mask_param=time_mask_length, iid_masks=True)
|
| 40 |
-
for _ in range(num_time_masks):
|
| 41 |
-
x = tm(x)
|
| 42 |
-
if num_freq_masks > 0:
|
| 43 |
-
fm = FrequencyMasking(freq_mask_param=freq_mask_length, iid_masks=True)
|
| 44 |
-
for _ in range(num_freq_masks):
|
| 45 |
-
x = fm(x)
|
| 46 |
-
return x
|
| 47 |
|
| 48 |
|
| 49 |
class ASRModel(PreTrainedModel, GenerationMixin):
|
|
@@ -192,6 +173,17 @@ class ASRModel(PreTrainedModel, GenerationMixin):
|
|
| 192 |
if getattr(config, "freeze_projector", False):
|
| 193 |
self.projector.requires_grad_(False)
|
| 194 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 195 |
# For model parallelism
|
| 196 |
self._no_split_modules = getattr(self.language_model, "_no_split_modules", [])
|
| 197 |
|
|
@@ -230,8 +222,6 @@ class ASRModel(PreTrainedModel, GenerationMixin):
|
|
| 230 |
full_model.language_model = None
|
| 231 |
full_model.multi_modal_projector = None
|
| 232 |
del full_model
|
| 233 |
-
if torch.cuda.is_available():
|
| 234 |
-
torch.cuda.empty_cache()
|
| 235 |
else:
|
| 236 |
encoder = AutoModel.from_pretrained(config.audio_model_id, **encoder_kwargs)
|
| 237 |
|
|
@@ -463,14 +453,8 @@ class ASRModel(PreTrainedModel, GenerationMixin):
|
|
| 463 |
|
| 464 |
if input_features is not None and input_ids is not None:
|
| 465 |
# Apply SpecAugment during training if enabled
|
| 466 |
-
if self.training and
|
| 467 |
-
input_features =
|
| 468 |
-
input_features,
|
| 469 |
-
num_time_masks=self.config.num_time_masks,
|
| 470 |
-
time_mask_length=self.config.time_mask_length,
|
| 471 |
-
num_freq_masks=self.config.num_freq_masks,
|
| 472 |
-
freq_mask_length=self.config.freq_mask_length,
|
| 473 |
-
)
|
| 474 |
|
| 475 |
# Encode audio -> flattened (total_audio_tokens, hidden_dim)
|
| 476 |
audio_embeds = self._encode_audio(input_features, audio_attention_mask)
|
|
|
|
| 24 |
from projectors import PROJECTOR_CLASSES # type: ignore[no-redef]
|
| 25 |
|
| 26 |
|
| 27 |
+
from torchaudio.transforms import SpecAugment
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
|
| 29 |
|
| 30 |
class ASRModel(PreTrainedModel, GenerationMixin):
|
|
|
|
| 173 |
if getattr(config, "freeze_projector", False):
|
| 174 |
self.projector.requires_grad_(False)
|
| 175 |
|
| 176 |
+
# SpecAugment for data augmentation during training
|
| 177 |
+
if getattr(config, "use_specaugment", False):
|
| 178 |
+
self.spec_augment = SpecAugment(
|
| 179 |
+
n_time_masks=config.num_time_masks,
|
| 180 |
+
time_mask_param=config.time_mask_length,
|
| 181 |
+
n_freq_masks=config.num_freq_masks,
|
| 182 |
+
freq_mask_param=config.freq_mask_length,
|
| 183 |
+
)
|
| 184 |
+
else:
|
| 185 |
+
self.spec_augment = None
|
| 186 |
+
|
| 187 |
# For model parallelism
|
| 188 |
self._no_split_modules = getattr(self.language_model, "_no_split_modules", [])
|
| 189 |
|
|
|
|
| 222 |
full_model.language_model = None
|
| 223 |
full_model.multi_modal_projector = None
|
| 224 |
del full_model
|
|
|
|
|
|
|
| 225 |
else:
|
| 226 |
encoder = AutoModel.from_pretrained(config.audio_model_id, **encoder_kwargs)
|
| 227 |
|
|
|
|
| 453 |
|
| 454 |
if input_features is not None and input_ids is not None:
|
| 455 |
# Apply SpecAugment during training if enabled
|
| 456 |
+
if self.training and self.spec_augment is not None:
|
| 457 |
+
input_features = self.spec_augment(input_features)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 458 |
|
| 459 |
# Encode audio -> flattened (total_audio_tokens, hidden_dim)
|
| 460 |
audio_embeds = self._encode_audio(input_features, audio_attention_mask)
|
asr_pipeline.py
CHANGED
|
@@ -14,6 +14,15 @@ except ImportError:
|
|
| 14 |
from asr_modeling import ASRModel # type: ignore[no-redef]
|
| 15 |
|
| 16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
class ForcedAligner:
|
| 18 |
"""Lazy-loaded forced aligner for word-level timestamps using torchaudio wav2vec2."""
|
| 19 |
|
|
@@ -66,7 +75,7 @@ class ForcedAligner:
|
|
| 66 |
import torchaudio
|
| 67 |
from torchaudio.functional import forced_align, merge_tokens
|
| 68 |
|
| 69 |
-
device =
|
| 70 |
model, labels, dictionary = cls.get_instance(device)
|
| 71 |
|
| 72 |
# Convert audio to tensor (copy to ensure array is writable)
|
|
@@ -179,11 +188,8 @@ class SpeakerDiarizer:
|
|
| 179 |
"pyannote/speaker-diarization-3.1",
|
| 180 |
)
|
| 181 |
|
| 182 |
-
# Move to
|
| 183 |
-
|
| 184 |
-
cls._pipeline.to(torch.device("cuda"))
|
| 185 |
-
elif torch.backends.mps.is_available():
|
| 186 |
-
cls._pipeline.to(torch.device("mps"))
|
| 187 |
|
| 188 |
return cls._pipeline
|
| 189 |
|
|
@@ -539,9 +545,18 @@ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
|
|
| 539 |
HALLUCINATION_PATTERNS = frozenset(
|
| 540 |
[
|
| 541 |
"and gt and gt",
|
|
|
|
| 542 |
]
|
| 543 |
)
|
| 544 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 545 |
def _post_process_prediction(self, text: str) -> str:
|
| 546 |
"""Post-process model output to fix common issues."""
|
| 547 |
if not text:
|
|
@@ -554,21 +569,30 @@ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
|
|
| 554 |
if text.strip() in self.HALLUCINATION_PATTERNS:
|
| 555 |
return ""
|
| 556 |
|
| 557 |
-
# 3.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 558 |
# Merge consecutive single letters into one word (e.g., "u s a" -> "usa")
|
| 559 |
text = re.sub(r"\b([a-z])((?:\s+[a-z])+)\b", lambda m: m.group(0).replace(" ", ""), text)
|
| 560 |
|
| 561 |
-
#
|
| 562 |
# Convert "eur X" to "X euros" for Whisper normalizer compatibility
|
| 563 |
text = re.sub(r"\beur\s+(\d+)", r"\1 euros", text)
|
| 564 |
|
| 565 |
-
#
|
| 566 |
text = self._truncate_character_repetitions(text)
|
| 567 |
|
| 568 |
-
#
|
| 569 |
text = self._truncate_trailing_repeats(text)
|
| 570 |
|
| 571 |
-
#
|
| 572 |
return re.sub(r"\s+", " ", text).strip()
|
| 573 |
|
| 574 |
def _truncate_trailing_repeats(self, text: str, max_ngram: int = 10) -> str:
|
|
|
|
| 14 |
from asr_modeling import ASRModel # type: ignore[no-redef]
|
| 15 |
|
| 16 |
|
| 17 |
+
def _get_device() -> str:
|
| 18 |
+
"""Get best available device for non-transformers models."""
|
| 19 |
+
if torch.cuda.is_available():
|
| 20 |
+
return "cuda"
|
| 21 |
+
if torch.backends.mps.is_available():
|
| 22 |
+
return "mps"
|
| 23 |
+
return "cpu"
|
| 24 |
+
|
| 25 |
+
|
| 26 |
class ForcedAligner:
|
| 27 |
"""Lazy-loaded forced aligner for word-level timestamps using torchaudio wav2vec2."""
|
| 28 |
|
|
|
|
| 75 |
import torchaudio
|
| 76 |
from torchaudio.functional import forced_align, merge_tokens
|
| 77 |
|
| 78 |
+
device = _get_device()
|
| 79 |
model, labels, dictionary = cls.get_instance(device)
|
| 80 |
|
| 81 |
# Convert audio to tensor (copy to ensure array is writable)
|
|
|
|
| 188 |
"pyannote/speaker-diarization-3.1",
|
| 189 |
)
|
| 190 |
|
| 191 |
+
# Move to best available device
|
| 192 |
+
cls._pipeline.to(torch.device(_get_device()))
|
|
|
|
|
|
|
|
|
|
| 193 |
|
| 194 |
return cls._pipeline
|
| 195 |
|
|
|
|
| 545 |
HALLUCINATION_PATTERNS = frozenset(
|
| 546 |
[
|
| 547 |
"and gt and gt",
|
| 548 |
+
"n", # Single character noise
|
| 549 |
]
|
| 550 |
)
|
| 551 |
|
| 552 |
+
# Regex patterns for hallucinations (compiled for efficiency)
|
| 553 |
+
HALLUCINATION_REGEXES = [
|
| 554 |
+
# Repeating decimal hallucinations (e.g., "12.93242424242424")
|
| 555 |
+
re.compile(r"\d+\.\d*?(\d{2,})\1{3,}"),
|
| 556 |
+
# Very long repeated digit sequences (e.g., "242424242424")
|
| 557 |
+
re.compile(r"(\d{2,})\1{4,}"),
|
| 558 |
+
]
|
| 559 |
+
|
| 560 |
def _post_process_prediction(self, text: str) -> str:
|
| 561 |
"""Post-process model output to fix common issues."""
|
| 562 |
if not text:
|
|
|
|
| 569 |
if text.strip() in self.HALLUCINATION_PATTERNS:
|
| 570 |
return ""
|
| 571 |
|
| 572 |
+
# 3. CHECK FOR REGEX-BASED HALLUCINATIONS
|
| 573 |
+
for pattern in self.HALLUCINATION_REGEXES:
|
| 574 |
+
if pattern.search(text):
|
| 575 |
+
# If hallucination is the entire output, return empty
|
| 576 |
+
if pattern.fullmatch(text.strip()):
|
| 577 |
+
return ""
|
| 578 |
+
# Otherwise remove the hallucinated portion
|
| 579 |
+
text = pattern.sub("", text)
|
| 580 |
+
|
| 581 |
+
# 4. COMBINE ACRONYMS
|
| 582 |
# Merge consecutive single letters into one word (e.g., "u s a" -> "usa")
|
| 583 |
text = re.sub(r"\b([a-z])((?:\s+[a-z])+)\b", lambda m: m.group(0).replace(" ", ""), text)
|
| 584 |
|
| 585 |
+
# 5. NORMALIZE CURRENCY
|
| 586 |
# Convert "eur X" to "X euros" for Whisper normalizer compatibility
|
| 587 |
text = re.sub(r"\beur\s+(\d+)", r"\1 euros", text)
|
| 588 |
|
| 589 |
+
# 6. TRUNCATE CHARACTER REPETITIONS (e.g., "uhhhhhh" -> "uhh")
|
| 590 |
text = self._truncate_character_repetitions(text)
|
| 591 |
|
| 592 |
+
# 7. TRUNCATE TRAILING REPEATS (word-level)
|
| 593 |
text = self._truncate_trailing_repeats(text)
|
| 594 |
|
| 595 |
+
# 8. STRIP WHITESPACE
|
| 596 |
return re.sub(r"\s+", " ", text).strip()
|
| 597 |
|
| 598 |
def _truncate_trailing_repeats(self, text: str, max_ngram: int = 10) -> str:
|
handler.py
CHANGED
|
@@ -2,8 +2,6 @@
|
|
| 2 |
|
| 3 |
from typing import Any, Dict, List, Union
|
| 4 |
|
| 5 |
-
import torch
|
| 6 |
-
|
| 7 |
try:
|
| 8 |
# For remote execution, imports are relative
|
| 9 |
from .asr_modeling import ASRModel
|
|
@@ -35,35 +33,21 @@ class EndpointHandler:
|
|
| 35 |
|
| 36 |
os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")
|
| 37 |
|
| 38 |
-
#
|
| 39 |
-
# Also beneficial for T4 (Turing) which supports TensorFloat-32
|
| 40 |
-
torch.backends.cuda.matmul.allow_tf32 = True
|
| 41 |
-
torch.backends.cudnn.allow_tf32 = True
|
| 42 |
-
|
| 43 |
-
# Set device and dtype
|
| 44 |
-
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 45 |
-
|
| 46 |
-
# Use float16 for better T4 compatibility (bfloat16 not well supported on T4)
|
| 47 |
-
# T4 has excellent float16 performance with tensor cores
|
| 48 |
-
self.dtype = torch.float16 if self.device == "cuda" else torch.float32
|
| 49 |
-
|
| 50 |
-
# Enable CUDA optimizations
|
| 51 |
-
if torch.cuda.is_available():
|
| 52 |
-
torch.backends.cudnn.benchmark = True
|
| 53 |
-
|
| 54 |
-
# Prepare model kwargs for pipeline
|
| 55 |
model_kwargs = {
|
| 56 |
-
"
|
|
|
|
| 57 |
"low_cpu_mem_usage": True,
|
| 58 |
}
|
| 59 |
-
if
|
| 60 |
-
model_kwargs["attn_implementation"] =
|
| 61 |
-
"flash_attention_2" if self._is_flash_attn_available() else "sdpa"
|
| 62 |
-
)
|
| 63 |
|
| 64 |
# Load model (this loads the model, tokenizer, and feature extractor)
|
| 65 |
self.model = ASRModel.from_pretrained(path, **model_kwargs)
|
| 66 |
|
|
|
|
|
|
|
|
|
|
| 67 |
# Instantiate custom pipeline - it will get feature_extractor and tokenizer from model
|
| 68 |
self.pipe = ASRPipeline(
|
| 69 |
model=self.model,
|
|
@@ -72,48 +56,12 @@ class EndpointHandler:
|
|
| 72 |
device=self.device,
|
| 73 |
)
|
| 74 |
|
| 75 |
-
# Apply torch.compile if enabled (after model is loaded by pipeline)
|
| 76 |
-
# Use "default" mode for T4 - better compatibility than "reduce-overhead"
|
| 77 |
-
# "reduce-overhead" is better for A100+ but can be slower on older GPUs
|
| 78 |
-
if torch.cuda.is_available() and os.getenv("ENABLE_TORCH_COMPILE", "1") == "1":
|
| 79 |
-
compile_mode = os.getenv("TORCH_COMPILE_MODE", "default")
|
| 80 |
-
self.model = torch.compile(self.model, mode=compile_mode)
|
| 81 |
-
self.pipe.model = self.model
|
| 82 |
-
|
| 83 |
-
# Warmup the model to trigger compilation and optimize kernels
|
| 84 |
-
if torch.cuda.is_available():
|
| 85 |
-
self._warmup()
|
| 86 |
-
|
| 87 |
def _is_flash_attn_available(self):
|
| 88 |
"""Check if flash attention is available."""
|
| 89 |
import importlib.util
|
| 90 |
|
| 91 |
return importlib.util.find_spec("flash_attn") is not None
|
| 92 |
|
| 93 |
-
def _warmup(self):
|
| 94 |
-
"""Warmup to trigger model compilation and allocate GPU memory."""
|
| 95 |
-
try:
|
| 96 |
-
# Create dummy audio (1 second at config sample rate)
|
| 97 |
-
sample_rate = self.pipe.model.config.audio_sample_rate
|
| 98 |
-
dummy_audio = torch.randn(sample_rate, dtype=torch.float32)
|
| 99 |
-
|
| 100 |
-
# Run inference to trigger torch.compile and kernel optimization
|
| 101 |
-
with torch.inference_mode():
|
| 102 |
-
warmup_tokens = self.pipe.model.config.inference_warmup_tokens
|
| 103 |
-
_ = self.pipe(
|
| 104 |
-
{"raw": dummy_audio, "sampling_rate": sample_rate},
|
| 105 |
-
max_new_tokens=warmup_tokens,
|
| 106 |
-
)
|
| 107 |
-
|
| 108 |
-
# Force CUDA synchronization to ensure kernels are compiled
|
| 109 |
-
if torch.cuda.is_available():
|
| 110 |
-
torch.cuda.synchronize()
|
| 111 |
-
# Clear cache after warmup to free memory
|
| 112 |
-
torch.cuda.empty_cache()
|
| 113 |
-
|
| 114 |
-
except Exception as e:
|
| 115 |
-
print(f"Warmup skipped due to: {e}")
|
| 116 |
-
|
| 117 |
def __call__(self, data: Dict[str, Any]) -> Union[Dict[str, Any], List[Dict[str, Any]]]:
|
| 118 |
"""Process an inference request.
|
| 119 |
|
|
|
|
| 2 |
|
| 3 |
from typing import Any, Dict, List, Union
|
| 4 |
|
|
|
|
|
|
|
| 5 |
try:
|
| 6 |
# For remote execution, imports are relative
|
| 7 |
from .asr_modeling import ASRModel
|
|
|
|
| 33 |
|
| 34 |
os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")
|
| 35 |
|
| 36 |
+
# Prepare model kwargs - let transformers handle device placement
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
model_kwargs = {
|
| 38 |
+
"device_map": "auto",
|
| 39 |
+
"torch_dtype": "auto",
|
| 40 |
"low_cpu_mem_usage": True,
|
| 41 |
}
|
| 42 |
+
if self._is_flash_attn_available():
|
| 43 |
+
model_kwargs["attn_implementation"] = "flash_attention_2"
|
|
|
|
|
|
|
| 44 |
|
| 45 |
# Load model (this loads the model, tokenizer, and feature extractor)
|
| 46 |
self.model = ASRModel.from_pretrained(path, **model_kwargs)
|
| 47 |
|
| 48 |
+
# Get device from model for pipeline
|
| 49 |
+
self.device = next(self.model.parameters()).device
|
| 50 |
+
|
| 51 |
# Instantiate custom pipeline - it will get feature_extractor and tokenizer from model
|
| 52 |
self.pipe = ASRPipeline(
|
| 53 |
model=self.model,
|
|
|
|
| 56 |
device=self.device,
|
| 57 |
)
|
| 58 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
def _is_flash_attn_available(self):
|
| 60 |
"""Check if flash attention is available."""
|
| 61 |
import importlib.util
|
| 62 |
|
| 63 |
return importlib.util.find_spec("flash_attn") is not None
|
| 64 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
def __call__(self, data: Dict[str, Any]) -> Union[Dict[str, Any], List[Dict[str, Any]]]:
|
| 66 |
"""Process an inference request.
|
| 67 |
|