Spaces:
Runtime error
Runtime error
| import pickle | |
| import torchvision | |
| import torch | |
| import pathlib | |
| import numpy as np | |
| from typing import Callable, Optional, Union | |
| from torch.hub import get_dir as get_hub_dir | |
| def cache_embed_stats(embed_map: torch.Tensor): | |
| mean = embed_map.mean(dim=0, keepdim=True) | |
| rstd = ((embed_map - mean).square().mean(dim=0, keepdim=True)+1e-8).rsqrt() | |
| cache = dict(mean=mean, rstd=rstd, embed_map=embed_map) | |
| path = pathlib.Path(get_hub_dir(), f"embed_map_stats.torch") | |
| path.parent.mkdir(exist_ok=True, parents=True) | |
| torch.save(cache, path) | |
| class CocoCSE(torch.utils.data.Dataset): | |
| def __init__(self, | |
| dirpath: Union[str, pathlib.Path], | |
| transform: Optional[Callable], | |
| normalize_E: bool,): | |
| dirpath = pathlib.Path(dirpath) | |
| self.dirpath = dirpath | |
| self.transform = transform | |
| assert self.dirpath.is_dir(),\ | |
| f"Did not find dataset at: {dirpath}" | |
| self.image_paths, self.embedding_paths = self._load_impaths() | |
| self.embed_map = torch.from_numpy(np.load(self.dirpath.joinpath("embed_map.npy"))) | |
| mean = self.embed_map.mean(dim=0, keepdim=True) | |
| rstd = ((self.embed_map - mean).square().mean(dim=0, keepdim=True)+1e-8).rsqrt() | |
| self.embed_map = (self.embed_map - mean) * rstd | |
| cache_embed_stats(self.embed_map) | |
| def _load_impaths(self): | |
| image_dir = self.dirpath.joinpath("images") | |
| image_paths = list(image_dir.glob("*.png")) | |
| image_paths.sort() | |
| embedding_paths = [ | |
| self.dirpath.joinpath("embedding", x.stem + ".npy") for x in image_paths | |
| ] | |
| return image_paths, embedding_paths | |
| def __len__(self): | |
| return len(self.image_paths) | |
| def __getitem__(self, idx): | |
| im = torchvision.io.read_image(str(self.image_paths[idx])) | |
| vertices, mask, border = np.split(np.load(self.embedding_paths[idx]), 3, axis=-1) | |
| vertices = torch.from_numpy(vertices.squeeze()).long() | |
| mask = torch.from_numpy(mask.squeeze()).float() | |
| border = torch.from_numpy(border.squeeze()).float() | |
| E_mask = 1 - mask - border | |
| batch = { | |
| "img": im, | |
| "vertices": vertices[None], | |
| "mask": mask[None], | |
| "embed_map": self.embed_map, | |
| "border": border[None], | |
| "E_mask": E_mask[None] | |
| } | |
| if self.transform is None: | |
| return batch | |
| return self.transform(batch) | |