oxdev commited on
Commit
74022f8
Β·
verified Β·
1 Parent(s): c4b5a68

Upload train_grpo_job.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train_grpo_job.py +220 -0
train_grpo_job.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ train_grpo_job.py β€” Self-contained GRPO training job for HF Jobs.
4
+
5
+ Loads dataset from HF Hub, runs GRPO training with custom reward functions,
6
+ pushes model to Hub on completion.
7
+ """
8
+
9
+ import logging
10
+ import os
11
+ import re
12
+ import shutil
13
+ import subprocess
14
+ import tempfile
15
+ from pathlib import Path
16
+
17
+ import torch
18
+ from datasets import load_dataset
19
+ from trl import GRPOTrainer, GRPOConfig
20
+
21
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
22
+ logger = logging.getLogger(__name__)
23
+
24
+ # ─── Config ───────────────────────────────────────────────────────────────────
25
+ MODEL_NAME = "Qwen/Qwen2.5-Coder-1.5B-Instruct"
26
+ DATASET_ID = "oxdev/smart-contract-security-sft"
27
+ OUTPUT_DIR = "/tmp/grpo_output"
28
+ HUB_MODEL_ID = "oxdev/security-auditor-grpo"
29
+
30
+ FORGE_AVAILABLE = shutil.which("forge") is not None
31
+
32
+ # ─── Reward Functions ─────────────────────────────────────────────────────────
33
+
34
+ def extract_finding_block(text: str) -> dict | None:
35
+ pattern = re.compile(
36
+ r'FINDING\s*\|\s*contract:\s*(\S+)\s*\|\s*function:\s*(\S+)\s*\|'
37
+ r'\s*bug_class:\s*(\S+)\s*\|\s*confidence:\s*(\d+)',
38
+ re.IGNORECASE
39
+ )
40
+ match = pattern.search(text)
41
+ if not match:
42
+ return None
43
+ return {
44
+ "contract": match.group(1),
45
+ "function": match.group(2),
46
+ "bug_class": match.group(3),
47
+ "confidence": int(match.group(4)),
48
+ }
49
+
50
+
51
+ def extract_solidity_poc(text: str) -> str | None:
52
+ pattern = re.compile(r'```solidity\s*\n(.*?)```', re.DOTALL)
53
+ matches = pattern.findall(text)
54
+ if not matches:
55
+ return None
56
+ for code in matches:
57
+ if "is Test" in code or "function test_" in code:
58
+ return code.strip()
59
+ return max(matches, key=len).strip() if matches else None
60
+
61
+
62
+ def _check_solidity_syntax(code: str) -> bool:
63
+ required = [r'pragma\s+solidity', r'contract\s+\w+', r'function\s+\w+']
64
+ return all(re.search(p, code) for p in required)
65
+
66
+
67
+ def run_forge_test(poc_code: str, timeout: int = 30) -> dict:
68
+ if not FORGE_AVAILABLE:
69
+ return {
70
+ "compiled": False,
71
+ "test_passed": False,
72
+ "syntax_valid": _check_solidity_syntax(poc_code),
73
+ }
74
+
75
+ tmpdir = tempfile.mkdtemp(prefix="forge_poc_")
76
+ try:
77
+ test_dir = Path(tmpdir) / "test"
78
+ test_dir.mkdir()
79
+ (Path(tmpdir) / "foundry.toml").write_text('[profile.default]\nsrc = "src"\nout = "out"\nlibs = ["lib"]\nsolc_version = "0.8.24"\n')
80
+ (Path(tmpdir) / "src").mkdir()
81
+
82
+ try:
83
+ subprocess.run(
84
+ ["forge", "install", "foundry-rs/forge-std", "--no-git", "--no-commit"],
85
+ cwd=tmpdir, capture_output=True, timeout=60,
86
+ )
87
+ except Exception:
88
+ pass
89
+
90
+ (Path(tmpdir) / "remappings.txt").write_text("forge-std/=lib/forge-std/src/\n")
91
+ (test_dir / "PoC.t.sol").write_text(poc_code)
92
+
93
+ build = subprocess.run(["forge", "build"], cwd=tmpdir, capture_output=True, text=True, timeout=timeout)
94
+ if build.returncode != 0:
95
+ return {"compiled": False, "test_passed": False}
96
+
97
+ test = subprocess.run(["forge", "test", "-vv"], cwd=tmpdir, capture_output=True, text=True, timeout=timeout)
98
+ return {"compiled": True, "test_passed": test.returncode == 0 and "PASS" in test.stdout}
99
+
100
+ except Exception:
101
+ return {"compiled": False, "test_passed": False}
102
+ finally:
103
+ shutil.rmtree(tmpdir, ignore_errors=True)
104
+
105
+
106
+ def security_audit_reward(completions, log_extra=None, log_metric=None, **kwargs):
107
+ """Primary reward: FINDING block + PoC compilation + exploit verification."""
108
+ rewards = []
109
+ finding_count = compile_count = pass_count = 0
110
+
111
+ for completion in completions:
112
+ text = completion[0]["content"] if isinstance(completion, list) else str(completion)
113
+ reward = -1.0
114
+
115
+ finding = extract_finding_block(text)
116
+ if finding:
117
+ finding_count += 1
118
+ reward = 0.0
119
+ poc = extract_solidity_poc(text)
120
+ if poc:
121
+ reward = 0.2
122
+ result = run_forge_test(poc)
123
+ if result.get("compiled") or result.get("syntax_valid", False):
124
+ compile_count += 1
125
+ reward = 0.5
126
+ if result.get("test_passed"):
127
+ pass_count += 1
128
+ reward = 1.0
129
+ elif any(kw in text.lower() for kw in ["vulnerability", "exploit", "bug", "finding"]):
130
+ reward = -0.5
131
+
132
+ rewards.append(reward)
133
+
134
+ if log_metric and rewards:
135
+ log_metric("finding_rate", finding_count / len(rewards))
136
+ log_metric("compile_rate", compile_count / len(rewards))
137
+ log_metric("exploit_rate", pass_count / len(rewards))
138
+
139
+ return rewards
140
+
141
+
142
+ def format_reward(completions, **kwargs):
143
+ """Secondary reward: structural format compliance."""
144
+ rewards = []
145
+ for completion in completions:
146
+ text = completion[0]["content"] if isinstance(completion, list) else str(completion)
147
+ reward = 0.0
148
+ if re.search(r'FINDING\s*\|', text):
149
+ fields = sum(bool(re.search(p, text)) for p in [r'path:', r'proof:', r'description:', r'fix:'])
150
+ reward = 0.3 + (0.05 * fields)
151
+ if re.search(r'```solidity', text):
152
+ reward += 0.1
153
+ rewards.append(reward)
154
+ return rewards
155
+
156
+
157
+ # ─── Main ─────────────────────────────────────────────────────────────────────
158
+
159
+ def main():
160
+ logger.info("=" * 60)
161
+ logger.info("GRPO Training β€” Smart Contract Security Auditor")
162
+ logger.info(f"Model: {MODEL_NAME}")
163
+ logger.info(f"Dataset: {DATASET_ID}")
164
+ logger.info(f"Forge available: {FORGE_AVAILABLE}")
165
+ logger.info(f"GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'}")
166
+ logger.info("=" * 60)
167
+
168
+ # Load dataset
169
+ logger.info("Loading dataset from HF Hub...")
170
+ dataset = load_dataset(DATASET_ID, split="train")
171
+ logger.info(f"Dataset: {len(dataset)} samples, columns={dataset.column_names}")
172
+
173
+ # Configure GRPO
174
+ config = GRPOConfig(
175
+ output_dir=OUTPUT_DIR,
176
+ num_train_epochs=2,
177
+ per_device_train_batch_size=2,
178
+ num_generations=4,
179
+ max_completion_length=1536,
180
+ learning_rate=5e-7,
181
+ beta=0.0,
182
+ scale_rewards=True,
183
+ reward_weights=[0.7, 0.3],
184
+ gradient_checkpointing=True,
185
+ bf16=True,
186
+ logging_steps=5,
187
+ logging_first_step=True,
188
+ logging_strategy="steps",
189
+ disable_tqdm=True,
190
+ save_strategy="steps",
191
+ save_steps=50,
192
+ save_total_limit=2,
193
+ log_completions=True,
194
+ push_to_hub=True,
195
+ hub_model_id=HUB_MODEL_ID,
196
+ report_to="none",
197
+ seed=42,
198
+ )
199
+
200
+ # Train
201
+ logger.info("Initializing GRPOTrainer...")
202
+ trainer = GRPOTrainer(
203
+ model=MODEL_NAME,
204
+ args=config,
205
+ reward_funcs=[security_audit_reward, format_reward],
206
+ train_dataset=dataset,
207
+ )
208
+
209
+ logger.info("Starting training...")
210
+ trainer.train()
211
+
212
+ logger.info("Saving model...")
213
+ trainer.save_model(OUTPUT_DIR)
214
+ trainer.push_to_hub()
215
+
216
+ logger.info(f"βœ… Done! Model pushed to https://huggingface.co/{HUB_MODEL_ID}")
217
+
218
+
219
+ if __name__ == "__main__":
220
+ main()