anthropic-kernel / atempt_2 /manual_tuner.py
algorembrant's picture
Upload 39 files
f3ce0b0 verified
import os
import sys
# Add parent dir to path to import perf_takehome
current_dir = os.path.dirname(os.path.abspath(__file__))
parent_dir = os.path.dirname(current_dir)
sys.path.insert(0, parent_dir)
from perf_takehome import KernelBuilder, do_kernel_test, Tree, Input, build_mem_image, N_CORES, Machine, reference_kernel2
def objective(active_threshold, mask_skip):
try:
forest_height = 10
rounds = 16
batch_size = 256
forest = Tree.generate(forest_height)
inp = Input.generate(forest, batch_size, rounds)
mem = build_mem_image(forest, inp)
kb = KernelBuilder()
kb.build_kernel(
forest.height,
len(forest.values),
len(inp.indices),
rounds,
active_threshold=active_threshold,
mask_skip=mask_skip
)
value_trace = {}
machine = Machine(
mem,
kb.instrs,
kb.debug_info(),
n_cores=N_CORES,
value_trace=value_trace,
trace=False,
)
machine.prints = False
while machine.cores[0].state.value != 3: # STOPPED
machine.run()
if machine.cores[0].state.value == 2: # PAUSED
machine.cores[0].state = machine.cores[0].state.__class__(1)
continue
break
machine.enable_pause = False
for ref_mem in reference_kernel2(mem, value_trace):
pass
inp_values_p = ref_mem[6]
if machine.mem[inp_values_p : inp_values_p + len(inp.values)] != ref_mem[inp_values_p : inp_values_p + len(inp.values)]:
return 999999
return machine.cycle
except Exception as e:
print(f"Error: {e}")
return 999999
if __name__ == "__main__":
thresholds = [4]
mask_skip = True
scalar_offloads = [0, 2, 4, 6, 8, 10]
best_cycles = float('inf')
best_config = None
for ms in [True]:
for th in thresholds:
for so in scalar_offloads:
print(f"Testing active_threshold={th}, mask_skip={ms}, scalar_offload={so}...")
# We need to update objective to pass scalar_offload
try:
forest_height = 10
rounds = 16
batch_size = 256
forest = Tree.generate(forest_height)
inp = Input.generate(forest, batch_size, rounds)
mem = build_mem_image(forest, inp)
kb = KernelBuilder()
kb.build_kernel(
forest.height,
len(forest.values),
len(inp.indices),
rounds,
active_threshold=th,
mask_skip=ms,
scalar_offload=so
)
value_trace = {}
machine = Machine(
mem,
kb.instrs,
kb.debug_info(),
n_cores=N_CORES,
value_trace=value_trace,
trace=False,
)
machine.prints = False
while machine.cores[0].state.value != 3:
machine.run()
if machine.cores[0].state.value == 2:
machine.cores[0].state = machine.cores[0].state.__class__(1)
continue
break
machine.enable_pause = False
for ref_mem in reference_kernel2(mem, value_trace):
pass
inp_values_p = ref_mem[6]
cycles = 0
if machine.mem[inp_values_p : inp_values_p + len(inp.values)] != ref_mem[inp_values_p : inp_values_p + len(inp.values)]:
cycles = 999999
else:
cycles = machine.cycle
print(f" -> Cycles: {cycles}")
if cycles < best_cycles:
best_cycles = cycles
best_config = (th, ms, so)
except Exception as e:
print(f"Error: {e}")
print(f"Best Config: th={best_config[0]}, mask={best_config[1]}, offload={best_config[2]}")
print(f"Best Cycles: {best_cycles}")