| |
|
|
| import os |
|
|
| def generate_split_files(data_root, train_ratio=0.8): |
| """根据文件夹结构自动生成 train.txt 和 test.txt""" |
| train_list = [] |
| test_list = [] |
|
|
| for class_name in os.listdir(data_root): |
| class_dir = os.path.join(data_root, class_name) |
| if not os.path.isdir(class_dir): |
| continue |
|
|
| |
| files = [f.replace('.txt', '') for f in os.listdir(class_dir) if f.endswith('.txt')] |
| files.sort() |
|
|
| |
| split_idx = int(len(files) * train_ratio) |
| train_files = files[:split_idx] |
| test_files = files[split_idx:] |
|
|
| train_list.extend([f"{class_name}_{f}" for f in train_files]) |
| test_list.extend([f"{class_name}_{f}" for f in test_files]) |
|
|
| |
| with open(os.path.join(data_root, "modelnet40_train.txt"), 'w') as f: |
| f.write('\n'.join(train_list)) |
|
|
| with open(os.path.join(data_root, "modelnet40_test.txt"), 'w') as f: |
| f.write('\n'.join(test_list)) |
|
|
| print(f"Generated {len(train_list)} training samples and {len(test_list)} testing samples.") |
|
|
| if __name__ == "__main__": |
| DATA_ROOT = "/home/lab/LAD/bitpointV3/data/modelnet40_normal_resampled" |
| generate_split_files(DATA_ROOT, train_ratio=0.8) |