yoyolicoris commited on
Commit
48e1ce4
·
1 Parent(s): 0c4dc06

feat: add regression model and update inference methods

Browse files
Files changed (3) hide show
  1. app.py +73 -22
  2. ito.py +4 -20
  3. modules/model.py +33 -0
app.py CHANGED
@@ -19,7 +19,12 @@ from typing import Tuple, List, Optional, Union, Callable
19
 
20
  from modules.utils import vec2statedict, get_chunks
21
  from modules.fx import clip_delay_eq_Q, hadamard
22
- from utils import get_log_mags_from_eq, chain_functions
 
 
 
 
 
23
  from ito import find_closest_training_sample, one_evaluation
24
  from st_ito.utils import (
25
  load_param_model,
@@ -62,6 +67,7 @@ PRESET_PATH = {
62
  "internal": Path("presets/internal/"),
63
  "medleydb": Path("presets/medleydb/"),
64
  }
 
65
 
66
  PCA_PARAM_FILE = "gaussian.npz"
67
  INFO_PATH = "info.json"
@@ -100,6 +106,7 @@ def logp_x(mu, cov, cov_logdet, x):
100
  diff = x - mu
101
  b = torch.linalg.solve(cov, diff)
102
  norm = diff @ b
 
103
  return -0.5 * (norm + cov_logdet + mu.shape[0] * math.log(2 * math.pi))
104
 
105
 
@@ -168,6 +175,32 @@ def get_embedding_model(embedding: str) -> Callable:
168
  return two_chs_emb_fn
169
 
170
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
  def convert2float(sr: int, x: np.ndarray) -> np.ndarray:
172
  if sr != 44100:
173
  x = resample(x, sr, 44100)
@@ -200,6 +233,10 @@ def inference(
200
  ref = pyln.normalize.loudness(ref, ref_loudness, -18.0)
201
  ref = torch.from_numpy(ref).float().T.unsqueeze(0).to(device)
202
 
 
 
 
 
203
  y = convert2float(*input_audio)
204
  loudness = meter.integrated_loudness(y)
205
  y = pyln.normalize.loudness(y, loudness, -18.0)
@@ -219,7 +256,13 @@ def inference(
219
  match method:
220
  case "Nearest Neighbour":
221
  vec = find_closest_training_sample(
222
- fx, two_chs_emb_fn, to_fx_state_dict, preset_dict[dataset], ref, y
 
 
 
 
 
 
223
  )
224
 
225
  case "ST-ITO":
@@ -228,7 +271,7 @@ def inference(
228
  two_chs_emb_fn,
229
  to_fx_state_dict,
230
  partial(logp_x, *[x.to(device) for x in gaussian_params_dict[dataset]]),
231
- internal_mean.to(device),
232
  ref,
233
  y,
234
  optimiser_type=optimiser,
@@ -438,13 +481,7 @@ with gr.Blocks() as demo:
438
  # fx = vec2fx(fx_params.value)
439
  # sr, y = read(EXAMPLE_PATH)
440
 
441
- default_pc_slider = partial(
442
- gr.Slider, minimum=SLIDER_MIN, maximum=SLIDER_MAX, interactive=True, value=0
443
- )
444
  default_audio_block = partial(gr.Audio, type="numpy", loop=True)
445
- default_freq_slider = partial(gr.Slider, label="Frequency (Hz)", interactive=True)
446
- default_gain_slider = partial(gr.Slider, label="Gain (dB)", interactive=True)
447
- default_q_slider = partial(gr.Slider, label="Q", interactive=True)
448
 
449
  gr.Markdown(
450
  title_md,
@@ -468,6 +505,15 @@ with gr.Blocks() as demo:
468
  sources="upload",
469
  label="Reference Audio",
470
  )
 
 
 
 
 
 
 
 
 
471
 
472
  with gr.Column():
473
  audio_output = default_audio_block(label="Output Audio", interactive=False)
@@ -481,18 +527,8 @@ with gr.Blocks() as demo:
481
  direct_output = default_audio_block(label="Direct Audio", interactive=False)
482
  wet_output = default_audio_block(label="Wet Audio", interactive=False)
483
 
 
484
  with gr.Row():
485
- process_button = gr.Button("Run", elem_id="render-button", variant="primary")
486
- reset_button = gr.Button("Reset", elem_id="reset-button")
487
-
488
- _ = gr.Markdown("## Common Parameters")
489
- with gr.Row():
490
- method_dropdown = gr.Dropdown(
491
- ["Mean", "Nearest Neighbour", "ST-ITO", "Regression"],
492
- value="ST-ITO",
493
- label=f"Style Transfer Method",
494
- interactive=True,
495
- )
496
  dataset_dropdown = gr.Dropdown(
497
  [("Internal", "internal"), ("MedleyDB", "medleydb")],
498
  label="Prior Distribution",
@@ -503,7 +539,7 @@ with gr.Blocks() as demo:
503
  embedding_dropdown = gr.Dropdown(
504
  [("AFx-Rep", "afx-rep"), ("MFCC", "mfcc"), ("MIR Features", "mir")],
505
  label="Embedding Model",
506
- info="This parameter is used in the Nearest Neighbour and ST-ITO methods.",
507
  value="afx-rep",
508
  interactive=True,
509
  )
@@ -516,6 +552,7 @@ with gr.Blocks() as demo:
516
  )
517
  mid_side_checkbox = gr.Checkbox(
518
  label="Use Mid-Side Processing",
 
519
  value=True,
520
  interactive=True,
521
  )
@@ -524,7 +561,7 @@ with gr.Blocks() as demo:
524
  with gr.Row():
525
  optimisation_steps = gr.Slider(
526
  minimum=1,
527
- maximum=100,
528
  value=100,
529
  step=1,
530
  label="Number of Optimisation Steps",
@@ -620,5 +657,19 @@ with gr.Blocks() as demo:
620
  ],
621
  )
622
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
623
 
624
  demo.launch()
 
19
 
20
  from modules.utils import vec2statedict, get_chunks
21
  from modules.fx import clip_delay_eq_Q, hadamard
22
+ from utils import (
23
+ get_log_mags_from_eq,
24
+ chain_functions,
25
+ remove_window_fn,
26
+ jsonparse2hydra,
27
+ )
28
  from ito import find_closest_training_sample, one_evaluation
29
  from st_ito.utils import (
30
  load_param_model,
 
67
  "internal": Path("presets/internal/"),
68
  "medleydb": Path("presets/medleydb/"),
69
  }
70
+ CKPT_PATH = Path("reg-ckpts/")
71
 
72
  PCA_PARAM_FILE = "gaussian.npz"
73
  INFO_PATH = "info.json"
 
106
  diff = x - mu
107
  b = torch.linalg.solve(cov, diff)
108
  norm = diff @ b
109
+ assert torch.all(norm >= 0), "Negative norm detected, check covariance matrix."
110
  return -0.5 * (norm + cov_logdet + mu.shape[0] * math.log(2 * math.pi))
111
 
112
 
 
175
  return two_chs_emb_fn
176
 
177
 
178
+ def get_regressor() -> Callable:
179
+ with open(CKPT_PATH / "config.yaml") as f:
180
+ config = yaml.safe_load(f)
181
+
182
+ model_config = config["model"]
183
+
184
+ checkpoints = (CKPT_PATH / "checkpoints").glob("*val_loss*.ckpt")
185
+ lowest_checkpoint = min(checkpoints, key=lambda x: float(x.stem.split("=")[-1]))
186
+ last_ckpt = torch.load(lowest_checkpoint, map_location="cpu")
187
+ model = chain_functions(remove_window_fn, jsonparse2hydra, instantiate)(
188
+ model_config
189
+ )
190
+ model.load_state_dict(last_ckpt["state_dict"])
191
+
192
+ device = Path("DEVICE.txt").read_text()
193
+ model = model.to(device)
194
+ model.eval()
195
+ param_stats = torch.load(CKPT_PATH / "param_stats.pt")
196
+ param_mu, param_std = (
197
+ param_stats["mu"].float().to(device),
198
+ param_stats["std"].float().to(device),
199
+ )
200
+ regressor = lambda wet: model(wet, dry=None) * param_std + param_mu
201
+ return regressor
202
+
203
+
204
  def convert2float(sr: int, x: np.ndarray) -> np.ndarray:
205
  if sr != 44100:
206
  x = resample(x, sr, 44100)
 
233
  ref = pyln.normalize.loudness(ref, ref_loudness, -18.0)
234
  ref = torch.from_numpy(ref).float().T.unsqueeze(0).to(device)
235
 
236
+ if method == "Regression":
237
+ regressor = get_regressor()
238
+ return regressor(ref).mean(0)
239
+
240
  y = convert2float(*input_audio)
241
  loudness = meter.integrated_loudness(y)
242
  y = pyln.normalize.loudness(y, loudness, -18.0)
 
256
  match method:
257
  case "Nearest Neighbour":
258
  vec = find_closest_training_sample(
259
+ fx,
260
+ two_chs_emb_fn,
261
+ to_fx_state_dict,
262
+ preset_dict[dataset].to(device),
263
+ ref,
264
+ y,
265
+ progress,
266
  )
267
 
268
  case "ST-ITO":
 
271
  two_chs_emb_fn,
272
  to_fx_state_dict,
273
  partial(logp_x, *[x.to(device) for x in gaussian_params_dict[dataset]]),
274
+ gaussian_params_dict[dataset][0].to(device),
275
  ref,
276
  y,
277
  optimiser_type=optimiser,
 
481
  # fx = vec2fx(fx_params.value)
482
  # sr, y = read(EXAMPLE_PATH)
483
 
 
 
 
484
  default_audio_block = partial(gr.Audio, type="numpy", loop=True)
 
 
 
485
 
486
  gr.Markdown(
487
  title_md,
 
505
  sources="upload",
506
  label="Reference Audio",
507
  )
508
+ method_dropdown = gr.Dropdown(
509
+ ["Mean", "Nearest Neighbour", "ST-ITO", "Regression"],
510
+ value="ST-ITO",
511
+ label=f"Style Transfer Method",
512
+ interactive=True,
513
+ )
514
+ process_button = gr.Button(
515
+ "Run", elem_id="render-button", variant="primary"
516
+ )
517
 
518
  with gr.Column():
519
  audio_output = default_audio_block(label="Output Audio", interactive=False)
 
527
  direct_output = default_audio_block(label="Direct Audio", interactive=False)
528
  wet_output = default_audio_block(label="Wet Audio", interactive=False)
529
 
530
+ _ = gr.Markdown("## Control Parameters")
531
  with gr.Row():
 
 
 
 
 
 
 
 
 
 
 
532
  dataset_dropdown = gr.Dropdown(
533
  [("Internal", "internal"), ("MedleyDB", "medleydb")],
534
  label="Prior Distribution",
 
539
  embedding_dropdown = gr.Dropdown(
540
  [("AFx-Rep", "afx-rep"), ("MFCC", "mfcc"), ("MIR Features", "mir")],
541
  label="Embedding Model",
542
+ info="This parameter has no effect when using the Mean and Regression methods.",
543
  value="afx-rep",
544
  interactive=True,
545
  )
 
552
  )
553
  mid_side_checkbox = gr.Checkbox(
554
  label="Use Mid-Side Processing",
555
+ info="This option has no effect when using the Mean and Regression methods.",
556
  value=True,
557
  interactive=True,
558
  )
 
561
  with gr.Row():
562
  optimisation_steps = gr.Slider(
563
  minimum=1,
564
+ maximum=2000,
565
  value=100,
566
  step=1,
567
  label="Number of Optimisation Steps",
 
657
  ],
658
  )
659
 
660
+ dry_wet_ratio.input(
661
+ chain_functions(
662
+ lambda _, *args: (_, *map(lambda x: x[1] / 32768, args)),
663
+ lambda ratio, d, w: math.sqrt(2)
664
+ * (
665
+ math.cos(ratio * math.pi * 0.5) * d
666
+ + math.sin(ratio * math.pi * 0.5) * w
667
+ ),
668
+ lambda x: (44100, (x * 32768).astype(np.int16)),
669
+ ),
670
+ inputs=[dry_wet_ratio, direct_output, wet_output],
671
+ outputs=[audio_output],
672
+ )
673
+
674
 
675
  demo.launch()
ito.py CHANGED
@@ -26,25 +26,6 @@ from st_ito.utils import (
26
  from utils import remove_window_fn, jsonparse2hydra
27
 
28
 
29
- def get_reference_query_chunks(dry_audio, wet_audio, chunk_size, sr):
30
- dry = dry_audio.unfold(1, chunk_size, chunk_size).transpose(0, 1)
31
- wet = wet_audio.unfold(1, chunk_size, chunk_size).transpose(0, 1)
32
-
33
- max_filtered = F.max_pool1d(wet.mean(1).abs(), int(sr * 0.05), stride=1)
34
- active_mask = torch.quantile(max_filtered, 0.5, dim=1) > 0.001 # -60 dB
35
- if not active_mask.any():
36
- raise ValueError("No active frames")
37
- elif active_mask.count_nonzero() < 2:
38
- raise ValueError("Too few active frames")
39
-
40
- dry = dry[active_mask]
41
- wet = wet[active_mask]
42
-
43
- ref_audio = wet[::2].contiguous()
44
- raw_audio = dry[1::2].contiguous()
45
- return ref_audio, raw_audio
46
-
47
-
48
  def logp_y_given_x(y, mu, std):
49
  cos_dist = torch.arccos(y @ mu)
50
  return -0.5 * (cos_dist / std).pow(2) - 0.5 * math.log(2 * math.pi) - std.log()
@@ -130,6 +111,7 @@ def find_closest_training_sample(
130
  training_samples: torch.Tensor,
131
  ref_audio: torch.Tensor,
132
  raw_audio: torch.Tensor,
 
133
  ) -> torch.Tensor:
134
 
135
  peak_scaler = 1 / ref_audio.abs().max()
@@ -167,7 +149,9 @@ def find_closest_training_sample(
167
  )
168
 
169
  best_logp, best_param = reduce(
170
- reduce_closure, training_samples.unbind(0), (-float("inf"), torch.tensor([]))
 
 
171
  )
172
  print(f"Best log-likelihood: {best_logp}")
173
  return best_param
 
26
  from utils import remove_window_fn, jsonparse2hydra
27
 
28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  def logp_y_given_x(y, mu, std):
30
  cos_dist = torch.arccos(y @ mu)
31
  return -0.5 * (cos_dist / std).pow(2) - 0.5 * math.log(2 * math.pi) - std.log()
 
111
  training_samples: torch.Tensor,
112
  ref_audio: torch.Tensor,
113
  raw_audio: torch.Tensor,
114
+ progress,
115
  ) -> torch.Tensor:
116
 
117
  peak_scaler = 1 / ref_audio.abs().max()
 
149
  )
150
 
151
  best_logp, best_param = reduce(
152
+ reduce_closure,
153
+ progress.tqdm(training_samples.unbind(0)),
154
+ (-float("inf"), torch.tensor([])),
155
  )
156
  print(f"Best log-likelihood: {best_logp}")
157
  return best_param
modules/model.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import torch.nn.functional as F
4
+ from functools import partial, reduce
5
+ from typing import Optional, List
6
+ from torchaudio.transforms import MelSpectrogram, MFCC
7
+
8
+
9
+ class LogMelSpectrogram(MelSpectrogram):
10
+ def forward(self, waveform):
11
+ return super().forward(waveform).add(1e-8).log()
12
+
13
+
14
+ class LogMFCC(MFCC):
15
+ def __init__(self, *args, **kwargs):
16
+ super().__init__(*args, log_mels=True, **kwargs)
17
+
18
+
19
+ class LightningSequential(nn.Sequential):
20
+ def __init__(self, modules: List[nn.Module]):
21
+ super().__init__(*modules)
22
+
23
+ def forward(self, *args):
24
+ return reduce(lambda x, f: f(*x) if isinstance(x, tuple) else f(x), self, args)
25
+
26
+
27
+ class ResidualWrapper(nn.Module):
28
+ def __init__(self, m: nn.Module):
29
+ super().__init__()
30
+ self.m = m
31
+
32
+ def forward(self, x):
33
+ return x + self.m(x)