import os import sys import ray from ray import tune from ray.tune.search.optuna import OptunaSearch # Add parent dir to path to import perf_takehome current_dir = os.path.dirname(os.path.abspath(__file__)) parent_dir = os.path.dirname(current_dir) sys.path.insert(0, parent_dir) # Add ray/python to path ray_path = os.path.join(parent_dir, "ray", "python") sys.path.insert(0, ray_path) import ray from ray import tune def objective(config): # Wrapper to run kernel test with params # We need to monkeypath KernelBuilder default args? # Or modify do_kernel_test to accept kwargs? # do_kernel_test calls KernelBuilder().build_kernel(...) # We can perform a hack: Subclass KernelBuilder and inject it? # Or better: Just use the code from do_kernel_test but adapted. try: forest_height = 10 rounds = 16 batch_size = 256 # Setup similar to do_kernel_test forest = Tree.generate(forest_height) inp = Input.generate(forest, batch_size, rounds) mem = build_mem_image(forest, inp) kb = KernelBuilder() # Pass tuned parameters kb.build_kernel( forest.height, len(forest.values), len(inp.indices), rounds, active_threshold=config["active_threshold"], mask_skip=config["mask_skip"] ) value_trace = {} machine = Machine( mem, kb.instrs, kb.debug_info(), n_cores=N_CORES, value_trace=value_trace, trace=False, ) machine.prints = False # Run while machine.cores[0].state.value != 3: # STOPPED machine.run() if machine.cores[0].state.value == 2: # PAUSED machine.cores[0].state = machine.cores[0].state.__class__(1) continue break machine.enable_pause = False # Ref for ref_mem in reference_kernel2(mem, value_trace): pass # Validate inp_values_p = ref_mem[6] if machine.mem[inp_values_p : inp_values_p + len(inp.values)] != ref_mem[inp_values_p : inp_values_p + len(inp.values)]: return {"cycles": 999999, "correct": False} return {"cycles": machine.cycle, "correct": True} except Exception as e: print(f"Error: {e}") return {"cycles": 999999, "correct": False} if __name__ == "__main__": ray.init() analysis = tune.run( objective, config={ "active_threshold": tune.grid_search([4, 8, 16]), # "mask_skip": tune.grid_search([True, False]), # We know True is better? Or maybe overhead logic is buggy? "mask_skip": True }, mode="min", metric="cycles", num_samples=1, ) print("Best config: ", analysis.get_best_config(metric="cycles", mode="min")) print("Best cycles: ", analysis.best_result["cycles"])