LLM4HEP / util /compare_arrays.py
ho22joshua's picture
initial commit
cfcbbc8
#!/usr/bin/env python3
import os
import sys
import argparse
import numpy as np
def compare_pair(res_path: str, sol_path: str) -> str:
A = np.load(res_path)
B = np.load(sol_path)
lines = []
lines.append(f"shape: res {A.shape} vs sol {B.shape}")
if A.shape != B.shape:
lines.append("DIFF: shape mismatch")
return "\n".join(lines)
nan_mask_equal = np.array_equal(np.isnan(A), np.isnan(B))
finite_mask = (~np.isnan(A)) & (~np.isnan(B))
equal_exact = np.array_equal(A[finite_mask], B[finite_mask])
mismatch_count = int(np.sum(A[finite_mask] != B[finite_mask]))
lines.append(f"NaN mask equal: {nan_mask_equal}")
lines.append(f"Exact equal (finite): {equal_exact}")
lines.append(f"mismatches (finite): {mismatch_count} / {finite_mask.sum()}")
if finite_mask.any():
diffs = np.abs(A[finite_mask] - B[finite_mask])
lines.append(f"diff stats (finite): max={diffs.max():.6g}, mean={diffs.mean():.6g}")
lines.append(f"allclose atol=1e-10: {np.allclose(A, B, rtol=0.0, atol=1e-10, equal_nan=True)}")
lines.append(f"allclose atol=1e-6: {np.allclose(A, B, rtol=0.0, atol=1e-6, equal_nan=True)}")
else:
lines.append("No finite entries to compare")
return "\n".join(lines)
def main():
p = argparse.ArgumentParser(description="Compare arrays under results/arrays vs solution/arrays")
p.add_argument("--results_dir", required=True, help="Path to results_XXX directory containing arrays/")
p.add_argument("--solution_arrays", default="solution/arrays", help="Path to solution/arrays directory")
args = p.parse_args()
res_arrays = os.path.join(args.results_dir, "arrays")
sol_arrays = args.solution_arrays
if not os.path.isdir(res_arrays):
print(f"ERROR: results arrays dir not found: {res_arrays}", file=sys.stderr)
sys.exit(2)
if not os.path.isdir(sol_arrays):
print(f"ERROR: solution arrays dir not found: {sol_arrays}", file=sys.stderr)
sys.exit(2)
files = sorted([f for f in os.listdir(res_arrays) if f.endswith('.npy')])
any_diff = False
print(f"Comparing {len(files)} files found in {res_arrays} against {sol_arrays}\n")
for f in files:
sol_path = os.path.join(sol_arrays, f)
res_path = os.path.join(res_arrays, f)
print(f"== {f} ==")
if not os.path.exists(sol_path):
print(f"- No reference: {sol_path}")
print("")
continue
try:
report = compare_pair(res_path, sol_path)
print(report)
print("")
if "DIFF:" in report or "Exact equal (finite): False" in report or "NaN mask equal: False" in report:
any_diff = True
except Exception as e:
print(f"ERROR comparing {f}: {e}")
any_diff = True
print("")
if any_diff:
print("SUMMARY: Differences detected (see details above)")
sys.exit(1)
else:
print("SUMMARY: All compared arrays match within tolerance (see details above)")
sys.exit(0)
if __name__ == "__main__":
main()