Yajur Preetham commited on
Commit
c4bb20c
·
1 Parent(s): 9328007

Added script to visualize all input variable distributions for a model.

Browse files
root_gnn_dgl/root_gnn_base/visualize_input_distributions.py ADDED
@@ -0,0 +1,582 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import matplotlib.pyplot as plt
3
+ import pandas as pd
4
+ import uproot
5
+ import yaml
6
+ import argparse
7
+ import sys
8
+ from pathlib import Path
9
+ from array import array
10
+ import os
11
+ import awkward as ak
12
+ import math
13
+
14
+ def tree_to_dataframe(tree_filepath, sort_by="", branches=[]):
15
+ """
16
+ Convert a ROOT tree to a Pandas DataFrame (Assuming data is columnar).
17
+ Depends on uproot and pandas libraries (import them before-hand).
18
+ """
19
+ data_dict = {} # Use dictionary instead of list
20
+
21
+ with uproot.open(tree_filepath) as file:
22
+ if not branches: # If branches list is empty
23
+ keys = file.keys()
24
+ for key in keys:
25
+ try:
26
+ data_dict[key] = file[key].array(library="pd")
27
+ except Exception as e:
28
+ print(f"Warning: Could not load branch '{key}': {e}")
29
+ else: # If specific branches are requested
30
+ for branch in branches:
31
+ try:
32
+ data_dict[branch] = file[branch].array(library="pd")
33
+ except KeyError:
34
+ print(f"Warning: Branch '{branch}' not found in ROOT file")
35
+ except Exception as e:
36
+ print(f"Warning: Could not load branch '{branch}': {e}")
37
+
38
+ # Create DataFrame from dictionary
39
+ data = pd.DataFrame(data_dict)
40
+
41
+ if sort_by == "":
42
+ return data
43
+ else:
44
+ if sort_by in data.columns:
45
+ data.sort_values(by=[sort_by], inplace=True)
46
+ data.reset_index(inplace=True, drop=True)
47
+ else:
48
+ print(f"Warning: Sort column '{sort_by}' not found in DataFrame")
49
+ return data
50
+
51
+ def extract_dataset_info(yaml_file_path):
52
+ with open(yaml_file_path, 'r') as file:
53
+ config = yaml.safe_load(file)
54
+
55
+ datasets_info = {}
56
+ if "Datasets" in config:
57
+ for dset_name, dset_config in config['Datasets'].items():
58
+ if 'args' not in dset_config:
59
+ continue
60
+ args = dset_config["args"]
61
+ dset_info = {}
62
+ if "raw_dir" in args:
63
+ dset_info["raw_dir"] = args["raw_dir"]
64
+ if "file_names" in args:
65
+ dset_info["file_names"] = args["file_names"]
66
+ if "node_branch_names" in args:
67
+ dset_info["node_branch_names"] = args["node_branch_names"]
68
+ if "name" in args:
69
+ dset_info["name"] = args["name"]
70
+ if "node_feature_scales" in args:
71
+ dset_info["node_feature_scales"] = args["node_feature_scales"]
72
+ if "tree_name" in args:
73
+ dset_info["tree_name"] = args["tree_name"]
74
+ if "label" in args:
75
+ dset_info["label"] = args["label"]
76
+ # if "exclude_zeros" in args:
77
+ # dset_info["exclude_zeros"] = args["exclude_zeros"]
78
+ # if "exclude_zeros" not in args:
79
+ # print("ERROR: Please add the following variable to your config, under args for each dataset:\nFor example, exclude_zeros: [pt, phi, eta]\exclude_zeros should be a list that contains the endings of the names of observables that you want to exclude the value 0 from while plotting histograms.")
80
+ # sys.exit()
81
+ if dset_info:
82
+ datasets_info[dset_name] = dset_info
83
+ return(datasets_info)
84
+
85
+ def adaptive_bins(data, method='auto'):
86
+ """Choose optimal number of bins based on data characteristics"""
87
+ data = np.array([x for x in data if x is not None and not np.isnan(x)])
88
+
89
+ if len(data) == 0:
90
+ return 10
91
+
92
+ if method == 'sturges':
93
+ return int(np.ceil(np.log2(len(data)) + 1))
94
+ elif method == 'scott':
95
+ h = 3.5 * np.std(data) / (len(data) ** (1/3))
96
+ return int(np.ceil((np.max(data) - np.min(data)) / h))
97
+ elif method == 'freedman':
98
+ iqr = np.percentile(data, 75) - np.percentile(data, 25)
99
+ h = 2 * iqr / (len(data) ** (1/3))
100
+ return int(np.ceil((np.max(data) - np.min(data)) / h)) if h > 0 else 50
101
+ elif method == 'sqrt':
102
+ return int(np.ceil(np.sqrt(len(data))))
103
+ else: # 'auto'
104
+ return 'auto' # Let matplotlib decide
105
+
106
+ def safe_clean_data(data, observable_name=""):
107
+ """Safely clean data, handling different data types and ignoring zeros for specific observables"""
108
+ if data is None or len(data) == 0:
109
+ return []
110
+
111
+ # Convert to numpy array if it isn't already
112
+ if not isinstance(data, np.ndarray):
113
+ data = np.array(data)
114
+
115
+ # Check if we should ignore zeros
116
+ # ignore_zeros = observable_name.lower().endswith(exclude_zeros)
117
+
118
+ # Handle different data types
119
+ if data.dtype.kind in ['i', 'f']: # integer or float
120
+ # Numeric data - can use isnan and isfinite
121
+ if data.dtype.kind == 'f': # float
122
+ mask = ~np.isnan(data) & np.isfinite(data)
123
+ clean_data = data[mask]
124
+ else: # integer
125
+ clean_data = data # integers don't have NaN/inf issues
126
+
127
+ clean_data = clean_data[(clean_data != -999) & (clean_data != -1)]
128
+
129
+ # Remove zeros if needed
130
+ # if ignore_zeros:
131
+ # clean_data = clean_data[clean_data != 0]
132
+
133
+ return clean_data
134
+ else:
135
+ # Non-numeric data - filter manually
136
+ clean_list = []
137
+ for item in data:
138
+ if item is None:
139
+ continue
140
+
141
+ try:
142
+ # Try to convert to float to check if it's numeric
143
+ float_val = float(item)
144
+ if not (np.isnan(float_val) or np.isinf(float_val)):
145
+ # Check if we should ignore zeros
146
+ if ignore_zeros and float_val == 0:
147
+ continue
148
+ clean_list.append(float_val)
149
+ except (ValueError, TypeError):
150
+ # Not convertible to float, skip
151
+ continue
152
+ return np.array(clean_list) if clean_list else np.array([])
153
+
154
+ def make_distributions(dset_info, output_dir, exclude_zeros):
155
+ os.makedirs(output_dir, exist_ok=True)
156
+ awk_type = ak.Array
157
+ list_type = type([])
158
+
159
+ for dset_name in dset_info:
160
+ curr_dset_info = dset_info[dset_name]
161
+ curr_df = tree_to_dataframe(f"{curr_dset_info['raw_dir']}{curr_dset_info['file_names']}:{curr_dset_info['tree_name']}")
162
+
163
+ # Collect all observables and their data for this dataset
164
+ observables_data = {}
165
+
166
+ for branch in curr_dset_info["node_branch_names"]:
167
+ if type(branch) != list_type:
168
+ continue
169
+ for observable in branch:
170
+ if type(observable) != type("str"):
171
+ continue
172
+ try:
173
+ data = curr_df[observable]
174
+ if type(data.iloc[0]) == awk_type or type(data.iloc[0]) == list_type:
175
+ appended_data = []
176
+ for i in range(len(data.iloc[0])):
177
+ try:
178
+ ith_obs_data = np.array([x[i] if x is not None and len(x) > i else None for x in data])
179
+ # Filter out None values
180
+ ith_obs_data = ith_obs_data[ith_obs_data != None]
181
+ if len(ith_obs_data) > 0:
182
+ appended_data.append(ith_obs_data)
183
+ except (IndexError, TypeError):
184
+ continue
185
+ if appended_data:
186
+ plot_data = np.concatenate(appended_data)
187
+ observables_data[observable] = plot_data
188
+ else:
189
+ observables_data[observable] = data
190
+ except KeyError:
191
+ continue
192
+
193
+ # Create subplot grid for all observables in this dataset
194
+ if not observables_data:
195
+ print(f"No data found for {dset_name}")
196
+ continue
197
+
198
+ n_observables = len(observables_data)
199
+
200
+ # Calculate grid dimensions (try to make it roughly square)
201
+ n_cols = math.ceil(math.sqrt(n_observables))
202
+ n_rows = math.ceil(n_observables / n_cols)
203
+
204
+ # Create the figure with subplots
205
+ fig, axes = plt.subplots(n_rows, n_cols, figsize=(4*n_cols, 3*n_rows))
206
+ fig.suptitle(f'All Distributions for {dset_name}', fontsize=16, y=0.98)
207
+
208
+ # Handle case where there's only one subplot
209
+ if n_observables == 1:
210
+ axes = [axes]
211
+ elif n_rows == 1:
212
+ axes = axes.reshape(1, -1)
213
+ elif n_cols == 1:
214
+ axes = axes.reshape(-1, 1)
215
+
216
+ # Flatten axes for easy iteration
217
+ axes_flat = axes.flatten() if n_observables > 1 else axes
218
+
219
+ # Plot each observable
220
+ for idx, (observable, plot_data) in enumerate(observables_data.items()):
221
+ ax = axes_flat[idx]
222
+
223
+ # Clean data safely
224
+ clean_data = safe_clean_data(plot_data, exclude_zeros, observable)
225
+
226
+ if len(clean_data) > 0:
227
+ try:
228
+ bins = adaptive_bins(clean_data, method="freedman")
229
+ # Plot histogram with label including event count
230
+ ax.hist(clean_data, histtype="step", density=True, bins=bins,
231
+ label=f'N = {len(clean_data):,}')
232
+ if observable.lower().endswith(exclude_zeros):
233
+ ax.set_title(f'{observable} (zeros excluded)', fontsize=10)
234
+ else:
235
+ ax.set_title(f'{observable}', fontsize=10)
236
+ ax.set_xlabel(f'{observable}', fontsize=8)
237
+ ax.set_ylabel('Density', fontsize=8)
238
+ ax.tick_params(axis='both', which='major', labelsize=7)
239
+ ax.grid(True, alpha=0.3)
240
+
241
+ # Add legend with event count
242
+ ax.legend(fontsize=8, loc='upper right')
243
+
244
+ except Exception as e:
245
+ print(f"Error plotting {observable}: {e}")
246
+ ax.text(0.5, 0.5, f'Plot error:\n{str(e)[:50]}...', ha='center', va='center',
247
+ transform=ax.transAxes, fontsize=8)
248
+ ax.set_title(f'{observable} (Error)', fontsize=10)
249
+ else:
250
+ ax.text(0.5, 0.5, 'No valid data\nN = 0', ha='center', va='center',
251
+ transform=ax.transAxes)
252
+ ax.set_title(f'{observable} (No Data)', fontsize=10)
253
+
254
+ # Hide unused subplots
255
+ for idx in range(n_observables, len(axes_flat)):
256
+ axes_flat[idx].set_visible(False)
257
+
258
+ # Adjust layout and save
259
+ plt.tight_layout()
260
+ plt.subplots_adjust(top=0.93) # Make room for suptitle
261
+ plt.savefig(f"{output_dir}/{dset_name}_all_distributions.png",
262
+ dpi=300, bbox_inches='tight')
263
+ plt.close()
264
+
265
+ print(f"Created combined plot for {dset_name} with {n_observables} observables")
266
+
267
+ def make_distributions_comparison_grid_by_label(dset_info, output_dir, output_filename, label_names=None, use_percentile_for_xlims = False, xlim_adjustment = False):
268
+ """Create comparison plots grouped by label instead of dataset
269
+
270
+ Args:
271
+ dset_info: Dictionary containing dataset information
272
+ output_dir: Directory to save output plots
273
+ label_names: Optional list of strings to use as label names in legends.
274
+ If provided, must have length equal to number of unique labels.
275
+ Index corresponds to label number.
276
+ """
277
+ os.makedirs(output_dir, exist_ok=True)
278
+ awk_type = ak.Array
279
+ list_type = type([])
280
+
281
+ label_to_datasets = {}
282
+ for dset_name, curr_dset_info in dset_info.items():
283
+ dataset_label = curr_dset_info.get('label', 'Unknown')
284
+ if dataset_label not in label_to_datasets:
285
+ label_to_datasets[dataset_label] = []
286
+ label_to_datasets[dataset_label].append(dset_name)
287
+
288
+ # First, collect all data organized by observable and then by label
289
+ observables_by_variable = {}
290
+
291
+ for dset_name in dset_info:
292
+ print(f"Processing dataset: {dset_name}")
293
+ curr_dset_info = dset_info[dset_name]
294
+
295
+ # Get the label for this dataset
296
+ dataset_label = curr_dset_info.get('label', 'Unknown')
297
+ print(f" Label: {dataset_label}")
298
+
299
+ if type(curr_dset_info['file_names']) == type("str"):
300
+ curr_df = tree_to_dataframe(f"{curr_dset_info['raw_dir']}{curr_dset_info['file_names']}:{curr_dset_info['tree_name']}")
301
+ else:
302
+ curr_df_list = []
303
+ for i in range(len(curr_dset_info['file_names'])):
304
+ curr_name = curr_dset_info['file_names'][i]
305
+ curr_curr_df = tree_to_dataframe(f"{curr_dset_info['raw_dir']}{curr_name}:{curr_dset_info['tree_name']}")
306
+ curr_df_list.append(curr_curr_df)
307
+ curr_df = pd.concat(curr_df_list, ignore_index = True)
308
+
309
+ for branch in curr_dset_info["node_branch_names"]:
310
+ if type(branch) != list_type:
311
+ continue
312
+ for observable in branch:
313
+ if type(observable) != type("str"):
314
+ continue
315
+ try:
316
+ data = curr_df[observable]
317
+
318
+ # Initialize observable dict if not exists
319
+ if observable not in observables_by_variable:
320
+ observables_by_variable[observable] = {}
321
+
322
+ # Initialize label dict if not exists
323
+ if dataset_label not in observables_by_variable[observable]:
324
+ observables_by_variable[observable][dataset_label] = []
325
+
326
+ if type(data.iloc[0]) == awk_type or type(data.iloc[0]) == list_type:
327
+ appended_data = []
328
+ # for i in range(len(data.iloc[0])):
329
+ # try:
330
+ # ith_obs_data = np.array([x[i] if x is not None and len(x) > i else None for x in data])
331
+ # ith_obs_data = ith_obs_data[ith_obs_data != None]
332
+ # if len(ith_obs_data) > 0:
333
+ # appended_data.append(ith_obs_data)
334
+ # except (IndexError, TypeError):
335
+ # continue
336
+ for x in data:
337
+ row_data = []
338
+ for i in range(len(x)):
339
+ if x[i] == 0 or x[i] == 0.0:
340
+ continue
341
+ row_data.append(x[i])
342
+ row_data = np.array(row_data)
343
+ row_data = row_data[row_data != None]
344
+ if len(row_data > 0):
345
+ appended_data.append(row_data)
346
+
347
+ if appended_data:
348
+ plot_data = np.concatenate(appended_data)
349
+ observables_by_variable[observable][dataset_label].append(plot_data)
350
+ else:
351
+ observables_by_variable[observable][dataset_label].append(data)
352
+
353
+ except KeyError:
354
+ continue
355
+
356
+ # Combine data for each label (since multiple datasets might have the same label)
357
+ observables_by_label = {}
358
+ for observable, labels_data in observables_by_variable.items():
359
+ observables_by_label[observable] = {}
360
+ for label, data_list in labels_data.items():
361
+ if data_list:
362
+ # Concatenate all data for this label
363
+ combined_data = []
364
+ for data in data_list:
365
+ clean_data = safe_clean_data(data, observable)
366
+ if len(clean_data) > 0:
367
+ combined_data.extend(clean_data)
368
+
369
+ if combined_data:
370
+ observables_by_label[observable][label] = np.array(combined_data)
371
+
372
+ # Filter out observables with no data
373
+ observables_by_label = {k: v for k, v in observables_by_label.items() if v}
374
+
375
+ if not observables_by_label:
376
+ print("No observables found!")
377
+ return
378
+
379
+ # Get consistent colors for labels across all plots
380
+ all_labels = set()
381
+ for labels_data in observables_by_label.values():
382
+ all_labels.update(labels_data.keys())
383
+ all_labels = sorted(list(all_labels)) # Sort for consistency
384
+
385
+ print(f"Found labels: {all_labels}")
386
+
387
+ # Validate label_names parameter if provided
388
+ if label_names is not None:
389
+ if len(label_names) != len(all_labels):
390
+ raise ValueError(f"label_names must have length {len(all_labels)} to match number of unique labels, but got {len(label_names)}")
391
+ print(f"Using custom label names: {label_names}")
392
+
393
+ # Calculate grid dimensions
394
+ n_observables = len(observables_by_label)
395
+ n_cols = math.ceil(math.sqrt(n_observables))
396
+ n_rows = math.ceil(n_observables / n_cols)
397
+
398
+ print(f"Creating comparison grid for {n_observables} observables ({n_rows}x{n_cols})")
399
+
400
+ # Create the big figure
401
+ fig, axes = plt.subplots(n_rows, n_cols, figsize=(5*n_cols, 4*n_rows))
402
+ fig.suptitle('Distribution Comparisons Across All Labels', fontsize=20, y=0.98)
403
+
404
+ # Handle different subplot configurations
405
+ if n_observables == 1:
406
+ axes = [axes]
407
+ elif n_rows == 1:
408
+ axes = axes.reshape(1, -1)
409
+ elif n_cols == 1:
410
+ axes = axes.reshape(-1, 1)
411
+
412
+ # Flatten axes for easy iteration
413
+ axes_flat = axes.flatten() if n_observables > 1 else axes
414
+
415
+ # Create color map for labels
416
+ colors = plt.cm.tab10(np.linspace(0, 1, len(all_labels)))
417
+ label_colors = dict(zip(all_labels, colors))
418
+
419
+ # Plot each observable
420
+ for idx, (observable, labels_data) in enumerate(observables_by_label.items()):
421
+ ax = axes_flat[idx]
422
+
423
+ # Calculate consistent bins based on ALL data for this observable
424
+ all_combined_data = []
425
+ for label_data in labels_data.values():
426
+ all_combined_data.extend(label_data)
427
+
428
+ if not all_combined_data:
429
+ ax.text(0.5, 0.5, 'No valid data', ha='center', va='center', transform=ax.transAxes)
430
+ ax.set_title(f'{observable} (No Data)', fontsize=12)
431
+ continue
432
+
433
+ combined_array = np.array(all_combined_data)
434
+ if observable == "ph_phi" or observable == "ph_eta":
435
+ n_bins = 10
436
+ elif observable == "m_jet_btag77":
437
+ n_bins = 4
438
+ else:
439
+ n_bins = adaptive_bins(combined_array, method="freedman")
440
+ if n_bins > 35: ### CONTROL FINENESS OF BINNING HERE!!!!
441
+ n_bins = 35
442
+ bin_edges = np.histogram_bin_edges(combined_array, bins=n_bins)
443
+
444
+ print(f"{observable}: Using {len(bin_edges)-1} consistent bins for {len(labels_data)} labels")
445
+
446
+ # Plot each label's distribution for this observable
447
+ for label, plot_data in labels_data.items():
448
+ try:
449
+ # Determine label for legend
450
+ if label_names is not None:
451
+ # Use custom label name based on label index
452
+ label_idx = all_labels.index(label)
453
+ legend_label = f'{label_names[label_idx]} (N={len(plot_data):,})'
454
+ else:
455
+ # Use original format
456
+ legend_label = f'Label {label} (N={len(plot_data):,})'
457
+
458
+ ax.hist(plot_data, bins=bin_edges, histtype="step", density=True,
459
+ label=legend_label,
460
+ color=label_colors[label], linewidth=1.5, alpha=0.8)
461
+ except Exception as e:
462
+ print(f"Error plotting {observable} for label {label}: {e}")
463
+ continue
464
+
465
+ # Add title and labels
466
+ title = f'{observable}'
467
+ # if observable.lower().endswith(exclude_zeros):
468
+ # title += ' (zeros excluded)'
469
+
470
+ if use_percentile_for_xlims and xlim_adjustment:
471
+ print("ERROR: Only provide one of the flags at a time, either --use_percentile_for_xlims or --xlim_adjustment")
472
+ return()
473
+ if not use_percentile_for_xlims and not xlim_adjustment:
474
+ ax.set_xlim(bin_edges[0], bin_edges[-1])
475
+ elif use_percentile_for_xlims:
476
+ combined_array = np.array(all_combined_data)
477
+ ax.set_xlim(bin_edges[0], np.percentile(combined_array, 98))
478
+ elif xlim_adjustment:
479
+ combined_array = np.array(all_combined_data)
480
+ min_edge = max(bin_edges[0], np.mean(combined_array) - 3*np.std(combined_array))
481
+ max_edge = min(bin_edges[-1], np.mean(combined_array) + 3*np.std(combined_array))
482
+ ax.set_xlim(min_edge, max_edge)
483
+
484
+ ax.set_title(title, fontsize=12, pad=10)
485
+ ax.set_xlabel(f'{observable}', fontsize=10)
486
+ ax.set_ylabel('Density', fontsize=10)
487
+ ax.tick_params(axis='both', which='major', labelsize=8)
488
+ ax.grid(True, alpha=0.3)
489
+
490
+ # Create legend
491
+ if len(labels_data) <= 5:
492
+ if label_names is not None:
493
+ # Simple legend with just custom names and counts
494
+ ax.legend(fontsize=8, loc='best')
495
+ else:
496
+ # Create custom legend labels with dataset information
497
+ legend_labels = []
498
+ for label in labels_data.keys():
499
+ datasets = label_to_datasets.get(label, [])
500
+
501
+ if len(datasets) == 1:
502
+ # Single dataset
503
+ dataset_info = datasets[0]
504
+ elif len(datasets) <= 2:
505
+ # Few datasets - show all names
506
+ dataset_info = ', '.join(datasets)
507
+ else:
508
+ # Many datasets - show count
509
+ dataset_info = f"{datasets[0]}, +{len(datasets)-1} more"
510
+
511
+ legend_labels.append(f'Label {label} (N={len(labels_data[label]):,})\n{dataset_info}')
512
+
513
+ # Get the legend handles and update their labels
514
+ handles, _ = ax.get_legend_handles_labels()
515
+ ax.legend(handles, legend_labels, fontsize=6, loc='best')
516
+ else:
517
+ total_events = sum(len(data) for data in labels_data.values())
518
+ ax.set_title(f'{title}\n(Total N={total_events:,})', fontsize=11)
519
+
520
+ # Hide unused subplots
521
+ for idx in range(n_observables, len(axes_flat)):
522
+ axes_flat[idx].set_visible(False)
523
+
524
+ # Adjust layout and save
525
+ plt.tight_layout()
526
+ plt.subplots_adjust(top=0.94, right=0.85 if len(all_labels) > 5 else 0.95)
527
+
528
+ output_path = f"{output_dir}/{output_filename}"
529
+ plt.savefig(output_path, dpi=300, bbox_inches='tight', facecolor='white')
530
+ plt.close()
531
+
532
+ print(f"Created comparison grid by label: {output_path}")
533
+ print(f"Grid contains {n_observables} observables across {len(all_labels)} labels")
534
+
535
+ # Print summary of what was combined
536
+ print("\nLabel summary:")
537
+ for label in all_labels:
538
+ datasets_with_label = [dset for dset, info in dset_info.items() if info.get('label') == label]
539
+ if label_names is not None:
540
+ label_idx = all_labels.index(label)
541
+ display_name = label_names[label_idx]
542
+ else:
543
+ display_name = f"Label {label}"
544
+ print(f" {display_name}: {len(datasets_with_label)} datasets ({', '.join(datasets_with_label)})")
545
+
546
+ def main(): ###DONT SPECIFY EXCLUDE ZEROS HERE, BUT RATHER DERIVE IT FROM THE CONFIG!!!
547
+ parser = argparse.ArgumentParser()
548
+ add_arg = parser.add_argument
549
+
550
+ add_arg("--config", type=str, required = True, help = "The path to the config.")
551
+ add_arg("--output_dir", type=str, required = True, help = "The path of the directory where you want the plots to be outputted to.")
552
+ add_arg('--label_names', nargs='+', default = ["None"], help = "A list of the names associated with each label to be displayed in the legends of the histograms.")
553
+ add_arg("--output_filename", type=str, default = "input_var_distribution_comparisons.png", help = "The name of the file you want the plots to be outputted to.")
554
+ add_arg("--use_percentile_for_xlims", action = "store_true", help = "If this flag is provided, the xlims will be set as [first bin edge, 98th percentile] rather than [first bin edge, last bin edge].")
555
+ add_arg("--xlim_adjustment", action = "store_true", help = "If this flag is provided, the xlims will be set using the mean and std of the data.")
556
+
557
+ args = parser.parse_args()
558
+
559
+ config_filepath = args.config
560
+ output_dir = args.output_dir
561
+ label_names = args.label_names
562
+ output_filename = args.output_filename
563
+ use_percentile = args.use_percentile_for_xlims
564
+ xlim_adjustment = args.xlim_adjustment
565
+
566
+ dset = extract_dataset_info(config_filepath)
567
+
568
+ # exclude_zeros_list = []
569
+ # for key in dset:
570
+ # exclude_zeros_list = dset[key]["exclude_zeros"]
571
+ # break
572
+
573
+ # exclude_zeros = tuple(exclude_zeros_list)
574
+
575
+ # make_distributions(dset, output_dir, exclude_zeros)
576
+ if label_names[0] == "None":
577
+ make_distributions_comparison_grid_by_label(dset, output_dir, output_filename, use_percentile_for_xlims=use_percentile, xlim_adjustment=xlim_adjustment)
578
+ else:
579
+ make_distributions_comparison_grid_by_label(dset, output_dir, output_filename, label_names, use_percentile, xlim_adjustment)
580
+
581
+ if __name__ == "__main__":
582
+ main()