File size: 4,900 Bytes
f3ce0b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136

import os
import sys

# Add parent dir to path to import perf_takehome
current_dir = os.path.dirname(os.path.abspath(__file__))
parent_dir = os.path.dirname(current_dir)
sys.path.insert(0, parent_dir)

from perf_takehome import KernelBuilder, do_kernel_test, Tree, Input, build_mem_image, N_CORES, Machine, reference_kernel2

def objective(active_threshold, mask_skip):
    try:
        forest_height = 10
        rounds = 16
        batch_size = 256
        
        forest = Tree.generate(forest_height)
        inp = Input.generate(forest, batch_size, rounds)
        mem = build_mem_image(forest, inp)
        
        kb = KernelBuilder()
        kb.build_kernel(
            forest.height, 
            len(forest.values), 
            len(inp.indices), 
            rounds,
            active_threshold=active_threshold,
            mask_skip=mask_skip
        )
        
        value_trace = {}
        machine = Machine(
            mem,
            kb.instrs,
            kb.debug_info(),
            n_cores=N_CORES,
            value_trace=value_trace,
            trace=False,
        )
        machine.prints = False
        
        while machine.cores[0].state.value != 3: # STOPPED
            machine.run()
            if machine.cores[0].state.value == 2: # PAUSED
                machine.cores[0].state = machine.cores[0].state.__class__(1)
                continue
            break
            
        machine.enable_pause = False
        for ref_mem in reference_kernel2(mem, value_trace):
            pass
            
        inp_values_p = ref_mem[6]
        if machine.mem[inp_values_p : inp_values_p + len(inp.values)] != ref_mem[inp_values_p : inp_values_p + len(inp.values)]:
             return 999999
             
        return machine.cycle
        
    except Exception as e:
        print(f"Error: {e}")
        return 999999

if __name__ == "__main__":
    thresholds = [4]
    mask_skip = True
    scalar_offloads = [0, 2, 4, 6, 8, 10]
    
    best_cycles = float('inf')
    best_config = None
    
    for ms in [True]:
        for th in thresholds:
            for so in scalar_offloads:
                print(f"Testing active_threshold={th}, mask_skip={ms}, scalar_offload={so}...")
                # We need to update objective to pass scalar_offload
                try:
                    forest_height = 10
                    rounds = 16
                    batch_size = 256
                    
                    forest = Tree.generate(forest_height)
                    inp = Input.generate(forest, batch_size, rounds)
                    mem = build_mem_image(forest, inp)
                    
                    kb = KernelBuilder()
                    kb.build_kernel(
                        forest.height, 
                        len(forest.values), 
                        len(inp.indices), 
                        rounds,
                        active_threshold=th,
                        mask_skip=ms,
                        scalar_offload=so
                    )
                    
                    value_trace = {}
                    machine = Machine(
                        mem,
                        kb.instrs,
                        kb.debug_info(),
                        n_cores=N_CORES,
                        value_trace=value_trace,
                        trace=False,
                    )
                    machine.prints = False
                    
                    while machine.cores[0].state.value != 3: 
                        machine.run()
                        if machine.cores[0].state.value == 2: 
                            machine.cores[0].state = machine.cores[0].state.__class__(1)
                            continue
                        break
                        
                    machine.enable_pause = False
                    for ref_mem in reference_kernel2(mem, value_trace):
                        pass
                        
                    inp_values_p = ref_mem[6]
                    cycles = 0
                    if machine.mem[inp_values_p : inp_values_p + len(inp.values)] != ref_mem[inp_values_p : inp_values_p + len(inp.values)]:
                         cycles = 999999
                    else:
                         cycles = machine.cycle
                         
                    print(f"  -> Cycles: {cycles}")
                    if cycles < best_cycles:
                        best_cycles = cycles
                        best_config = (th, ms, so)
                        
                except Exception as e:
                    print(f"Error: {e}")

    print(f"Best Config: th={best_config[0]}, mask={best_config[1]}, offload={best_config[2]}")
    print(f"Best Cycles: {best_cycles}")