diffvox-ito / st_ito /utils.py
yoyolicoris's picture
refactor: use modules from the dataset repo
53e45f1
import os
import yaml
import torch
import torchaudio
import pyloudnorm as pyln
from typing import Optional
from importlib import import_module
from huggingface_hub import hf_hub_download
from reg_modules.encoder import (
StatisticReduction,
LogCrest,
LogRMS,
LogSpread,
LogSpectralBandwidth,
LogSpectralCentroid,
LogSpectralFlatness,
Frame,
MapAndMerge,
)
# The following import should be done after downloading the diffvox dataset module files
from modules.fx import hadamard
# ------------------ Normalization functions ------------------
def apply_fade_in(x: torch.Tensor, num_samples: int = 16384):
"""Apply fade in to the first num_samples of the audio signal.
Args:
x (torch.Tensor): Input audio tensor
num_samples (int, optional): Number of samples to apply fade in. Defaults to 16384.
Returns:
torch.Tensor: Audio tensor with fade in applied
"""
fade = torch.linspace(0, 1, num_samples, device=x.device)
x[..., :num_samples] = x[..., :num_samples] * fade
return x
def batch_peak_normalize(x: torch.Tensor):
peak = torch.max(torch.abs(x), dim=1)[0]
x = x / peak[:, None].clamp(min=1e-8)
return x
def batch_loudness_normalize(x: torch.Tensor, meter: pyln.Meter, target_lufs: float):
for batch_idx in range(x.shape[0]):
lufs = meter.integrated_loudness(
x[batch_idx : batch_idx + 1, ...].permute(1, 0).cpu().numpy()
)
gain_db = target_lufs - lufs
gain_lin = 10 ** (gain_db / 20)
x[batch_idx, :] = gain_lin * x[batch_idx, :]
return x
# -------- self-supervised parameter estimation model -------- #
def get_param_embeds(
x: torch.Tensor,
model: torch.nn.Module,
sample_rate: int,
dropout: float = 0.0,
):
# if peak_normalize:
# x = batch_peak_normalize(x)
if sample_rate != 48000:
x = torchaudio.functional.resample(x, sample_rate, 48000)
seq_len = x.shape[-1] # update seq_len after resampling
# if longer than 262144 crop, else repeat pad to 262144
# if seq_len > 262144:
# x = x[:, :, :262144]
# else:
# x = torch.nn.functional.pad(x, (0, 262144 - seq_len), "replicate")
# peak normalize each batch item
# for batch_idx in range(bs):
# x[batch_idx, ...] /= x[batch_idx, ...].abs().max().clamp(1e-8)
# x = x / x.abs().amax(dim=(-1, -2), keepdim=True).clamp(min=1e-8)
mid_embeddings, side_embeddings = model(x)
# add dropout
if dropout > 0.0:
mid_embeddings = torch.nn.functional.dropout(
mid_embeddings, p=dropout, training=True
)
side_embeddings = torch.nn.functional.dropout(
side_embeddings, p=dropout, training=True
)
# check for nan
if torch.isnan(mid_embeddings).any():
print("Warning: NaNs found in mid_embeddings")
mid_embeddings = torch.nan_to_num(mid_embeddings)
elif torch.isnan(side_embeddings).any():
print("Warning: NaNs found in side_embeddings")
side_embeddings = torch.nan_to_num(side_embeddings)
# l2 normalize
mid_embeddings = torch.nn.functional.normalize(mid_embeddings, p=2, dim=-1)
side_embeddings = torch.nn.functional.normalize(side_embeddings, p=2, dim=-1)
return mid_embeddings, side_embeddings
def load_param_model(ckpt_path: Optional[str] = None):
if ckpt_path is None: # look in tmp direcory
# ckpt_path = os.path.join(os.getcwd(), "tmp", "afx-rep.ckpt")
os.makedirs("tmp", exist_ok=True)
# if not os.path.isfile(ckpt_path):
# download from huggingfacehub
# os.system(
# "wget -O tmp/afx-rep.ckpt https://huggingface.co/csteinmetz1/afx-rep/resolve/main/afx-rep.ckpt"
# )
# os.system(
# "wget -O tmp/config.yaml https://huggingface.co/csteinmetz1/afx-rep/resolve/main/config.yaml"
# )
ckpt_path = hf_hub_download(
repo_id="csteinmetz1/afx-rep",
filename="afx-rep.ckpt",
local_dir="tmp",
)
config_path = hf_hub_download(
repo_id="csteinmetz1/afx-rep",
filename="config.yaml",
local_dir="tmp",
)
else:
config_path = os.path.join(os.path.dirname(ckpt_path), "config.yaml")
with open(config_path) as f:
config = yaml.safe_load(f)
encoder_configs = config["model"]["init_args"]["encoder"]
module_path, class_name = encoder_configs["class_path"].rsplit(".", 1)
module_path = module_path.replace("lcap", "st_ito")
module = import_module(module_path)
model = getattr(module, class_name)(**encoder_configs["init_args"])
checkpoint = torch.load(ckpt_path, map_location="cpu")
# load state dicts
state_dict = {}
for k, v in checkpoint["state_dict"].items():
if k.startswith("encoder"):
state_dict[k.replace("encoder.", "", 1)] = v
model.load_state_dict(state_dict)
model.eval()
return model
def load_mfcc_feature_extractor():
transform = torch.nn.Sequential(
torchaudio.transforms.MFCC(
sample_rate=44100,
n_mfcc=25,
melkwargs={
"n_fft": 2048,
"hop_length": 1024,
"n_mels": 128,
"center": False,
},
),
StatisticReduction(),
torch.nn.Flatten(-2, -1),
)
return transform
def load_mir_feature_extractor():
transform = torch.nn.Sequential(
MapAndMerge(
[
torch.nn.Sequential(
Frame(2048, 1024, center=False),
MapAndMerge(
[
LogRMS(),
LogCrest(),
LogSpread(),
],
dim=-2,
),
),
torch.nn.Sequential(
torchaudio.transforms.Spectrogram(
n_fft=2048, hop_length=1024, center=False, power=1
),
MapAndMerge(
[
LogSpectralCentroid(),
LogSpectralBandwidth(),
LogSpectralFlatness(),
],
dim=-2,
),
),
],
dim=-2,
),
StatisticReduction(),
torch.nn.Flatten(-2, -1),
)
return transform
def get_feature_embeds(
x: torch.Tensor,
model: torch.nn.Module,
):
bs, chs, seq_len = x.shape
assert chs == 2, "MFCC feature extractor expects stereo input"
# x_ms = hadamard(x)
x_ms = x
# Get embeddings
embeddings = model(x_ms)
# l2 normalize
embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=-1)
return embeddings[:, 0], embeddings[:, 1]