Spaces:
Sleeping
Sleeping
| import torch | |
| import os.path as osp | |
| import os | |
| import sys | |
| from src.dataset.dataset import SimpleIterDataset | |
| from src.utils.utils import to_filelist | |
| from pathlib import Path | |
| import pickle | |
| from src.utils.paths import get_path | |
| import argparse | |
| import numpy as np | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--input", type=str) | |
| parser.add_argument("--output", type=str) | |
| parser.add_argument("--overwrite", action="store_true") | |
| parser.add_argument("--dataset-cap", type=int, default=-1) | |
| parser.add_argument("--v2", action="store_true") # V2 means that the dataset also stores parton-level and genParticles | |
| parser.add_argument("--delphes", action="store_true") | |
| args = parser.parse_args() | |
| path = get_path(args.input, "data") | |
| def remove_from_list(lst): | |
| out = [] | |
| for item in lst: | |
| if item in ["hgcal", "data.txt", "test_file.root"]: | |
| continue | |
| out.append(item) | |
| return out | |
| def preprocess_dataset(datasets, output_path, config_file, dataset_cap): | |
| #datasets = os.listdir(path) | |
| #datasets = [os.path.join(path, x) for x in datasets] | |
| class Args: | |
| def __init__(self): | |
| self.data_train = datasets | |
| self.data_val = datasets | |
| #self.data_train = files_train | |
| self.data_config = config_file | |
| self.extra_selection = None | |
| self.train_val_split = 1.0 | |
| self.data_fraction = 1 | |
| self.file_fraction = 1 | |
| self.fetch_by_files = False | |
| self.fetch_step = 1 | |
| self.steps_per_epoch = None | |
| self.in_memory = False | |
| self.local_rank = None | |
| self.copy_inputs = False | |
| self.no_remake_weights = False | |
| self.batch_size = 10 | |
| self.num_workers = 0 | |
| self.demo = False | |
| self.laplace = False | |
| self.diffs = False | |
| self.class_edges = False | |
| args = Args() | |
| train_range = (0, args.train_val_split) | |
| train_file_dict, train_files = to_filelist(args, 'train') | |
| train_data = SimpleIterDataset(train_file_dict, args.data_config, for_training=True, | |
| extra_selection=args.extra_selection, | |
| remake_weights=True, | |
| load_range_and_fraction=(train_range, args.data_fraction), | |
| file_fraction=args.file_fraction, | |
| fetch_by_files=args.fetch_by_files, | |
| fetch_step=args.fetch_step, | |
| infinity_mode=False, | |
| in_memory=args.in_memory, | |
| async_load=False, | |
| name='train', jets=True) | |
| iterator = iter(train_data) | |
| from time import time | |
| t0 = time() | |
| data = [] | |
| count = 0 | |
| while True: | |
| try: | |
| i = next(iterator) | |
| data.append(i) | |
| count += 1 | |
| if dataset_cap > 0 and count >= dataset_cap: | |
| break | |
| except StopIteration: | |
| break | |
| t1 = time() | |
| print("Took", t1-t0, "s -", datasets[0]) | |
| from src.dataset.functions_data import concat_events | |
| events = concat_events(data) # TODO: This can be done in a nicer way, using less memory (?) | |
| result = events.serialize() | |
| dir_name = datasets[0].split("/")[-2] | |
| save_to_dir = os.path.join(output_path, dir_name) | |
| Path(save_to_dir).mkdir(parents=True, exist_ok=True) | |
| for key in result[0]: | |
| with open(osp.join(save_to_dir, key + ".pkl"), "wb") as f: | |
| #pickle.dump(result[0][key], f) #save with torch for mmap | |
| #torch.save(result[0][key], f) | |
| np.save(f, result[0][key].numpy()) | |
| with open(osp.join(save_to_dir, "metadata.pkl"), "wb") as f: | |
| pickle.dump(result[1], f) | |
| print("Saved to", save_to_dir) | |
| print("Finished", dir_name) | |
| ''' | |
| from src.dataset.functions_data import EventCollection, EventJets, Event | |
| from src.dataset.dataset import EventDataset | |
| t2 = time() | |
| data1 = [] | |
| for event in EventDataset(result[0], result[1]): | |
| data1.append(event) | |
| t3 = time() | |
| print("Took", t3-t2, "s") | |
| print("Done") | |
| ''' | |
| output = get_path(args.output, "preprocessed_data") | |
| for dir in os.listdir(path): | |
| if args.overwrite or not os.path.exists(os.path.join(output, dir)): | |
| config = get_path('config_files/config_jets.yaml', 'code') | |
| if args.v2: | |
| delphes_suffix = "" | |
| if args.delphes: | |
| delphes_suffix = "_delphes" | |
| config = get_path(f'config_files/config_jets_2{delphes_suffix}.yaml', 'code') | |
| for i, file in enumerate(sorted(os.listdir(os.path.join(path, dir)))): | |
| print("Preprocessing file", file) | |
| preprocess_dataset([os.path.join(path, dir, file)], output + "_part"+str(i), config_file=config, dataset_cap=args.dataset_cap) | |
| else: | |
| print("Skipping", dir + ", already exists") | |
| # flush | |
| sys.stdout.flush() | |