Instructions to use mlx-community/audiogen-medium-mlx with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- MLX
How to use mlx-community/audiogen-medium-mlx with MLX:
# Download the model from the Hub pip install huggingface_hub[hf_xet] huggingface-cli download --local-dir audiogen-medium-mlx mlx-community/audiogen-medium-mlx
- Notebooks
- Google Colab
- Kaggle
- Local Apps Settings
- LM Studio
fix: critical T5 conditioner key sanitization and metadata
Browse files- extract_t5.py: sanitize .layer.0. β .layer_0. and .layer.1. β .layer_1.
for MLX ModuleParameters.unflattened() compatibility. Without this fix,
all 24 T5 transformer block weights remain at random initialization.
Also adds automatic T5 variant detection (small/base/large).
- config.json: correct T5 metadata from t5-small to t5-large
(d_model=1024, 24 layers, 16 heads, d_ff=4096).
- README.md: fix text encoder reference, add T5 extraction section,
MLX key sanitization note, T5 unscaled attention note, update Swift usage.
- verify_t5.py: new file β Python MLX reference implementation for
verifying T5 encoder output against the Swift implementation.
- README.md +25 -5
- config.json +9 -9
- extract_t5.py +271 -0
- verify_t5.py +274 -0
README.md
CHANGED
|
@@ -21,15 +21,28 @@ This is the MLX-native port of [facebook/audiogen-medium](https://huggingface.co
|
|
| 21 |
- **Parameters**: ~1.5B (LM) + EnCodec compression model
|
| 22 |
- **Sampling rate**: 16 kHz
|
| 23 |
- **Frame rate**: 50 Hz (4 codebooks, delayed pattern)
|
| 24 |
-
- **Text encoder**: T5-
|
| 25 |
- **Max duration**: 10 seconds (configurable)
|
| 26 |
|
| 27 |
## Files
|
| 28 |
|
| 29 |
-
- `config.json` β Model configuration
|
| 30 |
- `model.safetensors` β LM + EnCodec weights
|
| 31 |
- `model.safetensors.index.json` β Weight index (for sharded variants)
|
| 32 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
|
| 34 |
## Usage (Swift/MLX)
|
| 35 |
|
|
@@ -40,15 +53,22 @@ let model = try await AudioGenModel.fromPretrained(
|
|
| 40 |
modelFolder: modelURL,
|
| 41 |
t5Folder: t5URL
|
| 42 |
)
|
| 43 |
-
|
| 44 |
-
|
|
|
|
| 45 |
duration: 5.0,
|
| 46 |
cfgCoef: 3.0,
|
| 47 |
temperature: 1.0,
|
| 48 |
topK: 250
|
| 49 |
)
|
|
|
|
|
|
|
| 50 |
```
|
| 51 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
## License
|
| 53 |
|
| 54 |
This model is published under the [CC-BY-NC 4.0](https://creativecommons.org/licenses/by-nc/4.0/) license (non-commercial use only), following the original [AudioGen license](https://huggingface.co/facebook/audiogen-medium).
|
|
|
|
| 21 |
- **Parameters**: ~1.5B (LM) + EnCodec compression model
|
| 22 |
- **Sampling rate**: 16 kHz
|
| 23 |
- **Frame rate**: 50 Hz (4 codebooks, delayed pattern)
|
| 24 |
+
- **Text encoder**: T5-large (d_model=1024, 24 layers, 16 heads)
|
| 25 |
- **Max duration**: 10 seconds (configurable)
|
| 26 |
|
| 27 |
## Files
|
| 28 |
|
| 29 |
+
- `config.json` β Model configuration (includes `t5_model_name` reference)
|
| 30 |
- `model.safetensors` β LM + EnCodec weights
|
| 31 |
- `model.safetensors.index.json` β Weight index (for sharded variants)
|
| 32 |
+
|
| 33 |
+
### T5 Conditioner (extracted separately)
|
| 34 |
+
|
| 35 |
+
The T5-large text encoder weights are not included in this repository. Use `extract_t5.py` to extract them from the original `facebook/audiogen-medium` checkpoint:
|
| 36 |
+
|
| 37 |
+
```bash
|
| 38 |
+
python extract_t5.py --output /path/to/audiogen-mlx/t5
|
| 39 |
+
```
|
| 40 |
+
|
| 41 |
+
This produces a `t5/` directory with `config.json`, `model.safetensors`, and tokenizer files.
|
| 42 |
+
|
| 43 |
+
> **Note**: The T5 safetensors keys use MLX-compatible naming (`.layer_0.` / `.layer_1.`
|
| 44 |
+
> instead of HuggingFace's `.layer.0.` / `.layer.1.`). This is required because MLX's
|
| 45 |
+
> `ModuleParameters.unflattened()` splits on all dots.
|
| 46 |
|
| 47 |
## Usage (Swift/MLX)
|
| 48 |
|
|
|
|
| 53 |
modelFolder: modelURL,
|
| 54 |
t5Folder: t5URL
|
| 55 |
)
|
| 56 |
+
|
| 57 |
+
let tokens = try await model.generate(
|
| 58 |
+
descriptions: ["dog barking"],
|
| 59 |
duration: 5.0,
|
| 60 |
cfgCoef: 3.0,
|
| 61 |
temperature: 1.0,
|
| 62 |
topK: 250
|
| 63 |
)
|
| 64 |
+
|
| 65 |
+
let audio = model.decode(tokens: tokens)
|
| 66 |
```
|
| 67 |
|
| 68 |
+
## T5 Attention
|
| 69 |
+
|
| 70 |
+
T5's self-attention intentionally does **not** scale scores by `1/sqrt(d_k)`. This is a deliberate design choice in the T5 architecture β do not add scaling in the inference code.
|
| 71 |
+
|
| 72 |
## License
|
| 73 |
|
| 74 |
This model is published under the [CC-BY-NC 4.0](https://creativecommons.org/licenses/by-nc/4.0/) license (non-commercial use only), following the original [AudioGen license](https://huggingface.co/facebook/audiogen-medium).
|
config.json
CHANGED
|
@@ -36,14 +36,14 @@
|
|
| 36 |
"duration": 10.0,
|
| 37 |
"numSamples": 1,
|
| 38 |
"specialToken": 2048,
|
| 39 |
-
"tokenizer": "t5-
|
| 40 |
-
"t5_model_name": "t5-
|
| 41 |
"clsToken": 2048,
|
| 42 |
"padToken": 2048,
|
| 43 |
"encodec": {
|
| 44 |
"model_type": "encodec",
|
| 45 |
"audio_channels": 1,
|
| 46 |
-
"num_filters":
|
| 47 |
"kernel_size": 7,
|
| 48 |
"num_residual_layers": 1,
|
| 49 |
"dilation_growth_rate": 2,
|
|
@@ -67,15 +67,15 @@
|
|
| 67 |
"use_conv_shortcut": false
|
| 68 |
},
|
| 69 |
"t5": {
|
| 70 |
-
"model_name": "t5-
|
| 71 |
-
"d_model":
|
| 72 |
"d_kv": 64,
|
| 73 |
-
"d_ff":
|
| 74 |
-
"num_layers":
|
| 75 |
-
"num_heads":
|
| 76 |
"relative_attention_num_buckets": 32,
|
| 77 |
"relative_attention_max_distance": 128,
|
| 78 |
-
"dropout_rate": 0.
|
| 79 |
"layer_norm_epsilon": 1e-06,
|
| 80 |
"feed_forward_proj": "relu",
|
| 81 |
"vocab_size": 32128,
|
|
|
|
| 36 |
"duration": 10.0,
|
| 37 |
"numSamples": 1,
|
| 38 |
"specialToken": 2048,
|
| 39 |
+
"tokenizer": "t5-large",
|
| 40 |
+
"t5_model_name": "t5-large",
|
| 41 |
"clsToken": 2048,
|
| 42 |
"padToken": 2048,
|
| 43 |
"encodec": {
|
| 44 |
"model_type": "encodec",
|
| 45 |
"audio_channels": 1,
|
| 46 |
+
"num_filters": 32,
|
| 47 |
"kernel_size": 7,
|
| 48 |
"num_residual_layers": 1,
|
| 49 |
"dilation_growth_rate": 2,
|
|
|
|
| 67 |
"use_conv_shortcut": false
|
| 68 |
},
|
| 69 |
"t5": {
|
| 70 |
+
"model_name": "t5-large",
|
| 71 |
+
"d_model": 1024,
|
| 72 |
"d_kv": 64,
|
| 73 |
+
"d_ff": 4096,
|
| 74 |
+
"num_layers": 24,
|
| 75 |
+
"num_heads": 16,
|
| 76 |
"relative_attention_num_buckets": 32,
|
| 77 |
"relative_attention_max_distance": 128,
|
| 78 |
+
"dropout_rate": 0.0,
|
| 79 |
"layer_norm_epsilon": 1e-06,
|
| 80 |
"feed_forward_proj": "relu",
|
| 81 |
"vocab_size": 32128,
|
extract_t5.py
ADDED
|
@@ -0,0 +1,271 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Extract T5 conditioner weights from facebook/audiogen-medium for MLX.
|
| 4 |
+
|
| 5 |
+
The original AudioGen model bundles a frozen T5 text encoder and a trained
|
| 6 |
+
output projection inside condition_provider.*. The main MLX conversion strips
|
| 7 |
+
these keys. This script extracts them into a t5/ subdirectory that the MLX
|
| 8 |
+
AudioGen loader expects.
|
| 9 |
+
|
| 10 |
+
Usage:
|
| 11 |
+
# Automatic: downloads from HuggingFace, extracts, cleans up
|
| 12 |
+
python extract_t5.py --output /path/to/audiogen-mlx/t5
|
| 13 |
+
|
| 14 |
+
# Manual: use a local state_dict.bin you already downloaded
|
| 15 |
+
python extract_t5.py --lm /path/to/state_dict.bin --output /path/to/audiogen-mlx/t5
|
| 16 |
+
|
| 17 |
+
Output (in --output directory):
|
| 18 |
+
config.json T5 encoder config (derived from weight shapes)
|
| 19 |
+
model.safetensors T5 encoder weights + output_proj
|
| 20 |
+
tokenizer.json Downloaded from google-t5/t5-small
|
| 21 |
+
tokenizer_config.json Downloaded from google-t5/t5-small
|
| 22 |
+
|
| 23 |
+
Requirements:
|
| 24 |
+
pip install torch safetensors huggingface_hub
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
import argparse
|
| 28 |
+
import json
|
| 29 |
+
import os
|
| 30 |
+
import struct
|
| 31 |
+
import tempfile
|
| 32 |
+
import shutil
|
| 33 |
+
|
| 34 |
+
import torch
|
| 35 |
+
from safetensors.torch import save_file
|
| 36 |
+
from huggingface_hub import hf_hub_download
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
T5_PREFIX = "condition_provider.conditioners.description.model."
|
| 40 |
+
OUTPUT_PROJ_PREFIX = "condition_provider.conditioners.description.output_proj."
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def load_lm_state(path):
|
| 44 |
+
"""Load the LM state dict from a PyTorch checkpoint."""
|
| 45 |
+
ckpt = torch.load(path, map_location="cpu", weights_only=True)
|
| 46 |
+
if "best_state" in ckpt:
|
| 47 |
+
return ckpt["best_state"]
|
| 48 |
+
return ckpt
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def extract_t5_weights(lm_state):
|
| 52 |
+
"""Extract T5 encoder and output_proj weights from the LM state dict."""
|
| 53 |
+
t5_weights = {}
|
| 54 |
+
output_proj = {}
|
| 55 |
+
other_cp = []
|
| 56 |
+
|
| 57 |
+
for key, tensor in lm_state.items():
|
| 58 |
+
if not key.startswith("condition_provider."):
|
| 59 |
+
continue
|
| 60 |
+
|
| 61 |
+
if key.startswith(T5_PREFIX):
|
| 62 |
+
# Strip prefix to get standard HuggingFace T5 key format
|
| 63 |
+
new_key = key[len(T5_PREFIX):]
|
| 64 |
+
t5_weights[new_key] = tensor
|
| 65 |
+
elif key.startswith(OUTPUT_PROJ_PREFIX):
|
| 66 |
+
# output_proj.weight / output_proj.bias
|
| 67 |
+
new_key = key[len(OUTPUT_PROJ_PREFIX):]
|
| 68 |
+
output_proj[f"output_proj.{new_key}"] = tensor
|
| 69 |
+
else:
|
| 70 |
+
other_cp.append(key)
|
| 71 |
+
|
| 72 |
+
return t5_weights, output_proj, other_cp
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def sanitize_keys_for_mlx(weights):
|
| 76 |
+
"""Rename T5 weight keys for MLX compatibility.
|
| 77 |
+
|
| 78 |
+
HuggingFace T5 uses keys like "encoder.block.0.layer.0.SelfAttention.q.weight"
|
| 79 |
+
where "layer.0" and "layer.1" are sub-module names. MLX's
|
| 80 |
+
ModuleParameters.unflattened() splits on ALL dots, which misparses "layer.0"
|
| 81 |
+
as {"layer": {"0": ...}} instead of treating it as a single key.
|
| 82 |
+
|
| 83 |
+
This renames ".layer.0." to ".layer_0." and ".layer.1." to ".layer_1." so
|
| 84 |
+
the keys work correctly with MLX's parameter loading.
|
| 85 |
+
"""
|
| 86 |
+
sanitized = {}
|
| 87 |
+
for key, value in weights.items():
|
| 88 |
+
new_key = key
|
| 89 |
+
new_key = new_key.replace(".layer.0.", ".layer_0.")
|
| 90 |
+
new_key = new_key.replace(".layer.1.", ".layer_1.")
|
| 91 |
+
sanitized[new_key] = value
|
| 92 |
+
return sanitized
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def infer_t5_config(t5_weights):
|
| 96 |
+
"""Determine T5 architecture from weight shapes."""
|
| 97 |
+
# shared.weight: [vocab_size, d_model]
|
| 98 |
+
shared = t5_weights.get("shared.weight")
|
| 99 |
+
if shared is None:
|
| 100 |
+
raise ValueError("Cannot find shared.weight in T5 weights")
|
| 101 |
+
|
| 102 |
+
vocab_size = shared.shape[0]
|
| 103 |
+
d_model = shared.shape[1]
|
| 104 |
+
|
| 105 |
+
# Find q projection to determine d_kv and num_heads
|
| 106 |
+
q_weight = t5_weights.get("encoder.block.0.layer.0.SelfAttention.q.weight")
|
| 107 |
+
if q_weight is None:
|
| 108 |
+
raise ValueError("Cannot find SelfAttention.q.weight")
|
| 109 |
+
|
| 110 |
+
# q.weight: [num_heads * d_kv, d_model]
|
| 111 |
+
total_kv = q_weight.shape[0]
|
| 112 |
+
|
| 113 |
+
# Find DenseReluDense.wi to determine d_ff
|
| 114 |
+
wi = t5_weights.get("encoder.block.0.layer.1.DenseReluDense.wi.weight")
|
| 115 |
+
if wi is None:
|
| 116 |
+
raise ValueError("Cannot find DenseReluDense.wi.weight")
|
| 117 |
+
d_ff = wi.shape[0]
|
| 118 |
+
|
| 119 |
+
# Count encoder layers
|
| 120 |
+
num_layers = 0
|
| 121 |
+
while f"encoder.block.{num_layers}.layer.0.SelfAttention.q.weight" in t5_weights:
|
| 122 |
+
num_layers += 1
|
| 123 |
+
|
| 124 |
+
# Determine d_kv and num_heads
|
| 125 |
+
# Standard T5 d_kv values: 64 (all sizes)
|
| 126 |
+
d_kv = 64
|
| 127 |
+
num_heads = total_kv // d_kv
|
| 128 |
+
|
| 129 |
+
# Check relative_attention_bias
|
| 130 |
+
rab = t5_weights.get(
|
| 131 |
+
"encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight"
|
| 132 |
+
)
|
| 133 |
+
num_buckets = rab.shape[0] if rab is not None else 32
|
| 134 |
+
|
| 135 |
+
# Determine T5 variant name from d_model
|
| 136 |
+
t5_variant = "t5-unknown"
|
| 137 |
+
if d_model == 512:
|
| 138 |
+
t5_variant = "t5-small"
|
| 139 |
+
elif d_model == 768:
|
| 140 |
+
t5_variant = "t5-base"
|
| 141 |
+
elif d_model == 1024:
|
| 142 |
+
t5_variant = "t5-large"
|
| 143 |
+
elif d_model == 4096:
|
| 144 |
+
t5_variant = "t5-3b"
|
| 145 |
+
|
| 146 |
+
config = {
|
| 147 |
+
"architectures": ["T5EncoderModel"],
|
| 148 |
+
"model_name": t5_variant,
|
| 149 |
+
"d_model": d_model,
|
| 150 |
+
"d_kv": d_kv,
|
| 151 |
+
"d_ff": d_ff,
|
| 152 |
+
"num_heads": num_heads,
|
| 153 |
+
"num_layers": num_layers,
|
| 154 |
+
"vocab_size": vocab_size,
|
| 155 |
+
"relative_attention_num_buckets": num_buckets,
|
| 156 |
+
"relative_attention_max_distance": 128,
|
| 157 |
+
"dropout_rate": 0.0,
|
| 158 |
+
"layer_norm_epsilon": 1e-6,
|
| 159 |
+
"feed_forward_proj": "relu",
|
| 160 |
+
"tie_word_embeddings": True,
|
| 161 |
+
"decoder_start_token_id": 0,
|
| 162 |
+
"model_type": "t5",
|
| 163 |
+
}
|
| 164 |
+
return config
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def download_tokenizer(output_dir):
|
| 168 |
+
"""Download T5 tokenizer files from HuggingFace.
|
| 169 |
+
|
| 170 |
+
All T5 model sizes share the same SentencePiece tokenizer (32128 tokens),
|
| 171 |
+
so we download from t5-small for convenience.
|
| 172 |
+
"""
|
| 173 |
+
repo = "google-t5/t5-small"
|
| 174 |
+
for filename in ["tokenizer.json", "tokenizer_config.json"]:
|
| 175 |
+
path = hf_hub_download(repo_id=repo, filename=filename)
|
| 176 |
+
dst = os.path.join(output_dir, filename)
|
| 177 |
+
shutil.copy2(path, dst)
|
| 178 |
+
print(f" Copied {filename}")
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
def main():
|
| 182 |
+
parser = argparse.ArgumentParser(
|
| 183 |
+
description="Extract T5 conditioner from facebook/audiogen-medium"
|
| 184 |
+
)
|
| 185 |
+
parser.add_argument(
|
| 186 |
+
"--lm",
|
| 187 |
+
help="Path to local state_dict.bin (skips download)",
|
| 188 |
+
)
|
| 189 |
+
parser.add_argument(
|
| 190 |
+
"--output",
|
| 191 |
+
required=True,
|
| 192 |
+
help="Output directory for T5 weights (e.g. /path/to/model/t5)",
|
| 193 |
+
)
|
| 194 |
+
args = parser.parse_args()
|
| 195 |
+
|
| 196 |
+
os.makedirs(args.output, exist_ok=True)
|
| 197 |
+
|
| 198 |
+
# Get the state dict
|
| 199 |
+
if args.lm:
|
| 200 |
+
lm_path = args.lm
|
| 201 |
+
print(f"Loading local checkpoint: {lm_path}")
|
| 202 |
+
else:
|
| 203 |
+
print("Downloading facebook/audiogen-medium state_dict.bin ...")
|
| 204 |
+
lm_path = hf_hub_download(
|
| 205 |
+
repo_id="facebook/audiogen-medium",
|
| 206 |
+
filename="state_dict.bin",
|
| 207 |
+
)
|
| 208 |
+
print(f" Downloaded to cache: {lm_path}")
|
| 209 |
+
|
| 210 |
+
print("Loading state dict ...")
|
| 211 |
+
lm_state = load_lm_state(lm_path)
|
| 212 |
+
|
| 213 |
+
print("Extracting T5 weights ...")
|
| 214 |
+
t5_weights, output_proj, other_cp = extract_t5_weights(lm_state)
|
| 215 |
+
|
| 216 |
+
print(f" T5 encoder keys: {len(t5_weights)}")
|
| 217 |
+
print(f" Output projection keys: {len(output_proj)}")
|
| 218 |
+
if other_cp:
|
| 219 |
+
print(f" Other condition_provider keys (skipped): {len(other_cp)}")
|
| 220 |
+
|
| 221 |
+
if not t5_weights:
|
| 222 |
+
print("ERROR: No T5 weights found in checkpoint!")
|
| 223 |
+
return
|
| 224 |
+
|
| 225 |
+
# Infer T5 architecture
|
| 226 |
+
config = infer_t5_config(t5_weights)
|
| 227 |
+
print(f" T5 config: {config['model_name']} β d_model={config['d_model']}, "
|
| 228 |
+
f"num_heads={config['num_heads']}, "
|
| 229 |
+
f"num_layers={config['num_layers']}, "
|
| 230 |
+
f"d_ff={config['d_ff']}, "
|
| 231 |
+
f"vocab_size={config['vocab_size']}")
|
| 232 |
+
|
| 233 |
+
if output_proj:
|
| 234 |
+
proj_w = output_proj.get("output_proj.weight")
|
| 235 |
+
if proj_w is not None:
|
| 236 |
+
print(f" Output projection: {list(proj_w.shape)} "
|
| 237 |
+
f"(T5 d_model={proj_w.shape[1]} β LM dim={proj_w.shape[0]})")
|
| 238 |
+
|
| 239 |
+
# Sanitize keys for MLX compatibility before saving
|
| 240 |
+
sanitized_t5 = sanitize_keys_for_mlx(t5_weights)
|
| 241 |
+
print(f" Sanitized {len(sanitized_t5)} T5 keys (.layer.N. β .layer_N.)")
|
| 242 |
+
|
| 243 |
+
# Combine sanitized T5 weights + output_proj into one safetensors
|
| 244 |
+
all_weights = {}
|
| 245 |
+
all_weights.update(sanitized_t5)
|
| 246 |
+
all_weights.update(output_proj)
|
| 247 |
+
|
| 248 |
+
# Save safetensors
|
| 249 |
+
st_path = os.path.join(args.output, "model.safetensors")
|
| 250 |
+
print(f"Saving {len(all_weights)} tensors to {st_path} ...")
|
| 251 |
+
save_file(all_weights, st_path)
|
| 252 |
+
|
| 253 |
+
total_bytes = os.path.getsize(st_path)
|
| 254 |
+
print(f" Size: {total_bytes / 1e6:.1f} MB")
|
| 255 |
+
|
| 256 |
+
# Save config
|
| 257 |
+
config_path = os.path.join(args.output, "config.json")
|
| 258 |
+
with open(config_path, "w") as f:
|
| 259 |
+
json.dump(config, f, indent=2)
|
| 260 |
+
print(f"Saved config.json")
|
| 261 |
+
|
| 262 |
+
# Download tokenizer
|
| 263 |
+
print("Downloading T5 tokenizer ...")
|
| 264 |
+
download_tokenizer(args.output)
|
| 265 |
+
|
| 266 |
+
print(f"\nDone! T5 conditioner saved to: {args.output}")
|
| 267 |
+
print("Files:", sorted(os.listdir(args.output)))
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
if __name__ == "__main__":
|
| 271 |
+
main()
|
verify_t5.py
ADDED
|
@@ -0,0 +1,274 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Verify T5 encoder output against Swift implementation.
|
| 3 |
+
|
| 4 |
+
Loads the same T5 safetensors weights, runs the encoder on the same tokens,
|
| 5 |
+
and prints output stats for comparison with the Swift logs.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import math
|
| 9 |
+
import mlx.core as mx
|
| 10 |
+
import mlx.nn as nn
|
| 11 |
+
import json
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
|
| 14 |
+
MODEL_DIR = Path.home() / "Library/Application Support/Velvox/Models/audiogen-mlx/t5"
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
# ββ T5 LayerNorm (RMSNorm, no centering) ββ
|
| 18 |
+
|
| 19 |
+
class T5LayerNorm(nn.Module):
|
| 20 |
+
def __init__(self, dims, eps=1e-6):
|
| 21 |
+
super().__init__()
|
| 22 |
+
self.weight = mx.ones((dims,))
|
| 23 |
+
self.eps = eps
|
| 24 |
+
|
| 25 |
+
def __call__(self, x):
|
| 26 |
+
y = x.astype(mx.float32)
|
| 27 |
+
y = y * mx.rsqrt(mx.mean(y * y, axis=-1, keepdims=True) + self.eps)
|
| 28 |
+
return self.weight * y.astype(x.dtype)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
# ββ T5 DenseReluDense ββ
|
| 32 |
+
|
| 33 |
+
class T5DenseActDense(nn.Module):
|
| 34 |
+
def __init__(self, d_model, d_ff, act="relu"):
|
| 35 |
+
super().__init__()
|
| 36 |
+
self.wi = nn.Linear(d_model, d_ff, bias=False)
|
| 37 |
+
self.wo = nn.Linear(d_ff, d_model, bias=False)
|
| 38 |
+
self.act = act
|
| 39 |
+
|
| 40 |
+
def __call__(self, x):
|
| 41 |
+
h = self.wi(x)
|
| 42 |
+
h = nn.relu(h) if self.act == "relu" else nn.gelu(h)
|
| 43 |
+
return self.wo(h)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
# ββ T5 Attention (NO sqrt(d_k) scaling β this is T5's design) ββ
|
| 47 |
+
|
| 48 |
+
class T5Attention(nn.Module):
|
| 49 |
+
def __init__(self, config, has_relative_attention_bias=False):
|
| 50 |
+
super().__init__()
|
| 51 |
+
self.num_heads = config["num_heads"]
|
| 52 |
+
self.d_kv = config["d_kv"]
|
| 53 |
+
self.d_model = config["d_model"]
|
| 54 |
+
self.has_relative_attention_bias = has_relative_attention_bias
|
| 55 |
+
self.num_buckets = config["relative_attention_num_buckets"]
|
| 56 |
+
self.max_distance = config.get("relative_attention_max_distance", 128)
|
| 57 |
+
|
| 58 |
+
self.q = nn.Linear(self.d_model, self.num_heads * self.d_kv, bias=False)
|
| 59 |
+
self.k = nn.Linear(self.d_model, self.num_heads * self.d_kv, bias=False)
|
| 60 |
+
self.v = nn.Linear(self.d_model, self.num_heads * self.d_kv, bias=False)
|
| 61 |
+
self.o = nn.Linear(self.num_heads * self.d_kv, self.d_model, bias=False)
|
| 62 |
+
|
| 63 |
+
if has_relative_attention_bias:
|
| 64 |
+
self.relative_attention_bias = nn.Embedding(self.num_buckets, self.num_heads)
|
| 65 |
+
|
| 66 |
+
@staticmethod
|
| 67 |
+
def _relative_position_bucket(rp, bidirectional=True, num_buckets=32, max_distance=128):
|
| 68 |
+
nb = num_buckets
|
| 69 |
+
result = mx.zeros(rp.shape, dtype=mx.int32)
|
| 70 |
+
if bidirectional:
|
| 71 |
+
nb = nb // 2
|
| 72 |
+
is_pos = mx.where(rp > 0, mx.array(nb, dtype=mx.int32), mx.array(0, dtype=mx.int32))
|
| 73 |
+
result = is_pos
|
| 74 |
+
rp = mx.abs(rp)
|
| 75 |
+
else:
|
| 76 |
+
rp = -mx.minimum(rp, mx.zeros_like(rp))
|
| 77 |
+
|
| 78 |
+
max_exact = nb // 2
|
| 79 |
+
is_small = rp < max_exact
|
| 80 |
+
|
| 81 |
+
large_rp = rp.astype(mx.float32)
|
| 82 |
+
log_ratio = mx.log(large_rp / max_exact) / math.log(max_distance / max_exact)
|
| 83 |
+
large_bucket = (log_ratio * (nb - max_exact)).astype(mx.int32) + max_exact
|
| 84 |
+
clamped = mx.minimum(large_bucket, mx.array(nb - 1, dtype=mx.int32))
|
| 85 |
+
|
| 86 |
+
buckets = mx.where(is_small, rp.astype(mx.int32), clamped)
|
| 87 |
+
return result + buckets
|
| 88 |
+
|
| 89 |
+
def compute_bias(self, q_len, k_len):
|
| 90 |
+
if not self.has_relative_attention_bias:
|
| 91 |
+
return None
|
| 92 |
+
ctx = mx.arange(q_len, dtype=mx.int32)
|
| 93 |
+
mem = mx.arange(k_len, dtype=mx.int32)
|
| 94 |
+
rp = mem.reshape(1, -1).astype(mx.float32) - ctx.reshape(-1, 1).astype(mx.float32)
|
| 95 |
+
rp_bucket = self._relative_position_bucket(
|
| 96 |
+
rp, bidirectional=True,
|
| 97 |
+
num_buckets=self.num_buckets, max_distance=self.max_distance
|
| 98 |
+
)
|
| 99 |
+
flat = rp_bucket.reshape(-1)
|
| 100 |
+
bias_flat = self.relative_attention_bias(flat)
|
| 101 |
+
bias = bias_flat.reshape(q_len, k_len, self.num_heads)
|
| 102 |
+
bias = bias.transpose(2, 0, 1)[None, :, :, :] # [1, H, Q, K]
|
| 103 |
+
return bias
|
| 104 |
+
|
| 105 |
+
def __call__(self, hidden, mask=None, position_bias=None):
|
| 106 |
+
B, T, _ = hidden.shape
|
| 107 |
+
q = self.q(hidden).reshape(B, T, self.num_heads, self.d_kv)
|
| 108 |
+
k = self.k(hidden).reshape(B, T, self.num_heads, self.d_kv)
|
| 109 |
+
v = self.v(hidden).reshape(B, T, self.num_heads, self.d_kv)
|
| 110 |
+
|
| 111 |
+
q = q.transpose(0, 2, 1, 3) # [B, H, T, d]
|
| 112 |
+
k = k.transpose(0, 2, 3, 1) # [B, H, d, T]
|
| 113 |
+
v = v.transpose(0, 2, 1, 3) # [B, H, T, d]
|
| 114 |
+
|
| 115 |
+
# T5: NO scaling by 1/sqrt(d_k)
|
| 116 |
+
scores = q @ k
|
| 117 |
+
|
| 118 |
+
if position_bias is None:
|
| 119 |
+
position_bias = self.compute_bias(T, T)
|
| 120 |
+
if position_bias is not None:
|
| 121 |
+
scores = scores + position_bias
|
| 122 |
+
|
| 123 |
+
weights = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype)
|
| 124 |
+
out = (weights @ v).transpose(0, 2, 1, 3).reshape(B, T, -1)
|
| 125 |
+
return self.o(out)
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
# ββ T5 Block ββ
|
| 129 |
+
|
| 130 |
+
class T5Block(nn.Module):
|
| 131 |
+
def __init__(self, config, has_relative_attention_bias=False):
|
| 132 |
+
super().__init__()
|
| 133 |
+
self.self_attn = T5Attention(config, has_relative_attention_bias)
|
| 134 |
+
self.layer_norm_sa = T5LayerNorm(config["d_model"], config.get("layer_norm_epsilon", 1e-6))
|
| 135 |
+
self.ff = T5DenseActDense(config["d_model"], config["d_ff"], config.get("feed_forward_proj", "relu"))
|
| 136 |
+
self.layer_norm_ff = T5LayerNorm(config["d_model"], config.get("layer_norm_epsilon", 1e-6))
|
| 137 |
+
|
| 138 |
+
def __call__(self, x, mask=None, position_bias=None):
|
| 139 |
+
normed = self.layer_norm_sa(x)
|
| 140 |
+
attn_out = self.self_attn(normed, mask=mask, position_bias=position_bias)
|
| 141 |
+
x = x + attn_out
|
| 142 |
+
normed = self.layer_norm_ff(x)
|
| 143 |
+
ff_out = self.ff(normed)
|
| 144 |
+
x = x + ff_out
|
| 145 |
+
return x
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
# ββ T5 Encoder ββ
|
| 149 |
+
|
| 150 |
+
class T5Encoder(nn.Module):
|
| 151 |
+
def __init__(self, config):
|
| 152 |
+
super().__init__()
|
| 153 |
+
self.shared = nn.Embedding(config["vocab_size"], config["d_model"])
|
| 154 |
+
self.blocks = [T5Block(config, has_relative_attention_bias=(i == 0))
|
| 155 |
+
for i in range(config["num_layers"])]
|
| 156 |
+
self.final_layer_norm = T5LayerNorm(config["d_model"], config.get("layer_norm_epsilon", 1e-6))
|
| 157 |
+
|
| 158 |
+
def __call__(self, input_ids):
|
| 159 |
+
x = self.shared(input_ids)
|
| 160 |
+
# Compute position bias from first block, reuse for all
|
| 161 |
+
pos_bias = self.blocks[0].self_attn.compute_bias(x.shape[1], x.shape[1])
|
| 162 |
+
for block in self.blocks:
|
| 163 |
+
x = block(x, position_bias=pos_bias)
|
| 164 |
+
return self.final_layer_norm(x)
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def load_and_remap_weights(t5_dir):
|
| 168 |
+
"""Load safetensors and remap HuggingFace T5 keys to our module structure."""
|
| 169 |
+
import glob
|
| 170 |
+
safetensors_files = sorted(glob.glob(str(t5_dir / "*.safetensors")))
|
| 171 |
+
|
| 172 |
+
all_weights = {}
|
| 173 |
+
for f in safetensors_files:
|
| 174 |
+
w = mx.load(f)
|
| 175 |
+
all_weights.update(w)
|
| 176 |
+
|
| 177 |
+
# Separate output_proj from T5 weights
|
| 178 |
+
output_proj_w = all_weights.pop("output_proj.weight", None)
|
| 179 |
+
output_proj_b = all_weights.pop("output_proj.bias", None)
|
| 180 |
+
|
| 181 |
+
# Remap HuggingFace keys to our module structure
|
| 182 |
+
remapped = {}
|
| 183 |
+
for key, val in all_weights.items():
|
| 184 |
+
new_key = key
|
| 185 |
+
|
| 186 |
+
# shared.weight β shared.weight (OK)
|
| 187 |
+
|
| 188 |
+
# encoder.block.N.layer.0.SelfAttention.X β blocks.N.self_attn.X
|
| 189 |
+
new_key = new_key.replace("encoder.block.", "blocks.")
|
| 190 |
+
new_key = new_key.replace(".layer.0.SelfAttention.", ".self_attn.")
|
| 191 |
+
new_key = new_key.replace(".layer.0.layer_norm.", ".layer_norm_sa.")
|
| 192 |
+
new_key = new_key.replace(".layer.1.DenseReluDense.", ".ff.")
|
| 193 |
+
new_key = new_key.replace(".layer.1.layer_norm.", ".layer_norm_ff.")
|
| 194 |
+
|
| 195 |
+
# encoder.final_layer_norm β final_layer_norm
|
| 196 |
+
new_key = new_key.replace("encoder.final_layer_norm.", "final_layer_norm.")
|
| 197 |
+
|
| 198 |
+
remapped[new_key] = val
|
| 199 |
+
|
| 200 |
+
return remapped, output_proj_w, output_proj_b
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
def main():
|
| 204 |
+
print("=" * 60)
|
| 205 |
+
print("T5 Encoder Verification (MLX Python reference)")
|
| 206 |
+
print("=" * 60)
|
| 207 |
+
|
| 208 |
+
# Load config
|
| 209 |
+
with open(t5_dir / "config.json") as f:
|
| 210 |
+
config = json.load(f)
|
| 211 |
+
|
| 212 |
+
print(f"Config: d_model={config['d_model']} layers={config['num_layers']} "
|
| 213 |
+
f"heads={config['num_heads']} d_kv={config['d_kv']} d_ff={config['d_ff']}")
|
| 214 |
+
|
| 215 |
+
# Build model
|
| 216 |
+
encoder = T5Encoder(config)
|
| 217 |
+
|
| 218 |
+
# Load weights
|
| 219 |
+
t5_dir_p = MODEL_DIR
|
| 220 |
+
weights, proj_w, proj_b = load_and_remap_weights(t5_dir_p)
|
| 221 |
+
|
| 222 |
+
# Apply weights
|
| 223 |
+
encoder.load_weights(list(weights.items()))
|
| 224 |
+
|
| 225 |
+
# Build output_proj
|
| 226 |
+
output_proj = None
|
| 227 |
+
if proj_w is not None:
|
| 228 |
+
output_proj = nn.Linear(proj_w.shape[1], proj_w.shape[0])
|
| 229 |
+
proj_params = [("weight", proj_w)]
|
| 230 |
+
if proj_b is not None:
|
| 231 |
+
proj_params.append(("bias", proj_b))
|
| 232 |
+
output_proj.load_weights(proj_params)
|
| 233 |
+
print(f"output_proj: {proj_w.shape[1]} β {proj_w.shape[0]}")
|
| 234 |
+
|
| 235 |
+
# Test prompts with known token IDs from Swift logs
|
| 236 |
+
test_cases = [
|
| 237 |
+
("dog barking", [1782, 21696, 53, 1]),
|
| 238 |
+
("cars in the street", [2948, 16, 8, 2815, 1]),
|
| 239 |
+
("A metro train leaving the platform", [71, 12810, 2412, 3140, 8, 1585, 1]),
|
| 240 |
+
]
|
| 241 |
+
|
| 242 |
+
for prompt, token_ids in test_cases:
|
| 243 |
+
print(f"\n--- '{prompt}' ---")
|
| 244 |
+
print(f"Tokens: {token_ids}")
|
| 245 |
+
|
| 246 |
+
input_ids = mx.array([token_ids], dtype=mx.int32)
|
| 247 |
+
features = encoder(input_ids)
|
| 248 |
+
mx.eval(features)
|
| 249 |
+
|
| 250 |
+
print(f"Encoder output: shape={features.shape} "
|
| 251 |
+
f"min={features.min().item():.7f} max={features.max().item():.7f} "
|
| 252 |
+
f"sum={features.sum().item():.4f}")
|
| 253 |
+
|
| 254 |
+
# Per-position stats
|
| 255 |
+
for i in range(features.shape[1]):
|
| 256 |
+
pos_feat = features[0, i]
|
| 257 |
+
print(f" pos[{i}]: min={pos_feat.min().item():.5f} "
|
| 258 |
+
f"max={pos_feat.max().item():.5f} "
|
| 259 |
+
f"mean={pos_feat.mean().item():.5f}")
|
| 260 |
+
|
| 261 |
+
if output_proj is not None:
|
| 262 |
+
projected = output_proj(features)
|
| 263 |
+
mx.eval(projected)
|
| 264 |
+
print(f"After output_proj: shape={projected.shape} "
|
| 265 |
+
f"min={projected.min().item():.7f} max={projected.max().item():.7f} "
|
| 266 |
+
f"sum={projected.sum().item():.4f}")
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
if __name__ == "__main__":
|
| 270 |
+
t5_dir = MODEL_DIR
|
| 271 |
+
if not t5_dir.exists():
|
| 272 |
+
print(f"T5 directory not found: {t5_dir}")
|
| 273 |
+
exit(1)
|
| 274 |
+
main()
|