File size: 2,807 Bytes
23c93db
d1b7df5
 
23c93db
 
 
 
 
 
 
 
 
 
 
 
 
 
1837337
23c93db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f251d7d
 
23c93db
 
f251d7d
 
23c93db
 
f251d7d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import sys
import os
file_path = os.getcwd()
sys.path.append(file_path)

import root_gnn_base.utils as utils
import argparse
from root_gnn_base.batched_dataset import PreBatchedDataset
from root_gnn_base.batched_dataset import LazyPreBatchedDataset

def main():
    parser = argparse.ArgumentParser()
    add_arg = parser.add_argument
    add_arg('--config', type=str, required=True)
    add_arg('--dataset', type=str, required=True)
    add_arg('--chunk', type=int, default=0)
    add_arg('--shuffle_mode', action='store_true', help='Shuffle the dataset before training.')
    add_arg('--drop_last', action='store_false', help='Set drop_last to False if the flag is provided. Defaults to True.')
    args = parser.parse_args()

    config = utils.load_config(args.config)
    dset_config = config['Datasets'][args.dataset]
    batch_size = config['Training']['batch_size']
    if not args.shuffle_mode:
        dset = utils.buildFromConfig(dset_config, {'process_chunks': [args.chunk,]})
    else:
        dset = utils.buildFromConfig(dset_config)
        if 'batch_size' in dset_config:
            batch_size = dset_config['batch_size']

        shuffle_chunks = dset_config.get('shuffle_chunks', 10)
        padding_mode = dset_config.get('padding_mode', 'STEPS')
        fold_conf = dset_config["folding"]
        print(f"shuffle_chunks = {shuffle_chunks}, args.chunk = {args.chunk}, padding_mode = {padding_mode}")
        if dset_config["class"] == "LazyMultiLabelDataset":
            LazyPreBatchedDataset(start_dataset = dset, batch_size = batch_size, mask_fn = utils.fold_selection(fold_conf, "train"), suffix = utils.fold_selection_name(fold_conf, "train"), chunks = shuffle_chunks, chunkno = args.chunk, padding_mode = padding_mode, drop_last=args.drop_last, hidden_size=config['Model']['args']['hid_size'] )
            LazyPreBatchedDataset(start_dataset = dset, batch_size = batch_size, mask_fn = utils.fold_selection(fold_conf, "test"),  suffix = utils.fold_selection_name(fold_conf, 'test'), chunks = shuffle_chunks, chunkno = args.chunk, padding_mode = padding_mode, drop_last=args.drop_last, hidden_size=config['Model']['args']['hid_size'])

        else:
            PreBatchedDataset(dset, batch_size, utils.fold_selection(fold_conf, "train"), suffix = utils.fold_selection_name(fold_conf, "train"), chunks = shuffle_chunks, chunkno = args.chunk, padding_mode = padding_mode, drop_last=args.drop_last,hidden_size=config['Model']['args']['hid_size'])
            PreBatchedDataset(dset, batch_size, utils.fold_selection(fold_conf, "test"),  suffix = utils.fold_selection_name(fold_conf, 'test'), chunks = shuffle_chunks, chunkno = args.chunk, padding_mode = padding_mode, drop_last=args.drop_last,hidden_size=config['Model']['args']['hid_size'] )

if __name__ == "__main__":
    main()