File size: 8,826 Bytes
f3ce0b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239

from collections import defaultdict, deque
import heapq

SLOT_LIMITS = {
    "alu": 12,
    "valu": 6,
    "load": 2,
    "store": 2,
    "flow": 1,
    "debug": 64,
}

class Node:
    def __init__(self, id, engine, args, desc=""):
        self.id = id
        self.engine = engine
        self.args = args # Tuple of args
        self.desc = desc
        self.parents = []
        self.children = []
        self.priority = 0
        self.latency = 1 # Default latency

    def add_child(self, node):
        self.children.append(node)
        node.parents.append(self)

class Scheduler:
    def __init__(self):
        self.nodes = []
        self.id_counter = 0
        self.scratch_reads = defaultdict(list) # addr -> [nodes reading it]
        self.scratch_writes = defaultdict(list) # addr -> [nodes writing it]
        
    def add_op(self, engine, args, desc=""):
        node = Node(self.id_counter, engine, args, desc)
        self.nodes.append(node)
        self.id_counter += 1
        
        # Analyze dependencies
        # This requires knowing which args are sources and dests.
        # We need a grammar for this.
        
        reads, writes = self._get_rw(engine, args)
        
        # RAW (Read After Write): Current node reads from a previous write
        for r in reads:
            if r in self.scratch_writes and self.scratch_writes[r]:
                # Depend on the LAST writer
                last_writer = self.scratch_writes[r][-1]
                last_writer.add_child(node)
                
        # WAW (Write After Write): Current node writes to same addr as previous write
        # Strictly speaking, in VLIW, we just need to ensure ordering.
        for w in writes:
            if w in self.scratch_writes and self.scratch_writes[w]:
                 last_writer = self.scratch_writes[w][-1]
                 last_writer.add_child(node)
                 
        # WAR (Write After Read): Current node writes to addr that was read previously
        # We must not write until previous reads are done.
        for w in writes:
            if w in self.scratch_reads and self.scratch_reads[w]:
                for reader in self.scratch_reads[w]:
                    if reader != node: # Don't depend on self
                         reader.add_child(node)
                         
        # Register Access
        for r in reads:
            self.scratch_reads[r].append(node)
        for w in writes:
            self.scratch_writes[w].append(node)
            
        return node

    def _get_rw(self, engine, args):
        reads = []
        writes = []
        
        # Helpers
        def is_addr(x): return isinstance(x, int)
        
        if engine == "alu":
            # (op, dest, a1, a2)
            op, dest, a1, a2 = args
            writes.append(dest)
            reads.append(a1)
            reads.append(a2)
        elif engine == "valu":
            # varargs
            op = args[0]
            if op == "vbroadcast":
                # dest, src
                writes.extend([args[1] + i for i in range(8)])
                reads.append(args[2])
            elif op == "multiply_add":
                # dest, a, b, c
                writes.extend([args[1] + i for i in range(8)])
                reads.extend([args[2] + i for i in range(8)])
                reads.extend([args[3] + i for i in range(8)])
                reads.extend([args[4] + i for i in range(8)])
            else:
                # op, dest, a1, a2
                writes.extend([args[1] + i for i in range(8)])
                reads.extend([args[2] + i for i in range(8)])
                reads.extend([args[3] + i for i in range(8)])
        elif engine == "load":
            op = args[0]
            if op == "const":
                writes.append(args[1])
            elif op == "load":
                writes.append(args[1])
                reads.append(args[2])
            elif op == "vload":
                writes.extend([args[1] + i for i in range(8)])
                reads.append(args[2]) # scalar addr
            # Add others as needed
        elif engine == "store":
            op = args[0]
            if op == "vstore":
                reads.append(args[1]) # addr
                reads.extend([args[2] + i for i in range(8)]) # val
            # Add others
        elif engine == "flow":
            op = args[0]
            if op == "vselect":
                # dest, cond, a, b
                writes.extend([args[1] + i for i in range(8)])
                reads.extend([args[2] + i for i in range(8)])
                reads.extend([args[3] + i for i in range(8)])
                reads.extend([args[4] + i for i in range(8)])
            elif op == "select":
                # dest, cond, a, b
                writes.append(args[1])
                reads.append(args[2])
                reads.append(args[3])
                reads.append(args[4])
            elif op == "add_imm":
                # dest, a, imm
                writes.append(args[1])
                reads.append(args[2])
            elif op == "cond_jump" or op == "cond_jump_rel":
                # cond, dest
                reads.append(args[1])
                # Control flow barrier?
                pass
            # pause, halt, etc have no data dependencies but might be barriers
             
             
        return reads, writes
        
    def schedule(self):
        # Calculate priorities (longest path)
        self._calc_priorities()
        
        ready = [] # Heap of (-priority, node)
        in_degree = defaultdict(int)
        
        for node in self.nodes:
            in_degree[node] = len(node.parents)
            if in_degree[node] == 0:
                heapq.heappush(ready, (-node.priority, node.id, node))
                
        instructions = []
        
        while ready or any(count > 0 for count in in_degree.values()):
            # Start a new cycle
            cycle_ops = defaultdict(list)
            
            # Helper: Try to pop from ready
            # We need to respect SLOT_LIMITS for this cycle
            
            # Since heapq is min-heap, we use negative priority
            # We want to greedily fill the cycle
            
            deferred = []
            
            # Snapshot of current cycle usage
            usage = {k:0 for k in SLOT_LIMITS}
            
            # Multi-pass or one-pass? 
            # One pass: Pop best. If fits, take it. Else put aside.
            
            curr_cycle_nodes = []
            
            while ready:
                prio, nid, node = heapq.heappop(ready)
                
                # Check slot limit
                if usage[node.engine] < SLOT_LIMITS[node.engine]:
                    # Schedule it
                    usage[node.engine] += 1
                    cycle_ops[node.engine].append(node.args)
                    curr_cycle_nodes.append(node)
                else:
                    deferred.append((prio, nid, node))
            
            # Push back deferred
            for item in deferred:
                heapq.heappush(ready, item)
            
            if not curr_cycle_nodes and not ready and any(in_degree.values()):
                # Deadlock? Or waiting?
                # If ready is empty but in_degree has stuff, it means everything is blocked.
                # But we just scheduled nothing?
                # Wait, if `ready` was empty initially, we are done.
                if len(instructions) == 0 and len(self.nodes) > 0:
                     raise Exception("Deadlock or Cycle detected")
                break 

            if not curr_cycle_nodes and not ready:
                break
                
            instructions.append(dict(cycle_ops))
            
            # Update children
            for node in curr_cycle_nodes:
                for child in node.children:
                    in_degree[child] -= 1
                    if in_degree[child] == 0:
                        heapq.heappush(ready, (-child.priority, child.id, child))
                        
        return instructions

    def _calc_priorities(self):
        # Reverse topological traversal (or recursive memoized)
        memo = {}
        def get_dist(node):
            if node in memo: return memo[node]
            max_d = 0
            for child in node.children:
                max_d = max(max_d, get_dist(child))
            memo[node] = max_d + 1
            return max_d + 1
            
        for node in self.nodes:
            node.priority = get_dist(node)