anthropic-kernel / atempt_2 /scheduler.py
algorembrant's picture
Upload 39 files
f3ce0b0 verified
from collections import defaultdict, deque
import heapq
SLOT_LIMITS = {
"alu": 12,
"valu": 6,
"load": 2,
"store": 2,
"flow": 1,
"debug": 64,
}
class Node:
def __init__(self, id, engine, args, desc=""):
self.id = id
self.engine = engine
self.args = args # Tuple of args
self.desc = desc
self.parents = []
self.children = []
self.priority = 0
self.latency = 1 # Default latency
def add_child(self, node):
self.children.append(node)
node.parents.append(self)
class Scheduler:
def __init__(self):
self.nodes = []
self.id_counter = 0
self.scratch_reads = defaultdict(list) # addr -> [nodes reading it]
self.scratch_writes = defaultdict(list) # addr -> [nodes writing it]
def add_op(self, engine, args, desc=""):
node = Node(self.id_counter, engine, args, desc)
self.nodes.append(node)
self.id_counter += 1
# Analyze dependencies
# This requires knowing which args are sources and dests.
# We need a grammar for this.
reads, writes = self._get_rw(engine, args)
# RAW (Read After Write): Current node reads from a previous write
for r in reads:
if r in self.scratch_writes and self.scratch_writes[r]:
# Depend on the LAST writer
last_writer = self.scratch_writes[r][-1]
last_writer.add_child(node)
# WAW (Write After Write): Current node writes to same addr as previous write
# Strictly speaking, in VLIW, we just need to ensure ordering.
for w in writes:
if w in self.scratch_writes and self.scratch_writes[w]:
last_writer = self.scratch_writes[w][-1]
last_writer.add_child(node)
# WAR (Write After Read): Current node writes to addr that was read previously
# We must not write until previous reads are done.
for w in writes:
if w in self.scratch_reads and self.scratch_reads[w]:
for reader in self.scratch_reads[w]:
if reader != node: # Don't depend on self
reader.add_child(node)
# Register Access
for r in reads:
self.scratch_reads[r].append(node)
for w in writes:
self.scratch_writes[w].append(node)
return node
def _get_rw(self, engine, args):
reads = []
writes = []
# Helpers
def is_addr(x): return isinstance(x, int)
if engine == "alu":
# (op, dest, a1, a2)
op, dest, a1, a2 = args
writes.append(dest)
reads.append(a1)
reads.append(a2)
elif engine == "valu":
# varargs
op = args[0]
if op == "vbroadcast":
# dest, src
writes.extend([args[1] + i for i in range(8)])
reads.append(args[2])
elif op == "multiply_add":
# dest, a, b, c
writes.extend([args[1] + i for i in range(8)])
reads.extend([args[2] + i for i in range(8)])
reads.extend([args[3] + i for i in range(8)])
reads.extend([args[4] + i for i in range(8)])
else:
# op, dest, a1, a2
writes.extend([args[1] + i for i in range(8)])
reads.extend([args[2] + i for i in range(8)])
reads.extend([args[3] + i for i in range(8)])
elif engine == "load":
op = args[0]
if op == "const":
writes.append(args[1])
elif op == "load":
writes.append(args[1])
reads.append(args[2])
elif op == "vload":
writes.extend([args[1] + i for i in range(8)])
reads.append(args[2]) # scalar addr
# Add others as needed
elif engine == "store":
op = args[0]
if op == "vstore":
reads.append(args[1]) # addr
reads.extend([args[2] + i for i in range(8)]) # val
# Add others
elif engine == "flow":
op = args[0]
if op == "vselect":
# dest, cond, a, b
writes.extend([args[1] + i for i in range(8)])
reads.extend([args[2] + i for i in range(8)])
reads.extend([args[3] + i for i in range(8)])
reads.extend([args[4] + i for i in range(8)])
elif op == "select":
# dest, cond, a, b
writes.append(args[1])
reads.append(args[2])
reads.append(args[3])
reads.append(args[4])
elif op == "add_imm":
# dest, a, imm
writes.append(args[1])
reads.append(args[2])
elif op == "cond_jump" or op == "cond_jump_rel":
# cond, dest
reads.append(args[1])
# Control flow barrier?
pass
# pause, halt, etc have no data dependencies but might be barriers
return reads, writes
def schedule(self):
# Calculate priorities (longest path)
self._calc_priorities()
ready = [] # Heap of (-priority, node)
in_degree = defaultdict(int)
for node in self.nodes:
in_degree[node] = len(node.parents)
if in_degree[node] == 0:
heapq.heappush(ready, (-node.priority, node.id, node))
instructions = []
while ready or any(count > 0 for count in in_degree.values()):
# Start a new cycle
cycle_ops = defaultdict(list)
# Helper: Try to pop from ready
# We need to respect SLOT_LIMITS for this cycle
# Since heapq is min-heap, we use negative priority
# We want to greedily fill the cycle
deferred = []
# Snapshot of current cycle usage
usage = {k:0 for k in SLOT_LIMITS}
# Multi-pass or one-pass?
# One pass: Pop best. If fits, take it. Else put aside.
curr_cycle_nodes = []
while ready:
prio, nid, node = heapq.heappop(ready)
# Check slot limit
if usage[node.engine] < SLOT_LIMITS[node.engine]:
# Schedule it
usage[node.engine] += 1
cycle_ops[node.engine].append(node.args)
curr_cycle_nodes.append(node)
else:
deferred.append((prio, nid, node))
# Push back deferred
for item in deferred:
heapq.heappush(ready, item)
if not curr_cycle_nodes and not ready and any(in_degree.values()):
# Deadlock? Or waiting?
# If ready is empty but in_degree has stuff, it means everything is blocked.
# But we just scheduled nothing?
# Wait, if `ready` was empty initially, we are done.
if len(instructions) == 0 and len(self.nodes) > 0:
raise Exception("Deadlock or Cycle detected")
break
if not curr_cycle_nodes and not ready:
break
instructions.append(dict(cycle_ops))
# Update children
for node in curr_cycle_nodes:
for child in node.children:
in_degree[child] -= 1
if in_degree[child] == 0:
heapq.heappush(ready, (-child.priority, child.id, child))
return instructions
def _calc_priorities(self):
# Reverse topological traversal (or recursive memoized)
memo = {}
def get_dist(node):
if node in memo: return memo[node]
max_d = 0
for child in node.children:
max_d = max(max_d, get_dist(child))
memo[node] = max_d + 1
return max_d + 1
for node in self.nodes:
node.priority = get_dist(node)