diffvox-ito / app.py
Chin-Yun Yu
fix: set height for overview image in demo layout
3816c03
import spaces
import gradio as gr
import numpy as np
from scipy.io.wavfile import read
import matplotlib.pyplot as plt
import torch
from torch import Tensor
import math
import yaml
import os
from huggingface_hub import hf_hub_download, snapshot_download
import json
import pyloudnorm as pyln
from hydra.utils import instantiate
from soxr import resample
from functools import partial, reduce
from itertools import accumulate
from torchcomp import coef2ms, ms2coef
from copy import deepcopy
from pathlib import Path
from typing import Tuple, List, Optional, Union, Callable
from os import environ
# Download AFx-Rep model and config files
os.makedirs("tmp", exist_ok=True)
ckpt_path = hf_hub_download(
repo_id="csteinmetz1/afx-rep",
filename="afx-rep.ckpt",
local_dir="tmp",
local_files_only=False,
)
config_path = hf_hub_download(
repo_id="csteinmetz1/afx-rep",
filename="config.yaml",
local_dir="tmp",
local_files_only=False,
)
preset_path = snapshot_download(
"yoyolicoris/diffvox",
repo_type="dataset",
local_dir="./",
local_files_only=False,
allow_patterns=["presets/*", "modules/*"],
)
from modules.utils import vec2statedict, get_chunks
from modules.fx import clip_delay_eq_Q, hadamard
from utils import (
get_log_mags_from_eq,
chain_functions,
remove_window_fn,
jsonparse2hydra,
)
from ito import find_closest_training_sample, one_evaluation
from st_ito.utils import (
load_param_model,
get_param_embeds,
get_feature_embeds,
load_mfcc_feature_extractor,
load_mir_feature_extractor,
)
title_md = "# Vocal Effects Style Transfer Demo"
description_md = """
This is a demo of the paper [Improving Inference-Time Optimisation for Vocal Effects Style Transfer with a Gaussian Prior](https://iamycy.github.io/diffvox-ito-demo/), published at WASPAA 2025.
In this demo, you can upload a raw vocal audio file (in mono) and a reference vocal mix (in stereo).
This system will apply vocal effects to the raw vocal such that the processed audio matches the style of the reference mix.
The effects is the same as in [DiffVox](https://huggingface.co/spaces/yoyolicoris/diffvox/), which consists of series of EQ, compressor, delay, and reverb.
We offer four different methods for style transfer:
1. **Mean**: Use the mean of the chosen vocal preset dataset as the effect parameters.
2. **Nearest Neighbour**: Find the closest vocal preset in the chosen dataset based on chosen embedding features, and use its effect parameters.
3. **ST-ITO**: Our proposed method that perform inference-time optimisation in the chosen embedding space, regularised by a Gaussian prior learned from the chosen vocal preset dataset.
4. **Regression**: A pre-trained regression model that directly predicts effect parameters from the reference mix.
### Datasets
- **Internal**: Derived from a proprietary dataset in Sony.
- **MedleyDB**: Derived from solo vocal tracks from the MedleyDB v1/v2 dataset.
### Embedding Models
- **AFx-Rep**: A self-supervised audio representation [model](https://huggingface.co/csteinmetz1/afx-rep) trained on random audio effects.
- **MFCC**: Mel-frequency cepstral coefficients.
- **MIR Features**: A set of common features in MIR like RMS, spectral centroid, spectral flatness, etc.
> **_Note:_** To upload your own audio, click X on the top right corner of the input audio block.
"""
# device = "cpu"
SLIDER_MAX = 3
SLIDER_MIN = -3
NUMBER_OF_PCS = 4
TEMPERATURE = 0.7
CONFIG_PATH = {
"realtime": "presets/rt_config.yaml",
"approx": "fx_config.yaml",
}
PRESET_PATH = {
"internal": Path("presets/internal/"),
"medleydb": Path("presets/medleydb/"),
}
CKPT_PATH = Path("reg-ckpts/")
PCA_PARAM_FILE = "gaussian.npz"
INFO_PATH = "info.json"
MASK_PATH = "feature_mask.npy"
PARAMS_PATH = "raw_params.npy"
TRAIN_INDEX_PATH = "train_index.npy"
EXAMPLE_PATH = "eleanor_erased.wav"
with open(CONFIG_PATH["approx"]) as fp:
fx_config = yaml.safe_load(fp)["model"]
with open(CONFIG_PATH["realtime"]) as fp:
rt_config = yaml.safe_load(fp)["model"]
def load_presets(preset_folder: Path) -> Tensor:
raw_params = torch.from_numpy(np.load(preset_folder / PARAMS_PATH))
feature_mask = torch.from_numpy(np.load(preset_folder / MASK_PATH))
train_index_path = preset_folder / TRAIN_INDEX_PATH
if train_index_path.exists():
train_index = torch.from_numpy(np.load(train_index_path))
raw_params = raw_params[train_index]
presets = raw_params[:, feature_mask].contiguous()
return presets
def load_gaussian_params(f: Union[Path, str]) -> Tuple[Tensor, Tensor, Tensor]:
gauss_params = np.load(f)
mean = torch.from_numpy(gauss_params["mean"]).float()
cov = torch.from_numpy(gauss_params["cov"]).float()
return mean, cov, cov.logdet()
def logp_x(mu, cov, cov_logdet, x):
diff = x - mu
b = torch.linalg.solve(cov, diff)
norm = diff @ b
assert torch.all(norm >= 0), "Negative norm detected, check covariance matrix."
return -0.5 * (norm + cov_logdet + mu.shape[0] * math.log(2 * math.pi))
preset_dict = {k: load_presets(v) for k, v in PRESET_PATH.items()}
gaussian_params_dict = {
k: load_gaussian_params(v / PCA_PARAM_FILE) for k, v in PRESET_PATH.items()
}
# Global latent variable
# z = torch.zeros_like(mean)
with open(PRESET_PATH["internal"] / INFO_PATH) as f:
info = json.load(f)
param_keys = info["params_keys"]
original_shapes = list(
map(lambda lst: lst if len(lst) else [1], info["params_original_shapes"])
)
*vec2dict_args, _ = get_chunks(param_keys, original_shapes)
vec2dict_args = [param_keys, original_shapes] + vec2dict_args
vec2dict = partial(
vec2statedict,
**dict(
zip(
[
"keys",
"original_shapes",
"selected_chunks",
"position",
"U_matrix_shape",
],
vec2dict_args,
)
),
)
internal_mean = gaussian_params_dict["internal"][0]
# Global effect
global_fx = instantiate(fx_config)
# global_fx.eval()
global_fx.load_state_dict(vec2dict(internal_mean), strict=False)
ndim_dict = {k: v.ndim for k, v in global_fx.state_dict().items()}
to_fx_state_dict = lambda x: {
k: v[0] if ndim_dict[k] == 0 else v for k, v in vec2dict(x).items()
}
meter = pyln.Meter(44100)
def get_embedding_model(embedding: str) -> Callable:
device = environ.get("DEVICE", "cpu")
match embedding:
case "afx-rep":
afx_rep = load_param_model().to(device)
two_chs_emb_fn = lambda x: get_param_embeds(x, afx_rep, 44100)
case "mfcc":
mfcc = load_mfcc_feature_extractor().to(device)
two_chs_emb_fn = lambda x: get_feature_embeds(x, mfcc)
case "mir":
mir = load_mir_feature_extractor().to(device)
two_chs_emb_fn = lambda x: get_feature_embeds(x, mir)
case _:
raise ValueError(f"Unknown encoder: {embedding}")
return two_chs_emb_fn
def get_regressor() -> Callable:
with open(CKPT_PATH / "config.yaml") as f:
config = yaml.safe_load(f)
model_config = config["model"]
checkpoints = (CKPT_PATH / "checkpoints").glob("*val_loss*.ckpt")
lowest_checkpoint = min(checkpoints, key=lambda x: float(x.stem.split("=")[-1]))
last_ckpt = torch.load(lowest_checkpoint, map_location="cpu")
model = chain_functions(remove_window_fn, jsonparse2hydra, instantiate)(
model_config
)
model.load_state_dict(last_ckpt["state_dict"])
device = environ.get("DEVICE", "cpu")
model = model.to(device)
model.eval()
param_stats = torch.load(CKPT_PATH / "param_stats.pt")
param_mu, param_std = (
param_stats["mu"].float().to(device),
param_stats["std"].float().to(device),
)
regressor = lambda wet: model(wet, dry=None) * param_std + param_mu
return regressor
def convert2float(sr: int, x: np.ndarray) -> np.ndarray:
if sr != 44100:
x = resample(x, sr, 44100)
if x.dtype.kind != "f":
x = x / 32768.0
if x.ndim == 1:
x = x[:, None]
return x
# @spaces.GPU(duration=60)
def inference(
input_audio,
ref_audio,
method,
dataset,
embedding,
mid_side,
steps,
prior_weight,
optimiser,
lr,
progress=gr.Progress(track_tqdm=True),
):
# close all figures to avoid too many open figures
plt.close("all")
device = environ.get("DEVICE", "cpu")
if method == "Mean":
return gaussian_params_dict[dataset][0].to(device)
ref = convert2float(*ref_audio)
ref_loudness = meter.integrated_loudness(ref)
ref = pyln.normalize.loudness(ref, ref_loudness, -18.0)
ref = torch.from_numpy(ref).float().T.unsqueeze(0).to(device)
if method == "Regression":
regressor = get_regressor()
with torch.no_grad():
return regressor(ref).mean(0)
y = convert2float(*input_audio)
loudness = meter.integrated_loudness(y)
y = pyln.normalize.loudness(y, loudness, -18.0)
y = torch.from_numpy(y).float().T.unsqueeze(0).to(device)
if y.shape[1] != 1:
y = y.mean(dim=1, keepdim=True)
fx = deepcopy(global_fx).to(device)
fx.train()
two_chs_emb_fn = chain_functions(
hadamard if mid_side else lambda x: x,
get_embedding_model(embedding),
)
match method:
case "Nearest Neighbour":
vec = find_closest_training_sample(
fx,
two_chs_emb_fn,
to_fx_state_dict,
preset_dict[dataset].to(device),
ref,
y,
progress,
)
case "ST-ITO":
vec = one_evaluation(
fx,
two_chs_emb_fn,
to_fx_state_dict,
partial(
logp_x, *[x.to(device) for x in gaussian_params_dict["internal"]]
),
internal_mean.to(device),
ref,
y,
optimiser_type=optimiser,
lr=lr,
steps=steps,
weight=prior_weight,
progress=progress,
)
case _:
raise ValueError(f"Unknown method: {method}")
return vec
# @spaces.GPU(duration=10)
def render(y, remove_approx, ratio, vec):
device = environ.get("DEVICE", "cpu")
y = convert2float(*y)
loudness = meter.integrated_loudness(y)
y = pyln.normalize.loudness(y, loudness, -18.0)
y = torch.from_numpy(y).float().T.unsqueeze(0).to(device)
if y.shape[1] != 1:
y = y.mean(dim=1, keepdim=True)
if remove_approx:
infer_fx = instantiate(rt_config).to(device)
else:
infer_fx = instantiate(fx_config).to(device)
infer_fx.load_state_dict(vec2dict(vec), strict=False)
# fx.apply(partial(clip_delay_eq_Q, Q=0.707))
infer_fx.eval()
with torch.no_grad():
direct, wet = infer_fx(y)
direct = direct.squeeze(0).T.cpu().numpy()
wet = wet.squeeze(0).T.cpu().numpy()
angle = ratio * math.pi * 0.5
test_clipping = direct + wet
# rendered = fx(y).squeeze(0).T.numpy()
if np.max(np.abs(test_clipping)) > 1:
scaler = np.max(np.abs(test_clipping))
# rendered = rendered / scaler
direct = direct / scaler
wet = wet / scaler
rendered = math.sqrt(2) * (math.cos(angle) * direct + math.sin(angle) * wet)
return (
(44100, (rendered * 32768).astype(np.int16)),
(44100, (direct * 32768).astype(np.int16)),
(
44100,
(wet * 32768).astype(np.int16),
),
)
def model2json(fx):
fx_names = ["PK1", "PK2", "LS", "HS", "LP", "HP", "DRC"]
results = {k: v.toJSON() for k, v in zip(fx_names, fx)} | {
"Panner": fx[7].pan.toJSON()
}
spatial_fx = {
"DLY": fx[7].effects[0].toJSON() | {"LP": fx[7].effects[0].eq.toJSON()},
"FDN": fx[7].effects[1].toJSON()
| {
"Tone correction PEQ": {
k: v.toJSON() for k, v in zip(fx_names[:4], fx[7].effects[1].eq)
}
},
"Cross Send (dB)": fx[7].params.sends_0.log10().mul(20).item(),
}
return {
"Direct": results,
"Sends": spatial_fx,
}
@torch.no_grad()
def plot_eq(fx):
fig, ax = plt.subplots(figsize=(6, 4), constrained_layout=True)
w, eq_log_mags = get_log_mags_from_eq(fx[:6])
ax.plot(w, sum(eq_log_mags), color="black", linestyle="-")
for i, eq_log_mag in enumerate(eq_log_mags):
ax.plot(w, eq_log_mag, "k-", alpha=0.3)
ax.fill_between(w, eq_log_mag, 0, facecolor="gray", edgecolor="none", alpha=0.1)
ax.set_xlabel("Frequency (Hz)")
ax.set_ylabel("Magnitude (dB)")
ax.set_xlim(20, 20000)
ax.set_ylim(-40, 20)
ax.set_xscale("log")
ax.grid()
return fig
@torch.no_grad()
def plot_comp(fx):
fig, ax = plt.subplots(figsize=(6, 5), constrained_layout=True)
comp = fx[6]
cmp_th = comp.params.cmp_th.item()
exp_th = comp.params.exp_th.item()
cmp_ratio = comp.params.cmp_ratio.item()
exp_ratio = comp.params.exp_ratio.item()
make_up = comp.params.make_up.item()
# print(cmp_ratio, cmp_th, exp_ratio, exp_th, make_up)
comp_in = np.linspace(-80, 0, 100)
comp_curve = np.where(
comp_in > cmp_th,
comp_in - (comp_in - cmp_th) * (cmp_ratio - 1) / cmp_ratio,
comp_in,
)
comp_out = (
np.where(
comp_curve < exp_th,
comp_curve - (exp_th - comp_curve) / exp_ratio,
comp_curve,
)
+ make_up
)
ax.plot(comp_in, comp_out, c="black", linestyle="-")
ax.plot(comp_in, comp_in, c="r", alpha=0.5)
ax.set_xlabel("Input Level (dB)")
ax.set_ylabel("Output Level (dB)")
ax.set_xlim(-80, 0)
ax.set_ylim(-80, 0)
ax.grid()
return fig
@torch.no_grad()
def plot_delay(fx):
fig, ax = plt.subplots(figsize=(6, 4), constrained_layout=True)
delay = fx[7].effects[0]
w, eq_log_mags = get_log_mags_from_eq([delay.eq])
log_gain = delay.params.gain.log10().item() * 20
d = delay.params.delay.item() / 1000
log_mag = sum(eq_log_mags)
ax.plot(w, log_mag + log_gain, color="black", linestyle="-")
log_feedback = delay.params.feedback.log10().item() * 20
for i in range(1, 10):
feedback_log_mag = log_mag * (i + 1) + log_feedback * i + log_gain
ax.plot(
w,
feedback_log_mag,
c="black",
alpha=max(0, (10 - i * d * 4) / 10),
linestyle="-",
)
ax.set_xscale("log")
ax.set_xlim(20, 20000)
ax.set_ylim(-80, 0)
ax.set_xlabel("Frequency (Hz)")
ax.set_ylabel("Magnitude (dB)")
ax.grid()
return fig
@torch.no_grad()
def plot_reverb(fx):
fig, ax = plt.subplots(figsize=(6, 4), constrained_layout=True)
fdn = fx[7].effects[1]
w, eq_log_mags = get_log_mags_from_eq(fdn.eq)
bc = fdn.params.c.norm() * fdn.params.b.norm()
log_bc = torch.log10(bc).item() * 20
# eq_log_mags = [x + log_bc / len(eq_log_mags) for x in eq_log_mags]
# ax.plot(w, sum(eq_log_mags), color="black", linestyle="-")
eq_log_mags = sum(eq_log_mags) + log_bc
ax.plot(w, eq_log_mags, color="black", linestyle="-")
ax.set_xlabel("Frequency (Hz)")
ax.set_ylabel("Magnitude (dB)")
ax.set_xlim(20, 20000)
ax.set_ylim(-40, 20)
ax.set_xscale("log")
ax.grid()
return fig
@torch.no_grad()
def plot_t60(fx):
fig, ax = plt.subplots(figsize=(6, 4), constrained_layout=True)
fdn = fx[7].effects[1]
gamma = fdn.params.gamma.squeeze().numpy()
delays = fdn.delays.numpy()
w = np.linspace(0, 22050, gamma.size)
t60 = -60 / (20 * np.log10(gamma + 1e-10) / np.min(delays)) / 44100
ax.plot(w, t60, color="black", linestyle="-")
ax.set_xlabel("Frequency (Hz)")
ax.set_ylabel("T60 (s)")
ax.set_xlim(20, 20000)
ax.set_ylim(0, 9)
ax.set_xscale("log")
ax.grid()
return fig
def vec2fx(x):
fx = deepcopy(global_fx)
fx.load_state_dict(vec2dict(x), strict=False)
# fx.apply(partial(clip_delay_eq_Q, Q=0.707))
return fx
with gr.Blocks() as demo:
fx_params = gr.State(internal_mean)
# fx = vec2fx(fx_params.value)
# sr, y = read(EXAMPLE_PATH)
default_audio_block = partial(gr.Audio, type="numpy", loop=True)
gr.Markdown(
title_md,
elem_id="title",
)
with gr.Row():
gr.Markdown(
description_md,
elem_id="description",
)
gr.Image("overview.png", elem_id="diagram", height=500)
with gr.Row():
with gr.Column():
audio_input = default_audio_block(
sources="upload",
label="Input Audio",
# value=(sr, y)
)
audio_reference = default_audio_block(
sources="upload",
label="Reference Audio",
)
with gr.Row():
method_dropdown = gr.Dropdown(
["Mean", "Nearest Neighbour", "ST-ITO", "Regression"],
value="ST-ITO",
label=f"Style Transfer Method",
interactive=True,
)
process_button = gr.Button(
"Run", elem_id="render-button", variant="primary"
)
with gr.Column():
audio_output = default_audio_block(label="Output Audio", interactive=False)
dry_wet_ratio = gr.Slider(
minimum=0,
maximum=1,
value=0.5,
label="Dry/Wet Ratio",
interactive=True,
)
direct_output = default_audio_block(label="Direct Audio", interactive=False)
wet_output = default_audio_block(label="Wet Audio", interactive=False)
with gr.Row():
with gr.Column():
_ = gr.Markdown("## Control Parameters")
with gr.Row():
dataset_dropdown = gr.Dropdown(
[("Internal", "internal"), ("MedleyDB", "medleydb")],
label="Prior Distribution (Dataset)",
info="This parameter has no effect when using the ST-ITO and Regression methods.",
value="internal",
interactive=True,
)
embedding_dropdown = gr.Dropdown(
[("AFx-Rep", "afx-rep"), ("MFCC", "mfcc"), ("MIR Features", "mir")],
label="Embedding Model",
info="This parameter has no effect when using the Mean and Regression methods.",
value="afx-rep",
interactive=True,
)
# with gr.Column():
remove_approx_checkbox = gr.Checkbox(
label="Use Real-time Effects",
info="Use real-time delay and reverb effects instead of approximated ones.",
value=False,
interactive=True,
)
mid_side_checkbox = gr.Checkbox(
label="Use Mid-Side Processing",
info="This option has no effect when using the Mean and Regression methods.",
value=True,
interactive=True,
)
with gr.Column():
_ = gr.Markdown("## Parameters for ST-ITO Method")
with gr.Row():
optimisation_steps = gr.Slider(
minimum=1,
maximum=2000,
value=100,
step=1,
label="Number of Optimisation Steps",
interactive=True,
)
prior_weight = gr.Dropdown(
[
("0", 0.0),
("0.001", 0.001),
("0.01", 0.01),
("0.1", 0.1),
("1", 1.0),
],
info="Weight of the prior distribution in the loss function. A higher value means the model will try to stay closer to the prior distribution.",
value=0.01,
label="Prior Weight",
interactive=True,
)
optimiser_dropdown = gr.Dropdown(
[
"Adadelta",
"Adafactor",
"Adagrad",
"Adam",
"AdamW",
"Adamax",
"RMSprop",
"ASGD",
"NAdam",
"RAdam",
"SGD",
],
value="Adam",
label="Optimiser",
interactive=True,
)
lr_slider = gr.Dropdown(
[("0.0001", 1e-4), ("0.001", 1e-3), ("0.01", 1e-2), ("0.1", 1e-1)],
value=1e-2,
label="Learning Rate",
interactive=True,
)
_ = gr.Markdown("## Effect Parameters Visualisation")
with gr.Row():
peq_plot = gr.Plot(
plot_eq(global_fx), label="PEQ Frequency Response", elem_id="peq-plot"
)
comp_plot = gr.Plot(
plot_comp(global_fx), label="Compressor Curve", elem_id="comp-plot"
)
with gr.Row():
delay_plot = gr.Plot(
plot_delay(global_fx),
label="Delay Frequency Response",
elem_id="delay-plot",
)
reverb_plot = gr.Plot(
plot_reverb(global_fx),
label="Tone Correction PEQ",
elem_id="reverb-plot",
min_width=160,
)
t60_plot = gr.Plot(
plot_t60(global_fx), label="Decay Time", elem_id="t60-plot", min_width=160
)
_ = gr.Markdown("## Effect Settings JSON")
with gr.Row():
json_output = gr.JSON(
model2json(global_fx), label="Effect Settings", max_height=800, open=True
)
process_button.click(
spaces.GPU(duration=60)(
chain_functions(
lambda audio, approx, ratio, *args: (
audio,
approx,
ratio,
inference(audio, *args),
),
lambda audio, approx, ratio, vec: (
vec2fx(vec),
*render(audio, approx, ratio, vec),
vec,
),
lambda fx, *args: (
*args,
*map(
lambda f: f(fx),
[
plot_eq,
plot_comp,
plot_delay,
plot_reverb,
plot_t60,
model2json,
],
),
),
lambda out, dir, wet, vec, *args: (
out,
dir,
wet,
vec.cpu(),
*args,
),
)
),
inputs=[
audio_input,
remove_approx_checkbox,
dry_wet_ratio,
audio_reference,
method_dropdown,
dataset_dropdown,
embedding_dropdown,
mid_side_checkbox,
optimisation_steps,
prior_weight,
optimiser_dropdown,
lr_slider,
],
outputs=[
audio_output,
direct_output,
wet_output,
fx_params,
peq_plot,
comp_plot,
delay_plot,
reverb_plot,
t60_plot,
json_output,
],
)
remove_approx_checkbox.change(
spaces.GPU(duration=10)(
lambda in_audio, out_audio, di, wet, *args: (
(out_audio, di, wet) if in_audio is None else render(in_audio, *args)
)
),
inputs=[
audio_input,
audio_output,
direct_output,
wet_output,
remove_approx_checkbox,
dry_wet_ratio,
fx_params,
],
outputs=[
audio_output,
direct_output,
wet_output,
],
)
dry_wet_ratio.input(
chain_functions(
lambda _, *args: (_, *map(lambda x: x[1] / 32768, args)),
lambda ratio, d, w: math.sqrt(2)
* (
math.cos(ratio * math.pi * 0.5) * d
+ math.sin(ratio * math.pi * 0.5) * w
),
lambda x: (44100, (x * 32768).astype(np.int16)),
),
inputs=[dry_wet_ratio, direct_output, wet_output],
outputs=[audio_output],
)
demo.launch()