| import os |
| import sys |
| import time |
|
|
| sys.path.append('../models') |
|
|
| import torch |
| import functools |
| import options as opt |
|
|
| from torch import optim |
| from tqdm.auto import tqdm |
|
|
| from PauseChecker import PauseChecker |
| from Trainer import Trainer |
| from models.LipNetPlus import LipNetPlus |
| from TranslatorTrainer import TranslatorTrainer |
| from dataset import GridDataset, CharMap, Datasets |
| from helpers import contains_nan_or_inf |
| from models.PhonemeTransformer import * |
| from helpers import * |
|
|
|
|
| class TransformerTrainer(Trainer, TranslatorTrainer): |
| def __init__( |
| self, batch_size=opt.batch_size, word_tokenize=False, |
| dataset_type: Datasets = opt.dataset, embeds_size=256, |
| vocab_files=None, write_logs=True, |
| input_char_map=CharMap.phonemes, |
| output_char_map=CharMap.letters, |
| name='embeds-transformer-v2', |
| **kwargs |
| ): |
| super().__init__(**kwargs, name=name) |
|
|
| self.batch_size = batch_size |
| self.word_tokenize = word_tokenize |
| self.input_char_map = input_char_map |
| self.output_char_map = output_char_map |
| self.dataset_type = dataset_type |
| self.embeds_size = embeds_size |
|
|
| self.text_tokenizer = functools.partial( |
| GridDataset.tokenize_text, word_tokenize=word_tokenize |
| ) |
| self.device = torch.device( |
| 'cuda' if torch.cuda.is_available() else 'cpu' |
| ) |
|
|
| if vocab_files is None: |
| vocabs = self.load_vocabs(self.base_dir) |
| self.phonemes_vocab, self.text_vocab = vocabs |
| else: |
| phonemes_vocab_path, text_vocab_path = vocab_files |
| self.phonemes_vocab = torch.load(phonemes_vocab_path) |
| self.text_vocab = torch.load(text_vocab_path) |
|
|
| self.model = None |
| self.optimizer = None |
| self.best_test_loss = float('inf') |
| self.loss_fn = torch.nn.CrossEntropyLoss(ignore_index=PAD_IDX) |
|
|
| """ |
| self.phonemes_encoder = self.sequential_transforms( |
| GridDataset.tokenize_phonemes, self.phonemes_vocab, |
| self.tensor_transform |
| ) |
| """ |
| self.text_encoder = self.sequential_transforms( |
| self.text_tokenizer, self.text_vocab, |
| self.tensor_transform |
| ) |
|
|
| if write_logs: |
| self.init_tensorboard() |
|
|
| def create_model(self): |
| if self.model is None: |
| output_classes = len(self.train_dataset.get_char_mapping()) |
|
|
| self.model = LipNetPlus( |
| output_classes=output_classes, |
| pre_gru_repeats=self.pre_gru_repeats, |
| embeds_size=self.embeds_size, |
| output_vocab_size=len(self.text_vocab) |
| ) |
| self.model = self.model.cuda() |
| if self.net is None: |
| self.net = nn.DataParallel(self.model).cuda() |
|
|
| def load_datasets(self): |
| if self.train_dataset is None: |
| self.train_dataset = GridDataset( |
| **self.dataset_kwargs, phase='train', |
| file_list=opt.train_list, |
| sample_all_props=True |
| ) |
| if self.test_dataset is None: |
| self.test_dataset = GridDataset( |
| **self.dataset_kwargs, phase='test', |
| file_list=opt.val_list, |
| sample_all_props=True |
| ) |
|
|
| def train(self): |
| self.load_datasets() |
| self.create_model() |
|
|
| dataset = self.train_dataset |
| loader = self.dataset2dataloader( |
| dataset, num_workers=self.num_workers |
| ) |
| """ |
| optimizer = optim.Adam( |
| self.model.parameters(), lr=opt.base_lr, |
| weight_decay=0., amsgrad=True |
| ) |
| """ |
| optimizer = optim.RMSprop( |
| self.model.parameters(), lr=opt.base_lr |
| ) |
|
|
| print('num_train_data:{}'.format(len(dataset.data))) |
| |
| |
| tic = time.time() |
|
|
| self.best_test_loss = float('inf') |
| log_scalar = functools.partial(self.log_scalar, label='train') |
|
|
| for epoch in range(opt.max_epoch): |
| print(f'RUNNING EPOCH {epoch}') |
| train_wer = [] |
|
|
| pbar = tqdm(loader) |
| for (i_iter, input_sample) in enumerate(pbar): |
| PauseChecker.check() |
|
|
| self.model.train() |
| vid = input_sample.get('vid').cuda() |
| |
| |
| batch_arr_sentences = input_sample['txt_anno'] |
| batch_arr_sentences = np.array(batch_arr_sentences) |
|
|
| _, batch_size = batch_arr_sentences.shape |
| batch_sentences = [ |
| ''.join(batch_arr_sentences[:, k]).strip() |
| for k in range(batch_size) |
| ] |
|
|
| tgt = self.collate_tgt_fn(batch_sentences) |
| tgt = tgt.to(self.device) |
| tgt_input = tgt[:-1, :] |
|
|
| with torch.no_grad(): |
| gru_output = self.model.forward_gru(vid) |
| y = self.model.predict_from_gru_out(gru_output) |
|
|
| src_embeds = self.model.make_src_embeds(gru_output) |
| transformer_out = self.make_transformer_embeds( |
| dataset, src_embeds, y, batch_size=batch_size |
| ) |
|
|
| transformer_src_embeds, src_idx_arr = transformer_out |
| transformer_src_embeds = transformer_src_embeds.to(self.device) |
| src_idx_arr = src_idx_arr.to(self.device) |
| max_seq_len, batch_size = src_idx_arr.shape |
|
|
| ( |
| src_mask, tgt_mask, |
| src_padding_mask, tgt_padding_mask |
| ) = create_mask( |
| src_idx_arr, tgt_input, self.device |
| ) |
|
|
| logits = self.model.seq_forward( |
| transformer_src_embeds, tgt_input, src_mask, tgt_mask, |
| src_padding_mask, tgt_padding_mask, src_padding_mask |
| ) |
|
|
| optimizer.zero_grad() |
|
|
| tgt_out = tgt[1:, :] |
| loss = self.loss_fn( |
| logits.reshape(-1, logits.shape[-1]), |
| tgt_out.reshape(-1) |
| ) |
|
|
| tot_iter = i_iter + epoch * len(loader) |
|
|
| loss.backward() |
| optimizer.step() |
|
|
| |
| with torch.no_grad(): |
| |
| probs = torch.softmax(logits, dim=-1) |
| token_indices = torch.argmax(probs, dim=-1) |
|
|
| |
| |
| gap = ' ' if self.word_tokenize else '' |
| |
| pred_sentences = self.batch_indices_to_text( |
| token_indices, batch_size=batch_size, gap=gap |
| ) |
| wer = np.mean(GridDataset.get_wer( |
| pred_sentences, batch_sentences, |
| char_map=self.output_char_map |
| )) |
| train_wer.append(wer) |
|
|
| if tot_iter % opt.display == 0: |
| v = 1.0 * (time.time() - tic) / (tot_iter + 1) |
| eta = (len(loader) - i_iter) * v / 3600.0 |
| wer = np.array(train_wer).mean() |
|
|
| log_scalar('loss', loss, tot_iter) |
| log_scalar('wer', wer, tot_iter) |
| self.log_pred_texts( |
| pred_sentences, batch_sentences, sub_samples=3 |
| ) |
|
|
| print('epoch={},tot_iter={},eta={},loss={},train_wer={}' |
| .format( |
| epoch, tot_iter, eta, loss, |
| np.array(train_wer).mean() |
| )) |
| print(''.join(161 * '-')) |
|
|
| if (tot_iter > -1) and (tot_iter % opt.test_step == 0): |
| |
| self.run_test(tot_iter, optimizer) |
|
|
| def make_transformer_embeds( |
| self, dataset, src_embeds, y, batch_size |
| ): |
| batch_indices = dataset.ctc_decode_indices(y) |
| filter_batch_embeds = [] |
|
|
| pad_embed = self.model.src_tok_emb( |
| torch.IntTensor([PAD_IDX]).to(self.device) |
| ) |
| begin_embed = self.model.src_tok_emb( |
| torch.IntTensor([BOS_IDX]).to(self.device) |
| ) |
| end_embed = self.model.src_tok_emb( |
| torch.IntTensor([EOS_IDX]).to(self.device) |
| ) |
| max_sentence_len = max([len(x) for x in batch_indices]) |
|
|
| |
| |
| transformer_src_embeds = pad_embed.expand( |
| max_sentence_len + 2, batch_size, pad_embed.shape[1] |
| ) |
|
|
| src_idx_mask = torch.full( |
| transformer_src_embeds.shape[:2], PAD_IDX, |
| dtype=torch.int |
| ) |
|
|
| |
| for k, sentence_indices in enumerate(batch_indices): |
| filter_sentence_embeds = [] |
| for sentence_index in sentence_indices: |
| filter_sentence_embeds.append( |
| src_embeds[sentence_index][k] |
| ) |
|
|
| sentence_length = len(filter_sentence_embeds) |
| filter_batch_embeds.append(filter_sentence_embeds) |
| |
| transformer_src_embeds[0][k] = begin_embed |
| src_idx_mask[0][k] = UNK_IDX |
|
|
| |
| for i, char_embed in enumerate(filter_sentence_embeds): |
| transformer_src_embeds[i + 1][k] = char_embed |
| src_idx_mask[i + 1][k] = UNK_IDX |
|
|
| transformer_src_embeds[sentence_length + 1][k] = end_embed |
| src_idx_mask[sentence_length + 1][k] = UNK_IDX |
|
|
| return transformer_src_embeds, src_idx_mask |
|
|
| @staticmethod |
| def log_pred_texts( |
| pred_txt, truth_txt, pad=80, sub_samples=None |
| ): |
| line_length = 2 * pad + 1 |
| print(''.join(line_length * '-')) |
| print('{:<{pad}}|{:>{pad}}'.format( |
| 'predict', 'truth', pad=pad |
| )) |
|
|
| print(''.join(line_length * '-')) |
| zipped_samples = list(zip(pred_txt, truth_txt)) |
| if sub_samples is not None: |
| zipped_samples = zipped_samples[:sub_samples] |
|
|
| for (predict, truth) in zipped_samples: |
| print('{:<{pad}}|{:>{pad}}'.format( |
| predict, truth, pad=pad |
| )) |
|
|
| print(''.join(line_length * '-')) |
|
|
| def test(self): |
| dataset = self.test_dataset |
|
|
| with torch.no_grad(): |
| print('num_test_data:{}'.format(len(dataset.data))) |
| self.model.eval() |
| loader = self.dataset2dataloader( |
| dataset, shuffle=False, num_workers=self.num_workers |
| ) |
|
|
| loss_list = [] |
| wer = [] |
| cer = [] |
| tic = time.time() |
| print('RUNNING VALIDATION') |
|
|
| pbar = tqdm(loader) |
| for (i_iter, input_sample) in enumerate(pbar): |
| PauseChecker.check() |
|
|
| vid = input_sample.get('vid').cuda() |
| batch_arr_sentences = input_sample['txt_anno'] |
| batch_arr_sentences = np.array(batch_arr_sentences) |
|
|
| _, batch_size = batch_arr_sentences.shape |
| batch_sentences = [ |
| ''.join(batch_arr_sentences[:, k]).strip() |
| for k in range(batch_size) |
| ] |
|
|
| tgt = self.collate_tgt_fn(batch_sentences) |
| tgt = tgt.to(self.device) |
| tgt_input = tgt[:-1, :] |
|
|
| with torch.no_grad(): |
| gru_output = self.model.forward_gru(vid) |
| y = self.model.predict_from_gru_out(gru_output) |
|
|
| src_embeds = self.model.make_src_embeds(gru_output) |
| transformer_out = self.make_transformer_embeds( |
| dataset, src_embeds, y, batch_size=batch_size |
| ) |
|
|
| transformer_src_embeds, src_idx_arr = transformer_out |
| transformer_src_embeds = transformer_src_embeds.to(self.device) |
| src_idx_arr = src_idx_arr.to(self.device) |
| max_seq_len, batch_size = src_idx_arr.shape |
|
|
| ( |
| src_mask, tgt_mask, |
| src_padding_mask, tgt_padding_mask |
| ) = create_mask( |
| src_idx_arr, tgt_input, self.device |
| ) |
|
|
| logits = self.model.seq_forward( |
| transformer_src_embeds, tgt_input, src_mask, tgt_mask, |
| src_padding_mask, tgt_padding_mask, src_padding_mask |
| ) |
|
|
| |
| with torch.no_grad(): |
| |
| probs = torch.softmax(logits, dim=-1) |
| token_indices = torch.argmax(probs, dim=-1) |
|
|
| |
| |
| gap = ' ' if self.word_tokenize else '' |
| |
| pred_sentences = self.batch_indices_to_text( |
| token_indices, batch_size=batch_size, gap=gap |
| ) |
|
|
| tgt_out = tgt[1:, :] |
| loss = self.loss_fn( |
| logits.reshape(-1, logits.shape[-1]), |
| tgt_out.reshape(-1) |
| ) |
|
|
| loss_item = loss.detach().cpu().numpy() |
| loss_list.append(loss_item) |
|
|
| wer.extend(GridDataset.get_wer( |
| pred_sentences, batch_sentences, |
| char_map=self.output_char_map |
| )) |
| cer.extend(GridDataset.get_cer( |
| pred_sentences, batch_sentences, |
| char_map=self.output_char_map |
| )) |
|
|
| if i_iter % opt.display == 0: |
| v = 1.0 * (time.time() - tic) / (i_iter + 1) |
| eta = v * (len(loader) - i_iter) / 3600.0 |
|
|
| self.log_pred_texts( |
| pred_sentences, batch_sentences, sub_samples=10 |
| ) |
|
|
| print('test_iter={},eta={},wer={},cer={}'.format( |
| i_iter, eta, np.array(wer).mean(), |
| np.array(cer).mean() |
| )) |
| print(''.join(161 * '-')) |
|
|
| return ( |
| np.array(loss_list).mean(), np.array(wer).mean(), |
| np.array(cer).mean() |
| ) |
|
|
| def run_test(self, tot_iter, optimizer): |
| log_scalar = functools.partial(self.log_scalar, label='test') |
|
|
| (loss, wer, cer) = self.test() |
| print('i_iter={},lr={},loss={},wer={},cer={}'.format( |
| tot_iter, show_lr(optimizer), loss, wer, cer |
| )) |
| log_scalar('loss', loss, tot_iter) |
| log_scalar('wer', wer, tot_iter) |
| log_scalar('cer', cer, tot_iter) |
|
|
| if loss < self.best_test_loss: |
| print(f'NEW BEST LOSS: {loss}') |
| self.best_test_loss = loss |
|
|
| savename = 'I{}-L{:.4f}-W{:.4f}-C{:.4f}'.format( |
| tot_iter, loss, wer, cer |
| ) |
|
|
| savename = savename.replace('.', '') + '.pt' |
| savepath = os.path.join(self.weights_dir, savename) |
|
|
| (save_dir, name) = os.path.split(savepath) |
| if not os.path.exists(save_dir): |
| os.makedirs(save_dir) |
|
|
| torch.save(self.model.state_dict(), savepath) |
| print(f'best model saved at {savepath}') |
|
|
| if not opt.is_optimize: |
| exit() |
|
|
|
|
| if __name__ == '__main__': |
| vocab_filepaths = ( |
| 'data/grid_phoneme_vocab.pth', |
| 'data/grid_text_char_vocab.pth' |
| ) |
| """ |
| vocab_filepaths = ( |
| 'data/lsr2_phoneme_vocab.pth', |
| 'data/lsr2_text_char_vocab.pth' |
| ) |
| """ |
|
|
| trainer = TransformerTrainer( |
| word_tokenize=False, vocab_files=vocab_filepaths, |
| input_char_map=opt.char_map, |
| output_char_map=opt.text_char_map |
| ) |
|
|
| if hasattr(opt, 'weights'): |
| trainer.load_weights(opt.weights) |
|
|
| trainer.train() |