Spaces:
Running
Running
| import gradio as gr | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| import torch | |
| import yaml | |
| import json | |
| import pyloudnorm as pyln | |
| from hydra.utils import instantiate | |
| from soxr import resample | |
| from functools import partial | |
| from modules.utils import chain_functions, vec2statedict, get_chunks | |
| from modules.fx import clip_delay_eq_Q | |
| from plot_utils import get_log_mags_from_eq | |
| title_md = "# Vocal Effects Generator" | |
| description_md = """ | |
| This is a demo of the paper [DiffVox: A Differentiable Model for Capturing and Analysing Professional Effects Distributions](https://arxiv.org/abs/2504.14735), accepted at DAFx 2025. | |
| In this demo, you can upload a raw vocal audio file (in mono) and apply random effects to make it sound better! | |
| The effects consist of series of EQ, compressor, delay, and reverb. | |
| The generator is a PCA model derived from 365 vocal effects presets fitted with the same effects chain. | |
| This interface allows you to control the principal components (PCs) of the generator, randomise them, and render the audio. | |
| To give you some idea, we emperically found that the first PC controls the amount of reverb and the second PC controls the amount of brightness. | |
| Note that adding these PCs together does not necessarily mean that their effects are additive in the final audio. | |
| We found sometimes the effects of least important PCs are more perceptible. | |
| Try to play around with the sliders and buttons and see what you can come up with! | |
| Currently only PCs are tweakable, but in the future we will add more controls and visualisation tools. | |
| For example: | |
| - Directly controlling the parameters of the effects | |
| - Visualising the PCA space | |
| - Visualising the frequency responses/dynamic curves of the effects | |
| """ | |
| SLIDER_MAX = 3 | |
| SLIDER_MIN = -3 | |
| NUMBER_OF_PCS = 10 | |
| TEMPERATURE = 0.7 | |
| CONFIG_PATH = "presets/rt_config.yaml" | |
| PCA_PARAM_FILE = "presets/internal/gaussian.npz" | |
| INFO_PATH = "presets/internal/info.json" | |
| MASK_PATH = "presets/internal/feature_mask.npy" | |
| with open(CONFIG_PATH) as fp: | |
| fx_config = yaml.safe_load(fp)["model"] | |
| # Global effect | |
| fx = instantiate(fx_config) | |
| fx.eval() | |
| pca_params = np.load(PCA_PARAM_FILE) | |
| mean = pca_params["mean"] | |
| cov = pca_params["cov"] | |
| eigvals, eigvecs = np.linalg.eigh(cov) | |
| eigvals = np.flip(eigvals, axis=0)[:75] | |
| eigvecs = np.flip(eigvecs, axis=1)[:, :75] | |
| U = eigvecs * np.sqrt(eigvals) | |
| U = torch.from_numpy(U).float() | |
| mean = torch.from_numpy(mean).float() | |
| feature_mask = torch.from_numpy(np.load(MASK_PATH)) | |
| # Global latent variable | |
| z = torch.zeros(75) | |
| with open(INFO_PATH) as f: | |
| info = json.load(f) | |
| param_keys = info["params_keys"] | |
| original_shapes = list( | |
| map(lambda lst: lst if len(lst) else [1], info["params_original_shapes"]) | |
| ) | |
| *vec2dict_args, _ = get_chunks(param_keys, original_shapes) | |
| vec2dict_args = [param_keys, original_shapes] + vec2dict_args | |
| vec2dict = partial( | |
| vec2statedict, | |
| **dict( | |
| zip( | |
| [ | |
| "keys", | |
| "original_shapes", | |
| "selected_chunks", | |
| "position", | |
| "U_matrix_shape", | |
| ], | |
| vec2dict_args, | |
| ) | |
| ), | |
| ) | |
| fx.load_state_dict(vec2dict(mean), strict=False) | |
| meter = pyln.Meter(44100) | |
| def z2fx(): | |
| # close all figures to avoid too many open figures | |
| plt.close("all") | |
| x = U @ z + mean | |
| # print(z) | |
| fx.load_state_dict(vec2dict(x), strict=False) | |
| return | |
| def fx2z(func): | |
| def wrapper(*args, **kwargs): | |
| ret = func(*args, **kwargs) | |
| state_dict = fx.state_dict() | |
| flattened = torch.cat([state_dict[k].flatten() for k in param_keys]) | |
| x = flattened[feature_mask] | |
| z.copy_(U.T @ (x - mean)) | |
| return ret | |
| return wrapper | |
| def inference(audio): | |
| sr, y = audio | |
| if sr != 44100: | |
| y = resample(y, sr, 44100) | |
| if y.dtype.kind != "f": | |
| y = y / 32768.0 | |
| if y.ndim == 1: | |
| y = y[:, None] | |
| loudness = meter.integrated_loudness(y) | |
| y = pyln.normalize.loudness(y, loudness, -18.0) | |
| y = torch.from_numpy(y).float().T.unsqueeze(0) | |
| if y.shape[1] != 1: | |
| y = y.mean(dim=1, keepdim=True) | |
| fx.apply(partial(clip_delay_eq_Q, Q=0.707)) | |
| rendered = fx(y).squeeze(0).T.numpy() | |
| if np.max(np.abs(rendered)) > 1: | |
| rendered = rendered / np.max(np.abs(rendered)) | |
| return (44100, (rendered * 32768).astype(np.int16)) | |
| def get_important_pcs(n=10, **kwargs): | |
| sliders = [ | |
| gr.Slider(minimum=SLIDER_MIN, maximum=SLIDER_MAX, label=f"PC {i}", **kwargs) | |
| for i in range(1, n + 1) | |
| ] | |
| return sliders | |
| def model2json(): | |
| fx_names = ["PK1", "PK2", "LS", "HS", "LP", "HP", "DRC"] | |
| results = {k: v.toJSON() for k, v in zip(fx_names, fx)} | { | |
| "Panner": fx[7].pan.toJSON() | |
| } | |
| spatial_fx = { | |
| "DLY": fx[7].effects[0].toJSON() | {"LP": fx[7].effects[0].eq.toJSON()}, | |
| "FDN": fx[7].effects[1].toJSON() | |
| | { | |
| "Tone correction PEQ": { | |
| k: v.toJSON() for k, v in zip(fx_names[:4], fx[7].effects[1].eq) | |
| } | |
| }, | |
| "Cross Send (dB)": fx[7].params.sends_0.log10().mul(20).item(), | |
| } | |
| return json.dumps( | |
| { | |
| "Direct": results, | |
| "Sends": spatial_fx, | |
| } | |
| ) | |
| def plot_eq(): | |
| fig, ax = plt.subplots(figsize=(8, 4)) | |
| w, eq_log_mags = get_log_mags_from_eq(fx[:6]) | |
| ax.plot(w, sum(eq_log_mags), color="black", linestyle="-") | |
| for i, eq_log_mag in enumerate(eq_log_mags): | |
| ax.plot(w, eq_log_mag, "k-", alpha=0.3) | |
| ax.fill_between(w, eq_log_mag, 0, facecolor="gray", edgecolor="none", alpha=0.1) | |
| ax.set_xlabel("Frequency (Hz)") | |
| ax.set_ylabel("Magnitude (dB)") | |
| ax.set_xlim(20, 20000) | |
| ax.set_ylim(-40, 20) | |
| ax.set_xscale("log") | |
| ax.grid() | |
| return fig | |
| with gr.Blocks() as demo: | |
| gr.Markdown( | |
| title_md, | |
| elem_id="title", | |
| ) | |
| with gr.Row(): | |
| gr.Markdown( | |
| description_md, | |
| elem_id="description", | |
| ) | |
| gr.Image("diffvox_diagram.png", elem_id="diagram") | |
| with gr.Row(): | |
| with gr.Column(): | |
| audio_input = gr.Audio( | |
| type="numpy", sources="upload", label="Input Audio", loop=True | |
| ) | |
| with gr.Row(): | |
| random_button = gr.Button( | |
| f"Randomise PCs", | |
| elem_id="randomise-button", | |
| ) | |
| reset_button = gr.Button( | |
| "Reset", | |
| elem_id="reset-button", | |
| ) | |
| render_button = gr.Button( | |
| "Run", elem_id="render-button", variant="primary" | |
| ) | |
| # random_rest_checkbox = gr.Checkbox( | |
| # label=f"Randomise PCs > {NUMBER_OF_PCS} (default to zeros)", | |
| # value=False, | |
| # elem_id="randomise-checkbox", | |
| # ) | |
| sliders = get_important_pcs(NUMBER_OF_PCS, value=0) | |
| extra_pc_dropdown = gr.Dropdown( | |
| list(range(NUMBER_OF_PCS + 1, 76)), | |
| label=f"PC > {NUMBER_OF_PCS}", | |
| info="Select which extra PC to adjust", | |
| interactive=True, | |
| ) | |
| extra_slider = gr.Slider( | |
| minimum=SLIDER_MIN, | |
| maximum=SLIDER_MAX, | |
| label="Extra PC", | |
| value=0, | |
| ) | |
| with gr.Column(): | |
| audio_output = gr.Audio( | |
| type="numpy", label="Output Audio", interactive=False, loop=True | |
| ) | |
| peq_plot = gr.Plot( | |
| plot_eq(), label="PEQ Frequency Response", elem_id="peq-plot" | |
| ) | |
| with gr.Row(): | |
| json_output = gr.JSON(label="Effect Settings", max_height=800, open=True) | |
| render_button.click( | |
| lambda *args: (lambda x: (x, model2json(), plot_eq()))(inference(*args)), | |
| inputs=[ | |
| audio_input, | |
| # random_rest_checkbox, | |
| ] | |
| # + sliders, | |
| , | |
| outputs=[audio_output, json_output, peq_plot], | |
| ) | |
| random_button.click( | |
| # lambda *xs: [ | |
| # chain_functions( | |
| # partial(max, SLIDER_MIN), | |
| # partial(min, SLIDER_MAX), | |
| # )(normalvariate(0, 1)) | |
| # for _ in range(len(xs)) | |
| # ], | |
| # lambda i: (lambda x: x[:NUMBER_OF_PCS].tolist() + [x[i - 1].item()])( | |
| # z.normal_(0, 1).clip_(SLIDER_MIN, SLIDER_MAX) | |
| # ), | |
| chain_functions( | |
| lambda i: (z.normal_(0, 1).clip_(SLIDER_MIN, SLIDER_MAX), i), | |
| lambda args: args + (z2fx(),), | |
| lambda args: args[0][:NUMBER_OF_PCS].tolist() | |
| + [args[0][args[1] - 1].item(), plot_eq()], | |
| ), | |
| inputs=extra_pc_dropdown, | |
| outputs=sliders + [extra_slider, peq_plot], | |
| ) | |
| reset_button.click( | |
| # lambda: (lambda _: [0 for _ in range(NUMBER_OF_PCS + 1)])(z.zero_()), | |
| lambda: chain_functions( | |
| lambda _: z.zero_(), | |
| lambda _: z2fx(), | |
| lambda _: [0 for _ in range(NUMBER_OF_PCS + 1)] + [plot_eq()], | |
| )(None), | |
| # inputs=sliders + [extra_slider], | |
| outputs=sliders + [extra_slider, peq_plot], | |
| ) | |
| def update_z(s, i): | |
| z[i] = s | |
| return | |
| for i, slider in enumerate(sliders): | |
| slider.input( | |
| chain_functions( | |
| partial(update_z, i=i), | |
| lambda _: z2fx(), | |
| lambda _: plot_eq(), | |
| ), | |
| inputs=slider, | |
| outputs=peq_plot, | |
| ) | |
| extra_slider.input( | |
| lambda *xs: chain_functions( | |
| lambda args: update_z(args[0], args[1] - 1), | |
| lambda _: z2fx(), | |
| lambda _: plot_eq(), | |
| )(xs), | |
| inputs=[extra_slider, extra_pc_dropdown], | |
| outputs=peq_plot, | |
| ) | |
| extra_pc_dropdown.input( | |
| lambda i: z[i - 1].item(), | |
| inputs=extra_pc_dropdown, | |
| outputs=extra_slider, | |
| ) | |
| demo.launch() | |