yoyolicoris commited on
Commit
a25cbf8
·
1 Parent(s): d8ddc04

feat: ito functionalities

Browse files
Files changed (1) hide show
  1. ito.py +486 -0
ito.py ADDED
@@ -0,0 +1,486 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import torchaudio
4
+ import torch.nn.functional as F
5
+ import argparse
6
+ from pathlib import Path
7
+ import yaml
8
+ from typing import Callable, Tuple, Optional
9
+ import json
10
+ from hydra.utils import instantiate
11
+ from tqdm import tqdm
12
+ from functools import reduce
13
+ import math
14
+ import pyloudnorm as pyln
15
+ from functools import partial
16
+ from auraloss.freq import MultiResolutionSTFTLoss, SumAndDifferenceSTFTLoss
17
+
18
+ from modules.utils import chain_functions, get_chunks, vec2statedict
19
+ from st_ito.utils import (
20
+ load_param_model,
21
+ get_param_embeds,
22
+ get_feature_embeds,
23
+ load_mfcc_feature_extractor,
24
+ load_mir_feature_extractor,
25
+ )
26
+ from utils import remove_window_fn, jsonparse2hydra
27
+
28
+
29
+ def get_reference_query_chunks(dry_audio, wet_audio, chunk_size, sr):
30
+ dry = dry_audio.unfold(1, chunk_size, chunk_size).transpose(0, 1)
31
+ wet = wet_audio.unfold(1, chunk_size, chunk_size).transpose(0, 1)
32
+
33
+ max_filtered = F.max_pool1d(wet.mean(1).abs(), int(sr * 0.05), stride=1)
34
+ active_mask = torch.quantile(max_filtered, 0.5, dim=1) > 0.001 # -60 dB
35
+ if not active_mask.any():
36
+ raise ValueError("No active frames")
37
+ elif active_mask.count_nonzero() < 2:
38
+ raise ValueError("Too few active frames")
39
+
40
+ dry = dry[active_mask]
41
+ wet = wet[active_mask]
42
+
43
+ ref_audio = wet[::2].contiguous()
44
+ raw_audio = dry[1::2].contiguous()
45
+ return ref_audio, raw_audio
46
+
47
+
48
+ def logp_y_given_x(y, mu, std):
49
+ cos_dist = torch.arccos(y @ mu)
50
+ return -0.5 * (cos_dist / std).pow(2) - 0.5 * math.log(2 * math.pi) - std.log()
51
+
52
+
53
+ def one_evaluation(
54
+ fx: torch.nn.Module,
55
+ mid_side_embeds_fn: Callable[[torch.Tensor], tuple[torch.Tensor, torch.Tensor]],
56
+ to_fx_state_dict: Callable[[torch.Tensor], dict],
57
+ logp_x: Callable[[torch.Tensor], torch.Tensor],
58
+ init_vec: torch.Tensor,
59
+ ref_audio: torch.Tensor,
60
+ raw_audio: torch.Tensor,
61
+ lr: float,
62
+ steps: int,
63
+ weight: float,
64
+ ) -> torch.Tensor:
65
+
66
+ peak_scaler = 1 / ref_audio.abs().max()
67
+ ref_audio = ref_audio * peak_scaler
68
+
69
+ print(ref_audio.shape, raw_audio.shape)
70
+
71
+ param_logits = torch.nn.Parameter(init_vec.clone())
72
+ optimiser = torch.optim.Adam([param_logits], lr=lr)
73
+
74
+ with torch.no_grad():
75
+ ref_mid_embs, ref_side_embs = mid_side_embeds_fn(ref_audio)
76
+
77
+ with tqdm(range(steps), disable=True) as pbar:
78
+ for i in pbar:
79
+ cur_state_dict = to_fx_state_dict(param_logits)
80
+ preds = (
81
+ torch.func.functional_call(fx, cur_state_dict, raw_audio) * peak_scaler
82
+ )
83
+ mid_embs_pred, side_embs_pred = mid_side_embeds_fn(preds)
84
+
85
+ mid_cos = torch.arccos(mid_embs_pred @ ref_mid_embs.T)
86
+ side_cos = torch.arccos(side_embs_pred @ ref_side_embs.T)
87
+
88
+ mid_std = mid_cos.square().mean().sqrt()
89
+ side_std = side_cos.square().mean().sqrt()
90
+
91
+ y_x_ll = (
92
+ logp_y_given_x(ref_mid_embs, mid_embs_pred.T, mid_std).mean()
93
+ + logp_y_given_x(ref_side_embs, side_embs_pred.T, side_std).mean()
94
+ )
95
+
96
+ if weight > 0:
97
+ x_ll = logp_x(param_logits)
98
+ loss = -y_x_ll - x_ll * weight
99
+ else:
100
+ x_ll = y_x_ll.new_zeros(1)
101
+ loss = -y_x_ll
102
+ optimiser.zero_grad()
103
+ loss.backward()
104
+ optimiser.step()
105
+
106
+ postfix_dict = {
107
+ "y_x_ll": y_x_ll.item(),
108
+ "x_ll": x_ll.item(),
109
+ "loss": loss.item(),
110
+ "mid_std": mid_std.item() / math.pi * 180,
111
+ "side_std": side_std.item() / math.pi * 180,
112
+ }
113
+
114
+ pbar.set_postfix(
115
+ **postfix_dict,
116
+ )
117
+
118
+ print(y_x_ll.item(), x_ll.item(), loss.item())
119
+ print(mid_std.item() / math.pi * 180, side_std.item() / math.pi * 180)
120
+ return param_logits.detach()
121
+
122
+
123
+ @torch.no_grad()
124
+ def find_closest_training_sample(
125
+ fx: torch.nn.Module,
126
+ mid_side_embeds_fn: Callable[[torch.Tensor], tuple[torch.Tensor, torch.Tensor]],
127
+ to_fx_state_dict: Callable[[torch.Tensor], dict],
128
+ training_samples: torch.Tensor,
129
+ ref_audio: torch.Tensor,
130
+ raw_audio: torch.Tensor,
131
+ ) -> torch.Tensor:
132
+
133
+ peak_scaler = 1 / ref_audio.abs().max()
134
+ ref_audio = ref_audio * peak_scaler
135
+
136
+ print(ref_audio.shape, raw_audio.shape)
137
+
138
+ ref_mid_embs, ref_side_embs = mid_side_embeds_fn(ref_audio)
139
+
140
+ def reduce_closure(
141
+ x: Tuple[float, torch.Tensor], next_param: torch.Tensor
142
+ ) -> Tuple[float, torch.Tensor]:
143
+ cur_best_logp, cur_best_param = x
144
+ cur_state_dict = to_fx_state_dict(next_param)
145
+ preds = (
146
+ sum(torch.func.functional_call(fx, cur_state_dict, raw_audio)) * peak_scaler
147
+ )
148
+ mid_embs_pred, side_embs_pred = mid_side_embeds_fn(preds)
149
+
150
+ mid_cos = torch.arccos(mid_embs_pred @ ref_mid_embs.T)
151
+ side_cos = torch.arccos(side_embs_pred @ ref_side_embs.T)
152
+
153
+ mid_std = mid_cos.square().mean().sqrt()
154
+ side_std = side_cos.square().mean().sqrt()
155
+
156
+ y_x_ll = (
157
+ logp_y_given_x(ref_mid_embs, mid_embs_pred.T, mid_std).mean()
158
+ + logp_y_given_x(ref_side_embs, side_embs_pred.T, side_std).mean()
159
+ ).item()
160
+
161
+ return (
162
+ (cur_best_logp, cur_best_param)
163
+ if y_x_ll < cur_best_logp
164
+ else (y_x_ll, next_param)
165
+ )
166
+
167
+ best_logp, best_param = reduce(
168
+ reduce_closure, training_samples.unbind(0), (-float("inf"), torch.tensor([]))
169
+ )
170
+ print(f"Best log-likelihood: {best_logp}")
171
+ return best_param
172
+
173
+
174
+ def main():
175
+ parser = argparse.ArgumentParser()
176
+ parser.add_argument("eval_analysis_dir", type=str)
177
+ parser.add_argument("train_analysis_dir", type=str)
178
+ parser.add_argument("output_dir", type=str)
179
+ parser.add_argument("--config", type=str, help="Path to fx config file")
180
+ parser.add_argument("--chunk-duration", type=float, default=11.0)
181
+ parser.add_argument("--weight", type=float, default=0.01)
182
+ parser.add_argument("--steps", type=int, default=1000)
183
+ parser.add_argument("--lr", type=float, default=0.01)
184
+ parser.add_argument(
185
+ "--method",
186
+ type=str,
187
+ choices=["ito", "oracle", "nn_param", "nn_emb", "mean", "regression"],
188
+ default="ito",
189
+ )
190
+ parser.add_argument(
191
+ "--encoder", type=str, default="afx-rep", choices=["afx-rep", "mfcc", "mir"]
192
+ )
193
+ parser.add_argument("--save-pred", action="store_true")
194
+ parser.add_argument("--ckpt-dir", type=str)
195
+
196
+ args = parser.parse_args()
197
+
198
+ # load PCA
199
+ train_analysis_folder = Path(args.train_analysis_dir).resolve()
200
+ eval_analysis_folder = Path(args.eval_analysis_dir).resolve()
201
+
202
+ gauss_data = np.load(train_analysis_folder / "gaussian.npz")
203
+ baseline_vec = torch.tensor(gauss_data["mean"]).cuda()
204
+ cov = torch.tensor(gauss_data["cov"]).cuda()
205
+ cov_logdet = cov.logdet()
206
+
207
+ def logp_x(x):
208
+ diff = x - baseline_vec
209
+ b = torch.linalg.solve(cov, diff)
210
+ norm = diff @ b
211
+ return -0.5 * (
212
+ norm + cov_logdet + baseline_vec.shape[0] * math.log(2 * math.pi)
213
+ )
214
+
215
+ print(f"Baseline logp: {logp_x(baseline_vec).item()}")
216
+
217
+ with open(eval_analysis_folder / "info.json") as f:
218
+ info = json.load(f)
219
+
220
+ param_keys = info["params_keys"]
221
+ original_shapes = list(
222
+ map(lambda lst: lst if len(lst) else [1], info["params_original_shapes"])
223
+ )
224
+
225
+ *vec2dict_args, dimensions_not_need = get_chunks(param_keys, original_shapes)
226
+ vec2dict_args = [param_keys, original_shapes] + vec2dict_args
227
+ vec2dict = partial(
228
+ vec2statedict,
229
+ **dict(
230
+ zip(
231
+ [
232
+ "keys",
233
+ "original_shapes",
234
+ "selected_chunks",
235
+ "position",
236
+ "U_matrix_shape",
237
+ ],
238
+ vec2dict_args,
239
+ )
240
+ ),
241
+ )
242
+
243
+ if args.config is not None:
244
+ config_path = Path(args.config).resolve()
245
+ else:
246
+ config_path = Path(info["runs"][0]) / "config.yaml"
247
+
248
+ with open(config_path) as fp:
249
+ fx_config = yaml.safe_load(fp)
250
+ fx = instantiate(fx_config["model"])
251
+ fx = fx.cuda()
252
+ fx.eval()
253
+
254
+ fx.load_state_dict(vec2dict(baseline_vec), strict=False)
255
+
256
+ ndim_dict = {k: v.ndim for k, v in fx.state_dict().items()}
257
+ to_fx_state_dict = lambda x: {
258
+ k: v[0] if ndim_dict[k] == 0 else v for k, v in vec2dict(x).items()
259
+ }
260
+
261
+ if args.method == "regression":
262
+ ckpt_dir = Path(args.ckpt_dir)
263
+ with open(ckpt_dir / "config.yaml") as f:
264
+ config = yaml.safe_load(f)
265
+
266
+ model_config = config["model"]
267
+ data_config = config["data"]
268
+
269
+ checkpoints = (ckpt_dir / "checkpoints").glob("*val_loss*.ckpt")
270
+ lowest_checkpoint = min(checkpoints, key=lambda x: float(x.stem.split("=")[-1]))
271
+ print(f"Loading checkpoint: {lowest_checkpoint}")
272
+ last_ckpt = torch.load(lowest_checkpoint, map_location="cpu")
273
+ model = chain_functions(remove_window_fn, jsonparse2hydra, instantiate)(
274
+ model_config
275
+ )
276
+ model.load_state_dict(last_ckpt["state_dict"])
277
+
278
+ model = model.cuda()
279
+ model.eval()
280
+
281
+ train_root = Path(data_config["init_args"]["train_root"])
282
+ try:
283
+ param_stats = torch.load(train_root / "param_stats.pt")
284
+ except FileNotFoundError:
285
+ param_stats = torch.load(ckpt_dir / "param_stats.pt")
286
+
287
+ param_mu, param_std = (
288
+ param_stats["mu"].float().cuda(),
289
+ param_stats["std"].float().cuda(),
290
+ )
291
+
292
+ regressor = lambda wet: model(wet, dry=None) * param_std + param_mu
293
+ mid_side_embeds_fn = lambda x: (x, x)
294
+ else:
295
+ match args.encoder:
296
+ case "afx-rep":
297
+ afx_rep = load_param_model().cuda()
298
+ mid_side_embeds_fn = lambda x: get_param_embeds(x, afx_rep, 44100)
299
+ case "mfcc":
300
+ mfcc = load_mfcc_feature_extractor().cuda()
301
+ mid_side_embeds_fn = lambda x: get_feature_embeds(x, mfcc)
302
+ case "mir":
303
+ mir = load_mir_feature_extractor().cuda()
304
+ mid_side_embeds_fn = lambda x: get_feature_embeds(x, mir)
305
+ case _:
306
+ raise ValueError(f"Unknown encoder: {args.encoder}")
307
+
308
+ loss_fns = {
309
+ "mss_lr": MultiResolutionSTFTLoss(
310
+ [128, 512, 2048],
311
+ [32, 128, 512],
312
+ [128, 512, 2048],
313
+ sample_rate=44100,
314
+ perceptual_weighting=True,
315
+ ).cuda(),
316
+ "mss_ms": SumAndDifferenceSTFTLoss(
317
+ [128, 512, 2048],
318
+ [32, 128, 512],
319
+ [128, 512, 2048],
320
+ sample_rate=44100,
321
+ perceptual_weighting=True,
322
+ ),
323
+ "mldr_lr": MLDRLoss(
324
+ sr=44100,
325
+ s_taus=[50, 100],
326
+ l_taus=[1000, 2000],
327
+ ).cuda(),
328
+ "mldr_ms": MLDRLoss(
329
+ sr=44100,
330
+ s_taus=[50, 100],
331
+ l_taus=[1000, 2000],
332
+ mid_side=True,
333
+ ).cuda(),
334
+ }
335
+
336
+ raw_params = np.load(eval_analysis_folder / "raw_params.npy")
337
+ feature_mask = np.load(train_analysis_folder / "feature_mask.npy")
338
+ gt_params = raw_params[:, feature_mask]
339
+
340
+ train_params = np.load(train_analysis_folder / "raw_params.npy")
341
+ train_index = np.load(train_analysis_folder / "train_index.npy")
342
+ train_params = torch.from_numpy(train_params[train_index][:, feature_mask]).cuda()
343
+
344
+ output_root = Path(args.output_dir)
345
+
346
+ weights = []
347
+ losses = []
348
+
349
+ for dry_file, wet_file, shifts, gt_param in zip(
350
+ info["dry_files"], info["wet_files"], info["alignment_shifts"], gt_params
351
+ ):
352
+ dry, sr = torchaudio.load(dry_file)
353
+ wet, _ = torchaudio.load(wet_file)
354
+ assert sr == _
355
+
356
+ dry = dry[:, : wet.shape[1]]
357
+ wet = wet[:, : dry.shape[1]]
358
+
359
+ dry = torch.roll(dry, shifts=int(shifts), dims=1)
360
+ print(shifts, dry.shape, dry_file)
361
+
362
+ dry = dry.mean(0, keepdim=True)
363
+
364
+ meter = pyln.Meter(sr)
365
+ normaliser = lambda x: pyln.normalize.loudness(
366
+ x, meter.integrated_loudness(x), -18.0
367
+ )
368
+ dry = torch.from_numpy(normaliser(dry.numpy().T).T).float().cuda()
369
+ wet = torch.from_numpy(normaliser(wet.numpy().T).T).float().cuda()
370
+ gt_param = torch.tensor(gt_param).cuda()
371
+
372
+ match args.method:
373
+ case "ito":
374
+ try:
375
+ ref_audio, raw_audio = get_reference_query_chunks(
376
+ dry, wet, int(sr * args.chunk_duration), sr
377
+ )
378
+ except ValueError as e:
379
+ print(f"Skipping {dry_file}: {e}")
380
+ continue
381
+ pred_param = one_evaluation(
382
+ fx,
383
+ mid_side_embeds_fn,
384
+ to_fx_state_dict,
385
+ logp_x,
386
+ baseline_vec,
387
+ ref_audio,
388
+ raw_audio,
389
+ lr=args.lr,
390
+ steps=args.steps,
391
+ weight=args.weight,
392
+ )
393
+ case "oracle":
394
+ pred_param = gt_param
395
+ case "nn_param":
396
+ pred_param = train_params[
397
+ torch.argmin((train_params - gt_param).square().mean(1))
398
+ ]
399
+ case "nn_emb":
400
+ try:
401
+ ref_audio, raw_audio = get_reference_query_chunks(
402
+ dry, wet, int(sr * args.chunk_duration), sr
403
+ )
404
+ except ValueError as e:
405
+ print(f"Skipping {dry_file}: {e}")
406
+ continue
407
+ pred_param = find_closest_training_sample(
408
+ fx,
409
+ mid_side_embeds_fn,
410
+ to_fx_state_dict,
411
+ train_params,
412
+ ref_audio,
413
+ raw_audio,
414
+ )
415
+ case "mean":
416
+ pred_param = baseline_vec
417
+ case "regression":
418
+ try:
419
+ ref_audio, _ = get_reference_query_chunks(
420
+ dry, wet, int(sr * args.chunk_duration), sr
421
+ )
422
+ except ValueError as e:
423
+ print(f"Skipping {dry_file}: {e}")
424
+ continue
425
+ with torch.no_grad():
426
+ pred_param = regressor(ref_audio).mean(0)
427
+ case _:
428
+ raise ValueError(f"Unknown method: {args.method}")
429
+
430
+ fx.load_state_dict(vec2dict(pred_param), strict=False)
431
+ with torch.no_grad():
432
+ rendered = fx(dry.unsqueeze(0)).squeeze()
433
+
434
+ loss = {
435
+ k: f(rendered.unsqueeze(0), wet.unsqueeze(0)).item()
436
+ for k, f in loss_fns.items()
437
+ }
438
+ param_mse_loss = F.mse_loss(pred_param, gt_param).item()
439
+ loss["param_mse"] = param_mse_loss
440
+ print(", ".join([f"{k}: {v}" for k, v in loss.items()]))
441
+
442
+ losses.append(loss)
443
+ weights.append(wet.shape[1])
444
+
445
+ dry_file = Path(dry_file)
446
+ out_dir = output_root / dry_file.parts[-2] / dry_file.stem
447
+ out_dir.mkdir(parents=True, exist_ok=True)
448
+
449
+ with open(out_dir / "metrics.yaml", "w") as fp:
450
+ yaml.safe_dump(
451
+ loss,
452
+ fp,
453
+ )
454
+
455
+ torch.save(pred_param.cpu(), out_dir / "pred_param.pth")
456
+
457
+ with open(out_dir / "meta.yaml", "w") as fp:
458
+ yaml.safe_dump(
459
+ {
460
+ "model": fx_config["model"],
461
+ "params_keys": param_keys,
462
+ "params_original_shapes": original_shapes,
463
+ "alignment_shift": shifts,
464
+ },
465
+ fp,
466
+ )
467
+
468
+ # symbolic link
469
+ original_wet = out_dir / "wet.wav"
470
+ original_dry = out_dir / "dry.wav"
471
+ if not original_wet.exists():
472
+ original_wet.symlink_to(wet_file)
473
+ if not original_dry.exists():
474
+ original_dry.symlink_to(dry_file)
475
+
476
+ if args.save_pred:
477
+ torchaudio.save(out_dir / "pred.wav", rendered.cpu(), sr)
478
+
479
+ weights = np.array(weights)
480
+ weights = weights / weights.sum()
481
+
482
+ print({k: np.array([l[k] for l in losses]) @ weights for k in losses[0].keys()})
483
+
484
+
485
+ if __name__ == "__main__":
486
+ main()