Spaces:
Runtime error
Runtime error
| import pathlib | |
| from typing import Tuple | |
| import numpy as np | |
| import torch | |
| import pathlib | |
| try: | |
| import pyspng | |
| PYSPNG_IMPORTED = True | |
| except ImportError: | |
| PYSPNG_IMPORTED = False | |
| print("Could not load pyspng. Defaulting to pillow image backend.") | |
| from PIL import Image | |
| from tops import logger | |
| class FDFDataset: | |
| def __init__(self, | |
| dirpath, | |
| imsize: Tuple[int], | |
| load_keypoints: bool, | |
| transform): | |
| dirpath = pathlib.Path(dirpath) | |
| self.dirpath = dirpath | |
| self.transform = transform | |
| self.imsize = imsize[0] | |
| self.load_keypoints = load_keypoints | |
| assert self.dirpath.is_dir(),\ | |
| f"Did not find dataset at: {dirpath}" | |
| image_dir = self.dirpath.joinpath("images", str(self.imsize)) | |
| self.image_paths = list(image_dir.glob("*.png")) | |
| assert len(self.image_paths) > 0,\ | |
| f"Did not find images in: {image_dir}" | |
| self.image_paths.sort(key=lambda x: int(x.stem)) | |
| self.landmarks = np.load(self.dirpath.joinpath("landmarks.npy")).reshape(-1, 7, 2).astype(np.float32) | |
| self.bounding_boxes = torch.load(self.dirpath.joinpath("bounding_box", f"{self.imsize}.torch")) | |
| assert len(self.image_paths) == len(self.bounding_boxes) | |
| assert len(self.image_paths) == len(self.landmarks) | |
| logger.log( | |
| f"Dataset loaded from: {dirpath}. Number of samples:{len(self)}, imsize={imsize}") | |
| def get_mask(self, idx): | |
| mask = torch.ones((1, self.imsize, self.imsize), dtype=torch.bool) | |
| bounding_box = self.bounding_boxes[idx] | |
| x0, y0, x1, y1 = bounding_box | |
| mask[:, y0:y1, x0:x1] = 0 | |
| return mask | |
| def __len__(self): | |
| return len(self.image_paths) | |
| def __getitem__(self, index): | |
| impath = self.image_paths[index] | |
| if PYSPNG_IMPORTED: | |
| with open(impath, "rb") as fp: | |
| im = pyspng.load(fp.read()) | |
| else: | |
| with Image.open(impath) as fp: | |
| im = np.array(fp) | |
| im = torch.from_numpy(np.rollaxis(im, -1, 0)) | |
| masks = self.get_mask(index) | |
| landmark = self.landmarks[index] | |
| batch = { | |
| "img": im, | |
| "mask": masks, | |
| } | |
| if self.load_keypoints: | |
| batch["keypoints"] = landmark | |
| if self.transform is None: | |
| return batch | |
| return self.transform(batch) | |
| class FDF256Dataset: | |
| def __init__(self, | |
| dirpath, | |
| load_keypoints: bool, | |
| transform): | |
| dirpath = pathlib.Path(dirpath) | |
| self.dirpath = dirpath | |
| self.transform = transform | |
| self.load_keypoints = load_keypoints | |
| assert self.dirpath.is_dir(),\ | |
| f"Did not find dataset at: {dirpath}" | |
| image_dir = self.dirpath.joinpath("images") | |
| self.image_paths = list(image_dir.glob("*.png")) | |
| assert len(self.image_paths) > 0,\ | |
| f"Did not find images in: {image_dir}" | |
| self.image_paths.sort(key=lambda x: int(x.stem)) | |
| self.landmarks = np.load(self.dirpath.joinpath("landmarks.npy")).reshape(-1, 7, 2).astype(np.float32) | |
| self.bounding_boxes = torch.from_numpy(np.load(self.dirpath.joinpath("bounding_box.npy"))) | |
| assert len(self.image_paths) == len(self.bounding_boxes) | |
| assert len(self.image_paths) == len(self.landmarks) | |
| logger.log( | |
| f"Dataset loaded from: {dirpath}. Number of samples:{len(self)}") | |
| def get_mask(self, idx): | |
| mask = torch.ones((1, 256, 256), dtype=torch.bool) | |
| bounding_box = self.bounding_boxes[idx] | |
| x0, y0, x1, y1 = bounding_box | |
| mask[:, y0:y1, x0:x1] = 0 | |
| return mask | |
| def __len__(self): | |
| return len(self.image_paths) | |
| def __getitem__(self, index): | |
| impath = self.image_paths[index] | |
| if PYSPNG_IMPORTED: | |
| with open(impath, "rb") as fp: | |
| im = pyspng.load(fp.read()) | |
| else: | |
| with Image.open(impath) as fp: | |
| im = np.array(fp) | |
| im = torch.from_numpy(np.rollaxis(im, -1, 0)) | |
| masks = self.get_mask(index) | |
| landmark = self.landmarks[index] | |
| batch = { | |
| "img": im, | |
| "mask": masks, | |
| } | |
| if self.load_keypoints: | |
| batch["keypoints"] = landmark | |
| if self.transform is None: | |
| return batch | |
| return self.transform(batch) | |