CreativeEngineer commited on
Commit
b3b926b
·
1 Parent(s): 9c10799

Add VLIW simulator for cycle-count based rewards

Browse files
Files changed (2) hide show
  1. app.py +247 -59
  2. problem.py +568 -0
app.py CHANGED
@@ -1,9 +1,13 @@
1
  """
2
  HF Spaces app for VLIW kernel optimization via RL.
 
3
  """
4
  import gradio as gr
5
  import threading
6
  import time
 
 
 
7
 
8
  # Check imports at startup
9
  startup_log = []
@@ -38,11 +42,30 @@ try:
38
  except Exception as e:
39
  startup_log.append(f"✗ CUDA check: {e}")
40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  # Training state
42
  training_state = {
43
  "is_training": False,
44
  "should_stop": False,
45
  "log": [],
 
 
46
  }
47
  state_lock = threading.Lock()
48
 
@@ -51,20 +74,171 @@ def get_status():
51
  return "\n".join(startup_log)
52
 
53
 
54
- def simple_reward_fn(completions, **kwargs):
55
- """Simple reward: prefer longer, code-like outputs."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  rewards = []
57
- for c in completions:
58
- text = c[0]["content"] if isinstance(c, list) else str(c)
59
- score = min(len(text) / 200.0, 1.0)
60
- if any(kw in text for kw in ["def ", "for ", "if ", "while ", "return "]):
61
- score += 0.3
62
- rewards.append(score)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  return rewards
64
 
65
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  def run_training(model_name, num_steps):
67
- """Run RL training with GRPO."""
68
  import torch
69
  from datasets import Dataset
70
  from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
@@ -83,9 +257,13 @@ def run_training(model_name, num_steps):
83
  training_state["is_training"] = True
84
  training_state["should_stop"] = False
85
  training_state["log"] = []
 
 
86
 
87
  try:
88
- add_log(f"Starting training: {model_name}, {num_steps} steps")
 
 
89
 
90
  # Load tokenizer
91
  add_log("Loading tokenizer...")
@@ -109,16 +287,11 @@ def run_training(model_name, num_steps):
109
  )
110
  add_log(f"✓ Model loaded on {next(model.parameters()).device}")
111
 
112
- # Create dataset
113
- add_log("Creating training dataset...")
114
- prompts = [
115
- "Write optimized VLIW assembly code for matrix multiplication using SIMD instructions",
116
- "Generate efficient parallel code for vector dot product",
117
- "Create VLIW code for memory-bound reduction operation",
118
- "Write pipelined code for element-wise array operations",
119
- ] * 8 # 32 prompts total
120
  dataset = Dataset.from_dict({"prompt": prompts})
121
- add_log(f"✓ Dataset: {len(prompts)} prompts")
122
 
123
  # LoRA config
124
  add_log("Setting up LoRA...")
@@ -131,28 +304,36 @@ def run_training(model_name, num_steps):
131
  task_type="CAUSAL_LM",
132
  )
133
 
134
- # Stop callback
135
- class StopCallback(TrainerCallback):
136
  def on_step_end(self, args, state, control, **kwargs):
137
  with state_lock:
 
138
  if training_state["should_stop"]:
139
  control.should_training_stop = True
140
  return control
141
 
 
 
 
 
 
 
 
142
  # GRPO config
143
- add_log("Creating GRPO trainer...")
144
  config = GRPOConfig(
145
- output_dir="./grpo_output",
146
  num_train_epochs=1,
147
  max_steps=num_steps,
148
- per_device_train_batch_size=2,
149
- gradient_accumulation_steps=2,
150
- learning_rate=5e-6,
151
  logging_steps=1,
152
- save_steps=999999, # Don't save checkpoints
153
  report_to="none",
154
  remove_unused_columns=False,
155
- max_completion_length=128,
156
  num_generations=4,
157
  )
158
 
@@ -160,41 +341,54 @@ def run_training(model_name, num_steps):
160
  model=model,
161
  args=config,
162
  train_dataset=dataset,
163
- reward_funcs=simple_reward_fn,
164
  peft_config=lora_config,
165
  processing_class=tokenizer,
166
- callbacks=[StopCallback()],
167
  )
168
  add_log("✓ Trainer ready")
169
 
170
  # Train
171
  add_log("Starting training loop...")
 
172
  train_result = trainer.train()
173
 
174
  metrics = train_result.metrics
175
  add_log(f"✓ Training complete!")
176
- add_log(f" Steps: {metrics.get('train_steps', 'N/A')}")
177
- add_log(f" Loss: {metrics.get('train_loss', 'N/A'):.4f}" if 'train_loss' in metrics else " Loss: N/A")
178
 
179
  # Test generation
180
  add_log("Testing trained model...")
181
- test_prompt = "Write efficient VLIW code for:"
182
- inputs = tokenizer(test_prompt, return_tensors="pt").to(model.device)
183
  with torch.no_grad():
184
- outputs = model.generate(**inputs, max_new_tokens=50, do_sample=True, temperature=0.7)
 
 
 
 
 
 
185
  result = tokenizer.decode(outputs[0], skip_special_tokens=True)
186
- add_log(f"Sample output: {result[:100]}...")
 
 
 
 
 
 
 
 
 
187
 
188
  add_log("\n✓ All done!")
189
 
190
  except Exception as e:
191
  import traceback
192
  add_log(f"✗ Error: {e}")
193
- add_log(traceback.format_exc()[:500])
194
  finally:
195
  with state_lock:
196
  training_state["is_training"] = False
197
- # Cleanup
198
  try:
199
  del model
200
  torch.cuda.empty_cache()
@@ -205,7 +399,7 @@ def run_training(model_name, num_steps):
205
 
206
 
207
  def start_training(model_name, num_steps):
208
- """Start training (blocking for simplicity)."""
209
  with state_lock:
210
  if training_state["is_training"]:
211
  return "Training already in progress. Please wait."
@@ -222,24 +416,18 @@ def stop_training():
222
  return "Stop requested. Training will stop after current step."
223
 
224
 
225
- def get_progress():
226
- """Get current log."""
227
- with state_lock:
228
- if not training_state["log"]:
229
- return "No training started yet"
230
- return "\n".join(training_state["log"])
231
-
232
-
233
  # Gradio UI
234
  with gr.Blocks(title="VLIW Optimizer") as demo:
235
  gr.Markdown("# VLIW Kernel Optimizer - RL Training")
236
- gr.Markdown("""
237
- Train a language model with reinforcement learning to generate optimized VLIW/SIMD code.
 
 
238
 
239
- **Instructions:**
240
- 1. Select a model (1.5B is faster, 3B may produce better results)
241
- 2. Set training steps (10-50 recommended for testing)
242
- 3. Click 'Start Training' and wait for completion
243
  """)
244
 
245
  with gr.Row():
@@ -247,7 +435,7 @@ with gr.Blocks(title="VLIW Optimizer") as demo:
247
  status_box = gr.Textbox(
248
  label="System Status",
249
  value=get_status(),
250
- lines=9,
251
  interactive=False,
252
  )
253
 
@@ -261,10 +449,10 @@ with gr.Blocks(title="VLIW Optimizer") as demo:
261
  label="Model",
262
  )
263
  steps_slider = gr.Slider(
264
- minimum=1,
265
  maximum=100,
266
- value=10,
267
- step=1,
268
  label="Training Steps",
269
  )
270
 
@@ -274,9 +462,9 @@ with gr.Blocks(title="VLIW Optimizer") as demo:
274
 
275
  output_box = gr.Textbox(
276
  label="Training Log",
277
- lines=20,
278
  interactive=False,
279
- value="Click 'Start Training' to begin.",
280
  )
281
 
282
  start_btn.click(
 
1
  """
2
  HF Spaces app for VLIW kernel optimization via RL.
3
+ Uses actual simulator for cycle-count based rewards.
4
  """
5
  import gradio as gr
6
  import threading
7
  import time
8
+ import random
9
+ import re
10
+ from copy import copy
11
 
12
  # Check imports at startup
13
  startup_log = []
 
42
  except Exception as e:
43
  startup_log.append(f"✗ CUDA check: {e}")
44
 
45
+ # Import simulator components
46
+ try:
47
+ from problem import (
48
+ Machine, Tree, Input, DebugInfo,
49
+ build_mem_image, reference_kernel2,
50
+ SLOT_LIMITS, VLEN, N_CORES, SCRATCH_SIZE, HASH_STAGES
51
+ )
52
+ startup_log.append("✓ VLIW Simulator: OK")
53
+ SIMULATOR_AVAILABLE = True
54
+ except Exception as e:
55
+ startup_log.append(f"✗ VLIW Simulator: {e}")
56
+ SIMULATOR_AVAILABLE = False
57
+
58
+ # Constants
59
+ BASELINE_CYCLES = 147734
60
+ TARGET_CYCLES = 1363
61
+
62
  # Training state
63
  training_state = {
64
  "is_training": False,
65
  "should_stop": False,
66
  "log": [],
67
+ "best_cycles": BASELINE_CYCLES,
68
+ "step": 0,
69
  }
70
  state_lock = threading.Lock()
71
 
 
74
  return "\n".join(startup_log)
75
 
76
 
77
+ def parse_kernel_code(code_text):
78
+ """
79
+ Parse LLM-generated kernel code into simulator instructions.
80
+ Returns list of instruction dicts or None if parsing fails.
81
+ """
82
+ instructions = []
83
+
84
+ # Try to find instruction patterns in the code
85
+ # Format: {"engine": [("op", arg1, arg2, ...)]}
86
+
87
+ # Look for dict-like instruction patterns
88
+ pattern = r'\{[^}]+\}'
89
+ matches = re.findall(pattern, code_text)
90
+
91
+ for match in matches:
92
+ try:
93
+ # Try to eval as Python dict (safely)
94
+ instr = eval(match, {"__builtins__": {}})
95
+ if isinstance(instr, dict):
96
+ # Validate it's a valid instruction
97
+ valid_engines = {"alu", "valu", "load", "store", "flow", "debug"}
98
+ if any(k in valid_engines for k in instr.keys()):
99
+ instructions.append(instr)
100
+ except:
101
+ continue
102
+
103
+ return instructions if instructions else None
104
+
105
+
106
+ def build_simple_kernel(batch_size, rounds):
107
+ """
108
+ Build a simple baseline kernel for comparison.
109
+ This is a simplified version that the model should try to beat.
110
+ """
111
+ instructions = []
112
+
113
+ # Initialize scratch space addresses
114
+ for i in range(7):
115
+ instructions.append({"load": [("const", i, i)]})
116
+ instructions.append({"load": [("load", i, i)]})
117
+
118
+ instructions.append({"flow": [("pause",)]})
119
+
120
+ # Main loop body (simplified)
121
+ for r in range(min(rounds, 2)): # Limit for testing
122
+ for i in range(min(batch_size, 4)): # Limit for testing
123
+ # Load index and value
124
+ instructions.append({"alu": [("+", 10, 5, 0)]}) # addr = inp_indices_p + 0
125
+ instructions.append({"load": [("load", 11, 10)]}) # idx = mem[addr]
126
+ instructions.append({"alu": [("+", 12, 6, 0)]}) # addr = inp_values_p + 0
127
+ instructions.append({"load": [("load", 13, 12)]}) # val = mem[addr]
128
+
129
+ instructions.append({"flow": [("pause",)]})
130
+
131
+ return instructions
132
+
133
+
134
+ def evaluate_kernel(instructions, seed=42):
135
+ """
136
+ Run kernel through simulator and return cycle count.
137
+ Lower is better.
138
+ """
139
+ if not SIMULATOR_AVAILABLE:
140
+ return BASELINE_CYCLES
141
+
142
+ try:
143
+ random.seed(seed)
144
+ forest = Tree.generate(10)
145
+ inp = Input.generate(forest, 256, 16)
146
+ mem = build_mem_image(forest, inp)
147
+
148
+ debug_info = DebugInfo(scratch_map={})
149
+
150
+ machine = Machine(
151
+ mem,
152
+ instructions,
153
+ debug_info,
154
+ n_cores=N_CORES,
155
+ trace=False,
156
+ )
157
+ machine.enable_pause = False
158
+ machine.enable_debug = False
159
+
160
+ # Run the machine
161
+ machine.run()
162
+
163
+ return machine.cycle
164
+ except Exception as e:
165
+ # Return high cycle count for invalid code
166
+ return BASELINE_CYCLES * 2
167
+
168
+
169
+ def vliw_reward_fn(completions, prompts=None, **kwargs):
170
+ """
171
+ Reward function based on VLIW simulator cycle count.
172
+ Higher reward for lower cycle count.
173
+ """
174
  rewards = []
175
+
176
+ for completion in completions:
177
+ # Extract text from completion
178
+ if isinstance(completion, list):
179
+ text = completion[0].get("content", "") if completion else ""
180
+ else:
181
+ text = str(completion)
182
+
183
+ # Try to parse as kernel instructions
184
+ instructions = parse_kernel_code(text)
185
+
186
+ if instructions and len(instructions) > 5:
187
+ # Evaluate with simulator
188
+ cycles = evaluate_kernel(instructions)
189
+
190
+ # Reward: normalized improvement over baseline
191
+ # Max reward when cycles <= TARGET_CYCLES
192
+ if cycles <= TARGET_CYCLES:
193
+ reward = 2.0 # Maximum reward
194
+ elif cycles < BASELINE_CYCLES:
195
+ # Linear scale between baseline and target
196
+ improvement = (BASELINE_CYCLES - cycles) / (BASELINE_CYCLES - TARGET_CYCLES)
197
+ reward = 0.5 + 1.5 * improvement
198
+ else:
199
+ # Below baseline performance
200
+ reward = 0.5 * (BASELINE_CYCLES / max(cycles, 1))
201
+ else:
202
+ # Could not parse - give small reward for code-like output
203
+ reward = 0.1
204
+ if "def " in text or "for " in text:
205
+ reward = 0.2
206
+ if any(kw in text for kw in ["alu", "load", "store", "valu"]):
207
+ reward = 0.3
208
+
209
+ rewards.append(reward)
210
+
211
  return rewards
212
 
213
 
214
+ # Prompt template for VLIW optimization
215
+ VLIW_PROMPT = """You are an expert in VLIW (Very Long Instruction Word) architecture optimization.
216
+
217
+ Generate optimized VLIW assembly code for a parallel tree traversal kernel.
218
+
219
+ The architecture has these engines that execute in parallel each cycle:
220
+ - alu: up to 12 scalar ALU operations per cycle
221
+ - valu: up to 6 vector ALU operations (VLEN=8 elements)
222
+ - load: up to 2 load operations per cycle
223
+ - store: up to 2 store operations per cycle
224
+ - flow: 1 control flow operation per cycle
225
+
226
+ Instructions are in Python dict format:
227
+ {"alu": [("+", dest, src1, src2), ("*", dest, src1, src2)], "load": [("load", dest, addr)]}
228
+
229
+ The kernel should:
230
+ 1. Load indices and values from memory
231
+ 2. Perform hash computation (6 stages using +, ^, <<, >>)
232
+ 3. Update tree traversal index based on hash result
233
+ 4. Store results back to memory
234
+
235
+ Optimize for minimum cycle count. Current baseline: 147,734 cycles. Target: <1,363 cycles.
236
+
237
+ Generate the optimized kernel code:"""
238
+
239
+
240
  def run_training(model_name, num_steps):
241
+ """Run RL training with VLIW simulator rewards."""
242
  import torch
243
  from datasets import Dataset
244
  from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
 
257
  training_state["is_training"] = True
258
  training_state["should_stop"] = False
259
  training_state["log"] = []
260
+ training_state["best_cycles"] = BASELINE_CYCLES
261
+ training_state["step"] = 0
262
 
263
  try:
264
+ add_log(f"Starting VLIW optimization training")
265
+ add_log(f"Model: {model_name}, Steps: {num_steps}")
266
+ add_log(f"Baseline: {BASELINE_CYCLES:,} cycles, Target: {TARGET_CYCLES:,} cycles")
267
 
268
  # Load tokenizer
269
  add_log("Loading tokenizer...")
 
287
  )
288
  add_log(f"✓ Model loaded on {next(model.parameters()).device}")
289
 
290
+ # Create dataset with VLIW prompts
291
+ add_log("Creating VLIW optimization dataset...")
292
+ prompts = [VLIW_PROMPT] * 16
 
 
 
 
 
293
  dataset = Dataset.from_dict({"prompt": prompts})
294
+ add_log(f"✓ Dataset ready: {len(prompts)} prompts")
295
 
296
  # LoRA config
297
  add_log("Setting up LoRA...")
 
304
  task_type="CAUSAL_LM",
305
  )
306
 
307
+ # Custom callback for logging
308
+ class VLIWCallback(TrainerCallback):
309
  def on_step_end(self, args, state, control, **kwargs):
310
  with state_lock:
311
+ training_state["step"] = state.global_step
312
  if training_state["should_stop"]:
313
  control.should_training_stop = True
314
  return control
315
 
316
+ def on_log(self, args, state, control, logs=None, **kwargs):
317
+ if logs:
318
+ loss = logs.get("loss", "N/A")
319
+ reward = logs.get("reward", logs.get("mean_reward", "N/A"))
320
+ step = state.global_step
321
+ add_log(f"Step {step}: loss={loss:.4f}, reward={reward:.4f}" if isinstance(loss, float) else f"Step {step}: {logs}")
322
+
323
  # GRPO config
324
+ add_log("Creating GRPO trainer with VLIW rewards...")
325
  config = GRPOConfig(
326
+ output_dir="./grpo_vliw_output",
327
  num_train_epochs=1,
328
  max_steps=num_steps,
329
+ per_device_train_batch_size=1,
330
+ gradient_accumulation_steps=4,
331
+ learning_rate=1e-5,
332
  logging_steps=1,
333
+ save_steps=999999,
334
  report_to="none",
335
  remove_unused_columns=False,
336
+ max_completion_length=512,
337
  num_generations=4,
338
  )
339
 
 
341
  model=model,
342
  args=config,
343
  train_dataset=dataset,
344
+ reward_funcs=vliw_reward_fn,
345
  peft_config=lora_config,
346
  processing_class=tokenizer,
347
+ callbacks=[VLIWCallback()],
348
  )
349
  add_log("✓ Trainer ready")
350
 
351
  # Train
352
  add_log("Starting training loop...")
353
+ add_log("(Model will learn to generate VLIW code with lower cycle counts)")
354
  train_result = trainer.train()
355
 
356
  metrics = train_result.metrics
357
  add_log(f"✓ Training complete!")
358
+ add_log(f" Total steps: {metrics.get('train_steps', num_steps)}")
 
359
 
360
  # Test generation
361
  add_log("Testing trained model...")
362
+ inputs = tokenizer(VLIW_PROMPT[:200], return_tensors="pt").to(model.device)
 
363
  with torch.no_grad():
364
+ outputs = model.generate(
365
+ **inputs,
366
+ max_new_tokens=256,
367
+ do_sample=True,
368
+ temperature=0.7,
369
+ top_p=0.9,
370
+ )
371
  result = tokenizer.decode(outputs[0], skip_special_tokens=True)
372
+
373
+ # Try to evaluate the generated code
374
+ instructions = parse_kernel_code(result)
375
+ if instructions:
376
+ cycles = evaluate_kernel(instructions)
377
+ add_log(f"Generated kernel: {len(instructions)} instructions, {cycles:,} cycles")
378
+ speedup = BASELINE_CYCLES / max(cycles, 1)
379
+ add_log(f"Speedup: {speedup:.2f}x over baseline")
380
+ else:
381
+ add_log(f"Sample output (first 200 chars): {result[len(VLIW_PROMPT[:200]):len(VLIW_PROMPT[:200])+200]}...")
382
 
383
  add_log("\n✓ All done!")
384
 
385
  except Exception as e:
386
  import traceback
387
  add_log(f"✗ Error: {e}")
388
+ add_log(traceback.format_exc()[:800])
389
  finally:
390
  with state_lock:
391
  training_state["is_training"] = False
 
392
  try:
393
  del model
394
  torch.cuda.empty_cache()
 
399
 
400
 
401
  def start_training(model_name, num_steps):
402
+ """Start training."""
403
  with state_lock:
404
  if training_state["is_training"]:
405
  return "Training already in progress. Please wait."
 
416
  return "Stop requested. Training will stop after current step."
417
 
418
 
 
 
 
 
 
 
 
 
419
  # Gradio UI
420
  with gr.Blocks(title="VLIW Optimizer") as demo:
421
  gr.Markdown("# VLIW Kernel Optimizer - RL Training")
422
+ gr.Markdown(f"""
423
+ Train a language model with reinforcement learning to generate optimized VLIW/SIMD kernels.
424
+
425
+ **Goal:** Reduce cycle count from **{BASELINE_CYCLES:,}** (baseline) to **<{TARGET_CYCLES:,}** (108x speedup)
426
 
427
+ **How it works:**
428
+ 1. Model generates VLIW assembly code
429
+ 2. Simulator evaluates cycle count
430
+ 3. RL training improves model based on cycle-count rewards
431
  """)
432
 
433
  with gr.Row():
 
435
  status_box = gr.Textbox(
436
  label="System Status",
437
  value=get_status(),
438
+ lines=12,
439
  interactive=False,
440
  )
441
 
 
449
  label="Model",
450
  )
451
  steps_slider = gr.Slider(
452
+ minimum=5,
453
  maximum=100,
454
+ value=20,
455
+ step=5,
456
  label="Training Steps",
457
  )
458
 
 
462
 
463
  output_box = gr.Textbox(
464
  label="Training Log",
465
+ lines=25,
466
  interactive=False,
467
+ value="Click 'Start Training' to begin VLIW optimization.",
468
  )
469
 
470
  start_btn.click(
problem.py ADDED
@@ -0,0 +1,568 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Read the top of perf_takehome.py for more introduction.
3
+
4
+ This file is separate mostly for ease of copying it to freeze the machine and
5
+ reference kernel for testing.
6
+ """
7
+
8
+ from copy import copy
9
+ from dataclasses import dataclass
10
+ from enum import Enum
11
+ from typing import Any, Literal
12
+ import random
13
+
14
+ Engine = Literal["alu", "load", "store", "flow"]
15
+ Instruction = dict[Engine, list[tuple]]
16
+
17
+
18
+ class CoreState(Enum):
19
+ RUNNING = 1
20
+ PAUSED = 2
21
+ STOPPED = 3
22
+
23
+
24
+ @dataclass
25
+ class Core:
26
+ id: int
27
+ scratch: list[int]
28
+ trace_buf: list[int]
29
+ pc: int = 0
30
+ state: CoreState = CoreState.RUNNING
31
+
32
+
33
+ @dataclass
34
+ class DebugInfo:
35
+ """
36
+ We give you some debug info but it's up to you to use it in Machine if you
37
+ want to. You're also welcome to add more.
38
+ """
39
+
40
+ # Maps scratch variable addr to (name, len) pair
41
+ scratch_map: dict[int, (str, int)]
42
+
43
+
44
+ def cdiv(a, b):
45
+ return (a + b - 1) // b
46
+
47
+
48
+ SLOT_LIMITS = {
49
+ "alu": 12,
50
+ "valu": 6,
51
+ "load": 2,
52
+ "store": 2,
53
+ "flow": 1,
54
+ "debug": 64,
55
+ }
56
+
57
+ VLEN = 8
58
+ # Older versions of the take-home used multiple cores, but this version only uses 1
59
+ N_CORES = 1
60
+ SCRATCH_SIZE = 1536
61
+ BASE_ADDR_TID = 100000
62
+
63
+
64
+ class Machine:
65
+ """
66
+ Simulator for a custom VLIW SIMD architecture.
67
+
68
+ VLIW (Very Large Instruction Word): Cores are composed of different
69
+ "engines" each of which can execute multiple "slots" per cycle in parallel.
70
+ How many slots each engine can execute per cycle is limited by SLOT_LIMITS.
71
+ Effects of instructions don't take effect until the end of cycle. Each
72
+ cycle, all engines execute all of their filled slots for that instruction.
73
+ Effects like writes to memory take place after all the inputs are read.
74
+
75
+ SIMD: There are instructions for acting on vectors of VLEN elements in a
76
+ single slot. You can use vload and vstore to load multiple contiguous
77
+ elements but not non-contiguous elements. Use vbroadcast to broadcast a
78
+ scalar to a vector and then operate on vectors with valu instructions.
79
+
80
+ The memory and scratch space are composed of 32-bit words. The solution is
81
+ plucked out of the memory at the end of the program. You can think of the
82
+ scratch space as serving the purpose of registers, constant memory, and a
83
+ manually-managed cache.
84
+
85
+ Here's an example of what an instruction might look like:
86
+
87
+ {"valu": [("*", 4, 0, 0), ("+", 8, 4, 0)], "load": [("load", 16, 17)]}
88
+
89
+ In general every number in an instruction is a scratch address except for
90
+ const and jump, and except for store and some flow instructions the first
91
+ operand is the destination.
92
+
93
+ This comment is not meant to be full ISA documentation though, for the rest
94
+ you should look through the simulator code.
95
+ """
96
+
97
+ def __init__(
98
+ self,
99
+ mem_dump: list[int],
100
+ program: list[Instruction],
101
+ debug_info: DebugInfo,
102
+ n_cores: int = 1,
103
+ scratch_size: int = SCRATCH_SIZE,
104
+ trace: bool = False,
105
+ value_trace: dict[Any, int] = {},
106
+ ):
107
+ self.cores = [
108
+ Core(id=i, scratch=[0] * scratch_size, trace_buf=[]) for i in range(n_cores)
109
+ ]
110
+ self.mem = copy(mem_dump)
111
+ self.program = program
112
+ self.debug_info = debug_info
113
+ self.value_trace = value_trace
114
+ self.prints = False
115
+ self.cycle = 0
116
+ self.enable_pause = True
117
+ self.enable_debug = True
118
+ if trace:
119
+ self.setup_trace()
120
+ else:
121
+ self.trace = None
122
+
123
+ def rewrite_instr(self, instr):
124
+ """
125
+ Rewrite an instruction to use scratch addresses instead of names
126
+ """
127
+ res = {}
128
+ for name, slots in instr.items():
129
+ res[name] = []
130
+ for slot in slots:
131
+ res[name].append(self.rewrite_slot(slot))
132
+ return res
133
+
134
+ def print_step(self, instr, core):
135
+ # print(core.id)
136
+ # print(core.trace_buf)
137
+ print(self.scratch_map(core))
138
+ print(core.pc, instr, self.rewrite_instr(instr))
139
+
140
+ def scratch_map(self, core):
141
+ res = {}
142
+ for addr, (name, length) in self.debug_info.scratch_map.items():
143
+ res[name] = core.scratch[addr : addr + length]
144
+ return res
145
+
146
+ def rewrite_slot(self, slot):
147
+ return tuple(
148
+ self.debug_info.scratch_map.get(s, (None, None))[0] or s for s in slot
149
+ )
150
+
151
+ def setup_trace(self):
152
+ """
153
+ The simulator generates traces in Chrome's Trace Event Format for
154
+ visualization in Perfetto (or chrome://tracing if you prefer it). See
155
+ the bottom of the file for info about how to use this.
156
+
157
+ See the format docs in case you want to add more info to the trace:
158
+ https://docs.google.com/document/d/1CvAClvFfyA5R-PhYUmn5OOQtYMH4h6I0nSsKchNAySU/preview
159
+ """
160
+ self.trace = open("trace.json", "w")
161
+ self.trace.write("[")
162
+ tid_counter = 0
163
+ self.tids = {}
164
+ for ci, core in enumerate(self.cores):
165
+ self.trace.write(
166
+ f'{{"name": "process_name", "ph": "M", "pid": {ci}, "tid": 0, "args": {{"name":"Core {ci}"}}}},\n'
167
+ )
168
+ for name, limit in SLOT_LIMITS.items():
169
+ if name == "debug":
170
+ continue
171
+ for i in range(limit):
172
+ tid_counter += 1
173
+ self.trace.write(
174
+ f'{{"name": "thread_name", "ph": "M", "pid": {ci}, "tid": {tid_counter}, "args": {{"name":"{name}-{i}"}}}},\n'
175
+ )
176
+ self.tids[(ci, name, i)] = tid_counter
177
+
178
+ # Add zero-length events at the start so all slots show up in Perfetto
179
+ for ci, core in enumerate(self.cores):
180
+ for name, limit in SLOT_LIMITS.items():
181
+ if name == "debug":
182
+ continue
183
+ for i in range(limit):
184
+ tid = self.tids[(ci, name, i)]
185
+ self.trace.write(
186
+ f'{{"name": "init", "cat": "op", "ph": "X", "pid": {ci}, "tid": {tid}, "ts": 0, "dur": 0}},\n'
187
+ )
188
+ for ci, core in enumerate(self.cores):
189
+ self.trace.write(
190
+ f'{{"name": "process_name", "ph": "M", "pid": {len(self.cores) + ci}, "tid": 0, "args": {{"name":"Core {ci} Scratch"}}}},\n'
191
+ )
192
+ for addr, (name, length) in self.debug_info.scratch_map.items():
193
+ self.trace.write(
194
+ f'{{"name": "thread_name", "ph": "M", "pid": {len(self.cores) + ci}, "tid": {BASE_ADDR_TID + addr}, "args": {{"name":"{name}-{length}"}}}},\n'
195
+ )
196
+
197
+ def run(self):
198
+ for core in self.cores:
199
+ if core.state == CoreState.PAUSED:
200
+ core.state = CoreState.RUNNING
201
+ while any(c.state == CoreState.RUNNING for c in self.cores):
202
+ has_non_debug = False
203
+ for core in self.cores:
204
+ if core.state != CoreState.RUNNING:
205
+ continue
206
+ if core.pc >= len(self.program):
207
+ core.state = CoreState.STOPPED
208
+ continue
209
+ instr = self.program[core.pc]
210
+ if self.prints:
211
+ self.print_step(instr, core)
212
+ core.pc += 1
213
+ self.step(instr, core)
214
+ if any(name != "debug" for name in instr.keys()):
215
+ has_non_debug = True
216
+ if has_non_debug:
217
+ self.cycle += 1
218
+
219
+ def alu(self, core, op, dest, a1, a2):
220
+ a1 = core.scratch[a1]
221
+ a2 = core.scratch[a2]
222
+ match op:
223
+ case "+":
224
+ res = a1 + a2
225
+ case "-":
226
+ res = a1 - a2
227
+ case "*":
228
+ res = a1 * a2
229
+ case "//":
230
+ res = a1 // a2
231
+ case "cdiv":
232
+ res = cdiv(a1, a2)
233
+ case "^":
234
+ res = a1 ^ a2
235
+ case "&":
236
+ res = a1 & a2
237
+ case "|":
238
+ res = a1 | a2
239
+ case "<<":
240
+ res = a1 << a2
241
+ case ">>":
242
+ res = a1 >> a2
243
+ case "%":
244
+ res = a1 % a2
245
+ case "<":
246
+ res = int(a1 < a2)
247
+ case "==":
248
+ res = int(a1 == a2)
249
+ case _:
250
+ raise NotImplementedError(f"Unknown alu op {op}")
251
+ res = res % (2**32)
252
+ self.scratch_write[dest] = res
253
+
254
+ def valu(self, core, *slot):
255
+ match slot:
256
+ case ("vbroadcast", dest, src):
257
+ for i in range(VLEN):
258
+ self.scratch_write[dest + i] = core.scratch[src]
259
+ case ("multiply_add", dest, a, b, c):
260
+ for i in range(VLEN):
261
+ mul = (core.scratch[a + i] * core.scratch[b + i]) % (2**32)
262
+ self.scratch_write[dest + i] = (mul + core.scratch[c + i]) % (2**32)
263
+ case (op, dest, a1, a2):
264
+ for i in range(VLEN):
265
+ self.alu(core, op, dest + i, a1 + i, a2 + i)
266
+ case _:
267
+ raise NotImplementedError(f"Unknown valu op {slot}")
268
+
269
+ def load(self, core, *slot):
270
+ match slot:
271
+ case ("load", dest, addr):
272
+ # print(dest, addr, core.scratch[addr])
273
+ self.scratch_write[dest] = self.mem[core.scratch[addr]]
274
+ case ("load_offset", dest, addr, offset):
275
+ # Handy for treating vector dest and addr as a full block in the mini-compiler if you want
276
+ self.scratch_write[dest + offset] = self.mem[
277
+ core.scratch[addr + offset]
278
+ ]
279
+ case ("vload", dest, addr): # addr is a scalar
280
+ addr = core.scratch[addr]
281
+ for vi in range(VLEN):
282
+ self.scratch_write[dest + vi] = self.mem[addr + vi]
283
+ case ("const", dest, val):
284
+ self.scratch_write[dest] = (val) % (2**32)
285
+ case _:
286
+ raise NotImplementedError(f"Unknown load op {slot}")
287
+
288
+ def store(self, core, *slot):
289
+ match slot:
290
+ case ("store", addr, src):
291
+ addr = core.scratch[addr]
292
+ self.mem_write[addr] = core.scratch[src]
293
+ case ("vstore", addr, src): # addr is a scalar
294
+ addr = core.scratch[addr]
295
+ for vi in range(VLEN):
296
+ self.mem_write[addr + vi] = core.scratch[src + vi]
297
+ case _:
298
+ raise NotImplementedError(f"Unknown store op {slot}")
299
+
300
+ def flow(self, core, *slot):
301
+ match slot:
302
+ case ("select", dest, cond, a, b):
303
+ self.scratch_write[dest] = (
304
+ core.scratch[a] if core.scratch[cond] != 0 else core.scratch[b]
305
+ )
306
+ case ("add_imm", dest, a, imm):
307
+ self.scratch_write[dest] = (core.scratch[a] + imm) % (2**32)
308
+ case ("vselect", dest, cond, a, b):
309
+ for vi in range(VLEN):
310
+ self.scratch_write[dest + vi] = (
311
+ core.scratch[a + vi]
312
+ if core.scratch[cond + vi] != 0
313
+ else core.scratch[b + vi]
314
+ )
315
+ case ("halt",):
316
+ core.state = CoreState.STOPPED
317
+ case ("pause",):
318
+ if self.enable_pause:
319
+ core.state = CoreState.PAUSED
320
+ case ("trace_write", val):
321
+ core.trace_buf.append(core.scratch[val])
322
+ case ("cond_jump", cond, addr):
323
+ if core.scratch[cond] != 0:
324
+ core.pc = addr
325
+ case ("cond_jump_rel", cond, offset):
326
+ if core.scratch[cond] != 0:
327
+ core.pc += offset
328
+ case ("jump", addr):
329
+ core.pc = addr
330
+ case ("jump_indirect", addr):
331
+ core.pc = core.scratch[addr]
332
+ case ("coreid", dest):
333
+ self.scratch_write[dest] = core.id
334
+ case _:
335
+ raise NotImplementedError(f"Unknown flow op {slot}")
336
+
337
+ def trace_post_step(self, instr, core):
338
+ # You can add extra stuff to the trace if you want!
339
+ for addr, (name, length) in self.debug_info.scratch_map.items():
340
+ if any((addr + vi) in self.scratch_write for vi in range(length)):
341
+ val = str(core.scratch[addr : addr + length])
342
+ val = val.replace("[", "").replace("]", "")
343
+ self.trace.write(
344
+ f'{{"name": "{val}", "cat": "op", "ph": "X", "pid": {len(self.cores) + core.id}, "tid": {BASE_ADDR_TID + addr}, "ts": {self.cycle}, "dur": 1 }},\n'
345
+ )
346
+
347
+ def trace_slot(self, core, slot, name, i):
348
+ self.trace.write(
349
+ f'{{"name": "{slot[0]}", "cat": "op", "ph": "X", "pid": {core.id}, "tid": {self.tids[(core.id, name, i)]}, "ts": {self.cycle}, "dur": 1, "args":{{"slot": "{str(slot)}", "named": "{str(self.rewrite_slot(slot))}" }} }},\n'
350
+ )
351
+
352
+ def step(self, instr: Instruction, core):
353
+ """
354
+ Execute all the slots in each engine for a single instruction bundle
355
+ """
356
+ ENGINE_FNS = {
357
+ "alu": self.alu,
358
+ "valu": self.valu,
359
+ "load": self.load,
360
+ "store": self.store,
361
+ "flow": self.flow,
362
+ }
363
+ self.scratch_write = {}
364
+ self.mem_write = {}
365
+ for name, slots in instr.items():
366
+ if name == "debug":
367
+ if not self.enable_debug:
368
+ continue
369
+ for slot in slots:
370
+ if slot[0] == "compare":
371
+ loc, key = slot[1], slot[2]
372
+ ref = self.value_trace[key]
373
+ res = core.scratch[loc]
374
+ assert res == ref, f"{res} != {ref} for {key} at pc={core.pc}"
375
+ elif slot[0] == "vcompare":
376
+ loc, keys = slot[1], slot[2]
377
+ ref = [self.value_trace[key] for key in keys]
378
+ res = core.scratch[loc : loc + VLEN]
379
+ assert res == ref, (
380
+ f"{res} != {ref} for {keys} at pc={core.pc} loc={loc}"
381
+ )
382
+ continue
383
+ assert len(slots) <= SLOT_LIMITS[name]
384
+ for i, slot in enumerate(slots):
385
+ if self.trace is not None:
386
+ self.trace_slot(core, slot, name, i)
387
+ ENGINE_FNS[name](core, *slot)
388
+ for addr, val in self.scratch_write.items():
389
+ core.scratch[addr] = val
390
+ for addr, val in self.mem_write.items():
391
+ self.mem[addr] = val
392
+
393
+ if self.trace:
394
+ self.trace_post_step(instr, core)
395
+
396
+ del self.scratch_write
397
+ del self.mem_write
398
+
399
+ def __del__(self):
400
+ if self.trace is not None:
401
+ self.trace.write("]")
402
+ self.trace.close()
403
+
404
+
405
+ @dataclass
406
+ class Tree:
407
+ """
408
+ An implicit perfect balanced binary tree with values on the nodes.
409
+ """
410
+
411
+ height: int
412
+ values: list[int]
413
+
414
+ @staticmethod
415
+ def generate(height: int):
416
+ n_nodes = 2 ** (height + 1) - 1
417
+ values = [random.randint(0, 2**30 - 1) for _ in range(n_nodes)]
418
+ return Tree(height, values)
419
+
420
+
421
+ @dataclass
422
+ class Input:
423
+ """
424
+ A batch of inputs, indices to nodes (starting as 0) and initial input
425
+ values. We then iterate these for a specified number of rounds.
426
+ """
427
+
428
+ indices: list[int]
429
+ values: list[int]
430
+ rounds: int
431
+
432
+ @staticmethod
433
+ def generate(forest: Tree, batch_size: int, rounds: int):
434
+ indices = [0 for _ in range(batch_size)]
435
+ values = [random.randint(0, 2**30 - 1) for _ in range(batch_size)]
436
+ return Input(indices, values, rounds)
437
+
438
+
439
+ HASH_STAGES = [
440
+ ("+", 0x7ED55D16, "+", "<<", 12),
441
+ ("^", 0xC761C23C, "^", ">>", 19),
442
+ ("+", 0x165667B1, "+", "<<", 5),
443
+ ("+", 0xD3A2646C, "^", "<<", 9),
444
+ ("+", 0xFD7046C5, "+", "<<", 3),
445
+ ("^", 0xB55A4F09, "^", ">>", 16),
446
+ ]
447
+
448
+
449
+ def myhash(a: int) -> int:
450
+ """A simple 32-bit hash function"""
451
+ fns = {
452
+ "+": lambda x, y: x + y,
453
+ "^": lambda x, y: x ^ y,
454
+ "<<": lambda x, y: x << y,
455
+ ">>": lambda x, y: x >> y,
456
+ }
457
+
458
+ def r(x):
459
+ return x % (2**32)
460
+
461
+ for op1, val1, op2, op3, val3 in HASH_STAGES:
462
+ a = r(fns[op2](r(fns[op1](a, val1)), r(fns[op3](a, val3))))
463
+
464
+ return a
465
+
466
+
467
+ def reference_kernel(t: Tree, inp: Input):
468
+ """
469
+ Reference implementation of the kernel.
470
+
471
+ A parallel tree traversal where at each node we set
472
+ cur_inp_val = myhash(cur_inp_val ^ node_val)
473
+ and then choose the left branch if cur_inp_val is even.
474
+ If we reach the bottom of the tree we wrap around to the top.
475
+ """
476
+ for h in range(inp.rounds):
477
+ for i in range(len(inp.indices)):
478
+ idx = inp.indices[i]
479
+ val = inp.values[i]
480
+ val = myhash(val ^ t.values[idx])
481
+ idx = 2 * idx + (1 if val % 2 == 0 else 2)
482
+ idx = 0 if idx >= len(t.values) else idx
483
+ inp.values[i] = val
484
+ inp.indices[i] = idx
485
+
486
+
487
+ def build_mem_image(t: Tree, inp: Input) -> list[int]:
488
+ """
489
+ Build a flat memory image of the problem.
490
+ """
491
+ header = 7
492
+ extra_room = len(t.values) + len(inp.indices) * 2 + VLEN * 2 + 32
493
+ mem = [0] * (
494
+ header + len(t.values) + len(inp.indices) + len(inp.values) + extra_room
495
+ )
496
+ forest_values_p = header
497
+ inp_indices_p = forest_values_p + len(t.values)
498
+ inp_values_p = inp_indices_p + len(inp.values)
499
+ extra_room = inp_values_p + len(inp.values)
500
+
501
+ mem[0] = inp.rounds
502
+ mem[1] = len(t.values)
503
+ mem[2] = len(inp.indices)
504
+ mem[3] = t.height
505
+ mem[4] = forest_values_p
506
+ mem[5] = inp_indices_p
507
+ mem[6] = inp_values_p
508
+ mem[7] = extra_room
509
+
510
+ mem[header:inp_indices_p] = t.values
511
+ mem[inp_indices_p:inp_values_p] = inp.indices
512
+ mem[inp_values_p:] = inp.values
513
+ return mem
514
+
515
+
516
+ def myhash_traced(a: int, trace: dict[Any, int], round: int, batch_i: int) -> int:
517
+ """A simple 32-bit hash function"""
518
+ fns = {
519
+ "+": lambda x, y: x + y,
520
+ "^": lambda x, y: x ^ y,
521
+ "<<": lambda x, y: x << y,
522
+ ">>": lambda x, y: x >> y,
523
+ }
524
+
525
+ def r(x):
526
+ return x % (2**32)
527
+
528
+ for i, (op1, val1, op2, op3, val3) in enumerate(HASH_STAGES):
529
+ a = r(fns[op2](r(fns[op1](a, val1)), r(fns[op3](a, val3))))
530
+ trace[(round, batch_i, "hash_stage", i)] = a
531
+
532
+ return a
533
+
534
+
535
+ def reference_kernel2(mem: list[int], trace: dict[Any, int] = {}):
536
+ """
537
+ Reference implementation of the kernel on a flat memory.
538
+ """
539
+ # This is the initial memory layout
540
+ rounds = mem[0]
541
+ n_nodes = mem[1]
542
+ batch_size = mem[2]
543
+ forest_height = mem[3]
544
+ # Offsets into the memory which indices get added to
545
+ forest_values_p = mem[4]
546
+ inp_indices_p = mem[5]
547
+ inp_values_p = mem[6]
548
+ yield mem
549
+ for h in range(rounds):
550
+ for i in range(batch_size):
551
+ idx = mem[inp_indices_p + i]
552
+ trace[(h, i, "idx")] = idx
553
+ val = mem[inp_values_p + i]
554
+ trace[(h, i, "val")] = val
555
+ node_val = mem[forest_values_p + idx]
556
+ trace[(h, i, "node_val")] = node_val
557
+ val = myhash_traced(val ^ node_val, trace, h, i)
558
+ trace[(h, i, "hashed_val")] = val
559
+ idx = 2 * idx + (1 if val % 2 == 0 else 2)
560
+ trace[(h, i, "next_idx")] = idx
561
+ idx = 0 if idx >= n_nodes else idx
562
+ trace[(h, i, "wrapped_idx")] = idx
563
+ mem[inp_values_p + i] = val
564
+ mem[inp_indices_p + i] = idx
565
+ # You can add new yields or move this around for debugging
566
+ # as long as it's matched by pause instructions.
567
+ # The submission tests evaluate only on final memory.
568
+ yield mem