File size: 3,359 Bytes
87dd991
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import os
import numpy as np
import h5py
import cv2
from torch.utils.data import Dataset
from typing import List
from omegaconf import OmegaConf, listconfig
from .helper.image_transform import wrap_transforms


class MPIIGazeDataset(Dataset):
	def __init__(self, 
				dataset_path: str, 
				color_type,
				keys_to_use: List[str] = None, 
				data_name=None, 
				image_size:int=224,  ## <--- 
				transform_type='basic_imagenet', ## <--- modified
				image_key='face_patch',
				gaze_key='face_gaze',
				):
		
		self.dataset_path = dataset_path
		self.hdfs = {}
		self.data_name = data_name
		self.image_key = image_key
		self.gaze_key = gaze_key
		self.image_size = (image_size, image_size)

		assert color_type in ['rgb', 'bgr']
		self.color_type = color_type
		self.transform = wrap_transforms(transform_type, image_size=image_size)


		self.selected_keys = [k for k in keys_to_use]
		assert len(self.selected_keys) > 0
		
		self.file_paths = [os.path.join(self.dataset_path, k) for k in self.selected_keys]

		for num_i in range(0, len(self.selected_keys)):
			file_path = os.path.join(self.dataset_path, self.selected_keys[num_i]) # the subdirectories: train, test are not used in MPIIFaceGaze and MPII_Rotate
			self.hdfs[num_i] = h5py.File(file_path, 'r', swmr=True)
			print('read file: ', os.path.join(self.dataset_path, self.selected_keys[num_i]))
			assert self.hdfs[num_i].swmr_mode

		self.build_idx_to_kv()

		for num_i in range(0, len(self.hdfs)):            
			if self.hdfs[num_i]:
				self.hdfs[num_i].close()
				self.hdfs[num_i] = None


		
		self.__hdfs = None
		self.hdf = None

	def __len__(self):
		return len(self.idx_to_kv)

	def __del__(self):
		for num_i in range(0, len(self.hdfs)):
			if self.hdfs[num_i]:
				self.hdfs[num_i].close()
				self.hdfs[num_i] = None

	def build_idx_to_kv(self):
		
		self.idx_to_kv = []
		self.key_idx_dict = {}
		for num_i in range(0, len(self.selected_keys)):
			p_key = self.selected_keys[num_i].split('.')[0]  ##p00
			n = self.hdfs[num_i][self.image_key].shape[0] 
			self.idx_to_kv += [(num_i, i) for i in range(n)]
			self.key_idx_dict[p_key] = [i for i in range(n)]
	@property
	def archives(self):
		if self.__hdfs is None: # lazy loading here!
			self.__hdfs = [h5py.File(h5_path, "r", swmr=True) for h5_path in self.file_paths]
		return self.__hdfs


	def preprocess_image(self, image):
		image = image.astype(np.float32)
		if self.color_type == 'bgr':
			image = image[..., ::-1]
		if image.shape[0] != self.image_size[0] or image.shape[1] != self.image_size[1]:
			image = cv2.resize(image, self.image_size, interpolation=cv2.INTER_AREA)
		image = self.transform(image.astype(np.uint8)		)
		return image

	def __getitem__(self, index):
		key, idx = self.idx_to_kv[index]
		self.hdf = self.archives[key]
		# self.hdf = h5py.File(os.path.join(self.dataset_path, self.selected_keys[key]), 'r', swmr=True)
		assert self.hdf.swmr_mode
		image = self.hdf[self.image_key][idx, :]
		gaze_label = self.hdf[self.gaze_key][idx].astype('float') if self.gaze_key in self.hdf else np.array([0,0]).astype('float')
		head_label = self.hdf['face_head_pose'][idx].astype('float') if 'face_head_pose' in self.hdf else np.array([0,0]).astype('float')
		entry = {
			'image': self.preprocess_image(image),
			'gaze': gaze_label,
			'head': head_label,
			'key': key,
			'index':index
		}

		return entry