Spaces:
Sleeping
Sleeping
File size: 4,992 Bytes
87dd991 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 | import os,random
import numpy as np
import h5py
import cv2
from typing import List
from torch.utils.data import Dataset
from .helper.image_transform import wrap_transforms
class XGazeDataset(Dataset):
def __init__(self,
dataset_path: str,
color_type,
images_per_frame,
keys_to_use: List[str] = None,
data_name=None,
image_size:int=224,
transform_type='basic_imagenet', ## <--- modified
image_key='face_patch',
gaze_key='face_gaze',
camera_random=None,
frame_tag=[0,1000],
seed=0,
):
self.path = dataset_path
self.hdfs = {}
self.data_name = data_name
self.images_per_frame = images_per_frame
print('images_per_frame: ', images_per_frame)
self.image_key = image_key
self.gaze_key = gaze_key
self.image_size = (image_size, image_size)
random.seed(seed)
assert color_type in ['rgb', 'bgr']
self.color_type = color_type
self.cameras_idx = list(range(self.images_per_frame))
self.camera_random = camera_random
#### -------------------------------------------------------- read the h5 files -------------------------------------------------------
self.selected_keys = [k for k in keys_to_use]
assert len(self.selected_keys) > 0
self.file_paths = [os.path.join(self.path, k) for k in self.selected_keys]
for num_i in range(0, len(self.selected_keys)):
file_path = os.path.join(self.path, self.selected_keys[num_i]) # the subdirectories: train, test are not used in MPIIFaceGaze and MPII_Rotate
self.hdfs[num_i] = h5py.File(file_path, 'r', swmr=True)
print('read file: ', os.path.join(self.path, self.selected_keys[num_i]))
assert self.hdfs[num_i].swmr_mode
####-----------------------------------------------------------------------------------------------------------------------------------
self.idx_to_kv = []
self.key_idx_dict = {} ## this is for reading the second sample from the same person
for num_i in range(0, len(self.selected_keys)):
this_sub = self.selected_keys[num_i].split('.')[0]
n = self.hdfs[num_i][image_key].shape[0]
if type(frame_tag) == list:
self.start_frame, self.end_frame = frame_tag
elif frame_tag == 'all':
self.start_frame, self.end_frame = 0, 10000
else:
raise ValueError("frame_tag should be either a list of integers or str 'all' ")
start_idx = min(n, self.start_frame * self.images_per_frame)
end_idx = min(n, self.end_frame * self.images_per_frame)
if self.camera_random is None:
self.idx_to_kv += [(num_i, i) for i in range(start_idx, end_idx) if (i % self.images_per_frame ) in self.cameras_idx ]
self.key_idx_dict[this_sub] = [ i for i in range(start_idx, end_idx) if (i % self.images_per_frame ) in self.cameras_idx ]
else:
for frame in range(start_idx // self.images_per_frame, end_idx // self.images_per_frame):
frame_start_idx = frame * self.images_per_frame
frame_end_idx = frame_start_idx + self.images_per_frame
# Randomly select self.images_per_frame camera indices for this frame
random_cameras_idx = random.sample(range(self.images_per_frame), self.camera_random)
self.idx_to_kv += [(num_i, i) for i in range(frame_start_idx, frame_end_idx) if (i % self.images_per_frame) in random_cameras_idx]
self.key_idx_dict.setdefault(this_sub, []).extend(
[i for i in range(frame_start_idx, frame_end_idx) if (i % self.images_per_frame) in random_cameras_idx]
)
for num_i in range(0, len(self.hdfs)):
if self.hdfs[num_i]:
self.hdfs[num_i].close()
self.hdfs[num_i] = None
self.transform = wrap_transforms(transform_type, image_size=image_size)
self.__hdfs = None
self.hdf = None
def __len__(self):
return len(self.idx_to_kv)
def __del__(self):
for num_i in range(0, len(self.hdfs)):
if self.hdfs[num_i]:
self.hdfs[num_i].close()
self.hdfs[num_i] = None
@property
def archives(self):
if self.__hdfs is None: # lazy loading here!
self.__hdfs = [h5py.File(h5_path, "r", swmr=True) for h5_path in self.file_paths]
return self.__hdfs
def preprocess_image(self, image):
image = image.astype(np.float32)
if self.color_type == 'bgr':
image = image[..., ::-1]
if image.shape[0] != self.image_size[0] or image.shape[1] != self.image_size[1]:
image = cv2.resize(image, self.image_size, interpolation=cv2.INTER_AREA)
image = self.transform( image.astype(np.uint8) )
return image
def __getitem__(self, index):
key, idx = self.idx_to_kv[index]
self.hdf = self.archives[key]
assert self.hdf.swmr_mode
image = self.hdf[self.image_key][idx, :]
gaze_label = self.hdf[self.gaze_key][idx].astype('float') if self.gaze_key in self.hdf else np.array([0,0]).astype('float')
head_label = self.hdf['face_head_pose'][idx].astype('float') if 'face_head_pose' in self.hdf else np.array([0,0]).astype('float')
entry = {
'image': self.preprocess_image(image),
'gaze': gaze_label,
'head': head_label,
'key': key,
'index':index
}
return entry
|