security-auditor-grpo / train_grpo_v2.py
oxdev's picture
v2: 5K subset for A10G, fix escaping
3c818d7 verified
#!/usr/bin/env python3
"""
train_grpo_v2.py β€” GRPO training on 50K real audit findings.
V2 improvements over V1:
- 155x more data (50,902 vs 327)
- 4 reward functions with ground-truth severity/category matching
- Reference-based semantic similarity reward
- Better exploration via higher num_generations
"""
import logging
import os
import re
import shutil
from collections import Counter
import torch
from datasets import load_dataset
from trl import GRPOTrainer, GRPOConfig
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
logger = logging.getLogger(__name__)
# ─── Config ───────────────────────────────────────────────────────────────────
MODEL_NAME = "Qwen/Qwen2.5-Coder-0.5B-Instruct"
DATASET_ID = "oxdev/smart-contract-security-audit-v2"
OUTPUT_DIR = "/tmp/grpo_v2_output"
HUB_MODEL_ID = "oxdev/security-auditor-grpo"
# ─── Reward Function 1: Structure & Format (weight: 0.25) ────────────────────
def format_reward(prompts, completions, completion_ids=None, **kwargs):
"""Reward for producing structured FINDING blocks and proper formatting."""
rewards = []
for completion in completions:
text = completion[0]["content"] if isinstance(completion, list) else str(completion)
reward = 0.0
# FINDING block present
if re.search(r'FINDING\s*\|', text):
reward += 0.3
# Required fields
fields = ['contract:', 'function:', 'bug_class:', 'confidence:']
field_count = sum(1 for f in fields if f in text)
reward += 0.05 * field_count # up to 0.2 more
# Has code block
if re.search(r'```solidity', text):
reward += 0.15
# Has structured sections
section_keywords = ['description', 'impact', 'proof', 'fix', 'recommendation', 'mitigation']
section_count = sum(1 for kw in section_keywords if re.search(rf'(?i)(###?\s*{kw}|{kw}:)', text))
reward += 0.05 * min(section_count, 3) # up to 0.15
# Penalize very short or very long
if len(text) < 50:
reward -= 0.3
elif len(text) > 4000:
reward -= 0.1
rewards.append(max(-1.0, min(1.0, reward)))
return rewards
# ─── Reward Function 2: Severity Match (weight: 0.25) ────────────────────────
def severity_reward(prompts, completions, completion_ids=None, severity=None, **kwargs):
"""Reward for correctly identifying the severity level."""
rewards = []
if severity is None:
return [0.0] * len(completions)
# Handle batch: severity may be a list
if isinstance(severity, list):
sev_list = severity
else:
sev_list = [severity] * len(completions)
for i, completion in enumerate(completions):
text = completion[0]["content"] if isinstance(completion, list) else str(completion)
text_lower = text.lower()
gt_sev = sev_list[i] if i < len(sev_list) else "unknown"
if gt_sev == "unknown":
rewards.append(0.0)
continue
# Extract predicted severity
pred_sev = None
sev_match = re.search(r'(?i)(critical|high|medium|low|informational|gas)', text_lower)
if sev_match:
pred_sev = sev_match.group(1).lower()
if pred_sev is None:
rewards.append(-0.3)
elif pred_sev == gt_sev:
rewards.append(1.0) # Exact match
elif abs(_sev_rank(pred_sev) - _sev_rank(gt_sev)) == 1:
rewards.append(0.3) # Off by one level
else:
rewards.append(-0.5) # Way off
return rewards
def _sev_rank(sev):
ranks = {"critical": 5, "high": 4, "medium": 3, "low": 2, "informational": 1, "gas": 0}
return ranks.get(sev, -1)
# ─── Reward Function 3: Vulnerability Category (weight: 0.25) ────────────────
CATEGORY_KEYWORDS = {
"reentrancy": ["reentrancy", "reentrant", "re-enter", "callback"],
"access-control": ["access control", "unauthorized", "permission", "onlyowner", "role", "privilege"],
"oracle": ["oracle", "price feed", "chainlink", "twap", "price manipulation"],
"flash-loan": ["flash loan", "flashloan"],
"overflow": ["overflow", "underflow", "arithmetic"],
"front-running": ["front-run", "frontrun", "sandwich", "mev"],
"dos": ["denial of service", "dos", "gas limit", "unbounded", "out of gas"],
"token": ["erc20", "erc721", "token", "fee-on-transfer", "rebasing"],
"storage": ["storage collision", "delegatecall", "proxy", "slot"],
"cross-chain": ["bridge", "cross-chain", "relay", "message passing"],
"liquidation": ["liquidation", "collateral", "health factor"],
"signature": ["signature", "ecrecover", "replay", "nonce", "eip712"],
"initialization": ["initialize", "constructor", "uninitialized"],
"rounding": ["rounding", "precision", "truncation", "decimal"],
"logic": ["logic error", "incorrect calculation", "business logic"],
}
def category_reward(prompts, completions, completion_ids=None, category=None, **kwargs):
"""Reward for identifying the correct vulnerability category."""
rewards = []
if category is None:
return [0.0] * len(completions)
if isinstance(category, list):
cat_list = category
else:
cat_list = [category] * len(completions)
for i, completion in enumerate(completions):
text = completion[0]["content"] if isinstance(completion, list) else str(completion)
text_lower = text.lower()
gt_cat = cat_list[i] if i < len(cat_list) else "other"
if gt_cat == "other" or gt_cat == "unknown":
# Can't evaluate β€” neutral reward
rewards.append(0.0)
continue
# Check if the model mentions keywords from the ground truth category
gt_keywords = CATEGORY_KEYWORDS.get(gt_cat, [])
if not gt_keywords:
rewards.append(0.0)
continue
hits = sum(1 for kw in gt_keywords if kw in text_lower)
if hits >= 2:
rewards.append(1.0)
elif hits == 1:
rewards.append(0.5)
else:
# Check if it mentions ANY vulnerability category (at least trying)
any_hit = any(kw in text_lower for kws in CATEGORY_KEYWORDS.values() for kw in kws)
rewards.append(-0.2 if any_hit else -0.5)
return rewards
# ─── Reward Function 4: Content Quality (weight: 0.25) ───────────────────────
def quality_reward(prompts, completions, completion_ids=None, **kwargs):
"""Reward for overall response quality: technical depth, actionability."""
rewards = []
for completion in completions:
text = completion[0]["content"] if isinstance(completion, list) else str(completion)
reward = 0.0
# Technical indicators
technical_terms = [
'msg.sender', 'tx.origin', 'delegatecall', 'selfdestruct',
'transfer', 'call.value', 'abi.encode', 'keccak256',
'require(', 'assert(', 'revert', 'mapping', 'storage',
'memory', 'calldata', 'modifier', 'interface', 'pragma',
'assembly', 'unchecked', 'payable', 'receive()', 'fallback()',
]
tech_count = sum(1 for t in technical_terms if t in text)
reward += min(0.3, 0.03 * tech_count)
# Explanation quality (has reasoning)
reasoning_indicators = [
'because', 'therefore', 'this means', 'as a result',
'the attacker can', 'this allows', 'leading to',
'step 1', 'step 2', 'first,', 'then,', 'finally,',
]
reasoning_count = sum(1 for r in reasoning_indicators if r.lower() in text.lower())
reward += min(0.3, 0.06 * reasoning_count)
# Actionable fix provided
fix_indicators = ['fix:', 'recommendation:', 'mitigation:', 'should', 'consider', 'instead']
fix_count = sum(1 for f in fix_indicators if f.lower() in text.lower())
reward += min(0.2, 0.05 * fix_count)
# Code reference specificity
if re.search(r'line\s+\d+|L\d+|#L\d+', text):
reward += 0.1
if re.search(r'function\s+\w+\s*\(', text):
reward += 0.1
# Penalize generic/unhelpful responses
generic_phrases = ['i cannot', 'i don\'t', 'no vulnerabilities found', 'the code looks safe']
if any(p in text.lower() for p in generic_phrases):
reward -= 0.5
rewards.append(max(-1.0, min(1.0, reward)))
return rewards
# ─── Main ─────────────────────────────────────────────────────────────────────
def main():
logger.info("=" * 60)
logger.info("GRPO V2 Training β€” 50K Real Audit Findings")
logger.info(f"Model: {MODEL_NAME}")
logger.info(f"Dataset: {DATASET_ID}")
logger.info(f"GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'}")
if torch.cuda.is_available():
logger.info(f"GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
logger.info("=" * 60)
# Load dataset
logger.info("Loading dataset...")
dataset = load_dataset(DATASET_ID, split="train")
logger.info(f"Dataset: {len(dataset)} samples, columns={dataset.column_names}")
# For GRPO we only need 'prompt' column + metadata columns for reward
# The reward functions access metadata via kwargs passed from the dataset
# Log severity distribution
sev_dist = Counter(dataset['severity'])
logger.info(f"Severity distribution: {dict(sev_dist)}")
# Subsample β€” 5K highest-value samples for A10G (fits in ~6hrs)
# Focus on HIGH+CRITICAL with code β€” most valuable training signal
logger.info("Selecting high-quality training subset (5K for A10G)...")
indices = []
idx_set = set()
# Priority 1: HIGH+CRITICAL severity with code (most valuable)
for i, row in enumerate(dataset):
if row['severity'] in ('high', 'critical') and row['has_code']:
indices.append(i)
idx_set.add(i)
logger.info(f" HIGH+CRITICAL with code: {len(indices)}")
# Priority 2: Any with PoC reference
for i, row in enumerate(dataset):
if row['has_poc'] and i not in idx_set:
indices.append(i)
idx_set.add(i)
logger.info(f" + Has PoC: {len(indices)}")
# Priority 3: MEDIUM with code (fill to 5K cap)
for i, row in enumerate(dataset):
if row['severity'] == 'medium' and row['has_code'] and i not in idx_set:
indices.append(i)
idx_set.add(i)
if len(indices) >= 5000:
break
logger.info(f" Final subset: {len(indices)} samples")
train_dataset = dataset.select(indices)
# Log final stats
final_sev = Counter(train_dataset['severity'])
final_src = Counter(train_dataset['source'])
logger.info(f"Training severity: {dict(final_sev)}")
logger.info(f"Training sources: {dict(final_src)}")
# GRPO Config β€” tuned for 0.5B on T4 (16GB VRAM)
config = GRPOConfig(
output_dir=OUTPUT_DIR,
num_train_epochs=1, # 1 epoch over 15K samples = plenty
per_device_train_batch_size=2,
gradient_accumulation_steps=4, # effective batch = 8
num_generations=2,
max_completion_length=768, # more room for detailed findings
learning_rate=1e-6, # slightly higher lr for more data
beta=0.04, # small KL penalty to prevent mode collapse with large dataset
scale_rewards=True,
reward_weights=[0.25, 0.25, 0.25, 0.25], # equal weight across 4 rewards
gradient_checkpointing=True,
bf16=True,
logging_steps=10,
logging_first_step=True,
logging_strategy="steps",
disable_tqdm=True,
save_strategy="steps",
save_steps=200,
save_total_limit=2,
push_to_hub=False,
log_completions=False,
report_to="none",
seed=42,
)
logger.info("Initializing GRPOTrainer with 4 reward functions...")
trainer = GRPOTrainer(
model=MODEL_NAME,
args=config,
reward_funcs=[format_reward, severity_reward, category_reward, quality_reward],
train_dataset=train_dataset,
)
logger.info("GRPOTrainer initialized!")
logger.info("Starting training...")
trainer.train()
logger.info("Training complete!")
# Save
logger.info(f"Saving model to {OUTPUT_DIR}...")
trainer.save_model(OUTPUT_DIR)
# Push to Hub
hf_token = os.environ.get("HF_TOKEN")
if hf_token:
logger.info(f"Pushing to hub: {HUB_MODEL_ID}")
try:
from huggingface_hub import HfApi
api = HfApi(token=hf_token)
try:
api.create_repo(repo_id=HUB_MODEL_ID, exist_ok=True)
except Exception as e:
logger.warning(f"create_repo: {e}")
api.upload_folder(
folder_path=OUTPUT_DIR,
repo_id=HUB_MODEL_ID,
commit_message="GRPO v2 β€” trained on 50K real audit findings, 4 reward functions",
)
logger.info(f"βœ… Model pushed to https://huggingface.co/{HUB_MODEL_ID}")
except Exception as e:
logger.error(f"Push failed: {e}")
else:
logger.warning("No HF_TOKEN β€” model saved locally only")
logger.info("DONE")
if __name__ == "__main__":
main()