|
|
""" |
|
|
Read the top of perf_takehome.py for more introduction. |
|
|
|
|
|
This file is separate mostly for ease of copying it to freeze the machine and |
|
|
reference kernel for testing. |
|
|
""" |
|
|
|
|
|
from copy import copy |
|
|
from dataclasses import dataclass |
|
|
from enum import Enum |
|
|
from typing import Any, Literal |
|
|
import random |
|
|
|
|
|
Engine = Literal["alu", "load", "store", "flow"] |
|
|
Instruction = dict[Engine, list[tuple]] |
|
|
|
|
|
|
|
|
class CoreState(Enum): |
|
|
RUNNING = 1 |
|
|
PAUSED = 2 |
|
|
STOPPED = 3 |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class Core: |
|
|
id: int |
|
|
scratch: list[int] |
|
|
trace_buf: list[int] |
|
|
pc: int = 0 |
|
|
state: CoreState = CoreState.RUNNING |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class DebugInfo: |
|
|
""" |
|
|
We give you some debug info but it's up to you to use it in Machine if you |
|
|
want to. You're also welcome to add more. |
|
|
""" |
|
|
|
|
|
|
|
|
scratch_map: dict[int, (str, int)] |
|
|
|
|
|
|
|
|
def cdiv(a, b): |
|
|
return (a + b - 1) // b |
|
|
|
|
|
|
|
|
SLOT_LIMITS = { |
|
|
"alu": 12, |
|
|
"valu": 6, |
|
|
"load": 2, |
|
|
"store": 2, |
|
|
"flow": 1, |
|
|
"debug": 64, |
|
|
} |
|
|
|
|
|
VLEN = 8 |
|
|
|
|
|
N_CORES = 1 |
|
|
SCRATCH_SIZE = 1536 |
|
|
BASE_ADDR_TID = 100000 |
|
|
|
|
|
|
|
|
class Machine: |
|
|
""" |
|
|
Simulator for a custom VLIW SIMD architecture. |
|
|
|
|
|
VLIW (Very Large Instruction Word): Cores are composed of different |
|
|
"engines" each of which can execute multiple "slots" per cycle in parallel. |
|
|
How many slots each engine can execute per cycle is limited by SLOT_LIMITS. |
|
|
Effects of instructions don't take effect until the end of cycle. Each |
|
|
cycle, all engines execute all of their filled slots for that instruction. |
|
|
Effects like writes to memory take place after all the inputs are read. |
|
|
|
|
|
SIMD: There are instructions for acting on vectors of VLEN elements in a |
|
|
single slot. You can use vload and vstore to load multiple contiguous |
|
|
elements but not non-contiguous elements. Use vbroadcast to broadcast a |
|
|
scalar to a vector and then operate on vectors with valu instructions. |
|
|
|
|
|
The memory and scratch space are composed of 32-bit words. The solution is |
|
|
plucked out of the memory at the end of the program. You can think of the |
|
|
scratch space as serving the purpose of registers, constant memory, and a |
|
|
manually-managed cache. |
|
|
|
|
|
Here's an example of what an instruction might look like: |
|
|
|
|
|
{"valu": [("*", 4, 0, 0), ("+", 8, 4, 0)], "load": [("load", 16, 17)]} |
|
|
|
|
|
In general every number in an instruction is a scratch address except for |
|
|
const and jump, and except for store and some flow instructions the first |
|
|
operand is the destination. |
|
|
|
|
|
This comment is not meant to be full ISA documentation though, for the rest |
|
|
you should look through the simulator code. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
mem_dump: list[int], |
|
|
program: list[Instruction], |
|
|
debug_info: DebugInfo, |
|
|
n_cores: int = 1, |
|
|
scratch_size: int = SCRATCH_SIZE, |
|
|
trace: bool = False, |
|
|
value_trace: dict[Any, int] = {}, |
|
|
): |
|
|
self.cores = [ |
|
|
Core(id=i, scratch=[0] * scratch_size, trace_buf=[]) for i in range(n_cores) |
|
|
] |
|
|
self.mem = copy(mem_dump) |
|
|
self.program = program |
|
|
self.debug_info = debug_info |
|
|
self.value_trace = value_trace |
|
|
self.prints = False |
|
|
self.cycle = 0 |
|
|
self.enable_pause = True |
|
|
self.enable_debug = True |
|
|
if trace: |
|
|
self.setup_trace() |
|
|
else: |
|
|
self.trace = None |
|
|
|
|
|
def rewrite_instr(self, instr): |
|
|
""" |
|
|
Rewrite an instruction to use scratch addresses instead of names |
|
|
""" |
|
|
res = {} |
|
|
for name, slots in instr.items(): |
|
|
res[name] = [] |
|
|
for slot in slots: |
|
|
res[name].append(self.rewrite_slot(slot)) |
|
|
return res |
|
|
|
|
|
def print_step(self, instr, core): |
|
|
|
|
|
|
|
|
print(self.scratch_map(core)) |
|
|
print(core.pc, instr, self.rewrite_instr(instr)) |
|
|
|
|
|
def scratch_map(self, core): |
|
|
res = {} |
|
|
for addr, (name, length) in self.debug_info.scratch_map.items(): |
|
|
res[name] = core.scratch[addr : addr + length] |
|
|
return res |
|
|
|
|
|
def rewrite_slot(self, slot): |
|
|
return tuple( |
|
|
self.debug_info.scratch_map.get(s, (None, None))[0] or s for s in slot |
|
|
) |
|
|
|
|
|
def setup_trace(self): |
|
|
""" |
|
|
The simulator generates traces in Chrome's Trace Event Format for |
|
|
visualization in Perfetto (or chrome://tracing if you prefer it). See |
|
|
the bottom of the file for info about how to use this. |
|
|
|
|
|
See the format docs in case you want to add more info to the trace: |
|
|
https://docs.google.com/document/d/1CvAClvFfyA5R-PhYUmn5OOQtYMH4h6I0nSsKchNAySU/preview |
|
|
""" |
|
|
self.trace = open("trace.json", "w") |
|
|
self.trace.write("[") |
|
|
tid_counter = 0 |
|
|
self.tids = {} |
|
|
for ci, core in enumerate(self.cores): |
|
|
self.trace.write( |
|
|
f'{{"name": "process_name", "ph": "M", "pid": {ci}, "tid": 0, "args": {{"name":"Core {ci}"}}}},\n' |
|
|
) |
|
|
for name, limit in SLOT_LIMITS.items(): |
|
|
if name == "debug": |
|
|
continue |
|
|
for i in range(limit): |
|
|
tid_counter += 1 |
|
|
self.trace.write( |
|
|
f'{{"name": "thread_name", "ph": "M", "pid": {ci}, "tid": {tid_counter}, "args": {{"name":"{name}-{i}"}}}},\n' |
|
|
) |
|
|
self.tids[(ci, name, i)] = tid_counter |
|
|
|
|
|
|
|
|
for ci, core in enumerate(self.cores): |
|
|
for name, limit in SLOT_LIMITS.items(): |
|
|
if name == "debug": |
|
|
continue |
|
|
for i in range(limit): |
|
|
tid = self.tids[(ci, name, i)] |
|
|
self.trace.write( |
|
|
f'{{"name": "init", "cat": "op", "ph": "X", "pid": {ci}, "tid": {tid}, "ts": 0, "dur": 0}},\n' |
|
|
) |
|
|
for ci, core in enumerate(self.cores): |
|
|
self.trace.write( |
|
|
f'{{"name": "process_name", "ph": "M", "pid": {len(self.cores) + ci}, "tid": 0, "args": {{"name":"Core {ci} Scratch"}}}},\n' |
|
|
) |
|
|
for addr, (name, length) in self.debug_info.scratch_map.items(): |
|
|
self.trace.write( |
|
|
f'{{"name": "thread_name", "ph": "M", "pid": {len(self.cores) + ci}, "tid": {BASE_ADDR_TID + addr}, "args": {{"name":"{name}-{length}"}}}},\n' |
|
|
) |
|
|
|
|
|
def run(self): |
|
|
for core in self.cores: |
|
|
if core.state == CoreState.PAUSED: |
|
|
core.state = CoreState.RUNNING |
|
|
while any(c.state == CoreState.RUNNING for c in self.cores): |
|
|
has_non_debug = False |
|
|
for core in self.cores: |
|
|
if core.state != CoreState.RUNNING: |
|
|
continue |
|
|
if core.pc >= len(self.program): |
|
|
core.state = CoreState.STOPPED |
|
|
continue |
|
|
instr = self.program[core.pc] |
|
|
if self.prints: |
|
|
self.print_step(instr, core) |
|
|
core.pc += 1 |
|
|
self.step(instr, core) |
|
|
if any(name != "debug" for name in instr.keys()): |
|
|
has_non_debug = True |
|
|
if has_non_debug: |
|
|
self.cycle += 1 |
|
|
|
|
|
def alu(self, core, op, dest, a1, a2): |
|
|
a1 = core.scratch[a1] |
|
|
a2 = core.scratch[a2] |
|
|
match op: |
|
|
case "+": |
|
|
res = a1 + a2 |
|
|
case "-": |
|
|
res = a1 - a2 |
|
|
case "*": |
|
|
res = a1 * a2 |
|
|
case "//": |
|
|
res = a1 // a2 |
|
|
case "cdiv": |
|
|
res = cdiv(a1, a2) |
|
|
case "^": |
|
|
res = a1 ^ a2 |
|
|
case "&": |
|
|
res = a1 & a2 |
|
|
case "|": |
|
|
res = a1 | a2 |
|
|
case "<<": |
|
|
res = a1 << a2 |
|
|
case ">>": |
|
|
res = a1 >> a2 |
|
|
case "%": |
|
|
res = a1 % a2 |
|
|
case "<": |
|
|
res = int(a1 < a2) |
|
|
case "==": |
|
|
res = int(a1 == a2) |
|
|
case _: |
|
|
raise NotImplementedError(f"Unknown alu op {op}") |
|
|
res = res % (2**32) |
|
|
self.scratch_write[dest] = res |
|
|
|
|
|
def valu(self, core, *slot): |
|
|
match slot: |
|
|
case ("vbroadcast", dest, src): |
|
|
for i in range(VLEN): |
|
|
self.scratch_write[dest + i] = core.scratch[src] |
|
|
case ("multiply_add", dest, a, b, c): |
|
|
for i in range(VLEN): |
|
|
mul = (core.scratch[a + i] * core.scratch[b + i]) % (2**32) |
|
|
self.scratch_write[dest + i] = (mul + core.scratch[c + i]) % (2**32) |
|
|
case (op, dest, a1, a2): |
|
|
for i in range(VLEN): |
|
|
self.alu(core, op, dest + i, a1 + i, a2 + i) |
|
|
case _: |
|
|
raise NotImplementedError(f"Unknown valu op {slot}") |
|
|
|
|
|
def load(self, core, *slot): |
|
|
match slot: |
|
|
case ("load", dest, addr): |
|
|
|
|
|
self.scratch_write[dest] = self.mem[core.scratch[addr]] |
|
|
case ("load_offset", dest, addr, offset): |
|
|
|
|
|
self.scratch_write[dest + offset] = self.mem[ |
|
|
core.scratch[addr + offset] |
|
|
] |
|
|
case ("vload", dest, addr): |
|
|
addr = core.scratch[addr] |
|
|
for vi in range(VLEN): |
|
|
self.scratch_write[dest + vi] = self.mem[addr + vi] |
|
|
case ("const", dest, val): |
|
|
self.scratch_write[dest] = (val) % (2**32) |
|
|
case _: |
|
|
raise NotImplementedError(f"Unknown load op {slot}") |
|
|
|
|
|
def store(self, core, *slot): |
|
|
match slot: |
|
|
case ("store", addr, src): |
|
|
addr = core.scratch[addr] |
|
|
self.mem_write[addr] = core.scratch[src] |
|
|
case ("vstore", addr, src): |
|
|
addr = core.scratch[addr] |
|
|
for vi in range(VLEN): |
|
|
self.mem_write[addr + vi] = core.scratch[src + vi] |
|
|
case _: |
|
|
raise NotImplementedError(f"Unknown store op {slot}") |
|
|
|
|
|
def flow(self, core, *slot): |
|
|
match slot: |
|
|
case ("select", dest, cond, a, b): |
|
|
self.scratch_write[dest] = ( |
|
|
core.scratch[a] if core.scratch[cond] != 0 else core.scratch[b] |
|
|
) |
|
|
case ("add_imm", dest, a, imm): |
|
|
self.scratch_write[dest] = (core.scratch[a] + imm) % (2**32) |
|
|
case ("vselect", dest, cond, a, b): |
|
|
for vi in range(VLEN): |
|
|
self.scratch_write[dest + vi] = ( |
|
|
core.scratch[a + vi] |
|
|
if core.scratch[cond + vi] != 0 |
|
|
else core.scratch[b + vi] |
|
|
) |
|
|
case ("halt",): |
|
|
core.state = CoreState.STOPPED |
|
|
case ("pause",): |
|
|
if self.enable_pause: |
|
|
core.state = CoreState.PAUSED |
|
|
case ("trace_write", val): |
|
|
core.trace_buf.append(core.scratch[val]) |
|
|
case ("cond_jump", cond, addr): |
|
|
if core.scratch[cond] != 0: |
|
|
core.pc = addr |
|
|
case ("cond_jump_rel", cond, offset): |
|
|
if core.scratch[cond] != 0: |
|
|
core.pc += offset |
|
|
case ("jump", addr): |
|
|
core.pc = addr |
|
|
case ("jump_indirect", addr): |
|
|
core.pc = core.scratch[addr] |
|
|
case ("coreid", dest): |
|
|
self.scratch_write[dest] = core.id |
|
|
case _: |
|
|
raise NotImplementedError(f"Unknown flow op {slot}") |
|
|
|
|
|
def trace_post_step(self, instr, core): |
|
|
|
|
|
for addr, (name, length) in self.debug_info.scratch_map.items(): |
|
|
if any((addr + vi) in self.scratch_write for vi in range(length)): |
|
|
val = str(core.scratch[addr : addr + length]) |
|
|
val = val.replace("[", "").replace("]", "") |
|
|
self.trace.write( |
|
|
f'{{"name": "{val}", "cat": "op", "ph": "X", "pid": {len(self.cores) + core.id}, "tid": {BASE_ADDR_TID + addr}, "ts": {self.cycle}, "dur": 1 }},\n' |
|
|
) |
|
|
|
|
|
def trace_slot(self, core, slot, name, i): |
|
|
self.trace.write( |
|
|
f'{{"name": "{slot[0]}", "cat": "op", "ph": "X", "pid": {core.id}, "tid": {self.tids[(core.id, name, i)]}, "ts": {self.cycle}, "dur": 1, "args":{{"slot": "{str(slot)}", "named": "{str(self.rewrite_slot(slot))}" }} }},\n' |
|
|
) |
|
|
|
|
|
def step(self, instr: Instruction, core): |
|
|
""" |
|
|
Execute all the slots in each engine for a single instruction bundle |
|
|
""" |
|
|
ENGINE_FNS = { |
|
|
"alu": self.alu, |
|
|
"valu": self.valu, |
|
|
"load": self.load, |
|
|
"store": self.store, |
|
|
"flow": self.flow, |
|
|
} |
|
|
self.scratch_write = {} |
|
|
self.mem_write = {} |
|
|
for name, slots in instr.items(): |
|
|
if name == "debug": |
|
|
if not self.enable_debug: |
|
|
continue |
|
|
for slot in slots: |
|
|
if slot[0] == "compare": |
|
|
loc, key = slot[1], slot[2] |
|
|
ref = self.value_trace[key] |
|
|
res = core.scratch[loc] |
|
|
assert res == ref, f"{res} != {ref} for {key} at pc={core.pc}" |
|
|
elif slot[0] == "vcompare": |
|
|
loc, keys = slot[1], slot[2] |
|
|
ref = [self.value_trace[key] for key in keys] |
|
|
res = core.scratch[loc : loc + VLEN] |
|
|
assert res == ref, ( |
|
|
f"{res} != {ref} for {keys} at pc={core.pc} loc={loc}" |
|
|
) |
|
|
continue |
|
|
assert len(slots) <= SLOT_LIMITS[name] |
|
|
for i, slot in enumerate(slots): |
|
|
if self.trace is not None: |
|
|
self.trace_slot(core, slot, name, i) |
|
|
ENGINE_FNS[name](core, *slot) |
|
|
for addr, val in self.scratch_write.items(): |
|
|
core.scratch[addr] = val |
|
|
for addr, val in self.mem_write.items(): |
|
|
self.mem[addr] = val |
|
|
|
|
|
if self.trace: |
|
|
self.trace_post_step(instr, core) |
|
|
|
|
|
del self.scratch_write |
|
|
del self.mem_write |
|
|
|
|
|
def __del__(self): |
|
|
if self.trace is not None: |
|
|
self.trace.write("]") |
|
|
self.trace.close() |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class Tree: |
|
|
""" |
|
|
An implicit perfect balanced binary tree with values on the nodes. |
|
|
""" |
|
|
|
|
|
height: int |
|
|
values: list[int] |
|
|
|
|
|
@staticmethod |
|
|
def generate(height: int): |
|
|
n_nodes = 2 ** (height + 1) - 1 |
|
|
values = [random.randint(0, 2**30 - 1) for _ in range(n_nodes)] |
|
|
return Tree(height, values) |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class Input: |
|
|
""" |
|
|
A batch of inputs, indices to nodes (starting as 0) and initial input |
|
|
values. We then iterate these for a specified number of rounds. |
|
|
""" |
|
|
|
|
|
indices: list[int] |
|
|
values: list[int] |
|
|
rounds: int |
|
|
|
|
|
@staticmethod |
|
|
def generate(forest: Tree, batch_size: int, rounds: int): |
|
|
indices = [0 for _ in range(batch_size)] |
|
|
values = [random.randint(0, 2**30 - 1) for _ in range(batch_size)] |
|
|
return Input(indices, values, rounds) |
|
|
|
|
|
|
|
|
HASH_STAGES = [ |
|
|
("+", 0x7ED55D16, "+", "<<", 12), |
|
|
("^", 0xC761C23C, "^", ">>", 19), |
|
|
("+", 0x165667B1, "+", "<<", 5), |
|
|
("+", 0xD3A2646C, "^", "<<", 9), |
|
|
("+", 0xFD7046C5, "+", "<<", 3), |
|
|
("^", 0xB55A4F09, "^", ">>", 16), |
|
|
] |
|
|
|
|
|
|
|
|
def myhash(a: int) -> int: |
|
|
"""A simple 32-bit hash function""" |
|
|
fns = { |
|
|
"+": lambda x, y: x + y, |
|
|
"^": lambda x, y: x ^ y, |
|
|
"<<": lambda x, y: x << y, |
|
|
">>": lambda x, y: x >> y, |
|
|
} |
|
|
|
|
|
def r(x): |
|
|
return x % (2**32) |
|
|
|
|
|
for op1, val1, op2, op3, val3 in HASH_STAGES: |
|
|
a = r(fns[op2](r(fns[op1](a, val1)), r(fns[op3](a, val3)))) |
|
|
|
|
|
return a |
|
|
|
|
|
|
|
|
def reference_kernel(t: Tree, inp: Input): |
|
|
""" |
|
|
Reference implementation of the kernel. |
|
|
|
|
|
A parallel tree traversal where at each node we set |
|
|
cur_inp_val = myhash(cur_inp_val ^ node_val) |
|
|
and then choose the left branch if cur_inp_val is even. |
|
|
If we reach the bottom of the tree we wrap around to the top. |
|
|
""" |
|
|
for h in range(inp.rounds): |
|
|
for i in range(len(inp.indices)): |
|
|
idx = inp.indices[i] |
|
|
val = inp.values[i] |
|
|
val = myhash(val ^ t.values[idx]) |
|
|
idx = 2 * idx + (1 if val % 2 == 0 else 2) |
|
|
idx = 0 if idx >= len(t.values) else idx |
|
|
inp.values[i] = val |
|
|
inp.indices[i] = idx |
|
|
|
|
|
|
|
|
def build_mem_image(t: Tree, inp: Input) -> list[int]: |
|
|
""" |
|
|
Build a flat memory image of the problem. |
|
|
""" |
|
|
header = 7 |
|
|
extra_room = len(t.values) + len(inp.indices) * 2 + VLEN * 2 + 32 |
|
|
mem = [0] * ( |
|
|
header + len(t.values) + len(inp.indices) + len(inp.values) + extra_room |
|
|
) |
|
|
forest_values_p = header |
|
|
inp_indices_p = forest_values_p + len(t.values) |
|
|
inp_values_p = inp_indices_p + len(inp.values) |
|
|
extra_room = inp_values_p + len(inp.values) |
|
|
|
|
|
mem[0] = inp.rounds |
|
|
mem[1] = len(t.values) |
|
|
mem[2] = len(inp.indices) |
|
|
mem[3] = t.height |
|
|
mem[4] = forest_values_p |
|
|
mem[5] = inp_indices_p |
|
|
mem[6] = inp_values_p |
|
|
mem[7] = extra_room |
|
|
|
|
|
mem[header:inp_indices_p] = t.values |
|
|
mem[inp_indices_p:inp_values_p] = inp.indices |
|
|
mem[inp_values_p:] = inp.values |
|
|
return mem |
|
|
|
|
|
|
|
|
def myhash_traced(a: int, trace: dict[Any, int], round: int, batch_i: int) -> int: |
|
|
"""A simple 32-bit hash function""" |
|
|
fns = { |
|
|
"+": lambda x, y: x + y, |
|
|
"^": lambda x, y: x ^ y, |
|
|
"<<": lambda x, y: x << y, |
|
|
">>": lambda x, y: x >> y, |
|
|
} |
|
|
|
|
|
def r(x): |
|
|
return x % (2**32) |
|
|
|
|
|
for i, (op1, val1, op2, op3, val3) in enumerate(HASH_STAGES): |
|
|
a = r(fns[op2](r(fns[op1](a, val1)), r(fns[op3](a, val3)))) |
|
|
trace[(round, batch_i, "hash_stage", i)] = a |
|
|
|
|
|
return a |
|
|
|
|
|
|
|
|
def reference_kernel2(mem: list[int], trace: dict[Any, int] = {}): |
|
|
""" |
|
|
Reference implementation of the kernel on a flat memory. |
|
|
""" |
|
|
|
|
|
rounds = mem[0] |
|
|
n_nodes = mem[1] |
|
|
batch_size = mem[2] |
|
|
forest_height = mem[3] |
|
|
|
|
|
forest_values_p = mem[4] |
|
|
inp_indices_p = mem[5] |
|
|
inp_values_p = mem[6] |
|
|
yield mem |
|
|
for h in range(rounds): |
|
|
for i in range(batch_size): |
|
|
idx = mem[inp_indices_p + i] |
|
|
trace[(h, i, "idx")] = idx |
|
|
val = mem[inp_values_p + i] |
|
|
trace[(h, i, "val")] = val |
|
|
node_val = mem[forest_values_p + idx] |
|
|
trace[(h, i, "node_val")] = node_val |
|
|
val = myhash_traced(val ^ node_val, trace, h, i) |
|
|
trace[(h, i, "hashed_val")] = val |
|
|
idx = 2 * idx + (1 if val % 2 == 0 else 2) |
|
|
trace[(h, i, "next_idx")] = idx |
|
|
idx = 0 if idx >= n_nodes else idx |
|
|
trace[(h, i, "wrapped_idx")] = idx |
|
|
mem[inp_values_p + i] = val |
|
|
mem[inp_indices_p + i] = idx |
|
|
|
|
|
|
|
|
|
|
|
yield mem |
|
|
|