yoyolicoris commited on
Commit
40355de
·
1 Parent(s): a25cbf8

update app.py

Browse files
Files changed (1) hide show
  1. app.py +77 -30
app.py CHANGED
@@ -15,19 +15,19 @@ from itertools import accumulate
15
  from torchcomp import coef2ms, ms2coef
16
  from copy import deepcopy
17
  from pathlib import Path
18
- from typing import Tuple, List, Optional, Union
19
 
20
  from modules.utils import vec2statedict, get_chunks
21
- from modules.fx import clip_delay_eq_Q
22
- from plot_utils import get_log_mags_from_eq
23
-
24
-
25
- def chain_functions(*functions):
26
- return lambda *initial_args: reduce(
27
- lambda xs, f: f(*xs) if isinstance(xs, tuple) else f(xs),
28
- functions,
29
- initial_args,
30
- )
31
 
32
 
33
  title_md = "# Vocal Effects Style Transfer Demo"
@@ -135,34 +135,65 @@ global_fx = instantiate(fx_config)
135
  # global_fx.eval()
136
  global_fx.load_state_dict(vec2dict(internal_mean), strict=False)
137
 
 
 
 
 
138
 
139
  meter = pyln.Meter(44100)
140
 
141
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
  def inference(
143
- audio,
 
144
  ratio,
145
  method,
146
  dataset,
147
  embedding,
148
  remove_approx,
 
149
  steps,
150
  prior_weight,
151
  optimiser,
152
  lr,
153
  ):
154
- sr, y = audio
155
- if sr != 44100:
156
- y = resample(y, sr, 44100)
157
- if y.dtype.kind != "f":
158
- y = y / 32768.0
159
 
160
- if y.ndim == 1:
161
- y = y[:, None]
162
  loudness = meter.integrated_loudness(y)
163
  y = pyln.normalize.loudness(y, loudness, -18.0)
164
-
165
  y = torch.from_numpy(y).float().T.unsqueeze(0)
 
 
 
 
 
166
  if y.shape[1] != 1:
167
  y = y.mean(dim=1, keepdim=True)
168
 
@@ -172,8 +203,16 @@ def inference(
172
  match method:
173
  case "Mean":
174
  vec = gaussian_params_dict[dataset][0]
 
 
 
 
 
 
 
 
175
  case _:
176
- vec = internal_mean.clone()
177
 
178
  if remove_approx:
179
  infer_fx = instantiate(rt_config)
@@ -407,8 +446,8 @@ with gr.Blocks() as demo:
407
  wet_output = default_audio_block(label="Wet Audio", interactive=False)
408
 
409
  with gr.Row():
410
- reset_button = gr.Button("Reset", elem_id="reset-button")
411
  render_button = gr.Button("Run", elem_id="render-button", variant="primary")
 
412
 
413
  _ = gr.Markdown("## Common Parameters")
414
  with gr.Row():
@@ -426,18 +465,24 @@ with gr.Blocks() as demo:
426
  interactive=True,
427
  )
428
  embedding_dropdown = gr.Dropdown(
429
- ["AFx-Rep", "MFCC", "MIR Features"],
430
  label="Embedding Model",
431
  info="This parameter is used in the Nearest Neighbour and ST-ITO methods.",
432
- value="AFx-Rep",
433
- interactive=True,
434
- )
435
- remove_approx_checkbox = gr.Checkbox(
436
- label="Use Real-time Effects",
437
- info="Use real-time delay and reverb effects instead of approximated ones.",
438
- value=False,
439
  interactive=True,
440
  )
 
 
 
 
 
 
 
 
 
 
 
 
441
 
442
  _ = gr.Markdown("## Parameters for ST-ITO Method")
443
  with gr.Row():
@@ -492,11 +537,13 @@ with gr.Blocks() as demo:
492
  ),
493
  inputs=[
494
  audio_input,
 
495
  dry_wet_ratio,
496
  method_dropdown,
497
  dataset_dropdown,
498
  embedding_dropdown,
499
  remove_approx_checkbox,
 
500
  optimisation_steps,
501
  prior_weight,
502
  optimiser_dropdown,
 
15
  from torchcomp import coef2ms, ms2coef
16
  from copy import deepcopy
17
  from pathlib import Path
18
+ from typing import Tuple, List, Optional, Union, Callable
19
 
20
  from modules.utils import vec2statedict, get_chunks
21
+ from modules.fx import clip_delay_eq_Q, hadamard
22
+ from utils import get_log_mags_from_eq, chain_functions
23
+ from ito import find_closest_training_sample
24
+ from st_ito.utils import (
25
+ load_param_model,
26
+ get_param_embeds,
27
+ get_feature_embeds,
28
+ load_mfcc_feature_extractor,
29
+ load_mir_feature_extractor,
30
+ )
31
 
32
 
33
  title_md = "# Vocal Effects Style Transfer Demo"
 
135
  # global_fx.eval()
136
  global_fx.load_state_dict(vec2dict(internal_mean), strict=False)
137
 
138
+ ndim_dict = {k: v.ndim for k, v in global_fx.state_dict().items()}
139
+ to_fx_state_dict = lambda x: {
140
+ k: v[0] if ndim_dict[k] == 0 else v for k, v in vec2dict(x).items()
141
+ }
142
 
143
  meter = pyln.Meter(44100)
144
 
145
 
146
+ def get_embedding_model(embedding: str) -> Callable:
147
+ match embedding:
148
+ case "afx-rep":
149
+ afx_rep = load_param_model()
150
+ two_chs_emb_fn = lambda x: get_param_embeds(x, afx_rep, 44100)
151
+ case "mfcc":
152
+ mfcc = load_mfcc_feature_extractor()
153
+ two_chs_emb_fn = lambda x: get_feature_embeds(x, mfcc)
154
+ case "mir":
155
+ mir = load_mir_feature_extractor()
156
+ two_chs_emb_fn = lambda x: get_feature_embeds(x, mir)
157
+ case _:
158
+ raise ValueError(f"Unknown encoder: {embedding}")
159
+ return two_chs_emb_fn
160
+
161
+
162
+ def convert2float(sr: int, x: np.ndarray) -> np.ndarray:
163
+ if sr != 44100:
164
+ x = resample(x, sr, 44100)
165
+ if x.dtype.kind != "f":
166
+ x = x / 32768.0
167
+ if x.ndim == 1:
168
+ x = x[:, None]
169
+ return x
170
+
171
+
172
  def inference(
173
+ input_audio,
174
+ ref_audio,
175
  ratio,
176
  method,
177
  dataset,
178
  embedding,
179
  remove_approx,
180
+ mid_side,
181
  steps,
182
  prior_weight,
183
  optimiser,
184
  lr,
185
  ):
186
+ y = convert2float(*input_audio)
187
+ ref = convert2float(*ref_audio)
 
 
 
188
 
 
 
189
  loudness = meter.integrated_loudness(y)
190
  y = pyln.normalize.loudness(y, loudness, -18.0)
 
191
  y = torch.from_numpy(y).float().T.unsqueeze(0)
192
+
193
+ ref_loudness = meter.integrated_loudness(ref)
194
+ ref = pyln.normalize.loudness(ref, ref_loudness, -18.0)
195
+ ref = torch.from_numpy(ref).float().T.unsqueeze(0)
196
+
197
  if y.shape[1] != 1:
198
  y = y.mean(dim=1, keepdim=True)
199
 
 
203
  match method:
204
  case "Mean":
205
  vec = gaussian_params_dict[dataset][0]
206
+ case "Nearest Neighbour":
207
+ two_chs_emb_fn = chain_functions(
208
+ hadamard if mid_side else lambda x: x,
209
+ get_embedding_model(embedding),
210
+ )
211
+ vec = find_closest_training_sample(
212
+ fx, two_chs_emb_fn, to_fx_state_dict, preset_dict[dataset], ref, y
213
+ )
214
  case _:
215
+ raise ValueError(f"Unknown method: {method}")
216
 
217
  if remove_approx:
218
  infer_fx = instantiate(rt_config)
 
446
  wet_output = default_audio_block(label="Wet Audio", interactive=False)
447
 
448
  with gr.Row():
 
449
  render_button = gr.Button("Run", elem_id="render-button", variant="primary")
450
+ reset_button = gr.Button("Reset", elem_id="reset-button")
451
 
452
  _ = gr.Markdown("## Common Parameters")
453
  with gr.Row():
 
465
  interactive=True,
466
  )
467
  embedding_dropdown = gr.Dropdown(
468
+ [("AFx-Rep", "afx-rep"), ("MFCC", "mfcc"), ("MIR Features", "mir")],
469
  label="Embedding Model",
470
  info="This parameter is used in the Nearest Neighbour and ST-ITO methods.",
471
+ value="afx-rep",
 
 
 
 
 
 
472
  interactive=True,
473
  )
474
+ with gr.Column():
475
+ remove_approx_checkbox = gr.Checkbox(
476
+ label="Use Real-time Effects",
477
+ info="Use real-time delay and reverb effects instead of approximated ones.",
478
+ value=False,
479
+ interactive=True,
480
+ )
481
+ mid_side_checkbox = gr.Checkbox(
482
+ label="Use Mid-Side Processing",
483
+ value=True,
484
+ interactive=True,
485
+ )
486
 
487
  _ = gr.Markdown("## Parameters for ST-ITO Method")
488
  with gr.Row():
 
537
  ),
538
  inputs=[
539
  audio_input,
540
+ audio_reference,
541
  dry_wet_ratio,
542
  method_dropdown,
543
  dataset_dropdown,
544
  embedding_dropdown,
545
  remove_approx_checkbox,
546
+ mid_side_checkbox,
547
  optimisation_steps,
548
  prior_weight,
549
  optimiser_dropdown,