anthropic-kernel / atempt_2 /perf_takehome.py
algorembrant's picture
Upload 39 files
f3ce0b0 verified
"""
# Anthropic's Original Performance Engineering Take-home (Release version)
Copyright Anthropic PBC 2026. Permission is granted to modify and use, but not
to publish or redistribute your solutions so it's hard to find spoilers.
# Task
- Optimize the kernel (in KernelBuilder.build_kernel) as much as possible in the
available time, as measured by test_kernel_cycles on a frozen separate copy
of the simulator.
Validate your results using `python tests/submission_tests.py` without modifying
anything in the tests/ folder.
We recommend you look through problem.py next.
"""
from collections import defaultdict
import random
import unittest
from problem import (
Engine,
DebugInfo,
SLOT_LIMITS,
VLEN,
N_CORES,
SCRATCH_SIZE,
Machine,
Tree,
Input,
HASH_STAGES,
reference_kernel,
build_mem_image,
reference_kernel2,
)
from scheduler import Scheduler
class KernelBuilder:
def __init__(self):
self.scheduler = Scheduler()
self.scratch = {}
self.scratch_debug = {}
self.scratch_ptr = 0
self.const_map = {}
def debug_info(self):
return DebugInfo(scratch_map=self.scratch_debug)
def finalize(self):
return self.scheduler.schedule()
def add_instr(self, instr_dict):
# Fallback for manual addition (rarely used now)
# Actually, we should parse this into the scheduler
for engine, slots in instr_dict.items():
for args in slots:
self.scheduler.add_op(engine, args)
def alloc_scratch(self, name=None, length=1):
addr = self.scratch_ptr
if name is not None:
self.scratch[name] = addr
self.scratch_debug[addr] = (name, length)
self.scratch_ptr += length
assert self.scratch_ptr <= SCRATCH_SIZE, f"Out of scratch space: {self.scratch_ptr}"
return addr
def scratch_const(self, val, name=None):
if val not in self.const_map:
addr = self.alloc_scratch(name)
# We can only load constants using 'load' engine or 'flow' add_imm
# But the simplest is using the 'const' op in 'load' engine
# self.instrs.append({"load": [("const", addr, val)]})
self.scheduler.add_op("load", ("const", addr, val))
self.const_map[val] = addr
return self.const_map[val]
def scratch_vec_const(self, val, name=None):
# Create a vector constant (broadcasted)
key = (val, "vec")
if key not in self.const_map:
addr = self.alloc_scratch(name if name else f"vconst_{val}", VLEN)
scalar_addr = self.scratch_const(val)
# self.add_instr({"valu": [("vbroadcast", addr, scalar_addr)]})
self.scheduler.add_op("valu", ("vbroadcast", addr, scalar_addr))
self.const_map[key] = addr
return self.const_map[key]
def add_hash_opt(self, val_vec, tmp1_vec, tmp2_vec):
"""
Adds slots for the strength-reduced hash function to scheduler.
"""
# Stage 0: MAD
c1 = self.scratch_vec_const(0x7ED55D16, "h0_c")
m1 = self.scratch_vec_const(1 + (1<<12), "h0_m")
self.scheduler.add_op("valu", ("multiply_add", val_vec, val_vec, m1, c1))
# Stage 1: Xor, Shift, Xor
c2 = self.scratch_vec_const(0xC761C23C, "h1_c")
s2 = self.scratch_vec_const(19, "h1_s")
# 1a
self.scheduler.add_op("valu", ("^", tmp1_vec, val_vec, c2))
self.scheduler.add_op("valu", (">>", tmp2_vec, val_vec, s2))
# 1b
self.scheduler.add_op("valu", ("^", val_vec, tmp1_vec, tmp2_vec))
# Stage 2: MAD
c3 = self.scratch_vec_const(0x165667B1, "h2_c")
m3 = self.scratch_vec_const(1 + (1<<5), "h2_m")
self.scheduler.add_op("valu", ("multiply_add", val_vec, val_vec, m3, c3))
# Stage 3: Add, Shift, Xor
c4 = self.scratch_vec_const(0xD3A2646C, "h3_c")
s4 = self.scratch_vec_const(9, "h3_s")
self.scheduler.add_op("valu", ("+", tmp1_vec, val_vec, c4))
self.scheduler.add_op("valu", ("<<", tmp2_vec, val_vec, s4))
self.scheduler.add_op("valu", ("^", val_vec, tmp1_vec, tmp2_vec))
# Stage 4: MAD
c5 = self.scratch_vec_const(0xFD7046C5, "h4_c")
m5 = self.scratch_vec_const(1 + (1<<3), "h4_m")
self.scheduler.add_op("valu", ("multiply_add", val_vec, val_vec, m5, c5))
# Stage 5: Xor, Shift, Xor
c6 = self.scratch_vec_const(0xB55A4F09, "h5_c")
s6 = self.scratch_vec_const(16, "h5_s")
self.scheduler.add_op("valu", ("^", tmp1_vec, val_vec, c6))
self.scheduler.add_op("valu", (">>", tmp2_vec, val_vec, s6))
self.scheduler.add_op("valu", ("^", val_vec, tmp1_vec, tmp2_vec))
def add_hash_opt_scalar(self, val_vec, tmp1_vec, tmp2_vec):
"""
Scalarized version of hash optimization.
Unrolls loop over 8 lanes and uses ALU engine.
"""
# Helper to unroll 8 lanes
def add_alu_lanes(op, dest_vec, src1_vec, src2_vec, s2_is_const=False):
# src2_vec might be constant (scalar address) if s2_is_const
for lane in range(VLEN):
# If s2 is const, it's just one addr, not a vector base
s2_addr = src2_vec if s2_is_const else src2_vec + lane
self.scheduler.add_op("alu", (op, dest_vec + lane, src1_vec + lane, s2_addr))
# Helper for multiply_add which is 3 ops in scalar
# mad(d, a, b, c) -> d = a*b + c
def add_mad_lanes(dest_vec, a_vec, b_vec, c_vec, b_is_const=False, c_is_const=False):
for lane in range(VLEN):
b_addr = b_vec if b_is_const else b_vec + lane
c_addr = c_vec if c_is_const else c_vec + lane
# We need a temp for mul result?
# Can we write to dest? dest = a*b. dest = dest+c.
# Yes if dest is not a/b.
# Here we operate on result value 'val_vec'.
# val = val * m + c.
# val = val * m
self.scheduler.add_op("alu", ("*", dest_vec + lane, a_vec + lane, b_addr))
# val = val + c
self.scheduler.add_op("alu", ("+", dest_vec + lane, dest_vec + lane, c_addr))
# Stage 0: MAD
c1 = self.scratch_const(0x7ED55D16, "h0_c")
m1 = self.scratch_const(1 + (1<<12), "h0_m")
# vector version: multiply_add(val, val, m1, c1)
# scalar version: val = val * m1 + c1
add_mad_lanes(val_vec, val_vec, m1, c1, True, True)
# Stage 1: Xor, Shift, Xor
c2 = self.scratch_const(0xC761C23C, "h1_c")
s2 = self.scratch_const(19, "h1_s")
add_alu_lanes("^", tmp1_vec, val_vec, c2, True)
add_alu_lanes(">>", tmp2_vec, val_vec, s2, True)
add_alu_lanes("^", val_vec, tmp1_vec, tmp2_vec, False)
# Stage 2: MAD
c3 = self.scratch_const(0x165667B1, "h2_c")
m3 = self.scratch_const(1 + (1<<5), "h2_m")
add_mad_lanes(val_vec, val_vec, m3, c3, True, True)
# Stage 3: Add, Shift, Xor
c4 = self.scratch_const(0xD3A2646C, "h3_c")
s4 = self.scratch_const(9, "h3_s")
add_alu_lanes("+", tmp1_vec, val_vec, c4, True)
add_alu_lanes("<<", tmp2_vec, val_vec, s4, True)
add_alu_lanes("^", val_vec, tmp1_vec, tmp2_vec, False)
# Stage 4: MAD
c5 = self.scratch_const(0xFD7046C5, "h4_c")
m5 = self.scratch_const(1 + (1<<3), "h4_m")
add_mad_lanes(val_vec, val_vec, m5, c5, True, True)
# Stage 5: Xor, Shift, Xor
c6 = self.scratch_const(0xB55A4F09, "h5_c")
s6 = self.scratch_const(16, "h5_s")
add_alu_lanes("^", tmp1_vec, val_vec, c6, True)
add_alu_lanes(">>", tmp2_vec, val_vec, s6, True)
add_alu_lanes("^", val_vec, tmp1_vec, tmp2_vec, False)
def build_kernel(
self, forest_height: int, n_nodes: int, batch_size: int, rounds: int,
active_threshold=4, mask_skip=True, scalar_offload=2
):
result_scalar_offload = scalar_offload
"""
Vectorized Wavefront implementation.
"""
# --- Memory Pointers ---
init_vars = [
"rounds", "n_nodes", "batch_size", "forest_height",
"forest_values_p", "inp_indices_p", "inp_values_p"
]
ptr_map = {}
tmp_load = self.alloc_scratch("tmp_load")
for i, v in enumerate(init_vars):
addr = self.alloc_scratch(v)
ptr_map[v] = addr
self.add_instr({"load": [("const", tmp_load, i)]})
self.add_instr({"load": [("load", addr, tmp_load)]})
indices_base = self.alloc_scratch("indices_cache", batch_size)
values_base = self.alloc_scratch("values_cache", batch_size)
# Memory Optimization: Reuse Scratch
# We need 2 Blocks for Temps:
# Block X: tmp_addrs -> node_vals -> vtmp1
# Block Y: vtmp2
block_x = self.alloc_scratch("block_x", batch_size)
block_y = self.alloc_scratch("block_y", batch_size)
num_vecs = batch_size // VLEN
tmp_addrs_base = block_x
node_vals_base = block_x # Alias safe (load dest same as addr source)
vtmp1_base = block_x # Alias safe (node_vals dead after Mix)
vtmp2_base = block_y
# Constants
const_0_vec = self.scratch_vec_const(0)
const_1_vec = self.scratch_vec_const(1)
global_n_nodes_vec = self.alloc_scratch("n_nodes_vec", VLEN)
self.add_instr({"valu": [("vbroadcast", global_n_nodes_vec, ptr_map["n_nodes"])]})
active_temp_base = self.alloc_scratch("active_temp", 200)
# --- 1. Load Input Data (Wavefront) ---
# Address Calc
# --- 1. Load Input Data (Wavefront) ---
# Address Calc
for i in range(0, batch_size, VLEN):
i_const = self.scratch_const(i)
# Indices Addr
self.scheduler.add_op("alu", ("+", tmp_load, ptr_map["inp_indices_p"], i_const))
self.scheduler.add_op("load", ("vload", indices_base + i, tmp_load))
self.scheduler.add_op("alu", ("+", tmp_load, ptr_map["inp_values_p"], i_const))
self.scheduler.add_op("load", ("vload", values_base + i, tmp_load))
# --- 2. Main Loop ---
self.scheduler.add_op("flow", ("pause",))
# self.add_instr({"debug": [("comment", "Starting Computed Loop")]})
# Unrolled Loop for 'rounds'
for r in range(rounds):
# self.add_instr({"debug": [("comment", f"Round {r}")]})
# --- Wavefront Body ---
# Collect register pointers for all vectors
vecs = []
for vec_i in range(num_vecs):
offset = vec_i * VLEN
vecs.append({
'idx': indices_base + offset,
'val': values_base + offset,
'node': node_vals_base + offset,
'tmp1': vtmp1_base + offset,
'tmp2': vtmp2_base + offset,
'addr': tmp_addrs_base + offset
})
for r in range(rounds):
# self.add_instr({"debug": [("comment", f"Round {r}")]})
# --- Wavefront Body ---
# Collect register pointers for all vectors
vecs = []
for vec_i in range(num_vecs):
offset = vec_i * VLEN
vecs.append({
'idx': indices_base + offset,
'val': values_base + offset,
'node': node_vals_base + offset,
'tmp1': vtmp1_base + offset,
'tmp2': vtmp2_base + offset,
'addr': tmp_addrs_base + offset
})
if r == 0:
# Round 0: 1 Node (0)
scalar_node = self.alloc_scratch("scalar_node_r0")
self.scheduler.add_op("load", ("load", scalar_node, ptr_map["forest_values_p"]))
for vec in vecs:
self.scheduler.add_op("valu", ("vbroadcast", vec['node'], scalar_node))
active_indices = [0]
elif len(active_indices) * 2 <= 8: # Threshold for next round
# Reuse Scratch
active_dev_ptr = active_temp_base
def alloc_temp(length=1):
nonlocal active_dev_ptr
addr = active_dev_ptr
active_dev_ptr += length
assert active_dev_ptr <= active_temp_base + 512
return addr
# Update active indices for CURRENT round (which were computed in prev round)
# Logic: active_indices list tracks the set of indices available at START of round.
new_actives = []
for x in active_indices:
new_actives.append(2*x + 1)
new_actives.append(2*x + 2)
active_indices = new_actives
# Active Load Strategy
# 1. Load all unique nodes
node_map = {} # uidx -> vector_reg_of_node_val
for uidx in active_indices:
s_node = alloc_temp(1)
s_addr = alloc_temp(1)
idx_c = self.scratch_const(uidx)
# Calc Addr
self.scheduler.add_op("alu", ("+", s_addr, ptr_map["forest_values_p"], idx_c))
# Load
self.scheduler.add_op("load", ("load", s_node, s_addr))
# Broadcast
v_node = alloc_temp(VLEN)
self.scheduler.add_op("valu", ("vbroadcast", v_node, s_node))
node_map[uidx] = v_node
# Mark storage used by Node Map
tree_temp_start = active_dev_ptr
# 2. Select Tree for each vector
for vec in vecs:
# Reset temps for this vector
active_dev_ptr = tree_temp_start
# vec['idx'] holds current index.
# We need to set vec['node'] based on vec['idx'] looking up node_map.
# Build binary search tree of vselects
def build_tree(indices):
if len(indices) == 1:
return node_map[indices[0]]
mid = len(indices) // 2
left = indices[:mid]
right = indices[mid:]
split_val = right[0]
# cond = idx < split_val
split_c = self.scratch_vec_const(split_val)
cond = alloc_temp(VLEN) # Need temp
self.scheduler.add_op("valu", ("<", cond, vec['idx'], split_c))
l_res = build_tree(left)
r_res = build_tree(right)
# Result of this level
res = alloc_temp(VLEN)
self.scheduler.add_op("flow", ("vselect", res, cond, l_res, r_res))
return res
final_res = build_tree(active_indices)
# Move final_res to vec['node']
# Using logical OR with self.
self.scheduler.add_op("valu", ("|", vec['node'], final_res, final_res))
else:
# Generic Wavefront Load
# Wave A: Address Calc (All Vecs)
for vec in vecs:
for lane in range(VLEN):
self.scheduler.add_op("alu", ("+", vec['addr'] + lane, ptr_map["forest_values_p"], vec['idx'] + lane))
# Wave B: Load Node Vals (All Vecs)
for vec in vecs:
for lane in range(VLEN):
self.scheduler.add_op("load", ("load", vec['node'] + lane, vec['addr'] + lane))
do_wrap = True
if mask_skip and (1<<(r+2)) < n_nodes:
do_wrap = False
# Only offload if NOT wrapping (to avoid scalar select overhead)
# OR if we find a better way to wrap scalar.
use_offload = (r >= active_threshold) and (not do_wrap)
scalar_vectors = vecs[:result_scalar_offload] if use_offload else []
vector_vectors = vecs[result_scalar_offload:] if use_offload else vecs
# --- VECTORIZED VECTORS ---
# Mixed Hash
for vec in vector_vectors:
self.scheduler.add_op("valu", ("^", vec['val'], vec['val'], vec['node']))
for vec in vector_vectors:
self.add_hash_opt(vec['val'], vec['tmp1'], vec['tmp2'])
# Index Update
for vec in vector_vectors:
self.scheduler.add_op("valu", ("&", vec['tmp1'], vec['val'], const_1_vec))
self.scheduler.add_op("valu", ("+", vec['tmp1'], vec['tmp1'], const_1_vec))
self.scheduler.add_op("valu", ("+", vec['idx'], vec['idx'], vec['idx']))
self.scheduler.add_op("valu", ("+", vec['idx'], vec['idx'], vec['tmp1']))
# Wrap
if do_wrap:
for vec in vector_vectors:
self.scheduler.add_op("valu", ("<", vec['tmp1'], vec['idx'], global_n_nodes_vec))
for vec in vector_vectors:
self.scheduler.add_op("flow", ("vselect", vec['idx'], vec['tmp1'], vec['idx'], const_0_vec))
# --- SCALARIZED VECTORS ---
# Helpers
def alu_lanes(op, dest, s1, s2, s2_c=False):
for l in range(VLEN):
s2_Address = s2 if s2_c else s2+l
self.scheduler.add_op("alu", (op, dest+l, s1+l, s2_Address))
# Mixed Hash
for vec in scalar_vectors:
alu_lanes("^", vec['val'], vec['val'], vec['node'], False)
for vec in scalar_vectors:
self.add_hash_opt_scalar(vec['val'], vec['tmp1'], vec['tmp2'])
# Index Update
const_1 = self.scratch_const(1)
for vec in scalar_vectors:
alu_lanes("&", vec['tmp1'], vec['val'], const_1, True)
alu_lanes("+", vec['tmp1'], vec['tmp1'], const_1, True)
alu_lanes("+", vec['idx'], vec['idx'], vec['idx'], False)
alu_lanes("+", vec['idx'], vec['idx'], vec['tmp1'], False)
# Wrap
if do_wrap:
const_0 = self.scratch_const(0)
n_nodes_c = ptr_map["n_nodes"] # Scalar n_nodes
# Mask
for vec in scalar_vectors:
alu_lanes("<", vec['tmp1'], vec['idx'], n_nodes_c, True)
# Select using scalar flow 'select'
for vec in scalar_vectors:
for l in range(VLEN):
# flow select: dest, cond, a, b
self.scheduler.add_op("flow", ("select", vec['idx']+l, vec['tmp1']+l, vec['idx']+l, const_0))
# End Unrolled Loop
# --- 3. Final Store ---
for i in range(0, batch_size, VLEN):
i_const = self.scratch_const(i)
self.scheduler.add_op("alu", ("+", tmp_load, ptr_map["inp_indices_p"], i_const))
self.scheduler.add_op("store", ("vstore", tmp_load, indices_base + i))
self.scheduler.add_op("alu", ("+", tmp_load, ptr_map["inp_values_p"], i_const))
self.scheduler.add_op("store", ("vstore", tmp_load, values_base + i))
self.scheduler.add_op("flow", ("pause",))
self.instrs = self.scheduler.schedule()
BASELINE = 147734
def do_kernel_test(
forest_height: int,
rounds: int,
batch_size: int,
seed: int = 123,
trace: bool = False,
prints: bool = False,
):
print(f"{forest_height=}, {rounds=}, {batch_size=}")
random.seed(seed)
forest = Tree.generate(forest_height)
inp = Input.generate(forest, batch_size, rounds)
mem = build_mem_image(forest, inp)
kb = KernelBuilder()
kb.build_kernel(forest.height, len(forest.values), len(inp.indices), rounds)
# final_instrs = kb.finalize()
# print(final_instrs)
value_trace = {}
machine = Machine(
mem,
kb.instrs,
kb.debug_info(),
n_cores=N_CORES,
value_trace=value_trace,
trace=trace,
)
machine.prints = prints
# machine.enable_pause = False # If we want to skip pauses like submission_tests
# Run fully
# Since we have pauses, we can loop, but checking intermediate state fails if we don't write to mem.
# So we just run until done.
while machine.cores[0].state.value != 3: # STOPPED
# print(f"Run. Start State: {machine.cores[0].state} PC: {machine.cores[0].pc}")
machine.run()
# print(f"Run. End State: {machine.cores[0].state} PC: {machine.cores[0].pc}")
# If paused, unpause?
if machine.cores[0].state.value == 2: # PAUSED
machine.cores[0].state = machine.cores[0].state.__class__(1) # RUNNING
continue
break
# Check FINAL result
machine.enable_pause = False
# Grab final ref state
for ref_mem in reference_kernel2(mem, value_trace):
pass
inp_indices_p = ref_mem[5]
if prints:
print("INDICES (Machine):", machine.mem[inp_indices_p : inp_indices_p + len(inp.indices)])
print("INDICES (Ref): ", ref_mem[inp_indices_p : inp_indices_p + len(inp.indices)])
inp_values_p = ref_mem[6]
if prints:
print("VALUES (Machine):", machine.mem[inp_values_p : inp_values_p + len(inp.values)])
print("VALUES (Ref): ", ref_mem[inp_values_p : inp_values_p + len(inp.values)])
# DEBUG PRINT ALWAYS
print("CYCLES: ", machine.cycle)
if hasattr(machine.cores[0], 'trace_buf'):
print("TRACE BUF:", machine.cores[0].trace_buf[:64]) # Print first 64 items (Round 0)
assert (
machine.mem[inp_values_p : inp_values_p + len(inp.values)]
== ref_mem[inp_values_p : inp_values_p + len(inp.values)]
), f"Incorrect result on final round"
return machine.cycle
class Tests(unittest.TestCase):
def test_ref_kernels(self):
"""
Test the reference kernels against each other
"""
random.seed(123)
for i in range(10):
f = Tree.generate(4)
inp = Input.generate(f, 10, 6)
mem = build_mem_image(f, inp)
reference_kernel(f, inp)
for _ in reference_kernel2(mem, {}):
pass
assert inp.indices == mem[mem[5] : mem[5] + len(inp.indices)]
assert inp.values == mem[mem[6] : mem[6] + len(inp.values)]
def test_kernel_trace(self):
# Full-scale example for performance testing
do_kernel_test(10, 16, 256, trace=True, prints=False)
# Passing this test is not required for submission, see submission_tests.py for the actual correctness test
# You can uncomment this if you think it might help you debug
# def test_kernel_correctness(self):
# for batch in range(1, 3):
# for forest_height in range(3):
# do_kernel_test(
# forest_height + 2, forest_height + 4, batch * 16 * VLEN * N_CORES
# )
def test_kernel_cycles(self):
do_kernel_test(10, 16, 256, prints=False)
# To run all the tests:
# python perf_takehome.py
# To run a specific test:
# python perf_takehome.py Tests.test_kernel_cycles
# To view a hot-reloading trace of all the instructions: **Recommended debug loop**
# NOTE: The trace hot-reloading only works in Chrome. In the worst case if things aren't working, drag trace.json onto https://ui.perfetto.dev/
# python perf_takehome.py Tests.test_kernel_trace
# Then run `python watch_trace.py` in another tab, it'll open a browser tab, then click "Open Perfetto"
# You can then keep that open and re-run the test to see a new trace.
# To run the proper checks to see which thresholds you pass:
# python tests/submission_tests.py
if __name__ == "__main__":
unittest.main()