| import os.path |
| import sys |
|
|
| base_dir = '..' |
| sys.path.append(base_dir) |
|
|
| from Trainer import Trainer |
| from TranslatorTrainer import TranslatorTrainer |
| from dataset import GridDataset, CharMap |
|
|
| |
| WORD_TOKENIZE = False |
| |
| PHONEME_FILTER_PREV = False |
| BEAM_SIZE = 0 |
|
|
| |
| |
| lipnet_weights = 'saved-weights/phonemes-231207-2130/I283000-L00683-W01012-C00765.pt' |
|
|
| if WORD_TOKENIZE: |
| translator_weights = 'saved-weights/translate-231204-1652/I160-L00047-W00000.pt' |
| else: |
| translator_weights = 'saved-weights/translate-231204-2227/I860-L00000-W00000.pt' |
| |
| |
|
|
| lipnet_predictor = Trainer( |
| write_logs=False, base_dir=base_dir, |
| num_workers=0, char_map=CharMap.phonemes |
| ) |
| lipnet_predictor.load_weights(lipnet_weights) |
| lipnet_predictor.load_datasets() |
| dataset = lipnet_predictor.test_dataset |
|
|
| phoneme_translator = TranslatorTrainer( |
| write_logs=False, base_dir=base_dir, word_tokenize=WORD_TOKENIZE |
| ) |
| phoneme_translator.load_weights(os.path.join( |
| base_dir, translator_weights |
| )) |
|
|
| """ |
| new_phonemes = GridDataset.text_to_phonemes("Do you like fries") |
| print("PRE_REV_TRANSLATE", [new_phonemes]) |
| pred_text = phoneme_translator.translate(new_phonemes) |
| print("AFT_REV_TRANSLATE", pred_text) |
| |
| phoneme_sentence = 'B-IH1-N B-L-UW1 AE1-T EH1-F TH-R-IY1 S-UW1-N' |
| pred_text = phoneme_translator.translate(phoneme_sentence) |
| print(f'PRED_TEXT: [{pred_text}]') |
| """ |
|
|
| total_samples = 1000 |
| total_wer = 0 |
| num_correct = 0 |
| num_phonemes_correct = 0 |
|
|
| for k in range(total_samples): |
| sample = dataset.load_random_sample(char_map=all) |
| tgt_phonemes = sample['phonemes'] |
| tgt_text = sample['txt'] |
|
|
| target_phonemes_sentence = dataset.ctc_arr2txt( |
| tgt_phonemes, start=1, char_map=CharMap.phonemes, |
| filter_previous=PHONEME_FILTER_PREV |
| ) |
| target_sentence = dataset.ctc_arr2txt( |
| tgt_text, start=1, char_map=CharMap.letters, |
| filter_previous=False |
| ) |
|
|
| pred_phonemes_sentence = lipnet_predictor.predict_sample(sample)[0] |
| pred_text = phoneme_translator.translate( |
| pred_phonemes_sentence, beam_size=BEAM_SIZE |
| ) |
| match_phonemes = pred_phonemes_sentence == target_phonemes_sentence |
| wer = dataset.get_wer( |
| [pred_text], [target_sentence], char_map=CharMap.letters |
| )[0] |
|
|
| total_wer += wer |
|
|
| correct = False |
| if pred_text == target_sentence: |
| correct = True |
| num_correct += 1 |
| if pred_phonemes_sentence == target_phonemes_sentence: |
| num_phonemes_correct += 1 |
|
|
| print( |
| f'PRED-PHONEMES [{k}]', |
| [pred_phonemes_sentence, target_phonemes_sentence], |
| [pred_text, target_sentence], correct, wer |
| ) |
|
|
| avg_wer = total_wer / total_samples |
| print(f'{num_correct}/{total_samples} samples correct') |
| print(f'{num_phonemes_correct}/{total_samples} phoneme samples correct') |
| print(f'average WER: {avg_wer}') |