| """ |
| Default Datasets |
| |
| Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com) |
| Please cite our work if the code is helpful to you. |
| """ |
|
|
| import os |
| import glob |
| import json |
| import numpy as np |
| from copy import deepcopy |
| from torch.utils.data import Dataset |
| from collections.abc import Sequence |
|
|
| from pointcept.utils.logger import get_root_logger |
| from pointcept.utils.cache import shared_dict |
|
|
| from .builder import DATASETS, build_dataset |
| from .transform import Compose, TRANSFORMS |
|
|
|
|
| @DATASETS.register_module() |
| class DefaultDataset(Dataset): |
| VALID_ASSETS = [ |
| "coord", |
| "color", |
| "normal", |
| "strength", |
| "segment", |
| "segment20", |
| "segment200", |
| "instance", |
| "pose", |
| "superpoint", |
| "spt", |
| ] |
|
|
| def __init__( |
| self, |
| split="train", |
| data_root="data/dataset", |
| transform=None, |
| test_mode=False, |
| test_cfg=None, |
| cache=False, |
| ignore_index=-1, |
| loop=1, |
| ): |
| super(DefaultDataset, self).__init__() |
| self.data_root = data_root |
| self.split = split |
| self.transform = Compose(transform) |
| self.cache = cache |
| self.ignore_index = ignore_index |
| self.loop = ( |
| loop if not test_mode else 1 |
| ) |
| self.test_mode = test_mode |
| self.test_cfg = test_cfg if test_mode else None |
|
|
| if test_mode: |
| self.test_voxelize = TRANSFORMS.build(self.test_cfg.voxelize) |
| self.test_crop = ( |
| TRANSFORMS.build(self.test_cfg.crop) if self.test_cfg.crop else None |
| ) |
| self.post_transform = Compose(self.test_cfg.post_transform) |
| self.aug_transform = [Compose(aug) for aug in self.test_cfg.aug_transform] |
|
|
| self.data_list = self.get_data_list() |
| logger = get_root_logger() |
| logger.info( |
| "Totally {} x {} samples in {} {} set.".format( |
| len(self.data_list), self.loop, os.path.basename(self.data_root), split |
| ) |
| ) |
| |
| if len(self.data_list) > 0: |
| logger.info(f"[DEBUG] First 3 data paths: {self.data_list[:3]}") |
|
|
| def get_data_list(self): |
| if isinstance(self.split, str): |
| split_list = [self.split] |
| elif isinstance(self.split, Sequence): |
| split_list = self.split |
| else: |
| raise NotImplementedError |
|
|
| data_list = [] |
| for split in split_list: |
| split_path = os.path.join(self.data_root, split) |
|
|
| |
| if os.path.isfile(split_path): |
| logger = get_root_logger() |
| logger.info(f"[INFO] Loading split from file: {split_path}") |
| with open(split_path, 'r') as f: |
| lines = f.readlines() |
| for line in lines: |
| scene_name = line.strip() |
| if scene_name: |
| |
| |
| if 'train' in split: |
| subdir = 'train' |
| elif 'val' in split: |
| subdir = 'val' |
| elif 'test' in split: |
| subdir = 'test' |
| else: |
| |
| subdir = 'train' |
| logger.warning(f"[WARNING] Cannot infer subdir from split name '{split}', defaulting to 'train'.") |
|
|
| full_scene_path = os.path.join(self.data_root, subdir, scene_name) |
| data_list.append(full_scene_path) |
| else: |
| |
| logger = get_root_logger() |
| logger.info(f"[INFO] Listing scenes from directory: {split_path}") |
| data_list += glob.glob(os.path.join(split_path, "*")) |
|
|
| return data_list |
|
|
| def get_data(self, idx): |
| data_path = self.data_list[idx % len(self.data_list)] |
| name = self.get_data_name(idx) |
| split = self.get_split_name(idx) |
| if self.cache: |
| cache_name = f"pointcept-{name}" |
| return shared_dict(cache_name) |
|
|
| data_dict = {} |
| |
| print(f"[DEBUG] Loading data from: {data_path}") |
|
|
| if not os.path.exists(data_path): |
| print(f"❌ Error: Data directory not found: {data_path}") |
| |
| data_dict["coord"] = np.zeros((1, 3), dtype=np.float32) |
| data_dict["segment"] = np.array([self.ignore_index], dtype=np.int32) |
| data_dict["color"] = np.zeros((1, 3), dtype=np.float32) |
| data_dict["name"] = name |
| data_dict["split"] = split |
| return data_dict |
|
|
| assets = os.listdir(data_path) |
| for asset in assets: |
| if not asset.endswith(".npy"): |
| continue |
| if asset[:-4] not in self.VALID_ASSETS: |
| continue |
| data_dict[asset[:-4]] = np.load(os.path.join(data_path, asset)) |
|
|
| data_dict["name"] = name |
| data_dict["split"] = split |
|
|
| |
| if "coord" in data_dict.keys(): |
| data_dict["coord"] = data_dict["coord"].astype(np.float32) |
| else: |
| print(f"❌ Error: 'coord.npy' not found in {data_path}, using dummy data.") |
| data_dict["coord"] = np.zeros((1, 3), dtype=np.float32) |
|
|
| |
| if "color" in data_dict.keys(): |
| data_dict["color"] = data_dict["color"].astype(np.float32) |
| else: |
| print(f"⚠️ Warning: 'color.npy' not found in {data_path}, using zeros.") |
| data_dict["color"] = np.zeros(data_dict["coord"].shape, dtype=np.float32) |
|
|
| |
| if "normal" in data_dict.keys(): |
| data_dict["normal"] = data_dict["normal"].astype(np.float32) |
|
|
| |
| segment_key = None |
| if "segment20" in data_dict: |
| segment_key = "segment20" |
| elif "segment200" in data_dict: |
| segment_key = "segment200" |
| elif "segment" in data_dict: |
| segment_key = "segment" |
|
|
| if segment_key is not None: |
| data_dict["segment"] = data_dict[segment_key].reshape([-1]).astype(np.int32) |
| |
| |
| |
| else: |
| print(f"❌ Error: No segment label found in {data_path}, using ignore index.") |
| data_dict["segment"] = ( |
| np.ones(data_dict["coord"].shape[0], dtype=np.int32) * self.ignore_index |
| ) |
|
|
| |
| if "instance" in data_dict.keys(): |
| data_dict["instance"] = data_dict["instance"].reshape([-1]).astype(np.int32) |
| else: |
| data_dict["instance"] = ( |
| np.ones(data_dict["coord"].shape[0], dtype=np.int32) * -1 |
| ) |
|
|
| if "superpoint" in data_dict.keys(): |
| data_dict["superpoint"] = data_dict["superpoint"].reshape([-1]).astype(np.int32) |
| elif "spt" in data_dict.keys(): |
| data_dict["spt"] = data_dict["spt"].reshape([-1]).astype(np.int32) |
|
|
| return data_dict |
|
|
| def get_data_name(self, idx): |
| return os.path.basename(self.data_list[idx % len(self.data_list)]) |
|
|
| def get_split_name(self, idx): |
| return os.path.basename( |
| os.path.dirname(self.data_list[idx % len(self.data_list)]) |
| ) |
|
|
| def prepare_train_data(self, idx): |
| |
| data_dict = self.get_data(idx) |
| data_dict = self.transform(data_dict) |
| return data_dict |
|
|
| def prepare_test_data(self, idx): |
| |
| data_dict = self.get_data(idx) |
| data_dict = self.transform(data_dict) |
| result_dict = dict(segment=data_dict.pop("segment"), name=data_dict.pop("name")) |
| if "origin_segment" in data_dict: |
| assert "inverse" in data_dict |
| result_dict["origin_segment"] = data_dict.pop("origin_segment") |
| result_dict["inverse"] = data_dict.pop("inverse") |
|
|
| data_dict_list = [] |
| for aug in self.aug_transform: |
| data_dict_list.append(aug(deepcopy(data_dict))) |
|
|
| fragment_list = [] |
| for data in data_dict_list: |
| if self.test_voxelize is not None: |
| data_part_list = self.test_voxelize(data) |
| else: |
| data["index"] = np.arange(data["coord"].shape[0]) |
| data_part_list = [data] |
| for data_part in data_part_list: |
| if self.test_crop is not None: |
| data_part = self.test_crop(data_part) |
| else: |
| data_part = [data_part] |
| fragment_list += data_part |
|
|
| for i in range(len(fragment_list)): |
| fragment_list[i] = self.post_transform(fragment_list[i]) |
| result_dict["fragment_list"] = fragment_list |
| return result_dict |
|
|
| def __getitem__(self, idx): |
| if self.test_mode: |
| return self.prepare_test_data(idx) |
| else: |
| return self.prepare_train_data(idx) |
|
|
| def __len__(self): |
| return len(self.data_list) * self.loop |
|
|
|
|
| @DATASETS.register_module() |
| class ConcatDataset(Dataset): |
| def __init__(self, datasets, loop=1): |
| super(ConcatDataset, self).__init__() |
| self.datasets = [build_dataset(dataset) for dataset in datasets] |
| self.loop = loop |
| self.data_list = self.get_data_list() |
| logger = get_root_logger() |
| logger.info( |
| "Totally {} x {} samples in the concat set.".format( |
| len(self.data_list), self.loop |
| ) |
| ) |
|
|
| def get_data_list(self): |
| data_list = [] |
| for i in range(len(self.datasets)): |
| data_list.extend( |
| zip( |
| np.ones(len(self.datasets[i]), dtype=int) * i, |
| np.arange(len(self.datasets[i])), |
| ) |
| ) |
| return data_list |
|
|
| def get_data(self, idx): |
| dataset_idx, data_idx = self.data_list[idx % len(self.data_list)] |
| return self.datasets[dataset_idx][data_idx] |
|
|
| def get_data_name(self, idx): |
| dataset_idx, data_idx = self.data_list[idx % len(self.data_list)] |
| return self.datasets[dataset_idx].get_data_name(data_idx) |
|
|
| def __getitem__(self, idx): |
| return self.get_data(idx) |
|
|
| def __len__(self): |
| return len(self.data_list) * self.loop |