Instructions to use mlx-community/audiogen-medium-mlx with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- MLX
How to use mlx-community/audiogen-medium-mlx with MLX:
# Download the model from the Hub pip install huggingface_hub[hf_xet] huggingface-cli download --local-dir audiogen-medium-mlx mlx-community/audiogen-medium-mlx
- Notebooks
- Google Colab
- Kaggle
- Local Apps Settings
- LM Studio
| #!/usr/bin/env python3 | |
| """ | |
| Extract T5 conditioner weights from facebook/audiogen-medium for MLX. | |
| The original AudioGen model bundles a frozen T5 text encoder and a trained | |
| output projection inside condition_provider.*. The main MLX conversion strips | |
| these keys. This script extracts them into a t5/ subdirectory that the MLX | |
| AudioGen loader expects. | |
| Usage: | |
| # Automatic: downloads from HuggingFace, extracts, cleans up | |
| python extract_t5.py --output /path/to/audiogen-mlx/t5 | |
| # Manual: use a local state_dict.bin you already downloaded | |
| python extract_t5.py --lm /path/to/state_dict.bin --output /path/to/audiogen-mlx/t5 | |
| Output (in --output directory): | |
| config.json T5 encoder config (derived from weight shapes) | |
| model.safetensors T5 encoder weights + output_proj | |
| tokenizer.json Downloaded from google-t5/t5-small | |
| tokenizer_config.json Downloaded from google-t5/t5-small | |
| Requirements: | |
| pip install torch safetensors huggingface_hub | |
| """ | |
| import argparse | |
| import json | |
| import os | |
| import struct | |
| import tempfile | |
| import shutil | |
| import torch | |
| from safetensors.torch import save_file | |
| from huggingface_hub import hf_hub_download | |
| T5_PREFIX = "condition_provider.conditioners.description.model." | |
| OUTPUT_PROJ_PREFIX = "condition_provider.conditioners.description.output_proj." | |
| def load_lm_state(path): | |
| """Load the LM state dict from a PyTorch checkpoint.""" | |
| ckpt = torch.load(path, map_location="cpu", weights_only=True) | |
| if "best_state" in ckpt: | |
| return ckpt["best_state"] | |
| return ckpt | |
| def extract_t5_weights(lm_state): | |
| """Extract T5 encoder and output_proj weights from the LM state dict.""" | |
| t5_weights = {} | |
| output_proj = {} | |
| other_cp = [] | |
| for key, tensor in lm_state.items(): | |
| if not key.startswith("condition_provider."): | |
| continue | |
| if key.startswith(T5_PREFIX): | |
| # Strip prefix to get standard HuggingFace T5 key format | |
| new_key = key[len(T5_PREFIX):] | |
| t5_weights[new_key] = tensor | |
| elif key.startswith(OUTPUT_PROJ_PREFIX): | |
| # output_proj.weight / output_proj.bias | |
| new_key = key[len(OUTPUT_PROJ_PREFIX):] | |
| output_proj[f"output_proj.{new_key}"] = tensor | |
| else: | |
| other_cp.append(key) | |
| return t5_weights, output_proj, other_cp | |
| def sanitize_keys_for_mlx(weights): | |
| """Rename T5 weight keys for MLX compatibility. | |
| HuggingFace T5 uses keys like "encoder.block.0.layer.0.SelfAttention.q.weight" | |
| where "layer.0" and "layer.1" are sub-module names. MLX's | |
| ModuleParameters.unflattened() splits on ALL dots, which misparses "layer.0" | |
| as {"layer": {"0": ...}} instead of treating it as a single key. | |
| This renames ".layer.0." to ".layer_0." and ".layer.1." to ".layer_1." so | |
| the keys work correctly with MLX's parameter loading. | |
| """ | |
| sanitized = {} | |
| for key, value in weights.items(): | |
| new_key = key | |
| new_key = new_key.replace(".layer.0.", ".layer_0.") | |
| new_key = new_key.replace(".layer.1.", ".layer_1.") | |
| sanitized[new_key] = value | |
| return sanitized | |
| def infer_t5_config(t5_weights): | |
| """Determine T5 architecture from weight shapes.""" | |
| # shared.weight: [vocab_size, d_model] | |
| shared = t5_weights.get("shared.weight") | |
| if shared is None: | |
| raise ValueError("Cannot find shared.weight in T5 weights") | |
| vocab_size = shared.shape[0] | |
| d_model = shared.shape[1] | |
| # Find q projection to determine d_kv and num_heads | |
| q_weight = t5_weights.get("encoder.block.0.layer.0.SelfAttention.q.weight") | |
| if q_weight is None: | |
| raise ValueError("Cannot find SelfAttention.q.weight") | |
| # q.weight: [num_heads * d_kv, d_model] | |
| total_kv = q_weight.shape[0] | |
| # Find DenseReluDense.wi to determine d_ff | |
| wi = t5_weights.get("encoder.block.0.layer.1.DenseReluDense.wi.weight") | |
| if wi is None: | |
| raise ValueError("Cannot find DenseReluDense.wi.weight") | |
| d_ff = wi.shape[0] | |
| # Count encoder layers | |
| num_layers = 0 | |
| while f"encoder.block.{num_layers}.layer.0.SelfAttention.q.weight" in t5_weights: | |
| num_layers += 1 | |
| # Determine d_kv and num_heads | |
| # Standard T5 d_kv values: 64 (all sizes) | |
| d_kv = 64 | |
| num_heads = total_kv // d_kv | |
| # Check relative_attention_bias | |
| rab = t5_weights.get( | |
| "encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight" | |
| ) | |
| num_buckets = rab.shape[0] if rab is not None else 32 | |
| # Determine T5 variant name from d_model | |
| t5_variant = "t5-unknown" | |
| if d_model == 512: | |
| t5_variant = "t5-small" | |
| elif d_model == 768: | |
| t5_variant = "t5-base" | |
| elif d_model == 1024: | |
| t5_variant = "t5-large" | |
| elif d_model == 4096: | |
| t5_variant = "t5-3b" | |
| config = { | |
| "architectures": ["T5EncoderModel"], | |
| "model_name": t5_variant, | |
| "d_model": d_model, | |
| "d_kv": d_kv, | |
| "d_ff": d_ff, | |
| "num_heads": num_heads, | |
| "num_layers": num_layers, | |
| "vocab_size": vocab_size, | |
| "relative_attention_num_buckets": num_buckets, | |
| "relative_attention_max_distance": 128, | |
| "dropout_rate": 0.0, | |
| "layer_norm_epsilon": 1e-6, | |
| "feed_forward_proj": "relu", | |
| "tie_word_embeddings": True, | |
| "decoder_start_token_id": 0, | |
| "model_type": "t5", | |
| } | |
| return config | |
| def download_tokenizer(output_dir): | |
| """Download T5 tokenizer files from HuggingFace. | |
| All T5 model sizes share the same SentencePiece tokenizer (32128 tokens), | |
| so we download from t5-small for convenience. | |
| """ | |
| repo = "google-t5/t5-small" | |
| for filename in ["tokenizer.json", "tokenizer_config.json"]: | |
| path = hf_hub_download(repo_id=repo, filename=filename) | |
| dst = os.path.join(output_dir, filename) | |
| shutil.copy2(path, dst) | |
| print(f" Copied {filename}") | |
| def main(): | |
| parser = argparse.ArgumentParser( | |
| description="Extract T5 conditioner from facebook/audiogen-medium" | |
| ) | |
| parser.add_argument( | |
| "--lm", | |
| help="Path to local state_dict.bin (skips download)", | |
| ) | |
| parser.add_argument( | |
| "--output", | |
| required=True, | |
| help="Output directory for T5 weights (e.g. /path/to/model/t5)", | |
| ) | |
| args = parser.parse_args() | |
| os.makedirs(args.output, exist_ok=True) | |
| # Get the state dict | |
| if args.lm: | |
| lm_path = args.lm | |
| print(f"Loading local checkpoint: {lm_path}") | |
| else: | |
| print("Downloading facebook/audiogen-medium state_dict.bin ...") | |
| lm_path = hf_hub_download( | |
| repo_id="facebook/audiogen-medium", | |
| filename="state_dict.bin", | |
| ) | |
| print(f" Downloaded to cache: {lm_path}") | |
| print("Loading state dict ...") | |
| lm_state = load_lm_state(lm_path) | |
| print("Extracting T5 weights ...") | |
| t5_weights, output_proj, other_cp = extract_t5_weights(lm_state) | |
| print(f" T5 encoder keys: {len(t5_weights)}") | |
| print(f" Output projection keys: {len(output_proj)}") | |
| if other_cp: | |
| print(f" Other condition_provider keys (skipped): {len(other_cp)}") | |
| if not t5_weights: | |
| print("ERROR: No T5 weights found in checkpoint!") | |
| return | |
| # Infer T5 architecture | |
| config = infer_t5_config(t5_weights) | |
| print(f" T5 config: {config['model_name']} — d_model={config['d_model']}, " | |
| f"num_heads={config['num_heads']}, " | |
| f"num_layers={config['num_layers']}, " | |
| f"d_ff={config['d_ff']}, " | |
| f"vocab_size={config['vocab_size']}") | |
| if output_proj: | |
| proj_w = output_proj.get("output_proj.weight") | |
| if proj_w is not None: | |
| print(f" Output projection: {list(proj_w.shape)} " | |
| f"(T5 d_model={proj_w.shape[1]} → LM dim={proj_w.shape[0]})") | |
| # Sanitize keys for MLX compatibility before saving | |
| sanitized_t5 = sanitize_keys_for_mlx(t5_weights) | |
| print(f" Sanitized {len(sanitized_t5)} T5 keys (.layer.N. → .layer_N.)") | |
| # Combine sanitized T5 weights + output_proj into one safetensors | |
| all_weights = {} | |
| all_weights.update(sanitized_t5) | |
| all_weights.update(output_proj) | |
| # Save safetensors | |
| st_path = os.path.join(args.output, "model.safetensors") | |
| print(f"Saving {len(all_weights)} tensors to {st_path} ...") | |
| save_file(all_weights, st_path) | |
| total_bytes = os.path.getsize(st_path) | |
| print(f" Size: {total_bytes / 1e6:.1f} MB") | |
| # Save config | |
| config_path = os.path.join(args.output, "config.json") | |
| with open(config_path, "w") as f: | |
| json.dump(config, f, indent=2) | |
| print(f"Saved config.json") | |
| # Download tokenizer | |
| print("Downloading T5 tokenizer ...") | |
| download_tokenizer(args.output) | |
| print(f"\nDone! T5 conditioner saved to: {args.output}") | |
| print("Files:", sorted(os.listdir(args.output))) | |
| if __name__ == "__main__": | |
| main() | |