mazesmazes commited on
Commit
73cda73
·
verified ·
1 Parent(s): 370211e

Update custom model files, README, and requirements

Browse files
Files changed (6) hide show
  1. .gitattributes +0 -1
  2. README.md +42 -106
  3. asr_config.py +6 -6
  4. asr_modeling.py +14 -30
  5. asr_pipeline.py +35 -11
  6. handler.py +8 -60
.gitattributes CHANGED
@@ -1,4 +1,3 @@
1
  *.safetensors filter=lfs diff=lfs merge=lfs -text
2
  *.bin filter=lfs diff=lfs merge=lfs -text
3
  tokenizer_config.json -filter -diff -merge text
4
- tokenizer.json filter=lfs diff=lfs merge=lfs -text
 
1
  *.safetensors filter=lfs diff=lfs merge=lfs -text
2
  *.bin filter=lfs diff=lfs merge=lfs -text
3
  tokenizer_config.json -filter -diff -merge text
 
README.md CHANGED
@@ -1,123 +1,59 @@
1
  ---
2
- library_name: transformers
 
 
 
 
 
 
 
 
3
  tags:
4
- - generated_from_trainer
5
- model-index:
6
- - name: tiny-audio
7
- results: []
 
8
  ---
9
 
10
- <!-- This model card has been generated automatically according to the information the Trainer had access to. You
11
- should probably proofread and complete it, then remove this comment. -->
12
 
13
- # tiny-audio
14
 
15
- This model is a fine-tuned version of [](https://huggingface.co/) on an unknown dataset.
16
- It achieves the following results on the evaluation set:
17
- - Loss: 0.2566
18
 
19
- ## Model description
 
 
20
 
21
- More information needed
22
 
23
- ## Intended uses & limitations
24
 
25
- More information needed
 
 
 
 
 
26
 
27
- ## Training and evaluation data
28
 
29
- More information needed
 
30
 
31
- ## Training procedure
 
 
 
32
 
33
- ### Training hyperparameters
34
 
35
- The following hyperparameters were used during training:
36
- - learning_rate: 0.0001
37
- - train_batch_size: 16
38
- - eval_batch_size: 16
39
- - seed: 936
40
- - optimizer: Use OptimizerNames.ADAMW_TORCH_FUSED with betas=(0.9,0.95) and epsilon=1e-08 and optimizer_args=No additional optimizer arguments
41
- - lr_scheduler_type: cosine
42
- - lr_scheduler_warmup_steps: 500
43
- - num_epochs: 1
44
 
45
- ### Training results
46
 
47
- | Training Loss | Epoch | Step | Validation Loss |
48
- |:-------------:|:------:|:-----:|:---------------:|
49
- | 0.2888 | 0.0149 | 1000 | 0.2819 |
50
- | 0.3565 | 0.0298 | 2000 | 0.2919 |
51
- | 0.3189 | 0.0447 | 3000 | 0.2879 |
52
- | 0.3274 | 0.0596 | 4000 | 0.2929 |
53
- | 0.3231 | 0.0745 | 5000 | 0.2870 |
54
- | 0.3270 | 0.0894 | 6000 | 0.2853 |
55
- | 0.3486 | 0.1043 | 7000 | 0.2860 |
56
- | 0.3066 | 0.1192 | 8000 | 0.2865 |
57
- | 0.3487 | 0.1341 | 9000 | 0.2866 |
58
- | 0.3307 | 0.1490 | 10000 | 0.2871 |
59
- | 0.3419 | 0.1639 | 11000 | 0.2852 |
60
- | 0.3601 | 0.1788 | 12000 | 0.2848 |
61
- | 0.3156 | 0.1936 | 13000 | 0.2860 |
62
- | 0.3098 | 0.2085 | 14000 | 0.2830 |
63
- | 0.3133 | 0.2234 | 15000 | 0.2851 |
64
- | 0.3269 | 0.2383 | 16000 | 0.2826 |
65
- | 0.3257 | 0.2532 | 17000 | 0.2822 |
66
- | 0.3281 | 0.2681 | 18000 | 0.2822 |
67
- | 0.3941 | 0.2830 | 19000 | 0.2813 |
68
- | 0.3875 | 0.2979 | 20000 | 0.2854 |
69
- | 0.3214 | 0.3128 | 21000 | 0.2795 |
70
- | 0.2914 | 0.3277 | 22000 | 0.2792 |
71
- | 0.2951 | 0.3426 | 23000 | 0.2805 |
72
- | 0.3343 | 0.3575 | 24000 | 0.2779 |
73
- | 0.3252 | 0.3724 | 25000 | 0.2771 |
74
- | 0.3027 | 0.3873 | 26000 | 0.2768 |
75
- | 0.3287 | 0.4022 | 27000 | 0.2759 |
76
- | 0.3208 | 0.4171 | 28000 | 0.2749 |
77
- | 0.3402 | 0.4320 | 29000 | 0.2730 |
78
- | 0.2928 | 0.4469 | 30000 | 0.2726 |
79
- | 0.3085 | 0.4618 | 31000 | 0.2737 |
80
- | 0.3073 | 0.4767 | 32000 | 0.2705 |
81
- | 0.3471 | 0.4916 | 33000 | 0.2708 |
82
- | 0.2945 | 0.5065 | 34000 | 0.2690 |
83
- | 0.3294 | 0.5214 | 35000 | 0.2696 |
84
- | 0.3095 | 0.5363 | 36000 | 0.2679 |
85
- | 0.3152 | 0.5512 | 37000 | 0.2659 |
86
- | 0.3035 | 0.5660 | 38000 | 0.2674 |
87
- | 0.3342 | 0.5809 | 39000 | 0.2656 |
88
- | 0.3242 | 0.5958 | 40000 | 0.2653 |
89
- | 0.2789 | 0.6107 | 41000 | 0.2643 |
90
- | 0.3082 | 0.6256 | 42000 | 0.2643 |
91
- | 0.3174 | 0.6405 | 43000 | 0.2633 |
92
- | 0.2730 | 0.6554 | 44000 | 0.2628 |
93
- | 0.2934 | 0.6703 | 45000 | 0.2609 |
94
- | 0.2944 | 0.6852 | 46000 | 0.2606 |
95
- | 0.3111 | 0.7001 | 47000 | 0.2614 |
96
- | 0.3431 | 0.7150 | 48000 | 0.2605 |
97
- | 0.3226 | 0.7299 | 49000 | 0.2601 |
98
- | 0.2735 | 0.7448 | 50000 | 0.2591 |
99
- | 0.3208 | 0.7597 | 51000 | 0.2590 |
100
- | 0.3208 | 0.7746 | 52000 | 0.2584 |
101
- | 0.3021 | 0.7895 | 53000 | 0.2578 |
102
- | 0.2730 | 0.8044 | 54000 | 0.2583 |
103
- | 0.2938 | 0.8193 | 55000 | 0.2581 |
104
- | 0.2894 | 0.8342 | 56000 | 0.2574 |
105
- | 0.2781 | 0.8491 | 57000 | 0.2572 |
106
- | 0.3003 | 0.8640 | 58000 | 0.2568 |
107
- | 0.2719 | 0.8789 | 59000 | 0.2568 |
108
- | 0.2878 | 0.8938 | 60000 | 0.2567 |
109
- | 0.3058 | 0.9087 | 61000 | 0.2568 |
110
- | 0.3036 | 0.9236 | 62000 | 0.2568 |
111
- | 0.3050 | 0.9384 | 63000 | 0.2568 |
112
- | 0.3244 | 0.9533 | 64000 | 0.2567 |
113
- | 0.3187 | 0.9682 | 65000 | 0.2566 |
114
- | 0.3016 | 0.9831 | 66000 | 0.2566 |
115
- | 0.2697 | 0.9980 | 67000 | 0.2566 |
116
-
117
-
118
- ### Framework versions
119
-
120
- - Transformers 5.0.0.dev0
121
- - Pytorch 2.8.0+cu128
122
- - Datasets 3.6.0
123
- - Tokenizers 0.22.1
 
1
  ---
2
+ license: mit
3
+ language:
4
+ - en
5
+ datasets:
6
+ - speechbrain/LoquaciousSet
7
+ base_model:
8
+ - zai-org/GLM-ASR-Nano-2512
9
+ - Qwen/Qwen3-0.6B
10
+ pipeline_tag: automatic-speech-recognition
11
  tags:
12
+ - asr
13
+ - speech-recognition
14
+ - audio
15
+ - qwen
16
+ - glm-asr
17
  ---
18
 
19
+ # Tiny Audio
 
20
 
21
+ A speech recognition model trained in 24 hours on a single GPU for ~$12. Built with [Tiny Audio](https://github.com/alexkroman/tiny-audio)—a minimal, hackable ASR framework.
22
 
23
+ ## Architecture
 
 
24
 
25
+ ```
26
+ Audio (16kHz) → GLM-ASR Encoder (frozen) → MLP Projector (trained) → Qwen3 (frozen) → Text
27
+ ```
28
 
29
+ Only the projector is trained (~12M params). The encoder and decoder remain frozen.
30
 
31
+ ## Training
32
 
33
+ | | |
34
+ |---|---|
35
+ | **Dataset** | LoquaciousSet (25,000 hours) |
36
+ | **Hardware** | Single NVIDIA A40 |
37
+ | **Time** | ~24 hours |
38
+ | **Cost** | ~$12 |
39
 
40
+ ## Usage
41
 
42
+ ```python
43
+ from transformers import pipeline
44
 
45
+ pipe = pipeline("automatic-speech-recognition", model="mazesmazes/tiny-audio", trust_remote_code=True)
46
+ result = pipe("audio.wav")
47
+ print(result["text"])
48
+ ```
49
 
50
+ ## Limitations
51
 
52
+ - English only
53
+ - 16kHz audio (other sample rates resampled automatically)
54
+ - May degrade on accented speech, noisy audio, or domain-specific terms
 
 
 
 
 
 
55
 
56
+ ## Links
57
 
58
+ - [Train your own](https://github.com/alexkroman/tiny-audio)
59
+ - [Free 3.5-hour course](https://github.com/alexkroman/tiny-audio/blob/main/docs/course/0-course-overview.md)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
asr_config.py CHANGED
@@ -7,8 +7,8 @@ class ASRConfig(transformers.PretrainedConfig):
7
  """Configuration class for the ASR model.
8
 
9
  This config combines settings for:
10
- - Audio encoder (Whisper)
11
- - Text decoder (SmolLM/Qwen)
12
  - Projector (MLP, MOSA, MoE, QFormer)
13
  - Generation parameters
14
  - Training options (SpecAugment, LoRA)
@@ -19,8 +19,8 @@ class ASRConfig(transformers.PretrainedConfig):
19
 
20
  def __init__(
21
  self,
22
- audio_model_id: str = "openai/whisper-large-v3-turbo",
23
- text_model_id: str = "HuggingFaceTB/SmolLM3-3B",
24
  attn_implementation: str = "flash_attention_2",
25
  model_dtype: str = "bfloat16",
26
  num_beams: Optional[int] = None,
@@ -74,8 +74,8 @@ class ASRConfig(transformers.PretrainedConfig):
74
  """Initialize ASR model configuration.
75
 
76
  Args:
77
- audio_model_id: HuggingFace model ID for audio encoder (Whisper)
78
- text_model_id: HuggingFace model ID for text decoder (SmolLM/Qwen)
79
  attn_implementation: Attention implementation ("flash_attention_2", "sdpa", "eager")
80
  model_dtype: Model dtype ("bfloat16", "float16", "float32")
81
  projector_type: Projector architecture ("mlp", "mosa", "moe", "qformer")
 
7
  """Configuration class for the ASR model.
8
 
9
  This config combines settings for:
10
+ - Audio encoder (GLM-ASR/Whisper)
11
+ - Text decoder (Qwen)
12
  - Projector (MLP, MOSA, MoE, QFormer)
13
  - Generation parameters
14
  - Training options (SpecAugment, LoRA)
 
19
 
20
  def __init__(
21
  self,
22
+ audio_model_id: str = "zai-org/GLM-ASR-Nano-2512",
23
+ text_model_id: str = "Qwen/Qwen3-0.6B",
24
  attn_implementation: str = "flash_attention_2",
25
  model_dtype: str = "bfloat16",
26
  num_beams: Optional[int] = None,
 
74
  """Initialize ASR model configuration.
75
 
76
  Args:
77
+ audio_model_id: HuggingFace model ID for audio encoder (GLM-ASR/Whisper)
78
+ text_model_id: HuggingFace model ID for text decoder (Qwen)
79
  attn_implementation: Attention implementation ("flash_attention_2", "sdpa", "eager")
80
  model_dtype: Model dtype ("bfloat16", "float16", "float32")
81
  projector_type: Projector architecture ("mlp", "mosa", "moe", "qformer")
asr_modeling.py CHANGED
@@ -24,26 +24,7 @@ except ImportError:
24
  from projectors import PROJECTOR_CLASSES # type: ignore[no-redef]
25
 
26
 
27
- from torchaudio.transforms import FrequencyMasking, TimeMasking
28
-
29
-
30
- def apply_specaugment(
31
- x: torch.Tensor,
32
- num_time_masks: int = 2,
33
- time_mask_length: int = 10,
34
- num_freq_masks: int = 0,
35
- freq_mask_length: int = 10,
36
- ) -> torch.Tensor:
37
- """Apply SpecAugment using torchaudio. Input shape: (batch, n_mels, time)."""
38
- if num_time_masks > 0:
39
- tm = TimeMasking(time_mask_param=time_mask_length, iid_masks=True)
40
- for _ in range(num_time_masks):
41
- x = tm(x)
42
- if num_freq_masks > 0:
43
- fm = FrequencyMasking(freq_mask_param=freq_mask_length, iid_masks=True)
44
- for _ in range(num_freq_masks):
45
- x = fm(x)
46
- return x
47
 
48
 
49
  class ASRModel(PreTrainedModel, GenerationMixin):
@@ -192,6 +173,17 @@ class ASRModel(PreTrainedModel, GenerationMixin):
192
  if getattr(config, "freeze_projector", False):
193
  self.projector.requires_grad_(False)
194
 
 
 
 
 
 
 
 
 
 
 
 
195
  # For model parallelism
196
  self._no_split_modules = getattr(self.language_model, "_no_split_modules", [])
197
 
@@ -230,8 +222,6 @@ class ASRModel(PreTrainedModel, GenerationMixin):
230
  full_model.language_model = None
231
  full_model.multi_modal_projector = None
232
  del full_model
233
- if torch.cuda.is_available():
234
- torch.cuda.empty_cache()
235
  else:
236
  encoder = AutoModel.from_pretrained(config.audio_model_id, **encoder_kwargs)
237
 
@@ -463,14 +453,8 @@ class ASRModel(PreTrainedModel, GenerationMixin):
463
 
464
  if input_features is not None and input_ids is not None:
465
  # Apply SpecAugment during training if enabled
466
- if self.training and getattr(self.config, "use_specaugment", False):
467
- input_features = apply_specaugment(
468
- input_features,
469
- num_time_masks=self.config.num_time_masks,
470
- time_mask_length=self.config.time_mask_length,
471
- num_freq_masks=self.config.num_freq_masks,
472
- freq_mask_length=self.config.freq_mask_length,
473
- )
474
 
475
  # Encode audio -> flattened (total_audio_tokens, hidden_dim)
476
  audio_embeds = self._encode_audio(input_features, audio_attention_mask)
 
24
  from projectors import PROJECTOR_CLASSES # type: ignore[no-redef]
25
 
26
 
27
+ from torchaudio.transforms import SpecAugment
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
 
30
  class ASRModel(PreTrainedModel, GenerationMixin):
 
173
  if getattr(config, "freeze_projector", False):
174
  self.projector.requires_grad_(False)
175
 
176
+ # SpecAugment for data augmentation during training
177
+ if getattr(config, "use_specaugment", False):
178
+ self.spec_augment = SpecAugment(
179
+ n_time_masks=config.num_time_masks,
180
+ time_mask_param=config.time_mask_length,
181
+ n_freq_masks=config.num_freq_masks,
182
+ freq_mask_param=config.freq_mask_length,
183
+ )
184
+ else:
185
+ self.spec_augment = None
186
+
187
  # For model parallelism
188
  self._no_split_modules = getattr(self.language_model, "_no_split_modules", [])
189
 
 
222
  full_model.language_model = None
223
  full_model.multi_modal_projector = None
224
  del full_model
 
 
225
  else:
226
  encoder = AutoModel.from_pretrained(config.audio_model_id, **encoder_kwargs)
227
 
 
453
 
454
  if input_features is not None and input_ids is not None:
455
  # Apply SpecAugment during training if enabled
456
+ if self.training and self.spec_augment is not None:
457
+ input_features = self.spec_augment(input_features)
 
 
 
 
 
 
458
 
459
  # Encode audio -> flattened (total_audio_tokens, hidden_dim)
460
  audio_embeds = self._encode_audio(input_features, audio_attention_mask)
asr_pipeline.py CHANGED
@@ -14,6 +14,15 @@ except ImportError:
14
  from asr_modeling import ASRModel # type: ignore[no-redef]
15
 
16
 
 
 
 
 
 
 
 
 
 
17
  class ForcedAligner:
18
  """Lazy-loaded forced aligner for word-level timestamps using torchaudio wav2vec2."""
19
 
@@ -66,7 +75,7 @@ class ForcedAligner:
66
  import torchaudio
67
  from torchaudio.functional import forced_align, merge_tokens
68
 
69
- device = "cuda" if torch.cuda.is_available() else "cpu"
70
  model, labels, dictionary = cls.get_instance(device)
71
 
72
  # Convert audio to tensor (copy to ensure array is writable)
@@ -179,11 +188,8 @@ class SpeakerDiarizer:
179
  "pyannote/speaker-diarization-3.1",
180
  )
181
 
182
- # Move to GPU if available
183
- if torch.cuda.is_available():
184
- cls._pipeline.to(torch.device("cuda"))
185
- elif torch.backends.mps.is_available():
186
- cls._pipeline.to(torch.device("mps"))
187
 
188
  return cls._pipeline
189
 
@@ -539,9 +545,18 @@ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
539
  HALLUCINATION_PATTERNS = frozenset(
540
  [
541
  "and gt and gt",
 
542
  ]
543
  )
544
 
 
 
 
 
 
 
 
 
545
  def _post_process_prediction(self, text: str) -> str:
546
  """Post-process model output to fix common issues."""
547
  if not text:
@@ -554,21 +569,30 @@ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
554
  if text.strip() in self.HALLUCINATION_PATTERNS:
555
  return ""
556
 
557
- # 3. COMBINE ACRONYMS
 
 
 
 
 
 
 
 
 
558
  # Merge consecutive single letters into one word (e.g., "u s a" -> "usa")
559
  text = re.sub(r"\b([a-z])((?:\s+[a-z])+)\b", lambda m: m.group(0).replace(" ", ""), text)
560
 
561
- # 4. NORMALIZE CURRENCY
562
  # Convert "eur X" to "X euros" for Whisper normalizer compatibility
563
  text = re.sub(r"\beur\s+(\d+)", r"\1 euros", text)
564
 
565
- # 5. TRUNCATE CHARACTER REPETITIONS (e.g., "uhhhhhh" -> "uhh")
566
  text = self._truncate_character_repetitions(text)
567
 
568
- # 6. TRUNCATE TRAILING REPEATS (word-level)
569
  text = self._truncate_trailing_repeats(text)
570
 
571
- # 7. STRIP WHITESPACE
572
  return re.sub(r"\s+", " ", text).strip()
573
 
574
  def _truncate_trailing_repeats(self, text: str, max_ngram: int = 10) -> str:
 
14
  from asr_modeling import ASRModel # type: ignore[no-redef]
15
 
16
 
17
+ def _get_device() -> str:
18
+ """Get best available device for non-transformers models."""
19
+ if torch.cuda.is_available():
20
+ return "cuda"
21
+ if torch.backends.mps.is_available():
22
+ return "mps"
23
+ return "cpu"
24
+
25
+
26
  class ForcedAligner:
27
  """Lazy-loaded forced aligner for word-level timestamps using torchaudio wav2vec2."""
28
 
 
75
  import torchaudio
76
  from torchaudio.functional import forced_align, merge_tokens
77
 
78
+ device = _get_device()
79
  model, labels, dictionary = cls.get_instance(device)
80
 
81
  # Convert audio to tensor (copy to ensure array is writable)
 
188
  "pyannote/speaker-diarization-3.1",
189
  )
190
 
191
+ # Move to best available device
192
+ cls._pipeline.to(torch.device(_get_device()))
 
 
 
193
 
194
  return cls._pipeline
195
 
 
545
  HALLUCINATION_PATTERNS = frozenset(
546
  [
547
  "and gt and gt",
548
+ "n", # Single character noise
549
  ]
550
  )
551
 
552
+ # Regex patterns for hallucinations (compiled for efficiency)
553
+ HALLUCINATION_REGEXES = [
554
+ # Repeating decimal hallucinations (e.g., "12.93242424242424")
555
+ re.compile(r"\d+\.\d*?(\d{2,})\1{3,}"),
556
+ # Very long repeated digit sequences (e.g., "242424242424")
557
+ re.compile(r"(\d{2,})\1{4,}"),
558
+ ]
559
+
560
  def _post_process_prediction(self, text: str) -> str:
561
  """Post-process model output to fix common issues."""
562
  if not text:
 
569
  if text.strip() in self.HALLUCINATION_PATTERNS:
570
  return ""
571
 
572
+ # 3. CHECK FOR REGEX-BASED HALLUCINATIONS
573
+ for pattern in self.HALLUCINATION_REGEXES:
574
+ if pattern.search(text):
575
+ # If hallucination is the entire output, return empty
576
+ if pattern.fullmatch(text.strip()):
577
+ return ""
578
+ # Otherwise remove the hallucinated portion
579
+ text = pattern.sub("", text)
580
+
581
+ # 4. COMBINE ACRONYMS
582
  # Merge consecutive single letters into one word (e.g., "u s a" -> "usa")
583
  text = re.sub(r"\b([a-z])((?:\s+[a-z])+)\b", lambda m: m.group(0).replace(" ", ""), text)
584
 
585
+ # 5. NORMALIZE CURRENCY
586
  # Convert "eur X" to "X euros" for Whisper normalizer compatibility
587
  text = re.sub(r"\beur\s+(\d+)", r"\1 euros", text)
588
 
589
+ # 6. TRUNCATE CHARACTER REPETITIONS (e.g., "uhhhhhh" -> "uhh")
590
  text = self._truncate_character_repetitions(text)
591
 
592
+ # 7. TRUNCATE TRAILING REPEATS (word-level)
593
  text = self._truncate_trailing_repeats(text)
594
 
595
+ # 8. STRIP WHITESPACE
596
  return re.sub(r"\s+", " ", text).strip()
597
 
598
  def _truncate_trailing_repeats(self, text: str, max_ngram: int = 10) -> str:
handler.py CHANGED
@@ -2,8 +2,6 @@
2
 
3
  from typing import Any, Dict, List, Union
4
 
5
- import torch
6
-
7
  try:
8
  # For remote execution, imports are relative
9
  from .asr_modeling import ASRModel
@@ -35,35 +33,21 @@ class EndpointHandler:
35
 
36
  os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")
37
 
38
- # Enable TF32 for faster matmul on Ampere+ GPUs (A100, etc.)
39
- # Also beneficial for T4 (Turing) which supports TensorFloat-32
40
- torch.backends.cuda.matmul.allow_tf32 = True
41
- torch.backends.cudnn.allow_tf32 = True
42
-
43
- # Set device and dtype
44
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
45
-
46
- # Use float16 for better T4 compatibility (bfloat16 not well supported on T4)
47
- # T4 has excellent float16 performance with tensor cores
48
- self.dtype = torch.float16 if self.device == "cuda" else torch.float32
49
-
50
- # Enable CUDA optimizations
51
- if torch.cuda.is_available():
52
- torch.backends.cudnn.benchmark = True
53
-
54
- # Prepare model kwargs for pipeline
55
  model_kwargs = {
56
- "dtype": self.dtype,
 
57
  "low_cpu_mem_usage": True,
58
  }
59
- if torch.cuda.is_available():
60
- model_kwargs["attn_implementation"] = (
61
- "flash_attention_2" if self._is_flash_attn_available() else "sdpa"
62
- )
63
 
64
  # Load model (this loads the model, tokenizer, and feature extractor)
65
  self.model = ASRModel.from_pretrained(path, **model_kwargs)
66
 
 
 
 
67
  # Instantiate custom pipeline - it will get feature_extractor and tokenizer from model
68
  self.pipe = ASRPipeline(
69
  model=self.model,
@@ -72,48 +56,12 @@ class EndpointHandler:
72
  device=self.device,
73
  )
74
 
75
- # Apply torch.compile if enabled (after model is loaded by pipeline)
76
- # Use "default" mode for T4 - better compatibility than "reduce-overhead"
77
- # "reduce-overhead" is better for A100+ but can be slower on older GPUs
78
- if torch.cuda.is_available() and os.getenv("ENABLE_TORCH_COMPILE", "1") == "1":
79
- compile_mode = os.getenv("TORCH_COMPILE_MODE", "default")
80
- self.model = torch.compile(self.model, mode=compile_mode)
81
- self.pipe.model = self.model
82
-
83
- # Warmup the model to trigger compilation and optimize kernels
84
- if torch.cuda.is_available():
85
- self._warmup()
86
-
87
  def _is_flash_attn_available(self):
88
  """Check if flash attention is available."""
89
  import importlib.util
90
 
91
  return importlib.util.find_spec("flash_attn") is not None
92
 
93
- def _warmup(self):
94
- """Warmup to trigger model compilation and allocate GPU memory."""
95
- try:
96
- # Create dummy audio (1 second at config sample rate)
97
- sample_rate = self.pipe.model.config.audio_sample_rate
98
- dummy_audio = torch.randn(sample_rate, dtype=torch.float32)
99
-
100
- # Run inference to trigger torch.compile and kernel optimization
101
- with torch.inference_mode():
102
- warmup_tokens = self.pipe.model.config.inference_warmup_tokens
103
- _ = self.pipe(
104
- {"raw": dummy_audio, "sampling_rate": sample_rate},
105
- max_new_tokens=warmup_tokens,
106
- )
107
-
108
- # Force CUDA synchronization to ensure kernels are compiled
109
- if torch.cuda.is_available():
110
- torch.cuda.synchronize()
111
- # Clear cache after warmup to free memory
112
- torch.cuda.empty_cache()
113
-
114
- except Exception as e:
115
- print(f"Warmup skipped due to: {e}")
116
-
117
  def __call__(self, data: Dict[str, Any]) -> Union[Dict[str, Any], List[Dict[str, Any]]]:
118
  """Process an inference request.
119
 
 
2
 
3
  from typing import Any, Dict, List, Union
4
 
 
 
5
  try:
6
  # For remote execution, imports are relative
7
  from .asr_modeling import ASRModel
 
33
 
34
  os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")
35
 
36
+ # Prepare model kwargs - let transformers handle device placement
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  model_kwargs = {
38
+ "device_map": "auto",
39
+ "torch_dtype": "auto",
40
  "low_cpu_mem_usage": True,
41
  }
42
+ if self._is_flash_attn_available():
43
+ model_kwargs["attn_implementation"] = "flash_attention_2"
 
 
44
 
45
  # Load model (this loads the model, tokenizer, and feature extractor)
46
  self.model = ASRModel.from_pretrained(path, **model_kwargs)
47
 
48
+ # Get device from model for pipeline
49
+ self.device = next(self.model.parameters()).device
50
+
51
  # Instantiate custom pipeline - it will get feature_extractor and tokenizer from model
52
  self.pipe = ASRPipeline(
53
  model=self.model,
 
56
  device=self.device,
57
  )
58
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  def _is_flash_attn_available(self):
60
  """Check if flash attention is available."""
61
  import importlib.util
62
 
63
  return importlib.util.find_spec("flash_attn") is not None
64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  def __call__(self, data: Dict[str, Any]) -> Union[Dict[str, Any], List[Dict[str, Any]]]:
66
  """Process an inference request.
67