File size: 4,057 Bytes
87dd991
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import os
import numpy as np
import h5py
import cv2
from torch.utils.data import Dataset
from typing import List
from omegaconf import OmegaConf, listconfig
from .helper.image_transform import wrap_transforms


class GazeCaptureDataset(Dataset):
	def __init__(self, 
				dataset_path: str, 
				color_type,
				keys_to_use: List[str] = None, 
				data_name=None, 
				image_size:int=224,  ## <--- 
				transform_type='basic_imagenet', ## <--- modified
				image_key='face_patch',
				gaze_key='face_gaze',
				sample_rate_use=1,
				):
		
		self.transform = wrap_transforms(transform_type, image_size=image_size)

		self.path = dataset_path
		self.hdfs = {}
		self.data_name = data_name
		self.image_key = image_key
		self.gaze_key = gaze_key

		self.image_size = (image_size, image_size)

		self.sample_rate_use = sample_rate_use

		assert color_type in ['rgb', 'bgr']
		self.color_type = color_type
		self.selected_keys = [ k for k in keys_to_use]
		assert len(self.selected_keys) > 0
		
		self.file_paths = [os.path.join(self.path, k) for k in self.selected_keys]
		for num_i in range(0, len(self.selected_keys)):
			file_path = os.path.join(self.path, self.selected_keys[num_i]) # the subdirectories: train, test are not used in MPIIFaceGaze and MPII_Rotate
			self.hdfs[num_i] = h5py.File(file_path, 'r', swmr=True)
			print('read file: ', os.path.join(self.path, self.selected_keys[num_i]))
			assert self.hdfs[num_i].swmr_mode


		self.build_idx_to_kv()
		

		for num_i in range(0, len(self.hdfs)):            
			if self.hdfs[num_i]:
				self.hdfs[num_i].close()
				self.hdfs[num_i] = None

		self.__hdfs = None
		self.hdf = None

	def __len__(self):
		return len(self.idx_to_kv)

	def __del__(self):
		for num_i in range(0, len(self.hdfs)):
			if self.hdfs[num_i]:
				self.hdfs[num_i].close()
				self.hdfs[num_i] = None

	def build_idx_to_kv(self):
		self.idx_to_kv = []
		self.key_idx_dict = {}
		for num_i in range(0, len(self.selected_keys)):
			this_sub = self.selected_keys[num_i].split('.')[0]
			n = self.hdfs[num_i][self.image_key].shape[0] 
			if self.sample_rate_use > 1:
				indices = np.arange(0, n, self.sample_rate_use)
			else:
				indices = np.arange(0, n)
			self.idx_to_kv += [(num_i, i) for i in indices ]
			self.key_idx_dict[this_sub] = [ i for i in indices ]

	@property
	def archives(self):
		if self.__hdfs is None: # lazy loading here!
			self.__hdfs = [h5py.File(h5_path, "r", swmr=True) for h5_path in self.file_paths]
		return self.__hdfs


	def preprocess_image(self, image):
		image = image.astype(np.float32)
		if self.color_type == 'bgr':
			image = image[..., ::-1]
		image = cv2.resize(image, self.image_size, interpolation=cv2.INTER_AREA)
		image = self.transform(image.astype(np.uint8)		)
		return image

	def __getitem__(self, index):
		
		key, idx = self.idx_to_kv[index]
		self.hdf = self.archives[key]

		# self.hdf = h5py.File(os.path.join(self.path, self.selected_keys[key]), 'r', swmr=True)
		assert self.hdf.swmr_mode

		image = self.hdf[self.image_key][idx, :]
		gaze_label = self.hdf[self.gaze_key][idx].astype('float') if self.gaze_key in self.hdf else np.array([0,0]).astype('float')
		head_label = self.hdf['face_head_pose'][idx].astype('float') if 'face_head_pose' in self.hdf else np.array([0,0]).astype('float')

		entry = {
			'image': self.preprocess_image(image),
			'gaze': gaze_label,
			'head': head_label,
			'key': key,
			'index':index
		}
		return entry

# class GazeCaptureDatasetSubset(GazeCaptureDataset):
# 	def __init__(self, images_per_person=None, **kwargs):
# 		self.images_per_person = images_per_person
# 		super().__init__(**kwargs)

# 	def build_idx_to_kv(self):
# 		self.idx_to_kv = []
# 		self.key_idx_dict = {}
# 		for num_i in range(0, len(self.selected_keys)):
# 			this_sub = self.selected_keys[num_i].split('.')[0]
# 			n = self.hdfs[num_i][self.image_key].shape[0] 
# 			if self.images_per_person is not None:
# 				n = min(n, self.images_per_person)
# 			self.idx_to_kv += [(num_i, i) for i in range(n)]
# 			self.key_idx_dict[this_sub] = [ i for i in range(n)]