File size: 7,073 Bytes
02c783d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
import json
import argparse
import os
import sys
from random import randint
from glob import glob
import importlib.util
import random
import numpy as np
import torch
from collections import namedtuple

torch.set_printoptions(profile="full")

def set_seed(seed: int = 42) -> None:
    """
    Set the random seed for reproducibility across multiple libraries and configure PyTorch for deterministic behavior.

    Args:
        seed (int): The seed value to set. Default is 42.
    """
    # Set seed for Python's built-in random module
    random.seed(seed)
    # Set seed for NumPy
    np.random.seed(seed)
    # Set seed for PyTorch on CPU
    torch.manual_seed(seed)
    # Set seed for PyTorch on all GPUs (if available)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    # Ensure deterministic behavior in PyTorch
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    # Set environment variable for hash-based operations
    os.environ['PYTHONHASHSEED'] = str(seed)


def import_variable_from_file(file_path, variable_name):
    """
    Dynamically imports a variable from a Python file.

    Parameters:
    - file_path (str): The path to the Python file.
    - variable_name (str): The name of the variable to import.

    Returns:
    - The value of the specified variable, or None if not found.
    """
    set_seed()

    if not os.path.isfile(file_path):
        raise FileNotFoundError(f"No file found at {file_path}")

    module_name = os.path.splitext(os.path.basename(file_path))[0]

    # Create a module spec
    spec = importlib.util.spec_from_file_location(module_name, file_path)
    if spec is None:
        raise ImportError(f"Could not load specification for module {module_name}")

    # Create a new module based on the spec
    module = importlib.util.module_from_spec(spec)
    if module is None:
        raise ImportError(f"Could not create module {module_name} from spec")

    # Execute the module in its own namespace
    try:
        spec.loader.exec_module(module)
    except Exception as e:
        raise ImportError(f"Could not execute module {module_name} due to {e}")

    # Retrieve the variable from the module
    return getattr(module, variable_name, None)

def _compare(tri, pyt, fname, atol=1e-3, rtol=1e-3, verbose=False):
    if type(pyt) == np.ndarray:
        if np.allclose(tri, pyt,  atol=atol, rtol=rtol):
            if verbose:
                print(f"PyTorch and Triton matched for file: {fname}")
            return True
        else:
            if verbose:
                diff = np.amax(np.abs(tri - pyt))
                print(f"Test failed for file: {fname} with abs max diff: {diff}")
            return False
    elif type(pyt) == torch.Tensor:
        if torch.allclose(tri, pyt,  atol=atol, rtol=rtol):
            if verbose:
                print(f"PyTorch and Triton matched for file: {fname}")
            return True
        else:
            if verbose:
                diff = (tri - pyt).abs().max()
                print(f"Test failed for file: {fname} with abs max diff: {diff}")
            return False
    elif type(pyt) == namedtuple:
        if tri._fields != pyt._fields:
            return False
        for field in tri._fields:
            if not torch.equal(getattr(tri, field), getattr(pyt, field)):
                return False
        return True
    else:
        return tri == pyt
    return None

def compare(ref, gen, fname, atol=1e-3, rtol=1e-3, verbose=False):
    ret_val = True
    # import pdb; pdb.set_trace()
    if (type(gen) == list) or (type(gen) == tuple):
        for tri, pyt in zip(ref, gen):
            ret_val &= compare(tri, pyt, fname)
    elif type(gen) == dict:
        return compare( list(ref.values()), list(gen.values()), fname, atol=atol, rtol=rtol, verbose=verbose)
    else:
        ret_val &= _compare(ref, gen, fname, atol=atol, rtol=rtol, verbose=verbose)
    return ret_val

def test_correctness(ref_file, gen_file, var_name, atol=1e-3, rtol=1e-3, verbose=False):
    fname = os.path.basename(gen_file)
    gen_call_acc, ref_call_acc = False, False
    gen_stderr, ref_stderr = None, None
    try:
        gen_result_golden = import_variable_from_file(gen_file, var_name)
        if verbose:
            with open(gen_file+".out", "w") as f:
                f.write(f"file: {fname}\n")
                json.dump(str(gen_result_golden), f)
                f.write("\n\n\n")
                f.write("#"*146)
        gen_call_acc = True
    except Exception as e:
        gen_stderr = e
        return gen_call_acc, None, None, gen_stderr
    try:
        ref_result_golden = import_variable_from_file(ref_file, var_name)
        if verbose:
            with open(gen_file+".out_ref", "w") as f:
                f.write(f"file: {fname}\n")
                json.dump(str(ref_result_golden), f)
                f.write("\n\n\n")
                f.write("#"*146)
        ref_call_acc = True
    except Exception as e:
        ref_stderr = e
        return gen_call_acc, None, None, ref_stderr
    # assert (ref_result_golden is not None), f'Reference output is None for file: {ref_file}'
    # assert (gen_result_golden is not None), f'Generated output is None for file: {gen_file}'
    # assert type(gen_result_golden) == type(ref_result_golden), f"Reference and Generated output results should be of the same type but generated is: {type(gen_result_golden)}, and reference is: {type(ref_result_golden)}"
    if gen_result_golden is None:
        return gen_call_acc, False, None, "Generated output is None"
    if ref_result_golden is None:
        return gen_call_acc, False, None, "Reference output is None"
    if type(gen_result_golden) != type(ref_result_golden):
        return gen_call_acc, False, None, f"Reference and Generated output results should be of the same type but generated is: {type(gen_result_golden)}, and reference is: {type(ref_result_golden)}"
    exec_acc = compare(ref_result_golden, gen_result_golden, fname, atol=atol, rtol=rtol, verbose=verbose)
    if not exec_acc:
        gen_stderr = f"Generated output does not match reference output for file: {fname}"
    return gen_call_acc, exec_acc, None, gen_stderr

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--gen_file", "-pf", type=str, required=True)
    parser.add_argument("--ref_file", "-tf", type=str, required=True)
    parser.add_argument("--var_name", type=str, default="result_gold")
    parser.add_argument("--atol", type=float, default=1e-3)
    parser.add_argument("--rtol", type=float, default=1e-1)
    parser.add_argument("--verbose", action="store_true", help="Enable verbose output")
    args = parser.parse_args()
    return args

if __name__ == "__main__":
    args = parse_args()
    gen_call_acc, exec_acc, stdout, gen_stderr = test_correctness(args.ref_file, args.gen_file, args.var_name, atol=args.atol, rtol=args.rtol, verbose=args.verbose)
    print(f"{gen_call_acc}*#*#{exec_acc}*#*#{stdout}*#*#{gen_stderr}")