yoyolicoris commited on
Commit
1bfa935
·
1 Parent(s): 6d75109

add vocal effects style transfer demo with Gradio interface

Browse files
Files changed (1) hide show
  1. app.py +439 -0
app.py ADDED
@@ -0,0 +1,439 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ from scipy.io.wavfile import read
4
+ import matplotlib.pyplot as plt
5
+ import torch
6
+ from torch import Tensor
7
+ import math
8
+ import yaml
9
+ import json
10
+ import pyloudnorm as pyln
11
+ from hydra.utils import instantiate
12
+ from soxr import resample
13
+ from functools import partial, reduce
14
+ from itertools import accumulate
15
+ from torchcomp import coef2ms, ms2coef
16
+ from copy import deepcopy
17
+ from pathlib import Path
18
+ from typing import Tuple, List, Optional, Union
19
+
20
+ from modules.utils import vec2statedict, get_chunks
21
+ from modules.fx import clip_delay_eq_Q
22
+ from plot_utils import get_log_mags_from_eq
23
+
24
+
25
+ def chain_functions(*functions):
26
+ return lambda *initial_args: reduce(
27
+ lambda xs, f: f(*xs) if isinstance(xs, tuple) else f(xs),
28
+ functions,
29
+ initial_args,
30
+ )
31
+
32
+
33
+ title_md = "# Vocal Effects Style Transfer Demo"
34
+ description_md = """
35
+ This is a demo of the paper [DiffVox: A Differentiable Model for Capturing and Analysing Professional Effects Distributions](https://arxiv.org/abs/2504.14735), accepted at DAFx 2025.
36
+ In this demo, you can upload a raw vocal audio file (in mono) and use our model to apply professional-quality vocal processing by tweaking generated effects settings to enhance your vocals!
37
+
38
+ The effects consist of series of EQ, compressor, delay, and reverb.
39
+ The generator is a PCA model derived from 365 vocal effects presets fitted with the same effects chain.
40
+ This interface allows you to control the principal components (PCs) of the generator, randomise them, and render the audio.
41
+
42
+ To give you some idea, we empirically found that the first PC controls the amount of reverb and the second PC controls the amount of brightness.
43
+ Note that adding these PCs together does not necessarily mean that their effects are additive in the final audio.
44
+ We found sometimes the effects of least important PCs are more perceptible.
45
+ Try to play around with the sliders and buttons and see what you can come up with!
46
+
47
+ > **_Note:_** To upload your own audio, click X on the top right corner of the input audio block.
48
+ """
49
+
50
+ SLIDER_MAX = 3
51
+ SLIDER_MIN = -3
52
+ NUMBER_OF_PCS = 4
53
+ TEMPERATURE = 0.7
54
+
55
+ CONFIG_PATH = {
56
+ "realtime": "presets/rt_config.yaml",
57
+ "approx": "presets/fx_config.yaml",
58
+ }
59
+
60
+ PRESET_PATH = {
61
+ "internal": Path("presets/internal/"),
62
+ "medleydb": Path("presets/medleydb/"),
63
+ }
64
+
65
+ PCA_PARAM_FILE = "gaussian.npz"
66
+ INFO_PATH = "info.json"
67
+ MASK_PATH = "feature_mask.npy"
68
+ PARAMS_PATH = "raw_params.npy"
69
+ TRAIN_INDEX_PATH = "train_index.npy"
70
+ EXAMPLE_PATH = "eleanor_erased.wav"
71
+
72
+
73
+ with open(CONFIG_PATH["approx"]) as fp:
74
+ fx_config = yaml.safe_load(fp)["model"]
75
+
76
+
77
+ def load_presets(preset_folder: Path) -> Tensor:
78
+ raw_params = torch.from_numpy(np.load(preset_folder / PARAMS_PATH))
79
+ feature_mask = torch.from_numpy(np.load(preset_folder / MASK_PATH))
80
+ train_index_path = preset_folder / TRAIN_INDEX_PATH
81
+ if train_index_path.exists():
82
+ train_index = torch.from_numpy(np.load(train_index_path))
83
+ raw_params = raw_params[train_index]
84
+ presets = raw_params[:, feature_mask].contiguous()
85
+ return presets
86
+
87
+
88
+ def load_gaussian_params(f: Union[Path, str]) -> Tuple[Tensor, Tensor]:
89
+ gauss_params = np.load(f)
90
+ mean = torch.from_numpy(gauss_params["mean"]).float()
91
+ cov = torch.from_numpy(gauss_params["cov"]).float()
92
+ return mean, cov
93
+
94
+
95
+ preset_dict = {k: load_presets(v) for k, v in PRESET_PATH.items()}
96
+ gaussian_params_dict = {
97
+ k: load_gaussian_params(v / PCA_PARAM_FILE) for k, v in PRESET_PATH.items()
98
+ }
99
+
100
+ # Global latent variable
101
+ # z = torch.zeros_like(mean)
102
+
103
+ with open(PRESET_PATH["internal"] / INFO_PATH) as f:
104
+ info = json.load(f)
105
+
106
+ param_keys = info["params_keys"]
107
+ original_shapes = list(
108
+ map(lambda lst: lst if len(lst) else [1], info["params_original_shapes"])
109
+ )
110
+
111
+ *vec2dict_args, _ = get_chunks(param_keys, original_shapes)
112
+ vec2dict_args = [param_keys, original_shapes] + vec2dict_args
113
+ vec2dict = partial(
114
+ vec2statedict,
115
+ **dict(
116
+ zip(
117
+ [
118
+ "keys",
119
+ "original_shapes",
120
+ "selected_chunks",
121
+ "position",
122
+ "U_matrix_shape",
123
+ ],
124
+ vec2dict_args,
125
+ )
126
+ ),
127
+ )
128
+ internal_mean = gaussian_params_dict["internal"][0]
129
+
130
+ # Global effect
131
+ global_fx = instantiate(fx_config)
132
+ # global_fx.eval()
133
+ global_fx.load_state_dict(vec2dict(internal_mean), strict=False)
134
+
135
+
136
+ meter = pyln.Meter(44100)
137
+
138
+
139
+ @torch.no_grad()
140
+ def inference(audio, ratio, fx):
141
+ sr, y = audio
142
+ if sr != 44100:
143
+ y = resample(y, sr, 44100)
144
+ if y.dtype.kind != "f":
145
+ y = y / 32768.0
146
+
147
+ if y.ndim == 1:
148
+ y = y[:, None]
149
+ loudness = meter.integrated_loudness(y)
150
+ y = pyln.normalize.loudness(y, loudness, -18.0)
151
+
152
+ y = torch.from_numpy(y).float().T.unsqueeze(0)
153
+ if y.shape[1] != 1:
154
+ y = y.mean(dim=1, keepdim=True)
155
+
156
+ direct, wet = fx(y)
157
+ direct = direct.squeeze(0).T.numpy()
158
+ wet = wet.squeeze(0).T.numpy()
159
+ angle = ratio * math.pi * 0.5
160
+ test_clipping = direct + wet
161
+ # rendered = fx(y).squeeze(0).T.numpy()
162
+ if np.max(np.abs(test_clipping)) > 1:
163
+ scaler = np.max(np.abs(test_clipping))
164
+ # rendered = rendered / scaler
165
+ direct = direct / scaler
166
+ wet = wet / scaler
167
+
168
+ rendered = math.sqrt(2) * (math.cos(angle) * direct + math.sin(angle) * wet)
169
+ return (
170
+ (44100, (rendered * 32768).astype(np.int16)),
171
+ (44100, (direct * 32768).astype(np.int16)),
172
+ (
173
+ 44100,
174
+ (wet * 32768).astype(np.int16),
175
+ ),
176
+ )
177
+
178
+
179
+ def model2json(fx):
180
+ fx_names = ["PK1", "PK2", "LS", "HS", "LP", "HP", "DRC"]
181
+ results = {k: v.toJSON() for k, v in zip(fx_names, fx)} | {
182
+ "Panner": fx[7].pan.toJSON()
183
+ }
184
+ spatial_fx = {
185
+ "DLY": fx[7].effects[0].toJSON() | {"LP": fx[7].effects[0].eq.toJSON()},
186
+ "FDN": fx[7].effects[1].toJSON()
187
+ | {
188
+ "Tone correction PEQ": {
189
+ k: v.toJSON() for k, v in zip(fx_names[:4], fx[7].effects[1].eq)
190
+ }
191
+ },
192
+ "Cross Send (dB)": fx[7].params.sends_0.log10().mul(20).item(),
193
+ }
194
+ return {
195
+ "Direct": results,
196
+ "Sends": spatial_fx,
197
+ }
198
+
199
+
200
+ @torch.no_grad()
201
+ def plot_eq(fx):
202
+ fig, ax = plt.subplots(figsize=(6, 4), constrained_layout=True)
203
+ w, eq_log_mags = get_log_mags_from_eq(fx[:6])
204
+ ax.plot(w, sum(eq_log_mags), color="black", linestyle="-")
205
+ for i, eq_log_mag in enumerate(eq_log_mags):
206
+ ax.plot(w, eq_log_mag, "k-", alpha=0.3)
207
+ ax.fill_between(w, eq_log_mag, 0, facecolor="gray", edgecolor="none", alpha=0.1)
208
+ ax.set_xlabel("Frequency (Hz)")
209
+ ax.set_ylabel("Magnitude (dB)")
210
+ ax.set_xlim(20, 20000)
211
+ ax.set_ylim(-40, 20)
212
+ ax.set_xscale("log")
213
+ ax.grid()
214
+ return fig
215
+
216
+
217
+ @torch.no_grad()
218
+ def plot_comp(fx):
219
+ fig, ax = plt.subplots(figsize=(6, 5), constrained_layout=True)
220
+ comp = fx[6]
221
+ cmp_th = comp.params.cmp_th.item()
222
+ exp_th = comp.params.exp_th.item()
223
+ cmp_ratio = comp.params.cmp_ratio.item()
224
+ exp_ratio = comp.params.exp_ratio.item()
225
+ make_up = comp.params.make_up.item()
226
+ # print(cmp_ratio, cmp_th, exp_ratio, exp_th, make_up)
227
+
228
+ comp_in = np.linspace(-80, 0, 100)
229
+ comp_curve = np.where(
230
+ comp_in > cmp_th,
231
+ comp_in - (comp_in - cmp_th) * (cmp_ratio - 1) / cmp_ratio,
232
+ comp_in,
233
+ )
234
+ comp_out = (
235
+ np.where(
236
+ comp_curve < exp_th,
237
+ comp_curve - (exp_th - comp_curve) / exp_ratio,
238
+ comp_curve,
239
+ )
240
+ + make_up
241
+ )
242
+ ax.plot(comp_in, comp_out, c="black", linestyle="-")
243
+ ax.plot(comp_in, comp_in, c="r", alpha=0.5)
244
+ ax.set_xlabel("Input Level (dB)")
245
+ ax.set_ylabel("Output Level (dB)")
246
+ ax.set_xlim(-80, 0)
247
+ ax.set_ylim(-80, 0)
248
+ ax.grid()
249
+ return fig
250
+
251
+
252
+ @torch.no_grad()
253
+ def plot_delay(fx):
254
+ fig, ax = plt.subplots(figsize=(6, 4), constrained_layout=True)
255
+ delay = fx[7].effects[0]
256
+ w, eq_log_mags = get_log_mags_from_eq([delay.eq])
257
+ log_gain = delay.params.gain.log10().item() * 20
258
+ d = delay.params.delay.item() / 1000
259
+ log_mag = sum(eq_log_mags)
260
+ ax.plot(w, log_mag + log_gain, color="black", linestyle="-")
261
+
262
+ log_feedback = delay.params.feedback.log10().item() * 20
263
+ for i in range(1, 10):
264
+ feedback_log_mag = log_mag * (i + 1) + log_feedback * i + log_gain
265
+ ax.plot(
266
+ w,
267
+ feedback_log_mag,
268
+ c="black",
269
+ alpha=max(0, (10 - i * d * 4) / 10),
270
+ linestyle="-",
271
+ )
272
+
273
+ ax.set_xscale("log")
274
+ ax.set_xlim(20, 20000)
275
+ ax.set_ylim(-80, 0)
276
+ ax.set_xlabel("Frequency (Hz)")
277
+ ax.set_ylabel("Magnitude (dB)")
278
+ ax.grid()
279
+ return fig
280
+
281
+
282
+ @torch.no_grad()
283
+ def plot_reverb(fx):
284
+ fig, ax = plt.subplots(figsize=(6, 4), constrained_layout=True)
285
+ fdn = fx[7].effects[1]
286
+ w, eq_log_mags = get_log_mags_from_eq(fdn.eq)
287
+
288
+ bc = fdn.params.c.norm() * fdn.params.b.norm()
289
+ log_bc = torch.log10(bc).item() * 20
290
+ # eq_log_mags = [x + log_bc / len(eq_log_mags) for x in eq_log_mags]
291
+ # ax.plot(w, sum(eq_log_mags), color="black", linestyle="-")
292
+ eq_log_mags = sum(eq_log_mags) + log_bc
293
+ ax.plot(w, eq_log_mags, color="black", linestyle="-")
294
+
295
+ ax.set_xlabel("Frequency (Hz)")
296
+ ax.set_ylabel("Magnitude (dB)")
297
+ ax.set_xlim(20, 20000)
298
+ ax.set_ylim(-40, 20)
299
+ ax.set_xscale("log")
300
+ ax.grid()
301
+ return fig
302
+
303
+
304
+ @torch.no_grad()
305
+ def plot_t60(fx):
306
+ fig, ax = plt.subplots(figsize=(6, 4), constrained_layout=True)
307
+ fdn = fx[7].effects[1]
308
+ gamma = fdn.params.gamma.squeeze().numpy()
309
+ delays = fdn.delays.numpy()
310
+ w = np.linspace(0, 22050, gamma.size)
311
+ t60 = -60 / (20 * np.log10(gamma + 1e-10) / np.min(delays)) / 44100
312
+ ax.plot(w, t60, color="black", linestyle="-")
313
+ ax.set_xlabel("Frequency (Hz)")
314
+ ax.set_ylabel("T60 (s)")
315
+ ax.set_xlim(20, 20000)
316
+ ax.set_ylim(0, 9)
317
+ ax.set_xscale("log")
318
+ ax.grid()
319
+ return fig
320
+
321
+
322
+ def vec2fx(x):
323
+ fx = deepcopy(global_fx)
324
+ fx.load_state_dict(vec2dict(x), strict=False)
325
+ fx.apply(partial(clip_delay_eq_Q, Q=0.707))
326
+ return fx
327
+
328
+
329
+ with gr.Blocks() as demo:
330
+ fx_params = gr.State(internal_mean)
331
+ fx = vec2fx(fx_params.value)
332
+ # sr, y = read(EXAMPLE_PATH)
333
+
334
+ default_pc_slider = partial(
335
+ gr.Slider, minimum=SLIDER_MIN, maximum=SLIDER_MAX, interactive=True, value=0
336
+ )
337
+ default_audio_block = partial(gr.Audio, type="numpy", loop=True)
338
+ default_freq_slider = partial(gr.Slider, label="Frequency (Hz)", interactive=True)
339
+ default_gain_slider = partial(gr.Slider, label="Gain (dB)", interactive=True)
340
+ default_q_slider = partial(gr.Slider, label="Q", interactive=True)
341
+
342
+ gr.Markdown(
343
+ title_md,
344
+ elem_id="title",
345
+ )
346
+ with gr.Row():
347
+ gr.Markdown(
348
+ description_md,
349
+ elem_id="description",
350
+ )
351
+ # gr.Image("diffvox_diagram.png", elem_id="diagram")
352
+
353
+ with gr.Row():
354
+ with gr.Column():
355
+ audio_input = default_audio_block(
356
+ sources="upload",
357
+ label="Input Audio",
358
+ # value=(sr, y)
359
+ )
360
+ with gr.Row():
361
+ reset_button = gr.Button(
362
+ "Reset",
363
+ elem_id="reset-button",
364
+ )
365
+ render_button = gr.Button(
366
+ "Run", elem_id="render-button", variant="primary"
367
+ )
368
+
369
+ with gr.Column():
370
+ audio_output = default_audio_block(label="Output Audio", interactive=False)
371
+ dry_wet_ratio = gr.Slider(
372
+ minimum=0,
373
+ maximum=1,
374
+ value=0.5,
375
+ label="Dry/Wet Ratio",
376
+ interactive=True,
377
+ )
378
+ direct_output = default_audio_block(label="Direct Audio", interactive=False)
379
+ wet_output = default_audio_block(label="Wet Audio", interactive=False)
380
+
381
+ _ = gr.Markdown("## Common Parameters")
382
+ with gr.Row():
383
+ method_dropdown = gr.Dropdown(
384
+ ["Mean", "Nearest Neighbour", "ST-ITO", "Regression"],
385
+ value="ST-ITO",
386
+ label=f"Style Transfer Method",
387
+ interactive=True,
388
+ )
389
+ dataset_dropdown = gr.Dropdown(
390
+ ["Internal", "MedleyDB"],
391
+ label="Prior Distribution",
392
+ info="When using the Regression method, this parameter has no effect as the model is trained on the internal dataset.",
393
+ value="Internal",
394
+ interactive=True,
395
+ )
396
+ embedding_dropdown = gr.Dropdown(
397
+ ["AFx-Rep", "MFCC", "MIR Features"],
398
+ label="Embedding Model",
399
+ info="This parameter is used in the Nearest Neighbour and ST-ITO methods.",
400
+ value="AFx-Rep",
401
+ interactive=True,
402
+ )
403
+
404
+ _ = gr.Markdown("## Parameters for ST-ITO Method")
405
+ with gr.Row():
406
+ optimisation_steps = gr.Slider(
407
+ minimum=1,
408
+ maximum=10000,
409
+ value=1000,
410
+ label="Number of Optimisation Steps",
411
+ interactive=True,
412
+ )
413
+ prior_weight = gr.Slider(
414
+ minimum=0.0,
415
+ maximum=1.0,
416
+ value=0.1,
417
+ label="Prior Weight",
418
+ interactive=True,
419
+ )
420
+ optimiser_dropdown = gr.Dropdown(
421
+ [
422
+ "Adadelta",
423
+ "Adafactor",
424
+ "Adagrad",
425
+ "Adam",
426
+ "AdamW",
427
+ "Adamax",
428
+ "RMSprop",
429
+ "ASGD",
430
+ "NAdam",
431
+ "RAdam",
432
+ "SGD",
433
+ ],
434
+ value="Adam",
435
+ label="Optimiser",
436
+ interactive=True,
437
+ )
438
+
439
+ demo.launch()