|
|
import json |
|
|
import os |
|
|
import ast |
|
|
import subprocess |
|
|
from random import randint |
|
|
from tqdm import tqdm |
|
|
import signal |
|
|
from multiprocessing import Pool, Lock, Value |
|
|
from dataloaders.ProblemState import ProblemState |
|
|
from dataloaders.TB_eval.utils import code_call_exec_success_allclose, code_kernel_profiling |
|
|
import re |
|
|
from tb_eval.evaluators.interface import get_evaluators |
|
|
|
|
|
class TritonBench: |
|
|
def __init__(self, |
|
|
statis_path, |
|
|
py_folder, |
|
|
instruction_path, |
|
|
golden_metrics, |
|
|
py_interpreter, |
|
|
perf_ref_folder, |
|
|
perf_G_path, |
|
|
result_path=None |
|
|
): |
|
|
self.statis_path = statis_path |
|
|
self.py_folder = py_folder |
|
|
self.instruction_path = instruction_path |
|
|
self.golden_metrics_folder = golden_metrics |
|
|
self.py_interpreter = py_interpreter |
|
|
self.perf_ref_folder = perf_ref_folder |
|
|
self.perf_G_path = perf_G_path |
|
|
self.result_path = result_path |
|
|
|
|
|
self.problem_states = self.load_ps(result_path) |
|
|
self.evaluator = get_evaluators['tbg']() |
|
|
|
|
|
def load_ps(self, path): |
|
|
problem_states = [] |
|
|
if path is None: |
|
|
with open(self.instruction_path, "r", encoding='utf-8') as file: |
|
|
instructions = json.load(file) |
|
|
statis_data = json.loads(open(self.statis_path, 'r', encoding='utf-8').read()) |
|
|
|
|
|
for line in instructions: |
|
|
instruction = line["instruction"] |
|
|
label = line["output"] |
|
|
|
|
|
|
|
|
g = label.replace("<|im_end|>", "").replace("<|EOT|>", "") |
|
|
tmp = False |
|
|
for item in statis_data: |
|
|
if g in item["output"]: |
|
|
file = item["file"] |
|
|
tmp = item |
|
|
break |
|
|
if tmp: |
|
|
statis_data.remove(tmp) |
|
|
elif g[50:220] == 'as tl\n\nif triton.__version__ >= "2.1.0":\n @triton.jit\n def _fwd_kernel(\n Q, K, V, sm_scale, B_Start_Loc, B_Seqlen, # B_LOC 内部记录每个batch 输入的真实位置, B_SEQ_len 记录': |
|
|
file = "context_attn_nopad.py" |
|
|
path = os.path.join(self.py_folder, file) |
|
|
assert os.path.exists(path), f"{file} not exist!" |
|
|
test_code = open(path, "r", encoding="utf-8").read().split("#"*146)[-1] |
|
|
assert "def test_" in test_code, "" |
|
|
|
|
|
problemstate = ProblemState(instruction=instruction, |
|
|
label=label, |
|
|
test_code=test_code, |
|
|
filename=file, |
|
|
) |
|
|
|
|
|
problem_states.append( |
|
|
problemstate |
|
|
) |
|
|
else: |
|
|
with open(path, 'r', encoding='utf-8') as file: |
|
|
for line in file.readlines(): |
|
|
content = json.loads(line) |
|
|
problem_state = ProblemState(instruction=content["instruction"], |
|
|
label=content["label"], |
|
|
filename=content["filename"], |
|
|
) |
|
|
if "test_code" in content: |
|
|
problem_state.test_code = content["test_code"] |
|
|
if "predict" in content: |
|
|
problem_state.solution = content["predict"] |
|
|
problem_states.append(problem_state) |
|
|
return problem_states |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.problem_states) |
|
|
|
|
|
def write_file(self, file_path): |
|
|
os.makedirs(os.path.dirname(file_path), exist_ok=True) |
|
|
with open(file_path, 'w') as f: |
|
|
for ps in self.problem_states: |
|
|
output = { |
|
|
"instruction": ps.instruction, |
|
|
"label": ps.label, |
|
|
"filename": ps.filename, |
|
|
} |
|
|
if ps.test_code: |
|
|
output["test_code"] = ps.test_code |
|
|
if ps.solution: |
|
|
output["predict"] = ps.solution |
|
|
output["speedup"] = ps.speedup |
|
|
|
|
|
f.write(json.dumps(output) + "\n") |
|
|
|
|
|
def test_opt_correctness(self, code, filename, tmp_dir, exe_dir="pass_exe",gpu_id=0): |
|
|
""" |
|
|
Runs a given Python script on a specified GPU. |
|
|
""" |
|
|
pass_call, pass_exe, speedup, call_stdout, call_stderr = self.evaluator(code, tmp_dir, exe_dir, filename, atol=1e-3, rtol=1e-3, custom_tests_path=None) |
|
|
|
|
|
return pass_call, pass_exe, speedup, call_stdout, call_stderr |
|
|
|
|
|
def test_kernel_profiling(self, code, filename, tmp_dir, save_scripts=True, exe_dir="pass_exe", target_gpu=None, timeout=20*60): |
|
|
os.makedirs(exe_dir, exist_ok=True) |
|
|
profile_status, stdout_profile, stderr_profile, stdout_analyze = code_kernel_profiling(code=code, fname=filename, temp_root=tmp_dir, py_folder=self.py_folder, target_gpu=target_gpu, timeout=timeout) |
|
|
pass_prfiler = False |
|
|
if "True" in str(profile_status): |
|
|
pass_prfiler=True |
|
|
|
|
|
return pass_prfiler, stdout_profile, stderr_profile, stdout_analyze |
|
|
|
|
|
|
|
|
|