yoyolicoris commited on
Commit
90573cb
·
1 Parent(s): a59e3f3

feat: enable cuda device option

Browse files
Files changed (1) hide show
  1. app.py +41 -15
app.py CHANGED
@@ -20,7 +20,7 @@ from typing import Tuple, List, Optional, Union, Callable
20
  from modules.utils import vec2statedict, get_chunks
21
  from modules.fx import clip_delay_eq_Q, hadamard
22
  from utils import get_log_mags_from_eq, chain_functions
23
- from ito import find_closest_training_sample
24
  from st_ito.utils import (
25
  load_param_model,
26
  get_param_embeds,
@@ -47,6 +47,7 @@ Try to play around with the sliders and buttons and see what you can come up wit
47
  > **_Note:_** To upload your own audio, click X on the top right corner of the input audio block.
48
  """
49
 
 
50
  SLIDER_MAX = 3
51
  SLIDER_MIN = -3
52
  NUMBER_OF_PCS = 4
@@ -88,11 +89,18 @@ def load_presets(preset_folder: Path) -> Tensor:
88
  return presets
89
 
90
 
91
- def load_gaussian_params(f: Union[Path, str]) -> Tuple[Tensor, Tensor]:
92
  gauss_params = np.load(f)
93
  mean = torch.from_numpy(gauss_params["mean"]).float()
94
  cov = torch.from_numpy(gauss_params["cov"]).float()
95
- return mean, cov
 
 
 
 
 
 
 
96
 
97
 
98
  preset_dict = {k: load_presets(v) for k, v in PRESET_PATH.items()}
@@ -146,13 +154,13 @@ meter = pyln.Meter(44100)
146
  def get_embedding_model(embedding: str) -> Callable:
147
  match embedding:
148
  case "afx-rep":
149
- afx_rep = load_param_model()
150
  two_chs_emb_fn = lambda x: get_param_embeds(x, afx_rep, 44100)
151
  case "mfcc":
152
- mfcc = load_mfcc_feature_extractor()
153
  two_chs_emb_fn = lambda x: get_feature_embeds(x, mfcc)
154
  case "mir":
155
- mir = load_mir_feature_extractor()
156
  two_chs_emb_fn = lambda x: get_feature_embeds(x, mir)
157
  case _:
158
  raise ValueError(f"Unknown encoder: {embedding}")
@@ -188,34 +196,52 @@ def inference(
188
 
189
  loudness = meter.integrated_loudness(y)
190
  y = pyln.normalize.loudness(y, loudness, -18.0)
191
- y = torch.from_numpy(y).float().T.unsqueeze(0)
192
 
193
  ref_loudness = meter.integrated_loudness(ref)
194
  ref = pyln.normalize.loudness(ref, ref_loudness, -18.0)
195
- ref = torch.from_numpy(ref).float().T.unsqueeze(0)
196
 
197
  if y.shape[1] != 1:
198
  y = y.mean(dim=1, keepdim=True)
199
 
200
- fx = deepcopy(global_fx)
201
  fx.train()
202
 
203
  match method:
204
  case "Mean":
205
  vec = gaussian_params_dict[dataset][0]
206
- case "Nearest Neighbour":
207
  two_chs_emb_fn = chain_functions(
208
  hadamard if mid_side else lambda x: x,
209
  get_embedding_model(embedding),
210
  )
211
- vec = find_closest_training_sample(
212
- fx, two_chs_emb_fn, to_fx_state_dict, preset_dict[dataset], ref, y
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
213
  )
214
  case _:
215
  raise ValueError(f"Unknown method: {method}")
216
 
217
  if remove_approx:
218
- infer_fx = instantiate(rt_config)
219
  else:
220
  infer_fx = fx
221
 
@@ -225,8 +251,8 @@ def inference(
225
 
226
  with torch.no_grad():
227
  direct, wet = fx(y)
228
- direct = direct.squeeze(0).T.numpy()
229
- wet = wet.squeeze(0).T.numpy()
230
  angle = ratio * math.pi * 0.5
231
  test_clipping = direct + wet
232
  # rendered = fx(y).squeeze(0).T.numpy()
 
20
  from modules.utils import vec2statedict, get_chunks
21
  from modules.fx import clip_delay_eq_Q, hadamard
22
  from utils import get_log_mags_from_eq, chain_functions
23
+ from ito import find_closest_training_sample, one_evaluation
24
  from st_ito.utils import (
25
  load_param_model,
26
  get_param_embeds,
 
47
  > **_Note:_** To upload your own audio, click X on the top right corner of the input audio block.
48
  """
49
 
50
+ DEVICE = "cuda"
51
  SLIDER_MAX = 3
52
  SLIDER_MIN = -3
53
  NUMBER_OF_PCS = 4
 
89
  return presets
90
 
91
 
92
+ def load_gaussian_params(f: Union[Path, str]) -> Tuple[Tensor, Tensor, Tensor]:
93
  gauss_params = np.load(f)
94
  mean = torch.from_numpy(gauss_params["mean"]).float()
95
  cov = torch.from_numpy(gauss_params["cov"]).float()
96
+ return mean, cov, cov.logdet()
97
+
98
+
99
+ def logp_x(mu, cov, cov_logdet, x):
100
+ diff = x - mu
101
+ b = torch.linalg.solve(cov, diff)
102
+ norm = diff @ b
103
+ return -0.5 * (norm + cov_logdet + mu.shape[0] * math.log(2 * math.pi))
104
 
105
 
106
  preset_dict = {k: load_presets(v) for k, v in PRESET_PATH.items()}
 
154
  def get_embedding_model(embedding: str) -> Callable:
155
  match embedding:
156
  case "afx-rep":
157
+ afx_rep = load_param_model().to(DEVICE)
158
  two_chs_emb_fn = lambda x: get_param_embeds(x, afx_rep, 44100)
159
  case "mfcc":
160
+ mfcc = load_mfcc_feature_extractor().to(DEVICE)
161
  two_chs_emb_fn = lambda x: get_feature_embeds(x, mfcc)
162
  case "mir":
163
+ mir = load_mir_feature_extractor().to(DEVICE)
164
  two_chs_emb_fn = lambda x: get_feature_embeds(x, mir)
165
  case _:
166
  raise ValueError(f"Unknown encoder: {embedding}")
 
196
 
197
  loudness = meter.integrated_loudness(y)
198
  y = pyln.normalize.loudness(y, loudness, -18.0)
199
+ y = torch.from_numpy(y).float().T.unsqueeze(0).to(DEVICE)
200
 
201
  ref_loudness = meter.integrated_loudness(ref)
202
  ref = pyln.normalize.loudness(ref, ref_loudness, -18.0)
203
+ ref = torch.from_numpy(ref).float().T.unsqueeze(0).to(DEVICE)
204
 
205
  if y.shape[1] != 1:
206
  y = y.mean(dim=1, keepdim=True)
207
 
208
+ fx = deepcopy(global_fx).to(DEVICE)
209
  fx.train()
210
 
211
  match method:
212
  case "Mean":
213
  vec = gaussian_params_dict[dataset][0]
214
+ case "Nearest Neighbour" | "ST-ITO":
215
  two_chs_emb_fn = chain_functions(
216
  hadamard if mid_side else lambda x: x,
217
  get_embedding_model(embedding),
218
  )
219
+ vec = (
220
+ find_closest_training_sample(
221
+ fx, two_chs_emb_fn, to_fx_state_dict, preset_dict[dataset], ref, y
222
+ )
223
+ if method == "Nearest Neighbour"
224
+ else one_evaluation(
225
+ fx,
226
+ two_chs_emb_fn,
227
+ to_fx_state_dict,
228
+ partial(
229
+ logp_x, *[x.to(DEVICE) for x in gaussian_params_dict[dataset]]
230
+ ),
231
+ internal_mean.to(DEVICE),
232
+ ref,
233
+ y,
234
+ optimiser_type=optimiser,
235
+ lr=lr,
236
+ steps=steps,
237
+ weight=prior_weight,
238
+ )
239
  )
240
  case _:
241
  raise ValueError(f"Unknown method: {method}")
242
 
243
  if remove_approx:
244
+ infer_fx = instantiate(rt_config).to(DEVICE)
245
  else:
246
  infer_fx = fx
247
 
 
251
 
252
  with torch.no_grad():
253
  direct, wet = fx(y)
254
+ direct = direct.squeeze(0).T.cpu().numpy()
255
+ wet = wet.squeeze(0).T.cpu().numpy()
256
  angle = ratio * math.pi * 0.5
257
  test_clipping = direct + wet
258
  # rendered = fx(y).squeeze(0).T.numpy()