| import torch.nn as nn |
| import functools |
| import torch.optim as optim |
| import options as opt |
| import time |
|
|
| from helpers import * |
| from dataset import GridDataset, CharMap |
| from datetime import datetime as Datetime |
| from models.LipNet import LipNet |
| from tqdm.auto import tqdm |
| from PauseChecker import PauseChecker |
| from torch.utils.data import DataLoader |
| from torch.multiprocessing import Manager |
| from BaseTrainer import BaseTrainer |
|
|
|
|
| class Trainer(BaseTrainer): |
| def __init__( |
| self, name=opt.run_name, write_logs=True, |
| num_workers=None, base_dir='', char_map=opt.char_map, |
| pre_gru_repeats=None |
| ): |
| super().__init__(name=name, base_dir=base_dir) |
|
|
| images_dir = opt.images_dir |
| if opt.use_lip_crops: |
| images_dir = opt.crop_images_dir |
| if num_workers is None: |
| num_workers = opt.num_workers |
| if pre_gru_repeats is None: |
| pre_gru_repeats = opt.pre_gru_repeats |
|
|
| assert pre_gru_repeats >= 1 |
| assert isinstance(pre_gru_repeats, int) |
|
|
| self.images_dir = images_dir |
| self.num_workers = num_workers |
| self.pre_gru_repeats = pre_gru_repeats |
| self.char_map = char_map |
|
|
| manager = Manager() |
| if opt.cache_videos: |
| shared_dict = manager.dict() |
| else: |
| shared_dict = None |
|
|
| self.shared_dict = shared_dict |
| self.dataset_kwargs = self.get_dataset_kwargs( |
| shared_dict=shared_dict, base_dir=self.base_dir, |
| char_map=self.char_map |
| ) |
|
|
| self.best_test_loss = float('inf') |
| self.train_dataset = None |
| self.test_dataset = None |
| self.model = None |
| self.net = None |
|
|
| if write_logs: |
| self.init_tensorboard() |
|
|
| def load_datasets(self): |
| if self.train_dataset is None: |
| self.train_dataset = GridDataset( |
| **self.dataset_kwargs, phase='train', |
| file_list=opt.train_list |
| ) |
| if self.test_dataset is None: |
| self.test_dataset = GridDataset( |
| **self.dataset_kwargs, phase='test', |
| file_list=opt.val_list |
| ) |
|
|
| def create_model(self): |
| output_classes = len(self.train_dataset.get_char_mapping()) |
|
|
| if self.model is None: |
| self.model = LipNet( |
| output_classes=output_classes, |
| pre_gru_repeats=self.pre_gru_repeats |
| ) |
| self.model = self.model.cuda() |
| if self.net is None: |
| self.net = nn.DataParallel(self.model).cuda() |
|
|
| def load_weights(self, weights_path): |
| self.load_datasets() |
| self.create_model() |
|
|
| weights_path = os.path.join(self.base_dir, weights_path) |
| pretrained_dict = torch.load(weights_path) |
| model_dict = self.model.state_dict() |
| pretrained_dict = { |
| k: v for k, v in pretrained_dict.items() if |
| k in model_dict.keys() and v.size() == model_dict[k].size() |
| } |
|
|
| missed_params = [ |
| k for k, v in model_dict.items() |
| if k not in pretrained_dict.keys() |
| ] |
|
|
| print('loaded params/tot params: {}/{}'.format( |
| len(pretrained_dict), len(model_dict) |
| )) |
| print('miss matched params:{}'.format(missed_params)) |
| model_dict.update(pretrained_dict) |
| self.model.load_state_dict(model_dict) |
|
|
| @staticmethod |
| def make_date_stamp(): |
| return Datetime.now().strftime("%y%m%d-%H%M") |
|
|
| @staticmethod |
| def dataset2dataloader( |
| dataset, num_workers, shuffle=True |
| ): |
| return DataLoader( |
| dataset, |
| batch_size=opt.batch_size, |
| shuffle=shuffle, |
| num_workers=num_workers, |
| drop_last=False |
| ) |
|
|
| def test(self): |
| dataset = self.test_dataset |
|
|
| with torch.no_grad(): |
| print('num_test_data:{}'.format(len(dataset.data))) |
| self.model.eval() |
| loader = self.dataset2dataloader( |
| dataset, shuffle=False, num_workers=self.num_workers |
| ) |
|
|
| loss_list = [] |
| wer = [] |
| cer = [] |
| crit = nn.CTCLoss(zero_infinity=True) |
| tic = time.time() |
| print('RUNNING VALIDATION') |
|
|
| pbar = tqdm(loader) |
| for (i_iter, input_sample) in enumerate(pbar): |
| PauseChecker.check() |
|
|
| vid = input_sample.get('vid').cuda() |
| vid_len = input_sample.get('vid_len').cuda() |
| txt, txt_len = self.extract_char_output(input_sample) |
| y = self.net(vid) |
|
|
| |
| assert ( |
| self.pre_gru_repeats * vid_len.view(-1) > |
| 2 * txt_len.view(-1) |
| ).all() |
|
|
| loss = crit( |
| y.transpose(0, 1).log_softmax(-1), txt, |
| self.pre_gru_repeats * vid_len.view(-1), |
| txt_len.view(-1) |
| ).detach().cpu().numpy() |
|
|
| loss_list.append(loss) |
| pred_txt = dataset.ctc_decode(y) |
| truth_txt = [ |
| dataset.arr2txt(txt[_], start=1) |
| for _ in range(txt.size(0)) |
| ] |
|
|
| wer.extend(dataset.wer(pred_txt, truth_txt)) |
| cer.extend(dataset.cer(pred_txt, truth_txt)) |
|
|
| if i_iter % opt.display == 0: |
| v = 1.0 * (time.time() - tic) / (i_iter + 1) |
| eta = v * (len(loader) - i_iter) / 3600.0 |
|
|
| self.log_pred_texts(pred_txt, truth_txt, sub_samples=10) |
| print('test_iter={},eta={},wer={},cer={}'.format( |
| i_iter, eta, np.array(wer).mean(), |
| np.array(cer).mean() |
| )) |
| print(''.join(161 * '-')) |
|
|
| return ( |
| np.array(loss_list).mean(), np.array(wer).mean(), |
| np.array(cer).mean() |
| ) |
|
|
| def extract_char_output(self, input_sample): |
| """ |
| extract output character sequence from input_sample |
| output character sequence is text if char_map is CharMap.letters |
| output character sequence is phonemes if char_map is CharMap.phonemes |
| """ |
| if self.char_map == CharMap.letters: |
| txt = input_sample.get('txt').cuda() |
| txt_len = input_sample.get('txt_len').cuda() |
| elif self.char_map == CharMap.phonemes: |
| txt = input_sample.get('phonemes').cuda() |
| txt_len = input_sample.get('phonemes_len').cuda() |
| elif self.char_map == CharMap.cmu_phonemes: |
| txt = input_sample.get('cmu_phonemes').cuda() |
| txt_len = input_sample.get('cmu_phonemes_len').cuda() |
| else: |
| raise ValueError(f'UNSUPPORTED CHAR_MAP: {self.char_map}') |
|
|
| return txt, txt_len |
|
|
| def train(self): |
| self.load_datasets() |
| self.create_model() |
|
|
| dataset = self.train_dataset |
| loader = self.dataset2dataloader( |
| dataset, num_workers=self.num_workers |
| ) |
| """ |
| optimizer = optim.Adam( |
| self.model.parameters(), lr=opt.base_lr, |
| weight_decay=0., amsgrad=True |
| ) |
| """ |
| optimizer = optim.RMSprop( |
| self.model.parameters(), lr=opt.base_lr |
| ) |
|
|
| print('num_train_data:{}'.format(len(dataset.data))) |
| |
| |
| crit = nn.CTCLoss(zero_infinity=True) |
| tic = time.time() |
|
|
| train_wer = [] |
| self.best_test_loss = float('inf') |
| log_scalar = functools.partial(self.log_scalar, label='train') |
|
|
| for epoch in range(opt.max_epoch): |
| print(f'RUNNING EPOCH {epoch}') |
|
|
| pbar = tqdm(loader) |
| for (i_iter, input_sample) in enumerate(pbar): |
| PauseChecker.check() |
|
|
| self.model.train() |
| vid = input_sample.get('vid').cuda() |
| vid_len = input_sample.get('vid_len').cuda() |
| txt, txt_len = self.extract_char_output(input_sample) |
|
|
| optimizer.zero_grad() |
| y = self.net(vid) |
| assert not contains_nan_or_inf(y) |
| assert ( |
| self.pre_gru_repeats * vid_len.view(-1) > |
| 2 * txt_len.view(-1) |
| ).all() |
|
|
| loss = crit( |
| y.transpose(0, 1).log_softmax(-1), txt, |
| self.pre_gru_repeats * vid_len.view(-1), |
| txt_len.view(-1) |
| ) |
|
|
| if contains_nan_or_inf(loss): |
| print(f'LOSS IS INVALID. SKIPPING {i_iter}') |
| |
| |
| continue |
|
|
| loss.backward() |
| params = self.model.parameters() |
| |
| if any(torch.isnan(p.grad).any() for p in params): |
| optimizer.zero_grad() |
| print('SKIPPING NAN GRADS') |
| continue |
|
|
| if opt.is_optimize: |
| optimizer.step() |
|
|
| assert not contains_nan_or_inf(self.model.conv1.weight) |
| tot_iter = i_iter + epoch * len(loader) |
| pred_txt = dataset.ctc_decode(y) |
| truth_txt = [ |
| dataset.arr2txt(txt[_], start=1) |
| for _ in range(txt.size(0)) |
| ] |
| train_wer.extend(dataset.wer(pred_txt, truth_txt)) |
|
|
| if tot_iter % opt.display == 0: |
| v = 1.0 * (time.time() - tic) / (tot_iter + 1) |
| eta = (len(loader) - i_iter) * v / 3600.0 |
| wer = np.array(train_wer).mean() |
|
|
| log_scalar('loss', loss, tot_iter) |
| log_scalar('wer', wer, tot_iter) |
|
|
| self.log_pred_texts(pred_txt, truth_txt, sub_samples=3) |
| print('epoch={},tot_iter={},eta={},loss={},train_wer={}' |
| .format( |
| epoch, tot_iter, eta, loss, |
| np.array(train_wer).mean() |
| ) |
| ) |
| print(''.join(161 * '-')) |
|
|
| if (tot_iter > 0) and (tot_iter % opt.test_step == 0): |
| |
| self.run_test(tot_iter, optimizer) |
|
|
| @staticmethod |
| def log_pred_texts(pred_txt, truth_txt, pad=80, sub_samples=None): |
| line_length = 2 * pad + 1 |
| print(''.join(line_length * '-')) |
| print('{:<{pad}}|{:>{pad}}'.format( |
| 'predict', 'truth', pad=pad |
| )) |
|
|
| print(''.join(line_length * '-')) |
| zipped_samples = list(zip(pred_txt, truth_txt)) |
| if sub_samples is not None: |
| zipped_samples = zipped_samples[:sub_samples] |
|
|
| for (predict, truth) in zipped_samples: |
| print('{:<{pad}}|{:>{pad}}'.format( |
| predict, truth, pad=pad |
| )) |
|
|
| print(''.join(line_length * '-')) |
|
|
| def run_test(self, tot_iter, optimizer): |
| log_scalar = functools.partial(self.log_scalar, label='test') |
|
|
| (loss, wer, cer) = self.test() |
| print('i_iter={},lr={},loss={},wer={},cer={}'.format( |
| tot_iter, show_lr(optimizer), loss, wer, cer |
| )) |
| log_scalar('loss', loss, tot_iter) |
| log_scalar('wer', wer, tot_iter) |
| log_scalar('cer', cer, tot_iter) |
|
|
| if loss < self.best_test_loss: |
| print(f'NEW BEST LOSS: {loss}') |
| self.best_test_loss = loss |
|
|
| savename = 'I{}-L{:.4f}-W{:.4f}-C{:.4f}'.format( |
| tot_iter, loss, wer, cer |
| ) |
|
|
| savename = savename.replace('.', '') + '.pt' |
| savepath = os.path.join(self.weights_dir, savename) |
|
|
| (save_dir, name) = os.path.split(savepath) |
| if not os.path.exists(save_dir): |
| os.makedirs(save_dir) |
|
|
| torch.save(self.model.state_dict(), savepath) |
| print(f'best model saved at {savepath}') |
|
|
| if not opt.is_optimize: |
| exit() |
|
|
| def predict_sample(self, input_sample): |
| self.model.eval() |
| vid = input_sample.get('vid').cuda() |
| return self.predict_video(vid) |
|
|
| def predict_video(self, video): |
| video = video.cuda() |
| vid = video.unsqueeze(0) |
| y = self.net(vid) |
| pred_txt = self.train_dataset.ctc_decode(y) |
| return pred_txt |