Spaces:
Runtime error
Runtime error
| import os | |
| import json | |
| import cv2 | |
| import numpy as np | |
| from PIL import Image | |
| from torch.utils.data import Dataset | |
| from torchvision import transforms | |
| from spiga.data.loaders.transforms import get_transformers | |
| class AlignmentsDataset(Dataset): | |
| '''Loads datasets of images with landmarks and bounding boxes. | |
| ''' | |
| def __init__(self, | |
| database, | |
| json_file, | |
| images_dir, | |
| image_size=(128, 128), | |
| transform=None, | |
| indices=None, | |
| debug=False): | |
| """ | |
| :param database: class DatabaseStruct containing all the specifics of the database | |
| :param json_file: path to the json file which contains the names of the images, landmarks, bounding boxes, etc | |
| :param images_dir: path of the directory containing the images. | |
| :param image_size: tuple like e.g. (128, 128) | |
| :param transform: composition of transformations that will be applied to the samples. | |
| :param debug_mode: bool if True, loads a very reduced_version of the dataset for debugging purposes. | |
| :param indices: If it is a list of indices, allows to work with the subset of | |
| items specified by the list. If it is None, the whole set is used. | |
| """ | |
| self.database = database | |
| self.images_dir = images_dir | |
| self.transform = transform | |
| self.image_size = image_size | |
| self.indices = indices | |
| self._imgs_dict = None | |
| self.debug = debug | |
| with open(json_file) as jsonfile: | |
| self.data = json.load(jsonfile) | |
| def __len__(self): | |
| '''Returns the length of the dataset | |
| ''' | |
| if self.indices is None: | |
| return len(self.data) | |
| else: | |
| return len(self.indices) | |
| def __getitem__(self, sample_idx): | |
| '''Returns sample of the dataset of index idx''' | |
| # To allow work with a subset | |
| if self.indices is not None: | |
| sample_idx = self.indices[sample_idx] | |
| # Load sample image | |
| img_name = os.path.join(self.images_dir, self.data[sample_idx]['imgpath']) | |
| if not self._imgs_dict: | |
| image_cv = cv2.imread(img_name) | |
| else: | |
| image_cv = self._imgs_dict[sample_idx] | |
| # Some images are B&W. We make sure that any image has three channels. | |
| if len(image_cv.shape) == 2: | |
| image_cv = np.repeat(image_cv[:, :, np.newaxis], 3, axis=-1) | |
| # Some images have alpha channel | |
| image_cv = image_cv[:, :, :3] | |
| image_cv = cv2.cvtColor(image_cv, cv2.COLOR_BGR2RGB) | |
| image = Image.fromarray(image_cv) | |
| # Load sample anns | |
| ids = np.array(self.data[sample_idx]['ids']) | |
| landmarks = np.array(self.data[sample_idx]['landmarks']) | |
| bbox = np.array(self.data[sample_idx]['bbox']) | |
| vis = np.array(self.data[sample_idx]['visible']) | |
| headpose = self.data[sample_idx]['headpose'] | |
| # Generate bbox if need it | |
| if bbox is None: | |
| # Compute bbox using landmarks | |
| aux = landmarks[vis == 1.0] | |
| bbox = np.zeros(4) | |
| bbox[0] = min(aux[:, 0]) | |
| bbox[1] = min(aux[:, 1]) | |
| bbox[2] = max(aux[:, 0]) - bbox[0] | |
| bbox[3] = max(aux[:, 1]) - bbox[1] | |
| # Clean and mask landmarks | |
| mask_ldm = np.ones(self.database.num_landmarks) | |
| if not self.database.ldm_ids == ids.tolist(): | |
| new_ldm = np.zeros((self.database.num_landmarks, 2)) | |
| new_vis = np.zeros(self.database.num_landmarks) | |
| xyv = np.hstack((landmarks, vis[np.newaxis,:].T)) | |
| ids_dict = dict(zip(ids.astype(int).astype(str), xyv)) | |
| for pos, identifier in enumerate(self.database.ldm_ids): | |
| if str(identifier) in ids_dict: | |
| x, y, v = ids_dict[str(identifier)] | |
| new_ldm[pos] = [x,y] | |
| new_vis[pos] = v | |
| else: | |
| mask_ldm[pos] = 0 | |
| landmarks = new_ldm | |
| vis = new_vis | |
| sample = {'image': image, | |
| 'sample_idx': sample_idx, | |
| 'imgpath': img_name, | |
| 'ids_ldm': np.array(self.database.ldm_ids), | |
| 'bbox': bbox, | |
| 'bbox_raw': bbox, | |
| 'landmarks': landmarks, | |
| 'visible': vis.astype(np.float64), | |
| 'mask_ldm': mask_ldm, | |
| 'imgpath_local': self.data[sample_idx]['imgpath'], | |
| } | |
| if self.debug: | |
| sample['landmarks_ori'] = landmarks | |
| sample['visible_ori'] = vis.astype(np.float64) | |
| sample['mask_ldm_ori'] = mask_ldm | |
| if headpose is not None: | |
| sample['headpose_ori'] = np.array(headpose) | |
| if self.transform: | |
| sample = self.transform(sample) | |
| return sample | |
| def get_dataset(data_config, pretreat=None, debug=False): | |
| augmentors = get_transformers(data_config) | |
| if pretreat is not None: | |
| augmentors.append(pretreat) | |
| dataset = AlignmentsDataset(data_config.database, | |
| data_config.anns_file, | |
| data_config.image_dir, | |
| image_size=data_config.image_size, | |
| transform=transforms.Compose(augmentors), | |
| indices=data_config.ids, | |
| debug=debug) | |
| return dataset | |