File size: 3,138 Bytes
cfcbbc8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
#!/usr/bin/env python3
import os
import sys
import argparse
import numpy as np


def compare_pair(res_path: str, sol_path: str) -> str:
    A = np.load(res_path)
    B = np.load(sol_path)
    lines = []
    lines.append(f"shape: res {A.shape} vs sol {B.shape}")
    if A.shape != B.shape:
        lines.append("DIFF: shape mismatch")
        return "\n".join(lines)

    nan_mask_equal = np.array_equal(np.isnan(A), np.isnan(B))
    finite_mask = (~np.isnan(A)) & (~np.isnan(B))
    equal_exact = np.array_equal(A[finite_mask], B[finite_mask])
    mismatch_count = int(np.sum(A[finite_mask] != B[finite_mask]))

    lines.append(f"NaN mask equal: {nan_mask_equal}")
    lines.append(f"Exact equal (finite): {equal_exact}")
    lines.append(f"mismatches (finite): {mismatch_count} / {finite_mask.sum()}")

    if finite_mask.any():
        diffs = np.abs(A[finite_mask] - B[finite_mask])
        lines.append(f"diff stats (finite): max={diffs.max():.6g}, mean={diffs.mean():.6g}")
        lines.append(f"allclose atol=1e-10: {np.allclose(A, B, rtol=0.0, atol=1e-10, equal_nan=True)}")
        lines.append(f"allclose atol=1e-6: {np.allclose(A, B, rtol=0.0, atol=1e-6, equal_nan=True)}")
    else:
        lines.append("No finite entries to compare")

    return "\n".join(lines)


def main():
    p = argparse.ArgumentParser(description="Compare arrays under results/arrays vs solution/arrays")
    p.add_argument("--results_dir", required=True, help="Path to results_XXX directory containing arrays/")
    p.add_argument("--solution_arrays", default="solution/arrays", help="Path to solution/arrays directory")
    args = p.parse_args()

    res_arrays = os.path.join(args.results_dir, "arrays")
    sol_arrays = args.solution_arrays

    if not os.path.isdir(res_arrays):
        print(f"ERROR: results arrays dir not found: {res_arrays}", file=sys.stderr)
        sys.exit(2)
    if not os.path.isdir(sol_arrays):
        print(f"ERROR: solution arrays dir not found: {sol_arrays}", file=sys.stderr)
        sys.exit(2)

    files = sorted([f for f in os.listdir(res_arrays) if f.endswith('.npy')])
    any_diff = False
    print(f"Comparing {len(files)} files found in {res_arrays} against {sol_arrays}\n")
    for f in files:
        sol_path = os.path.join(sol_arrays, f)
        res_path = os.path.join(res_arrays, f)
        print(f"== {f} ==")
        if not os.path.exists(sol_path):
            print(f"- No reference: {sol_path}")
            print("")
            continue
        try:
            report = compare_pair(res_path, sol_path)
            print(report)
            print("")
            if "DIFF:" in report or "Exact equal (finite): False" in report or "NaN mask equal: False" in report:
                any_diff = True
        except Exception as e:
            print(f"ERROR comparing {f}: {e}")
            any_diff = True
            print("")

    if any_diff:
        print("SUMMARY: Differences detected (see details above)")
        sys.exit(1)
    else:
        print("SUMMARY: All compared arrays match within tolerance (see details above)")
        sys.exit(0)


if __name__ == "__main__":
    main()