CreativeEngineer commited on
Commit
77e3392
·
1 Parent(s): f4dcd8f

Full training app with verified imports

Browse files
Files changed (1) hide show
  1. app.py +349 -79
app.py CHANGED
@@ -1,101 +1,371 @@
1
  """
2
  HF Spaces app for VLIW kernel optimization via RL.
3
- Minimal version for debugging.
4
  """
5
  import os
6
  import sys
 
 
 
 
 
 
7
  import gradio as gr
8
 
9
- # Check imports on startup
10
- startup_log = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
- def check_import(name, import_fn):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  try:
14
- result = import_fn()
15
- startup_log.append(f" {name}: {result}")
16
- return True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  except Exception as e:
18
- startup_log.append(f" {name}: {str(e)[:100]}")
19
- return False
20
-
21
- # Test imports
22
- check_import("torch", lambda: __import__("torch").__version__)
23
- check_import("transformers", lambda: __import__("transformers").__version__)
24
- check_import("datasets", lambda: __import__("datasets").__version__)
25
- check_import("peft", lambda: __import__("peft").__version__)
26
- check_import("trl", lambda: __import__("trl").__version__)
27
- check_import("accelerate", lambda: __import__("accelerate").__version__)
28
-
29
- # Try GRPO import
30
- try:
31
- from trl import GRPOConfig, GRPOTrainer
32
- startup_log.append("✓ GRPOTrainer: imported from trl")
33
- except ImportError as e:
34
- startup_log.append(f"✗ GRPOTrainer from trl: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  try:
36
- from trl.trainer.grpo_trainer import GRPOConfig, GRPOTrainer
37
- startup_log.append("✓ GRPOTrainer: imported from trl.trainer.grpo_trainer")
38
- except ImportError as e2:
39
- startup_log.append(f"✗ GRPOTrainer alt: {e2}")
40
-
41
- # Check CUDA
42
- try:
43
- import torch
44
- if torch.cuda.is_available():
45
- startup_log.append(f"✓ CUDA: {torch.cuda.get_device_name(0)}")
46
- else:
47
- startup_log.append("✗ CUDA: Not available")
48
- except:
49
- startup_log.append("✗ CUDA: Could not check")
50
-
51
- # Check simulator
52
- try:
53
- SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
54
- PERF_PATH = os.path.join(SCRIPT_DIR, "original_performance_takehome")
55
- if os.path.exists(PERF_PATH):
56
- sys.path.insert(0, PERF_PATH)
57
- from problem import Machine, Tree
58
- startup_log.append("✓ Simulator: loaded")
59
- else:
60
- startup_log.append(f"✗ Simulator: path not found ({PERF_PATH})")
61
- except Exception as e:
62
- startup_log.append(f"✗ Simulator: {e}")
63
-
64
-
65
- def get_startup_log():
66
- return "\n".join(startup_log)
67
-
68
-
69
- def dummy_train(model, steps):
70
- return f"Would train {model} for {steps} steps\n\nImport status:\n" + get_startup_log()
71
-
72
-
73
- # Simple UI
74
- with gr.Blocks(title="VLIW Optimizer") as demo:
75
- gr.Markdown("# VLIW Kernel Optimizer - Debug Mode")
76
- gr.Markdown("Checking if all imports work...")
77
 
78
- with gr.Row():
79
- with gr.Column():
80
- status = gr.Textbox(
81
- label="Startup Log",
82
- value=get_startup_log(),
83
- lines=20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  )
85
- refresh_btn = gr.Button("Refresh Status")
86
- refresh_btn.click(get_startup_log, outputs=[status])
87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  with gr.Column():
89
  model = gr.Dropdown(
90
- choices=["Qwen/Qwen2.5-Coder-1.5B-Instruct", "Qwen/Qwen2.5-Coder-3B-Instruct"],
91
- value="Qwen/Qwen2.5-Coder-1.5B-Instruct",
92
  label="Model"
93
  )
94
- steps = gr.Slider(1, 10, value=3, label="Steps")
95
- train_btn = gr.Button("Test Train", variant="primary")
96
- output = gr.Textbox(label="Output", lines=10)
97
- train_btn.click(dummy_train, inputs=[model, steps], outputs=[output])
 
 
 
 
 
 
 
 
 
 
 
98
 
 
 
 
 
 
99
 
100
  if __name__ == "__main__":
101
  demo.launch(server_name="0.0.0.0", server_port=7860)
 
1
  """
2
  HF Spaces app for VLIW kernel optimization via RL.
3
+ Deploy to HF Spaces Pro (A10G GPU).
4
  """
5
  import os
6
  import sys
7
+ import re
8
+ import threading
9
+ import time
10
+ import random
11
+ from datetime import datetime
12
+
13
  import gradio as gr
14
 
15
+ # Thread lock for safe state access
16
+ training_state_lock = threading.Lock()
17
+
18
+ # Add simulator path
19
+ SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
20
+ PERF_TAKEHOME_PATH = os.path.join(SCRIPT_DIR, "original_performance_takehome")
21
+ if os.path.exists(PERF_TAKEHOME_PATH):
22
+ sys.path.insert(0, PERF_TAKEHOME_PATH)
23
+
24
+ # Constants
25
+ BASELINE_CYCLES = 147734
26
+ TARGET_CYCLES = 1363
27
+ SCORE_SCALE = 3000.0
28
+
29
+ # Training state
30
+ training_state = {
31
+ "running": False,
32
+ "step": 0,
33
+ "total_steps": 0,
34
+ "best_cycles": BASELINE_CYCLES,
35
+ "best_code": None,
36
+ "log": [],
37
+ "start_time": None,
38
+ "results": [],
39
+ }
40
+
41
+ SYSTEM_PROMPT = '''Write optimized VLIW/SIMD kernel. OUTPUT ONLY ONE ```python CODE BLOCK.
42
+
43
+ ARCHITECTURE: 12 ALU + 6 VALU (VLEN=8) + 2 load + 2 store + 1 flow slots per cycle.
44
+
45
+ API:
46
+ - alloc_scratch(name, length) -> addr
47
+ - add(engine, slot): engine in {alu, valu, load, store, flow}
48
+ - valu ops work on 8 elements at once
49
+ - build(slots, vliw=True): pack into VLIW bundle
50
+
51
+ ALGORITHM: 16 rounds x 256 items, hash tree traversal.
52
+
53
+ OPTIMIZATION: Use vload/vstore (8 elements), pack 6 VALU ops/cycle, unroll loops.
54
+ '''
55
+
56
 
57
+ def extract_code_block(text: str) -> str:
58
+ """Extract python code from markdown."""
59
+ pattern = r"```python\s*(.*?)```"
60
+ matches = re.findall(pattern, text, re.DOTALL)
61
+ if matches:
62
+ return matches[-1].strip()
63
+ pattern = r"```\s*(.*?)```"
64
+ matches = re.findall(pattern, text, re.DOTALL)
65
+ if matches:
66
+ return matches[-1].strip()
67
+ return text.strip()
68
+
69
+
70
+ def verify_code(code: str) -> dict:
71
+ """Verify kernel code and return metrics."""
72
  try:
73
+ if not code or "def run" not in code:
74
+ return {"score": 0.0, "correctness": 0.0, "cycles": None, "msg": "Invalid code"}
75
+
76
+ if "OptimizedKernelBuilder" not in code:
77
+ return {"score": 0.0, "correctness": 0.0, "cycles": None, "msg": "No OptimizedKernelBuilder"}
78
+
79
+ exec_globals = {"FOREST_HEIGHT": 10, "ROUNDS": 16, "BATCH_SIZE": 256}
80
+
81
+ setup_code = f'''
82
+ import sys
83
+ sys.path.insert(0, "{PERF_TAKEHOME_PATH}")
84
+ from problem import Machine, Tree, Input, build_mem_image, N_CORES, reference_kernel2
85
+ from perf_takehome import KernelBuilder, HASH_STAGES, BASELINE
86
+ import random
87
+ '''
88
+ full_code = setup_code + "\n" + code
89
+ exec(full_code, exec_globals)
90
+
91
+ if "OptimizedKernelBuilder" not in exec_globals:
92
+ return {"score": 0.0, "correctness": 0.0, "cycles": None, "msg": "Class not defined"}
93
+
94
+ random.seed(123)
95
+ from problem import Tree, Input, Machine, build_mem_image, N_CORES, reference_kernel2
96
+
97
+ forest = Tree.generate(10)
98
+ inp = Input.generate(forest, 256, 16)
99
+ mem = build_mem_image(forest, inp)
100
+
101
+ ref_mem = None
102
+ for ref_mem in reference_kernel2(list(mem)):
103
+ pass
104
+
105
+ if ref_mem is None:
106
+ return {"score": 0.0, "correctness": 0.0, "cycles": None, "msg": "Reference failed"}
107
+
108
+ kb = exec_globals["OptimizedKernelBuilder"]()
109
+ kb.build_kernel(10, len(forest.values), 256, 16)
110
+ machine = Machine(list(mem), kb.instrs, kb.debug_info(), n_cores=N_CORES)
111
+ machine.enable_pause = False
112
+ machine.enable_debug = False
113
+ machine.run()
114
+
115
+ cycles = machine.cycle
116
+
117
+ if cycles <= 100 or cycles > 200000:
118
+ return {"score": 0.0, "correctness": 0.0, "cycles": cycles, "msg": f"Bad cycles: {cycles}"}
119
+
120
+ inp_values_p = ref_mem[6]
121
+ expected = ref_mem[inp_values_p : inp_values_p + len(inp.values)]
122
+ actual = machine.mem[inp_values_p : inp_values_p + len(inp.values)]
123
+
124
+ if expected != actual:
125
+ return {"score": 0.0, "correctness": 0.0, "cycles": cycles, "msg": "Wrong output"}
126
+
127
+ score = SCORE_SCALE / cycles
128
+ return {"score": score, "correctness": 1.0, "cycles": cycles, "msg": f"OK: {cycles} cycles"}
129
+
130
  except Exception as e:
131
+ return {"score": 0.0, "correctness": 0.0, "cycles": None, "msg": f"Error: {str(e)[:100]}"}
132
+
133
+
134
+ def log(msg: str):
135
+ """Thread-safe logging."""
136
+ timestamp = datetime.now().strftime("%H:%M:%S")
137
+ formatted = f"[{timestamp}] {msg}"
138
+ with training_state_lock:
139
+ training_state["log"].append(formatted)
140
+ print(formatted)
141
+
142
+
143
+ def reward_function(completions: list[str], **kwargs) -> list[float]:
144
+ """Compute rewards."""
145
+ rewards = []
146
+ for completion in completions:
147
+ try:
148
+ code = extract_code_block(completion)
149
+ result = verify_code(code)
150
+ reward = result["score"]
151
+
152
+ if result["correctness"] > 0:
153
+ reward += 1.0
154
+ cycles = result.get("cycles")
155
+ if cycles:
156
+ with training_state_lock:
157
+ training_state["results"].append({"cycles": cycles, "time": time.time()})
158
+ if cycles < training_state["best_cycles"]:
159
+ training_state["best_cycles"] = cycles
160
+ training_state["best_code"] = code
161
+ log(f"NEW BEST: {cycles:,} cycles ({BASELINE_CYCLES/cycles:.2f}x)")
162
+
163
+ rewards.append(reward)
164
+ except Exception as e:
165
+ log(f"Reward error: {str(e)[:50]}")
166
+ rewards.append(0.0)
167
+
168
+ return rewards
169
+
170
+
171
+ def run_training(model_name: str, num_steps: int, batch_size: int, lr: float, lora_rank: int):
172
+ """Main training loop."""
173
+ with training_state_lock:
174
+ training_state["running"] = True
175
+ training_state["step"] = 0
176
+ training_state["total_steps"] = num_steps
177
+ training_state["best_cycles"] = BASELINE_CYCLES
178
+ training_state["best_code"] = None
179
+ training_state["log"] = []
180
+ training_state["results"] = []
181
+ training_state["start_time"] = time.time()
182
+
183
+ log(f"Starting: {model_name}")
184
+ log(f"Steps: {num_steps}, Batch: {batch_size}, LR: {lr}")
185
+
186
  try:
187
+ import torch
188
+ from datasets import Dataset
189
+ from transformers import AutoTokenizer, BitsAndBytesConfig, TrainerCallback
190
+ from peft import LoraConfig
191
+ from trl import GRPOConfig, GRPOTrainer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
192
 
193
+ if torch.cuda.is_available():
194
+ log(f"GPU: {torch.cuda.get_device_name(0)}")
195
+ else:
196
+ log("WARNING: No GPU!")
197
+
198
+ log("Loading tokenizer...")
199
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
200
+ if tokenizer.pad_token is None:
201
+ tokenizer.pad_token = tokenizer.eos_token
202
+
203
+ prompt = f"{SYSTEM_PROMPT}\n\nCURRENT: {BASELINE_CYCLES} cycles. TARGET: <{TARGET_CYCLES}."
204
+ dataset = Dataset.from_dict({"prompt": [prompt] * 32})
205
+
206
+ peft_config = LoraConfig(
207
+ r=lora_rank,
208
+ lora_alpha=lora_rank * 2,
209
+ lora_dropout=0.05,
210
+ bias="none",
211
+ task_type="CAUSAL_LM",
212
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
213
+ )
214
+
215
+ output_dir = f"./output/{datetime.now().strftime('%Y%m%d-%H%M%S')}"
216
+ os.makedirs(output_dir, exist_ok=True)
217
+
218
+ training_args = GRPOConfig(
219
+ output_dir=output_dir,
220
+ num_train_epochs=num_steps,
221
+ per_device_train_batch_size=batch_size,
222
+ learning_rate=lr,
223
+ logging_steps=1,
224
+ save_steps=max(1, num_steps // 5),
225
+ max_completion_length=2048,
226
+ temperature=0.7,
227
+ num_generations=4,
228
+ beta=0.1,
229
+ bf16=True,
230
+ report_to="none",
231
+ )
232
+
233
+ quant_config = None
234
+ if "7B" in model_name or "7b" in model_name:
235
+ log("Using 4-bit quantization")
236
+ quant_config = BitsAndBytesConfig(
237
+ load_in_4bit=True,
238
+ bnb_4bit_compute_dtype=torch.bfloat16,
239
+ bnb_4bit_use_double_quant=True,
240
+ bnb_4bit_quant_type="nf4",
241
  )
 
 
242
 
243
+ log("Loading model...")
244
+ model_kwargs = {}
245
+ if quant_config:
246
+ model_kwargs["quantization_config"] = quant_config
247
+
248
+ class StopCallback(TrainerCallback):
249
+ def on_step_end(self, args, state, control, **kwargs):
250
+ if not training_state["running"]:
251
+ log("Stopping...")
252
+ control.should_training_stop = True
253
+ return control
254
+
255
+ trainer = GRPOTrainer(
256
+ model=model_name,
257
+ reward_funcs=[reward_function],
258
+ args=training_args,
259
+ train_dataset=dataset,
260
+ peft_config=peft_config,
261
+ processing_class=tokenizer,
262
+ model_init_kwargs=model_kwargs,
263
+ callbacks=[StopCallback()],
264
+ )
265
+
266
+ log("Model loaded! Training...")
267
+ trainer.train()
268
+ log("Training complete!")
269
+
270
+ trainer.save_model(os.path.join(output_dir, "final"))
271
+ log(f"Saved to {output_dir}")
272
+
273
+ if training_state["best_code"]:
274
+ with open(os.path.join(output_dir, "best_code.py"), "w") as f:
275
+ f.write(training_state["best_code"])
276
+
277
+ except Exception as e:
278
+ import traceback
279
+ log(f"ERROR: {e}")
280
+ log(traceback.format_exc()[:500])
281
+
282
+ finally:
283
+ with training_state_lock:
284
+ training_state["running"] = False
285
+ elapsed = time.time() - (training_state["start_time"] or time.time())
286
+ best = training_state["best_cycles"]
287
+ log(f"Time: {elapsed/60:.1f} min, Best: {best:,} cycles")
288
+
289
+
290
+ def start_training(model_name, num_steps, batch_size, lr, lora_rank):
291
+ if training_state["running"]:
292
+ return "Already running!"
293
+ thread = threading.Thread(
294
+ target=run_training,
295
+ args=(model_name, int(num_steps), int(batch_size), float(lr), int(lora_rank)),
296
+ daemon=False
297
+ )
298
+ thread.start()
299
+ return "Training started!"
300
+
301
+
302
+ def stop_training():
303
+ with training_state_lock:
304
+ training_state["running"] = False
305
+ return "Stop signal sent."
306
+
307
+
308
+ def get_status():
309
+ with training_state_lock:
310
+ if not training_state["start_time"]:
311
+ return "### Not started"
312
+ elapsed = time.time() - training_state["start_time"]
313
+ best = max(training_state["best_cycles"], 1)
314
+ is_running = training_state["running"]
315
+ logs = training_state["log"][-20:]
316
+
317
+ speedup = BASELINE_CYCLES / best
318
+ return f"""### {'Running' if is_running else 'Stopped'}
319
+ | Metric | Value |
320
+ |--------|-------|
321
+ | Time | {elapsed/60:.1f} min |
322
+ | Best | **{best:,}** cycles |
323
+ | Speedup | **{speedup:.2f}x** |
324
+ | Target | {TARGET_CYCLES:,} |
325
+
326
+ ```
327
+ {chr(10).join(logs)}
328
+ ```"""
329
+
330
+
331
+ def get_best_code():
332
+ with training_state_lock:
333
+ return training_state["best_code"] or "# No valid code yet"
334
+
335
+
336
+ # UI
337
+ with gr.Blocks(title="VLIW Optimizer") as demo:
338
+ gr.Markdown("# VLIW Kernel Optimizer via RL")
339
+ gr.Markdown(f"**Baseline:** {BASELINE_CYCLES:,} | **Target:** {TARGET_CYCLES:,} (108x speedup)")
340
+
341
+ with gr.Row():
342
  with gr.Column():
343
  model = gr.Dropdown(
344
+ ["Qwen/Qwen2.5-Coder-7B-Instruct", "Qwen/Qwen2.5-Coder-3B-Instruct"],
345
+ value="Qwen/Qwen2.5-Coder-3B-Instruct",
346
  label="Model"
347
  )
348
+ steps = gr.Slider(1, 100, value=50, step=1, label="Steps")
349
+ batch = gr.Slider(1, 8, value=4, step=1, label="Batch")
350
+ lr = gr.Number(value=2e-4, label="LR")
351
+ lora = gr.Slider(8, 64, value=32, step=8, label="LoRA Rank")
352
+ with gr.Row():
353
+ start_btn = gr.Button("Start", variant="primary")
354
+ stop_btn = gr.Button("Stop", variant="stop")
355
+
356
+ with gr.Column():
357
+ status = gr.Markdown("### Not started")
358
+ refresh = gr.Button("Refresh")
359
+
360
+ with gr.Row():
361
+ code_out = gr.Code(label="Best Code", language="python", lines=20)
362
+ code_btn = gr.Button("Show Best Code")
363
 
364
+ start_btn.click(start_training, [model, steps, batch, lr, lora], [status])
365
+ stop_btn.click(stop_training, outputs=[status])
366
+ refresh.click(get_status, outputs=[status])
367
+ code_btn.click(get_best_code, outputs=[code_out])
368
+ demo.load(get_status, outputs=[status], every=5)
369
 
370
  if __name__ == "__main__":
371
  demo.launch(server_name="0.0.0.0", server_port=7860)