CreativeEngineer commited on
Commit
1d07708
·
1 Parent(s): 3aa84d6

Initial commit: VLIW kernel optimizer via RL

Browse files
Files changed (40) hide show
  1. README.md +26 -5
  2. __pycache__/app.cpython-314.pyc +0 -0
  3. app.py +549 -0
  4. original_performance_takehome/.git_backup/HEAD +1 -0
  5. original_performance_takehome/.git_backup/config +13 -0
  6. original_performance_takehome/.git_backup/description +1 -0
  7. original_performance_takehome/.git_backup/hooks/applypatch-msg.sample +15 -0
  8. original_performance_takehome/.git_backup/hooks/commit-msg.sample +24 -0
  9. original_performance_takehome/.git_backup/hooks/fsmonitor-watchman.sample +174 -0
  10. original_performance_takehome/.git_backup/hooks/post-update.sample +8 -0
  11. original_performance_takehome/.git_backup/hooks/pre-applypatch.sample +14 -0
  12. original_performance_takehome/.git_backup/hooks/pre-commit.sample +49 -0
  13. original_performance_takehome/.git_backup/hooks/pre-merge-commit.sample +13 -0
  14. original_performance_takehome/.git_backup/hooks/pre-push.sample +53 -0
  15. original_performance_takehome/.git_backup/hooks/pre-rebase.sample +169 -0
  16. original_performance_takehome/.git_backup/hooks/pre-receive.sample +24 -0
  17. original_performance_takehome/.git_backup/hooks/prepare-commit-msg.sample +42 -0
  18. original_performance_takehome/.git_backup/hooks/push-to-checkout.sample +78 -0
  19. original_performance_takehome/.git_backup/hooks/sendemail-validate.sample +77 -0
  20. original_performance_takehome/.git_backup/hooks/update.sample +128 -0
  21. original_performance_takehome/.git_backup/index +0 -0
  22. original_performance_takehome/.git_backup/info/exclude +6 -0
  23. original_performance_takehome/.git_backup/logs/HEAD +1 -0
  24. original_performance_takehome/.git_backup/logs/refs/heads/main +1 -0
  25. original_performance_takehome/.git_backup/logs/refs/remotes/origin/HEAD +1 -0
  26. original_performance_takehome/.git_backup/objects/pack/pack-813c2c470e2abf2cfcfb6aa8ba6478e559e69577.idx +0 -0
  27. original_performance_takehome/.git_backup/objects/pack/pack-813c2c470e2abf2cfcfb6aa8ba6478e559e69577.pack +0 -0
  28. original_performance_takehome/.git_backup/objects/pack/pack-813c2c470e2abf2cfcfb6aa8ba6478e559e69577.rev +0 -0
  29. original_performance_takehome/.git_backup/packed-refs +4 -0
  30. original_performance_takehome/.git_backup/refs/heads/main +1 -0
  31. original_performance_takehome/.git_backup/refs/remotes/origin/HEAD +1 -0
  32. original_performance_takehome/.gitignore +4 -0
  33. original_performance_takehome/Readme.md +39 -0
  34. original_performance_takehome/perf_takehome.py +275 -0
  35. original_performance_takehome/problem.py +568 -0
  36. original_performance_takehome/tests/frozen_problem.py +568 -0
  37. original_performance_takehome/tests/submission_tests.py +119 -0
  38. original_performance_takehome/watch_trace.html +132 -0
  39. original_performance_takehome/watch_trace.py +84 -0
  40. requirements.txt +8 -0
README.md CHANGED
@@ -1,12 +1,33 @@
1
  ---
2
- title: Vliw Optimizer
3
- emoji: 📈
4
  colorFrom: blue
5
- colorTo: indigo
6
  sdk: gradio
7
- sdk_version: 6.4.0
8
  app_file: app.py
9
  pinned: false
 
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: VLIW Kernel Optimizer
3
+ emoji: "⚡"
4
  colorFrom: blue
5
+ colorTo: purple
6
  sdk: gradio
7
+ sdk_version: 5.0.0
8
  app_file: app.py
9
  pinned: false
10
+ license: mit
11
  ---
12
 
13
+ # VLIW Kernel Optimization via Reinforcement Learning
14
+
15
+ Train a language model to generate optimized VLIW/SIMD kernels using test-time RL training.
16
+
17
+ ## Goal
18
+ - **Baseline:** 147,734 cycles
19
+ - **Target:** 1,363 cycles (108x speedup)
20
+
21
+ ## How it works
22
+ 1. Model generates kernel code
23
+ 2. Simulator evaluates cycle count
24
+ 3. RL training improves the model based on rewards
25
+
26
+ ## Usage
27
+ 1. Select a model (Qwen2.5-Coder-7B recommended)
28
+ 2. Configure training steps (50 recommended)
29
+ 3. Click "Start Training"
30
+ 4. Monitor progress - training continues even if you close the browser
31
+
32
+ ## Hardware
33
+ Requires A10G GPU (HF Spaces Pro)
__pycache__/app.cpython-314.pyc ADDED
Binary file (25.2 kB). View file
 
app.py ADDED
@@ -0,0 +1,549 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ HF Spaces app for VLIW kernel optimization via RL.
3
+ Deploy to HF Spaces Pro (A10G GPU).
4
+
5
+ This is self-contained - includes verification logic inline.
6
+ """
7
+ import os
8
+ import sys
9
+ import re
10
+ import threading
11
+ import time
12
+ import random
13
+ from datetime import datetime
14
+
15
+ import gradio as gr
16
+
17
+ # Thread lock for safe state access
18
+ training_state_lock = threading.Lock()
19
+
20
+ # Add simulator path
21
+ SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
22
+ PERF_TAKEHOME_PATH = os.path.join(SCRIPT_DIR, "original_performance_takehome")
23
+ if os.path.exists(PERF_TAKEHOME_PATH):
24
+ sys.path.insert(0, PERF_TAKEHOME_PATH)
25
+
26
+ # Constants
27
+ BASELINE_CYCLES = 147734
28
+ TARGET_CYCLES = 1363
29
+ SCORE_SCALE = 3000.0
30
+
31
+ # Training state (global)
32
+ training_state = {
33
+ "running": False,
34
+ "step": 0,
35
+ "total_steps": 0,
36
+ "best_cycles": BASELINE_CYCLES,
37
+ "best_code": None,
38
+ "log": [],
39
+ "start_time": None,
40
+ "results": [],
41
+ }
42
+
43
+ SYSTEM_PROMPT = '''Write optimized VLIW/SIMD kernel. OUTPUT ONLY ONE ```python CODE BLOCK.
44
+
45
+ ARCHITECTURE: 12 ALU + 6 VALU (VLEN=8) + 2 load + 2 store + 1 flow slots per cycle. 1536-word scratch.
46
+
47
+ API:
48
+ - alloc_scratch(name, length) -> addr
49
+ - scratch_const(val, name) -> addr
50
+ - add(engine, slot): engine in {alu, valu, load, store, flow}
51
+ - alu: (op, dst, src1, src2) where op in {+,-,*,/,%,^,&,|,==,!=,<,>,<=,>=}
52
+ - valu: same ops but on vectors (VLEN=8)
53
+ - load: (load,dst,addr), (vload,dst,addr), (const,dst,val), (vbroadcast,dst,scalar_addr)
54
+ - store: (store,addr,src), (vstore,addr,src)
55
+ - flow: (select,dst,cond,t,f), (jump,label), (jump_if_zero,cond,label), (halt,)
56
+ - label(name): mark code position
57
+ - build(slots, vliw=True): pack slots into VLIW bundle
58
+
59
+ MEMORY: mem[4]=forest_values, mem[5]=inp_indices, mem[6]=inp_values (256 elements each)
60
+
61
+ ALGORITHM: 16 rounds x 256 items: load idx,val; val=hash(val^tree[idx]); idx=2*idx+(1 or 2 based on val%2); store. Hash is 16 stages using HASH_STAGES constant.
62
+
63
+ OPTIMIZATION:
64
+ 1. Use vload/vstore: process 8 elements per instruction (256/8 = 32 vector iterations)
65
+ 2. Pack ops: 6 VALU slots = 6 vector ops per cycle
66
+ 3. Unroll: minimize loop overhead
67
+ 4. Pipeline: overlap loads with compute
68
+
69
+ You MUST override build_kernel() with actual instructions. Do NOT just call super().
70
+ '''
71
+
72
+
73
+ def extract_code_block(text: str) -> str:
74
+ """Extract python code from markdown code blocks."""
75
+ pattern = r"```python\s*(.*?)```"
76
+ matches = re.findall(pattern, text, re.DOTALL)
77
+ if matches:
78
+ return matches[-1].strip()
79
+ pattern = r"```\s*(.*?)```"
80
+ matches = re.findall(pattern, text, re.DOTALL)
81
+ if matches:
82
+ return matches[-1].strip()
83
+ return text.strip()
84
+
85
+
86
+ def verify_perf_takehome(generation: str, score_scale: float = SCORE_SCALE) -> dict:
87
+ """
88
+ Verify kernel code and return score.
89
+ Self-contained verification using the simulator.
90
+ """
91
+ try:
92
+ code = generation.strip()
93
+
94
+ if not code:
95
+ return {"score": 0.0, "correctness": 0.0, "performance": -1000000,
96
+ "msg": "Empty code", "cycles": None}
97
+
98
+ if "def run" not in code:
99
+ return {"score": 0.0, "correctness": 0.0, "performance": -1000000,
100
+ "msg": "No 'run' function defined", "cycles": None}
101
+
102
+ # Build execution environment
103
+ exec_globals = {
104
+ "FOREST_HEIGHT": 10,
105
+ "ROUNDS": 16,
106
+ "BATCH_SIZE": 256,
107
+ }
108
+
109
+ # Setup imports
110
+ setup_code = f'''
111
+ import sys
112
+ sys.path.insert(0, "{PERF_TAKEHOME_PATH}")
113
+ from problem import Machine, Tree, Input, build_mem_image, N_CORES, VLEN, reference_kernel2
114
+ from perf_takehome import KernelBuilder, HASH_STAGES, BASELINE
115
+ import random
116
+ '''
117
+ full_code = setup_code + "\n" + code
118
+ exec(full_code, exec_globals)
119
+
120
+ if "run" not in exec_globals:
121
+ return {"score": 0.0, "correctness": 0.0, "performance": -1000000,
122
+ "msg": "No 'run' function after exec", "cycles": None}
123
+
124
+ # Require OptimizedKernelBuilder
125
+ if "OptimizedKernelBuilder" not in exec_globals:
126
+ return {"score": 0.0, "correctness": 0.0, "performance": -1000000,
127
+ "msg": "No OptimizedKernelBuilder class", "cycles": None}
128
+
129
+ # Run verification
130
+ random.seed(123)
131
+ from problem import Tree, Input, Machine, build_mem_image, N_CORES, reference_kernel2
132
+
133
+ forest = Tree.generate(10)
134
+ inp = Input.generate(forest, 256, 16)
135
+ mem = build_mem_image(forest, inp)
136
+
137
+ # Get reference output
138
+ ref_mem = None
139
+ for ref_mem in reference_kernel2(list(mem)):
140
+ pass
141
+
142
+ if ref_mem is None:
143
+ return {"score": 0.0, "correctness": 0.0, "performance": -1000000,
144
+ "msg": "Reference kernel failed", "cycles": None}
145
+
146
+ # Run submitted kernel
147
+ kb = exec_globals["OptimizedKernelBuilder"]()
148
+ kb.build_kernel(10, len(forest.values), 256, 16)
149
+ machine = Machine(list(mem), kb.instrs, kb.debug_info(), n_cores=N_CORES)
150
+ machine.enable_pause = False
151
+ machine.enable_debug = False
152
+ machine.run()
153
+
154
+ cycles = machine.cycle
155
+
156
+ # Validate cycles
157
+ if cycles <= 100:
158
+ return {"score": 0.0, "correctness": 0.0, "performance": -1000000,
159
+ "msg": f"Suspiciously low cycles ({cycles})", "cycles": cycles}
160
+
161
+ if cycles > 200000:
162
+ return {"score": 0.0, "correctness": 0.0, "performance": -1000000,
163
+ "msg": f"Cycles too high: {cycles}", "cycles": cycles}
164
+
165
+ # Compare outputs
166
+ inp_values_p = ref_mem[6]
167
+ expected = ref_mem[inp_values_p : inp_values_p + len(inp.values)]
168
+ actual = machine.mem[inp_values_p : inp_values_p + len(inp.values)]
169
+
170
+ if expected != actual:
171
+ return {"score": 0.0, "correctness": 0.0, "performance": -1000000,
172
+ "msg": f"Incorrect output (cycles={cycles})", "cycles": cycles}
173
+
174
+ # Success!
175
+ score = score_scale / cycles
176
+ return {
177
+ "score": score,
178
+ "correctness": 1.0,
179
+ "performance": -cycles,
180
+ "msg": f"Success: {cycles} cycles",
181
+ "cycles": cycles,
182
+ }
183
+
184
+ except Exception as e:
185
+ import traceback
186
+ tb = traceback.format_exc()
187
+ error_line = tb.strip().split('\n')[-1][:200]
188
+ return {"score": 0.0, "correctness": 0.0, "performance": -1000000,
189
+ "msg": f"Error: {error_line}", "cycles": None}
190
+
191
+
192
+ def log(msg: str):
193
+ """Add to training log (thread-safe)."""
194
+ timestamp = datetime.now().strftime("%H:%M:%S")
195
+ formatted = f"[{timestamp}] {msg}"
196
+ with training_state_lock:
197
+ training_state["log"].append(formatted)
198
+ print(formatted)
199
+
200
+
201
+ def reward_function(completions: list[str], **kwargs) -> list[float]:
202
+ """Compute rewards for completions."""
203
+ rewards = []
204
+ for completion in completions:
205
+ try:
206
+ code = extract_code_block(completion)
207
+ result = verify_perf_takehome(code)
208
+ reward = result["score"]
209
+
210
+ if result["correctness"] > 0:
211
+ reward += 1.0
212
+ cycles = result.get("cycles")
213
+ if cycles:
214
+ with training_state_lock:
215
+ training_state["results"].append({
216
+ "step": training_state["step"],
217
+ "cycles": cycles,
218
+ "time": time.time() - (training_state["start_time"] or time.time())
219
+ })
220
+ if cycles < training_state["best_cycles"]:
221
+ training_state["best_cycles"] = cycles
222
+ training_state["best_code"] = code
223
+ speedup = BASELINE_CYCLES / cycles
224
+ log(f"NEW BEST: {cycles:,} cycles ({speedup:.2f}x speedup)")
225
+
226
+ rewards.append(reward)
227
+
228
+ except Exception as e:
229
+ log(f"Reward error: {str(e)[:100]}")
230
+ rewards.append(0.0)
231
+
232
+ return rewards
233
+
234
+
235
+ def build_prompt(current_cycles: int = BASELINE_CYCLES, last_code: str = "") -> str:
236
+ """Build training prompt."""
237
+ prompt = f"""{SYSTEM_PROMPT}
238
+
239
+ CURRENT: {current_cycles:,} cycles. TARGET: <{TARGET_CYCLES:,} cycles (need {current_cycles//TARGET_CYCLES}x speedup).
240
+ """
241
+ if last_code:
242
+ prompt += f"""
243
+ Previous best attempt:
244
+ ```python
245
+ {last_code[:2000]}
246
+ ```
247
+
248
+ Improve this code to reduce cycles further.
249
+ """
250
+ else:
251
+ prompt += """
252
+ Write a complete solution with:
253
+ 1. A run() function that returns (cycles, code_string)
254
+ 2. An OptimizedKernelBuilder class with build_kernel() method
255
+ """
256
+ return prompt
257
+
258
+
259
+ def run_training(model_name: str, num_steps: int, batch_size: int, lr: float, lora_rank: int):
260
+ """Main training loop."""
261
+ global training_state
262
+
263
+ with training_state_lock:
264
+ training_state["running"] = True
265
+ training_state["step"] = 0
266
+ training_state["total_steps"] = num_steps
267
+ training_state["best_cycles"] = BASELINE_CYCLES
268
+ training_state["best_code"] = None
269
+ training_state["log"] = []
270
+ training_state["results"] = []
271
+ training_state["start_time"] = time.time()
272
+
273
+ log(f"Starting training: {model_name}")
274
+ log(f"Steps: {num_steps}, Batch: {batch_size}, LR: {lr}, LoRA rank: {lora_rank}")
275
+
276
+ try:
277
+ import torch
278
+ from datasets import Dataset
279
+ from transformers import AutoTokenizer, BitsAndBytesConfig, TrainerCallback
280
+ from peft import LoraConfig
281
+ from trl import GRPOConfig, GRPOTrainer
282
+
283
+ # Check GPU
284
+ if torch.cuda.is_available():
285
+ gpu_name = torch.cuda.get_device_name(0)
286
+ gpu_mem = torch.cuda.get_device_properties(0).total_memory / 1e9
287
+ log(f"GPU: {gpu_name} ({gpu_mem:.1f}GB)")
288
+ else:
289
+ log("WARNING: No GPU detected!")
290
+
291
+ log("Loading tokenizer...")
292
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
293
+ if tokenizer.pad_token is None:
294
+ tokenizer.pad_token = tokenizer.eos_token
295
+
296
+ # Create dataset
297
+ prompt = build_prompt(BASELINE_CYCLES, "")
298
+ dataset = Dataset.from_dict({"prompt": [prompt] * 64})
299
+
300
+ # LoRA config
301
+ peft_config = LoraConfig(
302
+ r=lora_rank,
303
+ lora_alpha=lora_rank * 2,
304
+ lora_dropout=0.05,
305
+ bias="none",
306
+ task_type="CAUSAL_LM",
307
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
308
+ "gate_proj", "up_proj", "down_proj"],
309
+ )
310
+
311
+ # Training config
312
+ output_dir = f"./output/{datetime.now().strftime('%Y%m%d-%H%M%S')}"
313
+ os.makedirs(output_dir, exist_ok=True)
314
+
315
+ training_args = GRPOConfig(
316
+ output_dir=output_dir,
317
+ num_train_epochs=num_steps,
318
+ per_device_train_batch_size=batch_size,
319
+ gradient_accumulation_steps=4,
320
+ learning_rate=lr,
321
+ logging_steps=1,
322
+ save_steps=10,
323
+ max_completion_length=2048,
324
+ max_prompt_length=2048,
325
+ temperature=0.7,
326
+ num_generations=4,
327
+ beta=0.1,
328
+ bf16=True,
329
+ report_to="none",
330
+ )
331
+
332
+ # Quantization for 7B model on A10G
333
+ quant_config = None
334
+ if "7B" in model_name or "7b" in model_name:
335
+ log("Using 4-bit quantization for 7B model")
336
+ quant_config = BitsAndBytesConfig(
337
+ load_in_4bit=True,
338
+ bnb_4bit_compute_dtype=torch.bfloat16,
339
+ bnb_4bit_use_double_quant=True,
340
+ bnb_4bit_quant_type="nf4",
341
+ )
342
+
343
+ log("Loading model (this may take a few minutes)...")
344
+
345
+ model_kwargs = {}
346
+ if quant_config:
347
+ model_kwargs["quantization_config"] = quant_config
348
+
349
+ # Create stop callback
350
+ class StopCallback(TrainerCallback):
351
+ def on_step_end(self, args, state, control, **kwargs):
352
+ if not training_state["running"]:
353
+ log("Stop signal received, halting training...")
354
+ control.should_training_stop = True
355
+ return control
356
+
357
+ trainer = GRPOTrainer(
358
+ model=model_name,
359
+ reward_funcs=[reward_function],
360
+ args=training_args,
361
+ train_dataset=dataset,
362
+ peft_config=peft_config,
363
+ processing_class=tokenizer,
364
+ model_init_kwargs=model_kwargs,
365
+ callbacks=[StopCallback()],
366
+ )
367
+
368
+ log("Model loaded! Starting training...")
369
+
370
+ # Train
371
+ trainer.train()
372
+
373
+ log("Training complete!")
374
+
375
+ # Save
376
+ trainer.save_model(os.path.join(output_dir, "final"))
377
+ log(f"Model saved to {output_dir}/final")
378
+
379
+ # Save best code
380
+ if training_state["best_code"]:
381
+ with open(os.path.join(output_dir, "best_code.py"), "w") as f:
382
+ f.write(training_state["best_code"])
383
+ log("Best code saved!")
384
+
385
+ except Exception as e:
386
+ import traceback
387
+ log(f"ERROR: {str(e)}")
388
+ log(traceback.format_exc())
389
+
390
+ finally:
391
+ with training_state_lock:
392
+ training_state["running"] = False
393
+ elapsed = time.time() - training_state["start_time"]
394
+ best = training_state["best_cycles"]
395
+ log(f"Total time: {elapsed/60:.1f} minutes")
396
+ log(f"Best result: {best:,} cycles")
397
+
398
+
399
+ def start_training(model_name, num_steps, batch_size, lr, lora_rank):
400
+ """Start training in background."""
401
+ if training_state["running"]:
402
+ return "Training already running!"
403
+
404
+ thread = threading.Thread(
405
+ target=run_training,
406
+ args=(model_name, int(num_steps), int(batch_size), float(lr), int(lora_rank)),
407
+ daemon=False # Non-daemon to ensure training completes
408
+ )
409
+ thread.start()
410
+ return "Training started! Monitor progress below."
411
+
412
+
413
+ def stop_training():
414
+ """Signal training to stop."""
415
+ with training_state_lock:
416
+ training_state["running"] = False
417
+ return "Stop signal sent. Training will stop after current step."
418
+
419
+
420
+ def get_status():
421
+ """Get current status as markdown."""
422
+ if not training_state["start_time"]:
423
+ return "### Status: Not started\n\nConfigure settings and click Start Training."
424
+
425
+ with training_state_lock:
426
+ elapsed = time.time() - training_state["start_time"]
427
+ elapsed_str = f"{elapsed/60:.1f} min"
428
+ best_cycles = max(training_state["best_cycles"], 1) # Prevent division by zero
429
+ is_running = training_state["running"]
430
+ log_lines = training_state["log"][-15:]
431
+
432
+ speedup = BASELINE_CYCLES / best_cycles
433
+ progress_pct = (1 - best_cycles / BASELINE_CYCLES) * 100
434
+
435
+ status = f"""### Status: {'Running' if is_running else 'Stopped'}
436
+
437
+ | Metric | Value |
438
+ |--------|-------|
439
+ | Elapsed | {elapsed_str} |
440
+ | Best Cycles | **{best_cycles:,}** |
441
+ | Speedup | **{speedup:.2f}x** |
442
+ | Progress to Target | {progress_pct:.1f}% |
443
+ | Target | {TARGET_CYCLES:,} cycles |
444
+
445
+ ---
446
+
447
+ ### Recent Log
448
+ ```
449
+ {chr(10).join(log_lines)}
450
+ ```
451
+ """
452
+ return status
453
+
454
+
455
+ def get_best_code():
456
+ """Get best code found."""
457
+ with training_state_lock:
458
+ best_code = training_state["best_code"]
459
+ if best_code:
460
+ return best_code
461
+ return "# No valid code found yet.\n# Start training to generate optimized kernels."
462
+
463
+
464
+ def get_results_chart():
465
+ """Get results as simple text chart."""
466
+ with training_state_lock:
467
+ results = list(training_state["results"][-20:])
468
+
469
+ if not results:
470
+ return "No results yet."
471
+
472
+ lines = ["Cycles over time:", ""]
473
+ for r in results:
474
+ bar_len = max(1, int(50 * r["cycles"] / BASELINE_CYCLES))
475
+ bar = "#" * bar_len
476
+ lines.append(f"{r['cycles']:>7,} | {bar}")
477
+
478
+ return "\n".join(lines)
479
+
480
+
481
+ # Build Gradio UI
482
+ with gr.Blocks(title="VLIW Kernel Optimizer", theme=gr.themes.Soft()) as demo:
483
+ gr.Markdown("""
484
+ # VLIW Kernel Optimization via Reinforcement Learning
485
+
486
+ Train a language model to generate optimized VLIW/SIMD kernels.
487
+
488
+ | Baseline | Target | Goal |
489
+ |----------|--------|------|
490
+ | 147,734 cycles | 1,363 cycles | 108x speedup |
491
+ """)
492
+
493
+ with gr.Row():
494
+ with gr.Column(scale=1):
495
+ gr.Markdown("### Configuration")
496
+
497
+ model_dropdown = gr.Dropdown(
498
+ choices=[
499
+ "Qwen/Qwen2.5-Coder-7B-Instruct",
500
+ "Qwen/Qwen2.5-Coder-3B-Instruct",
501
+ "Qwen/Qwen2.5-Coder-1.5B-Instruct",
502
+ "deepseek-ai/deepseek-coder-6.7b-instruct",
503
+ "codellama/CodeLlama-7b-Instruct-hf",
504
+ ],
505
+ value="Qwen/Qwen2.5-Coder-7B-Instruct",
506
+ label="Model"
507
+ )
508
+
509
+ steps_slider = gr.Slider(1, 100, value=50, step=1, label="Training Steps")
510
+ batch_slider = gr.Slider(1, 8, value=4, step=1, label="Batch Size")
511
+ lr_input = gr.Number(value=2e-4, label="Learning Rate")
512
+ lora_slider = gr.Slider(8, 64, value=32, step=8, label="LoRA Rank")
513
+
514
+ with gr.Row():
515
+ start_btn = gr.Button("Start Training", variant="primary", size="lg")
516
+ stop_btn = gr.Button("Stop", variant="stop")
517
+
518
+ with gr.Column(scale=2):
519
+ status_md = gr.Markdown("### Status: Not started")
520
+ refresh_btn = gr.Button("Refresh", size="sm")
521
+
522
+ with gr.Row():
523
+ with gr.Column():
524
+ gr.Markdown("### Best Code Found")
525
+ code_output = gr.Code(language="python", lines=25)
526
+ code_btn = gr.Button("Show Best Code")
527
+
528
+ with gr.Column():
529
+ gr.Markdown("### Results")
530
+ results_output = gr.Textbox(lines=15, label="Cycles Progress")
531
+ results_btn = gr.Button("Show Results")
532
+
533
+ # Event handlers
534
+ start_btn.click(
535
+ start_training,
536
+ inputs=[model_dropdown, steps_slider, batch_slider, lr_input, lora_slider],
537
+ outputs=[status_md]
538
+ )
539
+ stop_btn.click(stop_training, outputs=[status_md])
540
+ refresh_btn.click(get_status, outputs=[status_md])
541
+ code_btn.click(get_best_code, outputs=[code_output])
542
+ results_btn.click(get_results_chart, outputs=[results_output])
543
+
544
+ # Auto-refresh
545
+ demo.load(get_status, outputs=[status_md], every=5)
546
+
547
+
548
+ if __name__ == "__main__":
549
+ demo.launch(server_name="0.0.0.0", server_port=7860)
original_performance_takehome/.git_backup/HEAD ADDED
@@ -0,0 +1 @@
 
 
1
+ ref: refs/heads/main
original_performance_takehome/.git_backup/config ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [core]
2
+ repositoryformatversion = 0
3
+ filemode = true
4
+ bare = false
5
+ logallrefupdates = true
6
+ ignorecase = true
7
+ precomposeunicode = true
8
+ [remote "origin"]
9
+ url = https://github.com/anthropics/original_performance_takehome.git
10
+ fetch = +refs/heads/*:refs/remotes/origin/*
11
+ [branch "main"]
12
+ remote = origin
13
+ merge = refs/heads/main
original_performance_takehome/.git_backup/description ADDED
@@ -0,0 +1 @@
 
 
1
+ Unnamed repository; edit this file 'description' to name the repository.
original_performance_takehome/.git_backup/hooks/applypatch-msg.sample ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/sh
2
+ #
3
+ # An example hook script to check the commit log message taken by
4
+ # applypatch from an e-mail message.
5
+ #
6
+ # The hook should exit with non-zero status after issuing an
7
+ # appropriate message if it wants to stop the commit. The hook is
8
+ # allowed to edit the commit message file.
9
+ #
10
+ # To enable this hook, rename this file to "applypatch-msg".
11
+
12
+ . git-sh-setup
13
+ commitmsg="$(git rev-parse --git-path hooks/commit-msg)"
14
+ test -x "$commitmsg" && exec "$commitmsg" ${1+"$@"}
15
+ :
original_performance_takehome/.git_backup/hooks/commit-msg.sample ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/sh
2
+ #
3
+ # An example hook script to check the commit log message.
4
+ # Called by "git commit" with one argument, the name of the file
5
+ # that has the commit message. The hook should exit with non-zero
6
+ # status after issuing an appropriate message if it wants to stop the
7
+ # commit. The hook is allowed to edit the commit message file.
8
+ #
9
+ # To enable this hook, rename this file to "commit-msg".
10
+
11
+ # Uncomment the below to add a Signed-off-by line to the message.
12
+ # Doing this in a hook is a bad idea in general, but the prepare-commit-msg
13
+ # hook is more suited to it.
14
+ #
15
+ # SOB=$(git var GIT_AUTHOR_IDENT | sed -n 's/^\(.*>\).*$/Signed-off-by: \1/p')
16
+ # grep -qs "^$SOB" "$1" || echo "$SOB" >> "$1"
17
+
18
+ # This example catches duplicate Signed-off-by lines.
19
+
20
+ test "" = "$(grep '^Signed-off-by: ' "$1" |
21
+ sort | uniq -c | sed -e '/^[ ]*1[ ]/d')" || {
22
+ echo >&2 Duplicate Signed-off-by lines.
23
+ exit 1
24
+ }
original_performance_takehome/.git_backup/hooks/fsmonitor-watchman.sample ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/perl
2
+
3
+ use strict;
4
+ use warnings;
5
+ use IPC::Open2;
6
+
7
+ # An example hook script to integrate Watchman
8
+ # (https://facebook.github.io/watchman/) with git to speed up detecting
9
+ # new and modified files.
10
+ #
11
+ # The hook is passed a version (currently 2) and last update token
12
+ # formatted as a string and outputs to stdout a new update token and
13
+ # all files that have been modified since the update token. Paths must
14
+ # be relative to the root of the working tree and separated by a single NUL.
15
+ #
16
+ # To enable this hook, rename this file to "query-watchman" and set
17
+ # 'git config core.fsmonitor .git/hooks/query-watchman'
18
+ #
19
+ my ($version, $last_update_token) = @ARGV;
20
+
21
+ # Uncomment for debugging
22
+ # print STDERR "$0 $version $last_update_token\n";
23
+
24
+ # Check the hook interface version
25
+ if ($version ne 2) {
26
+ die "Unsupported query-fsmonitor hook version '$version'.\n" .
27
+ "Falling back to scanning...\n";
28
+ }
29
+
30
+ my $git_work_tree = get_working_dir();
31
+
32
+ my $retry = 1;
33
+
34
+ my $json_pkg;
35
+ eval {
36
+ require JSON::XS;
37
+ $json_pkg = "JSON::XS";
38
+ 1;
39
+ } or do {
40
+ require JSON::PP;
41
+ $json_pkg = "JSON::PP";
42
+ };
43
+
44
+ launch_watchman();
45
+
46
+ sub launch_watchman {
47
+ my $o = watchman_query();
48
+ if (is_work_tree_watched($o)) {
49
+ output_result($o->{clock}, @{$o->{files}});
50
+ }
51
+ }
52
+
53
+ sub output_result {
54
+ my ($clockid, @files) = @_;
55
+
56
+ # Uncomment for debugging watchman output
57
+ # open (my $fh, ">", ".git/watchman-output.out");
58
+ # binmode $fh, ":utf8";
59
+ # print $fh "$clockid\n@files\n";
60
+ # close $fh;
61
+
62
+ binmode STDOUT, ":utf8";
63
+ print $clockid;
64
+ print "\0";
65
+ local $, = "\0";
66
+ print @files;
67
+ }
68
+
69
+ sub watchman_clock {
70
+ my $response = qx/watchman clock "$git_work_tree"/;
71
+ die "Failed to get clock id on '$git_work_tree'.\n" .
72
+ "Falling back to scanning...\n" if $? != 0;
73
+
74
+ return $json_pkg->new->utf8->decode($response);
75
+ }
76
+
77
+ sub watchman_query {
78
+ my $pid = open2(\*CHLD_OUT, \*CHLD_IN, 'watchman -j --no-pretty')
79
+ or die "open2() failed: $!\n" .
80
+ "Falling back to scanning...\n";
81
+
82
+ # In the query expression below we're asking for names of files that
83
+ # changed since $last_update_token but not from the .git folder.
84
+ #
85
+ # To accomplish this, we're using the "since" generator to use the
86
+ # recency index to select candidate nodes and "fields" to limit the
87
+ # output to file names only. Then we're using the "expression" term to
88
+ # further constrain the results.
89
+ my $last_update_line = "";
90
+ if (substr($last_update_token, 0, 1) eq "c") {
91
+ $last_update_token = "\"$last_update_token\"";
92
+ $last_update_line = qq[\n"since": $last_update_token,];
93
+ }
94
+ my $query = <<" END";
95
+ ["query", "$git_work_tree", {$last_update_line
96
+ "fields": ["name"],
97
+ "expression": ["not", ["dirname", ".git"]]
98
+ }]
99
+ END
100
+
101
+ # Uncomment for debugging the watchman query
102
+ # open (my $fh, ">", ".git/watchman-query.json");
103
+ # print $fh $query;
104
+ # close $fh;
105
+
106
+ print CHLD_IN $query;
107
+ close CHLD_IN;
108
+ my $response = do {local $/; <CHLD_OUT>};
109
+
110
+ # Uncomment for debugging the watch response
111
+ # open ($fh, ">", ".git/watchman-response.json");
112
+ # print $fh $response;
113
+ # close $fh;
114
+
115
+ die "Watchman: command returned no output.\n" .
116
+ "Falling back to scanning...\n" if $response eq "";
117
+ die "Watchman: command returned invalid output: $response\n" .
118
+ "Falling back to scanning...\n" unless $response =~ /^\{/;
119
+
120
+ return $json_pkg->new->utf8->decode($response);
121
+ }
122
+
123
+ sub is_work_tree_watched {
124
+ my ($output) = @_;
125
+ my $error = $output->{error};
126
+ if ($retry > 0 and $error and $error =~ m/unable to resolve root .* directory (.*) is not watched/) {
127
+ $retry--;
128
+ my $response = qx/watchman watch "$git_work_tree"/;
129
+ die "Failed to make watchman watch '$git_work_tree'.\n" .
130
+ "Falling back to scanning...\n" if $? != 0;
131
+ $output = $json_pkg->new->utf8->decode($response);
132
+ $error = $output->{error};
133
+ die "Watchman: $error.\n" .
134
+ "Falling back to scanning...\n" if $error;
135
+
136
+ # Uncomment for debugging watchman output
137
+ # open (my $fh, ">", ".git/watchman-output.out");
138
+ # close $fh;
139
+
140
+ # Watchman will always return all files on the first query so
141
+ # return the fast "everything is dirty" flag to git and do the
142
+ # Watchman query just to get it over with now so we won't pay
143
+ # the cost in git to look up each individual file.
144
+ my $o = watchman_clock();
145
+ $error = $output->{error};
146
+
147
+ die "Watchman: $error.\n" .
148
+ "Falling back to scanning...\n" if $error;
149
+
150
+ output_result($o->{clock}, ("/"));
151
+ $last_update_token = $o->{clock};
152
+
153
+ eval { launch_watchman() };
154
+ return 0;
155
+ }
156
+
157
+ die "Watchman: $error.\n" .
158
+ "Falling back to scanning...\n" if $error;
159
+
160
+ return 1;
161
+ }
162
+
163
+ sub get_working_dir {
164
+ my $working_dir;
165
+ if ($^O =~ 'msys' || $^O =~ 'cygwin') {
166
+ $working_dir = Win32::GetCwd();
167
+ $working_dir =~ tr/\\/\//;
168
+ } else {
169
+ require Cwd;
170
+ $working_dir = Cwd::cwd();
171
+ }
172
+
173
+ return $working_dir;
174
+ }
original_performance_takehome/.git_backup/hooks/post-update.sample ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ #!/bin/sh
2
+ #
3
+ # An example hook script to prepare a packed repository for use over
4
+ # dumb transports.
5
+ #
6
+ # To enable this hook, rename this file to "post-update".
7
+
8
+ exec git update-server-info
original_performance_takehome/.git_backup/hooks/pre-applypatch.sample ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/sh
2
+ #
3
+ # An example hook script to verify what is about to be committed
4
+ # by applypatch from an e-mail message.
5
+ #
6
+ # The hook should exit with non-zero status after issuing an
7
+ # appropriate message if it wants to stop the commit.
8
+ #
9
+ # To enable this hook, rename this file to "pre-applypatch".
10
+
11
+ . git-sh-setup
12
+ precommit="$(git rev-parse --git-path hooks/pre-commit)"
13
+ test -x "$precommit" && exec "$precommit" ${1+"$@"}
14
+ :
original_performance_takehome/.git_backup/hooks/pre-commit.sample ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/sh
2
+ #
3
+ # An example hook script to verify what is about to be committed.
4
+ # Called by "git commit" with no arguments. The hook should
5
+ # exit with non-zero status after issuing an appropriate message if
6
+ # it wants to stop the commit.
7
+ #
8
+ # To enable this hook, rename this file to "pre-commit".
9
+
10
+ if git rev-parse --verify HEAD >/dev/null 2>&1
11
+ then
12
+ against=HEAD
13
+ else
14
+ # Initial commit: diff against an empty tree object
15
+ against=$(git hash-object -t tree /dev/null)
16
+ fi
17
+
18
+ # If you want to allow non-ASCII filenames set this variable to true.
19
+ allownonascii=$(git config --type=bool hooks.allownonascii)
20
+
21
+ # Redirect output to stderr.
22
+ exec 1>&2
23
+
24
+ # Cross platform projects tend to avoid non-ASCII filenames; prevent
25
+ # them from being added to the repository. We exploit the fact that the
26
+ # printable range starts at the space character and ends with tilde.
27
+ if [ "$allownonascii" != "true" ] &&
28
+ # Note that the use of brackets around a tr range is ok here, (it's
29
+ # even required, for portability to Solaris 10's /usr/bin/tr), since
30
+ # the square bracket bytes happen to fall in the designated range.
31
+ test $(git diff-index --cached --name-only --diff-filter=A -z $against |
32
+ LC_ALL=C tr -d '[ -~]\0' | wc -c) != 0
33
+ then
34
+ cat <<\EOF
35
+ Error: Attempt to add a non-ASCII file name.
36
+
37
+ This can cause problems if you want to work with people on other platforms.
38
+
39
+ To be portable it is advisable to rename the file.
40
+
41
+ If you know what you are doing you can disable this check using:
42
+
43
+ git config hooks.allownonascii true
44
+ EOF
45
+ exit 1
46
+ fi
47
+
48
+ # If there are whitespace errors, print the offending file names and fail.
49
+ exec git diff-index --check --cached $against --
original_performance_takehome/.git_backup/hooks/pre-merge-commit.sample ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/sh
2
+ #
3
+ # An example hook script to verify what is about to be committed.
4
+ # Called by "git merge" with no arguments. The hook should
5
+ # exit with non-zero status after issuing an appropriate message to
6
+ # stderr if it wants to stop the merge commit.
7
+ #
8
+ # To enable this hook, rename this file to "pre-merge-commit".
9
+
10
+ . git-sh-setup
11
+ test -x "$GIT_DIR/hooks/pre-commit" &&
12
+ exec "$GIT_DIR/hooks/pre-commit"
13
+ :
original_performance_takehome/.git_backup/hooks/pre-push.sample ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/sh
2
+
3
+ # An example hook script to verify what is about to be pushed. Called by "git
4
+ # push" after it has checked the remote status, but before anything has been
5
+ # pushed. If this script exits with a non-zero status nothing will be pushed.
6
+ #
7
+ # This hook is called with the following parameters:
8
+ #
9
+ # $1 -- Name of the remote to which the push is being done
10
+ # $2 -- URL to which the push is being done
11
+ #
12
+ # If pushing without using a named remote those arguments will be equal.
13
+ #
14
+ # Information about the commits which are being pushed is supplied as lines to
15
+ # the standard input in the form:
16
+ #
17
+ # <local ref> <local oid> <remote ref> <remote oid>
18
+ #
19
+ # This sample shows how to prevent push of commits where the log message starts
20
+ # with "WIP" (work in progress).
21
+
22
+ remote="$1"
23
+ url="$2"
24
+
25
+ zero=$(git hash-object --stdin </dev/null | tr '[0-9a-f]' '0')
26
+
27
+ while read local_ref local_oid remote_ref remote_oid
28
+ do
29
+ if test "$local_oid" = "$zero"
30
+ then
31
+ # Handle delete
32
+ :
33
+ else
34
+ if test "$remote_oid" = "$zero"
35
+ then
36
+ # New branch, examine all commits
37
+ range="$local_oid"
38
+ else
39
+ # Update to existing branch, examine new commits
40
+ range="$remote_oid..$local_oid"
41
+ fi
42
+
43
+ # Check for WIP commit
44
+ commit=$(git rev-list -n 1 --grep '^WIP' "$range")
45
+ if test -n "$commit"
46
+ then
47
+ echo >&2 "Found WIP commit in $local_ref, not pushing"
48
+ exit 1
49
+ fi
50
+ fi
51
+ done
52
+
53
+ exit 0
original_performance_takehome/.git_backup/hooks/pre-rebase.sample ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/sh
2
+ #
3
+ # Copyright (c) 2006, 2008 Junio C Hamano
4
+ #
5
+ # The "pre-rebase" hook is run just before "git rebase" starts doing
6
+ # its job, and can prevent the command from running by exiting with
7
+ # non-zero status.
8
+ #
9
+ # The hook is called with the following parameters:
10
+ #
11
+ # $1 -- the upstream the series was forked from.
12
+ # $2 -- the branch being rebased (or empty when rebasing the current branch).
13
+ #
14
+ # This sample shows how to prevent topic branches that are already
15
+ # merged to 'next' branch from getting rebased, because allowing it
16
+ # would result in rebasing already published history.
17
+
18
+ publish=next
19
+ basebranch="$1"
20
+ if test "$#" = 2
21
+ then
22
+ topic="refs/heads/$2"
23
+ else
24
+ topic=`git symbolic-ref HEAD` ||
25
+ exit 0 ;# we do not interrupt rebasing detached HEAD
26
+ fi
27
+
28
+ case "$topic" in
29
+ refs/heads/??/*)
30
+ ;;
31
+ *)
32
+ exit 0 ;# we do not interrupt others.
33
+ ;;
34
+ esac
35
+
36
+ # Now we are dealing with a topic branch being rebased
37
+ # on top of master. Is it OK to rebase it?
38
+
39
+ # Does the topic really exist?
40
+ git show-ref -q "$topic" || {
41
+ echo >&2 "No such branch $topic"
42
+ exit 1
43
+ }
44
+
45
+ # Is topic fully merged to master?
46
+ not_in_master=`git rev-list --pretty=oneline ^master "$topic"`
47
+ if test -z "$not_in_master"
48
+ then
49
+ echo >&2 "$topic is fully merged to master; better remove it."
50
+ exit 1 ;# we could allow it, but there is no point.
51
+ fi
52
+
53
+ # Is topic ever merged to next? If so you should not be rebasing it.
54
+ only_next_1=`git rev-list ^master "^$topic" ${publish} | sort`
55
+ only_next_2=`git rev-list ^master ${publish} | sort`
56
+ if test "$only_next_1" = "$only_next_2"
57
+ then
58
+ not_in_topic=`git rev-list "^$topic" master`
59
+ if test -z "$not_in_topic"
60
+ then
61
+ echo >&2 "$topic is already up to date with master"
62
+ exit 1 ;# we could allow it, but there is no point.
63
+ else
64
+ exit 0
65
+ fi
66
+ else
67
+ not_in_next=`git rev-list --pretty=oneline ^${publish} "$topic"`
68
+ /usr/bin/perl -e '
69
+ my $topic = $ARGV[0];
70
+ my $msg = "* $topic has commits already merged to public branch:\n";
71
+ my (%not_in_next) = map {
72
+ /^([0-9a-f]+) /;
73
+ ($1 => 1);
74
+ } split(/\n/, $ARGV[1]);
75
+ for my $elem (map {
76
+ /^([0-9a-f]+) (.*)$/;
77
+ [$1 => $2];
78
+ } split(/\n/, $ARGV[2])) {
79
+ if (!exists $not_in_next{$elem->[0]}) {
80
+ if ($msg) {
81
+ print STDERR $msg;
82
+ undef $msg;
83
+ }
84
+ print STDERR " $elem->[1]\n";
85
+ }
86
+ }
87
+ ' "$topic" "$not_in_next" "$not_in_master"
88
+ exit 1
89
+ fi
90
+
91
+ <<\DOC_END
92
+
93
+ This sample hook safeguards topic branches that have been
94
+ published from being rewound.
95
+
96
+ The workflow assumed here is:
97
+
98
+ * Once a topic branch forks from "master", "master" is never
99
+ merged into it again (either directly or indirectly).
100
+
101
+ * Once a topic branch is fully cooked and merged into "master",
102
+ it is deleted. If you need to build on top of it to correct
103
+ earlier mistakes, a new topic branch is created by forking at
104
+ the tip of the "master". This is not strictly necessary, but
105
+ it makes it easier to keep your history simple.
106
+
107
+ * Whenever you need to test or publish your changes to topic
108
+ branches, merge them into "next" branch.
109
+
110
+ The script, being an example, hardcodes the publish branch name
111
+ to be "next", but it is trivial to make it configurable via
112
+ $GIT_DIR/config mechanism.
113
+
114
+ With this workflow, you would want to know:
115
+
116
+ (1) ... if a topic branch has ever been merged to "next". Young
117
+ topic branches can have stupid mistakes you would rather
118
+ clean up before publishing, and things that have not been
119
+ merged into other branches can be easily rebased without
120
+ affecting other people. But once it is published, you would
121
+ not want to rewind it.
122
+
123
+ (2) ... if a topic branch has been fully merged to "master".
124
+ Then you can delete it. More importantly, you should not
125
+ build on top of it -- other people may already want to
126
+ change things related to the topic as patches against your
127
+ "master", so if you need further changes, it is better to
128
+ fork the topic (perhaps with the same name) afresh from the
129
+ tip of "master".
130
+
131
+ Let's look at this example:
132
+
133
+ o---o---o---o---o---o---o---o---o---o "next"
134
+ / / / /
135
+ / a---a---b A / /
136
+ / / / /
137
+ / / c---c---c---c B /
138
+ / / / \ /
139
+ / / / b---b C \ /
140
+ / / / / \ /
141
+ ---o---o---o---o---o---o---o---o---o---o---o "master"
142
+
143
+
144
+ A, B and C are topic branches.
145
+
146
+ * A has one fix since it was merged up to "next".
147
+
148
+ * B has finished. It has been fully merged up to "master" and "next",
149
+ and is ready to be deleted.
150
+
151
+ * C has not merged to "next" at all.
152
+
153
+ We would want to allow C to be rebased, refuse A, and encourage
154
+ B to be deleted.
155
+
156
+ To compute (1):
157
+
158
+ git rev-list ^master ^topic next
159
+ git rev-list ^master next
160
+
161
+ if these match, topic has not merged in next at all.
162
+
163
+ To compute (2):
164
+
165
+ git rev-list master..topic
166
+
167
+ if this is empty, it is fully merged to "master".
168
+
169
+ DOC_END
original_performance_takehome/.git_backup/hooks/pre-receive.sample ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/sh
2
+ #
3
+ # An example hook script to make use of push options.
4
+ # The example simply echoes all push options that start with 'echoback='
5
+ # and rejects all pushes when the "reject" push option is used.
6
+ #
7
+ # To enable this hook, rename this file to "pre-receive".
8
+
9
+ if test -n "$GIT_PUSH_OPTION_COUNT"
10
+ then
11
+ i=0
12
+ while test "$i" -lt "$GIT_PUSH_OPTION_COUNT"
13
+ do
14
+ eval "value=\$GIT_PUSH_OPTION_$i"
15
+ case "$value" in
16
+ echoback=*)
17
+ echo "echo from the pre-receive-hook: ${value#*=}" >&2
18
+ ;;
19
+ reject)
20
+ exit 1
21
+ esac
22
+ i=$((i + 1))
23
+ done
24
+ fi
original_performance_takehome/.git_backup/hooks/prepare-commit-msg.sample ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/sh
2
+ #
3
+ # An example hook script to prepare the commit log message.
4
+ # Called by "git commit" with the name of the file that has the
5
+ # commit message, followed by the description of the commit
6
+ # message's source. The hook's purpose is to edit the commit
7
+ # message file. If the hook fails with a non-zero status,
8
+ # the commit is aborted.
9
+ #
10
+ # To enable this hook, rename this file to "prepare-commit-msg".
11
+
12
+ # This hook includes three examples. The first one removes the
13
+ # "# Please enter the commit message..." help message.
14
+ #
15
+ # The second includes the output of "git diff --name-status -r"
16
+ # into the message, just before the "git status" output. It is
17
+ # commented because it doesn't cope with --amend or with squashed
18
+ # commits.
19
+ #
20
+ # The third example adds a Signed-off-by line to the message, that can
21
+ # still be edited. This is rarely a good idea.
22
+
23
+ COMMIT_MSG_FILE=$1
24
+ COMMIT_SOURCE=$2
25
+ SHA1=$3
26
+
27
+ /usr/bin/perl -i.bak -ne 'print unless(m/^. Please enter the commit message/..m/^#$/)' "$COMMIT_MSG_FILE"
28
+
29
+ # case "$COMMIT_SOURCE,$SHA1" in
30
+ # ,|template,)
31
+ # /usr/bin/perl -i.bak -pe '
32
+ # print "\n" . `git diff --cached --name-status -r`
33
+ # if /^#/ && $first++ == 0' "$COMMIT_MSG_FILE" ;;
34
+ # *) ;;
35
+ # esac
36
+
37
+ # SOB=$(git var GIT_COMMITTER_IDENT | sed -n 's/^\(.*>\).*$/Signed-off-by: \1/p')
38
+ # git interpret-trailers --in-place --trailer "$SOB" "$COMMIT_MSG_FILE"
39
+ # if test -z "$COMMIT_SOURCE"
40
+ # then
41
+ # /usr/bin/perl -i.bak -pe 'print "\n" if !$first_line++' "$COMMIT_MSG_FILE"
42
+ # fi
original_performance_takehome/.git_backup/hooks/push-to-checkout.sample ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/sh
2
+
3
+ # An example hook script to update a checked-out tree on a git push.
4
+ #
5
+ # This hook is invoked by git-receive-pack(1) when it reacts to git
6
+ # push and updates reference(s) in its repository, and when the push
7
+ # tries to update the branch that is currently checked out and the
8
+ # receive.denyCurrentBranch configuration variable is set to
9
+ # updateInstead.
10
+ #
11
+ # By default, such a push is refused if the working tree and the index
12
+ # of the remote repository has any difference from the currently
13
+ # checked out commit; when both the working tree and the index match
14
+ # the current commit, they are updated to match the newly pushed tip
15
+ # of the branch. This hook is to be used to override the default
16
+ # behaviour; however the code below reimplements the default behaviour
17
+ # as a starting point for convenient modification.
18
+ #
19
+ # The hook receives the commit with which the tip of the current
20
+ # branch is going to be updated:
21
+ commit=$1
22
+
23
+ # It can exit with a non-zero status to refuse the push (when it does
24
+ # so, it must not modify the index or the working tree).
25
+ die () {
26
+ echo >&2 "$*"
27
+ exit 1
28
+ }
29
+
30
+ # Or it can make any necessary changes to the working tree and to the
31
+ # index to bring them to the desired state when the tip of the current
32
+ # branch is updated to the new commit, and exit with a zero status.
33
+ #
34
+ # For example, the hook can simply run git read-tree -u -m HEAD "$1"
35
+ # in order to emulate git fetch that is run in the reverse direction
36
+ # with git push, as the two-tree form of git read-tree -u -m is
37
+ # essentially the same as git switch or git checkout that switches
38
+ # branches while keeping the local changes in the working tree that do
39
+ # not interfere with the difference between the branches.
40
+
41
+ # The below is a more-or-less exact translation to shell of the C code
42
+ # for the default behaviour for git's push-to-checkout hook defined in
43
+ # the push_to_deploy() function in builtin/receive-pack.c.
44
+ #
45
+ # Note that the hook will be executed from the repository directory,
46
+ # not from the working tree, so if you want to perform operations on
47
+ # the working tree, you will have to adapt your code accordingly, e.g.
48
+ # by adding "cd .." or using relative paths.
49
+
50
+ if ! git update-index -q --ignore-submodules --refresh
51
+ then
52
+ die "Up-to-date check failed"
53
+ fi
54
+
55
+ if ! git diff-files --quiet --ignore-submodules --
56
+ then
57
+ die "Working directory has unstaged changes"
58
+ fi
59
+
60
+ # This is a rough translation of:
61
+ #
62
+ # head_has_history() ? "HEAD" : EMPTY_TREE_SHA1_HEX
63
+ if git cat-file -e HEAD 2>/dev/null
64
+ then
65
+ head=HEAD
66
+ else
67
+ head=$(git hash-object -t tree --stdin </dev/null)
68
+ fi
69
+
70
+ if ! git diff-index --quiet --cached --ignore-submodules $head --
71
+ then
72
+ die "Working directory has staged changes"
73
+ fi
74
+
75
+ if ! git read-tree -u -m "$commit"
76
+ then
77
+ die "Could not update working tree to new HEAD"
78
+ fi
original_performance_takehome/.git_backup/hooks/sendemail-validate.sample ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/sh
2
+
3
+ # An example hook script to validate a patch (and/or patch series) before
4
+ # sending it via email.
5
+ #
6
+ # The hook should exit with non-zero status after issuing an appropriate
7
+ # message if it wants to prevent the email(s) from being sent.
8
+ #
9
+ # To enable this hook, rename this file to "sendemail-validate".
10
+ #
11
+ # By default, it will only check that the patch(es) can be applied on top of
12
+ # the default upstream branch without conflicts in a secondary worktree. After
13
+ # validation (successful or not) of the last patch of a series, the worktree
14
+ # will be deleted.
15
+ #
16
+ # The following config variables can be set to change the default remote and
17
+ # remote ref that are used to apply the patches against:
18
+ #
19
+ # sendemail.validateRemote (default: origin)
20
+ # sendemail.validateRemoteRef (default: HEAD)
21
+ #
22
+ # Replace the TODO placeholders with appropriate checks according to your
23
+ # needs.
24
+
25
+ validate_cover_letter () {
26
+ file="$1"
27
+ # TODO: Replace with appropriate checks (e.g. spell checking).
28
+ true
29
+ }
30
+
31
+ validate_patch () {
32
+ file="$1"
33
+ # Ensure that the patch applies without conflicts.
34
+ git am -3 "$file" || return
35
+ # TODO: Replace with appropriate checks for this patch
36
+ # (e.g. checkpatch.pl).
37
+ true
38
+ }
39
+
40
+ validate_series () {
41
+ # TODO: Replace with appropriate checks for the whole series
42
+ # (e.g. quick build, coding style checks, etc.).
43
+ true
44
+ }
45
+
46
+ # main -------------------------------------------------------------------------
47
+
48
+ if test "$GIT_SENDEMAIL_FILE_COUNTER" = 1
49
+ then
50
+ remote=$(git config --default origin --get sendemail.validateRemote) &&
51
+ ref=$(git config --default HEAD --get sendemail.validateRemoteRef) &&
52
+ worktree=$(mktemp --tmpdir -d sendemail-validate.XXXXXXX) &&
53
+ git worktree add -fd --checkout "$worktree" "refs/remotes/$remote/$ref" &&
54
+ git config --replace-all sendemail.validateWorktree "$worktree"
55
+ else
56
+ worktree=$(git config --get sendemail.validateWorktree)
57
+ fi || {
58
+ echo "sendemail-validate: error: failed to prepare worktree" >&2
59
+ exit 1
60
+ }
61
+
62
+ unset GIT_DIR GIT_WORK_TREE
63
+ cd "$worktree" &&
64
+
65
+ if grep -q "^diff --git " "$1"
66
+ then
67
+ validate_patch "$1"
68
+ else
69
+ validate_cover_letter "$1"
70
+ fi &&
71
+
72
+ if test "$GIT_SENDEMAIL_FILE_COUNTER" = "$GIT_SENDEMAIL_FILE_TOTAL"
73
+ then
74
+ git config --unset-all sendemail.validateWorktree &&
75
+ trap 'git worktree remove -ff "$worktree"' EXIT &&
76
+ validate_series
77
+ fi
original_performance_takehome/.git_backup/hooks/update.sample ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/sh
2
+ #
3
+ # An example hook script to block unannotated tags from entering.
4
+ # Called by "git receive-pack" with arguments: refname sha1-old sha1-new
5
+ #
6
+ # To enable this hook, rename this file to "update".
7
+ #
8
+ # Config
9
+ # ------
10
+ # hooks.allowunannotated
11
+ # This boolean sets whether unannotated tags will be allowed into the
12
+ # repository. By default they won't be.
13
+ # hooks.allowdeletetag
14
+ # This boolean sets whether deleting tags will be allowed in the
15
+ # repository. By default they won't be.
16
+ # hooks.allowmodifytag
17
+ # This boolean sets whether a tag may be modified after creation. By default
18
+ # it won't be.
19
+ # hooks.allowdeletebranch
20
+ # This boolean sets whether deleting branches will be allowed in the
21
+ # repository. By default they won't be.
22
+ # hooks.denycreatebranch
23
+ # This boolean sets whether remotely creating branches will be denied
24
+ # in the repository. By default this is allowed.
25
+ #
26
+
27
+ # --- Command line
28
+ refname="$1"
29
+ oldrev="$2"
30
+ newrev="$3"
31
+
32
+ # --- Safety check
33
+ if [ -z "$GIT_DIR" ]; then
34
+ echo "Don't run this script from the command line." >&2
35
+ echo " (if you want, you could supply GIT_DIR then run" >&2
36
+ echo " $0 <ref> <oldrev> <newrev>)" >&2
37
+ exit 1
38
+ fi
39
+
40
+ if [ -z "$refname" -o -z "$oldrev" -o -z "$newrev" ]; then
41
+ echo "usage: $0 <ref> <oldrev> <newrev>" >&2
42
+ exit 1
43
+ fi
44
+
45
+ # --- Config
46
+ allowunannotated=$(git config --type=bool hooks.allowunannotated)
47
+ allowdeletebranch=$(git config --type=bool hooks.allowdeletebranch)
48
+ denycreatebranch=$(git config --type=bool hooks.denycreatebranch)
49
+ allowdeletetag=$(git config --type=bool hooks.allowdeletetag)
50
+ allowmodifytag=$(git config --type=bool hooks.allowmodifytag)
51
+
52
+ # check for no description
53
+ projectdesc=$(sed -e '1q' "$GIT_DIR/description")
54
+ case "$projectdesc" in
55
+ "Unnamed repository"* | "")
56
+ echo "*** Project description file hasn't been set" >&2
57
+ exit 1
58
+ ;;
59
+ esac
60
+
61
+ # --- Check types
62
+ # if $newrev is 0000...0000, it's a commit to delete a ref.
63
+ zero=$(git hash-object --stdin </dev/null | tr '[0-9a-f]' '0')
64
+ if [ "$newrev" = "$zero" ]; then
65
+ newrev_type=delete
66
+ else
67
+ newrev_type=$(git cat-file -t $newrev)
68
+ fi
69
+
70
+ case "$refname","$newrev_type" in
71
+ refs/tags/*,commit)
72
+ # un-annotated tag
73
+ short_refname=${refname##refs/tags/}
74
+ if [ "$allowunannotated" != "true" ]; then
75
+ echo "*** The un-annotated tag, $short_refname, is not allowed in this repository" >&2
76
+ echo "*** Use 'git tag [ -a | -s ]' for tags you want to propagate." >&2
77
+ exit 1
78
+ fi
79
+ ;;
80
+ refs/tags/*,delete)
81
+ # delete tag
82
+ if [ "$allowdeletetag" != "true" ]; then
83
+ echo "*** Deleting a tag is not allowed in this repository" >&2
84
+ exit 1
85
+ fi
86
+ ;;
87
+ refs/tags/*,tag)
88
+ # annotated tag
89
+ if [ "$allowmodifytag" != "true" ] && git rev-parse $refname > /dev/null 2>&1
90
+ then
91
+ echo "*** Tag '$refname' already exists." >&2
92
+ echo "*** Modifying a tag is not allowed in this repository." >&2
93
+ exit 1
94
+ fi
95
+ ;;
96
+ refs/heads/*,commit)
97
+ # branch
98
+ if [ "$oldrev" = "$zero" -a "$denycreatebranch" = "true" ]; then
99
+ echo "*** Creating a branch is not allowed in this repository" >&2
100
+ exit 1
101
+ fi
102
+ ;;
103
+ refs/heads/*,delete)
104
+ # delete branch
105
+ if [ "$allowdeletebranch" != "true" ]; then
106
+ echo "*** Deleting a branch is not allowed in this repository" >&2
107
+ exit 1
108
+ fi
109
+ ;;
110
+ refs/remotes/*,commit)
111
+ # tracking branch
112
+ ;;
113
+ refs/remotes/*,delete)
114
+ # delete tracking branch
115
+ if [ "$allowdeletebranch" != "true" ]; then
116
+ echo "*** Deleting a tracking branch is not allowed in this repository" >&2
117
+ exit 1
118
+ fi
119
+ ;;
120
+ *)
121
+ # Anything else (is there anything else?)
122
+ echo "*** Update hook: unknown type of update to ref $refname of type $newrev_type" >&2
123
+ exit 1
124
+ ;;
125
+ esac
126
+
127
+ # --- Finished
128
+ exit 0
original_performance_takehome/.git_backup/index ADDED
Binary file (743 Bytes). View file
 
original_performance_takehome/.git_backup/info/exclude ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # git ls-files --others --exclude-from=.git/info/exclude
2
+ # Lines that start with '#' are comments.
3
+ # For a project mostly in C, the following would be a good set of
4
+ # exclude patterns (uncomment them if you want to use them):
5
+ # *.[oa]
6
+ # *~
original_performance_takehome/.git_backup/logs/HEAD ADDED
@@ -0,0 +1 @@
 
 
1
+ 0000000000000000000000000000000000000000 5452f74bd977807ac2e74f3d29432b9df6f25197 Jung Dae Suh <jungdaesuh1221@gmail.com> 1769316765 +0900 clone: from https://github.com/anthropics/original_performance_takehome.git
original_performance_takehome/.git_backup/logs/refs/heads/main ADDED
@@ -0,0 +1 @@
 
 
1
+ 0000000000000000000000000000000000000000 5452f74bd977807ac2e74f3d29432b9df6f25197 Jung Dae Suh <jungdaesuh1221@gmail.com> 1769316765 +0900 clone: from https://github.com/anthropics/original_performance_takehome.git
original_performance_takehome/.git_backup/logs/refs/remotes/origin/HEAD ADDED
@@ -0,0 +1 @@
 
 
1
+ 0000000000000000000000000000000000000000 5452f74bd977807ac2e74f3d29432b9df6f25197 Jung Dae Suh <jungdaesuh1221@gmail.com> 1769316765 +0900 clone: from https://github.com/anthropics/original_performance_takehome.git
original_performance_takehome/.git_backup/objects/pack/pack-813c2c470e2abf2cfcfb6aa8ba6478e559e69577.idx ADDED
Binary file (1.8 kB). View file
 
original_performance_takehome/.git_backup/objects/pack/pack-813c2c470e2abf2cfcfb6aa8ba6478e559e69577.pack ADDED
Binary file (20.2 kB). View file
 
original_performance_takehome/.git_backup/objects/pack/pack-813c2c470e2abf2cfcfb6aa8ba6478e559e69577.rev ADDED
Binary file (156 Bytes). View file
 
original_performance_takehome/.git_backup/packed-refs ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # pack-refs with: peeled fully-peeled sorted
2
+ 5452f74bd977807ac2e74f3d29432b9df6f25197 refs/remotes/origin/main
3
+ d45812f96a6740086db7f2aa78925d9a0b7389dd refs/remotes/origin/tristan/add-warning
4
+ 3697cecc2a093b4df01de46e6a61b3b56d3ad6be refs/remotes/origin/tristan/update-readme
original_performance_takehome/.git_backup/refs/heads/main ADDED
@@ -0,0 +1 @@
 
 
1
+ 5452f74bd977807ac2e74f3d29432b9df6f25197
original_performance_takehome/.git_backup/refs/remotes/origin/HEAD ADDED
@@ -0,0 +1 @@
 
 
1
+ ref: refs/remotes/origin/main
original_performance_takehome/.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ trace.json
2
+ **/*.pyc
3
+ .hypothesis
4
+ .DS_Store
original_performance_takehome/Readme.md ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Anthropic's Original Performance Take-Home
2
+
3
+ This repo contains a version of Anthropic's original performance take-home, before Claude Opus 4.5 started doing better than humans given only 2 hours.
4
+
5
+ The original take-home was a 4-hour one that starts close to the contents of this repo, after Claude Opus 4 beat most humans at that, it was updated to a 2-hour one which started with code which achieved 18532 cycles (7.97x faster than this repo starts you). This repo is based on the newer take-home which has a few more instructions and comes with better debugging tools, but has the starter code reverted to the slowest baseline. After Claude Opus 4.5 we started using a different base for our time-limited take-homes.
6
+
7
+ Now you can try to beat Claude Opus 4.5 given unlimited time!
8
+
9
+ ## Performance benchmarks
10
+
11
+ Measured in clock cycles from the simulated machine. All of these numbers are for models doing the 2 hour version which started at 18532 cycles:
12
+
13
+ - **2164 cycles**: Claude Opus 4 after many hours in the test-time compute harness
14
+ - **1790 cycles**: Claude Opus 4.5 in a casual Claude Code session, approximately matching the best human performance in 2 hours
15
+ - **1579 cycles**: Claude Opus 4.5 after 2 hours in our test-time compute harness
16
+ - **1548 cycles**: Claude Sonnet 4.5 after many more than 2 hours of test-time compute
17
+ - **1487 cycles**: Claude Opus 4.5 after 11.5 hours in the harness
18
+ - **1363 cycles**: Claude Opus 4.5 in an improved test time compute harness
19
+ - **??? cycles**: Best human performance ever is substantially better than the above, but we won't say how much.
20
+
21
+ While it's no longer a good time-limited test, you can still use this test to get us excited about hiring you! If you optimize below 1487 cycles, beating Claude Opus 4.5's best performance at launch, email us at performance-recruiting@anthropic.com with your code (and ideally a resume) so we can be appropriately impressed, especially if you get near the best solution we've seen. New model releases may change what threshold impresses us though, and no guarantees that we keep this readme updated with the latest on that.
22
+
23
+ Run `python tests/submission_tests.py` to see which thresholds you pass.
24
+
25
+ ## Warning: LLMs can cheat
26
+
27
+ None of the solutions we received on the first day post-release below 1300 cycles were valid solutions. In each case, a language model modified the tests to make the problem easier.
28
+
29
+ If you use an AI agent, we recommend instructing it not to change the `tests/` folder and to use `tests/submission_tests.py` for verification.
30
+
31
+ Please run the following commands to validate your submission, and mention that you did so when submitting:
32
+ ```
33
+ # This should be empty, the tests folder must be unchanged
34
+ git diff origin/main tests/
35
+ # You should pass some of these tests and use the cycle count this prints
36
+ python tests/submission_tests.py
37
+ ```
38
+
39
+ An example of this kind of hack is a model noticing that `problem.py` has multicore support, implementing multicore as an optimization, noticing there's no speedup and "debugging" that `N_CORES = 1` and "fixing" the core count so they get a speedup. Multicore is disabled intentionally in this version.
original_performance_takehome/perf_takehome.py ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ # Anthropic's Original Performance Engineering Take-home (Release version)
3
+
4
+ Copyright Anthropic PBC 2026. Permission is granted to modify and use, but not
5
+ to publish or redistribute your solutions so it's hard to find spoilers.
6
+
7
+ # Task
8
+
9
+ - Optimize the kernel (in KernelBuilder.build_kernel) as much as possible in the
10
+ available time, as measured by test_kernel_cycles on a frozen separate copy
11
+ of the simulator.
12
+
13
+ Validate your results using `python tests/submission_tests.py` without modifying
14
+ anything in the tests/ folder.
15
+
16
+ We recommend you look through problem.py next.
17
+ """
18
+
19
+ from collections import defaultdict
20
+ import random
21
+ import unittest
22
+
23
+ from problem import (
24
+ Engine,
25
+ DebugInfo,
26
+ SLOT_LIMITS,
27
+ VLEN,
28
+ N_CORES,
29
+ SCRATCH_SIZE,
30
+ Machine,
31
+ Tree,
32
+ Input,
33
+ HASH_STAGES,
34
+ reference_kernel,
35
+ build_mem_image,
36
+ reference_kernel2,
37
+ )
38
+
39
+
40
+ class KernelBuilder:
41
+ def __init__(self):
42
+ self.instrs = []
43
+ self.scratch = {}
44
+ self.scratch_debug = {}
45
+ self.scratch_ptr = 0
46
+ self.const_map = {}
47
+
48
+ def debug_info(self):
49
+ return DebugInfo(scratch_map=self.scratch_debug)
50
+
51
+ def build(self, slots: list[tuple[Engine, tuple]], vliw: bool = False):
52
+ # Simple slot packing that just uses one slot per instruction bundle
53
+ instrs = []
54
+ for engine, slot in slots:
55
+ instrs.append({engine: [slot]})
56
+ return instrs
57
+
58
+ def add(self, engine, slot):
59
+ self.instrs.append({engine: [slot]})
60
+
61
+ def alloc_scratch(self, name=None, length=1):
62
+ addr = self.scratch_ptr
63
+ if name is not None:
64
+ self.scratch[name] = addr
65
+ self.scratch_debug[addr] = (name, length)
66
+ self.scratch_ptr += length
67
+ assert self.scratch_ptr <= SCRATCH_SIZE, "Out of scratch space"
68
+ return addr
69
+
70
+ def scratch_const(self, val, name=None):
71
+ if val not in self.const_map:
72
+ addr = self.alloc_scratch(name)
73
+ self.add("load", ("const", addr, val))
74
+ self.const_map[val] = addr
75
+ return self.const_map[val]
76
+
77
+ def build_hash(self, val_hash_addr, tmp1, tmp2, round, i):
78
+ slots = []
79
+
80
+ for hi, (op1, val1, op2, op3, val3) in enumerate(HASH_STAGES):
81
+ slots.append(("alu", (op1, tmp1, val_hash_addr, self.scratch_const(val1))))
82
+ slots.append(("alu", (op3, tmp2, val_hash_addr, self.scratch_const(val3))))
83
+ slots.append(("alu", (op2, val_hash_addr, tmp1, tmp2)))
84
+ slots.append(("debug", ("compare", val_hash_addr, (round, i, "hash_stage", hi))))
85
+
86
+ return slots
87
+
88
+ def build_kernel(
89
+ self, forest_height: int, n_nodes: int, batch_size: int, rounds: int
90
+ ):
91
+ """
92
+ Like reference_kernel2 but building actual instructions.
93
+ Scalar implementation using only scalar ALU and load/store.
94
+ """
95
+ tmp1 = self.alloc_scratch("tmp1")
96
+ tmp2 = self.alloc_scratch("tmp2")
97
+ tmp3 = self.alloc_scratch("tmp3")
98
+ # Scratch space addresses
99
+ init_vars = [
100
+ "rounds",
101
+ "n_nodes",
102
+ "batch_size",
103
+ "forest_height",
104
+ "forest_values_p",
105
+ "inp_indices_p",
106
+ "inp_values_p",
107
+ ]
108
+ for v in init_vars:
109
+ self.alloc_scratch(v, 1)
110
+ for i, v in enumerate(init_vars):
111
+ self.add("load", ("const", tmp1, i))
112
+ self.add("load", ("load", self.scratch[v], tmp1))
113
+
114
+ zero_const = self.scratch_const(0)
115
+ one_const = self.scratch_const(1)
116
+ two_const = self.scratch_const(2)
117
+
118
+ # Pause instructions are matched up with yield statements in the reference
119
+ # kernel to let you debug at intermediate steps. The testing harness in this
120
+ # file requires these match up to the reference kernel's yields, but the
121
+ # submission harness ignores them.
122
+ self.add("flow", ("pause",))
123
+ # Any debug engine instruction is ignored by the submission simulator
124
+ self.add("debug", ("comment", "Starting loop"))
125
+
126
+ body = [] # array of slots
127
+
128
+ # Scalar scratch registers
129
+ tmp_idx = self.alloc_scratch("tmp_idx")
130
+ tmp_val = self.alloc_scratch("tmp_val")
131
+ tmp_node_val = self.alloc_scratch("tmp_node_val")
132
+ tmp_addr = self.alloc_scratch("tmp_addr")
133
+
134
+ for round in range(rounds):
135
+ for i in range(batch_size):
136
+ i_const = self.scratch_const(i)
137
+ # idx = mem[inp_indices_p + i]
138
+ body.append(("alu", ("+", tmp_addr, self.scratch["inp_indices_p"], i_const)))
139
+ body.append(("load", ("load", tmp_idx, tmp_addr)))
140
+ body.append(("debug", ("compare", tmp_idx, (round, i, "idx"))))
141
+ # val = mem[inp_values_p + i]
142
+ body.append(("alu", ("+", tmp_addr, self.scratch["inp_values_p"], i_const)))
143
+ body.append(("load", ("load", tmp_val, tmp_addr)))
144
+ body.append(("debug", ("compare", tmp_val, (round, i, "val"))))
145
+ # node_val = mem[forest_values_p + idx]
146
+ body.append(("alu", ("+", tmp_addr, self.scratch["forest_values_p"], tmp_idx)))
147
+ body.append(("load", ("load", tmp_node_val, tmp_addr)))
148
+ body.append(("debug", ("compare", tmp_node_val, (round, i, "node_val"))))
149
+ # val = myhash(val ^ node_val)
150
+ body.append(("alu", ("^", tmp_val, tmp_val, tmp_node_val)))
151
+ body.extend(self.build_hash(tmp_val, tmp1, tmp2, round, i))
152
+ body.append(("debug", ("compare", tmp_val, (round, i, "hashed_val"))))
153
+ # idx = 2*idx + (1 if val % 2 == 0 else 2)
154
+ body.append(("alu", ("%", tmp1, tmp_val, two_const)))
155
+ body.append(("alu", ("==", tmp1, tmp1, zero_const)))
156
+ body.append(("flow", ("select", tmp3, tmp1, one_const, two_const)))
157
+ body.append(("alu", ("*", tmp_idx, tmp_idx, two_const)))
158
+ body.append(("alu", ("+", tmp_idx, tmp_idx, tmp3)))
159
+ body.append(("debug", ("compare", tmp_idx, (round, i, "next_idx"))))
160
+ # idx = 0 if idx >= n_nodes else idx
161
+ body.append(("alu", ("<", tmp1, tmp_idx, self.scratch["n_nodes"])))
162
+ body.append(("flow", ("select", tmp_idx, tmp1, tmp_idx, zero_const)))
163
+ body.append(("debug", ("compare", tmp_idx, (round, i, "wrapped_idx"))))
164
+ # mem[inp_indices_p + i] = idx
165
+ body.append(("alu", ("+", tmp_addr, self.scratch["inp_indices_p"], i_const)))
166
+ body.append(("store", ("store", tmp_addr, tmp_idx)))
167
+ # mem[inp_values_p + i] = val
168
+ body.append(("alu", ("+", tmp_addr, self.scratch["inp_values_p"], i_const)))
169
+ body.append(("store", ("store", tmp_addr, tmp_val)))
170
+
171
+ body_instrs = self.build(body)
172
+ self.instrs.extend(body_instrs)
173
+ # Required to match with the yield in reference_kernel2
174
+ self.instrs.append({"flow": [("pause",)]})
175
+
176
+ BASELINE = 147734
177
+
178
+ def do_kernel_test(
179
+ forest_height: int,
180
+ rounds: int,
181
+ batch_size: int,
182
+ seed: int = 123,
183
+ trace: bool = False,
184
+ prints: bool = False,
185
+ ):
186
+ print(f"{forest_height=}, {rounds=}, {batch_size=}")
187
+ random.seed(seed)
188
+ forest = Tree.generate(forest_height)
189
+ inp = Input.generate(forest, batch_size, rounds)
190
+ mem = build_mem_image(forest, inp)
191
+
192
+ kb = KernelBuilder()
193
+ kb.build_kernel(forest.height, len(forest.values), len(inp.indices), rounds)
194
+ # print(kb.instrs)
195
+
196
+ value_trace = {}
197
+ machine = Machine(
198
+ mem,
199
+ kb.instrs,
200
+ kb.debug_info(),
201
+ n_cores=N_CORES,
202
+ value_trace=value_trace,
203
+ trace=trace,
204
+ )
205
+ machine.prints = prints
206
+ for i, ref_mem in enumerate(reference_kernel2(mem, value_trace)):
207
+ machine.run()
208
+ inp_values_p = ref_mem[6]
209
+ if prints:
210
+ print(machine.mem[inp_values_p : inp_values_p + len(inp.values)])
211
+ print(ref_mem[inp_values_p : inp_values_p + len(inp.values)])
212
+ assert (
213
+ machine.mem[inp_values_p : inp_values_p + len(inp.values)]
214
+ == ref_mem[inp_values_p : inp_values_p + len(inp.values)]
215
+ ), f"Incorrect result on round {i}"
216
+ inp_indices_p = ref_mem[5]
217
+ if prints:
218
+ print(machine.mem[inp_indices_p : inp_indices_p + len(inp.indices)])
219
+ print(ref_mem[inp_indices_p : inp_indices_p + len(inp.indices)])
220
+ # Updating these in memory isn't required, but you can enable this check for debugging
221
+ # assert machine.mem[inp_indices_p:inp_indices_p+len(inp.indices)] == ref_mem[inp_indices_p:inp_indices_p+len(inp.indices)]
222
+
223
+ print("CYCLES: ", machine.cycle)
224
+ print("Speedup over baseline: ", BASELINE / machine.cycle)
225
+ return machine.cycle
226
+
227
+
228
+ class Tests(unittest.TestCase):
229
+ def test_ref_kernels(self):
230
+ """
231
+ Test the reference kernels against each other
232
+ """
233
+ random.seed(123)
234
+ for i in range(10):
235
+ f = Tree.generate(4)
236
+ inp = Input.generate(f, 10, 6)
237
+ mem = build_mem_image(f, inp)
238
+ reference_kernel(f, inp)
239
+ for _ in reference_kernel2(mem, {}):
240
+ pass
241
+ assert inp.indices == mem[mem[5] : mem[5] + len(inp.indices)]
242
+ assert inp.values == mem[mem[6] : mem[6] + len(inp.values)]
243
+
244
+ def test_kernel_trace(self):
245
+ # Full-scale example for performance testing
246
+ do_kernel_test(10, 16, 256, trace=True, prints=False)
247
+
248
+ # Passing this test is not required for submission, see submission_tests.py for the actual correctness test
249
+ # You can uncomment this if you think it might help you debug
250
+ # def test_kernel_correctness(self):
251
+ # for batch in range(1, 3):
252
+ # for forest_height in range(3):
253
+ # do_kernel_test(
254
+ # forest_height + 2, forest_height + 4, batch * 16 * VLEN * N_CORES
255
+ # )
256
+
257
+ def test_kernel_cycles(self):
258
+ do_kernel_test(10, 16, 256)
259
+
260
+
261
+ # To run all the tests:
262
+ # python perf_takehome.py
263
+ # To run a specific test:
264
+ # python perf_takehome.py Tests.test_kernel_cycles
265
+ # To view a hot-reloading trace of all the instructions: **Recommended debug loop**
266
+ # NOTE: The trace hot-reloading only works in Chrome. In the worst case if things aren't working, drag trace.json onto https://ui.perfetto.dev/
267
+ # python perf_takehome.py Tests.test_kernel_trace
268
+ # Then run `python watch_trace.py` in another tab, it'll open a browser tab, then click "Open Perfetto"
269
+ # You can then keep that open and re-run the test to see a new trace.
270
+
271
+ # To run the proper checks to see which thresholds you pass:
272
+ # python tests/submission_tests.py
273
+
274
+ if __name__ == "__main__":
275
+ unittest.main()
original_performance_takehome/problem.py ADDED
@@ -0,0 +1,568 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Read the top of perf_takehome.py for more introduction.
3
+
4
+ This file is separate mostly for ease of copying it to freeze the machine and
5
+ reference kernel for testing.
6
+ """
7
+
8
+ from copy import copy
9
+ from dataclasses import dataclass
10
+ from enum import Enum
11
+ from typing import Any, Literal
12
+ import random
13
+
14
+ Engine = Literal["alu", "load", "store", "flow"]
15
+ Instruction = dict[Engine, list[tuple]]
16
+
17
+
18
+ class CoreState(Enum):
19
+ RUNNING = 1
20
+ PAUSED = 2
21
+ STOPPED = 3
22
+
23
+
24
+ @dataclass
25
+ class Core:
26
+ id: int
27
+ scratch: list[int]
28
+ trace_buf: list[int]
29
+ pc: int = 0
30
+ state: CoreState = CoreState.RUNNING
31
+
32
+
33
+ @dataclass
34
+ class DebugInfo:
35
+ """
36
+ We give you some debug info but it's up to you to use it in Machine if you
37
+ want to. You're also welcome to add more.
38
+ """
39
+
40
+ # Maps scratch variable addr to (name, len) pair
41
+ scratch_map: dict[int, (str, int)]
42
+
43
+
44
+ def cdiv(a, b):
45
+ return (a + b - 1) // b
46
+
47
+
48
+ SLOT_LIMITS = {
49
+ "alu": 12,
50
+ "valu": 6,
51
+ "load": 2,
52
+ "store": 2,
53
+ "flow": 1,
54
+ "debug": 64,
55
+ }
56
+
57
+ VLEN = 8
58
+ # Older versions of the take-home used multiple cores, but this version only uses 1
59
+ N_CORES = 1
60
+ SCRATCH_SIZE = 1536
61
+ BASE_ADDR_TID = 100000
62
+
63
+
64
+ class Machine:
65
+ """
66
+ Simulator for a custom VLIW SIMD architecture.
67
+
68
+ VLIW (Very Large Instruction Word): Cores are composed of different
69
+ "engines" each of which can execute multiple "slots" per cycle in parallel.
70
+ How many slots each engine can execute per cycle is limited by SLOT_LIMITS.
71
+ Effects of instructions don't take effect until the end of cycle. Each
72
+ cycle, all engines execute all of their filled slots for that instruction.
73
+ Effects like writes to memory take place after all the inputs are read.
74
+
75
+ SIMD: There are instructions for acting on vectors of VLEN elements in a
76
+ single slot. You can use vload and vstore to load multiple contiguous
77
+ elements but not non-contiguous elements. Use vbroadcast to broadcast a
78
+ scalar to a vector and then operate on vectors with valu instructions.
79
+
80
+ The memory and scratch space are composed of 32-bit words. The solution is
81
+ plucked out of the memory at the end of the program. You can think of the
82
+ scratch space as serving the purpose of registers, constant memory, and a
83
+ manually-managed cache.
84
+
85
+ Here's an example of what an instruction might look like:
86
+
87
+ {"valu": [("*", 4, 0, 0), ("+", 8, 4, 0)], "load": [("load", 16, 17)]}
88
+
89
+ In general every number in an instruction is a scratch address except for
90
+ const and jump, and except for store and some flow instructions the first
91
+ operand is the destination.
92
+
93
+ This comment is not meant to be full ISA documentation though, for the rest
94
+ you should look through the simulator code.
95
+ """
96
+
97
+ def __init__(
98
+ self,
99
+ mem_dump: list[int],
100
+ program: list[Instruction],
101
+ debug_info: DebugInfo,
102
+ n_cores: int = 1,
103
+ scratch_size: int = SCRATCH_SIZE,
104
+ trace: bool = False,
105
+ value_trace: dict[Any, int] = {},
106
+ ):
107
+ self.cores = [
108
+ Core(id=i, scratch=[0] * scratch_size, trace_buf=[]) for i in range(n_cores)
109
+ ]
110
+ self.mem = copy(mem_dump)
111
+ self.program = program
112
+ self.debug_info = debug_info
113
+ self.value_trace = value_trace
114
+ self.prints = False
115
+ self.cycle = 0
116
+ self.enable_pause = True
117
+ self.enable_debug = True
118
+ if trace:
119
+ self.setup_trace()
120
+ else:
121
+ self.trace = None
122
+
123
+ def rewrite_instr(self, instr):
124
+ """
125
+ Rewrite an instruction to use scratch addresses instead of names
126
+ """
127
+ res = {}
128
+ for name, slots in instr.items():
129
+ res[name] = []
130
+ for slot in slots:
131
+ res[name].append(self.rewrite_slot(slot))
132
+ return res
133
+
134
+ def print_step(self, instr, core):
135
+ # print(core.id)
136
+ # print(core.trace_buf)
137
+ print(self.scratch_map(core))
138
+ print(core.pc, instr, self.rewrite_instr(instr))
139
+
140
+ def scratch_map(self, core):
141
+ res = {}
142
+ for addr, (name, length) in self.debug_info.scratch_map.items():
143
+ res[name] = core.scratch[addr : addr + length]
144
+ return res
145
+
146
+ def rewrite_slot(self, slot):
147
+ return tuple(
148
+ self.debug_info.scratch_map.get(s, (None, None))[0] or s for s in slot
149
+ )
150
+
151
+ def setup_trace(self):
152
+ """
153
+ The simulator generates traces in Chrome's Trace Event Format for
154
+ visualization in Perfetto (or chrome://tracing if you prefer it). See
155
+ the bottom of the file for info about how to use this.
156
+
157
+ See the format docs in case you want to add more info to the trace:
158
+ https://docs.google.com/document/d/1CvAClvFfyA5R-PhYUmn5OOQtYMH4h6I0nSsKchNAySU/preview
159
+ """
160
+ self.trace = open("trace.json", "w")
161
+ self.trace.write("[")
162
+ tid_counter = 0
163
+ self.tids = {}
164
+ for ci, core in enumerate(self.cores):
165
+ self.trace.write(
166
+ f'{{"name": "process_name", "ph": "M", "pid": {ci}, "tid": 0, "args": {{"name":"Core {ci}"}}}},\n'
167
+ )
168
+ for name, limit in SLOT_LIMITS.items():
169
+ if name == "debug":
170
+ continue
171
+ for i in range(limit):
172
+ tid_counter += 1
173
+ self.trace.write(
174
+ f'{{"name": "thread_name", "ph": "M", "pid": {ci}, "tid": {tid_counter}, "args": {{"name":"{name}-{i}"}}}},\n'
175
+ )
176
+ self.tids[(ci, name, i)] = tid_counter
177
+
178
+ # Add zero-length events at the start so all slots show up in Perfetto
179
+ for ci, core in enumerate(self.cores):
180
+ for name, limit in SLOT_LIMITS.items():
181
+ if name == "debug":
182
+ continue
183
+ for i in range(limit):
184
+ tid = self.tids[(ci, name, i)]
185
+ self.trace.write(
186
+ f'{{"name": "init", "cat": "op", "ph": "X", "pid": {ci}, "tid": {tid}, "ts": 0, "dur": 0}},\n'
187
+ )
188
+ for ci, core in enumerate(self.cores):
189
+ self.trace.write(
190
+ f'{{"name": "process_name", "ph": "M", "pid": {len(self.cores) + ci}, "tid": 0, "args": {{"name":"Core {ci} Scratch"}}}},\n'
191
+ )
192
+ for addr, (name, length) in self.debug_info.scratch_map.items():
193
+ self.trace.write(
194
+ f'{{"name": "thread_name", "ph": "M", "pid": {len(self.cores) + ci}, "tid": {BASE_ADDR_TID + addr}, "args": {{"name":"{name}-{length}"}}}},\n'
195
+ )
196
+
197
+ def run(self):
198
+ for core in self.cores:
199
+ if core.state == CoreState.PAUSED:
200
+ core.state = CoreState.RUNNING
201
+ while any(c.state == CoreState.RUNNING for c in self.cores):
202
+ has_non_debug = False
203
+ for core in self.cores:
204
+ if core.state != CoreState.RUNNING:
205
+ continue
206
+ if core.pc >= len(self.program):
207
+ core.state = CoreState.STOPPED
208
+ continue
209
+ instr = self.program[core.pc]
210
+ if self.prints:
211
+ self.print_step(instr, core)
212
+ core.pc += 1
213
+ self.step(instr, core)
214
+ if any(name != "debug" for name in instr.keys()):
215
+ has_non_debug = True
216
+ if has_non_debug:
217
+ self.cycle += 1
218
+
219
+ def alu(self, core, op, dest, a1, a2):
220
+ a1 = core.scratch[a1]
221
+ a2 = core.scratch[a2]
222
+ match op:
223
+ case "+":
224
+ res = a1 + a2
225
+ case "-":
226
+ res = a1 - a2
227
+ case "*":
228
+ res = a1 * a2
229
+ case "//":
230
+ res = a1 // a2
231
+ case "cdiv":
232
+ res = cdiv(a1, a2)
233
+ case "^":
234
+ res = a1 ^ a2
235
+ case "&":
236
+ res = a1 & a2
237
+ case "|":
238
+ res = a1 | a2
239
+ case "<<":
240
+ res = a1 << a2
241
+ case ">>":
242
+ res = a1 >> a2
243
+ case "%":
244
+ res = a1 % a2
245
+ case "<":
246
+ res = int(a1 < a2)
247
+ case "==":
248
+ res = int(a1 == a2)
249
+ case _:
250
+ raise NotImplementedError(f"Unknown alu op {op}")
251
+ res = res % (2**32)
252
+ self.scratch_write[dest] = res
253
+
254
+ def valu(self, core, *slot):
255
+ match slot:
256
+ case ("vbroadcast", dest, src):
257
+ for i in range(VLEN):
258
+ self.scratch_write[dest + i] = core.scratch[src]
259
+ case ("multiply_add", dest, a, b, c):
260
+ for i in range(VLEN):
261
+ mul = (core.scratch[a + i] * core.scratch[b + i]) % (2**32)
262
+ self.scratch_write[dest + i] = (mul + core.scratch[c + i]) % (2**32)
263
+ case (op, dest, a1, a2):
264
+ for i in range(VLEN):
265
+ self.alu(core, op, dest + i, a1 + i, a2 + i)
266
+ case _:
267
+ raise NotImplementedError(f"Unknown valu op {slot}")
268
+
269
+ def load(self, core, *slot):
270
+ match slot:
271
+ case ("load", dest, addr):
272
+ # print(dest, addr, core.scratch[addr])
273
+ self.scratch_write[dest] = self.mem[core.scratch[addr]]
274
+ case ("load_offset", dest, addr, offset):
275
+ # Handy for treating vector dest and addr as a full block in the mini-compiler if you want
276
+ self.scratch_write[dest + offset] = self.mem[
277
+ core.scratch[addr + offset]
278
+ ]
279
+ case ("vload", dest, addr): # addr is a scalar
280
+ addr = core.scratch[addr]
281
+ for vi in range(VLEN):
282
+ self.scratch_write[dest + vi] = self.mem[addr + vi]
283
+ case ("const", dest, val):
284
+ self.scratch_write[dest] = (val) % (2**32)
285
+ case _:
286
+ raise NotImplementedError(f"Unknown load op {slot}")
287
+
288
+ def store(self, core, *slot):
289
+ match slot:
290
+ case ("store", addr, src):
291
+ addr = core.scratch[addr]
292
+ self.mem_write[addr] = core.scratch[src]
293
+ case ("vstore", addr, src): # addr is a scalar
294
+ addr = core.scratch[addr]
295
+ for vi in range(VLEN):
296
+ self.mem_write[addr + vi] = core.scratch[src + vi]
297
+ case _:
298
+ raise NotImplementedError(f"Unknown store op {slot}")
299
+
300
+ def flow(self, core, *slot):
301
+ match slot:
302
+ case ("select", dest, cond, a, b):
303
+ self.scratch_write[dest] = (
304
+ core.scratch[a] if core.scratch[cond] != 0 else core.scratch[b]
305
+ )
306
+ case ("add_imm", dest, a, imm):
307
+ self.scratch_write[dest] = (core.scratch[a] + imm) % (2**32)
308
+ case ("vselect", dest, cond, a, b):
309
+ for vi in range(VLEN):
310
+ self.scratch_write[dest + vi] = (
311
+ core.scratch[a + vi]
312
+ if core.scratch[cond + vi] != 0
313
+ else core.scratch[b + vi]
314
+ )
315
+ case ("halt",):
316
+ core.state = CoreState.STOPPED
317
+ case ("pause",):
318
+ if self.enable_pause:
319
+ core.state = CoreState.PAUSED
320
+ case ("trace_write", val):
321
+ core.trace_buf.append(core.scratch[val])
322
+ case ("cond_jump", cond, addr):
323
+ if core.scratch[cond] != 0:
324
+ core.pc = addr
325
+ case ("cond_jump_rel", cond, offset):
326
+ if core.scratch[cond] != 0:
327
+ core.pc += offset
328
+ case ("jump", addr):
329
+ core.pc = addr
330
+ case ("jump_indirect", addr):
331
+ core.pc = core.scratch[addr]
332
+ case ("coreid", dest):
333
+ self.scratch_write[dest] = core.id
334
+ case _:
335
+ raise NotImplementedError(f"Unknown flow op {slot}")
336
+
337
+ def trace_post_step(self, instr, core):
338
+ # You can add extra stuff to the trace if you want!
339
+ for addr, (name, length) in self.debug_info.scratch_map.items():
340
+ if any((addr + vi) in self.scratch_write for vi in range(length)):
341
+ val = str(core.scratch[addr : addr + length])
342
+ val = val.replace("[", "").replace("]", "")
343
+ self.trace.write(
344
+ f'{{"name": "{val}", "cat": "op", "ph": "X", "pid": {len(self.cores) + core.id}, "tid": {BASE_ADDR_TID + addr}, "ts": {self.cycle}, "dur": 1 }},\n'
345
+ )
346
+
347
+ def trace_slot(self, core, slot, name, i):
348
+ self.trace.write(
349
+ f'{{"name": "{slot[0]}", "cat": "op", "ph": "X", "pid": {core.id}, "tid": {self.tids[(core.id, name, i)]}, "ts": {self.cycle}, "dur": 1, "args":{{"slot": "{str(slot)}", "named": "{str(self.rewrite_slot(slot))}" }} }},\n'
350
+ )
351
+
352
+ def step(self, instr: Instruction, core):
353
+ """
354
+ Execute all the slots in each engine for a single instruction bundle
355
+ """
356
+ ENGINE_FNS = {
357
+ "alu": self.alu,
358
+ "valu": self.valu,
359
+ "load": self.load,
360
+ "store": self.store,
361
+ "flow": self.flow,
362
+ }
363
+ self.scratch_write = {}
364
+ self.mem_write = {}
365
+ for name, slots in instr.items():
366
+ if name == "debug":
367
+ if not self.enable_debug:
368
+ continue
369
+ for slot in slots:
370
+ if slot[0] == "compare":
371
+ loc, key = slot[1], slot[2]
372
+ ref = self.value_trace[key]
373
+ res = core.scratch[loc]
374
+ assert res == ref, f"{res} != {ref} for {key} at pc={core.pc}"
375
+ elif slot[0] == "vcompare":
376
+ loc, keys = slot[1], slot[2]
377
+ ref = [self.value_trace[key] for key in keys]
378
+ res = core.scratch[loc : loc + VLEN]
379
+ assert res == ref, (
380
+ f"{res} != {ref} for {keys} at pc={core.pc} loc={loc}"
381
+ )
382
+ continue
383
+ assert len(slots) <= SLOT_LIMITS[name]
384
+ for i, slot in enumerate(slots):
385
+ if self.trace is not None:
386
+ self.trace_slot(core, slot, name, i)
387
+ ENGINE_FNS[name](core, *slot)
388
+ for addr, val in self.scratch_write.items():
389
+ core.scratch[addr] = val
390
+ for addr, val in self.mem_write.items():
391
+ self.mem[addr] = val
392
+
393
+ if self.trace:
394
+ self.trace_post_step(instr, core)
395
+
396
+ del self.scratch_write
397
+ del self.mem_write
398
+
399
+ def __del__(self):
400
+ if self.trace is not None:
401
+ self.trace.write("]")
402
+ self.trace.close()
403
+
404
+
405
+ @dataclass
406
+ class Tree:
407
+ """
408
+ An implicit perfect balanced binary tree with values on the nodes.
409
+ """
410
+
411
+ height: int
412
+ values: list[int]
413
+
414
+ @staticmethod
415
+ def generate(height: int):
416
+ n_nodes = 2 ** (height + 1) - 1
417
+ values = [random.randint(0, 2**30 - 1) for _ in range(n_nodes)]
418
+ return Tree(height, values)
419
+
420
+
421
+ @dataclass
422
+ class Input:
423
+ """
424
+ A batch of inputs, indices to nodes (starting as 0) and initial input
425
+ values. We then iterate these for a specified number of rounds.
426
+ """
427
+
428
+ indices: list[int]
429
+ values: list[int]
430
+ rounds: int
431
+
432
+ @staticmethod
433
+ def generate(forest: Tree, batch_size: int, rounds: int):
434
+ indices = [0 for _ in range(batch_size)]
435
+ values = [random.randint(0, 2**30 - 1) for _ in range(batch_size)]
436
+ return Input(indices, values, rounds)
437
+
438
+
439
+ HASH_STAGES = [
440
+ ("+", 0x7ED55D16, "+", "<<", 12),
441
+ ("^", 0xC761C23C, "^", ">>", 19),
442
+ ("+", 0x165667B1, "+", "<<", 5),
443
+ ("+", 0xD3A2646C, "^", "<<", 9),
444
+ ("+", 0xFD7046C5, "+", "<<", 3),
445
+ ("^", 0xB55A4F09, "^", ">>", 16),
446
+ ]
447
+
448
+
449
+ def myhash(a: int) -> int:
450
+ """A simple 32-bit hash function"""
451
+ fns = {
452
+ "+": lambda x, y: x + y,
453
+ "^": lambda x, y: x ^ y,
454
+ "<<": lambda x, y: x << y,
455
+ ">>": lambda x, y: x >> y,
456
+ }
457
+
458
+ def r(x):
459
+ return x % (2**32)
460
+
461
+ for op1, val1, op2, op3, val3 in HASH_STAGES:
462
+ a = r(fns[op2](r(fns[op1](a, val1)), r(fns[op3](a, val3))))
463
+
464
+ return a
465
+
466
+
467
+ def reference_kernel(t: Tree, inp: Input):
468
+ """
469
+ Reference implementation of the kernel.
470
+
471
+ A parallel tree traversal where at each node we set
472
+ cur_inp_val = myhash(cur_inp_val ^ node_val)
473
+ and then choose the left branch if cur_inp_val is even.
474
+ If we reach the bottom of the tree we wrap around to the top.
475
+ """
476
+ for h in range(inp.rounds):
477
+ for i in range(len(inp.indices)):
478
+ idx = inp.indices[i]
479
+ val = inp.values[i]
480
+ val = myhash(val ^ t.values[idx])
481
+ idx = 2 * idx + (1 if val % 2 == 0 else 2)
482
+ idx = 0 if idx >= len(t.values) else idx
483
+ inp.values[i] = val
484
+ inp.indices[i] = idx
485
+
486
+
487
+ def build_mem_image(t: Tree, inp: Input) -> list[int]:
488
+ """
489
+ Build a flat memory image of the problem.
490
+ """
491
+ header = 7
492
+ extra_room = len(t.values) + len(inp.indices) * 2 + VLEN * 2 + 32
493
+ mem = [0] * (
494
+ header + len(t.values) + len(inp.indices) + len(inp.values) + extra_room
495
+ )
496
+ forest_values_p = header
497
+ inp_indices_p = forest_values_p + len(t.values)
498
+ inp_values_p = inp_indices_p + len(inp.values)
499
+ extra_room = inp_values_p + len(inp.values)
500
+
501
+ mem[0] = inp.rounds
502
+ mem[1] = len(t.values)
503
+ mem[2] = len(inp.indices)
504
+ mem[3] = t.height
505
+ mem[4] = forest_values_p
506
+ mem[5] = inp_indices_p
507
+ mem[6] = inp_values_p
508
+ mem[7] = extra_room
509
+
510
+ mem[header:inp_indices_p] = t.values
511
+ mem[inp_indices_p:inp_values_p] = inp.indices
512
+ mem[inp_values_p:] = inp.values
513
+ return mem
514
+
515
+
516
+ def myhash_traced(a: int, trace: dict[Any, int], round: int, batch_i: int) -> int:
517
+ """A simple 32-bit hash function"""
518
+ fns = {
519
+ "+": lambda x, y: x + y,
520
+ "^": lambda x, y: x ^ y,
521
+ "<<": lambda x, y: x << y,
522
+ ">>": lambda x, y: x >> y,
523
+ }
524
+
525
+ def r(x):
526
+ return x % (2**32)
527
+
528
+ for i, (op1, val1, op2, op3, val3) in enumerate(HASH_STAGES):
529
+ a = r(fns[op2](r(fns[op1](a, val1)), r(fns[op3](a, val3))))
530
+ trace[(round, batch_i, "hash_stage", i)] = a
531
+
532
+ return a
533
+
534
+
535
+ def reference_kernel2(mem: list[int], trace: dict[Any, int] = {}):
536
+ """
537
+ Reference implementation of the kernel on a flat memory.
538
+ """
539
+ # This is the initial memory layout
540
+ rounds = mem[0]
541
+ n_nodes = mem[1]
542
+ batch_size = mem[2]
543
+ forest_height = mem[3]
544
+ # Offsets into the memory which indices get added to
545
+ forest_values_p = mem[4]
546
+ inp_indices_p = mem[5]
547
+ inp_values_p = mem[6]
548
+ yield mem
549
+ for h in range(rounds):
550
+ for i in range(batch_size):
551
+ idx = mem[inp_indices_p + i]
552
+ trace[(h, i, "idx")] = idx
553
+ val = mem[inp_values_p + i]
554
+ trace[(h, i, "val")] = val
555
+ node_val = mem[forest_values_p + idx]
556
+ trace[(h, i, "node_val")] = node_val
557
+ val = myhash_traced(val ^ node_val, trace, h, i)
558
+ trace[(h, i, "hashed_val")] = val
559
+ idx = 2 * idx + (1 if val % 2 == 0 else 2)
560
+ trace[(h, i, "next_idx")] = idx
561
+ idx = 0 if idx >= n_nodes else idx
562
+ trace[(h, i, "wrapped_idx")] = idx
563
+ mem[inp_values_p + i] = val
564
+ mem[inp_indices_p + i] = idx
565
+ # You can add new yields or move this around for debugging
566
+ # as long as it's matched by pause instructions.
567
+ # The submission tests evaluate only on final memory.
568
+ yield mem
original_performance_takehome/tests/frozen_problem.py ADDED
@@ -0,0 +1,568 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Read the top of perf_takehome.py for more introduction.
3
+
4
+ This file is separate mostly for ease of copying it to freeze the machine and
5
+ reference kernel for testing.
6
+ """
7
+
8
+ from copy import copy
9
+ from dataclasses import dataclass
10
+ from enum import Enum
11
+ from typing import Any, Literal
12
+ import random
13
+
14
+ Engine = Literal["alu", "load", "store", "flow"]
15
+ Instruction = dict[Engine, list[tuple]]
16
+
17
+
18
+ class CoreState(Enum):
19
+ RUNNING = 1
20
+ PAUSED = 2
21
+ STOPPED = 3
22
+
23
+
24
+ @dataclass
25
+ class Core:
26
+ id: int
27
+ scratch: list[int]
28
+ trace_buf: list[int]
29
+ pc: int = 0
30
+ state: CoreState = CoreState.RUNNING
31
+
32
+
33
+ @dataclass
34
+ class DebugInfo:
35
+ """
36
+ We give you some debug info but it's up to you to use it in Machine if you
37
+ want to. You're also welcome to add more.
38
+ """
39
+
40
+ # Maps scratch variable addr to (name, len) pair
41
+ scratch_map: dict[int, (str, int)]
42
+
43
+
44
+ def cdiv(a, b):
45
+ return (a + b - 1) // b
46
+
47
+
48
+ SLOT_LIMITS = {
49
+ "alu": 12,
50
+ "valu": 6,
51
+ "load": 2,
52
+ "store": 2,
53
+ "flow": 1,
54
+ "debug": 64,
55
+ }
56
+
57
+ VLEN = 8
58
+ # Older versions of the take-home used multiple cores, but this version only uses 1
59
+ N_CORES = 1
60
+ SCRATCH_SIZE = 1536
61
+ BASE_ADDR_TID = 100000
62
+
63
+
64
+ class Machine:
65
+ """
66
+ Simulator for a custom VLIW SIMD architecture.
67
+
68
+ VLIW (Very Large Instruction Word): Cores are composed of different
69
+ "engines" each of which can execute multiple "slots" per cycle in parallel.
70
+ How many slots each engine can execute per cycle is limited by SLOT_LIMITS.
71
+ Effects of instructions don't take effect until the end of cycle. Each
72
+ cycle, all engines execute all of their filled slots for that instruction.
73
+ Effects like writes to memory take place after all the inputs are read.
74
+
75
+ SIMD: There are instructions for acting on vectors of VLEN elements in a
76
+ single slot. You can use vload and vstore to load multiple contiguous
77
+ elements but not non-contiguous elements. Use vbroadcast to broadcast a
78
+ scalar to a vector and then operate on vectors with valu instructions.
79
+
80
+ The memory and scratch space are composed of 32-bit words. The solution is
81
+ plucked out of the memory at the end of the program. You can think of the
82
+ scratch space as serving the purpose of registers, constant memory, and a
83
+ manually-managed cache.
84
+
85
+ Here's an example of what an instruction might look like:
86
+
87
+ {"valu": [("*", 4, 0, 0), ("+", 8, 4, 0)], "load": [("load", 16, 17)]}
88
+
89
+ In general every number in an instruction is a scratch address except for
90
+ const and jump, and except for store and some flow instructions the first
91
+ operand is the destination.
92
+
93
+ This comment is not meant to be full ISA documentation though, for the rest
94
+ you should look through the simulator code.
95
+ """
96
+
97
+ def __init__(
98
+ self,
99
+ mem_dump: list[int],
100
+ program: list[Instruction],
101
+ debug_info: DebugInfo,
102
+ n_cores: int = 1,
103
+ scratch_size: int = SCRATCH_SIZE,
104
+ trace: bool = False,
105
+ value_trace: dict[Any, int] = {},
106
+ ):
107
+ self.cores = [
108
+ Core(id=i, scratch=[0] * scratch_size, trace_buf=[]) for i in range(n_cores)
109
+ ]
110
+ self.mem = copy(mem_dump)
111
+ self.program = program
112
+ self.debug_info = debug_info
113
+ self.value_trace = value_trace
114
+ self.prints = False
115
+ self.cycle = 0
116
+ self.enable_pause = True
117
+ self.enable_debug = True
118
+ if trace:
119
+ self.setup_trace()
120
+ else:
121
+ self.trace = None
122
+
123
+ def rewrite_instr(self, instr):
124
+ """
125
+ Rewrite an instruction to use scratch addresses instead of names
126
+ """
127
+ res = {}
128
+ for name, slots in instr.items():
129
+ res[name] = []
130
+ for slot in slots:
131
+ res[name].append(self.rewrite_slot(slot))
132
+ return res
133
+
134
+ def print_step(self, instr, core):
135
+ # print(core.id)
136
+ # print(core.trace_buf)
137
+ print(self.scratch_map(core))
138
+ print(core.pc, instr, self.rewrite_instr(instr))
139
+
140
+ def scratch_map(self, core):
141
+ res = {}
142
+ for addr, (name, length) in self.debug_info.scratch_map.items():
143
+ res[name] = core.scratch[addr : addr + length]
144
+ return res
145
+
146
+ def rewrite_slot(self, slot):
147
+ return tuple(
148
+ self.debug_info.scratch_map.get(s, (None, None))[0] or s for s in slot
149
+ )
150
+
151
+ def setup_trace(self):
152
+ """
153
+ The simulator generates traces in Chrome's Trace Event Format for
154
+ visualization in Perfetto (or chrome://tracing if you prefer it). See
155
+ the bottom of the file for info about how to use this.
156
+
157
+ See the format docs in case you want to add more info to the trace:
158
+ https://docs.google.com/document/d/1CvAClvFfyA5R-PhYUmn5OOQtYMH4h6I0nSsKchNAySU/preview
159
+ """
160
+ self.trace = open("trace.json", "w")
161
+ self.trace.write("[")
162
+ tid_counter = 0
163
+ self.tids = {}
164
+ for ci, core in enumerate(self.cores):
165
+ self.trace.write(
166
+ f'{{"name": "process_name", "ph": "M", "pid": {ci}, "tid": 0, "args": {{"name":"Core {ci}"}}}},\n'
167
+ )
168
+ for name, limit in SLOT_LIMITS.items():
169
+ if name == "debug":
170
+ continue
171
+ for i in range(limit):
172
+ tid_counter += 1
173
+ self.trace.write(
174
+ f'{{"name": "thread_name", "ph": "M", "pid": {ci}, "tid": {tid_counter}, "args": {{"name":"{name}-{i}"}}}},\n'
175
+ )
176
+ self.tids[(ci, name, i)] = tid_counter
177
+
178
+ # Add zero-length events at the start so all slots show up in Perfetto
179
+ for ci, core in enumerate(self.cores):
180
+ for name, limit in SLOT_LIMITS.items():
181
+ if name == "debug":
182
+ continue
183
+ for i in range(limit):
184
+ tid = self.tids[(ci, name, i)]
185
+ self.trace.write(
186
+ f'{{"name": "init", "cat": "op", "ph": "X", "pid": {ci}, "tid": {tid}, "ts": 0, "dur": 0}},\n'
187
+ )
188
+ for ci, core in enumerate(self.cores):
189
+ self.trace.write(
190
+ f'{{"name": "process_name", "ph": "M", "pid": {len(self.cores) + ci}, "tid": 0, "args": {{"name":"Core {ci} Scratch"}}}},\n'
191
+ )
192
+ for addr, (name, length) in self.debug_info.scratch_map.items():
193
+ self.trace.write(
194
+ f'{{"name": "thread_name", "ph": "M", "pid": {len(self.cores) + ci}, "tid": {BASE_ADDR_TID + addr}, "args": {{"name":"{name}-{length}"}}}},\n'
195
+ )
196
+
197
+ def run(self):
198
+ for core in self.cores:
199
+ if core.state == CoreState.PAUSED:
200
+ core.state = CoreState.RUNNING
201
+ while any(c.state == CoreState.RUNNING for c in self.cores):
202
+ has_non_debug = False
203
+ for core in self.cores:
204
+ if core.state != CoreState.RUNNING:
205
+ continue
206
+ if core.pc >= len(self.program):
207
+ core.state = CoreState.STOPPED
208
+ continue
209
+ instr = self.program[core.pc]
210
+ if self.prints:
211
+ self.print_step(instr, core)
212
+ core.pc += 1
213
+ self.step(instr, core)
214
+ if any(name != "debug" for name in instr.keys()):
215
+ has_non_debug = True
216
+ if has_non_debug:
217
+ self.cycle += 1
218
+
219
+ def alu(self, core, op, dest, a1, a2):
220
+ a1 = core.scratch[a1]
221
+ a2 = core.scratch[a2]
222
+ match op:
223
+ case "+":
224
+ res = a1 + a2
225
+ case "-":
226
+ res = a1 - a2
227
+ case "*":
228
+ res = a1 * a2
229
+ case "//":
230
+ res = a1 // a2
231
+ case "cdiv":
232
+ res = cdiv(a1, a2)
233
+ case "^":
234
+ res = a1 ^ a2
235
+ case "&":
236
+ res = a1 & a2
237
+ case "|":
238
+ res = a1 | a2
239
+ case "<<":
240
+ res = a1 << a2
241
+ case ">>":
242
+ res = a1 >> a2
243
+ case "%":
244
+ res = a1 % a2
245
+ case "<":
246
+ res = int(a1 < a2)
247
+ case "==":
248
+ res = int(a1 == a2)
249
+ case _:
250
+ raise NotImplementedError(f"Unknown alu op {op}")
251
+ res = res % (2**32)
252
+ self.scratch_write[dest] = res
253
+
254
+ def valu(self, core, *slot):
255
+ match slot:
256
+ case ("vbroadcast", dest, src):
257
+ for i in range(VLEN):
258
+ self.scratch_write[dest + i] = core.scratch[src]
259
+ case ("multiply_add", dest, a, b, c):
260
+ for i in range(VLEN):
261
+ mul = (core.scratch[a + i] * core.scratch[b + i]) % (2**32)
262
+ self.scratch_write[dest + i] = (mul + core.scratch[c + i]) % (2**32)
263
+ case (op, dest, a1, a2):
264
+ for i in range(VLEN):
265
+ self.alu(core, op, dest + i, a1 + i, a2 + i)
266
+ case _:
267
+ raise NotImplementedError(f"Unknown valu op {slot}")
268
+
269
+ def load(self, core, *slot):
270
+ match slot:
271
+ case ("load", dest, addr):
272
+ # print(dest, addr, core.scratch[addr])
273
+ self.scratch_write[dest] = self.mem[core.scratch[addr]]
274
+ case ("load_offset", dest, addr, offset):
275
+ # Handy for treating vector dest and addr as a full block in the mini-compiler if you want
276
+ self.scratch_write[dest + offset] = self.mem[
277
+ core.scratch[addr + offset]
278
+ ]
279
+ case ("vload", dest, addr): # addr is a scalar
280
+ addr = core.scratch[addr]
281
+ for vi in range(VLEN):
282
+ self.scratch_write[dest + vi] = self.mem[addr + vi]
283
+ case ("const", dest, val):
284
+ self.scratch_write[dest] = (val) % (2**32)
285
+ case _:
286
+ raise NotImplementedError(f"Unknown load op {slot}")
287
+
288
+ def store(self, core, *slot):
289
+ match slot:
290
+ case ("store", addr, src):
291
+ addr = core.scratch[addr]
292
+ self.mem_write[addr] = core.scratch[src]
293
+ case ("vstore", addr, src): # addr is a scalar
294
+ addr = core.scratch[addr]
295
+ for vi in range(VLEN):
296
+ self.mem_write[addr + vi] = core.scratch[src + vi]
297
+ case _:
298
+ raise NotImplementedError(f"Unknown store op {slot}")
299
+
300
+ def flow(self, core, *slot):
301
+ match slot:
302
+ case ("select", dest, cond, a, b):
303
+ self.scratch_write[dest] = (
304
+ core.scratch[a] if core.scratch[cond] != 0 else core.scratch[b]
305
+ )
306
+ case ("add_imm", dest, a, imm):
307
+ self.scratch_write[dest] = (core.scratch[a] + imm) % (2**32)
308
+ case ("vselect", dest, cond, a, b):
309
+ for vi in range(VLEN):
310
+ self.scratch_write[dest + vi] = (
311
+ core.scratch[a + vi]
312
+ if core.scratch[cond + vi] != 0
313
+ else core.scratch[b + vi]
314
+ )
315
+ case ("halt",):
316
+ core.state = CoreState.STOPPED
317
+ case ("pause",):
318
+ if self.enable_pause:
319
+ core.state = CoreState.PAUSED
320
+ case ("trace_write", val):
321
+ core.trace_buf.append(core.scratch[val])
322
+ case ("cond_jump", cond, addr):
323
+ if core.scratch[cond] != 0:
324
+ core.pc = addr
325
+ case ("cond_jump_rel", cond, offset):
326
+ if core.scratch[cond] != 0:
327
+ core.pc += offset
328
+ case ("jump", addr):
329
+ core.pc = addr
330
+ case ("jump_indirect", addr):
331
+ core.pc = core.scratch[addr]
332
+ case ("coreid", dest):
333
+ self.scratch_write[dest] = core.id
334
+ case _:
335
+ raise NotImplementedError(f"Unknown flow op {slot}")
336
+
337
+ def trace_post_step(self, instr, core):
338
+ # You can add extra stuff to the trace if you want!
339
+ for addr, (name, length) in self.debug_info.scratch_map.items():
340
+ if any((addr + vi) in self.scratch_write for vi in range(length)):
341
+ val = str(core.scratch[addr : addr + length])
342
+ val = val.replace("[", "").replace("]", "")
343
+ self.trace.write(
344
+ f'{{"name": "{val}", "cat": "op", "ph": "X", "pid": {len(self.cores) + core.id}, "tid": {BASE_ADDR_TID + addr}, "ts": {self.cycle}, "dur": 1 }},\n'
345
+ )
346
+
347
+ def trace_slot(self, core, slot, name, i):
348
+ self.trace.write(
349
+ f'{{"name": "{slot[0]}", "cat": "op", "ph": "X", "pid": {core.id}, "tid": {self.tids[(core.id, name, i)]}, "ts": {self.cycle}, "dur": 1, "args":{{"slot": "{str(slot)}", "named": "{str(self.rewrite_slot(slot))}" }} }},\n'
350
+ )
351
+
352
+ def step(self, instr: Instruction, core):
353
+ """
354
+ Execute all the slots in each engine for a single instruction bundle
355
+ """
356
+ ENGINE_FNS = {
357
+ "alu": self.alu,
358
+ "valu": self.valu,
359
+ "load": self.load,
360
+ "store": self.store,
361
+ "flow": self.flow,
362
+ }
363
+ self.scratch_write = {}
364
+ self.mem_write = {}
365
+ for name, slots in instr.items():
366
+ if name == "debug":
367
+ if not self.enable_debug:
368
+ continue
369
+ for slot in slots:
370
+ if slot[0] == "compare":
371
+ loc, key = slot[1], slot[2]
372
+ ref = self.value_trace[key]
373
+ res = core.scratch[loc]
374
+ assert res == ref, f"{res} != {ref} for {key} at pc={core.pc}"
375
+ elif slot[0] == "vcompare":
376
+ loc, keys = slot[1], slot[2]
377
+ ref = [self.value_trace[key] for key in keys]
378
+ res = core.scratch[loc : loc + VLEN]
379
+ assert res == ref, (
380
+ f"{res} != {ref} for {keys} at pc={core.pc} loc={loc}"
381
+ )
382
+ continue
383
+ assert len(slots) <= SLOT_LIMITS[name]
384
+ for i, slot in enumerate(slots):
385
+ if self.trace is not None:
386
+ self.trace_slot(core, slot, name, i)
387
+ ENGINE_FNS[name](core, *slot)
388
+ for addr, val in self.scratch_write.items():
389
+ core.scratch[addr] = val
390
+ for addr, val in self.mem_write.items():
391
+ self.mem[addr] = val
392
+
393
+ if self.trace:
394
+ self.trace_post_step(instr, core)
395
+
396
+ del self.scratch_write
397
+ del self.mem_write
398
+
399
+ def __del__(self):
400
+ if self.trace is not None:
401
+ self.trace.write("]")
402
+ self.trace.close()
403
+
404
+
405
+ @dataclass
406
+ class Tree:
407
+ """
408
+ An implicit perfect balanced binary tree with values on the nodes.
409
+ """
410
+
411
+ height: int
412
+ values: list[int]
413
+
414
+ @staticmethod
415
+ def generate(height: int):
416
+ n_nodes = 2 ** (height + 1) - 1
417
+ values = [random.randint(0, 2**30 - 1) for _ in range(n_nodes)]
418
+ return Tree(height, values)
419
+
420
+
421
+ @dataclass
422
+ class Input:
423
+ """
424
+ A batch of inputs, indices to nodes (starting as 0) and initial input
425
+ values. We then iterate these for a specified number of rounds.
426
+ """
427
+
428
+ indices: list[int]
429
+ values: list[int]
430
+ rounds: int
431
+
432
+ @staticmethod
433
+ def generate(forest: Tree, batch_size: int, rounds: int):
434
+ indices = [0 for _ in range(batch_size)]
435
+ values = [random.randint(0, 2**30 - 1) for _ in range(batch_size)]
436
+ return Input(indices, values, rounds)
437
+
438
+
439
+ HASH_STAGES = [
440
+ ("+", 0x7ED55D16, "+", "<<", 12),
441
+ ("^", 0xC761C23C, "^", ">>", 19),
442
+ ("+", 0x165667B1, "+", "<<", 5),
443
+ ("+", 0xD3A2646C, "^", "<<", 9),
444
+ ("+", 0xFD7046C5, "+", "<<", 3),
445
+ ("^", 0xB55A4F09, "^", ">>", 16),
446
+ ]
447
+
448
+
449
+ def myhash(a: int) -> int:
450
+ """A simple 32-bit hash function"""
451
+ fns = {
452
+ "+": lambda x, y: x + y,
453
+ "^": lambda x, y: x ^ y,
454
+ "<<": lambda x, y: x << y,
455
+ ">>": lambda x, y: x >> y,
456
+ }
457
+
458
+ def r(x):
459
+ return x % (2**32)
460
+
461
+ for op1, val1, op2, op3, val3 in HASH_STAGES:
462
+ a = r(fns[op2](r(fns[op1](a, val1)), r(fns[op3](a, val3))))
463
+
464
+ return a
465
+
466
+
467
+ def reference_kernel(t: Tree, inp: Input):
468
+ """
469
+ Reference implementation of the kernel.
470
+
471
+ A parallel tree traversal where at each node we set
472
+ cur_inp_val = myhash(cur_inp_val ^ node_val)
473
+ and then choose the left branch if cur_inp_val is even.
474
+ If we reach the bottom of the tree we wrap around to the top.
475
+ """
476
+ for h in range(inp.rounds):
477
+ for i in range(len(inp.indices)):
478
+ idx = inp.indices[i]
479
+ val = inp.values[i]
480
+ val = myhash(val ^ t.values[idx])
481
+ idx = 2 * idx + (1 if val % 2 == 0 else 2)
482
+ idx = 0 if idx >= len(t.values) else idx
483
+ inp.values[i] = val
484
+ inp.indices[i] = idx
485
+
486
+
487
+ def build_mem_image(t: Tree, inp: Input) -> list[int]:
488
+ """
489
+ Build a flat memory image of the problem.
490
+ """
491
+ header = 7
492
+ extra_room = len(t.values) + len(inp.indices) * 2 + VLEN * 2 + 32
493
+ mem = [0] * (
494
+ header + len(t.values) + len(inp.indices) + len(inp.values) + extra_room
495
+ )
496
+ forest_values_p = header
497
+ inp_indices_p = forest_values_p + len(t.values)
498
+ inp_values_p = inp_indices_p + len(inp.values)
499
+ extra_room = inp_values_p + len(inp.values)
500
+
501
+ mem[0] = inp.rounds
502
+ mem[1] = len(t.values)
503
+ mem[2] = len(inp.indices)
504
+ mem[3] = t.height
505
+ mem[4] = forest_values_p
506
+ mem[5] = inp_indices_p
507
+ mem[6] = inp_values_p
508
+ mem[7] = extra_room
509
+
510
+ mem[header:inp_indices_p] = t.values
511
+ mem[inp_indices_p:inp_values_p] = inp.indices
512
+ mem[inp_values_p:] = inp.values
513
+ return mem
514
+
515
+
516
+ def myhash_traced(a: int, trace: dict[Any, int], round: int, batch_i: int) -> int:
517
+ """A simple 32-bit hash function"""
518
+ fns = {
519
+ "+": lambda x, y: x + y,
520
+ "^": lambda x, y: x ^ y,
521
+ "<<": lambda x, y: x << y,
522
+ ">>": lambda x, y: x >> y,
523
+ }
524
+
525
+ def r(x):
526
+ return x % (2**32)
527
+
528
+ for i, (op1, val1, op2, op3, val3) in enumerate(HASH_STAGES):
529
+ a = r(fns[op2](r(fns[op1](a, val1)), r(fns[op3](a, val3))))
530
+ trace[(round, batch_i, "hash_stage", i)] = a
531
+
532
+ return a
533
+
534
+
535
+ def reference_kernel2(mem: list[int], trace: dict[Any, int] = {}):
536
+ """
537
+ Reference implementation of the kernel on a flat memory.
538
+ """
539
+ # This is the initial memory layout
540
+ rounds = mem[0]
541
+ n_nodes = mem[1]
542
+ batch_size = mem[2]
543
+ forest_height = mem[3]
544
+ # Offsets into the memory which indices get added to
545
+ forest_values_p = mem[4]
546
+ inp_indices_p = mem[5]
547
+ inp_values_p = mem[6]
548
+ yield mem
549
+ for h in range(rounds):
550
+ for i in range(batch_size):
551
+ idx = mem[inp_indices_p + i]
552
+ trace[(h, i, "idx")] = idx
553
+ val = mem[inp_values_p + i]
554
+ trace[(h, i, "val")] = val
555
+ node_val = mem[forest_values_p + idx]
556
+ trace[(h, i, "node_val")] = node_val
557
+ val = myhash_traced(val ^ node_val, trace, h, i)
558
+ trace[(h, i, "hashed_val")] = val
559
+ idx = 2 * idx + (1 if val % 2 == 0 else 2)
560
+ trace[(h, i, "next_idx")] = idx
561
+ idx = 0 if idx >= n_nodes else idx
562
+ trace[(h, i, "wrapped_idx")] = idx
563
+ mem[inp_values_p + i] = val
564
+ mem[inp_indices_p + i] = idx
565
+ # You can add new yields or move this around for debugging
566
+ # as long as it's matched by pause instructions.
567
+ # The submission tests evaluate only on final memory.
568
+ yield mem
original_performance_takehome/tests/submission_tests.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys, inspect
2
+
3
+ currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
4
+ parentdir = os.path.dirname(currentdir)
5
+ sys.path.insert(0, parentdir)
6
+
7
+ from functools import lru_cache
8
+ import unittest
9
+ import random
10
+
11
+ from frozen_problem import (
12
+ Machine,
13
+ build_mem_image,
14
+ reference_kernel2,
15
+ Tree,
16
+ Input,
17
+ N_CORES,
18
+ VLEN,
19
+ )
20
+ from perf_takehome import KernelBuilder
21
+
22
+
23
+ @lru_cache(maxsize=None)
24
+ def kernel_builder(forest_height: int, n_nodes: int, batch_size: int, rounds: int):
25
+ kb = KernelBuilder()
26
+ kb.build_kernel(forest_height, n_nodes, batch_size, rounds)
27
+ return kb
28
+
29
+
30
+ def do_kernel_test(forest_height: int, rounds: int, batch_size: int):
31
+ print(f"Testing {forest_height=}, {rounds=}, {batch_size=}")
32
+ # Note the random generator is not seeded here
33
+ forest = Tree.generate(forest_height)
34
+ inp = Input.generate(forest, batch_size, rounds)
35
+ mem = build_mem_image(forest, inp)
36
+
37
+ kb = kernel_builder(forest.height, len(forest.values), len(inp.indices), rounds)
38
+ # print(kb.instrs)
39
+
40
+ machine = Machine(mem, kb.instrs, kb.debug_info(), n_cores=N_CORES)
41
+ machine.enable_pause = False
42
+ machine.enable_debug = False
43
+ machine.run()
44
+
45
+ for ref_mem in reference_kernel2(mem):
46
+ pass
47
+
48
+ inp_values_p = ref_mem[6]
49
+ assert (
50
+ machine.mem[inp_values_p : inp_values_p + len(inp.values)]
51
+ == ref_mem[inp_values_p : inp_values_p + len(inp.values)]
52
+ ), "Incorrect output values"
53
+ print("CYCLES: ", machine.cycle)
54
+ return machine.cycle
55
+
56
+
57
+ class CorrectnessTests(unittest.TestCase):
58
+ def test_kernel_correctness(self):
59
+ for i in range(8):
60
+ do_kernel_test(10, 16, 256)
61
+
62
+
63
+ BASELINE = 147734
64
+
65
+
66
+ @lru_cache(maxsize=None)
67
+ def cycles():
68
+ try:
69
+ res = do_kernel_test(10, 16, 256)
70
+ print("Speedup over baseline: ", BASELINE / res)
71
+ return res
72
+ except AssertionError as e:
73
+ return BASELINE * 2
74
+
75
+
76
+ class SpeedTests(unittest.TestCase):
77
+ """
78
+ You very much don't need to pass all of these to pass the interview.
79
+ The impressiveness also isn't linear in number of tests passed.
80
+
81
+ These are just so that test pass rate gets translated into a number
82
+ on the CodeSignal UI.
83
+ """
84
+
85
+ def test_kernel_speedup(self):
86
+ assert cycles() < BASELINE
87
+
88
+ def test_kernel_updated_starting_point(self):
89
+ # The updated version of this take-home given to candidates contained starter code that started them at this point
90
+ assert cycles() < 18532
91
+
92
+ def test_opus4_many_hours(self):
93
+ # Claude Opus 4 after many hours in the test-time compute harness
94
+ assert cycles() < 2164
95
+
96
+ def test_opus45_casual(self):
97
+ # Claude Opus 4.5 in a casual Claude Code session, approximately matching
98
+ # the best human performance in 2 hours
99
+ assert cycles() < 1790
100
+
101
+ def test_opus45_2hr(self):
102
+ # Claude Opus 4.5 after 2 hours in our test-time compute harness
103
+ assert cycles() < 1579
104
+
105
+ def test_sonnet45_many_hours(self):
106
+ # Claude Sonnet 4.5 after many more than 2 hours of test-time compute
107
+ assert cycles() < 1548
108
+
109
+ def test_opus45_11hr(self):
110
+ # Claude Opus 4.5 after 11.5 hours in the harness
111
+ assert cycles() < 1487
112
+
113
+ def test_opus45_improved_harness(self):
114
+ # Claude Opus 4.5 in an improved test time compute harness
115
+ assert cycles() < 1363
116
+
117
+
118
+ if __name__ == "__main__":
119
+ unittest.main()
original_performance_takehome/watch_trace.html ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!doctype html>
2
+ <html lang="en-us">
3
+ <link rel="shortcut icon" href="data:image/x-icon;," type="image/x-icon" />
4
+
5
+ <body>
6
+ <style>
7
+ pre {
8
+ border: 1px solid #eee;
9
+ margin: 10px 0;
10
+ font-family: monospace;
11
+ font-size: 10px;
12
+ min-height: 100px;
13
+ }
14
+
15
+ body > * {
16
+ margin: 20px;
17
+ }
18
+
19
+ #btn_fetch {
20
+ font-size: 14px;
21
+ }
22
+ </style>
23
+
24
+ <select id="source" size="4">
25
+ <option selected>/trace.json</option>
26
+ </select>
27
+
28
+ <br />
29
+
30
+ <button type="button" id="btn_fetch">Open Perfetto</button>
31
+
32
+ <br />
33
+
34
+ <pre id="logs" cols="80" rows="20"></pre>
35
+
36
+ <script type="text/javascript">
37
+ // const ORIGIN = 'http://localhost:8000/perfetto/';
38
+ const ORIGIN = "https://ui.perfetto.dev";
39
+
40
+ const logs = document.getElementById("logs");
41
+ const btnFetch = document.getElementById("btn_fetch");
42
+
43
+ async function getMtime() {
44
+ const mtime_resp = await fetch("/mtime");
45
+ const mtime = await mtime_resp.text();
46
+ return mtime;
47
+ }
48
+
49
+ async function fetchAndOpen(traceUrl) {
50
+ logs.innerText += `Fetching trace from ${traceUrl}...\n`;
51
+ const mtime = await getMtime();
52
+ const resp = await fetch(traceUrl);
53
+ // Error checcking is left as an exercise to the reader.
54
+ const blob = await resp.blob();
55
+ const arrayBuffer = await blob.arrayBuffer();
56
+ logs.innerText += `fetch() complete, now passing to ui.perfetto.dev\n`;
57
+ openTrace(arrayBuffer, traceUrl, mtime);
58
+ }
59
+
60
+ async function repoll(win, traceUrl, mtime) {
61
+ const newMtime = await getMtime();
62
+ console.log(newMtime, mtime);
63
+ if (newMtime !== mtime) {
64
+ logs.innerText += `Trace updated, fetching new version...\n`;
65
+ const resp = await fetch(traceUrl);
66
+ const blob = await resp.blob();
67
+ const arrayBuffer = await blob.arrayBuffer();
68
+ logs.innerText += `New trace fetched, opening...\n`;
69
+ sendTrace(win, arrayBuffer, traceUrl);
70
+ }
71
+
72
+ setTimeout(() => repoll(win, traceUrl, newMtime), 500);
73
+ }
74
+
75
+ function sendTrace(win, arrayBuffer, traceUrl) {
76
+ const reopenUrl = new URL(location.href);
77
+ reopenUrl.hash = `#reopen=${traceUrl}`;
78
+ logs.innerText += `Sending trace to UI\n`;
79
+ win.postMessage(
80
+ {
81
+ perfetto: {
82
+ buffer: arrayBuffer,
83
+ title: "trace.json",
84
+ url: reopenUrl.toString(),
85
+ keepApiOpen: true,
86
+ },
87
+ },
88
+ ORIGIN,
89
+ );
90
+ }
91
+
92
+ function openTrace(arrayBuffer, traceUrl, mtime) {
93
+ const win = window.open(ORIGIN);
94
+ if (!win) {
95
+ btnFetch.style.background = "#f3ca63";
96
+ btnFetch.onclick = () => openTrace(arrayBuffer);
97
+ logs.innerText += `Popups blocked, you need to manually click the button`;
98
+ btnFetch.innerText =
99
+ "Popups blocked, click here to open the trace file";
100
+ return;
101
+ }
102
+
103
+ const timer = setInterval(
104
+ () => win.postMessage("PING", ORIGIN),
105
+ 50,
106
+ );
107
+
108
+ const onMessageHandler = (evt) => {
109
+ if (evt.data !== "PONG") return;
110
+
111
+ // We got a PONG, the UI is ready.
112
+ window.clearInterval(timer);
113
+ window.removeEventListener("message", onMessageHandler);
114
+
115
+ sendTrace(win, arrayBuffer, traceUrl);
116
+ setTimeout(() => repoll(win, traceUrl, mtime), 500);
117
+ };
118
+
119
+ window.addEventListener("message", onMessageHandler);
120
+ }
121
+
122
+ // This is triggered when following the link from the Perfetto UI's sidebar.
123
+ if (location.hash.startsWith("#reopen=")) {
124
+ const traceUrl = location.hash.substr(8);
125
+ fetchAndOpen(traceUrl);
126
+ }
127
+
128
+ btnFetch.onclick = () =>
129
+ fetchAndOpen(document.getElementById("source").value);
130
+ </script>
131
+ </body>
132
+ </html>
original_performance_takehome/watch_trace.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import http.server
2
+ import os
3
+ from datetime import datetime
4
+ import webbrowser
5
+ import urllib.request
6
+
7
+
8
+ # Define a handler class
9
+ class MyHandler(http.server.BaseHTTPRequestHandler):
10
+ def do_GET(self):
11
+ try:
12
+ # Serve a string constant at the index
13
+ if self.path == "/":
14
+ self.send_response(200)
15
+ self.send_header("Content-type", "text/html")
16
+ self.end_headers()
17
+ with open("watch_trace.html", "rb") as file:
18
+ self.wfile.write(file.read())
19
+
20
+ # Stream the contents of 'trace.json' at '/trace.json'
21
+ elif self.path == "/trace.json":
22
+ self.send_response(200)
23
+ self.send_header("Content-type", "application/json")
24
+ self.end_headers()
25
+ with open("trace.json", "rb") as file:
26
+ while chunk := file.read(8192):
27
+ self.wfile.write(chunk)
28
+
29
+ # Serve the file modification time of 'trace.json' at '/mtime'
30
+ elif self.path == "/mtime":
31
+ mtime = os.path.getmtime("trace.json")
32
+ last_modified_date = datetime.fromtimestamp(mtime).strftime(
33
+ "%Y-%m-%d %H:%M:%S"
34
+ )
35
+ self.send_response(200)
36
+ self.send_header("Content-type", "text/plain")
37
+ self.end_headers()
38
+ self.wfile.write(last_modified_date.encode())
39
+
40
+ elif self.path.startswith("/perfetto"):
41
+ proxy_url = "https://ui.perfetto.dev" + self.path[len("/perfetto") :]
42
+ print("Proxying request to " + proxy_url)
43
+ with urllib.request.urlopen(proxy_url) as response:
44
+ self.send_response(response.status)
45
+
46
+ self.end_headers()
47
+ res = response.read()
48
+ if self.path.endswith("frontend_bundle.js"):
49
+ print("Activating replacement")
50
+ # Fix a bug in Perfetto that they haven't deployed the fix for yet but have fixed internally
51
+ res = res.replace(
52
+ b"throw new Error(`EngineProxy ${this.tag} was disposed.`);",
53
+ b"return null;",
54
+ )
55
+ # Auto-expand tracks by default
56
+ res = res.replace(b"collapsed: true", b"collapsed: false")
57
+ res = res.replace(
58
+ b"collapsed: !hasHeapProfiles", b"collapsed: false"
59
+ )
60
+ for header in response.headers:
61
+ if header == "Content-Length":
62
+ self.send_header(header, len(res))
63
+ self.send_header(header, response.headers[header])
64
+ self.wfile.write(res)
65
+
66
+ else:
67
+ self.send_error(404, "File Not Found: {}".format(self.path))
68
+
69
+ except IOError:
70
+ self.send_error(404, "File Not Found: {}".format(self.path))
71
+
72
+
73
+ # Start the server
74
+ def run(server_class=http.server.HTTPServer, handler_class=MyHandler):
75
+ server_address = ("", 8000)
76
+ httpd = server_class(server_address, handler_class)
77
+ print("Starting httpd...")
78
+ webbrowser.open("http://localhost:8000")
79
+ httpd.serve_forever()
80
+
81
+
82
+ # Run the server
83
+ if __name__ == "__main__":
84
+ run()
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ torch>=2.1.0
2
+ transformers>=4.40.0
3
+ datasets>=2.18.0
4
+ peft>=0.10.0
5
+ trl>=0.8.0
6
+ accelerate>=0.28.0
7
+ bitsandbytes>=0.43.0
8
+ gradio>=4.0.0