|
|
import os
|
|
|
import datetime
|
|
|
import json
|
|
|
import logging
|
|
|
import librosa
|
|
|
import pickle
|
|
|
from typing import Dict
|
|
|
import numpy as np
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
import yaml
|
|
|
from models.audiosep import AudioSep, get_model_class
|
|
|
|
|
|
|
|
|
def ignore_warnings():
|
|
|
import warnings
|
|
|
|
|
|
warnings.filterwarnings('ignore', category=UserWarning, module='torch.functional')
|
|
|
|
|
|
|
|
|
pattern = r"Some weights of the model checkpoint at roberta-base were not used when initializing RobertaModel: \['lm_head\..*'\].*"
|
|
|
warnings.filterwarnings('ignore', message=pattern)
|
|
|
|
|
|
|
|
|
|
|
|
def create_logging(log_dir, filemode):
|
|
|
os.makedirs(log_dir, exist_ok=True)
|
|
|
i1 = 0
|
|
|
|
|
|
while os.path.isfile(os.path.join(log_dir, "{:04d}.log".format(i1))):
|
|
|
i1 += 1
|
|
|
|
|
|
log_path = os.path.join(log_dir, "{:04d}.log".format(i1))
|
|
|
logging.basicConfig(
|
|
|
level=logging.DEBUG,
|
|
|
format="%(asctime)s %(filename)s[line:%(lineno)d] %(levelname)s %(message)s",
|
|
|
datefmt="%a, %d %b %Y %H:%M:%S",
|
|
|
filename=log_path,
|
|
|
filemode=filemode,
|
|
|
)
|
|
|
|
|
|
|
|
|
console = logging.StreamHandler()
|
|
|
console.setLevel(logging.INFO)
|
|
|
formatter = logging.Formatter("%(name)-12s: %(levelname)-8s %(message)s")
|
|
|
console.setFormatter(formatter)
|
|
|
logging.getLogger("").addHandler(console)
|
|
|
|
|
|
return logging
|
|
|
|
|
|
|
|
|
def float32_to_int16(x: float) -> int:
|
|
|
x = np.clip(x, a_min=-1, a_max=1)
|
|
|
return (x * 32767.0).astype(np.int16)
|
|
|
|
|
|
|
|
|
def int16_to_float32(x: int) -> float:
|
|
|
return (x / 32767.0).astype(np.float32)
|
|
|
|
|
|
|
|
|
def parse_yaml(config_yaml: str) -> Dict:
|
|
|
r"""Parse yaml file.
|
|
|
|
|
|
Args:
|
|
|
config_yaml (str): config yaml path
|
|
|
|
|
|
Returns:
|
|
|
yaml_dict (Dict): parsed yaml file
|
|
|
"""
|
|
|
|
|
|
with open(config_yaml, "r") as fr:
|
|
|
return yaml.load(fr, Loader=yaml.FullLoader)
|
|
|
|
|
|
|
|
|
def get_audioset632_id_to_lb(ontology_path: str) -> Dict:
|
|
|
r"""Get AudioSet 632 classes ID to label mapping."""
|
|
|
|
|
|
audioset632_id_to_lb = {}
|
|
|
|
|
|
with open(ontology_path) as f:
|
|
|
data_list = json.load(f)
|
|
|
|
|
|
for e in data_list:
|
|
|
audioset632_id_to_lb[e["id"]] = e["name"]
|
|
|
|
|
|
return audioset632_id_to_lb
|
|
|
|
|
|
|
|
|
def load_pretrained_panns(
|
|
|
model_type: str,
|
|
|
checkpoint_path: str,
|
|
|
freeze: bool
|
|
|
) -> nn.Module:
|
|
|
r"""Load pretrained pretrained audio neural networks (PANNs).
|
|
|
|
|
|
Args:
|
|
|
model_type: str, e.g., "Cnn14"
|
|
|
checkpoint_path, str, e.g., "Cnn14_mAP=0.431.pth"
|
|
|
freeze: bool
|
|
|
|
|
|
Returns:
|
|
|
model: nn.Module
|
|
|
"""
|
|
|
|
|
|
if model_type == "Cnn14":
|
|
|
Model = Cnn14
|
|
|
|
|
|
elif model_type == "Cnn14_DecisionLevelMax":
|
|
|
Model = Cnn14_DecisionLevelMax
|
|
|
|
|
|
else:
|
|
|
raise NotImplementedError
|
|
|
|
|
|
model = Model(sample_rate=32000, window_size=1024, hop_size=320,
|
|
|
mel_bins=64, fmin=50, fmax=14000, classes_num=527)
|
|
|
|
|
|
if checkpoint_path:
|
|
|
checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
|
|
|
model.load_state_dict(checkpoint["model"])
|
|
|
|
|
|
if freeze:
|
|
|
for param in model.parameters():
|
|
|
param.requires_grad = False
|
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
def energy(x):
|
|
|
return torch.mean(x ** 2)
|
|
|
|
|
|
|
|
|
def magnitude_to_db(x):
|
|
|
eps = 1e-10
|
|
|
return 20. * np.log10(max(x, eps))
|
|
|
|
|
|
|
|
|
def db_to_magnitude(x):
|
|
|
return 10. ** (x / 20)
|
|
|
|
|
|
|
|
|
def ids_to_hots(ids, classes_num, device):
|
|
|
hots = torch.zeros(classes_num).to(device)
|
|
|
for id in ids:
|
|
|
hots[id] = 1
|
|
|
return hots
|
|
|
|
|
|
|
|
|
def calculate_sdr(
|
|
|
ref: np.ndarray,
|
|
|
est: np.ndarray,
|
|
|
eps=1e-10
|
|
|
) -> float:
|
|
|
r"""Calculate SDR between reference and estimation.
|
|
|
|
|
|
Args:
|
|
|
ref (np.ndarray), reference signal
|
|
|
est (np.ndarray), estimated signal
|
|
|
"""
|
|
|
reference = ref
|
|
|
noise = est - reference
|
|
|
|
|
|
|
|
|
numerator = np.clip(a=np.mean(reference ** 2), a_min=eps, a_max=None)
|
|
|
|
|
|
denominator = np.clip(a=np.mean(noise ** 2), a_min=eps, a_max=None)
|
|
|
|
|
|
sdr = 10. * np.log10(numerator / denominator)
|
|
|
|
|
|
return sdr
|
|
|
|
|
|
|
|
|
def calculate_sisdr(ref, est):
|
|
|
r"""Calculate SDR between reference and estimation.
|
|
|
|
|
|
Args:
|
|
|
ref (np.ndarray), reference signal
|
|
|
est (np.ndarray), estimated signal
|
|
|
"""
|
|
|
|
|
|
eps = np.finfo(ref.dtype).eps
|
|
|
|
|
|
reference = ref.copy()
|
|
|
estimate = est.copy()
|
|
|
|
|
|
reference = reference.reshape(reference.size, 1)
|
|
|
estimate = estimate.reshape(estimate.size, 1)
|
|
|
|
|
|
Rss = np.dot(reference.T, reference)
|
|
|
|
|
|
a = (eps + np.dot(reference.T, estimate)) / (Rss + eps)
|
|
|
|
|
|
e_true = a * reference
|
|
|
e_res = estimate - e_true
|
|
|
|
|
|
Sss = (e_true**2).sum()
|
|
|
Snn = (e_res**2).sum()
|
|
|
|
|
|
sisdr = 10 * np.log10((eps+ Sss)/(eps + Snn))
|
|
|
|
|
|
return sisdr
|
|
|
|
|
|
|
|
|
class StatisticsContainer(object):
|
|
|
def __init__(self, statistics_path):
|
|
|
self.statistics_path = statistics_path
|
|
|
|
|
|
self.backup_statistics_path = "{}_{}.pkl".format(
|
|
|
os.path.splitext(self.statistics_path)[0],
|
|
|
datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"),
|
|
|
)
|
|
|
|
|
|
self.statistics_dict = {"balanced_train": [], "test": []}
|
|
|
|
|
|
def append(self, steps, statistics, split, flush=True):
|
|
|
statistics["steps"] = steps
|
|
|
self.statistics_dict[split].append(statistics)
|
|
|
|
|
|
if flush:
|
|
|
self.flush()
|
|
|
|
|
|
def flush(self):
|
|
|
pickle.dump(self.statistics_dict, open(self.statistics_path, "wb"))
|
|
|
pickle.dump(self.statistics_dict, open(self.backup_statistics_path, "wb"))
|
|
|
logging.info(" Dump statistics to {}".format(self.statistics_path))
|
|
|
logging.info(" Dump statistics to {}".format(self.backup_statistics_path))
|
|
|
|
|
|
|
|
|
def get_mean_sdr_from_dict(sdris_dict):
|
|
|
mean_sdr = np.nanmean(list(sdris_dict.values()))
|
|
|
return mean_sdr
|
|
|
|
|
|
|
|
|
def remove_silence(audio: np.ndarray, sample_rate: int) -> np.ndarray:
|
|
|
r"""Remove silent frames."""
|
|
|
window_size = int(sample_rate * 0.1)
|
|
|
threshold = 0.02
|
|
|
|
|
|
frames = librosa.util.frame(x=audio, frame_length=window_size, hop_length=window_size).T
|
|
|
|
|
|
|
|
|
new_frames = get_active_frames(frames, threshold)
|
|
|
|
|
|
|
|
|
new_audio = new_frames.flatten()
|
|
|
|
|
|
|
|
|
return new_audio
|
|
|
|
|
|
|
|
|
def get_active_frames(frames: np.ndarray, threshold: float) -> np.ndarray:
|
|
|
r"""Get active frames."""
|
|
|
|
|
|
energy = np.max(np.abs(frames), axis=-1)
|
|
|
|
|
|
|
|
|
active_indexes = np.where(energy > threshold)[0]
|
|
|
|
|
|
|
|
|
new_frames = frames[active_indexes]
|
|
|
|
|
|
|
|
|
return new_frames
|
|
|
|
|
|
|
|
|
def repeat_to_length(audio: np.ndarray, segment_samples: int) -> np.ndarray:
|
|
|
r"""Repeat audio to length."""
|
|
|
|
|
|
repeats_num = (segment_samples // audio.shape[-1]) + 1
|
|
|
audio = np.tile(audio, repeats_num)[0 : segment_samples]
|
|
|
|
|
|
return audio
|
|
|
|
|
|
def calculate_segmentwise_sdr(ref, est, hop_samples, return_sdr_list=False):
|
|
|
min_len = min(ref.shape[-1], est.shape[-1])
|
|
|
pointer = 0
|
|
|
sdrs = []
|
|
|
while pointer + hop_samples < min_len:
|
|
|
sdr = calculate_sdr(
|
|
|
ref=ref[:, pointer : pointer + hop_samples],
|
|
|
est=est[:, pointer : pointer + hop_samples],
|
|
|
)
|
|
|
sdrs.append(sdr)
|
|
|
pointer += hop_samples
|
|
|
|
|
|
sdr = np.nanmedian(sdrs)
|
|
|
|
|
|
if return_sdr_list:
|
|
|
return sdr, sdrs
|
|
|
else:
|
|
|
return sdr
|
|
|
|
|
|
|
|
|
def loudness(data, input_loudness, target_loudness):
|
|
|
""" Loudness normalize a signal.
|
|
|
|
|
|
Normalize an input signal to a user loudness in dB LKFS.
|
|
|
|
|
|
Params
|
|
|
-------
|
|
|
data : torch.Tensor
|
|
|
Input multichannel audio data.
|
|
|
input_loudness : float
|
|
|
Loudness of the input in dB LUFS.
|
|
|
target_loudness : float
|
|
|
Target loudness of the output in dB LUFS.
|
|
|
|
|
|
Returns
|
|
|
-------
|
|
|
output : torch.Tensor
|
|
|
Loudness normalized output data.
|
|
|
"""
|
|
|
|
|
|
|
|
|
delta_loudness = target_loudness - input_loudness
|
|
|
gain = torch.pow(10.0, delta_loudness / 20.0)
|
|
|
|
|
|
output = gain * data
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return output
|
|
|
|
|
|
|
|
|
def load_ss_model(
|
|
|
configs: Dict,
|
|
|
checkpoint_path: str,
|
|
|
query_encoder: nn.Module
|
|
|
) -> nn.Module:
|
|
|
r"""Load trained universal source separation model.
|
|
|
|
|
|
Args:
|
|
|
configs (Dict)
|
|
|
checkpoint_path (str): path of the checkpoint to load
|
|
|
device (str): e.g., "cpu" | "cuda"
|
|
|
|
|
|
Returns:
|
|
|
pl_model: pl.LightningModule
|
|
|
"""
|
|
|
|
|
|
ss_model_type = configs["model"]["model_type"]
|
|
|
input_channels = configs["model"]["input_channels"]
|
|
|
output_channels = configs["model"]["output_channels"]
|
|
|
condition_size = configs["model"]["condition_size"]
|
|
|
|
|
|
|
|
|
SsModel = get_model_class(model_type=ss_model_type)
|
|
|
|
|
|
ss_model = SsModel(
|
|
|
input_channels=input_channels,
|
|
|
output_channels=output_channels,
|
|
|
condition_size=condition_size,
|
|
|
)
|
|
|
|
|
|
|
|
|
pl_model = AudioSep.load_from_checkpoint(
|
|
|
checkpoint_path=checkpoint_path,
|
|
|
strict=False,
|
|
|
ss_model=ss_model,
|
|
|
waveform_mixer=None,
|
|
|
query_encoder=query_encoder,
|
|
|
loss_function=None,
|
|
|
optimizer_type=None,
|
|
|
learning_rate=None,
|
|
|
lr_lambda_func=None,
|
|
|
map_location=torch.device('cpu'),
|
|
|
)
|
|
|
|
|
|
return pl_model
|
|
|
|
|
|
|
|
|
def parse_yaml(config_yaml: str) -> Dict:
|
|
|
r"""Parse yaml file.
|
|
|
|
|
|
Args:
|
|
|
config_yaml (str): config yaml path
|
|
|
|
|
|
Returns:
|
|
|
yaml_dict (Dict): parsed yaml file
|
|
|
"""
|
|
|
|
|
|
with open(config_yaml, "r") as fr:
|
|
|
return yaml.load(fr, Loader=yaml.FullLoader) |