Spaces:
Runtime error
Runtime error
| import librosa.display as lbd | |
| import matplotlib.pyplot as plt | |
| import torch | |
| import torch.multiprocessing | |
| from torch.cuda.amp import GradScaler | |
| from torch.cuda.amp import autocast | |
| from torch.nn.utils.rnn import pad_sequence | |
| from torch.utils.data.dataloader import DataLoader | |
| from tqdm import tqdm | |
| from Preprocessing.ArticulatoryCombinedTextFrontend import ArticulatoryCombinedTextFrontend | |
| from Preprocessing.ArticulatoryCombinedTextFrontend import get_language_id | |
| from Utility.WarmupScheduler import WarmupScheduler | |
| from Utility.path_to_transcript_dicts import * | |
| from Utility.utils import cumsum_durations | |
| from Utility.utils import delete_old_checkpoints | |
| from Utility.utils import get_most_recent_checkpoint | |
| def train_loop(net, | |
| datasets, | |
| device, | |
| save_directory, | |
| batch_size, | |
| steps, | |
| steps_per_checkpoint, | |
| lr, | |
| path_to_checkpoint, | |
| resume=False, | |
| warmup_steps=4000): | |
| # ============ | |
| # Preparations | |
| # ============ | |
| net = net.to(device) | |
| torch.multiprocessing.set_sharing_strategy('file_system') | |
| train_loaders = list() | |
| train_iters = list() | |
| for dataset in datasets: | |
| train_loaders.append(DataLoader(batch_size=batch_size, | |
| dataset=dataset, | |
| drop_last=True, | |
| num_workers=2, | |
| pin_memory=True, | |
| shuffle=True, | |
| prefetch_factor=5, | |
| collate_fn=collate_and_pad, | |
| persistent_workers=True)) | |
| train_iters.append(iter(train_loaders[-1])) | |
| default_embeddings = {"en": None, "de": None, "el": None, "es": None, "fi": None, "ru": None, "hu": None, "nl": None, "fr": None} | |
| for index, lang in enumerate(["en", "de", "el", "es", "fi", "ru", "hu", "nl", "fr"]): | |
| default_embedding = None | |
| for datapoint in datasets[index]: | |
| if default_embedding is None: | |
| default_embedding = datapoint[7].squeeze() | |
| else: | |
| default_embedding = default_embedding + datapoint[7].squeeze() | |
| default_embeddings[lang] = (default_embedding / len(datasets[index])).to(device) | |
| optimizer = torch.optim.RAdam(net.parameters(), lr=lr, eps=1.0e-06, weight_decay=0.0) | |
| grad_scaler = GradScaler() | |
| scheduler = WarmupScheduler(optimizer, warmup_steps=warmup_steps) | |
| if resume: | |
| previous_checkpoint = get_most_recent_checkpoint(checkpoint_dir=save_directory) | |
| if previous_checkpoint is not None: | |
| path_to_checkpoint = previous_checkpoint | |
| else: | |
| raise RuntimeError(f"No checkpoint found that can be resumed from in {save_directory}") | |
| step_counter = 0 | |
| train_losses_total = list() | |
| if path_to_checkpoint is not None: | |
| check_dict = torch.load(os.path.join(path_to_checkpoint), map_location=device) | |
| net.load_state_dict(check_dict["model"]) | |
| if resume: | |
| optimizer.load_state_dict(check_dict["optimizer"]) | |
| step_counter = check_dict["step_counter"] | |
| grad_scaler.load_state_dict(check_dict["scaler"]) | |
| scheduler.load_state_dict(check_dict["scheduler"]) | |
| if step_counter > steps: | |
| print("Desired steps already reached in loaded checkpoint.") | |
| return | |
| net.train() | |
| # ============================= | |
| # Actual train loop starts here | |
| # ============================= | |
| for step in tqdm(range(step_counter, steps)): | |
| batches = [] | |
| for index in range(len(datasets)): | |
| # we get one batch for each task (i.e. language in this case) | |
| try: | |
| batch = next(train_iters[index]) | |
| batches.append(batch) | |
| except StopIteration: | |
| train_iters[index] = iter(train_loaders[index]) | |
| batch = next(train_iters[index]) | |
| batches.append(batch) | |
| train_loss = 0.0 | |
| for batch in batches: | |
| with autocast(): | |
| # we sum the loss for each task, as we would do for the | |
| # second order regular MAML, but we do it only over one | |
| # step (i.e. iterations of inner loop = 1) | |
| train_loss = train_loss + net(text_tensors=batch[0].to(device), | |
| text_lengths=batch[1].to(device), | |
| gold_speech=batch[2].to(device), | |
| speech_lengths=batch[3].to(device), | |
| gold_durations=batch[4].to(device), | |
| gold_pitch=batch[6].to(device), # mind the switched order | |
| gold_energy=batch[5].to(device), # mind the switched order | |
| utterance_embedding=batch[7].to(device), | |
| lang_ids=batch[8].to(device), | |
| return_mels=False) | |
| # then we directly update our meta-parameters without | |
| # the need for any task specific parameters | |
| train_losses_total.append(train_loss.item()) | |
| optimizer.zero_grad() | |
| grad_scaler.scale(train_loss).backward() | |
| grad_scaler.unscale_(optimizer) | |
| torch.nn.utils.clip_grad_norm_(net.parameters(), 1.0, error_if_nonfinite=False) | |
| grad_scaler.step(optimizer) | |
| grad_scaler.update() | |
| scheduler.step() | |
| if step % steps_per_checkpoint == 0: | |
| # ============================== | |
| # Enough steps for some insights | |
| # ============================== | |
| net.eval() | |
| print(f"Total Loss: {round(sum(train_losses_total) / len(train_losses_total), 3)}") | |
| train_losses_total = list() | |
| torch.save({ | |
| "model" : net.state_dict(), | |
| "optimizer" : optimizer.state_dict(), | |
| "scaler" : grad_scaler.state_dict(), | |
| "scheduler" : scheduler.state_dict(), | |
| "step_counter": step, | |
| "default_emb" : default_embeddings["en"] | |
| }, | |
| os.path.join(save_directory, "checkpoint_{}.pt".format(step))) | |
| delete_old_checkpoints(save_directory, keep=5) | |
| for lang in ["en", "de", "el", "es", "fi", "ru", "hu", "nl", "fr"]: | |
| plot_progress_spec(net=net, | |
| device=device, | |
| lang=lang, | |
| save_dir=save_directory, | |
| step=step, | |
| utt_embeds=default_embeddings) | |
| net.train() | |
| def plot_progress_spec(net, device, save_dir, step, lang, utt_embeds): | |
| tf = ArticulatoryCombinedTextFrontend(language=lang) | |
| sentence = "" | |
| default_embed = utt_embeds[lang] | |
| if lang == "en": | |
| sentence = "This is a complex sentence, it even has a pause!" | |
| elif lang == "de": | |
| sentence = "Dies ist ein komplexer Satz, er hat sogar eine Pause!" | |
| elif lang == "el": | |
| sentence = "Αυτή είναι μια σύνθετη πρόταση, έχει ακόμη και παύση!" | |
| elif lang == "es": | |
| sentence = "Esta es una oración compleja, ¡incluso tiene una pausa!" | |
| elif lang == "fi": | |
| sentence = "Tämä on monimutkainen lause, sillä on jopa tauko!" | |
| elif lang == "ru": | |
| sentence = "Это сложное предложение, в нем даже есть пауза!" | |
| elif lang == "hu": | |
| sentence = "Ez egy összetett mondat, még szünet is van benne!" | |
| elif lang == "nl": | |
| sentence = "Dit is een complexe zin, er zit zelfs een pauze in!" | |
| elif lang == "fr": | |
| sentence = "C'est une phrase complexe, elle a même une pause !" | |
| phoneme_vector = tf.string_to_tensor(sentence).squeeze(0).to(device) | |
| spec, durations, *_ = net.inference(text=phoneme_vector, | |
| return_duration_pitch_energy=True, | |
| utterance_embedding=default_embed, | |
| lang_id=get_language_id(lang).to(device)) | |
| spec = spec.transpose(0, 1).to("cpu").numpy() | |
| duration_splits, label_positions = cumsum_durations(durations.cpu().numpy()) | |
| if not os.path.exists(os.path.join(save_dir, "spec")): | |
| os.makedirs(os.path.join(save_dir, "spec")) | |
| fig, ax = plt.subplots(nrows=1, ncols=1) | |
| lbd.specshow(spec, | |
| ax=ax, | |
| sr=16000, | |
| cmap='GnBu', | |
| y_axis='mel', | |
| x_axis=None, | |
| hop_length=256) | |
| ax.yaxis.set_visible(False) | |
| ax.set_xticks(duration_splits, minor=True) | |
| ax.xaxis.grid(True, which='minor') | |
| ax.set_xticks(label_positions, minor=False) | |
| ax.set_xticklabels(tf.get_phone_string(sentence)) | |
| ax.set_title(sentence) | |
| plt.savefig(os.path.join(os.path.join(save_dir, "spec"), f"{step}_{lang}.png")) | |
| plt.clf() | |
| plt.close() | |
| def collate_and_pad(batch): | |
| # text, text_len, speech, speech_len, durations, energy, pitch, utterance condition, language_id | |
| return (pad_sequence([datapoint[0] for datapoint in batch], batch_first=True), | |
| torch.stack([datapoint[1] for datapoint in batch]).squeeze(1), | |
| pad_sequence([datapoint[2] for datapoint in batch], batch_first=True), | |
| torch.stack([datapoint[3] for datapoint in batch]).squeeze(1), | |
| pad_sequence([datapoint[4] for datapoint in batch], batch_first=True), | |
| pad_sequence([datapoint[5] for datapoint in batch], batch_first=True), | |
| pad_sequence([datapoint[6] for datapoint in batch], batch_first=True), | |
| torch.stack([datapoint[7] for datapoint in batch]).squeeze(), | |
| torch.stack([datapoint[8] for datapoint in batch])) | |