demucs-mlx / export_from_pytorch.py
Kyle Howells
Replace demucs-mlx conversion scripts with direct PyTorch exporter
3ff2e58
#!/usr/bin/env python3
"""
Export Demucs PyTorch models directly to safetensors + JSON config for Swift MLX.
Converts all 8 pretrained models directly from the original PyTorch demucs package.
No dependency on demucs-mlx or any other re-implementation.
Usage:
# Export all models
python scripts/export_from_pytorch.py --out-dir ~/.cache/demucs-mlx-swift-models
# Export specific models
python scripts/export_from_pytorch.py --models htdemucs htdemucs_ft --out-dir ./Models
Requirements:
pip install demucs safetensors numpy
"""
from __future__ import annotations
import argparse
import inspect
import json
import re
import sys
from fractions import Fraction
from pathlib import Path
import numpy as np
import torch
ALL_MODELS = [
"htdemucs",
"htdemucs_ft",
"htdemucs_6s",
"hdemucs_mmi",
"mdx",
"mdx_extra",
"mdx_q",
"mdx_extra_q",
]
# Map PyTorch class names to MLX class names used by Swift loader
CLASS_MAP = {
"Demucs": "DemucsMLX",
"HDemucs": "HDemucsMLX",
"HTDemucs": "HTDemucsMLX",
}
# Conv-like layer names that get .conv. wrapper in MLX
CONV_LAYER_NAMES = {
"conv", "conv_tr", "rewrite",
"channel_upsampler", "channel_downsampler",
"channel_upsampler_t", "channel_downsampler_t",
}
# DConv attention sub-module names (LocalState)
DCONV_ATTN_NAMES = {"content", "key", "query", "proj", "query_decay", "query_freqs"}
def to_json_serializable(obj):
"""Convert Python objects to JSON-serializable types."""
if isinstance(obj, Fraction):
return f"{obj.numerator}/{obj.denominator}"
if isinstance(obj, torch.Tensor):
return obj.item() if obj.numel() == 1 else obj.tolist()
if isinstance(obj, np.ndarray):
return obj.tolist()
if isinstance(obj, (list, tuple)):
return [to_json_serializable(x) for x in obj]
if isinstance(obj, dict):
return {str(k): to_json_serializable(v) for k, v in obj.items()}
return obj
def transpose_conv_weights(key: str, value: np.ndarray, is_conv_transpose: bool = False) -> np.ndarray:
"""Transpose PyTorch conv weights to MLX layout.
Conv1d: (out, in, k) → MLX: (out, k, in) transpose (0,2,1)
Conv2d: (out, in, h, w) → MLX: (out, h, w, in) transpose (0,2,3,1)
ConvTranspose1d: (in, out, k) → MLX: (out, k, in) transpose (1,2,0)
ConvTranspose2d: (in, out, h, w) → MLX: (out, h, w, in) transpose (1,2,3,0)
"""
if not key.endswith(".weight"):
return value
if len(value.shape) == 3:
return np.transpose(value, (1, 2, 0) if is_conv_transpose else (0, 2, 1))
if len(value.shape) == 4:
return np.transpose(value, (1, 2, 3, 0) if is_conv_transpose else (0, 2, 3, 1))
return value
def remap_key(
key: str,
value: np.ndarray,
model_type: str = "HTDemucs",
dconv_conv_slots: set | None = None,
seq_conv_slots: set | None = None,
) -> list[tuple[str, np.ndarray]]:
"""Remap a PyTorch state dict key to MLX key convention.
Returns a list of (key, value) pairs (multiple for attention in_proj splits).
Duplicate target keys (e.g. LSTM bias_ih + bias_hh) are merged by the caller.
Args:
key: PyTorch state dict key
value: numpy array (already transposed for conv weights)
model_type: PyTorch class name ("Demucs", "HDemucs", "HTDemucs")
dconv_conv_slots: set of (block_prefix, slot_str) for DConv slots with 3D weights
seq_conv_slots: set of (enc_dec, layer, slot) for Demucs v1/v2 Sequential Conv slots
"""
dconv_conv_slots = dconv_conv_slots or set()
seq_conv_slots = seq_conv_slots or set()
# =========================================================================
# Step 1: Demucs v1/v2 Sequential insertion
# encoder.{i}.{j}.rest → encoder.{i}.layers.{j}.rest
# decoder.{i}.{j}.rest → decoder.{i}.layers.{j}.rest
# =========================================================================
if model_type == "Demucs":
m = re.match(r"(encoder|decoder)\.(\d+)\.(\d+)(\..*)?$", key)
if m:
enc_dec, layer, slot, rest = m.groups()
rest = rest or ""
key = f"{enc_dec}.{layer}.layers.{slot}{rest}"
# =========================================================================
# Step 1.5: Demucs v1/v2 Sequential Conv/Norm slot wrapping
# encoder.{i}.layers.{j}.weight → encoder.{i}.layers.{j}.conv.weight (if Conv slot)
# =========================================================================
if model_type == "Demucs":
m = re.match(r"(encoder|decoder)\.(\d+)\.layers\.(\d+)\.(weight|bias)$", key)
if m:
enc_dec, layer, slot, param = m.groups()
if (enc_dec, layer, slot) in seq_conv_slots:
return [(f"{enc_dec}.{layer}.layers.{slot}.conv.{param}", value)]
else:
return [(f"{enc_dec}.{layer}.layers.{slot}.{param}", value)]
# =========================================================================
# Step 2: DConv internal slot handling
# Matches: *.layers.{block_idx}.{slot_idx}.{rest}
# Both HDemucs (.dconv.layers.) and Demucs v1/v2 (.layers.{N}.layers.) end
# with this pattern after Step 1.
# =========================================================================
m = re.match(r"(.+\.layers\.\d+)\.(\d+)\.(.+)$", key)
if m:
block_prefix = m.group(1)
slot = m.group(2)
rest = m.group(3)
# --- 2a. Simple weight/bias/scale ---
if rest in ("weight", "bias", "scale"):
if rest == "weight" and len(value.shape) >= 2:
# 3D weight = Conv1d → add .conv.
return [(f"{block_prefix}.layers.{slot}.conv.{rest}", value)]
elif rest == "weight":
# 1D weight = GroupNorm → no wrapper
return [(f"{block_prefix}.layers.{slot}.{rest}", value)]
elif rest == "bias":
if (block_prefix, slot) in dconv_conv_slots:
return [(f"{block_prefix}.layers.{slot}.conv.{rest}", value)]
else:
return [(f"{block_prefix}.layers.{slot}.{rest}", value)]
else: # scale
return [(f"{block_prefix}.layers.{slot}.{rest}", value)]
# --- 2b. LSTM weights/biases ---
m_lstm = re.match(r"lstm\.(weight|bias)_(ih|hh)_l(\d+)(_reverse)?$", rest)
if m_lstm:
wb, ih_hh, layer_idx, reverse = m_lstm.groups()
direction = "backward_lstms" if reverse else "forward_lstms"
if wb == "weight":
param = "Wx" if ih_hh == "ih" else "Wh"
return [(f"{block_prefix}.layers.{slot}.{direction}.{layer_idx}.{param}", value)]
else: # bias — both bias_ih and bias_hh map to same key; caller merges
return [(f"{block_prefix}.layers.{slot}.{direction}.{layer_idx}.bias", value)]
# --- 2c. LSTM linear ---
m_linear = re.match(r"linear\.(weight|bias)$", rest)
if m_linear:
param = m_linear.group(1)
return [(f"{block_prefix}.layers.{slot}.linear.{param}", value)]
# --- 2d. Attention sub-modules (LocalState) ---
m_attn = re.match(r"(content|key|query|proj|query_decay|query_freqs)\.(weight|bias)$", rest)
if m_attn:
attn_name, param = m_attn.groups()
# These are all Conv1d modules → add .conv. wrapper
return [(f"{block_prefix}.layers.{slot}.{attn_name}.conv.{param}", value)]
# --- 2e. Fallback for unknown compound keys ---
return [(f"{block_prefix}.layers.{slot}.{rest}", value)]
# =========================================================================
# Step 3: MultiheadAttention in_proj split (HTDemucs transformer)
# =========================================================================
m = re.match(r"(.+)\.(self_attn|cross_attn)\.in_proj_(weight|bias)$", key)
if m:
prefix, attn_type, param = m.group(1), m.group(2), m.group(3)
mlx_attn = "attn" if attn_type == "self_attn" else "cross_attn"
dim = value.shape[0] // 3
q, k_val, v = value[:dim], value[dim : 2 * dim], value[2 * dim :]
return [
(f"{prefix}.{mlx_attn}.query_proj.{param}", q),
(f"{prefix}.{mlx_attn}.key_proj.{param}", k_val),
(f"{prefix}.{mlx_attn}.value_proj.{param}", v),
]
# self_attn.out_proj → attn.out_proj
m = re.match(r"(.+)\.self_attn\.out_proj\.(weight|bias)$", key)
if m:
prefix, param = m.group(1), m.group(2)
return [(f"{prefix}.attn.out_proj.{param}", value)]
# =========================================================================
# Step 4: norm_out wrapping → norm_out.gn
# =========================================================================
m = re.match(r"(.+)\.norm_out\.(weight|bias)$", key)
if m:
prefix, param = m.group(1), m.group(2)
return [(f"{prefix}.norm_out.gn.{param}", value)]
# =========================================================================
# Step 5: Bottleneck LSTM (Demucs v1/v2 and HDemucs)
# lstm.lstm.weight_ih_l0 → lstm.forward_lstms.0.Wx
# =========================================================================
m = re.match(r"(.+)\.lstm\.(weight|bias)_(ih|hh)_l(\d+)(_reverse)?$", key)
if m:
prefix = m.group(1)
wb = m.group(2)
ih_hh = m.group(3)
layer_idx = m.group(4)
reverse = m.group(5)
direction = "backward_lstms" if reverse else "forward_lstms"
if wb == "weight":
param = "Wx" if ih_hh == "ih" else "Wh"
return [(f"{prefix}.{direction}.{layer_idx}.{param}", value)]
else: # bias — merge handled by caller
return [(f"{prefix}.{direction}.{layer_idx}.bias", value)]
# =========================================================================
# Step 6: Conv/ConvTranspose/Rewrite named layers → add .conv. wrapper
# =========================================================================
parts = key.rsplit(".", 1)
if len(parts) == 2:
path, param = parts
path_parts = path.split(".")
last_name = path_parts[-1]
if last_name in CONV_LAYER_NAMES and param in ("weight", "bias"):
return [(f"{path}.conv.{param}", value)]
# =========================================================================
# Default: no change
# =========================================================================
return [(key, value)]
def convert_sub_model(model, prefix: str) -> dict[str, np.ndarray]:
"""Convert a single sub-model's state dict to MLX-compatible numpy arrays."""
cls_name = type(model).__name__
# --- Pre-scan: identify ConvTranspose modules by type ---
conv_tr_paths = set()
for name, module in model.named_modules():
if isinstance(module, (torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d)):
conv_tr_paths.add(name)
# --- Collect state dict as numpy ---
state_items = []
for key, tensor in model.state_dict().items():
arr = tensor.detach().cpu().float().numpy()
state_items.append((key, arr))
# --- Pre-scan: identify DConv Conv slots (3D weights) ---
# Pattern: *.layers.{block}.{slot}.weight where value is 3D
# For Demucs v1/v2, apply Sequential insertion first so lookups match remap_key
dconv_conv_slots: set[tuple[str, str]] = set()
for key, arr in state_items:
scan_key = key
if cls_name == "Demucs":
m = re.match(r"(encoder|decoder)\.(\d+)\.(\d+)(\..*)?$", scan_key)
if m:
enc_dec, layer, slot, rest = m.groups()
rest = rest or ""
scan_key = f"{enc_dec}.{layer}.layers.{slot}{rest}"
m = re.match(r"(.+\.layers\.\d+)\.(\d+)\.weight$", scan_key)
if m and len(arr.shape) >= 2:
dconv_conv_slots.add((m.group(1), m.group(2)))
# --- Pre-scan: Demucs v1/v2 Sequential Conv slots ---
seq_conv_slots: set[tuple[str, str, str]] = set()
if cls_name == "Demucs":
for key, arr in state_items:
m = re.match(r"(encoder|decoder)\.(\d+)\.(\d+)\.weight$", key)
if m and len(arr.shape) >= 2:
seq_conv_slots.add((m.group(1), m.group(2), m.group(3)))
# --- Convert ---
weights: dict[str, np.ndarray] = {}
for key, arr in state_items:
# Determine if this belongs to a ConvTranspose module
is_conv_tr = any(key.startswith(p + ".") for p in conv_tr_paths)
# Transpose conv weights
arr = transpose_conv_weights(key, arr, is_conv_transpose=is_conv_tr)
# Remap key
remapped = remap_key(key, arr, cls_name, dconv_conv_slots, seq_conv_slots)
for new_key, new_val in remapped:
full_key = f"{prefix}{new_key}"
if full_key in weights:
# LSTM bias merge: bias_ih + bias_hh → bias (additive)
weights[full_key] = weights[full_key] + new_val
else:
weights[full_key] = new_val
return weights
def extract_kwargs(model) -> dict:
"""Extract constructor kwargs from a model using _init_args_kwargs or inspection."""
if hasattr(model, "_init_args_kwargs"):
_, kwargs = model._init_args_kwargs
return {k: to_json_serializable(v) for k, v in kwargs.items()
if isinstance(v, (int, float, str, bool, list, tuple, type(None), Fraction))}
# Fallback: inspect __init__ signature and read matching attributes
sig = inspect.signature(type(model).__init__)
kwargs = {}
for name in sig.parameters:
if name == "self":
continue
if hasattr(model, name):
val = getattr(model, name)
kwargs[name] = to_json_serializable(val)
return kwargs
def export_model(model_name: str, out_dir: Path) -> bool:
"""Export a single model (or bag) to safetensors + config JSON."""
from demucs.pretrained import get_model
from demucs.apply import BagOfModels
print(f"\n--- Exporting {model_name} ---")
try:
model = get_model(model_name)
except Exception as e:
print(f" Failed to load model: {e}")
return False
is_bag = isinstance(model, BagOfModels)
if is_bag:
sub_models = list(model.models)
num_models = len(sub_models)
bag_weights = model.weights.tolist() if hasattr(model.weights, "tolist") else list(model.weights)
else:
sub_models = [model]
num_models = 1
bag_weights = None
print(f" {'Bag of ' + str(num_models) + ' models' if is_bag else 'Single model'}")
# Collect all weights and metadata
all_weights: dict[str, np.ndarray] = {}
model_classes: list[str] = []
model_configs: list[dict] = []
for i, sub in enumerate(sub_models):
cls_name = type(sub).__name__
mlx_cls = CLASS_MAP.get(cls_name, cls_name)
model_classes.append(mlx_cls)
print(f" Model {i}: {cls_name}{mlx_cls}")
prefix = f"model_{i}." if is_bag else ""
sub_weights = convert_sub_model(sub, prefix)
all_weights.update(sub_weights)
kwargs = extract_kwargs(sub)
model_configs.append({
"model_class": mlx_cls,
"kwargs": kwargs,
})
# Build config JSON
config: dict = {
"model_name": model_name,
"tensor_count": len(all_weights),
}
if is_bag:
config["model_class"] = "BagOfModelsMLX"
config["num_models"] = num_models
config["weights"] = bag_weights
config["sub_model_classes"] = model_classes
# If all sub-models are the same class, set sub_model_class for compat
unique = set(model_classes)
if len(unique) == 1:
config["sub_model_class"] = unique.pop()
config["model_configs"] = model_configs
# Also put kwargs at top level for single-model bags (common case)
if num_models == 1:
config["kwargs"] = model_configs[0]["kwargs"]
else:
config["model_class"] = model_classes[0]
config["kwargs"] = model_configs[0]["kwargs"]
# Save files
model_dir = out_dir / model_name
model_dir.mkdir(parents=True, exist_ok=True)
safetensors_path = model_dir / f"{model_name}.safetensors"
config_path = model_dir / f"{model_name}_config.json"
# Save safetensors (prefer safetensors library, fallback to mlx)
try:
from safetensors.numpy import save_file
save_file(all_weights, str(safetensors_path))
except ImportError:
import mlx.core as mx
mlx_weights = {k: mx.array(v) for k, v in all_weights.items()}
mx.save_safetensors(str(safetensors_path), mlx_weights)
with config_path.open("w") as f:
json.dump(config, f, indent=2, default=str)
size_mb = safetensors_path.stat().st_size / (1024 * 1024)
print(f" Wrote {safetensors_path} ({len(all_weights)} tensors, {size_mb:.0f} MB)")
print(f" Wrote {config_path}")
return True
def main():
ap = argparse.ArgumentParser(
description="Export Demucs PyTorch models to safetensors for Swift MLX"
)
ap.add_argument(
"--models",
nargs="*",
default=None,
help=f"Models to export (default: all). Choices: {', '.join(ALL_MODELS)}",
)
ap.add_argument(
"--out-dir",
default="./Models",
help="Output root directory (files go into <out-dir>/<model_name>/)",
)
args = ap.parse_args()
models = args.models or ALL_MODELS
out_dir = Path(args.out_dir).resolve()
exported = 0
failed = 0
for name in models:
if export_model(name, out_dir):
exported += 1
else:
failed += 1
print(f"\n=== Done: {exported} exported, {failed} failed ===")
if failed:
sys.exit(1)
if __name__ == "__main__":
main()