Spaces:
Runtime error
Runtime error
| """ | |
| ScanNet20 / ScanNet200 / ScanNet Data Efficient Dataset | |
| Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com) | |
| Please cite our work if the code is helpful to you. | |
| """ | |
| import os | |
| import glob | |
| import numpy as np | |
| import torch | |
| from copy import deepcopy | |
| from torch.utils.data import Dataset | |
| from collections.abc import Sequence | |
| from pointcept.utils.logger import get_root_logger | |
| from pointcept.utils.cache import shared_dict | |
| from .builder import DATASETS | |
| from .defaults import DefaultDataset | |
| from .transform import Compose, TRANSFORMS | |
| from .preprocessing.scannet.meta_data.scannet200_constants import ( | |
| VALID_CLASS_IDS_20, | |
| VALID_CLASS_IDS_200, | |
| ) | |
| class ScanNetDataset(DefaultDataset): | |
| VALID_ASSETS = [ | |
| "coord", | |
| "color", | |
| "normal", | |
| "segment20", | |
| "instance", | |
| ] | |
| class2id = np.array(VALID_CLASS_IDS_20) | |
| def __init__( | |
| self, | |
| lr_file=None, | |
| la_file=None, | |
| **kwargs, | |
| ): | |
| self.lr = np.loadtxt(lr_file, dtype=str) if lr_file is not None else None | |
| self.la = torch.load(la_file) if la_file is not None else None | |
| super().__init__(**kwargs) | |
| def get_data_list(self): | |
| if self.lr is None: | |
| data_list = super().get_data_list() | |
| else: | |
| data_list = [ | |
| os.path.join(self.data_root, "train", name) for name in self.lr | |
| ] | |
| return data_list | |
| def get_data(self, idx): | |
| data_path = self.data_list[idx % len(self.data_list)] | |
| name = self.get_data_name(idx) | |
| if self.cache: | |
| cache_name = f"pointcept-{name}" | |
| return shared_dict(cache_name) | |
| data_dict = {} | |
| assets = os.listdir(data_path) | |
| for asset in assets: | |
| if not asset.endswith(".npy"): | |
| continue | |
| if asset[:-4] not in self.VALID_ASSETS: | |
| continue | |
| data_dict[asset[:-4]] = np.load(os.path.join(data_path, asset)) | |
| data_dict["name"] = name | |
| data_dict["coord"] = data_dict["coord"].astype(np.float32) | |
| data_dict["color"] = data_dict["color"].astype(np.float32) | |
| data_dict["normal"] = data_dict["normal"].astype(np.float32) | |
| if "segment20" in data_dict.keys(): | |
| data_dict["segment"] = ( | |
| data_dict.pop("segment20").reshape([-1]).astype(np.int32) | |
| ) | |
| elif "segment200" in data_dict.keys(): | |
| data_dict["segment"] = ( | |
| data_dict.pop("segment200").reshape([-1]).astype(np.int32) | |
| ) | |
| else: | |
| data_dict["segment"] = ( | |
| np.ones(data_dict["coord"].shape[0], dtype=np.int32) * -1 | |
| ) | |
| if "instance" in data_dict.keys(): | |
| data_dict["instance"] = ( | |
| data_dict.pop("instance").reshape([-1]).astype(np.int32) | |
| ) | |
| else: | |
| data_dict["instance"] = ( | |
| np.ones(data_dict["coord"].shape[0], dtype=np.int32) * -1 | |
| ) | |
| if self.la: | |
| sampled_index = self.la[self.get_data_name(idx)] | |
| mask = np.ones_like(data_dict["segment"], dtype=bool) | |
| mask[sampled_index] = False | |
| data_dict["segment"][mask] = self.ignore_index | |
| data_dict["sampled_index"] = sampled_index | |
| return data_dict | |
| class ScanNet200Dataset(ScanNetDataset): | |
| VALID_ASSETS = [ | |
| "coord", | |
| "color", | |
| "normal", | |
| "segment200", | |
| "instance", | |
| ] | |
| class2id = np.array(VALID_CLASS_IDS_200) | |