#!/usr/bin/env python3 """ Smart Contract Security Auditor — GRPO vs Base Comparison Side-by-side evaluation of oxdev/security-auditor-grpo vs Qwen2.5-Coder-0.5B-Instruct """ import re import time import gradio as gr import torch from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline # ── Constants ───────────────────────────────────────────────────────────────── GRPO_MODEL = "oxdev/security-auditor-grpo" BASE_MODEL = "Qwen/Qwen2.5-Coder-0.5B-Instruct" SYSTEM_PROMPT = ( "You are an expert smart contract security auditor. Analyze the provided Solidity code " "for vulnerabilities.\n\nFor each finding, use this format:\n" "FINDING | severity | bug_class\n" "contract: \nfunction: \nbug_class: \nconfidence: high/medium/low\n\n" "### Description\n\n\n### Impact\n\n\n" "### Proof of Concept\n```solidity\n\n```\n\n" "### Recommendation\n" ) # ── Test Cases ──────────────────────────────────────────────────────────────── TEST_CASES = { "🔁 Reentrancy (Classic)": { "code": """// SPDX-License-Identifier: MIT pragma solidity ^0.8.0; contract VulnerableBank { mapping(address => uint256) public balances; function deposit() public payable { balances[msg.sender] += msg.value; } function withdraw(uint256 amount) public { require(balances[msg.sender] >= amount, "Insufficient balance"); (bool success, ) = msg.sender.call{value: amount}(""); require(success, "Transfer failed"); balances[msg.sender] -= amount; } function getBalance() public view returns (uint256) { return address(this).balance; } }""", "expected_vuln": "reentrancy", "expected_severity": "high", "description": "Classic reentrancy: external call before state update in withdraw()" }, "🔑 Access Control Missing": { "code": """// SPDX-License-Identifier: MIT pragma solidity ^0.8.0; contract TokenVault { address public owner; mapping(address => uint256) public balances; bool public paused; constructor() { owner = msg.sender; } function deposit() external payable { balances[msg.sender] += msg.value; } function withdrawAll() external { // Missing: only owner should call this uint256 balance = address(this).balance; (bool success, ) = msg.sender.call{value: balance}(""); require(success); } function setPaused(bool _paused) external { // Missing: only owner should call this paused = _paused; } function emergencyWithdraw(address token, uint256 amount) external { // Missing: only owner should call this IERC20(token).transfer(msg.sender, amount); } } interface IERC20 { function transfer(address to, uint256 amount) external returns (bool); }""", "expected_vuln": "access-control", "expected_severity": "critical", "description": "Missing access control on withdrawAll(), setPaused(), emergencyWithdraw()" }, "📊 Oracle Manipulation": { "code": """// SPDX-License-Identifier: MIT pragma solidity ^0.8.0; interface IUniswapV2Pair { function getReserves() external view returns (uint112, uint112, uint32); } contract VulnerableLending { IUniswapV2Pair public pair; mapping(address => uint256) public collateral; mapping(address => uint256) public debt; constructor(address _pair) { pair = IUniswapV2Pair(_pair); } function getPrice() public view returns (uint256) { (uint112 reserve0, uint112 reserve1, ) = pair.getReserves(); return (uint256(reserve1) * 1e18) / uint256(reserve0); } function deposit() external payable { collateral[msg.sender] += msg.value; } function borrow(uint256 amount) external { uint256 price = getPrice(); uint256 collateralValue = collateral[msg.sender] * price / 1e18; require(collateralValue >= amount * 15 / 10, "Undercollateralized"); debt[msg.sender] += amount; } function liquidate(address user) external { uint256 price = getPrice(); uint256 collateralValue = collateral[user] * price / 1e18; require(collateralValue < debt[user], "Not liquidatable"); collateral[user] = 0; debt[user] = 0; } }""", "expected_vuln": "oracle", "expected_severity": "high", "description": "Spot price from Uniswap reserves is manipulable via flash loans" }, "⚡ Flash Loan Attack Surface": { "code": """// SPDX-License-Identifier: MIT pragma solidity ^0.8.0; interface IERC20 { function balanceOf(address) external view returns (uint256); function transfer(address, uint256) external returns (bool); function transferFrom(address, address, uint256) external returns (bool); } contract VulnerableGovernance { IERC20 public token; mapping(uint256 => Proposal) public proposals; uint256 public proposalCount; struct Proposal { address proposer; string description; uint256 forVotes; uint256 againstVotes; uint256 endBlock; bool executed; } function propose(string calldata desc) external returns (uint256) { require(token.balanceOf(msg.sender) >= 1000e18, "Need 1000 tokens"); proposalCount++; proposals[proposalCount] = Proposal(msg.sender, desc, 0, 0, block.number + 100, false); return proposalCount; } function vote(uint256 proposalId, bool support) external { Proposal storage p = proposals[proposalId]; require(block.number <= p.endBlock, "Voting ended"); // Bug: uses current balance, not snapshot — flash loan can inflate votes uint256 votes = token.balanceOf(msg.sender); if (support) p.forVotes += votes; else p.againstVotes += votes; } function execute(uint256 proposalId) external { Proposal storage p = proposals[proposalId]; require(block.number > p.endBlock && !p.executed); require(p.forVotes > p.againstVotes, "Not passed"); p.executed = true; // execute proposal... } }""", "expected_vuln": "flash-loan", "expected_severity": "critical", "description": "Governance voting uses live balance instead of snapshots — flash loan can swing votes" }, "🔄 Unchecked Return Value": { "code": """// SPDX-License-Identifier: MIT pragma solidity ^0.8.0; interface IERC20 { function transfer(address to, uint256 amount) external returns (bool); function transferFrom(address from, address to, uint256 amount) external returns (bool); function approve(address spender, uint256 amount) external returns (bool); } contract TokenDistributor { IERC20 public token; address public admin; mapping(address => uint256) public allocations; constructor(address _token) { token = IERC20(_token); admin = msg.sender; } function setAllocation(address user, uint256 amount) external { require(msg.sender == admin); allocations[user] = amount; } function claim() external { uint256 amount = allocations[msg.sender]; require(amount > 0, "Nothing to claim"); allocations[msg.sender] = 0; // Bug: return value not checked — some tokens return false instead of reverting token.transfer(msg.sender, amount); } function batchTransfer(address[] calldata recipients, uint256[] calldata amounts) external { require(msg.sender == admin); for (uint256 i = 0; i < recipients.length; i++) { // Bug: unchecked return value token.transfer(recipients[i], amounts[i]); } } }""", "expected_vuln": "token", "expected_severity": "medium", "description": "Unchecked ERC20 transfer return values — silent failure with non-standard tokens" }, "🧮 Rounding / Precision Loss": { "code": """// SPDX-License-Identifier: MIT pragma solidity ^0.8.0; contract VulnerableVault { uint256 public totalShares; uint256 public totalAssets; mapping(address => uint256) public shares; function deposit(uint256 assets) external { uint256 sharesToMint; if (totalShares == 0) { sharesToMint = assets; } else { // Bug: division before multiplication causes rounding loss sharesToMint = assets / totalAssets * totalShares; } shares[msg.sender] += sharesToMint; totalShares += sharesToMint; totalAssets += assets; } function withdraw(uint256 shareAmount) external { require(shares[msg.sender] >= shareAmount); // Bug: same precision issue uint256 assetsToReturn = shareAmount / totalShares * totalAssets; shares[msg.sender] -= shareAmount; totalShares -= shareAmount; totalAssets -= assetsToReturn; // transfer assets... } function previewDeposit(uint256 assets) external view returns (uint256) { if (totalShares == 0) return assets; return assets / totalAssets * totalShares; } }""", "expected_vuln": "rounding", "expected_severity": "high", "description": "Division before multiplication causes severe precision loss in share calculations" }, "✅ Clean Contract (No Bugs)": { "code": """// SPDX-License-Identifier: MIT pragma solidity ^0.8.0; import "@openzeppelin/contracts/security/ReentrancyGuard.sol"; import "@openzeppelin/contracts/access/Ownable.sol"; contract SecureBank is ReentrancyGuard, Ownable { mapping(address => uint256) private _balances; event Deposited(address indexed user, uint256 amount); event Withdrawn(address indexed user, uint256 amount); constructor() Ownable(msg.sender) {} function deposit() external payable { require(msg.value > 0, "Must deposit > 0"); _balances[msg.sender] += msg.value; emit Deposited(msg.sender, msg.value); } function withdraw(uint256 amount) external nonReentrant { require(_balances[msg.sender] >= amount, "Insufficient balance"); _balances[msg.sender] -= amount; (bool success, ) = msg.sender.call{value: amount}(""); require(success, "Transfer failed"); emit Withdrawn(msg.sender, amount); } function balanceOf(address user) external view returns (uint256) { return _balances[user]; } }""", "expected_vuln": "none", "expected_severity": "none", "description": "Well-written contract with ReentrancyGuard, Ownable, CEI pattern, events" } } # ── Scoring Engine ──────────────────────────────────────────────────────────── def score_audit(text, expected_vuln, expected_severity): """Score an audit response on multiple dimensions. Returns dict of scores.""" text_lower = text.lower() scores = {} # 1. Structure score (0-1): Does it follow the FINDING format? struct_score = 0.0 if re.search(r'FINDING\s*\|', text): struct_score += 0.4 fields = ['contract:', 'function:', 'bug_class:', 'confidence:'] field_hits = sum(1 for f in fields if f in text_lower) struct_score += 0.1 * field_hits section_kws = ['description', 'impact', 'proof', 'recommendation', 'mitigation', 'fix'] sect_hits = sum(1 for k in section_kws if re.search(rf'(?i)(###?\s*{k}|{k}\s*:)', text)) struct_score += 0.1 * min(sect_hits, 3) scores["Structure"] = min(1.0, struct_score) # 2. Vulnerability detection (0-1): Did it find the right bug? vuln_keywords = { "reentrancy": ["reentrancy", "reentrant", "re-enter", "external call before state"], "access-control": ["access control", "unauthorized", "missing modifier", "anyone can call", "no restriction"], "oracle": ["oracle", "price manipulation", "spot price", "flash loan.*price", "getReserves"], "flash-loan": ["flash loan", "snapshot", "live balance", "current balance", "voting power"], "token": ["return value", "unchecked", "non-standard", "fee-on-transfer", "bool return"], "rounding": ["rounding", "precision", "division before multiplication", "truncat"], "none": [], } if expected_vuln == "none": # For clean contracts, reward saying "no major issues" / penalize false positives false_alarm_terms = ["critical", "high severity", "vulnerability found", "exploit"] has_false_alarm = any(t in text_lower for t in false_alarm_terms) safe_terms = ["no .* vulnerabilit", "well.written", "secure", "good practice", "no major"] recognizes_safe = any(re.search(t, text_lower) for t in safe_terms) scores["Detection"] = 0.8 if recognizes_safe and not has_false_alarm else 0.3 if recognizes_safe else 0.0 else: kws = vuln_keywords.get(expected_vuln, []) hits = sum(1 for kw in kws if re.search(kw, text_lower)) scores["Detection"] = min(1.0, hits * 0.35) # 3. Severity accuracy (0-1) if expected_severity == "none": scores["Severity"] = 0.5 # N/A for clean contracts else: sev_match = re.search(r'(?i)\b(critical|high|medium|low|informational|gas)\b', text_lower) if sev_match: pred = sev_match.group(1).lower() ranks = {"critical": 5, "high": 4, "medium": 3, "low": 2, "informational": 1, "gas": 0} diff = abs(ranks.get(pred, 0) - ranks.get(expected_severity, 0)) scores["Severity"] = 1.0 if diff == 0 else 0.5 if diff == 1 else 0.1 else: scores["Severity"] = 0.0 # 4. Technical depth (0-1) tech_terms = [ 'msg.sender', 'tx.origin', 'delegatecall', 'selfdestruct', 'call{value', 'abi.encode', 'keccak256', 'require(', 'mapping', 'storage', 'memory', 'modifier', 'interface', 'assembly', 'unchecked', 'payable', 'fallback()', 'receive()', ] tech_count = sum(1 for t in tech_terms if t in text) reasoning_terms = ['because', 'therefore', 'this means', 'this allows', 'the attacker', 'leading to', 'step 1', 'first,'] reason_count = sum(1 for r in reasoning_terms if r.lower() in text_lower) scores["Depth"] = min(1.0, 0.05 * tech_count + 0.1 * reason_count) # 5. Code presence (0-1) has_code = 1.0 if '```' in text else 0.0 scores["Code"] = has_code # Overall weighted score weights = {"Structure": 0.2, "Detection": 0.35, "Severity": 0.15, "Depth": 0.2, "Code": 0.1} scores["Overall"] = sum(scores[k] * weights[k] for k in weights) return scores def format_scores(scores): """Format scores as a readable markdown table.""" lines = ["| Metric | Score |", "|--------|-------|"] emojis = {"Structure": "📋", "Detection": "🎯", "Severity": "⚠️", "Depth": "🔬", "Code": "💻", "Overall": "⭐"} for k, v in scores.items(): emoji = emojis.get(k, "") bar = "█" * int(v * 10) + "░" * (10 - int(v * 10)) lines.append(f"| {emoji} {k} | {bar} {v:.0%} |") return "\n".join(lines) # ── Model Loading ───────────────────────────────────────────────────────────── print("🔄 Loading GRPO model...") grpo_model = AutoModelForCausalLM.from_pretrained( GRPO_MODEL, use_cache=True, torch_dtype=torch.float32, ) grpo_tokenizer = AutoTokenizer.from_pretrained(GRPO_MODEL) grpo_pipe = pipeline("text-generation", model=grpo_model, tokenizer=grpo_tokenizer, device="cpu") print("✅ GRPO model loaded") print("🔄 Loading base model...") base_model = AutoModelForCausalLM.from_pretrained( BASE_MODEL, torch_dtype=torch.float32, ) base_tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL) base_pipe = pipeline("text-generation", model=base_model, tokenizer=base_tokenizer, device="cpu") print("✅ Base model loaded") # ── Inference ───────────────────────────────────────────────────────────────── def run_single_audit(pipe, code, max_tokens=512): """Run audit with one model.""" messages = [ {"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": f"Audit this smart contract for security vulnerabilities:\n\n```solidity\n{code}\n```"}, ] result = pipe(messages, max_new_tokens=max_tokens, do_sample=False, return_full_text=False) output = result[0]["generated_text"] if isinstance(output, list): return output[-1]["content"] return str(output) def run_comparison(code, test_case_name, max_tokens): """Run both models and score results.""" if not code or not code.strip(): return "⚠️ Please enter Solidity code", "", "", "", "" max_tokens = int(max_tokens) # Get expected values from test case tc = TEST_CASES.get(test_case_name, {}) expected_vuln = tc.get("expected_vuln", "unknown") expected_severity = tc.get("expected_severity", "unknown") tc_desc = tc.get("description", "Custom contract — scoring against general audit quality") # If custom code, use "unknown" — score on structure/depth only if test_case_name == "Custom (paste your own)": expected_vuln = "unknown" expected_severity = "unknown" # Run GRPO model t0 = time.time() grpo_result = run_single_audit(grpo_pipe, code, max_tokens) grpo_time = time.time() - t0 # Run base model t0 = time.time() base_result = run_single_audit(base_pipe, code, max_tokens) base_time = time.time() - t0 # Score both grpo_scores = score_audit(grpo_result, expected_vuln, expected_severity) base_scores = score_audit(base_result, expected_vuln, expected_severity) # Format score comparison comparison = f"### 📊 Score Comparison\n\n**Test Case:** {test_case_name}\n" comparison += f"**Expected:** {expected_vuln} ({expected_severity})\n" comparison += f"**Description:** {tc_desc}\n\n" comparison += f"| Metric | 🎯 GRPO | 📦 Base | Delta |\n" comparison += f"|--------|---------|---------|-------|\n" for k in ["Structure", "Detection", "Severity", "Depth", "Code", "Overall"]: g = grpo_scores[k] b = base_scores[k] delta = g - b arrow = "🟢" if delta > 0.05 else "🔴" if delta < -0.05 else "⚪" comparison += f"| {k} | {g:.0%} | {b:.0%} | {arrow} {delta:+.0%} |\n" comparison += f"\n⏱️ GRPO: {grpo_time:.1f}s | Base: {base_time:.1f}s" grpo_header = f"*Generated in {grpo_time:.1f}s — Overall: {grpo_scores['Overall']:.0%}*\n\n" base_header = f"*Generated in {base_time:.1f}s — Overall: {base_scores['Overall']:.0%}*\n\n" return grpo_header + grpo_result, base_header + base_result, comparison def run_benchmark(): """Run all test cases and return aggregate scores.""" results = [] grpo_total = 0 base_total = 0 n = 0 for name, tc in TEST_CASES.items(): code = tc["code"] expected_vuln = tc["expected_vuln"] expected_severity = tc["expected_severity"] grpo_result = run_single_audit(grpo_pipe, code, 512) base_result = run_single_audit(base_pipe, code, 512) grpo_scores = score_audit(grpo_result, expected_vuln, expected_severity) base_scores = score_audit(base_result, expected_vuln, expected_severity) grpo_total += grpo_scores["Overall"] base_total += base_scores["Overall"] n += 1 g_ov = grpo_scores["Overall"] b_ov = base_scores["Overall"] winner = "🎯 GRPO" if g_ov > b_ov + 0.05 else "📦 Base" if b_ov > g_ov + 0.05 else "🤝 Tie" results.append(f"| {name} | {g_ov:.0%} | {b_ov:.0%} | {winner} |") header = "## 🏆 Full Benchmark Results\n\n" header += f"**GRPO Average: {grpo_total/n:.0%}** | **Base Average: {base_total/n:.0%}**\n\n" header += "| Test Case | GRPO | Base | Winner |\n|-----------|------|------|--------|\n" return header + "\n".join(results) def load_test_case(name): """Load a test case into the code editor.""" if name == "Custom (paste your own)": return "" tc = TEST_CASES.get(name, {}) return tc.get("code", "") # ── UI ──────────────────────────────────────────────────────────────────────── with gr.Blocks( title="🔐 Smart Contract Security Auditor", theme=gr.themes.Soft(), css=""" .score-box { padding: 10px; border-radius: 8px; } """ ) as demo: gr.Markdown( "# 🔐 Smart Contract Security Auditor\n" "### GRPO-Trained vs Base Model — Side-by-Side Comparison\n\n" "Compare [`oxdev/security-auditor-grpo`](https://huggingface.co/oxdev/security-auditor-grpo) " "(GRPO-trained on 327 real audit findings) against " "[`Qwen/Qwen2.5-Coder-0.5B-Instruct`](https://huggingface.co/Qwen/Qwen2.5-Coder-0.5B-Instruct) (base).\n\n" "⏱️ **Note:** Running on CPU — each audit takes ~30-90 seconds per model." ) with gr.Tab("🔍 Single Audit"): with gr.Row(): with gr.Column(scale=2): test_case_dropdown = gr.Dropdown( choices=list(TEST_CASES.keys()) + ["Custom (paste your own)"], value="🔁 Reentrancy (Classic)", label="Select Test Case", interactive=True, ) code_input = gr.Code( label="Solidity Contract", language=None, lines=22, value=TEST_CASES["🔁 Reentrancy (Classic)"]["code"], interactive=True, ) with gr.Column(scale=1): max_tokens_slider = gr.Slider( minimum=128, maximum=1024, value=512, step=64, label="Max Output Tokens", ) run_btn = gr.Button("🔍 Run Audit Comparison", variant="primary", size="lg") gr.Markdown( "**How Scoring Works:**\n" "- 📋 **Structure** (20%): FINDING format, sections, fields\n" "- 🎯 **Detection** (35%): Identifies the correct vulnerability\n" "- ⚠️ **Severity** (15%): Correct severity level\n" "- 🔬 **Depth** (20%): Technical terms, reasoning\n" "- 💻 **Code** (10%): Includes code examples" ) with gr.Row(): comparison_output = gr.Markdown(label="Score Comparison") with gr.Row(): with gr.Column(): gr.Markdown("### 🎯 GRPO-Trained Auditor") grpo_output = gr.Markdown(label="GRPO Output") with gr.Column(): gr.Markdown("### 📦 Base Qwen2.5-Coder-0.5B-Instruct") base_output = gr.Markdown(label="Base Output") test_case_dropdown.change( fn=load_test_case, inputs=test_case_dropdown, outputs=code_input, ) run_btn.click( fn=run_comparison, inputs=[code_input, test_case_dropdown, max_tokens_slider], outputs=[grpo_output, base_output, comparison_output], concurrency_limit=1, ) with gr.Tab("🏆 Full Benchmark"): gr.Markdown( "Run all 7 test cases and compare aggregate performance.\n\n" "⏱️ **Warning:** This takes 5-10 minutes on CPU (14 model inferences total)." ) bench_btn = gr.Button("🏆 Run Full Benchmark", variant="primary", size="lg") bench_output = gr.Markdown(label="Benchmark Results") bench_btn.click( fn=run_benchmark, outputs=bench_output, concurrency_limit=1, ) with gr.Tab("ℹ️ About"): gr.Markdown(""" ## Model Details ### 🎯 GRPO-Trained Auditor (`oxdev/security-auditor-grpo`) - **Architecture:** Qwen2ForCausalLM, 0.5B parameters - **Training:** Group Relative Policy Optimization (GRPO) on 327 synthetic smart contract audit samples - **Reward Functions:** Format compliance, finding rate - **Training Results:** Format reward improved 16× (0.025 → 0.40), finding rate 0% → 50-75% ### 📦 Base Model (`Qwen/Qwen2.5-Coder-0.5B-Instruct`) - **Architecture:** Same Qwen2ForCausalLM, 0.5B parameters - **Training:** Standard instruction tuning by Qwen team - **Domain:** General code generation, not specialized for security ### 📊 Training Data - **V1 (used for current model):** 327 synthetic attack vector samples - **V2 (pending training):** [50,902 real audit findings](https://huggingface.co/datasets/oxdev/smart-contract-security-audit-v2) from top security firms ### 🔬 Scoring Methodology Each audit response is scored on 5 dimensions: 1. **Structure (20%)** — Does it use the FINDING format with required fields? 2. **Detection (35%)** — Does it identify the correct vulnerability class? 3. **Severity (15%)** — Does it assign the correct severity level? 4. **Depth (20%)** — Technical terminology, reasoning chains, specificity 5. **Code (10%)** — Includes code examples (exploit PoC, fix) ### 🔗 Resources - [Model on Hub](https://huggingface.co/oxdev/security-auditor-grpo) - [Training Dataset V2](https://huggingface.co/datasets/oxdev/smart-contract-security-audit-v2) - [GitHub Repository](https://github.com/0xedev/skills) """) demo.queue(max_size=5, default_concurrency_limit=1) if __name__ == "__main__": demo.launch()