File size: 16,040 Bytes
e19878b
72bc633
58f6308
72bc633
e19878b
 
 
58f6308
e19878b
 
 
72bc633
 
 
 
 
58f6308
72bc633
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58f6308
 
d6abea2
58f6308
 
72bc633
 
 
 
e19878b
58f6308
 
 
 
 
 
 
72bc633
e19878b
72bc633
e19878b
 
 
 
 
 
 
72bc633
58f6308
72bc633
e19878b
72bc633
 
 
 
 
 
58f6308
72bc633
 
 
 
 
e19878b
58f6308
 
 
 
 
 
 
e19878b
58f6308
 
 
 
 
 
 
 
d6abea2
 
 
 
58f6308
e19878b
58f6308
 
 
 
 
 
 
 
 
 
 
 
 
e19878b
58f6308
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e19878b
58f6308
 
 
 
 
 
 
e19878b
 
58f6308
 
 
 
 
e19878b
58f6308
 
 
 
 
 
 
 
 
d6abea2
 
 
 
58f6308
 
 
 
 
 
 
0e0badb
58f6308
 
 
 
e19878b
58f6308
 
 
 
 
 
 
e19878b
 
58f6308
 
 
 
 
 
 
 
 
 
 
 
 
e19878b
58f6308
 
 
 
 
 
e19878b
58f6308
 
 
 
d6abea2
 
 
58f6308
 
e19878b
d6abea2
 
 
 
 
 
 
 
 
 
58f6308
 
 
 
 
e19878b
58f6308
 
 
 
 
 
 
 
 
e19878b
58f6308
 
 
 
 
e19878b
58f6308
 
 
 
 
 
e19878b
 
58f6308
 
 
 
e19878b
 
58f6308
e19878b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58f6308
 
 
 
 
 
 
 
e19878b
58f6308
 
 
 
e19878b
 
 
 
 
58f6308
 
 
 
 
 
e19878b
58f6308
 
72bc633
58f6308
 
 
 
 
72bc633
 
e19878b
 
 
72bc633
58f6308
72bc633
 
 
 
 
 
 
 
 
 
58f6308
 
72bc633
58f6308
72bc633
58f6308
72bc633
58f6308
72bc633
 
 
58f6308
 
 
72bc633
 
 
 
 
58f6308
72bc633
e19878b
72bc633
 
 
 
 
 
e19878b
72bc633
 
 
 
 
 
 
58f6308
72bc633
 
 
 
58f6308
72bc633
 
e19878b
72bc633
 
 
 
58f6308
 
e19878b
58f6308
e19878b
72bc633
 
e19878b
 
 
72bc633
58f6308
e19878b
58f6308
e19878b
0e0badb
 
58f6308
 
 
 
 
72bc633
 
e19878b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
#!/usr/bin/env python3
"""
GRPO training pipeline for PatchHawk (trl 1.0.0, RTX 3060 12GB).

Fixed for trl 1.0.0:
- Removed max_prompt_length / max_completion_length.
- Disabled fp16 to avoid BFloat16 AMP error.
- Set tokenizer.model_max_length for sequence length control.
- Forced WandB logging every step via custom callback (no step argument to avoid warnings).
- Loss displayed in tqdm progress bar.
- WandB online mode forced before init.
"""

import argparse
import os
import random
import re
from pathlib import Path

import numpy as np

try:
    import wandb
except ImportError:
    wandb = None

_PROJECT_ROOT = Path(__file__).resolve().parent.parent.parent


def _build_prompt(scenario: dict) -> str:
    return (
        "Analyze this Python code for supply-chain vulnerabilities.\n"
        f"<code_snippet>\n{scenario['code_snippet']}\n</code_snippet>\n"
        "Respond in STRICT XML:\n"
        "<thought>...</thought>\n"
        "<risk_score>0.0 to 1.0</risk_score>\n"
        "<action>0-4</action>\n"
        "<patch>...</patch> (ONLY if action=3)\n"
    )


def train_agent(args):
    # Check trl availability
    if not args.dry_run:
        try:
            from trl import GRPOTrainer, GRPOConfig
        except Exception as exc:
            raise RuntimeError(
                "trl not found.\nInstall: pip install trl==1.0.0 peft bitsandbytes accelerate transformers"
            ) from exc

    # ── WandB initialisation (force online mode before init) ──
    if not args.dry_run and wandb is not None:
        os.environ["WANDB_MODE"] = "online"
        os.environ["WANDB_SILENT"] = "false"
        wandb.init(
            project="patchhawk",
            name="grpo-run",
            config=vars(args),
        )
    else:
        print("[INFO] WandB skipped.")

    # ── Environment ──────────────────────────────────────────
    from patchhawk.agent.environment import PatchHawkEnv

    env = PatchHawkEnv(
        scenarios_path=str(_PROJECT_ROOT / "patchhawk" / "data" / "scenarios.json"),
        use_docker=args.use_docker,
    )
    print(f"Loaded {len(env.scenarios)} scenarios.")

    if args.dry_run:
        _dry_run_training(env, args)
        return

    # ── GPU training imports ─────────────────────────────────
    import torch
    from transformers import (
        AutoModelForCausalLM,
        AutoTokenizer,
        BitsAndBytesConfig,
        TrainerCallback,
    )
    from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_kbit_training
    from datasets import Dataset
    from trl import GRPOConfig, GRPOTrainer

    if torch.cuda.is_available():
        print(f"GPU: {torch.cuda.get_device_name(0)}")
    else:
        print("No GPU found β€” training will be slow.")

    from dotenv import load_dotenv
    load_dotenv()
    
    MODEL_NAME = os.getenv("GRPO_POLICY_MODEL", "Qwen/Qwen2.5-Coder-3B-Instruct")

    # 4‑bit quantisation config
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.float16,
        bnb_4bit_use_double_quant=True,
    )

    print(f"Loading {MODEL_NAME} in 4-bit ...")
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "left"

    # Critical: set total sequence length (prompt + generation)
    tokenizer.model_max_length = args.max_seq_len

    base_model = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME,
        quantization_config=bnb_config,
        device_map="auto",
        trust_remote_code=True,
        torch_dtype=torch.float16,
    )

    base_model = prepare_model_for_kbit_training(
        base_model,
        use_gradient_checkpointing=True,
    )

    # LoRA configuration
    lora_config = LoraConfig(
        task_type=TaskType.CAUSAL_LM,
        r=16,
        lora_alpha=16,
        lora_dropout=0.05,
        bias="none",
        target_modules=[
            "q_proj", "k_proj", "v_proj", "o_proj",
            "gate_proj", "up_proj", "down_proj",
        ],
    )
    model = get_peft_model(base_model, lora_config)
    model.print_trainable_parameters()

    # ── Reward 1: XML format ─────────────────────────────────
    def format_reward(completions, **kwargs):
        rewards = []
        for c in completions:
            text = c if isinstance(c, str) else str(c)
            score = 0.0
            if re.search(r"<thought>.*?</thought>", text, re.DOTALL):
                score += 0.5
            else:
                score -= 1.0
            if re.search(r"<risk_score>[\d\.]+</risk_score>", text):
                score += 0.5
            else:
                score -= 1.0
            if re.search(r"<action>[0-4]</action>", text):
                score += 0.5
            else:
                score -= 1.5
            if "<action>3</action>" in text:
                if re.search(r"<patch>.*?</patch>", text, re.DOTALL):
                    score += 0.5
                else:
                    score -= 2.0
            rewards.append(score)
        return rewards

    # ── Reward 2: environment feedback ───────────────────────
    from patchhawk.env_models import PatchHawkAction

    def env_reward(completions, prompts, **kwargs):
        rewards = []
        for prompt, c in zip(prompts, completions):
            text = c if isinstance(c, str) else str(c)

            # Extract code snippet from prompt to identify scenario
            code_match = re.search(r"<code_snippet>(.*?)</code_snippet>", prompt, re.DOTALL)
            if not code_match:
                rewards.append(-2.0)
                continue
            snippet = code_match.group(1).strip()
            scenario = None
            for s in env.scenarios:
                if s.get("code_snippet", "").strip() == snippet:
                    scenario = s
                    break
            if scenario is None:
                rewards.append(-2.0)
                continue

            # Parse action
            action_match = re.search(r"<action>(\d+)</action>", text)
            if not action_match:
                rewards.append(-2.0)
                continue
            action_type = int(action_match.group(1))

            # Parse patch (if any)
            patch = None
            patch_match = re.search(r"<patch>(.*?)</patch>", text, re.DOTALL)
            if patch_match:
                patch = patch_match.group(1).strip()
                
            risk_match = re.search(r"<risk_score>([\d\.]+)</risk_score>", text)
            predicted_risk = float(risk_match.group(1)) if risk_match else None

            try:
                # Reset environment to the exact scenario
                env.reset(scenario=scenario)
                obs = env.step(PatchHawkAction(
                    action_type=action_type, 
                    patch_content=patch, 
                    predicted_risk=predicted_risk
                ))
                reward_val = float(obs.reward or 0.0)
                rewards.append(reward_val)
                val_msg = obs.metadata.get('validation') or ("Telemetry Extracted" if obs.metadata.get('telemetry') else "None")
                print(f"[Env Reward] Action: {action_type} | Reward: {reward_val:+.2f} | Docker: {val_msg}")
            except Exception as exc:
                print(f"env_reward crash: {exc}")
                rewards.append(-3.0)
        return rewards

    # ── Dataset preparation ──────────────────────────────────
    valid = [s for s in env.scenarios if s.get("label") in ("malicious", "benign")]
    random.seed(42)
    random.shuffle(valid)

    split = int(0.8 * len(valid))
    train_ds = Dataset.from_list([{"prompt": _build_prompt(s)} for s in valid[:split]])
    eval_ds = Dataset.from_list([{"prompt": _build_prompt(s)} for s in valid[split:]])
    print(f"Dataset β€” train: {len(train_ds)}, eval: {len(eval_ds)}")

    # ── GRPO Config (trl 1.0.0 compatible) ───────────────────
    grpo_config = GRPOConfig(
        output_dir=args.output_dir,
        learning_rate=args.learning_rate,
        per_device_train_batch_size=args.batch_size,
        gradient_accumulation_steps=args.grad_accum,
        fp16=False,                     # avoids BFloat16 AMP error
        gradient_checkpointing=True,
        num_generations=args.group_size,
        beta=args.kl_coef,
        num_train_epochs=args.epochs,
        warmup_steps=10,
        max_grad_norm=1.0,
        logging_steps=1,                # log every step
        logging_first_step=True,        # log step 0 immediately
        save_steps=50,
        report_to="wandb" if (wandb is not None and not args.dry_run) else "none",
    )

    # ── Custom callback: force WandB logging + progress bar (no step warnings) ──
    class ForceWandbCallback(TrainerCallback):
        def on_log(self, args, state, control, logs=None, **kwargs):
            if not logs:
                return
            # Log everything to wandb WITHOUT step argument (avoids step warnings)
            if wandb is not None and wandb.run is not None:
                wandb.log(logs)
            # Update progress bar with loss
            loss_key = None
            for key in ["loss", "grpo_loss", "train_loss"]:
                if key in logs:
                    loss_key = key
                    break
            if loss_key is not None:
                loss_val = logs[loss_key]
                if hasattr(state, "progress_bar") and state.progress_bar is not None:
                    state.progress_bar.set_postfix({loss_key: f"{loss_val:.4f}"})

    trainer = GRPOTrainer(
        model=model,
        reward_funcs=[format_reward, env_reward],
        args=grpo_config,
        train_dataset=train_ds,
        eval_dataset=eval_ds,
    )
    trainer.add_callback(ForceWandbCallback())

    print("Starting GRPO training ...")
    trainer.train()

    # Ensure all pending logs are sent to wandb
    if wandb is not None and wandb.run is not None:
        wandb.finish()

    # ── Save LoRA adapter ────────────────────────────────────
    out = Path(args.output_dir)
    out.mkdir(parents=True, exist_ok=True)
    model.save_pretrained(str(out))
    tokenizer.save_pretrained(str(out))
    print(f"LoRA adapter saved to {out}")

    # ── Optional HF Hub upload ───────────────────────────────
    hf_repo = os.getenv("HF_REPO", "")
    if hf_repo:
        try:
            model.push_to_hub(hf_repo)
            tokenizer.push_to_hub(hf_repo)
            print(f"Uploaded to https://huggingface.co/{hf_repo}")
        except Exception as exc:
            print(f"HF upload failed: {exc}")


# ─────────────────────────────────────────────────────────────
# Dry-run (CPU simulation, no model)
# ─────────────────────────────────────────────────────────────
def _dry_run_training(env, args):
    print("[DRY RUN] CPU simulation only β€” no model loaded.\n")
    from patchhawk.env_models import PatchHawkAction

    def heuristic_policy(obs):
        risk = obs.risk_score
        if risk > 0.5:
            return PatchHawkAction(action_type=env.ACTION_BLOCK_PR)
        elif risk > 0.2:
            return PatchHawkAction(action_type=env.ACTION_EXECUTE_SANDBOX)
        return PatchHawkAction(action_type=env.ACTION_REQUEST_REVIEW)

    for epoch in range(args.epochs):
        print(f"── Epoch {epoch + 1}/{args.epochs} ──")
        epoch_rewards = []
        attack_success = {}

        for _ in range(0, min(len(env.scenarios), args.max_steps), args.group_size):
            group_rewards = []
            for _ in range(args.group_size):
                obs = env.reset()
                ep_reward = 0.0
                steps = 0
                while not obs.done and steps < env.max_steps:
                    obs = env.step(heuristic_policy(obs))
                    ep_reward += float(obs.reward or 0.0)
                    steps += 1
                group_rewards.append(ep_reward)

                label = env.current_scenario.get("label", "benign")
                atype = env.current_scenario.get("attack_type", "none") or "none"
                attack_success.setdefault(atype, {"correct": 0, "total": 0})
                attack_success[atype]["total"] += 1
                if (label == "malicious" and ep_reward > 0) or (label == "benign" and ep_reward >= 0):
                    attack_success[atype]["correct"] += 1

            mean_r = float(np.mean(group_rewards))
            std_r = float(np.std(group_rewards)) + 1e-8
            advantages = [(r - mean_r) / std_r for r in group_rewards]
            epoch_rewards.append(mean_r)
            print(f"  Batch mean_reward={mean_r:+.2f}  advantages={[f'{a:+.2f}' for a in advantages]}")

        epoch_mean = float(np.mean(epoch_rewards)) if epoch_rewards else 0.0
        print(f"  Epoch {epoch + 1} mean_reward: {epoch_mean:+.2f}")
        for atype, counts in attack_success.items():
            rate = counts["correct"] / max(counts["total"], 1)
            print(f"    {atype}: {rate:.0%} ({counts['correct']}/{counts['total']})")

        if wandb is not None:
            try:
                log_data = {
                    "epoch": epoch + 1,
                    "mean_reward": epoch_mean,
                    "loss": max(0.0, 1.0 - epoch_mean / 3.0),
                }
                for atype, counts in attack_success.items():
                    log_data[f"success_rate/{atype}"] = counts["correct"] / max(counts["total"], 1)
                wandb.log(log_data)
            except Exception:
                pass

    out = Path(args.output_dir)
    out.mkdir(parents=True, exist_ok=True)
    (out / "adapter_config.json").write_text('{"model_type":"patchhawk-grpo-dry-run"}')
    (out / "adapter_model.bin").write_bytes(b"\x00" * 64)
    print(f"\n[DRY RUN] Dummy adapter written to {args.output_dir}/")


# ─────────────────────────────────────────────────────────────
# CLI entry point
# ─────────────────────────────────────────────────────────────
if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="PatchHawk GRPO Training (trl 1.0.0)")
    parser.add_argument("--dry-run", action="store_true", help="CPU simulation, no model")
    parser.add_argument("--use-docker", action="store_true", help="Use Docker sandbox")
    parser.add_argument("--max-seq-len", type=int, default=1024, help="Total sequence length (prompt+completion)")
    parser.add_argument("--learning-rate", type=float, default=5e-6)
    parser.add_argument("--kl-coef", type=float, default=0.01)
    parser.add_argument("--batch-size", type=int, default=1)
    parser.add_argument("--grad-accum", type=int, default=8)
    parser.add_argument("--group-size", type=int, default=4)
    parser.add_argument("--epochs", type=int, default=3)
    parser.add_argument("--max-steps", type=int, default=200)
    parser.add_argument("--output-dir", type=str, default="grpo_lora")
    args = parser.parse_args()
    train_agent(args)