CreativeEngineer commited on
Commit
a7473b9
·
1 Parent(s): cccd413

Simpler test version for model loading

Browse files
Files changed (1) hide show
  1. app.py +39 -155
app.py CHANGED
@@ -37,55 +37,33 @@ try:
37
  except Exception as e:
38
  startup_log.append(f"✗ CUDA check: {e}")
39
 
40
- # Training state
41
- training_state = {
42
- "is_training": False,
43
- "should_stop": False,
44
- "current_step": 0,
45
- "best_cycles": float("inf"),
46
- "log": [],
47
- }
48
- training_lock = threading.Lock()
49
 
50
 
51
  def get_status():
52
  return "\n".join(startup_log)
53
 
54
 
55
- def reward_fn(completions, **kwargs):
56
- """Simple reward function for testing."""
57
- rewards = []
58
- for completion in completions:
59
- # Reward longer, code-like completions
60
- text = completion[0]["content"] if isinstance(completion, list) else str(completion)
61
- score = min(len(text) / 100.0, 1.0) # Simple length-based reward
62
- if "def " in text or "for " in text or "if " in text:
63
- score += 0.5
64
- rewards.append(score)
65
- return rewards
66
 
 
 
 
67
 
68
- def run_training(model_name, num_steps, progress_callback):
69
- """Run RL training."""
70
- import torch
71
- from datasets import Dataset
72
- from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
73
- from peft import LoraConfig
74
- from trl import GRPOConfig, GRPOTrainer
75
-
76
- with training_lock:
77
- training_state["is_training"] = True
78
- training_state["should_stop"] = False
79
- training_state["current_step"] = 0
80
- training_state["log"] = ["Starting training..."]
81
 
82
- try:
83
- progress_callback("Loading tokenizer...")
84
  tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
85
  if tokenizer.pad_token is None:
86
  tokenizer.pad_token = tokenizer.eos_token
 
87
 
88
- progress_callback("Loading model with 4-bit quantization...")
89
  bnb_config = BitsAndBytesConfig(
90
  load_in_4bit=True,
91
  bnb_4bit_quant_type="nf4",
@@ -97,118 +75,40 @@ def run_training(model_name, num_steps, progress_callback):
97
  device_map="auto",
98
  trust_remote_code=True,
99
  )
 
100
 
101
- progress_callback("Creating dataset...")
102
- prompts = [
103
- "Write optimized VLIW assembly for matrix multiplication",
104
- "Generate SIMD code for vector addition",
105
- "Create parallel code for reduction operation",
106
- "Write efficient loop for memory copy",
107
- ] * 4 # 16 prompts
108
-
109
- dataset = Dataset.from_dict({"prompt": prompts})
110
-
111
- progress_callback("Setting up LoRA config...")
112
- lora_config = LoraConfig(
113
- r=16,
114
- lora_alpha=32,
115
- target_modules=["q_proj", "v_proj"],
116
- lora_dropout=0.05,
117
- bias="none",
118
- task_type="CAUSAL_LM",
119
- )
120
-
121
- progress_callback("Creating trainer...")
122
- config = GRPOConfig(
123
- output_dir="./grpo_output",
124
- num_train_epochs=1,
125
- max_steps=num_steps,
126
- per_device_train_batch_size=1,
127
- gradient_accumulation_steps=4,
128
- learning_rate=1e-5,
129
- logging_steps=1,
130
- report_to="none",
131
- remove_unused_columns=False,
132
- )
133
-
134
- trainer = GRPOTrainer(
135
- model=model,
136
- args=config,
137
- train_dataset=dataset,
138
- reward_funcs=reward_fn,
139
- peft_config=lora_config,
140
- processing_class=tokenizer,
141
- )
142
 
143
- progress_callback("Starting training loop...")
144
- for step in range(num_steps):
145
- with training_lock:
146
- if training_state["should_stop"]:
147
- progress_callback("Training stopped by user")
148
- break
149
- training_state["current_step"] = step + 1
150
 
151
- # Run one step
152
- try:
153
- trainer.train()
154
- progress_callback(f"Step {step + 1}/{num_steps} completed")
155
- except Exception as e:
156
- progress_callback(f"Step {step + 1} error: {str(e)[:100]}")
157
- break
158
-
159
- progress_callback("Training complete!")
160
 
161
  except Exception as e:
162
  import traceback
163
- progress_callback(f"Error: {e}\n{traceback.format_exc()}")
164
- finally:
165
- with training_lock:
166
- training_state["is_training"] = False
167
-
168
-
169
- def start_training(model_name, num_steps):
170
- """Start training in background thread."""
171
- with training_lock:
172
- if training_state["is_training"]:
173
- return "Training already in progress"
174
 
175
- log_messages = []
176
-
177
- def progress_callback(msg):
178
- log_messages.append(msg)
179
- with training_lock:
180
- training_state["log"] = log_messages.copy()
181
-
182
- thread = threading.Thread(
183
- target=run_training,
184
- args=(model_name, int(num_steps), progress_callback),
185
- daemon=False,
186
- )
187
- thread.start()
188
- return "Training started! Check progress below."
189
 
190
 
191
- def stop_training():
192
- """Request training stop."""
193
- with training_lock:
194
- if not training_state["is_training"]:
195
- return "No training in progress"
196
- training_state["should_stop"] = True
197
- return "Stop requested..."
198
-
199
-
200
- def get_progress():
201
- """Get current training progress."""
202
- with training_lock:
203
- if not training_state["log"]:
204
- return "No training started yet"
205
- return "\n".join(training_state["log"][-20:]) # Last 20 messages
206
 
207
 
208
  # Gradio UI
209
  with gr.Blocks(title="VLIW Optimizer") as demo:
210
- gr.Markdown("# VLIW Kernel Optimizer - RL Training")
211
- gr.Markdown("Train a model to generate optimized VLIW/SIMD kernels using reinforcement learning.")
212
 
213
  with gr.Row():
214
  with gr.Column(scale=1):
@@ -228,33 +128,17 @@ with gr.Blocks(title="VLIW Optimizer") as demo:
228
  value="Qwen/Qwen2.5-Coder-1.5B-Instruct",
229
  label="Model",
230
  )
231
- steps_slider = gr.Slider(
232
- minimum=1,
233
- maximum=100,
234
- value=10,
235
- step=1,
236
- label="Training Steps",
237
- )
238
 
239
- with gr.Row():
240
- start_btn = gr.Button("Start Training", variant="primary")
241
- stop_btn = gr.Button("Stop Training", variant="stop")
242
 
243
  output_box = gr.Textbox(
244
- label="Training Progress",
245
  lines=15,
246
  interactive=False,
 
247
  )
248
 
249
- # Auto-refresh progress
250
- refresh_btn = gr.Button("Refresh Progress")
251
-
252
- start_btn.click(start_training, [model_dropdown, steps_slider], [output_box])
253
- stop_btn.click(stop_training, [], [output_box])
254
- refresh_btn.click(get_progress, [], [output_box])
255
-
256
- # Auto-refresh every 5 seconds when training
257
- demo.load(get_progress, [], [output_box], every=5)
258
 
259
  if __name__ == "__main__":
260
  demo.launch(server_name="0.0.0.0", server_port=7860)
 
37
  except Exception as e:
38
  startup_log.append(f"✗ CUDA check: {e}")
39
 
40
+ # Global state
41
+ training_log = []
42
+ is_training = False
 
 
 
 
 
 
43
 
44
 
45
  def get_status():
46
  return "\n".join(startup_log)
47
 
48
 
49
+ def test_model_load(model_name):
50
+ """Test loading the model."""
51
+ global training_log
52
+ training_log = []
 
 
 
 
 
 
 
53
 
54
+ try:
55
+ import torch
56
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
57
 
58
+ training_log.append(f"Testing model: {model_name}")
59
+ training_log.append("Loading tokenizer...")
 
 
 
 
 
 
 
 
 
 
 
60
 
 
 
61
  tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
62
  if tokenizer.pad_token is None:
63
  tokenizer.pad_token = tokenizer.eos_token
64
+ training_log.append("✓ Tokenizer loaded")
65
 
66
+ training_log.append("Loading model with 4-bit quantization...")
67
  bnb_config = BitsAndBytesConfig(
68
  load_in_4bit=True,
69
  bnb_4bit_quant_type="nf4",
 
75
  device_map="auto",
76
  trust_remote_code=True,
77
  )
78
+ training_log.append("✓ Model loaded")
79
 
80
+ # Quick test
81
+ training_log.append("Testing generation...")
82
+ inputs = tokenizer("def hello():", return_tensors="pt").to(model.device)
83
+ outputs = model.generate(**inputs, max_new_tokens=20)
84
+ result = tokenizer.decode(outputs[0], skip_special_tokens=True)
85
+ training_log.append(f" Generation test: {result[:50]}...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
 
87
+ training_log.append("\n✓ All tests passed!")
 
 
 
 
 
 
88
 
89
+ # Cleanup
90
+ del model
91
+ torch.cuda.empty_cache()
 
 
 
 
 
 
92
 
93
  except Exception as e:
94
  import traceback
95
+ training_log.append(f"Error: {e}")
96
+ training_log.append(traceback.format_exc())
 
 
 
 
 
 
 
 
 
97
 
98
+ return "\n".join(training_log)
 
 
 
 
 
 
 
 
 
 
 
 
 
99
 
100
 
101
+ def get_log():
102
+ """Return current log."""
103
+ if not training_log:
104
+ return "No operations run yet. Click 'Test Model Loading' to start."
105
+ return "\n".join(training_log)
 
 
 
 
 
 
 
 
 
 
106
 
107
 
108
  # Gradio UI
109
  with gr.Blocks(title="VLIW Optimizer") as demo:
110
+ gr.Markdown("# VLIW Kernel Optimizer - Test Mode")
111
+ gr.Markdown("Testing model loading and generation before full training.")
112
 
113
  with gr.Row():
114
  with gr.Column(scale=1):
 
128
  value="Qwen/Qwen2.5-Coder-1.5B-Instruct",
129
  label="Model",
130
  )
 
 
 
 
 
 
 
131
 
132
+ test_btn = gr.Button("Test Model Loading", variant="primary")
 
 
133
 
134
  output_box = gr.Textbox(
135
+ label="Output",
136
  lines=15,
137
  interactive=False,
138
+ value="Click 'Test Model Loading' to verify the setup.",
139
  )
140
 
141
+ test_btn.click(test_model_load, [model_dropdown], [output_box])
 
 
 
 
 
 
 
 
142
 
143
  if __name__ == "__main__":
144
  demo.launch(server_name="0.0.0.0", server_port=7860)