biptv3 / server_worktree /scripts /generate_train_test_split0915.py
YYYYYYUUU's picture
Backup FULL poplab work tree (source, configs, libs, scripts) excl. .pth
08cde47 verified
Raw
History Blame Contribute Delete
1.4 kB
# File: /home/lab/LAD/bitpointV3/pointcept_framework/scripts/generate_train_test_split.py
import os
def generate_split_files(data_root, train_ratio=0.8):
"""根据文件夹结构自动生成 train.txt 和 test.txt"""
train_list = []
test_list = []
for class_name in os.listdir(data_root):
class_dir = os.path.join(data_root, class_name)
if not os.path.isdir(class_dir):
continue
# 获取该类别下所有 .txt 文件
files = [f.replace('.txt', '') for f in os.listdir(class_dir) if f.endswith('.txt')]
files.sort() # 确保顺序一致
# 按比例划分
split_idx = int(len(files) * train_ratio)
train_files = files[:split_idx]
test_files = files[split_idx:]
train_list.extend([f"{class_name}_{f}" for f in train_files])
test_list.extend([f"{class_name}_{f}" for f in test_files])
# 写入文件
with open(os.path.join(data_root, "modelnet40_train.txt"), 'w') as f:
f.write('\n'.join(train_list))
with open(os.path.join(data_root, "modelnet40_test.txt"), 'w') as f:
f.write('\n'.join(test_list))
print(f"Generated {len(train_list)} training samples and {len(test_list)} testing samples.")
if __name__ == "__main__":
DATA_ROOT = "/home/lab/LAD/bitpointV3/data/modelnet40_normal_resampled"
generate_split_files(DATA_ROOT, train_ratio=0.8)