CreativeEngineer commited on
Commit
c166ffe
·
1 Parent(s): c5c47e3

Add few-shot prompt and shaped rewards

Browse files
Files changed (1) hide show
  1. app.py +98 -6
app.py CHANGED
@@ -80,6 +80,9 @@ except Exception as e:
80
  BASELINE_CYCLES = 147734
81
  TARGET_CYCLES = 1363
82
  SCORE_SCALE = 3000.0
 
 
 
83
  PERSIST_DIR = "/data" if os.path.isdir("/data") else "."
84
  ADAPTER_DIR = os.path.join(PERSIST_DIR, "adapters", "perf_takehome_latest")
85
  ADAPTER_DATASET_REPO = os.environ.get("ADAPTER_DATASET_REPO", "CreativeEngineer/vliw-optimizer-adapters")
@@ -172,9 +175,9 @@ def _try_download_adapter(add_log) -> None:
172
  fdst.write(fsrc.read())
173
  add_log(f"[OK] Downloaded adapter from dataset: {ADAPTER_DATASET_REPO}/{ADAPTER_DATASET_SUBDIR}")
174
  else:
175
- add_log(" No adapter found in dataset yet")
176
  except Exception as e:
177
- add_log(f" Adapter download skipped: {str(e)[:160]}")
178
 
179
 
180
  def _try_upload_adapter(add_log) -> None:
@@ -182,11 +185,11 @@ def _try_upload_adapter(add_log) -> None:
182
  add_log("[ERR] Hub sync disabled: huggingface_hub not available")
183
  return
184
  if not _adapter_exists(ADAPTER_DIR):
185
- add_log(" No adapter to upload yet")
186
  return
187
  token = _hf_token()
188
  if token is None:
189
- add_log(" No HF token set (HF_TOKEN/HUGGINGFACE_HUB_TOKEN); skipping upload")
190
  return
191
  try:
192
  api = HfApi(token=token)
@@ -200,7 +203,7 @@ def _try_upload_adapter(add_log) -> None:
200
  )
201
  add_log(f"[OK] Uploaded adapter to dataset: {ADAPTER_DATASET_REPO}/{ADAPTER_DATASET_SUBDIR}")
202
  except Exception as e:
203
- add_log(f" Adapter upload skipped: {str(e)[:160]}")
204
 
205
 
206
  def _run_machine_with_cycle_limit(machine: Machine, max_cycles: int) -> bool:
@@ -264,6 +267,9 @@ def verify_perf_takehome_code(code: str, seed: int = 123) -> dict:
264
  "correctness": 0.0,
265
  "cycles": None,
266
  "msg": "Simulator unavailable",
 
 
 
267
  }
268
 
269
  try:
@@ -274,6 +280,22 @@ def verify_perf_takehome_code(code: str, seed: int = 123) -> dict:
274
  "correctness": 0.0,
275
  "cycles": None,
276
  "msg": "Empty code",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
277
  }
278
 
279
  if "OptimizedKernelBuilder" not in code:
@@ -282,6 +304,9 @@ def verify_perf_takehome_code(code: str, seed: int = 123) -> dict:
282
  "correctness": 0.0,
283
  "cycles": None,
284
  "msg": "Missing OptimizedKernelBuilder",
 
 
 
285
  }
286
 
287
  if "def run" not in code:
@@ -290,6 +315,9 @@ def verify_perf_takehome_code(code: str, seed: int = 123) -> dict:
290
  "correctness": 0.0,
291
  "cycles": None,
292
  "msg": "Missing run()",
 
 
 
293
  }
294
 
295
  safe_builtins = {
@@ -316,7 +344,18 @@ def verify_perf_takehome_code(code: str, seed: int = 123) -> dict:
316
  "SLOT_LIMITS": SLOT_LIMITS,
317
  }
318
 
319
- exec(code, exec_globals)
 
 
 
 
 
 
 
 
 
 
 
320
 
321
  if "OptimizedKernelBuilder" not in exec_globals:
322
  return {
@@ -324,6 +363,9 @@ def verify_perf_takehome_code(code: str, seed: int = 123) -> dict:
324
  "correctness": 0.0,
325
  "cycles": None,
326
  "msg": "OptimizedKernelBuilder not defined after exec",
 
 
 
327
  }
328
 
329
  ctx = _get_eval_context(seed)
@@ -351,6 +393,9 @@ def verify_perf_takehome_code(code: str, seed: int = 123) -> dict:
351
  "correctness": 0.0,
352
  "cycles": int(machine.cycle),
353
  "msg": f"Exceeded cycle limit (cycles={machine.cycle})",
 
 
 
354
  }
355
  cycles = machine.cycle
356
 
@@ -360,6 +405,9 @@ def verify_perf_takehome_code(code: str, seed: int = 123) -> dict:
360
  "correctness": 0.0,
361
  "cycles": int(cycles),
362
  "msg": f"Suspiciously low cycles ({cycles})",
 
 
 
363
  }
364
  if cycles > 200000:
365
  return {
@@ -367,6 +415,9 @@ def verify_perf_takehome_code(code: str, seed: int = 123) -> dict:
367
  "correctness": 0.0,
368
  "cycles": int(cycles),
369
  "msg": f"Cycles too high ({cycles})",
 
 
 
370
  }
371
 
372
  inp_values_p = ctx["inp_values_p"]
@@ -378,6 +429,9 @@ def verify_perf_takehome_code(code: str, seed: int = 123) -> dict:
378
  "correctness": 0.0,
379
  "cycles": int(cycles),
380
  "msg": f"Incorrect output (cycles={cycles})",
 
 
 
381
  }
382
 
383
  score = SCORE_SCALE / cycles
@@ -386,6 +440,9 @@ def verify_perf_takehome_code(code: str, seed: int = 123) -> dict:
386
  "correctness": 1.0,
387
  "cycles": int(cycles),
388
  "msg": f"Success: {cycles} cycles",
 
 
 
389
  }
390
  except Exception as e:
391
  return {
@@ -393,6 +450,9 @@ def verify_perf_takehome_code(code: str, seed: int = 123) -> dict:
393
  "correctness": 0.0,
394
  "cycles": None,
395
  "msg": f"Execution error: {str(e)[:200]}",
 
 
 
396
  }
397
 
398
 
@@ -410,6 +470,13 @@ def perf_takehome_reward_fn(completions, prompts=None, **kwargs):
410
  reward = 0.0
411
  if result.get("correctness", 0.0) > 0:
412
  reward = float(result["score"]) + 1.0
 
 
 
 
 
 
 
413
  cycles = result.get("cycles")
414
  with state_lock:
415
  if isinstance(cycles, int) and cycles < training_state["best_cycles"]:
@@ -420,6 +487,29 @@ def perf_takehome_reward_fn(completions, prompts=None, **kwargs):
420
 
421
 
422
  # Prompt template for VLIW optimization
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
423
  PERF_TAKEHOME_PROMPT = f"""Write an optimized VLIW/SIMD kernel. OUTPUT ONLY ONE ```python CODE BLOCK.
424
 
425
  ARCHITECTURE: 12 ALU + 6 VALU (VLEN=8) + 2 load + 2 store + 1 flow slots per cycle. 1536-word scratch.
@@ -454,6 +544,8 @@ RULES:
454
  - No imports.
455
 
456
  Baseline: {BASELINE_CYCLES:,} cycles. Target: <{TARGET_CYCLES:,} cycles.
 
 
457
  """
458
 
459
 
 
80
  BASELINE_CYCLES = 147734
81
  TARGET_CYCLES = 1363
82
  SCORE_SCALE = 3000.0
83
+ PARSE_REWARD = 0.02
84
+ API_REWARD = 0.05
85
+ EXEC_REWARD = 0.10
86
  PERSIST_DIR = "/data" if os.path.isdir("/data") else "."
87
  ADAPTER_DIR = os.path.join(PERSIST_DIR, "adapters", "perf_takehome_latest")
88
  ADAPTER_DATASET_REPO = os.environ.get("ADAPTER_DATASET_REPO", "CreativeEngineer/vliw-optimizer-adapters")
 
175
  fdst.write(fsrc.read())
176
  add_log(f"[OK] Downloaded adapter from dataset: {ADAPTER_DATASET_REPO}/{ADAPTER_DATASET_SUBDIR}")
177
  else:
178
+ add_log("[INFO] No adapter found in dataset yet")
179
  except Exception as e:
180
+ add_log(f"[INFO] Adapter download skipped: {str(e)[:160]}")
181
 
182
 
183
  def _try_upload_adapter(add_log) -> None:
 
185
  add_log("[ERR] Hub sync disabled: huggingface_hub not available")
186
  return
187
  if not _adapter_exists(ADAPTER_DIR):
188
+ add_log("[INFO] No adapter to upload yet")
189
  return
190
  token = _hf_token()
191
  if token is None:
192
+ add_log("[INFO] No HF token set (HF_TOKEN/HUGGINGFACE_HUB_TOKEN); skipping upload")
193
  return
194
  try:
195
  api = HfApi(token=token)
 
203
  )
204
  add_log(f"[OK] Uploaded adapter to dataset: {ADAPTER_DATASET_REPO}/{ADAPTER_DATASET_SUBDIR}")
205
  except Exception as e:
206
+ add_log(f"[INFO] Adapter upload skipped: {str(e)[:160]}")
207
 
208
 
209
  def _run_machine_with_cycle_limit(machine: Machine, max_cycles: int) -> bool:
 
267
  "correctness": 0.0,
268
  "cycles": None,
269
  "msg": "Simulator unavailable",
270
+ "parse_ok": False,
271
+ "api_ok": False,
272
+ "exec_ok": False,
273
  }
274
 
275
  try:
 
280
  "correctness": 0.0,
281
  "cycles": None,
282
  "msg": "Empty code",
283
+ "parse_ok": False,
284
+ "api_ok": False,
285
+ "exec_ok": False,
286
+ }
287
+
288
+ try:
289
+ compile(code, "<string>", "exec")
290
+ except Exception as e:
291
+ return {
292
+ "score": 0.0,
293
+ "correctness": 0.0,
294
+ "cycles": None,
295
+ "msg": f"Syntax error: {str(e)[:200]}",
296
+ "parse_ok": False,
297
+ "api_ok": False,
298
+ "exec_ok": False,
299
  }
300
 
301
  if "OptimizedKernelBuilder" not in code:
 
304
  "correctness": 0.0,
305
  "cycles": None,
306
  "msg": "Missing OptimizedKernelBuilder",
307
+ "parse_ok": True,
308
+ "api_ok": False,
309
+ "exec_ok": False,
310
  }
311
 
312
  if "def run" not in code:
 
315
  "correctness": 0.0,
316
  "cycles": None,
317
  "msg": "Missing run()",
318
+ "parse_ok": True,
319
+ "api_ok": False,
320
+ "exec_ok": False,
321
  }
322
 
323
  safe_builtins = {
 
344
  "SLOT_LIMITS": SLOT_LIMITS,
345
  }
346
 
347
+ try:
348
+ exec(code, exec_globals)
349
+ except Exception as e:
350
+ return {
351
+ "score": 0.0,
352
+ "correctness": 0.0,
353
+ "cycles": None,
354
+ "msg": f"Execution error: {str(e)[:200]}",
355
+ "parse_ok": True,
356
+ "api_ok": True,
357
+ "exec_ok": False,
358
+ }
359
 
360
  if "OptimizedKernelBuilder" not in exec_globals:
361
  return {
 
363
  "correctness": 0.0,
364
  "cycles": None,
365
  "msg": "OptimizedKernelBuilder not defined after exec",
366
+ "parse_ok": True,
367
+ "api_ok": True,
368
+ "exec_ok": True,
369
  }
370
 
371
  ctx = _get_eval_context(seed)
 
393
  "correctness": 0.0,
394
  "cycles": int(machine.cycle),
395
  "msg": f"Exceeded cycle limit (cycles={machine.cycle})",
396
+ "parse_ok": True,
397
+ "api_ok": True,
398
+ "exec_ok": True,
399
  }
400
  cycles = machine.cycle
401
 
 
405
  "correctness": 0.0,
406
  "cycles": int(cycles),
407
  "msg": f"Suspiciously low cycles ({cycles})",
408
+ "parse_ok": True,
409
+ "api_ok": True,
410
+ "exec_ok": True,
411
  }
412
  if cycles > 200000:
413
  return {
 
415
  "correctness": 0.0,
416
  "cycles": int(cycles),
417
  "msg": f"Cycles too high ({cycles})",
418
+ "parse_ok": True,
419
+ "api_ok": True,
420
+ "exec_ok": True,
421
  }
422
 
423
  inp_values_p = ctx["inp_values_p"]
 
429
  "correctness": 0.0,
430
  "cycles": int(cycles),
431
  "msg": f"Incorrect output (cycles={cycles})",
432
+ "parse_ok": True,
433
+ "api_ok": True,
434
+ "exec_ok": True,
435
  }
436
 
437
  score = SCORE_SCALE / cycles
 
440
  "correctness": 1.0,
441
  "cycles": int(cycles),
442
  "msg": f"Success: {cycles} cycles",
443
+ "parse_ok": True,
444
+ "api_ok": True,
445
+ "exec_ok": True,
446
  }
447
  except Exception as e:
448
  return {
 
450
  "correctness": 0.0,
451
  "cycles": None,
452
  "msg": f"Execution error: {str(e)[:200]}",
453
+ "parse_ok": False,
454
+ "api_ok": False,
455
+ "exec_ok": False,
456
  }
457
 
458
 
 
470
  reward = 0.0
471
  if result.get("correctness", 0.0) > 0:
472
  reward = float(result["score"]) + 1.0
473
+ else:
474
+ if result.get("parse_ok"):
475
+ reward += PARSE_REWARD
476
+ if result.get("api_ok"):
477
+ reward += API_REWARD
478
+ if result.get("exec_ok"):
479
+ reward += EXEC_REWARD
480
  cycles = result.get("cycles")
481
  with state_lock:
482
  if isinstance(cycles, int) and cycles < training_state["best_cycles"]:
 
487
 
488
 
489
  # Prompt template for VLIW optimization
490
+ FEWSHOT_EXAMPLES = """Example format (not optimized):
491
+ ```python
492
+ class OptimizedKernelBuilder(KernelBuilder):
493
+ def build_kernel(self, forest_height, n_nodes, batch_size, rounds):
494
+ self.add("flow", ("halt",))
495
+
496
+ def run():
497
+ return (0,)
498
+ ```
499
+
500
+ Example with scratch + load:
501
+ ```python
502
+ class OptimizedKernelBuilder(KernelBuilder):
503
+ def build_kernel(self, forest_height, n_nodes, batch_size, rounds):
504
+ tmp = self.alloc_scratch("tmp")
505
+ self.add("load", ("const", tmp, 0))
506
+ self.add("flow", ("halt",))
507
+
508
+ def run():
509
+ return (0,)
510
+ ```
511
+ """
512
+
513
  PERF_TAKEHOME_PROMPT = f"""Write an optimized VLIW/SIMD kernel. OUTPUT ONLY ONE ```python CODE BLOCK.
514
 
515
  ARCHITECTURE: 12 ALU + 6 VALU (VLEN=8) + 2 load + 2 store + 1 flow slots per cycle. 1536-word scratch.
 
544
  - No imports.
545
 
546
  Baseline: {BASELINE_CYCLES:,} cycles. Target: <{TARGET_CYCLES:,} cycles.
547
+
548
+ {FEWSHOT_EXAMPLES}
549
  """
550
 
551