File size: 3,174 Bytes
f3ce0b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100

import os
import sys
import ray
from ray import tune
from ray.tune.search.optuna import OptunaSearch

# Add parent dir to path to import perf_takehome
current_dir = os.path.dirname(os.path.abspath(__file__))
parent_dir = os.path.dirname(current_dir)
sys.path.insert(0, parent_dir)
# Add ray/python to path
ray_path = os.path.join(parent_dir, "ray", "python")
sys.path.insert(0, ray_path)

import ray
from ray import tune

def objective(config):
    # Wrapper to run kernel test with params
    # We need to monkeypath KernelBuilder default args? 
    # Or modify do_kernel_test to accept kwargs?
    # do_kernel_test calls KernelBuilder().build_kernel(...)
    # We can perform a hack: Subclass KernelBuilder and inject it?
    # Or better: Just use the code from do_kernel_test but adapted.
    
    try:
        forest_height = 10
        rounds = 16
        batch_size = 256
        
        # Setup similar to do_kernel_test
        forest = Tree.generate(forest_height)
        inp = Input.generate(forest, batch_size, rounds)
        mem = build_mem_image(forest, inp)
        
        kb = KernelBuilder()
        # Pass tuned parameters
        kb.build_kernel(
            forest.height, 
            len(forest.values), 
            len(inp.indices), 
            rounds,
            active_threshold=config["active_threshold"],
            mask_skip=config["mask_skip"]
        )
        
        value_trace = {}
        machine = Machine(
            mem,
            kb.instrs,
            kb.debug_info(),
            n_cores=N_CORES,
            value_trace=value_trace,
            trace=False,
        )
        machine.prints = False
        
        # Run
        while machine.cores[0].state.value != 3: # STOPPED
            machine.run()
            if machine.cores[0].state.value == 2: # PAUSED
                machine.cores[0].state = machine.cores[0].state.__class__(1)
                continue
            break
            
        machine.enable_pause = False
        # Ref
        for ref_mem in reference_kernel2(mem, value_trace):
            pass
            
        # Validate
        inp_values_p = ref_mem[6]
        if machine.mem[inp_values_p : inp_values_p + len(inp.values)] != ref_mem[inp_values_p : inp_values_p + len(inp.values)]:
             return {"cycles": 999999, "correct": False}
             
        return {"cycles": machine.cycle, "correct": True}
        
    except Exception as e:
        print(f"Error: {e}")
        return {"cycles": 999999, "correct": False}

if __name__ == "__main__":
    ray.init()
    
    analysis = tune.run(
        objective,
        config={
            "active_threshold": tune.grid_search([4, 8, 16]),
            # "mask_skip": tune.grid_search([True, False]), # We know True is better? Or maybe overhead logic is buggy?
            "mask_skip": True
        },
        mode="min",
        metric="cycles",
        num_samples=1,
    )
    
    print("Best config: ", analysis.get_best_config(metric="cycles", mode="min"))
    print("Best cycles: ", analysis.best_result["cycles"])