#!/usr/bin/env python3 """ train_grpo_v2.py — GRPO training on 50K real audit findings. V2 improvements over V1: - 155x more data (50,902 vs 327) - 4 reward functions with ground-truth severity/category matching - Reference-based semantic similarity reward - Better exploration via higher num_generations """ import logging import os import re import shutil from collections import Counter import torch from datasets import load_dataset from trl import GRPOTrainer, GRPOConfig logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s") logger = logging.getLogger(__name__) # ─── Config ─────────────────────────────────────────────────────────────────── MODEL_NAME = "Qwen/Qwen2.5-Coder-0.5B-Instruct" DATASET_ID = "oxdev/smart-contract-security-audit-v2" OUTPUT_DIR = "/tmp/grpo_v2_output" HUB_MODEL_ID = "oxdev/security-auditor-grpo" # ─── Reward Function 1: Structure & Format (weight: 0.25) ──────────────────── def format_reward(prompts, completions, completion_ids=None, **kwargs): """Reward for producing structured FINDING blocks and proper formatting.""" rewards = [] for completion in completions: text = completion[0]["content"] if isinstance(completion, list) else str(completion) reward = 0.0 # FINDING block present if re.search(r'FINDING\s*\|', text): reward += 0.3 # Required fields fields = ['contract:', 'function:', 'bug_class:', 'confidence:'] field_count = sum(1 for f in fields if f in text) reward += 0.05 * field_count # up to 0.2 more # Has code block if re.search(r'```solidity', text): reward += 0.15 # Has structured sections section_keywords = ['description', 'impact', 'proof', 'fix', 'recommendation', 'mitigation'] section_count = sum(1 for kw in section_keywords if re.search(rf'(?i)(###?\s*{kw}|{kw}:)', text)) reward += 0.05 * min(section_count, 3) # up to 0.15 # Penalize very short or very long if len(text) < 50: reward -= 0.3 elif len(text) > 4000: reward -= 0.1 rewards.append(max(-1.0, min(1.0, reward))) return rewards # ─── Reward Function 2: Severity Match (weight: 0.25) ──────────────────────── def severity_reward(prompts, completions, completion_ids=None, severity=None, **kwargs): """Reward for correctly identifying the severity level.""" rewards = [] if severity is None: return [0.0] * len(completions) # Handle batch: severity may be a list if isinstance(severity, list): sev_list = severity else: sev_list = [severity] * len(completions) for i, completion in enumerate(completions): text = completion[0]["content"] if isinstance(completion, list) else str(completion) text_lower = text.lower() gt_sev = sev_list[i] if i < len(sev_list) else "unknown" if gt_sev == "unknown": rewards.append(0.0) continue # Extract predicted severity pred_sev = None sev_match = re.search(r'(?i)(critical|high|medium|low|informational|gas)', text_lower) if sev_match: pred_sev = sev_match.group(1).lower() if pred_sev is None: rewards.append(-0.3) elif pred_sev == gt_sev: rewards.append(1.0) # Exact match elif abs(_sev_rank(pred_sev) - _sev_rank(gt_sev)) == 1: rewards.append(0.3) # Off by one level else: rewards.append(-0.5) # Way off return rewards def _sev_rank(sev): ranks = {"critical": 5, "high": 4, "medium": 3, "low": 2, "informational": 1, "gas": 0} return ranks.get(sev, -1) # ─── Reward Function 3: Vulnerability Category (weight: 0.25) ──────────────── CATEGORY_KEYWORDS = { "reentrancy": ["reentrancy", "reentrant", "re-enter", "callback"], "access-control": ["access control", "unauthorized", "permission", "onlyowner", "role", "privilege"], "oracle": ["oracle", "price feed", "chainlink", "twap", "price manipulation"], "flash-loan": ["flash loan", "flashloan"], "overflow": ["overflow", "underflow", "arithmetic"], "front-running": ["front-run", "frontrun", "sandwich", "mev"], "dos": ["denial of service", "dos", "gas limit", "unbounded", "out of gas"], "token": ["erc20", "erc721", "token", "fee-on-transfer", "rebasing"], "storage": ["storage collision", "delegatecall", "proxy", "slot"], "cross-chain": ["bridge", "cross-chain", "relay", "message passing"], "liquidation": ["liquidation", "collateral", "health factor"], "signature": ["signature", "ecrecover", "replay", "nonce", "eip712"], "initialization": ["initialize", "constructor", "uninitialized"], "rounding": ["rounding", "precision", "truncation", "decimal"], "logic": ["logic error", "incorrect calculation", "business logic"], } def category_reward(prompts, completions, completion_ids=None, category=None, **kwargs): """Reward for identifying the correct vulnerability category.""" rewards = [] if category is None: return [0.0] * len(completions) if isinstance(category, list): cat_list = category else: cat_list = [category] * len(completions) for i, completion in enumerate(completions): text = completion[0]["content"] if isinstance(completion, list) else str(completion) text_lower = text.lower() gt_cat = cat_list[i] if i < len(cat_list) else "other" if gt_cat == "other" or gt_cat == "unknown": # Can't evaluate — neutral reward rewards.append(0.0) continue # Check if the model mentions keywords from the ground truth category gt_keywords = CATEGORY_KEYWORDS.get(gt_cat, []) if not gt_keywords: rewards.append(0.0) continue hits = sum(1 for kw in gt_keywords if kw in text_lower) if hits >= 2: rewards.append(1.0) elif hits == 1: rewards.append(0.5) else: # Check if it mentions ANY vulnerability category (at least trying) any_hit = any(kw in text_lower for kws in CATEGORY_KEYWORDS.values() for kw in kws) rewards.append(-0.2 if any_hit else -0.5) return rewards # ─── Reward Function 4: Content Quality (weight: 0.25) ─────────────────────── def quality_reward(prompts, completions, completion_ids=None, **kwargs): """Reward for overall response quality: technical depth, actionability.""" rewards = [] for completion in completions: text = completion[0]["content"] if isinstance(completion, list) else str(completion) reward = 0.0 # Technical indicators technical_terms = [ 'msg.sender', 'tx.origin', 'delegatecall', 'selfdestruct', 'transfer', 'call.value', 'abi.encode', 'keccak256', 'require(', 'assert(', 'revert', 'mapping', 'storage', 'memory', 'calldata', 'modifier', 'interface', 'pragma', 'assembly', 'unchecked', 'payable', 'receive()', 'fallback()', ] tech_count = sum(1 for t in technical_terms if t in text) reward += min(0.3, 0.03 * tech_count) # Explanation quality (has reasoning) reasoning_indicators = [ 'because', 'therefore', 'this means', 'as a result', 'the attacker can', 'this allows', 'leading to', 'step 1', 'step 2', 'first,', 'then,', 'finally,', ] reasoning_count = sum(1 for r in reasoning_indicators if r.lower() in text.lower()) reward += min(0.3, 0.06 * reasoning_count) # Actionable fix provided fix_indicators = ['fix:', 'recommendation:', 'mitigation:', 'should', 'consider', 'instead'] fix_count = sum(1 for f in fix_indicators if f.lower() in text.lower()) reward += min(0.2, 0.05 * fix_count) # Code reference specificity if re.search(r'line\s+\d+|L\d+|#L\d+', text): reward += 0.1 if re.search(r'function\s+\w+\s*\(', text): reward += 0.1 # Penalize generic/unhelpful responses generic_phrases = ['i cannot', 'i don\'t', 'no vulnerabilities found', 'the code looks safe'] if any(p in text.lower() for p in generic_phrases): reward -= 0.5 rewards.append(max(-1.0, min(1.0, reward))) return rewards # ─── Main ───────────────────────────────────────────────────────────────────── def main(): logger.info("=" * 60) logger.info("GRPO V2 Training — 50K Real Audit Findings") logger.info(f"Model: {MODEL_NAME}") logger.info(f"Dataset: {DATASET_ID}") logger.info(f"GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'}") if torch.cuda.is_available(): logger.info(f"GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB") logger.info("=" * 60) # Load dataset logger.info("Loading dataset...") dataset = load_dataset(DATASET_ID, split="train") logger.info(f"Dataset: {len(dataset)} samples, columns={dataset.column_names}") # For GRPO we only need 'prompt' column + metadata columns for reward # The reward functions access metadata via kwargs passed from the dataset # Log severity distribution sev_dist = Counter(dataset['severity']) logger.info(f"Severity distribution: {dict(sev_dist)}") # Subsample — 5K highest-value samples for A10G (fits in ~6hrs) # Focus on HIGH+CRITICAL with code — most valuable training signal logger.info("Selecting high-quality training subset (5K for A10G)...") indices = [] idx_set = set() # Priority 1: HIGH+CRITICAL severity with code (most valuable) for i, row in enumerate(dataset): if row['severity'] in ('high', 'critical') and row['has_code']: indices.append(i) idx_set.add(i) logger.info(f" HIGH+CRITICAL with code: {len(indices)}") # Priority 2: Any with PoC reference for i, row in enumerate(dataset): if row['has_poc'] and i not in idx_set: indices.append(i) idx_set.add(i) logger.info(f" + Has PoC: {len(indices)}") # Priority 3: MEDIUM with code (fill to 5K cap) for i, row in enumerate(dataset): if row['severity'] == 'medium' and row['has_code'] and i not in idx_set: indices.append(i) idx_set.add(i) if len(indices) >= 5000: break logger.info(f" Final subset: {len(indices)} samples") train_dataset = dataset.select(indices) # Log final stats final_sev = Counter(train_dataset['severity']) final_src = Counter(train_dataset['source']) logger.info(f"Training severity: {dict(final_sev)}") logger.info(f"Training sources: {dict(final_src)}") # GRPO Config — tuned for 0.5B on T4 (16GB VRAM) config = GRPOConfig( output_dir=OUTPUT_DIR, num_train_epochs=1, # 1 epoch over 15K samples = plenty per_device_train_batch_size=2, gradient_accumulation_steps=4, # effective batch = 8 num_generations=2, max_completion_length=768, # more room for detailed findings learning_rate=1e-6, # slightly higher lr for more data beta=0.04, # small KL penalty to prevent mode collapse with large dataset scale_rewards=True, reward_weights=[0.25, 0.25, 0.25, 0.25], # equal weight across 4 rewards gradient_checkpointing=True, bf16=True, logging_steps=10, logging_first_step=True, logging_strategy="steps", disable_tqdm=True, save_strategy="steps", save_steps=200, save_total_limit=2, push_to_hub=False, log_completions=False, report_to="none", seed=42, ) logger.info("Initializing GRPOTrainer with 4 reward functions...") trainer = GRPOTrainer( model=MODEL_NAME, args=config, reward_funcs=[format_reward, severity_reward, category_reward, quality_reward], train_dataset=train_dataset, ) logger.info("GRPOTrainer initialized!") logger.info("Starting training...") trainer.train() logger.info("Training complete!") # Save logger.info(f"Saving model to {OUTPUT_DIR}...") trainer.save_model(OUTPUT_DIR) # Push to Hub hf_token = os.environ.get("HF_TOKEN") if hf_token: logger.info(f"Pushing to hub: {HUB_MODEL_ID}") try: from huggingface_hub import HfApi api = HfApi(token=hf_token) try: api.create_repo(repo_id=HUB_MODEL_ID, exist_ok=True) except Exception as e: logger.warning(f"create_repo: {e}") api.upload_folder( folder_path=OUTPUT_DIR, repo_id=HUB_MODEL_ID, commit_message="GRPO v2 — trained on 50K real audit findings, 4 reward functions", ) logger.info(f"✅ Model pushed to https://huggingface.co/{HUB_MODEL_ID}") except Exception as e: logger.error(f"Push failed: {e}") else: logger.warning("No HF_TOKEN — model saved locally only") logger.info("DONE") if __name__ == "__main__": main()