Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| """ | |
| Smart Contract Security Auditor — GRPO vs Base Comparison | |
| Side-by-side evaluation of oxdev/security-auditor-grpo vs Qwen2.5-Coder-0.5B-Instruct | |
| """ | |
| import re | |
| import time | |
| import gradio as gr | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline | |
| # ── Constants ───────────────────────────────────────────────────────────────── | |
| GRPO_MODEL = "oxdev/security-auditor-grpo" | |
| BASE_MODEL = "Qwen/Qwen2.5-Coder-0.5B-Instruct" | |
| SYSTEM_PROMPT = ( | |
| "You are an expert smart contract security auditor. Analyze the provided Solidity code " | |
| "for vulnerabilities.\n\nFor each finding, use this format:\n" | |
| "FINDING | severity | bug_class\n" | |
| "contract: <name>\nfunction: <name>\nbug_class: <class>\nconfidence: high/medium/low\n\n" | |
| "### Description\n<detailed explanation>\n\n### Impact\n<what can go wrong>\n\n" | |
| "### Proof of Concept\n```solidity\n<exploit code>\n```\n\n" | |
| "### Recommendation\n<how to fix>" | |
| ) | |
| # ── Test Cases ──────────────────────────────────────────────────────────────── | |
| TEST_CASES = { | |
| "🔁 Reentrancy (Classic)": { | |
| "code": """// SPDX-License-Identifier: MIT | |
| pragma solidity ^0.8.0; | |
| contract VulnerableBank { | |
| mapping(address => uint256) public balances; | |
| function deposit() public payable { | |
| balances[msg.sender] += msg.value; | |
| } | |
| function withdraw(uint256 amount) public { | |
| require(balances[msg.sender] >= amount, "Insufficient balance"); | |
| (bool success, ) = msg.sender.call{value: amount}(""); | |
| require(success, "Transfer failed"); | |
| balances[msg.sender] -= amount; | |
| } | |
| function getBalance() public view returns (uint256) { | |
| return address(this).balance; | |
| } | |
| }""", | |
| "expected_vuln": "reentrancy", | |
| "expected_severity": "high", | |
| "description": "Classic reentrancy: external call before state update in withdraw()" | |
| }, | |
| "🔑 Access Control Missing": { | |
| "code": """// SPDX-License-Identifier: MIT | |
| pragma solidity ^0.8.0; | |
| contract TokenVault { | |
| address public owner; | |
| mapping(address => uint256) public balances; | |
| bool public paused; | |
| constructor() { | |
| owner = msg.sender; | |
| } | |
| function deposit() external payable { | |
| balances[msg.sender] += msg.value; | |
| } | |
| function withdrawAll() external { | |
| // Missing: only owner should call this | |
| uint256 balance = address(this).balance; | |
| (bool success, ) = msg.sender.call{value: balance}(""); | |
| require(success); | |
| } | |
| function setPaused(bool _paused) external { | |
| // Missing: only owner should call this | |
| paused = _paused; | |
| } | |
| function emergencyWithdraw(address token, uint256 amount) external { | |
| // Missing: only owner should call this | |
| IERC20(token).transfer(msg.sender, amount); | |
| } | |
| } | |
| interface IERC20 { | |
| function transfer(address to, uint256 amount) external returns (bool); | |
| }""", | |
| "expected_vuln": "access-control", | |
| "expected_severity": "critical", | |
| "description": "Missing access control on withdrawAll(), setPaused(), emergencyWithdraw()" | |
| }, | |
| "📊 Oracle Manipulation": { | |
| "code": """// SPDX-License-Identifier: MIT | |
| pragma solidity ^0.8.0; | |
| interface IUniswapV2Pair { | |
| function getReserves() external view returns (uint112, uint112, uint32); | |
| } | |
| contract VulnerableLending { | |
| IUniswapV2Pair public pair; | |
| mapping(address => uint256) public collateral; | |
| mapping(address => uint256) public debt; | |
| constructor(address _pair) { | |
| pair = IUniswapV2Pair(_pair); | |
| } | |
| function getPrice() public view returns (uint256) { | |
| (uint112 reserve0, uint112 reserve1, ) = pair.getReserves(); | |
| return (uint256(reserve1) * 1e18) / uint256(reserve0); | |
| } | |
| function deposit() external payable { | |
| collateral[msg.sender] += msg.value; | |
| } | |
| function borrow(uint256 amount) external { | |
| uint256 price = getPrice(); | |
| uint256 collateralValue = collateral[msg.sender] * price / 1e18; | |
| require(collateralValue >= amount * 15 / 10, "Undercollateralized"); | |
| debt[msg.sender] += amount; | |
| } | |
| function liquidate(address user) external { | |
| uint256 price = getPrice(); | |
| uint256 collateralValue = collateral[user] * price / 1e18; | |
| require(collateralValue < debt[user], "Not liquidatable"); | |
| collateral[user] = 0; | |
| debt[user] = 0; | |
| } | |
| }""", | |
| "expected_vuln": "oracle", | |
| "expected_severity": "high", | |
| "description": "Spot price from Uniswap reserves is manipulable via flash loans" | |
| }, | |
| "⚡ Flash Loan Attack Surface": { | |
| "code": """// SPDX-License-Identifier: MIT | |
| pragma solidity ^0.8.0; | |
| interface IERC20 { | |
| function balanceOf(address) external view returns (uint256); | |
| function transfer(address, uint256) external returns (bool); | |
| function transferFrom(address, address, uint256) external returns (bool); | |
| } | |
| contract VulnerableGovernance { | |
| IERC20 public token; | |
| mapping(uint256 => Proposal) public proposals; | |
| uint256 public proposalCount; | |
| struct Proposal { | |
| address proposer; | |
| string description; | |
| uint256 forVotes; | |
| uint256 againstVotes; | |
| uint256 endBlock; | |
| bool executed; | |
| } | |
| function propose(string calldata desc) external returns (uint256) { | |
| require(token.balanceOf(msg.sender) >= 1000e18, "Need 1000 tokens"); | |
| proposalCount++; | |
| proposals[proposalCount] = Proposal(msg.sender, desc, 0, 0, block.number + 100, false); | |
| return proposalCount; | |
| } | |
| function vote(uint256 proposalId, bool support) external { | |
| Proposal storage p = proposals[proposalId]; | |
| require(block.number <= p.endBlock, "Voting ended"); | |
| // Bug: uses current balance, not snapshot — flash loan can inflate votes | |
| uint256 votes = token.balanceOf(msg.sender); | |
| if (support) p.forVotes += votes; | |
| else p.againstVotes += votes; | |
| } | |
| function execute(uint256 proposalId) external { | |
| Proposal storage p = proposals[proposalId]; | |
| require(block.number > p.endBlock && !p.executed); | |
| require(p.forVotes > p.againstVotes, "Not passed"); | |
| p.executed = true; | |
| // execute proposal... | |
| } | |
| }""", | |
| "expected_vuln": "flash-loan", | |
| "expected_severity": "critical", | |
| "description": "Governance voting uses live balance instead of snapshots — flash loan can swing votes" | |
| }, | |
| "🔄 Unchecked Return Value": { | |
| "code": """// SPDX-License-Identifier: MIT | |
| pragma solidity ^0.8.0; | |
| interface IERC20 { | |
| function transfer(address to, uint256 amount) external returns (bool); | |
| function transferFrom(address from, address to, uint256 amount) external returns (bool); | |
| function approve(address spender, uint256 amount) external returns (bool); | |
| } | |
| contract TokenDistributor { | |
| IERC20 public token; | |
| address public admin; | |
| mapping(address => uint256) public allocations; | |
| constructor(address _token) { | |
| token = IERC20(_token); | |
| admin = msg.sender; | |
| } | |
| function setAllocation(address user, uint256 amount) external { | |
| require(msg.sender == admin); | |
| allocations[user] = amount; | |
| } | |
| function claim() external { | |
| uint256 amount = allocations[msg.sender]; | |
| require(amount > 0, "Nothing to claim"); | |
| allocations[msg.sender] = 0; | |
| // Bug: return value not checked — some tokens return false instead of reverting | |
| token.transfer(msg.sender, amount); | |
| } | |
| function batchTransfer(address[] calldata recipients, uint256[] calldata amounts) external { | |
| require(msg.sender == admin); | |
| for (uint256 i = 0; i < recipients.length; i++) { | |
| // Bug: unchecked return value | |
| token.transfer(recipients[i], amounts[i]); | |
| } | |
| } | |
| }""", | |
| "expected_vuln": "token", | |
| "expected_severity": "medium", | |
| "description": "Unchecked ERC20 transfer return values — silent failure with non-standard tokens" | |
| }, | |
| "🧮 Rounding / Precision Loss": { | |
| "code": """// SPDX-License-Identifier: MIT | |
| pragma solidity ^0.8.0; | |
| contract VulnerableVault { | |
| uint256 public totalShares; | |
| uint256 public totalAssets; | |
| mapping(address => uint256) public shares; | |
| function deposit(uint256 assets) external { | |
| uint256 sharesToMint; | |
| if (totalShares == 0) { | |
| sharesToMint = assets; | |
| } else { | |
| // Bug: division before multiplication causes rounding loss | |
| sharesToMint = assets / totalAssets * totalShares; | |
| } | |
| shares[msg.sender] += sharesToMint; | |
| totalShares += sharesToMint; | |
| totalAssets += assets; | |
| } | |
| function withdraw(uint256 shareAmount) external { | |
| require(shares[msg.sender] >= shareAmount); | |
| // Bug: same precision issue | |
| uint256 assetsToReturn = shareAmount / totalShares * totalAssets; | |
| shares[msg.sender] -= shareAmount; | |
| totalShares -= shareAmount; | |
| totalAssets -= assetsToReturn; | |
| // transfer assets... | |
| } | |
| function previewDeposit(uint256 assets) external view returns (uint256) { | |
| if (totalShares == 0) return assets; | |
| return assets / totalAssets * totalShares; | |
| } | |
| }""", | |
| "expected_vuln": "rounding", | |
| "expected_severity": "high", | |
| "description": "Division before multiplication causes severe precision loss in share calculations" | |
| }, | |
| "✅ Clean Contract (No Bugs)": { | |
| "code": """// SPDX-License-Identifier: MIT | |
| pragma solidity ^0.8.0; | |
| import "@openzeppelin/contracts/security/ReentrancyGuard.sol"; | |
| import "@openzeppelin/contracts/access/Ownable.sol"; | |
| contract SecureBank is ReentrancyGuard, Ownable { | |
| mapping(address => uint256) private _balances; | |
| event Deposited(address indexed user, uint256 amount); | |
| event Withdrawn(address indexed user, uint256 amount); | |
| constructor() Ownable(msg.sender) {} | |
| function deposit() external payable { | |
| require(msg.value > 0, "Must deposit > 0"); | |
| _balances[msg.sender] += msg.value; | |
| emit Deposited(msg.sender, msg.value); | |
| } | |
| function withdraw(uint256 amount) external nonReentrant { | |
| require(_balances[msg.sender] >= amount, "Insufficient balance"); | |
| _balances[msg.sender] -= amount; | |
| (bool success, ) = msg.sender.call{value: amount}(""); | |
| require(success, "Transfer failed"); | |
| emit Withdrawn(msg.sender, amount); | |
| } | |
| function balanceOf(address user) external view returns (uint256) { | |
| return _balances[user]; | |
| } | |
| }""", | |
| "expected_vuln": "none", | |
| "expected_severity": "none", | |
| "description": "Well-written contract with ReentrancyGuard, Ownable, CEI pattern, events" | |
| } | |
| } | |
| # ── Scoring Engine ──────────────────────────────────────────────────────────── | |
| def score_audit(text, expected_vuln, expected_severity): | |
| """Score an audit response on multiple dimensions. Returns dict of scores.""" | |
| text_lower = text.lower() | |
| scores = {} | |
| # 1. Structure score (0-1): Does it follow the FINDING format? | |
| struct_score = 0.0 | |
| if re.search(r'FINDING\s*\|', text): | |
| struct_score += 0.4 | |
| fields = ['contract:', 'function:', 'bug_class:', 'confidence:'] | |
| field_hits = sum(1 for f in fields if f in text_lower) | |
| struct_score += 0.1 * field_hits | |
| section_kws = ['description', 'impact', 'proof', 'recommendation', 'mitigation', 'fix'] | |
| sect_hits = sum(1 for k in section_kws if re.search(rf'(?i)(###?\s*{k}|{k}\s*:)', text)) | |
| struct_score += 0.1 * min(sect_hits, 3) | |
| scores["Structure"] = min(1.0, struct_score) | |
| # 2. Vulnerability detection (0-1): Did it find the right bug? | |
| vuln_keywords = { | |
| "reentrancy": ["reentrancy", "reentrant", "re-enter", "external call before state"], | |
| "access-control": ["access control", "unauthorized", "missing modifier", "anyone can call", "no restriction"], | |
| "oracle": ["oracle", "price manipulation", "spot price", "flash loan.*price", "getReserves"], | |
| "flash-loan": ["flash loan", "snapshot", "live balance", "current balance", "voting power"], | |
| "token": ["return value", "unchecked", "non-standard", "fee-on-transfer", "bool return"], | |
| "rounding": ["rounding", "precision", "division before multiplication", "truncat"], | |
| "none": [], | |
| } | |
| if expected_vuln == "none": | |
| # For clean contracts, reward saying "no major issues" / penalize false positives | |
| false_alarm_terms = ["critical", "high severity", "vulnerability found", "exploit"] | |
| has_false_alarm = any(t in text_lower for t in false_alarm_terms) | |
| safe_terms = ["no .* vulnerabilit", "well.written", "secure", "good practice", "no major"] | |
| recognizes_safe = any(re.search(t, text_lower) for t in safe_terms) | |
| scores["Detection"] = 0.8 if recognizes_safe and not has_false_alarm else 0.3 if recognizes_safe else 0.0 | |
| else: | |
| kws = vuln_keywords.get(expected_vuln, []) | |
| hits = sum(1 for kw in kws if re.search(kw, text_lower)) | |
| scores["Detection"] = min(1.0, hits * 0.35) | |
| # 3. Severity accuracy (0-1) | |
| if expected_severity == "none": | |
| scores["Severity"] = 0.5 # N/A for clean contracts | |
| else: | |
| sev_match = re.search(r'(?i)\b(critical|high|medium|low|informational|gas)\b', text_lower) | |
| if sev_match: | |
| pred = sev_match.group(1).lower() | |
| ranks = {"critical": 5, "high": 4, "medium": 3, "low": 2, "informational": 1, "gas": 0} | |
| diff = abs(ranks.get(pred, 0) - ranks.get(expected_severity, 0)) | |
| scores["Severity"] = 1.0 if diff == 0 else 0.5 if diff == 1 else 0.1 | |
| else: | |
| scores["Severity"] = 0.0 | |
| # 4. Technical depth (0-1) | |
| tech_terms = [ | |
| 'msg.sender', 'tx.origin', 'delegatecall', 'selfdestruct', | |
| 'call{value', 'abi.encode', 'keccak256', 'require(', | |
| 'mapping', 'storage', 'memory', 'modifier', 'interface', | |
| 'assembly', 'unchecked', 'payable', 'fallback()', 'receive()', | |
| ] | |
| tech_count = sum(1 for t in tech_terms if t in text) | |
| reasoning_terms = ['because', 'therefore', 'this means', 'this allows', | |
| 'the attacker', 'leading to', 'step 1', 'first,'] | |
| reason_count = sum(1 for r in reasoning_terms if r.lower() in text_lower) | |
| scores["Depth"] = min(1.0, 0.05 * tech_count + 0.1 * reason_count) | |
| # 5. Code presence (0-1) | |
| has_code = 1.0 if '```' in text else 0.0 | |
| scores["Code"] = has_code | |
| # Overall weighted score | |
| weights = {"Structure": 0.2, "Detection": 0.35, "Severity": 0.15, "Depth": 0.2, "Code": 0.1} | |
| scores["Overall"] = sum(scores[k] * weights[k] for k in weights) | |
| return scores | |
| def format_scores(scores): | |
| """Format scores as a readable markdown table.""" | |
| lines = ["| Metric | Score |", "|--------|-------|"] | |
| emojis = {"Structure": "📋", "Detection": "🎯", "Severity": "⚠️", "Depth": "🔬", "Code": "💻", "Overall": "⭐"} | |
| for k, v in scores.items(): | |
| emoji = emojis.get(k, "") | |
| bar = "█" * int(v * 10) + "░" * (10 - int(v * 10)) | |
| lines.append(f"| {emoji} {k} | {bar} {v:.0%} |") | |
| return "\n".join(lines) | |
| # ── Model Loading ───────────────────────────────────────────────────────────── | |
| print("🔄 Loading GRPO model...") | |
| grpo_model = AutoModelForCausalLM.from_pretrained( | |
| GRPO_MODEL, use_cache=True, torch_dtype=torch.float32, | |
| ) | |
| grpo_tokenizer = AutoTokenizer.from_pretrained(GRPO_MODEL) | |
| grpo_pipe = pipeline("text-generation", model=grpo_model, tokenizer=grpo_tokenizer, device="cpu") | |
| print("✅ GRPO model loaded") | |
| print("🔄 Loading base model...") | |
| base_model = AutoModelForCausalLM.from_pretrained( | |
| BASE_MODEL, torch_dtype=torch.float32, | |
| ) | |
| base_tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL) | |
| base_pipe = pipeline("text-generation", model=base_model, tokenizer=base_tokenizer, device="cpu") | |
| print("✅ Base model loaded") | |
| # ── Inference ───────────────────────────────────────────────────────────────── | |
| def run_single_audit(pipe, code, max_tokens=512): | |
| """Run audit with one model.""" | |
| messages = [ | |
| {"role": "system", "content": SYSTEM_PROMPT}, | |
| {"role": "user", "content": f"Audit this smart contract for security vulnerabilities:\n\n```solidity\n{code}\n```"}, | |
| ] | |
| result = pipe(messages, max_new_tokens=max_tokens, do_sample=False, return_full_text=False) | |
| output = result[0]["generated_text"] | |
| if isinstance(output, list): | |
| return output[-1]["content"] | |
| return str(output) | |
| def run_comparison(code, test_case_name, max_tokens): | |
| """Run both models and score results.""" | |
| if not code or not code.strip(): | |
| return "⚠️ Please enter Solidity code", "", "", "", "" | |
| max_tokens = int(max_tokens) | |
| # Get expected values from test case | |
| tc = TEST_CASES.get(test_case_name, {}) | |
| expected_vuln = tc.get("expected_vuln", "unknown") | |
| expected_severity = tc.get("expected_severity", "unknown") | |
| tc_desc = tc.get("description", "Custom contract — scoring against general audit quality") | |
| # If custom code, use "unknown" — score on structure/depth only | |
| if test_case_name == "Custom (paste your own)": | |
| expected_vuln = "unknown" | |
| expected_severity = "unknown" | |
| # Run GRPO model | |
| t0 = time.time() | |
| grpo_result = run_single_audit(grpo_pipe, code, max_tokens) | |
| grpo_time = time.time() - t0 | |
| # Run base model | |
| t0 = time.time() | |
| base_result = run_single_audit(base_pipe, code, max_tokens) | |
| base_time = time.time() - t0 | |
| # Score both | |
| grpo_scores = score_audit(grpo_result, expected_vuln, expected_severity) | |
| base_scores = score_audit(base_result, expected_vuln, expected_severity) | |
| # Format score comparison | |
| comparison = f"### 📊 Score Comparison\n\n**Test Case:** {test_case_name}\n" | |
| comparison += f"**Expected:** {expected_vuln} ({expected_severity})\n" | |
| comparison += f"**Description:** {tc_desc}\n\n" | |
| comparison += f"| Metric | 🎯 GRPO | 📦 Base | Delta |\n" | |
| comparison += f"|--------|---------|---------|-------|\n" | |
| for k in ["Structure", "Detection", "Severity", "Depth", "Code", "Overall"]: | |
| g = grpo_scores[k] | |
| b = base_scores[k] | |
| delta = g - b | |
| arrow = "🟢" if delta > 0.05 else "🔴" if delta < -0.05 else "⚪" | |
| comparison += f"| {k} | {g:.0%} | {b:.0%} | {arrow} {delta:+.0%} |\n" | |
| comparison += f"\n⏱️ GRPO: {grpo_time:.1f}s | Base: {base_time:.1f}s" | |
| grpo_header = f"*Generated in {grpo_time:.1f}s — Overall: {grpo_scores['Overall']:.0%}*\n\n" | |
| base_header = f"*Generated in {base_time:.1f}s — Overall: {base_scores['Overall']:.0%}*\n\n" | |
| return grpo_header + grpo_result, base_header + base_result, comparison | |
| def run_benchmark(): | |
| """Run all test cases and return aggregate scores.""" | |
| results = [] | |
| grpo_total = 0 | |
| base_total = 0 | |
| n = 0 | |
| for name, tc in TEST_CASES.items(): | |
| code = tc["code"] | |
| expected_vuln = tc["expected_vuln"] | |
| expected_severity = tc["expected_severity"] | |
| grpo_result = run_single_audit(grpo_pipe, code, 512) | |
| base_result = run_single_audit(base_pipe, code, 512) | |
| grpo_scores = score_audit(grpo_result, expected_vuln, expected_severity) | |
| base_scores = score_audit(base_result, expected_vuln, expected_severity) | |
| grpo_total += grpo_scores["Overall"] | |
| base_total += base_scores["Overall"] | |
| n += 1 | |
| g_ov = grpo_scores["Overall"] | |
| b_ov = base_scores["Overall"] | |
| winner = "🎯 GRPO" if g_ov > b_ov + 0.05 else "📦 Base" if b_ov > g_ov + 0.05 else "🤝 Tie" | |
| results.append(f"| {name} | {g_ov:.0%} | {b_ov:.0%} | {winner} |") | |
| header = "## 🏆 Full Benchmark Results\n\n" | |
| header += f"**GRPO Average: {grpo_total/n:.0%}** | **Base Average: {base_total/n:.0%}**\n\n" | |
| header += "| Test Case | GRPO | Base | Winner |\n|-----------|------|------|--------|\n" | |
| return header + "\n".join(results) | |
| def load_test_case(name): | |
| """Load a test case into the code editor.""" | |
| if name == "Custom (paste your own)": | |
| return "" | |
| tc = TEST_CASES.get(name, {}) | |
| return tc.get("code", "") | |
| # ── UI ──────────────────────────────────────────────────────────────────────── | |
| with gr.Blocks( | |
| title="🔐 Smart Contract Security Auditor", | |
| theme=gr.themes.Soft(), | |
| css=""" | |
| .score-box { padding: 10px; border-radius: 8px; } | |
| """ | |
| ) as demo: | |
| gr.Markdown( | |
| "# 🔐 Smart Contract Security Auditor\n" | |
| "### GRPO-Trained vs Base Model — Side-by-Side Comparison\n\n" | |
| "Compare [`oxdev/security-auditor-grpo`](https://huggingface.co/oxdev/security-auditor-grpo) " | |
| "(GRPO-trained on 327 real audit findings) against " | |
| "[`Qwen/Qwen2.5-Coder-0.5B-Instruct`](https://huggingface.co/Qwen/Qwen2.5-Coder-0.5B-Instruct) (base).\n\n" | |
| "⏱️ **Note:** Running on CPU — each audit takes ~30-90 seconds per model." | |
| ) | |
| with gr.Tab("🔍 Single Audit"): | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| test_case_dropdown = gr.Dropdown( | |
| choices=list(TEST_CASES.keys()) + ["Custom (paste your own)"], | |
| value="🔁 Reentrancy (Classic)", | |
| label="Select Test Case", | |
| interactive=True, | |
| ) | |
| code_input = gr.Code( | |
| label="Solidity Contract", | |
| language=None, | |
| lines=22, | |
| value=TEST_CASES["🔁 Reentrancy (Classic)"]["code"], | |
| interactive=True, | |
| ) | |
| with gr.Column(scale=1): | |
| max_tokens_slider = gr.Slider( | |
| minimum=128, maximum=1024, value=512, step=64, | |
| label="Max Output Tokens", | |
| ) | |
| run_btn = gr.Button("🔍 Run Audit Comparison", variant="primary", size="lg") | |
| gr.Markdown( | |
| "**How Scoring Works:**\n" | |
| "- 📋 **Structure** (20%): FINDING format, sections, fields\n" | |
| "- 🎯 **Detection** (35%): Identifies the correct vulnerability\n" | |
| "- ⚠️ **Severity** (15%): Correct severity level\n" | |
| "- 🔬 **Depth** (20%): Technical terms, reasoning\n" | |
| "- 💻 **Code** (10%): Includes code examples" | |
| ) | |
| with gr.Row(): | |
| comparison_output = gr.Markdown(label="Score Comparison") | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("### 🎯 GRPO-Trained Auditor") | |
| grpo_output = gr.Markdown(label="GRPO Output") | |
| with gr.Column(): | |
| gr.Markdown("### 📦 Base Qwen2.5-Coder-0.5B-Instruct") | |
| base_output = gr.Markdown(label="Base Output") | |
| test_case_dropdown.change( | |
| fn=load_test_case, | |
| inputs=test_case_dropdown, | |
| outputs=code_input, | |
| ) | |
| run_btn.click( | |
| fn=run_comparison, | |
| inputs=[code_input, test_case_dropdown, max_tokens_slider], | |
| outputs=[grpo_output, base_output, comparison_output], | |
| concurrency_limit=1, | |
| ) | |
| with gr.Tab("🏆 Full Benchmark"): | |
| gr.Markdown( | |
| "Run all 7 test cases and compare aggregate performance.\n\n" | |
| "⏱️ **Warning:** This takes 5-10 minutes on CPU (14 model inferences total)." | |
| ) | |
| bench_btn = gr.Button("🏆 Run Full Benchmark", variant="primary", size="lg") | |
| bench_output = gr.Markdown(label="Benchmark Results") | |
| bench_btn.click( | |
| fn=run_benchmark, | |
| outputs=bench_output, | |
| concurrency_limit=1, | |
| ) | |
| with gr.Tab("ℹ️ About"): | |
| gr.Markdown(""" | |
| ## Model Details | |
| ### 🎯 GRPO-Trained Auditor (`oxdev/security-auditor-grpo`) | |
| - **Architecture:** Qwen2ForCausalLM, 0.5B parameters | |
| - **Training:** Group Relative Policy Optimization (GRPO) on 327 synthetic smart contract audit samples | |
| - **Reward Functions:** Format compliance, finding rate | |
| - **Training Results:** Format reward improved 16× (0.025 → 0.40), finding rate 0% → 50-75% | |
| ### 📦 Base Model (`Qwen/Qwen2.5-Coder-0.5B-Instruct`) | |
| - **Architecture:** Same Qwen2ForCausalLM, 0.5B parameters | |
| - **Training:** Standard instruction tuning by Qwen team | |
| - **Domain:** General code generation, not specialized for security | |
| ### 📊 Training Data | |
| - **V1 (used for current model):** 327 synthetic attack vector samples | |
| - **V2 (pending training):** [50,902 real audit findings](https://huggingface.co/datasets/oxdev/smart-contract-security-audit-v2) from top security firms | |
| ### 🔬 Scoring Methodology | |
| Each audit response is scored on 5 dimensions: | |
| 1. **Structure (20%)** — Does it use the FINDING format with required fields? | |
| 2. **Detection (35%)** — Does it identify the correct vulnerability class? | |
| 3. **Severity (15%)** — Does it assign the correct severity level? | |
| 4. **Depth (20%)** — Technical terminology, reasoning chains, specificity | |
| 5. **Code (10%)** — Includes code examples (exploit PoC, fix) | |
| ### 🔗 Resources | |
| - [Model on Hub](https://huggingface.co/oxdev/security-auditor-grpo) | |
| - [Training Dataset V2](https://huggingface.co/datasets/oxdev/smart-contract-security-audit-v2) | |
| - [GitHub Repository](https://github.com/0xedev/skills) | |
| """) | |
| demo.queue(max_size=5, default_concurrency_limit=1) | |
| if __name__ == "__main__": | |
| demo.launch() | |