Spaces:
Sleeping
Sleeping
| """ | |
| This file contains a modified version of the dataloader originally from: | |
| weaver-core | |
| https://github.com/hqucms/weaver-core | |
| The original implementation has been adapted and extended for the needs of this project. | |
| Please refer to the original repository for the base implementation and license details. | |
| Changes in this version: | |
| - Adapted to read parquet files | |
| - Modified batching logic to build graphs on the fly | |
| - No reweighting or standarization of dataset | |
| """ | |
| import os | |
| import copy | |
| import json | |
| import numpy as np | |
| import awkward as ak | |
| import torch.utils.data | |
| import time | |
| from functools import partial | |
| from concurrent.futures.thread import ThreadPoolExecutor | |
| from src.data.tools import _pad | |
| from src.data.fileio import _read_files | |
| from src.data.preprocess import ( | |
| AutoStandardizer, | |
| WeightMaker, | |
| ) | |
| from src.dataset.functions_graph import create_graph | |
| def _preprocess(table, options): | |
| indices = np.arange( | |
| len(table["X_track"]) | |
| ) | |
| if options["shuffle"]: | |
| np.random.shuffle(indices) | |
| return table, indices | |
| def _load_next(filelist, load_range, options): | |
| table = _read_files( | |
| filelist, load_range, | |
| ) | |
| table, indices = _preprocess(table, options) | |
| return table, indices | |
| class _SimpleIter(object): | |
| r"""_SimpleIter | |
| Iterator object for ``SimpleIterDataset''. | |
| """ | |
| def __init__(self, **kwargs): | |
| # inherit all properties from SimpleIterDataset | |
| self.__dict__.update(**kwargs) | |
| self.iter_count = 0 | |
| # executor to read files and run preprocessing asynchronously | |
| self.executor = ThreadPoolExecutor(max_workers=1) if self._async_load else None | |
| # init: prefetch holds table and indices for the next fetch | |
| self.prefetch = None | |
| self.table = None | |
| self.indices = [] | |
| self.cursor = 0 | |
| self._seed = None | |
| worker_info = torch.utils.data.get_worker_info() | |
| file_dict = self._init_file_dict.copy() | |
| if worker_info is not None: | |
| # in a worker process | |
| self._name += "_worker%d" % worker_info.id | |
| self._seed = worker_info.seed & 0xFFFFFFFF | |
| np.random.seed(self._seed) | |
| # split workload by files | |
| new_file_dict = {} | |
| for name, files in file_dict.items(): | |
| new_files = files[worker_info.id :: worker_info.num_workers] | |
| assert len(new_files) > 0 | |
| new_file_dict[name] = new_files | |
| file_dict = new_file_dict | |
| self.worker_file_dict = file_dict | |
| self.worker_filelist = sum(file_dict.values(), []) | |
| self.worker_info = worker_info | |
| self.restart() | |
| def restart(self): | |
| print("=== Restarting DataIter %s, seed=%s ===" % (self._name, self._seed)) | |
| # re-shuffle filelist and load range if for training | |
| filelist = self.worker_filelist.copy() | |
| if self._sampler_options["shuffle"]: | |
| np.random.shuffle(filelist) | |
| if self._file_fraction < 1: | |
| num_files = int(len(filelist) * self._file_fraction) | |
| filelist = filelist[:num_files] | |
| self.filelist = filelist | |
| if self._init_load_range_and_fraction is None: | |
| self.load_range = (0, 1) | |
| else: | |
| (start_pos, end_pos), load_frac = self._init_load_range_and_fraction | |
| interval = (end_pos - start_pos) * load_frac | |
| if self._sampler_options["shuffle"]: | |
| offset = np.random.uniform(start_pos, end_pos - interval) | |
| self.load_range = (offset, offset + interval) | |
| else: | |
| self.load_range = (start_pos, start_pos + interval) | |
| self.ipos = 0 if self._fetch_by_files else self.load_range[0] | |
| # prefetch the first entry asynchronously | |
| self._try_get_next(init=True) | |
| def __next__(self): | |
| graph_empty = True | |
| self.iter_count += 1 | |
| while graph_empty: | |
| if len(self.filelist) == 0: | |
| raise StopIteration | |
| try: | |
| i = self.indices[self.cursor] | |
| except IndexError: | |
| # case 1: first entry, `self.indices` is still empty | |
| # case 2: running out of entries, `self.indices` is not empty | |
| while True: | |
| if self.prefetch is None: | |
| # reaching the end as prefetch got nothing | |
| self.table = None | |
| if self._async_load: | |
| self.executor.shutdown(wait=False) | |
| raise StopIteration | |
| # get result from prefetch | |
| if self._async_load: | |
| self.table, self.indices = self.prefetch.result() | |
| else: | |
| self.table, self.indices = self.prefetch | |
| # try to load the next ones asynchronously | |
| self._try_get_next() | |
| # check if any entries are fetched (i.e., passing selection) -- if not, do another fetch | |
| if len(self.indices) > 0: | |
| break | |
| # reset cursor | |
| self.cursor = 0 | |
| i = self.indices[self.cursor] | |
| self.cursor += 1 | |
| data, graph_empty = self.get_data(i) | |
| return data | |
| def _try_get_next(self, init=False): | |
| end_of_list = ( | |
| self.ipos >= len(self.filelist) | |
| if self._fetch_by_files | |
| else self.ipos >= self.load_range[1] | |
| ) | |
| if end_of_list: | |
| if init: | |
| raise RuntimeError( | |
| "Nothing to load for worker %d" % 0 | |
| if self.worker_info is None | |
| else self.worker_info.id | |
| ) | |
| if self._infinity_mode and not self._in_memory: | |
| # infinity mode: re-start | |
| self.restart() | |
| return | |
| else: | |
| # finite mode: set prefetch to None, exit | |
| self.prefetch = None | |
| return | |
| if self._fetch_by_files: | |
| filelist = self.filelist[int(self.ipos) : int(self.ipos + self._fetch_step)] | |
| load_range = self.load_range | |
| else: | |
| filelist = self.filelist | |
| load_range = ( | |
| self.ipos, | |
| min(self.ipos + self._fetch_step, self.load_range[1]), | |
| ) | |
| print('Start fetching next batch, len(filelist)=%d, load_range=%s'%(len(filelist), load_range)) | |
| if self._async_load: | |
| self.prefetch = self.executor.submit( | |
| _load_next, | |
| filelist, | |
| load_range, | |
| self._sampler_options, | |
| ) | |
| else: | |
| self.prefetch = _load_next( | |
| filelist, load_range, self._sampler_options | |
| ) | |
| self.ipos += self._fetch_step | |
| def get_data(self, i): | |
| # inputs | |
| self.args_parse.prediction = (not self.for_training) | |
| # X = {k: self.table["_" + k][i].copy() for k in self._data_config.input_names} | |
| X = {k: self.table[k][i] for k in self.table.fields} | |
| [g, features_partnn], graph_empty = create_graph( | |
| X, self.for_training, self.args_parse | |
| ) | |
| return [g, features_partnn], graph_empty | |
| # return X, False | |
| class SimpleIterDataset(torch.utils.data.IterableDataset): | |
| r"""Base IterableDataset. | |
| Handles dataloading. | |
| Arguments: | |
| file_dict (dict): dictionary of lists of files to be loaded. | |
| data_config_file (str): YAML file containing data format information. | |
| for_training (bool): flag indicating whether the dataset is used for training or testing. | |
| When set to ``True``, will enable shuffling and sampling-based reweighting. | |
| When set to ``False``, will disable shuffling and reweighting, but will load the observer variables. | |
| load_range_and_fraction (tuple of tuples, ``((start_pos, end_pos), load_frac)``): fractional range of events to load from each file. | |
| E.g., setting load_range_and_fraction=((0, 0.8), 0.5) will randomly load 50% out of the first 80% events from each file (so load 50%*80% = 40% of the file). | |
| fetch_by_files (bool): flag to control how events are retrieved each time we fetch data from disk. | |
| When set to ``True``, will read only a small number (set by ``fetch_step``) of files each time, but load all the events in these files. | |
| When set to ``False``, will read from all input files, but load only a small fraction (set by ``fetch_step``) of events each time. | |
| Default is ``False``, which results in a more uniform sample distribution but reduces the data loading speed. | |
| fetch_step (float or int): fraction of events (when ``fetch_by_files=False``) or number of files (when ``fetch_by_files=True``) to load each time we fetch data from disk. | |
| Event shuffling and reweighting (sampling) is performed each time after we fetch data. | |
| So set this to a large enough value to avoid getting an imbalanced minibatch (due to reweighting/sampling), especially when ``fetch_by_files`` set to ``True``. | |
| Will load all events (files) at once if set to non-positive value. | |
| file_fraction (float): fraction of files to load. | |
| """ | |
| def __init__( | |
| self, | |
| file_dict, | |
| data_config_file, | |
| for_training=True, | |
| load_range_and_fraction=None, | |
| extra_selection=None, | |
| fetch_by_files=False, | |
| fetch_step=0.01, | |
| file_fraction=1, | |
| remake_weights=False, | |
| up_sample=True, | |
| weight_scale=1, | |
| max_resample=10, | |
| async_load=True, | |
| infinity_mode=False, | |
| name="", | |
| args_parse=None | |
| ): | |
| self._iters = {} if infinity_mode else None | |
| _init_args = set(self.__dict__.keys()) | |
| self._init_file_dict = file_dict | |
| self._init_load_range_and_fraction = load_range_and_fraction | |
| self._fetch_by_files = fetch_by_files | |
| self._fetch_step = fetch_step | |
| self._file_fraction = file_fraction | |
| self._async_load = async_load | |
| self._infinity_mode = infinity_mode | |
| self._name = name | |
| self.for_training = for_training | |
| self.args_parse = args_parse | |
| # ==== sampling parameters ==== | |
| self._sampler_options = { | |
| "up_sample": up_sample, | |
| "weight_scale": weight_scale, | |
| "max_resample": max_resample, | |
| } | |
| if for_training: | |
| self._sampler_options.update(training=True, shuffle=True, reweight=True) | |
| else: | |
| self._sampler_options.update(training=False, shuffle=False, reweight=False) | |
| self._init_args = set(self.__dict__.keys()) - _init_args | |
| def __iter__(self): | |
| if self._iters is None: | |
| kwargs = {k: copy.deepcopy(self.__dict__[k]) for k in self._init_args} | |
| return _SimpleIter(**kwargs) | |
| else: | |
| worker_info = torch.utils.data.get_worker_info() | |
| worker_id = worker_info.id if worker_info is not None else 0 | |
| try: | |
| return self._iters[worker_id] | |
| except KeyError: | |
| kwargs = {k: copy.deepcopy(self.__dict__[k]) for k in self._init_args} | |
| self._iters[worker_id] = _SimpleIter(**kwargs) | |
| return self._iters[worker_id] | |