Kyle Howells commited on
Commit
8abe262
·
1 Parent(s): fc73d82

Add conversion script and update README with conversion instructions

Browse files
Files changed (2) hide show
  1. README.md +26 -7
  2. convert_deepfilternet.py +196 -0
README.md CHANGED
@@ -53,15 +53,16 @@ All versions share the same audio parameters:
53
  ## Files
54
 
55
  ```
 
56
  v1/
57
- config.json # v1 architecture configuration
58
- model.safetensors # v1 weights
59
  v2/
60
- config.json # v2 architecture configuration
61
- model.safetensors # v2 weights
62
  v3/
63
- config.json # v3 architecture configuration
64
- model.safetensors # v3 weights
65
  ```
66
 
67
  ## Usage
@@ -90,11 +91,29 @@ let model = try await DeepFilterNetModel.fromPretrained("mlx-community/DeepFilte
90
  let enhanced = try model.enhance(audioArray)
91
  ```
92
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  ## Origin
94
 
95
  - **Original model:** [DeepFilterNet](https://github.com/Rikorose/DeepFilterNet) by Hendrik Schroeter
96
  - **License:** MIT (same as the original)
97
- - **Conversion:** PyTorch → `safetensors`
98
 
99
  ## Citations
100
 
 
53
  ## Files
54
 
55
  ```
56
+ convert_deepfilternet.py # PyTorch → MLX conversion script
57
  v1/
58
+ config.json # v1 architecture configuration
59
+ model.safetensors # v1 weights
60
  v2/
61
+ config.json # v2 architecture configuration
62
+ model.safetensors # v2 weights
63
  v3/
64
+ config.json # v3 architecture configuration
65
+ model.safetensors # v3 weights
66
  ```
67
 
68
  ## Usage
 
91
  let enhanced = try model.enhance(audioArray)
92
  ```
93
 
94
+ ## Converting from PyTorch
95
+
96
+ To re-create these weights from the original DeepFilterNet checkpoints:
97
+
98
+ ```bash
99
+ # Clone the original repo to get the pretrained checkpoints
100
+ git clone https://github.com/Rikorose/DeepFilterNet
101
+
102
+ # Convert each version
103
+ python convert_deepfilternet.py --input DeepFilterNet/DeepFilterNet --output v1 --name DeepFilterNet
104
+ python convert_deepfilternet.py --input DeepFilterNet/DeepFilterNet2 --output v2 --name DeepFilterNet2
105
+ python convert_deepfilternet.py --input DeepFilterNet/DeepFilterNet3 --output v3 --name DeepFilterNet3
106
+ ```
107
+
108
+ Each input directory should contain a `config.ini` and a `checkpoints/` folder from the original repo.
109
+
110
+ Requires `torch` and `mlx` to be installed.
111
+
112
  ## Origin
113
 
114
  - **Original model:** [DeepFilterNet](https://github.com/Rikorose/DeepFilterNet) by Hendrik Schroeter
115
  - **License:** MIT (same as the original)
116
+ - **Conversion:** PyTorch → `safetensors` via `convert_deepfilternet.py`
117
 
118
  ## Citations
119
 
convert_deepfilternet.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Convert DeepFilterNet PyTorch weights to MLX format.
4
+
5
+ This script converts pretrained DeepFilterNet models from the original
6
+ PyTorch implementation to MLX-compatible format with proper weight mapping.
7
+ """
8
+
9
+ import argparse
10
+ import json
11
+ import configparser
12
+ from pathlib import Path
13
+ from typing import Dict, Any, List, Tuple
14
+ import re
15
+
16
+ import mlx.core as mx
17
+ import numpy as np
18
+ import torch
19
+
20
+
21
+ def convert_weight(weight: torch.Tensor) -> mx.array:
22
+ """Convert PyTorch tensor to MLX array."""
23
+ return mx.array(weight.detach().cpu().numpy())
24
+
25
+
26
+ def parse_config(config_path: Path) -> Dict[str, Any]:
27
+ """Parse DeepFilterNet config.ini file."""
28
+ config = configparser.ConfigParser()
29
+ config.read(config_path)
30
+
31
+ linear_groups = config.getint("deepfilternet", "linear_groups", fallback=16)
32
+ df_order = config.getint(
33
+ "df",
34
+ "df_order",
35
+ fallback=config.getint("deepfilternet", "df_order", fallback=5),
36
+ )
37
+ df_lookahead = config.getint(
38
+ "df",
39
+ "df_lookahead",
40
+ fallback=config.getint("deepfilternet", "df_lookahead", fallback=0),
41
+ )
42
+
43
+ result = {
44
+ # [df] section
45
+ "sample_rate": config.getint("df", "sr", fallback=48000),
46
+ "fft_size": config.getint("df", "fft_size", fallback=960),
47
+ "hop_size": config.getint("df", "hop_size", fallback=480),
48
+ "nb_erb": config.getint("df", "nb_erb", fallback=32),
49
+ "nb_df": config.getint("df", "nb_df", fallback=96),
50
+ "df_order": df_order,
51
+ "df_lookahead": df_lookahead,
52
+ "lsnr_max": config.getint("df", "lsnr_max", fallback=35),
53
+ "lsnr_min": config.getint("df", "lsnr_min", fallback=-15),
54
+
55
+ # [deepfilternet] section
56
+ "conv_ch": config.getint("deepfilternet", "conv_ch", fallback=64),
57
+ "conv_k_enc": config.getint("deepfilternet", "conv_k_enc", fallback=1),
58
+ "conv_k_dec": config.getint("deepfilternet", "conv_k_dec", fallback=1),
59
+ "conv_width_factor": config.getint("deepfilternet", "conv_width_factor", fallback=1),
60
+ "conv_dec_mode": config.get("deepfilternet", "conv_dec_mode", fallback="transposed"),
61
+ "emb_hidden_dim": config.getint("deepfilternet", "emb_hidden_dim", fallback=256),
62
+ "emb_num_layers": config.getint("deepfilternet", "emb_num_layers", fallback=3),
63
+ "df_hidden_dim": config.getint("deepfilternet", "df_hidden_dim", fallback=256),
64
+ "df_num_layers": config.getint("deepfilternet", "df_num_layers", fallback=2),
65
+ "gru_groups": config.getint("deepfilternet", "gru_groups", fallback=8),
66
+ "linear_groups": linear_groups,
67
+ # DeepFilterNet2 configs do not expose enc_linear_groups separately; in that case it
68
+ # should follow linear_groups to keep grouped-linear tensor shapes aligned.
69
+ "enc_linear_groups": config.getint("deepfilternet", "enc_linear_groups", fallback=linear_groups),
70
+ "group_shuffle": config.getboolean("deepfilternet", "group_shuffle", fallback=False),
71
+ "mask_pf": config.getboolean("deepfilternet", "mask_pf", fallback=False),
72
+ "conv_lookahead": config.getint("deepfilternet", "conv_lookahead", fallback=2),
73
+ "conv_depthwise": config.getboolean("deepfilternet", "conv_depthwise", fallback=True),
74
+ "convt_depthwise": config.getboolean("deepfilternet", "convt_depthwise", fallback=False),
75
+ "enc_concat": config.getboolean("deepfilternet", "enc_concat", fallback=False),
76
+ "emb_gru_skip_enc": config.get("deepfilternet", "emb_gru_skip_enc", fallback="none"),
77
+ "emb_gru_skip": config.get("deepfilternet", "emb_gru_skip", fallback="none"),
78
+ "df_gru_skip": config.get("deepfilternet", "df_gru_skip", fallback="groupedlinear"),
79
+ "dfop_method": config.get("deepfilternet", "dfop_method", fallback="real_unfold"),
80
+ }
81
+
82
+ # Parse conv_kernel strings
83
+ conv_kernel = config.get("deepfilternet", "conv_kernel", fallback="1,3")
84
+ result["conv_kernel"] = [int(x) for x in conv_kernel.split(",")]
85
+
86
+ convt_kernel = config.get("deepfilternet", "convt_kernel", fallback="1,3")
87
+ result["convt_kernel"] = [int(x) for x in convt_kernel.split(",")]
88
+
89
+ conv_kernel_inp = config.get("deepfilternet", "conv_kernel_inp", fallback="3,3")
90
+ result["conv_kernel_inp"] = [int(x) for x in conv_kernel_inp.split(",")]
91
+
92
+ return result
93
+
94
+
95
+ def convert_pytorch_to_mlx(
96
+ checkpoint_path: Path,
97
+ config_path: Path,
98
+ output_dir: Path,
99
+ model_name: str = "DeepFilterNet3",
100
+ ):
101
+ """Convert PyTorch checkpoint to MLX format with proper weight mapping."""
102
+
103
+ print(f"Loading checkpoint from {checkpoint_path}")
104
+ ckpt = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
105
+
106
+ # Get state dict
107
+ if "state_dict" in ckpt:
108
+ state_dict = ckpt["state_dict"]
109
+ elif "model_state_dict" in ckpt:
110
+ state_dict = ckpt["model_state_dict"]
111
+ else:
112
+ state_dict = ckpt
113
+
114
+ print(f"Found {len(state_dict)} parameters in checkpoint")
115
+
116
+ # Parse config
117
+ print(f"Parsing config from {config_path}")
118
+ config_dict = parse_config(config_path)
119
+ config_dict["model_version"] = model_name
120
+
121
+ # Print weight shapes for debugging
122
+ print("\nPyTorch weight shapes:")
123
+ for key, value in list(state_dict.items())[:20]:
124
+ print(f" {key}: {tuple(value.shape)}")
125
+ print(" ...")
126
+
127
+ # Convert weights - direct mapping since we'll match the architecture
128
+ print("\nConverting weights to MLX format...")
129
+ mlx_weights = {}
130
+
131
+ for key, value in state_dict.items():
132
+ # Skip buffers that aren't needed for inference
133
+ if "num_batches_tracked" in key:
134
+ continue
135
+
136
+ # Convert weight
137
+ mlx_array = convert_weight(value)
138
+ mlx_weights[key] = mlx_array
139
+
140
+ print(f"Converted {len(mlx_weights)} weights")
141
+
142
+ # Create output directory
143
+ output_dir.mkdir(parents=True, exist_ok=True)
144
+
145
+ # Save weights
146
+ weights_path = output_dir / "model.safetensors"
147
+ print(f"Saving weights to {weights_path}")
148
+ mx.save_safetensors(str(weights_path), mlx_weights)
149
+
150
+ # Save config
151
+ config_out_path = output_dir / "config.json"
152
+ print(f"Saving config to {config_out_path}")
153
+ with open(config_out_path, "w") as f:
154
+ json.dump(config_dict, f, indent=2)
155
+
156
+ print(f"\nConversion complete! Output saved to {output_dir}")
157
+ print(f" - model.safetensors: {weights_path.stat().st_size / 1024 / 1024:.1f} MB")
158
+ print(f" - config.json")
159
+
160
+ return mlx_weights, config_dict
161
+
162
+
163
+ def main():
164
+ parser = argparse.ArgumentParser(description="Convert DeepFilterNet PyTorch weights to MLX")
165
+ parser.add_argument("--input", type=str, required=True, help="Path to DeepFilterNet model directory")
166
+ parser.add_argument("--output", type=str, required=True, help="Output directory for MLX model")
167
+ parser.add_argument("--name", type=str, default="DeepFilterNet3", help="Model name")
168
+ args = parser.parse_args()
169
+
170
+ input_dir = Path(args.input)
171
+ output_dir = Path(args.output)
172
+
173
+ # Find checkpoint
174
+ checkpoint_dir = input_dir / "checkpoints"
175
+ if checkpoint_dir.exists():
176
+ # Look for best checkpoint
177
+ checkpoints = list(checkpoint_dir.glob("*.best"))
178
+ if not checkpoints:
179
+ checkpoints = list(checkpoint_dir.glob("*.ckpt"))
180
+ if checkpoints:
181
+ checkpoint_path = checkpoints[0]
182
+ else:
183
+ raise FileNotFoundError(f"No checkpoint files found in {checkpoint_dir}")
184
+ else:
185
+ raise FileNotFoundError(f"Checkpoint directory not found: {checkpoint_dir}")
186
+
187
+ # Find config
188
+ config_path = input_dir / "config.ini"
189
+ if not config_path.exists():
190
+ raise FileNotFoundError(f"Config file not found: {config_path}")
191
+
192
+ convert_pytorch_to_mlx(checkpoint_path, config_path, output_dir, args.name)
193
+
194
+
195
+ if __name__ == "__main__":
196
+ main()