Spaces:
Running
on
A10G
Running
on
A10G
Commit
·
1d07708
1
Parent(s):
3aa84d6
Initial commit: VLIW kernel optimizer via RL
Browse files- README.md +26 -5
- __pycache__/app.cpython-314.pyc +0 -0
- app.py +549 -0
- original_performance_takehome/.git_backup/HEAD +1 -0
- original_performance_takehome/.git_backup/config +13 -0
- original_performance_takehome/.git_backup/description +1 -0
- original_performance_takehome/.git_backup/hooks/applypatch-msg.sample +15 -0
- original_performance_takehome/.git_backup/hooks/commit-msg.sample +24 -0
- original_performance_takehome/.git_backup/hooks/fsmonitor-watchman.sample +174 -0
- original_performance_takehome/.git_backup/hooks/post-update.sample +8 -0
- original_performance_takehome/.git_backup/hooks/pre-applypatch.sample +14 -0
- original_performance_takehome/.git_backup/hooks/pre-commit.sample +49 -0
- original_performance_takehome/.git_backup/hooks/pre-merge-commit.sample +13 -0
- original_performance_takehome/.git_backup/hooks/pre-push.sample +53 -0
- original_performance_takehome/.git_backup/hooks/pre-rebase.sample +169 -0
- original_performance_takehome/.git_backup/hooks/pre-receive.sample +24 -0
- original_performance_takehome/.git_backup/hooks/prepare-commit-msg.sample +42 -0
- original_performance_takehome/.git_backup/hooks/push-to-checkout.sample +78 -0
- original_performance_takehome/.git_backup/hooks/sendemail-validate.sample +77 -0
- original_performance_takehome/.git_backup/hooks/update.sample +128 -0
- original_performance_takehome/.git_backup/index +0 -0
- original_performance_takehome/.git_backup/info/exclude +6 -0
- original_performance_takehome/.git_backup/logs/HEAD +1 -0
- original_performance_takehome/.git_backup/logs/refs/heads/main +1 -0
- original_performance_takehome/.git_backup/logs/refs/remotes/origin/HEAD +1 -0
- original_performance_takehome/.git_backup/objects/pack/pack-813c2c470e2abf2cfcfb6aa8ba6478e559e69577.idx +0 -0
- original_performance_takehome/.git_backup/objects/pack/pack-813c2c470e2abf2cfcfb6aa8ba6478e559e69577.pack +0 -0
- original_performance_takehome/.git_backup/objects/pack/pack-813c2c470e2abf2cfcfb6aa8ba6478e559e69577.rev +0 -0
- original_performance_takehome/.git_backup/packed-refs +4 -0
- original_performance_takehome/.git_backup/refs/heads/main +1 -0
- original_performance_takehome/.git_backup/refs/remotes/origin/HEAD +1 -0
- original_performance_takehome/.gitignore +4 -0
- original_performance_takehome/Readme.md +39 -0
- original_performance_takehome/perf_takehome.py +275 -0
- original_performance_takehome/problem.py +568 -0
- original_performance_takehome/tests/frozen_problem.py +568 -0
- original_performance_takehome/tests/submission_tests.py +119 -0
- original_performance_takehome/watch_trace.html +132 -0
- original_performance_takehome/watch_trace.py +84 -0
- requirements.txt +8 -0
README.md
CHANGED
|
@@ -1,12 +1,33 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
-
emoji:
|
| 4 |
colorFrom: blue
|
| 5 |
-
colorTo:
|
| 6 |
sdk: gradio
|
| 7 |
-
sdk_version:
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
|
|
|
| 10 |
---
|
| 11 |
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: VLIW Kernel Optimizer
|
| 3 |
+
emoji: "⚡"
|
| 4 |
colorFrom: blue
|
| 5 |
+
colorTo: purple
|
| 6 |
sdk: gradio
|
| 7 |
+
sdk_version: 5.0.0
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
+
license: mit
|
| 11 |
---
|
| 12 |
|
| 13 |
+
# VLIW Kernel Optimization via Reinforcement Learning
|
| 14 |
+
|
| 15 |
+
Train a language model to generate optimized VLIW/SIMD kernels using test-time RL training.
|
| 16 |
+
|
| 17 |
+
## Goal
|
| 18 |
+
- **Baseline:** 147,734 cycles
|
| 19 |
+
- **Target:** 1,363 cycles (108x speedup)
|
| 20 |
+
|
| 21 |
+
## How it works
|
| 22 |
+
1. Model generates kernel code
|
| 23 |
+
2. Simulator evaluates cycle count
|
| 24 |
+
3. RL training improves the model based on rewards
|
| 25 |
+
|
| 26 |
+
## Usage
|
| 27 |
+
1. Select a model (Qwen2.5-Coder-7B recommended)
|
| 28 |
+
2. Configure training steps (50 recommended)
|
| 29 |
+
3. Click "Start Training"
|
| 30 |
+
4. Monitor progress - training continues even if you close the browser
|
| 31 |
+
|
| 32 |
+
## Hardware
|
| 33 |
+
Requires A10G GPU (HF Spaces Pro)
|
__pycache__/app.cpython-314.pyc
ADDED
|
Binary file (25.2 kB). View file
|
|
|
app.py
ADDED
|
@@ -0,0 +1,549 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
HF Spaces app for VLIW kernel optimization via RL.
|
| 3 |
+
Deploy to HF Spaces Pro (A10G GPU).
|
| 4 |
+
|
| 5 |
+
This is self-contained - includes verification logic inline.
|
| 6 |
+
"""
|
| 7 |
+
import os
|
| 8 |
+
import sys
|
| 9 |
+
import re
|
| 10 |
+
import threading
|
| 11 |
+
import time
|
| 12 |
+
import random
|
| 13 |
+
from datetime import datetime
|
| 14 |
+
|
| 15 |
+
import gradio as gr
|
| 16 |
+
|
| 17 |
+
# Thread lock for safe state access
|
| 18 |
+
training_state_lock = threading.Lock()
|
| 19 |
+
|
| 20 |
+
# Add simulator path
|
| 21 |
+
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 22 |
+
PERF_TAKEHOME_PATH = os.path.join(SCRIPT_DIR, "original_performance_takehome")
|
| 23 |
+
if os.path.exists(PERF_TAKEHOME_PATH):
|
| 24 |
+
sys.path.insert(0, PERF_TAKEHOME_PATH)
|
| 25 |
+
|
| 26 |
+
# Constants
|
| 27 |
+
BASELINE_CYCLES = 147734
|
| 28 |
+
TARGET_CYCLES = 1363
|
| 29 |
+
SCORE_SCALE = 3000.0
|
| 30 |
+
|
| 31 |
+
# Training state (global)
|
| 32 |
+
training_state = {
|
| 33 |
+
"running": False,
|
| 34 |
+
"step": 0,
|
| 35 |
+
"total_steps": 0,
|
| 36 |
+
"best_cycles": BASELINE_CYCLES,
|
| 37 |
+
"best_code": None,
|
| 38 |
+
"log": [],
|
| 39 |
+
"start_time": None,
|
| 40 |
+
"results": [],
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
SYSTEM_PROMPT = '''Write optimized VLIW/SIMD kernel. OUTPUT ONLY ONE ```python CODE BLOCK.
|
| 44 |
+
|
| 45 |
+
ARCHITECTURE: 12 ALU + 6 VALU (VLEN=8) + 2 load + 2 store + 1 flow slots per cycle. 1536-word scratch.
|
| 46 |
+
|
| 47 |
+
API:
|
| 48 |
+
- alloc_scratch(name, length) -> addr
|
| 49 |
+
- scratch_const(val, name) -> addr
|
| 50 |
+
- add(engine, slot): engine in {alu, valu, load, store, flow}
|
| 51 |
+
- alu: (op, dst, src1, src2) where op in {+,-,*,/,%,^,&,|,==,!=,<,>,<=,>=}
|
| 52 |
+
- valu: same ops but on vectors (VLEN=8)
|
| 53 |
+
- load: (load,dst,addr), (vload,dst,addr), (const,dst,val), (vbroadcast,dst,scalar_addr)
|
| 54 |
+
- store: (store,addr,src), (vstore,addr,src)
|
| 55 |
+
- flow: (select,dst,cond,t,f), (jump,label), (jump_if_zero,cond,label), (halt,)
|
| 56 |
+
- label(name): mark code position
|
| 57 |
+
- build(slots, vliw=True): pack slots into VLIW bundle
|
| 58 |
+
|
| 59 |
+
MEMORY: mem[4]=forest_values, mem[5]=inp_indices, mem[6]=inp_values (256 elements each)
|
| 60 |
+
|
| 61 |
+
ALGORITHM: 16 rounds x 256 items: load idx,val; val=hash(val^tree[idx]); idx=2*idx+(1 or 2 based on val%2); store. Hash is 16 stages using HASH_STAGES constant.
|
| 62 |
+
|
| 63 |
+
OPTIMIZATION:
|
| 64 |
+
1. Use vload/vstore: process 8 elements per instruction (256/8 = 32 vector iterations)
|
| 65 |
+
2. Pack ops: 6 VALU slots = 6 vector ops per cycle
|
| 66 |
+
3. Unroll: minimize loop overhead
|
| 67 |
+
4. Pipeline: overlap loads with compute
|
| 68 |
+
|
| 69 |
+
You MUST override build_kernel() with actual instructions. Do NOT just call super().
|
| 70 |
+
'''
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def extract_code_block(text: str) -> str:
|
| 74 |
+
"""Extract python code from markdown code blocks."""
|
| 75 |
+
pattern = r"```python\s*(.*?)```"
|
| 76 |
+
matches = re.findall(pattern, text, re.DOTALL)
|
| 77 |
+
if matches:
|
| 78 |
+
return matches[-1].strip()
|
| 79 |
+
pattern = r"```\s*(.*?)```"
|
| 80 |
+
matches = re.findall(pattern, text, re.DOTALL)
|
| 81 |
+
if matches:
|
| 82 |
+
return matches[-1].strip()
|
| 83 |
+
return text.strip()
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def verify_perf_takehome(generation: str, score_scale: float = SCORE_SCALE) -> dict:
|
| 87 |
+
"""
|
| 88 |
+
Verify kernel code and return score.
|
| 89 |
+
Self-contained verification using the simulator.
|
| 90 |
+
"""
|
| 91 |
+
try:
|
| 92 |
+
code = generation.strip()
|
| 93 |
+
|
| 94 |
+
if not code:
|
| 95 |
+
return {"score": 0.0, "correctness": 0.0, "performance": -1000000,
|
| 96 |
+
"msg": "Empty code", "cycles": None}
|
| 97 |
+
|
| 98 |
+
if "def run" not in code:
|
| 99 |
+
return {"score": 0.0, "correctness": 0.0, "performance": -1000000,
|
| 100 |
+
"msg": "No 'run' function defined", "cycles": None}
|
| 101 |
+
|
| 102 |
+
# Build execution environment
|
| 103 |
+
exec_globals = {
|
| 104 |
+
"FOREST_HEIGHT": 10,
|
| 105 |
+
"ROUNDS": 16,
|
| 106 |
+
"BATCH_SIZE": 256,
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
# Setup imports
|
| 110 |
+
setup_code = f'''
|
| 111 |
+
import sys
|
| 112 |
+
sys.path.insert(0, "{PERF_TAKEHOME_PATH}")
|
| 113 |
+
from problem import Machine, Tree, Input, build_mem_image, N_CORES, VLEN, reference_kernel2
|
| 114 |
+
from perf_takehome import KernelBuilder, HASH_STAGES, BASELINE
|
| 115 |
+
import random
|
| 116 |
+
'''
|
| 117 |
+
full_code = setup_code + "\n" + code
|
| 118 |
+
exec(full_code, exec_globals)
|
| 119 |
+
|
| 120 |
+
if "run" not in exec_globals:
|
| 121 |
+
return {"score": 0.0, "correctness": 0.0, "performance": -1000000,
|
| 122 |
+
"msg": "No 'run' function after exec", "cycles": None}
|
| 123 |
+
|
| 124 |
+
# Require OptimizedKernelBuilder
|
| 125 |
+
if "OptimizedKernelBuilder" not in exec_globals:
|
| 126 |
+
return {"score": 0.0, "correctness": 0.0, "performance": -1000000,
|
| 127 |
+
"msg": "No OptimizedKernelBuilder class", "cycles": None}
|
| 128 |
+
|
| 129 |
+
# Run verification
|
| 130 |
+
random.seed(123)
|
| 131 |
+
from problem import Tree, Input, Machine, build_mem_image, N_CORES, reference_kernel2
|
| 132 |
+
|
| 133 |
+
forest = Tree.generate(10)
|
| 134 |
+
inp = Input.generate(forest, 256, 16)
|
| 135 |
+
mem = build_mem_image(forest, inp)
|
| 136 |
+
|
| 137 |
+
# Get reference output
|
| 138 |
+
ref_mem = None
|
| 139 |
+
for ref_mem in reference_kernel2(list(mem)):
|
| 140 |
+
pass
|
| 141 |
+
|
| 142 |
+
if ref_mem is None:
|
| 143 |
+
return {"score": 0.0, "correctness": 0.0, "performance": -1000000,
|
| 144 |
+
"msg": "Reference kernel failed", "cycles": None}
|
| 145 |
+
|
| 146 |
+
# Run submitted kernel
|
| 147 |
+
kb = exec_globals["OptimizedKernelBuilder"]()
|
| 148 |
+
kb.build_kernel(10, len(forest.values), 256, 16)
|
| 149 |
+
machine = Machine(list(mem), kb.instrs, kb.debug_info(), n_cores=N_CORES)
|
| 150 |
+
machine.enable_pause = False
|
| 151 |
+
machine.enable_debug = False
|
| 152 |
+
machine.run()
|
| 153 |
+
|
| 154 |
+
cycles = machine.cycle
|
| 155 |
+
|
| 156 |
+
# Validate cycles
|
| 157 |
+
if cycles <= 100:
|
| 158 |
+
return {"score": 0.0, "correctness": 0.0, "performance": -1000000,
|
| 159 |
+
"msg": f"Suspiciously low cycles ({cycles})", "cycles": cycles}
|
| 160 |
+
|
| 161 |
+
if cycles > 200000:
|
| 162 |
+
return {"score": 0.0, "correctness": 0.0, "performance": -1000000,
|
| 163 |
+
"msg": f"Cycles too high: {cycles}", "cycles": cycles}
|
| 164 |
+
|
| 165 |
+
# Compare outputs
|
| 166 |
+
inp_values_p = ref_mem[6]
|
| 167 |
+
expected = ref_mem[inp_values_p : inp_values_p + len(inp.values)]
|
| 168 |
+
actual = machine.mem[inp_values_p : inp_values_p + len(inp.values)]
|
| 169 |
+
|
| 170 |
+
if expected != actual:
|
| 171 |
+
return {"score": 0.0, "correctness": 0.0, "performance": -1000000,
|
| 172 |
+
"msg": f"Incorrect output (cycles={cycles})", "cycles": cycles}
|
| 173 |
+
|
| 174 |
+
# Success!
|
| 175 |
+
score = score_scale / cycles
|
| 176 |
+
return {
|
| 177 |
+
"score": score,
|
| 178 |
+
"correctness": 1.0,
|
| 179 |
+
"performance": -cycles,
|
| 180 |
+
"msg": f"Success: {cycles} cycles",
|
| 181 |
+
"cycles": cycles,
|
| 182 |
+
}
|
| 183 |
+
|
| 184 |
+
except Exception as e:
|
| 185 |
+
import traceback
|
| 186 |
+
tb = traceback.format_exc()
|
| 187 |
+
error_line = tb.strip().split('\n')[-1][:200]
|
| 188 |
+
return {"score": 0.0, "correctness": 0.0, "performance": -1000000,
|
| 189 |
+
"msg": f"Error: {error_line}", "cycles": None}
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
def log(msg: str):
|
| 193 |
+
"""Add to training log (thread-safe)."""
|
| 194 |
+
timestamp = datetime.now().strftime("%H:%M:%S")
|
| 195 |
+
formatted = f"[{timestamp}] {msg}"
|
| 196 |
+
with training_state_lock:
|
| 197 |
+
training_state["log"].append(formatted)
|
| 198 |
+
print(formatted)
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
def reward_function(completions: list[str], **kwargs) -> list[float]:
|
| 202 |
+
"""Compute rewards for completions."""
|
| 203 |
+
rewards = []
|
| 204 |
+
for completion in completions:
|
| 205 |
+
try:
|
| 206 |
+
code = extract_code_block(completion)
|
| 207 |
+
result = verify_perf_takehome(code)
|
| 208 |
+
reward = result["score"]
|
| 209 |
+
|
| 210 |
+
if result["correctness"] > 0:
|
| 211 |
+
reward += 1.0
|
| 212 |
+
cycles = result.get("cycles")
|
| 213 |
+
if cycles:
|
| 214 |
+
with training_state_lock:
|
| 215 |
+
training_state["results"].append({
|
| 216 |
+
"step": training_state["step"],
|
| 217 |
+
"cycles": cycles,
|
| 218 |
+
"time": time.time() - (training_state["start_time"] or time.time())
|
| 219 |
+
})
|
| 220 |
+
if cycles < training_state["best_cycles"]:
|
| 221 |
+
training_state["best_cycles"] = cycles
|
| 222 |
+
training_state["best_code"] = code
|
| 223 |
+
speedup = BASELINE_CYCLES / cycles
|
| 224 |
+
log(f"NEW BEST: {cycles:,} cycles ({speedup:.2f}x speedup)")
|
| 225 |
+
|
| 226 |
+
rewards.append(reward)
|
| 227 |
+
|
| 228 |
+
except Exception as e:
|
| 229 |
+
log(f"Reward error: {str(e)[:100]}")
|
| 230 |
+
rewards.append(0.0)
|
| 231 |
+
|
| 232 |
+
return rewards
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
def build_prompt(current_cycles: int = BASELINE_CYCLES, last_code: str = "") -> str:
|
| 236 |
+
"""Build training prompt."""
|
| 237 |
+
prompt = f"""{SYSTEM_PROMPT}
|
| 238 |
+
|
| 239 |
+
CURRENT: {current_cycles:,} cycles. TARGET: <{TARGET_CYCLES:,} cycles (need {current_cycles//TARGET_CYCLES}x speedup).
|
| 240 |
+
"""
|
| 241 |
+
if last_code:
|
| 242 |
+
prompt += f"""
|
| 243 |
+
Previous best attempt:
|
| 244 |
+
```python
|
| 245 |
+
{last_code[:2000]}
|
| 246 |
+
```
|
| 247 |
+
|
| 248 |
+
Improve this code to reduce cycles further.
|
| 249 |
+
"""
|
| 250 |
+
else:
|
| 251 |
+
prompt += """
|
| 252 |
+
Write a complete solution with:
|
| 253 |
+
1. A run() function that returns (cycles, code_string)
|
| 254 |
+
2. An OptimizedKernelBuilder class with build_kernel() method
|
| 255 |
+
"""
|
| 256 |
+
return prompt
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
def run_training(model_name: str, num_steps: int, batch_size: int, lr: float, lora_rank: int):
|
| 260 |
+
"""Main training loop."""
|
| 261 |
+
global training_state
|
| 262 |
+
|
| 263 |
+
with training_state_lock:
|
| 264 |
+
training_state["running"] = True
|
| 265 |
+
training_state["step"] = 0
|
| 266 |
+
training_state["total_steps"] = num_steps
|
| 267 |
+
training_state["best_cycles"] = BASELINE_CYCLES
|
| 268 |
+
training_state["best_code"] = None
|
| 269 |
+
training_state["log"] = []
|
| 270 |
+
training_state["results"] = []
|
| 271 |
+
training_state["start_time"] = time.time()
|
| 272 |
+
|
| 273 |
+
log(f"Starting training: {model_name}")
|
| 274 |
+
log(f"Steps: {num_steps}, Batch: {batch_size}, LR: {lr}, LoRA rank: {lora_rank}")
|
| 275 |
+
|
| 276 |
+
try:
|
| 277 |
+
import torch
|
| 278 |
+
from datasets import Dataset
|
| 279 |
+
from transformers import AutoTokenizer, BitsAndBytesConfig, TrainerCallback
|
| 280 |
+
from peft import LoraConfig
|
| 281 |
+
from trl import GRPOConfig, GRPOTrainer
|
| 282 |
+
|
| 283 |
+
# Check GPU
|
| 284 |
+
if torch.cuda.is_available():
|
| 285 |
+
gpu_name = torch.cuda.get_device_name(0)
|
| 286 |
+
gpu_mem = torch.cuda.get_device_properties(0).total_memory / 1e9
|
| 287 |
+
log(f"GPU: {gpu_name} ({gpu_mem:.1f}GB)")
|
| 288 |
+
else:
|
| 289 |
+
log("WARNING: No GPU detected!")
|
| 290 |
+
|
| 291 |
+
log("Loading tokenizer...")
|
| 292 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 293 |
+
if tokenizer.pad_token is None:
|
| 294 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 295 |
+
|
| 296 |
+
# Create dataset
|
| 297 |
+
prompt = build_prompt(BASELINE_CYCLES, "")
|
| 298 |
+
dataset = Dataset.from_dict({"prompt": [prompt] * 64})
|
| 299 |
+
|
| 300 |
+
# LoRA config
|
| 301 |
+
peft_config = LoraConfig(
|
| 302 |
+
r=lora_rank,
|
| 303 |
+
lora_alpha=lora_rank * 2,
|
| 304 |
+
lora_dropout=0.05,
|
| 305 |
+
bias="none",
|
| 306 |
+
task_type="CAUSAL_LM",
|
| 307 |
+
target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
|
| 308 |
+
"gate_proj", "up_proj", "down_proj"],
|
| 309 |
+
)
|
| 310 |
+
|
| 311 |
+
# Training config
|
| 312 |
+
output_dir = f"./output/{datetime.now().strftime('%Y%m%d-%H%M%S')}"
|
| 313 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 314 |
+
|
| 315 |
+
training_args = GRPOConfig(
|
| 316 |
+
output_dir=output_dir,
|
| 317 |
+
num_train_epochs=num_steps,
|
| 318 |
+
per_device_train_batch_size=batch_size,
|
| 319 |
+
gradient_accumulation_steps=4,
|
| 320 |
+
learning_rate=lr,
|
| 321 |
+
logging_steps=1,
|
| 322 |
+
save_steps=10,
|
| 323 |
+
max_completion_length=2048,
|
| 324 |
+
max_prompt_length=2048,
|
| 325 |
+
temperature=0.7,
|
| 326 |
+
num_generations=4,
|
| 327 |
+
beta=0.1,
|
| 328 |
+
bf16=True,
|
| 329 |
+
report_to="none",
|
| 330 |
+
)
|
| 331 |
+
|
| 332 |
+
# Quantization for 7B model on A10G
|
| 333 |
+
quant_config = None
|
| 334 |
+
if "7B" in model_name or "7b" in model_name:
|
| 335 |
+
log("Using 4-bit quantization for 7B model")
|
| 336 |
+
quant_config = BitsAndBytesConfig(
|
| 337 |
+
load_in_4bit=True,
|
| 338 |
+
bnb_4bit_compute_dtype=torch.bfloat16,
|
| 339 |
+
bnb_4bit_use_double_quant=True,
|
| 340 |
+
bnb_4bit_quant_type="nf4",
|
| 341 |
+
)
|
| 342 |
+
|
| 343 |
+
log("Loading model (this may take a few minutes)...")
|
| 344 |
+
|
| 345 |
+
model_kwargs = {}
|
| 346 |
+
if quant_config:
|
| 347 |
+
model_kwargs["quantization_config"] = quant_config
|
| 348 |
+
|
| 349 |
+
# Create stop callback
|
| 350 |
+
class StopCallback(TrainerCallback):
|
| 351 |
+
def on_step_end(self, args, state, control, **kwargs):
|
| 352 |
+
if not training_state["running"]:
|
| 353 |
+
log("Stop signal received, halting training...")
|
| 354 |
+
control.should_training_stop = True
|
| 355 |
+
return control
|
| 356 |
+
|
| 357 |
+
trainer = GRPOTrainer(
|
| 358 |
+
model=model_name,
|
| 359 |
+
reward_funcs=[reward_function],
|
| 360 |
+
args=training_args,
|
| 361 |
+
train_dataset=dataset,
|
| 362 |
+
peft_config=peft_config,
|
| 363 |
+
processing_class=tokenizer,
|
| 364 |
+
model_init_kwargs=model_kwargs,
|
| 365 |
+
callbacks=[StopCallback()],
|
| 366 |
+
)
|
| 367 |
+
|
| 368 |
+
log("Model loaded! Starting training...")
|
| 369 |
+
|
| 370 |
+
# Train
|
| 371 |
+
trainer.train()
|
| 372 |
+
|
| 373 |
+
log("Training complete!")
|
| 374 |
+
|
| 375 |
+
# Save
|
| 376 |
+
trainer.save_model(os.path.join(output_dir, "final"))
|
| 377 |
+
log(f"Model saved to {output_dir}/final")
|
| 378 |
+
|
| 379 |
+
# Save best code
|
| 380 |
+
if training_state["best_code"]:
|
| 381 |
+
with open(os.path.join(output_dir, "best_code.py"), "w") as f:
|
| 382 |
+
f.write(training_state["best_code"])
|
| 383 |
+
log("Best code saved!")
|
| 384 |
+
|
| 385 |
+
except Exception as e:
|
| 386 |
+
import traceback
|
| 387 |
+
log(f"ERROR: {str(e)}")
|
| 388 |
+
log(traceback.format_exc())
|
| 389 |
+
|
| 390 |
+
finally:
|
| 391 |
+
with training_state_lock:
|
| 392 |
+
training_state["running"] = False
|
| 393 |
+
elapsed = time.time() - training_state["start_time"]
|
| 394 |
+
best = training_state["best_cycles"]
|
| 395 |
+
log(f"Total time: {elapsed/60:.1f} minutes")
|
| 396 |
+
log(f"Best result: {best:,} cycles")
|
| 397 |
+
|
| 398 |
+
|
| 399 |
+
def start_training(model_name, num_steps, batch_size, lr, lora_rank):
|
| 400 |
+
"""Start training in background."""
|
| 401 |
+
if training_state["running"]:
|
| 402 |
+
return "Training already running!"
|
| 403 |
+
|
| 404 |
+
thread = threading.Thread(
|
| 405 |
+
target=run_training,
|
| 406 |
+
args=(model_name, int(num_steps), int(batch_size), float(lr), int(lora_rank)),
|
| 407 |
+
daemon=False # Non-daemon to ensure training completes
|
| 408 |
+
)
|
| 409 |
+
thread.start()
|
| 410 |
+
return "Training started! Monitor progress below."
|
| 411 |
+
|
| 412 |
+
|
| 413 |
+
def stop_training():
|
| 414 |
+
"""Signal training to stop."""
|
| 415 |
+
with training_state_lock:
|
| 416 |
+
training_state["running"] = False
|
| 417 |
+
return "Stop signal sent. Training will stop after current step."
|
| 418 |
+
|
| 419 |
+
|
| 420 |
+
def get_status():
|
| 421 |
+
"""Get current status as markdown."""
|
| 422 |
+
if not training_state["start_time"]:
|
| 423 |
+
return "### Status: Not started\n\nConfigure settings and click Start Training."
|
| 424 |
+
|
| 425 |
+
with training_state_lock:
|
| 426 |
+
elapsed = time.time() - training_state["start_time"]
|
| 427 |
+
elapsed_str = f"{elapsed/60:.1f} min"
|
| 428 |
+
best_cycles = max(training_state["best_cycles"], 1) # Prevent division by zero
|
| 429 |
+
is_running = training_state["running"]
|
| 430 |
+
log_lines = training_state["log"][-15:]
|
| 431 |
+
|
| 432 |
+
speedup = BASELINE_CYCLES / best_cycles
|
| 433 |
+
progress_pct = (1 - best_cycles / BASELINE_CYCLES) * 100
|
| 434 |
+
|
| 435 |
+
status = f"""### Status: {'Running' if is_running else 'Stopped'}
|
| 436 |
+
|
| 437 |
+
| Metric | Value |
|
| 438 |
+
|--------|-------|
|
| 439 |
+
| Elapsed | {elapsed_str} |
|
| 440 |
+
| Best Cycles | **{best_cycles:,}** |
|
| 441 |
+
| Speedup | **{speedup:.2f}x** |
|
| 442 |
+
| Progress to Target | {progress_pct:.1f}% |
|
| 443 |
+
| Target | {TARGET_CYCLES:,} cycles |
|
| 444 |
+
|
| 445 |
+
---
|
| 446 |
+
|
| 447 |
+
### Recent Log
|
| 448 |
+
```
|
| 449 |
+
{chr(10).join(log_lines)}
|
| 450 |
+
```
|
| 451 |
+
"""
|
| 452 |
+
return status
|
| 453 |
+
|
| 454 |
+
|
| 455 |
+
def get_best_code():
|
| 456 |
+
"""Get best code found."""
|
| 457 |
+
with training_state_lock:
|
| 458 |
+
best_code = training_state["best_code"]
|
| 459 |
+
if best_code:
|
| 460 |
+
return best_code
|
| 461 |
+
return "# No valid code found yet.\n# Start training to generate optimized kernels."
|
| 462 |
+
|
| 463 |
+
|
| 464 |
+
def get_results_chart():
|
| 465 |
+
"""Get results as simple text chart."""
|
| 466 |
+
with training_state_lock:
|
| 467 |
+
results = list(training_state["results"][-20:])
|
| 468 |
+
|
| 469 |
+
if not results:
|
| 470 |
+
return "No results yet."
|
| 471 |
+
|
| 472 |
+
lines = ["Cycles over time:", ""]
|
| 473 |
+
for r in results:
|
| 474 |
+
bar_len = max(1, int(50 * r["cycles"] / BASELINE_CYCLES))
|
| 475 |
+
bar = "#" * bar_len
|
| 476 |
+
lines.append(f"{r['cycles']:>7,} | {bar}")
|
| 477 |
+
|
| 478 |
+
return "\n".join(lines)
|
| 479 |
+
|
| 480 |
+
|
| 481 |
+
# Build Gradio UI
|
| 482 |
+
with gr.Blocks(title="VLIW Kernel Optimizer", theme=gr.themes.Soft()) as demo:
|
| 483 |
+
gr.Markdown("""
|
| 484 |
+
# VLIW Kernel Optimization via Reinforcement Learning
|
| 485 |
+
|
| 486 |
+
Train a language model to generate optimized VLIW/SIMD kernels.
|
| 487 |
+
|
| 488 |
+
| Baseline | Target | Goal |
|
| 489 |
+
|----------|--------|------|
|
| 490 |
+
| 147,734 cycles | 1,363 cycles | 108x speedup |
|
| 491 |
+
""")
|
| 492 |
+
|
| 493 |
+
with gr.Row():
|
| 494 |
+
with gr.Column(scale=1):
|
| 495 |
+
gr.Markdown("### Configuration")
|
| 496 |
+
|
| 497 |
+
model_dropdown = gr.Dropdown(
|
| 498 |
+
choices=[
|
| 499 |
+
"Qwen/Qwen2.5-Coder-7B-Instruct",
|
| 500 |
+
"Qwen/Qwen2.5-Coder-3B-Instruct",
|
| 501 |
+
"Qwen/Qwen2.5-Coder-1.5B-Instruct",
|
| 502 |
+
"deepseek-ai/deepseek-coder-6.7b-instruct",
|
| 503 |
+
"codellama/CodeLlama-7b-Instruct-hf",
|
| 504 |
+
],
|
| 505 |
+
value="Qwen/Qwen2.5-Coder-7B-Instruct",
|
| 506 |
+
label="Model"
|
| 507 |
+
)
|
| 508 |
+
|
| 509 |
+
steps_slider = gr.Slider(1, 100, value=50, step=1, label="Training Steps")
|
| 510 |
+
batch_slider = gr.Slider(1, 8, value=4, step=1, label="Batch Size")
|
| 511 |
+
lr_input = gr.Number(value=2e-4, label="Learning Rate")
|
| 512 |
+
lora_slider = gr.Slider(8, 64, value=32, step=8, label="LoRA Rank")
|
| 513 |
+
|
| 514 |
+
with gr.Row():
|
| 515 |
+
start_btn = gr.Button("Start Training", variant="primary", size="lg")
|
| 516 |
+
stop_btn = gr.Button("Stop", variant="stop")
|
| 517 |
+
|
| 518 |
+
with gr.Column(scale=2):
|
| 519 |
+
status_md = gr.Markdown("### Status: Not started")
|
| 520 |
+
refresh_btn = gr.Button("Refresh", size="sm")
|
| 521 |
+
|
| 522 |
+
with gr.Row():
|
| 523 |
+
with gr.Column():
|
| 524 |
+
gr.Markdown("### Best Code Found")
|
| 525 |
+
code_output = gr.Code(language="python", lines=25)
|
| 526 |
+
code_btn = gr.Button("Show Best Code")
|
| 527 |
+
|
| 528 |
+
with gr.Column():
|
| 529 |
+
gr.Markdown("### Results")
|
| 530 |
+
results_output = gr.Textbox(lines=15, label="Cycles Progress")
|
| 531 |
+
results_btn = gr.Button("Show Results")
|
| 532 |
+
|
| 533 |
+
# Event handlers
|
| 534 |
+
start_btn.click(
|
| 535 |
+
start_training,
|
| 536 |
+
inputs=[model_dropdown, steps_slider, batch_slider, lr_input, lora_slider],
|
| 537 |
+
outputs=[status_md]
|
| 538 |
+
)
|
| 539 |
+
stop_btn.click(stop_training, outputs=[status_md])
|
| 540 |
+
refresh_btn.click(get_status, outputs=[status_md])
|
| 541 |
+
code_btn.click(get_best_code, outputs=[code_output])
|
| 542 |
+
results_btn.click(get_results_chart, outputs=[results_output])
|
| 543 |
+
|
| 544 |
+
# Auto-refresh
|
| 545 |
+
demo.load(get_status, outputs=[status_md], every=5)
|
| 546 |
+
|
| 547 |
+
|
| 548 |
+
if __name__ == "__main__":
|
| 549 |
+
demo.launch(server_name="0.0.0.0", server_port=7860)
|
original_performance_takehome/.git_backup/HEAD
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
ref: refs/heads/main
|
original_performance_takehome/.git_backup/config
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[core]
|
| 2 |
+
repositoryformatversion = 0
|
| 3 |
+
filemode = true
|
| 4 |
+
bare = false
|
| 5 |
+
logallrefupdates = true
|
| 6 |
+
ignorecase = true
|
| 7 |
+
precomposeunicode = true
|
| 8 |
+
[remote "origin"]
|
| 9 |
+
url = https://github.com/anthropics/original_performance_takehome.git
|
| 10 |
+
fetch = +refs/heads/*:refs/remotes/origin/*
|
| 11 |
+
[branch "main"]
|
| 12 |
+
remote = origin
|
| 13 |
+
merge = refs/heads/main
|
original_performance_takehome/.git_backup/description
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
Unnamed repository; edit this file 'description' to name the repository.
|
original_performance_takehome/.git_backup/hooks/applypatch-msg.sample
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/sh
|
| 2 |
+
#
|
| 3 |
+
# An example hook script to check the commit log message taken by
|
| 4 |
+
# applypatch from an e-mail message.
|
| 5 |
+
#
|
| 6 |
+
# The hook should exit with non-zero status after issuing an
|
| 7 |
+
# appropriate message if it wants to stop the commit. The hook is
|
| 8 |
+
# allowed to edit the commit message file.
|
| 9 |
+
#
|
| 10 |
+
# To enable this hook, rename this file to "applypatch-msg".
|
| 11 |
+
|
| 12 |
+
. git-sh-setup
|
| 13 |
+
commitmsg="$(git rev-parse --git-path hooks/commit-msg)"
|
| 14 |
+
test -x "$commitmsg" && exec "$commitmsg" ${1+"$@"}
|
| 15 |
+
:
|
original_performance_takehome/.git_backup/hooks/commit-msg.sample
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/sh
|
| 2 |
+
#
|
| 3 |
+
# An example hook script to check the commit log message.
|
| 4 |
+
# Called by "git commit" with one argument, the name of the file
|
| 5 |
+
# that has the commit message. The hook should exit with non-zero
|
| 6 |
+
# status after issuing an appropriate message if it wants to stop the
|
| 7 |
+
# commit. The hook is allowed to edit the commit message file.
|
| 8 |
+
#
|
| 9 |
+
# To enable this hook, rename this file to "commit-msg".
|
| 10 |
+
|
| 11 |
+
# Uncomment the below to add a Signed-off-by line to the message.
|
| 12 |
+
# Doing this in a hook is a bad idea in general, but the prepare-commit-msg
|
| 13 |
+
# hook is more suited to it.
|
| 14 |
+
#
|
| 15 |
+
# SOB=$(git var GIT_AUTHOR_IDENT | sed -n 's/^\(.*>\).*$/Signed-off-by: \1/p')
|
| 16 |
+
# grep -qs "^$SOB" "$1" || echo "$SOB" >> "$1"
|
| 17 |
+
|
| 18 |
+
# This example catches duplicate Signed-off-by lines.
|
| 19 |
+
|
| 20 |
+
test "" = "$(grep '^Signed-off-by: ' "$1" |
|
| 21 |
+
sort | uniq -c | sed -e '/^[ ]*1[ ]/d')" || {
|
| 22 |
+
echo >&2 Duplicate Signed-off-by lines.
|
| 23 |
+
exit 1
|
| 24 |
+
}
|
original_performance_takehome/.git_backup/hooks/fsmonitor-watchman.sample
ADDED
|
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/perl
|
| 2 |
+
|
| 3 |
+
use strict;
|
| 4 |
+
use warnings;
|
| 5 |
+
use IPC::Open2;
|
| 6 |
+
|
| 7 |
+
# An example hook script to integrate Watchman
|
| 8 |
+
# (https://facebook.github.io/watchman/) with git to speed up detecting
|
| 9 |
+
# new and modified files.
|
| 10 |
+
#
|
| 11 |
+
# The hook is passed a version (currently 2) and last update token
|
| 12 |
+
# formatted as a string and outputs to stdout a new update token and
|
| 13 |
+
# all files that have been modified since the update token. Paths must
|
| 14 |
+
# be relative to the root of the working tree and separated by a single NUL.
|
| 15 |
+
#
|
| 16 |
+
# To enable this hook, rename this file to "query-watchman" and set
|
| 17 |
+
# 'git config core.fsmonitor .git/hooks/query-watchman'
|
| 18 |
+
#
|
| 19 |
+
my ($version, $last_update_token) = @ARGV;
|
| 20 |
+
|
| 21 |
+
# Uncomment for debugging
|
| 22 |
+
# print STDERR "$0 $version $last_update_token\n";
|
| 23 |
+
|
| 24 |
+
# Check the hook interface version
|
| 25 |
+
if ($version ne 2) {
|
| 26 |
+
die "Unsupported query-fsmonitor hook version '$version'.\n" .
|
| 27 |
+
"Falling back to scanning...\n";
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
my $git_work_tree = get_working_dir();
|
| 31 |
+
|
| 32 |
+
my $retry = 1;
|
| 33 |
+
|
| 34 |
+
my $json_pkg;
|
| 35 |
+
eval {
|
| 36 |
+
require JSON::XS;
|
| 37 |
+
$json_pkg = "JSON::XS";
|
| 38 |
+
1;
|
| 39 |
+
} or do {
|
| 40 |
+
require JSON::PP;
|
| 41 |
+
$json_pkg = "JSON::PP";
|
| 42 |
+
};
|
| 43 |
+
|
| 44 |
+
launch_watchman();
|
| 45 |
+
|
| 46 |
+
sub launch_watchman {
|
| 47 |
+
my $o = watchman_query();
|
| 48 |
+
if (is_work_tree_watched($o)) {
|
| 49 |
+
output_result($o->{clock}, @{$o->{files}});
|
| 50 |
+
}
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
sub output_result {
|
| 54 |
+
my ($clockid, @files) = @_;
|
| 55 |
+
|
| 56 |
+
# Uncomment for debugging watchman output
|
| 57 |
+
# open (my $fh, ">", ".git/watchman-output.out");
|
| 58 |
+
# binmode $fh, ":utf8";
|
| 59 |
+
# print $fh "$clockid\n@files\n";
|
| 60 |
+
# close $fh;
|
| 61 |
+
|
| 62 |
+
binmode STDOUT, ":utf8";
|
| 63 |
+
print $clockid;
|
| 64 |
+
print "\0";
|
| 65 |
+
local $, = "\0";
|
| 66 |
+
print @files;
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
sub watchman_clock {
|
| 70 |
+
my $response = qx/watchman clock "$git_work_tree"/;
|
| 71 |
+
die "Failed to get clock id on '$git_work_tree'.\n" .
|
| 72 |
+
"Falling back to scanning...\n" if $? != 0;
|
| 73 |
+
|
| 74 |
+
return $json_pkg->new->utf8->decode($response);
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
sub watchman_query {
|
| 78 |
+
my $pid = open2(\*CHLD_OUT, \*CHLD_IN, 'watchman -j --no-pretty')
|
| 79 |
+
or die "open2() failed: $!\n" .
|
| 80 |
+
"Falling back to scanning...\n";
|
| 81 |
+
|
| 82 |
+
# In the query expression below we're asking for names of files that
|
| 83 |
+
# changed since $last_update_token but not from the .git folder.
|
| 84 |
+
#
|
| 85 |
+
# To accomplish this, we're using the "since" generator to use the
|
| 86 |
+
# recency index to select candidate nodes and "fields" to limit the
|
| 87 |
+
# output to file names only. Then we're using the "expression" term to
|
| 88 |
+
# further constrain the results.
|
| 89 |
+
my $last_update_line = "";
|
| 90 |
+
if (substr($last_update_token, 0, 1) eq "c") {
|
| 91 |
+
$last_update_token = "\"$last_update_token\"";
|
| 92 |
+
$last_update_line = qq[\n"since": $last_update_token,];
|
| 93 |
+
}
|
| 94 |
+
my $query = <<" END";
|
| 95 |
+
["query", "$git_work_tree", {$last_update_line
|
| 96 |
+
"fields": ["name"],
|
| 97 |
+
"expression": ["not", ["dirname", ".git"]]
|
| 98 |
+
}]
|
| 99 |
+
END
|
| 100 |
+
|
| 101 |
+
# Uncomment for debugging the watchman query
|
| 102 |
+
# open (my $fh, ">", ".git/watchman-query.json");
|
| 103 |
+
# print $fh $query;
|
| 104 |
+
# close $fh;
|
| 105 |
+
|
| 106 |
+
print CHLD_IN $query;
|
| 107 |
+
close CHLD_IN;
|
| 108 |
+
my $response = do {local $/; <CHLD_OUT>};
|
| 109 |
+
|
| 110 |
+
# Uncomment for debugging the watch response
|
| 111 |
+
# open ($fh, ">", ".git/watchman-response.json");
|
| 112 |
+
# print $fh $response;
|
| 113 |
+
# close $fh;
|
| 114 |
+
|
| 115 |
+
die "Watchman: command returned no output.\n" .
|
| 116 |
+
"Falling back to scanning...\n" if $response eq "";
|
| 117 |
+
die "Watchman: command returned invalid output: $response\n" .
|
| 118 |
+
"Falling back to scanning...\n" unless $response =~ /^\{/;
|
| 119 |
+
|
| 120 |
+
return $json_pkg->new->utf8->decode($response);
|
| 121 |
+
}
|
| 122 |
+
|
| 123 |
+
sub is_work_tree_watched {
|
| 124 |
+
my ($output) = @_;
|
| 125 |
+
my $error = $output->{error};
|
| 126 |
+
if ($retry > 0 and $error and $error =~ m/unable to resolve root .* directory (.*) is not watched/) {
|
| 127 |
+
$retry--;
|
| 128 |
+
my $response = qx/watchman watch "$git_work_tree"/;
|
| 129 |
+
die "Failed to make watchman watch '$git_work_tree'.\n" .
|
| 130 |
+
"Falling back to scanning...\n" if $? != 0;
|
| 131 |
+
$output = $json_pkg->new->utf8->decode($response);
|
| 132 |
+
$error = $output->{error};
|
| 133 |
+
die "Watchman: $error.\n" .
|
| 134 |
+
"Falling back to scanning...\n" if $error;
|
| 135 |
+
|
| 136 |
+
# Uncomment for debugging watchman output
|
| 137 |
+
# open (my $fh, ">", ".git/watchman-output.out");
|
| 138 |
+
# close $fh;
|
| 139 |
+
|
| 140 |
+
# Watchman will always return all files on the first query so
|
| 141 |
+
# return the fast "everything is dirty" flag to git and do the
|
| 142 |
+
# Watchman query just to get it over with now so we won't pay
|
| 143 |
+
# the cost in git to look up each individual file.
|
| 144 |
+
my $o = watchman_clock();
|
| 145 |
+
$error = $output->{error};
|
| 146 |
+
|
| 147 |
+
die "Watchman: $error.\n" .
|
| 148 |
+
"Falling back to scanning...\n" if $error;
|
| 149 |
+
|
| 150 |
+
output_result($o->{clock}, ("/"));
|
| 151 |
+
$last_update_token = $o->{clock};
|
| 152 |
+
|
| 153 |
+
eval { launch_watchman() };
|
| 154 |
+
return 0;
|
| 155 |
+
}
|
| 156 |
+
|
| 157 |
+
die "Watchman: $error.\n" .
|
| 158 |
+
"Falling back to scanning...\n" if $error;
|
| 159 |
+
|
| 160 |
+
return 1;
|
| 161 |
+
}
|
| 162 |
+
|
| 163 |
+
sub get_working_dir {
|
| 164 |
+
my $working_dir;
|
| 165 |
+
if ($^O =~ 'msys' || $^O =~ 'cygwin') {
|
| 166 |
+
$working_dir = Win32::GetCwd();
|
| 167 |
+
$working_dir =~ tr/\\/\//;
|
| 168 |
+
} else {
|
| 169 |
+
require Cwd;
|
| 170 |
+
$working_dir = Cwd::cwd();
|
| 171 |
+
}
|
| 172 |
+
|
| 173 |
+
return $working_dir;
|
| 174 |
+
}
|
original_performance_takehome/.git_backup/hooks/post-update.sample
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/sh
|
| 2 |
+
#
|
| 3 |
+
# An example hook script to prepare a packed repository for use over
|
| 4 |
+
# dumb transports.
|
| 5 |
+
#
|
| 6 |
+
# To enable this hook, rename this file to "post-update".
|
| 7 |
+
|
| 8 |
+
exec git update-server-info
|
original_performance_takehome/.git_backup/hooks/pre-applypatch.sample
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/sh
|
| 2 |
+
#
|
| 3 |
+
# An example hook script to verify what is about to be committed
|
| 4 |
+
# by applypatch from an e-mail message.
|
| 5 |
+
#
|
| 6 |
+
# The hook should exit with non-zero status after issuing an
|
| 7 |
+
# appropriate message if it wants to stop the commit.
|
| 8 |
+
#
|
| 9 |
+
# To enable this hook, rename this file to "pre-applypatch".
|
| 10 |
+
|
| 11 |
+
. git-sh-setup
|
| 12 |
+
precommit="$(git rev-parse --git-path hooks/pre-commit)"
|
| 13 |
+
test -x "$precommit" && exec "$precommit" ${1+"$@"}
|
| 14 |
+
:
|
original_performance_takehome/.git_backup/hooks/pre-commit.sample
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/sh
|
| 2 |
+
#
|
| 3 |
+
# An example hook script to verify what is about to be committed.
|
| 4 |
+
# Called by "git commit" with no arguments. The hook should
|
| 5 |
+
# exit with non-zero status after issuing an appropriate message if
|
| 6 |
+
# it wants to stop the commit.
|
| 7 |
+
#
|
| 8 |
+
# To enable this hook, rename this file to "pre-commit".
|
| 9 |
+
|
| 10 |
+
if git rev-parse --verify HEAD >/dev/null 2>&1
|
| 11 |
+
then
|
| 12 |
+
against=HEAD
|
| 13 |
+
else
|
| 14 |
+
# Initial commit: diff against an empty tree object
|
| 15 |
+
against=$(git hash-object -t tree /dev/null)
|
| 16 |
+
fi
|
| 17 |
+
|
| 18 |
+
# If you want to allow non-ASCII filenames set this variable to true.
|
| 19 |
+
allownonascii=$(git config --type=bool hooks.allownonascii)
|
| 20 |
+
|
| 21 |
+
# Redirect output to stderr.
|
| 22 |
+
exec 1>&2
|
| 23 |
+
|
| 24 |
+
# Cross platform projects tend to avoid non-ASCII filenames; prevent
|
| 25 |
+
# them from being added to the repository. We exploit the fact that the
|
| 26 |
+
# printable range starts at the space character and ends with tilde.
|
| 27 |
+
if [ "$allownonascii" != "true" ] &&
|
| 28 |
+
# Note that the use of brackets around a tr range is ok here, (it's
|
| 29 |
+
# even required, for portability to Solaris 10's /usr/bin/tr), since
|
| 30 |
+
# the square bracket bytes happen to fall in the designated range.
|
| 31 |
+
test $(git diff-index --cached --name-only --diff-filter=A -z $against |
|
| 32 |
+
LC_ALL=C tr -d '[ -~]\0' | wc -c) != 0
|
| 33 |
+
then
|
| 34 |
+
cat <<\EOF
|
| 35 |
+
Error: Attempt to add a non-ASCII file name.
|
| 36 |
+
|
| 37 |
+
This can cause problems if you want to work with people on other platforms.
|
| 38 |
+
|
| 39 |
+
To be portable it is advisable to rename the file.
|
| 40 |
+
|
| 41 |
+
If you know what you are doing you can disable this check using:
|
| 42 |
+
|
| 43 |
+
git config hooks.allownonascii true
|
| 44 |
+
EOF
|
| 45 |
+
exit 1
|
| 46 |
+
fi
|
| 47 |
+
|
| 48 |
+
# If there are whitespace errors, print the offending file names and fail.
|
| 49 |
+
exec git diff-index --check --cached $against --
|
original_performance_takehome/.git_backup/hooks/pre-merge-commit.sample
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/sh
|
| 2 |
+
#
|
| 3 |
+
# An example hook script to verify what is about to be committed.
|
| 4 |
+
# Called by "git merge" with no arguments. The hook should
|
| 5 |
+
# exit with non-zero status after issuing an appropriate message to
|
| 6 |
+
# stderr if it wants to stop the merge commit.
|
| 7 |
+
#
|
| 8 |
+
# To enable this hook, rename this file to "pre-merge-commit".
|
| 9 |
+
|
| 10 |
+
. git-sh-setup
|
| 11 |
+
test -x "$GIT_DIR/hooks/pre-commit" &&
|
| 12 |
+
exec "$GIT_DIR/hooks/pre-commit"
|
| 13 |
+
:
|
original_performance_takehome/.git_backup/hooks/pre-push.sample
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/sh
|
| 2 |
+
|
| 3 |
+
# An example hook script to verify what is about to be pushed. Called by "git
|
| 4 |
+
# push" after it has checked the remote status, but before anything has been
|
| 5 |
+
# pushed. If this script exits with a non-zero status nothing will be pushed.
|
| 6 |
+
#
|
| 7 |
+
# This hook is called with the following parameters:
|
| 8 |
+
#
|
| 9 |
+
# $1 -- Name of the remote to which the push is being done
|
| 10 |
+
# $2 -- URL to which the push is being done
|
| 11 |
+
#
|
| 12 |
+
# If pushing without using a named remote those arguments will be equal.
|
| 13 |
+
#
|
| 14 |
+
# Information about the commits which are being pushed is supplied as lines to
|
| 15 |
+
# the standard input in the form:
|
| 16 |
+
#
|
| 17 |
+
# <local ref> <local oid> <remote ref> <remote oid>
|
| 18 |
+
#
|
| 19 |
+
# This sample shows how to prevent push of commits where the log message starts
|
| 20 |
+
# with "WIP" (work in progress).
|
| 21 |
+
|
| 22 |
+
remote="$1"
|
| 23 |
+
url="$2"
|
| 24 |
+
|
| 25 |
+
zero=$(git hash-object --stdin </dev/null | tr '[0-9a-f]' '0')
|
| 26 |
+
|
| 27 |
+
while read local_ref local_oid remote_ref remote_oid
|
| 28 |
+
do
|
| 29 |
+
if test "$local_oid" = "$zero"
|
| 30 |
+
then
|
| 31 |
+
# Handle delete
|
| 32 |
+
:
|
| 33 |
+
else
|
| 34 |
+
if test "$remote_oid" = "$zero"
|
| 35 |
+
then
|
| 36 |
+
# New branch, examine all commits
|
| 37 |
+
range="$local_oid"
|
| 38 |
+
else
|
| 39 |
+
# Update to existing branch, examine new commits
|
| 40 |
+
range="$remote_oid..$local_oid"
|
| 41 |
+
fi
|
| 42 |
+
|
| 43 |
+
# Check for WIP commit
|
| 44 |
+
commit=$(git rev-list -n 1 --grep '^WIP' "$range")
|
| 45 |
+
if test -n "$commit"
|
| 46 |
+
then
|
| 47 |
+
echo >&2 "Found WIP commit in $local_ref, not pushing"
|
| 48 |
+
exit 1
|
| 49 |
+
fi
|
| 50 |
+
fi
|
| 51 |
+
done
|
| 52 |
+
|
| 53 |
+
exit 0
|
original_performance_takehome/.git_backup/hooks/pre-rebase.sample
ADDED
|
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/sh
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2006, 2008 Junio C Hamano
|
| 4 |
+
#
|
| 5 |
+
# The "pre-rebase" hook is run just before "git rebase" starts doing
|
| 6 |
+
# its job, and can prevent the command from running by exiting with
|
| 7 |
+
# non-zero status.
|
| 8 |
+
#
|
| 9 |
+
# The hook is called with the following parameters:
|
| 10 |
+
#
|
| 11 |
+
# $1 -- the upstream the series was forked from.
|
| 12 |
+
# $2 -- the branch being rebased (or empty when rebasing the current branch).
|
| 13 |
+
#
|
| 14 |
+
# This sample shows how to prevent topic branches that are already
|
| 15 |
+
# merged to 'next' branch from getting rebased, because allowing it
|
| 16 |
+
# would result in rebasing already published history.
|
| 17 |
+
|
| 18 |
+
publish=next
|
| 19 |
+
basebranch="$1"
|
| 20 |
+
if test "$#" = 2
|
| 21 |
+
then
|
| 22 |
+
topic="refs/heads/$2"
|
| 23 |
+
else
|
| 24 |
+
topic=`git symbolic-ref HEAD` ||
|
| 25 |
+
exit 0 ;# we do not interrupt rebasing detached HEAD
|
| 26 |
+
fi
|
| 27 |
+
|
| 28 |
+
case "$topic" in
|
| 29 |
+
refs/heads/??/*)
|
| 30 |
+
;;
|
| 31 |
+
*)
|
| 32 |
+
exit 0 ;# we do not interrupt others.
|
| 33 |
+
;;
|
| 34 |
+
esac
|
| 35 |
+
|
| 36 |
+
# Now we are dealing with a topic branch being rebased
|
| 37 |
+
# on top of master. Is it OK to rebase it?
|
| 38 |
+
|
| 39 |
+
# Does the topic really exist?
|
| 40 |
+
git show-ref -q "$topic" || {
|
| 41 |
+
echo >&2 "No such branch $topic"
|
| 42 |
+
exit 1
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
# Is topic fully merged to master?
|
| 46 |
+
not_in_master=`git rev-list --pretty=oneline ^master "$topic"`
|
| 47 |
+
if test -z "$not_in_master"
|
| 48 |
+
then
|
| 49 |
+
echo >&2 "$topic is fully merged to master; better remove it."
|
| 50 |
+
exit 1 ;# we could allow it, but there is no point.
|
| 51 |
+
fi
|
| 52 |
+
|
| 53 |
+
# Is topic ever merged to next? If so you should not be rebasing it.
|
| 54 |
+
only_next_1=`git rev-list ^master "^$topic" ${publish} | sort`
|
| 55 |
+
only_next_2=`git rev-list ^master ${publish} | sort`
|
| 56 |
+
if test "$only_next_1" = "$only_next_2"
|
| 57 |
+
then
|
| 58 |
+
not_in_topic=`git rev-list "^$topic" master`
|
| 59 |
+
if test -z "$not_in_topic"
|
| 60 |
+
then
|
| 61 |
+
echo >&2 "$topic is already up to date with master"
|
| 62 |
+
exit 1 ;# we could allow it, but there is no point.
|
| 63 |
+
else
|
| 64 |
+
exit 0
|
| 65 |
+
fi
|
| 66 |
+
else
|
| 67 |
+
not_in_next=`git rev-list --pretty=oneline ^${publish} "$topic"`
|
| 68 |
+
/usr/bin/perl -e '
|
| 69 |
+
my $topic = $ARGV[0];
|
| 70 |
+
my $msg = "* $topic has commits already merged to public branch:\n";
|
| 71 |
+
my (%not_in_next) = map {
|
| 72 |
+
/^([0-9a-f]+) /;
|
| 73 |
+
($1 => 1);
|
| 74 |
+
} split(/\n/, $ARGV[1]);
|
| 75 |
+
for my $elem (map {
|
| 76 |
+
/^([0-9a-f]+) (.*)$/;
|
| 77 |
+
[$1 => $2];
|
| 78 |
+
} split(/\n/, $ARGV[2])) {
|
| 79 |
+
if (!exists $not_in_next{$elem->[0]}) {
|
| 80 |
+
if ($msg) {
|
| 81 |
+
print STDERR $msg;
|
| 82 |
+
undef $msg;
|
| 83 |
+
}
|
| 84 |
+
print STDERR " $elem->[1]\n";
|
| 85 |
+
}
|
| 86 |
+
}
|
| 87 |
+
' "$topic" "$not_in_next" "$not_in_master"
|
| 88 |
+
exit 1
|
| 89 |
+
fi
|
| 90 |
+
|
| 91 |
+
<<\DOC_END
|
| 92 |
+
|
| 93 |
+
This sample hook safeguards topic branches that have been
|
| 94 |
+
published from being rewound.
|
| 95 |
+
|
| 96 |
+
The workflow assumed here is:
|
| 97 |
+
|
| 98 |
+
* Once a topic branch forks from "master", "master" is never
|
| 99 |
+
merged into it again (either directly or indirectly).
|
| 100 |
+
|
| 101 |
+
* Once a topic branch is fully cooked and merged into "master",
|
| 102 |
+
it is deleted. If you need to build on top of it to correct
|
| 103 |
+
earlier mistakes, a new topic branch is created by forking at
|
| 104 |
+
the tip of the "master". This is not strictly necessary, but
|
| 105 |
+
it makes it easier to keep your history simple.
|
| 106 |
+
|
| 107 |
+
* Whenever you need to test or publish your changes to topic
|
| 108 |
+
branches, merge them into "next" branch.
|
| 109 |
+
|
| 110 |
+
The script, being an example, hardcodes the publish branch name
|
| 111 |
+
to be "next", but it is trivial to make it configurable via
|
| 112 |
+
$GIT_DIR/config mechanism.
|
| 113 |
+
|
| 114 |
+
With this workflow, you would want to know:
|
| 115 |
+
|
| 116 |
+
(1) ... if a topic branch has ever been merged to "next". Young
|
| 117 |
+
topic branches can have stupid mistakes you would rather
|
| 118 |
+
clean up before publishing, and things that have not been
|
| 119 |
+
merged into other branches can be easily rebased without
|
| 120 |
+
affecting other people. But once it is published, you would
|
| 121 |
+
not want to rewind it.
|
| 122 |
+
|
| 123 |
+
(2) ... if a topic branch has been fully merged to "master".
|
| 124 |
+
Then you can delete it. More importantly, you should not
|
| 125 |
+
build on top of it -- other people may already want to
|
| 126 |
+
change things related to the topic as patches against your
|
| 127 |
+
"master", so if you need further changes, it is better to
|
| 128 |
+
fork the topic (perhaps with the same name) afresh from the
|
| 129 |
+
tip of "master".
|
| 130 |
+
|
| 131 |
+
Let's look at this example:
|
| 132 |
+
|
| 133 |
+
o---o---o---o---o---o---o---o---o---o "next"
|
| 134 |
+
/ / / /
|
| 135 |
+
/ a---a---b A / /
|
| 136 |
+
/ / / /
|
| 137 |
+
/ / c---c---c---c B /
|
| 138 |
+
/ / / \ /
|
| 139 |
+
/ / / b---b C \ /
|
| 140 |
+
/ / / / \ /
|
| 141 |
+
---o---o---o---o---o---o---o---o---o---o---o "master"
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
A, B and C are topic branches.
|
| 145 |
+
|
| 146 |
+
* A has one fix since it was merged up to "next".
|
| 147 |
+
|
| 148 |
+
* B has finished. It has been fully merged up to "master" and "next",
|
| 149 |
+
and is ready to be deleted.
|
| 150 |
+
|
| 151 |
+
* C has not merged to "next" at all.
|
| 152 |
+
|
| 153 |
+
We would want to allow C to be rebased, refuse A, and encourage
|
| 154 |
+
B to be deleted.
|
| 155 |
+
|
| 156 |
+
To compute (1):
|
| 157 |
+
|
| 158 |
+
git rev-list ^master ^topic next
|
| 159 |
+
git rev-list ^master next
|
| 160 |
+
|
| 161 |
+
if these match, topic has not merged in next at all.
|
| 162 |
+
|
| 163 |
+
To compute (2):
|
| 164 |
+
|
| 165 |
+
git rev-list master..topic
|
| 166 |
+
|
| 167 |
+
if this is empty, it is fully merged to "master".
|
| 168 |
+
|
| 169 |
+
DOC_END
|
original_performance_takehome/.git_backup/hooks/pre-receive.sample
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/sh
|
| 2 |
+
#
|
| 3 |
+
# An example hook script to make use of push options.
|
| 4 |
+
# The example simply echoes all push options that start with 'echoback='
|
| 5 |
+
# and rejects all pushes when the "reject" push option is used.
|
| 6 |
+
#
|
| 7 |
+
# To enable this hook, rename this file to "pre-receive".
|
| 8 |
+
|
| 9 |
+
if test -n "$GIT_PUSH_OPTION_COUNT"
|
| 10 |
+
then
|
| 11 |
+
i=0
|
| 12 |
+
while test "$i" -lt "$GIT_PUSH_OPTION_COUNT"
|
| 13 |
+
do
|
| 14 |
+
eval "value=\$GIT_PUSH_OPTION_$i"
|
| 15 |
+
case "$value" in
|
| 16 |
+
echoback=*)
|
| 17 |
+
echo "echo from the pre-receive-hook: ${value#*=}" >&2
|
| 18 |
+
;;
|
| 19 |
+
reject)
|
| 20 |
+
exit 1
|
| 21 |
+
esac
|
| 22 |
+
i=$((i + 1))
|
| 23 |
+
done
|
| 24 |
+
fi
|
original_performance_takehome/.git_backup/hooks/prepare-commit-msg.sample
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/sh
|
| 2 |
+
#
|
| 3 |
+
# An example hook script to prepare the commit log message.
|
| 4 |
+
# Called by "git commit" with the name of the file that has the
|
| 5 |
+
# commit message, followed by the description of the commit
|
| 6 |
+
# message's source. The hook's purpose is to edit the commit
|
| 7 |
+
# message file. If the hook fails with a non-zero status,
|
| 8 |
+
# the commit is aborted.
|
| 9 |
+
#
|
| 10 |
+
# To enable this hook, rename this file to "prepare-commit-msg".
|
| 11 |
+
|
| 12 |
+
# This hook includes three examples. The first one removes the
|
| 13 |
+
# "# Please enter the commit message..." help message.
|
| 14 |
+
#
|
| 15 |
+
# The second includes the output of "git diff --name-status -r"
|
| 16 |
+
# into the message, just before the "git status" output. It is
|
| 17 |
+
# commented because it doesn't cope with --amend or with squashed
|
| 18 |
+
# commits.
|
| 19 |
+
#
|
| 20 |
+
# The third example adds a Signed-off-by line to the message, that can
|
| 21 |
+
# still be edited. This is rarely a good idea.
|
| 22 |
+
|
| 23 |
+
COMMIT_MSG_FILE=$1
|
| 24 |
+
COMMIT_SOURCE=$2
|
| 25 |
+
SHA1=$3
|
| 26 |
+
|
| 27 |
+
/usr/bin/perl -i.bak -ne 'print unless(m/^. Please enter the commit message/..m/^#$/)' "$COMMIT_MSG_FILE"
|
| 28 |
+
|
| 29 |
+
# case "$COMMIT_SOURCE,$SHA1" in
|
| 30 |
+
# ,|template,)
|
| 31 |
+
# /usr/bin/perl -i.bak -pe '
|
| 32 |
+
# print "\n" . `git diff --cached --name-status -r`
|
| 33 |
+
# if /^#/ && $first++ == 0' "$COMMIT_MSG_FILE" ;;
|
| 34 |
+
# *) ;;
|
| 35 |
+
# esac
|
| 36 |
+
|
| 37 |
+
# SOB=$(git var GIT_COMMITTER_IDENT | sed -n 's/^\(.*>\).*$/Signed-off-by: \1/p')
|
| 38 |
+
# git interpret-trailers --in-place --trailer "$SOB" "$COMMIT_MSG_FILE"
|
| 39 |
+
# if test -z "$COMMIT_SOURCE"
|
| 40 |
+
# then
|
| 41 |
+
# /usr/bin/perl -i.bak -pe 'print "\n" if !$first_line++' "$COMMIT_MSG_FILE"
|
| 42 |
+
# fi
|
original_performance_takehome/.git_backup/hooks/push-to-checkout.sample
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/sh
|
| 2 |
+
|
| 3 |
+
# An example hook script to update a checked-out tree on a git push.
|
| 4 |
+
#
|
| 5 |
+
# This hook is invoked by git-receive-pack(1) when it reacts to git
|
| 6 |
+
# push and updates reference(s) in its repository, and when the push
|
| 7 |
+
# tries to update the branch that is currently checked out and the
|
| 8 |
+
# receive.denyCurrentBranch configuration variable is set to
|
| 9 |
+
# updateInstead.
|
| 10 |
+
#
|
| 11 |
+
# By default, such a push is refused if the working tree and the index
|
| 12 |
+
# of the remote repository has any difference from the currently
|
| 13 |
+
# checked out commit; when both the working tree and the index match
|
| 14 |
+
# the current commit, they are updated to match the newly pushed tip
|
| 15 |
+
# of the branch. This hook is to be used to override the default
|
| 16 |
+
# behaviour; however the code below reimplements the default behaviour
|
| 17 |
+
# as a starting point for convenient modification.
|
| 18 |
+
#
|
| 19 |
+
# The hook receives the commit with which the tip of the current
|
| 20 |
+
# branch is going to be updated:
|
| 21 |
+
commit=$1
|
| 22 |
+
|
| 23 |
+
# It can exit with a non-zero status to refuse the push (when it does
|
| 24 |
+
# so, it must not modify the index or the working tree).
|
| 25 |
+
die () {
|
| 26 |
+
echo >&2 "$*"
|
| 27 |
+
exit 1
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
# Or it can make any necessary changes to the working tree and to the
|
| 31 |
+
# index to bring them to the desired state when the tip of the current
|
| 32 |
+
# branch is updated to the new commit, and exit with a zero status.
|
| 33 |
+
#
|
| 34 |
+
# For example, the hook can simply run git read-tree -u -m HEAD "$1"
|
| 35 |
+
# in order to emulate git fetch that is run in the reverse direction
|
| 36 |
+
# with git push, as the two-tree form of git read-tree -u -m is
|
| 37 |
+
# essentially the same as git switch or git checkout that switches
|
| 38 |
+
# branches while keeping the local changes in the working tree that do
|
| 39 |
+
# not interfere with the difference between the branches.
|
| 40 |
+
|
| 41 |
+
# The below is a more-or-less exact translation to shell of the C code
|
| 42 |
+
# for the default behaviour for git's push-to-checkout hook defined in
|
| 43 |
+
# the push_to_deploy() function in builtin/receive-pack.c.
|
| 44 |
+
#
|
| 45 |
+
# Note that the hook will be executed from the repository directory,
|
| 46 |
+
# not from the working tree, so if you want to perform operations on
|
| 47 |
+
# the working tree, you will have to adapt your code accordingly, e.g.
|
| 48 |
+
# by adding "cd .." or using relative paths.
|
| 49 |
+
|
| 50 |
+
if ! git update-index -q --ignore-submodules --refresh
|
| 51 |
+
then
|
| 52 |
+
die "Up-to-date check failed"
|
| 53 |
+
fi
|
| 54 |
+
|
| 55 |
+
if ! git diff-files --quiet --ignore-submodules --
|
| 56 |
+
then
|
| 57 |
+
die "Working directory has unstaged changes"
|
| 58 |
+
fi
|
| 59 |
+
|
| 60 |
+
# This is a rough translation of:
|
| 61 |
+
#
|
| 62 |
+
# head_has_history() ? "HEAD" : EMPTY_TREE_SHA1_HEX
|
| 63 |
+
if git cat-file -e HEAD 2>/dev/null
|
| 64 |
+
then
|
| 65 |
+
head=HEAD
|
| 66 |
+
else
|
| 67 |
+
head=$(git hash-object -t tree --stdin </dev/null)
|
| 68 |
+
fi
|
| 69 |
+
|
| 70 |
+
if ! git diff-index --quiet --cached --ignore-submodules $head --
|
| 71 |
+
then
|
| 72 |
+
die "Working directory has staged changes"
|
| 73 |
+
fi
|
| 74 |
+
|
| 75 |
+
if ! git read-tree -u -m "$commit"
|
| 76 |
+
then
|
| 77 |
+
die "Could not update working tree to new HEAD"
|
| 78 |
+
fi
|
original_performance_takehome/.git_backup/hooks/sendemail-validate.sample
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/sh
|
| 2 |
+
|
| 3 |
+
# An example hook script to validate a patch (and/or patch series) before
|
| 4 |
+
# sending it via email.
|
| 5 |
+
#
|
| 6 |
+
# The hook should exit with non-zero status after issuing an appropriate
|
| 7 |
+
# message if it wants to prevent the email(s) from being sent.
|
| 8 |
+
#
|
| 9 |
+
# To enable this hook, rename this file to "sendemail-validate".
|
| 10 |
+
#
|
| 11 |
+
# By default, it will only check that the patch(es) can be applied on top of
|
| 12 |
+
# the default upstream branch without conflicts in a secondary worktree. After
|
| 13 |
+
# validation (successful or not) of the last patch of a series, the worktree
|
| 14 |
+
# will be deleted.
|
| 15 |
+
#
|
| 16 |
+
# The following config variables can be set to change the default remote and
|
| 17 |
+
# remote ref that are used to apply the patches against:
|
| 18 |
+
#
|
| 19 |
+
# sendemail.validateRemote (default: origin)
|
| 20 |
+
# sendemail.validateRemoteRef (default: HEAD)
|
| 21 |
+
#
|
| 22 |
+
# Replace the TODO placeholders with appropriate checks according to your
|
| 23 |
+
# needs.
|
| 24 |
+
|
| 25 |
+
validate_cover_letter () {
|
| 26 |
+
file="$1"
|
| 27 |
+
# TODO: Replace with appropriate checks (e.g. spell checking).
|
| 28 |
+
true
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
validate_patch () {
|
| 32 |
+
file="$1"
|
| 33 |
+
# Ensure that the patch applies without conflicts.
|
| 34 |
+
git am -3 "$file" || return
|
| 35 |
+
# TODO: Replace with appropriate checks for this patch
|
| 36 |
+
# (e.g. checkpatch.pl).
|
| 37 |
+
true
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
validate_series () {
|
| 41 |
+
# TODO: Replace with appropriate checks for the whole series
|
| 42 |
+
# (e.g. quick build, coding style checks, etc.).
|
| 43 |
+
true
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
# main -------------------------------------------------------------------------
|
| 47 |
+
|
| 48 |
+
if test "$GIT_SENDEMAIL_FILE_COUNTER" = 1
|
| 49 |
+
then
|
| 50 |
+
remote=$(git config --default origin --get sendemail.validateRemote) &&
|
| 51 |
+
ref=$(git config --default HEAD --get sendemail.validateRemoteRef) &&
|
| 52 |
+
worktree=$(mktemp --tmpdir -d sendemail-validate.XXXXXXX) &&
|
| 53 |
+
git worktree add -fd --checkout "$worktree" "refs/remotes/$remote/$ref" &&
|
| 54 |
+
git config --replace-all sendemail.validateWorktree "$worktree"
|
| 55 |
+
else
|
| 56 |
+
worktree=$(git config --get sendemail.validateWorktree)
|
| 57 |
+
fi || {
|
| 58 |
+
echo "sendemail-validate: error: failed to prepare worktree" >&2
|
| 59 |
+
exit 1
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
unset GIT_DIR GIT_WORK_TREE
|
| 63 |
+
cd "$worktree" &&
|
| 64 |
+
|
| 65 |
+
if grep -q "^diff --git " "$1"
|
| 66 |
+
then
|
| 67 |
+
validate_patch "$1"
|
| 68 |
+
else
|
| 69 |
+
validate_cover_letter "$1"
|
| 70 |
+
fi &&
|
| 71 |
+
|
| 72 |
+
if test "$GIT_SENDEMAIL_FILE_COUNTER" = "$GIT_SENDEMAIL_FILE_TOTAL"
|
| 73 |
+
then
|
| 74 |
+
git config --unset-all sendemail.validateWorktree &&
|
| 75 |
+
trap 'git worktree remove -ff "$worktree"' EXIT &&
|
| 76 |
+
validate_series
|
| 77 |
+
fi
|
original_performance_takehome/.git_backup/hooks/update.sample
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/sh
|
| 2 |
+
#
|
| 3 |
+
# An example hook script to block unannotated tags from entering.
|
| 4 |
+
# Called by "git receive-pack" with arguments: refname sha1-old sha1-new
|
| 5 |
+
#
|
| 6 |
+
# To enable this hook, rename this file to "update".
|
| 7 |
+
#
|
| 8 |
+
# Config
|
| 9 |
+
# ------
|
| 10 |
+
# hooks.allowunannotated
|
| 11 |
+
# This boolean sets whether unannotated tags will be allowed into the
|
| 12 |
+
# repository. By default they won't be.
|
| 13 |
+
# hooks.allowdeletetag
|
| 14 |
+
# This boolean sets whether deleting tags will be allowed in the
|
| 15 |
+
# repository. By default they won't be.
|
| 16 |
+
# hooks.allowmodifytag
|
| 17 |
+
# This boolean sets whether a tag may be modified after creation. By default
|
| 18 |
+
# it won't be.
|
| 19 |
+
# hooks.allowdeletebranch
|
| 20 |
+
# This boolean sets whether deleting branches will be allowed in the
|
| 21 |
+
# repository. By default they won't be.
|
| 22 |
+
# hooks.denycreatebranch
|
| 23 |
+
# This boolean sets whether remotely creating branches will be denied
|
| 24 |
+
# in the repository. By default this is allowed.
|
| 25 |
+
#
|
| 26 |
+
|
| 27 |
+
# --- Command line
|
| 28 |
+
refname="$1"
|
| 29 |
+
oldrev="$2"
|
| 30 |
+
newrev="$3"
|
| 31 |
+
|
| 32 |
+
# --- Safety check
|
| 33 |
+
if [ -z "$GIT_DIR" ]; then
|
| 34 |
+
echo "Don't run this script from the command line." >&2
|
| 35 |
+
echo " (if you want, you could supply GIT_DIR then run" >&2
|
| 36 |
+
echo " $0 <ref> <oldrev> <newrev>)" >&2
|
| 37 |
+
exit 1
|
| 38 |
+
fi
|
| 39 |
+
|
| 40 |
+
if [ -z "$refname" -o -z "$oldrev" -o -z "$newrev" ]; then
|
| 41 |
+
echo "usage: $0 <ref> <oldrev> <newrev>" >&2
|
| 42 |
+
exit 1
|
| 43 |
+
fi
|
| 44 |
+
|
| 45 |
+
# --- Config
|
| 46 |
+
allowunannotated=$(git config --type=bool hooks.allowunannotated)
|
| 47 |
+
allowdeletebranch=$(git config --type=bool hooks.allowdeletebranch)
|
| 48 |
+
denycreatebranch=$(git config --type=bool hooks.denycreatebranch)
|
| 49 |
+
allowdeletetag=$(git config --type=bool hooks.allowdeletetag)
|
| 50 |
+
allowmodifytag=$(git config --type=bool hooks.allowmodifytag)
|
| 51 |
+
|
| 52 |
+
# check for no description
|
| 53 |
+
projectdesc=$(sed -e '1q' "$GIT_DIR/description")
|
| 54 |
+
case "$projectdesc" in
|
| 55 |
+
"Unnamed repository"* | "")
|
| 56 |
+
echo "*** Project description file hasn't been set" >&2
|
| 57 |
+
exit 1
|
| 58 |
+
;;
|
| 59 |
+
esac
|
| 60 |
+
|
| 61 |
+
# --- Check types
|
| 62 |
+
# if $newrev is 0000...0000, it's a commit to delete a ref.
|
| 63 |
+
zero=$(git hash-object --stdin </dev/null | tr '[0-9a-f]' '0')
|
| 64 |
+
if [ "$newrev" = "$zero" ]; then
|
| 65 |
+
newrev_type=delete
|
| 66 |
+
else
|
| 67 |
+
newrev_type=$(git cat-file -t $newrev)
|
| 68 |
+
fi
|
| 69 |
+
|
| 70 |
+
case "$refname","$newrev_type" in
|
| 71 |
+
refs/tags/*,commit)
|
| 72 |
+
# un-annotated tag
|
| 73 |
+
short_refname=${refname##refs/tags/}
|
| 74 |
+
if [ "$allowunannotated" != "true" ]; then
|
| 75 |
+
echo "*** The un-annotated tag, $short_refname, is not allowed in this repository" >&2
|
| 76 |
+
echo "*** Use 'git tag [ -a | -s ]' for tags you want to propagate." >&2
|
| 77 |
+
exit 1
|
| 78 |
+
fi
|
| 79 |
+
;;
|
| 80 |
+
refs/tags/*,delete)
|
| 81 |
+
# delete tag
|
| 82 |
+
if [ "$allowdeletetag" != "true" ]; then
|
| 83 |
+
echo "*** Deleting a tag is not allowed in this repository" >&2
|
| 84 |
+
exit 1
|
| 85 |
+
fi
|
| 86 |
+
;;
|
| 87 |
+
refs/tags/*,tag)
|
| 88 |
+
# annotated tag
|
| 89 |
+
if [ "$allowmodifytag" != "true" ] && git rev-parse $refname > /dev/null 2>&1
|
| 90 |
+
then
|
| 91 |
+
echo "*** Tag '$refname' already exists." >&2
|
| 92 |
+
echo "*** Modifying a tag is not allowed in this repository." >&2
|
| 93 |
+
exit 1
|
| 94 |
+
fi
|
| 95 |
+
;;
|
| 96 |
+
refs/heads/*,commit)
|
| 97 |
+
# branch
|
| 98 |
+
if [ "$oldrev" = "$zero" -a "$denycreatebranch" = "true" ]; then
|
| 99 |
+
echo "*** Creating a branch is not allowed in this repository" >&2
|
| 100 |
+
exit 1
|
| 101 |
+
fi
|
| 102 |
+
;;
|
| 103 |
+
refs/heads/*,delete)
|
| 104 |
+
# delete branch
|
| 105 |
+
if [ "$allowdeletebranch" != "true" ]; then
|
| 106 |
+
echo "*** Deleting a branch is not allowed in this repository" >&2
|
| 107 |
+
exit 1
|
| 108 |
+
fi
|
| 109 |
+
;;
|
| 110 |
+
refs/remotes/*,commit)
|
| 111 |
+
# tracking branch
|
| 112 |
+
;;
|
| 113 |
+
refs/remotes/*,delete)
|
| 114 |
+
# delete tracking branch
|
| 115 |
+
if [ "$allowdeletebranch" != "true" ]; then
|
| 116 |
+
echo "*** Deleting a tracking branch is not allowed in this repository" >&2
|
| 117 |
+
exit 1
|
| 118 |
+
fi
|
| 119 |
+
;;
|
| 120 |
+
*)
|
| 121 |
+
# Anything else (is there anything else?)
|
| 122 |
+
echo "*** Update hook: unknown type of update to ref $refname of type $newrev_type" >&2
|
| 123 |
+
exit 1
|
| 124 |
+
;;
|
| 125 |
+
esac
|
| 126 |
+
|
| 127 |
+
# --- Finished
|
| 128 |
+
exit 0
|
original_performance_takehome/.git_backup/index
ADDED
|
Binary file (743 Bytes). View file
|
|
|
original_performance_takehome/.git_backup/info/exclude
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# git ls-files --others --exclude-from=.git/info/exclude
|
| 2 |
+
# Lines that start with '#' are comments.
|
| 3 |
+
# For a project mostly in C, the following would be a good set of
|
| 4 |
+
# exclude patterns (uncomment them if you want to use them):
|
| 5 |
+
# *.[oa]
|
| 6 |
+
# *~
|
original_performance_takehome/.git_backup/logs/HEAD
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
0000000000000000000000000000000000000000 5452f74bd977807ac2e74f3d29432b9df6f25197 Jung Dae Suh <jungdaesuh1221@gmail.com> 1769316765 +0900 clone: from https://github.com/anthropics/original_performance_takehome.git
|
original_performance_takehome/.git_backup/logs/refs/heads/main
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
0000000000000000000000000000000000000000 5452f74bd977807ac2e74f3d29432b9df6f25197 Jung Dae Suh <jungdaesuh1221@gmail.com> 1769316765 +0900 clone: from https://github.com/anthropics/original_performance_takehome.git
|
original_performance_takehome/.git_backup/logs/refs/remotes/origin/HEAD
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
0000000000000000000000000000000000000000 5452f74bd977807ac2e74f3d29432b9df6f25197 Jung Dae Suh <jungdaesuh1221@gmail.com> 1769316765 +0900 clone: from https://github.com/anthropics/original_performance_takehome.git
|
original_performance_takehome/.git_backup/objects/pack/pack-813c2c470e2abf2cfcfb6aa8ba6478e559e69577.idx
ADDED
|
Binary file (1.8 kB). View file
|
|
|
original_performance_takehome/.git_backup/objects/pack/pack-813c2c470e2abf2cfcfb6aa8ba6478e559e69577.pack
ADDED
|
Binary file (20.2 kB). View file
|
|
|
original_performance_takehome/.git_backup/objects/pack/pack-813c2c470e2abf2cfcfb6aa8ba6478e559e69577.rev
ADDED
|
Binary file (156 Bytes). View file
|
|
|
original_performance_takehome/.git_backup/packed-refs
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# pack-refs with: peeled fully-peeled sorted
|
| 2 |
+
5452f74bd977807ac2e74f3d29432b9df6f25197 refs/remotes/origin/main
|
| 3 |
+
d45812f96a6740086db7f2aa78925d9a0b7389dd refs/remotes/origin/tristan/add-warning
|
| 4 |
+
3697cecc2a093b4df01de46e6a61b3b56d3ad6be refs/remotes/origin/tristan/update-readme
|
original_performance_takehome/.git_backup/refs/heads/main
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
5452f74bd977807ac2e74f3d29432b9df6f25197
|
original_performance_takehome/.git_backup/refs/remotes/origin/HEAD
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
ref: refs/remotes/origin/main
|
original_performance_takehome/.gitignore
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
trace.json
|
| 2 |
+
**/*.pyc
|
| 3 |
+
.hypothesis
|
| 4 |
+
.DS_Store
|
original_performance_takehome/Readme.md
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Anthropic's Original Performance Take-Home
|
| 2 |
+
|
| 3 |
+
This repo contains a version of Anthropic's original performance take-home, before Claude Opus 4.5 started doing better than humans given only 2 hours.
|
| 4 |
+
|
| 5 |
+
The original take-home was a 4-hour one that starts close to the contents of this repo, after Claude Opus 4 beat most humans at that, it was updated to a 2-hour one which started with code which achieved 18532 cycles (7.97x faster than this repo starts you). This repo is based on the newer take-home which has a few more instructions and comes with better debugging tools, but has the starter code reverted to the slowest baseline. After Claude Opus 4.5 we started using a different base for our time-limited take-homes.
|
| 6 |
+
|
| 7 |
+
Now you can try to beat Claude Opus 4.5 given unlimited time!
|
| 8 |
+
|
| 9 |
+
## Performance benchmarks
|
| 10 |
+
|
| 11 |
+
Measured in clock cycles from the simulated machine. All of these numbers are for models doing the 2 hour version which started at 18532 cycles:
|
| 12 |
+
|
| 13 |
+
- **2164 cycles**: Claude Opus 4 after many hours in the test-time compute harness
|
| 14 |
+
- **1790 cycles**: Claude Opus 4.5 in a casual Claude Code session, approximately matching the best human performance in 2 hours
|
| 15 |
+
- **1579 cycles**: Claude Opus 4.5 after 2 hours in our test-time compute harness
|
| 16 |
+
- **1548 cycles**: Claude Sonnet 4.5 after many more than 2 hours of test-time compute
|
| 17 |
+
- **1487 cycles**: Claude Opus 4.5 after 11.5 hours in the harness
|
| 18 |
+
- **1363 cycles**: Claude Opus 4.5 in an improved test time compute harness
|
| 19 |
+
- **??? cycles**: Best human performance ever is substantially better than the above, but we won't say how much.
|
| 20 |
+
|
| 21 |
+
While it's no longer a good time-limited test, you can still use this test to get us excited about hiring you! If you optimize below 1487 cycles, beating Claude Opus 4.5's best performance at launch, email us at performance-recruiting@anthropic.com with your code (and ideally a resume) so we can be appropriately impressed, especially if you get near the best solution we've seen. New model releases may change what threshold impresses us though, and no guarantees that we keep this readme updated with the latest on that.
|
| 22 |
+
|
| 23 |
+
Run `python tests/submission_tests.py` to see which thresholds you pass.
|
| 24 |
+
|
| 25 |
+
## Warning: LLMs can cheat
|
| 26 |
+
|
| 27 |
+
None of the solutions we received on the first day post-release below 1300 cycles were valid solutions. In each case, a language model modified the tests to make the problem easier.
|
| 28 |
+
|
| 29 |
+
If you use an AI agent, we recommend instructing it not to change the `tests/` folder and to use `tests/submission_tests.py` for verification.
|
| 30 |
+
|
| 31 |
+
Please run the following commands to validate your submission, and mention that you did so when submitting:
|
| 32 |
+
```
|
| 33 |
+
# This should be empty, the tests folder must be unchanged
|
| 34 |
+
git diff origin/main tests/
|
| 35 |
+
# You should pass some of these tests and use the cycle count this prints
|
| 36 |
+
python tests/submission_tests.py
|
| 37 |
+
```
|
| 38 |
+
|
| 39 |
+
An example of this kind of hack is a model noticing that `problem.py` has multicore support, implementing multicore as an optimization, noticing there's no speedup and "debugging" that `N_CORES = 1` and "fixing" the core count so they get a speedup. Multicore is disabled intentionally in this version.
|
original_performance_takehome/perf_takehome.py
ADDED
|
@@ -0,0 +1,275 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
# Anthropic's Original Performance Engineering Take-home (Release version)
|
| 3 |
+
|
| 4 |
+
Copyright Anthropic PBC 2026. Permission is granted to modify and use, but not
|
| 5 |
+
to publish or redistribute your solutions so it's hard to find spoilers.
|
| 6 |
+
|
| 7 |
+
# Task
|
| 8 |
+
|
| 9 |
+
- Optimize the kernel (in KernelBuilder.build_kernel) as much as possible in the
|
| 10 |
+
available time, as measured by test_kernel_cycles on a frozen separate copy
|
| 11 |
+
of the simulator.
|
| 12 |
+
|
| 13 |
+
Validate your results using `python tests/submission_tests.py` without modifying
|
| 14 |
+
anything in the tests/ folder.
|
| 15 |
+
|
| 16 |
+
We recommend you look through problem.py next.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
from collections import defaultdict
|
| 20 |
+
import random
|
| 21 |
+
import unittest
|
| 22 |
+
|
| 23 |
+
from problem import (
|
| 24 |
+
Engine,
|
| 25 |
+
DebugInfo,
|
| 26 |
+
SLOT_LIMITS,
|
| 27 |
+
VLEN,
|
| 28 |
+
N_CORES,
|
| 29 |
+
SCRATCH_SIZE,
|
| 30 |
+
Machine,
|
| 31 |
+
Tree,
|
| 32 |
+
Input,
|
| 33 |
+
HASH_STAGES,
|
| 34 |
+
reference_kernel,
|
| 35 |
+
build_mem_image,
|
| 36 |
+
reference_kernel2,
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class KernelBuilder:
|
| 41 |
+
def __init__(self):
|
| 42 |
+
self.instrs = []
|
| 43 |
+
self.scratch = {}
|
| 44 |
+
self.scratch_debug = {}
|
| 45 |
+
self.scratch_ptr = 0
|
| 46 |
+
self.const_map = {}
|
| 47 |
+
|
| 48 |
+
def debug_info(self):
|
| 49 |
+
return DebugInfo(scratch_map=self.scratch_debug)
|
| 50 |
+
|
| 51 |
+
def build(self, slots: list[tuple[Engine, tuple]], vliw: bool = False):
|
| 52 |
+
# Simple slot packing that just uses one slot per instruction bundle
|
| 53 |
+
instrs = []
|
| 54 |
+
for engine, slot in slots:
|
| 55 |
+
instrs.append({engine: [slot]})
|
| 56 |
+
return instrs
|
| 57 |
+
|
| 58 |
+
def add(self, engine, slot):
|
| 59 |
+
self.instrs.append({engine: [slot]})
|
| 60 |
+
|
| 61 |
+
def alloc_scratch(self, name=None, length=1):
|
| 62 |
+
addr = self.scratch_ptr
|
| 63 |
+
if name is not None:
|
| 64 |
+
self.scratch[name] = addr
|
| 65 |
+
self.scratch_debug[addr] = (name, length)
|
| 66 |
+
self.scratch_ptr += length
|
| 67 |
+
assert self.scratch_ptr <= SCRATCH_SIZE, "Out of scratch space"
|
| 68 |
+
return addr
|
| 69 |
+
|
| 70 |
+
def scratch_const(self, val, name=None):
|
| 71 |
+
if val not in self.const_map:
|
| 72 |
+
addr = self.alloc_scratch(name)
|
| 73 |
+
self.add("load", ("const", addr, val))
|
| 74 |
+
self.const_map[val] = addr
|
| 75 |
+
return self.const_map[val]
|
| 76 |
+
|
| 77 |
+
def build_hash(self, val_hash_addr, tmp1, tmp2, round, i):
|
| 78 |
+
slots = []
|
| 79 |
+
|
| 80 |
+
for hi, (op1, val1, op2, op3, val3) in enumerate(HASH_STAGES):
|
| 81 |
+
slots.append(("alu", (op1, tmp1, val_hash_addr, self.scratch_const(val1))))
|
| 82 |
+
slots.append(("alu", (op3, tmp2, val_hash_addr, self.scratch_const(val3))))
|
| 83 |
+
slots.append(("alu", (op2, val_hash_addr, tmp1, tmp2)))
|
| 84 |
+
slots.append(("debug", ("compare", val_hash_addr, (round, i, "hash_stage", hi))))
|
| 85 |
+
|
| 86 |
+
return slots
|
| 87 |
+
|
| 88 |
+
def build_kernel(
|
| 89 |
+
self, forest_height: int, n_nodes: int, batch_size: int, rounds: int
|
| 90 |
+
):
|
| 91 |
+
"""
|
| 92 |
+
Like reference_kernel2 but building actual instructions.
|
| 93 |
+
Scalar implementation using only scalar ALU and load/store.
|
| 94 |
+
"""
|
| 95 |
+
tmp1 = self.alloc_scratch("tmp1")
|
| 96 |
+
tmp2 = self.alloc_scratch("tmp2")
|
| 97 |
+
tmp3 = self.alloc_scratch("tmp3")
|
| 98 |
+
# Scratch space addresses
|
| 99 |
+
init_vars = [
|
| 100 |
+
"rounds",
|
| 101 |
+
"n_nodes",
|
| 102 |
+
"batch_size",
|
| 103 |
+
"forest_height",
|
| 104 |
+
"forest_values_p",
|
| 105 |
+
"inp_indices_p",
|
| 106 |
+
"inp_values_p",
|
| 107 |
+
]
|
| 108 |
+
for v in init_vars:
|
| 109 |
+
self.alloc_scratch(v, 1)
|
| 110 |
+
for i, v in enumerate(init_vars):
|
| 111 |
+
self.add("load", ("const", tmp1, i))
|
| 112 |
+
self.add("load", ("load", self.scratch[v], tmp1))
|
| 113 |
+
|
| 114 |
+
zero_const = self.scratch_const(0)
|
| 115 |
+
one_const = self.scratch_const(1)
|
| 116 |
+
two_const = self.scratch_const(2)
|
| 117 |
+
|
| 118 |
+
# Pause instructions are matched up with yield statements in the reference
|
| 119 |
+
# kernel to let you debug at intermediate steps. The testing harness in this
|
| 120 |
+
# file requires these match up to the reference kernel's yields, but the
|
| 121 |
+
# submission harness ignores them.
|
| 122 |
+
self.add("flow", ("pause",))
|
| 123 |
+
# Any debug engine instruction is ignored by the submission simulator
|
| 124 |
+
self.add("debug", ("comment", "Starting loop"))
|
| 125 |
+
|
| 126 |
+
body = [] # array of slots
|
| 127 |
+
|
| 128 |
+
# Scalar scratch registers
|
| 129 |
+
tmp_idx = self.alloc_scratch("tmp_idx")
|
| 130 |
+
tmp_val = self.alloc_scratch("tmp_val")
|
| 131 |
+
tmp_node_val = self.alloc_scratch("tmp_node_val")
|
| 132 |
+
tmp_addr = self.alloc_scratch("tmp_addr")
|
| 133 |
+
|
| 134 |
+
for round in range(rounds):
|
| 135 |
+
for i in range(batch_size):
|
| 136 |
+
i_const = self.scratch_const(i)
|
| 137 |
+
# idx = mem[inp_indices_p + i]
|
| 138 |
+
body.append(("alu", ("+", tmp_addr, self.scratch["inp_indices_p"], i_const)))
|
| 139 |
+
body.append(("load", ("load", tmp_idx, tmp_addr)))
|
| 140 |
+
body.append(("debug", ("compare", tmp_idx, (round, i, "idx"))))
|
| 141 |
+
# val = mem[inp_values_p + i]
|
| 142 |
+
body.append(("alu", ("+", tmp_addr, self.scratch["inp_values_p"], i_const)))
|
| 143 |
+
body.append(("load", ("load", tmp_val, tmp_addr)))
|
| 144 |
+
body.append(("debug", ("compare", tmp_val, (round, i, "val"))))
|
| 145 |
+
# node_val = mem[forest_values_p + idx]
|
| 146 |
+
body.append(("alu", ("+", tmp_addr, self.scratch["forest_values_p"], tmp_idx)))
|
| 147 |
+
body.append(("load", ("load", tmp_node_val, tmp_addr)))
|
| 148 |
+
body.append(("debug", ("compare", tmp_node_val, (round, i, "node_val"))))
|
| 149 |
+
# val = myhash(val ^ node_val)
|
| 150 |
+
body.append(("alu", ("^", tmp_val, tmp_val, tmp_node_val)))
|
| 151 |
+
body.extend(self.build_hash(tmp_val, tmp1, tmp2, round, i))
|
| 152 |
+
body.append(("debug", ("compare", tmp_val, (round, i, "hashed_val"))))
|
| 153 |
+
# idx = 2*idx + (1 if val % 2 == 0 else 2)
|
| 154 |
+
body.append(("alu", ("%", tmp1, tmp_val, two_const)))
|
| 155 |
+
body.append(("alu", ("==", tmp1, tmp1, zero_const)))
|
| 156 |
+
body.append(("flow", ("select", tmp3, tmp1, one_const, two_const)))
|
| 157 |
+
body.append(("alu", ("*", tmp_idx, tmp_idx, two_const)))
|
| 158 |
+
body.append(("alu", ("+", tmp_idx, tmp_idx, tmp3)))
|
| 159 |
+
body.append(("debug", ("compare", tmp_idx, (round, i, "next_idx"))))
|
| 160 |
+
# idx = 0 if idx >= n_nodes else idx
|
| 161 |
+
body.append(("alu", ("<", tmp1, tmp_idx, self.scratch["n_nodes"])))
|
| 162 |
+
body.append(("flow", ("select", tmp_idx, tmp1, tmp_idx, zero_const)))
|
| 163 |
+
body.append(("debug", ("compare", tmp_idx, (round, i, "wrapped_idx"))))
|
| 164 |
+
# mem[inp_indices_p + i] = idx
|
| 165 |
+
body.append(("alu", ("+", tmp_addr, self.scratch["inp_indices_p"], i_const)))
|
| 166 |
+
body.append(("store", ("store", tmp_addr, tmp_idx)))
|
| 167 |
+
# mem[inp_values_p + i] = val
|
| 168 |
+
body.append(("alu", ("+", tmp_addr, self.scratch["inp_values_p"], i_const)))
|
| 169 |
+
body.append(("store", ("store", tmp_addr, tmp_val)))
|
| 170 |
+
|
| 171 |
+
body_instrs = self.build(body)
|
| 172 |
+
self.instrs.extend(body_instrs)
|
| 173 |
+
# Required to match with the yield in reference_kernel2
|
| 174 |
+
self.instrs.append({"flow": [("pause",)]})
|
| 175 |
+
|
| 176 |
+
BASELINE = 147734
|
| 177 |
+
|
| 178 |
+
def do_kernel_test(
|
| 179 |
+
forest_height: int,
|
| 180 |
+
rounds: int,
|
| 181 |
+
batch_size: int,
|
| 182 |
+
seed: int = 123,
|
| 183 |
+
trace: bool = False,
|
| 184 |
+
prints: bool = False,
|
| 185 |
+
):
|
| 186 |
+
print(f"{forest_height=}, {rounds=}, {batch_size=}")
|
| 187 |
+
random.seed(seed)
|
| 188 |
+
forest = Tree.generate(forest_height)
|
| 189 |
+
inp = Input.generate(forest, batch_size, rounds)
|
| 190 |
+
mem = build_mem_image(forest, inp)
|
| 191 |
+
|
| 192 |
+
kb = KernelBuilder()
|
| 193 |
+
kb.build_kernel(forest.height, len(forest.values), len(inp.indices), rounds)
|
| 194 |
+
# print(kb.instrs)
|
| 195 |
+
|
| 196 |
+
value_trace = {}
|
| 197 |
+
machine = Machine(
|
| 198 |
+
mem,
|
| 199 |
+
kb.instrs,
|
| 200 |
+
kb.debug_info(),
|
| 201 |
+
n_cores=N_CORES,
|
| 202 |
+
value_trace=value_trace,
|
| 203 |
+
trace=trace,
|
| 204 |
+
)
|
| 205 |
+
machine.prints = prints
|
| 206 |
+
for i, ref_mem in enumerate(reference_kernel2(mem, value_trace)):
|
| 207 |
+
machine.run()
|
| 208 |
+
inp_values_p = ref_mem[6]
|
| 209 |
+
if prints:
|
| 210 |
+
print(machine.mem[inp_values_p : inp_values_p + len(inp.values)])
|
| 211 |
+
print(ref_mem[inp_values_p : inp_values_p + len(inp.values)])
|
| 212 |
+
assert (
|
| 213 |
+
machine.mem[inp_values_p : inp_values_p + len(inp.values)]
|
| 214 |
+
== ref_mem[inp_values_p : inp_values_p + len(inp.values)]
|
| 215 |
+
), f"Incorrect result on round {i}"
|
| 216 |
+
inp_indices_p = ref_mem[5]
|
| 217 |
+
if prints:
|
| 218 |
+
print(machine.mem[inp_indices_p : inp_indices_p + len(inp.indices)])
|
| 219 |
+
print(ref_mem[inp_indices_p : inp_indices_p + len(inp.indices)])
|
| 220 |
+
# Updating these in memory isn't required, but you can enable this check for debugging
|
| 221 |
+
# assert machine.mem[inp_indices_p:inp_indices_p+len(inp.indices)] == ref_mem[inp_indices_p:inp_indices_p+len(inp.indices)]
|
| 222 |
+
|
| 223 |
+
print("CYCLES: ", machine.cycle)
|
| 224 |
+
print("Speedup over baseline: ", BASELINE / machine.cycle)
|
| 225 |
+
return machine.cycle
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
class Tests(unittest.TestCase):
|
| 229 |
+
def test_ref_kernels(self):
|
| 230 |
+
"""
|
| 231 |
+
Test the reference kernels against each other
|
| 232 |
+
"""
|
| 233 |
+
random.seed(123)
|
| 234 |
+
for i in range(10):
|
| 235 |
+
f = Tree.generate(4)
|
| 236 |
+
inp = Input.generate(f, 10, 6)
|
| 237 |
+
mem = build_mem_image(f, inp)
|
| 238 |
+
reference_kernel(f, inp)
|
| 239 |
+
for _ in reference_kernel2(mem, {}):
|
| 240 |
+
pass
|
| 241 |
+
assert inp.indices == mem[mem[5] : mem[5] + len(inp.indices)]
|
| 242 |
+
assert inp.values == mem[mem[6] : mem[6] + len(inp.values)]
|
| 243 |
+
|
| 244 |
+
def test_kernel_trace(self):
|
| 245 |
+
# Full-scale example for performance testing
|
| 246 |
+
do_kernel_test(10, 16, 256, trace=True, prints=False)
|
| 247 |
+
|
| 248 |
+
# Passing this test is not required for submission, see submission_tests.py for the actual correctness test
|
| 249 |
+
# You can uncomment this if you think it might help you debug
|
| 250 |
+
# def test_kernel_correctness(self):
|
| 251 |
+
# for batch in range(1, 3):
|
| 252 |
+
# for forest_height in range(3):
|
| 253 |
+
# do_kernel_test(
|
| 254 |
+
# forest_height + 2, forest_height + 4, batch * 16 * VLEN * N_CORES
|
| 255 |
+
# )
|
| 256 |
+
|
| 257 |
+
def test_kernel_cycles(self):
|
| 258 |
+
do_kernel_test(10, 16, 256)
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
# To run all the tests:
|
| 262 |
+
# python perf_takehome.py
|
| 263 |
+
# To run a specific test:
|
| 264 |
+
# python perf_takehome.py Tests.test_kernel_cycles
|
| 265 |
+
# To view a hot-reloading trace of all the instructions: **Recommended debug loop**
|
| 266 |
+
# NOTE: The trace hot-reloading only works in Chrome. In the worst case if things aren't working, drag trace.json onto https://ui.perfetto.dev/
|
| 267 |
+
# python perf_takehome.py Tests.test_kernel_trace
|
| 268 |
+
# Then run `python watch_trace.py` in another tab, it'll open a browser tab, then click "Open Perfetto"
|
| 269 |
+
# You can then keep that open and re-run the test to see a new trace.
|
| 270 |
+
|
| 271 |
+
# To run the proper checks to see which thresholds you pass:
|
| 272 |
+
# python tests/submission_tests.py
|
| 273 |
+
|
| 274 |
+
if __name__ == "__main__":
|
| 275 |
+
unittest.main()
|
original_performance_takehome/problem.py
ADDED
|
@@ -0,0 +1,568 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Read the top of perf_takehome.py for more introduction.
|
| 3 |
+
|
| 4 |
+
This file is separate mostly for ease of copying it to freeze the machine and
|
| 5 |
+
reference kernel for testing.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from copy import copy
|
| 9 |
+
from dataclasses import dataclass
|
| 10 |
+
from enum import Enum
|
| 11 |
+
from typing import Any, Literal
|
| 12 |
+
import random
|
| 13 |
+
|
| 14 |
+
Engine = Literal["alu", "load", "store", "flow"]
|
| 15 |
+
Instruction = dict[Engine, list[tuple]]
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class CoreState(Enum):
|
| 19 |
+
RUNNING = 1
|
| 20 |
+
PAUSED = 2
|
| 21 |
+
STOPPED = 3
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@dataclass
|
| 25 |
+
class Core:
|
| 26 |
+
id: int
|
| 27 |
+
scratch: list[int]
|
| 28 |
+
trace_buf: list[int]
|
| 29 |
+
pc: int = 0
|
| 30 |
+
state: CoreState = CoreState.RUNNING
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
@dataclass
|
| 34 |
+
class DebugInfo:
|
| 35 |
+
"""
|
| 36 |
+
We give you some debug info but it's up to you to use it in Machine if you
|
| 37 |
+
want to. You're also welcome to add more.
|
| 38 |
+
"""
|
| 39 |
+
|
| 40 |
+
# Maps scratch variable addr to (name, len) pair
|
| 41 |
+
scratch_map: dict[int, (str, int)]
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def cdiv(a, b):
|
| 45 |
+
return (a + b - 1) // b
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
SLOT_LIMITS = {
|
| 49 |
+
"alu": 12,
|
| 50 |
+
"valu": 6,
|
| 51 |
+
"load": 2,
|
| 52 |
+
"store": 2,
|
| 53 |
+
"flow": 1,
|
| 54 |
+
"debug": 64,
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
VLEN = 8
|
| 58 |
+
# Older versions of the take-home used multiple cores, but this version only uses 1
|
| 59 |
+
N_CORES = 1
|
| 60 |
+
SCRATCH_SIZE = 1536
|
| 61 |
+
BASE_ADDR_TID = 100000
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class Machine:
|
| 65 |
+
"""
|
| 66 |
+
Simulator for a custom VLIW SIMD architecture.
|
| 67 |
+
|
| 68 |
+
VLIW (Very Large Instruction Word): Cores are composed of different
|
| 69 |
+
"engines" each of which can execute multiple "slots" per cycle in parallel.
|
| 70 |
+
How many slots each engine can execute per cycle is limited by SLOT_LIMITS.
|
| 71 |
+
Effects of instructions don't take effect until the end of cycle. Each
|
| 72 |
+
cycle, all engines execute all of their filled slots for that instruction.
|
| 73 |
+
Effects like writes to memory take place after all the inputs are read.
|
| 74 |
+
|
| 75 |
+
SIMD: There are instructions for acting on vectors of VLEN elements in a
|
| 76 |
+
single slot. You can use vload and vstore to load multiple contiguous
|
| 77 |
+
elements but not non-contiguous elements. Use vbroadcast to broadcast a
|
| 78 |
+
scalar to a vector and then operate on vectors with valu instructions.
|
| 79 |
+
|
| 80 |
+
The memory and scratch space are composed of 32-bit words. The solution is
|
| 81 |
+
plucked out of the memory at the end of the program. You can think of the
|
| 82 |
+
scratch space as serving the purpose of registers, constant memory, and a
|
| 83 |
+
manually-managed cache.
|
| 84 |
+
|
| 85 |
+
Here's an example of what an instruction might look like:
|
| 86 |
+
|
| 87 |
+
{"valu": [("*", 4, 0, 0), ("+", 8, 4, 0)], "load": [("load", 16, 17)]}
|
| 88 |
+
|
| 89 |
+
In general every number in an instruction is a scratch address except for
|
| 90 |
+
const and jump, and except for store and some flow instructions the first
|
| 91 |
+
operand is the destination.
|
| 92 |
+
|
| 93 |
+
This comment is not meant to be full ISA documentation though, for the rest
|
| 94 |
+
you should look through the simulator code.
|
| 95 |
+
"""
|
| 96 |
+
|
| 97 |
+
def __init__(
|
| 98 |
+
self,
|
| 99 |
+
mem_dump: list[int],
|
| 100 |
+
program: list[Instruction],
|
| 101 |
+
debug_info: DebugInfo,
|
| 102 |
+
n_cores: int = 1,
|
| 103 |
+
scratch_size: int = SCRATCH_SIZE,
|
| 104 |
+
trace: bool = False,
|
| 105 |
+
value_trace: dict[Any, int] = {},
|
| 106 |
+
):
|
| 107 |
+
self.cores = [
|
| 108 |
+
Core(id=i, scratch=[0] * scratch_size, trace_buf=[]) for i in range(n_cores)
|
| 109 |
+
]
|
| 110 |
+
self.mem = copy(mem_dump)
|
| 111 |
+
self.program = program
|
| 112 |
+
self.debug_info = debug_info
|
| 113 |
+
self.value_trace = value_trace
|
| 114 |
+
self.prints = False
|
| 115 |
+
self.cycle = 0
|
| 116 |
+
self.enable_pause = True
|
| 117 |
+
self.enable_debug = True
|
| 118 |
+
if trace:
|
| 119 |
+
self.setup_trace()
|
| 120 |
+
else:
|
| 121 |
+
self.trace = None
|
| 122 |
+
|
| 123 |
+
def rewrite_instr(self, instr):
|
| 124 |
+
"""
|
| 125 |
+
Rewrite an instruction to use scratch addresses instead of names
|
| 126 |
+
"""
|
| 127 |
+
res = {}
|
| 128 |
+
for name, slots in instr.items():
|
| 129 |
+
res[name] = []
|
| 130 |
+
for slot in slots:
|
| 131 |
+
res[name].append(self.rewrite_slot(slot))
|
| 132 |
+
return res
|
| 133 |
+
|
| 134 |
+
def print_step(self, instr, core):
|
| 135 |
+
# print(core.id)
|
| 136 |
+
# print(core.trace_buf)
|
| 137 |
+
print(self.scratch_map(core))
|
| 138 |
+
print(core.pc, instr, self.rewrite_instr(instr))
|
| 139 |
+
|
| 140 |
+
def scratch_map(self, core):
|
| 141 |
+
res = {}
|
| 142 |
+
for addr, (name, length) in self.debug_info.scratch_map.items():
|
| 143 |
+
res[name] = core.scratch[addr : addr + length]
|
| 144 |
+
return res
|
| 145 |
+
|
| 146 |
+
def rewrite_slot(self, slot):
|
| 147 |
+
return tuple(
|
| 148 |
+
self.debug_info.scratch_map.get(s, (None, None))[0] or s for s in slot
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
def setup_trace(self):
|
| 152 |
+
"""
|
| 153 |
+
The simulator generates traces in Chrome's Trace Event Format for
|
| 154 |
+
visualization in Perfetto (or chrome://tracing if you prefer it). See
|
| 155 |
+
the bottom of the file for info about how to use this.
|
| 156 |
+
|
| 157 |
+
See the format docs in case you want to add more info to the trace:
|
| 158 |
+
https://docs.google.com/document/d/1CvAClvFfyA5R-PhYUmn5OOQtYMH4h6I0nSsKchNAySU/preview
|
| 159 |
+
"""
|
| 160 |
+
self.trace = open("trace.json", "w")
|
| 161 |
+
self.trace.write("[")
|
| 162 |
+
tid_counter = 0
|
| 163 |
+
self.tids = {}
|
| 164 |
+
for ci, core in enumerate(self.cores):
|
| 165 |
+
self.trace.write(
|
| 166 |
+
f'{{"name": "process_name", "ph": "M", "pid": {ci}, "tid": 0, "args": {{"name":"Core {ci}"}}}},\n'
|
| 167 |
+
)
|
| 168 |
+
for name, limit in SLOT_LIMITS.items():
|
| 169 |
+
if name == "debug":
|
| 170 |
+
continue
|
| 171 |
+
for i in range(limit):
|
| 172 |
+
tid_counter += 1
|
| 173 |
+
self.trace.write(
|
| 174 |
+
f'{{"name": "thread_name", "ph": "M", "pid": {ci}, "tid": {tid_counter}, "args": {{"name":"{name}-{i}"}}}},\n'
|
| 175 |
+
)
|
| 176 |
+
self.tids[(ci, name, i)] = tid_counter
|
| 177 |
+
|
| 178 |
+
# Add zero-length events at the start so all slots show up in Perfetto
|
| 179 |
+
for ci, core in enumerate(self.cores):
|
| 180 |
+
for name, limit in SLOT_LIMITS.items():
|
| 181 |
+
if name == "debug":
|
| 182 |
+
continue
|
| 183 |
+
for i in range(limit):
|
| 184 |
+
tid = self.tids[(ci, name, i)]
|
| 185 |
+
self.trace.write(
|
| 186 |
+
f'{{"name": "init", "cat": "op", "ph": "X", "pid": {ci}, "tid": {tid}, "ts": 0, "dur": 0}},\n'
|
| 187 |
+
)
|
| 188 |
+
for ci, core in enumerate(self.cores):
|
| 189 |
+
self.trace.write(
|
| 190 |
+
f'{{"name": "process_name", "ph": "M", "pid": {len(self.cores) + ci}, "tid": 0, "args": {{"name":"Core {ci} Scratch"}}}},\n'
|
| 191 |
+
)
|
| 192 |
+
for addr, (name, length) in self.debug_info.scratch_map.items():
|
| 193 |
+
self.trace.write(
|
| 194 |
+
f'{{"name": "thread_name", "ph": "M", "pid": {len(self.cores) + ci}, "tid": {BASE_ADDR_TID + addr}, "args": {{"name":"{name}-{length}"}}}},\n'
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
def run(self):
|
| 198 |
+
for core in self.cores:
|
| 199 |
+
if core.state == CoreState.PAUSED:
|
| 200 |
+
core.state = CoreState.RUNNING
|
| 201 |
+
while any(c.state == CoreState.RUNNING for c in self.cores):
|
| 202 |
+
has_non_debug = False
|
| 203 |
+
for core in self.cores:
|
| 204 |
+
if core.state != CoreState.RUNNING:
|
| 205 |
+
continue
|
| 206 |
+
if core.pc >= len(self.program):
|
| 207 |
+
core.state = CoreState.STOPPED
|
| 208 |
+
continue
|
| 209 |
+
instr = self.program[core.pc]
|
| 210 |
+
if self.prints:
|
| 211 |
+
self.print_step(instr, core)
|
| 212 |
+
core.pc += 1
|
| 213 |
+
self.step(instr, core)
|
| 214 |
+
if any(name != "debug" for name in instr.keys()):
|
| 215 |
+
has_non_debug = True
|
| 216 |
+
if has_non_debug:
|
| 217 |
+
self.cycle += 1
|
| 218 |
+
|
| 219 |
+
def alu(self, core, op, dest, a1, a2):
|
| 220 |
+
a1 = core.scratch[a1]
|
| 221 |
+
a2 = core.scratch[a2]
|
| 222 |
+
match op:
|
| 223 |
+
case "+":
|
| 224 |
+
res = a1 + a2
|
| 225 |
+
case "-":
|
| 226 |
+
res = a1 - a2
|
| 227 |
+
case "*":
|
| 228 |
+
res = a1 * a2
|
| 229 |
+
case "//":
|
| 230 |
+
res = a1 // a2
|
| 231 |
+
case "cdiv":
|
| 232 |
+
res = cdiv(a1, a2)
|
| 233 |
+
case "^":
|
| 234 |
+
res = a1 ^ a2
|
| 235 |
+
case "&":
|
| 236 |
+
res = a1 & a2
|
| 237 |
+
case "|":
|
| 238 |
+
res = a1 | a2
|
| 239 |
+
case "<<":
|
| 240 |
+
res = a1 << a2
|
| 241 |
+
case ">>":
|
| 242 |
+
res = a1 >> a2
|
| 243 |
+
case "%":
|
| 244 |
+
res = a1 % a2
|
| 245 |
+
case "<":
|
| 246 |
+
res = int(a1 < a2)
|
| 247 |
+
case "==":
|
| 248 |
+
res = int(a1 == a2)
|
| 249 |
+
case _:
|
| 250 |
+
raise NotImplementedError(f"Unknown alu op {op}")
|
| 251 |
+
res = res % (2**32)
|
| 252 |
+
self.scratch_write[dest] = res
|
| 253 |
+
|
| 254 |
+
def valu(self, core, *slot):
|
| 255 |
+
match slot:
|
| 256 |
+
case ("vbroadcast", dest, src):
|
| 257 |
+
for i in range(VLEN):
|
| 258 |
+
self.scratch_write[dest + i] = core.scratch[src]
|
| 259 |
+
case ("multiply_add", dest, a, b, c):
|
| 260 |
+
for i in range(VLEN):
|
| 261 |
+
mul = (core.scratch[a + i] * core.scratch[b + i]) % (2**32)
|
| 262 |
+
self.scratch_write[dest + i] = (mul + core.scratch[c + i]) % (2**32)
|
| 263 |
+
case (op, dest, a1, a2):
|
| 264 |
+
for i in range(VLEN):
|
| 265 |
+
self.alu(core, op, dest + i, a1 + i, a2 + i)
|
| 266 |
+
case _:
|
| 267 |
+
raise NotImplementedError(f"Unknown valu op {slot}")
|
| 268 |
+
|
| 269 |
+
def load(self, core, *slot):
|
| 270 |
+
match slot:
|
| 271 |
+
case ("load", dest, addr):
|
| 272 |
+
# print(dest, addr, core.scratch[addr])
|
| 273 |
+
self.scratch_write[dest] = self.mem[core.scratch[addr]]
|
| 274 |
+
case ("load_offset", dest, addr, offset):
|
| 275 |
+
# Handy for treating vector dest and addr as a full block in the mini-compiler if you want
|
| 276 |
+
self.scratch_write[dest + offset] = self.mem[
|
| 277 |
+
core.scratch[addr + offset]
|
| 278 |
+
]
|
| 279 |
+
case ("vload", dest, addr): # addr is a scalar
|
| 280 |
+
addr = core.scratch[addr]
|
| 281 |
+
for vi in range(VLEN):
|
| 282 |
+
self.scratch_write[dest + vi] = self.mem[addr + vi]
|
| 283 |
+
case ("const", dest, val):
|
| 284 |
+
self.scratch_write[dest] = (val) % (2**32)
|
| 285 |
+
case _:
|
| 286 |
+
raise NotImplementedError(f"Unknown load op {slot}")
|
| 287 |
+
|
| 288 |
+
def store(self, core, *slot):
|
| 289 |
+
match slot:
|
| 290 |
+
case ("store", addr, src):
|
| 291 |
+
addr = core.scratch[addr]
|
| 292 |
+
self.mem_write[addr] = core.scratch[src]
|
| 293 |
+
case ("vstore", addr, src): # addr is a scalar
|
| 294 |
+
addr = core.scratch[addr]
|
| 295 |
+
for vi in range(VLEN):
|
| 296 |
+
self.mem_write[addr + vi] = core.scratch[src + vi]
|
| 297 |
+
case _:
|
| 298 |
+
raise NotImplementedError(f"Unknown store op {slot}")
|
| 299 |
+
|
| 300 |
+
def flow(self, core, *slot):
|
| 301 |
+
match slot:
|
| 302 |
+
case ("select", dest, cond, a, b):
|
| 303 |
+
self.scratch_write[dest] = (
|
| 304 |
+
core.scratch[a] if core.scratch[cond] != 0 else core.scratch[b]
|
| 305 |
+
)
|
| 306 |
+
case ("add_imm", dest, a, imm):
|
| 307 |
+
self.scratch_write[dest] = (core.scratch[a] + imm) % (2**32)
|
| 308 |
+
case ("vselect", dest, cond, a, b):
|
| 309 |
+
for vi in range(VLEN):
|
| 310 |
+
self.scratch_write[dest + vi] = (
|
| 311 |
+
core.scratch[a + vi]
|
| 312 |
+
if core.scratch[cond + vi] != 0
|
| 313 |
+
else core.scratch[b + vi]
|
| 314 |
+
)
|
| 315 |
+
case ("halt",):
|
| 316 |
+
core.state = CoreState.STOPPED
|
| 317 |
+
case ("pause",):
|
| 318 |
+
if self.enable_pause:
|
| 319 |
+
core.state = CoreState.PAUSED
|
| 320 |
+
case ("trace_write", val):
|
| 321 |
+
core.trace_buf.append(core.scratch[val])
|
| 322 |
+
case ("cond_jump", cond, addr):
|
| 323 |
+
if core.scratch[cond] != 0:
|
| 324 |
+
core.pc = addr
|
| 325 |
+
case ("cond_jump_rel", cond, offset):
|
| 326 |
+
if core.scratch[cond] != 0:
|
| 327 |
+
core.pc += offset
|
| 328 |
+
case ("jump", addr):
|
| 329 |
+
core.pc = addr
|
| 330 |
+
case ("jump_indirect", addr):
|
| 331 |
+
core.pc = core.scratch[addr]
|
| 332 |
+
case ("coreid", dest):
|
| 333 |
+
self.scratch_write[dest] = core.id
|
| 334 |
+
case _:
|
| 335 |
+
raise NotImplementedError(f"Unknown flow op {slot}")
|
| 336 |
+
|
| 337 |
+
def trace_post_step(self, instr, core):
|
| 338 |
+
# You can add extra stuff to the trace if you want!
|
| 339 |
+
for addr, (name, length) in self.debug_info.scratch_map.items():
|
| 340 |
+
if any((addr + vi) in self.scratch_write for vi in range(length)):
|
| 341 |
+
val = str(core.scratch[addr : addr + length])
|
| 342 |
+
val = val.replace("[", "").replace("]", "")
|
| 343 |
+
self.trace.write(
|
| 344 |
+
f'{{"name": "{val}", "cat": "op", "ph": "X", "pid": {len(self.cores) + core.id}, "tid": {BASE_ADDR_TID + addr}, "ts": {self.cycle}, "dur": 1 }},\n'
|
| 345 |
+
)
|
| 346 |
+
|
| 347 |
+
def trace_slot(self, core, slot, name, i):
|
| 348 |
+
self.trace.write(
|
| 349 |
+
f'{{"name": "{slot[0]}", "cat": "op", "ph": "X", "pid": {core.id}, "tid": {self.tids[(core.id, name, i)]}, "ts": {self.cycle}, "dur": 1, "args":{{"slot": "{str(slot)}", "named": "{str(self.rewrite_slot(slot))}" }} }},\n'
|
| 350 |
+
)
|
| 351 |
+
|
| 352 |
+
def step(self, instr: Instruction, core):
|
| 353 |
+
"""
|
| 354 |
+
Execute all the slots in each engine for a single instruction bundle
|
| 355 |
+
"""
|
| 356 |
+
ENGINE_FNS = {
|
| 357 |
+
"alu": self.alu,
|
| 358 |
+
"valu": self.valu,
|
| 359 |
+
"load": self.load,
|
| 360 |
+
"store": self.store,
|
| 361 |
+
"flow": self.flow,
|
| 362 |
+
}
|
| 363 |
+
self.scratch_write = {}
|
| 364 |
+
self.mem_write = {}
|
| 365 |
+
for name, slots in instr.items():
|
| 366 |
+
if name == "debug":
|
| 367 |
+
if not self.enable_debug:
|
| 368 |
+
continue
|
| 369 |
+
for slot in slots:
|
| 370 |
+
if slot[0] == "compare":
|
| 371 |
+
loc, key = slot[1], slot[2]
|
| 372 |
+
ref = self.value_trace[key]
|
| 373 |
+
res = core.scratch[loc]
|
| 374 |
+
assert res == ref, f"{res} != {ref} for {key} at pc={core.pc}"
|
| 375 |
+
elif slot[0] == "vcompare":
|
| 376 |
+
loc, keys = slot[1], slot[2]
|
| 377 |
+
ref = [self.value_trace[key] for key in keys]
|
| 378 |
+
res = core.scratch[loc : loc + VLEN]
|
| 379 |
+
assert res == ref, (
|
| 380 |
+
f"{res} != {ref} for {keys} at pc={core.pc} loc={loc}"
|
| 381 |
+
)
|
| 382 |
+
continue
|
| 383 |
+
assert len(slots) <= SLOT_LIMITS[name]
|
| 384 |
+
for i, slot in enumerate(slots):
|
| 385 |
+
if self.trace is not None:
|
| 386 |
+
self.trace_slot(core, slot, name, i)
|
| 387 |
+
ENGINE_FNS[name](core, *slot)
|
| 388 |
+
for addr, val in self.scratch_write.items():
|
| 389 |
+
core.scratch[addr] = val
|
| 390 |
+
for addr, val in self.mem_write.items():
|
| 391 |
+
self.mem[addr] = val
|
| 392 |
+
|
| 393 |
+
if self.trace:
|
| 394 |
+
self.trace_post_step(instr, core)
|
| 395 |
+
|
| 396 |
+
del self.scratch_write
|
| 397 |
+
del self.mem_write
|
| 398 |
+
|
| 399 |
+
def __del__(self):
|
| 400 |
+
if self.trace is not None:
|
| 401 |
+
self.trace.write("]")
|
| 402 |
+
self.trace.close()
|
| 403 |
+
|
| 404 |
+
|
| 405 |
+
@dataclass
|
| 406 |
+
class Tree:
|
| 407 |
+
"""
|
| 408 |
+
An implicit perfect balanced binary tree with values on the nodes.
|
| 409 |
+
"""
|
| 410 |
+
|
| 411 |
+
height: int
|
| 412 |
+
values: list[int]
|
| 413 |
+
|
| 414 |
+
@staticmethod
|
| 415 |
+
def generate(height: int):
|
| 416 |
+
n_nodes = 2 ** (height + 1) - 1
|
| 417 |
+
values = [random.randint(0, 2**30 - 1) for _ in range(n_nodes)]
|
| 418 |
+
return Tree(height, values)
|
| 419 |
+
|
| 420 |
+
|
| 421 |
+
@dataclass
|
| 422 |
+
class Input:
|
| 423 |
+
"""
|
| 424 |
+
A batch of inputs, indices to nodes (starting as 0) and initial input
|
| 425 |
+
values. We then iterate these for a specified number of rounds.
|
| 426 |
+
"""
|
| 427 |
+
|
| 428 |
+
indices: list[int]
|
| 429 |
+
values: list[int]
|
| 430 |
+
rounds: int
|
| 431 |
+
|
| 432 |
+
@staticmethod
|
| 433 |
+
def generate(forest: Tree, batch_size: int, rounds: int):
|
| 434 |
+
indices = [0 for _ in range(batch_size)]
|
| 435 |
+
values = [random.randint(0, 2**30 - 1) for _ in range(batch_size)]
|
| 436 |
+
return Input(indices, values, rounds)
|
| 437 |
+
|
| 438 |
+
|
| 439 |
+
HASH_STAGES = [
|
| 440 |
+
("+", 0x7ED55D16, "+", "<<", 12),
|
| 441 |
+
("^", 0xC761C23C, "^", ">>", 19),
|
| 442 |
+
("+", 0x165667B1, "+", "<<", 5),
|
| 443 |
+
("+", 0xD3A2646C, "^", "<<", 9),
|
| 444 |
+
("+", 0xFD7046C5, "+", "<<", 3),
|
| 445 |
+
("^", 0xB55A4F09, "^", ">>", 16),
|
| 446 |
+
]
|
| 447 |
+
|
| 448 |
+
|
| 449 |
+
def myhash(a: int) -> int:
|
| 450 |
+
"""A simple 32-bit hash function"""
|
| 451 |
+
fns = {
|
| 452 |
+
"+": lambda x, y: x + y,
|
| 453 |
+
"^": lambda x, y: x ^ y,
|
| 454 |
+
"<<": lambda x, y: x << y,
|
| 455 |
+
">>": lambda x, y: x >> y,
|
| 456 |
+
}
|
| 457 |
+
|
| 458 |
+
def r(x):
|
| 459 |
+
return x % (2**32)
|
| 460 |
+
|
| 461 |
+
for op1, val1, op2, op3, val3 in HASH_STAGES:
|
| 462 |
+
a = r(fns[op2](r(fns[op1](a, val1)), r(fns[op3](a, val3))))
|
| 463 |
+
|
| 464 |
+
return a
|
| 465 |
+
|
| 466 |
+
|
| 467 |
+
def reference_kernel(t: Tree, inp: Input):
|
| 468 |
+
"""
|
| 469 |
+
Reference implementation of the kernel.
|
| 470 |
+
|
| 471 |
+
A parallel tree traversal where at each node we set
|
| 472 |
+
cur_inp_val = myhash(cur_inp_val ^ node_val)
|
| 473 |
+
and then choose the left branch if cur_inp_val is even.
|
| 474 |
+
If we reach the bottom of the tree we wrap around to the top.
|
| 475 |
+
"""
|
| 476 |
+
for h in range(inp.rounds):
|
| 477 |
+
for i in range(len(inp.indices)):
|
| 478 |
+
idx = inp.indices[i]
|
| 479 |
+
val = inp.values[i]
|
| 480 |
+
val = myhash(val ^ t.values[idx])
|
| 481 |
+
idx = 2 * idx + (1 if val % 2 == 0 else 2)
|
| 482 |
+
idx = 0 if idx >= len(t.values) else idx
|
| 483 |
+
inp.values[i] = val
|
| 484 |
+
inp.indices[i] = idx
|
| 485 |
+
|
| 486 |
+
|
| 487 |
+
def build_mem_image(t: Tree, inp: Input) -> list[int]:
|
| 488 |
+
"""
|
| 489 |
+
Build a flat memory image of the problem.
|
| 490 |
+
"""
|
| 491 |
+
header = 7
|
| 492 |
+
extra_room = len(t.values) + len(inp.indices) * 2 + VLEN * 2 + 32
|
| 493 |
+
mem = [0] * (
|
| 494 |
+
header + len(t.values) + len(inp.indices) + len(inp.values) + extra_room
|
| 495 |
+
)
|
| 496 |
+
forest_values_p = header
|
| 497 |
+
inp_indices_p = forest_values_p + len(t.values)
|
| 498 |
+
inp_values_p = inp_indices_p + len(inp.values)
|
| 499 |
+
extra_room = inp_values_p + len(inp.values)
|
| 500 |
+
|
| 501 |
+
mem[0] = inp.rounds
|
| 502 |
+
mem[1] = len(t.values)
|
| 503 |
+
mem[2] = len(inp.indices)
|
| 504 |
+
mem[3] = t.height
|
| 505 |
+
mem[4] = forest_values_p
|
| 506 |
+
mem[5] = inp_indices_p
|
| 507 |
+
mem[6] = inp_values_p
|
| 508 |
+
mem[7] = extra_room
|
| 509 |
+
|
| 510 |
+
mem[header:inp_indices_p] = t.values
|
| 511 |
+
mem[inp_indices_p:inp_values_p] = inp.indices
|
| 512 |
+
mem[inp_values_p:] = inp.values
|
| 513 |
+
return mem
|
| 514 |
+
|
| 515 |
+
|
| 516 |
+
def myhash_traced(a: int, trace: dict[Any, int], round: int, batch_i: int) -> int:
|
| 517 |
+
"""A simple 32-bit hash function"""
|
| 518 |
+
fns = {
|
| 519 |
+
"+": lambda x, y: x + y,
|
| 520 |
+
"^": lambda x, y: x ^ y,
|
| 521 |
+
"<<": lambda x, y: x << y,
|
| 522 |
+
">>": lambda x, y: x >> y,
|
| 523 |
+
}
|
| 524 |
+
|
| 525 |
+
def r(x):
|
| 526 |
+
return x % (2**32)
|
| 527 |
+
|
| 528 |
+
for i, (op1, val1, op2, op3, val3) in enumerate(HASH_STAGES):
|
| 529 |
+
a = r(fns[op2](r(fns[op1](a, val1)), r(fns[op3](a, val3))))
|
| 530 |
+
trace[(round, batch_i, "hash_stage", i)] = a
|
| 531 |
+
|
| 532 |
+
return a
|
| 533 |
+
|
| 534 |
+
|
| 535 |
+
def reference_kernel2(mem: list[int], trace: dict[Any, int] = {}):
|
| 536 |
+
"""
|
| 537 |
+
Reference implementation of the kernel on a flat memory.
|
| 538 |
+
"""
|
| 539 |
+
# This is the initial memory layout
|
| 540 |
+
rounds = mem[0]
|
| 541 |
+
n_nodes = mem[1]
|
| 542 |
+
batch_size = mem[2]
|
| 543 |
+
forest_height = mem[3]
|
| 544 |
+
# Offsets into the memory which indices get added to
|
| 545 |
+
forest_values_p = mem[4]
|
| 546 |
+
inp_indices_p = mem[5]
|
| 547 |
+
inp_values_p = mem[6]
|
| 548 |
+
yield mem
|
| 549 |
+
for h in range(rounds):
|
| 550 |
+
for i in range(batch_size):
|
| 551 |
+
idx = mem[inp_indices_p + i]
|
| 552 |
+
trace[(h, i, "idx")] = idx
|
| 553 |
+
val = mem[inp_values_p + i]
|
| 554 |
+
trace[(h, i, "val")] = val
|
| 555 |
+
node_val = mem[forest_values_p + idx]
|
| 556 |
+
trace[(h, i, "node_val")] = node_val
|
| 557 |
+
val = myhash_traced(val ^ node_val, trace, h, i)
|
| 558 |
+
trace[(h, i, "hashed_val")] = val
|
| 559 |
+
idx = 2 * idx + (1 if val % 2 == 0 else 2)
|
| 560 |
+
trace[(h, i, "next_idx")] = idx
|
| 561 |
+
idx = 0 if idx >= n_nodes else idx
|
| 562 |
+
trace[(h, i, "wrapped_idx")] = idx
|
| 563 |
+
mem[inp_values_p + i] = val
|
| 564 |
+
mem[inp_indices_p + i] = idx
|
| 565 |
+
# You can add new yields or move this around for debugging
|
| 566 |
+
# as long as it's matched by pause instructions.
|
| 567 |
+
# The submission tests evaluate only on final memory.
|
| 568 |
+
yield mem
|
original_performance_takehome/tests/frozen_problem.py
ADDED
|
@@ -0,0 +1,568 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Read the top of perf_takehome.py for more introduction.
|
| 3 |
+
|
| 4 |
+
This file is separate mostly for ease of copying it to freeze the machine and
|
| 5 |
+
reference kernel for testing.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from copy import copy
|
| 9 |
+
from dataclasses import dataclass
|
| 10 |
+
from enum import Enum
|
| 11 |
+
from typing import Any, Literal
|
| 12 |
+
import random
|
| 13 |
+
|
| 14 |
+
Engine = Literal["alu", "load", "store", "flow"]
|
| 15 |
+
Instruction = dict[Engine, list[tuple]]
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class CoreState(Enum):
|
| 19 |
+
RUNNING = 1
|
| 20 |
+
PAUSED = 2
|
| 21 |
+
STOPPED = 3
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@dataclass
|
| 25 |
+
class Core:
|
| 26 |
+
id: int
|
| 27 |
+
scratch: list[int]
|
| 28 |
+
trace_buf: list[int]
|
| 29 |
+
pc: int = 0
|
| 30 |
+
state: CoreState = CoreState.RUNNING
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
@dataclass
|
| 34 |
+
class DebugInfo:
|
| 35 |
+
"""
|
| 36 |
+
We give you some debug info but it's up to you to use it in Machine if you
|
| 37 |
+
want to. You're also welcome to add more.
|
| 38 |
+
"""
|
| 39 |
+
|
| 40 |
+
# Maps scratch variable addr to (name, len) pair
|
| 41 |
+
scratch_map: dict[int, (str, int)]
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def cdiv(a, b):
|
| 45 |
+
return (a + b - 1) // b
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
SLOT_LIMITS = {
|
| 49 |
+
"alu": 12,
|
| 50 |
+
"valu": 6,
|
| 51 |
+
"load": 2,
|
| 52 |
+
"store": 2,
|
| 53 |
+
"flow": 1,
|
| 54 |
+
"debug": 64,
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
VLEN = 8
|
| 58 |
+
# Older versions of the take-home used multiple cores, but this version only uses 1
|
| 59 |
+
N_CORES = 1
|
| 60 |
+
SCRATCH_SIZE = 1536
|
| 61 |
+
BASE_ADDR_TID = 100000
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class Machine:
|
| 65 |
+
"""
|
| 66 |
+
Simulator for a custom VLIW SIMD architecture.
|
| 67 |
+
|
| 68 |
+
VLIW (Very Large Instruction Word): Cores are composed of different
|
| 69 |
+
"engines" each of which can execute multiple "slots" per cycle in parallel.
|
| 70 |
+
How many slots each engine can execute per cycle is limited by SLOT_LIMITS.
|
| 71 |
+
Effects of instructions don't take effect until the end of cycle. Each
|
| 72 |
+
cycle, all engines execute all of their filled slots for that instruction.
|
| 73 |
+
Effects like writes to memory take place after all the inputs are read.
|
| 74 |
+
|
| 75 |
+
SIMD: There are instructions for acting on vectors of VLEN elements in a
|
| 76 |
+
single slot. You can use vload and vstore to load multiple contiguous
|
| 77 |
+
elements but not non-contiguous elements. Use vbroadcast to broadcast a
|
| 78 |
+
scalar to a vector and then operate on vectors with valu instructions.
|
| 79 |
+
|
| 80 |
+
The memory and scratch space are composed of 32-bit words. The solution is
|
| 81 |
+
plucked out of the memory at the end of the program. You can think of the
|
| 82 |
+
scratch space as serving the purpose of registers, constant memory, and a
|
| 83 |
+
manually-managed cache.
|
| 84 |
+
|
| 85 |
+
Here's an example of what an instruction might look like:
|
| 86 |
+
|
| 87 |
+
{"valu": [("*", 4, 0, 0), ("+", 8, 4, 0)], "load": [("load", 16, 17)]}
|
| 88 |
+
|
| 89 |
+
In general every number in an instruction is a scratch address except for
|
| 90 |
+
const and jump, and except for store and some flow instructions the first
|
| 91 |
+
operand is the destination.
|
| 92 |
+
|
| 93 |
+
This comment is not meant to be full ISA documentation though, for the rest
|
| 94 |
+
you should look through the simulator code.
|
| 95 |
+
"""
|
| 96 |
+
|
| 97 |
+
def __init__(
|
| 98 |
+
self,
|
| 99 |
+
mem_dump: list[int],
|
| 100 |
+
program: list[Instruction],
|
| 101 |
+
debug_info: DebugInfo,
|
| 102 |
+
n_cores: int = 1,
|
| 103 |
+
scratch_size: int = SCRATCH_SIZE,
|
| 104 |
+
trace: bool = False,
|
| 105 |
+
value_trace: dict[Any, int] = {},
|
| 106 |
+
):
|
| 107 |
+
self.cores = [
|
| 108 |
+
Core(id=i, scratch=[0] * scratch_size, trace_buf=[]) for i in range(n_cores)
|
| 109 |
+
]
|
| 110 |
+
self.mem = copy(mem_dump)
|
| 111 |
+
self.program = program
|
| 112 |
+
self.debug_info = debug_info
|
| 113 |
+
self.value_trace = value_trace
|
| 114 |
+
self.prints = False
|
| 115 |
+
self.cycle = 0
|
| 116 |
+
self.enable_pause = True
|
| 117 |
+
self.enable_debug = True
|
| 118 |
+
if trace:
|
| 119 |
+
self.setup_trace()
|
| 120 |
+
else:
|
| 121 |
+
self.trace = None
|
| 122 |
+
|
| 123 |
+
def rewrite_instr(self, instr):
|
| 124 |
+
"""
|
| 125 |
+
Rewrite an instruction to use scratch addresses instead of names
|
| 126 |
+
"""
|
| 127 |
+
res = {}
|
| 128 |
+
for name, slots in instr.items():
|
| 129 |
+
res[name] = []
|
| 130 |
+
for slot in slots:
|
| 131 |
+
res[name].append(self.rewrite_slot(slot))
|
| 132 |
+
return res
|
| 133 |
+
|
| 134 |
+
def print_step(self, instr, core):
|
| 135 |
+
# print(core.id)
|
| 136 |
+
# print(core.trace_buf)
|
| 137 |
+
print(self.scratch_map(core))
|
| 138 |
+
print(core.pc, instr, self.rewrite_instr(instr))
|
| 139 |
+
|
| 140 |
+
def scratch_map(self, core):
|
| 141 |
+
res = {}
|
| 142 |
+
for addr, (name, length) in self.debug_info.scratch_map.items():
|
| 143 |
+
res[name] = core.scratch[addr : addr + length]
|
| 144 |
+
return res
|
| 145 |
+
|
| 146 |
+
def rewrite_slot(self, slot):
|
| 147 |
+
return tuple(
|
| 148 |
+
self.debug_info.scratch_map.get(s, (None, None))[0] or s for s in slot
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
def setup_trace(self):
|
| 152 |
+
"""
|
| 153 |
+
The simulator generates traces in Chrome's Trace Event Format for
|
| 154 |
+
visualization in Perfetto (or chrome://tracing if you prefer it). See
|
| 155 |
+
the bottom of the file for info about how to use this.
|
| 156 |
+
|
| 157 |
+
See the format docs in case you want to add more info to the trace:
|
| 158 |
+
https://docs.google.com/document/d/1CvAClvFfyA5R-PhYUmn5OOQtYMH4h6I0nSsKchNAySU/preview
|
| 159 |
+
"""
|
| 160 |
+
self.trace = open("trace.json", "w")
|
| 161 |
+
self.trace.write("[")
|
| 162 |
+
tid_counter = 0
|
| 163 |
+
self.tids = {}
|
| 164 |
+
for ci, core in enumerate(self.cores):
|
| 165 |
+
self.trace.write(
|
| 166 |
+
f'{{"name": "process_name", "ph": "M", "pid": {ci}, "tid": 0, "args": {{"name":"Core {ci}"}}}},\n'
|
| 167 |
+
)
|
| 168 |
+
for name, limit in SLOT_LIMITS.items():
|
| 169 |
+
if name == "debug":
|
| 170 |
+
continue
|
| 171 |
+
for i in range(limit):
|
| 172 |
+
tid_counter += 1
|
| 173 |
+
self.trace.write(
|
| 174 |
+
f'{{"name": "thread_name", "ph": "M", "pid": {ci}, "tid": {tid_counter}, "args": {{"name":"{name}-{i}"}}}},\n'
|
| 175 |
+
)
|
| 176 |
+
self.tids[(ci, name, i)] = tid_counter
|
| 177 |
+
|
| 178 |
+
# Add zero-length events at the start so all slots show up in Perfetto
|
| 179 |
+
for ci, core in enumerate(self.cores):
|
| 180 |
+
for name, limit in SLOT_LIMITS.items():
|
| 181 |
+
if name == "debug":
|
| 182 |
+
continue
|
| 183 |
+
for i in range(limit):
|
| 184 |
+
tid = self.tids[(ci, name, i)]
|
| 185 |
+
self.trace.write(
|
| 186 |
+
f'{{"name": "init", "cat": "op", "ph": "X", "pid": {ci}, "tid": {tid}, "ts": 0, "dur": 0}},\n'
|
| 187 |
+
)
|
| 188 |
+
for ci, core in enumerate(self.cores):
|
| 189 |
+
self.trace.write(
|
| 190 |
+
f'{{"name": "process_name", "ph": "M", "pid": {len(self.cores) + ci}, "tid": 0, "args": {{"name":"Core {ci} Scratch"}}}},\n'
|
| 191 |
+
)
|
| 192 |
+
for addr, (name, length) in self.debug_info.scratch_map.items():
|
| 193 |
+
self.trace.write(
|
| 194 |
+
f'{{"name": "thread_name", "ph": "M", "pid": {len(self.cores) + ci}, "tid": {BASE_ADDR_TID + addr}, "args": {{"name":"{name}-{length}"}}}},\n'
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
def run(self):
|
| 198 |
+
for core in self.cores:
|
| 199 |
+
if core.state == CoreState.PAUSED:
|
| 200 |
+
core.state = CoreState.RUNNING
|
| 201 |
+
while any(c.state == CoreState.RUNNING for c in self.cores):
|
| 202 |
+
has_non_debug = False
|
| 203 |
+
for core in self.cores:
|
| 204 |
+
if core.state != CoreState.RUNNING:
|
| 205 |
+
continue
|
| 206 |
+
if core.pc >= len(self.program):
|
| 207 |
+
core.state = CoreState.STOPPED
|
| 208 |
+
continue
|
| 209 |
+
instr = self.program[core.pc]
|
| 210 |
+
if self.prints:
|
| 211 |
+
self.print_step(instr, core)
|
| 212 |
+
core.pc += 1
|
| 213 |
+
self.step(instr, core)
|
| 214 |
+
if any(name != "debug" for name in instr.keys()):
|
| 215 |
+
has_non_debug = True
|
| 216 |
+
if has_non_debug:
|
| 217 |
+
self.cycle += 1
|
| 218 |
+
|
| 219 |
+
def alu(self, core, op, dest, a1, a2):
|
| 220 |
+
a1 = core.scratch[a1]
|
| 221 |
+
a2 = core.scratch[a2]
|
| 222 |
+
match op:
|
| 223 |
+
case "+":
|
| 224 |
+
res = a1 + a2
|
| 225 |
+
case "-":
|
| 226 |
+
res = a1 - a2
|
| 227 |
+
case "*":
|
| 228 |
+
res = a1 * a2
|
| 229 |
+
case "//":
|
| 230 |
+
res = a1 // a2
|
| 231 |
+
case "cdiv":
|
| 232 |
+
res = cdiv(a1, a2)
|
| 233 |
+
case "^":
|
| 234 |
+
res = a1 ^ a2
|
| 235 |
+
case "&":
|
| 236 |
+
res = a1 & a2
|
| 237 |
+
case "|":
|
| 238 |
+
res = a1 | a2
|
| 239 |
+
case "<<":
|
| 240 |
+
res = a1 << a2
|
| 241 |
+
case ">>":
|
| 242 |
+
res = a1 >> a2
|
| 243 |
+
case "%":
|
| 244 |
+
res = a1 % a2
|
| 245 |
+
case "<":
|
| 246 |
+
res = int(a1 < a2)
|
| 247 |
+
case "==":
|
| 248 |
+
res = int(a1 == a2)
|
| 249 |
+
case _:
|
| 250 |
+
raise NotImplementedError(f"Unknown alu op {op}")
|
| 251 |
+
res = res % (2**32)
|
| 252 |
+
self.scratch_write[dest] = res
|
| 253 |
+
|
| 254 |
+
def valu(self, core, *slot):
|
| 255 |
+
match slot:
|
| 256 |
+
case ("vbroadcast", dest, src):
|
| 257 |
+
for i in range(VLEN):
|
| 258 |
+
self.scratch_write[dest + i] = core.scratch[src]
|
| 259 |
+
case ("multiply_add", dest, a, b, c):
|
| 260 |
+
for i in range(VLEN):
|
| 261 |
+
mul = (core.scratch[a + i] * core.scratch[b + i]) % (2**32)
|
| 262 |
+
self.scratch_write[dest + i] = (mul + core.scratch[c + i]) % (2**32)
|
| 263 |
+
case (op, dest, a1, a2):
|
| 264 |
+
for i in range(VLEN):
|
| 265 |
+
self.alu(core, op, dest + i, a1 + i, a2 + i)
|
| 266 |
+
case _:
|
| 267 |
+
raise NotImplementedError(f"Unknown valu op {slot}")
|
| 268 |
+
|
| 269 |
+
def load(self, core, *slot):
|
| 270 |
+
match slot:
|
| 271 |
+
case ("load", dest, addr):
|
| 272 |
+
# print(dest, addr, core.scratch[addr])
|
| 273 |
+
self.scratch_write[dest] = self.mem[core.scratch[addr]]
|
| 274 |
+
case ("load_offset", dest, addr, offset):
|
| 275 |
+
# Handy for treating vector dest and addr as a full block in the mini-compiler if you want
|
| 276 |
+
self.scratch_write[dest + offset] = self.mem[
|
| 277 |
+
core.scratch[addr + offset]
|
| 278 |
+
]
|
| 279 |
+
case ("vload", dest, addr): # addr is a scalar
|
| 280 |
+
addr = core.scratch[addr]
|
| 281 |
+
for vi in range(VLEN):
|
| 282 |
+
self.scratch_write[dest + vi] = self.mem[addr + vi]
|
| 283 |
+
case ("const", dest, val):
|
| 284 |
+
self.scratch_write[dest] = (val) % (2**32)
|
| 285 |
+
case _:
|
| 286 |
+
raise NotImplementedError(f"Unknown load op {slot}")
|
| 287 |
+
|
| 288 |
+
def store(self, core, *slot):
|
| 289 |
+
match slot:
|
| 290 |
+
case ("store", addr, src):
|
| 291 |
+
addr = core.scratch[addr]
|
| 292 |
+
self.mem_write[addr] = core.scratch[src]
|
| 293 |
+
case ("vstore", addr, src): # addr is a scalar
|
| 294 |
+
addr = core.scratch[addr]
|
| 295 |
+
for vi in range(VLEN):
|
| 296 |
+
self.mem_write[addr + vi] = core.scratch[src + vi]
|
| 297 |
+
case _:
|
| 298 |
+
raise NotImplementedError(f"Unknown store op {slot}")
|
| 299 |
+
|
| 300 |
+
def flow(self, core, *slot):
|
| 301 |
+
match slot:
|
| 302 |
+
case ("select", dest, cond, a, b):
|
| 303 |
+
self.scratch_write[dest] = (
|
| 304 |
+
core.scratch[a] if core.scratch[cond] != 0 else core.scratch[b]
|
| 305 |
+
)
|
| 306 |
+
case ("add_imm", dest, a, imm):
|
| 307 |
+
self.scratch_write[dest] = (core.scratch[a] + imm) % (2**32)
|
| 308 |
+
case ("vselect", dest, cond, a, b):
|
| 309 |
+
for vi in range(VLEN):
|
| 310 |
+
self.scratch_write[dest + vi] = (
|
| 311 |
+
core.scratch[a + vi]
|
| 312 |
+
if core.scratch[cond + vi] != 0
|
| 313 |
+
else core.scratch[b + vi]
|
| 314 |
+
)
|
| 315 |
+
case ("halt",):
|
| 316 |
+
core.state = CoreState.STOPPED
|
| 317 |
+
case ("pause",):
|
| 318 |
+
if self.enable_pause:
|
| 319 |
+
core.state = CoreState.PAUSED
|
| 320 |
+
case ("trace_write", val):
|
| 321 |
+
core.trace_buf.append(core.scratch[val])
|
| 322 |
+
case ("cond_jump", cond, addr):
|
| 323 |
+
if core.scratch[cond] != 0:
|
| 324 |
+
core.pc = addr
|
| 325 |
+
case ("cond_jump_rel", cond, offset):
|
| 326 |
+
if core.scratch[cond] != 0:
|
| 327 |
+
core.pc += offset
|
| 328 |
+
case ("jump", addr):
|
| 329 |
+
core.pc = addr
|
| 330 |
+
case ("jump_indirect", addr):
|
| 331 |
+
core.pc = core.scratch[addr]
|
| 332 |
+
case ("coreid", dest):
|
| 333 |
+
self.scratch_write[dest] = core.id
|
| 334 |
+
case _:
|
| 335 |
+
raise NotImplementedError(f"Unknown flow op {slot}")
|
| 336 |
+
|
| 337 |
+
def trace_post_step(self, instr, core):
|
| 338 |
+
# You can add extra stuff to the trace if you want!
|
| 339 |
+
for addr, (name, length) in self.debug_info.scratch_map.items():
|
| 340 |
+
if any((addr + vi) in self.scratch_write for vi in range(length)):
|
| 341 |
+
val = str(core.scratch[addr : addr + length])
|
| 342 |
+
val = val.replace("[", "").replace("]", "")
|
| 343 |
+
self.trace.write(
|
| 344 |
+
f'{{"name": "{val}", "cat": "op", "ph": "X", "pid": {len(self.cores) + core.id}, "tid": {BASE_ADDR_TID + addr}, "ts": {self.cycle}, "dur": 1 }},\n'
|
| 345 |
+
)
|
| 346 |
+
|
| 347 |
+
def trace_slot(self, core, slot, name, i):
|
| 348 |
+
self.trace.write(
|
| 349 |
+
f'{{"name": "{slot[0]}", "cat": "op", "ph": "X", "pid": {core.id}, "tid": {self.tids[(core.id, name, i)]}, "ts": {self.cycle}, "dur": 1, "args":{{"slot": "{str(slot)}", "named": "{str(self.rewrite_slot(slot))}" }} }},\n'
|
| 350 |
+
)
|
| 351 |
+
|
| 352 |
+
def step(self, instr: Instruction, core):
|
| 353 |
+
"""
|
| 354 |
+
Execute all the slots in each engine for a single instruction bundle
|
| 355 |
+
"""
|
| 356 |
+
ENGINE_FNS = {
|
| 357 |
+
"alu": self.alu,
|
| 358 |
+
"valu": self.valu,
|
| 359 |
+
"load": self.load,
|
| 360 |
+
"store": self.store,
|
| 361 |
+
"flow": self.flow,
|
| 362 |
+
}
|
| 363 |
+
self.scratch_write = {}
|
| 364 |
+
self.mem_write = {}
|
| 365 |
+
for name, slots in instr.items():
|
| 366 |
+
if name == "debug":
|
| 367 |
+
if not self.enable_debug:
|
| 368 |
+
continue
|
| 369 |
+
for slot in slots:
|
| 370 |
+
if slot[0] == "compare":
|
| 371 |
+
loc, key = slot[1], slot[2]
|
| 372 |
+
ref = self.value_trace[key]
|
| 373 |
+
res = core.scratch[loc]
|
| 374 |
+
assert res == ref, f"{res} != {ref} for {key} at pc={core.pc}"
|
| 375 |
+
elif slot[0] == "vcompare":
|
| 376 |
+
loc, keys = slot[1], slot[2]
|
| 377 |
+
ref = [self.value_trace[key] for key in keys]
|
| 378 |
+
res = core.scratch[loc : loc + VLEN]
|
| 379 |
+
assert res == ref, (
|
| 380 |
+
f"{res} != {ref} for {keys} at pc={core.pc} loc={loc}"
|
| 381 |
+
)
|
| 382 |
+
continue
|
| 383 |
+
assert len(slots) <= SLOT_LIMITS[name]
|
| 384 |
+
for i, slot in enumerate(slots):
|
| 385 |
+
if self.trace is not None:
|
| 386 |
+
self.trace_slot(core, slot, name, i)
|
| 387 |
+
ENGINE_FNS[name](core, *slot)
|
| 388 |
+
for addr, val in self.scratch_write.items():
|
| 389 |
+
core.scratch[addr] = val
|
| 390 |
+
for addr, val in self.mem_write.items():
|
| 391 |
+
self.mem[addr] = val
|
| 392 |
+
|
| 393 |
+
if self.trace:
|
| 394 |
+
self.trace_post_step(instr, core)
|
| 395 |
+
|
| 396 |
+
del self.scratch_write
|
| 397 |
+
del self.mem_write
|
| 398 |
+
|
| 399 |
+
def __del__(self):
|
| 400 |
+
if self.trace is not None:
|
| 401 |
+
self.trace.write("]")
|
| 402 |
+
self.trace.close()
|
| 403 |
+
|
| 404 |
+
|
| 405 |
+
@dataclass
|
| 406 |
+
class Tree:
|
| 407 |
+
"""
|
| 408 |
+
An implicit perfect balanced binary tree with values on the nodes.
|
| 409 |
+
"""
|
| 410 |
+
|
| 411 |
+
height: int
|
| 412 |
+
values: list[int]
|
| 413 |
+
|
| 414 |
+
@staticmethod
|
| 415 |
+
def generate(height: int):
|
| 416 |
+
n_nodes = 2 ** (height + 1) - 1
|
| 417 |
+
values = [random.randint(0, 2**30 - 1) for _ in range(n_nodes)]
|
| 418 |
+
return Tree(height, values)
|
| 419 |
+
|
| 420 |
+
|
| 421 |
+
@dataclass
|
| 422 |
+
class Input:
|
| 423 |
+
"""
|
| 424 |
+
A batch of inputs, indices to nodes (starting as 0) and initial input
|
| 425 |
+
values. We then iterate these for a specified number of rounds.
|
| 426 |
+
"""
|
| 427 |
+
|
| 428 |
+
indices: list[int]
|
| 429 |
+
values: list[int]
|
| 430 |
+
rounds: int
|
| 431 |
+
|
| 432 |
+
@staticmethod
|
| 433 |
+
def generate(forest: Tree, batch_size: int, rounds: int):
|
| 434 |
+
indices = [0 for _ in range(batch_size)]
|
| 435 |
+
values = [random.randint(0, 2**30 - 1) for _ in range(batch_size)]
|
| 436 |
+
return Input(indices, values, rounds)
|
| 437 |
+
|
| 438 |
+
|
| 439 |
+
HASH_STAGES = [
|
| 440 |
+
("+", 0x7ED55D16, "+", "<<", 12),
|
| 441 |
+
("^", 0xC761C23C, "^", ">>", 19),
|
| 442 |
+
("+", 0x165667B1, "+", "<<", 5),
|
| 443 |
+
("+", 0xD3A2646C, "^", "<<", 9),
|
| 444 |
+
("+", 0xFD7046C5, "+", "<<", 3),
|
| 445 |
+
("^", 0xB55A4F09, "^", ">>", 16),
|
| 446 |
+
]
|
| 447 |
+
|
| 448 |
+
|
| 449 |
+
def myhash(a: int) -> int:
|
| 450 |
+
"""A simple 32-bit hash function"""
|
| 451 |
+
fns = {
|
| 452 |
+
"+": lambda x, y: x + y,
|
| 453 |
+
"^": lambda x, y: x ^ y,
|
| 454 |
+
"<<": lambda x, y: x << y,
|
| 455 |
+
">>": lambda x, y: x >> y,
|
| 456 |
+
}
|
| 457 |
+
|
| 458 |
+
def r(x):
|
| 459 |
+
return x % (2**32)
|
| 460 |
+
|
| 461 |
+
for op1, val1, op2, op3, val3 in HASH_STAGES:
|
| 462 |
+
a = r(fns[op2](r(fns[op1](a, val1)), r(fns[op3](a, val3))))
|
| 463 |
+
|
| 464 |
+
return a
|
| 465 |
+
|
| 466 |
+
|
| 467 |
+
def reference_kernel(t: Tree, inp: Input):
|
| 468 |
+
"""
|
| 469 |
+
Reference implementation of the kernel.
|
| 470 |
+
|
| 471 |
+
A parallel tree traversal where at each node we set
|
| 472 |
+
cur_inp_val = myhash(cur_inp_val ^ node_val)
|
| 473 |
+
and then choose the left branch if cur_inp_val is even.
|
| 474 |
+
If we reach the bottom of the tree we wrap around to the top.
|
| 475 |
+
"""
|
| 476 |
+
for h in range(inp.rounds):
|
| 477 |
+
for i in range(len(inp.indices)):
|
| 478 |
+
idx = inp.indices[i]
|
| 479 |
+
val = inp.values[i]
|
| 480 |
+
val = myhash(val ^ t.values[idx])
|
| 481 |
+
idx = 2 * idx + (1 if val % 2 == 0 else 2)
|
| 482 |
+
idx = 0 if idx >= len(t.values) else idx
|
| 483 |
+
inp.values[i] = val
|
| 484 |
+
inp.indices[i] = idx
|
| 485 |
+
|
| 486 |
+
|
| 487 |
+
def build_mem_image(t: Tree, inp: Input) -> list[int]:
|
| 488 |
+
"""
|
| 489 |
+
Build a flat memory image of the problem.
|
| 490 |
+
"""
|
| 491 |
+
header = 7
|
| 492 |
+
extra_room = len(t.values) + len(inp.indices) * 2 + VLEN * 2 + 32
|
| 493 |
+
mem = [0] * (
|
| 494 |
+
header + len(t.values) + len(inp.indices) + len(inp.values) + extra_room
|
| 495 |
+
)
|
| 496 |
+
forest_values_p = header
|
| 497 |
+
inp_indices_p = forest_values_p + len(t.values)
|
| 498 |
+
inp_values_p = inp_indices_p + len(inp.values)
|
| 499 |
+
extra_room = inp_values_p + len(inp.values)
|
| 500 |
+
|
| 501 |
+
mem[0] = inp.rounds
|
| 502 |
+
mem[1] = len(t.values)
|
| 503 |
+
mem[2] = len(inp.indices)
|
| 504 |
+
mem[3] = t.height
|
| 505 |
+
mem[4] = forest_values_p
|
| 506 |
+
mem[5] = inp_indices_p
|
| 507 |
+
mem[6] = inp_values_p
|
| 508 |
+
mem[7] = extra_room
|
| 509 |
+
|
| 510 |
+
mem[header:inp_indices_p] = t.values
|
| 511 |
+
mem[inp_indices_p:inp_values_p] = inp.indices
|
| 512 |
+
mem[inp_values_p:] = inp.values
|
| 513 |
+
return mem
|
| 514 |
+
|
| 515 |
+
|
| 516 |
+
def myhash_traced(a: int, trace: dict[Any, int], round: int, batch_i: int) -> int:
|
| 517 |
+
"""A simple 32-bit hash function"""
|
| 518 |
+
fns = {
|
| 519 |
+
"+": lambda x, y: x + y,
|
| 520 |
+
"^": lambda x, y: x ^ y,
|
| 521 |
+
"<<": lambda x, y: x << y,
|
| 522 |
+
">>": lambda x, y: x >> y,
|
| 523 |
+
}
|
| 524 |
+
|
| 525 |
+
def r(x):
|
| 526 |
+
return x % (2**32)
|
| 527 |
+
|
| 528 |
+
for i, (op1, val1, op2, op3, val3) in enumerate(HASH_STAGES):
|
| 529 |
+
a = r(fns[op2](r(fns[op1](a, val1)), r(fns[op3](a, val3))))
|
| 530 |
+
trace[(round, batch_i, "hash_stage", i)] = a
|
| 531 |
+
|
| 532 |
+
return a
|
| 533 |
+
|
| 534 |
+
|
| 535 |
+
def reference_kernel2(mem: list[int], trace: dict[Any, int] = {}):
|
| 536 |
+
"""
|
| 537 |
+
Reference implementation of the kernel on a flat memory.
|
| 538 |
+
"""
|
| 539 |
+
# This is the initial memory layout
|
| 540 |
+
rounds = mem[0]
|
| 541 |
+
n_nodes = mem[1]
|
| 542 |
+
batch_size = mem[2]
|
| 543 |
+
forest_height = mem[3]
|
| 544 |
+
# Offsets into the memory which indices get added to
|
| 545 |
+
forest_values_p = mem[4]
|
| 546 |
+
inp_indices_p = mem[5]
|
| 547 |
+
inp_values_p = mem[6]
|
| 548 |
+
yield mem
|
| 549 |
+
for h in range(rounds):
|
| 550 |
+
for i in range(batch_size):
|
| 551 |
+
idx = mem[inp_indices_p + i]
|
| 552 |
+
trace[(h, i, "idx")] = idx
|
| 553 |
+
val = mem[inp_values_p + i]
|
| 554 |
+
trace[(h, i, "val")] = val
|
| 555 |
+
node_val = mem[forest_values_p + idx]
|
| 556 |
+
trace[(h, i, "node_val")] = node_val
|
| 557 |
+
val = myhash_traced(val ^ node_val, trace, h, i)
|
| 558 |
+
trace[(h, i, "hashed_val")] = val
|
| 559 |
+
idx = 2 * idx + (1 if val % 2 == 0 else 2)
|
| 560 |
+
trace[(h, i, "next_idx")] = idx
|
| 561 |
+
idx = 0 if idx >= n_nodes else idx
|
| 562 |
+
trace[(h, i, "wrapped_idx")] = idx
|
| 563 |
+
mem[inp_values_p + i] = val
|
| 564 |
+
mem[inp_indices_p + i] = idx
|
| 565 |
+
# You can add new yields or move this around for debugging
|
| 566 |
+
# as long as it's matched by pause instructions.
|
| 567 |
+
# The submission tests evaluate only on final memory.
|
| 568 |
+
yield mem
|
original_performance_takehome/tests/submission_tests.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os, sys, inspect
|
| 2 |
+
|
| 3 |
+
currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
|
| 4 |
+
parentdir = os.path.dirname(currentdir)
|
| 5 |
+
sys.path.insert(0, parentdir)
|
| 6 |
+
|
| 7 |
+
from functools import lru_cache
|
| 8 |
+
import unittest
|
| 9 |
+
import random
|
| 10 |
+
|
| 11 |
+
from frozen_problem import (
|
| 12 |
+
Machine,
|
| 13 |
+
build_mem_image,
|
| 14 |
+
reference_kernel2,
|
| 15 |
+
Tree,
|
| 16 |
+
Input,
|
| 17 |
+
N_CORES,
|
| 18 |
+
VLEN,
|
| 19 |
+
)
|
| 20 |
+
from perf_takehome import KernelBuilder
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
@lru_cache(maxsize=None)
|
| 24 |
+
def kernel_builder(forest_height: int, n_nodes: int, batch_size: int, rounds: int):
|
| 25 |
+
kb = KernelBuilder()
|
| 26 |
+
kb.build_kernel(forest_height, n_nodes, batch_size, rounds)
|
| 27 |
+
return kb
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def do_kernel_test(forest_height: int, rounds: int, batch_size: int):
|
| 31 |
+
print(f"Testing {forest_height=}, {rounds=}, {batch_size=}")
|
| 32 |
+
# Note the random generator is not seeded here
|
| 33 |
+
forest = Tree.generate(forest_height)
|
| 34 |
+
inp = Input.generate(forest, batch_size, rounds)
|
| 35 |
+
mem = build_mem_image(forest, inp)
|
| 36 |
+
|
| 37 |
+
kb = kernel_builder(forest.height, len(forest.values), len(inp.indices), rounds)
|
| 38 |
+
# print(kb.instrs)
|
| 39 |
+
|
| 40 |
+
machine = Machine(mem, kb.instrs, kb.debug_info(), n_cores=N_CORES)
|
| 41 |
+
machine.enable_pause = False
|
| 42 |
+
machine.enable_debug = False
|
| 43 |
+
machine.run()
|
| 44 |
+
|
| 45 |
+
for ref_mem in reference_kernel2(mem):
|
| 46 |
+
pass
|
| 47 |
+
|
| 48 |
+
inp_values_p = ref_mem[6]
|
| 49 |
+
assert (
|
| 50 |
+
machine.mem[inp_values_p : inp_values_p + len(inp.values)]
|
| 51 |
+
== ref_mem[inp_values_p : inp_values_p + len(inp.values)]
|
| 52 |
+
), "Incorrect output values"
|
| 53 |
+
print("CYCLES: ", machine.cycle)
|
| 54 |
+
return machine.cycle
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class CorrectnessTests(unittest.TestCase):
|
| 58 |
+
def test_kernel_correctness(self):
|
| 59 |
+
for i in range(8):
|
| 60 |
+
do_kernel_test(10, 16, 256)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
BASELINE = 147734
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
@lru_cache(maxsize=None)
|
| 67 |
+
def cycles():
|
| 68 |
+
try:
|
| 69 |
+
res = do_kernel_test(10, 16, 256)
|
| 70 |
+
print("Speedup over baseline: ", BASELINE / res)
|
| 71 |
+
return res
|
| 72 |
+
except AssertionError as e:
|
| 73 |
+
return BASELINE * 2
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
class SpeedTests(unittest.TestCase):
|
| 77 |
+
"""
|
| 78 |
+
You very much don't need to pass all of these to pass the interview.
|
| 79 |
+
The impressiveness also isn't linear in number of tests passed.
|
| 80 |
+
|
| 81 |
+
These are just so that test pass rate gets translated into a number
|
| 82 |
+
on the CodeSignal UI.
|
| 83 |
+
"""
|
| 84 |
+
|
| 85 |
+
def test_kernel_speedup(self):
|
| 86 |
+
assert cycles() < BASELINE
|
| 87 |
+
|
| 88 |
+
def test_kernel_updated_starting_point(self):
|
| 89 |
+
# The updated version of this take-home given to candidates contained starter code that started them at this point
|
| 90 |
+
assert cycles() < 18532
|
| 91 |
+
|
| 92 |
+
def test_opus4_many_hours(self):
|
| 93 |
+
# Claude Opus 4 after many hours in the test-time compute harness
|
| 94 |
+
assert cycles() < 2164
|
| 95 |
+
|
| 96 |
+
def test_opus45_casual(self):
|
| 97 |
+
# Claude Opus 4.5 in a casual Claude Code session, approximately matching
|
| 98 |
+
# the best human performance in 2 hours
|
| 99 |
+
assert cycles() < 1790
|
| 100 |
+
|
| 101 |
+
def test_opus45_2hr(self):
|
| 102 |
+
# Claude Opus 4.5 after 2 hours in our test-time compute harness
|
| 103 |
+
assert cycles() < 1579
|
| 104 |
+
|
| 105 |
+
def test_sonnet45_many_hours(self):
|
| 106 |
+
# Claude Sonnet 4.5 after many more than 2 hours of test-time compute
|
| 107 |
+
assert cycles() < 1548
|
| 108 |
+
|
| 109 |
+
def test_opus45_11hr(self):
|
| 110 |
+
# Claude Opus 4.5 after 11.5 hours in the harness
|
| 111 |
+
assert cycles() < 1487
|
| 112 |
+
|
| 113 |
+
def test_opus45_improved_harness(self):
|
| 114 |
+
# Claude Opus 4.5 in an improved test time compute harness
|
| 115 |
+
assert cycles() < 1363
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
if __name__ == "__main__":
|
| 119 |
+
unittest.main()
|
original_performance_takehome/watch_trace.html
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<!doctype html>
|
| 2 |
+
<html lang="en-us">
|
| 3 |
+
<link rel="shortcut icon" href="data:image/x-icon;," type="image/x-icon" />
|
| 4 |
+
|
| 5 |
+
<body>
|
| 6 |
+
<style>
|
| 7 |
+
pre {
|
| 8 |
+
border: 1px solid #eee;
|
| 9 |
+
margin: 10px 0;
|
| 10 |
+
font-family: monospace;
|
| 11 |
+
font-size: 10px;
|
| 12 |
+
min-height: 100px;
|
| 13 |
+
}
|
| 14 |
+
|
| 15 |
+
body > * {
|
| 16 |
+
margin: 20px;
|
| 17 |
+
}
|
| 18 |
+
|
| 19 |
+
#btn_fetch {
|
| 20 |
+
font-size: 14px;
|
| 21 |
+
}
|
| 22 |
+
</style>
|
| 23 |
+
|
| 24 |
+
<select id="source" size="4">
|
| 25 |
+
<option selected>/trace.json</option>
|
| 26 |
+
</select>
|
| 27 |
+
|
| 28 |
+
<br />
|
| 29 |
+
|
| 30 |
+
<button type="button" id="btn_fetch">Open Perfetto</button>
|
| 31 |
+
|
| 32 |
+
<br />
|
| 33 |
+
|
| 34 |
+
<pre id="logs" cols="80" rows="20"></pre>
|
| 35 |
+
|
| 36 |
+
<script type="text/javascript">
|
| 37 |
+
// const ORIGIN = 'http://localhost:8000/perfetto/';
|
| 38 |
+
const ORIGIN = "https://ui.perfetto.dev";
|
| 39 |
+
|
| 40 |
+
const logs = document.getElementById("logs");
|
| 41 |
+
const btnFetch = document.getElementById("btn_fetch");
|
| 42 |
+
|
| 43 |
+
async function getMtime() {
|
| 44 |
+
const mtime_resp = await fetch("/mtime");
|
| 45 |
+
const mtime = await mtime_resp.text();
|
| 46 |
+
return mtime;
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
async function fetchAndOpen(traceUrl) {
|
| 50 |
+
logs.innerText += `Fetching trace from ${traceUrl}...\n`;
|
| 51 |
+
const mtime = await getMtime();
|
| 52 |
+
const resp = await fetch(traceUrl);
|
| 53 |
+
// Error checcking is left as an exercise to the reader.
|
| 54 |
+
const blob = await resp.blob();
|
| 55 |
+
const arrayBuffer = await blob.arrayBuffer();
|
| 56 |
+
logs.innerText += `fetch() complete, now passing to ui.perfetto.dev\n`;
|
| 57 |
+
openTrace(arrayBuffer, traceUrl, mtime);
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
async function repoll(win, traceUrl, mtime) {
|
| 61 |
+
const newMtime = await getMtime();
|
| 62 |
+
console.log(newMtime, mtime);
|
| 63 |
+
if (newMtime !== mtime) {
|
| 64 |
+
logs.innerText += `Trace updated, fetching new version...\n`;
|
| 65 |
+
const resp = await fetch(traceUrl);
|
| 66 |
+
const blob = await resp.blob();
|
| 67 |
+
const arrayBuffer = await blob.arrayBuffer();
|
| 68 |
+
logs.innerText += `New trace fetched, opening...\n`;
|
| 69 |
+
sendTrace(win, arrayBuffer, traceUrl);
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
setTimeout(() => repoll(win, traceUrl, newMtime), 500);
|
| 73 |
+
}
|
| 74 |
+
|
| 75 |
+
function sendTrace(win, arrayBuffer, traceUrl) {
|
| 76 |
+
const reopenUrl = new URL(location.href);
|
| 77 |
+
reopenUrl.hash = `#reopen=${traceUrl}`;
|
| 78 |
+
logs.innerText += `Sending trace to UI\n`;
|
| 79 |
+
win.postMessage(
|
| 80 |
+
{
|
| 81 |
+
perfetto: {
|
| 82 |
+
buffer: arrayBuffer,
|
| 83 |
+
title: "trace.json",
|
| 84 |
+
url: reopenUrl.toString(),
|
| 85 |
+
keepApiOpen: true,
|
| 86 |
+
},
|
| 87 |
+
},
|
| 88 |
+
ORIGIN,
|
| 89 |
+
);
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
function openTrace(arrayBuffer, traceUrl, mtime) {
|
| 93 |
+
const win = window.open(ORIGIN);
|
| 94 |
+
if (!win) {
|
| 95 |
+
btnFetch.style.background = "#f3ca63";
|
| 96 |
+
btnFetch.onclick = () => openTrace(arrayBuffer);
|
| 97 |
+
logs.innerText += `Popups blocked, you need to manually click the button`;
|
| 98 |
+
btnFetch.innerText =
|
| 99 |
+
"Popups blocked, click here to open the trace file";
|
| 100 |
+
return;
|
| 101 |
+
}
|
| 102 |
+
|
| 103 |
+
const timer = setInterval(
|
| 104 |
+
() => win.postMessage("PING", ORIGIN),
|
| 105 |
+
50,
|
| 106 |
+
);
|
| 107 |
+
|
| 108 |
+
const onMessageHandler = (evt) => {
|
| 109 |
+
if (evt.data !== "PONG") return;
|
| 110 |
+
|
| 111 |
+
// We got a PONG, the UI is ready.
|
| 112 |
+
window.clearInterval(timer);
|
| 113 |
+
window.removeEventListener("message", onMessageHandler);
|
| 114 |
+
|
| 115 |
+
sendTrace(win, arrayBuffer, traceUrl);
|
| 116 |
+
setTimeout(() => repoll(win, traceUrl, mtime), 500);
|
| 117 |
+
};
|
| 118 |
+
|
| 119 |
+
window.addEventListener("message", onMessageHandler);
|
| 120 |
+
}
|
| 121 |
+
|
| 122 |
+
// This is triggered when following the link from the Perfetto UI's sidebar.
|
| 123 |
+
if (location.hash.startsWith("#reopen=")) {
|
| 124 |
+
const traceUrl = location.hash.substr(8);
|
| 125 |
+
fetchAndOpen(traceUrl);
|
| 126 |
+
}
|
| 127 |
+
|
| 128 |
+
btnFetch.onclick = () =>
|
| 129 |
+
fetchAndOpen(document.getElementById("source").value);
|
| 130 |
+
</script>
|
| 131 |
+
</body>
|
| 132 |
+
</html>
|
original_performance_takehome/watch_trace.py
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import http.server
|
| 2 |
+
import os
|
| 3 |
+
from datetime import datetime
|
| 4 |
+
import webbrowser
|
| 5 |
+
import urllib.request
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
# Define a handler class
|
| 9 |
+
class MyHandler(http.server.BaseHTTPRequestHandler):
|
| 10 |
+
def do_GET(self):
|
| 11 |
+
try:
|
| 12 |
+
# Serve a string constant at the index
|
| 13 |
+
if self.path == "/":
|
| 14 |
+
self.send_response(200)
|
| 15 |
+
self.send_header("Content-type", "text/html")
|
| 16 |
+
self.end_headers()
|
| 17 |
+
with open("watch_trace.html", "rb") as file:
|
| 18 |
+
self.wfile.write(file.read())
|
| 19 |
+
|
| 20 |
+
# Stream the contents of 'trace.json' at '/trace.json'
|
| 21 |
+
elif self.path == "/trace.json":
|
| 22 |
+
self.send_response(200)
|
| 23 |
+
self.send_header("Content-type", "application/json")
|
| 24 |
+
self.end_headers()
|
| 25 |
+
with open("trace.json", "rb") as file:
|
| 26 |
+
while chunk := file.read(8192):
|
| 27 |
+
self.wfile.write(chunk)
|
| 28 |
+
|
| 29 |
+
# Serve the file modification time of 'trace.json' at '/mtime'
|
| 30 |
+
elif self.path == "/mtime":
|
| 31 |
+
mtime = os.path.getmtime("trace.json")
|
| 32 |
+
last_modified_date = datetime.fromtimestamp(mtime).strftime(
|
| 33 |
+
"%Y-%m-%d %H:%M:%S"
|
| 34 |
+
)
|
| 35 |
+
self.send_response(200)
|
| 36 |
+
self.send_header("Content-type", "text/plain")
|
| 37 |
+
self.end_headers()
|
| 38 |
+
self.wfile.write(last_modified_date.encode())
|
| 39 |
+
|
| 40 |
+
elif self.path.startswith("/perfetto"):
|
| 41 |
+
proxy_url = "https://ui.perfetto.dev" + self.path[len("/perfetto") :]
|
| 42 |
+
print("Proxying request to " + proxy_url)
|
| 43 |
+
with urllib.request.urlopen(proxy_url) as response:
|
| 44 |
+
self.send_response(response.status)
|
| 45 |
+
|
| 46 |
+
self.end_headers()
|
| 47 |
+
res = response.read()
|
| 48 |
+
if self.path.endswith("frontend_bundle.js"):
|
| 49 |
+
print("Activating replacement")
|
| 50 |
+
# Fix a bug in Perfetto that they haven't deployed the fix for yet but have fixed internally
|
| 51 |
+
res = res.replace(
|
| 52 |
+
b"throw new Error(`EngineProxy ${this.tag} was disposed.`);",
|
| 53 |
+
b"return null;",
|
| 54 |
+
)
|
| 55 |
+
# Auto-expand tracks by default
|
| 56 |
+
res = res.replace(b"collapsed: true", b"collapsed: false")
|
| 57 |
+
res = res.replace(
|
| 58 |
+
b"collapsed: !hasHeapProfiles", b"collapsed: false"
|
| 59 |
+
)
|
| 60 |
+
for header in response.headers:
|
| 61 |
+
if header == "Content-Length":
|
| 62 |
+
self.send_header(header, len(res))
|
| 63 |
+
self.send_header(header, response.headers[header])
|
| 64 |
+
self.wfile.write(res)
|
| 65 |
+
|
| 66 |
+
else:
|
| 67 |
+
self.send_error(404, "File Not Found: {}".format(self.path))
|
| 68 |
+
|
| 69 |
+
except IOError:
|
| 70 |
+
self.send_error(404, "File Not Found: {}".format(self.path))
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
# Start the server
|
| 74 |
+
def run(server_class=http.server.HTTPServer, handler_class=MyHandler):
|
| 75 |
+
server_address = ("", 8000)
|
| 76 |
+
httpd = server_class(server_address, handler_class)
|
| 77 |
+
print("Starting httpd...")
|
| 78 |
+
webbrowser.open("http://localhost:8000")
|
| 79 |
+
httpd.serve_forever()
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
# Run the server
|
| 83 |
+
if __name__ == "__main__":
|
| 84 |
+
run()
|
requirements.txt
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch>=2.1.0
|
| 2 |
+
transformers>=4.40.0
|
| 3 |
+
datasets>=2.18.0
|
| 4 |
+
peft>=0.10.0
|
| 5 |
+
trl>=0.8.0
|
| 6 |
+
accelerate>=0.28.0
|
| 7 |
+
bitsandbytes>=0.43.0
|
| 8 |
+
gradio>=4.0.0
|