| import sys | |
| import os | |
| file_path = os.getcwd() | |
| sys.path.append(file_path) | |
| import root_gnn_base.utils as utils | |
| import argparse | |
| from root_gnn_base.batched_dataset import PreBatchedDataset | |
| from root_gnn_base.batched_dataset import LazyPreBatchedDataset | |
| def main(): | |
| parser = argparse.ArgumentParser() | |
| add_arg = parser.add_argument | |
| add_arg('--config', type=str, required=True) | |
| add_arg('--dataset', type=str, required=True) | |
| add_arg('--chunk', type=int, default=0) | |
| add_arg('--shuffle_mode', action='store_true', help='Shuffle the dataset before training.') | |
| add_arg('--drop_last', action='store_false', help='Set drop_last to False if the flag is provided. Defaults to True.') | |
| args = parser.parse_args() | |
| config = utils.load_config(args.config) | |
| dset_config = config['Datasets'][args.dataset] | |
| batch_size = config['Training']['batch_size'] | |
| if not args.shuffle_mode: | |
| dset = utils.buildFromConfig(dset_config, {'process_chunks': [args.chunk,]}) | |
| else: | |
| dset = utils.buildFromConfig(dset_config) | |
| if 'batch_size' in dset_config: | |
| batch_size = dset_config['batch_size'] | |
| shuffle_chunks = dset_config.get('shuffle_chunks', 10) | |
| padding_mode = dset_config.get('padding_mode', 'STEPS') | |
| fold_conf = dset_config["folding"] | |
| print(f"shuffle_chunks = {shuffle_chunks}, args.chunk = {args.chunk}, padding_mode = {padding_mode}") | |
| if dset_config["class"] == "LazyMultiLabelDataset": | |
| LazyPreBatchedDataset(start_dataset = dset, batch_size = batch_size, mask_fn = utils.fold_selection(fold_conf, "train"), suffix = utils.fold_selection_name(fold_conf, "train"), chunks = shuffle_chunks, chunkno = args.chunk, padding_mode = padding_mode, drop_last=args.drop_last, hidden_size=config['Model']['args']['hid_size'] ) | |
| LazyPreBatchedDataset(start_dataset = dset, batch_size = batch_size, mask_fn = utils.fold_selection(fold_conf, "test"), suffix = utils.fold_selection_name(fold_conf, 'test'), chunks = shuffle_chunks, chunkno = args.chunk, padding_mode = padding_mode, drop_last=args.drop_last, hidden_size=config['Model']['args']['hid_size']) | |
| else: | |
| PreBatchedDataset(dset, batch_size, utils.fold_selection(fold_conf, "train"), suffix = utils.fold_selection_name(fold_conf, "train"), chunks = shuffle_chunks, chunkno = args.chunk, padding_mode = padding_mode, drop_last=args.drop_last,hidden_size=config['Model']['args']['hid_size']) | |
| PreBatchedDataset(dset, batch_size, utils.fold_selection(fold_conf, "test"), suffix = utils.fold_selection_name(fold_conf, 'test'), chunks = shuffle_chunks, chunkno = args.chunk, padding_mode = padding_mode, drop_last=args.drop_last,hidden_size=config['Model']['args']['hid_size'] ) | |
| if __name__ == "__main__": | |
| main() | |