Spaces:
Sleeping
Sleeping
| # A simple wrapper to run the L-GATr model on HuggingFace spaces | |
| import shutil | |
| import glob | |
| import argparse | |
| import functools | |
| import numpy as np | |
| import math | |
| import torch | |
| import sys | |
| import os | |
| import wandb | |
| import time | |
| from pathlib import Path | |
| from src.layers.object_cond import calc_eta_phi | |
| torch.autograd.set_detect_anomaly(True) | |
| from src.dataset.functions_data import get_batch | |
| from src.dataset.functions_data import concat_events, Event, EventPFCands | |
| from src.plotting.plot_event import plot_event | |
| from src.dataset.dataset import EventDataset | |
| from src.jetfinder.clustering import get_clustering_labels | |
| from torch_scatter import scatter_sum | |
| from src.utils.train_utils import ( | |
| to_filelist, | |
| train_load, | |
| test_load, | |
| get_model, | |
| get_optimizer_and_scheduler, | |
| get_model_obj_score | |
| ) | |
| from src.utils.paths import get_path | |
| import warnings | |
| import pickle | |
| import os | |
| import fastjet | |
| def inference(loss_str, train_dataset_str, input_text, input_text_quarks): | |
| args = argparse.ArgumentParser() | |
| model_path = f"models/{loss_str}/{train_dataset_str}.ckpt" | |
| args.spatial_part_only = True # LGATr | |
| args.load_model_weights = model_path | |
| args.aug_soft = True # LGATr_GP etc. | |
| args.network_config = "src/1models/LGATr/lgatr.py" | |
| args.beta_type = "pt+bc" | |
| args.embed_as_vectors = False | |
| args.debug = False | |
| args.epsilon = 0.3 | |
| args.gen_level = False | |
| args.parton_level = False | |
| args.global_features_obj_score = False | |
| args.gt_radius = 0.8 | |
| args.no_pid = True | |
| args.hidden_mv_channels = 16 | |
| args.hidden_s_channels = 64 | |
| args.internal_dim = 128 | |
| args.lorentz_norm = False | |
| args.min_cluster_size = 2 | |
| args.min_samples = 1 | |
| args.n_heads = 4 | |
| args.num_blocks = 10 | |
| args.scalars_oc=False | |
| dev = torch.device("cpu") | |
| model = get_model(args, dev) | |
| orig_model = model | |
| batch_config = {"use_p_xyz": True, "use_four_momenta": False} | |
| if "lgatr" in args.network_config.lower(): | |
| batch_config = {"use_four_momenta": True} | |
| batch_config["no_pid"] = True | |
| print("batch_config:", batch_config) | |
| model.eval() | |
| # input text in format pt,eta,phi,mass,charge | |
| pt, eta, phi, mass, charge = [], [], [], [], [] | |
| # now parse the input text | |
| for line in input_text.strip().split('\n'): | |
| values = list(map(float, line.split())) | |
| pt.append(values[0]) | |
| eta.append(values[1]) | |
| phi.append(values[2]) | |
| mass.append(values[3]) | |
| charge.append(int(values[4])) | |
| pt_quarks, eta_quarks, phi_quarks = [], [], [] | |
| for line in input_text_quarks.strip().split("\n"): | |
| values = list(map(float, line.split())) | |
| pt_quarks.append(values[0]) | |
| eta_quarks.append(values[1]) | |
| phi_quarks.append(values[2]) | |
| pid = torch.zeros(len(pt)) | |
| pf_cand_jet_idx = [-1] * len(pt) | |
| pfcands = EventPFCands(pt, eta, phi, mass, charge, pid, pf_cand_jet_idx=pf_cand_jet_idx) | |
| n_soft = 0 | |
| if "GP" in loss_str: | |
| n_soft = 500 | |
| if n_soft > 0: | |
| pfcands = EventDataset.pfcands_add_soft_particles(pfcands, n_soft, random_generator=np.random.RandomState(seed=0)) | |
| event = Event(pfcands=pfcands) | |
| event_batch = concat_events([event]) | |
| batch, _ = get_batch(event_batch, batch_config, torch.zeros(len(pfcands)), test=True) | |
| with torch.no_grad(): | |
| coords = model(batch, cpu_demo=True)[:, 1:4] # !!! Only use cpu_demo with batch size of 1 (quick fix for unavailability of xformers attention on CPU) | |
| clust_labels = get_clustering_labels(coords.detach().cpu().numpy(), batch.batch_idx, min_cluster_size=args.min_cluster_size, min_samples=args.min_samples, epsilon=args.epsilon) | |
| jets_pxyz = scatter_sum(torch.tensor(pfcands.pxyz), torch.tensor(clust_labels+1), dim=0)[1:] | |
| jets_pt = torch.norm(jets_pxyz[:, :2], p=2, dim=-1) | |
| filt = torch.where(jets_pt > 30)[0].tolist() | |
| jets_eta, jets_phi = calc_eta_phi(jets_pxyz, False) | |
| clust_assignment = {} | |
| for i in range(len(clust_labels)): | |
| if clust_labels[i] in filt and clust_labels[i] != -1: | |
| clust_assignment[i] = filt.index(clust_labels[i]) | |
| jets_pt = jets_pt[filt] | |
| jets_eta = jets_eta[filt] | |
| jets_phi = jets_phi[filt] | |
| ak_pt, ak_eta, ak_phi, _, ak_assignment = EventDataset.get_jets_fastjets_raw_with_assignment(pfcands, fastjet.JetDefinition(fastjet.antikt_algorithm, 0.8), pt_cutoff=30) | |
| model_coords = calc_eta_phi(coords, return_stacked=0) | |
| clist = ['#1f78b4', '#b3df8a', '#33a02c', '#fb9a99', '#e31a1c', '#fdbe6f', '#ff7f00', '#cab2d6', '#6a3d9a', '#ffff99', | |
| '#b15928'] | |
| colors = { | |
| -1: "gray", | |
| 0: clist[0], | |
| 1: clist[1], | |
| 2: clist[2], | |
| 3: clist[3], | |
| 4: clist[4], | |
| 5: clist[5], | |
| 6: clist[6], | |
| 7: clist[7], | |
| } | |
| c = [] | |
| c_ak = [] | |
| for i in range(len(pfcands)): | |
| if i in ak_assignment: | |
| c_ak.append(colors.get(ak_assignment[i], "purple")) | |
| else: | |
| c_ak.append("gray") | |
| if i in clust_assignment: | |
| c.append(colors.get(clust_assignment[i], "gray")) | |
| else: | |
| c.append("gray") | |
| import matplotlib.pyplot as plt | |
| fig, ax = plt.subplots(1, 3, figsize=(10, 3.33)) # with AK colors, with model colors, with model colors in clustering space | |
| ax[0].set_title("Colors: AK clusters") | |
| ax[1].set_title("Colors: Model clusters") | |
| ax[2].set_title("Colors: Model clusters in cl. space") | |
| plot_event(event, colors=c_ak, ax=ax[0], jets=0) | |
| plot_event(event, colors=c, ax=ax[1], jets=0) | |
| plot_event(event, colors=c, ax=ax[2], custom_coords=model_coords, jets=0) | |
| model_jets, ak_jets = [], [] | |
| for j in range(len(ak_pt)): | |
| if ak_pt[j] >= 30: | |
| ax[0].text(ak_eta[j] + 0.1, ak_phi[j] + 0.1, | |
| "pt=" + str(round(ak_pt[j], 1)), color="blue", fontsize=6, alpha=0.5) | |
| ak_jets.append({"pt": ak_pt[j], "eta": ak_eta[j], "phi": ak_phi[j]}) | |
| if ak_pt[j] >= 100: | |
| for k in range(3): | |
| circle = plt.Circle((ak_eta[j], ak_phi[j]), 0.8, color="green", fill=False, alpha=.7) | |
| ax[k].add_artist(circle) | |
| for j in range(len(jets_pt)): | |
| if jets_pt[j] >= 30: | |
| ax[1].text(jets_eta[j] + 0.1, jets_phi[j] + 0.1, | |
| "pt=" + str(round(jets_pt[j].item(), 1)), color="blue", fontsize=6, alpha=0.5) | |
| model_jets.append({"pt": jets_pt[j].item(), "eta": jets_eta[j].item(), "phi": jets_phi[j].item()}) | |
| if jets_pt[j] >= 100: | |
| for k in range(3): | |
| circle = plt.Circle((jets_eta[j], jets_phi[j]), 0.7, color="blue", fill=False, alpha=.7) | |
| ax[k].add_artist(circle) | |
| for k in range(3): | |
| #for n in range(len(phi_quarks)): | |
| # # add triangle symb | |
| ax[k].scatter(eta_quarks, phi_quarks, s=pt_quarks, c="red", marker="^", alpha=0.3) | |
| ax[k].set_xlabel("$\eta$") | |
| ax[k].set_ylabel("$\phi$") | |
| fig.tight_layout() | |
| return model_jets, ak_jets, fig | |