haoxiangsnr's picture
Upload folder using huggingface_hub
50de2e0 verified
from dataclasses import dataclass
from pathlib import Path
import numpy as np
import pandas as pd
import soundfile as sf
import torch
import torchaudio
from accelerate.logging import get_logger
from clarity.enhancer.compressor import Compressor
from clarity.enhancer.nalr import NALR
from clarity.evaluator.haspi import haspi_v2
from clarity.utils.audiogram import Audiogram, Listener
from dataset import DevDataset, DevDatasetArgs, PredictDataset, PredictDatasetArgs, TrainDataset, TrainDatasetArgs
from loss import FreqLoss, MultiResolutionL1SpecLoss, SNRLoss
from model import Model, ModelArgs
from simple_parsing import Serializable, parse
from audiozen.logger import init_logging_logger
from audiozen.metric import DNSMOS, PESQ, SISDR, STOI
from audiozen.trainer_args import TrainingArgs
from audiozen.trainer_v2 import Trainer as BaseTrainer
logger = get_logger(__name__)
class Trainer(BaseTrainer):
def __init__(self, *args, is_left_ear, loss_function, listeners_file, **kwargs):
super().__init__(*args, **kwargs)
# Extract run arguments
self.ch_idx = 0 if is_left_ear else 1
if loss_function == "SNRLoss":
self.loss_function = SNRLoss()
elif loss_function == "FreqLoss":
self.loss_function = FreqLoss()
elif loss_function == "MultiResolutionL1SpecLoss":
self.loss_function = MultiResolutionL1SpecLoss()
else:
raise ValueError(f"Invalid loss function: {loss_function}")
self.dns_mos = DNSMOS(input_sr=self.sr, device=self.process_index)
self.si_sdr = SISDR()
self.pesq_wb = PESQ(sr=self.sr)
self.stoi = STOI(sr=self.sr)
self.down_sample = torchaudio.transforms.Resample(44100, self.sr, resampling_method="sinc_interp_hann")
self.down_sample.to(self.device)
self.up_sample = torchaudio.transforms.Resample(self.sr, 44100, resampling_method="sinc_interp_hann")
self.up_sample.to(self.device)
# For prediction
self.enhanced_signals_dir = self.output_dir / "enhanced_signals"
self.amplified_dir = self.output_dir / "amplified_signals"
self.amplified_dir.mkdir(parents=True, exist_ok=True)
self.enhanced_signals_dir.mkdir(parents=True, exist_ok=True)
self.listener_dict = Listener.load_listener_dict(listeners_file)
self.enhancer = NALR(nfir=220, sample_rate=44100)
self.compressor = Compressor(threshold=0.35, attenuation=0.1, attack=50, release=1000, rms_buffer_size=0.064)
def training_step(self, batch, batch_idx):
mix, ref, enroll, scene, length_list = batch
ref = ref[:, self.ch_idx, :]
mix = self.down_sample(mix) # [B, 6, T]
ref = self.down_sample(ref) # [B, T]
enroll = self.down_sample(enroll) # [B, T]
*_, loss = self.model(mix, enroll, ref) # [B, T]
self.accelerator.backward(loss)
if self.accelerator.sync_gradients:
norm_before = self.accelerator.clip_grad_norm_(self.model.parameters(), self.args.max_grad_norm)
self.optimizer.step()
self.optimizer.zero_grad()
return {
"loss": loss.detach().cpu().numpy(),
"norm_before": norm_before.detach().cpu().numpy() if self.accelerator.sync_gradients else 0.0,
}
@staticmethod
def amplify_signal(signal, audiogram: Audiogram, enhancer: NALR, compressor: Compressor):
"""Amplify signal for a given audiogram"""
nalr_fir, _ = enhancer.build(audiogram)
out = enhancer.apply(nalr_fir, signal)
out, _, _ = compressor.process(out)
return out
@torch.no_grad()
def prediction_step(self, batch, batch_idx):
y_mix, y_enroll, listener_id, scene = batch
scene = scene[0]
listener_id = listener_id[0]
y_mix = y_mix[:, self.ch_idx, :]
y_mix = self.down_sample(y_mix)
y_enroll = self.down_sample(y_enroll)
y_est = self.model(y_mix, y_enroll)
# Convert the signal back to 44100 Hz
y_est = self.up_sample(y_est)
# Mock two channels
y_est = torch.stack([y_est, y_est], dim=1) # [B, 2, T]
# Convert to numpy
y_est = y_est.squeeze(0).detach().cpu().numpy() # [2, T]
# Save enhanced signals
sf.write(self.enhanced_signals_dir / f"{scene}_{listener_id}_enhanced.wav", y_est.T, 44100)
listener = self.listener_dict[listener_id]
# Amplify the signal
y_est_amp_l = self.amplify_signal(y_est[0], listener.audiogram_left, self.enhancer, self.compressor)
y_est_amp_r = self.amplify_signal(y_est[1], listener.audiogram_right, self.enhancer, self.compressor)
y_est_amp = np.stack([y_est_amp_l, y_est_amp_r], axis=0) # [2, T]
y_est_amp = y_est_amp.astype(np.float32)
# Save amplified signals
sf.write(self.amplified_dir / f"{scene}_{listener_id}_HA-output.wav", y_est_amp.T, 44100)
@torch.no_grad()
def evaluation_step(self, batch, batch_idx, dataloader_id):
if self.args.do_predict:
return self.prediction_step(batch, batch_idx)
# Unwrap features from the batch
y_mix, y_ref, y_enroll, listener_id, scene = batch
scene = scene[0]
listener_id = listener_id[0]
y_ref = y_ref[:, self.ch_idx, :]
y_mix_16k = self.down_sample(y_mix) # [B, 6, T]
y_enroll_16k = self.down_sample(y_enroll)
y_ref_16k = self.down_sample(y_ref)
num_samples = y_mix_16k.shape[-1]
# Forward pass
logits, loss = self.model(y_mix_16k, y_enroll_16k, y_ref_16k)
y_est_16k = self.accelerator.unwrap_model(self.model).decode(logits, num_samples)
y_est = self.up_sample(y_est_16k) # Before converting to numpy
# 16KHz metrics
y_est_16k = y_est_16k.squeeze(0).detach().cpu().numpy() # [T]
y_ref_16k = y_ref_16k.squeeze(0).detach().cpu().numpy() # [T]
dns_mos = self.dns_mos(y_est_16k)
si_sdr = self.si_sdr(y_est_16k, y_ref_16k)
pesq_wb = self.pesq_wb(y_est_16k, y_ref_16k)
stoi = self.stoi(y_est_16k, y_ref_16k)
return_dict = dns_mos | si_sdr | pesq_wb | stoi | {"loss": loss.detach().cpu().numpy()}
# Convert the signal back to 44100 Hz
if batch_idx < 4:
sf.write(self.enhanced_signals_dir / f"{scene}_{listener_id}_enhanced.wav", y_est_16k, 16000)
if dataloader_id == "dev":
# If we are evaluating on the dev set, we are able to compute the HASPI score as well
y_est = y_est.squeeze(0).detach().cpu().numpy() # [T]
y_ref = y_ref.squeeze(0).detach().cpu().numpy() # [T]
# Amplify the signal
listener = self.listener_dict[listener_id]
y_est_amp = self.amplify_signal(y_est, listener.audiogram_left, self.enhancer, self.compressor)
y_est_amp = y_est_amp.astype(np.float32)
# Save amplified signals
if batch_idx < 4:
sf.write(self.amplified_dir / f"{scene}_{listener_id}_HA-output.wav", y_est_amp, 44100)
# Compute the metrics on the 44.1 kHz signals
haspi_v2_score, _ = haspi_v2(
reference=y_ref,
reference_sample_rate=44100,
processed=y_est,
processed_sample_rate=44100,
audiogram=listener.audiogram_left,
level1=100.0,
)
return_dict |= {"haspi_v2": haspi_v2_score}
return [return_dict]
def evaluation_epoch_end(self, outputs, log_to_tensorboard=True):
# We use this variable to store the score for the current epoch
score = 0.0
for dl_id, dataloader_outputs in outputs.items():
metric_dict_list = []
for i, step_output in enumerate(dataloader_outputs):
metric_dict_list += step_output
# Use pandas to compute the mean of all metrics and save them to a csv file
df_metrics = pd.DataFrame(metric_dict_list)
df_metrics_mean = df_metrics.mean(numeric_only=True)
df_metrics_mean_df = df_metrics_mean.to_frame().T # Convert mean to a DataFrame
time_now = self._get_time_now()
df_metrics.to_csv(
self.metrics_dir / f"dl_{dl_id}_epoch_{self.state.epochs_trained}_{time_now}.csv",
index=False,
)
df_metrics_mean_df.to_csv(
self.metrics_dir / f"dl_{dl_id}_epoch_{self.state.epochs_trained}_{time_now}_mean.csv",
index=False,
)
logger.info(f"\n{df_metrics_mean_df.to_markdown()}")
# We use the `metric_for_best_model` to compute the score. In this case, it is the `si_sdr`.
if self.is_in_train:
score += df_metrics_mean[self.args.metric_for_best_model]
if log_to_tensorboard:
for metric, value in df_metrics_mean.items():
self.writer.add_scalar(f"metrics_{dl_id}/{metric}", value, self.state.epochs_trained)
return score
# ==================== Main ====================
@dataclass
class Args(Serializable):
trainer: TrainingArgs
model: ModelArgs
train_dataset: TrainDatasetArgs
eval_train_dataset: TrainDatasetArgs
eval_dev_dataset: DevDatasetArgs
predict_dataset: PredictDatasetArgs
is_left_ear: bool = True
loss_function: str = "SNRLoss"
listeners_file: str = ""
def run(args: Args):
# Initialize logger
init_logging_logger(args.trainer.output_dir)
# Serialize arguments and save to a yaml file
args.save(Path(args.trainer.output_dir) / "conf.yaml")
# Initialize model
model = Model(args.model)
# Initialize datasets
train_dataset = TrainDataset(args.train_dataset)
eval_datasets = {"train": TrainDataset(args.eval_train_dataset), "dev": DevDataset(args.eval_dev_dataset)}
predict_dataset = PredictDataset(args.predict_dataset)
# Initialize trainer
trainer = Trainer(
is_left_ear=args.is_left_ear,
loss_function=args.loss_function,
listeners_file=args.listeners_file,
model=model,
args=args.trainer,
train_dataset=train_dataset,
eval_dataset=predict_dataset if args.trainer.do_predict else eval_datasets,
)
if args.trainer.do_eval:
trainer.evaluate()
elif args.trainer.do_predict:
trainer.predict()
elif args.trainer.do_train:
trainer.train()
else:
raise ValueError("At least one of `do_train`, `do_eval`, or `do_predict` must be True.")
if __name__ == "__main__":
args = parse(Args, add_config_path_arg=True)
run(args)