| | |
| | import os |
| | import sys |
| | import argparse |
| | import numpy as np |
| |
|
| |
|
| | def compare_pair(res_path: str, sol_path: str) -> str: |
| | A = np.load(res_path) |
| | B = np.load(sol_path) |
| | lines = [] |
| | lines.append(f"shape: res {A.shape} vs sol {B.shape}") |
| | if A.shape != B.shape: |
| | lines.append("DIFF: shape mismatch") |
| | return "\n".join(lines) |
| |
|
| | nan_mask_equal = np.array_equal(np.isnan(A), np.isnan(B)) |
| | finite_mask = (~np.isnan(A)) & (~np.isnan(B)) |
| | equal_exact = np.array_equal(A[finite_mask], B[finite_mask]) |
| | mismatch_count = int(np.sum(A[finite_mask] != B[finite_mask])) |
| |
|
| | lines.append(f"NaN mask equal: {nan_mask_equal}") |
| | lines.append(f"Exact equal (finite): {equal_exact}") |
| | lines.append(f"mismatches (finite): {mismatch_count} / {finite_mask.sum()}") |
| |
|
| | if finite_mask.any(): |
| | diffs = np.abs(A[finite_mask] - B[finite_mask]) |
| | lines.append(f"diff stats (finite): max={diffs.max():.6g}, mean={diffs.mean():.6g}") |
| | lines.append(f"allclose atol=1e-10: {np.allclose(A, B, rtol=0.0, atol=1e-10, equal_nan=True)}") |
| | lines.append(f"allclose atol=1e-6: {np.allclose(A, B, rtol=0.0, atol=1e-6, equal_nan=True)}") |
| | else: |
| | lines.append("No finite entries to compare") |
| |
|
| | return "\n".join(lines) |
| |
|
| |
|
| | def main(): |
| | p = argparse.ArgumentParser(description="Compare arrays under results/arrays vs solution/arrays") |
| | p.add_argument("--results_dir", required=True, help="Path to results_XXX directory containing arrays/") |
| | p.add_argument("--solution_arrays", default="solution/arrays", help="Path to solution/arrays directory") |
| | args = p.parse_args() |
| |
|
| | res_arrays = os.path.join(args.results_dir, "arrays") |
| | sol_arrays = args.solution_arrays |
| |
|
| | if not os.path.isdir(res_arrays): |
| | print(f"ERROR: results arrays dir not found: {res_arrays}", file=sys.stderr) |
| | sys.exit(2) |
| | if not os.path.isdir(sol_arrays): |
| | print(f"ERROR: solution arrays dir not found: {sol_arrays}", file=sys.stderr) |
| | sys.exit(2) |
| |
|
| | files = sorted([f for f in os.listdir(res_arrays) if f.endswith('.npy')]) |
| | any_diff = False |
| | print(f"Comparing {len(files)} files found in {res_arrays} against {sol_arrays}\n") |
| | for f in files: |
| | sol_path = os.path.join(sol_arrays, f) |
| | res_path = os.path.join(res_arrays, f) |
| | print(f"== {f} ==") |
| | if not os.path.exists(sol_path): |
| | print(f"- No reference: {sol_path}") |
| | print("") |
| | continue |
| | try: |
| | report = compare_pair(res_path, sol_path) |
| | print(report) |
| | print("") |
| | if "DIFF:" in report or "Exact equal (finite): False" in report or "NaN mask equal: False" in report: |
| | any_diff = True |
| | except Exception as e: |
| | print(f"ERROR comparing {f}: {e}") |
| | any_diff = True |
| | print("") |
| |
|
| | if any_diff: |
| | print("SUMMARY: Differences detected (see details above)") |
| | sys.exit(1) |
| | else: |
| | print("SUMMARY: All compared arrays match within tolerance (see details above)") |
| | sys.exit(0) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|