sam25kat Claude Sonnet 4.6 commited on
Commit
443f900
·
1 Parent(s): a28dc6a

Add HF training Space: Gradio UI + GRPO train script

Browse files

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

training_space/README.md ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: SecureReview GRPO Trainer
3
+ emoji: 🔐
4
+ colorFrom: red
5
+ colorTo: gray
6
+ sdk: gradio
7
+ sdk_version: "4.40.0"
8
+ app_file: app.py
9
+ pinned: false
10
+ hardware: t4-small
11
+ ---
12
+
13
+ # SecureReview — GRPO Trainer
14
+
15
+ Trains `Qwen2.5-1.5B-Instruct` using Group Relative Policy Optimization (GRPO) on the [SecureReview](https://sam25kat-securereview.hf.space) environment.
16
+
17
+ ## What this does
18
+
19
+ - Loads the model in 4-bit QLoRA (via Unsloth)
20
+ - Connects to the live SecureReview environment as a reward oracle
21
+ - Runs 150 GRPO training steps — reward = F1-based score from graded vulnerability findings
22
+ - Produces `plots/reward_curve.png` and `plots/before_after.png`
23
+
24
+ ## Usage
25
+
26
+ Click **Run Training** in the Gradio UI. Training takes ~20 minutes on T4.
27
+
28
+ ## Environment
29
+
30
+ The reward signal comes from [sam25kat/securereview](https://huggingface.co/spaces/sam25kat/securereview) — a live OpenEnv environment that grades security findings against ground-truth scenarios.
training_space/app.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import subprocess
3
+ import sys
4
+ import os
5
+ import json
6
+ import threading
7
+
8
+ PLOTS_DIR = "./plots"
9
+ RESULTS_FILE = f"{PLOTS_DIR}/results.json"
10
+
11
+
12
+ def run_training():
13
+ os.makedirs(PLOTS_DIR, exist_ok=True)
14
+ proc = subprocess.Popen(
15
+ [sys.executable, "train.py"],
16
+ stdout=subprocess.PIPE,
17
+ stderr=subprocess.STDOUT,
18
+ text=True,
19
+ bufsize=1,
20
+ )
21
+ output = []
22
+ for line in proc.stdout:
23
+ output.append(line.rstrip())
24
+ yield "\n".join(output), None, None, None, None
25
+
26
+ proc.wait()
27
+
28
+ reward_img = f"{PLOTS_DIR}/reward_curve.png" if os.path.exists(f"{PLOTS_DIR}/reward_curve.png") else None
29
+ ba_img = f"{PLOTS_DIR}/before_after.png" if os.path.exists(f"{PLOTS_DIR}/before_after.png") else None
30
+
31
+ summary = ""
32
+ if os.path.exists(RESULTS_FILE):
33
+ with open(RESULTS_FILE) as f:
34
+ r = json.load(f)
35
+ summary = (
36
+ f"**Baseline mean:** {r['baseline_mean']:.3f}\n\n"
37
+ f"**Trained mean:** {r['trained_mean']:.3f}\n\n"
38
+ f"**Improvement:** {r['improvement']:+.3f}"
39
+ )
40
+
41
+ yield "\n".join(output), reward_img, ba_img, summary, gr.update(interactive=True)
42
+
43
+
44
+ def start_training(btn_state):
45
+ yield (
46
+ "Starting training... this takes ~20 minutes on T4.\n",
47
+ None,
48
+ None,
49
+ "",
50
+ gr.update(interactive=False, value="Training in progress..."),
51
+ )
52
+ yield from run_training()
53
+
54
+
55
+ with gr.Blocks(title="SecureReview GRPO Trainer", theme=gr.themes.Monochrome()) as demo:
56
+ gr.Markdown(
57
+ """# SecureReview — GRPO Training
58
+
59
+ Trains `Qwen2.5-1.5B-Instruct` via GRPO on the [SecureReview](https://sam25kat-securereview.hf.space) environment.
60
+ The model learns to identify security vulnerabilities in dependency files — reward comes from a live graded environment, not a static dataset.
61
+
62
+ **Hardware:** T4 GPU · **Time:** ~20 min · **Steps:** 150"""
63
+ )
64
+
65
+ with gr.Row():
66
+ run_btn = gr.Button("▶ Run Training", variant="primary", scale=1)
67
+
68
+ with gr.Row():
69
+ log_box = gr.Textbox(
70
+ label="Training Log",
71
+ lines=30,
72
+ max_lines=60,
73
+ autoscroll=True,
74
+ interactive=False,
75
+ scale=2,
76
+ )
77
+ with gr.Column(scale=1):
78
+ summary_md = gr.Markdown(label="Results Summary")
79
+ reward_img = gr.Image(label="Reward Curve", type="filepath")
80
+ ba_img = gr.Image(label="Before vs After", type="filepath")
81
+
82
+ run_btn.click(
83
+ fn=start_training,
84
+ inputs=[run_btn],
85
+ outputs=[log_box, reward_img, ba_img, summary_md, run_btn],
86
+ )
87
+
88
+ if __name__ == "__main__":
89
+ demo.launch()
training_space/requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git
2
+ trl>=0.12.0
3
+ datasets>=2.18.0
4
+ transformers>=4.40.0
5
+ accelerate>=0.29.0
6
+ bitsandbytes>=0.43.0
7
+ peft>=0.10.0
8
+ requests>=2.31.0
9
+ matplotlib>=3.8.0
10
+ numpy>=1.26.0
11
+ gradio>=4.40.0
training_space/train.py ADDED
@@ -0,0 +1,342 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import json
4
+ import re
5
+ import time
6
+ import requests
7
+ import functools
8
+
9
+ print = functools.partial(print, flush=True)
10
+
11
+ ENV_URL = os.getenv("ENV_URL", "https://sam25kat-securereview.hf.space")
12
+ TASK_ID = "dependency_review"
13
+ MODEL_NAME = "unsloth/Qwen2.5-1.5B-Instruct"
14
+ MAX_SEQ_LEN = 2048
15
+ NUM_GENERATIONS = 4
16
+ MAX_NEW_TOKENS = 600
17
+ TRAIN_STEPS = 150
18
+ LEARNING_RATE = 2e-5
19
+ LORA_RANK = 16
20
+ GRAD_ACCUM_STEPS = 4
21
+ OUTPUT_DIR = "./securereview-grpo"
22
+ PLOTS_DIR = "./plots"
23
+
24
+ os.makedirs(PLOTS_DIR, exist_ok=True)
25
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
26
+
27
+ SYSTEM_PROMPT = """You are a senior security engineer reviewing dependency files for vulnerabilities.
28
+
29
+ Identify ALL security issues including:
30
+ - Typosquatted packages (names that misspell popular libraries, e.g. 'reqeusts' instead of 'requests')
31
+ - Known CVE-vulnerable versions (e.g. requests<2.20.0 has CVE-2018-18074)
32
+ - Hallucinated / non-existent packages that don't exist on PyPI or npm
33
+ - Suspicious or malicious packages
34
+
35
+ Output ONLY a valid JSON array of findings. Each finding must have:
36
+ file, line (integer or null), rule_id (e.g. DEP-001), severity (critical/high/medium/low/info), description
37
+
38
+ Example output:
39
+ [
40
+ {"file": "requirements.txt", "line": 3, "rule_id": "DEP-001", "severity": "critical", "description": "Typosquat: 'reqeusts' misspells 'requests'"}
41
+ ]
42
+
43
+ If no issues found, output: []
44
+ Output ONLY the JSON array. No explanations, no markdown prose."""
45
+
46
+
47
+ # ── Environment helpers ───────────────────────────────────────────────────────
48
+
49
+ def env_reset(task_id, scenario_id=None):
50
+ payload = {"task_id": task_id}
51
+ if scenario_id:
52
+ payload["scenario_id"] = scenario_id
53
+ r = requests.post(f"{ENV_URL}/reset", json=payload, timeout=30)
54
+ r.raise_for_status()
55
+ return r.json()
56
+
57
+
58
+ def env_step(action):
59
+ r = requests.post(f"{ENV_URL}/step", json={"action": action}, timeout=30)
60
+ r.raise_for_status()
61
+ return r.json()
62
+
63
+
64
+ def parse_findings(text):
65
+ patterns = [
66
+ r'```(?:json)?\s*(\[.*?\])\s*```',
67
+ r'(\[\s*\{.*?\}\s*\])',
68
+ ]
69
+ for pattern in patterns:
70
+ m = re.search(pattern, text, re.DOTALL)
71
+ if m:
72
+ try:
73
+ return json.loads(m.group(1))
74
+ except json.JSONDecodeError:
75
+ continue
76
+ return []
77
+
78
+
79
+ def run_episode(completion, scenario_id):
80
+ findings = parse_findings(completion)
81
+ try:
82
+ env_reset(TASK_ID, scenario_id)
83
+ valid_sev = {"critical", "high", "medium", "low", "info"}
84
+ for f in findings:
85
+ if not isinstance(f, dict):
86
+ continue
87
+ finding = {
88
+ "file": str(f.get("file", "requirements.txt")),
89
+ "line": f.get("line"),
90
+ "rule_id": str(f.get("rule_id", "DEP-001")),
91
+ "severity": f.get("severity", "medium") if f.get("severity") in valid_sev else "medium",
92
+ "description": str(f.get("description", ""))[:400],
93
+ }
94
+ env_step({"action_type": "report_finding", "finding": finding})
95
+ result = env_step({"action_type": "mark_complete"})
96
+ return float(result.get("reward", 0.01))
97
+ except Exception as e:
98
+ print(f" [env error] {e}")
99
+ return 0.01
100
+
101
+
102
+ def build_prompt(obs):
103
+ ctx = obs["observation"]["context"]
104
+ files = ctx["files"]
105
+ parts = [f"Task: {ctx['task_description']}\n"]
106
+ for f in files:
107
+ parts.append(f"\n--- {f['filename']} ---\n{f['content']}")
108
+ parts.append("\nList all security issues as a JSON array:")
109
+ return "".join(parts)
110
+
111
+
112
+ # ── Main training entry point ─────────────────────────────────────────────────
113
+
114
+ def main():
115
+ print("=" * 60)
116
+ print(" SecureReview GRPO Training")
117
+ print(f" Model : {MODEL_NAME}")
118
+ print(f" Task : {TASK_ID}")
119
+ print(f" Steps : {TRAIN_STEPS}")
120
+ print("=" * 60)
121
+
122
+ # Verify environment
123
+ print("\n[1/6] Checking environment connection...")
124
+ r = requests.get(f"{ENV_URL}/health", timeout=15)
125
+ print(f" Health: {r.json()}")
126
+
127
+ # Load model
128
+ print("\n[2/6] Loading model (this takes ~2 min)...")
129
+ from unsloth import FastLanguageModel
130
+ import torch
131
+
132
+ model, tokenizer = FastLanguageModel.from_pretrained(
133
+ model_name = MODEL_NAME,
134
+ max_seq_length= MAX_SEQ_LEN,
135
+ dtype = None,
136
+ load_in_4bit = True,
137
+ )
138
+ model = FastLanguageModel.get_peft_model(
139
+ model,
140
+ r = LORA_RANK,
141
+ target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
142
+ "gate_proj", "up_proj", "down_proj"],
143
+ lora_alpha = LORA_RANK,
144
+ lora_dropout = 0,
145
+ bias = "none",
146
+ use_gradient_checkpointing = "unsloth",
147
+ random_state = 42,
148
+ )
149
+ print(f" Trainable params: ", end="")
150
+ model.print_trainable_parameters()
151
+
152
+ # Build dataset
153
+ print("\n[3/6] Building training dataset...")
154
+ from datasets import Dataset
155
+
156
+ scenario_ids = [f"dep_{i:03d}" for i in range(1, 7)]
157
+ examples = []
158
+ for sid in scenario_ids:
159
+ try:
160
+ obs = env_reset(TASK_ID, sid)
161
+ prompt = build_prompt(obs)
162
+ examples.append({
163
+ "prompt": [
164
+ {"role": "system", "content": SYSTEM_PROMPT},
165
+ {"role": "user", "content": prompt},
166
+ ],
167
+ "scenario_id": sid,
168
+ })
169
+ print(f" Loaded {sid}")
170
+ except Exception as e:
171
+ print(f" Skipping {sid}: {e}")
172
+
173
+ repeats = max(1, TRAIN_STEPS // (len(examples) * NUM_GENERATIONS) + 1)
174
+ examples = examples * repeats
175
+ dataset = Dataset.from_list(examples)
176
+ print(f" Dataset: {len(examples)} examples")
177
+
178
+ # Baseline evaluation
179
+ print("\n[4/6] Baseline evaluation (before training)...")
180
+ FastLanguageModel.for_inference(model)
181
+
182
+ def evaluate(sids, label):
183
+ scores = {}
184
+ for sid in sids:
185
+ obs = env_reset(TASK_ID, sid)
186
+ prompt_text = build_prompt(obs)
187
+ messages = [
188
+ {"role": "system", "content": SYSTEM_PROMPT},
189
+ {"role": "user", "content": prompt_text},
190
+ ]
191
+ inputs = tokenizer.apply_chat_template(
192
+ messages, tokenize=True, add_generation_prompt=True, return_tensors="pt"
193
+ ).to("cuda")
194
+ with torch.no_grad():
195
+ outputs = model.generate(
196
+ inputs, max_new_tokens=MAX_NEW_TOKENS,
197
+ temperature=0.1, do_sample=True,
198
+ pad_token_id=tokenizer.eos_token_id
199
+ )
200
+ completion = tokenizer.decode(outputs[0][inputs.shape[1]:], skip_special_tokens=True)
201
+ score = run_episode(completion, sid)
202
+ scores[sid] = score
203
+ print(f" [{label}] {sid}: {score:.3f}")
204
+ time.sleep(0.3)
205
+ return scores
206
+
207
+ baseline_scores = evaluate(scenario_ids, "before")
208
+ print(f" Baseline mean: {sum(baseline_scores.values())/len(baseline_scores):.3f}")
209
+
210
+ # GRPO training
211
+ print("\n[5/6] GRPO training...")
212
+ FastLanguageModel.for_training(model)
213
+
214
+ from trl import GRPOConfig, GRPOTrainer
215
+
216
+ reward_log = []
217
+ step_counter = [0]
218
+
219
+ def reward_fn(completions, prompts=None, **kwargs):
220
+ sids = kwargs.get("scenario_id", [scenario_ids[0]] * len(completions))
221
+ if isinstance(sids, str):
222
+ sids = [sids] * len(completions)
223
+ rewards = []
224
+ for completion, sid in zip(completions, sids):
225
+ text = completion if isinstance(completion, str) else completion[-1]["content"]
226
+ reward = run_episode(text, sid)
227
+ rewards.append(reward)
228
+ step_counter[0] += 1
229
+ mean_r = sum(rewards) / len(rewards)
230
+ reward_log.append({"step": step_counter[0], "mean_reward": mean_r})
231
+ print(f" Step {step_counter[0]:>4d} | {[round(r,3) for r in rewards]} | mean {mean_r:.3f}")
232
+ return rewards
233
+
234
+ training_args = GRPOConfig(
235
+ num_generations = NUM_GENERATIONS,
236
+ max_completion_length = MAX_NEW_TOKENS,
237
+ per_device_train_batch_size = 1,
238
+ gradient_accumulation_steps = GRAD_ACCUM_STEPS,
239
+ learning_rate = LEARNING_RATE,
240
+ optim = "adamw_8bit",
241
+ weight_decay = 0.01,
242
+ lr_scheduler_type = "cosine",
243
+ warmup_ratio = 0.05,
244
+ max_steps = TRAIN_STEPS,
245
+ logging_steps = 5,
246
+ save_steps = 50,
247
+ output_dir = OUTPUT_DIR,
248
+ fp16 = not torch.cuda.is_bf16_supported(),
249
+ bf16 = torch.cuda.is_bf16_supported(),
250
+ seed = 42,
251
+ report_to = "none",
252
+ )
253
+ trainer = GRPOTrainer(
254
+ model = model,
255
+ processing_class = tokenizer,
256
+ reward_funcs = reward_fn,
257
+ args = training_args,
258
+ train_dataset = dataset,
259
+ )
260
+ trainer.train()
261
+
262
+ # Post-training evaluation
263
+ print("\n[6/6] Post-training evaluation...")
264
+ FastLanguageModel.for_inference(model)
265
+ trained_scores = evaluate(scenario_ids, "after")
266
+ print(f" Trained mean: {sum(trained_scores.values())/len(trained_scores):.3f}")
267
+
268
+ print("\n=== Improvement Summary ===")
269
+ for sid in scenario_ids:
270
+ b = baseline_scores.get(sid, 0)
271
+ t = trained_scores.get(sid, 0)
272
+ arrow = "▲" if t > b else ("▼" if t < b else "—")
273
+ print(f" {sid}: {b:.3f} → {t:.3f} {arrow} {t-b:+.3f}")
274
+
275
+ # Plots
276
+ import matplotlib
277
+ matplotlib.use("Agg")
278
+ import matplotlib.pyplot as plt
279
+ import matplotlib.ticker as ticker
280
+ import numpy as np
281
+
282
+ plt.style.use("dark_background")
283
+ steps = [e["step"] for e in reward_log]
284
+ rewards = [e["mean_reward"] for e in reward_log]
285
+ window = 5
286
+ if len(rewards) >= window:
287
+ smoothed = np.convolve(rewards, np.ones(window)/window, mode="valid")
288
+ smooth_steps = steps[window-1:]
289
+ else:
290
+ smoothed, smooth_steps = rewards, steps
291
+
292
+ fig, ax = plt.subplots(figsize=(11, 4))
293
+ ax.plot(steps, rewards, color="#ff6b35", alpha=0.3, linewidth=1, label="Raw")
294
+ ax.plot(smooth_steps, smoothed, color="#ff6b35", linewidth=2.5, label=f"Smoothed (w={window})")
295
+ ax.set_xlabel("Training Step"); ax.set_ylabel("Episode Reward")
296
+ ax.set_title("SecureReview — GRPO Training Reward Curve", fontweight="bold")
297
+ ax.set_ylim(0, 1); ax.legend(); ax.grid(True, alpha=0.2)
298
+ fig.tight_layout()
299
+ plt.savefig(f"{PLOTS_DIR}/reward_curve.png", dpi=150, bbox_inches="tight")
300
+ plt.close()
301
+ print(f" Saved {PLOTS_DIR}/reward_curve.png")
302
+
303
+ b_vals = [baseline_scores.get(s, 0) for s in scenario_ids]
304
+ t_vals = [trained_scores.get(s, 0) for s in scenario_ids]
305
+ x = np.arange(len(scenario_ids))
306
+ fig, ax = plt.subplots(figsize=(10, 5))
307
+ ax.bar(x - 0.175, b_vals, 0.35, label="Before", color="#444444")
308
+ ax.bar(x + 0.175, t_vals, 0.35, label="After", color="#ff6b35")
309
+ for i, (b, t) in enumerate(zip(b_vals, t_vals)):
310
+ ax.text(i+0.175, t+0.02, f"{t-b:+.2f}", ha="center", fontsize=9,
311
+ color="#22d3ee" if t >= b else "#ef4444")
312
+ ax.set_xticks(x)
313
+ ax.set_xticklabels([s.replace("dep_","Dep ") for s in scenario_ids], rotation=15)
314
+ ax.set_ylim(0, 1); ax.legend()
315
+ ax.set_title("SecureReview — Before vs After GRPO", fontweight="bold")
316
+ mb = sum(b_vals)/len(b_vals); mt = sum(t_vals)/len(t_vals)
317
+ ax.text(0.98, 0.92, f"Mean: {mb:.2f} → {mt:.2f} ({mt-mb:+.2f})",
318
+ transform=ax.transAxes, ha="right", fontsize=10, color="white",
319
+ bbox=dict(boxstyle="round", facecolor="#1a1a1a", alpha=0.8))
320
+ fig.tight_layout()
321
+ plt.savefig(f"{PLOTS_DIR}/before_after.png", dpi=150, bbox_inches="tight")
322
+ plt.close()
323
+ print(f" Saved {PLOTS_DIR}/before_after.png")
324
+
325
+ # Save results JSON for the Gradio app to read
326
+ results = {
327
+ "baseline_mean": sum(b_vals)/len(b_vals),
328
+ "trained_mean": sum(t_vals)/len(t_vals),
329
+ "improvement": sum(t_vals)/len(t_vals) - sum(b_vals)/len(b_vals),
330
+ "baseline_scores": baseline_scores,
331
+ "trained_scores": trained_scores,
332
+ }
333
+ with open(f"{PLOTS_DIR}/results.json", "w") as f:
334
+ json.dump(results, f, indent=2)
335
+
336
+ print("\n" + "=" * 60)
337
+ print(f" DONE — Mean {sum(b_vals)/len(b_vals):.3f} → {sum(t_vals)/len(t_vals):.3f}")
338
+ print("=" * 60)
339
+
340
+
341
+ if __name__ == "__main__":
342
+ main()