CreativeEngineer commited on
Commit
9c10799
·
1 Parent(s): a7473b9

Add GRPO training with proper state management

Browse files
Files changed (1) hide show
  1. app.py +187 -41
app.py CHANGED
@@ -3,6 +3,7 @@ HF Spaces app for VLIW kernel optimization via RL.
3
  """
4
  import gradio as gr
5
  import threading
 
6
 
7
  # Check imports at startup
8
  startup_log = []
@@ -37,33 +38,64 @@ try:
37
  except Exception as e:
38
  startup_log.append(f"✗ CUDA check: {e}")
39
 
40
- # Global state
41
- training_log = []
42
- is_training = False
 
 
 
 
43
 
44
 
45
  def get_status():
46
  return "\n".join(startup_log)
47
 
48
 
49
- def test_model_load(model_name):
50
- """Test loading the model."""
51
- global training_log
52
- training_log = []
 
 
 
 
 
 
53
 
54
- try:
55
- import torch
56
- from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
57
 
58
- training_log.append(f"Testing model: {model_name}")
59
- training_log.append("Loading tokenizer...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
 
 
61
  tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
62
  if tokenizer.pad_token is None:
63
  tokenizer.pad_token = tokenizer.eos_token
64
- training_log.append("✓ Tokenizer loaded")
65
 
66
- training_log.append("Loading model with 4-bit quantization...")
 
67
  bnb_config = BitsAndBytesConfig(
68
  load_in_4bit=True,
69
  bnb_4bit_quant_type="nf4",
@@ -75,47 +107,147 @@ def test_model_load(model_name):
75
  device_map="auto",
76
  trust_remote_code=True,
77
  )
78
- training_log.append("✓ Model loaded")
79
 
80
- # Quick test
81
- training_log.append("Testing generation...")
82
- inputs = tokenizer("def hello():", return_tensors="pt").to(model.device)
83
- outputs = model.generate(**inputs, max_new_tokens=20)
84
- result = tokenizer.decode(outputs[0], skip_special_tokens=True)
85
- training_log.append(f" Generation test: {result[:50]}...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
 
87
- training_log.append("\n✓ All tests passed!")
 
 
 
 
 
 
88
 
89
- # Cleanup
90
- del model
91
- torch.cuda.empty_cache()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
 
93
  except Exception as e:
94
  import traceback
95
- training_log.append(f"✗ Error: {e}")
96
- training_log.append(traceback.format_exc())
 
 
 
 
 
 
 
 
 
 
 
97
 
98
- return "\n".join(training_log)
99
 
 
 
 
 
 
100
 
101
- def get_log():
102
- """Return current log."""
103
- if not training_log:
104
- return "No operations run yet. Click 'Test Model Loading' to start."
105
- return "\n".join(training_log)
 
 
 
 
 
 
 
 
 
 
 
 
 
106
 
107
 
108
  # Gradio UI
109
  with gr.Blocks(title="VLIW Optimizer") as demo:
110
- gr.Markdown("# VLIW Kernel Optimizer - Test Mode")
111
- gr.Markdown("Testing model loading and generation before full training.")
 
 
 
 
 
 
 
112
 
113
  with gr.Row():
114
  with gr.Column(scale=1):
115
  status_box = gr.Textbox(
116
  label="System Status",
117
  value=get_status(),
118
- lines=10,
119
  interactive=False,
120
  )
121
 
@@ -128,17 +260,31 @@ with gr.Blocks(title="VLIW Optimizer") as demo:
128
  value="Qwen/Qwen2.5-Coder-1.5B-Instruct",
129
  label="Model",
130
  )
 
 
 
 
 
 
 
131
 
132
- test_btn = gr.Button("Test Model Loading", variant="primary")
 
 
133
 
134
  output_box = gr.Textbox(
135
- label="Output",
136
- lines=15,
137
  interactive=False,
138
- value="Click 'Test Model Loading' to verify the setup.",
139
  )
140
 
141
- test_btn.click(test_model_load, [model_dropdown], [output_box])
 
 
 
 
 
142
 
143
  if __name__ == "__main__":
144
  demo.launch(server_name="0.0.0.0", server_port=7860)
 
3
  """
4
  import gradio as gr
5
  import threading
6
+ import time
7
 
8
  # Check imports at startup
9
  startup_log = []
 
38
  except Exception as e:
39
  startup_log.append(f"✗ CUDA check: {e}")
40
 
41
+ # Training state
42
+ training_state = {
43
+ "is_training": False,
44
+ "should_stop": False,
45
+ "log": [],
46
+ }
47
+ state_lock = threading.Lock()
48
 
49
 
50
  def get_status():
51
  return "\n".join(startup_log)
52
 
53
 
54
+ def simple_reward_fn(completions, **kwargs):
55
+ """Simple reward: prefer longer, code-like outputs."""
56
+ rewards = []
57
+ for c in completions:
58
+ text = c[0]["content"] if isinstance(c, list) else str(c)
59
+ score = min(len(text) / 200.0, 1.0)
60
+ if any(kw in text for kw in ["def ", "for ", "if ", "while ", "return "]):
61
+ score += 0.3
62
+ rewards.append(score)
63
+ return rewards
64
 
 
 
 
65
 
66
+ def run_training(model_name, num_steps):
67
+ """Run RL training with GRPO."""
68
+ import torch
69
+ from datasets import Dataset
70
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
71
+ from peft import LoraConfig
72
+ from trl import GRPOConfig, GRPOTrainer
73
+ from transformers import TrainerCallback
74
+
75
+ log = []
76
+
77
+ def add_log(msg):
78
+ log.append(f"[{time.strftime('%H:%M:%S')}] {msg}")
79
+ with state_lock:
80
+ training_state["log"] = log.copy()
81
+
82
+ with state_lock:
83
+ training_state["is_training"] = True
84
+ training_state["should_stop"] = False
85
+ training_state["log"] = []
86
+
87
+ try:
88
+ add_log(f"Starting training: {model_name}, {num_steps} steps")
89
 
90
+ # Load tokenizer
91
+ add_log("Loading tokenizer...")
92
  tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
93
  if tokenizer.pad_token is None:
94
  tokenizer.pad_token = tokenizer.eos_token
95
+ add_log("✓ Tokenizer ready")
96
 
97
+ # Load model with 4-bit quantization
98
+ add_log("Loading model (4-bit quantization)...")
99
  bnb_config = BitsAndBytesConfig(
100
  load_in_4bit=True,
101
  bnb_4bit_quant_type="nf4",
 
107
  device_map="auto",
108
  trust_remote_code=True,
109
  )
110
+ add_log(f"✓ Model loaded on {next(model.parameters()).device}")
111
 
112
+ # Create dataset
113
+ add_log("Creating training dataset...")
114
+ prompts = [
115
+ "Write optimized VLIW assembly code for matrix multiplication using SIMD instructions",
116
+ "Generate efficient parallel code for vector dot product",
117
+ "Create VLIW code for memory-bound reduction operation",
118
+ "Write pipelined code for element-wise array operations",
119
+ ] * 8 # 32 prompts total
120
+ dataset = Dataset.from_dict({"prompt": prompts})
121
+ add_log(f"✓ Dataset: {len(prompts)} prompts")
122
+
123
+ # LoRA config
124
+ add_log("Setting up LoRA...")
125
+ lora_config = LoraConfig(
126
+ r=16,
127
+ lora_alpha=32,
128
+ target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
129
+ lora_dropout=0.05,
130
+ bias="none",
131
+ task_type="CAUSAL_LM",
132
+ )
133
 
134
+ # Stop callback
135
+ class StopCallback(TrainerCallback):
136
+ def on_step_end(self, args, state, control, **kwargs):
137
+ with state_lock:
138
+ if training_state["should_stop"]:
139
+ control.should_training_stop = True
140
+ return control
141
 
142
+ # GRPO config
143
+ add_log("Creating GRPO trainer...")
144
+ config = GRPOConfig(
145
+ output_dir="./grpo_output",
146
+ num_train_epochs=1,
147
+ max_steps=num_steps,
148
+ per_device_train_batch_size=2,
149
+ gradient_accumulation_steps=2,
150
+ learning_rate=5e-6,
151
+ logging_steps=1,
152
+ save_steps=999999, # Don't save checkpoints
153
+ report_to="none",
154
+ remove_unused_columns=False,
155
+ max_completion_length=128,
156
+ num_generations=4,
157
+ )
158
+
159
+ trainer = GRPOTrainer(
160
+ model=model,
161
+ args=config,
162
+ train_dataset=dataset,
163
+ reward_funcs=simple_reward_fn,
164
+ peft_config=lora_config,
165
+ processing_class=tokenizer,
166
+ callbacks=[StopCallback()],
167
+ )
168
+ add_log("✓ Trainer ready")
169
+
170
+ # Train
171
+ add_log("Starting training loop...")
172
+ train_result = trainer.train()
173
+
174
+ metrics = train_result.metrics
175
+ add_log(f"✓ Training complete!")
176
+ add_log(f" Steps: {metrics.get('train_steps', 'N/A')}")
177
+ add_log(f" Loss: {metrics.get('train_loss', 'N/A'):.4f}" if 'train_loss' in metrics else " Loss: N/A")
178
+
179
+ # Test generation
180
+ add_log("Testing trained model...")
181
+ test_prompt = "Write efficient VLIW code for:"
182
+ inputs = tokenizer(test_prompt, return_tensors="pt").to(model.device)
183
+ with torch.no_grad():
184
+ outputs = model.generate(**inputs, max_new_tokens=50, do_sample=True, temperature=0.7)
185
+ result = tokenizer.decode(outputs[0], skip_special_tokens=True)
186
+ add_log(f"Sample output: {result[:100]}...")
187
+
188
+ add_log("\n✓ All done!")
189
 
190
  except Exception as e:
191
  import traceback
192
+ add_log(f"✗ Error: {e}")
193
+ add_log(traceback.format_exc()[:500])
194
+ finally:
195
+ with state_lock:
196
+ training_state["is_training"] = False
197
+ # Cleanup
198
+ try:
199
+ del model
200
+ torch.cuda.empty_cache()
201
+ except:
202
+ pass
203
+
204
+ return "\n".join(log)
205
 
 
206
 
207
+ def start_training(model_name, num_steps):
208
+ """Start training (blocking for simplicity)."""
209
+ with state_lock:
210
+ if training_state["is_training"]:
211
+ return "Training already in progress. Please wait."
212
 
213
+ return run_training(model_name, int(num_steps))
214
+
215
+
216
+ def stop_training():
217
+ """Request stop."""
218
+ with state_lock:
219
+ if not training_state["is_training"]:
220
+ return "No training in progress"
221
+ training_state["should_stop"] = True
222
+ return "Stop requested. Training will stop after current step."
223
+
224
+
225
+ def get_progress():
226
+ """Get current log."""
227
+ with state_lock:
228
+ if not training_state["log"]:
229
+ return "No training started yet"
230
+ return "\n".join(training_state["log"])
231
 
232
 
233
  # Gradio UI
234
  with gr.Blocks(title="VLIW Optimizer") as demo:
235
+ gr.Markdown("# VLIW Kernel Optimizer - RL Training")
236
+ gr.Markdown("""
237
+ Train a language model with reinforcement learning to generate optimized VLIW/SIMD code.
238
+
239
+ **Instructions:**
240
+ 1. Select a model (1.5B is faster, 3B may produce better results)
241
+ 2. Set training steps (10-50 recommended for testing)
242
+ 3. Click 'Start Training' and wait for completion
243
+ """)
244
 
245
  with gr.Row():
246
  with gr.Column(scale=1):
247
  status_box = gr.Textbox(
248
  label="System Status",
249
  value=get_status(),
250
+ lines=9,
251
  interactive=False,
252
  )
253
 
 
260
  value="Qwen/Qwen2.5-Coder-1.5B-Instruct",
261
  label="Model",
262
  )
263
+ steps_slider = gr.Slider(
264
+ minimum=1,
265
+ maximum=100,
266
+ value=10,
267
+ step=1,
268
+ label="Training Steps",
269
+ )
270
 
271
+ with gr.Row():
272
+ start_btn = gr.Button("Start Training", variant="primary")
273
+ stop_btn = gr.Button("Stop", variant="stop")
274
 
275
  output_box = gr.Textbox(
276
+ label="Training Log",
277
+ lines=20,
278
  interactive=False,
279
+ value="Click 'Start Training' to begin.",
280
  )
281
 
282
+ start_btn.click(
283
+ start_training,
284
+ [model_dropdown, steps_slider],
285
+ [output_box],
286
+ )
287
+ stop_btn.click(stop_training, [], [output_box])
288
 
289
  if __name__ == "__main__":
290
  demo.launch(server_name="0.0.0.0", server_port=7860)