RAMCr7 commited on
Commit
e19878b
Β·
1 Parent(s): 58f6308

Final train patch

Browse files
Files changed (1) hide show
  1. patchhawk/training/train_grpo.py +75 -67
patchhawk/training/train_grpo.py CHANGED
@@ -1,11 +1,14 @@
 
1
  """
2
  GRPO training pipeline for PatchHawk (trl 1.0.0, RTX 3060 12GB).
3
 
4
- Fixed:
5
- - Removed max_prompt_length / max_completion_length (unsupported in trl 1.0.0).
6
- - Disabled fp16 in GRPOConfig to avoid BFloat16 AMP error.
7
  - Set tokenizer.model_max_length for sequence length control.
8
- - Added custom callback to show loss in progress bar.
 
 
9
  """
10
 
11
  import argparse
@@ -36,6 +39,7 @@ def _build_prompt(scenario: dict) -> str:
36
 
37
 
38
  def train_agent(args):
 
39
  if not args.dry_run:
40
  try:
41
  from trl import GRPOTrainer, GRPOConfig
@@ -44,11 +48,19 @@ def train_agent(args):
44
  "trl not found.\nInstall: pip install trl==1.0.0 peft bitsandbytes accelerate transformers"
45
  ) from exc
46
 
 
47
  if not args.dry_run and wandb is not None:
48
- wandb.init(project="patchhawk", name="grpo-run", config=vars(args))
 
 
 
 
 
 
49
  else:
50
  print("[INFO] WandB skipped.")
51
 
 
52
  from patchhawk.agent.environment import PatchHawkEnv
53
 
54
  env = PatchHawkEnv(
@@ -61,6 +73,7 @@ def train_agent(args):
61
  _dry_run_training(env, args)
62
  return
63
 
 
64
  import torch
65
  from transformers import (
66
  AutoModelForCausalLM,
@@ -68,12 +81,7 @@ def train_agent(args):
68
  BitsAndBytesConfig,
69
  TrainerCallback,
70
  )
71
- from peft import (
72
- LoraConfig,
73
- TaskType,
74
- get_peft_model,
75
- prepare_model_for_kbit_training,
76
- )
77
  from datasets import Dataset
78
  from trl import GRPOConfig, GRPOTrainer
79
 
@@ -84,6 +92,7 @@ def train_agent(args):
84
 
85
  MODEL_NAME = "Qwen/Qwen2.5-Coder-3B-Instruct"
86
 
 
87
  bnb_config = BitsAndBytesConfig(
88
  load_in_4bit=True,
89
  bnb_4bit_quant_type="nf4",
@@ -97,7 +106,7 @@ def train_agent(args):
97
  tokenizer.pad_token = tokenizer.eos_token
98
  tokenizer.padding_side = "left"
99
 
100
- # Set maximum total sequence length (prompt + generation)
101
  tokenizer.model_max_length = args.max_seq_len
102
 
103
  base_model = AutoModelForCausalLM.from_pretrained(
@@ -113,6 +122,7 @@ def train_agent(args):
113
  use_gradient_checkpointing=True,
114
  )
115
 
 
116
  lora_config = LoraConfig(
117
  task_type=TaskType.CAUSAL_LM,
118
  r=16,
@@ -120,19 +130,14 @@ def train_agent(args):
120
  lora_dropout=0.05,
121
  bias="none",
122
  target_modules=[
123
- "q_proj",
124
- "k_proj",
125
- "v_proj",
126
- "o_proj",
127
- "gate_proj",
128
- "up_proj",
129
- "down_proj",
130
  ],
131
  )
132
  model = get_peft_model(base_model, lora_config)
133
  model.print_trainable_parameters()
134
 
135
- # Reward 1: format (trl 1.0.0 expects completions as list of strings)
136
  def format_reward(completions, **kwargs):
137
  rewards = []
138
  for c in completions:
@@ -154,7 +159,7 @@ def train_agent(args):
154
  rewards.append(score)
155
  return rewards
156
 
157
- # Reward 2: environment
158
  from patchhawk.env_models import PatchHawkAction
159
 
160
  def env_reward(completions, prompts, **kwargs):
@@ -162,10 +167,8 @@ def train_agent(args):
162
  for prompt, c in zip(prompts, completions):
163
  text = c if isinstance(c, str) else str(c)
164
 
165
- # Find scenario by code snippet in prompt
166
- code_match = re.search(
167
- r"<code_snippet>(.*?)</code_snippet>", prompt, re.DOTALL
168
- )
169
  if not code_match:
170
  rewards.append(-2.0)
171
  continue
@@ -179,30 +182,30 @@ def train_agent(args):
179
  rewards.append(-2.0)
180
  continue
181
 
 
182
  action_match = re.search(r"<action>(\d+)</action>", text)
183
  if not action_match:
184
  rewards.append(-2.0)
185
  continue
186
  action_type = int(action_match.group(1))
187
 
 
188
  patch = None
189
  patch_match = re.search(r"<patch>(.*?)</patch>", text, re.DOTALL)
190
  if patch_match:
191
  patch = patch_match.group(1).strip()
192
 
193
  try:
194
- # Reset environment to the specific scenario
195
  env.reset(scenario_idx=env.scenarios.index(scenario))
196
- obs = env.step(
197
- PatchHawkAction(action_type=action_type, patch_content=patch)
198
- )
199
  rewards.append(float(obs.reward or 0.0))
200
  except Exception as exc:
201
  print(f"env_reward crash: {exc}")
202
  rewards.append(-3.0)
203
  return rewards
204
 
205
- # Prepare dataset
206
  valid = [s for s in env.scenarios if s.get("label") in ("malicious", "benign")]
207
  random.seed(42)
208
  random.shuffle(valid)
@@ -212,32 +215,43 @@ def train_agent(args):
212
  eval_ds = Dataset.from_list([{"prompt": _build_prompt(s)} for s in valid[split:]])
213
  print(f"Dataset β€” train: {len(train_ds)}, eval: {len(eval_ds)}")
214
 
215
- # GRPOConfig – no max_prompt_length, no fp16, logging_steps=1 for frequent loss updates
216
  grpo_config = GRPOConfig(
217
  output_dir=args.output_dir,
218
  learning_rate=args.learning_rate,
219
  per_device_train_batch_size=args.batch_size,
220
  gradient_accumulation_steps=args.grad_accum,
221
- fp16=False, # FIX: disable mixed precision
222
  gradient_checkpointing=True,
223
  num_generations=args.group_size,
224
  beta=args.kl_coef,
225
  num_train_epochs=args.epochs,
226
  warmup_steps=10,
227
  max_grad_norm=1.0,
228
- logging_steps=1, # ← log loss every step
 
229
  save_steps=50,
230
  report_to="wandb" if (wandb is not None and not args.dry_run) else "none",
231
  )
232
 
233
- # ─────────────────────────────────────────────────────────
234
- # Custom callback to show loss in the tqdm progress bar
235
- # ─────────────────────────────────────────────────────────
236
- class LossProgressBarCallback(TrainerCallback):
237
  def on_log(self, args, state, control, logs=None, **kwargs):
238
- if logs and "loss" in logs:
239
- if hasattr(state, "progress_bar"):
240
- state.progress_bar.set_postfix({"loss": f"{logs['loss']:.4f}"})
 
 
 
 
 
 
 
 
 
 
 
 
241
 
242
  trainer = GRPOTrainer(
243
  model=model,
@@ -246,21 +260,23 @@ def train_agent(args):
246
  train_dataset=train_ds,
247
  eval_dataset=eval_ds,
248
  )
249
-
250
- # Add the callback
251
- trainer.add_callback(LossProgressBarCallback())
252
 
253
  print("Starting GRPO training ...")
254
  trainer.train()
255
 
256
- # Save adapter
 
 
 
 
257
  out = Path(args.output_dir)
258
  out.mkdir(parents=True, exist_ok=True)
259
  model.save_pretrained(str(out))
260
  tokenizer.save_pretrained(str(out))
261
  print(f"LoRA adapter saved to {out}")
262
 
263
- # Optional HF Hub upload
264
  hf_repo = os.getenv("HF_REPO", "")
265
  if hf_repo:
266
  try:
@@ -271,8 +287,10 @@ def train_agent(args):
271
  print(f"HF upload failed: {exc}")
272
 
273
 
 
 
 
274
  def _dry_run_training(env, args):
275
- # ... (unchanged, keep as in your original)
276
  print("[DRY RUN] CPU simulation only β€” no model loaded.\n")
277
  from patchhawk.env_models import PatchHawkAction
278
 
@@ -305,18 +323,14 @@ def _dry_run_training(env, args):
305
  atype = env.current_scenario.get("attack_type", "none") or "none"
306
  attack_success.setdefault(atype, {"correct": 0, "total": 0})
307
  attack_success[atype]["total"] += 1
308
- if (label == "malicious" and ep_reward > 0) or (
309
- label == "benign" and ep_reward >= 0
310
- ):
311
  attack_success[atype]["correct"] += 1
312
 
313
  mean_r = float(np.mean(group_rewards))
314
  std_r = float(np.std(group_rewards)) + 1e-8
315
  advantages = [(r - mean_r) / std_r for r in group_rewards]
316
  epoch_rewards.append(mean_r)
317
- print(
318
- f" Batch mean_reward={mean_r:+.2f} advantages={[f'{a:+.2f}' for a in advantages]}"
319
- )
320
 
321
  epoch_mean = float(np.mean(epoch_rewards)) if epoch_rewards else 0.0
322
  print(f" Epoch {epoch + 1} mean_reward: {epoch_mean:+.2f}")
@@ -332,32 +346,26 @@ def _dry_run_training(env, args):
332
  "loss": max(0.0, 1.0 - epoch_mean / 3.0),
333
  }
334
  for atype, counts in attack_success.items():
335
- log_data[f"success_rate/{atype}"] = counts["correct"] / max(
336
- counts["total"], 1
337
- )
338
  wandb.log(log_data)
339
  except Exception:
340
  pass
341
 
342
  out = Path(args.output_dir)
343
  out.mkdir(parents=True, exist_ok=True)
344
- (out / "adapter_config.json").write_text('{"model_type":"patchhawk-grpo-null-baseline"}')
345
  (out / "adapter_model.bin").write_bytes(b"\x00" * 64)
346
- print(f"\n[DRY RUN] Baseline constraint adapter written to {args.output_dir}/")
347
 
348
 
 
 
 
349
  if __name__ == "__main__":
350
  parser = argparse.ArgumentParser(description="PatchHawk GRPO Training (trl 1.0.0)")
351
- parser.add_argument(
352
- "--dry-run", action="store_true", help="CPU simulation, no model"
353
- )
354
  parser.add_argument("--use-docker", action="store_true", help="Use Docker sandbox")
355
- parser.add_argument(
356
- "--max-seq-len",
357
- type=int,
358
- default=1024,
359
- help="Total sequence length (prompt+completion)",
360
- )
361
  parser.add_argument("--learning-rate", type=float, default=5e-6)
362
  parser.add_argument("--kl-coef", type=float, default=0.01)
363
  parser.add_argument("--batch-size", type=int, default=1)
@@ -367,4 +375,4 @@ if __name__ == "__main__":
367
  parser.add_argument("--max-steps", type=int, default=200)
368
  parser.add_argument("--output-dir", type=str, default="grpo_lora")
369
  args = parser.parse_args()
370
- train_agent(args)
 
1
+ #!/usr/bin/env python3
2
  """
3
  GRPO training pipeline for PatchHawk (trl 1.0.0, RTX 3060 12GB).
4
 
5
+ Fixed for trl 1.0.0:
6
+ - Removed max_prompt_length / max_completion_length.
7
+ - Disabled fp16 to avoid BFloat16 AMP error.
8
  - Set tokenizer.model_max_length for sequence length control.
9
+ - Forced WandB logging every step via custom callback (no step argument to avoid warnings).
10
+ - Loss displayed in tqdm progress bar.
11
+ - WandB online mode forced before init.
12
  """
13
 
14
  import argparse
 
39
 
40
 
41
  def train_agent(args):
42
+ # Check trl availability
43
  if not args.dry_run:
44
  try:
45
  from trl import GRPOTrainer, GRPOConfig
 
48
  "trl not found.\nInstall: pip install trl==1.0.0 peft bitsandbytes accelerate transformers"
49
  ) from exc
50
 
51
+ # ── WandB initialisation (force online mode before init) ──
52
  if not args.dry_run and wandb is not None:
53
+ os.environ["WANDB_MODE"] = "online"
54
+ os.environ["WANDB_SILENT"] = "false"
55
+ wandb.init(
56
+ project="patchhawk",
57
+ name="grpo-run",
58
+ config=vars(args),
59
+ )
60
  else:
61
  print("[INFO] WandB skipped.")
62
 
63
+ # ── Environment ──────────────────────────────────────────
64
  from patchhawk.agent.environment import PatchHawkEnv
65
 
66
  env = PatchHawkEnv(
 
73
  _dry_run_training(env, args)
74
  return
75
 
76
+ # ── GPU training imports ─────────────────────────────────
77
  import torch
78
  from transformers import (
79
  AutoModelForCausalLM,
 
81
  BitsAndBytesConfig,
82
  TrainerCallback,
83
  )
84
+ from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_kbit_training
 
 
 
 
 
85
  from datasets import Dataset
86
  from trl import GRPOConfig, GRPOTrainer
87
 
 
92
 
93
  MODEL_NAME = "Qwen/Qwen2.5-Coder-3B-Instruct"
94
 
95
+ # 4‑bit quantisation config
96
  bnb_config = BitsAndBytesConfig(
97
  load_in_4bit=True,
98
  bnb_4bit_quant_type="nf4",
 
106
  tokenizer.pad_token = tokenizer.eos_token
107
  tokenizer.padding_side = "left"
108
 
109
+ # Critical: set total sequence length (prompt + generation)
110
  tokenizer.model_max_length = args.max_seq_len
111
 
112
  base_model = AutoModelForCausalLM.from_pretrained(
 
122
  use_gradient_checkpointing=True,
123
  )
124
 
125
+ # LoRA configuration
126
  lora_config = LoraConfig(
127
  task_type=TaskType.CAUSAL_LM,
128
  r=16,
 
130
  lora_dropout=0.05,
131
  bias="none",
132
  target_modules=[
133
+ "q_proj", "k_proj", "v_proj", "o_proj",
134
+ "gate_proj", "up_proj", "down_proj",
 
 
 
 
 
135
  ],
136
  )
137
  model = get_peft_model(base_model, lora_config)
138
  model.print_trainable_parameters()
139
 
140
+ # ── Reward 1: XML format ─────────────────────────────────
141
  def format_reward(completions, **kwargs):
142
  rewards = []
143
  for c in completions:
 
159
  rewards.append(score)
160
  return rewards
161
 
162
+ # ── Reward 2: environment feedback ───────────────────────
163
  from patchhawk.env_models import PatchHawkAction
164
 
165
  def env_reward(completions, prompts, **kwargs):
 
167
  for prompt, c in zip(prompts, completions):
168
  text = c if isinstance(c, str) else str(c)
169
 
170
+ # Extract code snippet from prompt to identify scenario
171
+ code_match = re.search(r"<code_snippet>(.*?)</code_snippet>", prompt, re.DOTALL)
 
 
172
  if not code_match:
173
  rewards.append(-2.0)
174
  continue
 
182
  rewards.append(-2.0)
183
  continue
184
 
185
+ # Parse action
186
  action_match = re.search(r"<action>(\d+)</action>", text)
187
  if not action_match:
188
  rewards.append(-2.0)
189
  continue
190
  action_type = int(action_match.group(1))
191
 
192
+ # Parse patch (if any)
193
  patch = None
194
  patch_match = re.search(r"<patch>(.*?)</patch>", text, re.DOTALL)
195
  if patch_match:
196
  patch = patch_match.group(1).strip()
197
 
198
  try:
199
+ # Reset environment to the exact scenario
200
  env.reset(scenario_idx=env.scenarios.index(scenario))
201
+ obs = env.step(PatchHawkAction(action_type=action_type, patch_content=patch))
 
 
202
  rewards.append(float(obs.reward or 0.0))
203
  except Exception as exc:
204
  print(f"env_reward crash: {exc}")
205
  rewards.append(-3.0)
206
  return rewards
207
 
208
+ # ── Dataset preparation ──────────────────────────────────
209
  valid = [s for s in env.scenarios if s.get("label") in ("malicious", "benign")]
210
  random.seed(42)
211
  random.shuffle(valid)
 
215
  eval_ds = Dataset.from_list([{"prompt": _build_prompt(s)} for s in valid[split:]])
216
  print(f"Dataset β€” train: {len(train_ds)}, eval: {len(eval_ds)}")
217
 
218
+ # ── GRPO Config (trl 1.0.0 compatible) ───────────────────
219
  grpo_config = GRPOConfig(
220
  output_dir=args.output_dir,
221
  learning_rate=args.learning_rate,
222
  per_device_train_batch_size=args.batch_size,
223
  gradient_accumulation_steps=args.grad_accum,
224
+ fp16=False, # avoids BFloat16 AMP error
225
  gradient_checkpointing=True,
226
  num_generations=args.group_size,
227
  beta=args.kl_coef,
228
  num_train_epochs=args.epochs,
229
  warmup_steps=10,
230
  max_grad_norm=1.0,
231
+ logging_steps=1, # log every step
232
+ logging_first_step=True, # log step 0 immediately
233
  save_steps=50,
234
  report_to="wandb" if (wandb is not None and not args.dry_run) else "none",
235
  )
236
 
237
+ # ── Custom callback: force WandB logging + progress bar (no step warnings) ──
238
+ class ForceWandbCallback(TrainerCallback):
 
 
239
  def on_log(self, args, state, control, logs=None, **kwargs):
240
+ if not logs:
241
+ return
242
+ # Log everything to wandb WITHOUT step argument (avoids step warnings)
243
+ if wandb is not None and wandb.run is not None:
244
+ wandb.log(logs)
245
+ # Update progress bar with loss
246
+ loss_key = None
247
+ for key in ["loss", "grpo_loss", "train_loss"]:
248
+ if key in logs:
249
+ loss_key = key
250
+ break
251
+ if loss_key is not None:
252
+ loss_val = logs[loss_key]
253
+ if hasattr(state, "progress_bar") and state.progress_bar is not None:
254
+ state.progress_bar.set_postfix({loss_key: f"{loss_val:.4f}"})
255
 
256
  trainer = GRPOTrainer(
257
  model=model,
 
260
  train_dataset=train_ds,
261
  eval_dataset=eval_ds,
262
  )
263
+ trainer.add_callback(ForceWandbCallback())
 
 
264
 
265
  print("Starting GRPO training ...")
266
  trainer.train()
267
 
268
+ # Ensure all pending logs are sent to wandb
269
+ if wandb is not None and wandb.run is not None:
270
+ wandb.finish()
271
+
272
+ # ── Save LoRA adapter ────────────────────────────────────
273
  out = Path(args.output_dir)
274
  out.mkdir(parents=True, exist_ok=True)
275
  model.save_pretrained(str(out))
276
  tokenizer.save_pretrained(str(out))
277
  print(f"LoRA adapter saved to {out}")
278
 
279
+ # ── Optional HF Hub upload ───────────────────────────────
280
  hf_repo = os.getenv("HF_REPO", "")
281
  if hf_repo:
282
  try:
 
287
  print(f"HF upload failed: {exc}")
288
 
289
 
290
+ # ─────────────────────────────────────────────────────────────
291
+ # Dry-run (CPU simulation, no model)
292
+ # ─────────────────────────────────────────────────────────────
293
  def _dry_run_training(env, args):
 
294
  print("[DRY RUN] CPU simulation only β€” no model loaded.\n")
295
  from patchhawk.env_models import PatchHawkAction
296
 
 
323
  atype = env.current_scenario.get("attack_type", "none") or "none"
324
  attack_success.setdefault(atype, {"correct": 0, "total": 0})
325
  attack_success[atype]["total"] += 1
326
+ if (label == "malicious" and ep_reward > 0) or (label == "benign" and ep_reward >= 0):
 
 
327
  attack_success[atype]["correct"] += 1
328
 
329
  mean_r = float(np.mean(group_rewards))
330
  std_r = float(np.std(group_rewards)) + 1e-8
331
  advantages = [(r - mean_r) / std_r for r in group_rewards]
332
  epoch_rewards.append(mean_r)
333
+ print(f" Batch mean_reward={mean_r:+.2f} advantages={[f'{a:+.2f}' for a in advantages]}")
 
 
334
 
335
  epoch_mean = float(np.mean(epoch_rewards)) if epoch_rewards else 0.0
336
  print(f" Epoch {epoch + 1} mean_reward: {epoch_mean:+.2f}")
 
346
  "loss": max(0.0, 1.0 - epoch_mean / 3.0),
347
  }
348
  for atype, counts in attack_success.items():
349
+ log_data[f"success_rate/{atype}"] = counts["correct"] / max(counts["total"], 1)
 
 
350
  wandb.log(log_data)
351
  except Exception:
352
  pass
353
 
354
  out = Path(args.output_dir)
355
  out.mkdir(parents=True, exist_ok=True)
356
+ (out / "adapter_config.json").write_text('{"model_type":"patchhawk-grpo-dry-run"}')
357
  (out / "adapter_model.bin").write_bytes(b"\x00" * 64)
358
+ print(f"\n[DRY RUN] Dummy adapter written to {args.output_dir}/")
359
 
360
 
361
+ # ─────────────────────────────────────────────────────────────
362
+ # CLI entry point
363
+ # ─────────────────────────────────────────────────────────────
364
  if __name__ == "__main__":
365
  parser = argparse.ArgumentParser(description="PatchHawk GRPO Training (trl 1.0.0)")
366
+ parser.add_argument("--dry-run", action="store_true", help="CPU simulation, no model")
 
 
367
  parser.add_argument("--use-docker", action="store_true", help="Use Docker sandbox")
368
+ parser.add_argument("--max-seq-len", type=int, default=1024, help="Total sequence length (prompt+completion)")
 
 
 
 
 
369
  parser.add_argument("--learning-rate", type=float, default=5e-6)
370
  parser.add_argument("--kl-coef", type=float, default=0.01)
371
  parser.add_argument("--batch-size", type=int, default=1)
 
375
  parser.add_argument("--max-steps", type=int, default=200)
376
  parser.add_argument("--output-dir", type=str, default="grpo_lora")
377
  args = parser.parse_args()
378
+ train_agent(args)