LLM4HEP / check_soln.py
ho22joshua's picture
initial commit
cfcbbc8
import os
import sys
import numpy as np
import matplotlib.pyplot as plt
# ATLAS style only needed for plotting
try:
import atlas_mpl_style as ampl
ampl.use_atlas_style()
plt.rcParams['font.family'] = 'DejaVu Sans'
except ImportError:
print("Warning: ATLAS style not available, using default matplotlib style")
plt.style.use('default')
# Plotting helpers are not used in array-only validation, keep import disabled to reduce deps
# from utils_plot import plot_myy_comparison, plot_scores_comparison
import argparse
parser = argparse.ArgumentParser()
add_arg = parser.add_argument
add_arg('--out_dir', help='output directory')
add_arg('--step', type=int, choices=[1, 2, 3, 4, 5],
help='Validate only specific step (1-5)')
args = parser.parse_args()
out_dir = args.out_dir
specific_step = args.step
def arrays_match(generated, reference, name: str, atol: float = 1e-10) -> bool:
"""
Compare two numpy arrays element-wise with a strict absolute tolerance.
- NaNs are considered equal when they appear at the same positions.
- rtol is set to 0.0 so only absolute tolerance matters.
Prints a concise status and returns True/False.
"""
print(f"Validating {name}...")
if generated.shape != reference.shape:
print(f" ❌ Shape mismatch: {generated.shape} vs {reference.shape}")
return False
ok = np.allclose(generated, reference, rtol=0.0, atol=atol, equal_nan=True)
if ok:
print(f" ✅ {name} matches (atol={atol})")
return True
# Brief diff stats to aid debugging
nan_mask_equal = np.array_equal(np.isnan(generated), np.isnan(reference))
finite = (~np.isnan(generated)) & (~np.isnan(reference))
mismatches = int(np.sum(generated[finite] != reference[finite]))
print(f" ❌ {name} differs: NaN mask equal={nan_mask_equal}, finite mismatches={mismatches}/{int(finite.sum())}")
if finite.any():
diffs = np.abs(generated[finite] - reference[finite])
print(f" diff stats: max={diffs.max():.6g}, mean={diffs.mean():.6g}")
# Additional debug: show sample mismatches
print("🔍 Running detailed mismatch analysis...")
analyze_array_differences(generated, reference, name)
return False
def calculate_adaptive_tolerance(values, significant_digits=4):
"""
Calculate adaptive tolerance based on the magnitude of values to achieve desired significant digits.
For each value, the tolerance is set to preserve the specified number of significant digits.
Examples:
- Value 123000 with 4 sig digits: tolerance = 1000 (1e3)
- Value 0.00014 with 4 sig digits: tolerance = 0.0000014 (1.4e-6)
- Value 0 with 4 sig digits: tolerance = 1e-10 (small default)
"""
# Handle zero values
non_zero_mask = values != 0
tolerances = np.full_like(values, 1e-10, dtype=float) # Default for zeros
if np.any(non_zero_mask):
# Calculate tolerance as value / 10^(significant_digits)
# This preserves the desired number of significant digits
abs_values = np.abs(values[non_zero_mask])
tolerances[non_zero_mask] = abs_values / (10 ** significant_digits)
return tolerances
def analyze_array_differences(generated, reference, array_name, significant_digits=4):
"""
Analyze differences between generated and reference numpy arrays.
Uses adaptive tolerance based on significant digits rather than fixed tolerance.
"""
print(f"\n🔍 Detailed analysis for {array_name} (using {significant_digits} significant digit tolerance):")
print(f" Generated shape: {generated.shape}, Reference shape: {reference.shape}")
print(f" Tolerance: Adaptive based on {significant_digits} significant digits per value")
# Check for shape differences first
if generated.shape != reference.shape:
print(f" ❌ Shape mismatch: {generated.shape} vs {reference.shape}")
return
# Calculate adaptive tolerances for each element
combined_values = np.abs(np.concatenate([generated.flatten(), reference.flatten()]))
adaptive_tolerances = calculate_adaptive_tolerance(combined_values, significant_digits)
# Reshape tolerances to match original arrays
atol_array = adaptive_tolerances[:generated.size].reshape(generated.shape)
# Use absolute tolerance only (relative tolerance not used)
# Find differences and identify where tolerances are exceeded
diff = generated - reference
abs_diff = np.abs(diff)
not_close = abs_diff > atol_array
# Remove any comparisons involving NaNs (gen or ref)
invalid = np.isnan(generated) | np.isnan(reference)
not_close = not_close & ~invalid
total_different = np.sum(not_close)
if total_different == 0:
print(" ✅ All elements match within tolerance")
return
print(f" ❌ {total_different} elements differ (out of {generated.size} total)")
# Show numeric mismatches only (exclude any NaN comparisons)
flat_gen = generated.flatten()
flat_ref = reference.flatten()
flat_not_close = not_close.flatten()
# Mask to include only finite mismatches
numeric_mask = (~np.isnan(flat_gen)) & (~np.isnan(flat_ref))
mismatch_mask = flat_not_close & numeric_mask
if np.any(mismatch_mask):
diff_indices = np.where(mismatch_mask)[0][:10]
print(" 📊 Sample numeric mismatches (first 10 indices):")
for idx in diff_indices:
gen_val = flat_gen[idx]
ref_val = flat_ref[idx]
diff_val = gen_val - ref_val
print(f" Index {idx}: gen={gen_val}, ref={ref_val}, diff={diff_val}")
else:
print(" ✅ No numeric mismatches (all differences involve NaNs)")
# Skip overall statistics for now - they may not be meaningful for all data types
# Analyze differences by column (if 2D array)
if generated.ndim == 2:
col_diffs = np.sum(not_close, axis=0)
cols_with_diffs = np.where(col_diffs > 0)[0]
if len(cols_with_diffs) > 0:
print(f" 📊 Columns with differences: {cols_with_diffs[:10]} (showing first 10)")
# Show side-by-side entries for first 10 differing columns
num_cols_to_show = min(10, len(cols_with_diffs))
num_rows_to_show = min(5, generated.shape[0]) # Show first 5 rows
print(f" 📋 Sample entries (first {num_rows_to_show} rows, first {num_cols_to_show} differing columns):")
print(" Row | Column | Generated Value | Reference Value | Difference")
print(" ----|--------|----------------|-----------------|------------")
for col_idx in cols_with_diffs[:num_cols_to_show]:
for row_idx in range(num_rows_to_show):
gen_val = generated[row_idx, col_idx]
ref_val = reference[row_idx, col_idx]
diff = gen_val - ref_val
# Format values nicely
gen_str = f"{gen_val:.6g}" if not np.isnan(gen_val) else "NaN"
ref_str = f"{ref_val:.6g}" if not np.isnan(ref_val) else "NaN"
diff_str = f"{diff:.6g}" if not np.isnan(diff) else "NaN"
print(f" {row_idx:3d} | {col_idx:3d} | {gen_str:>14} | {ref_str:>15} | {diff_str:>10}")
else:
print(" ✅ All columns match within tolerance")
else:
print(" 📊 1D array - no column-by-column analysis needed")
# Check for special values - only warn if there's a significant difference
nan_gen = np.sum(np.isnan(generated))
nan_ref = np.sum(np.isnan(reference))
if nan_gen > 1000 or nan_ref > 1000: # Only show if significant number of NaNs
# Check if NaN counts are very similar (within 1% difference)
if nan_gen > 0 and nan_ref > 0:
nan_ratio = min(nan_gen, nan_ref) / max(nan_gen, nan_ref)
if nan_ratio > 0.99: # NaN counts are essentially identical
print(" ✅ Data structure consistency: Identical NaN patterns in generated and reference files")
print(f" - Both files have {nan_gen:,} NaN values (excellent consistency)")
else:
print(" ⚠️ Special values detected:")
if nan_gen > 1000:
print(f" - NaN in generated: {nan_gen:,}")
if nan_ref > 1000:
print(f" - NaN in reference: {nan_ref:,}")
else:
print(" ⚠️ Special values detected:")
if nan_gen > 1000:
print(f" - NaN in generated: {nan_gen:,}")
if nan_ref > 1000:
print(f" - NaN in reference: {nan_ref:,}")
def validate_root_summary(llm_content, ref_content):
"""
Validate root_summary.txt content by checking that all required branch names are present
Focus on content (branch names) rather than exact format structure
"""
try:
# Extract all branch names from LLM content
llm_branches = set(extract_branch_names(llm_content))
# Required branches that must be present
required_branches = {
'SumWeights', 'XSection', 'channelNumber', 'ditau_m', 'eventNumber',
'jet_E', 'jet_MV2c10', 'jet_eta', 'jet_jvt', 'jet_n', 'jet_phi', 'jet_pt',
'jet_pt_syst', 'jet_trueflav', 'jet_truthMatched', 'largeRjet_D2', 'largeRjet_E',
'largeRjet_eta', 'largeRjet_m', 'largeRjet_n', 'largeRjet_phi', 'largeRjet_pt',
'largeRjet_pt_syst', 'largeRjet_tau32', 'largeRjet_truthMatched', 'lep_E',
'lep_charge', 'lep_eta', 'lep_etcone20', 'lep_isTightID', 'lep_n', 'lep_phi',
'lep_pt', 'lep_pt_syst', 'lep_ptcone30', 'lep_trackd0pvunbiased',
'lep_tracksigd0pvunbiased', 'lep_trigMatched', 'lep_truthMatched', 'lep_type',
'lep_z0', 'mcWeight', 'met_et', 'met_et_syst', 'met_phi', 'photon_E',
'photon_convType', 'photon_eta', 'photon_etcone20', 'photon_isTightID', 'photon_n',
'photon_phi', 'photon_pt', 'photon_pt_syst', 'photon_ptcone30', 'photon_trigMatched',
'photon_truthMatched', 'runNumber', 'scaleFactor_BTAG', 'scaleFactor_ELE',
'scaleFactor_LepTRIGGER', 'scaleFactor_MUON', 'scaleFactor_PHOTON', 'scaleFactor_PILEUP',
'scaleFactor_PhotonTRIGGER', 'scaleFactor_TAU', 'tau_BDTid', 'tau_E', 'tau_charge',
'tau_eta', 'tau_isTightID', 'tau_n', 'tau_nTracks', 'tau_phi', 'tau_pt',
'tau_pt_syst', 'tau_trigMatched', 'tau_truthMatched', 'trigE', 'trigM', 'trigP'
}
print(f" 📊 LLM output has {len(llm_branches)} unique words, Required: {len(required_branches)} branches")
# Debug: Show all required branch names found in txt file
found_required_branches = required_branches & llm_branches
if found_required_branches:
sorted_found = sorted(found_required_branches)
print(f" 🔍 Required branch names found in txt file: {', '.join(sorted_found)}")
# Check if we have any branches at all
if len(llm_branches) == 0:
print(" ❌ No branches found in LLM output")
return False
# Check if all required branches are present
missing_branches = required_branches - llm_branches
if missing_branches:
print(f" ❌ Missing {len(missing_branches)} required branches:")
for branch in sorted(missing_branches):
print(f" - {branch}")
return False
else:
print(" ✅ All required branches present in LLM output")
return True
except Exception as e:
print(f" ❌ Error parsing root_summary: {e}")
return False
def extract_branch_names(content):
"""
Extract all words from root_summary.txt content.
This approach parses the file into words and checks for branch names as tokens.
"""
import re
# Split content into words using regex to handle various separators
# This will capture words with underscores, dots, etc. as single tokens
words = re.findall(r'\b\w+\b', content)
# Convert to set to remove duplicates and for fast lookup
return set(words)
def parse_root_summary(content):
"""
Parse root_summary.txt content into structured data
Supports both reference format (File 1:, File 2:, etc.) and LLM format (single file summary)
"""
files = {}
current_file = None
lines = content.split('\n')
i = 0
while i < len(lines):
line = lines[i].strip()
# Look for file headers in reference format
if line.startswith('File ') and ':' in line:
# Extract filename
parts = line.split(': ')
if len(parts) >= 2:
filename = parts[1].strip()
current_file = filename
files[current_file] = {
'total_objects': 0,
'trees': 0,
'entries': 0,
'total_branches': 0,
'branches': {}
}
# Look for LLM format header (alternative format)
elif line.startswith('Root file: ') and ':' in line:
# Extract filename from path
parts = line.split(': ')
if len(parts) >= 2:
full_path = parts[1].strip()
filename = os.path.basename(full_path)
current_file = filename
files[current_file] = {
'total_objects': 1, # Assume 1 tree
'trees': 1,
'entries': 0, # Will be set if found
'total_branches': 0,
'branches': {}
}
# Parse file data
elif current_file and current_file in files:
if 'Total objects:' in line:
try:
files[current_file]['total_objects'] = int(line.split(':')[1].strip())
except Exception:
pass
elif 'Trees found:' in line:
try:
files[current_file]['trees'] = int(line.split(':')[1].strip())
except Exception:
pass
elif 'Entries:' in line:
try:
files[current_file]['entries'] = int(line.split(':')[1].strip())
except Exception:
pass
elif 'Common branches (' in line and ')' in line:
# Extract total branch count from common branches section
try:
count_part = line.split('(')[1].split(')')[0]
# This sets the total for all files since they're common
common_branch_count = int(count_part)
# Set this for all existing files
for filename in files:
files[filename]['total_branches'] = common_branch_count
except Exception:
pass
# Parse branch categories
branches = {}
j = i + 1
while j < len(lines) and not lines[j].strip().startswith('='):
branch_line = lines[j].strip()
if ': ' in branch_line:
category, branch_list = branch_line.split(': ', 1)
category = category.strip().lower()
branch_names = [b.strip() for b in branch_list.split(',')]
branches[category] = branch_names
j += 1
files[current_file]['branches'] = branches
i = j - 1 # Skip the lines we already processed
# Handle LLM format branch parsing (with - prefix)
elif line == 'TTree: mini':
# Count branches in LLM format
branches = {}
branch_lines = []
j = i + 1
while j < len(lines) and lines[j].strip() and not lines[j].strip().startswith('='):
branch_line = lines[j].strip()
if branch_line.startswith(' Branches:'):
# Skip the "Branches:" header
j += 1
continue
elif branch_line.startswith(' - '):
# Extract branch name from "- branch_name" format
branch_name = branch_line.replace(' - ', '').strip()
branch_lines.append(branch_name)
j += 1
# Categorize branches for LLM format
photon_branches = []
jet_branches = []
met_branches = []
lep_branches = []
tau_branches = []
event_branches = []
weights_branches = []
for branch in branch_lines:
if branch.startswith('photon_'):
photon_branches.append(branch)
elif branch.startswith('jet_'):
jet_branches.append(branch)
elif branch.startswith('met_'):
met_branches.append(branch)
elif branch.startswith('lep_'):
lep_branches.append(branch)
elif branch.startswith('tau_'):
tau_branches.append(branch)
elif branch in ['runNumber', 'eventNumber', 'channelNumber', 'mcWeight', 'trigE', 'trigM', 'trigP', 'ditau_m']:
event_branches.append(branch)
elif branch in ['SumWeights', 'XSection'] or branch.startswith('scaleFactor_') or branch.startswith('largeRjet_'):
weights_branches.append(branch)
if photon_branches:
branches['photon'] = photon_branches
if jet_branches:
branches['jet'] = jet_branches
if met_branches:
branches['met'] = met_branches
if lep_branches:
branches['lep'] = lep_branches
if tau_branches:
branches['tau'] = tau_branches
if event_branches:
branches['event'] = event_branches
if weights_branches:
branches['weights'] = weights_branches
files[current_file]['branches'] = branches
files[current_file]['total_branches'] = len(branch_lines)
i = j - 1 # Skip the lines we already processed
i += 1
return files
# Load reference solution files for steps 1 and 2 - only load what's needed
# This will be done after mode detection below
# Load existing reference files for steps 3, 4, 5
signal_soln = np.load('/global/cfs/projectdirs/atlas/dwkim/llm4hep/solution/arrays/signal.npy')
bkgd_soln = np.load('/global/cfs/projectdirs/atlas/dwkim/llm4hep/solution/arrays/bkgd.npy')
signal_scores_soln = np.load('/global/cfs/projectdirs/atlas/dwkim/llm4hep/solution/arrays/signal_scores.npy')
bkgd_scores_soln = np.load('/global/cfs/projectdirs/atlas/dwkim/llm4hep/solution/arrays/bkgd_scores.npy')
boundaries_soln = np.load('/global/cfs/projectdirs/atlas/dwkim/llm4hep/solution/arrays/boundaries.npy')
significances_soln = np.load('/global/cfs/projectdirs/atlas/dwkim/llm4hep/solution/arrays/significances.npy')
base_dir = os.path.join(out_dir, 'arrays')
missing_file_1 = False # Step 1: summarize_root files
missing_file_2 = False # Step 2: create_numpy files
missing_file_3 = False # Step 3: preprocess files
missing_file_4 = False # Step 4: scores files
missing_file_5 = False # Step 5: categorization files
# Step 1: Check summarize_root outputs (file_list.txt, root_summary.txt)
if not specific_step or specific_step == 1:
file_list_llm_path = os.path.join(out_dir, 'logs', 'file_list.txt')
root_summary_llm_path = os.path.join(out_dir, 'logs', 'root_summary.txt')
# Note: create_numpy_modified.txt comes from insert_root_summary rule (no LLM), so we don't validate it for step 1
if not (os.path.exists(file_list_llm_path) and os.path.exists(root_summary_llm_path)):
if not specific_step or specific_step == 1:
print("Step 1 (summarize_root) outputs missing")
missing_file_1 = True
# Step 2: Check create_numpy outputs (data_A_raw.npy and signal_WH_raw.npy)
if not specific_step or specific_step == 2:
# Check for the specific files requested: data_A_raw.npy and signal_WH_raw.npy
data_A_raw_llm_path = os.path.join(base_dir, 'data_A_raw.npy')
signal_WH_raw_llm_path = os.path.join(base_dir, 'signal_WH_raw.npy')
if os.path.exists(data_A_raw_llm_path) and os.path.exists(signal_WH_raw_llm_path):
data_raw_llm = np.load(data_A_raw_llm_path)
signal_raw_llm = np.load(signal_WH_raw_llm_path)
if not specific_step or specific_step == 2:
print("Found required files: data_A_raw.npy and signal_WH_raw.npy")
else:
if not specific_step or specific_step == 2:
print("Step 2 (create_numpy) outputs missing - data_A_raw.npy and/or signal_WH_raw.npy not found")
missing_file_2 = True
# Step 3: Check preprocess outputs (signal.npy, bkgd.npy)
if not specific_step or specific_step == 3:
signal_llm_path = os.path.join(base_dir, 'signal.npy')
if os.path.exists(signal_llm_path):
signal_llm = np.load(signal_llm_path)
else:
if not specific_step or specific_step == 3:
print("LLM generated signal sample does not exist (Step 3)")
missing_file_3 = True
bkgd_llm_path = os.path.join(base_dir, 'bkgd.npy')
if os.path.exists(bkgd_llm_path):
bkgd_llm = np.load(bkgd_llm_path)
else:
if not specific_step or specific_step == 3:
print("LLM generated background sample does not exist (Step 3)")
missing_file_3 = True
# Step 4: Check scores outputs (signal_scores.npy, bkgd_scores.npy)
if not specific_step or specific_step == 4:
signal_scores_llm_path = os.path.join(base_dir, 'signal_scores.npy')
if os.path.exists(signal_scores_llm_path):
signal_scores_llm = np.load(signal_scores_llm_path)
else:
if not specific_step or specific_step == 4:
print("LLM generated signal scores do not exist (Step 4)")
missing_file_4 = True
bkgd_scores_llm_path = os.path.join(base_dir, 'bkgd_scores.npy')
if os.path.exists(bkgd_scores_llm_path):
bkgd_scores_llm = np.load(bkgd_scores_llm_path)
else:
if not specific_step or specific_step == 4:
print("LLM generated background scores do not exist (Step 4)")
missing_file_4 = True
# Step 5: Check categorization outputs (boundaries.npy, significances.npy)
if not specific_step or specific_step == 5:
boundaries_llm_path = os.path.join(base_dir, 'boundaries.npy')
if os.path.exists(boundaries_llm_path):
boundaries_llm = np.load(boundaries_llm_path)
else:
if not specific_step or specific_step == 5:
print("LLM generated boundaries do not exist (Step 5)")
missing_file_5 = True
significances_llm_path = os.path.join(base_dir, 'significances.npy')
if os.path.exists(significances_llm_path):
significances_llm = np.load(significances_llm_path)
else:
if not specific_step or specific_step == 5:
print("LLM generated significances do not exist (Step 5)")
missing_file_5 = True
# Step 2: Check create_numpy outputs (data_A_raw.npy and signal_WH_raw.npy)
signal_raw_llm_path = os.path.join(base_dir, 'signal_raw.npy')
data_raw_llm_path = os.path.join(base_dir, 'data_raw.npy')
# Check for the specific files requested: data_A_raw.npy and signal_WH_raw.npy
data_A_raw_llm_path = os.path.join(base_dir, 'data_A_raw.npy')
signal_WH_raw_llm_path = os.path.join(base_dir, 'signal_WH_raw.npy')
if os.path.exists(data_A_raw_llm_path) and os.path.exists(signal_WH_raw_llm_path):
data_raw_llm = np.load(data_A_raw_llm_path)
signal_raw_llm = np.load(signal_WH_raw_llm_path)
else:
missing_file_2 = True
# Load reference files for Step 2 validation
selective_refs_loaded = False
standard_refs_loaded = False
data_A_raw_soln_path = '/global/cfs/projectdirs/atlas/dwkim/llm4hep/solution/arrays/data_A_raw.npy'
signal_WH_raw_soln_path = '/global/cfs/projectdirs/atlas/dwkim/llm4hep/solution/arrays/signal_WH_raw.npy'
signal_raw_soln_path = '/global/cfs/projectdirs/atlas/dwkim/llm4hep/solution/arrays/signal_raw.npy'
data_raw_soln_path = '/global/cfs/projectdirs/atlas/dwkim/llm4hep/solution/arrays/data_raw.npy'
# Try to load selective reference files first
if os.path.exists(data_A_raw_soln_path):
data_A_raw_soln = np.load(data_A_raw_soln_path)
selective_refs_loaded = True
if os.path.exists(signal_WH_raw_soln_path):
signal_WH_raw_soln = np.load(signal_WH_raw_soln_path)
selective_refs_loaded = True
# Also try to load standard reference files
if os.path.exists(signal_raw_soln_path):
signal_raw_soln = np.load(signal_raw_soln_path)
standard_refs_loaded = True
if os.path.exists(data_raw_soln_path):
data_raw_soln = np.load(data_raw_soln_path)
standard_refs_loaded = True
# Step 3: Check preprocess outputs (signal.npy, bkgd.npy)
signal_llm_path = os.path.join(base_dir, 'signal.npy')
if os.path.exists(signal_llm_path):
signal_llm = np.load(signal_llm_path)
else:
missing_file_3 = True
bkgd_llm_path = os.path.join(base_dir, 'bkgd.npy')
if os.path.exists(bkgd_llm_path):
bkgd_llm = np.load(bkgd_llm_path)
else:
missing_file_3 = True
# Step 4: Check scores outputs (signal_scores.npy, bkgd_scores.npy)
signal_scores_llm_path = os.path.join(base_dir, 'signal_scores.npy')
if os.path.exists(signal_scores_llm_path):
signal_scores_llm = np.load(signal_scores_llm_path)
else:
missing_file_4 = True
bkgd_scores_llm_path = os.path.join(base_dir, 'bkgd_scores.npy')
if os.path.exists(bkgd_scores_llm_path):
bkgd_scores_llm = np.load(bkgd_scores_llm_path)
else:
missing_file_4 = True
# Step 5: Check categorization outputs (boundaries.npy, significances.npy)
boundaries_llm_path = os.path.join(base_dir, 'boundaries.npy')
if os.path.exists(boundaries_llm_path):
boundaries_llm = np.load(boundaries_llm_path)
else:
missing_file_5 = True
significances_llm_path = os.path.join(base_dir, 'significances.npy')
if os.path.exists(significances_llm_path):
significances_llm = np.load(significances_llm_path)
else:
missing_file_5 = True
"""
Plotting and derived checks removed per request: validation for steps 2–5 now does
direct array comparisons only (generated vs reference).
"""
step1_success = False
step2_success = False
step3_success = False
step4_success = False
step5_success = False
# Step 1 validation (summarize_root outputs)
if (not specific_step or specific_step == 1) and not missing_file_1:
try:
print("=== Step 1 Validation (summarize_root) ===")
# Load reference files for comparison
ref_file_list_path = '/global/cfs/projectdirs/atlas/dwkim/llm4hep/solution/arrays/file_list.txt'
# ref_root_summary_path no longer needed since we don't compare to reference
# Load LLM-generated files
with open(file_list_llm_path, 'r') as f:
file_list_llm = f.read()
with open(root_summary_llm_path, 'r') as f:
root_summary_llm = f.read()
# Standard mode: compare content with reference
if os.path.exists(ref_file_list_path):
with open(ref_file_list_path, 'r') as f:
ref_file_list = f.read()
# Extract filenames from both files for comparison
# Handle both full paths and just filenames
def extract_filenames(content):
lines = [line.strip() for line in content.strip().split('\n') if line.strip()]
filenames = []
for line in lines:
# Extract filename from path or use as-is
filename = os.path.basename(line) if '/' in line else line
filenames.append(filename)
return sorted(filenames)
llm_filenames = extract_filenames(file_list_llm)
ref_filenames = extract_filenames(ref_file_list)
file_list_match = llm_filenames == ref_filenames
if not file_list_match:
print(f" 📊 LLM files: {len(llm_filenames)} | Reference files: {len(ref_filenames)}")
if len(llm_filenames) != len(ref_filenames):
print(f" ❌ File count mismatch: {len(llm_filenames)} vs {len(ref_filenames)}")
else:
# Show first few differences
for i, (llm_file, ref_file) in enumerate(zip(llm_filenames, ref_filenames)):
if llm_file != ref_file:
print(f" ❌ File {i+1} mismatch: '{llm_file}' vs '{ref_file}'")
break
else:
file_list_match = True # No reference to compare
# Use detailed root_summary validation
# Only check that required branches are present (no reference comparison needed)
root_summary_match = validate_root_summary(root_summary_llm, "")
step1_success = file_list_match and root_summary_match
# Removed duplicate printing - summary will be shown in VALIDATION SUMMARY section
except Exception as e:
print(f"Error in Step 1 validation: {e}")
step1_success = False
# Step 2 validation (create_numpy outputs) - direct array comparisons
if (not specific_step or specific_step == 2) and not missing_file_2:
print("=== Step 2 Validation (create_numpy) ===")
# Choose reference arrays: prefer selective names, fallback to standard
data_ref = None
signal_ref = None
if 'data_A_raw_soln' in globals():
data_ref = data_A_raw_soln
elif 'data_raw_soln' in globals():
data_ref = data_raw_soln
if 'signal_WH_raw_soln' in globals():
signal_ref = signal_WH_raw_soln
elif 'signal_raw_soln' in globals():
signal_ref = signal_raw_soln
ok_data = False
ok_signal = False
if data_ref is not None:
ok_data = arrays_match(data_raw_llm, data_ref, "data_A_raw.npy (or data_raw.npy)")
else:
print(" ❌ Missing data reference array (data_A_raw.npy or data_raw.npy)")
if signal_ref is not None:
ok_signal = arrays_match(signal_raw_llm, signal_ref, "signal_WH_raw.npy (or signal_raw.npy)")
else:
print(" ❌ Missing signal reference array (signal_WH_raw.npy or signal_raw.npy)")
step2_success = ok_data and ok_signal
print(f"Step 2 validation: {'PASS' if step2_success else 'FAIL'}")
# Step 3 validation (preprocess outputs) - direct array comparisons
if (not specific_step or specific_step == 3) and not missing_file_3:
print("=== Step 3 Validation (preprocess) ===")
ok_signal = arrays_match(signal_llm, signal_soln, "signal.npy")
ok_bkgd = arrays_match(bkgd_llm, bkgd_soln, "bkgd.npy")
step3_success = ok_signal and ok_bkgd
# Step 4 validation (scores) - direct array comparisons
if (not specific_step or specific_step == 4) and not missing_file_4:
print("=== Step 4 Validation (scores) ===")
ok_sig_scores = arrays_match(signal_scores_llm, signal_scores_soln, "signal_scores.npy")
ok_bkg_scores = arrays_match(bkgd_scores_llm, bkgd_scores_soln, "bkgd_scores.npy")
step4_success = ok_sig_scores and ok_bkg_scores
# Step 5 validation (categorization outputs) - direct array comparisons
if (not specific_step or specific_step == 5) and not missing_file_5:
print("=== Step 5 Validation (categorization) ===")
ok_boundaries = arrays_match(boundaries_llm, boundaries_soln, "boundaries.npy")
ok_significances = arrays_match(significances_llm, significances_soln, "significances.npy")
step5_success = ok_boundaries and ok_significances
# Save results
success_results = [int(step1_success), int(step2_success), int(step3_success), int(step4_success), int(step5_success)]
# np.save('success.npy', success_results) # Removed - results are already printed to console
print("\n=== VALIDATION SUMMARY ===")
if specific_step:
step_names = ["summarize_root", "create_numpy", "preprocess", "scores", "categorization"]
step_name = step_names[specific_step - 1]
print(f"Step: {specific_step} ({step_name})")
if specific_step == 1:
print("Files validated:")
print(" • file_list.txt - List of processed ROOT files")
print(" • root_summary.txt - Branch structure and file metadata")
elif specific_step == 2:
print("Files validated:")
print(" • data_A_raw.npy - Raw data array (must have 46 columns)")
print(" • signal_WH_raw.npy - Raw signal array (must have 46 columns)")
elif specific_step == 3:
print("Files validated:")
print(" • signal.npy - Preprocessed signal events")
print(" • bkgd.npy - Preprocessed background events")
# print("Histograms validated:")
# print(" • Signal m_yy histogram (10 bins, 123-127 GeV)")
# print(" • Background m_yy histogram (100 bins, 105-160 GeV)")
# print(" • Signal leading lepton pT histogram (10 bins, 25-300 GeV)")
# print(" • Background leading lepton pT histogram (10 bins, 25-300 GeV)")
elif specific_step == 4:
print("Files validated:")
print(" • signal_scores.npy - Signal event classification scores")
print(" • bkgd_scores.npy - Background event classification scores")
elif specific_step == 5:
print("Files validated:")
print(" • boundaries.npy - Category boundary thresholds")
print(" • significances.npy - Statistical significance values")
else:
print("All steps validated")
# Mode info removed; direct comparisons are used for all steps
# Show only relevant step status
if specific_step:
step_names = ["summarize_root", "create_numpy", "preprocess", "scores", "categorization"]
step_name = step_names[specific_step - 1]
if specific_step == 1 and not missing_file_1:
status = "PASS" if step1_success else "FAIL"
elif specific_step == 2 and not missing_file_2:
status = "PASS" if step2_success else "FAIL"
elif specific_step == 3 and not missing_file_3:
status = "PASS" if step3_success else "FAIL"
elif specific_step == 4 and not missing_file_4:
status = "PASS" if step4_success else "FAIL"
elif specific_step == 5 and not missing_file_5:
status = "PASS" if step5_success else "FAIL"
else:
status = "MISSING"
print(f"\nStep {specific_step} ({step_name}): {status}")
if status == "PASS":
print("✅ Validation successful")
elif status == "FAIL":
print("❌ Validation failed")
else:
print("⚠️ Step outputs missing")
else:
# Show all steps for full validation
step_status = []
for i, (success, missing) in enumerate([(step1_success, missing_file_1),
(step2_success, missing_file_2),
(step3_success, missing_file_3),
(step4_success, missing_file_4),
(step5_success, missing_file_5)], 1):
if missing:
step_status.append("MISSING")
elif success:
step_status.append("PASS")
else:
step_status.append("FAIL")
print(f"Step 1 (summarize_root): {step_status[0]}")
print(f"Step 2 (create_numpy): {step_status[1]}")
print(f"Step 3 (preprocess): {step_status[2]}")
print(f"Step 4 (scores): {step_status[3]}")
print(f"Step 5 (categorization): {step_status[4]}")
# Only count actually validated steps for overall success
if specific_step:
validated_steps = 1
passed_steps = 1 if success_results[specific_step-1] and not [missing_file_1, missing_file_2, missing_file_3, missing_file_4, missing_file_5][specific_step-1] else 0
print(f"\nResult: {passed_steps}/{validated_steps} step passed")
else:
validated_steps = sum(1 for missing in [missing_file_1, missing_file_2, missing_file_3, missing_file_4, missing_file_5] if not missing)
passed_steps = sum(success_results)
print(f"Overall success: {passed_steps}/{validated_steps} validated steps passed")
print(f"Success array: {success_results}")
# At the end of main script, ensure validation script exits zero so Run_SMK prints PASS/FAIL instead of 'failed to run'
sys.exit(0)