Spaces:
Running
on
A10G
Running
on
A10G
| """ | |
| Read the top of perf_takehome.py for more introduction. | |
| This file is separate mostly for ease of copying it to freeze the machine and | |
| reference kernel for testing. | |
| """ | |
| from copy import copy | |
| from dataclasses import dataclass | |
| from enum import Enum | |
| from typing import Any, Literal | |
| import random | |
| Engine = Literal["alu", "load", "store", "flow"] | |
| Instruction = dict[Engine, list[tuple]] | |
| class CoreState(Enum): | |
| RUNNING = 1 | |
| PAUSED = 2 | |
| STOPPED = 3 | |
| class Core: | |
| id: int | |
| scratch: list[int] | |
| trace_buf: list[int] | |
| pc: int = 0 | |
| state: CoreState = CoreState.RUNNING | |
| class DebugInfo: | |
| """ | |
| We give you some debug info but it's up to you to use it in Machine if you | |
| want to. You're also welcome to add more. | |
| """ | |
| # Maps scratch variable addr to (name, len) pair | |
| scratch_map: dict[int, (str, int)] | |
| def cdiv(a, b): | |
| return (a + b - 1) // b | |
| SLOT_LIMITS = { | |
| "alu": 12, | |
| "valu": 6, | |
| "load": 2, | |
| "store": 2, | |
| "flow": 1, | |
| "debug": 64, | |
| } | |
| VLEN = 8 | |
| # Older versions of the take-home used multiple cores, but this version only uses 1 | |
| N_CORES = 1 | |
| SCRATCH_SIZE = 1536 | |
| BASE_ADDR_TID = 100000 | |
| class Machine: | |
| """ | |
| Simulator for a custom VLIW SIMD architecture. | |
| VLIW (Very Large Instruction Word): Cores are composed of different | |
| "engines" each of which can execute multiple "slots" per cycle in parallel. | |
| How many slots each engine can execute per cycle is limited by SLOT_LIMITS. | |
| Effects of instructions don't take effect until the end of cycle. Each | |
| cycle, all engines execute all of their filled slots for that instruction. | |
| Effects like writes to memory take place after all the inputs are read. | |
| SIMD: There are instructions for acting on vectors of VLEN elements in a | |
| single slot. You can use vload and vstore to load multiple contiguous | |
| elements but not non-contiguous elements. Use vbroadcast to broadcast a | |
| scalar to a vector and then operate on vectors with valu instructions. | |
| The memory and scratch space are composed of 32-bit words. The solution is | |
| plucked out of the memory at the end of the program. You can think of the | |
| scratch space as serving the purpose of registers, constant memory, and a | |
| manually-managed cache. | |
| Here's an example of what an instruction might look like: | |
| {"valu": [("*", 4, 0, 0), ("+", 8, 4, 0)], "load": [("load", 16, 17)]} | |
| In general every number in an instruction is a scratch address except for | |
| const and jump, and except for store and some flow instructions the first | |
| operand is the destination. | |
| This comment is not meant to be full ISA documentation though, for the rest | |
| you should look through the simulator code. | |
| """ | |
| def __init__( | |
| self, | |
| mem_dump: list[int], | |
| program: list[Instruction], | |
| debug_info: DebugInfo, | |
| n_cores: int = 1, | |
| scratch_size: int = SCRATCH_SIZE, | |
| trace: bool = False, | |
| value_trace: dict[Any, int] = {}, | |
| ): | |
| self.cores = [ | |
| Core(id=i, scratch=[0] * scratch_size, trace_buf=[]) for i in range(n_cores) | |
| ] | |
| self.mem = copy(mem_dump) | |
| self.program = program | |
| self.debug_info = debug_info | |
| self.value_trace = value_trace | |
| self.prints = False | |
| self.cycle = 0 | |
| self.enable_pause = True | |
| self.enable_debug = True | |
| if trace: | |
| self.setup_trace() | |
| else: | |
| self.trace = None | |
| def rewrite_instr(self, instr): | |
| """ | |
| Rewrite an instruction to use scratch addresses instead of names | |
| """ | |
| res = {} | |
| for name, slots in instr.items(): | |
| res[name] = [] | |
| for slot in slots: | |
| res[name].append(self.rewrite_slot(slot)) | |
| return res | |
| def print_step(self, instr, core): | |
| # print(core.id) | |
| # print(core.trace_buf) | |
| print(self.scratch_map(core)) | |
| print(core.pc, instr, self.rewrite_instr(instr)) | |
| def scratch_map(self, core): | |
| res = {} | |
| for addr, (name, length) in self.debug_info.scratch_map.items(): | |
| res[name] = core.scratch[addr : addr + length] | |
| return res | |
| def rewrite_slot(self, slot): | |
| return tuple( | |
| self.debug_info.scratch_map.get(s, (None, None))[0] or s for s in slot | |
| ) | |
| def setup_trace(self): | |
| """ | |
| The simulator generates traces in Chrome's Trace Event Format for | |
| visualization in Perfetto (or chrome://tracing if you prefer it). See | |
| the bottom of the file for info about how to use this. | |
| See the format docs in case you want to add more info to the trace: | |
| https://docs.google.com/document/d/1CvAClvFfyA5R-PhYUmn5OOQtYMH4h6I0nSsKchNAySU/preview | |
| """ | |
| self.trace = open("trace.json", "w") | |
| self.trace.write("[") | |
| tid_counter = 0 | |
| self.tids = {} | |
| for ci, core in enumerate(self.cores): | |
| self.trace.write( | |
| f'{{"name": "process_name", "ph": "M", "pid": {ci}, "tid": 0, "args": {{"name":"Core {ci}"}}}},\n' | |
| ) | |
| for name, limit in SLOT_LIMITS.items(): | |
| if name == "debug": | |
| continue | |
| for i in range(limit): | |
| tid_counter += 1 | |
| self.trace.write( | |
| f'{{"name": "thread_name", "ph": "M", "pid": {ci}, "tid": {tid_counter}, "args": {{"name":"{name}-{i}"}}}},\n' | |
| ) | |
| self.tids[(ci, name, i)] = tid_counter | |
| # Add zero-length events at the start so all slots show up in Perfetto | |
| for ci, core in enumerate(self.cores): | |
| for name, limit in SLOT_LIMITS.items(): | |
| if name == "debug": | |
| continue | |
| for i in range(limit): | |
| tid = self.tids[(ci, name, i)] | |
| self.trace.write( | |
| f'{{"name": "init", "cat": "op", "ph": "X", "pid": {ci}, "tid": {tid}, "ts": 0, "dur": 0}},\n' | |
| ) | |
| for ci, core in enumerate(self.cores): | |
| self.trace.write( | |
| f'{{"name": "process_name", "ph": "M", "pid": {len(self.cores) + ci}, "tid": 0, "args": {{"name":"Core {ci} Scratch"}}}},\n' | |
| ) | |
| for addr, (name, length) in self.debug_info.scratch_map.items(): | |
| self.trace.write( | |
| f'{{"name": "thread_name", "ph": "M", "pid": {len(self.cores) + ci}, "tid": {BASE_ADDR_TID + addr}, "args": {{"name":"{name}-{length}"}}}},\n' | |
| ) | |
| def run(self): | |
| for core in self.cores: | |
| if core.state == CoreState.PAUSED: | |
| core.state = CoreState.RUNNING | |
| while any(c.state == CoreState.RUNNING for c in self.cores): | |
| has_non_debug = False | |
| for core in self.cores: | |
| if core.state != CoreState.RUNNING: | |
| continue | |
| if core.pc >= len(self.program): | |
| core.state = CoreState.STOPPED | |
| continue | |
| instr = self.program[core.pc] | |
| if self.prints: | |
| self.print_step(instr, core) | |
| core.pc += 1 | |
| self.step(instr, core) | |
| if any(name != "debug" for name in instr.keys()): | |
| has_non_debug = True | |
| if has_non_debug: | |
| self.cycle += 1 | |
| def alu(self, core, op, dest, a1, a2): | |
| a1 = core.scratch[a1] | |
| a2 = core.scratch[a2] | |
| match op: | |
| case "+": | |
| res = a1 + a2 | |
| case "-": | |
| res = a1 - a2 | |
| case "*": | |
| res = a1 * a2 | |
| case "//": | |
| res = a1 // a2 | |
| case "cdiv": | |
| res = cdiv(a1, a2) | |
| case "^": | |
| res = a1 ^ a2 | |
| case "&": | |
| res = a1 & a2 | |
| case "|": | |
| res = a1 | a2 | |
| case "<<": | |
| res = a1 << a2 | |
| case ">>": | |
| res = a1 >> a2 | |
| case "%": | |
| res = a1 % a2 | |
| case "<": | |
| res = int(a1 < a2) | |
| case "==": | |
| res = int(a1 == a2) | |
| case _: | |
| raise NotImplementedError(f"Unknown alu op {op}") | |
| res = res % (2**32) | |
| self.scratch_write[dest] = res | |
| def valu(self, core, *slot): | |
| match slot: | |
| case ("vbroadcast", dest, src): | |
| for i in range(VLEN): | |
| self.scratch_write[dest + i] = core.scratch[src] | |
| case ("multiply_add", dest, a, b, c): | |
| for i in range(VLEN): | |
| mul = (core.scratch[a + i] * core.scratch[b + i]) % (2**32) | |
| self.scratch_write[dest + i] = (mul + core.scratch[c + i]) % (2**32) | |
| case (op, dest, a1, a2): | |
| for i in range(VLEN): | |
| self.alu(core, op, dest + i, a1 + i, a2 + i) | |
| case _: | |
| raise NotImplementedError(f"Unknown valu op {slot}") | |
| def load(self, core, *slot): | |
| match slot: | |
| case ("load", dest, addr): | |
| # print(dest, addr, core.scratch[addr]) | |
| self.scratch_write[dest] = self.mem[core.scratch[addr]] | |
| case ("load_offset", dest, addr, offset): | |
| # Handy for treating vector dest and addr as a full block in the mini-compiler if you want | |
| self.scratch_write[dest + offset] = self.mem[ | |
| core.scratch[addr + offset] | |
| ] | |
| case ("vload", dest, addr): # addr is a scalar | |
| addr = core.scratch[addr] | |
| for vi in range(VLEN): | |
| self.scratch_write[dest + vi] = self.mem[addr + vi] | |
| case ("const", dest, val): | |
| self.scratch_write[dest] = (val) % (2**32) | |
| case _: | |
| raise NotImplementedError(f"Unknown load op {slot}") | |
| def store(self, core, *slot): | |
| match slot: | |
| case ("store", addr, src): | |
| addr = core.scratch[addr] | |
| self.mem_write[addr] = core.scratch[src] | |
| case ("vstore", addr, src): # addr is a scalar | |
| addr = core.scratch[addr] | |
| for vi in range(VLEN): | |
| self.mem_write[addr + vi] = core.scratch[src + vi] | |
| case _: | |
| raise NotImplementedError(f"Unknown store op {slot}") | |
| def flow(self, core, *slot): | |
| match slot: | |
| case ("select", dest, cond, a, b): | |
| self.scratch_write[dest] = ( | |
| core.scratch[a] if core.scratch[cond] != 0 else core.scratch[b] | |
| ) | |
| case ("add_imm", dest, a, imm): | |
| self.scratch_write[dest] = (core.scratch[a] + imm) % (2**32) | |
| case ("vselect", dest, cond, a, b): | |
| for vi in range(VLEN): | |
| self.scratch_write[dest + vi] = ( | |
| core.scratch[a + vi] | |
| if core.scratch[cond + vi] != 0 | |
| else core.scratch[b + vi] | |
| ) | |
| case ("halt",): | |
| core.state = CoreState.STOPPED | |
| case ("pause",): | |
| if self.enable_pause: | |
| core.state = CoreState.PAUSED | |
| case ("trace_write", val): | |
| core.trace_buf.append(core.scratch[val]) | |
| case ("cond_jump", cond, addr): | |
| if core.scratch[cond] != 0: | |
| core.pc = addr | |
| case ("cond_jump_rel", cond, offset): | |
| if core.scratch[cond] != 0: | |
| core.pc += offset | |
| case ("jump", addr): | |
| core.pc = addr | |
| case ("jump_indirect", addr): | |
| core.pc = core.scratch[addr] | |
| case ("coreid", dest): | |
| self.scratch_write[dest] = core.id | |
| case _: | |
| raise NotImplementedError(f"Unknown flow op {slot}") | |
| def trace_post_step(self, instr, core): | |
| # You can add extra stuff to the trace if you want! | |
| for addr, (name, length) in self.debug_info.scratch_map.items(): | |
| if any((addr + vi) in self.scratch_write for vi in range(length)): | |
| val = str(core.scratch[addr : addr + length]) | |
| val = val.replace("[", "").replace("]", "") | |
| self.trace.write( | |
| f'{{"name": "{val}", "cat": "op", "ph": "X", "pid": {len(self.cores) + core.id}, "tid": {BASE_ADDR_TID + addr}, "ts": {self.cycle}, "dur": 1 }},\n' | |
| ) | |
| def trace_slot(self, core, slot, name, i): | |
| self.trace.write( | |
| f'{{"name": "{slot[0]}", "cat": "op", "ph": "X", "pid": {core.id}, "tid": {self.tids[(core.id, name, i)]}, "ts": {self.cycle}, "dur": 1, "args":{{"slot": "{str(slot)}", "named": "{str(self.rewrite_slot(slot))}" }} }},\n' | |
| ) | |
| def step(self, instr: Instruction, core): | |
| """ | |
| Execute all the slots in each engine for a single instruction bundle | |
| """ | |
| ENGINE_FNS = { | |
| "alu": self.alu, | |
| "valu": self.valu, | |
| "load": self.load, | |
| "store": self.store, | |
| "flow": self.flow, | |
| } | |
| self.scratch_write = {} | |
| self.mem_write = {} | |
| for name, slots in instr.items(): | |
| if name == "debug": | |
| if not self.enable_debug: | |
| continue | |
| for slot in slots: | |
| if slot[0] == "compare": | |
| loc, key = slot[1], slot[2] | |
| ref = self.value_trace[key] | |
| res = core.scratch[loc] | |
| assert res == ref, f"{res} != {ref} for {key} at pc={core.pc}" | |
| elif slot[0] == "vcompare": | |
| loc, keys = slot[1], slot[2] | |
| ref = [self.value_trace[key] for key in keys] | |
| res = core.scratch[loc : loc + VLEN] | |
| assert res == ref, ( | |
| f"{res} != {ref} for {keys} at pc={core.pc} loc={loc}" | |
| ) | |
| continue | |
| assert len(slots) <= SLOT_LIMITS[name] | |
| for i, slot in enumerate(slots): | |
| if self.trace is not None: | |
| self.trace_slot(core, slot, name, i) | |
| ENGINE_FNS[name](core, *slot) | |
| for addr, val in self.scratch_write.items(): | |
| core.scratch[addr] = val | |
| for addr, val in self.mem_write.items(): | |
| self.mem[addr] = val | |
| if self.trace: | |
| self.trace_post_step(instr, core) | |
| del self.scratch_write | |
| del self.mem_write | |
| def __del__(self): | |
| if self.trace is not None: | |
| self.trace.write("]") | |
| self.trace.close() | |
| class Tree: | |
| """ | |
| An implicit perfect balanced binary tree with values on the nodes. | |
| """ | |
| height: int | |
| values: list[int] | |
| def generate(height: int): | |
| n_nodes = 2 ** (height + 1) - 1 | |
| values = [random.randint(0, 2**30 - 1) for _ in range(n_nodes)] | |
| return Tree(height, values) | |
| class Input: | |
| """ | |
| A batch of inputs, indices to nodes (starting as 0) and initial input | |
| values. We then iterate these for a specified number of rounds. | |
| """ | |
| indices: list[int] | |
| values: list[int] | |
| rounds: int | |
| def generate(forest: Tree, batch_size: int, rounds: int): | |
| indices = [0 for _ in range(batch_size)] | |
| values = [random.randint(0, 2**30 - 1) for _ in range(batch_size)] | |
| return Input(indices, values, rounds) | |
| HASH_STAGES = [ | |
| ("+", 0x7ED55D16, "+", "<<", 12), | |
| ("^", 0xC761C23C, "^", ">>", 19), | |
| ("+", 0x165667B1, "+", "<<", 5), | |
| ("+", 0xD3A2646C, "^", "<<", 9), | |
| ("+", 0xFD7046C5, "+", "<<", 3), | |
| ("^", 0xB55A4F09, "^", ">>", 16), | |
| ] | |
| def myhash(a: int) -> int: | |
| """A simple 32-bit hash function""" | |
| fns = { | |
| "+": lambda x, y: x + y, | |
| "^": lambda x, y: x ^ y, | |
| "<<": lambda x, y: x << y, | |
| ">>": lambda x, y: x >> y, | |
| } | |
| def r(x): | |
| return x % (2**32) | |
| for op1, val1, op2, op3, val3 in HASH_STAGES: | |
| a = r(fns[op2](r(fns[op1](a, val1)), r(fns[op3](a, val3)))) | |
| return a | |
| def reference_kernel(t: Tree, inp: Input): | |
| """ | |
| Reference implementation of the kernel. | |
| A parallel tree traversal where at each node we set | |
| cur_inp_val = myhash(cur_inp_val ^ node_val) | |
| and then choose the left branch if cur_inp_val is even. | |
| If we reach the bottom of the tree we wrap around to the top. | |
| """ | |
| for h in range(inp.rounds): | |
| for i in range(len(inp.indices)): | |
| idx = inp.indices[i] | |
| val = inp.values[i] | |
| val = myhash(val ^ t.values[idx]) | |
| idx = 2 * idx + (1 if val % 2 == 0 else 2) | |
| idx = 0 if idx >= len(t.values) else idx | |
| inp.values[i] = val | |
| inp.indices[i] = idx | |
| def build_mem_image(t: Tree, inp: Input) -> list[int]: | |
| """ | |
| Build a flat memory image of the problem. | |
| """ | |
| header = 7 | |
| extra_room = len(t.values) + len(inp.indices) * 2 + VLEN * 2 + 32 | |
| mem = [0] * ( | |
| header + len(t.values) + len(inp.indices) + len(inp.values) + extra_room | |
| ) | |
| forest_values_p = header | |
| inp_indices_p = forest_values_p + len(t.values) | |
| inp_values_p = inp_indices_p + len(inp.values) | |
| extra_room = inp_values_p + len(inp.values) | |
| mem[0] = inp.rounds | |
| mem[1] = len(t.values) | |
| mem[2] = len(inp.indices) | |
| mem[3] = t.height | |
| mem[4] = forest_values_p | |
| mem[5] = inp_indices_p | |
| mem[6] = inp_values_p | |
| mem[7] = extra_room | |
| mem[header:inp_indices_p] = t.values | |
| mem[inp_indices_p:inp_values_p] = inp.indices | |
| mem[inp_values_p:] = inp.values | |
| return mem | |
| def myhash_traced(a: int, trace: dict[Any, int], round: int, batch_i: int) -> int: | |
| """A simple 32-bit hash function""" | |
| fns = { | |
| "+": lambda x, y: x + y, | |
| "^": lambda x, y: x ^ y, | |
| "<<": lambda x, y: x << y, | |
| ">>": lambda x, y: x >> y, | |
| } | |
| def r(x): | |
| return x % (2**32) | |
| for i, (op1, val1, op2, op3, val3) in enumerate(HASH_STAGES): | |
| a = r(fns[op2](r(fns[op1](a, val1)), r(fns[op3](a, val3)))) | |
| trace[(round, batch_i, "hash_stage", i)] = a | |
| return a | |
| def reference_kernel2(mem: list[int], trace: dict[Any, int] = {}): | |
| """ | |
| Reference implementation of the kernel on a flat memory. | |
| """ | |
| # This is the initial memory layout | |
| rounds = mem[0] | |
| n_nodes = mem[1] | |
| batch_size = mem[2] | |
| forest_height = mem[3] | |
| # Offsets into the memory which indices get added to | |
| forest_values_p = mem[4] | |
| inp_indices_p = mem[5] | |
| inp_values_p = mem[6] | |
| yield mem | |
| for h in range(rounds): | |
| for i in range(batch_size): | |
| idx = mem[inp_indices_p + i] | |
| trace[(h, i, "idx")] = idx | |
| val = mem[inp_values_p + i] | |
| trace[(h, i, "val")] = val | |
| node_val = mem[forest_values_p + idx] | |
| trace[(h, i, "node_val")] = node_val | |
| val = myhash_traced(val ^ node_val, trace, h, i) | |
| trace[(h, i, "hashed_val")] = val | |
| idx = 2 * idx + (1 if val % 2 == 0 else 2) | |
| trace[(h, i, "next_idx")] = idx | |
| idx = 0 if idx >= n_nodes else idx | |
| trace[(h, i, "wrapped_idx")] = idx | |
| mem[inp_values_p + i] = val | |
| mem[inp_indices_p + i] = idx | |
| # You can add new yields or move this around for debugging | |
| # as long as it's matched by pause instructions. | |
| # The submission tests evaluate only on final memory. | |
| yield mem | |