File size: 18,115 Bytes
d4519e2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
#!/usr/bin/env python3
"""
Export Demucs PyTorch models directly to safetensors + JSON config for Swift MLX.

Converts all 8 pretrained models directly from the original PyTorch demucs package.
No dependency on demucs-mlx or any other re-implementation.

Usage:
    # Export all models
    python scripts/export_from_pytorch.py --out-dir ~/.cache/demucs-mlx-swift-models

    # Export specific models
    python scripts/export_from_pytorch.py --models htdemucs htdemucs_ft --out-dir ./Models

Requirements:
    pip install demucs safetensors numpy
"""
from __future__ import annotations

import argparse
import inspect
import json
import re
import sys
from fractions import Fraction
from pathlib import Path

import numpy as np
import torch

ALL_MODELS = [
    "htdemucs",
    "htdemucs_ft",
    "htdemucs_6s",
    "hdemucs_mmi",
    "mdx",
    "mdx_extra",
    "mdx_q",
    "mdx_extra_q",
]

# Map PyTorch class names to MLX class names used by Swift loader
CLASS_MAP = {
    "Demucs": "DemucsMLX",
    "HDemucs": "HDemucsMLX",
    "HTDemucs": "HTDemucsMLX",
}

# Conv-like layer names that get .conv. wrapper in MLX
CONV_LAYER_NAMES = {
    "conv", "conv_tr", "rewrite",
    "channel_upsampler", "channel_downsampler",
    "channel_upsampler_t", "channel_downsampler_t",
}

# DConv attention sub-module names (LocalState)
DCONV_ATTN_NAMES = {"content", "key", "query", "proj", "query_decay", "query_freqs"}


def to_json_serializable(obj):
    """Convert Python objects to JSON-serializable types."""
    if isinstance(obj, Fraction):
        return f"{obj.numerator}/{obj.denominator}"
    if isinstance(obj, torch.Tensor):
        return obj.item() if obj.numel() == 1 else obj.tolist()
    if isinstance(obj, np.ndarray):
        return obj.tolist()
    if isinstance(obj, (list, tuple)):
        return [to_json_serializable(x) for x in obj]
    if isinstance(obj, dict):
        return {str(k): to_json_serializable(v) for k, v in obj.items()}
    return obj


def transpose_conv_weights(key: str, value: np.ndarray, is_conv_transpose: bool = False) -> np.ndarray:
    """Transpose PyTorch conv weights to MLX layout.

    Conv1d:          (out, in, k)    β†’ MLX: (out, k, in)       transpose (0,2,1)
    Conv2d:          (out, in, h, w) β†’ MLX: (out, h, w, in)    transpose (0,2,3,1)
    ConvTranspose1d: (in, out, k)    β†’ MLX: (out, k, in)       transpose (1,2,0)
    ConvTranspose2d: (in, out, h, w) β†’ MLX: (out, h, w, in)    transpose (1,2,3,0)
    """
    if not key.endswith(".weight"):
        return value

    if len(value.shape) == 3:
        return np.transpose(value, (1, 2, 0) if is_conv_transpose else (0, 2, 1))
    if len(value.shape) == 4:
        return np.transpose(value, (1, 2, 3, 0) if is_conv_transpose else (0, 2, 3, 1))
    return value


def remap_key(
    key: str,
    value: np.ndarray,
    model_type: str = "HTDemucs",
    dconv_conv_slots: set | None = None,
    seq_conv_slots: set | None = None,
) -> list[tuple[str, np.ndarray]]:
    """Remap a PyTorch state dict key to MLX key convention.

    Returns a list of (key, value) pairs (multiple for attention in_proj splits).
    Duplicate target keys (e.g. LSTM bias_ih + bias_hh) are merged by the caller.

    Args:
        key: PyTorch state dict key
        value: numpy array (already transposed for conv weights)
        model_type: PyTorch class name ("Demucs", "HDemucs", "HTDemucs")
        dconv_conv_slots: set of (block_prefix, slot_str) for DConv slots with 3D weights
        seq_conv_slots: set of (enc_dec, layer, slot) for Demucs v1/v2 Sequential Conv slots
    """
    dconv_conv_slots = dconv_conv_slots or set()
    seq_conv_slots = seq_conv_slots or set()

    # =========================================================================
    # Step 1: Demucs v1/v2 Sequential insertion
    # encoder.{i}.{j}.rest β†’ encoder.{i}.layers.{j}.rest
    # decoder.{i}.{j}.rest β†’ decoder.{i}.layers.{j}.rest
    # =========================================================================
    if model_type == "Demucs":
        m = re.match(r"(encoder|decoder)\.(\d+)\.(\d+)(\..*)?$", key)
        if m:
            enc_dec, layer, slot, rest = m.groups()
            rest = rest or ""
            key = f"{enc_dec}.{layer}.layers.{slot}{rest}"

    # =========================================================================
    # Step 1.5: Demucs v1/v2 Sequential Conv/Norm slot wrapping
    # encoder.{i}.layers.{j}.weight β†’ encoder.{i}.layers.{j}.conv.weight (if Conv slot)
    # =========================================================================
    if model_type == "Demucs":
        m = re.match(r"(encoder|decoder)\.(\d+)\.layers\.(\d+)\.(weight|bias)$", key)
        if m:
            enc_dec, layer, slot, param = m.groups()
            if (enc_dec, layer, slot) in seq_conv_slots:
                return [(f"{enc_dec}.{layer}.layers.{slot}.conv.{param}", value)]
            else:
                return [(f"{enc_dec}.{layer}.layers.{slot}.{param}", value)]

    # =========================================================================
    # Step 2: DConv internal slot handling
    # Matches: *.layers.{block_idx}.{slot_idx}.{rest}
    # Both HDemucs (.dconv.layers.) and Demucs v1/v2 (.layers.{N}.layers.) end
    # with this pattern after Step 1.
    # =========================================================================
    m = re.match(r"(.+\.layers\.\d+)\.(\d+)\.(.+)$", key)
    if m:
        block_prefix = m.group(1)
        slot = m.group(2)
        rest = m.group(3)

        # --- 2a. Simple weight/bias/scale ---
        if rest in ("weight", "bias", "scale"):
            if rest == "weight" and len(value.shape) >= 2:
                # 3D weight = Conv1d β†’ add .conv.
                return [(f"{block_prefix}.layers.{slot}.conv.{rest}", value)]
            elif rest == "weight":
                # 1D weight = GroupNorm β†’ no wrapper
                return [(f"{block_prefix}.layers.{slot}.{rest}", value)]
            elif rest == "bias":
                if (block_prefix, slot) in dconv_conv_slots:
                    return [(f"{block_prefix}.layers.{slot}.conv.{rest}", value)]
                else:
                    return [(f"{block_prefix}.layers.{slot}.{rest}", value)]
            else:  # scale
                return [(f"{block_prefix}.layers.{slot}.{rest}", value)]

        # --- 2b. LSTM weights/biases ---
        m_lstm = re.match(r"lstm\.(weight|bias)_(ih|hh)_l(\d+)(_reverse)?$", rest)
        if m_lstm:
            wb, ih_hh, layer_idx, reverse = m_lstm.groups()
            direction = "backward_lstms" if reverse else "forward_lstms"
            if wb == "weight":
                param = "Wx" if ih_hh == "ih" else "Wh"
                return [(f"{block_prefix}.layers.{slot}.{direction}.{layer_idx}.{param}", value)]
            else:  # bias β€” both bias_ih and bias_hh map to same key; caller merges
                return [(f"{block_prefix}.layers.{slot}.{direction}.{layer_idx}.bias", value)]

        # --- 2c. LSTM linear ---
        m_linear = re.match(r"linear\.(weight|bias)$", rest)
        if m_linear:
            param = m_linear.group(1)
            return [(f"{block_prefix}.layers.{slot}.linear.{param}", value)]

        # --- 2d. Attention sub-modules (LocalState) ---
        m_attn = re.match(r"(content|key|query|proj|query_decay|query_freqs)\.(weight|bias)$", rest)
        if m_attn:
            attn_name, param = m_attn.groups()
            # These are all Conv1d modules β†’ add .conv. wrapper
            return [(f"{block_prefix}.layers.{slot}.{attn_name}.conv.{param}", value)]

        # --- 2e. Fallback for unknown compound keys ---
        return [(f"{block_prefix}.layers.{slot}.{rest}", value)]

    # =========================================================================
    # Step 3: MultiheadAttention in_proj split (HTDemucs transformer)
    # =========================================================================
    m = re.match(r"(.+)\.(self_attn|cross_attn)\.in_proj_(weight|bias)$", key)
    if m:
        prefix, attn_type, param = m.group(1), m.group(2), m.group(3)
        mlx_attn = "attn" if attn_type == "self_attn" else "cross_attn"
        dim = value.shape[0] // 3
        q, k_val, v = value[:dim], value[dim : 2 * dim], value[2 * dim :]
        return [
            (f"{prefix}.{mlx_attn}.query_proj.{param}", q),
            (f"{prefix}.{mlx_attn}.key_proj.{param}", k_val),
            (f"{prefix}.{mlx_attn}.value_proj.{param}", v),
        ]

    # self_attn.out_proj β†’ attn.out_proj
    m = re.match(r"(.+)\.self_attn\.out_proj\.(weight|bias)$", key)
    if m:
        prefix, param = m.group(1), m.group(2)
        return [(f"{prefix}.attn.out_proj.{param}", value)]

    # =========================================================================
    # Step 4: norm_out wrapping β†’ norm_out.gn
    # =========================================================================
    m = re.match(r"(.+)\.norm_out\.(weight|bias)$", key)
    if m:
        prefix, param = m.group(1), m.group(2)
        return [(f"{prefix}.norm_out.gn.{param}", value)]

    # =========================================================================
    # Step 5: Bottleneck LSTM (Demucs v1/v2 and HDemucs)
    # lstm.lstm.weight_ih_l0 β†’ lstm.forward_lstms.0.Wx
    # =========================================================================
    m = re.match(r"(.+)\.lstm\.(weight|bias)_(ih|hh)_l(\d+)(_reverse)?$", key)
    if m:
        prefix = m.group(1)
        wb = m.group(2)
        ih_hh = m.group(3)
        layer_idx = m.group(4)
        reverse = m.group(5)
        direction = "backward_lstms" if reverse else "forward_lstms"
        if wb == "weight":
            param = "Wx" if ih_hh == "ih" else "Wh"
            return [(f"{prefix}.{direction}.{layer_idx}.{param}", value)]
        else:  # bias β€” merge handled by caller
            return [(f"{prefix}.{direction}.{layer_idx}.bias", value)]

    # =========================================================================
    # Step 6: Conv/ConvTranspose/Rewrite named layers β†’ add .conv. wrapper
    # =========================================================================
    parts = key.rsplit(".", 1)
    if len(parts) == 2:
        path, param = parts
        path_parts = path.split(".")
        last_name = path_parts[-1]
        if last_name in CONV_LAYER_NAMES and param in ("weight", "bias"):
            return [(f"{path}.conv.{param}", value)]

    # =========================================================================
    # Default: no change
    # =========================================================================
    return [(key, value)]


def convert_sub_model(model, prefix: str) -> dict[str, np.ndarray]:
    """Convert a single sub-model's state dict to MLX-compatible numpy arrays."""
    cls_name = type(model).__name__

    # --- Pre-scan: identify ConvTranspose modules by type ---
    conv_tr_paths = set()
    for name, module in model.named_modules():
        if isinstance(module, (torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d)):
            conv_tr_paths.add(name)

    # --- Collect state dict as numpy ---
    state_items = []
    for key, tensor in model.state_dict().items():
        arr = tensor.detach().cpu().float().numpy()
        state_items.append((key, arr))

    # --- Pre-scan: identify DConv Conv slots (3D weights) ---
    # Pattern: *.layers.{block}.{slot}.weight where value is 3D
    # For Demucs v1/v2, apply Sequential insertion first so lookups match remap_key
    dconv_conv_slots: set[tuple[str, str]] = set()
    for key, arr in state_items:
        scan_key = key
        if cls_name == "Demucs":
            m = re.match(r"(encoder|decoder)\.(\d+)\.(\d+)(\..*)?$", scan_key)
            if m:
                enc_dec, layer, slot, rest = m.groups()
                rest = rest or ""
                scan_key = f"{enc_dec}.{layer}.layers.{slot}{rest}"
        m = re.match(r"(.+\.layers\.\d+)\.(\d+)\.weight$", scan_key)
        if m and len(arr.shape) >= 2:
            dconv_conv_slots.add((m.group(1), m.group(2)))

    # --- Pre-scan: Demucs v1/v2 Sequential Conv slots ---
    seq_conv_slots: set[tuple[str, str, str]] = set()
    if cls_name == "Demucs":
        for key, arr in state_items:
            m = re.match(r"(encoder|decoder)\.(\d+)\.(\d+)\.weight$", key)
            if m and len(arr.shape) >= 2:
                seq_conv_slots.add((m.group(1), m.group(2), m.group(3)))

    # --- Convert ---
    weights: dict[str, np.ndarray] = {}
    for key, arr in state_items:
        # Determine if this belongs to a ConvTranspose module
        is_conv_tr = any(key.startswith(p + ".") for p in conv_tr_paths)

        # Transpose conv weights
        arr = transpose_conv_weights(key, arr, is_conv_transpose=is_conv_tr)

        # Remap key
        remapped = remap_key(key, arr, cls_name, dconv_conv_slots, seq_conv_slots)
        for new_key, new_val in remapped:
            full_key = f"{prefix}{new_key}"
            if full_key in weights:
                # LSTM bias merge: bias_ih + bias_hh β†’ bias (additive)
                weights[full_key] = weights[full_key] + new_val
            else:
                weights[full_key] = new_val

    return weights


def extract_kwargs(model) -> dict:
    """Extract constructor kwargs from a model using _init_args_kwargs or inspection."""
    if hasattr(model, "_init_args_kwargs"):
        _, kwargs = model._init_args_kwargs
        return {k: to_json_serializable(v) for k, v in kwargs.items()
                if isinstance(v, (int, float, str, bool, list, tuple, type(None), Fraction))}

    # Fallback: inspect __init__ signature and read matching attributes
    sig = inspect.signature(type(model).__init__)
    kwargs = {}
    for name in sig.parameters:
        if name == "self":
            continue
        if hasattr(model, name):
            val = getattr(model, name)
            kwargs[name] = to_json_serializable(val)
    return kwargs


def export_model(model_name: str, out_dir: Path) -> bool:
    """Export a single model (or bag) to safetensors + config JSON."""
    from demucs.pretrained import get_model
    from demucs.apply import BagOfModels

    print(f"\n--- Exporting {model_name} ---")
    try:
        model = get_model(model_name)
    except Exception as e:
        print(f"  Failed to load model: {e}")
        return False

    is_bag = isinstance(model, BagOfModels)

    if is_bag:
        sub_models = list(model.models)
        num_models = len(sub_models)
        bag_weights = model.weights.tolist() if hasattr(model.weights, "tolist") else list(model.weights)
    else:
        sub_models = [model]
        num_models = 1
        bag_weights = None

    print(f"  {'Bag of ' + str(num_models) + ' models' if is_bag else 'Single model'}")

    # Collect all weights and metadata
    all_weights: dict[str, np.ndarray] = {}
    model_classes: list[str] = []
    model_configs: list[dict] = []

    for i, sub in enumerate(sub_models):
        cls_name = type(sub).__name__
        mlx_cls = CLASS_MAP.get(cls_name, cls_name)
        model_classes.append(mlx_cls)
        print(f"  Model {i}: {cls_name} β†’ {mlx_cls}")

        prefix = f"model_{i}." if is_bag else ""
        sub_weights = convert_sub_model(sub, prefix)
        all_weights.update(sub_weights)

        kwargs = extract_kwargs(sub)
        model_configs.append({
            "model_class": mlx_cls,
            "kwargs": kwargs,
        })

    # Build config JSON
    config: dict = {
        "model_name": model_name,
        "tensor_count": len(all_weights),
    }

    if is_bag:
        config["model_class"] = "BagOfModelsMLX"
        config["num_models"] = num_models
        config["weights"] = bag_weights
        config["sub_model_classes"] = model_classes

        # If all sub-models are the same class, set sub_model_class for compat
        unique = set(model_classes)
        if len(unique) == 1:
            config["sub_model_class"] = unique.pop()

        config["model_configs"] = model_configs

        # Also put kwargs at top level for single-model bags (common case)
        if num_models == 1:
            config["kwargs"] = model_configs[0]["kwargs"]
    else:
        config["model_class"] = model_classes[0]
        config["kwargs"] = model_configs[0]["kwargs"]

    # Save files
    model_dir = out_dir / model_name
    model_dir.mkdir(parents=True, exist_ok=True)

    safetensors_path = model_dir / f"{model_name}.safetensors"
    config_path = model_dir / f"{model_name}_config.json"

    # Save safetensors (prefer safetensors library, fallback to mlx)
    try:
        from safetensors.numpy import save_file
        save_file(all_weights, str(safetensors_path))
    except ImportError:
        import mlx.core as mx
        mlx_weights = {k: mx.array(v) for k, v in all_weights.items()}
        mx.save_safetensors(str(safetensors_path), mlx_weights)

    with config_path.open("w") as f:
        json.dump(config, f, indent=2, default=str)

    size_mb = safetensors_path.stat().st_size / (1024 * 1024)
    print(f"  Wrote {safetensors_path} ({len(all_weights)} tensors, {size_mb:.0f} MB)")
    print(f"  Wrote {config_path}")
    return True


def main():
    ap = argparse.ArgumentParser(
        description="Export Demucs PyTorch models to safetensors for Swift MLX"
    )
    ap.add_argument(
        "--models",
        nargs="*",
        default=None,
        help=f"Models to export (default: all). Choices: {', '.join(ALL_MODELS)}",
    )
    ap.add_argument(
        "--out-dir",
        default="./Models",
        help="Output root directory (files go into <out-dir>/<model_name>/)",
    )
    args = ap.parse_args()

    models = args.models or ALL_MODELS
    out_dir = Path(args.out_dir).resolve()

    exported = 0
    failed = 0

    for name in models:
        if export_model(name, out_dir):
            exported += 1
        else:
            failed += 1

    print(f"\n=== Done: {exported} exported, {failed} failed ===")
    if failed:
        sys.exit(1)


if __name__ == "__main__":
    main()