FALCON / solver.py
MLSpeech's picture
Add solver.py (referenced by checkpoint peak_detection_params on load)
d8968a6 verified
Raw
History Blame Contribute Delete
11 kB
import glob
from collections import OrderedDict, defaultdict
import os
import dill
import wandb
import torch_optimizer as optim_extra
import torch
from torch import optim
from torch.utils.data import ConcatDataset, DataLoader
import torchaudio
import numpy as np
from tqdm import tqdm
from dataloader import (TrainTestDataset,
TrainValTestDataset, collate_fn_padd, spectral_size)
from next_frame_classifier import NextFrameClassifier
from utils import (PrecisionRecallMetric, StatsMeter,
detect_peaks, line, max_min_norm, replicate_first_k_frames)
import dutch_preprocess
from utils import timit_to_leehon_map_MACRO
class Solver(torch.nn.Module):
def __init__(self, cfg):
super().__init__()
hp = cfg
self.hp = hp
self.hparams = hp
self.current_epoch = 0
self.peak_detection_params = defaultdict(lambda: {
"width": None,
"distance": None
})
self.pr = defaultdict(lambda: {
"train": PrecisionRecallMetric(),
"val": PrecisionRecallMetric(),
"test": PrecisionRecallMetric()
})
self.best_l1_dist= defaultdict(lambda: {
"train": (float("inf"), 0),
"val": (float("inf"), 0),
"test": (float("inf"), 0)
})
self.overall_best_l1_dist = 0
self.stats = defaultdict(lambda: {
"train": StatsMeter(),
"val": StatsMeter(),
"test": StatsMeter()
})
self.build_model()
self.prepare_data()
self.configure_optimizers()
def build_model(self):
print("MODEL:")
self.NFC = NextFrameClassifier(self.hp)
line()
def prepare_data(self):
if "timit" in self.hp.data and "buckeye" in self.hp.data:
# joint training on TIMIT + Buckeye (data: timit_buckeye)
t_train, t_val, t_test = TrainTestDataset.get_datasets(path=self.hp.timit_path)
b_train, b_val, b_test = TrainValTestDataset.get_datasets(path=self.hp.buckeye_path, percent=self.hp.buckeye_percent)
train = ConcatDataset([t_train, b_train])
val = ConcatDataset([t_val, b_val])
test = ConcatDataset([t_test, b_test])
train.path = f"TIMIT+Buckeye train ({len(t_train)}+{len(b_train)})"
val.path = f"TIMIT+Buckeye val ({len(t_val)}+{len(b_val)})"
test.path = f"TIMIT+Buckeye test ({len(t_test)}+{len(b_test)})"
elif "timit" in self.hp.data:
train, val, test = TrainTestDataset.get_datasets(path=self.hp.timit_path)
elif "buckeye" in self.hp.data:
train, val, test = TrainValTestDataset.get_datasets(path=self.hp.buckeye_path, percent=self.hp.buckeye_percent)
else:
raise Exception("no such training data!")
self.train_dataset = train
self.valid_dataset = val
self.test_dataset = test
line()
print("DATA:")
print(f"train: {self.train_dataset.path} ({len(self.train_dataset)})")
print(f"valid: {self.valid_dataset.path} ({len(self.valid_dataset)})")
print(f"test: {self.test_dataset.path} ({len(self.test_dataset)})")
line()
_persistent = self.hp.dataloader_n_workers > 0
self.train_loader = DataLoader(self.train_dataset,
batch_size=self.hp.batch_size,
shuffle=True,
collate_fn=collate_fn_padd,
num_workers=self.hp.dataloader_n_workers,
pin_memory=True,
persistent_workers=_persistent)
self.valid_loader = DataLoader(self.valid_dataset,
batch_size=self.hp.batch_size,
shuffle=False,
collate_fn=collate_fn_padd,
num_workers=self.hp.dataloader_n_workers,
pin_memory=True,
persistent_workers=_persistent)
self.test_loader = DataLoader(self.test_dataset,
batch_size=self.hp.batch_size,
shuffle=False,
collate_fn=collate_fn_padd,
num_workers=self.hp.dataloader_n_workers,
pin_memory=True,
persistent_workers=_persistent)
def configure_optimizers(self):
parameters = filter(lambda p: p.requires_grad, self.NFC.parameters())
if self.hp.optimizer == "sgd":
self.optimizer = optim.SGD(parameters, lr=self.hparams.lr, momentum=0.9, weight_decay=5e-4)
elif self.hp.optimizer == "adam":
self.optimizer = optim.Adam(parameters, lr=self.hparams.lr, weight_decay=5e-4)
elif self.hp.optimizer == "adamW":
self.optimizer = optim.AdamW(parameters, lr=self.hparams.lr)
elif self.hp.optimizer == "ranger":
self.optimizer = optim_extra.Ranger(parameters, lr=self.hparams.lr, alpha=0.5, k=6, N_sma_threshhold=5, betas=(.95, 0.999), eps=1e-5, weight_decay=0)
else:
raise Exception("unknown optimizer")
print(f"optimizer: {self.optimizer}")
line()
self.scheduler = optim.lr_scheduler.StepLR(self.optimizer,
step_size=self.hp.lr_anneal_step,
gamma=self.hp.lr_anneal_gamma)
return [self.optimizer]
def train_one_epoch(self):
self.run_epoch("train", self.current_epoch)
return self.stats['nfc_loss']['train'].get_stats()
def validate(self):
self.evaluate("val", self.current_epoch)
return self.stats['nfc_loss']['val'].get_stats()
def run_epoch(self, mode, epoch):
self.current_epoch = epoch
is_train = mode=="train"
loader = getattr(self, f"{mode}_loader")
self.NFC.train(mode=="train")
for batch_i, batch in enumerate(tqdm(loader)):
if is_train:
self.optimizer.zero_grad()
result = self.forward(batch,batch_i,mode)
if is_train:
result['loss'].backward()
self.optimizer.step()
if is_train:
self.scheduler.step()
self.generic_eval_end(mode, epoch)
def evaluate(self, mode, epoch=None):
loader = self.valid_loader if mode == 'val' else self.test_loader
self.NFC.eval()
with torch.no_grad():
for batch_i, batch in enumerate(tqdm(loader)):
self.forward(batch, batch_i, mode)
self.generic_eval_end(mode, epoch if epoch is not None else self.current_epoch)
def test(self):
self.evaluate('test', epoch=-1)
def forward(self, data_batch, batch_i, mode):
loss = 0
# TRAIN
audio, seg, phonemes, length, fname = data_batch
language = "english" # "dutch"
if language == "dutch":
lh39_ph = []
for phoneme_seq in phonemes:
lh39_ph_seq = []
for IFA_ph in phoneme_seq:
print(f"\nINPUT: {IFA_ph}")
output = dutch_preprocess.aligner_pipeline(timit_to_leehon_map_MACRO[IFA_ph.lower()])
lh39_ph_seq.append([x["lh39"] for x in output])
if not output:
print("Results: None")
print(phoneme_seq)
print(f"Dutch IPA to LH39 mapping: {lh39_ph_seq}")
lh39_ph.append(np.hstack(lh39_ph_seq).tolist())
print(f"Dutch IPA to LH39 total: {lh39_ph}")
phonemes = lh39_ph
if mode in ["test", "val"]:
with torch.no_grad():
self.NFC.eval()
preds,original_lengths, probs, frame_labels, _,preds_peaks, w_phi = self.NFC(audio,None,phonemes,length)
else:
self.NFC.train()
preds,original_lengths, probs, frame_labels,_,preds_peaks, w_phi = self.NFC(audio,seg,phonemes,length)
epoch = self.current_epoch
if epoch < 5:
total_loss, ph_loss, loss_nce, sum_mse, w_pos_neg, w_phi = self.NFC.loss_ph(preds,original_lengths, probs, frame_labels, seg,preds_peaks, w_phi, phonemes)
else:
total_loss, ph_loss, loss_nce, sum_mse, w_pos_neg, w_phi = self.NFC.total_loss(preds,original_lengths, probs, frame_labels, seg,preds_peaks, w_phi, phonemes)
# InfoNCE LOSS ABLATION -
# total_loss, ph_loss, loss_nce, sum_mse, w_pos_neg, w_phi = self.NFC.loss_InfoNCE_classic(preds, original_lengths, probs, frame_labels, seg, preds_peaks, w_phi, phonemes)
NFC_loss = total_loss
self.stats['w_pos'][mode].update(torch.tensor(w_pos_neg[0].item()))
self.stats['w_neg'][mode].update(torch.tensor(w_pos_neg[1].item()))
self.stats['w_phi1'][mode].update(torch.tensor(w_phi[0].item()))
self.stats['w_phi2'][mode].update(torch.tensor(w_phi[1].item()))
self.stats['ph_loss'][mode].update(torch.tensor(ph_loss.item()))
self.stats['nce_loss'][mode].update(torch.tensor(loss_nce.item()))
self.stats['softDPmse_loss'][mode].update(torch.tensor(sum_mse.item()))
self.stats['nfc_loss'][mode].update(torch.tensor(NFC_loss.item()))
loss += NFC_loss
loss_key = "loss" if mode == "train" else f"{mode}_loss"
return OrderedDict({
loss_key: loss
})
def generic_eval_end(self, mode, epoch):
metrics = {}
data = self.hp.data
for k, v in self.stats.items():
metrics[f"{mode}_{k}"] = self.stats[k][mode].get_stats()
metrics['epoch'] = epoch+1
metrics['current_lr'] = self.optimizer.param_groups[0]['lr']
line()
# get best_l1_dist from all l1_dist types and all epochs
best_overall_l1_dist = float("inf")
for pred_type, l1_dist in self.best_l1_dist.items():
if l1_dist[mode][0] < best_overall_l1_dist:
best_overall_l1_dist = l1_dist[mode][0]
metrics[f'{mode}_min_l1_dist'] = best_overall_l1_dist
for k, v in metrics.items():
print(f"\t{k:<30} -- {v}")
line()
wandb.log(metrics, step=epoch)
output = OrderedDict({
'log': metrics
})
return output
def get_ckpt_path(self):
# return glob.glob(self.hp.wd + "/*.ckpt")[0]
return glob.glob(os.path.join(self.hp.wd + "/*.ckpt"))[0]