| """ |
| ModelNet40 Dataset |
| |
| get sampled point clouds of ModelNet40 (XYZ and normal from mesh, 10k points per shape) |
| at "https://shapenet.cs.stanford.edu/media/modelnet40_normal_resampled.zip " |
| |
| Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com) |
| Please cite our work if the code is helpful to you. |
| """ |
| import glob |
| import os |
| import numpy as np |
| import copy |
| import pointops |
| import torch |
| from torch.utils.data import Dataset |
| from copy import deepcopy |
|
|
| from pointcept.utils.logger import get_root_logger |
| from .builder import DATASETS |
| from .transform import Compose |
|
|
|
|
| @DATASETS.register_module() |
| class ModelNetDataset(Dataset): |
| def __init__( |
| self, |
| split="train", |
| data_root="data/modelnet40", |
| class_names=None, |
| transform=None, |
| num_points=8192, |
| uniform_sampling=True, |
| save_record=True, |
| test_mode=False, |
| test_cfg=None, |
| loop=1, |
| ): |
| super().__init__() |
| self.data_root = data_root |
| self.class_names = dict(zip(class_names, range(len(class_names)))) |
| self.split = split |
| self.num_point = num_points |
| self.uniform_sampling = uniform_sampling |
| self.transform = Compose(transform) |
| self.loop = ( |
| loop if not test_mode else 1 |
| ) |
| self.test_mode = test_mode |
| self.test_cfg = test_cfg if test_mode else None |
| if test_mode: |
| self.post_transform = Compose(self.test_cfg.post_transform) |
| self.aug_transform = [Compose(aug) for aug in self.test_cfg.aug_transform] |
|
|
| self.data_list = self.get_data_list() |
| logger = get_root_logger() |
| logger.info( |
| "Totally {} x {} samples in {} set.".format( |
| len(self.data_list), self.loop, split |
| ) |
| ) |
|
|
| |
| record_name = f"modelnet40_{self.split}" |
| if num_points is not None: |
| record_name += f"_{num_points}points" |
| if uniform_sampling: |
| record_name += "_uniform" |
| record_path = os.path.join(self.data_root, f"{record_name}.pth") |
| if os.path.isfile(record_path): |
| logger.info(f"Loading record: {record_name} ...") |
| self.data = torch.load(record_path, weights_only=False) |
| else: |
| logger.info(f"Preparing record: {record_name} ...") |
| self.data = {} |
| for idx in range(len(self.data_list)): |
| data_name = self.get_data_name(idx) |
| logger.info(f"Parsing data [{idx}/{len(self.data_list)}]: {data_name}") |
| self.data[data_name] = self.get_data(idx) |
| if save_record: |
| torch.save(self.data, record_path) |
|
|
| def get_data(self, idx): |
| data_idx = idx % len(self.data_list) |
| data_name = self.get_data_name(data_idx) |
| if data_name in self.data.keys(): |
| return copy.deepcopy(self.data[data_name]) |
| else: |
| data_path = self.data_list[data_idx] |
| print(f"[DEBUG] Attempting to load: {data_path}") |
| data = np.loadtxt(data_path).astype(np.float32) |
| if self.num_point is not None: |
| if self.uniform_sampling: |
| with torch.no_grad(): |
| mask = pointops.farthest_point_sampling( |
| torch.tensor(data).float(), |
| torch.tensor([len(data)]).long(), |
| torch.tensor([self.num_point]).long(), |
| ) |
| data = data[mask.cpu()] |
| else: |
| data = data[: self.num_point] |
| coord, normal = data[:, 0:3], data[:, 3:6] |
| data_shape = os.path.basename(os.path.dirname(os.path.dirname(data_path))) |
| category = np.array([self.class_names[data_shape]]) |
|
|
| |
| min_points = 32 |
| if len(coord) < min_points: |
| print(f"⚠️ [SAFEGUARD] Sample {data_name} has only {len(coord)} points. Padding to {min_points}.") |
| |
| while len(coord) < min_points: |
| coord = np.concatenate([coord, coord[:1]], axis=0) |
| normal = np.concatenate([normal, normal[:1]], axis=0) |
|
|
| return dict(coord=coord, normal=normal, category=category) |
|
|
| def get_data_list(self): |
| print(f"[DEBUG] ModelNetDataset - data_root: {self.data_root}") |
| print(f"[DEBUG] ModelNetDataset - split: {self.split}") |
|
|
| |
| if self.split == "train": |
| subdir_name = "train" |
| elif self.split == "test" or self.split == "val": |
| subdir_name = "test" |
| else: |
| subdir_name = "train" |
| print(f"[WARNING] Unrecognized split '{self.split}'. Defaulting to '{subdir_name}' subdirectory.") |
|
|
| |
| class_names = [name for name in os.listdir(self.data_root) if os.path.isdir(os.path.join(self.data_root, name))] |
| print(f"[DEBUG] Found {len(class_names)} classes: {class_names[:5]}") |
|
|
| data_list = [] |
| for class_name in class_names: |
| |
| class_path = os.path.join(self.data_root, class_name, subdir_name) |
| if not os.path.exists(class_path): |
| print(f"⚠️ Warning: Directory not found: {class_path}") |
| continue |
| |
| files = glob.glob(os.path.join(class_path, "*.txt")) |
| data_list.extend(files) |
| print(f"[DEBUG] Found {len(files)} files in {class_path}") |
|
|
| print(f"[DEBUG] Total files found: {len(data_list)}") |
| return data_list |
|
|
| def get_data_name(self, idx): |
| |
| data_idx = idx % len(self.data_list) |
| return os.path.basename(self.data_list[data_idx]) |
|
|
| def __getitem__(self, idx): |
| if self.test_mode: |
| return self.prepare_test_data(idx) |
| else: |
| return self.prepare_train_data(idx) |
|
|
| def __len__(self): |
| return len(self.data_list) * self.loop |
|
|
| def prepare_train_data(self, idx): |
| data_dict = self.get_data(idx) |
| data_dict = self.transform(data_dict) |
| return data_dict |
|
|
| def prepare_test_data(self, idx): |
| assert idx < len(self.data_list) |
| data_dict = self.get_data(idx) |
| category = data_dict.pop("category") |
| data_dict = self.transform(data_dict) |
| data_dict_list = [] |
| for aug in self.aug_transform: |
| data_dict_list.append(aug(deepcopy(data_dict))) |
| for i in range(len(data_dict_list)): |
| data_dict_list[i] = self.post_transform(data_dict_list[i]) |
| data_dict = dict( |
| voting_list=data_dict_list, |
| category=category, |
| name=self.get_data_name(idx), |
| ) |
| return data_dict |