| import numpy as np |
| import cv2 |
| import os |
| from torch.utils.data import Dataset |
| from cvtransforms import * |
| import torch |
| import editdistance |
| import json |
|
|
|
|
| class MyDataset(Dataset): |
| letters = [ |
| " ", |
| "A", |
| "B", |
| "C", |
| "D", |
| "E", |
| "F", |
| "G", |
| "H", |
| "I", |
| "J", |
| "K", |
| "L", |
| "M", |
| "N", |
| "O", |
| "P", |
| "Q", |
| "R", |
| "S", |
| "T", |
| "U", |
| "V", |
| "W", |
| "X", |
| "Y", |
| "Z", |
| ] |
|
|
| def __init__( |
| self, |
| video_path, |
| anno_path, |
| coords_path, |
| file_list, |
| vid_pad, |
| txt_pad, |
| phase, |
| ): |
| self.anno_path = anno_path |
| self.coords_path = coords_path |
| self.vid_pad = vid_pad |
| self.txt_pad = txt_pad |
| self.phase = phase |
|
|
| with open(file_list, "r") as f: |
| self.videos = [ |
| os.path.join(video_path, line.strip()) for line in f.readlines() |
| ] |
|
|
| self.data = [] |
| for vid in self.videos: |
| items = vid.split("/") |
| self.data.append((vid, items[-4], items[-1])) |
|
|
| def __getitem__(self, idx): |
| (vid, spk, name) = self.data[idx] |
| vid = self._load_vid(vid) |
| anno = self._load_anno( |
| os.path.join(self.anno_path, spk, "align", name + ".align") |
| ) |
| coord = self._load_coords(os.path.join(self.coords_path, spk, name + ".json")) |
|
|
| if self.phase == "train": |
| vid = HorizontalFlip(vid) |
|
|
| vid = ColorNormalize(vid) |
|
|
| vid_len = vid.shape[0] |
| anno_len = anno.shape[0] |
| vid = self._padding(vid, self.vid_pad) |
| anno = self._padding(anno, self.txt_pad) |
| coord = self._padding(coord, self.vid_pad) |
|
|
| return { |
| "vid": torch.FloatTensor(vid.transpose(3, 0, 1, 2)), |
| "txt": torch.LongTensor(anno), |
| "coord": torch.FloatTensor(coord), |
| "txt_len": anno_len, |
| "vid_len": vid_len, |
| } |
|
|
| def __len__(self): |
| return len(self.data) |
|
|
| def _load_vid(self, p): |
| files = os.listdir(p) |
| files = list(filter(lambda file: file.find(".jpg") != -1, files)) |
| files = sorted(files, key=lambda file: int(os.path.splitext(file)[0])) |
| array = [cv2.imread(os.path.join(p, file)) for file in files] |
| array = list(filter(lambda im: not im is None, array)) |
| array = [ |
| cv2.resize(im, (128, 64), interpolation=cv2.INTER_LANCZOS4) for im in array |
| ] |
| array = np.stack(array, axis=0).astype(np.float32) |
|
|
| return array |
|
|
| def _load_anno(self, name): |
| with open(name, "r") as f: |
| lines = [line.strip().split(" ") for line in f.readlines()] |
| txt = [line[2] for line in lines] |
| txt = list(filter(lambda s: not s.upper() in ["SIL", "SP"], txt)) |
| return MyDataset.txt2arr(" ".join(txt).upper(), 1) |
|
|
| def _load_coords(self, name): |
| |
| img_width = 600 |
| img_height = 500 |
| with open(name, "r") as f: |
| coords_data = json.load(f) |
|
|
| coords = [] |
| for frame in sorted(coords_data.keys(), key=int): |
| frame_coords = coords_data[frame] |
|
|
| |
| normalized_coords = [] |
| for x, y in zip(frame_coords[0], frame_coords[1]): |
| normalized_x = x / img_width |
| normalized_y = y / img_height |
| normalized_coords.append((normalized_x, normalized_y)) |
|
|
| coords.append(normalized_coords) |
| coords_array = np.array(coords, dtype=np.float32) |
| return coords_array |
|
|
| def _padding(self, array, length): |
| array = [array[_] for _ in range(array.shape[0])] |
| size = array[0].shape |
| for i in range(length - len(array)): |
| array.append(np.zeros(size)) |
| return np.stack(array, axis=0) |
|
|
| @staticmethod |
| def txt2arr(txt, start): |
| arr = [] |
| for c in list(txt): |
| arr.append(MyDataset.letters.index(c) + start) |
| return np.array(arr) |
|
|
| @staticmethod |
| def arr2txt(arr, start): |
| txt = [] |
| for n in arr: |
| if n >= start: |
| txt.append(MyDataset.letters[n - start]) |
| return "".join(txt).strip() |
|
|
| @staticmethod |
| def ctc_arr2txt(arr, start): |
| pre = -1 |
| txt = [] |
| for n in arr: |
| if pre != n and n >= start: |
| if ( |
| len(txt) > 0 |
| and txt[-1] == " " |
| and MyDataset.letters[n - start] == " " |
| ): |
| pass |
| else: |
| txt.append(MyDataset.letters[n - start]) |
| pre = n |
| return "".join(txt).strip() |
|
|
| @staticmethod |
| def wer(predict, truth): |
| word_pairs = [(p[0].split(" "), p[1].split(" ")) for p in zip(predict, truth)] |
| wer = [1.0 * editdistance.eval(p[0], p[1]) / len(p[1]) for p in word_pairs] |
| return wer |
|
|
| @staticmethod |
| def cer(predict, truth): |
| cer = [ |
| 1.0 * editdistance.eval(p[0], p[1]) / len(p[1]) for p in zip(predict, truth) |
| ] |
| return cer |
|
|