| |
| """ |
| Benchmark Suite for Shield Agents. |
| |
| Tests the scanner against known vulnerable code samples to verify |
| detection accuracy. Uses OWASP WebGoat-inspired test cases and |
| other standard vulnerability benchmarks. |
| |
| Usage: |
| python -m benchmarks.benchmark |
| python -m benchmarks.benchmark --verbose |
| python -m benchmarks.benchmark --category injection |
| """ |
|
|
| import asyncio |
| import json |
| import logging |
| import os |
| import sys |
| import time |
| from dataclasses import dataclass, field |
| from pathlib import Path |
| from typing import Any, Dict, List, Optional, Tuple |
|
|
| |
| sys.path.insert(0, str(Path(__file__).parent.parent)) |
|
|
| from shield_agents.config import ShieldConfig |
| from shield_agents.orchestrator import Orchestrator |
| from shield_agents.scanners import SASTScanner, SecretsScanner |
|
|
| logger = logging.getLogger("shield_agents.benchmark") |
|
|
|
|
| @dataclass |
| class BenchmarkCase: |
| """A single benchmark test case.""" |
| name: str |
| category: str |
| code: str |
| language: str |
| expected_findings: List[str] |
| expected_min_severity: str = "MEDIUM" |
| description: str = "" |
|
|
|
|
| @dataclass |
| class BenchmarkResult: |
| """Result of running a benchmark case.""" |
| name: str |
| category: str |
| passed: bool |
| detected_categories: List[str] |
| expected_categories: List[str] |
| missed_categories: List[str] |
| false_positives: List[str] |
| findings_count: int |
| duration: float |
| details: str = "" |
|
|
|
|
| |
| |
| |
|
|
| BENCHMARK_CASES: List[BenchmarkCase] = [ |
| |
| BenchmarkCase( |
| name="sql_injection_string_concat", |
| category="injection", |
| language="python", |
| code='''import sqlite3 |
| |
| def get_user(username): |
| conn = sqlite3.connect("users.db") |
| cursor = conn.cursor() |
| query = "SELECT * FROM users WHERE username = '" + username + "'" |
| cursor.execute(query) |
| return cursor.fetchone() |
| |
| def search_products(term): |
| conn = sqlite3.connect("shop.db") |
| cursor = conn.cursor() |
| cursor.execute(f"SELECT * FROM products WHERE name LIKE '%{term}%'") |
| return cursor.fetchall() |
| ''', |
| expected_findings=["injection", "sql_injection"], |
| description="SQL injection via string concatenation and f-strings", |
| ), |
| BenchmarkCase( |
| name="sql_injection_format", |
| category="injection", |
| language="python", |
| code='''from django.db import connection |
| |
| def lookup_user(user_id): |
| cursor = connection.cursor() |
| cursor.execute("SELECT * FROM auth_user WHERE id = %s" % user_id) |
| return cursor.fetchone() |
| |
| def raw_query(table): |
| cursor = connection.cursor() |
| cursor.raw("SELECT * FROM {}".format(table)) |
| ''', |
| expected_findings=["injection", "sql_injection"], |
| description="SQL injection via % formatting and .format()", |
| ), |
|
|
| |
| BenchmarkCase( |
| name="xss_dom_based", |
| category="xss", |
| language="javascript", |
| code='''function displaySearchResults() { |
| const params = new URLSearchParams(window.location.search); |
| const query = params.get('q'); |
| document.getElementById('results').innerHTML = query; |
| } |
| |
| function logAction(action) { |
| document.write('<p>Action: ' + action + '</p>'); |
| } |
| ''', |
| expected_findings=["xss"], |
| description="DOM-based XSS via innerHTML and document.write", |
| ), |
| BenchmarkCase( |
| name="xss_template_injection", |
| category="xss", |
| language="python", |
| code='''from flask import Flask, request, render_template_string |
| |
| app = Flask(__name__) |
| |
| @app.route('/greet') |
| def greet(): |
| name = request.args.get('name', 'World') |
| template = f"<h1>Hello {name}!</h1>" |
| return render_template_string(template) |
| |
| @app.route('/profile') |
| def profile(): |
| bio = request.args.get('bio', '') |
| return f"<div>{bio | safe}</div>" |
| ''', |
| expected_findings=["xss", "injection"], |
| description="Server-side XSS via template injection and |safe filter", |
| ), |
|
|
| |
| BenchmarkCase( |
| name="command_injection_os_system", |
| category="injection", |
| language="python", |
| code='''import os |
| import subprocess |
| |
| def ping_host(host): |
| os.system(f"ping -c 4 {host}") |
| |
| def run_tool(input_file): |
| subprocess.call(f"convert {input_file} output.png", shell=True) |
| |
| def execute_command(cmd): |
| subprocess.Popen(cmd, shell=True) |
| ''', |
| expected_findings=["injection", "command_injection"], |
| description="OS command injection via os.system and subprocess with shell=True", |
| ), |
|
|
| |
| BenchmarkCase( |
| name="path_traversal_open", |
| category="path-traversal", |
| language="python", |
| code='''from flask import Flask, request, send_file |
| |
| app = Flask(__name__) |
| |
| @app.route('/download') |
| def download_file(): |
| filename = request.args.get('file') |
| with open('/var/files/' + filename, 'r') as f: |
| return f.read() |
| |
| @app.route('/view') |
| def view_file(): |
| path = request.args.get('path') |
| return send_file(path) |
| ''', |
| expected_findings=["path-traversal"], |
| description="Path traversal via user-controlled file paths", |
| ), |
|
|
| |
| BenchmarkCase( |
| name="insecure_deserialization_pickle", |
| category="deserialization", |
| language="python", |
| code='''import pickle |
| import yaml |
| |
| def load_session(data): |
| return pickle.loads(data) |
| |
| def load_config(content): |
| return yaml.load(content) # Missing Loader parameter |
| |
| def load_cache(raw): |
| import marshal |
| return marshal.loads(raw) |
| ''', |
| expected_findings=["deserialization"], |
| description="Insecure deserialization via pickle, yaml, and marshal", |
| ), |
|
|
| |
| BenchmarkCase( |
| name="weak_cryptography", |
| category="cryptography", |
| language="python", |
| code='''import hashlib |
| import random |
| |
| def hash_password(password): |
| return hashlib.md5(password.encode()).hexdigest() |
| |
| def generate_token(): |
| return str(random.randint(100000, 999999)) |
| |
| def hash_data(data): |
| return hashlib.sha1(data.encode()).hexdigest() |
| ''', |
| expected_findings=["cryptography"], |
| description="Weak cryptography: MD5, SHA1, and random module for security", |
| ), |
|
|
| |
| BenchmarkCase( |
| name="hardcoded_secrets", |
| category="credentials", |
| language="python", |
| code='''# Database configuration |
| DB_HOST = "db.prodserver.com" |
| DB_PASSWORD = "SuperSecret123!" |
| API_KEY = "PLACEHOLDER_STRIPE_KEY_FOR_TESTING_ONLY" |
| |
| import requests |
| |
| def call_api(): |
| headers = {"Authorization": "Bearer PLACEHOLDER_JWT_TOKEN_FOR_TESTING_ONLY"} |
| return requests.get("https://api.example.com/data", headers=headers) |
| |
| AWS_ACCESS_KEY = "PLACEHOLDER_AWS_KEY_FOR_TESTING_ONLY" |
| AWS_SECRET_KEY = "PLACEHOLDER_AWS_SECRET_FOR_TESTING_ONLY" |
| ''', |
| expected_findings=["credentials", "cloud-credentials", "auth-tokens", "generic-secrets"], |
| description="Hardcoded passwords, API keys, JWTs, and AWS credentials", |
| ), |
|
|
| |
| BenchmarkCase( |
| name="ssl_verification_disabled", |
| category="security-misconfiguration", |
| language="python", |
| code='''import requests |
| import ssl |
| |
| def fetch_data(url): |
| return requests.get(url, verify=False) |
| |
| def create_context(): |
| ctx = ssl._create_unverified_context() |
| return ctx |
| |
| def disable_cert_check(): |
| ssl._create_default_https_context = ssl._create_unverified_context |
| ''', |
| expected_findings=["security-misconfiguration"], |
| description="SSL/TLS certificate verification disabled", |
| ), |
|
|
| |
| BenchmarkCase( |
| name="ssrf_requests", |
| category="ssrf", |
| language="python", |
| code='''import requests |
| from flask import Flask, request |
| |
| app = Flask(__name__) |
| |
| @app.route('/fetch') |
| def fetch_url(): |
| url = request.args.get('url') |
| response = requests.get(url) |
| return response.text |
| |
| @app.route('/proxy') |
| def proxy_request(): |
| target = request.args.get('target') |
| return requests.post(target, data=request.form).text |
| ''', |
| expected_findings=["ssrf"], |
| description="SSRF via user-controlled URL in requests", |
| ), |
|
|
| |
| BenchmarkCase( |
| name="auth_bypass", |
| category="authentication", |
| language="python", |
| code='''from flask import Flask, session |
| |
| app = Flask(__name__) |
| |
| @app.route('/admin') |
| def admin_panel(): |
| assert session.get('is_admin') |
| return "Admin panel" |
| |
| @app.route('/login', methods=['POST']) |
| def login(): |
| session['authenticated'] = True |
| return "Logged in" |
| ''', |
| expected_findings=["authentication"], |
| description="Auth bypass via assertion and session manipulation", |
| ), |
|
|
| |
| BenchmarkCase( |
| name="clean_code_minimal_findings", |
| category="clean", |
| language="python", |
| code='''import os |
| import hashlib |
| import secrets |
| from typing import Optional |
| |
| def get_database_url() -> str: |
| """Get database URL from environment variable.""" |
| return os.environ.get("DATABASE_URL", "sqlite:///default.db") |
| |
| def hash_password(password: str, salt: Optional[bytes] = None) -> str: |
| """Hash password using SHA-256 with salt.""" |
| if salt is None: |
| salt = secrets.token_bytes(32) |
| return hashlib.sha256(salt + password.encode()).hexdigest() |
| |
| def generate_session_token() -> str: |
| """Generate a cryptographically secure session token.""" |
| return secrets.token_urlsafe(32) |
| |
| def validate_input(data: str, max_length: int = 1000) -> str: |
| """Validate and sanitize user input.""" |
| if len(data) > max_length: |
| raise ValueError(f"Input exceeds maximum length of {max_length}") |
| return data.strip() |
| ''', |
| expected_findings=[], |
| description="Clean, secure code that should produce minimal/no findings", |
| ), |
| ] |
|
|
|
|
| class BenchmarkRunner: |
| """Run benchmark test cases against Shield Agents.""" |
|
|
| def __init__(self, config: Optional[ShieldConfig] = None): |
| self.config = config or ShieldConfig() |
| self.results: List[BenchmarkResult] = [] |
|
|
| async def run_all(self, category: Optional[str] = None, verbose: bool = False) -> List[BenchmarkResult]: |
| """Run all benchmark cases. |
| |
| Args: |
| category: Only run cases in this category. |
| verbose: Print detailed output. |
| |
| Returns: |
| List of benchmark results. |
| """ |
| cases = BENCHMARK_CASES |
| if category: |
| cases = [c for c in cases if c.category == category] |
|
|
| print(f"\n{'='*60}") |
| print(f"Shield Agents Benchmark Suite") |
| print(f"Running {len(cases)} test cases...") |
| print(f"{'='*60}\n") |
|
|
| for case in cases: |
| result = await self.run_case(case, verbose) |
| self.results.append(result) |
| status = "PASS" if result.passed else "FAIL" |
| print(f" [{status}] {case.name} ({case.category})") |
| if not result.passed and verbose: |
| print(f" Expected: {result.expected_categories}") |
| print(f" Detected: {result.detected_categories}") |
| print(f" Missed: {result.missed_categories}") |
| if result.false_positives: |
| print(f" False +: {result.false_positives}") |
|
|
| return self.results |
|
|
| async def run_case(self, case: BenchmarkCase, verbose: bool = False) -> BenchmarkResult: |
| """Run a single benchmark case. |
| |
| Args: |
| case: Benchmark case to run. |
| verbose: Print detailed output. |
| |
| Returns: |
| Benchmark result. |
| """ |
| start_time = time.time() |
|
|
| |
| import tempfile |
| with tempfile.NamedTemporaryFile( |
| mode="w", |
| suffix=f".{case.language == 'javascript' and 'js' or case.language}", |
| delete=False, |
| ) as f: |
| f.write(case.code) |
| temp_path = f.name |
|
|
| try: |
| |
| sast = SASTScanner(self.config) |
| secrets = SecretsScanner(self.config) |
|
|
| sast_findings = sast.scan_file(temp_path) |
| secrets_findings = secrets.scan_file(temp_path) |
|
|
| all_findings = sast_findings + secrets_findings |
|
|
| |
| detected = set() |
| for finding in all_findings: |
| cat = finding.get("category", "").lower() |
| detected.add(cat) |
| |
| title = finding.get("title", "").lower() |
| if "sql" in title: |
| detected.add("sql_injection") |
| detected.add("injection") |
| if "xss" in title: |
| detected.add("xss") |
| if "command" in title or "os system" in title: |
| detected.add("command_injection") |
| detected.add("injection") |
|
|
| expected = set(c.lower() for c in case.expected_findings) |
| missed = expected - detected |
| false_positives = detected - expected - {"generic-secrets", "credentials"} |
|
|
| |
| |
| if not expected: |
| |
| high_severity = [f for f in all_findings if f.get("severity", "").upper() in ("CRITICAL", "HIGH")] |
| passed = len(high_severity) == 0 |
| else: |
| passed = len(missed) < len(expected) |
|
|
| duration = time.time() - start_time |
|
|
| return BenchmarkResult( |
| name=case.name, |
| category=case.category, |
| passed=passed, |
| detected_categories=sorted(list(detected)), |
| expected_categories=sorted(list(expected)), |
| missed_categories=sorted(list(missed)), |
| false_positives=sorted(list(false_positives)), |
| findings_count=len(all_findings), |
| duration=duration, |
| ) |
|
|
| finally: |
| os.unlink(temp_path) |
|
|
| def print_summary(self): |
| """Print a summary of all benchmark results.""" |
| total = len(self.results) |
| passed = sum(1 for r in self.results if r.passed) |
| failed = total - passed |
|
|
| print(f"\n{'='*60}") |
| print(f"Benchmark Summary") |
| print(f"{'='*60}") |
| print(f"Total: {total}") |
| print(f"Passed: {passed}") |
| print(f"Failed: {failed}") |
| print(f"Rate: {(passed/total*100):.1f}%" if total > 0 else "Rate: N/A") |
|
|
| if failed > 0: |
| print(f"\nFailed cases:") |
| for r in self.results: |
| if not r.passed: |
| print(f" - {r.name} (missed: {r.missed_categories})") |
|
|
| |
| categories = {} |
| for r in self.results: |
| cat = r.category |
| if cat not in categories: |
| categories[cat] = {"passed": 0, "failed": 0} |
| if r.passed: |
| categories[cat]["passed"] += 1 |
| else: |
| categories[cat]["failed"] += 1 |
|
|
| print(f"\nCategory Breakdown:") |
| for cat, stats in sorted(categories.items()): |
| total_cat = stats["passed"] + stats["failed"] |
| print(f" {cat}: {stats['passed']}/{total_cat} passed") |
|
|
| return passed, total |
|
|
|
|
| async def main(): |
| """Main entry point for the benchmark suite.""" |
| import argparse |
|
|
| parser = argparse.ArgumentParser(description="Shield Agents Benchmark Suite") |
| parser.add_argument("--verbose", "-v", action="store_true", help="Verbose output") |
| parser.add_argument("--category", "-c", help="Run only specific category") |
| parser.add_argument("--json", action="store_true", help="Output results as JSON") |
| args = parser.parse_args() |
|
|
| runner = BenchmarkRunner() |
| await runner.run_all(category=args.category, verbose=args.verbose) |
| runner.print_summary() |
|
|
| if args.json: |
| results_json = [] |
| for r in runner.results: |
| results_json.append({ |
| "name": r.name, |
| "category": r.category, |
| "passed": r.passed, |
| "detected": r.detected_categories, |
| "expected": r.expected_categories, |
| "missed": r.missed_categories, |
| "findings_count": r.findings_count, |
| }) |
| print(json.dumps(results_json, indent=2)) |
|
|
|
|
| if __name__ == "__main__": |
| asyncio.run(main()) |
|
|