CreativeEngineer commited on
Commit
648e193
·
1 Parent(s): b3b926b

Switch to correctness-gated GRPO LoRA with persistence

Browse files
Files changed (1) hide show
  1. app.py +403 -197
app.py CHANGED
@@ -1,7 +1,9 @@
1
  """
2
  HF Spaces app for VLIW kernel optimization via RL.
3
- Uses actual simulator for cycle-count based rewards.
4
  """
 
 
5
  import gradio as gr
6
  import threading
7
  import time
@@ -42,13 +44,21 @@ try:
42
  except Exception as e:
43
  startup_log.append(f"✗ CUDA check: {e}")
44
 
 
 
 
 
 
 
 
45
  # Import simulator components
46
  try:
47
  from problem import (
48
  Machine, Tree, Input, DebugInfo,
49
  build_mem_image, reference_kernel2,
50
- SLOT_LIMITS, VLEN, N_CORES, SCRATCH_SIZE, HASH_STAGES
51
  )
 
52
  startup_log.append("✓ VLIW Simulator: OK")
53
  SIMULATOR_AVAILABLE = True
54
  except Exception as e:
@@ -58,6 +68,9 @@ except Exception as e:
58
  # Constants
59
  BASELINE_CYCLES = 147734
60
  TARGET_CYCLES = 1363
 
 
 
61
 
62
  # Training state
63
  training_state = {
@@ -65,184 +78,291 @@ training_state = {
65
  "should_stop": False,
66
  "log": [],
67
  "best_cycles": BASELINE_CYCLES,
 
68
  "step": 0,
69
  }
70
  state_lock = threading.Lock()
71
 
 
 
72
 
73
  def get_status():
74
  return "\n".join(startup_log)
75
 
76
 
77
- def parse_kernel_code(code_text):
78
- """
79
- Parse LLM-generated kernel code into simulator instructions.
80
- Returns list of instruction dicts or None if parsing fails.
81
- """
82
- instructions = []
83
-
84
- # Try to find instruction patterns in the code
85
- # Format: {"engine": [("op", arg1, arg2, ...)]}
86
-
87
- # Look for dict-like instruction patterns
88
- pattern = r'\{[^}]+\}'
89
- matches = re.findall(pattern, code_text)
90
-
91
- for match in matches:
92
- try:
93
- # Try to eval as Python dict (safely)
94
- instr = eval(match, {"__builtins__": {}})
95
- if isinstance(instr, dict):
96
- # Validate it's a valid instruction
97
- valid_engines = {"alu", "valu", "load", "store", "flow", "debug"}
98
- if any(k in valid_engines for k in instr.keys()):
99
- instructions.append(instr)
100
- except:
101
- continue
102
-
103
- return instructions if instructions else None
104
-
105
-
106
- def build_simple_kernel(batch_size, rounds):
107
- """
108
- Build a simple baseline kernel for comparison.
109
- This is a simplified version that the model should try to beat.
110
- """
111
- instructions = []
112
-
113
- # Initialize scratch space addresses
114
- for i in range(7):
115
- instructions.append({"load": [("const", i, i)]})
116
- instructions.append({"load": [("load", i, i)]})
117
-
118
- instructions.append({"flow": [("pause",)]})
119
-
120
- # Main loop body (simplified)
121
- for r in range(min(rounds, 2)): # Limit for testing
122
- for i in range(min(batch_size, 4)): # Limit for testing
123
- # Load index and value
124
- instructions.append({"alu": [("+", 10, 5, 0)]}) # addr = inp_indices_p + 0
125
- instructions.append({"load": [("load", 11, 10)]}) # idx = mem[addr]
126
- instructions.append({"alu": [("+", 12, 6, 0)]}) # addr = inp_values_p + 0
127
- instructions.append({"load": [("load", 13, 12)]}) # val = mem[addr]
128
-
129
- instructions.append({"flow": [("pause",)]})
130
-
131
- return instructions
 
 
 
 
 
 
 
 
 
132
 
133
 
134
- def evaluate_kernel(instructions, seed=42):
135
- """
136
- Run kernel through simulator and return cycle count.
137
- Lower is better.
138
- """
139
  if not SIMULATOR_AVAILABLE:
140
- return BASELINE_CYCLES
 
 
 
 
 
141
 
142
  try:
143
- random.seed(seed)
144
- forest = Tree.generate(10)
145
- inp = Input.generate(forest, 256, 16)
146
- mem = build_mem_image(forest, inp)
147
-
148
- debug_info = DebugInfo(scratch_map={})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
 
150
  machine = Machine(
151
- mem,
152
- instructions,
153
- debug_info,
154
  n_cores=N_CORES,
155
  trace=False,
156
  )
157
  machine.enable_pause = False
158
  machine.enable_debug = False
159
 
160
- # Run the machine
161
- machine.run()
162
-
163
- return machine.cycle
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
  except Exception as e:
165
- # Return high cycle count for invalid code
166
- return BASELINE_CYCLES * 2
 
 
 
 
167
 
168
 
169
- def vliw_reward_fn(completions, prompts=None, **kwargs):
170
- """
171
- Reward function based on VLIW simulator cycle count.
172
- Higher reward for lower cycle count.
173
- """
174
  rewards = []
175
-
176
  for completion in completions:
177
- # Extract text from completion
178
  if isinstance(completion, list):
179
  text = completion[0].get("content", "") if completion else ""
180
  else:
181
  text = str(completion)
182
 
183
- # Try to parse as kernel instructions
184
- instructions = parse_kernel_code(text)
185
-
186
- if instructions and len(instructions) > 5:
187
- # Evaluate with simulator
188
- cycles = evaluate_kernel(instructions)
189
-
190
- # Reward: normalized improvement over baseline
191
- # Max reward when cycles <= TARGET_CYCLES
192
- if cycles <= TARGET_CYCLES:
193
- reward = 2.0 # Maximum reward
194
- elif cycles < BASELINE_CYCLES:
195
- # Linear scale between baseline and target
196
- improvement = (BASELINE_CYCLES - cycles) / (BASELINE_CYCLES - TARGET_CYCLES)
197
- reward = 0.5 + 1.5 * improvement
198
- else:
199
- # Below baseline performance
200
- reward = 0.5 * (BASELINE_CYCLES / max(cycles, 1))
201
- else:
202
- # Could not parse - give small reward for code-like output
203
- reward = 0.1
204
- if "def " in text or "for " in text:
205
- reward = 0.2
206
- if any(kw in text for kw in ["alu", "load", "store", "valu"]):
207
- reward = 0.3
208
-
209
- rewards.append(reward)
210
-
211
  return rewards
212
 
213
 
214
  # Prompt template for VLIW optimization
215
- VLIW_PROMPT = """You are an expert in VLIW (Very Long Instruction Word) architecture optimization.
216
-
217
- Generate optimized VLIW assembly code for a parallel tree traversal kernel.
218
-
219
- The architecture has these engines that execute in parallel each cycle:
220
- - alu: up to 12 scalar ALU operations per cycle
221
- - valu: up to 6 vector ALU operations (VLEN=8 elements)
222
- - load: up to 2 load operations per cycle
223
- - store: up to 2 store operations per cycle
224
- - flow: 1 control flow operation per cycle
225
-
226
- Instructions are in Python dict format:
227
- {"alu": [("+", dest, src1, src2), ("*", dest, src1, src2)], "load": [("load", dest, addr)]}
228
-
229
- The kernel should:
230
- 1. Load indices and values from memory
231
- 2. Perform hash computation (6 stages using +, ^, <<, >>)
232
- 3. Update tree traversal index based on hash result
233
- 4. Store results back to memory
234
-
235
- Optimize for minimum cycle count. Current baseline: 147,734 cycles. Target: <1,363 cycles.
236
-
237
- Generate the optimized kernel code:"""
 
 
 
 
 
 
 
 
 
 
 
 
238
 
239
 
240
- def run_training(model_name, num_steps):
241
- """Run RL training with VLIW simulator rewards."""
242
  import torch
243
  from datasets import Dataset
244
  from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
245
  from peft import LoraConfig
 
246
  from trl import GRPOConfig, GRPOTrainer
247
  from transformers import TrainerCallback
248
 
@@ -258,12 +378,16 @@ def run_training(model_name, num_steps):
258
  training_state["should_stop"] = False
259
  training_state["log"] = []
260
  training_state["best_cycles"] = BASELINE_CYCLES
 
261
  training_state["step"] = 0
262
 
263
  try:
264
  add_log(f"Starting VLIW optimization training")
265
- add_log(f"Model: {model_name}, Steps: {num_steps}")
 
 
266
  add_log(f"Baseline: {BASELINE_CYCLES:,} cycles, Target: {TARGET_CYCLES:,} cycles")
 
267
 
268
  # Load tokenizer
269
  add_log("Loading tokenizer...")
@@ -279,17 +403,25 @@ def run_training(model_name, num_steps):
279
  bnb_4bit_quant_type="nf4",
280
  bnb_4bit_compute_dtype=torch.bfloat16,
281
  )
282
- model = AutoModelForCausalLM.from_pretrained(
283
  model_name,
284
  quantization_config=bnb_config,
285
  device_map="auto",
286
  trust_remote_code=True,
287
  )
288
- add_log(f"✓ Model loaded on {next(model.parameters()).device}")
289
 
290
- # Create dataset with VLIW prompts
 
 
 
 
 
 
 
 
291
  add_log("Creating VLIW optimization dataset...")
292
- prompts = [VLIW_PROMPT] * 16
293
  dataset = Dataset.from_dict({"prompt": prompts})
294
  add_log(f"✓ Dataset ready: {len(prompts)} prompts")
295
 
@@ -304,62 +436,102 @@ def run_training(model_name, num_steps):
304
  task_type="CAUSAL_LM",
305
  )
306
 
307
- # Custom callback for logging
 
 
 
 
 
308
  class VLIWCallback(TrainerCallback):
309
  def on_step_end(self, args, state, control, **kwargs):
310
  with state_lock:
311
- training_state["step"] = state.global_step
 
312
  if training_state["should_stop"]:
313
  control.should_training_stop = True
 
 
314
  return control
315
 
316
  def on_log(self, args, state, control, logs=None, **kwargs):
317
  if logs:
318
  loss = logs.get("loss", "N/A")
319
  reward = logs.get("reward", logs.get("mean_reward", "N/A"))
320
- step = state.global_step
321
  add_log(f"Step {step}: loss={loss:.4f}, reward={reward:.4f}" if isinstance(loss, float) else f"Step {step}: {logs}")
322
 
323
- # GRPO config
324
- add_log("Creating GRPO trainer with VLIW rewards...")
325
- config = GRPOConfig(
326
- output_dir="./grpo_vliw_output",
327
- num_train_epochs=1,
328
- max_steps=num_steps,
329
- per_device_train_batch_size=1,
330
- gradient_accumulation_steps=4,
331
- learning_rate=1e-5,
332
- logging_steps=1,
333
- save_steps=999999,
334
- report_to="none",
335
- remove_unused_columns=False,
336
- max_completion_length=512,
337
- num_generations=4,
338
- )
339
-
340
- trainer = GRPOTrainer(
341
- model=model,
342
- args=config,
343
- train_dataset=dataset,
344
- reward_funcs=vliw_reward_fn,
345
- peft_config=lora_config,
346
- processing_class=tokenizer,
347
- callbacks=[VLIWCallback()],
348
- )
349
- add_log("✓ Trainer ready")
350
 
351
- # Train
352
  add_log("Starting training loop...")
353
- add_log("(Model will learn to generate VLIW code with lower cycle counts)")
354
- train_result = trainer.train()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
355
 
356
- metrics = train_result.metrics
357
- add_log(f"✓ Training complete!")
358
- add_log(f" Total steps: {metrics.get('train_steps', num_steps)}")
 
 
 
 
 
 
 
359
 
360
  # Test generation
361
  add_log("Testing trained model...")
362
- inputs = tokenizer(VLIW_PROMPT[:200], return_tensors="pt").to(model.device)
363
  with torch.no_grad():
364
  outputs = model.generate(
365
  **inputs,
@@ -370,15 +542,15 @@ def run_training(model_name, num_steps):
370
  )
371
  result = tokenizer.decode(outputs[0], skip_special_tokens=True)
372
 
373
- # Try to evaluate the generated code
374
- instructions = parse_kernel_code(result)
375
- if instructions:
376
- cycles = evaluate_kernel(instructions)
377
- add_log(f"Generated kernel: {len(instructions)} instructions, {cycles:,} cycles")
378
- speedup = BASELINE_CYCLES / max(cycles, 1)
379
  add_log(f"Speedup: {speedup:.2f}x over baseline")
380
  else:
381
- add_log(f"Sample output (first 200 chars): {result[len(VLIW_PROMPT[:200]):len(VLIW_PROMPT[:200])+200]}...")
382
 
383
  add_log("\n✓ All done!")
384
 
@@ -398,13 +570,25 @@ def run_training(model_name, num_steps):
398
  return "\n".join(log)
399
 
400
 
401
- def start_training(model_name, num_steps):
402
  """Start training."""
403
  with state_lock:
404
  if training_state["is_training"]:
405
- return "Training already in progress. Please wait."
406
 
407
- return run_training(model_name, int(num_steps))
 
 
 
 
 
 
 
 
 
 
 
 
408
 
409
 
410
  def stop_training():
@@ -420,14 +604,14 @@ def stop_training():
420
  with gr.Blocks(title="VLIW Optimizer") as demo:
421
  gr.Markdown("# VLIW Kernel Optimizer - RL Training")
422
  gr.Markdown(f"""
423
- Train a language model with reinforcement learning to generate optimized VLIW/SIMD kernels.
424
 
425
  **Goal:** Reduce cycle count from **{BASELINE_CYCLES:,}** (baseline) to **<{TARGET_CYCLES:,}** (108x speedup)
426
 
427
  **How it works:**
428
- 1. Model generates VLIW assembly code
429
- 2. Simulator evaluates cycle count
430
- 3. RL training improves model based on cycle-count rewards
431
  """)
432
 
433
  with gr.Row():
@@ -448,12 +632,28 @@ with gr.Blocks(title="VLIW Optimizer") as demo:
448
  value="Qwen/Qwen2.5-Coder-1.5B-Instruct",
449
  label="Model",
450
  )
451
- steps_slider = gr.Slider(
452
  minimum=5,
453
  maximum=100,
454
  value=20,
455
  step=5,
456
- label="Training Steps",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
457
  )
458
 
459
  with gr.Row():
@@ -467,12 +667,18 @@ with gr.Blocks(title="VLIW Optimizer") as demo:
467
  value="Click 'Start Training' to begin VLIW optimization.",
468
  )
469
 
 
 
 
 
470
  start_btn.click(
471
  start_training,
472
- [model_dropdown, steps_slider],
473
  [output_box],
474
  )
475
  stop_btn.click(stop_training, [], [output_box])
476
 
 
 
477
  if __name__ == "__main__":
478
  demo.launch(server_name="0.0.0.0", server_port=7860)
 
1
  """
2
  HF Spaces app for VLIW kernel optimization via RL.
3
+ Uses actual simulator for correctness-gated cycle-count rewards.
4
  """
5
+ import os
6
+ import sys
7
  import gradio as gr
8
  import threading
9
  import time
 
44
  except Exception as e:
45
  startup_log.append(f"✗ CUDA check: {e}")
46
 
47
+ # Prefer simulator + KernelBuilder from bundled original_performance_takehome.
48
+ # In Spaces, this keeps evaluation consistent and enables correctness checks.
49
+ THIS_DIR = os.path.dirname(os.path.abspath(__file__))
50
+ PERF_TAKEHOME_PATH = os.path.join(THIS_DIR, "original_performance_takehome")
51
+ if os.path.isdir(PERF_TAKEHOME_PATH):
52
+ sys.path.insert(0, PERF_TAKEHOME_PATH)
53
+
54
  # Import simulator components
55
  try:
56
  from problem import (
57
  Machine, Tree, Input, DebugInfo,
58
  build_mem_image, reference_kernel2,
59
+ SLOT_LIMITS, VLEN, N_CORES, SCRATCH_SIZE, CoreState
60
  )
61
+ from perf_takehome import KernelBuilder, HASH_STAGES
62
  startup_log.append("✓ VLIW Simulator: OK")
63
  SIMULATOR_AVAILABLE = True
64
  except Exception as e:
 
68
  # Constants
69
  BASELINE_CYCLES = 147734
70
  TARGET_CYCLES = 1363
71
+ SCORE_SCALE = 3000.0
72
+ PERSIST_DIR = "/data" if os.path.isdir("/data") else "."
73
+ ADAPTER_DIR = os.path.join(PERSIST_DIR, "adapters", "perf_takehome_latest")
74
 
75
  # Training state
76
  training_state = {
 
78
  "should_stop": False,
79
  "log": [],
80
  "best_cycles": BASELINE_CYCLES,
81
+ "best_code": None,
82
  "step": 0,
83
  }
84
  state_lock = threading.Lock()
85
 
86
+ _eval_context = {}
87
+
88
 
89
  def get_status():
90
  return "\n".join(startup_log)
91
 
92
 
93
+ def extract_code_block(text: str) -> str:
94
+ pattern = r"```python\s*(.*?)```"
95
+ matches = re.findall(pattern, text, re.DOTALL)
96
+ if matches:
97
+ return matches[-1].strip()
98
+ pattern = r"```\s*(.*?)```"
99
+ matches = re.findall(pattern, text, re.DOTALL)
100
+ if matches:
101
+ return matches[-1].strip()
102
+ return text.strip()
103
+
104
+
105
+ def _run_machine_with_cycle_limit(machine: Machine, max_cycles: int) -> bool:
106
+ for core in machine.cores:
107
+ if core.state == CoreState.PAUSED:
108
+ core.state = CoreState.RUNNING
109
+ while any(c.state == CoreState.RUNNING for c in machine.cores):
110
+ has_non_debug = False
111
+ for core in machine.cores:
112
+ if core.state != CoreState.RUNNING:
113
+ continue
114
+ if core.pc >= len(machine.program):
115
+ core.state = CoreState.STOPPED
116
+ continue
117
+ instr = machine.program[core.pc]
118
+ core.pc += 1
119
+ machine.step(instr, core)
120
+ if any(name != "debug" for name in instr.keys()):
121
+ has_non_debug = True
122
+ if has_non_debug:
123
+ machine.cycle += 1
124
+ if machine.cycle >= max_cycles:
125
+ for core in machine.cores:
126
+ core.state = CoreState.STOPPED
127
+ return False
128
+ return True
129
+
130
+
131
+ def _get_eval_context(seed: int) -> dict:
132
+ with state_lock:
133
+ cached = _eval_context.get(seed)
134
+ if cached is not None:
135
+ return cached
136
+ random.seed(seed)
137
+ forest = Tree.generate(10)
138
+ inp = Input.generate(forest, 256, 16)
139
+ mem0 = build_mem_image(forest, inp)
140
+ ref_mem = None
141
+ for ref_mem in reference_kernel2(list(mem0)):
142
+ pass
143
+ if ref_mem is None:
144
+ raise RuntimeError("Reference kernel produced no output")
145
+ inp_values_p = ref_mem[6]
146
+ expected = ref_mem[inp_values_p : inp_values_p + len(inp.values)]
147
+ ctx = {
148
+ "forest": forest,
149
+ "inp": inp,
150
+ "mem0": mem0,
151
+ "expected": expected,
152
+ "inp_values_p": inp_values_p,
153
+ }
154
+ with state_lock:
155
+ _eval_context[seed] = ctx
156
+ return ctx
157
 
158
 
159
+ def verify_perf_takehome_code(code: str, seed: int = 123) -> dict:
 
 
 
 
160
  if not SIMULATOR_AVAILABLE:
161
+ return {
162
+ "score": 0.0,
163
+ "correctness": 0.0,
164
+ "cycles": None,
165
+ "msg": "Simulator unavailable",
166
+ }
167
 
168
  try:
169
+ code = code.strip()
170
+ if not code:
171
+ return {
172
+ "score": 0.0,
173
+ "correctness": 0.0,
174
+ "cycles": None,
175
+ "msg": "Empty code",
176
+ }
177
+
178
+ if "OptimizedKernelBuilder" not in code:
179
+ return {
180
+ "score": 0.0,
181
+ "correctness": 0.0,
182
+ "cycles": None,
183
+ "msg": "Missing OptimizedKernelBuilder",
184
+ }
185
+
186
+ if "def run" not in code:
187
+ return {
188
+ "score": 0.0,
189
+ "correctness": 0.0,
190
+ "cycles": None,
191
+ "msg": "Missing run()",
192
+ }
193
+
194
+ safe_builtins = {
195
+ "abs": abs,
196
+ "all": all,
197
+ "any": any,
198
+ "dict": dict,
199
+ "enumerate": enumerate,
200
+ "int": int,
201
+ "len": len,
202
+ "list": list,
203
+ "max": max,
204
+ "min": min,
205
+ "range": range,
206
+ "sum": sum,
207
+ "tuple": tuple,
208
+ "zip": zip,
209
+ }
210
+ exec_globals = {
211
+ "__builtins__": safe_builtins,
212
+ "KernelBuilder": KernelBuilder,
213
+ "HASH_STAGES": HASH_STAGES,
214
+ "VLEN": VLEN,
215
+ "SLOT_LIMITS": SLOT_LIMITS,
216
+ }
217
+
218
+ exec(code, exec_globals)
219
+
220
+ if "OptimizedKernelBuilder" not in exec_globals:
221
+ return {
222
+ "score": 0.0,
223
+ "correctness": 0.0,
224
+ "cycles": None,
225
+ "msg": "OptimizedKernelBuilder not defined after exec",
226
+ }
227
+
228
+ ctx = _get_eval_context(seed)
229
+ forest = ctx["forest"]
230
+ inp = ctx["inp"]
231
+ mem0 = ctx["mem0"]
232
+
233
+ kb = exec_globals["OptimizedKernelBuilder"]()
234
+ kb.build_kernel(10, len(forest.values), 256, 16)
235
 
236
  machine = Machine(
237
+ list(mem0),
238
+ kb.instrs,
239
+ kb.debug_info(),
240
  n_cores=N_CORES,
241
  trace=False,
242
  )
243
  machine.enable_pause = False
244
  machine.enable_debug = False
245
 
246
+ ok = _run_machine_with_cycle_limit(machine, max_cycles=250000)
247
+ if not ok:
248
+ return {
249
+ "score": 0.0,
250
+ "correctness": 0.0,
251
+ "cycles": int(machine.cycle),
252
+ "msg": f"Exceeded cycle limit (cycles={machine.cycle})",
253
+ }
254
+ cycles = machine.cycle
255
+
256
+ if cycles <= 100:
257
+ return {
258
+ "score": 0.0,
259
+ "correctness": 0.0,
260
+ "cycles": int(cycles),
261
+ "msg": f"Suspiciously low cycles ({cycles})",
262
+ }
263
+ if cycles > 200000:
264
+ return {
265
+ "score": 0.0,
266
+ "correctness": 0.0,
267
+ "cycles": int(cycles),
268
+ "msg": f"Cycles too high ({cycles})",
269
+ }
270
+
271
+ inp_values_p = ctx["inp_values_p"]
272
+ expected = ctx["expected"]
273
+ actual = machine.mem[inp_values_p : inp_values_p + len(inp.values)]
274
+ if expected != actual:
275
+ return {
276
+ "score": 0.0,
277
+ "correctness": 0.0,
278
+ "cycles": int(cycles),
279
+ "msg": f"Incorrect output (cycles={cycles})",
280
+ }
281
+
282
+ score = SCORE_SCALE / cycles
283
+ return {
284
+ "score": float(score),
285
+ "correctness": 1.0,
286
+ "cycles": int(cycles),
287
+ "msg": f"Success: {cycles} cycles",
288
+ }
289
  except Exception as e:
290
+ return {
291
+ "score": 0.0,
292
+ "correctness": 0.0,
293
+ "cycles": None,
294
+ "msg": f"Execution error: {str(e)[:200]}",
295
+ }
296
 
297
 
298
+ def perf_takehome_reward_fn(completions, prompts=None, **kwargs):
 
 
 
 
299
  rewards = []
 
300
  for completion in completions:
 
301
  if isinstance(completion, list):
302
  text = completion[0].get("content", "") if completion else ""
303
  else:
304
  text = str(completion)
305
 
306
+ code = extract_code_block(text)
307
+ result = verify_perf_takehome_code(code)
308
+
309
+ reward = 0.0
310
+ if result.get("correctness", 0.0) > 0:
311
+ reward = float(result["score"]) + 1.0
312
+ cycles = result.get("cycles")
313
+ with state_lock:
314
+ if isinstance(cycles, int) and cycles < training_state["best_cycles"]:
315
+ training_state["best_cycles"] = cycles
316
+ training_state["best_code"] = code
317
+ rewards.append(float(reward))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
318
  return rewards
319
 
320
 
321
  # Prompt template for VLIW optimization
322
+ PERF_TAKEHOME_PROMPT = f"""Write an optimized VLIW/SIMD kernel. OUTPUT ONLY ONE ```python CODE BLOCK.
323
+
324
+ ARCHITECTURE: 12 ALU + 6 VALU (VLEN=8) + 2 load + 2 store + 1 flow slots per cycle. 1536-word scratch.
325
+
326
+ API (KernelBuilder):
327
+ - alloc_scratch(name, length) -> addr
328
+ - scratch_const(val, name) -> addr
329
+ - add(engine, slot): engine in {{alu, valu, load, store, flow}}
330
+ - alu: (op, dst, src1, src2) where op in {{+,-,*,//,%,^,&,|,<<,>>,<,==,!=,<=,>=,>}}
331
+ - valu: same ops but on vectors (VLEN=8)
332
+ - load: (load,dst,addr), (vload,dst,addr), (const,dst,val), (vbroadcast,dst,scalar_addr)
333
+ - store: (store,addr,src), (vstore,addr,src)
334
+ - flow: (select,dst,cond,t,f), (vselect,dst,cond,t,f), (cond_jump,cond,pc), (jump,pc), (halt,)
335
+ - label(name): mark code position
336
+ - build(slots, vliw=True): pack slots into VLIW bundle
337
+
338
+ MEMORY: mem[4]=forest_values, mem[5]=inp_indices, mem[6]=inp_values (256 elements each)
339
+
340
+ ALGORITHM: 16 rounds x 256 items:
341
+ load idx,val
342
+ node = tree[idx]
343
+ val = hash(val ^ node) using HASH_STAGES
344
+ idx = 2*idx + (1 if val%2==0 else 2)
345
+ idx = 0 if idx >= n_nodes else idx
346
+ store idx,val
347
+
348
+ RULES:
349
+ - Output exactly one python code block.
350
+ - The code block must define:
351
+ - class OptimizedKernelBuilder(KernelBuilder): override build_kernel() and emit instructions using add()/build()
352
+ - def run(): return any tuple (ignored), but must exist
353
+ - No imports.
354
+
355
+ Baseline: {BASELINE_CYCLES:,} cycles. Target: <{TARGET_CYCLES:,} cycles.
356
+ """
357
 
358
 
359
+ def run_training(model_name, chunk_steps, max_total_steps, max_minutes, auto_continue):
360
+ """Run GRPO + LoRA training with correctness-gated perf_takehome rewards."""
361
  import torch
362
  from datasets import Dataset
363
  from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
364
  from peft import LoraConfig
365
+ from peft import PeftModel
366
  from trl import GRPOConfig, GRPOTrainer
367
  from transformers import TrainerCallback
368
 
 
378
  training_state["should_stop"] = False
379
  training_state["log"] = []
380
  training_state["best_cycles"] = BASELINE_CYCLES
381
+ training_state["best_code"] = None
382
  training_state["step"] = 0
383
 
384
  try:
385
  add_log(f"Starting VLIW optimization training")
386
+ add_log(f"Model: {model_name}")
387
+ add_log(f"Chunk steps: {chunk_steps}")
388
+ add_log(f"Auto-continue: {auto_continue} (max_total_steps={max_total_steps}, max_minutes={max_minutes})")
389
  add_log(f"Baseline: {BASELINE_CYCLES:,} cycles, Target: {TARGET_CYCLES:,} cycles")
390
+ add_log(f"Adapter dir: {ADAPTER_DIR}")
391
 
392
  # Load tokenizer
393
  add_log("Loading tokenizer...")
 
403
  bnb_4bit_quant_type="nf4",
404
  bnb_4bit_compute_dtype=torch.bfloat16,
405
  )
406
+ base_model = AutoModelForCausalLM.from_pretrained(
407
  model_name,
408
  quantization_config=bnb_config,
409
  device_map="auto",
410
  trust_remote_code=True,
411
  )
412
+ add_log(f"✓ Base model loaded on {next(base_model.parameters()).device}")
413
 
414
+ # Resume LoRA adapter if present
415
+ if os.path.isdir(ADAPTER_DIR) and os.path.exists(os.path.join(ADAPTER_DIR, "adapter_config.json")):
416
+ add_log("Loading existing LoRA adapter (resume)...")
417
+ model = PeftModel.from_pretrained(base_model, ADAPTER_DIR, is_trainable=True)
418
+ add_log("✓ Adapter loaded")
419
+ else:
420
+ model = base_model
421
+
422
+ # Create dataset with prompts
423
  add_log("Creating VLIW optimization dataset...")
424
+ prompts = [PERF_TAKEHOME_PROMPT] * 16
425
  dataset = Dataset.from_dict({"prompt": prompts})
426
  add_log(f"✓ Dataset ready: {len(prompts)} prompts")
427
 
 
436
  task_type="CAUSAL_LM",
437
  )
438
 
439
+ progress = {"step": 0}
440
+ start_time = time.time()
441
+ max_seconds = float(max_minutes) * 60.0 if auto_continue else float("inf")
442
+ total_target_steps = int(max_total_steps) if auto_continue else int(chunk_steps)
443
+
444
+ # Custom callback for logging + early stop
445
  class VLIWCallback(TrainerCallback):
446
  def on_step_end(self, args, state, control, **kwargs):
447
  with state_lock:
448
+ progress["step"] += 1
449
+ training_state["step"] = progress["step"]
450
  if training_state["should_stop"]:
451
  control.should_training_stop = True
452
+ if training_state["best_cycles"] <= TARGET_CYCLES:
453
+ control.should_training_stop = True
454
  return control
455
 
456
  def on_log(self, args, state, control, logs=None, **kwargs):
457
  if logs:
458
  loss = logs.get("loss", "N/A")
459
  reward = logs.get("reward", logs.get("mean_reward", "N/A"))
460
+ step = progress["step"]
461
  add_log(f"Step {step}: loss={loss:.4f}, reward={reward:.4f}" if isinstance(loss, float) else f"Step {step}: {logs}")
462
 
463
+ add_log("Creating GRPO trainer with perf_takehome rewards...")
464
+ output_dir = os.path.join(PERSIST_DIR, "grpo_perf_takehome_output")
465
+ os.makedirs(output_dir, exist_ok=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
466
 
467
+ add_log("✓ Trainer config ready")
468
  add_log("Starting training loop...")
469
+ add_log("(Stops early if target reached; can auto-continue in chunks)")
470
+
471
+ chunk_idx = 0
472
+ while True:
473
+ with state_lock:
474
+ if training_state["should_stop"]:
475
+ break
476
+ if training_state["best_cycles"] <= TARGET_CYCLES:
477
+ break
478
+
479
+ if progress["step"] >= total_target_steps:
480
+ break
481
+ if (time.time() - start_time) >= max_seconds:
482
+ break
483
+
484
+ remaining = total_target_steps - progress["step"]
485
+ this_chunk_steps = min(int(chunk_steps), int(remaining))
486
+ if this_chunk_steps <= 0:
487
+ break
488
+
489
+ chunk_idx += 1
490
+ add_log(f"Chunk {chunk_idx}: training {this_chunk_steps} steps...")
491
+
492
+ config = GRPOConfig(
493
+ output_dir=output_dir,
494
+ num_train_epochs=1,
495
+ max_steps=this_chunk_steps,
496
+ per_device_train_batch_size=1,
497
+ gradient_accumulation_steps=4,
498
+ learning_rate=1e-5,
499
+ logging_steps=1,
500
+ save_steps=999999,
501
+ report_to="none",
502
+ remove_unused_columns=False,
503
+ max_completion_length=512,
504
+ num_generations=4,
505
+ )
506
+
507
+ trainer = GRPOTrainer(
508
+ model=model,
509
+ args=config,
510
+ train_dataset=dataset,
511
+ reward_funcs=perf_takehome_reward_fn,
512
+ peft_config=lora_config,
513
+ processing_class=tokenizer,
514
+ callbacks=[VLIWCallback()],
515
+ )
516
+
517
+ train_result = trainer.train()
518
+ metrics = train_result.metrics
519
+ add_log(f"Chunk {chunk_idx} done: steps={metrics.get('train_steps', this_chunk_steps)}")
520
 
521
+ # Save adapter after each chunk so it persists across restarts
522
+ try:
523
+ os.makedirs(os.path.dirname(ADAPTER_DIR), exist_ok=True)
524
+ trainer.save_model(ADAPTER_DIR)
525
+ add_log(f"✓ Saved adapter to {ADAPTER_DIR}")
526
+ except Exception as e:
527
+ add_log(f"✗ Failed to save adapter: {str(e)[:120]}")
528
+
529
+ if not auto_continue:
530
+ break
531
 
532
  # Test generation
533
  add_log("Testing trained model...")
534
+ inputs = tokenizer(PERF_TAKEHOME_PROMPT, return_tensors="pt").to(model.device)
535
  with torch.no_grad():
536
  outputs = model.generate(
537
  **inputs,
 
542
  )
543
  result = tokenizer.decode(outputs[0], skip_special_tokens=True)
544
 
545
+ code = extract_code_block(result)
546
+ verify_out = verify_perf_takehome_code(code)
547
+ if verify_out.get("correctness", 0.0) > 0:
548
+ cycles = verify_out.get("cycles")
549
+ add_log(f"Generated kernel verified: {cycles:,} cycles")
550
+ speedup = BASELINE_CYCLES / max(int(cycles), 1) if isinstance(cycles, int) else 0.0
551
  add_log(f"Speedup: {speedup:.2f}x over baseline")
552
  else:
553
+ add_log(f"Generated kernel invalid: {verify_out.get('msg', '')[:160]}")
554
 
555
  add_log("\n✓ All done!")
556
 
 
570
  return "\n".join(log)
571
 
572
 
573
+ def start_training(model_name, chunk_steps, max_total_steps, max_minutes, auto_continue):
574
  """Start training."""
575
  with state_lock:
576
  if training_state["is_training"]:
577
+ return "\n".join(training_state["log"][-200:]) or "Training already in progress. Please wait."
578
 
579
+ thread = threading.Thread(
580
+ target=run_training,
581
+ args=(
582
+ model_name,
583
+ int(chunk_steps),
584
+ int(max_total_steps),
585
+ float(max_minutes),
586
+ bool(auto_continue),
587
+ ),
588
+ daemon=True,
589
+ )
590
+ thread.start()
591
+ return "Training started. Logs will stream below."
592
 
593
 
594
  def stop_training():
 
604
  with gr.Blocks(title="VLIW Optimizer") as demo:
605
  gr.Markdown("# VLIW Kernel Optimizer - RL Training")
606
  gr.Markdown(f"""
607
+ Train a language model with reinforcement learning (LoRA) at test time to generate correct, fast VLIW/SIMD kernels.
608
 
609
  **Goal:** Reduce cycle count from **{BASELINE_CYCLES:,}** (baseline) to **<{TARGET_CYCLES:,}** (108x speedup)
610
 
611
  **How it works:**
612
+ 1. Model generates Python kernel builder code
613
+ 2. Simulator checks correctness vs reference and measures cycles
614
+ 3. GRPO updates LoRA weights; adapter is saved and reloaded from `{ADAPTER_DIR}`
615
  """)
616
 
617
  with gr.Row():
 
632
  value="Qwen/Qwen2.5-Coder-1.5B-Instruct",
633
  label="Model",
634
  )
635
+ chunk_steps_slider = gr.Slider(
636
  minimum=5,
637
  maximum=100,
638
  value=20,
639
  step=5,
640
+ label="Chunk Steps",
641
+ )
642
+ auto_continue_checkbox = gr.Checkbox(
643
+ value=False,
644
+ label="Auto-continue (chain chunks)",
645
+ )
646
+ max_total_steps_slider = gr.Slider(
647
+ minimum=5,
648
+ maximum=500,
649
+ value=100,
650
+ step=5,
651
+ label="Max Total Steps",
652
+ )
653
+ max_minutes_number = gr.Number(
654
+ value=60,
655
+ precision=0,
656
+ label="Max Minutes",
657
  )
658
 
659
  with gr.Row():
 
667
  value="Click 'Start Training' to begin VLIW optimization.",
668
  )
669
 
670
+ def poll_log():
671
+ with state_lock:
672
+ return "\n".join(training_state["log"][-400:]) if training_state["log"] else ""
673
+
674
  start_btn.click(
675
  start_training,
676
+ [model_dropdown, chunk_steps_slider, max_total_steps_slider, max_minutes_number, auto_continue_checkbox],
677
  [output_box],
678
  )
679
  stop_btn.click(stop_training, [], [output_box])
680
 
681
+ gr.Timer(1.0).tick(poll_log, outputs=[output_box])
682
+
683
  if __name__ == "__main__":
684
  demo.launch(server_name="0.0.0.0", server_port=7860)