audiogen-medium-mlx / extract_t5.py
ClementDuhamel's picture
fix: critical T5 conditioner key sanitization and metadata
387ced5 verified
#!/usr/bin/env python3
"""
Extract T5 conditioner weights from facebook/audiogen-medium for MLX.
The original AudioGen model bundles a frozen T5 text encoder and a trained
output projection inside condition_provider.*. The main MLX conversion strips
these keys. This script extracts them into a t5/ subdirectory that the MLX
AudioGen loader expects.
Usage:
# Automatic: downloads from HuggingFace, extracts, cleans up
python extract_t5.py --output /path/to/audiogen-mlx/t5
# Manual: use a local state_dict.bin you already downloaded
python extract_t5.py --lm /path/to/state_dict.bin --output /path/to/audiogen-mlx/t5
Output (in --output directory):
config.json T5 encoder config (derived from weight shapes)
model.safetensors T5 encoder weights + output_proj
tokenizer.json Downloaded from google-t5/t5-small
tokenizer_config.json Downloaded from google-t5/t5-small
Requirements:
pip install torch safetensors huggingface_hub
"""
import argparse
import json
import os
import struct
import tempfile
import shutil
import torch
from safetensors.torch import save_file
from huggingface_hub import hf_hub_download
T5_PREFIX = "condition_provider.conditioners.description.model."
OUTPUT_PROJ_PREFIX = "condition_provider.conditioners.description.output_proj."
def load_lm_state(path):
"""Load the LM state dict from a PyTorch checkpoint."""
ckpt = torch.load(path, map_location="cpu", weights_only=True)
if "best_state" in ckpt:
return ckpt["best_state"]
return ckpt
def extract_t5_weights(lm_state):
"""Extract T5 encoder and output_proj weights from the LM state dict."""
t5_weights = {}
output_proj = {}
other_cp = []
for key, tensor in lm_state.items():
if not key.startswith("condition_provider."):
continue
if key.startswith(T5_PREFIX):
# Strip prefix to get standard HuggingFace T5 key format
new_key = key[len(T5_PREFIX):]
t5_weights[new_key] = tensor
elif key.startswith(OUTPUT_PROJ_PREFIX):
# output_proj.weight / output_proj.bias
new_key = key[len(OUTPUT_PROJ_PREFIX):]
output_proj[f"output_proj.{new_key}"] = tensor
else:
other_cp.append(key)
return t5_weights, output_proj, other_cp
def sanitize_keys_for_mlx(weights):
"""Rename T5 weight keys for MLX compatibility.
HuggingFace T5 uses keys like "encoder.block.0.layer.0.SelfAttention.q.weight"
where "layer.0" and "layer.1" are sub-module names. MLX's
ModuleParameters.unflattened() splits on ALL dots, which misparses "layer.0"
as {"layer": {"0": ...}} instead of treating it as a single key.
This renames ".layer.0." to ".layer_0." and ".layer.1." to ".layer_1." so
the keys work correctly with MLX's parameter loading.
"""
sanitized = {}
for key, value in weights.items():
new_key = key
new_key = new_key.replace(".layer.0.", ".layer_0.")
new_key = new_key.replace(".layer.1.", ".layer_1.")
sanitized[new_key] = value
return sanitized
def infer_t5_config(t5_weights):
"""Determine T5 architecture from weight shapes."""
# shared.weight: [vocab_size, d_model]
shared = t5_weights.get("shared.weight")
if shared is None:
raise ValueError("Cannot find shared.weight in T5 weights")
vocab_size = shared.shape[0]
d_model = shared.shape[1]
# Find q projection to determine d_kv and num_heads
q_weight = t5_weights.get("encoder.block.0.layer.0.SelfAttention.q.weight")
if q_weight is None:
raise ValueError("Cannot find SelfAttention.q.weight")
# q.weight: [num_heads * d_kv, d_model]
total_kv = q_weight.shape[0]
# Find DenseReluDense.wi to determine d_ff
wi = t5_weights.get("encoder.block.0.layer.1.DenseReluDense.wi.weight")
if wi is None:
raise ValueError("Cannot find DenseReluDense.wi.weight")
d_ff = wi.shape[0]
# Count encoder layers
num_layers = 0
while f"encoder.block.{num_layers}.layer.0.SelfAttention.q.weight" in t5_weights:
num_layers += 1
# Determine d_kv and num_heads
# Standard T5 d_kv values: 64 (all sizes)
d_kv = 64
num_heads = total_kv // d_kv
# Check relative_attention_bias
rab = t5_weights.get(
"encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight"
)
num_buckets = rab.shape[0] if rab is not None else 32
# Determine T5 variant name from d_model
t5_variant = "t5-unknown"
if d_model == 512:
t5_variant = "t5-small"
elif d_model == 768:
t5_variant = "t5-base"
elif d_model == 1024:
t5_variant = "t5-large"
elif d_model == 4096:
t5_variant = "t5-3b"
config = {
"architectures": ["T5EncoderModel"],
"model_name": t5_variant,
"d_model": d_model,
"d_kv": d_kv,
"d_ff": d_ff,
"num_heads": num_heads,
"num_layers": num_layers,
"vocab_size": vocab_size,
"relative_attention_num_buckets": num_buckets,
"relative_attention_max_distance": 128,
"dropout_rate": 0.0,
"layer_norm_epsilon": 1e-6,
"feed_forward_proj": "relu",
"tie_word_embeddings": True,
"decoder_start_token_id": 0,
"model_type": "t5",
}
return config
def download_tokenizer(output_dir):
"""Download T5 tokenizer files from HuggingFace.
All T5 model sizes share the same SentencePiece tokenizer (32128 tokens),
so we download from t5-small for convenience.
"""
repo = "google-t5/t5-small"
for filename in ["tokenizer.json", "tokenizer_config.json"]:
path = hf_hub_download(repo_id=repo, filename=filename)
dst = os.path.join(output_dir, filename)
shutil.copy2(path, dst)
print(f" Copied {filename}")
def main():
parser = argparse.ArgumentParser(
description="Extract T5 conditioner from facebook/audiogen-medium"
)
parser.add_argument(
"--lm",
help="Path to local state_dict.bin (skips download)",
)
parser.add_argument(
"--output",
required=True,
help="Output directory for T5 weights (e.g. /path/to/model/t5)",
)
args = parser.parse_args()
os.makedirs(args.output, exist_ok=True)
# Get the state dict
if args.lm:
lm_path = args.lm
print(f"Loading local checkpoint: {lm_path}")
else:
print("Downloading facebook/audiogen-medium state_dict.bin ...")
lm_path = hf_hub_download(
repo_id="facebook/audiogen-medium",
filename="state_dict.bin",
)
print(f" Downloaded to cache: {lm_path}")
print("Loading state dict ...")
lm_state = load_lm_state(lm_path)
print("Extracting T5 weights ...")
t5_weights, output_proj, other_cp = extract_t5_weights(lm_state)
print(f" T5 encoder keys: {len(t5_weights)}")
print(f" Output projection keys: {len(output_proj)}")
if other_cp:
print(f" Other condition_provider keys (skipped): {len(other_cp)}")
if not t5_weights:
print("ERROR: No T5 weights found in checkpoint!")
return
# Infer T5 architecture
config = infer_t5_config(t5_weights)
print(f" T5 config: {config['model_name']} — d_model={config['d_model']}, "
f"num_heads={config['num_heads']}, "
f"num_layers={config['num_layers']}, "
f"d_ff={config['d_ff']}, "
f"vocab_size={config['vocab_size']}")
if output_proj:
proj_w = output_proj.get("output_proj.weight")
if proj_w is not None:
print(f" Output projection: {list(proj_w.shape)} "
f"(T5 d_model={proj_w.shape[1]} → LM dim={proj_w.shape[0]})")
# Sanitize keys for MLX compatibility before saving
sanitized_t5 = sanitize_keys_for_mlx(t5_weights)
print(f" Sanitized {len(sanitized_t5)} T5 keys (.layer.N. → .layer_N.)")
# Combine sanitized T5 weights + output_proj into one safetensors
all_weights = {}
all_weights.update(sanitized_t5)
all_weights.update(output_proj)
# Save safetensors
st_path = os.path.join(args.output, "model.safetensors")
print(f"Saving {len(all_weights)} tensors to {st_path} ...")
save_file(all_weights, st_path)
total_bytes = os.path.getsize(st_path)
print(f" Size: {total_bytes / 1e6:.1f} MB")
# Save config
config_path = os.path.join(args.output, "config.json")
with open(config_path, "w") as f:
json.dump(config, f, indent=2)
print(f"Saved config.json")
# Download tokenizer
print("Downloading T5 tokenizer ...")
download_tokenizer(args.output)
print(f"\nDone! T5 conditioner saved to: {args.output}")
print("Files:", sorted(os.listdir(args.output)))
if __name__ == "__main__":
main()