Instructions to use mlx-community/audiogen-medium-mlx with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- MLX
How to use mlx-community/audiogen-medium-mlx with MLX:
# Download the model from the Hub pip install huggingface_hub[hf_xet] huggingface-cli download --local-dir audiogen-medium-mlx mlx-community/audiogen-medium-mlx
- Notebooks
- Google Colab
- Kaggle
- Local Apps Settings
- LM Studio
| #!/usr/bin/env python3 | |
| """Verify T5 encoder output against Swift implementation. | |
| Loads the same T5 safetensors weights, runs the encoder on the same tokens, | |
| and prints output stats for comparison with the Swift logs. | |
| """ | |
| import math | |
| import mlx.core as mx | |
| import mlx.nn as nn | |
| import json | |
| from pathlib import Path | |
| MODEL_DIR = Path.home() / "Library/Application Support/Velvox/Models/audiogen-mlx/t5" | |
| # ── T5 LayerNorm (RMSNorm, no centering) ── | |
| class T5LayerNorm(nn.Module): | |
| def __init__(self, dims, eps=1e-6): | |
| super().__init__() | |
| self.weight = mx.ones((dims,)) | |
| self.eps = eps | |
| def __call__(self, x): | |
| y = x.astype(mx.float32) | |
| y = y * mx.rsqrt(mx.mean(y * y, axis=-1, keepdims=True) + self.eps) | |
| return self.weight * y.astype(x.dtype) | |
| # ── T5 DenseReluDense ── | |
| class T5DenseActDense(nn.Module): | |
| def __init__(self, d_model, d_ff, act="relu"): | |
| super().__init__() | |
| self.wi = nn.Linear(d_model, d_ff, bias=False) | |
| self.wo = nn.Linear(d_ff, d_model, bias=False) | |
| self.act = act | |
| def __call__(self, x): | |
| h = self.wi(x) | |
| h = nn.relu(h) if self.act == "relu" else nn.gelu(h) | |
| return self.wo(h) | |
| # ── T5 Attention (NO sqrt(d_k) scaling — this is T5's design) ── | |
| class T5Attention(nn.Module): | |
| def __init__(self, config, has_relative_attention_bias=False): | |
| super().__init__() | |
| self.num_heads = config["num_heads"] | |
| self.d_kv = config["d_kv"] | |
| self.d_model = config["d_model"] | |
| self.has_relative_attention_bias = has_relative_attention_bias | |
| self.num_buckets = config["relative_attention_num_buckets"] | |
| self.max_distance = config.get("relative_attention_max_distance", 128) | |
| self.q = nn.Linear(self.d_model, self.num_heads * self.d_kv, bias=False) | |
| self.k = nn.Linear(self.d_model, self.num_heads * self.d_kv, bias=False) | |
| self.v = nn.Linear(self.d_model, self.num_heads * self.d_kv, bias=False) | |
| self.o = nn.Linear(self.num_heads * self.d_kv, self.d_model, bias=False) | |
| if has_relative_attention_bias: | |
| self.relative_attention_bias = nn.Embedding(self.num_buckets, self.num_heads) | |
| def _relative_position_bucket(rp, bidirectional=True, num_buckets=32, max_distance=128): | |
| nb = num_buckets | |
| result = mx.zeros(rp.shape, dtype=mx.int32) | |
| if bidirectional: | |
| nb = nb // 2 | |
| is_pos = mx.where(rp > 0, mx.array(nb, dtype=mx.int32), mx.array(0, dtype=mx.int32)) | |
| result = is_pos | |
| rp = mx.abs(rp) | |
| else: | |
| rp = -mx.minimum(rp, mx.zeros_like(rp)) | |
| max_exact = nb // 2 | |
| is_small = rp < max_exact | |
| large_rp = rp.astype(mx.float32) | |
| log_ratio = mx.log(large_rp / max_exact) / math.log(max_distance / max_exact) | |
| large_bucket = (log_ratio * (nb - max_exact)).astype(mx.int32) + max_exact | |
| clamped = mx.minimum(large_bucket, mx.array(nb - 1, dtype=mx.int32)) | |
| buckets = mx.where(is_small, rp.astype(mx.int32), clamped) | |
| return result + buckets | |
| def compute_bias(self, q_len, k_len): | |
| if not self.has_relative_attention_bias: | |
| return None | |
| ctx = mx.arange(q_len, dtype=mx.int32) | |
| mem = mx.arange(k_len, dtype=mx.int32) | |
| rp = mem.reshape(1, -1).astype(mx.float32) - ctx.reshape(-1, 1).astype(mx.float32) | |
| rp_bucket = self._relative_position_bucket( | |
| rp, bidirectional=True, | |
| num_buckets=self.num_buckets, max_distance=self.max_distance | |
| ) | |
| flat = rp_bucket.reshape(-1) | |
| bias_flat = self.relative_attention_bias(flat) | |
| bias = bias_flat.reshape(q_len, k_len, self.num_heads) | |
| bias = bias.transpose(2, 0, 1)[None, :, :, :] # [1, H, Q, K] | |
| return bias | |
| def __call__(self, hidden, mask=None, position_bias=None): | |
| B, T, _ = hidden.shape | |
| q = self.q(hidden).reshape(B, T, self.num_heads, self.d_kv) | |
| k = self.k(hidden).reshape(B, T, self.num_heads, self.d_kv) | |
| v = self.v(hidden).reshape(B, T, self.num_heads, self.d_kv) | |
| q = q.transpose(0, 2, 1, 3) # [B, H, T, d] | |
| k = k.transpose(0, 2, 3, 1) # [B, H, d, T] | |
| v = v.transpose(0, 2, 1, 3) # [B, H, T, d] | |
| # T5: NO scaling by 1/sqrt(d_k) | |
| scores = q @ k | |
| if position_bias is None: | |
| position_bias = self.compute_bias(T, T) | |
| if position_bias is not None: | |
| scores = scores + position_bias | |
| weights = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype) | |
| out = (weights @ v).transpose(0, 2, 1, 3).reshape(B, T, -1) | |
| return self.o(out) | |
| # ── T5 Block ── | |
| class T5Block(nn.Module): | |
| def __init__(self, config, has_relative_attention_bias=False): | |
| super().__init__() | |
| self.self_attn = T5Attention(config, has_relative_attention_bias) | |
| self.layer_norm_sa = T5LayerNorm(config["d_model"], config.get("layer_norm_epsilon", 1e-6)) | |
| self.ff = T5DenseActDense(config["d_model"], config["d_ff"], config.get("feed_forward_proj", "relu")) | |
| self.layer_norm_ff = T5LayerNorm(config["d_model"], config.get("layer_norm_epsilon", 1e-6)) | |
| def __call__(self, x, mask=None, position_bias=None): | |
| normed = self.layer_norm_sa(x) | |
| attn_out = self.self_attn(normed, mask=mask, position_bias=position_bias) | |
| x = x + attn_out | |
| normed = self.layer_norm_ff(x) | |
| ff_out = self.ff(normed) | |
| x = x + ff_out | |
| return x | |
| # ── T5 Encoder ── | |
| class T5Encoder(nn.Module): | |
| def __init__(self, config): | |
| super().__init__() | |
| self.shared = nn.Embedding(config["vocab_size"], config["d_model"]) | |
| self.blocks = [T5Block(config, has_relative_attention_bias=(i == 0)) | |
| for i in range(config["num_layers"])] | |
| self.final_layer_norm = T5LayerNorm(config["d_model"], config.get("layer_norm_epsilon", 1e-6)) | |
| def __call__(self, input_ids): | |
| x = self.shared(input_ids) | |
| pos_bias = self.blocks[0].self_attn.compute_bias(x.shape[1], x.shape[1]) | |
| for block in self.blocks: | |
| x = block(x, position_bias=pos_bias) | |
| return self.final_layer_norm(x) | |
| def load_and_remap_weights(t5_dir): | |
| """Load safetensors and remap sanitized MLX keys to our module structure. | |
| The safetensors use MLX-sanitized keys with layer_0/layer_1 (underscores), | |
| not the original HuggingFace layer.0/layer.1 (dots). | |
| """ | |
| import glob | |
| safetensors_files = sorted(glob.glob(str(t5_dir / "*.safetensors"))) | |
| all_weights = {} | |
| for f in safetensors_files: | |
| w = mx.load(f) | |
| all_weights.update(w) | |
| # Separate output_proj from T5 weights | |
| output_proj_w = all_weights.pop("output_proj.weight", None) | |
| output_proj_b = all_weights.pop("output_proj.bias", None) | |
| # Remap sanitized keys to our module structure | |
| remapped = {} | |
| for key, val in all_weights.items(): | |
| new_key = key | |
| # encoder.block.N.layer_0.SelfAttention.X → blocks.N.self_attn.X | |
| new_key = new_key.replace("encoder.block.", "blocks.") | |
| new_key = new_key.replace(".layer_0.SelfAttention.", ".self_attn.") | |
| new_key = new_key.replace(".layer_0.layer_norm.", ".layer_norm_sa.") | |
| new_key = new_key.replace(".layer_1.DenseReluDense.", ".ff.") | |
| new_key = new_key.replace(".layer_1.layer_norm.", ".layer_norm_ff.") | |
| # encoder.final_layer_norm → final_layer_norm | |
| new_key = new_key.replace("encoder.final_layer_norm.", "final_layer_norm.") | |
| remapped[new_key] = val | |
| return remapped, output_proj_w, output_proj_b | |
| def main(): | |
| print("=" * 60) | |
| print("T5 Encoder Verification (MLX Python reference)") | |
| print("=" * 60) | |
| # Load config | |
| with open(t5_dir / "config.json") as f: | |
| config = json.load(f) | |
| print(f"Config: d_model={config['d_model']} layers={config['num_layers']} " | |
| f"heads={config['num_heads']} d_kv={config['d_kv']} d_ff={config['d_ff']}") | |
| # Build model | |
| encoder = T5Encoder(config) | |
| # Load weights | |
| weights, proj_w, proj_b = load_and_remap_weights(MODEL_DIR) | |
| # Apply weights | |
| encoder.load_weights(list(weights.items())) | |
| # Build output_proj | |
| output_proj = None | |
| if proj_w is not None: | |
| output_proj = nn.Linear(proj_w.shape[1], proj_w.shape[0]) | |
| proj_params = [("weight", proj_w)] | |
| if proj_b is not None: | |
| proj_params.append(("bias", proj_b)) | |
| output_proj.load_weights(proj_params) | |
| print(f"output_proj: {proj_w.shape[1]} → {proj_w.shape[0]}") | |
| # Test prompts with known token IDs from Swift logs | |
| test_cases = [ | |
| ("dog barking", [1782, 21696, 53, 1]), | |
| ("cars in the street", [2948, 16, 8, 2815, 1]), | |
| ("A metro train leaving the platform", [71, 12810, 2412, 3140, 8, 1585, 1]), | |
| ] | |
| for prompt, token_ids in test_cases: | |
| print(f"\n--- '{prompt}' ---") | |
| print(f"Tokens: {token_ids}") | |
| input_ids = mx.array([token_ids], dtype=mx.int32) | |
| features = encoder(input_ids) | |
| mx.eval(features) | |
| print(f"Encoder output: shape={features.shape} " | |
| f"min={features.min().item():.7f} max={features.max().item():.7f} " | |
| f"sum={features.sum().item():.4f}") | |
| for i in range(features.shape[1]): | |
| pos_feat = features[0, i] | |
| print(f" pos[{i}]: min={pos_feat.min().item():.5f} " | |
| f"max={pos_feat.max().item():.5f} " | |
| f"mean={pos_feat.mean().item():.5f}") | |
| if output_proj is not None: | |
| projected = output_proj(features) | |
| mx.eval(projected) | |
| print(f"After output_proj: shape={projected.shape} " | |
| f"min={projected.min().item():.7f} max={projected.max().item():.7f} " | |
| f"sum={projected.sum().item():.4f}") | |
| if __name__ == "__main__": | |
| t5_dir = MODEL_DIR | |
| if not t5_dir.exists(): | |
| print(f"T5 directory not found: {t5_dir}") | |
| exit(1) | |
| main() | |