File size: 5,458 Bytes
97aa5af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import argparse
import copy
import logging
import os
import shlex
import sys
from pathlib import Path

import numpy as np

_REPO_ROOT = Path(__file__).resolve().parents[1]
if str(_REPO_ROOT) not in sys.path:
    sys.path.insert(0, str(_REPO_ROOT))

from tools import data
from dataloader.dataset_generator import combine_dataset_dict, generate_dataset, generate_dataset_dict

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

def main():
    # Set up the argument parser
    parser = argparse.ArgumentParser(description='Automate dataset generation and processing.')

    # Define arguments (change these as needed)
    parser.add_argument('--pcdPath', type=str, required=True, help='Path to the PCD file')
    parser.add_argument('--cadPath', type=str, required=True, help='Path to the CAD file')
    parser.add_argument('--action', type=str, choices=['generate_dataset', 'generate_dataset_dict', 'combine_dataset_dict'], required=True, help='Action to perform')
    parser.add_argument('--compute_normals', action='store_true', help='Flag to compute normals')
    parser.add_argument('--every_k_points', type=int, default=1, help='Sampling rate for points')
    parser.add_argument('--save', action='store_true', help='Flag to save the generated dataset')
    parser.add_argument(
        '--save_path',
        type=str,
        default='data/simulators',
        help='Directory to save generated datasets (relative to repo root if not absolute)',
    )
    parser.add_argument('--name', type=str, required=True, help='Name identifier for the dataset (e.g., teeth, cube, etc.)')

    # Additional parameters for dataset generation (change these as needed)
    parser.add_argument('--num_transformation', type=int, default=50, help='Number of transformations')
    parser.add_argument('--angles', type=int, nargs='+', default=list(range(0, 360, 10)), help='Rotation angles')
    parser.add_argument('--translation_range', type=float, nargs=2, default=(-1, 1), help='Translation range')
    parser.add_argument('--dataset_size', type=int, default=400, help='Size of the dataset to generate')
    parser.add_argument('--index', type=int, default=0, help='Index for dataset generation')
    parser.add_argument('--noise_level', type=float, default=0, help='Noise level')
    parser.add_argument('--outlier_level', type=float, default=0, help='Outlier level')
    parser.add_argument('--outlier_bounds', type=float, nargs=2, default=(-10, 10), help='Outlier bounds')
    parser.add_argument('--occ_level', type=float, default=0, help='Occlusion level')

    # Parse the arguments

    # Check if an argument file is being used
    if sys.argv[1].startswith('@'):
        args_file = sys.argv[1][1:]  # Strip the '@' from the filename
        with open(args_file, 'r') as file:
            # Read and split arguments from the file
            args = parser.parse_args(shlex.split(file.read()))
    else:
        args = parser.parse_args()

    # Print out the arguments to verify
    print(vars(args))

    # Load the data
    np.random.seed(42)
    if args.compute_normals:
        _, cad, _, cad_normals = data.load_data(args.pcdPath, args.cadPath, every_k_points=args.every_k_points, same_length=True, compute_normals=True)
        suffix = '_with_normals'
    else:
        _, cad = data.load_data(args.pcdPath, args.cadPath, every_k_points=args.every_k_points, same_length=True)
        cad_normals = None
        suffix = ''
    source = copy.deepcopy(cad)

    rp = Path(args.save_path)
    if not rp.is_absolute():
        rp = _REPO_ROOT / args.save_path
    ROOT_DIR = str(rp.resolve())
    if not ROOT_DIR.endswith(os.sep):
        ROOT_DIR += os.sep

    # Perform the selected action
    if args.action == 'generate_dataset':
        logging.info('Generating dataset...')
        generate_dataset(source, args.pcdPath, args.cadPath, args.num_transformation, args.angles, args.translation_range, args.index, args.noise_level, args.outlier_level, args.outlier_bounds, args.occ_level, save_dir=ROOT_DIR)

    elif args.action == 'generate_dataset_dict':
        logging.info('Generating dataset dictionary...')
        output_train_file = f'{ROOT_DIR}data_dict_train_{args.name}{suffix}.pkl'
        output_test_file = f'{ROOT_DIR}data_dict_test_{args.name}{suffix}.pkl'
        generate_dataset_dict(source, args.dataset_size, args.index, output_train_file, output_test_file, cad_normals)

    elif args.action == 'combine_dataset_dict':
        logging.info('Combining dataset dictionaries...')
        train_files = [
            f'{ROOT_DIR}data_dict_train_teeth{suffix}.pkl'
            # f'{ROOT_DIR}data_dict_train_elephant{suffix}.pkl',
            # f'{ROOT_DIR}data_dict_train_house{suffix}.pkl',
            # f'{ROOT_DIR}data_dict_train_shoe{suffix}.pkl'
        ]

        test_files = [
            f'{ROOT_DIR}data_dict_test_teeth{suffix}.pkl'
            # f'{ROOT_DIR}data_dict_test_elephant{suffix}.pkl',
            # f'{ROOT_DIR}data_dict_test_house{suffix}.pkl',
            # f'{ROOT_DIR}data_dict_test_shoe{suffix}.pkl'
        ]

        output_train_file = f'{ROOT_DIR}data_dict_train_{suffix}.pkl'
        output_test_file = f'{ROOT_DIR}data_dict_test_{suffix}.pkl'

        combine_dataset_dict(train_files, test_files, output_train_file, output_test_file)

    else:
        logging.warning('No valid action selected.')

if __name__ == '__main__':
    main()