import collections from collections import defaultdict, deque import heapq import random import unittest # Assumes problem.py exists in the same directory as per the original structure from problem import ( Engine, DebugInfo, SLOT_LIMITS, # Note: Scheduler re-defines this, but we keep import for safety VLEN, N_CORES, SCRATCH_SIZE, Machine, Tree, Input, HASH_STAGES, reference_kernel, build_mem_image, reference_kernel2, ) # --- Integrated Scheduler Code --- # Redefining locally to ensure scheduler uses these exact limits SCHEDULER_SLOT_LIMITS = { "alu": 12, "valu": 6, "load": 2, "store": 2, "flow": 1, "debug": 64, } class Node: def __init__(self, id, engine, args, desc=""): self.id = id self.engine = engine self.args = args # Tuple of args self.desc = desc self.parents = [] self.children = [] self.priority = 0 self.latency = 1 # Default latency def add_child(self, node): self.children.append(node) node.parents.append(self) class Scheduler: def __init__(self): self.nodes = [] self.id_counter = 0 self.scratch_reads = defaultdict(list) # addr -> [nodes reading it] self.scratch_writes = defaultdict(list) # addr -> [nodes writing it] def add_op(self, engine, args, desc=""): node = Node(self.id_counter, engine, args, desc) self.nodes.append(node) self.id_counter += 1 # Analyze dependencies reads, writes = self._get_rw(engine, args) # RAW (Read After Write): Current node reads from a previous write for r in reads: if r in self.scratch_writes and self.scratch_writes[r]: # Depend on the LAST writer last_writer = self.scratch_writes[r][-1] last_writer.add_child(node) # WAW (Write After Write): Current node writes to same addr as previous write for w in writes: if w in self.scratch_writes and self.scratch_writes[w]: last_writer = self.scratch_writes[w][-1] last_writer.add_child(node) # WAR (Write After Read): Current node writes to addr that was read previously # We must not write until previous reads are done. for w in writes: if w in self.scratch_reads and self.scratch_reads[w]: for reader in self.scratch_reads[w]: if reader != node: # Don't depend on self reader.add_child(node) # Register Access updates for r in reads: self.scratch_reads[r].append(node) for w in writes: self.scratch_writes[w].append(node) return node def _get_rw(self, engine, args): reads = [] writes = [] # Helpers def is_addr(x): return isinstance(x, int) if engine == "alu": # (op, dest, a1, a2) # Generic ALU ops usually take 3 args: dest, src1, src2 op, dest, a1, a2 = args writes.append(dest) reads.append(a1) reads.append(a2) elif engine == "valu": # varargs op = args[0] if op == "vbroadcast": # dest, src writes.extend([args[1] + i for i in range(VLEN)]) reads.append(args[2]) elif op == "multiply_add": # dest, a, b, c writes.extend([args[1] + i for i in range(VLEN)]) reads.extend([args[2] + i for i in range(VLEN)]) reads.extend([args[3] + i for i in range(VLEN)]) reads.extend([args[4] + i for i in range(VLEN)]) else: # Generic VALU op: op, dest, a1, a2 # e.g. ^, >>, +, <, & writes.extend([args[1] + i for i in range(VLEN)]) reads.extend([args[2] + i for i in range(VLEN)]) reads.extend([args[3] + i for i in range(VLEN)]) elif engine == "load": op = args[0] if op == "const": writes.append(args[1]) elif op == "load": writes.append(args[1]) reads.append(args[2]) elif op == "vload": writes.extend([args[1] + i for i in range(VLEN)]) reads.append(args[2]) # scalar addr elif engine == "store": op = args[0] if op == "vstore": reads.append(args[1]) # addr reads.extend([args[2] + i for i in range(VLEN)]) # val elif engine == "flow": op = args[0] if op == "vselect": # dest, cond, a, b writes.extend([args[1] + i for i in range(VLEN)]) reads.extend([args[2] + i for i in range(VLEN)]) reads.extend([args[3] + i for i in range(VLEN)]) reads.extend([args[4] + i for i in range(VLEN)]) elif op == "select": # dest, cond, a, b writes.append(args[1]) reads.append(args[2]) reads.append(args[3]) reads.append(args[4]) elif op == "add_imm": # dest, a, imm writes.append(args[1]) reads.append(args[2]) elif op == "cond_jump" or op == "cond_jump_rel": # cond, dest reads.append(args[1]) elif op == "pause": pass return reads, writes def schedule(self): # Calculate priorities (longest path) self._calc_priorities() ready = [] # Heap of (-priority, node) in_degree = defaultdict(int) for node in self.nodes: in_degree[node] = len(node.parents) if in_degree[node] == 0: heapq.heappush(ready, (-node.priority, node.id, node)) instructions = [] # Main Scheduling Loop while ready or any(count > 0 for count in in_degree.values()): cycle_ops = defaultdict(list) deferred = [] usage = {k:0 for k in SCHEDULER_SLOT_LIMITS} curr_cycle_nodes = [] # Greedy allocation for this cycle while ready: prio, nid, node = heapq.heappop(ready) if usage[node.engine] < SCHEDULER_SLOT_LIMITS[node.engine]: usage[node.engine] += 1 cycle_ops[node.engine].append(node.args) curr_cycle_nodes.append(node) else: deferred.append((prio, nid, node)) # Push back deferred for next cycle for item in deferred: heapq.heappush(ready, item) # Check for termination or deadlock if not curr_cycle_nodes and not ready: if any(in_degree.values()): raise Exception("Deadlock detected in scheduler") break instructions.append(dict(cycle_ops)) # Update children for NEXT cycle for node in curr_cycle_nodes: for child in node.children: in_degree[child] -= 1 if in_degree[child] == 0: heapq.heappush(ready, (-child.priority, child.id, child)) return instructions def _calc_priorities(self): memo = {} def get_dist(node): if node in memo: return memo[node] max_d = 0 for child in node.children: max_d = max(max_d, get_dist(child)) memo[node] = max_d + 1 return max_d + 1 for node in self.nodes: node.priority = get_dist(node) # --- Main Kernel Logic --- class KernelBuilder: def __init__(self): self.scheduler = Scheduler() self.scratch = {} self.scratch_debug = {} self.scratch_ptr = 0 self.const_map = {} def debug_info(self): return DebugInfo(scratch_map=self.scratch_debug) def finalize(self): return self.scheduler.schedule() def add_instr(self, instr_dict): # Compatibility wrapper for engine, slots in instr_dict.items(): for args in slots: self.scheduler.add_op(engine, args) def alloc_scratch(self, name=None, length=1): addr = self.scratch_ptr if name is not None: self.scratch[name] = addr self.scratch_debug[addr] = (name, length) self.scratch_ptr += length assert self.scratch_ptr <= SCRATCH_SIZE, f"Out of scratch space: {self.scratch_ptr}" return addr def scratch_const(self, val, name=None): if val not in self.const_map: addr = self.alloc_scratch(name) self.scheduler.add_op("load", ("const", addr, val)) self.const_map[val] = addr return self.const_map[val] def scratch_vec_const(self, val, name=None): key = (val, "vec") if key not in self.const_map: addr = self.alloc_scratch(name if name else f"vconst_{val}", VLEN) scalar_addr = self.scratch_const(val) self.scheduler.add_op("valu", ("vbroadcast", addr, scalar_addr)) self.const_map[key] = addr return self.const_map[key] def add_hash_opt(self, val_vec, tmp1_vec, tmp2_vec): """ Adds slots for the strength-reduced hash function to scheduler. """ # Stage 0: MAD c1 = self.scratch_vec_const(0x7ED55D16, "h0_c") m1 = self.scratch_vec_const(1 + (1<<12), "h0_m") self.scheduler.add_op("valu", ("multiply_add", val_vec, val_vec, m1, c1)) # Stage 1: Xor, Shift, Xor c2 = self.scratch_vec_const(0xC761C23C, "h1_c") s2 = self.scratch_vec_const(19, "h1_s") # 1a self.scheduler.add_op("valu", ("^", tmp1_vec, val_vec, c2)) self.scheduler.add_op("valu", (">>", tmp2_vec, val_vec, s2)) # 1b self.scheduler.add_op("valu", ("^", val_vec, tmp1_vec, tmp2_vec)) # Stage 2: MAD c3 = self.scratch_vec_const(0x165667B1, "h2_c") m3 = self.scratch_vec_const(1 + (1<<5), "h2_m") self.scheduler.add_op("valu", ("multiply_add", val_vec, val_vec, m3, c3)) # Stage 3: Add, Shift, Xor c4 = self.scratch_vec_const(0xD3A2646C, "h3_c") s4 = self.scratch_vec_const(9, "h3_s") self.scheduler.add_op("valu", ("+", tmp1_vec, val_vec, c4)) self.scheduler.add_op("valu", ("<<", tmp2_vec, val_vec, s4)) self.scheduler.add_op("valu", ("^", val_vec, tmp1_vec, tmp2_vec)) # Stage 4: MAD c5 = self.scratch_vec_const(0xFD7046C5, "h4_c") m5 = self.scratch_vec_const(1 + (1<<3), "h4_m") self.scheduler.add_op("valu", ("multiply_add", val_vec, val_vec, m5, c5)) # Stage 5: Xor, Shift, Xor c6 = self.scratch_vec_const(0xB55A4F09, "h5_c") s6 = self.scratch_vec_const(16, "h5_s") self.scheduler.add_op("valu", ("^", tmp1_vec, val_vec, c6)) self.scheduler.add_op("valu", (">>", tmp2_vec, val_vec, s6)) self.scheduler.add_op("valu", ("^", val_vec, tmp1_vec, tmp2_vec)) def add_hash_opt_scalar(self, val_vec, tmp1_vec, tmp2_vec): """ Scalarized version of hash optimization. Unrolls loop over 8 lanes and uses ALU engine. """ def add_alu_lanes(op, dest_vec, src1_vec, src2_vec, s2_is_const=False): for lane in range(VLEN): s2_addr = src2_vec if s2_is_const else src2_vec + lane self.scheduler.add_op("alu", (op, dest_vec + lane, src1_vec + lane, s2_addr)) def add_mad_lanes(dest_vec, a_vec, b_vec, c_vec, b_is_const=False, c_is_const=False): for lane in range(VLEN): b_addr = b_vec if b_is_const else b_vec + lane c_addr = c_vec if c_is_const else c_vec + lane # dest = a*b self.scheduler.add_op("alu", ("*", dest_vec + lane, a_vec + lane, b_addr)) # dest = dest+c self.scheduler.add_op("alu", ("+", dest_vec + lane, dest_vec + lane, c_addr)) # Stage 0: MAD c1 = self.scratch_const(0x7ED55D16, "h0_c") m1 = self.scratch_const(1 + (1<<12), "h0_m") add_mad_lanes(val_vec, val_vec, m1, c1, True, True) # Stage 1: Xor, Shift, Xor c2 = self.scratch_const(0xC761C23C, "h1_c") s2 = self.scratch_const(19, "h1_s") add_alu_lanes("^", tmp1_vec, val_vec, c2, True) add_alu_lanes(">>", tmp2_vec, val_vec, s2, True) add_alu_lanes("^", val_vec, tmp1_vec, tmp2_vec, False) # Stage 2: MAD c3 = self.scratch_const(0x165667B1, "h2_c") m3 = self.scratch_const(1 + (1<<5), "h2_m") add_mad_lanes(val_vec, val_vec, m3, c3, True, True) # Stage 3: Add, Shift, Xor c4 = self.scratch_const(0xD3A2646C, "h3_c") s4 = self.scratch_const(9, "h3_s") add_alu_lanes("+", tmp1_vec, val_vec, c4, True) add_alu_lanes("<<", tmp2_vec, val_vec, s4, True) add_alu_lanes("^", val_vec, tmp1_vec, tmp2_vec, False) # Stage 4: MAD c5 = self.scratch_const(0xFD7046C5, "h4_c") m5 = self.scratch_const(1 + (1<<3), "h4_m") add_mad_lanes(val_vec, val_vec, m5, c5, True, True) # Stage 5: Xor, Shift, Xor c6 = self.scratch_const(0xB55A4F09, "h5_c") s6 = self.scratch_const(16, "h5_s") add_alu_lanes("^", tmp1_vec, val_vec, c6, True) add_alu_lanes(">>", tmp2_vec, val_vec, s6, True) add_alu_lanes("^", val_vec, tmp1_vec, tmp2_vec, False) def build_kernel( self, forest_height: int, n_nodes: int, batch_size: int, rounds: int, active_threshold=4, mask_skip=True, scalar_offload=2 ): result_scalar_offload = scalar_offload # --- Memory Pointers --- init_vars = [ "rounds", "n_nodes", "batch_size", "forest_height", "forest_values_p", "inp_indices_p", "inp_values_p" ] ptr_map = {} tmp_load = self.alloc_scratch("tmp_load") for i, v in enumerate(init_vars): addr = self.alloc_scratch(v) ptr_map[v] = addr self.scheduler.add_op("load", ("const", tmp_load, i)) self.scheduler.add_op("load", ("load", addr, tmp_load)) indices_base = self.alloc_scratch("indices_cache", batch_size) values_base = self.alloc_scratch("values_cache", batch_size) # Memory Optimization: Reuse Scratch block_x = self.alloc_scratch("block_x", batch_size) block_y = self.alloc_scratch("block_y", batch_size) num_vecs = batch_size // VLEN tmp_addrs_base = block_x node_vals_base = block_x vtmp1_base = block_x vtmp2_base = block_y # Constants const_0_vec = self.scratch_vec_const(0) const_1_vec = self.scratch_vec_const(1) global_n_nodes_vec = self.alloc_scratch("n_nodes_vec", VLEN) self.scheduler.add_op("valu", ("vbroadcast", global_n_nodes_vec, ptr_map["n_nodes"])) active_temp_base = self.alloc_scratch("active_temp", 200) # --- 1. Load Input Data (Wavefront) --- for i in range(0, batch_size, VLEN): i_const = self.scratch_const(i) # Indices Addr self.scheduler.add_op("alu", ("+", tmp_load, ptr_map["inp_indices_p"], i_const)) self.scheduler.add_op("load", ("vload", indices_base + i, tmp_load)) self.scheduler.add_op("alu", ("+", tmp_load, ptr_map["inp_values_p"], i_const)) self.scheduler.add_op("load", ("vload", values_base + i, tmp_load)) # --- 2. Main Loop --- self.scheduler.add_op("flow", ("pause",)) active_indices = [] for r in range(rounds): # Collect register pointers for all vectors vecs = [] for vec_i in range(num_vecs): offset = vec_i * VLEN vecs.append({ 'idx': indices_base + offset, 'val': values_base + offset, 'node': node_vals_base + offset, 'tmp1': vtmp1_base + offset, 'tmp2': vtmp2_base + offset, 'addr': tmp_addrs_base + offset }) if r == 0: # Round 0: 1 Node (0) scalar_node = self.alloc_scratch("scalar_node_r0") self.scheduler.add_op("load", ("load", scalar_node, ptr_map["forest_values_p"])) for vec in vecs: self.scheduler.add_op("valu", ("vbroadcast", vec['node'], scalar_node)) active_indices = [0] elif len(active_indices) * 2 <= 8: # Threshold for next round # Reuse Scratch active_dev_ptr = active_temp_base def alloc_temp(length=1): nonlocal active_dev_ptr addr = active_dev_ptr active_dev_ptr += length assert active_dev_ptr <= active_temp_base + 512 return addr # Update active indices new_actives = [] for x in active_indices: new_actives.append(2*x + 1) new_actives.append(2*x + 2) active_indices = new_actives # Active Load Strategy node_map = {} for uidx in active_indices: s_node = alloc_temp(1) s_addr = alloc_temp(1) idx_c = self.scratch_const(uidx) # Calc Addr self.scheduler.add_op("alu", ("+", s_addr, ptr_map["forest_values_p"], idx_c)) # Load self.scheduler.add_op("load", ("load", s_node, s_addr)) # Broadcast v_node = alloc_temp(VLEN) self.scheduler.add_op("valu", ("vbroadcast", v_node, s_node)) node_map[uidx] = v_node tree_temp_start = active_dev_ptr # Select Tree for each vector for vec in vecs: active_dev_ptr = tree_temp_start def build_tree(indices): if len(indices) == 1: return node_map[indices[0]] mid = len(indices) // 2 left = indices[:mid] right = indices[mid:] split_val = right[0] split_c = self.scratch_vec_const(split_val) cond = alloc_temp(VLEN) self.scheduler.add_op("valu", ("<", cond, vec['idx'], split_c)) l_res = build_tree(left) r_res = build_tree(right) res = alloc_temp(VLEN) self.scheduler.add_op("flow", ("vselect", res, cond, l_res, r_res)) return res final_res = build_tree(active_indices) self.scheduler.add_op("valu", ("|", vec['node'], final_res, final_res)) else: # Generic Wavefront Load for vec in vecs: for lane in range(VLEN): self.scheduler.add_op("alu", ("+", vec['addr'] + lane, ptr_map["forest_values_p"], vec['idx'] + lane)) for vec in vecs: for lane in range(VLEN): self.scheduler.add_op("load", ("load", vec['node'] + lane, vec['addr'] + lane)) do_wrap = True if mask_skip and (1<<(r+2)) < n_nodes: do_wrap = False use_offload = (r >= active_threshold) and (not do_wrap) scalar_vectors = vecs[:result_scalar_offload] if use_offload else [] vector_vectors = vecs[result_scalar_offload:] if use_offload else vecs # --- VECTORIZED VECTORS --- # Mixed Hash for vec in vector_vectors: self.scheduler.add_op("valu", ("^", vec['val'], vec['val'], vec['node'])) for vec in vector_vectors: self.add_hash_opt(vec['val'], vec['tmp1'], vec['tmp2']) # Index Update for vec in vector_vectors: self.scheduler.add_op("valu", ("&", vec['tmp1'], vec['val'], const_1_vec)) self.scheduler.add_op("valu", ("+", vec['tmp1'], vec['tmp1'], const_1_vec)) self.scheduler.add_op("valu", ("+", vec['idx'], vec['idx'], vec['idx'])) self.scheduler.add_op("valu", ("+", vec['idx'], vec['idx'], vec['tmp1'])) # Wrap if do_wrap: for vec in vector_vectors: self.scheduler.add_op("valu", ("<", vec['tmp1'], vec['idx'], global_n_nodes_vec)) for vec in vector_vectors: self.scheduler.add_op("flow", ("vselect", vec['idx'], vec['tmp1'], vec['idx'], const_0_vec)) # --- SCALARIZED VECTORS --- def alu_lanes(op, dest, s1, s2, s2_c=False): for l in range(VLEN): s2_Address = s2 if s2_c else s2+l self.scheduler.add_op("alu", (op, dest+l, s1+l, s2_Address)) # Mixed Hash for vec in scalar_vectors: alu_lanes("^", vec['val'], vec['val'], vec['node'], False) for vec in scalar_vectors: self.add_hash_opt_scalar(vec['val'], vec['tmp1'], vec['tmp2']) # Index Update const_1 = self.scratch_const(1) for vec in scalar_vectors: alu_lanes("&", vec['tmp1'], vec['val'], const_1, True) alu_lanes("+", vec['tmp1'], vec['tmp1'], const_1, True) alu_lanes("+", vec['idx'], vec['idx'], vec['idx'], False) alu_lanes("+", vec['idx'], vec['idx'], vec['tmp1'], False) # Wrap if do_wrap: const_0 = self.scratch_const(0) n_nodes_c = ptr_map["n_nodes"] for vec in scalar_vectors: alu_lanes("<", vec['tmp1'], vec['idx'], n_nodes_c, True) for vec in scalar_vectors: for l in range(VLEN): self.scheduler.add_op("flow", ("select", vec['idx']+l, vec['tmp1']+l, vec['idx']+l, const_0)) # --- 3. Final Store --- for i in range(0, batch_size, VLEN): i_const = self.scratch_const(i) self.scheduler.add_op("alu", ("+", tmp_load, ptr_map["inp_indices_p"], i_const)) self.scheduler.add_op("store", ("vstore", tmp_load, indices_base + i)) self.scheduler.add_op("alu", ("+", tmp_load, ptr_map["inp_values_p"], i_const)) self.scheduler.add_op("store", ("vstore", tmp_load, values_base + i)) self.scheduler.add_op("flow", ("pause",)) self.instrs = self.scheduler.schedule() BASELINE = 147734 def do_kernel_test( forest_height: int, rounds: int, batch_size: int, seed: int = 123, trace: bool = False, prints: bool = False, ): print(f"{forest_height=}, {rounds=}, {batch_size=}") random.seed(seed) forest = Tree.generate(forest_height) inp = Input.generate(forest, batch_size, rounds) mem = build_mem_image(forest, inp) kb = KernelBuilder() kb.build_kernel(forest.height, len(forest.values), len(inp.indices), rounds) value_trace = {} machine = Machine( mem, kb.instrs, kb.debug_info(), n_cores=N_CORES, value_trace=value_trace, trace=trace, ) machine.prints = prints while machine.cores[0].state.value != 3: # STOPPED machine.run() if machine.cores[0].state.value == 2: # PAUSED machine.cores[0].state = machine.cores[0].state.__class__(1) # RUNNING continue break # Check FINAL result machine.enable_pause = False for ref_mem in reference_kernel2(mem, value_trace): pass inp_values_p = ref_mem[6] # DEBUG PRINT ALWAYS print("CYCLES: ", machine.cycle) if hasattr(machine.cores[0], 'trace_buf'): print("TRACE BUF:", machine.cores[0].trace_buf[:64]) assert ( machine.mem[inp_values_p : inp_values_p + len(inp.values)] == ref_mem[inp_values_p : inp_values_p + len(inp.values)] ), f"Incorrect result on final round" return machine.cycle class Tests(unittest.TestCase): def test_ref_kernels(self): random.seed(123) for i in range(10): f = Tree.generate(4) inp = Input.generate(f, 10, 6) mem = build_mem_image(f, inp) reference_kernel(f, inp) for _ in reference_kernel2(mem, {}): pass assert inp.indices == mem[mem[5] : mem[5] + len(inp.indices)] assert inp.values == mem[mem[6] : mem[6] + len(inp.values)] def test_kernel_trace(self): do_kernel_test(10, 16, 256, trace=True, prints=False) def test_kernel_cycles(self): do_kernel_test(10, 16, 256, prints=False) if __name__ == "__main__": unittest.main()