Spaces:
Running
Running
Final train patch
Browse files- patchhawk/training/train_grpo.py +75 -67
patchhawk/training/train_grpo.py
CHANGED
|
@@ -1,11 +1,14 @@
|
|
|
|
|
| 1 |
"""
|
| 2 |
GRPO training pipeline for PatchHawk (trl 1.0.0, RTX 3060 12GB).
|
| 3 |
|
| 4 |
-
Fixed:
|
| 5 |
-
- Removed max_prompt_length / max_completion_length
|
| 6 |
-
- Disabled fp16
|
| 7 |
- Set tokenizer.model_max_length for sequence length control.
|
| 8 |
-
-
|
|
|
|
|
|
|
| 9 |
"""
|
| 10 |
|
| 11 |
import argparse
|
|
@@ -36,6 +39,7 @@ def _build_prompt(scenario: dict) -> str:
|
|
| 36 |
|
| 37 |
|
| 38 |
def train_agent(args):
|
|
|
|
| 39 |
if not args.dry_run:
|
| 40 |
try:
|
| 41 |
from trl import GRPOTrainer, GRPOConfig
|
|
@@ -44,11 +48,19 @@ def train_agent(args):
|
|
| 44 |
"trl not found.\nInstall: pip install trl==1.0.0 peft bitsandbytes accelerate transformers"
|
| 45 |
) from exc
|
| 46 |
|
|
|
|
| 47 |
if not args.dry_run and wandb is not None:
|
| 48 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
else:
|
| 50 |
print("[INFO] WandB skipped.")
|
| 51 |
|
|
|
|
| 52 |
from patchhawk.agent.environment import PatchHawkEnv
|
| 53 |
|
| 54 |
env = PatchHawkEnv(
|
|
@@ -61,6 +73,7 @@ def train_agent(args):
|
|
| 61 |
_dry_run_training(env, args)
|
| 62 |
return
|
| 63 |
|
|
|
|
| 64 |
import torch
|
| 65 |
from transformers import (
|
| 66 |
AutoModelForCausalLM,
|
|
@@ -68,12 +81,7 @@ def train_agent(args):
|
|
| 68 |
BitsAndBytesConfig,
|
| 69 |
TrainerCallback,
|
| 70 |
)
|
| 71 |
-
from peft import
|
| 72 |
-
LoraConfig,
|
| 73 |
-
TaskType,
|
| 74 |
-
get_peft_model,
|
| 75 |
-
prepare_model_for_kbit_training,
|
| 76 |
-
)
|
| 77 |
from datasets import Dataset
|
| 78 |
from trl import GRPOConfig, GRPOTrainer
|
| 79 |
|
|
@@ -84,6 +92,7 @@ def train_agent(args):
|
|
| 84 |
|
| 85 |
MODEL_NAME = "Qwen/Qwen2.5-Coder-3B-Instruct"
|
| 86 |
|
|
|
|
| 87 |
bnb_config = BitsAndBytesConfig(
|
| 88 |
load_in_4bit=True,
|
| 89 |
bnb_4bit_quant_type="nf4",
|
|
@@ -97,7 +106,7 @@ def train_agent(args):
|
|
| 97 |
tokenizer.pad_token = tokenizer.eos_token
|
| 98 |
tokenizer.padding_side = "left"
|
| 99 |
|
| 100 |
-
#
|
| 101 |
tokenizer.model_max_length = args.max_seq_len
|
| 102 |
|
| 103 |
base_model = AutoModelForCausalLM.from_pretrained(
|
|
@@ -113,6 +122,7 @@ def train_agent(args):
|
|
| 113 |
use_gradient_checkpointing=True,
|
| 114 |
)
|
| 115 |
|
|
|
|
| 116 |
lora_config = LoraConfig(
|
| 117 |
task_type=TaskType.CAUSAL_LM,
|
| 118 |
r=16,
|
|
@@ -120,19 +130,14 @@ def train_agent(args):
|
|
| 120 |
lora_dropout=0.05,
|
| 121 |
bias="none",
|
| 122 |
target_modules=[
|
| 123 |
-
"q_proj",
|
| 124 |
-
"
|
| 125 |
-
"v_proj",
|
| 126 |
-
"o_proj",
|
| 127 |
-
"gate_proj",
|
| 128 |
-
"up_proj",
|
| 129 |
-
"down_proj",
|
| 130 |
],
|
| 131 |
)
|
| 132 |
model = get_peft_model(base_model, lora_config)
|
| 133 |
model.print_trainable_parameters()
|
| 134 |
|
| 135 |
-
# Reward 1: format
|
| 136 |
def format_reward(completions, **kwargs):
|
| 137 |
rewards = []
|
| 138 |
for c in completions:
|
|
@@ -154,7 +159,7 @@ def train_agent(args):
|
|
| 154 |
rewards.append(score)
|
| 155 |
return rewards
|
| 156 |
|
| 157 |
-
# Reward 2: environment
|
| 158 |
from patchhawk.env_models import PatchHawkAction
|
| 159 |
|
| 160 |
def env_reward(completions, prompts, **kwargs):
|
|
@@ -162,10 +167,8 @@ def train_agent(args):
|
|
| 162 |
for prompt, c in zip(prompts, completions):
|
| 163 |
text = c if isinstance(c, str) else str(c)
|
| 164 |
|
| 165 |
-
#
|
| 166 |
-
code_match = re.search(
|
| 167 |
-
r"<code_snippet>(.*?)</code_snippet>", prompt, re.DOTALL
|
| 168 |
-
)
|
| 169 |
if not code_match:
|
| 170 |
rewards.append(-2.0)
|
| 171 |
continue
|
|
@@ -179,30 +182,30 @@ def train_agent(args):
|
|
| 179 |
rewards.append(-2.0)
|
| 180 |
continue
|
| 181 |
|
|
|
|
| 182 |
action_match = re.search(r"<action>(\d+)</action>", text)
|
| 183 |
if not action_match:
|
| 184 |
rewards.append(-2.0)
|
| 185 |
continue
|
| 186 |
action_type = int(action_match.group(1))
|
| 187 |
|
|
|
|
| 188 |
patch = None
|
| 189 |
patch_match = re.search(r"<patch>(.*?)</patch>", text, re.DOTALL)
|
| 190 |
if patch_match:
|
| 191 |
patch = patch_match.group(1).strip()
|
| 192 |
|
| 193 |
try:
|
| 194 |
-
# Reset environment to the
|
| 195 |
env.reset(scenario_idx=env.scenarios.index(scenario))
|
| 196 |
-
obs = env.step(
|
| 197 |
-
PatchHawkAction(action_type=action_type, patch_content=patch)
|
| 198 |
-
)
|
| 199 |
rewards.append(float(obs.reward or 0.0))
|
| 200 |
except Exception as exc:
|
| 201 |
print(f"env_reward crash: {exc}")
|
| 202 |
rewards.append(-3.0)
|
| 203 |
return rewards
|
| 204 |
|
| 205 |
-
#
|
| 206 |
valid = [s for s in env.scenarios if s.get("label") in ("malicious", "benign")]
|
| 207 |
random.seed(42)
|
| 208 |
random.shuffle(valid)
|
|
@@ -212,32 +215,43 @@ def train_agent(args):
|
|
| 212 |
eval_ds = Dataset.from_list([{"prompt": _build_prompt(s)} for s in valid[split:]])
|
| 213 |
print(f"Dataset β train: {len(train_ds)}, eval: {len(eval_ds)}")
|
| 214 |
|
| 215 |
-
#
|
| 216 |
grpo_config = GRPOConfig(
|
| 217 |
output_dir=args.output_dir,
|
| 218 |
learning_rate=args.learning_rate,
|
| 219 |
per_device_train_batch_size=args.batch_size,
|
| 220 |
gradient_accumulation_steps=args.grad_accum,
|
| 221 |
-
fp16=False,
|
| 222 |
gradient_checkpointing=True,
|
| 223 |
num_generations=args.group_size,
|
| 224 |
beta=args.kl_coef,
|
| 225 |
num_train_epochs=args.epochs,
|
| 226 |
warmup_steps=10,
|
| 227 |
max_grad_norm=1.0,
|
| 228 |
-
logging_steps=1,
|
|
|
|
| 229 |
save_steps=50,
|
| 230 |
report_to="wandb" if (wandb is not None and not args.dry_run) else "none",
|
| 231 |
)
|
| 232 |
|
| 233 |
-
# ββββ
|
| 234 |
-
|
| 235 |
-
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 236 |
-
class LossProgressBarCallback(TrainerCallback):
|
| 237 |
def on_log(self, args, state, control, logs=None, **kwargs):
|
| 238 |
-
if
|
| 239 |
-
|
| 240 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 241 |
|
| 242 |
trainer = GRPOTrainer(
|
| 243 |
model=model,
|
|
@@ -246,21 +260,23 @@ def train_agent(args):
|
|
| 246 |
train_dataset=train_ds,
|
| 247 |
eval_dataset=eval_ds,
|
| 248 |
)
|
| 249 |
-
|
| 250 |
-
# Add the callback
|
| 251 |
-
trainer.add_callback(LossProgressBarCallback())
|
| 252 |
|
| 253 |
print("Starting GRPO training ...")
|
| 254 |
trainer.train()
|
| 255 |
|
| 256 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
| 257 |
out = Path(args.output_dir)
|
| 258 |
out.mkdir(parents=True, exist_ok=True)
|
| 259 |
model.save_pretrained(str(out))
|
| 260 |
tokenizer.save_pretrained(str(out))
|
| 261 |
print(f"LoRA adapter saved to {out}")
|
| 262 |
|
| 263 |
-
# Optional HF Hub upload
|
| 264 |
hf_repo = os.getenv("HF_REPO", "")
|
| 265 |
if hf_repo:
|
| 266 |
try:
|
|
@@ -271,8 +287,10 @@ def train_agent(args):
|
|
| 271 |
print(f"HF upload failed: {exc}")
|
| 272 |
|
| 273 |
|
|
|
|
|
|
|
|
|
|
| 274 |
def _dry_run_training(env, args):
|
| 275 |
-
# ... (unchanged, keep as in your original)
|
| 276 |
print("[DRY RUN] CPU simulation only β no model loaded.\n")
|
| 277 |
from patchhawk.env_models import PatchHawkAction
|
| 278 |
|
|
@@ -305,18 +323,14 @@ def _dry_run_training(env, args):
|
|
| 305 |
atype = env.current_scenario.get("attack_type", "none") or "none"
|
| 306 |
attack_success.setdefault(atype, {"correct": 0, "total": 0})
|
| 307 |
attack_success[atype]["total"] += 1
|
| 308 |
-
if (label == "malicious" and ep_reward > 0) or (
|
| 309 |
-
label == "benign" and ep_reward >= 0
|
| 310 |
-
):
|
| 311 |
attack_success[atype]["correct"] += 1
|
| 312 |
|
| 313 |
mean_r = float(np.mean(group_rewards))
|
| 314 |
std_r = float(np.std(group_rewards)) + 1e-8
|
| 315 |
advantages = [(r - mean_r) / std_r for r in group_rewards]
|
| 316 |
epoch_rewards.append(mean_r)
|
| 317 |
-
print(
|
| 318 |
-
f" Batch mean_reward={mean_r:+.2f} advantages={[f'{a:+.2f}' for a in advantages]}"
|
| 319 |
-
)
|
| 320 |
|
| 321 |
epoch_mean = float(np.mean(epoch_rewards)) if epoch_rewards else 0.0
|
| 322 |
print(f" Epoch {epoch + 1} mean_reward: {epoch_mean:+.2f}")
|
|
@@ -332,32 +346,26 @@ def _dry_run_training(env, args):
|
|
| 332 |
"loss": max(0.0, 1.0 - epoch_mean / 3.0),
|
| 333 |
}
|
| 334 |
for atype, counts in attack_success.items():
|
| 335 |
-
log_data[f"success_rate/{atype}"] = counts["correct"] / max(
|
| 336 |
-
counts["total"], 1
|
| 337 |
-
)
|
| 338 |
wandb.log(log_data)
|
| 339 |
except Exception:
|
| 340 |
pass
|
| 341 |
|
| 342 |
out = Path(args.output_dir)
|
| 343 |
out.mkdir(parents=True, exist_ok=True)
|
| 344 |
-
(out / "adapter_config.json").write_text('{"model_type":"patchhawk-grpo-
|
| 345 |
(out / "adapter_model.bin").write_bytes(b"\x00" * 64)
|
| 346 |
-
print(f"\n[DRY RUN]
|
| 347 |
|
| 348 |
|
|
|
|
|
|
|
|
|
|
| 349 |
if __name__ == "__main__":
|
| 350 |
parser = argparse.ArgumentParser(description="PatchHawk GRPO Training (trl 1.0.0)")
|
| 351 |
-
parser.add_argument(
|
| 352 |
-
"--dry-run", action="store_true", help="CPU simulation, no model"
|
| 353 |
-
)
|
| 354 |
parser.add_argument("--use-docker", action="store_true", help="Use Docker sandbox")
|
| 355 |
-
parser.add_argument(
|
| 356 |
-
"--max-seq-len",
|
| 357 |
-
type=int,
|
| 358 |
-
default=1024,
|
| 359 |
-
help="Total sequence length (prompt+completion)",
|
| 360 |
-
)
|
| 361 |
parser.add_argument("--learning-rate", type=float, default=5e-6)
|
| 362 |
parser.add_argument("--kl-coef", type=float, default=0.01)
|
| 363 |
parser.add_argument("--batch-size", type=int, default=1)
|
|
@@ -367,4 +375,4 @@ if __name__ == "__main__":
|
|
| 367 |
parser.add_argument("--max-steps", type=int, default=200)
|
| 368 |
parser.add_argument("--output-dir", type=str, default="grpo_lora")
|
| 369 |
args = parser.parse_args()
|
| 370 |
-
train_agent(args)
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
"""
|
| 3 |
GRPO training pipeline for PatchHawk (trl 1.0.0, RTX 3060 12GB).
|
| 4 |
|
| 5 |
+
Fixed for trl 1.0.0:
|
| 6 |
+
- Removed max_prompt_length / max_completion_length.
|
| 7 |
+
- Disabled fp16 to avoid BFloat16 AMP error.
|
| 8 |
- Set tokenizer.model_max_length for sequence length control.
|
| 9 |
+
- Forced WandB logging every step via custom callback (no step argument to avoid warnings).
|
| 10 |
+
- Loss displayed in tqdm progress bar.
|
| 11 |
+
- WandB online mode forced before init.
|
| 12 |
"""
|
| 13 |
|
| 14 |
import argparse
|
|
|
|
| 39 |
|
| 40 |
|
| 41 |
def train_agent(args):
|
| 42 |
+
# Check trl availability
|
| 43 |
if not args.dry_run:
|
| 44 |
try:
|
| 45 |
from trl import GRPOTrainer, GRPOConfig
|
|
|
|
| 48 |
"trl not found.\nInstall: pip install trl==1.0.0 peft bitsandbytes accelerate transformers"
|
| 49 |
) from exc
|
| 50 |
|
| 51 |
+
# ββ WandB initialisation (force online mode before init) ββ
|
| 52 |
if not args.dry_run and wandb is not None:
|
| 53 |
+
os.environ["WANDB_MODE"] = "online"
|
| 54 |
+
os.environ["WANDB_SILENT"] = "false"
|
| 55 |
+
wandb.init(
|
| 56 |
+
project="patchhawk",
|
| 57 |
+
name="grpo-run",
|
| 58 |
+
config=vars(args),
|
| 59 |
+
)
|
| 60 |
else:
|
| 61 |
print("[INFO] WandB skipped.")
|
| 62 |
|
| 63 |
+
# ββ Environment ββββββββββββββββββββββββββββββββββββββββββ
|
| 64 |
from patchhawk.agent.environment import PatchHawkEnv
|
| 65 |
|
| 66 |
env = PatchHawkEnv(
|
|
|
|
| 73 |
_dry_run_training(env, args)
|
| 74 |
return
|
| 75 |
|
| 76 |
+
# ββ GPU training imports βββββββββββββββββββββββββββββββββ
|
| 77 |
import torch
|
| 78 |
from transformers import (
|
| 79 |
AutoModelForCausalLM,
|
|
|
|
| 81 |
BitsAndBytesConfig,
|
| 82 |
TrainerCallback,
|
| 83 |
)
|
| 84 |
+
from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_kbit_training
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 85 |
from datasets import Dataset
|
| 86 |
from trl import GRPOConfig, GRPOTrainer
|
| 87 |
|
|
|
|
| 92 |
|
| 93 |
MODEL_NAME = "Qwen/Qwen2.5-Coder-3B-Instruct"
|
| 94 |
|
| 95 |
+
# 4βbit quantisation config
|
| 96 |
bnb_config = BitsAndBytesConfig(
|
| 97 |
load_in_4bit=True,
|
| 98 |
bnb_4bit_quant_type="nf4",
|
|
|
|
| 106 |
tokenizer.pad_token = tokenizer.eos_token
|
| 107 |
tokenizer.padding_side = "left"
|
| 108 |
|
| 109 |
+
# Critical: set total sequence length (prompt + generation)
|
| 110 |
tokenizer.model_max_length = args.max_seq_len
|
| 111 |
|
| 112 |
base_model = AutoModelForCausalLM.from_pretrained(
|
|
|
|
| 122 |
use_gradient_checkpointing=True,
|
| 123 |
)
|
| 124 |
|
| 125 |
+
# LoRA configuration
|
| 126 |
lora_config = LoraConfig(
|
| 127 |
task_type=TaskType.CAUSAL_LM,
|
| 128 |
r=16,
|
|
|
|
| 130 |
lora_dropout=0.05,
|
| 131 |
bias="none",
|
| 132 |
target_modules=[
|
| 133 |
+
"q_proj", "k_proj", "v_proj", "o_proj",
|
| 134 |
+
"gate_proj", "up_proj", "down_proj",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 135 |
],
|
| 136 |
)
|
| 137 |
model = get_peft_model(base_model, lora_config)
|
| 138 |
model.print_trainable_parameters()
|
| 139 |
|
| 140 |
+
# ββ Reward 1: XML format βββββββββββββββββββββββββββββββββ
|
| 141 |
def format_reward(completions, **kwargs):
|
| 142 |
rewards = []
|
| 143 |
for c in completions:
|
|
|
|
| 159 |
rewards.append(score)
|
| 160 |
return rewards
|
| 161 |
|
| 162 |
+
# ββ Reward 2: environment feedback βββββββββββββββββββββββ
|
| 163 |
from patchhawk.env_models import PatchHawkAction
|
| 164 |
|
| 165 |
def env_reward(completions, prompts, **kwargs):
|
|
|
|
| 167 |
for prompt, c in zip(prompts, completions):
|
| 168 |
text = c if isinstance(c, str) else str(c)
|
| 169 |
|
| 170 |
+
# Extract code snippet from prompt to identify scenario
|
| 171 |
+
code_match = re.search(r"<code_snippet>(.*?)</code_snippet>", prompt, re.DOTALL)
|
|
|
|
|
|
|
| 172 |
if not code_match:
|
| 173 |
rewards.append(-2.0)
|
| 174 |
continue
|
|
|
|
| 182 |
rewards.append(-2.0)
|
| 183 |
continue
|
| 184 |
|
| 185 |
+
# Parse action
|
| 186 |
action_match = re.search(r"<action>(\d+)</action>", text)
|
| 187 |
if not action_match:
|
| 188 |
rewards.append(-2.0)
|
| 189 |
continue
|
| 190 |
action_type = int(action_match.group(1))
|
| 191 |
|
| 192 |
+
# Parse patch (if any)
|
| 193 |
patch = None
|
| 194 |
patch_match = re.search(r"<patch>(.*?)</patch>", text, re.DOTALL)
|
| 195 |
if patch_match:
|
| 196 |
patch = patch_match.group(1).strip()
|
| 197 |
|
| 198 |
try:
|
| 199 |
+
# Reset environment to the exact scenario
|
| 200 |
env.reset(scenario_idx=env.scenarios.index(scenario))
|
| 201 |
+
obs = env.step(PatchHawkAction(action_type=action_type, patch_content=patch))
|
|
|
|
|
|
|
| 202 |
rewards.append(float(obs.reward or 0.0))
|
| 203 |
except Exception as exc:
|
| 204 |
print(f"env_reward crash: {exc}")
|
| 205 |
rewards.append(-3.0)
|
| 206 |
return rewards
|
| 207 |
|
| 208 |
+
# ββ Dataset preparation ββββββββββββββββββββββββββββββββββ
|
| 209 |
valid = [s for s in env.scenarios if s.get("label") in ("malicious", "benign")]
|
| 210 |
random.seed(42)
|
| 211 |
random.shuffle(valid)
|
|
|
|
| 215 |
eval_ds = Dataset.from_list([{"prompt": _build_prompt(s)} for s in valid[split:]])
|
| 216 |
print(f"Dataset β train: {len(train_ds)}, eval: {len(eval_ds)}")
|
| 217 |
|
| 218 |
+
# ββ GRPO Config (trl 1.0.0 compatible) βββββββββββββββββββ
|
| 219 |
grpo_config = GRPOConfig(
|
| 220 |
output_dir=args.output_dir,
|
| 221 |
learning_rate=args.learning_rate,
|
| 222 |
per_device_train_batch_size=args.batch_size,
|
| 223 |
gradient_accumulation_steps=args.grad_accum,
|
| 224 |
+
fp16=False, # avoids BFloat16 AMP error
|
| 225 |
gradient_checkpointing=True,
|
| 226 |
num_generations=args.group_size,
|
| 227 |
beta=args.kl_coef,
|
| 228 |
num_train_epochs=args.epochs,
|
| 229 |
warmup_steps=10,
|
| 230 |
max_grad_norm=1.0,
|
| 231 |
+
logging_steps=1, # log every step
|
| 232 |
+
logging_first_step=True, # log step 0 immediately
|
| 233 |
save_steps=50,
|
| 234 |
report_to="wandb" if (wandb is not None and not args.dry_run) else "none",
|
| 235 |
)
|
| 236 |
|
| 237 |
+
# ββ Custom callback: force WandB logging + progress bar (no step warnings) ββ
|
| 238 |
+
class ForceWandbCallback(TrainerCallback):
|
|
|
|
|
|
|
| 239 |
def on_log(self, args, state, control, logs=None, **kwargs):
|
| 240 |
+
if not logs:
|
| 241 |
+
return
|
| 242 |
+
# Log everything to wandb WITHOUT step argument (avoids step warnings)
|
| 243 |
+
if wandb is not None and wandb.run is not None:
|
| 244 |
+
wandb.log(logs)
|
| 245 |
+
# Update progress bar with loss
|
| 246 |
+
loss_key = None
|
| 247 |
+
for key in ["loss", "grpo_loss", "train_loss"]:
|
| 248 |
+
if key in logs:
|
| 249 |
+
loss_key = key
|
| 250 |
+
break
|
| 251 |
+
if loss_key is not None:
|
| 252 |
+
loss_val = logs[loss_key]
|
| 253 |
+
if hasattr(state, "progress_bar") and state.progress_bar is not None:
|
| 254 |
+
state.progress_bar.set_postfix({loss_key: f"{loss_val:.4f}"})
|
| 255 |
|
| 256 |
trainer = GRPOTrainer(
|
| 257 |
model=model,
|
|
|
|
| 260 |
train_dataset=train_ds,
|
| 261 |
eval_dataset=eval_ds,
|
| 262 |
)
|
| 263 |
+
trainer.add_callback(ForceWandbCallback())
|
|
|
|
|
|
|
| 264 |
|
| 265 |
print("Starting GRPO training ...")
|
| 266 |
trainer.train()
|
| 267 |
|
| 268 |
+
# Ensure all pending logs are sent to wandb
|
| 269 |
+
if wandb is not None and wandb.run is not None:
|
| 270 |
+
wandb.finish()
|
| 271 |
+
|
| 272 |
+
# ββ Save LoRA adapter ββββββββββββββββββββββββββββββββββββ
|
| 273 |
out = Path(args.output_dir)
|
| 274 |
out.mkdir(parents=True, exist_ok=True)
|
| 275 |
model.save_pretrained(str(out))
|
| 276 |
tokenizer.save_pretrained(str(out))
|
| 277 |
print(f"LoRA adapter saved to {out}")
|
| 278 |
|
| 279 |
+
# ββ Optional HF Hub upload βββββββββββββββββββββββββββββββ
|
| 280 |
hf_repo = os.getenv("HF_REPO", "")
|
| 281 |
if hf_repo:
|
| 282 |
try:
|
|
|
|
| 287 |
print(f"HF upload failed: {exc}")
|
| 288 |
|
| 289 |
|
| 290 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 291 |
+
# Dry-run (CPU simulation, no model)
|
| 292 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 293 |
def _dry_run_training(env, args):
|
|
|
|
| 294 |
print("[DRY RUN] CPU simulation only β no model loaded.\n")
|
| 295 |
from patchhawk.env_models import PatchHawkAction
|
| 296 |
|
|
|
|
| 323 |
atype = env.current_scenario.get("attack_type", "none") or "none"
|
| 324 |
attack_success.setdefault(atype, {"correct": 0, "total": 0})
|
| 325 |
attack_success[atype]["total"] += 1
|
| 326 |
+
if (label == "malicious" and ep_reward > 0) or (label == "benign" and ep_reward >= 0):
|
|
|
|
|
|
|
| 327 |
attack_success[atype]["correct"] += 1
|
| 328 |
|
| 329 |
mean_r = float(np.mean(group_rewards))
|
| 330 |
std_r = float(np.std(group_rewards)) + 1e-8
|
| 331 |
advantages = [(r - mean_r) / std_r for r in group_rewards]
|
| 332 |
epoch_rewards.append(mean_r)
|
| 333 |
+
print(f" Batch mean_reward={mean_r:+.2f} advantages={[f'{a:+.2f}' for a in advantages]}")
|
|
|
|
|
|
|
| 334 |
|
| 335 |
epoch_mean = float(np.mean(epoch_rewards)) if epoch_rewards else 0.0
|
| 336 |
print(f" Epoch {epoch + 1} mean_reward: {epoch_mean:+.2f}")
|
|
|
|
| 346 |
"loss": max(0.0, 1.0 - epoch_mean / 3.0),
|
| 347 |
}
|
| 348 |
for atype, counts in attack_success.items():
|
| 349 |
+
log_data[f"success_rate/{atype}"] = counts["correct"] / max(counts["total"], 1)
|
|
|
|
|
|
|
| 350 |
wandb.log(log_data)
|
| 351 |
except Exception:
|
| 352 |
pass
|
| 353 |
|
| 354 |
out = Path(args.output_dir)
|
| 355 |
out.mkdir(parents=True, exist_ok=True)
|
| 356 |
+
(out / "adapter_config.json").write_text('{"model_type":"patchhawk-grpo-dry-run"}')
|
| 357 |
(out / "adapter_model.bin").write_bytes(b"\x00" * 64)
|
| 358 |
+
print(f"\n[DRY RUN] Dummy adapter written to {args.output_dir}/")
|
| 359 |
|
| 360 |
|
| 361 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 362 |
+
# CLI entry point
|
| 363 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 364 |
if __name__ == "__main__":
|
| 365 |
parser = argparse.ArgumentParser(description="PatchHawk GRPO Training (trl 1.0.0)")
|
| 366 |
+
parser.add_argument("--dry-run", action="store_true", help="CPU simulation, no model")
|
|
|
|
|
|
|
| 367 |
parser.add_argument("--use-docker", action="store_true", help="Use Docker sandbox")
|
| 368 |
+
parser.add_argument("--max-seq-len", type=int, default=1024, help="Total sequence length (prompt+completion)")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 369 |
parser.add_argument("--learning-rate", type=float, default=5e-6)
|
| 370 |
parser.add_argument("--kl-coef", type=float, default=0.01)
|
| 371 |
parser.add_argument("--batch-size", type=int, default=1)
|
|
|
|
| 375 |
parser.add_argument("--max-steps", type=int, default=200)
|
| 376 |
parser.add_argument("--output-dir", type=str, default="grpo_lora")
|
| 377 |
args = parser.parse_args()
|
| 378 |
+
train_agent(args)
|