llmll's picture
Upload folder using huggingface_hub
02c783d verified
import json
import argparse
import os
import sys
from random import randint
from glob import glob
import importlib.util
import random
import numpy as np
import torch
from collections import namedtuple
torch.set_printoptions(profile="full")
def set_seed(seed: int = 42) -> None:
"""
Set the random seed for reproducibility across multiple libraries and configure PyTorch for deterministic behavior.
Args:
seed (int): The seed value to set. Default is 42.
"""
# Set seed for Python's built-in random module
random.seed(seed)
# Set seed for NumPy
np.random.seed(seed)
# Set seed for PyTorch on CPU
torch.manual_seed(seed)
# Set seed for PyTorch on all GPUs (if available)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
# Ensure deterministic behavior in PyTorch
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# Set environment variable for hash-based operations
os.environ['PYTHONHASHSEED'] = str(seed)
def import_variable_from_file(file_path, variable_name):
"""
Dynamically imports a variable from a Python file.
Parameters:
- file_path (str): The path to the Python file.
- variable_name (str): The name of the variable to import.
Returns:
- The value of the specified variable, or None if not found.
"""
set_seed()
if not os.path.isfile(file_path):
raise FileNotFoundError(f"No file found at {file_path}")
module_name = os.path.splitext(os.path.basename(file_path))[0]
# Create a module spec
spec = importlib.util.spec_from_file_location(module_name, file_path)
if spec is None:
raise ImportError(f"Could not load specification for module {module_name}")
# Create a new module based on the spec
module = importlib.util.module_from_spec(spec)
if module is None:
raise ImportError(f"Could not create module {module_name} from spec")
# Execute the module in its own namespace
try:
spec.loader.exec_module(module)
except Exception as e:
raise ImportError(f"Could not execute module {module_name} due to {e}")
# Retrieve the variable from the module
return getattr(module, variable_name, None)
def _compare(tri, pyt, fname, atol=1e-3, rtol=1e-3, verbose=False):
if type(pyt) == np.ndarray:
if np.allclose(tri, pyt, atol=atol, rtol=rtol):
if verbose:
print(f"PyTorch and Triton matched for file: {fname}")
return True
else:
if verbose:
diff = np.amax(np.abs(tri - pyt))
print(f"Test failed for file: {fname} with abs max diff: {diff}")
return False
elif type(pyt) == torch.Tensor:
if torch.allclose(tri, pyt, atol=atol, rtol=rtol):
if verbose:
print(f"PyTorch and Triton matched for file: {fname}")
return True
else:
if verbose:
diff = (tri - pyt).abs().max()
print(f"Test failed for file: {fname} with abs max diff: {diff}")
return False
elif type(pyt) == namedtuple:
if tri._fields != pyt._fields:
return False
for field in tri._fields:
if not torch.equal(getattr(tri, field), getattr(pyt, field)):
return False
return True
else:
return tri == pyt
return None
def compare(ref, gen, fname, atol=1e-3, rtol=1e-3, verbose=False):
ret_val = True
# import pdb; pdb.set_trace()
if (type(gen) == list) or (type(gen) == tuple):
for tri, pyt in zip(ref, gen):
ret_val &= compare(tri, pyt, fname)
elif type(gen) == dict:
return compare( list(ref.values()), list(gen.values()), fname, atol=atol, rtol=rtol, verbose=verbose)
else:
ret_val &= _compare(ref, gen, fname, atol=atol, rtol=rtol, verbose=verbose)
return ret_val
def test_correctness(ref_file, gen_file, var_name, atol=1e-3, rtol=1e-3, verbose=False):
fname = os.path.basename(gen_file)
gen_call_acc, ref_call_acc = False, False
gen_stderr, ref_stderr = None, None
try:
gen_result_golden = import_variable_from_file(gen_file, var_name)
if verbose:
with open(gen_file+".out", "w") as f:
f.write(f"file: {fname}\n")
json.dump(str(gen_result_golden), f)
f.write("\n\n\n")
f.write("#"*146)
gen_call_acc = True
except Exception as e:
gen_stderr = e
return gen_call_acc, None, None, gen_stderr
try:
ref_result_golden = import_variable_from_file(ref_file, var_name)
if verbose:
with open(gen_file+".out_ref", "w") as f:
f.write(f"file: {fname}\n")
json.dump(str(ref_result_golden), f)
f.write("\n\n\n")
f.write("#"*146)
ref_call_acc = True
except Exception as e:
ref_stderr = e
return gen_call_acc, None, None, ref_stderr
# assert (ref_result_golden is not None), f'Reference output is None for file: {ref_file}'
# assert (gen_result_golden is not None), f'Generated output is None for file: {gen_file}'
# assert type(gen_result_golden) == type(ref_result_golden), f"Reference and Generated output results should be of the same type but generated is: {type(gen_result_golden)}, and reference is: {type(ref_result_golden)}"
if gen_result_golden is None:
return gen_call_acc, False, None, "Generated output is None"
if ref_result_golden is None:
return gen_call_acc, False, None, "Reference output is None"
if type(gen_result_golden) != type(ref_result_golden):
return gen_call_acc, False, None, f"Reference and Generated output results should be of the same type but generated is: {type(gen_result_golden)}, and reference is: {type(ref_result_golden)}"
exec_acc = compare(ref_result_golden, gen_result_golden, fname, atol=atol, rtol=rtol, verbose=verbose)
if not exec_acc:
gen_stderr = f"Generated output does not match reference output for file: {fname}"
return gen_call_acc, exec_acc, None, gen_stderr
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--gen_file", "-pf", type=str, required=True)
parser.add_argument("--ref_file", "-tf", type=str, required=True)
parser.add_argument("--var_name", type=str, default="result_gold")
parser.add_argument("--atol", type=float, default=1e-3)
parser.add_argument("--rtol", type=float, default=1e-1)
parser.add_argument("--verbose", action="store_true", help="Enable verbose output")
args = parser.parse_args()
return args
if __name__ == "__main__":
args = parse_args()
gen_call_acc, exec_acc, stdout, gen_stderr = test_correctness(args.ref_file, args.gen_file, args.var_name, atol=args.atol, rtol=args.rtol, verbose=args.verbose)
print(f"{gen_call_acc}*#*#{exec_acc}*#*#{stdout}*#*#{gen_stderr}")