import os import sys # Add parent dir to path to import perf_takehome current_dir = os.path.dirname(os.path.abspath(__file__)) parent_dir = os.path.dirname(current_dir) sys.path.insert(0, parent_dir) from perf_takehome import KernelBuilder, do_kernel_test, Tree, Input, build_mem_image, N_CORES, Machine, reference_kernel2 def objective(active_threshold, mask_skip): try: forest_height = 10 rounds = 16 batch_size = 256 forest = Tree.generate(forest_height) inp = Input.generate(forest, batch_size, rounds) mem = build_mem_image(forest, inp) kb = KernelBuilder() kb.build_kernel( forest.height, len(forest.values), len(inp.indices), rounds, active_threshold=active_threshold, mask_skip=mask_skip ) value_trace = {} machine = Machine( mem, kb.instrs, kb.debug_info(), n_cores=N_CORES, value_trace=value_trace, trace=False, ) machine.prints = False while machine.cores[0].state.value != 3: # STOPPED machine.run() if machine.cores[0].state.value == 2: # PAUSED machine.cores[0].state = machine.cores[0].state.__class__(1) continue break machine.enable_pause = False for ref_mem in reference_kernel2(mem, value_trace): pass inp_values_p = ref_mem[6] if machine.mem[inp_values_p : inp_values_p + len(inp.values)] != ref_mem[inp_values_p : inp_values_p + len(inp.values)]: return 999999 return machine.cycle except Exception as e: print(f"Error: {e}") return 999999 if __name__ == "__main__": thresholds = [4] mask_skip = True scalar_offloads = [0, 2, 4, 6, 8, 10] best_cycles = float('inf') best_config = None for ms in [True]: for th in thresholds: for so in scalar_offloads: print(f"Testing active_threshold={th}, mask_skip={ms}, scalar_offload={so}...") # We need to update objective to pass scalar_offload try: forest_height = 10 rounds = 16 batch_size = 256 forest = Tree.generate(forest_height) inp = Input.generate(forest, batch_size, rounds) mem = build_mem_image(forest, inp) kb = KernelBuilder() kb.build_kernel( forest.height, len(forest.values), len(inp.indices), rounds, active_threshold=th, mask_skip=ms, scalar_offload=so ) value_trace = {} machine = Machine( mem, kb.instrs, kb.debug_info(), n_cores=N_CORES, value_trace=value_trace, trace=False, ) machine.prints = False while machine.cores[0].state.value != 3: machine.run() if machine.cores[0].state.value == 2: machine.cores[0].state = machine.cores[0].state.__class__(1) continue break machine.enable_pause = False for ref_mem in reference_kernel2(mem, value_trace): pass inp_values_p = ref_mem[6] cycles = 0 if machine.mem[inp_values_p : inp_values_p + len(inp.values)] != ref_mem[inp_values_p : inp_values_p + len(inp.values)]: cycles = 999999 else: cycles = machine.cycle print(f" -> Cycles: {cycles}") if cycles < best_cycles: best_cycles = cycles best_config = (th, ms, so) except Exception as e: print(f"Error: {e}") print(f"Best Config: th={best_config[0]}, mask={best_config[1]}, offload={best_config[2]}") print(f"Best Cycles: {best_cycles}")