|
|
from src.metrics.metrics import Metrics |
|
|
import src.utils as utils |
|
|
import argparse |
|
|
import os, json, glob |
|
|
import numpy as np |
|
|
import torch |
|
|
import pandas as pd |
|
|
import torchaudio |
|
|
import matplotlib.pyplot as plt |
|
|
import torch.nn as nn |
|
|
import copy |
|
|
import torch.nn.functional as F |
|
|
from torchmetrics.functional import signal_noise_ratio as snr |
|
|
|
|
|
|
|
|
def mod_pad(x, chunk_size, pad): |
|
|
mod = 0 |
|
|
if (x.shape[-1] % chunk_size) != 0: |
|
|
mod = chunk_size - (x.shape[-1] % chunk_size) |
|
|
|
|
|
x = F.pad(x, (0, mod)) |
|
|
x = F.pad(x, pad) |
|
|
|
|
|
return x, mod |
|
|
|
|
|
|
|
|
class LayerNormPermuted(nn.LayerNorm): |
|
|
def __init__(self, *args, **kwargs): |
|
|
super(LayerNormPermuted, self).__init__(*args, **kwargs) |
|
|
|
|
|
def forward(self, x): |
|
|
""" |
|
|
Args: |
|
|
x: [B, C, T, F] |
|
|
""" |
|
|
x = x.permute(0, 2, 3, 1) |
|
|
x = super().forward(x) |
|
|
x = x.permute(0, 3, 1, 2) |
|
|
return x |
|
|
|
|
|
|
|
|
def save_audio_file_torch(file_path, wavform, sample_rate=16000, rescale=False): |
|
|
if rescale: |
|
|
wavform = wavform / torch.max(wavform) * 0.9 |
|
|
torchaudio.save(file_path, wavform, sample_rate) |
|
|
|
|
|
|
|
|
def get_mixture_and_gt(curr_dir, rng, SHIFT_VALUE=0, noise_audio_list=[]): |
|
|
metadata2 = utils.read_json(os.path.join(curr_dir, "metadata.json")) |
|
|
diags = metadata2["target_dialogue"] |
|
|
|
|
|
if os.path.exists(os.path.join(curr_dir, "self_speech.wav")): |
|
|
self_speech = utils.read_audio_file_torch(os.path.join(curr_dir, "self_speech.wav"), 1) |
|
|
elif os.path.exists(os.path.join(curr_dir, "self_speech_original.wav")): |
|
|
self_speech = utils.read_audio_file_torch(os.path.join(curr_dir, "self_speech_original.wav"), 1) |
|
|
|
|
|
other_speech = torch.zeros_like(self_speech) |
|
|
|
|
|
for i in range(len(diags) - 1): |
|
|
wav = utils.read_audio_file_torch(os.path.join(curr_dir, f"target_speech{i}.wav"), 1) |
|
|
other_speech += wav |
|
|
|
|
|
if os.path.exists(os.path.join(curr_dir, f"intereference.wav")): |
|
|
interfere = utils.read_audio_file_torch(os.path.join(curr_dir, f"intereference.wav"), 1) |
|
|
else: |
|
|
interfere = torch.zeros_like(self_speech) |
|
|
interfere += utils.read_audio_file_torch(os.path.join(curr_dir, f"intereference0.wav"), 1) |
|
|
interfere += utils.read_audio_file_torch(os.path.join(curr_dir, f"intereference1.wav"), 1) |
|
|
|
|
|
gt = self_speech + other_speech |
|
|
tgt_snr = rng.uniform(-10, 10) |
|
|
interfere = scale_noise_to_snr(gt, interfere, tgt_snr) |
|
|
|
|
|
mixture = gt + interfere |
|
|
|
|
|
if noise_audio_list != []: |
|
|
print("added noise") |
|
|
noise_audio = noise_sample(noise_audio_list, mixture.shape[-1], rng) |
|
|
wham_scale = rng.uniform(0, 1) |
|
|
mixture += noise_audio * wham_scale |
|
|
|
|
|
embed_path = os.path.join(curr_dir, "embed.pt") |
|
|
if os.path.exists(embed_path): |
|
|
embed = torch.load(embed_path, weights_only=False) |
|
|
embed = torch.from_numpy(embed) |
|
|
else: |
|
|
embed = torch.zeros(256) |
|
|
|
|
|
L = mixture.shape[-1] |
|
|
|
|
|
peak = np.abs(mixture).max() |
|
|
if peak > 1: |
|
|
mixture /= peak |
|
|
self_speech /= peak |
|
|
gt /= peak |
|
|
|
|
|
inputs = { |
|
|
"mixture": mixture.float(), |
|
|
"embed": embed.float(), |
|
|
"self_speech": self_speech[0:1, :].float(), |
|
|
} |
|
|
|
|
|
targets = { |
|
|
"self": self_speech[0:1, :].numpy(), |
|
|
"other": other_speech[0:1, :].numpy(), |
|
|
"target": gt[0:1, :].float(), |
|
|
} |
|
|
|
|
|
return inputs, targets, metadata2 |
|
|
|
|
|
|
|
|
def scale_utterance(audio, timestamp, rng, db_change=7): |
|
|
for start, end in timestamp: |
|
|
if rng.uniform(0, 1) < 0.3: |
|
|
random_db = rng.uniform(-db_change, db_change) |
|
|
amplitude_factor = 10 ** (random_db / 20) |
|
|
audio[..., start:end] *= amplitude_factor |
|
|
|
|
|
return audio |
|
|
|
|
|
|
|
|
def get_snr(target, mixture, EPS=1e-9): |
|
|
""" |
|
|
Computes the average SNR across all channels |
|
|
""" |
|
|
return snr(mixture, target).mean() |
|
|
|
|
|
|
|
|
def scale_noise_to_snr(target_speech: torch.Tensor, noise: torch.Tensor, target_snr: float): |
|
|
current_snr = get_snr(target_speech, noise + target_speech) |
|
|
|
|
|
pwr = (current_snr - target_snr) / 20 |
|
|
k = 10**pwr |
|
|
|
|
|
return k * noise |
|
|
|
|
|
|
|
|
def run_testcase(model, inputs, device) -> np.ndarray: |
|
|
with torch.inference_mode(): |
|
|
inputs["mixture"] = inputs["mixture"][0:1, ...].unsqueeze(0).to(device) |
|
|
inputs["embed"] = inputs["embed"].unsqueeze(0).to(device) |
|
|
inputs["self_speech"] = inputs["self_speech"][0:1, ...].unsqueeze(0).to(device) |
|
|
|
|
|
inputs["start_idx"] = 0 |
|
|
inputs["end_idx"] = inputs["mixture"].shape[-1] |
|
|
outputs = model(inputs) |
|
|
|
|
|
output_target = outputs["output"].squeeze(0) |
|
|
|
|
|
final_output = output_target.cpu().numpy() |
|
|
|
|
|
return final_output |
|
|
|
|
|
|
|
|
def get_timestamp_mask(timestamps, mask_shape): |
|
|
mask = torch.zeros(mask_shape) |
|
|
for s, e in timestamps: |
|
|
mask[..., s:e] = 1 |
|
|
|
|
|
return mask |
|
|
|
|
|
|
|
|
def noise_sample(noise_file_list, audio_length, rng: np.random.RandomState): |
|
|
|
|
|
target_sr = 16000 |
|
|
|
|
|
acc_len = 0 |
|
|
concatenated_audio = None |
|
|
while acc_len <= audio_length: |
|
|
noise_file = rng.choice(noise_file_list) |
|
|
info = torchaudio.info(noise_file) |
|
|
noise_sr = info.sample_rate |
|
|
|
|
|
noise_wav, _ = torchaudio.load(noise_file) |
|
|
noise_wav = noise_wav[0:1, ...] |
|
|
|
|
|
if noise_sr != target_sr: |
|
|
resampler = torchaudio.transforms.Resample(orig_freq=noise_sr, new_freq=target_sr) |
|
|
noise_wav = resampler(noise_wav) |
|
|
|
|
|
if concatenated_audio is None: |
|
|
concatenated_audio = noise_wav |
|
|
else: |
|
|
concatenated_audio = torch.cat((concatenated_audio, noise_wav), dim=1) |
|
|
|
|
|
acc_len = concatenated_audio.shape[-1] |
|
|
|
|
|
concatenated_audio = concatenated_audio[..., :audio_length] |
|
|
|
|
|
assert concatenated_audio.shape[1] == audio_length |
|
|
|
|
|
return concatenated_audio |
|
|
|
|
|
|
|
|
def main(args: argparse.Namespace): |
|
|
device = "cuda" if args.use_cuda else "cpu" |
|
|
|
|
|
|
|
|
model = utils.load_torch_pretrained(args.run_dir).model |
|
|
model_name = args.run_dir.split("/")[-1] |
|
|
model = model.to(device) |
|
|
model.eval() |
|
|
|
|
|
|
|
|
snr = Metrics("snr") |
|
|
snr_i = Metrics("snr_i") |
|
|
|
|
|
si_sdr = Metrics("si_sdr") |
|
|
|
|
|
records = [] |
|
|
|
|
|
noise_audio_list = [] |
|
|
if args.noise_dir is not None: |
|
|
noise_audio_sublist = glob.glob(os.path.join(args.noise_dir, "*.wav")) |
|
|
if not noise_audio_sublist: |
|
|
print("no noise file found") |
|
|
noise_audio_list.extend(noise_audio_sublist) |
|
|
|
|
|
for i in range(0, 200): |
|
|
rng = np.random.RandomState(i) |
|
|
dataset_name = os.path.basename(args.test_dir) |
|
|
curr_dir = os.path.join(args.test_dir, "{:05d}".format(i)) |
|
|
|
|
|
meta_dir = os.path.join(curr_dir, "metadata.json") |
|
|
|
|
|
if not os.path.exists(meta_dir): |
|
|
continue |
|
|
|
|
|
inputs, targets, metadata = get_mixture_and_gt(curr_dir, rng, noise_audio_list=noise_audio_list) |
|
|
|
|
|
if inputs is None: |
|
|
continue |
|
|
|
|
|
self_timestamps = metadata["target_dialogue"][0]["timestamp"] |
|
|
|
|
|
target_speech = targets["target"].cpu().numpy() |
|
|
row = {"test_case_index": i} |
|
|
mixture = inputs["mixture"].cpu().numpy() |
|
|
|
|
|
self_speech = inputs["self_speech"].squeeze(0).cpu().numpy() |
|
|
|
|
|
inputs["mixture"] = inputs["mixture"][0:1, ...] |
|
|
target_speech = target_speech[0:1, ...] |
|
|
|
|
|
output_target = run_testcase(model, inputs, device) |
|
|
|
|
|
self_timestamps = metadata["target_dialogue"][0]["timestamp"] |
|
|
self_mask = get_timestamp_mask(self_timestamps, target_speech.shape) |
|
|
self_mask[..., : args.sr] = 0 |
|
|
|
|
|
if mixture.ndim == 1: |
|
|
mixture = mixture[np.newaxis, ...] |
|
|
|
|
|
total_input_sisdr = si_sdr(est=mixture[0:1], gt=target_speech, mix=mixture[0:1]).item() |
|
|
total_output_sisdr = si_sdr(est=output_target, gt=target_speech, mix=mixture[0:1]).item() |
|
|
|
|
|
row[f"sisdr_input_total"] = total_input_sisdr |
|
|
row[f"sisdr_output_total"] = total_output_sisdr |
|
|
|
|
|
|
|
|
|
|
|
self_sisdr_mix = si_sdr( |
|
|
est=self_mask * mixture[:1], gt=self_mask * target_speech, mix=self_mask * mixture[:1] |
|
|
).item() |
|
|
self_sisdr_pred = si_sdr( |
|
|
est=self_mask * output_target, gt=self_mask * target_speech, mix=self_mask * mixture[:1] |
|
|
).item() |
|
|
|
|
|
row[f"sisdr_mix_self"] = self_sisdr_mix |
|
|
row[f"sisdr_pred_self"] = self_sisdr_pred |
|
|
|
|
|
|
|
|
|
|
|
other_timestamps = metadata["target_dialogue"][1]["timestamp"] |
|
|
if len(metadata["target_dialogue"]) > 2: |
|
|
for j in range(2, len(metadata["target_dialogue"])): |
|
|
timestamp = metadata["target_dialogue"][j]["timestamp"] |
|
|
other_timestamps = other_timestamps + timestamp |
|
|
|
|
|
other_mask = get_timestamp_mask(other_timestamps, target_speech.shape) |
|
|
other_mask[..., : args.sr] = 0 |
|
|
|
|
|
other_sisdr_mix = si_sdr( |
|
|
est=other_mask * mixture[:1], gt=other_mask * target_speech, mix=other_mask * mixture[:1] |
|
|
).item() |
|
|
other_sisdr_pred = si_sdr( |
|
|
est=other_mask * output_target, gt=other_mask * target_speech, mix=other_mask * mixture[:1] |
|
|
).item() |
|
|
|
|
|
row[f"sisdr_mix_other"] = other_sisdr_mix |
|
|
row[f"sisdr_pred_other"] = other_sisdr_pred |
|
|
|
|
|
print(i) |
|
|
records.append(row) |
|
|
|
|
|
if noise_audio_list != []: |
|
|
save_folder = f"./result_{dataset_name}_noise/{model_name}/{i}" |
|
|
else: |
|
|
save_folder = f"./result_{dataset_name}/{model_name}/{i}" |
|
|
os.makedirs(save_folder, exist_ok=True) |
|
|
|
|
|
if type(self_speech) == np.ndarray: |
|
|
self_speech = torch.from_numpy(self_speech) |
|
|
|
|
|
if self_speech.dim() == 1: |
|
|
self_speech = self_speech.unsqueeze(0) |
|
|
|
|
|
if args.save: |
|
|
save_audio_file_torch( |
|
|
f"{save_folder}/mix.wav", torch.from_numpy(mixture[0:1]), sample_rate=args.sr, rescale=False |
|
|
) |
|
|
save_audio_file_torch(f"{save_folder}/self.wav", self_speech, sample_rate=args.sr, rescale=False) |
|
|
save_audio_file_torch( |
|
|
f"{save_folder}/output_target.wav", torch.from_numpy(output_target), sample_rate=args.sr, rescale=False |
|
|
) |
|
|
save_audio_file_torch( |
|
|
f"{save_folder}/target_speech.wav", torch.from_numpy(target_speech), sample_rate=args.sr, rescale=False |
|
|
) |
|
|
|
|
|
results_df = pd.DataFrame.from_records(records) |
|
|
|
|
|
columns = ["test_case_index"] + [col for col in results_df.columns if col != "test_case_index"] |
|
|
results_df = results_df[columns] |
|
|
|
|
|
if noise_audio_list != []: |
|
|
results_csv_path = f"./result_{dataset_name}_noise/{model_name}_multi.csv" |
|
|
else: |
|
|
results_csv_path = f"./result_{dataset_name}/{model_name}_multi.csv" |
|
|
results_df.to_csv(results_csv_path, index=False) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
parser = argparse.ArgumentParser() |
|
|
parser.add_argument("test_dir", type=str, help="Path to test dataset") |
|
|
|
|
|
parser.add_argument("run_dir", type=str, help="Path to model run checkpoint") |
|
|
|
|
|
parser.add_argument("--sr", type=int, default=16000, help="Project sampling rate") |
|
|
|
|
|
parser.add_argument("--noise_dir", type=str, default=None, help="Wham noise directory") |
|
|
|
|
|
parser.add_argument("--use_cuda", action="store_true", help="Whether to use cuda") |
|
|
|
|
|
parser.add_argument("--save", action="store_true", help="Whether to save output audio") |
|
|
|
|
|
main(parser.parse_args()) |
|
|
|