Spaces:
Sleeping
Sleeping
| import torch | |
| import numpy as np | |
| import torchaudio | |
| import torch.nn.functional as F | |
| import argparse | |
| from pathlib import Path | |
| import yaml | |
| from typing import Callable, Tuple, Optional | |
| import json | |
| from hydra.utils import instantiate | |
| from tqdm import tqdm | |
| from functools import reduce | |
| import math | |
| import pyloudnorm as pyln | |
| from functools import partial | |
| from auraloss.freq import MultiResolutionSTFTLoss, SumAndDifferenceSTFTLoss | |
| from modules.utils import chain_functions, get_chunks, vec2statedict | |
| from st_ito.utils import ( | |
| load_param_model, | |
| get_param_embeds, | |
| get_feature_embeds, | |
| load_mfcc_feature_extractor, | |
| load_mir_feature_extractor, | |
| ) | |
| from utils import remove_window_fn, jsonparse2hydra | |
| def logp_y_given_x(y, mu, std): | |
| cos_dist = torch.arccos(y @ mu) | |
| return -0.5 * (cos_dist / std).pow(2) - 0.5 * math.log(2 * math.pi) - std.log() | |
| def one_evaluation( | |
| fx: torch.nn.Module, | |
| mid_side_embeds_fn: Callable[[torch.Tensor], tuple[torch.Tensor, torch.Tensor]], | |
| to_fx_state_dict: Callable[[torch.Tensor], dict], | |
| logp_x: Callable[[torch.Tensor], torch.Tensor], | |
| init_vec: torch.Tensor, | |
| ref_audio: torch.Tensor, | |
| raw_audio: torch.Tensor, | |
| optimiser_type: str, | |
| lr: float, | |
| steps: int, | |
| weight: float, | |
| progress, | |
| ) -> torch.Tensor: | |
| peak_scaler = 1 / ref_audio.abs().max() | |
| ref_audio = ref_audio * peak_scaler | |
| print(ref_audio.shape, raw_audio.shape) | |
| param_logits = torch.nn.Parameter(init_vec.clone()) | |
| # optimiser = torch.optim.Adam([param_logits], lr=lr) | |
| optimiser = getattr(torch.optim, optimiser_type)([param_logits], lr=lr) | |
| with torch.no_grad(): | |
| ref_mid_embs, ref_side_embs = mid_side_embeds_fn(ref_audio) | |
| for i in progress.tqdm(range(steps)): | |
| cur_state_dict = to_fx_state_dict(param_logits) | |
| preds = ( | |
| sum(torch.func.functional_call(fx, cur_state_dict, raw_audio)) * peak_scaler | |
| ) | |
| mid_embs_pred, side_embs_pred = mid_side_embeds_fn(preds) | |
| mid_cos = torch.arccos(mid_embs_pred @ ref_mid_embs.T) | |
| side_cos = torch.arccos(side_embs_pred @ ref_side_embs.T) | |
| mid_std = mid_cos.square().mean().sqrt() | |
| side_std = side_cos.square().mean().sqrt() | |
| y_x_ll = ( | |
| logp_y_given_x(ref_mid_embs, mid_embs_pred.T, mid_std).mean() | |
| + logp_y_given_x(ref_side_embs, side_embs_pred.T, side_std).mean() | |
| ) | |
| if weight > 0: | |
| x_ll = logp_x(param_logits) | |
| loss = -y_x_ll - x_ll * weight | |
| else: | |
| x_ll = y_x_ll.new_zeros(1) | |
| loss = -y_x_ll | |
| optimiser.zero_grad() | |
| loss.backward() | |
| optimiser.step() | |
| postfix_dict = { | |
| "y_x_ll": y_x_ll.item(), | |
| "x_ll": x_ll.item(), | |
| "loss": loss.item(), | |
| "mid_std": mid_std.item() / math.pi * 180, | |
| "side_std": side_std.item() / math.pi * 180, | |
| } | |
| # pbar.set_postfix( | |
| # **postfix_dict, | |
| # ) | |
| print(y_x_ll.item(), x_ll.item(), loss.item()) | |
| print(mid_std.item() / math.pi * 180, side_std.item() / math.pi * 180) | |
| return param_logits.detach() | |
| def find_closest_training_sample( | |
| fx: torch.nn.Module, | |
| mid_side_embeds_fn: Callable[[torch.Tensor], tuple[torch.Tensor, torch.Tensor]], | |
| to_fx_state_dict: Callable[[torch.Tensor], dict], | |
| training_samples: torch.Tensor, | |
| ref_audio: torch.Tensor, | |
| raw_audio: torch.Tensor, | |
| progress, | |
| ) -> torch.Tensor: | |
| peak_scaler = 1 / ref_audio.abs().max() | |
| ref_audio = ref_audio * peak_scaler | |
| print(ref_audio.shape, raw_audio.shape) | |
| ref_mid_embs, ref_side_embs = mid_side_embeds_fn(ref_audio) | |
| def reduce_closure( | |
| x: Tuple[float, torch.Tensor], next_param: torch.Tensor | |
| ) -> Tuple[float, torch.Tensor]: | |
| cur_best_logp, cur_best_param = x | |
| cur_state_dict = to_fx_state_dict(next_param) | |
| preds = ( | |
| sum(torch.func.functional_call(fx, cur_state_dict, raw_audio)) * peak_scaler | |
| ) | |
| mid_embs_pred, side_embs_pred = mid_side_embeds_fn(preds) | |
| mid_cos = torch.arccos(mid_embs_pred @ ref_mid_embs.T) | |
| side_cos = torch.arccos(side_embs_pred @ ref_side_embs.T) | |
| mid_std = mid_cos.square().mean().sqrt() | |
| side_std = side_cos.square().mean().sqrt() | |
| y_x_ll = ( | |
| logp_y_given_x(ref_mid_embs, mid_embs_pred.T, mid_std).mean() | |
| + logp_y_given_x(ref_side_embs, side_embs_pred.T, side_std).mean() | |
| ).item() | |
| return ( | |
| (cur_best_logp, cur_best_param) | |
| if y_x_ll < cur_best_logp | |
| else (y_x_ll, next_param) | |
| ) | |
| best_logp, best_param = reduce( | |
| reduce_closure, | |
| progress.tqdm(training_samples.unbind(0)), | |
| (-float("inf"), torch.tensor([])), | |
| ) | |
| print(f"Best log-likelihood: {best_logp}") | |
| return best_param | |
| def main(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("eval_analysis_dir", type=str) | |
| parser.add_argument("train_analysis_dir", type=str) | |
| parser.add_argument("output_dir", type=str) | |
| parser.add_argument("--config", type=str, help="Path to fx config file") | |
| parser.add_argument("--chunk-duration", type=float, default=11.0) | |
| parser.add_argument("--weight", type=float, default=0.01) | |
| parser.add_argument("--steps", type=int, default=1000) | |
| parser.add_argument("--lr", type=float, default=0.01) | |
| parser.add_argument( | |
| "--method", | |
| type=str, | |
| choices=["ito", "oracle", "nn_param", "nn_emb", "mean", "regression"], | |
| default="ito", | |
| ) | |
| parser.add_argument( | |
| "--encoder", type=str, default="afx-rep", choices=["afx-rep", "mfcc", "mir"] | |
| ) | |
| parser.add_argument("--save-pred", action="store_true") | |
| parser.add_argument("--ckpt-dir", type=str) | |
| args = parser.parse_args() | |
| # load PCA | |
| train_analysis_folder = Path(args.train_analysis_dir).resolve() | |
| eval_analysis_folder = Path(args.eval_analysis_dir).resolve() | |
| gauss_data = np.load(train_analysis_folder / "gaussian.npz") | |
| baseline_vec = torch.tensor(gauss_data["mean"]).cuda() | |
| cov = torch.tensor(gauss_data["cov"]).cuda() | |
| cov_logdet = cov.logdet() | |
| def logp_x(x): | |
| diff = x - baseline_vec | |
| b = torch.linalg.solve(cov, diff) | |
| norm = diff @ b | |
| return -0.5 * ( | |
| norm + cov_logdet + baseline_vec.shape[0] * math.log(2 * math.pi) | |
| ) | |
| print(f"Baseline logp: {logp_x(baseline_vec).item()}") | |
| with open(eval_analysis_folder / "info.json") as f: | |
| info = json.load(f) | |
| param_keys = info["params_keys"] | |
| original_shapes = list( | |
| map(lambda lst: lst if len(lst) else [1], info["params_original_shapes"]) | |
| ) | |
| *vec2dict_args, dimensions_not_need = get_chunks(param_keys, original_shapes) | |
| vec2dict_args = [param_keys, original_shapes] + vec2dict_args | |
| vec2dict = partial( | |
| vec2statedict, | |
| **dict( | |
| zip( | |
| [ | |
| "keys", | |
| "original_shapes", | |
| "selected_chunks", | |
| "position", | |
| "U_matrix_shape", | |
| ], | |
| vec2dict_args, | |
| ) | |
| ), | |
| ) | |
| if args.config is not None: | |
| config_path = Path(args.config).resolve() | |
| else: | |
| config_path = Path(info["runs"][0]) / "config.yaml" | |
| with open(config_path) as fp: | |
| fx_config = yaml.safe_load(fp) | |
| fx = instantiate(fx_config["model"]) | |
| fx = fx.cuda() | |
| fx.eval() | |
| fx.load_state_dict(vec2dict(baseline_vec), strict=False) | |
| ndim_dict = {k: v.ndim for k, v in fx.state_dict().items()} | |
| to_fx_state_dict = lambda x: { | |
| k: v[0] if ndim_dict[k] == 0 else v for k, v in vec2dict(x).items() | |
| } | |
| if args.method == "regression": | |
| ckpt_dir = Path(args.ckpt_dir) | |
| with open(ckpt_dir / "config.yaml") as f: | |
| config = yaml.safe_load(f) | |
| model_config = config["model"] | |
| data_config = config["data"] | |
| checkpoints = (ckpt_dir / "checkpoints").glob("*val_loss*.ckpt") | |
| lowest_checkpoint = min(checkpoints, key=lambda x: float(x.stem.split("=")[-1])) | |
| print(f"Loading checkpoint: {lowest_checkpoint}") | |
| last_ckpt = torch.load(lowest_checkpoint, map_location="cpu") | |
| model = chain_functions(remove_window_fn, jsonparse2hydra, instantiate)( | |
| model_config | |
| ) | |
| model.load_state_dict(last_ckpt["state_dict"]) | |
| model = model.cuda() | |
| model.eval() | |
| train_root = Path(data_config["init_args"]["train_root"]) | |
| try: | |
| param_stats = torch.load(train_root / "param_stats.pt") | |
| except FileNotFoundError: | |
| param_stats = torch.load(ckpt_dir / "param_stats.pt") | |
| param_mu, param_std = ( | |
| param_stats["mu"].float().cuda(), | |
| param_stats["std"].float().cuda(), | |
| ) | |
| regressor = lambda wet: model(wet, dry=None) * param_std + param_mu | |
| mid_side_embeds_fn = lambda x: (x, x) | |
| else: | |
| match args.encoder: | |
| case "afx-rep": | |
| afx_rep = load_param_model().cuda() | |
| mid_side_embeds_fn = lambda x: get_param_embeds(x, afx_rep, 44100) | |
| case "mfcc": | |
| mfcc = load_mfcc_feature_extractor().cuda() | |
| mid_side_embeds_fn = lambda x: get_feature_embeds(x, mfcc) | |
| case "mir": | |
| mir = load_mir_feature_extractor().cuda() | |
| mid_side_embeds_fn = lambda x: get_feature_embeds(x, mir) | |
| case _: | |
| raise ValueError(f"Unknown encoder: {args.encoder}") | |
| loss_fns = { | |
| "mss_lr": MultiResolutionSTFTLoss( | |
| [128, 512, 2048], | |
| [32, 128, 512], | |
| [128, 512, 2048], | |
| sample_rate=44100, | |
| perceptual_weighting=True, | |
| ).cuda(), | |
| "mss_ms": SumAndDifferenceSTFTLoss( | |
| [128, 512, 2048], | |
| [32, 128, 512], | |
| [128, 512, 2048], | |
| sample_rate=44100, | |
| perceptual_weighting=True, | |
| ), | |
| "mldr_lr": MLDRLoss( | |
| sr=44100, | |
| s_taus=[50, 100], | |
| l_taus=[1000, 2000], | |
| ).cuda(), | |
| "mldr_ms": MLDRLoss( | |
| sr=44100, | |
| s_taus=[50, 100], | |
| l_taus=[1000, 2000], | |
| mid_side=True, | |
| ).cuda(), | |
| } | |
| raw_params = np.load(eval_analysis_folder / "raw_params.npy") | |
| feature_mask = np.load(train_analysis_folder / "feature_mask.npy") | |
| gt_params = raw_params[:, feature_mask] | |
| train_params = np.load(train_analysis_folder / "raw_params.npy") | |
| train_index = np.load(train_analysis_folder / "train_index.npy") | |
| train_params = torch.from_numpy(train_params[train_index][:, feature_mask]).cuda() | |
| output_root = Path(args.output_dir) | |
| weights = [] | |
| losses = [] | |
| for dry_file, wet_file, shifts, gt_param in zip( | |
| info["dry_files"], info["wet_files"], info["alignment_shifts"], gt_params | |
| ): | |
| dry, sr = torchaudio.load(dry_file) | |
| wet, _ = torchaudio.load(wet_file) | |
| assert sr == _ | |
| dry = dry[:, : wet.shape[1]] | |
| wet = wet[:, : dry.shape[1]] | |
| dry = torch.roll(dry, shifts=int(shifts), dims=1) | |
| print(shifts, dry.shape, dry_file) | |
| dry = dry.mean(0, keepdim=True) | |
| meter = pyln.Meter(sr) | |
| normaliser = lambda x: pyln.normalize.loudness( | |
| x, meter.integrated_loudness(x), -18.0 | |
| ) | |
| dry = torch.from_numpy(normaliser(dry.numpy().T).T).float().cuda() | |
| wet = torch.from_numpy(normaliser(wet.numpy().T).T).float().cuda() | |
| gt_param = torch.tensor(gt_param).cuda() | |
| match args.method: | |
| case "ito": | |
| try: | |
| ref_audio, raw_audio = get_reference_query_chunks( | |
| dry, wet, int(sr * args.chunk_duration), sr | |
| ) | |
| except ValueError as e: | |
| print(f"Skipping {dry_file}: {e}") | |
| continue | |
| pred_param = one_evaluation( | |
| fx, | |
| mid_side_embeds_fn, | |
| to_fx_state_dict, | |
| logp_x, | |
| baseline_vec, | |
| ref_audio, | |
| raw_audio, | |
| lr=args.lr, | |
| steps=args.steps, | |
| weight=args.weight, | |
| ) | |
| case "oracle": | |
| pred_param = gt_param | |
| case "nn_param": | |
| pred_param = train_params[ | |
| torch.argmin((train_params - gt_param).square().mean(1)) | |
| ] | |
| case "nn_emb": | |
| try: | |
| ref_audio, raw_audio = get_reference_query_chunks( | |
| dry, wet, int(sr * args.chunk_duration), sr | |
| ) | |
| except ValueError as e: | |
| print(f"Skipping {dry_file}: {e}") | |
| continue | |
| pred_param = find_closest_training_sample( | |
| fx, | |
| mid_side_embeds_fn, | |
| to_fx_state_dict, | |
| train_params, | |
| ref_audio, | |
| raw_audio, | |
| ) | |
| case "mean": | |
| pred_param = baseline_vec | |
| case "regression": | |
| try: | |
| ref_audio, _ = get_reference_query_chunks( | |
| dry, wet, int(sr * args.chunk_duration), sr | |
| ) | |
| except ValueError as e: | |
| print(f"Skipping {dry_file}: {e}") | |
| continue | |
| with torch.no_grad(): | |
| pred_param = regressor(ref_audio).mean(0) | |
| case _: | |
| raise ValueError(f"Unknown method: {args.method}") | |
| fx.load_state_dict(vec2dict(pred_param), strict=False) | |
| with torch.no_grad(): | |
| rendered = fx(dry.unsqueeze(0)).squeeze() | |
| loss = { | |
| k: f(rendered.unsqueeze(0), wet.unsqueeze(0)).item() | |
| for k, f in loss_fns.items() | |
| } | |
| param_mse_loss = F.mse_loss(pred_param, gt_param).item() | |
| loss["param_mse"] = param_mse_loss | |
| print(", ".join([f"{k}: {v}" for k, v in loss.items()])) | |
| losses.append(loss) | |
| weights.append(wet.shape[1]) | |
| dry_file = Path(dry_file) | |
| out_dir = output_root / dry_file.parts[-2] / dry_file.stem | |
| out_dir.mkdir(parents=True, exist_ok=True) | |
| with open(out_dir / "metrics.yaml", "w") as fp: | |
| yaml.safe_dump( | |
| loss, | |
| fp, | |
| ) | |
| torch.save(pred_param.cpu(), out_dir / "pred_param.pth") | |
| with open(out_dir / "meta.yaml", "w") as fp: | |
| yaml.safe_dump( | |
| { | |
| "model": fx_config["model"], | |
| "params_keys": param_keys, | |
| "params_original_shapes": original_shapes, | |
| "alignment_shift": shifts, | |
| }, | |
| fp, | |
| ) | |
| # symbolic link | |
| original_wet = out_dir / "wet.wav" | |
| original_dry = out_dir / "dry.wav" | |
| if not original_wet.exists(): | |
| original_wet.symlink_to(wet_file) | |
| if not original_dry.exists(): | |
| original_dry.symlink_to(dry_file) | |
| if args.save_pred: | |
| torchaudio.save(out_dir / "pred.wav", rendered.cpu(), sr) | |
| weights = np.array(weights) | |
| weights = weights / weights.sum() | |
| print({k: np.array([l[k] for l in losses]) @ weights for k in losses[0].keys()}) | |
| if __name__ == "__main__": | |
| main() | |