| | """
|
| | # Anthropic's Original Performance Engineering Take-home (Release version)
|
| |
|
| | Copyright Anthropic PBC 2026. Permission is granted to modify and use, but not
|
| | to publish or redistribute your solutions so it's hard to find spoilers.
|
| |
|
| | # Task
|
| |
|
| | - Optimize the kernel (in KernelBuilder.build_kernel) as much as possible in the
|
| | available time, as measured by test_kernel_cycles on a frozen separate copy
|
| | of the simulator.
|
| |
|
| | Validate your results using `python tests/submission_tests.py` without modifying
|
| | anything in the tests/ folder.
|
| |
|
| | We recommend you look through problem.py next.
|
| | """
|
| |
|
| | from collections import defaultdict
|
| | import random
|
| | import unittest
|
| |
|
| | from problem import (
|
| | Engine,
|
| | DebugInfo,
|
| | SLOT_LIMITS,
|
| | VLEN,
|
| | N_CORES,
|
| | SCRATCH_SIZE,
|
| | Machine,
|
| | Tree,
|
| | Input,
|
| | HASH_STAGES,
|
| | reference_kernel,
|
| | build_mem_image,
|
| | reference_kernel2,
|
| | )
|
| |
|
| | from scheduler import Scheduler
|
| |
|
| |
|
| | class KernelBuilder:
|
| | def __init__(self):
|
| | self.scheduler = Scheduler()
|
| | self.scratch = {}
|
| | self.scratch_debug = {}
|
| | self.scratch_ptr = 0
|
| | self.const_map = {}
|
| |
|
| | def debug_info(self):
|
| | return DebugInfo(scratch_map=self.scratch_debug)
|
| |
|
| | def finalize(self):
|
| | return self.scheduler.schedule()
|
| |
|
| | def add_instr(self, instr_dict):
|
| |
|
| |
|
| | for engine, slots in instr_dict.items():
|
| | for args in slots:
|
| | self.scheduler.add_op(engine, args)
|
| |
|
| |
|
| | def alloc_scratch(self, name=None, length=1):
|
| | addr = self.scratch_ptr
|
| | if name is not None:
|
| | self.scratch[name] = addr
|
| | self.scratch_debug[addr] = (name, length)
|
| | self.scratch_ptr += length
|
| | assert self.scratch_ptr <= SCRATCH_SIZE, f"Out of scratch space: {self.scratch_ptr}"
|
| | return addr
|
| |
|
| | def scratch_const(self, val, name=None):
|
| | if val not in self.const_map:
|
| | addr = self.alloc_scratch(name)
|
| |
|
| |
|
| |
|
| | self.scheduler.add_op("load", ("const", addr, val))
|
| | self.const_map[val] = addr
|
| | return self.const_map[val]
|
| | def scratch_vec_const(self, val, name=None):
|
| |
|
| | key = (val, "vec")
|
| | if key not in self.const_map:
|
| | addr = self.alloc_scratch(name if name else f"vconst_{val}", VLEN)
|
| | scalar_addr = self.scratch_const(val)
|
| |
|
| | self.scheduler.add_op("valu", ("vbroadcast", addr, scalar_addr))
|
| | self.const_map[key] = addr
|
| | return self.const_map[key]
|
| |
|
| | def add_hash_opt(self, val_vec, tmp1_vec, tmp2_vec):
|
| | """
|
| | Adds slots for the strength-reduced hash function to scheduler.
|
| | """
|
| |
|
| | c1 = self.scratch_vec_const(0x7ED55D16, "h0_c")
|
| | m1 = self.scratch_vec_const(1 + (1<<12), "h0_m")
|
| | self.scheduler.add_op("valu", ("multiply_add", val_vec, val_vec, m1, c1))
|
| |
|
| |
|
| | c2 = self.scratch_vec_const(0xC761C23C, "h1_c")
|
| | s2 = self.scratch_vec_const(19, "h1_s")
|
| |
|
| | self.scheduler.add_op("valu", ("^", tmp1_vec, val_vec, c2))
|
| | self.scheduler.add_op("valu", (">>", tmp2_vec, val_vec, s2))
|
| |
|
| | self.scheduler.add_op("valu", ("^", val_vec, tmp1_vec, tmp2_vec))
|
| |
|
| |
|
| | c3 = self.scratch_vec_const(0x165667B1, "h2_c")
|
| | m3 = self.scratch_vec_const(1 + (1<<5), "h2_m")
|
| | self.scheduler.add_op("valu", ("multiply_add", val_vec, val_vec, m3, c3))
|
| |
|
| |
|
| | c4 = self.scratch_vec_const(0xD3A2646C, "h3_c")
|
| | s4 = self.scratch_vec_const(9, "h3_s")
|
| | self.scheduler.add_op("valu", ("+", tmp1_vec, val_vec, c4))
|
| | self.scheduler.add_op("valu", ("<<", tmp2_vec, val_vec, s4))
|
| | self.scheduler.add_op("valu", ("^", val_vec, tmp1_vec, tmp2_vec))
|
| |
|
| |
|
| | c5 = self.scratch_vec_const(0xFD7046C5, "h4_c")
|
| | m5 = self.scratch_vec_const(1 + (1<<3), "h4_m")
|
| | self.scheduler.add_op("valu", ("multiply_add", val_vec, val_vec, m5, c5))
|
| |
|
| |
|
| | c6 = self.scratch_vec_const(0xB55A4F09, "h5_c")
|
| | s6 = self.scratch_vec_const(16, "h5_s")
|
| | self.scheduler.add_op("valu", ("^", tmp1_vec, val_vec, c6))
|
| | self.scheduler.add_op("valu", (">>", tmp2_vec, val_vec, s6))
|
| | self.scheduler.add_op("valu", ("^", val_vec, tmp1_vec, tmp2_vec))
|
| |
|
| | def add_hash_opt_scalar(self, val_vec, tmp1_vec, tmp2_vec):
|
| | """
|
| | Scalarized version of hash optimization.
|
| | Unrolls loop over 8 lanes and uses ALU engine.
|
| | """
|
| |
|
| | def add_alu_lanes(op, dest_vec, src1_vec, src2_vec, s2_is_const=False):
|
| |
|
| | for lane in range(VLEN):
|
| |
|
| | s2_addr = src2_vec if s2_is_const else src2_vec + lane
|
| | self.scheduler.add_op("alu", (op, dest_vec + lane, src1_vec + lane, s2_addr))
|
| |
|
| |
|
| |
|
| | def add_mad_lanes(dest_vec, a_vec, b_vec, c_vec, b_is_const=False, c_is_const=False):
|
| | for lane in range(VLEN):
|
| | b_addr = b_vec if b_is_const else b_vec + lane
|
| | c_addr = c_vec if c_is_const else c_vec + lane
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | self.scheduler.add_op("alu", ("*", dest_vec + lane, a_vec + lane, b_addr))
|
| |
|
| | self.scheduler.add_op("alu", ("+", dest_vec + lane, dest_vec + lane, c_addr))
|
| |
|
| |
|
| | c1 = self.scratch_const(0x7ED55D16, "h0_c")
|
| | m1 = self.scratch_const(1 + (1<<12), "h0_m")
|
| |
|
| |
|
| | add_mad_lanes(val_vec, val_vec, m1, c1, True, True)
|
| |
|
| |
|
| | c2 = self.scratch_const(0xC761C23C, "h1_c")
|
| | s2 = self.scratch_const(19, "h1_s")
|
| | add_alu_lanes("^", tmp1_vec, val_vec, c2, True)
|
| | add_alu_lanes(">>", tmp2_vec, val_vec, s2, True)
|
| | add_alu_lanes("^", val_vec, tmp1_vec, tmp2_vec, False)
|
| |
|
| |
|
| | c3 = self.scratch_const(0x165667B1, "h2_c")
|
| | m3 = self.scratch_const(1 + (1<<5), "h2_m")
|
| | add_mad_lanes(val_vec, val_vec, m3, c3, True, True)
|
| |
|
| |
|
| | c4 = self.scratch_const(0xD3A2646C, "h3_c")
|
| | s4 = self.scratch_const(9, "h3_s")
|
| | add_alu_lanes("+", tmp1_vec, val_vec, c4, True)
|
| | add_alu_lanes("<<", tmp2_vec, val_vec, s4, True)
|
| | add_alu_lanes("^", val_vec, tmp1_vec, tmp2_vec, False)
|
| |
|
| |
|
| | c5 = self.scratch_const(0xFD7046C5, "h4_c")
|
| | m5 = self.scratch_const(1 + (1<<3), "h4_m")
|
| | add_mad_lanes(val_vec, val_vec, m5, c5, True, True)
|
| |
|
| |
|
| | c6 = self.scratch_const(0xB55A4F09, "h5_c")
|
| | s6 = self.scratch_const(16, "h5_s")
|
| | add_alu_lanes("^", tmp1_vec, val_vec, c6, True)
|
| | add_alu_lanes(">>", tmp2_vec, val_vec, s6, True)
|
| | add_alu_lanes("^", val_vec, tmp1_vec, tmp2_vec, False)
|
| |
|
| |
|
| |
|
| | def build_kernel(
|
| | self, forest_height: int, n_nodes: int, batch_size: int, rounds: int,
|
| | active_threshold=4, mask_skip=True, scalar_offload=2
|
| | ):
|
| | result_scalar_offload = scalar_offload
|
| | """
|
| | Vectorized Wavefront implementation.
|
| | """
|
| |
|
| | init_vars = [
|
| | "rounds", "n_nodes", "batch_size", "forest_height",
|
| | "forest_values_p", "inp_indices_p", "inp_values_p"
|
| | ]
|
| | ptr_map = {}
|
| | tmp_load = self.alloc_scratch("tmp_load")
|
| |
|
| | for i, v in enumerate(init_vars):
|
| | addr = self.alloc_scratch(v)
|
| | ptr_map[v] = addr
|
| | self.add_instr({"load": [("const", tmp_load, i)]})
|
| | self.add_instr({"load": [("load", addr, tmp_load)]})
|
| |
|
| | indices_base = self.alloc_scratch("indices_cache", batch_size)
|
| | values_base = self.alloc_scratch("values_cache", batch_size)
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | block_x = self.alloc_scratch("block_x", batch_size)
|
| | block_y = self.alloc_scratch("block_y", batch_size)
|
| |
|
| | num_vecs = batch_size // VLEN
|
| |
|
| | tmp_addrs_base = block_x
|
| | node_vals_base = block_x
|
| | vtmp1_base = block_x
|
| | vtmp2_base = block_y
|
| |
|
| |
|
| | const_0_vec = self.scratch_vec_const(0)
|
| | const_1_vec = self.scratch_vec_const(1)
|
| | global_n_nodes_vec = self.alloc_scratch("n_nodes_vec", VLEN)
|
| | self.add_instr({"valu": [("vbroadcast", global_n_nodes_vec, ptr_map["n_nodes"])]})
|
| |
|
| | active_temp_base = self.alloc_scratch("active_temp", 200)
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | for i in range(0, batch_size, VLEN):
|
| | i_const = self.scratch_const(i)
|
| |
|
| | self.scheduler.add_op("alu", ("+", tmp_load, ptr_map["inp_indices_p"], i_const))
|
| | self.scheduler.add_op("load", ("vload", indices_base + i, tmp_load))
|
| | self.scheduler.add_op("alu", ("+", tmp_load, ptr_map["inp_values_p"], i_const))
|
| | self.scheduler.add_op("load", ("vload", values_base + i, tmp_load))
|
| |
|
| |
|
| | self.scheduler.add_op("flow", ("pause",))
|
| |
|
| |
|
| |
|
| | for r in range(rounds):
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | vecs = []
|
| | for vec_i in range(num_vecs):
|
| | offset = vec_i * VLEN
|
| | vecs.append({
|
| | 'idx': indices_base + offset,
|
| | 'val': values_base + offset,
|
| | 'node': node_vals_base + offset,
|
| | 'tmp1': vtmp1_base + offset,
|
| | 'tmp2': vtmp2_base + offset,
|
| | 'addr': tmp_addrs_base + offset
|
| | })
|
| |
|
| |
|
| |
|
| | for r in range(rounds):
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | vecs = []
|
| | for vec_i in range(num_vecs):
|
| | offset = vec_i * VLEN
|
| | vecs.append({
|
| | 'idx': indices_base + offset,
|
| | 'val': values_base + offset,
|
| | 'node': node_vals_base + offset,
|
| | 'tmp1': vtmp1_base + offset,
|
| | 'tmp2': vtmp2_base + offset,
|
| | 'addr': tmp_addrs_base + offset
|
| | })
|
| |
|
| | if r == 0:
|
| |
|
| | scalar_node = self.alloc_scratch("scalar_node_r0")
|
| | self.scheduler.add_op("load", ("load", scalar_node, ptr_map["forest_values_p"]))
|
| | for vec in vecs:
|
| | self.scheduler.add_op("valu", ("vbroadcast", vec['node'], scalar_node))
|
| | active_indices = [0]
|
| | elif len(active_indices) * 2 <= 8:
|
| |
|
| | active_dev_ptr = active_temp_base
|
| | def alloc_temp(length=1):
|
| | nonlocal active_dev_ptr
|
| | addr = active_dev_ptr
|
| | active_dev_ptr += length
|
| | assert active_dev_ptr <= active_temp_base + 512
|
| | return addr
|
| |
|
| |
|
| |
|
| | new_actives = []
|
| | for x in active_indices:
|
| | new_actives.append(2*x + 1)
|
| | new_actives.append(2*x + 2)
|
| | active_indices = new_actives
|
| |
|
| |
|
| |
|
| | node_map = {}
|
| | for uidx in active_indices:
|
| | s_node = alloc_temp(1)
|
| | s_addr = alloc_temp(1)
|
| | idx_c = self.scratch_const(uidx)
|
| |
|
| | self.scheduler.add_op("alu", ("+", s_addr, ptr_map["forest_values_p"], idx_c))
|
| |
|
| | self.scheduler.add_op("load", ("load", s_node, s_addr))
|
| |
|
| | v_node = alloc_temp(VLEN)
|
| | self.scheduler.add_op("valu", ("vbroadcast", v_node, s_node))
|
| | node_map[uidx] = v_node
|
| |
|
| |
|
| | tree_temp_start = active_dev_ptr
|
| |
|
| |
|
| | for vec in vecs:
|
| |
|
| | active_dev_ptr = tree_temp_start
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | def build_tree(indices):
|
| | if len(indices) == 1:
|
| | return node_map[indices[0]]
|
| |
|
| | mid = len(indices) // 2
|
| | left = indices[:mid]
|
| | right = indices[mid:]
|
| | split_val = right[0]
|
| |
|
| | split_c = self.scratch_vec_const(split_val)
|
| | cond = alloc_temp(VLEN)
|
| | self.scheduler.add_op("valu", ("<", cond, vec['idx'], split_c))
|
| |
|
| | l_res = build_tree(left)
|
| | r_res = build_tree(right)
|
| |
|
| |
|
| | res = alloc_temp(VLEN)
|
| | self.scheduler.add_op("flow", ("vselect", res, cond, l_res, r_res))
|
| | return res
|
| |
|
| | final_res = build_tree(active_indices)
|
| |
|
| |
|
| | self.scheduler.add_op("valu", ("|", vec['node'], final_res, final_res))
|
| |
|
| | else:
|
| |
|
| |
|
| |
|
| | for vec in vecs:
|
| | for lane in range(VLEN):
|
| | self.scheduler.add_op("alu", ("+", vec['addr'] + lane, ptr_map["forest_values_p"], vec['idx'] + lane))
|
| |
|
| |
|
| | for vec in vecs:
|
| | for lane in range(VLEN):
|
| | self.scheduler.add_op("load", ("load", vec['node'] + lane, vec['addr'] + lane))
|
| |
|
| | do_wrap = True
|
| | if mask_skip and (1<<(r+2)) < n_nodes:
|
| | do_wrap = False
|
| |
|
| |
|
| |
|
| | use_offload = (r >= active_threshold) and (not do_wrap)
|
| | scalar_vectors = vecs[:result_scalar_offload] if use_offload else []
|
| | vector_vectors = vecs[result_scalar_offload:] if use_offload else vecs
|
| |
|
| |
|
| |
|
| | for vec in vector_vectors:
|
| | self.scheduler.add_op("valu", ("^", vec['val'], vec['val'], vec['node']))
|
| | for vec in vector_vectors:
|
| | self.add_hash_opt(vec['val'], vec['tmp1'], vec['tmp2'])
|
| |
|
| | for vec in vector_vectors:
|
| | self.scheduler.add_op("valu", ("&", vec['tmp1'], vec['val'], const_1_vec))
|
| | self.scheduler.add_op("valu", ("+", vec['tmp1'], vec['tmp1'], const_1_vec))
|
| | self.scheduler.add_op("valu", ("+", vec['idx'], vec['idx'], vec['idx']))
|
| | self.scheduler.add_op("valu", ("+", vec['idx'], vec['idx'], vec['tmp1']))
|
| |
|
| | if do_wrap:
|
| | for vec in vector_vectors:
|
| | self.scheduler.add_op("valu", ("<", vec['tmp1'], vec['idx'], global_n_nodes_vec))
|
| | for vec in vector_vectors:
|
| | self.scheduler.add_op("flow", ("vselect", vec['idx'], vec['tmp1'], vec['idx'], const_0_vec))
|
| |
|
| |
|
| |
|
| | def alu_lanes(op, dest, s1, s2, s2_c=False):
|
| | for l in range(VLEN):
|
| | s2_Address = s2 if s2_c else s2+l
|
| | self.scheduler.add_op("alu", (op, dest+l, s1+l, s2_Address))
|
| |
|
| |
|
| | for vec in scalar_vectors:
|
| | alu_lanes("^", vec['val'], vec['val'], vec['node'], False)
|
| | for vec in scalar_vectors:
|
| | self.add_hash_opt_scalar(vec['val'], vec['tmp1'], vec['tmp2'])
|
| |
|
| |
|
| | const_1 = self.scratch_const(1)
|
| | for vec in scalar_vectors:
|
| | alu_lanes("&", vec['tmp1'], vec['val'], const_1, True)
|
| | alu_lanes("+", vec['tmp1'], vec['tmp1'], const_1, True)
|
| | alu_lanes("+", vec['idx'], vec['idx'], vec['idx'], False)
|
| | alu_lanes("+", vec['idx'], vec['idx'], vec['tmp1'], False)
|
| |
|
| |
|
| | if do_wrap:
|
| | const_0 = self.scratch_const(0)
|
| | n_nodes_c = ptr_map["n_nodes"]
|
| |
|
| | for vec in scalar_vectors:
|
| | alu_lanes("<", vec['tmp1'], vec['idx'], n_nodes_c, True)
|
| |
|
| | for vec in scalar_vectors:
|
| | for l in range(VLEN):
|
| |
|
| | self.scheduler.add_op("flow", ("select", vec['idx']+l, vec['tmp1']+l, vec['idx']+l, const_0))
|
| |
|
| |
|
| |
|
| |
|
| | for i in range(0, batch_size, VLEN):
|
| | i_const = self.scratch_const(i)
|
| | self.scheduler.add_op("alu", ("+", tmp_load, ptr_map["inp_indices_p"], i_const))
|
| | self.scheduler.add_op("store", ("vstore", tmp_load, indices_base + i))
|
| | self.scheduler.add_op("alu", ("+", tmp_load, ptr_map["inp_values_p"], i_const))
|
| | self.scheduler.add_op("store", ("vstore", tmp_load, values_base + i))
|
| |
|
| | self.scheduler.add_op("flow", ("pause",))
|
| |
|
| | self.instrs = self.scheduler.schedule()
|
| |
|
| |
|
| | BASELINE = 147734
|
| |
|
| | def do_kernel_test(
|
| | forest_height: int,
|
| | rounds: int,
|
| | batch_size: int,
|
| | seed: int = 123,
|
| | trace: bool = False,
|
| | prints: bool = False,
|
| | ):
|
| | print(f"{forest_height=}, {rounds=}, {batch_size=}")
|
| | random.seed(seed)
|
| | forest = Tree.generate(forest_height)
|
| | inp = Input.generate(forest, batch_size, rounds)
|
| | mem = build_mem_image(forest, inp)
|
| |
|
| | kb = KernelBuilder()
|
| | kb.build_kernel(forest.height, len(forest.values), len(inp.indices), rounds)
|
| |
|
| |
|
| |
|
| | value_trace = {}
|
| | machine = Machine(
|
| | mem,
|
| | kb.instrs,
|
| | kb.debug_info(),
|
| | n_cores=N_CORES,
|
| | value_trace=value_trace,
|
| | trace=trace,
|
| | )
|
| | machine.prints = prints
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | while machine.cores[0].state.value != 3:
|
| |
|
| | machine.run()
|
| |
|
| |
|
| | if machine.cores[0].state.value == 2:
|
| | machine.cores[0].state = machine.cores[0].state.__class__(1)
|
| | continue
|
| | break
|
| |
|
| |
|
| | machine.enable_pause = False
|
| |
|
| | for ref_mem in reference_kernel2(mem, value_trace):
|
| | pass
|
| |
|
| | inp_indices_p = ref_mem[5]
|
| | if prints:
|
| | print("INDICES (Machine):", machine.mem[inp_indices_p : inp_indices_p + len(inp.indices)])
|
| | print("INDICES (Ref): ", ref_mem[inp_indices_p : inp_indices_p + len(inp.indices)])
|
| |
|
| | inp_values_p = ref_mem[6]
|
| | if prints:
|
| | print("VALUES (Machine):", machine.mem[inp_values_p : inp_values_p + len(inp.values)])
|
| | print("VALUES (Ref): ", ref_mem[inp_values_p : inp_values_p + len(inp.values)])
|
| |
|
| |
|
| | print("CYCLES: ", machine.cycle)
|
| | if hasattr(machine.cores[0], 'trace_buf'):
|
| | print("TRACE BUF:", machine.cores[0].trace_buf[:64])
|
| |
|
| | assert (
|
| | machine.mem[inp_values_p : inp_values_p + len(inp.values)]
|
| | == ref_mem[inp_values_p : inp_values_p + len(inp.values)]
|
| | ), f"Incorrect result on final round"
|
| |
|
| | return machine.cycle
|
| |
|
| |
|
| | class Tests(unittest.TestCase):
|
| | def test_ref_kernels(self):
|
| | """
|
| | Test the reference kernels against each other
|
| | """
|
| | random.seed(123)
|
| | for i in range(10):
|
| | f = Tree.generate(4)
|
| | inp = Input.generate(f, 10, 6)
|
| | mem = build_mem_image(f, inp)
|
| | reference_kernel(f, inp)
|
| | for _ in reference_kernel2(mem, {}):
|
| | pass
|
| | assert inp.indices == mem[mem[5] : mem[5] + len(inp.indices)]
|
| | assert inp.values == mem[mem[6] : mem[6] + len(inp.values)]
|
| |
|
| | def test_kernel_trace(self):
|
| |
|
| | do_kernel_test(10, 16, 256, trace=True, prints=False)
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | def test_kernel_cycles(self):
|
| | do_kernel_test(10, 16, 256, prints=False)
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | if __name__ == "__main__":
|
| | unittest.main()
|
| |
|