CreativeEngineer commited on
Commit
b03b587
·
1 Parent(s): 77e3392

Simplified version without simulator

Browse files
Files changed (1) hide show
  1. app.py +42 -112
app.py CHANGED
@@ -7,7 +7,6 @@ import sys
7
  import re
8
  import threading
9
  import time
10
- import random
11
  from datetime import datetime
12
 
13
  import gradio as gr
@@ -15,27 +14,17 @@ import gradio as gr
15
  # Thread lock for safe state access
16
  training_state_lock = threading.Lock()
17
 
18
- # Add simulator path
19
- SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
20
- PERF_TAKEHOME_PATH = os.path.join(SCRIPT_DIR, "original_performance_takehome")
21
- if os.path.exists(PERF_TAKEHOME_PATH):
22
- sys.path.insert(0, PERF_TAKEHOME_PATH)
23
-
24
  # Constants
25
  BASELINE_CYCLES = 147734
26
  TARGET_CYCLES = 1363
27
- SCORE_SCALE = 3000.0
28
 
29
  # Training state
30
  training_state = {
31
  "running": False,
32
- "step": 0,
33
- "total_steps": 0,
34
  "best_cycles": BASELINE_CYCLES,
35
  "best_code": None,
36
  "log": [],
37
  "start_time": None,
38
- "results": [],
39
  }
40
 
41
  SYSTEM_PROMPT = '''Write optimized VLIW/SIMD kernel. OUTPUT ONLY ONE ```python CODE BLOCK.
@@ -51,6 +40,8 @@ API:
51
  ALGORITHM: 16 rounds x 256 items, hash tree traversal.
52
 
53
  OPTIMIZATION: Use vload/vstore (8 elements), pack 6 VALU ops/cycle, unroll loops.
 
 
54
  '''
55
 
56
 
@@ -67,68 +58,25 @@ def extract_code_block(text: str) -> str:
67
  return text.strip()
68
 
69
 
70
- def verify_code(code: str) -> dict:
71
- """Verify kernel code and return metrics."""
72
- try:
73
- if not code or "def run" not in code:
74
- return {"score": 0.0, "correctness": 0.0, "cycles": None, "msg": "Invalid code"}
75
-
76
- if "OptimizedKernelBuilder" not in code:
77
- return {"score": 0.0, "correctness": 0.0, "cycles": None, "msg": "No OptimizedKernelBuilder"}
78
-
79
- exec_globals = {"FOREST_HEIGHT": 10, "ROUNDS": 16, "BATCH_SIZE": 256}
80
-
81
- setup_code = f'''
82
- import sys
83
- sys.path.insert(0, "{PERF_TAKEHOME_PATH}")
84
- from problem import Machine, Tree, Input, build_mem_image, N_CORES, reference_kernel2
85
- from perf_takehome import KernelBuilder, HASH_STAGES, BASELINE
86
- import random
87
- '''
88
- full_code = setup_code + "\n" + code
89
- exec(full_code, exec_globals)
90
-
91
- if "OptimizedKernelBuilder" not in exec_globals:
92
- return {"score": 0.0, "correctness": 0.0, "cycles": None, "msg": "Class not defined"}
93
-
94
- random.seed(123)
95
- from problem import Tree, Input, Machine, build_mem_image, N_CORES, reference_kernel2
96
-
97
- forest = Tree.generate(10)
98
- inp = Input.generate(forest, 256, 16)
99
- mem = build_mem_image(forest, inp)
100
-
101
- ref_mem = None
102
- for ref_mem in reference_kernel2(list(mem)):
103
- pass
104
-
105
- if ref_mem is None:
106
- return {"score": 0.0, "correctness": 0.0, "cycles": None, "msg": "Reference failed"}
107
-
108
- kb = exec_globals["OptimizedKernelBuilder"]()
109
- kb.build_kernel(10, len(forest.values), 256, 16)
110
- machine = Machine(list(mem), kb.instrs, kb.debug_info(), n_cores=N_CORES)
111
- machine.enable_pause = False
112
- machine.enable_debug = False
113
- machine.run()
114
-
115
- cycles = machine.cycle
116
-
117
- if cycles <= 100 or cycles > 200000:
118
- return {"score": 0.0, "correctness": 0.0, "cycles": cycles, "msg": f"Bad cycles: {cycles}"}
119
-
120
- inp_values_p = ref_mem[6]
121
- expected = ref_mem[inp_values_p : inp_values_p + len(inp.values)]
122
- actual = machine.mem[inp_values_p : inp_values_p + len(inp.values)]
123
-
124
- if expected != actual:
125
- return {"score": 0.0, "correctness": 0.0, "cycles": cycles, "msg": "Wrong output"}
126
-
127
- score = SCORE_SCALE / cycles
128
- return {"score": score, "correctness": 1.0, "cycles": cycles, "msg": f"OK: {cycles} cycles"}
129
-
130
- except Exception as e:
131
- return {"score": 0.0, "correctness": 0.0, "cycles": None, "msg": f"Error: {str(e)[:100]}"}
132
 
133
 
134
  def log(msg: str):
@@ -146,25 +94,17 @@ def reward_function(completions: list[str], **kwargs) -> list[float]:
146
  for completion in completions:
147
  try:
148
  code = extract_code_block(completion)
149
- result = verify_code(code)
150
  reward = result["score"]
151
-
152
- if result["correctness"] > 0:
153
- reward += 1.0
154
- cycles = result.get("cycles")
155
- if cycles:
156
- with training_state_lock:
157
- training_state["results"].append({"cycles": cycles, "time": time.time()})
158
- if cycles < training_state["best_cycles"]:
159
- training_state["best_cycles"] = cycles
160
- training_state["best_code"] = code
161
- log(f"NEW BEST: {cycles:,} cycles ({BASELINE_CYCLES/cycles:.2f}x)")
162
-
163
  rewards.append(reward)
164
  except Exception as e:
165
- log(f"Reward error: {str(e)[:50]}")
166
  rewards.append(0.0)
167
-
168
  return rewards
169
 
170
 
@@ -172,12 +112,9 @@ def run_training(model_name: str, num_steps: int, batch_size: int, lr: float, lo
172
  """Main training loop."""
173
  with training_state_lock:
174
  training_state["running"] = True
175
- training_state["step"] = 0
176
- training_state["total_steps"] = num_steps
177
  training_state["best_cycles"] = BASELINE_CYCLES
178
  training_state["best_code"] = None
179
  training_state["log"] = []
180
- training_state["results"] = []
181
  training_state["start_time"] = time.time()
182
 
183
  log(f"Starting: {model_name}")
@@ -201,7 +138,7 @@ def run_training(model_name: str, num_steps: int, batch_size: int, lr: float, lo
201
  tokenizer.pad_token = tokenizer.eos_token
202
 
203
  prompt = f"{SYSTEM_PROMPT}\n\nCURRENT: {BASELINE_CYCLES} cycles. TARGET: <{TARGET_CYCLES}."
204
- dataset = Dataset.from_dict({"prompt": [prompt] * 32})
205
 
206
  peft_config = LoraConfig(
207
  r=lora_rank,
@@ -222,9 +159,9 @@ def run_training(model_name: str, num_steps: int, batch_size: int, lr: float, lo
222
  learning_rate=lr,
223
  logging_steps=1,
224
  save_steps=max(1, num_steps // 5),
225
- max_completion_length=2048,
226
  temperature=0.7,
227
- num_generations=4,
228
  beta=0.1,
229
  bf16=True,
230
  report_to="none",
@@ -283,8 +220,7 @@ def run_training(model_name: str, num_steps: int, batch_size: int, lr: float, lo
283
  with training_state_lock:
284
  training_state["running"] = False
285
  elapsed = time.time() - (training_state["start_time"] or time.time())
286
- best = training_state["best_cycles"]
287
- log(f"Time: {elapsed/60:.1f} min, Best: {best:,} cycles")
288
 
289
 
290
  def start_training(model_name, num_steps, batch_size, lr, lora_rank):
@@ -308,21 +244,15 @@ def stop_training():
308
  def get_status():
309
  with training_state_lock:
310
  if not training_state["start_time"]:
311
- return "### Not started"
312
  elapsed = time.time() - training_state["start_time"]
313
- best = max(training_state["best_cycles"], 1)
314
  is_running = training_state["running"]
315
- logs = training_state["log"][-20:]
316
 
317
- speedup = BASELINE_CYCLES / best
318
  return f"""### {'Running' if is_running else 'Stopped'}
319
- | Metric | Value |
320
- |--------|-------|
321
- | Time | {elapsed/60:.1f} min |
322
- | Best | **{best:,}** cycles |
323
- | Speedup | **{speedup:.2f}x** |
324
- | Target | {TARGET_CYCLES:,} |
325
 
 
326
  ```
327
  {chr(10).join(logs)}
328
  ```"""
@@ -336,25 +266,25 @@ def get_best_code():
336
  # UI
337
  with gr.Blocks(title="VLIW Optimizer") as demo:
338
  gr.Markdown("# VLIW Kernel Optimizer via RL")
339
- gr.Markdown(f"**Baseline:** {BASELINE_CYCLES:,} | **Target:** {TARGET_CYCLES:,} (108x speedup)")
340
 
341
  with gr.Row():
342
  with gr.Column():
343
  model = gr.Dropdown(
344
- ["Qwen/Qwen2.5-Coder-7B-Instruct", "Qwen/Qwen2.5-Coder-3B-Instruct"],
345
  value="Qwen/Qwen2.5-Coder-3B-Instruct",
346
  label="Model"
347
  )
348
- steps = gr.Slider(1, 100, value=50, step=1, label="Steps")
349
- batch = gr.Slider(1, 8, value=4, step=1, label="Batch")
350
  lr = gr.Number(value=2e-4, label="LR")
351
- lora = gr.Slider(8, 64, value=32, step=8, label="LoRA Rank")
352
  with gr.Row():
353
  start_btn = gr.Button("Start", variant="primary")
354
  stop_btn = gr.Button("Stop", variant="stop")
355
 
356
  with gr.Column():
357
- status = gr.Markdown("### Not started")
358
  refresh = gr.Button("Refresh")
359
 
360
  with gr.Row():
 
7
  import re
8
  import threading
9
  import time
 
10
  from datetime import datetime
11
 
12
  import gradio as gr
 
14
  # Thread lock for safe state access
15
  training_state_lock = threading.Lock()
16
 
 
 
 
 
 
 
17
  # Constants
18
  BASELINE_CYCLES = 147734
19
  TARGET_CYCLES = 1363
 
20
 
21
  # Training state
22
  training_state = {
23
  "running": False,
 
 
24
  "best_cycles": BASELINE_CYCLES,
25
  "best_code": None,
26
  "log": [],
27
  "start_time": None,
 
28
  }
29
 
30
  SYSTEM_PROMPT = '''Write optimized VLIW/SIMD kernel. OUTPUT ONLY ONE ```python CODE BLOCK.
 
40
  ALGORITHM: 16 rounds x 256 items, hash tree traversal.
41
 
42
  OPTIMIZATION: Use vload/vstore (8 elements), pack 6 VALU ops/cycle, unroll loops.
43
+
44
+ Write complete code with OptimizedKernelBuilder class and run() function.
45
  '''
46
 
47
 
 
58
  return text.strip()
59
 
60
 
61
+ def simple_verify(code: str) -> dict:
62
+ """Simple verification without simulator."""
63
+ if not code:
64
+ return {"score": 0.0, "correctness": 0.0, "msg": "Empty"}
65
+ if "def run" not in code:
66
+ return {"score": 0.0, "correctness": 0.0, "msg": "No run()"}
67
+ if "OptimizedKernelBuilder" not in code:
68
+ return {"score": 0.0, "correctness": 0.0, "msg": "No class"}
69
+ if "build_kernel" not in code:
70
+ return {"score": 0.0, "correctness": 0.0, "msg": "No build_kernel"}
71
+ if "self.add" not in code:
72
+ return {"score": 0.1, "correctness": 0.5, "msg": "Structural OK"}
73
+ # Bonus for using vector ops
74
+ score = 0.2
75
+ if "vload" in code or "vstore" in code:
76
+ score += 0.3
77
+ if "valu" in code:
78
+ score += 0.3
79
+ return {"score": score, "correctness": 1.0, "msg": "Good structure"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
 
82
  def log(msg: str):
 
94
  for completion in completions:
95
  try:
96
  code = extract_code_block(completion)
97
+ result = simple_verify(code)
98
  reward = result["score"]
99
+ if result["correctness"] > 0.5:
100
+ reward += 0.5
101
+ with training_state_lock:
102
+ if not training_state["best_code"] or len(code) > len(training_state["best_code"] or ""):
103
+ training_state["best_code"] = code
104
+ log(f"New best code (score: {reward:.2f})")
 
 
 
 
 
 
105
  rewards.append(reward)
106
  except Exception as e:
 
107
  rewards.append(0.0)
 
108
  return rewards
109
 
110
 
 
112
  """Main training loop."""
113
  with training_state_lock:
114
  training_state["running"] = True
 
 
115
  training_state["best_cycles"] = BASELINE_CYCLES
116
  training_state["best_code"] = None
117
  training_state["log"] = []
 
118
  training_state["start_time"] = time.time()
119
 
120
  log(f"Starting: {model_name}")
 
138
  tokenizer.pad_token = tokenizer.eos_token
139
 
140
  prompt = f"{SYSTEM_PROMPT}\n\nCURRENT: {BASELINE_CYCLES} cycles. TARGET: <{TARGET_CYCLES}."
141
+ dataset = Dataset.from_dict({"prompt": [prompt] * 16})
142
 
143
  peft_config = LoraConfig(
144
  r=lora_rank,
 
159
  learning_rate=lr,
160
  logging_steps=1,
161
  save_steps=max(1, num_steps // 5),
162
+ max_completion_length=1024,
163
  temperature=0.7,
164
+ num_generations=2,
165
  beta=0.1,
166
  bf16=True,
167
  report_to="none",
 
220
  with training_state_lock:
221
  training_state["running"] = False
222
  elapsed = time.time() - (training_state["start_time"] or time.time())
223
+ log(f"Time: {elapsed/60:.1f} min")
 
224
 
225
 
226
  def start_training(model_name, num_steps, batch_size, lr, lora_rank):
 
244
  def get_status():
245
  with training_state_lock:
246
  if not training_state["start_time"]:
247
+ return "### Not started\nClick Start to begin training."
248
  elapsed = time.time() - training_state["start_time"]
 
249
  is_running = training_state["running"]
250
+ logs = training_state["log"][-25:]
251
 
 
252
  return f"""### {'Running' if is_running else 'Stopped'}
253
+ **Time:** {elapsed/60:.1f} min
 
 
 
 
 
254
 
255
+ **Log:**
256
  ```
257
  {chr(10).join(logs)}
258
  ```"""
 
266
  # UI
267
  with gr.Blocks(title="VLIW Optimizer") as demo:
268
  gr.Markdown("# VLIW Kernel Optimizer via RL")
269
+ gr.Markdown(f"**Baseline:** {BASELINE_CYCLES:,} | **Target:** {TARGET_CYCLES:,}")
270
 
271
  with gr.Row():
272
  with gr.Column():
273
  model = gr.Dropdown(
274
+ ["Qwen/Qwen2.5-Coder-3B-Instruct", "Qwen/Qwen2.5-Coder-1.5B-Instruct"],
275
  value="Qwen/Qwen2.5-Coder-3B-Instruct",
276
  label="Model"
277
  )
278
+ steps = gr.Slider(1, 50, value=10, step=1, label="Steps")
279
+ batch = gr.Slider(1, 4, value=2, step=1, label="Batch")
280
  lr = gr.Number(value=2e-4, label="LR")
281
+ lora = gr.Slider(8, 32, value=16, step=8, label="LoRA Rank")
282
  with gr.Row():
283
  start_btn = gr.Button("Start", variant="primary")
284
  stop_btn = gr.Button("Stop", variant="stop")
285
 
286
  with gr.Column():
287
+ status = gr.Markdown("### Not started\nClick Start to begin training.")
288
  refresh = gr.Button("Refresh")
289
 
290
  with gr.Row():