| from dataclasses import dataclass |
| from pathlib import Path |
|
|
| import numpy as np |
| import pandas as pd |
| import soundfile as sf |
| import torch |
| import torchaudio |
| from accelerate.logging import get_logger |
| from clarity.enhancer.compressor import Compressor |
| from clarity.enhancer.nalr import NALR |
| from clarity.evaluator.haspi import haspi_v2 |
| from clarity.utils.audiogram import Audiogram, Listener |
| from dataset import DevDataset, DevDatasetArgs, PredictDataset, PredictDatasetArgs, TrainDataset, TrainDatasetArgs |
| from loss import FreqLoss, MultiResolutionL1SpecLoss, SNRLoss |
| from model import Model, ModelArgs |
| from simple_parsing import Serializable, parse |
|
|
| from audiozen.logger import init_logging_logger |
| from audiozen.metric import DNSMOS, PESQ, SISDR, STOI |
| from audiozen.trainer_args import TrainingArgs |
| from audiozen.trainer_v2 import Trainer as BaseTrainer |
|
|
|
|
| logger = get_logger(__name__) |
|
|
|
|
| class Trainer(BaseTrainer): |
| def __init__(self, *args, is_left_ear, loss_function, listeners_file, **kwargs): |
| super().__init__(*args, **kwargs) |
| |
| self.ch_idx = 0 if is_left_ear else 1 |
|
|
| if loss_function == "SNRLoss": |
| self.loss_function = SNRLoss() |
| elif loss_function == "FreqLoss": |
| self.loss_function = FreqLoss() |
| elif loss_function == "MultiResolutionL1SpecLoss": |
| self.loss_function = MultiResolutionL1SpecLoss() |
| else: |
| raise ValueError(f"Invalid loss function: {loss_function}") |
|
|
| self.dns_mos = DNSMOS(input_sr=self.sr, device=self.process_index) |
| self.si_sdr = SISDR() |
| self.pesq_wb = PESQ(sr=self.sr) |
| self.stoi = STOI(sr=self.sr) |
|
|
| self.down_sample = torchaudio.transforms.Resample(44100, self.sr, resampling_method="sinc_interp_hann") |
| self.down_sample.to(self.device) |
|
|
| self.up_sample = torchaudio.transforms.Resample(self.sr, 44100, resampling_method="sinc_interp_hann") |
| self.up_sample.to(self.device) |
|
|
| |
| self.enhanced_signals_dir = self.output_dir / "enhanced_signals" |
| self.amplified_dir = self.output_dir / "amplified_signals" |
| self.amplified_dir.mkdir(parents=True, exist_ok=True) |
| self.enhanced_signals_dir.mkdir(parents=True, exist_ok=True) |
| self.listener_dict = Listener.load_listener_dict(listeners_file) |
| self.enhancer = NALR(nfir=220, sample_rate=44100) |
| self.compressor = Compressor(threshold=0.35, attenuation=0.1, attack=50, release=1000, rms_buffer_size=0.064) |
|
|
| def training_step(self, batch, batch_idx): |
| mix, ref, enroll, scene, length_list = batch |
|
|
| ref = ref[:, self.ch_idx, :] |
|
|
| mix = self.down_sample(mix) |
| ref = self.down_sample(ref) |
| enroll = self.down_sample(enroll) |
|
|
| *_, loss = self.model(mix, enroll, ref) |
|
|
| self.accelerator.backward(loss) |
|
|
| if self.accelerator.sync_gradients: |
| norm_before = self.accelerator.clip_grad_norm_(self.model.parameters(), self.args.max_grad_norm) |
|
|
| self.optimizer.step() |
| self.optimizer.zero_grad() |
|
|
| return { |
| "loss": loss.detach().cpu().numpy(), |
| "norm_before": norm_before.detach().cpu().numpy() if self.accelerator.sync_gradients else 0.0, |
| } |
|
|
| @staticmethod |
| def amplify_signal(signal, audiogram: Audiogram, enhancer: NALR, compressor: Compressor): |
| """Amplify signal for a given audiogram""" |
| nalr_fir, _ = enhancer.build(audiogram) |
| out = enhancer.apply(nalr_fir, signal) |
| out, _, _ = compressor.process(out) |
| return out |
|
|
| @torch.no_grad() |
| def prediction_step(self, batch, batch_idx): |
| y_mix, y_enroll, listener_id, scene = batch |
| scene = scene[0] |
| listener_id = listener_id[0] |
|
|
| y_mix = y_mix[:, self.ch_idx, :] |
|
|
| y_mix = self.down_sample(y_mix) |
| y_enroll = self.down_sample(y_enroll) |
|
|
| y_est = self.model(y_mix, y_enroll) |
|
|
| |
| y_est = self.up_sample(y_est) |
|
|
| |
| y_est = torch.stack([y_est, y_est], dim=1) |
|
|
| |
| y_est = y_est.squeeze(0).detach().cpu().numpy() |
|
|
| |
| sf.write(self.enhanced_signals_dir / f"{scene}_{listener_id}_enhanced.wav", y_est.T, 44100) |
|
|
| listener = self.listener_dict[listener_id] |
|
|
| |
| y_est_amp_l = self.amplify_signal(y_est[0], listener.audiogram_left, self.enhancer, self.compressor) |
| y_est_amp_r = self.amplify_signal(y_est[1], listener.audiogram_right, self.enhancer, self.compressor) |
|
|
| y_est_amp = np.stack([y_est_amp_l, y_est_amp_r], axis=0) |
| y_est_amp = y_est_amp.astype(np.float32) |
|
|
| |
| sf.write(self.amplified_dir / f"{scene}_{listener_id}_HA-output.wav", y_est_amp.T, 44100) |
|
|
| @torch.no_grad() |
| def evaluation_step(self, batch, batch_idx, dataloader_id): |
| if self.args.do_predict: |
| return self.prediction_step(batch, batch_idx) |
|
|
| |
| y_mix, y_ref, y_enroll, listener_id, scene = batch |
| scene = scene[0] |
| listener_id = listener_id[0] |
|
|
| y_ref = y_ref[:, self.ch_idx, :] |
|
|
| y_mix_16k = self.down_sample(y_mix) |
| y_enroll_16k = self.down_sample(y_enroll) |
| y_ref_16k = self.down_sample(y_ref) |
| num_samples = y_mix_16k.shape[-1] |
|
|
| |
| logits, loss = self.model(y_mix_16k, y_enroll_16k, y_ref_16k) |
| y_est_16k = self.accelerator.unwrap_model(self.model).decode(logits, num_samples) |
|
|
| y_est = self.up_sample(y_est_16k) |
|
|
| |
| y_est_16k = y_est_16k.squeeze(0).detach().cpu().numpy() |
| y_ref_16k = y_ref_16k.squeeze(0).detach().cpu().numpy() |
| dns_mos = self.dns_mos(y_est_16k) |
| si_sdr = self.si_sdr(y_est_16k, y_ref_16k) |
| pesq_wb = self.pesq_wb(y_est_16k, y_ref_16k) |
| stoi = self.stoi(y_est_16k, y_ref_16k) |
|
|
| return_dict = dns_mos | si_sdr | pesq_wb | stoi | {"loss": loss.detach().cpu().numpy()} |
|
|
| |
| if batch_idx < 4: |
| sf.write(self.enhanced_signals_dir / f"{scene}_{listener_id}_enhanced.wav", y_est_16k, 16000) |
|
|
| if dataloader_id == "dev": |
| |
| y_est = y_est.squeeze(0).detach().cpu().numpy() |
| y_ref = y_ref.squeeze(0).detach().cpu().numpy() |
|
|
| |
| listener = self.listener_dict[listener_id] |
| y_est_amp = self.amplify_signal(y_est, listener.audiogram_left, self.enhancer, self.compressor) |
| y_est_amp = y_est_amp.astype(np.float32) |
|
|
| |
| if batch_idx < 4: |
| sf.write(self.amplified_dir / f"{scene}_{listener_id}_HA-output.wav", y_est_amp, 44100) |
|
|
| |
| haspi_v2_score, _ = haspi_v2( |
| reference=y_ref, |
| reference_sample_rate=44100, |
| processed=y_est, |
| processed_sample_rate=44100, |
| audiogram=listener.audiogram_left, |
| level1=100.0, |
| ) |
|
|
| return_dict |= {"haspi_v2": haspi_v2_score} |
|
|
| return [return_dict] |
|
|
| def evaluation_epoch_end(self, outputs, log_to_tensorboard=True): |
| |
| score = 0.0 |
|
|
| for dl_id, dataloader_outputs in outputs.items(): |
| metric_dict_list = [] |
| for i, step_output in enumerate(dataloader_outputs): |
| metric_dict_list += step_output |
|
|
| |
| df_metrics = pd.DataFrame(metric_dict_list) |
| df_metrics_mean = df_metrics.mean(numeric_only=True) |
| df_metrics_mean_df = df_metrics_mean.to_frame().T |
|
|
| time_now = self._get_time_now() |
| df_metrics.to_csv( |
| self.metrics_dir / f"dl_{dl_id}_epoch_{self.state.epochs_trained}_{time_now}.csv", |
| index=False, |
| ) |
| df_metrics_mean_df.to_csv( |
| self.metrics_dir / f"dl_{dl_id}_epoch_{self.state.epochs_trained}_{time_now}_mean.csv", |
| index=False, |
| ) |
|
|
| logger.info(f"\n{df_metrics_mean_df.to_markdown()}") |
|
|
| |
| if self.is_in_train: |
| score += df_metrics_mean[self.args.metric_for_best_model] |
|
|
| if log_to_tensorboard: |
| for metric, value in df_metrics_mean.items(): |
| self.writer.add_scalar(f"metrics_{dl_id}/{metric}", value, self.state.epochs_trained) |
|
|
| return score |
|
|
|
|
| |
| @dataclass |
| class Args(Serializable): |
| trainer: TrainingArgs |
| model: ModelArgs |
| train_dataset: TrainDatasetArgs |
| eval_train_dataset: TrainDatasetArgs |
| eval_dev_dataset: DevDatasetArgs |
| predict_dataset: PredictDatasetArgs |
| is_left_ear: bool = True |
| loss_function: str = "SNRLoss" |
| listeners_file: str = "" |
|
|
|
|
| def run(args: Args): |
| |
| init_logging_logger(args.trainer.output_dir) |
|
|
| |
| args.save(Path(args.trainer.output_dir) / "conf.yaml") |
|
|
| |
| model = Model(args.model) |
|
|
| |
| train_dataset = TrainDataset(args.train_dataset) |
| eval_datasets = {"train": TrainDataset(args.eval_train_dataset), "dev": DevDataset(args.eval_dev_dataset)} |
| predict_dataset = PredictDataset(args.predict_dataset) |
|
|
| |
| trainer = Trainer( |
| is_left_ear=args.is_left_ear, |
| loss_function=args.loss_function, |
| listeners_file=args.listeners_file, |
| model=model, |
| args=args.trainer, |
| train_dataset=train_dataset, |
| eval_dataset=predict_dataset if args.trainer.do_predict else eval_datasets, |
| ) |
|
|
| if args.trainer.do_eval: |
| trainer.evaluate() |
| elif args.trainer.do_predict: |
| trainer.predict() |
| elif args.trainer.do_train: |
| trainer.train() |
| else: |
| raise ValueError("At least one of `do_train`, `do_eval`, or `do_predict` must be True.") |
|
|
|
|
| if __name__ == "__main__": |
| args = parse(Args, add_config_path_arg=True) |
| run(args) |
|
|