CreativeEngineer commited on
Commit
cccd413
·
1 Parent(s): e721a4b

Add training functionality with GRPO

Browse files
Files changed (1) hide show
  1. app.py +192 -33
app.py CHANGED
@@ -1,10 +1,10 @@
1
  """
2
  HF Spaces app for VLIW kernel optimization via RL.
3
- Minimal working version.
4
  """
5
  import gradio as gr
 
6
 
7
- # Check all imports at startup
8
  startup_log = []
9
 
10
  def check_import(name, import_fn):
@@ -37,65 +37,224 @@ try:
37
  except Exception as e:
38
  startup_log.append(f"✗ CUDA check: {e}")
39
 
 
 
 
 
 
 
 
 
 
 
40
 
41
  def get_status():
42
  return "\n".join(startup_log)
43
 
44
 
45
- def test_train(model_name, steps):
46
- """Test training function."""
47
- try:
48
- import torch
49
- from datasets import Dataset
50
- from transformers import AutoTokenizer
51
- from peft import LoraConfig
52
- from trl import GRPOConfig, GRPOTrainer
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
- log = [f"Testing with {model_name}, {steps} steps"]
55
- log.append(f"CUDA available: {torch.cuda.is_available()}")
 
 
 
56
 
 
 
57
  tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
58
  if tokenizer.pad_token is None:
59
  tokenizer.pad_token = tokenizer.eos_token
60
- log.append("Tokenizer loaded")
61
 
62
- dataset = Dataset.from_dict({"prompt": ["Write hello world"] * 4})
63
- log.append("Dataset created")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
 
65
- # Just check config creation
66
  config = GRPOConfig(
67
- output_dir="./test",
68
  num_train_epochs=1,
 
69
  per_device_train_batch_size=1,
 
 
 
70
  report_to="none",
 
 
 
 
 
 
 
 
 
 
71
  )
72
- log.append("Config created")
73
 
74
- return "\n".join(log)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
  except Exception as e:
77
  import traceback
78
- return f"Error: {e}\n\n{traceback.format_exc()}"
 
 
 
79
 
80
 
81
- with gr.Blocks(title="VLIW Test") as demo:
82
- gr.Markdown("# VLIW Optimizer - Test Mode")
 
 
 
83
 
84
- with gr.Row():
85
- status_box = gr.Textbox(label="Startup Status", value=get_status(), lines=15)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
 
87
  with gr.Row():
88
- model = gr.Dropdown(
89
- ["Qwen/Qwen2.5-Coder-1.5B-Instruct", "Qwen/Qwen2.5-Coder-3B-Instruct"],
90
- value="Qwen/Qwen2.5-Coder-1.5B-Instruct",
91
- label="Model"
92
- )
93
- steps = gr.Slider(1, 5, value=1, label="Steps")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
 
95
- test_btn = gr.Button("Test Training Setup", variant="primary")
96
- output = gr.Textbox(label="Output", lines=15)
 
97
 
98
- test_btn.click(test_train, [model, steps], [output])
 
99
 
100
  if __name__ == "__main__":
101
  demo.launch(server_name="0.0.0.0", server_port=7860)
 
1
  """
2
  HF Spaces app for VLIW kernel optimization via RL.
 
3
  """
4
  import gradio as gr
5
+ import threading
6
 
7
+ # Check imports at startup
8
  startup_log = []
9
 
10
  def check_import(name, import_fn):
 
37
  except Exception as e:
38
  startup_log.append(f"✗ CUDA check: {e}")
39
 
40
+ # Training state
41
+ training_state = {
42
+ "is_training": False,
43
+ "should_stop": False,
44
+ "current_step": 0,
45
+ "best_cycles": float("inf"),
46
+ "log": [],
47
+ }
48
+ training_lock = threading.Lock()
49
+
50
 
51
  def get_status():
52
  return "\n".join(startup_log)
53
 
54
 
55
+ def reward_fn(completions, **kwargs):
56
+ """Simple reward function for testing."""
57
+ rewards = []
58
+ for completion in completions:
59
+ # Reward longer, code-like completions
60
+ text = completion[0]["content"] if isinstance(completion, list) else str(completion)
61
+ score = min(len(text) / 100.0, 1.0) # Simple length-based reward
62
+ if "def " in text or "for " in text or "if " in text:
63
+ score += 0.5
64
+ rewards.append(score)
65
+ return rewards
66
+
67
+
68
+ def run_training(model_name, num_steps, progress_callback):
69
+ """Run RL training."""
70
+ import torch
71
+ from datasets import Dataset
72
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
73
+ from peft import LoraConfig
74
+ from trl import GRPOConfig, GRPOTrainer
75
 
76
+ with training_lock:
77
+ training_state["is_training"] = True
78
+ training_state["should_stop"] = False
79
+ training_state["current_step"] = 0
80
+ training_state["log"] = ["Starting training..."]
81
 
82
+ try:
83
+ progress_callback("Loading tokenizer...")
84
  tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
85
  if tokenizer.pad_token is None:
86
  tokenizer.pad_token = tokenizer.eos_token
 
87
 
88
+ progress_callback("Loading model with 4-bit quantization...")
89
+ bnb_config = BitsAndBytesConfig(
90
+ load_in_4bit=True,
91
+ bnb_4bit_quant_type="nf4",
92
+ bnb_4bit_compute_dtype=torch.bfloat16,
93
+ )
94
+ model = AutoModelForCausalLM.from_pretrained(
95
+ model_name,
96
+ quantization_config=bnb_config,
97
+ device_map="auto",
98
+ trust_remote_code=True,
99
+ )
100
+
101
+ progress_callback("Creating dataset...")
102
+ prompts = [
103
+ "Write optimized VLIW assembly for matrix multiplication",
104
+ "Generate SIMD code for vector addition",
105
+ "Create parallel code for reduction operation",
106
+ "Write efficient loop for memory copy",
107
+ ] * 4 # 16 prompts
108
+
109
+ dataset = Dataset.from_dict({"prompt": prompts})
110
+
111
+ progress_callback("Setting up LoRA config...")
112
+ lora_config = LoraConfig(
113
+ r=16,
114
+ lora_alpha=32,
115
+ target_modules=["q_proj", "v_proj"],
116
+ lora_dropout=0.05,
117
+ bias="none",
118
+ task_type="CAUSAL_LM",
119
+ )
120
 
121
+ progress_callback("Creating trainer...")
122
  config = GRPOConfig(
123
+ output_dir="./grpo_output",
124
  num_train_epochs=1,
125
+ max_steps=num_steps,
126
  per_device_train_batch_size=1,
127
+ gradient_accumulation_steps=4,
128
+ learning_rate=1e-5,
129
+ logging_steps=1,
130
  report_to="none",
131
+ remove_unused_columns=False,
132
+ )
133
+
134
+ trainer = GRPOTrainer(
135
+ model=model,
136
+ args=config,
137
+ train_dataset=dataset,
138
+ reward_funcs=reward_fn,
139
+ peft_config=lora_config,
140
+ processing_class=tokenizer,
141
  )
 
142
 
143
+ progress_callback("Starting training loop...")
144
+ for step in range(num_steps):
145
+ with training_lock:
146
+ if training_state["should_stop"]:
147
+ progress_callback("Training stopped by user")
148
+ break
149
+ training_state["current_step"] = step + 1
150
+
151
+ # Run one step
152
+ try:
153
+ trainer.train()
154
+ progress_callback(f"Step {step + 1}/{num_steps} completed")
155
+ except Exception as e:
156
+ progress_callback(f"Step {step + 1} error: {str(e)[:100]}")
157
+ break
158
+
159
+ progress_callback("Training complete!")
160
 
161
  except Exception as e:
162
  import traceback
163
+ progress_callback(f"Error: {e}\n{traceback.format_exc()}")
164
+ finally:
165
+ with training_lock:
166
+ training_state["is_training"] = False
167
 
168
 
169
+ def start_training(model_name, num_steps):
170
+ """Start training in background thread."""
171
+ with training_lock:
172
+ if training_state["is_training"]:
173
+ return "Training already in progress"
174
 
175
+ log_messages = []
176
+
177
+ def progress_callback(msg):
178
+ log_messages.append(msg)
179
+ with training_lock:
180
+ training_state["log"] = log_messages.copy()
181
+
182
+ thread = threading.Thread(
183
+ target=run_training,
184
+ args=(model_name, int(num_steps), progress_callback),
185
+ daemon=False,
186
+ )
187
+ thread.start()
188
+ return "Training started! Check progress below."
189
+
190
+
191
+ def stop_training():
192
+ """Request training stop."""
193
+ with training_lock:
194
+ if not training_state["is_training"]:
195
+ return "No training in progress"
196
+ training_state["should_stop"] = True
197
+ return "Stop requested..."
198
+
199
+
200
+ def get_progress():
201
+ """Get current training progress."""
202
+ with training_lock:
203
+ if not training_state["log"]:
204
+ return "No training started yet"
205
+ return "\n".join(training_state["log"][-20:]) # Last 20 messages
206
+
207
+
208
+ # Gradio UI
209
+ with gr.Blocks(title="VLIW Optimizer") as demo:
210
+ gr.Markdown("# VLIW Kernel Optimizer - RL Training")
211
+ gr.Markdown("Train a model to generate optimized VLIW/SIMD kernels using reinforcement learning.")
212
 
213
  with gr.Row():
214
+ with gr.Column(scale=1):
215
+ status_box = gr.Textbox(
216
+ label="System Status",
217
+ value=get_status(),
218
+ lines=10,
219
+ interactive=False,
220
+ )
221
+
222
+ with gr.Column(scale=2):
223
+ model_dropdown = gr.Dropdown(
224
+ choices=[
225
+ "Qwen/Qwen2.5-Coder-1.5B-Instruct",
226
+ "Qwen/Qwen2.5-Coder-3B-Instruct",
227
+ ],
228
+ value="Qwen/Qwen2.5-Coder-1.5B-Instruct",
229
+ label="Model",
230
+ )
231
+ steps_slider = gr.Slider(
232
+ minimum=1,
233
+ maximum=100,
234
+ value=10,
235
+ step=1,
236
+ label="Training Steps",
237
+ )
238
+
239
+ with gr.Row():
240
+ start_btn = gr.Button("Start Training", variant="primary")
241
+ stop_btn = gr.Button("Stop Training", variant="stop")
242
+
243
+ output_box = gr.Textbox(
244
+ label="Training Progress",
245
+ lines=15,
246
+ interactive=False,
247
+ )
248
+
249
+ # Auto-refresh progress
250
+ refresh_btn = gr.Button("Refresh Progress")
251
 
252
+ start_btn.click(start_training, [model_dropdown, steps_slider], [output_box])
253
+ stop_btn.click(stop_training, [], [output_box])
254
+ refresh_btn.click(get_progress, [], [output_box])
255
 
256
+ # Auto-refresh every 5 seconds when training
257
+ demo.load(get_progress, [], [output_box], every=5)
258
 
259
  if __name__ == "__main__":
260
  demo.launch(server_name="0.0.0.0", server_port=7860)