File size: 37,220 Bytes
cfcbbc8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
import os
import sys
import numpy as np
import matplotlib.pyplot as plt

# ATLAS style only needed for plotting
try:
    import atlas_mpl_style as ampl
    ampl.use_atlas_style()
    plt.rcParams['font.family'] = 'DejaVu Sans'
except ImportError:
    print("Warning: ATLAS style not available, using default matplotlib style")
    plt.style.use('default')

# Plotting helpers are not used in array-only validation, keep import disabled to reduce deps
# from utils_plot import plot_myy_comparison, plot_scores_comparison

import argparse
parser = argparse.ArgumentParser()
add_arg = parser.add_argument
add_arg('--out_dir', help='output directory')
add_arg('--step', type=int, choices=[1, 2, 3, 4, 5], 
        help='Validate only specific step (1-5)')
args = parser.parse_args()
out_dir = args.out_dir
specific_step = args.step


def arrays_match(generated, reference, name: str, atol: float = 1e-10) -> bool:
    """
    Compare two numpy arrays element-wise with a strict absolute tolerance.
    - NaNs are considered equal when they appear at the same positions.
    - rtol is set to 0.0 so only absolute tolerance matters.
    Prints a concise status and returns True/False.
    """
    print(f"Validating {name}...")
    if generated.shape != reference.shape:
        print(f"  ❌ Shape mismatch: {generated.shape} vs {reference.shape}")
        return False
    ok = np.allclose(generated, reference, rtol=0.0, atol=atol, equal_nan=True)
    if ok:
        print(f"  ✅ {name} matches (atol={atol})")
        return True
    # Brief diff stats to aid debugging
    nan_mask_equal = np.array_equal(np.isnan(generated), np.isnan(reference))
    finite = (~np.isnan(generated)) & (~np.isnan(reference))
    mismatches = int(np.sum(generated[finite] != reference[finite]))
    print(f"  ❌ {name} differs: NaN mask equal={nan_mask_equal}, finite mismatches={mismatches}/{int(finite.sum())}")
    if finite.any():
        diffs = np.abs(generated[finite] - reference[finite])
        print(f"     diff stats: max={diffs.max():.6g}, mean={diffs.mean():.6g}")
        # Additional debug: show sample mismatches
        print("🔍 Running detailed mismatch analysis...")
        analyze_array_differences(generated, reference, name)
    return False

def calculate_adaptive_tolerance(values, significant_digits=4):
    """
    Calculate adaptive tolerance based on the magnitude of values to achieve desired significant digits.
    For each value, the tolerance is set to preserve the specified number of significant digits.
    
    Examples:
    - Value 123000 with 4 sig digits: tolerance = 1000 (1e3)
    - Value 0.00014 with 4 sig digits: tolerance = 0.0000014 (1.4e-6)
    - Value 0 with 4 sig digits: tolerance = 1e-10 (small default)
    """
    # Handle zero values
    non_zero_mask = values != 0
    tolerances = np.full_like(values, 1e-10, dtype=float)  # Default for zeros
    
    if np.any(non_zero_mask):
        # Calculate tolerance as value / 10^(significant_digits)
        # This preserves the desired number of significant digits
        abs_values = np.abs(values[non_zero_mask])
        tolerances[non_zero_mask] = abs_values / (10 ** significant_digits)
    
    return tolerances

def analyze_array_differences(generated, reference, array_name, significant_digits=4):
    """
    Analyze differences between generated and reference numpy arrays.
    Uses adaptive tolerance based on significant digits rather than fixed tolerance.
    """
    print(f"\n🔍 Detailed analysis for {array_name} (using {significant_digits} significant digit tolerance):")
    print(f"  Generated shape: {generated.shape}, Reference shape: {reference.shape}")
    print(f"  Tolerance: Adaptive based on {significant_digits} significant digits per value")

    # Check for shape differences first
    if generated.shape != reference.shape:
        print(f"  ❌ Shape mismatch: {generated.shape} vs {reference.shape}")
        return

    # Calculate adaptive tolerances for each element
    combined_values = np.abs(np.concatenate([generated.flatten(), reference.flatten()]))
    adaptive_tolerances = calculate_adaptive_tolerance(combined_values, significant_digits)
    
    # Reshape tolerances to match original arrays
    atol_array = adaptive_tolerances[:generated.size].reshape(generated.shape)
    
    # Use absolute tolerance only (relative tolerance not used)

    # Find differences and identify where tolerances are exceeded
    diff = generated - reference
    abs_diff = np.abs(diff)
    not_close = abs_diff > atol_array
    # Remove any comparisons involving NaNs (gen or ref)
    invalid = np.isnan(generated) | np.isnan(reference)
    not_close = not_close & ~invalid
    
    total_different = np.sum(not_close)

    if total_different == 0:
        print("  ✅ All elements match within tolerance")
        return

    print(f"  ❌ {total_different} elements differ (out of {generated.size} total)")
    
    # Show numeric mismatches only (exclude any NaN comparisons)
    flat_gen = generated.flatten()
    flat_ref = reference.flatten()
    flat_not_close = not_close.flatten()
    # Mask to include only finite mismatches
    numeric_mask = (~np.isnan(flat_gen)) & (~np.isnan(flat_ref))
    mismatch_mask = flat_not_close & numeric_mask
    if np.any(mismatch_mask):
        diff_indices = np.where(mismatch_mask)[0][:10]
        print("  📊 Sample numeric mismatches (first 10 indices):")
        for idx in diff_indices:
            gen_val = flat_gen[idx]
            ref_val = flat_ref[idx]
            diff_val = gen_val - ref_val
            print(f"    Index {idx}: gen={gen_val}, ref={ref_val}, diff={diff_val}")
    else:
        print("  ✅ No numeric mismatches (all differences involve NaNs)")
    
    # Skip overall statistics for now - they may not be meaningful for all data types

    # Analyze differences by column (if 2D array)
    if generated.ndim == 2:
        col_diffs = np.sum(not_close, axis=0)
        cols_with_diffs = np.where(col_diffs > 0)[0]

        if len(cols_with_diffs) > 0:
            print(f"  📊 Columns with differences: {cols_with_diffs[:10]} (showing first 10)")
            
            # Show side-by-side entries for first 10 differing columns
            num_cols_to_show = min(10, len(cols_with_diffs))
            num_rows_to_show = min(5, generated.shape[0])  # Show first 5 rows
            
            print(f"  📋 Sample entries (first {num_rows_to_show} rows, first {num_cols_to_show} differing columns):")
            print("     Row | Column | Generated Value | Reference Value | Difference")
            print("     ----|--------|----------------|-----------------|------------")
            
            for col_idx in cols_with_diffs[:num_cols_to_show]:
                for row_idx in range(num_rows_to_show):
                    gen_val = generated[row_idx, col_idx]
                    ref_val = reference[row_idx, col_idx]
                    diff = gen_val - ref_val
                    
                    # Format values nicely
                    gen_str = f"{gen_val:.6g}" if not np.isnan(gen_val) else "NaN"
                    ref_str = f"{ref_val:.6g}" if not np.isnan(ref_val) else "NaN"
                    diff_str = f"{diff:.6g}" if not np.isnan(diff) else "NaN"
                    
                    print(f"     {row_idx:3d} |   {col_idx:3d} | {gen_str:>14} | {ref_str:>15} | {diff_str:>10}")
        else:
            print("  ✅ All columns match within tolerance")
    else:
        print("  📊 1D array - no column-by-column analysis needed")

    # Check for special values - only warn if there's a significant difference
    nan_gen = np.sum(np.isnan(generated))
    nan_ref = np.sum(np.isnan(reference))

    if nan_gen > 1000 or nan_ref > 1000:  # Only show if significant number of NaNs
        # Check if NaN counts are very similar (within 1% difference)
        if nan_gen > 0 and nan_ref > 0:
            nan_ratio = min(nan_gen, nan_ref) / max(nan_gen, nan_ref)
            if nan_ratio > 0.99:  # NaN counts are essentially identical
                print("  ✅ Data structure consistency: Identical NaN patterns in generated and reference files")
                print(f"     - Both files have {nan_gen:,} NaN values (excellent consistency)")
            else:
                print("  ⚠️  Special values detected:")
                if nan_gen > 1000:
                    print(f"    - NaN in generated: {nan_gen:,}")
                if nan_ref > 1000:
                    print(f"    - NaN in reference: {nan_ref:,}")
        else:
            print("  ⚠️  Special values detected:")
            if nan_gen > 1000:
                print(f"    - NaN in generated: {nan_gen:,}")
            if nan_ref > 1000:
                print(f"    - NaN in reference: {nan_ref:,}")
def validate_root_summary(llm_content, ref_content):
    """
    Validate root_summary.txt content by checking that all required branch names are present
    Focus on content (branch names) rather than exact format structure
    """
    try:
        # Extract all branch names from LLM content
        llm_branches = set(extract_branch_names(llm_content))
        
        # Required branches that must be present
        required_branches = {
            'SumWeights', 'XSection', 'channelNumber', 'ditau_m', 'eventNumber', 
            'jet_E', 'jet_MV2c10', 'jet_eta', 'jet_jvt', 'jet_n', 'jet_phi', 'jet_pt', 
            'jet_pt_syst', 'jet_trueflav', 'jet_truthMatched', 'largeRjet_D2', 'largeRjet_E', 
            'largeRjet_eta', 'largeRjet_m', 'largeRjet_n', 'largeRjet_phi', 'largeRjet_pt', 
            'largeRjet_pt_syst', 'largeRjet_tau32', 'largeRjet_truthMatched', 'lep_E', 
            'lep_charge', 'lep_eta', 'lep_etcone20', 'lep_isTightID', 'lep_n', 'lep_phi', 
            'lep_pt', 'lep_pt_syst', 'lep_ptcone30', 'lep_trackd0pvunbiased', 
            'lep_tracksigd0pvunbiased', 'lep_trigMatched', 'lep_truthMatched', 'lep_type', 
            'lep_z0', 'mcWeight', 'met_et', 'met_et_syst', 'met_phi', 'photon_E', 
            'photon_convType', 'photon_eta', 'photon_etcone20', 'photon_isTightID', 'photon_n', 
            'photon_phi', 'photon_pt', 'photon_pt_syst', 'photon_ptcone30', 'photon_trigMatched', 
            'photon_truthMatched', 'runNumber', 'scaleFactor_BTAG', 'scaleFactor_ELE', 
            'scaleFactor_LepTRIGGER', 'scaleFactor_MUON', 'scaleFactor_PHOTON', 'scaleFactor_PILEUP', 
            'scaleFactor_PhotonTRIGGER', 'scaleFactor_TAU', 'tau_BDTid', 'tau_E', 'tau_charge', 
            'tau_eta', 'tau_isTightID', 'tau_n', 'tau_nTracks', 'tau_phi', 'tau_pt', 
            'tau_pt_syst', 'tau_trigMatched', 'tau_truthMatched', 'trigE', 'trigM', 'trigP'
        }
        
        print(f"  📊 LLM output has {len(llm_branches)} unique words, Required: {len(required_branches)} branches")

        # Debug: Show all required branch names found in txt file
        found_required_branches = required_branches & llm_branches
        if found_required_branches:
            sorted_found = sorted(found_required_branches)
            print(f"  🔍 Required branch names found in txt file: {', '.join(sorted_found)}")
        
        # Check if we have any branches at all
        if len(llm_branches) == 0:
            print("  ❌ No branches found in LLM output")
            return False
        
        # Check if all required branches are present
        missing_branches = required_branches - llm_branches
        
        if missing_branches:
            print(f"  ❌ Missing {len(missing_branches)} required branches:")
            for branch in sorted(missing_branches):
                print(f"     - {branch}")
            return False
        else:
            print("  ✅ All required branches present in LLM output")
            return True
            
    except Exception as e:
        print(f"  ❌ Error parsing root_summary: {e}")
        return False

def extract_branch_names(content):
    """
    Extract all words from root_summary.txt content.
    This approach parses the file into words and checks for branch names as tokens.
    """
    import re
    
    # Split content into words using regex to handle various separators
    # This will capture words with underscores, dots, etc. as single tokens
    words = re.findall(r'\b\w+\b', content)
    
    # Convert to set to remove duplicates and for fast lookup
    return set(words)

def parse_root_summary(content):
    """
    Parse root_summary.txt content into structured data
    Supports both reference format (File 1:, File 2:, etc.) and LLM format (single file summary)
    """
    files = {}
    current_file = None
    lines = content.split('\n')
    i = 0
    
    while i < len(lines):
        line = lines[i].strip()
        
        # Look for file headers in reference format
        if line.startswith('File ') and ':' in line:
            # Extract filename
            parts = line.split(': ')
            if len(parts) >= 2:
                filename = parts[1].strip()
                current_file = filename
                files[current_file] = {
                    'total_objects': 0,
                    'trees': 0,
                    'entries': 0,
                    'total_branches': 0,
                    'branches': {}
                }
        
        # Look for LLM format header (alternative format)
        elif line.startswith('Root file: ') and ':' in line:
            # Extract filename from path
            parts = line.split(': ')
            if len(parts) >= 2:
                full_path = parts[1].strip()
                filename = os.path.basename(full_path)
                current_file = filename
                files[current_file] = {
                    'total_objects': 1,  # Assume 1 tree
                    'trees': 1,
                    'entries': 0,  # Will be set if found
                    'total_branches': 0,
                    'branches': {}
                }
        
        # Parse file data
        elif current_file and current_file in files:
            if 'Total objects:' in line:
                try:
                    files[current_file]['total_objects'] = int(line.split(':')[1].strip())
                except Exception:
                    pass
            elif 'Trees found:' in line:
                try:
                    files[current_file]['trees'] = int(line.split(':')[1].strip())
                except Exception:
                    pass
            elif 'Entries:' in line:
                try:
                    files[current_file]['entries'] = int(line.split(':')[1].strip())
                except Exception:
                    pass
            elif 'Common branches (' in line and ')' in line:
                # Extract total branch count from common branches section
                try:
                    count_part = line.split('(')[1].split(')')[0]
                    # This sets the total for all files since they're common
                    common_branch_count = int(count_part)
                    # Set this for all existing files
                    for filename in files:
                        files[filename]['total_branches'] = common_branch_count
                except Exception:
                    pass
                
                # Parse branch categories
                branches = {}
                j = i + 1
                while j < len(lines) and not lines[j].strip().startswith('='):
                    branch_line = lines[j].strip()
                    if ': ' in branch_line:
                        category, branch_list = branch_line.split(': ', 1)
                        category = category.strip().lower()
                        branch_names = [b.strip() for b in branch_list.split(',')]
                        branches[category] = branch_names
                    j += 1
                
                files[current_file]['branches'] = branches
                i = j - 1  # Skip the lines we already processed
            
            # Handle LLM format branch parsing (with - prefix)
            elif line == 'TTree: mini':
                # Count branches in LLM format
                branches = {}
                branch_lines = []
                j = i + 1
                while j < len(lines) and lines[j].strip() and not lines[j].strip().startswith('='):
                    branch_line = lines[j].strip()
                    if branch_line.startswith('  Branches:'):
                        # Skip the "Branches:" header
                        j += 1
                        continue
                    elif branch_line.startswith('    - '):
                        # Extract branch name from "- branch_name" format
                        branch_name = branch_line.replace('    - ', '').strip()
                        branch_lines.append(branch_name)
                    j += 1
                
                # Categorize branches for LLM format
                photon_branches = []
                jet_branches = []
                met_branches = []
                lep_branches = []
                tau_branches = []
                event_branches = []
                weights_branches = []
                
                for branch in branch_lines:
                    if branch.startswith('photon_'):
                        photon_branches.append(branch)
                    elif branch.startswith('jet_'):
                        jet_branches.append(branch)
                    elif branch.startswith('met_'):
                        met_branches.append(branch)
                    elif branch.startswith('lep_'):
                        lep_branches.append(branch)
                    elif branch.startswith('tau_'):
                        tau_branches.append(branch)
                    elif branch in ['runNumber', 'eventNumber', 'channelNumber', 'mcWeight', 'trigE', 'trigM', 'trigP', 'ditau_m']:
                        event_branches.append(branch)
                    elif branch in ['SumWeights', 'XSection'] or branch.startswith('scaleFactor_') or branch.startswith('largeRjet_'):
                        weights_branches.append(branch)
                
                if photon_branches:
                    branches['photon'] = photon_branches
                if jet_branches:
                    branches['jet'] = jet_branches
                if met_branches:
                    branches['met'] = met_branches
                if lep_branches:
                    branches['lep'] = lep_branches
                if tau_branches:
                    branches['tau'] = tau_branches
                if event_branches:
                    branches['event'] = event_branches
                if weights_branches:
                    branches['weights'] = weights_branches
                
                files[current_file]['branches'] = branches
                files[current_file]['total_branches'] = len(branch_lines)
                i = j - 1  # Skip the lines we already processed
        
        i += 1
    
    return files

# Load reference solution files for steps 1 and 2 - only load what's needed
# This will be done after mode detection below

# Load existing reference files for steps 3, 4, 5
signal_soln = np.load('/global/cfs/projectdirs/atlas/dwkim/llm4hep/solution/arrays/signal.npy')
bkgd_soln = np.load('/global/cfs/projectdirs/atlas/dwkim/llm4hep/solution/arrays/bkgd.npy')
signal_scores_soln = np.load('/global/cfs/projectdirs/atlas/dwkim/llm4hep/solution/arrays/signal_scores.npy')
bkgd_scores_soln = np.load('/global/cfs/projectdirs/atlas/dwkim/llm4hep/solution/arrays/bkgd_scores.npy')
boundaries_soln = np.load('/global/cfs/projectdirs/atlas/dwkim/llm4hep/solution/arrays/boundaries.npy')
significances_soln = np.load('/global/cfs/projectdirs/atlas/dwkim/llm4hep/solution/arrays/significances.npy')

base_dir = os.path.join(out_dir, 'arrays')

missing_file_1 = False  # Step 1: summarize_root files
missing_file_2 = False  # Step 2: create_numpy files  
missing_file_3 = False  # Step 3: preprocess files
missing_file_4 = False  # Step 4: scores files
missing_file_5 = False  # Step 5: categorization files

# Step 1: Check summarize_root outputs (file_list.txt, root_summary.txt)
if not specific_step or specific_step == 1:
    file_list_llm_path = os.path.join(out_dir, 'logs', 'file_list.txt')
    root_summary_llm_path = os.path.join(out_dir, 'logs', 'root_summary.txt')
    # Note: create_numpy_modified.txt comes from insert_root_summary rule (no LLM), so we don't validate it for step 1
    
    if not (os.path.exists(file_list_llm_path) and os.path.exists(root_summary_llm_path)):
        if not specific_step or specific_step == 1:
            print("Step 1 (summarize_root) outputs missing")
        missing_file_1 = True

# Step 2: Check create_numpy outputs (data_A_raw.npy and signal_WH_raw.npy)
if not specific_step or specific_step == 2:
    # Check for the specific files requested: data_A_raw.npy and signal_WH_raw.npy
    data_A_raw_llm_path = os.path.join(base_dir, 'data_A_raw.npy')
    signal_WH_raw_llm_path = os.path.join(base_dir, 'signal_WH_raw.npy')

    if os.path.exists(data_A_raw_llm_path) and os.path.exists(signal_WH_raw_llm_path):
        data_raw_llm = np.load(data_A_raw_llm_path)
        signal_raw_llm = np.load(signal_WH_raw_llm_path)
        if not specific_step or specific_step == 2:
            print("Found required files: data_A_raw.npy and signal_WH_raw.npy")
    else:
        if not specific_step or specific_step == 2:
            print("Step 2 (create_numpy) outputs missing - data_A_raw.npy and/or signal_WH_raw.npy not found")
        missing_file_2 = True

# Step 3: Check preprocess outputs (signal.npy, bkgd.npy)
if not specific_step or specific_step == 3:
    signal_llm_path = os.path.join(base_dir, 'signal.npy')
    if os.path.exists(signal_llm_path):
        signal_llm = np.load(signal_llm_path)
    else:
        if not specific_step or specific_step == 3:
            print("LLM generated signal sample does not exist (Step 3)")
        missing_file_3 = True

    bkgd_llm_path = os.path.join(base_dir, 'bkgd.npy')
    if os.path.exists(bkgd_llm_path):
        bkgd_llm = np.load(bkgd_llm_path)
    else:
        if not specific_step or specific_step == 3:
            print("LLM generated background sample does not exist (Step 3)")
        missing_file_3 = True

# Step 4: Check scores outputs (signal_scores.npy, bkgd_scores.npy)
if not specific_step or specific_step == 4:
    signal_scores_llm_path = os.path.join(base_dir, 'signal_scores.npy')
    if os.path.exists(signal_scores_llm_path):
        signal_scores_llm = np.load(signal_scores_llm_path)
    else:
        if not specific_step or specific_step == 4:
            print("LLM generated signal scores do not exist (Step 4)")
        missing_file_4 = True

    bkgd_scores_llm_path = os.path.join(base_dir, 'bkgd_scores.npy')
    if os.path.exists(bkgd_scores_llm_path):
        bkgd_scores_llm = np.load(bkgd_scores_llm_path)
    else:
        if not specific_step or specific_step == 4:
            print("LLM generated background scores do not exist (Step 4)")
        missing_file_4 = True

# Step 5: Check categorization outputs (boundaries.npy, significances.npy)
if not specific_step or specific_step == 5:
    boundaries_llm_path = os.path.join(base_dir, 'boundaries.npy')
    if os.path.exists(boundaries_llm_path):
        boundaries_llm = np.load(boundaries_llm_path)
    else:
        if not specific_step or specific_step == 5:
            print("LLM generated boundaries do not exist (Step 5)")
        missing_file_5 = True

    significances_llm_path = os.path.join(base_dir, 'significances.npy')
    if os.path.exists(significances_llm_path):
        significances_llm = np.load(significances_llm_path)
    else:
        if not specific_step or specific_step == 5:
            print("LLM generated significances do not exist (Step 5)")
        missing_file_5 = True

# Step 2: Check create_numpy outputs (data_A_raw.npy and signal_WH_raw.npy)
signal_raw_llm_path = os.path.join(base_dir, 'signal_raw.npy')
data_raw_llm_path = os.path.join(base_dir, 'data_raw.npy')

# Check for the specific files requested: data_A_raw.npy and signal_WH_raw.npy
data_A_raw_llm_path = os.path.join(base_dir, 'data_A_raw.npy')
signal_WH_raw_llm_path = os.path.join(base_dir, 'signal_WH_raw.npy')

if os.path.exists(data_A_raw_llm_path) and os.path.exists(signal_WH_raw_llm_path):
    data_raw_llm = np.load(data_A_raw_llm_path)
    signal_raw_llm = np.load(signal_WH_raw_llm_path)
else:
    missing_file_2 = True

# Load reference files for Step 2 validation
selective_refs_loaded = False
standard_refs_loaded = False

data_A_raw_soln_path = '/global/cfs/projectdirs/atlas/dwkim/llm4hep/solution/arrays/data_A_raw.npy'
signal_WH_raw_soln_path = '/global/cfs/projectdirs/atlas/dwkim/llm4hep/solution/arrays/signal_WH_raw.npy'
signal_raw_soln_path = '/global/cfs/projectdirs/atlas/dwkim/llm4hep/solution/arrays/signal_raw.npy'
data_raw_soln_path = '/global/cfs/projectdirs/atlas/dwkim/llm4hep/solution/arrays/data_raw.npy'

# Try to load selective reference files first
if os.path.exists(data_A_raw_soln_path):
    data_A_raw_soln = np.load(data_A_raw_soln_path)
    selective_refs_loaded = True
if os.path.exists(signal_WH_raw_soln_path):
    signal_WH_raw_soln = np.load(signal_WH_raw_soln_path)
    selective_refs_loaded = True
    
# Also try to load standard reference files
if os.path.exists(signal_raw_soln_path):
    signal_raw_soln = np.load(signal_raw_soln_path)
    standard_refs_loaded = True
if os.path.exists(data_raw_soln_path):
    data_raw_soln = np.load(data_raw_soln_path)
    standard_refs_loaded = True

# Step 3: Check preprocess outputs (signal.npy, bkgd.npy)
signal_llm_path = os.path.join(base_dir, 'signal.npy')
if os.path.exists(signal_llm_path):
    signal_llm = np.load(signal_llm_path)
else:
    missing_file_3 = True

bkgd_llm_path = os.path.join(base_dir, 'bkgd.npy')
if os.path.exists(bkgd_llm_path):
    bkgd_llm = np.load(bkgd_llm_path)
else:
    missing_file_3 = True

# Step 4: Check scores outputs (signal_scores.npy, bkgd_scores.npy)
signal_scores_llm_path = os.path.join(base_dir, 'signal_scores.npy')
if os.path.exists(signal_scores_llm_path):
    signal_scores_llm = np.load(signal_scores_llm_path)
else:
    missing_file_4 = True

bkgd_scores_llm_path = os.path.join(base_dir, 'bkgd_scores.npy')
if os.path.exists(bkgd_scores_llm_path):
    bkgd_scores_llm = np.load(bkgd_scores_llm_path)
else:
    missing_file_4 = True

# Step 5: Check categorization outputs (boundaries.npy, significances.npy)
boundaries_llm_path = os.path.join(base_dir, 'boundaries.npy')
if os.path.exists(boundaries_llm_path):
    boundaries_llm = np.load(boundaries_llm_path)
else:
    missing_file_5 = True

significances_llm_path = os.path.join(base_dir, 'significances.npy')
if os.path.exists(significances_llm_path):
    significances_llm = np.load(significances_llm_path)
else:
    missing_file_5 = True

"""
Plotting and derived checks removed per request: validation for steps 2–5 now does
direct array comparisons only (generated vs reference).
"""

step1_success = False
step2_success = False
step3_success = False
step4_success = False
step5_success = False

# Step 1 validation (summarize_root outputs)
if (not specific_step or specific_step == 1) and not missing_file_1:
    try:
        print("=== Step 1 Validation (summarize_root) ===")
        # Load reference files for comparison
        ref_file_list_path = '/global/cfs/projectdirs/atlas/dwkim/llm4hep/solution/arrays/file_list.txt'
        # ref_root_summary_path no longer needed since we don't compare to reference
        
        # Load LLM-generated files
        with open(file_list_llm_path, 'r') as f:
            file_list_llm = f.read()
        with open(root_summary_llm_path, 'r') as f:
            root_summary_llm = f.read()
        
        # Standard mode: compare content with reference
        if os.path.exists(ref_file_list_path):
            with open(ref_file_list_path, 'r') as f:
                ref_file_list = f.read()
            
            # Extract filenames from both files for comparison
            # Handle both full paths and just filenames
            def extract_filenames(content):
                lines = [line.strip() for line in content.strip().split('\n') if line.strip()]
                filenames = []
                for line in lines:
                    # Extract filename from path or use as-is
                    filename = os.path.basename(line) if '/' in line else line
                    filenames.append(filename)
                return sorted(filenames)
            
            llm_filenames = extract_filenames(file_list_llm)
            ref_filenames = extract_filenames(ref_file_list)
            file_list_match = llm_filenames == ref_filenames
            
            if not file_list_match:
                print(f"  📊 LLM files: {len(llm_filenames)} | Reference files: {len(ref_filenames)}")
                if len(llm_filenames) != len(ref_filenames):
                    print(f"  ❌ File count mismatch: {len(llm_filenames)} vs {len(ref_filenames)}")
                else:
                    # Show first few differences
                    for i, (llm_file, ref_file) in enumerate(zip(llm_filenames, ref_filenames)):
                        if llm_file != ref_file:
                            print(f"  ❌ File {i+1} mismatch: '{llm_file}' vs '{ref_file}'")
                            break
        else:
            file_list_match = True  # No reference to compare
        
        # Use detailed root_summary validation
        # Only check that required branches are present (no reference comparison needed)
        root_summary_match = validate_root_summary(root_summary_llm, "")
        
        step1_success = file_list_match and root_summary_match
        # Removed duplicate printing - summary will be shown in VALIDATION SUMMARY section
    except Exception as e:
        print(f"Error in Step 1 validation: {e}")
        step1_success = False

# Step 2 validation (create_numpy outputs) - direct array comparisons
if (not specific_step or specific_step == 2) and not missing_file_2:
    print("=== Step 2 Validation (create_numpy) ===")
    # Choose reference arrays: prefer selective names, fallback to standard
    data_ref = None
    signal_ref = None
    if 'data_A_raw_soln' in globals():
        data_ref = data_A_raw_soln
    elif 'data_raw_soln' in globals():
        data_ref = data_raw_soln
    if 'signal_WH_raw_soln' in globals():
        signal_ref = signal_WH_raw_soln
    elif 'signal_raw_soln' in globals():
        signal_ref = signal_raw_soln

    ok_data = False
    ok_signal = False
    if data_ref is not None:
        ok_data = arrays_match(data_raw_llm, data_ref, "data_A_raw.npy (or data_raw.npy)")
    else:
        print("  ❌ Missing data reference array (data_A_raw.npy or data_raw.npy)")
    if signal_ref is not None:
        ok_signal = arrays_match(signal_raw_llm, signal_ref, "signal_WH_raw.npy (or signal_raw.npy)")
    else:
        print("  ❌ Missing signal reference array (signal_WH_raw.npy or signal_raw.npy)")
    step2_success = ok_data and ok_signal
    print(f"Step 2 validation: {'PASS' if step2_success else 'FAIL'}")

# Step 3 validation (preprocess outputs) - direct array comparisons
if (not specific_step or specific_step == 3) and not missing_file_3:
    print("=== Step 3 Validation (preprocess) ===")
    ok_signal = arrays_match(signal_llm, signal_soln, "signal.npy")
    ok_bkgd = arrays_match(bkgd_llm, bkgd_soln, "bkgd.npy")
    step3_success = ok_signal and ok_bkgd
# Step 4 validation (scores) - direct array comparisons
if (not specific_step or specific_step == 4) and not missing_file_4:
    print("=== Step 4 Validation (scores) ===")
    ok_sig_scores = arrays_match(signal_scores_llm, signal_scores_soln, "signal_scores.npy")
    ok_bkg_scores = arrays_match(bkgd_scores_llm, bkgd_scores_soln, "bkgd_scores.npy")
    step4_success = ok_sig_scores and ok_bkg_scores

# Step 5 validation (categorization outputs) - direct array comparisons
if (not specific_step or specific_step == 5) and not missing_file_5:
    print("=== Step 5 Validation (categorization) ===")
    ok_boundaries = arrays_match(boundaries_llm, boundaries_soln, "boundaries.npy")
    ok_significances = arrays_match(significances_llm, significances_soln, "significances.npy")
    step5_success = ok_boundaries and ok_significances

# Save results
success_results = [int(step1_success), int(step2_success), int(step3_success), int(step4_success), int(step5_success)]
# np.save('success.npy', success_results)  # Removed - results are already printed to console

print("\n=== VALIDATION SUMMARY ===")
if specific_step:
    step_names = ["summarize_root", "create_numpy", "preprocess", "scores", "categorization"]
    step_name = step_names[specific_step - 1]
    print(f"Step: {specific_step} ({step_name})")
    if specific_step == 1:
        print("Files validated:")
        print("  • file_list.txt - List of processed ROOT files")
        print("  • root_summary.txt - Branch structure and file metadata")
    elif specific_step == 2:
        print("Files validated:")
        print("  • data_A_raw.npy - Raw data array (must have 46 columns)")
        print("  • signal_WH_raw.npy - Raw signal array (must have 46 columns)")
    elif specific_step == 3:
        print("Files validated:")
        print("  • signal.npy - Preprocessed signal events")
        print("  • bkgd.npy - Preprocessed background events")
        # print("Histograms validated:")
        # print("  • Signal m_yy histogram (10 bins, 123-127 GeV)")
        # print("  • Background m_yy histogram (100 bins, 105-160 GeV)")
        # print("  • Signal leading lepton pT histogram (10 bins, 25-300 GeV)")
        # print("  • Background leading lepton pT histogram (10 bins, 25-300 GeV)")
    elif specific_step == 4:
        print("Files validated:")
        print("  • signal_scores.npy - Signal event classification scores")
        print("  • bkgd_scores.npy - Background event classification scores")
    elif specific_step == 5:
        print("Files validated:")
        print("  • boundaries.npy - Category boundary thresholds")
        print("  • significances.npy - Statistical significance values")
else:
    print("All steps validated")

# Mode info removed; direct comparisons are used for all steps

# Show only relevant step status
if specific_step:
    step_names = ["summarize_root", "create_numpy", "preprocess", "scores", "categorization"]
    step_name = step_names[specific_step - 1]
    
    if specific_step == 1 and not missing_file_1:
        status = "PASS" if step1_success else "FAIL"
    elif specific_step == 2 and not missing_file_2:
        status = "PASS" if step2_success else "FAIL"
    elif specific_step == 3 and not missing_file_3:
        status = "PASS" if step3_success else "FAIL"
    elif specific_step == 4 and not missing_file_4:
        status = "PASS" if step4_success else "FAIL"
    elif specific_step == 5 and not missing_file_5:
        status = "PASS" if step5_success else "FAIL"
    else:
        status = "MISSING"
    
    print(f"\nStep {specific_step} ({step_name}): {status}")
    
    if status == "PASS":
        print("✅ Validation successful")
    elif status == "FAIL":
        print("❌ Validation failed")
    else:
        print("⚠️  Step outputs missing")
else:
    # Show all steps for full validation
    step_status = []
    for i, (success, missing) in enumerate([(step1_success, missing_file_1), 
                                            (step2_success, missing_file_2), 
                                            (step3_success, missing_file_3), 
                                            (step4_success, missing_file_4), 
                                            (step5_success, missing_file_5)], 1):
        if missing:
            step_status.append("MISSING")
        elif success:
            step_status.append("PASS")
        else:
            step_status.append("FAIL")

    print(f"Step 1 (summarize_root): {step_status[0]}")
    print(f"Step 2 (create_numpy): {step_status[1]}")
    print(f"Step 3 (preprocess): {step_status[2]}")
    print(f"Step 4 (scores): {step_status[3]}")
    print(f"Step 5 (categorization): {step_status[4]}")

# Only count actually validated steps for overall success
if specific_step:
    validated_steps = 1
    passed_steps = 1 if success_results[specific_step-1] and not [missing_file_1, missing_file_2, missing_file_3, missing_file_4, missing_file_5][specific_step-1] else 0
    print(f"\nResult: {passed_steps}/{validated_steps} step passed")
else:
    validated_steps = sum(1 for missing in [missing_file_1, missing_file_2, missing_file_3, missing_file_4, missing_file_5] if not missing)
    passed_steps = sum(success_results)
    print(f"Overall success: {passed_steps}/{validated_steps} validated steps passed")
    print(f"Success array: {success_results}")

# At the end of main script, ensure validation script exits zero so Run_SMK prints PASS/FAIL instead of 'failed to run'
sys.exit(0)