|
|
import ast |
|
|
import os |
|
|
import subprocess |
|
|
from random import randint |
|
|
from tqdm import tqdm |
|
|
from shutil import copyfile |
|
|
import datetime |
|
|
import json |
|
|
from parse_llm_code import extract_code_blocks |
|
|
import numpy as np |
|
|
import re |
|
|
|
|
|
def get_temp_bash_file(prefix='temp_code'): |
|
|
|
|
|
temp_file_name = f'{prefix}_{randint(999, 999999)}.sh' |
|
|
while os.path.exists(temp_file_name): |
|
|
temp_file_name.replace('.sh', f'_{randint(999, 999999)}.sh') |
|
|
return temp_file_name |
|
|
|
|
|
def parse_profiler_content(profile_content): |
|
|
delimiter = "--------------------------------------------------------------------------------" |
|
|
|
|
|
parts = profile_content.split(delimiter) |
|
|
|
|
|
section_data = {} |
|
|
|
|
|
section_pattern = re.compile(r"^\s*(\d+)\..*$", re.MULTILINE) |
|
|
|
|
|
for part in parts: |
|
|
trimmed_part = part.strip() |
|
|
if not trimmed_part: |
|
|
continue |
|
|
|
|
|
match = section_pattern.search(trimmed_part) |
|
|
if match: |
|
|
section_number = match.group(1) |
|
|
full_section_content = delimiter + part |
|
|
section_data[section_number] = full_section_content |
|
|
|
|
|
return section_data |
|
|
|
|
|
|
|
|
def passk(n, c, k): |
|
|
if n -c < k: return 1.0 |
|
|
return 1 - np.prod( |
|
|
1 - k/ np.arange( |
|
|
n-c+1, n+1 |
|
|
) |
|
|
) |
|
|
|
|
|
def get_time(): |
|
|
|
|
|
return datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S") |
|
|
|
|
|
def get_temp_file(prefix='temp_code'): |
|
|
|
|
|
temp_file_name = f'{prefix}_{randint(999, 999999)}.py' |
|
|
while os.path.exists(temp_file_name): |
|
|
temp_file_name.replace('.py', f'_{randint(999, 999999)}.py') |
|
|
return temp_file_name |
|
|
|
|
|
def code_call_exec_success_stdout(code, fname, temp_root="tmp2", tolerance=2, verbose=False): |
|
|
|
|
|
tmp_triton_folder = os.path.join(temp_root, "triton") |
|
|
tmp_gen_folder = os.path.join(temp_root, "gen") |
|
|
os.makedirs(tmp_triton_folder, exist_ok=True) |
|
|
os.makedirs(tmp_gen_folder, exist_ok=True) |
|
|
|
|
|
|
|
|
triton_root = "dataloaders/TB_eval/TritonBench/data/TritonBench_G_v1" |
|
|
RAND_FILE = os.path.join(triton_root, "rand_utils.py") |
|
|
|
|
|
copyfile(RAND_FILE, os.path.join(tmp_triton_folder, "rand_utils.py")) |
|
|
copyfile(RAND_FILE, os.path.join(tmp_gen_folder, "rand_utils.py")) |
|
|
|
|
|
gen_file = get_temp_file(prefix=f'{fname}_gen_triton_code') |
|
|
triton_file = os.path.join(triton_root, fname) |
|
|
temp_triton_file = get_temp_file(prefix=f'{fname}_temp_triton') |
|
|
|
|
|
gen_file = os.path.join(tmp_gen_folder, gen_file) |
|
|
temp_triton_file = os.path.join(tmp_triton_folder, temp_triton_file) |
|
|
|
|
|
IMPORT_STATEMENT = f""" |
|
|
from rand_utils import torch_rand, torch_randint, torch_randn |
|
|
import torch |
|
|
torch.set_printoptions(precision={tolerance},profile='full',sci_mode=False) |
|
|
""" |
|
|
|
|
|
hash_line = "#"*146 |
|
|
|
|
|
with open(triton_file, 'r') as f: |
|
|
lines = f.readlines() |
|
|
|
|
|
|
|
|
|
|
|
for iL, line in enumerate(lines): |
|
|
if line.strip() == hash_line: |
|
|
break |
|
|
test_code_lines = lines[iL+1:] |
|
|
test_code_lines = IMPORT_STATEMENT.split('\n') + test_code_lines |
|
|
test_code_lines_procs = [] |
|
|
for line in test_code_lines: |
|
|
if "torch.rand" in line: |
|
|
line = line.replace("torch.rand", "torch_rand") |
|
|
test_code_lines_procs.append(line) |
|
|
|
|
|
with open(temp_triton_file, 'w') as f: |
|
|
triton_lines = lines[:iL] + [hash_line] + test_code_lines_procs |
|
|
for line in triton_lines: |
|
|
f.write(line + "\n") |
|
|
|
|
|
code = code + '\n\n' + hash_line + '\n' + '\n' + '\n'.join(test_code_lines_procs) |
|
|
with open(gen_file, 'w') as f: |
|
|
f.write(code) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
|
|
|
result_gen = subprocess.run(['python3', gen_file], capture_output=True, text=True, timeout=2*60) |
|
|
stdout_gen = result_gen.stdout |
|
|
stderr_gen = result_gen.stderr |
|
|
|
|
|
|
|
|
if result_gen.returncode != 0: |
|
|
if verbose: |
|
|
print(f"Error in generated code: {stderr_gen}") |
|
|
return False, False, stdout_gen, stderr_gen |
|
|
|
|
|
|
|
|
result_triton = subprocess.run(['python3', temp_triton_file], capture_output=True, text=True, timeout=2*60) |
|
|
stdout_triton = result_triton.stdout |
|
|
stderr_triton = result_triton.stderr |
|
|
|
|
|
|
|
|
if result_triton.returncode != 0: |
|
|
if verbose: |
|
|
print(f"Error in Triton code: {stderr_triton}") |
|
|
return None, None, None, None |
|
|
|
|
|
with open(gen_file+".out", 'w') as f: |
|
|
f.write(stdout_gen) |
|
|
with open(temp_triton_file+".out", 'w') as f: |
|
|
f.write(stdout_triton) |
|
|
|
|
|
with open(gen_file+".err", 'w') as f: |
|
|
f.write(stderr_gen) |
|
|
with open(temp_triton_file+".err", 'w') as f: |
|
|
f.write(stderr_triton) |
|
|
|
|
|
|
|
|
if stdout_gen == stdout_triton: |
|
|
return True, True, None, None |
|
|
else: |
|
|
return True, False, stdout_gen, "Error: not all test cases passed. The generated code and ground truth code produced different outputs." |
|
|
except Exception as e: |
|
|
if verbose: |
|
|
print(f"File: {fname}, Execution error: {e}") |
|
|
return False, False, None, str(e) |
|
|
|
|
|
except subprocess.TimeoutExpired: |
|
|
if verbose: |
|
|
print(f"File: {fname} timed out!") |
|
|
return None, None, None, "Time out" |
|
|
finally: |
|
|
pass |
|
|
|
|
|
|
|
|
|
|
|
return False, False, None, None |
|
|
|
|
|
def code_kernel_profiling(code, fname, py_folder, target_gpu, temp_root="tmp2", atol=1e-3, rtol=1e-1, timeout=6*60, verbose=False): |
|
|
tmp_gen_folder = os.path.join(temp_root, "gen") |
|
|
os.makedirs(tmp_gen_folder, exist_ok=True) |
|
|
|
|
|
|
|
|
triton_root = py_folder |
|
|
triton_file = os.path.join(triton_root, fname) |
|
|
|
|
|
gen_file = get_temp_file(prefix=f'{fname}_gen_triton_code') |
|
|
gen_file = os.path.join(tmp_gen_folder, gen_file) |
|
|
|
|
|
fname_split = fname.split('.')[0] |
|
|
gen_bash_file = get_temp_bash_file(prefix=f'{fname_split}_gen_triton_code') |
|
|
gen_bash_file = os.path.join(tmp_gen_folder, gen_bash_file) |
|
|
|
|
|
hash_line = "#"*146 |
|
|
|
|
|
with open(triton_file, 'r') as f: |
|
|
lines = f.readlines() |
|
|
for iL, line in enumerate(lines): |
|
|
if line.strip() == hash_line: |
|
|
break |
|
|
test_code_lines = lines[iL+1:] |
|
|
test_code_lines_procs = test_code_lines |
|
|
|
|
|
|
|
|
|
|
|
code = code + '\n\n' + hash_line + '\n' + '\n' + '\n'.join(test_code_lines_procs) |
|
|
|
|
|
code_bash = f"python3 {gen_file}" |
|
|
with open(gen_file, 'w') as f: |
|
|
f.write(code) |
|
|
with open(gen_bash_file, 'w') as f: |
|
|
f.write(code_bash) |
|
|
try: |
|
|
|
|
|
result_profile = subprocess.run([f'rocprof-compute profile -n {fname_split} -- /bin/bash {gen_bash_file}'], capture_output=True, text=True, timeout=timeout, shell=True) |
|
|
analyze_profile = subprocess.run([f'rocprof-compute analyze -p workloads/{fname_split}/{target_gpu}'], capture_output=True, text=True, timeout=timeout, shell=True) |
|
|
|
|
|
|
|
|
profile_status = result_profile.returncode == 0 |
|
|
stdout_profile = result_profile.stdout |
|
|
stderr_profile = result_profile.stderr |
|
|
|
|
|
except Exception as e: |
|
|
if verbose: |
|
|
print(f"File: {fname}, Execution error: {e}") |
|
|
return None, None, str(e), None |
|
|
|
|
|
|
|
|
except subprocess.TimeoutExpired: |
|
|
if verbose: |
|
|
print(f"File: {fname} timed out!") |
|
|
return None, None, "Time out", None |
|
|
finally: |
|
|
pass |
|
|
|
|
|
|
|
|
if result_profile.returncode != 0: |
|
|
if verbose: |
|
|
print(f"Error in profiling kernel") |
|
|
else: |
|
|
if verbose: |
|
|
print(f"Success in in profiling kernel") |
|
|
try: |
|
|
section_text = parse_profiler_content(analyze_profile.stdout) |
|
|
stdout_analyze = "\nBelow are some profiling info of this kernel generated by the tool of rocprof-compute on AMD MI250 gpu, you can reference these info to analyze and generate better kernel." |
|
|
stdout_analyze += "\n1.Overview:Briefly describe the kernel type along with its runtime and dispatch statistics, such as the main kernel name, invocation count, and average execution time." |
|
|
stdout_analyze += f"\n{section_text['0']}" |
|
|
stdout_analyze += "\n2.Hardware & Resources:Key hardware details including model, architecture, number of CUs, capacities of LDS/SMEM/registers, and maximum workgroup size." |
|
|
stdout_analyze += f"\n{section_text['1']}" |
|
|
stdout_analyze += "\n3.Performance Utilization & Bottlenecks:Core bottleneck indicators such as FLOPs utilization, active CUs, occupancy, and memory bandwidth/utilization." |
|
|
stdout_analyze += f"\n{section_text['2']}" |
|
|
stdout_analyze += "\n4.Instruction Mix & Memory Access:Distribution of arithmetic, memory, and branch instructions (e.g., MFMA/FMA/VALU/VMEM), cache hit rates (L1/L2), memory bandwidth, and conflict statistics." |
|
|
stdout_analyze += f"\n{section_text['10']}" |
|
|
stdout_analyze += f"\n{section_text['16']}" |
|
|
stdout_analyze += f"\n{section_text['17']}" |
|
|
stdout_analyze += "\n5.Threading & Allocation:Wavefront/workgroup counts, allocation of VGPRs/SGPRs/LDS, thread concurrency, and resource usage per thread or workgroup." |
|
|
stdout_analyze += f"\n{section_text['7']}" |
|
|
except Exception as e: |
|
|
return None, None, str(e), None |
|
|
return profile_status, stdout_profile, stderr_profile, stdout_analyze |
|
|
|
|
|
def extract_code_from_llm_output(response): |
|
|
|
|
|
code = None |
|
|
if "```" not in response: |
|
|
return response |
|
|
code_blocks = extract_code_blocks(response) |
|
|
for _code in code_blocks.code_dict_list: |
|
|
code += _code['context'] + "\n" |
|
|
return code |
|
|
|
|
|
def get_fname_difficulty_from_label(label): |
|
|
triton_root = "dataloaders/TB_eval/TritonBench/data/TritonBench_G_comp_alpac_v1_fixed_with_difficulty.json" |
|
|
with open(triton_root, 'r') as f: |
|
|
data = json.load(f) |
|
|
for item in data: |
|
|
if item['output'] == label: |
|
|
return item['file'], item['difficulty'] |
|
|
return None, None |
|
|
|
|
|
def process_code(code: str): |
|
|
if "```python" in code: |
|
|
code = code.split("```python")[-1].replace("<|im_end|>", "").replace("<|EOT|>", "") |
|
|
|
|
|
try: |
|
|
tree = ast.parse(code) |
|
|
imports = [] |
|
|
function_definitions = [] |
|
|
|
|
|
|
|
|
for node in ast.walk(tree): |
|
|
if isinstance(node, ast.Import) or isinstance(node, ast.ImportFrom): |
|
|
|
|
|
imports.append(ast.unparse(node)) |
|
|
elif isinstance(node, ast.FunctionDef): |
|
|
|
|
|
function_code = ast.unparse(node) |
|
|
function_definitions.append(function_code) |
|
|
|
|
|
return "\n".join(imports) + "\n\n" + "\n".join(function_definitions) |
|
|
|
|
|
except: |
|
|
return code |
|
|
|
|
|
|
|
|
def code_call_exec_success_allclose(code, fname, py_folder, temp_root="tmp2", atol=1e-3, rtol=1e-1, timeout=2*60, verbose=False, gpu_id=0): |
|
|
tmp_gen_folder = os.path.join(temp_root, "gen") |
|
|
os.makedirs(tmp_gen_folder, exist_ok=True) |
|
|
match = re.match(r"^([a-zA-Z0-9_]+?)(?:_\d+)?\.py$", fname) |
|
|
if match: |
|
|
op = match.group(1) |
|
|
filename = op + '.py' |
|
|
triton_root = py_folder |
|
|
triton_file = os.path.join(triton_root, filename) |
|
|
|
|
|
gen_file = get_temp_file(prefix=f'{fname}_gen_triton_code') |
|
|
gen_file = os.path.join(tmp_gen_folder, gen_file) |
|
|
|
|
|
hash_line = "#"*146 |
|
|
|
|
|
with open(triton_file, 'r') as f: |
|
|
lines = f.readlines() |
|
|
for iL, line in enumerate(lines): |
|
|
if line.strip() == hash_line: |
|
|
break |
|
|
test_code_lines = lines[iL+1:] |
|
|
test_code_lines_procs = test_code_lines |
|
|
|
|
|
|
|
|
|
|
|
code = code + '\n\n' + hash_line + '\n' + '\n' + '\n'.join(test_code_lines_procs) |
|
|
|
|
|
with open(gen_file, 'w') as f: |
|
|
f.write(code) |
|
|
|
|
|
try: |
|
|
|
|
|
result_call = subprocess.run([f'HIP_VISIBLE_DEVICES={gpu_id} python3 {gen_file}'], capture_output=True, text=True, timeout=timeout, shell=True) |
|
|
call_status = result_call.returncode == 0 |
|
|
|
|
|
|
|
|
result_corr = subprocess.run([f'HIP_VISIBLE_DEVICES={gpu_id} python3 dataloaders/TB_eval/correctness.py --gen_file {gen_file} --ref_file {triton_file} --atol {atol} --rtol {rtol}'], capture_output=True, text=True, timeout=timeout, shell=True) |
|
|
stdout_corr = result_corr.stdout |
|
|
stderr_corr = result_corr.stderr |
|
|
|
|
|
except Exception as e: |
|
|
if verbose: |
|
|
print(f"File: {fname}, Execution error: {e}") |
|
|
return None, None, None, str(e), None, None |
|
|
|
|
|
|
|
|
except subprocess.TimeoutExpired: |
|
|
if verbose: |
|
|
print(f"File: {fname} timed out!") |
|
|
return None, None, None, "Time out", None, None |
|
|
finally: |
|
|
pass |
|
|
|
|
|
with open(gen_file+".stdout", 'w') as f: |
|
|
f.write(stdout_corr) |
|
|
|
|
|
with open(gen_file+".stderr", 'w') as f: |
|
|
f.write(stderr_corr) |
|
|
|
|
|
|
|
|
if result_corr.returncode != 0: |
|
|
if verbose: |
|
|
print(f"Error in generated code: {stderr_corr}") |
|
|
return call_status, None, result_call.stdout, result_call.stderr, stdout_corr, stderr_corr |
|
|
else: |
|
|
if verbose: |
|
|
print(f"Success in generated code: {stdout_corr}") |
|
|
_, exec_status, gen_stdout, gen_stderr = stdout_corr.split("*#*#") |
|
|
return call_status, exec_status, result_call.stdout, result_call.stderr, gen_stdout, gen_stderr |
|
|
|
|
|
|
|
|
|
|
|
class bcolors: |
|
|
HEADER = '\033[95m' |
|
|
OKBLUE = '\033[94m' |
|
|
OKCYAN = '\033[96m' |
|
|
OKGREEN = '\033[92m' |
|
|
WARNING = '\033[93m' |
|
|
FAIL = '\033[91m' |
|
|
ENDC = '\033[0m' |
|
|
BOLD = '\033[1m' |
|
|
UNDERLINE = '\033[4m' |
|
|
|
|
|
def green_or_red(status): |
|
|
if status: |
|
|
return bcolors.OKGREEN |
|
|
else: |
|
|
return bcolors.FAIL |
|
|
|
|
|
def color_end(): |
|
|
return bcolors.ENDC |
|
|
|
|
|
def bool_colorize(status): |
|
|
if status: |
|
|
return bcolors.OKGREEN + str(status) + bcolors.ENDC |
|
|
else: |
|
|
return bcolors.FAIL + str(status) + bcolors.ENDC |