File size: 4,900 Bytes
f3ce0b0 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
import os
import sys
# Add parent dir to path to import perf_takehome
current_dir = os.path.dirname(os.path.abspath(__file__))
parent_dir = os.path.dirname(current_dir)
sys.path.insert(0, parent_dir)
from perf_takehome import KernelBuilder, do_kernel_test, Tree, Input, build_mem_image, N_CORES, Machine, reference_kernel2
def objective(active_threshold, mask_skip):
try:
forest_height = 10
rounds = 16
batch_size = 256
forest = Tree.generate(forest_height)
inp = Input.generate(forest, batch_size, rounds)
mem = build_mem_image(forest, inp)
kb = KernelBuilder()
kb.build_kernel(
forest.height,
len(forest.values),
len(inp.indices),
rounds,
active_threshold=active_threshold,
mask_skip=mask_skip
)
value_trace = {}
machine = Machine(
mem,
kb.instrs,
kb.debug_info(),
n_cores=N_CORES,
value_trace=value_trace,
trace=False,
)
machine.prints = False
while machine.cores[0].state.value != 3: # STOPPED
machine.run()
if machine.cores[0].state.value == 2: # PAUSED
machine.cores[0].state = machine.cores[0].state.__class__(1)
continue
break
machine.enable_pause = False
for ref_mem in reference_kernel2(mem, value_trace):
pass
inp_values_p = ref_mem[6]
if machine.mem[inp_values_p : inp_values_p + len(inp.values)] != ref_mem[inp_values_p : inp_values_p + len(inp.values)]:
return 999999
return machine.cycle
except Exception as e:
print(f"Error: {e}")
return 999999
if __name__ == "__main__":
thresholds = [4]
mask_skip = True
scalar_offloads = [0, 2, 4, 6, 8, 10]
best_cycles = float('inf')
best_config = None
for ms in [True]:
for th in thresholds:
for so in scalar_offloads:
print(f"Testing active_threshold={th}, mask_skip={ms}, scalar_offload={so}...")
# We need to update objective to pass scalar_offload
try:
forest_height = 10
rounds = 16
batch_size = 256
forest = Tree.generate(forest_height)
inp = Input.generate(forest, batch_size, rounds)
mem = build_mem_image(forest, inp)
kb = KernelBuilder()
kb.build_kernel(
forest.height,
len(forest.values),
len(inp.indices),
rounds,
active_threshold=th,
mask_skip=ms,
scalar_offload=so
)
value_trace = {}
machine = Machine(
mem,
kb.instrs,
kb.debug_info(),
n_cores=N_CORES,
value_trace=value_trace,
trace=False,
)
machine.prints = False
while machine.cores[0].state.value != 3:
machine.run()
if machine.cores[0].state.value == 2:
machine.cores[0].state = machine.cores[0].state.__class__(1)
continue
break
machine.enable_pause = False
for ref_mem in reference_kernel2(mem, value_trace):
pass
inp_values_p = ref_mem[6]
cycles = 0
if machine.mem[inp_values_p : inp_values_p + len(inp.values)] != ref_mem[inp_values_p : inp_values_p + len(inp.values)]:
cycles = 999999
else:
cycles = machine.cycle
print(f" -> Cycles: {cycles}")
if cycles < best_cycles:
best_cycles = cycles
best_config = (th, ms, so)
except Exception as e:
print(f"Error: {e}")
print(f"Best Config: th={best_config[0]}, mask={best_config[1]}, offload={best_config[2]}")
print(f"Best Cycles: {best_cycles}")
|