File size: 8,907 Bytes
387ced5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
#!/usr/bin/env python3
"""
Extract T5 conditioner weights from facebook/audiogen-medium for MLX.

The original AudioGen model bundles a frozen T5 text encoder and a trained
output projection inside condition_provider.*. The main MLX conversion strips
these keys. This script extracts them into a t5/ subdirectory that the MLX
AudioGen loader expects.

Usage:
    # Automatic: downloads from HuggingFace, extracts, cleans up
    python extract_t5.py --output /path/to/audiogen-mlx/t5

    # Manual: use a local state_dict.bin you already downloaded
    python extract_t5.py --lm /path/to/state_dict.bin --output /path/to/audiogen-mlx/t5

Output (in --output directory):
    config.json             T5 encoder config (derived from weight shapes)
    model.safetensors       T5 encoder weights + output_proj
    tokenizer.json          Downloaded from google-t5/t5-small
    tokenizer_config.json   Downloaded from google-t5/t5-small

Requirements:
    pip install torch safetensors huggingface_hub
"""

import argparse
import json
import os
import struct
import tempfile
import shutil

import torch
from safetensors.torch import save_file
from huggingface_hub import hf_hub_download


T5_PREFIX = "condition_provider.conditioners.description.model."
OUTPUT_PROJ_PREFIX = "condition_provider.conditioners.description.output_proj."


def load_lm_state(path):
    """Load the LM state dict from a PyTorch checkpoint."""
    ckpt = torch.load(path, map_location="cpu", weights_only=True)
    if "best_state" in ckpt:
        return ckpt["best_state"]
    return ckpt


def extract_t5_weights(lm_state):
    """Extract T5 encoder and output_proj weights from the LM state dict."""
    t5_weights = {}
    output_proj = {}
    other_cp = []

    for key, tensor in lm_state.items():
        if not key.startswith("condition_provider."):
            continue

        if key.startswith(T5_PREFIX):
            # Strip prefix to get standard HuggingFace T5 key format
            new_key = key[len(T5_PREFIX):]
            t5_weights[new_key] = tensor
        elif key.startswith(OUTPUT_PROJ_PREFIX):
            # output_proj.weight / output_proj.bias
            new_key = key[len(OUTPUT_PROJ_PREFIX):]
            output_proj[f"output_proj.{new_key}"] = tensor
        else:
            other_cp.append(key)

    return t5_weights, output_proj, other_cp


def sanitize_keys_for_mlx(weights):
    """Rename T5 weight keys for MLX compatibility.

    HuggingFace T5 uses keys like "encoder.block.0.layer.0.SelfAttention.q.weight"
    where "layer.0" and "layer.1" are sub-module names. MLX's
    ModuleParameters.unflattened() splits on ALL dots, which misparses "layer.0"
    as {"layer": {"0": ...}} instead of treating it as a single key.

    This renames ".layer.0." to ".layer_0." and ".layer.1." to ".layer_1." so
    the keys work correctly with MLX's parameter loading.
    """
    sanitized = {}
    for key, value in weights.items():
        new_key = key
        new_key = new_key.replace(".layer.0.", ".layer_0.")
        new_key = new_key.replace(".layer.1.", ".layer_1.")
        sanitized[new_key] = value
    return sanitized


def infer_t5_config(t5_weights):
    """Determine T5 architecture from weight shapes."""
    # shared.weight: [vocab_size, d_model]
    shared = t5_weights.get("shared.weight")
    if shared is None:
        raise ValueError("Cannot find shared.weight in T5 weights")

    vocab_size = shared.shape[0]
    d_model = shared.shape[1]

    # Find q projection to determine d_kv and num_heads
    q_weight = t5_weights.get("encoder.block.0.layer.0.SelfAttention.q.weight")
    if q_weight is None:
        raise ValueError("Cannot find SelfAttention.q.weight")

    # q.weight: [num_heads * d_kv, d_model]
    total_kv = q_weight.shape[0]

    # Find DenseReluDense.wi to determine d_ff
    wi = t5_weights.get("encoder.block.0.layer.1.DenseReluDense.wi.weight")
    if wi is None:
        raise ValueError("Cannot find DenseReluDense.wi.weight")
    d_ff = wi.shape[0]

    # Count encoder layers
    num_layers = 0
    while f"encoder.block.{num_layers}.layer.0.SelfAttention.q.weight" in t5_weights:
        num_layers += 1

    # Determine d_kv and num_heads
    # Standard T5 d_kv values: 64 (all sizes)
    d_kv = 64
    num_heads = total_kv // d_kv

    # Check relative_attention_bias
    rab = t5_weights.get(
        "encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight"
    )
    num_buckets = rab.shape[0] if rab is not None else 32

    # Determine T5 variant name from d_model
    t5_variant = "t5-unknown"
    if d_model == 512:
        t5_variant = "t5-small"
    elif d_model == 768:
        t5_variant = "t5-base"
    elif d_model == 1024:
        t5_variant = "t5-large"
    elif d_model == 4096:
        t5_variant = "t5-3b"

    config = {
        "architectures": ["T5EncoderModel"],
        "model_name": t5_variant,
        "d_model": d_model,
        "d_kv": d_kv,
        "d_ff": d_ff,
        "num_heads": num_heads,
        "num_layers": num_layers,
        "vocab_size": vocab_size,
        "relative_attention_num_buckets": num_buckets,
        "relative_attention_max_distance": 128,
        "dropout_rate": 0.0,
        "layer_norm_epsilon": 1e-6,
        "feed_forward_proj": "relu",
        "tie_word_embeddings": True,
        "decoder_start_token_id": 0,
        "model_type": "t5",
    }
    return config


def download_tokenizer(output_dir):
    """Download T5 tokenizer files from HuggingFace.

    All T5 model sizes share the same SentencePiece tokenizer (32128 tokens),
    so we download from t5-small for convenience.
    """
    repo = "google-t5/t5-small"
    for filename in ["tokenizer.json", "tokenizer_config.json"]:
        path = hf_hub_download(repo_id=repo, filename=filename)
        dst = os.path.join(output_dir, filename)
        shutil.copy2(path, dst)
        print(f"  Copied {filename}")


def main():
    parser = argparse.ArgumentParser(
        description="Extract T5 conditioner from facebook/audiogen-medium"
    )
    parser.add_argument(
        "--lm",
        help="Path to local state_dict.bin (skips download)",
    )
    parser.add_argument(
        "--output",
        required=True,
        help="Output directory for T5 weights (e.g. /path/to/model/t5)",
    )
    args = parser.parse_args()

    os.makedirs(args.output, exist_ok=True)

    # Get the state dict
    if args.lm:
        lm_path = args.lm
        print(f"Loading local checkpoint: {lm_path}")
    else:
        print("Downloading facebook/audiogen-medium state_dict.bin ...")
        lm_path = hf_hub_download(
            repo_id="facebook/audiogen-medium",
            filename="state_dict.bin",
        )
        print(f"  Downloaded to cache: {lm_path}")

    print("Loading state dict ...")
    lm_state = load_lm_state(lm_path)

    print("Extracting T5 weights ...")
    t5_weights, output_proj, other_cp = extract_t5_weights(lm_state)

    print(f"  T5 encoder keys: {len(t5_weights)}")
    print(f"  Output projection keys: {len(output_proj)}")
    if other_cp:
        print(f"  Other condition_provider keys (skipped): {len(other_cp)}")

    if not t5_weights:
        print("ERROR: No T5 weights found in checkpoint!")
        return

    # Infer T5 architecture
    config = infer_t5_config(t5_weights)
    print(f"  T5 config: {config['model_name']} — d_model={config['d_model']}, "
          f"num_heads={config['num_heads']}, "
          f"num_layers={config['num_layers']}, "
          f"d_ff={config['d_ff']}, "
          f"vocab_size={config['vocab_size']}")

    if output_proj:
        proj_w = output_proj.get("output_proj.weight")
        if proj_w is not None:
            print(f"  Output projection: {list(proj_w.shape)} "
                  f"(T5 d_model={proj_w.shape[1]} → LM dim={proj_w.shape[0]})")

    # Sanitize keys for MLX compatibility before saving
    sanitized_t5 = sanitize_keys_for_mlx(t5_weights)
    print(f"  Sanitized {len(sanitized_t5)} T5 keys (.layer.N. → .layer_N.)")

    # Combine sanitized T5 weights + output_proj into one safetensors
    all_weights = {}
    all_weights.update(sanitized_t5)
    all_weights.update(output_proj)

    # Save safetensors
    st_path = os.path.join(args.output, "model.safetensors")
    print(f"Saving {len(all_weights)} tensors to {st_path} ...")
    save_file(all_weights, st_path)

    total_bytes = os.path.getsize(st_path)
    print(f"  Size: {total_bytes / 1e6:.1f} MB")

    # Save config
    config_path = os.path.join(args.output, "config.json")
    with open(config_path, "w") as f:
        json.dump(config, f, indent=2)
    print(f"Saved config.json")

    # Download tokenizer
    print("Downloading T5 tokenizer ...")
    download_tokenizer(args.output)

    print(f"\nDone! T5 conditioner saved to: {args.output}")
    print("Files:", sorted(os.listdir(args.output)))


if __name__ == "__main__":
    main()