| |
| import numpy as np |
| import glob |
| import time |
| import cv2 |
| import yaml |
| import os |
| import torch |
| import glob |
| import re |
| import string |
| import copy |
| import json |
| import random |
| import enum |
| import editdistance |
| import pronouncing |
|
|
| from torch.utils.data import Dataset |
|
|
| import Extractor |
| import options |
| from cvtransforms import * |
| from typing import List, Iterable |
| from helpers import * |
|
|
|
|
| class CharMap(str, enum.Enum): |
| letters = 'letters' |
| lsr2_text = 'lsr2_text' |
| phonemes = 'phonemes' |
| cmu_phonemes = 'cmu_phonemes' |
| visemes = 'visemes' |
|
|
|
|
| class Datasets(str, enum.Enum): |
| GRID = 'GRID' |
| LRS2 = 'LRS2' |
|
|
|
|
| class GridDataset(Dataset): |
| letters = [ |
| ' ', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', |
| 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', |
| 'T', 'U', 'V', 'W', 'X', 'Y', 'Z' |
| ] |
| lrs2_chars = [ |
| ' ', "'", '0', '1', '2', '3', '4', '5', '6', '7', '8', |
| '9', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', |
| 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', |
| 'V', 'W', 'X', 'Y', 'Z' |
| ] |
| |
| phonemes = [ |
| ' ', 'AE1', 'AO1', 'D', 'JH', 'Y', 'P', 'AH0', 'OW1', 'G', |
| 'AY1', 'TH', 'IY1', 'CH', 'T', 'AW1', 'F', 'AH1', 'Z', |
| 'R', 'EH1', 'UW1', 'M', 'B', 'W', 'V', 'DH', 'K', 'IH0', |
| 'AA1', 'IH1', 'S', 'EY1', 'N', 'OW0', 'L' |
| ] |
| |
| cmu_phonemes = [ |
| ' ', '#', 'AA', 'AA0', 'AA1', 'AA2', 'AE', 'AE0', 'AE1', |
| 'AE2', 'AH', 'AH0', 'AH1', 'AH2', 'AO', 'AO0', 'AO1', 'AO2', |
| 'AW', 'AW0', 'AW1', 'AW2', 'AY', 'AY0', 'AY1', 'AY2', |
| 'B', 'CH', 'D', 'DH', 'EH', 'EH0', 'EH1', 'EH2', 'ER', |
| 'ER0', 'ER1', 'ER2', 'EY', 'EY0', 'EY1', 'EY2', 'F', 'G', |
| 'HH', 'IH', 'IH0', 'IH1', 'IH2', 'IY', 'IY0', 'IY1', 'IY2', |
| 'JH', 'K', 'L', 'M', 'N', 'NG', 'OW', 'OW0', 'OW1', 'OW2', |
| 'OY', 'OY0', 'OY1', 'OY2', 'P', 'R', 'S', 'SH', 'T', 'TH', |
| 'UH', 'UH0', 'UH1', 'UH2', 'UW', 'UW0', 'UW1', 'UW2', 'V', |
| 'W', 'Y', 'Z', 'ZH' |
| ] |
|
|
| phoneme_chars = map_phonemes(phonemes) |
| cmu_phoneme_chars = map_phonemes(cmu_phonemes) |
|
|
| def __init__( |
| self, video_path, alignments_dir, |
| phonemes_dir, file_list, vid_pad, |
| image_dir, txt_pad, phase, shared_dict=None, |
| char_map=CharMap.letters, base_dir='', |
| frame_doubling=False, sample_all_props=False |
| ): |
| self.base_dir = base_dir |
| self.sample_all_props = sample_all_props |
|
|
| self.image_dir = os.path.join(base_dir, image_dir) |
| self.alignments_dir = os.path.join(base_dir, alignments_dir) |
| self.phonemes_dir = os.path.join(base_dir, phonemes_dir) |
| self.frame_doubling = frame_doubling |
|
|
| if type(file_list) is str: |
| file_list = os.path.join(base_dir, file_list) |
| |
| file_list = open(file_list, 'r').readlines() |
|
|
| self.shared_dict = shared_dict |
| self.char_map = char_map |
|
|
| self.vid_pad = vid_pad |
| self.txt_pad = txt_pad |
| self.phase = phase |
|
|
| self.videos = [ |
| os.path.join(video_path, line.strip()) |
| for line in file_list |
| ] |
|
|
| self.data = [] |
| for vid in self.videos: |
| items = vid.split(os.path.sep) |
| if len(items) < 2: |
| print('BAD VID ITEM', items) |
| raise ValueError |
|
|
| speaker_name, filename = items[-2], items[-1] |
| self.data.append((vid, speaker_name, filename)) |
|
|
| def _fetch_anno_path(self, spk, basename): |
| return self.fetch_anno_path( |
| spk=spk, basename=basename, char_map=self.char_map |
| ) |
|
|
| @classmethod |
| def text_to_phonemes( |
| cls, text, as_str=True, char_map=CharMap.phonemes |
| ): |
| sentence_phonemes = [] |
|
|
| words = text.upper().strip().split(' ') |
| for word in words: |
| word_phonemes = pronouncing.phones_for_word(word)[0] |
| word_phonemes = word_phonemes.split(' ') |
| sentence_phonemes.extend(word_phonemes) |
| sentence_phonemes.append(' ') |
|
|
| if sentence_phonemes[-1] == ' ': |
| sentence_phonemes = sentence_phonemes[:-1] |
|
|
| if as_str: |
| return cls.stringify(sentence_phonemes, char_map=char_map) |
| else: |
| return sentence_phonemes |
|
|
| def fetch_anno_path(self, spk, basename, char_map): |
| if char_map == CharMap.letters: |
| align_path_name = os.path.join( |
| self.alignments_dir, spk, basename + '.align' |
| ) |
| return align_path_name |
| elif char_map == CharMap.lsr2_text: |
| align_path_name = os.path.join( |
| self.alignments_dir, spk, basename + '.txt' |
| ) |
| return align_path_name |
| elif char_map == CharMap.phonemes: |
| phonemes_path_name = os.path.join( |
| self.phonemes_dir, spk, basename + '.align' |
| ) |
| return phonemes_path_name |
| elif char_map == CharMap.cmu_phonemes: |
| phonemes_path_name = os.path.join( |
| self.phonemes_dir, spk, basename + '.txt' |
| ) |
| return phonemes_path_name |
| else: |
| raise NotImplementedError |
|
|
| def fetch_anno_text(self, spk, basename, char_map: CharMap): |
| return self.load_anno_text(self.fetch_anno_path( |
| spk, basename, char_map=char_map |
| ), char_map=char_map) |
|
|
| def __getitem__(self, idx): |
| (vid, spk, name) = self.data[idx] |
| return self.load_sample( |
| video_name=vid, speaker_name=spk, |
| filename=name |
| ) |
|
|
| def load_random_sample(self, char_map=None): |
| (vid, spk, name) = random.choice(self.data) |
| return self.load_sample( |
| video_name=vid, speaker_name=spk, |
| filename=name, char_map=char_map |
| ) |
|
|
| def load_sample( |
| self, video_name, speaker_name, filename, |
| char_map=None |
| ): |
| if char_map is None: |
| char_map = self.char_map |
| if self.sample_all_props: |
| char_map = all |
|
|
| vid = self.load_vid(video_name) |
| if self.frame_doubling: |
| vid = np.repeat(vid, repeats=2, axis=0) |
|
|
| basename, _ = os.path.splitext(filename) |
| |
| txt_results, phoneme_results = {}, {} |
| cmu_phoneme_results = {} |
|
|
| if (char_map is all) or (char_map == CharMap.letters): |
| txt_anno, txt_anno_arr = self.fetch_anno_text( |
| speaker_name, basename, char_map=CharMap.letters |
| ) |
|
|
| txt_anno_arr_len = txt_anno_arr.shape[0] |
| txt_anno_arr = self._padding(txt_anno_arr, self.txt_pad) |
| assert not np.isnan(txt_anno_arr).any() |
|
|
| txt_anno += [' '] * (options.txt_padding - len(txt_anno)) |
| txt_results = kwargify( |
| txt=torch.LongTensor(txt_anno_arr), |
| txt_len=txt_anno_arr_len, txt_anno=txt_anno |
| ) |
|
|
| if (char_map is all) or (char_map == CharMap.phonemes): |
| phoneme_anno, phoneme_anno_arr = self.fetch_anno_text( |
| speaker_name, basename, char_map=CharMap.phonemes |
| ) |
|
|
| phoneme_anno_arr_len = phoneme_anno_arr.shape[0] |
| phoneme_anno_arr = self._padding( |
| phoneme_anno_arr, self.txt_pad |
| ) |
| assert not np.isnan(phoneme_anno_arr_len).any() |
|
|
| phoneme_results = kwargify( |
| phonemes=torch.LongTensor(phoneme_anno_arr), |
| phonemes_len=phoneme_anno_arr_len, |
| ) |
|
|
| elif (char_map is all) or (char_map == CharMap.cmu_phonemes): |
| phoneme_anno, phoneme_anno_arr = self.fetch_anno_text( |
| speaker_name, basename, char_map=CharMap.cmu_phonemes |
| ) |
|
|
| phoneme_anno_arr_len = phoneme_anno_arr.shape[0] |
| phoneme_anno_arr = self._padding( |
| phoneme_anno_arr, self.txt_pad |
| ) |
| assert not np.isnan(phoneme_anno_arr_len).any() |
|
|
| cmu_phoneme_results = kwargify( |
| cmu_phonemes=torch.LongTensor(phoneme_anno_arr), |
| cmu_phonemes_len=phoneme_anno_arr_len, |
| ) |
|
|
| if self.phase == 'train': |
| vid = HorizontalFlip(vid) |
|
|
| vid = ColorNormalize(vid) |
| vid_len = vid.shape[0] |
| vid = self._padding(vid, self.vid_pad) |
|
|
| """ |
| if vid_len <= anno_len * 2: |
| raise ValueError(f'CTC INVALID: {self.data[idx]}') |
| """ |
|
|
| assert not np.isnan(vid).any() |
|
|
| return kwargify( |
| vid=torch.FloatTensor(vid.transpose(3, 0, 1, 2)), |
| vid_len=vid_len, **txt_results, **phoneme_results, |
| **cmu_phoneme_results |
| ) |
|
|
| def __len__(self): |
| return len(self.data) |
|
|
| @staticmethod |
| def serialize(data: np.ndarray): |
| return torch.from_numpy(data.astype(np.uint8)) |
|
|
| @staticmethod |
| def deserialize(data: torch.Tensor): |
| return data.numpy().astype(np.float16) |
|
|
| @staticmethod |
| def process_vid(video_path: str, to_tensor=True): |
| frames = Extractor.extract_frames( |
| video_path, recycle_landmarks=True, use_gpu=True |
| ) |
|
|
| frames = [f for f in frames if f is not None] |
| array = list(filter(lambda im: im is not None, frames)) |
| array = [ |
| cv2.resize(im, (128, 64), interpolation=cv2.INTER_LANCZOS4) |
| for im in array |
| ] |
|
|
| array = np.stack(array, axis=0).astype(np.float16) |
| vid = ColorNormalize(array) |
|
|
| if to_tensor: |
| vid = torch.FloatTensor(vid.transpose(3, 0, 1, 2)) |
|
|
| return vid |
|
|
| def load_vid(self, video_path: str) -> np.ndarray: |
| return self._load_vid(video_path, cache=False) |
|
|
| def _load_vid(self, video_path: str, cache=True) -> np.ndarray: |
| if cache and self.shared_dict is not None: |
| if video_path in self.shared_dict: |
| return self.deserialize( |
| self.shared_dict[video_path] |
| ) |
|
|
| |
| base_filename = os.path.basename(video_path) |
| basename, _ = os.path.splitext(base_filename) |
| speaker_dir = os.path.basename(os.path.dirname(video_path)) |
| image_dir = f'{self.image_dir}/{speaker_dir}/{basename}' |
|
|
| files = os.listdir(image_dir) |
| files = list(filter(lambda file: file.find('.jpg') != -1, files)) |
| files = sorted(files, key=lambda file: int(os.path.splitext(file)[0])) |
| array = [cv2.imread(os.path.join(image_dir, file)) for file in files] |
| array = list(filter(lambda im: im is not None, array)) |
| array = [ |
| cv2.resize(im, (128, 64), interpolation=cv2.INTER_LANCZOS4) |
| for im in array |
| ] |
|
|
| try: |
| array = np.stack(array, axis=0).astype(np.float16) |
| except ValueError as e: |
| print(f'BAD VIDEO PATH: {video_path}') |
| raise e |
|
|
| if cache and self.shared_dict is not None: |
| |
| serialized_data = self.serialize(array) |
| serialized_data.share_memory_() |
| self.shared_dict[video_path] = serialized_data |
| |
|
|
| return array |
|
|
| @classmethod |
| def load_anno(cls, name, char_map): |
| return cls.load_anno_text(name, char_map)[1] |
|
|
| @classmethod |
| def load_anno_text(cls, name, char_map): |
| |
| txt = cls.load_sentence(name, char_map=char_map) |
| indices = cls.txt2arr(txt, 1, char_map=char_map) |
| |
| return txt, indices |
|
|
| def _load_anno(self, name): |
| return self.load_anno(name, self.char_map) |
|
|
| @classmethod |
| def load_sentence(cls, name, char_map=CharMap.letters) -> List[str]: |
| with open(name, 'r') as f: |
| if char_map == CharMap.letters: |
| lines = [line.strip().split(' ') for line in f.readlines()] |
| txt = [line[2] for line in lines] |
| txt = list(filter( |
| lambda s: not s.upper() in ['SIL', 'SP'], txt |
| )) |
|
|
| all_chars = list(' '.join(txt)) |
| all_chars = [char.upper() for char in all_chars] |
| return all_chars |
|
|
| elif char_map == CharMap.lsr2_text: |
| text_line = f.readlines()[0] |
| text_line = text_line[5:].strip() |
| all_chars = [char.upper() for char in text_line] |
| return all_chars |
|
|
| elif char_map in (CharMap.phonemes, CharMap.cmu_phonemes): |
| all_chars = [] |
|
|
| for line in f.readlines(): |
| word_phonemes = line.strip().split(' ') |
| all_chars.extend(word_phonemes) |
| all_chars.append(' ') |
|
|
| if all_chars[-1] == ' ': |
| all_chars = all_chars[:-1] |
|
|
| return all_chars |
| else: |
| raise ValueError(f'BAD CHAR MAP {char_map}') |
|
|
| @classmethod |
| def load_str_sentence(cls, name, char_map=CharMap.letters) -> str: |
| chars_seq = cls.load_sentence(name=name, char_map=char_map) |
| return cls.stringify(chars_seq, char_map=char_map) |
|
|
| @staticmethod |
| def tokenize_text(text: str, word_tokenize=False) -> List[str]: |
| """ |
| :param text: |
| :param word_tokenize: |
| whether to tokenize into words or individual characters |
| :return: |
| """ |
| if word_tokenize: |
| return text.split(' ') |
| else: |
| return list(text) |
|
|
| @staticmethod |
| def tokenize_phonemes(text: str, word_tokenize=False) -> List[str]: |
| """ |
| :param text: |
| :param word_tokenize: |
| whether to tokenize into words or individual phonemes |
| example: |
| text = 'S-EH1-T G-R-IY1-N IH0-N EH1-L S-IH1-K-S AH0-G-EH1-N' |
| word-level tokens: |
| ['S-EH1-T', 'G-R-IY1-N', 'IH0-N', 'EH1-L', 'S-IH1-K-S', 'AH0-G-EH1-N'] |
| phoneme-level tokens: |
| ['S', 'EH1', 'T', ' ', 'G', 'R', 'IY1', 'N', ' ', 'IH0', |
| 'N', ' ', 'EH1', 'L', ' ', 'S', 'IH1', 'K', 'S', ' ', |
| 'AH0', 'G', 'EH1', 'N'] |
| :return: |
| """ |
| if word_tokenize: |
| return text.split(' ') |
| else: |
| words = text.split(' ') |
| phonemes = [] |
|
|
| for word in words: |
| assert not word.startswith('-') |
| assert not word.endswith('-') |
| phonemes.extend(word.split('-')) |
| phonemes.append(' ') |
|
|
| if phonemes[-1] == ' ': |
| phonemes = phonemes[:-1] |
|
|
| return phonemes |
|
|
| @staticmethod |
| def _padding(array, length): |
| array = [array[_] for _ in range(array.shape[0])] |
| size = array[0].shape |
|
|
| for i in range(length - len(array)): |
| array.append(np.zeros(size)) |
|
|
| return np.stack(array, axis=0) |
|
|
| @classmethod |
| def txt2arr(cls, txt, start, char_map=CharMap.letters): |
| arr = [] |
|
|
| if char_map == CharMap.letters: |
| for char in list(txt): |
| arr.append(cls.letters.index(char) + start) |
|
|
| elif char_map == CharMap.phonemes: |
| |
| for phoneme in txt: |
| arr.append(cls.phonemes.index(phoneme) + start) |
|
|
| elif char_map == CharMap.cmu_phonemes: |
| |
| for phoneme in txt: |
| arr.append(cls.cmu_phonemes.index(phoneme) + start) |
|
|
| elif char_map == CharMap.visemes: |
| raise NotImplementedError |
| else: |
| raise ValueError(f'BAD CHAR MAP: {char_map}') |
|
|
| return np.array(arr) |
|
|
| def arr2txt(self, arr, start, char_map=None): |
| char_map = self.char_map if char_map is None else char_map |
| return self._arr2txt(arr, start, char_map=char_map) |
|
|
| @classmethod |
| def _arr2txt(cls, arr, start, char_map=CharMap.letters): |
| txt = [] |
|
|
| for n in arr: |
| if n >= start: |
| if char_map == CharMap.letters: |
| txt.append(cls.letters[n - start]) |
| elif char_map == CharMap.phonemes: |
| txt.append(cls.phonemes[n - start]) |
| elif char_map == CharMap.cmu_phonemes: |
| txt.append(cls.cmu_phonemes[n - start]) |
| elif char_map == CharMap.visemes: |
| raise NotImplementedError |
| else: |
| raise ValueError(f'BAD CHAR MAP: {char_map}') |
|
|
| return cls.stringify(txt, char_map) |
|
|
| def get_char_mapping(self): |
| return self.char_mapping(self.char_map) |
|
|
| @classmethod |
| def char_mapping(cls, char_map): |
| if char_map == CharMap.letters: |
| return cls.letters |
| elif char_map == CharMap.phonemes: |
| return cls.phonemes |
| elif char_map == CharMap.cmu_phonemes: |
| return cls.cmu_phonemes |
| elif char_map == CharMap.visemes: |
| raise NotImplementedError |
| else: |
| raise ValueError(f'BAD CHAR MAP: {char_map}') |
|
|
| def ctc_decode(self, y): |
| y = y.argmax(-1) |
| return [ |
| self.ctc_arr2txt(y[_], start=1) |
| for _ in range(y.size(0)) |
| ] |
|
|
| def ctc_decode_indices(self, y): |
| y = y.argmax(-1) |
| return [ |
| self.ctc_arr2txt_indices(y[_], start=1)[1] |
| for _ in range(y.size(0)) |
| ] |
|
|
| def ctc_arr2txt(self, *args, **kwargs): |
| sentence, indices = self.ctc_arr2txt_pair(*args, **kwargs) |
| return sentence |
|
|
| def ctc_arr2txt_pair( |
| self, arr, start, char_map=None, |
| filter_previous=True |
| ): |
| """ |
| converts token indices into a string sentence |
| |
| :param arr: |
| array of token indices |
| :param start: |
| number of special characters in character set |
| :param char_map: |
| character set to use for tokenization |
| :param filter_previous: |
| if True, removes consecutive occurrences of an index / token |
| e.g. THREE becomes THRE, SOON becomes SON |
| :return: |
| """ |
| sentence, indices = self.ctc_arr2txt_indices( |
| arr=arr, start=start, char_map=char_map, |
| filter_previous=filter_previous |
| ) |
| return sentence, indices |
|
|
| def ctc_arr2txt_indices( |
| self, arr, start, char_map=None, |
| filter_previous=True |
| ): |
| """ |
| converts token indices into a string sentence |
| and indices of tokens taken along arr |
| |
| :param arr: |
| array of token indices |
| :param start: |
| number of special characters in character set |
| :param char_map: |
| character set to use for tokenization |
| :param filter_previous: |
| if True, removes consecutive occurrences of an index / token |
| e.g. THREE becomes THRE, SOON becomes SON |
| :return: |
| """ |
| if char_map is None: |
| char_map = self.char_map |
|
|
| previous = -1 |
| txt, indices = [], [] |
| char_mapping = self.char_mapping(char_map) |
|
|
| for k, n in enumerate(arr): |
| check_consecutive = ( |
| not filter_previous or previous != n |
| ) |
| if n >= start: |
| has_empty_char = ( |
| len(txt) > 0 and txt[-1] == ' ' and |
| char_mapping[n - start] == ' ' |
| ) |
|
|
| if not has_empty_char and check_consecutive: |
| txt.append(char_mapping[n - start]) |
| indices.append(k) |
|
|
| previous = n |
|
|
| sentence = self.stringify(txt, char_map) |
| return sentence, indices |
|
|
| @staticmethod |
| def stringify(txt, char_map): |
| if char_map in (CharMap.letters, CharMap.lsr2_text): |
| return ''.join(txt).strip() |
| elif char_map in (CharMap.phonemes, CharMap.cmu_phonemes): |
| sentence = '-'.join(txt).strip() |
| sentence = sentence.replace('- ', ' ') |
| sentence = sentence.replace(' -', ' ') |
| if sentence.endswith('-'): sentence = sentence[:-1] |
| if sentence.startswith('-'): sentence = sentence[1:] |
| return sentence |
| else: |
| raise NotImplementedError |
|
|
| def _map_chars(self, chars: str): |
| return self.map_chars(chars, char_map=self.char_map) |
|
|
| @classmethod |
| def map_chars(cls, chars: str, char_map: CharMap): |
| |
| |
|
|
| if char_map == CharMap.letters: |
| return chars |
| elif char_map in (CharMap.phonemes, CharMap.cmu_phonemes): |
| if char_map == CharMap.phonemes: |
| phonemes_arr = cls.phonemes |
| char_phonemes_arr = cls.phonemes |
| elif char_map == CharMap.cmu_phonemes: |
| phonemes_arr = cls.cmu_phonemes |
| char_phonemes_arr = cls.cmu_phoneme_chars |
| else: |
| raise ValueError(f'BAD CHAR MAP {char_map}') |
|
|
| words = chars.split(' ') |
| char_phonemes = '' |
|
|
| for word in words: |
| phonemes = word.split('-') |
| phonemes = [ |
| phoneme for phoneme in phonemes |
| if phoneme.strip() != '' |
| ] |
|
|
| for phoneme in phonemes: |
| char_phonemes += char_phonemes_arr[ |
| phonemes_arr.index(phoneme) |
| ] |
|
|
| char_phonemes += ' ' |
|
|
| return char_phonemes |
| elif char_map == CharMap.visemes: |
| raise NotImplementedError |
| else: |
| raise ValueError(f'BAD CHAR MAP: {char_map}') |
|
|
| @classmethod |
| def map_char_lists( |
| cls, char_lists: Iterable[str], char_map: CharMap |
| ): |
| return [cls.map_chars( |
| char_seq, char_map=char_map |
| ) for char_seq in char_lists] |
|
|
| def wer(self, raw_predict, raw_truth): |
| return self.get_wer( |
| raw_predict, raw_truth, char_map=self.char_map |
| ) |
|
|
| @classmethod |
| def get_wer(cls, raw_predict, raw_truth, char_map: CharMap): |
| assert isinstance(raw_predict, Iterable) |
| assert isinstance(raw_truth, Iterable) |
|
|
| predict = cls.map_char_lists(raw_predict, char_map=char_map) |
| truth = cls.map_char_lists(raw_truth, char_map=char_map) |
| |
|
|
| word_pairs = [ |
| (p[0].split(' '), p[1].split(' ')) |
| for p in zip(predict, truth) |
| ] |
| wer = [ |
| 1.0 * editdistance.eval(p[0], p[1])/len(p[1]) |
| for p in word_pairs |
| ] |
| return wer |
|
|
| def cer(self, raw_predict, raw_truth): |
| return self.get_cer( |
| raw_predict, raw_truth, char_map=self.char_map |
| ) |
|
|
| @classmethod |
| def get_cer(cls, raw_predict, raw_truth, char_map: CharMap): |
| assert isinstance(raw_predict, Iterable) |
| assert isinstance(raw_truth, Iterable) |
|
|
| predict = cls.map_char_lists(raw_predict, char_map=char_map) |
| truth = cls.map_char_lists(raw_truth, char_map=char_map) |
|
|
| cer = [ |
| 1.0 * editdistance.eval(p[0], p[1]) / len(p[1]) |
| for p in zip(predict, truth) |
| ] |
| return cer |
|
|