Spaces:
Sleeping
Sleeping
| #!/usr/bin/python3 | |
| # -*- coding: utf-8 -*- | |
| """ | |
| https://github.com/yxlu-0102/MP-SENet/blob/main/train.py | |
| """ | |
| import argparse | |
| import json | |
| import logging | |
| from logging.handlers import TimedRotatingFileHandler | |
| import os | |
| import platform | |
| from pathlib import Path | |
| import random | |
| import sys | |
| import shutil | |
| from typing import List | |
| pwd = os.path.abspath(os.path.dirname(__file__)) | |
| sys.path.append(os.path.join(pwd, "../../")) | |
| import numpy as np | |
| import torch | |
| from torch.nn import functional as F | |
| from torch.utils.data.dataloader import DataLoader | |
| from tqdm import tqdm | |
| from toolbox.torch.utils.data.dataset.denoise_excel_dataset import DenoiseExcelDataset | |
| from toolbox.torchaudio.models.mpnet.configuration_mpnet import MPNetConfig | |
| from toolbox.torchaudio.models.mpnet.discriminator import MetricDiscriminatorPretrainedModel | |
| from toolbox.torchaudio.models.mpnet.modeling_mpnet import MPNet, MPNetPretrainedModel, phase_losses | |
| from toolbox.torchaudio.models.mpnet.utils import mag_pha_stft, mag_pha_istft | |
| from toolbox.torchaudio.models.mpnet.metrics import run_batch_pesq, run_pesq_score | |
| def get_args(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--train_dataset", default="train.xlsx", type=str) | |
| parser.add_argument("--valid_dataset", default="valid.xlsx", type=str) | |
| parser.add_argument("--max_epochs", default=100, type=int) | |
| parser.add_argument("--num_serialized_models_to_keep", default=10, type=int) | |
| parser.add_argument("--patience", default=5, type=int) | |
| parser.add_argument("--serialization_dir", default="serialization_dir", type=str) | |
| parser.add_argument("--config_file", default="config.yaml", type=str) | |
| args = parser.parse_args() | |
| return args | |
| def logging_config(file_dir: str): | |
| fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s" | |
| logging.basicConfig(format=fmt, | |
| datefmt="%m/%d/%Y %H:%M:%S", | |
| level=logging.INFO) | |
| file_handler = TimedRotatingFileHandler( | |
| filename=os.path.join(file_dir, "main.log"), | |
| encoding="utf-8", | |
| when="D", | |
| interval=1, | |
| backupCount=7 | |
| ) | |
| file_handler.setLevel(logging.INFO) | |
| file_handler.setFormatter(logging.Formatter(fmt)) | |
| logger = logging.getLogger(__name__) | |
| logger.addHandler(file_handler) | |
| return logger | |
| class CollateFunction(object): | |
| def __init__(self): | |
| pass | |
| def __call__(self, batch: List[dict]): | |
| clean_audios = list() | |
| noisy_audios = list() | |
| for sample in batch: | |
| # noise_wave: torch.Tensor = sample["noise_wave"] | |
| clean_audio: torch.Tensor = sample["speech_wave"] | |
| noisy_audio: torch.Tensor = sample["mix_wave"] | |
| # snr_db: float = sample["snr_db"] | |
| clean_audios.append(clean_audio) | |
| noisy_audios.append(noisy_audio) | |
| clean_audios = torch.stack(clean_audios) | |
| noisy_audios = torch.stack(noisy_audios) | |
| # assert | |
| if torch.any(torch.isnan(clean_audios)) or torch.any(torch.isinf(clean_audios)): | |
| raise AssertionError("nan or inf in clean_audios") | |
| if torch.any(torch.isnan(noisy_audios)) or torch.any(torch.isinf(noisy_audios)): | |
| raise AssertionError("nan or inf in noisy_audios") | |
| return clean_audios, noisy_audios | |
| collate_fn = CollateFunction() | |
| def main(): | |
| args = get_args() | |
| config = MPNetConfig.from_pretrained( | |
| pretrained_model_name_or_path=args.config_file, | |
| ) | |
| serialization_dir = Path(args.serialization_dir) | |
| serialization_dir.mkdir(parents=True, exist_ok=True) | |
| logger = logging_config(serialization_dir) | |
| random.seed(config.seed) | |
| np.random.seed(config.seed) | |
| torch.manual_seed(config.seed) | |
| logger.info(f"set seed: {config.seed}") | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| n_gpu = torch.cuda.device_count() | |
| logger.info(f"GPU available count: {n_gpu}; device: {device}") | |
| # datasets | |
| train_dataset = DenoiseExcelDataset( | |
| excel_file=args.train_dataset, | |
| expected_sample_rate=8000, | |
| max_wave_value=32768.0, | |
| ) | |
| valid_dataset = DenoiseExcelDataset( | |
| excel_file=args.valid_dataset, | |
| expected_sample_rate=8000, | |
| max_wave_value=32768.0, | |
| ) | |
| train_data_loader = DataLoader( | |
| dataset=train_dataset, | |
| batch_size=config.batch_size, | |
| shuffle=True, | |
| sampler=None, | |
| # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能. | |
| num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2, | |
| collate_fn=collate_fn, | |
| pin_memory=False, | |
| # prefetch_factor=64, | |
| ) | |
| valid_data_loader = DataLoader( | |
| dataset=valid_dataset, | |
| batch_size=config.batch_size, | |
| shuffle=True, | |
| sampler=None, | |
| # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能. | |
| num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2, | |
| collate_fn=collate_fn, | |
| pin_memory=False, | |
| # prefetch_factor=64, | |
| ) | |
| # models | |
| logger.info(f"prepare models. config_file: {args.config_file}") | |
| generator = MPNetPretrainedModel(config).to(device) | |
| discriminator = MetricDiscriminatorPretrainedModel(config).to(device) | |
| # optimizer | |
| logger.info("prepare optimizer, lr_scheduler") | |
| num_params = 0 | |
| for p in generator.parameters(): | |
| num_params += p.numel() | |
| logger.info("total parameters (generator): {:.3f}M".format(num_params/1e6)) | |
| optim_g = torch.optim.AdamW(generator.parameters(), config.learning_rate, betas=[config.adam_b1, config.adam_b2]) | |
| optim_d = torch.optim.AdamW(discriminator.parameters(), config.learning_rate, betas=[config.adam_b1, config.adam_b2]) | |
| # resume training | |
| last_epoch = -1 | |
| for epoch_i in serialization_dir.glob("epoch-*"): | |
| epoch_i = Path(epoch_i) | |
| epoch_idx = epoch_i.stem.split("-")[1] | |
| epoch_idx = int(epoch_idx) | |
| if epoch_idx > last_epoch: | |
| last_epoch = epoch_idx | |
| if last_epoch != -1: | |
| logger.info(f"resume from epoch-{last_epoch}.") | |
| generator_pt = serialization_dir / f"epoch-{last_epoch}/generator.pt" | |
| discriminator_pt = serialization_dir / f"epoch-{last_epoch}/discriminator.pt" | |
| optim_g_pth = serialization_dir / f"epoch-{last_epoch}/optim_g.pth" | |
| optim_d_pth = serialization_dir / f"epoch-{last_epoch}/optim_d.pth" | |
| logger.info(f"load state dict for generator.") | |
| with open(generator_pt.as_posix(), "rb") as f: | |
| state_dict = torch.load(f, map_location="cpu", weights_only=True) | |
| generator.load_state_dict(state_dict, strict=True) | |
| logger.info(f"load state dict for discriminator.") | |
| with open(discriminator_pt.as_posix(), "rb") as f: | |
| state_dict = torch.load(f, map_location="cpu", weights_only=True) | |
| discriminator.load_state_dict(state_dict, strict=True) | |
| logger.info(f"load state dict for optim_g.") | |
| with open(optim_g_pth.as_posix(), "rb") as f: | |
| state_dict = torch.load(f, map_location="cpu", weights_only=True) | |
| optim_g.load_state_dict(state_dict) | |
| logger.info(f"load state dict for optim_d.") | |
| with open(optim_d_pth.as_posix(), "rb") as f: | |
| state_dict = torch.load(f, map_location="cpu", weights_only=True) | |
| optim_d.load_state_dict(state_dict) | |
| scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=config.lr_decay, last_epoch=last_epoch) | |
| scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=config.lr_decay, last_epoch=last_epoch) | |
| # training loop | |
| # state | |
| loss_d = 10000000000 | |
| loss_g = 10000000000 | |
| pesq_metric = 10000000000 | |
| mag_err = 10000000000 | |
| pha_err = 10000000000 | |
| com_err = 10000000000 | |
| stft_err = 10000000000 | |
| model_list = list() | |
| best_idx_epoch = None | |
| best_metric = None | |
| patience_count = 0 | |
| logger.info("training") | |
| early_stop_flag = False | |
| for idx_epoch in range(max(0, last_epoch+1), args.max_epochs): | |
| if early_stop_flag: | |
| break | |
| # train | |
| generator.train() | |
| discriminator.train() | |
| total_loss_d = 0. | |
| total_loss_g = 0. | |
| total_batches = 0. | |
| progress_bar = tqdm( | |
| total=len(train_data_loader), | |
| desc="Training; epoch: {}".format(idx_epoch), | |
| ) | |
| for batch in train_data_loader: | |
| clean_audio, noisy_audio = batch | |
| clean_audio = clean_audio.to(device) | |
| noisy_audio = noisy_audio.to(device) | |
| one_labels = torch.ones(clean_audio.shape[0]).to(device) | |
| clean_mag, clean_pha, clean_com = mag_pha_stft(clean_audio, config.n_fft, config.hop_size, config.win_size, config.compress_factor) | |
| noisy_mag, noisy_pha, noisy_com = mag_pha_stft(noisy_audio, config.n_fft, config.hop_size, config.win_size, config.compress_factor) | |
| mag_g, pha_g, com_g = generator.forward(noisy_mag, noisy_pha) | |
| audio_g = mag_pha_istft(mag_g, pha_g, config.n_fft, config.hop_size, config.win_size, config.compress_factor) | |
| mag_g_hat, pha_g_hat, com_g_hat = mag_pha_stft(audio_g, config.n_fft, config.hop_size, config.win_size, config.compress_factor) | |
| audio_list_r, audio_list_g = list(clean_audio.cpu().numpy()), list(audio_g.detach().cpu().numpy()) | |
| pesq_score_list: List[float] = run_batch_pesq(audio_list_r, audio_list_g, sample_rate=config.sample_rate, mode="nb") | |
| # Discriminator | |
| optim_d.zero_grad() | |
| metric_r = discriminator.forward(clean_mag, clean_mag) | |
| metric_g = discriminator.forward(clean_mag, mag_g_hat.detach()) | |
| loss_disc_r = F.mse_loss(one_labels, metric_r.flatten()) | |
| if -1 in pesq_score_list: | |
| # print("-1 in batch_pesq_score!") | |
| loss_disc_g = 0 | |
| else: | |
| pesq_score_list: torch.FloatTensor = torch.tensor([(score - 1) / 3.5 for score in pesq_score_list], dtype=torch.float32) | |
| loss_disc_g = F.mse_loss(pesq_score_list.to(device), metric_g.flatten()) | |
| loss_disc_all = loss_disc_r + loss_disc_g | |
| loss_disc_all.backward() | |
| optim_d.step() | |
| # Generator | |
| optim_g.zero_grad() | |
| # L2 Magnitude Loss | |
| loss_mag = F.mse_loss(clean_mag, mag_g) | |
| # Anti-wrapping Phase Loss | |
| loss_ip, loss_gd, loss_iaf = phase_losses(clean_pha, pha_g) | |
| loss_pha = loss_ip + loss_gd + loss_iaf | |
| # L2 Complex Loss | |
| loss_com = F.mse_loss(clean_com, com_g) * 2 | |
| # L2 Consistency Loss | |
| loss_stft = F.mse_loss(com_g, com_g_hat) * 2 | |
| # Time Loss | |
| loss_time = F.l1_loss(clean_audio, audio_g) | |
| # Metric Loss | |
| metric_g = discriminator.forward(clean_mag, mag_g_hat) | |
| loss_metric = F.mse_loss(metric_g.flatten(), one_labels) | |
| loss_gen_all = loss_mag * 0.9 + loss_pha * 0.3 + loss_com * 0.1 + loss_stft * 0.1 + loss_metric * 0.05 + loss_time * 0.2 | |
| loss_gen_all.backward() | |
| optim_g.step() | |
| total_loss_d += loss_disc_all.item() | |
| total_loss_g += loss_gen_all.item() | |
| total_batches += 1 | |
| loss_d = round(total_loss_d / total_batches, 4) | |
| loss_g = round(total_loss_g / total_batches, 4) | |
| progress_bar.update(1) | |
| progress_bar.set_postfix({ | |
| "loss_d": loss_d, | |
| "loss_g": loss_g, | |
| }) | |
| # evaluation | |
| generator.eval() | |
| discriminator.eval() | |
| torch.cuda.empty_cache() | |
| total_pesq_score = 0. | |
| total_mag_err = 0. | |
| total_pha_err = 0. | |
| total_com_err = 0. | |
| total_stft_err = 0. | |
| total_batches = 0. | |
| progress_bar = tqdm( | |
| total=len(valid_data_loader), | |
| desc="Evaluation; epoch: {}".format(idx_epoch), | |
| ) | |
| with torch.no_grad(): | |
| for batch in valid_data_loader: | |
| clean_audio, noisy_audio = batch | |
| clean_audio = clean_audio.to(device) | |
| noisy_audio = noisy_audio.to(device) | |
| clean_mag, clean_pha, clean_com = mag_pha_stft(clean_audio, config.n_fft, config.hop_size, config.win_size, config.compress_factor) | |
| noisy_mag, noisy_pha, noisy_com = mag_pha_stft(noisy_audio, config.n_fft, config.hop_size, config.win_size, config.compress_factor) | |
| mag_g, pha_g, com_g = generator.forward(noisy_mag, noisy_pha) | |
| audio_g = mag_pha_istft(mag_g, pha_g, config.n_fft, config.hop_size, config.win_size, config.compress_factor) | |
| mag_g_hat, pha_g_hat, com_g_hat = mag_pha_stft(audio_g, config.n_fft, config.hop_size, config.win_size, config.compress_factor) | |
| clean_audio_list = torch.split(clean_audio, 1, dim=0) | |
| enhanced_audio_list = torch.split(audio_g, 1, dim=0) | |
| clean_audio_list = [t.squeeze().cpu().numpy() for t in clean_audio_list] | |
| enhanced_audio_list = [t.squeeze().cpu().numpy() for t in enhanced_audio_list] | |
| pesq_score = run_pesq_score( | |
| clean_audio_list, | |
| enhanced_audio_list, | |
| sample_rate = config.sample_rate, | |
| mode = "nb", | |
| ) | |
| total_pesq_score += pesq_score | |
| total_mag_err += F.mse_loss(clean_mag, mag_g).item() | |
| val_ip_err, val_gd_err, val_iaf_err = phase_losses(clean_pha, pha_g) | |
| total_pha_err += (val_ip_err + val_gd_err + val_iaf_err).item() | |
| total_com_err += F.mse_loss(clean_com, com_g).item() | |
| total_stft_err += F.mse_loss(com_g, com_g_hat).item() | |
| total_batches += 1 | |
| pesq_metric = round(total_pesq_score / total_batches, 4) | |
| mag_err = round(total_mag_err / total_batches, 4) | |
| pha_err = round(total_pha_err / total_batches, 4) | |
| com_err = round(total_com_err / total_batches, 4) | |
| stft_err = round(total_stft_err / total_batches, 4) | |
| progress_bar.update(1) | |
| progress_bar.set_postfix({ | |
| "pesq_metric": pesq_metric, | |
| "mag_err": mag_err, | |
| "pha_err": pha_err, | |
| "com_err": com_err, | |
| "stft_err": stft_err, | |
| }) | |
| # scheduler | |
| scheduler_g.step() | |
| scheduler_d.step() | |
| # save path | |
| epoch_dir = serialization_dir / "epoch-{}".format(idx_epoch) | |
| epoch_dir.mkdir(parents=True, exist_ok=False) | |
| # save models | |
| generator.save_pretrained(epoch_dir.as_posix()) | |
| discriminator.save_pretrained(epoch_dir.as_posix()) | |
| # save optim | |
| torch.save(optim_d.state_dict(), (epoch_dir / "optim_d.pth").as_posix()) | |
| torch.save(optim_g.state_dict(), (epoch_dir / "optim_g.pth").as_posix()) | |
| model_list.append(epoch_dir) | |
| if len(model_list) >= args.num_serialized_models_to_keep: | |
| model_to_delete: Path = model_list.pop(0) | |
| shutil.rmtree(model_to_delete.as_posix()) | |
| # save metric | |
| if best_metric is None: | |
| best_idx_epoch = idx_epoch | |
| best_metric = pesq_metric | |
| elif pesq_metric > best_metric: | |
| # great is better. | |
| best_idx_epoch = idx_epoch | |
| best_metric = pesq_metric | |
| else: | |
| pass | |
| metrics = { | |
| "idx_epoch": idx_epoch, | |
| "best_idx_epoch": best_idx_epoch, | |
| "loss_d": loss_d, | |
| "loss_g": loss_g, | |
| "pesq_metric": pesq_metric, | |
| "mag_err": mag_err, | |
| "pha_err": pha_err, | |
| "com_err": com_err, | |
| "stft_err": stft_err, | |
| } | |
| metrics_filename = epoch_dir / "metrics_epoch.json" | |
| with open(metrics_filename, "w", encoding="utf-8") as f: | |
| json.dump(metrics, f, indent=4, ensure_ascii=False) | |
| # save best | |
| best_dir = serialization_dir / "best" | |
| if best_idx_epoch == idx_epoch: | |
| if best_dir.exists(): | |
| shutil.rmtree(best_dir) | |
| shutil.copytree(epoch_dir, best_dir) | |
| # early stop | |
| early_stop_flag = False | |
| if best_idx_epoch == idx_epoch: | |
| patience_count = 0 | |
| else: | |
| patience_count += 1 | |
| if patience_count >= args.patience: | |
| early_stop_flag = True | |
| # early stop | |
| if early_stop_flag: | |
| break | |
| return | |
| if __name__ == "__main__": | |
| main() | |