File size: 5,432 Bytes
02c783d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 |
import json
import os
import ast
import subprocess
from random import randint
from tqdm import tqdm
import signal
from multiprocessing import Pool, Lock, Value
from dataloaders.ProblemState import ProblemState
from dataloaders.TB_eval.utils import code_call_exec_success_allclose, code_kernel_profiling
import re
from tb_eval.evaluators.interface import get_evaluators
class TritonBench:
def __init__(self,
statis_path,
py_folder,
instruction_path,
golden_metrics,
py_interpreter,
perf_ref_folder,
perf_G_path,
result_path=None
):
self.statis_path = statis_path
self.py_folder = py_folder
self.instruction_path = instruction_path
self.golden_metrics_folder = golden_metrics
self.py_interpreter = py_interpreter
self.perf_ref_folder = perf_ref_folder
self.perf_G_path = perf_G_path
self.result_path = result_path
self.problem_states = self.load_ps(result_path)
self.evaluator = get_evaluators['tbg']()
def load_ps(self, path):
problem_states = []
if path is None:
with open(self.instruction_path, "r", encoding='utf-8') as file:
instructions = json.load(file)
statis_data = json.loads(open(self.statis_path, 'r', encoding='utf-8').read())
for line in instructions:
instruction = line["instruction"]
label = line["output"]
# get test code
g = label.replace("<|im_end|>", "").replace("<|EOT|>", "")
tmp = False
for item in statis_data:
if g in item["output"]:
file = item["file"]
tmp = item
break
if tmp:
statis_data.remove(tmp)
elif g[50:220] == 'as tl\n\nif triton.__version__ >= "2.1.0":\n @triton.jit\n def _fwd_kernel(\n Q, K, V, sm_scale, B_Start_Loc, B_Seqlen, # B_LOC 内部记录每个batch 输入的真实位置, B_SEQ_len 记录':
file = "context_attn_nopad.py"
path = os.path.join(self.py_folder, file)
assert os.path.exists(path), f"{file} not exist!"
test_code = open(path, "r", encoding="utf-8").read().split("#"*146)[-1]
assert "def test_" in test_code, ""
problemstate = ProblemState(instruction=instruction,
label=label,
test_code=test_code,
filename=file,
)
problem_states.append(
problemstate
)
else:
with open(path, 'r', encoding='utf-8') as file:
for line in file.readlines():
content = json.loads(line)
problem_state = ProblemState(instruction=content["instruction"],
label=content["label"],
filename=content["filename"],
)
if "test_code" in content:
problem_state.test_code = content["test_code"]
if "predict" in content:
problem_state.solution = content["predict"]
problem_states.append(problem_state)
return problem_states
def __len__(self):
return len(self.problem_states)
def write_file(self, file_path):
os.makedirs(os.path.dirname(file_path), exist_ok=True)
with open(file_path, 'w') as f:
for ps in self.problem_states:
output = {
"instruction": ps.instruction,
"label": ps.label,
"filename": ps.filename,
}
if ps.test_code:
output["test_code"] = ps.test_code
if ps.solution:
output["predict"] = ps.solution
output["speedup"] = ps.speedup
f.write(json.dumps(output) + "\n")
def test_opt_correctness(self, code, filename, tmp_dir, exe_dir="pass_exe",gpu_id=0):
"""
Runs a given Python script on a specified GPU.
"""
pass_call, pass_exe, speedup, call_stdout, call_stderr = self.evaluator(code, tmp_dir, exe_dir, filename, atol=1e-3, rtol=1e-3, custom_tests_path=None)
return pass_call, pass_exe, speedup, call_stdout, call_stderr
def test_kernel_profiling(self, code, filename, tmp_dir, save_scripts=True, exe_dir="pass_exe", target_gpu=None, timeout=20*60):
os.makedirs(exe_dir, exist_ok=True)
profile_status, stdout_profile, stderr_profile, stdout_analyze = code_kernel_profiling(code=code, fname=filename, temp_root=tmp_dir, py_folder=self.py_folder, target_gpu=target_gpu, timeout=timeout)
pass_prfiler = False
if "True" in str(profile_status):
pass_prfiler=True
return pass_prfiler, stdout_profile, stderr_profile, stdout_analyze
|