|
|
import os |
|
|
import sys |
|
|
import time |
|
|
|
|
|
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) |
|
|
import linecache |
|
|
import mmap |
|
|
import pickle as pkl |
|
|
import random |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
import torch.optim as optim |
|
|
import torchaudio |
|
|
from accelerate import Accelerator, DistributedDataParallelKwargs |
|
|
from mel_spec import get_mel_spectrogram |
|
|
from torch.distributed import destroy_process_group, init_process_group |
|
|
from torch.distributed.elastic.utils.data import ElasticDistributedSampler |
|
|
from torch.nn.parallel import DistributedDataParallel as DDP |
|
|
from torch.utils.data import DataLoader, Dataset |
|
|
from torch.utils.data.distributed import DistributedSampler |
|
|
from tqdm.auto import tqdm |
|
|
|
|
|
import wandb |
|
|
from config import config |
|
|
from S2A.diff_model import DiffModel |
|
|
from S2A.flow_matching import BASECFM |
|
|
from S2A.inference import infer |
|
|
from S2A.utilities import (dynamic_range_compression, get_mask, |
|
|
get_mask_from_lengths, load_wav_to_torch, |
|
|
normalize_tacotron_mel) |
|
|
from Text import code_labels, labels, text_labels |
|
|
|
|
|
|
|
|
torch.backends.cuda.matmul.allow_tf32 = True |
|
|
torch.backends.cudnn.allow_tf32 = True |
|
|
|
|
|
torch.manual_seed(config.seed_value) |
|
|
np.random.seed(config.seed_value) |
|
|
random.seed(config.seed_value) |
|
|
|
|
|
CLIP_LENGTH = config.CLIP_LENGTH |
|
|
|
|
|
|
|
|
text_enc = {j: i for i, j in enumerate(text_labels)} |
|
|
text_dec = {i: j for i, j in enumerate(text_labels)} |
|
|
|
|
|
|
|
|
code_enc = {j: i for i, j in enumerate(code_labels)} |
|
|
code_dec = {i: j for i, j in enumerate(code_labels)} |
|
|
|
|
|
|
|
|
def read_specific_line(filename, line_number): |
|
|
line = linecache.getline(filename, line_number) |
|
|
return line.strip() |
|
|
|
|
|
|
|
|
class Acoustic_dataset(Dataset): |
|
|
def __init__( |
|
|
self, |
|
|
transcript_path, |
|
|
semantic_path=None, |
|
|
ref_mels_path=None, |
|
|
ref_k=1, |
|
|
scale=True, |
|
|
ar_active=False, |
|
|
clip=True, |
|
|
dur_=None, |
|
|
): |
|
|
super(Acoustic_dataset).__init__() |
|
|
self.scale = scale |
|
|
self.ar_active = ar_active |
|
|
self.clip = clip |
|
|
self.dur_ = dur_ |
|
|
if self.dur_ is None: |
|
|
self.dur_ = 2 |
|
|
if not scale: |
|
|
with open(transcript_path, "r") as file: |
|
|
data = file.read().strip("\n").split("\n")[:] |
|
|
|
|
|
with open(semantic_path, "r") as file: |
|
|
semb = file.read().strip("\n").split("\n") |
|
|
|
|
|
with open(ref_mels_path, "rb") as file: |
|
|
self.ref_mels = pkl.load(file) |
|
|
|
|
|
semb = { |
|
|
i.split("\t")[0]: [j for j in i.split("\t")[1].split()] for i in semb |
|
|
} |
|
|
data = {i.split("|")[0]: i.split("|")[1].strip().lower() for i in data} |
|
|
|
|
|
self.data = [[i, semb[i], data[i]] for i in data.keys()][:] |
|
|
|
|
|
else: |
|
|
self.transcript_path = transcript_path |
|
|
line_index = {} |
|
|
with open(transcript_path, "rb") as file: |
|
|
mmapped_file = mmap.mmap(file.fileno(), 0, access=mmap.ACCESS_READ) |
|
|
line_number = 0 |
|
|
offset = 0 |
|
|
progress_bar = tqdm(desc="processing:") |
|
|
while offset < len(mmapped_file): |
|
|
line_index[line_number] = offset |
|
|
offset = mmapped_file.find(b"\n", offset) + 1 |
|
|
line_number += 1 |
|
|
progress_bar.update(1) |
|
|
progress_bar.close() |
|
|
self.mmapped_file = mmapped_file |
|
|
self.data_len = len(line_index) |
|
|
self.line_index = line_index |
|
|
|
|
|
self.ref_k = ref_k |
|
|
self.max_wav_value = config.MAX_WAV_VALUE |
|
|
|
|
|
def get_mel(self, filepath, semb_ids=None, align=False, ref_clip=False): |
|
|
audio_norm, sampling_rate = torchaudio.load(filepath) |
|
|
dur = audio_norm.shape[-1] / sampling_rate |
|
|
|
|
|
if self.clip and dur > self.dur_ and align: |
|
|
max_audio_start = int(dur - self.dur_) |
|
|
if max_audio_start <= 0: |
|
|
audio_start = 0 |
|
|
else: |
|
|
audio_start = np.random.randint(0, max_audio_start) |
|
|
|
|
|
audio_norm = audio_norm[ |
|
|
:, |
|
|
audio_start * sampling_rate : (audio_start + self.dur_) * sampling_rate, |
|
|
] |
|
|
semb_ids = semb_ids[audio_start * 50 : ((audio_start + self.dur_) * 50) - 1] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if ref_clip == True: |
|
|
dur_ = 6 |
|
|
max_audio_start = int(dur - dur_) |
|
|
if max_audio_start <= 0: |
|
|
audio_start = 0 |
|
|
else: |
|
|
audio_start = np.random.randint(0, max_audio_start) |
|
|
audio_norm = audio_norm[ |
|
|
:, audio_start * sampling_rate : (audio_start + dur_) * sampling_rate |
|
|
] |
|
|
|
|
|
melspec = get_mel_spectrogram(audio_norm, sampling_rate).squeeze(0) |
|
|
energy = [] |
|
|
if align: |
|
|
return melspec, list(energy), semb_ids |
|
|
return melspec, list(energy) |
|
|
|
|
|
def __len__(self): |
|
|
if self.scale: |
|
|
return self.data_len |
|
|
return len(self.data) |
|
|
|
|
|
def __getitem__(self, index): |
|
|
""" |
|
|
mel_spec,semb |
|
|
""" |
|
|
if not self.scale: |
|
|
lang, path, semb, text = self.data[index] |
|
|
ref_mels = self.ref_mels[path][: self.ref_k] |
|
|
semb_ids = [int(i) + 1 for i in semb] |
|
|
|
|
|
else: |
|
|
self.mmapped_file.seek(self.line_index[index]) |
|
|
line = self.mmapped_file.readline().decode("utf-8") |
|
|
|
|
|
lang, path, text, semb_ids = line.split("|") |
|
|
semb_ids = [int(i) + 1 for i in semb_ids.split()] |
|
|
ref_mels = [path][: self.ref_k] |
|
|
|
|
|
try: |
|
|
mel_spec, energy, semb_ids = self.get_mel(path, semb_ids, align=True) |
|
|
if len(semb_ids) == 0: |
|
|
raise Exception("Sorry, no semb ids" + str(line)) |
|
|
except Exception as e: |
|
|
print(index, e) |
|
|
if index + 1 < self.data_len: |
|
|
return self.__getitem__(index + 1) |
|
|
return self.__getitem__(0) |
|
|
|
|
|
if len(ref_mels) == 0: |
|
|
print(index, e, "no ref mels present") |
|
|
if index + 1 < self.data_len: |
|
|
return self.__getitem__(index + 1) |
|
|
return self.__getitem__(0) |
|
|
|
|
|
while len(ref_mels) < self.ref_k: |
|
|
ref_mels.append(ref_mels[-1]) |
|
|
|
|
|
if mel_spec is None: |
|
|
print(index, e, "mel_spec error present") |
|
|
if index + 1 < self.data_len: |
|
|
return self.__getitem__(index + 1) |
|
|
return self.__getitem__(0) |
|
|
|
|
|
def get_random_portion(mel, mask_lengths): |
|
|
clip = mask_lengths <= CLIP_LENGTH |
|
|
ref_mel = mel[:, :, :CLIP_LENGTH].clone() |
|
|
for n, z in enumerate(clip): |
|
|
if not z: |
|
|
start = np.random.randint(0, mask_lengths[n].item() - CLIP_LENGTH) |
|
|
ref_mel[n, :, :] = mel[n, :, start : start + CLIP_LENGTH].clone() |
|
|
return ref_mel |
|
|
|
|
|
try: |
|
|
ref_mels = [self.get_mel(path, ref_clip=True)[0] for path in ref_mels] |
|
|
except Exception as e: |
|
|
print(index, e, "ref_mels mel_spec error") |
|
|
if index + 1 < self.data_len: |
|
|
return self.__getitem__(index + 1) |
|
|
return self.__getitem__(0) |
|
|
|
|
|
ref_c = [] |
|
|
for i in range(self.ref_k): |
|
|
if ref_mels[i] is None: |
|
|
continue |
|
|
ref_c.append(ref_mels[i]) |
|
|
|
|
|
if len(ref_c) == 0: |
|
|
print("no refs mel spec found") |
|
|
if index + 1 < self.data_len: |
|
|
return self.__getitem__(index + 1) |
|
|
return self.__getitem__(0) |
|
|
|
|
|
if len(ref_c) != self.ref_k: |
|
|
while len(ref_c) < self.ref_k: |
|
|
ref_c.append(ref_c[-1]) |
|
|
|
|
|
ref_mels = ref_c |
|
|
max_target_len = max([x.size(1) for x in ref_mels]) |
|
|
ref_mels_padded = ( |
|
|
torch.randn((self.ref_k, config.n_mel_channels, max_target_len)) * 1e-9 |
|
|
) |
|
|
mel_length = [] |
|
|
for i, mel in enumerate(ref_mels): |
|
|
ref_mels_padded[i, :, : mel.size(1)] = mel |
|
|
mel_length.append(mel.shape[-1]) |
|
|
|
|
|
ref_mels = get_random_portion(ref_mels_padded, torch.tensor(mel_length)) |
|
|
|
|
|
text_ids = ( |
|
|
[text_enc["<S>"]] |
|
|
+ [text_enc[i] for i in text.strip() if i in text_enc] |
|
|
+ [text_enc["<E>"]] |
|
|
) |
|
|
if self.ar_active: |
|
|
semb_ids = ( |
|
|
[code_enc["<SST>"]] |
|
|
+ [code_enc[str(i - 1)] for i in semb_ids] |
|
|
+ [code_enc["<EST>"]] |
|
|
) |
|
|
|
|
|
return { |
|
|
"mel": mel_spec, |
|
|
"code": semb_ids, |
|
|
"path": path, |
|
|
"ref_mels": ref_mels, |
|
|
"text": text_ids, |
|
|
} |
|
|
|
|
|
|
|
|
def get_padded_seq(sequences, pad_random, before=False, pad__=0): |
|
|
max_len = max([len(s) for s in sequences]) |
|
|
seq_len = [] |
|
|
for i in range(len(sequences)): |
|
|
seq_len.append(len(sequences[i])) |
|
|
if pad_random: |
|
|
pad_ = list((np.random.rand(max_len - len(sequences[i]))) * 1e-9) |
|
|
else: |
|
|
pad_ = [pad__] * (max_len - len(sequences[i])) |
|
|
if not before: |
|
|
sequences[i] = sequences[i] + pad_ |
|
|
else: |
|
|
sequences[i] = pad_ + sequences[i] |
|
|
|
|
|
return sequences, seq_len |
|
|
|
|
|
|
|
|
def collate(batch): |
|
|
mel_specs = [] |
|
|
code = [] |
|
|
paths = [] |
|
|
ref_mels = [] |
|
|
text_ids = [] |
|
|
|
|
|
for b in batch: |
|
|
mel_specs.append(b["mel"]) |
|
|
code.append(b["code"]) |
|
|
paths.append(b["path"]) |
|
|
ref_mels.append(b["ref_mels"]) |
|
|
text_ids.append(b["text"]) |
|
|
|
|
|
if code[-1][-1] == code_enc["<EST>"]: |
|
|
code, code_len = get_padded_seq(code, pad_random=False, pad__=code_enc["<PAD>"]) |
|
|
else: |
|
|
code, code_len = get_padded_seq(code, pad_random=False) |
|
|
|
|
|
text_ids, text_len = get_padded_seq( |
|
|
text_ids, pad_random=False, before=True, pad__=text_enc["<PAD>"] |
|
|
) |
|
|
ref_max_target_len = max([x.size(-1) for x in ref_mels]) |
|
|
ref_mels_padded = ( |
|
|
torch.randn( |
|
|
( |
|
|
len(batch), |
|
|
ref_mels[0].shape[0], |
|
|
config.n_mel_channels, |
|
|
ref_max_target_len, |
|
|
) |
|
|
) |
|
|
* 1e-9 |
|
|
) |
|
|
|
|
|
for i, mel in enumerate(ref_mels): |
|
|
ref_mels_padded[i, :, :, : mel.size(-1)] = mel |
|
|
|
|
|
max_target_len = max([x.size(-1) for x in mel_specs]) |
|
|
mel_padded = torch.randn((len(batch), config.n_mel_channels, max_target_len)) * 1e-9 |
|
|
mel_length = [] |
|
|
for i, mel in enumerate(mel_specs): |
|
|
mel_padded[i, :, : mel.size(-1)] = mel |
|
|
mel_length.append(mel.shape[-1]) |
|
|
|
|
|
return ( |
|
|
normalize_tacotron_mel(mel_padded), |
|
|
torch.tensor(code), |
|
|
torch.tensor(mel_length), |
|
|
torch.tensor(code_len), |
|
|
ref_mels_padded, |
|
|
torch.tensor(text_ids), |
|
|
torch.tensor(text_len), |
|
|
paths, |
|
|
) |
|
|
|
|
|
|
|
|
def train( |
|
|
model, |
|
|
diffuser, |
|
|
train_dataloader, |
|
|
val_dataloader, |
|
|
schedule_sampler=None, |
|
|
rank=0, |
|
|
ar_active=False, |
|
|
m1=None, |
|
|
checkpoint_initial=None, |
|
|
): |
|
|
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) |
|
|
accelerator = Accelerator( |
|
|
gradient_accumulation_steps=config.gradient_accumulation_steps, |
|
|
kwargs_handlers=[ddp_kwargs], |
|
|
) |
|
|
if config.sa_wandb_logs and accelerator.is_local_main_process: |
|
|
conf_ = {} |
|
|
for i, j in config.__dict__.items(): |
|
|
conf_[str(i)] = str(j) |
|
|
wandb_log = wandb.init( |
|
|
project=config.wandb_project, |
|
|
entity=config.user_name, |
|
|
name=config.model_name, |
|
|
config=conf_, |
|
|
) |
|
|
wandb_log.watch(model, log_freq=100) |
|
|
else: |
|
|
wandb_log = None |
|
|
|
|
|
model.train() |
|
|
optimizer = optim.AdamW( |
|
|
model.parameters(), lr=config.sa_lr, weight_decay=config.sa_weight_decay |
|
|
) |
|
|
lr = config.sa_lr |
|
|
min_val_loss = 1000 |
|
|
step_num = 0 |
|
|
start_epoch = 0 |
|
|
if checkpoint_initial is not None: |
|
|
print(checkpoint_initial) |
|
|
model.load_state_dict( |
|
|
torch.load(checkpoint_initial, map_location=torch.device("cpu"))["model"], |
|
|
strict=True, |
|
|
) |
|
|
model.train() |
|
|
optimizer.load_state_dict( |
|
|
torch.load(checkpoint_initial, map_location=torch.device("cpu"))[ |
|
|
"optimizer" |
|
|
] |
|
|
) |
|
|
step_num = int( |
|
|
torch.load(checkpoint_initial, map_location=torch.device("cpu"))["step"] |
|
|
) |
|
|
step_num = 0 |
|
|
start_epoch = ( |
|
|
int( |
|
|
torch.load(checkpoint_initial, map_location=torch.device("cpu"))[ |
|
|
"epoch" |
|
|
] |
|
|
) |
|
|
+ 1 |
|
|
) |
|
|
print(f"resuming training from epoch {start_epoch} and step {step_num}") |
|
|
|
|
|
train_dataloader, model, optimizer = accelerator.prepare( |
|
|
train_dataloader, model, optimizer |
|
|
) |
|
|
|
|
|
FM = BASECFM() |
|
|
device = next(model.parameters()).device |
|
|
if ar_active: |
|
|
m1 = m1.to(device) |
|
|
|
|
|
loading_time = [] |
|
|
for i in range(start_epoch, config.sa_epochs): |
|
|
epoch_loss = {"vlb": [], "mse": [], "loss": []} |
|
|
if accelerator.is_local_main_process: |
|
|
train_loader = tqdm(train_dataloader, desc="Training epoch %d" % (i)) |
|
|
else: |
|
|
train_loader = train_dataloader |
|
|
|
|
|
for inputs in train_loader: |
|
|
with accelerator.accumulate(model): |
|
|
optimizer.zero_grad() |
|
|
x1, code_emb, mask_lengths, _, ref_mels, text_ids, _, _ = inputs |
|
|
mask = get_mask_from_lengths(mask_lengths).unsqueeze(1) |
|
|
mask = mask.squeeze(1).float() |
|
|
|
|
|
loss, _, t = FM.compute_loss(model, x1, mask, code_emb, ref_mels) |
|
|
|
|
|
accelerator.backward(loss) |
|
|
accelerator.clip_grad_norm_(model.parameters(), 1.0) |
|
|
optimizer.step() |
|
|
step_num += 1 |
|
|
|
|
|
epoch_loss["loss"].append(loss.item()) |
|
|
|
|
|
if step_num % config.gradient_accumulation_steps == 0: |
|
|
epoch_training_loss = torch.tensor( |
|
|
sum(epoch_loss["loss"]) / len(epoch_loss["loss"]) |
|
|
).to(device) |
|
|
epoch_loss = {"vlb": [], "mse": [], "loss": []} |
|
|
epoch_training_loss = ( |
|
|
accelerator.gather_for_metrics(epoch_training_loss) |
|
|
.mean() |
|
|
.item() |
|
|
) |
|
|
|
|
|
if config.sa_wandb_logs and accelerator.is_local_main_process: |
|
|
wandb_log.log({"training_loss": epoch_training_loss}) |
|
|
|
|
|
if ( |
|
|
step_num % (config.sa_eval_step * config.gradient_accumulation_steps) |
|
|
== 0 |
|
|
): |
|
|
print(f"evaluation at step_num {step_num}") |
|
|
if accelerator.is_local_main_process: |
|
|
|
|
|
unwrapped_model = accelerator.unwrap_model(model) |
|
|
checkpoint = { |
|
|
"epoch": i, |
|
|
"step": step_num // config.gradient_accumulation_steps, |
|
|
"model": unwrapped_model.state_dict(), |
|
|
"optimizer": optimizer.state_dict(), |
|
|
"norms": config.norms, |
|
|
} |
|
|
torch.save( |
|
|
checkpoint, |
|
|
os.path.join(config.save_root_dir, "latest.pt",), |
|
|
) |
|
|
|
|
|
if accelerator.is_local_main_process: |
|
|
val_loss, val_mse, val_vlb, time_steps_mean = val( |
|
|
model, |
|
|
FM, |
|
|
val_dataloader, |
|
|
infer_=config.sa_infer, |
|
|
epoch=i, |
|
|
rank=accelerator.is_local_main_process, |
|
|
ar_active=ar_active, |
|
|
m1=m1, |
|
|
) |
|
|
model.train() |
|
|
|
|
|
print( |
|
|
"validation loss : ", |
|
|
val_loss, |
|
|
"\nvalidation mse loss : ", |
|
|
val_mse, |
|
|
"\nvalidation vlb loss : ", |
|
|
val_vlb, |
|
|
) |
|
|
if config.sa_wandb_logs: |
|
|
wandb_log.log({"val_loss": val_loss}) |
|
|
if val_loss < min_val_loss: |
|
|
unwrapped_model = accelerator.unwrap_model(model) |
|
|
checkpoint = { |
|
|
"epoch": i, |
|
|
"step": step_num // config.gradient_accumulation_steps, |
|
|
"model": unwrapped_model.state_dict(), |
|
|
"optimizer": optimizer.state_dict(), |
|
|
"norms": config.norms, |
|
|
} |
|
|
torch.save( |
|
|
checkpoint, |
|
|
os.path.join(config.save_root_dir, "_best.pt"), |
|
|
) |
|
|
min_val_loss = val_loss |
|
|
|
|
|
if i == start_epoch + 12: |
|
|
exit() |
|
|
if config.sa_wandb_logs and accelerator.is_local_main_process: |
|
|
wandb_log.finish() |
|
|
|
|
|
|
|
|
def val( |
|
|
model, |
|
|
FM, |
|
|
val_dataloader, |
|
|
infer_=False, |
|
|
epoch=0, |
|
|
rank=False, |
|
|
ar_active=False, |
|
|
m1=None, |
|
|
): |
|
|
""" |
|
|
Return the loss value |
|
|
""" |
|
|
model.eval() |
|
|
epoch_loss = {"vlb": [], "mse": [], "loss": []} |
|
|
code_emb = None |
|
|
x = None |
|
|
mask_lengths = None |
|
|
time_steps_mean = [] |
|
|
device = next(model.parameters()).device |
|
|
if rank: |
|
|
val_dataloader = tqdm(val_dataloader, desc="validation epoch %d" % (epoch)) |
|
|
else: |
|
|
val_dataloader = val_dataloader |
|
|
|
|
|
with torch.no_grad(): |
|
|
for inputs in val_dataloader: |
|
|
x1, code_emb, mask_lengths, code_len, ref_mels, text_ids, _, _ = inputs |
|
|
|
|
|
mask = get_mask_from_lengths(mask_lengths).unsqueeze(1).to(device) |
|
|
mask = mask.squeeze(1).float() |
|
|
x1 = x1.to(device) |
|
|
code_emb = code_emb.to(device) |
|
|
text_ids = text_ids.to(device) |
|
|
ref_mels = ref_mels.to(device) |
|
|
|
|
|
loss, _, t = FM.compute_loss(model, x1, mask, code_emb, ref_mels) |
|
|
time_steps_mean.extend(t.detach().cpu().squeeze(-1).squeeze(-1).tolist()) |
|
|
mse = loss |
|
|
vlb = loss |
|
|
|
|
|
epoch_loss["loss"].append(loss.item()) |
|
|
epoch_loss["mse"].append(mse.item()) |
|
|
epoch_loss["vlb"].append(vlb.item()) |
|
|
|
|
|
epoch_vlb_loss = sum(epoch_loss["vlb"]) / len(epoch_loss["vlb"]) |
|
|
epoch_training_loss = sum(epoch_loss["loss"]) / len(epoch_loss["loss"]) |
|
|
epoch_mse_loss = sum(epoch_loss["mse"]) / len(epoch_loss["mse"]) |
|
|
if rank and infer_ and epoch % config.sa_infer_epoch == 0: |
|
|
k = 4 |
|
|
if ar_active: |
|
|
code_embs = [code_emb[i, :, : code_len[i]] for i in range(k)] |
|
|
else: |
|
|
code_embs = [code_emb[i, : code_len[i]] for i in range(k)] |
|
|
audio_paths, mels = infer( |
|
|
model, mask_lengths[:k], code_embs, ref_mels[:k, :], epoch |
|
|
) |
|
|
|
|
|
if config.sa_wandb_logs: |
|
|
images = [ |
|
|
wandb.Image(mel[0], caption="epoch: " + str(epoch)) for mel in mels |
|
|
] |
|
|
x = [ |
|
|
wandb.Image(x1[i, :, : mask_lengths[i]], caption="Actual: ") |
|
|
for i in range(k) |
|
|
] |
|
|
wandb.log( |
|
|
{ |
|
|
"predicted audio": [ |
|
|
wandb.Audio(audio_path) for audio_path in audio_paths |
|
|
], |
|
|
"predicted melspec": images, |
|
|
"actual melspec": x, |
|
|
"epoch": epoch, |
|
|
} |
|
|
) |
|
|
|
|
|
return ( |
|
|
epoch_training_loss, |
|
|
epoch_mse_loss, |
|
|
epoch_vlb_loss, |
|
|
sum(time_steps_mean) / len(time_steps_mean), |
|
|
) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
os.makedirs(os.path.join(config.save_root_dir, config.model_name, "S2A"), exist_ok=True) |
|
|
|
|
|
model = DiffModel( |
|
|
input_channels=100, |
|
|
output_channels=100, |
|
|
model_channels=512, |
|
|
num_heads=8, |
|
|
dropout=0.10, |
|
|
num_layers=8, |
|
|
enable_fp16=False, |
|
|
condition_free_per=0.0, |
|
|
multispeaker=True, |
|
|
style_tokens=100, |
|
|
training=True, |
|
|
ar_active=False, |
|
|
in_latent_channels=len(code_labels), |
|
|
) |
|
|
m1 = None |
|
|
checkpoint = None |
|
|
print("Model Loaded") |
|
|
print("batch_size :", config.sa_batch_size) |
|
|
print("Diffusion timesteps:", config.sa_timesteps_max) |
|
|
|
|
|
file_name_train = config.train_file |
|
|
file_name_val = config.val_file |
|
|
|
|
|
train_dataset = Acoustic_dataset(file_name_train, scale=config.scale) |
|
|
train_dataloader = DataLoader( |
|
|
train_dataset, |
|
|
pin_memory=True, |
|
|
persistent_workers=True, |
|
|
num_workers=config.sa_num_workers, |
|
|
batch_size=config.sa_batch_size, |
|
|
shuffle=True, |
|
|
drop_last=False, |
|
|
collate_fn=collate, |
|
|
) |
|
|
|
|
|
val_dataset = Acoustic_dataset(file_name_val, scale=config.scale, dur_=5) |
|
|
val_dataloader = DataLoader( |
|
|
val_dataset, |
|
|
pin_memory=True, |
|
|
persistent_workers=True, |
|
|
num_workers=config.sa_num_workers, |
|
|
batch_size=config.sa_batch_size, |
|
|
shuffle=True, |
|
|
drop_last=True, |
|
|
collate_fn=collate, |
|
|
) |
|
|
|
|
|
train( |
|
|
model, |
|
|
diffuser=None, |
|
|
train_dataloader=train_dataloader, |
|
|
val_dataloader=val_dataloader, |
|
|
rank=0, |
|
|
ar_active=False, |
|
|
m1=m1, |
|
|
checkpoint_initial=checkpoint, |
|
|
) |
|
|
|