| import os | |
| import random | |
| import glob | |
| import cv2 | |
| import numpy as np | |
| from torch.utils.data import Dataset | |
| class UnlabeledVideoDataset(Dataset): | |
| def __init__(self, root_dir, content=None, transform=None): | |
| self.root_dir = os.path.normpath(root_dir) | |
| self.transform = transform | |
| if content is not None: | |
| self.content = content | |
| else: | |
| self.content = [] | |
| for path in glob.iglob(os.path.join(self.root_dir, '**', '*.mp4'), recursive=True): | |
| rel_path = path[len(self.root_dir) + 1:] | |
| self.content.append(rel_path) | |
| self.content = sorted(self.content) | |
| def __len__(self): | |
| return len(self.content) | |
| def __getitem__(self, idx): | |
| rel_path = self.content[idx] | |
| path = os.path.join(self.root_dir, rel_path) | |
| capture = cv2.VideoCapture(path) | |
| frames = [] | |
| if capture.isOpened(): | |
| while True: | |
| ret, frame = capture.read() | |
| if not ret: | |
| break | |
| if self.transform is not None: | |
| frame = self.transform(frame) | |
| frames.append(frame) | |
| sample = { | |
| 'frames': frames, | |
| 'index': idx | |
| } | |
| return sample | |
| class FaceDataset(Dataset): | |
| def __init__(self, root_dir, content, labels=None, transform=None): | |
| self.root_dir = os.path.normpath(root_dir) | |
| self.content = content | |
| self.labels = labels | |
| self.transform = transform | |
| def __len__(self): | |
| return len(self.content) | |
| def __getitem__(self, idx): | |
| rel_path = self.content[idx] | |
| path = os.path.join(self.root_dir, rel_path) | |
| face = cv2.imread(path, cv2.IMREAD_COLOR) | |
| face = cv2.cvtColor(face, cv2.COLOR_BGR2RGB) | |
| if self.transform is not None: | |
| face = self.transform(image=face)['image'] | |
| sample = { | |
| 'face': face, | |
| 'index': idx | |
| } | |
| if self.labels is not None: | |
| sample['label'] = self.labels[idx] | |
| return sample | |
| class TrackPairDataset(Dataset): | |
| FPS = 30 | |
| def __init__(self, tracks_root, pairs_path, indices, track_length, track_transform=None, image_transform=None, | |
| sequence_mode=True): | |
| self.tracks_root = os.path.normpath(tracks_root) | |
| self.track_transform = track_transform | |
| self.image_transform = image_transform | |
| self.indices = np.asarray(indices, dtype=np.int32) | |
| self.track_length = track_length | |
| self.sequence_mode = sequence_mode | |
| self.pairs = [] | |
| with open(pairs_path, 'r') as f: | |
| for line in f: | |
| real_track, fake_track = line.strip().split(',') | |
| self.pairs.append((real_track, fake_track)) | |
| def __len__(self): | |
| return len(self.pairs) | |
| def __getitem__(self, idx): | |
| real_track_path, fake_track_path = self.pairs[idx] | |
| real_track_path = os.path.join(self.tracks_root, real_track_path) | |
| fake_track_path = os.path.join(self.tracks_root, fake_track_path) | |
| if self.track_transform is not None: | |
| img = self.load_img(real_track_path, 0) | |
| src_height, src_width = img.shape[:2] | |
| track_transform_params = self.track_transform.get_params(self.FPS, src_height, src_width) | |
| else: | |
| track_transform_params = None | |
| real_track = self.load_track(real_track_path, self.indices, track_transform_params) | |
| fake_track = self.load_track(fake_track_path, self.indices, track_transform_params) | |
| if self.image_transform is not None: | |
| prev_state = random.getstate() | |
| transformed_real_track = [] | |
| for img in real_track: | |
| if self.sequence_mode: | |
| random.setstate(prev_state) | |
| transformed_real_track.append(self.image_transform(image=img)['image']) | |
| real_track = transformed_real_track | |
| random.setstate(prev_state) | |
| transformed_fake_track = [] | |
| for img in fake_track: | |
| if self.sequence_mode: | |
| random.setstate(prev_state) | |
| transformed_fake_track.append(self.image_transform(image=img)['image']) | |
| fake_track = transformed_fake_track | |
| sample = { | |
| 'real': real_track, | |
| 'fake': fake_track | |
| } | |
| return sample | |
| def load_img(self, track_path, idx): | |
| img = cv2.imread(os.path.join(track_path, '{}.png'.format(idx))) | |
| img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | |
| return img | |
| def load_track(self, track_path, indices, transform_params): | |
| if transform_params is None: | |
| track = np.stack([self.load_img(track_path, idx) for idx in indices]) | |
| else: | |
| track = self.track_transform(track_path, self.FPS, *transform_params) | |
| indices = (indices.astype(np.float32) / self.track_length) * len(track) | |
| indices = np.round(indices).astype(np.int32).clip(0, len(track) - 1) | |
| track = track[indices] | |
| return track | |