diffvox-ito / ito.py
yoyolicoris's picture
feat: add regression model and update inference methods
48e1ce4
import torch
import numpy as np
import torchaudio
import torch.nn.functional as F
import argparse
from pathlib import Path
import yaml
from typing import Callable, Tuple, Optional
import json
from hydra.utils import instantiate
from tqdm import tqdm
from functools import reduce
import math
import pyloudnorm as pyln
from functools import partial
from auraloss.freq import MultiResolutionSTFTLoss, SumAndDifferenceSTFTLoss
from modules.utils import chain_functions, get_chunks, vec2statedict
from st_ito.utils import (
load_param_model,
get_param_embeds,
get_feature_embeds,
load_mfcc_feature_extractor,
load_mir_feature_extractor,
)
from utils import remove_window_fn, jsonparse2hydra
def logp_y_given_x(y, mu, std):
cos_dist = torch.arccos(y @ mu)
return -0.5 * (cos_dist / std).pow(2) - 0.5 * math.log(2 * math.pi) - std.log()
def one_evaluation(
fx: torch.nn.Module,
mid_side_embeds_fn: Callable[[torch.Tensor], tuple[torch.Tensor, torch.Tensor]],
to_fx_state_dict: Callable[[torch.Tensor], dict],
logp_x: Callable[[torch.Tensor], torch.Tensor],
init_vec: torch.Tensor,
ref_audio: torch.Tensor,
raw_audio: torch.Tensor,
optimiser_type: str,
lr: float,
steps: int,
weight: float,
progress,
) -> torch.Tensor:
peak_scaler = 1 / ref_audio.abs().max()
ref_audio = ref_audio * peak_scaler
print(ref_audio.shape, raw_audio.shape)
param_logits = torch.nn.Parameter(init_vec.clone())
# optimiser = torch.optim.Adam([param_logits], lr=lr)
optimiser = getattr(torch.optim, optimiser_type)([param_logits], lr=lr)
with torch.no_grad():
ref_mid_embs, ref_side_embs = mid_side_embeds_fn(ref_audio)
for i in progress.tqdm(range(steps)):
cur_state_dict = to_fx_state_dict(param_logits)
preds = (
sum(torch.func.functional_call(fx, cur_state_dict, raw_audio)) * peak_scaler
)
mid_embs_pred, side_embs_pred = mid_side_embeds_fn(preds)
mid_cos = torch.arccos(mid_embs_pred @ ref_mid_embs.T)
side_cos = torch.arccos(side_embs_pred @ ref_side_embs.T)
mid_std = mid_cos.square().mean().sqrt()
side_std = side_cos.square().mean().sqrt()
y_x_ll = (
logp_y_given_x(ref_mid_embs, mid_embs_pred.T, mid_std).mean()
+ logp_y_given_x(ref_side_embs, side_embs_pred.T, side_std).mean()
)
if weight > 0:
x_ll = logp_x(param_logits)
loss = -y_x_ll - x_ll * weight
else:
x_ll = y_x_ll.new_zeros(1)
loss = -y_x_ll
optimiser.zero_grad()
loss.backward()
optimiser.step()
postfix_dict = {
"y_x_ll": y_x_ll.item(),
"x_ll": x_ll.item(),
"loss": loss.item(),
"mid_std": mid_std.item() / math.pi * 180,
"side_std": side_std.item() / math.pi * 180,
}
# pbar.set_postfix(
# **postfix_dict,
# )
print(y_x_ll.item(), x_ll.item(), loss.item())
print(mid_std.item() / math.pi * 180, side_std.item() / math.pi * 180)
return param_logits.detach()
@torch.no_grad()
def find_closest_training_sample(
fx: torch.nn.Module,
mid_side_embeds_fn: Callable[[torch.Tensor], tuple[torch.Tensor, torch.Tensor]],
to_fx_state_dict: Callable[[torch.Tensor], dict],
training_samples: torch.Tensor,
ref_audio: torch.Tensor,
raw_audio: torch.Tensor,
progress,
) -> torch.Tensor:
peak_scaler = 1 / ref_audio.abs().max()
ref_audio = ref_audio * peak_scaler
print(ref_audio.shape, raw_audio.shape)
ref_mid_embs, ref_side_embs = mid_side_embeds_fn(ref_audio)
def reduce_closure(
x: Tuple[float, torch.Tensor], next_param: torch.Tensor
) -> Tuple[float, torch.Tensor]:
cur_best_logp, cur_best_param = x
cur_state_dict = to_fx_state_dict(next_param)
preds = (
sum(torch.func.functional_call(fx, cur_state_dict, raw_audio)) * peak_scaler
)
mid_embs_pred, side_embs_pred = mid_side_embeds_fn(preds)
mid_cos = torch.arccos(mid_embs_pred @ ref_mid_embs.T)
side_cos = torch.arccos(side_embs_pred @ ref_side_embs.T)
mid_std = mid_cos.square().mean().sqrt()
side_std = side_cos.square().mean().sqrt()
y_x_ll = (
logp_y_given_x(ref_mid_embs, mid_embs_pred.T, mid_std).mean()
+ logp_y_given_x(ref_side_embs, side_embs_pred.T, side_std).mean()
).item()
return (
(cur_best_logp, cur_best_param)
if y_x_ll < cur_best_logp
else (y_x_ll, next_param)
)
best_logp, best_param = reduce(
reduce_closure,
progress.tqdm(training_samples.unbind(0)),
(-float("inf"), torch.tensor([])),
)
print(f"Best log-likelihood: {best_logp}")
return best_param
def main():
parser = argparse.ArgumentParser()
parser.add_argument("eval_analysis_dir", type=str)
parser.add_argument("train_analysis_dir", type=str)
parser.add_argument("output_dir", type=str)
parser.add_argument("--config", type=str, help="Path to fx config file")
parser.add_argument("--chunk-duration", type=float, default=11.0)
parser.add_argument("--weight", type=float, default=0.01)
parser.add_argument("--steps", type=int, default=1000)
parser.add_argument("--lr", type=float, default=0.01)
parser.add_argument(
"--method",
type=str,
choices=["ito", "oracle", "nn_param", "nn_emb", "mean", "regression"],
default="ito",
)
parser.add_argument(
"--encoder", type=str, default="afx-rep", choices=["afx-rep", "mfcc", "mir"]
)
parser.add_argument("--save-pred", action="store_true")
parser.add_argument("--ckpt-dir", type=str)
args = parser.parse_args()
# load PCA
train_analysis_folder = Path(args.train_analysis_dir).resolve()
eval_analysis_folder = Path(args.eval_analysis_dir).resolve()
gauss_data = np.load(train_analysis_folder / "gaussian.npz")
baseline_vec = torch.tensor(gauss_data["mean"]).cuda()
cov = torch.tensor(gauss_data["cov"]).cuda()
cov_logdet = cov.logdet()
def logp_x(x):
diff = x - baseline_vec
b = torch.linalg.solve(cov, diff)
norm = diff @ b
return -0.5 * (
norm + cov_logdet + baseline_vec.shape[0] * math.log(2 * math.pi)
)
print(f"Baseline logp: {logp_x(baseline_vec).item()}")
with open(eval_analysis_folder / "info.json") as f:
info = json.load(f)
param_keys = info["params_keys"]
original_shapes = list(
map(lambda lst: lst if len(lst) else [1], info["params_original_shapes"])
)
*vec2dict_args, dimensions_not_need = get_chunks(param_keys, original_shapes)
vec2dict_args = [param_keys, original_shapes] + vec2dict_args
vec2dict = partial(
vec2statedict,
**dict(
zip(
[
"keys",
"original_shapes",
"selected_chunks",
"position",
"U_matrix_shape",
],
vec2dict_args,
)
),
)
if args.config is not None:
config_path = Path(args.config).resolve()
else:
config_path = Path(info["runs"][0]) / "config.yaml"
with open(config_path) as fp:
fx_config = yaml.safe_load(fp)
fx = instantiate(fx_config["model"])
fx = fx.cuda()
fx.eval()
fx.load_state_dict(vec2dict(baseline_vec), strict=False)
ndim_dict = {k: v.ndim for k, v in fx.state_dict().items()}
to_fx_state_dict = lambda x: {
k: v[0] if ndim_dict[k] == 0 else v for k, v in vec2dict(x).items()
}
if args.method == "regression":
ckpt_dir = Path(args.ckpt_dir)
with open(ckpt_dir / "config.yaml") as f:
config = yaml.safe_load(f)
model_config = config["model"]
data_config = config["data"]
checkpoints = (ckpt_dir / "checkpoints").glob("*val_loss*.ckpt")
lowest_checkpoint = min(checkpoints, key=lambda x: float(x.stem.split("=")[-1]))
print(f"Loading checkpoint: {lowest_checkpoint}")
last_ckpt = torch.load(lowest_checkpoint, map_location="cpu")
model = chain_functions(remove_window_fn, jsonparse2hydra, instantiate)(
model_config
)
model.load_state_dict(last_ckpt["state_dict"])
model = model.cuda()
model.eval()
train_root = Path(data_config["init_args"]["train_root"])
try:
param_stats = torch.load(train_root / "param_stats.pt")
except FileNotFoundError:
param_stats = torch.load(ckpt_dir / "param_stats.pt")
param_mu, param_std = (
param_stats["mu"].float().cuda(),
param_stats["std"].float().cuda(),
)
regressor = lambda wet: model(wet, dry=None) * param_std + param_mu
mid_side_embeds_fn = lambda x: (x, x)
else:
match args.encoder:
case "afx-rep":
afx_rep = load_param_model().cuda()
mid_side_embeds_fn = lambda x: get_param_embeds(x, afx_rep, 44100)
case "mfcc":
mfcc = load_mfcc_feature_extractor().cuda()
mid_side_embeds_fn = lambda x: get_feature_embeds(x, mfcc)
case "mir":
mir = load_mir_feature_extractor().cuda()
mid_side_embeds_fn = lambda x: get_feature_embeds(x, mir)
case _:
raise ValueError(f"Unknown encoder: {args.encoder}")
loss_fns = {
"mss_lr": MultiResolutionSTFTLoss(
[128, 512, 2048],
[32, 128, 512],
[128, 512, 2048],
sample_rate=44100,
perceptual_weighting=True,
).cuda(),
"mss_ms": SumAndDifferenceSTFTLoss(
[128, 512, 2048],
[32, 128, 512],
[128, 512, 2048],
sample_rate=44100,
perceptual_weighting=True,
),
"mldr_lr": MLDRLoss(
sr=44100,
s_taus=[50, 100],
l_taus=[1000, 2000],
).cuda(),
"mldr_ms": MLDRLoss(
sr=44100,
s_taus=[50, 100],
l_taus=[1000, 2000],
mid_side=True,
).cuda(),
}
raw_params = np.load(eval_analysis_folder / "raw_params.npy")
feature_mask = np.load(train_analysis_folder / "feature_mask.npy")
gt_params = raw_params[:, feature_mask]
train_params = np.load(train_analysis_folder / "raw_params.npy")
train_index = np.load(train_analysis_folder / "train_index.npy")
train_params = torch.from_numpy(train_params[train_index][:, feature_mask]).cuda()
output_root = Path(args.output_dir)
weights = []
losses = []
for dry_file, wet_file, shifts, gt_param in zip(
info["dry_files"], info["wet_files"], info["alignment_shifts"], gt_params
):
dry, sr = torchaudio.load(dry_file)
wet, _ = torchaudio.load(wet_file)
assert sr == _
dry = dry[:, : wet.shape[1]]
wet = wet[:, : dry.shape[1]]
dry = torch.roll(dry, shifts=int(shifts), dims=1)
print(shifts, dry.shape, dry_file)
dry = dry.mean(0, keepdim=True)
meter = pyln.Meter(sr)
normaliser = lambda x: pyln.normalize.loudness(
x, meter.integrated_loudness(x), -18.0
)
dry = torch.from_numpy(normaliser(dry.numpy().T).T).float().cuda()
wet = torch.from_numpy(normaliser(wet.numpy().T).T).float().cuda()
gt_param = torch.tensor(gt_param).cuda()
match args.method:
case "ito":
try:
ref_audio, raw_audio = get_reference_query_chunks(
dry, wet, int(sr * args.chunk_duration), sr
)
except ValueError as e:
print(f"Skipping {dry_file}: {e}")
continue
pred_param = one_evaluation(
fx,
mid_side_embeds_fn,
to_fx_state_dict,
logp_x,
baseline_vec,
ref_audio,
raw_audio,
lr=args.lr,
steps=args.steps,
weight=args.weight,
)
case "oracle":
pred_param = gt_param
case "nn_param":
pred_param = train_params[
torch.argmin((train_params - gt_param).square().mean(1))
]
case "nn_emb":
try:
ref_audio, raw_audio = get_reference_query_chunks(
dry, wet, int(sr * args.chunk_duration), sr
)
except ValueError as e:
print(f"Skipping {dry_file}: {e}")
continue
pred_param = find_closest_training_sample(
fx,
mid_side_embeds_fn,
to_fx_state_dict,
train_params,
ref_audio,
raw_audio,
)
case "mean":
pred_param = baseline_vec
case "regression":
try:
ref_audio, _ = get_reference_query_chunks(
dry, wet, int(sr * args.chunk_duration), sr
)
except ValueError as e:
print(f"Skipping {dry_file}: {e}")
continue
with torch.no_grad():
pred_param = regressor(ref_audio).mean(0)
case _:
raise ValueError(f"Unknown method: {args.method}")
fx.load_state_dict(vec2dict(pred_param), strict=False)
with torch.no_grad():
rendered = fx(dry.unsqueeze(0)).squeeze()
loss = {
k: f(rendered.unsqueeze(0), wet.unsqueeze(0)).item()
for k, f in loss_fns.items()
}
param_mse_loss = F.mse_loss(pred_param, gt_param).item()
loss["param_mse"] = param_mse_loss
print(", ".join([f"{k}: {v}" for k, v in loss.items()]))
losses.append(loss)
weights.append(wet.shape[1])
dry_file = Path(dry_file)
out_dir = output_root / dry_file.parts[-2] / dry_file.stem
out_dir.mkdir(parents=True, exist_ok=True)
with open(out_dir / "metrics.yaml", "w") as fp:
yaml.safe_dump(
loss,
fp,
)
torch.save(pred_param.cpu(), out_dir / "pred_param.pth")
with open(out_dir / "meta.yaml", "w") as fp:
yaml.safe_dump(
{
"model": fx_config["model"],
"params_keys": param_keys,
"params_original_shapes": original_shapes,
"alignment_shift": shifts,
},
fp,
)
# symbolic link
original_wet = out_dir / "wet.wav"
original_dry = out_dir / "dry.wav"
if not original_wet.exists():
original_wet.symlink_to(wet_file)
if not original_dry.exists():
original_dry.symlink_to(dry_file)
if args.save_pred:
torchaudio.save(out_dir / "pred.wav", rendered.cpu(), sr)
weights = np.array(weights)
weights = weights / weights.sum()
print({k: np.array([l[k] for l in losses]) @ weights for k in losses[0].keys()})
if __name__ == "__main__":
main()