htdemucs-coreml / convert.py
dexxdean's picture
Initial release: HTDemucs to Core ML conversion (FP32 + FP16)
8b4d481 verified
#!/usr/bin/env python3
"""
convert.py — Convert Demucs (Hybrid Transformer) to Core ML.
Core ML does not support complex64 tensors. This script wraps HTDemucs with
a real-valued STFT/ISTFT implementation (rfft -> view_as_real for STFT,
matrix IDFT + overlap-add for ISTFT) while keeping the neural network
(encoder/transformer/decoder) unchanged.
Default output: HTDemucs_CoreML.mlpackage
Prerequisites:
python3 -m venv venv && source venv/bin/activate
pip install -r requirements.txt
Usage:
python convert.py # FP32, ~400 MB
python convert.py --fp16 # FP16, ~200 MB
python convert.py --segment 7 # 7-second segments instead of 10
python convert.py --output Foo.mlpackage
"""
import argparse
import math
import warnings
from pathlib import Path
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
# ---------------------------------------------------------------------------
# Defaults (override via CLI args)
# ---------------------------------------------------------------------------
MODEL_NAME = "htdemucs"
SAMPLE_RATE = 44100
SEGMENT_SAMPLES = 441000 # 10s @ 44.1 kHz
NUM_CHANNELS = 2
NUM_SOURCES = 4
DEFAULT_OUTPUT = "HTDemucs_CoreML.mlpackage"
# Demucs internal source order: drums(0), bass(1), other(2), vocals(3)
# We reorder to vocals, drums, bass, other (typical UI / DJ convention).
SOURCE_REORDER = [3, 0, 1, 2]
SOURCE_NAMES = ["vocals", "drums", "bass", "other"]
# ---------------------------------------------------------------------------
# ManualMHA: replaces nn.MultiheadAttention.
# coremltools cannot convert the fused _native_multi_head_attention op,
# so we decompose attention into matmul + softmax explicitly.
# ---------------------------------------------------------------------------
class ManualMHA(nn.Module):
"""Drop-in für nn.MultiheadAttention, dekomponiert in matmul+softmax."""
def __init__(self, mha: nn.MultiheadAttention):
super().__init__()
self.embed_dim = mha.embed_dim
self.num_heads = mha.num_heads
self.head_dim = mha.embed_dim // mha.num_heads
self.in_proj_weight = mha.in_proj_weight
self.in_proj_bias = mha.in_proj_bias
self.out_proj = mha.out_proj
# Cross-attention: separate k/v projections
self.kdim = mha.kdim
self.vdim = mha.vdim
self._qkv_same_embed_dim = mha._qkv_same_embed_dim
def forward(self, query, key, value, need_weights=False, **kwargs):
B, T, E = query.shape
S = key.shape[1]
if self._qkv_same_embed_dim and query.data_ptr() == key.data_ptr():
# Self-attention: single in_proj for Q, K, V.
qkv = F.linear(query, self.in_proj_weight, self.in_proj_bias)
q, k, v = qkv.chunk(3, dim=-1)
else:
# Cross-attention or different inputs.
w_q, w_k, w_v = self.in_proj_weight.chunk(3, dim=0)
b_q, b_k, b_v = (self.in_proj_bias.chunk(3, dim=0)
if self.in_proj_bias is not None
else (None, None, None))
q = F.linear(query, w_q, b_q)
k = F.linear(key, w_k, b_k)
v = F.linear(value, w_v, b_v)
q = q.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
k = k.view(B, S, self.num_heads, self.head_dim).transpose(1, 2)
v = v.view(B, S, self.num_heads, self.head_dim).transpose(1, 2)
scale = self.head_dim ** -0.5
attn = torch.matmul(q, k.transpose(-2, -1)) * scale
attn = F.softmax(attn, dim=-1)
out = torch.matmul(attn, v)
out = out.transpose(1, 2).contiguous().view(B, T, E)
out = self.out_proj(out)
return out, None
def _replace_mha_recursive(module: nn.Module) -> None:
"""Replace all nn.MultiheadAttention submodules with ManualMHA, in place."""
for name, child in module.named_children():
if isinstance(child, nn.MultiheadAttention):
setattr(module, name, ManualMHA(child))
else:
_replace_mha_recursive(child)
# ---------------------------------------------------------------------------
# 1D reflect-pad helper (mirrors demucs.hdemucs.pad1d).
# ---------------------------------------------------------------------------
def _pad1d(x: torch.Tensor, paddings: tuple, mode: str = "reflect"):
"""Reflect-pad along the last dim, with a fallback for very short signals."""
pl, pr = paddings
length = x.shape[-1]
max_pad = max(pl, pr)
if length <= max_pad:
extra_pad = max_pad - length + 1
x = F.pad(x, (0, extra_pad))
padded = F.pad(x, (pl, pr), mode=mode)
end = padded.shape[-1] - extra_pad
return padded[..., :end]
return F.pad(x, (pl, pr), mode=mode)
# ---------------------------------------------------------------------------
# RealSTFT: real-valued STFT via rfft -> view_as_real.
# Produces (..., freqs, frames, 2) so no complex64 leaks into the traced graph.
# ---------------------------------------------------------------------------
class RealSTFT(nn.Module):
"""STFT that returns only real tensors."""
def __init__(self, n_fft: int, hop_length: int):
super().__init__()
self.n_fft = n_fft
self.hop_length = hop_length
self.register_buffer("window", torch.hann_window(n_fft))
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Input: (B, C, T)
Output: (B, C, freqs, frames, 2) -- [real, imag]
"""
B, C, T = x.shape
x_flat = x.reshape(B * C, T)
# torch.stft -> complex -> immediately view_as_real.
z = torch.stft(
x_flat, self.n_fft, self.hop_length,
window=self.window, win_length=self.n_fft,
normalized=True, center=True, return_complex=True,
)
# z: (B*C, freqs, frames) complex64.
z_ri = torch.view_as_real(z) # (B*C, freqs, frames, 2) float32.
_, Fr, Fm, _ = z_ri.shape
return z_ri.view(B, C, Fr, Fm, 2)
# ---------------------------------------------------------------------------
# RealISTFT: real-valued ISTFT via matrix IDFT + overlap-add.
# Avoids view_as_complex (not supported by coremltools).
# ---------------------------------------------------------------------------
class RealISTFT(nn.Module):
"""Pure real-valued ISTFT (matrix IDFT + OLA)."""
def __init__(self, n_fft: int, hop_length: int, num_frames: int):
super().__init__()
self.n_fft = n_fft
self.hop_length = hop_length
freqs = n_fft // 2 + 1
# Synthesis window.
window = torch.hann_window(n_fft)
self.register_buffer("window", window)
# IDFT basis matrices: cos / sin for a one-sided spectrum.
n = torch.arange(n_fft, dtype=torch.float32).unsqueeze(0) # (1, N)
k = torch.arange(freqs, dtype=torch.float32).unsqueeze(1) # (freqs, 1)
angles = 2.0 * math.pi * k * n / n_fft # (freqs, N)
cos_basis = torch.cos(angles)
sin_basis = torch.sin(angles)
# Scaling: DC and Nyquist single, rest double (one-sided spectrum).
# Normalization: /N * sqrt(N) because the forward STFT used normalized=True.
norm = math.sqrt(n_fft)
scale = torch.ones(freqs, 1) * (2.0 / n_fft * norm)
scale[0] = 1.0 / n_fft * norm
scale[-1] = 1.0 / n_fft * norm
self.register_buffer("cos_basis", cos_basis * scale) # (freqs, N)
self.register_buffer("sin_basis", sin_basis * scale) # (freqs, N)
# Pre-compute OLA indices and window-sum buffer.
# Core ML's 1D scatter_add can mis-compile for some shapes; using a
# pre-built index tensor + the canonical scatter_add_ call sidesteps it.
out_length = (num_frames - 1) * hop_length + n_fft
frame_offsets = torch.arange(num_frames) * hop_length
local_offsets = torch.arange(n_fft)
ola_indices = (frame_offsets.unsqueeze(1) + local_offsets.unsqueeze(0)).reshape(-1)
self.register_buffer("ola_indices", ola_indices.long())
window_sq = window * window
win_sum = torch.zeros(out_length)
for i in range(num_frames):
start = i * hop_length
win_sum[start:start + n_fft] += window_sq
win_sum = win_sum.clamp(min=1e-8)
self.register_buffer("win_sum", win_sum)
self.out_length = out_length
def forward(self, z_ri: torch.Tensor, length: int) -> torch.Tensor:
"""
Input: z_ri (batch, freqs, frames, 2)
Output: (batch, length)
"""
real = z_ri[..., 0] # (batch, freqs, frames)
imag = z_ri[..., 1]
# Per-frame IDFT: (batch, frames, freqs) @ (freqs, N) -> (batch, frames, N)
real_t = real.transpose(-2, -1)
imag_t = imag.transpose(-2, -1)
frames_signal = (
torch.matmul(real_t, self.cos_basis)
- torch.matmul(imag_t, self.sin_basis)
)
# Apply synthesis window.
frames_signal = frames_signal * self.window.unsqueeze(0).unsqueeze(0)
# --- Overlap-add via scatter_add ---
batch = frames_signal.shape[0]
idx = self.ola_indices.unsqueeze(0).expand(batch, -1)
flat = frames_signal.reshape(batch, -1)
output = torch.zeros(batch, self.out_length, device=z_ri.device)
output.scatter_add_(1, idx, flat)
# Window normalization (pre-computed buffer).
output = output / self.win_sum.unsqueeze(0)
# Strip center padding.
pad = self.n_fft // 2
output = output[:, pad:pad + length]
return output
# ---------------------------------------------------------------------------
# RealValuedHTDemucs: wrapper that swaps STFT/ISTFT for real-valued versions
# while keeping the actual network (encoder / transformer / decoder) intact.
# ---------------------------------------------------------------------------
class RealValuedHTDemucs(nn.Module):
"""
Wraps HTDemucs with real-valued STFT/ISTFT.
Data flow:
1. RealSTFT -> (B, C, Fr, T, 2) [real]
2. Spec trimming (real instead of complex)
3. _magnitude (cac=True): permute+reshape -> (B, C*2, Fr, T) [real]
4. Encoder / CrossTransformer / Decoder [all real]
5. _mask (cac=True): reshape+permute -> (B, S, C, Fr, T, 2) [real]
6. RealISTFT -> waveform [real]
7. + time branch (denormalized) [real]
"""
def __init__(self, model: nn.Module, segment_samples: int):
super().__init__()
self.segment_samples = segment_samples
# Adopt network submodules from the loaded HTDemucs.
self.encoder = model.encoder
self.tencoder = model.tencoder
self.decoder = model.decoder
self.tdecoder = model.tdecoder
self.crosstransformer = model.crosstransformer
self.freq_emb = model.freq_emb
self.freq_emb_scale = model.freq_emb_scale
self.sources = model.sources
self.depth = model.depth
# Bottom-channel projection (present in some HTDemucs variants).
self.bottom_channels = model.bottom_channels
if self.bottom_channels:
self.channel_upsampler = model.channel_upsampler
self.channel_downsampler = model.channel_downsampler
self.channel_upsampler_t = model.channel_upsampler_t
self.channel_downsampler_t = model.channel_downsampler_t
# STFT / ISTFT parameters.
self.nfft = model.nfft
self.hop_length = model.hop_length
# Real-valued STFT / ISTFT modules.
self.real_stft = RealSTFT(model.nfft, model.hop_length)
# Frame count for ISTFT (fixed for the chosen segment size).
le = int(math.ceil(segment_samples / model.hop_length))
num_frames_istft = le + 4 # after padding inside _real_ispec
self.real_istft = RealISTFT(model.nfft, model.hop_length, num_frames_istft)
# nn.MultiheadAttention -> ManualMHA (fused op not supported by coremltools).
_replace_mha_recursive(self)
def _real_spec(self, mix: torch.Tensor) -> torch.Tensor:
"""
Real-valued STFT + trim.
Input: (B, C, T)
Output: (B, C, Fr, le, 2) -- trimmed, real
"""
hl = self.hop_length
length = mix.shape[-1]
le = int(math.ceil(length / hl))
pad = hl // 2 * 3
x = _pad1d(mix, (pad, pad + le * hl - length), mode="reflect")
z_ri = self.real_stft(x) # (B, C, Fr, frames, 2)
# Trim: drop the last freq bin, keep frames [2 : 2+le].
z_ri = z_ri[:, :, :-1, :, :]
z_ri = z_ri[:, :, :, 2:2 + le, :]
return z_ri
def _real_magnitude(self, z_ri: torch.Tensor) -> torch.Tensor:
"""
cac=True: real/imag channels.
Input: (B, C, Fr, T, 2)
Output: (B, C*2, Fr, T)
"""
# Move the (..., 2) dim into the channel axis:
# (B, C, Fr, T, 2) -> (B, C, 2, Fr, T) -> (B, C*2, Fr, T).
B, C, Fr, T, _ = z_ri.shape
m = z_ri.permute(0, 1, 4, 2, 3)
m = m.reshape(B, C * 2, Fr, T)
return m
def _real_mask(self, m: torch.Tensor) -> torch.Tensor:
"""
cac=True: network output -> real/imag tensor.
Input: (B, S, C*2, Fr, T) -- denormalized network output
Output: (B*S*C, Fr, T, 2) -- ready for RealISTFT
"""
B, S, _, Fr, T = m.shape
out = m.view(B, S, -1, 2, Fr, T).permute(0, 1, 2, 4, 5, 3)
out = out.reshape(B * S * (out.shape[2]), Fr, T, 2)
return out
def _real_ispec(self, z_ri: torch.Tensor, length: int) -> torch.Tensor:
"""
Real-valued ISTFT.
Input: (batch, Fr, T, 2)
Output: (batch, length)
"""
hl = self.hop_length
# Pad freq: add 1 bin at the end.
z_ri = F.pad(z_ri, (0, 0, 0, 0, 0, 1))
# Pad frames: add 2 on each side.
z_ri = F.pad(z_ri, (0, 0, 2, 2))
pad = hl // 2 * 3
le = hl * int(math.ceil(length / hl)) + 2 * pad
x = self.real_istft(z_ri, le)
x = x[:, pad:pad + length]
return x
def forward(self, mix: torch.Tensor) -> torch.Tensor:
"""
Input: (1, 2, segment_samples)
Output: (1, 4, 2, segment_samples) -- [vocals, drums, bass, other]
"""
length = mix.shape[-1]
# --- Frequency branch: real-valued STFT ---
z_ri = self._real_spec(mix) # (B, C, Fr, T, 2)
mag = self._real_magnitude(z_ri) # (B, C*2, Fr, T) float
x = mag
B, C_mag, Fq, T = x.shape
# Normalize.
mean = x.mean(dim=(1, 2, 3), keepdim=True)
std = x.std(dim=(1, 2, 3), keepdim=True)
x = (x - mean) / (1e-5 + std)
# --- Time branch ---
xt = mix
meant = xt.mean(dim=(1, 2), keepdim=True)
stdt = xt.std(dim=(1, 2), keepdim=True)
xt = (xt - meant) / (1e-5 + stdt)
# --- Encoder ---
saved = []
saved_t = []
lengths = []
lengths_t = []
for idx, encode in enumerate(self.encoder):
lengths.append(x.shape[-1])
inject = None
if idx < len(self.tencoder):
lengths_t.append(xt.shape[-1])
tenc = self.tencoder[idx]
xt = tenc(xt)
if not tenc.empty:
saved_t.append(xt)
else:
inject = xt
x = encode(x, inject)
if idx == 0 and self.freq_emb is not None:
frs = torch.arange(x.shape[-2], device=x.device)
emb = self.freq_emb(frs).t()[None, :, :, None].expand_as(x)
x = x + self.freq_emb_scale * emb
saved.append(x)
# --- Cross-Transformer ---
if self.crosstransformer:
if self.bottom_channels:
b, c, f, t = x.shape
from einops import rearrange
x = rearrange(x, "b c f t-> b c (f t)")
x = self.channel_upsampler(x)
x = rearrange(x, "b c (f t)-> b c f t", f=f)
xt = self.channel_upsampler_t(xt)
x, xt = self.crosstransformer(x, xt)
if self.bottom_channels:
x = rearrange(x, "b c f t-> b c (f t)")
x = self.channel_downsampler(x)
x = rearrange(x, "b c (f t)-> b c f t", f=f)
xt = self.channel_downsampler_t(xt)
# --- Decoder ---
for idx, decode in enumerate(self.decoder):
skip = saved.pop(-1)
x, pre = decode(x, skip, lengths.pop(-1))
offset = self.depth - len(self.tdecoder)
if idx >= offset:
tdec = self.tdecoder[idx - offset]
length_t = lengths_t.pop(-1)
if tdec.empty:
pre = pre[:, :, 0]
xt, _ = tdec(pre, None, length_t)
else:
skip = saved_t.pop(-1)
xt, _ = tdec(xt, skip, length_t)
# --- Frequency branch: denormalize + mask ---
S = len(self.sources)
x = x.view(B, S, -1, Fq, T)
x = x * std[:, None] + mean[:, None]
# _real_mask -> (B*S*C, Fr, T, 2)
zout_ri = self._real_mask(x)
# Real-valued ISTFT.
x_freq = self._real_ispec(zout_ri, length)
# x_freq: (B*S*C, length) -> (B, S, C, length)
C_orig = NUM_CHANNELS
x_freq = x_freq.view(B, S, C_orig, length)
# --- Time branch: denormalize ---
xt = xt.view(B, S, -1, length)
xt = xt * stdt[:, None] + meant[:, None]
# --- Combine ---
x_out = x_freq + xt
# Reorder sources: drums,bass,other,vocals -> vocals,drums,bass,other.
x_out = x_out[:, SOURCE_REORDER, :, :]
return x_out
# ---------------------------------------------------------------------------
# Metadata
# ---------------------------------------------------------------------------
def _add_metadata(mlmodel, segment_samples: int) -> None:
mlmodel.author = "HTDemucs CoreML conversion"
mlmodel.license = (
"MIT. Original Demucs: Copyright (c) Meta Platforms, Inc. and "
"affiliates, MIT License. See LICENSE and ATTRIBUTION."
)
mlmodel.short_description = (
f"Hybrid Transformer Demucs (HTDemucs) -- music source separation "
f"into {', '.join(SOURCE_NAMES)} at {SAMPLE_RATE} Hz."
)
mlmodel.input_description["audio"] = (
f"Stereo audio. Shape (1, 2, {segment_samples}), Float32, {SAMPLE_RATE} Hz."
)
mlmodel.output_description["sources"] = (
f"Separated stems. Shape (1, 4, 2, {segment_samples}). "
f"Order: [{', '.join(SOURCE_NAMES)}]."
)
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser(
description="Convert Demucs (HTDemucs) to Core ML mlpackage."
)
p.add_argument(
"--segment", type=float, default=10.0,
help="Segment length in seconds (default: 10.0).",
)
p.add_argument(
"--fp16", action="store_true",
help="Quantize to FP16 (~half the file size, minor accuracy loss).",
)
p.add_argument(
"--output", type=str, default=None,
help="Output mlpackage path (default: HTDemucs_CoreML[_FP16].mlpackage).",
)
p.add_argument(
"--compute-units", choices=["cpu_and_gpu", "all", "cpu_only"],
default="cpu_and_gpu",
help="Default ComputeUnit baked into the model (default: cpu_and_gpu). "
"HTDemucs is unstable on the Neural Engine -- keep 'cpu_and_gpu' "
"unless you have specifically validated 'all'.",
)
return p.parse_args()
def main() -> None:
import coremltools as ct
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)
args = parse_args()
segment_samples = int(round(args.segment * SAMPLE_RATE))
output_path = args.output or (
"HTDemucs_CoreML_FP16.mlpackage" if args.fp16 else DEFAULT_OUTPUT
)
precision = ct.precision.FLOAT16 if args.fp16 else ct.precision.FLOAT32
compute_units = {
"cpu_and_gpu": ct.ComputeUnit.CPU_AND_GPU,
"all": ct.ComputeUnit.ALL,
"cpu_only": ct.ComputeUnit.CPU_ONLY,
}[args.compute_units]
print("=" * 60)
print(" HTDemucs -> Core ML Converter")
print(" (real-valued STFT / ISTFT wrapper)")
print("=" * 60)
print(f" Model: {MODEL_NAME}")
print(f" Sample rate: {SAMPLE_RATE} Hz")
print(f" Segment: {segment_samples} samples ({args.segment:.1f}s)")
print(f" Stems: {', '.join(SOURCE_NAMES)}")
print(f" Precision: {'FP16' if args.fp16 else 'FP32'}")
print(f" Compute: {args.compute_units}")
print(f" Output: {output_path}")
print("=" * 60)
# --- Load model ---
print(f"\n[1/5] Loading Demucs '{MODEL_NAME}' ...")
from demucs.pretrained import get_model
bag = get_model(MODEL_NAME)
model = bag.models[0]
model.eval()
model.use_train_segment = False
num_params = sum(p.numel() for p in model.parameters()) / 1e6
print(f" {num_params:.1f}M parameters loaded.")
# --- Build wrapper ---
print("\n[2/5] Building real-valued wrapper ...")
wrapper = RealValuedHTDemucs(model, segment_samples=segment_samples)
wrapper.eval()
dummy = torch.randn(1, NUM_CHANNELS, segment_samples)
# --- PyTorch sanity check ---
print("\n[3/5] PyTorch forward pass ...")
with torch.no_grad():
out_wrapper = wrapper(dummy)
print(f" Output shape: {out_wrapper.shape}")
expected = (1, NUM_SOURCES, NUM_CHANNELS, segment_samples)
assert out_wrapper.shape == expected, f"Shape {out_wrapper.shape} != {expected}"
print(" OK.")
# --- Trace ---
print("\n[4/5] torch.jit.trace ...")
with torch.no_grad():
traced = torch.jit.trace(wrapper, dummy, strict=False)
print(" Trace OK.")
# --- Core ML conversion ---
print("\n[5/5] Core ML conversion ...")
mlmodel = ct.convert(
traced,
inputs=[
ct.TensorType(
name="audio",
shape=(1, NUM_CHANNELS, segment_samples),
dtype=np.float32,
)
],
outputs=[ct.TensorType(name="sources")],
convert_to="mlprogram",
compute_units=compute_units,
compute_precision=precision,
minimum_deployment_target=ct.target.macOS14,
)
_add_metadata(mlmodel, segment_samples)
mlmodel.save(output_path)
# --- Validation ---
# Important: reload with the SAME compute_units we converted for.
# MLModel(path) without a config defaults to ComputeUnit.ALL, which on
# HTDemucs may dispatch to ANE and crash with E5RT errors -- exactly
# the bug we baked the CPU_AND_GPU default into the model to avoid.
print("\n[Val] Validating Core ML vs. PyTorch reference ...")
try:
val_config = ct.ComputeUnit.CPU_AND_GPU
mlmodel_loaded = ct.models.MLModel(output_path, compute_units=val_config)
with torch.no_grad():
ref = wrapper(dummy).numpy()
pred = mlmodel_loaded.predict({"audio": dummy.numpy()})
cml_out = pred["sources"]
assert ref.shape == cml_out.shape, f"Shape mismatch: {ref.shape} vs {cml_out.shape}"
max_diff = float(np.max(np.abs(ref - cml_out)))
mean_diff = float(np.mean(np.abs(ref - cml_out)))
print(f" Max diff: {max_diff:.6f}")
print(f" Mean diff: {mean_diff:.6f}")
threshold = 0.2 if args.fp16 else 0.1
if max_diff < threshold:
print(" Validation OK.")
else:
print(" Large numerical drift (expected for FP16 on ANE).")
except Exception as e:
print(f" Validation skipped: {e}")
# --- Summary ---
size_mb = sum(
f.stat().st_size for f in Path(output_path).rglob("*") if f.is_file()
) / (1024 * 1024)
print("\n" + "=" * 60)
print(f" Done: {output_path} ({size_mb:.0f} MB)")
print()
print(" Next step: drag the .mlpackage into your Xcode project")
print(" and load it via MLModel(contentsOf: ...). See examples/swift/.")
print("=" * 60)
if __name__ == "__main__":
main()