ClementDuhamel commited on
Commit
387ced5
Β·
verified Β·
1 Parent(s): d80e890

fix: critical T5 conditioner key sanitization and metadata

Browse files

- extract_t5.py: sanitize .layer.0. β†’ .layer_0. and .layer.1. β†’ .layer_1.
for MLX ModuleParameters.unflattened() compatibility. Without this fix,
all 24 T5 transformer block weights remain at random initialization.
Also adds automatic T5 variant detection (small/base/large).

- config.json: correct T5 metadata from t5-small to t5-large
(d_model=1024, 24 layers, 16 heads, d_ff=4096).

- README.md: fix text encoder reference, add T5 extraction section,
MLX key sanitization note, T5 unscaled attention note, update Swift usage.

- verify_t5.py: new file β€” Python MLX reference implementation for
verifying T5 encoder output against the Swift implementation.

Files changed (4) hide show
  1. README.md +25 -5
  2. config.json +9 -9
  3. extract_t5.py +271 -0
  4. verify_t5.py +274 -0
README.md CHANGED
@@ -21,15 +21,28 @@ This is the MLX-native port of [facebook/audiogen-medium](https://huggingface.co
21
  - **Parameters**: ~1.5B (LM) + EnCodec compression model
22
  - **Sampling rate**: 16 kHz
23
  - **Frame rate**: 50 Hz (4 codebooks, delayed pattern)
24
- - **Text encoder**: T5-small (loaded separately)
25
  - **Max duration**: 10 seconds (configurable)
26
 
27
  ## Files
28
 
29
- - `config.json` β€” Model configuration
30
  - `model.safetensors` β€” LM + EnCodec weights
31
  - `model.safetensors.index.json` β€” Weight index (for sharded variants)
32
- - `tokenizer.json` / `tokenizer_config.json` β€” T5 tokenizer files
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
  ## Usage (Swift/MLX)
35
 
@@ -40,15 +53,22 @@ let model = try await AudioGenModel.fromPretrained(
40
  modelFolder: modelURL,
41
  t5Folder: t5URL
42
  )
43
- let audio = try await model.generateAudio(
44
- description: "dog barking",
 
45
  duration: 5.0,
46
  cfgCoef: 3.0,
47
  temperature: 1.0,
48
  topK: 250
49
  )
 
 
50
  ```
51
 
 
 
 
 
52
  ## License
53
 
54
  This model is published under the [CC-BY-NC 4.0](https://creativecommons.org/licenses/by-nc/4.0/) license (non-commercial use only), following the original [AudioGen license](https://huggingface.co/facebook/audiogen-medium).
 
21
  - **Parameters**: ~1.5B (LM) + EnCodec compression model
22
  - **Sampling rate**: 16 kHz
23
  - **Frame rate**: 50 Hz (4 codebooks, delayed pattern)
24
+ - **Text encoder**: T5-large (d_model=1024, 24 layers, 16 heads)
25
  - **Max duration**: 10 seconds (configurable)
26
 
27
  ## Files
28
 
29
+ - `config.json` β€” Model configuration (includes `t5_model_name` reference)
30
  - `model.safetensors` β€” LM + EnCodec weights
31
  - `model.safetensors.index.json` β€” Weight index (for sharded variants)
32
+
33
+ ### T5 Conditioner (extracted separately)
34
+
35
+ The T5-large text encoder weights are not included in this repository. Use `extract_t5.py` to extract them from the original `facebook/audiogen-medium` checkpoint:
36
+
37
+ ```bash
38
+ python extract_t5.py --output /path/to/audiogen-mlx/t5
39
+ ```
40
+
41
+ This produces a `t5/` directory with `config.json`, `model.safetensors`, and tokenizer files.
42
+
43
+ > **Note**: The T5 safetensors keys use MLX-compatible naming (`.layer_0.` / `.layer_1.`
44
+ > instead of HuggingFace's `.layer.0.` / `.layer.1.`). This is required because MLX's
45
+ > `ModuleParameters.unflattened()` splits on all dots.
46
 
47
  ## Usage (Swift/MLX)
48
 
 
53
  modelFolder: modelURL,
54
  t5Folder: t5URL
55
  )
56
+
57
+ let tokens = try await model.generate(
58
+ descriptions: ["dog barking"],
59
  duration: 5.0,
60
  cfgCoef: 3.0,
61
  temperature: 1.0,
62
  topK: 250
63
  )
64
+
65
+ let audio = model.decode(tokens: tokens)
66
  ```
67
 
68
+ ## T5 Attention
69
+
70
+ T5's self-attention intentionally does **not** scale scores by `1/sqrt(d_k)`. This is a deliberate design choice in the T5 architecture β€” do not add scaling in the inference code.
71
+
72
  ## License
73
 
74
  This model is published under the [CC-BY-NC 4.0](https://creativecommons.org/licenses/by-nc/4.0/) license (non-commercial use only), following the original [AudioGen license](https://huggingface.co/facebook/audiogen-medium).
config.json CHANGED
@@ -36,14 +36,14 @@
36
  "duration": 10.0,
37
  "numSamples": 1,
38
  "specialToken": 2048,
39
- "tokenizer": "t5-small",
40
- "t5_model_name": "t5-small",
41
  "clsToken": 2048,
42
  "padToken": 2048,
43
  "encodec": {
44
  "model_type": "encodec",
45
  "audio_channels": 1,
46
- "num_filters": 64,
47
  "kernel_size": 7,
48
  "num_residual_layers": 1,
49
  "dilation_growth_rate": 2,
@@ -67,15 +67,15 @@
67
  "use_conv_shortcut": false
68
  },
69
  "t5": {
70
- "model_name": "t5-small",
71
- "d_model": 512,
72
  "d_kv": 64,
73
- "d_ff": 2048,
74
- "num_layers": 8,
75
- "num_heads": 6,
76
  "relative_attention_num_buckets": 32,
77
  "relative_attention_max_distance": 128,
78
- "dropout_rate": 0.1,
79
  "layer_norm_epsilon": 1e-06,
80
  "feed_forward_proj": "relu",
81
  "vocab_size": 32128,
 
36
  "duration": 10.0,
37
  "numSamples": 1,
38
  "specialToken": 2048,
39
+ "tokenizer": "t5-large",
40
+ "t5_model_name": "t5-large",
41
  "clsToken": 2048,
42
  "padToken": 2048,
43
  "encodec": {
44
  "model_type": "encodec",
45
  "audio_channels": 1,
46
+ "num_filters": 32,
47
  "kernel_size": 7,
48
  "num_residual_layers": 1,
49
  "dilation_growth_rate": 2,
 
67
  "use_conv_shortcut": false
68
  },
69
  "t5": {
70
+ "model_name": "t5-large",
71
+ "d_model": 1024,
72
  "d_kv": 64,
73
+ "d_ff": 4096,
74
+ "num_layers": 24,
75
+ "num_heads": 16,
76
  "relative_attention_num_buckets": 32,
77
  "relative_attention_max_distance": 128,
78
+ "dropout_rate": 0.0,
79
  "layer_norm_epsilon": 1e-06,
80
  "feed_forward_proj": "relu",
81
  "vocab_size": 32128,
extract_t5.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Extract T5 conditioner weights from facebook/audiogen-medium for MLX.
4
+
5
+ The original AudioGen model bundles a frozen T5 text encoder and a trained
6
+ output projection inside condition_provider.*. The main MLX conversion strips
7
+ these keys. This script extracts them into a t5/ subdirectory that the MLX
8
+ AudioGen loader expects.
9
+
10
+ Usage:
11
+ # Automatic: downloads from HuggingFace, extracts, cleans up
12
+ python extract_t5.py --output /path/to/audiogen-mlx/t5
13
+
14
+ # Manual: use a local state_dict.bin you already downloaded
15
+ python extract_t5.py --lm /path/to/state_dict.bin --output /path/to/audiogen-mlx/t5
16
+
17
+ Output (in --output directory):
18
+ config.json T5 encoder config (derived from weight shapes)
19
+ model.safetensors T5 encoder weights + output_proj
20
+ tokenizer.json Downloaded from google-t5/t5-small
21
+ tokenizer_config.json Downloaded from google-t5/t5-small
22
+
23
+ Requirements:
24
+ pip install torch safetensors huggingface_hub
25
+ """
26
+
27
+ import argparse
28
+ import json
29
+ import os
30
+ import struct
31
+ import tempfile
32
+ import shutil
33
+
34
+ import torch
35
+ from safetensors.torch import save_file
36
+ from huggingface_hub import hf_hub_download
37
+
38
+
39
+ T5_PREFIX = "condition_provider.conditioners.description.model."
40
+ OUTPUT_PROJ_PREFIX = "condition_provider.conditioners.description.output_proj."
41
+
42
+
43
+ def load_lm_state(path):
44
+ """Load the LM state dict from a PyTorch checkpoint."""
45
+ ckpt = torch.load(path, map_location="cpu", weights_only=True)
46
+ if "best_state" in ckpt:
47
+ return ckpt["best_state"]
48
+ return ckpt
49
+
50
+
51
+ def extract_t5_weights(lm_state):
52
+ """Extract T5 encoder and output_proj weights from the LM state dict."""
53
+ t5_weights = {}
54
+ output_proj = {}
55
+ other_cp = []
56
+
57
+ for key, tensor in lm_state.items():
58
+ if not key.startswith("condition_provider."):
59
+ continue
60
+
61
+ if key.startswith(T5_PREFIX):
62
+ # Strip prefix to get standard HuggingFace T5 key format
63
+ new_key = key[len(T5_PREFIX):]
64
+ t5_weights[new_key] = tensor
65
+ elif key.startswith(OUTPUT_PROJ_PREFIX):
66
+ # output_proj.weight / output_proj.bias
67
+ new_key = key[len(OUTPUT_PROJ_PREFIX):]
68
+ output_proj[f"output_proj.{new_key}"] = tensor
69
+ else:
70
+ other_cp.append(key)
71
+
72
+ return t5_weights, output_proj, other_cp
73
+
74
+
75
+ def sanitize_keys_for_mlx(weights):
76
+ """Rename T5 weight keys for MLX compatibility.
77
+
78
+ HuggingFace T5 uses keys like "encoder.block.0.layer.0.SelfAttention.q.weight"
79
+ where "layer.0" and "layer.1" are sub-module names. MLX's
80
+ ModuleParameters.unflattened() splits on ALL dots, which misparses "layer.0"
81
+ as {"layer": {"0": ...}} instead of treating it as a single key.
82
+
83
+ This renames ".layer.0." to ".layer_0." and ".layer.1." to ".layer_1." so
84
+ the keys work correctly with MLX's parameter loading.
85
+ """
86
+ sanitized = {}
87
+ for key, value in weights.items():
88
+ new_key = key
89
+ new_key = new_key.replace(".layer.0.", ".layer_0.")
90
+ new_key = new_key.replace(".layer.1.", ".layer_1.")
91
+ sanitized[new_key] = value
92
+ return sanitized
93
+
94
+
95
+ def infer_t5_config(t5_weights):
96
+ """Determine T5 architecture from weight shapes."""
97
+ # shared.weight: [vocab_size, d_model]
98
+ shared = t5_weights.get("shared.weight")
99
+ if shared is None:
100
+ raise ValueError("Cannot find shared.weight in T5 weights")
101
+
102
+ vocab_size = shared.shape[0]
103
+ d_model = shared.shape[1]
104
+
105
+ # Find q projection to determine d_kv and num_heads
106
+ q_weight = t5_weights.get("encoder.block.0.layer.0.SelfAttention.q.weight")
107
+ if q_weight is None:
108
+ raise ValueError("Cannot find SelfAttention.q.weight")
109
+
110
+ # q.weight: [num_heads * d_kv, d_model]
111
+ total_kv = q_weight.shape[0]
112
+
113
+ # Find DenseReluDense.wi to determine d_ff
114
+ wi = t5_weights.get("encoder.block.0.layer.1.DenseReluDense.wi.weight")
115
+ if wi is None:
116
+ raise ValueError("Cannot find DenseReluDense.wi.weight")
117
+ d_ff = wi.shape[0]
118
+
119
+ # Count encoder layers
120
+ num_layers = 0
121
+ while f"encoder.block.{num_layers}.layer.0.SelfAttention.q.weight" in t5_weights:
122
+ num_layers += 1
123
+
124
+ # Determine d_kv and num_heads
125
+ # Standard T5 d_kv values: 64 (all sizes)
126
+ d_kv = 64
127
+ num_heads = total_kv // d_kv
128
+
129
+ # Check relative_attention_bias
130
+ rab = t5_weights.get(
131
+ "encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight"
132
+ )
133
+ num_buckets = rab.shape[0] if rab is not None else 32
134
+
135
+ # Determine T5 variant name from d_model
136
+ t5_variant = "t5-unknown"
137
+ if d_model == 512:
138
+ t5_variant = "t5-small"
139
+ elif d_model == 768:
140
+ t5_variant = "t5-base"
141
+ elif d_model == 1024:
142
+ t5_variant = "t5-large"
143
+ elif d_model == 4096:
144
+ t5_variant = "t5-3b"
145
+
146
+ config = {
147
+ "architectures": ["T5EncoderModel"],
148
+ "model_name": t5_variant,
149
+ "d_model": d_model,
150
+ "d_kv": d_kv,
151
+ "d_ff": d_ff,
152
+ "num_heads": num_heads,
153
+ "num_layers": num_layers,
154
+ "vocab_size": vocab_size,
155
+ "relative_attention_num_buckets": num_buckets,
156
+ "relative_attention_max_distance": 128,
157
+ "dropout_rate": 0.0,
158
+ "layer_norm_epsilon": 1e-6,
159
+ "feed_forward_proj": "relu",
160
+ "tie_word_embeddings": True,
161
+ "decoder_start_token_id": 0,
162
+ "model_type": "t5",
163
+ }
164
+ return config
165
+
166
+
167
+ def download_tokenizer(output_dir):
168
+ """Download T5 tokenizer files from HuggingFace.
169
+
170
+ All T5 model sizes share the same SentencePiece tokenizer (32128 tokens),
171
+ so we download from t5-small for convenience.
172
+ """
173
+ repo = "google-t5/t5-small"
174
+ for filename in ["tokenizer.json", "tokenizer_config.json"]:
175
+ path = hf_hub_download(repo_id=repo, filename=filename)
176
+ dst = os.path.join(output_dir, filename)
177
+ shutil.copy2(path, dst)
178
+ print(f" Copied {filename}")
179
+
180
+
181
+ def main():
182
+ parser = argparse.ArgumentParser(
183
+ description="Extract T5 conditioner from facebook/audiogen-medium"
184
+ )
185
+ parser.add_argument(
186
+ "--lm",
187
+ help="Path to local state_dict.bin (skips download)",
188
+ )
189
+ parser.add_argument(
190
+ "--output",
191
+ required=True,
192
+ help="Output directory for T5 weights (e.g. /path/to/model/t5)",
193
+ )
194
+ args = parser.parse_args()
195
+
196
+ os.makedirs(args.output, exist_ok=True)
197
+
198
+ # Get the state dict
199
+ if args.lm:
200
+ lm_path = args.lm
201
+ print(f"Loading local checkpoint: {lm_path}")
202
+ else:
203
+ print("Downloading facebook/audiogen-medium state_dict.bin ...")
204
+ lm_path = hf_hub_download(
205
+ repo_id="facebook/audiogen-medium",
206
+ filename="state_dict.bin",
207
+ )
208
+ print(f" Downloaded to cache: {lm_path}")
209
+
210
+ print("Loading state dict ...")
211
+ lm_state = load_lm_state(lm_path)
212
+
213
+ print("Extracting T5 weights ...")
214
+ t5_weights, output_proj, other_cp = extract_t5_weights(lm_state)
215
+
216
+ print(f" T5 encoder keys: {len(t5_weights)}")
217
+ print(f" Output projection keys: {len(output_proj)}")
218
+ if other_cp:
219
+ print(f" Other condition_provider keys (skipped): {len(other_cp)}")
220
+
221
+ if not t5_weights:
222
+ print("ERROR: No T5 weights found in checkpoint!")
223
+ return
224
+
225
+ # Infer T5 architecture
226
+ config = infer_t5_config(t5_weights)
227
+ print(f" T5 config: {config['model_name']} β€” d_model={config['d_model']}, "
228
+ f"num_heads={config['num_heads']}, "
229
+ f"num_layers={config['num_layers']}, "
230
+ f"d_ff={config['d_ff']}, "
231
+ f"vocab_size={config['vocab_size']}")
232
+
233
+ if output_proj:
234
+ proj_w = output_proj.get("output_proj.weight")
235
+ if proj_w is not None:
236
+ print(f" Output projection: {list(proj_w.shape)} "
237
+ f"(T5 d_model={proj_w.shape[1]} β†’ LM dim={proj_w.shape[0]})")
238
+
239
+ # Sanitize keys for MLX compatibility before saving
240
+ sanitized_t5 = sanitize_keys_for_mlx(t5_weights)
241
+ print(f" Sanitized {len(sanitized_t5)} T5 keys (.layer.N. β†’ .layer_N.)")
242
+
243
+ # Combine sanitized T5 weights + output_proj into one safetensors
244
+ all_weights = {}
245
+ all_weights.update(sanitized_t5)
246
+ all_weights.update(output_proj)
247
+
248
+ # Save safetensors
249
+ st_path = os.path.join(args.output, "model.safetensors")
250
+ print(f"Saving {len(all_weights)} tensors to {st_path} ...")
251
+ save_file(all_weights, st_path)
252
+
253
+ total_bytes = os.path.getsize(st_path)
254
+ print(f" Size: {total_bytes / 1e6:.1f} MB")
255
+
256
+ # Save config
257
+ config_path = os.path.join(args.output, "config.json")
258
+ with open(config_path, "w") as f:
259
+ json.dump(config, f, indent=2)
260
+ print(f"Saved config.json")
261
+
262
+ # Download tokenizer
263
+ print("Downloading T5 tokenizer ...")
264
+ download_tokenizer(args.output)
265
+
266
+ print(f"\nDone! T5 conditioner saved to: {args.output}")
267
+ print("Files:", sorted(os.listdir(args.output)))
268
+
269
+
270
+ if __name__ == "__main__":
271
+ main()
verify_t5.py ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Verify T5 encoder output against Swift implementation.
3
+
4
+ Loads the same T5 safetensors weights, runs the encoder on the same tokens,
5
+ and prints output stats for comparison with the Swift logs.
6
+ """
7
+
8
+ import math
9
+ import mlx.core as mx
10
+ import mlx.nn as nn
11
+ import json
12
+ from pathlib import Path
13
+
14
+ MODEL_DIR = Path.home() / "Library/Application Support/Velvox/Models/audiogen-mlx/t5"
15
+
16
+
17
+ # ── T5 LayerNorm (RMSNorm, no centering) ──
18
+
19
+ class T5LayerNorm(nn.Module):
20
+ def __init__(self, dims, eps=1e-6):
21
+ super().__init__()
22
+ self.weight = mx.ones((dims,))
23
+ self.eps = eps
24
+
25
+ def __call__(self, x):
26
+ y = x.astype(mx.float32)
27
+ y = y * mx.rsqrt(mx.mean(y * y, axis=-1, keepdims=True) + self.eps)
28
+ return self.weight * y.astype(x.dtype)
29
+
30
+
31
+ # ── T5 DenseReluDense ──
32
+
33
+ class T5DenseActDense(nn.Module):
34
+ def __init__(self, d_model, d_ff, act="relu"):
35
+ super().__init__()
36
+ self.wi = nn.Linear(d_model, d_ff, bias=False)
37
+ self.wo = nn.Linear(d_ff, d_model, bias=False)
38
+ self.act = act
39
+
40
+ def __call__(self, x):
41
+ h = self.wi(x)
42
+ h = nn.relu(h) if self.act == "relu" else nn.gelu(h)
43
+ return self.wo(h)
44
+
45
+
46
+ # ── T5 Attention (NO sqrt(d_k) scaling β€” this is T5's design) ──
47
+
48
+ class T5Attention(nn.Module):
49
+ def __init__(self, config, has_relative_attention_bias=False):
50
+ super().__init__()
51
+ self.num_heads = config["num_heads"]
52
+ self.d_kv = config["d_kv"]
53
+ self.d_model = config["d_model"]
54
+ self.has_relative_attention_bias = has_relative_attention_bias
55
+ self.num_buckets = config["relative_attention_num_buckets"]
56
+ self.max_distance = config.get("relative_attention_max_distance", 128)
57
+
58
+ self.q = nn.Linear(self.d_model, self.num_heads * self.d_kv, bias=False)
59
+ self.k = nn.Linear(self.d_model, self.num_heads * self.d_kv, bias=False)
60
+ self.v = nn.Linear(self.d_model, self.num_heads * self.d_kv, bias=False)
61
+ self.o = nn.Linear(self.num_heads * self.d_kv, self.d_model, bias=False)
62
+
63
+ if has_relative_attention_bias:
64
+ self.relative_attention_bias = nn.Embedding(self.num_buckets, self.num_heads)
65
+
66
+ @staticmethod
67
+ def _relative_position_bucket(rp, bidirectional=True, num_buckets=32, max_distance=128):
68
+ nb = num_buckets
69
+ result = mx.zeros(rp.shape, dtype=mx.int32)
70
+ if bidirectional:
71
+ nb = nb // 2
72
+ is_pos = mx.where(rp > 0, mx.array(nb, dtype=mx.int32), mx.array(0, dtype=mx.int32))
73
+ result = is_pos
74
+ rp = mx.abs(rp)
75
+ else:
76
+ rp = -mx.minimum(rp, mx.zeros_like(rp))
77
+
78
+ max_exact = nb // 2
79
+ is_small = rp < max_exact
80
+
81
+ large_rp = rp.astype(mx.float32)
82
+ log_ratio = mx.log(large_rp / max_exact) / math.log(max_distance / max_exact)
83
+ large_bucket = (log_ratio * (nb - max_exact)).astype(mx.int32) + max_exact
84
+ clamped = mx.minimum(large_bucket, mx.array(nb - 1, dtype=mx.int32))
85
+
86
+ buckets = mx.where(is_small, rp.astype(mx.int32), clamped)
87
+ return result + buckets
88
+
89
+ def compute_bias(self, q_len, k_len):
90
+ if not self.has_relative_attention_bias:
91
+ return None
92
+ ctx = mx.arange(q_len, dtype=mx.int32)
93
+ mem = mx.arange(k_len, dtype=mx.int32)
94
+ rp = mem.reshape(1, -1).astype(mx.float32) - ctx.reshape(-1, 1).astype(mx.float32)
95
+ rp_bucket = self._relative_position_bucket(
96
+ rp, bidirectional=True,
97
+ num_buckets=self.num_buckets, max_distance=self.max_distance
98
+ )
99
+ flat = rp_bucket.reshape(-1)
100
+ bias_flat = self.relative_attention_bias(flat)
101
+ bias = bias_flat.reshape(q_len, k_len, self.num_heads)
102
+ bias = bias.transpose(2, 0, 1)[None, :, :, :] # [1, H, Q, K]
103
+ return bias
104
+
105
+ def __call__(self, hidden, mask=None, position_bias=None):
106
+ B, T, _ = hidden.shape
107
+ q = self.q(hidden).reshape(B, T, self.num_heads, self.d_kv)
108
+ k = self.k(hidden).reshape(B, T, self.num_heads, self.d_kv)
109
+ v = self.v(hidden).reshape(B, T, self.num_heads, self.d_kv)
110
+
111
+ q = q.transpose(0, 2, 1, 3) # [B, H, T, d]
112
+ k = k.transpose(0, 2, 3, 1) # [B, H, d, T]
113
+ v = v.transpose(0, 2, 1, 3) # [B, H, T, d]
114
+
115
+ # T5: NO scaling by 1/sqrt(d_k)
116
+ scores = q @ k
117
+
118
+ if position_bias is None:
119
+ position_bias = self.compute_bias(T, T)
120
+ if position_bias is not None:
121
+ scores = scores + position_bias
122
+
123
+ weights = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype)
124
+ out = (weights @ v).transpose(0, 2, 1, 3).reshape(B, T, -1)
125
+ return self.o(out)
126
+
127
+
128
+ # ── T5 Block ──
129
+
130
+ class T5Block(nn.Module):
131
+ def __init__(self, config, has_relative_attention_bias=False):
132
+ super().__init__()
133
+ self.self_attn = T5Attention(config, has_relative_attention_bias)
134
+ self.layer_norm_sa = T5LayerNorm(config["d_model"], config.get("layer_norm_epsilon", 1e-6))
135
+ self.ff = T5DenseActDense(config["d_model"], config["d_ff"], config.get("feed_forward_proj", "relu"))
136
+ self.layer_norm_ff = T5LayerNorm(config["d_model"], config.get("layer_norm_epsilon", 1e-6))
137
+
138
+ def __call__(self, x, mask=None, position_bias=None):
139
+ normed = self.layer_norm_sa(x)
140
+ attn_out = self.self_attn(normed, mask=mask, position_bias=position_bias)
141
+ x = x + attn_out
142
+ normed = self.layer_norm_ff(x)
143
+ ff_out = self.ff(normed)
144
+ x = x + ff_out
145
+ return x
146
+
147
+
148
+ # ── T5 Encoder ──
149
+
150
+ class T5Encoder(nn.Module):
151
+ def __init__(self, config):
152
+ super().__init__()
153
+ self.shared = nn.Embedding(config["vocab_size"], config["d_model"])
154
+ self.blocks = [T5Block(config, has_relative_attention_bias=(i == 0))
155
+ for i in range(config["num_layers"])]
156
+ self.final_layer_norm = T5LayerNorm(config["d_model"], config.get("layer_norm_epsilon", 1e-6))
157
+
158
+ def __call__(self, input_ids):
159
+ x = self.shared(input_ids)
160
+ # Compute position bias from first block, reuse for all
161
+ pos_bias = self.blocks[0].self_attn.compute_bias(x.shape[1], x.shape[1])
162
+ for block in self.blocks:
163
+ x = block(x, position_bias=pos_bias)
164
+ return self.final_layer_norm(x)
165
+
166
+
167
+ def load_and_remap_weights(t5_dir):
168
+ """Load safetensors and remap HuggingFace T5 keys to our module structure."""
169
+ import glob
170
+ safetensors_files = sorted(glob.glob(str(t5_dir / "*.safetensors")))
171
+
172
+ all_weights = {}
173
+ for f in safetensors_files:
174
+ w = mx.load(f)
175
+ all_weights.update(w)
176
+
177
+ # Separate output_proj from T5 weights
178
+ output_proj_w = all_weights.pop("output_proj.weight", None)
179
+ output_proj_b = all_weights.pop("output_proj.bias", None)
180
+
181
+ # Remap HuggingFace keys to our module structure
182
+ remapped = {}
183
+ for key, val in all_weights.items():
184
+ new_key = key
185
+
186
+ # shared.weight β†’ shared.weight (OK)
187
+
188
+ # encoder.block.N.layer.0.SelfAttention.X β†’ blocks.N.self_attn.X
189
+ new_key = new_key.replace("encoder.block.", "blocks.")
190
+ new_key = new_key.replace(".layer.0.SelfAttention.", ".self_attn.")
191
+ new_key = new_key.replace(".layer.0.layer_norm.", ".layer_norm_sa.")
192
+ new_key = new_key.replace(".layer.1.DenseReluDense.", ".ff.")
193
+ new_key = new_key.replace(".layer.1.layer_norm.", ".layer_norm_ff.")
194
+
195
+ # encoder.final_layer_norm β†’ final_layer_norm
196
+ new_key = new_key.replace("encoder.final_layer_norm.", "final_layer_norm.")
197
+
198
+ remapped[new_key] = val
199
+
200
+ return remapped, output_proj_w, output_proj_b
201
+
202
+
203
+ def main():
204
+ print("=" * 60)
205
+ print("T5 Encoder Verification (MLX Python reference)")
206
+ print("=" * 60)
207
+
208
+ # Load config
209
+ with open(t5_dir / "config.json") as f:
210
+ config = json.load(f)
211
+
212
+ print(f"Config: d_model={config['d_model']} layers={config['num_layers']} "
213
+ f"heads={config['num_heads']} d_kv={config['d_kv']} d_ff={config['d_ff']}")
214
+
215
+ # Build model
216
+ encoder = T5Encoder(config)
217
+
218
+ # Load weights
219
+ t5_dir_p = MODEL_DIR
220
+ weights, proj_w, proj_b = load_and_remap_weights(t5_dir_p)
221
+
222
+ # Apply weights
223
+ encoder.load_weights(list(weights.items()))
224
+
225
+ # Build output_proj
226
+ output_proj = None
227
+ if proj_w is not None:
228
+ output_proj = nn.Linear(proj_w.shape[1], proj_w.shape[0])
229
+ proj_params = [("weight", proj_w)]
230
+ if proj_b is not None:
231
+ proj_params.append(("bias", proj_b))
232
+ output_proj.load_weights(proj_params)
233
+ print(f"output_proj: {proj_w.shape[1]} β†’ {proj_w.shape[0]}")
234
+
235
+ # Test prompts with known token IDs from Swift logs
236
+ test_cases = [
237
+ ("dog barking", [1782, 21696, 53, 1]),
238
+ ("cars in the street", [2948, 16, 8, 2815, 1]),
239
+ ("A metro train leaving the platform", [71, 12810, 2412, 3140, 8, 1585, 1]),
240
+ ]
241
+
242
+ for prompt, token_ids in test_cases:
243
+ print(f"\n--- '{prompt}' ---")
244
+ print(f"Tokens: {token_ids}")
245
+
246
+ input_ids = mx.array([token_ids], dtype=mx.int32)
247
+ features = encoder(input_ids)
248
+ mx.eval(features)
249
+
250
+ print(f"Encoder output: shape={features.shape} "
251
+ f"min={features.min().item():.7f} max={features.max().item():.7f} "
252
+ f"sum={features.sum().item():.4f}")
253
+
254
+ # Per-position stats
255
+ for i in range(features.shape[1]):
256
+ pos_feat = features[0, i]
257
+ print(f" pos[{i}]: min={pos_feat.min().item():.5f} "
258
+ f"max={pos_feat.max().item():.5f} "
259
+ f"mean={pos_feat.mean().item():.5f}")
260
+
261
+ if output_proj is not None:
262
+ projected = output_proj(features)
263
+ mx.eval(projected)
264
+ print(f"After output_proj: shape={projected.shape} "
265
+ f"min={projected.min().item():.7f} max={projected.max().item():.7f} "
266
+ f"sum={projected.sum().item():.4f}")
267
+
268
+
269
+ if __name__ == "__main__":
270
+ t5_dir = MODEL_DIR
271
+ if not t5_dir.exists():
272
+ print(f"T5 directory not found: {t5_dir}")
273
+ exit(1)
274
+ main()