| | import os |
| | import numpy as np |
| | import functools |
| | import shutil |
| |
|
| | from typing import List |
| |
|
| | import torch |
| | from tqdm.auto import tqdm |
| |
|
| | from torch.utils.data import DataLoader |
| | from torchtext.datasets import Multi30k |
| |
|
| | import options |
| | from Loader import GridLoader |
| | from PauseChecker import PauseChecker |
| | from dataset import GridDataset, CharMap, Datasets |
| | from datetime import datetime as Datetime |
| |
|
| | from models.PhonemeTransformer import * |
| | from torchtext.vocab import build_vocab_from_iterator |
| | from torch.nn.utils.rnn import pad_sequence |
| | from BaseTrainer import BaseTrainer |
| |
|
| |
|
| | class TranslationDataset(GridDataset): |
| | def __init__( |
| | self, input_char_map: CharMap, |
| | output_char_map: CharMap, **kwargs |
| | ): |
| | super().__init__(**kwargs) |
| | self.input_char_map = input_char_map |
| | self.output_char_map = output_char_map |
| |
|
| | def __getitem__(self, idx): |
| | (vid, spk, name) = self.data[idx] |
| | basename, _ = os.path.splitext(name) |
| |
|
| | input_filepath = self.fetch_anno_path( |
| | spk, basename, char_map=self.input_char_map |
| | ) |
| | output_filepath = self.fetch_anno_path( |
| | spk, basename, char_map=self.output_char_map |
| | ) |
| |
|
| | input_str = self.load_str_sentence( |
| | input_filepath, char_map=self.input_char_map |
| | ) |
| | output_str = self.load_str_sentence( |
| | output_filepath, char_map=self.output_char_map |
| | ) |
| | return input_str, output_str |
| |
|
| |
|
| | class TranslatorTrainer(BaseTrainer): |
| | def __init__( |
| | self, dataset_type: Datasets = options.dataset, |
| | batch_size=128, validate_every=20, display_every=10, |
| | name='translate', write_logs=True, base_dir='', |
| | word_tokenize=False, vocab_files=None, |
| | input_char_map=CharMap.phonemes, |
| | output_char_map=CharMap.letters |
| | ): |
| | super().__init__(name=name, base_dir=base_dir) |
| |
|
| | self.batch_size = batch_size |
| | self.validate_every = validate_every |
| | self.display_every = display_every |
| | self.word_tokenize = word_tokenize |
| | self.input_char_map = input_char_map |
| | self.output_char_map = output_char_map |
| | self.dataset_type = dataset_type |
| |
|
| | self.text_tokenizer = functools.partial( |
| | GridDataset.tokenize_text, word_tokenize=word_tokenize |
| | ) |
| | self.device = torch.device( |
| | 'cuda' if torch.cuda.is_available() else 'cpu' |
| | ) |
| |
|
| | if vocab_files is None: |
| | vocabs = self.load_vocabs(self.base_dir) |
| | self.phonemes_vocab, self.text_vocab = vocabs |
| | else: |
| | phonemes_vocab_path, text_vocab_path = vocab_files |
| | self.phonemes_vocab = torch.load(phonemes_vocab_path) |
| | self.text_vocab = torch.load(text_vocab_path) |
| |
|
| | self.model = None |
| | self.optimizer = None |
| | self.best_test_loss = float('inf') |
| | self.loss_fn = torch.nn.CrossEntropyLoss(ignore_index=PAD_IDX) |
| |
|
| | self.phonemes_encoder = self.sequential_transforms( |
| | GridDataset.tokenize_phonemes, self.phonemes_vocab, |
| | self.tensor_transform |
| | ) |
| | self.text_encoder = self.sequential_transforms( |
| | self.text_tokenizer, self.text_vocab, |
| | self.tensor_transform |
| | ) |
| |
|
| | if write_logs: |
| | self.init_tensorboard() |
| |
|
| | def load_vocabs(self, base_dir): |
| | loader = GridLoader(base_dir=base_dir) |
| |
|
| | if self.dataset_type == Datasets.GRID: |
| | phonemes_text_map = loader.load_grid_phonemes_text_map( |
| | phonemes_char_map=self.input_char_map, |
| | text_char_map=self.output_char_map |
| | ) |
| | elif self.dataset_type == Datasets.LRS2: |
| | phonemes_text_map = loader.load_lsr2_phonemes_text_map( |
| | phonemes_char_map=self.input_char_map, |
| | text_char_map=self.output_char_map |
| | ) |
| | else: |
| | raise NotImplementedError |
| |
|
| | phonemes_map = phonemes_text_map[self.input_char_map] |
| | text_map = phonemes_text_map[self.output_char_map] |
| |
|
| | phonemes_vocab = self.build_vocab( |
| | phonemes_map, tokenizer=GridDataset.tokenize_phonemes |
| | ) |
| | text_vocab = self.build_vocab( |
| | text_map, tokenizer=self.text_tokenizer |
| | ) |
| |
|
| | return phonemes_vocab, text_vocab |
| |
|
| | def save_vocabs( |
| | self, phoneme_vocab_path, text_vocab_path |
| | ): |
| | torch.save(self.phonemes_vocab, phoneme_vocab_path) |
| | torch.save(self.text_vocab, text_vocab_path) |
| |
|
| | def load_weights(self, weights): |
| | self.create_model() |
| |
|
| | pretrained_dict = torch.load(weights) |
| | model_dict = self.model.state_dict() |
| | pretrained_dict = { |
| | k: v for k, v in pretrained_dict.items() if |
| | k in model_dict.keys() and v.size() == model_dict[k].size() |
| | } |
| |
|
| | missed_params = [ |
| | k for k, v in model_dict.items() |
| | if k not in pretrained_dict.keys() |
| | ] |
| |
|
| | print('loaded params/tot params: {}/{}'.format( |
| | len(pretrained_dict), len(model_dict) |
| | )) |
| | print('miss matched params:{}'.format(missed_params)) |
| | model_dict.update(pretrained_dict) |
| | self.model.load_state_dict(model_dict) |
| |
|
| | def create_model(self): |
| | self.model = Seq2SeqTransformer( |
| | src_vocab_size=len(self.phonemes_vocab), |
| | tgt_vocab_size=len(self.text_vocab) |
| | ) |
| |
|
| | self.model = self.model.to(self.device) |
| | self.optimizer = torch.optim.Adam( |
| | self.model.parameters(), |
| | lr=0.0001, betas=(0.9, 0.98), eps=1e-9 |
| | ) |
| |
|
| | def collate_tgt_fn(self, batch): |
| | tgt_batch = [] |
| | for tgt_sample in batch: |
| | tgt_batch.append(self.text_encoder(tgt_sample.rstrip("\n"))) |
| |
|
| | tgt_batch = pad_sequence(tgt_batch, padding_value=PAD_IDX) |
| | return tgt_batch |
| |
|
| | |
| | def collate_fn(self, batch): |
| | src_batch, tgt_batch = [], [] |
| | for src_sample, tgt_sample in batch: |
| | src_batch.append(self.phonemes_encoder(src_sample.rstrip("\n"))) |
| | tgt_batch.append(self.text_encoder(tgt_sample.rstrip("\n"))) |
| |
|
| | src_batch = pad_sequence(src_batch, padding_value=PAD_IDX) |
| | tgt_batch = pad_sequence(tgt_batch, padding_value=PAD_IDX) |
| | return src_batch, tgt_batch |
| |
|
| | def train(self, max_iters=10*1000): |
| | assert self.writer is not None |
| | assert self.display_every < self.validate_every |
| |
|
| | self.create_model() |
| | self.best_test_loss = float('inf') |
| | log_scalar = functools.partial(self.log_scalar, label='train') |
| | self.model.train() |
| | losses = 0 |
| |
|
| | dataset_kwargs = self.get_dataset_kwargs( |
| | input_char_map=self.input_char_map, |
| | char_map=self.output_char_map, |
| | output_char_map=self.output_char_map, |
| | file_list=options.train_list |
| | ) |
| |
|
| | train_iter = TranslationDataset(**dataset_kwargs, phase='train') |
| | test_iter = TranslationDataset(**dataset_kwargs, phase='test') |
| |
|
| | train_dataloader = DataLoader( |
| | train_iter, batch_size=self.batch_size, |
| | |
| | ) |
| | test_dataloader = DataLoader( |
| | test_iter, batch_size=self.batch_size, |
| | |
| | ) |
| |
|
| | tot_iters = 0 |
| | pbar = tqdm(total=max_iters) |
| |
|
| | while tot_iters < max_iters: |
| | for train_pair in train_dataloader: |
| | PauseChecker.check() |
| |
|
| | raw_src, raw_tgt = train_pair |
| | src, tgt = self.collate_fn(zip(raw_src, raw_tgt)) |
| | batch_size, max_seq_len = src.shape |
| |
|
| | src = src.to(self.device) |
| | tgt = tgt.to(self.device) |
| | tgt_input = tgt[:-1, :] |
| | ( |
| | src_mask, tgt_mask, |
| | src_padding_mask, tgt_padding_mask |
| | ) = create_mask(src, tgt_input, self.device) |
| |
|
| | logits = self.model( |
| | src, tgt_input, src_mask, tgt_mask, |
| | src_padding_mask, tgt_padding_mask, src_padding_mask |
| | ) |
| |
|
| | self.optimizer.zero_grad() |
| |
|
| | tgt_out = tgt[1:, :] |
| | loss = self.loss_fn( |
| | logits.reshape(-1, logits.shape[-1]), |
| | tgt_out.reshape(-1) |
| | ) |
| |
|
| | loss.backward() |
| | self.optimizer.step() |
| | loss_item = loss.item() |
| |
|
| | |
| | with torch.no_grad(): |
| | |
| | probs = torch.softmax(logits, dim=-1) |
| | token_indices = torch.argmax(probs, dim=-1) |
| |
|
| | |
| | |
| | gap = ' ' if self.word_tokenize else '' |
| | pred_sentences = self.batch_indices_to_text( |
| | token_indices, batch_size=max_seq_len, gap=gap |
| | ) |
| | wer = np.mean(GridDataset.get_wer( |
| | pred_sentences, raw_tgt, char_map=self.output_char_map |
| | )) |
| |
|
| | desc = f'loss: {loss_item:.4f}, wer: {wer:.4f}' |
| | pbar.desc = desc |
| |
|
| | losses += loss_item |
| | tot_iters += 1 |
| | pbar.update(1) |
| |
|
| | run_validation = ( |
| | (tot_iters > 0) and |
| | (tot_iters % self.validate_every == 0) |
| | ) |
| | run_display = ( |
| | (tot_iters > 0) and |
| | (tot_iters % self.display_every == 0) |
| | ) |
| |
|
| | if run_validation: |
| | self.run_test(test_dataloader, tot_iters=tot_iters) |
| | elif run_display: |
| | print('TRAIN PREDICTIONS') |
| | self.show_sentences(pred_sentences, raw_tgt, batch_size) |
| |
|
| | if self.writer is not None: |
| | log_scalar('loss', loss, tot_iters) |
| | log_scalar('wer', wer, tot_iters) |
| |
|
| | return losses / len(list(train_dataloader)) |
| |
|
| | @staticmethod |
| | def show_sentences( |
| | pred_sentences, target_sentences, batch_size, pad=40 |
| | ): |
| | print('{:<{pad}}|{:>{pad}}'.format( |
| | 'predict', 'target', pad=pad |
| | )) |
| |
|
| | line_length = 2 * pad + 1 |
| | print(''.join(line_length * '-')) |
| |
|
| | for k in range(batch_size): |
| | pred_sentence = pred_sentences[k] |
| | target_sentence = target_sentences[k] |
| | print('{:<{pad}}|{:>{pad}}'.format( |
| | pred_sentence, target_sentence, pad=pad |
| | )) |
| |
|
| | print(''.join(line_length * '-')) |
| |
|
| | def run_test(self, test_dataloader, tot_iters): |
| | log_scalar = functools.partial(self.log_scalar, label='test') |
| |
|
| | with torch.no_grad(): |
| | self.model.eval() |
| |
|
| | for batch in test_dataloader: |
| | break |
| |
|
| | raw_src, raw_tgt = batch |
| | src, tgt = self.collate_fn(zip(raw_src, raw_tgt)) |
| | batch_size, max_seq_len = src.shape |
| | src = src.to(self.device) |
| | tgt = tgt.to(self.device) |
| |
|
| | tgt_input = tgt[:-1, :] |
| | ( |
| | src_mask, tgt_mask, |
| | src_padding_mask, tgt_padding_mask |
| | ) = create_mask(src, tgt_input, self.device) |
| |
|
| | logits = self.model( |
| | src, tgt_input, src_mask, tgt_mask, |
| | src_padding_mask, tgt_padding_mask, src_padding_mask |
| | ) |
| |
|
| | self.optimizer.zero_grad() |
| |
|
| | tgt_out = tgt[1:, :] |
| | loss = self.loss_fn( |
| | logits.reshape(-1, logits.shape[-1]), |
| | tgt_out.reshape(-1) |
| | ) |
| |
|
| | loss_item = loss.item() |
| |
|
| | |
| | probs = torch.softmax(logits, dim=-1) |
| | token_indices = torch.argmax(torch.softmax(logits, dim=-1), dim=-1) |
| | |
| | gap = ' ' if self.word_tokenize else '' |
| | pred_sentences = self.batch_indices_to_text( |
| | token_indices, batch_size=max_seq_len, gap=gap |
| | ) |
| | wer = np.mean(GridDataset.get_wer( |
| | pred_sentences, raw_tgt, char_map=self.output_char_map |
| | )) |
| |
|
| | log_scalar('loss', loss, tot_iters) |
| | log_scalar('wer', wer, tot_iters) |
| | print(f'TEST PREDS [loss={loss_item:.4f}, wer={wer:.4f}]') |
| | self.show_sentences(pred_sentences, raw_tgt, batch_size) |
| |
|
| | if loss < self.best_test_loss: |
| | print(f'NEW BEST LOSS: {loss}') |
| | self.best_test_loss = loss |
| | savename = 'I{}-L{:.4f}-W{:.4f}'.format( |
| | tot_iters, loss, wer |
| | ) |
| |
|
| | savename = savename.replace('.', '') + '.pt' |
| | savepath = os.path.join(self.weights_dir, savename) |
| |
|
| | (save_dir, name) = os.path.split(savepath) |
| | if not os.path.exists(save_dir): |
| | os.makedirs(save_dir) |
| |
|
| | torch.save(self.model.state_dict(), savepath) |
| | print(f'best model saved at {savepath}') |
| |
|
| | def batch_indices_to_text( |
| | self, indices_tensor, batch_size, gap='' |
| | ): |
| | sentences = [] |
| |
|
| | for k in range(batch_size): |
| | tokens = [] |
| |
|
| | for indices_row in indices_tensor: |
| | idx = indices_row[k] |
| |
|
| | if idx == EOS_IDX: |
| | break |
| | if idx in [PAD_IDX, BOS_IDX, EOS_IDX]: |
| | continue |
| |
|
| | token = self.text_vocab.lookup_token(idx) |
| | tokens.append(token) |
| |
|
| | sentence = gap.join(tokens) |
| | sentences.append(sentence) |
| |
|
| | return sentences |
| |
|
| | @staticmethod |
| | def batch_tokenize_text(batch_sentences, word_tokenize=False): |
| | return [ |
| | GridDataset.tokenize_text( |
| | sentence, word_tokenize=word_tokenize |
| | ) for sentence in batch_sentences |
| | ] |
| |
|
| | def evaluate(self, model): |
| | model.eval() |
| | losses = 0 |
| |
|
| | language_pair = (str(CharMap.phonemes), str(CharMap.letters)) |
| | val_iter = Multi30k( |
| | split='valid', language_pair=language_pair |
| | ) |
| | val_dataloader = DataLoader( |
| | val_iter, batch_size=self.batch_size, |
| | collate_fn=self.collate_fn |
| | ) |
| |
|
| | for src, tgt in val_dataloader: |
| | src = src.to(self.device) |
| | tgt = tgt.to(self.device) |
| | tgt_input = tgt[:-1, :] |
| | ( |
| | src_mask, tgt_mask, |
| | src_padding_mask, tgt_padding_mask |
| | ) = create_mask(src, tgt_input, self.device) |
| |
|
| | logits = model( |
| | src, tgt_input, src_mask, tgt_mask, |
| | src_padding_mask, tgt_padding_mask, src_padding_mask |
| | ) |
| |
|
| | tgt_out = tgt[1:, :] |
| | loss = self.loss_fn( |
| | logits.reshape(-1, logits.shape[-1]), |
| | tgt_out.reshape(-1) |
| | ) |
| | losses += loss.item() |
| |
|
| | return losses / len(list(val_dataloader)) |
| |
|
| | |
| | def translate( |
| | self, phoneme_sentence: str, beam_size=0 |
| | ): |
| | self.model.eval() |
| | dummy_sentence = self.text_vocab.lookup_token( |
| | len(self.text_vocab) - 1 |
| | ) |
| | src, _ = self.collate_fn(zip( |
| | [phoneme_sentence], [dummy_sentence] |
| | )) |
| |
|
| | batch_size, max_seq_len = src.shape |
| | src = src.to(self.device) |
| |
|
| | num_tokens = src.shape[0] |
| | src_mask = (torch.zeros(num_tokens, num_tokens)).type(torch.bool) |
| | max_len = num_tokens + 5 |
| |
|
| | if beam_size > 0: |
| | tgt_tokens = self.beam_search_decode( |
| | src, src_mask, max_len=max_len, |
| | start_symbol=BOS_IDX, beam_size=beam_size |
| | ) |
| | else: |
| | tgt_tokens = self.greedy_decode( |
| | src, src_mask, max_len=max_len, |
| | start_symbol=BOS_IDX |
| | ) |
| |
|
| | gap = ' ' if self.word_tokenize else '' |
| | pred_sentence = self.batch_indices_to_text( |
| | tgt_tokens, batch_size=max_seq_len, gap=gap |
| | )[0] |
| | return pred_sentence |
| |
|
| | |
| | def greedy_decode(self, src, src_mask, max_len, start_symbol): |
| | src = src.to(self.device) |
| | src_mask = src_mask.to(self.device) |
| | memory = self.model.encode(src, src_mask) |
| | ys = ( |
| | torch.ones(1, 1).fill_(start_symbol). |
| | type(torch.long).to(self.device) |
| | ) |
| |
|
| | for i in range(max_len - 1): |
| | memory = memory.to(self.device) |
| | tgt_mask = ( |
| | generate_square_subsequent_mask( |
| | ys.size(0), device=self.device |
| | ).type(torch.bool) |
| | ).to(self.device) |
| |
|
| | out = self.model.decode(ys, memory, tgt_mask) |
| | out = out.transpose(0, 1) |
| | prob = self.model.generator(out[:, -1]) |
| | _, next_word = torch.max(prob, dim=1) |
| | next_word = next_word.item() |
| |
|
| | ys = torch.cat([ |
| | ys, torch.ones(1, 1).type_as(src.data).fill_(next_word) |
| | ], dim=0) |
| |
|
| | if next_word == EOS_IDX: |
| | break |
| |
|
| | return ys |
| |
|
| | def beam_search_decode( |
| | self, src, src_mask, max_len, start_symbol, beam_size=5 |
| | ): |
| | src = src.to(self.device) |
| | src_mask = src_mask.to(self.device) |
| | memory = self.model.encode(src, src_mask) |
| | ys = ( |
| | torch.ones(1, 1).fill_(start_symbol). |
| | type(torch.long).to(self.device) |
| | ) |
| |
|
| | |
| | hypotheses = [(ys, 0.0)] |
| |
|
| | for _ in range(max_len - 1): |
| | new_hypotheses = [] |
| |
|
| | for seq, score in hypotheses: |
| | if seq[-1] == EOS_IDX: |
| | new_hypotheses.append((seq, score)) |
| | continue |
| |
|
| | tgt_mask = generate_square_subsequent_mask( |
| | seq.size(0), device=self.device |
| | ).type(torch.bool) |
| |
|
| | out = self.model.decode(seq, memory, tgt_mask) |
| | out = out.transpose(0, 1) |
| | prob = self.model.generator(out[:, -1]) |
| | |
| | topk_prob, topk_indices = torch.topk(prob, beam_size) |
| |
|
| | for i in range(beam_size): |
| | next_word = topk_indices[0][i] |
| | |
| | next_score = score - topk_prob[0][i].item() |
| | new_seq = torch.cat([ |
| | seq, torch.ones(1, 1).type_as(src.data).fill_(next_word) |
| | ], dim=0) |
| |
|
| | |
| | new_hypotheses.append((new_seq, next_score)) |
| |
|
| | if len(new_hypotheses) == 0: |
| | break |
| |
|
| | |
| | hypotheses = sorted( |
| | new_hypotheses, key=lambda x: x[1] |
| | )[:beam_size] |
| |
|
| | return hypotheses[0][0] |
| |
|
| | @staticmethod |
| | def yield_tokens(sequence_map, tokenizer): |
| | for key in sequence_map: |
| | yield tokenizer(sequence_map[key]) |
| |
|
| | def build_vocab(self, sequence_map, tokenizer): |
| | return build_vocab_from_iterator( |
| | self.yield_tokens(sequence_map, tokenizer), |
| | min_freq=1, specials=SPECIAL_SYMBOLS, |
| | special_first=True |
| | ) |
| |
|
| | |
| | @staticmethod |
| | def sequential_transforms(*transforms): |
| | def func(txt_input): |
| | for transform in transforms: |
| | txt_input = transform(txt_input) |
| |
|
| | return txt_input |
| |
|
| | return func |
| |
|
| | |
| | @staticmethod |
| | def tensor_transform(token_ids: List[int]): |
| | return torch.cat(( |
| | torch.tensor([BOS_IDX]), torch.tensor(token_ids), |
| | torch.tensor([EOS_IDX]) |
| | )) |
| |
|
| |
|
| | if __name__ == '__main__': |
| | vocab_filepaths = ( |
| | 'data/grid_phoneme_vocab.pth', |
| | 'data/grid_text_char_vocab.pth' |
| | ) |
| | """ |
| | vocab_filepaths = ( |
| | 'data/lsr2_phoneme_vocab.pth', |
| | 'data/lsr2_text_char_vocab.pth' |
| | ) |
| | """ |
| |
|
| | trainer = TranslatorTrainer( |
| | word_tokenize=False, vocab_files=vocab_filepaths, |
| | input_char_map=options.char_map, |
| | output_char_map=options.text_char_map |
| | ) |
| |
|
| | trainer.train() |
| | |
| | |
| | |
| | |