File size: 11,946 Bytes
02c783d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
# Copyright(C) [2025] Advanced Micro Devices, Inc. All rights reserved.
import os, sys, json
import subprocess
from shutil import copyfile
from .base import BaseEvaluator
from ..helpers import get_temp_file, get_rocm_temp_file
from ..helpers.helper import run_shell, process_code, extract_errors, extract_json_from_stdout
from ..processors.llm import LLMOutputProcessor
from ..constants import REPO_ROOT, TMP_ROOT, TBG_DATA_ROOT, ROCm_DATA_ROOT, Names, NATIVE_PERF_GOLD_ROOT
_MODULE_DIR = os.path.dirname(os.path.abspath(__file__))

class TestAllCloseEvaluatorTBG(BaseEvaluator):
    def __init__(self, ground_truth_root :str = TBG_DATA_ROOT, MODULE_DIR :str =_MODULE_DIR) -> None:
        super().__init__("CorrectnessWithAllCloseEvaluator")

        self.ground_truth_root = ground_truth_root
        self.MODULE_DIR = MODULE_DIR
        

        self.llm_output_processor = LLMOutputProcessor()

    def get_gen_fpath(self, root :str, fname :str) -> str:
        """
        Returns the path to the generated file based on the given filename.
        """
        tmp_gen_folder = os.path.join(root, "tmp", Names.GEN_FOLDER)
        os.makedirs(tmp_gen_folder, exist_ok=True)
        gen_file = get_temp_file(prefix=f'{fname}{Names.GEN_SUFFIX}')
        return os.path.join(tmp_gen_folder, gen_file)

    def get_ground_truth_fpath(self, fname:str ) -> str:
        """
        Returns the path to the ground truth file based on the given filename.
        """
        return os.path.join(self.ground_truth_root, fname)

    def get_tests_code(self, fname :str) -> list[str]:
        with open(fname, 'r') as f:
            lines = f.readlines()
            for iL, line in enumerate(lines):
                if line.strip() == self.tests_sep_line:
                    break
            test_code_lines = lines[iL+1:]
            assert len(test_code_lines) > 0, f"No test code found in {fname} with test line seperator {self.tests_sep_line}"
        return test_code_lines

    def format_gen_code(self, fname:str, code:str, tests:list[str]) -> str:
        code =  code + '\n\n' + self.tests_sep_line + '\n' + '\n' + '\n'.join(tests)
        with open(fname, 'w') as f:
            f.write(code)
        return code
    
    def _call_file(self, fpath :str, timeout :int =2*60) -> tuple[bool, str, str]:
        ## Just to a simple call to the generated code
        cmd = [f'python3 {fpath}']
        call_status, stdout, stderr = run_shell(cmd, timeout=timeout)
        return call_status, stdout, stderr

    def _check_match(self, gen_fpath :str, ref_fpath:str , atol :float=1e-3, rtol :float=1e-1, timeout :int =2*60) -> tuple[bool, str, str]:
        """
        Executes the generated file and checks its correctness against the reference file.
        """
        cmd = [f'python3 {self.MODULE_DIR}/TB_correctness.py --gen_file {gen_fpath} --ref_file {ref_fpath} --atol {atol} --rtol {rtol}']
        status, stdout, stderr = run_shell(cmd, timeout=timeout)
        with open(gen_fpath+".stdout", 'w') as f:
            f.write(stdout)

        with open(gen_fpath+".stderr", 'w') as f:
            f.write(stderr)
        return status, stdout, stderr

    def execute(self, code :str, log_root:str, exec_root:str, fname:str, atol :float=1e-3, rtol:float=1e-1, timeout :int =2*60, verbose :bool =False, custom_tests_path=None) -> tuple[bool, bool, str, str]:
        
        speedup = 0

        triton_file = self.get_ground_truth_fpath(fname) 

        gen_file = self.get_gen_fpath(log_root, fname)

        test_code_lines_procs = self.get_tests_code(triton_file)

        code = process_code(code)

        code = self.format_gen_code(gen_file, code, test_code_lines_procs)

        try:
            call_status, stdout, stderr = self._call_file(gen_file, timeout=timeout)

            if not call_status:
                return call_status, False, speedup, stdout, stderr

            ## copy the generated file and its perf file to exec_fpath
            exec_fpath = os.path.join(exec_root, fname)
            os.makedirs(exec_root, exist_ok=True)
            with open(exec_fpath, 'w') as f:
                f.write(code)
            
            perf_fname = f"{NATIVE_PERF_GOLD_ROOT}/{fname.replace('.py', '_perf.py')}"
            assert os.path.exists(perf_fname), f"Performance file {perf_fname} does not exist. Please check the ground truth data."

            exec_perf_fpath = os.path.join(exec_root, os.path.basename(perf_fname))
            copyfile(perf_fname, exec_perf_fpath)
            
            cmd_status, perf_stdout, perf_stderr = self._call_file(exec_perf_fpath, timeout=timeout)

            if not cmd_status:
                if verbose:
                    print(f"Performance file execution failed: {perf_stderr}")
                return call_status, False, speedup, perf_stdout, perf_stderr
            
            exec_status = "Failed to run benchmark for input tensor." not in perf_stdout #perf_stdout.split(",")[0].strip().lower() == str(True).lower()

            if not exec_status:
                if verbose:
                    print(f"Performance file execution failed: {perf_stderr}")
                return call_status, False, speedup, perf_stdout, perf_stderr

            perf_data = extract_json_from_stdout(perf_stdout)

            if not perf_data:
                if verbose:
                    print(f"Performance data extraction failed: {perf_stderr}")
                return call_status, exec_status, speedup, perf_stdout, perf_stderr

            speedup = perf_data[-1].get("speedup", 0)
            return call_status, exec_status, speedup, perf_stdout, perf_stderr

        except Exception as e:
            if verbose:
                print(f"File: {fname}, Execution error: {e}")
            return False, False, 0, None, str(e)

        except subprocess.TimeoutExpired:
            if verbose:
                print(f"File: {fname} timed out!")
            return False, False, 0, None, "Time out"


class TestAllCloseEvaluatorROCm(TestAllCloseEvaluatorTBG):
    def __init__(self, ground_truth_root :str = ROCm_DATA_ROOT, MODULE_DIR :str =_MODULE_DIR) -> None:
        super().__init__("CorrectnessWithAllCloseEvaluator")

        self.ground_truth_root = ground_truth_root
        self.MODULE_DIR = MODULE_DIR
        
        self.llm_output_processor = LLMOutputProcessor()
    
    def get_ground_truth_fpath(self, fname:str ) -> str:
        """
        Returns the path to the ground truth file based on the given filename.
        """
        return os.path.join(self.ground_truth_root, fname)
    
    

    def get_fpath(self, root :str, fname :str, suffix: str) -> str:
        """
        Returns the path to the generated file based on the given filename.
        """
        tmp_gen_folder = os.path.join(root, "tmp", Names.GEN_FOLDER)
        os.makedirs(tmp_gen_folder, exist_ok=True)
        file = get_rocm_temp_file(prefix=f'{fname}{suffix}')
        return os.path.join(tmp_gen_folder, file)
    
    def get_tests_code(self, fname :str) -> list[str]:
        with open(fname, 'r') as f:
            content = f.read()

        snippets = content.split(self.tests_sep_line) 
        assert len(snippets) == 2, f"No test code found in {fname} with test line seperator {self.tests_sep_line}"  
 
        test_code_lines_procs = snippets[1].strip()
        return test_code_lines_procs
    
    def format_gen_code(self, fname:str, code:str, tests:list[str]) -> str:
        code =  code + '\n\n' + '#'*146 + '\n\n' + tests
        with open(fname, 'w') as f:
            f.write(code)
        return code
    
    def _check_match(self, gen_fpath :str, ref_fpath:str , atol :float=1e-2, rtol :float=1e-2, timeout :int =60*60) -> tuple[bool, str, str]:
        """
        Executes the generated file and checks its correctness against the reference file.
        """
        cmd = [f'python3 {self.MODULE_DIR}/ROCm_correctness.py --gen_file {gen_fpath} --ref_file {ref_fpath} --atol {atol} --rtol {rtol} --global_timeout {timeout} --verbose']
        status, stdout, stderr = run_shell(cmd, timeout=None if timeout is None else timeout)
        with open(gen_fpath+".stdout", 'w') as f:
            f.write(stdout)

        stderr += extract_errors(stdout.split(Names.PYTEST_SEPARATOR)[0])
        with open(gen_fpath+".stderr", 'w') as f:
            f.write(stderr)
        return status, stdout, stderr
    
    def execute(self, code :str, log_root:str, exec_root:str, fname:str, atol :float=1e-3, rtol:float=1e-1, timeout :int =40*60, custom_tests_path=None, verbose :bool =False) -> tuple[bool, bool, str, str]:
        
        triton_file = self.get_ground_truth_fpath(fname) 

        gen_file = self.get_fpath(log_root, fname, Names.GEN_SUFFIX)

        ref_file = self.get_fpath(log_root, fname, Names.REF_SUFFIX)

        copyfile(triton_file,ref_file)
    
        test_code_lines_procs = self.get_tests_code(triton_file)

        if custom_tests_path is not None and "@triton.autotune" in code:
            custom_tests_file = os.path.join(custom_tests_path, fname)
            if os.path.exists(custom_tests_file):
                test_code_lines_procs = self.get_tests_code(custom_tests_file)
            else:
                if verbose:
                    print(f"Custom tests file {custom_tests_file} does not exist. Skipping custom tests.", file=sys.stderr)

        code = process_code(code)
        code = self.format_gen_code(gen_file, code, test_code_lines_procs)

        try:
            call_status, stdout, stderr = False, None, None 

            # Check for correctness
            match_status = False
            atol, rtol = 1e-2, 1e-2
            match_status, stdout, stderr = self._check_match(gen_file, ref_file, atol=atol, rtol=rtol, timeout=timeout)

            # Check if the generated code executed successfully
            if not match_status:
                if verbose:
                    print(f"Error in generated code: {stderr}")
                stderr += "Error in generate triton-kernel code :\n " + extract_errors(stdout.split(Names.PYTEST_SEPARATOR)[0])
                return call_status, False, -1, stdout, stderr
            else:
                if verbose:
                    print(f"Success in generated code: {stdout}")
                call_status_str, exec_status_str, gen_stdout, gen_stderr = stdout.split(Names.PYTEST_SEPARATOR)[-1].split(Names.RET_SEPERATOR)
                gen_stderr += extract_errors(stdout.split(Names.PYTEST_SEPARATOR)[0])
                call_status = call_status_str.replace('\n', '').strip().lower() == str(True).lower()
                
                # Original logic for exec_status, ensure it depends on the boolean call_status
                # If call_status is False, exec_status should also be False.
                # If call_status is True, then exec_status depends on its own string value.
                if call_status:
                    exec_status = exec_status_str.strip().lower() == str(True).lower()
                else:
                    exec_status = False

                if exec_status:
                    ## The generated code executed successfully, save the file in exec folder
                    exec_fpath = os.path.join(exec_root, fname)
                    os.makedirs(exec_root, exist_ok=True)
                    with open(exec_fpath, 'w') as f:
                        f.write(code)

                
                return call_status, exec_status, -1, gen_stdout, gen_stderr

        except Exception as e:
            if verbose:
                print(f"File: {fname}, Execution error: {e}")
            return False, False, None, -1, str(e)

        except subprocess.TimeoutExpired:
            if verbose:
                print(f"File: {fname} timed out!")
            return False, False, None, -1, "Time out"
get_evaluators = {
    "tbg": TestAllCloseEvaluatorTBG,
    "rocm": TestAllCloseEvaluatorROCm
}