|
|
|
|
|
import os, sys, json |
|
|
import subprocess |
|
|
from shutil import copyfile |
|
|
from .base import BaseEvaluator |
|
|
from ..helpers import get_temp_file, get_rocm_temp_file |
|
|
from ..helpers.helper import run_shell, process_code, extract_errors, extract_json_from_stdout |
|
|
from ..processors.llm import LLMOutputProcessor |
|
|
from ..constants import REPO_ROOT, TMP_ROOT, TBG_DATA_ROOT, ROCm_DATA_ROOT, Names, NATIVE_PERF_GOLD_ROOT |
|
|
_MODULE_DIR = os.path.dirname(os.path.abspath(__file__)) |
|
|
|
|
|
class TestAllCloseEvaluatorTBG(BaseEvaluator): |
|
|
def __init__(self, ground_truth_root :str = TBG_DATA_ROOT, MODULE_DIR :str =_MODULE_DIR) -> None: |
|
|
super().__init__("CorrectnessWithAllCloseEvaluator") |
|
|
|
|
|
self.ground_truth_root = ground_truth_root |
|
|
self.MODULE_DIR = MODULE_DIR |
|
|
|
|
|
|
|
|
self.llm_output_processor = LLMOutputProcessor() |
|
|
|
|
|
def get_gen_fpath(self, root :str, fname :str) -> str: |
|
|
""" |
|
|
Returns the path to the generated file based on the given filename. |
|
|
""" |
|
|
tmp_gen_folder = os.path.join(root, "tmp", Names.GEN_FOLDER) |
|
|
os.makedirs(tmp_gen_folder, exist_ok=True) |
|
|
gen_file = get_temp_file(prefix=f'{fname}{Names.GEN_SUFFIX}') |
|
|
return os.path.join(tmp_gen_folder, gen_file) |
|
|
|
|
|
def get_ground_truth_fpath(self, fname:str ) -> str: |
|
|
""" |
|
|
Returns the path to the ground truth file based on the given filename. |
|
|
""" |
|
|
return os.path.join(self.ground_truth_root, fname) |
|
|
|
|
|
def get_tests_code(self, fname :str) -> list[str]: |
|
|
with open(fname, 'r') as f: |
|
|
lines = f.readlines() |
|
|
for iL, line in enumerate(lines): |
|
|
if line.strip() == self.tests_sep_line: |
|
|
break |
|
|
test_code_lines = lines[iL+1:] |
|
|
assert len(test_code_lines) > 0, f"No test code found in {fname} with test line seperator {self.tests_sep_line}" |
|
|
return test_code_lines |
|
|
|
|
|
def format_gen_code(self, fname:str, code:str, tests:list[str]) -> str: |
|
|
code = code + '\n\n' + self.tests_sep_line + '\n' + '\n' + '\n'.join(tests) |
|
|
with open(fname, 'w') as f: |
|
|
f.write(code) |
|
|
return code |
|
|
|
|
|
def _call_file(self, fpath :str, timeout :int =2*60) -> tuple[bool, str, str]: |
|
|
|
|
|
cmd = [f'python3 {fpath}'] |
|
|
call_status, stdout, stderr = run_shell(cmd, timeout=timeout) |
|
|
return call_status, stdout, stderr |
|
|
|
|
|
def _check_match(self, gen_fpath :str, ref_fpath:str , atol :float=1e-3, rtol :float=1e-1, timeout :int =2*60) -> tuple[bool, str, str]: |
|
|
""" |
|
|
Executes the generated file and checks its correctness against the reference file. |
|
|
""" |
|
|
cmd = [f'python3 {self.MODULE_DIR}/TB_correctness.py --gen_file {gen_fpath} --ref_file {ref_fpath} --atol {atol} --rtol {rtol}'] |
|
|
status, stdout, stderr = run_shell(cmd, timeout=timeout) |
|
|
with open(gen_fpath+".stdout", 'w') as f: |
|
|
f.write(stdout) |
|
|
|
|
|
with open(gen_fpath+".stderr", 'w') as f: |
|
|
f.write(stderr) |
|
|
return status, stdout, stderr |
|
|
|
|
|
def execute(self, code :str, log_root:str, exec_root:str, fname:str, atol :float=1e-3, rtol:float=1e-1, timeout :int =2*60, verbose :bool =False, custom_tests_path=None) -> tuple[bool, bool, str, str]: |
|
|
|
|
|
speedup = 0 |
|
|
|
|
|
triton_file = self.get_ground_truth_fpath(fname) |
|
|
|
|
|
gen_file = self.get_gen_fpath(log_root, fname) |
|
|
|
|
|
test_code_lines_procs = self.get_tests_code(triton_file) |
|
|
|
|
|
code = process_code(code) |
|
|
|
|
|
code = self.format_gen_code(gen_file, code, test_code_lines_procs) |
|
|
|
|
|
try: |
|
|
call_status, stdout, stderr = self._call_file(gen_file, timeout=timeout) |
|
|
|
|
|
if not call_status: |
|
|
return call_status, False, speedup, stdout, stderr |
|
|
|
|
|
|
|
|
exec_fpath = os.path.join(exec_root, fname) |
|
|
os.makedirs(exec_root, exist_ok=True) |
|
|
with open(exec_fpath, 'w') as f: |
|
|
f.write(code) |
|
|
|
|
|
perf_fname = f"{NATIVE_PERF_GOLD_ROOT}/{fname.replace('.py', '_perf.py')}" |
|
|
assert os.path.exists(perf_fname), f"Performance file {perf_fname} does not exist. Please check the ground truth data." |
|
|
|
|
|
exec_perf_fpath = os.path.join(exec_root, os.path.basename(perf_fname)) |
|
|
copyfile(perf_fname, exec_perf_fpath) |
|
|
|
|
|
cmd_status, perf_stdout, perf_stderr = self._call_file(exec_perf_fpath, timeout=timeout) |
|
|
|
|
|
if not cmd_status: |
|
|
if verbose: |
|
|
print(f"Performance file execution failed: {perf_stderr}") |
|
|
return call_status, False, speedup, perf_stdout, perf_stderr |
|
|
|
|
|
exec_status = "Failed to run benchmark for input tensor." not in perf_stdout |
|
|
|
|
|
if not exec_status: |
|
|
if verbose: |
|
|
print(f"Performance file execution failed: {perf_stderr}") |
|
|
return call_status, False, speedup, perf_stdout, perf_stderr |
|
|
|
|
|
perf_data = extract_json_from_stdout(perf_stdout) |
|
|
|
|
|
if not perf_data: |
|
|
if verbose: |
|
|
print(f"Performance data extraction failed: {perf_stderr}") |
|
|
return call_status, exec_status, speedup, perf_stdout, perf_stderr |
|
|
|
|
|
speedup = perf_data[-1].get("speedup", 0) |
|
|
return call_status, exec_status, speedup, perf_stdout, perf_stderr |
|
|
|
|
|
except Exception as e: |
|
|
if verbose: |
|
|
print(f"File: {fname}, Execution error: {e}") |
|
|
return False, False, 0, None, str(e) |
|
|
|
|
|
except subprocess.TimeoutExpired: |
|
|
if verbose: |
|
|
print(f"File: {fname} timed out!") |
|
|
return False, False, 0, None, "Time out" |
|
|
|
|
|
|
|
|
class TestAllCloseEvaluatorROCm(TestAllCloseEvaluatorTBG): |
|
|
def __init__(self, ground_truth_root :str = ROCm_DATA_ROOT, MODULE_DIR :str =_MODULE_DIR) -> None: |
|
|
super().__init__("CorrectnessWithAllCloseEvaluator") |
|
|
|
|
|
self.ground_truth_root = ground_truth_root |
|
|
self.MODULE_DIR = MODULE_DIR |
|
|
|
|
|
self.llm_output_processor = LLMOutputProcessor() |
|
|
|
|
|
def get_ground_truth_fpath(self, fname:str ) -> str: |
|
|
""" |
|
|
Returns the path to the ground truth file based on the given filename. |
|
|
""" |
|
|
return os.path.join(self.ground_truth_root, fname) |
|
|
|
|
|
|
|
|
|
|
|
def get_fpath(self, root :str, fname :str, suffix: str) -> str: |
|
|
""" |
|
|
Returns the path to the generated file based on the given filename. |
|
|
""" |
|
|
tmp_gen_folder = os.path.join(root, "tmp", Names.GEN_FOLDER) |
|
|
os.makedirs(tmp_gen_folder, exist_ok=True) |
|
|
file = get_rocm_temp_file(prefix=f'{fname}{suffix}') |
|
|
return os.path.join(tmp_gen_folder, file) |
|
|
|
|
|
def get_tests_code(self, fname :str) -> list[str]: |
|
|
with open(fname, 'r') as f: |
|
|
content = f.read() |
|
|
|
|
|
snippets = content.split(self.tests_sep_line) |
|
|
assert len(snippets) == 2, f"No test code found in {fname} with test line seperator {self.tests_sep_line}" |
|
|
|
|
|
test_code_lines_procs = snippets[1].strip() |
|
|
return test_code_lines_procs |
|
|
|
|
|
def format_gen_code(self, fname:str, code:str, tests:list[str]) -> str: |
|
|
code = code + '\n\n' + '#'*146 + '\n\n' + tests |
|
|
with open(fname, 'w') as f: |
|
|
f.write(code) |
|
|
return code |
|
|
|
|
|
def _check_match(self, gen_fpath :str, ref_fpath:str , atol :float=1e-2, rtol :float=1e-2, timeout :int =60*60) -> tuple[bool, str, str]: |
|
|
""" |
|
|
Executes the generated file and checks its correctness against the reference file. |
|
|
""" |
|
|
cmd = [f'python3 {self.MODULE_DIR}/ROCm_correctness.py --gen_file {gen_fpath} --ref_file {ref_fpath} --atol {atol} --rtol {rtol} --global_timeout {timeout} --verbose'] |
|
|
status, stdout, stderr = run_shell(cmd, timeout=None if timeout is None else timeout) |
|
|
with open(gen_fpath+".stdout", 'w') as f: |
|
|
f.write(stdout) |
|
|
|
|
|
stderr += extract_errors(stdout.split(Names.PYTEST_SEPARATOR)[0]) |
|
|
with open(gen_fpath+".stderr", 'w') as f: |
|
|
f.write(stderr) |
|
|
return status, stdout, stderr |
|
|
|
|
|
def execute(self, code :str, log_root:str, exec_root:str, fname:str, atol :float=1e-3, rtol:float=1e-1, timeout :int =40*60, custom_tests_path=None, verbose :bool =False) -> tuple[bool, bool, str, str]: |
|
|
|
|
|
triton_file = self.get_ground_truth_fpath(fname) |
|
|
|
|
|
gen_file = self.get_fpath(log_root, fname, Names.GEN_SUFFIX) |
|
|
|
|
|
ref_file = self.get_fpath(log_root, fname, Names.REF_SUFFIX) |
|
|
|
|
|
copyfile(triton_file,ref_file) |
|
|
|
|
|
test_code_lines_procs = self.get_tests_code(triton_file) |
|
|
|
|
|
if custom_tests_path is not None and "@triton.autotune" in code: |
|
|
custom_tests_file = os.path.join(custom_tests_path, fname) |
|
|
if os.path.exists(custom_tests_file): |
|
|
test_code_lines_procs = self.get_tests_code(custom_tests_file) |
|
|
else: |
|
|
if verbose: |
|
|
print(f"Custom tests file {custom_tests_file} does not exist. Skipping custom tests.", file=sys.stderr) |
|
|
|
|
|
code = process_code(code) |
|
|
code = self.format_gen_code(gen_file, code, test_code_lines_procs) |
|
|
|
|
|
try: |
|
|
call_status, stdout, stderr = False, None, None |
|
|
|
|
|
|
|
|
match_status = False |
|
|
atol, rtol = 1e-2, 1e-2 |
|
|
match_status, stdout, stderr = self._check_match(gen_file, ref_file, atol=atol, rtol=rtol, timeout=timeout) |
|
|
|
|
|
|
|
|
if not match_status: |
|
|
if verbose: |
|
|
print(f"Error in generated code: {stderr}") |
|
|
stderr += "Error in generate triton-kernel code :\n " + extract_errors(stdout.split(Names.PYTEST_SEPARATOR)[0]) |
|
|
return call_status, False, -1, stdout, stderr |
|
|
else: |
|
|
if verbose: |
|
|
print(f"Success in generated code: {stdout}") |
|
|
call_status_str, exec_status_str, gen_stdout, gen_stderr = stdout.split(Names.PYTEST_SEPARATOR)[-1].split(Names.RET_SEPERATOR) |
|
|
gen_stderr += extract_errors(stdout.split(Names.PYTEST_SEPARATOR)[0]) |
|
|
call_status = call_status_str.replace('\n', '').strip().lower() == str(True).lower() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if call_status: |
|
|
exec_status = exec_status_str.strip().lower() == str(True).lower() |
|
|
else: |
|
|
exec_status = False |
|
|
|
|
|
if exec_status: |
|
|
|
|
|
exec_fpath = os.path.join(exec_root, fname) |
|
|
os.makedirs(exec_root, exist_ok=True) |
|
|
with open(exec_fpath, 'w') as f: |
|
|
f.write(code) |
|
|
|
|
|
|
|
|
return call_status, exec_status, -1, gen_stdout, gen_stderr |
|
|
|
|
|
except Exception as e: |
|
|
if verbose: |
|
|
print(f"File: {fname}, Execution error: {e}") |
|
|
return False, False, None, -1, str(e) |
|
|
|
|
|
except subprocess.TimeoutExpired: |
|
|
if verbose: |
|
|
print(f"File: {fname} timed out!") |
|
|
return False, False, None, -1, "Time out" |
|
|
get_evaluators = { |
|
|
"tbg": TestAllCloseEvaluatorTBG, |
|
|
"rocm": TestAllCloseEvaluatorROCm |
|
|
} |
|
|
|