#!/usr/bin/env python3 """ Extract T5 conditioner weights from facebook/audiogen-medium for MLX. The original AudioGen model bundles a frozen T5 text encoder and a trained output projection inside condition_provider.*. The main MLX conversion strips these keys. This script extracts them into a t5/ subdirectory that the MLX AudioGen loader expects. Usage: # Automatic: downloads from HuggingFace, extracts, cleans up python extract_t5.py --output /path/to/audiogen-mlx/t5 # Manual: use a local state_dict.bin you already downloaded python extract_t5.py --lm /path/to/state_dict.bin --output /path/to/audiogen-mlx/t5 Output (in --output directory): config.json T5 encoder config (derived from weight shapes) model.safetensors T5 encoder weights + output_proj tokenizer.json Downloaded from google-t5/t5-small tokenizer_config.json Downloaded from google-t5/t5-small Requirements: pip install torch safetensors huggingface_hub """ import argparse import json import os import struct import tempfile import shutil import torch from safetensors.torch import save_file from huggingface_hub import hf_hub_download T5_PREFIX = "condition_provider.conditioners.description.model." OUTPUT_PROJ_PREFIX = "condition_provider.conditioners.description.output_proj." def load_lm_state(path): """Load the LM state dict from a PyTorch checkpoint.""" ckpt = torch.load(path, map_location="cpu", weights_only=True) if "best_state" in ckpt: return ckpt["best_state"] return ckpt def extract_t5_weights(lm_state): """Extract T5 encoder and output_proj weights from the LM state dict.""" t5_weights = {} output_proj = {} other_cp = [] for key, tensor in lm_state.items(): if not key.startswith("condition_provider."): continue if key.startswith(T5_PREFIX): # Strip prefix to get standard HuggingFace T5 key format new_key = key[len(T5_PREFIX):] t5_weights[new_key] = tensor elif key.startswith(OUTPUT_PROJ_PREFIX): # output_proj.weight / output_proj.bias new_key = key[len(OUTPUT_PROJ_PREFIX):] output_proj[f"output_proj.{new_key}"] = tensor else: other_cp.append(key) return t5_weights, output_proj, other_cp def sanitize_keys_for_mlx(weights): """Rename T5 weight keys for MLX compatibility. HuggingFace T5 uses keys like "encoder.block.0.layer.0.SelfAttention.q.weight" where "layer.0" and "layer.1" are sub-module names. MLX's ModuleParameters.unflattened() splits on ALL dots, which misparses "layer.0" as {"layer": {"0": ...}} instead of treating it as a single key. This renames ".layer.0." to ".layer_0." and ".layer.1." to ".layer_1." so the keys work correctly with MLX's parameter loading. """ sanitized = {} for key, value in weights.items(): new_key = key new_key = new_key.replace(".layer.0.", ".layer_0.") new_key = new_key.replace(".layer.1.", ".layer_1.") sanitized[new_key] = value return sanitized def infer_t5_config(t5_weights): """Determine T5 architecture from weight shapes.""" # shared.weight: [vocab_size, d_model] shared = t5_weights.get("shared.weight") if shared is None: raise ValueError("Cannot find shared.weight in T5 weights") vocab_size = shared.shape[0] d_model = shared.shape[1] # Find q projection to determine d_kv and num_heads q_weight = t5_weights.get("encoder.block.0.layer.0.SelfAttention.q.weight") if q_weight is None: raise ValueError("Cannot find SelfAttention.q.weight") # q.weight: [num_heads * d_kv, d_model] total_kv = q_weight.shape[0] # Find DenseReluDense.wi to determine d_ff wi = t5_weights.get("encoder.block.0.layer.1.DenseReluDense.wi.weight") if wi is None: raise ValueError("Cannot find DenseReluDense.wi.weight") d_ff = wi.shape[0] # Count encoder layers num_layers = 0 while f"encoder.block.{num_layers}.layer.0.SelfAttention.q.weight" in t5_weights: num_layers += 1 # Determine d_kv and num_heads # Standard T5 d_kv values: 64 (all sizes) d_kv = 64 num_heads = total_kv // d_kv # Check relative_attention_bias rab = t5_weights.get( "encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight" ) num_buckets = rab.shape[0] if rab is not None else 32 # Determine T5 variant name from d_model t5_variant = "t5-unknown" if d_model == 512: t5_variant = "t5-small" elif d_model == 768: t5_variant = "t5-base" elif d_model == 1024: t5_variant = "t5-large" elif d_model == 4096: t5_variant = "t5-3b" config = { "architectures": ["T5EncoderModel"], "model_name": t5_variant, "d_model": d_model, "d_kv": d_kv, "d_ff": d_ff, "num_heads": num_heads, "num_layers": num_layers, "vocab_size": vocab_size, "relative_attention_num_buckets": num_buckets, "relative_attention_max_distance": 128, "dropout_rate": 0.0, "layer_norm_epsilon": 1e-6, "feed_forward_proj": "relu", "tie_word_embeddings": True, "decoder_start_token_id": 0, "model_type": "t5", } return config def download_tokenizer(output_dir): """Download T5 tokenizer files from HuggingFace. All T5 model sizes share the same SentencePiece tokenizer (32128 tokens), so we download from t5-small for convenience. """ repo = "google-t5/t5-small" for filename in ["tokenizer.json", "tokenizer_config.json"]: path = hf_hub_download(repo_id=repo, filename=filename) dst = os.path.join(output_dir, filename) shutil.copy2(path, dst) print(f" Copied {filename}") def main(): parser = argparse.ArgumentParser( description="Extract T5 conditioner from facebook/audiogen-medium" ) parser.add_argument( "--lm", help="Path to local state_dict.bin (skips download)", ) parser.add_argument( "--output", required=True, help="Output directory for T5 weights (e.g. /path/to/model/t5)", ) args = parser.parse_args() os.makedirs(args.output, exist_ok=True) # Get the state dict if args.lm: lm_path = args.lm print(f"Loading local checkpoint: {lm_path}") else: print("Downloading facebook/audiogen-medium state_dict.bin ...") lm_path = hf_hub_download( repo_id="facebook/audiogen-medium", filename="state_dict.bin", ) print(f" Downloaded to cache: {lm_path}") print("Loading state dict ...") lm_state = load_lm_state(lm_path) print("Extracting T5 weights ...") t5_weights, output_proj, other_cp = extract_t5_weights(lm_state) print(f" T5 encoder keys: {len(t5_weights)}") print(f" Output projection keys: {len(output_proj)}") if other_cp: print(f" Other condition_provider keys (skipped): {len(other_cp)}") if not t5_weights: print("ERROR: No T5 weights found in checkpoint!") return # Infer T5 architecture config = infer_t5_config(t5_weights) print(f" T5 config: {config['model_name']} — d_model={config['d_model']}, " f"num_heads={config['num_heads']}, " f"num_layers={config['num_layers']}, " f"d_ff={config['d_ff']}, " f"vocab_size={config['vocab_size']}") if output_proj: proj_w = output_proj.get("output_proj.weight") if proj_w is not None: print(f" Output projection: {list(proj_w.shape)} " f"(T5 d_model={proj_w.shape[1]} → LM dim={proj_w.shape[0]})") # Sanitize keys for MLX compatibility before saving sanitized_t5 = sanitize_keys_for_mlx(t5_weights) print(f" Sanitized {len(sanitized_t5)} T5 keys (.layer.N. → .layer_N.)") # Combine sanitized T5 weights + output_proj into one safetensors all_weights = {} all_weights.update(sanitized_t5) all_weights.update(output_proj) # Save safetensors st_path = os.path.join(args.output, "model.safetensors") print(f"Saving {len(all_weights)} tensors to {st_path} ...") save_file(all_weights, st_path) total_bytes = os.path.getsize(st_path) print(f" Size: {total_bytes / 1e6:.1f} MB") # Save config config_path = os.path.join(args.output, "config.json") with open(config_path, "w") as f: json.dump(config, f, indent=2) print(f"Saved config.json") # Download tokenizer print("Downloading T5 tokenizer ...") download_tokenizer(args.output) print(f"\nDone! T5 conditioner saved to: {args.output}") print("Files:", sorted(os.listdir(args.output))) if __name__ == "__main__": main()