llmll's picture
Upload folder using huggingface_hub
02c783d verified
import ast
import os
import subprocess
from random import randint
from tqdm import tqdm
from shutil import copyfile
import datetime
import json
from parse_llm_code import extract_code_blocks
import numpy as np
import re
def get_temp_bash_file(prefix='temp_code'):
# Generate a unique temporary file nameAdd commentMore actions
temp_file_name = f'{prefix}_{randint(999, 999999)}.sh'
while os.path.exists(temp_file_name):
temp_file_name.replace('.sh', f'_{randint(999, 999999)}.sh')
return temp_file_name
def parse_profiler_content(profile_content):
delimiter = "--------------------------------------------------------------------------------"
parts = profile_content.split(delimiter)
section_data = {}
section_pattern = re.compile(r"^\s*(\d+)\..*$", re.MULTILINE)
for part in parts:
trimmed_part = part.strip()
if not trimmed_part:
continue
match = section_pattern.search(trimmed_part)
if match:
section_number = match.group(1)
full_section_content = delimiter + part
section_data[section_number] = full_section_content
return section_data
## Implementation from https://arxiv.org/pdf/2107.03374
def passk(n, c, k):
if n -c < k: return 1.0
return 1 - np.prod(
1 - k/ np.arange(
n-c+1, n+1
)
)
def get_time():
# Get the current time in the format YYYY-MM-DD_HH-MM-SS
return datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
def get_temp_file(prefix='temp_code'):
# Generate a unique temporary file name
temp_file_name = f'{prefix}_{randint(999, 999999)}.py'
while os.path.exists(temp_file_name):
temp_file_name.replace('.py', f'_{randint(999, 999999)}.py')
return temp_file_name
def code_call_exec_success_stdout(code, fname, temp_root="tmp2", tolerance=2, verbose=False):
# Save the code to a temporary file
tmp_triton_folder = os.path.join(temp_root, "triton") #f"{temp_root}_triton"
tmp_gen_folder = os.path.join(temp_root, "gen") #f"{temp_root}_gen"
os.makedirs(tmp_triton_folder, exist_ok=True)
os.makedirs(tmp_gen_folder, exist_ok=True)
triton_root = "dataloaders/TB_eval/TritonBench/data/TritonBench_G_v1"
RAND_FILE = os.path.join(triton_root, "rand_utils.py")
copyfile(RAND_FILE, os.path.join(tmp_triton_folder, "rand_utils.py"))
copyfile(RAND_FILE, os.path.join(tmp_gen_folder, "rand_utils.py"))
gen_file = get_temp_file(prefix=f'{fname}_gen_triton_code')
triton_file = os.path.join(triton_root, fname)
temp_triton_file = get_temp_file(prefix=f'{fname}_temp_triton')
gen_file = os.path.join(tmp_gen_folder, gen_file)
temp_triton_file = os.path.join(tmp_triton_folder, temp_triton_file)
IMPORT_STATEMENT = f"""
from rand_utils import torch_rand, torch_randint, torch_randn
import torch
torch.set_printoptions(precision={tolerance},profile='full',sci_mode=False)
"""
hash_line = "#"*146
## from triton_file copy everything after the hash_line into gen_file
with open(triton_file, 'r') as f:
lines = f.readlines()
# lines.append(
# '\nprint(result_gold)'
# )
for iL, line in enumerate(lines):
if line.strip() == hash_line:
break
test_code_lines = lines[iL+1:]
test_code_lines = IMPORT_STATEMENT.split('\n') + test_code_lines
test_code_lines_procs = []
for line in test_code_lines:
if "torch.rand" in line:
line = line.replace("torch.rand", "torch_rand")
test_code_lines_procs.append(line)
with open(temp_triton_file, 'w') as f:
triton_lines = lines[:iL] + [hash_line] + test_code_lines_procs
for line in triton_lines:
f.write(line + "\n")
code = code + '\n\n' + hash_line + '\n' + '\n' + '\n'.join(test_code_lines_procs)
with open(gen_file, 'w') as f:
f.write(code)
## Execute two codes gen_file and triton_file using subprocess.
## 1. If gen_file return error then return status as False, and stdout and stderr from gen file
## 2. If triton_file return error then return status as True and stdout and stderr as None
## 3. If gen_file and triton_file both return success then compare stdout from gen_file and triton_file. If stdout matches then return status as True, and stdout and stderr as None else return status as False and stdout and stderr as test cases mismatched.
try:
# Execute the generated code
result_gen = subprocess.run(['python3', gen_file], capture_output=True, text=True, timeout=2*60)
stdout_gen = result_gen.stdout
stderr_gen = result_gen.stderr
# Check if the generated code executed successfully
if result_gen.returncode != 0:
if verbose:
print(f"Error in generated code: {stderr_gen}")
return False, False, stdout_gen, stderr_gen
# Execute the Triton code
result_triton = subprocess.run(['python3', temp_triton_file], capture_output=True, text=True, timeout=2*60)
stdout_triton = result_triton.stdout
stderr_triton = result_triton.stderr
# Check if the Triton code executed successfully
if result_triton.returncode != 0:
if verbose:
print(f"Error in Triton code: {stderr_triton}")
return None, None, None, None
with open(gen_file+".out", 'w') as f:
f.write(stdout_gen)
with open(temp_triton_file+".out", 'w') as f:
f.write(stdout_triton)
with open(gen_file+".err", 'w') as f:
f.write(stderr_gen)
with open(temp_triton_file+".err", 'w') as f:
f.write(stderr_triton)
# Compare the outputs
if stdout_gen == stdout_triton:
return True, True, None, None
else:
return True, False, stdout_gen, "Error: not all test cases passed. The generated code and ground truth code produced different outputs."
except Exception as e:
if verbose:
print(f"File: {fname}, Execution error: {e}")
return False, False, None, str(e)
# Clean up the temporary file
except subprocess.TimeoutExpired:
if verbose:
print(f"File: {fname} timed out!")
return None, None, None, "Time out"
finally:
pass
# print(f"temp file for File: {fname} removed!")
# if os.path.exists(gen_file):
# os.remove(gen_file)
return False, False, None, None
def code_kernel_profiling(code, fname, py_folder, target_gpu, temp_root="tmp2", atol=1e-3, rtol=1e-1, timeout=6*60, verbose=False):
tmp_gen_folder = os.path.join(temp_root, "gen")
os.makedirs(tmp_gen_folder, exist_ok=True)
triton_root = py_folder
triton_file = os.path.join(triton_root, fname)
gen_file = get_temp_file(prefix=f'{fname}_gen_triton_code')
gen_file = os.path.join(tmp_gen_folder, gen_file)
fname_split = fname.split('.')[0]
gen_bash_file = get_temp_bash_file(prefix=f'{fname_split}_gen_triton_code')
gen_bash_file = os.path.join(tmp_gen_folder, gen_bash_file)
hash_line = "#"*146
with open(triton_file, 'r') as f:
lines = f.readlines()
for iL, line in enumerate(lines):
if line.strip() == hash_line:
break
test_code_lines = lines[iL+1:]
test_code_lines_procs = test_code_lines
# code = process_code(code)
code = code + '\n\n' + hash_line + '\n' + '\n' + '\n'.join(test_code_lines_procs)
code_bash = f"python3 {gen_file}"
with open(gen_file, 'w') as f:
f.write(code)
with open(gen_bash_file, 'w') as f:
f.write(code_bash)
try:
## Just to a simple call to the generated code
result_profile = subprocess.run([f'rocprof-compute profile -n {fname_split} -- /bin/bash {gen_bash_file}'], capture_output=True, text=True, timeout=timeout, shell=True)
analyze_profile = subprocess.run([f'rocprof-compute analyze -p workloads/{fname_split}/{target_gpu}'], capture_output=True, text=True, timeout=timeout, shell=True)
# abstract profiling info
profile_status = result_profile.returncode == 0
stdout_profile = result_profile.stdout
stderr_profile = result_profile.stderr
except Exception as e:
if verbose:
print(f"File: {fname}, Execution error: {e}")
return None, None, str(e), None
# Clean up the temporary file
except subprocess.TimeoutExpired:
if verbose:
print(f"File: {fname} timed out!")
return None, None, "Time out", None
finally:
pass
# Check if the generated code executed successfully
if result_profile.returncode != 0:
if verbose:
print(f"Error in profiling kernel")
else:
if verbose:
print(f"Success in in profiling kernel")
try:
section_text = parse_profiler_content(analyze_profile.stdout)
stdout_analyze = "\nBelow are some profiling info of this kernel generated by the tool of rocprof-compute on AMD MI250 gpu, you can reference these info to analyze and generate better kernel."
stdout_analyze += "\n1.Overview:Briefly describe the kernel type along with its runtime and dispatch statistics, such as the main kernel name, invocation count, and average execution time."
stdout_analyze += f"\n{section_text['0']}"
stdout_analyze += "\n2.Hardware & Resources:Key hardware details including model, architecture, number of CUs, capacities of LDS/SMEM/registers, and maximum workgroup size."
stdout_analyze += f"\n{section_text['1']}"
stdout_analyze += "\n3.Performance Utilization & Bottlenecks:Core bottleneck indicators such as FLOPs utilization, active CUs, occupancy, and memory bandwidth/utilization."
stdout_analyze += f"\n{section_text['2']}"
stdout_analyze += "\n4.Instruction Mix & Memory Access:Distribution of arithmetic, memory, and branch instructions (e.g., MFMA/FMA/VALU/VMEM), cache hit rates (L1/L2), memory bandwidth, and conflict statistics."
stdout_analyze += f"\n{section_text['10']}"
stdout_analyze += f"\n{section_text['16']}"
stdout_analyze += f"\n{section_text['17']}"
stdout_analyze += "\n5.Threading & Allocation:Wavefront/workgroup counts, allocation of VGPRs/SGPRs/LDS, thread concurrency, and resource usage per thread or workgroup."
stdout_analyze += f"\n{section_text['7']}"
except Exception as e:
return None, None, str(e), None
return profile_status, stdout_profile, stderr_profile, stdout_analyze
def extract_code_from_llm_output(response):
# Extract code blocks from the LLM response
code = None
if "```" not in response:
return response
code_blocks = extract_code_blocks(response)
for _code in code_blocks.code_dict_list:
code += _code['context'] + "\n"
return code
def get_fname_difficulty_from_label(label):
triton_root = "dataloaders/TB_eval/TritonBench/data/TritonBench_G_comp_alpac_v1_fixed_with_difficulty.json"
with open(triton_root, 'r') as f:
data = json.load(f)
for item in data:
if item['output'] == label:
return item['file'], item['difficulty']
return None, None
def process_code(code: str):
if "```python" in code:
code = code.split("```python")[-1].replace("<|im_end|>", "").replace("<|EOT|>", "")
try:
tree = ast.parse(code)
imports = []
function_definitions = []
# Traverse the AST to find import statements and function definitions
for node in ast.walk(tree):
if isinstance(node, ast.Import) or isinstance(node, ast.ImportFrom):
# Collect the import statements
imports.append(ast.unparse(node)) # Convert the AST node back to code
elif isinstance(node, ast.FunctionDef):
# Collect function definitions
function_code = ast.unparse(node) # Get the Python code for the function
function_definitions.append(function_code)
return "\n".join(imports) + "\n\n" + "\n".join(function_definitions)
except:
return code
def code_call_exec_success_allclose(code, fname, py_folder, temp_root="tmp2", atol=1e-3, rtol=1e-1, timeout=2*60, verbose=False, gpu_id=0):
tmp_gen_folder = os.path.join(temp_root, "gen")
os.makedirs(tmp_gen_folder, exist_ok=True)
match = re.match(r"^([a-zA-Z0-9_]+?)(?:_\d+)?\.py$", fname)
if match:
op = match.group(1)
filename = op + '.py'
triton_root = py_folder
triton_file = os.path.join(triton_root, filename)
gen_file = get_temp_file(prefix=f'{fname}_gen_triton_code')
gen_file = os.path.join(tmp_gen_folder, gen_file)
hash_line = "#"*146
with open(triton_file, 'r') as f:
lines = f.readlines()
for iL, line in enumerate(lines):
if line.strip() == hash_line:
break
test_code_lines = lines[iL+1:]
test_code_lines_procs = test_code_lines
# code = process_code(code)
code = code + '\n\n' + hash_line + '\n' + '\n' + '\n'.join(test_code_lines_procs)
with open(gen_file, 'w') as f:
f.write(code)
try:
## Just to a simple call to the generated code
result_call = subprocess.run([f'HIP_VISIBLE_DEVICES={gpu_id} python3 {gen_file}'], capture_output=True, text=True, timeout=timeout, shell=True)
call_status = result_call.returncode == 0
# Check for correctness
result_corr = subprocess.run([f'HIP_VISIBLE_DEVICES={gpu_id} python3 dataloaders/TB_eval/correctness.py --gen_file {gen_file} --ref_file {triton_file} --atol {atol} --rtol {rtol}'], capture_output=True, text=True, timeout=timeout, shell=True)
stdout_corr = result_corr.stdout
stderr_corr = result_corr.stderr
except Exception as e:
if verbose:
print(f"File: {fname}, Execution error: {e}")
return None, None, None, str(e), None, None
# Clean up the temporary file
except subprocess.TimeoutExpired:
if verbose:
print(f"File: {fname} timed out!")
return None, None, None, "Time out", None, None
finally:
pass
with open(gen_file+".stdout", 'w') as f:
f.write(stdout_corr)
with open(gen_file+".stderr", 'w') as f:
f.write(stderr_corr)
# Check if the generated code executed successfully
if result_corr.returncode != 0:
if verbose:
print(f"Error in generated code: {stderr_corr}")
return call_status, None, result_call.stdout, result_call.stderr, stdout_corr, stderr_corr
else:
if verbose:
print(f"Success in generated code: {stdout_corr}")
_, exec_status, gen_stdout, gen_stderr = stdout_corr.split("*#*#")
return call_status, exec_status, result_call.stdout, result_call.stderr, gen_stdout, gen_stderr
class bcolors:
HEADER = '\033[95m'
OKBLUE = '\033[94m'
OKCYAN = '\033[96m'
OKGREEN = '\033[92m'
WARNING = '\033[93m'
FAIL = '\033[91m'
ENDC = '\033[0m'
BOLD = '\033[1m'
UNDERLINE = '\033[4m'
def green_or_red(status):
if status:
return bcolors.OKGREEN
else:
return bcolors.FAIL
def color_end():
return bcolors.ENDC
def bool_colorize(status):
if status:
return bcolors.OKGREEN + str(status) + bcolors.ENDC
else:
return bcolors.FAIL + str(status) + bcolors.ENDC