Spaces:
Sleeping
Sleeping
| #!/usr/bin/python3 | |
| # -*- coding: utf-8 -*- | |
| """ | |
| https://github.com/Rikorose/DeepFilterNet | |
| """ | |
| import argparse | |
| import json | |
| import logging | |
| from logging.handlers import TimedRotatingFileHandler | |
| import os | |
| import platform | |
| from pathlib import Path | |
| import random | |
| import sys | |
| import shutil | |
| from typing import List | |
| from fontTools.varLib.plot import stops | |
| pwd = os.path.abspath(os.path.dirname(__file__)) | |
| sys.path.append(os.path.join(pwd, "../../")) | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| from torch.nn import functional as F | |
| from torch.utils.data.dataloader import DataLoader | |
| from tqdm import tqdm | |
| from toolbox.torch.utils.data.dataset.denoise_jsonl_dataset import DenoiseJsonlDataset | |
| from toolbox.torchaudio.losses.snr import NegativeSISNRLoss | |
| from toolbox.torchaudio.losses.spectral import LSDLoss, MultiResolutionSTFTLoss | |
| from toolbox.torchaudio.metrics.pesq import run_pesq_score | |
| from toolbox.torchaudio.models.dfnet.configuration_dfnet import DfNetConfig | |
| from toolbox.torchaudio.models.dfnet.modeling_dfnet import DfNet, DfNetPretrainedModel | |
| def get_args(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--train_dataset", default="train.jsonl", type=str) | |
| parser.add_argument("--valid_dataset", default="valid.jsonl", type=str) | |
| parser.add_argument("--num_serialized_models_to_keep", default=15, type=int) | |
| parser.add_argument("--patience", default=10, type=int) | |
| parser.add_argument("--serialization_dir", default="serialization_dir", type=str) | |
| parser.add_argument("--config_file", default="config.yaml", type=str) | |
| args = parser.parse_args() | |
| return args | |
| def logging_config(file_dir: str): | |
| fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s" | |
| logging.basicConfig(format=fmt, | |
| datefmt="%m/%d/%Y %H:%M:%S", | |
| level=logging.INFO) | |
| file_handler = TimedRotatingFileHandler( | |
| filename=os.path.join(file_dir, "main.log"), | |
| encoding="utf-8", | |
| when="D", | |
| interval=1, | |
| backupCount=7 | |
| ) | |
| file_handler.setLevel(logging.INFO) | |
| file_handler.setFormatter(logging.Formatter(fmt)) | |
| logger = logging.getLogger(__name__) | |
| logger.addHandler(file_handler) | |
| return logger | |
| class CollateFunction(object): | |
| def __init__(self): | |
| pass | |
| def __call__(self, batch: List[dict]): | |
| clean_audios = list() | |
| noisy_audios = list() | |
| snr_db_list = list() | |
| for sample in batch: | |
| # noise_wave: torch.Tensor = sample["noise_wave"] | |
| clean_audio: torch.Tensor = sample["speech_wave"] | |
| noisy_audio: torch.Tensor = sample["mix_wave"] | |
| # snr_db: float = sample["snr_db"] | |
| clean_audios.append(clean_audio) | |
| noisy_audios.append(noisy_audio) | |
| clean_audios = torch.stack(clean_audios) | |
| noisy_audios = torch.stack(noisy_audios) | |
| # assert | |
| if torch.any(torch.isnan(clean_audios)) or torch.any(torch.isinf(clean_audios)): | |
| raise AssertionError("nan or inf in clean_audios") | |
| if torch.any(torch.isnan(noisy_audios)) or torch.any(torch.isinf(noisy_audios)): | |
| raise AssertionError("nan or inf in noisy_audios") | |
| return clean_audios, noisy_audios | |
| collate_fn = CollateFunction() | |
| def main(): | |
| args = get_args() | |
| config = DfNetConfig.from_pretrained( | |
| pretrained_model_name_or_path=args.config_file, | |
| ) | |
| serialization_dir = Path(args.serialization_dir) | |
| serialization_dir.mkdir(parents=True, exist_ok=True) | |
| logger = logging_config(serialization_dir) | |
| random.seed(config.seed) | |
| np.random.seed(config.seed) | |
| torch.manual_seed(config.seed) | |
| logger.info(f"set seed: {config.seed}") | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| n_gpu = torch.cuda.device_count() | |
| logger.info(f"GPU available count: {n_gpu}; device: {device}") | |
| # datasets | |
| train_dataset = DenoiseJsonlDataset( | |
| jsonl_file=args.train_dataset, | |
| expected_sample_rate=config.sample_rate, | |
| max_wave_value=32768.0, | |
| min_snr_db=config.min_snr_db, | |
| max_snr_db=config.max_snr_db, | |
| # skip=225000, | |
| ) | |
| valid_dataset = DenoiseJsonlDataset( | |
| jsonl_file=args.valid_dataset, | |
| expected_sample_rate=config.sample_rate, | |
| max_wave_value=32768.0, | |
| min_snr_db=config.min_snr_db, | |
| max_snr_db=config.max_snr_db, | |
| ) | |
| train_data_loader = DataLoader( | |
| dataset=train_dataset, | |
| batch_size=config.batch_size, | |
| # shuffle=True, | |
| sampler=None, | |
| # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能. | |
| num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2, | |
| collate_fn=collate_fn, | |
| pin_memory=False, | |
| prefetch_factor=None if platform.system() == "Windows" else 2, | |
| ) | |
| valid_data_loader = DataLoader( | |
| dataset=valid_dataset, | |
| batch_size=config.batch_size, | |
| # shuffle=True, | |
| sampler=None, | |
| # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能. | |
| num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2, | |
| collate_fn=collate_fn, | |
| pin_memory=False, | |
| prefetch_factor=None if platform.system() == "Windows" else 2, | |
| ) | |
| # models | |
| logger.info(f"prepare models. config_file: {args.config_file}") | |
| model = DfNetPretrainedModel(config).to(device) | |
| model.to(device) | |
| model.train() | |
| # optimizer | |
| logger.info("prepare optimizer, lr_scheduler, loss_fn, evaluation_metric") | |
| optimizer = torch.optim.AdamW(model.parameters(), config.lr) | |
| # resume training | |
| last_step_idx = -1 | |
| last_epoch = -1 | |
| for step_idx_str in serialization_dir.glob("steps-*"): | |
| step_idx_str = Path(step_idx_str) | |
| step_idx = step_idx_str.stem.split("-")[1] | |
| step_idx = int(step_idx) | |
| if step_idx > last_step_idx: | |
| last_step_idx = step_idx | |
| # last_epoch = 1 | |
| if last_step_idx != -1: | |
| logger.info(f"resume from steps-{last_step_idx}.") | |
| model_pt = serialization_dir / f"steps-{last_step_idx}/model.pt" | |
| logger.info(f"load state dict for model.") | |
| with open(model_pt.as_posix(), "rb") as f: | |
| state_dict = torch.load(f, map_location="cpu", weights_only=True) | |
| model.load_state_dict(state_dict, strict=True) | |
| if config.lr_scheduler == "CosineAnnealingLR": | |
| lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( | |
| optimizer, | |
| last_epoch=last_epoch, | |
| # T_max=10 * config.eval_steps, | |
| # eta_min=0.01 * config.lr, | |
| **config.lr_scheduler_kwargs, | |
| ) | |
| elif config.lr_scheduler == "MultiStepLR": | |
| lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( | |
| optimizer, | |
| last_epoch=last_epoch, | |
| milestones=[10000, 20000, 30000, 40000, 50000], gamma=0.5 | |
| ) | |
| else: | |
| raise AssertionError(f"invalid lr_scheduler: {config.lr_scheduler}") | |
| neg_si_snr_loss_fn = NegativeSISNRLoss(reduction="mean").to(device) | |
| mr_stft_loss_fn = MultiResolutionSTFTLoss( | |
| fft_size_list=[256, 512, 1024], | |
| win_size_list=[256, 512, 1024], | |
| hop_size_list=[128, 256, 512], | |
| factor_sc=1.5, | |
| factor_mag=1.0, | |
| reduction="mean" | |
| ).to(device) | |
| # training loop | |
| # state | |
| average_pesq_score = 1000000000 | |
| average_loss = 1000000000 | |
| average_mr_stft_loss = 1000000000 | |
| average_neg_si_snr_loss = 1000000000 | |
| average_mask_loss = 1000000000 | |
| average_lsnr_loss = 1000000000 | |
| model_list = list() | |
| best_epoch_idx = None | |
| best_step_idx = None | |
| best_metric = None | |
| patience_count = 0 | |
| step_idx = 0 if last_step_idx == -1 else last_step_idx | |
| logger.info("training") | |
| early_stop_flag = False | |
| for epoch_idx in range(max(0, last_epoch+1), config.max_epochs): | |
| if early_stop_flag: | |
| break | |
| # train | |
| model.train() | |
| total_pesq_score = 0. | |
| total_loss = 0. | |
| total_mr_stft_loss = 0. | |
| total_neg_si_snr_loss = 0. | |
| total_mask_loss = 0. | |
| total_lsnr_loss = 0. | |
| total_batches = 0. | |
| progress_bar_train = tqdm( | |
| initial=step_idx, | |
| desc="Training; epoch-{}".format(epoch_idx), | |
| ) | |
| for train_batch in train_data_loader: | |
| clean_audios, noisy_audios = train_batch | |
| clean_audios: torch.Tensor = clean_audios.to(device) | |
| noisy_audios: torch.Tensor = noisy_audios.to(device) | |
| est_spec, est_wav, est_mask, lsnr = model.forward(noisy_audios) | |
| mr_stft_loss = mr_stft_loss_fn.forward(est_wav, clean_audios) | |
| neg_si_snr_loss = neg_si_snr_loss_fn.forward(est_wav, clean_audios) | |
| mask_loss = model.mask_loss_fn(est_mask, clean_audios, noisy_audios) | |
| lsnr_loss = model.lsnr_loss_fn(lsnr, clean_audios, noisy_audios) | |
| loss = 1.0 * mr_stft_loss + 1.0 * neg_si_snr_loss + 1.0 * mask_loss + 0.3 * lsnr_loss | |
| if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)): | |
| logger.info(f"find nan or inf in loss.") | |
| continue | |
| denoise_audios_list_r = list(est_wav.detach().cpu().numpy()) | |
| clean_audios_list_r = list(clean_audios.detach().cpu().numpy()) | |
| pesq_score = run_pesq_score(clean_audios_list_r, denoise_audios_list_r, sample_rate=config.sample_rate, mode="nb") | |
| optimizer.zero_grad() | |
| loss.backward() | |
| torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=config.clip_grad_norm) | |
| optimizer.step() | |
| lr_scheduler.step() | |
| total_pesq_score += pesq_score | |
| total_loss += loss.item() | |
| total_mr_stft_loss += mr_stft_loss.item() | |
| total_neg_si_snr_loss += neg_si_snr_loss.item() | |
| total_mask_loss += mask_loss.item() | |
| total_lsnr_loss += lsnr_loss.item() | |
| total_batches += 1 | |
| average_pesq_score = round(total_pesq_score / total_batches, 4) | |
| average_loss = round(total_loss / total_batches, 4) | |
| average_mr_stft_loss = round(total_mr_stft_loss / total_batches, 4) | |
| average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4) | |
| average_mask_loss = round(total_mask_loss / total_batches, 4) | |
| average_lsnr_loss = round(total_lsnr_loss / total_batches, 4) | |
| progress_bar_train.update(1) | |
| progress_bar_train.set_postfix({ | |
| "lr": lr_scheduler.get_last_lr()[0], | |
| "pesq_score": average_pesq_score, | |
| "loss": average_loss, | |
| "mr_stft_loss": average_mr_stft_loss, | |
| "neg_si_snr_loss": average_neg_si_snr_loss, | |
| "mask_loss": average_mask_loss, | |
| "lsnr_loss": average_lsnr_loss, | |
| }) | |
| # evaluation | |
| step_idx += 1 | |
| if step_idx % config.eval_steps == 0: | |
| model.eval() | |
| with torch.no_grad(): | |
| torch.cuda.empty_cache() | |
| total_pesq_score = 0. | |
| total_loss = 0. | |
| total_mr_stft_loss = 0. | |
| total_neg_si_snr_loss = 0. | |
| total_mask_loss = 0. | |
| total_lsnr_loss = 0. | |
| total_batches = 0. | |
| progress_bar_train.close() | |
| progress_bar_eval = tqdm( | |
| desc="Evaluation; steps-{}k".format(int(step_idx/1000)), | |
| ) | |
| for eval_batch in valid_data_loader: | |
| clean_audios, noisy_audios = eval_batch | |
| clean_audios: torch.Tensor = clean_audios.to(device) | |
| noisy_audios: torch.Tensor = noisy_audios.to(device) | |
| est_spec, est_wav, est_mask, lsnr = model.forward(noisy_audios) | |
| mr_stft_loss = mr_stft_loss_fn.forward(est_wav, clean_audios) | |
| neg_si_snr_loss = neg_si_snr_loss_fn.forward(est_wav, clean_audios) | |
| mask_loss = model.mask_loss_fn(est_mask, clean_audios, noisy_audios) | |
| lsnr_loss = model.lsnr_loss_fn(lsnr, clean_audios, noisy_audios) | |
| loss = 1.0 * mr_stft_loss + 1.0 * neg_si_snr_loss + 1.0 * mask_loss + 0.3 * lsnr_loss | |
| if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)): | |
| logger.info(f"find nan or inf in loss.") | |
| continue | |
| denoise_audios_list_r = list(est_wav.detach().cpu().numpy()) | |
| clean_audios_list_r = list(clean_audios.detach().cpu().numpy()) | |
| pesq_score = run_pesq_score(clean_audios_list_r, denoise_audios_list_r, sample_rate=config.sample_rate, mode="nb") | |
| total_pesq_score += pesq_score | |
| total_loss += loss.item() | |
| total_mr_stft_loss += mr_stft_loss.item() | |
| total_neg_si_snr_loss += neg_si_snr_loss.item() | |
| total_mask_loss += mask_loss.item() | |
| total_lsnr_loss += lsnr_loss.item() | |
| total_batches += 1 | |
| average_pesq_score = round(total_pesq_score / total_batches, 4) | |
| average_loss = round(total_loss / total_batches, 4) | |
| average_mr_stft_loss = round(total_mr_stft_loss / total_batches, 4) | |
| average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4) | |
| average_mask_loss = round(total_mask_loss / total_batches, 4) | |
| average_lsnr_loss = round(total_lsnr_loss / total_batches, 4) | |
| progress_bar_eval.update(1) | |
| progress_bar_eval.set_postfix({ | |
| "lr": lr_scheduler.get_last_lr()[0], | |
| "pesq_score": average_pesq_score, | |
| "loss": average_loss, | |
| "mr_stft_loss": average_mr_stft_loss, | |
| "neg_si_snr_loss": average_neg_si_snr_loss, | |
| "mask_loss": average_mask_loss, | |
| "lsnr_loss": average_lsnr_loss, | |
| }) | |
| total_pesq_score = 0. | |
| total_loss = 0. | |
| total_mr_stft_loss = 0. | |
| total_neg_si_snr_loss = 0. | |
| total_mask_loss = 0. | |
| total_lsnr_loss = 0. | |
| total_batches = 0. | |
| progress_bar_eval.close() | |
| progress_bar_train = tqdm( | |
| initial=progress_bar_train.n, | |
| postfix=progress_bar_train.postfix, | |
| desc=progress_bar_train.desc, | |
| ) | |
| # save path | |
| save_dir = serialization_dir / "steps-{}".format(step_idx) | |
| save_dir.mkdir(parents=True, exist_ok=False) | |
| # save models | |
| model.save_pretrained(save_dir.as_posix()) | |
| model_list.append(save_dir) | |
| if len(model_list) >= args.num_serialized_models_to_keep: | |
| model_to_delete: Path = model_list.pop(0) | |
| shutil.rmtree(model_to_delete.as_posix()) | |
| # save metric | |
| if best_metric is None: | |
| best_epoch_idx = epoch_idx | |
| best_step_idx = step_idx | |
| best_metric = average_pesq_score | |
| elif average_pesq_score >= best_metric: | |
| # great is better. | |
| best_epoch_idx = epoch_idx | |
| best_step_idx = step_idx | |
| best_metric = average_pesq_score | |
| else: | |
| pass | |
| metrics = { | |
| "epoch_idx": epoch_idx, | |
| "best_epoch_idx": best_epoch_idx, | |
| "best_step_idx": best_step_idx, | |
| "pesq_score": average_pesq_score, | |
| "loss": average_loss, | |
| "mr_stft_loss": average_mr_stft_loss, | |
| "neg_si_snr_loss": average_neg_si_snr_loss, | |
| "mask_loss": average_mask_loss, | |
| "lsnr_loss": average_lsnr_loss, | |
| } | |
| metrics_filename = save_dir / "metrics_epoch.json" | |
| with open(metrics_filename, "w", encoding="utf-8") as f: | |
| json.dump(metrics, f, indent=4, ensure_ascii=False) | |
| # save best | |
| best_dir = serialization_dir / "best" | |
| if best_epoch_idx == epoch_idx and best_step_idx == step_idx: | |
| if best_dir.exists(): | |
| shutil.rmtree(best_dir) | |
| shutil.copytree(save_dir, best_dir) | |
| # early stop | |
| early_stop_flag = False | |
| if best_epoch_idx == epoch_idx and best_step_idx == step_idx: | |
| patience_count = 0 | |
| else: | |
| patience_count += 1 | |
| if patience_count >= args.patience: | |
| early_stop_flag = True | |
| # early stop | |
| if early_stop_flag: | |
| break | |
| model.train() | |
| return | |
| if __name__ == "__main__": | |
| main() | |