from collections import defaultdict, deque import heapq SLOT_LIMITS = { "alu": 12, "valu": 6, "load": 2, "store": 2, "flow": 1, "debug": 64, } class Node: def __init__(self, id, engine, args, desc=""): self.id = id self.engine = engine self.args = args # Tuple of args self.desc = desc self.parents = [] self.children = [] self.priority = 0 self.latency = 1 # Default latency def add_child(self, node): self.children.append(node) node.parents.append(self) class Scheduler: def __init__(self): self.nodes = [] self.id_counter = 0 self.scratch_reads = defaultdict(list) # addr -> [nodes reading it] self.scratch_writes = defaultdict(list) # addr -> [nodes writing it] def add_op(self, engine, args, desc=""): node = Node(self.id_counter, engine, args, desc) self.nodes.append(node) self.id_counter += 1 # Analyze dependencies # This requires knowing which args are sources and dests. # We need a grammar for this. reads, writes = self._get_rw(engine, args) # RAW (Read After Write): Current node reads from a previous write for r in reads: if r in self.scratch_writes and self.scratch_writes[r]: # Depend on the LAST writer last_writer = self.scratch_writes[r][-1] last_writer.add_child(node) # WAW (Write After Write): Current node writes to same addr as previous write # Strictly speaking, in VLIW, we just need to ensure ordering. for w in writes: if w in self.scratch_writes and self.scratch_writes[w]: last_writer = self.scratch_writes[w][-1] last_writer.add_child(node) # WAR (Write After Read): Current node writes to addr that was read previously # We must not write until previous reads are done. for w in writes: if w in self.scratch_reads and self.scratch_reads[w]: for reader in self.scratch_reads[w]: if reader != node: # Don't depend on self reader.add_child(node) # Register Access for r in reads: self.scratch_reads[r].append(node) for w in writes: self.scratch_writes[w].append(node) return node def _get_rw(self, engine, args): reads = [] writes = [] # Helpers def is_addr(x): return isinstance(x, int) if engine == "alu": # (op, dest, a1, a2) op, dest, a1, a2 = args writes.append(dest) reads.append(a1) reads.append(a2) elif engine == "valu": # varargs op = args[0] if op == "vbroadcast": # dest, src writes.extend([args[1] + i for i in range(8)]) reads.append(args[2]) elif op == "multiply_add": # dest, a, b, c writes.extend([args[1] + i for i in range(8)]) reads.extend([args[2] + i for i in range(8)]) reads.extend([args[3] + i for i in range(8)]) reads.extend([args[4] + i for i in range(8)]) else: # op, dest, a1, a2 writes.extend([args[1] + i for i in range(8)]) reads.extend([args[2] + i for i in range(8)]) reads.extend([args[3] + i for i in range(8)]) elif engine == "load": op = args[0] if op == "const": writes.append(args[1]) elif op == "load": writes.append(args[1]) reads.append(args[2]) elif op == "vload": writes.extend([args[1] + i for i in range(8)]) reads.append(args[2]) # scalar addr # Add others as needed elif engine == "store": op = args[0] if op == "vstore": reads.append(args[1]) # addr reads.extend([args[2] + i for i in range(8)]) # val # Add others elif engine == "flow": op = args[0] if op == "vselect": # dest, cond, a, b writes.extend([args[1] + i for i in range(8)]) reads.extend([args[2] + i for i in range(8)]) reads.extend([args[3] + i for i in range(8)]) reads.extend([args[4] + i for i in range(8)]) elif op == "select": # dest, cond, a, b writes.append(args[1]) reads.append(args[2]) reads.append(args[3]) reads.append(args[4]) elif op == "add_imm": # dest, a, imm writes.append(args[1]) reads.append(args[2]) elif op == "cond_jump" or op == "cond_jump_rel": # cond, dest reads.append(args[1]) # Control flow barrier? pass # pause, halt, etc have no data dependencies but might be barriers return reads, writes def schedule(self): # Calculate priorities (longest path) self._calc_priorities() ready = [] # Heap of (-priority, node) in_degree = defaultdict(int) for node in self.nodes: in_degree[node] = len(node.parents) if in_degree[node] == 0: heapq.heappush(ready, (-node.priority, node.id, node)) instructions = [] while ready or any(count > 0 for count in in_degree.values()): # Start a new cycle cycle_ops = defaultdict(list) # Helper: Try to pop from ready # We need to respect SLOT_LIMITS for this cycle # Since heapq is min-heap, we use negative priority # We want to greedily fill the cycle deferred = [] # Snapshot of current cycle usage usage = {k:0 for k in SLOT_LIMITS} # Multi-pass or one-pass? # One pass: Pop best. If fits, take it. Else put aside. curr_cycle_nodes = [] while ready: prio, nid, node = heapq.heappop(ready) # Check slot limit if usage[node.engine] < SLOT_LIMITS[node.engine]: # Schedule it usage[node.engine] += 1 cycle_ops[node.engine].append(node.args) curr_cycle_nodes.append(node) else: deferred.append((prio, nid, node)) # Push back deferred for item in deferred: heapq.heappush(ready, item) if not curr_cycle_nodes and not ready and any(in_degree.values()): # Deadlock? Or waiting? # If ready is empty but in_degree has stuff, it means everything is blocked. # But we just scheduled nothing? # Wait, if `ready` was empty initially, we are done. if len(instructions) == 0 and len(self.nodes) > 0: raise Exception("Deadlock or Cycle detected") break if not curr_cycle_nodes and not ready: break instructions.append(dict(cycle_ops)) # Update children for node in curr_cycle_nodes: for child in node.children: in_degree[child] -= 1 if in_degree[child] == 0: heapq.heappush(ready, (-child.priority, child.id, child)) return instructions def _calc_priorities(self): # Reverse topological traversal (or recursive memoized) memo = {} def get_dist(node): if node in memo: return memo[node] max_d = 0 for child in node.children: max_d = max(max_d, get_dist(child)) memo[node] = max_d + 1 return max_d + 1 for node in self.nodes: node.priority = get_dist(node)