| | import os, json, argparse, yaml |
| | import numpy as np |
| | from tqdm import tqdm |
| | import librosa |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from torch.utils.data import DataLoader |
| | from torch.cuda.amp import autocast, GradScaler |
| |
|
| | from dataset.diffpitch import DiffPitch |
| | from models.transformer import PitchFormer |
| | from utils import minmax_norm_diff, reverse_minmax_norm_diff, save_curve_plot |
| |
|
| |
|
| | parser = argparse.ArgumentParser() |
| | parser.add_argument('-config', type=str, default='config/DiffPitch.yaml') |
| |
|
| | parser.add_argument('-seed', type=int, default=9811) |
| | parser.add_argument('-amp', type=bool, default=False) |
| | parser.add_argument('-compile', type=bool, default=False) |
| |
|
| | parser.add_argument('-data_dir', type=str, default='data/') |
| | parser.add_argument('-content_dir', type=str, default='world') |
| |
|
| | parser.add_argument('-train_frames', type=int, default=256) |
| | parser.add_argument('-test_frames', type=int, default=256) |
| | parser.add_argument('-batch_size', type=int, default=32) |
| | parser.add_argument('-test_size', type=int, default=1) |
| | parser.add_argument('-num_workers', type=int, default=4) |
| | parser.add_argument('-lr', type=float, default=5e-5) |
| | parser.add_argument('-weight_decay', type=int, default=1e-6) |
| |
|
| | parser.add_argument('-epochs', type=int, default=1) |
| | parser.add_argument('-save_every', type=int, default=20) |
| | parser.add_argument('-log_step', type=int, default=100) |
| | parser.add_argument('-log_dir', type=str, default='logs_transformer_pitch') |
| | parser.add_argument('-ckpt_dir', type=str, default='ckpt_transformer_pitch') |
| |
|
| | args = parser.parse_args() |
| | args.save_ori = True |
| | config = yaml.load(open(args.config), Loader=yaml.FullLoader) |
| | mel_cfg = config['logmel'] |
| | ddpm_cfg = config['ddpm'] |
| | |
| |
|
| |
|
| | def RMSE(gen_f0, gt_f0): |
| | |
| | gt_f0 = gt_f0[0] |
| | gen_f0 = gen_f0[0] |
| |
|
| | nonzero_idxs = np.where((gen_f0 != 0) & (gt_f0 != 0))[0] |
| | gen_f0_voiced = np.log2(gen_f0[nonzero_idxs]) |
| | gt_f0_voiced = np.log2(gt_f0[nonzero_idxs]) |
| | |
| | if len(gen_f0_voiced) != 0: |
| | f0_rmse = np.sqrt(np.mean((gen_f0_voiced - gt_f0_voiced) ** 2)) |
| | else: |
| | f0_rmse = 0 |
| | return f0_rmse |
| |
|
| |
|
| | if __name__ == "__main__": |
| | torch.manual_seed(args.seed) |
| | np.random.seed(args.seed) |
| | if torch.cuda.is_available(): |
| | args.device = 'cuda' |
| | torch.cuda.manual_seed(args.seed) |
| | torch.cuda.manual_seed_all(args.seed) |
| | torch.backends.cuda.matmul.allow_tf32 = True |
| | if torch.backends.cudnn.is_available(): |
| | torch.backends.cudnn.deterministic = True |
| | torch.backends.cudnn.benchmark = True |
| | else: |
| | args.device = 'cpu' |
| |
|
| | if os.path.exists(args.log_dir) is False: |
| | os.makedirs(args.log_dir) |
| |
|
| | if os.path.exists(args.ckpt_dir) is False: |
| | os.makedirs(args.ckpt_dir) |
| |
|
| | print('Initializing data loaders...') |
| | trainset = DiffPitch('data/', 'train', args.train_frames, shift=True) |
| | train_loader = DataLoader(trainset, batch_size=args.batch_size, num_workers=args.num_workers, |
| | drop_last=True, shuffle=True) |
| |
|
| | val_set = DiffPitch('data/', 'val', args.test_frames, shift=True) |
| | val_loader = DataLoader(val_set, batch_size=1, shuffle=False) |
| |
|
| | test_set = DiffPitch('data/', 'test', args.test_frames, shift=True) |
| | test_loader = DataLoader(test_set, batch_size=1, shuffle=False) |
| |
|
| | real_set = DiffPitch('data/', 'real', args.test_frames, shift=False) |
| | read_loader = DataLoader(real_set, batch_size=1, shuffle=False) |
| |
|
| | print('Initializing and loading models...') |
| | model = PitchFormer(mel_cfg['n_mels'], 512).to(args.device) |
| | ckpt = torch.load('ckpt_transformer_pitch/transformer_pitch_460.pt') |
| | model.load_state_dict(ckpt) |
| |
|
| | print('Initializing optimizers...') |
| | optimizer = torch.optim.AdamW(params=model.parameters(), lr=args.lr, weight_decay=args.weight_decay) |
| | scaler = GradScaler() |
| |
|
| | if args.compile: |
| | model = torch.compile(model) |
| |
|
| | print('Start training.') |
| | global_step = 0 |
| | for epoch in range(1, args.epochs + 1): |
| | print(f'Epoch: {epoch} [iteration: {global_step}]') |
| | model.train() |
| | losses = [] |
| |
|
| | for step, batch in enumerate(tqdm(train_loader)): |
| | optimizer.zero_grad() |
| | mel, midi, f0 = batch |
| | mel = mel.to(args.device) |
| | midi = midi.to(args.device) |
| | f0 = f0.to(args.device) |
| |
|
| | if args.amp: |
| | with autocast(): |
| | f0_pred = model(sp=mel, midi=midi) |
| | loss = F.mse_loss(f0_pred, f0) |
| | scaler.scale(loss).backward() |
| | scaler.step(optimizer) |
| | scaler.update() |
| | else: |
| | f0_pred = model(sp=mel, midi=midi) |
| | loss = F.l1_loss(f0_pred, f0) |
| | |
| | loss.backward() |
| | optimizer.step() |
| |
|
| | losses.append(loss.item()) |
| | global_step += 1 |
| |
|
| | if global_step % args.log_step == 0: |
| | losses = np.asarray(losses) |
| | |
| | msg = '\nEpoch: [{}][{}]\t' \ |
| | 'Batch: [{}][{}]\tLoss: {:.6f}\n'.format(epoch, |
| | args.epochs, |
| | step+1, |
| | len(train_loader), |
| | np.mean(losses)) |
| | with open(f'{args.log_dir}/train_dec.log', 'a') as f: |
| | f.write(msg) |
| | losses = [] |
| |
|
| | if epoch % args.save_every > 0: |
| | continue |
| |
|
| | print('Saving model...\n') |
| | ckpt = model.state_dict() |
| | torch.save(ckpt, f=f"{args.ckpt_dir}/transformer_pitch_{epoch}.pt") |
| |
|
| | print('Inference...\n') |
| | model.eval() |
| | with torch.no_grad(): |
| | val_loss = [] |
| | val_rmse = [] |
| | for i, batch in enumerate(val_loader): |
| | |
| | mel, midi, f0 = batch |
| | mel = mel.to(args.device) |
| | midi = midi.to(args.device) |
| | f0 = f0.to(args.device) |
| |
|
| | f0_pred = model(sp=mel, midi=midi) |
| |
|
| | |
| | f0_pred[f0_pred < librosa.note_to_hz('C2')] = 0 |
| | f0_pred[f0_pred > librosa.note_to_hz('C6')] = librosa.note_to_hz('C6') |
| |
|
| | val_loss.append(F.l1_loss(f0_pred, f0).item()) |
| | val_rmse.append(RMSE(f0_pred.cpu().numpy(), f0.cpu().numpy())) |
| |
|
| | if i <= 4: |
| | save_path = f'{args.log_dir}/pic/{i}/{epoch}_val.png' |
| | if os.path.exists(os.path.dirname(save_path)) is False: |
| | os.makedirs(os.path.dirname(save_path)) |
| | save_curve_plot(f0_pred.cpu().squeeze(), midi.cpu().squeeze(), f0.cpu().squeeze(), save_path) |
| | |
| | |
| |
|
| | msg = '\nEpoch: [{}][{}]\tLoss: {:.6f}\tRMSE:{:.6f}\n'.\ |
| | format(epoch, args.epochs, np.mean(val_loss), np.mean(val_rmse)) |
| | with open(f'{args.log_dir}/eval_dec.log', 'a') as f: |
| | f.write(msg) |
| |
|
| | test_loss = [] |
| | test_rmse = [] |
| | for i, batch in enumerate(test_loader): |
| | |
| | mel, midi, f0 = batch |
| | mel = mel.to(args.device) |
| | midi = midi.to(args.device) |
| | f0 = f0.to(args.device) |
| |
|
| | f0_pred = model(sp=mel, midi=midi) |
| |
|
| | |
| | f0_pred[f0_pred < librosa.note_to_hz('C2')] = 0 |
| | f0_pred[f0_pred > librosa.note_to_hz('C6')] = librosa.note_to_hz('C6') |
| |
|
| | test_loss.append(F.l1_loss(f0_pred, f0).item()) |
| | test_rmse.append(RMSE(f0_pred.cpu().numpy(), f0.cpu().numpy())) |
| |
|
| | if i <= 4: |
| | save_path = f'{args.log_dir}/pic/{i}/{epoch}_test.png' |
| | if os.path.exists(os.path.dirname(save_path)) is False: |
| | os.makedirs(os.path.dirname(save_path)) |
| | save_curve_plot(f0_pred.cpu().squeeze(), midi.cpu().squeeze(), f0.cpu().squeeze(), save_path) |
| |
|
| | msg = '\nEpoch: [{}][{}]\tLoss: {:.6f}\tRMSE:{:.6f}\n'. \ |
| | format(epoch, args.epochs, np.mean(test_loss), np.mean(test_rmse)) |
| | with open(f'{args.log_dir}/test_dec.log', 'a') as f: |
| | f.write(msg) |
| |
|
| | for i, batch in enumerate(read_loader): |
| | |
| | mel, midi, f0 = batch |
| | mel = mel.to(args.device) |
| | midi = midi.to(args.device) |
| | f0 = f0.to(args.device) |
| |
|
| | f0_pred = model(sp=mel, midi=midi) |
| | f0_pred[f0 == 0] = 0 |
| |
|
| | |
| | f0_pred[f0_pred < librosa.note_to_hz('C2')] = 0 |
| | f0_pred[f0_pred > librosa.note_to_hz('C6')] = librosa.note_to_hz('C6') |
| |
|
| | save_path = f'{args.log_dir}/pic/{i}/{epoch}_real.png' |
| | if os.path.exists(os.path.dirname(save_path)) is False: |
| | os.makedirs(os.path.dirname(save_path)) |
| | save_curve_plot(f0_pred.cpu().squeeze(), midi.cpu().squeeze(), f0.cpu().squeeze(), save_path) |
| |
|
| |
|
| |
|
| |
|
| |
|