""" ModelNet40 Dataset get sampled point clouds of ModelNet40 (XYZ and normal from mesh, 10k points per shape) at "https://shapenet.cs.stanford.edu/media/modelnet40_normal_resampled.zip " Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com) Please cite our work if the code is helpful to you. """ import glob # 🛠️ 关键修复:添加这一行 import os import numpy as np import copy import pointops import torch from torch.utils.data import Dataset from copy import deepcopy from pointcept.utils.logger import get_root_logger from .builder import DATASETS from .transform import Compose @DATASETS.register_module() class ModelNetDataset(Dataset): def __init__( self, split="train", data_root="data/modelnet40", class_names=None, transform=None, num_points=8192, uniform_sampling=True, save_record=True, test_mode=False, test_cfg=None, loop=1, ): super().__init__() self.data_root = data_root self.class_names = dict(zip(class_names, range(len(class_names)))) self.split = split self.num_point = num_points self.uniform_sampling = uniform_sampling self.transform = Compose(transform) self.loop = ( loop if not test_mode else 1 ) # force make loop = 1 while in test mode self.test_mode = test_mode self.test_cfg = test_cfg if test_mode else None if test_mode: self.post_transform = Compose(self.test_cfg.post_transform) self.aug_transform = [Compose(aug) for aug in self.test_cfg.aug_transform] self.data_list = self.get_data_list() logger = get_root_logger() logger.info( "Totally {} x {} samples in {} set.".format( len(self.data_list), self.loop, split ) ) # check, prepare record record_name = f"modelnet40_{self.split}" if num_points is not None: record_name += f"_{num_points}points" if uniform_sampling: record_name += "_uniform" record_path = os.path.join(self.data_root, f"{record_name}.pth") if os.path.isfile(record_path): logger.info(f"Loading record: {record_name} ...") self.data = torch.load(record_path, weights_only=False) else: logger.info(f"Preparing record: {record_name} ...") self.data = {} for idx in range(len(self.data_list)): data_name = self.get_data_name(idx) # 🛠️ 修复:使用 get_data_name 获取文件名 logger.info(f"Parsing data [{idx}/{len(self.data_list)}]: {data_name}") self.data[data_name] = self.get_data(idx) if save_record: torch.save(self.data, record_path) def get_data(self, idx): data_idx = idx % len(self.data_list) data_name = self.get_data_name(data_idx) if data_name in self.data.keys(): return copy.deepcopy(self.data[data_name]) else: data_path = self.data_list[data_idx] print(f"[DEBUG] Attempting to load: {data_path}") data = np.loadtxt(data_path).astype(np.float32) if self.num_point is not None: if self.uniform_sampling: with torch.no_grad(): mask = pointops.farthest_point_sampling( torch.tensor(data).float(), torch.tensor([len(data)]).long(), torch.tensor([self.num_point]).long(), ) data = data[mask.cpu()] else: data = data[: self.num_point] coord, normal = data[:, 0:3], data[:, 3:6] data_shape = os.path.basename(os.path.dirname(os.path.dirname(data_path))) category = np.array([self.class_names[data_shape]]) # 🛠️ 关键兜底:确保点数 >= 32 min_points = 32 if len(coord) < min_points: print(f"⚠️ [SAFEGUARD] Sample {data_name} has only {len(coord)} points. Padding to {min_points}.") # 重复第一个点,直到点数达到 min_points while len(coord) < min_points: coord = np.concatenate([coord, coord[:1]], axis=0) normal = np.concatenate([normal, normal[:1]], axis=0) return dict(coord=coord, normal=normal, category=category) def get_data_list(self): print(f"[DEBUG] ModelNetDataset - data_root: {self.data_root}") print(f"[DEBUG] ModelNetDataset - split: {self.split}") # 🛠️ 修复:根据 split 参数确定是加载 'train' 还是 'test' 子目录 if self.split == "train": subdir_name = "train" elif self.split == "test" or self.split == "val": subdir_name = "test" else: subdir_name = "train" print(f"[WARNING] Unrecognized split '{self.split}'. Defaulting to '{subdir_name}' subdirectory.") # 🛠️ 修复:始终扫描 data_root 下的所有类别目录 class_names = [name for name in os.listdir(self.data_root) if os.path.isdir(os.path.join(self.data_root, name))] print(f"[DEBUG] Found {len(class_names)} classes: {class_names[:5]}") data_list = [] for class_name in class_names: # 🛠️ 构建完整路径 .../ModelNet40/airplane/train/ class_path = os.path.join(self.data_root, class_name, subdir_name) if not os.path.exists(class_path): print(f"⚠️ Warning: Directory not found: {class_path}") continue # 扫描该目录下的所有 .txt 文件,返回完整路径 files = glob.glob(os.path.join(class_path, "*.txt")) data_list.extend(files) print(f"[DEBUG] Found {len(files)} files in {class_path}") print(f"[DEBUG] Total files found: {len(data_list)}") return data_list def get_data_name(self, idx): # 🛠️ 修复:返回文件名(不含路径) data_idx = idx % len(self.data_list) return os.path.basename(self.data_list[data_idx]) def __getitem__(self, idx): if self.test_mode: return self.prepare_test_data(idx) else: return self.prepare_train_data(idx) def __len__(self): return len(self.data_list) * self.loop def prepare_train_data(self, idx): data_dict = self.get_data(idx) data_dict = self.transform(data_dict) return data_dict def prepare_test_data(self, idx): assert idx < len(self.data_list) data_dict = self.get_data(idx) category = data_dict.pop("category") data_dict = self.transform(data_dict) data_dict_list = [] for aug in self.aug_transform: data_dict_list.append(aug(deepcopy(data_dict))) for i in range(len(data_dict_list)): data_dict_list[i] = self.post_transform(data_dict_list[i]) data_dict = dict( voting_list=data_dict_list, category=category, name=self.get_data_name(idx), ) return data_dict