deep / datasets.py
Aryan6192's picture
deep
79cf6ef verified
import os
import random
import glob
import cv2
import numpy as np
from torch.utils.data import Dataset
class UnlabeledVideoDataset(Dataset):
def __init__(self, root_dir, content=None, transform=None):
self.root_dir = os.path.normpath(root_dir)
self.transform = transform
if content is not None:
self.content = content
else:
self.content = []
for path in glob.iglob(os.path.join(self.root_dir, '**', '*.mp4'), recursive=True):
rel_path = path[len(self.root_dir) + 1:]
self.content.append(rel_path)
self.content = sorted(self.content)
def __len__(self):
return len(self.content)
def __getitem__(self, idx):
rel_path = self.content[idx]
path = os.path.join(self.root_dir, rel_path)
capture = cv2.VideoCapture(path)
frames = []
if capture.isOpened():
while True:
ret, frame = capture.read()
if not ret:
break
if self.transform is not None:
frame = self.transform(frame)
frames.append(frame)
sample = {
'frames': frames,
'index': idx
}
return sample
class FaceDataset(Dataset):
def __init__(self, root_dir, content, labels=None, transform=None):
self.root_dir = os.path.normpath(root_dir)
self.content = content
self.labels = labels
self.transform = transform
def __len__(self):
return len(self.content)
def __getitem__(self, idx):
rel_path = self.content[idx]
path = os.path.join(self.root_dir, rel_path)
face = cv2.imread(path, cv2.IMREAD_COLOR)
face = cv2.cvtColor(face, cv2.COLOR_BGR2RGB)
if self.transform is not None:
face = self.transform(image=face)['image']
sample = {
'face': face,
'index': idx
}
if self.labels is not None:
sample['label'] = self.labels[idx]
return sample
class TrackPairDataset(Dataset):
FPS = 30
def __init__(self, tracks_root, pairs_path, indices, track_length, track_transform=None, image_transform=None,
sequence_mode=True):
self.tracks_root = os.path.normpath(tracks_root)
self.track_transform = track_transform
self.image_transform = image_transform
self.indices = np.asarray(indices, dtype=np.int32)
self.track_length = track_length
self.sequence_mode = sequence_mode
self.pairs = []
with open(pairs_path, 'r') as f:
for line in f:
real_track, fake_track = line.strip().split(',')
self.pairs.append((real_track, fake_track))
def __len__(self):
return len(self.pairs)
def __getitem__(self, idx):
real_track_path, fake_track_path = self.pairs[idx]
real_track_path = os.path.join(self.tracks_root, real_track_path)
fake_track_path = os.path.join(self.tracks_root, fake_track_path)
if self.track_transform is not None:
img = self.load_img(real_track_path, 0)
src_height, src_width = img.shape[:2]
track_transform_params = self.track_transform.get_params(self.FPS, src_height, src_width)
else:
track_transform_params = None
real_track = self.load_track(real_track_path, self.indices, track_transform_params)
fake_track = self.load_track(fake_track_path, self.indices, track_transform_params)
if self.image_transform is not None:
prev_state = random.getstate()
transformed_real_track = []
for img in real_track:
if self.sequence_mode:
random.setstate(prev_state)
transformed_real_track.append(self.image_transform(image=img)['image'])
real_track = transformed_real_track
random.setstate(prev_state)
transformed_fake_track = []
for img in fake_track:
if self.sequence_mode:
random.setstate(prev_state)
transformed_fake_track.append(self.image_transform(image=img)['image'])
fake_track = transformed_fake_track
sample = {
'real': real_track,
'fake': fake_track
}
return sample
def load_img(self, track_path, idx):
img = cv2.imread(os.path.join(track_path, '{}.png'.format(idx)))
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
return img
def load_track(self, track_path, indices, transform_params):
if transform_params is None:
track = np.stack([self.load_img(track_path, idx) for idx in indices])
else:
track = self.track_transform(track_path, self.FPS, *transform_params)
indices = (indices.astype(np.float32) / self.track_length) * len(track)
indices = np.round(indices).astype(np.int32).clip(0, len(track) - 1)
track = track[indices]
return track