| import dgl |
| import torch |
| import numpy as np |
| import awkward as ak |
| from dataclasses import dataclass |
| from typing import List, Any, Union |
|
|
| from dataset.Graphs import Graphs, save_graphs |
| from dataset import Normalization |
|
|
| @dataclass |
| class ChunkConfig: |
| name: str |
| label: Union[str, int] |
| chunk_id: int |
| batch_size: int |
| arrays: List[Any] |
| features: List[Any] |
| globals: List[Any] |
| weights: Union[str, float] |
| tracking: List[Any] |
| branches: List[Any] |
| dtype: torch.dtype |
| save_path: str |
| prebatch: dict |
|
|
| def process_chunk(cfg: ChunkConfig): |
| |
| graph_list = [] |
| meta_dict = { |
| 'globals': [], |
| 'label': [], |
| 'weight': [], |
| 'tracking': [], |
| 'batch_num_nodes': [], |
| 'batch_num_edges': [], |
| } |
|
|
| for i in range(len(cfg.arrays)): |
| g, meta = process_single_entry(cfg, i) |
| graph_list.append(g) |
| for k in meta_dict: |
| meta_dict[k].append(meta[k]) |
|
|
| |
| for k in meta_dict: |
| meta_dict[k] = torch.stack(meta_dict[k]) |
|
|
| graphs = Graphs(graphs=graph_list, metadata=meta_dict) |
| Normalization.save_stats(graphs, f"{cfg.save_path}/stats/{cfg.name}_{cfg.chunk_id:04d}.json") |
|
|
| if getattr(cfg.prebatch, "enabled", False): |
| graphs.shuffle() |
| graphs.batch(cfg.prebatch["chunk_size"]) |
|
|
| save_graphs(graphs, f"{cfg.save_path}/{cfg.name}_{cfg.chunk_id:04d}.bin") |
|
|
| def process_single_entry(cfg, i): |
| |
| node_features: List[torch.Tensor] = [] |
|
|
| for particle, branch_list in cfg.features.items(): |
| feature_tensors: List[torch.Tensor] = [] |
| for branch in branch_list: |
| if branch == "CALC_E": |
| pT = feature_tensors[0] |
| eta = feature_tensors[1] |
| val = pT * torch.cosh(eta) |
| elif isinstance(branch, str): |
| arr = cfg.arrays[branch][i] |
| val = torch.from_numpy(ak.to_numpy(arr)).to(cfg.dtype) |
| else: |
| length = feature_tensors[0].shape[0] |
| val = torch.full((length,), float(branch), dtype=cfg.dtype) |
| feature_tensors.append(val) |
|
|
| if feature_tensors and feature_tensors[0].numel() > 0: |
| block = torch.stack(feature_tensors, dim=1) |
| node_features.append(block) |
|
|
| node_features = torch.cat(node_features, dim=0) if node_features else torch.empty((0, len(cfg.features)), dtype=cfg.dtype) |
|
|
| |
| global_feat_list: List[torch.Tensor] = [] |
| for b in cfg.globals: |
| if b == "NUM_NODES": |
| global_feat_list.append(torch.tensor([len(node_features)], dtype=cfg.dtype)) |
| else: |
| arr = cfg.arrays[b][i] |
| global_feat_list.append(torch.from_numpy(ak.to_numpy(arr)).to(cfg.dtype)) |
| global_feat = torch.cat(global_feat_list, dim=0) if global_feat_list else torch.zeros((1,), dtype=cfg.dtype) |
|
|
| |
| tracking_list: List[torch.Tensor] = [] |
| for b in cfg.tracking: |
| arr = cfg.arrays[b][i] |
| tracking_list.append(torch.from_numpy(ak.to_numpy(arr)).to(cfg.dtype)) |
| tracking = torch.cat(tracking_list, dim=0) if tracking_list else torch.zeros((1,), dtype=cfg.dtype) |
|
|
| |
| weight = float(cfg.arrays[cfg.weights][i]) if isinstance(cfg.weights, str) else cfg.weights |
| weight = torch.tensor(weight, dtype=cfg.dtype) |
|
|
| |
| label = float(cfg.arrays[cfg.label][i]) if isinstance(cfg.label, str) else cfg.label |
| label = torch.tensor(label, dtype=cfg.dtype) |
|
|
| |
| g = make_graph(node_features, dtype=cfg.dtype) |
|
|
| |
| batch_num_nodes = g.batch_num_nodes() |
| batch_num_edges = g.batch_num_edges() |
|
|
| meta = { |
| 'globals': global_feat, |
| 'label': label, |
| 'weight': weight, |
| 'tracking': tracking, |
| 'batch_num_nodes': batch_num_nodes, |
| 'batch_num_edges': batch_num_edges, |
| } |
| return g, meta |
|
|
| src_dst_cache = {} |
| def get_src_dst(num_nodes): |
| if num_nodes not in src_dst_cache: |
| src, dst = torch.meshgrid(torch.arange(num_nodes), torch.arange(num_nodes), indexing='ij') |
| src_dst_cache[num_nodes] = (src.flatten(), dst.flatten()) |
| return src_dst_cache[num_nodes] |
|
|
| @torch.jit.script |
| def compute_edge_features(eta, phi, src, dst): |
| deta = eta[src] - eta[dst] |
| dphi = phi[src] - phi[dst] |
| dphi = torch.remainder(dphi + np.pi, 2 * np.pi) - np.pi |
| dR = torch.sqrt(deta ** 2 + dphi ** 2) |
| edge_features = torch.stack([dR, deta, dphi], dim=1) |
| return edge_features |
|
|
| def make_graph(node_features: torch.tensor, dtype=torch.float32): |
|
|
| num_nodes = node_features.shape[0] |
| if num_nodes == 0: |
| g = dgl.graph(([], [])) |
| g.ndata['features'] = node_features |
| g.edata['features'] = torch.empty((0, 3), dtype=dtype) |
| g.globals = torch.tensor([0], dtype=dtype) |
| return g |
| |
| src, dst = get_src_dst(num_nodes) |
| src = src.flatten() |
| dst = dst.flatten() |
| g = dgl.graph((src, dst)) |
| g.ndata['features'] = node_features |
|
|
| eta = node_features[:, 1] |
| phi = node_features[:, 2] |
| edge_features = compute_edge_features(eta, phi, src, dst) |
| g.edata['features'] = edge_features |
|
|
| return g |