llmll's picture
Upload folder using huggingface_hub
02c783d verified
import json
import os
import ast
import subprocess
from random import randint
from tqdm import tqdm
import signal
from multiprocessing import Pool, Lock, Value
from dataloaders.ProblemState import ProblemState
from dataloaders.TB_eval.utils import code_call_exec_success_allclose, code_kernel_profiling
import re
from tb_eval.evaluators.interface import get_evaluators
class TritonBench:
def __init__(self,
statis_path,
py_folder,
instruction_path,
golden_metrics,
py_interpreter,
perf_ref_folder,
perf_G_path,
result_path=None
):
self.statis_path = statis_path
self.py_folder = py_folder
self.instruction_path = instruction_path
self.golden_metrics_folder = golden_metrics
self.py_interpreter = py_interpreter
self.perf_ref_folder = perf_ref_folder
self.perf_G_path = perf_G_path
self.result_path = result_path
self.problem_states = self.load_ps(result_path)
self.evaluator = get_evaluators['tbg']()
def load_ps(self, path):
problem_states = []
if path is None:
with open(self.instruction_path, "r", encoding='utf-8') as file:
instructions = json.load(file)
statis_data = json.loads(open(self.statis_path, 'r', encoding='utf-8').read())
for line in instructions:
instruction = line["instruction"]
label = line["output"]
# get test code
g = label.replace("<|im_end|>", "").replace("<|EOT|>", "")
tmp = False
for item in statis_data:
if g in item["output"]:
file = item["file"]
tmp = item
break
if tmp:
statis_data.remove(tmp)
elif g[50:220] == 'as tl\n\nif triton.__version__ >= "2.1.0":\n @triton.jit\n def _fwd_kernel(\n Q, K, V, sm_scale, B_Start_Loc, B_Seqlen, # B_LOC 内部记录每个batch 输入的真实位置, B_SEQ_len 记录':
file = "context_attn_nopad.py"
path = os.path.join(self.py_folder, file)
assert os.path.exists(path), f"{file} not exist!"
test_code = open(path, "r", encoding="utf-8").read().split("#"*146)[-1]
assert "def test_" in test_code, ""
problemstate = ProblemState(instruction=instruction,
label=label,
test_code=test_code,
filename=file,
)
problem_states.append(
problemstate
)
else:
with open(path, 'r', encoding='utf-8') as file:
for line in file.readlines():
content = json.loads(line)
problem_state = ProblemState(instruction=content["instruction"],
label=content["label"],
filename=content["filename"],
)
if "test_code" in content:
problem_state.test_code = content["test_code"]
if "predict" in content:
problem_state.solution = content["predict"]
problem_states.append(problem_state)
return problem_states
def __len__(self):
return len(self.problem_states)
def write_file(self, file_path):
os.makedirs(os.path.dirname(file_path), exist_ok=True)
with open(file_path, 'w') as f:
for ps in self.problem_states:
output = {
"instruction": ps.instruction,
"label": ps.label,
"filename": ps.filename,
}
if ps.test_code:
output["test_code"] = ps.test_code
if ps.solution:
output["predict"] = ps.solution
output["speedup"] = ps.speedup
f.write(json.dumps(output) + "\n")
def test_opt_correctness(self, code, filename, tmp_dir, exe_dir="pass_exe",gpu_id=0):
"""
Runs a given Python script on a specified GPU.
"""
pass_call, pass_exe, speedup, call_stdout, call_stderr = self.evaluator(code, tmp_dir, exe_dir, filename, atol=1e-3, rtol=1e-3, custom_tests_path=None)
return pass_call, pass_exe, speedup, call_stdout, call_stderr
def test_kernel_profiling(self, code, filename, tmp_dir, save_scripts=True, exe_dir="pass_exe", target_gpu=None, timeout=20*60):
os.makedirs(exe_dir, exist_ok=True)
profile_status, stdout_profile, stderr_profile, stdout_analyze = code_kernel_profiling(code=code, fname=filename, temp_root=tmp_dir, py_folder=self.py_folder, target_gpu=target_gpu, timeout=timeout)
pass_prfiler = False
if "True" in str(profile_status):
pass_prfiler=True
return pass_prfiler, stdout_profile, stderr_profile, stdout_analyze