YYYYYYUUU's picture
Backup FULL local core code incl. libs/ CUDA ext + all configs
3499c27 verified
Raw
History Blame Contribute Delete
7.32 kB
"""
ModelNet40 Dataset
get sampled point clouds of ModelNet40 (XYZ and normal from mesh, 10k points per shape)
at "https://shapenet.cs.stanford.edu/media/modelnet40_normal_resampled.zip "
Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com)
Please cite our work if the code is helpful to you.
"""
import glob # 🛠️ 关键修复:添加这一行
import os
import numpy as np
import copy
import pointops
import torch
from torch.utils.data import Dataset
from copy import deepcopy
from pointcept.utils.logger import get_root_logger
from .builder import DATASETS
from .transform import Compose
@DATASETS.register_module()
class ModelNetDataset(Dataset):
def __init__(
self,
split="train",
data_root="data/modelnet40",
class_names=None,
transform=None,
num_points=8192,
uniform_sampling=True,
save_record=True,
test_mode=False,
test_cfg=None,
loop=1,
):
super().__init__()
self.data_root = data_root
self.class_names = dict(zip(class_names, range(len(class_names))))
self.split = split
self.num_point = num_points
self.uniform_sampling = uniform_sampling
self.transform = Compose(transform)
self.loop = (
loop if not test_mode else 1
) # force make loop = 1 while in test mode
self.test_mode = test_mode
self.test_cfg = test_cfg if test_mode else None
if test_mode:
self.post_transform = Compose(self.test_cfg.post_transform)
self.aug_transform = [Compose(aug) for aug in self.test_cfg.aug_transform]
self.data_list = self.get_data_list()
logger = get_root_logger()
logger.info(
"Totally {} x {} samples in {} set.".format(
len(self.data_list), self.loop, split
)
)
# check, prepare record
record_name = f"modelnet40_{self.split}"
if num_points is not None:
record_name += f"_{num_points}points"
if uniform_sampling:
record_name += "_uniform"
record_path = os.path.join(self.data_root, f"{record_name}.pth")
if os.path.isfile(record_path):
logger.info(f"Loading record: {record_name} ...")
self.data = torch.load(record_path, weights_only=False)
else:
logger.info(f"Preparing record: {record_name} ...")
self.data = {}
for idx in range(len(self.data_list)):
data_name = self.get_data_name(idx) # 🛠️ 修复:使用 get_data_name 获取文件名
logger.info(f"Parsing data [{idx}/{len(self.data_list)}]: {data_name}")
self.data[data_name] = self.get_data(idx)
if save_record:
torch.save(self.data, record_path)
def get_data(self, idx):
data_idx = idx % len(self.data_list)
data_name = self.get_data_name(data_idx)
if data_name in self.data.keys():
return copy.deepcopy(self.data[data_name])
else:
data_path = self.data_list[data_idx]
print(f"[DEBUG] Attempting to load: {data_path}")
data = np.loadtxt(data_path).astype(np.float32)
if self.num_point is not None:
if self.uniform_sampling:
with torch.no_grad():
mask = pointops.farthest_point_sampling(
torch.tensor(data).float(),
torch.tensor([len(data)]).long(),
torch.tensor([self.num_point]).long(),
)
data = data[mask.cpu()]
else:
data = data[: self.num_point]
coord, normal = data[:, 0:3], data[:, 3:6]
data_shape = os.path.basename(os.path.dirname(os.path.dirname(data_path)))
category = np.array([self.class_names[data_shape]])
# 🛠️ 关键兜底:确保点数 >= 32
min_points = 32
if len(coord) < min_points:
print(f"⚠️ [SAFEGUARD] Sample {data_name} has only {len(coord)} points. Padding to {min_points}.")
# 重复第一个点,直到点数达到 min_points
while len(coord) < min_points:
coord = np.concatenate([coord, coord[:1]], axis=0)
normal = np.concatenate([normal, normal[:1]], axis=0)
return dict(coord=coord, normal=normal, category=category)
def get_data_list(self):
print(f"[DEBUG] ModelNetDataset - data_root: {self.data_root}")
print(f"[DEBUG] ModelNetDataset - split: {self.split}")
# 🛠️ 修复:根据 split 参数确定是加载 'train' 还是 'test' 子目录
if self.split == "train":
subdir_name = "train"
elif self.split == "test" or self.split == "val":
subdir_name = "test"
else:
subdir_name = "train"
print(f"[WARNING] Unrecognized split '{self.split}'. Defaulting to '{subdir_name}' subdirectory.")
# 🛠️ 修复:始终扫描 data_root 下的所有类别目录
class_names = [name for name in os.listdir(self.data_root) if os.path.isdir(os.path.join(self.data_root, name))]
print(f"[DEBUG] Found {len(class_names)} classes: {class_names[:5]}")
data_list = []
for class_name in class_names:
# 🛠️ 构建完整路径 .../ModelNet40/airplane/train/
class_path = os.path.join(self.data_root, class_name, subdir_name)
if not os.path.exists(class_path):
print(f"⚠️ Warning: Directory not found: {class_path}")
continue
# 扫描该目录下的所有 .txt 文件,返回完整路径
files = glob.glob(os.path.join(class_path, "*.txt"))
data_list.extend(files)
print(f"[DEBUG] Found {len(files)} files in {class_path}")
print(f"[DEBUG] Total files found: {len(data_list)}")
return data_list
def get_data_name(self, idx):
# 🛠️ 修复:返回文件名(不含路径)
data_idx = idx % len(self.data_list)
return os.path.basename(self.data_list[data_idx])
def __getitem__(self, idx):
if self.test_mode:
return self.prepare_test_data(idx)
else:
return self.prepare_train_data(idx)
def __len__(self):
return len(self.data_list) * self.loop
def prepare_train_data(self, idx):
data_dict = self.get_data(idx)
data_dict = self.transform(data_dict)
return data_dict
def prepare_test_data(self, idx):
assert idx < len(self.data_list)
data_dict = self.get_data(idx)
category = data_dict.pop("category")
data_dict = self.transform(data_dict)
data_dict_list = []
for aug in self.aug_transform:
data_dict_list.append(aug(deepcopy(data_dict)))
for i in range(len(data_dict_list)):
data_dict_list[i] = self.post_transform(data_dict_list[i])
data_dict = dict(
voting_list=data_dict_list,
category=category,
name=self.get_data_name(idx),
)
return data_dict