Spaces:
Sleeping
Sleeping
| import fastjet | |
| import os | |
| import copy | |
| import json | |
| import numpy as np | |
| import awkward as ak | |
| import torch.utils.data | |
| import time | |
| import pickle | |
| from collections import OrderedDict | |
| from functools import partial | |
| from concurrent.futures.thread import ThreadPoolExecutor | |
| from src.logger.logger import _logger, warn_once | |
| from src.data.tools import _pad, _repeat_pad, _clip, _pad_vector | |
| from src.data.fileio import _read_files | |
| from src.data.config import DataConfig, _md5 | |
| from src.data.preprocess import ( | |
| _apply_selection, | |
| _build_new_variables, | |
| _build_weights, | |
| AutoStandardizer, | |
| WeightMaker, | |
| ) | |
| from src.dataset.functions_data import to_tensor | |
| from src.layers.object_cond import calc_eta_phi | |
| from torch_scatter import scatter_sum | |
| from src.dataset.functions_graph import (create_graph, create_jets_outputs, | |
| create_jets_outputs_new, create_jets_outputs_Delphes, create_jets_outputs_Delphes2) | |
| from src.dataset.functions_data import Event, EventCollection, EventJets | |
| from src.utils.utils import CPU_Unpickler | |
| from src.dataset.functions_data import EventPFCands, concat_event_collection | |
| def get_pseudojets_fastjet(pfcands): | |
| pseudojets = [] | |
| for i in range(len(pfcands)): | |
| pseudojets.append(fastjet.PseudoJet(pfcands.pxyz[i, 0].item(), pfcands.pxyz[i, 1].item(), pfcands.pxyz[i, 2].item(), pfcands.E[i].item())) | |
| return pseudojets | |
| def _finalize_inputs(table, data_config): | |
| # transformation | |
| output = {} | |
| # transformation | |
| for k, params in data_config.preprocess_params.items(): | |
| if data_config._auto_standardization and params["center"] == "auto": | |
| raise ValueError("No valid standardization params for %s" % k) | |
| # if params["center"] is not None: | |
| # table[k] = (table[k] - params["center"]) * params["scale"] | |
| if params["length"] is not None: | |
| # if k == "hit_genlink": | |
| # pad_fn = partial(_pad_vector, value=-1) | |
| # table[k] = pad_fn(table[k]) | |
| # else: | |
| pad_fn = partial(_pad, value=0) | |
| table[k] = pad_fn(table[k], params["length"]) | |
| # stack variables for each input group | |
| for k, names in data_config.input_dicts.items(): | |
| if ( | |
| len(names) == 1 | |
| and data_config.preprocess_params[names[0]]["length"] is None | |
| ): | |
| output["_" + k] = ak.to_numpy(ak.values_astype(table[names[0]], "float32")) | |
| else: | |
| output["_" + k] = ak.to_numpy( | |
| np.stack( | |
| [ak.to_numpy(table[n]).astype("float32") for n in names], axis=1 | |
| ) | |
| ) | |
| # copy monitor variables | |
| for k in data_config.z_variables: | |
| if k not in output: | |
| output[k] = ak.to_numpy(table[k]) | |
| return output | |
| def _padlabel(table, _data_config): | |
| for k in _data_config.label_value: | |
| pad_fn = partial(_pad, value=0) | |
| table[k] = pad_fn(table[k], 400) | |
| return table | |
| def _preprocess(table, data_config, options): | |
| # apply selection | |
| table = _apply_selection( | |
| table, | |
| data_config.selection | |
| if options["training"] | |
| else data_config.test_time_selection, | |
| ) | |
| if len(table) == 0: | |
| return [] | |
| # table = _padlabel(table,data_config) | |
| # define new variables | |
| table = _build_new_variables(table, data_config.var_funcs) | |
| # else: | |
| indices = np.arange( | |
| len(table[table.fields[0]]) | |
| ) # np.arange(len(table[data_config.label_names[0]])) | |
| # shuffle | |
| if options["shuffle"]: | |
| np.random.shuffle(indices) | |
| # perform input variable standardization, clipping, padding and stacking | |
| table = _finalize_inputs(table, data_config) | |
| return table, indices | |
| def _load_next(data_config, filelist, load_range, options): | |
| table = _read_files( | |
| filelist, data_config.load_branches, load_range, treename=data_config.treename | |
| ) | |
| table, indices = _preprocess(table, data_config, options) | |
| return table, indices | |
| class _SimpleIter(object): | |
| r"""_SimpleIter | |
| Iterator object for ``SimpleIterDataset''. | |
| """ | |
| def __init__(self, **kwargs): | |
| # inherit all properties from SimpleIterDataset | |
| self.__dict__.update(**kwargs) | |
| self.iter_count = 0 # to raise StopIteration when dataset_cap is reached | |
| if "dataset_cap" in kwargs and kwargs["dataset_cap"] is not None: | |
| self.dataset_cap = kwargs["dataset_cap"] | |
| self._sampler_options["shuffle"] = False | |
| print("!!! Dataset_cap flag set, disabling shuffling") | |
| else: | |
| self.dataset_cap = None | |
| # executor to read files and run preprocessing asynchronously | |
| self.executor = ThreadPoolExecutor(max_workers=1) if self._async_load else None | |
| # init: prefetch holds table and indices for the next fetch | |
| self.prefetch = None | |
| self.table = None | |
| self.indices = [] | |
| self.cursor = 0 | |
| self._seed = None | |
| worker_info = torch.utils.data.get_worker_info() | |
| file_dict = self._init_file_dict.copy() | |
| if worker_info is not None: | |
| # in a worker process | |
| self._name += "_worker%d" % worker_info.id | |
| self._seed = worker_info.seed & 0xFFFFFFFF | |
| np.random.seed(self._seed) | |
| # split workload by files | |
| new_file_dict = {} | |
| for name, files in file_dict.items(): | |
| new_files = files[worker_info.id :: worker_info.num_workers] | |
| assert len(new_files) > 0 | |
| new_file_dict[name] = new_files | |
| file_dict = new_file_dict | |
| self.worker_file_dict = file_dict | |
| self.worker_filelist = sum(file_dict.values(), []) | |
| self.worker_info = worker_info | |
| self.restart() | |
| def restart(self): | |
| print("=== Restarting DataIter %s, seed=%s ===" % (self._name, self._seed)) | |
| # re-shuffle filelist and load range if for training | |
| filelist = self.worker_filelist.copy() | |
| if self._sampler_options["shuffle"]: | |
| np.random.shuffle(filelist) | |
| if self._file_fraction < 1: | |
| num_files = int(len(filelist) * self._file_fraction) | |
| filelist = filelist[:num_files] | |
| self.filelist = filelist | |
| if self._init_load_range_and_fraction is None: | |
| self.load_range = (0, 1) | |
| else: | |
| (start_pos, end_pos), load_frac = self._init_load_range_and_fraction | |
| interval = (end_pos - start_pos) * load_frac | |
| if self._sampler_options["shuffle"]: | |
| offset = np.random.uniform(start_pos, end_pos - interval) | |
| self.load_range = (offset, offset + interval) | |
| else: | |
| self.load_range = (start_pos, start_pos + interval) | |
| _logger.debug( | |
| "Init iter [%d], will load %d (out of %d*%s=%d) files with load_range=%s:\n%s", | |
| 0 if self.worker_info is None else self.worker_info.id, | |
| len(self.filelist), | |
| len(sum(self._init_file_dict.values(), [])), | |
| self._file_fraction, | |
| int(len(sum(self._init_file_dict.values(), [])) * self._file_fraction), | |
| str(self.load_range), | |
| ) | |
| # '\n'.join(self.filelist[: 3]) + '\n ... ' + self.filelist[-1],) | |
| _logger.info( | |
| "Restarted DataIter %s, load_range=%s, file_list:\n%s" | |
| % ( | |
| self._name, | |
| str(self.load_range), | |
| json.dumps(self.worker_file_dict, indent=2), | |
| ) | |
| ) | |
| # reset file fetching cursor | |
| self.ipos = 0 if self._fetch_by_files else self.load_range[0] | |
| # prefetch the first entry asynchronously | |
| self._try_get_next(init=True) | |
| def __next__(self): | |
| # print(self.ipos, self.cursor) | |
| graph_empty = True | |
| self.iter_count += 1 | |
| if self.dataset_cap is not None and self.iter_count > self.dataset_cap: | |
| raise StopIteration | |
| while graph_empty: | |
| if len(self.filelist) == 0: | |
| raise StopIteration | |
| try: | |
| i = self.indices[self.cursor] | |
| except IndexError: | |
| # case 1: first entry, `self.indices` is still empty | |
| # case 2: running out of entries, `self.indices` is not empty | |
| while True: | |
| if self._in_memory and len(self.indices) > 0: | |
| # only need to re-shuffle the indices, if this is not the first entry | |
| if self._sampler_options["shuffle"]: | |
| np.random.shuffle(self.indices) | |
| break | |
| if self.prefetch is None: | |
| # reaching the end as prefetch got nothing | |
| self.table = None | |
| if self._async_load: | |
| self.executor.shutdown(wait=False) | |
| raise StopIteration | |
| # get result from prefetch | |
| if self._async_load: | |
| self.table, self.indices = self.prefetch.result() | |
| else: | |
| self.table, self.indices = self.prefetch | |
| # try to load the next ones asynchronously | |
| self._try_get_next() | |
| # check if any entries are fetched (i.e., passing selection) -- if not, do another fetch | |
| if len(self.indices) > 0: | |
| break | |
| # reset cursor | |
| self.cursor = 0 | |
| i = self.indices[self.cursor] | |
| self.cursor += 1 | |
| data, graph_empty = self.get_data(i) | |
| return data | |
| def _try_get_next(self, init=False): | |
| end_of_list = ( | |
| self.ipos >= len(self.filelist) | |
| if self._fetch_by_files | |
| else self.ipos >= self.load_range[1] | |
| ) | |
| if end_of_list: | |
| if init: | |
| raise RuntimeError( | |
| "Nothing to load for worker %d" % 0 | |
| if self.worker_info is None | |
| else self.worker_info.id | |
| ) | |
| if self._infinity_mode and not self._in_memory: | |
| # infinity mode: re-start | |
| self.restart() | |
| return | |
| else: | |
| # finite mode: set prefetch to None, exit | |
| self.prefetch = None | |
| return | |
| if self._fetch_by_files: | |
| filelist = self.filelist[int(self.ipos) : int(self.ipos + self._fetch_step)] | |
| load_range = self.load_range | |
| else: | |
| filelist = self.filelist | |
| load_range = ( | |
| self.ipos, | |
| min(self.ipos + self._fetch_step, self.load_range[1]), | |
| ) | |
| # _logger.info('Start fetching next batch, len(filelist)=%d, load_range=%s'%(len(filelist), load_range)) | |
| if self._async_load: | |
| self.prefetch = self.executor.submit( | |
| _load_next, | |
| self._data_config, | |
| filelist, | |
| load_range, | |
| self._sampler_options, | |
| ) | |
| else: | |
| self.prefetch = _load_next( | |
| self._data_config, filelist, load_range, self._sampler_options | |
| ) | |
| self.ipos += self._fetch_step | |
| def get_data(self, i): | |
| # inputs | |
| X = {k: self.table["_" + k][i].copy() for k in self._data_config.input_names} | |
| if "EFlowPhoton" in X: | |
| return create_jets_outputs_Delphes(X), False | |
| elif "PFCands" in X: | |
| # v2 config | |
| return create_jets_outputs_Delphes2(X), False | |
| return create_jets_outputs_new(X), False | |
| class EventDatasetCollection(torch.utils.data.Dataset): | |
| def __init__(self, dir_list, args, aug_soft=False, aug_collinear=False, shuffle_seed=10): | |
| self.event_collections_dict = OrderedDict() | |
| if args: | |
| aug_soft = args.augment_soft_particles | |
| else: | |
| aug_soft=False | |
| for dir in dir_list: | |
| self.event_collections_dict[dir] = EventDataset.from_directory(dir, mmap=True, aug_soft=aug_soft or aug_soft, seed=0, aug_collinear=aug_collinear) | |
| self.n_events = sum([x.n_events for x in self.event_collections_dict.values()]) | |
| evt_idx = np.arange(0, self.n_events) # now shuffle this using the shuffle_seed and a separate random generator | |
| rng = np.random.default_rng(shuffle_seed) | |
| rng.shuffle(evt_idx) | |
| self.old_to_new_idx = evt_idx | |
| self.event_thresholds = [x.n_events for x in self.event_collections_dict.values()] | |
| self.event_thresholds = np.cumsum([0] + self.event_thresholds) | |
| self.dir_list = dir_list | |
| def __len__(self): | |
| return self.n_events | |
| def get_idx(self, i): | |
| assert i < self.n_events, "Index out of bounds: %d >= %d" % (i, self.n_events) | |
| for j in range(len(self.event_thresholds)-1): | |
| threshold = self.event_thresholds[j] | |
| if i >= threshold and i < self.event_thresholds[j+1]: | |
| #print("-------------", i, threshold, self.event_thresholds, j, self.dir_list[j]) | |
| return self.event_collections_dict[self.dir_list[j]][i - threshold] | |
| def getitem(self, i): | |
| return self.get_idx(i) | |
| def __iter__(self): | |
| for i in range(self.n_events): | |
| yield self.get_idx(self.old_to_new_idx[i]) | |
| def __getitem__(self, i): | |
| assert i < self.n_events, "Index out of bounds: %d >= %d" % (i, self.n_events) | |
| return self.get_idx(self.old_to_new_idx[i]) | |
| # A collection of EventDatasets. | |
| # You should use a sampler together with this, as by default it just concatenates the EventDatasets together! | |
| def get_batch_bounds(batch_idx): | |
| # batch_idx: tensor of format [0,0,0,0,1,1,1...] | |
| # returns tensor of format [0, 4, ...] | |
| print("Batch idx", batch_idx.shape, batch_idx[(batch_idx>3130) & (batch_idx < 3140)]) | |
| batches = sorted(batch_idx.unique().tolist()) | |
| skipped = [] | |
| for i in range(batch_idx.max().int().item()): | |
| if i not in batches: | |
| skipped.append(i) | |
| # reverse sort skipped | |
| skipped = sorted(skipped, reverse=True) | |
| result = torch.zeros(batch_idx.max().int().item() + 2 + len(skipped)) | |
| #for i, b in enumerate(batches): | |
| # assert i == b | |
| # result[i] = torch.where(batch_idx==b)[0].min() | |
| # result[i+1] = torch.where(batch_idx==b)[0].max() | |
| b_list = batch_idx.int().tolist() | |
| prev = -1 | |
| for i, b in enumerate(b_list): | |
| if b != prev: | |
| result[b] = i | |
| prev = b | |
| result[-1] = len(b_list) | |
| print("skipped", skipped) | |
| for s in skipped: | |
| if s == 0: | |
| result[s] = 0 | |
| else: | |
| result[s] = result[s+1] | |
| print("result", result.shape, result[3130:3140].tolist()) | |
| return result | |
| def filter_pfcands(pfcands): | |
| # filter the GenParticles so that dark matter particles are not present | |
| # dark matter particles are defined as those with abs(pdgId) > 10000 or pdgId between 50-60 | |
| # TODO: filter out high eta - temporarily this is done here, but it should be done in the ntuplizer in order to avoid big files | |
| mask = (torch.abs(pfcands.pid) < 10000) & ((torch.abs(pfcands.pid) < 50) | (torch.abs(pfcands.pid) > 60)) & (torch.abs(pfcands.eta) < 2.4) & (pfcands.pt > 0.5) #& (pfcands.pt > 0.5) | |
| pfcands.mask(mask) | |
| return pfcands | |
| class EventDataset(torch.utils.data.Dataset): | |
| def from_directory(dir, mmap=True, model_clusters_file=None, model_output_file=None, include_model_jets_unfiltered=False, fastjet_R=None, parton_level=False, gen_level=False, aug_soft=False, seed=0, aug_collinear=False, pt_jet_cutoff=100): | |
| result = {} | |
| for file in os.listdir(dir): | |
| if file == "metadata.pkl": | |
| metadata = pickle.load(open(os.path.join(dir, file), "rb")) | |
| else: | |
| print("File:", file) | |
| result[file.split(".")[0]] = np.load( | |
| os.path.join(dir, file), mmap_mode="r" if mmap else None | |
| ) | |
| dataset = EventDataset(result, metadata, model_clusters_file=model_clusters_file, | |
| model_output_file=model_output_file, | |
| include_model_jets_unfiltered=include_model_jets_unfiltered, | |
| fastjet_R=fastjet_R, parton_level=parton_level, gen_level=gen_level, aug_soft=aug_soft, | |
| seed=seed, aug_collinear=aug_collinear, pt_jet_cutoff=pt_jet_cutoff) | |
| return dataset | |
| def get_pfcands_key(self): | |
| pfcands_key = "pfcands" | |
| print("get_pfcands_key") | |
| if self.gen_level: | |
| return "final_gen_particles" | |
| if self.parton_level: | |
| return "final_parton_level_particles" | |
| if self.model_output is None: | |
| if self.gen_level: | |
| return "final_gen_particles" | |
| if self.parton_level: | |
| return "final_parton_level_particles" | |
| return pfcands_key # ignore | |
| for i in [0, 1, 2]: # try the first three if it fits | |
| start = {key: self.metadata[key + "_batch_idx"][i] for key in self.attrs} | |
| end = {key: self.metadata[key + "_batch_idx"][i + 1] for key in self.attrs} | |
| result = {key: self.events[key][start[key]:end[key]] for key in self.attrs} | |
| result = {key: EventCollection.deserialize(result[key], batch_number=None, cls=Event.evt_collections[key]) | |
| for key in self.attrs} | |
| if "final_parton_level_particles" in result: | |
| result["final_parton_level_particles"] = filter_pfcands(result["final_parton_level_particles"]) | |
| if "final_gen_particles" in result: | |
| result["final_gen_particles"] = filter_pfcands(result["final_gen_particles"]) | |
| event_filter_s, event_filter_e = self.model_output["event_idx_bounds"][i].int().item(), \ | |
| self.model_output["event_idx_bounds"][i + 1].int().item() | |
| diff = event_filter_e - event_filter_s | |
| if diff != len(result["pfcands"]): | |
| if diff == len(result["final_parton_level_particles"]): | |
| pfcands_key = "final_parton_level_particles" | |
| break | |
| if diff == len(result["final_gen_particles"]): | |
| pfcands_key = "final_gen_particles" | |
| break | |
| print("Found pfcands_key=%s" % pfcands_key) | |
| return pfcands_key | |
| def __init__(self, events, metadata, model_clusters_file=None, model_output_file=None, include_model_jets_unfiltered=False, fastjet_R=None, parton_level=False, gen_level=False, aug_soft=False, seed=0, aug_collinear=False, pt_jet_cutoff=100): | |
| # events: serialized events dict | |
| # metadata: dict with metadata | |
| self.events = events | |
| self.n_events = metadata["n_events"] | |
| self.attrs = metadata["attrs"] | |
| self.metadata = metadata | |
| self.include_model_jets_unfiltered = include_model_jets_unfiltered | |
| self.model_i = 0 | |
| self.parton_level = parton_level | |
| self.gen_level = gen_level | |
| self.augment_soft_particles = aug_soft | |
| self.aug_collinear = aug_collinear | |
| self.seed = seed | |
| self.pt_jet_cutoff = pt_jet_cutoff | |
| #self.pfcands_key = "pfcands" | |
| # set to final_parton_level_particles or final_gen_particles in case needed | |
| #for key in self.attrs: | |
| # self.evt_idx_to_batch_idx[key] = {} | |
| if model_output_file is not None: | |
| if type(model_output_file) == str: | |
| self.model_output = CPU_Unpickler(open(model_output_file, "rb")).load() | |
| else: | |
| self.model_output = model_output_file | |
| self.model_output["event_idx_bounds"] = get_batch_bounds(self.model_output["event_idx"]) | |
| self.n_events = self.model_output["event_idx"].max().int().item() # sometimes the last batch gets cut off, which causes problems | |
| if model_clusters_file is not None: | |
| self.model_clusters = to_tensor(pickle.load((open(model_clusters_file, "rb")))) | |
| else: | |
| self.model_clusters = self.model_output["model_cluster"] | |
| # model_output["batch_idx"] contains the batch index for each event. model_clusters is an array of the model labels for each event. | |
| else: | |
| self.model_output = None | |
| self.model_clusters = None | |
| if fastjet_R is not None: | |
| self.fastjet_jetdef = {r: fastjet.JetDefinition(fastjet.antikt_algorithm, r) for r in fastjet_R} | |
| ## fastjet_R is an array of radiuses for which to compute that | |
| self.pfcands_key = self.get_pfcands_key() | |
| def __len__(self): | |
| return self.n_events | |
| # def __next__(self): | |
| def add_model_output(self, model_output): | |
| if model_output is not None: | |
| if type(model_output) == str: | |
| self.model_output = CPU_Unpickler(open(model_output, "rb")).load() | |
| else: | |
| self.model_output = model_output | |
| self.model_output["event_idx_bounds"] = get_batch_bounds(self.model_output["event_idx"]) | |
| self.n_events = self.model_output["event_idx"].max().int().item() # sometimes the last batch gets cut off, which causes problems | |
| self.model_clusters = self.model_output["model_cluster"] | |
| # model_output["batch_idx"] contains the batch index for each event. model_clusters is an array of the model labels for each event. | |
| else: | |
| self.model_output = None | |
| self.model_clusters = None | |
| def pfcands_add_soft_particles(pfcands, n_soft, random_generator, add_original_particle_mapping=False): | |
| # augment the dataset with soft particles | |
| eta_bounds = [-2.4, 2.4] | |
| phi_bounds = [-3.14, 3.14] | |
| #pt_bounds = [0.02, 0.5] | |
| # choose random eta and phi | |
| # use the random generator for eta, phi | |
| eta = random_generator.uniform(eta_bounds[0], eta_bounds[1], n_soft).astype(np.double) | |
| phi = random_generator.uniform(phi_bounds[0], phi_bounds[1], n_soft).astype(np.double) | |
| #pt = random_generator.uniform(pt_bounds[0], pt_bounds[1], n_soft).astype(np.double) | |
| pt = np.ones(n_soft).astype(np.double) * 1e-2 | |
| charge = np.zeros(n_soft).astype(np.double) | |
| pid = np.zeros(n_soft).astype(np.double) | |
| mass = np.zeros(n_soft).astype(np.double) | |
| if hasattr(pfcands, "status"): | |
| status = np.zeros(n_soft) | |
| soft_pfcands = EventPFCands(pt, eta, phi, mass, charge, pid, pf_cand_jet_idx=-1 * torch.ones(n_soft), status=status) | |
| else: | |
| soft_pfcands = EventPFCands(pt, eta, phi, mass, charge, pid, pf_cand_jet_idx=-1*torch.ones(n_soft)) | |
| soft_pfcands.original_particle_mapping = torch.tensor([-1] * len(soft_pfcands)) | |
| pfcandsc = copy.deepcopy(pfcands) | |
| pfcandsc.original_particle_mapping = torch.arange(len(pfcands)) | |
| pfcandsc = concat_event_collection([pfcandsc, soft_pfcands], nobatch=1) | |
| if not add_original_particle_mapping: | |
| pfcandsc.original_particle_mapping = torch.arange(len(pfcandsc)) # For now, ignore the soft particles | |
| #print("Original PM:", pfcandsc.original_particle_mapping.max()) | |
| return pfcandsc | |
| def pfcands_split_particles(pfcands, random_generator): | |
| # Augment the dataset by spliting the harder particles | |
| # 5 highest pt particles | |
| k = min(5, len(pfcands)) | |
| highest_pt_idx = torch.topk(pfcands.pt, k)[1] | |
| weights = pfcands.pt[highest_pt_idx] | |
| # Pick a random particle to split according to weights | |
| n_to_split = random_generator.randint(0, k) | |
| #idx = random_generator.choice(highest_pt_idx, p=weights / weights.sum()) | |
| indices = highest_pt_idx[:n_to_split] | |
| pfcandsc = copy.deepcopy(pfcands) | |
| pfcandsc.original_particle_mapping = torch.arange(len(pfcands)) | |
| # assert that indices are all lower than len(pfcands) | |
| if not torch.all(indices < len(pfcands)): | |
| print("Indices:", indices) | |
| print("PFCands:", pfcands.pt) | |
| print("PFCands len:", len(pfcands.pt)) | |
| raise ValueError("Indices are out of bounds") | |
| for idx in indices: | |
| split_into = random_generator.randint(2, 5) | |
| # split the particle into | |
| eta = pfcands.eta[idx] | |
| phi = pfcands.phi[idx] | |
| pt = pfcands.pt[idx] / split_into | |
| charge = pfcands.charge[idx] | |
| mass = 0 | |
| pid = pfcands.pid[idx] | |
| colinear_pfcands = EventPFCands(pt=[pt], eta=[eta], phi=[phi], mass=[mass], charge=[charge], pid=[pid], pf_cand_jet_idx=[pfcands.pf_cand_jet_idx[idx]], original_particle_mapping=[idx]) | |
| #pfcandsc.original_particle_mapping[idx] = idx | |
| pfcandsc.pt[idx] = pt | |
| for _ in range(split_into-1): | |
| pfcandsc = concat_event_collection([pfcandsc, colinear_pfcands], nobatch=1) | |
| if pfcandsc.original_particle_mapping.max() >= len(pfcands): | |
| #print("Original PM:", pfcandsc.original_particle_mapping.max(), "len pfcands", len(pfcands)) | |
| raise ValueError("Original particle mapping is out of bounds") | |
| return pfcandsc | |
| def get_idx(self, i): | |
| #print("Getting idx", i) | |
| start = {key: self.metadata[key + "_batch_idx"][i] for key in self.attrs} | |
| end = {key: self.metadata[key + "_batch_idx"][i + 1] for key in self.attrs} | |
| result = {key: self.events[key][start[key]:end[key]] for key in self.attrs} | |
| result = {key: EventCollection.deserialize(result[key], batch_number=None, cls=Event.evt_collections[key]) for | |
| key in self.attrs} | |
| result["pfcands"] = filter_pfcands(result["pfcands"]) | |
| if "final_parton_level_particles" in result: | |
| #print("i=", i) | |
| #print("BEFORE:", len(result["final_parton_level_particles"])) | |
| result["final_parton_level_particles"] = filter_pfcands(result["final_parton_level_particles"]) | |
| #print("AFTER:", len(result["final_parton_level_particles"])) | |
| #print("------") | |
| if "final_gen_particles" in result: | |
| result["final_gen_particles"] = filter_pfcands(result["final_gen_particles"]) | |
| ## augment pfcands here | |
| if self.augment_soft_particles: | |
| random_generator = np.random.RandomState(seed=i + self.seed) | |
| #n_soft = int(random_generator.uniform(10, 1000)) | |
| n_soft = 500 | |
| #n_soft = 1000 | |
| result["pfcands"] = EventDataset.pfcands_add_soft_particles(result["pfcands"], n_soft, random_generator) | |
| if "final_parton_level_particles" in result: | |
| result["final_parton_level_particles"] = EventDataset.pfcands_add_soft_particles(result["final_parton_level_particles"], n_soft, random_generator) # Also augment parton-level event for testing | |
| if "final_gen_particles" in result: | |
| result["final_gen_particles"] = EventDataset.pfcands_add_soft_particles(result["final_gen_particles"], n_soft, random_generator) | |
| else: | |
| result["pfcands"].original_particle_mapping = torch.arange(len(result["pfcands"].pt)) | |
| if self.aug_collinear: | |
| random_generator = np.random.RandomState(seed=i + self.seed) | |
| if i % 2: # Every second one: | |
| result["pfcands"] = EventDataset.pfcands_split_particles(result["pfcands"], random_generator) | |
| if "final_parton_level_particles" in result: | |
| result["final_parton_level_particles"] = EventDataset.pfcands_split_particles( | |
| result["final_parton_level_particles"], random_generator | |
| ) | |
| # Also augment parton-level event for testing | |
| if "final_gen_particles" in result: | |
| result["final_gen_particles"] = EventDataset.pfcands_split_particles(result["final_gen_particles"], random_generator) | |
| else: | |
| n_soft = 500 | |
| result["pfcands"] = EventDataset.pfcands_add_soft_particles(result["pfcands"], n_soft, random_generator, | |
| add_original_particle_mapping=True) | |
| if "final_parton_level_particles" in result: | |
| result["final_parton_level_particles"] = EventDataset.pfcands_add_soft_particles( | |
| result["final_parton_level_particles"], n_soft, random_generator, add_original_particle_mapping=True | |
| ) | |
| # Also augment parton-level event for testing | |
| if "final_gen_particles" in result: | |
| result["final_gen_particles"] = EventDataset.pfcands_add_soft_particles( | |
| result["final_gen_particles"], | |
| n_soft, | |
| random_generator, | |
| add_original_particle_mapping=True | |
| ) | |
| if self.model_output is not None: | |
| #if "final_parton_level_particles" in result and len(result["final_parton_level_particles"]) == 0: | |
| # print("!!") | |
| # return None | |
| result["model_jets"], bc_scores_pfcands, bc_labels_pfcands = self.get_model_jets(i, pfcands=result[self.pfcands_key], include_target=1, dq=result["matrix_element_gen_particles"]) | |
| result[self.pfcands_key].bc_scores_pfcands = bc_scores_pfcands | |
| result[self.pfcands_key].bc_labels_pfcands = bc_labels_pfcands | |
| if self.include_model_jets_unfiltered: | |
| result["model_jets_unfiltered"], _, _ = self.get_model_jets(i, pfcands=result[self.pfcands_key], filter=False) | |
| if hasattr(self, "fastjet_jetdef") and self.fastjet_jetdef is not None: | |
| if self.gen_level: | |
| result["fastjet_jets"] = {key: EventDataset.get_fastjet_jets(result, self.fastjet_jetdef[key], key="final_gen_particles", pt_cutoff=self.pt_jet_cutoff) for key in self.fastjet_jetdef} | |
| elif self.parton_level: | |
| result["fastjet_jets"] = {key: EventDataset.get_fastjet_jets(result, self.fastjet_jetdef[key], key="final_parton_level_particles", pt_cutoff=self.pt_jet_cutoff) for key in self.fastjet_jetdef} | |
| else: | |
| result["fastjet_jets"] = {key: EventDataset.get_fastjet_jets(result, self.fastjet_jetdef[key], key="pfcands", pt_cutoff=self.pt_jet_cutoff) for key | |
| in self.fastjet_jetdef} | |
| if "genjets" in result: | |
| result["genjets"] = EventDataset.mask_jets(result["genjets"]) | |
| evt = Event(**result) | |
| assert evt.pfcands.original_particle_mapping.max() < len(evt.pfcands.pt), "Original particle mapping is out of bounds: " + str(evt.original_particle_mapping.max()) + " >= " + str(len(evt.pfcands.pt)) | |
| return evt | |
| def get_target_obj_score(clusters_eta, clusters_phi, clusters_pt, event_idx_clusters, dq_eta, dq_phi, dq_event_idx): | |
| # return the target scores for each cluster (reteurns list of 1's and 0's) | |
| # dq_coords: list of [eta, phi] for each dark quark | |
| # dq_event_idx: list of event_idx for each dark quarks | |
| target = [] | |
| for event in event_idx_clusters.unique(): | |
| filt = event_idx_clusters == event | |
| clusters = torch.stack([clusters_eta[filt], clusters_phi[filt], clusters_pt[filt]], dim=1) | |
| dq_coords_event = torch.stack([dq_eta[dq_event_idx == event], dq_phi[dq_event_idx == event]], dim=1) | |
| dist_matrix = torch.cdist( | |
| dq_coords_event, | |
| clusters[:, :2].to(dq_coords_event.device), | |
| p=2 | |
| ).T | |
| if len(dist_matrix) == 0: | |
| target.append(torch.zeros(len(clusters)).int().to(dist_matrix.device)) | |
| continue | |
| closest_quark_dist, closest_quark_idx = dist_matrix.min(dim=1) | |
| closest_quark_idx[closest_quark_dist > 0.8] = -1 | |
| target.append((closest_quark_idx != -1).float()) | |
| if len(target): | |
| return torch.cat(target).flatten() | |
| return torch.tensor([]) | |
| def mask_jets(jets, cutoff=100): | |
| mask = jets.pt >= cutoff | |
| return EventJets(jets.pt[mask], jets.eta[mask], jets.phi[mask], jets.mass[mask]) | |
| def get_model_jets_static(i, pfcands, model_output, model_clusters): | |
| event_filter_s, event_filter_e = model_output["event_idx_bounds"][i].int().item(), model_output["event_idx_bounds"][i + 1].int().item() | |
| pfcands_pt = pfcands.pt | |
| pfcands_pxyz = pfcands.pxyz | |
| pfcands_E = pfcands.E | |
| #assert len(pfcands_pt) == event_filter_e - event_filter_s, "Error!, len(pfcands_pt)==%d, event_filter_e-event_filter_s=%d" % (len(pfcands_pt), event_filter_e - event_filter_s) | |
| if not len(pfcands_pt) == event_filter_e - event_filter_s: | |
| return None | |
| # jets_pt = scatter_sum(to_tensor(pfcands_pt), self.model_clusters[event_filter] + 1, dim=0)[1:] | |
| jets_pxyz = scatter_sum(to_tensor(pfcands_pxyz), model_clusters[event_filter_s:event_filter_e] + 1, dim=0)[1:] | |
| jets_pt = torch.norm(jets_pxyz[:, :2], p=2, dim=-1) | |
| jets_eta, jets_phi = calc_eta_phi(jets_pxyz, False) | |
| # jets_mass = torch.zeros_like(jets_eta) | |
| jets_E = scatter_sum(to_tensor(pfcands_E), model_clusters[event_filter_s:event_filter_e] + 1, dim=0)[1:] | |
| jets_mass = torch.sqrt(jets_E ** 2 - jets_pxyz.norm(dim=-1) ** 2) | |
| cluster_labels = model_clusters[event_filter_s:event_filter_e] | |
| bc_scores = model_output["pred"][event_filter_s:event_filter_e, -1] | |
| cutoff = 100 | |
| mask = jets_pt >= cutoff | |
| return EventJets(jets_pt[mask], jets_eta[mask], jets_phi[mask], jets_mass[mask]) | |
| def get_jets_fastjets_raw_with_assignment(pfcands, jetdef, pt_cutoff=100): | |
| pt = [] | |
| eta = [] | |
| phis = [] | |
| mass = [] | |
| particle_to_jet = {} # this will map particle_idx -> jet_idx | |
| array = get_pseudojets_fastjet(pfcands) | |
| for idx, pseudojet in enumerate(array): | |
| pseudojet.set_user_index(idx) | |
| cluster = fastjet.ClusterSequence(array, jetdef) | |
| inc_jets = cluster.inclusive_jets() | |
| jet_idx = 0 | |
| for elem in inc_jets: | |
| if elem.pt() < pt_cutoff: | |
| continue | |
| # print("pt:", elem.pt(), "eta:", elem.rap(), "phi:", elem.phi())ž | |
| pt.append(elem.pt()) | |
| eta.append(elem.rap()) | |
| phi = elem.phi() | |
| if phi > np.pi: | |
| phi -= 2 * np.pi | |
| phis.append(phi) | |
| mass.append(elem.m()) | |
| # Get constituents of this jet | |
| constituents = cluster.constituents(elem) | |
| for constituent in constituents: | |
| particle_idx = constituent.user_index() | |
| particle_to_jet[particle_idx] = jet_idx | |
| jet_idx += 1 | |
| return pt, eta, phis, mass, particle_to_jet | |
| def get_jets_fastjets_raw(pfcands, jetdef, pt_cutoff=100): | |
| pt = [] | |
| eta = [] | |
| phis = [] | |
| mass = [] | |
| array = get_pseudojets_fastjet(pfcands) | |
| cluster = fastjet.ClusterSequence(array, jetdef) | |
| inc_jets = cluster.inclusive_jets() | |
| for elem in inc_jets: | |
| if elem.pt() < pt_cutoff: | |
| continue | |
| # print("pt:", elem.pt(), "eta:", elem.rap(), "phi:", elem.phi())ž | |
| pt.append(elem.pt()) | |
| eta.append(elem.rap()) | |
| phi = elem.phi() | |
| if phi > np.pi: | |
| phi -= 2 * np.pi | |
| phis.append(phi) | |
| mass.append(elem.m()) | |
| return pt, eta, phis, mass | |
| def get_fastjet_jets_with_assignment(event, jetdef, key="pfcands", pt_cutoff=100): | |
| if type(event) == dict: | |
| k = event[key] | |
| else: | |
| k = getattr(event, key) | |
| pt, eta, phi, m, assignment = EventDataset.get_jets_fastjets_raw_with_assignment(k, jetdef, pt_cutoff=pt_cutoff) | |
| return EventJets(torch.tensor(pt), torch.tensor(eta), torch.tensor(phi), torch.tensor(m)), assignment | |
| def get_fastjet_jets(event, jetdef, key="pfcands", pt_cutoff=100): | |
| if type(event) == dict: | |
| k = event[key] | |
| else: | |
| k = getattr(event, key) | |
| pt, eta, phi, m = EventDataset.get_jets_fastjets_raw(k, jetdef, pt_cutoff=pt_cutoff) | |
| return EventJets(torch.tensor(pt), torch.tensor(eta), torch.tensor(phi), torch.tensor(m)) | |
| def get_model_jets(self, i, pfcands, filter=True, dq=None, include_target=False): | |
| event_filter_s, event_filter_e = self.model_output["event_idx_bounds"][i].int().item(), self.model_output["event_idx_bounds"][i+1].int().item() | |
| pfcands_pt = pfcands.pt | |
| pfcands_pxyz = pfcands.pxyz | |
| pfcands_E = pfcands.E | |
| obj_score = None | |
| #print("Len pfcands_pt", len(pfcands_pt), "event_filter_e", event_filter_e, "event_filter_s", event_filter_s) | |
| if len(pfcands_pt) == 0: | |
| return EventJets(torch.tensor([]), torch.tensor([]), torch.tensor([]) ,torch.tensor([])), None, None | |
| assert len(pfcands_pt) == event_filter_e - event_filter_s, "Error! filter={} len(pfcands_pt)={} event_filter_e={} event_filter_s={}".format(filter, len(pfcands_pt), event_filter_e, event_filter_s) | |
| #jets_pt = scatter_sum(to_tensor(pfcands_pt), self.model_clusters[event_filter] + 1, dim=0)[1:] | |
| jets_pxyz = scatter_sum(to_tensor(pfcands_pxyz), self.model_clusters[event_filter_s:event_filter_e] + 1, dim=0)[1:] | |
| jets_pt = torch.norm(jets_pxyz[:, :2], p=2, dim=-1) | |
| jets_eta, jets_phi = calc_eta_phi(jets_pxyz, False) | |
| #jets_mass = torch.zeros_like(jets_eta) | |
| jets_E = scatter_sum(to_tensor(pfcands_E), self.model_clusters[event_filter_s:event_filter_e] + 1, dim=0)[1:] | |
| jets_mass = torch.sqrt(jets_E**2 - jets_pxyz.norm(dim=-1)**2) | |
| cluster_labels = self.model_clusters[event_filter_s:event_filter_e] | |
| bc_scores = self.model_output["pred"][event_filter_s:event_filter_e, -1] | |
| if "obj_score_pred" in self.model_output and not torch.is_tensor(self.model_output["obj_score_pred"]): | |
| self.model_output["obj_score_pred"] = torch.cat(self.model_output["obj_score_pred"]) | |
| print("Concatenated obj_score_pred") | |
| target_obj_score = None | |
| if filter: | |
| cutoff = self.pt_jet_cutoff | |
| mask = jets_pt >= cutoff | |
| if "obj_score_pred" in self.model_output: | |
| obj_score = self.model_output["obj_score_pred"][(self.model_output["event_clusters_idx"] == i)] | |
| #print("Jets pt", jets_pt, "obj score", obj_score) | |
| assert len(obj_score) == len(jets_pt), "Error! len(obj_score)=%d, len(jets_pt)=%d" % ( | |
| len(obj_score), len(jets_pt)) | |
| if include_target: | |
| target_obj_score = EventDataset.get_target_obj_score(jets_eta, jets_phi, jets_pt, torch.zeros(jets_pt.size(0)), dq.eta, dq.phi, torch.zeros(dq.eta.size(0))) | |
| else: | |
| mask = torch.ones_like(jets_pt, dtype=torch.bool) | |
| if obj_score is not None: | |
| obj_score = obj_score[mask] | |
| assert len(jets_pt[mask]) == len(obj_score), "Error! len(jets_pt[mask])=%d, len(obj_score)=%d" % (len(jets_pt[mask]), len(obj_score)) | |
| if target_obj_score is not None: | |
| target_obj_score = target_obj_score[mask] | |
| assert len(jets_pt[mask]) == len(target_obj_score), "Error! len(jets_pt[mask])=%d, len(obj_score)=%d" % (len(jets_pt[mask]), len(obj_score)) | |
| return EventJets(jets_pt[mask], jets_eta[mask], jets_phi[mask], jets_mass[mask], obj_score=obj_score, target_obj_score=target_obj_score), bc_scores, cluster_labels | |
| def get_iter(self): | |
| self.i = 0 | |
| while self.i < self.n_events: | |
| yield self.get_idx(self.i) | |
| self.i += 1 | |
| def __iter__(self): | |
| return self.get_iter() | |
| def __getitem__(self, i): | |
| assert i < self.n_events, "Index out of bounds: %d >= %d" % (i, self.n_events) | |
| return self.get_idx(i) | |
| class SimpleIterDataset(torch.utils.data.IterableDataset): | |
| r"""Base IterableDataset. | |
| Handles dataloading. | |
| Arguments: | |
| file_dict (dict): dictionary of lists of files to be loaded. | |
| data_config_file (str): YAML file containing data format information. | |
| for_training (bool): flag indicating whether the dataset is used for training or testing. | |
| When set to ``True``, will enable shuffling and sampling-based reweighting. | |
| When set to ``False``, will disable shuffling and reweighting, but will load the observer variables. | |
| load_range_and_fraction (tuple of tuples, ``((start_pos, end_pos), load_frac)``): fractional range of events to load from each file. | |
| E.g., setting load_range_and_fraction=((0, 0.8), 0.5) will randomly load 50% out of the first 80% events from each file (so load 50%*80% = 40% of the file). | |
| fetch_by_files (bool): flag to control how events are retrieved each time we fetch data from disk. | |
| When set to ``True``, will read only a small number (set by ``fetch_step``) of files each time, but load all the events in these files. | |
| When set to ``False``, will read from all input files, but load only a small fraction (set by ``fetch_step``) of events each time. | |
| Default is ``False``, which results in a more uniform sample distribution but reduces the data loading speed. | |
| fetch_step (float or int): fraction of events (when ``fetch_by_files=False``) or number of files (when ``fetch_by_files=True``) to load each time we fetch data from disk. | |
| Event shuffling and reweighting (sampling) is performed each time after we fetch data. | |
| So set this to a large enough value to avoid getting an imbalanced minibatch (due to reweighting/sampling), especially when ``fetch_by_files`` set to ``True``. | |
| Will load all events (files) at once if set to non-positive value. | |
| file_fraction (float): fraction of files to load. | |
| """ | |
| def __init__( | |
| self, | |
| file_dict, | |
| data_config_file, | |
| for_training=True, | |
| load_range_and_fraction=None, | |
| extra_selection=None, | |
| fetch_by_files=False, | |
| fetch_step=0.01, | |
| file_fraction=1, | |
| remake_weights=False, | |
| up_sample=True, | |
| weight_scale=1, | |
| max_resample=10, | |
| async_load=True, | |
| infinity_mode=False, | |
| in_memory=False, | |
| name="", | |
| laplace=False, | |
| edges=False, | |
| diffs=False, | |
| dataset_cap=None, | |
| n_noise=0, | |
| synthetic=False, | |
| synthetic_npart_min=2, | |
| synthetic_npart_max=5, | |
| jets=False, | |
| ): | |
| self._iters = {} if infinity_mode or in_memory else None | |
| _init_args = set(self.__dict__.keys()) | |
| self._init_file_dict = file_dict | |
| self._init_load_range_and_fraction = load_range_and_fraction | |
| self._fetch_by_files = fetch_by_files | |
| self._fetch_step = fetch_step | |
| self._file_fraction = file_fraction | |
| self._async_load = async_load | |
| self._infinity_mode = infinity_mode | |
| self._in_memory = in_memory | |
| self._name = name | |
| self.laplace = laplace | |
| self.edges = edges | |
| self.diffs = diffs | |
| self.synthetic = synthetic | |
| self.synthetic_npart_min = synthetic_npart_min | |
| self.synthetic_npart_max = synthetic_npart_max | |
| self.dataset_cap = dataset_cap # used to cap the dataset to some fixed number of events - used for debugging purposes | |
| self.n_noise = n_noise | |
| self.jets = jets | |
| # ==== sampling parameters ==== | |
| self._sampler_options = { | |
| "up_sample": up_sample, | |
| "weight_scale": weight_scale, | |
| "max_resample": max_resample, | |
| } | |
| if for_training: | |
| self._sampler_options.update(training=True, shuffle=False, reweight=True) | |
| else: | |
| self._sampler_options.update(training=False, shuffle=False, reweight=False) | |
| # discover auto-generated reweight file | |
| if ".auto.yaml" in data_config_file: | |
| data_config_autogen_file = data_config_file | |
| else: | |
| data_config_md5 = _md5(data_config_file) | |
| data_config_autogen_file = data_config_file.replace( | |
| ".yaml", ".%s.auto.yaml" % data_config_md5 | |
| ) | |
| if os.path.exists(data_config_autogen_file): | |
| data_config_file = data_config_autogen_file | |
| _logger.info( | |
| "Found file %s w/ auto-generated preprocessing information, will use that instead!" | |
| % data_config_file | |
| ) | |
| # load data config (w/ observers now -- so they will be included in the auto-generated yaml) | |
| self._data_config = DataConfig.load(data_config_file) | |
| if for_training: | |
| # produce variable standardization info if needed | |
| if self._data_config._missing_standardization_info: | |
| s = AutoStandardizer(file_dict, self._data_config) | |
| self._data_config = s.produce(data_config_autogen_file) | |
| # produce reweight info if needed | |
| # if self._sampler_options['reweight'] and self._data_config.weight_name and not self._data_config.use_precomputed_weights: | |
| # if remake_weights or self._data_config.reweight_hists is None: | |
| # w = WeightMaker(file_dict, self._data_config) | |
| # self._data_config = w.produce(data_config_autogen_file) | |
| # reload data_config w/o observers for training | |
| if ( | |
| os.path.exists(data_config_autogen_file) | |
| and data_config_file != data_config_autogen_file | |
| ): | |
| data_config_file = data_config_autogen_file | |
| _logger.info( | |
| "Found file %s w/ auto-generated preprocessing information, will use that instead!" | |
| % data_config_file | |
| ) | |
| self._data_config = DataConfig.load( | |
| data_config_file, load_observers=False, extra_selection=extra_selection | |
| ) | |
| else: | |
| self._data_config = DataConfig.load( | |
| data_config_file, | |
| load_reweight_info=False, | |
| extra_test_selection=extra_selection, | |
| ) | |
| # Derive all variables added to self.__dict__ | |
| self._init_args = set(self.__dict__.keys()) - _init_args | |
| def config(self): | |
| return self._data_config | |
| def __iter__(self): | |
| if self._iters is None: | |
| kwargs = {k: copy.deepcopy(self.__dict__[k]) for k in self._init_args} | |
| return _SimpleIter(**kwargs) | |
| else: | |
| worker_info = torch.utils.data.get_worker_info() | |
| worker_id = worker_info.id if worker_info is not None else 0 | |
| try: | |
| return self._iters[worker_id] | |
| except KeyError: | |
| kwargs = {k: copy.deepcopy(self.__dict__[k]) for k in self._init_args} | |
| self._iters[worker_id] = _SimpleIter(**kwargs) | |
| return self._iters[worker_id] | |