Spaces:
Running
on
Zero
Running
on
Zero
| import spaces | |
| import gradio as gr | |
| import numpy as np | |
| from scipy.io.wavfile import read | |
| import matplotlib.pyplot as plt | |
| import torch | |
| from torch import Tensor | |
| import math | |
| import yaml | |
| import os | |
| from huggingface_hub import hf_hub_download, snapshot_download | |
| import json | |
| import pyloudnorm as pyln | |
| from hydra.utils import instantiate | |
| from soxr import resample | |
| from functools import partial, reduce | |
| from itertools import accumulate | |
| from torchcomp import coef2ms, ms2coef | |
| from copy import deepcopy | |
| from pathlib import Path | |
| from typing import Tuple, List, Optional, Union, Callable | |
| from os import environ | |
| # Download AFx-Rep model and config files | |
| os.makedirs("tmp", exist_ok=True) | |
| ckpt_path = hf_hub_download( | |
| repo_id="csteinmetz1/afx-rep", | |
| filename="afx-rep.ckpt", | |
| local_dir="tmp", | |
| local_files_only=False, | |
| ) | |
| config_path = hf_hub_download( | |
| repo_id="csteinmetz1/afx-rep", | |
| filename="config.yaml", | |
| local_dir="tmp", | |
| local_files_only=False, | |
| ) | |
| preset_path = snapshot_download( | |
| "yoyolicoris/diffvox", | |
| repo_type="dataset", | |
| local_dir="./", | |
| local_files_only=False, | |
| allow_patterns=["presets/*", "modules/*"], | |
| ) | |
| from modules.utils import vec2statedict, get_chunks | |
| from modules.fx import clip_delay_eq_Q, hadamard | |
| from utils import ( | |
| get_log_mags_from_eq, | |
| chain_functions, | |
| remove_window_fn, | |
| jsonparse2hydra, | |
| ) | |
| from ito import find_closest_training_sample, one_evaluation | |
| from st_ito.utils import ( | |
| load_param_model, | |
| get_param_embeds, | |
| get_feature_embeds, | |
| load_mfcc_feature_extractor, | |
| load_mir_feature_extractor, | |
| ) | |
| title_md = "# Vocal Effects Style Transfer Demo" | |
| description_md = """ | |
| This is a demo of the paper [Improving Inference-Time Optimisation for Vocal Effects Style Transfer with a Gaussian Prior](https://iamycy.github.io/diffvox-ito-demo/), published at WASPAA 2025. | |
| In this demo, you can upload a raw vocal audio file (in mono) and a reference vocal mix (in stereo). | |
| This system will apply vocal effects to the raw vocal such that the processed audio matches the style of the reference mix. | |
| The effects is the same as in [DiffVox](https://huggingface.co/spaces/yoyolicoris/diffvox/), which consists of series of EQ, compressor, delay, and reverb. | |
| We offer four different methods for style transfer: | |
| 1. **Mean**: Use the mean of the chosen vocal preset dataset as the effect parameters. | |
| 2. **Nearest Neighbour**: Find the closest vocal preset in the chosen dataset based on chosen embedding features, and use its effect parameters. | |
| 3. **ST-ITO**: Our proposed method that perform inference-time optimisation in the chosen embedding space, regularised by a Gaussian prior learned from the chosen vocal preset dataset. | |
| 4. **Regression**: A pre-trained regression model that directly predicts effect parameters from the reference mix. | |
| ### Datasets | |
| - **Internal**: Derived from a proprietary dataset in Sony. | |
| - **MedleyDB**: Derived from solo vocal tracks from the MedleyDB v1/v2 dataset. | |
| ### Embedding Models | |
| - **AFx-Rep**: A self-supervised audio representation [model](https://huggingface.co/csteinmetz1/afx-rep) trained on random audio effects. | |
| - **MFCC**: Mel-frequency cepstral coefficients. | |
| - **MIR Features**: A set of common features in MIR like RMS, spectral centroid, spectral flatness, etc. | |
| > **_Note:_** To upload your own audio, click X on the top right corner of the input audio block. | |
| """ | |
| # device = "cpu" | |
| SLIDER_MAX = 3 | |
| SLIDER_MIN = -3 | |
| NUMBER_OF_PCS = 4 | |
| TEMPERATURE = 0.7 | |
| CONFIG_PATH = { | |
| "realtime": "presets/rt_config.yaml", | |
| "approx": "fx_config.yaml", | |
| } | |
| PRESET_PATH = { | |
| "internal": Path("presets/internal/"), | |
| "medleydb": Path("presets/medleydb/"), | |
| } | |
| CKPT_PATH = Path("reg-ckpts/") | |
| PCA_PARAM_FILE = "gaussian.npz" | |
| INFO_PATH = "info.json" | |
| MASK_PATH = "feature_mask.npy" | |
| PARAMS_PATH = "raw_params.npy" | |
| TRAIN_INDEX_PATH = "train_index.npy" | |
| EXAMPLE_PATH = "eleanor_erased.wav" | |
| with open(CONFIG_PATH["approx"]) as fp: | |
| fx_config = yaml.safe_load(fp)["model"] | |
| with open(CONFIG_PATH["realtime"]) as fp: | |
| rt_config = yaml.safe_load(fp)["model"] | |
| def load_presets(preset_folder: Path) -> Tensor: | |
| raw_params = torch.from_numpy(np.load(preset_folder / PARAMS_PATH)) | |
| feature_mask = torch.from_numpy(np.load(preset_folder / MASK_PATH)) | |
| train_index_path = preset_folder / TRAIN_INDEX_PATH | |
| if train_index_path.exists(): | |
| train_index = torch.from_numpy(np.load(train_index_path)) | |
| raw_params = raw_params[train_index] | |
| presets = raw_params[:, feature_mask].contiguous() | |
| return presets | |
| def load_gaussian_params(f: Union[Path, str]) -> Tuple[Tensor, Tensor, Tensor]: | |
| gauss_params = np.load(f) | |
| mean = torch.from_numpy(gauss_params["mean"]).float() | |
| cov = torch.from_numpy(gauss_params["cov"]).float() | |
| return mean, cov, cov.logdet() | |
| def logp_x(mu, cov, cov_logdet, x): | |
| diff = x - mu | |
| b = torch.linalg.solve(cov, diff) | |
| norm = diff @ b | |
| assert torch.all(norm >= 0), "Negative norm detected, check covariance matrix." | |
| return -0.5 * (norm + cov_logdet + mu.shape[0] * math.log(2 * math.pi)) | |
| preset_dict = {k: load_presets(v) for k, v in PRESET_PATH.items()} | |
| gaussian_params_dict = { | |
| k: load_gaussian_params(v / PCA_PARAM_FILE) for k, v in PRESET_PATH.items() | |
| } | |
| # Global latent variable | |
| # z = torch.zeros_like(mean) | |
| with open(PRESET_PATH["internal"] / INFO_PATH) as f: | |
| info = json.load(f) | |
| param_keys = info["params_keys"] | |
| original_shapes = list( | |
| map(lambda lst: lst if len(lst) else [1], info["params_original_shapes"]) | |
| ) | |
| *vec2dict_args, _ = get_chunks(param_keys, original_shapes) | |
| vec2dict_args = [param_keys, original_shapes] + vec2dict_args | |
| vec2dict = partial( | |
| vec2statedict, | |
| **dict( | |
| zip( | |
| [ | |
| "keys", | |
| "original_shapes", | |
| "selected_chunks", | |
| "position", | |
| "U_matrix_shape", | |
| ], | |
| vec2dict_args, | |
| ) | |
| ), | |
| ) | |
| internal_mean = gaussian_params_dict["internal"][0] | |
| # Global effect | |
| global_fx = instantiate(fx_config) | |
| # global_fx.eval() | |
| global_fx.load_state_dict(vec2dict(internal_mean), strict=False) | |
| ndim_dict = {k: v.ndim for k, v in global_fx.state_dict().items()} | |
| to_fx_state_dict = lambda x: { | |
| k: v[0] if ndim_dict[k] == 0 else v for k, v in vec2dict(x).items() | |
| } | |
| meter = pyln.Meter(44100) | |
| def get_embedding_model(embedding: str) -> Callable: | |
| device = environ.get("DEVICE", "cpu") | |
| match embedding: | |
| case "afx-rep": | |
| afx_rep = load_param_model().to(device) | |
| two_chs_emb_fn = lambda x: get_param_embeds(x, afx_rep, 44100) | |
| case "mfcc": | |
| mfcc = load_mfcc_feature_extractor().to(device) | |
| two_chs_emb_fn = lambda x: get_feature_embeds(x, mfcc) | |
| case "mir": | |
| mir = load_mir_feature_extractor().to(device) | |
| two_chs_emb_fn = lambda x: get_feature_embeds(x, mir) | |
| case _: | |
| raise ValueError(f"Unknown encoder: {embedding}") | |
| return two_chs_emb_fn | |
| def get_regressor() -> Callable: | |
| with open(CKPT_PATH / "config.yaml") as f: | |
| config = yaml.safe_load(f) | |
| model_config = config["model"] | |
| checkpoints = (CKPT_PATH / "checkpoints").glob("*val_loss*.ckpt") | |
| lowest_checkpoint = min(checkpoints, key=lambda x: float(x.stem.split("=")[-1])) | |
| last_ckpt = torch.load(lowest_checkpoint, map_location="cpu") | |
| model = chain_functions(remove_window_fn, jsonparse2hydra, instantiate)( | |
| model_config | |
| ) | |
| model.load_state_dict(last_ckpt["state_dict"]) | |
| device = environ.get("DEVICE", "cpu") | |
| model = model.to(device) | |
| model.eval() | |
| param_stats = torch.load(CKPT_PATH / "param_stats.pt") | |
| param_mu, param_std = ( | |
| param_stats["mu"].float().to(device), | |
| param_stats["std"].float().to(device), | |
| ) | |
| regressor = lambda wet: model(wet, dry=None) * param_std + param_mu | |
| return regressor | |
| def convert2float(sr: int, x: np.ndarray) -> np.ndarray: | |
| if sr != 44100: | |
| x = resample(x, sr, 44100) | |
| if x.dtype.kind != "f": | |
| x = x / 32768.0 | |
| if x.ndim == 1: | |
| x = x[:, None] | |
| return x | |
| # @spaces.GPU(duration=60) | |
| def inference( | |
| input_audio, | |
| ref_audio, | |
| method, | |
| dataset, | |
| embedding, | |
| mid_side, | |
| steps, | |
| prior_weight, | |
| optimiser, | |
| lr, | |
| progress=gr.Progress(track_tqdm=True), | |
| ): | |
| # close all figures to avoid too many open figures | |
| plt.close("all") | |
| device = environ.get("DEVICE", "cpu") | |
| if method == "Mean": | |
| return gaussian_params_dict[dataset][0].to(device) | |
| ref = convert2float(*ref_audio) | |
| ref_loudness = meter.integrated_loudness(ref) | |
| ref = pyln.normalize.loudness(ref, ref_loudness, -18.0) | |
| ref = torch.from_numpy(ref).float().T.unsqueeze(0).to(device) | |
| if method == "Regression": | |
| regressor = get_regressor() | |
| with torch.no_grad(): | |
| return regressor(ref).mean(0) | |
| y = convert2float(*input_audio) | |
| loudness = meter.integrated_loudness(y) | |
| y = pyln.normalize.loudness(y, loudness, -18.0) | |
| y = torch.from_numpy(y).float().T.unsqueeze(0).to(device) | |
| if y.shape[1] != 1: | |
| y = y.mean(dim=1, keepdim=True) | |
| fx = deepcopy(global_fx).to(device) | |
| fx.train() | |
| two_chs_emb_fn = chain_functions( | |
| hadamard if mid_side else lambda x: x, | |
| get_embedding_model(embedding), | |
| ) | |
| match method: | |
| case "Nearest Neighbour": | |
| vec = find_closest_training_sample( | |
| fx, | |
| two_chs_emb_fn, | |
| to_fx_state_dict, | |
| preset_dict[dataset].to(device), | |
| ref, | |
| y, | |
| progress, | |
| ) | |
| case "ST-ITO": | |
| vec = one_evaluation( | |
| fx, | |
| two_chs_emb_fn, | |
| to_fx_state_dict, | |
| partial( | |
| logp_x, *[x.to(device) for x in gaussian_params_dict["internal"]] | |
| ), | |
| internal_mean.to(device), | |
| ref, | |
| y, | |
| optimiser_type=optimiser, | |
| lr=lr, | |
| steps=steps, | |
| weight=prior_weight, | |
| progress=progress, | |
| ) | |
| case _: | |
| raise ValueError(f"Unknown method: {method}") | |
| return vec | |
| # @spaces.GPU(duration=10) | |
| def render(y, remove_approx, ratio, vec): | |
| device = environ.get("DEVICE", "cpu") | |
| y = convert2float(*y) | |
| loudness = meter.integrated_loudness(y) | |
| y = pyln.normalize.loudness(y, loudness, -18.0) | |
| y = torch.from_numpy(y).float().T.unsqueeze(0).to(device) | |
| if y.shape[1] != 1: | |
| y = y.mean(dim=1, keepdim=True) | |
| if remove_approx: | |
| infer_fx = instantiate(rt_config).to(device) | |
| else: | |
| infer_fx = instantiate(fx_config).to(device) | |
| infer_fx.load_state_dict(vec2dict(vec), strict=False) | |
| # fx.apply(partial(clip_delay_eq_Q, Q=0.707)) | |
| infer_fx.eval() | |
| with torch.no_grad(): | |
| direct, wet = infer_fx(y) | |
| direct = direct.squeeze(0).T.cpu().numpy() | |
| wet = wet.squeeze(0).T.cpu().numpy() | |
| angle = ratio * math.pi * 0.5 | |
| test_clipping = direct + wet | |
| # rendered = fx(y).squeeze(0).T.numpy() | |
| if np.max(np.abs(test_clipping)) > 1: | |
| scaler = np.max(np.abs(test_clipping)) | |
| # rendered = rendered / scaler | |
| direct = direct / scaler | |
| wet = wet / scaler | |
| rendered = math.sqrt(2) * (math.cos(angle) * direct + math.sin(angle) * wet) | |
| return ( | |
| (44100, (rendered * 32768).astype(np.int16)), | |
| (44100, (direct * 32768).astype(np.int16)), | |
| ( | |
| 44100, | |
| (wet * 32768).astype(np.int16), | |
| ), | |
| ) | |
| def model2json(fx): | |
| fx_names = ["PK1", "PK2", "LS", "HS", "LP", "HP", "DRC"] | |
| results = {k: v.toJSON() for k, v in zip(fx_names, fx)} | { | |
| "Panner": fx[7].pan.toJSON() | |
| } | |
| spatial_fx = { | |
| "DLY": fx[7].effects[0].toJSON() | {"LP": fx[7].effects[0].eq.toJSON()}, | |
| "FDN": fx[7].effects[1].toJSON() | |
| | { | |
| "Tone correction PEQ": { | |
| k: v.toJSON() for k, v in zip(fx_names[:4], fx[7].effects[1].eq) | |
| } | |
| }, | |
| "Cross Send (dB)": fx[7].params.sends_0.log10().mul(20).item(), | |
| } | |
| return { | |
| "Direct": results, | |
| "Sends": spatial_fx, | |
| } | |
| def plot_eq(fx): | |
| fig, ax = plt.subplots(figsize=(6, 4), constrained_layout=True) | |
| w, eq_log_mags = get_log_mags_from_eq(fx[:6]) | |
| ax.plot(w, sum(eq_log_mags), color="black", linestyle="-") | |
| for i, eq_log_mag in enumerate(eq_log_mags): | |
| ax.plot(w, eq_log_mag, "k-", alpha=0.3) | |
| ax.fill_between(w, eq_log_mag, 0, facecolor="gray", edgecolor="none", alpha=0.1) | |
| ax.set_xlabel("Frequency (Hz)") | |
| ax.set_ylabel("Magnitude (dB)") | |
| ax.set_xlim(20, 20000) | |
| ax.set_ylim(-40, 20) | |
| ax.set_xscale("log") | |
| ax.grid() | |
| return fig | |
| def plot_comp(fx): | |
| fig, ax = plt.subplots(figsize=(6, 5), constrained_layout=True) | |
| comp = fx[6] | |
| cmp_th = comp.params.cmp_th.item() | |
| exp_th = comp.params.exp_th.item() | |
| cmp_ratio = comp.params.cmp_ratio.item() | |
| exp_ratio = comp.params.exp_ratio.item() | |
| make_up = comp.params.make_up.item() | |
| # print(cmp_ratio, cmp_th, exp_ratio, exp_th, make_up) | |
| comp_in = np.linspace(-80, 0, 100) | |
| comp_curve = np.where( | |
| comp_in > cmp_th, | |
| comp_in - (comp_in - cmp_th) * (cmp_ratio - 1) / cmp_ratio, | |
| comp_in, | |
| ) | |
| comp_out = ( | |
| np.where( | |
| comp_curve < exp_th, | |
| comp_curve - (exp_th - comp_curve) / exp_ratio, | |
| comp_curve, | |
| ) | |
| + make_up | |
| ) | |
| ax.plot(comp_in, comp_out, c="black", linestyle="-") | |
| ax.plot(comp_in, comp_in, c="r", alpha=0.5) | |
| ax.set_xlabel("Input Level (dB)") | |
| ax.set_ylabel("Output Level (dB)") | |
| ax.set_xlim(-80, 0) | |
| ax.set_ylim(-80, 0) | |
| ax.grid() | |
| return fig | |
| def plot_delay(fx): | |
| fig, ax = plt.subplots(figsize=(6, 4), constrained_layout=True) | |
| delay = fx[7].effects[0] | |
| w, eq_log_mags = get_log_mags_from_eq([delay.eq]) | |
| log_gain = delay.params.gain.log10().item() * 20 | |
| d = delay.params.delay.item() / 1000 | |
| log_mag = sum(eq_log_mags) | |
| ax.plot(w, log_mag + log_gain, color="black", linestyle="-") | |
| log_feedback = delay.params.feedback.log10().item() * 20 | |
| for i in range(1, 10): | |
| feedback_log_mag = log_mag * (i + 1) + log_feedback * i + log_gain | |
| ax.plot( | |
| w, | |
| feedback_log_mag, | |
| c="black", | |
| alpha=max(0, (10 - i * d * 4) / 10), | |
| linestyle="-", | |
| ) | |
| ax.set_xscale("log") | |
| ax.set_xlim(20, 20000) | |
| ax.set_ylim(-80, 0) | |
| ax.set_xlabel("Frequency (Hz)") | |
| ax.set_ylabel("Magnitude (dB)") | |
| ax.grid() | |
| return fig | |
| def plot_reverb(fx): | |
| fig, ax = plt.subplots(figsize=(6, 4), constrained_layout=True) | |
| fdn = fx[7].effects[1] | |
| w, eq_log_mags = get_log_mags_from_eq(fdn.eq) | |
| bc = fdn.params.c.norm() * fdn.params.b.norm() | |
| log_bc = torch.log10(bc).item() * 20 | |
| # eq_log_mags = [x + log_bc / len(eq_log_mags) for x in eq_log_mags] | |
| # ax.plot(w, sum(eq_log_mags), color="black", linestyle="-") | |
| eq_log_mags = sum(eq_log_mags) + log_bc | |
| ax.plot(w, eq_log_mags, color="black", linestyle="-") | |
| ax.set_xlabel("Frequency (Hz)") | |
| ax.set_ylabel("Magnitude (dB)") | |
| ax.set_xlim(20, 20000) | |
| ax.set_ylim(-40, 20) | |
| ax.set_xscale("log") | |
| ax.grid() | |
| return fig | |
| def plot_t60(fx): | |
| fig, ax = plt.subplots(figsize=(6, 4), constrained_layout=True) | |
| fdn = fx[7].effects[1] | |
| gamma = fdn.params.gamma.squeeze().numpy() | |
| delays = fdn.delays.numpy() | |
| w = np.linspace(0, 22050, gamma.size) | |
| t60 = -60 / (20 * np.log10(gamma + 1e-10) / np.min(delays)) / 44100 | |
| ax.plot(w, t60, color="black", linestyle="-") | |
| ax.set_xlabel("Frequency (Hz)") | |
| ax.set_ylabel("T60 (s)") | |
| ax.set_xlim(20, 20000) | |
| ax.set_ylim(0, 9) | |
| ax.set_xscale("log") | |
| ax.grid() | |
| return fig | |
| def vec2fx(x): | |
| fx = deepcopy(global_fx) | |
| fx.load_state_dict(vec2dict(x), strict=False) | |
| # fx.apply(partial(clip_delay_eq_Q, Q=0.707)) | |
| return fx | |
| with gr.Blocks() as demo: | |
| fx_params = gr.State(internal_mean) | |
| # fx = vec2fx(fx_params.value) | |
| # sr, y = read(EXAMPLE_PATH) | |
| default_audio_block = partial(gr.Audio, type="numpy", loop=True) | |
| gr.Markdown( | |
| title_md, | |
| elem_id="title", | |
| ) | |
| with gr.Row(): | |
| gr.Markdown( | |
| description_md, | |
| elem_id="description", | |
| ) | |
| gr.Image("overview.png", elem_id="diagram", height=500) | |
| with gr.Row(): | |
| with gr.Column(): | |
| audio_input = default_audio_block( | |
| sources="upload", | |
| label="Input Audio", | |
| # value=(sr, y) | |
| ) | |
| audio_reference = default_audio_block( | |
| sources="upload", | |
| label="Reference Audio", | |
| ) | |
| with gr.Row(): | |
| method_dropdown = gr.Dropdown( | |
| ["Mean", "Nearest Neighbour", "ST-ITO", "Regression"], | |
| value="ST-ITO", | |
| label=f"Style Transfer Method", | |
| interactive=True, | |
| ) | |
| process_button = gr.Button( | |
| "Run", elem_id="render-button", variant="primary" | |
| ) | |
| with gr.Column(): | |
| audio_output = default_audio_block(label="Output Audio", interactive=False) | |
| dry_wet_ratio = gr.Slider( | |
| minimum=0, | |
| maximum=1, | |
| value=0.5, | |
| label="Dry/Wet Ratio", | |
| interactive=True, | |
| ) | |
| direct_output = default_audio_block(label="Direct Audio", interactive=False) | |
| wet_output = default_audio_block(label="Wet Audio", interactive=False) | |
| with gr.Row(): | |
| with gr.Column(): | |
| _ = gr.Markdown("## Control Parameters") | |
| with gr.Row(): | |
| dataset_dropdown = gr.Dropdown( | |
| [("Internal", "internal"), ("MedleyDB", "medleydb")], | |
| label="Prior Distribution (Dataset)", | |
| info="This parameter has no effect when using the ST-ITO and Regression methods.", | |
| value="internal", | |
| interactive=True, | |
| ) | |
| embedding_dropdown = gr.Dropdown( | |
| [("AFx-Rep", "afx-rep"), ("MFCC", "mfcc"), ("MIR Features", "mir")], | |
| label="Embedding Model", | |
| info="This parameter has no effect when using the Mean and Regression methods.", | |
| value="afx-rep", | |
| interactive=True, | |
| ) | |
| # with gr.Column(): | |
| remove_approx_checkbox = gr.Checkbox( | |
| label="Use Real-time Effects", | |
| info="Use real-time delay and reverb effects instead of approximated ones.", | |
| value=False, | |
| interactive=True, | |
| ) | |
| mid_side_checkbox = gr.Checkbox( | |
| label="Use Mid-Side Processing", | |
| info="This option has no effect when using the Mean and Regression methods.", | |
| value=True, | |
| interactive=True, | |
| ) | |
| with gr.Column(): | |
| _ = gr.Markdown("## Parameters for ST-ITO Method") | |
| with gr.Row(): | |
| optimisation_steps = gr.Slider( | |
| minimum=1, | |
| maximum=2000, | |
| value=100, | |
| step=1, | |
| label="Number of Optimisation Steps", | |
| interactive=True, | |
| ) | |
| prior_weight = gr.Dropdown( | |
| [ | |
| ("0", 0.0), | |
| ("0.001", 0.001), | |
| ("0.01", 0.01), | |
| ("0.1", 0.1), | |
| ("1", 1.0), | |
| ], | |
| info="Weight of the prior distribution in the loss function. A higher value means the model will try to stay closer to the prior distribution.", | |
| value=0.01, | |
| label="Prior Weight", | |
| interactive=True, | |
| ) | |
| optimiser_dropdown = gr.Dropdown( | |
| [ | |
| "Adadelta", | |
| "Adafactor", | |
| "Adagrad", | |
| "Adam", | |
| "AdamW", | |
| "Adamax", | |
| "RMSprop", | |
| "ASGD", | |
| "NAdam", | |
| "RAdam", | |
| "SGD", | |
| ], | |
| value="Adam", | |
| label="Optimiser", | |
| interactive=True, | |
| ) | |
| lr_slider = gr.Dropdown( | |
| [("0.0001", 1e-4), ("0.001", 1e-3), ("0.01", 1e-2), ("0.1", 1e-1)], | |
| value=1e-2, | |
| label="Learning Rate", | |
| interactive=True, | |
| ) | |
| _ = gr.Markdown("## Effect Parameters Visualisation") | |
| with gr.Row(): | |
| peq_plot = gr.Plot( | |
| plot_eq(global_fx), label="PEQ Frequency Response", elem_id="peq-plot" | |
| ) | |
| comp_plot = gr.Plot( | |
| plot_comp(global_fx), label="Compressor Curve", elem_id="comp-plot" | |
| ) | |
| with gr.Row(): | |
| delay_plot = gr.Plot( | |
| plot_delay(global_fx), | |
| label="Delay Frequency Response", | |
| elem_id="delay-plot", | |
| ) | |
| reverb_plot = gr.Plot( | |
| plot_reverb(global_fx), | |
| label="Tone Correction PEQ", | |
| elem_id="reverb-plot", | |
| min_width=160, | |
| ) | |
| t60_plot = gr.Plot( | |
| plot_t60(global_fx), label="Decay Time", elem_id="t60-plot", min_width=160 | |
| ) | |
| _ = gr.Markdown("## Effect Settings JSON") | |
| with gr.Row(): | |
| json_output = gr.JSON( | |
| model2json(global_fx), label="Effect Settings", max_height=800, open=True | |
| ) | |
| process_button.click( | |
| spaces.GPU(duration=60)( | |
| chain_functions( | |
| lambda audio, approx, ratio, *args: ( | |
| audio, | |
| approx, | |
| ratio, | |
| inference(audio, *args), | |
| ), | |
| lambda audio, approx, ratio, vec: ( | |
| vec2fx(vec), | |
| *render(audio, approx, ratio, vec), | |
| vec, | |
| ), | |
| lambda fx, *args: ( | |
| *args, | |
| *map( | |
| lambda f: f(fx), | |
| [ | |
| plot_eq, | |
| plot_comp, | |
| plot_delay, | |
| plot_reverb, | |
| plot_t60, | |
| model2json, | |
| ], | |
| ), | |
| ), | |
| lambda out, dir, wet, vec, *args: ( | |
| out, | |
| dir, | |
| wet, | |
| vec.cpu(), | |
| *args, | |
| ), | |
| ) | |
| ), | |
| inputs=[ | |
| audio_input, | |
| remove_approx_checkbox, | |
| dry_wet_ratio, | |
| audio_reference, | |
| method_dropdown, | |
| dataset_dropdown, | |
| embedding_dropdown, | |
| mid_side_checkbox, | |
| optimisation_steps, | |
| prior_weight, | |
| optimiser_dropdown, | |
| lr_slider, | |
| ], | |
| outputs=[ | |
| audio_output, | |
| direct_output, | |
| wet_output, | |
| fx_params, | |
| peq_plot, | |
| comp_plot, | |
| delay_plot, | |
| reverb_plot, | |
| t60_plot, | |
| json_output, | |
| ], | |
| ) | |
| remove_approx_checkbox.change( | |
| spaces.GPU(duration=10)( | |
| lambda in_audio, out_audio, di, wet, *args: ( | |
| (out_audio, di, wet) if in_audio is None else render(in_audio, *args) | |
| ) | |
| ), | |
| inputs=[ | |
| audio_input, | |
| audio_output, | |
| direct_output, | |
| wet_output, | |
| remove_approx_checkbox, | |
| dry_wet_ratio, | |
| fx_params, | |
| ], | |
| outputs=[ | |
| audio_output, | |
| direct_output, | |
| wet_output, | |
| ], | |
| ) | |
| dry_wet_ratio.input( | |
| chain_functions( | |
| lambda _, *args: (_, *map(lambda x: x[1] / 32768, args)), | |
| lambda ratio, d, w: math.sqrt(2) | |
| * ( | |
| math.cos(ratio * math.pi * 0.5) * d | |
| + math.sin(ratio * math.pi * 0.5) * w | |
| ), | |
| lambda x: (44100, (x * 32768).astype(np.int16)), | |
| ), | |
| inputs=[dry_wet_ratio, direct_output, wet_output], | |
| outputs=[audio_output], | |
| ) | |
| demo.launch() | |