| |
|
| | import os
|
| | import sys
|
| | import ray
|
| | from ray import tune
|
| | from ray.tune.search.optuna import OptunaSearch
|
| |
|
| |
|
| | current_dir = os.path.dirname(os.path.abspath(__file__))
|
| | parent_dir = os.path.dirname(current_dir)
|
| | sys.path.insert(0, parent_dir)
|
| |
|
| | ray_path = os.path.join(parent_dir, "ray", "python")
|
| | sys.path.insert(0, ray_path)
|
| |
|
| | import ray
|
| | from ray import tune
|
| |
|
| | def objective(config):
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | try:
|
| | forest_height = 10
|
| | rounds = 16
|
| | batch_size = 256
|
| |
|
| |
|
| | forest = Tree.generate(forest_height)
|
| | inp = Input.generate(forest, batch_size, rounds)
|
| | mem = build_mem_image(forest, inp)
|
| |
|
| | kb = KernelBuilder()
|
| |
|
| | kb.build_kernel(
|
| | forest.height,
|
| | len(forest.values),
|
| | len(inp.indices),
|
| | rounds,
|
| | active_threshold=config["active_threshold"],
|
| | mask_skip=config["mask_skip"]
|
| | )
|
| |
|
| | value_trace = {}
|
| | machine = Machine(
|
| | mem,
|
| | kb.instrs,
|
| | kb.debug_info(),
|
| | n_cores=N_CORES,
|
| | value_trace=value_trace,
|
| | trace=False,
|
| | )
|
| | machine.prints = False
|
| |
|
| |
|
| | while machine.cores[0].state.value != 3:
|
| | machine.run()
|
| | if machine.cores[0].state.value == 2:
|
| | machine.cores[0].state = machine.cores[0].state.__class__(1)
|
| | continue
|
| | break
|
| |
|
| | machine.enable_pause = False
|
| |
|
| | for ref_mem in reference_kernel2(mem, value_trace):
|
| | pass
|
| |
|
| |
|
| | inp_values_p = ref_mem[6]
|
| | if machine.mem[inp_values_p : inp_values_p + len(inp.values)] != ref_mem[inp_values_p : inp_values_p + len(inp.values)]:
|
| | return {"cycles": 999999, "correct": False}
|
| |
|
| | return {"cycles": machine.cycle, "correct": True}
|
| |
|
| | except Exception as e:
|
| | print(f"Error: {e}")
|
| | return {"cycles": 999999, "correct": False}
|
| |
|
| | if __name__ == "__main__":
|
| | ray.init()
|
| |
|
| | analysis = tune.run(
|
| | objective,
|
| | config={
|
| | "active_threshold": tune.grid_search([4, 8, 16]),
|
| |
|
| | "mask_skip": True
|
| | },
|
| | mode="min",
|
| | metric="cycles",
|
| | num_samples=1,
|
| | )
|
| |
|
| | print("Best config: ", analysis.get_best_config(metric="cycles", mode="min"))
|
| | print("Best cycles: ", analysis.best_result["cycles"])
|
| |
|