toothless-esnet / train.py
rossijakob's picture
Upload folder using huggingface_hub
18b9615 verified
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
import sys
sys.path.append("..")
import os
import time
import argparse
import json
import torch
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DistributedSampler, DataLoader
import torch.multiprocessing as mp
from torch.distributed import init_process_group
from torch.nn.parallel import DistributedDataParallel
from env import AttrDict, build_env
from dataset import Dataset, mag_pha_stft, mag_pha_istft, get_dataset_filelist
from models.model import MPNet, pesq_score, phase_losses
from models.discriminator import MetricDiscriminator, batch_pesq
from utils import scan_checkpoint, load_checkpoint, save_checkpoint
torch.backends.cudnn.benchmark = True
def train(rank, a, h):
if h.num_gpus > 1:
init_process_group(backend=h.dist_config['dist_backend'], init_method=h.dist_config['dist_url'],
world_size=h.dist_config['world_size'] * h.num_gpus, rank=rank)
torch.cuda.manual_seed(h.seed)
device = torch.device('cuda:{:d}'.format(rank))
generator = MPNet(h).to(device)
discriminator = MetricDiscriminator().to(device)
if rank == 0:
print(generator)
num_params = 0
for p in generator.parameters():
num_params += p.numel()
print('Total Parameters: {:.3f}M'.format(num_params/1e6))
os.makedirs(a.checkpoint_path, exist_ok=True)
os.makedirs(os.path.join(a.checkpoint_path, 'logs'), exist_ok=True)
print("checkpoints directory : ", a.checkpoint_path)
if os.path.isdir(a.checkpoint_path):
cp_g = scan_checkpoint(a.checkpoint_path, 'g_')
cp_do = scan_checkpoint(a.checkpoint_path, 'do_')
steps = 0
if cp_g is None or cp_do is None:
state_dict_do = None
last_epoch = -1
else:
state_dict_g = load_checkpoint(cp_g, device)
state_dict_do = load_checkpoint(cp_do, device)
generator.load_state_dict(state_dict_g['generator'])
discriminator.load_state_dict(state_dict_do['discriminator'])
steps = state_dict_do['steps'] + 1
last_epoch = state_dict_do['epoch']
if h.num_gpus > 1:
generator = DistributedDataParallel(generator, device_ids=[rank]).to(device)
discriminator = DistributedDataParallel(discriminator, device_ids=[rank]).to(device)
optim_g = torch.optim.AdamW(generator.parameters(), h.learning_rate, betas=[h.adam_b1, h.adam_b2])
optim_d = torch.optim.AdamW(discriminator.parameters(), h.learning_rate, betas=[h.adam_b1, h.adam_b2])
if state_dict_do is not None:
optim_g.load_state_dict(state_dict_do['optim_g'])
optim_d.load_state_dict(state_dict_do['optim_d'])
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=h.lr_decay, last_epoch=last_epoch)
scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=h.lr_decay, last_epoch=last_epoch)
training_indexes, validation_indexes = get_dataset_filelist(a)
trainset = Dataset(training_indexes, a.input_clean_wavs_dir, a.input_noisy_wavs_dir, h.segment_size, h.sampling_rate,
split=True, n_cache_reuse=0, shuffle=False if h.num_gpus > 1 else True, device=device)
train_sampler = DistributedSampler(trainset) if h.num_gpus > 1 else None
train_loader = DataLoader(trainset, num_workers=h.num_workers, shuffle=False,
sampler=train_sampler,
batch_size=h.batch_size,
pin_memory=True,
drop_last=True)
if rank == 0:
validset = Dataset(validation_indexes, a.input_clean_wavs_dir, a.input_noisy_wavs_dir, h.segment_size, h.sampling_rate,
split=False, shuffle=False, n_cache_reuse=0, device=device)
validation_loader = DataLoader(validset, num_workers=1, shuffle=False,
sampler=None,
batch_size=1,
pin_memory=True,
drop_last=True)
sw = SummaryWriter(os.path.join(a.checkpoint_path, 'logs'))
generator.train()
discriminator.train()
best_pesq = 0
for epoch in range(max(0, last_epoch), a.training_epochs):
if rank == 0:
start = time.time()
print("Epoch: {}".format(epoch+1))
if h.num_gpus > 1:
train_sampler.set_epoch(epoch)
for i, batch in enumerate(train_loader):
if rank == 0:
start_b = time.time()
clean_audio, noisy_audio = batch
clean_audio = torch.autograd.Variable(clean_audio.to(device, non_blocking=True))
noisy_audio = torch.autograd.Variable(noisy_audio.to(device, non_blocking=True))
one_labels = torch.ones(h.batch_size).to(device, non_blocking=True)
clean_mag, clean_pha, clean_com = mag_pha_stft(clean_audio, h.n_fft, h.hop_size, h.win_size, h.compress_factor)
noisy_mag, noisy_pha, noisy_com = mag_pha_stft(noisy_audio, h.n_fft, h.hop_size, h.win_size, h.compress_factor)
mag_g, pha_g, com_g = generator(noisy_mag, noisy_pha)
audio_g = mag_pha_istft(mag_g, pha_g, h.n_fft, h.hop_size, h.win_size, h.compress_factor)
mag_g_hat, pha_g_hat, com_g_hat = mag_pha_stft(audio_g, h.n_fft, h.hop_size, h.win_size, h.compress_factor)
audio_list_r, audio_list_g = list(clean_audio.cpu().numpy()), list(audio_g.detach().cpu().numpy())
batch_pesq_score = batch_pesq(audio_list_r, audio_list_g)
# Discriminator
optim_d.zero_grad()
metric_r = discriminator(clean_mag, clean_mag)
metric_g = discriminator(clean_mag, mag_g_hat.detach())
loss_disc_r = F.mse_loss(one_labels, metric_r.flatten())
if batch_pesq_score is not None:
loss_disc_g = F.mse_loss(batch_pesq_score.to(device), metric_g.flatten())
else:
print('pesq is None!')
loss_disc_g = 0
loss_disc_all = loss_disc_r + loss_disc_g
loss_disc_all.backward()
optim_d.step()
# if batch_pesq_score is not None:
# # Discriminator
# optim_d.zero_grad()
# metric_r = discriminator(clean_mag, clean_mag)
# metric_g = discriminator(clean_mag, mag_g_hat.detach())
# loss_disc_r = F.mse_loss(one_labels, metric_r.flatten())
# loss_disc_g = F.mse_loss(batch_pesq_score.to(device), metric_g.flatten())
# loss_disc_all = loss_disc_r + loss_disc_g
# loss_disc_all.backward()
# optim_d.step()
# else:
# print('PESQ is None!')
# loss_disc_all = 0
# Generator
optim_g.zero_grad()
# L2 Magnitude Loss
loss_mag = F.mse_loss(clean_mag, mag_g)
# Anti-wrapping Phase Loss
loss_ip, loss_gd, loss_iaf = phase_losses(clean_pha, pha_g)
loss_pha = loss_ip + loss_gd + loss_iaf
# L2 Complex Loss
loss_com = F.mse_loss(clean_com, com_g) * 2
# L2 Consistency Loss
loss_stft = F.mse_loss(com_g, com_g_hat) * 2
# Time Loss
loss_time = F.l1_loss(clean_audio, audio_g)
# Metric Loss
metric_g = discriminator(clean_mag, mag_g_hat)
loss_metric = F.mse_loss(metric_g.flatten(), one_labels)
loss_gen_all = loss_mag * 0.9 + loss_pha * 0.3 + loss_com * 0.1 + loss_stft * 0.1 + loss_metric * 0.05 + loss_time * 0.2
loss_gen_all.backward()
optim_g.step()
if rank == 0:
# STDOUT logging
if steps % a.stdout_interval == 0:
with torch.no_grad():
metric_error = F.mse_loss(metric_g.flatten(), one_labels).item()
mag_error = F.mse_loss(clean_mag, mag_g).item()
ip_error, gd_error, iaf_error = phase_losses(clean_pha, pha_g)
pha_error = (ip_error + gd_error + iaf_error).item()
com_error = F.mse_loss(clean_com, com_g).item()
time_error = F.l1_loss(clean_audio, audio_g).item()
stft_error = F.mse_loss(com_g, com_g_hat).item()
print('Steps : {:d}, Gen Loss: {:4.3f}, Disc Loss: {:4.3f}, Metric loss: {:4.3f}, Magnitude Loss : {:4.3f}, Phase Loss : {:4.3f}, Complex Loss : {:4.3f}, Time Loss : {:4.3f}, STFT Loss : {:4.3f}, s/b : {:4.3f}'.
format(steps, loss_gen_all, loss_disc_all, metric_error, mag_error, pha_error, com_error, time_error, stft_error, time.time() - start_b))
# checkpointing
if steps % a.checkpoint_interval == 0 and steps != 0:
checkpoint_path = "{}/g_{:08d}".format(a.checkpoint_path, steps)
save_checkpoint(checkpoint_path,
{'generator': (generator.module if h.num_gpus > 1 else generator).state_dict()})
checkpoint_path = "{}/do_{:08d}".format(a.checkpoint_path, steps)
save_checkpoint(checkpoint_path,
{'discriminator': (discriminator.module if h.num_gpus > 1 else discriminator).state_dict(),
'optim_g': optim_g.state_dict(), 'optim_d': optim_d.state_dict(), 'steps': steps,
'epoch': epoch})
# Tensorboard summary logging
if steps % a.summary_interval == 0:
sw.add_scalar("Training/Generator Loss", loss_gen_all, steps)
sw.add_scalar("Training/Discriminator Loss", loss_disc_all, steps)
sw.add_scalar("Training/Metric Loss", metric_error, steps)
sw.add_scalar("Training/Magnitude Loss", mag_error, steps)
sw.add_scalar("Training/Phase Loss", pha_error, steps)
sw.add_scalar("Training/Complex Loss", com_error, steps)
sw.add_scalar("Training/Time Loss", time_error, steps)
sw.add_scalar("Training/Consistency Loss", stft_error, steps)
# Validation
if steps % a.validation_interval == 0 and steps != 0:
generator.eval()
torch.cuda.empty_cache()
audios_r, audios_g = [], []
val_mag_err_tot = 0
val_pha_err_tot = 0
val_com_err_tot = 0
val_stft_err_tot = 0
with torch.no_grad():
for j, batch in enumerate(validation_loader):
clean_audio, noisy_audio = batch
clean_audio = torch.autograd.Variable(clean_audio.to(device, non_blocking=True))
noisy_audio = torch.autograd.Variable(noisy_audio.to(device, non_blocking=True))
clean_mag, clean_pha, clean_com = mag_pha_stft(clean_audio, h.n_fft, h.hop_size, h.win_size, h.compress_factor)
noisy_mag, noisy_pha, noisy_com = mag_pha_stft(noisy_audio, h.n_fft, h.hop_size, h.win_size, h.compress_factor)
mag_g, pha_g, com_g = generator(noisy_mag, noisy_pha)
audio_g = mag_pha_istft(mag_g, pha_g, h.n_fft, h.hop_size, h.win_size, h.compress_factor)
mag_g_hat, pha_g_hat, com_g_hat = mag_pha_stft(audio_g, h.n_fft, h.hop_size, h.win_size, h.compress_factor)
audios_r += torch.split(clean_audio, 1, dim=0) # [1, T] * B
audios_g += torch.split(audio_g, 1, dim=0)
val_mag_err_tot += F.mse_loss(clean_mag, mag_g).item()
val_ip_err, val_gd_err, val_iaf_err = phase_losses(clean_pha, pha_g)
val_pha_err_tot += (val_ip_err + val_gd_err + val_iaf_err).item()
val_com_err_tot += F.mse_loss(clean_com, com_g).item()
val_stft_err_tot += F.mse_loss(com_g, com_g_hat).item()
val_mag_err = val_mag_err_tot / (j+1)
val_pha_err = val_pha_err_tot / (j+1)
val_com_err = val_com_err_tot / (j+1)
val_stft_err = val_stft_err_tot / (j+1)
val_pesq_score = pesq_score(audios_r, audios_g, h).item()
print('Steps : {:d}, PESQ Score: {:4.3f}, s/b : {:4.3f}'.
format(steps, val_pesq_score, time.time() - start_b))
sw.add_scalar("Validation/PESQ Score", val_pesq_score, steps)
sw.add_scalar("Validation/Magnitude Loss", val_mag_err, steps)
sw.add_scalar("Validation/Phase Loss", val_pha_err, steps)
sw.add_scalar("Validation/Complex Loss", val_com_err, steps)
sw.add_scalar("Validation/Consistency Loss", val_stft_err, steps)
if epoch >= a.best_checkpoint_start_epoch:
if val_pesq_score > best_pesq:
best_pesq = val_pesq_score
best_checkpoint_path = "{}/g_best".format(a.checkpoint_path)
save_checkpoint(best_checkpoint_path,
{'generator': (generator.module if h.num_gpus > 1 else generator).state_dict()})
generator.train()
steps += 1
scheduler_g.step()
scheduler_d.step()
if rank == 0:
print('Time taken for epoch {} is {} sec\n'.format(epoch + 1, int(time.time() - start)))
def main():
print('Initializing Training Process..')
parser = argparse.ArgumentParser()
parser.add_argument('--group_name', default=None)
parser.add_argument('--input_clean_wavs_dir', default='VoiceBank+DEMAND/wavs_clean')
parser.add_argument('--input_noisy_wavs_dir', default='VoiceBank+DEMAND/wavs_noisy')
parser.add_argument('--input_training_file', default='VoiceBank+DEMAND/training.txt')
parser.add_argument('--input_validation_file', default='VoiceBank+DEMAND/test.txt')
parser.add_argument('--checkpoint_path', default='cp_model')
parser.add_argument('--config', default='')
parser.add_argument('--training_epochs', default=400, type=int)
parser.add_argument('--stdout_interval', default=5, type=int)
parser.add_argument('--checkpoint_interval', default=5000, type=int)
parser.add_argument('--summary_interval', default=100, type=int)
parser.add_argument('--validation_interval', default=5000, type=int)
parser.add_argument('--best_checkpoint_start_epoch', default=40, type=int)
a = parser.parse_args()
with open(a.config) as f:
data = f.read()
json_config = json.loads(data)
h = AttrDict(json_config)
build_env(a.config, 'config.json', a.checkpoint_path)
torch.manual_seed(h.seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(h.seed)
h.num_gpus = torch.cuda.device_count()
h.batch_size = int(h.batch_size / h.num_gpus)
print('Batch size per GPU :', h.batch_size)
else:
pass
if h.num_gpus > 1:
mp.spawn(train, nprocs=h.num_gpus, args=(a, h,))
else:
train(0, a, h)
if __name__ == '__main__':
main()