""" # Anthropic's Original Performance Engineering Take-home (Release version) Copyright Anthropic PBC 2026. Permission is granted to modify and use, but not to publish or redistribute your solutions so it's hard to find spoilers. # Task - Optimize the kernel (in KernelBuilder.build_kernel) as much as possible in the available time, as measured by test_kernel_cycles on a frozen separate copy of the simulator. Validate your results using `python tests/submission_tests.py` without modifying anything in the tests/ folder. We recommend you look through problem.py next. """ from collections import defaultdict import random import unittest from problem import ( Engine, DebugInfo, SLOT_LIMITS, VLEN, N_CORES, SCRATCH_SIZE, Machine, Tree, Input, HASH_STAGES, reference_kernel, build_mem_image, reference_kernel2, ) from scheduler import Scheduler class KernelBuilder: def __init__(self): self.scheduler = Scheduler() self.scratch = {} self.scratch_debug = {} self.scratch_ptr = 0 self.const_map = {} def debug_info(self): return DebugInfo(scratch_map=self.scratch_debug) def finalize(self): return self.scheduler.schedule() def add_instr(self, instr_dict): # Fallback for manual addition (rarely used now) # Actually, we should parse this into the scheduler for engine, slots in instr_dict.items(): for args in slots: self.scheduler.add_op(engine, args) def alloc_scratch(self, name=None, length=1): addr = self.scratch_ptr if name is not None: self.scratch[name] = addr self.scratch_debug[addr] = (name, length) self.scratch_ptr += length assert self.scratch_ptr <= SCRATCH_SIZE, f"Out of scratch space: {self.scratch_ptr}" return addr def scratch_const(self, val, name=None): if val not in self.const_map: addr = self.alloc_scratch(name) # We can only load constants using 'load' engine or 'flow' add_imm # But the simplest is using the 'const' op in 'load' engine # self.instrs.append({"load": [("const", addr, val)]}) self.scheduler.add_op("load", ("const", addr, val)) self.const_map[val] = addr return self.const_map[val] def scratch_vec_const(self, val, name=None): # Create a vector constant (broadcasted) key = (val, "vec") if key not in self.const_map: addr = self.alloc_scratch(name if name else f"vconst_{val}", VLEN) scalar_addr = self.scratch_const(val) # self.add_instr({"valu": [("vbroadcast", addr, scalar_addr)]}) self.scheduler.add_op("valu", ("vbroadcast", addr, scalar_addr)) self.const_map[key] = addr return self.const_map[key] def add_hash_opt(self, val_vec, tmp1_vec, tmp2_vec): """ Adds slots for the strength-reduced hash function to scheduler. """ # Stage 0: MAD c1 = self.scratch_vec_const(0x7ED55D16, "h0_c") m1 = self.scratch_vec_const(1 + (1<<12), "h0_m") self.scheduler.add_op("valu", ("multiply_add", val_vec, val_vec, m1, c1)) # Stage 1: Xor, Shift, Xor c2 = self.scratch_vec_const(0xC761C23C, "h1_c") s2 = self.scratch_vec_const(19, "h1_s") # 1a self.scheduler.add_op("valu", ("^", tmp1_vec, val_vec, c2)) self.scheduler.add_op("valu", (">>", tmp2_vec, val_vec, s2)) # 1b self.scheduler.add_op("valu", ("^", val_vec, tmp1_vec, tmp2_vec)) # Stage 2: MAD c3 = self.scratch_vec_const(0x165667B1, "h2_c") m3 = self.scratch_vec_const(1 + (1<<5), "h2_m") self.scheduler.add_op("valu", ("multiply_add", val_vec, val_vec, m3, c3)) # Stage 3: Add, Shift, Xor c4 = self.scratch_vec_const(0xD3A2646C, "h3_c") s4 = self.scratch_vec_const(9, "h3_s") self.scheduler.add_op("valu", ("+", tmp1_vec, val_vec, c4)) self.scheduler.add_op("valu", ("<<", tmp2_vec, val_vec, s4)) self.scheduler.add_op("valu", ("^", val_vec, tmp1_vec, tmp2_vec)) # Stage 4: MAD c5 = self.scratch_vec_const(0xFD7046C5, "h4_c") m5 = self.scratch_vec_const(1 + (1<<3), "h4_m") self.scheduler.add_op("valu", ("multiply_add", val_vec, val_vec, m5, c5)) # Stage 5: Xor, Shift, Xor c6 = self.scratch_vec_const(0xB55A4F09, "h5_c") s6 = self.scratch_vec_const(16, "h5_s") self.scheduler.add_op("valu", ("^", tmp1_vec, val_vec, c6)) self.scheduler.add_op("valu", (">>", tmp2_vec, val_vec, s6)) self.scheduler.add_op("valu", ("^", val_vec, tmp1_vec, tmp2_vec)) def add_hash_opt_scalar(self, val_vec, tmp1_vec, tmp2_vec): """ Scalarized version of hash optimization. Unrolls loop over 8 lanes and uses ALU engine. """ # Helper to unroll 8 lanes def add_alu_lanes(op, dest_vec, src1_vec, src2_vec, s2_is_const=False): # src2_vec might be constant (scalar address) if s2_is_const for lane in range(VLEN): # If s2 is const, it's just one addr, not a vector base s2_addr = src2_vec if s2_is_const else src2_vec + lane self.scheduler.add_op("alu", (op, dest_vec + lane, src1_vec + lane, s2_addr)) # Helper for multiply_add which is 3 ops in scalar # mad(d, a, b, c) -> d = a*b + c def add_mad_lanes(dest_vec, a_vec, b_vec, c_vec, b_is_const=False, c_is_const=False): for lane in range(VLEN): b_addr = b_vec if b_is_const else b_vec + lane c_addr = c_vec if c_is_const else c_vec + lane # We need a temp for mul result? # Can we write to dest? dest = a*b. dest = dest+c. # Yes if dest is not a/b. # Here we operate on result value 'val_vec'. # val = val * m + c. # val = val * m self.scheduler.add_op("alu", ("*", dest_vec + lane, a_vec + lane, b_addr)) # val = val + c self.scheduler.add_op("alu", ("+", dest_vec + lane, dest_vec + lane, c_addr)) # Stage 0: MAD c1 = self.scratch_const(0x7ED55D16, "h0_c") m1 = self.scratch_const(1 + (1<<12), "h0_m") # vector version: multiply_add(val, val, m1, c1) # scalar version: val = val * m1 + c1 add_mad_lanes(val_vec, val_vec, m1, c1, True, True) # Stage 1: Xor, Shift, Xor c2 = self.scratch_const(0xC761C23C, "h1_c") s2 = self.scratch_const(19, "h1_s") add_alu_lanes("^", tmp1_vec, val_vec, c2, True) add_alu_lanes(">>", tmp2_vec, val_vec, s2, True) add_alu_lanes("^", val_vec, tmp1_vec, tmp2_vec, False) # Stage 2: MAD c3 = self.scratch_const(0x165667B1, "h2_c") m3 = self.scratch_const(1 + (1<<5), "h2_m") add_mad_lanes(val_vec, val_vec, m3, c3, True, True) # Stage 3: Add, Shift, Xor c4 = self.scratch_const(0xD3A2646C, "h3_c") s4 = self.scratch_const(9, "h3_s") add_alu_lanes("+", tmp1_vec, val_vec, c4, True) add_alu_lanes("<<", tmp2_vec, val_vec, s4, True) add_alu_lanes("^", val_vec, tmp1_vec, tmp2_vec, False) # Stage 4: MAD c5 = self.scratch_const(0xFD7046C5, "h4_c") m5 = self.scratch_const(1 + (1<<3), "h4_m") add_mad_lanes(val_vec, val_vec, m5, c5, True, True) # Stage 5: Xor, Shift, Xor c6 = self.scratch_const(0xB55A4F09, "h5_c") s6 = self.scratch_const(16, "h5_s") add_alu_lanes("^", tmp1_vec, val_vec, c6, True) add_alu_lanes(">>", tmp2_vec, val_vec, s6, True) add_alu_lanes("^", val_vec, tmp1_vec, tmp2_vec, False) def build_kernel( self, forest_height: int, n_nodes: int, batch_size: int, rounds: int, active_threshold=4, mask_skip=True, scalar_offload=2 ): result_scalar_offload = scalar_offload """ Vectorized Wavefront implementation. """ # --- Memory Pointers --- init_vars = [ "rounds", "n_nodes", "batch_size", "forest_height", "forest_values_p", "inp_indices_p", "inp_values_p" ] ptr_map = {} tmp_load = self.alloc_scratch("tmp_load") for i, v in enumerate(init_vars): addr = self.alloc_scratch(v) ptr_map[v] = addr self.add_instr({"load": [("const", tmp_load, i)]}) self.add_instr({"load": [("load", addr, tmp_load)]}) indices_base = self.alloc_scratch("indices_cache", batch_size) values_base = self.alloc_scratch("values_cache", batch_size) # Memory Optimization: Reuse Scratch # We need 2 Blocks for Temps: # Block X: tmp_addrs -> node_vals -> vtmp1 # Block Y: vtmp2 block_x = self.alloc_scratch("block_x", batch_size) block_y = self.alloc_scratch("block_y", batch_size) num_vecs = batch_size // VLEN tmp_addrs_base = block_x node_vals_base = block_x # Alias safe (load dest same as addr source) vtmp1_base = block_x # Alias safe (node_vals dead after Mix) vtmp2_base = block_y # Constants const_0_vec = self.scratch_vec_const(0) const_1_vec = self.scratch_vec_const(1) global_n_nodes_vec = self.alloc_scratch("n_nodes_vec", VLEN) self.add_instr({"valu": [("vbroadcast", global_n_nodes_vec, ptr_map["n_nodes"])]}) active_temp_base = self.alloc_scratch("active_temp", 200) # --- 1. Load Input Data (Wavefront) --- # Address Calc # --- 1. Load Input Data (Wavefront) --- # Address Calc for i in range(0, batch_size, VLEN): i_const = self.scratch_const(i) # Indices Addr self.scheduler.add_op("alu", ("+", tmp_load, ptr_map["inp_indices_p"], i_const)) self.scheduler.add_op("load", ("vload", indices_base + i, tmp_load)) self.scheduler.add_op("alu", ("+", tmp_load, ptr_map["inp_values_p"], i_const)) self.scheduler.add_op("load", ("vload", values_base + i, tmp_load)) # --- 2. Main Loop --- self.scheduler.add_op("flow", ("pause",)) # self.add_instr({"debug": [("comment", "Starting Computed Loop")]}) # Unrolled Loop for 'rounds' for r in range(rounds): # self.add_instr({"debug": [("comment", f"Round {r}")]}) # --- Wavefront Body --- # Collect register pointers for all vectors vecs = [] for vec_i in range(num_vecs): offset = vec_i * VLEN vecs.append({ 'idx': indices_base + offset, 'val': values_base + offset, 'node': node_vals_base + offset, 'tmp1': vtmp1_base + offset, 'tmp2': vtmp2_base + offset, 'addr': tmp_addrs_base + offset }) for r in range(rounds): # self.add_instr({"debug": [("comment", f"Round {r}")]}) # --- Wavefront Body --- # Collect register pointers for all vectors vecs = [] for vec_i in range(num_vecs): offset = vec_i * VLEN vecs.append({ 'idx': indices_base + offset, 'val': values_base + offset, 'node': node_vals_base + offset, 'tmp1': vtmp1_base + offset, 'tmp2': vtmp2_base + offset, 'addr': tmp_addrs_base + offset }) if r == 0: # Round 0: 1 Node (0) scalar_node = self.alloc_scratch("scalar_node_r0") self.scheduler.add_op("load", ("load", scalar_node, ptr_map["forest_values_p"])) for vec in vecs: self.scheduler.add_op("valu", ("vbroadcast", vec['node'], scalar_node)) active_indices = [0] elif len(active_indices) * 2 <= 8: # Threshold for next round # Reuse Scratch active_dev_ptr = active_temp_base def alloc_temp(length=1): nonlocal active_dev_ptr addr = active_dev_ptr active_dev_ptr += length assert active_dev_ptr <= active_temp_base + 512 return addr # Update active indices for CURRENT round (which were computed in prev round) # Logic: active_indices list tracks the set of indices available at START of round. new_actives = [] for x in active_indices: new_actives.append(2*x + 1) new_actives.append(2*x + 2) active_indices = new_actives # Active Load Strategy # 1. Load all unique nodes node_map = {} # uidx -> vector_reg_of_node_val for uidx in active_indices: s_node = alloc_temp(1) s_addr = alloc_temp(1) idx_c = self.scratch_const(uidx) # Calc Addr self.scheduler.add_op("alu", ("+", s_addr, ptr_map["forest_values_p"], idx_c)) # Load self.scheduler.add_op("load", ("load", s_node, s_addr)) # Broadcast v_node = alloc_temp(VLEN) self.scheduler.add_op("valu", ("vbroadcast", v_node, s_node)) node_map[uidx] = v_node # Mark storage used by Node Map tree_temp_start = active_dev_ptr # 2. Select Tree for each vector for vec in vecs: # Reset temps for this vector active_dev_ptr = tree_temp_start # vec['idx'] holds current index. # We need to set vec['node'] based on vec['idx'] looking up node_map. # Build binary search tree of vselects def build_tree(indices): if len(indices) == 1: return node_map[indices[0]] mid = len(indices) // 2 left = indices[:mid] right = indices[mid:] split_val = right[0] # cond = idx < split_val split_c = self.scratch_vec_const(split_val) cond = alloc_temp(VLEN) # Need temp self.scheduler.add_op("valu", ("<", cond, vec['idx'], split_c)) l_res = build_tree(left) r_res = build_tree(right) # Result of this level res = alloc_temp(VLEN) self.scheduler.add_op("flow", ("vselect", res, cond, l_res, r_res)) return res final_res = build_tree(active_indices) # Move final_res to vec['node'] # Using logical OR with self. self.scheduler.add_op("valu", ("|", vec['node'], final_res, final_res)) else: # Generic Wavefront Load # Wave A: Address Calc (All Vecs) for vec in vecs: for lane in range(VLEN): self.scheduler.add_op("alu", ("+", vec['addr'] + lane, ptr_map["forest_values_p"], vec['idx'] + lane)) # Wave B: Load Node Vals (All Vecs) for vec in vecs: for lane in range(VLEN): self.scheduler.add_op("load", ("load", vec['node'] + lane, vec['addr'] + lane)) do_wrap = True if mask_skip and (1<<(r+2)) < n_nodes: do_wrap = False # Only offload if NOT wrapping (to avoid scalar select overhead) # OR if we find a better way to wrap scalar. use_offload = (r >= active_threshold) and (not do_wrap) scalar_vectors = vecs[:result_scalar_offload] if use_offload else [] vector_vectors = vecs[result_scalar_offload:] if use_offload else vecs # --- VECTORIZED VECTORS --- # Mixed Hash for vec in vector_vectors: self.scheduler.add_op("valu", ("^", vec['val'], vec['val'], vec['node'])) for vec in vector_vectors: self.add_hash_opt(vec['val'], vec['tmp1'], vec['tmp2']) # Index Update for vec in vector_vectors: self.scheduler.add_op("valu", ("&", vec['tmp1'], vec['val'], const_1_vec)) self.scheduler.add_op("valu", ("+", vec['tmp1'], vec['tmp1'], const_1_vec)) self.scheduler.add_op("valu", ("+", vec['idx'], vec['idx'], vec['idx'])) self.scheduler.add_op("valu", ("+", vec['idx'], vec['idx'], vec['tmp1'])) # Wrap if do_wrap: for vec in vector_vectors: self.scheduler.add_op("valu", ("<", vec['tmp1'], vec['idx'], global_n_nodes_vec)) for vec in vector_vectors: self.scheduler.add_op("flow", ("vselect", vec['idx'], vec['tmp1'], vec['idx'], const_0_vec)) # --- SCALARIZED VECTORS --- # Helpers def alu_lanes(op, dest, s1, s2, s2_c=False): for l in range(VLEN): s2_Address = s2 if s2_c else s2+l self.scheduler.add_op("alu", (op, dest+l, s1+l, s2_Address)) # Mixed Hash for vec in scalar_vectors: alu_lanes("^", vec['val'], vec['val'], vec['node'], False) for vec in scalar_vectors: self.add_hash_opt_scalar(vec['val'], vec['tmp1'], vec['tmp2']) # Index Update const_1 = self.scratch_const(1) for vec in scalar_vectors: alu_lanes("&", vec['tmp1'], vec['val'], const_1, True) alu_lanes("+", vec['tmp1'], vec['tmp1'], const_1, True) alu_lanes("+", vec['idx'], vec['idx'], vec['idx'], False) alu_lanes("+", vec['idx'], vec['idx'], vec['tmp1'], False) # Wrap if do_wrap: const_0 = self.scratch_const(0) n_nodes_c = ptr_map["n_nodes"] # Scalar n_nodes # Mask for vec in scalar_vectors: alu_lanes("<", vec['tmp1'], vec['idx'], n_nodes_c, True) # Select using scalar flow 'select' for vec in scalar_vectors: for l in range(VLEN): # flow select: dest, cond, a, b self.scheduler.add_op("flow", ("select", vec['idx']+l, vec['tmp1']+l, vec['idx']+l, const_0)) # End Unrolled Loop # --- 3. Final Store --- for i in range(0, batch_size, VLEN): i_const = self.scratch_const(i) self.scheduler.add_op("alu", ("+", tmp_load, ptr_map["inp_indices_p"], i_const)) self.scheduler.add_op("store", ("vstore", tmp_load, indices_base + i)) self.scheduler.add_op("alu", ("+", tmp_load, ptr_map["inp_values_p"], i_const)) self.scheduler.add_op("store", ("vstore", tmp_load, values_base + i)) self.scheduler.add_op("flow", ("pause",)) self.instrs = self.scheduler.schedule() BASELINE = 147734 def do_kernel_test( forest_height: int, rounds: int, batch_size: int, seed: int = 123, trace: bool = False, prints: bool = False, ): print(f"{forest_height=}, {rounds=}, {batch_size=}") random.seed(seed) forest = Tree.generate(forest_height) inp = Input.generate(forest, batch_size, rounds) mem = build_mem_image(forest, inp) kb = KernelBuilder() kb.build_kernel(forest.height, len(forest.values), len(inp.indices), rounds) # final_instrs = kb.finalize() # print(final_instrs) value_trace = {} machine = Machine( mem, kb.instrs, kb.debug_info(), n_cores=N_CORES, value_trace=value_trace, trace=trace, ) machine.prints = prints # machine.enable_pause = False # If we want to skip pauses like submission_tests # Run fully # Since we have pauses, we can loop, but checking intermediate state fails if we don't write to mem. # So we just run until done. while machine.cores[0].state.value != 3: # STOPPED # print(f"Run. Start State: {machine.cores[0].state} PC: {machine.cores[0].pc}") machine.run() # print(f"Run. End State: {machine.cores[0].state} PC: {machine.cores[0].pc}") # If paused, unpause? if machine.cores[0].state.value == 2: # PAUSED machine.cores[0].state = machine.cores[0].state.__class__(1) # RUNNING continue break # Check FINAL result machine.enable_pause = False # Grab final ref state for ref_mem in reference_kernel2(mem, value_trace): pass inp_indices_p = ref_mem[5] if prints: print("INDICES (Machine):", machine.mem[inp_indices_p : inp_indices_p + len(inp.indices)]) print("INDICES (Ref): ", ref_mem[inp_indices_p : inp_indices_p + len(inp.indices)]) inp_values_p = ref_mem[6] if prints: print("VALUES (Machine):", machine.mem[inp_values_p : inp_values_p + len(inp.values)]) print("VALUES (Ref): ", ref_mem[inp_values_p : inp_values_p + len(inp.values)]) # DEBUG PRINT ALWAYS print("CYCLES: ", machine.cycle) if hasattr(machine.cores[0], 'trace_buf'): print("TRACE BUF:", machine.cores[0].trace_buf[:64]) # Print first 64 items (Round 0) assert ( machine.mem[inp_values_p : inp_values_p + len(inp.values)] == ref_mem[inp_values_p : inp_values_p + len(inp.values)] ), f"Incorrect result on final round" return machine.cycle class Tests(unittest.TestCase): def test_ref_kernels(self): """ Test the reference kernels against each other """ random.seed(123) for i in range(10): f = Tree.generate(4) inp = Input.generate(f, 10, 6) mem = build_mem_image(f, inp) reference_kernel(f, inp) for _ in reference_kernel2(mem, {}): pass assert inp.indices == mem[mem[5] : mem[5] + len(inp.indices)] assert inp.values == mem[mem[6] : mem[6] + len(inp.values)] def test_kernel_trace(self): # Full-scale example for performance testing do_kernel_test(10, 16, 256, trace=True, prints=False) # Passing this test is not required for submission, see submission_tests.py for the actual correctness test # You can uncomment this if you think it might help you debug # def test_kernel_correctness(self): # for batch in range(1, 3): # for forest_height in range(3): # do_kernel_test( # forest_height + 2, forest_height + 4, batch * 16 * VLEN * N_CORES # ) def test_kernel_cycles(self): do_kernel_test(10, 16, 256, prints=False) # To run all the tests: # python perf_takehome.py # To run a specific test: # python perf_takehome.py Tests.test_kernel_cycles # To view a hot-reloading trace of all the instructions: **Recommended debug loop** # NOTE: The trace hot-reloading only works in Chrome. In the worst case if things aren't working, drag trace.json onto https://ui.perfetto.dev/ # python perf_takehome.py Tests.test_kernel_trace # Then run `python watch_trace.py` in another tab, it'll open a browser tab, then click "Open Perfetto" # You can then keep that open and re-run the test to see a new trace. # To run the proper checks to see which thresholds you pass: # python tests/submission_tests.py if __name__ == "__main__": unittest.main()