oxdev commited on
Commit
1c81810
·
verified ·
1 Parent(s): 6266ba6

Add main app with side-by-side comparison, test cases, and auto-scoring

Browse files
Files changed (1) hide show
  1. app.py +660 -0
app.py ADDED
@@ -0,0 +1,660 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Smart Contract Security Auditor — GRPO vs Base Comparison
4
+ Side-by-side evaluation of oxdev/security-auditor-grpo vs Qwen2.5-Coder-0.5B-Instruct
5
+ """
6
+
7
+ import re
8
+ import time
9
+ import gradio as gr
10
+ import torch
11
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
12
+
13
+ # ── Constants ─────────────────────────────────────────────────────────────────
14
+ GRPO_MODEL = "oxdev/security-auditor-grpo"
15
+ BASE_MODEL = "Qwen/Qwen2.5-Coder-0.5B-Instruct"
16
+
17
+ SYSTEM_PROMPT = (
18
+ "You are an expert smart contract security auditor. Analyze the provided Solidity code "
19
+ "for vulnerabilities.\n\nFor each finding, use this format:\n"
20
+ "FINDING | severity | bug_class\n"
21
+ "contract: <name>\nfunction: <name>\nbug_class: <class>\nconfidence: high/medium/low\n\n"
22
+ "### Description\n<detailed explanation>\n\n### Impact\n<what can go wrong>\n\n"
23
+ "### Proof of Concept\n```solidity\n<exploit code>\n```\n\n"
24
+ "### Recommendation\n<how to fix>"
25
+ )
26
+
27
+ # ── Test Cases ────────────────────────────────────────────────────────────────
28
+ TEST_CASES = {
29
+ "🔁 Reentrancy (Classic)": {
30
+ "code": """// SPDX-License-Identifier: MIT
31
+ pragma solidity ^0.8.0;
32
+
33
+ contract VulnerableBank {
34
+ mapping(address => uint256) public balances;
35
+
36
+ function deposit() public payable {
37
+ balances[msg.sender] += msg.value;
38
+ }
39
+
40
+ function withdraw(uint256 amount) public {
41
+ require(balances[msg.sender] >= amount, "Insufficient balance");
42
+ (bool success, ) = msg.sender.call{value: amount}("");
43
+ require(success, "Transfer failed");
44
+ balances[msg.sender] -= amount;
45
+ }
46
+
47
+ function getBalance() public view returns (uint256) {
48
+ return address(this).balance;
49
+ }
50
+ }""",
51
+ "expected_vuln": "reentrancy",
52
+ "expected_severity": "high",
53
+ "description": "Classic reentrancy: external call before state update in withdraw()"
54
+ },
55
+
56
+ "🔑 Access Control Missing": {
57
+ "code": """// SPDX-License-Identifier: MIT
58
+ pragma solidity ^0.8.0;
59
+
60
+ contract TokenVault {
61
+ address public owner;
62
+ mapping(address => uint256) public balances;
63
+ bool public paused;
64
+
65
+ constructor() {
66
+ owner = msg.sender;
67
+ }
68
+
69
+ function deposit() external payable {
70
+ balances[msg.sender] += msg.value;
71
+ }
72
+
73
+ function withdrawAll() external {
74
+ // Missing: only owner should call this
75
+ uint256 balance = address(this).balance;
76
+ (bool success, ) = msg.sender.call{value: balance}("");
77
+ require(success);
78
+ }
79
+
80
+ function setPaused(bool _paused) external {
81
+ // Missing: only owner should call this
82
+ paused = _paused;
83
+ }
84
+
85
+ function emergencyWithdraw(address token, uint256 amount) external {
86
+ // Missing: only owner should call this
87
+ IERC20(token).transfer(msg.sender, amount);
88
+ }
89
+ }
90
+
91
+ interface IERC20 {
92
+ function transfer(address to, uint256 amount) external returns (bool);
93
+ }""",
94
+ "expected_vuln": "access-control",
95
+ "expected_severity": "critical",
96
+ "description": "Missing access control on withdrawAll(), setPaused(), emergencyWithdraw()"
97
+ },
98
+
99
+ "📊 Oracle Manipulation": {
100
+ "code": """// SPDX-License-Identifier: MIT
101
+ pragma solidity ^0.8.0;
102
+
103
+ interface IUniswapV2Pair {
104
+ function getReserves() external view returns (uint112, uint112, uint32);
105
+ }
106
+
107
+ contract VulnerableLending {
108
+ IUniswapV2Pair public pair;
109
+ mapping(address => uint256) public collateral;
110
+ mapping(address => uint256) public debt;
111
+
112
+ constructor(address _pair) {
113
+ pair = IUniswapV2Pair(_pair);
114
+ }
115
+
116
+ function getPrice() public view returns (uint256) {
117
+ (uint112 reserve0, uint112 reserve1, ) = pair.getReserves();
118
+ return (uint256(reserve1) * 1e18) / uint256(reserve0);
119
+ }
120
+
121
+ function deposit() external payable {
122
+ collateral[msg.sender] += msg.value;
123
+ }
124
+
125
+ function borrow(uint256 amount) external {
126
+ uint256 price = getPrice();
127
+ uint256 collateralValue = collateral[msg.sender] * price / 1e18;
128
+ require(collateralValue >= amount * 15 / 10, "Undercollateralized");
129
+ debt[msg.sender] += amount;
130
+ }
131
+
132
+ function liquidate(address user) external {
133
+ uint256 price = getPrice();
134
+ uint256 collateralValue = collateral[user] * price / 1e18;
135
+ require(collateralValue < debt[user], "Not liquidatable");
136
+ collateral[user] = 0;
137
+ debt[user] = 0;
138
+ }
139
+ }""",
140
+ "expected_vuln": "oracle",
141
+ "expected_severity": "high",
142
+ "description": "Spot price from Uniswap reserves is manipulable via flash loans"
143
+ },
144
+
145
+ "⚡ Flash Loan Attack Surface": {
146
+ "code": """// SPDX-License-Identifier: MIT
147
+ pragma solidity ^0.8.0;
148
+
149
+ interface IERC20 {
150
+ function balanceOf(address) external view returns (uint256);
151
+ function transfer(address, uint256) external returns (bool);
152
+ function transferFrom(address, address, uint256) external returns (bool);
153
+ }
154
+
155
+ contract VulnerableGovernance {
156
+ IERC20 public token;
157
+ mapping(uint256 => Proposal) public proposals;
158
+ uint256 public proposalCount;
159
+
160
+ struct Proposal {
161
+ address proposer;
162
+ string description;
163
+ uint256 forVotes;
164
+ uint256 againstVotes;
165
+ uint256 endBlock;
166
+ bool executed;
167
+ }
168
+
169
+ function propose(string calldata desc) external returns (uint256) {
170
+ require(token.balanceOf(msg.sender) >= 1000e18, "Need 1000 tokens");
171
+ proposalCount++;
172
+ proposals[proposalCount] = Proposal(msg.sender, desc, 0, 0, block.number + 100, false);
173
+ return proposalCount;
174
+ }
175
+
176
+ function vote(uint256 proposalId, bool support) external {
177
+ Proposal storage p = proposals[proposalId];
178
+ require(block.number <= p.endBlock, "Voting ended");
179
+ // Bug: uses current balance, not snapshot — flash loan can inflate votes
180
+ uint256 votes = token.balanceOf(msg.sender);
181
+ if (support) p.forVotes += votes;
182
+ else p.againstVotes += votes;
183
+ }
184
+
185
+ function execute(uint256 proposalId) external {
186
+ Proposal storage p = proposals[proposalId];
187
+ require(block.number > p.endBlock && !p.executed);
188
+ require(p.forVotes > p.againstVotes, "Not passed");
189
+ p.executed = true;
190
+ // execute proposal...
191
+ }
192
+ }""",
193
+ "expected_vuln": "flash-loan",
194
+ "expected_severity": "critical",
195
+ "description": "Governance voting uses live balance instead of snapshots — flash loan can swing votes"
196
+ },
197
+
198
+ "🔄 Unchecked Return Value": {
199
+ "code": """// SPDX-License-Identifier: MIT
200
+ pragma solidity ^0.8.0;
201
+
202
+ interface IERC20 {
203
+ function transfer(address to, uint256 amount) external returns (bool);
204
+ function transferFrom(address from, address to, uint256 amount) external returns (bool);
205
+ function approve(address spender, uint256 amount) external returns (bool);
206
+ }
207
+
208
+ contract TokenDistributor {
209
+ IERC20 public token;
210
+ address public admin;
211
+ mapping(address => uint256) public allocations;
212
+
213
+ constructor(address _token) {
214
+ token = IERC20(_token);
215
+ admin = msg.sender;
216
+ }
217
+
218
+ function setAllocation(address user, uint256 amount) external {
219
+ require(msg.sender == admin);
220
+ allocations[user] = amount;
221
+ }
222
+
223
+ function claim() external {
224
+ uint256 amount = allocations[msg.sender];
225
+ require(amount > 0, "Nothing to claim");
226
+ allocations[msg.sender] = 0;
227
+ // Bug: return value not checked — some tokens return false instead of reverting
228
+ token.transfer(msg.sender, amount);
229
+ }
230
+
231
+ function batchTransfer(address[] calldata recipients, uint256[] calldata amounts) external {
232
+ require(msg.sender == admin);
233
+ for (uint256 i = 0; i < recipients.length; i++) {
234
+ // Bug: unchecked return value
235
+ token.transfer(recipients[i], amounts[i]);
236
+ }
237
+ }
238
+ }""",
239
+ "expected_vuln": "token",
240
+ "expected_severity": "medium",
241
+ "description": "Unchecked ERC20 transfer return values — silent failure with non-standard tokens"
242
+ },
243
+
244
+ "🧮 Rounding / Precision Loss": {
245
+ "code": """// SPDX-License-Identifier: MIT
246
+ pragma solidity ^0.8.0;
247
+
248
+ contract VulnerableVault {
249
+ uint256 public totalShares;
250
+ uint256 public totalAssets;
251
+ mapping(address => uint256) public shares;
252
+
253
+ function deposit(uint256 assets) external {
254
+ uint256 sharesToMint;
255
+ if (totalShares == 0) {
256
+ sharesToMint = assets;
257
+ } else {
258
+ // Bug: division before multiplication causes rounding loss
259
+ sharesToMint = assets / totalAssets * totalShares;
260
+ }
261
+ shares[msg.sender] += sharesToMint;
262
+ totalShares += sharesToMint;
263
+ totalAssets += assets;
264
+ }
265
+
266
+ function withdraw(uint256 shareAmount) external {
267
+ require(shares[msg.sender] >= shareAmount);
268
+ // Bug: same precision issue
269
+ uint256 assetsToReturn = shareAmount / totalShares * totalAssets;
270
+ shares[msg.sender] -= shareAmount;
271
+ totalShares -= shareAmount;
272
+ totalAssets -= assetsToReturn;
273
+ // transfer assets...
274
+ }
275
+
276
+ function previewDeposit(uint256 assets) external view returns (uint256) {
277
+ if (totalShares == 0) return assets;
278
+ return assets / totalAssets * totalShares;
279
+ }
280
+ }""",
281
+ "expected_vuln": "rounding",
282
+ "expected_severity": "high",
283
+ "description": "Division before multiplication causes severe precision loss in share calculations"
284
+ },
285
+
286
+ "✅ Clean Contract (No Bugs)": {
287
+ "code": """// SPDX-License-Identifier: MIT
288
+ pragma solidity ^0.8.0;
289
+
290
+ import "@openzeppelin/contracts/security/ReentrancyGuard.sol";
291
+ import "@openzeppelin/contracts/access/Ownable.sol";
292
+
293
+ contract SecureBank is ReentrancyGuard, Ownable {
294
+ mapping(address => uint256) private _balances;
295
+
296
+ event Deposited(address indexed user, uint256 amount);
297
+ event Withdrawn(address indexed user, uint256 amount);
298
+
299
+ constructor() Ownable(msg.sender) {}
300
+
301
+ function deposit() external payable {
302
+ require(msg.value > 0, "Must deposit > 0");
303
+ _balances[msg.sender] += msg.value;
304
+ emit Deposited(msg.sender, msg.value);
305
+ }
306
+
307
+ function withdraw(uint256 amount) external nonReentrant {
308
+ require(_balances[msg.sender] >= amount, "Insufficient balance");
309
+ _balances[msg.sender] -= amount;
310
+ (bool success, ) = msg.sender.call{value: amount}("");
311
+ require(success, "Transfer failed");
312
+ emit Withdrawn(msg.sender, amount);
313
+ }
314
+
315
+ function balanceOf(address user) external view returns (uint256) {
316
+ return _balances[user];
317
+ }
318
+ }""",
319
+ "expected_vuln": "none",
320
+ "expected_severity": "none",
321
+ "description": "Well-written contract with ReentrancyGuard, Ownable, CEI pattern, events"
322
+ }
323
+ }
324
+
325
+
326
+ # ── Scoring Engine ────────────────────────────────────────────────────────────
327
+ def score_audit(text, expected_vuln, expected_severity):
328
+ """Score an audit response on multiple dimensions. Returns dict of scores."""
329
+ text_lower = text.lower()
330
+ scores = {}
331
+
332
+ # 1. Structure score (0-1): Does it follow the FINDING format?
333
+ struct_score = 0.0
334
+ if re.search(r'FINDING\s*\|', text):
335
+ struct_score += 0.4
336
+ fields = ['contract:', 'function:', 'bug_class:', 'confidence:']
337
+ field_hits = sum(1 for f in fields if f in text_lower)
338
+ struct_score += 0.1 * field_hits
339
+ section_kws = ['description', 'impact', 'proof', 'recommendation', 'mitigation', 'fix']
340
+ sect_hits = sum(1 for k in section_kws if re.search(rf'(?i)(###?\s*{k}|{k}\s*:)', text))
341
+ struct_score += 0.1 * min(sect_hits, 3)
342
+ scores["Structure"] = min(1.0, struct_score)
343
+
344
+ # 2. Vulnerability detection (0-1): Did it find the right bug?
345
+ vuln_keywords = {
346
+ "reentrancy": ["reentrancy", "reentrant", "re-enter", "external call before state"],
347
+ "access-control": ["access control", "unauthorized", "missing modifier", "anyone can call", "no restriction"],
348
+ "oracle": ["oracle", "price manipulation", "spot price", "flash loan.*price", "getReserves"],
349
+ "flash-loan": ["flash loan", "snapshot", "live balance", "current balance", "voting power"],
350
+ "token": ["return value", "unchecked", "non-standard", "fee-on-transfer", "bool return"],
351
+ "rounding": ["rounding", "precision", "division before multiplication", "truncat"],
352
+ "none": [],
353
+ }
354
+ if expected_vuln == "none":
355
+ # For clean contracts, reward saying "no major issues" / penalize false positives
356
+ false_alarm_terms = ["critical", "high severity", "vulnerability found", "exploit"]
357
+ has_false_alarm = any(t in text_lower for t in false_alarm_terms)
358
+ safe_terms = ["no .* vulnerabilit", "well.written", "secure", "good practice", "no major"]
359
+ recognizes_safe = any(re.search(t, text_lower) for t in safe_terms)
360
+ scores["Detection"] = 0.8 if recognizes_safe and not has_false_alarm else 0.3 if recognizes_safe else 0.0
361
+ else:
362
+ kws = vuln_keywords.get(expected_vuln, [])
363
+ hits = sum(1 for kw in kws if re.search(kw, text_lower))
364
+ scores["Detection"] = min(1.0, hits * 0.35)
365
+
366
+ # 3. Severity accuracy (0-1)
367
+ if expected_severity == "none":
368
+ scores["Severity"] = 0.5 # N/A for clean contracts
369
+ else:
370
+ sev_match = re.search(r'(?i)\b(critical|high|medium|low|informational|gas)\b', text_lower)
371
+ if sev_match:
372
+ pred = sev_match.group(1).lower()
373
+ ranks = {"critical": 5, "high": 4, "medium": 3, "low": 2, "informational": 1, "gas": 0}
374
+ diff = abs(ranks.get(pred, 0) - ranks.get(expected_severity, 0))
375
+ scores["Severity"] = 1.0 if diff == 0 else 0.5 if diff == 1 else 0.1
376
+ else:
377
+ scores["Severity"] = 0.0
378
+
379
+ # 4. Technical depth (0-1)
380
+ tech_terms = [
381
+ 'msg.sender', 'tx.origin', 'delegatecall', 'selfdestruct',
382
+ 'call{value', 'abi.encode', 'keccak256', 'require(',
383
+ 'mapping', 'storage', 'memory', 'modifier', 'interface',
384
+ 'assembly', 'unchecked', 'payable', 'fallback()', 'receive()',
385
+ ]
386
+ tech_count = sum(1 for t in tech_terms if t in text)
387
+ reasoning_terms = ['because', 'therefore', 'this means', 'this allows',
388
+ 'the attacker', 'leading to', 'step 1', 'first,']
389
+ reason_count = sum(1 for r in reasoning_terms if r.lower() in text_lower)
390
+ scores["Depth"] = min(1.0, 0.05 * tech_count + 0.1 * reason_count)
391
+
392
+ # 5. Code presence (0-1)
393
+ has_code = 1.0 if '```' in text else 0.0
394
+ scores["Code"] = has_code
395
+
396
+ # Overall weighted score
397
+ weights = {"Structure": 0.2, "Detection": 0.35, "Severity": 0.15, "Depth": 0.2, "Code": 0.1}
398
+ scores["Overall"] = sum(scores[k] * weights[k] for k in weights)
399
+
400
+ return scores
401
+
402
+
403
+ def format_scores(scores):
404
+ """Format scores as a readable markdown table."""
405
+ lines = ["| Metric | Score |", "|--------|-------|"]
406
+ emojis = {"Structure": "📋", "Detection": "🎯", "Severity": "⚠️", "Depth": "🔬", "Code": "💻", "Overall": "⭐"}
407
+ for k, v in scores.items():
408
+ emoji = emojis.get(k, "")
409
+ bar = "█" * int(v * 10) + "░" * (10 - int(v * 10))
410
+ lines.append(f"| {emoji} {k} | {bar} {v:.0%} |")
411
+ return "\n".join(lines)
412
+
413
+
414
+ # ── Model Loading ─────────────────────────────────────────────────────────────
415
+ print("🔄 Loading GRPO model...")
416
+ grpo_model = AutoModelForCausalLM.from_pretrained(
417
+ GRPO_MODEL, use_cache=True, torch_dtype=torch.float32,
418
+ )
419
+ grpo_tokenizer = AutoTokenizer.from_pretrained(GRPO_MODEL)
420
+ grpo_pipe = pipeline("text-generation", model=grpo_model, tokenizer=grpo_tokenizer, device="cpu")
421
+ print("✅ GRPO model loaded")
422
+
423
+ print("🔄 Loading base model...")
424
+ base_model = AutoModelForCausalLM.from_pretrained(
425
+ BASE_MODEL, torch_dtype=torch.float32,
426
+ )
427
+ base_tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
428
+ base_pipe = pipeline("text-generation", model=base_model, tokenizer=base_tokenizer, device="cpu")
429
+ print("✅ Base model loaded")
430
+
431
+
432
+ # ── Inference ─────────────────────────────────────────────────────────────────
433
+ def run_single_audit(pipe, code, max_tokens=512):
434
+ """Run audit with one model."""
435
+ messages = [
436
+ {"role": "system", "content": SYSTEM_PROMPT},
437
+ {"role": "user", "content": f"Audit this smart contract for security vulnerabilities:\n\n```solidity\n{code}\n```"},
438
+ ]
439
+ result = pipe(messages, max_new_tokens=max_tokens, do_sample=False, return_full_text=False)
440
+ output = result[0]["generated_text"]
441
+ if isinstance(output, list):
442
+ return output[-1]["content"]
443
+ return str(output)
444
+
445
+
446
+ def run_comparison(code, test_case_name, max_tokens):
447
+ """Run both models and score results."""
448
+ if not code or not code.strip():
449
+ return "⚠️ Please enter Solidity code", "", "", "", ""
450
+
451
+ max_tokens = int(max_tokens)
452
+
453
+ # Get expected values from test case
454
+ tc = TEST_CASES.get(test_case_name, {})
455
+ expected_vuln = tc.get("expected_vuln", "unknown")
456
+ expected_severity = tc.get("expected_severity", "unknown")
457
+ tc_desc = tc.get("description", "Custom contract — scoring against general audit quality")
458
+
459
+ # If custom code, use "unknown" — score on structure/depth only
460
+ if test_case_name == "Custom (paste your own)":
461
+ expected_vuln = "unknown"
462
+ expected_severity = "unknown"
463
+
464
+ # Run GRPO model
465
+ t0 = time.time()
466
+ grpo_result = run_single_audit(grpo_pipe, code, max_tokens)
467
+ grpo_time = time.time() - t0
468
+
469
+ # Run base model
470
+ t0 = time.time()
471
+ base_result = run_single_audit(base_pipe, code, max_tokens)
472
+ base_time = time.time() - t0
473
+
474
+ # Score both
475
+ grpo_scores = score_audit(grpo_result, expected_vuln, expected_severity)
476
+ base_scores = score_audit(base_result, expected_vuln, expected_severity)
477
+
478
+ # Format score comparison
479
+ comparison = f"### 📊 Score Comparison\n\n**Test Case:** {test_case_name}\n"
480
+ comparison += f"**Expected:** {expected_vuln} ({expected_severity})\n"
481
+ comparison += f"**Description:** {tc_desc}\n\n"
482
+ comparison += f"| Metric | 🎯 GRPO | 📦 Base | Delta |\n"
483
+ comparison += f"|--------|---------|---------|-------|\n"
484
+ for k in ["Structure", "Detection", "Severity", "Depth", "Code", "Overall"]:
485
+ g = grpo_scores[k]
486
+ b = base_scores[k]
487
+ delta = g - b
488
+ arrow = "🟢" if delta > 0.05 else "🔴" if delta < -0.05 else "⚪"
489
+ comparison += f"| {k} | {g:.0%} | {b:.0%} | {arrow} {delta:+.0%} |\n"
490
+ comparison += f"\n⏱️ GRPO: {grpo_time:.1f}s | Base: {base_time:.1f}s"
491
+
492
+ grpo_header = f"*Generated in {grpo_time:.1f}s — Overall: {grpo_scores['Overall']:.0%}*\n\n"
493
+ base_header = f"*Generated in {base_time:.1f}s — Overall: {base_scores['Overall']:.0%}*\n\n"
494
+
495
+ return grpo_header + grpo_result, base_header + base_result, comparison
496
+
497
+
498
+ def run_benchmark():
499
+ """Run all test cases and return aggregate scores."""
500
+ results = []
501
+ grpo_total = 0
502
+ base_total = 0
503
+ n = 0
504
+
505
+ for name, tc in TEST_CASES.items():
506
+ code = tc["code"]
507
+ expected_vuln = tc["expected_vuln"]
508
+ expected_severity = tc["expected_severity"]
509
+
510
+ grpo_result = run_single_audit(grpo_pipe, code, 512)
511
+ base_result = run_single_audit(base_pipe, code, 512)
512
+
513
+ grpo_scores = score_audit(grpo_result, expected_vuln, expected_severity)
514
+ base_scores = score_audit(base_result, expected_vuln, expected_severity)
515
+
516
+ grpo_total += grpo_scores["Overall"]
517
+ base_total += base_scores["Overall"]
518
+ n += 1
519
+
520
+ g_ov = grpo_scores["Overall"]
521
+ b_ov = base_scores["Overall"]
522
+ winner = "🎯 GRPO" if g_ov > b_ov + 0.05 else "📦 Base" if b_ov > g_ov + 0.05 else "🤝 Tie"
523
+ results.append(f"| {name} | {g_ov:.0%} | {b_ov:.0%} | {winner} |")
524
+
525
+ header = "## 🏆 Full Benchmark Results\n\n"
526
+ header += f"**GRPO Average: {grpo_total/n:.0%}** | **Base Average: {base_total/n:.0%}**\n\n"
527
+ header += "| Test Case | GRPO | Base | Winner |\n|-----------|------|------|--------|\n"
528
+ return header + "\n".join(results)
529
+
530
+
531
+ def load_test_case(name):
532
+ """Load a test case into the code editor."""
533
+ if name == "Custom (paste your own)":
534
+ return ""
535
+ tc = TEST_CASES.get(name, {})
536
+ return tc.get("code", "")
537
+
538
+
539
+ # ── UI ────────────────────────────────────────────────────────────────────────
540
+ with gr.Blocks(
541
+ title="🔐 Smart Contract Security Auditor",
542
+ theme=gr.themes.Soft(),
543
+ css="""
544
+ .score-box { padding: 10px; border-radius: 8px; }
545
+ """
546
+ ) as demo:
547
+ gr.Markdown(
548
+ "# 🔐 Smart Contract Security Auditor\n"
549
+ "### GRPO-Trained vs Base Model — Side-by-Side Comparison\n\n"
550
+ "Compare [`oxdev/security-auditor-grpo`](https://huggingface.co/oxdev/security-auditor-grpo) "
551
+ "(GRPO-trained on 327 real audit findings) against "
552
+ "[`Qwen/Qwen2.5-Coder-0.5B-Instruct`](https://huggingface.co/Qwen/Qwen2.5-Coder-0.5B-Instruct) (base).\n\n"
553
+ "⏱️ **Note:** Running on CPU — each audit takes ~30-90 seconds per model."
554
+ )
555
+
556
+ with gr.Tab("🔍 Single Audit"):
557
+ with gr.Row():
558
+ with gr.Column(scale=2):
559
+ test_case_dropdown = gr.Dropdown(
560
+ choices=list(TEST_CASES.keys()) + ["Custom (paste your own)"],
561
+ value="🔁 Reentrancy (Classic)",
562
+ label="Select Test Case",
563
+ interactive=True,
564
+ )
565
+ code_input = gr.Code(
566
+ label="Solidity Contract",
567
+ language=None,
568
+ lines=22,
569
+ value=TEST_CASES["🔁 Reentrancy (Classic)"]["code"],
570
+ interactive=True,
571
+ )
572
+ with gr.Column(scale=1):
573
+ max_tokens_slider = gr.Slider(
574
+ minimum=128, maximum=1024, value=512, step=64,
575
+ label="Max Output Tokens",
576
+ )
577
+ run_btn = gr.Button("🔍 Run Audit Comparison", variant="primary", size="lg")
578
+ gr.Markdown(
579
+ "**How Scoring Works:**\n"
580
+ "- 📋 **Structure** (20%): FINDING format, sections, fields\n"
581
+ "- 🎯 **Detection** (35%): Identifies the correct vulnerability\n"
582
+ "- ⚠️ **Severity** (15%): Correct severity level\n"
583
+ "- 🔬 **Depth** (20%): Technical terms, reasoning\n"
584
+ "- 💻 **Code** (10%): Includes code examples"
585
+ )
586
+
587
+ with gr.Row():
588
+ comparison_output = gr.Markdown(label="Score Comparison")
589
+
590
+ with gr.Row():
591
+ with gr.Column():
592
+ gr.Markdown("### 🎯 GRPO-Trained Auditor")
593
+ grpo_output = gr.Markdown(label="GRPO Output")
594
+ with gr.Column():
595
+ gr.Markdown("### 📦 Base Qwen2.5-Coder-0.5B-Instruct")
596
+ base_output = gr.Markdown(label="Base Output")
597
+
598
+ test_case_dropdown.change(
599
+ fn=load_test_case,
600
+ inputs=test_case_dropdown,
601
+ outputs=code_input,
602
+ )
603
+
604
+ run_btn.click(
605
+ fn=run_comparison,
606
+ inputs=[code_input, test_case_dropdown, max_tokens_slider],
607
+ outputs=[grpo_output, base_output, comparison_output],
608
+ concurrency_limit=1,
609
+ )
610
+
611
+ with gr.Tab("🏆 Full Benchmark"):
612
+ gr.Markdown(
613
+ "Run all 7 test cases and compare aggregate performance.\n\n"
614
+ "⏱️ **Warning:** This takes 5-10 minutes on CPU (14 model inferences total)."
615
+ )
616
+ bench_btn = gr.Button("🏆 Run Full Benchmark", variant="primary", size="lg")
617
+ bench_output = gr.Markdown(label="Benchmark Results")
618
+ bench_btn.click(
619
+ fn=run_benchmark,
620
+ outputs=bench_output,
621
+ concurrency_limit=1,
622
+ )
623
+
624
+ with gr.Tab("ℹ️ About"):
625
+ gr.Markdown("""
626
+ ## Model Details
627
+
628
+ ### 🎯 GRPO-Trained Auditor (`oxdev/security-auditor-grpo`)
629
+ - **Architecture:** Qwen2ForCausalLM, 0.5B parameters
630
+ - **Training:** Group Relative Policy Optimization (GRPO) on 327 synthetic smart contract audit samples
631
+ - **Reward Functions:** Format compliance, finding rate
632
+ - **Training Results:** Format reward improved 16× (0.025 → 0.40), finding rate 0% → 50-75%
633
+
634
+ ### 📦 Base Model (`Qwen/Qwen2.5-Coder-0.5B-Instruct`)
635
+ - **Architecture:** Same Qwen2ForCausalLM, 0.5B parameters
636
+ - **Training:** Standard instruction tuning by Qwen team
637
+ - **Domain:** General code generation, not specialized for security
638
+
639
+ ### 📊 Training Data
640
+ - **V1 (used for current model):** 327 synthetic attack vector samples
641
+ - **V2 (pending training):** [50,902 real audit findings](https://huggingface.co/datasets/oxdev/smart-contract-security-audit-v2) from top security firms
642
+
643
+ ### 🔬 Scoring Methodology
644
+ Each audit response is scored on 5 dimensions:
645
+ 1. **Structure (20%)** — Does it use the FINDING format with required fields?
646
+ 2. **Detection (35%)** — Does it identify the correct vulnerability class?
647
+ 3. **Severity (15%)** — Does it assign the correct severity level?
648
+ 4. **Depth (20%)** — Technical terminology, reasoning chains, specificity
649
+ 5. **Code (10%)** — Includes code examples (exploit PoC, fix)
650
+
651
+ ### 🔗 Resources
652
+ - [Model on Hub](https://huggingface.co/oxdev/security-auditor-grpo)
653
+ - [Training Dataset V2](https://huggingface.co/datasets/oxdev/smart-contract-security-audit-v2)
654
+ - [GitHub Repository](https://github.com/0xedev/skills)
655
+ """)
656
+
657
+ demo.queue(max_size=5, default_concurrency_limit=1)
658
+
659
+ if __name__ == "__main__":
660
+ demo.launch()