| # File: /home/lab/LAD/bitpointV3/pointcept_framework/scripts/generate_train_test_split.py | |
| import os | |
| def generate_split_files(data_root, train_ratio=0.8): | |
| """根据文件夹结构自动生成 train.txt 和 test.txt""" | |
| train_list = [] | |
| test_list = [] | |
| for class_name in os.listdir(data_root): | |
| class_dir = os.path.join(data_root, class_name) | |
| if not os.path.isdir(class_dir): | |
| continue | |
| # 获取该类别下所有 .txt 文件 | |
| files = [f.replace('.txt', '') for f in os.listdir(class_dir) if f.endswith('.txt')] | |
| files.sort() # 确保顺序一致 | |
| # 按比例划分 | |
| split_idx = int(len(files) * train_ratio) | |
| train_files = files[:split_idx] | |
| test_files = files[split_idx:] | |
| train_list.extend([f"{class_name}_{f}" for f in train_files]) | |
| test_list.extend([f"{class_name}_{f}" for f in test_files]) | |
| # 写入文件 | |
| with open(os.path.join(data_root, "modelnet40_train.txt"), 'w') as f: | |
| f.write('\n'.join(train_list)) | |
| with open(os.path.join(data_root, "modelnet40_test.txt"), 'w') as f: | |
| f.write('\n'.join(test_list)) | |
| print(f"Generated {len(train_list)} training samples and {len(test_list)} testing samples.") | |
| if __name__ == "__main__": | |
| DATA_ROOT = "/home/lab/LAD/bitpointV3/data/modelnet40_normal_resampled" | |
| generate_split_files(DATA_ROOT, train_ratio=0.8) |