import spaces import gradio as gr import numpy as np from scipy.io.wavfile import read import matplotlib.pyplot as plt import torch from torch import Tensor import math import yaml import os from huggingface_hub import hf_hub_download, snapshot_download import json import pyloudnorm as pyln from hydra.utils import instantiate from soxr import resample from functools import partial, reduce from itertools import accumulate from torchcomp import coef2ms, ms2coef from copy import deepcopy from pathlib import Path from typing import Tuple, List, Optional, Union, Callable from os import environ # Download AFx-Rep model and config files os.makedirs("tmp", exist_ok=True) ckpt_path = hf_hub_download( repo_id="csteinmetz1/afx-rep", filename="afx-rep.ckpt", local_dir="tmp", local_files_only=False, ) config_path = hf_hub_download( repo_id="csteinmetz1/afx-rep", filename="config.yaml", local_dir="tmp", local_files_only=False, ) preset_path = snapshot_download( "yoyolicoris/diffvox", repo_type="dataset", local_dir="./", local_files_only=False, allow_patterns=["presets/*", "modules/*"], ) from modules.utils import vec2statedict, get_chunks from modules.fx import clip_delay_eq_Q, hadamard from utils import ( get_log_mags_from_eq, chain_functions, remove_window_fn, jsonparse2hydra, ) from ito import find_closest_training_sample, one_evaluation from st_ito.utils import ( load_param_model, get_param_embeds, get_feature_embeds, load_mfcc_feature_extractor, load_mir_feature_extractor, ) title_md = "# Vocal Effects Style Transfer Demo" description_md = """ This is a demo of the paper [Improving Inference-Time Optimisation for Vocal Effects Style Transfer with a Gaussian Prior](https://iamycy.github.io/diffvox-ito-demo/), published at WASPAA 2025. In this demo, you can upload a raw vocal audio file (in mono) and a reference vocal mix (in stereo). This system will apply vocal effects to the raw vocal such that the processed audio matches the style of the reference mix. The effects is the same as in [DiffVox](https://huggingface.co/spaces/yoyolicoris/diffvox/), which consists of series of EQ, compressor, delay, and reverb. We offer four different methods for style transfer: 1. **Mean**: Use the mean of the chosen vocal preset dataset as the effect parameters. 2. **Nearest Neighbour**: Find the closest vocal preset in the chosen dataset based on chosen embedding features, and use its effect parameters. 3. **ST-ITO**: Our proposed method that perform inference-time optimisation in the chosen embedding space, regularised by a Gaussian prior learned from the chosen vocal preset dataset. 4. **Regression**: A pre-trained regression model that directly predicts effect parameters from the reference mix. ### Datasets - **Internal**: Derived from a proprietary dataset in Sony. - **MedleyDB**: Derived from solo vocal tracks from the MedleyDB v1/v2 dataset. ### Embedding Models - **AFx-Rep**: A self-supervised audio representation [model](https://huggingface.co/csteinmetz1/afx-rep) trained on random audio effects. - **MFCC**: Mel-frequency cepstral coefficients. - **MIR Features**: A set of common features in MIR like RMS, spectral centroid, spectral flatness, etc. > **_Note:_** To upload your own audio, click X on the top right corner of the input audio block. """ # device = "cpu" SLIDER_MAX = 3 SLIDER_MIN = -3 NUMBER_OF_PCS = 4 TEMPERATURE = 0.7 CONFIG_PATH = { "realtime": "presets/rt_config.yaml", "approx": "fx_config.yaml", } PRESET_PATH = { "internal": Path("presets/internal/"), "medleydb": Path("presets/medleydb/"), } CKPT_PATH = Path("reg-ckpts/") PCA_PARAM_FILE = "gaussian.npz" INFO_PATH = "info.json" MASK_PATH = "feature_mask.npy" PARAMS_PATH = "raw_params.npy" TRAIN_INDEX_PATH = "train_index.npy" EXAMPLE_PATH = "eleanor_erased.wav" with open(CONFIG_PATH["approx"]) as fp: fx_config = yaml.safe_load(fp)["model"] with open(CONFIG_PATH["realtime"]) as fp: rt_config = yaml.safe_load(fp)["model"] def load_presets(preset_folder: Path) -> Tensor: raw_params = torch.from_numpy(np.load(preset_folder / PARAMS_PATH)) feature_mask = torch.from_numpy(np.load(preset_folder / MASK_PATH)) train_index_path = preset_folder / TRAIN_INDEX_PATH if train_index_path.exists(): train_index = torch.from_numpy(np.load(train_index_path)) raw_params = raw_params[train_index] presets = raw_params[:, feature_mask].contiguous() return presets def load_gaussian_params(f: Union[Path, str]) -> Tuple[Tensor, Tensor, Tensor]: gauss_params = np.load(f) mean = torch.from_numpy(gauss_params["mean"]).float() cov = torch.from_numpy(gauss_params["cov"]).float() return mean, cov, cov.logdet() def logp_x(mu, cov, cov_logdet, x): diff = x - mu b = torch.linalg.solve(cov, diff) norm = diff @ b assert torch.all(norm >= 0), "Negative norm detected, check covariance matrix." return -0.5 * (norm + cov_logdet + mu.shape[0] * math.log(2 * math.pi)) preset_dict = {k: load_presets(v) for k, v in PRESET_PATH.items()} gaussian_params_dict = { k: load_gaussian_params(v / PCA_PARAM_FILE) for k, v in PRESET_PATH.items() } # Global latent variable # z = torch.zeros_like(mean) with open(PRESET_PATH["internal"] / INFO_PATH) as f: info = json.load(f) param_keys = info["params_keys"] original_shapes = list( map(lambda lst: lst if len(lst) else [1], info["params_original_shapes"]) ) *vec2dict_args, _ = get_chunks(param_keys, original_shapes) vec2dict_args = [param_keys, original_shapes] + vec2dict_args vec2dict = partial( vec2statedict, **dict( zip( [ "keys", "original_shapes", "selected_chunks", "position", "U_matrix_shape", ], vec2dict_args, ) ), ) internal_mean = gaussian_params_dict["internal"][0] # Global effect global_fx = instantiate(fx_config) # global_fx.eval() global_fx.load_state_dict(vec2dict(internal_mean), strict=False) ndim_dict = {k: v.ndim for k, v in global_fx.state_dict().items()} to_fx_state_dict = lambda x: { k: v[0] if ndim_dict[k] == 0 else v for k, v in vec2dict(x).items() } meter = pyln.Meter(44100) def get_embedding_model(embedding: str) -> Callable: device = environ.get("DEVICE", "cpu") match embedding: case "afx-rep": afx_rep = load_param_model().to(device) two_chs_emb_fn = lambda x: get_param_embeds(x, afx_rep, 44100) case "mfcc": mfcc = load_mfcc_feature_extractor().to(device) two_chs_emb_fn = lambda x: get_feature_embeds(x, mfcc) case "mir": mir = load_mir_feature_extractor().to(device) two_chs_emb_fn = lambda x: get_feature_embeds(x, mir) case _: raise ValueError(f"Unknown encoder: {embedding}") return two_chs_emb_fn def get_regressor() -> Callable: with open(CKPT_PATH / "config.yaml") as f: config = yaml.safe_load(f) model_config = config["model"] checkpoints = (CKPT_PATH / "checkpoints").glob("*val_loss*.ckpt") lowest_checkpoint = min(checkpoints, key=lambda x: float(x.stem.split("=")[-1])) last_ckpt = torch.load(lowest_checkpoint, map_location="cpu") model = chain_functions(remove_window_fn, jsonparse2hydra, instantiate)( model_config ) model.load_state_dict(last_ckpt["state_dict"]) device = environ.get("DEVICE", "cpu") model = model.to(device) model.eval() param_stats = torch.load(CKPT_PATH / "param_stats.pt") param_mu, param_std = ( param_stats["mu"].float().to(device), param_stats["std"].float().to(device), ) regressor = lambda wet: model(wet, dry=None) * param_std + param_mu return regressor def convert2float(sr: int, x: np.ndarray) -> np.ndarray: if sr != 44100: x = resample(x, sr, 44100) if x.dtype.kind != "f": x = x / 32768.0 if x.ndim == 1: x = x[:, None] return x # @spaces.GPU(duration=60) def inference( input_audio, ref_audio, method, dataset, embedding, mid_side, steps, prior_weight, optimiser, lr, progress=gr.Progress(track_tqdm=True), ): # close all figures to avoid too many open figures plt.close("all") device = environ.get("DEVICE", "cpu") if method == "Mean": return gaussian_params_dict[dataset][0].to(device) ref = convert2float(*ref_audio) ref_loudness = meter.integrated_loudness(ref) ref = pyln.normalize.loudness(ref, ref_loudness, -18.0) ref = torch.from_numpy(ref).float().T.unsqueeze(0).to(device) if method == "Regression": regressor = get_regressor() with torch.no_grad(): return regressor(ref).mean(0) y = convert2float(*input_audio) loudness = meter.integrated_loudness(y) y = pyln.normalize.loudness(y, loudness, -18.0) y = torch.from_numpy(y).float().T.unsqueeze(0).to(device) if y.shape[1] != 1: y = y.mean(dim=1, keepdim=True) fx = deepcopy(global_fx).to(device) fx.train() two_chs_emb_fn = chain_functions( hadamard if mid_side else lambda x: x, get_embedding_model(embedding), ) match method: case "Nearest Neighbour": vec = find_closest_training_sample( fx, two_chs_emb_fn, to_fx_state_dict, preset_dict[dataset].to(device), ref, y, progress, ) case "ST-ITO": vec = one_evaluation( fx, two_chs_emb_fn, to_fx_state_dict, partial( logp_x, *[x.to(device) for x in gaussian_params_dict["internal"]] ), internal_mean.to(device), ref, y, optimiser_type=optimiser, lr=lr, steps=steps, weight=prior_weight, progress=progress, ) case _: raise ValueError(f"Unknown method: {method}") return vec # @spaces.GPU(duration=10) def render(y, remove_approx, ratio, vec): device = environ.get("DEVICE", "cpu") y = convert2float(*y) loudness = meter.integrated_loudness(y) y = pyln.normalize.loudness(y, loudness, -18.0) y = torch.from_numpy(y).float().T.unsqueeze(0).to(device) if y.shape[1] != 1: y = y.mean(dim=1, keepdim=True) if remove_approx: infer_fx = instantiate(rt_config).to(device) else: infer_fx = instantiate(fx_config).to(device) infer_fx.load_state_dict(vec2dict(vec), strict=False) # fx.apply(partial(clip_delay_eq_Q, Q=0.707)) infer_fx.eval() with torch.no_grad(): direct, wet = infer_fx(y) direct = direct.squeeze(0).T.cpu().numpy() wet = wet.squeeze(0).T.cpu().numpy() angle = ratio * math.pi * 0.5 test_clipping = direct + wet # rendered = fx(y).squeeze(0).T.numpy() if np.max(np.abs(test_clipping)) > 1: scaler = np.max(np.abs(test_clipping)) # rendered = rendered / scaler direct = direct / scaler wet = wet / scaler rendered = math.sqrt(2) * (math.cos(angle) * direct + math.sin(angle) * wet) return ( (44100, (rendered * 32768).astype(np.int16)), (44100, (direct * 32768).astype(np.int16)), ( 44100, (wet * 32768).astype(np.int16), ), ) def model2json(fx): fx_names = ["PK1", "PK2", "LS", "HS", "LP", "HP", "DRC"] results = {k: v.toJSON() for k, v in zip(fx_names, fx)} | { "Panner": fx[7].pan.toJSON() } spatial_fx = { "DLY": fx[7].effects[0].toJSON() | {"LP": fx[7].effects[0].eq.toJSON()}, "FDN": fx[7].effects[1].toJSON() | { "Tone correction PEQ": { k: v.toJSON() for k, v in zip(fx_names[:4], fx[7].effects[1].eq) } }, "Cross Send (dB)": fx[7].params.sends_0.log10().mul(20).item(), } return { "Direct": results, "Sends": spatial_fx, } @torch.no_grad() def plot_eq(fx): fig, ax = plt.subplots(figsize=(6, 4), constrained_layout=True) w, eq_log_mags = get_log_mags_from_eq(fx[:6]) ax.plot(w, sum(eq_log_mags), color="black", linestyle="-") for i, eq_log_mag in enumerate(eq_log_mags): ax.plot(w, eq_log_mag, "k-", alpha=0.3) ax.fill_between(w, eq_log_mag, 0, facecolor="gray", edgecolor="none", alpha=0.1) ax.set_xlabel("Frequency (Hz)") ax.set_ylabel("Magnitude (dB)") ax.set_xlim(20, 20000) ax.set_ylim(-40, 20) ax.set_xscale("log") ax.grid() return fig @torch.no_grad() def plot_comp(fx): fig, ax = plt.subplots(figsize=(6, 5), constrained_layout=True) comp = fx[6] cmp_th = comp.params.cmp_th.item() exp_th = comp.params.exp_th.item() cmp_ratio = comp.params.cmp_ratio.item() exp_ratio = comp.params.exp_ratio.item() make_up = comp.params.make_up.item() # print(cmp_ratio, cmp_th, exp_ratio, exp_th, make_up) comp_in = np.linspace(-80, 0, 100) comp_curve = np.where( comp_in > cmp_th, comp_in - (comp_in - cmp_th) * (cmp_ratio - 1) / cmp_ratio, comp_in, ) comp_out = ( np.where( comp_curve < exp_th, comp_curve - (exp_th - comp_curve) / exp_ratio, comp_curve, ) + make_up ) ax.plot(comp_in, comp_out, c="black", linestyle="-") ax.plot(comp_in, comp_in, c="r", alpha=0.5) ax.set_xlabel("Input Level (dB)") ax.set_ylabel("Output Level (dB)") ax.set_xlim(-80, 0) ax.set_ylim(-80, 0) ax.grid() return fig @torch.no_grad() def plot_delay(fx): fig, ax = plt.subplots(figsize=(6, 4), constrained_layout=True) delay = fx[7].effects[0] w, eq_log_mags = get_log_mags_from_eq([delay.eq]) log_gain = delay.params.gain.log10().item() * 20 d = delay.params.delay.item() / 1000 log_mag = sum(eq_log_mags) ax.plot(w, log_mag + log_gain, color="black", linestyle="-") log_feedback = delay.params.feedback.log10().item() * 20 for i in range(1, 10): feedback_log_mag = log_mag * (i + 1) + log_feedback * i + log_gain ax.plot( w, feedback_log_mag, c="black", alpha=max(0, (10 - i * d * 4) / 10), linestyle="-", ) ax.set_xscale("log") ax.set_xlim(20, 20000) ax.set_ylim(-80, 0) ax.set_xlabel("Frequency (Hz)") ax.set_ylabel("Magnitude (dB)") ax.grid() return fig @torch.no_grad() def plot_reverb(fx): fig, ax = plt.subplots(figsize=(6, 4), constrained_layout=True) fdn = fx[7].effects[1] w, eq_log_mags = get_log_mags_from_eq(fdn.eq) bc = fdn.params.c.norm() * fdn.params.b.norm() log_bc = torch.log10(bc).item() * 20 # eq_log_mags = [x + log_bc / len(eq_log_mags) for x in eq_log_mags] # ax.plot(w, sum(eq_log_mags), color="black", linestyle="-") eq_log_mags = sum(eq_log_mags) + log_bc ax.plot(w, eq_log_mags, color="black", linestyle="-") ax.set_xlabel("Frequency (Hz)") ax.set_ylabel("Magnitude (dB)") ax.set_xlim(20, 20000) ax.set_ylim(-40, 20) ax.set_xscale("log") ax.grid() return fig @torch.no_grad() def plot_t60(fx): fig, ax = plt.subplots(figsize=(6, 4), constrained_layout=True) fdn = fx[7].effects[1] gamma = fdn.params.gamma.squeeze().numpy() delays = fdn.delays.numpy() w = np.linspace(0, 22050, gamma.size) t60 = -60 / (20 * np.log10(gamma + 1e-10) / np.min(delays)) / 44100 ax.plot(w, t60, color="black", linestyle="-") ax.set_xlabel("Frequency (Hz)") ax.set_ylabel("T60 (s)") ax.set_xlim(20, 20000) ax.set_ylim(0, 9) ax.set_xscale("log") ax.grid() return fig def vec2fx(x): fx = deepcopy(global_fx) fx.load_state_dict(vec2dict(x), strict=False) # fx.apply(partial(clip_delay_eq_Q, Q=0.707)) return fx with gr.Blocks() as demo: fx_params = gr.State(internal_mean) # fx = vec2fx(fx_params.value) # sr, y = read(EXAMPLE_PATH) default_audio_block = partial(gr.Audio, type="numpy", loop=True) gr.Markdown( title_md, elem_id="title", ) with gr.Row(): gr.Markdown( description_md, elem_id="description", ) gr.Image("overview.png", elem_id="diagram", height=500) with gr.Row(): with gr.Column(): audio_input = default_audio_block( sources="upload", label="Input Audio", # value=(sr, y) ) audio_reference = default_audio_block( sources="upload", label="Reference Audio", ) with gr.Row(): method_dropdown = gr.Dropdown( ["Mean", "Nearest Neighbour", "ST-ITO", "Regression"], value="ST-ITO", label=f"Style Transfer Method", interactive=True, ) process_button = gr.Button( "Run", elem_id="render-button", variant="primary" ) with gr.Column(): audio_output = default_audio_block(label="Output Audio", interactive=False) dry_wet_ratio = gr.Slider( minimum=0, maximum=1, value=0.5, label="Dry/Wet Ratio", interactive=True, ) direct_output = default_audio_block(label="Direct Audio", interactive=False) wet_output = default_audio_block(label="Wet Audio", interactive=False) with gr.Row(): with gr.Column(): _ = gr.Markdown("## Control Parameters") with gr.Row(): dataset_dropdown = gr.Dropdown( [("Internal", "internal"), ("MedleyDB", "medleydb")], label="Prior Distribution (Dataset)", info="This parameter has no effect when using the ST-ITO and Regression methods.", value="internal", interactive=True, ) embedding_dropdown = gr.Dropdown( [("AFx-Rep", "afx-rep"), ("MFCC", "mfcc"), ("MIR Features", "mir")], label="Embedding Model", info="This parameter has no effect when using the Mean and Regression methods.", value="afx-rep", interactive=True, ) # with gr.Column(): remove_approx_checkbox = gr.Checkbox( label="Use Real-time Effects", info="Use real-time delay and reverb effects instead of approximated ones.", value=False, interactive=True, ) mid_side_checkbox = gr.Checkbox( label="Use Mid-Side Processing", info="This option has no effect when using the Mean and Regression methods.", value=True, interactive=True, ) with gr.Column(): _ = gr.Markdown("## Parameters for ST-ITO Method") with gr.Row(): optimisation_steps = gr.Slider( minimum=1, maximum=2000, value=100, step=1, label="Number of Optimisation Steps", interactive=True, ) prior_weight = gr.Dropdown( [ ("0", 0.0), ("0.001", 0.001), ("0.01", 0.01), ("0.1", 0.1), ("1", 1.0), ], info="Weight of the prior distribution in the loss function. A higher value means the model will try to stay closer to the prior distribution.", value=0.01, label="Prior Weight", interactive=True, ) optimiser_dropdown = gr.Dropdown( [ "Adadelta", "Adafactor", "Adagrad", "Adam", "AdamW", "Adamax", "RMSprop", "ASGD", "NAdam", "RAdam", "SGD", ], value="Adam", label="Optimiser", interactive=True, ) lr_slider = gr.Dropdown( [("0.0001", 1e-4), ("0.001", 1e-3), ("0.01", 1e-2), ("0.1", 1e-1)], value=1e-2, label="Learning Rate", interactive=True, ) _ = gr.Markdown("## Effect Parameters Visualisation") with gr.Row(): peq_plot = gr.Plot( plot_eq(global_fx), label="PEQ Frequency Response", elem_id="peq-plot" ) comp_plot = gr.Plot( plot_comp(global_fx), label="Compressor Curve", elem_id="comp-plot" ) with gr.Row(): delay_plot = gr.Plot( plot_delay(global_fx), label="Delay Frequency Response", elem_id="delay-plot", ) reverb_plot = gr.Plot( plot_reverb(global_fx), label="Tone Correction PEQ", elem_id="reverb-plot", min_width=160, ) t60_plot = gr.Plot( plot_t60(global_fx), label="Decay Time", elem_id="t60-plot", min_width=160 ) _ = gr.Markdown("## Effect Settings JSON") with gr.Row(): json_output = gr.JSON( model2json(global_fx), label="Effect Settings", max_height=800, open=True ) process_button.click( spaces.GPU(duration=60)( chain_functions( lambda audio, approx, ratio, *args: ( audio, approx, ratio, inference(audio, *args), ), lambda audio, approx, ratio, vec: ( vec2fx(vec), *render(audio, approx, ratio, vec), vec, ), lambda fx, *args: ( *args, *map( lambda f: f(fx), [ plot_eq, plot_comp, plot_delay, plot_reverb, plot_t60, model2json, ], ), ), lambda out, dir, wet, vec, *args: ( out, dir, wet, vec.cpu(), *args, ), ) ), inputs=[ audio_input, remove_approx_checkbox, dry_wet_ratio, audio_reference, method_dropdown, dataset_dropdown, embedding_dropdown, mid_side_checkbox, optimisation_steps, prior_weight, optimiser_dropdown, lr_slider, ], outputs=[ audio_output, direct_output, wet_output, fx_params, peq_plot, comp_plot, delay_plot, reverb_plot, t60_plot, json_output, ], ) remove_approx_checkbox.change( spaces.GPU(duration=10)( lambda in_audio, out_audio, di, wet, *args: ( (out_audio, di, wet) if in_audio is None else render(in_audio, *args) ) ), inputs=[ audio_input, audio_output, direct_output, wet_output, remove_approx_checkbox, dry_wet_ratio, fx_params, ], outputs=[ audio_output, direct_output, wet_output, ], ) dry_wet_ratio.input( chain_functions( lambda _, *args: (_, *map(lambda x: x[1] / 32768, args)), lambda ratio, d, w: math.sqrt(2) * ( math.cos(ratio * math.pi * 0.5) * d + math.sin(ratio * math.pi * 0.5) * w ), lambda x: (44100, (x * 32768).astype(np.int16)), ), inputs=[dry_wet_ratio, direct_output, wet_output], outputs=[audio_output], ) demo.launch()