llmll's picture
Upload folder using huggingface_hub
02c783d verified
# Copyright(C) [2025] Advanced Micro Devices, Inc. All rights reserved.
import os, sys, json
import subprocess
from shutil import copyfile
from .base import BaseEvaluator
from ..helpers import get_temp_file, get_rocm_temp_file
from ..helpers.helper import run_shell, process_code, extract_errors, extract_json_from_stdout
from ..processors.llm import LLMOutputProcessor
from ..constants import REPO_ROOT, TMP_ROOT, TBG_DATA_ROOT, ROCm_DATA_ROOT, Names, NATIVE_PERF_GOLD_ROOT
_MODULE_DIR = os.path.dirname(os.path.abspath(__file__))
class TestAllCloseEvaluatorTBG(BaseEvaluator):
def __init__(self, ground_truth_root :str = TBG_DATA_ROOT, MODULE_DIR :str =_MODULE_DIR) -> None:
super().__init__("CorrectnessWithAllCloseEvaluator")
self.ground_truth_root = ground_truth_root
self.MODULE_DIR = MODULE_DIR
self.llm_output_processor = LLMOutputProcessor()
def get_gen_fpath(self, root :str, fname :str) -> str:
"""
Returns the path to the generated file based on the given filename.
"""
tmp_gen_folder = os.path.join(root, "tmp", Names.GEN_FOLDER)
os.makedirs(tmp_gen_folder, exist_ok=True)
gen_file = get_temp_file(prefix=f'{fname}{Names.GEN_SUFFIX}')
return os.path.join(tmp_gen_folder, gen_file)
def get_ground_truth_fpath(self, fname:str ) -> str:
"""
Returns the path to the ground truth file based on the given filename.
"""
return os.path.join(self.ground_truth_root, fname)
def get_tests_code(self, fname :str) -> list[str]:
with open(fname, 'r') as f:
lines = f.readlines()
for iL, line in enumerate(lines):
if line.strip() == self.tests_sep_line:
break
test_code_lines = lines[iL+1:]
assert len(test_code_lines) > 0, f"No test code found in {fname} with test line seperator {self.tests_sep_line}"
return test_code_lines
def format_gen_code(self, fname:str, code:str, tests:list[str]) -> str:
code = code + '\n\n' + self.tests_sep_line + '\n' + '\n' + '\n'.join(tests)
with open(fname, 'w') as f:
f.write(code)
return code
def _call_file(self, fpath :str, timeout :int =2*60) -> tuple[bool, str, str]:
## Just to a simple call to the generated code
cmd = [f'python3 {fpath}']
call_status, stdout, stderr = run_shell(cmd, timeout=timeout)
return call_status, stdout, stderr
def _check_match(self, gen_fpath :str, ref_fpath:str , atol :float=1e-3, rtol :float=1e-1, timeout :int =2*60) -> tuple[bool, str, str]:
"""
Executes the generated file and checks its correctness against the reference file.
"""
cmd = [f'python3 {self.MODULE_DIR}/TB_correctness.py --gen_file {gen_fpath} --ref_file {ref_fpath} --atol {atol} --rtol {rtol}']
status, stdout, stderr = run_shell(cmd, timeout=timeout)
with open(gen_fpath+".stdout", 'w') as f:
f.write(stdout)
with open(gen_fpath+".stderr", 'w') as f:
f.write(stderr)
return status, stdout, stderr
def execute(self, code :str, log_root:str, exec_root:str, fname:str, atol :float=1e-3, rtol:float=1e-1, timeout :int =2*60, verbose :bool =False, custom_tests_path=None) -> tuple[bool, bool, str, str]:
speedup = 0
triton_file = self.get_ground_truth_fpath(fname)
gen_file = self.get_gen_fpath(log_root, fname)
test_code_lines_procs = self.get_tests_code(triton_file)
code = process_code(code)
code = self.format_gen_code(gen_file, code, test_code_lines_procs)
try:
call_status, stdout, stderr = self._call_file(gen_file, timeout=timeout)
if not call_status:
return call_status, False, speedup, stdout, stderr
## copy the generated file and its perf file to exec_fpath
exec_fpath = os.path.join(exec_root, fname)
os.makedirs(exec_root, exist_ok=True)
with open(exec_fpath, 'w') as f:
f.write(code)
perf_fname = f"{NATIVE_PERF_GOLD_ROOT}/{fname.replace('.py', '_perf.py')}"
assert os.path.exists(perf_fname), f"Performance file {perf_fname} does not exist. Please check the ground truth data."
exec_perf_fpath = os.path.join(exec_root, os.path.basename(perf_fname))
copyfile(perf_fname, exec_perf_fpath)
cmd_status, perf_stdout, perf_stderr = self._call_file(exec_perf_fpath, timeout=timeout)
if not cmd_status:
if verbose:
print(f"Performance file execution failed: {perf_stderr}")
return call_status, False, speedup, perf_stdout, perf_stderr
exec_status = "Failed to run benchmark for input tensor." not in perf_stdout #perf_stdout.split(",")[0].strip().lower() == str(True).lower()
if not exec_status:
if verbose:
print(f"Performance file execution failed: {perf_stderr}")
return call_status, False, speedup, perf_stdout, perf_stderr
perf_data = extract_json_from_stdout(perf_stdout)
if not perf_data:
if verbose:
print(f"Performance data extraction failed: {perf_stderr}")
return call_status, exec_status, speedup, perf_stdout, perf_stderr
speedup = perf_data[-1].get("speedup", 0)
return call_status, exec_status, speedup, perf_stdout, perf_stderr
except Exception as e:
if verbose:
print(f"File: {fname}, Execution error: {e}")
return False, False, 0, None, str(e)
except subprocess.TimeoutExpired:
if verbose:
print(f"File: {fname} timed out!")
return False, False, 0, None, "Time out"
class TestAllCloseEvaluatorROCm(TestAllCloseEvaluatorTBG):
def __init__(self, ground_truth_root :str = ROCm_DATA_ROOT, MODULE_DIR :str =_MODULE_DIR) -> None:
super().__init__("CorrectnessWithAllCloseEvaluator")
self.ground_truth_root = ground_truth_root
self.MODULE_DIR = MODULE_DIR
self.llm_output_processor = LLMOutputProcessor()
def get_ground_truth_fpath(self, fname:str ) -> str:
"""
Returns the path to the ground truth file based on the given filename.
"""
return os.path.join(self.ground_truth_root, fname)
def get_fpath(self, root :str, fname :str, suffix: str) -> str:
"""
Returns the path to the generated file based on the given filename.
"""
tmp_gen_folder = os.path.join(root, "tmp", Names.GEN_FOLDER)
os.makedirs(tmp_gen_folder, exist_ok=True)
file = get_rocm_temp_file(prefix=f'{fname}{suffix}')
return os.path.join(tmp_gen_folder, file)
def get_tests_code(self, fname :str) -> list[str]:
with open(fname, 'r') as f:
content = f.read()
snippets = content.split(self.tests_sep_line)
assert len(snippets) == 2, f"No test code found in {fname} with test line seperator {self.tests_sep_line}"
test_code_lines_procs = snippets[1].strip()
return test_code_lines_procs
def format_gen_code(self, fname:str, code:str, tests:list[str]) -> str:
code = code + '\n\n' + '#'*146 + '\n\n' + tests
with open(fname, 'w') as f:
f.write(code)
return code
def _check_match(self, gen_fpath :str, ref_fpath:str , atol :float=1e-2, rtol :float=1e-2, timeout :int =60*60) -> tuple[bool, str, str]:
"""
Executes the generated file and checks its correctness against the reference file.
"""
cmd = [f'python3 {self.MODULE_DIR}/ROCm_correctness.py --gen_file {gen_fpath} --ref_file {ref_fpath} --atol {atol} --rtol {rtol} --global_timeout {timeout} --verbose']
status, stdout, stderr = run_shell(cmd, timeout=None if timeout is None else timeout)
with open(gen_fpath+".stdout", 'w') as f:
f.write(stdout)
stderr += extract_errors(stdout.split(Names.PYTEST_SEPARATOR)[0])
with open(gen_fpath+".stderr", 'w') as f:
f.write(stderr)
return status, stdout, stderr
def execute(self, code :str, log_root:str, exec_root:str, fname:str, atol :float=1e-3, rtol:float=1e-1, timeout :int =40*60, custom_tests_path=None, verbose :bool =False) -> tuple[bool, bool, str, str]:
triton_file = self.get_ground_truth_fpath(fname)
gen_file = self.get_fpath(log_root, fname, Names.GEN_SUFFIX)
ref_file = self.get_fpath(log_root, fname, Names.REF_SUFFIX)
copyfile(triton_file,ref_file)
test_code_lines_procs = self.get_tests_code(triton_file)
if custom_tests_path is not None and "@triton.autotune" in code:
custom_tests_file = os.path.join(custom_tests_path, fname)
if os.path.exists(custom_tests_file):
test_code_lines_procs = self.get_tests_code(custom_tests_file)
else:
if verbose:
print(f"Custom tests file {custom_tests_file} does not exist. Skipping custom tests.", file=sys.stderr)
code = process_code(code)
code = self.format_gen_code(gen_file, code, test_code_lines_procs)
try:
call_status, stdout, stderr = False, None, None
# Check for correctness
match_status = False
atol, rtol = 1e-2, 1e-2
match_status, stdout, stderr = self._check_match(gen_file, ref_file, atol=atol, rtol=rtol, timeout=timeout)
# Check if the generated code executed successfully
if not match_status:
if verbose:
print(f"Error in generated code: {stderr}")
stderr += "Error in generate triton-kernel code :\n " + extract_errors(stdout.split(Names.PYTEST_SEPARATOR)[0])
return call_status, False, -1, stdout, stderr
else:
if verbose:
print(f"Success in generated code: {stdout}")
call_status_str, exec_status_str, gen_stdout, gen_stderr = stdout.split(Names.PYTEST_SEPARATOR)[-1].split(Names.RET_SEPERATOR)
gen_stderr += extract_errors(stdout.split(Names.PYTEST_SEPARATOR)[0])
call_status = call_status_str.replace('\n', '').strip().lower() == str(True).lower()
# Original logic for exec_status, ensure it depends on the boolean call_status
# If call_status is False, exec_status should also be False.
# If call_status is True, then exec_status depends on its own string value.
if call_status:
exec_status = exec_status_str.strip().lower() == str(True).lower()
else:
exec_status = False
if exec_status:
## The generated code executed successfully, save the file in exec folder
exec_fpath = os.path.join(exec_root, fname)
os.makedirs(exec_root, exist_ok=True)
with open(exec_fpath, 'w') as f:
f.write(code)
return call_status, exec_status, -1, gen_stdout, gen_stderr
except Exception as e:
if verbose:
print(f"File: {fname}, Execution error: {e}")
return False, False, None, -1, str(e)
except subprocess.TimeoutExpired:
if verbose:
print(f"File: {fname} timed out!")
return False, False, None, -1, "Time out"
get_evaluators = {
"tbg": TestAllCloseEvaluatorTBG,
"rocm": TestAllCloseEvaluatorROCm
}