Spaces:
Running on Zero
Running on Zero
Commit ·
48e1ce4
1
Parent(s): 0c4dc06
feat: add regression model and update inference methods
Browse files- app.py +73 -22
- ito.py +4 -20
- modules/model.py +33 -0
app.py
CHANGED
|
@@ -19,7 +19,12 @@ from typing import Tuple, List, Optional, Union, Callable
|
|
| 19 |
|
| 20 |
from modules.utils import vec2statedict, get_chunks
|
| 21 |
from modules.fx import clip_delay_eq_Q, hadamard
|
| 22 |
-
from utils import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
from ito import find_closest_training_sample, one_evaluation
|
| 24 |
from st_ito.utils import (
|
| 25 |
load_param_model,
|
|
@@ -62,6 +67,7 @@ PRESET_PATH = {
|
|
| 62 |
"internal": Path("presets/internal/"),
|
| 63 |
"medleydb": Path("presets/medleydb/"),
|
| 64 |
}
|
|
|
|
| 65 |
|
| 66 |
PCA_PARAM_FILE = "gaussian.npz"
|
| 67 |
INFO_PATH = "info.json"
|
|
@@ -100,6 +106,7 @@ def logp_x(mu, cov, cov_logdet, x):
|
|
| 100 |
diff = x - mu
|
| 101 |
b = torch.linalg.solve(cov, diff)
|
| 102 |
norm = diff @ b
|
|
|
|
| 103 |
return -0.5 * (norm + cov_logdet + mu.shape[0] * math.log(2 * math.pi))
|
| 104 |
|
| 105 |
|
|
@@ -168,6 +175,32 @@ def get_embedding_model(embedding: str) -> Callable:
|
|
| 168 |
return two_chs_emb_fn
|
| 169 |
|
| 170 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 171 |
def convert2float(sr: int, x: np.ndarray) -> np.ndarray:
|
| 172 |
if sr != 44100:
|
| 173 |
x = resample(x, sr, 44100)
|
|
@@ -200,6 +233,10 @@ def inference(
|
|
| 200 |
ref = pyln.normalize.loudness(ref, ref_loudness, -18.0)
|
| 201 |
ref = torch.from_numpy(ref).float().T.unsqueeze(0).to(device)
|
| 202 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 203 |
y = convert2float(*input_audio)
|
| 204 |
loudness = meter.integrated_loudness(y)
|
| 205 |
y = pyln.normalize.loudness(y, loudness, -18.0)
|
|
@@ -219,7 +256,13 @@ def inference(
|
|
| 219 |
match method:
|
| 220 |
case "Nearest Neighbour":
|
| 221 |
vec = find_closest_training_sample(
|
| 222 |
-
fx,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 223 |
)
|
| 224 |
|
| 225 |
case "ST-ITO":
|
|
@@ -228,7 +271,7 @@ def inference(
|
|
| 228 |
two_chs_emb_fn,
|
| 229 |
to_fx_state_dict,
|
| 230 |
partial(logp_x, *[x.to(device) for x in gaussian_params_dict[dataset]]),
|
| 231 |
-
|
| 232 |
ref,
|
| 233 |
y,
|
| 234 |
optimiser_type=optimiser,
|
|
@@ -438,13 +481,7 @@ with gr.Blocks() as demo:
|
|
| 438 |
# fx = vec2fx(fx_params.value)
|
| 439 |
# sr, y = read(EXAMPLE_PATH)
|
| 440 |
|
| 441 |
-
default_pc_slider = partial(
|
| 442 |
-
gr.Slider, minimum=SLIDER_MIN, maximum=SLIDER_MAX, interactive=True, value=0
|
| 443 |
-
)
|
| 444 |
default_audio_block = partial(gr.Audio, type="numpy", loop=True)
|
| 445 |
-
default_freq_slider = partial(gr.Slider, label="Frequency (Hz)", interactive=True)
|
| 446 |
-
default_gain_slider = partial(gr.Slider, label="Gain (dB)", interactive=True)
|
| 447 |
-
default_q_slider = partial(gr.Slider, label="Q", interactive=True)
|
| 448 |
|
| 449 |
gr.Markdown(
|
| 450 |
title_md,
|
|
@@ -468,6 +505,15 @@ with gr.Blocks() as demo:
|
|
| 468 |
sources="upload",
|
| 469 |
label="Reference Audio",
|
| 470 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 471 |
|
| 472 |
with gr.Column():
|
| 473 |
audio_output = default_audio_block(label="Output Audio", interactive=False)
|
|
@@ -481,18 +527,8 @@ with gr.Blocks() as demo:
|
|
| 481 |
direct_output = default_audio_block(label="Direct Audio", interactive=False)
|
| 482 |
wet_output = default_audio_block(label="Wet Audio", interactive=False)
|
| 483 |
|
|
|
|
| 484 |
with gr.Row():
|
| 485 |
-
process_button = gr.Button("Run", elem_id="render-button", variant="primary")
|
| 486 |
-
reset_button = gr.Button("Reset", elem_id="reset-button")
|
| 487 |
-
|
| 488 |
-
_ = gr.Markdown("## Common Parameters")
|
| 489 |
-
with gr.Row():
|
| 490 |
-
method_dropdown = gr.Dropdown(
|
| 491 |
-
["Mean", "Nearest Neighbour", "ST-ITO", "Regression"],
|
| 492 |
-
value="ST-ITO",
|
| 493 |
-
label=f"Style Transfer Method",
|
| 494 |
-
interactive=True,
|
| 495 |
-
)
|
| 496 |
dataset_dropdown = gr.Dropdown(
|
| 497 |
[("Internal", "internal"), ("MedleyDB", "medleydb")],
|
| 498 |
label="Prior Distribution",
|
|
@@ -503,7 +539,7 @@ with gr.Blocks() as demo:
|
|
| 503 |
embedding_dropdown = gr.Dropdown(
|
| 504 |
[("AFx-Rep", "afx-rep"), ("MFCC", "mfcc"), ("MIR Features", "mir")],
|
| 505 |
label="Embedding Model",
|
| 506 |
-
info="This parameter
|
| 507 |
value="afx-rep",
|
| 508 |
interactive=True,
|
| 509 |
)
|
|
@@ -516,6 +552,7 @@ with gr.Blocks() as demo:
|
|
| 516 |
)
|
| 517 |
mid_side_checkbox = gr.Checkbox(
|
| 518 |
label="Use Mid-Side Processing",
|
|
|
|
| 519 |
value=True,
|
| 520 |
interactive=True,
|
| 521 |
)
|
|
@@ -524,7 +561,7 @@ with gr.Blocks() as demo:
|
|
| 524 |
with gr.Row():
|
| 525 |
optimisation_steps = gr.Slider(
|
| 526 |
minimum=1,
|
| 527 |
-
maximum=
|
| 528 |
value=100,
|
| 529 |
step=1,
|
| 530 |
label="Number of Optimisation Steps",
|
|
@@ -620,5 +657,19 @@ with gr.Blocks() as demo:
|
|
| 620 |
],
|
| 621 |
)
|
| 622 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 623 |
|
| 624 |
demo.launch()
|
|
|
|
| 19 |
|
| 20 |
from modules.utils import vec2statedict, get_chunks
|
| 21 |
from modules.fx import clip_delay_eq_Q, hadamard
|
| 22 |
+
from utils import (
|
| 23 |
+
get_log_mags_from_eq,
|
| 24 |
+
chain_functions,
|
| 25 |
+
remove_window_fn,
|
| 26 |
+
jsonparse2hydra,
|
| 27 |
+
)
|
| 28 |
from ito import find_closest_training_sample, one_evaluation
|
| 29 |
from st_ito.utils import (
|
| 30 |
load_param_model,
|
|
|
|
| 67 |
"internal": Path("presets/internal/"),
|
| 68 |
"medleydb": Path("presets/medleydb/"),
|
| 69 |
}
|
| 70 |
+
CKPT_PATH = Path("reg-ckpts/")
|
| 71 |
|
| 72 |
PCA_PARAM_FILE = "gaussian.npz"
|
| 73 |
INFO_PATH = "info.json"
|
|
|
|
| 106 |
diff = x - mu
|
| 107 |
b = torch.linalg.solve(cov, diff)
|
| 108 |
norm = diff @ b
|
| 109 |
+
assert torch.all(norm >= 0), "Negative norm detected, check covariance matrix."
|
| 110 |
return -0.5 * (norm + cov_logdet + mu.shape[0] * math.log(2 * math.pi))
|
| 111 |
|
| 112 |
|
|
|
|
| 175 |
return two_chs_emb_fn
|
| 176 |
|
| 177 |
|
| 178 |
+
def get_regressor() -> Callable:
|
| 179 |
+
with open(CKPT_PATH / "config.yaml") as f:
|
| 180 |
+
config = yaml.safe_load(f)
|
| 181 |
+
|
| 182 |
+
model_config = config["model"]
|
| 183 |
+
|
| 184 |
+
checkpoints = (CKPT_PATH / "checkpoints").glob("*val_loss*.ckpt")
|
| 185 |
+
lowest_checkpoint = min(checkpoints, key=lambda x: float(x.stem.split("=")[-1]))
|
| 186 |
+
last_ckpt = torch.load(lowest_checkpoint, map_location="cpu")
|
| 187 |
+
model = chain_functions(remove_window_fn, jsonparse2hydra, instantiate)(
|
| 188 |
+
model_config
|
| 189 |
+
)
|
| 190 |
+
model.load_state_dict(last_ckpt["state_dict"])
|
| 191 |
+
|
| 192 |
+
device = Path("DEVICE.txt").read_text()
|
| 193 |
+
model = model.to(device)
|
| 194 |
+
model.eval()
|
| 195 |
+
param_stats = torch.load(CKPT_PATH / "param_stats.pt")
|
| 196 |
+
param_mu, param_std = (
|
| 197 |
+
param_stats["mu"].float().to(device),
|
| 198 |
+
param_stats["std"].float().to(device),
|
| 199 |
+
)
|
| 200 |
+
regressor = lambda wet: model(wet, dry=None) * param_std + param_mu
|
| 201 |
+
return regressor
|
| 202 |
+
|
| 203 |
+
|
| 204 |
def convert2float(sr: int, x: np.ndarray) -> np.ndarray:
|
| 205 |
if sr != 44100:
|
| 206 |
x = resample(x, sr, 44100)
|
|
|
|
| 233 |
ref = pyln.normalize.loudness(ref, ref_loudness, -18.0)
|
| 234 |
ref = torch.from_numpy(ref).float().T.unsqueeze(0).to(device)
|
| 235 |
|
| 236 |
+
if method == "Regression":
|
| 237 |
+
regressor = get_regressor()
|
| 238 |
+
return regressor(ref).mean(0)
|
| 239 |
+
|
| 240 |
y = convert2float(*input_audio)
|
| 241 |
loudness = meter.integrated_loudness(y)
|
| 242 |
y = pyln.normalize.loudness(y, loudness, -18.0)
|
|
|
|
| 256 |
match method:
|
| 257 |
case "Nearest Neighbour":
|
| 258 |
vec = find_closest_training_sample(
|
| 259 |
+
fx,
|
| 260 |
+
two_chs_emb_fn,
|
| 261 |
+
to_fx_state_dict,
|
| 262 |
+
preset_dict[dataset].to(device),
|
| 263 |
+
ref,
|
| 264 |
+
y,
|
| 265 |
+
progress,
|
| 266 |
)
|
| 267 |
|
| 268 |
case "ST-ITO":
|
|
|
|
| 271 |
two_chs_emb_fn,
|
| 272 |
to_fx_state_dict,
|
| 273 |
partial(logp_x, *[x.to(device) for x in gaussian_params_dict[dataset]]),
|
| 274 |
+
gaussian_params_dict[dataset][0].to(device),
|
| 275 |
ref,
|
| 276 |
y,
|
| 277 |
optimiser_type=optimiser,
|
|
|
|
| 481 |
# fx = vec2fx(fx_params.value)
|
| 482 |
# sr, y = read(EXAMPLE_PATH)
|
| 483 |
|
|
|
|
|
|
|
|
|
|
| 484 |
default_audio_block = partial(gr.Audio, type="numpy", loop=True)
|
|
|
|
|
|
|
|
|
|
| 485 |
|
| 486 |
gr.Markdown(
|
| 487 |
title_md,
|
|
|
|
| 505 |
sources="upload",
|
| 506 |
label="Reference Audio",
|
| 507 |
)
|
| 508 |
+
method_dropdown = gr.Dropdown(
|
| 509 |
+
["Mean", "Nearest Neighbour", "ST-ITO", "Regression"],
|
| 510 |
+
value="ST-ITO",
|
| 511 |
+
label=f"Style Transfer Method",
|
| 512 |
+
interactive=True,
|
| 513 |
+
)
|
| 514 |
+
process_button = gr.Button(
|
| 515 |
+
"Run", elem_id="render-button", variant="primary"
|
| 516 |
+
)
|
| 517 |
|
| 518 |
with gr.Column():
|
| 519 |
audio_output = default_audio_block(label="Output Audio", interactive=False)
|
|
|
|
| 527 |
direct_output = default_audio_block(label="Direct Audio", interactive=False)
|
| 528 |
wet_output = default_audio_block(label="Wet Audio", interactive=False)
|
| 529 |
|
| 530 |
+
_ = gr.Markdown("## Control Parameters")
|
| 531 |
with gr.Row():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 532 |
dataset_dropdown = gr.Dropdown(
|
| 533 |
[("Internal", "internal"), ("MedleyDB", "medleydb")],
|
| 534 |
label="Prior Distribution",
|
|
|
|
| 539 |
embedding_dropdown = gr.Dropdown(
|
| 540 |
[("AFx-Rep", "afx-rep"), ("MFCC", "mfcc"), ("MIR Features", "mir")],
|
| 541 |
label="Embedding Model",
|
| 542 |
+
info="This parameter has no effect when using the Mean and Regression methods.",
|
| 543 |
value="afx-rep",
|
| 544 |
interactive=True,
|
| 545 |
)
|
|
|
|
| 552 |
)
|
| 553 |
mid_side_checkbox = gr.Checkbox(
|
| 554 |
label="Use Mid-Side Processing",
|
| 555 |
+
info="This option has no effect when using the Mean and Regression methods.",
|
| 556 |
value=True,
|
| 557 |
interactive=True,
|
| 558 |
)
|
|
|
|
| 561 |
with gr.Row():
|
| 562 |
optimisation_steps = gr.Slider(
|
| 563 |
minimum=1,
|
| 564 |
+
maximum=2000,
|
| 565 |
value=100,
|
| 566 |
step=1,
|
| 567 |
label="Number of Optimisation Steps",
|
|
|
|
| 657 |
],
|
| 658 |
)
|
| 659 |
|
| 660 |
+
dry_wet_ratio.input(
|
| 661 |
+
chain_functions(
|
| 662 |
+
lambda _, *args: (_, *map(lambda x: x[1] / 32768, args)),
|
| 663 |
+
lambda ratio, d, w: math.sqrt(2)
|
| 664 |
+
* (
|
| 665 |
+
math.cos(ratio * math.pi * 0.5) * d
|
| 666 |
+
+ math.sin(ratio * math.pi * 0.5) * w
|
| 667 |
+
),
|
| 668 |
+
lambda x: (44100, (x * 32768).astype(np.int16)),
|
| 669 |
+
),
|
| 670 |
+
inputs=[dry_wet_ratio, direct_output, wet_output],
|
| 671 |
+
outputs=[audio_output],
|
| 672 |
+
)
|
| 673 |
+
|
| 674 |
|
| 675 |
demo.launch()
|
ito.py
CHANGED
|
@@ -26,25 +26,6 @@ from st_ito.utils import (
|
|
| 26 |
from utils import remove_window_fn, jsonparse2hydra
|
| 27 |
|
| 28 |
|
| 29 |
-
def get_reference_query_chunks(dry_audio, wet_audio, chunk_size, sr):
|
| 30 |
-
dry = dry_audio.unfold(1, chunk_size, chunk_size).transpose(0, 1)
|
| 31 |
-
wet = wet_audio.unfold(1, chunk_size, chunk_size).transpose(0, 1)
|
| 32 |
-
|
| 33 |
-
max_filtered = F.max_pool1d(wet.mean(1).abs(), int(sr * 0.05), stride=1)
|
| 34 |
-
active_mask = torch.quantile(max_filtered, 0.5, dim=1) > 0.001 # -60 dB
|
| 35 |
-
if not active_mask.any():
|
| 36 |
-
raise ValueError("No active frames")
|
| 37 |
-
elif active_mask.count_nonzero() < 2:
|
| 38 |
-
raise ValueError("Too few active frames")
|
| 39 |
-
|
| 40 |
-
dry = dry[active_mask]
|
| 41 |
-
wet = wet[active_mask]
|
| 42 |
-
|
| 43 |
-
ref_audio = wet[::2].contiguous()
|
| 44 |
-
raw_audio = dry[1::2].contiguous()
|
| 45 |
-
return ref_audio, raw_audio
|
| 46 |
-
|
| 47 |
-
|
| 48 |
def logp_y_given_x(y, mu, std):
|
| 49 |
cos_dist = torch.arccos(y @ mu)
|
| 50 |
return -0.5 * (cos_dist / std).pow(2) - 0.5 * math.log(2 * math.pi) - std.log()
|
|
@@ -130,6 +111,7 @@ def find_closest_training_sample(
|
|
| 130 |
training_samples: torch.Tensor,
|
| 131 |
ref_audio: torch.Tensor,
|
| 132 |
raw_audio: torch.Tensor,
|
|
|
|
| 133 |
) -> torch.Tensor:
|
| 134 |
|
| 135 |
peak_scaler = 1 / ref_audio.abs().max()
|
|
@@ -167,7 +149,9 @@ def find_closest_training_sample(
|
|
| 167 |
)
|
| 168 |
|
| 169 |
best_logp, best_param = reduce(
|
| 170 |
-
reduce_closure,
|
|
|
|
|
|
|
| 171 |
)
|
| 172 |
print(f"Best log-likelihood: {best_logp}")
|
| 173 |
return best_param
|
|
|
|
| 26 |
from utils import remove_window_fn, jsonparse2hydra
|
| 27 |
|
| 28 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
def logp_y_given_x(y, mu, std):
|
| 30 |
cos_dist = torch.arccos(y @ mu)
|
| 31 |
return -0.5 * (cos_dist / std).pow(2) - 0.5 * math.log(2 * math.pi) - std.log()
|
|
|
|
| 111 |
training_samples: torch.Tensor,
|
| 112 |
ref_audio: torch.Tensor,
|
| 113 |
raw_audio: torch.Tensor,
|
| 114 |
+
progress,
|
| 115 |
) -> torch.Tensor:
|
| 116 |
|
| 117 |
peak_scaler = 1 / ref_audio.abs().max()
|
|
|
|
| 149 |
)
|
| 150 |
|
| 151 |
best_logp, best_param = reduce(
|
| 152 |
+
reduce_closure,
|
| 153 |
+
progress.tqdm(training_samples.unbind(0)),
|
| 154 |
+
(-float("inf"), torch.tensor([])),
|
| 155 |
)
|
| 156 |
print(f"Best log-likelihood: {best_logp}")
|
| 157 |
return best_param
|
modules/model.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from functools import partial, reduce
|
| 5 |
+
from typing import Optional, List
|
| 6 |
+
from torchaudio.transforms import MelSpectrogram, MFCC
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class LogMelSpectrogram(MelSpectrogram):
|
| 10 |
+
def forward(self, waveform):
|
| 11 |
+
return super().forward(waveform).add(1e-8).log()
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class LogMFCC(MFCC):
|
| 15 |
+
def __init__(self, *args, **kwargs):
|
| 16 |
+
super().__init__(*args, log_mels=True, **kwargs)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class LightningSequential(nn.Sequential):
|
| 20 |
+
def __init__(self, modules: List[nn.Module]):
|
| 21 |
+
super().__init__(*modules)
|
| 22 |
+
|
| 23 |
+
def forward(self, *args):
|
| 24 |
+
return reduce(lambda x, f: f(*x) if isinstance(x, tuple) else f(x), self, args)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class ResidualWrapper(nn.Module):
|
| 28 |
+
def __init__(self, m: nn.Module):
|
| 29 |
+
super().__init__()
|
| 30 |
+
self.m = m
|
| 31 |
+
|
| 32 |
+
def forward(self, x):
|
| 33 |
+
return x + self.m(x)
|