diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..ae351b4fb07331a15de0de190c78d3fcce59ad43 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +examples/De_Nachtwacht.png filter=lfs diff=lfs merge=lfs -text +examples/The_Night_Watch_Frans_Banninck_Cocq.png filter=lfs diff=lfs merge=lfs -text diff --git a/data/face_model.txt b/data/face_model.txt new file mode 100644 index 0000000000000000000000000000000000000000..da16e6547aee4b24db8fcbced2d2b27e64851433 --- /dev/null +++ b/data/face_model.txt @@ -0,0 +1,50 @@ +5.862468481063842773e-01 7.872964477539062500e+01 2.317634201049804688e+01 +-5.711375045776367188e+01 -5.130039978027343750e+01 4.678271484375000000e+01 +-5.021675109863281250e+01 -5.602691268920898438e+01 3.416214370727539062e+01 +-3.879566955566406250e+01 -5.690497207641601562e+01 2.192905616760253906e+01 +-2.962696456909179688e+01 -5.768646621704101562e+01 1.585745716094970703e+01 +-1.556392288208007812e+01 -5.381772232055664062e+01 1.200321197509765625e+01 +1.493891811370849609e+01 -5.252636718750000000e+01 1.241601753234863281e+01 +2.762125968933105469e+01 -5.633798599243164062e+01 1.620070838928222656e+01 +3.687218856811523438e+01 -5.588240051269531250e+01 2.234012985229492188e+01 +4.801872634887695312e+01 -5.413969039916992188e+01 3.287670516967773438e+01 +5.493420410156250000e+01 -4.876091766357421875e+01 4.391139984130859375e+01 +9.755885004997253418e-01 -3.599571609497070312e+01 1.533371734619140625e+01 +1.295488834381103516e+00 -1.837105178833007812e+01 1.295253086090087891e+01 +1.169039964675903320e+00 -5.502729415893554688e+00 6.933759689331054688e+00 +1.324353933334350586e+00 5.223155975341796875e+00 3.281763553619384766e+00 +-1.061166477203369141e+01 1.295834922790527344e+01 2.162276458740234375e+01 +-5.147602558135986328e+00 1.608338356018066406e+01 1.863278388977050781e+01 +7.948544025421142578e-01 1.780137062072753906e+01 1.740065383911132812e+01 +6.404633045196533203e+00 1.649684906005859375e+01 1.887524223327636719e+01 +1.128962993621826172e+01 1.386424446105957031e+01 2.183790016174316406e+01 +-4.650949859619140625e+01 -3.832709503173828125e+01 3.641600418090820312e+01 +-3.662562179565429688e+01 -4.003409194946289062e+01 2.697853851318359375e+01 +-2.613725852966308594e+01 -4.035707473754882812e+01 2.568147850036621094e+01 +-1.776072120666503906e+01 -3.258519744873046875e+01 2.907615661621093750e+01 +-2.857307624816894531e+01 -3.133931159973144531e+01 2.851314163208007812e+01 +-3.531597518920898438e+01 -3.336409759521484375e+01 2.953546142578125000e+01 +1.804391098022460938e+01 -3.095682334899902344e+01 2.906296730041503906e+01 +2.545973777770996094e+01 -3.785017395019531250e+01 2.660374259948730469e+01 +3.494161224365234375e+01 -3.641166687011718750e+01 2.815935897827148438e+01 +4.473758697509765625e+01 -3.410787200927734375e+01 3.673243713378906250e+01 +3.460580825805664062e+01 -2.936051368713378906e+01 3.002419853210449219e+01 +2.828340530395507812e+01 -2.810362434387207031e+01 2.857681274414062500e+01 +-2.000109672546386719e+01 3.587311935424804688e+01 2.467940139770507812e+01 +-1.517112541198730469e+01 3.055978584289550781e+01 2.077887535095214844e+01 +-4.272953987121582031e+00 2.849174499511718750e+01 1.563890647888183594e+01 +9.129478931427001953e-01 2.940682983398437500e+01 1.530903434753417969e+01 +5.915512084960937500e+00 2.886590385437011719e+01 1.577433967590332031e+01 +1.609077262878417969e+01 3.112099075317382812e+01 2.045352745056152344e+01 +2.146691894531250000e+01 3.712250137329101562e+01 2.439267730712890625e+01 +1.636684226989746094e+01 4.110508346557617188e+01 2.019831085205078125e+01 +8.093836784362792969e+00 4.461882400512695312e+01 1.674007606506347656e+01 +6.376140713691711426e-01 4.504141998291015625e+01 1.691001510620117188e+01 +-6.237400531768798828e+00 4.375403594970703125e+01 1.704776954650878906e+01 +-1.439151859283447266e+01 4.025728225708007812e+01 2.136468315124511719e+01 +-9.422926902770996094e+00 3.452179336547851562e+01 2.028252601623535156e+01 +1.115690827369689941e+00 3.555863952636718750e+01 1.827753639221191406e+01 +1.108111095428466797e+01 3.538360214233398438e+01 2.027869033813476562e+01 +1.114828586578369141e+01 3.651076889038085938e+01 2.039755630493164062e+01 +9.804738759994506836e-01 3.681156921386718750e+01 1.785094261169433594e+01 +-9.598259925842285156e+00 3.567073822021484375e+01 2.036244964599609375e+01 \ No newline at end of file diff --git a/datasets/__init__.py b/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/datasets/eyediap.py b/datasets/eyediap.py new file mode 100644 index 0000000000000000000000000000000000000000..1e7572dbf07bf2ac4ad7d33e02bc87e6428730d1 --- /dev/null +++ b/datasets/eyediap.py @@ -0,0 +1,103 @@ +import os +import numpy as np +import h5py +import cv2 +from torch.utils.data import Dataset +from typing import List +from omegaconf import OmegaConf, listconfig +from .helper.image_transform import wrap_transforms + + +class EYEDIAPDataset(Dataset): + def __init__(self, + dataset_path: str, + color_type, + keys_to_use: List[str] = None, + data_name=None, + image_size:int=224, ## <--- + transform_type='basic_imagenet', ## <--- modified + image_key='face_patch', + gaze_key='face_gaze', + ): + + self.path = dataset_path + self.hdfs = {} + self.data_name = data_name + self.image_key = image_key + self.gaze_key = gaze_key + + self.image_size = (image_size, image_size) + + assert color_type in ['rgb', 'bgr'] + self.color_type = color_type + self.selected_keys = [k for k in keys_to_use] + assert len(self.selected_keys) > 0 + + self.file_paths = [os.path.join(self.path, k) for k in self.selected_keys] + + for num_i in range(0, len(self.selected_keys)): + file_path = os.path.join(self.path, self.selected_keys[num_i]) # the subdirectories: train, test are not used in MPIIFaceGaze and MPII_Rotate + self.hdfs[num_i] = h5py.File(file_path, 'r', swmr=True) + print('read file: ', os.path.join(self.path, self.selected_keys[num_i])) + assert self.hdfs[num_i].swmr_mode + + self.build_idx_to_kv() + + for num_i in range(0, len(self.hdfs)): + if self.hdfs[num_i]: + self.hdfs[num_i].close() + self.hdfs[num_i] = None + self.transform = wrap_transforms(transform_type, image_size=image_size) + self.__hdfs = None + self.hdf = None + + def __len__(self): + return len(self.idx_to_kv) + + def __del__(self): + for num_i in range(0, len(self.hdfs)): + if self.hdfs[num_i]: + self.hdfs[num_i].close() + self.hdfs[num_i] = None + + def build_idx_to_kv(self): + self.idx_to_kv = [] + self.key_idx_dict = {} + for num_i in range(0, len(self.selected_keys)): + this_sub = self.selected_keys[num_i].split('.')[0] + n = self.hdfs[num_i][self.image_key].shape[0] + self.idx_to_kv += [(num_i, i) for i in range(n)] + self.key_idx_dict[this_sub] = [ i for i in range(n)] + + @property + def archives(self): + if self.__hdfs is None: # lazy loading here! + self.__hdfs = [h5py.File(h5_path, "r", swmr=True) for h5_path in self.file_paths] + return self.__hdfs + + + def preprocess_image(self, image): + image = image.astype(np.float32) + if self.color_type == 'bgr': + image = image[..., ::-1] + image = cv2.resize(image, self.image_size, interpolation=cv2.INTER_AREA) + image = self.transform(image.astype(np.uint8) ) + return image + + def __getitem__(self, index): + key, idx = self.idx_to_kv[index] + self.hdf = self.archives[key] + assert self.hdf.swmr_mode + + image = self.hdf[self.image_key][idx, :] + gaze_label = self.hdf[self.gaze_key][idx].astype('float') if self.gaze_key in self.hdf else np.array([0,0]).astype('float') + head_label = self.hdf['face_head_pose'][idx].astype('float') if 'face_head_pose' in self.hdf else np.array([0,0]).astype('float') + + entry = { + 'image': self.preprocess_image(image), + 'gaze': gaze_label, + 'head': head_label, + 'key': key, + 'index':index + } + return entry diff --git a/datasets/gaze360.py b/datasets/gaze360.py new file mode 100644 index 0000000000000000000000000000000000000000..b7b23ba94b5005442d29a86a3f44a6bd88160c66 --- /dev/null +++ b/datasets/gaze360.py @@ -0,0 +1,106 @@ +import os +import numpy as np +import h5py, cv2 +from torch.utils.data import Dataset +from typing import List +from .helper.image_transform import wrap_transforms + + +class Gaze360Dataset(Dataset): + def __init__(self, + dataset_path: str, + color_type, + keys_to_use: List[str] = None, + data_name=None, + image_size:int=224, + transform_type='basic_imagenet', + image_key='face_patch', + gaze_key='face_gaze', + sample_rate_use=1, + ): + super().__init__() + self.dataset_path = dataset_path + self.hdfs = {} + self.data_name = data_name + self.image_key = image_key + self.gaze_key = gaze_key + self.image_size = (image_size, image_size) + + assert color_type in ['rgb', 'bgr'] + self.color_type = color_type + self.transform = wrap_transforms(transform_type, image_size=image_size) + + self.sample_rate_use = sample_rate_use + #### -------------------------------------------------------- read the h5 files ------------------------------------------------------- + self.selected_keys = [k for k in keys_to_use] + assert len(self.selected_keys) > 0 + self.file_paths = [os.path.join(self.dataset_path, k) for k in self.selected_keys] + for num_i in range(0, len(self.selected_keys)): + file_path = os.path.join(self.dataset_path, self.selected_keys[num_i]) # the subdirectories: train, test are not used in MPIIFaceGaze and MPII_Rotate + self.hdfs[num_i] = h5py.File(file_path, 'r', swmr=True) + print('read file: ', os.path.join(self.dataset_path, self.selected_keys[num_i])) + assert self.hdfs[num_i].swmr_mode + ####----------------------------------------------------------------------------------------------------------------------------------- + + self.build_idx_to_kv() + for num_i in range(0, len(self.hdfs)): + if self.hdfs[num_i]: + self.hdfs[num_i].close() + self.hdfs[num_i] = None + + self.__hdfs = None + self.hdf = None + + def build_idx_to_kv(self): + self.idx_to_kv = [] + self.key_idx_dict = {} + for num_i in range(0, len(self.selected_keys)): + p_key = self.selected_keys[num_i].split('.')[0] ##p00 + n = self.hdfs[num_i][self.image_key].shape[0] + if self.sample_rate_use > 1: + indices = np.arange(0, n, self.sample_rate_use) + else: + indices = np.arange(0, n) + self.idx_to_kv += [(num_i, i) for i in indices] + self.key_idx_dict[p_key] = [i for i in indices] + + + def __len__(self): + return len(self.idx_to_kv) + + def __del__(self): + for num_i in range(0, len(self.hdfs)): + if self.hdfs[num_i]: + self.hdfs[num_i].close() + self.hdfs[num_i] = None + + @property + def archives(self): + if self.__hdfs is None: # lazy loading here! + self.__hdfs = [h5py.File(h5_path, "r", swmr=True) for h5_path in self.file_paths] + return self.__hdfs + + def preprocess_image(self, image): + image = image.astype(np.float32) + if self.color_type == 'bgr': + image = image[..., ::-1] + if image.shape[0] != self.image_size[0] or image.shape[1] != self.image_size[1]: + image = cv2.resize(image, self.image_size, interpolation=cv2.INTER_AREA) + image = self.transform(image.astype(np.uint8) ) + return image + + def __getitem__(self, index): + key, idx = self.idx_to_kv[index] + self.hdf = self.archives[key] + image = self.hdf[self.image_key][idx] + gaze_label = self.hdf[self.gaze_key][idx].astype('float') if self.gaze_key in self.hdf else np.array([0,0]).astype('float') + head_label = self.hdf['face_head_pose'][idx].astype('float') if 'face_head_pose' in self.hdf else np.array([0,0]).astype('float') + entry = { + 'image': self.preprocess_image(image), + 'gaze': gaze_label, + 'head': head_label, + 'key': idx, + 'index':index + } + return entry + \ No newline at end of file diff --git a/datasets/gazecapture.py b/datasets/gazecapture.py new file mode 100644 index 0000000000000000000000000000000000000000..6ca829f967825546506c3af255a1ef7baace69bb --- /dev/null +++ b/datasets/gazecapture.py @@ -0,0 +1,132 @@ +import os +import numpy as np +import h5py +import cv2 +from torch.utils.data import Dataset +from typing import List +from omegaconf import OmegaConf, listconfig +from .helper.image_transform import wrap_transforms + + +class GazeCaptureDataset(Dataset): + def __init__(self, + dataset_path: str, + color_type, + keys_to_use: List[str] = None, + data_name=None, + image_size:int=224, ## <--- + transform_type='basic_imagenet', ## <--- modified + image_key='face_patch', + gaze_key='face_gaze', + sample_rate_use=1, + ): + + self.transform = wrap_transforms(transform_type, image_size=image_size) + + self.path = dataset_path + self.hdfs = {} + self.data_name = data_name + self.image_key = image_key + self.gaze_key = gaze_key + + self.image_size = (image_size, image_size) + + self.sample_rate_use = sample_rate_use + + assert color_type in ['rgb', 'bgr'] + self.color_type = color_type + self.selected_keys = [ k for k in keys_to_use] + assert len(self.selected_keys) > 0 + + self.file_paths = [os.path.join(self.path, k) for k in self.selected_keys] + for num_i in range(0, len(self.selected_keys)): + file_path = os.path.join(self.path, self.selected_keys[num_i]) # the subdirectories: train, test are not used in MPIIFaceGaze and MPII_Rotate + self.hdfs[num_i] = h5py.File(file_path, 'r', swmr=True) + print('read file: ', os.path.join(self.path, self.selected_keys[num_i])) + assert self.hdfs[num_i].swmr_mode + + + self.build_idx_to_kv() + + + for num_i in range(0, len(self.hdfs)): + if self.hdfs[num_i]: + self.hdfs[num_i].close() + self.hdfs[num_i] = None + + self.__hdfs = None + self.hdf = None + + def __len__(self): + return len(self.idx_to_kv) + + def __del__(self): + for num_i in range(0, len(self.hdfs)): + if self.hdfs[num_i]: + self.hdfs[num_i].close() + self.hdfs[num_i] = None + + def build_idx_to_kv(self): + self.idx_to_kv = [] + self.key_idx_dict = {} + for num_i in range(0, len(self.selected_keys)): + this_sub = self.selected_keys[num_i].split('.')[0] + n = self.hdfs[num_i][self.image_key].shape[0] + if self.sample_rate_use > 1: + indices = np.arange(0, n, self.sample_rate_use) + else: + indices = np.arange(0, n) + self.idx_to_kv += [(num_i, i) for i in indices ] + self.key_idx_dict[this_sub] = [ i for i in indices ] + + @property + def archives(self): + if self.__hdfs is None: # lazy loading here! + self.__hdfs = [h5py.File(h5_path, "r", swmr=True) for h5_path in self.file_paths] + return self.__hdfs + + + def preprocess_image(self, image): + image = image.astype(np.float32) + if self.color_type == 'bgr': + image = image[..., ::-1] + image = cv2.resize(image, self.image_size, interpolation=cv2.INTER_AREA) + image = self.transform(image.astype(np.uint8) ) + return image + + def __getitem__(self, index): + + key, idx = self.idx_to_kv[index] + self.hdf = self.archives[key] + + # self.hdf = h5py.File(os.path.join(self.path, self.selected_keys[key]), 'r', swmr=True) + assert self.hdf.swmr_mode + + image = self.hdf[self.image_key][idx, :] + gaze_label = self.hdf[self.gaze_key][idx].astype('float') if self.gaze_key in self.hdf else np.array([0,0]).astype('float') + head_label = self.hdf['face_head_pose'][idx].astype('float') if 'face_head_pose' in self.hdf else np.array([0,0]).astype('float') + + entry = { + 'image': self.preprocess_image(image), + 'gaze': gaze_label, + 'head': head_label, + 'key': key, + 'index':index + } + return entry + +# class GazeCaptureDatasetSubset(GazeCaptureDataset): +# def __init__(self, images_per_person=None, **kwargs): +# self.images_per_person = images_per_person +# super().__init__(**kwargs) + +# def build_idx_to_kv(self): +# self.idx_to_kv = [] +# self.key_idx_dict = {} +# for num_i in range(0, len(self.selected_keys)): +# this_sub = self.selected_keys[num_i].split('.')[0] +# n = self.hdfs[num_i][self.image_key].shape[0] +# if self.images_per_person is not None: +# n = min(n, self.images_per_person) +# self.idx_to_kv += [(num_i, i) for i in range(n)] +# self.key_idx_dict[this_sub] = [ i for i in range(n)] diff --git a/datasets/helper/image_transform.py b/datasets/helper/image_transform.py new file mode 100644 index 0000000000000000000000000000000000000000..81c4d5c501651239a1128cb2833e8423466a1ab5 --- /dev/null +++ b/datasets/helper/image_transform.py @@ -0,0 +1,81 @@ + +import cv2 +from torchvision import transforms +import numpy as np +import torch + +def re_normalize(image_tensor, old='[-1,1]', new='imagenet'): + """ + Re-normalizes an image tensor from one normalization scheme to another. + Args: + image_tensor (torch.Tensor): Image tensor to be re-normalized. + old (str): Old normalization scheme. Options: '[-1,1]', 'imagenet'. + new (str): New normalization scheme. Options: '[-1,1]', 'imagenet'. + Returns: + torch.Tensor: Re-normalized image tensor. + """ + # Old normalization parameters + device = image_tensor.device + if old == '[-1,1]': + old_mean = torch.tensor([0.5, 0.5, 0.5]).view(1, 3, 1, 1).to(device) + old_std = torch.tensor([0.5, 0.5, 0.5]).view(1, 3, 1, 1).to(device) + elif old == 'imagenet': + old_mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device) + old_std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device) + elif old == '[0,1]': + old_mean = torch.tensor([0.0, 0.0, 0.0]).view(1, 3, 1, 1).to(device) + old_std = torch.tensor([1.0, 1.0, 1.0]).view(1, 3, 1, 1).to(device) + else: + print('old normalization not implemented') + raise NotImplementedError + # New normalization parameters + if new == '[-1,1]': + new_mean = torch.tensor([0.5, 0.5, 0.5]).view(1, 3, 1, 1).to(device) + new_std = torch.tensor([0.5, 0.5, 0.5]).view(1, 3, 1, 1).to(device) + elif new == 'imagenet': + new_mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device) + new_std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device) + elif new == '[0,1]': + new_mean = torch.tensor([0.0, 0.0, 0.0]).view(1, 3, 1, 1).to(device) + new_std = torch.tensor([1.0, 1.0, 1.0]).view(1, 3, 1, 1).to(device) + else: + print('new normalization not implemented') + raise NotImplementedError + # Step 1: Denormalize the image tensor using the old mean and std + denormalized_image = image_tensor * old_std + old_mean + # Step 2: Normalize the image tensor using the new mean and std + normalized_image = (denormalized_image - new_mean) / new_std + + return normalized_image + + + + + + +def wrap_transforms(image_transforms_type, image_size): + + + if image_transforms_type == 'basic_imagenet': + MEAN = [0.485, 0.456, 0.406] + STD = [0.229, 0.224, 0.225] + return transforms.Compose([ + transforms.ToPILImage(), + transforms.ToTensor(), + transforms.Normalize(mean=MEAN, std=STD) + ]) + + + else: + raise NotImplementedError + + + +# def enhance_contrast_clahe(image): +# clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8)) +# lab = cv2.cvtColor(image, cv2.COLOR_RGB2LAB) +# lab_planes = list( cv2.split(lab) ) +# lab_planes[0] = clahe.apply(lab_planes[0]) +# lab = cv2.merge(lab_planes) +# image = cv2.cvtColor(lab, cv2.COLOR_LAB2RGB) +# return image diff --git a/datasets/mpiigaze.py b/datasets/mpiigaze.py new file mode 100644 index 0000000000000000000000000000000000000000..bfa042b77ce83a06d4e111cfa2ceed7fa0151ffc --- /dev/null +++ b/datasets/mpiigaze.py @@ -0,0 +1,109 @@ +import os +import numpy as np +import h5py +import cv2 +from torch.utils.data import Dataset +from typing import List +from omegaconf import OmegaConf, listconfig +from .helper.image_transform import wrap_transforms + + +class MPIIGazeDataset(Dataset): + def __init__(self, + dataset_path: str, + color_type, + keys_to_use: List[str] = None, + data_name=None, + image_size:int=224, ## <--- + transform_type='basic_imagenet', ## <--- modified + image_key='face_patch', + gaze_key='face_gaze', + ): + + self.dataset_path = dataset_path + self.hdfs = {} + self.data_name = data_name + self.image_key = image_key + self.gaze_key = gaze_key + self.image_size = (image_size, image_size) + + assert color_type in ['rgb', 'bgr'] + self.color_type = color_type + self.transform = wrap_transforms(transform_type, image_size=image_size) + + + self.selected_keys = [k for k in keys_to_use] + assert len(self.selected_keys) > 0 + + self.file_paths = [os.path.join(self.dataset_path, k) for k in self.selected_keys] + + for num_i in range(0, len(self.selected_keys)): + file_path = os.path.join(self.dataset_path, self.selected_keys[num_i]) # the subdirectories: train, test are not used in MPIIFaceGaze and MPII_Rotate + self.hdfs[num_i] = h5py.File(file_path, 'r', swmr=True) + print('read file: ', os.path.join(self.dataset_path, self.selected_keys[num_i])) + assert self.hdfs[num_i].swmr_mode + + self.build_idx_to_kv() + + for num_i in range(0, len(self.hdfs)): + if self.hdfs[num_i]: + self.hdfs[num_i].close() + self.hdfs[num_i] = None + + + + self.__hdfs = None + self.hdf = None + + def __len__(self): + return len(self.idx_to_kv) + + def __del__(self): + for num_i in range(0, len(self.hdfs)): + if self.hdfs[num_i]: + self.hdfs[num_i].close() + self.hdfs[num_i] = None + + def build_idx_to_kv(self): + + self.idx_to_kv = [] + self.key_idx_dict = {} + for num_i in range(0, len(self.selected_keys)): + p_key = self.selected_keys[num_i].split('.')[0] ##p00 + n = self.hdfs[num_i][self.image_key].shape[0] + self.idx_to_kv += [(num_i, i) for i in range(n)] + self.key_idx_dict[p_key] = [i for i in range(n)] + @property + def archives(self): + if self.__hdfs is None: # lazy loading here! + self.__hdfs = [h5py.File(h5_path, "r", swmr=True) for h5_path in self.file_paths] + return self.__hdfs + + + def preprocess_image(self, image): + image = image.astype(np.float32) + if self.color_type == 'bgr': + image = image[..., ::-1] + if image.shape[0] != self.image_size[0] or image.shape[1] != self.image_size[1]: + image = cv2.resize(image, self.image_size, interpolation=cv2.INTER_AREA) + image = self.transform(image.astype(np.uint8) ) + return image + + def __getitem__(self, index): + key, idx = self.idx_to_kv[index] + self.hdf = self.archives[key] + # self.hdf = h5py.File(os.path.join(self.dataset_path, self.selected_keys[key]), 'r', swmr=True) + assert self.hdf.swmr_mode + image = self.hdf[self.image_key][idx, :] + gaze_label = self.hdf[self.gaze_key][idx].astype('float') if self.gaze_key in self.hdf else np.array([0,0]).astype('float') + head_label = self.hdf['face_head_pose'][idx].astype('float') if 'face_head_pose' in self.hdf else np.array([0,0]).astype('float') + entry = { + 'image': self.preprocess_image(image), + 'gaze': gaze_label, + 'head': head_label, + 'key': key, + 'index':index + } + + return entry + diff --git a/datasets/xgaze.py b/datasets/xgaze.py new file mode 100644 index 0000000000000000000000000000000000000000..43652c5db7d82acf1e9f474aac4224a75058ca05 --- /dev/null +++ b/datasets/xgaze.py @@ -0,0 +1,137 @@ +import os,random +import numpy as np +import h5py +import cv2 +from typing import List +from torch.utils.data import Dataset +from .helper.image_transform import wrap_transforms + +class XGazeDataset(Dataset): + def __init__(self, + dataset_path: str, + color_type, + images_per_frame, + keys_to_use: List[str] = None, + data_name=None, + image_size:int=224, + transform_type='basic_imagenet', ## <--- modified + image_key='face_patch', + gaze_key='face_gaze', + camera_random=None, + frame_tag=[0,1000], + seed=0, + ): + + self.path = dataset_path + self.hdfs = {} + self.data_name = data_name + self.images_per_frame = images_per_frame + + print('images_per_frame: ', images_per_frame) + self.image_key = image_key + self.gaze_key = gaze_key + self.image_size = (image_size, image_size) + random.seed(seed) + + assert color_type in ['rgb', 'bgr'] + self.color_type = color_type + self.cameras_idx = list(range(self.images_per_frame)) + self.camera_random = camera_random + + #### -------------------------------------------------------- read the h5 files ------------------------------------------------------- + self.selected_keys = [k for k in keys_to_use] + assert len(self.selected_keys) > 0 + self.file_paths = [os.path.join(self.path, k) for k in self.selected_keys] + for num_i in range(0, len(self.selected_keys)): + file_path = os.path.join(self.path, self.selected_keys[num_i]) # the subdirectories: train, test are not used in MPIIFaceGaze and MPII_Rotate + self.hdfs[num_i] = h5py.File(file_path, 'r', swmr=True) + print('read file: ', os.path.join(self.path, self.selected_keys[num_i])) + assert self.hdfs[num_i].swmr_mode + ####----------------------------------------------------------------------------------------------------------------------------------- + + + self.idx_to_kv = [] + self.key_idx_dict = {} ## this is for reading the second sample from the same person + for num_i in range(0, len(self.selected_keys)): + this_sub = self.selected_keys[num_i].split('.')[0] + n = self.hdfs[num_i][image_key].shape[0] + + if type(frame_tag) == list: + self.start_frame, self.end_frame = frame_tag + elif frame_tag == 'all': + self.start_frame, self.end_frame = 0, 10000 + else: + raise ValueError("frame_tag should be either a list of integers or str 'all' ") + start_idx = min(n, self.start_frame * self.images_per_frame) + end_idx = min(n, self.end_frame * self.images_per_frame) + + if self.camera_random is None: + self.idx_to_kv += [(num_i, i) for i in range(start_idx, end_idx) if (i % self.images_per_frame ) in self.cameras_idx ] + self.key_idx_dict[this_sub] = [ i for i in range(start_idx, end_idx) if (i % self.images_per_frame ) in self.cameras_idx ] + else: + for frame in range(start_idx // self.images_per_frame, end_idx // self.images_per_frame): + frame_start_idx = frame * self.images_per_frame + frame_end_idx = frame_start_idx + self.images_per_frame + + # Randomly select self.images_per_frame camera indices for this frame + random_cameras_idx = random.sample(range(self.images_per_frame), self.camera_random) + self.idx_to_kv += [(num_i, i) for i in range(frame_start_idx, frame_end_idx) if (i % self.images_per_frame) in random_cameras_idx] + self.key_idx_dict.setdefault(this_sub, []).extend( + [i for i in range(frame_start_idx, frame_end_idx) if (i % self.images_per_frame) in random_cameras_idx] + ) + + for num_i in range(0, len(self.hdfs)): + if self.hdfs[num_i]: + self.hdfs[num_i].close() + self.hdfs[num_i] = None + + self.transform = wrap_transforms(transform_type, image_size=image_size) + self.__hdfs = None + self.hdf = None + + + def __len__(self): + return len(self.idx_to_kv) + + def __del__(self): + for num_i in range(0, len(self.hdfs)): + if self.hdfs[num_i]: + self.hdfs[num_i].close() + self.hdfs[num_i] = None + + + @property + def archives(self): + if self.__hdfs is None: # lazy loading here! + self.__hdfs = [h5py.File(h5_path, "r", swmr=True) for h5_path in self.file_paths] + return self.__hdfs + + def preprocess_image(self, image): + image = image.astype(np.float32) + if self.color_type == 'bgr': + image = image[..., ::-1] + if image.shape[0] != self.image_size[0] or image.shape[1] != self.image_size[1]: + image = cv2.resize(image, self.image_size, interpolation=cv2.INTER_AREA) + + image = self.transform( image.astype(np.uint8) ) + return image + + def __getitem__(self, index): + key, idx = self.idx_to_kv[index] + self.hdf = self.archives[key] + assert self.hdf.swmr_mode + image = self.hdf[self.image_key][idx, :] + gaze_label = self.hdf[self.gaze_key][idx].astype('float') if self.gaze_key in self.hdf else np.array([0,0]).astype('float') + head_label = self.hdf['face_head_pose'][idx].astype('float') if 'face_head_pose' in self.hdf else np.array([0,0]).astype('float') + + entry = { + 'image': self.preprocess_image(image), + 'gaze': gaze_label, + 'head': head_label, + 'key': key, + 'index':index + } + + return entry + + diff --git a/examples/De_Nachtwacht.png b/examples/De_Nachtwacht.png new file mode 100644 index 0000000000000000000000000000000000000000..801c67d881d0c064f790a31cd38f1dadc310d777 --- /dev/null +++ b/examples/De_Nachtwacht.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f95e98d7e7a725599ae1b3f7f86978834aba8773806947a94378902540b07d58 +size 12150021 diff --git a/examples/The_Night_Watch_Frans_Banninck_Cocq.png b/examples/The_Night_Watch_Frans_Banninck_Cocq.png new file mode 100644 index 0000000000000000000000000000000000000000..517f37d7001bafd478a34facef56b2aa8cbaa25a --- /dev/null +++ b/examples/The_Night_Watch_Frans_Banninck_Cocq.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3468d4cf328e965a68e797cad000b7d3007a40fc1a5fb4d9b15620cea184ad7c +size 590858 diff --git a/gazelib/__init__.py b/gazelib/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc --- /dev/null +++ b/gazelib/__init__.py @@ -0,0 +1 @@ + diff --git a/gazelib/draw/__init__.py b/gazelib/draw/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/gazelib/draw/draw_image.py b/gazelib/draw/draw_image.py new file mode 100644 index 0000000000000000000000000000000000000000..caca00f5b0a86dcc54102750db814814be06a32b --- /dev/null +++ b/gazelib/draw/draw_image.py @@ -0,0 +1,69 @@ +import cv2 + + +import torch +import numpy as np + + +def recover_image( image_tensor, MEAN=[0.5, 0.5, 0.5], STD=[0.5, 0.5, 0.5]): + """ + read a tensor and recover it to image in cv2 format + args: + image_tensor: [C, H, W] or [B, C, H, W] + return: + image_save: [B, H, W, C] + """ + if image_tensor.ndim == 3: + image_tensor = image_tensor.unsqueeze(0) + + x = torch.mul(image_tensor, torch.FloatTensor(STD).view(3,1,1).to(image_tensor.device)) + x = torch.add(x, torch.FloatTensor(MEAN).view(3,1,1).to(image_tensor.device) ) + x = x.data.cpu().numpy() + # [C, H, W] -> [H, W, C] + image_rgb = np.transpose(x, (0, 2, 3, 1)) + # RGB -> BGR + image_bgr = image_rgb[:, :, :, [2,1,0]] + # float -> int + image_save = np.clip(image_bgr*255, 0, 255).astype('uint8') + + return image_save + + +def draw_lm(image, landmarks, color= (0, 0, 255), radius = 20, print_idx=False): + i = 0 + image_out = image.copy() + for x,y in landmarks: + # Radius of circle + # Line thickness of 2 px + thickness = -1 + image_out = cv2.circle(image_out, (int(x), int(y)), radius, color, thickness) + + if print_idx: + image_out = cv2.putText(image_out, + text=str(i), + org=(int(x), int(y)), + fontFace=cv2.FONT_HERSHEY_SIMPLEX, + fontScale=2.0, + color=color, + thickness=2, + lineType=cv2.LINE_4) + + i += 1 + + return image_out + + +def draw_gaze(image_in, pitchyaw, thickness=2, color=(0, 0, 255)): + """Draw gaze angle on given image with a given eye positions.""" + image_out = image_in.copy() + (h, w) = image_in.shape[:2] + length = w / 2.0 + pos = (int(h / 2.0), int(w / 2.0)) + if len(image_out.shape) == 2 or image_out.shape[2] == 1: + image_out = cv2.cvtColor(image_out, cv2.COLOR_GRAY2BGR) + dx = -length * np.sin(pitchyaw[1]) * np.cos(pitchyaw[0]) + dy = -length * np.sin(pitchyaw[0]) + cv2.arrowedLine(image_out, tuple(np.round(pos).astype(np.int32)), + tuple(np.round([pos[0] + dx, pos[1] + dy]).astype(int)), color, + thickness, cv2.LINE_AA, tipLength=0.2) + return image_out \ No newline at end of file diff --git a/gazelib/gaze/__init__.py b/gazelib/gaze/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4381959735d242fd0ad6a1bd62dd634653af2b10 --- /dev/null +++ b/gazelib/gaze/__init__.py @@ -0,0 +1 @@ +from .gaze_utils import pitchyaw_to_vector, vector_to_pitchyaw, angular_error diff --git a/gazelib/gaze/gaze_utils.py b/gazelib/gaze/gaze_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..acc7795a8320a2568549d7346387f6662cc27eea --- /dev/null +++ b/gazelib/gaze/gaze_utils.py @@ -0,0 +1,166 @@ +import os +import numpy as np +import imageio +import cv2 +import h5py +import math +import torch + +def pitchyaw_to_vector(pitchyaws): + r"""Convert given yaw (:math:`\theta`) and pitch (:math:`\phi`) angles to unit gaze vectors. + + Args: + pitchyaws: Input array of yaw and pitch angles, either numpy array or tensor. + + Returns: + Output array of shape (n x 3) with 3D vectors per row, of the same type as the input. + """ + if isinstance(pitchyaws, np.ndarray): + return pitchyaw_to_vector_numpy(pitchyaws) + elif isinstance(pitchyaws, torch.Tensor): + return pitchyaw_to_vector_torch(pitchyaws) + else: + raise ValueError("Unsupported input type. Only numpy arrays and torch tensors are supported.") + +def pitchyaw_to_vector_numpy(pitchyaws): + n = pitchyaws.shape[0] + sin = np.sin(pitchyaws) + cos = np.cos(pitchyaws) + out = np.empty((n, 3)) + out[:, 0] = np.multiply(cos[:, 0], sin[:, 1]) + out[:, 1] = sin[:, 0] + out[:, 2] = np.multiply(cos[:, 0], cos[:, 1]) + return out + +def pitchyaw_to_vector_torch(pitchyaws): + n = pitchyaws.size()[0] + sin = torch.sin(pitchyaws) + cos = torch.cos(pitchyaws) + out = torch.empty((n, 3), device=pitchyaws.device) + out[:, 0] = torch.mul(cos[:, 0], sin[:, 1]) + out[:, 1] = sin[:, 0] + out[:, 2] = torch.mul(cos[:, 0], cos[:, 1]) + return out + +def vector_to_pitchyaw(vectors): + """Convert given gaze vectors to pitch (theta) and yaw (phi) angles. + + Args: + vectors: Input array of gaze vectors, either numpy array or tensor. + + Returns: + Output array of shape (n x 2) with pitch and yaw angles, of the same type as the input. + """ + if isinstance(vectors, np.ndarray): + return vector_to_pitchyaw_numpy(vectors) + elif isinstance(vectors, torch.Tensor): + return vector_to_pitchyaw_torch(vectors) + else: + raise ValueError("Unsupported input type. Only numpy arrays and torch tensors are supported.") + +def vector_to_pitchyaw_numpy(vectors): + n = vectors.shape[0] + vectors = vectors / np.linalg.norm(vectors, axis=1).reshape(n, 1) + out = np.empty((n, 2)) + out[:, 0] = np.arcsin(vectors[:, 1]) # theta + out[:, 1] = np.arctan2(vectors[:, 0], vectors[:, 2]) # phi + return out + +def vector_to_pitchyaw_torch(vectors): + n = vectors.size()[0] + vectors = vectors / torch.norm(vectors, dim=1).reshape(n, 1) + out = torch.empty((n, 2), device=vectors.device) + out[:, 0] = torch.asin(vectors[:, 1]) # theta + out[:, 1] = torch.atan2(vectors[:, 0], vectors[:, 2]) # phi + return out + + +def angular_error(a, b): + """Calculate angular error (via cosine similarity).""" + if isinstance(a, np.ndarray) and isinstance(b, np.ndarray): + return angular_error_numpy(a, b) + elif isinstance(a, torch.Tensor) and isinstance(b, torch.Tensor): + return angular_error_torch(a, b) + else: + raise ValueError("Input type mismatch. Both inputs should be either numpy arrays or torch tensors.") + +def angular_error_numpy(a, b): + """Calculate angular error for numpy arrays.""" + a = pitchyaw_to_vector(a) if a.shape[1] == 2 else a + b = pitchyaw_to_vector(b) if b.shape[1] == 2 else b + + ab = np.sum(np.multiply(a, b), axis=1) + a_norm = np.linalg.norm(a, axis=1) + b_norm = np.linalg.norm(b, axis=1) + + # Avoid zero-values (to avoid NaNs) + a_norm = np.clip(a_norm, a_min=1e-7, a_max=None) + b_norm = np.clip(b_norm, a_min=1e-7, a_max=None) + + similarity = np.divide(ab, np.multiply(a_norm, b_norm)) + + return np.arccos(similarity) * 180.0 / np.pi + +def angular_error_torch(a, b): + """Calculate angular error for torch tensors.""" + a = pitchyaw_to_vector(a) if a.size()[1] == 2 else a + b = pitchyaw_to_vector(b) if b.size()[1] == 2 else b + + ab = torch.sum(a * b, dim=1) + a_norm = torch.norm(a, dim=1) + b_norm = torch.norm(b, dim=1) + + # Avoid zero-values (to avoid NaNs) + a_norm = torch.clamp(a_norm, min=1e-7) + b_norm = torch.clamp(b_norm, min=1e-7) + + similarity = ab / (a_norm * b_norm) + + return torch.acos(similarity) * 180.0 / np.pi + + + + + + +def cos_similarity(a, b): + """Calculate angular error (via cosine similarity).""" + if isinstance(a, np.ndarray) and isinstance(b, np.ndarray): + return cos_similarity_numpy(a, b) + elif isinstance(a, torch.Tensor) and isinstance(b, torch.Tensor): + return cos_similarity_torch(a, b) + else: + raise ValueError("Input type mismatch. Both inputs should be either numpy arrays or torch tensors.") + +def cos_similarity_numpy(a, b): + """Calculate angular error for numpy arrays.""" + a = pitchyaw_to_vector(a) if a.shape[1] == 2 else a + b = pitchyaw_to_vector(b) if b.shape[1] == 2 else b + + ab = np.sum(np.multiply(a, b), axis=1) + a_norm = np.linalg.norm(a, axis=1) + b_norm = np.linalg.norm(b, axis=1) + # Avoid zero-values (to avoid NaNs) + a_norm = np.clip(a_norm, a_min=1e-7, a_max=None) + b_norm = np.clip(b_norm, a_min=1e-7, a_max=None) + similarity = np.divide(ab, np.multiply(a_norm, b_norm)) + similarity = np.clip(similarity, min=0., max=1.) + return similarity + +def cos_similarity_torch(a, b): + """Calculate angular error for torch tensors.""" + a = pitchyaw_to_vector(a) if a.size()[1] == 2 else a + b = pitchyaw_to_vector(b) if b.size()[1] == 2 else b + + ab = torch.sum(a * b, dim=1) + a_norm = torch.norm(a, dim=1) + b_norm = torch.norm(b, dim=1) + + # Avoid zero-values (to avoid NaNs) + a_norm = torch.clamp(a_norm, min=1e-7) + b_norm = torch.clamp(b_norm, min=1e-7) + + similarity = ab / (a_norm * b_norm) + similarity = torch.clamp(similarity, min=0., max=1.) + return similarity + diff --git a/gazelib/gaze/normalize.py b/gazelib/gaze/normalize.py new file mode 100644 index 0000000000000000000000000000000000000000..9fde1df59c7c0a0fabb06cd198a491fae9162789 --- /dev/null +++ b/gazelib/gaze/normalize.py @@ -0,0 +1,266 @@ +# -*- coding: utf-8 -*- +""" +###################################################################################################################################### +This work is licensed under the Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License. To view a copy of this license, +visit http://creativecommons.org/licenses/by-nc-sa/4.0/ or send a letter to Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. + +Any publications arising from the use of this software, including but +not limited to academic journal and conference publications, technical +reports and manuals, must cite at least one of the following works: + +Revisiting Data Normalization for Appearance-Based Gaze Estimation +Xucong Zhang, Yusuke Sugano, Andreas Bulling +in Proc. International Symposium on Eye Tracking Research and Applications (ETRA), 2018 +###################################################################################################################################### +""" + +import os +import cv2 +import numpy as np +import csv +import argparse +# import dlib +import glob + + + + + +def normalize_woimg(landmarks, focal_norm, distance_norm, roi_size, center, hr, ht, cam, gc=None): + center = center.reshape(3,1) + ## universal function for data normalization + hR = cv2.Rodrigues(hr)[0] # rotation matrix + + ## ---------- normalize image ---------- + distance = np.linalg.norm(center) # actual distance between eye and original camera + + z_scale = distance_norm/distance + cam_norm = np.array([ + [focal_norm, 0, roi_size[0]/2], + [0, focal_norm, roi_size[1]/2], + [0, 0, 1.0], + ]) + S = np.array([ # scaling matrix + [1.0, 0.0, 0.0], + [0.0, 1.0, 0.0], + [0.0, 0.0, z_scale], + ]) + + hRx = hR[:,0] + forward = (center/distance).reshape(3) + down = np.cross(forward, hRx) + down /= np.linalg.norm(down) + right = np.cross(down, forward) + right /= np.linalg.norm(right) + R = np.c_[right, down, forward].T # rotation matrix R + + W = np.dot(np.dot(cam_norm, S), np.dot(R, np.linalg.inv(cam))) # transformation matrix + + ## ---------- normalize rotation ---------- + hR_norm = np.dot(R, hR) # rotation matrix in normalized space + # hr_norm = cv2.Rodrigues(hR_norm)[0] # convert rotation matrix to rotation vectors + + ## ---------- normalize gaze vector ---------- + gc_normalized = None + + num_point = landmarks.shape[0] + landmarks_warped = cv2.perspectiveTransform(landmarks.reshape(-1,1,2).astype('float32'), W) + landmarks_warped = landmarks_warped.reshape(num_point, 2) + if gc is not None: + gc_normalized = gc.reshape((3,1)) - center # gaze vector + # For modified data normalization, scaling is not applied to gaze direction (only R applied). + # For original data normalization, here should be: + # "M = np.dot(S,R) + # gc_normalized = np.dot(R, gc_normalized)" + gc_normalized = np.dot(R, gc_normalized) + gc_normalized = gc_normalized/np.linalg.norm(gc_normalized) + + return [None, R, hR_norm, gc_normalized, landmarks_warped, W] + + +def normalize(img, landmarks, focal_norm, distance_norm, roi_size, center, hr, ht, cam, gc=None): + center = center.reshape(3,1) + ## universal function for data normalization + hR = cv2.Rodrigues(hr)[0] # rotation matrix + + ## ---------- normalize image ---------- + distance = np.linalg.norm(center) # actual distance between eye and original camera + + z_scale = distance_norm/distance + cam_norm = np.array([ + [focal_norm, 0, roi_size[0]/2], + [0, focal_norm, roi_size[1]/2], + [0, 0, 1.0], + ]) + S = np.array([ # scaling matrix + [1.0, 0.0, 0.0], + [0.0, 1.0, 0.0], + [0.0, 0.0, z_scale], + ]) + + hRx = hR[:,0] + forward = (center/distance).reshape(3) + down = np.cross(forward, hRx) + down /= np.linalg.norm(down) + right = np.cross(down, forward) + right /= np.linalg.norm(right) + R = np.c_[right, down, forward].T # rotation matrix R + W = np.dot(np.dot(cam_norm, S), np.dot(R, np.linalg.inv(cam))) # transformation matrix + + # if img is not None: + # img_warped = cv2.warpPerspective(img, W, roi_size) # image normalization + # else: + # img_warped = None + + img_warped = cv2.warpPerspective(img, W, roi_size) # image normalization + ## ---------- normalize rotation ---------- + hR_norm = np.dot(R, hR) # rotation matrix in normalized space + # hr_norm = cv2.Rodrigues(hR_norm)[0] # convert rotation matrix to rotation vectors + + ## ---------- normalize gaze vector ---------- + gc_normalized = None + num_point = landmarks.shape[0] + landmarks_warped = cv2.perspectiveTransform(landmarks.reshape(-1,1,2).astype('float32'), W) + landmarks_warped = landmarks_warped.reshape(num_point, 2) + if gc is not None: + gc_normalized = gc.reshape((3,1)) - center # gaze vector + # For modified data normalization, scaling is not applied to gaze direction (only R applied). + # For original data normalization, here should be: + # "M = np.dot(S,R) + # gc_normalized = np.dot(R, gc_normalized)" + gc_normalized = np.dot(R, gc_normalized) + gc_normalized = gc_normalized/np.linalg.norm(gc_normalized) + + return [img_warped, R, hR_norm, gc_normalized, landmarks_warped, W] + +def normalize_face(img, face, hr, ht, cam, gc=None): + ## normalized camera parameters + focal_norm = 960 # focal length of normalized camera + distance_norm = 600 # normalized distance between eye and camera + roi_size = (224, 224) # size of cropped eye image + + ## compute estimated 3D positions of the landmarks + ht = ht.reshape((3,1)) + hR = cv2.Rodrigues(hr)[0] # rotation matrix + Fc = np.dot(hR, face) + ht # 3D positions of facial landmarks + # fm = np.mean(Fc, axis=1).reshape((3,1)) # center of facial landmarks + two_eye_center = np.mean(Fc[:, 0:4], axis=1).reshape((3, 1)) + nose_center = np.mean(Fc[:, 4:6], axis=1).reshape((3, 1)) + # get the face center + face_center = np.mean(np.concatenate((two_eye_center, nose_center), axis=1), axis=1).reshape((3, 1)) + # face_center = np.mean(Fc, axis=1).reshape((3,1)) + return normalize(img, focal_norm, distance_norm, roi_size, face_center, hr, ht, cam, gc) + +def normalize_eye(img, face, hr, ht, cam, gc=None): + ## normalized camera parameters + focal_norm = 960 # focal length of normalized camera + distance_norm = 600 # normalized distance between eye and camera + roi_size = (60, 36) # size of cropped eye image + + ## compute estimated 3D positions of the landmarks + ht = ht.reshape((3,1)) + hR = cv2.Rodrigues(hr)[0] # rotation matrix + Fc = np.dot(hR, face) + ht # 3D positions of facial landmarks + re = 0.5*(Fc[:,0] + Fc[:,1]).reshape((3,1)) # center of left eye + le = 0.5*(Fc[:,2] + Fc[:,3]).reshape((3,1)) # center of right eye + + ## normalize each eye + data = [ + normalize(img, focal_norm, distance_norm, roi_size, re, hr, ht, cam, gc), + normalize(img, focal_norm, distance_norm, roi_size, le, hr, ht, cam, gc) + ] + return data + +def load_calibration(calib_path): + ## load calibration data, these paramters are expected to be obtained by camera calibration functions in OpenCV + fs = cv2.FileStorage(calib_path, cv2.FILE_STORAGE_READ) + camera_matrix = fs.getNode('camera_matrix').mat() + camera_distortion = fs.getNode('dist_coeffs').mat() + return camera_matrix, camera_distortion + +def load_facemodel(model_path): + # load the generic face model, which includes 6 facial landmarks: four eye corners and two mouth corners + fs = cv2.FileStorage(model_path, cv2.FILE_STORAGE_READ) + face_model = fs.getNode('face_model').mat() + return face_model + +def read_image(img_path, camera_matrix, camera_distortion): + # load input image and undistort + img_original = cv2.imread(img_path) + img = cv2.undistort(img_original, camera_matrix, camera_distortion) + + return img + +def estimateHeadPose(landmarks, face_model, camera, distortion, iterate=True): + ret, rvec, tvec = cv2.solvePnP(face_model, landmarks, camera, distortion, flags=cv2.SOLVEPNP_EPNP) + + ## further optimize + if iterate: + ret, rvec, tvec = cv2.solvePnP(face_model, landmarks, camera, distortion, rvec, tvec, True) + + return rvec, tvec + +def detect_landmark(img, detector_path, predictor_path): + ## obtain facial landmarks using dlib + detector = dlib.cnn_face_detection_model_v1(detector_path) + dets = detector(img, 0) + + if len(dets) == 0: + return None + + predictor = dlib.shape_predictor(predictor_path) + shape = predictor(img, dets[0].rect) + + ## extract required keypoints + landmarks = np.array([ + [shape.part(36).x, shape.part(36).y], + [shape.part(39).x, shape.part(39).y], + [shape.part(42).x, shape.part(42).y], + [shape.part(45).x, shape.part(45).y], + [shape.part(48).x, shape.part(48).y], + [shape.part(54).x, shape.part(54).y] + ]) + + return landmarks + + +def read_landmark(img_path): + img_file = img_path.split(os.path.sep)[-1] + day = img_path.split(os.path.sep)[-2] + person = img_path.split(os.path.sep)[-3] + person_path = os.path.split(os.path.split(img_path)[0])[0] + + person_txt = os.path.join(person_path, person+'.txt') + index = os.path.join(day,img_file) + print(person_txt) + print(index) + + with open(person_txt) as f: + data = f.readlines() + reader = csv.reader(data) + p = {} + for row in reader: + words = row[0].split() + p[words[0]] = words[1:] + landmarks = np.array([int(i) for i in p[index][2:14]]).reshape((6,2)) + return landmarks + +# def process_image(img_path, detector_path, predictor_path, camera_matrix, camera_distortion, face_model, gc=None): +# # read input image +# img = read_image(img_path, camera_matrix, camera_distortion) + +# # detect facial landmarks +# landmarks = detect_landmark(img, detector_path, predictor_path) + +# if landmarks is not None: +# # estimate head pose +# hr, ht = estimateHeadPose(face_model, landmarks, camera_matrix, camera_distortion) + +# # data normalization for left and right eye image +# normalized_eyes = normalize_eye(img, face_model, hr, ht, camera_matrix, gc) + +# # data normalization for full face +# normalized_face = normalize_face(img, face_model, hr, ht, camera_matrix, gc) + +# # return a list of [reye, leye, face] +# return normalized_eyes + [normalized_face] diff --git a/gazelib/label_transform.py b/gazelib/label_transform.py new file mode 100644 index 0000000000000000000000000000000000000000..25852de9e2bedcd214702c5493fe6d008b5ef5ac --- /dev/null +++ b/gazelib/label_transform.py @@ -0,0 +1,195 @@ + + +import cv2 + +import numpy as np + + +def get_eye_nose_landmarks(landmarks): + assert landmarks.shape[0]==50 or landmarks.shape[0]==68 + if landmarks.shape[0] == 50: + lm_6 = landmarks[[20, 23, 26, 29, 15, 19], :] # the eye and nose landmarks + elif landmarks.shape[0] == 68: + lm_6 = landmarks[[36, 39, 42, 45, 31, 35], :] # the eye and nose landmarks + return lm_6 +def get_eye_mouth_landmarks(landmarks): + assert landmarks.shape[0]==50 or landmarks.shape[0]==68 + if landmarks.shape[0] == 50: + lm_6 = landmarks[[20, 23, 26, 29, 32, 38], :] # the eye and nose landmarks + elif landmarks.shape[0] == 68: + lm_6 = landmarks[[36,39,42,45,48,54], :] # the eye and nose landmarks + return lm_6 + +def mean_eye_nose(landmarks): + assert landmarks.shape[0]==6 + # get the face center + two_eye_center = np.mean(landmarks[0:4, :], axis=0).reshape(1,-1) + nose_center = np.mean(landmarks[4:6, :], axis=0).reshape(1,-1) + face_center = np.mean(np.concatenate((two_eye_center, nose_center), axis=0), axis=0).reshape(1,-1) + return face_center +def mean_eye_mouth(landmarks): + assert landmarks.shape[0]==6 + face_center = np.mean(landmarks, axis=0).reshape(1,-1) + return face_center + +def get_face_center_by_nose(hR, ht, face_model_load): + face_model = get_eye_nose_landmarks(face_model_load) # the eye and nose landmarks + Fc = np.dot(hR, face_model.T) + ht # 3D positions of facial landmarks + face_center = mean_eye_nose(Fc.T).reshape((3, 1)) # get the face center + return face_center, Fc + +def get_face_center_by_mouth(hR, ht, face_model_load): + face_model = get_eye_mouth_landmarks(face_model_load) # the eye and nose landmarks + Fc = np.dot(hR, face_model.T) + ht # 3D positions of facial landmarks + face_center = mean_eye_mouth(Fc.T).reshape((3, 1)) # get the face center + return face_center, Fc + +def lm68_to_50(lm_68): + ''' + lm_68: (68,2) + ''' + lm_50 = np.zeros((50,2)) + lm_50[0] = lm_68[8] + lm_50[1:44] = lm_68[17:60] + lm_50[44:47] = lm_68[61:64] + lm_50[47:50] = lm_68[65:68] + return lm_50 + + +def lm68_subset(lm_68, NUM_KPTS_TO_USE): + ''' + lm_68: (68,2) + ''' + if NUM_KPTS_TO_USE == 6: + lm_68 = np.array(lm_68, dtype=np.float32) + return lm_68[[36, 39, 42, 45, 31, 35], :] + elif NUM_KPTS_TO_USE ==50: + return lm68_to_50(lm_68) + else: + print('not supported yet') + exit(0) + +def lm50_subset(lm_50, NUM_KPTS_TO_USE): + ''' + lm_50: (50,2) + ''' + lm_50 = lm_50.copy() + if NUM_KPTS_TO_USE == 6: + lm_50 = lm_50[[20, 23, 26, 29, 15, 19], :] + return lm_50 + elif NUM_KPTS_TO_USE ==50: + return lm_50 + else: + print('not supported yet') + exit(0) + +def get_face_center(landmarks_3d): + ''' + landmarks_3d: (3, 6) + --> + face_center: (3,1) + ''' + two_eye_center = np.mean(landmarks_3d[:, 0:4], axis=1).reshape((3, 1)) + nose_center = np.mean(landmarks_3d[:, 4:6], axis=1).reshape((3, 1)) + face_center = np.mean(np.concatenate((two_eye_center, nose_center), axis=1), axis=1).reshape((3, 1)) + return face_center + + + +def compute_R(lm6, dataname): + ''' + 6 landmarks in opencv coordinate + dataname: mpii or xgaze + the face center are computed differently + for mpii: the 6 landmarks are 4 eye + 2 mouth + for xgaze: the 6 landmarks are 4 eye + 2 nose + ''' + if dataname=='mpii': + left_center = np.mean(lm6[2:4,:],axis=0) + right_center = np.mean(lm6[:2,:],axis=0) + face_center = np.mean(lm6,axis=0) + elif dataname=='xgaze': + left_center = np.mean(lm6[2:4,:],axis=0) + right_center = np.mean(lm6[:2,:],axis=0) + nose_center = np.mean(lm6[[4,5],:],axis=0) + face_center = ( (left_center + right_center)/2 + nose_center ) /2 + + distance = np.linalg.norm(face_center) + + hRx = left_center - right_center + hRx /= np.linalg.norm(hRx) + forward = (face_center/distance).reshape(3) + down = np.cross(forward, hRx) + down /= np.linalg.norm(down) + right = np.cross(down, forward) + right /= np.linalg.norm(right) + R = np.c_[right, down, forward].T + return R + +def rotation_matrix(x, y, z): + ''' + x, y, z: roll, pitch, yaw, (radians) + ''' + Rx = np.array([[1,0,0], + [0, np.cos(x), -np.sin(x)], + [0, np.sin(x), np.cos(x)]]) + + Ry = np.array([[ np.cos(y), 0, np.sin(y)], + [ 0, 1, 0], + [-np.sin(y), 0, np.cos(y)]]) + + Rz = np.array([[np.cos(z), -np.sin(z), 0], + [np.sin(z), np.cos(z), 0], + [0,0,1]]) + return Rz@Ry@Rx +def get_rotation(from_pose, target_pose): + + rotation1 = rotation_matrix( -from_pose[0], from_pose[1], 0) + rotation2 = rotation_matrix(-target_pose[0], target_pose[1], 0) + rotation = rotation2@np.linalg.inv(rotation1) + return rotation + +def hR_2_hr(hR): + hr = np.array([np.arcsin(hR[1, 2]), + np.arctan2(hR[0, 2], hR[2, 2])]) + return hr + +def hr_2_hR(hr): + hR = rotation_matrix( -hr[0], hr[1], 0) + return hR + + + +if __name__ == '__main__': + # hr_norm = np.array([0.15, 0.2]) + # pose = np.array([-0.1, 0.3]) + # rotation1 = rotation_matrix( -hr_norm[0], hr_norm[1], 0) + # rot = cv2.Rodrigues( np.array([hr_norm[0], hr_norm[1], 0]) )[0] + def to_hR(hr_norm): + hR_norm = rotation_matrix( -hr_norm[0], hr_norm[1], 0) + return hR_norm + + hr1 = np.array([0.15, 0.2]) + hr2 = np.array([0.10, 0.2]) + + hr_t = np.array([-0.1, 0.3]) + + hR1 = to_hR(hr1) + hR2 = to_hR(hr2) + print('hR1: ', hR1) + print('hR2: ', hR2) + + R1t = get_rotation(hr1, hr_t) + + hR1_ = np.dot(R1t, hR1) + + print('rotated hR_: ', hR1_) + hr1_ = np.array([np.arcsin(hR1_[1, 2]), + np.arctan2(hR1_[0, 2], hR1_[2, 2])]) + print('rotated hr1_: ', hr1_) + print('hR t: ', to_hR(hr_t)) + hR2_ = np.dot(R1t, hR2) + print('rotated hR2_: ', hR2_) + # rotation2 = rotation_matrix( -pose[0], pose[1], 0) + + diff --git a/gazelib/utils/__init__.py b/gazelib/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8dff6a55c3cf771d39435af81827b78fbe2bd642 --- /dev/null +++ b/gazelib/utils/__init__.py @@ -0,0 +1,4 @@ + + +from .h5_utils import add, to_h5 + diff --git a/gazelib/utils/color_text.py b/gazelib/utils/color_text.py new file mode 100644 index 0000000000000000000000000000000000000000..953d61c3b309ac40798559e946af6c4ba127105c --- /dev/null +++ b/gazelib/utils/color_text.py @@ -0,0 +1,85 @@ +class ColorText: + """A simple text processor for printing colored text to the terminal.""" + + colors = { + 'black': '\033[30m', + 'red': '\033[31m', + 'green': '\033[32m', + 'yellow': '\033[33m', + 'blue': '\033[34m', + 'magenta': '\033[35m', + 'cyan': '\033[36m', + 'white': '\033[37m', + 'reset': '\033[0m' + } + + @classmethod + def colorize(cls, text, color): + """Colorize the given text using the specified color.""" + return f'{cls.colors[color]}{text}{cls.colors["reset"]}' + + @classmethod + def black(cls, text): + """Colorize the given text with black.""" + return cls.colorize(text, 'black') + + @classmethod + def red(cls, text): + """Colorize the given text with red.""" + return cls.colorize(text, 'red') + + @classmethod + def green(cls, text): + """Colorize the given text with green.""" + return cls.colorize(text, 'green') + + @classmethod + def yellow(cls, text): + """Colorize the given text with yellow.""" + return cls.colorize(text, 'yellow') + + @classmethod + def blue(cls, text): + """Colorize the given text with blue.""" + return cls.colorize(text, 'blue') + + @classmethod + def magenta(cls, text): + """Colorize the given text with magenta.""" + return cls.colorize(text, 'magenta') + + @classmethod + def cyan(cls, text): + """Colorize the given text with cyan.""" + return cls.colorize(text, 'cyan') + + @classmethod + def white(cls, text): + """Colorize the given text with white.""" + return cls.colorize(text, 'white') + +def print_green(*args, **kwargs): + out = ' '.join([str(arg) for arg in args]) + print(ColorText.green(out)) +def print_yellow(*args, **kwargs): + out = ' '.join([str(arg) for arg in args]) + print(ColorText.yellow(out)) +def print_magenta(*args, **kwargs): + out = ' '.join([str(arg) for arg in args]) + print(ColorText.magenta(out)) +def print_cyan(*args, **kwargs): + out = ' '.join([str(arg) for arg in args]) + print(ColorText.cyan(out)) +def print_red(*args, **kwargs): + out = ' '.join([str(arg) for arg in args]) + print(ColorText.red(out)) + +if __name__ == '__main__': + print(ColorText.red('red')) + print(ColorText.green('green')) + print(ColorText.yellow('yellow')) + print(ColorText.blue('blue')) + print(ColorText.magenta('magenta')) + print(ColorText.cyan('cyan')) + print(ColorText.white('white')) + print(ColorText.black('black')) \ No newline at end of file diff --git a/gazelib/utils/h5_utils.py b/gazelib/utils/h5_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d23fc31fc991e391f670880198eb38468812bf23 --- /dev/null +++ b/gazelib/utils/h5_utils.py @@ -0,0 +1,53 @@ +import os +import numpy as np +import imageio +import cv2 +import h5py +import math + + +def add(to_write, key, value): # noqa + if key not in to_write: + to_write[key] = [value] + else: + to_write[key].append(value) + +def to_h5(to_write, output_path): + for key, values in to_write.items(): + to_write[key] = np.asarray(values) + # print('%s: ' % key, to_write[key].shape) + + if not os.path.isfile(output_path): + with h5py.File(output_path, 'w') as f: + for key, values in to_write.items(): + print("values.shape: ", values.shape) + f.create_dataset( + key, data=values, + chunks=( + tuple([1] + list(values.shape[1:])) + if isinstance(values, np.ndarray) + else None + ), + compression='lzf', + maxshape=tuple([None] + list(values.shape[1:])), + ) + print("chunks: ", f[key].chunks) + else: + with h5py.File(output_path, 'a') as f: + for key, values in to_write.items(): + if key not in list(f.keys()): + print('write it to f {}'.format(output_path)) + f.create_dataset( + key, data=values, + chunks=( + tuple([1] + list(values.shape[1:])) + if isinstance(values, np.ndarray) + else None + ), + compression='lzf', + maxshape=tuple([None] + list(values.shape[1:])), + ) + else: + data = f[key] + data.resize(data.shape[0] + values.shape[0], axis=0) + data[-values.shape[0]:] = values \ No newline at end of file diff --git a/models/hybrid_tr.py b/models/hybrid_tr.py new file mode 100644 index 0000000000000000000000000000000000000000..64e2aa2c54fff5a73d4f291f09c6a8bf5dffc0cf --- /dev/null +++ b/models/hybrid_tr.py @@ -0,0 +1,570 @@ +import os +import sys +import torch +import torch.nn as nn +import torchvision.models as models +import numpy as np +import math +import copy +# from modules.resnet_v1 import resnet50 +import torch.utils.model_zoo as model_zoo + +from torch.utils.model_zoo import load_url as load_state_dict_from_url + + +__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', + 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', + 'wide_resnet50_2', 'wide_resnet101_2'] + + +def _resnet(arch, block, layers, pretrained, progress, **kwargs): + model = ResFeature(block, layers, **kwargs) + if pretrained: + state_dict = load_state_dict_from_url(model_urls[arch], + progress=progress) + model.load_state_dict(state_dict, strict=False) + return model + + +def resnet18(pretrained=False, progress=True, **kwargs): + r"""ResNet-18 model from + `"Deep Residual Learning for Image Recognition" '_ + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, + **kwargs) + + +def resnet34(pretrained=False, progress=True, **kwargs): + r"""ResNet-34 model from + `"Deep Residual Learning for Image Recognition" '_ + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, + **kwargs) + + +def resnet50(pretrained=False, progress=True, **kwargs): + r"""ResNet-50 model from + `"Deep Residual Learning for Image Recognition" '_ + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, + **kwargs) + + +def resnet101(pretrained=False, progress=True, **kwargs): + r"""ResNet-101 model from + `"Deep Residual Learning for Image Recognition" '_ + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, + **kwargs) + + +def resnet152(pretrained=False, progress=True, **kwargs): + r"""ResNet-152 model from + `"Deep Residual Learning for Image Recognition" '_ + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, + **kwargs) + + +def resnext50_32x4d(pretrained=False, progress=True, **kwargs): + r"""ResNeXt-50 32x4d model from + `"Aggregated Residual Transformation for Deep Neural Networks" `_ + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + kwargs['groups'] = 32 + kwargs['width_per_group'] = 4 + return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], + pretrained, progress, **kwargs) + + +def resnext101_32x8d(pretrained=False, progress=True, **kwargs): + r"""ResNeXt-101 32x8d model from + `"Aggregated Residual Transformation for Deep Neural Networks" `_ + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + kwargs['groups'] = 32 + kwargs['width_per_group'] = 8 + return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], + pretrained, progress, **kwargs) + + +def wide_resnet50_2(pretrained=False, progress=True, **kwargs): + r"""Wide ResNet-50-2 model from + `"Wide Residual Networks" `_ + + The model is the same as ResNet except for the bottleneck number of channels + which is twice larger in every block. The number of channels in outer 1x1 + convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 + channels, and in Wide ResNet-50-2 has 2048-1024-2048. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + kwargs['width_per_group'] = 64 * 2 + return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3], + pretrained, progress, **kwargs) + + +def wide_resnet101_2(pretrained=False, progress=True, **kwargs): + r"""Wide ResNet-101-2 model from + `"Wide Residual Networks" `_ + + The model is the same as ResNet except for the bottleneck number of channels + which is twice larger in every block. The number of channels in outer 1x1 + convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 + channels, and in Wide ResNet-50-2 has 2048-1024-2048. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + kwargs['width_per_group'] = 64 * 2 + return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3], + pretrained, progress, **kwargs) + + +model_urls = { + 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', + 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', + 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', + 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', + 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', + 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', + 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', + 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth', + 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth', +} + + +def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=dilation, groups=groups, bias=False, dilation=dilation) + + +def conv1x1(in_planes, out_planes, stride=1): + """1x1 convolution""" + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, + base_width=64, dilation=1, norm_layer=None): + super(BasicBlock, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + if groups != 1 or base_width != 64: + raise ValueError('BasicBlock only supports groups=1 and base_width=64') + if dilation > 1: + raise NotImplementedError("Dilation > 1 not supported in BasicBlock") + + # Both self.conv1 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = norm_layer(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = norm_layer(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, + base_width=64, dilation=1, norm_layer=None): + super(Bottleneck, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + width = int(planes * (base_width / 64.)) * groups + # Both self.conv2 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv1x1(inplanes, width) + self.bn1 = norm_layer(width) + self.conv2 = conv3x3(width, width, stride, groups, dilation) + self.bn2 = norm_layer(width) + self.conv3 = conv1x1(width, planes * self.expansion) + self.bn3 = norm_layer(planes * self.expansion) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + +class ResFeature(nn.Module): + + def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, + groups=1, width_per_group=64, replace_stride_with_dilation=None, + norm_layer=None): + super(ResFeature, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + self._norm_layer = norm_layer + + self.inplanes = 64 + self.dilation = 1 + if replace_stride_with_dilation is None: + # each element in the tuple indicates if we should replace + # the 2x2 stride with a dilated convolution instead + replace_stride_with_dilation = [False, False, False] + if len(replace_stride_with_dilation) != 3: + raise ValueError("replace_stride_with_dilation should be None " + "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) + self.groups = groups + self.base_width = width_per_group + self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, + bias=False) + self.bn1 = norm_layer(self.inplanes) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2, + dilate=replace_stride_with_dilation[0]) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2, + dilate=replace_stride_with_dilation[1]) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2, + dilate=replace_stride_with_dilation[2]) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + # Zero-initialize the last BN in each residual branch, + # so that the residual branch starts with zeros, and each residual block behaves like an identity. + # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 + if zero_init_residual: + for m in self.modules(): + if isinstance(m, Bottleneck): + nn.init.constant_(m.bn3.weight, 0) + elif isinstance(m, BasicBlock): + nn.init.constant_(m.bn2.weight, 0) + + def _make_layer(self, block, planes, blocks, stride=1, dilate=False): + norm_layer = self._norm_layer + downsample = None + previous_dilation = self.dilation + if dilate: + self.dilation *= stride + stride = 1 + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + conv1x1(self.inplanes, planes * block.expansion, stride), + norm_layer(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample, self.groups, + self.base_width, previous_dilation, norm_layer)) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append(block(self.inplanes, planes, groups=self.groups, + base_width=self.base_width, dilation=self.dilation, + norm_layer=norm_layer)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + return x + + +class ResGazeEs(nn.Module): + def __init__(self, ): + super(ResGazeEs, self).__init__() + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + self.fc = nn.Linear(2048, 2) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + def forward(self, x): + x = self.avgpool(x) + x = x.view(x.size(0), -1) + x = self.fc(x) + return x + +class CNN_Model(nn.Module): + def __init__(self): + super(CNN_Model, self).__init__() + self.feature = resnet50(pretrained=True) + # self.feature.load_state_dict(torch.load(pretrained_url), strict=False ) + self.gazeEs = ResGazeEs() + # self.gazeEs.load_state_dict(torch.load(pretrained_url), strict=False ) + + def forward(self, x_in): + features = self.feature(x_in) + gaze = self.gazeEs(features) + + return gaze, features + + + + +class TransformerEncoder(nn.Module): + + def __init__(self, encoder_layer, num_layers, norm=None): + super().__init__() + + self.layers = nn.ModuleList([copy.deepcopy(encoder_layer) for i in range(num_layers)]) + self.num_layers = num_layers + self.norm = norm + + def forward(self, src, pos): + output = src + for layer in self.layers: + output = layer(output, pos) + + if self.norm is not None: + output = self.norm(output) + + return output + + +class TransformerEncoderLayer(nn.Module): + + def __init__(self, d_model, nhead, dim_feedforward=512, dropout=0.1): + super().__init__() + self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + + self.activation = nn.ReLU(inplace=True) + + def pos_embed(self, src, pos): + batch_pos = pos.unsqueeze(1).repeat(1, src.size(1), 1) + return src + batch_pos + + + def forward(self, src, pos): + # src_mask: Optional[Tensor] = None, + # src_key_padding_mask: Optional[Tensor] = None): + # pos: Optional[Tensor] = None): + + q = k = self.pos_embed(src, pos) + src2 = self.self_attn(q, k, value=src)[0] + src = src + self.dropout1(src2) + src = self.norm1(src) + + src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) + src = src + self.dropout2(src2) + src = self.norm2(src) + return src + +class FeatureTransformer(nn.Module): + ''' + This is the end head which is included in the resnet18 (in official code) + To avoid ambiguity, extract this part out of resnet18 + ''' + def __init__(self, in_channels=512, maps=32): + super(FeatureTransformer, self).__init__() + self.conv = nn.Sequential( + nn.Conv2d(in_channels, maps, 1), + nn.BatchNorm2d(maps), + nn.ReLU(inplace=True) + ) + def forward(self, x): + x = self.conv(x) + return x + +class HybridTR18(nn.Module): + def __init__(self): + super().__init__() + maps = 32 + nhead = 8 + dim_feature = 7*7 + dim_feedforward=512 + dropout = 0.1 + num_layers=6 + + self.base_model = resnet18(pretrained=True) #False, maps=maps) + self.base_model_head = FeatureTransformer(in_channels=dim_feedforward, maps=maps) + + # d_model: dim of Q, K, V + # nhead: seq num + # dim_feedforward: dim of hidden linear layers + # dropout: prob + + encoder_layer = TransformerEncoderLayer( + maps, + nhead, + dim_feedforward, + dropout) + + encoder_norm = nn.LayerNorm(maps) + # num_encoder_layer: deeps of layers + self.encoder = TransformerEncoder(encoder_layer, num_layers, encoder_norm) + self.cls_token = nn.Parameter(torch.randn(1, 1, maps)) + self.pos_embedding = nn.Embedding(dim_feature+1, maps) + self.feed = nn.Linear(maps, 2) + + + def forward(self, x_in, normalize_z=False): + output_dict = {} + # feature = self.base_model(x_in["face"]) + feature = self.base_model(x_in) + feature = self.base_model_head(feature) + batch_size = feature.size(0) + feature = feature.flatten(2) + feature = feature.permute(2, 0, 1) + cls = self.cls_token.repeat( (1, batch_size, 1)) + feature = torch.cat([cls, feature], 0) + position = torch.from_numpy(np.arange(0, 50)).cuda() + + pos_feature = self.pos_embedding(position) + # feature is [HW, batch, channel] + feature = self.encoder(feature, pos_feature) + + feature = feature.permute(1, 2, 0) + feature = feature[:,:,0] + pred_gaze = self.feed(feature) + output_dict['pred_gaze'] = pred_gaze + return output_dict + + + +class HybridTR50(nn.Module): + def __init__(self): + super().__init__() + maps = 32 + nhead = 8 + dim_feature = 7*7 + dim_feedforward=2048 + dropout = 0.1 + num_layers=6 + + self.base_model = resnet50(pretrained=True) #False, maps=maps) + self.base_model_head = FeatureTransformer(in_channels=dim_feedforward,maps=maps) + + # d_model: dim of Q, K, V + # nhead: seq num + # dim_feedforward: dim of hidden linear layers + # dropout: prob + + encoder_layer = TransformerEncoderLayer( + maps, + nhead, + dim_feedforward, + dropout) + + encoder_norm = nn.LayerNorm(maps) + # num_encoder_layer: deeps of layers + + self.encoder = TransformerEncoder(encoder_layer, num_layers, encoder_norm) + self.cls_token = nn.Parameter(torch.randn(1, 1, maps)) + self.pos_embedding = nn.Embedding(dim_feature+1, maps) + + self.feed = nn.Linear(maps, 2) + + + + + def forward(self, x_in, normalize_z=False): + output_dict = {} + feature = self.base_model(x_in) ##(batch, 2048, 7, 7) + feature = self.base_model_head(feature) ## (batch, 32, 7, 7) + batch_size = feature.size(0) ## batch size + feature = feature.flatten(2) ## (batch, 32, 49) + feature = feature.permute(2, 0, 1) ## (49, batch, 32) + cls = self.cls_token.repeat( (1, batch_size, 1)) ## (1, batch, 32) + feature = torch.cat([cls, feature], 0) ## (50, batch, 32) + position = torch.from_numpy(np.arange(0, 50)).cuda() ## (50,) + pos_feature = self.pos_embedding(position) ## (50, 32) + # feature is [HW, batch, channel] + feature = self.encoder(feature, pos_feature) ## (50, batch, 32) + feature = feature.permute(1, 2, 0) ## (batch, 32, 50) + feature = feature[:,:,0] ## (batch, 32) + pred_gaze = self.feed(feature) ## (batch, 2) + output_dict['pred_gaze'] = pred_gaze + return output_dict + + + diff --git a/models/resnet.py b/models/resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..1567672d166c9fc2485df51195ba83e05247d62d --- /dev/null +++ b/models/resnet.py @@ -0,0 +1,366 @@ +import torch +import torch.nn as nn +from torch.utils.model_zoo import load_url as load_state_dict_from_url +import torch.nn.functional as F + + +model_urls = { + 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', + 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', + 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', + 'resnet152': 'https://download.pytorch.org/models/resnet152-394f9c45.pth' +} + + +def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=dilation, groups=groups, bias=False, dilation=dilation) + + +def conv1x1(in_planes, out_planes, stride=1): + """1x1 convolution""" + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, + base_width=64, dilation=1, norm_layer=None): + super(BasicBlock, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + if groups != 1 or base_width != 64: + raise ValueError('BasicBlock only supports groups=1 and base_width=64') + if dilation > 1: + raise NotImplementedError("Dilation > 1 not supported in BasicBlock") + # Both self.conv1 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = norm_layer(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = norm_layer(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, + base_width=64, dilation=1, norm_layer=None): + super(Bottleneck, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + width = int(planes * (base_width / 64.)) * groups + # Both self.conv2 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv1x1(inplanes, width) + self.bn1 = norm_layer(width) + self.conv2 = conv3x3(width, width, stride, groups, dilation) + self.bn2 = norm_layer(width) + self.conv3 = conv1x1(width, planes * self.expansion) + self.bn3 = norm_layer(planes * self.expansion) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + +class DeconvBasicBlock(nn.Module): + def __init__(self, in_planes, stride=1, norm_layer=None): + super(DeconvBasicBlock, self).__init__() + + if norm_layer is None: + norm_layer = nn.BatchNorm2d + planes = int(in_planes/stride) + + self.conv2 = nn.Conv2d(in_planes, in_planes, kernel_size=3, stride=1, padding=1, bias=False) + self.bn2 = norm_layer(in_planes) + + self.bn1 = norm_layer(planes) + + if stride == 1: + self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=1, padding=1, bias=False) + self.bn1 = norm_layer(planes) + self.shortcut = nn.Sequential() + else: + self.conv1 = nn.ConvTranspose2d(in_planes, planes, kernel_size=3, stride=stride, bias=False, padding=1, output_padding=1) + self.bn1 = norm_layer(planes) + self.shortcut = nn.Sequential( + nn.ConvTranspose2d(in_planes, planes, kernel_size=3, stride=stride, bias=False, padding=1, output_padding=1), + norm_layer(planes) + ) + + def forward(self, x): + out = torch.relu(self.bn2(self.conv2(x))) + out = self.bn1(self.conv1(out)) + out += self.shortcut(x) + out = torch.relu(out) + return out + +class DeconvBottleneck(nn.Module): + def __init__(self, in_channels, out_channels, expansion=2, stride=1, upsample=None, norm_layer=None): + super(DeconvBottleneck, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + self.expansion = expansion + self.conv1 = nn.Conv2d(in_channels, out_channels, + kernel_size=1, bias=False) + self.bn1 = norm_layer(out_channels) + if stride == 1: + self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, + stride=stride, bias=False, padding=1) + else: + self.conv2 = nn.ConvTranspose2d(out_channels, out_channels, + kernel_size=3, + stride=stride, bias=False, + padding=1, + output_padding=1) + self.bn2 = norm_layer(out_channels) + self.conv3 = nn.Conv2d(out_channels, out_channels * self.expansion, + kernel_size=1, bias=False) + self.bn3 = norm_layer(out_channels * self.expansion) + self.relu = nn.ReLU() + self.upsample = upsample + + def forward(self, x): + shortcut = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + out = self.relu(out) + + if self.upsample is not None: + shortcut = self.upsample(x) + + out += shortcut + out = self.relu(out) + + return out + + +class ResNet(nn.Module): + + def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, + groups=1, width_per_group=64, replace_stride_with_dilation=None, + norm_layer=None): + super(ResNet, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + self._norm_layer = norm_layer + + self.inplanes = 64 + self.dilation = 1 + if replace_stride_with_dilation is None: + # each element in the tuple indicates if we should replace + # the 2x2 stride with a dilated convolution instead + replace_stride_with_dilation = [False, False, False] + if len(replace_stride_with_dilation) != 3: + raise ValueError("replace_stride_with_dilation should be None " + "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) + self.groups = groups + self.base_width = width_per_group + self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, + bias=False) + self.bn1 = norm_layer(self.inplanes) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2, + dilate=replace_stride_with_dilation[0]) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2, + dilate=replace_stride_with_dilation[1]) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2, + dilate=replace_stride_with_dilation[2]) + + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + self.fc = nn.Linear(512 * block.expansion, num_classes) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + # Zero-initialize the last BN in each residual branch, + # so that the residual branch starts with zeros, and each residual block behaves like an identity. + # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 + if zero_init_residual: + for m in self.modules(): + if isinstance(m, Bottleneck): + nn.init.constant_(m.bn3.weight, 0) + elif isinstance(m, BasicBlock): + nn.init.constant_(m.bn2.weight, 0) + + def _make_layer(self, block, planes, blocks, stride=1, dilate=False): + norm_layer = self._norm_layer + downsample = None + previous_dilation = self.dilation + if dilate: + self.dilation *= stride + stride = 1 + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + conv1x1(self.inplanes, planes * block.expansion, stride), + norm_layer(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample, self.groups, + self.base_width, previous_dilation, norm_layer)) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append(block(self.inplanes, planes, groups=self.groups, + base_width=self.base_width, dilation=self.dilation, + norm_layer=norm_layer)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + return x + + +def _resnet(arch, block, layers, pretrained, progress, **kwargs): + model = ResNet(block, layers, **kwargs) + if pretrained: + state_dict = load_state_dict_from_url(model_urls[arch], + progress=progress) + + model.load_state_dict(state_dict) + return model + +def resnet18(pretrained=False, progress=True, **kwargs): + r"""ResNet-18 model from + `"Deep Residual Learning for Image Recognition" '_ + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, + **kwargs) + +def resnet34(pretrained=False, progress=True, **kwargs): + r"""ResNet-34 model from + `"Deep Residual Learning for Image Recognition" '_ + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, + **kwargs) + +def resnet50(pretrained=False, progress=True, **kwargs): + r"""ResNet-50 model from + `"Deep Residual Learning for Image Recognition" '_ + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, + **kwargs) + +def resnet152(pretrained=False, progress=True, **kwargs): + r"""ResNet-152 model from + `"Deep Residual Learning for Image Recognition" '_ + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, + **kwargs) + + +class ResNetGaze(nn.Module): + def __init__(self): + raise NotImplementedError + + def forward(self, x_in): + output_dict = {} + features = self.feature(x_in) + z = self.avgpool(features) + z = z.view(z.size(0), -1) ## (batch, dim) + pred_gaze = self.fc(z) + output_dict['pred_gaze'] = pred_gaze + return output_dict + +class Res18(ResNetGaze, nn.Module): + def __init__(self, resnet_pretrained=True): + nn.Module.__init__(self) + self.feature = resnet18(pretrained=resnet_pretrained) + self.avgpool = nn.AdaptiveAvgPool2d((1,1)) + + self.fc = nn.Linear(512, 2) + + +class Res50(ResNetGaze, nn.Module): + def __init__(self, resnet_pretrained=True): + nn.Module.__init__(self) + self.feature = resnet50(pretrained=resnet_pretrained) + self.avgpool = nn.AdaptiveAvgPool2d((1,1)) + self.fc = nn.Linear(2048, 2) + + +class Res152(ResNetGaze, nn.Module): + def __init__(self, resnet_pretrained=True): + nn.Module.__init__(self) + self.feature = resnet152(pretrained=resnet_pretrained) + self.avgpool = nn.AdaptiveAvgPool2d((1,1)) + self.fc = nn.Linear(2048, 2) \ No newline at end of file diff --git a/models/vit/mae.py b/models/vit/mae.py new file mode 100644 index 0000000000000000000000000000000000000000..b5a95ef728d59c9681c6004728f18ab11d51e587 --- /dev/null +++ b/models/vit/mae.py @@ -0,0 +1,429 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# References: +# timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm +# DeiT: https://github.com/facebookresearch/deit +# -------------------------------------------------------- + +from functools import partial + +import torch +import torch.nn as nn + +from timm.models.vision_transformer import PatchEmbed, Block + +# from util.pos_embed import get_2d_sincos_pos_embed + + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# Position embedding utils +# -------------------------------------------------------- + +import numpy as np + +import torch + +# -------------------------------------------------------- +# 2D sine-cosine position embedding +# References: +# Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py +# MoCo v3: https://github.com/facebookresearch/moco-v3 +# -------------------------------------------------------- +def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): + """ + grid_size: int of the grid height and width + return: + pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + grid_h = np.arange(grid_size, dtype=np.float32) + grid_w = np.arange(grid_size, dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + + grid = grid.reshape([2, 1, grid_size, grid_size]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if cls_token: + pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) + return pos_embed + + +def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + assert embed_dim % 2 == 0 + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + return emb + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) + out: (M, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=float) + omega /= embed_dim / 2. + omega = 1. / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + + + +# -------------------------------------------------------- +# Interpolate position embeddings for high-resolution +# References: +# DeiT: https://github.com/facebookresearch/deit +# -------------------------------------------------------- +def interpolate_pos_embed(model, checkpoint_model): + if 'pos_embed' in checkpoint_model: + pos_embed_checkpoint = checkpoint_model['pos_embed'] + embedding_size = pos_embed_checkpoint.shape[-1] + num_patches = model.patch_embed.num_patches + num_extra_tokens = model.pos_embed.shape[-2] - num_patches + # height (== width) for the checkpoint position embedding + orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) + # height (== width) for the new position embedding + new_size = int(num_patches ** 0.5) + # class_token and dist_token are kept unchanged + if orig_size != new_size: + print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) + extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] + # only the position tokens are interpolated + pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] + pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) + pos_tokens = torch.nn.functional.interpolate( + pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) + pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) + new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) + checkpoint_model['pos_embed'] = new_pos_embed + + +class MaskedAutoencoderViT(nn.Module): + """ Masked Autoencoder with VisionTransformer backbone + """ + def __init__(self, img_size=224, patch_size=16, in_chans=3, + embed_dim=1024, depth=24, num_heads=16, + decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, + mlp_ratio=4., norm_layer=nn.LayerNorm, norm_pix_loss=False): + super().__init__() + + # -------------------------------------------------------------------------- + # MAE encoder specifics + self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim) + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False) # fixed sin-cos embedding + + self.blocks = nn.ModuleList([ + # Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer) + Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer) + for i in range(depth)]) + self.norm = norm_layer(embed_dim) + # -------------------------------------------------------------------------- + + # -------------------------------------------------------------------------- + # MAE decoder specifics + self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True) + + self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim)) + + self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, decoder_embed_dim), requires_grad=False) # fixed sin-cos embedding + + self.decoder_blocks = nn.ModuleList([ + # Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer) + Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer) + for i in range(decoder_depth)]) + + self.decoder_norm = norm_layer(decoder_embed_dim) + self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size**2 * in_chans, bias=True) # decoder to patch + # -------------------------------------------------------------------------- + + self.norm_pix_loss = norm_pix_loss + + self.initialize_weights() + + def initialize_weights(self): + # initialization + # initialize (and freeze) pos_embed by sin-cos embedding + pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True) + self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) + + decoder_pos_embed = get_2d_sincos_pos_embed(self.decoder_pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True) + self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0)) + + # initialize patch_embed like nn.Linear (instead of nn.Conv2d) + w = self.patch_embed.proj.weight.data + torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + + # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.) + torch.nn.init.normal_(self.cls_token, std=.02) + torch.nn.init.normal_(self.mask_token, std=.02) + + # initialize nn.Linear and nn.LayerNorm + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + # we use xavier_uniform following official JAX ViT: + torch.nn.init.xavier_uniform_(m.weight) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def patchify(self, imgs): + """ + imgs: (N, 3, H, W) + x: (N, L, patch_size**2 *3) + """ + p = self.patch_embed.patch_size[0] + assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0 + + h = w = imgs.shape[2] // p + x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p)) + x = torch.einsum('nchpwq->nhwpqc', x) + x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3)) + return x + + def unpatchify(self, x): + """ + x: (N, L, patch_size**2 *3) + imgs: (N, 3, H, W) + """ + p = self.patch_embed.patch_size[0] + h = w = int(x.shape[1]**.5) + assert h * w == x.shape[1] + + x = x.reshape(shape=(x.shape[0], h, w, p, p, 3)) + x = torch.einsum('nhwpqc->nchpwq', x) + imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p)) + return imgs + + def random_masking(self, x, mask_ratio): + """ + Perform per-sample random masking by per-sample shuffling. + Per-sample shuffling is done by argsort random noise. + x: [N, L, D], sequence + """ + N, L, D = x.shape # batch, length, dim + len_keep = int(L * (1 - mask_ratio)) + + noise = torch.rand(N, L, device=x.device) # noise in [0, 1] + + # sort noise for each sample + ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove + ids_restore = torch.argsort(ids_shuffle, dim=1) + + # keep the first subset + ids_keep = ids_shuffle[:, :len_keep] + x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) + + # generate the binary mask: 0 is keep, 1 is remove + mask = torch.ones([N, L], device=x.device) + mask[:, :len_keep] = 0 + # unshuffle to get the binary mask + mask = torch.gather(mask, dim=1, index=ids_restore) + + return x_masked, mask, ids_restore + + def forward_encoder(self, x, mask_ratio): + # embed patches + x = self.patch_embed(x) + + # add pos embed w/o cls token + x = x + self.pos_embed[:, 1:, :] + + # masking: length -> length * mask_ratio + x, mask, ids_restore = self.random_masking(x, mask_ratio) + + # append cls token + cls_token = self.cls_token + self.pos_embed[:, :1, :] + cls_tokens = cls_token.expand(x.shape[0], -1, -1) + x = torch.cat((cls_tokens, x), dim=1) + + # apply Transformer blocks + for blk in self.blocks: + x = blk(x) + x = self.norm(x) + + return x, mask, ids_restore + + def forward_decoder(self, x, ids_restore): + # embed tokens + x = self.decoder_embed(x) + + # append mask tokens to sequence + mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1) + x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token + x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle + x = torch.cat([x[:, :1, :], x_], dim=1) # append cls token + + # add pos embed + x = x + self.decoder_pos_embed + + # apply Transformer blocks + for blk in self.decoder_blocks: + x = blk(x) + x = self.decoder_norm(x) + + # predictor projection + x = self.decoder_pred(x) + + # remove cls token + x = x[:, 1:, :] + + return x + + def forward_loss(self, imgs, pred, mask): + """ + imgs: [N, 3, H, W] + pred: [N, L, p*p*3] + mask: [N, L], 0 is keep, 1 is remove, + """ + target = self.patchify(imgs) + if self.norm_pix_loss: + mean = target.mean(dim=-1, keepdim=True) + var = target.var(dim=-1, keepdim=True) + target = (target - mean) / (var + 1.e-6)**.5 + + loss = (pred - target) ** 2 + loss = loss.mean(dim=-1) # [N, L], mean loss per patch + + loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches + return loss + + def forward(self, imgs, mask_ratio=0.75): + latent, mask, ids_restore = self.forward_encoder(imgs, mask_ratio) + pred = self.forward_decoder(latent, ids_restore) # [N, L, p*p*3] + loss = self.forward_loss(imgs, pred, mask) + return loss, pred, mask + + +def mae_vit_base_patch16_dec512d8b(**kwargs): + model = MaskedAutoencoderViT( + patch_size=16, embed_dim=768, depth=12, num_heads=12, + decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, + mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + return model + + +def mae_vit_large_patch16_dec512d8b(**kwargs): + model = MaskedAutoencoderViT( + patch_size=16, embed_dim=1024, depth=24, num_heads=16, + decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, + mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + return model + + +def mae_vit_huge_patch14_dec512d8b(**kwargs): + model = MaskedAutoencoderViT( + patch_size=14, embed_dim=1280, depth=32, num_heads=16, + decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, + mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + return model + + +# set recommended archs +mae_vit_base_patch16 = mae_vit_base_patch16_dec512d8b # decoder: 512 dim, 8 blocks +mae_vit_large_patch16 = mae_vit_large_patch16_dec512d8b # decoder: 512 dim, 8 blocks +mae_vit_huge_patch14 = mae_vit_huge_patch14_dec512d8b # decoder: 512 dim, 8 blocks + + + + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# References: +# timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm +# DeiT: https://github.com/facebookresearch/deit +# -------------------------------------------------------- + +from functools import partial + +import torch +import torch.nn as nn + +import timm.models.vision_transformer + + +class VisionTransformer(timm.models.vision_transformer.VisionTransformer): + """ Vision Transformer with support for global average pooling + """ + def __init__(self, global_pool=False, **kwargs): + super(VisionTransformer, self).__init__(**kwargs) + + self.global_pool = global_pool + if self.global_pool: + norm_layer = kwargs['norm_layer'] + embed_dim = kwargs['embed_dim'] + self.fc_norm = norm_layer(embed_dim) + + del self.norm # remove the original norm + + def forward_features(self, x): + B = x.shape[0] + x = self.patch_embed(x) + + cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks + x = torch.cat((cls_tokens, x), dim=1) + x = x + self.pos_embed + x = self.pos_drop(x) + + for blk in self.blocks: + x = blk(x) + + if self.global_pool: + x = x[:, 1:, :].mean(dim=1) # global pool without cls token + outcome = self.fc_norm(x) + else: + x = self.norm(x) + outcome = x[:, 0] + + return outcome + + +def vit_base_patch16(**kwargs): + model = VisionTransformer( + patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + return model + + +def vit_large_patch16(**kwargs): + model = VisionTransformer( + patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + return model + + +def vit_huge_patch14(**kwargs): + model = VisionTransformer( + patch_size=14, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + return model \ No newline at end of file diff --git a/models/vit/mae_gaze.py b/models/vit/mae_gaze.py new file mode 100644 index 0000000000000000000000000000000000000000..7251d98244a7d6c366bd18630e128d86c0a85db0 --- /dev/null +++ b/models/vit/mae_gaze.py @@ -0,0 +1,69 @@ + +from os import replace +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.optim.lr_scheduler import StepLR +import torch.utils.model_zoo as model_zoo +from torch.utils.model_zoo import load_url as load_state_dict_from_url + +from functools import partial +from torchvision.models import vit_b_16, vit_b_32, vit_l_16, vit_l_32 + +from models.vit.mae import interpolate_pos_embed, MaskedAutoencoderViT, vit_base_patch16, vit_large_patch16, vit_huge_patch14 + + + + +class MAE_Gaze(nn.Module): + + def __init__(self, model_type='vit_b_16', global_pool=False, drop_path_rate=0.1, + custom_pretrained_path=None): + + super().__init__() + if model_type == "vit_b_16": + self.vit = vit_base_patch16( global_pool=global_pool, drop_path_rate=drop_path_rate) + elif model_type == "vit_l_16": + self.vit = vit_large_patch16( global_pool=global_pool, drop_path_rate=drop_path_rate) + elif model_type == "vit_h_14": + self.vit = vit_huge_patch14( global_pool=global_pool, drop_path_rate=drop_path_rate) + else: + raise ValueError('model_type not supported') + + if custom_pretrained_path is not None: + checkpoint_model = torch.load(custom_pretrained_path, map_location='cpu')['model'] + state_dict = self.vit.state_dict() + for k in ['head.weight', 'head.bias']: + if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape: + print(f"Removing key {k} from pretrained checkpoint") + del checkpoint_model[k] + + + + # interpolate position embedding + interpolate_pos_embed(self.vit, checkpoint_model) + + # keys_in_ckpt = checkpoint_model.keys() + # print('Keys in ckpt: ', keys_in_ckpt) + self.vit.load_state_dict( checkpoint_model, strict=False) + print('Loaded custom pretrained weights from {}'.format(custom_pretrained_path)) + + # del self.decoder_embed + # del self.mask_token + # del self.decoder_pos_embed + # del self.decoder_blocks + # del self.decoder_norm + # del self.decoder_pred + + embed_dim = self.vit.embed_dim + self.gaze_fc = nn.Linear(embed_dim, 2) + + + def forward(self, input): + features = self.vit.forward_features(input) + + pred_gaze = self.gaze_fc(features) + output_dict = {} + output_dict['pred_gaze'] = pred_gaze + return output_dict + \ No newline at end of file diff --git a/models/vit/vit_gaze.py b/models/vit/vit_gaze.py new file mode 100644 index 0000000000000000000000000000000000000000..2678ae368552392c4211f1c642d3a6e4938b98ac --- /dev/null +++ b/models/vit/vit_gaze.py @@ -0,0 +1,103 @@ + +from os import replace +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.optim.lr_scheduler import StepLR +import torch.utils.model_zoo as model_zoo +from torch.utils.model_zoo import load_url as load_state_dict_from_url +from functools import partial +from torchvision.models import vit_b_16, vit_b_32, vit_l_16, vit_l_32 + + + +class ViTGaze(nn.Module): + def __init__(self, + vit_type="b_16", + pretrained=True, + custom_pretrained_path=None, + **kwargs + ): + super().__init__() + if vit_type == "b_16": + """ + patch_size=16, + num_layers=12, + num_heads=12, + hidden_dim=768, + mlp_dim=3072, + """ + self.vit = vit_b_16(pretrained=pretrained ) + + + self.vit.heads = nn.Sequential( + nn.Linear(768,2) + ) + elif vit_type == "b_32": + self.vit = vit_b_32(pretrained=pretrained) + self.vit.heads = nn.Sequential( + nn.Linear(768,2) + ) + elif vit_type == "l_16": + self.vit = vit_l_16(pretrained=pretrained) + self.vit.heads = nn.Sequential( + nn.Linear(1024,2) + ) + elif vit_type == "l_32": + self.vit = vit_l_32(pretrained=pretrained) + self.vit.heads = nn.Sequential( + nn.Linear(1024,2) + ) + if custom_pretrained_path is not None: + ckpt = torch.load(custom_pretrained_path) + print('Loading custom pretrained weights from: ', custom_pretrained_path) + # self.vit.load_state_dict( ckpt['model'], strict=True) + keys_in_ckpt = ckpt.keys() + print('Keys in ckpt: ', keys_in_ckpt) + self.vit.load_state_dict( ckpt, strict=True) + + def forward(self, x_in): + out_dict = {} + gaze = self.vit(x_in) + out_dict['pred_gaze'] = gaze + return out_dict + + + +from models.vit.mae import interpolate_pos_embed, vit_huge_patch14 +class CustomViT_H14(nn.Module): + def __init__(self, global_pool=False, drop_path_rate=0.1, + custom_pretrained_path=None): + super().__init__() + self.vit = vit_huge_patch14( global_pool=global_pool, drop_path_rate=drop_path_rate) + + if custom_pretrained_path is not None: + checkpoint_model = torch.load(custom_pretrained_path, map_location='cpu') + + state_dict = self.vit.state_dict() + for k in ['head.weight', 'head.bias']: + if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape: + print(f"Removing key {k} from pretrained checkpoint") + del checkpoint_model[k] + + # interpolate position embedding + interpolate_pos_embed(self.vit, checkpoint_model) + + self.vit.load_state_dict( checkpoint_model, strict=False ) + print('Loaded custom pretrained weights from {}'.format(custom_pretrained_path)) + + embed_dim = self.vit.embed_dim + self.gaze_fc = nn.Linear(embed_dim, 2) + + + def forward(self, input): + features = self.vit.forward_features(input) + + pred_gaze = self.gaze_fc(features) + output_dict = {} + output_dict['pred_gaze'] = pred_gaze + return output_dict + + + + \ No newline at end of file diff --git a/unigaze/__init__.py b/unigaze/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/unigaze/configs/config.yaml b/unigaze/configs/config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..dc61e28d7401fcba19402b7c1a3fb7c795ccd112 --- /dev/null +++ b/unigaze/configs/config.yaml @@ -0,0 +1,38 @@ +defaults: + - _self_ + - exp: exp_224 + + + +mode: train +random_seed: 42 +num_workers: 20 + + +test_per_epoch: 1 +print_freq: 100 +data_sanity_check: false +log_wandb: false +output_dir: "./logs" +ckpt_resume: null +pretrain_ckptpath: null + + +optimizer_cfg: null +scheduler_cfg: null + +batch_size: 50 +test_batch_size: 200 + +epochs: 25 +valid_epoch: 1 +eval_epoch: 10 +save_epoch: 10 + + +use_autocast: False + +batchnorm_type: + label: clean + unlabel: aug + test: clean \ No newline at end of file diff --git a/unigaze/configs/data/eyediap_cs.yaml b/unigaze/configs/data/eyediap_cs.yaml new file mode 100644 index 0000000000000000000000000000000000000000..71e4681edb267d45f367bd9ca4c71703acda57a2 --- /dev/null +++ b/unigaze/configs/data/eyediap_cs.yaml @@ -0,0 +1,22 @@ +type: datasets.eyediap.EYEDIAPDataset +params: + data_name: eyediap_cs + color_type: bgr + transform_type: 'basic_imagenet' + image_size: 224 + dataset_path: null + keys_to_use: + - 'person_1.h5' + - 'person_2.h5' + - 'person_3.h5' + - 'person_4.h5' + - 'person_5.h5' + - 'person_6.h5' + - 'person_7.h5' + - 'person_8.h5' + - 'person_9.h5' + - 'person_10.h5' + - 'person_11.h5' + - 'person_14.h5' + - 'person_15.h5' + - 'person_16.h5' \ No newline at end of file diff --git a/unigaze/configs/data/eyediap_cs_test.yaml b/unigaze/configs/data/eyediap_cs_test.yaml new file mode 100644 index 0000000000000000000000000000000000000000..cfb7a081b612e4e4466d389cbe5ddf3cc9174e6c --- /dev/null +++ b/unigaze/configs/data/eyediap_cs_test.yaml @@ -0,0 +1,22 @@ +type: datasets.eyediap.EYEDIAPDataset +params: + data_name: eyediap_cs + color_type: bgr + transform_type: 'basic_imagenet' + image_size: 224 + dataset_path: null + keys_to_use: + # - 'person_1.h5' + # - 'person_2.h5' + # - 'person_3.h5' + # - 'person_4.h5' + # - 'person_5.h5' + # - 'person_6.h5' + # - 'person_7.h5' + # - 'person_8.h5' + - 'person_9.h5' + - 'person_10.h5' + - 'person_11.h5' + - 'person_14.h5' + - 'person_15.h5' + - 'person_16.h5' \ No newline at end of file diff --git a/unigaze/configs/data/eyediap_cs_train.yaml b/unigaze/configs/data/eyediap_cs_train.yaml new file mode 100644 index 0000000000000000000000000000000000000000..4462d18e9cd1fb80f92a027c26c4f86c86f61186 --- /dev/null +++ b/unigaze/configs/data/eyediap_cs_train.yaml @@ -0,0 +1,22 @@ +type: datasets.eyediap.EYEDIAPDataset +params: + data_name: eyediap_cs + color_type: bgr + transform_type: 'basic_imagenet' + image_size: 224 + dataset_path: null + keys_to_use: + - 'person_1.h5' + - 'person_2.h5' + - 'person_3.h5' + - 'person_4.h5' + - 'person_5.h5' + - 'person_6.h5' + - 'person_7.h5' + - 'person_8.h5' + # - 'person_9.h5' + # - 'person_10.h5' + # - 'person_11.h5' + # - 'person_14.h5' + # - 'person_15.h5' + # - 'person_16.h5' \ No newline at end of file diff --git a/unigaze/configs/data/eyediap_ft.yaml b/unigaze/configs/data/eyediap_ft.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9a703334fab43c151486f9fc245ffd38d68f3f72 --- /dev/null +++ b/unigaze/configs/data/eyediap_ft.yaml @@ -0,0 +1,24 @@ +type: datasets.eyediap.EYEDIAPDataset +params: + data_name: eyediap_ft + color_type: bgr + transform_type: 'basic_imagenet' + image_size: 224 + dataset_path: null + keys_to_use: + - 'person_1.h5' + - 'person_2.h5' + - 'person_3.h5' + - 'person_4.h5' + - 'person_5.h5' + - 'person_6.h5' + - 'person_7.h5' + - 'person_8.h5' + - 'person_9.h5' + - 'person_10.h5' + - 'person_11.h5' + - 'person_12.h5' + - 'person_13.h5' + - 'person_14.h5' + - 'person_15.h5' + - 'person_16.h5' \ No newline at end of file diff --git a/unigaze/configs/data/eyediap_ft_test.yaml b/unigaze/configs/data/eyediap_ft_test.yaml new file mode 100644 index 0000000000000000000000000000000000000000..78742f31938db777bb9a91ca5e3a39fd62ab4a2b --- /dev/null +++ b/unigaze/configs/data/eyediap_ft_test.yaml @@ -0,0 +1,24 @@ +type: datasets.eyediap.EYEDIAPDataset +params: + data_name: eyediap_ft + color_type: bgr + transform_type: 'basic_imagenet' + image_size: 224 + dataset_path: null + keys_to_use: + # - 'person_1.h5' + # - 'person_2.h5' + # - 'person_3.h5' + # - 'person_4.h5' + # - 'person_5.h5' + # - 'person_6.h5' + # - 'person_7.h5' + # - 'person_8.h5' + - 'person_9.h5' + - 'person_10.h5' + - 'person_11.h5' + - 'person_12.h5' + - 'person_13.h5' + - 'person_14.h5' + - 'person_15.h5' + - 'person_16.h5' \ No newline at end of file diff --git a/unigaze/configs/data/eyediap_ft_train.yaml b/unigaze/configs/data/eyediap_ft_train.yaml new file mode 100644 index 0000000000000000000000000000000000000000..166676a23b2bf6e50bbd21caf82840c0d6d37587 --- /dev/null +++ b/unigaze/configs/data/eyediap_ft_train.yaml @@ -0,0 +1,24 @@ +type: datasets.eyediap.EYEDIAPDataset +params: + data_name: eyediap_ft + color_type: bgr + transform_type: 'basic_imagenet' + image_size: 224 + dataset_path: null + keys_to_use: + - 'person_1.h5' + - 'person_2.h5' + - 'person_3.h5' + - 'person_4.h5' + - 'person_5.h5' + - 'person_6.h5' + - 'person_7.h5' + - 'person_8.h5' + # - 'person_9.h5' + # - 'person_10.h5' + # - 'person_11.h5' + # - 'person_12.h5' + # - 'person_13.h5' + # - 'person_14.h5' + # - 'person_15.h5' + # - 'person_16.h5' \ No newline at end of file diff --git a/unigaze/configs/data/gaze360_test.yaml b/unigaze/configs/data/gaze360_test.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3e82d8a53edc6fc65971d5de6b31d849d23376c2 --- /dev/null +++ b/unigaze/configs/data/gaze360_test.yaml @@ -0,0 +1,74 @@ + + + +type: datasets.gaze360.Gaze360Dataset +params: + data_name: gaze360_224_test + saved_norm_config: + focal_norm: 960 + distance_norm: 600 + roi_size: [224, 224] + norm_config: + focal_norm: 960 + distance_norm: 600 + roi_size: [224, 224] + + color_type: bgr + transform_type: 'basic_imagenet' + dataset_path: null + image_size: 224 + sample_rate_use: 1 + whether_crop_resize: False + keys_to_use: + - 000000.h5 + - 000001.h5 + - 000002.h5 + - 000003.h5 + - 000004.h5 + - 000010.h5 + - 000014.h5 + - 000022.h5 + - 000031.h5 + - 000032.h5 + - 000044.h5 + - 000045.h5 + - 000057.h5 + - 000058.h5 + - 000070.h5 + - 000078.h5 + - 000270.h5 + - 000277.h5 + - 000278.h5 + - 000279.h5 + - 000316.h5 + - 000364.h5 + - 000367.h5 + - 000511.h5 + - 000512.h5 + - 000513.h5 + - 000515.h5 + - 000527.h5 + - 000536.h5 + - 000543.h5 + - 000579.h5 + - 000584.h5 + - 000585.h5 + - 000600.h5 + - 000603.h5 + - 000604.h5 + - 000611.h5 + - 000614.h5 + - 000615.h5 + - 000616.h5 + - 000649.h5 + - 000650.h5 + - 000651.h5 + - 000652.h5 + - 000687.h5 + - 000723.h5 + - 000777.h5 + - 000782.h5 + - 000823.h5 + - 000907.h5 + - 000909.h5 + - 000982.h5 diff --git a/unigaze/configs/data/gaze360_train.yaml b/unigaze/configs/data/gaze360_train.yaml new file mode 100644 index 0000000000000000000000000000000000000000..bdd75fb074414aa09e49bcf9e8eae5f9214d8ce3 --- /dev/null +++ b/unigaze/configs/data/gaze360_train.yaml @@ -0,0 +1,252 @@ + + + +type: datasets.gaze360.Gaze360Dataset +params: + data_name: gaze360_224_train + saved_norm_config: + focal_norm: 960 + distance_norm: 600 + roi_size: [224, 224] + norm_config: + focal_norm: 960 + distance_norm: 600 + roi_size: [224, 224] + + color_type: bgr + transform_type: 'basic_imagenet' + dataset_path: null + image_size: 224 + sample_rate_use: 1 + + keys_to_use: + - 000000.h5 + - 000001.h5 + - 000002.h5 + - 000003.h5 + - 000004.h5 + - 000006.h5 + - 000007.h5 + - 000009.h5 + - 000010.h5 + - 000011.h5 + - 000013.h5 + - 000016.h5 + - 000019.h5 + - 000020.h5 + - 000029.h5 + - 000030.h5 + - 000031.h5 + - 000032.h5 + - 000034.h5 + - 000035.h5 + - 000038.h5 + - 000039.h5 + - 000043.h5 + - 000048.h5 + - 000049.h5 + - 000050.h5 + - 000058.h5 + - 000060.h5 + - 000061.h5 + - 000062.h5 + - 000063.h5 + - 000072.h5 + - 000073.h5 + - 000074.h5 + - 000075.h5 + - 000076.h5 + - 000077.h5 + - 000081.h5 + - 000083.h5 + - 000084.h5 + - 000085.h5 + - 000090.h5 + - 000093.h5 + - 000094.h5 + - 000099.h5 + - 000109.h5 + - 000111.h5 + - 000112.h5 + - 000116.h5 + - 000122.h5 + - 000134.h5 + - 000146.h5 + - 000148.h5 + - 000149.h5 + - 000150.h5 + - 000151.h5 + - 000152.h5 + - 000154.h5 + - 000156.h5 + - 000158.h5 + - 000159.h5 + - 000160.h5 + - 000161.h5 + - 000162.h5 + - 000165.h5 + - 000166.h5 + - 000170.h5 + - 000171.h5 + - 000172.h5 + - 000184.h5 + - 000185.h5 + - 000186.h5 + - 000187.h5 + - 000188.h5 + - 000189.h5 + - 000190.h5 + - 000202.h5 + - 000205.h5 + - 000206.h5 + - 000208.h5 + - 000214.h5 + - 000216.h5 + - 000217.h5 + - 000219.h5 + - 000220.h5 + - 000221.h5 + - 000222.h5 + - 000228.h5 + - 000237.h5 + - 000248.h5 + - 000250.h5 + - 000255.h5 + - 000256.h5 + - 000257.h5 + - 000258.h5 + - 000262.h5 + - 000278.h5 + - 000283.h5 + - 000284.h5 + - 000287.h5 + - 000288.h5 + - 000297.h5 + - 000298.h5 + - 000299.h5 + - 000300.h5 + - 000324.h5 + - 000368.h5 + - 000369.h5 + - 000408.h5 + - 000409.h5 + - 000410.h5 + - 000418.h5 + - 000440.h5 + - 000441.h5 + - 000449.h5 + - 000457.h5 + - 000458.h5 + - 000459.h5 + - 000460.h5 + - 000461.h5 + - 000463.h5 + - 000494.h5 + - 000501.h5 + - 000502.h5 + - 000509.h5 + - 000510.h5 + - 000511.h5 + - 000512.h5 + - 000513.h5 + - 000514.h5 + - 000515.h5 + - 000517.h5 + - 000519.h5 + - 000529.h5 + - 000541.h5 + - 000547.h5 + - 000548.h5 + - 000549.h5 + - 000550.h5 + - 000551.h5 + - 000552.h5 + - 000565.h5 + - 000566.h5 + - 000569.h5 + - 000571.h5 + - 000573.h5 + - 000574.h5 + - 000586.h5 + - 000587.h5 + - 000588.h5 + - 000589.h5 + - 000592.h5 + - 000597.h5 + - 000603.h5 + - 000604.h5 + - 000605.h5 + - 000611.h5 + - 000613.h5 + - 000617.h5 + - 000620.h5 + - 000623.h5 + - 000634.h5 + - 000635.h5 + - 000636.h5 + - 000639.h5 + - 000640.h5 + - 000641.h5 + - 000642.h5 + - 000643.h5 + - 000644.h5 + - 000645.h5 + - 000650.h5 + - 000656.h5 + - 000658.h5 + - 000659.h5 + - 000660.h5 + - 000661.h5 + - 000662.h5 + - 000670.h5 + - 000671.h5 + - 000677.h5 + - 000683.h5 + - 000714.h5 + - 000721.h5 + - 000723.h5 + - 000738.h5 + - 000741.h5 + - 000742.h5 + - 000744.h5 + - 000751.h5 + - 000755.h5 + - 000761.h5 + - 000762.h5 + - 000763.h5 + - 000764.h5 + - 000765.h5 + - 000768.h5 + - 000777.h5 + - 000779.h5 + - 000780.h5 + - 000781.h5 + - 000783.h5 + - 000786.h5 + - 000787.h5 + - 000789.h5 + - 000800.h5 + - 000801.h5 + - 000802.h5 + - 000803.h5 + - 000813.h5 + - 000815.h5 + - 000816.h5 + - 000831.h5 + - 000834.h5 + - 000835.h5 + - 000838.h5 + - 000861.h5 + - 000862.h5 + - 000899.h5 + - 000900.h5 + - 000916.h5 + - 000918.h5 + - 000923.h5 + - 000935.h5 + - 000946.h5 + - 000971.h5 + - 000978.h5 + - 000990.h5 + - 000991.h5 + - 001092.h5 diff --git a/unigaze/configs/data/gazecapture_test.yaml b/unigaze/configs/data/gazecapture_test.yaml new file mode 100644 index 0000000000000000000000000000000000000000..4af37a4e95fa9f219487570eee1101116e9dcb0f --- /dev/null +++ b/unigaze/configs/data/gazecapture_test.yaml @@ -0,0 +1,15 @@ +type: datasets.gazecapture.GazeCaptureDataset +params: + data_name: gazecapture_test + color_type: rgb + transform_type: 'basic_imagenet' + image_size: 224 + dataset_path: null + keys_to_use: [ "00010.h5", "00110.h5", "00126.h5", "00178.h5", "00190.h5", "00192.h5", "00220.h5", "00222.h5", "00233.h5", "00319.h5", "00330.h5", "00343.h5", "00382.h5", "00460.h5", "00509.h5", "00511.h5", "00546.h5", "00563.h5", "00580.h5", "00585.h5", + "00611.h5", "00616.h5", "00619.h5", "00646.h5", "00654.h5", "00680.h5", "00686.h5", "00700.h5", "00721.h5", "00741.h5", "00777.h5", "00796.h5", "00868.h5", "00921.h5", "00935.h5", "00949.h5", "00953.h5", + "00965.h5", "00968.h5", "01036.h5", "01041.h5", "01051.h5", "01091.h5", "01148.h5", "01152.h5", "01155.h5", "01183.h5", "01200.h5", "01273.h5", "01278.h5", "01286.h5", "01326.h5", "01329.h5", "01370.h5", "01376.h5", "01425.h5", "01457.h5", "01477.h5", + "01506.h5", "01517.h5", "01525.h5", "01575.h5", "01625.h5", "01672.h5", "01674.h5", "01689.h5", "01782.h5", "01794.h5", "01813.h5", "01830.h5", "01855.h5", "01863.h5", "01877.h5", "01893.h5", "01941.h5", "01959.h5", "01978.h5", + "01983.h5", "01985.h5", "01997.h5", "02006.h5", "02020.h5", "02043.h5", "02078.h5", "02091.h5", "02109.h5", "02197.h5", "02213.h5", "02239.h5", "02240.h5", "02269.h5", "02275.h5", "02281.h5", "02292.h5", + "02301.h5", "02348.h5", "02413.h5", "02419.h5", "02450.h5", "02455.h5", "02461.h5", "02480.h5", "02536.h5", "02601.h5", "02734.h5", "02755.h5", "02756.h5", "02805.h5", "02833.h5", "02851.h5", "02885.h5", "02899.h5", "02942.h5", "02966.h5", + "02986.h5", "03011.h5", "03024.h5", "03043.h5", "03117.h5", "03126.h5", "03140.h5", "03177.h5", "03183.h5", "03185.h5", "03202.h5", "03216.h5", "03223.h5", "03247.h5", "03270.h5", "03324.h5", "03326.h5", "03344.h5", "03352.h5", "03361.h5", "03366.h5", + "03404.h5", "03412.h5", "03451.h5", "03523.h5"] diff --git a/unigaze/configs/data/gazecapture_test_ds15.yaml b/unigaze/configs/data/gazecapture_test_ds15.yaml new file mode 100644 index 0000000000000000000000000000000000000000..cae03ec679320f2a675d77933ebb645f46599287 --- /dev/null +++ b/unigaze/configs/data/gazecapture_test_ds15.yaml @@ -0,0 +1,16 @@ +type: datasets.gazecapture.GazeCaptureDataset +params: + data_name: gazecapture_test + color_type: rgb + transform_type: 'basic_imagenet' + image_size: 224 + sample_rate_use: 15 + dataset_path: null + keys_to_use: [ "00010.h5", "00110.h5", "00126.h5", "00178.h5", "00190.h5", "00192.h5", "00220.h5", "00222.h5", "00233.h5", "00319.h5", "00330.h5", "00343.h5", "00382.h5", "00460.h5", "00509.h5", "00511.h5", "00546.h5", "00563.h5", "00580.h5", "00585.h5", + "00611.h5", "00616.h5", "00619.h5", "00646.h5", "00654.h5", "00680.h5", "00686.h5", "00700.h5", "00721.h5", "00741.h5", "00777.h5", "00796.h5", "00868.h5", "00921.h5", "00935.h5", "00949.h5", "00953.h5", + "00965.h5", "00968.h5", "01036.h5", "01041.h5", "01051.h5", "01091.h5", "01148.h5", "01152.h5", "01155.h5", "01183.h5", "01200.h5", "01273.h5", "01278.h5", "01286.h5", "01326.h5", "01329.h5", "01370.h5", "01376.h5", "01425.h5", "01457.h5", "01477.h5", + "01506.h5", "01517.h5", "01525.h5", "01575.h5", "01625.h5", "01672.h5", "01674.h5", "01689.h5", "01782.h5", "01794.h5", "01813.h5", "01830.h5", "01855.h5", "01863.h5", "01877.h5", "01893.h5", "01941.h5", "01959.h5", "01978.h5", + "01983.h5", "01985.h5", "01997.h5", "02006.h5", "02020.h5", "02043.h5", "02078.h5", "02091.h5", "02109.h5", "02197.h5", "02213.h5", "02239.h5", "02240.h5", "02269.h5", "02275.h5", "02281.h5", "02292.h5", + "02301.h5", "02348.h5", "02413.h5", "02419.h5", "02450.h5", "02455.h5", "02461.h5", "02480.h5", "02536.h5", "02601.h5", "02734.h5", "02755.h5", "02756.h5", "02805.h5", "02833.h5", "02851.h5", "02885.h5", "02899.h5", "02942.h5", "02966.h5", + "02986.h5", "03011.h5", "03024.h5", "03043.h5", "03117.h5", "03126.h5", "03140.h5", "03177.h5", "03183.h5", "03185.h5", "03202.h5", "03216.h5", "03223.h5", "03247.h5", "03270.h5", "03324.h5", "03326.h5", "03344.h5", "03352.h5", "03361.h5", "03366.h5", + "03404.h5", "03412.h5", "03451.h5", "03523.h5"] diff --git a/unigaze/configs/data/gazecapture_train.yaml b/unigaze/configs/data/gazecapture_train.yaml new file mode 100644 index 0000000000000000000000000000000000000000..16f3396248c25510beb1113c2af56aa29d8ce73d --- /dev/null +++ b/unigaze/configs/data/gazecapture_train.yaml @@ -0,0 +1,1189 @@ +type: datasets.gazecapture.GazeCaptureDataset +params: + data_name: gazecapture_train_224 + color_type: rgb + transform_type: 'basic_imagenet' + image_size: 224 + dataset_path: null + sample_rate_use: 1 + keys_to_use: + - 00002.h5 + - 00003.h5 + - 00005.h5 + - 00006.h5 + - 00024.h5 + - 00028.h5 + - 00033.h5 + - 00034.h5 + - 00087.h5 + - 00089.h5 + - 00097.h5 + - 00098.h5 + - 00099.h5 + - 00102.h5 + - 00103.h5 + - 00104.h5 + - 00114.h5 + - 00120.h5 + - 00121.h5 + - 00122.h5 + - 00123.h5 + - 00127.h5 + - 00128.h5 + - 00130.h5 + - 00132.h5 + - 00137.h5 + - 00138.h5 + - 00139.h5 + - 00140.h5 + - 00141.h5 + - 00142.h5 + - 00143.h5 + - 00144.h5 + - 00145.h5 + - 00146.h5 + - 00148.h5 + - 00149.h5 + - 00150.h5 + - 00153.h5 + - 00154.h5 + - 00156.h5 + - 00162.h5 + - 00164.h5 + - 00165.h5 + - 00173.h5 + - 00179.h5 + - 00191.h5 + - 00194.h5 + - 00200.h5 + - 00202.h5 + - 00208.h5 + - 00209.h5 + - 00210.h5 + - 00211.h5 + - 00212.h5 + - 00214.h5 + - 00218.h5 + - 00221.h5 + - 00224.h5 + - 00225.h5 + - 00226.h5 + - 00227.h5 + - 00228.h5 + - 00232.h5 + - 00234.h5 + - 00236.h5 + - 00237.h5 + - 00238.h5 + - 00239.h5 + - 00240.h5 + - 00241.h5 + - 00243.h5 + - 00245.h5 + - 00247.h5 + - 00249.h5 + - 00268.h5 + - 00269.h5 + - 00273.h5 + - 00274.h5 + - 00285.h5 + - 00288.h5 + - 00289.h5 + - 00295.h5 + - 00296.h5 + - 00299.h5 + - 00300.h5 + - 00303.h5 + - 00304.h5 + - 00305.h5 + - 00307.h5 + - 00309.h5 + - 00310.h5 + - 00311.h5 + - 00312.h5 + - 00317.h5 + - 00324.h5 + - 00325.h5 + - 00326.h5 + - 00331.h5 + - 00332.h5 + - 00339.h5 + - 00342.h5 + - 00351.h5 + - 00354.h5 + - 00355.h5 + - 00356.h5 + - 00357.h5 + - 00358.h5 + - 00359.h5 + - 00363.h5 + - 00376.h5 + - 00377.h5 + - 00459.h5 + - 00465.h5 + - 00466.h5 + - 00467.h5 + - 00469.h5 + - 00472.h5 + - 00473.h5 + - 00475.h5 + - 00477.h5 + - 00480.h5 + - 00481.h5 + - 00487.h5 + - 00488.h5 + - 00491.h5 + - 00492.h5 + - 00493.h5 + - 00494.h5 + - 00495.h5 + - 00496.h5 + - 00499.h5 + - 00501.h5 + - 00503.h5 + - 00505.h5 + - 00510.h5 + - 00512.h5 + - 00513.h5 + - 00514.h5 + - 00518.h5 + - 00519.h5 + - 00520.h5 + - 00522.h5 + - 00525.h5 + - 00531.h5 + - 00533.h5 + - 00534.h5 + - 00535.h5 + - 00539.h5 + - 00540.h5 + - 00542.h5 + - 00544.h5 + - 00545.h5 + - 00548.h5 + - 00550.h5 + - 00553.h5 + - 00554.h5 + - 00555.h5 + - 00560.h5 + - 00562.h5 + - 00565.h5 + - 00566.h5 + - 00569.h5 + - 00572.h5 + - 00574.h5 + - 00575.h5 + - 00578.h5 + - 00581.h5 + - 00584.h5 + - 00588.h5 + - 00590.h5 + - 00599.h5 + - 00600.h5 + - 00601.h5 + - 00602.h5 + - 00605.h5 + - 00606.h5 + - 00607.h5 + - 00610.h5 + - 00613.h5 + - 00617.h5 + - 00621.h5 + - 00622.h5 + - 00623.h5 + - 00624.h5 + - 00626.h5 + - 00627.h5 + - 00632.h5 + - 00633.h5 + - 00634.h5 + - 00636.h5 + - 00638.h5 + - 00641.h5 + - 00642.h5 + - 00643.h5 + - 00644.h5 + - 00645.h5 + - 00649.h5 + - 00650.h5 + - 00658.h5 + - 00661.h5 + - 00663.h5 + - 00666.h5 + - 00667.h5 + - 00668.h5 + - 00669.h5 + - 00670.h5 + - 00672.h5 + - 00675.h5 + - 00676.h5 + - 00677.h5 + - 00678.h5 + - 00679.h5 + - 00682.h5 + - 00683.h5 + - 00687.h5 + - 00688.h5 + - 00690.h5 + - 00691.h5 + - 00693.h5 + - 00694.h5 + - 00695.h5 + - 00699.h5 + - 00704.h5 + - 00706.h5 + - 00707.h5 + - 00710.h5 + - 00711.h5 + - 00712.h5 + - 00714.h5 + - 00716.h5 + - 00718.h5 + - 00719.h5 + - 00722.h5 + - 00728.h5 + - 00729.h5 + - 00730.h5 + - 00731.h5 + - 00732.h5 + - 00733.h5 + - 00737.h5 + - 00742.h5 + - 00743.h5 + - 00745.h5 + - 00747.h5 + - 00749.h5 + - 00750.h5 + - 00752.h5 + - 00753.h5 + - 00755.h5 + - 00756.h5 + - 00757.h5 + - 00764.h5 + - 00765.h5 + - 00767.h5 + - 00771.h5 + - 00772.h5 + - 00773.h5 + - 00774.h5 + - 00775.h5 + - 00789.h5 + - 00790.h5 + - 00791.h5 + - 00795.h5 + - 00798.h5 + - 00801.h5 + - 00802.h5 + - 00804.h5 + - 00806.h5 + - 00807.h5 + - 00810.h5 + - 00811.h5 + - 00812.h5 + - 00814.h5 + - 00818.h5 + - 00819.h5 + - 00820.h5 + - 00821.h5 + - 00823.h5 + - 00825.h5 + - 00827.h5 + - 00831.h5 + - 00832.h5 + - 00833.h5 + - 00835.h5 + - 00837.h5 + - 00840.h5 + - 00841.h5 + - 00842.h5 + - 00849.h5 + - 00850.h5 + - 00851.h5 + - 00852.h5 + - 00853.h5 + - 00855.h5 + - 00859.h5 + - 00864.h5 + - 00865.h5 + - 00869.h5 + - 00872.h5 + - 00873.h5 + - 00874.h5 + - 00875.h5 + - 00878.h5 + - 00881.h5 + - 00882.h5 + - 00886.h5 + - 00888.h5 + - 00889.h5 + - 00891.h5 + - 00892.h5 + - 00894.h5 + - 00896.h5 + - 00897.h5 + - 00898.h5 + - 00899.h5 + - 00900.h5 + - 00904.h5 + - 00905.h5 + - 00907.h5 + - 00911.h5 + - 00912.h5 + - 00914.h5 + - 00915.h5 + - 00923.h5 + - 00924.h5 + - 00927.h5 + - 00931.h5 + - 00933.h5 + - 00934.h5 + - 00938.h5 + - 00944.h5 + - 00945.h5 + - 00947.h5 + - 00948.h5 + - 00956.h5 + - 00961.h5 + - 00963.h5 + - 00969.h5 + - 00971.h5 + - 00974.h5 + - 00980.h5 + - 00981.h5 + - 00982.h5 + - 00983.h5 + - 00984.h5 + - 00986.h5 + - 00989.h5 + - 00991.h5 + - 00992.h5 + - 00997.h5 + - 00999.h5 + - 01000.h5 + - 01002.h5 + - 01003.h5 + - 01009.h5 + - 01010.h5 + - 01012.h5 + - 01015.h5 + - 01018.h5 + - 01019.h5 + - 01020.h5 + - 01021.h5 + - 01022.h5 + - 01024.h5 + - 01025.h5 + - 01031.h5 + - 01032.h5 + - 01034.h5 + - 01035.h5 + - 01038.h5 + - 01039.h5 + - 01042.h5 + - 01044.h5 + - 01045.h5 + - 01046.h5 + - 01050.h5 + - 01052.h5 + - 01054.h5 + - 01055.h5 + - 01056.h5 + - 01057.h5 + - 01058.h5 + - 01059.h5 + - 01060.h5 + - 01062.h5 + - 01063.h5 + - 01064.h5 + - 01065.h5 + - 01069.h5 + - 01070.h5 + - 01073.h5 + - 01075.h5 + - 01076.h5 + - 01077.h5 + - 01080.h5 + - 01081.h5 + - 01082.h5 + - 01083.h5 + - 01084.h5 + - 01085.h5 + - 01086.h5 + - 01087.h5 + - 01088.h5 + - 01089.h5 + - 01090.h5 + - 01092.h5 + - 01093.h5 + - 01095.h5 + - 01100.h5 + - 01102.h5 + - 01104.h5 + - 01105.h5 + - 01106.h5 + - 01107.h5 + - 01110.h5 + - 01118.h5 + - 01120.h5 + - 01121.h5 + - 01123.h5 + - 01127.h5 + - 01128.h5 + - 01129.h5 + - 01135.h5 + - 01138.h5 + - 01139.h5 + - 01143.h5 + - 01145.h5 + - 01146.h5 + - 01147.h5 + - 01149.h5 + - 01151.h5 + - 01156.h5 + - 01157.h5 + - 01158.h5 + - 01161.h5 + - 01162.h5 + - 01163.h5 + - 01164.h5 + - 01165.h5 + - 01166.h5 + - 01167.h5 + - 01168.h5 + - 01169.h5 + - 01170.h5 + - 01171.h5 + - 01172.h5 + - 01173.h5 + - 01174.h5 + - 01175.h5 + - 01177.h5 + - 01178.h5 + - 01180.h5 + - 01181.h5 + - 01182.h5 + - 01184.h5 + - 01186.h5 + - 01188.h5 + - 01191.h5 + - 01195.h5 + - 01199.h5 + - 01201.h5 + - 01204.h5 + - 01207.h5 + - 01208.h5 + - 01209.h5 + - 01211.h5 + - 01212.h5 + - 01213.h5 + - 01219.h5 + - 01221.h5 + - 01222.h5 + - 01231.h5 + - 01232.h5 + - 01233.h5 + - 01237.h5 + - 01243.h5 + - 01244.h5 + - 01247.h5 + - 01250.h5 + - 01252.h5 + - 01254.h5 + - 01255.h5 + - 01256.h5 + - 01259.h5 + - 01260.h5 + - 01262.h5 + - 01266.h5 + - 01269.h5 + - 01270.h5 + - 01275.h5 + - 01276.h5 + - 01279.h5 + - 01281.h5 + - 01283.h5 + - 01285.h5 + - 01293.h5 + - 01295.h5 + - 01298.h5 + - 01300.h5 + - 01301.h5 + - 01303.h5 + - 01304.h5 + - 01315.h5 + - 01316.h5 + - 01320.h5 + - 01323.h5 + - 01327.h5 + - 01328.h5 + - 01330.h5 + - 01331.h5 + - 01333.h5 + - 01340.h5 + - 01347.h5 + - 01348.h5 + - 01349.h5 + - 01351.h5 + - 01352.h5 + - 01353.h5 + - 01354.h5 + - 01356.h5 + - 01357.h5 + - 01358.h5 + - 01360.h5 + - 01361.h5 + - 01362.h5 + - 01368.h5 + - 01375.h5 + - 01377.h5 + - 01379.h5 + - 01380.h5 + - 01382.h5 + - 01383.h5 + - 01384.h5 + - 01386.h5 + - 01387.h5 + - 01388.h5 + - 01389.h5 + - 01390.h5 + - 01391.h5 + - 01393.h5 + - 01396.h5 + - 01400.h5 + - 01405.h5 + - 01406.h5 + - 01414.h5 + - 01415.h5 + - 01420.h5 + - 01421.h5 + - 01423.h5 + - 01424.h5 + - 01428.h5 + - 01430.h5 + - 01431.h5 + - 01434.h5 + - 01435.h5 + - 01438.h5 + - 01440.h5 + - 01445.h5 + - 01446.h5 + - 01448.h5 + - 01451.h5 + - 01454.h5 + - 01456.h5 + - 01459.h5 + - 01460.h5 + - 01462.h5 + - 01467.h5 + - 01470.h5 + - 01471.h5 + - 01472.h5 + - 01473.h5 + - 01478.h5 + - 01479.h5 + - 01480.h5 + - 01481.h5 + - 01482.h5 + - 01483.h5 + - 01485.h5 + - 01486.h5 + - 01487.h5 + - 01488.h5 + - 01491.h5 + - 01492.h5 + - 01496.h5 + - 01497.h5 + - 01499.h5 + - 01508.h5 + - 01510.h5 + - 01511.h5 + - 01514.h5 + - 01515.h5 + - 01516.h5 + - 01519.h5 + - 01523.h5 + - 01524.h5 + - 01528.h5 + - 01531.h5 + - 01532.h5 + - 01533.h5 + - 01534.h5 + - 01540.h5 + - 01542.h5 + - 01546.h5 + - 01551.h5 + - 01553.h5 + - 01566.h5 + - 01569.h5 + - 01574.h5 + - 01577.h5 + - 01581.h5 + - 01582.h5 + - 01583.h5 + - 01584.h5 + - 01602.h5 + - 01603.h5 + - 01604.h5 + - 01606.h5 + - 01611.h5 + - 01612.h5 + - 01613.h5 + - 01617.h5 + - 01618.h5 + - 01627.h5 + - 01630.h5 + - 01631.h5 + - 01633.h5 + - 01635.h5 + - 01636.h5 + - 01637.h5 + - 01640.h5 + - 01643.h5 + - 01644.h5 + - 01645.h5 + - 01648.h5 + - 01650.h5 + - 01651.h5 + - 01653.h5 + - 01658.h5 + - 01665.h5 + - 01669.h5 + - 01671.h5 + - 01678.h5 + - 01680.h5 + - 01681.h5 + - 01682.h5 + - 01684.h5 + - 01687.h5 + - 01690.h5 + - 01692.h5 + - 01693.h5 + - 01697.h5 + - 01698.h5 + - 01700.h5 + - 01703.h5 + - 01705.h5 + - 01706.h5 + - 01709.h5 + - 01710.h5 + - 01713.h5 + - 01717.h5 + - 01718.h5 + - 01719.h5 + - 01720.h5 + - 01726.h5 + - 01727.h5 + - 01728.h5 + - 01729.h5 + - 01730.h5 + - 01731.h5 + - 01734.h5 + - 01738.h5 + - 01741.h5 + - 01744.h5 + - 01745.h5 + - 01747.h5 + - 01748.h5 + - 01755.h5 + - 01762.h5 + - 01763.h5 + - 01768.h5 + - 01770.h5 + - 01771.h5 + - 01775.h5 + - 01778.h5 + - 01779.h5 + - 01783.h5 + - 01789.h5 + - 01792.h5 + - 01795.h5 + - 01796.h5 + - 01798.h5 + - 01802.h5 + - 01803.h5 + - 01806.h5 + - 01812.h5 + - 01816.h5 + - 01817.h5 + - 01818.h5 + - 01821.h5 + - 01823.h5 + - 01825.h5 + - 01826.h5 + - 01827.h5 + - 01828.h5 + - 01833.h5 + - 01843.h5 + - 01849.h5 + - 01858.h5 + - 01860.h5 + - 01862.h5 + - 01866.h5 + - 01867.h5 + - 01868.h5 + - 01869.h5 + - 01870.h5 + - 01874.h5 + - 01878.h5 + - 01880.h5 + - 01882.h5 + - 01883.h5 + - 01884.h5 + - 01885.h5 + - 01887.h5 + - 01888.h5 + - 01889.h5 + - 01892.h5 + - 01897.h5 + - 01900.h5 + - 01901.h5 + - 01902.h5 + - 01905.h5 + - 01906.h5 + - 01907.h5 + - 01908.h5 + - 01912.h5 + - 01915.h5 + - 01921.h5 + - 01922.h5 + - 01924.h5 + - 01925.h5 + - 01926.h5 + - 01927.h5 + - 01930.h5 + - 01933.h5 + - 01936.h5 + - 01943.h5 + - 01960.h5 + - 01961.h5 + - 01962.h5 + - 01964.h5 + - 01965.h5 + - 01966.h5 + - 01975.h5 + - 01976.h5 + - 01977.h5 + - 01979.h5 + - 01984.h5 + - 01987.h5 + - 01995.h5 + - 02009.h5 + - 02011.h5 + - 02015.h5 + - 02019.h5 + - 02022.h5 + - 02023.h5 + - 02024.h5 + - 02025.h5 + - 02026.h5 + - 02028.h5 + - 02029.h5 + - 02034.h5 + - 02035.h5 + - 02038.h5 + - 02045.h5 + - 02047.h5 + - 02051.h5 + - 02052.h5 + - 02056.h5 + - 02058.h5 + - 02059.h5 + - 02061.h5 + - 02064.h5 + - 02065.h5 + - 02077.h5 + - 02084.h5 + - 02085.h5 + - 02086.h5 + - 02087.h5 + - 02090.h5 + - 02092.h5 + - 02093.h5 + - 02099.h5 + - 02102.h5 + - 02105.h5 + - 02106.h5 + - 02112.h5 + - 02113.h5 + - 02114.h5 + - 02115.h5 + - 02118.h5 + - 02123.h5 + - 02131.h5 + - 02136.h5 + - 02137.h5 + - 02138.h5 + - 02140.h5 + - 02141.h5 + - 02142.h5 + - 02152.h5 + - 02154.h5 + - 02156.h5 + - 02159.h5 + - 02161.h5 + - 02162.h5 + - 02168.h5 + - 02170.h5 + - 02172.h5 + - 02173.h5 + - 02186.h5 + - 02187.h5 + - 02193.h5 + - 02198.h5 + - 02203.h5 + - 02204.h5 + - 02206.h5 + - 02207.h5 + - 02212.h5 + - 02216.h5 + - 02219.h5 + - 02220.h5 + - 02229.h5 + - 02230.h5 + - 02232.h5 + - 02234.h5 + - 02237.h5 + - 02241.h5 + - 02244.h5 + - 02249.h5 + - 02250.h5 + - 02255.h5 + - 02257.h5 + - 02264.h5 + - 02266.h5 + - 02267.h5 + - 02270.h5 + - 02272.h5 + - 02277.h5 + - 02278.h5 + - 02279.h5 + - 02282.h5 + - 02293.h5 + - 02297.h5 + - 02298.h5 + - 02300.h5 + - 02311.h5 + - 02314.h5 + - 02319.h5 + - 02321.h5 + - 02322.h5 + - 02324.h5 + - 02326.h5 + - 02327.h5 + - 02328.h5 + - 02332.h5 + - 02334.h5 + - 02337.h5 + - 02339.h5 + - 02342.h5 + - 02343.h5 + - 02347.h5 + - 02349.h5 + - 02350.h5 + - 02352.h5 + - 02355.h5 + - 02358.h5 + - 02359.h5 + - 02361.h5 + - 02362.h5 + - 02365.h5 + - 02366.h5 + - 02367.h5 + - 02368.h5 + - 02370.h5 + - 02371.h5 + - 02373.h5 + - 02375.h5 + - 02379.h5 + - 02394.h5 + - 02412.h5 + - 02414.h5 + - 02415.h5 + - 02418.h5 + - 02420.h5 + - 02421.h5 + - 02424.h5 + - 02426.h5 + - 02430.h5 + - 02431.h5 + - 02432.h5 + - 02434.h5 + - 02435.h5 + - 02436.h5 + - 02439.h5 + - 02440.h5 + - 02441.h5 + - 02442.h5 + - 02443.h5 + - 02445.h5 + - 02447.h5 + - 02448.h5 + - 02452.h5 + - 02454.h5 + - 02457.h5 + - 02458.h5 + - 02459.h5 + - 02462.h5 + - 02465.h5 + - 02467.h5 + - 02468.h5 + - 02469.h5 + - 02472.h5 + - 02474.h5 + - 02478.h5 + - 02510.h5 + - 02518.h5 + - 02520.h5 + - 02521.h5 + - 02522.h5 + - 02524.h5 + - 02525.h5 + - 02534.h5 + - 02535.h5 + - 02540.h5 + - 02547.h5 + - 02550.h5 + - 02552.h5 + - 02553.h5 + - 02554.h5 + - 02557.h5 + - 02559.h5 + - 02566.h5 + - 02567.h5 + - 02571.h5 + - 02573.h5 + - 02575.h5 + - 02576.h5 + - 02578.h5 + - 02581.h5 + - 02585.h5 + - 02587.h5 + - 02588.h5 + - 02590.h5 + - 02595.h5 + - 02610.h5 + - 02611.h5 + - 02613.h5 + - 02615.h5 + - 02617.h5 + - 02619.h5 + - 02629.h5 + - 02632.h5 + - 02634.h5 + - 02649.h5 + - 02663.h5 + - 02666.h5 + - 02669.h5 + - 02673.h5 + - 02681.h5 + - 02689.h5 + - 02690.h5 + - 02700.h5 + - 02705.h5 + - 02709.h5 + - 02713.h5 + - 02718.h5 + - 02721.h5 + - 02722.h5 + - 02723.h5 + - 02725.h5 + - 02729.h5 + - 02730.h5 + - 02732.h5 + - 02737.h5 + - 02740.h5 + - 02741.h5 + - 02749.h5 + - 02758.h5 + - 02760.h5 + - 02761.h5 + - 02762.h5 + - 02763.h5 + - 02764.h5 + - 02765.h5 + - 02772.h5 + - 02773.h5 + - 02774.h5 + - 02776.h5 + - 02780.h5 + - 02781.h5 + - 02785.h5 + - 02797.h5 + - 02818.h5 + - 02819.h5 + - 02827.h5 + - 02829.h5 + - 02832.h5 + - 02837.h5 + - 02841.h5 + - 02843.h5 + - 02846.h5 + - 02847.h5 + - 02852.h5 + - 02854.h5 + - 02857.h5 + - 02868.h5 + - 02872.h5 + - 02873.h5 + - 02874.h5 + - 02876.h5 + - 02877.h5 + - 02878.h5 + - 02879.h5 + - 02880.h5 + - 02882.h5 + - 02883.h5 + - 02888.h5 + - 02898.h5 + - 02902.h5 + - 02908.h5 + - 02911.h5 + - 02919.h5 + - 02920.h5 + - 02921.h5 + - 02922.h5 + - 02924.h5 + - 02925.h5 + - 02928.h5 + - 02938.h5 + - 02941.h5 + - 02944.h5 + - 02945.h5 + - 02954.h5 + - 02955.h5 + - 02956.h5 + - 02960.h5 + - 02961.h5 + - 02964.h5 + - 02967.h5 + - 02977.h5 + - 02978.h5 + - 02979.h5 + - 02980.h5 + - 02985.h5 + - 02987.h5 + - 02988.h5 + - 02989.h5 + - 02991.h5 + - 02997.h5 + - 02998.h5 + - 03003.h5 + - 03004.h5 + - 03006.h5 + - 03009.h5 + - 03012.h5 + - 03013.h5 + - 03014.h5 + - 03023.h5 + - 03026.h5 + - 03027.h5 + - 03037.h5 + - 03042.h5 + - 03051.h5 + - 03057.h5 + - 03064.h5 + - 03065.h5 + - 03079.h5 + - 03089.h5 + - 03102.h5 + - 03107.h5 + - 03116.h5 + - 03122.h5 + - 03125.h5 + - 03130.h5 + - 03133.h5 + - 03134.h5 + - 03137.h5 + - 03139.h5 + - 03160.h5 + - 03163.h5 + - 03172.h5 + - 03174.h5 + - 03178.h5 + - 03179.h5 + - 03180.h5 + - 03188.h5 + - 03189.h5 + - 03190.h5 + - 03192.h5 + - 03193.h5 + - 03197.h5 + - 03199.h5 + - 03200.h5 + - 03205.h5 + - 03206.h5 + - 03211.h5 + - 03218.h5 + - 03219.h5 + - 03222.h5 + - 03225.h5 + - 03231.h5 + - 03246.h5 + - 03248.h5 + - 03251.h5 + - 03253.h5 + - 03255.h5 + - 03259.h5 + - 03263.h5 + - 03265.h5 + - 03266.h5 + - 03273.h5 + - 03275.h5 + - 03277.h5 + - 03278.h5 + - 03282.h5 + - 03283.h5 + - 03302.h5 + - 03303.h5 + - 03304.h5 + - 03307.h5 + - 03314.h5 + - 03315.h5 + - 03327.h5 + - 03328.h5 + - 03332.h5 + - 03336.h5 + - 03340.h5 + - 03342.h5 + - 03343.h5 + - 03348.h5 + - 03351.h5 + - 03354.h5 + - 03358.h5 + - 03359.h5 + - 03360.h5 + - 03367.h5 + - 03371.h5 + - 03374.h5 + - 03375.h5 + - 03377.h5 + - 03378.h5 + - 03379.h5 + - 03381.h5 + - 03382.h5 + - 03384.h5 + - 03397.h5 + - 03403.h5 + - 03406.h5 + - 03413.h5 + - 03425.h5 + - 03431.h5 + - 03432.h5 + - 03435.h5 + - 03442.h5 + - 03453.h5 + - 03454.h5 + - 03456.h5 + - 03463.h5 + - 03465.h5 + - 03466.h5 + - 03467.h5 + - 03469.h5 + - 03473.h5 + - 03491.h5 + - 03492.h5 + - 03495.h5 + - 03498.h5 + - 03501.h5 + - 03502.h5 + diff --git a/unigaze/configs/data/gazecapture_train_ds15.yaml b/unigaze/configs/data/gazecapture_train_ds15.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9eeca10d59800fccd81a469a211ed6825b6aecdf --- /dev/null +++ b/unigaze/configs/data/gazecapture_train_ds15.yaml @@ -0,0 +1,1189 @@ +type: datasets.gazecapture.GazeCaptureDataset +params: + data_name: gazecapture_train_224 + color_type: rgb + transform_type: 'basic_imagenet' + image_size: 224 + dataset_path: null + sample_rate_use: 15 + keys_to_use: + - 00002.h5 + - 00003.h5 + - 00005.h5 + - 00006.h5 + - 00024.h5 + - 00028.h5 + - 00033.h5 + - 00034.h5 + - 00087.h5 + - 00089.h5 + - 00097.h5 + - 00098.h5 + - 00099.h5 + - 00102.h5 + - 00103.h5 + - 00104.h5 + - 00114.h5 + - 00120.h5 + - 00121.h5 + - 00122.h5 + - 00123.h5 + - 00127.h5 + - 00128.h5 + - 00130.h5 + - 00132.h5 + - 00137.h5 + - 00138.h5 + - 00139.h5 + - 00140.h5 + - 00141.h5 + - 00142.h5 + - 00143.h5 + - 00144.h5 + - 00145.h5 + - 00146.h5 + - 00148.h5 + - 00149.h5 + - 00150.h5 + - 00153.h5 + - 00154.h5 + - 00156.h5 + - 00162.h5 + - 00164.h5 + - 00165.h5 + - 00173.h5 + - 00179.h5 + - 00191.h5 + - 00194.h5 + - 00200.h5 + - 00202.h5 + - 00208.h5 + - 00209.h5 + - 00210.h5 + - 00211.h5 + - 00212.h5 + - 00214.h5 + - 00218.h5 + - 00221.h5 + - 00224.h5 + - 00225.h5 + - 00226.h5 + - 00227.h5 + - 00228.h5 + - 00232.h5 + - 00234.h5 + - 00236.h5 + - 00237.h5 + - 00238.h5 + - 00239.h5 + - 00240.h5 + - 00241.h5 + - 00243.h5 + - 00245.h5 + - 00247.h5 + - 00249.h5 + - 00268.h5 + - 00269.h5 + - 00273.h5 + - 00274.h5 + - 00285.h5 + - 00288.h5 + - 00289.h5 + - 00295.h5 + - 00296.h5 + - 00299.h5 + - 00300.h5 + - 00303.h5 + - 00304.h5 + - 00305.h5 + - 00307.h5 + - 00309.h5 + - 00310.h5 + - 00311.h5 + - 00312.h5 + - 00317.h5 + - 00324.h5 + - 00325.h5 + - 00326.h5 + - 00331.h5 + - 00332.h5 + - 00339.h5 + - 00342.h5 + - 00351.h5 + - 00354.h5 + - 00355.h5 + - 00356.h5 + - 00357.h5 + - 00358.h5 + - 00359.h5 + - 00363.h5 + - 00376.h5 + - 00377.h5 + - 00459.h5 + - 00465.h5 + - 00466.h5 + - 00467.h5 + - 00469.h5 + - 00472.h5 + - 00473.h5 + - 00475.h5 + - 00477.h5 + - 00480.h5 + - 00481.h5 + - 00487.h5 + - 00488.h5 + - 00491.h5 + - 00492.h5 + - 00493.h5 + - 00494.h5 + - 00495.h5 + - 00496.h5 + - 00499.h5 + - 00501.h5 + - 00503.h5 + - 00505.h5 + - 00510.h5 + - 00512.h5 + - 00513.h5 + - 00514.h5 + - 00518.h5 + - 00519.h5 + - 00520.h5 + - 00522.h5 + - 00525.h5 + - 00531.h5 + - 00533.h5 + - 00534.h5 + - 00535.h5 + - 00539.h5 + - 00540.h5 + - 00542.h5 + - 00544.h5 + - 00545.h5 + - 00548.h5 + - 00550.h5 + - 00553.h5 + - 00554.h5 + - 00555.h5 + - 00560.h5 + - 00562.h5 + - 00565.h5 + - 00566.h5 + - 00569.h5 + - 00572.h5 + - 00574.h5 + - 00575.h5 + - 00578.h5 + - 00581.h5 + - 00584.h5 + - 00588.h5 + - 00590.h5 + - 00599.h5 + - 00600.h5 + - 00601.h5 + - 00602.h5 + - 00605.h5 + - 00606.h5 + - 00607.h5 + - 00610.h5 + - 00613.h5 + - 00617.h5 + - 00621.h5 + - 00622.h5 + - 00623.h5 + - 00624.h5 + - 00626.h5 + - 00627.h5 + - 00632.h5 + - 00633.h5 + - 00634.h5 + - 00636.h5 + - 00638.h5 + - 00641.h5 + - 00642.h5 + - 00643.h5 + - 00644.h5 + - 00645.h5 + - 00649.h5 + - 00650.h5 + - 00658.h5 + - 00661.h5 + - 00663.h5 + - 00666.h5 + - 00667.h5 + - 00668.h5 + - 00669.h5 + - 00670.h5 + - 00672.h5 + - 00675.h5 + - 00676.h5 + - 00677.h5 + - 00678.h5 + - 00679.h5 + - 00682.h5 + - 00683.h5 + - 00687.h5 + - 00688.h5 + - 00690.h5 + - 00691.h5 + - 00693.h5 + - 00694.h5 + - 00695.h5 + - 00699.h5 + - 00704.h5 + - 00706.h5 + - 00707.h5 + - 00710.h5 + - 00711.h5 + - 00712.h5 + - 00714.h5 + - 00716.h5 + - 00718.h5 + - 00719.h5 + - 00722.h5 + - 00728.h5 + - 00729.h5 + - 00730.h5 + - 00731.h5 + - 00732.h5 + - 00733.h5 + - 00737.h5 + - 00742.h5 + - 00743.h5 + - 00745.h5 + - 00747.h5 + - 00749.h5 + - 00750.h5 + - 00752.h5 + - 00753.h5 + - 00755.h5 + - 00756.h5 + - 00757.h5 + - 00764.h5 + - 00765.h5 + - 00767.h5 + - 00771.h5 + - 00772.h5 + - 00773.h5 + - 00774.h5 + - 00775.h5 + - 00789.h5 + - 00790.h5 + - 00791.h5 + - 00795.h5 + - 00798.h5 + - 00801.h5 + - 00802.h5 + - 00804.h5 + - 00806.h5 + - 00807.h5 + - 00810.h5 + - 00811.h5 + - 00812.h5 + - 00814.h5 + - 00818.h5 + - 00819.h5 + - 00820.h5 + - 00821.h5 + - 00823.h5 + - 00825.h5 + - 00827.h5 + - 00831.h5 + - 00832.h5 + - 00833.h5 + - 00835.h5 + - 00837.h5 + - 00840.h5 + - 00841.h5 + - 00842.h5 + - 00849.h5 + - 00850.h5 + - 00851.h5 + - 00852.h5 + - 00853.h5 + - 00855.h5 + - 00859.h5 + - 00864.h5 + - 00865.h5 + - 00869.h5 + - 00872.h5 + - 00873.h5 + - 00874.h5 + - 00875.h5 + - 00878.h5 + - 00881.h5 + - 00882.h5 + - 00886.h5 + - 00888.h5 + - 00889.h5 + - 00891.h5 + - 00892.h5 + - 00894.h5 + - 00896.h5 + - 00897.h5 + - 00898.h5 + - 00899.h5 + - 00900.h5 + - 00904.h5 + - 00905.h5 + - 00907.h5 + - 00911.h5 + - 00912.h5 + - 00914.h5 + - 00915.h5 + - 00923.h5 + - 00924.h5 + - 00927.h5 + - 00931.h5 + - 00933.h5 + - 00934.h5 + - 00938.h5 + - 00944.h5 + - 00945.h5 + - 00947.h5 + - 00948.h5 + - 00956.h5 + - 00961.h5 + - 00963.h5 + - 00969.h5 + - 00971.h5 + - 00974.h5 + - 00980.h5 + - 00981.h5 + - 00982.h5 + - 00983.h5 + - 00984.h5 + - 00986.h5 + - 00989.h5 + - 00991.h5 + - 00992.h5 + - 00997.h5 + - 00999.h5 + - 01000.h5 + - 01002.h5 + - 01003.h5 + - 01009.h5 + - 01010.h5 + - 01012.h5 + - 01015.h5 + - 01018.h5 + - 01019.h5 + - 01020.h5 + - 01021.h5 + - 01022.h5 + - 01024.h5 + - 01025.h5 + - 01031.h5 + - 01032.h5 + - 01034.h5 + - 01035.h5 + - 01038.h5 + - 01039.h5 + - 01042.h5 + - 01044.h5 + - 01045.h5 + - 01046.h5 + - 01050.h5 + - 01052.h5 + - 01054.h5 + - 01055.h5 + - 01056.h5 + - 01057.h5 + - 01058.h5 + - 01059.h5 + - 01060.h5 + - 01062.h5 + - 01063.h5 + - 01064.h5 + - 01065.h5 + - 01069.h5 + - 01070.h5 + - 01073.h5 + - 01075.h5 + - 01076.h5 + - 01077.h5 + - 01080.h5 + - 01081.h5 + - 01082.h5 + - 01083.h5 + - 01084.h5 + - 01085.h5 + - 01086.h5 + - 01087.h5 + - 01088.h5 + - 01089.h5 + - 01090.h5 + - 01092.h5 + - 01093.h5 + - 01095.h5 + - 01100.h5 + - 01102.h5 + - 01104.h5 + - 01105.h5 + - 01106.h5 + - 01107.h5 + - 01110.h5 + - 01118.h5 + - 01120.h5 + - 01121.h5 + - 01123.h5 + - 01127.h5 + - 01128.h5 + - 01129.h5 + - 01135.h5 + - 01138.h5 + - 01139.h5 + - 01143.h5 + - 01145.h5 + - 01146.h5 + - 01147.h5 + - 01149.h5 + - 01151.h5 + - 01156.h5 + - 01157.h5 + - 01158.h5 + - 01161.h5 + - 01162.h5 + - 01163.h5 + - 01164.h5 + - 01165.h5 + - 01166.h5 + - 01167.h5 + - 01168.h5 + - 01169.h5 + - 01170.h5 + - 01171.h5 + - 01172.h5 + - 01173.h5 + - 01174.h5 + - 01175.h5 + - 01177.h5 + - 01178.h5 + - 01180.h5 + - 01181.h5 + - 01182.h5 + - 01184.h5 + - 01186.h5 + - 01188.h5 + - 01191.h5 + - 01195.h5 + - 01199.h5 + - 01201.h5 + - 01204.h5 + - 01207.h5 + - 01208.h5 + - 01209.h5 + - 01211.h5 + - 01212.h5 + - 01213.h5 + - 01219.h5 + - 01221.h5 + - 01222.h5 + - 01231.h5 + - 01232.h5 + - 01233.h5 + - 01237.h5 + - 01243.h5 + - 01244.h5 + - 01247.h5 + - 01250.h5 + - 01252.h5 + - 01254.h5 + - 01255.h5 + - 01256.h5 + - 01259.h5 + - 01260.h5 + - 01262.h5 + - 01266.h5 + - 01269.h5 + - 01270.h5 + - 01275.h5 + - 01276.h5 + - 01279.h5 + - 01281.h5 + - 01283.h5 + - 01285.h5 + - 01293.h5 + - 01295.h5 + - 01298.h5 + - 01300.h5 + - 01301.h5 + - 01303.h5 + - 01304.h5 + - 01315.h5 + - 01316.h5 + - 01320.h5 + - 01323.h5 + - 01327.h5 + - 01328.h5 + - 01330.h5 + - 01331.h5 + - 01333.h5 + - 01340.h5 + - 01347.h5 + - 01348.h5 + - 01349.h5 + - 01351.h5 + - 01352.h5 + - 01353.h5 + - 01354.h5 + - 01356.h5 + - 01357.h5 + - 01358.h5 + - 01360.h5 + - 01361.h5 + - 01362.h5 + - 01368.h5 + - 01375.h5 + - 01377.h5 + - 01379.h5 + - 01380.h5 + - 01382.h5 + - 01383.h5 + - 01384.h5 + - 01386.h5 + - 01387.h5 + - 01388.h5 + - 01389.h5 + - 01390.h5 + - 01391.h5 + - 01393.h5 + - 01396.h5 + - 01400.h5 + - 01405.h5 + - 01406.h5 + - 01414.h5 + - 01415.h5 + - 01420.h5 + - 01421.h5 + - 01423.h5 + - 01424.h5 + - 01428.h5 + - 01430.h5 + - 01431.h5 + - 01434.h5 + - 01435.h5 + - 01438.h5 + - 01440.h5 + - 01445.h5 + - 01446.h5 + - 01448.h5 + - 01451.h5 + - 01454.h5 + - 01456.h5 + - 01459.h5 + - 01460.h5 + - 01462.h5 + - 01467.h5 + - 01470.h5 + - 01471.h5 + - 01472.h5 + - 01473.h5 + - 01478.h5 + - 01479.h5 + - 01480.h5 + - 01481.h5 + - 01482.h5 + - 01483.h5 + - 01485.h5 + - 01486.h5 + - 01487.h5 + - 01488.h5 + - 01491.h5 + - 01492.h5 + - 01496.h5 + - 01497.h5 + - 01499.h5 + - 01508.h5 + - 01510.h5 + - 01511.h5 + - 01514.h5 + - 01515.h5 + - 01516.h5 + - 01519.h5 + - 01523.h5 + - 01524.h5 + - 01528.h5 + - 01531.h5 + - 01532.h5 + - 01533.h5 + - 01534.h5 + - 01540.h5 + - 01542.h5 + - 01546.h5 + - 01551.h5 + - 01553.h5 + - 01566.h5 + - 01569.h5 + - 01574.h5 + - 01577.h5 + - 01581.h5 + - 01582.h5 + - 01583.h5 + - 01584.h5 + - 01602.h5 + - 01603.h5 + - 01604.h5 + - 01606.h5 + - 01611.h5 + - 01612.h5 + - 01613.h5 + - 01617.h5 + - 01618.h5 + - 01627.h5 + - 01630.h5 + - 01631.h5 + - 01633.h5 + - 01635.h5 + - 01636.h5 + - 01637.h5 + - 01640.h5 + - 01643.h5 + - 01644.h5 + - 01645.h5 + - 01648.h5 + - 01650.h5 + - 01651.h5 + - 01653.h5 + - 01658.h5 + - 01665.h5 + - 01669.h5 + - 01671.h5 + - 01678.h5 + - 01680.h5 + - 01681.h5 + - 01682.h5 + - 01684.h5 + - 01687.h5 + - 01690.h5 + - 01692.h5 + - 01693.h5 + - 01697.h5 + - 01698.h5 + - 01700.h5 + - 01703.h5 + - 01705.h5 + - 01706.h5 + - 01709.h5 + - 01710.h5 + - 01713.h5 + - 01717.h5 + - 01718.h5 + - 01719.h5 + - 01720.h5 + - 01726.h5 + - 01727.h5 + - 01728.h5 + - 01729.h5 + - 01730.h5 + - 01731.h5 + - 01734.h5 + - 01738.h5 + - 01741.h5 + - 01744.h5 + - 01745.h5 + - 01747.h5 + - 01748.h5 + - 01755.h5 + - 01762.h5 + - 01763.h5 + - 01768.h5 + - 01770.h5 + - 01771.h5 + - 01775.h5 + - 01778.h5 + - 01779.h5 + - 01783.h5 + - 01789.h5 + - 01792.h5 + - 01795.h5 + - 01796.h5 + - 01798.h5 + - 01802.h5 + - 01803.h5 + - 01806.h5 + - 01812.h5 + - 01816.h5 + - 01817.h5 + - 01818.h5 + - 01821.h5 + - 01823.h5 + - 01825.h5 + - 01826.h5 + - 01827.h5 + - 01828.h5 + - 01833.h5 + - 01843.h5 + - 01849.h5 + - 01858.h5 + - 01860.h5 + - 01862.h5 + - 01866.h5 + - 01867.h5 + - 01868.h5 + - 01869.h5 + - 01870.h5 + - 01874.h5 + - 01878.h5 + - 01880.h5 + - 01882.h5 + - 01883.h5 + - 01884.h5 + - 01885.h5 + - 01887.h5 + - 01888.h5 + - 01889.h5 + - 01892.h5 + - 01897.h5 + - 01900.h5 + - 01901.h5 + - 01902.h5 + - 01905.h5 + - 01906.h5 + - 01907.h5 + - 01908.h5 + - 01912.h5 + - 01915.h5 + - 01921.h5 + - 01922.h5 + - 01924.h5 + - 01925.h5 + - 01926.h5 + - 01927.h5 + - 01930.h5 + - 01933.h5 + - 01936.h5 + - 01943.h5 + - 01960.h5 + - 01961.h5 + - 01962.h5 + - 01964.h5 + - 01965.h5 + - 01966.h5 + - 01975.h5 + - 01976.h5 + - 01977.h5 + - 01979.h5 + - 01984.h5 + - 01987.h5 + - 01995.h5 + - 02009.h5 + - 02011.h5 + - 02015.h5 + - 02019.h5 + - 02022.h5 + - 02023.h5 + - 02024.h5 + - 02025.h5 + - 02026.h5 + - 02028.h5 + - 02029.h5 + - 02034.h5 + - 02035.h5 + - 02038.h5 + - 02045.h5 + - 02047.h5 + - 02051.h5 + - 02052.h5 + - 02056.h5 + - 02058.h5 + - 02059.h5 + - 02061.h5 + - 02064.h5 + - 02065.h5 + - 02077.h5 + - 02084.h5 + - 02085.h5 + - 02086.h5 + - 02087.h5 + - 02090.h5 + - 02092.h5 + - 02093.h5 + - 02099.h5 + - 02102.h5 + - 02105.h5 + - 02106.h5 + - 02112.h5 + - 02113.h5 + - 02114.h5 + - 02115.h5 + - 02118.h5 + - 02123.h5 + - 02131.h5 + - 02136.h5 + - 02137.h5 + - 02138.h5 + - 02140.h5 + - 02141.h5 + - 02142.h5 + - 02152.h5 + - 02154.h5 + - 02156.h5 + - 02159.h5 + - 02161.h5 + - 02162.h5 + - 02168.h5 + - 02170.h5 + - 02172.h5 + - 02173.h5 + - 02186.h5 + - 02187.h5 + - 02193.h5 + - 02198.h5 + - 02203.h5 + - 02204.h5 + - 02206.h5 + - 02207.h5 + - 02212.h5 + - 02216.h5 + - 02219.h5 + - 02220.h5 + - 02229.h5 + - 02230.h5 + - 02232.h5 + - 02234.h5 + - 02237.h5 + - 02241.h5 + - 02244.h5 + - 02249.h5 + - 02250.h5 + - 02255.h5 + - 02257.h5 + - 02264.h5 + - 02266.h5 + - 02267.h5 + - 02270.h5 + - 02272.h5 + - 02277.h5 + - 02278.h5 + - 02279.h5 + - 02282.h5 + - 02293.h5 + - 02297.h5 + - 02298.h5 + - 02300.h5 + - 02311.h5 + - 02314.h5 + - 02319.h5 + - 02321.h5 + - 02322.h5 + - 02324.h5 + - 02326.h5 + - 02327.h5 + - 02328.h5 + - 02332.h5 + - 02334.h5 + - 02337.h5 + - 02339.h5 + - 02342.h5 + - 02343.h5 + - 02347.h5 + - 02349.h5 + - 02350.h5 + - 02352.h5 + - 02355.h5 + - 02358.h5 + - 02359.h5 + - 02361.h5 + - 02362.h5 + - 02365.h5 + - 02366.h5 + - 02367.h5 + - 02368.h5 + - 02370.h5 + - 02371.h5 + - 02373.h5 + - 02375.h5 + - 02379.h5 + - 02394.h5 + - 02412.h5 + - 02414.h5 + - 02415.h5 + - 02418.h5 + - 02420.h5 + - 02421.h5 + - 02424.h5 + - 02426.h5 + - 02430.h5 + - 02431.h5 + - 02432.h5 + - 02434.h5 + - 02435.h5 + - 02436.h5 + - 02439.h5 + - 02440.h5 + - 02441.h5 + - 02442.h5 + - 02443.h5 + - 02445.h5 + - 02447.h5 + - 02448.h5 + - 02452.h5 + - 02454.h5 + - 02457.h5 + - 02458.h5 + - 02459.h5 + - 02462.h5 + - 02465.h5 + - 02467.h5 + - 02468.h5 + - 02469.h5 + - 02472.h5 + - 02474.h5 + - 02478.h5 + - 02510.h5 + - 02518.h5 + - 02520.h5 + - 02521.h5 + - 02522.h5 + - 02524.h5 + - 02525.h5 + - 02534.h5 + - 02535.h5 + - 02540.h5 + - 02547.h5 + - 02550.h5 + - 02552.h5 + - 02553.h5 + - 02554.h5 + - 02557.h5 + - 02559.h5 + - 02566.h5 + - 02567.h5 + - 02571.h5 + - 02573.h5 + - 02575.h5 + - 02576.h5 + - 02578.h5 + - 02581.h5 + - 02585.h5 + - 02587.h5 + - 02588.h5 + - 02590.h5 + - 02595.h5 + - 02610.h5 + - 02611.h5 + - 02613.h5 + - 02615.h5 + - 02617.h5 + - 02619.h5 + - 02629.h5 + - 02632.h5 + - 02634.h5 + - 02649.h5 + - 02663.h5 + - 02666.h5 + - 02669.h5 + - 02673.h5 + - 02681.h5 + - 02689.h5 + - 02690.h5 + - 02700.h5 + - 02705.h5 + - 02709.h5 + - 02713.h5 + - 02718.h5 + - 02721.h5 + - 02722.h5 + - 02723.h5 + - 02725.h5 + - 02729.h5 + - 02730.h5 + - 02732.h5 + - 02737.h5 + - 02740.h5 + - 02741.h5 + - 02749.h5 + - 02758.h5 + - 02760.h5 + - 02761.h5 + - 02762.h5 + - 02763.h5 + - 02764.h5 + - 02765.h5 + - 02772.h5 + - 02773.h5 + - 02774.h5 + - 02776.h5 + - 02780.h5 + - 02781.h5 + - 02785.h5 + - 02797.h5 + - 02818.h5 + - 02819.h5 + - 02827.h5 + - 02829.h5 + - 02832.h5 + - 02837.h5 + - 02841.h5 + - 02843.h5 + - 02846.h5 + - 02847.h5 + - 02852.h5 + - 02854.h5 + - 02857.h5 + - 02868.h5 + - 02872.h5 + - 02873.h5 + - 02874.h5 + - 02876.h5 + - 02877.h5 + - 02878.h5 + - 02879.h5 + - 02880.h5 + - 02882.h5 + - 02883.h5 + - 02888.h5 + - 02898.h5 + - 02902.h5 + - 02908.h5 + - 02911.h5 + - 02919.h5 + - 02920.h5 + - 02921.h5 + - 02922.h5 + - 02924.h5 + - 02925.h5 + - 02928.h5 + - 02938.h5 + - 02941.h5 + - 02944.h5 + - 02945.h5 + - 02954.h5 + - 02955.h5 + - 02956.h5 + - 02960.h5 + - 02961.h5 + - 02964.h5 + - 02967.h5 + - 02977.h5 + - 02978.h5 + - 02979.h5 + - 02980.h5 + - 02985.h5 + - 02987.h5 + - 02988.h5 + - 02989.h5 + - 02991.h5 + - 02997.h5 + - 02998.h5 + - 03003.h5 + - 03004.h5 + - 03006.h5 + - 03009.h5 + - 03012.h5 + - 03013.h5 + - 03014.h5 + - 03023.h5 + - 03026.h5 + - 03027.h5 + - 03037.h5 + - 03042.h5 + - 03051.h5 + - 03057.h5 + - 03064.h5 + - 03065.h5 + - 03079.h5 + - 03089.h5 + - 03102.h5 + - 03107.h5 + - 03116.h5 + - 03122.h5 + - 03125.h5 + - 03130.h5 + - 03133.h5 + - 03134.h5 + - 03137.h5 + - 03139.h5 + - 03160.h5 + - 03163.h5 + - 03172.h5 + - 03174.h5 + - 03178.h5 + - 03179.h5 + - 03180.h5 + - 03188.h5 + - 03189.h5 + - 03190.h5 + - 03192.h5 + - 03193.h5 + - 03197.h5 + - 03199.h5 + - 03200.h5 + - 03205.h5 + - 03206.h5 + - 03211.h5 + - 03218.h5 + - 03219.h5 + - 03222.h5 + - 03225.h5 + - 03231.h5 + - 03246.h5 + - 03248.h5 + - 03251.h5 + - 03253.h5 + - 03255.h5 + - 03259.h5 + - 03263.h5 + - 03265.h5 + - 03266.h5 + - 03273.h5 + - 03275.h5 + - 03277.h5 + - 03278.h5 + - 03282.h5 + - 03283.h5 + - 03302.h5 + - 03303.h5 + - 03304.h5 + - 03307.h5 + - 03314.h5 + - 03315.h5 + - 03327.h5 + - 03328.h5 + - 03332.h5 + - 03336.h5 + - 03340.h5 + - 03342.h5 + - 03343.h5 + - 03348.h5 + - 03351.h5 + - 03354.h5 + - 03358.h5 + - 03359.h5 + - 03360.h5 + - 03367.h5 + - 03371.h5 + - 03374.h5 + - 03375.h5 + - 03377.h5 + - 03378.h5 + - 03379.h5 + - 03381.h5 + - 03382.h5 + - 03384.h5 + - 03397.h5 + - 03403.h5 + - 03406.h5 + - 03413.h5 + - 03425.h5 + - 03431.h5 + - 03432.h5 + - 03435.h5 + - 03442.h5 + - 03453.h5 + - 03454.h5 + - 03456.h5 + - 03463.h5 + - 03465.h5 + - 03466.h5 + - 03467.h5 + - 03469.h5 + - 03473.h5 + - 03491.h5 + - 03492.h5 + - 03495.h5 + - 03498.h5 + - 03501.h5 + - 03502.h5 + diff --git a/unigaze/configs/data/mpiigaze.yaml b/unigaze/configs/data/mpiigaze.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7619a307641904290b7801523efe7118f1cc0b28 --- /dev/null +++ b/unigaze/configs/data/mpiigaze.yaml @@ -0,0 +1,24 @@ +type: datasets.mpiigaze.MPIIGazeDataset +params: + data_name: mpii + color_type: bgr + transform_type: 'basic_imagenet' + dataset_path: null + image_size: 224 + keys_to_use: + - p00.h5 + - p01.h5 + - p02.h5 + - p03.h5 + - p04.h5 + - p05.h5 + - p06.h5 + - p07.h5 + - p08.h5 + - p09.h5 + - p10.h5 + - p11.h5 + - p12.h5 + - p13.h5 + - p14.h5 + diff --git a/unigaze/configs/data/mpiigaze_test.yaml b/unigaze/configs/data/mpiigaze_test.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3e453d6cf7985377483e459b1ce795653f92e5c5 --- /dev/null +++ b/unigaze/configs/data/mpiigaze_test.yaml @@ -0,0 +1,24 @@ +type: datasets.mpiigaze.MPIIGazeDataset +params: + data_name: mpii + color_type: bgr + transform_type: 'basic_imagenet' + dataset_path: null + image_size: 224 + keys_to_use: + # - p00.h5 + # - p01.h5 + # - p02.h5 + # - p03.h5 + # - p04.h5 + # - p05.h5 + # - p06.h5 + # - p07.h5 + # - p08.h5 + # - p09.h5 + - p10.h5 + - p11.h5 + - p12.h5 + - p13.h5 + - p14.h5 + diff --git a/unigaze/configs/data/mpiigaze_train.yaml b/unigaze/configs/data/mpiigaze_train.yaml new file mode 100644 index 0000000000000000000000000000000000000000..364d1048b55daea12ed76737ec5b2cba1d102dcb --- /dev/null +++ b/unigaze/configs/data/mpiigaze_train.yaml @@ -0,0 +1,24 @@ +type: datasets.mpiigaze.MPIIGazeDataset +params: + data_name: mpii + color_type: bgr + transform_type: 'basic_imagenet' + dataset_path: null + image_size: 224 + keys_to_use: + - p00.h5 + - p01.h5 + - p02.h5 + - p03.h5 + - p04.h5 + - p05.h5 + - p06.h5 + - p07.h5 + - p08.h5 + - p09.h5 + # - p10.h5 + # - p11.h5 + # - p12.h5 + # - p13.h5 + # - p14.h5 + diff --git a/unigaze/configs/data/xgaze_0_60sub.yaml b/unigaze/configs/data/xgaze_0_60sub.yaml new file mode 100644 index 0000000000000000000000000000000000000000..4fd79de160a46b380ae4addecf3f2cb360990062 --- /dev/null +++ b/unigaze/configs/data/xgaze_0_60sub.yaml @@ -0,0 +1,76 @@ + +type: datasets.xgaze.XGazeDataset +params: + data_name: xgaze_v2_224 + + images_per_frame: 18 + color_type: bgr + transform_type: 'basic_imagenet' + dataset_path: null + frame_tag: 'all' + image_size: 224 + keys_to_use: + - subject0000.h5 + - subject0003.h5 + - subject0004.h5 + - subject0005.h5 + - subject0006.h5 + - subject0007.h5 + - subject0008.h5 + - subject0009.h5 + - subject0010.h5 + - subject0013.h5 + - subject0014.h5 + - subject0015.h5 + - subject0016.h5 + - subject0018.h5 + - subject0019.h5 + - subject0021.h5 + - subject0024.h5 + - subject0026.h5 + - subject0027.h5 + - subject0028.h5 + - subject0029.h5 + - subject0030.h5 + - subject0031.h5 + - subject0032.h5 + - subject0033.h5 + - subject0035.h5 + - subject0036.h5 + - subject0038.h5 + - subject0039.h5 + - subject0040.h5 + - subject0041.h5 + - subject0043.h5 + - subject0044.h5 + - subject0045.h5 + - subject0046.h5 + - subject0048.h5 + - subject0050.h5 + - subject0051.h5 + - subject0052.h5 + - subject0055.h5 + - subject0056.h5 + - subject0057.h5 + - subject0058.h5 + - subject0059.h5 + - subject0060.h5 + - subject0061.h5 + - subject0062.h5 + - subject0063.h5 + - subject0065.h5 + - subject0066.h5 + - subject0067.h5 + - subject0069.h5 + - subject0072.h5 + - subject0073.h5 + - subject0075.h5 + - subject0076.h5 + - subject0078.h5 + - subject0079.h5 + - subject0080.h5 + - subject0081.h5 + + + + \ No newline at end of file diff --git a/unigaze/configs/data/xgaze_0_60sub_d3.yaml b/unigaze/configs/data/xgaze_0_60sub_d3.yaml new file mode 100644 index 0000000000000000000000000000000000000000..4f898d5fd12368e0b47a572fe7224d8fbba18edc --- /dev/null +++ b/unigaze/configs/data/xgaze_0_60sub_d3.yaml @@ -0,0 +1,80 @@ + +type: datasets.xgaze.XGazeDataset +params: + data_name: xgaze_v2_224 + + images_per_frame: 18 + color_type: bgr + transform_type: 'basic_imagenet' + dataset_path: null + frame_tag: 'all' + camera_random: 3 + image_size: 224 + + + + keys_to_use: + - subject0000.h5 + - subject0003.h5 + - subject0004.h5 + - subject0005.h5 + - subject0006.h5 + - subject0007.h5 + - subject0008.h5 + - subject0009.h5 + - subject0010.h5 + - subject0013.h5 + - subject0014.h5 + - subject0015.h5 + - subject0016.h5 + - subject0018.h5 + - subject0019.h5 + - subject0021.h5 + - subject0024.h5 + - subject0026.h5 + - subject0027.h5 + - subject0028.h5 + - subject0029.h5 + - subject0030.h5 + - subject0031.h5 + - subject0032.h5 + - subject0033.h5 + - subject0035.h5 + - subject0036.h5 + - subject0038.h5 + - subject0039.h5 + - subject0040.h5 + - subject0041.h5 + - subject0043.h5 + - subject0044.h5 + - subject0045.h5 + - subject0046.h5 + - subject0048.h5 + - subject0050.h5 + - subject0051.h5 + - subject0052.h5 + - subject0055.h5 + - subject0056.h5 + - subject0057.h5 + - subject0058.h5 + - subject0059.h5 + - subject0060.h5 + - subject0061.h5 + - subject0062.h5 + - subject0063.h5 + - subject0065.h5 + - subject0066.h5 + - subject0067.h5 + - subject0069.h5 + - subject0072.h5 + - subject0073.h5 + - subject0075.h5 + - subject0076.h5 + - subject0078.h5 + - subject0079.h5 + - subject0080.h5 + - subject0081.h5 + + + + \ No newline at end of file diff --git a/unigaze/configs/data/xgaze_0_80sub.yaml b/unigaze/configs/data/xgaze_0_80sub.yaml new file mode 100644 index 0000000000000000000000000000000000000000..4450ec951e6b7955c0a3acc85d34958a33da614a --- /dev/null +++ b/unigaze/configs/data/xgaze_0_80sub.yaml @@ -0,0 +1,97 @@ + +type: datasets.xgaze.XGazeDataset +params: + data_name: xgaze_v2_224 + + images_per_frame: 18 + color_type: bgr + transform_type: 'basic_imagenet' + dataset_path: null + frame_tag: 'all' + image_size: 224 + + keys_to_use: + - subject0000.h5 + - subject0003.h5 + - subject0004.h5 + - subject0005.h5 + - subject0006.h5 + - subject0007.h5 + - subject0008.h5 + - subject0009.h5 + - subject0010.h5 + - subject0013.h5 + - subject0014.h5 + - subject0015.h5 + - subject0016.h5 + - subject0018.h5 + - subject0019.h5 + - subject0021.h5 + - subject0024.h5 + - subject0026.h5 + - subject0027.h5 + - subject0028.h5 + - subject0029.h5 + - subject0030.h5 + - subject0031.h5 + - subject0032.h5 + - subject0033.h5 + - subject0035.h5 + - subject0036.h5 + - subject0038.h5 + - subject0039.h5 + - subject0040.h5 + - subject0041.h5 + - subject0043.h5 + - subject0044.h5 + - subject0045.h5 + - subject0046.h5 + - subject0048.h5 + - subject0050.h5 + - subject0051.h5 + - subject0052.h5 + - subject0055.h5 + - subject0056.h5 + - subject0057.h5 + - subject0058.h5 + - subject0059.h5 + - subject0060.h5 + - subject0061.h5 + - subject0062.h5 + - subject0063.h5 + - subject0065.h5 + - subject0066.h5 + - subject0067.h5 + - subject0069.h5 + - subject0072.h5 + - subject0073.h5 + - subject0075.h5 + - subject0076.h5 + - subject0078.h5 + - subject0079.h5 + - subject0080.h5 + - subject0081.h5 + - subject0083.h5 + - subject0084.h5 + - subject0085.h5 + - subject0088.h5 + - subject0090.h5 + - subject0092.h5 + - subject0095.h5 + - subject0098.h5 + - subject0099.h5 + - subject0100.h5 + - subject0101.h5 + - subject0102.h5 + - subject0103.h5 + - subject0104.h5 + - subject0105.h5 + - subject0106.h5 + - subject0107.h5 + - subject0108.h5 + - subject0109.h5 + - subject0111.h5 + + + + \ No newline at end of file diff --git a/unigaze/configs/data/xgaze_0_80sub_d3.yaml b/unigaze/configs/data/xgaze_0_80sub_d3.yaml new file mode 100644 index 0000000000000000000000000000000000000000..506416e9e8b12f12d428223cd2a3e6b2a13a9266 --- /dev/null +++ b/unigaze/configs/data/xgaze_0_80sub_d3.yaml @@ -0,0 +1,97 @@ + +type: datasets.xgaze.XGazeDataset +params: + data_name: xgaze_v2_224 + color_type: bgr + images_per_frame: 18 + transform_type: 'basic_imagenet' + dataset_path: null + frame_tag: 'all' + camera_random: 3 + image_size: 224 + + keys_to_use: + - subject0000.h5 + - subject0003.h5 + - subject0004.h5 + - subject0005.h5 + - subject0006.h5 + - subject0007.h5 + - subject0008.h5 + - subject0009.h5 + - subject0010.h5 + - subject0013.h5 + - subject0014.h5 + - subject0015.h5 + - subject0016.h5 + - subject0018.h5 + - subject0019.h5 + - subject0021.h5 + - subject0024.h5 + - subject0026.h5 + - subject0027.h5 + - subject0028.h5 + - subject0029.h5 + - subject0030.h5 + - subject0031.h5 + - subject0032.h5 + - subject0033.h5 + - subject0035.h5 + - subject0036.h5 + - subject0038.h5 + - subject0039.h5 + - subject0040.h5 + - subject0041.h5 + - subject0043.h5 + - subject0044.h5 + - subject0045.h5 + - subject0046.h5 + - subject0048.h5 + - subject0050.h5 + - subject0051.h5 + - subject0052.h5 + - subject0055.h5 + - subject0056.h5 + - subject0057.h5 + - subject0058.h5 + - subject0059.h5 + - subject0060.h5 + - subject0061.h5 + - subject0062.h5 + - subject0063.h5 + - subject0065.h5 + - subject0066.h5 + - subject0067.h5 + - subject0069.h5 + - subject0072.h5 + - subject0073.h5 + - subject0075.h5 + - subject0076.h5 + - subject0078.h5 + - subject0079.h5 + - subject0080.h5 + - subject0081.h5 + - subject0083.h5 + - subject0084.h5 + - subject0085.h5 + - subject0088.h5 + - subject0090.h5 + - subject0092.h5 + - subject0095.h5 + - subject0098.h5 + - subject0099.h5 + - subject0100.h5 + - subject0101.h5 + - subject0102.h5 + - subject0103.h5 + - subject0104.h5 + - subject0105.h5 + - subject0106.h5 + - subject0107.h5 + - subject0108.h5 + - subject0109.h5 + - subject0111.h5 + + + + \ No newline at end of file diff --git a/unigaze/configs/data/xgaze_60_80sub.yaml b/unigaze/configs/data/xgaze_60_80sub.yaml new file mode 100644 index 0000000000000000000000000000000000000000..be8b9545d9e5b79a6c39e234027b5ba349f0d38a --- /dev/null +++ b/unigaze/configs/data/xgaze_60_80sub.yaml @@ -0,0 +1,31 @@ +type: datasets.xgaze.XGazeDataset +params: + data_name: xgaze_v2_224 + images_per_frame: 18 + color_type: bgr + transform_type: 'basic_imagenet' + dataset_path: null + frame_tag: 'all' + image_size: 224 + keys_to_use: + - subject0083.h5 + - subject0084.h5 + - subject0085.h5 + - subject0088.h5 + - subject0090.h5 + - subject0092.h5 + - subject0095.h5 + - subject0098.h5 + - subject0099.h5 + - subject0100.h5 + - subject0101.h5 + - subject0102.h5 + - subject0103.h5 + - subject0104.h5 + - subject0105.h5 + - subject0106.h5 + - subject0107.h5 + - subject0108.h5 + - subject0109.h5 + - subject0111.h5 + \ No newline at end of file diff --git a/unigaze/configs/exp/blank.yaml b/unigaze/configs/exp/blank.yaml new file mode 100644 index 0000000000000000000000000000000000000000..34dd91d0ff0735004867afd04436ada260c1d5d3 --- /dev/null +++ b/unigaze/configs/exp/blank.yaml @@ -0,0 +1,22 @@ + + + + +exp_name: tbd + + +exp_explanation: + +data: null + +## can be overwritten +model: null + + +trainer: null +loss: null + + +optimizer: configs/optimizers/default_Adam_e4.yaml +scheduler: configs/schedulers/default_stepLR_5.yaml + diff --git a/unigaze/configs/exp/cross/train_ED.yaml b/unigaze/configs/exp/cross/train_ED.yaml new file mode 100644 index 0000000000000000000000000000000000000000..51f4b6ea2021bfeeed8c397b2c35989779f09bf9 --- /dev/null +++ b/unigaze/configs/exp/cross/train_ED.yaml @@ -0,0 +1,27 @@ + + + +train: + - configs/data/eyediap_cs.yaml + - configs/data/eyediap_ft.yaml + + + +val: + + - configs/data/mpiigaze_train.yaml + - configs/data/mpiigaze_test.yaml + + +test: + # - configs/data/xgaze_0_80sub.yaml + - configs/data/xgaze_0_60sub.yaml + - configs/data/xgaze_60_80sub.yaml + + - configs/data/gazecapture_train.yaml + - configs/data/gazecapture_test.yaml + - configs/data/gaze360_train.yaml + - configs/data/gaze360_test.yaml + + - configs/data/mpiigaze.yaml + - configs/data/our_mpii.yaml \ No newline at end of file diff --git a/unigaze/configs/exp/cross/train_G360.yaml b/unigaze/configs/exp/cross/train_G360.yaml new file mode 100644 index 0000000000000000000000000000000000000000..cf23a8bb8ff8495314d9bc34ec89df012f9d3d39 --- /dev/null +++ b/unigaze/configs/exp/cross/train_G360.yaml @@ -0,0 +1,32 @@ + +train: + - configs/data/gaze360_train.yaml + - configs/data/gaze360_test.yaml + + +val: + - configs/data/mpiigaze_train.yaml + - configs/data/mpiigaze_test.yaml + + - configs/data/eyediap_cs_train.yaml + - configs/data/eyediap_cs_test.yaml + - configs/data/eyediap_ft_train.yaml + - configs/data/eyediap_ft_test.yaml + + + + +test: + # - configs/data/xgaze_0_80sub.yaml + - configs/data/xgaze_0_60sub.yaml + - configs/data/xgaze_60_80sub.yaml + - configs/data/gazecapture_train.yaml + - configs/data/gazecapture_test.yaml + + - configs/data/mpiigaze.yaml + # - configs/data/our_mpii.yaml + + - configs/data/eyediap_cs.yaml + - configs/data/eyediap_ft.yaml + + diff --git a/unigaze/configs/exp/cross/train_GC.yaml b/unigaze/configs/exp/cross/train_GC.yaml new file mode 100644 index 0000000000000000000000000000000000000000..19c4036f5dcb071290d1e9cbbe183e5290730d01 --- /dev/null +++ b/unigaze/configs/exp/cross/train_GC.yaml @@ -0,0 +1,26 @@ + +train: + - configs/data/gazecapture_train_ds15.yaml + - configs/data/gazecapture_test_ds15.yaml + + + +val: + - configs/data/mpiigaze_train.yaml + - configs/data/mpiigaze_test.yaml + + - configs/data/eyediap_cs_train.yaml + - configs/data/eyediap_cs_test.yaml + - configs/data/eyediap_ft_train.yaml + - configs/data/eyediap_ft_test.yaml + + +test: + + - configs/data/xgaze_0_60sub.yaml + - configs/data/xgaze_60_80sub.yaml + - configs/data/gaze360_train.yaml + - configs/data/gaze360_test.yaml + - configs/data/mpiigaze.yaml + - configs/data/eyediap_cs.yaml + - configs/data/eyediap_ft.yaml \ No newline at end of file diff --git a/unigaze/configs/exp/cross/train_MPII.yaml b/unigaze/configs/exp/cross/train_MPII.yaml new file mode 100644 index 0000000000000000000000000000000000000000..bf3c558fd4b4ddc524bc8b6672a501a8b6269f92 --- /dev/null +++ b/unigaze/configs/exp/cross/train_MPII.yaml @@ -0,0 +1,24 @@ + + +train: + - configs/data/mpiigaze.yaml + +val: + + - configs/data/eyediap_cs_train.yaml + - configs/data/eyediap_cs_test.yaml + - configs/data/eyediap_ft_train.yaml + - configs/data/eyediap_ft_test.yaml + + +test: + - configs/data/xgaze_0_60sub.yaml + - configs/data/xgaze_60_80sub.yaml + - configs/data/gazecapture_train.yaml + - configs/data/gazecapture_test.yaml + - configs/data/gaze360_train.yaml + - configs/data/gaze360_test.yaml + + + - configs/data/eyediap_cs.yaml + - configs/data/eyediap_ft.yaml diff --git a/unigaze/configs/exp/cross/train_X.yaml b/unigaze/configs/exp/cross/train_X.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b609b8353222d9178269f6ea24a3105e37541361 --- /dev/null +++ b/unigaze/configs/exp/cross/train_X.yaml @@ -0,0 +1,27 @@ + + + +train: + - configs/data/xgaze_0_80sub_d3.yaml + + +val: + + - configs/data/mpiigaze_train.yaml + - configs/data/mpiigaze_test.yaml + + - configs/data/eyediap_cs_train.yaml + - configs/data/eyediap_cs_test.yaml + - configs/data/eyediap_ft_train.yaml + - configs/data/eyediap_ft_test.yaml + +test: + - configs/data/gazecapture_train.yaml + - configs/data/gazecapture_test.yaml + - configs/data/gaze360_train.yaml + - configs/data/gaze360_test.yaml + + - configs/data/mpiigaze.yaml + + - configs/data/eyediap_cs.yaml + - configs/data/eyediap_ft.yaml \ No newline at end of file diff --git a/unigaze/configs/exp/joint/X_GC_M_ED_g360.yaml b/unigaze/configs/exp/joint/X_GC_M_ED_g360.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d38bf65a1ee3b2792eca1ff3f8772fe434e7b2aa --- /dev/null +++ b/unigaze/configs/exp/joint/X_GC_M_ED_g360.yaml @@ -0,0 +1,23 @@ + + +train: + - configs/data/xgaze_0_60sub_d3.yaml + - configs/data/gazecapture_train_ds15.yaml + - configs/data/mpiigaze_train.yaml + - configs/data/eyediap_cs_train.yaml + - configs/data/eyediap_ft_train.yaml + - configs/data/gaze360_train.yaml + + +val: + - configs/data/eyediap_cs_test.yaml + - configs/data/eyediap_ft_test.yaml + - configs/data/gaze360_test.yaml + - configs/data/mpiigaze_test.yaml + - configs/data/our_mpii_test.yaml + +test: + - configs/data/xgaze_60_80sub.yaml + - configs/data/gazecapture_test.yaml + + diff --git a/unigaze/configs/loss/l1_loss.yaml b/unigaze/configs/loss/l1_loss.yaml new file mode 100644 index 0000000000000000000000000000000000000000..eb52f4c8dc5abd16c854b139aff7947538eb02ed --- /dev/null +++ b/unigaze/configs/loss/l1_loss.yaml @@ -0,0 +1,4 @@ +loss_config: + type: criteria.gaze_loss.PitchYawLoss + params: + loss_type: 'l1' diff --git a/unigaze/configs/model/hybrid_tr50.yaml b/unigaze/configs/model/hybrid_tr50.yaml new file mode 100644 index 0000000000000000000000000000000000000000..66b40784c01e20611c175f3ba8083d37b660e873 --- /dev/null +++ b/unigaze/configs/model/hybrid_tr50.yaml @@ -0,0 +1,5 @@ + +net_config: + type: models.hybrid_tr.HybridTR50 + + diff --git a/unigaze/configs/model/mae_b_16_gaze.yaml b/unigaze/configs/model/mae_b_16_gaze.yaml new file mode 100644 index 0000000000000000000000000000000000000000..5dcbafb832c9e985ab1626e8d22ebd5592c4b4ca --- /dev/null +++ b/unigaze/configs/model/mae_b_16_gaze.yaml @@ -0,0 +1,11 @@ + + + + +net_config: + type: models.vit.mae_gaze.MAE_Gaze + params: + model_type: 'vit_b_16' + global_pool: False + drop_path_rate: 0.1 + custom_pretrained_path: checkpoints/mae_b16/mae_b16_checkpoint-299.pth \ No newline at end of file diff --git a/unigaze/configs/model/mae_h_14_gaze.yaml b/unigaze/configs/model/mae_h_14_gaze.yaml new file mode 100644 index 0000000000000000000000000000000000000000..15190914573b035a687300aeb148de5902c5bb3e --- /dev/null +++ b/unigaze/configs/model/mae_h_14_gaze.yaml @@ -0,0 +1,14 @@ +# ViT-Huge patchsize=14 +# data_yamls="${ffhqnv_yaml} ${facesyn_yaml} ${sfhq_t2i_224_yaml} ${vfhq_yaml} ${celebv_yaml} ${vgg2_yaml} ${xgaze_mvs_dense_224_yaml}" + + +net_config: + type: models.vit.mae_gaze.MAE_Gaze + params: + model_type: 'vit_h_14' + global_pool: False + drop_path_rate: 0.1 + custom_pretrained_path: checkpoints/mae_h14/mae_h14_checkpoint-299.pth + + + diff --git a/unigaze/configs/model/mae_l_16_gaze.yaml b/unigaze/configs/model/mae_l_16_gaze.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f70f621b76d602f5974447e47bfac11500c152e8 --- /dev/null +++ b/unigaze/configs/model/mae_l_16_gaze.yaml @@ -0,0 +1,11 @@ + +s +net_config: + type: models.vit.mae_gaze.MAE_Gaze + params: + model_type: 'vit_l_16' + global_pool: False + drop_path_rate: 0.1 + custom_pretrained_path: checkpoints/mae_l16/mae_h14_checkpoint-299.pth + + diff --git a/unigaze/configs/model/res18.yaml b/unigaze/configs/model/res18.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c3a899339867461cc727228b8ce709350b2cab88 --- /dev/null +++ b/unigaze/configs/model/res18.yaml @@ -0,0 +1,5 @@ + +net_config: + type: models.resnet.Res18 + + diff --git a/unigaze/configs/model/res50.yaml b/unigaze/configs/model/res50.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8316bee1c70c46522ddda4126702c20bd5dc0f44 --- /dev/null +++ b/unigaze/configs/model/res50.yaml @@ -0,0 +1,6 @@ + + +net_config: + type: models.resnet.Res50 + + diff --git a/unigaze/configs/model/vit_b_16_gaze.yaml b/unigaze/configs/model/vit_b_16_gaze.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ca3a5975a96acc8776320090e0508802a7fb28ad --- /dev/null +++ b/unigaze/configs/model/vit_b_16_gaze.yaml @@ -0,0 +1,11 @@ + + +net_config: + type: models.vit.vit_gaze.ViTGaze + params: + vit_type: 'b_16' + + + + + diff --git a/unigaze/configs/model/vit_h_14_gaze.yaml b/unigaze/configs/model/vit_h_14_gaze.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9e31c39b13b710e946521d95eca087cbf867462c --- /dev/null +++ b/unigaze/configs/model/vit_h_14_gaze.yaml @@ -0,0 +1,19 @@ + + +# net_config: +# type: models.vit.vit_gaze.CustomViT_H14 +# params: +# imagenet_pretrained_path: checkpoints/vit_pretrain/vit_h_14_swag-80465313.pth + + + + +net_config: + type: models.vit.vit_gaze.CustomViT_H14 + params: + global_pool: False + drop_path_rate: 0.1 + custom_pretrained_path: checkpoints/vit_pretrain/timm_vit_h_14_imagenet21k.bin + + + diff --git a/unigaze/configs/model/vit_l_16_gaze.yaml b/unigaze/configs/model/vit_l_16_gaze.yaml new file mode 100644 index 0000000000000000000000000000000000000000..12b7d46587d57b217097b3af441517d29d36490d --- /dev/null +++ b/unigaze/configs/model/vit_l_16_gaze.yaml @@ -0,0 +1,11 @@ + + +net_config: + type: models.vit.vit_gaze.ViTGaze + params: + vit_type: 'l_16' + + + + + diff --git a/unigaze/configs/optimizers/default_Adam_e4.yaml b/unigaze/configs/optimizers/default_Adam_e4.yaml new file mode 100644 index 0000000000000000000000000000000000000000..fa408a9b98ea898253feaa9eaac863b5ab6c4fa4 --- /dev/null +++ b/unigaze/configs/optimizers/default_Adam_e4.yaml @@ -0,0 +1,7 @@ + + +optimizer_name: Adam # Optimizer type (e.g., Adam, SGD, AdamW) + +lr: 0.0001 # Initial learning rate + +weight_decay: 0.000001 # Weight decay (L2 regularization) \ No newline at end of file diff --git a/unigaze/configs/schedulers/OneCycleLR.yaml b/unigaze/configs/schedulers/OneCycleLR.yaml new file mode 100644 index 0000000000000000000000000000000000000000..84a49afee7da1a6a474ead6728f69a31330c786f --- /dev/null +++ b/unigaze/configs/schedulers/OneCycleLR.yaml @@ -0,0 +1,7 @@ + + +scheduler_name: OneCycleLR + + +div_factor: 25.0 +final_div_factor: 10000.0 \ No newline at end of file diff --git a/unigaze/configs/schedulers/default_stepLR_5.yaml b/unigaze/configs/schedulers/default_stepLR_5.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e71b37daf9bdc71346df3c0052fb47f4301c5853 --- /dev/null +++ b/unigaze/configs/schedulers/default_stepLR_5.yaml @@ -0,0 +1,7 @@ + + +scheduler_name: StepLR + +step_size: 5 + +gamma: 0.1 \ No newline at end of file diff --git a/unigaze/configs/trainer/simple_trainer.yaml b/unigaze/configs/trainer/simple_trainer.yaml new file mode 100644 index 0000000000000000000000000000000000000000..4ac10752b6adbd481a58224d21b646fe9f0909d3 --- /dev/null +++ b/unigaze/configs/trainer/simple_trainer.yaml @@ -0,0 +1,12 @@ + + +type: trainers.simple_trainer.SimpleTrainer + + + + + + + + + diff --git a/unigaze/infer_runtime.py b/unigaze/infer_runtime.py new file mode 100644 index 0000000000000000000000000000000000000000..8874a37496f06a12e26d96be1548fa3cbc938ba7 --- /dev/null +++ b/unigaze/infer_runtime.py @@ -0,0 +1,297 @@ +# unigaze/infer_runtime.py +import os +import cv2 +import time +import tempfile +from pathlib import Path +from typing import Optional, Tuple, Dict, Any, List + +import numpy as np +import torch +from omegaconf import OmegaConf + +# --- UniGaze / your repo imports (these exist in your repo) --- +from datasets.helper.image_transform import wrap_transforms +from utils import instantiate_from_cfg +from utils.util import set_seed +from gazelib.gaze.gaze_utils import pitchyaw_to_vector, vector_to_pitchyaw +from gazelib.gaze.normalize import estimateHeadPose, normalize +from gazelib.label_transform import get_face_center_by_nose +import face_alignment + +# ---------------- Helpers copied from predict_gaze_video.py ---------------- + +def draw_gaze(image_in, pitchyaw, thickness=8, color=(0, 0, 255)): + image_out = image_in.copy() + (h, w) = image_in.shape[:2] + length = w / 2.0 + pos = (int(h / 2.0), int(w / 2.0)) + if len(image_out.shape) == 2 or image_out.shape[2] == 1: + image_out = cv2.cvtColor(image_out, cv2.COLOR_GRAY2BGR) + dx = -length * np.sin(pitchyaw[1]) * np.cos(pitchyaw[0]) + dy = -length * np.sin(pitchyaw[0]) + end_point = (int(pos[0] + dx), int(pos[1] + dy)) + + shadow_offset = 2 + shadow_color = (40, 40, 40) + shadow_end = (end_point[0] + shadow_offset, end_point[1] + shadow_offset) + cv2.arrowedLine(image_out, (pos[0] + shadow_offset, pos[1] + shadow_offset), + shadow_end, shadow_color, thickness + 2, cv2.LINE_AA, tipLength=0.3) + + thickness_values = [4, 3, 2, 1] + num_layers = len(thickness_values) + for i in range(num_layers): + alpha = i / num_layers + layer_color = tuple(int((1 - alpha) * color[j] + alpha * 255) for j in range(3)) + cv2.arrowedLine(image_out, pos, end_point, layer_color, thickness_values[i], + cv2.LINE_AA, tipLength=0.3) + return image_out + +def set_dummy_camera_model(image=None): + h, w = image.shape[:2] + focal_length = w * 4 + center = (w // 2, h // 2) + camera_matrix = np.array( + [[focal_length, 0, center[0]], + [0, focal_length, center[1]], + [0, 0, 1]], dtype="double" + ) + camera_distortion = np.zeros((1, 5)) + return np.array(camera_matrix), np.array(camera_distortion) + +def denormalize_predicted_gaze(gaze_yaw_pitch, R_inv): + pred_gaze_cancel_nor = pitchyaw_to_vector(gaze_yaw_pitch.reshape(1, 2)).reshape(3, 1) + pred_gaze_cancel_nor = np.matmul(R_inv, pred_gaze_cancel_nor.reshape(3, 1)) + pred_gaze_cancel_nor = pred_gaze_cancel_nor / np.linalg.norm(pred_gaze_cancel_nor) + pred_yaw_pitch_cancel_nor = vector_to_pitchyaw(pred_gaze_cancel_nor.reshape(1, 3)) + return pred_gaze_cancel_nor, pred_yaw_pitch_cancel_nor + +def load_checkpoint(model, ckpt_key, ckpt_path): + assert os.path.isfile(ckpt_path), f"Checkpoint not found: {ckpt_path}" + weights = torch.load(ckpt_path, map_location="cpu") + model_state = weights[ckpt_key] + if next(iter(model_state.keys())).startswith("module."): + model_state = {k[7:]: v for k, v in model_state.items()} + model.load_state_dict(model_state, strict=True) + del weights + +# ---------------- Runtime (model + pre/post) ---------------- + +class UniGazeRuntime: + def __init__(self, cfg_path: str, ckpt_path: str, device: str = "cpu"): + self.device = torch.device(device) + torch.set_grad_enabled(False) + set_seed(42) + + pretrained_model_cfg = OmegaConf.load(cfg_path)["net_config"] + pretrained_model_cfg.params.custom_pretrained_path = None + + self.model = instantiate_from_cfg(pretrained_model_cfg) + load_checkpoint(self.model, "model_state", ckpt_path) + self.model.eval().to(self.device) + + self.transform = wrap_transforms("basic_imagenet", image_size=224) + self.fa = face_alignment.FaceAlignment( + face_alignment.LandmarksType.TWO_D, + flip_input=False, + device=self.device.type, # 'cpu' or 'cuda' + ) + + # Constants from your script + self.focal_norm = 960 + self.distance_norm = 600 + self.roi_size = (224, 224) + self.resize_factor = 0.5 + self.arrow_colors = [(47, 255, 173)] # BGR + + # ---- One-frame inference on a BGR frame; returns annotated BGR frame ---- + def process_frame(self, image_original_bgr: np.ndarray) -> np.ndarray: + image_original = image_original_bgr.copy() + + # resize for detection + if self.resize_factor >= 1: + image_resize = image_original.copy() + else: + image_resize = cv2.resize( + image_original, dsize=None, + fx=self.resize_factor, fy=self.resize_factor, + interpolation=cv2.INTER_AREA + ) + + image_resize = cv2.cvtColor(image_resize, cv2.COLOR_BGR2RGB) + preds = self.fa.get_landmarks(image_resize) + + # no face: just return original + if preds is None: + return image_original + + # ----- keep same semantics as predict_gaze_video.py ----- + landmarks_record = {} # only add when a valid vector is produced + vector_start_end_point_list = {} # start/end 2D points for arrows + bbox_record = {} # for drawing rectangles (same idx keys) + + for idx, landmarks_in_original in enumerate(preds): + color = self.arrow_colors[idx % len(self.arrow_colors)] + + # scale landmarks back to original size + landmarks_in_original = landmarks_in_original / self.resize_factor + x_min = int(landmarks_in_original[:, 0].min()) + x_max = int(landmarks_in_original[:, 0].max()) + y_min = int(landmarks_in_original[:, 1].min()) + y_max = int(landmarks_in_original[:, 1].max()) + + # bbox for drawing (scale 1.2) + scale_factor_draw = 1.2 + bbox_width = x_max - x_min + bbox_height = y_max - y_min + bbox_center = ((x_min + x_max) // 2, (y_min + y_max) // 2) + x_min_draw = max(0, bbox_center[0] - int(bbox_width * scale_factor_draw // 2)) + x_max_draw = min(image_original.shape[1], bbox_center[0] + int(bbox_width * scale_factor_draw // 2)) + y_min_draw = max(0, bbox_center[1] - int(bbox_height * scale_factor_draw // 2)) + y_max_draw = min(image_original.shape[0], bbox_center[1] + int(bbox_height * scale_factor_draw // 2)) + bbox_record[idx] = (x_min_draw, y_min_draw, x_max_draw, y_max_draw) + + # crop for normalization & inference (scale 2.0) + scale_factor_crop = 2.0 + bbox_width = x_max - x_min + bbox_height = y_max - y_min + bbox_center = ((x_min + x_max) // 2, (y_min + y_max) // 2) + x_min_c = max(0, bbox_center[0] - int(bbox_width * scale_factor_crop // 2)) + x_max_c = min(image_original.shape[1], bbox_center[0] + int(bbox_width * scale_factor_crop // 2)) + y_min_c = max(0, bbox_center[1] - int(bbox_height * scale_factor_crop // 2)) + y_max_c = min(image_original.shape[0], bbox_center[1] + int(bbox_height * scale_factor_crop // 2)) + + image = image_original[y_min_c:y_max_c, x_min_c:x_max_c] + landmarks = landmarks_in_original - np.array([x_min_c, y_min_c]) + + # camera + head pose + camera_matrix, camera_distortion = set_dummy_camera_model(image=image) + face_model_load = np.loadtxt("data/face_model.txt") + face_model = face_model_load[[20, 23, 26, 29, 15, 19], :] + facePts = face_model.reshape(6, 1, 3) + + landmarks_sub = landmarks[[36, 39, 42, 45, 31, 35], :].astype(float).reshape(6, 1, 2) + hr, ht = estimateHeadPose(landmarks_sub, facePts, camera_matrix, camera_distortion) + hR = cv2.Rodrigues(hr)[0] + face_center_camera_cord, _ = get_face_center_by_nose(hR=hR, ht=ht, face_model_load=face_model_load) + + # normalize + img_normalized, R, hR_norm, _, _, _ = normalize( + image, landmarks, self.focal_norm, self.distance_norm, + self.roi_size, face_center_camera_cord, hr, ht, camera_matrix, gc=None + ) + + # skip bad heads (same as script) + hr_norm = np.array([np.arcsin(hR_norm[1, 2]), np.arctan2(hR_norm[0, 2], hR_norm[2, 2])]) + if np.linalg.norm(hr_norm) > 80 * np.pi / 180: + continue + + # inference + input_var = img_normalized[:, :, [2, 1, 0]] # BGR->RGB + input_var = self.transform(input_var) + input_var = torch.as_tensor(input_var, dtype=torch.float32, device=self.device).unsqueeze(0) + with torch.no_grad(): + ret = self.model(input_var) + + pred_gaze = ret["pred_gaze"][0] + pred_gaze_np = pred_gaze.detach().cpu().numpy() + + # denormalize to original camera coords, then project to 2D + R_inv = np.linalg.inv(R) + pred_gaze_cancel_nor, _ = denormalize_predicted_gaze(pred_gaze_np, R_inv) + + vec_length = pred_gaze_cancel_nor * -112 * 1.5 + gazeRay = np.concatenate( + (face_center_camera_cord.reshape(1, 3), + (face_center_camera_cord + vec_length).reshape(1, 3)), + axis=0, + ) + result = cv2.projectPoints( + gazeRay, + np.array([0, 0, 0]).reshape(3, 1).astype(float), + np.array([0, 0, 0]).reshape(3, 1).astype(float), + camera_matrix, camera_distortion + )[0].reshape(2, 2) + result += np.array([x_min_c, y_min_c]) + + vector_start_point = (int(result[0][0]), int(result[0][1])) + vector_end_point = (int(result[1][0]), int(result[1][1])) + + # only record these after a valid vector exists + vector_start_end_point_list[idx] = (vector_start_point, vector_end_point) + landmarks_record[idx] = landmarks_in_original + + # If nothing valid was produced, return original frame + if not landmarks_record: + return image_original + + # ----- draw exactly like predict_gaze_video.py (iterate over landmarks_record) ----- + for idx in list(landmarks_record.keys()): + x_min_d, y_min_d, x_max_d, y_max_d = bbox_record[idx] + color = self.arrow_colors[idx % len(self.arrow_colors)] + + cv2.rectangle(image_original, (x_min_d, y_min_d), (x_max_d, y_max_d), (0, 0, 240), 2) + + vsp, vep = vector_start_end_point_list[idx] + shadow_offset = 2 + shadow_color = (40, 40, 40) + shadow_end = (vep[0] + shadow_offset, vep[1] + shadow_offset) + cv2.arrowedLine( + image_original, + (vsp[0] + shadow_offset, vsp[1] + shadow_offset), + shadow_end, + shadow_color, + 5, + cv2.LINE_AA, + tipLength=0.2, + ) + + thickness_values = [x * 3 for x in [4, 3, 2, 1]] + num_layers = len(thickness_values) + for i in range(num_layers): + alpha = i / num_layers + layer_color = tuple(int((1 - alpha) * color[j] + alpha * 255) for j in range(3)) + cv2.arrowedLine(image_original, vsp, vep, layer_color, thickness_values[i], cv2.LINE_AA, tipLength=0.2) + + return image_original + + + # ---- Public APIs ---- + + def predict_image(self, image_rgb: np.ndarray) -> np.ndarray: + """Accepts an RGB image (HxWx3) and returns an annotated RGB image.""" + bgr = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2BGR) + out_bgr = self.process_frame(bgr) + out_rgb = cv2.cvtColor(out_bgr, cv2.COLOR_BGR2RGB) + return out_rgb + + def predict_video(self, video_path: str) -> Tuple[Optional[str], Optional[np.ndarray], float]: + """ + Process a video file and return: + - temp MP4 path (string) for Gradio Video + - last annotated RGB frame (numpy) for Gradio Image + - total runtime seconds (float) + """ + t0 = time.time() + cap = cv2.VideoCapture(video_path) + if not cap.isOpened(): + return None, None, 0.0 + + width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + fps = max(1, int(cap.get(cv2.CAP_PROP_FPS)) or 25) + + tmp_mp4 = Path(tempfile.mkdtemp(prefix="unigaze_vid_")) / "out.mp4" + fourcc = cv2.VideoWriter_fourcc(*"mp4v") + writer = cv2.VideoWriter(str(tmp_mp4), fourcc, fps, (width, height)) + + while True: + ret, frame_bgr = cap.read() + if not ret: + break + out_bgr = self.process_frame(frame_bgr) + writer.write(out_bgr) + + cap.release() + writer.release() + return str(tmp_mp4), float(time.time() - t0) \ No newline at end of file diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9693f41679570977d700895a4f02228b9afb30a4 --- /dev/null +++ b/utils/__init__.py @@ -0,0 +1,3 @@ + + +from .util import * diff --git a/utils/helper.py b/utils/helper.py new file mode 100644 index 0000000000000000000000000000000000000000..1bf2a5f0baf40ac41d4c753dc3c45167ca1fc833 --- /dev/null +++ b/utils/helper.py @@ -0,0 +1,79 @@ + +import torch +import numpy as np +import cv2 +import torch.nn as nn + + + +class AverageMeter(object): + """ + Computes and stores the average and + current value. + """ + def __init__(self): + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + + + +def worker_init_fn(_): + worker_info = torch.utils.data.get_worker_info() + dataset = worker_info.dataset + worker_id = worker_info.id + return np.random.seed(np.random.get_state()[1][0] + worker_id) + + + +def recover_image( image_tensor, MEAN=[0.5, 0.5, 0.5], STD=[0.5, 0.5, 0.5]): + """ + read a tensor and recover it to image in cv2 format + args: + image_tensor: [C, H, W] or [B, C, H, W] + return: + image_save: [B, H, W, C] + """ + if image_tensor.ndim == 3: + image_tensor = image_tensor.unsqueeze(0) + + x = torch.mul(image_tensor, torch.FloatTensor(STD).view(3,1,1).to(image_tensor.device)) + x = torch.add(x, torch.FloatTensor(MEAN).view(3,1,1).to(image_tensor.device) ) + x = x.data.cpu().numpy() + # [C, H, W] -> [H, W, C] + image_rgb = np.transpose(x, (0, 2, 3, 1)) + # RGB -> BGR + image_bgr = image_rgb[:, :, :, [2,1,0]] + # float -> int + image_save = np.clip(image_bgr*255, 0, 255).astype('uint8') + + return image_save + + + +def align_images(image_list, h, w): + if len(image_list) != h * w: + # automatically calculate the number of rows, try to make it as square as possible + h = int(np.sqrt(len(image_list))) + w = int(np.ceil(len(image_list) / h)) + ## if the number of images is not a perfect square, add blank images to the list + image_list += [np.zeros_like(image_list[0])] * (h * w - len(image_list)) + + rows = [image_list[i * w:(i + 1) * w] for i in range(h)] + row_images = [cv2.hconcat(row) for row in rows] + final_image = cv2.vconcat(row_images) + return final_image + + + diff --git a/utils/misc.py b/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..c83dca8fed56530edb962aa26628b72e4f8e4c6b --- /dev/null +++ b/utils/misc.py @@ -0,0 +1,122 @@ +import datetime +import torch +import builtins +import torch.distributed as dist +from torch import inf +def is_dist_avail_and_initialized(): + if not dist.is_available(): + return False + if not dist.is_initialized(): + return False + return True + + +def get_world_size(): + if not is_dist_avail_and_initialized(): + return 1 + return dist.get_world_size() + + +def get_rank(): + if not is_dist_avail_and_initialized(): + return 0 + return dist.get_rank() + + +def is_main_process(): + return get_rank() == 0 + + +def save_on_master(*args, **kwargs): + if is_main_process(): + torch.save(*args, **kwargs) + +def setup_for_distributed(is_master): + """ + This function disables printing when not in master process + """ + builtin_print = builtins.print + + def print(*args, **kwargs): + force = kwargs.pop('force', False) + force = force or (get_world_size() > 8) + if is_master or force: + now = datetime.datetime.now().time() + builtin_print('[{}] '.format(now), end='') # print with time stamp + builtin_print(*args, **kwargs) + + builtins.print = print + + +def all_reduce_mean(x): + world_size = get_world_size() + if world_size > 1: + x_reduce = torch.tensor(x).cuda() + # x_reduce = x.clone() + dist.all_reduce(x_reduce) + x_reduce /= world_size + return x_reduce.item() + else: + return x + + + +def gather_tensors(tensor): + # Gather tensors across all processes + gather_list = [torch.zeros_like(tensor) for _ in range(dist.get_world_size())] + dist.all_gather(gather_list, tensor) + # return torch.cat(gather_list, dim=0) + + # Stack tensors along a new dimension (each GPU result along dimension 0) + tensor_gather = torch.stack(gather_list, dim=0) + # Transpose to get them in the correct order (swap dimensions 0 and 1) + tensor_gather = tensor_gather.transpose(0, 1) + # Reshape back into a 2D tensor with the correct order + tensor_gather = tensor_gather.reshape(-1, tensor.size(-1)) + return tensor_gather + + + + +def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor: + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + parameters = [p for p in parameters if p.grad is not None] + norm_type = float(norm_type) + if len(parameters) == 0: + return torch.tensor(0.) + device = parameters[0].grad.device + if norm_type == inf: + total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters) + else: + total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type) + return total_norm + + +class NativeScalerWithGradNormCount: + state_dict_key = "amp_scaler" + + def __init__(self): + self._scaler = torch.cuda.amp.GradScaler() + + def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True): + self._scaler.scale(loss).backward(create_graph=create_graph) + if update_grad: + if clip_grad is not None: + assert parameters is not None + self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place + norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad) + else: + self._scaler.unscale_(optimizer) + norm = get_grad_norm_(parameters) + self._scaler.step(optimizer) + self._scaler.update() + else: + norm = None + return norm + + def state_dict(self): + return self._scaler.state_dict() + + def load_state_dict(self, state_dict): + self._scaler.load_state_dict(state_dict) \ No newline at end of file diff --git a/utils/util.py b/utils/util.py new file mode 100644 index 0000000000000000000000000000000000000000..8e3e6bada7c0b7e1552fb69d0cba34d9dedfd237 --- /dev/null +++ b/utils/util.py @@ -0,0 +1,228 @@ +import importlib +import os +import random +import torch +import numpy as np +from collections import abc +from einops import rearrange +from functools import partial + +import multiprocessing as mp +from threading import Thread +from queue import Queue + +from inspect import isfunction +from PIL import Image, ImageDraw, ImageFont +import shutil + +def count_params(model, verbose=False): + total_params = sum(p.numel() for p in model.parameters()) + if verbose: + print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.") + return total_params + +def instantiate_from_cfg(config): + if not "type" in config: + raise KeyError("Expected key `type` to instantiate.") + return get_obj_from_str(config["type"])(**config.get("params", dict())) + +def get_obj_from_str(string, reload=False): + module, cls = string.rsplit(".", 1) + if reload: + module_imp = importlib.import_module(module) + importlib.reload(module_imp) + return getattr(importlib.import_module(module, package=None), cls) + + +def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False): + # create dummy dataset instance + + # run prefetching + if idx_to_fn: + res = func(data, worker_id=idx) + else: + res = func(data) + Q.put([idx, res]) + Q.put("Done") + + +def parallel_data_prefetch( + func: callable, data, n_proc, target_data_type="ndarray", cpu_intensive=True, use_worker_id=False +): + # if target_data_type not in ["ndarray", "list"]: + # raise ValueError( + # "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray." + # ) + if isinstance(data, np.ndarray) and target_data_type == "list": + raise ValueError("list expected but function got ndarray.") + elif isinstance(data, abc.Iterable): + if isinstance(data, dict): + print( + f'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.' + ) + data = list(data.values()) + if target_data_type == "ndarray": + data = np.asarray(data) + else: + data = list(data) + else: + raise TypeError( + f"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}." + ) + + if cpu_intensive: + Q = mp.Queue(1000) + proc = mp.Process + else: + Q = Queue(1000) + proc = Thread + # spawn processes + if target_data_type == "ndarray": + arguments = [ + [func, Q, part, i, use_worker_id] + for i, part in enumerate(np.array_split(data, n_proc)) + ] + else: + step = ( + int(len(data) / n_proc + 1) + if len(data) % n_proc != 0 + else int(len(data) / n_proc) + ) + arguments = [ + [func, Q, part, i, use_worker_id] + for i, part in enumerate( + [data[i: i + step] for i in range(0, len(data), step)] + ) + ] + processes = [] + for i in range(n_proc): + p = proc(target=_do_parallel_data_prefetch, args=arguments[i]) + processes += [p] + + # start processes + print(f"Start prefetching...") + import time + + start = time.time() + gather_res = [[] for _ in range(n_proc)] + try: + for p in processes: + p.start() + + k = 0 + while k < n_proc: + # get result + res = Q.get() + if res == "Done": + k += 1 + else: + gather_res[res[0]] = res[1] + + except Exception as e: + print("Exception: ", e) + for p in processes: + p.terminate() + + raise e + finally: + for p in processes: + p.join() + print(f"Prefetching complete. [{time.time() - start} sec.]") + + if target_data_type == 'ndarray': + if not isinstance(gather_res[0], np.ndarray): + return np.concatenate([np.asarray(r) for r in gather_res], axis=0) + + # order outputs + return np.concatenate(gather_res, axis=0) + elif target_data_type == 'list': + out = [] + for r in gather_res: + out.extend(r) + return out + else: + return gather_res + + +def set_seed(seed): + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + np.random.seed(seed) + random.seed(seed) + # ensure reproducibility + os.environ["PYTHONHASHSEED"] = str(seed) + + +def transform_date_str(date_str): + from datetime import datetime + + # Convert the date string to a datetime object + date_obj = datetime.strptime(date_str, '%Y-%m-%d') + + # Calculate the week of the month + start_of_month = datetime(date_obj.year, date_obj.month, 1) + week_of_month = (date_obj - start_of_month).days // 7 + 1 + + return f"{date_obj.year}{date_obj.month:02}week{week_of_month}" + + + +def save_files(base_dir, run_directory, extensions=('.py', '.yaml')): + run_directory = os.path.join(run_directory, 'run') + os.makedirs(run_directory, exist_ok=True) + src_dirs = [ + "configs", + "criteria", + "datasets", + "models", + "trainers", + "utils", + ] + src_dirs = [os.path.join(base_dir, src_dir) for src_dir in src_dirs] + + for src_dir in src_dirs: + # Traverse the directory tree + for root, dirs, files in os.walk(src_dir): + # Calculate the relative path from the base directory + relative_path = os.path.relpath(root, base_dir) + dest_dir = os.path.join(run_directory, relative_path) + os.makedirs(dest_dir, exist_ok=True) + # Copy files with the specified extensions + for file in files: + if file.endswith(extensions): + src_file_path = os.path.join(root, file) + dest_file_path = os.path.join(dest_dir, file) + shutil.copy(src_file_path, dest_file_path) + # print(f"Saved {src_file_path} to {dest_file_path}") + + +def call_model_method(model, method_name, *args, **kwargs): + """ + Calls a method on the model, regardless of whether it is wrapped in DataParallel or not. + :param model: The model or DataParallel wrapped model. + :param method_name: The name of the method to call. + :param args: Positional arguments to pass to the method. + :param kwargs: Keyword arguments to pass to the method. + """ + + if isinstance(model, torch.nn.DataParallel): + target_model = model.module + else: + target_model = model + # Get the method and call it + method = getattr(target_model, method_name) + + return method(*args, **kwargs) + +def get_attributes_with_prefix(instance, prefix): + return {attr_name: getattr(instance, attr_name) for attr_name in vars(instance) if attr_name.startswith(prefix)} + + +def update_ema_params(model, ema_model, alpha, global_step): + alpha = min(1 - 1 / (global_step + 1), alpha) + # print('ema_model = ema_model * {} + (1 - {}) * model'.format(alpha, alpha)) + for ema_param, param in zip(ema_model.parameters(), model.parameters()): + # ema_param.data.mul_(alpha).add_(1 - alpha, param.data) + ema_param.data.mul_(alpha).add_(param.data, alpha=1 - alpha)