mazesmazes commited on
Commit
842b6ba
·
verified ·
1 Parent(s): 2b1609e

Update custom model files, README, and requirements

Browse files
Files changed (3) hide show
  1. README.md +222 -14
  2. asr_modeling.py +37 -14
  3. asr_pipeline.py +12 -0
README.md CHANGED
@@ -14,21 +14,177 @@ tags:
14
  - audio
15
  - qwen
16
  - glm-asr
 
17
  ---
18
 
19
  # Tiny Audio
20
 
21
  A speech recognition model trained in 24 hours on a single GPU for ~$12. Built with [Tiny Audio](https://github.com/alexkroman/tiny-audio)—a minimal, hackable ASR framework.
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  ## Architecture
24
 
25
  ```
26
  Audio (16kHz) → GLM-ASR Encoder (frozen) → MLP Projector (trained) → Qwen3 (frozen) → Text
27
  ```
28
 
29
- Only the projector is trained (~12M params). The encoder and decoder remain frozen.
 
 
 
 
 
 
 
 
30
 
31
- ## Training
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
  | | |
34
  |---|---|
@@ -36,24 +192,76 @@ Only the projector is trained (~12M params). The encoder and decoder remain froz
36
  | **Hardware** | Single NVIDIA A40 |
37
  | **Time** | ~24 hours |
38
  | **Cost** | ~$12 |
 
 
 
 
39
 
40
- ## Usage
41
 
42
- ```python
43
- from transformers import pipeline
 
 
 
 
 
 
 
44
 
45
- pipe = pipeline("automatic-speech-recognition", model="mazesmazes/tiny-audio", trust_remote_code=True)
46
- result = pipe("audio.wav")
47
- print(result["text"])
 
 
 
48
  ```
49
 
50
- ## Limitations
 
 
 
 
51
 
52
- - English only
53
- - 16kHz audio (other sample rates resampled automatically)
54
- - May degrade on accented speech, noisy audio, or domain-specific terms
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
  ## Links
57
 
58
- - [Train your own](https://github.com/alexkroman/tiny-audio)
59
- - [Free 3.5-hour course](https://github.com/alexkroman/tiny-audio/blob/main/docs/course/0-course-overview.md)
 
 
 
 
 
 
 
 
 
 
 
 
14
  - audio
15
  - qwen
16
  - glm-asr
17
+ library_name: transformers
18
  ---
19
 
20
  # Tiny Audio
21
 
22
  A speech recognition model trained in 24 hours on a single GPU for ~$12. Built with [Tiny Audio](https://github.com/alexkroman/tiny-audio)—a minimal, hackable ASR framework.
23
 
24
+ ## Quick Start
25
+
26
+ ```python
27
+ from transformers import pipeline
28
+
29
+ pipe = pipeline("automatic-speech-recognition", model="mazesmazes/tiny-audio", trust_remote_code=True)
30
+ result = pipe("audio.wav")
31
+ print(result["text"])
32
+ ```
33
+
34
+ ## Usage Examples
35
+
36
+ ### Basic Transcription
37
+
38
+ ```python
39
+ from transformers import pipeline
40
+
41
+ pipe = pipeline("automatic-speech-recognition", model="mazesmazes/tiny-audio", trust_remote_code=True)
42
+
43
+ # From file
44
+ result = pipe("audio.wav")
45
+ print(result["text"])
46
+
47
+ # From URL
48
+ result = pipe("https://example.com/audio.mp3")
49
+
50
+ # From numpy array (must be 16kHz)
51
+ import numpy as np
52
+ audio = np.random.randn(16000).astype(np.float32) # 1 second
53
+ result = pipe(audio)
54
+ ```
55
+
56
+ ### Batch Processing
57
+
58
+ ```python
59
+ # Process multiple files
60
+ files = ["audio1.wav", "audio2.wav", "audio3.wav"]
61
+ results = pipe(files, batch_size=4)
62
+ for r in results:
63
+ print(r["text"])
64
+ ```
65
+
66
+ ### Word-Level Timestamps
67
+
68
+ ```python
69
+ result = pipe("audio.wav", return_timestamps="word")
70
+ # Returns:
71
+ # {
72
+ # "text": "hello world",
73
+ # "chunks": [
74
+ # {"text": "hello", "timestamp": (0.0, 0.5)},
75
+ # {"text": "world", "timestamp": (0.6, 1.0)}
76
+ # ]
77
+ # }
78
+ ```
79
+
80
+ ### Streaming Inference
81
+
82
+ ```python
83
+ from tiny_audio import ASRModel, ASRProcessor
84
+ import torch
85
+
86
+ model = ASRModel.from_pretrained("mazesmazes/tiny-audio")
87
+ processor = ASRProcessor.from_pretrained("mazesmazes/tiny-audio")
88
+
89
+ # Load and process audio
90
+ import librosa
91
+ audio, sr = librosa.load("audio.wav", sr=16000)
92
+ inputs = processor(audio, sampling_rate=16000, return_tensors="pt")
93
+
94
+ # Stream tokens
95
+ for token in model.generate_streaming(inputs["input_features"]):
96
+ print(token, end="", flush=True)
97
+ ```
98
+
99
+ ### Using with torch directly
100
+
101
+ ```python
102
+ from tiny_audio import ASRModel, ASRProcessor
103
+ import torch
104
+ import librosa
105
+
106
+ # Load model and processor
107
+ model = ASRModel.from_pretrained("mazesmazes/tiny-audio")
108
+ processor = ASRProcessor.from_pretrained("mazesmazes/tiny-audio")
109
+
110
+ # Load audio (16kHz)
111
+ audio, sr = librosa.load("audio.wav", sr=16000)
112
+
113
+ # Process
114
+ inputs = processor(audio, sampling_rate=16000, return_tensors="pt")
115
+
116
+ # Generate
117
+ with torch.no_grad():
118
+ output = model.generate(
119
+ input_features=inputs["input_features"],
120
+ attention_mask=inputs["attention_mask"],
121
+ max_new_tokens=256
122
+ )
123
+
124
+ # Decode
125
+ text = processor.batch_decode(output, skip_special_tokens=True)[0]
126
+ print(text)
127
+ ```
128
+
129
+ ### GPU Inference
130
+
131
+ ```python
132
+ import torch
133
+
134
+ pipe = pipeline(
135
+ "automatic-speech-recognition",
136
+ model="mazesmazes/tiny-audio",
137
+ trust_remote_code=True,
138
+ device="cuda" # or device=0
139
+ )
140
+ ```
141
+
142
+ ### Half Precision
143
+
144
+ ```python
145
+ pipe = pipeline(
146
+ "automatic-speech-recognition",
147
+ model="mazesmazes/tiny-audio",
148
+ trust_remote_code=True,
149
+ torch_dtype=torch.float16,
150
+ device="cuda"
151
+ )
152
+ ```
153
+
154
  ## Architecture
155
 
156
  ```
157
  Audio (16kHz) → GLM-ASR Encoder (frozen) → MLP Projector (trained) → Qwen3 (frozen) → Text
158
  ```
159
 
160
+ Only the projector is trained (~12M params). The encoder and decoder remain frozen, leveraging their pretrained knowledge.
161
+
162
+ | Component | Model | Parameters | Status |
163
+ |-----------|-------|------------|--------|
164
+ | Audio Encoder | GLM-ASR-Nano-2512 | ~600M | Frozen |
165
+ | Projector | 2-layer MLP | ~12M | Trained |
166
+ | Language Model | Qwen3-0.6B | ~600M | Frozen |
167
+
168
+ ### How It Works
169
 
170
+ 1. **Audio Encoder**: GLM-ASR converts 16kHz audio into frame-level embeddings (768-dim)
171
+ 2. **Projector**: A 2-layer MLP with frame stacking bridges the audio and text embedding spaces
172
+ 3. **Language Model**: Qwen3 generates text autoregressively, conditioned on the projected audio
173
+
174
+ The projector reduces sequence length via frame stacking: `output_len = (input_len - 5) // 5 + 1`
175
+
176
+ ## Model Specifications
177
+
178
+ | Specification | Value |
179
+ |---------------|-------|
180
+ | Input | Audio (16kHz mono) |
181
+ | Output | Text transcription |
182
+ | Max Audio Length | ~30 seconds (limited by encoder) |
183
+ | Vocabulary | Qwen3 tokenizer |
184
+ | Languages | English only |
185
+ | Generation | Greedy decoding (num_beams=1, do_sample=False) |
186
+
187
+ ## Training Details
188
 
189
  | | |
190
  |---|---|
 
192
  | **Hardware** | Single NVIDIA A40 |
193
  | **Time** | ~24 hours |
194
  | **Cost** | ~$12 |
195
+ | **Optimizer** | AdamW |
196
+ | **Learning Rate** | 1e-4 |
197
+ | **Batch Size** | 4 |
198
+ | **Steps** | 50,000 |
199
 
200
+ ## Limitations
201
 
202
+ - **English only**: Not trained on other languages
203
+ - **Sample rate**: Expects 16kHz audio (other rates resampled automatically)
204
+ - **Audio length**: Best for clips under 30 seconds
205
+ - **Accuracy**: May degrade on:
206
+ - Heavily accented speech
207
+ - Noisy or low-quality audio
208
+ - Domain-specific terminology
209
+ - Overlapping speakers
210
+ - **No punctuation**: Output is lowercase without punctuation by default
211
 
212
+ ## Requirements
213
+
214
+ ```
215
+ transformers>=4.40.0
216
+ torch>=2.0.0
217
+ torchaudio>=2.0.0
218
  ```
219
 
220
+ Optional for streaming:
221
+ ```
222
+ librosa
223
+ soundfile
224
+ ```
225
 
226
+ ## Files
227
+
228
+ | File | Description |
229
+ |------|-------------|
230
+ | `config.json` | Model configuration |
231
+ | `model.safetensors` | Projector weights (~48MB) |
232
+ | `preprocessor_config.json` | Audio preprocessing config |
233
+ | `tokenizer.json` | Tokenizer |
234
+ | `tokenizer_config.json` | Tokenizer config |
235
+ | `special_tokens_map.json` | Special tokens |
236
+
237
+ Note: Only the projector weights are stored. The encoder (GLM-ASR) and decoder (Qwen3) are loaded from their respective HuggingFace repos.
238
+
239
+ ## Citation
240
+
241
+ If you use this model, please cite:
242
+
243
+ ```bibtex
244
+ @misc{tinyaudio2024,
245
+ author = {Alex Kroman},
246
+ title = {Tiny Audio: Minimal ASR Training},
247
+ year = {2024},
248
+ publisher = {GitHub},
249
+ url = {https://github.com/alexkroman/tiny-audio}
250
+ }
251
+ ```
252
 
253
  ## Links
254
 
255
+ - [GitHub Repository](https://github.com/alexkroman/tiny-audio) - Train your own model
256
+ - [Free 3.5-hour Course](https://github.com/alexkroman/tiny-audio/blob/main/docs/course/0-course-overview.md) - Learn ASR from scratch
257
+ - [Live Demo](https://huggingface.co/spaces/mazesmazes/tiny-audio) - Try it in your browser
258
+
259
+ ## Acknowledgments
260
+
261
+ - [GLM-ASR](https://huggingface.co/zai-org/GLM-ASR-Nano-2512) for the audio encoder
262
+ - [Qwen3](https://huggingface.co/Qwen/Qwen3-0.6B) for the language model
263
+ - [LoquaciousSet](https://huggingface.co/datasets/speechbrain/LoquaciousSet) for training data
264
+
265
+ ## License
266
+
267
+ MIT
asr_modeling.py CHANGED
@@ -89,13 +89,27 @@ class ASRModel(PreTrainedModel, GenerationMixin):
89
  if adapter_config_file is not None:
90
  # Load saved adapter weights using the original repo_id/path
91
  # PEFT handles Hub downloads and caching internally
92
- from peft import PeftModel
 
 
 
 
 
 
 
 
 
 
 
 
 
93
 
94
  # language_model is bare (not PEFT-wrapped) since we skipped _setup_lora
95
  model.language_model = PeftModel.from_pretrained(
96
  model.language_model,
97
  pretrained_model_name_or_path, # Use original repo_id, not cache path
98
  is_trainable=True,
 
99
  **cache_kwargs,
100
  )
101
  else:
@@ -113,8 +127,8 @@ class ASRModel(PreTrainedModel, GenerationMixin):
113
  model.language_model = get_peft_model(model.language_model, lora_config)
114
 
115
  # Clear base_model_name_or_path so PEFT doesn't save a reference
116
- # to the base LLM. See _setup_lora for details.
117
- model.language_model.peft_config["default"].base_model_name_or_path = None
118
 
119
  return model
120
  finally:
@@ -295,8 +309,8 @@ class ASRModel(PreTrainedModel, GenerationMixin):
295
 
296
  # Clear base_model_name_or_path so PEFT doesn't save a reference to the
297
  # base LLM (e.g. Qwen). This prevents pipeline() from redirecting to the
298
- # wrong model. The correct path gets set during save_pretrained/push_to_hub.
299
- self.language_model.peft_config["default"].base_model_name_or_path = None
300
 
301
  def _init_tokenizer(self, config: ASRConfig):
302
  """Initialize tokenizer with audio token."""
@@ -738,23 +752,25 @@ class ASRModel(PreTrainedModel, GenerationMixin):
738
  if hasattr(self.language_model, "peft_config"):
739
  self.language_model.save_pretrained(save_dir, save_embedding_layers=False)
740
 
741
- # Fix adapter_config.json to point base_model_name_or_path to the repo itself
742
- # This prevents transformers pipeline() from redirecting to the base LLM repo
743
- # (like Qwen) which breaks feature extractor loading for multimodal models.
744
- # See: https://huggingface.co/ibm-granite/granite-speech-3.3-2b for reference
745
  adapter_config_path = save_dir / "adapter_config.json"
746
  if adapter_config_path.exists():
747
  with adapter_config_path.open() as f:
748
  adapter_config = json.load(f)
749
 
750
- # Use repo_id from kwargs or config - never use checkpoint directory name
 
 
751
  repo_id = (
752
  kwargs.get("repo_id")
753
  or kwargs.get("push_to_hub_model_id")
754
  or getattr(self.config, "pretrained_model_path", None)
 
755
  )
756
- if repo_id:
757
- adapter_config["base_model_name_or_path"] = repo_id
758
 
759
  with adapter_config_path.open("w") as f:
760
  json.dump(adapter_config, f, indent=2)
@@ -785,8 +801,15 @@ class ASRModel(PreTrainedModel, GenerationMixin):
785
  shutil.copy(src_dir / "projectors.py", save_dir / "projectors.py")
786
 
787
  def push_to_hub(self, repo_id: str, **kwargs) -> str:
788
- """Push model to HuggingFace Hub, ensuring adapter_config points to repo."""
789
- # Call parent's push_to_hub with repo_id in kwargs so save_pretrained can use it
 
 
 
 
 
 
 
790
  return super().push_to_hub(repo_id, repo_id=repo_id, **kwargs)
791
 
792
  def create_or_update_model_card(self, output_dir: Union[str, Path]) -> None:
 
89
  if adapter_config_file is not None:
90
  # Load saved adapter weights using the original repo_id/path
91
  # PEFT handles Hub downloads and caching internally
92
+ from peft import LoraConfig, PeftModel
93
+
94
+ # Pre-load and fix the adapter config to avoid str(None) -> "None" bug.
95
+ # Some PEFT/transformers versions convert null to "None" string which
96
+ # causes HF to try loading a model called "None".
97
+ with open(adapter_config_file) as f:
98
+ adapter_config_dict = json.load(f)
99
+
100
+ # Fix base_model_name_or_path if it's None/null
101
+ if adapter_config_dict.get("base_model_name_or_path") is None:
102
+ adapter_config_dict["base_model_name_or_path"] = ""
103
+
104
+ # Create LoraConfig from the fixed dict
105
+ peft_config = LoraConfig(**adapter_config_dict)
106
 
107
  # language_model is bare (not PEFT-wrapped) since we skipped _setup_lora
108
  model.language_model = PeftModel.from_pretrained(
109
  model.language_model,
110
  pretrained_model_name_or_path, # Use original repo_id, not cache path
111
  is_trainable=True,
112
+ config=peft_config, # Use our fixed config
113
  **cache_kwargs,
114
  )
115
  else:
 
127
  model.language_model = get_peft_model(model.language_model, lora_config)
128
 
129
  # Clear base_model_name_or_path so PEFT doesn't save a reference
130
+ # to the base LLM. Use empty string to avoid str(None) -> "None" bug.
131
+ model.language_model.peft_config["default"].base_model_name_or_path = ""
132
 
133
  return model
134
  finally:
 
309
 
310
  # Clear base_model_name_or_path so PEFT doesn't save a reference to the
311
  # base LLM (e.g. Qwen). This prevents pipeline() from redirecting to the
312
+ # wrong model. Use empty string to avoid str(None) -> "None" bug.
313
+ self.language_model.peft_config["default"].base_model_name_or_path = ""
314
 
315
  def _init_tokenizer(self, config: ASRConfig):
316
  """Initialize tokenizer with audio token."""
 
752
  if hasattr(self.language_model, "peft_config"):
753
  self.language_model.save_pretrained(save_dir, save_embedding_layers=False)
754
 
755
+ # Clear base_model_name_or_path in adapter_config.json to prevent HF pipeline
756
+ # from redirecting to the base LLM repo (like Qwen) which breaks feature
757
+ # extractor loading for multimodal models. If a repo_id is provided, use that
758
+ # so the model can be loaded directly from the Hub.
759
  adapter_config_path = save_dir / "adapter_config.json"
760
  if adapter_config_path.exists():
761
  with adapter_config_path.open() as f:
762
  adapter_config = json.load(f)
763
 
764
+ # Use repo_id if available, otherwise clear to prevent redirect.
765
+ # Use empty string instead of None to avoid str(None) -> "None" bug
766
+ # in some transformers/PEFT versions.
767
  repo_id = (
768
  kwargs.get("repo_id")
769
  or kwargs.get("push_to_hub_model_id")
770
  or getattr(self.config, "pretrained_model_path", None)
771
+ or "" # Use empty string instead of None
772
  )
773
+ adapter_config["base_model_name_or_path"] = repo_id
 
774
 
775
  with adapter_config_path.open("w") as f:
776
  json.dump(adapter_config, f, indent=2)
 
801
  shutil.copy(src_dir / "projectors.py", save_dir / "projectors.py")
802
 
803
  def push_to_hub(self, repo_id: str, **kwargs) -> str:
804
+ """Push model to HuggingFace Hub, ensuring adapter_config points to repo.
805
+
806
+ IMPORTANT: Sets base_model_name_or_path in adapter_config.json to repo_id
807
+ so that transformers pipeline() can load the model correctly. Without this,
808
+ the pipeline tries to load from "None" which fails.
809
+ """
810
+ # Store repo_id in config so save_pretrained can access it
811
+ self.config.pretrained_model_path = repo_id
812
+ # Call parent's push_to_hub with repo_id in kwargs
813
  return super().push_to_hub(repo_id, repo_id=repo_id, **kwargs)
814
 
815
  def create_or_update_model_card(self, output_dir: Union[str, Path]) -> None:
asr_pipeline.py CHANGED
@@ -521,12 +521,19 @@ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
521
  Returns:
522
  Dict with 'text' key containing transcription
523
  """
 
 
 
 
524
  # Handle list of outputs (from chunking)
525
  if isinstance(model_outputs, list):
 
526
  model_outputs = model_outputs[0] if model_outputs else {}
527
 
528
  tokens = model_outputs.get("tokens")
 
529
  if tokens is None:
 
530
  return super().postprocess(model_outputs, **kwargs)
531
 
532
  if torch.is_tensor(tokens):
@@ -537,15 +544,20 @@ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
537
  text = self.tokenizer.decode(tokens, skip_special_tokens=True).strip()
538
  # Strip <think>...</think> tags (Qwen3 doesn't respect /no_think prompt)
539
  text = re.sub(r"<think>.*?</think>\s*", "", text, flags=re.DOTALL).strip()
 
540
  # Post-process prediction
541
  text = self._post_process_prediction(text)
 
542
  return {"text": text}
543
 
544
  # Known hallucination patterns that should be deleted entirely
545
  HALLUCINATION_PATTERNS = frozenset(
546
  [
547
  "and gt and gt",
 
 
548
  "n", # Single character noise
 
549
  ]
550
  )
551
 
 
521
  Returns:
522
  Dict with 'text' key containing transcription
523
  """
524
+ # DEBUG: Track which code path we're using
525
+ import sys
526
+ print(f"[DEBUG postprocess] type(model_outputs)={type(model_outputs).__name__}", file=sys.stderr)
527
+
528
  # Handle list of outputs (from chunking)
529
  if isinstance(model_outputs, list):
530
+ print(f"[DEBUG postprocess] list len={len(model_outputs)}", file=sys.stderr)
531
  model_outputs = model_outputs[0] if model_outputs else {}
532
 
533
  tokens = model_outputs.get("tokens")
534
+ print(f"[DEBUG postprocess] tokens is None: {tokens is None}", file=sys.stderr)
535
  if tokens is None:
536
+ print("[DEBUG postprocess] FALLING BACK TO SUPER", file=sys.stderr)
537
  return super().postprocess(model_outputs, **kwargs)
538
 
539
  if torch.is_tensor(tokens):
 
544
  text = self.tokenizer.decode(tokens, skip_special_tokens=True).strip()
545
  # Strip <think>...</think> tags (Qwen3 doesn't respect /no_think prompt)
546
  text = re.sub(r"<think>.*?</think>\s*", "", text, flags=re.DOTALL).strip()
547
+ print(f"[DEBUG postprocess] BEFORE truncation: {len(text.split())} words", file=sys.stderr)
548
  # Post-process prediction
549
  text = self._post_process_prediction(text)
550
+ print(f"[DEBUG postprocess] AFTER truncation: {len(text.split())} words", file=sys.stderr)
551
  return {"text": text}
552
 
553
  # Known hallucination patterns that should be deleted entirely
554
  HALLUCINATION_PATTERNS = frozenset(
555
  [
556
  "and gt and gt",
557
+ "and gt",
558
+ "gt and gt",
559
  "n", # Single character noise
560
+ "and", # Common short hallucination
561
  ]
562
  )
563