oxdev's picture
Add main app with side-by-side comparison, test cases, and auto-scoring
1c81810 verified
#!/usr/bin/env python3
"""
Smart Contract Security Auditor — GRPO vs Base Comparison
Side-by-side evaluation of oxdev/security-auditor-grpo vs Qwen2.5-Coder-0.5B-Instruct
"""
import re
import time
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
# ── Constants ─────────────────────────────────────────────────────────────────
GRPO_MODEL = "oxdev/security-auditor-grpo"
BASE_MODEL = "Qwen/Qwen2.5-Coder-0.5B-Instruct"
SYSTEM_PROMPT = (
"You are an expert smart contract security auditor. Analyze the provided Solidity code "
"for vulnerabilities.\n\nFor each finding, use this format:\n"
"FINDING | severity | bug_class\n"
"contract: <name>\nfunction: <name>\nbug_class: <class>\nconfidence: high/medium/low\n\n"
"### Description\n<detailed explanation>\n\n### Impact\n<what can go wrong>\n\n"
"### Proof of Concept\n```solidity\n<exploit code>\n```\n\n"
"### Recommendation\n<how to fix>"
)
# ── Test Cases ────────────────────────────────────────────────────────────────
TEST_CASES = {
"🔁 Reentrancy (Classic)": {
"code": """// SPDX-License-Identifier: MIT
pragma solidity ^0.8.0;
contract VulnerableBank {
mapping(address => uint256) public balances;
function deposit() public payable {
balances[msg.sender] += msg.value;
}
function withdraw(uint256 amount) public {
require(balances[msg.sender] >= amount, "Insufficient balance");
(bool success, ) = msg.sender.call{value: amount}("");
require(success, "Transfer failed");
balances[msg.sender] -= amount;
}
function getBalance() public view returns (uint256) {
return address(this).balance;
}
}""",
"expected_vuln": "reentrancy",
"expected_severity": "high",
"description": "Classic reentrancy: external call before state update in withdraw()"
},
"🔑 Access Control Missing": {
"code": """// SPDX-License-Identifier: MIT
pragma solidity ^0.8.0;
contract TokenVault {
address public owner;
mapping(address => uint256) public balances;
bool public paused;
constructor() {
owner = msg.sender;
}
function deposit() external payable {
balances[msg.sender] += msg.value;
}
function withdrawAll() external {
// Missing: only owner should call this
uint256 balance = address(this).balance;
(bool success, ) = msg.sender.call{value: balance}("");
require(success);
}
function setPaused(bool _paused) external {
// Missing: only owner should call this
paused = _paused;
}
function emergencyWithdraw(address token, uint256 amount) external {
// Missing: only owner should call this
IERC20(token).transfer(msg.sender, amount);
}
}
interface IERC20 {
function transfer(address to, uint256 amount) external returns (bool);
}""",
"expected_vuln": "access-control",
"expected_severity": "critical",
"description": "Missing access control on withdrawAll(), setPaused(), emergencyWithdraw()"
},
"📊 Oracle Manipulation": {
"code": """// SPDX-License-Identifier: MIT
pragma solidity ^0.8.0;
interface IUniswapV2Pair {
function getReserves() external view returns (uint112, uint112, uint32);
}
contract VulnerableLending {
IUniswapV2Pair public pair;
mapping(address => uint256) public collateral;
mapping(address => uint256) public debt;
constructor(address _pair) {
pair = IUniswapV2Pair(_pair);
}
function getPrice() public view returns (uint256) {
(uint112 reserve0, uint112 reserve1, ) = pair.getReserves();
return (uint256(reserve1) * 1e18) / uint256(reserve0);
}
function deposit() external payable {
collateral[msg.sender] += msg.value;
}
function borrow(uint256 amount) external {
uint256 price = getPrice();
uint256 collateralValue = collateral[msg.sender] * price / 1e18;
require(collateralValue >= amount * 15 / 10, "Undercollateralized");
debt[msg.sender] += amount;
}
function liquidate(address user) external {
uint256 price = getPrice();
uint256 collateralValue = collateral[user] * price / 1e18;
require(collateralValue < debt[user], "Not liquidatable");
collateral[user] = 0;
debt[user] = 0;
}
}""",
"expected_vuln": "oracle",
"expected_severity": "high",
"description": "Spot price from Uniswap reserves is manipulable via flash loans"
},
"⚡ Flash Loan Attack Surface": {
"code": """// SPDX-License-Identifier: MIT
pragma solidity ^0.8.0;
interface IERC20 {
function balanceOf(address) external view returns (uint256);
function transfer(address, uint256) external returns (bool);
function transferFrom(address, address, uint256) external returns (bool);
}
contract VulnerableGovernance {
IERC20 public token;
mapping(uint256 => Proposal) public proposals;
uint256 public proposalCount;
struct Proposal {
address proposer;
string description;
uint256 forVotes;
uint256 againstVotes;
uint256 endBlock;
bool executed;
}
function propose(string calldata desc) external returns (uint256) {
require(token.balanceOf(msg.sender) >= 1000e18, "Need 1000 tokens");
proposalCount++;
proposals[proposalCount] = Proposal(msg.sender, desc, 0, 0, block.number + 100, false);
return proposalCount;
}
function vote(uint256 proposalId, bool support) external {
Proposal storage p = proposals[proposalId];
require(block.number <= p.endBlock, "Voting ended");
// Bug: uses current balance, not snapshot — flash loan can inflate votes
uint256 votes = token.balanceOf(msg.sender);
if (support) p.forVotes += votes;
else p.againstVotes += votes;
}
function execute(uint256 proposalId) external {
Proposal storage p = proposals[proposalId];
require(block.number > p.endBlock && !p.executed);
require(p.forVotes > p.againstVotes, "Not passed");
p.executed = true;
// execute proposal...
}
}""",
"expected_vuln": "flash-loan",
"expected_severity": "critical",
"description": "Governance voting uses live balance instead of snapshots — flash loan can swing votes"
},
"🔄 Unchecked Return Value": {
"code": """// SPDX-License-Identifier: MIT
pragma solidity ^0.8.0;
interface IERC20 {
function transfer(address to, uint256 amount) external returns (bool);
function transferFrom(address from, address to, uint256 amount) external returns (bool);
function approve(address spender, uint256 amount) external returns (bool);
}
contract TokenDistributor {
IERC20 public token;
address public admin;
mapping(address => uint256) public allocations;
constructor(address _token) {
token = IERC20(_token);
admin = msg.sender;
}
function setAllocation(address user, uint256 amount) external {
require(msg.sender == admin);
allocations[user] = amount;
}
function claim() external {
uint256 amount = allocations[msg.sender];
require(amount > 0, "Nothing to claim");
allocations[msg.sender] = 0;
// Bug: return value not checked — some tokens return false instead of reverting
token.transfer(msg.sender, amount);
}
function batchTransfer(address[] calldata recipients, uint256[] calldata amounts) external {
require(msg.sender == admin);
for (uint256 i = 0; i < recipients.length; i++) {
// Bug: unchecked return value
token.transfer(recipients[i], amounts[i]);
}
}
}""",
"expected_vuln": "token",
"expected_severity": "medium",
"description": "Unchecked ERC20 transfer return values — silent failure with non-standard tokens"
},
"🧮 Rounding / Precision Loss": {
"code": """// SPDX-License-Identifier: MIT
pragma solidity ^0.8.0;
contract VulnerableVault {
uint256 public totalShares;
uint256 public totalAssets;
mapping(address => uint256) public shares;
function deposit(uint256 assets) external {
uint256 sharesToMint;
if (totalShares == 0) {
sharesToMint = assets;
} else {
// Bug: division before multiplication causes rounding loss
sharesToMint = assets / totalAssets * totalShares;
}
shares[msg.sender] += sharesToMint;
totalShares += sharesToMint;
totalAssets += assets;
}
function withdraw(uint256 shareAmount) external {
require(shares[msg.sender] >= shareAmount);
// Bug: same precision issue
uint256 assetsToReturn = shareAmount / totalShares * totalAssets;
shares[msg.sender] -= shareAmount;
totalShares -= shareAmount;
totalAssets -= assetsToReturn;
// transfer assets...
}
function previewDeposit(uint256 assets) external view returns (uint256) {
if (totalShares == 0) return assets;
return assets / totalAssets * totalShares;
}
}""",
"expected_vuln": "rounding",
"expected_severity": "high",
"description": "Division before multiplication causes severe precision loss in share calculations"
},
"✅ Clean Contract (No Bugs)": {
"code": """// SPDX-License-Identifier: MIT
pragma solidity ^0.8.0;
import "@openzeppelin/contracts/security/ReentrancyGuard.sol";
import "@openzeppelin/contracts/access/Ownable.sol";
contract SecureBank is ReentrancyGuard, Ownable {
mapping(address => uint256) private _balances;
event Deposited(address indexed user, uint256 amount);
event Withdrawn(address indexed user, uint256 amount);
constructor() Ownable(msg.sender) {}
function deposit() external payable {
require(msg.value > 0, "Must deposit > 0");
_balances[msg.sender] += msg.value;
emit Deposited(msg.sender, msg.value);
}
function withdraw(uint256 amount) external nonReentrant {
require(_balances[msg.sender] >= amount, "Insufficient balance");
_balances[msg.sender] -= amount;
(bool success, ) = msg.sender.call{value: amount}("");
require(success, "Transfer failed");
emit Withdrawn(msg.sender, amount);
}
function balanceOf(address user) external view returns (uint256) {
return _balances[user];
}
}""",
"expected_vuln": "none",
"expected_severity": "none",
"description": "Well-written contract with ReentrancyGuard, Ownable, CEI pattern, events"
}
}
# ── Scoring Engine ────────────────────────────────────────────────────────────
def score_audit(text, expected_vuln, expected_severity):
"""Score an audit response on multiple dimensions. Returns dict of scores."""
text_lower = text.lower()
scores = {}
# 1. Structure score (0-1): Does it follow the FINDING format?
struct_score = 0.0
if re.search(r'FINDING\s*\|', text):
struct_score += 0.4
fields = ['contract:', 'function:', 'bug_class:', 'confidence:']
field_hits = sum(1 for f in fields if f in text_lower)
struct_score += 0.1 * field_hits
section_kws = ['description', 'impact', 'proof', 'recommendation', 'mitigation', 'fix']
sect_hits = sum(1 for k in section_kws if re.search(rf'(?i)(###?\s*{k}|{k}\s*:)', text))
struct_score += 0.1 * min(sect_hits, 3)
scores["Structure"] = min(1.0, struct_score)
# 2. Vulnerability detection (0-1): Did it find the right bug?
vuln_keywords = {
"reentrancy": ["reentrancy", "reentrant", "re-enter", "external call before state"],
"access-control": ["access control", "unauthorized", "missing modifier", "anyone can call", "no restriction"],
"oracle": ["oracle", "price manipulation", "spot price", "flash loan.*price", "getReserves"],
"flash-loan": ["flash loan", "snapshot", "live balance", "current balance", "voting power"],
"token": ["return value", "unchecked", "non-standard", "fee-on-transfer", "bool return"],
"rounding": ["rounding", "precision", "division before multiplication", "truncat"],
"none": [],
}
if expected_vuln == "none":
# For clean contracts, reward saying "no major issues" / penalize false positives
false_alarm_terms = ["critical", "high severity", "vulnerability found", "exploit"]
has_false_alarm = any(t in text_lower for t in false_alarm_terms)
safe_terms = ["no .* vulnerabilit", "well.written", "secure", "good practice", "no major"]
recognizes_safe = any(re.search(t, text_lower) for t in safe_terms)
scores["Detection"] = 0.8 if recognizes_safe and not has_false_alarm else 0.3 if recognizes_safe else 0.0
else:
kws = vuln_keywords.get(expected_vuln, [])
hits = sum(1 for kw in kws if re.search(kw, text_lower))
scores["Detection"] = min(1.0, hits * 0.35)
# 3. Severity accuracy (0-1)
if expected_severity == "none":
scores["Severity"] = 0.5 # N/A for clean contracts
else:
sev_match = re.search(r'(?i)\b(critical|high|medium|low|informational|gas)\b', text_lower)
if sev_match:
pred = sev_match.group(1).lower()
ranks = {"critical": 5, "high": 4, "medium": 3, "low": 2, "informational": 1, "gas": 0}
diff = abs(ranks.get(pred, 0) - ranks.get(expected_severity, 0))
scores["Severity"] = 1.0 if diff == 0 else 0.5 if diff == 1 else 0.1
else:
scores["Severity"] = 0.0
# 4. Technical depth (0-1)
tech_terms = [
'msg.sender', 'tx.origin', 'delegatecall', 'selfdestruct',
'call{value', 'abi.encode', 'keccak256', 'require(',
'mapping', 'storage', 'memory', 'modifier', 'interface',
'assembly', 'unchecked', 'payable', 'fallback()', 'receive()',
]
tech_count = sum(1 for t in tech_terms if t in text)
reasoning_terms = ['because', 'therefore', 'this means', 'this allows',
'the attacker', 'leading to', 'step 1', 'first,']
reason_count = sum(1 for r in reasoning_terms if r.lower() in text_lower)
scores["Depth"] = min(1.0, 0.05 * tech_count + 0.1 * reason_count)
# 5. Code presence (0-1)
has_code = 1.0 if '```' in text else 0.0
scores["Code"] = has_code
# Overall weighted score
weights = {"Structure": 0.2, "Detection": 0.35, "Severity": 0.15, "Depth": 0.2, "Code": 0.1}
scores["Overall"] = sum(scores[k] * weights[k] for k in weights)
return scores
def format_scores(scores):
"""Format scores as a readable markdown table."""
lines = ["| Metric | Score |", "|--------|-------|"]
emojis = {"Structure": "📋", "Detection": "🎯", "Severity": "⚠️", "Depth": "🔬", "Code": "💻", "Overall": "⭐"}
for k, v in scores.items():
emoji = emojis.get(k, "")
bar = "█" * int(v * 10) + "░" * (10 - int(v * 10))
lines.append(f"| {emoji} {k} | {bar} {v:.0%} |")
return "\n".join(lines)
# ── Model Loading ─────────────────────────────────────────────────────────────
print("🔄 Loading GRPO model...")
grpo_model = AutoModelForCausalLM.from_pretrained(
GRPO_MODEL, use_cache=True, torch_dtype=torch.float32,
)
grpo_tokenizer = AutoTokenizer.from_pretrained(GRPO_MODEL)
grpo_pipe = pipeline("text-generation", model=grpo_model, tokenizer=grpo_tokenizer, device="cpu")
print("✅ GRPO model loaded")
print("🔄 Loading base model...")
base_model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL, torch_dtype=torch.float32,
)
base_tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
base_pipe = pipeline("text-generation", model=base_model, tokenizer=base_tokenizer, device="cpu")
print("✅ Base model loaded")
# ── Inference ─────────────────────────────────────────────────────────────────
def run_single_audit(pipe, code, max_tokens=512):
"""Run audit with one model."""
messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": f"Audit this smart contract for security vulnerabilities:\n\n```solidity\n{code}\n```"},
]
result = pipe(messages, max_new_tokens=max_tokens, do_sample=False, return_full_text=False)
output = result[0]["generated_text"]
if isinstance(output, list):
return output[-1]["content"]
return str(output)
def run_comparison(code, test_case_name, max_tokens):
"""Run both models and score results."""
if not code or not code.strip():
return "⚠️ Please enter Solidity code", "", "", "", ""
max_tokens = int(max_tokens)
# Get expected values from test case
tc = TEST_CASES.get(test_case_name, {})
expected_vuln = tc.get("expected_vuln", "unknown")
expected_severity = tc.get("expected_severity", "unknown")
tc_desc = tc.get("description", "Custom contract — scoring against general audit quality")
# If custom code, use "unknown" — score on structure/depth only
if test_case_name == "Custom (paste your own)":
expected_vuln = "unknown"
expected_severity = "unknown"
# Run GRPO model
t0 = time.time()
grpo_result = run_single_audit(grpo_pipe, code, max_tokens)
grpo_time = time.time() - t0
# Run base model
t0 = time.time()
base_result = run_single_audit(base_pipe, code, max_tokens)
base_time = time.time() - t0
# Score both
grpo_scores = score_audit(grpo_result, expected_vuln, expected_severity)
base_scores = score_audit(base_result, expected_vuln, expected_severity)
# Format score comparison
comparison = f"### 📊 Score Comparison\n\n**Test Case:** {test_case_name}\n"
comparison += f"**Expected:** {expected_vuln} ({expected_severity})\n"
comparison += f"**Description:** {tc_desc}\n\n"
comparison += f"| Metric | 🎯 GRPO | 📦 Base | Delta |\n"
comparison += f"|--------|---------|---------|-------|\n"
for k in ["Structure", "Detection", "Severity", "Depth", "Code", "Overall"]:
g = grpo_scores[k]
b = base_scores[k]
delta = g - b
arrow = "🟢" if delta > 0.05 else "🔴" if delta < -0.05 else "⚪"
comparison += f"| {k} | {g:.0%} | {b:.0%} | {arrow} {delta:+.0%} |\n"
comparison += f"\n⏱️ GRPO: {grpo_time:.1f}s | Base: {base_time:.1f}s"
grpo_header = f"*Generated in {grpo_time:.1f}s — Overall: {grpo_scores['Overall']:.0%}*\n\n"
base_header = f"*Generated in {base_time:.1f}s — Overall: {base_scores['Overall']:.0%}*\n\n"
return grpo_header + grpo_result, base_header + base_result, comparison
def run_benchmark():
"""Run all test cases and return aggregate scores."""
results = []
grpo_total = 0
base_total = 0
n = 0
for name, tc in TEST_CASES.items():
code = tc["code"]
expected_vuln = tc["expected_vuln"]
expected_severity = tc["expected_severity"]
grpo_result = run_single_audit(grpo_pipe, code, 512)
base_result = run_single_audit(base_pipe, code, 512)
grpo_scores = score_audit(grpo_result, expected_vuln, expected_severity)
base_scores = score_audit(base_result, expected_vuln, expected_severity)
grpo_total += grpo_scores["Overall"]
base_total += base_scores["Overall"]
n += 1
g_ov = grpo_scores["Overall"]
b_ov = base_scores["Overall"]
winner = "🎯 GRPO" if g_ov > b_ov + 0.05 else "📦 Base" if b_ov > g_ov + 0.05 else "🤝 Tie"
results.append(f"| {name} | {g_ov:.0%} | {b_ov:.0%} | {winner} |")
header = "## 🏆 Full Benchmark Results\n\n"
header += f"**GRPO Average: {grpo_total/n:.0%}** | **Base Average: {base_total/n:.0%}**\n\n"
header += "| Test Case | GRPO | Base | Winner |\n|-----------|------|------|--------|\n"
return header + "\n".join(results)
def load_test_case(name):
"""Load a test case into the code editor."""
if name == "Custom (paste your own)":
return ""
tc = TEST_CASES.get(name, {})
return tc.get("code", "")
# ── UI ────────────────────────────────────────────────────────────────────────
with gr.Blocks(
title="🔐 Smart Contract Security Auditor",
theme=gr.themes.Soft(),
css="""
.score-box { padding: 10px; border-radius: 8px; }
"""
) as demo:
gr.Markdown(
"# 🔐 Smart Contract Security Auditor\n"
"### GRPO-Trained vs Base Model — Side-by-Side Comparison\n\n"
"Compare [`oxdev/security-auditor-grpo`](https://huggingface.co/oxdev/security-auditor-grpo) "
"(GRPO-trained on 327 real audit findings) against "
"[`Qwen/Qwen2.5-Coder-0.5B-Instruct`](https://huggingface.co/Qwen/Qwen2.5-Coder-0.5B-Instruct) (base).\n\n"
"⏱️ **Note:** Running on CPU — each audit takes ~30-90 seconds per model."
)
with gr.Tab("🔍 Single Audit"):
with gr.Row():
with gr.Column(scale=2):
test_case_dropdown = gr.Dropdown(
choices=list(TEST_CASES.keys()) + ["Custom (paste your own)"],
value="🔁 Reentrancy (Classic)",
label="Select Test Case",
interactive=True,
)
code_input = gr.Code(
label="Solidity Contract",
language=None,
lines=22,
value=TEST_CASES["🔁 Reentrancy (Classic)"]["code"],
interactive=True,
)
with gr.Column(scale=1):
max_tokens_slider = gr.Slider(
minimum=128, maximum=1024, value=512, step=64,
label="Max Output Tokens",
)
run_btn = gr.Button("🔍 Run Audit Comparison", variant="primary", size="lg")
gr.Markdown(
"**How Scoring Works:**\n"
"- 📋 **Structure** (20%): FINDING format, sections, fields\n"
"- 🎯 **Detection** (35%): Identifies the correct vulnerability\n"
"- ⚠️ **Severity** (15%): Correct severity level\n"
"- 🔬 **Depth** (20%): Technical terms, reasoning\n"
"- 💻 **Code** (10%): Includes code examples"
)
with gr.Row():
comparison_output = gr.Markdown(label="Score Comparison")
with gr.Row():
with gr.Column():
gr.Markdown("### 🎯 GRPO-Trained Auditor")
grpo_output = gr.Markdown(label="GRPO Output")
with gr.Column():
gr.Markdown("### 📦 Base Qwen2.5-Coder-0.5B-Instruct")
base_output = gr.Markdown(label="Base Output")
test_case_dropdown.change(
fn=load_test_case,
inputs=test_case_dropdown,
outputs=code_input,
)
run_btn.click(
fn=run_comparison,
inputs=[code_input, test_case_dropdown, max_tokens_slider],
outputs=[grpo_output, base_output, comparison_output],
concurrency_limit=1,
)
with gr.Tab("🏆 Full Benchmark"):
gr.Markdown(
"Run all 7 test cases and compare aggregate performance.\n\n"
"⏱️ **Warning:** This takes 5-10 minutes on CPU (14 model inferences total)."
)
bench_btn = gr.Button("🏆 Run Full Benchmark", variant="primary", size="lg")
bench_output = gr.Markdown(label="Benchmark Results")
bench_btn.click(
fn=run_benchmark,
outputs=bench_output,
concurrency_limit=1,
)
with gr.Tab("ℹ️ About"):
gr.Markdown("""
## Model Details
### 🎯 GRPO-Trained Auditor (`oxdev/security-auditor-grpo`)
- **Architecture:** Qwen2ForCausalLM, 0.5B parameters
- **Training:** Group Relative Policy Optimization (GRPO) on 327 synthetic smart contract audit samples
- **Reward Functions:** Format compliance, finding rate
- **Training Results:** Format reward improved 16× (0.025 → 0.40), finding rate 0% → 50-75%
### 📦 Base Model (`Qwen/Qwen2.5-Coder-0.5B-Instruct`)
- **Architecture:** Same Qwen2ForCausalLM, 0.5B parameters
- **Training:** Standard instruction tuning by Qwen team
- **Domain:** General code generation, not specialized for security
### 📊 Training Data
- **V1 (used for current model):** 327 synthetic attack vector samples
- **V2 (pending training):** [50,902 real audit findings](https://huggingface.co/datasets/oxdev/smart-contract-security-audit-v2) from top security firms
### 🔬 Scoring Methodology
Each audit response is scored on 5 dimensions:
1. **Structure (20%)** — Does it use the FINDING format with required fields?
2. **Detection (35%)** — Does it identify the correct vulnerability class?
3. **Severity (15%)** — Does it assign the correct severity level?
4. **Depth (20%)** — Technical terminology, reasoning chains, specificity
5. **Code (10%)** — Includes code examples (exploit PoC, fix)
### 🔗 Resources
- [Model on Hub](https://huggingface.co/oxdev/security-auditor-grpo)
- [Training Dataset V2](https://huggingface.co/datasets/oxdev/smart-contract-security-audit-v2)
- [GitHub Repository](https://github.com/0xedev/skills)
""")
demo.queue(max_size=5, default_concurrency_limit=1)
if __name__ == "__main__":
demo.launch()