mazesmazes commited on
Commit
cf9c6ea
·
verified ·
1 Parent(s): 1d1f836

Update custom model files, README, and requirements

Browse files
Files changed (4) hide show
  1. .gitattributes +0 -1
  2. README.md +263 -83
  3. asr_modeling.py +53 -2
  4. asr_pipeline.py +53 -0
.gitattributes CHANGED
@@ -1,4 +1,3 @@
1
  *.safetensors filter=lfs diff=lfs merge=lfs -text
2
  *.bin filter=lfs diff=lfs merge=lfs -text
3
  tokenizer_config.json -filter -diff -merge text
4
- tokenizer.json filter=lfs diff=lfs merge=lfs -text
 
1
  *.safetensors filter=lfs diff=lfs merge=lfs -text
2
  *.bin filter=lfs diff=lfs merge=lfs -text
3
  tokenizer_config.json -filter -diff -merge text
 
README.md CHANGED
@@ -1,87 +1,267 @@
1
  ---
2
- library_name: transformers
 
 
 
 
 
 
 
 
3
  tags:
4
- - generated_from_trainer
5
- model-index:
6
- - name: tiny-audio
7
- results: []
 
 
8
  ---
9
 
10
- <!-- This model card has been generated automatically according to the information the Trainer had access to. You
11
- should probably proofread and complete it, then remove this comment. -->
12
-
13
- # tiny-audio
14
-
15
- This model is a fine-tuned version of [](https://huggingface.co/) on an unknown dataset.
16
- It achieves the following results on the evaluation set:
17
- - Loss: 1.8002
18
-
19
- ## Model description
20
-
21
- More information needed
22
-
23
- ## Intended uses & limitations
24
-
25
- More information needed
26
-
27
- ## Training and evaluation data
28
-
29
- More information needed
30
-
31
- ## Training procedure
32
-
33
- ### Training hyperparameters
34
-
35
- The following hyperparameters were used during training:
36
- - learning_rate: 0.002
37
- - train_batch_size: 14
38
- - eval_batch_size: 14
39
- - seed: 42
40
- - optimizer: Use OptimizerNames.ADAMW_TORCH_FUSED with betas=(0.9,0.999) and epsilon=1e-08 and optimizer_args=No additional optimizer arguments
41
- - lr_scheduler_type: polynomial
42
- - lr_scheduler_warmup_steps: 1000
43
- - num_epochs: 4
44
- - label_smoothing_factor: 0.1
45
-
46
- ### Training results
47
-
48
- | Training Loss | Epoch | Step | Validation Loss |
49
- |:-------------:|:------:|:------:|:---------------:|
50
- | 2.1624 | 0.1303 | 10000 | 1.8803 |
51
- | 2.1100 | 0.2607 | 20000 | 1.8542 |
52
- | 2.0734 | 0.3910 | 30000 | 1.8479 |
53
- | 2.1233 | 0.5214 | 40000 | 1.8361 |
54
- | 2.1015 | 0.6517 | 50000 | 1.8280 |
55
- | 2.0839 | 0.7820 | 60000 | 1.8288 |
56
- | 2.0971 | 0.9124 | 70000 | 1.8219 |
57
- | 2.0907 | 1.0427 | 80000 | 1.8218 |
58
- | 2.0599 | 1.1731 | 90000 | 1.8167 |
59
- | 2.0747 | 1.3034 | 100000 | 1.8171 |
60
- | 2.0713 | 1.4337 | 110000 | 1.8152 |
61
- | 2.0866 | 1.5641 | 120000 | 1.8133 |
62
- | 2.0904 | 1.6944 | 130000 | 1.8104 |
63
- | 2.0554 | 1.8248 | 140000 | 1.8092 |
64
- | 2.0968 | 1.9551 | 150000 | 1.8100 |
65
- | 2.0644 | 2.0855 | 160000 | 1.8077 |
66
- | 2.0499 | 2.2158 | 170000 | 1.8054 |
67
- | 2.0570 | 2.3461 | 180000 | 1.8056 |
68
- | 2.0432 | 2.4765 | 190000 | 1.8066 |
69
- | 2.0413 | 2.6068 | 200000 | 1.8050 |
70
- | 2.0373 | 2.7372 | 210000 | 1.8039 |
71
- | 2.0117 | 2.8675 | 220000 | 1.8036 |
72
- | 2.0437 | 2.9978 | 230000 | 1.8036 |
73
- | 2.0454 | 3.1282 | 240000 | 1.8032 |
74
- | 2.0181 | 3.2585 | 250000 | 1.8022 |
75
- | 2.0266 | 3.3889 | 260000 | 1.8015 |
76
- | 2.0451 | 3.5192 | 270000 | 1.8018 |
77
- | 2.0308 | 3.6495 | 280000 | 1.8019 |
78
- | 2.0419 | 3.7799 | 290000 | 1.8005 |
79
- | 2.0172 | 3.9102 | 300000 | 1.8002 |
80
-
81
-
82
- ### Framework versions
83
-
84
- - Transformers 5.0.0.dev0
85
- - Pytorch 2.8.0+cu128
86
- - Datasets 3.6.0
87
- - Tokenizers 0.22.2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ license: mit
3
+ language:
4
+ - en
5
+ datasets:
6
+ - speechbrain/LoquaciousSet
7
+ base_model:
8
+ - zai-org/GLM-ASR-Nano-2512
9
+ - Qwen/Qwen3-0.6B
10
+ pipeline_tag: automatic-speech-recognition
11
  tags:
12
+ - asr
13
+ - speech-recognition
14
+ - audio
15
+ - qwen
16
+ - glm-asr
17
+ library_name: transformers
18
  ---
19
 
20
+ # Tiny Audio
21
+
22
+ A speech recognition model trained in 24 hours on a single GPU for ~$12. Built with [Tiny Audio](https://github.com/alexkroman/tiny-audio)—a minimal, hackable ASR framework.
23
+
24
+ ## Quick Start
25
+
26
+ ```python
27
+ from transformers import pipeline
28
+
29
+ pipe = pipeline("automatic-speech-recognition", model="mazesmazes/tiny-audio", trust_remote_code=True)
30
+ result = pipe("audio.wav")
31
+ print(result["text"])
32
+ ```
33
+
34
+ ## Usage Examples
35
+
36
+ ### Basic Transcription
37
+
38
+ ```python
39
+ from transformers import pipeline
40
+
41
+ pipe = pipeline("automatic-speech-recognition", model="mazesmazes/tiny-audio", trust_remote_code=True)
42
+
43
+ # From file
44
+ result = pipe("audio.wav")
45
+ print(result["text"])
46
+
47
+ # From URL
48
+ result = pipe("https://example.com/audio.mp3")
49
+
50
+ # From numpy array (must be 16kHz)
51
+ import numpy as np
52
+ audio = np.random.randn(16000).astype(np.float32) # 1 second
53
+ result = pipe(audio)
54
+ ```
55
+
56
+ ### Batch Processing
57
+
58
+ ```python
59
+ # Process multiple files
60
+ files = ["audio1.wav", "audio2.wav", "audio3.wav"]
61
+ results = pipe(files, batch_size=4)
62
+ for r in results:
63
+ print(r["text"])
64
+ ```
65
+
66
+ ### Word-Level Timestamps
67
+
68
+ ```python
69
+ result = pipe("audio.wav", return_timestamps="word")
70
+ # Returns:
71
+ # {
72
+ # "text": "hello world",
73
+ # "chunks": [
74
+ # {"text": "hello", "timestamp": (0.0, 0.5)},
75
+ # {"text": "world", "timestamp": (0.6, 1.0)}
76
+ # ]
77
+ # }
78
+ ```
79
+
80
+ ### Streaming Inference
81
+
82
+ ```python
83
+ from tiny_audio import ASRModel, ASRProcessor
84
+ import torch
85
+
86
+ model = ASRModel.from_pretrained("mazesmazes/tiny-audio")
87
+ processor = ASRProcessor.from_pretrained("mazesmazes/tiny-audio")
88
+
89
+ # Load and process audio
90
+ import librosa
91
+ audio, sr = librosa.load("audio.wav", sr=16000)
92
+ inputs = processor(audio, sampling_rate=16000, return_tensors="pt")
93
+
94
+ # Stream tokens
95
+ for token in model.generate_streaming(inputs["input_features"]):
96
+ print(token, end="", flush=True)
97
+ ```
98
+
99
+ ### Using with torch directly
100
+
101
+ ```python
102
+ from tiny_audio import ASRModel, ASRProcessor
103
+ import torch
104
+ import librosa
105
+
106
+ # Load model and processor
107
+ model = ASRModel.from_pretrained("mazesmazes/tiny-audio")
108
+ processor = ASRProcessor.from_pretrained("mazesmazes/tiny-audio")
109
+
110
+ # Load audio (16kHz)
111
+ audio, sr = librosa.load("audio.wav", sr=16000)
112
+
113
+ # Process
114
+ inputs = processor(audio, sampling_rate=16000, return_tensors="pt")
115
+
116
+ # Generate
117
+ with torch.no_grad():
118
+ output = model.generate(
119
+ input_features=inputs["input_features"],
120
+ attention_mask=inputs["attention_mask"],
121
+ max_new_tokens=256
122
+ )
123
+
124
+ # Decode
125
+ text = processor.batch_decode(output, skip_special_tokens=True)[0]
126
+ print(text)
127
+ ```
128
+
129
+ ### GPU Inference
130
+
131
+ ```python
132
+ import torch
133
+
134
+ pipe = pipeline(
135
+ "automatic-speech-recognition",
136
+ model="mazesmazes/tiny-audio",
137
+ trust_remote_code=True,
138
+ device="cuda" # or device=0
139
+ )
140
+ ```
141
+
142
+ ### Half Precision
143
+
144
+ ```python
145
+ pipe = pipeline(
146
+ "automatic-speech-recognition",
147
+ model="mazesmazes/tiny-audio",
148
+ trust_remote_code=True,
149
+ torch_dtype=torch.float16,
150
+ device="cuda"
151
+ )
152
+ ```
153
+
154
+ ## Architecture
155
+
156
+ ```
157
+ Audio (16kHz) → GLM-ASR Encoder (frozen) → MLP Projector (trained) → Qwen3 (frozen) → Text
158
+ ```
159
+
160
+ Only the projector is trained (~12M params). The encoder and decoder remain frozen, leveraging their pretrained knowledge.
161
+
162
+ | Component | Model | Parameters | Status |
163
+ |-----------|-------|------------|--------|
164
+ | Audio Encoder | GLM-ASR-Nano-2512 | ~600M | Frozen |
165
+ | Projector | 2-layer MLP | ~12M | Trained |
166
+ | Language Model | Qwen3-0.6B | ~600M | Frozen |
167
+
168
+ ### How It Works
169
+
170
+ 1. **Audio Encoder**: GLM-ASR converts 16kHz audio into frame-level embeddings (768-dim)
171
+ 2. **Projector**: A 2-layer MLP with frame stacking bridges the audio and text embedding spaces
172
+ 3. **Language Model**: Qwen3 generates text autoregressively, conditioned on the projected audio
173
+
174
+ The projector reduces sequence length via frame stacking: `output_len = (input_len - 5) // 5 + 1`
175
+
176
+ ## Model Specifications
177
+
178
+ | Specification | Value |
179
+ |---------------|-------|
180
+ | Input | Audio (16kHz mono) |
181
+ | Output | Text transcription |
182
+ | Max Audio Length | ~30 seconds (limited by encoder) |
183
+ | Vocabulary | Qwen3 tokenizer |
184
+ | Languages | English only |
185
+ | Generation | Greedy decoding (num_beams=1, do_sample=False) |
186
+
187
+ ## Training Details
188
+
189
+ | | |
190
+ |---|---|
191
+ | **Dataset** | LoquaciousSet (25,000 hours) |
192
+ | **Hardware** | Single NVIDIA A40 |
193
+ | **Time** | ~24 hours |
194
+ | **Cost** | ~$12 |
195
+ | **Optimizer** | AdamW |
196
+ | **Learning Rate** | 1e-4 |
197
+ | **Batch Size** | 4 |
198
+ | **Steps** | 50,000 |
199
+
200
+ ## Limitations
201
+
202
+ - **English only**: Not trained on other languages
203
+ - **Sample rate**: Expects 16kHz audio (other rates resampled automatically)
204
+ - **Audio length**: Best for clips under 30 seconds
205
+ - **Accuracy**: May degrade on:
206
+ - Heavily accented speech
207
+ - Noisy or low-quality audio
208
+ - Domain-specific terminology
209
+ - Overlapping speakers
210
+ - **No punctuation**: Output is lowercase without punctuation by default
211
+
212
+ ## Requirements
213
+
214
+ ```
215
+ transformers>=4.40.0
216
+ torch>=2.0.0
217
+ torchaudio>=2.0.0
218
+ ```
219
+
220
+ Optional for streaming:
221
+ ```
222
+ librosa
223
+ soundfile
224
+ ```
225
+
226
+ ## Files
227
+
228
+ | File | Description |
229
+ |------|-------------|
230
+ | `config.json` | Model configuration |
231
+ | `model.safetensors` | Projector weights (~48MB) |
232
+ | `preprocessor_config.json` | Audio preprocessing config |
233
+ | `tokenizer.json` | Tokenizer |
234
+ | `tokenizer_config.json` | Tokenizer config |
235
+ | `special_tokens_map.json` | Special tokens |
236
+
237
+ Note: Only the projector weights are stored. The encoder (GLM-ASR) and decoder (Qwen3) are loaded from their respective HuggingFace repos.
238
+
239
+ ## Citation
240
+
241
+ If you use this model, please cite:
242
+
243
+ ```bibtex
244
+ @misc{tinyaudio2024,
245
+ author = {Alex Kroman},
246
+ title = {Tiny Audio: Minimal ASR Training},
247
+ year = {2024},
248
+ publisher = {GitHub},
249
+ url = {https://github.com/alexkroman/tiny-audio}
250
+ }
251
+ ```
252
+
253
+ ## Links
254
+
255
+ - [GitHub Repository](https://github.com/alexkroman/tiny-audio) - Train your own model
256
+ - [Free 3.5-hour Course](https://github.com/alexkroman/tiny-audio/blob/main/docs/course/0-course-overview.md) - Learn ASR from scratch
257
+ - [Live Demo](https://huggingface.co/spaces/mazesmazes/tiny-audio) - Try it in your browser
258
+
259
+ ## Acknowledgments
260
+
261
+ - [GLM-ASR](https://huggingface.co/zai-org/GLM-ASR-Nano-2512) for the audio encoder
262
+ - [Qwen3](https://huggingface.co/Qwen/Qwen3-0.6B) for the language model
263
+ - [LoquaciousSet](https://huggingface.co/datasets/speechbrain/LoquaciousSet) for training data
264
+
265
+ ## License
266
+
267
+ MIT
asr_modeling.py CHANGED
@@ -703,6 +703,57 @@ class ASRModel(PreTrainedModel, GenerationMixin):
703
 
704
  thread.join()
705
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
706
  def save_pretrained(self, save_directory: Union[str, Path], **kwargs) -> None:
707
  """Save model, tokenizer, and processor."""
708
  import shutil
@@ -796,8 +847,8 @@ class ASRModel(PreTrainedModel, GenerationMixin):
796
  """
797
  # Store repo_id in config so save_pretrained can access it
798
  self.config.pretrained_model_path = repo_id
799
- # Call parent's push_to_hub with repo_id in kwargs
800
- return super().push_to_hub(repo_id, repo_id=repo_id, **kwargs)
801
 
802
  def create_or_update_model_card(self, output_dir: Union[str, Path]) -> None:
803
  """No-op for model card creation - we use MODEL_CARD.md in repo instead."""
 
703
 
704
  thread.join()
705
 
706
+ @torch.no_grad()
707
+ def generate_text_only(
708
+ self,
709
+ messages: list[dict[str, str]],
710
+ max_new_tokens: int = 256,
711
+ **generate_kwargs,
712
+ ) -> str:
713
+ """Generate text using only the LLM (no audio encoding).
714
+
715
+ Used for SIFT-style response generation from metadata prompts.
716
+
717
+ Args:
718
+ messages: List of chat messages [{"role": "user", "content": "..."}]
719
+ max_new_tokens: Maximum tokens to generate
720
+ **generate_kwargs: Additional generation arguments
721
+
722
+ Returns:
723
+ Generated text response
724
+ """
725
+ device = next(self.language_model.parameters()).device
726
+
727
+ # Apply chat template
728
+ input_ids = self.tokenizer.apply_chat_template(
729
+ messages,
730
+ tokenize=True,
731
+ add_generation_prompt=True,
732
+ return_tensors="pt",
733
+ enable_thinking=False,
734
+ ).to(device)
735
+
736
+ if input_ids.dim() == 1:
737
+ input_ids = input_ids.unsqueeze(0)
738
+
739
+ attention_mask = torch.ones_like(input_ids)
740
+
741
+ # Generate using language model directly
742
+ output = self.language_model.generate(
743
+ input_ids=input_ids,
744
+ attention_mask=attention_mask,
745
+ max_new_tokens=max_new_tokens,
746
+ do_sample=False,
747
+ pad_token_id=self.tokenizer.pad_token_id,
748
+ eos_token_id=self.tokenizer.eos_token_id,
749
+ **generate_kwargs,
750
+ )
751
+
752
+ # Decode only the new tokens
753
+ new_tokens = output[0, input_ids.shape[1] :]
754
+ response = self.tokenizer.decode(new_tokens, skip_special_tokens=True)
755
+ return response.strip()
756
+
757
  def save_pretrained(self, save_directory: Union[str, Path], **kwargs) -> None:
758
  """Save model, tokenizer, and processor."""
759
  import shutil
 
847
  """
848
  # Store repo_id in config so save_pretrained can access it
849
  self.config.pretrained_model_path = repo_id
850
+ # Call parent's push_to_hub
851
+ return super().push_to_hub(repo_id, **kwargs)
852
 
853
  def create_or_update_model_card(self, output_dir: Union[str, Path]) -> None:
854
  """No-op for model card creation - we use MODEL_CARD.md in repo instead."""
asr_pipeline.py CHANGED
@@ -418,4 +418,57 @@ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
418
  text = self.tokenizer.decode(tokens, skip_special_tokens=True).strip()
419
  # Strip <think>...</think> tags (Qwen3 doesn't respect /no_think prompt)
420
  text = re.sub(r"<think>.*?</think>\s*", "", text, flags=re.DOTALL).strip()
 
 
421
  return {"text": text}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
418
  text = self.tokenizer.decode(tokens, skip_special_tokens=True).strip()
419
  # Strip <think>...</think> tags (Qwen3 doesn't respect /no_think prompt)
420
  text = re.sub(r"<think>.*?</think>\s*", "", text, flags=re.DOTALL).strip()
421
+ # Truncate repetitions at end of text
422
+ text = _truncate_repetitions(text)
423
  return {"text": text}
424
+
425
+
426
+ def _truncate_repetitions(text: str, min_repeats: int = 3) -> str:
427
+ """Truncate repeated words/phrases/characters at end of text.
428
+
429
+ Detects patterns like:
430
+ - Repeated words: "the the the the" -> "the"
431
+ - Repeated phrases: "i am sorry i am sorry i am sorry" -> "i am sorry"
432
+ - Repeated characters: "444444" -> "4"
433
+
434
+ Args:
435
+ text: Input text to process
436
+ min_repeats: Minimum repetitions to trigger truncation (default 3)
437
+
438
+ Returns:
439
+ Text with trailing repetitions removed
440
+ """
441
+ if not text:
442
+ return text
443
+
444
+ # 1. Truncate repeated characters at end (e.g., "444444" -> "4")
445
+ char_pattern = re.compile(r"(.)\1{" + str(min_repeats - 1) + r",}$")
446
+ text = char_pattern.sub(r"\1", text)
447
+
448
+ # 2. Truncate repeated words at end (e.g., "the the the" -> "the")
449
+ word_pattern = re.compile(r"\b(\w+)(?:\s+\1){" + str(min_repeats - 1) + r",}\s*$", re.IGNORECASE)
450
+ while word_pattern.search(text):
451
+ text = word_pattern.sub(r"\1", text)
452
+
453
+ # 3. Truncate repeated phrases (2-20 words) at end
454
+ # e.g., "i am sorry i am sorry i am sorry" -> "i am sorry"
455
+ words = text.split()
456
+ if len(words) >= min_repeats * 2:
457
+ # Try phrase lengths from 2 to 20 words
458
+ for phrase_len in range(2, min(21, len(words) // min_repeats + 1)):
459
+ # Check if the last phrase_len words repeat
460
+ phrase = " ".join(words[-phrase_len:])
461
+ # Build pattern to match repeated phrases at end
462
+ phrase_escaped = re.escape(phrase)
463
+ phrase_pattern = re.compile(
464
+ r"(^|.*?\s)(" + phrase_escaped + r")(?:\s+" + phrase_escaped + r"){" + str(min_repeats - 1) + r",}\s*$",
465
+ re.IGNORECASE,
466
+ )
467
+ match = phrase_pattern.match(text)
468
+ if match:
469
+ # Keep prefix + one instance of the phrase
470
+ text = (match.group(1) + match.group(2)).strip()
471
+ words = text.split()
472
+ break
473
+
474
+ return text