| import glob |
| from collections import OrderedDict, defaultdict |
| import os |
| import dill |
| import wandb |
| import torch_optimizer as optim_extra |
| import torch |
| from torch import optim |
| from torch.utils.data import ConcatDataset, DataLoader |
| import torchaudio |
| import numpy as np |
| from tqdm import tqdm |
|
|
| from dataloader import (TrainTestDataset, |
| TrainValTestDataset, collate_fn_padd, spectral_size) |
| from next_frame_classifier import NextFrameClassifier |
| from utils import (PrecisionRecallMetric, StatsMeter, |
| detect_peaks, line, max_min_norm, replicate_first_k_frames) |
|
|
| import dutch_preprocess |
| from utils import timit_to_leehon_map_MACRO |
|
|
|
|
| class Solver(torch.nn.Module): |
| def __init__(self, cfg): |
| super().__init__() |
| hp = cfg |
| self.hp = hp |
| self.hparams = hp |
| self.current_epoch = 0 |
| |
| self.peak_detection_params = defaultdict(lambda: { |
| "width": None, |
| "distance": None |
| }) |
| self.pr = defaultdict(lambda: { |
| "train": PrecisionRecallMetric(), |
| "val": PrecisionRecallMetric(), |
| "test": PrecisionRecallMetric() |
| }) |
| |
| self.best_l1_dist= defaultdict(lambda: { |
| "train": (float("inf"), 0), |
| "val": (float("inf"), 0), |
| "test": (float("inf"), 0) |
| }) |
|
|
| self.overall_best_l1_dist = 0 |
| self.stats = defaultdict(lambda: { |
| "train": StatsMeter(), |
| "val": StatsMeter(), |
| "test": StatsMeter() |
| }) |
| |
| self.build_model() |
| self.prepare_data() |
| self.configure_optimizers() |
| |
| def build_model(self): |
| print("MODEL:") |
| self.NFC = NextFrameClassifier(self.hp) |
| line() |
|
|
| def prepare_data(self): |
| if "timit" in self.hp.data and "buckeye" in self.hp.data: |
| |
| t_train, t_val, t_test = TrainTestDataset.get_datasets(path=self.hp.timit_path) |
| b_train, b_val, b_test = TrainValTestDataset.get_datasets(path=self.hp.buckeye_path, percent=self.hp.buckeye_percent) |
| train = ConcatDataset([t_train, b_train]) |
| val = ConcatDataset([t_val, b_val]) |
| test = ConcatDataset([t_test, b_test]) |
| train.path = f"TIMIT+Buckeye train ({len(t_train)}+{len(b_train)})" |
| val.path = f"TIMIT+Buckeye val ({len(t_val)}+{len(b_val)})" |
| test.path = f"TIMIT+Buckeye test ({len(t_test)}+{len(b_test)})" |
| elif "timit" in self.hp.data: |
| train, val, test = TrainTestDataset.get_datasets(path=self.hp.timit_path) |
| elif "buckeye" in self.hp.data: |
| train, val, test = TrainValTestDataset.get_datasets(path=self.hp.buckeye_path, percent=self.hp.buckeye_percent) |
| else: |
| raise Exception("no such training data!") |
|
|
| self.train_dataset = train |
| self.valid_dataset = val |
| self.test_dataset = test |
| |
| line() |
| print("DATA:") |
| print(f"train: {self.train_dataset.path} ({len(self.train_dataset)})") |
| print(f"valid: {self.valid_dataset.path} ({len(self.valid_dataset)})") |
| print(f"test: {self.test_dataset.path} ({len(self.test_dataset)})") |
| line() |
| |
| _persistent = self.hp.dataloader_n_workers > 0 |
| self.train_loader = DataLoader(self.train_dataset, |
| batch_size=self.hp.batch_size, |
| shuffle=True, |
| collate_fn=collate_fn_padd, |
| num_workers=self.hp.dataloader_n_workers, |
| pin_memory=True, |
| persistent_workers=_persistent) |
|
|
| self.valid_loader = DataLoader(self.valid_dataset, |
| batch_size=self.hp.batch_size, |
| shuffle=False, |
| collate_fn=collate_fn_padd, |
| num_workers=self.hp.dataloader_n_workers, |
| pin_memory=True, |
| persistent_workers=_persistent) |
|
|
| self.test_loader = DataLoader(self.test_dataset, |
| batch_size=self.hp.batch_size, |
| shuffle=False, |
| collate_fn=collate_fn_padd, |
| num_workers=self.hp.dataloader_n_workers, |
| pin_memory=True, |
| persistent_workers=_persistent) |
|
|
| def configure_optimizers(self): |
| parameters = filter(lambda p: p.requires_grad, self.NFC.parameters()) |
| if self.hp.optimizer == "sgd": |
| self.optimizer = optim.SGD(parameters, lr=self.hparams.lr, momentum=0.9, weight_decay=5e-4) |
| elif self.hp.optimizer == "adam": |
| self.optimizer = optim.Adam(parameters, lr=self.hparams.lr, weight_decay=5e-4) |
| elif self.hp.optimizer == "adamW": |
| self.optimizer = optim.AdamW(parameters, lr=self.hparams.lr) |
| elif self.hp.optimizer == "ranger": |
| self.optimizer = optim_extra.Ranger(parameters, lr=self.hparams.lr, alpha=0.5, k=6, N_sma_threshhold=5, betas=(.95, 0.999), eps=1e-5, weight_decay=0) |
| else: |
| raise Exception("unknown optimizer") |
| print(f"optimizer: {self.optimizer}") |
| line() |
| self.scheduler = optim.lr_scheduler.StepLR(self.optimizer, |
| step_size=self.hp.lr_anneal_step, |
| gamma=self.hp.lr_anneal_gamma) |
| return [self.optimizer] |
|
|
| def train_one_epoch(self): |
| self.run_epoch("train", self.current_epoch) |
| return self.stats['nfc_loss']['train'].get_stats() |
| |
| def validate(self): |
| self.evaluate("val", self.current_epoch) |
| return self.stats['nfc_loss']['val'].get_stats() |
| |
| def run_epoch(self, mode, epoch): |
| self.current_epoch = epoch |
| is_train = mode=="train" |
| loader = getattr(self, f"{mode}_loader") |
| self.NFC.train(mode=="train") |
| |
| for batch_i, batch in enumerate(tqdm(loader)): |
| if is_train: |
| self.optimizer.zero_grad() |
| result = self.forward(batch,batch_i,mode) |
| |
| if is_train: |
| result['loss'].backward() |
| self.optimizer.step() |
| if is_train: |
| self.scheduler.step() |
| self.generic_eval_end(mode, epoch) |
|
|
| |
| def evaluate(self, mode, epoch=None): |
| loader = self.valid_loader if mode == 'val' else self.test_loader |
| self.NFC.eval() |
| with torch.no_grad(): |
| for batch_i, batch in enumerate(tqdm(loader)): |
| self.forward(batch, batch_i, mode) |
| self.generic_eval_end(mode, epoch if epoch is not None else self.current_epoch) |
| |
| def test(self): |
| self.evaluate('test', epoch=-1) |
| |
| def forward(self, data_batch, batch_i, mode): |
| loss = 0 |
| |
| |
| audio, seg, phonemes, length, fname = data_batch |
| |
| language = "english" |
| if language == "dutch": |
| lh39_ph = [] |
| for phoneme_seq in phonemes: |
| lh39_ph_seq = [] |
| for IFA_ph in phoneme_seq: |
| print(f"\nINPUT: {IFA_ph}") |
| output = dutch_preprocess.aligner_pipeline(timit_to_leehon_map_MACRO[IFA_ph.lower()]) |
| lh39_ph_seq.append([x["lh39"] for x in output]) |
| if not output: |
| print("Results: None") |
| print(phoneme_seq) |
| print(f"Dutch IPA to LH39 mapping: {lh39_ph_seq}") |
| lh39_ph.append(np.hstack(lh39_ph_seq).tolist()) |
| print(f"Dutch IPA to LH39 total: {lh39_ph}") |
| phonemes = lh39_ph |
| |
| |
| |
| |
| if mode in ["test", "val"]: |
| with torch.no_grad(): |
| self.NFC.eval() |
| preds,original_lengths, probs, frame_labels, _,preds_peaks, w_phi = self.NFC(audio,None,phonemes,length) |
| else: |
| self.NFC.train() |
| preds,original_lengths, probs, frame_labels,_,preds_peaks, w_phi = self.NFC(audio,seg,phonemes,length) |
| |
| epoch = self.current_epoch |
| if epoch < 5: |
| total_loss, ph_loss, loss_nce, sum_mse, w_pos_neg, w_phi = self.NFC.loss_ph(preds,original_lengths, probs, frame_labels, seg,preds_peaks, w_phi, phonemes) |
| else: |
| total_loss, ph_loss, loss_nce, sum_mse, w_pos_neg, w_phi = self.NFC.total_loss(preds,original_lengths, probs, frame_labels, seg,preds_peaks, w_phi, phonemes) |
|
|
| |
| |
| |
| NFC_loss = total_loss |
| |
| self.stats['w_pos'][mode].update(torch.tensor(w_pos_neg[0].item())) |
| self.stats['w_neg'][mode].update(torch.tensor(w_pos_neg[1].item())) |
| self.stats['w_phi1'][mode].update(torch.tensor(w_phi[0].item())) |
| self.stats['w_phi2'][mode].update(torch.tensor(w_phi[1].item())) |
| self.stats['ph_loss'][mode].update(torch.tensor(ph_loss.item())) |
| self.stats['nce_loss'][mode].update(torch.tensor(loss_nce.item())) |
| self.stats['softDPmse_loss'][mode].update(torch.tensor(sum_mse.item())) |
| self.stats['nfc_loss'][mode].update(torch.tensor(NFC_loss.item())) |
| loss += NFC_loss |
| loss_key = "loss" if mode == "train" else f"{mode}_loss" |
| return OrderedDict({ |
| loss_key: loss |
| }) |
|
|
| def generic_eval_end(self, mode, epoch): |
| metrics = {} |
| data = self.hp.data |
|
|
| for k, v in self.stats.items(): |
| metrics[f"{mode}_{k}"] = self.stats[k][mode].get_stats() |
| metrics['epoch'] = epoch+1 |
| metrics['current_lr'] = self.optimizer.param_groups[0]['lr'] |
|
|
| line() |
| |
| best_overall_l1_dist = float("inf") |
| for pred_type, l1_dist in self.best_l1_dist.items(): |
| if l1_dist[mode][0] < best_overall_l1_dist: |
| best_overall_l1_dist = l1_dist[mode][0] |
| metrics[f'{mode}_min_l1_dist'] = best_overall_l1_dist |
| for k, v in metrics.items(): |
| print(f"\t{k:<30} -- {v}") |
| line() |
| wandb.log(metrics, step=epoch) |
|
|
| output = OrderedDict({ |
| 'log': metrics |
| }) |
|
|
| return output |
| |
| def get_ckpt_path(self): |
| |
| return glob.glob(os.path.join(self.hp.wd + "/*.ckpt"))[0] |