| import yaml
|
| import random
|
| import argparse
|
| import os
|
| import time
|
| from tqdm import tqdm
|
| from pathlib import Path
|
|
|
| import torch
|
| from torch.utils.data import DataLoader
|
|
|
| from accelerate import Accelerator
|
| from diffusers import DDIMScheduler
|
|
|
| from configs.plugin import get_params
|
| from model.p2e_cross import P2E_Cross
|
| from modules.speaker_encoder.encoder import inference as spk_encoder
|
| from transformers import T5Tokenizer, T5EncoderModel, AutoModel
|
| from inference_freevc import eval_plugin
|
| from dataset.dreamvc import DreamData
|
|
|
| from freevc_wrapper import get_freevc_models
|
| from utils import minmax_norm_diff, reverse_minmax_norm_diff, scale_shift
|
|
|
| parser = argparse.ArgumentParser()
|
|
|
|
|
| parser.add_argument('--config-name', type=str, default='Plugin_freevc')
|
| parser.add_argument('--vc-unet-path', type=str, default='freevc')
|
| parser.add_argument('--speaker-path', type=str, default='speaker_encoder/ckpt/pretrained_bak_5805000.pt')
|
|
|
|
|
|
|
| parser.add_argument("--amp", type=str, default='fp16')
|
| parser.add_argument('--epochs', type=int, default=200)
|
| parser.add_argument('--batch-size', type=int, default=32)
|
| parser.add_argument('--num-workers', type=int, default=8)
|
| parser.add_argument('--num-threads', type=int, default=1)
|
| parser.add_argument('--save-every', type=int, default=10)
|
|
|
|
|
| parser.add_argument('--random-seed', type=int, default=2023)
|
| parser.add_argument('--log-step', type=int, default=200)
|
| parser.add_argument('--log-dir', type=str, default='../logs/')
|
| parser.add_argument('--save-dir', type=str, default='../ckpts/')
|
|
|
| args = parser.parse_args()
|
| params = get_params(args.config_name)
|
| args.log_dir = args.log_dir + args.config_name + '/'
|
|
|
| with open('model/p2e_cross.yaml', 'r') as fp:
|
| config = yaml.safe_load(fp)
|
|
|
| if os.path.exists(args.save_dir + args.config_name) is False:
|
| os.makedirs(args.save_dir + args.config_name)
|
|
|
| if os.path.exists(args.log_dir) is False:
|
| os.makedirs(args.log_dir)
|
|
|
| if __name__ == '__main__':
|
|
|
| random.seed(args.random_seed)
|
| torch.manual_seed(args.random_seed)
|
|
|
|
|
| torch.set_num_threads(args.num_threads)
|
| if torch.cuda.is_available():
|
| args.device = 'cuda'
|
| torch.cuda.manual_seed(args.random_seed)
|
| torch.cuda.manual_seed_all(args.random_seed)
|
| torch.backends.cuda.matmul.allow_tf32 = True
|
| if torch.backends.cudnn.is_available():
|
| torch.backends.cudnn.deterministic = True
|
| torch.backends.cudnn.allow_tf32 = True
|
| torch.backends.cudnn.benchmark = False
|
| else:
|
| args.device = 'cpu'
|
|
|
| train_set = DreamData(data_dir='../prepare_freevc/spk/', meta_dir='../prepare/plugin_meta.csv',
|
| subset='train', prompt_dir='../prepare/prompts.csv',)
|
| train_loader = DataLoader(train_set, num_workers=args.num_workers, batch_size=args.batch_size, shuffle=True)
|
|
|
|
|
| accelerator = Accelerator(mixed_precision=args.amp)
|
|
|
|
|
|
|
|
|
|
|
|
|
| freevc_24, cmodel, _, hps = get_freevc_models(args.vc_unet_path, args.speaker_path, accelerator.device)
|
|
|
|
|
|
|
|
|
| tokenizer = T5Tokenizer.from_pretrained(params.text_encoder.model)
|
| text_encoder = T5EncoderModel.from_pretrained(params.text_encoder.model).to(accelerator.device)
|
| text_encoder.eval()
|
|
|
|
|
| model = P2E_Cross(config['diffwrap']).to(accelerator.device)
|
| model.load_state_dict(torch.load('../ckpts/Plugin_freevc/49.pt')['model'])
|
|
|
| total_params = sum([param.nelement() for param in model.parameters()])
|
| print("Number of parameter: %.2fM" % (total_params / 1e6))
|
|
|
| if params.diff.v_prediction:
|
| print('v prediction')
|
| noise_scheduler = DDIMScheduler(num_train_timesteps=params.diff.num_train_steps,
|
| beta_start=params.diff.beta_start, beta_end=params.diff.beta_end,
|
| rescale_betas_zero_snr=True,
|
| timestep_spacing="trailing",
|
| clip_sample=False,
|
| prediction_type='v_prediction')
|
| else:
|
| print('noise prediction')
|
| noise_scheduler = DDIMScheduler(num_train_timesteps=args.num_train_steps,
|
| beta_start=args.beta_start, beta_end=args.beta_end,
|
| clip_sample=False,
|
| prediction_type='epsilon')
|
|
|
| optimizer = torch.optim.AdamW(model.parameters(),
|
| lr=params.opt.learning_rate,
|
| betas=(params.opt.beta1, params.opt.beta2),
|
| weight_decay=params.opt.weight_decay,
|
| eps=params.opt.adam_epsilon,
|
| )
|
| loss_func = torch.nn.MSELoss()
|
|
|
| model, optimizer, train_loader = accelerator.prepare(model, optimizer, train_loader)
|
|
|
| global_step = 0
|
| losses = 0
|
|
|
| if accelerator.is_main_process:
|
| eval_plugin(freevc_24, cmodel, [tokenizer, text_encoder],
|
| model, noise_scheduler, (1, 256, 1),
|
| val_meta='../prepare/val_meta.csv',
|
| val_folder='/home/jerry/Projects/Dataset/Speech/vctk_libritts/',
|
| guidance_scale=3.0, guidance_rescale=0.0,
|
| ddim_steps=100, eta=1, random_seed=None,
|
| device=accelerator.device,
|
| epoch='test', save_path=args.log_dir + 'output/', val_num=10)
|
| accelerator.wait_for_everyone()
|
|
|
| for epoch in range(args.epochs):
|
| model.train()
|
| for step, batch in enumerate(tqdm(train_loader)):
|
| spk_embed, prompt = batch
|
| spk_embed = spk_embed.unsqueeze(-1)
|
|
|
| with torch.no_grad():
|
| text_batch = tokenizer(prompt,
|
| max_length=32,
|
| padding='max_length', truncation=True, return_tensors="pt")
|
| text, text_mask = text_batch.input_ids.to(spk_embed.device), \
|
| text_batch.attention_mask.to(spk_embed.device)
|
| text = text_encoder(input_ids=text, attention_mask=text_mask)[0]
|
|
|
| spk_embed = scale_shift(spk_embed, 20, -0.035)
|
|
|
|
|
|
|
|
|
|
|
| noise = torch.randn(spk_embed.shape).to(accelerator.device)
|
| timesteps = torch.randint(0, params.diff.num_train_steps, (noise.shape[0],),
|
| device=accelerator.device, ).long()
|
| noisy_target = noise_scheduler.add_noise(spk_embed, noise, timesteps)
|
|
|
| velocity = noise_scheduler.get_velocity(spk_embed, noise, timesteps)
|
|
|
|
|
| pred = model(noisy_target, timesteps, text, text_mask, train_cfg=True, cfg_prob=0.25)
|
|
|
| if params.diff.v_prediction:
|
| loss = loss_func(pred, velocity)
|
| else:
|
| loss = loss_func(pred, noise)
|
|
|
| accelerator.backward(loss)
|
| optimizer.step()
|
| optimizer.zero_grad()
|
|
|
| global_step += 1
|
| losses += loss.item()
|
|
|
| if accelerator.is_main_process:
|
| if global_step % args.log_step == 0:
|
| n = open(args.log_dir + 'diff_vc.txt', mode='a')
|
| n.write(time.asctime(time.localtime(time.time())))
|
| n.write('\n')
|
| n.write('Epoch: [{}][{}] Batch: [{}][{}] Loss: {:.6f}\n'.format(
|
| epoch + 1, args.epochs, step + 1, len(train_loader), losses / args.log_step))
|
| n.close()
|
| losses = 0.0
|
|
|
| accelerator.wait_for_everyone()
|
|
|
| if (epoch + 1) % args.save_every == 0:
|
| if accelerator.is_main_process:
|
| eval_plugin(freevc_24, cmodel, [tokenizer, text_encoder],
|
| model, noise_scheduler, (1, 256, 1),
|
| val_meta='../prepare/val_meta.csv',
|
| val_folder='/home/jerry/Projects/Dataset/Speech/vctk_libritts/',
|
| guidance_scale=3, guidance_rescale=0.0,
|
| ddim_steps=50, eta=1, random_seed=2024,
|
| device=accelerator.device,
|
| epoch=epoch, save_path=args.log_dir + 'output/', val_num=10)
|
|
|
| unwrapped_unet = accelerator.unwrap_model(model)
|
| accelerator.save({
|
| "model": unwrapped_unet.state_dict(),
|
| }, args.save_dir + args.config_name + '/' + str(epoch) + '.pt')
|
|
|