CreativeEngineer commited on
Commit
e721a4b
·
1 Parent(s): b03b587

Minimal test version

Browse files
Files changed (1) hide show
  1. app.py +64 -264
app.py CHANGED
@@ -1,301 +1,101 @@
1
  """
2
  HF Spaces app for VLIW kernel optimization via RL.
3
- Deploy to HF Spaces Pro (A10G GPU).
4
  """
5
- import os
6
- import sys
7
- import re
8
- import threading
9
- import time
10
- from datetime import datetime
11
-
12
  import gradio as gr
13
 
14
- # Thread lock for safe state access
15
- training_state_lock = threading.Lock()
16
-
17
- # Constants
18
- BASELINE_CYCLES = 147734
19
- TARGET_CYCLES = 1363
20
-
21
- # Training state
22
- training_state = {
23
- "running": False,
24
- "best_cycles": BASELINE_CYCLES,
25
- "best_code": None,
26
- "log": [],
27
- "start_time": None,
28
- }
29
-
30
- SYSTEM_PROMPT = '''Write optimized VLIW/SIMD kernel. OUTPUT ONLY ONE ```python CODE BLOCK.
31
-
32
- ARCHITECTURE: 12 ALU + 6 VALU (VLEN=8) + 2 load + 2 store + 1 flow slots per cycle.
33
-
34
- API:
35
- - alloc_scratch(name, length) -> addr
36
- - add(engine, slot): engine in {alu, valu, load, store, flow}
37
- - valu ops work on 8 elements at once
38
- - build(slots, vliw=True): pack into VLIW bundle
39
-
40
- ALGORITHM: 16 rounds x 256 items, hash tree traversal.
41
-
42
- OPTIMIZATION: Use vload/vstore (8 elements), pack 6 VALU ops/cycle, unroll loops.
43
-
44
- Write complete code with OptimizedKernelBuilder class and run() function.
45
- '''
46
-
47
-
48
- def extract_code_block(text: str) -> str:
49
- """Extract python code from markdown."""
50
- pattern = r"```python\s*(.*?)```"
51
- matches = re.findall(pattern, text, re.DOTALL)
52
- if matches:
53
- return matches[-1].strip()
54
- pattern = r"```\s*(.*?)```"
55
- matches = re.findall(pattern, text, re.DOTALL)
56
- if matches:
57
- return matches[-1].strip()
58
- return text.strip()
59
 
60
-
61
- def simple_verify(code: str) -> dict:
62
- """Simple verification without simulator."""
63
- if not code:
64
- return {"score": 0.0, "correctness": 0.0, "msg": "Empty"}
65
- if "def run" not in code:
66
- return {"score": 0.0, "correctness": 0.0, "msg": "No run()"}
67
- if "OptimizedKernelBuilder" not in code:
68
- return {"score": 0.0, "correctness": 0.0, "msg": "No class"}
69
- if "build_kernel" not in code:
70
- return {"score": 0.0, "correctness": 0.0, "msg": "No build_kernel"}
71
- if "self.add" not in code:
72
- return {"score": 0.1, "correctness": 0.5, "msg": "Structural OK"}
73
- # Bonus for using vector ops
74
- score = 0.2
75
- if "vload" in code or "vstore" in code:
76
- score += 0.3
77
- if "valu" in code:
78
- score += 0.3
79
- return {"score": score, "correctness": 1.0, "msg": "Good structure"}
80
-
81
-
82
- def log(msg: str):
83
- """Thread-safe logging."""
84
- timestamp = datetime.now().strftime("%H:%M:%S")
85
- formatted = f"[{timestamp}] {msg}"
86
- with training_state_lock:
87
- training_state["log"].append(formatted)
88
- print(formatted)
89
-
90
-
91
- def reward_function(completions: list[str], **kwargs) -> list[float]:
92
- """Compute rewards."""
93
- rewards = []
94
- for completion in completions:
95
- try:
96
- code = extract_code_block(completion)
97
- result = simple_verify(code)
98
- reward = result["score"]
99
- if result["correctness"] > 0.5:
100
- reward += 0.5
101
- with training_state_lock:
102
- if not training_state["best_code"] or len(code) > len(training_state["best_code"] or ""):
103
- training_state["best_code"] = code
104
- log(f"New best code (score: {reward:.2f})")
105
- rewards.append(reward)
106
- except Exception as e:
107
- rewards.append(0.0)
108
- return rewards
109
 
110
 
111
- def run_training(model_name: str, num_steps: int, batch_size: int, lr: float, lora_rank: int):
112
- """Main training loop."""
113
- with training_state_lock:
114
- training_state["running"] = True
115
- training_state["best_cycles"] = BASELINE_CYCLES
116
- training_state["best_code"] = None
117
- training_state["log"] = []
118
- training_state["start_time"] = time.time()
119
 
120
- log(f"Starting: {model_name}")
121
- log(f"Steps: {num_steps}, Batch: {batch_size}, LR: {lr}")
122
 
 
 
123
  try:
124
  import torch
125
  from datasets import Dataset
126
- from transformers import AutoTokenizer, BitsAndBytesConfig, TrainerCallback
127
  from peft import LoraConfig
128
  from trl import GRPOConfig, GRPOTrainer
129
 
130
- if torch.cuda.is_available():
131
- log(f"GPU: {torch.cuda.get_device_name(0)}")
132
- else:
133
- log("WARNING: No GPU!")
134
 
135
- log("Loading tokenizer...")
136
- tokenizer = AutoTokenizer.from_pretrained(model_name)
137
  if tokenizer.pad_token is None:
138
  tokenizer.pad_token = tokenizer.eos_token
 
139
 
140
- prompt = f"{SYSTEM_PROMPT}\n\nCURRENT: {BASELINE_CYCLES} cycles. TARGET: <{TARGET_CYCLES}."
141
- dataset = Dataset.from_dict({"prompt": [prompt] * 16})
142
-
143
- peft_config = LoraConfig(
144
- r=lora_rank,
145
- lora_alpha=lora_rank * 2,
146
- lora_dropout=0.05,
147
- bias="none",
148
- task_type="CAUSAL_LM",
149
- target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
150
- )
151
-
152
- output_dir = f"./output/{datetime.now().strftime('%Y%m%d-%H%M%S')}"
153
- os.makedirs(output_dir, exist_ok=True)
154
 
155
- training_args = GRPOConfig(
156
- output_dir=output_dir,
157
- num_train_epochs=num_steps,
158
- per_device_train_batch_size=batch_size,
159
- learning_rate=lr,
160
- logging_steps=1,
161
- save_steps=max(1, num_steps // 5),
162
- max_completion_length=1024,
163
- temperature=0.7,
164
- num_generations=2,
165
- beta=0.1,
166
- bf16=True,
167
  report_to="none",
168
  )
 
169
 
170
- quant_config = None
171
- if "7B" in model_name or "7b" in model_name:
172
- log("Using 4-bit quantization")
173
- quant_config = BitsAndBytesConfig(
174
- load_in_4bit=True,
175
- bnb_4bit_compute_dtype=torch.bfloat16,
176
- bnb_4bit_use_double_quant=True,
177
- bnb_4bit_quant_type="nf4",
178
- )
179
-
180
- log("Loading model...")
181
- model_kwargs = {}
182
- if quant_config:
183
- model_kwargs["quantization_config"] = quant_config
184
-
185
- class StopCallback(TrainerCallback):
186
- def on_step_end(self, args, state, control, **kwargs):
187
- if not training_state["running"]:
188
- log("Stopping...")
189
- control.should_training_stop = True
190
- return control
191
-
192
- trainer = GRPOTrainer(
193
- model=model_name,
194
- reward_funcs=[reward_function],
195
- args=training_args,
196
- train_dataset=dataset,
197
- peft_config=peft_config,
198
- processing_class=tokenizer,
199
- model_init_kwargs=model_kwargs,
200
- callbacks=[StopCallback()],
201
- )
202
-
203
- log("Model loaded! Training...")
204
- trainer.train()
205
- log("Training complete!")
206
-
207
- trainer.save_model(os.path.join(output_dir, "final"))
208
- log(f"Saved to {output_dir}")
209
-
210
- if training_state["best_code"]:
211
- with open(os.path.join(output_dir, "best_code.py"), "w") as f:
212
- f.write(training_state["best_code"])
213
 
214
  except Exception as e:
215
  import traceback
216
- log(f"ERROR: {e}")
217
- log(traceback.format_exc()[:500])
218
-
219
- finally:
220
- with training_state_lock:
221
- training_state["running"] = False
222
- elapsed = time.time() - (training_state["start_time"] or time.time())
223
- log(f"Time: {elapsed/60:.1f} min")
224
-
225
-
226
- def start_training(model_name, num_steps, batch_size, lr, lora_rank):
227
- if training_state["running"]:
228
- return "Already running!"
229
- thread = threading.Thread(
230
- target=run_training,
231
- args=(model_name, int(num_steps), int(batch_size), float(lr), int(lora_rank)),
232
- daemon=False
233
- )
234
- thread.start()
235
- return "Training started!"
236
-
237
 
238
- def stop_training():
239
- with training_state_lock:
240
- training_state["running"] = False
241
- return "Stop signal sent."
242
 
243
-
244
- def get_status():
245
- with training_state_lock:
246
- if not training_state["start_time"]:
247
- return "### Not started\nClick Start to begin training."
248
- elapsed = time.time() - training_state["start_time"]
249
- is_running = training_state["running"]
250
- logs = training_state["log"][-25:]
251
-
252
- return f"""### {'Running' if is_running else 'Stopped'}
253
- **Time:** {elapsed/60:.1f} min
254
-
255
- **Log:**
256
- ```
257
- {chr(10).join(logs)}
258
- ```"""
259
-
260
-
261
- def get_best_code():
262
- with training_state_lock:
263
- return training_state["best_code"] or "# No valid code yet"
264
-
265
-
266
- # UI
267
- with gr.Blocks(title="VLIW Optimizer") as demo:
268
- gr.Markdown("# VLIW Kernel Optimizer via RL")
269
- gr.Markdown(f"**Baseline:** {BASELINE_CYCLES:,} | **Target:** {TARGET_CYCLES:,}")
270
 
271
  with gr.Row():
272
- with gr.Column():
273
- model = gr.Dropdown(
274
- ["Qwen/Qwen2.5-Coder-3B-Instruct", "Qwen/Qwen2.5-Coder-1.5B-Instruct"],
275
- value="Qwen/Qwen2.5-Coder-3B-Instruct",
276
- label="Model"
277
- )
278
- steps = gr.Slider(1, 50, value=10, step=1, label="Steps")
279
- batch = gr.Slider(1, 4, value=2, step=1, label="Batch")
280
- lr = gr.Number(value=2e-4, label="LR")
281
- lora = gr.Slider(8, 32, value=16, step=8, label="LoRA Rank")
282
- with gr.Row():
283
- start_btn = gr.Button("Start", variant="primary")
284
- stop_btn = gr.Button("Stop", variant="stop")
285
-
286
- with gr.Column():
287
- status = gr.Markdown("### Not started\nClick Start to begin training.")
288
- refresh = gr.Button("Refresh")
289
 
290
  with gr.Row():
291
- code_out = gr.Code(label="Best Code", language="python", lines=20)
292
- code_btn = gr.Button("Show Best Code")
 
 
 
 
 
 
 
293
 
294
- start_btn.click(start_training, [model, steps, batch, lr, lora], [status])
295
- stop_btn.click(stop_training, outputs=[status])
296
- refresh.click(get_status, outputs=[status])
297
- code_btn.click(get_best_code, outputs=[code_out])
298
- demo.load(get_status, outputs=[status], every=5)
299
 
300
  if __name__ == "__main__":
301
  demo.launch(server_name="0.0.0.0", server_port=7860)
 
1
  """
2
  HF Spaces app for VLIW kernel optimization via RL.
3
+ Minimal working version.
4
  """
 
 
 
 
 
 
 
5
  import gradio as gr
6
 
7
+ # Check all imports at startup
8
+ startup_log = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
+ def check_import(name, import_fn):
11
+ try:
12
+ result = import_fn()
13
+ startup_log.append(f"✓ {name}: {result}")
14
+ return True
15
+ except Exception as e:
16
+ startup_log.append(f" {name}: {str(e)[:80]}")
17
+ return False
18
+
19
+ check_import("torch", lambda: __import__("torch").__version__)
20
+ check_import("transformers", lambda: __import__("transformers").__version__)
21
+ check_import("datasets", lambda: __import__("datasets").__version__)
22
+ check_import("peft", lambda: __import__("peft").__version__)
23
+ check_import("trl", lambda: __import__("trl").__version__)
24
+
25
+ try:
26
+ from trl import GRPOConfig, GRPOTrainer
27
+ startup_log.append(" GRPOTrainer: OK")
28
+ except Exception as e:
29
+ startup_log.append(f" GRPOTrainer: {e}")
30
+
31
+ try:
32
+ import torch
33
+ if torch.cuda.is_available():
34
+ startup_log.append(f"✓ CUDA: {torch.cuda.get_device_name(0)}")
35
+ else:
36
+ startup_log.append("✗ CUDA: Not available")
37
+ except Exception as e:
38
+ startup_log.append(f"✗ CUDA check: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
 
41
+ def get_status():
42
+ return "\n".join(startup_log)
 
 
 
 
 
 
43
 
 
 
44
 
45
+ def test_train(model_name, steps):
46
+ """Test training function."""
47
  try:
48
  import torch
49
  from datasets import Dataset
50
+ from transformers import AutoTokenizer
51
  from peft import LoraConfig
52
  from trl import GRPOConfig, GRPOTrainer
53
 
54
+ log = [f"Testing with {model_name}, {steps} steps"]
55
+ log.append(f"CUDA available: {torch.cuda.is_available()}")
 
 
56
 
57
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
 
58
  if tokenizer.pad_token is None:
59
  tokenizer.pad_token = tokenizer.eos_token
60
+ log.append("Tokenizer loaded")
61
 
62
+ dataset = Dataset.from_dict({"prompt": ["Write hello world"] * 4})
63
+ log.append("Dataset created")
 
 
 
 
 
 
 
 
 
 
 
 
64
 
65
+ # Just check config creation
66
+ config = GRPOConfig(
67
+ output_dir="./test",
68
+ num_train_epochs=1,
69
+ per_device_train_batch_size=1,
 
 
 
 
 
 
 
70
  report_to="none",
71
  )
72
+ log.append("Config created")
73
 
74
+ return "\n".join(log)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
  except Exception as e:
77
  import traceback
78
+ return f"Error: {e}\n\n{traceback.format_exc()}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
 
 
 
 
 
80
 
81
+ with gr.Blocks(title="VLIW Test") as demo:
82
+ gr.Markdown("# VLIW Optimizer - Test Mode")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
 
84
  with gr.Row():
85
+ status_box = gr.Textbox(label="Startup Status", value=get_status(), lines=15)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
 
87
  with gr.Row():
88
+ model = gr.Dropdown(
89
+ ["Qwen/Qwen2.5-Coder-1.5B-Instruct", "Qwen/Qwen2.5-Coder-3B-Instruct"],
90
+ value="Qwen/Qwen2.5-Coder-1.5B-Instruct",
91
+ label="Model"
92
+ )
93
+ steps = gr.Slider(1, 5, value=1, label="Steps")
94
+
95
+ test_btn = gr.Button("Test Training Setup", variant="primary")
96
+ output = gr.Textbox(label="Output", lines=15)
97
 
98
+ test_btn.click(test_train, [model, steps], [output])
 
 
 
 
99
 
100
  if __name__ == "__main__":
101
  demo.launch(server_name="0.0.0.0", server_port=7860)